diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml
index 499613625..9b89b69ae 100644
--- a/.github/workflows/build.yml
+++ b/.github/workflows/build.yml
@@ -2,7 +2,7 @@ name: build
on:
push:
- branches: ["main"]
+ branches: [ "main" ]
jobs:
build:
@@ -16,10 +16,5 @@ jobs:
with:
go-version-file: go.mod
- - name: fmt
- run: |
- make fmt
- git diff --exit-code || (echo "::error::Code is not formatted. Run 'make fmt' and commit the changes." && exit 1)
-
- name: Build
run: make build-all
diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml
index 55bf77e00..be1c10c52 100644
--- a/.github/workflows/pr.yml
+++ b/.github/workflows/pr.yml
@@ -24,48 +24,9 @@ jobs:
with:
version: v2.10.1
- # TODO: Remove once linter is properly configured
- fmt-check:
- name: Formatting
- runs-on: ubuntu-latest
- steps:
- - name: Checkout
- uses: actions/checkout@v6
-
- - name: Setup Go
- uses: actions/setup-go@v6
- with:
- go-version-file: go.mod
-
- - name: Check formatting
- run: |
- make fmt
- git diff --exit-code || (echo "::error::Code is not formatted. Run 'make fmt' and commit the changes." && exit 1)
-
- # TODO: Remove once linter is properly configured
- vet:
- name: Vet
- runs-on: ubuntu-latest
- needs: fmt-check
- steps:
- - name: Checkout
- uses: actions/checkout@v6
-
- - name: Setup Go
- uses: actions/setup-go@v6
- with:
- go-version-file: go.mod
-
- - name: Run go generate
- run: go generate ./...
-
- - name: Run go vet
- run: go vet ./...
-
test:
name: Tests
runs-on: ubuntu-latest
- needs: fmt-check
steps:
- name: Checkout
uses: actions/checkout@v6
diff --git a/.gitignore b/.gitignore
index ce30d749e..3ff195fbf 100644
--- a/.gitignore
+++ b/.gitignore
@@ -10,7 +10,7 @@ build/
*.out
/picoclaw
/picoclaw-test
-cmd/picoclaw/workspace
+cmd/**/workspace
# Picoclaw specific
diff --git a/.golangci.yaml b/.golangci.yaml
index 80e54ac1c..d0ba90716 100644
--- a/.golangci.yaml
+++ b/.golangci.yaml
@@ -28,9 +28,7 @@ linters:
- wsl_v5
# TODO: Disabled, because they are failing at the moment, we should fix them and enable (step by step)
- - bodyclose
- contextcheck
- - dogsled
- embeddedstructfieldcheck
- errcheck
- errchkjson
@@ -45,33 +43,24 @@ linters:
- gocritic
- gocyclo
- godox
- - goprintffuncname
- gosec
- - govet
- ineffassign
- lll
- maintidx
- - misspell
- mnd
- modernize
- - nakedret
- nestif
- nilnil
- paralleltest
- perfsprint
- - prealloc
- - predeclared
- revive
- staticcheck
- tagalign
- testifylint
- thelper
- unparam
- - unused
- usestdlibvars
- usetesting
- - wastedassign
- - whitespace
settings:
errcheck:
check-type-assertions: true
@@ -153,6 +142,9 @@ linters:
- gocognit
- gocyclo
path: _test\.go$
+ - linters:
+ - nolintlint
+ path: 'pkg/tools/(i2c\.go|spi\.go)$'
issues:
max-issues-per-linter: 0
@@ -160,12 +152,11 @@ issues:
formatters:
enable:
+ - gci
+ - gofmt
+ - gofumpt
- goimports
- # TODO: Disabled, because they are failing at the moment, we should fix them and enable (step by step)
- # - gci
- # - gofmt
- # - gofumpt
- # - golines
+ - golines
settings:
gci:
sections:
diff --git a/.goreleaser.yaml b/.goreleaser.yaml
index 2c47f7d86..69bf1fae3 100644
--- a/.goreleaser.yaml
+++ b/.goreleaser.yaml
@@ -15,10 +15,10 @@ builds:
- stdjson
ldflags:
- -s -w
- - -X main.version={{ .Version }}
- - -X main.gitCommit={{ .ShortCommit }}
- - -X main.buildTime={{ .Date }}
- - -X main.goVersion={{ .Env.GOVERSION }}
+ - -X github.com/sipeed/picoclaw/cmd/picoclaw/internal.version={{ .Version }}
+ - -X github.com/sipeed/picoclaw/cmd/picoclaw/internal.gitCommit={{ .ShortCommit }}
+ - -X github.com/sipeed/picoclaw/cmd/picoclaw/internal.buildTime={{ .Date }}
+ - -X github.com/sipeed/picoclaw/cmd/picoclaw/internal.goVersion={{ .Env.GOVERSION }}
goos:
- linux
- windows
@@ -28,9 +28,10 @@ builds:
- amd64
- arm64
- riscv64
- - s390x
- - mips64
+ - loong64
- arm
+ goarm:
+ - "7"
main: ./cmd/picoclaw
ignore:
- goos: windows
@@ -67,6 +68,25 @@ archives:
- goos: windows
formats: [zip]
+nfpms:
+ - id: picoclaw
+ package_name: picoclaw
+ file_name_template: >-
+ {{ .PackageName }}_
+ {{- if eq .Arch "amd64" }}x86_64
+ {{- else if eq .Arch "arm64" }}aarch64
+ {{- else if eq .Arch "arm" }}armv{{ .Arm }}
+ {{- else }}{{ .Arch }}{{ end }}
+ vendor: picoclaw
+ homepage: https://github.com/{{ .Env.GITHUB_REPOSITORY_OWNER }}/picoclaw
+ maintainer: picoclaw contributors
+ description: picoclaw - a tool for managing and running tasks
+ license: MIT
+ formats:
+ - rpm
+ - deb
+ bindir: /usr/bin
+
changelog:
sort: asc
filters:
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
new file mode 100644
index 000000000..88227f493
--- /dev/null
+++ b/CONTRIBUTING.md
@@ -0,0 +1,302 @@
+# Contributing to PicoClaw
+
+Thank you for your interest in contributing to PicoClaw! This project is a community-driven effort to build the lightweight and versatile personal AI assistant. We welcome contributions of all kinds: bug fixes, features, documentation, translations, and testing.
+
+PicoClaw itself was substantially developed with AI assistance — we embrace this approach and have built our contribution process around it.
+
+## Table of Contents
+
+- [Code of Conduct](#code-of-conduct)
+- [Ways to Contribute](#ways-to-contribute)
+- [Getting Started](#getting-started)
+- [Development Setup](#development-setup)
+- [Making Changes](#making-changes)
+- [AI-Assisted Contributions](#ai-assisted-contributions)
+- [Pull Request Process](#pull-request-process)
+- [Branch Strategy](#branch-strategy)
+- [Code Review](#code-review)
+- [Communication](#communication)
+
+---
+
+## Code of Conduct
+
+We are committed to maintaining a welcoming and respectful community. Be kind, constructive, and assume good faith. Harassment or discrimination of any kind will not be tolerated.
+
+---
+
+## Ways to Contribute
+
+- **Bug reports** — Open an issue using the bug report template.
+- **Feature requests** — Open an issue using the feature request template; discuss before implementing.
+- **Code** — Fix bugs or implement features. See the workflow below.
+- **Documentation** — Improve READMEs, docs, inline comments, or translations.
+- **Testing** — Run PicoClaw on new hardware, channels, or LLM providers and report your results.
+
+For substantial new features, please open an issue first to discuss the design before writing code. This prevents wasted effort and ensures alignment with the project's direction.
+
+---
+
+## Getting Started
+
+1. **Fork** the repository on GitHub.
+2. **Clone** your fork locally:
+ ```bash
+ git clone https://github.com//picoclaw.git
+ cd picoclaw
+ ```
+3. Add the upstream remote:
+ ```bash
+ git remote add upstream https://github.com/sipeed/picoclaw.git
+ ```
+
+---
+
+## Development Setup
+
+### Prerequisites
+
+- Go 1.25 or later
+- `make`
+
+### Build
+
+```bash
+make build # Build binary (runs go generate first)
+make generate # Run go generate only
+make check # Full pre-commit check: deps + fmt + vet + test
+```
+
+### Running Tests
+
+```bash
+make test # Run all tests
+go test -run TestName -v ./pkg/session/ # Run a single test
+go test -bench=. -benchmem -run='^$' ./... # Run benchmarks
+```
+
+### Code Style
+
+```bash
+make fmt # Format code
+make vet # Static analysis
+make lint # Full linter run
+```
+
+All CI checks must pass before a PR can be merged. Run `make check` locally before pushing to catch issues early.
+
+---
+
+## Making Changes
+
+### Branching
+
+Always branch off `main` and target `main` in your PR. Never push directly to `main` or any `release/*` branch:
+
+```bash
+git checkout main
+git pull upstream main
+git checkout -b your-feature-branch
+```
+
+Use descriptive branch names, e.g. `fix/telegram-timeout`, `feat/ollama-provider`, `docs/contributing-guide`.
+
+### Commits
+
+- Write clear, concise commit messages in English.
+- Use the imperative mood: "Add retry logic" not "Added retry logic".
+- Reference the related issue when relevant: `Fix session leak (#123)`.
+- Keep commits focused. One logical change per commit is preferred.
+- For minor cleanups or typo fixes, squash them into a single commit before opening a PR.
+- Refer to https://www.conventionalcommits.org/zh-hans/v1.0.0/
+
+### Keeping Up to Date
+
+Rebase your branch onto upstream `main` before opening a PR:
+
+```bash
+git fetch upstream
+git rebase upstream/main
+```
+
+---
+
+## AI-Assisted Contributions
+
+PicoClaw was built with substantial AI assistance, and we fully embrace AI-assisted development. However, contributors must understand their responsibilities when using AI tools.
+
+### Disclosure Is Required
+
+Every PR must disclose AI involvement using the PR template's **🤖 AI Code Generation** section. There are three levels:
+
+| Level | Description |
+|---|---|
+| 🤖 Fully AI-generated | AI wrote the code; contributor reviewed and validated it |
+| 🛠️ Mostly AI-generated | AI produced the draft; contributor made significant modifications |
+| 👨💻 Mostly Human-written | Contributor led; AI provided suggestions or none at all |
+
+Honest disclosure is expected. There is no stigma attached to any level — what matters is the quality of the contribution.
+
+### You Are Responsible for What You Submit
+
+Using AI to generate code does not reduce your responsibility as the contributor. Before opening a PR with AI-generated code, you must:
+
+- **Read and understand** every line of the generated code.
+- **Test it** in a real environment (see the Test Environment section of the PR template).
+- **Check for security issues** — AI models can generate subtly insecure code (e.g., path traversal, injection, credential exposure). Review carefully.
+- **Verify correctness** — AI-generated logic can be plausible-sounding but wrong. Validate the behavior, not just the syntax.
+
+PRs where it is clear the contributor has not read or tested the AI-generated code will be closed without review.
+
+### AI-Generated Code Quality Standards
+
+AI-generated contributions are held to the **same quality bar** as human-written code:
+
+- It must pass all CI checks (`make check`).
+- It must be idiomatic Go and consistent with the existing codebase style.
+- It must not introduce unnecessary abstractions, dead code, or over-engineering.
+- It must include or update tests where appropriate.
+
+### Security Review
+
+AI-generated code requires extra security scrutiny. Pay special attention to:
+
+- File path handling and sandbox escapes (see commit `244eb0b` for a real example)
+- External input validation in channel handlers and tool implementations
+- Credential or secret handling
+- Command execution (`exec.Command`, shell invocations)
+
+If you are unsure whether a piece of AI-generated code is safe, say so in the PR — reviewers will help.
+
+---
+
+## Pull Request Process
+
+### Before Opening a PR
+
+- [ ] Run `make check` and ensure it passes locally.
+- [ ] Fill in the PR template completely, including the AI disclosure section.
+- [ ] Link any related issue(s) in the PR description.
+- [ ] Keep the PR focused. Avoid bundling unrelated changes together.
+
+### PR Template Sections
+
+The PR template asks for:
+
+- **Description** — What does this change do and why?
+- **Type of Change** — Bug fix, feature, docs, or refactor.
+- **AI Code Generation** — Disclosure of AI involvement (required).
+- **Related Issue** — Link to the issue this addresses.
+- **Technical Context** — Reference URLs and reasoning (skip for pure docs PRs).
+- **Test Environment** — Hardware, OS, model/provider, and channels used for testing.
+- **Evidence** — Optional logs or screenshots demonstrating the change works.
+- **Checklist** — Self-review confirmation.
+
+### PR Size
+
+Prefer small, reviewable PRs. A PR that changes 200 lines across 5 files is much easier to review than one that changes 2000 lines across 30 files. If your feature is large, consider splitting it into a series of smaller, logically complete PRs.
+
+---
+
+## Branch Strategy
+
+### Long-Lived Branches
+
+- **`main`** — the active development branch. All feature PRs target `main`. The branch is protected: direct pushes are not permitted, and at least one maintainer approval is required before merging.
+- **`release/x.y`** — stable release branches, cut from `main` when a version is ready to ship. These branches are more strictly protected than `main`.
+
+### Requirements to Merge into `main`
+
+A PR can only be merged when all of the following are satisfied:
+
+1. **CI passes** — All GitHub Actions workflows (lint, test, build) must be green.
+2. **Reviewer approval** — At least one maintainer has approved the PR.
+3. **No unresolved review comments** — All review threads must be resolved.
+4. **PR template is complete** — Including AI disclosure and test environment.
+
+### Who Can Merge
+
+Only maintainers can merge PRs. Contributors cannot merge their own PRs, even if they have write access.
+
+### Merge Strategy
+
+We use **squash merge** for most PRs to keep the `main` history clean and readable. Each merged PR becomes a single commit referencing the PR number, e.g.:
+
+```
+feat: Add Ollama provider support (#491)
+```
+
+If a PR consists of multiple independent, well-separated commits that tell a clear story, a regular merge may be used at the maintainer's discretion.
+
+### Release Branches
+
+When a version is ready, maintainers cut a `release/x.y` branch from `main`. After that point:
+
+- **New features are not backported.** The release branch receives no new functionality after it is cut.
+- **Security fixes and critical bug fixes are cherry-picked.** If a fix in `main` qualifies (security vulnerability, data loss, crash), maintainers will cherry-pick the relevant commit(s) onto the affected `release/x.y` branch and issue a patch release.
+
+If you believe a fix in `main` should be backported to a release branch, note it in the PR description or open a separate issue. The decision rests with the maintainers.
+
+Release branches have stricter protections than `main` and are never directly pushed to under any circumstances.
+
+---
+
+## Code Review
+
+### For Contributors
+
+- Respond to review comments within a reasonable time. If you need more time, say so.
+- When you update a PR in response to feedback, briefly note what changed (e.g., "Updated to use `sync.RWMutex` as suggested").
+- If you disagree with feedback, engage respectfully. Explain your reasoning; reviewers can be wrong too.
+- Do not force-push after a review has started — it makes it harder for reviewers to see what changed. Use additional commits instead; the maintainer will squash on merge.
+
+### For Reviewers
+
+Review for:
+
+1. **Correctness** — Does the code do what it claims? Are there edge cases?
+2. **Security** — Especially for AI-generated code, tool implementations, and channel handlers.
+3. **Architecture** — Is the approach consistent with the existing design?
+4. **Simplicity** — Is there a simpler solution? Does this add unnecessary complexity?
+5. **Tests** — Are the changes covered by tests? Are existing tests still meaningful?
+
+Be constructive and specific. "This could have a race condition if two goroutines call this concurrently — consider using a mutex here" is better than "this looks wrong".
+
+
+### Reviewer List
+Once your PR is submitted, you can reach out to the assigned reviewers listed in the following table.
+
+|Function| Reviewer|
+|--- |--- |
+|Provider|@yinwm |
+|Channel |@yinwm |
+|Agent |@lxowalle|
+|Tools |@lxowalle|
+|SKill ||
+|MCP ||
+|Optimization|@lxowalle|
+|Security||
+|AI CI |@imguoguo|
+|UX ||
+|Document||
+
+---
+
+## Communication
+
+- **GitHub Issues** — Bug reports, feature requests, design discussions.
+- **GitHub Discussions** — General questions, ideas, community conversation.
+- **Pull Request comments** — Code-specific feedback.
+- **Wechat&Discord** — We will invite you when you have at least one merged PR
+
+When in doubt, open an issue before writing code. It costs little and prevents wasted effort.
+
+---
+
+## A Note on the Project's AI-Driven Origin
+
+PicoClaw's architecture was substantially designed and implemented with AI assistance, guided by human oversight. If you find something that looks odd or over-engineered, it may be an artifact of that process — opening an issue to discuss it is always welcome.
+
+We believe AI-assisted development done responsibly produces great results. We also believe humans must remain accountable for what they ship. These two beliefs are not in conflict.
+
+Thank you for contributing!
diff --git a/CONTRIBUTING.zh.md b/CONTRIBUTING.zh.md
new file mode 100644
index 000000000..01a1abfd5
--- /dev/null
+++ b/CONTRIBUTING.zh.md
@@ -0,0 +1,303 @@
+# 参与贡献 PicoClaw
+
+感谢你对 PicoClaw 的关注!本项目是一个社区驱动的开源项目,目标是构建 轻量灵活,人人可用 的个人AI助手。我们欢迎一切形式的贡献:Bug 修复、新功能、文档、翻译和测试。
+
+PicoClaw 本身在很大程度上是借助 AI 辅助开发的——我们拥抱这种方式,并围绕它构建了贡献流程。
+
+## 目录
+
+- [行为准则](#行为准则)
+- [贡献方式](#贡献方式)
+- [快速开始](#快速开始)
+- [开发环境配置](#开发环境配置)
+- [提交修改](#提交修改)
+- [AI 辅助贡献](#ai-辅助贡献)
+- [Pull Request 流程](#pull-request-流程)
+- [分支策略](#分支策略)
+- [代码审查](#代码审查)
+- [沟通渠道](#沟通渠道)
+
+---
+
+## 行为准则
+
+我们致力于维护一个友好、互相尊重的社区环境。请保持善意、建设性的态度,并善意地理解他人。任何形式的骚扰或歧视均不被接受。
+
+---
+
+## 贡献方式
+
+- **Bug 反馈** — 使用 Bug 报告模板提交 Issue。
+- **功能建议** — 使用功能请求模板提交 Issue,建议在开始实现前先进行讨论。
+- **代码贡献** — 修复 Bug 或实现新功能,参见下方工作流程。
+- **文档改进** — 完善 README、文档、代码注释或翻译。
+- **测试与验证** — 在新硬件、新渠道或新 LLM 提供商上运行 PicoClaw 并反馈结果。
+
+对于较大的新功能,请先提交 Issue 讨论设计方案,再动手写代码。这能避免无效投入,也确保与项目方向保持一致。
+
+---
+
+## 快速开始
+
+1. 在 GitHub 上 **Fork** 本仓库。
+2. 将你的 Fork **克隆**到本地:
+ ```bash
+ git clone https://github.com/<你的用户名>/picoclaw.git
+ cd picoclaw
+ ```
+3. 添加上游远程仓库:
+ ```bash
+ git remote add upstream https://github.com/sipeed/picoclaw.git
+ ```
+
+---
+
+## 开发环境配置
+
+### 前置依赖
+
+- Go 1.25 或更高版本
+- `make`
+
+### 构建
+
+```bash
+make build # 构建二进制文件(会先执行 go generate)
+make generate # 仅执行 go generate
+make check # 完整的提交前检查:deps + fmt + vet + test
+```
+
+### 运行测试
+
+```bash
+make test # 运行所有测试
+go test -run TestName -v ./pkg/session/ # 运行单个测试
+go test -bench=. -benchmem -run='^$' ./... # 运行基准测试
+```
+
+### 代码风格
+
+```bash
+make fmt # 格式化代码
+make vet # 静态分析
+make lint # 完整的 lint 检查
+```
+
+所有 CI 检查通过后 PR 才能被合并。推送代码前请先在本地运行 `make check`,提前发现问题。
+
+---
+
+## 提交修改
+
+### 分支管理
+
+始终从 `main` 分支切出,并在 PR 中以 `main` 为目标分支。不要直接向 `main` 或任何 `release/*` 分支推送代码:
+
+```bash
+git checkout main
+git pull upstream main
+git checkout -b 你的功能分支名
+```
+
+请使用描述性的分支名,例如:`fix/telegram-timeout`、`feat/ollama-provider`、`docs/contributing-guide`。
+
+### Commit 规范
+
+- 使用英文撰写清晰、简洁的 commit 信息。
+- 使用祈使句:写 "Add retry logic",而不是 "Added retry logic"。
+- 有关联 Issue 时请引用:`Fix session leak (#123)`。
+- 保持 commit 专注,每个 commit 只做一件事。
+- 对于小的清理或拼写修正,提 PR 前请将其合并为一个 commit。
+- 按照 https://www.conventionalcommits.org/zh-hans/v1.0.0/ 规范来撰写
+
+### 保持与上游同步
+
+提 PR 前,请将你的分支变基到上游 `main`:
+
+```bash
+git fetch upstream
+git rebase upstream/main
+```
+
+---
+
+## AI 辅助贡献
+
+PicoClaw 在很大程度上借助 AI 辅助开发,我们完全拥抱这种开发方式。但贡献者必须清楚地了解自己在使用 AI 工具时所承担的责任。
+
+### 必须披露 AI 使用情况
+
+每个 PR 都必须通过 PR 模板中的 **🤖 AI 代码生成** 部分披露 AI 参与情况,共分三个级别:
+
+| 级别 | 说明 |
+|---|---|
+| 🤖 完全由 AI 生成 | AI 编写代码,贡献者负责审查和验证 |
+| 🛠️ 主要由 AI 生成 | AI 起草,贡献者做了较大修改 |
+| 👨💻 主要由人工编写 | 贡献者主导,AI 仅提供辅助或未使用 AI |
+
+我们期望你诚实填写。三种级别均可接受,没有任何歧视——重要的是贡献的质量。
+
+### 你对提交的代码负全责
+
+使用 AI 生成代码并不能减轻你作为贡献者的责任。在提交含有 AI 生成代码的 PR 之前,你必须:
+
+- **逐行阅读并理解**生成的代码。
+- **在真实环境中测试**(参见 PR 模板中的测试环境部分)。
+- **检查安全问题** — AI 模型可能生成存在安全隐患的代码(如路径穿越、注入攻击、凭据泄露等),请仔细审查。
+- **验证正确性** — AI 生成的逻辑可能听起来合理但实际上是错误的,请验证行为,而不仅仅是语法。
+
+如果明显可以看出贡献者没有阅读或测试 AI 生成的代码,该 PR 将被直接关闭,不予审查。
+
+### AI 生成代码的质量标准
+
+AI 生成的代码与人工编写的代码遵循**相同的质量要求**:
+
+- 必须通过所有 CI 检查(`make check`)。
+- 必须符合 Go 惯用写法,并与现有代码库的风格保持一致。
+- 不得引入不必要的抽象、死代码或过度设计。
+- 须在适当的地方包含或更新测试。
+
+### 安全审查
+
+AI 生成的代码需要格外仔细的安全审查。请特别关注以下方面:
+
+- 文件路径处理与沙箱逃逸(项目历史中的 commit `244eb0b` 就是真实案例)
+- channel 处理器和 tool 实现中的外部输入校验
+- 凭据或密钥的处理
+- 命令执行(`exec.Command`、shell 调用等)
+
+如果你不确定某段 AI 生成代码是否安全,请在 PR 中说明——审查者会帮助判断。
+
+---
+
+## Pull Request 流程
+
+### 提 PR 前的检查
+
+- [ ] 在本地运行 `make check` 并确认通过。
+- [ ] 完整填写 PR 模板,包括 AI 披露部分。
+- [ ] 在 PR 描述中关联相关 Issue。
+- [ ] 保持 PR 专注,避免将不相关的修改混在一起。
+
+### PR 模板各部分说明
+
+PR 模板要求填写:
+
+- **描述** — 这个改动做了什么,为什么要做?
+- **变更类型** — Bug 修复、新功能、文档或重构。
+- **AI 代码生成** — AI 参与情况披露(必填)。
+- **关联 Issue** — 此 PR 解决的 Issue 链接。
+- **技术背景** — 参考链接和设计理由(纯文档类 PR 可跳过)。
+- **测试环境** — 用于测试的硬件、操作系统、模型/提供商和渠道。
+- **验证证据** — 可选的日志或截图,用于证明改动有效。
+- **检查清单** — 自我审查确认。
+
+### PR 规模
+
+请尽量提交小而易于审查的 PR。一个涉及 5 个文件共 200 行改动的 PR,远比涉及 30 个文件共 2000 行改动的 PR 容易审查。如果你的功能较大,可以考虑将其拆分为一系列逻辑完整的小 PR。
+
+---
+
+## 分支策略
+
+### 长期分支
+
+- **`main`** — 活跃开发分支。所有功能 PR 均以 `main` 为目标。该分支受保护:禁止直接推送,合并前必须获得至少一名维护者的批准。
+- **`release/x.y`** — 稳定发布分支,在某个版本准备发布时从 `main` 切出。这些分支的保护级别高于 `main`。
+
+### 合并到 `main` 的前提条件
+
+PR 必须同时满足以下所有条件,才能被合并:
+
+1. **CI 全部通过** — 所有 GitHub Actions 工作流(lint、test、build)均为绿色。
+2. **获得审查者批准** — 至少一名维护者已批准该 PR。
+3. **无未解决的审查意见** — 所有审查讨论线程均已关闭。
+4. **PR 模板填写完整** — 包括 AI 披露和测试环境信息。
+
+### 谁可以合并
+
+只有维护者才能合并 PR。贡献者不能合并自己的 PR,即使拥有写权限也不行。
+
+### 合并策略
+
+为保持 `main` 历史清晰可读,我们对大多数 PR 使用 **Squash Merge**。每个合并的 PR 变为一个包含 PR 编号的单独 commit,例如:
+
+```
+feat: Add Ollama provider support (#491)
+```
+
+如果一个 PR 包含多个独立、结构清晰、能讲述完整故事的 commit,维护者可视情况使用普通 merge。
+
+### Release 分支
+
+当某个版本准备就绪时,维护者会从 `main` 切出 `release/x.y` 分支。此后:
+
+- **新功能不会被回溯(backport)。** Release 分支切出后,不再接收任何新功能。
+- **安全修复和关键 Bug 修复会被 cherry-pick 进来。** 若 `main` 上的某个修复属于安全漏洞、数据丢失或崩溃类问题,维护者会将相关 commit cherry-pick 到受影响的 `release/x.y` 分支,并发布补丁版本。
+
+如果你认为 `main` 上的某个修复应该被回溯到某个 release 分支,请在 PR 描述中注明,或单独开一个 Issue 说明。最终决定由维护者做出。
+
+Release 分支的保护级别高于 `main`,在任何情况下均不允许直接推送。
+
+---
+
+## 代码审查
+
+### 对贡献者的建议
+
+- 在合理时间内回复审查意见。如果需要更多时间,请告知。
+- 更新 PR 以响应反馈时,简要说明改动内容(例如:"按建议改用了 `sync.RWMutex`")。
+- 如果你不同意某条反馈,请礼貌地阐述你的理由——审查者也可能有判断失误的时候。
+- 审查开始后请不要 force push——这会让审查者难以追踪变化。请使用额外的 commit,维护者在合并时会进行 squash。
+
+### 对审查者的建议
+
+审查重点:
+
+1. **正确性** — 代码是否实现了其声称的功能?是否存在边界情况?
+2. **安全性** — 对 AI 生成代码、tool 实现和 channel 处理器尤其需要关注。
+3. **架构** — 实现方式是否与现有设计一致?
+4. **简洁性** — 是否有更简单的方案?是否引入了不必要的复杂度?
+5. **测试** — 改动是否有测试覆盖?现有测试是否仍然有意义?
+
+请给出建设性且具体的反馈。"如果两个 goroutine 同时调用这个函数可能会有竞态条件,建议在这里加一个 mutex" 远比 "这里看起来有问题" 更有帮助。
+
+### 审查者列表
+提交对应PR后,可以参考下表联系对应的审查人员沟通
+
+|Function| Reviewer|
+|--- |--- |
+|Provider|@yinwm |
+|Channel |@yinwm |
+|Agent |@lxowalle|
+|Tools |@lxowalle|
+|SKill ||
+|MCP ||
+|Optimization|@lxowalle|
+|Security||
+|AI CI |@imguoguo|
+|UX ||
+|Document||
+
+
+
+---
+
+## 沟通渠道
+
+- **GitHub Issues** — Bug 报告、功能建议、设计讨论。
+- **GitHub Discussions** — 一般性问题、想法交流、社区讨论。
+- **Pull Request 评论** — 与具体代码相关的反馈。
+- **Wechat&Discord** — 当你有至少一个已合并的PR后,我们会邀请你加入开发者交流群
+
+有疑问时,请先开 Issue 讨论,再动手写代码。这几乎没有成本,却能避免大量无效投入。
+
+---
+
+## 关于本项目的 AI 驱动起源
+
+PicoClaw 的架构在人工监督下,经由 AI 辅助完成了大量设计和实现工作。如果你发现某处看起来奇怪或过度设计,这可能是该过程留下的痕迹——欢迎提 Issue 讨论。
+
+我们相信,负责任地使用 AI 辅助开发能产生优秀的成果。我们同样相信,人类必须对自己提交的内容负责。这两点并不矛盾。
+
+感谢你的贡献!
diff --git a/Dockerfile b/Dockerfile
index 0360cfda6..480244127 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1,7 +1,7 @@
# ============================================================
# Stage 1: Build the picoclaw binary
# ============================================================
-FROM golang:1.26.0-alpine AS builder
+FROM golang:1.25-alpine AS builder
RUN apk add --no-cache git make
diff --git a/Makefile b/Makefile
index ff280e3e4..c59c414f3 100644
--- a/Makefile
+++ b/Makefile
@@ -11,16 +11,21 @@ VERSION?=$(shell git describe --tags --always --dirty 2>/dev/null || echo "dev")
GIT_COMMIT=$(shell git rev-parse --short=8 HEAD 2>/dev/null || echo "dev")
BUILD_TIME=$(shell date +%FT%T%z)
GO_VERSION=$(shell $(GO) version | awk '{print $$3}')
-LDFLAGS=-ldflags "-X main.version=$(VERSION) -X main.gitCommit=$(GIT_COMMIT) -X main.buildTime=$(BUILD_TIME) -X main.goVersion=$(GO_VERSION) -s -w"
+INTERNAL=github.com/sipeed/picoclaw/cmd/picoclaw/internal
+LDFLAGS=-ldflags "-X $(INTERNAL).version=$(VERSION) -X $(INTERNAL).gitCommit=$(GIT_COMMIT) -X $(INTERNAL).buildTime=$(BUILD_TIME) -X $(INTERNAL).goVersion=$(GO_VERSION) -s -w"
# Go variables
-GO?=go
+GO?=CGO_ENABLED=0 go
GOFLAGS?=-v -tags stdjson
+# Golangci-lint
+GOLANGCI_LINT?=golangci-lint
+
# Installation
INSTALL_PREFIX?=$(HOME)/.local
INSTALL_BIN_DIR=$(INSTALL_PREFIX)/bin
INSTALL_MAN_DIR=$(INSTALL_PREFIX)/share/man/man1
+INSTALL_TMP_SUFFIX=.new
# Workspace and Skills
PICOCLAW_HOME?=$(HOME)/.picoclaw
@@ -88,6 +93,7 @@ build-all: generate
GOOS=linux GOARCH=arm64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm64 ./$(CMD_DIR)
GOOS=linux GOARCH=loong64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-loong64 ./$(CMD_DIR)
GOOS=linux GOARCH=riscv64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-riscv64 ./$(CMD_DIR)
+ GOOS=linux GOARCH=arm GOARM=7 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-armv7 ./$(CMD_DIR)
GOOS=darwin GOARCH=arm64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-darwin-arm64 ./$(CMD_DIR)
GOOS=windows GOARCH=amd64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-windows-amd64.exe ./$(CMD_DIR)
@echo "All builds complete"
@@ -96,8 +102,10 @@ build-all: generate
install: build
@echo "Installing $(BINARY_NAME)..."
@mkdir -p $(INSTALL_BIN_DIR)
- @cp $(BUILD_DIR)/$(BINARY_NAME) $(INSTALL_BIN_DIR)/$(BINARY_NAME)
- @chmod +x $(INSTALL_BIN_DIR)/$(BINARY_NAME)
+ # Copy binary with temporary suffix to ensure atomic update
+ @cp $(BUILD_DIR)/$(BINARY_NAME) $(INSTALL_BIN_DIR)/$(BINARY_NAME)$(INSTALL_TMP_SUFFIX)
+ @chmod +x $(INSTALL_BIN_DIR)/$(BINARY_NAME)$(INSTALL_TMP_SUFFIX)
+ @mv -f $(INSTALL_BIN_DIR)/$(BINARY_NAME)$(INSTALL_TMP_SUFFIX) $(INSTALL_BIN_DIR)/$(BINARY_NAME)
@echo "Installed binary to $(INSTALL_BIN_DIR)/$(BINARY_NAME)"
@echo "Installation complete!"
@@ -126,13 +134,21 @@ clean:
vet:
@$(GO) vet ./...
-## fmt: Format Go code
+## test: Test Go code
test:
@$(GO) test ./...
## fmt: Format Go code
fmt:
- @$(GO) fmt ./...
+ @$(GOLANGCI_LINT) fmt
+
+## lint: Run linters
+lint:
+ @$(GOLANGCI_LINT) run
+
+## fix: Fix linting issues
+fix:
+ @$(GOLANGCI_LINT) run --fix
## deps: Download dependencies
deps:
@@ -159,7 +175,7 @@ help:
@echo " make [target]"
@echo ""
@echo "Targets:"
- @grep -E '^## ' $(MAKEFILE_LIST) | sed 's/## / /'
+ @grep -E '^## ' $(MAKEFILE_LIST) | sort | awk -F': ' '{printf " %-16s %s\n", substr($$1, 4), $$2}'
@echo ""
@echo "Examples:"
@echo " make build # Build for current platform"
diff --git a/README.fr.md b/README.fr.md
index 21913f6ba..f59807739 100644
--- a/README.fr.md
+++ b/README.fr.md
@@ -50,7 +50,7 @@
## 📢 Actualités
-2026-02-16 🎉 PicoClaw a atteint 12K étoiles en une semaine ! Merci à tous pour votre soutien ! PicoClaw grandit plus vite que nous ne l'avions jamais imaginé. Vu le volume élevé de PR, nous avons un besoin urgent de mainteneurs communautaires. Nos rôles de bénévoles et notre feuille de route sont officiellement publiés [ici](docs/picoclaw_community_roadmap_260216.md) — nous avons hâte de vous accueillir !
+2026-02-16 🎉 PicoClaw a atteint 12K étoiles en une semaine ! Merci à tous pour votre soutien ! PicoClaw grandit plus vite que nous ne l'avions jamais imaginé. Vu le volume élevé de PR, nous avons un besoin urgent de mainteneurs communautaires. Nos rôles de bénévoles et notre feuille de route sont officiellement publiés [ici](docs/ROADMAP.md) — nous avons hâte de vous accueillir !
2026-02-13 🎉 PicoClaw a atteint 5000 étoiles en 4 jours ! Merci à la communauté ! Nous finalisons la **Feuille de Route du Projet** et mettons en place le **Groupe de Développeurs** pour accélérer le développement de PicoClaw.
🚀 **Appel à l'action :** Soumettez vos demandes de fonctionnalités dans les GitHub Discussions. Nous les examinerons et les prioriserons lors de notre prochaine réunion hebdomadaire.
@@ -171,6 +171,10 @@ vim config/config.json # Configurez DISCORD_BOT_TOKEN, clés API, etc.
# 3. Compiler & Démarrer
docker compose --profile gateway up -d
+> [!TIP]
+> **Utilisateurs Docker** : Par défaut, le Gateway écoute sur `127.0.0.1`, ce qui n'est pas accessible depuis l'hôte. Si vous avez besoin d'accéder aux endpoints de santé ou d'exposer des ports, définissez `PICOCLAW_GATEWAY_HOST=0.0.0.0` dans votre environnement ou mettez à jour `config.json`.
+
+
# 4. Voir les logs
docker compose logs -f picoclaw-gateway
@@ -212,19 +216,24 @@ picoclaw onboard
```json
{
+ "model_list": [
+ {
+ "model_name": "gpt4",
+ "model": "openai/gpt-5.2",
+ "api_key": "sk-your-openai-key",
+ "api_base": "https://api.openai.com/v1"
+ }
+ ],
"agents": {
"defaults": {
- "workspace": "~/.picoclaw/workspace",
- "model": "glm-4.7",
- "max_tokens": 8192,
- "temperature": 0.7,
- "max_tool_iterations": 20
+ "model_name": "gpt4"
}
},
- "providers": {
- "openrouter": {
- "api_key": "xxx",
- "api_base": "https://openrouter.ai/api/v1"
+ "channels": {
+ "telegram": {
+ "enabled": true,
+ "token": "VOTRE_TOKEN_BOT",
+ "allow_from": ["VOTRE_USER_ID"]
}
},
"tools": {
@@ -262,7 +271,7 @@ Et voilà ! Vous avez un assistant IA fonctionnel en 2 minutes.
## 💬 Applications de Chat
-Discutez avec votre PicoClaw via Telegram, Discord, DingTalk ou LINE
+Discutez avec votre PicoClaw via Telegram, Discord, DingTalk, LINE ou WeCom
| Canal | Configuration |
| ------------ | -------------------------------------- |
@@ -271,6 +280,7 @@ Discutez avec votre PicoClaw via Telegram, Discord, DingTalk ou LINE
| **QQ** | Facile (AppID + AppSecret) |
| **DingTalk** | Moyen (identifiants de l'application) |
| **LINE** | Moyen (identifiants + URL de webhook) |
+| **WeCom** | Moyen (CorpID + configuration webhook) |
Telegram (Recommandé)
@@ -289,7 +299,7 @@ Discutez avec votre PicoClaw via Telegram, Discord, DingTalk ou LINE
"telegram": {
"enabled": true,
"token": "VOTRE_TOKEN_BOT",
- "allowFrom": ["VOTRE_USER_ID"]
+ "allow_from": ["VOTRE_USER_ID"]
}
}
}
@@ -332,7 +342,7 @@ picoclaw gateway
"discord": {
"enabled": true,
"token": "VOTRE_TOKEN_BOT",
- "allowFrom": ["VOTRE_USER_ID"]
+ "allow_from": ["VOTRE_USER_ID"]
}
}
}
@@ -470,6 +480,87 @@ picoclaw gateway
+
+WeCom (WeChat Work)
+
+PicoClaw prend en charge deux types d'intégration WeCom :
+
+**Option 1 : WeCom Bot (Robot Intelligent)** - Configuration plus facile, prend en charge les discussions de groupe
+**Option 2 : WeCom App (Application Personnalisée)** - Plus de fonctionnalités, messagerie proactive
+
+Voir le [Guide de Configuration WeCom App](docs/wecom-app-configuration.md) pour des instructions détaillées.
+
+**Configuration Rapide - WeCom Bot :**
+
+**1. Créer un bot**
+
+* Accédez à la Console d'Administration WeCom → Discussion de Groupe → Ajouter un Bot de Groupe
+* Copiez l'URL du webhook (format : `https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=xxx`)
+
+**2. Configurer**
+
+```json
+{
+ "channels": {
+ "wecom": {
+ "enabled": true,
+ "token": "YOUR_TOKEN",
+ "encoding_aes_key": "YOUR_ENCODING_AES_KEY",
+ "webhook_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=YOUR_KEY",
+ "webhook_host": "0.0.0.0",
+ "webhook_port": 18793,
+ "webhook_path": "/webhook/wecom",
+ "allow_from": []
+ }
+ }
+}
+```
+
+**Configuration Rapide - WeCom App :**
+
+**1. Créer une application**
+
+* Accédez à la Console d'Administration WeCom → Gestion des Applications → Créer une Application
+* Copiez l'**AgentId** et le **Secret**
+* Accédez à la page "Mon Entreprise", copiez le **CorpID**
+
+**2. Configurer la réception des messages**
+
+* Dans les détails de l'application, cliquez sur "Recevoir les Messages" → "Configurer l'API"
+* Définissez l'URL sur `http://your-server:18792/webhook/wecom-app`
+* Générez le **Token** et l'**EncodingAESKey**
+
+**3. Configurer**
+
+```json
+{
+ "channels": {
+ "wecom_app": {
+ "enabled": true,
+ "corp_id": "wwxxxxxxxxxxxxxxxx",
+ "corp_secret": "YOUR_CORP_SECRET",
+ "agent_id": 1000002,
+ "token": "YOUR_TOKEN",
+ "encoding_aes_key": "YOUR_ENCODING_AES_KEY",
+ "webhook_host": "0.0.0.0",
+ "webhook_port": 18792,
+ "webhook_path": "/webhook/wecom-app",
+ "allow_from": []
+ }
+ }
+}
+```
+
+**4. Lancer**
+
+```bash
+picoclaw gateway
+```
+
+> **Note** : WeCom App nécessite l'ouverture du port 18792 pour les callbacks webhook. Utilisez un proxy inverse pour HTTPS en production.
+
+
+
##
Rejoignez le Réseau Social d'Agents
Connectez PicoClaw au Réseau Social d'Agents simplement en envoyant un seul message via le CLI ou n'importe quelle application de chat intégrée.
@@ -683,6 +774,8 @@ Le sous-agent a accès aux outils (message, web_search, etc.) et peut communique
| `anthropic` (À tester) | LLM (Claude direct) | [console.anthropic.com](https://console.anthropic.com) |
| `openai` (À tester) | LLM (GPT direct) | [platform.openai.com](https://platform.openai.com) |
| `deepseek` (À tester) | LLM (DeepSeek direct) | [platform.deepseek.com](https://platform.deepseek.com) |
+| `qwen` | LLM (Alibaba Qwen) | [dashscope.aliyuncs.com](https://dashscope.aliyuncs.com/compatible-mode/v1) |
+| `cerebras` | LLM (Cerebras) | [cerebras.ai](https://api.cerebras.ai/v1) |
| `groq` | LLM + **Transcription vocale** (Whisper) | [console.groq.com](https://console.groq.com) |
@@ -1005,7 +1098,7 @@ Ajoutez la clé dans `~/.picoclaw/config.json` si vous utilisez Brave :
"tools": {
"web": {
"brave": {
- "enabled": true,
+ "enabled": false,
"api_key": "VOTRE_CLE_API_BRAVE",
"max_results": 5
},
diff --git a/README.ja.md b/README.ja.md
index c0e40883d..5a7bb8542 100644
--- a/README.ja.md
+++ b/README.ja.md
@@ -133,6 +133,10 @@ vim config/config.json # DISCORD_BOT_TOKEN, プロバイダーの API キ
# 3. ビルドと起動
docker compose --profile gateway up -d
+> [!TIP]
+> **Docker ユーザー**: デフォルトでは、Gateway は `127.0.0.1` でリッスンしており、ホストからアクセスできません。ヘルスチェックエンドポイントにアクセスしたり、ポートを公開したりする必要がある場合は、環境変数で `PICOCLAW_GATEWAY_HOST=0.0.0.0` を設定するか、`config.json` を更新してください。
+
+
# 4. ログ確認
docker compose logs -f picoclaw-gateway
@@ -162,7 +166,7 @@ docker compose --profile gateway up -d
> [!TIP]
> `~/.picoclaw/config.json` に API キーを設定してください。
> API キーの取得先: [OpenRouter](https://openrouter.ai/keys) (LLM) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) (LLM)
-> Web 検索は **任意** です - 無料の [Brave Search API](https://brave.com/search/api) (月 2000 クエリ無料)
+> Web 検索は **任意** です - 無料の [Tavily API](https://tavily.com) (月 1000 クエリ無料) または [Brave Search API](https://brave.com/search/api) (月 2000 クエリ無料)
**1. 初期化**
@@ -174,19 +178,24 @@ picoclaw onboard
```json
{
+ "model_list": [
+ {
+ "model_name": "gpt4",
+ "model": "openai/gpt-5.2",
+ "api_key": "sk-your-openai-key",
+ "api_base": "https://api.openai.com/v1"
+ }
+ ],
"agents": {
"defaults": {
- "workspace": "~/.picoclaw/workspace",
- "model": "glm-4.7",
- "max_tokens": 8192,
- "temperature": 0.7,
- "max_tool_iterations": 20
+ "model_name": "gpt4"
}
},
- "providers": {
- "openrouter": {
- "api_key": "xxx",
- "api_base": "https://openrouter.ai/api/v1"
+ "channels": {
+ "telegram": {
+ "enabled": true,
+ "token": "YOUR_TELEGRAM_BOT_TOKEN",
+ "allow_from": []
}
},
"tools": {
@@ -194,6 +203,11 @@ picoclaw onboard
"search": {
"api_key": "YOUR_BRAVE_API_KEY",
"max_results": 5
+ },
+ "tavily": {
+ "enabled": false,
+ "api_key": "YOUR_TAVILY_API_KEY",
+ "max_results": 5
}
},
"cron": {
@@ -209,12 +223,12 @@ picoclaw onboard
**3. API キーの取得**
-- **LLM プロバイダー**: [OpenRouter](https://openrouter.ai/keys) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) · [Anthropic](https://console.anthropic.com) · [OpenAI](https://platform.openai.com) · [Gemini](https://aistudio.google.com/api-keys) · [Qwen](https://dashscope.console.aliyun.com)
-- **Web 検索**(任意): [Brave Search](https://brave.com/search/api) - 無料枠あり(月 2000 リクエスト)
+- **LLM プロバイダー**: [OpenRouter](https://openrouter.ai/keys) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) · [Anthropic](https://console.anthropic.com) · [OpenAI](https://platform.openai.com) · [Gemini](https://aistudio.google.com/api-keys)
+- **Web 検索**(任意): [Tavily](https://tavily.com) - AI エージェント向けに最適化 (月 1000 リクエスト) · [Brave Search](https://brave.com/search/api) - 無料枠あり(月 2000 リクエスト)
> **注意**: 完全な設定テンプレートは `config.example.json` を参照してください。
-**3. チャット**
+**4. チャット**
```bash
picoclaw agent -m "What is 2+2?"
@@ -226,7 +240,7 @@ picoclaw agent -m "What is 2+2?"
## 💬 チャットアプリ
-Telegram、Discord、QQ、DingTalk、LINE で PicoClaw と会話できます
+Telegram、Discord、QQ、DingTalk、LINE、WeCom で PicoClaw と会話できます
| チャネル | セットアップ |
|---------|------------|
@@ -235,6 +249,7 @@ Telegram、Discord、QQ、DingTalk、LINE で PicoClaw と会話できます
| **QQ** | 簡単(AppID + AppSecret) |
| **DingTalk** | 普通(アプリ認証情報) |
| **LINE** | 普通(認証情報 + Webhook URL) |
+| **WeCom** | 普通(CorpID + Webhook設定) |
Telegram(推奨)
@@ -430,6 +445,87 @@ picoclaw gateway
+
+WeCom (企業微信)
+
+PicoClaw は2種類の WeCom 統合をサポートしています:
+
+**オプション1: WeCom Bot (智能ロボット)** - 簡単な設定、グループチャット対応
+**オプション2: WeCom App (自作アプリ)** - より多機能、アクティブメッセージング対応
+
+詳細な設定手順は [WeCom App Configuration Guide](docs/wecom-app-configuration.md) を参照してください。
+
+**クイックセットアップ - WeCom Bot:**
+
+**1. ボットを作成**
+
+* WeCom 管理コンソール → グループチャット → グループボットを追加
+* Webhook URL をコピー(形式: `https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=xxx`)
+
+**2. 設定**
+
+```json
+{
+ "channels": {
+ "wecom": {
+ "enabled": true,
+ "token": "YOUR_TOKEN",
+ "encoding_aes_key": "YOUR_ENCODING_AES_KEY",
+ "webhook_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=YOUR_KEY",
+ "webhook_host": "0.0.0.0",
+ "webhook_port": 18793,
+ "webhook_path": "/webhook/wecom",
+ "allow_from": []
+ }
+ }
+}
+```
+
+**クイックセットアップ - WeCom App:**
+
+**1. アプリを作成**
+
+* WeCom 管理コンソール → アプリ管理 → アプリを作成
+* **AgentId** と **Secret** をコピー
+* "マイ会社" ページで **CorpID** をコピー
+
+**2. メッセージ受信を設定**
+
+* アプリ詳細で "メッセージを受信" → "APIを設定" をクリック
+* URL を `http://your-server:18792/webhook/wecom-app` に設定
+* **Token** と **EncodingAESKey** を生成
+
+**3. 設定**
+
+```json
+{
+ "channels": {
+ "wecom_app": {
+ "enabled": true,
+ "corp_id": "wwxxxxxxxxxxxxxxxx",
+ "corp_secret": "YOUR_CORP_SECRET",
+ "agent_id": 1000002,
+ "token": "YOUR_TOKEN",
+ "encoding_aes_key": "YOUR_ENCODING_AES_KEY",
+ "webhook_host": "0.0.0.0",
+ "webhook_port": 18792,
+ "webhook_path": "/webhook/wecom-app",
+ "allow_from": []
+ }
+ }
+}
+```
+
+**4. 起動**
+
+```bash
+picoclaw gateway
+```
+
+> **注意**: WeCom App は Webhook コールバック用にポート 18792 を開放する必要があります。本番環境では HTTPS 用のリバースプロキシを使用してください。
+
+
+
## ⚙️ 設定
設定ファイル: `~/.picoclaw/config.json`
@@ -682,10 +778,10 @@ HEARTBEAT_OK 応答 ユーザーが直接結果を受け取る
},
"providers": {
"openrouter": {
- "apiKey": "sk-or-v1-xxx"
+ "api_key": "sk-or-v1-xxx"
},
"groq": {
- "apiKey": "gsk_xxx"
+ "api_key": "gsk_xxx"
}
},
"channels": {
@@ -704,17 +800,17 @@ HEARTBEAT_OK 応答 ユーザーが直接結果を受け取る
},
"feishu": {
"enabled": false,
- "appId": "cli_xxx",
- "appSecret": "xxx",
- "encryptKey": "",
- "verificationToken": "",
+ "app_id": "cli_xxx",
+ "app_secret": "xxx",
+ "encrypt_key": "",
+ "verification_token": "",
"allow_from": []
}
},
"tools": {
"web": {
"search": {
- "apiKey": "BSA..."
+ "api_key": "BSA..."
}
},
"cron": {
@@ -913,15 +1009,20 @@ Discord: https://discord.gg/V4sAZ9XWpN
検索 API キーをまだ設定していない場合、これは正常です。PicoClaw は手動検索用の便利なリンクを提供します。
Web 検索を有効にするには:
-1. [https://brave.com/search/api](https://brave.com/search/api) で無料の API キーを取得(月 2000 クエリ無料)
+1. [https://tavily.com](https://tavily.com) (月 1000 クエリ無料) または [https://brave.com/search/api](https://brave.com/search/api) で無料の API キーを取得(月 2000 クエリ無料)
2. `~/.picoclaw/config.json` に追加:
```json
{
"tools": {
"web": {
- "search": {
+ "brave": {
+ "enabled": true,
"api_key": "YOUR_BRAVE_API_KEY",
"max_results": 5
+ },
+ "duckduckgo": {
+ "enabled": true,
+ "max_results": 5
}
}
}
@@ -946,5 +1047,6 @@ Web 検索を有効にするには:
| **Zhipu** | 月 200K トークン | 中国ユーザー向け最適 |
| **Qwen** | 無料枠あり | 通義千問 (Qwen) |
| **Brave Search** | 月 2000 クエリ | Web 検索機能 |
+| **Tavily** | 月 1000 クエリ | AI エージェント検索最適化 |
| **Groq** | 無料枠あり | 高速推論(Llama, Mixtral) |
| **Cerebras** | 無料枠あり | 高速推論(Llama, Qwen など) |
diff --git a/README.md b/README.md
index 1f7200d15..4584d9a16 100644
--- a/README.md
+++ b/README.md
@@ -12,9 +12,13 @@
+
+
+
- [中文](README.zh.md) | [日本語](README.ja.md) | [Português](README.pt-br.md) | [Tiếng Việt](README.vi.md) | [Français](README.fr.md) | **English**
+[中文](README.zh.md) | [日本語](README.ja.md) | [Português](README.pt-br.md) | [Tiếng Việt](README.vi.md) | [Français](README.fr.md) | **English**
+
---
@@ -42,16 +46,17 @@
> **🚨 SECURITY & OFFICIAL CHANNELS / 安全声明**
>
> * **NO CRYPTO:** PicoClaw has **NO** official token/coin. All claims on `pump.fun` or other trading platforms are **SCAMS**.
+>
> * **OFFICIAL DOMAIN:** The **ONLY** official website is **[picoclaw.io](https://picoclaw.io)**, and company website is **[sipeed.com](https://sipeed.com)**
> * **Warning:** Many `.ai/.org/.com/.net/...` domains are registered by third parties.
> * **Warning:** picoclaw is in early development now and may have unresolved network security issues. Do not deploy to production environments before the v1.0 release.
> * **Note:** picoclaw has recently merged a lot of PRs, which may result in a larger memory footprint (10–20MB) in the latest versions. We plan to prioritize resource optimization as soon as the current feature set reaches a stable state.
-
## 📢 News
-2026-02-16 🎉 PicoClaw hit 12K stars in one week! Thank you all for your support! PicoClaw is growing faster than we ever imagined. Given the high volume of PRs, we urgently need community maintainers. Our volunteer roles and roadmap are officially posted [here](docs/picoclaw_community_roadmap_260216.md) —we can’t wait to have you on board!
-2026-02-13 🎉 PicoClaw hit 5000 stars in 4days! Thank you for the community! There are so many PRs&issues come in (during Chinese New Year holidays), we are finalizing the Project Roadmap and setting up the Developer Group to accelerate PicoClaw's development.
+2026-02-16 🎉 PicoClaw hit 12K stars in one week! Thank you all for your support! PicoClaw is growing faster than we ever imagined. Given the high volume of PRs, we urgently need community maintainers. Our volunteer roles and roadmap are officially posted [here](docs/ROADMAP.md) —we can’t wait to have you on board!
+
+2026-02-13 🎉 PicoClaw hit 5000 stars in 4days! Thank you for the community! There are so many PRs & issues coming in (during Chinese New Year holidays), we are finalizing the Project Roadmap and setting up the Developer Group to accelerate PicoClaw's development.
🚀 Call to Action: Please submit your feature requests in GitHub Discussions. We will review and prioritize them during our upcoming weekly meeting.
2026-02-09 🎉 PicoClaw Launched! Built in 1 day to bring AI Agents to $10 hardware with <10MB RAM. 🦐 PicoClaw,Let's Go!
@@ -100,9 +105,12 @@
### 📱 Run on old Android Phones
+
Give your decade-old phone a second life! Turn it into a smart AI Assistant with PicoClaw. Quick Start:
+
1. **Install Termux** (Available on F-Droid or Google Play).
2. **Execute cmds**
+
```bash
# Note: Replace v0.1.1 with the latest version from the Releases page
wget https://github.com/sipeed/picoclaw/releases/download/v0.1.1/picoclaw-linux-arm64
@@ -110,6 +118,7 @@ chmod +x picoclaw-linux-arm64
pkg install proot
termux-chroot ./picoclaw-linux-arm64 onboard
```
+
And then follow the instructions in the "Quick Start" section to complete the configuration!
@@ -165,6 +174,10 @@ vim config/config.json # Set DISCORD_BOT_TOKEN, API keys, etc.
# 3. Build & Start
docker compose --profile gateway up -d
+> [!TIP]
+> **Docker Users**: By default, the Gateway listens on `127.0.0.1` which is not accessible from the host. If you need to access the health endpoints or expose ports, set `PICOCLAW_GATEWAY_HOST=0.0.0.0` in your environment or update `config.json`.
+
+
# 4. Check logs
docker compose logs -f picoclaw-gateway
@@ -194,7 +207,7 @@ docker compose --profile gateway up -d
> [!TIP]
> Set your API key in `~/.picoclaw/config.json`.
> Get API keys: [OpenRouter](https://openrouter.ai/keys) (LLM) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) (LLM)
-> Web search is **optional** - [Brave Search API](https://brave.com/search/api) ($5/1000 queries, ~$5-6/month), [SearXNG](https://github.com/searxng/searxng) (free, self-hosted), or use built-in DuckDuckGo fallback.
+> Web Search is **optional** - get free [Tavily API](https://tavily.com) (1000 free queries/month), [SearXNG](https://github.com/searxng/searxng) (free, self-hosted) or [Brave Search API](https://brave.com/search/api) (2000 free queries/month) or use built-in auto fallback.
**1. Initialize**
@@ -209,7 +222,7 @@ picoclaw onboard
"agents": {
"defaults": {
"workspace": "~/.picoclaw/workspace",
- "model": "gpt4",
+ "model_name": "gpt4",
"max_tokens": 8192,
"temperature": 0.7,
"max_tool_iterations": 20
@@ -234,6 +247,11 @@ picoclaw onboard
"api_key": "YOUR_BRAVE_API_KEY",
"max_results": 5
},
+ "tavily": {
+ "enabled": false,
+ "api_key": "YOUR_TAVILY_API_KEY",
+ "max_results": 5
+ },
"duckduckgo": {
"enabled": true,
"max_results": 5
@@ -253,7 +271,7 @@ picoclaw onboard
}
```
-> **New**: The `model_list` configuration format allows zero-code provider addition. See [Model Configuration](#-model-configuration) for details.
+> **New**: The `model_list` configuration format allows zero-code provider addition. See [Model Configuration](#model-configuration-model_list) for details.
**3. Get API Keys**
@@ -262,6 +280,7 @@ picoclaw onboard
* [Brave Search](https://brave.com/search/api) - Paid ($5/1000 queries, ~$5-6/month)
* [Perplexity](https://www.perplexity.ai) - AI-powered search with chat interface
* [SearXNG](https://github.com/searxng/searxng) - Self-hosted metasearch engine (free, no API key needed)
+ * [Tavily](https://tavily.com) - Optimized for AI Agents (1000 requests/month)
* DuckDuckGo - Built-in fallback (no API key required)
> **Note**: See `config.example.json` for a complete configuration template.
@@ -278,7 +297,7 @@ That's it! You have a working AI assistant in 2 minutes.
## 💬 Chat Apps
-Talk to your picoclaw through Telegram, Discord, DingTalk, or LINE
+Talk to your picoclaw through Telegram, Discord, DingTalk, LINE, or WeCom
| Channel | Setup |
| ------------ | ---------------------------------- |
@@ -287,6 +306,7 @@ Talk to your picoclaw through Telegram, Discord, DingTalk, or LINE
| **QQ** | Easy (AppID + AppSecret) |
| **DingTalk** | Medium (app credentials) |
| **LINE** | Medium (credentials + webhook URL) |
+| **WeCom** | Medium (CorpID + webhook setup) |
Telegram (Recommended)
@@ -336,7 +356,6 @@ picoclaw gateway
* (Optional) Enable **SERVER MEMBERS INTENT** if you plan to use allow lists based on member data
**3. Get your User ID**
-
* Discord Settings → Advanced → enable **Developer Mode**
* Right-click your avatar → **Copy User ID**
@@ -348,7 +367,8 @@ picoclaw gateway
"discord": {
"enabled": true,
"token": "YOUR_BOT_TOKEN",
- "allow_from": ["YOUR_USER_ID"]
+ "allow_from": ["YOUR_USER_ID"],
+ "mention_only": false
}
}
}
@@ -361,6 +381,10 @@ picoclaw gateway
* Bot Permissions: `Send Messages`, `Read Message History`
* Open the generated invite URL and add the bot to your server
+**Optional: Mention-only mode**
+
+Set `"mention_only": true` to make the bot respond only when @-mentioned. Useful for shared servers where you want the bot to respond only when explicitly called.
+
**6. Run**
```bash
@@ -426,14 +450,13 @@ picoclaw gateway
}
```
-> Set `allow_from` to empty to allow all users, or specify QQ numbers to restrict access.
+> Set `allow_from` to empty to allow all users, or specify DingTalk user IDs to restrict access.
**3. Run**
```bash
picoclaw gateway
```
-
@@ -486,6 +509,86 @@ picoclaw gateway
+
+WeCom (企业微信)
+
+PicoClaw supports two types of WeCom integration:
+
+**Option 1: WeCom Bot (智能机器人)** - Easier setup, supports group chats
+**Option 2: WeCom App (自建应用)** - More features, proactive messaging
+
+See [WeCom App Configuration Guide](docs/wecom-app-configuration.md) for detailed setup instructions.
+
+**Quick Setup - WeCom Bot:**
+
+**1. Create a bot**
+
+* Go to WeCom Admin Console → Group Chat → Add Group Bot
+* Copy the webhook URL (format: `https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=xxx`)
+
+**2. Configure**
+
+```json
+{
+ "channels": {
+ "wecom": {
+ "enabled": true,
+ "token": "YOUR_TOKEN",
+ "encoding_aes_key": "YOUR_ENCODING_AES_KEY",
+ "webhook_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=YOUR_KEY",
+ "webhook_host": "0.0.0.0",
+ "webhook_port": 18793,
+ "webhook_path": "/webhook/wecom",
+ "allow_from": []
+ }
+ }
+}
+```
+
+**Quick Setup - WeCom App:**
+
+**1. Create an app**
+
+* Go to WeCom Admin Console → App Management → Create App
+* Copy **AgentId** and **Secret**
+* Go to "My Company" page, copy **CorpID**
+**2. Configure receive message**
+
+* In App details, click "Receive Message" → "Set API"
+* Set URL to `http://your-server:18792/webhook/wecom-app`
+* Generate **Token** and **EncodingAESKey**
+
+**3. Configure**
+
+```json
+{
+ "channels": {
+ "wecom_app": {
+ "enabled": true,
+ "corp_id": "wwxxxxxxxxxxxxxxxx",
+ "corp_secret": "YOUR_CORP_SECRET",
+ "agent_id": 1000002,
+ "token": "YOUR_TOKEN",
+ "encoding_aes_key": "YOUR_ENCODING_AES_KEY",
+ "webhook_host": "0.0.0.0",
+ "webhook_port": 18792,
+ "webhook_path": "/webhook/wecom-app",
+ "allow_from": []
+ }
+ }
+}
+```
+
+**4. Run**
+
+```bash
+picoclaw gateway
+```
+
+> **Note**: WeCom App requires opening port 18792 for webhook callbacks. Use a reverse proxy for HTTPS.
+
+
+
##
Join the Agent Social Network
Connect Picoclaw to the Agent Social Network simply by sending a single message via the CLI or any integrated Chat App.
@@ -532,23 +635,23 @@ PicoClaw runs in a sandboxed environment by default. The agent can only access f
}
```
-| Option | Default | Description |
-|--------|---------|-------------|
-| `workspace` | `~/.picoclaw/workspace` | Working directory for the agent |
-| `restrict_to_workspace` | `true` | Restrict file/command access to workspace |
+| Option | Default | Description |
+| ----------------------- | ----------------------- | ----------------------------------------- |
+| `workspace` | `~/.picoclaw/workspace` | Working directory for the agent |
+| `restrict_to_workspace` | `true` | Restrict file/command access to workspace |
#### Protected Tools
When `restrict_to_workspace: true`, the following tools are sandboxed:
-| Tool | Function | Restriction |
-|------|----------|-------------|
-| `read_file` | Read files | Only files within workspace |
-| `write_file` | Write files | Only files within workspace |
-| `list_dir` | List directories | Only directories within workspace |
-| `edit_file` | Edit files | Only files within workspace |
-| `append_file` | Append to files | Only files within workspace |
-| `exec` | Execute commands | Command paths must be within workspace |
+| Tool | Function | Restriction |
+| ------------- | ---------------- | -------------------------------------- |
+| `read_file` | Read files | Only files within workspace |
+| `write_file` | Write files | Only files within workspace |
+| `list_dir` | List directories | Only directories within workspace |
+| `edit_file` | Edit files | Only files within workspace |
+| `append_file` | Append to files | Only files within workspace |
+| `exec` | Execute commands | Command paths must be within workspace |
#### Additional Exec Protection
@@ -601,11 +704,11 @@ export PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE=false
The `restrict_to_workspace` setting applies consistently across all execution paths:
-| Execution Path | Security Boundary |
-|----------------|-------------------|
-| Main Agent | `restrict_to_workspace` ✅ |
+| Execution Path | Security Boundary |
+| ---------------- | ---------------------------- |
+| Main Agent | `restrict_to_workspace` ✅ |
| Subagent / Spawn | Inherits same restriction ✅ |
-| Heartbeat tasks | Inherits same restriction ✅ |
+| Heartbeat tasks | Inherits same restriction ✅ |
All paths share the same workspace restriction — there's no way to bypass the security boundary through subagents or scheduled tasks.
@@ -631,21 +734,23 @@ For long-running tasks (web search, API calls), use the `spawn` tool to create a
# Periodic Tasks
## Quick Tasks (respond directly)
+
- Report current time
## Long Tasks (use spawn for async)
+
- Search the web for AI news and summarize
- Check email and report important messages
```
**Key behaviors:**
-| Feature | Description |
-|---------|-------------|
-| **spawn** | Creates async subagent, doesn't block heartbeat |
-| **Independent context** | Subagent has its own context, no session history |
-| **message tool** | Subagent communicates with user directly via message tool |
-| **Non-blocking** | After spawning, heartbeat continues to next task |
+| Feature | Description |
+| ----------------------- | --------------------------------------------------------- |
+| **spawn** | Creates async subagent, doesn't block heartbeat |
+| **Independent context** | Subagent has its own context, no session history |
+| **message tool** | Subagent communicates with user directly via message tool |
+| **Non-blocking** | After spawning, heartbeat continues to next task |
#### How Subagent Communication Works
@@ -676,10 +781,10 @@ The subagent has access to tools (message, web_search, etc.) and can communicate
}
```
-| Option | Default | Description |
-|--------|---------|-------------|
-| `enabled` | `true` | Enable/disable heartbeat |
-| `interval` | `30` | Check interval in minutes (min: 5) |
+| Option | Default | Description |
+| ---------- | ------- | ---------------------------------- |
+| `enabled` | `true` | Enable/disable heartbeat |
+| `interval` | `30` | Check interval in minutes (min: 5) |
**Environment variables:**
@@ -691,17 +796,17 @@ The subagent has access to tools (message, web_search, etc.) and can communicate
> [!NOTE]
> Groq provides free voice transcription via Whisper. If configured, Telegram voice messages will be automatically transcribed.
-| Provider | Purpose | Get API Key |
-| -------------------------- | --------------------------------------- | ------------------------------------------------------ |
-| `gemini` | LLM (Gemini direct) | [aistudio.google.com](https://aistudio.google.com) |
-| `zhipu` | LLM (Zhipu direct) | [bigmodel.cn](bigmodel.cn) |
-| `openrouter(To be tested)` | LLM (recommended, access to all models) | [openrouter.ai](https://openrouter.ai) |
-| `anthropic(To be tested)` | LLM (Claude direct) | [console.anthropic.com](https://console.anthropic.com) |
-| `openai(To be tested)` | LLM (GPT direct) | [platform.openai.com](https://platform.openai.com) |
-| `deepseek(To be tested)` | LLM (DeepSeek direct) | [platform.deepseek.com](https://platform.deepseek.com) |
+| Provider | Purpose | Get API Key |
+| -------------------------- | --------------------------------------- | -------------------------------------------------------------------- |
+| `gemini` | LLM (Gemini direct) | [aistudio.google.com](https://aistudio.google.com) |
+| `zhipu` | LLM (Zhipu direct) | [bigmodel.cn](https://bigmodel.cn) |
+| `openrouter(To be tested)` | LLM (recommended, access to all models) | [openrouter.ai](https://openrouter.ai) |
+| `anthropic(To be tested)` | LLM (Claude direct) | [console.anthropic.com](https://console.anthropic.com) |
+| `openai(To be tested)` | LLM (GPT direct) | [platform.openai.com](https://platform.openai.com) |
+| `deepseek(To be tested)` | LLM (DeepSeek direct) | [platform.deepseek.com](https://platform.deepseek.com) |
| `qwen` | LLM (Qwen direct) | [dashscope.console.aliyun.com](https://dashscope.console.aliyun.com) |
-| `groq` | LLM + **Voice transcription** (Whisper) | [console.groq.com](https://console.groq.com) |
-| `cerebras` | LLM (Cerebras direct) | [cerebras.ai](https://cerebras.ai) |
+| `groq` | LLM + **Voice transcription** (Whisper) | [console.groq.com](https://console.groq.com) |
+| `cerebras` | LLM (Cerebras direct) | [cerebras.ai](https://cerebras.ai) |
### Model Configuration (model_list)
@@ -716,25 +821,25 @@ This design also enables **multi-agent support** with flexible provider selectio
#### 📋 All Supported Vendors
-| Vendor | `model` Prefix | Default API Base | Protocol | API Key |
-|--------|----------------|------------------|----------|---------|
-| **OpenAI** | `openai/` | `https://api.openai.com/v1` | OpenAI | [Get Key](https://platform.openai.com) |
-| **Anthropic** | `anthropic/` | `https://api.anthropic.com/v1` | Anthropic | [Get Key](https://console.anthropic.com) |
-| **智谱 AI (GLM)** | `zhipu/` | `https://open.bigmodel.cn/api/paas/v4` | OpenAI | [Get Key](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) |
-| **DeepSeek** | `deepseek/` | `https://api.deepseek.com/v1` | OpenAI | [Get Key](https://platform.deepseek.com) |
-| **Google Gemini** | `gemini/` | `https://generativelanguage.googleapis.com/v1beta` | OpenAI | [Get Key](https://aistudio.google.com/api-keys) |
-| **Groq** | `groq/` | `https://api.groq.com/openai/v1` | OpenAI | [Get Key](https://console.groq.com) |
-| **Moonshot** | `moonshot/` | `https://api.moonshot.cn/v1` | OpenAI | [Get Key](https://platform.moonshot.cn) |
-| **通义千问 (Qwen)** | `qwen/` | `https://dashscope.aliyuncs.com/compatible-mode/v1` | OpenAI | [Get Key](https://dashscope.console.aliyun.com) |
-| **NVIDIA** | `nvidia/` | `https://integrate.api.nvidia.com/v1` | OpenAI | [Get Key](https://build.nvidia.com) |
-| **Ollama** | `ollama/` | `http://localhost:11434/v1` | OpenAI | Local (no key needed) |
-| **OpenRouter** | `openrouter/` | `https://openrouter.ai/api/v1` | OpenAI | [Get Key](https://openrouter.ai/keys) |
-| **VLLM** | `vllm/` | `http://localhost:8000/v1` | OpenAI | Local |
-| **Cerebras** | `cerebras/` | `https://api.cerebras.ai/v1` | OpenAI | [Get Key](https://cerebras.ai) |
-| **火山引擎** | `volcengine/` | `https://ark.cn-beijing.volces.com/api/v3` | OpenAI | [Get Key](https://console.volcengine.com) |
-| **神算云** | `shengsuanyun/` | `https://router.shengsuanyun.com/api/v1` | OpenAI | - |
-| **Antigravity** | `antigravity/` | Google Cloud | Custom | OAuth only |
-| **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - |
+| Vendor | `model` Prefix | Default API Base | Protocol | API Key |
+| ------------------- | ----------------- | --------------------------------------------------- | --------- | ---------------------------------------------------------------- |
+| **OpenAI** | `openai/` | `https://api.openai.com/v1` | OpenAI | [Get Key](https://platform.openai.com) |
+| **Anthropic** | `anthropic/` | `https://api.anthropic.com/v1` | Anthropic | [Get Key](https://console.anthropic.com) |
+| **智谱 AI (GLM)** | `zhipu/` | `https://open.bigmodel.cn/api/paas/v4` | OpenAI | [Get Key](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) |
+| **DeepSeek** | `deepseek/` | `https://api.deepseek.com/v1` | OpenAI | [Get Key](https://platform.deepseek.com) |
+| **Google Gemini** | `gemini/` | `https://generativelanguage.googleapis.com/v1beta` | OpenAI | [Get Key](https://aistudio.google.com/api-keys) |
+| **Groq** | `groq/` | `https://api.groq.com/openai/v1` | OpenAI | [Get Key](https://console.groq.com) |
+| **Moonshot** | `moonshot/` | `https://api.moonshot.cn/v1` | OpenAI | [Get Key](https://platform.moonshot.cn) |
+| **通义千问 (Qwen)** | `qwen/` | `https://dashscope.aliyuncs.com/compatible-mode/v1` | OpenAI | [Get Key](https://dashscope.console.aliyun.com) |
+| **NVIDIA** | `nvidia/` | `https://integrate.api.nvidia.com/v1` | OpenAI | [Get Key](https://build.nvidia.com) |
+| **Ollama** | `ollama/` | `http://localhost:11434/v1` | OpenAI | Local (no key needed) |
+| **OpenRouter** | `openrouter/` | `https://openrouter.ai/api/v1` | OpenAI | [Get Key](https://openrouter.ai/keys) |
+| **VLLM** | `vllm/` | `http://localhost:8000/v1` | OpenAI | Local |
+| **Cerebras** | `cerebras/` | `https://api.cerebras.ai/v1` | OpenAI | [Get Key](https://cerebras.ai) |
+| **火山引擎** | `volcengine/` | `https://ark.cn-beijing.volces.com/api/v3` | OpenAI | [Get Key](https://console.volcengine.com) |
+| **神算云** | `shengsuanyun/` | `https://router.shengsuanyun.com/api/v1` | OpenAI | - |
+| **Antigravity** | `antigravity/` | Google Cloud | Custom | OAuth only |
+| **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - |
#### Basic Configuration
@@ -768,6 +873,7 @@ This design also enables **multi-agent support** with flexible provider selectio
#### Vendor-Specific Examples
**OpenAI**
+
```json
{
"model_name": "gpt-5.2",
@@ -777,6 +883,7 @@ This design also enables **multi-agent support** with flexible provider selectio
```
**智谱 AI (GLM)**
+
```json
{
"model_name": "glm-4.7",
@@ -786,6 +893,7 @@ This design also enables **multi-agent support** with flexible provider selectio
```
**DeepSeek**
+
```json
{
"model_name": "deepseek-chat",
@@ -794,17 +902,20 @@ This design also enables **multi-agent support** with flexible provider selectio
}
```
-**Anthropic (with OAuth)**
+**Anthropic (with API key)**
+
```json
{
"model_name": "claude-sonnet-4.6",
"model": "anthropic/claude-sonnet-4.6",
- "auth_method": "oauth"
+ "api_key": "sk-ant-your-key"
}
```
-> Run `picoclaw auth login --provider anthropic` to set up OAuth credentials.
+
+> Run `picoclaw auth login --provider anthropic` to paste your API token.
**Ollama (local)**
+
```json
{
"model_name": "llama3",
@@ -813,6 +924,7 @@ This design also enables **multi-agent support** with flexible provider selectio
```
**Custom Proxy/API**
+
```json
{
"model_name": "my-custom-model",
@@ -850,6 +962,7 @@ Configure multiple endpoints for the same model name—PicoClaw will automatical
The old `providers` configuration is **deprecated** but still supported for backward compatibility.
**Old Config (deprecated):**
+
```json
{
"providers": {
@@ -868,6 +981,7 @@ The old `providers` configuration is **deprecated** but still supported for back
```
**New Config (recommended):**
+
```json
{
"model_list": [
@@ -1042,19 +1156,19 @@ Jobs are stored in `~/.picoclaw/workspace/cron/` and processed automatically.
PRs welcome! The codebase is intentionally small and readable. 🤗
-Roadmap coming soon...
+See our full [Community Roadmap](https://github.com/sipeed/picoclaw/blob/main/ROADMAP.md).
-Developer group building, Entry Requirement: At least 1 Merged PR.
+Developer group building, join after your first merged PR!
User Groups:
-discord:
+discord:
## 🐛 Troubleshooting
-### Web search says "API 配置问题"
+### Web search says "API key configuration issue"
This is normal if you haven't configured a search API key yet. PicoClaw will provide helpful links for manual searching.
diff --git a/README.pt-br.md b/README.pt-br.md
index 44f27813c..0115b7f89 100644
--- a/README.pt-br.md
+++ b/README.pt-br.md
@@ -50,7 +50,7 @@
## 📢 Novidades
-2026-02-16 🎉 PicoClaw atingiu 12K stars em uma semana! Obrigado a todos pelo apoio! O PicoClaw está crescendo mais rápido do que jamais imaginamos. Dado o alto volume de PRs, precisamos urgentemente de maintainers da comunidade. Nossos papéis de voluntários e roadmap foram publicados oficialmente [aqui](docs/picoclaw_community_roadmap_260216.md) — estamos ansiosos para ter você a bordo!
+2026-02-16 🎉 PicoClaw atingiu 12K stars em uma semana! Obrigado a todos pelo apoio! O PicoClaw está crescendo mais rápido do que jamais imaginamos. Dado o alto volume de PRs, precisamos urgentemente de maintainers da comunidade. Nossos papéis de voluntários e roadmap foram publicados oficialmente [aqui](docs/ROADMAP.md) — estamos ansiosos para ter você a bordo!
2026-02-13 🎉 PicoClaw atingiu 5000 stars em 4 dias! Obrigado à comunidade! Estamos finalizando o **Roadmap do Projeto** e configurando o **Grupo de Desenvolvedores** para acelerar o desenvolvimento do PicoClaw.
@@ -172,6 +172,10 @@ vim config/config.json # Configure DISCORD_BOT_TOKEN, API keys, etc.
# 3. Build & Iniciar
docker compose --profile gateway up -d
+> [!TIP]
+> **Usuários Docker**: Por padrão, o Gateway ouve em `127.0.0.1`, o que não é acessível a partir do host. Se você precisar acessar os endpoints de integridade ou expor portas, defina `PICOCLAW_GATEWAY_HOST=0.0.0.0` em seu ambiente ou atualize o `config.json`.
+
+
# 4. Ver logs
docker compose logs -f picoclaw-gateway
@@ -213,19 +217,17 @@ picoclaw onboard
```json
{
+ "model_list": [
+ {
+ "model_name": "gpt4",
+ "model": "openai/gpt-5.2",
+ "api_key": "sk-your-openai-key",
+ "api_base": "https://api.openai.com/v1"
+ }
+ ],
"agents": {
"defaults": {
- "workspace": "~/.picoclaw/workspace",
- "model": "glm-4.7",
- "max_tokens": 8192,
- "temperature": 0.7,
- "max_tool_iterations": 20
- }
- },
- "providers": {
- "openrouter": {
- "api_key": "xxx",
- "api_base": "https://openrouter.ai/api/v1"
+ "model_name": "gpt4"
}
},
"tools": {
@@ -263,7 +265,7 @@ Pronto! Você tem um assistente de IA funcionando em 2 minutos.
## 💬 Integração com Apps de Chat
-Converse com seu PicoClaw via Telegram, Discord, DingTalk ou LINE.
+Converse com seu PicoClaw via Telegram, Discord, DingTalk, LINE ou WeCom.
| Canal | Nível de Configuração |
| --- | --- |
@@ -272,6 +274,7 @@ Converse com seu PicoClaw via Telegram, Discord, DingTalk ou LINE.
| **QQ** | Fácil (AppID + AppSecret) |
| **DingTalk** | Médio (credenciais do app) |
| **LINE** | Médio (credenciais + webhook URL) |
+| **WeCom** | Médio (CorpID + configuração webhook) |
Telegram (Recomendado)
@@ -290,7 +293,7 @@ Converse com seu PicoClaw via Telegram, Discord, DingTalk ou LINE.
"telegram": {
"enabled": true,
"token": "YOUR_BOT_TOKEN",
- "allowFrom": ["YOUR_USER_ID"]
+ "allow_from": ["YOUR_USER_ID"]
}
}
}
@@ -333,7 +336,7 @@ picoclaw gateway
"discord": {
"enabled": true,
"token": "YOUR_BOT_TOKEN",
- "allowFrom": ["YOUR_USER_ID"]
+ "allow_from": ["YOUR_USER_ID"]
}
}
}
@@ -471,6 +474,87 @@ picoclaw gateway
+
+WeCom (WeChat Work)
+
+O PicoClaw suporta dois tipos de integração WeCom:
+
+**Opção 1: WeCom Bot (Robô Inteligente)** - Configuração mais fácil, suporta chats em grupo
+**Opção 2: WeCom App (Aplicativo Personalizado)** - Mais recursos, mensagens proativas
+
+Veja o [Guia de Configuração WeCom App](docs/wecom-app-configuration.md) para instruções detalhadas.
+
+**Configuração Rápida - WeCom Bot:**
+
+**1. Criar um bot**
+
+* Acesse o Console de Administração WeCom → Chat em Grupo → Adicionar Bot de Grupo
+* Copie a URL do webhook (formato: `https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=xxx`)
+
+**2. Configurar**
+
+```json
+{
+ "channels": {
+ "wecom": {
+ "enabled": true,
+ "token": "YOUR_TOKEN",
+ "encoding_aes_key": "YOUR_ENCODING_AES_KEY",
+ "webhook_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=YOUR_KEY",
+ "webhook_host": "0.0.0.0",
+ "webhook_port": 18793,
+ "webhook_path": "/webhook/wecom",
+ "allow_from": []
+ }
+ }
+}
+```
+
+**Configuração Rápida - WeCom App:**
+
+**1. Criar um aplicativo**
+
+* Acesse o Console de Administração WeCom → Gerenciamento de Aplicativos → Criar Aplicativo
+* Copie o **AgentId** e o **Secret**
+* Acesse a página "Minha Empresa", copie o **CorpID**
+
+**2. Configurar recebimento de mensagens**
+
+* Nos detalhes do aplicativo, clique em "Receber Mensagens" → "Configurar API"
+* Defina a URL como `http://your-server:18792/webhook/wecom-app`
+* Gere o **Token** e o **EncodingAESKey**
+
+**3. Configurar**
+
+```json
+{
+ "channels": {
+ "wecom_app": {
+ "enabled": true,
+ "corp_id": "wwxxxxxxxxxxxxxxxx",
+ "corp_secret": "YOUR_CORP_SECRET",
+ "agent_id": 1000002,
+ "token": "YOUR_TOKEN",
+ "encoding_aes_key": "YOUR_ENCODING_AES_KEY",
+ "webhook_host": "0.0.0.0",
+ "webhook_port": 18792,
+ "webhook_path": "/webhook/wecom-app",
+ "allow_from": []
+ }
+ }
+}
+```
+
+**4. Executar**
+
+```bash
+picoclaw gateway
+```
+
+> **Nota**: O WeCom App requer a abertura da porta 18792 para callbacks de webhook. Use um proxy reverso para HTTPS em produção.
+
+
+
##
Junte-se a Rede Social de Agentes
Conecte o PicoClaw a Rede Social de Agentes simplesmente enviando uma única mensagem via CLI ou qualquer App de Chat integrado.
@@ -684,6 +768,8 @@ O subagente tem acesso às ferramentas (message, web_search, etc.) e pode se com
| `anthropic` (Em teste) | LLM (Claude direto) | [console.anthropic.com](https://console.anthropic.com) |
| `openai` (Em teste) | LLM (GPT direto) | [platform.openai.com](https://platform.openai.com) |
| `deepseek` (Em teste) | LLM (DeepSeek direto) | [platform.deepseek.com](https://platform.deepseek.com) |
+| `qwen` | Alibaba Qwen | [dashscope.console.aliyun.com](https://dashscope.console.aliyun.com) |
+| `cerebras` | Cerebras | [cerebras.ai](https://cerebras.ai) |
| `groq` | LLM + **Transcrição de voz** (Whisper) | [console.groq.com](https://console.groq.com) |
@@ -1006,7 +1092,7 @@ Adicione a key em `~/.picoclaw/config.json` se usar o Brave:
"tools": {
"web": {
"brave": {
- "enabled": true,
+ "enabled": false,
"api_key": "YOUR_BRAVE_API_KEY",
"max_results": 5
},
@@ -1037,3 +1123,4 @@ Isso acontece quando outra instância do bot está em execução. Certifique-se
| **Zhipu** | 200K tokens/mês | Melhor para usuários chineses |
| **Brave Search** | 2000 consultas/mês | Funcionalidade de busca web |
| **Groq** | Plano gratuito disponível | Inferência ultra-rápida (Llama, Mixtral) |
+| **Cerebras** | Plano gratuito disponível | Inferência ultra-rápida (Llama 3.3 70B) |
diff --git a/README.vi.md b/README.vi.md
index 08fa3dccd..015bc264e 100644
--- a/README.vi.md
+++ b/README.vi.md
@@ -50,7 +50,7 @@
## 📢 Tin tức
-2026-02-16 🎉 PicoClaw đạt 12K stars chỉ trong một tuần! Cảm ơn tất cả mọi người! PicoClaw đang phát triển nhanh hơn chúng tôi tưởng tượng. Do số lượng PR tăng cao, chúng tôi cấp thiết cần maintainer từ cộng đồng. Các vai trò tình nguyện viên và roadmap đã được công bố [tại đây](docs/picoclaw_community_roadmap_260216.md) — rất mong đón nhận sự tham gia của bạn!
+2026-02-16 🎉 PicoClaw đạt 12K stars chỉ trong một tuần! Cảm ơn tất cả mọi người! PicoClaw đang phát triển nhanh hơn chúng tôi tưởng tượng. Do số lượng PR tăng cao, chúng tôi cấp thiết cần maintainer từ cộng đồng. Các vai trò tình nguyện viên và roadmap đã được công bố [tại đây](docs/ROADMAP.md) — rất mong đón nhận sự tham gia của bạn!
2026-02-13 🎉 PicoClaw đạt 5000 stars trong 4 ngày! Cảm ơn cộng đồng! Chúng tôi đang hoàn thiện **Lộ trình dự án (Roadmap)** và thiết lập **Nhóm phát triển** để đẩy nhanh tốc độ phát triển PicoClaw.
🚀 **Kêu gọi hành động:** Vui lòng gửi yêu cầu tính năng tại GitHub Discussions. Chúng tôi sẽ xem xét và ưu tiên trong cuộc họp hàng tuần.
@@ -152,6 +152,10 @@ vim config/config.json # Thiết lập DISCORD_BOT_TOKEN, API keys, v.v.
# 3. Build & Khởi động
docker compose --profile gateway up -d
+> [!TIP]
+> **Người dùng Docker**: Theo mặc định, Gateway lắng nghe trên `127.0.0.1`, không thể truy cập từ máy chủ. Nếu bạn cần truy cập các endpoint kiểm tra sức khỏe hoặc mở cổng, hãy đặt `PICOCLAW_GATEWAY_HOST=0.0.0.0` trong môi trường của bạn hoặc cập nhật `config.json`.
+
+
# 4. Xem logs
docker compose logs -f picoclaw-gateway
@@ -193,32 +197,24 @@ picoclaw onboard
```json
{
+ "model_list": [
+ {
+ "model_name": "gpt4",
+ "model": "openai/gpt-5.2",
+ "api_key": "sk-your-openai-key",
+ "api_base": "https://api.openai.com/v1"
+ }
+ ],
"agents": {
"defaults": {
- "workspace": "~/.picoclaw/workspace",
- "model": "glm-4.7",
- "max_tokens": 8192,
- "temperature": 0.7,
- "max_tool_iterations": 20
+ "model_name": "gpt4"
}
},
- "providers": {
- "openrouter": {
- "api_key": "xxx",
- "api_base": "https://openrouter.ai/api/v1"
- }
- },
- "tools": {
- "web": {
- "brave": {
- "enabled": false,
- "api_key": "YOUR_BRAVE_API_KEY",
- "max_results": 5
- },
- "duckduckgo": {
- "enabled": true,
- "max_results": 5
- }
+ "channels": {
+ "telegram": {
+ "enabled": true,
+ "token": "YOUR_TELEGRAM_BOT_TOKEN",
+ "allow_from": []
}
}
}
@@ -243,7 +239,7 @@ Vậy là xong! Bạn đã có một trợ lý AI hoạt động chỉ trong 2 p
## 💬 Tích hợp ứng dụng Chat
-Trò chuyện với PicoClaw qua Telegram, Discord, DingTalk hoặc LINE.
+Trò chuyện với PicoClaw qua Telegram, Discord, DingTalk, LINE hoặc WeCom.
| Kênh | Mức độ thiết lập |
| --- | --- |
@@ -252,6 +248,7 @@ Trò chuyện với PicoClaw qua Telegram, Discord, DingTalk hoặc LINE.
| **QQ** | Dễ (AppID + AppSecret) |
| **DingTalk** | Trung bình (app credentials) |
| **LINE** | Trung bình (credentials + webhook URL) |
+| **WeCom** | Trung bình (CorpID + cấu hình webhook) |
Telegram (Khuyên dùng)
@@ -451,6 +448,87 @@ picoclaw gateway
+
+WeCom (WeChat Work)
+
+PicoClaw hỗ trợ hai loại tích hợp WeCom:
+
+**Tùy chọn 1: WeCom Bot (Robot Thông minh)** - Thiết lập dễ dàng hơn, hỗ trợ chat nhóm
+**Tùy chọn 2: WeCom App (Ứng dụng Tự xây dựng)** - Nhiều tính năng hơn, nhắn tin chủ động
+
+Xem [Hướng dẫn Cấu hình WeCom App](docs/wecom-app-configuration.md) để biết hướng dẫn chi tiết.
+
+**Thiết lập Nhanh - WeCom Bot:**
+
+**1. Tạo bot**
+
+* Truy cập Bảng điều khiển Quản trị WeCom → Chat Nhóm → Thêm Bot Nhóm
+* Sao chép URL webhook (định dạng: `https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=xxx`)
+
+**2. Cấu hình**
+
+```json
+{
+ "channels": {
+ "wecom": {
+ "enabled": true,
+ "token": "YOUR_TOKEN",
+ "encoding_aes_key": "YOUR_ENCODING_AES_KEY",
+ "webhook_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=YOUR_KEY",
+ "webhook_host": "0.0.0.0",
+ "webhook_port": 18793,
+ "webhook_path": "/webhook/wecom",
+ "allow_from": []
+ }
+ }
+}
+```
+
+**Thiết lập Nhanh - WeCom App:**
+
+**1. Tạo ứng dụng**
+
+* Truy cập Bảng điều khiển Quản trị WeCom → Quản lý Ứng dụng → Tạo Ứng dụng
+* Sao chép **AgentId** và **Secret**
+* Truy cập trang "Công ty của tôi", sao chép **CorpID**
+
+**2. Cấu hình nhận tin nhắn**
+
+* Trong chi tiết ứng dụng, nhấp vào "Nhận Tin nhắn" → "Thiết lập API"
+* Đặt URL thành `http://your-server:18792/webhook/wecom-app`
+* Tạo **Token** và **EncodingAESKey**
+
+**3. Cấu hình**
+
+```json
+{
+ "channels": {
+ "wecom_app": {
+ "enabled": true,
+ "corp_id": "wwxxxxxxxxxxxxxxxx",
+ "corp_secret": "YOUR_CORP_SECRET",
+ "agent_id": 1000002,
+ "token": "YOUR_TOKEN",
+ "encoding_aes_key": "YOUR_ENCODING_AES_KEY",
+ "webhook_host": "0.0.0.0",
+ "webhook_port": 18792,
+ "webhook_path": "/webhook/wecom-app",
+ "allow_from": []
+ }
+ }
+}
+```
+
+**4. Chạy**
+
+```bash
+picoclaw gateway
+```
+
+> **Lưu ý**: WeCom App yêu cầu mở cổng 18792 cho callback webhook. Sử dụng proxy ngược cho HTTPS trong môi trường sản xuất.
+
+
+
##
Tham gia Mạng xã hội Agent
Kết nối PicoClaw với Mạng xã hội Agent chỉ bằng cách gửi một tin nhắn qua CLI hoặc bất kỳ ứng dụng Chat nào đã tích hợp.
@@ -665,6 +743,8 @@ Subagent có quyền truy cập các công cụ (message, web_search, v.v.) và
| `openai` (Đang thử nghiệm) | LLM (GPT trực tiếp) | [platform.openai.com](https://platform.openai.com) |
| `deepseek` (Đang thử nghiệm) | LLM (DeepSeek trực tiếp) | [platform.deepseek.com](https://platform.deepseek.com) |
| `groq` | LLM + **Chuyển giọng nói** (Whisper) | [console.groq.com](https://console.groq.com) |
+| `qwen` | LLM (Qwen trực tiếp) | [dashscope.console.aliyun.com](https://dashscope.console.aliyun.com) |
+| `cerebras` | LLM (Cerebras trực tiếp) | [cerebras.ai](https://cerebras.ai) |
Cấu hình Zhipu
@@ -983,7 +1063,7 @@ Thêm key vào `~/.picoclaw/config.json` nếu dùng Brave:
"tools": {
"web": {
"brave": {
- "enabled": true,
+ "enabled": false,
"api_key": "YOUR_BRAVE_API_KEY",
"max_results": 5
},
diff --git a/README.zh.md b/README.zh.md
index 4827e66ea..4f4bde46a 100644
--- a/README.zh.md
+++ b/README.zh.md
@@ -14,7 +14,8 @@
- **中文** | [日本語](README.ja.md) | [Português](README.pt-br.md) | [Tiếng Việt](README.vi.md) | [Français](README.fr.md) | [English](README.md)
+**中文** | [日本語](README.ja.md) | [Português](README.pt-br.md) | [Tiếng Việt](README.vi.md) | [Français](README.fr.md) | [English](README.md)
+
---
@@ -42,15 +43,16 @@
> [!CAUTION]
> **🚨 SECURITY & OFFICIAL CHANNELS / 安全声明**
-> * **无加密货币 (NO CRYPTO):** PicoClaw **没有** 发行任何官方代币、Token 或虚拟货币。所有在 `pump.fun` 或其他交易平台上的相关声称均为 **诈骗**。
-> * **官方域名:** 唯一的官方网站是 **[picoclaw.io](https://picoclaw.io)**,公司官网是 **[sipeed.com](https://sipeed.com)**。
-> * **警惕:** 许多 `.ai/.org/.com/.net/...` 后缀的域名被第三方抢注,请勿轻信。
-> * **注意:** picoclaw正在初期的快速功能开发阶段,可能有尚未修复的网络安全问题,在1.0正式版发布前,请不要将其部署到生产环境中
-> * **注意:** picoclaw最近合并了大量PRs,近期版本可能内存占用较大(10~20MB),我们将在功能较为收敛后进行资源占用优化.
-
+>
+> - **无加密货币 (NO CRYPTO):** PicoClaw **没有** 发行任何官方代币、Token 或虚拟货币。所有在 `pump.fun` 或其他交易平台上的相关声称均为 **诈骗**。
+> - **官方域名:** 唯一的官方网站是 **[picoclaw.io](https://picoclaw.io)**,公司官网是 **[sipeed.com](https://sipeed.com)**。
+> - **警惕:** 许多 `.ai/.org/.com/.net/...` 后缀的域名被第三方抢注,请勿轻信。
+> - **注意:** picoclaw正在初期的快速功能开发阶段,可能有尚未修复的网络安全问题,在1.0正式版发布前,请不要将其部署到生产环境中
+> - **注意:** picoclaw最近合并了大量PRs,近期版本可能内存占用较大(10~20MB),我们将在功能较为收敛后进行资源占用优化.
## 📢 新闻 (News)
-2026-02-16 🎉 PicoClaw 在一周内突破了12K star! 感谢大家的关注!PicoClaw 的成长速度超乎我们预期. 由于PR数量的快速膨胀,我们亟需社区开发者参与维护. 我们需要的志愿者角色和roadmap已经发布到了[这里](docs/picoclaw_community_roadmap_260216.md), 期待你的参与!
+
+2026-02-16 🎉 PicoClaw 在一周内突破了12K star! 感谢大家的关注!PicoClaw 的成长速度超乎我们预期. 由于PR数量的快速膨胀,我们亟需社区开发者参与维护. 我们需要的志愿者角色和roadmap已经发布到了[这里](docs/ROADMAP.md), 期待你的参与!
2026-02-13 🎉 **PicoClaw 在 4 天内突破 5000 Stars!** 感谢社区的支持!由于正值中国春节假期,PR 和 Issue 涌入较多,我们正在利用这段时间敲定 **项目路线图 (Roadmap)** 并组建 **开发者群组**,以便加速 PicoClaw 的开发。
🚀 **行动号召:** 请在 GitHub Discussions 中提交您的功能请求 (Feature Requests)。我们将在接下来的周会上进行审查和优先级排序。
@@ -69,12 +71,12 @@
🤖 **AI 自举**: 纯 Go 语言原生实现 — 95% 的核心代码由 Agent 生成,并经由“人机回环 (Human-in-the-loop)”微调。
-| | OpenClaw | NanoBot | **PicoClaw** |
-| --- | --- | --- | --- |
-| **语言** | TypeScript | Python | **Go** |
-| **RAM** | >1GB | >100MB | **< 10MB** |
-| **启动时间**(0.8GHz core) | >500s | >30s | **<1s** |
-| **成本** | Mac Mini $599 | 大多数 Linux 开发板 ~$50 | **任意 Linux 开发板****低至 $10** |
+| | OpenClaw | NanoBot | **PicoClaw** |
+| ------------------------------ | ------------- | ------------------------ | -------------------------------------- |
+| **语言** | TypeScript | Python | **Go** |
+| **RAM** | >1GB | >100MB | **< 10MB** |
+| **启动时间**(0.8GHz core) | >500s | >30s | **<1s** |
+| **成本** | Mac Mini $599 | 大多数 Linux 开发板 ~$50 | **任意 Linux 开发板****低至 $10** |
@@ -101,9 +103,12 @@
### 📱 在手机上轻松运行
+
picoclaw 可以将你10年前的老旧手机废物利用,变身成为你的AI助理!快速指南:
+
1. 先去应用商店下载安装Termux
2. 打开后执行指令
+
```bash
# 注意: 下面的v0.1.1 可以换为你实际看到的最新版本
wget https://github.com/sipeed/picoclaw/releases/download/v0.1.1/picoclaw-linux-arm64
@@ -111,19 +116,17 @@ chmod +x picoclaw-linux-arm64
pkg install proot
termux-chroot ./picoclaw-linux-arm64 onboard
```
-然后跟随下面的“快速开始”章节继续配置picoclaw即可使用!
+
+然后跟随下面的“快速开始”章节继续配置picoclaw即可使用!
-
-
-
### 🐜 创新的低占用部署
PicoClaw 几乎可以部署在任何 Linux 设备上!
-* $9.9 [LicheeRV-Nano](https://www.aliexpress.com/item/1005006519668532.html) E(网口) 或 W(WiFi6) 版本,用于极简家庭助手。
-* $30~50 [NanoKVM](https://www.aliexpress.com/item/1005007369816019.html),或 $100 [NanoKVM-Pro](https://www.aliexpress.com/item/1005010048471263.html),用于自动化服务器运维。
-* $50 [MaixCAM](https://www.aliexpress.com/item/1005008053333693.html) 或 $100 [MaixCAM2](https://www.kickstarter.com/projects/zepan/maixcam2-build-your-next-gen-4k-ai-camera),用于智能监控。
+- $9.9 [LicheeRV-Nano](https://www.aliexpress.com/item/1005006519668532.html) E(网口) 或 W(WiFi6) 版本,用于极简家庭助手。
+- $30~50 [NanoKVM](https://www.aliexpress.com/item/1005007369816019.html),或 $100 [NanoKVM-Pro](https://www.aliexpress.com/item/1005010048471263.html),用于自动化服务器运维。
+- $50 [MaixCAM](https://www.aliexpress.com/item/1005008053333693.html) 或 $100 [MaixCAM2](https://www.kickstarter.com/projects/zepan/maixcam2-build-your-next-gen-4k-ai-camera),用于智能监控。
[https://private-user-images.githubusercontent.com/83055338/547056448-e7b031ff-d6f5-4468-bcca-5726b6fecb5c.mp4](https://private-user-images.githubusercontent.com/83055338/547056448-e7b031ff-d6f5-4468-bcca-5726b6fecb5c.mp4)
@@ -170,6 +173,9 @@ vim config/config.json # 设置 DISCORD_BOT_TOKEN, API keys 等
# 3. 构建并启动
docker compose --profile gateway up -d
+> [!TIP]
+**Docker 用户**: 默认情况下, Gateway监听 `127.0.0.1`,这使得这个端口未暴露到容器外。如果你需要通过端口映射访问健康检查接口, 请在环境变量中设置 `PICOCLAW_GATEWAY_HOST=0.0.0.0` 或修改 `config.json`。
+
# 4. 查看日志
docker compose logs -f picoclaw-gateway
@@ -202,7 +208,7 @@ docker compose --profile gateway up -d
> [!TIP]
> 在 `~/.picoclaw/config.json` 中设置您的 API Key。
> 获取 API Key: [OpenRouter](https://openrouter.ai/keys) (LLM) · [Zhipu (智谱)](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) (LLM)
-> 网络搜索是 **可选的** - 获取免费的 [Brave Search API](https://brave.com/search/api) (每月 2000 次免费查询)
+> 网络搜索是 **可选的** - 获取免费的 [Tavily API](https://tavily.com) (每月 1000 次免费查询) 或 [Brave Search API](https://brave.com/search/api) (每月 2000 次免费查询)
**1. 初始化 (Initialize)**
@@ -218,7 +224,7 @@ picoclaw onboard
"agents": {
"defaults": {
"workspace": "~/.picoclaw/workspace",
- "model": "gpt4",
+ "model_name": "gpt4",
"max_tokens": 8192,
"temperature": 0.7,
"max_tool_iterations": 20
@@ -238,9 +244,15 @@ picoclaw onboard
],
"tools": {
"web": {
- "search": {
+ "brave": {
+ "enabled": false,
"api_key": "YOUR_BRAVE_API_KEY",
"max_results": 5
+ },
+ "tavily": {
+ "enabled": false,
+ "api_key": "YOUR_TAVILY_API_KEY",
+ "max_results": 5
}
},
"cron": {
@@ -248,15 +260,14 @@ picoclaw onboard
}
}
}
-
```
-> **新功能**: `model_list` 配置格式支持零代码添加 provider。详见[模型配置](#-模型配置-model_list)章节。
+> **新功能**: `model_list` 配置格式支持零代码添加 provider。详见[模型配置](#模型配置-model_list)章节。
**3. 获取 API Key**
* **LLM 提供商**: [OpenRouter](https://openrouter.ai/keys) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) · [Anthropic](https://console.anthropic.com) · [OpenAI](https://platform.openai.com) · [Gemini](https://aistudio.google.com/api-keys)
-* **网络搜索** (可选): [Brave Search](https://brave.com/search/api) - 提供免费层级 (2000 请求/月)
+* **网络搜索** (可选): [Tavily](https://tavily.com) - 专为 AI Agent 优化 (1000 请求/月) · [Brave Search](https://brave.com/search/api) - 提供免费层级 (2000 请求/月)
> **注意**: 完整的配置模板请参考 `config.example.json`。
@@ -273,176 +284,28 @@ picoclaw agent -m "2+2 等于几?"
## 💬 聊天应用集成 (Chat Apps)
-通过 Telegram, Discord 或钉钉与您的 PicoClaw 对话。
+PicoClaw 支持多种聊天平台,使您的 Agent 能够连接到任何地方。
-| 渠道 | 设置难度 |
-| --- | --- |
-| **Telegram** | 简单 (仅需 token) |
-| **Discord** | 简单 (bot token + intents) |
-| **QQ** | 简单 (AppID + AppSecret) |
-| **钉钉 (DingTalk)** | 中等 (app credentials) |
+### 核心渠道
-
-Telegram (推荐)
-
-**1. 创建机器人**
-
-* 打开 Telegram,搜索 `@BotFather`
-* 发送 `/newbot`,按照提示操作
-* 复制 token
-
-**2. 配置**
-
-```json
-{
- "channels": {
- "telegram": {
- "enabled": true,
- "token": "YOUR_BOT_TOKEN",
- "allow_from": ["YOUR_USER_ID"]
- }
- }
-}
-
-```
-
-> 从 Telegram 上的 `@userinfobot` 获取您的用户 ID。
-
-**3. 运行**
-
-```bash
-picoclaw gateway
-
-```
-
-
-
-
-Discord
-
-**1. 创建机器人**
-
-* 前往 [https://discord.com/developers/applications](https://discord.com/developers/applications)
-* Create an application → Bot → Add Bot
-* 复制 bot token
-
-**2. 开启 Intents**
-
-* 在 Bot 设置中,开启 **MESSAGE CONTENT INTENT**
-* (可选) 如果计划基于成员数据使用白名单,开启 **SERVER MEMBERS INTENT**
-
-**3. 获取您的 User ID**
-
-* Discord 设置 → Advanced → 开启 **Developer Mode**
-* 右键点击您的头像 → **Copy User ID**
-
-**4. 配置**
-
-```json
-{
- "channels": {
- "discord": {
- "enabled": true,
- "token": "YOUR_BOT_TOKEN",
- "allow_from": ["YOUR_USER_ID"]
- }
- }
-}
-
-```
-
-**5. 邀请机器人**
-
-* OAuth2 → URL Generator
-* Scopes: `bot`
-* Bot Permissions: `Send Messages`, `Read Message History`
-* 打开生成的邀请 URL,将机器人添加到您的服务器
-
-**6. 运行**
-
-```bash
-picoclaw gateway
-
-```
-
-
-
-
-QQ
-
-**1. 创建机器人**
-
-* 前往 [QQ 开放平台](https://q.qq.com/#)
-* 创建应用 → 获取 **AppID** 和 **AppSecret**
-
-**2. 配置**
-
-```json
-{
- "channels": {
- "qq": {
- "enabled": true,
- "app_id": "YOUR_APP_ID",
- "app_secret": "YOUR_APP_SECRET",
- "allow_from": []
- }
- }
-}
-
-```
-
-> 将 `allow_from` 设为空以允许所有用户,或指定 QQ 号以限制访问。
-
-**3. 运行**
-
-```bash
-picoclaw gateway
-
-```
-
-
-
-
-钉钉 (DingTalk)
-
-**1. 创建机器人**
-
-* 前往 [开放平台](https://open.dingtalk.com/)
-* 创建内部应用
-* 复制 Client ID 和 Client Secret
-
-**2. 配置**
-
-```json
-{
- "channels": {
- "dingtalk": {
- "enabled": true,
- "client_id": "YOUR_CLIENT_ID",
- "client_secret": "YOUR_CLIENT_SECRET",
- "allow_from": []
- }
- }
-}
-
-```
-
-> 将 `allow_from` 设为空以允许所有用户,或指定 ID 以限制访问。
-
-**3. 运行**
-
-```bash
-picoclaw gateway
-
-```
-
-
+| 渠道 | 设置难度 | 特性说明 | 文档链接 |
+| -------------------- | ----------- | ----------------------------------------- | --------------------------------------------------------------------------------------------------------------- |
+| **Telegram** | ⭐ 简单 | 推荐,支持语音转文字,长轮询无需公网 | [查看文档](docs/channels/telegram/README.zh.md) |
+| **Discord** | ⭐ 简单 | Socket Mode,支持群组/私信,Bot 生态成熟 | [查看文档](docs/channels/discord/README.zh.md) |
+| **Slack** | ⭐ 简单 | **Socket Mode** (无需公网 IP),企业级支持 | [查看文档](docs/channels/slack/README.zh.md) |
+| **QQ** | ⭐⭐ 中等 | 官方机器人 API,适合国内社群 | [查看文档](docs/channels/qq/README.zh.md) |
+| **钉钉 (DingTalk)** | ⭐⭐ 中等 | Stream 模式无需公网,企业办公首选 | [查看文档](docs/channels/dingtalk/README.zh.md) |
+| **企业微信 (WeCom)** | ⭐⭐⭐ 较难 | 支持群机器人(Webhook)和自建应用(API) | [Bot 文档](docs/channels/wecom/wecom_bot/README.zh.md) / [App 文档](docs/channels/wecom/wecom_app/README.zh.md) |
+| **飞书 (Feishu)** | ⭐⭐⭐ 较难 | 企业级协作,功能丰富 | [查看文档](docs/channels/feishu/README.zh.md) |
+| **Line** | ⭐⭐⭐ 较难 | 需要 HTTPS Webhook | [查看文档](docs/channels/line/README.zh.md) |
+| **OneBot** | ⭐⭐ 中等 | 兼容 NapCat/Go-CQHTTP,社区生态丰富 | [查看文档](docs/channels/onebot/README.zh.md) |
+| **MaixCam** | ⭐ 简单 | 专为 AI 摄像头设计的硬件集成通道 | [查看文档](docs/channels/maixcam/README.zh.md) |
##
加入 Agent 社交网络
只需通过 CLI 或任何集成的聊天应用发送一条消息,即可将 PicoClaw 连接到 Agent 社交网络。
-**阅读 `https://clawdchat.ai/skill.md` 并按照说明加入 [ClawdChat.ai**](https://clawdchat.ai)
+\*\*阅读 `https://clawdchat.ai/skill.md` 并按照说明加入 [ClawdChat.ai](https://clawdchat.ai)
## ⚙️ 配置详解
@@ -478,7 +341,6 @@ PicoClaw 可以自动执行周期性任务。在工作区创建 `HEARTBEAT.md`
- Check my email for important messages
- Review my calendar for upcoming events
- Check the weather forecast
-
```
Agent 将每隔 30 分钟(可配置)读取此文件,并使用可用工具执行任务。
@@ -491,22 +353,23 @@ Agent 将每隔 30 分钟(可配置)读取此文件,并使用可用工具
# Periodic Tasks
## Quick Tasks (respond directly)
+
- Report current time
## Long Tasks (use spawn for async)
+
- Search the web for AI news and summarize
- Check email and report important messages
-
```
**关键行为:**
-| 特性 | 描述 |
-| --- | --- |
-| **spawn** | 创建异步子 Agent,不阻塞主心跳进程 |
-| **独立上下文** | 子 Agent 拥有独立上下文,无会话历史 |
+| 特性 | 描述 |
+| ---------------- | ---------------------------------------- |
+| **spawn** | 创建异步子 Agent,不阻塞主心跳进程 |
+| **独立上下文** | 子 Agent 拥有独立上下文,无会话历史 |
| **message tool** | 子 Agent 通过 message 工具直接与用户通信 |
-| **非阻塞** | spawn 后,心跳继续处理下一个任务 |
+| **非阻塞** | spawn 后,心跳继续处理下一个任务 |
#### 子 Agent 通信原理
@@ -536,35 +399,34 @@ Agent 读取 HEARTBEAT.md
"interval": 30
}
}
-
```
-| 选项 | 默认值 | 描述 |
-| --- | --- | --- |
-| `enabled` | `true` | 启用/禁用心跳 |
-| `interval` | `30` | 检查间隔,单位分钟 (最小: 5) |
+| 选项 | 默认值 | 描述 |
+| ---------- | ------ | ---------------------------- |
+| `enabled` | `true` | 启用/禁用心跳 |
+| `interval` | `30` | 检查间隔,单位分钟 (最小: 5) |
**环境变量:**
-* `PICOCLAW_HEARTBEAT_ENABLED=false` 禁用
-* `PICOCLAW_HEARTBEAT_INTERVAL=60` 更改间隔
+- `PICOCLAW_HEARTBEAT_ENABLED=false` 禁用
+- `PICOCLAW_HEARTBEAT_INTERVAL=60` 更改间隔
### 提供商 (Providers)
> [!NOTE]
> Groq 通过 Whisper 提供免费的语音转录。如果配置了 Groq,Telegram 语音消息将被自动转录为文字。
-| 提供商 | 用途 | 获取 API Key |
-| --- | --- | --- |
-| `gemini` | LLM (Gemini 直连) | [aistudio.google.com](https://aistudio.google.com) |
-| `zhipu` | LLM (智谱直连) | [bigmodel.cn](bigmodel.cn) |
-| `openrouter(待测试)` | LLM (推荐,可访问所有模型) | [openrouter.ai](https://openrouter.ai) |
-| `anthropic(待测试)` | LLM (Claude 直连) | [console.anthropic.com](https://console.anthropic.com) |
-| `openai(待测试)` | LLM (GPT 直连) | [platform.openai.com](https://platform.openai.com) |
-| `deepseek(待测试)` | LLM (DeepSeek 直连) | [platform.deepseek.com](https://platform.deepseek.com) |
-| `qwen` | LLM (通义千问) | [dashscope.console.aliyun.com](https://dashscope.console.aliyun.com) |
-| `groq` | LLM + **语音转录** (Whisper) | [console.groq.com](https://console.groq.com) |
-| `cerebras` | LLM (Cerebras 直连) | [cerebras.ai](https://cerebras.ai) |
+| 提供商 | 用途 | 获取 API Key |
+| -------------------- | ---------------------------- | -------------------------------------------------------------------- |
+| `gemini` | LLM (Gemini 直连) | [aistudio.google.com](https://aistudio.google.com) |
+| `zhipu` | LLM (智谱直连) | [bigmodel.cn](bigmodel.cn) |
+| `openrouter(待测试)` | LLM (推荐,可访问所有模型) | [openrouter.ai](https://openrouter.ai) |
+| `anthropic(待测试)` | LLM (Claude 直连) | [console.anthropic.com](https://console.anthropic.com) |
+| `openai(待测试)` | LLM (GPT 直连) | [platform.openai.com](https://platform.openai.com) |
+| `deepseek(待测试)` | LLM (DeepSeek 直连) | [platform.deepseek.com](https://platform.deepseek.com) |
+| `qwen` | LLM (通义千问) | [dashscope.console.aliyun.com](https://dashscope.console.aliyun.com) |
+| `groq` | LLM + **语音转录** (Whisper) | [console.groq.com](https://console.groq.com) |
+| `cerebras` | LLM (Cerebras 直连) | [cerebras.ai](https://cerebras.ai) |
### 模型配置 (model_list)
@@ -579,25 +441,25 @@ Agent 读取 HEARTBEAT.md
#### 📋 所有支持的厂商
-| 厂商 | `model` 前缀 | 默认 API Base | 协议 | 获取 API Key |
-|------|-------------|---------------|------|--------------|
-| **OpenAI** | `openai/` | `https://api.openai.com/v1` | OpenAI | [获取密钥](https://platform.openai.com) |
-| **Anthropic** | `anthropic/` | `https://api.anthropic.com/v1` | Anthropic | [获取密钥](https://console.anthropic.com) |
-| **智谱 AI (GLM)** | `zhipu/` | `https://open.bigmodel.cn/api/paas/v4` | OpenAI | [获取密钥](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) |
-| **DeepSeek** | `deepseek/` | `https://api.deepseek.com/v1` | OpenAI | [获取密钥](https://platform.deepseek.com) |
-| **Google Gemini** | `gemini/` | `https://generativelanguage.googleapis.com/v1beta` | OpenAI | [获取密钥](https://aistudio.google.com/api-keys) |
-| **Groq** | `groq/` | `https://api.groq.com/openai/v1` | OpenAI | [获取密钥](https://console.groq.com) |
-| **Moonshot** | `moonshot/` | `https://api.moonshot.cn/v1` | OpenAI | [获取密钥](https://platform.moonshot.cn) |
-| **通义千问 (Qwen)** | `qwen/` | `https://dashscope.aliyuncs.com/compatible-mode/v1` | OpenAI | [获取密钥](https://dashscope.console.aliyun.com) |
-| **NVIDIA** | `nvidia/` | `https://integrate.api.nvidia.com/v1` | OpenAI | [获取密钥](https://build.nvidia.com) |
-| **Ollama** | `ollama/` | `http://localhost:11434/v1` | OpenAI | 本地(无需密钥) |
-| **OpenRouter** | `openrouter/` | `https://openrouter.ai/api/v1` | OpenAI | [获取密钥](https://openrouter.ai/keys) |
-| **VLLM** | `vllm/` | `http://localhost:8000/v1` | OpenAI | 本地 |
-| **Cerebras** | `cerebras/` | `https://api.cerebras.ai/v1` | OpenAI | [获取密钥](https://cerebras.ai) |
-| **火山引擎** | `volcengine/` | `https://ark.cn-beijing.volces.com/api/v3` | OpenAI | [获取密钥](https://console.volcengine.com) |
-| **神算云** | `shengsuanyun/` | `https://router.shengsuanyun.com/api/v1` | OpenAI | - |
-| **Antigravity** | `antigravity/` | Google Cloud | 自定义 | 仅 OAuth |
-| **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - |
+| 厂商 | `model` 前缀 | 默认 API Base | 协议 | 获取 API Key |
+| ------------------- | ----------------- | --------------------------------------------------- | --------- | ----------------------------------------------------------------- |
+| **OpenAI** | `openai/` | `https://api.openai.com/v1` | OpenAI | [获取密钥](https://platform.openai.com) |
+| **Anthropic** | `anthropic/` | `https://api.anthropic.com/v1` | Anthropic | [获取密钥](https://console.anthropic.com) |
+| **智谱 AI (GLM)** | `zhipu/` | `https://open.bigmodel.cn/api/paas/v4` | OpenAI | [获取密钥](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) |
+| **DeepSeek** | `deepseek/` | `https://api.deepseek.com/v1` | OpenAI | [获取密钥](https://platform.deepseek.com) |
+| **Google Gemini** | `gemini/` | `https://generativelanguage.googleapis.com/v1beta` | OpenAI | [获取密钥](https://aistudio.google.com/api-keys) |
+| **Groq** | `groq/` | `https://api.groq.com/openai/v1` | OpenAI | [获取密钥](https://console.groq.com) |
+| **Moonshot** | `moonshot/` | `https://api.moonshot.cn/v1` | OpenAI | [获取密钥](https://platform.moonshot.cn) |
+| **通义千问 (Qwen)** | `qwen/` | `https://dashscope.aliyuncs.com/compatible-mode/v1` | OpenAI | [获取密钥](https://dashscope.console.aliyun.com) |
+| **NVIDIA** | `nvidia/` | `https://integrate.api.nvidia.com/v1` | OpenAI | [获取密钥](https://build.nvidia.com) |
+| **Ollama** | `ollama/` | `http://localhost:11434/v1` | OpenAI | 本地(无需密钥) |
+| **OpenRouter** | `openrouter/` | `https://openrouter.ai/api/v1` | OpenAI | [获取密钥](https://openrouter.ai/keys) |
+| **VLLM** | `vllm/` | `http://localhost:8000/v1` | OpenAI | 本地 |
+| **Cerebras** | `cerebras/` | `https://api.cerebras.ai/v1` | OpenAI | [获取密钥](https://cerebras.ai) |
+| **火山引擎** | `volcengine/` | `https://ark.cn-beijing.volces.com/api/v3` | OpenAI | [获取密钥](https://console.volcengine.com) |
+| **神算云** | `shengsuanyun/` | `https://router.shengsuanyun.com/api/v1` | OpenAI | - |
+| **Antigravity** | `antigravity/` | Google Cloud | 自定义 | 仅 OAuth |
+| **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - |
#### 基础配置示例
@@ -631,6 +493,7 @@ Agent 读取 HEARTBEAT.md
#### 各厂商配置示例
**OpenAI**
+
```json
{
"model_name": "gpt-5.2",
@@ -640,6 +503,7 @@ Agent 读取 HEARTBEAT.md
```
**智谱 AI (GLM)**
+
```json
{
"model_name": "glm-4.7",
@@ -649,6 +513,7 @@ Agent 读取 HEARTBEAT.md
```
**DeepSeek**
+
```json
{
"model_name": "deepseek-chat",
@@ -658,6 +523,7 @@ Agent 读取 HEARTBEAT.md
```
**Anthropic (使用 OAuth)**
+
```json
{
"model_name": "claude-sonnet-4.6",
@@ -665,9 +531,11 @@ Agent 读取 HEARTBEAT.md
"auth_method": "oauth"
}
```
+
> 运行 `picoclaw auth login --provider anthropic` 来设置 OAuth 凭证。
**Ollama (本地)**
+
```json
{
"model_name": "llama3",
@@ -676,6 +544,7 @@ Agent 读取 HEARTBEAT.md
```
**自定义代理/API**
+
```json
{
"model_name": "my-custom-model",
@@ -713,6 +582,7 @@ Agent 读取 HEARTBEAT.md
旧的 `providers` 配置格式**已弃用**,但为向后兼容仍支持。
**旧配置(已弃用):**
+
```json
{
"providers": {
@@ -731,6 +601,7 @@ Agent 读取 HEARTBEAT.md
```
**新配置(推荐):**
+
```json
{
"model_list": [
@@ -755,7 +626,7 @@ Agent 读取 HEARTBEAT.md
**1. 获取 API key 和 base URL**
-* 获取 [API key](https://bigmodel.cn/usercenter/proj-mgmt/apikeys)
+- 获取 [API key](https://bigmodel.cn/usercenter/proj-mgmt/apikeys)
**2. 配置**
@@ -774,10 +645,9 @@ Agent 读取 HEARTBEAT.md
"zhipu": {
"api_key": "Your API Key",
"api_base": "https://open.bigmodel.cn/api/paas/v4"
- },
- },
+ }
+ }
}
-
```
**3. 运行**
@@ -838,8 +708,14 @@ picoclaw agent -m "你好"
},
"tools": {
"web": {
- "search": {
- "api_key": "BSA..."
+ "brave": {
+ "enabled": false,
+ "api_key": "YOUR_BRAVE_API_KEY",
+ "max_results": 5
+ },
+ "duckduckgo": {
+ "enabled": true,
+ "max_results": 5
}
},
"cron": {
@@ -851,30 +727,29 @@ picoclaw agent -m "你好"
"interval": 30
}
}
-
```
## CLI 命令行参考
-| 命令 | 描述 |
-| --- | --- |
-| `picoclaw onboard` | 初始化配置和工作区 |
-| `picoclaw agent -m "..."` | 与 Agent 对话 |
-| `picoclaw agent` | 交互式聊天模式 |
-| `picoclaw gateway` | 启动网关 (Gateway) |
-| `picoclaw status` | 显示状态 |
-| `picoclaw cron list` | 列出所有定时任务 |
-| `picoclaw cron add ...` | 添加定时任务 |
+| 命令 | 描述 |
+| ------------------------- | ------------------ |
+| `picoclaw onboard` | 初始化配置和工作区 |
+| `picoclaw agent -m "..."` | 与 Agent 对话 |
+| `picoclaw agent` | 交互式聊天模式 |
+| `picoclaw gateway` | 启动网关 (Gateway) |
+| `picoclaw status` | 显示状态 |
+| `picoclaw cron list` | 列出所有定时任务 |
+| `picoclaw cron add ...` | 添加定时任务 |
### 定时任务 / 提醒 (Scheduled Tasks)
PicoClaw 通过 `cron` 工具支持定时提醒和重复任务:
-* **一次性提醒**: "Remind me in 10 minutes" (10分钟后提醒我) → 10分钟后触发一次
-* **重复任务**: "Remind me every 2 hours" (每2小时提醒我) → 每2小时触发
-* **Cron 表达式**: "Remind me at 9am daily" (每天上午9点提醒我) → 使用 cron 表达式
+- **一次性提醒**: "Remind me in 10 minutes" (10分钟后提醒我) → 10分钟后触发一次
+- **重复任务**: "Remind me every 2 hours" (每2小时提醒我) → 每2小时触发
+- **Cron 表达式**: "Remind me at 9am daily" (每天上午9点提醒我) → 使用 cron 表达式
任务存储在 `~/.picoclaw/workspace/cron/` 中并自动处理。
@@ -888,7 +763,7 @@ PicoClaw 通过 `cron` 工具支持定时提醒和重复任务:
用户群组:
-Discord: [https://discord.gg/V4sAZ9XWpN](https://discord.gg/V4sAZ9XWpN)
+Discord: [https://discord.gg/V4sAZ9XWpN](https://discord.gg/V4sAZ9XWpN)
@@ -900,24 +775,27 @@ Discord: [https://discord.gg/V4sAZ9XWpN](https://discord.gg/V4sAZ9XWpN)
启用网络搜索:
-1. 在 [https://brave.com/search/api](https://brave.com/search/api) 获取免费 API Key (每月 2000 次免费查询)
+1. 在 [https://tavily.com](https://tavily.com) (1000 次免费) 或 [https://brave.com/search/api](https://brave.com/search/api) 获取免费 API Key (2000 次免费)
2. 添加到 `~/.picoclaw/config.json`:
+
```json
{
"tools": {
"web": {
- "search": {
+ "brave": {
+ "enabled": false,
"api_key": "YOUR_BRAVE_API_KEY",
"max_results": 5
+ },
+ "duckduckgo": {
+ "enabled": true,
+ "max_results": 5
}
}
}
}
-
```
-
-
### 遇到内容过滤错误 (Content Filtering Errors)
某些提供商(如智谱)有严格的内容过滤。尝试改写您的问题或使用其他模型。
@@ -935,5 +813,5 @@ Discord: [https://discord.gg/V4sAZ9XWpN](https://discord.gg/V4sAZ9XWpN)
| **OpenRouter** | 200K tokens/月 | 多模型聚合 (Claude, GPT-4 等) |
| **智谱 (Zhipu)** | 200K tokens/月 | 最适合中国用户 |
| **Brave Search** | 2000 次查询/月 | 网络搜索功能 |
+| **Tavily** | 1000 次查询/月 | AI Agent 搜索优化 |
| **Groq** | 提供免费层级 | 极速推理 (Llama, Mixtral) |
-| **Cerebras** | 提供免费层级 | 极速推理 (Llama, Qwen 等) |
\ No newline at end of file
diff --git a/assets/wechat.png b/assets/wechat.png
index 8fc41ea7d..776c07885 100644
Binary files a/assets/wechat.png and b/assets/wechat.png differ
diff --git a/cmd/picoclaw/cmd_cron.go b/cmd/picoclaw/cmd_cron.go
deleted file mode 100644
index 8c42bde06..000000000
--- a/cmd/picoclaw/cmd_cron.go
+++ /dev/null
@@ -1,227 +0,0 @@
-// PicoClaw - Ultra-lightweight personal AI agent
-// License: MIT
-
-package main
-
-import (
- "fmt"
- "os"
- "path/filepath"
- "time"
-
- "github.com/sipeed/picoclaw/pkg/cron"
-)
-
-func cronCmd() {
- if len(os.Args) < 3 {
- cronHelp()
- return
- }
-
- subcommand := os.Args[2]
-
- // Load config to get workspace path
- cfg, err := loadConfig()
- if err != nil {
- fmt.Printf("Error loading config: %v\n", err)
- return
- }
-
- cronStorePath := filepath.Join(cfg.WorkspacePath(), "cron", "jobs.json")
-
- switch subcommand {
- case "list":
- cronListCmd(cronStorePath)
- case "add":
- cronAddCmd(cronStorePath)
- case "remove":
- if len(os.Args) < 4 {
- fmt.Println("Usage: picoclaw cron remove ")
- return
- }
- cronRemoveCmd(cronStorePath, os.Args[3])
- case "enable":
- cronEnableCmd(cronStorePath, false)
- case "disable":
- cronEnableCmd(cronStorePath, true)
- default:
- fmt.Printf("Unknown cron command: %s\n", subcommand)
- cronHelp()
- }
-}
-
-func cronHelp() {
- fmt.Println("\nCron commands:")
- fmt.Println(" list List all scheduled jobs")
- fmt.Println(" add Add a new scheduled job")
- fmt.Println(" remove Remove a job by ID")
- fmt.Println(" enable Enable a job")
- fmt.Println(" disable Disable a job")
- fmt.Println()
- fmt.Println("Add options:")
- fmt.Println(" -n, --name Job name")
- fmt.Println(" -m, --message Message for agent")
- fmt.Println(" -e, --every Run every N seconds")
- fmt.Println(" -c, --cron Cron expression (e.g. '0 9 * * *')")
- fmt.Println(" -d, --deliver Deliver response to channel")
- fmt.Println(" --to Recipient for delivery")
- fmt.Println(" --channel Channel for delivery")
-}
-
-func cronListCmd(storePath string) {
- cs := cron.NewCronService(storePath, nil)
- jobs := cs.ListJobs(true) // Show all jobs, including disabled
-
- if len(jobs) == 0 {
- fmt.Println("No scheduled jobs.")
- return
- }
-
- fmt.Println("\nScheduled Jobs:")
- fmt.Println("----------------")
- for _, job := range jobs {
- var schedule string
- if job.Schedule.Kind == "every" && job.Schedule.EveryMS != nil {
- schedule = fmt.Sprintf("every %ds", *job.Schedule.EveryMS/1000)
- } else if job.Schedule.Kind == "cron" {
- schedule = job.Schedule.Expr
- } else {
- schedule = "one-time"
- }
-
- nextRun := "scheduled"
- if job.State.NextRunAtMS != nil {
- nextTime := time.UnixMilli(*job.State.NextRunAtMS)
- nextRun = nextTime.Format("2006-01-02 15:04")
- }
-
- status := "enabled"
- if !job.Enabled {
- status = "disabled"
- }
-
- fmt.Printf(" %s (%s)\n", job.Name, job.ID)
- fmt.Printf(" Schedule: %s\n", schedule)
- fmt.Printf(" Status: %s\n", status)
- fmt.Printf(" Next run: %s\n", nextRun)
- }
-}
-
-func cronAddCmd(storePath string) {
- name := ""
- message := ""
- var everySec *int64
- cronExpr := ""
- deliver := false
- channel := ""
- to := ""
-
- args := os.Args[3:]
- for i := 0; i < len(args); i++ {
- switch args[i] {
- case "-n", "--name":
- if i+1 < len(args) {
- name = args[i+1]
- i++
- }
- case "-m", "--message":
- if i+1 < len(args) {
- message = args[i+1]
- i++
- }
- case "-e", "--every":
- if i+1 < len(args) {
- var sec int64
- fmt.Sscanf(args[i+1], "%d", &sec)
- everySec = &sec
- i++
- }
- case "-c", "--cron":
- if i+1 < len(args) {
- cronExpr = args[i+1]
- i++
- }
- case "-d", "--deliver":
- deliver = true
- case "--to":
- if i+1 < len(args) {
- to = args[i+1]
- i++
- }
- case "--channel":
- if i+1 < len(args) {
- channel = args[i+1]
- i++
- }
- }
- }
-
- if name == "" {
- fmt.Println("Error: --name is required")
- return
- }
-
- if message == "" {
- fmt.Println("Error: --message is required")
- return
- }
-
- if everySec == nil && cronExpr == "" {
- fmt.Println("Error: Either --every or --cron must be specified")
- return
- }
-
- var schedule cron.CronSchedule
- if everySec != nil {
- everyMS := *everySec * 1000
- schedule = cron.CronSchedule{
- Kind: "every",
- EveryMS: &everyMS,
- }
- } else {
- schedule = cron.CronSchedule{
- Kind: "cron",
- Expr: cronExpr,
- }
- }
-
- cs := cron.NewCronService(storePath, nil)
- job, err := cs.AddJob(name, schedule, message, deliver, channel, to)
- if err != nil {
- fmt.Printf("Error adding job: %v\n", err)
- return
- }
-
- fmt.Printf("✓ Added job '%s' (%s)\n", job.Name, job.ID)
-}
-
-func cronRemoveCmd(storePath, jobID string) {
- cs := cron.NewCronService(storePath, nil)
- if cs.RemoveJob(jobID) {
- fmt.Printf("✓ Removed job %s\n", jobID)
- } else {
- fmt.Printf("✗ Job %s not found\n", jobID)
- }
-}
-
-func cronEnableCmd(storePath string, disable bool) {
- if len(os.Args) < 4 {
- fmt.Println("Usage: picoclaw cron enable/disable ")
- return
- }
-
- jobID := os.Args[3]
- cs := cron.NewCronService(storePath, nil)
- enabled := !disable
-
- job := cs.EnableJob(jobID, enabled)
- if job != nil {
- status := "enabled"
- if disable {
- status = "disabled"
- }
- fmt.Printf("✓ Job '%s' %s\n", job.Name, status)
- } else {
- fmt.Printf("✗ Job %s not found\n", jobID)
- }
-}
diff --git a/cmd/picoclaw/cmd_migrate.go b/cmd/picoclaw/cmd_migrate.go
deleted file mode 100644
index 86d4903ef..000000000
--- a/cmd/picoclaw/cmd_migrate.go
+++ /dev/null
@@ -1,81 +0,0 @@
-// PicoClaw - Ultra-lightweight personal AI agent
-// License: MIT
-
-package main
-
-import (
- "fmt"
- "os"
-
- "github.com/sipeed/picoclaw/pkg/migrate"
-)
-
-func migrateCmd() {
- if len(os.Args) > 2 && (os.Args[2] == "--help" || os.Args[2] == "-h") {
- migrateHelp()
- return
- }
-
- opts := migrate.Options{}
-
- args := os.Args[2:]
- for i := 0; i < len(args); i++ {
- switch args[i] {
- case "--dry-run":
- opts.DryRun = true
- case "--config-only":
- opts.ConfigOnly = true
- case "--workspace-only":
- opts.WorkspaceOnly = true
- case "--force":
- opts.Force = true
- case "--refresh":
- opts.Refresh = true
- case "--openclaw-home":
- if i+1 < len(args) {
- opts.OpenClawHome = args[i+1]
- i++
- }
- case "--picoclaw-home":
- if i+1 < len(args) {
- opts.PicoClawHome = args[i+1]
- i++
- }
- default:
- fmt.Printf("Unknown flag: %s\n", args[i])
- migrateHelp()
- os.Exit(1)
- }
- }
-
- result, err := migrate.Run(opts)
- if err != nil {
- fmt.Printf("Error: %v\n", err)
- os.Exit(1)
- }
-
- if !opts.DryRun {
- migrate.PrintSummary(result)
- }
-}
-
-func migrateHelp() {
- fmt.Println("\nMigrate from OpenClaw to PicoClaw")
- fmt.Println()
- fmt.Println("Usage: picoclaw migrate [options]")
- fmt.Println()
- fmt.Println("Options:")
- fmt.Println(" --dry-run Show what would be migrated without making changes")
- fmt.Println(" --refresh Re-sync workspace files from OpenClaw (repeatable)")
- fmt.Println(" --config-only Only migrate config, skip workspace files")
- fmt.Println(" --workspace-only Only migrate workspace files, skip config")
- fmt.Println(" --force Skip confirmation prompts")
- fmt.Println(" --openclaw-home Override OpenClaw home directory (default: ~/.openclaw)")
- fmt.Println(" --picoclaw-home Override PicoClaw home directory (default: ~/.picoclaw)")
- fmt.Println()
- fmt.Println("Examples:")
- fmt.Println(" picoclaw migrate Detect and migrate from OpenClaw")
- fmt.Println(" picoclaw migrate --dry-run Show what would be migrated")
- fmt.Println(" picoclaw migrate --refresh Re-sync workspace files")
- fmt.Println(" picoclaw migrate --force Migrate without confirmation")
-}
diff --git a/cmd/picoclaw/internal/agent/command.go b/cmd/picoclaw/internal/agent/command.go
new file mode 100644
index 000000000..47262fc85
--- /dev/null
+++ b/cmd/picoclaw/internal/agent/command.go
@@ -0,0 +1,30 @@
+package agent
+
+import (
+ "github.com/spf13/cobra"
+)
+
+func NewAgentCommand() *cobra.Command {
+ var (
+ message string
+ sessionKey string
+ model string
+ debug bool
+ )
+
+ cmd := &cobra.Command{
+ Use: "agent",
+ Short: "Interact with the agent directly",
+ Args: cobra.NoArgs,
+ RunE: func(cmd *cobra.Command, _ []string) error {
+ return agentCmd(message, sessionKey, model, debug)
+ },
+ }
+
+ cmd.Flags().BoolVarP(&debug, "debug", "d", false, "Enable debug logging")
+ cmd.Flags().StringVarP(&message, "message", "m", "", "Send a single message (non-interactive mode)")
+ cmd.Flags().StringVarP(&sessionKey, "session", "s", "cli:default", "Session key")
+ cmd.Flags().StringVarP(&model, "model", "", "", "Model to use")
+
+ return cmd
+}
diff --git a/cmd/picoclaw/internal/agent/command_test.go b/cmd/picoclaw/internal/agent/command_test.go
new file mode 100644
index 000000000..1457d6a49
--- /dev/null
+++ b/cmd/picoclaw/internal/agent/command_test.go
@@ -0,0 +1,33 @@
+package agent
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestNewAgentCommand(t *testing.T) {
+ cmd := NewAgentCommand()
+
+ require.NotNil(t, cmd)
+
+ assert.Equal(t, "agent", cmd.Use)
+ assert.Equal(t, "Interact with the agent directly", cmd.Short)
+
+ assert.Len(t, cmd.Aliases, 0)
+ assert.False(t, cmd.HasSubCommands())
+
+ assert.Nil(t, cmd.Run)
+ assert.NotNil(t, cmd.RunE)
+
+ assert.Nil(t, cmd.PersistentPreRun)
+ assert.Nil(t, cmd.PersistentPostRun)
+
+ assert.True(t, cmd.HasFlags())
+
+ assert.NotNil(t, cmd.Flags().Lookup("debug"))
+ assert.NotNil(t, cmd.Flags().Lookup("message"))
+ assert.NotNil(t, cmd.Flags().Lookup("session"))
+ assert.NotNil(t, cmd.Flags().Lookup("model"))
+}
diff --git a/cmd/picoclaw/cmd_agent.go b/cmd/picoclaw/internal/agent/helpers.go
similarity index 62%
rename from cmd/picoclaw/cmd_agent.go
rename to cmd/picoclaw/internal/agent/helpers.go
index cee9f68ec..746e9755e 100644
--- a/cmd/picoclaw/cmd_agent.go
+++ b/cmd/picoclaw/internal/agent/helpers.go
@@ -1,7 +1,4 @@
-// PicoClaw - Ultra-lightweight personal AI agent
-// License: MIT
-
-package main
+package agent
import (
"bufio"
@@ -13,59 +10,41 @@ import (
"strings"
"github.com/chzyer/readline"
+
+ "github.com/sipeed/picoclaw/cmd/picoclaw/internal"
"github.com/sipeed/picoclaw/pkg/agent"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
)
-func agentCmd() {
- message := ""
- sessionKey := "cli:default"
- modelOverride := ""
-
- args := os.Args[2:]
- for i := 0; i < len(args); i++ {
- switch args[i] {
- case "--debug", "-d":
- logger.SetLevel(logger.DEBUG)
- fmt.Println("🔍 Debug mode enabled")
- case "-m", "--message":
- if i+1 < len(args) {
- message = args[i+1]
- i++
- }
- case "-s", "--session":
- if i+1 < len(args) {
- sessionKey = args[i+1]
- i++
- }
- case "--model", "-model":
- if i+1 < len(args) {
- modelOverride = args[i+1]
- i++
- }
- }
+func agentCmd(message, sessionKey, model string, debug bool) error {
+ if sessionKey == "" {
+ sessionKey = "cli:default"
}
- cfg, err := loadConfig()
+ if debug {
+ logger.SetLevel(logger.DEBUG)
+ fmt.Println("🔍 Debug mode enabled")
+ }
+
+ cfg, err := internal.LoadConfig()
if err != nil {
- fmt.Printf("Error loading config: %v\n", err)
- os.Exit(1)
+ return fmt.Errorf("error loading config: %w", err)
}
- if modelOverride != "" {
- cfg.Agents.Defaults.Model = modelOverride
+ if model != "" {
+ cfg.Agents.Defaults.ModelName = model
}
provider, modelID, err := providers.CreateProvider(cfg)
if err != nil {
- fmt.Printf("Error creating provider: %v\n", err)
- os.Exit(1)
+ return fmt.Errorf("error creating provider: %w", err)
}
+
// Use the resolved model ID from provider creation
if modelID != "" {
- cfg.Agents.Defaults.Model = modelID
+ cfg.Agents.Defaults.ModelName = modelID
}
msgBus := bus.NewMessageBus()
@@ -74,28 +53,30 @@ func agentCmd() {
// Print agent startup info (only for interactive mode)
startupInfo := agentLoop.GetStartupInfo()
logger.InfoCF("agent", "Agent initialized",
- map[string]interface{}{
- "tools_count": startupInfo["tools"].(map[string]interface{})["count"],
- "skills_total": startupInfo["skills"].(map[string]interface{})["total"],
- "skills_available": startupInfo["skills"].(map[string]interface{})["available"],
+ map[string]any{
+ "tools_count": startupInfo["tools"].(map[string]any)["count"],
+ "skills_total": startupInfo["skills"].(map[string]any)["total"],
+ "skills_available": startupInfo["skills"].(map[string]any)["available"],
})
if message != "" {
ctx := context.Background()
response, err := agentLoop.ProcessDirect(ctx, message, sessionKey)
if err != nil {
- fmt.Printf("Error: %v\n", err)
- os.Exit(1)
+ return fmt.Errorf("error processing message: %w", err)
}
- fmt.Printf("\n%s %s\n", logo, response)
- } else {
- fmt.Printf("%s Interactive mode (Ctrl+C to exit)\n\n", logo)
- interactiveMode(agentLoop, sessionKey)
+ fmt.Printf("\n%s %s\n", internal.Logo, response)
+ return nil
}
+
+ fmt.Printf("%s Interactive mode (Ctrl+C to exit)\n\n", internal.Logo)
+ interactiveMode(agentLoop, sessionKey)
+
+ return nil
}
func interactiveMode(agentLoop *agent.AgentLoop, sessionKey string) {
- prompt := fmt.Sprintf("%s You: ", logo)
+ prompt := fmt.Sprintf("%s You: ", internal.Logo)
rl, err := readline.NewEx(&readline.Config{
Prompt: prompt,
@@ -104,7 +85,6 @@ func interactiveMode(agentLoop *agent.AgentLoop, sessionKey string) {
InterruptPrompt: "^C",
EOFPrompt: "exit",
})
-
if err != nil {
fmt.Printf("Error initializing readline: %v\n", err)
fmt.Println("Falling back to simple input mode...")
@@ -141,14 +121,14 @@ func interactiveMode(agentLoop *agent.AgentLoop, sessionKey string) {
continue
}
- fmt.Printf("\n%s %s\n\n", logo, response)
+ fmt.Printf("\n%s %s\n\n", internal.Logo, response)
}
}
func simpleInteractiveMode(agentLoop *agent.AgentLoop, sessionKey string) {
reader := bufio.NewReader(os.Stdin)
for {
- fmt.Print(fmt.Sprintf("%s You: ", logo))
+ fmt.Print(fmt.Sprintf("%s You: ", internal.Logo))
line, err := reader.ReadString('\n')
if err != nil {
if err == io.EOF {
@@ -176,6 +156,6 @@ func simpleInteractiveMode(agentLoop *agent.AgentLoop, sessionKey string) {
continue
}
- fmt.Printf("\n%s %s\n\n", logo, response)
+ fmt.Printf("\n%s %s\n\n", internal.Logo, response)
}
}
diff --git a/cmd/picoclaw/internal/auth/command.go b/cmd/picoclaw/internal/auth/command.go
new file mode 100644
index 000000000..12a0a3a8c
--- /dev/null
+++ b/cmd/picoclaw/internal/auth/command.go
@@ -0,0 +1,22 @@
+package auth
+
+import "github.com/spf13/cobra"
+
+func NewAuthCommand() *cobra.Command {
+ cmd := &cobra.Command{
+ Use: "auth",
+ Short: "Manage authentication (login, logout, status)",
+ RunE: func(cmd *cobra.Command, _ []string) error {
+ return cmd.Help()
+ },
+ }
+
+ cmd.AddCommand(
+ newLoginCommand(),
+ newLogoutCommand(),
+ newStatusCommand(),
+ newModelsCommand(),
+ )
+
+ return cmd
+}
diff --git a/cmd/picoclaw/internal/auth/command_test.go b/cmd/picoclaw/internal/auth/command_test.go
new file mode 100644
index 000000000..48dc704dd
--- /dev/null
+++ b/cmd/picoclaw/internal/auth/command_test.go
@@ -0,0 +1,55 @@
+package auth
+
+import (
+ "slices"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestNewAuthCommand(t *testing.T) {
+ cmd := NewAuthCommand()
+
+ require.NotNil(t, cmd)
+
+ assert.Equal(t, "auth", cmd.Use)
+ assert.Equal(t, "Manage authentication (login, logout, status)", cmd.Short)
+
+ assert.Len(t, cmd.Aliases, 0)
+
+ assert.Nil(t, cmd.Run)
+ assert.NotNil(t, cmd.RunE)
+
+ assert.Nil(t, cmd.PersistentPreRun)
+ assert.Nil(t, cmd.PersistentPostRun)
+
+ assert.False(t, cmd.HasFlags())
+ assert.True(t, cmd.HasSubCommands())
+
+ allowedCommands := []string{
+ "login",
+ "logout",
+ "status",
+ "models",
+ }
+
+ subcommands := cmd.Commands()
+ assert.Len(t, subcommands, len(allowedCommands))
+
+ for _, subcmd := range subcommands {
+ found := slices.Contains(allowedCommands, subcmd.Name())
+ assert.True(t, found, "unexpected subcommand %q", subcmd.Name())
+
+ assert.Len(t, subcmd.Aliases, 0)
+ assert.False(t, subcmd.Hidden)
+
+ assert.False(t, subcmd.HasSubCommands())
+
+ assert.Nil(t, subcmd.Run)
+ assert.NotNil(t, subcmd.RunE)
+
+ assert.Nil(t, subcmd.PersistentPreRun)
+ assert.Nil(t, subcmd.PersistentPostRun)
+ }
+}
diff --git a/cmd/picoclaw/cmd_auth.go b/cmd/picoclaw/internal/auth/helpers.go
similarity index 64%
rename from cmd/picoclaw/cmd_auth.go
rename to cmd/picoclaw/internal/auth/helpers.go
index 5bed7f116..633ce8740 100644
--- a/cmd/picoclaw/cmd_auth.go
+++ b/cmd/picoclaw/internal/auth/helpers.go
@@ -1,7 +1,4 @@
-// PicoClaw - Ultra-lightweight personal AI agent
-// License: MIT
-
-package main
+package auth
import (
"encoding/json"
@@ -12,92 +9,28 @@ import (
"strings"
"time"
+ "github.com/sipeed/picoclaw/cmd/picoclaw/internal"
"github.com/sipeed/picoclaw/pkg/auth"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/providers"
)
-const supportedProvidersMsg = "Supported providers: openai, anthropic, google-antigravity"
-
-func authCmd() {
- if len(os.Args) < 3 {
- authHelp()
- return
- }
-
- switch os.Args[2] {
- case "login":
- authLoginCmd()
- case "logout":
- authLogoutCmd()
- case "status":
- authStatusCmd()
- case "models":
- authModelsCmd()
- default:
- fmt.Printf("Unknown auth command: %s\n", os.Args[2])
- authHelp()
- }
-}
-
-func authHelp() {
- fmt.Println("\nAuth commands:")
- fmt.Println(" login Login via OAuth or paste token")
- fmt.Println(" logout Remove stored credentials")
- fmt.Println(" status Show current auth status")
- fmt.Println(" models List available Antigravity models")
- fmt.Println()
- fmt.Println("Login options:")
- fmt.Println(" --provider Provider to login with (openai, anthropic, google-antigravity)")
- fmt.Println(" --device-code Use device code flow (for headless environments)")
- fmt.Println()
- fmt.Println("Examples:")
- fmt.Println(" picoclaw auth login --provider openai")
- fmt.Println(" picoclaw auth login --provider openai --device-code")
- fmt.Println(" picoclaw auth login --provider anthropic")
- fmt.Println(" picoclaw auth login --provider google-antigravity")
- fmt.Println(" picoclaw auth models")
- fmt.Println(" picoclaw auth logout --provider openai")
- fmt.Println(" picoclaw auth status")
-}
-
-func authLoginCmd() {
- provider := ""
- useDeviceCode := false
-
- args := os.Args[3:]
- for i := 0; i < len(args); i++ {
- switch args[i] {
- case "--provider", "-p":
- if i+1 < len(args) {
- provider = args[i+1]
- i++
- }
- case "--device-code":
- useDeviceCode = true
- }
- }
-
- if provider == "" {
- fmt.Println("Error: --provider is required")
- fmt.Println(supportedProvidersMsg)
- return
- }
+const supportedProvidersMsg = "supported providers: openai, anthropic, google-antigravity"
+func authLoginCmd(provider string, useDeviceCode bool) error {
switch provider {
case "openai":
- authLoginOpenAI(useDeviceCode)
+ return authLoginOpenAI(useDeviceCode)
case "anthropic":
- authLoginPasteToken(provider)
+ return authLoginPasteToken(provider)
case "google-antigravity", "antigravity":
- authLoginGoogleAntigravity()
+ return authLoginGoogleAntigravity()
default:
- fmt.Printf("Unsupported provider: %s\n", provider)
- fmt.Println(supportedProvidersMsg)
+ return fmt.Errorf("unsupported provider: %s (%s)", provider, supportedProvidersMsg)
}
}
-func authLoginOpenAI(useDeviceCode bool) {
+func authLoginOpenAI(useDeviceCode bool) error {
cfg := auth.OpenAIOAuthConfig()
var cred *auth.AuthCredential
@@ -110,16 +43,14 @@ func authLoginOpenAI(useDeviceCode bool) {
}
if err != nil {
- fmt.Printf("Login failed: %v\n", err)
- os.Exit(1)
+ return fmt.Errorf("login failed: %w", err)
}
- if err := auth.SetCredential("openai", cred); err != nil {
- fmt.Printf("Failed to save credentials: %v\n", err)
- os.Exit(1)
+ if err = auth.SetCredential("openai", cred); err != nil {
+ return fmt.Errorf("failed to save credentials: %w", err)
}
- appCfg, err := loadConfig()
+ appCfg, err := internal.LoadConfig()
if err == nil {
// Update Providers (legacy format)
appCfg.Providers.OpenAI.AuthMethod = "oauth"
@@ -144,10 +75,10 @@ func authLoginOpenAI(useDeviceCode bool) {
}
// Update default model to use OpenAI
- appCfg.Agents.Defaults.Model = "gpt-5.2"
+ appCfg.Agents.Defaults.ModelName = "gpt-5.2"
- if err := config.SaveConfig(getConfigPath(), appCfg); err != nil {
- fmt.Printf("Warning: could not update config: %v\n", err)
+ if err = config.SaveConfig(internal.GetConfigPath(), appCfg); err != nil {
+ return fmt.Errorf("could not update config: %w", err)
}
}
@@ -156,15 +87,16 @@ func authLoginOpenAI(useDeviceCode bool) {
fmt.Printf("Account: %s\n", cred.AccountID)
}
fmt.Println("Default model set to: gpt-5.2")
+
+ return nil
}
-func authLoginGoogleAntigravity() {
+func authLoginGoogleAntigravity() error {
cfg := auth.GoogleAntigravityOAuthConfig()
cred, err := auth.LoginBrowser(cfg)
if err != nil {
- fmt.Printf("Login failed: %v\n", err)
- os.Exit(1)
+ return fmt.Errorf("login failed: %w", err)
}
cred.Provider = "google-antigravity"
@@ -188,12 +120,11 @@ func authLoginGoogleAntigravity() {
fmt.Printf("Project: %s\n", projectID)
}
- if err := auth.SetCredential("google-antigravity", cred); err != nil {
- fmt.Printf("Failed to save credentials: %v\n", err)
- os.Exit(1)
+ if err = auth.SetCredential("google-antigravity", cred); err != nil {
+ return fmt.Errorf("failed to save credentials: %w", err)
}
- appCfg, err := loadConfig()
+ appCfg, err := internal.LoadConfig()
if err == nil {
// Update Providers (legacy format, for backward compatibility)
appCfg.Providers.Antigravity.AuthMethod = "oauth"
@@ -218,9 +149,9 @@ func authLoginGoogleAntigravity() {
}
// Update default model
- appCfg.Agents.Defaults.Model = "gemini-flash"
+ appCfg.Agents.Defaults.ModelName = "gemini-flash"
- if err := config.SaveConfig(getConfigPath(), appCfg); err != nil {
+ if err := config.SaveConfig(internal.GetConfigPath(), appCfg); err != nil {
fmt.Printf("Warning: could not update config: %v\n", err)
}
}
@@ -228,6 +159,8 @@ func authLoginGoogleAntigravity() {
fmt.Println("\n✓ Google Antigravity login successful!")
fmt.Println("Default model set to: gemini-flash")
fmt.Println("Try it: picoclaw agent -m \"Hello world\"")
+
+ return nil
}
func fetchGoogleUserEmail(accessToken string) (string, error) {
@@ -258,19 +191,17 @@ func fetchGoogleUserEmail(accessToken string) (string, error) {
return userInfo.Email, nil
}
-func authLoginPasteToken(provider string) {
+func authLoginPasteToken(provider string) error {
cred, err := auth.LoginPasteToken(provider, os.Stdin)
if err != nil {
- fmt.Printf("Login failed: %v\n", err)
- os.Exit(1)
+ return fmt.Errorf("login failed: %w", err)
}
- if err := auth.SetCredential(provider, cred); err != nil {
- fmt.Printf("Failed to save credentials: %v\n", err)
- os.Exit(1)
+ if err = auth.SetCredential(provider, cred); err != nil {
+ return fmt.Errorf("failed to save credentials: %w", err)
}
- appCfg, err := loadConfig()
+ appCfg, err := internal.LoadConfig()
if err == nil {
switch provider {
case "anthropic":
@@ -292,7 +223,7 @@ func authLoginPasteToken(provider string) {
})
}
// Update default model
- appCfg.Agents.Defaults.Model = "claude-sonnet-4.6"
+ appCfg.Agents.Defaults.ModelName = "claude-sonnet-4.6"
case "openai":
appCfg.Providers.OpenAI.AuthMethod = "token"
// Update ModelList
@@ -312,38 +243,29 @@ func authLoginPasteToken(provider string) {
})
}
// Update default model
- appCfg.Agents.Defaults.Model = "gpt-5.2"
+ appCfg.Agents.Defaults.ModelName = "gpt-5.2"
}
- if err := config.SaveConfig(getConfigPath(), appCfg); err != nil {
- fmt.Printf("Warning: could not update config: %v\n", err)
+ if err := config.SaveConfig(internal.GetConfigPath(), appCfg); err != nil {
+ return fmt.Errorf("could not update config: %w", err)
}
}
fmt.Printf("Token saved for %s!\n", provider)
- fmt.Printf("Default model set to: %s\n", appCfg.Agents.Defaults.Model)
-}
-func authLogoutCmd() {
- provider := ""
-
- args := os.Args[3:]
- for i := 0; i < len(args); i++ {
- switch args[i] {
- case "--provider", "-p":
- if i+1 < len(args) {
- provider = args[i+1]
- i++
- }
- }
+ if appCfg != nil {
+ fmt.Printf("Default model set to: %s\n", appCfg.Agents.Defaults.GetModelName())
}
+ return nil
+}
+
+func authLogoutCmd(provider string) error {
if provider != "" {
if err := auth.DeleteCredential(provider); err != nil {
- fmt.Printf("Failed to remove credentials: %v\n", err)
- os.Exit(1)
+ return fmt.Errorf("failed to remove credentials: %w", err)
}
- appCfg, err := loadConfig()
+ appCfg, err := internal.LoadConfig()
if err == nil {
// Clear AuthMethod in ModelList
for i := range appCfg.ModelList {
@@ -371,44 +293,46 @@ func authLogoutCmd() {
case "google-antigravity", "antigravity":
appCfg.Providers.Antigravity.AuthMethod = ""
}
- config.SaveConfig(getConfigPath(), appCfg)
+ config.SaveConfig(internal.GetConfigPath(), appCfg)
}
fmt.Printf("Logged out from %s\n", provider)
- } else {
- if err := auth.DeleteAllCredentials(); err != nil {
- fmt.Printf("Failed to remove credentials: %v\n", err)
- os.Exit(1)
- }
- appCfg, err := loadConfig()
- if err == nil {
- // Clear all AuthMethods in ModelList
- for i := range appCfg.ModelList {
- appCfg.ModelList[i].AuthMethod = ""
- }
- // Clear all AuthMethods in Providers (legacy)
- appCfg.Providers.OpenAI.AuthMethod = ""
- appCfg.Providers.Anthropic.AuthMethod = ""
- appCfg.Providers.Antigravity.AuthMethod = ""
- config.SaveConfig(getConfigPath(), appCfg)
- }
-
- fmt.Println("Logged out from all providers")
+ return nil
}
+
+ if err := auth.DeleteAllCredentials(); err != nil {
+ return fmt.Errorf("failed to remove credentials: %w", err)
+ }
+
+ appCfg, err := internal.LoadConfig()
+ if err == nil {
+ // Clear all AuthMethods in ModelList
+ for i := range appCfg.ModelList {
+ appCfg.ModelList[i].AuthMethod = ""
+ }
+ // Clear all AuthMethods in Providers (legacy)
+ appCfg.Providers.OpenAI.AuthMethod = ""
+ appCfg.Providers.Anthropic.AuthMethod = ""
+ appCfg.Providers.Antigravity.AuthMethod = ""
+ config.SaveConfig(internal.GetConfigPath(), appCfg)
+ }
+
+ fmt.Println("Logged out from all providers")
+
+ return nil
}
-func authStatusCmd() {
+func authStatusCmd() error {
store, err := auth.LoadStore()
if err != nil {
- fmt.Printf("Error loading auth store: %v\n", err)
- return
+ return fmt.Errorf("failed to load auth store: %w", err)
}
if len(store.Credentials) == 0 {
fmt.Println("No authenticated providers.")
fmt.Println("Run: picoclaw auth login --provider ")
- return
+ return nil
}
fmt.Println("\nAuthenticated Providers:")
@@ -437,14 +361,16 @@ func authStatusCmd() {
fmt.Printf(" Expires: %s\n", cred.ExpiresAt.Format("2006-01-02 15:04"))
}
}
+
+ return nil
}
-func authModelsCmd() {
+func authModelsCmd() error {
cred, err := auth.GetCredential("google-antigravity")
if err != nil || cred == nil {
- fmt.Println("Not logged in to Google Antigravity.")
- fmt.Println("Run: picoclaw auth login --provider google-antigravity")
- return
+ return fmt.Errorf(
+ "not logged in to Google Antigravity.\nrun: picoclaw auth login --provider google-antigravity",
+ )
}
// Refresh token if needed
@@ -459,21 +385,18 @@ func authModelsCmd() {
projectID := cred.ProjectID
if projectID == "" {
- fmt.Println("No project ID stored. Try logging in again.")
- return
+ return fmt.Errorf("no project id stored. Try logging in again")
}
fmt.Printf("Fetching models for project: %s\n\n", projectID)
models, err := providers.FetchAntigravityModels(cred.AccessToken, projectID)
if err != nil {
- fmt.Printf("Error fetching models: %v\n", err)
- return
+ return fmt.Errorf("error fetching models: %w", err)
}
if len(models) == 0 {
- fmt.Println("No models available.")
- return
+ return fmt.Errorf("no models available")
}
fmt.Println("Available Antigravity Models:")
@@ -489,6 +412,8 @@ func authModelsCmd() {
}
fmt.Printf(" %s %s\n", status, name)
}
+
+ return nil
}
// isAntigravityModel checks if a model string belongs to antigravity provider
diff --git a/cmd/picoclaw/internal/auth/login.go b/cmd/picoclaw/internal/auth/login.go
new file mode 100644
index 000000000..9a6d28d2f
--- /dev/null
+++ b/cmd/picoclaw/internal/auth/login.go
@@ -0,0 +1,25 @@
+package auth
+
+import "github.com/spf13/cobra"
+
+func newLoginCommand() *cobra.Command {
+ var (
+ provider string
+ useDeviceCode bool
+ )
+
+ cmd := &cobra.Command{
+ Use: "login",
+ Short: "Login via OAuth or paste token",
+ Args: cobra.NoArgs,
+ RunE: func(cmd *cobra.Command, _ []string) error {
+ return authLoginCmd(provider, useDeviceCode)
+ },
+ }
+
+ cmd.Flags().StringVarP(&provider, "provider", "p", "", "Provider to login with (openai, anthropic)")
+ cmd.Flags().BoolVar(&useDeviceCode, "device-code", false, "Use device code flow (for headless environments)")
+ _ = cmd.MarkFlagRequired("provider")
+
+ return cmd
+}
diff --git a/cmd/picoclaw/internal/auth/login_test.go b/cmd/picoclaw/internal/auth/login_test.go
new file mode 100644
index 000000000..d6a03c25b
--- /dev/null
+++ b/cmd/picoclaw/internal/auth/login_test.go
@@ -0,0 +1,29 @@
+package auth
+
+import (
+ "testing"
+
+ "github.com/spf13/cobra"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestNewLoginSubCommand(t *testing.T) {
+ cmd := newLoginCommand()
+
+ require.NotNil(t, cmd)
+
+ assert.Equal(t, "Login via OAuth or paste token", cmd.Short)
+
+ assert.True(t, cmd.HasFlags())
+
+ assert.NotNil(t, cmd.Flags().Lookup("device-code"))
+
+ providerFlag := cmd.Flags().Lookup("provider")
+ require.NotNil(t, providerFlag)
+
+ val, found := providerFlag.Annotations[cobra.BashCompOneRequiredFlag]
+ require.True(t, found)
+ require.NotEmpty(t, val)
+ assert.Equal(t, "true", val[0])
+}
diff --git a/cmd/picoclaw/internal/auth/logout.go b/cmd/picoclaw/internal/auth/logout.go
new file mode 100644
index 000000000..384667524
--- /dev/null
+++ b/cmd/picoclaw/internal/auth/logout.go
@@ -0,0 +1,20 @@
+package auth
+
+import "github.com/spf13/cobra"
+
+func newLogoutCommand() *cobra.Command {
+ var provider string
+
+ cmd := &cobra.Command{
+ Use: "logout",
+ Short: "Remove stored credentials",
+ Args: cobra.NoArgs,
+ RunE: func(cmd *cobra.Command, _ []string) error {
+ return authLogoutCmd(provider)
+ },
+ }
+
+ cmd.Flags().StringVarP(&provider, "provider", "p", "", "Provider to logout from (openai, anthropic); empty = all")
+
+ return cmd
+}
diff --git a/cmd/picoclaw/internal/auth/logout_test.go b/cmd/picoclaw/internal/auth/logout_test.go
new file mode 100644
index 000000000..c0f3a5e92
--- /dev/null
+++ b/cmd/picoclaw/internal/auth/logout_test.go
@@ -0,0 +1,20 @@
+package auth
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestNewLogoutSubcommand(t *testing.T) {
+ cmd := newLogoutCommand()
+
+ require.NotNil(t, cmd)
+
+ assert.Equal(t, "Remove stored credentials", cmd.Short)
+
+ assert.True(t, cmd.HasFlags())
+
+ assert.NotNil(t, cmd.Flags().Lookup("provider"))
+}
diff --git a/cmd/picoclaw/internal/auth/models.go b/cmd/picoclaw/internal/auth/models.go
new file mode 100644
index 000000000..cabe6822c
--- /dev/null
+++ b/cmd/picoclaw/internal/auth/models.go
@@ -0,0 +1,15 @@
+package auth
+
+import "github.com/spf13/cobra"
+
+func newModelsCommand() *cobra.Command {
+ cmd := &cobra.Command{
+ Use: "models",
+ Short: "Show available models",
+ RunE: func(_ *cobra.Command, _ []string) error {
+ return authModelsCmd()
+ },
+ }
+
+ return cmd
+}
diff --git a/cmd/picoclaw/internal/auth/models_test.go b/cmd/picoclaw/internal/auth/models_test.go
new file mode 100644
index 000000000..26ca67787
--- /dev/null
+++ b/cmd/picoclaw/internal/auth/models_test.go
@@ -0,0 +1,19 @@
+package auth
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestNewModelsCommand(t *testing.T) {
+ cmd := newModelsCommand()
+
+ require.NotNil(t, cmd)
+
+ assert.Equal(t, "models", cmd.Use)
+ assert.Equal(t, "Show available models", cmd.Short)
+
+ assert.False(t, cmd.HasFlags())
+}
diff --git a/cmd/picoclaw/internal/auth/status.go b/cmd/picoclaw/internal/auth/status.go
new file mode 100644
index 000000000..ca3007d12
--- /dev/null
+++ b/cmd/picoclaw/internal/auth/status.go
@@ -0,0 +1,16 @@
+package auth
+
+import "github.com/spf13/cobra"
+
+func newStatusCommand() *cobra.Command {
+ cmd := &cobra.Command{
+ Use: "status",
+ Short: "Show current auth status",
+ Args: cobra.NoArgs,
+ RunE: func(cmd *cobra.Command, _ []string) error {
+ return authStatusCmd()
+ },
+ }
+
+ return cmd
+}
diff --git a/cmd/picoclaw/internal/auth/status_test.go b/cmd/picoclaw/internal/auth/status_test.go
new file mode 100644
index 000000000..7748ba502
--- /dev/null
+++ b/cmd/picoclaw/internal/auth/status_test.go
@@ -0,0 +1,18 @@
+package auth
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestNewStatusSubcommand(t *testing.T) {
+ cmd := newStatusCommand()
+
+ require.NotNil(t, cmd)
+
+ assert.Equal(t, "Show current auth status", cmd.Short)
+
+ assert.False(t, cmd.HasFlags())
+}
diff --git a/cmd/picoclaw/internal/cron/add.go b/cmd/picoclaw/internal/cron/add.go
new file mode 100644
index 000000000..947557d5a
--- /dev/null
+++ b/cmd/picoclaw/internal/cron/add.go
@@ -0,0 +1,64 @@
+package cron
+
+import (
+ "fmt"
+
+ "github.com/spf13/cobra"
+
+ "github.com/sipeed/picoclaw/pkg/cron"
+)
+
+func newAddCommand(storePath func() string) *cobra.Command {
+ var (
+ name string
+ message string
+ every int64
+ cronExp string
+ deliver bool
+ channel string
+ to string
+ )
+
+ cmd := &cobra.Command{
+ Use: "add",
+ Short: "Add a new scheduled job",
+ Args: cobra.NoArgs,
+ RunE: func(cmd *cobra.Command, _ []string) error {
+ if every <= 0 && cronExp == "" {
+ return fmt.Errorf("either --every or --cron must be specified")
+ }
+
+ var schedule cron.CronSchedule
+ if every > 0 {
+ everyMS := every * 1000
+ schedule = cron.CronSchedule{Kind: "every", EveryMS: &everyMS}
+ } else {
+ schedule = cron.CronSchedule{Kind: "cron", Expr: cronExp}
+ }
+
+ cs := cron.NewCronService(storePath(), nil)
+ job, err := cs.AddJob(name, schedule, message, deliver, channel, to)
+ if err != nil {
+ return fmt.Errorf("error adding job: %w", err)
+ }
+
+ fmt.Printf("✓ Added job '%s' (%s)\n", job.Name, job.ID)
+
+ return nil
+ },
+ }
+
+ cmd.Flags().StringVarP(&name, "name", "n", "", "Job name")
+ cmd.Flags().StringVarP(&message, "message", "m", "", "Message for agent")
+ cmd.Flags().Int64VarP(&every, "every", "e", 0, "Run every N seconds")
+ cmd.Flags().StringVarP(&cronExp, "cron", "c", "", "Cron expression (e.g. '0 9 * * *')")
+ cmd.Flags().BoolVarP(&deliver, "deliver", "d", false, "Deliver response to channel")
+ cmd.Flags().StringVar(&to, "to", "", "Recipient for delivery")
+ cmd.Flags().StringVar(&channel, "channel", "", "Channel for delivery")
+
+ _ = cmd.MarkFlagRequired("name")
+ _ = cmd.MarkFlagRequired("message")
+ cmd.MarkFlagsMutuallyExclusive("every", "cron")
+
+ return cmd
+}
diff --git a/cmd/picoclaw/internal/cron/add_test.go b/cmd/picoclaw/internal/cron/add_test.go
new file mode 100644
index 000000000..09701fab5
--- /dev/null
+++ b/cmd/picoclaw/internal/cron/add_test.go
@@ -0,0 +1,57 @@
+package cron
+
+import (
+ "testing"
+
+ "github.com/spf13/cobra"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestNewAddSubcommand(t *testing.T) {
+ fn := func() string { return "" }
+ cmd := newAddCommand(fn)
+
+ require.NotNil(t, cmd)
+
+ assert.Equal(t, "add", cmd.Use)
+ assert.Equal(t, "Add a new scheduled job", cmd.Short)
+
+ assert.True(t, cmd.HasFlags())
+
+ assert.NotNil(t, cmd.Flags().Lookup("every"))
+ assert.NotNil(t, cmd.Flags().Lookup("cron"))
+ assert.NotNil(t, cmd.Flags().Lookup("deliver"))
+ assert.NotNil(t, cmd.Flags().Lookup("to"))
+ assert.NotNil(t, cmd.Flags().Lookup("channel"))
+
+ nameFlag := cmd.Flags().Lookup("name")
+ require.NotNil(t, nameFlag)
+
+ messageFlag := cmd.Flags().Lookup("message")
+ require.NotNil(t, messageFlag)
+
+ val, found := nameFlag.Annotations[cobra.BashCompOneRequiredFlag]
+ require.True(t, found)
+ require.NotEmpty(t, val)
+ assert.Equal(t, "true", val[0])
+
+ val, found = messageFlag.Annotations[cobra.BashCompOneRequiredFlag]
+ require.True(t, found)
+ require.NotEmpty(t, val)
+ assert.Equal(t, "true", val[0])
+}
+
+func TestNewAddCommandEveryAndCronMutuallyExclusive(t *testing.T) {
+ cmd := newAddCommand(func() string { return "testing" })
+
+ cmd.SetArgs([]string{
+ "--name", "job",
+ "--message", "hello",
+ "--every", "10",
+ "--cron", "0 9 * * *",
+ })
+
+ err := cmd.Execute()
+ require.Error(t, err)
+}
diff --git a/cmd/picoclaw/internal/cron/command.go b/cmd/picoclaw/internal/cron/command.go
new file mode 100644
index 000000000..39f8ccf28
--- /dev/null
+++ b/cmd/picoclaw/internal/cron/command.go
@@ -0,0 +1,44 @@
+package cron
+
+import (
+ "fmt"
+ "path/filepath"
+
+ "github.com/spf13/cobra"
+
+ "github.com/sipeed/picoclaw/cmd/picoclaw/internal"
+)
+
+func NewCronCommand() *cobra.Command {
+ var storePath string
+
+ cmd := &cobra.Command{
+ Use: "cron",
+ Aliases: []string{"c"},
+ Short: "Manage scheduled tasks",
+ Args: cobra.NoArgs,
+ RunE: func(cmd *cobra.Command, _ []string) error {
+ return cmd.Help()
+ },
+ // Resolve storePath at execution time so it reflects the current config
+ // and is shared across all subcommands.
+ PersistentPreRunE: func(_ *cobra.Command, _ []string) error {
+ cfg, err := internal.LoadConfig()
+ if err != nil {
+ return fmt.Errorf("error loading config: %w", err)
+ }
+ storePath = filepath.Join(cfg.WorkspacePath(), "cron", "jobs.json")
+ return nil
+ },
+ }
+
+ cmd.AddCommand(
+ newListCommand(func() string { return storePath }),
+ newAddCommand(func() string { return storePath }),
+ newRemoveCommand(func() string { return storePath }),
+ newEnableCommand(func() string { return storePath }),
+ newDisableCommand(func() string { return storePath }),
+ )
+
+ return cmd
+}
diff --git a/cmd/picoclaw/internal/cron/command_test.go b/cmd/picoclaw/internal/cron/command_test.go
new file mode 100644
index 000000000..af2ac83ae
--- /dev/null
+++ b/cmd/picoclaw/internal/cron/command_test.go
@@ -0,0 +1,58 @@
+package cron
+
+import (
+ "slices"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestNewCronCommand(t *testing.T) {
+ cmd := NewCronCommand()
+
+ require.NotNil(t, cmd)
+
+ assert.Equal(t, "Manage scheduled tasks", cmd.Short)
+
+ assert.Len(t, cmd.Aliases, 1)
+ assert.True(t, cmd.HasAlias("c"))
+
+ assert.False(t, cmd.HasFlags())
+
+ assert.Nil(t, cmd.Run)
+ assert.NotNil(t, cmd.RunE)
+
+ assert.NotNil(t, cmd.PersistentPreRunE)
+ assert.Nil(t, cmd.PersistentPreRun)
+ assert.Nil(t, cmd.PersistentPostRun)
+
+ assert.True(t, cmd.HasSubCommands())
+
+ allowedCommands := []string{
+ "list",
+ "add",
+ "remove",
+ "enable",
+ "disable",
+ }
+
+ subcommands := cmd.Commands()
+ assert.Len(t, subcommands, len(allowedCommands))
+
+ for _, subcmd := range subcommands {
+ found := slices.Contains(allowedCommands, subcmd.Name())
+ assert.True(t, found, "unexpected subcommand %q", subcmd.Name())
+
+ assert.Len(t, subcmd.Aliases, 0)
+ assert.False(t, subcmd.Hidden)
+
+ assert.False(t, subcmd.HasSubCommands())
+
+ assert.Nil(t, subcmd.Run)
+ assert.NotNil(t, subcmd.RunE)
+
+ assert.Nil(t, subcmd.PersistentPreRun)
+ assert.Nil(t, subcmd.PersistentPostRun)
+ }
+}
diff --git a/cmd/picoclaw/internal/cron/disable.go b/cmd/picoclaw/internal/cron/disable.go
new file mode 100644
index 000000000..a3670fd50
--- /dev/null
+++ b/cmd/picoclaw/internal/cron/disable.go
@@ -0,0 +1,16 @@
+package cron
+
+import "github.com/spf13/cobra"
+
+func newDisableCommand(storePath func() string) *cobra.Command {
+ return &cobra.Command{
+ Use: "disable",
+ Short: "Disable a job",
+ Args: cobra.ExactArgs(1),
+ Example: `picoclaw cron disable 1`,
+ RunE: func(_ *cobra.Command, args []string) error {
+ cronSetJobEnabled(storePath(), args[0], false)
+ return nil
+ },
+ }
+}
diff --git a/cmd/picoclaw/internal/cron/disable_test.go b/cmd/picoclaw/internal/cron/disable_test.go
new file mode 100644
index 000000000..e5d2ff844
--- /dev/null
+++ b/cmd/picoclaw/internal/cron/disable_test.go
@@ -0,0 +1,20 @@
+package cron
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestDisableSubcommand(t *testing.T) {
+ fn := func() string { return "" }
+ cmd := newDisableCommand(fn)
+
+ require.NotNil(t, cmd)
+
+ assert.Equal(t, "disable", cmd.Use)
+ assert.Equal(t, "Disable a job", cmd.Short)
+
+ assert.True(t, cmd.HasExample())
+}
diff --git a/cmd/picoclaw/internal/cron/enable.go b/cmd/picoclaw/internal/cron/enable.go
new file mode 100644
index 000000000..7f8b05233
--- /dev/null
+++ b/cmd/picoclaw/internal/cron/enable.go
@@ -0,0 +1,16 @@
+package cron
+
+import "github.com/spf13/cobra"
+
+func newEnableCommand(storePath func() string) *cobra.Command {
+ return &cobra.Command{
+ Use: "enable",
+ Short: "Enable a job",
+ Args: cobra.ExactArgs(1),
+ Example: `picoclaw cron enable 1`,
+ RunE: func(_ *cobra.Command, args []string) error {
+ cronSetJobEnabled(storePath(), args[0], true)
+ return nil
+ },
+ }
+}
diff --git a/cmd/picoclaw/internal/cron/enable_test.go b/cmd/picoclaw/internal/cron/enable_test.go
new file mode 100644
index 000000000..85a2e01aa
--- /dev/null
+++ b/cmd/picoclaw/internal/cron/enable_test.go
@@ -0,0 +1,20 @@
+package cron
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestEnableSubcommand(t *testing.T) {
+ fn := func() string { return "" }
+ cmd := newEnableCommand(fn)
+
+ require.NotNil(t, cmd)
+
+ assert.Equal(t, "enable", cmd.Use)
+ assert.Equal(t, "Enable a job", cmd.Short)
+
+ assert.True(t, cmd.HasExample())
+}
diff --git a/cmd/picoclaw/internal/cron/helpers.go b/cmd/picoclaw/internal/cron/helpers.go
new file mode 100644
index 000000000..88bdf1bf7
--- /dev/null
+++ b/cmd/picoclaw/internal/cron/helpers.go
@@ -0,0 +1,66 @@
+package cron
+
+import (
+ "fmt"
+ "time"
+
+ "github.com/sipeed/picoclaw/pkg/cron"
+)
+
+func cronListCmd(storePath string) {
+ cs := cron.NewCronService(storePath, nil)
+ jobs := cs.ListJobs(true) // Show all jobs, including disabled
+
+ if len(jobs) == 0 {
+ fmt.Println("No scheduled jobs.")
+ return
+ }
+
+ fmt.Println("\nScheduled Jobs:")
+ fmt.Println("----------------")
+ for _, job := range jobs {
+ var schedule string
+ if job.Schedule.Kind == "every" && job.Schedule.EveryMS != nil {
+ schedule = fmt.Sprintf("every %ds", *job.Schedule.EveryMS/1000)
+ } else if job.Schedule.Kind == "cron" {
+ schedule = job.Schedule.Expr
+ } else {
+ schedule = "one-time"
+ }
+
+ nextRun := "scheduled"
+ if job.State.NextRunAtMS != nil {
+ nextTime := time.UnixMilli(*job.State.NextRunAtMS)
+ nextRun = nextTime.Format("2006-01-02 15:04")
+ }
+
+ status := "enabled"
+ if !job.Enabled {
+ status = "disabled"
+ }
+
+ fmt.Printf(" %s (%s)\n", job.Name, job.ID)
+ fmt.Printf(" Schedule: %s\n", schedule)
+ fmt.Printf(" Status: %s\n", status)
+ fmt.Printf(" Next run: %s\n", nextRun)
+ }
+}
+
+func cronRemoveCmd(storePath, jobID string) {
+ cs := cron.NewCronService(storePath, nil)
+ if cs.RemoveJob(jobID) {
+ fmt.Printf("✓ Removed job %s\n", jobID)
+ } else {
+ fmt.Printf("✗ Job %s not found\n", jobID)
+ }
+}
+
+func cronSetJobEnabled(storePath, jobID string, enabled bool) {
+ cs := cron.NewCronService(storePath, nil)
+ job := cs.EnableJob(jobID, enabled)
+ if job != nil {
+ fmt.Printf("✓ Job '%s' enabled\n", job.Name)
+ } else {
+ fmt.Printf("✗ Job %s not found\n", jobID)
+ }
+}
diff --git a/cmd/picoclaw/internal/cron/list.go b/cmd/picoclaw/internal/cron/list.go
new file mode 100644
index 000000000..854eb1a44
--- /dev/null
+++ b/cmd/picoclaw/internal/cron/list.go
@@ -0,0 +1,17 @@
+package cron
+
+import "github.com/spf13/cobra"
+
+func newListCommand(storePath func() string) *cobra.Command {
+ cmd := &cobra.Command{
+ Use: "list",
+ Short: "List all scheduled jobs",
+ Args: cobra.NoArgs,
+ RunE: func(_ *cobra.Command, _ []string) error {
+ cronListCmd(storePath())
+ return nil
+ },
+ }
+
+ return cmd
+}
diff --git a/cmd/picoclaw/internal/cron/list_test.go b/cmd/picoclaw/internal/cron/list_test.go
new file mode 100644
index 000000000..0b9d1bd59
--- /dev/null
+++ b/cmd/picoclaw/internal/cron/list_test.go
@@ -0,0 +1,17 @@
+package cron
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestNewListSubcommand(t *testing.T) {
+ fn := func() string { return "" }
+ cmd := newListCommand(fn)
+
+ require.NotNil(t, cmd)
+
+ assert.Equal(t, "List all scheduled jobs", cmd.Short)
+}
diff --git a/cmd/picoclaw/internal/cron/remove.go b/cmd/picoclaw/internal/cron/remove.go
new file mode 100644
index 000000000..5f1d1a04b
--- /dev/null
+++ b/cmd/picoclaw/internal/cron/remove.go
@@ -0,0 +1,18 @@
+package cron
+
+import "github.com/spf13/cobra"
+
+func newRemoveCommand(storePath func() string) *cobra.Command {
+ cmd := &cobra.Command{
+ Use: "remove",
+ Short: "Remove a job by ID",
+ Args: cobra.ExactArgs(1),
+ Example: `picoclaw cron remove 1`,
+ RunE: func(_ *cobra.Command, args []string) error {
+ cronRemoveCmd(storePath(), args[0])
+ return nil
+ },
+ }
+
+ return cmd
+}
diff --git a/cmd/picoclaw/internal/cron/remove_test.go b/cmd/picoclaw/internal/cron/remove_test.go
new file mode 100644
index 000000000..36121f370
--- /dev/null
+++ b/cmd/picoclaw/internal/cron/remove_test.go
@@ -0,0 +1,19 @@
+package cron
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestNewRemoveSubcommand(t *testing.T) {
+ fn := func() string { return "" }
+ cmd := newRemoveCommand(fn)
+
+ require.NotNil(t, cmd)
+
+ assert.Equal(t, "Remove a job by ID", cmd.Short)
+
+ assert.True(t, cmd.HasExample())
+}
diff --git a/cmd/picoclaw/internal/gateway/command.go b/cmd/picoclaw/internal/gateway/command.go
new file mode 100644
index 000000000..66a56f9ce
--- /dev/null
+++ b/cmd/picoclaw/internal/gateway/command.go
@@ -0,0 +1,23 @@
+package gateway
+
+import (
+ "github.com/spf13/cobra"
+)
+
+func NewGatewayCommand() *cobra.Command {
+ var debug bool
+
+ cmd := &cobra.Command{
+ Use: "gateway",
+ Aliases: []string{"g"},
+ Short: "Start picoclaw gateway",
+ Args: cobra.NoArgs,
+ RunE: func(_ *cobra.Command, _ []string) error {
+ return gatewayCmd(debug)
+ },
+ }
+
+ cmd.Flags().BoolVarP(&debug, "debug", "d", false, "Enable debug logging")
+
+ return cmd
+}
diff --git a/cmd/picoclaw/internal/gateway/command_test.go b/cmd/picoclaw/internal/gateway/command_test.go
new file mode 100644
index 000000000..4d591ea67
--- /dev/null
+++ b/cmd/picoclaw/internal/gateway/command_test.go
@@ -0,0 +1,31 @@
+package gateway
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestNewGatewayCommand(t *testing.T) {
+ cmd := NewGatewayCommand()
+
+ require.NotNil(t, cmd)
+
+ assert.Equal(t, "gateway", cmd.Use)
+ assert.Equal(t, "Start picoclaw gateway", cmd.Short)
+
+ assert.Len(t, cmd.Aliases, 1)
+ assert.True(t, cmd.HasAlias("g"))
+
+ assert.Nil(t, cmd.Run)
+ assert.NotNil(t, cmd.RunE)
+
+ assert.Nil(t, cmd.PersistentPreRun)
+ assert.Nil(t, cmd.PersistentPostRun)
+
+ assert.False(t, cmd.HasSubCommands())
+
+ assert.True(t, cmd.HasFlags())
+ assert.NotNil(t, cmd.Flags().Lookup("debug"))
+}
diff --git a/cmd/picoclaw/cmd_gateway.go b/cmd/picoclaw/internal/gateway/helpers.go
similarity index 78%
rename from cmd/picoclaw/cmd_gateway.go
rename to cmd/picoclaw/internal/gateway/helpers.go
index 1f1bf5491..a06625dc9 100644
--- a/cmd/picoclaw/cmd_gateway.go
+++ b/cmd/picoclaw/internal/gateway/helpers.go
@@ -1,17 +1,17 @@
-// PicoClaw - Ultra-lightweight personal AI agent
-// License: MIT
-
-package main
+package gateway
import (
"context"
+ "errors"
"fmt"
"net/http"
"os"
"os/signal"
"path/filepath"
+ "strings"
"time"
+ "github.com/sipeed/picoclaw/cmd/picoclaw/internal"
"github.com/sipeed/picoclaw/pkg/agent"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
@@ -27,31 +27,25 @@ import (
"github.com/sipeed/picoclaw/pkg/voice"
)
-func gatewayCmd() {
- // Check for --debug flag
- args := os.Args[2:]
- for _, arg := range args {
- if arg == "--debug" || arg == "-d" {
- logger.SetLevel(logger.DEBUG)
- fmt.Println("🔍 Debug mode enabled")
- break
- }
+func gatewayCmd(debug bool) error {
+ if debug {
+ logger.SetLevel(logger.DEBUG)
+ fmt.Println("🔍 Debug mode enabled")
}
- cfg, err := loadConfig()
+ cfg, err := internal.LoadConfig()
if err != nil {
- fmt.Printf("Error loading config: %v\n", err)
- os.Exit(1)
+ return fmt.Errorf("error loading config: %w", err)
}
provider, modelID, err := providers.CreateProvider(cfg)
if err != nil {
- fmt.Printf("Error creating provider: %v\n", err)
- os.Exit(1)
+ return fmt.Errorf("error creating provider: %w", err)
}
+
// Use the resolved model ID from provider creation
if modelID != "" {
- cfg.Agents.Defaults.Model = modelID
+ cfg.Agents.Defaults.ModelName = modelID
}
msgBus := bus.NewMessageBus()
@@ -60,8 +54,8 @@ func gatewayCmd() {
// Print agent startup info
fmt.Println("\n📦 Agent Status:")
startupInfo := agentLoop.GetStartupInfo()
- toolsInfo := startupInfo["tools"].(map[string]interface{})
- skillsInfo := startupInfo["skills"].(map[string]interface{})
+ toolsInfo := startupInfo["tools"].(map[string]any)
+ skillsInfo := startupInfo["skills"].(map[string]any)
fmt.Printf(" • Tools: %d loaded\n", toolsInfo["count"])
fmt.Printf(" • Skills: %d/%d available\n",
skillsInfo["available"],
@@ -69,7 +63,7 @@ func gatewayCmd() {
// Log to file as well
logger.InfoCF("agent", "Agent initialized",
- map[string]interface{}{
+ map[string]any{
"tools_count": toolsInfo["count"],
"skills_total": skillsInfo["total"],
"skills_available": skillsInfo["available"],
@@ -77,7 +71,14 @@ func gatewayCmd() {
// Setup cron tool and service
execTimeout := time.Duration(cfg.Tools.Cron.ExecTimeoutMinutes) * time.Minute
- cronService := setupCronTool(agentLoop, msgBus, cfg.WorkspacePath(), cfg.Agents.Defaults.RestrictToWorkspace, execTimeout, cfg)
+ cronService := setupCronTool(
+ agentLoop,
+ msgBus,
+ cfg.WorkspacePath(),
+ cfg.Agents.Defaults.RestrictToWorkspace,
+ execTimeout,
+ cfg,
+ )
heartbeatService := heartbeat.NewHeartbeatService(
cfg.WorkspacePath(),
@@ -91,7 +92,8 @@ func gatewayCmd() {
channel, chatID = "cli", "direct"
}
// Use ProcessHeartbeat - no session history, each heartbeat is independent
- response, err := agentLoop.ProcessHeartbeat(context.Background(), prompt, channel, chatID)
+ var response string
+ response, err = agentLoop.ProcessHeartbeat(context.Background(), prompt, channel, chatID)
if err != nil {
return tools.ErrorResult(fmt.Sprintf("Heartbeat error: %v", err))
}
@@ -105,16 +107,24 @@ func gatewayCmd() {
channelManager, err := channels.NewManager(cfg, msgBus)
if err != nil {
- fmt.Printf("Error creating channel manager: %v\n", err)
- os.Exit(1)
+ return fmt.Errorf("error creating channel manager: %w", err)
}
// Inject channel manager into agent loop for command handling
agentLoop.SetChannelManager(channelManager)
var transcriber *voice.GroqTranscriber
- if cfg.Providers.Groq.APIKey != "" {
- transcriber = voice.NewGroqTranscriber(cfg.Providers.Groq.APIKey)
+ groqAPIKey := cfg.Providers.Groq.APIKey
+ if groqAPIKey == "" {
+ for _, mc := range cfg.ModelList {
+ if strings.HasPrefix(mc.Model, "groq/") && mc.APIKey != "" {
+ groqAPIKey = mc.APIKey
+ break
+ }
+ }
+ }
+ if groqAPIKey != "" {
+ transcriber = voice.NewGroqTranscriber(groqAPIKey)
logger.InfoC("voice", "Groq voice transcription enabled")
}
@@ -180,8 +190,8 @@ func gatewayCmd() {
healthServer := health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port)
go func() {
- if err := healthServer.Start(); err != nil && err != http.ErrServerClosed {
- logger.ErrorCF("health", "Health server error", map[string]interface{}{"error": err.Error()})
+ if err := healthServer.Start(); err != nil && !errors.Is(err, http.ErrServerClosed) {
+ logger.ErrorCF("health", "Health server error", map[string]any{"error": err.Error()})
}
}()
fmt.Printf("✓ Health endpoints available at http://%s:%d/health and /ready\n", cfg.Gateway.Host, cfg.Gateway.Port)
@@ -193,6 +203,9 @@ func gatewayCmd() {
<-sigChan
fmt.Println("\nShutting down...")
+ if cp, ok := provider.(providers.StatefulProvider); ok {
+ cp.Close()
+ }
cancel()
healthServer.Stop(context.Background())
deviceService.Stop()
@@ -201,9 +214,18 @@ func gatewayCmd() {
agentLoop.Stop()
channelManager.StopAll(ctx)
fmt.Println("✓ Gateway stopped")
+
+ return nil
}
-func setupCronTool(agentLoop *agent.AgentLoop, msgBus *bus.MessageBus, workspace string, restrict bool, execTimeout time.Duration, cfg *config.Config) *cron.CronService {
+func setupCronTool(
+ agentLoop *agent.AgentLoop,
+ msgBus *bus.MessageBus,
+ workspace string,
+ restrict bool,
+ execTimeout time.Duration,
+ cfg *config.Config,
+) *cron.CronService {
cronStorePath := filepath.Join(workspace, "cron", "jobs.json")
// Create cron service
diff --git a/cmd/picoclaw/internal/helpers.go b/cmd/picoclaw/internal/helpers.go
new file mode 100644
index 000000000..1f52df5dd
--- /dev/null
+++ b/cmd/picoclaw/internal/helpers.go
@@ -0,0 +1,52 @@
+package internal
+
+import (
+ "fmt"
+ "os"
+ "path/filepath"
+ "runtime"
+
+ "github.com/sipeed/picoclaw/pkg/config"
+)
+
+const Logo = "🦞"
+
+var (
+ version = "dev"
+ gitCommit string
+ buildTime string
+ goVersion string
+)
+
+func GetConfigPath() string {
+ home, _ := os.UserHomeDir()
+ return filepath.Join(home, ".picoclaw", "config.json")
+}
+
+func LoadConfig() (*config.Config, error) {
+ return config.LoadConfig(GetConfigPath())
+}
+
+// FormatVersion returns the version string with optional git commit
+func FormatVersion() string {
+ v := version
+ if gitCommit != "" {
+ v += fmt.Sprintf(" (git: %s)", gitCommit)
+ }
+ return v
+}
+
+// FormatBuildInfo returns build time and go version info
+func FormatBuildInfo() (string, string) {
+ build := buildTime
+ goVer := goVersion
+ if goVer == "" {
+ goVer = runtime.Version()
+ }
+ return build, goVer
+}
+
+// GetVersion returns the version string
+func GetVersion() string {
+ return version
+}
diff --git a/cmd/picoclaw/internal/helpers_test.go b/cmd/picoclaw/internal/helpers_test.go
new file mode 100644
index 000000000..9342d141d
--- /dev/null
+++ b/cmd/picoclaw/internal/helpers_test.go
@@ -0,0 +1,97 @@
+package internal
+
+import (
+ "path/filepath"
+ "runtime"
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestGetConfigPath(t *testing.T) {
+ t.Setenv("HOME", "/tmp/home")
+
+ got := GetConfigPath()
+ want := filepath.Join("/tmp/home", ".picoclaw", "config.json")
+
+ assert.Equal(t, want, got)
+}
+
+func TestFormatVersion_NoGitCommit(t *testing.T) {
+ oldVersion, oldGit := version, gitCommit
+ t.Cleanup(func() { version, gitCommit = oldVersion, oldGit })
+
+ version = "1.2.3"
+ gitCommit = ""
+
+ assert.Equal(t, "1.2.3", FormatVersion())
+}
+
+func TestFormatVersion_WithGitCommit(t *testing.T) {
+ oldVersion, oldGit := version, gitCommit
+ t.Cleanup(func() { version, gitCommit = oldVersion, oldGit })
+
+ version = "1.2.3"
+ gitCommit = "abc123"
+
+ assert.Equal(t, "1.2.3 (git: abc123)", FormatVersion())
+}
+
+func TestFormatBuildInfo_UsesBuildTimeAndGoVersion_WhenSet(t *testing.T) {
+ oldBuildTime, oldGoVersion := buildTime, goVersion
+ t.Cleanup(func() { buildTime, goVersion = oldBuildTime, oldGoVersion })
+
+ buildTime = "2026-02-20T00:00:00Z"
+ goVersion = "go1.23.0"
+
+ build, goVer := FormatBuildInfo()
+
+ assert.Equal(t, buildTime, build)
+ assert.Equal(t, goVersion, goVer)
+}
+
+func TestFormatBuildInfo_EmptyBuildTime_ReturnsEmptyBuild(t *testing.T) {
+ oldBuildTime, oldGoVersion := buildTime, goVersion
+ t.Cleanup(func() { buildTime, goVersion = oldBuildTime, oldGoVersion })
+
+ buildTime = ""
+ goVersion = "go1.23.0"
+
+ build, goVer := FormatBuildInfo()
+
+ assert.Empty(t, build)
+ assert.Equal(t, goVersion, goVer)
+}
+
+func TestFormatBuildInfo_EmptyGoVersion_FallsBackToRuntimeVersion(t *testing.T) {
+ oldBuildTime, oldGoVersion := buildTime, goVersion
+ t.Cleanup(func() { buildTime, goVersion = oldBuildTime, oldGoVersion })
+
+ buildTime = "x"
+ goVersion = ""
+
+ build, goVer := FormatBuildInfo()
+
+ assert.Equal(t, "x", build)
+ assert.Equal(t, runtime.Version(), goVer)
+}
+
+func TestGetConfigPath_Windows(t *testing.T) {
+ if runtime.GOOS != "windows" {
+ t.Skip("windows-specific HOME behavior varies; run on windows")
+ }
+
+ testUserProfilePath := `C:\Users\Test`
+ t.Setenv("USERPROFILE", testUserProfilePath)
+
+ got := GetConfigPath()
+ want := filepath.Join(testUserProfilePath, ".picoclaw", "config.json")
+
+ require.True(t, strings.EqualFold(got, want), "GetConfigPath() = %q, want %q", got, want)
+}
+
+func TestGetVersion(t *testing.T) {
+ assert.Equal(t, "dev", GetVersion())
+}
diff --git a/cmd/picoclaw/internal/migrate/command.go b/cmd/picoclaw/internal/migrate/command.go
new file mode 100644
index 000000000..fb1cee164
--- /dev/null
+++ b/cmd/picoclaw/internal/migrate/command.go
@@ -0,0 +1,48 @@
+package migrate
+
+import (
+ "github.com/spf13/cobra"
+
+ "github.com/sipeed/picoclaw/pkg/migrate"
+)
+
+func NewMigrateCommand() *cobra.Command {
+ var opts migrate.Options
+
+ cmd := &cobra.Command{
+ Use: "migrate",
+ Short: "Migrate from OpenClaw to PicoClaw",
+ Args: cobra.NoArgs,
+ Example: ` picoclaw migrate
+ picoclaw migrate --dry-run
+ picoclaw migrate --refresh
+ picoclaw migrate --force`,
+ RunE: func(cmd *cobra.Command, _ []string) error {
+ result, err := migrate.Run(opts)
+ if err != nil {
+ return err
+ }
+ if !opts.DryRun {
+ migrate.PrintSummary(result)
+ }
+ return nil
+ },
+ }
+
+ cmd.Flags().BoolVar(&opts.DryRun, "dry-run", false,
+ "Show what would be migrated without making changes")
+ cmd.Flags().BoolVar(&opts.Refresh, "refresh", false,
+ "Re-sync workspace files from OpenClaw (repeatable)")
+ cmd.Flags().BoolVar(&opts.ConfigOnly, "config-only", false,
+ "Only migrate config, skip workspace files")
+ cmd.Flags().BoolVar(&opts.WorkspaceOnly, "workspace-only", false,
+ "Only migrate workspace files, skip config")
+ cmd.Flags().BoolVar(&opts.Force, "force", false,
+ "Skip confirmation prompts")
+ cmd.Flags().StringVar(&opts.OpenClawHome, "openclaw-home", "",
+ "Override OpenClaw home directory (default: ~/.openclaw)")
+ cmd.Flags().StringVar(&opts.PicoClawHome, "picoclaw-home", "",
+ "Override PicoClaw home directory (default: ~/.picoclaw)")
+
+ return cmd
+}
diff --git a/cmd/picoclaw/internal/migrate/command_test.go b/cmd/picoclaw/internal/migrate/command_test.go
new file mode 100644
index 000000000..1948aa327
--- /dev/null
+++ b/cmd/picoclaw/internal/migrate/command_test.go
@@ -0,0 +1,38 @@
+package migrate
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestNewMigrateCommand(t *testing.T) {
+ cmd := NewMigrateCommand()
+
+ require.NotNil(t, cmd)
+
+ assert.Equal(t, "migrate", cmd.Use)
+ assert.Equal(t, "Migrate from OpenClaw to PicoClaw", cmd.Short)
+
+ assert.Len(t, cmd.Aliases, 0)
+
+ assert.True(t, cmd.HasExample())
+ assert.False(t, cmd.HasSubCommands())
+
+ assert.Nil(t, cmd.Run)
+ assert.NotNil(t, cmd.RunE)
+
+ assert.Nil(t, cmd.PersistentPreRun)
+ assert.Nil(t, cmd.PersistentPostRun)
+
+ assert.True(t, cmd.HasFlags())
+
+ assert.NotNil(t, cmd.Flags().Lookup("dry-run"))
+ assert.NotNil(t, cmd.Flags().Lookup("refresh"))
+ assert.NotNil(t, cmd.Flags().Lookup("config-only"))
+ assert.NotNil(t, cmd.Flags().Lookup("workspace-only"))
+ assert.NotNil(t, cmd.Flags().Lookup("force"))
+ assert.NotNil(t, cmd.Flags().Lookup("openclaw-home"))
+ assert.NotNil(t, cmd.Flags().Lookup("picoclaw-home"))
+}
diff --git a/cmd/picoclaw/internal/onboard/command.go b/cmd/picoclaw/internal/onboard/command.go
new file mode 100644
index 000000000..ec1012959
--- /dev/null
+++ b/cmd/picoclaw/internal/onboard/command.go
@@ -0,0 +1,24 @@
+package onboard
+
+import (
+ "embed"
+
+ "github.com/spf13/cobra"
+)
+
+//go:generate cp -r ../../../../workspace .
+//go:embed workspace
+var embeddedFiles embed.FS
+
+func NewOnboardCommand() *cobra.Command {
+ cmd := &cobra.Command{
+ Use: "onboard",
+ Aliases: []string{"o"},
+ Short: "Initialize picoclaw configuration and workspace",
+ Run: func(cmd *cobra.Command, args []string) {
+ onboard()
+ },
+ }
+
+ return cmd
+}
diff --git a/cmd/picoclaw/internal/onboard/command_test.go b/cmd/picoclaw/internal/onboard/command_test.go
new file mode 100644
index 000000000..bc799a079
--- /dev/null
+++ b/cmd/picoclaw/internal/onboard/command_test.go
@@ -0,0 +1,29 @@
+package onboard
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestNewOnboardCommand(t *testing.T) {
+ cmd := NewOnboardCommand()
+
+ require.NotNil(t, cmd)
+
+ assert.Equal(t, "onboard", cmd.Use)
+ assert.Equal(t, "Initialize picoclaw configuration and workspace", cmd.Short)
+
+ assert.Len(t, cmd.Aliases, 1)
+ assert.True(t, cmd.HasAlias("o"))
+
+ assert.NotNil(t, cmd.Run)
+ assert.Nil(t, cmd.RunE)
+
+ assert.Nil(t, cmd.PersistentPreRun)
+ assert.Nil(t, cmd.PersistentPostRun)
+
+ assert.False(t, cmd.HasFlags())
+ assert.False(t, cmd.HasSubCommands())
+}
diff --git a/cmd/picoclaw/cmd_onboard.go b/cmd/picoclaw/internal/onboard/helpers.go
similarity index 83%
rename from cmd/picoclaw/cmd_onboard.go
rename to cmd/picoclaw/internal/onboard/helpers.go
index 6e61e3267..4db8bdc8b 100644
--- a/cmd/picoclaw/cmd_onboard.go
+++ b/cmd/picoclaw/internal/onboard/helpers.go
@@ -1,24 +1,17 @@
-// PicoClaw - Ultra-lightweight personal AI agent
-// License: MIT
-
-package main
+package onboard
import (
- "embed"
"fmt"
"io/fs"
"os"
"path/filepath"
+ "github.com/sipeed/picoclaw/cmd/picoclaw/internal"
"github.com/sipeed/picoclaw/pkg/config"
)
-//go:generate cp -r ../../workspace .
-//go:embed workspace
-var embeddedFiles embed.FS
-
func onboard() {
- configPath := getConfigPath()
+ configPath := internal.GetConfigPath()
if _, err := os.Stat(configPath); err == nil {
fmt.Printf("Config already exists at %s\n", configPath)
@@ -40,7 +33,7 @@ func onboard() {
workspace := cfg.WorkspacePath()
createWorkspaceTemplates(workspace)
- fmt.Printf("%s picoclaw is ready!\n", logo)
+ fmt.Printf("%s picoclaw is ready!\n", internal.Logo)
fmt.Println("\nNext steps:")
fmt.Println(" 1. Add your API key to", configPath)
fmt.Println("")
@@ -53,9 +46,16 @@ func onboard() {
fmt.Println(" 2. Chat: picoclaw agent -m \"Hello!\"")
}
+func createWorkspaceTemplates(workspace string) {
+ err := copyEmbeddedToTarget(workspace)
+ if err != nil {
+ fmt.Printf("Error copying workspace templates: %v\n", err)
+ }
+}
+
func copyEmbeddedToTarget(targetDir string) error {
// Ensure target directory exists
- if err := os.MkdirAll(targetDir, 0755); err != nil {
+ if err := os.MkdirAll(targetDir, 0o755); err != nil {
return fmt.Errorf("Failed to create target directory: %w", err)
}
@@ -85,12 +85,12 @@ func copyEmbeddedToTarget(targetDir string) error {
targetPath := filepath.Join(targetDir, new_path)
// Ensure target file's directory exists
- if err := os.MkdirAll(filepath.Dir(targetPath), 0755); err != nil {
+ if err := os.MkdirAll(filepath.Dir(targetPath), 0o755); err != nil {
return fmt.Errorf("Failed to create directory %s: %w", filepath.Dir(targetPath), err)
}
// Write file
- if err := os.WriteFile(targetPath, data, 0644); err != nil {
+ if err := os.WriteFile(targetPath, data, 0o644); err != nil {
return fmt.Errorf("Failed to write file %s: %w", targetPath, err)
}
@@ -99,10 +99,3 @@ func copyEmbeddedToTarget(targetDir string) error {
return err
}
-
-func createWorkspaceTemplates(workspace string) {
- err := copyEmbeddedToTarget(workspace)
- if err != nil {
- fmt.Printf("Error copying workspace templates: %v\n", err)
- }
-}
diff --git a/cmd/picoclaw/internal/skills/command.go b/cmd/picoclaw/internal/skills/command.go
new file mode 100644
index 000000000..7f8bd011d
--- /dev/null
+++ b/cmd/picoclaw/internal/skills/command.go
@@ -0,0 +1,79 @@
+package skills
+
+import (
+ "fmt"
+ "path/filepath"
+
+ "github.com/spf13/cobra"
+
+ "github.com/sipeed/picoclaw/cmd/picoclaw/internal"
+ "github.com/sipeed/picoclaw/pkg/skills"
+)
+
+type deps struct {
+ workspace string
+ installer *skills.SkillInstaller
+ skillsLoader *skills.SkillsLoader
+}
+
+func NewSkillsCommand() *cobra.Command {
+ var d deps
+
+ cmd := &cobra.Command{
+ Use: "skills",
+ Short: "Manage skills",
+ PersistentPreRunE: func(cmd *cobra.Command, _ []string) error {
+ cfg, err := internal.LoadConfig()
+ if err != nil {
+ return fmt.Errorf("error loading config: %w", err)
+ }
+
+ d.workspace = cfg.WorkspacePath()
+ d.installer = skills.NewSkillInstaller(d.workspace)
+
+ // get global config directory and builtin skills directory
+ globalDir := filepath.Dir(internal.GetConfigPath())
+ globalSkillsDir := filepath.Join(globalDir, "skills")
+ builtinSkillsDir := filepath.Join(globalDir, "picoclaw", "skills")
+ d.skillsLoader = skills.NewSkillsLoader(d.workspace, globalSkillsDir, builtinSkillsDir)
+
+ return nil
+ },
+ RunE: func(cmd *cobra.Command, _ []string) error {
+ return cmd.Help()
+ },
+ }
+
+ installerFn := func() (*skills.SkillInstaller, error) {
+ if d.installer == nil {
+ return nil, fmt.Errorf("skills installer is not initialized")
+ }
+ return d.installer, nil
+ }
+
+ loaderFn := func() (*skills.SkillsLoader, error) {
+ if d.skillsLoader == nil {
+ return nil, fmt.Errorf("skills loader is not initialized")
+ }
+ return d.skillsLoader, nil
+ }
+
+ workspaceFn := func() (string, error) {
+ if d.workspace == "" {
+ return "", fmt.Errorf("workspace is not initialized")
+ }
+ return d.workspace, nil
+ }
+
+ cmd.AddCommand(
+ newListCommand(loaderFn),
+ newInstallCommand(installerFn),
+ newInstallBuiltinCommand(workspaceFn),
+ newListBuiltinCommand(),
+ newRemoveCommand(installerFn),
+ newSearchCommand(installerFn),
+ newShowCommand(loaderFn),
+ )
+
+ return cmd
+}
diff --git a/cmd/picoclaw/internal/skills/command_test.go b/cmd/picoclaw/internal/skills/command_test.go
new file mode 100644
index 000000000..0917d1384
--- /dev/null
+++ b/cmd/picoclaw/internal/skills/command_test.go
@@ -0,0 +1,28 @@
+package skills
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestNewSkillsCommand(t *testing.T) {
+ cmd := NewSkillsCommand()
+
+ require.NotNil(t, cmd)
+
+ assert.Equal(t, "skills", cmd.Use)
+ assert.Equal(t, "Manage skills", cmd.Short)
+
+ assert.Len(t, cmd.Aliases, 0)
+
+ assert.False(t, cmd.HasFlags())
+
+ assert.Nil(t, cmd.Run)
+ assert.NotNil(t, cmd.RunE)
+
+ assert.NotNil(t, cmd.PersistentPreRunE)
+ assert.Nil(t, cmd.PersistentPreRun)
+ assert.Nil(t, cmd.PersistentPostRun)
+}
diff --git a/cmd/picoclaw/cmd_skills.go b/cmd/picoclaw/internal/skills/helpers.go
similarity index 57%
rename from cmd/picoclaw/cmd_skills.go
rename to cmd/picoclaw/internal/skills/helpers.go
index 9ea38dcf6..439b81a4f 100644
--- a/cmd/picoclaw/cmd_skills.go
+++ b/cmd/picoclaw/internal/skills/helpers.go
@@ -1,37 +1,20 @@
-// PicoClaw - Ultra-lightweight personal AI agent
-// License: MIT
-
-package main
+package skills
import (
"context"
"fmt"
+ "io"
"os"
"path/filepath"
"strings"
"time"
+ "github.com/sipeed/picoclaw/cmd/picoclaw/internal"
+ "github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/skills"
+ "github.com/sipeed/picoclaw/pkg/utils"
)
-func skillsHelp() {
- fmt.Println("\nSkills commands:")
- fmt.Println(" list List installed skills")
- fmt.Println(" install Install skill from GitHub")
- fmt.Println(" install-builtin Install all builtin skills to workspace")
- fmt.Println(" list-builtin List available builtin skills")
- fmt.Println(" remove Remove installed skill")
- fmt.Println(" search Search available skills")
- fmt.Println(" show Show skill details")
- fmt.Println()
- fmt.Println("Examples:")
- fmt.Println(" picoclaw skills list")
- fmt.Println(" picoclaw skills install sipeed/picoclaw-skills/weather")
- fmt.Println(" picoclaw skills install-builtin")
- fmt.Println(" picoclaw skills list-builtin")
- fmt.Println(" picoclaw skills remove weather")
-}
-
func skillsListCmd(loader *skills.SkillsLoader) {
allSkills := loader.ListSkills()
@@ -50,25 +33,87 @@ func skillsListCmd(loader *skills.SkillsLoader) {
}
}
-func skillsInstallCmd(installer *skills.SkillInstaller) {
- if len(os.Args) < 4 {
- fmt.Println("Usage: picoclaw skills install ")
- fmt.Println("Example: picoclaw skills install sipeed/picoclaw-skills/weather")
- return
- }
-
- repo := os.Args[3]
+func skillsInstallCmd(installer *skills.SkillInstaller, repo string) error {
fmt.Printf("Installing skill from %s...\n", repo)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := installer.InstallFromGitHub(ctx, repo); err != nil {
- fmt.Printf("✗ Failed to install skill: %v\n", err)
- os.Exit(1)
+ return fmt.Errorf("failed to install skill: %w", err)
}
- fmt.Printf("✓ Skill '%s' installed successfully!\n", filepath.Base(repo))
+ fmt.Printf("\u2713 Skill '%s' installed successfully!\n", filepath.Base(repo))
+
+ return nil
+}
+
+// skillsInstallFromRegistry installs a skill from a named registry (e.g. clawhub).
+func skillsInstallFromRegistry(cfg *config.Config, registryName, slug string) error {
+ err := utils.ValidateSkillIdentifier(registryName)
+ if err != nil {
+ return fmt.Errorf("✗ invalid registry name: %w", err)
+ }
+
+ err = utils.ValidateSkillIdentifier(slug)
+ if err != nil {
+ return fmt.Errorf("✗ invalid slug: %w", err)
+ }
+
+ fmt.Printf("Installing skill '%s' from %s registry...\n", slug, registryName)
+
+ registryMgr := skills.NewRegistryManagerFromConfig(skills.RegistryConfig{
+ MaxConcurrentSearches: cfg.Tools.Skills.MaxConcurrentSearches,
+ ClawHub: skills.ClawHubConfig(cfg.Tools.Skills.Registries.ClawHub),
+ })
+
+ registry := registryMgr.GetRegistry(registryName)
+ if registry == nil {
+ return fmt.Errorf("✗ registry '%s' not found or not enabled. check your config.json.", registryName)
+ }
+
+ workspace := cfg.WorkspacePath()
+ targetDir := filepath.Join(workspace, "skills", slug)
+
+ if _, err = os.Stat(targetDir); err == nil {
+ return fmt.Errorf("\u2717 skill '%s' already installed at %s", slug, targetDir)
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
+ defer cancel()
+
+ if err = os.MkdirAll(filepath.Join(workspace, "skills"), 0o755); err != nil {
+ return fmt.Errorf("\u2717 failed to create skills directory: %v", err)
+ }
+
+ result, err := registry.DownloadAndInstall(ctx, slug, "", targetDir)
+ if err != nil {
+ rmErr := os.RemoveAll(targetDir)
+ if rmErr != nil {
+ fmt.Printf("\u2717 Failed to remove partial install: %v\n", rmErr)
+ }
+ return fmt.Errorf("✗ failed to install skill: %w", err)
+ }
+
+ if result.IsMalwareBlocked {
+ rmErr := os.RemoveAll(targetDir)
+ if rmErr != nil {
+ fmt.Printf("\u2717 Failed to remove partial install: %v\n", rmErr)
+ }
+
+ return fmt.Errorf("\u2717 Skill '%s' is flagged as malicious and cannot be installed.\n", slug)
+ }
+
+ if result.IsSuspicious {
+ fmt.Printf("\u26a0\ufe0f Warning: skill '%s' is flagged as suspicious.\n", slug)
+ }
+
+ fmt.Printf("\u2713 Skill '%s' v%s installed successfully!\n", slug, result.Version)
+ if result.Summary != "" {
+ fmt.Printf(" %s\n", result.Summary)
+ }
+
+ return nil
}
func skillsRemoveCmd(installer *skills.SkillInstaller, skillName string) {
@@ -104,7 +149,7 @@ func skillsInstallBuiltinCmd(workspace string) {
continue
}
- if err := os.MkdirAll(workspacePath, 0755); err != nil {
+ if err := os.MkdirAll(workspacePath, 0o755); err != nil {
fmt.Printf("✗ Failed to create directory for %s: %v\n", skillName, err)
continue
}
@@ -119,7 +164,7 @@ func skillsInstallBuiltinCmd(workspace string) {
}
func skillsListBuiltinCmd() {
- cfg, err := loadConfig()
+ cfg, err := internal.LoadConfig()
if err != nil {
fmt.Printf("Error loading config: %v\n", err)
return
@@ -214,3 +259,37 @@ func skillsShowCmd(loader *skills.SkillsLoader, skillName string) {
fmt.Println("----------------------")
fmt.Println(content)
}
+
+func copyDirectory(src, dst string) error {
+ return filepath.Walk(src, func(path string, info os.FileInfo, err error) error {
+ if err != nil {
+ return err
+ }
+
+ relPath, err := filepath.Rel(src, path)
+ if err != nil {
+ return err
+ }
+
+ dstPath := filepath.Join(dst, relPath)
+
+ if info.IsDir() {
+ return os.MkdirAll(dstPath, info.Mode())
+ }
+
+ srcFile, err := os.Open(path)
+ if err != nil {
+ return err
+ }
+ defer srcFile.Close()
+
+ dstFile, err := os.OpenFile(dstPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, info.Mode())
+ if err != nil {
+ return err
+ }
+ defer dstFile.Close()
+
+ _, err = io.Copy(dstFile, srcFile)
+ return err
+ })
+}
diff --git a/cmd/picoclaw/internal/skills/install.go b/cmd/picoclaw/internal/skills/install.go
new file mode 100644
index 000000000..a30f68632
--- /dev/null
+++ b/cmd/picoclaw/internal/skills/install.go
@@ -0,0 +1,58 @@
+package skills
+
+import (
+ "fmt"
+
+ "github.com/spf13/cobra"
+
+ "github.com/sipeed/picoclaw/cmd/picoclaw/internal"
+ "github.com/sipeed/picoclaw/pkg/skills"
+)
+
+func newInstallCommand(installerFn func() (*skills.SkillInstaller, error)) *cobra.Command {
+ var registry string
+
+ cmd := &cobra.Command{
+ Use: "install",
+ Short: "Install skill from GitHub",
+ Example: `
+picoclaw skills install sipeed/picoclaw-skills/weather
+picoclaw skills install --registry clawhub github
+`,
+ Args: func(cmd *cobra.Command, args []string) error {
+ if registry != "" {
+ if len(args) != 2 {
+ return fmt.Errorf("when --registry is set, exactly 2 arguments are required: ")
+ }
+ return nil
+ }
+
+ if len(args) != 1 {
+ return fmt.Errorf("exactly 1 argument is required: ")
+ }
+
+ return nil
+ },
+ RunE: func(_ *cobra.Command, args []string) error {
+ installer, err := installerFn()
+ if err != nil {
+ return err
+ }
+
+ if registry != "" {
+ cfg, err := internal.LoadConfig()
+ if err != nil {
+ return err
+ }
+
+ return skillsInstallFromRegistry(cfg, args[0], args[1])
+ }
+
+ return skillsInstallCmd(installer, args[0])
+ },
+ }
+
+ cmd.Flags().StringVar(®istry, "registry", "", "Install from registry: --registry ")
+
+ return cmd
+}
diff --git a/cmd/picoclaw/internal/skills/install_test.go b/cmd/picoclaw/internal/skills/install_test.go
new file mode 100644
index 000000000..97787a986
--- /dev/null
+++ b/cmd/picoclaw/internal/skills/install_test.go
@@ -0,0 +1,28 @@
+package skills
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestNewInstallSubcommand(t *testing.T) {
+ cmd := newInstallCommand(nil)
+
+ require.NotNil(t, cmd)
+
+ assert.Equal(t, "install", cmd.Use)
+ assert.Equal(t, "Install skill from GitHub", cmd.Short)
+
+ assert.Nil(t, cmd.Run)
+ assert.NotNil(t, cmd.RunE)
+
+ assert.True(t, cmd.HasExample())
+ assert.False(t, cmd.HasSubCommands())
+
+ assert.True(t, cmd.HasFlags())
+ assert.NotNil(t, cmd.Flags().Lookup("registry"))
+
+ assert.Len(t, cmd.Aliases, 0)
+}
diff --git a/cmd/picoclaw/internal/skills/installbuiltin.go b/cmd/picoclaw/internal/skills/installbuiltin.go
new file mode 100644
index 000000000..d4b7c6a9f
--- /dev/null
+++ b/cmd/picoclaw/internal/skills/installbuiltin.go
@@ -0,0 +1,21 @@
+package skills
+
+import "github.com/spf13/cobra"
+
+func newInstallBuiltinCommand(workspaceFn func() (string, error)) *cobra.Command {
+ cmd := &cobra.Command{
+ Use: "install-builtin",
+ Short: "Install all builtin skills to workspace",
+ Example: `picoclaw skills install-builtin`,
+ RunE: func(_ *cobra.Command, _ []string) error {
+ workspace, err := workspaceFn()
+ if err != nil {
+ return err
+ }
+ skillsInstallBuiltinCmd(workspace)
+ return nil
+ },
+ }
+
+ return cmd
+}
diff --git a/cmd/picoclaw/internal/skills/installbuiltin_test.go b/cmd/picoclaw/internal/skills/installbuiltin_test.go
new file mode 100644
index 000000000..ea65907e3
--- /dev/null
+++ b/cmd/picoclaw/internal/skills/installbuiltin_test.go
@@ -0,0 +1,27 @@
+package skills
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestNewInstallbuiltinSubcommand(t *testing.T) {
+ cmd := newInstallBuiltinCommand(nil)
+
+ require.NotNil(t, cmd)
+
+ assert.Equal(t, "install-builtin", cmd.Use)
+ assert.Equal(t, "Install all builtin skills to workspace", cmd.Short)
+
+ assert.Nil(t, cmd.Run)
+ assert.NotNil(t, cmd.RunE)
+
+ assert.True(t, cmd.HasExample())
+ assert.False(t, cmd.HasSubCommands())
+
+ assert.False(t, cmd.HasFlags())
+
+ assert.Len(t, cmd.Aliases, 0)
+}
diff --git a/cmd/picoclaw/internal/skills/list.go b/cmd/picoclaw/internal/skills/list.go
new file mode 100644
index 000000000..7d89ff8ed
--- /dev/null
+++ b/cmd/picoclaw/internal/skills/list.go
@@ -0,0 +1,25 @@
+package skills
+
+import (
+ "github.com/spf13/cobra"
+
+ "github.com/sipeed/picoclaw/pkg/skills"
+)
+
+func newListCommand(loaderFn func() (*skills.SkillsLoader, error)) *cobra.Command {
+ cmd := &cobra.Command{
+ Use: "list",
+ Short: "List installed skills",
+ Example: `picoclaw skills list`,
+ RunE: func(_ *cobra.Command, _ []string) error {
+ loader, err := loaderFn()
+ if err != nil {
+ return err
+ }
+ skillsListCmd(loader)
+ return nil
+ },
+ }
+
+ return cmd
+}
diff --git a/cmd/picoclaw/internal/skills/list_test.go b/cmd/picoclaw/internal/skills/list_test.go
new file mode 100644
index 000000000..9947ce7aa
--- /dev/null
+++ b/cmd/picoclaw/internal/skills/list_test.go
@@ -0,0 +1,27 @@
+package skills
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestNewListSubcommand(t *testing.T) {
+ cmd := newListCommand(nil)
+
+ require.NotNil(t, cmd)
+
+ assert.Equal(t, "list", cmd.Use)
+ assert.Equal(t, "List installed skills", cmd.Short)
+
+ assert.Nil(t, cmd.Run)
+ assert.NotNil(t, cmd.RunE)
+
+ assert.True(t, cmd.HasExample())
+ assert.False(t, cmd.HasSubCommands())
+
+ assert.False(t, cmd.HasFlags())
+
+ assert.Len(t, cmd.Aliases, 0)
+}
diff --git a/cmd/picoclaw/internal/skills/listbuiltin.go b/cmd/picoclaw/internal/skills/listbuiltin.go
new file mode 100644
index 000000000..a3efb8d83
--- /dev/null
+++ b/cmd/picoclaw/internal/skills/listbuiltin.go
@@ -0,0 +1,16 @@
+package skills
+
+import "github.com/spf13/cobra"
+
+func newListBuiltinCommand() *cobra.Command {
+ cmd := &cobra.Command{
+ Use: "list-builtin",
+ Short: "List available builtin skills",
+ Example: `picoclaw skills list-builtin`,
+ Run: func(_ *cobra.Command, _ []string) {
+ skillsListBuiltinCmd()
+ },
+ }
+
+ return cmd
+}
diff --git a/cmd/picoclaw/internal/skills/listbuiltin_test.go b/cmd/picoclaw/internal/skills/listbuiltin_test.go
new file mode 100644
index 000000000..d4f45a436
--- /dev/null
+++ b/cmd/picoclaw/internal/skills/listbuiltin_test.go
@@ -0,0 +1,26 @@
+package skills
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestNewListbuiltinSubcommand(t *testing.T) {
+ cmd := newListBuiltinCommand()
+
+ require.NotNil(t, cmd)
+
+ assert.Equal(t, "list-builtin", cmd.Use)
+ assert.Equal(t, "List available builtin skills", cmd.Short)
+
+ assert.NotNil(t, cmd.Run)
+
+ assert.True(t, cmd.HasExample())
+ assert.False(t, cmd.HasSubCommands())
+
+ assert.False(t, cmd.HasFlags())
+
+ assert.Len(t, cmd.Aliases, 0)
+}
diff --git a/cmd/picoclaw/internal/skills/remove.go b/cmd/picoclaw/internal/skills/remove.go
new file mode 100644
index 000000000..cd7d3a8b4
--- /dev/null
+++ b/cmd/picoclaw/internal/skills/remove.go
@@ -0,0 +1,27 @@
+package skills
+
+import (
+ "github.com/spf13/cobra"
+
+ "github.com/sipeed/picoclaw/pkg/skills"
+)
+
+func newRemoveCommand(installerFn func() (*skills.SkillInstaller, error)) *cobra.Command {
+ cmd := &cobra.Command{
+ Use: "remove",
+ Aliases: []string{"rm", "uninstall"},
+ Short: "Remove installed skill",
+ Args: cobra.ExactArgs(1),
+ Example: `picoclaw skills remove weather`,
+ RunE: func(_ *cobra.Command, args []string) error {
+ installer, err := installerFn()
+ if err != nil {
+ return err
+ }
+ skillsRemoveCmd(installer, args[0])
+ return nil
+ },
+ }
+
+ return cmd
+}
diff --git a/cmd/picoclaw/internal/skills/remove_test.go b/cmd/picoclaw/internal/skills/remove_test.go
new file mode 100644
index 000000000..b4c79760c
--- /dev/null
+++ b/cmd/picoclaw/internal/skills/remove_test.go
@@ -0,0 +1,29 @@
+package skills
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestNewRemoveSubcommand(t *testing.T) {
+ cmd := newRemoveCommand(nil)
+
+ require.NotNil(t, cmd)
+
+ assert.Equal(t, "remove", cmd.Use)
+ assert.Equal(t, "Remove installed skill", cmd.Short)
+
+ assert.Nil(t, cmd.Run)
+ assert.NotNil(t, cmd.RunE)
+
+ assert.True(t, cmd.HasExample())
+ assert.False(t, cmd.HasSubCommands())
+
+ assert.False(t, cmd.HasFlags())
+
+ assert.Len(t, cmd.Aliases, 2)
+ assert.True(t, cmd.HasAlias("rm"))
+ assert.True(t, cmd.HasAlias("uninstall"))
+}
diff --git a/cmd/picoclaw/internal/skills/search.go b/cmd/picoclaw/internal/skills/search.go
new file mode 100644
index 000000000..53bc99109
--- /dev/null
+++ b/cmd/picoclaw/internal/skills/search.go
@@ -0,0 +1,24 @@
+package skills
+
+import (
+ "github.com/spf13/cobra"
+
+ "github.com/sipeed/picoclaw/pkg/skills"
+)
+
+func newSearchCommand(installerFn func() (*skills.SkillInstaller, error)) *cobra.Command {
+ cmd := &cobra.Command{
+ Use: "search",
+ Short: "Search available skills",
+ RunE: func(_ *cobra.Command, _ []string) error {
+ installer, err := installerFn()
+ if err != nil {
+ return err
+ }
+ skillsSearchCmd(installer)
+ return nil
+ },
+ }
+
+ return cmd
+}
diff --git a/cmd/picoclaw/internal/skills/search_test.go b/cmd/picoclaw/internal/skills/search_test.go
new file mode 100644
index 000000000..19f63a9ff
--- /dev/null
+++ b/cmd/picoclaw/internal/skills/search_test.go
@@ -0,0 +1,25 @@
+package skills
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestNewSearchSubcommand(t *testing.T) {
+ cmd := newSearchCommand(nil)
+
+ require.NotNil(t, cmd)
+
+ assert.Equal(t, "search", cmd.Use)
+ assert.Equal(t, "Search available skills", cmd.Short)
+
+ assert.Nil(t, cmd.Run)
+ assert.NotNil(t, cmd.RunE)
+
+ assert.False(t, cmd.HasSubCommands())
+ assert.False(t, cmd.HasFlags())
+
+ assert.Len(t, cmd.Aliases, 0)
+}
diff --git a/cmd/picoclaw/internal/skills/show.go b/cmd/picoclaw/internal/skills/show.go
new file mode 100644
index 000000000..e484f3f28
--- /dev/null
+++ b/cmd/picoclaw/internal/skills/show.go
@@ -0,0 +1,26 @@
+package skills
+
+import (
+ "github.com/spf13/cobra"
+
+ "github.com/sipeed/picoclaw/pkg/skills"
+)
+
+func newShowCommand(loaderFn func() (*skills.SkillsLoader, error)) *cobra.Command {
+ cmd := &cobra.Command{
+ Use: "show",
+ Short: "Show skill details",
+ Args: cobra.ExactArgs(1),
+ Example: `picoclaw skills show weather`,
+ RunE: func(_ *cobra.Command, args []string) error {
+ loader, err := loaderFn()
+ if err != nil {
+ return err
+ }
+ skillsShowCmd(loader, args[0])
+ return nil
+ },
+ }
+
+ return cmd
+}
diff --git a/cmd/picoclaw/internal/skills/show_test.go b/cmd/picoclaw/internal/skills/show_test.go
new file mode 100644
index 000000000..5858d2790
--- /dev/null
+++ b/cmd/picoclaw/internal/skills/show_test.go
@@ -0,0 +1,27 @@
+package skills
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestNewShowSubcommand(t *testing.T) {
+ cmd := newShowCommand(nil)
+
+ require.NotNil(t, cmd)
+
+ assert.Equal(t, "show", cmd.Use)
+ assert.Equal(t, "Show skill details", cmd.Short)
+
+ assert.Nil(t, cmd.Run)
+ assert.NotNil(t, cmd.RunE)
+
+ assert.True(t, cmd.HasExample())
+ assert.False(t, cmd.HasSubCommands())
+
+ assert.False(t, cmd.HasFlags())
+
+ assert.Len(t, cmd.Aliases, 0)
+}
diff --git a/cmd/picoclaw/internal/status/command.go b/cmd/picoclaw/internal/status/command.go
new file mode 100644
index 000000000..9303ae2ec
--- /dev/null
+++ b/cmd/picoclaw/internal/status/command.go
@@ -0,0 +1,18 @@
+package status
+
+import (
+ "github.com/spf13/cobra"
+)
+
+func NewStatusCommand() *cobra.Command {
+ cmd := &cobra.Command{
+ Use: "status",
+ Aliases: []string{"s"},
+ Short: "Show picoclaw status",
+ Run: func(cmd *cobra.Command, args []string) {
+ statusCmd()
+ },
+ }
+
+ return cmd
+}
diff --git a/cmd/picoclaw/internal/status/command_test.go b/cmd/picoclaw/internal/status/command_test.go
new file mode 100644
index 000000000..974b4ea3d
--- /dev/null
+++ b/cmd/picoclaw/internal/status/command_test.go
@@ -0,0 +1,29 @@
+package status
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestNewStatusCommand(t *testing.T) {
+ cmd := NewStatusCommand()
+
+ require.NotNil(t, cmd)
+
+ assert.Equal(t, "status", cmd.Use)
+
+ assert.Len(t, cmd.Aliases, 1)
+ assert.True(t, cmd.HasAlias("s"))
+
+ assert.Equal(t, "Show picoclaw status", cmd.Short)
+
+ assert.False(t, cmd.HasSubCommands())
+
+ assert.NotNil(t, cmd.Run)
+ assert.Nil(t, cmd.RunE)
+
+ assert.Nil(t, cmd.PersistentPreRun)
+ assert.Nil(t, cmd.PersistentPostRun)
+}
diff --git a/cmd/picoclaw/cmd_status.go b/cmd/picoclaw/internal/status/helpers.go
similarity index 88%
rename from cmd/picoclaw/cmd_status.go
rename to cmd/picoclaw/internal/status/helpers.go
index 07296784e..ab28f4885 100644
--- a/cmd/picoclaw/cmd_status.go
+++ b/cmd/picoclaw/internal/status/helpers.go
@@ -1,27 +1,25 @@
-// PicoClaw - Ultra-lightweight personal AI agent
-// License: MIT
-
-package main
+package status
import (
"fmt"
"os"
+ "github.com/sipeed/picoclaw/cmd/picoclaw/internal"
"github.com/sipeed/picoclaw/pkg/auth"
)
func statusCmd() {
- cfg, err := loadConfig()
+ cfg, err := internal.LoadConfig()
if err != nil {
fmt.Printf("Error loading config: %v\n", err)
return
}
- configPath := getConfigPath()
+ configPath := internal.GetConfigPath()
- fmt.Printf("%s picoclaw Status\n", logo)
- fmt.Printf("Version: %s\n", formatVersion())
- build, _ := formatBuildInfo()
+ fmt.Printf("%s picoclaw Status\n", internal.Logo)
+ fmt.Printf("Version: %s\n", internal.FormatVersion())
+ build, _ := internal.FormatBuildInfo()
if build != "" {
fmt.Printf("Build: %s\n", build)
}
@@ -41,7 +39,7 @@ func statusCmd() {
}
if _, err := os.Stat(configPath); err == nil {
- fmt.Printf("Model: %s\n", cfg.Agents.Defaults.Model)
+ fmt.Printf("Model: %s\n", cfg.Agents.Defaults.GetModelName())
hasOpenRouter := cfg.Providers.OpenRouter.APIKey != ""
hasAnthropic := cfg.Providers.Anthropic.APIKey != ""
diff --git a/cmd/picoclaw/internal/version/command.go b/cmd/picoclaw/internal/version/command.go
new file mode 100644
index 000000000..1cf686671
--- /dev/null
+++ b/cmd/picoclaw/internal/version/command.go
@@ -0,0 +1,33 @@
+package version
+
+import (
+ "fmt"
+
+ "github.com/spf13/cobra"
+
+ "github.com/sipeed/picoclaw/cmd/picoclaw/internal"
+)
+
+func NewVersionCommand() *cobra.Command {
+ cmd := &cobra.Command{
+ Use: "version",
+ Aliases: []string{"v"},
+ Short: "Show version information",
+ Run: func(_ *cobra.Command, _ []string) {
+ printVersion()
+ },
+ }
+
+ return cmd
+}
+
+func printVersion() {
+ fmt.Printf("%s picoclaw %s\n", internal.Logo, internal.FormatVersion())
+ build, goVer := internal.FormatBuildInfo()
+ if build != "" {
+ fmt.Printf(" Build: %s\n", build)
+ }
+ if goVer != "" {
+ fmt.Printf(" Go: %s\n", goVer)
+ }
+}
diff --git a/cmd/picoclaw/internal/version/command_test.go b/cmd/picoclaw/internal/version/command_test.go
new file mode 100644
index 000000000..f08a4d1ea
--- /dev/null
+++ b/cmd/picoclaw/internal/version/command_test.go
@@ -0,0 +1,31 @@
+package version
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestNewVersionCommand(t *testing.T) {
+ cmd := NewVersionCommand()
+
+ require.NotNil(t, cmd)
+
+ assert.Equal(t, "version", cmd.Use)
+
+ assert.Len(t, cmd.Aliases, 1)
+ assert.True(t, cmd.HasAlias("v"))
+
+ assert.False(t, cmd.HasFlags())
+
+ assert.Equal(t, "Show version information", cmd.Short)
+
+ assert.False(t, cmd.HasSubCommands())
+
+ assert.NotNil(t, cmd.Run)
+ assert.Nil(t, cmd.RunE)
+
+ assert.Nil(t, cmd.PersistentPreRun)
+ assert.Nil(t, cmd.PersistentPostRun)
+}
diff --git a/cmd/picoclaw/main.go b/cmd/picoclaw/main.go
index ce9389417..6db69c990 100644
--- a/cmd/picoclaw/main.go
+++ b/cmd/picoclaw/main.go
@@ -8,192 +8,49 @@ package main
import (
"fmt"
- "io"
"os"
- "path/filepath"
- "runtime"
- "github.com/sipeed/picoclaw/pkg/config"
- "github.com/sipeed/picoclaw/pkg/skills"
+ "github.com/spf13/cobra"
+
+ "github.com/sipeed/picoclaw/cmd/picoclaw/internal"
+ "github.com/sipeed/picoclaw/cmd/picoclaw/internal/agent"
+ "github.com/sipeed/picoclaw/cmd/picoclaw/internal/auth"
+ "github.com/sipeed/picoclaw/cmd/picoclaw/internal/cron"
+ "github.com/sipeed/picoclaw/cmd/picoclaw/internal/gateway"
+ "github.com/sipeed/picoclaw/cmd/picoclaw/internal/migrate"
+ "github.com/sipeed/picoclaw/cmd/picoclaw/internal/onboard"
+ "github.com/sipeed/picoclaw/cmd/picoclaw/internal/skills"
+ "github.com/sipeed/picoclaw/cmd/picoclaw/internal/status"
+ "github.com/sipeed/picoclaw/cmd/picoclaw/internal/version"
)
-var (
- version = "dev"
- gitCommit string
- buildTime string
- goVersion string
-)
+func NewPicoclawCommand() *cobra.Command {
+ short := fmt.Sprintf("%s picoclaw - Personal AI Assistant v%s\n\n", internal.Logo, internal.GetVersion())
-const logo = "🦞"
-
-// formatVersion returns the version string with optional git commit
-func formatVersion() string {
- v := version
- if gitCommit != "" {
- v += fmt.Sprintf(" (git: %s)", gitCommit)
+ cmd := &cobra.Command{
+ Use: "picoclaw",
+ Short: short,
+ Example: "picoclaw list",
}
- return v
-}
-// formatBuildInfo returns build time and go version info
-func formatBuildInfo() (build string, goVer string) {
- if buildTime != "" {
- build = buildTime
- }
- goVer = goVersion
- if goVer == "" {
- goVer = runtime.Version()
- }
- return
-}
+ cmd.AddCommand(
+ onboard.NewOnboardCommand(),
+ agent.NewAgentCommand(),
+ auth.NewAuthCommand(),
+ gateway.NewGatewayCommand(),
+ status.NewStatusCommand(),
+ cron.NewCronCommand(),
+ migrate.NewMigrateCommand(),
+ skills.NewSkillsCommand(),
+ version.NewVersionCommand(),
+ )
-func printVersion() {
- fmt.Printf("%s picoclaw %s\n", logo, formatVersion())
- build, goVer := formatBuildInfo()
- if build != "" {
- fmt.Printf(" Build: %s\n", build)
- }
- if goVer != "" {
- fmt.Printf(" Go: %s\n", goVer)
- }
-}
-
-func copyDirectory(src, dst string) error {
- return filepath.Walk(src, func(path string, info os.FileInfo, err error) error {
- if err != nil {
- return err
- }
-
- relPath, err := filepath.Rel(src, path)
- if err != nil {
- return err
- }
-
- dstPath := filepath.Join(dst, relPath)
-
- if info.IsDir() {
- return os.MkdirAll(dstPath, info.Mode())
- }
-
- srcFile, err := os.Open(path)
- if err != nil {
- return err
- }
- defer srcFile.Close()
-
- dstFile, err := os.OpenFile(dstPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, info.Mode())
- if err != nil {
- return err
- }
- defer dstFile.Close()
-
- _, err = io.Copy(dstFile, srcFile)
- return err
- })
+ return cmd
}
func main() {
- if len(os.Args) < 2 {
- printHelp()
- os.Exit(1)
- }
-
- command := os.Args[1]
-
- switch command {
- case "onboard":
- onboard()
- case "agent":
- agentCmd()
- case "gateway":
- gatewayCmd()
- case "status":
- statusCmd()
- case "migrate":
- migrateCmd()
- case "auth":
- authCmd()
- case "cron":
- cronCmd()
- case "skills":
- if len(os.Args) < 3 {
- skillsHelp()
- return
- }
-
- subcommand := os.Args[2]
-
- cfg, err := loadConfig()
- if err != nil {
- fmt.Printf("Error loading config: %v\n", err)
- os.Exit(1)
- }
-
- workspace := cfg.WorkspacePath()
- installer := skills.NewSkillInstaller(workspace)
- // 获取全局配置目录和内置 skills 目录
- globalDir := filepath.Dir(getConfigPath())
- globalSkillsDir := filepath.Join(globalDir, "skills")
- builtinSkillsDir := filepath.Join(globalDir, "picoclaw", "skills")
- skillsLoader := skills.NewSkillsLoader(workspace, globalSkillsDir, builtinSkillsDir)
-
- switch subcommand {
- case "list":
- skillsListCmd(skillsLoader)
- case "install":
- skillsInstallCmd(installer)
- case "remove", "uninstall":
- if len(os.Args) < 4 {
- fmt.Println("Usage: picoclaw skills remove ")
- return
- }
- skillsRemoveCmd(installer, os.Args[3])
- case "install-builtin":
- skillsInstallBuiltinCmd(workspace)
- case "list-builtin":
- skillsListBuiltinCmd()
- case "search":
- skillsSearchCmd(installer)
- case "show":
- if len(os.Args) < 4 {
- fmt.Println("Usage: picoclaw skills show ")
- return
- }
- skillsShowCmd(skillsLoader, os.Args[3])
- default:
- fmt.Printf("Unknown skills command: %s\n", subcommand)
- skillsHelp()
- }
- case "version", "--version", "-v":
- printVersion()
- default:
- fmt.Printf("Unknown command: %s\n", command)
- printHelp()
+ cmd := NewPicoclawCommand()
+ if err := cmd.Execute(); err != nil {
os.Exit(1)
}
}
-
-func printHelp() {
- fmt.Printf("%s picoclaw - Personal AI Assistant v%s\n\n", logo, version)
- fmt.Println("Usage: picoclaw ")
- fmt.Println()
- fmt.Println("Commands:")
- fmt.Println(" onboard Initialize picoclaw configuration and workspace")
- fmt.Println(" agent Interact with the agent directly")
- fmt.Println(" auth Manage authentication (login, logout, status)")
- fmt.Println(" gateway Start picoclaw gateway")
- fmt.Println(" status Show picoclaw status")
- fmt.Println(" cron Manage scheduled tasks")
- fmt.Println(" migrate Migrate from OpenClaw to PicoClaw")
- fmt.Println(" skills Manage skills (install, list, remove)")
- fmt.Println(" version Show version information")
-}
-
-func getConfigPath() string {
- home, _ := os.UserHomeDir()
- return filepath.Join(home, ".picoclaw", "config.json")
-}
-
-func loadConfig() (*config.Config, error) {
- return config.LoadConfig(getConfigPath())
-}
diff --git a/cmd/picoclaw/main_test.go b/cmd/picoclaw/main_test.go
new file mode 100644
index 000000000..3740ba358
--- /dev/null
+++ b/cmd/picoclaw/main_test.go
@@ -0,0 +1,56 @@
+package main
+
+import (
+ "fmt"
+ "slices"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "github.com/sipeed/picoclaw/cmd/picoclaw/internal"
+)
+
+func TestNewPicoclawCommand(t *testing.T) {
+ cmd := NewPicoclawCommand()
+
+ require.NotNil(t, cmd)
+
+ short := fmt.Sprintf("%s picoclaw - Personal AI Assistant v%s\n\n", internal.Logo, internal.GetVersion())
+
+ assert.Equal(t, "picoclaw", cmd.Use)
+ assert.Equal(t, short, cmd.Short)
+
+ assert.True(t, cmd.HasSubCommands())
+ assert.True(t, cmd.HasAvailableSubCommands())
+
+ assert.False(t, cmd.HasFlags())
+
+ assert.Nil(t, cmd.Run)
+ assert.Nil(t, cmd.RunE)
+
+ assert.Nil(t, cmd.PersistentPreRun)
+ assert.Nil(t, cmd.PersistentPostRun)
+
+ allowedCommands := []string{
+ "agent",
+ "auth",
+ "cron",
+ "gateway",
+ "migrate",
+ "onboard",
+ "skills",
+ "status",
+ "version",
+ }
+
+ subcommands := cmd.Commands()
+ assert.Len(t, subcommands, len(allowedCommands))
+
+ for _, subcmd := range subcommands {
+ found := slices.Contains(allowedCommands, subcmd.Name())
+ assert.True(t, found, "unexpected subcommand %q", subcmd.Name())
+
+ assert.False(t, subcmd.Hidden)
+ }
+}
diff --git a/config/config.example.json b/config/config.example.json
index e046f7b76..605f9dc1d 100644
--- a/config/config.example.json
+++ b/config/config.example.json
@@ -3,7 +3,7 @@
"defaults": {
"workspace": "~/.picoclaw/workspace",
"restrict_to_workspace": true,
- "model": "gpt4",
+ "model_name": "gpt4",
"max_tokens": 8192,
"temperature": 0.7,
"max_tool_iterations": 20
@@ -57,7 +57,8 @@
"discord": {
"enabled": false,
"token": "YOUR_DISCORD_BOT_TOKEN",
- "allow_from": []
+ "allow_from": [],
+ "mention_only": false
},
"qq": {
"enabled": false,
@@ -112,6 +113,32 @@
"reconnect_interval": 5,
"group_trigger_prefix": [],
"allow_from": []
+ },
+ "wecom": {
+ "_comment": "WeCom Bot (智能机器人) - Easier setup, supports group chats",
+ "enabled": false,
+ "token": "YOUR_TOKEN",
+ "encoding_aes_key": "YOUR_43_CHAR_ENCODING_AES_KEY",
+ "webhook_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=YOUR_KEY",
+ "webhook_host": "0.0.0.0",
+ "webhook_port": 18793,
+ "webhook_path": "/webhook/wecom",
+ "allow_from": [],
+ "reply_timeout": 5
+ },
+ "wecom_app": {
+ "_comment": "WeCom App (自建应用) - More features, proactive messaging, private chat only. See docs/wecom-app-configuration.md",
+ "enabled": false,
+ "corp_id": "YOUR_CORP_ID",
+ "corp_secret": "YOUR_CORP_SECRET",
+ "agent_id": 1000002,
+ "token": "YOUR_TOKEN",
+ "encoding_aes_key": "YOUR_43_CHAR_ENCODING_AES_KEY",
+ "webhook_host": "0.0.0.0",
+ "webhook_port": 18792,
+ "webhook_path": "/webhook/wecom-app",
+ "allow_from": [],
+ "reply_timeout": 5
}
},
"providers": {
@@ -169,6 +196,10 @@
"volcengine": {
"api_key": "",
"api_base": ""
+ },
+ "mistral": {
+ "api_key": "",
+ "api_base": "https://api.mistral.ai/v1"
}
},
"tools": {
@@ -191,7 +222,8 @@
"enabled": false,
"base_url": "http://localhost:8888",
"max_results": 5
- }
+ },
+ "proxy": ""
},
"cron": {
"exec_timeout_minutes": 5
@@ -199,6 +231,17 @@
"exec": {
"enable_deny_patterns": false,
"custom_deny_patterns": []
+ },
+ "skills": {
+ "registries": {
+ "clawhub": {
+ "enabled": true,
+ "base_url": "https://clawhub.ai",
+ "search_path": "/api/v1/search",
+ "skills_path": "/api/v1/skills",
+ "download_path": "/api/v1/download"
+ }
+ }
}
},
"heartbeat": {
@@ -210,7 +253,7 @@
"monitor_usb": true
},
"gateway": {
- "host": "0.0.0.0",
+ "host": "127.0.0.1",
"port": 18790
}
}
diff --git a/docker-compose.yml b/docker-compose.yml
index 32e8ee339..c268b01cd 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -10,6 +10,9 @@ services:
container_name: picoclaw-agent
profiles:
- agent
+ # Uncomment to access host network; leave commented unless needed.
+ #extra_hosts:
+ # - "host.docker.internal:host-gateway"
volumes:
- ./config/config.json:/home/picoclaw/.picoclaw/config.json:ro
- picoclaw-workspace:/home/picoclaw/.picoclaw/workspace
@@ -29,6 +32,9 @@ services:
restart: unless-stopped
profiles:
- gateway
+ # Uncomment to access host network; leave commented unless needed.
+ #extra_hosts:
+ # - "host.docker.internal:host-gateway"
volumes:
# Configuration file
- ./config/config.json:/home/picoclaw/.picoclaw/config.json:ro
diff --git a/docs/ANTIGRAVITY_AUTH.md b/docs/ANTIGRAVITY_AUTH.md
index 5d68de427..89261d899 100644
--- a/docs/ANTIGRAVITY_AUTH.md
+++ b/docs/ANTIGRAVITY_AUTH.md
@@ -378,7 +378,7 @@ const antigravityPlugin = {
description: "OAuth flow for Google Antigravity (Cloud Code Assist)",
configSchema: emptyPluginConfigSchema(),
- register(api: OpenClawPluginApi) {
+ register(api: PicoClawPluginApi) {
api.registerProvider({
id: "google-antigravity",
label: "Google Antigravity",
@@ -405,7 +405,7 @@ const antigravityPlugin = {
```typescript
type ProviderAuthContext = {
- config: OpenClawConfig;
+ config: PicoClawConfig;
agentDir?: string;
workspaceDir?: string;
prompter: WizardPrompter; // UI prompts/notifications
@@ -426,7 +426,7 @@ type ProviderAuthResult = {
profileId: string;
credential: AuthProfileCredential;
}>;
- configPatch?: Partial;
+ configPatch?: Partial;
defaultModel?: string;
notes?: string[];
};
@@ -438,10 +438,9 @@ type ProviderAuthResult = {
### 1. Required Environment/Dependencies
-- Node.js ≥ 22
-- OpenClaw plugin-sdk
-- crypto module (built-in)
-- http module (built-in)
+- Go ≥ 1.21
+- PicoClaw codebase (`pkg/providers/` and `pkg/auth/`)
+- `crypto` and `net/http` standard library packages
### 2. Required Headers for API Calls
@@ -572,36 +571,40 @@ Each SSE message (`data: {...}`) is wrapped in a `response` field:
## Configuration
-### openclaw.json Configuration
+### config.json Configuration
-```json5
+```json
{
- agents: {
- defaults: {
- model: {
- primary: "google-antigravity/claude-opus-4-6-thinking",
- },
- },
- },
+ "model_list": [
+ {
+ "model_name": "gemini-flash",
+ "model": "antigravity/gemini-3-flash",
+ "auth_method": "oauth"
+ }
+ ],
+ "agents": {
+ "defaults": {
+ "model": "gemini-flash"
+ }
+ }
}
```
### Auth Profile Storage
-Auth profiles are stored in `~/.openclaw/agent/auth-profiles.json`:
+Auth profiles are stored in `~/.picoclaw/auth.json`:
```json
{
- "version": 1,
- "profiles": {
- "google-antigravity:user@example.com": {
- "type": "oauth",
+ "credentials": {
+ "google-antigravity": {
+ "access_token": "ya29...",
+ "refresh_token": "1//...",
+ "expires_at": "2026-01-01T00:00:00Z",
"provider": "google-antigravity",
- "access": "ya29...",
- "refresh": "1//...",
- "expires": 1704067200000,
+ "auth_method": "oauth",
"email": "user@example.com",
- "projectId": "my-project-id"
+ "project_id": "my-project-id"
}
}
}
@@ -611,277 +614,85 @@ Auth profiles are stored in `~/.openclaw/agent/auth-profiles.json`:
## Creating a New Provider in PicoClaw
+PicoClaw providers are implemented as Go packages under `pkg/providers/`. To add a new provider:
+
### Step-by-Step Implementation
-#### 1. Create Plugin Structure
+#### 1. Create Provider File
+
+Create a new Go file in `pkg/providers/`:
```
-extensions/
-└── your-provider-auth/
- ├── openclaw.plugin.json
- ├── package.json
- ├── README.md
- └── index.ts
+pkg/providers/
+└── your_provider.go
```
-#### 2. Define Plugin Manifest
+#### 2. Implement the Provider Interface
-**openclaw.plugin.json:**
-```json
-{
- "id": "your-provider-auth",
- "providers": ["your-provider"],
- "configSchema": {
- "type": "object",
- "additionalProperties": false,
- "properties": {}
- }
+Your provider must implement the `Provider` interface defined in `pkg/providers/types.go`:
+
+```go
+package providers
+
+type YourProvider struct {
+ apiKey string
+ apiBase string
}
-```
-**package.json:**
-```json
-{
- "name": "@openclaw/your-provider-auth",
- "version": "1.0.0",
- "private": true,
- "description": "Your Provider OAuth plugin",
- "type": "module"
-}
-```
-
-#### 3. Implement OAuth Flow
-
-```typescript
-import {
- buildOauthProviderAuthResult,
- emptyPluginConfigSchema,
- type OpenClawPluginApi,
- type ProviderAuthContext,
-} from "openclaw/plugin-sdk";
-
-const YOUR_CLIENT_ID = "your-client-id";
-const YOUR_CLIENT_SECRET = "your-client-secret";
-const AUTH_URL = "https://provider.com/oauth/authorize";
-const TOKEN_URL = "https://provider.com/oauth/token";
-const REDIRECT_URI = "http://localhost:PORT/oauth-callback";
-
-async function loginYourProvider(params: {
- isRemote: boolean;
- openUrl: (url: string) => Promise;
- prompt: (message: string) => Promise;
- note: (message: string, title?: string) => Promise;
- log: (message: string) => void;
- progress: { update: (msg: string) => void; stop: (msg?: string) => void };
-}) {
- // 1. Generate PKCE
- const { verifier, challenge } = generatePkce();
- const state = randomBytes(16).toString("hex");
-
- // 2. Build auth URL
- const authUrl = buildAuthUrl({ challenge, state });
-
- // 3. Start callback server (if not remote)
- const callbackServer = !params.isRemote
- ? await startCallbackServer({ timeoutMs: 5 * 60 * 1000 })
- : null;
-
- // 4. Open browser or show URL
- if (callbackServer) {
- await params.openUrl(authUrl);
- const callback = await callbackServer.waitForCallback();
- code = callback.searchParams.get("code");
- } else {
- await params.note(`Auth URL: ${authUrl}`, "OAuth");
- const input = await params.prompt("Paste redirect URL:");
- const parsed = parseCallbackInput(input);
- code = parsed.code;
- }
-
- // 5. Exchange code for tokens
- const tokens = await exchangeCode({ code, verifier });
-
- // 6. Fetch additional user data
- const email = await fetchUserEmail(tokens.access);
-
- return { ...tokens, email };
-}
-```
-
-#### 4. Register Provider
-
-```typescript
-const yourProviderPlugin = {
- id: "your-provider-auth",
- name: "Your Provider Auth",
- description: "OAuth for Your Provider",
- configSchema: emptyPluginConfigSchema(),
-
- register(api: OpenClawPluginApi) {
- api.registerProvider({
- id: "your-provider",
- label: "Your Provider",
- docsPath: "/providers/models",
- aliases: ["yp"],
-
- auth: [
- {
- id: "oauth",
- label: "OAuth Login",
- hint: "Browser-based authentication",
- kind: "oauth",
-
- run: async (ctx: ProviderAuthContext) => {
- const spin = ctx.prompter.progress("Starting OAuth...");
-
- try {
- const result = await loginYourProvider({
- isRemote: ctx.isRemote,
- openUrl: ctx.openUrl,
- prompt: async (msg) => String(await ctx.prompter.text({ message: msg })),
- note: ctx.prompter.note,
- log: (msg) => ctx.runtime.log(msg),
- progress: spin,
- });
-
- return buildOauthProviderAuthResult({
- providerId: "your-provider",
- defaultModel: "your-provider/model-name",
- access: result.access,
- refresh: result.refresh,
- expires: result.expires,
- email: result.email,
- notes: ["Provider-specific notes"],
- });
- } catch (err) {
- spin.stop("OAuth failed");
- throw err;
- }
- },
- },
- ],
- });
- },
-};
-
-export default yourProviderPlugin;
-```
-
-#### 5. Implement Usage Tracking (Optional)
-
-```typescript
-// src/infra/provider-usage.fetch.your-provider.ts
-export async function fetchYourProviderUsage(
- token: string,
- timeoutMs: number,
- fetchFn: typeof fetch
-): Promise {
- // Fetch usage data from provider API
- const response = await fetchFn("https://api.provider.com/usage", {
- headers: { Authorization: `Bearer ${token}` },
- });
-
- const data = await response.json();
-
- return {
- provider: "your-provider",
- displayName: "Your Provider",
- windows: [
- { label: "Credits", usedPercent: data.usedPercent },
- ],
- plan: data.planName,
- };
-}
-```
-
-#### 6. Register Usage Fetcher
-
-```typescript
-// src/infra/provider-usage.load.ts
-case "your-provider":
- return await fetchYourProviderUsage(auth.token, timeoutMs, fetchFn);
-```
-
-#### 7. Add Provider to Type Definitions
-
-```typescript
-// src/infra/provider-usage.types.ts
-export type SupportedProvider =
- | "anthropic"
- | "github-copilot"
- | "google-gemini-cli"
- | "google-antigravity"
- | "your-provider" // Add here
- | "minimax"
- | "openai-codex";
-```
-
-#### 8. Add Auth Choice Handler
-
-```typescript
-// src/commands/auth-choice.apply.your-provider.ts
-import { applyAuthChoicePluginProvider } from "./auth-choice.apply.plugin-provider.js";
-
-export async function applyAuthChoiceYourProvider(
- params: ApplyAuthChoiceParams
-): Promise {
- return await applyAuthChoicePluginProvider(params, {
- authChoice: "your-provider",
- pluginId: "your-provider-auth",
- providerId: "your-provider",
- methodId: "oauth",
- label: "Your Provider",
- });
-}
-```
-
-#### 9. Export from Main Index
-
-```typescript
-// src/commands/auth-choice.apply.ts
-import { applyAuthChoiceYourProvider } from "./auth-choice.apply.your-provider.js";
-
-// In the switch statement:
-case "your-provider":
- return await applyAuthChoiceYourProvider(params);
-```
-
-### Helper Utilities
-
-#### PKCE Generation
-```typescript
-function generatePkce(): { verifier: string; challenge: string } {
- const verifier = randomBytes(32).toString("hex");
- const challenge = createHash("sha256").update(verifier).digest("base64url");
- return { verifier, challenge };
-}
-```
-
-#### Callback Server
-```typescript
-async function startCallbackServer(params: { timeoutMs: number }) {
- const port = 51121; // Your port
-
- const server = createServer((request, response) => {
- const url = new URL(request.url!, `http://localhost:${port}`);
-
- if (url.pathname === "/oauth-callback") {
- response.writeHead(200, { "Content-Type": "text/html" });
- response.end("Authentication complete
");
- resolveCallback(url);
- server.close();
+func NewYourProvider(apiKey, apiBase, proxy string) *YourProvider {
+ if apiBase == "" {
+ apiBase = "https://api.your-provider.com/v1"
}
- });
-
- await new Promise((resolve, reject) => {
- server.listen(port, "127.0.0.1", resolve);
- server.once("error", reject);
- });
-
- return {
- waitForCallback: () => callbackPromise,
- close: () => new Promise((resolve) => server.close(resolve)),
- };
+ return &YourProvider{apiKey: apiKey, apiBase: apiBase}
+}
+
+func (p *YourProvider) Chat(ctx context.Context, messages []Message, tools []Tool, cb StreamCallback) error {
+ // Implement chat completion with streaming
+}
+```
+
+#### 3. Register in the Factory
+
+Add your provider to the protocol switch in `pkg/providers/factory.go`:
+
+```go
+case "your-provider":
+ return NewYourProvider(sel.apiKey, sel.apiBase, sel.proxy), nil
+```
+
+#### 4. Add Default Config (Optional)
+
+Add a default entry in `pkg/config/defaults.go`:
+
+```go
+{
+ ModelName: "your-model",
+ Model: "your-provider/model-name",
+ APIKey: "",
+},
+```
+
+#### 5. Add Auth Support (Optional)
+
+If your provider requires OAuth or special authentication, add a case to `cmd/picoclaw/cmd_auth.go`:
+
+```go
+case "your-provider":
+ authLoginYourProvider()
+```
+
+#### 6. Configure via `config.json`
+
+```json
+{
+ "model_list": [
+ {
+ "model_name": "your-model",
+ "model": "your-provider/model-name",
+ "api_key": "your-api-key",
+ "api_base": "https://api.your-provider.com/v1"
+ }
+ ]
}
```
@@ -892,33 +703,27 @@ async function startCallbackServer(params: { timeoutMs: number }) {
### CLI Commands
```bash
-# Enable the plugin
-openclaw plugins enable your-provider-auth
+# Authenticate with a provider
+picoclaw auth login --provider your-provider
-# Restart gateway
-openclaw gateway restart
+# List models (for Antigravity)
+picoclaw auth models
-# Authenticate
-openclaw models auth login --provider your-provider --set-default
+# Start the gateway
+picoclaw gateway
-# List models
-openclaw models list
-
-# Set model
-openclaw models set your-provider/model-name
-
-# Check usage
-openclaw models usage
+# Run an agent with a specific model
+picoclaw agent -m "Hello" --model your-model
```
### Environment Variables for Testing
```bash
-# Test specific providers only
-export OPENCLAW_LIVE_PROVIDERS="your-provider,google-antigravity"
+# Override default model
+export PICOCLAW_AGENTS_DEFAULTS_MODEL=your-model
-# Test with specific models
-export OPENCLAW_LIVE_GATEWAY_MODELS="your-provider/model-name"
+# Override provider settings
+export PICOCLAW_MODEL_LIST='[{"model_name":"your-model","model":"your-provider/model-name","api_key":"..."}]'
```
---
@@ -926,16 +731,16 @@ export OPENCLAW_LIVE_GATEWAY_MODELS="your-provider/model-name"
## References
- **Source Files:**
- - `extensions/google-antigravity-auth/index.ts` - Full OAuth implementation
- - `src/infra/provider-usage.fetch.antigravity.ts` - Usage fetching
- - `src/agents/pi-embedded-runner/google.ts` - Model sanitization
- - `src/agents/model-forward-compat.ts` - Forward compatibility
- - `src/plugin-sdk/provider-auth-result.ts` - Auth result builder
- - `src/plugins/types.ts` - Plugin type definitions
+ - `pkg/providers/antigravity_provider.go` - Antigravity provider implementation
+ - `pkg/auth/oauth.go` - OAuth flow implementation
+ - `pkg/auth/store.go` - Auth credential storage (`~/.picoclaw/auth.json`)
+ - `pkg/providers/factory.go` - Provider factory and protocol routing
+ - `pkg/providers/types.go` - Provider interface definitions
+ - `cmd/picoclaw/cmd_auth.go` - Auth CLI commands
- **Documentation:**
- - `docs/concepts/model-providers.md` - Provider overview
- - `docs/concepts/usage-tracking.md` - Usage tracking
+ - `docs/ANTIGRAVITY_USAGE.md` - Antigravity usage guide
+ - `docs/migration/model-list-migration.md` - Migration guide
---
@@ -987,7 +792,7 @@ Some models might show up in the available models list but return an empty respo
## Troubleshooting
### "Token expired"
-- Refresh OAuth tokens: `openclaw models auth login --provider google-antigravity`
+- Refresh OAuth tokens: `picoclaw auth login --provider antigravity`
### "Gemini for Google Cloud is not enabled"
- Enable the API in your Google Cloud Console
@@ -998,5 +803,5 @@ Some models might show up in the available models list but return an empty respo
### Models not appearing in list
- Verify OAuth authentication completed successfully
-- Check auth profile storage: `~/.openclaw/agent/auth-profiles.json`
-- Ensure the plugin is enabled: `openclaw plugins list`
+- Check auth profile storage: `~/.picoclaw/auth.json`
+- Re-run `picoclaw auth login --provider antigravity`
diff --git a/docs/ANTIGRAVITY_USAGE.md b/docs/ANTIGRAVITY_USAGE.md
index 8bf1fdfdb..e8194b6bc 100644
--- a/docs/ANTIGRAVITY_USAGE.md
+++ b/docs/ANTIGRAVITY_USAGE.md
@@ -47,14 +47,12 @@ picoclaw agent -m "Hello" --model claude-opus-4-6-thinking
If you are deploying via Coolify or Docker, follow these steps to test:
-1. **Branch**: Use the `feat/antigravity-provider` branch.
-2. **Environment Variables**:
- * `PICOCLAW_AGENTS_DEFAULTS_PROVIDER=antigravity`
- * `PICOCLAW_AGENTS_DEFAULTS_MODEL=gemini-3-flash`
-3. **Authentication persistence**:
+1. **Environment Variables**:
+ * `PICOCLAW_AGENTS_DEFAULTS_MODEL=gemini-flash`
+2. **Authentication persistence**:
If you've logged in locally, you can copy your credentials to the server:
```bash
- scp ~/.picoclaw/auth-profiles.json user@your-server:~/.picoclaw/
+ scp ~/.picoclaw/auth.json user@your-server:~/.picoclaw/
```
*Alternatively*, run the `auth login` command once on the server if you have terminal access.
diff --git a/docs/channels/dingtalk/README.zh.md b/docs/channels/dingtalk/README.zh.md
new file mode 100644
index 000000000..1e445d0b0
--- /dev/null
+++ b/docs/channels/dingtalk/README.zh.md
@@ -0,0 +1,33 @@
+# 钉钉
+
+钉钉是阿里巴巴的企业通讯平台,在中国职场中广受欢迎。它采用流式 SDK 来维持持久连接。
+
+## 配置
+
+```json
+{
+ "channels": {
+ "dingtalk": {
+ "enabled": true,
+ "client_id": "YOUR_CLIENT_ID",
+ "client_secret": "YOUR_CLIENT_SECRET",
+ "allow_from": []
+ }
+ }
+}
+```
+
+| 字段 | 类型 | 必填 | 描述 |
+| ------------- | ------ | ---- | -------------------------------- |
+| enabled | bool | 是 | 是否启用钉钉频道 |
+| client_id | string | 是 | 钉钉应用的 Client ID |
+| client_secret | string | 是 | 钉钉应用的 Client Secret |
+| allow_from | array | 否 | 用户ID白名单,空表示允许所有用户 |
+
+## 设置流程
+
+1. 前往 [钉钉开放平台](https://open.dingtalk.com/)
+2. 创建一个企业内部应用
+3. 从应用设置中获取 Client ID 和 Client Secret
+4. 配置OAuth和事件订阅(如需要)
+5. 将 Client ID 和 Client Secret 填入配置文件中
diff --git a/docs/channels/discord/README.zh.md b/docs/channels/discord/README.zh.md
new file mode 100644
index 000000000..5b597eced
--- /dev/null
+++ b/docs/channels/discord/README.zh.md
@@ -0,0 +1,35 @@
+# Discord
+
+Discord 是一个专为社区设计的免费语音、视频和文本聊天应用。PicoClaw 通过 Discord Bot API 连接到 Discord 服务器,支持接收和发送消息。
+
+## 配置
+
+```json
+{
+ "channels": {
+ "discord": {
+ "enabled": true,
+ "token": "YOUR_BOT_TOKEN",
+ "allow_from": ["YOUR_USER_ID"],
+ "mention_only": false
+ }
+ }
+}
+```
+
+| 字段 | 类型 | 必填 | 描述 |
+| ------------ | ------ | ---- | -------------------------------- |
+| enabled | bool | 是 | 是否启用 Discord 频道 |
+| token | string | 是 | Discord 机器人 Token |
+| allow_from | array | 否 | 用户ID白名单,空表示允许所有用户 |
+| mention_only | bool | 否 | 是否仅响应提及机器人的消息 |
+
+## 设置流程
+
+1. 前往 [Discord 开发者门户](https://discord.com/developers/applications) 创建一个新的应用
+2. 启用 Intents:
+ - Message Content Intent
+ - Server Members Intent
+3. 获取 Bot Token
+4. 将 Bot Token 填入配置文件中
+5. 邀请机器人加入服务器并授予必要权限(例如发送消息、读取消息历史等)
diff --git a/docs/channels/feishu/README.zh.md b/docs/channels/feishu/README.zh.md
new file mode 100644
index 000000000..310827723
--- /dev/null
+++ b/docs/channels/feishu/README.zh.md
@@ -0,0 +1,37 @@
+# 飞书
+
+飞书(国际版名称:Lark)是字节跳动旗下的企业协作平台。它通过事件驱动的 Webhook 同时支持中国和全球市场。
+
+## 配置
+
+```json
+{
+ "channels": {
+ "feishu": {
+ "enabled": true,
+ "app_id": "cli_xxx",
+ "app_secret": "xxx",
+ "encrypt_key": "",
+ "verification_token": "",
+ "allow_from": []
+ }
+ }
+}
+```
+
+| 字段 | 类型 | 必填 | 描述 |
+| ------------------ | ------ | ---- | -------------------------------- |
+| enabled | bool | 是 | 是否启用飞书频道 |
+| app_id | string | 是 | 飞书应用的 App ID(以cli\_开头) |
+| app_secret | string | 是 | 飞书应用的 App Secret |
+| encrypt_key | string | 否 | 事件回调加密密钥 |
+| verification_token | string | 否 | 用于Webhook事件验证的Token |
+| allow_from | array | 否 | 用户ID白名单,空表示允许所有用户 |
+
+## 设置流程
+
+1. 前往 [飞书开放平台](https://open.feishu.cn/)创建应用程序
+2. 获取 App ID 和 App Secret
+3. 配置事件订阅和Webhook URL
+4. 设置加密(可选,生产环境建议启用)
+5. 将 App ID、App Secret、Encrypt Key 和 Verification Token(如果启用加密) 填入配置文件中
diff --git a/docs/channels/line/README.zh.md b/docs/channels/line/README.zh.md
new file mode 100644
index 000000000..fd3aa80da
--- /dev/null
+++ b/docs/channels/line/README.zh.md
@@ -0,0 +1,41 @@
+# Line
+
+PicoClaw 通过 LINE Messaging API 配合 Webhook 回调功能实现对 LINE 的支持。
+
+## 配置
+
+```json
+{
+ "channels": {
+ "line": {
+ "enabled": true,
+ "channel_secret": "YOUR_CHANNEL_SECRET",
+ "channel_access_token": "YOUR_CHANNEL_ACCESS_TOKEN",
+ "webhook_host": "0.0.0.0",
+ "webhook_port": 18791,
+ "webhook_path": "/webhook/line",
+ "allow_from": []
+ }
+ }
+}
+```
+
+| 字段 | 类型 | 必填 | 描述 |
+| -------------------- | ------ | ---- | ------------------------------------------ |
+| enabled | bool | 是 | 是否启用 LINE Channel |
+| channel_secret | string | 是 | LINE Messaging API 的 Channel Secret |
+| channel_access_token | string | 是 | LINE Messaging API 的 Channel Access Token |
+| webhook_host | string | 是 | Webhook 监听的主机地址 (通常为 0.0.0.0) |
+| webhook_port | int | 是 | Webhook 监听的端口 (默认为 18791) |
+| webhook_path | string | 是 | Webhook 的路径 (默认为 /webhook/line) |
+| allow_from | array | 否 | 用户ID白名单,空表示允许所有用户 |
+
+## 设置流程
+
+1. 前往 [LINE Developers Console](https://developers.line.biz/console/) 创建一个服务提供商和一个 Messaging API Channel
+2. 获取 Channel Secret 和 Channel Access Token
+3. 配置Webhook:
+ - Line要求Webhook必须使用HTTPS协议,因此需要部署一个支持HTTPS的服务器,或者使用反向代理工具如ngrok将本地服务器暴露到公网
+ - 将 Webhook URL 设置为 `https://your-domain.com/webhook/line`
+ - 启用 Webhook 并验证 URL
+4. 将 Channel Secret 和 Channel Access Token 填入配置文件中
diff --git a/docs/channels/maixcam/README.zh.md b/docs/channels/maixcam/README.zh.md
new file mode 100644
index 000000000..8d53d4bef
--- /dev/null
+++ b/docs/channels/maixcam/README.zh.md
@@ -0,0 +1,31 @@
+# MaixCam
+
+MaixCam 是专用于连接矽速科技 MaixCAM 与 MaixCAM2 AI 摄像设备的通道。它采用 TCP 套接字实现双向通信,支持边缘 AI 部署场景。
+
+## 配置
+
+```json
+{
+ "channels": {
+ "maixcam": {
+ "enabled": true,
+ "server_address": "0.0.0.0:8899",
+ "allow_from": []
+ }
+ }
+}
+```
+
+| 字段 | 类型 | 必填 | 描述 |
+| -------------- | ------ | ---- | -------------------------------- |
+| enabled | bool | 是 | 是否启用 MaixCam 频道 |
+| server_address | string | 是 | TCP 服务器监听地址和端口 |
+| allow_from | array | 否 | 设备ID白名单,空表示允许所有设备 |
+
+## 使用场景
+
+MaixCam 通道使 PicoClaw 能够作为边缘设备的 AI 后端运行:
+
+- **智能监控** :MaixCAM 发送图像帧,PicoClaw 通过视觉模型进行分析
+- **物联网控制** :设备发送传感器数据,PicoClaw 协调响应
+- **离线AI** :在本地网络部署 PicoClaw 实现低延迟推理
diff --git a/docs/channels/onebot/README.zh.md b/docs/channels/onebot/README.zh.md
new file mode 100644
index 000000000..6195f1c98
--- /dev/null
+++ b/docs/channels/onebot/README.zh.md
@@ -0,0 +1,31 @@
+# OneBot
+
+OneBot 是一个面向 QQ 机器人的开放协议标准,为多种 QQ 机器人实现(例如 go-cqhttp、Mirai)提供了统一的接口。它使用 WebSocket 进行通信。
+
+## 配置
+
+```json
+{
+ "channels": {
+ "onebot": {
+ "enabled": true,
+ "ws_url": "ws://localhost:8080",
+ "access_token": "",
+ "allow_from": []
+ }
+ }
+}
+```
+
+| 字段 | 类型 | 必填 | 描述 |
+| ------------ | ------ | ---- | -------------------------------- |
+| enabled | bool | 是 | 是否启用 OneBot 频道 |
+| ws_url | string | 是 | OneBot 服务器的 WebSocket URL |
+| access_token | string | 否 | 连接 OneBot 服务器的访问令牌 |
+| allow_from | array | 否 | 用户ID白名单,空表示允许所有用户 |
+
+## 设置流程
+
+1. 部署一个 OneBot 兼容的实现(例如napcat)
+2. 配置 OneBot 实现以启用 WebSocket 服务并设置访问令牌(如果需要)
+3. 将 WebSocket URL 和访问令牌填入配置文件中
diff --git a/docs/channels/qq/README.zh.md b/docs/channels/qq/README.zh.md
new file mode 100644
index 000000000..bd774960f
--- /dev/null
+++ b/docs/channels/qq/README.zh.md
@@ -0,0 +1,32 @@
+# QQ
+
+PicoClaw 通过 QQ 开放平台的官方机器人 API 提供对 QQ 的支持。
+
+## 配置
+
+```json
+{
+ "channels": {
+ "qq": {
+ "enabled": true,
+ "app_id": "YOUR_APP_ID",
+ "app_secret": "YOUR_APP_SECRET",
+ "allow_from": []
+ }
+ }
+}
+```
+
+| 字段 | 类型 | 必填 | 描述 |
+| ---------- | ------ | ---- | -------------------------------- |
+| enabled | bool | 是 | 是否启用 QQ Channel |
+| app_id | string | 是 | QQ 机器人应用的 App ID |
+| app_secret | string | 是 | QQ 机器人应用的 App Secret |
+| allow_from | array | 否 | 用户ID白名单,空表示允许所有用户 |
+
+## 设置流程
+
+1. 前往 [QQ 开放平台](https://q.qq.com/) 创建一个机器人
+2. 通过仪表盘获取 App ID 和 App Secret
+3. 开启机器人沙箱模式, 将用户和群添加到沙箱中
+4. 将 App ID 和 App Secret 填入配置文件中
diff --git a/docs/channels/slack/README.zh.md b/docs/channels/slack/README.zh.md
new file mode 100644
index 000000000..58ebcb566
--- /dev/null
+++ b/docs/channels/slack/README.zh.md
@@ -0,0 +1,33 @@
+# Slack
+
+Slack 是全球领先的企业级即时通讯平台。PicoClaw 采用 Slack 的 Socket Mode 实现实时双向通信,无需配置公开的 Webhook 端点。
+
+## 配置
+
+```json
+{
+ "channels": {
+ "slack": {
+ "enabled": true,
+ "bot_token": "xoxb-...",
+ "app_token": "xapp-...",
+ "allow_from": []
+ }
+ }
+}
+```
+
+| 字段 | 类型 | 必填 | 描述 |
+| ---------- | ------ | ---- | -------------------------------------------------------- |
+| enabled | bool | 是 | 是否启用 Slack 频道 |
+| bot_token | string | 是 | Slack 机器人的 Bot User OAuth Token (以 xoxb- 开头) |
+| app_token | string | 是 | Slack 应用的 Socket Mode App Level Token (以 xapp- 开头) |
+| allow_from | array | 否 | 用户ID白名单,空表示允许所有用户 |
+
+## 设置流程
+
+1. 前往 [Slack API](https://api.slack.com/) 创建一个新的 Slack 应用
+2. 启用 Socket Mode 并获取 App Level Token
+3. 添加 Bot Token Scopes(例如`chat:write`、`im:history`等)
+4. 安装应用到工作区并获取 Bot User OAuth Token
+5. 将 Bot Token 和 App Token 填入配置文件中
diff --git a/docs/channels/telegram/README.zh.md b/docs/channels/telegram/README.zh.md
new file mode 100644
index 000000000..d453c68fa
--- /dev/null
+++ b/docs/channels/telegram/README.zh.md
@@ -0,0 +1,33 @@
+# Telegram
+
+Telegram Channel 通过 Telegram 机器人 API 使用长轮询实现基于机器人的通信。它支持文本消息、媒体附件(照片、语音、音频、文档)、通过 Groq Whisper 进行语音转录以及内置命令处理器。
+
+## 配置
+
+```json
+{
+ "channels": {
+ "telegram": {
+ "enabled": true,
+ "token": "123456789:ABCdefGHIjklMNOpqrsTUVwxyz",
+ "allow_from": ["123456789"],
+ "proxy": ""
+ }
+ }
+}
+```
+
+| 字段 | 类型 | 必填 | 描述 |
+| ---------- | ------ | ---- | --------------------------------------------------------- |
+| enabled | bool | 是 | 是否启用 Telegram 频道 |
+| token | string | 是 | Telegram 机器人 API Token |
+| allow_from | array | 否 | 用户ID白名单,空表示允许所有用户 |
+| proxy | string | 否 | 连接 Telegram API 的代理 URL (例如 http://127.0.0.1:7890) |
+
+## 设置流程
+
+1. 在 Telegram 中搜索 `@BotFather`
+2. 发送 `/newbot` 命令并按照提示创建新机器人
+3. 获取 HTTP API Token
+4. 将 Token 填入配置文件中
+5. (可选) 配置 `allow_from` 以限制允许互动的用户 ID (可通过 `@userinfobot` 获取 ID)
diff --git a/docs/channels/wecom/wecom_app/README.zh.md b/docs/channels/wecom/wecom_app/README.zh.md
new file mode 100644
index 000000000..1e6a0e2b3
--- /dev/null
+++ b/docs/channels/wecom/wecom_app/README.zh.md
@@ -0,0 +1,47 @@
+# 企业微信自建应用
+
+企业微信自建应用是指企业在企业微信中创建的应用,主要用于企业内部使用。通过企业微信自建应用,企业可以实现与员工的高效沟通和协作,提高工作效率。
+
+## 配置
+
+```json
+{
+ "channels": {
+ "wecom_app": {
+ "enabled": true,
+ "corp_id": "wwxxxxxxxxxxxxxxxx",
+ "corp_secret": "YOUR_CORP_SECRET",
+ "agent_id": 1000002,
+ "token": "YOUR_TOKEN",
+ "encoding_aes_key": "YOUR_ENCODING_AES_KEY",
+ "webhook_host": "0.0.0.0",
+ "webhook_port": 18792,
+ "webhook_path": "/webhook/wecom-app",
+ "allow_from": [],
+ "reply_timeout": 5
+ }
+ }
+}
+```
+
+| 字段 | 类型 | 必填 | 描述 |
+| ---------------- | ------ | ---- | ---------------------------------------- |
+| corp_id | string | 是 | 企业 ID |
+| corp_secret | string | 是 | 应用程序密钥 |
+| agent_id | int | 是 | 应用程序代理 ID |
+| token | string | 是 | 回调验证令牌 |
+| encoding_aes_key | string | 是 | 43 字符 AES 密钥 |
+| webhook_host | string | 否 | HTTP 服务器绑定地址 |
+| webhook_port | int | 否 | HTTP 服务器端口(默认:18792) |
+| webhook_path | string | 否 | Webhook 路径(默认:/webhook/wecom-app) |
+| allow_from | array | 否 | 用户 ID 白名单 |
+| reply_timeout | int | 否 | 回复超时时间(秒) |
+
+## 设置流程
+
+1. 登录 [企业微信管理后台](https://work.weixin.qq.com/)
+2. 进入“应用管理” -> “创建应用”
+3. 获取企业 ID (CorpID) 和应用 Secret
+4. 在应用设置中配置“接收消息”,获取 Token 和 EncodingAESKey
+5. 设置回调 URL 为 `http://:/webhook/wecom-app`
+6. 将 CorpID, Secret, AgentID 等信息填入配置文件
diff --git a/docs/channels/wecom/wecom_bot/README.zh.md b/docs/channels/wecom/wecom_bot/README.zh.md
new file mode 100644
index 000000000..c4bb1c87e
--- /dev/null
+++ b/docs/channels/wecom/wecom_bot/README.zh.md
@@ -0,0 +1,41 @@
+# 企业微信机器人
+
+企业微信机器人是企业微信提供的一种快速接入方式,可以通过 Webhook URL 接收消息。
+
+## 配置
+
+```json
+{
+ "channels": {
+ "wecom": {
+ "enabled": true,
+ "token": "YOUR_TOKEN",
+ "encoding_aes_key": "YOUR_ENCODING_AES_KEY",
+ "webhook_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=YOUR_KEY",
+ "webhook_host": "0.0.0.0",
+ "webhook_port": 18793,
+ "webhook_path": "/webhook/wecom",
+ "allow_from": [],
+ "reply_timeout": 5
+ }
+ }
+}
+```
+
+| 字段 | 类型 | 必填 | 描述 |
+| ---------------- | ------ | ---- | -------------------------------------------- |
+| token | string | 是 | 签名验证代币 |
+| encoding_aes_key | string | 是 | 用于解密的 43 字符 AES 密钥 |
+| webhook_url | string | 是 | 用于发送回复的企业微信群聊机器人 Webhook URL |
+| webhook_host | string | 否 | HTTP 服务器绑定地址(默认:0.0.0.0) |
+| webhook_port | int | 否 | HTTP 服务器端口(默认:18793) |
+| webhook_path | string | 否 | Webhook 端点路径(默认:/webhook/wecom) |
+| allow_from | array | 否 | 用户 ID 白名单(空值 = 允许所有用户) |
+| reply_timeout | int | 否 | 回复超时时间(单位:秒,默认值:5) |
+
+## 设置流程
+
+1. 在企业微信群中添加机器人
+2. 获取 Webhook URL
+3. (如需接收消息) 在机器人配置页面设置接收消息的 API 地址(回调地址)以及 Token 和 EncodingAESKey
+4. 将相关信息填入配置文件
diff --git a/docs/design/provider-refactoring-tests.md b/docs/design/provider-refactoring-tests.md
index fc6429278..060be9ba8 100644
--- a/docs/design/provider-refactoring-tests.md
+++ b/docs/design/provider-refactoring-tests.md
@@ -1,7 +1,5 @@
# Provider Architecture Refactoring - Test Suite Summary
-> PRD: `tasks/prd-provider-refactoring.md`
-
This document summarizes the complete test suite designed for the Provider architecture refactoring.
## Test File Structure
@@ -12,10 +10,8 @@ pkg/
│ ├── model_config_test.go # US-001, US-002: ModelConfig struct and GetModelConfig tests
│ └── migration_test.go # US-003: Backward compatibility and migration tests
├── providers/
-│ ├── registry_test.go # US-006: Load balancing tests
-│ ├── integration_test.go # E2E integration tests
-│ └── factory/
-│ └── factory_test.go # US-004, US-005: Provider factory tests
+│ ├── factory_test.go # US-004, US-005: Provider factory tests
+│ └── factory_provider_test.go # Factory provider integration tests
```
---
@@ -122,7 +118,6 @@ go test ./pkg/... -race
# Run specific package tests
go test ./pkg/config -v
go test ./pkg/providers -v
-go test ./pkg/providers/factory -v
# Run E2E tests
go test ./pkg/providers -run TestE2E -v
diff --git a/docs/migration/model-list-migration.md b/docs/migration/model-list-migration.md
index 0682bae1a..589dfc043 100644
--- a/docs/migration/model-list-migration.md
+++ b/docs/migration/model-list-migration.md
@@ -85,6 +85,7 @@ The `model` field uses a protocol prefix format: `[protocol/]model-identifier`
| `openai/` | OpenAI API (default) | `openai/gpt-5.2` |
| `anthropic/` | Anthropic API | `anthropic/claude-opus-4` |
| `antigravity/` | Google via Antigravity OAuth | `antigravity/gemini-2.0-flash` |
+| `gemini/` | Google Gemini API | `gemini/gemini-2.0-flash-exp` |
| `claude-cli/` | Claude CLI (local) | `claude-cli/claude-sonnet-4.6` |
| `codex-cli/` | Codex CLI (local) | `codex-cli/codex-4` |
| `github-copilot/` | GitHub Copilot | `github-copilot/gpt-4o` |
@@ -93,6 +94,13 @@ The `model` field uses a protocol prefix format: `[protocol/]model-identifier`
| `deepseek/` | DeepSeek API | `deepseek/deepseek-chat` |
| `cerebras/` | Cerebras API | `cerebras/llama-3.3-70b` |
| `qwen/` | Alibaba Qwen | `qwen/qwen-max` |
+| `zhipu/` | Zhipu AI | `zhipu/glm-4` |
+| `nvidia/` | NVIDIA NIM | `nvidia/llama-3.1-nemotron-70b` |
+| `ollama/` | Ollama (local) | `ollama/llama3` |
+| `vllm/` | vLLM (local) | `vllm/my-model` |
+| `moonshot/` | Moonshot AI | `moonshot/moonshot-v1-8k` |
+| `shengsuanyun/` | ShengSuanYun | `shengsuanyun/deepseek-v3` |
+| `volcengine/` | Volcengine | `volcengine/doubao-pro-32k` |
**Note**: If no prefix is specified, `openai/` is used as the default.
diff --git a/docs/picoclaw_community_roadmap_260216.md b/docs/picoclaw_community_roadmap_260216.md
deleted file mode 100644
index cfcc30f17..000000000
--- a/docs/picoclaw_community_roadmap_260216.md
+++ /dev/null
@@ -1,112 +0,0 @@
-## 🚀 Join the PicoClaw Journey: Call for Community Volunteers & Roadmap Reveal
-
-**Hello, PicoClaw Community!**
-
-First, a massive thank you to everyone for your enthusiasm and PR contributions. It is because of you that PicoClaw continues to iterate and evolve so rapidly. Thanks to the simplicity and accessibility of the **Go language**, we’ve seen a non-stop stream of high-quality PRs!
-
-PicoClaw is growing much faster than we anticipated. As we are currently in the midst of the **Chinese New Year holiday**, we are looking to recruit community volunteers to help us maintain this incredible momentum.
-
-This document outlines the specific volunteer roles we need right now and provides a look at our upcoming **Roadmap**.
-
-### 🎁 Community Perks
-
-To show our appreciation, developers who officially join our community operations will receive:
-
-* **Exclusive AI Hardware:** Our upcoming, unreleased AI device.
-* **Token Discounts:** Potential discounts on LLM tokens (currently in negotiations with major providers).
-
-### 🎥 Calling All Content Creators!
-
-Not a developer? You can still help! We welcome users to post **PicoClaw reviews or tutorials**.
-
-* **Twitter:** Use the tag **#picoclaw** and mention **@SipeedIO**.
-* **Bilibili:** Mention **@Sipeed矽速科技** or send us a DM.
-We will be rewarding high-quality content creators with the same perks as our community developers!
-
----
-
-## 🛠️ Urgent Volunteer Roles
-
-We are looking for experts in the following areas:
-
-1. **Issue/PR Reviewers**
-* **The Mission:** With PRs and Issues exploding in volume, we need help with initial triage, evaluation, and merging.
-* **Focus:** Preliminary merging and community health. Efficiency optimization and security audits will be handled by specialized roles.
-
-
-2. **Resource Optimization Experts**
-* **The Mission:** Rapid growth has introduced dependencies that are making PicoClaw a bit "heavy." We want to keep it lean.
-* **Focus:** Analyzing resource growth between releases and trimming redundancy.
-* **Priority:** **RAM usage optimization** > Binary size reduction.
-
-
-3. **Security Audit & Bug Fixes**
-* **The Mission:** Due to the "vibe coding" nature of our early stages, we need a thorough review of network security and AI permission management.
-* **Focus:** Auditing the codebase for vulnerabilities and implementing robust fixes.
-
-
-4. **Documentation & DX (Developer Experience)**
-* **The Mission:** Our current README is a bit outdated. We need "step-by-step" guides that even beginners can follow.
-* **Focus:** Creating clear, user-friendly documentation for both setup and development.
-
-
-5. **AI-Powered CI/CD Optimization**
-* **The Mission:** PicoClaw started as a "vibe coding" experiment; now we want to use AI to manage it.
-* **Focus:** Automating builds with AI and exploring AI-driven issue resolution.
-
-**How to Apply:** > If you are interested in any of the roles above, please send an email to support@sipeed.com with the subject line: [Apply: PicoClaw Expert Volunteer] + Your Desired Role.
-Please include a brief introduction and any relevant experience or portfolio links. We will review all applications and grant project permissions to selected contributors!
-
----
-
-## 📍 The Roadmap
-
-Interested in a specific feature? You can "claim" these tasks and start building:
-
-###
-* **Provider:**
- * **Provider Refactor:** Currently being handled by **@Daming** (ETA: 5 days)
- * You can still submit code; Daming will merge it into the new implementation.
-* **Channels:**
- * Support for OneBot, additional platforms
- * attachments (images, audio, video, files).
-* **Skills:**
- * Implementing `find_skill` to discover tools via [openclaw/skills](https://github.com/openclaw/skills) and other platforms.
-* **Operations:** * MCP Support.
- * Android operations (e.g., botdrop).
- * Browser automation via CDP or ActionBook.
-
-
-* **Multi-Agent Ecosystem:**
- * **Basic Model-Agnet** S
- * **Model Routing:** Small models for easy tasks, large models for hard ones (to save tokens).
- * **Swarm Mode.**
- * **AIEOS Integration.**
-
-
-* **Branding:**
- * **Logo**: We need a cute logo! We’re leaning toward a **Mantis Shrimp**—small, but packs a legendary punch!
-
-
-We have officially created these tasks as GitHub Issues, all marked with the roadmap tag.
-This list will be updated continuously as we progress.
-If you would like to claim a task, please feel free to start a conversation by commenting directly on the corresponding issue!
-
----
-
-## 🤝 How to Join
-
-**Everything is open to your creativity!** If you have a wild idea, just PR it.
-
-1. **The Fast Track:** Once you have at least **one merged PR**, you are eligible to join our **Developer Discord** to help plan the future of PicoClaw.
-2. **The Application Track:** If you haven’t submitted a PR yet but want to dive in, email **support@sipeed.com** with the subject:
-> `[Apply Join PicoClaw Dev Group] + Your GitHub Account`
-> Include the role you're interested in and any evidence of your development experience.
-
-
-
-### Looking Ahead
-
-Powered by PicoClaw, we are crafting a Swarm AI Assistant to transform your environment into a seamless network of personal stewards. By automating the friction of daily life, we empower you to transcend the ordinary and freely explore your creative potential.
-
-**Finally, Happy Chinese New Year to everyone!** May PicoClaw gallop forward in this **Year of the Horse!** 🐎
diff --git a/docs/tools_configuration.md b/docs/tools_configuration.md
index 8777ddbd6..8aba1aa91 100644
--- a/docs/tools_configuration.md
+++ b/docs/tools_configuration.md
@@ -9,8 +9,8 @@ PicoClaw's tools configuration is located in the `tools` field of `config.json`.
"tools": {
"web": { ... },
"exec": { ... },
- "approval": { ... },
- "cron": { ... }
+ "cron": { ... },
+ "skills": { ... }
}
}
```
@@ -83,25 +83,12 @@ By default, PicoClaw blocks the following dangerous commands:
"custom_deny_patterns": [
"\\brm\\s+-r\\b",
"\\bkillall\\s+python"
- ],
+ ]
}
}
}
```
-## Approval Tool
-
-The approval tool controls permissions for dangerous operations.
-
-| Config | Type | Default | Description |
-|--------|------|---------|-------------|
-| `enabled` | bool | true | Enable approval functionality |
-| `write_file` | bool | true | Require approval for file writes |
-| `edit_file` | bool | true | Require approval for file edits |
-| `append_file` | bool | true | Require approval for file appends |
-| `exec` | bool | true | Require approval for command execution |
-| `timeout_minutes` | int | 5 | Approval timeout in minutes |
-
## Cron Tool
The cron tool is used for scheduling periodic tasks.
@@ -110,6 +97,40 @@ The cron tool is used for scheduling periodic tasks.
|--------|------|---------|-------------|
| `exec_timeout_minutes` | int | 5 | Execution timeout in minutes, 0 means no limit |
+## Skills Tool
+
+The skills tool configures skill discovery and installation via registries like ClawHub.
+
+### Registries
+
+| Config | Type | Default | Description |
+|--------|------|---------|-------------|
+| `registries.clawhub.enabled` | bool | true | Enable ClawHub registry |
+| `registries.clawhub.base_url` | string | `https://clawhub.ai` | ClawHub base URL |
+| `registries.clawhub.search_path` | string | `/api/v1/search` | Search API path |
+| `registries.clawhub.skills_path` | string | `/api/v1/skills` | Skills API path |
+| `registries.clawhub.download_path` | string | `/api/v1/download` | Download API path |
+
+### Configuration Example
+
+```json
+{
+ "tools": {
+ "skills": {
+ "registries": {
+ "clawhub": {
+ "enabled": true,
+ "base_url": "https://clawhub.ai",
+ "search_path": "/api/v1/search",
+ "skills_path": "/api/v1/skills",
+ "download_path": "/api/v1/download"
+ }
+ }
+ }
+ }
+}
+```
+
## Environment Variables
All configuration options can be overridden via environment variables with the format `PICOCLAW_TOOLS__`:
diff --git a/docs/wecom-app-configuration.md b/docs/wecom-app-configuration.md
new file mode 100644
index 000000000..3b17d37a7
--- /dev/null
+++ b/docs/wecom-app-configuration.md
@@ -0,0 +1,117 @@
+# 企业微信自建应用 (WeCom App) 配置指南
+
+本文档介绍如何在 PicoClaw 中配置企业微信自建应用 (wecom-app) 通道。
+
+## 功能特性
+
+| 功能 | 支持状态 |
+|------|---------|
+| 被动接收消息 | ✅ |
+| 主动发送消息 | ✅ |
+| 私聊 | ✅ |
+| 群聊 | ❌ |
+
+## 配置步骤
+
+### 1. 企业微信后台配置
+
+1. 登录 [企业微信管理后台](https://work.weixin.qq.com/wework_admin)
+2. 进入"应用管理" → 选择自建应用
+3. 记录以下信息:
+ - **AgentId**: 应用详情页显示
+ - **Secret**: 点击"查看"获取
+4. 进入"我的企业"页面,记录 **企业ID** (CorpID)
+
+### 2. 接收消息配置
+
+1. 在应用详情页,点击"接收消息"的"设置API接收"
+2. 填写以下信息:
+ - **URL**: `http://your-server:18792/webhook/wecom-app`
+ - **Token**: 随机生成或自定义(用于签名验证)
+ - **EncodingAESKey**: 点击"随机生成"生成43字符的密钥
+3. 点击"保存"时,企业微信会发送验证请求
+
+### 3. PicoClaw 配置
+
+在 `config.json` 中添加以下配置:
+
+```json
+{
+ "channels": {
+ "wecom_app": {
+ "enabled": true,
+ "corp_id": "wwxxxxxxxxxxxxxxxx", // 企业ID
+ "corp_secret": "xxxxxxxxxxxxxxxxxxxxxxxx", // 应用Secret
+ "agent_id": 1000002, // 应用AgentId
+ "token": "your_token", // 接收消息配置的Token
+ "encoding_aes_key": "your_encoding_aes_key", // 接收消息配置的EncodingAESKey
+ "webhook_host": "0.0.0.0",
+ "webhook_port": 18792,
+ "webhook_path": "/webhook/wecom-app",
+ "allow_from": [],
+ "reply_timeout": 5
+ }
+ }
+}
+```
+
+## 常见问题
+
+### 1. 回调URL验证失败
+
+**症状**: 企业微信保存API接收消息时提示验证失败
+
+**检查项**:
+- 确认服务器防火墙已开放 18792 端口
+- 确认 `corp_id`、`token`、`encoding_aes_key` 配置正确
+- 查看 PicoClaw 日志是否有请求到达
+
+### 2. 中文消息解密失败
+
+**症状**: 发送中文消息时出现 `invalid padding size` 错误
+
+**原因**: 企业微信使用非标准的 PKCS7 填充(32字节块大小)
+
+**解决**: 确保使用最新版本的 PicoClaw,已修复此问题。
+
+### 3. 端口冲突
+
+**症状**: 启动时提示端口已被占用
+
+**解决**: 修改 `webhook_port` 为其他端口,如 18794
+
+## 技术细节
+
+### 加密算法
+
+- **算法**: AES-256-CBC
+- **密钥**: EncodingAESKey Base64解码后的32字节
+- **IV**: AESKey的前16字节
+- **填充**: PKCS7(块大小为32字节,非标准16字节)
+- **消息格式**: XML
+
+### 消息结构
+
+解密后的消息格式:
+```
+random(16B) + msg_len(4B) + msg + receiveid
+```
+
+其中 `receiveid` 对于自建应用是 `corp_id`。
+
+## 调试
+
+启用调试模式查看详细日志:
+
+```bash
+picoclaw gateway --debug
+```
+
+关键日志标识:
+- `wecom_app`: WeCom App 通道相关日志
+- `wecom_common`: 加密解密相关日志
+
+## 参考文档
+
+- [企业微信官方文档 - 接收消息](https://developer.work.weixin.qq.com/document/path/96211)
+- [企业微信官方加解密库](https://github.com/sbzhu/weworkapi_golang)
diff --git a/go.mod b/go.mod
index 1f88639c8..98e20d07d 100644
--- a/go.mod
+++ b/go.mod
@@ -15,6 +15,7 @@ require (
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1
github.com/openai/openai-go/v3 v3.22.0
github.com/slack-go/slack v0.17.3
+ github.com/spf13/cobra v1.10.2
github.com/stretchr/testify v1.11.1
github.com/tencent-connect/botgo v0.2.1
golang.org/x/oauth2 v0.35.0
@@ -22,7 +23,9 @@ require (
require (
github.com/davecgh/go-spew v1.1.1 // indirect
+ github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
+ github.com/spf13/pflag v1.0.10 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
diff --git a/go.sum b/go.sum
index 0e95bf5cd..abbb11cd6 100644
--- a/go.sum
+++ b/go.sum
@@ -25,6 +25,7 @@ github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04=
github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8=
github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M=
github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU=
+github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
@@ -72,6 +73,8 @@ github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/ad
github.com/grbit/go-json v0.11.0 h1:bAbyMdYrYl/OjYsSqLH99N2DyQ291mHy726Mx+sYrnc=
github.com/grbit/go-json v0.11.0/go.mod h1:IYpHsdybQ386+6g3VE6AXQ3uTGa5mquBme5/ZWmtzek=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
+github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
+github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/klauspost/compress v1.18.4 h1:RPhnKRAQ4Fh8zU2FY/6ZFDwTVTxgJ/EMydqSTzE9a2c=
@@ -108,8 +111,14 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
+github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/slack-go/slack v0.17.3 h1:zV5qO3Q+WJAQ/XwbGfNFrRMaJ5T/naqaonyPV/1TP4g=
github.com/slack-go/slack v0.17.3/go.mod h1:X+UqOufi3LYQHDnMG1vxf0J8asC6+WllXrVrhl8/Prk=
+github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU=
+github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4=
+github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
+github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk=
+github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
@@ -151,6 +160,7 @@ github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9dec
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
+go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/arch v0.24.0 h1:qlJ3M9upxvFfwRM51tTg3Yl+8CP9vCC1E7vlFpgv99Y=
golang.org/x/arch v0.24.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
diff --git a/pkg/agent/context.go b/pkg/agent/context.go
index 27e3ef9dc..b7c6e1108 100644
--- a/pkg/agent/context.go
+++ b/pkg/agent/context.go
@@ -1,24 +1,38 @@
package agent
import (
+ "errors"
"fmt"
+ "io/fs"
"os"
"path/filepath"
"runtime"
"strings"
+ "sync"
"time"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/skills"
- "github.com/sipeed/picoclaw/pkg/tools"
)
type ContextBuilder struct {
workspace string
skillsLoader *skills.SkillsLoader
memory *MemoryStore
- tools *tools.ToolRegistry // Direct reference to tool registry
+
+ // Cache for system prompt to avoid rebuilding on every call.
+ // This fixes issue #607: repeated reprocessing of the entire context.
+ // The cache auto-invalidates when workspace source files change (mtime check).
+ systemPromptMutex sync.RWMutex
+ cachedSystemPrompt string
+ cachedAt time.Time // max observed mtime across tracked paths at cache build time
+
+ // existedAtCache tracks which source file paths existed the last time the
+ // cache was built. This lets sourceFilesChanged detect files that are newly
+ // created (didn't exist at cache time, now exist) or deleted (existed at
+ // cache time, now gone) — both of which should trigger a cache rebuild.
+ existedAtCache map[string]bool
}
func getGlobalConfigDir() string {
@@ -43,67 +57,29 @@ func NewContextBuilder(workspace string) *ContextBuilder {
}
}
-// SetToolsRegistry sets the tools registry for dynamic tool summary generation.
-func (cb *ContextBuilder) SetToolsRegistry(registry *tools.ToolRegistry) {
- cb.tools = registry
-}
-
func (cb *ContextBuilder) getIdentity() string {
- now := time.Now().Format("2006-01-02 15:04 (Monday)")
workspacePath, _ := filepath.Abs(filepath.Join(cb.workspace))
- runtime := fmt.Sprintf("%s %s, Go %s", runtime.GOOS, runtime.GOARCH, runtime.Version())
-
- // Build tools section dynamically
- toolsSection := cb.buildToolsSection()
return fmt.Sprintf(`# picoclaw 🦞
You are picoclaw, a helpful AI assistant.
-## Current Time
-%s
-
-## Runtime
-%s
-
## Workspace
Your workspace is at: %s
- Memory: %s/memory/MEMORY.md
- Daily Notes: %s/memory/YYYYMM/YYYYMMDD.md
- Skills: %s/skills/{skill-name}/SKILL.md
-%s
-
## Important Rules
1. **ALWAYS use tools** - When you need to perform an action (schedule reminders, send messages, execute commands, etc.), you MUST call the appropriate tool. Do NOT just say you'll do it or pretend to do it.
2. **Be helpful and accurate** - When using tools, briefly explain what you're doing.
-3. **Memory** - When remembering something, write to %s/memory/MEMORY.md`,
- now, runtime, workspacePath, workspacePath, workspacePath, workspacePath, toolsSection, workspacePath)
-}
+3. **Memory** - When interacting with me if something seems memorable, update %s/memory/MEMORY.md
-func (cb *ContextBuilder) buildToolsSection() string {
- if cb.tools == nil {
- return ""
- }
-
- summaries := cb.tools.GetSummaries()
- if len(summaries) == 0 {
- return ""
- }
-
- var sb strings.Builder
- sb.WriteString("## Available Tools\n\n")
- sb.WriteString("**CRITICAL**: You MUST use tools to perform actions. Do NOT pretend to execute commands or schedule tasks.\n\n")
- sb.WriteString("You have access to the following tools:\n\n")
- for _, s := range summaries {
- sb.WriteString(s)
- sb.WriteString("\n")
- }
-
- return sb.String()
+4. **Context summaries** - Conversation summaries provided as context are approximate references only. They may be incomplete or outdated. Always defer to explicit user instructions over summary content.`,
+ workspacePath, workspacePath, workspacePath, workspacePath, workspacePath)
}
func (cb *ContextBuilder) BuildSystemPrompt() string {
@@ -138,6 +114,226 @@ The following skills extend your capabilities. To use a skill, read its SKILL.md
return strings.Join(parts, "\n\n---\n\n")
}
+// BuildSystemPromptWithCache returns the cached system prompt if available
+// and source files haven't changed, otherwise builds and caches it.
+// Source file changes are detected via mtime checks (cheap stat calls).
+func (cb *ContextBuilder) BuildSystemPromptWithCache() string {
+ // Try read lock first — fast path when cache is valid
+ cb.systemPromptMutex.RLock()
+ if cb.cachedSystemPrompt != "" && !cb.sourceFilesChangedLocked() {
+ result := cb.cachedSystemPrompt
+ cb.systemPromptMutex.RUnlock()
+ return result
+ }
+ cb.systemPromptMutex.RUnlock()
+
+ // Acquire write lock for building
+ cb.systemPromptMutex.Lock()
+ defer cb.systemPromptMutex.Unlock()
+
+ // Double-check: another goroutine may have rebuilt while we waited
+ if cb.cachedSystemPrompt != "" && !cb.sourceFilesChangedLocked() {
+ return cb.cachedSystemPrompt
+ }
+
+ // Snapshot the baseline (existence + max mtime) BEFORE building the prompt.
+ // This way cachedAt reflects the pre-build state: if a file is modified
+ // during BuildSystemPrompt, its new mtime will be > baseline.maxMtime,
+ // so the next sourceFilesChangedLocked check will correctly trigger a
+ // rebuild. The alternative (baseline after build) risks caching stale
+ // content with a too-new baseline, making the staleness invisible.
+ baseline := cb.buildCacheBaseline()
+ prompt := cb.BuildSystemPrompt()
+ cb.cachedSystemPrompt = prompt
+ cb.cachedAt = baseline.maxMtime
+ cb.existedAtCache = baseline.existed
+
+ logger.DebugCF("agent", "System prompt cached",
+ map[string]any{
+ "length": len(prompt),
+ })
+
+ return prompt
+}
+
+// InvalidateCache clears the cached system prompt.
+// Normally not needed because the cache auto-invalidates via mtime checks,
+// but this is useful for tests or explicit reload commands.
+func (cb *ContextBuilder) InvalidateCache() {
+ cb.systemPromptMutex.Lock()
+ defer cb.systemPromptMutex.Unlock()
+
+ cb.cachedSystemPrompt = ""
+ cb.cachedAt = time.Time{}
+ cb.existedAtCache = nil
+
+ logger.DebugCF("agent", "System prompt cache invalidated", nil)
+}
+
+// sourcePaths returns the workspace source file paths tracked for cache
+// invalidation (bootstrap files + memory). The skills directory is handled
+// separately in sourceFilesChangedLocked because it requires both directory-
+// level and recursive file-level mtime checks.
+func (cb *ContextBuilder) sourcePaths() []string {
+ return []string{
+ filepath.Join(cb.workspace, "AGENTS.md"),
+ filepath.Join(cb.workspace, "SOUL.md"),
+ filepath.Join(cb.workspace, "USER.md"),
+ filepath.Join(cb.workspace, "IDENTITY.md"),
+ filepath.Join(cb.workspace, "memory", "MEMORY.md"),
+ }
+}
+
+// cacheBaseline holds the file existence snapshot and the latest observed
+// mtime across all tracked paths. Used as the cache reference point.
+type cacheBaseline struct {
+ existed map[string]bool
+ maxMtime time.Time
+}
+
+// buildCacheBaseline records which tracked paths currently exist and computes
+// the latest mtime across all tracked files + skills directory contents.
+// Called under write lock when the cache is built.
+func (cb *ContextBuilder) buildCacheBaseline() cacheBaseline {
+ skillsDir := filepath.Join(cb.workspace, "skills")
+
+ // All paths whose existence we track: source files + skills dir.
+ allPaths := append(cb.sourcePaths(), skillsDir)
+
+ existed := make(map[string]bool, len(allPaths))
+ var maxMtime time.Time
+
+ for _, p := range allPaths {
+ info, err := os.Stat(p)
+ existed[p] = err == nil
+ if err == nil && info.ModTime().After(maxMtime) {
+ maxMtime = info.ModTime()
+ }
+ }
+
+ // Walk skills files to capture their mtimes too.
+ // Use os.Stat (not d.Info) to match the stat method used in
+ // fileChangedSince / skillFilesModifiedSince for consistency.
+ _ = filepath.WalkDir(skillsDir, func(path string, d fs.DirEntry, walkErr error) error {
+ if walkErr == nil && !d.IsDir() {
+ if info, err := os.Stat(path); err == nil && info.ModTime().After(maxMtime) {
+ maxMtime = info.ModTime()
+ }
+ }
+ return nil
+ })
+
+ // If no tracked files exist yet (empty workspace), maxMtime is zero.
+ // Use a very old non-zero time so that:
+ // 1. cachedAt.IsZero() won't trigger perpetual rebuilds.
+ // 2. Any real file created afterwards has mtime > cachedAt, so it
+ // will be detected by fileChangedSince (unlike time.Now() which
+ // could race with a file whose mtime <= Now).
+ if maxMtime.IsZero() {
+ maxMtime = time.Unix(1, 0)
+ }
+
+ return cacheBaseline{existed: existed, maxMtime: maxMtime}
+}
+
+// sourceFilesChangedLocked checks whether any workspace source file has been
+// modified, created, or deleted since the cache was last built.
+//
+// IMPORTANT: The caller MUST hold at least a read lock on systemPromptMutex.
+// Go's sync.RWMutex is not reentrant, so this function must NOT acquire the
+// lock itself (it would deadlock when called from BuildSystemPromptWithCache
+// which already holds RLock or Lock).
+func (cb *ContextBuilder) sourceFilesChangedLocked() bool {
+ if cb.cachedAt.IsZero() {
+ return true
+ }
+
+ // Check tracked source files (bootstrap + memory).
+ for _, p := range cb.sourcePaths() {
+ if cb.fileChangedSince(p) {
+ return true
+ }
+ }
+
+ // --- Skills directory (handled separately from sourcePaths) ---
+ //
+ // 1. Creation/deletion: tracked via existedAtCache, same as bootstrap files.
+ skillsDir := filepath.Join(cb.workspace, "skills")
+ if cb.fileChangedSince(skillsDir) {
+ return true
+ }
+
+ // 2. Structural changes (add/remove entries inside the dir) are reflected
+ // in the directory's own mtime, which fileChangedSince already checks.
+ //
+ // 3. Content-only edits to files inside skills/ do NOT update the parent
+ // directory mtime on most filesystems, so we recursively walk to check
+ // individual file mtimes at any nesting depth.
+ if skillFilesModifiedSince(skillsDir, cb.cachedAt) {
+ return true
+ }
+
+ return false
+}
+
+// fileChangedSince returns true if a tracked source file has been modified,
+// newly created, or deleted since the cache was built.
+//
+// Four cases:
+// - existed at cache time, exists now -> check mtime
+// - existed at cache time, gone now -> changed (deleted)
+// - absent at cache time, exists now -> changed (created)
+// - absent at cache time, gone now -> no change
+func (cb *ContextBuilder) fileChangedSince(path string) bool {
+ // Defensive: if existedAtCache was never initialized, treat as changed
+ // so the cache rebuilds rather than silently serving stale data.
+ if cb.existedAtCache == nil {
+ return true
+ }
+
+ existedBefore := cb.existedAtCache[path]
+ info, err := os.Stat(path)
+ existsNow := err == nil
+
+ if existedBefore != existsNow {
+ return true // file was created or deleted
+ }
+ if !existsNow {
+ return false // didn't exist before, doesn't exist now
+ }
+ return info.ModTime().After(cb.cachedAt)
+}
+
+// errWalkStop is a sentinel error used to stop filepath.WalkDir early.
+// Using a dedicated error (instead of fs.SkipAll) makes the early-exit
+// intent explicit and avoids the nilerr linter warning that would fire
+// if the callback returned nil when its err parameter is non-nil.
+var errWalkStop = errors.New("walk stop")
+
+// skillFilesModifiedSince recursively walks the skills directory and checks
+// whether any file was modified after t. This catches content-only edits at
+// any nesting depth (e.g. skills/name/docs/extra.md) that don't update
+// parent directory mtimes.
+func skillFilesModifiedSince(skillsDir string, t time.Time) bool {
+ changed := false
+ err := filepath.WalkDir(skillsDir, func(path string, d fs.DirEntry, walkErr error) error {
+ if walkErr == nil && !d.IsDir() {
+ if info, statErr := os.Stat(path); statErr == nil && info.ModTime().After(t) {
+ changed = true
+ return errWalkStop // stop walking
+ }
+ }
+ return nil
+ })
+ // errWalkStop is expected (early exit on first changed file).
+ // os.IsNotExist means the skills dir doesn't exist yet — not an error.
+ // Any other error is unexpected and worth logging.
+ if err != nil && !errors.Is(err, errWalkStop) && !os.IsNotExist(err) {
+ logger.DebugCF("agent", "skills walk error", map[string]any{"error": err.Error()})
+ }
+ return changed
+}
+
func (cb *ContextBuilder) LoadBootstrapFiles() string {
bootstrapFiles := []string{
"AGENTS.md",
@@ -146,58 +342,130 @@ func (cb *ContextBuilder) LoadBootstrapFiles() string {
"IDENTITY.md",
}
- var result string
+ var sb strings.Builder
for _, filename := range bootstrapFiles {
filePath := filepath.Join(cb.workspace, filename)
if data, err := os.ReadFile(filePath); err == nil {
- result += fmt.Sprintf("## %s\n\n%s\n\n", filename, string(data))
+ fmt.Fprintf(&sb, "## %s\n\n%s\n\n", filename, data)
}
}
- return result
+ return sb.String()
}
-func (cb *ContextBuilder) BuildMessages(history []providers.Message, summary string, currentMessage string, media []string, channel, chatID string) []providers.Message {
- messages := []providers.Message{}
+// buildDynamicContext returns a short dynamic context string with per-request info.
+// This changes every request (time, session) so it is NOT part of the cached prompt.
+// LLM-side KV cache reuse is achieved by each provider adapter's native mechanism:
+// - Anthropic: per-block cache_control (ephemeral) on the static SystemParts block
+// - OpenAI / Codex: prompt_cache_key for prefix-based caching
+//
+// See: https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
+// See: https://platform.openai.com/docs/guides/prompt-caching
+func (cb *ContextBuilder) buildDynamicContext(channel, chatID string) string {
+ now := time.Now().Format("2006-01-02 15:04 (Monday)")
+ rt := fmt.Sprintf("%s %s, Go %s", runtime.GOOS, runtime.GOARCH, runtime.Version())
- systemPrompt := cb.BuildSystemPrompt()
+ var sb strings.Builder
+ fmt.Fprintf(&sb, "## Current Time\n%s\n\n## Runtime\n%s", now, rt)
- // Add Current Session info if provided
if channel != "" && chatID != "" {
- systemPrompt += fmt.Sprintf("\n\n## Current Session\nChannel: %s\nChat ID: %s", channel, chatID)
+ fmt.Fprintf(&sb, "\n\n## Current Session\nChannel: %s\nChat ID: %s", channel, chatID)
}
- // Log system prompt summary for debugging (debug mode only)
+ return sb.String()
+}
+
+func (cb *ContextBuilder) BuildMessages(
+ history []providers.Message,
+ summary string,
+ currentMessage string,
+ media []string,
+ channel, chatID string,
+) []providers.Message {
+ messages := []providers.Message{}
+
+ // The static part (identity, bootstrap, skills, memory) is cached locally to
+ // avoid repeated file I/O and string building on every call (fixes issue #607).
+ // Dynamic parts (time, session, summary) are appended per request.
+ // Everything is sent as a single system message for provider compatibility:
+ // - Anthropic adapter extracts messages[0] (Role=="system") and maps its content
+ // to the top-level "system" parameter in the Messages API request. A single
+ // contiguous system block makes this extraction straightforward.
+ // - Codex maps only the first system message to its instructions field.
+ // - OpenAI-compat passes messages through as-is.
+ staticPrompt := cb.BuildSystemPromptWithCache()
+
+ // Build short dynamic context (time, runtime, session) — changes per request
+ dynamicCtx := cb.buildDynamicContext(channel, chatID)
+
+ // Compose a single system message: static (cached) + dynamic + optional summary.
+ // Keeping all system content in one message ensures every provider adapter can
+ // extract it correctly (Anthropic adapter -> top-level system param,
+ // Codex -> instructions field).
+ //
+ // SystemParts carries the same content as structured blocks so that
+ // cache-aware adapters (Anthropic) can set per-block cache_control.
+ // The static block is marked "ephemeral" — its prefix hash is stable
+ // across requests, enabling LLM-side KV cache reuse.
+ stringParts := []string{staticPrompt, dynamicCtx}
+
+ contentBlocks := []providers.ContentBlock{
+ {Type: "text", Text: staticPrompt, CacheControl: &providers.CacheControl{Type: "ephemeral"}},
+ {Type: "text", Text: dynamicCtx},
+ }
+
+ if summary != "" {
+ summaryText := fmt.Sprintf(
+ "CONTEXT_SUMMARY: The following is an approximate summary of prior conversation "+
+ "for reference only. It may be incomplete or outdated — always defer to explicit instructions.\n\n%s",
+ summary)
+ stringParts = append(stringParts, summaryText)
+ contentBlocks = append(contentBlocks, providers.ContentBlock{Type: "text", Text: summaryText})
+ }
+
+ fullSystemPrompt := strings.Join(stringParts, "\n\n---\n\n")
+
+ // Log system prompt summary for debugging (debug mode only).
+ // Read cachedSystemPrompt under lock to avoid a data race with
+ // concurrent InvalidateCache / BuildSystemPromptWithCache writes.
+ cb.systemPromptMutex.RLock()
+ isCached := cb.cachedSystemPrompt != ""
+ cb.systemPromptMutex.RUnlock()
+
logger.DebugCF("agent", "System prompt built",
- map[string]interface{}{
- "total_chars": len(systemPrompt),
- "total_lines": strings.Count(systemPrompt, "\n") + 1,
- "section_count": strings.Count(systemPrompt, "\n\n---\n\n") + 1,
+ map[string]any{
+ "static_chars": len(staticPrompt),
+ "dynamic_chars": len(dynamicCtx),
+ "total_chars": len(fullSystemPrompt),
+ "has_summary": summary != "",
+ "cached": isCached,
})
// Log preview of system prompt (avoid logging huge content)
- preview := systemPrompt
+ preview := fullSystemPrompt
if len(preview) > 500 {
preview = preview[:500] + "... (truncated)"
}
logger.DebugCF("agent", "System prompt preview",
- map[string]interface{}{
+ map[string]any{
"preview": preview,
})
- if summary != "" {
- systemPrompt += "\n\n## Summary of Previous Conversation\n\n" + summary
- }
-
history = sanitizeHistoryForProvider(history)
+ // Single system message containing all context — compatible with all providers.
+ // SystemParts enables cache-aware adapters to set per-block cache_control;
+ // Content is the concatenated fallback for adapters that don't read SystemParts.
messages = append(messages, providers.Message{
- Role: "system",
- Content: systemPrompt,
+ Role: "system",
+ Content: fullSystemPrompt,
+ SystemParts: contentBlocks,
})
+ // Add conversation history
messages = append(messages, history...)
+ // Add current user message
if strings.TrimSpace(currentMessage) != "" {
messages = append(messages, providers.Message{
Role: "user",
@@ -216,14 +484,33 @@ func sanitizeHistoryForProvider(history []providers.Message) []providers.Message
sanitized := make([]providers.Message, 0, len(history))
for _, msg := range history {
switch msg.Role {
+ case "system":
+ // Drop system messages from history. BuildMessages always
+ // constructs its own single system message (static + dynamic +
+ // summary); extra system messages would break providers that
+ // only accept one (Anthropic, Codex).
+ logger.DebugCF("agent", "Dropping system message from history", map[string]any{})
+ continue
+
case "tool":
if len(sanitized) == 0 {
- logger.DebugCF("agent", "Dropping orphaned leading tool message", map[string]interface{}{})
+ logger.DebugCF("agent", "Dropping orphaned leading tool message", map[string]any{})
continue
}
- last := sanitized[len(sanitized)-1]
- if last.Role != "assistant" || len(last.ToolCalls) == 0 {
- logger.DebugCF("agent", "Dropping orphaned tool message", map[string]interface{}{})
+ // Walk backwards to find the nearest assistant message,
+ // skipping over any preceding tool messages (multi-tool-call case).
+ foundAssistant := false
+ for i := len(sanitized) - 1; i >= 0; i-- {
+ if sanitized[i].Role == "tool" {
+ continue
+ }
+ if sanitized[i].Role == "assistant" && len(sanitized[i].ToolCalls) > 0 {
+ foundAssistant = true
+ }
+ break
+ }
+ if !foundAssistant {
+ logger.DebugCF("agent", "Dropping orphaned tool message", map[string]any{})
continue
}
sanitized = append(sanitized, msg)
@@ -231,12 +518,16 @@ func sanitizeHistoryForProvider(history []providers.Message) []providers.Message
case "assistant":
if len(msg.ToolCalls) > 0 {
if len(sanitized) == 0 {
- logger.DebugCF("agent", "Dropping assistant tool-call turn at history start", map[string]interface{}{})
+ logger.DebugCF("agent", "Dropping assistant tool-call turn at history start", map[string]any{})
continue
}
prev := sanitized[len(sanitized)-1]
if prev.Role != "user" && prev.Role != "tool" {
- logger.DebugCF("agent", "Dropping assistant tool-call turn with invalid predecessor", map[string]interface{}{"prev_role": prev.Role})
+ logger.DebugCF(
+ "agent",
+ "Dropping assistant tool-call turn with invalid predecessor",
+ map[string]any{"prev_role": prev.Role},
+ )
continue
}
}
@@ -250,7 +541,10 @@ func sanitizeHistoryForProvider(history []providers.Message) []providers.Message
return sanitized
}
-func (cb *ContextBuilder) AddToolResult(messages []providers.Message, toolCallID, toolName, result string) []providers.Message {
+func (cb *ContextBuilder) AddToolResult(
+ messages []providers.Message,
+ toolCallID, toolName, result string,
+) []providers.Message {
messages = append(messages, providers.Message{
Role: "tool",
Content: result,
@@ -259,7 +553,11 @@ func (cb *ContextBuilder) AddToolResult(messages []providers.Message, toolCallID
return messages
}
-func (cb *ContextBuilder) AddAssistantMessage(messages []providers.Message, content string, toolCalls []map[string]interface{}) []providers.Message {
+func (cb *ContextBuilder) AddAssistantMessage(
+ messages []providers.Message,
+ content string,
+ toolCalls []map[string]any,
+) []providers.Message {
msg := providers.Message{
Role: "assistant",
Content: content,
@@ -269,33 +567,14 @@ func (cb *ContextBuilder) AddAssistantMessage(messages []providers.Message, cont
return messages
}
-func (cb *ContextBuilder) loadSkills() string {
- allSkills := cb.skillsLoader.ListSkills()
- if len(allSkills) == 0 {
- return ""
- }
-
- var skillNames []string
- for _, s := range allSkills {
- skillNames = append(skillNames, s.Name)
- }
-
- content := cb.skillsLoader.LoadSkillsForContext(skillNames)
- if content == "" {
- return ""
- }
-
- return "# Skill Definitions\n\n" + content
-}
-
// GetSkillsInfo returns information about loaded skills.
-func (cb *ContextBuilder) GetSkillsInfo() map[string]interface{} {
+func (cb *ContextBuilder) GetSkillsInfo() map[string]any {
allSkills := cb.skillsLoader.ListSkills()
skillNames := make([]string, 0, len(allSkills))
for _, s := range allSkills {
skillNames = append(skillNames, s.Name)
}
- return map[string]interface{}{
+ return map[string]any{
"total": len(allSkills),
"available": len(allSkills),
"names": skillNames,
diff --git a/pkg/agent/context_cache_test.go b/pkg/agent/context_cache_test.go
new file mode 100644
index 000000000..ba70d4c0d
--- /dev/null
+++ b/pkg/agent/context_cache_test.go
@@ -0,0 +1,513 @@
+package agent
+
+import (
+ "os"
+ "path/filepath"
+ "strings"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/sipeed/picoclaw/pkg/providers"
+)
+
+// setupWorkspace creates a temporary workspace with standard directories and optional files.
+// Returns the tmpDir path; caller should defer os.RemoveAll(tmpDir).
+func setupWorkspace(t *testing.T, files map[string]string) string {
+ t.Helper()
+ tmpDir, err := os.MkdirTemp("", "picoclaw-test-*")
+ if err != nil {
+ t.Fatal(err)
+ }
+ os.MkdirAll(filepath.Join(tmpDir, "memory"), 0o755)
+ os.MkdirAll(filepath.Join(tmpDir, "skills"), 0o755)
+ for name, content := range files {
+ dir := filepath.Dir(filepath.Join(tmpDir, name))
+ os.MkdirAll(dir, 0o755)
+ if err := os.WriteFile(filepath.Join(tmpDir, name), []byte(content), 0o644); err != nil {
+ t.Fatal(err)
+ }
+ }
+ return tmpDir
+}
+
+// TestSingleSystemMessage verifies that BuildMessages always produces exactly one
+// system message regardless of summary/history variations.
+// Fix: multiple system messages break Anthropic (top-level system param) and
+// Codex (only reads last system message as instructions).
+func TestSingleSystemMessage(t *testing.T) {
+ tmpDir := setupWorkspace(t, map[string]string{
+ "IDENTITY.md": "# Identity\nTest agent.",
+ })
+ defer os.RemoveAll(tmpDir)
+
+ cb := NewContextBuilder(tmpDir)
+
+ tests := []struct {
+ name string
+ history []providers.Message
+ summary string
+ message string
+ }{
+ {
+ name: "no summary, no history",
+ summary: "",
+ message: "hello",
+ },
+ {
+ name: "with summary",
+ summary: "Previous conversation discussed X",
+ message: "hello",
+ },
+ {
+ name: "with history and summary",
+ history: []providers.Message{
+ {Role: "user", Content: "hi"},
+ {Role: "assistant", Content: "hello"},
+ },
+ summary: strings.Repeat("Long summary text. ", 50),
+ message: "new message",
+ },
+ {
+ name: "system message in history is filtered",
+ history: []providers.Message{
+ {Role: "system", Content: "stale system prompt from previous session"},
+ {Role: "user", Content: "hi"},
+ {Role: "assistant", Content: "hello"},
+ },
+ summary: "",
+ message: "new message",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ msgs := cb.BuildMessages(tt.history, tt.summary, tt.message, nil, "test", "chat1")
+
+ systemCount := 0
+ for _, m := range msgs {
+ if m.Role == "system" {
+ systemCount++
+ }
+ }
+ if systemCount != 1 {
+ t.Errorf("expected exactly 1 system message, got %d", systemCount)
+ }
+ if msgs[0].Role != "system" {
+ t.Errorf("first message should be system, got %s", msgs[0].Role)
+ }
+ if msgs[len(msgs)-1].Role != "user" {
+ t.Errorf("last message should be user, got %s", msgs[len(msgs)-1].Role)
+ }
+
+ // System message must contain identity (static) and time (dynamic)
+ sys := msgs[0].Content
+ if !strings.Contains(sys, "picoclaw") {
+ t.Error("system message missing identity")
+ }
+ if !strings.Contains(sys, "Current Time") {
+ t.Error("system message missing dynamic time context")
+ }
+
+ // Summary handling
+ if tt.summary != "" {
+ if !strings.Contains(sys, "CONTEXT_SUMMARY:") {
+ t.Error("summary present but CONTEXT_SUMMARY prefix missing")
+ }
+ if !strings.Contains(sys, tt.summary[:20]) {
+ t.Error("summary content not found in system message")
+ }
+ } else {
+ if strings.Contains(sys, "CONTEXT_SUMMARY:") {
+ t.Error("CONTEXT_SUMMARY should not appear without summary")
+ }
+ }
+ })
+ }
+}
+
+// TestMtimeAutoInvalidation verifies that the cache detects source file changes
+// via mtime without requiring explicit InvalidateCache().
+// Fix: original implementation had no auto-invalidation — edits to bootstrap files,
+// memory, or skills were invisible until process restart.
+func TestMtimeAutoInvalidation(t *testing.T) {
+ tests := []struct {
+ name string
+ file string // relative path inside workspace
+ contentV1 string
+ contentV2 string
+ checkField string // substring to verify in rebuilt prompt
+ }{
+ {
+ name: "bootstrap file change",
+ file: "IDENTITY.md",
+ contentV1: "# Original Identity",
+ contentV2: "# Updated Identity",
+ checkField: "Updated Identity",
+ },
+ {
+ name: "memory file change",
+ file: "memory/MEMORY.md",
+ contentV1: "# Memory\nUser likes Go.",
+ contentV2: "# Memory\nUser likes Rust.",
+ checkField: "User likes Rust",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ tmpDir := setupWorkspace(t, map[string]string{tt.file: tt.contentV1})
+ defer os.RemoveAll(tmpDir)
+
+ cb := NewContextBuilder(tmpDir)
+
+ sp1 := cb.BuildSystemPromptWithCache()
+
+ // Overwrite file and set future mtime to ensure detection.
+ // Use 2s offset for filesystem mtime resolution safety (some FS
+ // have 1s or coarser granularity, especially in CI containers).
+ fullPath := filepath.Join(tmpDir, tt.file)
+ os.WriteFile(fullPath, []byte(tt.contentV2), 0o644)
+ future := time.Now().Add(2 * time.Second)
+ os.Chtimes(fullPath, future, future)
+
+ // Verify sourceFilesChangedLocked detects the mtime change
+ cb.systemPromptMutex.RLock()
+ changed := cb.sourceFilesChangedLocked()
+ cb.systemPromptMutex.RUnlock()
+ if !changed {
+ t.Fatalf("sourceFilesChangedLocked() should detect %s change", tt.file)
+ }
+
+ // Should auto-rebuild without explicit InvalidateCache()
+ sp2 := cb.BuildSystemPromptWithCache()
+ if sp1 == sp2 {
+ t.Errorf("cache not rebuilt after %s change", tt.file)
+ }
+ if !strings.Contains(sp2, tt.checkField) {
+ t.Errorf("rebuilt prompt missing expected content %q", tt.checkField)
+ }
+ })
+ }
+
+ // Skills directory mtime change
+ t.Run("skills dir change", func(t *testing.T) {
+ tmpDir := setupWorkspace(t, nil)
+ defer os.RemoveAll(tmpDir)
+
+ cb := NewContextBuilder(tmpDir)
+ _ = cb.BuildSystemPromptWithCache() // populate cache
+
+ // Touch skills directory (simulate new skill installed)
+ skillsDir := filepath.Join(tmpDir, "skills")
+ future := time.Now().Add(2 * time.Second)
+ os.Chtimes(skillsDir, future, future)
+
+ // Verify sourceFilesChangedLocked detects it (cache is rebuilt)
+ // We confirm by checking internal state: a second call should rebuild.
+ cb.systemPromptMutex.RLock()
+ changed := cb.sourceFilesChangedLocked()
+ cb.systemPromptMutex.RUnlock()
+ if !changed {
+ t.Error("sourceFilesChangedLocked() should detect skills dir mtime change")
+ }
+ })
+}
+
+// TestExplicitInvalidateCache verifies that InvalidateCache() forces a rebuild
+// even when source files haven't changed (useful for tests and reload commands).
+func TestExplicitInvalidateCache(t *testing.T) {
+ tmpDir := setupWorkspace(t, map[string]string{
+ "IDENTITY.md": "# Test Identity",
+ })
+ defer os.RemoveAll(tmpDir)
+
+ cb := NewContextBuilder(tmpDir)
+
+ sp1 := cb.BuildSystemPromptWithCache()
+ cb.InvalidateCache()
+ sp2 := cb.BuildSystemPromptWithCache()
+
+ if sp1 != sp2 {
+ t.Error("prompt should be identical after invalidate+rebuild when files unchanged")
+ }
+
+ // Verify cachedAt was reset
+ cb.InvalidateCache()
+ cb.systemPromptMutex.RLock()
+ if !cb.cachedAt.IsZero() {
+ t.Error("cachedAt should be zero after InvalidateCache()")
+ }
+ cb.systemPromptMutex.RUnlock()
+}
+
+// TestCacheStability verifies that the static prompt is stable across repeated calls
+// when no files change (regression test for issue #607).
+func TestCacheStability(t *testing.T) {
+ tmpDir := setupWorkspace(t, map[string]string{
+ "IDENTITY.md": "# Identity\nContent",
+ "SOUL.md": "# Soul\nContent",
+ })
+ defer os.RemoveAll(tmpDir)
+
+ cb := NewContextBuilder(tmpDir)
+
+ results := make([]string, 5)
+ for i := range results {
+ results[i] = cb.BuildSystemPromptWithCache()
+ }
+ for i := 1; i < len(results); i++ {
+ if results[i] != results[0] {
+ t.Errorf("cached prompt changed between call 0 and %d", i)
+ }
+ }
+
+ // Static prompt must NOT contain per-request data
+ if strings.Contains(results[0], "Current Time") {
+ t.Error("static cached prompt should not contain time (added dynamically)")
+ }
+}
+
+// TestNewFileCreationInvalidatesCache verifies that creating a source file that
+// did not exist when the cache was built triggers a cache rebuild.
+// This catches the "from nothing to something" edge case that the old
+// modifiedSince (return false on stat error) would miss.
+func TestNewFileCreationInvalidatesCache(t *testing.T) {
+ tests := []struct {
+ name string
+ file string // relative path inside workspace
+ content string
+ checkField string // substring to verify in rebuilt prompt
+ }{
+ {
+ name: "new bootstrap file",
+ file: "SOUL.md",
+ content: "# Soul\nBe kind and helpful.",
+ checkField: "Be kind and helpful",
+ },
+ {
+ name: "new memory file",
+ file: "memory/MEMORY.md",
+ content: "# Memory\nUser prefers dark mode.",
+ checkField: "User prefers dark mode",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ // Start with an empty workspace (no bootstrap/memory files)
+ tmpDir := setupWorkspace(t, nil)
+ defer os.RemoveAll(tmpDir)
+
+ cb := NewContextBuilder(tmpDir)
+
+ // Populate cache — file does not exist yet
+ sp1 := cb.BuildSystemPromptWithCache()
+ if strings.Contains(sp1, tt.checkField) {
+ t.Fatalf("prompt should not contain %q before file is created", tt.checkField)
+ }
+
+ // Create the file after cache was built
+ fullPath := filepath.Join(tmpDir, tt.file)
+ os.MkdirAll(filepath.Dir(fullPath), 0o755)
+ if err := os.WriteFile(fullPath, []byte(tt.content), 0o644); err != nil {
+ t.Fatal(err)
+ }
+ // Set future mtime to guarantee detection
+ future := time.Now().Add(2 * time.Second)
+ os.Chtimes(fullPath, future, future)
+
+ // Cache should auto-invalidate because file went from absent -> present
+ sp2 := cb.BuildSystemPromptWithCache()
+ if !strings.Contains(sp2, tt.checkField) {
+ t.Errorf("cache not invalidated on new file creation: expected %q in prompt", tt.checkField)
+ }
+ })
+ }
+}
+
+// TestSkillFileContentChange verifies that modifying a skill file's content
+// (not just the directory structure) invalidates the cache.
+// This is the scenario where directory mtime alone is insufficient — on most
+// filesystems, editing a file inside a directory does NOT update the parent
+// directory's mtime.
+func TestSkillFileContentChange(t *testing.T) {
+ skillMD := `---
+name: test-skill
+description: "A test skill"
+---
+# Test Skill v1
+Original content.`
+
+ tmpDir := setupWorkspace(t, map[string]string{
+ "skills/test-skill/SKILL.md": skillMD,
+ })
+ defer os.RemoveAll(tmpDir)
+
+ cb := NewContextBuilder(tmpDir)
+
+ // Populate cache
+ sp1 := cb.BuildSystemPromptWithCache()
+ _ = sp1 // cache is warm
+
+ // Modify the skill file content (without touching the skills/ directory)
+ updatedSkillMD := `---
+name: test-skill
+description: "An updated test skill"
+---
+# Test Skill v2
+Updated content.`
+
+ skillPath := filepath.Join(tmpDir, "skills", "test-skill", "SKILL.md")
+ if err := os.WriteFile(skillPath, []byte(updatedSkillMD), 0o644); err != nil {
+ t.Fatal(err)
+ }
+ // Set future mtime on the skill file only (NOT the directory)
+ future := time.Now().Add(2 * time.Second)
+ os.Chtimes(skillPath, future, future)
+
+ // Verify that sourceFilesChangedLocked detects the content change
+ cb.systemPromptMutex.RLock()
+ changed := cb.sourceFilesChangedLocked()
+ cb.systemPromptMutex.RUnlock()
+ if !changed {
+ t.Error("sourceFilesChangedLocked() should detect skill file content change")
+ }
+
+ // Verify cache is actually rebuilt with new content
+ sp2 := cb.BuildSystemPromptWithCache()
+ if sp1 == sp2 && strings.Contains(sp1, "test-skill") {
+ // If the skill appeared in the prompt and the prompt didn't change,
+ // the cache was not invalidated.
+ t.Error("cache should be invalidated when skill file content changes")
+ }
+}
+
+// TestConcurrentBuildSystemPromptWithCache verifies that multiple goroutines
+// can safely call BuildSystemPromptWithCache concurrently without producing
+// empty results, panics, or data races.
+// Run with: go test -race ./pkg/agent/ -run TestConcurrentBuildSystemPromptWithCache
+func TestConcurrentBuildSystemPromptWithCache(t *testing.T) {
+ tmpDir := setupWorkspace(t, map[string]string{
+ "IDENTITY.md": "# Identity\nConcurrency test agent.",
+ "SOUL.md": "# Soul\nBe helpful.",
+ "memory/MEMORY.md": "# Memory\nUser prefers Go.",
+ "skills/demo/SKILL.md": "---\nname: demo\ndescription: \"demo skill\"\n---\n# Demo",
+ })
+ defer os.RemoveAll(tmpDir)
+
+ cb := NewContextBuilder(tmpDir)
+
+ const goroutines = 20
+ const iterations = 50
+
+ var wg sync.WaitGroup
+ errs := make(chan string, goroutines*iterations)
+
+ for g := 0; g < goroutines; g++ {
+ wg.Add(1)
+ go func(id int) {
+ defer wg.Done()
+ for i := 0; i < iterations; i++ {
+ result := cb.BuildSystemPromptWithCache()
+ if result == "" {
+ errs <- "empty prompt returned"
+ return
+ }
+ if !strings.Contains(result, "picoclaw") {
+ errs <- "prompt missing identity"
+ return
+ }
+
+ // Also exercise BuildMessages concurrently
+ msgs := cb.BuildMessages(nil, "", "hello", nil, "test", "chat")
+ if len(msgs) < 2 {
+ errs <- "BuildMessages returned fewer than 2 messages"
+ return
+ }
+ if msgs[0].Role != "system" {
+ errs <- "first message not system"
+ return
+ }
+
+ // Occasionally invalidate to exercise the write path
+ if i%10 == 0 {
+ cb.InvalidateCache()
+ }
+ }
+ }(g)
+ }
+
+ wg.Wait()
+ close(errs)
+
+ for errMsg := range errs {
+ t.Errorf("concurrent access error: %s", errMsg)
+ }
+}
+
+// BenchmarkBuildMessagesWithCache measures caching performance.
+
+// TestEmptyWorkspaceBaselineDetectsNewFiles verifies that when the cache is
+// built on an empty workspace (no tracked files exist), creating a file
+// afterwards still triggers cache invalidation. This validates the
+// time.Unix(1, 0) fallback for maxMtime: any real file's mtime is after epoch,
+// so fileChangedSince correctly detects the absent -> present transition AND
+// the mtime comparison succeeds even without artificially inflated Chtimes.
+func TestEmptyWorkspaceBaselineDetectsNewFiles(t *testing.T) {
+ // Empty workspace: no bootstrap files, no memory, no skills content.
+ tmpDir := setupWorkspace(t, nil)
+ defer os.RemoveAll(tmpDir)
+
+ cb := NewContextBuilder(tmpDir)
+
+ // Build cache — all tracked files are absent, maxMtime falls back to epoch.
+ sp1 := cb.BuildSystemPromptWithCache()
+
+ // Create a bootstrap file with natural mtime (no Chtimes manipulation).
+ // The file's mtime should be the current wall-clock time, which is
+ // strictly after time.Unix(1, 0).
+ soulPath := filepath.Join(tmpDir, "SOUL.md")
+ if err := os.WriteFile(soulPath, []byte("# Soul\nNewly created."), 0o644); err != nil {
+ t.Fatal(err)
+ }
+
+ // Cache should detect the new file via existedAtCache (absent -> present).
+ cb.systemPromptMutex.RLock()
+ changed := cb.sourceFilesChangedLocked()
+ cb.systemPromptMutex.RUnlock()
+ if !changed {
+ t.Fatal("sourceFilesChangedLocked should detect newly created file on empty workspace")
+ }
+
+ sp2 := cb.BuildSystemPromptWithCache()
+ if !strings.Contains(sp2, "Newly created") {
+ t.Error("rebuilt prompt should contain new file content")
+ }
+ if sp1 == sp2 {
+ t.Error("cache should have been invalidated after file creation")
+ }
+}
+
+// BenchmarkBuildMessagesWithCache measures caching performance.
+func BenchmarkBuildMessagesWithCache(b *testing.B) {
+ tmpDir, _ := os.MkdirTemp("", "picoclaw-bench-*")
+ defer os.RemoveAll(tmpDir)
+
+ os.MkdirAll(filepath.Join(tmpDir, "memory"), 0o755)
+ os.MkdirAll(filepath.Join(tmpDir, "skills"), 0o755)
+ for _, name := range []string{"IDENTITY.md", "SOUL.md", "USER.md"} {
+ os.WriteFile(filepath.Join(tmpDir, name), []byte(strings.Repeat("Content.\n", 10)), 0o644)
+ }
+
+ cb := NewContextBuilder(tmpDir)
+ history := []providers.Message{
+ {Role: "user", Content: "previous message"},
+ {Role: "assistant", Content: "previous response"},
+ }
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ _ = cb.BuildMessages(history, "summary", "new message", nil, "cli", "test")
+ }
+}
diff --git a/pkg/agent/context_test.go b/pkg/agent/context_test.go
new file mode 100644
index 000000000..e023c9c30
--- /dev/null
+++ b/pkg/agent/context_test.go
@@ -0,0 +1,209 @@
+package agent
+
+import (
+ "testing"
+
+ "github.com/sipeed/picoclaw/pkg/providers"
+)
+
+func msg(role, content string) providers.Message {
+ return providers.Message{Role: role, Content: content}
+}
+
+func assistantWithTools(toolIDs ...string) providers.Message {
+ calls := make([]providers.ToolCall, len(toolIDs))
+ for i, id := range toolIDs {
+ calls[i] = providers.ToolCall{ID: id, Type: "function"}
+ }
+ return providers.Message{Role: "assistant", ToolCalls: calls}
+}
+
+func toolResult(id string) providers.Message {
+ return providers.Message{Role: "tool", Content: "result", ToolCallID: id}
+}
+
+func TestSanitizeHistoryForProvider_EmptyHistory(t *testing.T) {
+ result := sanitizeHistoryForProvider(nil)
+ if len(result) != 0 {
+ t.Fatalf("expected empty, got %d messages", len(result))
+ }
+
+ result = sanitizeHistoryForProvider([]providers.Message{})
+ if len(result) != 0 {
+ t.Fatalf("expected empty, got %d messages", len(result))
+ }
+}
+
+func TestSanitizeHistoryForProvider_SingleToolCall(t *testing.T) {
+ history := []providers.Message{
+ msg("user", "hello"),
+ assistantWithTools("A"),
+ toolResult("A"),
+ msg("assistant", "done"),
+ }
+
+ result := sanitizeHistoryForProvider(history)
+ if len(result) != 4 {
+ t.Fatalf("expected 4 messages, got %d", len(result))
+ }
+ assertRoles(t, result, "user", "assistant", "tool", "assistant")
+}
+
+func TestSanitizeHistoryForProvider_MultiToolCalls(t *testing.T) {
+ history := []providers.Message{
+ msg("user", "do two things"),
+ assistantWithTools("A", "B"),
+ toolResult("A"),
+ toolResult("B"),
+ msg("assistant", "both done"),
+ }
+
+ result := sanitizeHistoryForProvider(history)
+ if len(result) != 5 {
+ t.Fatalf("expected 5 messages, got %d: %+v", len(result), roles(result))
+ }
+ assertRoles(t, result, "user", "assistant", "tool", "tool", "assistant")
+}
+
+func TestSanitizeHistoryForProvider_AssistantToolCallAfterPlainAssistant(t *testing.T) {
+ history := []providers.Message{
+ msg("user", "hi"),
+ msg("assistant", "thinking"),
+ assistantWithTools("A"),
+ toolResult("A"),
+ }
+
+ result := sanitizeHistoryForProvider(history)
+ if len(result) != 2 {
+ t.Fatalf("expected 2 messages, got %d: %+v", len(result), roles(result))
+ }
+ assertRoles(t, result, "user", "assistant")
+}
+
+func TestSanitizeHistoryForProvider_OrphanedLeadingTool(t *testing.T) {
+ history := []providers.Message{
+ toolResult("A"),
+ msg("user", "hello"),
+ }
+
+ result := sanitizeHistoryForProvider(history)
+ if len(result) != 1 {
+ t.Fatalf("expected 1 message, got %d: %+v", len(result), roles(result))
+ }
+ assertRoles(t, result, "user")
+}
+
+func TestSanitizeHistoryForProvider_ToolAfterUserDropped(t *testing.T) {
+ history := []providers.Message{
+ msg("user", "hello"),
+ toolResult("A"),
+ }
+
+ result := sanitizeHistoryForProvider(history)
+ if len(result) != 1 {
+ t.Fatalf("expected 1 message, got %d: %+v", len(result), roles(result))
+ }
+ assertRoles(t, result, "user")
+}
+
+func TestSanitizeHistoryForProvider_ToolAfterAssistantNoToolCalls(t *testing.T) {
+ history := []providers.Message{
+ msg("user", "hello"),
+ msg("assistant", "hi"),
+ toolResult("A"),
+ }
+
+ result := sanitizeHistoryForProvider(history)
+ if len(result) != 2 {
+ t.Fatalf("expected 2 messages, got %d: %+v", len(result), roles(result))
+ }
+ assertRoles(t, result, "user", "assistant")
+}
+
+func TestSanitizeHistoryForProvider_AssistantToolCallAtStart(t *testing.T) {
+ history := []providers.Message{
+ assistantWithTools("A"),
+ toolResult("A"),
+ msg("user", "hello"),
+ }
+
+ result := sanitizeHistoryForProvider(history)
+ if len(result) != 1 {
+ t.Fatalf("expected 1 message, got %d: %+v", len(result), roles(result))
+ }
+ assertRoles(t, result, "user")
+}
+
+func TestSanitizeHistoryForProvider_MultiToolCallsThenNewRound(t *testing.T) {
+ history := []providers.Message{
+ msg("user", "do two things"),
+ assistantWithTools("A", "B"),
+ toolResult("A"),
+ toolResult("B"),
+ msg("assistant", "done"),
+ msg("user", "hi"),
+ assistantWithTools("C"),
+ toolResult("C"),
+ msg("assistant", "done again"),
+ }
+
+ result := sanitizeHistoryForProvider(history)
+ if len(result) != 9 {
+ t.Fatalf("expected 9 messages, got %d: %+v", len(result), roles(result))
+ }
+ assertRoles(t, result, "user", "assistant", "tool", "tool", "assistant", "user", "assistant", "tool", "assistant")
+}
+
+func TestSanitizeHistoryForProvider_ConsecutiveMultiToolRounds(t *testing.T) {
+ history := []providers.Message{
+ msg("user", "start"),
+ assistantWithTools("A", "B"),
+ toolResult("A"),
+ toolResult("B"),
+ assistantWithTools("C", "D"),
+ toolResult("C"),
+ toolResult("D"),
+ msg("assistant", "all done"),
+ }
+
+ result := sanitizeHistoryForProvider(history)
+ if len(result) != 8 {
+ t.Fatalf("expected 8 messages, got %d: %+v", len(result), roles(result))
+ }
+ assertRoles(t, result, "user", "assistant", "tool", "tool", "assistant", "tool", "tool", "assistant")
+}
+
+func TestSanitizeHistoryForProvider_PlainConversation(t *testing.T) {
+ history := []providers.Message{
+ msg("user", "hello"),
+ msg("assistant", "hi"),
+ msg("user", "how are you"),
+ msg("assistant", "fine"),
+ }
+
+ result := sanitizeHistoryForProvider(history)
+ if len(result) != 4 {
+ t.Fatalf("expected 4 messages, got %d", len(result))
+ }
+ assertRoles(t, result, "user", "assistant", "user", "assistant")
+}
+
+func roles(msgs []providers.Message) []string {
+ r := make([]string, len(msgs))
+ for i, m := range msgs {
+ r[i] = m.Role
+ }
+ return r
+}
+
+func assertRoles(t *testing.T, msgs []providers.Message, expected ...string) {
+ t.Helper()
+ if len(msgs) != len(expected) {
+ t.Fatalf("role count mismatch: got %v, want %v", roles(msgs), expected)
+ }
+ for i, exp := range expected {
+ if msgs[i].Role != exp {
+ t.Errorf("message[%d]: got role %q, want %q", i, msgs[i].Role, exp)
+ }
+ }
+}
diff --git a/pkg/agent/instance.go b/pkg/agent/instance.go
index 37b253685..a6fd365c7 100644
--- a/pkg/agent/instance.go
+++ b/pkg/agent/instance.go
@@ -41,7 +41,7 @@ func NewAgentInstance(
provider providers.LLMProvider,
) *AgentInstance {
workspace := resolveAgentWorkspace(agentCfg, defaults)
- os.MkdirAll(workspace, 0755)
+ os.MkdirAll(workspace, 0o755)
model := resolveAgentModel(agentCfg, defaults)
fallbacks := resolveAgentFallbacks(agentCfg, defaults)
@@ -59,7 +59,6 @@ func NewAgentInstance(
sessionsManager := session.NewSessionManager(sessionsDir)
contextBuilder := NewContextBuilder(workspace)
- contextBuilder.SetToolsRegistry(toolsRegistry)
agentID := routing.DefaultAgentID
agentName := ""
@@ -133,7 +132,7 @@ func resolveAgentModel(agentCfg *config.AgentConfig, defaults *config.AgentDefau
if agentCfg != nil && agentCfg.Model != nil && strings.TrimSpace(agentCfg.Model.Primary) != "" {
return strings.TrimSpace(agentCfg.Model.Primary)
}
- return defaults.Model
+ return defaults.GetModelName()
}
// resolveAgentFallbacks resolves the fallback models for an agent.
diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go
index 4cdc1fe90..37591fa79 100644
--- a/pkg/agent/loop.go
+++ b/pkg/agent/loop.go
@@ -23,6 +23,7 @@ import (
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/routing"
+ "github.com/sipeed/picoclaw/pkg/skills"
"github.com/sipeed/picoclaw/pkg/state"
"github.com/sipeed/picoclaw/pkg/tools"
"github.com/sipeed/picoclaw/pkg/utils"
@@ -79,7 +80,12 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers
}
// registerSharedTools registers tools that are shared across all agents (web, message, spawn).
-func registerSharedTools(cfg *config.Config, msgBus *bus.MessageBus, registry *AgentRegistry, provider providers.LLMProvider) {
+func registerSharedTools(
+ cfg *config.Config,
+ msgBus *bus.MessageBus,
+ registry *AgentRegistry,
+ provider providers.LLMProvider,
+) {
for _, agentID := range registry.ListAgentIDs() {
agent, ok := registry.GetAgent(agentID)
if !ok {
@@ -91,6 +97,10 @@ func registerSharedTools(cfg *config.Config, msgBus *bus.MessageBus, registry *A
BraveAPIKey: cfg.Tools.Web.Brave.APIKey,
BraveMaxResults: cfg.Tools.Web.Brave.MaxResults,
BraveEnabled: cfg.Tools.Web.Brave.Enabled,
+ TavilyAPIKey: cfg.Tools.Web.Tavily.APIKey,
+ TavilyBaseURL: cfg.Tools.Web.Tavily.BaseURL,
+ TavilyMaxResults: cfg.Tools.Web.Tavily.MaxResults,
+ TavilyEnabled: cfg.Tools.Web.Tavily.Enabled,
DuckDuckGoMaxResults: cfg.Tools.Web.DuckDuckGo.MaxResults,
DuckDuckGoEnabled: cfg.Tools.Web.DuckDuckGo.Enabled,
PerplexityAPIKey: cfg.Tools.Web.Perplexity.APIKey,
@@ -99,10 +109,11 @@ func registerSharedTools(cfg *config.Config, msgBus *bus.MessageBus, registry *A
SearXNGBaseURL: cfg.Tools.Web.SearXNG.BaseURL,
SearXNGMaxResults: cfg.Tools.Web.SearXNG.MaxResults,
SearXNGEnabled: cfg.Tools.Web.SearXNG.Enabled,
+ Proxy: cfg.Tools.Web.Proxy,
}); searchTool != nil {
agent.Tools.Register(searchTool)
}
- agent.Tools.Register(tools.NewWebFetchTool(50000))
+ agent.Tools.Register(tools.NewWebFetchToolWithProxy(50000, cfg.Tools.Web.Proxy))
// Hardware tools (I2C, SPI) - Linux only, returns error on other platforms
agent.Tools.Register(tools.NewI2CTool())
@@ -120,6 +131,18 @@ func registerSharedTools(cfg *config.Config, msgBus *bus.MessageBus, registry *A
})
agent.Tools.Register(messageTool)
+ // Skill discovery and installation tools
+ registryMgr := skills.NewRegistryManagerFromConfig(skills.RegistryConfig{
+ MaxConcurrentSearches: cfg.Tools.Skills.MaxConcurrentSearches,
+ ClawHub: skills.ClawHubConfig(cfg.Tools.Skills.Registries.ClawHub),
+ })
+ searchCache := skills.NewSearchCache(
+ cfg.Tools.Skills.SearchCache.MaxSize,
+ time.Duration(cfg.Tools.Skills.SearchCache.TTLSeconds)*time.Second,
+ )
+ agent.Tools.Register(tools.NewFindSkillsTool(registryMgr, searchCache))
+ agent.Tools.Register(tools.NewInstallSkillTool(registryMgr, agent.Workspace))
+
// Spawn tool with allowlist checker
subagentManager := tools.NewSubagentManager(provider, agent.Model, agent.Workspace, msgBus)
subagentManager.SetLLMOptions(agent.MaxTokens, agent.Temperature)
@@ -129,9 +152,6 @@ func registerSharedTools(cfg *config.Config, msgBus *bus.MessageBus, registry *A
return registry.CanSpawnSubagent(currentAgentID, targetAgentID)
})
agent.Tools.Register(spawnTool)
-
- // Update context builder with the complete tools registry
- agent.ContextBuilder.SetToolsRegistry(agent.Tools)
}
}
@@ -219,7 +239,10 @@ func (al *AgentLoop) ProcessDirect(ctx context.Context, content, sessionKey stri
return al.ProcessDirectWithChannel(ctx, content, sessionKey, "cli", "direct")
}
-func (al *AgentLoop) ProcessDirectWithChannel(ctx context.Context, content, sessionKey, channel, chatID string) (string, error) {
+func (al *AgentLoop) ProcessDirectWithChannel(
+ ctx context.Context,
+ content, sessionKey, channel, chatID string,
+) (string, error) {
msg := bus.InboundMessage{
Channel: channel,
SenderID: "cron",
@@ -256,7 +279,7 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
logContent = utils.Truncate(msg.Content, 80)
}
logger.InfoCF("agent", fmt.Sprintf("Processing message from %s:%s: %s", msg.Channel, msg.SenderID, logContent),
- map[string]interface{}{
+ map[string]any{
"channel": msg.Channel,
"chat_id": msg.ChatID,
"sender_id": msg.SenderID,
@@ -295,7 +318,7 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
}
logger.InfoCF("agent", "Routed message",
- map[string]interface{}{
+ map[string]any{
"agent_id": agent.ID,
"session_key": sessionKey,
"matched_by": route.MatchedBy,
@@ -318,7 +341,7 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe
}
logger.InfoCF("agent", "Processing system message",
- map[string]interface{}{
+ map[string]any{
"sender_id": msg.SenderID,
"chat_id": msg.ChatID,
})
@@ -343,7 +366,7 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe
// Skip internal channels - only log, don't send to user
if constants.IsInternalChannel(originChannel) {
logger.InfoCF("agent", "Subagent completed (internal channel)",
- map[string]interface{}{
+ map[string]any{
"sender_id": msg.SenderID,
"content_len": len(content),
"channel": originChannel,
@@ -376,7 +399,7 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opt
if !constants.IsInternalChannel(opts.Channel) {
channelKey := fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID)
if err := al.RecordLastChannel(channelKey); err != nil {
- logger.WarnCF("agent", "Failed to record last channel", map[string]interface{}{"error": err.Error()})
+ logger.WarnCF("agent", "Failed to record last channel", map[string]any{"error": err.Error()})
}
}
}
@@ -438,7 +461,7 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opt
// 9. Log response
responsePreview := utils.Truncate(finalContent, 120)
logger.InfoCF("agent", fmt.Sprintf("Response: %s", responsePreview),
- map[string]interface{}{
+ map[string]any{
"agent_id": agent.ID,
"session_key": opts.SessionKey,
"iterations": iteration,
@@ -449,7 +472,12 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opt
}
// runLLMIteration executes the LLM call loop with tool handling.
-func (al *AgentLoop) runLLMIteration(ctx context.Context, agent *AgentInstance, messages []providers.Message, opts processOptions) (string, int, error) {
+func (al *AgentLoop) runLLMIteration(
+ ctx context.Context,
+ agent *AgentInstance,
+ messages []providers.Message,
+ opts processOptions,
+) (string, int, error) {
iteration := 0
var finalContent string
@@ -457,7 +485,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, agent *AgentInstance,
iteration++
logger.DebugCF("agent", "LLM iteration",
- map[string]interface{}{
+ map[string]any{
"agent_id": agent.ID,
"iteration": iteration,
"max": agent.MaxIterations,
@@ -468,7 +496,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, agent *AgentInstance,
// Log LLM request details
logger.DebugCF("agent", "LLM request",
- map[string]interface{}{
+ map[string]any{
"agent_id": agent.ID,
"iteration": iteration,
"model": agent.Model,
@@ -481,7 +509,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, agent *AgentInstance,
// Log full messages (detailed)
logger.DebugCF("agent", "Full LLM request",
- map[string]interface{}{
+ map[string]any{
"iteration": iteration,
"messages_json": formatMessagesForLog(messages),
"tools_json": formatToolsForLog(providerToolDefs),
@@ -495,9 +523,10 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, agent *AgentInstance,
if len(agent.Candidates) > 1 && al.fallback != nil {
fbResult, fbErr := al.fallback.Execute(ctx, agent.Candidates,
func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) {
- return agent.Provider.Chat(ctx, messages, providerToolDefs, model, map[string]interface{}{
- "max_tokens": agent.MaxTokens,
- "temperature": agent.Temperature,
+ return agent.Provider.Chat(ctx, messages, providerToolDefs, model, map[string]any{
+ "max_tokens": agent.MaxTokens,
+ "temperature": agent.Temperature,
+ "prompt_cache_key": agent.ID,
})
},
)
@@ -507,13 +536,14 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, agent *AgentInstance,
if fbResult.Provider != "" && len(fbResult.Attempts) > 0 {
logger.InfoCF("agent", fmt.Sprintf("Fallback: succeeded with %s/%s after %d attempts",
fbResult.Provider, fbResult.Model, len(fbResult.Attempts)+1),
- map[string]interface{}{"agent_id": agent.ID, "iteration": iteration})
+ map[string]any{"agent_id": agent.ID, "iteration": iteration})
}
return fbResult.Response, nil
}
- return agent.Provider.Chat(ctx, messages, providerToolDefs, agent.Model, map[string]interface{}{
- "max_tokens": agent.MaxTokens,
- "temperature": agent.Temperature,
+ return agent.Provider.Chat(ctx, messages, providerToolDefs, agent.Model, map[string]any{
+ "max_tokens": agent.MaxTokens,
+ "temperature": agent.Temperature,
+ "prompt_cache_key": agent.ID,
})
}
@@ -532,7 +562,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, agent *AgentInstance,
strings.Contains(errMsg, "length")
if isContextError && retry < maxRetries {
- logger.WarnCF("agent", "Context window error detected, attempting compression", map[string]interface{}{
+ logger.WarnCF("agent", "Context window error detected, attempting compression", map[string]any{
"error": err.Error(),
"retry": retry,
})
@@ -559,7 +589,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, agent *AgentInstance,
if err != nil {
logger.ErrorCF("agent", "LLM call failed",
- map[string]interface{}{
+ map[string]any{
"agent_id": agent.ID,
"iteration": iteration,
"error": err.Error(),
@@ -571,7 +601,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, agent *AgentInstance,
if len(response.ToolCalls) == 0 {
finalContent = response.Content
logger.InfoCF("agent", "LLM response without tool calls (direct answer)",
- map[string]interface{}{
+ map[string]any{
"agent_id": agent.ID,
"iteration": iteration,
"content_chars": len(finalContent),
@@ -590,7 +620,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, agent *AgentInstance,
toolNames = append(toolNames, tc.Name)
}
logger.InfoCF("agent", "LLM requested tool calls",
- map[string]interface{}{
+ map[string]any{
"agent_id": agent.ID,
"tools": toolNames,
"count": len(normalizedToolCalls),
@@ -599,8 +629,9 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, agent *AgentInstance,
// Build assistant message with tool calls
assistantMsg := providers.Message{
- Role: "assistant",
- Content: response.Content,
+ Role: "assistant",
+ Content: response.Content,
+ ReasoningContent: response.ReasoningContent,
}
for _, tc := range normalizedToolCalls {
argumentsJSON, _ := json.Marshal(tc.Arguments)
@@ -634,7 +665,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, agent *AgentInstance,
argsJSON, _ := json.Marshal(tc.Arguments)
argsPreview := utils.Truncate(string(argsJSON), 200)
logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
- map[string]interface{}{
+ map[string]any{
"agent_id": agent.ID,
"tool": tc.Name,
"iteration": iteration,
@@ -649,14 +680,21 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, agent *AgentInstance,
// The agent will handle user notification via processSystemMessage
if !result.Silent && result.ForUser != "" {
logger.InfoCF("agent", "Async tool completed, agent will handle notification",
- map[string]interface{}{
+ map[string]any{
"tool": tc.Name,
"content_len": len(result.ForUser),
})
}
}
- toolResult := agent.Tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, opts.Channel, opts.ChatID, asyncCallback)
+ toolResult := agent.Tools.ExecuteWithContext(
+ ctx,
+ tc.Name,
+ tc.Arguments,
+ opts.Channel,
+ opts.ChatID,
+ asyncCallback,
+ )
// Send ForUser content to user immediately if not Silent
if !toolResult.Silent && toolResult.ForUser != "" && opts.SendResponse {
@@ -666,7 +704,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, agent *AgentInstance,
Content: toolResult.ForUser,
})
logger.DebugCF("agent", "Sent tool result to user",
- map[string]interface{}{
+ map[string]any{
"tool": tc.Name,
"content_len": len(toolResult.ForUser),
})
@@ -724,13 +762,7 @@ func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey, channel, c
if _, loading := al.summarizing.LoadOrStore(summarizeKey, true); !loading {
go func() {
defer al.summarizing.Delete(summarizeKey)
- if !constants.IsInternalChannel(channel) {
- al.bus.PublishOutbound(bus.OutboundMessage{
- Channel: channel,
- ChatID: chatID,
- Content: "Memory threshold reached. Optimizing conversation history...",
- })
- }
+ logger.Debug("Memory threshold reached. Optimizing conversation history...")
al.summarizeSession(agent, sessionKey)
}()
}
@@ -764,11 +796,14 @@ func (al *AgentLoop) forceCompression(agent *AgentInstance, sessionKey string) {
droppedCount := mid
keptConversation := conversation[mid:]
- newHistory := make([]providers.Message, 0)
+ newHistory := make([]providers.Message, 0, 1+len(keptConversation)+1)
// Append compression note to the original system prompt instead of adding a new system message
// This avoids having two consecutive system messages which some APIs (like Zhipu) reject
- compressionNote := fmt.Sprintf("\n\n[System Note: Emergency compression dropped %d oldest messages due to context limit]", droppedCount)
+ compressionNote := fmt.Sprintf(
+ "\n\n[System Note: Emergency compression dropped %d oldest messages due to context limit]",
+ droppedCount,
+ )
enhancedSystemPrompt := history[0]
enhancedSystemPrompt.Content = enhancedSystemPrompt.Content + compressionNote
newHistory = append(newHistory, enhancedSystemPrompt)
@@ -780,7 +815,7 @@ func (al *AgentLoop) forceCompression(agent *AgentInstance, sessionKey string) {
agent.Sessions.SetHistory(sessionKey, newHistory)
agent.Sessions.Save(sessionKey)
- logger.WarnCF("agent", "Forced compression executed", map[string]interface{}{
+ logger.WarnCF("agent", "Forced compression executed", map[string]any{
"session_key": sessionKey,
"dropped_msgs": droppedCount,
"new_count": len(newHistory),
@@ -788,8 +823,8 @@ func (al *AgentLoop) forceCompression(agent *AgentInstance, sessionKey string) {
}
// GetStartupInfo returns information about loaded tools and skills for logging.
-func (al *AgentLoop) GetStartupInfo() map[string]interface{} {
- info := make(map[string]interface{})
+func (al *AgentLoop) GetStartupInfo() map[string]any {
+ info := make(map[string]any)
agent := al.registry.GetDefaultAgent()
if agent == nil {
@@ -798,7 +833,7 @@ func (al *AgentLoop) GetStartupInfo() map[string]interface{} {
// Tools info
toolsList := agent.Tools.List()
- info["tools"] = map[string]interface{}{
+ info["tools"] = map[string]any{
"count": len(toolsList),
"names": toolsList,
}
@@ -807,7 +842,7 @@ func (al *AgentLoop) GetStartupInfo() map[string]interface{} {
info["skills"] = agent.ContextBuilder.GetSkillsInfo()
// Agents info
- info["agents"] = map[string]interface{}{
+ info["agents"] = map[string]any{
"count": len(al.registry.ListAgentIDs()),
"ids": al.registry.ListAgentIDs(),
}
@@ -821,49 +856,49 @@ func formatMessagesForLog(messages []providers.Message) string {
return "[]"
}
- var result string
- result += "[\n"
+ var sb strings.Builder
+ sb.WriteString("[\n")
for i, msg := range messages {
- result += fmt.Sprintf(" [%d] Role: %s\n", i, msg.Role)
+ fmt.Fprintf(&sb, " [%d] Role: %s\n", i, msg.Role)
if len(msg.ToolCalls) > 0 {
- result += " ToolCalls:\n"
+ sb.WriteString(" ToolCalls:\n")
for _, tc := range msg.ToolCalls {
- result += fmt.Sprintf(" - ID: %s, Type: %s, Name: %s\n", tc.ID, tc.Type, tc.Name)
+ fmt.Fprintf(&sb, " - ID: %s, Type: %s, Name: %s\n", tc.ID, tc.Type, tc.Name)
if tc.Function != nil {
- result += fmt.Sprintf(" Arguments: %s\n", utils.Truncate(tc.Function.Arguments, 200))
+ fmt.Fprintf(&sb, " Arguments: %s\n", utils.Truncate(tc.Function.Arguments, 200))
}
}
}
if msg.Content != "" {
content := utils.Truncate(msg.Content, 200)
- result += fmt.Sprintf(" Content: %s\n", content)
+ fmt.Fprintf(&sb, " Content: %s\n", content)
}
if msg.ToolCallID != "" {
- result += fmt.Sprintf(" ToolCallID: %s\n", msg.ToolCallID)
+ fmt.Fprintf(&sb, " ToolCallID: %s\n", msg.ToolCallID)
}
- result += "\n"
+ sb.WriteString("\n")
}
- result += "]"
- return result
+ sb.WriteString("]")
+ return sb.String()
}
// formatToolsForLog formats tool definitions for logging
-func formatToolsForLog(tools []providers.ToolDefinition) string {
- if len(tools) == 0 {
+func formatToolsForLog(toolDefs []providers.ToolDefinition) string {
+ if len(toolDefs) == 0 {
return "[]"
}
- var result string
- result += "[\n"
- for i, tool := range tools {
- result += fmt.Sprintf(" [%d] Type: %s, Name: %s\n", i, tool.Type, tool.Function.Name)
- result += fmt.Sprintf(" Description: %s\n", tool.Function.Description)
+ var sb strings.Builder
+ sb.WriteString("[\n")
+ for i, tool := range toolDefs {
+ fmt.Fprintf(&sb, " [%d] Type: %s, Name: %s\n", i, tool.Type, tool.Function.Name)
+ fmt.Fprintf(&sb, " Description: %s\n", tool.Function.Description)
if len(tool.Function.Parameters) > 0 {
- result += fmt.Sprintf(" Parameters: %s\n", utils.Truncate(fmt.Sprintf("%v", tool.Function.Parameters), 200))
+ fmt.Fprintf(&sb, " Parameters: %s\n", utils.Truncate(fmt.Sprintf("%v", tool.Function.Parameters), 200))
}
}
- result += "]"
- return result
+ sb.WriteString("]")
+ return sb.String()
}
// summarizeSession summarizes the conversation history for a session.
@@ -912,11 +947,22 @@ func (al *AgentLoop) summarizeSession(agent *AgentInstance, sessionKey string) {
s1, _ := al.summarizeBatch(ctx, agent, part1, "")
s2, _ := al.summarizeBatch(ctx, agent, part2, "")
- mergePrompt := fmt.Sprintf("Merge these two conversation summaries into one cohesive summary:\n\n1: %s\n\n2: %s", s1, s2)
- resp, err := agent.Provider.Chat(ctx, []providers.Message{{Role: "user", Content: mergePrompt}}, nil, agent.Model, map[string]interface{}{
- "max_tokens": 1024,
- "temperature": 0.3,
- })
+ mergePrompt := fmt.Sprintf(
+ "Merge these two conversation summaries into one cohesive summary:\n\n1: %s\n\n2: %s",
+ s1,
+ s2,
+ )
+ resp, err := agent.Provider.Chat(
+ ctx,
+ []providers.Message{{Role: "user", Content: mergePrompt}},
+ nil,
+ agent.Model,
+ map[string]any{
+ "max_tokens": 1024,
+ "temperature": 0.3,
+ "prompt_cache_key": agent.ID,
+ },
+ )
if err == nil {
finalSummary = resp.Content
} else {
@@ -938,20 +984,36 @@ func (al *AgentLoop) summarizeSession(agent *AgentInstance, sessionKey string) {
}
// summarizeBatch summarizes a batch of messages.
-func (al *AgentLoop) summarizeBatch(ctx context.Context, agent *AgentInstance, batch []providers.Message, existingSummary string) (string, error) {
- prompt := "Provide a concise summary of this conversation segment, preserving core context and key points.\n"
+func (al *AgentLoop) summarizeBatch(
+ ctx context.Context,
+ agent *AgentInstance,
+ batch []providers.Message,
+ existingSummary string,
+) (string, error) {
+ var sb strings.Builder
+ sb.WriteString("Provide a concise summary of this conversation segment, preserving core context and key points.\n")
if existingSummary != "" {
- prompt += "Existing context: " + existingSummary + "\n"
+ sb.WriteString("Existing context: ")
+ sb.WriteString(existingSummary)
+ sb.WriteString("\n")
}
- prompt += "\nCONVERSATION:\n"
+ sb.WriteString("\nCONVERSATION:\n")
for _, m := range batch {
- prompt += fmt.Sprintf("%s: %s\n", m.Role, m.Content)
+ fmt.Fprintf(&sb, "%s: %s\n", m.Role, m.Content)
}
+ prompt := sb.String()
- response, err := agent.Provider.Chat(ctx, []providers.Message{{Role: "user", Content: prompt}}, nil, agent.Model, map[string]interface{}{
- "max_tokens": 1024,
- "temperature": 0.3,
- })
+ response, err := agent.Provider.Chat(
+ ctx,
+ []providers.Message{{Role: "user", Content: prompt}},
+ nil,
+ agent.Model,
+ map[string]any{
+ "max_tokens": 1024,
+ "temperature": 0.3,
+ "prompt_cache_key": agent.ID,
+ },
+ )
if err != nil {
return "", err
}
diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go
index 360685eca..4414398b1 100644
--- a/pkg/agent/loop_test.go
+++ b/pkg/agent/loop_test.go
@@ -171,7 +171,7 @@ func TestToolRegistry_ToolRegistration(t *testing.T) {
// Verify tool is registered by checking it doesn't panic on GetStartupInfo
// (actual tool retrieval is tested in tools package tests)
info := al.GetStartupInfo()
- toolsInfo := info["tools"].(map[string]interface{})
+ toolsInfo := info["tools"].(map[string]any)
toolsList := toolsInfo["names"].([]string)
// Check that our custom tool name is in the list
@@ -246,7 +246,7 @@ func TestToolRegistry_GetDefinitions(t *testing.T) {
al.RegisterTool(testTool)
info := al.GetStartupInfo()
- toolsInfo := info["tools"].(map[string]interface{})
+ toolsInfo := info["tools"].(map[string]any)
toolsList := toolsInfo["names"].([]string)
// Check that our custom tool name is in the list
@@ -293,7 +293,7 @@ func TestAgentLoop_GetStartupInfo(t *testing.T) {
t.Fatal("Expected 'tools' key in startup info")
}
- toolsMap, ok := toolsInfo.(map[string]interface{})
+ toolsMap, ok := toolsInfo.(map[string]any)
if !ok {
t.Fatal("Expected 'tools' to be a map")
}
@@ -349,7 +349,13 @@ type simpleMockProvider struct {
response string
}
-func (m *simpleMockProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, opts map[string]interface{}) (*providers.LLMResponse, error) {
+func (m *simpleMockProvider) Chat(
+ ctx context.Context,
+ messages []providers.Message,
+ tools []providers.ToolDefinition,
+ model string,
+ opts map[string]any,
+) (*providers.LLMResponse, error) {
return &providers.LLMResponse{
Content: m.response,
ToolCalls: []providers.ToolCall{},
@@ -371,14 +377,14 @@ func (m *mockCustomTool) Description() string {
return "Mock custom tool for testing"
}
-func (m *mockCustomTool) Parameters() map[string]interface{} {
- return map[string]interface{}{
+func (m *mockCustomTool) Parameters() map[string]any {
+ return map[string]any{
"type": "object",
- "properties": map[string]interface{}{},
+ "properties": map[string]any{},
}
}
-func (m *mockCustomTool) Execute(ctx context.Context, args map[string]interface{}) *tools.ToolResult {
+func (m *mockCustomTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
return tools.SilentResult("Custom tool executed")
}
@@ -396,14 +402,14 @@ func (m *mockContextualTool) Description() string {
return "Mock contextual tool"
}
-func (m *mockContextualTool) Parameters() map[string]interface{} {
- return map[string]interface{}{
+func (m *mockContextualTool) Parameters() map[string]any {
+ return map[string]any{
"type": "object",
- "properties": map[string]interface{}{},
+ "properties": map[string]any{},
}
}
-func (m *mockContextualTool) Execute(ctx context.Context, args map[string]interface{}) *tools.ToolResult {
+func (m *mockContextualTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
return tools.SilentResult("Contextual tool executed")
}
@@ -523,7 +529,13 @@ type failFirstMockProvider struct {
successResp string
}
-func (m *failFirstMockProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, opts map[string]interface{}) (*providers.LLMResponse, error) {
+func (m *failFirstMockProvider) Chat(
+ ctx context.Context,
+ messages []providers.Message,
+ tools []providers.ToolDefinition,
+ model string,
+ opts map[string]any,
+) (*providers.LLMResponse, error) {
m.currentCall++
if m.currentCall <= m.failures {
return nil, m.failError
@@ -588,7 +600,13 @@ func TestAgentLoop_ContextExhaustionRetry(t *testing.T) {
// Call ProcessDirectWithChannel
// Note: ProcessDirectWithChannel calls processMessage which will execute runLLMIteration
- response, err := al.ProcessDirectWithChannel(context.Background(), "Trigger message", sessionKey, "test", "test-chat")
+ response, err := al.ProcessDirectWithChannel(
+ context.Background(),
+ "Trigger message",
+ sessionKey,
+ "test",
+ "test-chat",
+ )
if err != nil {
t.Fatalf("Expected success after retry, got error: %v", err)
}
diff --git a/pkg/agent/memory.go b/pkg/agent/memory.go
index 3f6896f91..dd5f4441c 100644
--- a/pkg/agent/memory.go
+++ b/pkg/agent/memory.go
@@ -10,6 +10,7 @@ import (
"fmt"
"os"
"path/filepath"
+ "strings"
"time"
)
@@ -29,7 +30,7 @@ func NewMemoryStore(workspace string) *MemoryStore {
memoryFile := filepath.Join(memoryDir, "MEMORY.md")
// Ensure memory directory exists
- os.MkdirAll(memoryDir, 0755)
+ os.MkdirAll(memoryDir, 0o755)
return &MemoryStore{
workspace: workspace,
@@ -57,7 +58,7 @@ func (ms *MemoryStore) ReadLongTerm() string {
// WriteLongTerm writes content to the long-term memory file (MEMORY.md).
func (ms *MemoryStore) WriteLongTerm(content string) error {
- return os.WriteFile(ms.memoryFile, []byte(content), 0644)
+ return os.WriteFile(ms.memoryFile, []byte(content), 0o644)
}
// ReadToday reads today's daily note.
@@ -77,7 +78,7 @@ func (ms *MemoryStore) AppendToday(content string) error {
// Ensure month directory exists
monthDir := filepath.Dir(todayFile)
- os.MkdirAll(monthDir, 0755)
+ os.MkdirAll(monthDir, 0o755)
var existingContent string
if data, err := os.ReadFile(todayFile); err == nil {
@@ -94,13 +95,14 @@ func (ms *MemoryStore) AppendToday(content string) error {
newContent = existingContent + "\n" + content
}
- return os.WriteFile(todayFile, []byte(newContent), 0644)
+ return os.WriteFile(todayFile, []byte(newContent), 0o644)
}
// GetRecentDailyNotes returns daily notes from the last N days.
// Contents are joined with "---" separator.
func (ms *MemoryStore) GetRecentDailyNotes(days int) string {
- var notes []string
+ var sb strings.Builder
+ first := true
for i := 0; i < days; i++ {
date := time.Now().AddDate(0, 0, -i)
@@ -109,53 +111,41 @@ func (ms *MemoryStore) GetRecentDailyNotes(days int) string {
filePath := filepath.Join(ms.memoryDir, monthDir, dateStr+".md")
if data, err := os.ReadFile(filePath); err == nil {
- notes = append(notes, string(data))
+ if !first {
+ sb.WriteString("\n\n---\n\n")
+ }
+ sb.Write(data)
+ first = false
}
}
- if len(notes) == 0 {
- return ""
- }
-
- // Join with separator
- var result string
- for i, note := range notes {
- if i > 0 {
- result += "\n\n---\n\n"
- }
- result += note
- }
- return result
+ return sb.String()
}
// GetMemoryContext returns formatted memory context for the agent prompt.
// Includes long-term memory and recent daily notes.
func (ms *MemoryStore) GetMemoryContext() string {
- var parts []string
-
- // Long-term memory
longTerm := ms.ReadLongTerm()
- if longTerm != "" {
- parts = append(parts, "## Long-term Memory\n\n"+longTerm)
- }
-
- // Recent daily notes (last 3 days)
recentNotes := ms.GetRecentDailyNotes(3)
- if recentNotes != "" {
- parts = append(parts, "## Recent Daily Notes\n\n"+recentNotes)
- }
- if len(parts) == 0 {
+ if longTerm == "" && recentNotes == "" {
return ""
}
- // Join parts with separator
- var result string
- for i, part := range parts {
- if i > 0 {
- result += "\n\n---\n\n"
- }
- result += part
+ var sb strings.Builder
+
+ if longTerm != "" {
+ sb.WriteString("## Long-term Memory\n\n")
+ sb.WriteString(longTerm)
}
- return fmt.Sprintf("# Memory\n\n%s", result)
+
+ if recentNotes != "" {
+ if longTerm != "" {
+ sb.WriteString("\n\n---\n\n")
+ }
+ sb.WriteString("## Recent Daily Notes\n\n")
+ sb.WriteString(recentNotes)
+ }
+
+ return sb.String()
}
diff --git a/pkg/agent/mock_provider_test.go b/pkg/agent/mock_provider_test.go
index ccbecbafe..4962810dc 100644
--- a/pkg/agent/mock_provider_test.go
+++ b/pkg/agent/mock_provider_test.go
@@ -8,7 +8,13 @@ import (
type mockProvider struct{}
-func (m *mockProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, opts map[string]interface{}) (*providers.LLMResponse, error) {
+func (m *mockProvider) Chat(
+ ctx context.Context,
+ messages []providers.Message,
+ tools []providers.ToolDefinition,
+ model string,
+ opts map[string]any,
+) (*providers.LLMResponse, error) {
return &providers.LLMResponse{
Content: "Mock response",
ToolCalls: []providers.ToolCall{},
diff --git a/pkg/agent/registry.go b/pkg/agent/registry.go
index 4cf5a6fca..77b846832 100644
--- a/pkg/agent/registry.go
+++ b/pkg/agent/registry.go
@@ -42,7 +42,7 @@ func NewAgentRegistry(
instance := NewAgentInstance(ac, &cfg.Agents.Defaults, cfg, provider)
registry.agents[id] = instance
logger.InfoCF("agent", "Registered agent",
- map[string]interface{}{
+ map[string]any{
"agent_id": id,
"name": ac.Name,
"workspace": instance.Workspace,
diff --git a/pkg/agent/registry_test.go b/pkg/agent/registry_test.go
index f196d7fb7..518bb441f 100644
--- a/pkg/agent/registry_test.go
+++ b/pkg/agent/registry_test.go
@@ -10,7 +10,13 @@ import (
type mockRegistryProvider struct{}
-func (m *mockRegistryProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]interface{}) (*providers.LLMResponse, error) {
+func (m *mockRegistryProvider) Chat(
+ ctx context.Context,
+ messages []providers.Message,
+ tools []providers.ToolDefinition,
+ model string,
+ options map[string]any,
+) (*providers.LLMResponse, error) {
return &providers.LLMResponse{Content: "mock", FinishReason: "stop"}, nil
}
diff --git a/pkg/auth/oauth.go b/pkg/auth/oauth.go
index 4376f24d4..ba757ffd4 100644
--- a/pkg/auth/oauth.go
+++ b/pkg/auth/oauth.go
@@ -44,7 +44,9 @@ func OpenAIOAuthConfig() OAuthProviderConfig {
// Client credentials are the same ones used by OpenCode/pi-ai for Cloud Code Assist access.
func GoogleAntigravityOAuthConfig() OAuthProviderConfig {
// These are the same client credentials used by the OpenCode antigravity plugin.
- clientID := decodeBase64("MTA3MTAwNjA2MDU5MS10bWhzc2luMmgyMWxjcmUyMzV2dG9sb2poNGc0MDNlcC5hcHBzLmdvb2dsZXVzZXJjb250ZW50LmNvbQ==")
+ clientID := decodeBase64(
+ "MTA3MTAwNjA2MDU5MS10bWhzc2luMmgyMWxjcmUyMzV2dG9sb2poNGc0MDNlcC5hcHBzLmdvb2dsZXVzZXJjb250ZW50LmNvbQ==",
+ )
clientSecret := decodeBase64("R09DU1BYLUs1OEZXUjQ4NkxkTEoxbUxCOHNYQzR6NnFEQWY=")
return OAuthProviderConfig{
Issuer: "https://accounts.google.com/o/oauth2/v2",
@@ -129,8 +131,13 @@ func LoginBrowser(cfg OAuthProviderConfig) (*AuthCredential, error) {
fmt.Printf("Could not open browser automatically.\nPlease open this URL manually:\n\n%s\n\n", authURL)
}
- fmt.Printf("Wait! If you are in a headless environment (like Coolify/VPS) and cannot reach localhost:%d,\n", cfg.Port)
- fmt.Println("please complete the login in your local browser and then PASTE the final redirect URL (or just the code) here.")
+ fmt.Printf(
+ "Wait! If you are in a headless environment (like Coolify/VPS) and cannot reach localhost:%d,\n",
+ cfg.Port,
+ )
+ fmt.Println(
+ "please complete the login in your local browser and then PASTE the final redirect URL (or just the code) here.",
+ )
fmt.Println("Waiting for authentication (browser or manual paste)...")
// Start manual input in a goroutine
@@ -149,7 +156,7 @@ func LoginBrowser(cfg OAuthProviderConfig) (*AuthCredential, error) {
return exchangeCodeForTokens(cfg, result.code, pkce.CodeVerifier, redirectURI)
case manualInput := <-manualCh:
if manualInput == "" {
- return nil, fmt.Errorf("manual input cancelled")
+ return nil, fmt.Errorf("manual input canceled")
}
// Extract code from URL if it's a full URL
code := manualInput
@@ -253,8 +260,11 @@ func LoginDeviceCode(cfg OAuthProviderConfig) (*AuthCredential, error) {
deviceResp.Interval = 5
}
- fmt.Printf("\nTo authenticate, open this URL in your browser:\n\n %s/codex/device\n\nThen enter this code: %s\n\nWaiting for authentication...\n",
- cfg.Issuer, deviceResp.UserCode)
+ fmt.Printf(
+ "\nTo authenticate, open this URL in your browser:\n\n %s/codex/device\n\nThen enter this code: %s\n\nWaiting for authentication...\n",
+ cfg.Issuer,
+ deviceResp.UserCode,
+ )
deadline := time.After(15 * time.Minute)
ticker := time.NewTicker(time.Duration(deviceResp.Interval) * time.Second)
@@ -491,15 +501,15 @@ func extractAccountID(token string) string {
return accountID
}
- if authClaim, ok := claims["https://api.openai.com/auth"].(map[string]interface{}); ok {
+ if authClaim, ok := claims["https://api.openai.com/auth"].(map[string]any); ok {
if accountID, ok := authClaim["chatgpt_account_id"].(string); ok && accountID != "" {
return accountID
}
}
- if orgs, ok := claims["organizations"].([]interface{}); ok {
+ if orgs, ok := claims["organizations"].([]any); ok {
for _, org := range orgs {
- if orgMap, ok := org.(map[string]interface{}); ok {
+ if orgMap, ok := org.(map[string]any); ok {
if accountID, ok := orgMap["id"].(string); ok && accountID != "" {
return accountID
}
@@ -510,7 +520,7 @@ func extractAccountID(token string) string {
return ""
}
-func parseJWTClaims(token string) (map[string]interface{}, error) {
+func parseJWTClaims(token string) (map[string]any, error) {
parts := strings.Split(token, ".")
if len(parts) < 2 {
return nil, fmt.Errorf("token is not a JWT")
@@ -529,7 +539,7 @@ func parseJWTClaims(token string) (map[string]interface{}, error) {
return nil, err
}
- var claims map[string]interface{}
+ var claims map[string]any
if err := json.Unmarshal(decoded, &claims); err != nil {
return nil, err
}
diff --git a/pkg/auth/oauth_test.go b/pkg/auth/oauth_test.go
index 5deb17805..0cb589069 100644
--- a/pkg/auth/oauth_test.go
+++ b/pkg/auth/oauth_test.go
@@ -10,7 +10,7 @@ import (
"testing"
)
-func makeJWTForClaims(t *testing.T, claims map[string]interface{}) string {
+func makeJWTForClaims(t *testing.T, claims map[string]any) string {
t.Helper()
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"none","typ":"JWT"}`))
@@ -89,7 +89,7 @@ func TestBuildAuthorizeURLOpenAIExtras(t *testing.T) {
}
func TestParseTokenResponse(t *testing.T) {
- resp := map[string]interface{}{
+ resp := map[string]any{
"access_token": "test-access-token",
"refresh_token": "test-refresh-token",
"expires_in": 3600,
@@ -120,8 +120,8 @@ func TestParseTokenResponse(t *testing.T) {
}
func TestParseTokenResponseExtractsAccountIDFromIDToken(t *testing.T) {
- idToken := makeJWTForClaims(t, map[string]interface{}{"chatgpt_account_id": "acc-id-from-id-token"})
- resp := map[string]interface{}{
+ idToken := makeJWTForClaims(t, map[string]any{"chatgpt_account_id": "acc-id-from-id-token"})
+ resp := map[string]any{
"access_token": "opaque-access-token",
"refresh_token": "test-refresh-token",
"expires_in": 3600,
@@ -139,9 +139,9 @@ func TestParseTokenResponseExtractsAccountIDFromIDToken(t *testing.T) {
}
func TestExtractAccountIDFromOrganizationsFallback(t *testing.T) {
- token := makeJWTForClaims(t, map[string]interface{}{
- "organizations": []interface{}{
- map[string]interface{}{"id": "org_from_orgs"},
+ token := makeJWTForClaims(t, map[string]any{
+ "organizations": []any{
+ map[string]any{"id": "org_from_orgs"},
},
})
@@ -160,7 +160,7 @@ func TestParseTokenResponseNoAccessToken(t *testing.T) {
func TestParseTokenResponseAccountIDFromIDToken(t *testing.T) {
idToken := makeJWTWithAccountID("acc-from-id")
- resp := map[string]interface{}{
+ resp := map[string]any{
"access_token": "not-a-jwt",
"refresh_token": "test-refresh-token",
"expires_in": 3600,
@@ -180,7 +180,9 @@ func TestParseTokenResponseAccountIDFromIDToken(t *testing.T) {
func makeJWTWithAccountID(accountID string) string {
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"none","typ":"JWT"}`))
- payload := base64.RawURLEncoding.EncodeToString([]byte(`{"https://api.openai.com/auth":{"chatgpt_account_id":"` + accountID + `"}}`))
+ payload := base64.RawURLEncoding.EncodeToString(
+ []byte(`{"https://api.openai.com/auth":{"chatgpt_account_id":"` + accountID + `"}}`),
+ )
return header + "." + payload + ".sig"
}
@@ -201,7 +203,7 @@ func TestExchangeCodeForTokens(t *testing.T) {
return
}
- resp := map[string]interface{}{
+ resp := map[string]any{
"access_token": "mock-access-token",
"refresh_token": "mock-refresh-token",
"expires_in": 3600,
@@ -240,7 +242,7 @@ func TestRefreshAccessToken(t *testing.T) {
return
}
- resp := map[string]interface{}{
+ resp := map[string]any{
"access_token": "refreshed-access-token",
"refresh_token": "refreshed-refresh-token",
"expires_in": 3600,
@@ -290,7 +292,7 @@ func TestRefreshAccessTokenNoRefreshToken(t *testing.T) {
func TestRefreshAccessTokenPreservesRefreshAndAccountID(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- resp := map[string]interface{}{
+ resp := map[string]any{
"access_token": "new-access-token-only",
"expires_in": 3600,
}
diff --git a/pkg/auth/store.go b/pkg/auth/store.go
index 785d5858e..64708421b 100644
--- a/pkg/auth/store.go
+++ b/pkg/auth/store.go
@@ -64,7 +64,7 @@ func LoadStore() (*AuthStore, error) {
func SaveStore(store *AuthStore) error {
path := authFilePath()
dir := filepath.Dir(path)
- if err := os.MkdirAll(dir, 0755); err != nil {
+ if err := os.MkdirAll(dir, 0o755); err != nil {
return err
}
@@ -72,7 +72,7 @@ func SaveStore(store *AuthStore) error {
if err != nil {
return err
}
- return os.WriteFile(path, data, 0600)
+ return os.WriteFile(path, data, 0o600)
}
func GetCredential(provider string) (*AuthCredential, error) {
diff --git a/pkg/auth/store_test.go b/pkg/auth/store_test.go
index d96b460a1..f6793cfce 100644
--- a/pkg/auth/store_test.go
+++ b/pkg/auth/store_test.go
@@ -108,7 +108,7 @@ func TestStoreFilePermissions(t *testing.T) {
t.Fatalf("Stat() error: %v", err)
}
perm := info.Mode().Perm()
- if perm != 0600 {
+ if perm != 0o600 {
t.Errorf("file permissions = %o, want 0600", perm)
}
}
diff --git a/pkg/channels/base.go b/pkg/channels/base.go
index 4925099a3..cd6419ebb 100644
--- a/pkg/channels/base.go
+++ b/pkg/channels/base.go
@@ -17,14 +17,14 @@ type Channel interface {
}
type BaseChannel struct {
- config interface{}
+ config any
bus *bus.MessageBus
running bool
name string
allowList []string
}
-func NewBaseChannel(name string, config interface{}, bus *bus.MessageBus, allowList []string) *BaseChannel {
+func NewBaseChannel(name string, config any, bus *bus.MessageBus, allowList []string) *BaseChannel {
return &BaseChannel{
config: config,
bus: bus,
diff --git a/pkg/channels/dingtalk.go b/pkg/channels/dingtalk.go
index 79cc85219..662fba3b7 100644
--- a/pkg/channels/dingtalk.go
+++ b/pkg/channels/dingtalk.go
@@ -10,6 +10,7 @@ import (
"github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot"
"github.com/open-dingtalk/dingtalk-stream-sdk-go/client"
+
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
@@ -108,7 +109,7 @@ func (c *DingTalkChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
return fmt.Errorf("invalid session_webhook type for chat %s", msg.ChatID)
}
- logger.DebugCF("dingtalk", "Sending message", map[string]interface{}{
+ logger.DebugCF("dingtalk", "Sending message", map[string]any{
"chat_id": msg.ChatID,
"preview": utils.Truncate(msg.Content, 100),
})
@@ -120,12 +121,15 @@ func (c *DingTalkChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
// onChatBotMessageReceived implements the IChatBotMessageHandler function signature
// This is called by the Stream SDK when a new message arrives
// IChatBotMessageHandler is: func(c context.Context, data *chatbot.BotCallbackDataModel) ([]byte, error)
-func (c *DingTalkChannel) onChatBotMessageReceived(ctx context.Context, data *chatbot.BotCallbackDataModel) ([]byte, error) {
+func (c *DingTalkChannel) onChatBotMessageReceived(
+ ctx context.Context,
+ data *chatbot.BotCallbackDataModel,
+) ([]byte, error) {
// Extract message content from Text field
content := data.Text.Content
if content == "" {
// Try to extract from Content interface{} if Text is empty
- if contentMap, ok := data.Content.(map[string]interface{}); ok {
+ if contentMap, ok := data.Content.(map[string]any); ok {
if textContent, ok := contentMap["content"].(string); ok {
content = textContent
}
@@ -163,7 +167,7 @@ func (c *DingTalkChannel) onChatBotMessageReceived(ctx context.Context, data *ch
metadata["peer_id"] = data.ConversationId
}
- logger.DebugCF("dingtalk", "Received message", map[string]interface{}{
+ logger.DebugCF("dingtalk", "Received message", map[string]any{
"sender_nick": senderNick,
"sender_id": senderID,
"preview": utils.Truncate(content, 50),
@@ -192,7 +196,6 @@ func (c *DingTalkChannel) SendDirectReply(ctx context.Context, sessionWebhook, c
titleBytes,
contentBytes,
)
-
if err != nil {
return fmt.Errorf("failed to send reply: %w", err)
}
diff --git a/pkg/channels/discord.go b/pkg/channels/discord.go
index 9ddec662c..f6faa3373 100644
--- a/pkg/channels/discord.go
+++ b/pkg/channels/discord.go
@@ -4,10 +4,12 @@ import (
"context"
"fmt"
"os"
+ "strings"
"sync"
"time"
"github.com/bwmarrin/discordgo"
+
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
@@ -28,6 +30,7 @@ type DiscordChannel struct {
ctx context.Context
typingMu sync.Mutex
typingStop map[string]chan struct{} // chatID → stop signal
+ botUserID string // stored for mention checking
}
func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordChannel, error) {
@@ -63,6 +66,14 @@ func (c *DiscordChannel) Start(ctx context.Context) error {
logger.InfoC("discord", "Starting Discord bot")
c.ctx = ctx
+
+ // Get bot user ID before opening session to avoid race condition
+ botUser, err := c.session.User("@me")
+ if err != nil {
+ return fmt.Errorf("failed to get bot user: %w", err)
+ }
+ c.botUserID = botUser.ID
+
c.session.AddHandler(c.handleMessage)
if err := c.session.Open(); err != nil {
@@ -71,10 +82,6 @@ func (c *DiscordChannel) Start(ctx context.Context) error {
c.setRunning(true)
- botUser, err := c.session.User("@me")
- if err != nil {
- return fmt.Errorf("failed to get bot user: %w", err)
- }
logger.InfoCF("discord", "Discord bot connected", map[string]any{
"username": botUser.Username,
"user_id": botUser.ID,
@@ -131,7 +138,7 @@ func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro
}
func (c *DiscordChannel) sendChunk(ctx context.Context, channelID, content string) error {
- // 使用传入的 ctx 进行超时控制
+ // Use the passed ctx for timeout control
sendCtx, cancel := context.WithTimeout(ctx, sendTimeout)
defer cancel()
@@ -152,7 +159,7 @@ func (c *DiscordChannel) sendChunk(ctx context.Context, channelID, content strin
}
}
-// appendContent 安全地追加内容到现有文本
+// appendContent safely appends content to existing text
func appendContent(content, suffix string) string {
if content == "" {
return suffix
@@ -169,7 +176,7 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
return
}
- // 检查白名单,避免为被拒绝的用户下载附件和转录
+ // Check allowlist first to avoid downloading attachments and transcribing for rejected users
if !c.IsAllowed(m.Author.ID) {
logger.DebugCF("discord", "Message rejected by allowlist", map[string]any{
"user_id": m.Author.ID,
@@ -177,6 +184,24 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
return
}
+ // If configured to only respond to mentions, check if bot is mentioned
+ // Skip this check for DMs (GuildID is empty) - DMs should always be responded to
+ if c.config.MentionOnly && m.GuildID != "" {
+ isMentioned := false
+ for _, mention := range m.Mentions {
+ if mention.ID == c.botUserID {
+ isMentioned = true
+ break
+ }
+ }
+ if !isMentioned {
+ logger.DebugCF("discord", "Message ignored - bot not mentioned", map[string]any{
+ "user_id": m.Author.ID,
+ })
+ return
+ }
+ }
+
senderID := m.Author.ID
senderName := m.Author.Username
if m.Author.Discriminator != "" && m.Author.Discriminator != "0" {
@@ -184,10 +209,11 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
}
content := m.Content
+ content = c.stripBotMention(content)
mediaPaths := make([]string, 0, len(m.Attachments))
localFiles := make([]string, 0, len(m.Attachments))
- // 确保临时文件在函数返回时被清理
+ // Ensure temp files are cleaned up when function returns
defer func() {
for _, file := range localFiles {
if err := os.Remove(file); err != nil {
@@ -207,11 +233,11 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
if localPath != "" {
localFiles = append(localFiles, localPath)
- transcribedText := ""
+ var transcribedText string
if c.transcriber != nil && c.transcriber.IsAvailable() {
ctx, cancel := context.WithTimeout(c.getContext(), transcriptionTimeout)
result, err := c.transcriber.Transcribe(ctx, localPath)
- cancel() // 立即释放context资源,避免在for循环中泄漏
+ cancel() // Release context resources immediately to avoid leaks in for loop
if err != nil {
logger.ErrorCF("discord", "Voice transcription failed", map[string]any{
@@ -296,7 +322,7 @@ func (c *DiscordChannel) startTyping(chatID string) {
go func() {
if err := c.session.ChannelTyping(chatID); err != nil {
- logger.DebugCF("discord", "ChannelTyping error", map[string]interface{}{"chatID": chatID, "err": err})
+ logger.DebugCF("discord", "ChannelTyping error", map[string]any{"chatID": chatID, "err": err})
}
ticker := time.NewTicker(8 * time.Second)
defer ticker.Stop()
@@ -311,7 +337,7 @@ func (c *DiscordChannel) startTyping(chatID string) {
return
case <-ticker.C:
if err := c.session.ChannelTyping(chatID); err != nil {
- logger.DebugCF("discord", "ChannelTyping error", map[string]interface{}{"chatID": chatID, "err": err})
+ logger.DebugCF("discord", "ChannelTyping error", map[string]any{"chatID": chatID, "err": err})
}
}
}
@@ -333,3 +359,15 @@ func (c *DiscordChannel) downloadAttachment(url, filename string) string {
LoggerPrefix: "discord",
})
}
+
+// stripBotMention removes the bot mention from the message content.
+// Discord mentions have the format <@USER_ID> or <@!USER_ID> (with nickname).
+func (c *DiscordChannel) stripBotMention(text string) string {
+ if c.botUserID == "" {
+ return text
+ }
+ // Remove both regular mention <@USER_ID> and nickname mention <@!USER_ID>
+ text = strings.ReplaceAll(text, fmt.Sprintf("<@%s>", c.botUserID), "")
+ text = strings.ReplaceAll(text, fmt.Sprintf("<@!%s>", c.botUserID), "")
+ return strings.TrimSpace(text)
+}
diff --git a/pkg/channels/feishu_32.go b/pkg/channels/feishu_32.go
index 4e60fbc11..5109b8195 100644
--- a/pkg/channels/feishu_32.go
+++ b/pkg/channels/feishu_32.go
@@ -17,7 +17,9 @@ type FeishuChannel struct {
// NewFeishuChannel returns an error on 32-bit architectures where the Feishu SDK is not supported
func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChannel, error) {
- return nil, errors.New("feishu channel is not supported on 32-bit architectures (armv7l, 386, etc.). Please use a 64-bit system or disable feishu in your config")
+ return nil, errors.New(
+ "feishu channel is not supported on 32-bit architectures (armv7l, 386, etc.). Please use a 64-bit system or disable feishu in your config",
+ )
}
// Start is a stub method to satisfy the Channel interface
diff --git a/pkg/channels/feishu_64.go b/pkg/channels/feishu_64.go
index 9e15fa3a7..42e74980f 100644
--- a/pkg/channels/feishu_64.go
+++ b/pkg/channels/feishu_64.go
@@ -65,7 +65,7 @@ func (c *FeishuChannel) Start(ctx context.Context) error {
go func() {
if err := wsClient.Start(runCtx); err != nil {
- logger.ErrorCF("feishu", "Feishu websocket stopped with error", map[string]interface{}{
+ logger.ErrorCF("feishu", "Feishu websocket stopped with error", map[string]any{
"error": err.Error(),
})
}
@@ -121,7 +121,7 @@ func (c *FeishuChannel) Send(ctx context.Context, msg bus.OutboundMessage) error
return fmt.Errorf("feishu api error: code=%d msg=%s", resp.Code, resp.Msg)
}
- logger.DebugCF("feishu", "Feishu message sent", map[string]interface{}{
+ logger.DebugCF("feishu", "Feishu message sent", map[string]any{
"chat_id": msg.ChatID,
})
@@ -174,7 +174,7 @@ func (c *FeishuChannel) handleMessageReceive(_ context.Context, event *larkim.P2
metadata["peer_id"] = chatID
}
- logger.InfoCF("feishu", "Feishu message received", map[string]interface{}{
+ logger.InfoCF("feishu", "Feishu message received", map[string]any{
"sender_id": senderID,
"chat_id": chatID,
"preview": utils.Truncate(content, 80),
diff --git a/pkg/channels/line.go b/pkg/channels/line.go
index 9f7d2bde0..44134996f 100644
--- a/pkg/channels/line.go
+++ b/pkg/channels/line.go
@@ -75,11 +75,11 @@ func (c *LINEChannel) Start(ctx context.Context) error {
// Fetch bot profile to get bot's userId for mention detection
if err := c.fetchBotInfo(); err != nil {
- logger.WarnCF("line", "Failed to fetch bot info (mention detection disabled)", map[string]interface{}{
+ logger.WarnCF("line", "Failed to fetch bot info (mention detection disabled)", map[string]any{
"error": err.Error(),
})
} else {
- logger.InfoCF("line", "Bot info fetched", map[string]interface{}{
+ logger.InfoCF("line", "Bot info fetched", map[string]any{
"bot_user_id": c.botUserID,
"basic_id": c.botBasicID,
"display_name": c.botDisplayName,
@@ -100,12 +100,12 @@ func (c *LINEChannel) Start(ctx context.Context) error {
}
go func() {
- logger.InfoCF("line", "LINE webhook server listening", map[string]interface{}{
+ logger.InfoCF("line", "LINE webhook server listening", map[string]any{
"addr": addr,
"path": path,
})
if err := c.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
- logger.ErrorCF("line", "Webhook server error", map[string]interface{}{
+ logger.ErrorCF("line", "Webhook server error", map[string]any{
"error": err.Error(),
})
}
@@ -162,7 +162,7 @@ func (c *LINEChannel) Stop(ctx context.Context) error {
shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
if err := c.httpServer.Shutdown(shutdownCtx); err != nil {
- logger.ErrorCF("line", "Webhook server shutdown error", map[string]interface{}{
+ logger.ErrorCF("line", "Webhook server shutdown error", map[string]any{
"error": err.Error(),
})
}
@@ -182,7 +182,7 @@ func (c *LINEChannel) webhookHandler(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body)
if err != nil {
- logger.ErrorCF("line", "Failed to read request body", map[string]interface{}{
+ logger.ErrorCF("line", "Failed to read request body", map[string]any{
"error": err.Error(),
})
http.Error(w, "Bad request", http.StatusBadRequest)
@@ -200,7 +200,7 @@ func (c *LINEChannel) webhookHandler(w http.ResponseWriter, r *http.Request) {
Events []lineEvent `json:"events"`
}
if err := json.Unmarshal(body, &payload); err != nil {
- logger.ErrorCF("line", "Failed to parse webhook payload", map[string]interface{}{
+ logger.ErrorCF("line", "Failed to parse webhook payload", map[string]any{
"error": err.Error(),
})
http.Error(w, "Bad request", http.StatusBadRequest)
@@ -266,7 +266,7 @@ type lineMentionee struct {
func (c *LINEChannel) processEvent(event lineEvent) {
if event.Type != "message" {
- logger.DebugCF("line", "Ignoring non-message event", map[string]interface{}{
+ logger.DebugCF("line", "Ignoring non-message event", map[string]any{
"type": event.Type,
})
return
@@ -278,7 +278,7 @@ func (c *LINEChannel) processEvent(event lineEvent) {
var msg lineMessage
if err := json.Unmarshal(event.Message, &msg); err != nil {
- logger.ErrorCF("line", "Failed to parse message", map[string]interface{}{
+ logger.ErrorCF("line", "Failed to parse message", map[string]any{
"error": err.Error(),
})
return
@@ -286,7 +286,7 @@ func (c *LINEChannel) processEvent(event lineEvent) {
// In group chats, only respond when the bot is mentioned
if isGroup && !c.isBotMentioned(msg) {
- logger.DebugCF("line", "Ignoring group message without mention", map[string]interface{}{
+ logger.DebugCF("line", "Ignoring group message without mention", map[string]any{
"chat_id": chatID,
})
return
@@ -312,7 +312,7 @@ func (c *LINEChannel) processEvent(event lineEvent) {
defer func() {
for _, file := range localFiles {
if err := os.Remove(file); err != nil {
- logger.DebugCF("line", "Failed to cleanup temp file", map[string]interface{}{
+ logger.DebugCF("line", "Failed to cleanup temp file", map[string]any{
"file": file,
"error": err.Error(),
})
@@ -374,7 +374,7 @@ func (c *LINEChannel) processEvent(event lineEvent) {
metadata["peer_id"] = senderID
}
- logger.DebugCF("line", "Received message", map[string]interface{}{
+ logger.DebugCF("line", "Received message", map[string]any{
"sender_id": senderID,
"chat_id": chatID,
"message_type": msg.Type,
@@ -505,7 +505,7 @@ func (c *LINEChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
tokenEntry := entry.(replyTokenEntry)
if time.Since(tokenEntry.timestamp) < lineReplyTokenMaxAge {
if err := c.sendReply(ctx, tokenEntry.token, msg.Content, quoteToken); err == nil {
- logger.DebugCF("line", "Message sent via Reply API", map[string]interface{}{
+ logger.DebugCF("line", "Message sent via Reply API", map[string]any{
"chat_id": msg.ChatID,
"quoted": quoteToken != "",
})
@@ -533,7 +533,7 @@ func buildTextMessage(content, quoteToken string) map[string]string {
// sendReply sends a message using the LINE Reply API.
func (c *LINEChannel) sendReply(ctx context.Context, replyToken, content, quoteToken string) error {
- payload := map[string]interface{}{
+ payload := map[string]any{
"replyToken": replyToken,
"messages": []map[string]string{buildTextMessage(content, quoteToken)},
}
@@ -543,7 +543,7 @@ func (c *LINEChannel) sendReply(ctx context.Context, replyToken, content, quoteT
// sendPush sends a message using the LINE Push API.
func (c *LINEChannel) sendPush(ctx context.Context, to, content, quoteToken string) error {
- payload := map[string]interface{}{
+ payload := map[string]any{
"to": to,
"messages": []map[string]string{buildTextMessage(content, quoteToken)},
}
@@ -553,19 +553,19 @@ func (c *LINEChannel) sendPush(ctx context.Context, to, content, quoteToken stri
// sendLoading sends a loading animation indicator to the chat.
func (c *LINEChannel) sendLoading(chatID string) {
- payload := map[string]interface{}{
+ payload := map[string]any{
"chatId": chatID,
"loadingSeconds": 60,
}
if err := c.callAPI(c.ctx, lineLoadingEndpoint, payload); err != nil {
- logger.DebugCF("line", "Failed to send loading indicator", map[string]interface{}{
+ logger.DebugCF("line", "Failed to send loading indicator", map[string]any{
"error": err.Error(),
})
}
}
// callAPI makes an authenticated POST request to the LINE API.
-func (c *LINEChannel) callAPI(ctx context.Context, endpoint string, payload interface{}) error {
+func (c *LINEChannel) callAPI(ctx context.Context, endpoint string, payload any) error {
body, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("failed to marshal payload: %w", err)
diff --git a/pkg/channels/maixcam.go b/pkg/channels/maixcam.go
index 95da0547c..34ce62b20 100644
--- a/pkg/channels/maixcam.go
+++ b/pkg/channels/maixcam.go
@@ -21,10 +21,10 @@ type MaixCamChannel struct {
}
type MaixCamMessage struct {
- Type string `json:"type"`
- Tips string `json:"tips"`
- Timestamp float64 `json:"timestamp"`
- Data map[string]interface{} `json:"data"`
+ Type string `json:"type"`
+ Tips string `json:"tips"`
+ Timestamp float64 `json:"timestamp"`
+ Data map[string]any `json:"data"`
}
func NewMaixCamChannel(cfg config.MaixCamConfig, bus *bus.MessageBus) (*MaixCamChannel, error) {
@@ -49,7 +49,7 @@ func (c *MaixCamChannel) Start(ctx context.Context) error {
c.listener = listener
c.setRunning(true)
- logger.InfoCF("maixcam", "MaixCam server listening", map[string]interface{}{
+ logger.InfoCF("maixcam", "MaixCam server listening", map[string]any{
"host": c.config.Host,
"port": c.config.Port,
})
@@ -71,14 +71,14 @@ func (c *MaixCamChannel) acceptConnections(ctx context.Context) {
conn, err := c.listener.Accept()
if err != nil {
if c.running {
- logger.ErrorCF("maixcam", "Failed to accept connection", map[string]interface{}{
+ logger.ErrorCF("maixcam", "Failed to accept connection", map[string]any{
"error": err.Error(),
})
}
return
}
- logger.InfoCF("maixcam", "New connection from MaixCam device", map[string]interface{}{
+ logger.InfoCF("maixcam", "New connection from MaixCam device", map[string]any{
"remote_addr": conn.RemoteAddr().String(),
})
@@ -112,7 +112,7 @@ func (c *MaixCamChannel) handleConnection(conn net.Conn, ctx context.Context) {
var msg MaixCamMessage
if err := decoder.Decode(&msg); err != nil {
if err.Error() != "EOF" {
- logger.ErrorCF("maixcam", "Failed to decode message", map[string]interface{}{
+ logger.ErrorCF("maixcam", "Failed to decode message", map[string]any{
"error": err.Error(),
})
}
@@ -133,14 +133,14 @@ func (c *MaixCamChannel) processMessage(msg MaixCamMessage, conn net.Conn) {
case "status":
c.handleStatusUpdate(msg)
default:
- logger.WarnCF("maixcam", "Unknown message type", map[string]interface{}{
+ logger.WarnCF("maixcam", "Unknown message type", map[string]any{
"type": msg.Type,
})
}
}
func (c *MaixCamChannel) handlePersonDetection(msg MaixCamMessage) {
- logger.InfoCF("maixcam", "", map[string]interface{}{
+ logger.InfoCF("maixcam", "", map[string]any{
"timestamp": msg.Timestamp,
"data": msg.Data,
})
@@ -178,7 +178,7 @@ func (c *MaixCamChannel) handlePersonDetection(msg MaixCamMessage) {
}
func (c *MaixCamChannel) handleStatusUpdate(msg MaixCamMessage) {
- logger.InfoCF("maixcam", "Status update from MaixCam", map[string]interface{}{
+ logger.InfoCF("maixcam", "Status update from MaixCam", map[string]any{
"status": msg.Data,
})
}
@@ -216,7 +216,7 @@ func (c *MaixCamChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro
return fmt.Errorf("no connected MaixCam devices")
}
- response := map[string]interface{}{
+ response := map[string]any{
"type": "command",
"timestamp": float64(0),
"message": msg.Content,
@@ -231,7 +231,7 @@ func (c *MaixCamChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro
var sendErr error
for conn := range c.clients {
if _, err := conn.Write(data); err != nil {
- logger.ErrorCF("maixcam", "Failed to send to client", map[string]interface{}{
+ logger.ErrorCF("maixcam", "Failed to send to client", map[string]any{
"client": conn.RemoteAddr().String(),
"error": err.Error(),
})
diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go
index 7f6abc4cb..75edaf49e 100644
--- a/pkg/channels/manager.go
+++ b/pkg/channels/manager.go
@@ -50,7 +50,7 @@ func (m *Manager) initChannels() error {
logger.DebugC("channels", "Attempting to initialize Telegram channel")
telegram, err := NewTelegramChannel(m.config, m.bus)
if err != nil {
- logger.ErrorCF("channels", "Failed to initialize Telegram channel", map[string]interface{}{
+ logger.ErrorCF("channels", "Failed to initialize Telegram channel", map[string]any{
"error": err.Error(),
})
} else {
@@ -63,7 +63,7 @@ func (m *Manager) initChannels() error {
logger.DebugC("channels", "Attempting to initialize WhatsApp channel")
whatsapp, err := NewWhatsAppChannel(m.config.Channels.WhatsApp, m.bus)
if err != nil {
- logger.ErrorCF("channels", "Failed to initialize WhatsApp channel", map[string]interface{}{
+ logger.ErrorCF("channels", "Failed to initialize WhatsApp channel", map[string]any{
"error": err.Error(),
})
} else {
@@ -76,7 +76,7 @@ func (m *Manager) initChannels() error {
logger.DebugC("channels", "Attempting to initialize Feishu channel")
feishu, err := NewFeishuChannel(m.config.Channels.Feishu, m.bus)
if err != nil {
- logger.ErrorCF("channels", "Failed to initialize Feishu channel", map[string]interface{}{
+ logger.ErrorCF("channels", "Failed to initialize Feishu channel", map[string]any{
"error": err.Error(),
})
} else {
@@ -89,7 +89,7 @@ func (m *Manager) initChannels() error {
logger.DebugC("channels", "Attempting to initialize Discord channel")
discord, err := NewDiscordChannel(m.config.Channels.Discord, m.bus)
if err != nil {
- logger.ErrorCF("channels", "Failed to initialize Discord channel", map[string]interface{}{
+ logger.ErrorCF("channels", "Failed to initialize Discord channel", map[string]any{
"error": err.Error(),
})
} else {
@@ -102,7 +102,7 @@ func (m *Manager) initChannels() error {
logger.DebugC("channels", "Attempting to initialize MaixCam channel")
maixcam, err := NewMaixCamChannel(m.config.Channels.MaixCam, m.bus)
if err != nil {
- logger.ErrorCF("channels", "Failed to initialize MaixCam channel", map[string]interface{}{
+ logger.ErrorCF("channels", "Failed to initialize MaixCam channel", map[string]any{
"error": err.Error(),
})
} else {
@@ -115,7 +115,7 @@ func (m *Manager) initChannels() error {
logger.DebugC("channels", "Attempting to initialize QQ channel")
qq, err := NewQQChannel(m.config.Channels.QQ, m.bus)
if err != nil {
- logger.ErrorCF("channels", "Failed to initialize QQ channel", map[string]interface{}{
+ logger.ErrorCF("channels", "Failed to initialize QQ channel", map[string]any{
"error": err.Error(),
})
} else {
@@ -128,7 +128,7 @@ func (m *Manager) initChannels() error {
logger.DebugC("channels", "Attempting to initialize DingTalk channel")
dingtalk, err := NewDingTalkChannel(m.config.Channels.DingTalk, m.bus)
if err != nil {
- logger.ErrorCF("channels", "Failed to initialize DingTalk channel", map[string]interface{}{
+ logger.ErrorCF("channels", "Failed to initialize DingTalk channel", map[string]any{
"error": err.Error(),
})
} else {
@@ -141,7 +141,7 @@ func (m *Manager) initChannels() error {
logger.DebugC("channels", "Attempting to initialize Slack channel")
slackCh, err := NewSlackChannel(m.config.Channels.Slack, m.bus)
if err != nil {
- logger.ErrorCF("channels", "Failed to initialize Slack channel", map[string]interface{}{
+ logger.ErrorCF("channels", "Failed to initialize Slack channel", map[string]any{
"error": err.Error(),
})
} else {
@@ -154,7 +154,7 @@ func (m *Manager) initChannels() error {
logger.DebugC("channels", "Attempting to initialize LINE channel")
line, err := NewLINEChannel(m.config.Channels.LINE, m.bus)
if err != nil {
- logger.ErrorCF("channels", "Failed to initialize LINE channel", map[string]interface{}{
+ logger.ErrorCF("channels", "Failed to initialize LINE channel", map[string]any{
"error": err.Error(),
})
} else {
@@ -167,7 +167,7 @@ func (m *Manager) initChannels() error {
logger.DebugC("channels", "Attempting to initialize OneBot channel")
onebot, err := NewOneBotChannel(m.config.Channels.OneBot, m.bus)
if err != nil {
- logger.ErrorCF("channels", "Failed to initialize OneBot channel", map[string]interface{}{
+ logger.ErrorCF("channels", "Failed to initialize OneBot channel", map[string]any{
"error": err.Error(),
})
} else {
@@ -176,7 +176,33 @@ func (m *Manager) initChannels() error {
}
}
- logger.InfoCF("channels", "Channel initialization completed", map[string]interface{}{
+ if m.config.Channels.WeCom.Enabled && m.config.Channels.WeCom.Token != "" {
+ logger.DebugC("channels", "Attempting to initialize WeCom channel")
+ wecom, err := NewWeComBotChannel(m.config.Channels.WeCom, m.bus)
+ if err != nil {
+ logger.ErrorCF("channels", "Failed to initialize WeCom channel", map[string]any{
+ "error": err.Error(),
+ })
+ } else {
+ m.channels["wecom"] = wecom
+ logger.InfoC("channels", "WeCom channel enabled successfully")
+ }
+ }
+
+ if m.config.Channels.WeComApp.Enabled && m.config.Channels.WeComApp.CorpID != "" {
+ logger.DebugC("channels", "Attempting to initialize WeCom App channel")
+ wecomApp, err := NewWeComAppChannel(m.config.Channels.WeComApp, m.bus)
+ if err != nil {
+ logger.ErrorCF("channels", "Failed to initialize WeCom App channel", map[string]any{
+ "error": err.Error(),
+ })
+ } else {
+ m.channels["wecom_app"] = wecomApp
+ logger.InfoC("channels", "WeCom App channel enabled successfully")
+ }
+ }
+
+ logger.InfoCF("channels", "Channel initialization completed", map[string]any{
"enabled_channels": len(m.channels),
})
@@ -200,11 +226,11 @@ func (m *Manager) StartAll(ctx context.Context) error {
go m.dispatchOutbound(dispatchCtx)
for name, channel := range m.channels {
- logger.InfoCF("channels", "Starting channel", map[string]interface{}{
+ logger.InfoCF("channels", "Starting channel", map[string]any{
"channel": name,
})
if err := channel.Start(ctx); err != nil {
- logger.ErrorCF("channels", "Failed to start channel", map[string]interface{}{
+ logger.ErrorCF("channels", "Failed to start channel", map[string]any{
"channel": name,
"error": err.Error(),
})
@@ -227,11 +253,11 @@ func (m *Manager) StopAll(ctx context.Context) error {
}
for name, channel := range m.channels {
- logger.InfoCF("channels", "Stopping channel", map[string]interface{}{
+ logger.InfoCF("channels", "Stopping channel", map[string]any{
"channel": name,
})
if err := channel.Stop(ctx); err != nil {
- logger.ErrorCF("channels", "Error stopping channel", map[string]interface{}{
+ logger.ErrorCF("channels", "Error stopping channel", map[string]any{
"channel": name,
"error": err.Error(),
})
@@ -266,14 +292,14 @@ func (m *Manager) dispatchOutbound(ctx context.Context) {
m.mu.RUnlock()
if !exists {
- logger.WarnCF("channels", "Unknown channel for outbound message", map[string]interface{}{
+ logger.WarnCF("channels", "Unknown channel for outbound message", map[string]any{
"channel": msg.Channel,
})
continue
}
if err := channel.Send(ctx, msg); err != nil {
- logger.ErrorCF("channels", "Error sending message to channel", map[string]interface{}{
+ logger.ErrorCF("channels", "Error sending message to channel", map[string]any{
"channel": msg.Channel,
"error": err.Error(),
})
@@ -289,13 +315,13 @@ func (m *Manager) GetChannel(name string) (Channel, bool) {
return channel, ok
}
-func (m *Manager) GetStatus() map[string]interface{} {
+func (m *Manager) GetStatus() map[string]any {
m.mu.RLock()
defer m.mu.RUnlock()
- status := make(map[string]interface{})
+ status := make(map[string]any)
for name, channel := range m.channels {
- status[name] = map[string]interface{}{
+ status[name] = map[string]any{
"enabled": true,
"running": channel.IsRunning(),
}
diff --git a/pkg/channels/onebot.go b/pkg/channels/onebot.go
index 06186f783..4576a11ce 100644
--- a/pkg/channels/onebot.go
+++ b/pkg/channels/onebot.go
@@ -87,14 +87,14 @@ type oneBotSender struct {
}
type oneBotAPIRequest struct {
- Action string `json:"action"`
- Params interface{} `json:"params"`
- Echo string `json:"echo,omitempty"`
+ Action string `json:"action"`
+ Params any `json:"params"`
+ Echo string `json:"echo,omitempty"`
}
type oneBotMessageSegment struct {
- Type string `json:"type"`
- Data map[string]interface{} `json:"data"`
+ Type string `json:"type"`
+ Data map[string]any `json:"data"`
}
func NewOneBotChannel(cfg config.OneBotConfig, messageBus *bus.MessageBus) (*OneBotChannel, error) {
@@ -117,13 +117,13 @@ func (c *OneBotChannel) SetTranscriber(transcriber *voice.GroqTranscriber) {
func (c *OneBotChannel) setMsgEmojiLike(messageID string, emojiID int, set bool) {
go func() {
- _, err := c.sendAPIRequest("set_msg_emoji_like", map[string]interface{}{
+ _, err := c.sendAPIRequest("set_msg_emoji_like", map[string]any{
"message_id": messageID,
"emoji_id": emojiID,
"set": set,
}, 5*time.Second)
if err != nil {
- logger.DebugCF("onebot", "Failed to set emoji like", map[string]interface{}{
+ logger.DebugCF("onebot", "Failed to set emoji like", map[string]any{
"message_id": messageID,
"error": err.Error(),
})
@@ -136,14 +136,14 @@ func (c *OneBotChannel) Start(ctx context.Context) error {
return fmt.Errorf("OneBot ws_url not configured")
}
- logger.InfoCF("onebot", "Starting OneBot channel", map[string]interface{}{
+ logger.InfoCF("onebot", "Starting OneBot channel", map[string]any{
"ws_url": c.config.WSUrl,
})
c.ctx, c.cancel = context.WithCancel(ctx)
if err := c.connect(); err != nil {
- logger.WarnCF("onebot", "Initial connection failed, will retry in background", map[string]interface{}{
+ logger.WarnCF("onebot", "Initial connection failed, will retry in background", map[string]any{
"error": err.Error(),
})
} else {
@@ -174,7 +174,10 @@ func (c *OneBotChannel) connect() error {
header["Authorization"] = []string{"Bearer " + c.config.AccessToken}
}
- conn, _, err := dialer.Dial(c.config.WSUrl, header)
+ conn, resp, err := dialer.Dial(c.config.WSUrl, header)
+ if resp != nil {
+ resp.Body.Close()
+ }
if err != nil {
return err
}
@@ -208,7 +211,7 @@ func (c *OneBotChannel) pinger(conn *websocket.Conn) {
err := conn.WriteMessage(websocket.PingMessage, nil)
c.writeMu.Unlock()
if err != nil {
- logger.DebugCF("onebot", "Ping write failed, stopping pinger", map[string]interface{}{
+ logger.DebugCF("onebot", "Ping write failed, stopping pinger", map[string]any{
"error": err.Error(),
})
return
@@ -220,7 +223,7 @@ func (c *OneBotChannel) pinger(conn *websocket.Conn) {
func (c *OneBotChannel) fetchSelfID() {
resp, err := c.sendAPIRequest("get_login_info", nil, 5*time.Second)
if err != nil {
- logger.WarnCF("onebot", "Failed to get_login_info", map[string]interface{}{
+ logger.WarnCF("onebot", "Failed to get_login_info", map[string]any{
"error": err.Error(),
})
return
@@ -250,7 +253,7 @@ func (c *OneBotChannel) fetchSelfID() {
}
if uid, err := parseJSONInt64(info.UserID); err == nil && uid > 0 {
atomic.StoreInt64(&c.selfID, uid)
- logger.InfoCF("onebot", "Bot self ID retrieved", map[string]interface{}{
+ logger.InfoCF("onebot", "Bot self ID retrieved", map[string]any{
"self_id": uid,
"nickname": info.Nickname,
})
@@ -258,12 +261,12 @@ func (c *OneBotChannel) fetchSelfID() {
}
}
- logger.WarnCF("onebot", "Could not parse self ID from get_login_info response", map[string]interface{}{
+ logger.WarnCF("onebot", "Could not parse self ID from get_login_info response", map[string]any{
"response": string(resp),
})
}
-func (c *OneBotChannel) sendAPIRequest(action string, params interface{}, timeout time.Duration) (json.RawMessage, error) {
+func (c *OneBotChannel) sendAPIRequest(action string, params any, timeout time.Duration) (json.RawMessage, error) {
c.mu.Lock()
conn := c.conn
c.mu.Unlock()
@@ -310,7 +313,7 @@ func (c *OneBotChannel) sendAPIRequest(action string, params interface{}, timeou
case <-time.After(timeout):
return nil, fmt.Errorf("API request %s timed out after %v", action, timeout)
case <-c.ctx.Done():
- return nil, fmt.Errorf("context cancelled")
+ return nil, fmt.Errorf("context canceled")
}
}
@@ -332,7 +335,7 @@ func (c *OneBotChannel) reconnectLoop() {
if conn == nil {
logger.InfoC("onebot", "Attempting to reconnect...")
if err := c.connect(); err != nil {
- logger.ErrorCF("onebot", "Reconnect failed", map[string]interface{}{
+ logger.ErrorCF("onebot", "Reconnect failed", map[string]any{
"error": err.Error(),
})
} else {
@@ -405,7 +408,7 @@ func (c *OneBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error
c.writeMu.Unlock()
if err != nil {
- logger.ErrorCF("onebot", "Failed to send message", map[string]interface{}{
+ logger.ErrorCF("onebot", "Failed to send message", map[string]any{
"error": err.Error(),
})
return err
@@ -427,20 +430,20 @@ func (c *OneBotChannel) buildMessageSegments(chatID, content string) []oneBotMes
if msgID, ok := lastMsgID.(string); ok && msgID != "" {
segments = append(segments, oneBotMessageSegment{
Type: "reply",
- Data: map[string]interface{}{"id": msgID},
+ Data: map[string]any{"id": msgID},
})
}
}
segments = append(segments, oneBotMessageSegment{
Type: "text",
- Data: map[string]interface{}{"text": content},
+ Data: map[string]any{"text": content},
})
return segments
}
-func (c *OneBotChannel) buildSendRequest(msg bus.OutboundMessage) (string, interface{}, error) {
+func (c *OneBotChannel) buildSendRequest(msg bus.OutboundMessage) (string, any, error) {
chatID := msg.ChatID
segments := c.buildMessageSegments(chatID, msg.Content)
@@ -458,7 +461,7 @@ func (c *OneBotChannel) buildSendRequest(msg bus.OutboundMessage) (string, inter
if err != nil {
return "", nil, fmt.Errorf("invalid %s in chatID: %s", idKey, chatID)
}
- return action, map[string]interface{}{idKey: id, "message": segments}, nil
+ return action, map[string]any{idKey: id, "message": segments}, nil
}
func (c *OneBotChannel) listen() {
@@ -478,7 +481,7 @@ func (c *OneBotChannel) listen() {
default:
_, message, err := conn.ReadMessage()
if err != nil {
- logger.ErrorCF("onebot", "WebSocket read error", map[string]interface{}{
+ logger.ErrorCF("onebot", "WebSocket read error", map[string]any{
"error": err.Error(),
})
c.mu.Lock()
@@ -494,14 +497,14 @@ func (c *OneBotChannel) listen() {
var raw oneBotRawEvent
if err := json.Unmarshal(message, &raw); err != nil {
- logger.WarnCF("onebot", "Failed to unmarshal raw event", map[string]interface{}{
+ logger.WarnCF("onebot", "Failed to unmarshal raw event", map[string]any{
"error": err.Error(),
"payload": string(message),
})
continue
}
- logger.DebugCF("onebot", "WebSocket event", map[string]interface{}{
+ logger.DebugCF("onebot", "WebSocket event", map[string]any{
"length": len(message),
"post_type": raw.PostType,
"sub_type": raw.SubType,
@@ -518,7 +521,7 @@ func (c *OneBotChannel) listen() {
default:
}
} else {
- logger.DebugCF("onebot", "Received API response (no waiter)", map[string]interface{}{
+ logger.DebugCF("onebot", "Received API response (no waiter)", map[string]any{
"echo": raw.Echo,
"status": string(raw.Status),
})
@@ -527,7 +530,7 @@ func (c *OneBotChannel) listen() {
}
if isAPIResponse(raw.Status) {
- logger.DebugCF("onebot", "Received API response without echo, skipping", map[string]interface{}{
+ logger.DebugCF("onebot", "Received API response without echo, skipping", map[string]any{
"status": string(raw.Status),
})
continue
@@ -594,7 +597,7 @@ func (c *OneBotChannel) parseMessageSegments(raw json.RawMessage, selfID int64)
return parseMessageResult{Text: s, IsBotMentioned: mentioned}
}
- var segments []map[string]interface{}
+ var segments []map[string]any
if err := json.Unmarshal(raw, &segments); err != nil {
return parseMessageResult{}
}
@@ -608,7 +611,7 @@ func (c *OneBotChannel) parseMessageSegments(raw json.RawMessage, selfID int64)
for _, seg := range segments {
segType, _ := seg["type"].(string)
- data, _ := seg["data"].(map[string]interface{})
+ data, _ := seg["data"].(map[string]any)
switch segType {
case "text":
@@ -662,7 +665,7 @@ func (c *OneBotChannel) parseMessageSegments(raw json.RawMessage, selfID int64)
result, err := c.transcriber.Transcribe(tctx, localPath)
tcancel()
if err != nil {
- logger.WarnCF("onebot", "Voice transcription failed", map[string]interface{}{
+ logger.WarnCF("onebot", "Voice transcription failed", map[string]any{
"error": err.Error(),
})
textParts = append(textParts, "[voice (transcription failed)]")
@@ -695,7 +698,6 @@ func (c *OneBotChannel) parseMessageSegments(raw json.RawMessage, selfID int64)
textParts = append(textParts, "[forward message]")
default:
-
}
}
@@ -713,7 +715,7 @@ func (c *OneBotChannel) handleRawEvent(raw *oneBotRawEvent) {
case "message":
if userID, err := parseJSONInt64(raw.UserID); err == nil && userID > 0 {
if !c.IsAllowed(strconv.FormatInt(userID, 10)) {
- logger.DebugCF("onebot", "Message rejected by allowlist", map[string]interface{}{
+ logger.DebugCF("onebot", "Message rejected by allowlist", map[string]any{
"user_id": userID,
})
return
@@ -722,7 +724,7 @@ func (c *OneBotChannel) handleRawEvent(raw *oneBotRawEvent) {
c.handleMessage(raw)
case "message_sent":
- logger.DebugCF("onebot", "Bot sent message event", map[string]interface{}{
+ logger.DebugCF("onebot", "Bot sent message event", map[string]any{
"message_type": raw.MessageType,
"message_id": parseJSONString(raw.MessageID),
})
@@ -734,18 +736,18 @@ func (c *OneBotChannel) handleRawEvent(raw *oneBotRawEvent) {
c.handleNoticeEvent(raw)
case "request":
- logger.DebugCF("onebot", "Request event received", map[string]interface{}{
+ logger.DebugCF("onebot", "Request event received", map[string]any{
"sub_type": raw.SubType,
})
case "":
- logger.DebugCF("onebot", "Event with empty post_type (possibly API response)", map[string]interface{}{
+ logger.DebugCF("onebot", "Event with empty post_type (possibly API response)", map[string]any{
"echo": raw.Echo,
"status": raw.Status,
})
default:
- logger.DebugCF("onebot", "Unknown post_type", map[string]interface{}{
+ logger.DebugCF("onebot", "Unknown post_type", map[string]any{
"post_type": raw.PostType,
})
}
@@ -753,14 +755,14 @@ func (c *OneBotChannel) handleRawEvent(raw *oneBotRawEvent) {
func (c *OneBotChannel) handleMetaEvent(raw *oneBotRawEvent) {
if raw.MetaEventType == "lifecycle" {
- logger.InfoCF("onebot", "Lifecycle event", map[string]interface{}{"sub_type": raw.SubType})
+ logger.InfoCF("onebot", "Lifecycle event", map[string]any{"sub_type": raw.SubType})
} else if raw.MetaEventType != "heartbeat" {
logger.DebugCF("onebot", "Meta event: "+raw.MetaEventType, nil)
}
}
func (c *OneBotChannel) handleNoticeEvent(raw *oneBotRawEvent) {
- fields := map[string]interface{}{
+ fields := map[string]any{
"notice_type": raw.NoticeType,
"sub_type": raw.SubType,
"group_id": parseJSONString(raw.GroupID),
@@ -780,7 +782,7 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) {
// Parse fields from raw event
userID, err := parseJSONInt64(raw.UserID)
if err != nil {
- logger.WarnCF("onebot", "Failed to parse user_id", map[string]interface{}{
+ logger.WarnCF("onebot", "Failed to parse user_id", map[string]any{
"error": err.Error(),
"raw": string(raw.UserID),
})
@@ -817,7 +819,7 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) {
var sender oneBotSender
if len(raw.Sender) > 0 {
if err := json.Unmarshal(raw.Sender, &sender); err != nil {
- logger.WarnCF("onebot", "Failed to parse sender", map[string]interface{}{
+ logger.WarnCF("onebot", "Failed to parse sender", map[string]any{
"error": err.Error(),
"sender": string(raw.Sender),
})
@@ -829,7 +831,7 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) {
defer func() {
for _, f := range parsed.LocalFiles {
if err := os.Remove(f); err != nil {
- logger.DebugCF("onebot", "Failed to remove temp file", map[string]interface{}{
+ logger.DebugCF("onebot", "Failed to remove temp file", map[string]any{
"path": f,
"error": err.Error(),
})
@@ -839,14 +841,14 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) {
}
if c.isDuplicate(messageID) {
- logger.DebugCF("onebot", "Duplicate message, skipping", map[string]interface{}{
+ logger.DebugCF("onebot", "Duplicate message, skipping", map[string]any{
"message_id": messageID,
})
return
}
if content == "" {
- logger.DebugCF("onebot", "Received empty message, ignoring", map[string]interface{}{
+ logger.DebugCF("onebot", "Received empty message, ignoring", map[string]any{
"message_id": messageID,
})
return
@@ -889,7 +891,7 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) {
triggered, strippedContent := c.checkGroupTrigger(content, isBotMentioned)
if !triggered {
- logger.DebugCF("onebot", "Group message ignored (no trigger)", map[string]interface{}{
+ logger.DebugCF("onebot", "Group message ignored (no trigger)", map[string]any{
"sender": senderID,
"group": groupIDStr,
"is_mentioned": isBotMentioned,
@@ -900,7 +902,7 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) {
content = strippedContent
default:
- logger.WarnCF("onebot", "Unknown message type, cannot route", map[string]interface{}{
+ logger.WarnCF("onebot", "Unknown message type, cannot route", map[string]any{
"type": raw.MessageType,
"message_id": messageID,
"user_id": userID,
@@ -908,7 +910,7 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) {
return
}
- logger.InfoCF("onebot", "Received "+raw.MessageType+" message", map[string]interface{}{
+ logger.InfoCF("onebot", "Received "+raw.MessageType+" message", map[string]any{
"sender": senderID,
"chat_id": chatID,
"message_id": messageID,
@@ -961,7 +963,10 @@ func truncate(s string, n int) string {
return string(runes[:n]) + "..."
}
-func (c *OneBotChannel) checkGroupTrigger(content string, isBotMentioned bool) (triggered bool, strippedContent string) {
+func (c *OneBotChannel) checkGroupTrigger(
+ content string,
+ isBotMentioned bool,
+) (triggered bool, strippedContent string) {
if isBotMentioned {
return true, strings.TrimSpace(content)
}
diff --git a/pkg/channels/qq.go b/pkg/channels/qq.go
index 79907df83..b10776db6 100644
--- a/pkg/channels/qq.go
+++ b/pkg/channels/qq.go
@@ -47,47 +47,47 @@ func (c *QQChannel) Start(ctx context.Context) error {
logger.InfoC("qq", "Starting QQ bot (WebSocket mode)")
- // 创建 token source
+ // create token source
credentials := &token.QQBotCredentials{
AppID: c.config.AppID,
AppSecret: c.config.AppSecret,
}
c.tokenSource = token.NewQQBotTokenSource(credentials)
- // 创建子 context
+ // create child context
c.ctx, c.cancel = context.WithCancel(ctx)
- // 启动自动刷新 token 协程
+ // start auto-refresh token goroutine
if err := token.StartRefreshAccessToken(c.ctx, c.tokenSource); err != nil {
return fmt.Errorf("failed to start token refresh: %w", err)
}
- // 初始化 OpenAPI 客户端
+ // initialize OpenAPI client
c.api = botgo.NewOpenAPI(c.config.AppID, c.tokenSource).WithTimeout(5 * time.Second)
- // 注册事件处理器
+ // register event handlers
intent := event.RegisterHandlers(
c.handleC2CMessage(),
c.handleGroupATMessage(),
)
- // 获取 WebSocket 接入点
+ // get WebSocket endpoint
wsInfo, err := c.api.WS(c.ctx, nil, "")
if err != nil {
return fmt.Errorf("failed to get websocket info: %w", err)
}
- logger.InfoCF("qq", "Got WebSocket info", map[string]interface{}{
+ logger.InfoCF("qq", "Got WebSocket info", map[string]any{
"shards": wsInfo.Shards,
})
- // 创建并保存 sessionManager
+ // create and save sessionManager
c.sessionManager = botgo.NewSessionManager()
- // 在 goroutine 中启动 WebSocket 连接,避免阻塞
+ // start WebSocket connection in goroutine to avoid blocking
go func() {
if err := c.sessionManager.Start(wsInfo, c.tokenSource, &intent); err != nil {
- logger.ErrorCF("qq", "WebSocket session error", map[string]interface{}{
+ logger.ErrorCF("qq", "WebSocket session error", map[string]any{
"error": err.Error(),
})
c.setRunning(false)
@@ -116,15 +116,15 @@ func (c *QQChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
return fmt.Errorf("QQ bot not running")
}
- // 构造消息
+ // construct message
msgToCreate := &dto.MessageToCreate{
Content: msg.Content,
}
- // C2C 消息发送
+ // send C2C message
_, err := c.api.PostC2CMessage(ctx, msg.ChatID, msgToCreate)
if err != nil {
- logger.ErrorCF("qq", "Failed to send C2C message", map[string]interface{}{
+ logger.ErrorCF("qq", "Failed to send C2C message", map[string]any{
"error": err.Error(),
})
return err
@@ -133,15 +133,15 @@ func (c *QQChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
return nil
}
-// handleC2CMessage 处理 QQ 私聊消息
+// handleC2CMessage handles QQ private messages
func (c *QQChannel) handleC2CMessage() event.C2CMessageEventHandler {
return func(event *dto.WSPayload, data *dto.WSC2CMessageData) error {
- // 去重检查
+ // deduplication check
if c.isDuplicate(data.ID) {
return nil
}
- // 提取用户信息
+ // extract user info
var senderID string
if data.Author != nil && data.Author.ID != "" {
senderID = data.Author.ID
@@ -150,19 +150,19 @@ func (c *QQChannel) handleC2CMessage() event.C2CMessageEventHandler {
return nil
}
- // 提取消息内容
+ // extract message content
content := data.Content
if content == "" {
logger.DebugC("qq", "Received empty message, ignoring")
return nil
}
- logger.InfoCF("qq", "Received C2C message", map[string]interface{}{
+ logger.InfoCF("qq", "Received C2C message", map[string]any{
"sender": senderID,
"length": len(content),
})
- // 转发到消息总线
+ // forward to message bus
metadata := map[string]string{
"message_id": data.ID,
"peer_kind": "direct",
@@ -175,15 +175,15 @@ func (c *QQChannel) handleC2CMessage() event.C2CMessageEventHandler {
}
}
-// handleGroupATMessage 处理群@消息
+// handleGroupATMessage handles group @messages
func (c *QQChannel) handleGroupATMessage() event.GroupATMessageEventHandler {
return func(event *dto.WSPayload, data *dto.WSGroupATMessageData) error {
- // 去重检查
+ // deduplication check
if c.isDuplicate(data.ID) {
return nil
}
- // 提取用户信息
+ // extract user info
var senderID string
if data.Author != nil && data.Author.ID != "" {
senderID = data.Author.ID
@@ -192,20 +192,20 @@ func (c *QQChannel) handleGroupATMessage() event.GroupATMessageEventHandler {
return nil
}
- // 提取消息内容(去掉 @ 机器人部分)
+ // extract message content (remove @bot part)
content := data.Content
if content == "" {
logger.DebugC("qq", "Received empty group message, ignoring")
return nil
}
- logger.InfoCF("qq", "Received group AT message", map[string]interface{}{
+ logger.InfoCF("qq", "Received group AT message", map[string]any{
"sender": senderID,
"group": data.GroupID,
"length": len(content),
})
- // 转发到消息总线(使用 GroupID 作为 ChatID)
+ // forward to message bus (use GroupID as ChatID)
metadata := map[string]string{
"message_id": data.ID,
"group_id": data.GroupID,
@@ -219,7 +219,7 @@ func (c *QQChannel) handleGroupATMessage() event.GroupATMessageEventHandler {
}
}
-// isDuplicate 检查消息是否重复
+// isDuplicate checks if message is duplicate
func (c *QQChannel) isDuplicate(messageID string) bool {
c.mu.Lock()
defer c.mu.Unlock()
@@ -230,9 +230,9 @@ func (c *QQChannel) isDuplicate(messageID string) bool {
c.processedIDs[messageID] = true
- // 简单清理:限制 map 大小
+ // simple cleanup: limit map size
if len(c.processedIDs) > 10000 {
- // 清空一半
+ // clear half
count := 0
for id := range c.processedIDs {
if count >= 5000 {
diff --git a/pkg/channels/slack.go b/pkg/channels/slack.go
index 0060972ed..cfb731b16 100644
--- a/pkg/channels/slack.go
+++ b/pkg/channels/slack.go
@@ -75,7 +75,7 @@ func (c *SlackChannel) Start(ctx context.Context) error {
c.botUserID = authResp.UserID
c.teamID = authResp.TeamID
- logger.InfoCF("slack", "Slack bot connected", map[string]interface{}{
+ logger.InfoCF("slack", "Slack bot connected", map[string]any{
"bot_user_id": c.botUserID,
"team": authResp.Team,
})
@@ -85,7 +85,7 @@ func (c *SlackChannel) Start(ctx context.Context) error {
go func() {
if err := c.socketClient.RunContext(c.ctx); err != nil {
if c.ctx.Err() == nil {
- logger.ErrorCF("slack", "Socket Mode connection error", map[string]interface{}{
+ logger.ErrorCF("slack", "Socket Mode connection error", map[string]any{
"error": err.Error(),
})
}
@@ -140,7 +140,7 @@ func (c *SlackChannel) Send(ctx context.Context, msg bus.OutboundMessage) error
})
}
- logger.DebugCF("slack", "Message sent", map[string]interface{}{
+ logger.DebugCF("slack", "Message sent", map[string]any{
"channel_id": channelID,
"thread_ts": threadTS,
})
@@ -200,9 +200,9 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) {
return
}
- // 检查白名单,避免为被拒绝的用户下载附件
+ // check allowlist to avoid downloading attachments for rejected users
if !c.IsAllowed(ev.User) {
- logger.DebugCF("slack", "Message rejected by allowlist", map[string]interface{}{
+ logger.DebugCF("slack", "Message rejected by allowlist", map[string]any{
"user_id": ev.User,
})
return
@@ -232,13 +232,13 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) {
content = c.stripBotMention(content)
var mediaPaths []string
- localFiles := []string{} // 跟踪需要清理的本地文件
+ localFiles := []string{} // track local files that need cleanup
- // 确保临时文件在函数返回时被清理
+ // ensure temp files are cleaned up when function returns
defer func() {
for _, file := range localFiles {
if err := os.Remove(file); err != nil {
- logger.DebugCF("slack", "Failed to cleanup temp file", map[string]interface{}{
+ logger.DebugCF("slack", "Failed to cleanup temp file", map[string]any{
"file": file,
"error": err.Error(),
})
@@ -261,7 +261,7 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) {
result, err := c.transcriber.Transcribe(ctx, localPath)
if err != nil {
- logger.ErrorCF("slack", "Voice transcription failed", map[string]interface{}{"error": err.Error()})
+ logger.ErrorCF("slack", "Voice transcription failed", map[string]any{"error": err.Error()})
content += fmt.Sprintf("\n[audio: %s (transcription failed)]", file.Name)
} else {
content += fmt.Sprintf("\n[voice transcription: %s]", result.Text)
@@ -293,7 +293,7 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) {
"team_id": c.teamID,
}
- logger.DebugCF("slack", "Received message", map[string]interface{}{
+ logger.DebugCF("slack", "Received message", map[string]any{
"sender_id": senderID,
"chat_id": chatID,
"preview": utils.Truncate(content, 50),
@@ -309,7 +309,7 @@ func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) {
}
if !c.IsAllowed(ev.User) {
- logger.DebugCF("slack", "Mention rejected by allowlist", map[string]interface{}{
+ logger.DebugCF("slack", "Mention rejected by allowlist", map[string]any{
"user_id": ev.User,
})
return
@@ -375,7 +375,7 @@ func (c *SlackChannel) handleSlashCommand(event socketmode.Event) {
}
if !c.IsAllowed(cmd.UserID) {
- logger.DebugCF("slack", "Slash command rejected by allowlist", map[string]interface{}{
+ logger.DebugCF("slack", "Slash command rejected by allowlist", map[string]any{
"user_id": cmd.UserID,
})
return
@@ -400,7 +400,7 @@ func (c *SlackChannel) handleSlashCommand(event socketmode.Event) {
"team_id": c.teamID,
}
- logger.DebugCF("slack", "Slash command received", map[string]interface{}{
+ logger.DebugCF("slack", "Slash command received", map[string]any{
"sender_id": senderID,
"command": cmd.Command,
"text": utils.Truncate(content, 50),
@@ -415,7 +415,7 @@ func (c *SlackChannel) downloadSlackFile(file slack.File) string {
downloadURL = file.URLPrivate
}
if downloadURL == "" {
- logger.ErrorCF("slack", "No download URL for file", map[string]interface{}{"file_id": file.ID})
+ logger.ErrorCF("slack", "No download URL for file", map[string]any{"file_id": file.ID})
return ""
}
@@ -439,5 +439,5 @@ func parseSlackChatID(chatID string) (channelID, threadTS string) {
if len(parts) > 1 {
threadTS = parts[1]
}
- return
+ return channelID, threadTS
}
diff --git a/pkg/channels/telegram.go b/pkg/channels/telegram.go
index 20bbf6830..524494849 100644
--- a/pkg/channels/telegram.go
+++ b/pkg/channels/telegram.go
@@ -11,10 +11,9 @@ import (
"sync"
"time"
- th "github.com/mymmrac/telego/telegohandler"
-
"github.com/mymmrac/telego"
"github.com/mymmrac/telego/telegohandler"
+ th "github.com/mymmrac/telego/telegohandler"
tu "github.com/mymmrac/telego/telegoutil"
"github.com/sipeed/picoclaw/pkg/bus"
@@ -127,7 +126,7 @@ func (c *TelegramChannel) Start(ctx context.Context) error {
}, th.AnyMessage())
c.setRunning(true)
- logger.InfoCF("telegram", "Telegram bot connected", map[string]interface{}{
+ logger.InfoCF("telegram", "Telegram bot connected", map[string]any{
"username": c.bot.Username(),
})
@@ -140,6 +139,7 @@ func (c *TelegramChannel) Start(ctx context.Context) error {
return nil
}
+
func (c *TelegramChannel) Stop(ctx context.Context) error {
logger.InfoC("telegram", "Stopping Telegram bot...")
c.setRunning(false)
@@ -182,7 +182,7 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
tgMsg.ParseMode = telego.ModeHTML
if _, err = c.bot.SendMessage(ctx, tgMsg); err != nil {
- logger.ErrorCF("telegram", "HTML parse failed, falling back to plain text", map[string]interface{}{
+ logger.ErrorCF("telegram", "HTML parse failed, falling back to plain text", map[string]any{
"error": err.Error(),
})
tgMsg.ParseMode = ""
@@ -208,9 +208,9 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes
senderID = fmt.Sprintf("%d|%s", user.ID, user.Username)
}
- // 检查白名单,避免为被拒绝的用户下载附件
+ // check allowlist to avoid downloading attachments for rejected users
if !c.IsAllowed(senderID) {
- logger.DebugCF("telegram", "Message rejected by allowlist", map[string]interface{}{
+ logger.DebugCF("telegram", "Message rejected by allowlist", map[string]any{
"user_id": senderID,
})
return nil
@@ -221,13 +221,13 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes
content := ""
mediaPaths := []string{}
- localFiles := []string{} // 跟踪需要清理的本地文件
+ localFiles := []string{} // track local files that need cleanup
- // 确保临时文件在函数返回时被清理
+ // ensure temp files are cleaned up when function returns
defer func() {
for _, file := range localFiles {
if err := os.Remove(file); err != nil {
- logger.DebugCF("telegram", "Failed to cleanup temp file", map[string]interface{}{
+ logger.DebugCF("telegram", "Failed to cleanup temp file", map[string]any{
"file": file,
"error": err.Error(),
})
@@ -265,21 +265,21 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes
localFiles = append(localFiles, voicePath)
mediaPaths = append(mediaPaths, voicePath)
- transcribedText := ""
+ var transcribedText string
if c.transcriber != nil && c.transcriber.IsAvailable() {
- ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
+ transcriberCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
- result, err := c.transcriber.Transcribe(ctx, voicePath)
+ result, err := c.transcriber.Transcribe(transcriberCtx, voicePath)
if err != nil {
- logger.ErrorCF("telegram", "Voice transcription failed", map[string]interface{}{
+ logger.ErrorCF("telegram", "Voice transcription failed", map[string]any{
"error": err.Error(),
"path": voicePath,
})
transcribedText = "[voice (transcription failed)]"
} else {
transcribedText = fmt.Sprintf("[voice transcription: %s]", result.Text)
- logger.InfoCF("telegram", "Voice transcribed successfully", map[string]interface{}{
+ logger.InfoCF("telegram", "Voice transcribed successfully", map[string]any{
"text": result.Text,
})
}
@@ -322,7 +322,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes
content = "[empty message]"
}
- logger.DebugCF("telegram", "Received message", map[string]interface{}{
+ logger.DebugCF("telegram", "Received message", map[string]any{
"sender_id": senderID,
"chat_id": fmt.Sprintf("%d", chatID),
"preview": utils.Truncate(content, 50),
@@ -331,7 +331,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes
// Thinking indicator
err := c.bot.SendChatAction(ctx, tu.ChatAction(tu.ID(chatID), telego.ChatActionTyping))
if err != nil {
- logger.ErrorCF("telegram", "Failed to send chat action", map[string]interface{}{
+ logger.ErrorCF("telegram", "Failed to send chat action", map[string]any{
"error": err.Error(),
})
}
@@ -378,7 +378,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes
func (c *TelegramChannel) downloadPhoto(ctx context.Context, fileID string) string {
file, err := c.bot.GetFile(ctx, &telego.GetFileParams{FileID: fileID})
if err != nil {
- logger.ErrorCF("telegram", "Failed to get photo file", map[string]interface{}{
+ logger.ErrorCF("telegram", "Failed to get photo file", map[string]any{
"error": err.Error(),
})
return ""
@@ -393,7 +393,7 @@ func (c *TelegramChannel) downloadFileWithInfo(file *telego.File, ext string) st
}
url := c.bot.FileDownloadURL(file.FilePath)
- logger.DebugCF("telegram", "File URL", map[string]interface{}{"url": url})
+ logger.DebugCF("telegram", "File URL", map[string]any{"url": url})
// Use FilePath as filename for better identification
filename := file.FilePath + ext
@@ -405,7 +405,7 @@ func (c *TelegramChannel) downloadFileWithInfo(file *telego.File, ext string) st
func (c *TelegramChannel) downloadFile(ctx context.Context, fileID, ext string) string {
file, err := c.bot.GetFile(ctx, &telego.GetFileParams{FileID: fileID})
if err != nil {
- logger.ErrorCF("telegram", "Failed to get file", map[string]interface{}{
+ logger.ErrorCF("telegram", "Failed to get file", map[string]any{
"error": err.Error(),
})
return ""
@@ -463,7 +463,11 @@ func markdownToTelegramHTML(text string) string {
for i, code := range codeBlocks.codes {
escaped := escapeHTML(code)
- text = strings.ReplaceAll(text, fmt.Sprintf("\x00CB%d\x00", i), fmt.Sprintf("%s
", escaped))
+ text = strings.ReplaceAll(
+ text,
+ fmt.Sprintf("\x00CB%d\x00", i),
+ fmt.Sprintf("%s
", escaped),
+ )
}
return text
diff --git a/pkg/channels/telegram_commands.go b/pkg/channels/telegram_commands.go
index df245e156..f28434f46 100644
--- a/pkg/channels/telegram_commands.go
+++ b/pkg/channels/telegram_commands.go
@@ -6,6 +6,7 @@ import (
"strings"
"github.com/mymmrac/telego"
+
"github.com/sipeed/picoclaw/pkg/config"
)
@@ -35,6 +36,7 @@ func commandArgs(text string) string {
}
return strings.TrimSpace(parts[1])
}
+
func (c *cmd) Help(ctx context.Context, message telego.Message) error {
msg := `/start - Start the bot
/help - Show this help message
@@ -79,7 +81,7 @@ func (c *cmd) Show(ctx context.Context, message telego.Message) error {
switch args {
case "model":
response = fmt.Sprintf("Current Model: %s (Provider: %s)",
- c.config.Agents.Defaults.Model,
+ c.config.Agents.Defaults.GetModelName(),
c.config.Agents.Defaults.Provider)
case "channel":
response = "Current Channel: telegram"
@@ -96,6 +98,7 @@ func (c *cmd) Show(ctx context.Context, message telego.Message) error {
})
return err
}
+
func (c *cmd) List(ctx context.Context, message telego.Message) error {
args := commandArgs(message.Text)
if args == "" {
@@ -117,7 +120,7 @@ func (c *cmd) List(ctx context.Context, message telego.Message) error {
provider = "configured default"
}
response = fmt.Sprintf("Configured Model: %s\nProvider: %s\n\nTo change models, update config.yaml",
- c.config.Agents.Defaults.Model, provider)
+ c.config.Agents.Defaults.GetModelName(), provider)
case "channels":
var enabled []string
diff --git a/pkg/channels/wecom.go b/pkg/channels/wecom.go
new file mode 100644
index 000000000..f8daf89de
--- /dev/null
+++ b/pkg/channels/wecom.go
@@ -0,0 +1,605 @@
+// PicoClaw - Ultra-lightweight personal AI agent
+// WeCom Bot (企业微信智能机器人) channel implementation
+// Uses webhook callback mode for receiving messages and webhook API for sending replies
+
+package channels
+
+import (
+ "bytes"
+ "context"
+ "crypto/aes"
+ "crypto/cipher"
+ "crypto/sha1"
+ "encoding/base64"
+ "encoding/binary"
+ "encoding/json"
+ "encoding/xml"
+ "fmt"
+ "io"
+ "net/http"
+ "sort"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/config"
+ "github.com/sipeed/picoclaw/pkg/logger"
+ "github.com/sipeed/picoclaw/pkg/utils"
+)
+
+// WeComBotChannel implements the Channel interface for WeCom Bot (企业微信智能机器人)
+// Uses webhook callback mode - simpler than WeCom App but only supports passive replies
+type WeComBotChannel struct {
+ *BaseChannel
+ config config.WeComConfig
+ server *http.Server
+ ctx context.Context
+ cancel context.CancelFunc
+ processedMsgs map[string]bool // Message deduplication: msg_id -> processed
+ msgMu sync.RWMutex
+}
+
+// WeComBotMessage represents the JSON message structure from WeCom Bot (AIBOT)
+type WeComBotMessage struct {
+ MsgID string `json:"msgid"`
+ AIBotID string `json:"aibotid"`
+ ChatID string `json:"chatid"` // Session ID, only present for group chats
+ ChatType string `json:"chattype"` // "single" for DM, "group" for group chat
+ From struct {
+ UserID string `json:"userid"`
+ } `json:"from"`
+ ResponseURL string `json:"response_url"`
+ MsgType string `json:"msgtype"` // text, image, voice, file, mixed
+ Text struct {
+ Content string `json:"content"`
+ } `json:"text"`
+ Image struct {
+ URL string `json:"url"`
+ } `json:"image"`
+ Voice struct {
+ Content string `json:"content"` // Voice to text content
+ } `json:"voice"`
+ File struct {
+ URL string `json:"url"`
+ } `json:"file"`
+ Mixed struct {
+ MsgItem []struct {
+ MsgType string `json:"msgtype"`
+ Text struct {
+ Content string `json:"content"`
+ } `json:"text"`
+ Image struct {
+ URL string `json:"url"`
+ } `json:"image"`
+ } `json:"msg_item"`
+ } `json:"mixed"`
+ Quote struct {
+ MsgType string `json:"msgtype"`
+ Text struct {
+ Content string `json:"content"`
+ } `json:"text"`
+ } `json:"quote"`
+}
+
+// WeComBotReplyMessage represents the reply message structure
+type WeComBotReplyMessage struct {
+ MsgType string `json:"msgtype"`
+ Text struct {
+ Content string `json:"content"`
+ } `json:"text,omitempty"`
+}
+
+// NewWeComBotChannel creates a new WeCom Bot channel instance
+func NewWeComBotChannel(cfg config.WeComConfig, messageBus *bus.MessageBus) (*WeComBotChannel, error) {
+ if cfg.Token == "" || cfg.WebhookURL == "" {
+ return nil, fmt.Errorf("wecom token and webhook_url are required")
+ }
+
+ base := NewBaseChannel("wecom", cfg, messageBus, cfg.AllowFrom)
+
+ return &WeComBotChannel{
+ BaseChannel: base,
+ config: cfg,
+ processedMsgs: make(map[string]bool),
+ }, nil
+}
+
+// Name returns the channel name
+func (c *WeComBotChannel) Name() string {
+ return "wecom"
+}
+
+// Start initializes the WeCom Bot channel with HTTP webhook server
+func (c *WeComBotChannel) Start(ctx context.Context) error {
+ logger.InfoC("wecom", "Starting WeCom Bot channel...")
+
+ c.ctx, c.cancel = context.WithCancel(ctx)
+
+ // Setup HTTP server for webhook
+ mux := http.NewServeMux()
+ webhookPath := c.config.WebhookPath
+ if webhookPath == "" {
+ webhookPath = "/webhook/wecom"
+ }
+ mux.HandleFunc(webhookPath, c.handleWebhook)
+
+ // Health check endpoint
+ mux.HandleFunc("/health/wecom", c.handleHealth)
+
+ addr := fmt.Sprintf("%s:%d", c.config.WebhookHost, c.config.WebhookPort)
+ c.server = &http.Server{
+ Addr: addr,
+ Handler: mux,
+ }
+
+ c.setRunning(true)
+ logger.InfoCF("wecom", "WeCom Bot channel started", map[string]any{
+ "address": addr,
+ "path": webhookPath,
+ })
+
+ // Start server in goroutine
+ go func() {
+ if err := c.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
+ logger.ErrorCF("wecom", "HTTP server error", map[string]any{
+ "error": err.Error(),
+ })
+ }
+ }()
+
+ return nil
+}
+
+// Stop gracefully stops the WeCom Bot channel
+func (c *WeComBotChannel) Stop(ctx context.Context) error {
+ logger.InfoC("wecom", "Stopping WeCom Bot channel...")
+
+ if c.cancel != nil {
+ c.cancel()
+ }
+
+ if c.server != nil {
+ shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
+ defer cancel()
+ c.server.Shutdown(shutdownCtx)
+ }
+
+ c.setRunning(false)
+ logger.InfoC("wecom", "WeCom Bot channel stopped")
+ return nil
+}
+
+// Send sends a message to WeCom user via webhook API
+// Note: WeCom Bot can only reply within the configured timeout (default 5 seconds) of receiving a message
+// For delayed responses, we use the webhook URL
+func (c *WeComBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
+ if !c.IsRunning() {
+ return fmt.Errorf("wecom channel not running")
+ }
+
+ logger.DebugCF("wecom", "Sending message via webhook", map[string]any{
+ "chat_id": msg.ChatID,
+ "preview": utils.Truncate(msg.Content, 100),
+ })
+
+ return c.sendWebhookReply(ctx, msg.ChatID, msg.Content)
+}
+
+// handleWebhook handles incoming webhook requests from WeCom
+func (c *WeComBotChannel) handleWebhook(w http.ResponseWriter, r *http.Request) {
+ ctx := r.Context()
+
+ if r.Method == http.MethodGet {
+ // Handle verification request
+ c.handleVerification(ctx, w, r)
+ return
+ }
+
+ if r.Method == http.MethodPost {
+ // Handle message callback
+ c.handleMessageCallback(ctx, w, r)
+ return
+ }
+
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+}
+
+// handleVerification handles the URL verification request from WeCom
+func (c *WeComBotChannel) handleVerification(ctx context.Context, w http.ResponseWriter, r *http.Request) {
+ query := r.URL.Query()
+ msgSignature := query.Get("msg_signature")
+ timestamp := query.Get("timestamp")
+ nonce := query.Get("nonce")
+ echostr := query.Get("echostr")
+
+ if msgSignature == "" || timestamp == "" || nonce == "" || echostr == "" {
+ http.Error(w, "Missing parameters", http.StatusBadRequest)
+ return
+ }
+
+ // Verify signature
+ if !WeComVerifySignature(c.config.Token, msgSignature, timestamp, nonce, echostr) {
+ logger.WarnC("wecom", "Signature verification failed")
+ http.Error(w, "Invalid signature", http.StatusForbidden)
+ return
+ }
+
+ // Decrypt echostr
+ // For AIBOT (智能机器人), receiveid should be empty string ""
+ // Reference: https://developer.work.weixin.qq.com/document/path/101033
+ decryptedEchoStr, err := WeComDecryptMessageWithVerify(echostr, c.config.EncodingAESKey, "")
+ if err != nil {
+ logger.ErrorCF("wecom", "Failed to decrypt echostr", map[string]any{
+ "error": err.Error(),
+ })
+ http.Error(w, "Decryption failed", http.StatusInternalServerError)
+ return
+ }
+
+ // Remove BOM and whitespace as per WeCom documentation
+ // The response must be plain text without quotes, BOM, or newlines
+ decryptedEchoStr = strings.TrimSpace(decryptedEchoStr)
+ decryptedEchoStr = strings.TrimPrefix(decryptedEchoStr, "\xef\xbb\xbf") // Remove UTF-8 BOM
+ w.Write([]byte(decryptedEchoStr))
+}
+
+// handleMessageCallback handles incoming messages from WeCom
+func (c *WeComBotChannel) handleMessageCallback(ctx context.Context, w http.ResponseWriter, r *http.Request) {
+ query := r.URL.Query()
+ msgSignature := query.Get("msg_signature")
+ timestamp := query.Get("timestamp")
+ nonce := query.Get("nonce")
+
+ if msgSignature == "" || timestamp == "" || nonce == "" {
+ http.Error(w, "Missing parameters", http.StatusBadRequest)
+ return
+ }
+
+ // Read request body
+ body, err := io.ReadAll(r.Body)
+ if err != nil {
+ http.Error(w, "Failed to read body", http.StatusBadRequest)
+ return
+ }
+ defer r.Body.Close()
+
+ // Parse XML to get encrypted message
+ var encryptedMsg struct {
+ XMLName xml.Name `xml:"xml"`
+ ToUserName string `xml:"ToUserName"`
+ Encrypt string `xml:"Encrypt"`
+ AgentID string `xml:"AgentID"`
+ }
+
+ if err = xml.Unmarshal(body, &encryptedMsg); err != nil {
+ logger.ErrorCF("wecom", "Failed to parse XML", map[string]any{
+ "error": err.Error(),
+ })
+ http.Error(w, "Invalid XML", http.StatusBadRequest)
+ return
+ }
+
+ // Verify signature
+ if !WeComVerifySignature(c.config.Token, msgSignature, timestamp, nonce, encryptedMsg.Encrypt) {
+ logger.WarnC("wecom", "Message signature verification failed")
+ http.Error(w, "Invalid signature", http.StatusForbidden)
+ return
+ }
+
+ // Decrypt message
+ // For AIBOT (智能机器人), receiveid should be empty string ""
+ // Reference: https://developer.work.weixin.qq.com/document/path/101033
+ decryptedMsg, err := WeComDecryptMessageWithVerify(encryptedMsg.Encrypt, c.config.EncodingAESKey, "")
+ if err != nil {
+ logger.ErrorCF("wecom", "Failed to decrypt message", map[string]any{
+ "error": err.Error(),
+ })
+ http.Error(w, "Decryption failed", http.StatusInternalServerError)
+ return
+ }
+
+ // Parse decrypted JSON message (AIBOT uses JSON format)
+ var msg WeComBotMessage
+ if err := json.Unmarshal([]byte(decryptedMsg), &msg); err != nil {
+ logger.ErrorCF("wecom", "Failed to parse decrypted message", map[string]any{
+ "error": err.Error(),
+ })
+ http.Error(w, "Invalid message format", http.StatusBadRequest)
+ return
+ }
+
+ // Process the message asynchronously with context
+ go c.processMessage(ctx, msg)
+
+ // Return success response immediately
+ // WeCom Bot requires response within configured timeout (default 5 seconds)
+ w.Write([]byte("success"))
+}
+
+// processMessage processes the received message
+func (c *WeComBotChannel) processMessage(ctx context.Context, msg WeComBotMessage) {
+ // Skip unsupported message types
+ if msg.MsgType != "text" && msg.MsgType != "image" && msg.MsgType != "voice" && msg.MsgType != "file" &&
+ msg.MsgType != "mixed" {
+ logger.DebugCF("wecom", "Skipping non-supported message type", map[string]any{
+ "msg_type": msg.MsgType,
+ })
+ return
+ }
+
+ // Message deduplication: Use msg_id to prevent duplicate processing
+ msgID := msg.MsgID
+ c.msgMu.Lock()
+ if c.processedMsgs[msgID] {
+ c.msgMu.Unlock()
+ logger.DebugCF("wecom", "Skipping duplicate message", map[string]any{
+ "msg_id": msgID,
+ })
+ return
+ }
+ c.processedMsgs[msgID] = true
+ c.msgMu.Unlock()
+
+ // Clean up old messages periodically (keep last 1000)
+ if len(c.processedMsgs) > 1000 {
+ c.msgMu.Lock()
+ c.processedMsgs = make(map[string]bool)
+ c.msgMu.Unlock()
+ }
+
+ senderID := msg.From.UserID
+
+ // Determine if this is a group chat or direct message
+ // ChatType: "single" for DM, "group" for group chat
+ isGroupChat := msg.ChatType == "group"
+
+ var chatID, peerKind, peerID string
+ if isGroupChat {
+ // Group chat: use ChatID as chatID and peer_id
+ chatID = msg.ChatID
+ peerKind = "group"
+ peerID = msg.ChatID
+ } else {
+ // Direct message: use senderID as chatID and peer_id
+ chatID = senderID
+ peerKind = "direct"
+ peerID = senderID
+ }
+
+ // Extract content based on message type
+ var content string
+ switch msg.MsgType {
+ case "text":
+ content = msg.Text.Content
+ case "voice":
+ content = msg.Voice.Content // Voice to text content
+ case "mixed":
+ // For mixed messages, concatenate text items
+ for _, item := range msg.Mixed.MsgItem {
+ if item.MsgType == "text" {
+ content += item.Text.Content
+ }
+ }
+ case "image", "file":
+ // For image and file, we don't have text content
+ content = ""
+ }
+
+ // Build metadata
+ metadata := map[string]string{
+ "msg_type": msg.MsgType,
+ "msg_id": msg.MsgID,
+ "platform": "wecom",
+ "peer_kind": peerKind,
+ "peer_id": peerID,
+ "response_url": msg.ResponseURL,
+ }
+ if isGroupChat {
+ metadata["chat_id"] = msg.ChatID
+ metadata["sender_id"] = senderID
+ }
+
+ logger.DebugCF("wecom", "Received message", map[string]any{
+ "sender_id": senderID,
+ "msg_type": msg.MsgType,
+ "peer_kind": peerKind,
+ "is_group_chat": isGroupChat,
+ "preview": utils.Truncate(content, 50),
+ })
+
+ // Handle the message through the base channel
+ c.HandleMessage(senderID, chatID, content, nil, metadata)
+}
+
+// sendWebhookReply sends a reply using the webhook URL
+func (c *WeComBotChannel) sendWebhookReply(ctx context.Context, userID, content string) error {
+ reply := WeComBotReplyMessage{
+ MsgType: "text",
+ }
+ reply.Text.Content = content
+
+ jsonData, err := json.Marshal(reply)
+ if err != nil {
+ return fmt.Errorf("failed to marshal reply: %w", err)
+ }
+
+ // Use configurable timeout (default 5 seconds)
+ timeout := c.config.ReplyTimeout
+ if timeout <= 0 {
+ timeout = 5
+ }
+
+ reqCtx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second)
+ defer cancel()
+
+ req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, c.config.WebhookURL, bytes.NewBuffer(jsonData))
+ if err != nil {
+ return fmt.Errorf("failed to create request: %w", err)
+ }
+ req.Header.Set("Content-Type", "application/json")
+
+ client := &http.Client{Timeout: time.Duration(timeout) * time.Second}
+ resp, err := client.Do(req)
+ if err != nil {
+ return fmt.Errorf("failed to send webhook reply: %w", err)
+ }
+ defer resp.Body.Close()
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return fmt.Errorf("failed to read response: %w", err)
+ }
+
+ // Check response
+ var result struct {
+ ErrCode int `json:"errcode"`
+ ErrMsg string `json:"errmsg"`
+ }
+ if err := json.Unmarshal(body, &result); err != nil {
+ return fmt.Errorf("failed to parse response: %w", err)
+ }
+
+ if result.ErrCode != 0 {
+ return fmt.Errorf("webhook API error: %s (code: %d)", result.ErrMsg, result.ErrCode)
+ }
+
+ return nil
+}
+
+// handleHealth handles health check requests
+func (c *WeComBotChannel) handleHealth(w http.ResponseWriter, r *http.Request) {
+ status := map[string]any{
+ "status": "ok",
+ "running": c.IsRunning(),
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(status)
+}
+
+// WeCom common utilities for both WeCom Bot and WeCom App
+// The following functions were moved from wecom_common.go
+
+// WeComVerifySignature verifies the message signature for WeCom
+// This is a common function used by both WeCom Bot and WeCom App
+func WeComVerifySignature(token, msgSignature, timestamp, nonce, msgEncrypt string) bool {
+ if token == "" {
+ return true // Skip verification if token is not set
+ }
+
+ // Sort parameters
+ params := []string{token, timestamp, nonce, msgEncrypt}
+ sort.Strings(params)
+
+ // Concatenate
+ str := strings.Join(params, "")
+
+ // SHA1 hash
+ hash := sha1.Sum([]byte(str))
+ expectedSignature := fmt.Sprintf("%x", hash)
+
+ return expectedSignature == msgSignature
+}
+
+// WeComDecryptMessage decrypts the encrypted message using AES
+// This is a common function used by both WeCom Bot and WeCom App
+// For AIBOT, receiveid should be the aibotid; for other apps, it should be corp_id
+func WeComDecryptMessage(encryptedMsg, encodingAESKey string) (string, error) {
+ return WeComDecryptMessageWithVerify(encryptedMsg, encodingAESKey, "")
+}
+
+// WeComDecryptMessageWithVerify decrypts the encrypted message and optionally verifies receiveid
+// receiveid: for AIBOT use aibotid, for WeCom App use corp_id. If empty, skip verification.
+func WeComDecryptMessageWithVerify(encryptedMsg, encodingAESKey, receiveid string) (string, error) {
+ if encodingAESKey == "" {
+ // No encryption, return as is (base64 decode)
+ decoded, err := base64.StdEncoding.DecodeString(encryptedMsg)
+ if err != nil {
+ return "", err
+ }
+ return string(decoded), nil
+ }
+
+ // Decode AES key (base64)
+ aesKey, err := base64.StdEncoding.DecodeString(encodingAESKey + "=")
+ if err != nil {
+ return "", fmt.Errorf("failed to decode AES key: %w", err)
+ }
+
+ // Decode encrypted message
+ cipherText, err := base64.StdEncoding.DecodeString(encryptedMsg)
+ if err != nil {
+ return "", fmt.Errorf("failed to decode message: %w", err)
+ }
+
+ // AES decrypt
+ block, err := aes.NewCipher(aesKey)
+ if err != nil {
+ return "", fmt.Errorf("failed to create cipher: %w", err)
+ }
+
+ if len(cipherText) < aes.BlockSize {
+ return "", fmt.Errorf("ciphertext too short")
+ }
+
+ // IV is the first 16 bytes of AESKey
+ iv := aesKey[:aes.BlockSize]
+ mode := cipher.NewCBCDecrypter(block, iv)
+ plainText := make([]byte, len(cipherText))
+ mode.CryptBlocks(plainText, cipherText)
+
+ // Remove PKCS7 padding
+ plainText, err = pkcs7UnpadWeCom(plainText)
+ if err != nil {
+ return "", fmt.Errorf("failed to unpad: %w", err)
+ }
+
+ // Parse message structure
+ // Format: random(16) + msg_len(4) + msg + receiveid
+ if len(plainText) < 20 {
+ return "", fmt.Errorf("decrypted message too short")
+ }
+
+ msgLen := binary.BigEndian.Uint32(plainText[16:20])
+ if int(msgLen) > len(plainText)-20 {
+ return "", fmt.Errorf("invalid message length")
+ }
+
+ msg := plainText[20 : 20+msgLen]
+
+ // Verify receiveid if provided
+ if receiveid != "" && len(plainText) > 20+int(msgLen) {
+ actualReceiveID := string(plainText[20+msgLen:])
+ if actualReceiveID != receiveid {
+ return "", fmt.Errorf("receiveid mismatch: expected %s, got %s", receiveid, actualReceiveID)
+ }
+ }
+
+ return string(msg), nil
+}
+
+// pkcs7UnpadWeCom removes PKCS7 padding with validation
+// WeCom uses block size of 32 (not standard AES block size of 16)
+const wecomBlockSize = 32
+
+func pkcs7UnpadWeCom(data []byte) ([]byte, error) {
+ if len(data) == 0 {
+ return data, nil
+ }
+ padding := int(data[len(data)-1])
+ // WeCom uses 32-byte block size for PKCS7 padding
+ if padding == 0 || padding > wecomBlockSize {
+ return nil, fmt.Errorf("invalid padding size: %d", padding)
+ }
+ if padding > len(data) {
+ return nil, fmt.Errorf("padding size larger than data")
+ }
+ // Verify all padding bytes
+ for i := 0; i < padding; i++ {
+ if data[len(data)-1-i] != byte(padding) {
+ return nil, fmt.Errorf("invalid padding byte at position %d", i)
+ }
+ }
+ return data[:len(data)-padding], nil
+}
diff --git a/pkg/channels/wecom_app.go b/pkg/channels/wecom_app.go
new file mode 100644
index 000000000..302603445
--- /dev/null
+++ b/pkg/channels/wecom_app.go
@@ -0,0 +1,584 @@
+// PicoClaw - Ultra-lightweight personal AI agent
+// WeCom App (企业微信自建应用) channel implementation
+// Supports receiving messages via webhook callback and sending messages proactively
+
+package channels
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "encoding/xml"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/config"
+ "github.com/sipeed/picoclaw/pkg/logger"
+ "github.com/sipeed/picoclaw/pkg/utils"
+)
+
+const (
+ wecomAPIBase = "https://qyapi.weixin.qq.com"
+)
+
+// WeComAppChannel implements the Channel interface for WeCom App (企业微信自建应用)
+type WeComAppChannel struct {
+ *BaseChannel
+ config config.WeComAppConfig
+ server *http.Server
+ accessToken string
+ tokenExpiry time.Time
+ tokenMu sync.RWMutex
+ ctx context.Context
+ cancel context.CancelFunc
+ processedMsgs map[string]bool // Message deduplication: msg_id -> processed
+ msgMu sync.RWMutex
+}
+
+// WeComXMLMessage represents the XML message structure from WeCom
+type WeComXMLMessage struct {
+ XMLName xml.Name `xml:"xml"`
+ ToUserName string `xml:"ToUserName"`
+ FromUserName string `xml:"FromUserName"`
+ CreateTime int64 `xml:"CreateTime"`
+ MsgType string `xml:"MsgType"`
+ Content string `xml:"Content"`
+ MsgId int64 `xml:"MsgId"`
+ AgentID int64 `xml:"AgentID"`
+ PicUrl string `xml:"PicUrl"`
+ MediaId string `xml:"MediaId"`
+ Format string `xml:"Format"`
+ ThumbMediaId string `xml:"ThumbMediaId"`
+ LocationX float64 `xml:"Location_X"`
+ LocationY float64 `xml:"Location_Y"`
+ Scale int `xml:"Scale"`
+ Label string `xml:"Label"`
+ Title string `xml:"Title"`
+ Description string `xml:"Description"`
+ Url string `xml:"Url"`
+ Event string `xml:"Event"`
+ EventKey string `xml:"EventKey"`
+}
+
+// WeComTextMessage represents text message for sending
+type WeComTextMessage struct {
+ ToUser string `json:"touser"`
+ MsgType string `json:"msgtype"`
+ AgentID int64 `json:"agentid"`
+ Text struct {
+ Content string `json:"content"`
+ } `json:"text"`
+ Safe int `json:"safe,omitempty"`
+}
+
+// WeComMarkdownMessage represents markdown message for sending
+type WeComMarkdownMessage struct {
+ ToUser string `json:"touser"`
+ MsgType string `json:"msgtype"`
+ AgentID int64 `json:"agentid"`
+ Markdown struct {
+ Content string `json:"content"`
+ } `json:"markdown"`
+}
+
+// WeComImageMessage represents image message for sending
+type WeComImageMessage struct {
+ ToUser string `json:"touser"`
+ MsgType string `json:"msgtype"`
+ AgentID int64 `json:"agentid"`
+ Image struct {
+ MediaID string `json:"media_id"`
+ } `json:"image"`
+}
+
+// WeComAccessTokenResponse represents the access token API response
+type WeComAccessTokenResponse struct {
+ ErrCode int `json:"errcode"`
+ ErrMsg string `json:"errmsg"`
+ AccessToken string `json:"access_token"`
+ ExpiresIn int `json:"expires_in"`
+}
+
+// WeComSendMessageResponse represents the send message API response
+type WeComSendMessageResponse struct {
+ ErrCode int `json:"errcode"`
+ ErrMsg string `json:"errmsg"`
+ InvalidUser string `json:"invaliduser"`
+ InvalidParty string `json:"invalidparty"`
+ InvalidTag string `json:"invalidtag"`
+}
+
+// PKCS7Padding adds PKCS7 padding
+type PKCS7Padding struct{}
+
+// NewWeComAppChannel creates a new WeCom App channel instance
+func NewWeComAppChannel(cfg config.WeComAppConfig, messageBus *bus.MessageBus) (*WeComAppChannel, error) {
+ if cfg.CorpID == "" || cfg.CorpSecret == "" || cfg.AgentID == 0 {
+ return nil, fmt.Errorf("wecom_app corp_id, corp_secret and agent_id are required")
+ }
+
+ base := NewBaseChannel("wecom_app", cfg, messageBus, cfg.AllowFrom)
+
+ return &WeComAppChannel{
+ BaseChannel: base,
+ config: cfg,
+ processedMsgs: make(map[string]bool),
+ }, nil
+}
+
+// Name returns the channel name
+func (c *WeComAppChannel) Name() string {
+ return "wecom_app"
+}
+
+// Start initializes the WeCom App channel with HTTP webhook server
+func (c *WeComAppChannel) Start(ctx context.Context) error {
+ logger.InfoC("wecom_app", "Starting WeCom App channel...")
+
+ c.ctx, c.cancel = context.WithCancel(ctx)
+
+ // Get initial access token
+ if err := c.refreshAccessToken(); err != nil {
+ logger.WarnCF("wecom_app", "Failed to get initial access token", map[string]any{
+ "error": err.Error(),
+ })
+ }
+
+ // Start token refresh goroutine
+ go c.tokenRefreshLoop()
+
+ // Setup HTTP server for webhook
+ mux := http.NewServeMux()
+ webhookPath := c.config.WebhookPath
+ if webhookPath == "" {
+ webhookPath = "/webhook/wecom-app"
+ }
+ mux.HandleFunc(webhookPath, c.handleWebhook)
+
+ // Health check endpoint
+ mux.HandleFunc("/health/wecom-app", c.handleHealth)
+
+ addr := fmt.Sprintf("%s:%d", c.config.WebhookHost, c.config.WebhookPort)
+ c.server = &http.Server{
+ Addr: addr,
+ Handler: mux,
+ }
+
+ c.setRunning(true)
+ logger.InfoCF("wecom_app", "WeCom App channel started", map[string]any{
+ "address": addr,
+ "path": webhookPath,
+ })
+
+ // Start server in goroutine
+ go func() {
+ if err := c.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
+ logger.ErrorCF("wecom_app", "HTTP server error", map[string]any{
+ "error": err.Error(),
+ })
+ }
+ }()
+
+ return nil
+}
+
+// Stop gracefully stops the WeCom App channel
+func (c *WeComAppChannel) Stop(ctx context.Context) error {
+ logger.InfoC("wecom_app", "Stopping WeCom App channel...")
+
+ if c.cancel != nil {
+ c.cancel()
+ }
+
+ if c.server != nil {
+ shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
+ defer cancel()
+ c.server.Shutdown(shutdownCtx)
+ }
+
+ c.setRunning(false)
+ logger.InfoC("wecom_app", "WeCom App channel stopped")
+ return nil
+}
+
+// Send sends a message to WeCom user proactively using access token
+func (c *WeComAppChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
+ if !c.IsRunning() {
+ return fmt.Errorf("wecom_app channel not running")
+ }
+
+ accessToken := c.getAccessToken()
+ if accessToken == "" {
+ return fmt.Errorf("no valid access token available")
+ }
+
+ logger.DebugCF("wecom_app", "Sending message", map[string]any{
+ "chat_id": msg.ChatID,
+ "preview": utils.Truncate(msg.Content, 100),
+ })
+
+ return c.sendTextMessage(ctx, accessToken, msg.ChatID, msg.Content)
+}
+
+// handleWebhook handles incoming webhook requests from WeCom
+func (c *WeComAppChannel) handleWebhook(w http.ResponseWriter, r *http.Request) {
+ ctx := r.Context()
+
+ // Log all incoming requests for debugging
+ logger.DebugCF("wecom_app", "Received webhook request", map[string]any{
+ "method": r.Method,
+ "url": r.URL.String(),
+ "path": r.URL.Path,
+ "query": r.URL.RawQuery,
+ })
+
+ if r.Method == http.MethodGet {
+ // Handle verification request
+ c.handleVerification(ctx, w, r)
+ return
+ }
+
+ if r.Method == http.MethodPost {
+ // Handle message callback
+ c.handleMessageCallback(ctx, w, r)
+ return
+ }
+
+ logger.WarnCF("wecom_app", "Method not allowed", map[string]any{
+ "method": r.Method,
+ })
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+}
+
+// handleVerification handles the URL verification request from WeCom
+func (c *WeComAppChannel) handleVerification(ctx context.Context, w http.ResponseWriter, r *http.Request) {
+ query := r.URL.Query()
+ msgSignature := query.Get("msg_signature")
+ timestamp := query.Get("timestamp")
+ nonce := query.Get("nonce")
+ echostr := query.Get("echostr")
+
+ logger.DebugCF("wecom_app", "Handling verification request", map[string]any{
+ "msg_signature": msgSignature,
+ "timestamp": timestamp,
+ "nonce": nonce,
+ "echostr": echostr,
+ "corp_id": c.config.CorpID,
+ })
+
+ if msgSignature == "" || timestamp == "" || nonce == "" || echostr == "" {
+ logger.ErrorC("wecom_app", "Missing parameters in verification request")
+ http.Error(w, "Missing parameters", http.StatusBadRequest)
+ return
+ }
+
+ // Verify signature
+ if !WeComVerifySignature(c.config.Token, msgSignature, timestamp, nonce, echostr) {
+ logger.WarnCF("wecom_app", "Signature verification failed", map[string]any{
+ "token": c.config.Token,
+ "msg_signature": msgSignature,
+ "timestamp": timestamp,
+ "nonce": nonce,
+ })
+ http.Error(w, "Invalid signature", http.StatusForbidden)
+ return
+ }
+
+ logger.DebugC("wecom_app", "Signature verification passed")
+
+ // Decrypt echostr with CorpID verification
+ // For WeCom App (自建应用), receiveid should be corp_id
+ logger.DebugCF("wecom_app", "Attempting to decrypt echostr", map[string]any{
+ "encoding_aes_key": c.config.EncodingAESKey,
+ "corp_id": c.config.CorpID,
+ })
+ decryptedEchoStr, err := WeComDecryptMessageWithVerify(echostr, c.config.EncodingAESKey, c.config.CorpID)
+ if err != nil {
+ logger.ErrorCF("wecom_app", "Failed to decrypt echostr", map[string]any{
+ "error": err.Error(),
+ "encoding_aes_key": c.config.EncodingAESKey,
+ "corp_id": c.config.CorpID,
+ })
+ http.Error(w, "Decryption failed", http.StatusInternalServerError)
+ return
+ }
+
+ logger.DebugCF("wecom_app", "Successfully decrypted echostr", map[string]any{
+ "decrypted": decryptedEchoStr,
+ })
+
+ // Remove BOM and whitespace as per WeCom documentation
+ // The response must be plain text without quotes, BOM, or newlines
+ decryptedEchoStr = strings.TrimSpace(decryptedEchoStr)
+ decryptedEchoStr = strings.TrimPrefix(decryptedEchoStr, "\xef\xbb\xbf") // Remove UTF-8 BOM
+ w.Write([]byte(decryptedEchoStr))
+}
+
+// handleMessageCallback handles incoming messages from WeCom
+func (c *WeComAppChannel) handleMessageCallback(ctx context.Context, w http.ResponseWriter, r *http.Request) {
+ query := r.URL.Query()
+ msgSignature := query.Get("msg_signature")
+ timestamp := query.Get("timestamp")
+ nonce := query.Get("nonce")
+
+ if msgSignature == "" || timestamp == "" || nonce == "" {
+ http.Error(w, "Missing parameters", http.StatusBadRequest)
+ return
+ }
+
+ // Read request body
+ body, err := io.ReadAll(r.Body)
+ if err != nil {
+ http.Error(w, "Failed to read body", http.StatusBadRequest)
+ return
+ }
+ defer r.Body.Close()
+
+ // Parse XML to get encrypted message
+ var encryptedMsg struct {
+ XMLName xml.Name `xml:"xml"`
+ ToUserName string `xml:"ToUserName"`
+ Encrypt string `xml:"Encrypt"`
+ AgentID string `xml:"AgentID"`
+ }
+
+ if err = xml.Unmarshal(body, &encryptedMsg); err != nil {
+ logger.ErrorCF("wecom_app", "Failed to parse XML", map[string]any{
+ "error": err.Error(),
+ })
+ http.Error(w, "Invalid XML", http.StatusBadRequest)
+ return
+ }
+
+ // Verify signature
+ if !WeComVerifySignature(c.config.Token, msgSignature, timestamp, nonce, encryptedMsg.Encrypt) {
+ logger.WarnC("wecom_app", "Message signature verification failed")
+ http.Error(w, "Invalid signature", http.StatusForbidden)
+ return
+ }
+
+ // Decrypt message with CorpID verification
+ // For WeCom App (自建应用), receiveid should be corp_id
+ decryptedMsg, err := WeComDecryptMessageWithVerify(encryptedMsg.Encrypt, c.config.EncodingAESKey, c.config.CorpID)
+ if err != nil {
+ logger.ErrorCF("wecom_app", "Failed to decrypt message", map[string]any{
+ "error": err.Error(),
+ })
+ http.Error(w, "Decryption failed", http.StatusInternalServerError)
+ return
+ }
+
+ // Parse decrypted XML message
+ var msg WeComXMLMessage
+ if err := xml.Unmarshal([]byte(decryptedMsg), &msg); err != nil {
+ logger.ErrorCF("wecom_app", "Failed to parse decrypted message", map[string]any{
+ "error": err.Error(),
+ })
+ http.Error(w, "Invalid message format", http.StatusBadRequest)
+ return
+ }
+
+ // Process the message with context
+ go c.processMessage(ctx, msg)
+
+ // Return success response immediately
+ // WeCom App requires response within configured timeout (default 5 seconds)
+ w.Write([]byte("success"))
+}
+
+// processMessage processes the received message
+func (c *WeComAppChannel) processMessage(ctx context.Context, msg WeComXMLMessage) {
+ // Skip non-text messages for now (can be extended)
+ if msg.MsgType != "text" && msg.MsgType != "image" && msg.MsgType != "voice" {
+ logger.DebugCF("wecom_app", "Skipping non-supported message type", map[string]any{
+ "msg_type": msg.MsgType,
+ })
+ return
+ }
+
+ // Message deduplication: Use msg_id to prevent duplicate processing
+ // As per WeCom documentation, use msg_id for deduplication
+ msgID := fmt.Sprintf("%d", msg.MsgId)
+ c.msgMu.Lock()
+ if c.processedMsgs[msgID] {
+ c.msgMu.Unlock()
+ logger.DebugCF("wecom_app", "Skipping duplicate message", map[string]any{
+ "msg_id": msgID,
+ })
+ return
+ }
+ c.processedMsgs[msgID] = true
+ c.msgMu.Unlock()
+
+ // Clean up old messages periodically (keep last 1000)
+ if len(c.processedMsgs) > 1000 {
+ c.msgMu.Lock()
+ c.processedMsgs = make(map[string]bool)
+ c.msgMu.Unlock()
+ }
+
+ senderID := msg.FromUserName
+ chatID := senderID // WeCom App uses user ID as chat ID for direct messages
+
+ // Build metadata
+ // WeCom App only supports direct messages (private chat)
+ metadata := map[string]string{
+ "msg_type": msg.MsgType,
+ "msg_id": fmt.Sprintf("%d", msg.MsgId),
+ "agent_id": fmt.Sprintf("%d", msg.AgentID),
+ "platform": "wecom_app",
+ "media_id": msg.MediaId,
+ "create_time": fmt.Sprintf("%d", msg.CreateTime),
+ "peer_kind": "direct",
+ "peer_id": senderID,
+ }
+
+ content := msg.Content
+
+ logger.DebugCF("wecom_app", "Received message", map[string]any{
+ "sender_id": senderID,
+ "msg_type": msg.MsgType,
+ "preview": utils.Truncate(content, 50),
+ })
+
+ // Handle the message through the base channel
+ c.HandleMessage(senderID, chatID, content, nil, metadata)
+}
+
+// tokenRefreshLoop periodically refreshes the access token
+func (c *WeComAppChannel) tokenRefreshLoop() {
+ ticker := time.NewTicker(5 * time.Minute)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-c.ctx.Done():
+ return
+ case <-ticker.C:
+ if err := c.refreshAccessToken(); err != nil {
+ logger.ErrorCF("wecom_app", "Failed to refresh access token", map[string]any{
+ "error": err.Error(),
+ })
+ }
+ }
+ }
+}
+
+// refreshAccessToken gets a new access token from WeCom API
+func (c *WeComAppChannel) refreshAccessToken() error {
+ apiURL := fmt.Sprintf("%s/cgi-bin/gettoken?corpid=%s&corpsecret=%s",
+ wecomAPIBase, url.QueryEscape(c.config.CorpID), url.QueryEscape(c.config.CorpSecret))
+
+ resp, err := http.Get(apiURL)
+ if err != nil {
+ return fmt.Errorf("failed to request access token: %w", err)
+ }
+ defer resp.Body.Close()
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return fmt.Errorf("failed to read response: %w", err)
+ }
+
+ var tokenResp WeComAccessTokenResponse
+ if err := json.Unmarshal(body, &tokenResp); err != nil {
+ return fmt.Errorf("failed to parse response: %w", err)
+ }
+
+ if tokenResp.ErrCode != 0 {
+ return fmt.Errorf("API error: %s (code: %d)", tokenResp.ErrMsg, tokenResp.ErrCode)
+ }
+
+ c.tokenMu.Lock()
+ c.accessToken = tokenResp.AccessToken
+ c.tokenExpiry = time.Now().Add(time.Duration(tokenResp.ExpiresIn-300) * time.Second) // Refresh 5 minutes early
+ c.tokenMu.Unlock()
+
+ logger.DebugC("wecom_app", "Access token refreshed successfully")
+ return nil
+}
+
+// getAccessToken returns the current valid access token
+func (c *WeComAppChannel) getAccessToken() string {
+ c.tokenMu.RLock()
+ defer c.tokenMu.RUnlock()
+
+ if time.Now().After(c.tokenExpiry) {
+ return ""
+ }
+
+ return c.accessToken
+}
+
+// sendTextMessage sends a text message to a user
+func (c *WeComAppChannel) sendTextMessage(ctx context.Context, accessToken, userID, content string) error {
+ apiURL := fmt.Sprintf("%s/cgi-bin/message/send?access_token=%s", wecomAPIBase, accessToken)
+
+ msg := WeComTextMessage{
+ ToUser: userID,
+ MsgType: "text",
+ AgentID: c.config.AgentID,
+ }
+ msg.Text.Content = content
+
+ jsonData, err := json.Marshal(msg)
+ if err != nil {
+ return fmt.Errorf("failed to marshal message: %w", err)
+ }
+
+ // Use configurable timeout (default 5 seconds)
+ timeout := c.config.ReplyTimeout
+ if timeout <= 0 {
+ timeout = 5
+ }
+
+ reqCtx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second)
+ defer cancel()
+
+ req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, apiURL, bytes.NewBuffer(jsonData))
+ if err != nil {
+ return fmt.Errorf("failed to create request: %w", err)
+ }
+ req.Header.Set("Content-Type", "application/json")
+
+ client := &http.Client{Timeout: time.Duration(timeout) * time.Second}
+ resp, err := client.Do(req)
+ if err != nil {
+ return fmt.Errorf("failed to send message: %w", err)
+ }
+ defer resp.Body.Close()
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return fmt.Errorf("failed to read response: %w", err)
+ }
+
+ var sendResp WeComSendMessageResponse
+ if err := json.Unmarshal(body, &sendResp); err != nil {
+ return fmt.Errorf("failed to parse response: %w", err)
+ }
+
+ if sendResp.ErrCode != 0 {
+ return fmt.Errorf("API error: %s (code: %d)", sendResp.ErrMsg, sendResp.ErrCode)
+ }
+
+ return nil
+}
+
+// handleHealth handles health check requests
+func (c *WeComAppChannel) handleHealth(w http.ResponseWriter, r *http.Request) {
+ status := map[string]any{
+ "status": "ok",
+ "running": c.IsRunning(),
+ "has_token": c.getAccessToken() != "",
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(status)
+}
diff --git a/pkg/channels/wecom_app_test.go b/pkg/channels/wecom_app_test.go
new file mode 100644
index 000000000..abf15c52b
--- /dev/null
+++ b/pkg/channels/wecom_app_test.go
@@ -0,0 +1,1104 @@
+// PicoClaw - Ultra-lightweight personal AI agent
+// WeCom App (企业微信自建应用) channel tests
+
+package channels
+
+import (
+ "bytes"
+ "context"
+ "crypto/aes"
+ "crypto/cipher"
+ "crypto/sha1"
+ "encoding/base64"
+ "encoding/binary"
+ "encoding/json"
+ "encoding/xml"
+ "fmt"
+ "net/http"
+ "net/http/httptest"
+ "sort"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/config"
+)
+
+// generateTestAESKeyApp generates a valid test AES key for WeCom App
+func generateTestAESKeyApp() string {
+ // AES key needs to be 32 bytes (256 bits) for AES-256
+ key := make([]byte, 32)
+ for i := range key {
+ key[i] = byte(i + 1)
+ }
+ // Return base64 encoded key without padding
+ return base64.StdEncoding.EncodeToString(key)[:43]
+}
+
+// encryptTestMessageApp encrypts a message for testing WeCom App
+func encryptTestMessageApp(message, aesKey string) (string, error) {
+ // Decode AES key
+ key, err := base64.StdEncoding.DecodeString(aesKey + "=")
+ if err != nil {
+ return "", err
+ }
+
+ // Prepare message: random(16) + msg_len(4) + msg + corp_id
+ random := make([]byte, 0, 16)
+ for i := 0; i < 16; i++ {
+ random = append(random, byte(i+1))
+ }
+
+ msgBytes := []byte(message)
+ corpID := []byte("test_corp_id")
+
+ msgLen := uint32(len(msgBytes))
+ lenBytes := make([]byte, 4)
+ binary.BigEndian.PutUint32(lenBytes, msgLen)
+
+ plainText := append(random, lenBytes...)
+ plainText = append(plainText, msgBytes...)
+ plainText = append(plainText, corpID...)
+
+ // PKCS7 padding
+ blockSize := aes.BlockSize
+ padding := blockSize - len(plainText)%blockSize
+ padText := bytes.Repeat([]byte{byte(padding)}, padding)
+ plainText = append(plainText, padText...)
+
+ // Encrypt
+ block, err := aes.NewCipher(key)
+ if err != nil {
+ return "", err
+ }
+
+ mode := cipher.NewCBCEncrypter(block, key[:aes.BlockSize])
+ cipherText := make([]byte, len(plainText))
+ mode.CryptBlocks(cipherText, plainText)
+
+ return base64.StdEncoding.EncodeToString(cipherText), nil
+}
+
+// generateSignatureApp generates a signature for testing WeCom App
+func generateSignatureApp(token, timestamp, nonce, msgEncrypt string) string {
+ params := []string{token, timestamp, nonce, msgEncrypt}
+ sort.Strings(params)
+ str := strings.Join(params, "")
+ hash := sha1.Sum([]byte(str))
+ return fmt.Sprintf("%x", hash)
+}
+
+func TestNewWeComAppChannel(t *testing.T) {
+ msgBus := bus.NewMessageBus()
+
+ t.Run("missing corp_id", func(t *testing.T) {
+ cfg := config.WeComAppConfig{
+ CorpID: "",
+ CorpSecret: "test_secret",
+ AgentID: 1000002,
+ }
+ _, err := NewWeComAppChannel(cfg, msgBus)
+ if err == nil {
+ t.Error("expected error for missing corp_id, got nil")
+ }
+ })
+
+ t.Run("missing corp_secret", func(t *testing.T) {
+ cfg := config.WeComAppConfig{
+ CorpID: "test_corp_id",
+ CorpSecret: "",
+ AgentID: 1000002,
+ }
+ _, err := NewWeComAppChannel(cfg, msgBus)
+ if err == nil {
+ t.Error("expected error for missing corp_secret, got nil")
+ }
+ })
+
+ t.Run("missing agent_id", func(t *testing.T) {
+ cfg := config.WeComAppConfig{
+ CorpID: "test_corp_id",
+ CorpSecret: "test_secret",
+ AgentID: 0,
+ }
+ _, err := NewWeComAppChannel(cfg, msgBus)
+ if err == nil {
+ t.Error("expected error for missing agent_id, got nil")
+ }
+ })
+
+ t.Run("valid config", func(t *testing.T) {
+ cfg := config.WeComAppConfig{
+ CorpID: "test_corp_id",
+ CorpSecret: "test_secret",
+ AgentID: 1000002,
+ AllowFrom: []string{"user1", "user2"},
+ }
+ ch, err := NewWeComAppChannel(cfg, msgBus)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if ch.Name() != "wecom_app" {
+ t.Errorf("Name() = %q, want %q", ch.Name(), "wecom_app")
+ }
+ if ch.IsRunning() {
+ t.Error("new channel should not be running")
+ }
+ })
+}
+
+func TestWeComAppChannelIsAllowed(t *testing.T) {
+ msgBus := bus.NewMessageBus()
+
+ t.Run("empty allowlist allows all", func(t *testing.T) {
+ cfg := config.WeComAppConfig{
+ CorpID: "test_corp_id",
+ CorpSecret: "test_secret",
+ AgentID: 1000002,
+ AllowFrom: []string{},
+ }
+ ch, _ := NewWeComAppChannel(cfg, msgBus)
+ if !ch.IsAllowed("any_user") {
+ t.Error("empty allowlist should allow all users")
+ }
+ })
+
+ t.Run("allowlist restricts users", func(t *testing.T) {
+ cfg := config.WeComAppConfig{
+ CorpID: "test_corp_id",
+ CorpSecret: "test_secret",
+ AgentID: 1000002,
+ AllowFrom: []string{"allowed_user"},
+ }
+ ch, _ := NewWeComAppChannel(cfg, msgBus)
+ if !ch.IsAllowed("allowed_user") {
+ t.Error("allowed user should pass allowlist check")
+ }
+ if ch.IsAllowed("blocked_user") {
+ t.Error("non-allowed user should be blocked")
+ }
+ })
+}
+
+func TestWeComAppVerifySignature(t *testing.T) {
+ msgBus := bus.NewMessageBus()
+ cfg := config.WeComAppConfig{
+ CorpID: "test_corp_id",
+ CorpSecret: "test_secret",
+ AgentID: 1000002,
+ Token: "test_token",
+ }
+ ch, _ := NewWeComAppChannel(cfg, msgBus)
+
+ t.Run("valid signature", func(t *testing.T) {
+ timestamp := "1234567890"
+ nonce := "test_nonce"
+ msgEncrypt := "test_message"
+ expectedSig := generateSignatureApp("test_token", timestamp, nonce, msgEncrypt)
+
+ if !WeComVerifySignature(ch.config.Token, expectedSig, timestamp, nonce, msgEncrypt) {
+ t.Error("valid signature should pass verification")
+ }
+ })
+
+ t.Run("invalid signature", func(t *testing.T) {
+ timestamp := "1234567890"
+ nonce := "test_nonce"
+ msgEncrypt := "test_message"
+
+ if WeComVerifySignature(ch.config.Token, "invalid_sig", timestamp, nonce, msgEncrypt) {
+ t.Error("invalid signature should fail verification")
+ }
+ })
+
+ t.Run("empty token skips verification", func(t *testing.T) {
+ cfgEmpty := config.WeComAppConfig{
+ CorpID: "test_corp_id",
+ CorpSecret: "test_secret",
+ AgentID: 1000002,
+ Token: "",
+ }
+ chEmpty, _ := NewWeComAppChannel(cfgEmpty, msgBus)
+
+ if !WeComVerifySignature(chEmpty.config.Token, "any_sig", "any_ts", "any_nonce", "any_msg") {
+ t.Error("empty token should skip verification and return true")
+ }
+ })
+}
+
+func TestWeComAppDecryptMessage(t *testing.T) {
+ msgBus := bus.NewMessageBus()
+
+ t.Run("decrypt without AES key", func(t *testing.T) {
+ cfg := config.WeComAppConfig{
+ CorpID: "test_corp_id",
+ CorpSecret: "test_secret",
+ AgentID: 1000002,
+ EncodingAESKey: "",
+ }
+ ch, _ := NewWeComAppChannel(cfg, msgBus)
+
+ // Without AES key, message should be base64 decoded only
+ plainText := "hello world"
+ encoded := base64.StdEncoding.EncodeToString([]byte(plainText))
+
+ result, err := WeComDecryptMessage(encoded, ch.config.EncodingAESKey)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if result != plainText {
+ t.Errorf("decryptMessage() = %q, want %q", result, plainText)
+ }
+ })
+
+ t.Run("decrypt with AES key", func(t *testing.T) {
+ aesKey := generateTestAESKeyApp()
+ cfg := config.WeComAppConfig{
+ CorpID: "test_corp_id",
+ CorpSecret: "test_secret",
+ AgentID: 1000002,
+ EncodingAESKey: aesKey,
+ }
+ ch, _ := NewWeComAppChannel(cfg, msgBus)
+
+ originalMsg := "Hello"
+ encrypted, err := encryptTestMessageApp(originalMsg, aesKey)
+ if err != nil {
+ t.Fatalf("failed to encrypt test message: %v", err)
+ }
+
+ result, err := WeComDecryptMessage(encrypted, ch.config.EncodingAESKey)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if result != originalMsg {
+ t.Errorf("WeComDecryptMessage() = %q, want %q", result, originalMsg)
+ }
+ })
+
+ t.Run("invalid base64", func(t *testing.T) {
+ cfg := config.WeComAppConfig{
+ CorpID: "test_corp_id",
+ CorpSecret: "test_secret",
+ AgentID: 1000002,
+ EncodingAESKey: "",
+ }
+ ch, _ := NewWeComAppChannel(cfg, msgBus)
+
+ _, err := WeComDecryptMessage("invalid_base64!!!", ch.config.EncodingAESKey)
+ if err == nil {
+ t.Error("expected error for invalid base64, got nil")
+ }
+ })
+
+ t.Run("invalid AES key", func(t *testing.T) {
+ cfg := config.WeComAppConfig{
+ CorpID: "test_corp_id",
+ CorpSecret: "test_secret",
+ AgentID: 1000002,
+ EncodingAESKey: "invalid_key",
+ }
+ ch, _ := NewWeComAppChannel(cfg, msgBus)
+
+ _, err := WeComDecryptMessage(base64.StdEncoding.EncodeToString([]byte("test")), ch.config.EncodingAESKey)
+ if err == nil {
+ t.Error("expected error for invalid AES key, got nil")
+ }
+ })
+
+ t.Run("ciphertext too short", func(t *testing.T) {
+ aesKey := generateTestAESKeyApp()
+ cfg := config.WeComAppConfig{
+ CorpID: "test_corp_id",
+ CorpSecret: "test_secret",
+ AgentID: 1000002,
+ EncodingAESKey: aesKey,
+ }
+ ch, _ := NewWeComAppChannel(cfg, msgBus)
+
+ // Encrypt a very short message that results in ciphertext less than block size
+ shortData := make([]byte, 8)
+ _, err := WeComDecryptMessage(base64.StdEncoding.EncodeToString(shortData), ch.config.EncodingAESKey)
+ if err == nil {
+ t.Error("expected error for short ciphertext, got nil")
+ }
+ })
+}
+
+func TestWeComAppPKCS7Unpad(t *testing.T) {
+ tests := []struct {
+ name string
+ input []byte
+ expected []byte
+ }{
+ {
+ name: "empty input",
+ input: []byte{},
+ expected: []byte{},
+ },
+ {
+ name: "valid padding 3 bytes",
+ input: append([]byte("hello"), bytes.Repeat([]byte{3}, 3)...),
+ expected: []byte("hello"),
+ },
+ {
+ name: "valid padding 16 bytes (full block)",
+ input: append([]byte("123456789012345"), bytes.Repeat([]byte{16}, 16)...),
+ expected: []byte("123456789012345"),
+ },
+ {
+ name: "invalid padding larger than data",
+ input: []byte{20},
+ expected: nil, // should return error
+ },
+ {
+ name: "invalid padding zero",
+ input: append([]byte("test"), byte(0)),
+ expected: nil, // should return error
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result, err := pkcs7UnpadWeCom(tt.input)
+ if tt.expected == nil {
+ // This case should return an error
+ if err == nil {
+ t.Errorf("pkcs7Unpad() expected error for invalid padding, got result: %v", result)
+ }
+ return
+ }
+ if err != nil {
+ t.Errorf("pkcs7Unpad() unexpected error: %v", err)
+ return
+ }
+ if !bytes.Equal(result, tt.expected) {
+ t.Errorf("pkcs7Unpad() = %v, want %v", result, tt.expected)
+ }
+ })
+ }
+}
+
+func TestWeComAppHandleVerification(t *testing.T) {
+ msgBus := bus.NewMessageBus()
+ aesKey := generateTestAESKeyApp()
+ cfg := config.WeComAppConfig{
+ CorpID: "test_corp_id",
+ CorpSecret: "test_secret",
+ AgentID: 1000002,
+ Token: "test_token",
+ EncodingAESKey: aesKey,
+ }
+ ch, _ := NewWeComAppChannel(cfg, msgBus)
+
+ t.Run("valid verification request", func(t *testing.T) {
+ echostr := "test_echostr_123"
+ encryptedEchostr, _ := encryptTestMessageApp(echostr, aesKey)
+ timestamp := "1234567890"
+ nonce := "test_nonce"
+ signature := generateSignatureApp("test_token", timestamp, nonce, encryptedEchostr)
+
+ req := httptest.NewRequest(
+ http.MethodGet,
+ "/webhook/wecom-app?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encryptedEchostr,
+ nil,
+ )
+ w := httptest.NewRecorder()
+
+ ch.handleVerification(context.Background(), w, req)
+
+ if w.Code != http.StatusOK {
+ t.Errorf("status code = %d, want %d", w.Code, http.StatusOK)
+ }
+ if w.Body.String() != echostr {
+ t.Errorf("response body = %q, want %q", w.Body.String(), echostr)
+ }
+ })
+
+ t.Run("missing parameters", func(t *testing.T) {
+ req := httptest.NewRequest(http.MethodGet, "/webhook/wecom-app?msg_signature=sig×tamp=ts", nil)
+ w := httptest.NewRecorder()
+
+ ch.handleVerification(context.Background(), w, req)
+
+ if w.Code != http.StatusBadRequest {
+ t.Errorf("status code = %d, want %d", w.Code, http.StatusBadRequest)
+ }
+ })
+
+ t.Run("invalid signature", func(t *testing.T) {
+ echostr := "test_echostr"
+ encryptedEchostr, _ := encryptTestMessageApp(echostr, aesKey)
+ timestamp := "1234567890"
+ nonce := "test_nonce"
+
+ req := httptest.NewRequest(
+ http.MethodGet,
+ "/webhook/wecom-app?msg_signature=invalid_sig×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encryptedEchostr,
+ nil,
+ )
+ w := httptest.NewRecorder()
+
+ ch.handleVerification(context.Background(), w, req)
+
+ if w.Code != http.StatusForbidden {
+ t.Errorf("status code = %d, want %d", w.Code, http.StatusForbidden)
+ }
+ })
+}
+
+func TestWeComAppHandleMessageCallback(t *testing.T) {
+ msgBus := bus.NewMessageBus()
+ aesKey := generateTestAESKeyApp()
+ cfg := config.WeComAppConfig{
+ CorpID: "test_corp_id",
+ CorpSecret: "test_secret",
+ AgentID: 1000002,
+ Token: "test_token",
+ EncodingAESKey: aesKey,
+ }
+ ch, _ := NewWeComAppChannel(cfg, msgBus)
+
+ t.Run("valid message callback", func(t *testing.T) {
+ // Create XML message
+ xmlMsg := WeComXMLMessage{
+ ToUserName: "corp_id",
+ FromUserName: "user123",
+ CreateTime: 1234567890,
+ MsgType: "text",
+ Content: "Hello World",
+ MsgId: 123456,
+ AgentID: 1000002,
+ }
+ xmlData, _ := xml.Marshal(xmlMsg)
+
+ // Encrypt message
+ encrypted, _ := encryptTestMessageApp(string(xmlData), aesKey)
+
+ // Create encrypted XML wrapper
+ encryptedWrapper := struct {
+ XMLName xml.Name `xml:"xml"`
+ Encrypt string `xml:"Encrypt"`
+ }{
+ Encrypt: encrypted,
+ }
+ wrapperData, _ := xml.Marshal(encryptedWrapper)
+
+ timestamp := "1234567890"
+ nonce := "test_nonce"
+ signature := generateSignatureApp("test_token", timestamp, nonce, encrypted)
+
+ req := httptest.NewRequest(
+ http.MethodPost,
+ "/webhook/wecom-app?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce,
+ bytes.NewReader(wrapperData),
+ )
+ w := httptest.NewRecorder()
+
+ ch.handleMessageCallback(context.Background(), w, req)
+
+ if w.Code != http.StatusOK {
+ t.Errorf("status code = %d, want %d", w.Code, http.StatusOK)
+ }
+ if w.Body.String() != "success" {
+ t.Errorf("response body = %q, want %q", w.Body.String(), "success")
+ }
+ })
+
+ t.Run("missing parameters", func(t *testing.T) {
+ req := httptest.NewRequest(http.MethodPost, "/webhook/wecom-app?msg_signature=sig", nil)
+ w := httptest.NewRecorder()
+
+ ch.handleMessageCallback(context.Background(), w, req)
+
+ if w.Code != http.StatusBadRequest {
+ t.Errorf("status code = %d, want %d", w.Code, http.StatusBadRequest)
+ }
+ })
+
+ t.Run("invalid XML", func(t *testing.T) {
+ timestamp := "1234567890"
+ nonce := "test_nonce"
+ signature := generateSignatureApp("test_token", timestamp, nonce, "")
+
+ req := httptest.NewRequest(
+ http.MethodPost,
+ "/webhook/wecom-app?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce,
+ strings.NewReader("invalid xml"),
+ )
+ w := httptest.NewRecorder()
+
+ ch.handleMessageCallback(context.Background(), w, req)
+
+ if w.Code != http.StatusBadRequest {
+ t.Errorf("status code = %d, want %d", w.Code, http.StatusBadRequest)
+ }
+ })
+
+ t.Run("invalid signature", func(t *testing.T) {
+ encryptedWrapper := struct {
+ XMLName xml.Name `xml:"xml"`
+ Encrypt string `xml:"Encrypt"`
+ }{
+ Encrypt: "encrypted_data",
+ }
+ wrapperData, _ := xml.Marshal(encryptedWrapper)
+
+ timestamp := "1234567890"
+ nonce := "test_nonce"
+
+ req := httptest.NewRequest(
+ http.MethodPost,
+ "/webhook/wecom-app?msg_signature=invalid_sig×tamp="+timestamp+"&nonce="+nonce,
+ bytes.NewReader(wrapperData),
+ )
+ w := httptest.NewRecorder()
+
+ ch.handleMessageCallback(context.Background(), w, req)
+
+ if w.Code != http.StatusForbidden {
+ t.Errorf("status code = %d, want %d", w.Code, http.StatusForbidden)
+ }
+ })
+}
+
+func TestWeComAppProcessMessage(t *testing.T) {
+ msgBus := bus.NewMessageBus()
+ cfg := config.WeComAppConfig{
+ CorpID: "test_corp_id",
+ CorpSecret: "test_secret",
+ AgentID: 1000002,
+ }
+ ch, _ := NewWeComAppChannel(cfg, msgBus)
+
+ t.Run("process text message", func(t *testing.T) {
+ msg := WeComXMLMessage{
+ ToUserName: "corp_id",
+ FromUserName: "user123",
+ CreateTime: 1234567890,
+ MsgType: "text",
+ Content: "Hello World",
+ MsgId: 123456,
+ AgentID: 1000002,
+ }
+
+ // Should not panic
+ ch.processMessage(context.Background(), msg)
+ })
+
+ t.Run("process image message", func(t *testing.T) {
+ msg := WeComXMLMessage{
+ ToUserName: "corp_id",
+ FromUserName: "user123",
+ CreateTime: 1234567890,
+ MsgType: "image",
+ PicUrl: "https://example.com/image.jpg",
+ MediaId: "media_123",
+ MsgId: 123456,
+ AgentID: 1000002,
+ }
+
+ // Should not panic
+ ch.processMessage(context.Background(), msg)
+ })
+
+ t.Run("process voice message", func(t *testing.T) {
+ msg := WeComXMLMessage{
+ ToUserName: "corp_id",
+ FromUserName: "user123",
+ CreateTime: 1234567890,
+ MsgType: "voice",
+ MediaId: "media_123",
+ Format: "amr",
+ MsgId: 123456,
+ AgentID: 1000002,
+ }
+
+ // Should not panic
+ ch.processMessage(context.Background(), msg)
+ })
+
+ t.Run("skip unsupported message type", func(t *testing.T) {
+ msg := WeComXMLMessage{
+ ToUserName: "corp_id",
+ FromUserName: "user123",
+ CreateTime: 1234567890,
+ MsgType: "video",
+ MsgId: 123456,
+ AgentID: 1000002,
+ }
+
+ // Should not panic
+ ch.processMessage(context.Background(), msg)
+ })
+
+ t.Run("process event message", func(t *testing.T) {
+ msg := WeComXMLMessage{
+ ToUserName: "corp_id",
+ FromUserName: "user123",
+ CreateTime: 1234567890,
+ MsgType: "event",
+ Event: "subscribe",
+ MsgId: 123456,
+ AgentID: 1000002,
+ }
+
+ // Should not panic
+ ch.processMessage(context.Background(), msg)
+ })
+}
+
+func TestWeComAppHandleWebhook(t *testing.T) {
+ msgBus := bus.NewMessageBus()
+ cfg := config.WeComAppConfig{
+ CorpID: "test_corp_id",
+ CorpSecret: "test_secret",
+ AgentID: 1000002,
+ Token: "test_token",
+ }
+ ch, _ := NewWeComAppChannel(cfg, msgBus)
+
+ t.Run("GET request calls verification", func(t *testing.T) {
+ echostr := "test_echostr"
+ encoded := base64.StdEncoding.EncodeToString([]byte(echostr))
+ timestamp := "1234567890"
+ nonce := "test_nonce"
+ signature := generateSignatureApp("test_token", timestamp, nonce, encoded)
+
+ req := httptest.NewRequest(
+ http.MethodGet,
+ "/webhook/wecom-app?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encoded,
+ nil,
+ )
+ w := httptest.NewRecorder()
+
+ ch.handleWebhook(w, req)
+
+ if w.Code != http.StatusOK {
+ t.Errorf("status code = %d, want %d", w.Code, http.StatusOK)
+ }
+ })
+
+ t.Run("POST request calls message callback", func(t *testing.T) {
+ encryptedWrapper := struct {
+ XMLName xml.Name `xml:"xml"`
+ Encrypt string `xml:"Encrypt"`
+ }{
+ Encrypt: base64.StdEncoding.EncodeToString([]byte("test")),
+ }
+ wrapperData, _ := xml.Marshal(encryptedWrapper)
+
+ timestamp := "1234567890"
+ nonce := "test_nonce"
+ signature := generateSignatureApp("test_token", timestamp, nonce, encryptedWrapper.Encrypt)
+
+ req := httptest.NewRequest(
+ http.MethodPost,
+ "/webhook/wecom-app?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce,
+ bytes.NewReader(wrapperData),
+ )
+ w := httptest.NewRecorder()
+
+ ch.handleWebhook(w, req)
+
+ // Should not be method not allowed
+ if w.Code == http.StatusMethodNotAllowed {
+ t.Error("POST request should not return Method Not Allowed")
+ }
+ })
+
+ t.Run("unsupported method", func(t *testing.T) {
+ req := httptest.NewRequest(http.MethodPut, "/webhook/wecom-app", nil)
+ w := httptest.NewRecorder()
+
+ ch.handleWebhook(w, req)
+
+ if w.Code != http.StatusMethodNotAllowed {
+ t.Errorf("status code = %d, want %d", w.Code, http.StatusMethodNotAllowed)
+ }
+ })
+}
+
+func TestWeComAppHandleHealth(t *testing.T) {
+ msgBus := bus.NewMessageBus()
+ cfg := config.WeComAppConfig{
+ CorpID: "test_corp_id",
+ CorpSecret: "test_secret",
+ AgentID: 1000002,
+ }
+ ch, _ := NewWeComAppChannel(cfg, msgBus)
+
+ req := httptest.NewRequest(http.MethodGet, "/health/wecom-app", nil)
+ w := httptest.NewRecorder()
+
+ ch.handleHealth(w, req)
+
+ if w.Code != http.StatusOK {
+ t.Errorf("status code = %d, want %d", w.Code, http.StatusOK)
+ }
+
+ contentType := w.Header().Get("Content-Type")
+ if contentType != "application/json" {
+ t.Errorf("Content-Type = %q, want %q", contentType, "application/json")
+ }
+
+ body := w.Body.String()
+ if !strings.Contains(body, "status") || !strings.Contains(body, "running") || !strings.Contains(body, "has_token") {
+ t.Errorf("response body should contain status, running, and has_token fields, got: %s", body)
+ }
+}
+
+func TestWeComAppAccessToken(t *testing.T) {
+ msgBus := bus.NewMessageBus()
+ cfg := config.WeComAppConfig{
+ CorpID: "test_corp_id",
+ CorpSecret: "test_secret",
+ AgentID: 1000002,
+ }
+ ch, _ := NewWeComAppChannel(cfg, msgBus)
+
+ t.Run("get empty access token initially", func(t *testing.T) {
+ token := ch.getAccessToken()
+ if token != "" {
+ t.Errorf("getAccessToken() = %q, want empty string", token)
+ }
+ })
+
+ t.Run("set and get access token", func(t *testing.T) {
+ ch.tokenMu.Lock()
+ ch.accessToken = "test_token_123"
+ ch.tokenExpiry = time.Now().Add(1 * time.Hour)
+ ch.tokenMu.Unlock()
+
+ token := ch.getAccessToken()
+ if token != "test_token_123" {
+ t.Errorf("getAccessToken() = %q, want %q", token, "test_token_123")
+ }
+ })
+
+ t.Run("expired token returns empty", func(t *testing.T) {
+ ch.tokenMu.Lock()
+ ch.accessToken = "expired_token"
+ ch.tokenExpiry = time.Now().Add(-1 * time.Hour)
+ ch.tokenMu.Unlock()
+
+ token := ch.getAccessToken()
+ if token != "" {
+ t.Errorf("getAccessToken() = %q, want empty string for expired token", token)
+ }
+ })
+}
+
+func TestWeComAppMessageStructures(t *testing.T) {
+ t.Run("WeComTextMessage structure", func(t *testing.T) {
+ msg := WeComTextMessage{
+ ToUser: "user123",
+ MsgType: "text",
+ AgentID: 1000002,
+ }
+ msg.Text.Content = "Hello World"
+
+ if msg.ToUser != "user123" {
+ t.Errorf("ToUser = %q, want %q", msg.ToUser, "user123")
+ }
+ if msg.MsgType != "text" {
+ t.Errorf("MsgType = %q, want %q", msg.MsgType, "text")
+ }
+ if msg.AgentID != 1000002 {
+ t.Errorf("AgentID = %d, want %d", msg.AgentID, 1000002)
+ }
+ if msg.Text.Content != "Hello World" {
+ t.Errorf("Text.Content = %q, want %q", msg.Text.Content, "Hello World")
+ }
+
+ // Test JSON marshaling
+ jsonData, err := json.Marshal(msg)
+ if err != nil {
+ t.Fatalf("failed to marshal JSON: %v", err)
+ }
+
+ var unmarshaled WeComTextMessage
+ err = json.Unmarshal(jsonData, &unmarshaled)
+ if err != nil {
+ t.Fatalf("failed to unmarshal JSON: %v", err)
+ }
+
+ if unmarshaled.ToUser != msg.ToUser {
+ t.Errorf("JSON round-trip failed for ToUser")
+ }
+ })
+
+ t.Run("WeComMarkdownMessage structure", func(t *testing.T) {
+ msg := WeComMarkdownMessage{
+ ToUser: "user123",
+ MsgType: "markdown",
+ AgentID: 1000002,
+ }
+ msg.Markdown.Content = "# Hello\nWorld"
+
+ if msg.Markdown.Content != "# Hello\nWorld" {
+ t.Errorf("Markdown.Content = %q, want %q", msg.Markdown.Content, "# Hello\nWorld")
+ }
+
+ // Test JSON marshaling
+ jsonData, err := json.Marshal(msg)
+ if err != nil {
+ t.Fatalf("failed to marshal JSON: %v", err)
+ }
+
+ if !bytes.Contains(jsonData, []byte("markdown")) {
+ t.Error("JSON should contain 'markdown' field")
+ }
+ })
+
+ t.Run("WeComAccessTokenResponse structure", func(t *testing.T) {
+ jsonData := `{
+ "errcode": 0,
+ "errmsg": "ok",
+ "access_token": "test_access_token",
+ "expires_in": 7200
+ }`
+
+ var resp WeComAccessTokenResponse
+ err := json.Unmarshal([]byte(jsonData), &resp)
+ if err != nil {
+ t.Fatalf("failed to unmarshal JSON: %v", err)
+ }
+
+ if resp.ErrCode != 0 {
+ t.Errorf("ErrCode = %d, want %d", resp.ErrCode, 0)
+ }
+ if resp.ErrMsg != "ok" {
+ t.Errorf("ErrMsg = %q, want %q", resp.ErrMsg, "ok")
+ }
+ if resp.AccessToken != "test_access_token" {
+ t.Errorf("AccessToken = %q, want %q", resp.AccessToken, "test_access_token")
+ }
+ if resp.ExpiresIn != 7200 {
+ t.Errorf("ExpiresIn = %d, want %d", resp.ExpiresIn, 7200)
+ }
+ })
+
+ t.Run("WeComSendMessageResponse structure", func(t *testing.T) {
+ jsonData := `{
+ "errcode": 0,
+ "errmsg": "ok",
+ "invaliduser": "",
+ "invalidparty": "",
+ "invalidtag": ""
+ }`
+
+ var resp WeComSendMessageResponse
+ err := json.Unmarshal([]byte(jsonData), &resp)
+ if err != nil {
+ t.Fatalf("failed to unmarshal JSON: %v", err)
+ }
+
+ if resp.ErrCode != 0 {
+ t.Errorf("ErrCode = %d, want %d", resp.ErrCode, 0)
+ }
+ if resp.ErrMsg != "ok" {
+ t.Errorf("ErrMsg = %q, want %q", resp.ErrMsg, "ok")
+ }
+ })
+}
+
+func TestWeComAppXMLMessageStructure(t *testing.T) {
+ xmlData := `
+
+
+
+ 1234567890
+
+
+ 1234567890123456
+ 1000002
+`
+
+ var msg WeComXMLMessage
+ err := xml.Unmarshal([]byte(xmlData), &msg)
+ if err != nil {
+ t.Fatalf("failed to unmarshal XML: %v", err)
+ }
+
+ if msg.ToUserName != "corp_id" {
+ t.Errorf("ToUserName = %q, want %q", msg.ToUserName, "corp_id")
+ }
+ if msg.FromUserName != "user123" {
+ t.Errorf("FromUserName = %q, want %q", msg.FromUserName, "user123")
+ }
+ if msg.CreateTime != 1234567890 {
+ t.Errorf("CreateTime = %d, want %d", msg.CreateTime, 1234567890)
+ }
+ if msg.MsgType != "text" {
+ t.Errorf("MsgType = %q, want %q", msg.MsgType, "text")
+ }
+ if msg.Content != "Hello World" {
+ t.Errorf("Content = %q, want %q", msg.Content, "Hello World")
+ }
+ if msg.MsgId != 1234567890123456 {
+ t.Errorf("MsgId = %d, want %d", msg.MsgId, 1234567890123456)
+ }
+ if msg.AgentID != 1000002 {
+ t.Errorf("AgentID = %d, want %d", msg.AgentID, 1000002)
+ }
+}
+
+func TestWeComAppXMLMessageImage(t *testing.T) {
+ xmlData := `
+
+
+
+ 1234567890
+
+
+
+ 1234567890123456
+ 1000002
+`
+
+ var msg WeComXMLMessage
+ err := xml.Unmarshal([]byte(xmlData), &msg)
+ if err != nil {
+ t.Fatalf("failed to unmarshal XML: %v", err)
+ }
+
+ if msg.MsgType != "image" {
+ t.Errorf("MsgType = %q, want %q", msg.MsgType, "image")
+ }
+ if msg.PicUrl != "https://example.com/image.jpg" {
+ t.Errorf("PicUrl = %q, want %q", msg.PicUrl, "https://example.com/image.jpg")
+ }
+ if msg.MediaId != "media_123" {
+ t.Errorf("MediaId = %q, want %q", msg.MediaId, "media_123")
+ }
+}
+
+func TestWeComAppXMLMessageVoice(t *testing.T) {
+ xmlData := `
+
+
+
+ 1234567890
+
+
+
+ 1234567890123456
+ 1000002
+`
+
+ var msg WeComXMLMessage
+ err := xml.Unmarshal([]byte(xmlData), &msg)
+ if err != nil {
+ t.Fatalf("failed to unmarshal XML: %v", err)
+ }
+
+ if msg.MsgType != "voice" {
+ t.Errorf("MsgType = %q, want %q", msg.MsgType, "voice")
+ }
+ if msg.Format != "amr" {
+ t.Errorf("Format = %q, want %q", msg.Format, "amr")
+ }
+}
+
+func TestWeComAppXMLMessageLocation(t *testing.T) {
+ xmlData := `
+
+
+
+ 1234567890
+
+ 39.9042
+ 116.4074
+ 16
+
+ 1234567890123456
+ 1000002
+`
+
+ var msg WeComXMLMessage
+ err := xml.Unmarshal([]byte(xmlData), &msg)
+ if err != nil {
+ t.Fatalf("failed to unmarshal XML: %v", err)
+ }
+
+ if msg.MsgType != "location" {
+ t.Errorf("MsgType = %q, want %q", msg.MsgType, "location")
+ }
+ if msg.LocationX != 39.9042 {
+ t.Errorf("LocationX = %f, want %f", msg.LocationX, 39.9042)
+ }
+ if msg.LocationY != 116.4074 {
+ t.Errorf("LocationY = %f, want %f", msg.LocationY, 116.4074)
+ }
+ if msg.Scale != 16 {
+ t.Errorf("Scale = %d, want %d", msg.Scale, 16)
+ }
+ if msg.Label != "Beijing" {
+ t.Errorf("Label = %q, want %q", msg.Label, "Beijing")
+ }
+}
+
+func TestWeComAppXMLMessageLink(t *testing.T) {
+ xmlData := `
+
+
+
+ 1234567890
+
+
+
+
+ 1234567890123456
+ 1000002
+`
+
+ var msg WeComXMLMessage
+ err := xml.Unmarshal([]byte(xmlData), &msg)
+ if err != nil {
+ t.Fatalf("failed to unmarshal XML: %v", err)
+ }
+
+ if msg.MsgType != "link" {
+ t.Errorf("MsgType = %q, want %q", msg.MsgType, "link")
+ }
+ if msg.Title != "Link Title" {
+ t.Errorf("Title = %q, want %q", msg.Title, "Link Title")
+ }
+ if msg.Description != "Link Description" {
+ t.Errorf("Description = %q, want %q", msg.Description, "Link Description")
+ }
+ if msg.Url != "https://example.com" {
+ t.Errorf("Url = %q, want %q", msg.Url, "https://example.com")
+ }
+}
+
+func TestWeComAppXMLMessageEvent(t *testing.T) {
+ xmlData := `
+
+
+
+ 1234567890
+
+
+
+ 1000002
+`
+
+ var msg WeComXMLMessage
+ err := xml.Unmarshal([]byte(xmlData), &msg)
+ if err != nil {
+ t.Fatalf("failed to unmarshal XML: %v", err)
+ }
+
+ if msg.MsgType != "event" {
+ t.Errorf("MsgType = %q, want %q", msg.MsgType, "event")
+ }
+ if msg.Event != "subscribe" {
+ t.Errorf("Event = %q, want %q", msg.Event, "subscribe")
+ }
+ if msg.EventKey != "event_key_123" {
+ t.Errorf("EventKey = %q, want %q", msg.EventKey, "event_key_123")
+ }
+}
diff --git a/pkg/channels/wecom_test.go b/pkg/channels/wecom_test.go
new file mode 100644
index 000000000..8afa7e8c3
--- /dev/null
+++ b/pkg/channels/wecom_test.go
@@ -0,0 +1,785 @@
+// PicoClaw - Ultra-lightweight personal AI agent
+// WeCom Bot (企业微信智能机器人) channel tests
+
+package channels
+
+import (
+ "bytes"
+ "context"
+ "crypto/aes"
+ "crypto/cipher"
+ "crypto/sha1"
+ "encoding/base64"
+ "encoding/binary"
+ "encoding/json"
+ "encoding/xml"
+ "fmt"
+ "net/http"
+ "net/http/httptest"
+ "sort"
+ "strings"
+ "testing"
+
+ "github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/config"
+)
+
+// generateTestAESKey generates a valid test AES key
+func generateTestAESKey() string {
+ // AES key needs to be 32 bytes (256 bits) for AES-256
+ key := make([]byte, 32)
+ for i := range key {
+ key[i] = byte(i)
+ }
+ // Return base64 encoded key without padding
+ return base64.StdEncoding.EncodeToString(key)[:43]
+}
+
+// encryptTestMessage encrypts a message for testing (AIBOT JSON format)
+func encryptTestMessage(message, aesKey string) (string, error) {
+ // Decode AES key
+ key, err := base64.StdEncoding.DecodeString(aesKey + "=")
+ if err != nil {
+ return "", err
+ }
+
+ // Prepare message: random(16) + msg_len(4) + msg + receiveid
+ random := make([]byte, 0, 16)
+ for i := 0; i < 16; i++ {
+ random = append(random, byte(i))
+ }
+
+ msgBytes := []byte(message)
+ receiveID := []byte("test_aibot_id")
+
+ msgLen := uint32(len(msgBytes))
+ lenBytes := make([]byte, 4)
+ binary.BigEndian.PutUint32(lenBytes, msgLen)
+
+ plainText := append(random, lenBytes...)
+ plainText = append(plainText, msgBytes...)
+ plainText = append(plainText, receiveID...)
+
+ // PKCS7 padding
+ blockSize := aes.BlockSize
+ padding := blockSize - len(plainText)%blockSize
+ padText := bytes.Repeat([]byte{byte(padding)}, padding)
+ plainText = append(plainText, padText...)
+
+ // Encrypt
+ block, err := aes.NewCipher(key)
+ if err != nil {
+ return "", err
+ }
+
+ mode := cipher.NewCBCEncrypter(block, key[:aes.BlockSize])
+ cipherText := make([]byte, len(plainText))
+ mode.CryptBlocks(cipherText, plainText)
+
+ return base64.StdEncoding.EncodeToString(cipherText), nil
+}
+
+// generateSignature generates a signature for testing
+func generateSignature(token, timestamp, nonce, msgEncrypt string) string {
+ params := []string{token, timestamp, nonce, msgEncrypt}
+ sort.Strings(params)
+ str := strings.Join(params, "")
+ hash := sha1.Sum([]byte(str))
+ return fmt.Sprintf("%x", hash)
+}
+
+func TestNewWeComBotChannel(t *testing.T) {
+ msgBus := bus.NewMessageBus()
+
+ t.Run("missing token", func(t *testing.T) {
+ cfg := config.WeComConfig{
+ Token: "",
+ WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
+ }
+ _, err := NewWeComBotChannel(cfg, msgBus)
+ if err == nil {
+ t.Error("expected error for missing token, got nil")
+ }
+ })
+
+ t.Run("missing webhook_url", func(t *testing.T) {
+ cfg := config.WeComConfig{
+ Token: "test_token",
+ WebhookURL: "",
+ }
+ _, err := NewWeComBotChannel(cfg, msgBus)
+ if err == nil {
+ t.Error("expected error for missing webhook_url, got nil")
+ }
+ })
+
+ t.Run("valid config", func(t *testing.T) {
+ cfg := config.WeComConfig{
+ Token: "test_token",
+ WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
+ AllowFrom: []string{"user1", "user2"},
+ }
+ ch, err := NewWeComBotChannel(cfg, msgBus)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if ch.Name() != "wecom" {
+ t.Errorf("Name() = %q, want %q", ch.Name(), "wecom")
+ }
+ if ch.IsRunning() {
+ t.Error("new channel should not be running")
+ }
+ })
+}
+
+func TestWeComBotChannelIsAllowed(t *testing.T) {
+ msgBus := bus.NewMessageBus()
+
+ t.Run("empty allowlist allows all", func(t *testing.T) {
+ cfg := config.WeComConfig{
+ Token: "test_token",
+ WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
+ AllowFrom: []string{},
+ }
+ ch, _ := NewWeComBotChannel(cfg, msgBus)
+ if !ch.IsAllowed("any_user") {
+ t.Error("empty allowlist should allow all users")
+ }
+ })
+
+ t.Run("allowlist restricts users", func(t *testing.T) {
+ cfg := config.WeComConfig{
+ Token: "test_token",
+ WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
+ AllowFrom: []string{"allowed_user"},
+ }
+ ch, _ := NewWeComBotChannel(cfg, msgBus)
+ if !ch.IsAllowed("allowed_user") {
+ t.Error("allowed user should pass allowlist check")
+ }
+ if ch.IsAllowed("blocked_user") {
+ t.Error("non-allowed user should be blocked")
+ }
+ })
+}
+
+func TestWeComBotVerifySignature(t *testing.T) {
+ msgBus := bus.NewMessageBus()
+ cfg := config.WeComConfig{
+ Token: "test_token",
+ WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
+ }
+ ch, _ := NewWeComBotChannel(cfg, msgBus)
+
+ t.Run("valid signature", func(t *testing.T) {
+ timestamp := "1234567890"
+ nonce := "test_nonce"
+ msgEncrypt := "test_message"
+ expectedSig := generateSignature("test_token", timestamp, nonce, msgEncrypt)
+
+ if !WeComVerifySignature(ch.config.Token, expectedSig, timestamp, nonce, msgEncrypt) {
+ t.Error("valid signature should pass verification")
+ }
+ })
+
+ t.Run("invalid signature", func(t *testing.T) {
+ timestamp := "1234567890"
+ nonce := "test_nonce"
+ msgEncrypt := "test_message"
+
+ if WeComVerifySignature(ch.config.Token, "invalid_sig", timestamp, nonce, msgEncrypt) {
+ t.Error("invalid signature should fail verification")
+ }
+ })
+
+ t.Run("empty token skips verification", func(t *testing.T) {
+ // Create a channel manually with empty token to test the behavior
+ cfgEmpty := config.WeComConfig{
+ Token: "",
+ WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
+ }
+ chEmpty := &WeComBotChannel{
+ config: cfgEmpty,
+ }
+
+ if !WeComVerifySignature(chEmpty.config.Token, "any_sig", "any_ts", "any_nonce", "any_msg") {
+ t.Error("empty token should skip verification and return true")
+ }
+ })
+}
+
+func TestWeComBotDecryptMessage(t *testing.T) {
+ msgBus := bus.NewMessageBus()
+
+ t.Run("decrypt without AES key", func(t *testing.T) {
+ cfg := config.WeComConfig{
+ Token: "test_token",
+ WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
+ EncodingAESKey: "",
+ }
+ ch, _ := NewWeComBotChannel(cfg, msgBus)
+
+ // Without AES key, message should be base64 decoded only
+ plainText := "hello world"
+ encoded := base64.StdEncoding.EncodeToString([]byte(plainText))
+
+ result, err := WeComDecryptMessage(encoded, ch.config.EncodingAESKey)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if result != plainText {
+ t.Errorf("decryptMessage() = %q, want %q", result, plainText)
+ }
+ })
+
+ t.Run("decrypt with AES key", func(t *testing.T) {
+ aesKey := generateTestAESKey()
+ cfg := config.WeComConfig{
+ Token: "test_token",
+ WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
+ EncodingAESKey: aesKey,
+ }
+ ch, _ := NewWeComBotChannel(cfg, msgBus)
+
+ originalMsg := "Hello"
+ encrypted, err := encryptTestMessage(originalMsg, aesKey)
+ if err != nil {
+ t.Fatalf("failed to encrypt test message: %v", err)
+ }
+
+ result, err := WeComDecryptMessage(encrypted, ch.config.EncodingAESKey)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if result != originalMsg {
+ t.Errorf("WeComDecryptMessage() = %q, want %q", result, originalMsg)
+ }
+ })
+
+ t.Run("invalid base64", func(t *testing.T) {
+ cfg := config.WeComConfig{
+ Token: "test_token",
+ WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
+ EncodingAESKey: "",
+ }
+ ch, _ := NewWeComBotChannel(cfg, msgBus)
+
+ _, err := WeComDecryptMessage("invalid_base64!!!", ch.config.EncodingAESKey)
+ if err == nil {
+ t.Error("expected error for invalid base64, got nil")
+ }
+ })
+
+ t.Run("invalid AES key", func(t *testing.T) {
+ cfg := config.WeComConfig{
+ Token: "test_token",
+ WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
+ EncodingAESKey: "invalid_key",
+ }
+ ch, _ := NewWeComBotChannel(cfg, msgBus)
+
+ _, err := WeComDecryptMessage(base64.StdEncoding.EncodeToString([]byte("test")), ch.config.EncodingAESKey)
+ if err == nil {
+ t.Error("expected error for invalid AES key, got nil")
+ }
+ })
+}
+
+func TestWeComBotPKCS7Unpad(t *testing.T) {
+ tests := []struct {
+ name string
+ input []byte
+ expected []byte
+ }{
+ {
+ name: "empty input",
+ input: []byte{},
+ expected: []byte{},
+ },
+ {
+ name: "valid padding 3 bytes",
+ input: append([]byte("hello"), bytes.Repeat([]byte{3}, 3)...),
+ expected: []byte("hello"),
+ },
+ {
+ name: "valid padding 16 bytes (full block)",
+ input: append([]byte("123456789012345"), bytes.Repeat([]byte{16}, 16)...),
+ expected: []byte("123456789012345"),
+ },
+ {
+ name: "invalid padding larger than data",
+ input: []byte{20},
+ expected: nil, // should return error
+ },
+ {
+ name: "invalid padding zero",
+ input: append([]byte("test"), byte(0)),
+ expected: nil, // should return error
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result, err := pkcs7UnpadWeCom(tt.input)
+ if tt.expected == nil {
+ // This case should return an error
+ if err == nil {
+ t.Errorf("pkcs7UnpadWeCom() expected error for invalid padding, got result: %v", result)
+ }
+ return
+ }
+ if err != nil {
+ t.Errorf("pkcs7UnpadWeCom() unexpected error: %v", err)
+ return
+ }
+ if !bytes.Equal(result, tt.expected) {
+ t.Errorf("pkcs7UnpadWeCom() = %v, want %v", result, tt.expected)
+ }
+ })
+ }
+}
+
+func TestWeComBotHandleVerification(t *testing.T) {
+ msgBus := bus.NewMessageBus()
+ aesKey := generateTestAESKey()
+ cfg := config.WeComConfig{
+ Token: "test_token",
+ EncodingAESKey: aesKey,
+ WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
+ }
+ ch, _ := NewWeComBotChannel(cfg, msgBus)
+
+ t.Run("valid verification request", func(t *testing.T) {
+ echostr := "test_echostr_123"
+ encryptedEchostr, _ := encryptTestMessage(echostr, aesKey)
+ timestamp := "1234567890"
+ nonce := "test_nonce"
+ signature := generateSignature("test_token", timestamp, nonce, encryptedEchostr)
+
+ req := httptest.NewRequest(
+ http.MethodGet,
+ "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encryptedEchostr,
+ nil,
+ )
+ w := httptest.NewRecorder()
+
+ ch.handleVerification(context.Background(), w, req)
+
+ if w.Code != http.StatusOK {
+ t.Errorf("status code = %d, want %d", w.Code, http.StatusOK)
+ }
+ if w.Body.String() != echostr {
+ t.Errorf("response body = %q, want %q", w.Body.String(), echostr)
+ }
+ })
+
+ t.Run("missing parameters", func(t *testing.T) {
+ req := httptest.NewRequest(http.MethodGet, "/webhook/wecom?msg_signature=sig×tamp=ts", nil)
+ w := httptest.NewRecorder()
+
+ ch.handleVerification(context.Background(), w, req)
+
+ if w.Code != http.StatusBadRequest {
+ t.Errorf("status code = %d, want %d", w.Code, http.StatusBadRequest)
+ }
+ })
+
+ t.Run("invalid signature", func(t *testing.T) {
+ echostr := "test_echostr"
+ encryptedEchostr, _ := encryptTestMessage(echostr, aesKey)
+ timestamp := "1234567890"
+ nonce := "test_nonce"
+
+ req := httptest.NewRequest(
+ http.MethodGet,
+ "/webhook/wecom?msg_signature=invalid_sig×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encryptedEchostr,
+ nil,
+ )
+ w := httptest.NewRecorder()
+
+ ch.handleVerification(context.Background(), w, req)
+
+ if w.Code != http.StatusForbidden {
+ t.Errorf("status code = %d, want %d", w.Code, http.StatusForbidden)
+ }
+ })
+}
+
+func TestWeComBotHandleMessageCallback(t *testing.T) {
+ msgBus := bus.NewMessageBus()
+ aesKey := generateTestAESKey()
+ cfg := config.WeComConfig{
+ Token: "test_token",
+ EncodingAESKey: aesKey,
+ WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
+ }
+ ch, _ := NewWeComBotChannel(cfg, msgBus)
+
+ t.Run("valid direct message callback", func(t *testing.T) {
+ // Create JSON message for direct chat (single)
+ jsonMsg := `{
+ "msgid": "test_msg_id_123",
+ "aibotid": "test_aibot_id",
+ "chattype": "single",
+ "from": {"userid": "user123"},
+ "response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
+ "msgtype": "text",
+ "text": {"content": "Hello World"}
+ }`
+
+ // Encrypt message
+ encrypted, _ := encryptTestMessage(jsonMsg, aesKey)
+
+ // Create encrypted XML wrapper
+ encryptedWrapper := struct {
+ XMLName xml.Name `xml:"xml"`
+ Encrypt string `xml:"Encrypt"`
+ }{
+ Encrypt: encrypted,
+ }
+ wrapperData, _ := xml.Marshal(encryptedWrapper)
+
+ timestamp := "1234567890"
+ nonce := "test_nonce"
+ signature := generateSignature("test_token", timestamp, nonce, encrypted)
+
+ req := httptest.NewRequest(
+ http.MethodPost,
+ "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce,
+ bytes.NewReader(wrapperData),
+ )
+ w := httptest.NewRecorder()
+
+ ch.handleMessageCallback(context.Background(), w, req)
+
+ if w.Code != http.StatusOK {
+ t.Errorf("status code = %d, want %d", w.Code, http.StatusOK)
+ }
+ if w.Body.String() != "success" {
+ t.Errorf("response body = %q, want %q", w.Body.String(), "success")
+ }
+ })
+
+ t.Run("valid group message callback", func(t *testing.T) {
+ // Create JSON message for group chat
+ jsonMsg := `{
+ "msgid": "test_msg_id_456",
+ "aibotid": "test_aibot_id",
+ "chatid": "group_chat_id_123",
+ "chattype": "group",
+ "from": {"userid": "user456"},
+ "response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
+ "msgtype": "text",
+ "text": {"content": "Hello Group"}
+ }`
+
+ // Encrypt message
+ encrypted, _ := encryptTestMessage(jsonMsg, aesKey)
+
+ // Create encrypted XML wrapper
+ encryptedWrapper := struct {
+ XMLName xml.Name `xml:"xml"`
+ Encrypt string `xml:"Encrypt"`
+ }{
+ Encrypt: encrypted,
+ }
+ wrapperData, _ := xml.Marshal(encryptedWrapper)
+
+ timestamp := "1234567890"
+ nonce := "test_nonce"
+ signature := generateSignature("test_token", timestamp, nonce, encrypted)
+
+ req := httptest.NewRequest(
+ http.MethodPost,
+ "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce,
+ bytes.NewReader(wrapperData),
+ )
+ w := httptest.NewRecorder()
+
+ ch.handleMessageCallback(context.Background(), w, req)
+
+ if w.Code != http.StatusOK {
+ t.Errorf("status code = %d, want %d", w.Code, http.StatusOK)
+ }
+ if w.Body.String() != "success" {
+ t.Errorf("response body = %q, want %q", w.Body.String(), "success")
+ }
+ })
+
+ t.Run("missing parameters", func(t *testing.T) {
+ req := httptest.NewRequest(http.MethodPost, "/webhook/wecom?msg_signature=sig", nil)
+ w := httptest.NewRecorder()
+
+ ch.handleMessageCallback(context.Background(), w, req)
+
+ if w.Code != http.StatusBadRequest {
+ t.Errorf("status code = %d, want %d", w.Code, http.StatusBadRequest)
+ }
+ })
+
+ t.Run("invalid XML", func(t *testing.T) {
+ timestamp := "1234567890"
+ nonce := "test_nonce"
+ signature := generateSignature("test_token", timestamp, nonce, "")
+
+ req := httptest.NewRequest(
+ http.MethodPost,
+ "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce,
+ strings.NewReader("invalid xml"),
+ )
+ w := httptest.NewRecorder()
+
+ ch.handleMessageCallback(context.Background(), w, req)
+
+ if w.Code != http.StatusBadRequest {
+ t.Errorf("status code = %d, want %d", w.Code, http.StatusBadRequest)
+ }
+ })
+
+ t.Run("invalid signature", func(t *testing.T) {
+ encryptedWrapper := struct {
+ XMLName xml.Name `xml:"xml"`
+ Encrypt string `xml:"Encrypt"`
+ }{
+ Encrypt: "encrypted_data",
+ }
+ wrapperData, _ := xml.Marshal(encryptedWrapper)
+
+ timestamp := "1234567890"
+ nonce := "test_nonce"
+
+ req := httptest.NewRequest(
+ http.MethodPost,
+ "/webhook/wecom?msg_signature=invalid_sig×tamp="+timestamp+"&nonce="+nonce,
+ bytes.NewReader(wrapperData),
+ )
+ w := httptest.NewRecorder()
+
+ ch.handleMessageCallback(context.Background(), w, req)
+
+ if w.Code != http.StatusForbidden {
+ t.Errorf("status code = %d, want %d", w.Code, http.StatusForbidden)
+ }
+ })
+}
+
+func TestWeComBotProcessMessage(t *testing.T) {
+ msgBus := bus.NewMessageBus()
+ cfg := config.WeComConfig{
+ Token: "test_token",
+ WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
+ }
+ ch, _ := NewWeComBotChannel(cfg, msgBus)
+
+ t.Run("process direct text message", func(t *testing.T) {
+ msg := WeComBotMessage{
+ MsgID: "test_msg_id_123",
+ AIBotID: "test_aibot_id",
+ ChatType: "single",
+ ResponseURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
+ MsgType: "text",
+ }
+ msg.From.UserID = "user123"
+ msg.Text.Content = "Hello World"
+
+ // Should not panic
+ ch.processMessage(context.Background(), msg)
+ })
+
+ t.Run("process group text message", func(t *testing.T) {
+ msg := WeComBotMessage{
+ MsgID: "test_msg_id_456",
+ AIBotID: "test_aibot_id",
+ ChatID: "group_chat_id_123",
+ ChatType: "group",
+ ResponseURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
+ MsgType: "text",
+ }
+ msg.From.UserID = "user456"
+ msg.Text.Content = "Hello Group"
+
+ // Should not panic
+ ch.processMessage(context.Background(), msg)
+ })
+
+ t.Run("process voice message", func(t *testing.T) {
+ msg := WeComBotMessage{
+ MsgID: "test_msg_id_789",
+ AIBotID: "test_aibot_id",
+ ChatType: "single",
+ ResponseURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
+ MsgType: "voice",
+ }
+ msg.From.UserID = "user123"
+ msg.Voice.Content = "Voice message text"
+
+ // Should not panic
+ ch.processMessage(context.Background(), msg)
+ })
+
+ t.Run("skip unsupported message type", func(t *testing.T) {
+ msg := WeComBotMessage{
+ MsgID: "test_msg_id_000",
+ AIBotID: "test_aibot_id",
+ ChatType: "single",
+ ResponseURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
+ MsgType: "video",
+ }
+ msg.From.UserID = "user123"
+
+ // Should not panic
+ ch.processMessage(context.Background(), msg)
+ })
+}
+
+func TestWeComBotHandleWebhook(t *testing.T) {
+ msgBus := bus.NewMessageBus()
+ cfg := config.WeComConfig{
+ Token: "test_token",
+ WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
+ }
+ ch, _ := NewWeComBotChannel(cfg, msgBus)
+
+ t.Run("GET request calls verification", func(t *testing.T) {
+ echostr := "test_echostr"
+ encoded := base64.StdEncoding.EncodeToString([]byte(echostr))
+ timestamp := "1234567890"
+ nonce := "test_nonce"
+ signature := generateSignature("test_token", timestamp, nonce, encoded)
+
+ req := httptest.NewRequest(
+ http.MethodGet,
+ "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encoded,
+ nil,
+ )
+ w := httptest.NewRecorder()
+
+ ch.handleWebhook(w, req)
+
+ if w.Code != http.StatusOK {
+ t.Errorf("status code = %d, want %d", w.Code, http.StatusOK)
+ }
+ })
+
+ t.Run("POST request calls message callback", func(t *testing.T) {
+ encryptedWrapper := struct {
+ XMLName xml.Name `xml:"xml"`
+ Encrypt string `xml:"Encrypt"`
+ }{
+ Encrypt: base64.StdEncoding.EncodeToString([]byte("test")),
+ }
+ wrapperData, _ := xml.Marshal(encryptedWrapper)
+
+ timestamp := "1234567890"
+ nonce := "test_nonce"
+ signature := generateSignature("test_token", timestamp, nonce, encryptedWrapper.Encrypt)
+
+ req := httptest.NewRequest(
+ http.MethodPost,
+ "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce,
+ bytes.NewReader(wrapperData),
+ )
+ w := httptest.NewRecorder()
+
+ ch.handleWebhook(w, req)
+
+ // Should not be method not allowed
+ if w.Code == http.StatusMethodNotAllowed {
+ t.Error("POST request should not return Method Not Allowed")
+ }
+ })
+
+ t.Run("unsupported method", func(t *testing.T) {
+ req := httptest.NewRequest(http.MethodPut, "/webhook/wecom", nil)
+ w := httptest.NewRecorder()
+
+ ch.handleWebhook(w, req)
+
+ if w.Code != http.StatusMethodNotAllowed {
+ t.Errorf("status code = %d, want %d", w.Code, http.StatusMethodNotAllowed)
+ }
+ })
+}
+
+func TestWeComBotHandleHealth(t *testing.T) {
+ msgBus := bus.NewMessageBus()
+ cfg := config.WeComConfig{
+ Token: "test_token",
+ WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
+ }
+ ch, _ := NewWeComBotChannel(cfg, msgBus)
+
+ req := httptest.NewRequest(http.MethodGet, "/health/wecom", nil)
+ w := httptest.NewRecorder()
+
+ ch.handleHealth(w, req)
+
+ if w.Code != http.StatusOK {
+ t.Errorf("status code = %d, want %d", w.Code, http.StatusOK)
+ }
+
+ contentType := w.Header().Get("Content-Type")
+ if contentType != "application/json" {
+ t.Errorf("Content-Type = %q, want %q", contentType, "application/json")
+ }
+
+ body := w.Body.String()
+ if !strings.Contains(body, "status") || !strings.Contains(body, "running") {
+ t.Errorf("response body should contain status and running fields, got: %s", body)
+ }
+}
+
+func TestWeComBotReplyMessage(t *testing.T) {
+ msg := WeComBotReplyMessage{
+ MsgType: "text",
+ }
+ msg.Text.Content = "Hello World"
+
+ if msg.MsgType != "text" {
+ t.Errorf("MsgType = %q, want %q", msg.MsgType, "text")
+ }
+ if msg.Text.Content != "Hello World" {
+ t.Errorf("Text.Content = %q, want %q", msg.Text.Content, "Hello World")
+ }
+}
+
+func TestWeComBotMessageStructure(t *testing.T) {
+ jsonData := `{
+ "msgid": "test_msg_id_123",
+ "aibotid": "test_aibot_id",
+ "chatid": "group_chat_id_123",
+ "chattype": "group",
+ "from": {"userid": "user123"},
+ "response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test",
+ "msgtype": "text",
+ "text": {"content": "Hello World"}
+ }`
+
+ var msg WeComBotMessage
+ err := json.Unmarshal([]byte(jsonData), &msg)
+ if err != nil {
+ t.Fatalf("failed to unmarshal JSON: %v", err)
+ }
+
+ if msg.MsgID != "test_msg_id_123" {
+ t.Errorf("MsgID = %q, want %q", msg.MsgID, "test_msg_id_123")
+ }
+ if msg.AIBotID != "test_aibot_id" {
+ t.Errorf("AIBotID = %q, want %q", msg.AIBotID, "test_aibot_id")
+ }
+ if msg.ChatID != "group_chat_id_123" {
+ t.Errorf("ChatID = %q, want %q", msg.ChatID, "group_chat_id_123")
+ }
+ if msg.ChatType != "group" {
+ t.Errorf("ChatType = %q, want %q", msg.ChatType, "group")
+ }
+ if msg.From.UserID != "user123" {
+ t.Errorf("From.UserID = %q, want %q", msg.From.UserID, "user123")
+ }
+ if msg.MsgType != "text" {
+ t.Errorf("MsgType = %q, want %q", msg.MsgType, "text")
+ }
+ if msg.Text.Content != "Hello World" {
+ t.Errorf("Text.Content = %q, want %q", msg.Text.Content, "Hello World")
+ }
+}
diff --git a/pkg/channels/whatsapp.go b/pkg/channels/whatsapp.go
index 065424e0c..2dc4017ac 100644
--- a/pkg/channels/whatsapp.go
+++ b/pkg/channels/whatsapp.go
@@ -41,7 +41,10 @@ func (c *WhatsAppChannel) Start(ctx context.Context) error {
dialer := websocket.DefaultDialer
dialer.HandshakeTimeout = 10 * time.Second
- conn, _, err := dialer.Dial(c.url, nil)
+ conn, resp, err := dialer.Dial(c.url, nil)
+ if resp != nil {
+ resp.Body.Close()
+ }
if err != nil {
return fmt.Errorf("failed to connect to WhatsApp bridge: %w", err)
}
@@ -86,7 +89,7 @@ func (c *WhatsAppChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
return fmt.Errorf("whatsapp connection not established")
}
- payload := map[string]interface{}{
+ payload := map[string]any{
"type": "message",
"to": msg.ChatID,
"content": msg.Content,
@@ -126,7 +129,7 @@ func (c *WhatsAppChannel) listen(ctx context.Context) {
continue
}
- var msg map[string]interface{}
+ var msg map[string]any
if err := json.Unmarshal(message, &msg); err != nil {
log.Printf("Failed to unmarshal WhatsApp message: %v", err)
continue
@@ -144,7 +147,7 @@ func (c *WhatsAppChannel) listen(ctx context.Context) {
}
}
-func (c *WhatsAppChannel) handleIncomingMessage(msg map[string]interface{}) {
+func (c *WhatsAppChannel) handleIncomingMessage(msg map[string]any) {
senderID, ok := msg["from"].(string)
if !ok {
return
@@ -161,7 +164,7 @@ func (c *WhatsAppChannel) handleIncomingMessage(msg map[string]interface{}) {
}
var mediaPaths []string
- if mediaData, ok := msg["media"].([]interface{}); ok {
+ if mediaData, ok := msg["media"].([]any); ok {
mediaPaths = make([]string, 0, len(mediaData))
for _, m := range mediaData {
if path, ok := m.(string); ok {
diff --git a/pkg/config/config.go b/pkg/config/config.go
index 87a1186a8..f7c78136b 100644
--- a/pkg/config/config.go
+++ b/pkg/config/config.go
@@ -26,7 +26,7 @@ func (f *FlexibleStringSlice) UnmarshalJSON(data []byte) error {
}
// Try []interface{} to handle mixed types
- var raw []interface{}
+ var raw []any
if err := json.Unmarshal(data, &raw); err != nil {
return err
}
@@ -167,16 +167,26 @@ type SessionConfig struct {
}
type AgentDefaults struct {
- Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"`
- RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"`
- Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"`
- Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"`
+ Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"`
+ RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"`
+ Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"`
+ ModelName string `json:"model_name,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL_NAME"`
+ Model string `json:"model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` // Deprecated: use model_name instead
ModelFallbacks []string `json:"model_fallbacks,omitempty"`
- ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"`
+ ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"`
ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"`
- MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
- Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
- MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
+ MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
+ Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
+ MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
+}
+
+// GetModelName returns the effective model name for the agent defaults.
+// It prefers the new "model_name" field but falls back to "model" for backward compatibility.
+func (d *AgentDefaults) GetModelName() string {
+ if d.ModelName != "" {
+ return d.ModelName
+ }
+ return d.Model
}
type ChannelsConfig struct {
@@ -190,90 +200,119 @@ type ChannelsConfig struct {
Slack SlackConfig `json:"slack"`
LINE LINEConfig `json:"line"`
OneBot OneBotConfig `json:"onebot"`
+ WeCom WeComConfig `json:"wecom"`
+ WeComApp WeComAppConfig `json:"wecom_app"`
}
type WhatsAppConfig struct {
- Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WHATSAPP_ENABLED"`
+ Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WHATSAPP_ENABLED"`
BridgeURL string `json:"bridge_url" env:"PICOCLAW_CHANNELS_WHATSAPP_BRIDGE_URL"`
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WHATSAPP_ALLOW_FROM"`
}
type TelegramConfig struct {
- Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_TELEGRAM_ENABLED"`
- Token string `json:"token" env:"PICOCLAW_CHANNELS_TELEGRAM_TOKEN"`
- Proxy string `json:"proxy" env:"PICOCLAW_CHANNELS_TELEGRAM_PROXY"`
+ Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_TELEGRAM_ENABLED"`
+ Token string `json:"token" env:"PICOCLAW_CHANNELS_TELEGRAM_TOKEN"`
+ Proxy string `json:"proxy" env:"PICOCLAW_CHANNELS_TELEGRAM_PROXY"`
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_TELEGRAM_ALLOW_FROM"`
}
type FeishuConfig struct {
- Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_FEISHU_ENABLED"`
- AppID string `json:"app_id" env:"PICOCLAW_CHANNELS_FEISHU_APP_ID"`
- AppSecret string `json:"app_secret" env:"PICOCLAW_CHANNELS_FEISHU_APP_SECRET"`
- EncryptKey string `json:"encrypt_key" env:"PICOCLAW_CHANNELS_FEISHU_ENCRYPT_KEY"`
+ Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_FEISHU_ENABLED"`
+ AppID string `json:"app_id" env:"PICOCLAW_CHANNELS_FEISHU_APP_ID"`
+ AppSecret string `json:"app_secret" env:"PICOCLAW_CHANNELS_FEISHU_APP_SECRET"`
+ EncryptKey string `json:"encrypt_key" env:"PICOCLAW_CHANNELS_FEISHU_ENCRYPT_KEY"`
VerificationToken string `json:"verification_token" env:"PICOCLAW_CHANNELS_FEISHU_VERIFICATION_TOKEN"`
- AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_FEISHU_ALLOW_FROM"`
+ AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_FEISHU_ALLOW_FROM"`
}
type DiscordConfig struct {
- Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DISCORD_ENABLED"`
- Token string `json:"token" env:"PICOCLAW_CHANNELS_DISCORD_TOKEN"`
- AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_DISCORD_ALLOW_FROM"`
+ Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DISCORD_ENABLED"`
+ Token string `json:"token" env:"PICOCLAW_CHANNELS_DISCORD_TOKEN"`
+ AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_DISCORD_ALLOW_FROM"`
+ MentionOnly bool `json:"mention_only" env:"PICOCLAW_CHANNELS_DISCORD_MENTION_ONLY"`
}
type MaixCamConfig struct {
- Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_MAIXCAM_ENABLED"`
- Host string `json:"host" env:"PICOCLAW_CHANNELS_MAIXCAM_HOST"`
- Port int `json:"port" env:"PICOCLAW_CHANNELS_MAIXCAM_PORT"`
+ Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_MAIXCAM_ENABLED"`
+ Host string `json:"host" env:"PICOCLAW_CHANNELS_MAIXCAM_HOST"`
+ Port int `json:"port" env:"PICOCLAW_CHANNELS_MAIXCAM_PORT"`
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_MAIXCAM_ALLOW_FROM"`
}
type QQConfig struct {
- Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_QQ_ENABLED"`
- AppID string `json:"app_id" env:"PICOCLAW_CHANNELS_QQ_APP_ID"`
+ Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_QQ_ENABLED"`
+ AppID string `json:"app_id" env:"PICOCLAW_CHANNELS_QQ_APP_ID"`
AppSecret string `json:"app_secret" env:"PICOCLAW_CHANNELS_QQ_APP_SECRET"`
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_QQ_ALLOW_FROM"`
}
type DingTalkConfig struct {
- Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DINGTALK_ENABLED"`
- ClientID string `json:"client_id" env:"PICOCLAW_CHANNELS_DINGTALK_CLIENT_ID"`
+ Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DINGTALK_ENABLED"`
+ ClientID string `json:"client_id" env:"PICOCLAW_CHANNELS_DINGTALK_CLIENT_ID"`
ClientSecret string `json:"client_secret" env:"PICOCLAW_CHANNELS_DINGTALK_CLIENT_SECRET"`
- AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_DINGTALK_ALLOW_FROM"`
+ AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_DINGTALK_ALLOW_FROM"`
}
type SlackConfig struct {
- Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_SLACK_ENABLED"`
- BotToken string `json:"bot_token" env:"PICOCLAW_CHANNELS_SLACK_BOT_TOKEN"`
- AppToken string `json:"app_token" env:"PICOCLAW_CHANNELS_SLACK_APP_TOKEN"`
+ Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_SLACK_ENABLED"`
+ BotToken string `json:"bot_token" env:"PICOCLAW_CHANNELS_SLACK_BOT_TOKEN"`
+ AppToken string `json:"app_token" env:"PICOCLAW_CHANNELS_SLACK_APP_TOKEN"`
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_SLACK_ALLOW_FROM"`
}
type LINEConfig struct {
- Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_LINE_ENABLED"`
- ChannelSecret string `json:"channel_secret" env:"PICOCLAW_CHANNELS_LINE_CHANNEL_SECRET"`
+ Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_LINE_ENABLED"`
+ ChannelSecret string `json:"channel_secret" env:"PICOCLAW_CHANNELS_LINE_CHANNEL_SECRET"`
ChannelAccessToken string `json:"channel_access_token" env:"PICOCLAW_CHANNELS_LINE_CHANNEL_ACCESS_TOKEN"`
- WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_LINE_WEBHOOK_HOST"`
- WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_LINE_WEBHOOK_PORT"`
- WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_LINE_WEBHOOK_PATH"`
- AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_LINE_ALLOW_FROM"`
+ WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_LINE_WEBHOOK_HOST"`
+ WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_LINE_WEBHOOK_PORT"`
+ WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_LINE_WEBHOOK_PATH"`
+ AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_LINE_ALLOW_FROM"`
}
type OneBotConfig struct {
- Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_ONEBOT_ENABLED"`
- WSUrl string `json:"ws_url" env:"PICOCLAW_CHANNELS_ONEBOT_WS_URL"`
- AccessToken string `json:"access_token" env:"PICOCLAW_CHANNELS_ONEBOT_ACCESS_TOKEN"`
- ReconnectInterval int `json:"reconnect_interval" env:"PICOCLAW_CHANNELS_ONEBOT_RECONNECT_INTERVAL"`
+ Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_ONEBOT_ENABLED"`
+ WSUrl string `json:"ws_url" env:"PICOCLAW_CHANNELS_ONEBOT_WS_URL"`
+ AccessToken string `json:"access_token" env:"PICOCLAW_CHANNELS_ONEBOT_ACCESS_TOKEN"`
+ ReconnectInterval int `json:"reconnect_interval" env:"PICOCLAW_CHANNELS_ONEBOT_RECONNECT_INTERVAL"`
GroupTriggerPrefix []string `json:"group_trigger_prefix" env:"PICOCLAW_CHANNELS_ONEBOT_GROUP_TRIGGER_PREFIX"`
- AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_ONEBOT_ALLOW_FROM"`
+ AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_ONEBOT_ALLOW_FROM"`
+}
+
+type WeComConfig struct {
+ Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_ENABLED"`
+ Token string `json:"token" env:"PICOCLAW_CHANNELS_WECOM_TOKEN"`
+ EncodingAESKey string `json:"encoding_aes_key" env:"PICOCLAW_CHANNELS_WECOM_ENCODING_AES_KEY"`
+ WebhookURL string `json:"webhook_url" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_URL"`
+ WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_HOST"`
+ WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_PORT"`
+ WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_PATH"`
+ AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_ALLOW_FROM"`
+ ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_REPLY_TIMEOUT"`
+}
+
+type WeComAppConfig struct {
+ Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_APP_ENABLED"`
+ CorpID string `json:"corp_id" env:"PICOCLAW_CHANNELS_WECOM_APP_CORP_ID"`
+ CorpSecret string `json:"corp_secret" env:"PICOCLAW_CHANNELS_WECOM_APP_CORP_SECRET"`
+ AgentID int64 `json:"agent_id" env:"PICOCLAW_CHANNELS_WECOM_APP_AGENT_ID"`
+ Token string `json:"token" env:"PICOCLAW_CHANNELS_WECOM_APP_TOKEN"`
+ EncodingAESKey string `json:"encoding_aes_key" env:"PICOCLAW_CHANNELS_WECOM_APP_ENCODING_AES_KEY"`
+ WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_HOST"`
+ WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_PORT"`
+ WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_PATH"`
+ AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_APP_ALLOW_FROM"`
+ ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_APP_REPLY_TIMEOUT"`
}
type HeartbeatConfig struct {
- Enabled bool `json:"enabled" env:"PICOCLAW_HEARTBEAT_ENABLED"`
+ Enabled bool `json:"enabled" env:"PICOCLAW_HEARTBEAT_ENABLED"`
Interval int `json:"interval" env:"PICOCLAW_HEARTBEAT_INTERVAL"` // minutes, min 5
}
type DevicesConfig struct {
- Enabled bool `json:"enabled" env:"PICOCLAW_DEVICES_ENABLED"`
+ Enabled bool `json:"enabled" env:"PICOCLAW_DEVICES_ENABLED"`
MonitorUSB bool `json:"monitor_usb" env:"PICOCLAW_DEVICES_MONITOR_USB"`
}
@@ -295,6 +334,7 @@ type ProvidersConfig struct {
GitHubCopilot ProviderConfig `json:"github_copilot"`
Antigravity ProviderConfig `json:"antigravity"`
Qwen ProviderConfig `json:"qwen"`
+ Mistral ProviderConfig `json:"mistral"`
}
// IsEmpty checks if all provider configs are empty (no API keys or API bases set)
@@ -316,7 +356,8 @@ func (p ProvidersConfig) IsEmpty() bool {
p.VolcEngine.APIKey == "" && p.VolcEngine.APIBase == "" &&
p.GitHubCopilot.APIKey == "" && p.GitHubCopilot.APIBase == "" &&
p.Antigravity.APIKey == "" && p.Antigravity.APIBase == "" &&
- p.Qwen.APIKey == "" && p.Qwen.APIBase == ""
+ p.Qwen.APIKey == "" && p.Qwen.APIBase == "" &&
+ p.Mistral.APIKey == "" && p.Mistral.APIBase == ""
}
// MarshalJSON implements custom JSON marshaling for ProvidersConfig
@@ -330,11 +371,11 @@ func (p ProvidersConfig) MarshalJSON() ([]byte, error) {
}
type ProviderConfig struct {
- APIKey string `json:"api_key" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_KEY"`
- APIBase string `json:"api_base" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_BASE"`
- Proxy string `json:"proxy,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_PROXY"`
- AuthMethod string `json:"auth_method,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_AUTH_METHOD"`
- ConnectMode string `json:"connect_mode,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_CONNECT_MODE"` //only for Github Copilot, `stdio` or `grpc`
+ APIKey string `json:"api_key" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_KEY"`
+ APIBase string `json:"api_base" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_BASE"`
+ Proxy string `json:"proxy,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_PROXY"`
+ AuthMethod string `json:"auth_method,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_AUTH_METHOD"`
+ ConnectMode string `json:"connect_mode,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_CONNECT_MODE"` // only for Github Copilot, `stdio` or `grpc`
}
type OpenAIProviderConfig struct {
@@ -384,19 +425,26 @@ type GatewayConfig struct {
}
type BraveConfig struct {
- Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_BRAVE_ENABLED"`
- APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_BRAVE_API_KEY"`
+ Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_BRAVE_ENABLED"`
+ APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_BRAVE_API_KEY"`
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_BRAVE_MAX_RESULTS"`
}
+type TavilyConfig struct {
+ Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_TAVILY_ENABLED"`
+ APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_TAVILY_API_KEY"`
+ BaseURL string `json:"base_url" env:"PICOCLAW_TOOLS_WEB_TAVILY_BASE_URL"`
+ MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_TAVILY_MAX_RESULTS"`
+}
+
type DuckDuckGoConfig struct {
- Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_DUCKDUCKGO_ENABLED"`
+ Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_DUCKDUCKGO_ENABLED"`
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_DUCKDUCKGO_MAX_RESULTS"`
}
type PerplexityConfig struct {
- Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_ENABLED"`
- APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_API_KEY"`
+ Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_ENABLED"`
+ APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_API_KEY"`
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_MAX_RESULTS"`
}
@@ -408,9 +456,13 @@ type SearXNGConfig struct {
type WebToolsConfig struct {
Brave BraveConfig `json:"brave"`
+ Tavily TavilyConfig `json:"tavily"`
DuckDuckGo DuckDuckGoConfig `json:"duckduckgo"`
Perplexity PerplexityConfig `json:"perplexity"`
SearXNG SearXNGConfig `json:"searxng"`
+ // Proxy is an optional proxy URL for web tools (http/https/socks5/socks5h).
+ // For authenticated proxies, prefer HTTP_PROXY/HTTPS_PROXY env vars instead of embedding credentials in config.
+ Proxy string `json:"proxy,omitempty" env:"PICOCLAW_TOOLS_WEB_PROXY"`
}
type CronToolsConfig struct {
@@ -423,9 +475,37 @@ type ExecConfig struct {
}
type ToolsConfig struct {
- Web WebToolsConfig `json:"web"`
- Cron CronToolsConfig `json:"cron"`
- Exec ExecConfig `json:"exec"`
+ Web WebToolsConfig `json:"web"`
+ Cron CronToolsConfig `json:"cron"`
+ Exec ExecConfig `json:"exec"`
+ Skills SkillsToolsConfig `json:"skills"`
+}
+
+type SkillsToolsConfig struct {
+ Registries SkillsRegistriesConfig `json:"registries"`
+ MaxConcurrentSearches int `json:"max_concurrent_searches" env:"PICOCLAW_SKILLS_MAX_CONCURRENT_SEARCHES"`
+ SearchCache SearchCacheConfig `json:"search_cache"`
+}
+
+type SearchCacheConfig struct {
+ MaxSize int `json:"max_size" env:"PICOCLAW_SKILLS_SEARCH_CACHE_MAX_SIZE"`
+ TTLSeconds int `json:"ttl_seconds" env:"PICOCLAW_SKILLS_SEARCH_CACHE_TTL_SECONDS"`
+}
+
+type SkillsRegistriesConfig struct {
+ ClawHub ClawHubRegistryConfig `json:"clawhub"`
+}
+
+type ClawHubRegistryConfig struct {
+ Enabled bool `json:"enabled" env:"PICOCLAW_SKILLS_REGISTRIES_CLAWHUB_ENABLED"`
+ BaseURL string `json:"base_url" env:"PICOCLAW_SKILLS_REGISTRIES_CLAWHUB_BASE_URL"`
+ AuthToken string `json:"auth_token" env:"PICOCLAW_SKILLS_REGISTRIES_CLAWHUB_AUTH_TOKEN"`
+ SearchPath string `json:"search_path" env:"PICOCLAW_SKILLS_REGISTRIES_CLAWHUB_SEARCH_PATH"`
+ SkillsPath string `json:"skills_path" env:"PICOCLAW_SKILLS_REGISTRIES_CLAWHUB_SKILLS_PATH"`
+ DownloadPath string `json:"download_path" env:"PICOCLAW_SKILLS_REGISTRIES_CLAWHUB_DOWNLOAD_PATH"`
+ Timeout int `json:"timeout" env:"PICOCLAW_SKILLS_REGISTRIES_CLAWHUB_TIMEOUT"`
+ MaxZipSize int `json:"max_zip_size" env:"PICOCLAW_SKILLS_REGISTRIES_CLAWHUB_MAX_ZIP_SIZE"`
+ MaxResponseSize int `json:"max_response_size" env:"PICOCLAW_SKILLS_REGISTRIES_CLAWHUB_MAX_RESPONSE_SIZE"`
}
func LoadConfig(path string) (*Config, error) {
@@ -439,6 +519,20 @@ func LoadConfig(path string) (*Config, error) {
return nil, err
}
+ // Pre-scan the JSON to check how many model_list entries the user provided.
+ // Go's JSON decoder reuses existing slice backing-array elements rather than
+ // zero-initializing them, so fields absent from the user's JSON (e.g. api_base)
+ // would silently inherit values from the DefaultConfig template at the same
+ // index position. We only reset cfg.ModelList when the user actually provides
+ // entries; when count is 0 we keep DefaultConfig's built-in list as fallback.
+ var tmp Config
+ if err := json.Unmarshal(data, &tmp); err != nil {
+ return nil, err
+ }
+ if len(tmp.ModelList) > 0 {
+ cfg.ModelList = nil
+ }
+
if err := json.Unmarshal(data, cfg); err != nil {
return nil, err
}
@@ -467,11 +561,11 @@ func SaveConfig(path string, cfg *Config) error {
}
dir := filepath.Dir(path)
- if err := os.MkdirAll(dir, 0755); err != nil {
+ if err := os.MkdirAll(dir, 0o755); err != nil {
return err
}
- return os.WriteFile(path, data, 0600)
+ return os.WriteFile(path, data, 0o600)
}
func (c *Config) WorkspacePath() string {
@@ -586,7 +680,8 @@ func (c *Config) HasProvidersConfig() bool {
v.VolcEngine.APIKey != "" || v.VolcEngine.APIBase != "" ||
v.GitHubCopilot.APIKey != "" || v.GitHubCopilot.APIBase != "" ||
v.Antigravity.APIKey != "" || v.Antigravity.APIBase != "" ||
- v.Qwen.APIKey != "" || v.Qwen.APIBase != ""
+ v.Qwen.APIKey != "" || v.Qwen.APIBase != "" ||
+ v.Mistral.APIKey != "" || v.Mistral.APIBase != ""
}
// ValidateModelList validates all ModelConfig entries in the model_list.
diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go
index 7e706d8ce..223ac798d 100644
--- a/pkg/config/config_test.go
+++ b/pkg/config/config_test.go
@@ -55,7 +55,7 @@ func TestAgentModelConfig_MarshalObject(t *testing.T) {
if err != nil {
t.Fatalf("marshal: %v", err)
}
- var result map[string]interface{}
+ var result map[string]any
json.Unmarshal(data, &result)
if result["primary"] != "claude-opus" {
t.Errorf("primary = %v", result["primary"])
@@ -246,7 +246,7 @@ func TestDefaultConfig_Temperature(t *testing.T) {
func TestDefaultConfig_Gateway(t *testing.T) {
cfg := DefaultConfig()
- if cfg.Gateway.Host != "0.0.0.0" {
+ if cfg.Gateway.Host != "127.0.0.1" {
t.Error("Gateway host should have default value")
}
if cfg.Gateway.Port == 0 {
@@ -319,7 +319,7 @@ func TestSaveConfig_FilePermissions(t *testing.T) {
}
perm := info.Mode().Perm()
- if perm != 0600 {
+ if perm != 0o600 {
t.Errorf("config file has permission %04o, want 0600", perm)
}
}
@@ -343,7 +343,7 @@ func TestConfig_Complete(t *testing.T) {
if cfg.Agents.Defaults.MaxToolIterations == 0 {
t.Error("MaxToolIterations should not be zero")
}
- if cfg.Gateway.Host != "0.0.0.0" {
+ if cfg.Gateway.Host != "127.0.0.1" {
t.Error("Gateway host should have default value")
}
if cfg.Gateway.Port == 0 {
@@ -392,3 +392,24 @@ func TestLoadConfig_OpenAIWebSearchCanBeDisabled(t *testing.T) {
t.Fatal("OpenAI codex web search should be false when disabled in config file")
}
}
+
+func TestLoadConfig_WebToolsProxy(t *testing.T) {
+ tmpDir := t.TempDir()
+ configPath := filepath.Join(tmpDir, "config.json")
+ configJSON := `{
+ "agents": {"defaults":{"workspace":"./workspace","model":"gpt4","max_tokens":8192,"max_tool_iterations":20}},
+ "model_list": [{"model_name":"gpt4","model":"openai/gpt-5.2","api_key":"x"}],
+ "tools": {"web":{"proxy":"http://127.0.0.1:7890"}}
+}`
+ if err := os.WriteFile(configPath, []byte(configJSON), 0o600); err != nil {
+ t.Fatalf("os.WriteFile() error: %v", err)
+ }
+
+ cfg, err := LoadConfig(configPath)
+ if err != nil {
+ t.Fatalf("LoadConfig() error: %v", err)
+ }
+ if cfg.Tools.Web.Proxy != "http://127.0.0.1:7890" {
+ t.Fatalf("Tools.Web.Proxy = %q, want %q", cfg.Tools.Web.Proxy, "http://127.0.0.1:7890")
+ }
+}
diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go
index a8c5bee58..47c51fc9d 100644
--- a/pkg/config/defaults.go
+++ b/pkg/config/defaults.go
@@ -43,9 +43,10 @@ func DefaultConfig() *Config {
AllowFrom: FlexibleStringSlice{},
},
Discord: DiscordConfig{
- Enabled: false,
- Token: "",
- AllowFrom: FlexibleStringSlice{},
+ Enabled: false,
+ Token: "",
+ AllowFrom: FlexibleStringSlice{},
+ MentionOnly: false,
},
MaixCam: MaixCamConfig{
Enabled: false,
@@ -88,6 +89,30 @@ func DefaultConfig() *Config {
GroupTriggerPrefix: []string{},
AllowFrom: FlexibleStringSlice{},
},
+ WeCom: WeComConfig{
+ Enabled: false,
+ Token: "",
+ EncodingAESKey: "",
+ WebhookURL: "",
+ WebhookHost: "0.0.0.0",
+ WebhookPort: 18793,
+ WebhookPath: "/webhook/wecom",
+ AllowFrom: FlexibleStringSlice{},
+ ReplyTimeout: 5,
+ },
+ WeComApp: WeComAppConfig{
+ Enabled: false,
+ CorpID: "",
+ CorpSecret: "",
+ AgentID: 0,
+ Token: "",
+ EncodingAESKey: "",
+ WebhookHost: "0.0.0.0",
+ WebhookPort: 18792,
+ WebhookPath: "/webhook/wecom-app",
+ AllowFrom: FlexibleStringSlice{},
+ ReplyTimeout: 5,
+ },
},
Providers: ProvidersConfig{
OpenAI: OpenAIProviderConfig{WebSearch: true},
@@ -230,6 +255,14 @@ func DefaultConfig() *Config {
APIKey: "ollama",
},
+ // Mistral AI - https://console.mistral.ai/api-keys
+ {
+ ModelName: "mistral-small",
+ Model: "mistral/mistral-small-latest",
+ APIBase: "https://api.mistral.ai/v1",
+ APIKey: "",
+ },
+
// VLLM (local) - http://localhost:8000
{
ModelName: "local-model",
@@ -239,11 +272,12 @@ func DefaultConfig() *Config {
},
},
Gateway: GatewayConfig{
- Host: "0.0.0.0",
+ Host: "127.0.0.1",
Port: 18790,
},
Tools: ToolsConfig{
Web: WebToolsConfig{
+ Proxy: "",
Brave: BraveConfig{
Enabled: false,
APIKey: "",
@@ -270,6 +304,19 @@ func DefaultConfig() *Config {
Exec: ExecConfig{
EnableDenyPatterns: true,
},
+ Skills: SkillsToolsConfig{
+ Registries: SkillsRegistriesConfig{
+ ClawHub: ClawHubRegistryConfig{
+ Enabled: true,
+ BaseURL: "https://clawhub.ai",
+ },
+ },
+ MaxConcurrentSearches: 2,
+ SearchCache: SearchCacheConfig{
+ MaxSize: 50,
+ TTLSeconds: 300,
+ },
+ },
},
Heartbeat: HeartbeatConfig{
Enabled: true,
diff --git a/pkg/config/migration.go b/pkg/config/migration.go
index 689e2312f..70e1de438 100644
--- a/pkg/config/migration.go
+++ b/pkg/config/migration.go
@@ -41,7 +41,7 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
// Get user's configured provider and model
userProvider := strings.ToLower(cfg.Agents.Defaults.Provider)
- userModel := cfg.Agents.Defaults.Model
+ userModel := cfg.Agents.Defaults.GetModelName()
p := cfg.Providers
@@ -324,6 +324,22 @@ func ConvertProvidersToModelList(cfg *Config) []ModelConfig {
}, true
},
},
+ {
+ providerNames: []string{"mistral"},
+ protocol: "mistral",
+ buildConfig: func(p ProvidersConfig) (ModelConfig, bool) {
+ if p.Mistral.APIKey == "" && p.Mistral.APIBase == "" {
+ return ModelConfig{}, false
+ }
+ return ModelConfig{
+ ModelName: "mistral",
+ Model: "mistral/mistral-small-latest",
+ APIKey: p.Mistral.APIKey,
+ APIBase: p.Mistral.APIBase,
+ Proxy: p.Mistral.Proxy,
+ }, true
+ },
+ },
}
// Process each provider migration
diff --git a/pkg/config/migration_test.go b/pkg/config/migration_test.go
index b9a333f9e..42165cb71 100644
--- a/pkg/config/migration_test.go
+++ b/pkg/config/migration_test.go
@@ -131,14 +131,15 @@ func TestConvertProvidersToModelList_AllProviders(t *testing.T) {
GitHubCopilot: ProviderConfig{ConnectMode: "grpc"},
Antigravity: ProviderConfig{AuthMethod: "oauth"},
Qwen: ProviderConfig{APIKey: "key17"},
+ Mistral: ProviderConfig{APIKey: "key18"},
},
}
result := ConvertProvidersToModelList(cfg)
- // All 17 providers should be converted
- if len(result) != 17 {
- t.Errorf("len(result) = %d, want 17", len(result))
+ // All 18 providers should be converted
+ if len(result) != 18 {
+ t.Errorf("len(result) = %d, want 18", len(result))
}
}
@@ -361,7 +362,10 @@ func TestConvertProvidersToModelList_ProviderNameAliases(t *testing.T) {
Agents: AgentsConfig{
Defaults: AgentDefaults{
Provider: tt.providerAlias,
- Model: strings.TrimPrefix(tt.expectedModel, tt.expectedModel[:strings.Index(tt.expectedModel, "/")+1]),
+ Model: strings.TrimPrefix(
+ tt.expectedModel,
+ tt.expectedModel[:strings.Index(tt.expectedModel, "/")+1],
+ ),
},
},
Providers: ProvidersConfig{},
@@ -382,7 +386,10 @@ func TestConvertProvidersToModelList_ProviderNameAliases(t *testing.T) {
}
// Need to fix the model name in config
- cfg.Agents.Defaults.Model = strings.TrimPrefix(tt.expectedModel, tt.expectedModel[:strings.Index(tt.expectedModel, "/")+1])
+ cfg.Agents.Defaults.Model = strings.TrimPrefix(
+ tt.expectedModel,
+ tt.expectedModel[:strings.Index(tt.expectedModel, "/")+1],
+ )
result := ConvertProvidersToModelList(cfg)
if len(result) != 1 {
@@ -515,7 +522,11 @@ func TestBuildModelWithProtocol_AlreadyHasPrefix(t *testing.T) {
func TestBuildModelWithProtocol_DifferentPrefix(t *testing.T) {
result := buildModelWithProtocol("anthropic", "openrouter/claude-sonnet-4.6")
if result != "openrouter/claude-sonnet-4.6" {
- t.Errorf("buildModelWithProtocol(anthropic, openrouter/claude-sonnet-4.6) = %q, want %q", result, "openrouter/claude-sonnet-4.6")
+ t.Errorf(
+ "buildModelWithProtocol(anthropic, openrouter/claude-sonnet-4.6) = %q, want %q",
+ result,
+ "openrouter/claude-sonnet-4.6",
+ )
}
}
diff --git a/pkg/config/model_config_test.go b/pkg/config/model_config_test.go
index 3c411dc0f..99eea2782 100644
--- a/pkg/config/model_config_test.go
+++ b/pkg/config/model_config_test.go
@@ -6,6 +6,7 @@
package config
import (
+ "encoding/json"
"strings"
"sync"
"testing"
@@ -114,6 +115,137 @@ func TestGetModelConfig_Concurrent(t *testing.T) {
}
}
+func TestAgentDefaults_GetModelName_BackwardCompat(t *testing.T) {
+ tests := []struct {
+ name string
+ defaults AgentDefaults
+ wantName string
+ }{
+ {
+ name: "new model_name field only",
+ defaults: AgentDefaults{ModelName: "new-model"},
+ wantName: "new-model",
+ },
+ {
+ name: "old model field only",
+ defaults: AgentDefaults{Model: "legacy-model"},
+ wantName: "legacy-model",
+ },
+ {
+ name: "both fields - model_name takes precedence",
+ defaults: AgentDefaults{ModelName: "new-model", Model: "old-model"},
+ wantName: "new-model",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := tt.defaults.GetModelName(); got != tt.wantName {
+ t.Errorf("GetModelName() = %q, want %q", got, tt.wantName)
+ }
+ })
+ }
+}
+
+func TestAgentDefaults_JSON_BackwardCompat(t *testing.T) {
+ tests := []struct {
+ name string
+ json string
+ wantName string
+ }{
+ {
+ name: "new model_name field",
+ json: `{"model_name": "gpt4"}`,
+ wantName: "gpt4",
+ },
+ {
+ name: "old model field",
+ json: `{"model": "gpt4"}`,
+ wantName: "gpt4",
+ },
+ {
+ name: "both fields - model_name wins",
+ json: `{"model_name": "new", "model": "old"}`,
+ wantName: "new",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ var defaults AgentDefaults
+ if err := json.Unmarshal([]byte(tt.json), &defaults); err != nil {
+ t.Fatalf("Unmarshal error: %v", err)
+ }
+ if got := defaults.GetModelName(); got != tt.wantName {
+ t.Errorf("GetModelName() = %q, want %q", got, tt.wantName)
+ }
+ })
+ }
+}
+
+func TestFullConfig_JSON_BackwardCompat(t *testing.T) {
+ // Test complete config with both old and new formats
+ oldFormat := `{
+ "agents": {
+ "defaults": {
+ "workspace": "~/.picoclaw/workspace",
+ "model": "gpt4",
+ "max_tokens": 4096
+ }
+ },
+ "model_list": [
+ {
+ "model_name": "gpt4",
+ "model": "openai/gpt-4o",
+ "api_key": "test-key"
+ }
+ ]
+ }`
+
+ newFormat := `{
+ "agents": {
+ "defaults": {
+ "workspace": "~/.picoclaw/workspace",
+ "model_name": "gpt4",
+ "max_tokens": 4096
+ }
+ },
+ "model_list": [
+ {
+ "model_name": "gpt4",
+ "model": "openai/gpt-4o",
+ "api_key": "test-key"
+ }
+ ]
+ }`
+
+ for name, jsonStr := range map[string]string{
+ "old format (model)": oldFormat,
+ "new format (model_name)": newFormat,
+ } {
+ t.Run(name, func(t *testing.T) {
+ cfg := &Config{}
+ if err := json.Unmarshal([]byte(jsonStr), cfg); err != nil {
+ t.Fatalf("Unmarshal error: %v", err)
+ }
+
+ // Check that GetModelName returns correct value
+ if got := cfg.Agents.Defaults.GetModelName(); got != "gpt4" {
+ t.Errorf("GetModelName() = %q, want %q", got, "gpt4")
+ }
+
+ // Check that GetModelConfig works
+ modelCfg, err := cfg.GetModelConfig("gpt4")
+ if err != nil {
+ t.Fatalf("GetModelConfig error: %v", err)
+ }
+ if modelCfg.Model != "openai/gpt-4o" {
+ t.Errorf("Model = %q, want %q", modelCfg.Model, "openai/gpt-4o")
+ }
+ })
+ }
+}
+
func TestModelConfig_Validate(t *testing.T) {
tests := []struct {
name string
diff --git a/pkg/cron/service.go b/pkg/cron/service.go
index 9f62c743b..e699a44b5 100644
--- a/pkg/cron/service.go
+++ b/pkg/cron/service.go
@@ -331,7 +331,7 @@ func (cs *CronService) loadStore() error {
func (cs *CronService) saveStoreUnsafe() error {
dir := filepath.Dir(cs.storePath)
- if err := os.MkdirAll(dir, 0755); err != nil {
+ if err := os.MkdirAll(dir, 0o755); err != nil {
return err
}
@@ -340,10 +340,16 @@ func (cs *CronService) saveStoreUnsafe() error {
return err
}
- return os.WriteFile(cs.storePath, data, 0600)
+ return os.WriteFile(cs.storePath, data, 0o600)
}
-func (cs *CronService) AddJob(name string, schedule CronSchedule, message string, deliver bool, channel, to string) (*CronJob, error) {
+func (cs *CronService) AddJob(
+ name string,
+ schedule CronSchedule,
+ message string,
+ deliver bool,
+ channel, to string,
+) (*CronJob, error) {
cs.mu.Lock()
defer cs.mu.Unlock()
@@ -465,7 +471,7 @@ func (cs *CronService) ListJobs(includeDisabled bool) []CronJob {
return enabled
}
-func (cs *CronService) Status() map[string]interface{} {
+func (cs *CronService) Status() map[string]any {
cs.mu.RLock()
defer cs.mu.RUnlock()
@@ -476,7 +482,7 @@ func (cs *CronService) Status() map[string]interface{} {
}
}
- return map[string]interface{}{
+ return map[string]any{
"enabled": cs.running,
"jobs": len(cs.store.Jobs),
"nextWakeAtMS": cs.getNextWakeMS(),
diff --git a/pkg/cron/service_test.go b/pkg/cron/service_test.go
index 53d69f6a9..1a0dd1829 100644
--- a/pkg/cron/service_test.go
+++ b/pkg/cron/service_test.go
@@ -28,7 +28,7 @@ func TestSaveStore_FilePermissions(t *testing.T) {
}
perm := info.Mode().Perm()
- if perm != 0600 {
+ if perm != 0o600 {
t.Errorf("cron store has permission %04o, want 0600", perm)
}
}
diff --git a/pkg/devices/service.go b/pkg/devices/service.go
index 05a254729..1541d3c57 100644
--- a/pkg/devices/service.go
+++ b/pkg/devices/service.go
@@ -63,14 +63,14 @@ func (s *Service) Start(ctx context.Context) error {
for _, src := range s.sources {
eventCh, err := src.Start(s.ctx)
if err != nil {
- logger.ErrorCF("devices", "Failed to start source", map[string]interface{}{
+ logger.ErrorCF("devices", "Failed to start source", map[string]any{
"kind": src.Kind(),
"error": err.Error(),
})
continue
}
go s.handleEvents(src.Kind(), eventCh)
- logger.InfoCF("devices", "Device source started", map[string]interface{}{
+ logger.InfoCF("devices", "Device source started", map[string]any{
"kind": src.Kind(),
})
}
@@ -115,7 +115,7 @@ func (s *Service) sendNotification(ev *events.DeviceEvent) {
lastChannel := s.state.GetLastChannel()
if lastChannel == "" {
- logger.DebugCF("devices", "No last channel, skipping notification", map[string]interface{}{
+ logger.DebugCF("devices", "No last channel, skipping notification", map[string]any{
"event": ev.FormatMessage(),
})
return
@@ -133,7 +133,7 @@ func (s *Service) sendNotification(ev *events.DeviceEvent) {
Content: msg,
})
- logger.InfoCF("devices", "Device notification sent", map[string]interface{}{
+ logger.InfoCF("devices", "Device notification sent", map[string]any{
"kind": ev.Kind,
"action": ev.Action,
"to": platform,
diff --git a/pkg/devices/sources/usb_linux.go b/pkg/devices/sources/usb_linux.go
index 1f6c068b3..2bb38941f 100644
--- a/pkg/devices/sources/usb_linux.go
+++ b/pkg/devices/sources/usb_linux.go
@@ -35,9 +35,8 @@ var usbClassToCapability = map[string]string{
}
type USBMonitor struct {
- cmd *exec.Cmd
- cancel context.CancelFunc
- mu sync.Mutex
+ cmd *exec.Cmd
+ mu sync.Mutex
}
func NewUSBMonitor() *USBMonitor {
@@ -115,7 +114,7 @@ func (m *USBMonitor) Start(ctx context.Context) (<-chan *events.DeviceEvent, err
}
if err := scanner.Err(); err != nil {
- logger.ErrorCF("devices", "udevadm scan error", map[string]interface{}{"error": err.Error()})
+ logger.ErrorCF("devices", "udevadm scan error", map[string]any{"error": err.Error()})
}
cmd.Wait()
}()
diff --git a/pkg/heartbeat/service.go b/pkg/heartbeat/service.go
index dfdaef58b..e05a9fdbf 100644
--- a/pkg/heartbeat/service.go
+++ b/pkg/heartbeat/service.go
@@ -166,7 +166,7 @@ func (hs *HeartbeatService) executeHeartbeat() {
}
if handler == nil {
- hs.logError("Heartbeat handler not configured")
+ hs.logErrorf("Heartbeat handler not configured")
return
}
@@ -175,25 +175,25 @@ func (hs *HeartbeatService) executeHeartbeat() {
channel, chatID := hs.parseLastChannel(lastChannel)
// Debug log for channel resolution
- hs.logInfo("Resolved channel: %s, chatID: %s (from lastChannel: %s)", channel, chatID, lastChannel)
+ hs.logInfof("Resolved channel: %s, chatID: %s (from lastChannel: %s)", channel, chatID, lastChannel)
result := handler(prompt, channel, chatID)
if result == nil {
- hs.logInfo("Heartbeat handler returned nil result")
+ hs.logInfof("Heartbeat handler returned nil result")
return
}
// Handle different result types
if result.IsError {
- hs.logError("Heartbeat error: %s", result.ForLLM)
+ hs.logErrorf("Heartbeat error: %s", result.ForLLM)
return
}
if result.Async {
- hs.logInfo("Async task started: %s", result.ForLLM)
+ hs.logInfof("Async task started: %s", result.ForLLM)
logger.InfoCF("heartbeat", "Async heartbeat task started",
- map[string]interface{}{
+ map[string]any{
"message": result.ForLLM,
})
return
@@ -201,7 +201,7 @@ func (hs *HeartbeatService) executeHeartbeat() {
// Check if silent
if result.Silent {
- hs.logInfo("Heartbeat OK - silent")
+ hs.logInfof("Heartbeat OK - silent")
return
}
@@ -212,7 +212,7 @@ func (hs *HeartbeatService) executeHeartbeat() {
hs.sendResponse(result.ForLLM)
}
- hs.logInfo("Heartbeat completed: %s", result.ForLLM)
+ hs.logInfof("Heartbeat completed: %s", result.ForLLM)
}
// buildPrompt builds the heartbeat prompt from HEARTBEAT.md
@@ -225,7 +225,7 @@ func (hs *HeartbeatService) buildPrompt() string {
hs.createDefaultHeartbeatTemplate()
return ""
}
- hs.logError("Error reading HEARTBEAT.md: %v", err)
+ hs.logErrorf("Error reading HEARTBEAT.md: %v", err)
return ""
}
@@ -275,10 +275,10 @@ This file contains tasks for the heartbeat service to check periodically.
Add your heartbeat tasks below this line:
`
- if err := os.WriteFile(heartbeatPath, []byte(defaultContent), 0644); err != nil {
- hs.logError("Failed to create default HEARTBEAT.md: %v", err)
+ if err := os.WriteFile(heartbeatPath, []byte(defaultContent), 0o644); err != nil {
+ hs.logErrorf("Failed to create default HEARTBEAT.md: %v", err)
} else {
- hs.logInfo("Created default HEARTBEAT.md template")
+ hs.logInfof("Created default HEARTBEAT.md template")
}
}
@@ -289,14 +289,14 @@ func (hs *HeartbeatService) sendResponse(response string) {
hs.mu.RUnlock()
if msgBus == nil {
- hs.logInfo("No message bus configured, heartbeat result not sent")
+ hs.logInfof("No message bus configured, heartbeat result not sent")
return
}
// Get last channel from state
lastChannel := hs.state.GetLastChannel()
if lastChannel == "" {
- hs.logInfo("No last channel recorded, heartbeat result not sent")
+ hs.logInfof("No last channel recorded, heartbeat result not sent")
return
}
@@ -313,7 +313,7 @@ func (hs *HeartbeatService) sendResponse(response string) {
Content: response,
})
- hs.logInfo("Heartbeat result sent to %s", platform)
+ hs.logInfof("Heartbeat result sent to %s", platform)
}
// parseLastChannel parses the last channel string into platform and userID.
@@ -326,7 +326,7 @@ func (hs *HeartbeatService) parseLastChannel(lastChannel string) (platform, user
// Parse channel format: "platform:user_id" (e.g., "telegram:123456")
parts := strings.SplitN(lastChannel, ":", 2)
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
- hs.logError("Invalid last channel format: %s", lastChannel)
+ hs.logErrorf("Invalid last channel format: %s", lastChannel)
return "", ""
}
@@ -334,27 +334,27 @@ func (hs *HeartbeatService) parseLastChannel(lastChannel string) (platform, user
// Skip internal channels
if constants.IsInternalChannel(platform) {
- hs.logInfo("Skipping internal channel: %s", platform)
+ hs.logInfof("Skipping internal channel: %s", platform)
return "", ""
}
return platform, userID
}
-// logInfo logs an informational message to the heartbeat log
-func (hs *HeartbeatService) logInfo(format string, args ...any) {
- hs.log("INFO", format, args...)
+// logInfof logs an informational message to the heartbeat log
+func (hs *HeartbeatService) logInfof(format string, args ...any) {
+ hs.logf("INFO", format, args...)
}
-// logError logs an error message to the heartbeat log
-func (hs *HeartbeatService) logError(format string, args ...any) {
- hs.log("ERROR", format, args...)
+// logErrorf logs an error message to the heartbeat log
+func (hs *HeartbeatService) logErrorf(format string, args ...any) {
+ hs.logf("ERROR", format, args...)
}
-// log writes a message to the heartbeat log file
-func (hs *HeartbeatService) log(level, format string, args ...any) {
+// logf writes a message to the heartbeat log file
+func (hs *HeartbeatService) logf(level, format string, args ...any) {
logFile := filepath.Join(hs.workspace, "heartbeat.log")
- f, err := os.OpenFile(logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
+ f, err := os.OpenFile(logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644)
if err != nil {
return
}
diff --git a/pkg/heartbeat/service_test.go b/pkg/heartbeat/service_test.go
index a2b59e350..a7aef8c3a 100644
--- a/pkg/heartbeat/service_test.go
+++ b/pkg/heartbeat/service_test.go
@@ -37,7 +37,7 @@ func TestExecuteHeartbeat_Async(t *testing.T) {
})
// Create HEARTBEAT.md
- os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0644)
+ os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0o644)
// Execute heartbeat directly (internal method for testing)
hs.executeHeartbeat()
@@ -68,7 +68,7 @@ func TestExecuteHeartbeat_Error(t *testing.T) {
})
// Create HEARTBEAT.md
- os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0644)
+ os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0o644)
hs.executeHeartbeat()
@@ -106,7 +106,7 @@ func TestExecuteHeartbeat_Silent(t *testing.T) {
})
// Create HEARTBEAT.md
- os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0644)
+ os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0o644)
hs.executeHeartbeat()
@@ -174,7 +174,7 @@ func TestExecuteHeartbeat_NilResult(t *testing.T) {
})
// Create HEARTBEAT.md
- os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0644)
+ os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0o644)
// Should not panic with nil result
hs.executeHeartbeat()
@@ -191,7 +191,7 @@ func TestLogPath(t *testing.T) {
hs := NewHeartbeatService(tmpDir, 30, true)
// Write a log entry
- hs.log("INFO", "Test log entry")
+ hs.logf("INFO", "Test log entry")
// Verify log file exists at workspace root
expectedLogPath := filepath.Join(tmpDir, "heartbeat.log")
diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go
index 22f66829f..56dc87a53 100644
--- a/pkg/logger/logger.go
+++ b/pkg/logger/logger.go
@@ -41,12 +41,12 @@ type Logger struct {
}
type LogEntry struct {
- Level string `json:"level"`
- Timestamp string `json:"timestamp"`
- Component string `json:"component,omitempty"`
- Message string `json:"message"`
- Fields map[string]interface{} `json:"fields,omitempty"`
- Caller string `json:"caller,omitempty"`
+ Level string `json:"level"`
+ Timestamp string `json:"timestamp"`
+ Component string `json:"component,omitempty"`
+ Message string `json:"message"`
+ Fields map[string]any `json:"fields,omitempty"`
+ Caller string `json:"caller,omitempty"`
}
func init() {
@@ -71,7 +71,7 @@ func EnableFileLogging(filePath string) error {
mu.Lock()
defer mu.Unlock()
- file, err := os.OpenFile(filePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
+ file, err := os.OpenFile(filePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644)
if err != nil {
return fmt.Errorf("failed to open log file: %w", err)
}
@@ -96,7 +96,7 @@ func DisableFileLogging() {
}
}
-func logMessage(level LogLevel, component string, message string, fields map[string]interface{}) {
+func logMessage(level LogLevel, component string, message string, fields map[string]any) {
if level < currentLevel {
return
}
@@ -119,13 +119,15 @@ func logMessage(level LogLevel, component string, message string, fields map[str
if logger.file != nil {
jsonData, err := json.Marshal(entry)
if err == nil {
- logger.file.WriteString(string(jsonData) + "\n")
+ logger.file.Write(append(jsonData, '\n'))
}
}
var fieldStr string
if len(fields) > 0 {
fieldStr = " " + formatFields(fields)
+ } else {
+ fieldStr = ""
}
logLine := fmt.Sprintf("[%s] [%s]%s %s%s",
@@ -150,8 +152,8 @@ func formatComponent(component string) string {
return fmt.Sprintf(" %s:", component)
}
-func formatFields(fields map[string]interface{}) string {
- var parts []string
+func formatFields(fields map[string]any) string {
+ parts := make([]string, 0, len(fields))
for k, v := range fields {
parts = append(parts, fmt.Sprintf("%s=%v", k, v))
}
@@ -166,11 +168,11 @@ func DebugC(component string, message string) {
logMessage(DEBUG, component, message, nil)
}
-func DebugF(message string, fields map[string]interface{}) {
+func DebugF(message string, fields map[string]any) {
logMessage(DEBUG, "", message, fields)
}
-func DebugCF(component string, message string, fields map[string]interface{}) {
+func DebugCF(component string, message string, fields map[string]any) {
logMessage(DEBUG, component, message, fields)
}
@@ -182,11 +184,11 @@ func InfoC(component string, message string) {
logMessage(INFO, component, message, nil)
}
-func InfoF(message string, fields map[string]interface{}) {
+func InfoF(message string, fields map[string]any) {
logMessage(INFO, "", message, fields)
}
-func InfoCF(component string, message string, fields map[string]interface{}) {
+func InfoCF(component string, message string, fields map[string]any) {
logMessage(INFO, component, message, fields)
}
@@ -198,11 +200,11 @@ func WarnC(component string, message string) {
logMessage(WARN, component, message, nil)
}
-func WarnF(message string, fields map[string]interface{}) {
+func WarnF(message string, fields map[string]any) {
logMessage(WARN, "", message, fields)
}
-func WarnCF(component string, message string, fields map[string]interface{}) {
+func WarnCF(component string, message string, fields map[string]any) {
logMessage(WARN, component, message, fields)
}
@@ -214,11 +216,11 @@ func ErrorC(component string, message string) {
logMessage(ERROR, component, message, nil)
}
-func ErrorF(message string, fields map[string]interface{}) {
+func ErrorF(message string, fields map[string]any) {
logMessage(ERROR, "", message, fields)
}
-func ErrorCF(component string, message string, fields map[string]interface{}) {
+func ErrorCF(component string, message string, fields map[string]any) {
logMessage(ERROR, component, message, fields)
}
@@ -230,10 +232,10 @@ func FatalC(component string, message string) {
logMessage(FATAL, component, message, nil)
}
-func FatalF(message string, fields map[string]interface{}) {
+func FatalF(message string, fields map[string]any) {
logMessage(FATAL, "", message, fields)
}
-func FatalCF(component string, message string, fields map[string]interface{}) {
+func FatalCF(component string, message string, fields map[string]any) {
logMessage(FATAL, component, message, fields)
}
diff --git a/pkg/logger/logger_test.go b/pkg/logger/logger_test.go
index 9b9c96820..6e6f8dfa8 100644
--- a/pkg/logger/logger_test.go
+++ b/pkg/logger/logger_test.go
@@ -54,11 +54,11 @@ func TestLoggerWithComponent(t *testing.T) {
name string
component string
message string
- fields map[string]interface{}
+ fields map[string]any
}{
{"Simple message", "test", "Hello, world!", nil},
{"Message with component", "discord", "Discord message", nil},
- {"Message with fields", "telegram", "Telegram message", map[string]interface{}{
+ {"Message with fields", "telegram", "Telegram message", map[string]any{
"user_id": "12345",
"count": 42,
}},
@@ -128,12 +128,12 @@ func TestLoggerHelperFunctions(t *testing.T) {
Error("This should log")
InfoC("test", "Component message")
- InfoF("Fields message", map[string]interface{}{"key": "value"})
+ InfoF("Fields message", map[string]any{"key": "value"})
WarnC("test", "Warning with component")
- ErrorF("Error with fields", map[string]interface{}{"error": "test"})
+ ErrorF("Error with fields", map[string]any{"error": "test"})
SetLevel(DEBUG)
DebugC("test", "Debug with component")
- WarnF("Warning with fields", map[string]interface{}{"key": "value"})
+ WarnF("Warning with fields", map[string]any{"key": "value"})
}
diff --git a/pkg/migrate/config.go b/pkg/migrate/config.go
index b01bb80e3..869b39827 100644
--- a/pkg/migrate/config.go
+++ b/pkg/migrate/config.go
@@ -22,6 +22,7 @@ var supportedProviders = map[string]bool{
"qwen": true,
"deepseek": true,
"github_copilot": true,
+ "mistral": true,
}
var supportedChannels = map[string]bool{
@@ -47,32 +48,35 @@ func findOpenClawConfig(openclawHome string) (string, error) {
return "", fmt.Errorf("no config file found in %s (tried openclaw.json, config.json)", openclawHome)
}
-func LoadOpenClawConfig(configPath string) (map[string]interface{}, error) {
+func LoadOpenClawConfig(configPath string) (map[string]any, error) {
data, err := os.ReadFile(configPath)
if err != nil {
return nil, fmt.Errorf("reading OpenClaw config: %w", err)
}
- var raw map[string]interface{}
+ var raw map[string]any
if err := json.Unmarshal(data, &raw); err != nil {
return nil, fmt.Errorf("parsing OpenClaw config: %w", err)
}
converted := convertKeysToSnake(raw)
- result, ok := converted.(map[string]interface{})
+ result, ok := converted.(map[string]any)
if !ok {
return nil, fmt.Errorf("unexpected config format")
}
return result, nil
}
-func ConvertConfig(data map[string]interface{}) (*config.Config, []string, error) {
+func ConvertConfig(data map[string]any) (*config.Config, []string, error) {
cfg := config.DefaultConfig()
var warnings []string
if agents, ok := getMap(data, "agents"); ok {
if defaults, ok := getMap(agents, "defaults"); ok {
- if v, ok := getString(defaults, "model"); ok {
+ // Prefer model_name, fallback to model for backward compatibility
+ if v, ok := getString(defaults, "model_name"); ok {
+ cfg.Agents.Defaults.ModelName = v
+ } else if v, ok := getString(defaults, "model"); ok {
cfg.Agents.Defaults.Model = v
}
if v, ok := getFloat(defaults, "max_tokens"); ok {
@@ -92,7 +96,7 @@ func ConvertConfig(data map[string]interface{}) (*config.Config, []string, error
if providers, ok := getMap(data, "providers"); ok {
for name, val := range providers {
- pMap, ok := val.(map[string]interface{})
+ pMap, ok := val.(map[string]any)
if !ok {
continue
}
@@ -131,7 +135,7 @@ func ConvertConfig(data map[string]interface{}) (*config.Config, []string, error
if channels, ok := getMap(data, "channels"); ok {
for name, val := range channels {
- cMap, ok := val.(map[string]interface{})
+ cMap, ok := val.(map[string]any)
if !ok {
continue
}
@@ -318,16 +322,16 @@ func camelToSnake(s string) string {
return result.String()
}
-func convertKeysToSnake(data interface{}) interface{} {
+func convertKeysToSnake(data any) any {
switch v := data.(type) {
- case map[string]interface{}:
- result := make(map[string]interface{}, len(v))
+ case map[string]any:
+ result := make(map[string]any, len(v))
for key, val := range v {
result[camelToSnake(key)] = convertKeysToSnake(val)
}
return result
- case []interface{}:
- result := make([]interface{}, len(v))
+ case []any:
+ result := make([]any, len(v))
for i, val := range v {
result[i] = convertKeysToSnake(val)
}
@@ -342,16 +346,16 @@ func rewriteWorkspacePath(path string) string {
return path
}
-func getMap(data map[string]interface{}, key string) (map[string]interface{}, bool) {
+func getMap(data map[string]any, key string) (map[string]any, bool) {
v, ok := data[key]
if !ok {
return nil, false
}
- m, ok := v.(map[string]interface{})
+ m, ok := v.(map[string]any)
return m, ok
}
-func getString(data map[string]interface{}, key string) (string, bool) {
+func getString(data map[string]any, key string) (string, bool) {
v, ok := data[key]
if !ok {
return "", false
@@ -360,7 +364,7 @@ func getString(data map[string]interface{}, key string) (string, bool) {
return s, ok
}
-func getFloat(data map[string]interface{}, key string) (float64, bool) {
+func getFloat(data map[string]any, key string) (float64, bool) {
v, ok := data[key]
if !ok {
return 0, false
@@ -369,7 +373,7 @@ func getFloat(data map[string]interface{}, key string) (float64, bool) {
return f, ok
}
-func getBool(data map[string]interface{}, key string) (bool, bool) {
+func getBool(data map[string]any, key string) (bool, bool) {
v, ok := data[key]
if !ok {
return false, false
@@ -378,19 +382,19 @@ func getBool(data map[string]interface{}, key string) (bool, bool) {
return b, ok
}
-func getBoolOrDefault(data map[string]interface{}, key string, defaultVal bool) bool {
+func getBoolOrDefault(data map[string]any, key string, defaultVal bool) bool {
if v, ok := getBool(data, key); ok {
return v
}
return defaultVal
}
-func getStringSlice(data map[string]interface{}, key string) []string {
+func getStringSlice(data map[string]any, key string) []string {
v, ok := data[key]
if !ok {
return []string{}
}
- arr, ok := v.([]interface{})
+ arr, ok := v.([]any)
if !ok {
return []string{}
}
diff --git a/pkg/migrate/migrate.go b/pkg/migrate/migrate.go
index 921f821cb..cfa82b7d7 100644
--- a/pkg/migrate/migrate.go
+++ b/pkg/migrate/migrate.go
@@ -67,7 +67,7 @@ func Run(opts Options) (*Result, error) {
return nil, err
}
- if _, err := os.Stat(openclawHome); os.IsNotExist(err) {
+ if _, err = os.Stat(openclawHome); os.IsNotExist(err) {
return nil, fmt.Errorf("OpenClaw installation not found at %s", openclawHome)
}
@@ -161,7 +161,7 @@ func Execute(actions []Action, openclawHome, picoClawHome string) *Result {
fmt.Printf(" ✓ Converted config: %s\n", action.Destination)
}
case ActionCreateDir:
- if err := os.MkdirAll(action.Destination, 0755); err != nil {
+ if err := os.MkdirAll(action.Destination, 0o755); err != nil {
result.Errors = append(result.Errors, err)
} else {
result.DirsCreated++
@@ -174,9 +174,13 @@ func Execute(actions []Action, openclawHome, picoClawHome string) *Result {
continue
}
result.BackupsCreated++
- fmt.Printf(" ✓ Backed up %s -> %s.bak\n", filepath.Base(action.Destination), filepath.Base(action.Destination))
+ fmt.Printf(
+ " ✓ Backed up %s -> %s.bak\n",
+ filepath.Base(action.Destination),
+ filepath.Base(action.Destination),
+ )
- if err := os.MkdirAll(filepath.Dir(action.Destination), 0755); err != nil {
+ if err := os.MkdirAll(filepath.Dir(action.Destination), 0o755); err != nil {
result.Errors = append(result.Errors, err)
continue
}
@@ -188,7 +192,7 @@ func Execute(actions []Action, openclawHome, picoClawHome string) *Result {
fmt.Printf(" ✓ Copied %s\n", relPath(action.Source, openclawHome))
}
case ActionCopy:
- if err := os.MkdirAll(filepath.Dir(action.Destination), 0755); err != nil {
+ if err := os.MkdirAll(filepath.Dir(action.Destination), 0o755); err != nil {
result.Errors = append(result.Errors, err)
continue
}
@@ -226,7 +230,7 @@ func executeConfigMigration(srcConfigPath, dstConfigPath, picoClawHome string) e
incoming = MergeConfig(existing, incoming)
}
- if err := os.MkdirAll(filepath.Dir(dstConfigPath), 0755); err != nil {
+ if err := os.MkdirAll(filepath.Dir(dstConfigPath), 0o755); err != nil {
return err
}
return config.SaveConfig(dstConfigPath, incoming)
diff --git a/pkg/migrate/migrate_test.go b/pkg/migrate/migrate_test.go
index 759fc9024..b6b3d70aa 100644
--- a/pkg/migrate/migrate_test.go
+++ b/pkg/migrate/migrate_test.go
@@ -40,43 +40,43 @@ func TestCamelToSnake(t *testing.T) {
}
func TestConvertKeysToSnake(t *testing.T) {
- input := map[string]interface{}{
+ input := map[string]any{
"apiKey": "test-key",
"apiBase": "https://example.com",
- "nested": map[string]interface{}{
+ "nested": map[string]any{
"maxTokens": float64(8192),
- "allowFrom": []interface{}{"user1", "user2"},
- "deeperLevel": map[string]interface{}{
+ "allowFrom": []any{"user1", "user2"},
+ "deeperLevel": map[string]any{
"clientId": "abc",
},
},
}
result := convertKeysToSnake(input)
- m, ok := result.(map[string]interface{})
+ m, ok := result.(map[string]any)
if !ok {
t.Fatal("expected map[string]interface{}")
}
- if _, ok := m["api_key"]; !ok {
+ if _, ok = m["api_key"]; !ok {
t.Error("expected key 'api_key' after conversion")
}
- if _, ok := m["api_base"]; !ok {
+ if _, ok = m["api_base"]; !ok {
t.Error("expected key 'api_base' after conversion")
}
- nested, ok := m["nested"].(map[string]interface{})
+ nested, ok := m["nested"].(map[string]any)
if !ok {
t.Fatal("expected nested map")
}
- if _, ok := nested["max_tokens"]; !ok {
+ if _, ok = nested["max_tokens"]; !ok {
t.Error("expected key 'max_tokens' in nested map")
}
- if _, ok := nested["allow_from"]; !ok {
+ if _, ok = nested["allow_from"]; !ok {
t.Error("expected key 'allow_from' in nested map")
}
- deeper, ok := nested["deeper_level"].(map[string]interface{})
+ deeper, ok := nested["deeper_level"].(map[string]any)
if !ok {
t.Fatal("expected deeper_level map")
}
@@ -89,15 +89,15 @@ func TestLoadOpenClawConfig(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "openclaw.json")
- openclawConfig := map[string]interface{}{
- "providers": map[string]interface{}{
- "anthropic": map[string]interface{}{
+ openclawConfig := map[string]any{
+ "providers": map[string]any{
+ "anthropic": map[string]any{
"apiKey": "sk-ant-test123",
"apiBase": "https://api.anthropic.com",
},
},
- "agents": map[string]interface{}{
- "defaults": map[string]interface{}{
+ "agents": map[string]any{
+ "defaults": map[string]any{
"maxTokens": float64(4096),
"model": "claude-3-opus",
},
@@ -108,7 +108,7 @@ func TestLoadOpenClawConfig(t *testing.T) {
if err != nil {
t.Fatal(err)
}
- if err := os.WriteFile(configPath, data, 0644); err != nil {
+ if err = os.WriteFile(configPath, data, 0o644); err != nil {
t.Fatal(err)
}
@@ -117,11 +117,11 @@ func TestLoadOpenClawConfig(t *testing.T) {
t.Fatalf("LoadOpenClawConfig: %v", err)
}
- providers, ok := result["providers"].(map[string]interface{})
+ providers, ok := result["providers"].(map[string]any)
if !ok {
t.Fatal("expected providers map")
}
- anthropic, ok := providers["anthropic"].(map[string]interface{})
+ anthropic, ok := providers["anthropic"].(map[string]any)
if !ok {
t.Fatal("expected anthropic map")
}
@@ -129,11 +129,11 @@ func TestLoadOpenClawConfig(t *testing.T) {
t.Errorf("api_key = %v, want sk-ant-test123", anthropic["api_key"])
}
- agents, ok := result["agents"].(map[string]interface{})
+ agents, ok := result["agents"].(map[string]any)
if !ok {
t.Fatal("expected agents map")
}
- defaults, ok := agents["defaults"].(map[string]interface{})
+ defaults, ok := agents["defaults"].(map[string]any)
if !ok {
t.Fatal("expected defaults map")
}
@@ -144,16 +144,16 @@ func TestLoadOpenClawConfig(t *testing.T) {
func TestConvertConfig(t *testing.T) {
t.Run("providers mapping", func(t *testing.T) {
- data := map[string]interface{}{
- "providers": map[string]interface{}{
- "anthropic": map[string]interface{}{
+ data := map[string]any{
+ "providers": map[string]any{
+ "anthropic": map[string]any{
"api_key": "sk-ant-test",
"api_base": "https://api.anthropic.com",
},
- "openrouter": map[string]interface{}{
+ "openrouter": map[string]any{
"api_key": "sk-or-test",
},
- "groq": map[string]interface{}{
+ "groq": map[string]any{
"api_key": "gsk-test",
},
},
@@ -178,9 +178,9 @@ func TestConvertConfig(t *testing.T) {
})
t.Run("unsupported provider warning", func(t *testing.T) {
- data := map[string]interface{}{
- "providers": map[string]interface{}{
- "unknown_provider": map[string]interface{}{
+ data := map[string]any{
+ "providers": map[string]any{
+ "unknown_provider": map[string]any{
"api_key": "sk-test",
},
},
@@ -199,14 +199,14 @@ func TestConvertConfig(t *testing.T) {
})
t.Run("channels mapping", func(t *testing.T) {
- data := map[string]interface{}{
- "channels": map[string]interface{}{
- "telegram": map[string]interface{}{
+ data := map[string]any{
+ "channels": map[string]any{
+ "telegram": map[string]any{
"enabled": true,
"token": "tg-token-123",
- "allow_from": []interface{}{"user1"},
+ "allow_from": []any{"user1"},
},
- "discord": map[string]interface{}{
+ "discord": map[string]any{
"enabled": true,
"token": "disc-token-456",
},
@@ -232,9 +232,9 @@ func TestConvertConfig(t *testing.T) {
})
t.Run("unsupported channel warning", func(t *testing.T) {
- data := map[string]interface{}{
- "channels": map[string]interface{}{
- "email": map[string]interface{}{
+ data := map[string]any{
+ "channels": map[string]any{
+ "email": map[string]any{
"enabled": true,
},
},
@@ -253,9 +253,9 @@ func TestConvertConfig(t *testing.T) {
})
t.Run("agent defaults", func(t *testing.T) {
- data := map[string]interface{}{
- "agents": map[string]interface{}{
- "defaults": map[string]interface{}{
+ data := map[string]any{
+ "agents": map[string]any{
+ "defaults": map[string]any{
"model": "claude-3-opus",
"max_tokens": float64(4096),
"temperature": 0.5,
@@ -287,7 +287,7 @@ func TestConvertConfig(t *testing.T) {
})
t.Run("empty config", func(t *testing.T) {
- data := map[string]interface{}{}
+ data := map[string]any{}
cfg, warnings, err := ConvertConfig(data)
if err != nil {
@@ -389,9 +389,9 @@ func TestPlanWorkspaceMigration(t *testing.T) {
srcDir := t.TempDir()
dstDir := t.TempDir()
- os.WriteFile(filepath.Join(srcDir, "AGENTS.md"), []byte("# Agents"), 0644)
- os.WriteFile(filepath.Join(srcDir, "SOUL.md"), []byte("# Soul"), 0644)
- os.WriteFile(filepath.Join(srcDir, "USER.md"), []byte("# User"), 0644)
+ os.WriteFile(filepath.Join(srcDir, "AGENTS.md"), []byte("# Agents"), 0o644)
+ os.WriteFile(filepath.Join(srcDir, "SOUL.md"), []byte("# Soul"), 0o644)
+ os.WriteFile(filepath.Join(srcDir, "USER.md"), []byte("# User"), 0o644)
actions, err := PlanWorkspaceMigration(srcDir, dstDir, false)
if err != nil {
@@ -420,8 +420,8 @@ func TestPlanWorkspaceMigration(t *testing.T) {
srcDir := t.TempDir()
dstDir := t.TempDir()
- os.WriteFile(filepath.Join(srcDir, "AGENTS.md"), []byte("# Agents from OpenClaw"), 0644)
- os.WriteFile(filepath.Join(dstDir, "AGENTS.md"), []byte("# Existing Agents"), 0644)
+ os.WriteFile(filepath.Join(srcDir, "AGENTS.md"), []byte("# Agents from OpenClaw"), 0o644)
+ os.WriteFile(filepath.Join(dstDir, "AGENTS.md"), []byte("# Existing Agents"), 0o644)
actions, err := PlanWorkspaceMigration(srcDir, dstDir, false)
if err != nil {
@@ -443,8 +443,8 @@ func TestPlanWorkspaceMigration(t *testing.T) {
srcDir := t.TempDir()
dstDir := t.TempDir()
- os.WriteFile(filepath.Join(srcDir, "AGENTS.md"), []byte("# Agents"), 0644)
- os.WriteFile(filepath.Join(dstDir, "AGENTS.md"), []byte("# Existing"), 0644)
+ os.WriteFile(filepath.Join(srcDir, "AGENTS.md"), []byte("# Agents"), 0o644)
+ os.WriteFile(filepath.Join(dstDir, "AGENTS.md"), []byte("# Existing"), 0o644)
actions, err := PlanWorkspaceMigration(srcDir, dstDir, true)
if err != nil {
@@ -463,8 +463,8 @@ func TestPlanWorkspaceMigration(t *testing.T) {
dstDir := t.TempDir()
memDir := filepath.Join(srcDir, "memory")
- os.MkdirAll(memDir, 0755)
- os.WriteFile(filepath.Join(memDir, "MEMORY.md"), []byte("# Memory"), 0644)
+ os.MkdirAll(memDir, 0o755)
+ os.WriteFile(filepath.Join(memDir, "MEMORY.md"), []byte("# Memory"), 0o644)
actions, err := PlanWorkspaceMigration(srcDir, dstDir, false)
if err != nil {
@@ -494,8 +494,8 @@ func TestPlanWorkspaceMigration(t *testing.T) {
dstDir := t.TempDir()
skillDir := filepath.Join(srcDir, "skills", "weather")
- os.MkdirAll(skillDir, 0755)
- os.WriteFile(filepath.Join(skillDir, "SKILL.md"), []byte("# Weather"), 0644)
+ os.MkdirAll(skillDir, 0o755)
+ os.WriteFile(filepath.Join(skillDir, "SKILL.md"), []byte("# Weather"), 0o644)
actions, err := PlanWorkspaceMigration(srcDir, dstDir, false)
if err != nil {
@@ -518,7 +518,7 @@ func TestFindOpenClawConfig(t *testing.T) {
t.Run("finds openclaw.json", func(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "openclaw.json")
- os.WriteFile(configPath, []byte("{}"), 0644)
+ os.WriteFile(configPath, []byte("{}"), 0o644)
found, err := findOpenClawConfig(tmpDir)
if err != nil {
@@ -532,7 +532,7 @@ func TestFindOpenClawConfig(t *testing.T) {
t.Run("falls back to config.json", func(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.json")
- os.WriteFile(configPath, []byte("{}"), 0644)
+ os.WriteFile(configPath, []byte("{}"), 0o644)
found, err := findOpenClawConfig(tmpDir)
if err != nil {
@@ -546,8 +546,8 @@ func TestFindOpenClawConfig(t *testing.T) {
t.Run("prefers openclaw.json over config.json", func(t *testing.T) {
tmpDir := t.TempDir()
openclawPath := filepath.Join(tmpDir, "openclaw.json")
- os.WriteFile(openclawPath, []byte("{}"), 0644)
- os.WriteFile(filepath.Join(tmpDir, "config.json"), []byte("{}"), 0644)
+ os.WriteFile(openclawPath, []byte("{}"), 0o644)
+ os.WriteFile(filepath.Join(tmpDir, "config.json"), []byte("{}"), 0o644)
found, err := findOpenClawConfig(tmpDir)
if err != nil {
@@ -593,19 +593,19 @@ func TestRunDryRun(t *testing.T) {
picoClawHome := t.TempDir()
wsDir := filepath.Join(openclawHome, "workspace")
- os.MkdirAll(wsDir, 0755)
- os.WriteFile(filepath.Join(wsDir, "SOUL.md"), []byte("# Soul"), 0644)
- os.WriteFile(filepath.Join(wsDir, "AGENTS.md"), []byte("# Agents"), 0644)
+ os.MkdirAll(wsDir, 0o755)
+ os.WriteFile(filepath.Join(wsDir, "SOUL.md"), []byte("# Soul"), 0o644)
+ os.WriteFile(filepath.Join(wsDir, "AGENTS.md"), []byte("# Agents"), 0o644)
- configData := map[string]interface{}{
- "providers": map[string]interface{}{
- "anthropic": map[string]interface{}{
+ configData := map[string]any{
+ "providers": map[string]any{
+ "anthropic": map[string]any{
"apiKey": "test-key",
},
},
}
data, _ := json.Marshal(configData)
- os.WriteFile(filepath.Join(openclawHome, "openclaw.json"), data, 0644)
+ os.WriteFile(filepath.Join(openclawHome, "openclaw.json"), data, 0o644)
opts := Options{
DryRun: true,
@@ -634,33 +634,33 @@ func TestRunFullMigration(t *testing.T) {
picoClawHome := t.TempDir()
wsDir := filepath.Join(openclawHome, "workspace")
- os.MkdirAll(wsDir, 0755)
- os.WriteFile(filepath.Join(wsDir, "SOUL.md"), []byte("# Soul from OpenClaw"), 0644)
- os.WriteFile(filepath.Join(wsDir, "AGENTS.md"), []byte("# Agents from OpenClaw"), 0644)
- os.WriteFile(filepath.Join(wsDir, "USER.md"), []byte("# User from OpenClaw"), 0644)
+ os.MkdirAll(wsDir, 0o755)
+ os.WriteFile(filepath.Join(wsDir, "SOUL.md"), []byte("# Soul from OpenClaw"), 0o644)
+ os.WriteFile(filepath.Join(wsDir, "AGENTS.md"), []byte("# Agents from OpenClaw"), 0o644)
+ os.WriteFile(filepath.Join(wsDir, "USER.md"), []byte("# User from OpenClaw"), 0o644)
memDir := filepath.Join(wsDir, "memory")
- os.MkdirAll(memDir, 0755)
- os.WriteFile(filepath.Join(memDir, "MEMORY.md"), []byte("# Memory notes"), 0644)
+ os.MkdirAll(memDir, 0o755)
+ os.WriteFile(filepath.Join(memDir, "MEMORY.md"), []byte("# Memory notes"), 0o644)
- configData := map[string]interface{}{
- "providers": map[string]interface{}{
- "anthropic": map[string]interface{}{
+ configData := map[string]any{
+ "providers": map[string]any{
+ "anthropic": map[string]any{
"apiKey": "sk-ant-migrate-test",
},
- "openrouter": map[string]interface{}{
+ "openrouter": map[string]any{
"apiKey": "sk-or-migrate-test",
},
},
- "channels": map[string]interface{}{
- "telegram": map[string]interface{}{
+ "channels": map[string]any{
+ "telegram": map[string]any{
"enabled": true,
"token": "tg-migrate-test",
},
},
}
data, _ := json.Marshal(configData)
- os.WriteFile(filepath.Join(openclawHome, "openclaw.json"), data, 0644)
+ os.WriteFile(filepath.Join(openclawHome, "openclaw.json"), data, 0o644)
opts := Options{
Force: true,
@@ -754,7 +754,7 @@ func TestRunMutuallyExclusiveFlags(t *testing.T) {
func TestBackupFile(t *testing.T) {
tmpDir := t.TempDir()
filePath := filepath.Join(tmpDir, "test.md")
- os.WriteFile(filePath, []byte("original content"), 0644)
+ os.WriteFile(filePath, []byte("original content"), 0o644)
if err := backupFile(filePath); err != nil {
t.Fatalf("backupFile: %v", err)
@@ -775,7 +775,7 @@ func TestCopyFile(t *testing.T) {
srcPath := filepath.Join(tmpDir, "src.md")
dstPath := filepath.Join(tmpDir, "dst.md")
- os.WriteFile(srcPath, []byte("file content"), 0644)
+ os.WriteFile(srcPath, []byte("file content"), 0o644)
if err := copyFile(srcPath, dstPath); err != nil {
t.Fatalf("copyFile: %v", err)
@@ -795,18 +795,18 @@ func TestRunConfigOnly(t *testing.T) {
picoClawHome := t.TempDir()
wsDir := filepath.Join(openclawHome, "workspace")
- os.MkdirAll(wsDir, 0755)
- os.WriteFile(filepath.Join(wsDir, "SOUL.md"), []byte("# Soul"), 0644)
+ os.MkdirAll(wsDir, 0o755)
+ os.WriteFile(filepath.Join(wsDir, "SOUL.md"), []byte("# Soul"), 0o644)
- configData := map[string]interface{}{
- "providers": map[string]interface{}{
- "anthropic": map[string]interface{}{
+ configData := map[string]any{
+ "providers": map[string]any{
+ "anthropic": map[string]any{
"apiKey": "sk-config-only",
},
},
}
data, _ := json.Marshal(configData)
- os.WriteFile(filepath.Join(openclawHome, "openclaw.json"), data, 0644)
+ os.WriteFile(filepath.Join(openclawHome, "openclaw.json"), data, 0o644)
opts := Options{
Force: true,
@@ -835,18 +835,18 @@ func TestRunWorkspaceOnly(t *testing.T) {
picoClawHome := t.TempDir()
wsDir := filepath.Join(openclawHome, "workspace")
- os.MkdirAll(wsDir, 0755)
- os.WriteFile(filepath.Join(wsDir, "SOUL.md"), []byte("# Soul"), 0644)
+ os.MkdirAll(wsDir, 0o755)
+ os.WriteFile(filepath.Join(wsDir, "SOUL.md"), []byte("# Soul"), 0o644)
- configData := map[string]interface{}{
- "providers": map[string]interface{}{
- "anthropic": map[string]interface{}{
+ configData := map[string]any{
+ "providers": map[string]any{
+ "anthropic": map[string]any{
"apiKey": "sk-ws-only",
},
},
}
data, _ := json.Marshal(configData)
- os.WriteFile(filepath.Join(openclawHome, "openclaw.json"), data, 0644)
+ os.WriteFile(filepath.Join(openclawHome, "openclaw.json"), data, 0o644)
opts := Options{
Force: true,
diff --git a/pkg/providers/anthropic/provider.go b/pkg/providers/anthropic/provider.go
index a27a25a2d..9162174c9 100644
--- a/pkg/providers/anthropic/provider.go
+++ b/pkg/providers/anthropic/provider.go
@@ -9,16 +9,19 @@ import (
"github.com/anthropics/anthropic-sdk-go"
"github.com/anthropics/anthropic-sdk-go/option"
+
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
-type ToolCall = protocoltypes.ToolCall
-type FunctionCall = protocoltypes.FunctionCall
-type LLMResponse = protocoltypes.LLMResponse
-type UsageInfo = protocoltypes.UsageInfo
-type Message = protocoltypes.Message
-type ToolDefinition = protocoltypes.ToolDefinition
-type ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition
+type (
+ ToolCall = protocoltypes.ToolCall
+ FunctionCall = protocoltypes.FunctionCall
+ LLMResponse = protocoltypes.LLMResponse
+ UsageInfo = protocoltypes.UsageInfo
+ Message = protocoltypes.Message
+ ToolDefinition = protocoltypes.ToolDefinition
+ ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition
+)
const defaultBaseURL = "https://api.anthropic.com"
@@ -61,7 +64,13 @@ func NewProviderWithTokenSourceAndBaseURL(token string, tokenSource func() (stri
return p
}
-func (p *Provider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
+func (p *Provider) Chat(
+ ctx context.Context,
+ messages []Message,
+ tools []ToolDefinition,
+ model string,
+ options map[string]any,
+) (*LLMResponse, error) {
var opts []option.RequestOption
if p.tokenSource != nil {
tok, err := p.tokenSource()
@@ -92,14 +101,32 @@ func (p *Provider) BaseURL() string {
return p.baseURL
}
-func buildParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (anthropic.MessageNewParams, error) {
+func buildParams(
+ messages []Message,
+ tools []ToolDefinition,
+ model string,
+ options map[string]any,
+) (anthropic.MessageNewParams, error) {
var system []anthropic.TextBlockParam
var anthropicMessages []anthropic.MessageParam
for _, msg := range messages {
switch msg.Role {
case "system":
- system = append(system, anthropic.TextBlockParam{Text: msg.Content})
+ // Prefer structured SystemParts for per-block cache_control.
+ // This enables LLM-side KV cache reuse: the static block's prefix
+ // hash stays stable across requests while dynamic parts change freely.
+ if len(msg.SystemParts) > 0 {
+ for _, part := range msg.SystemParts {
+ block := anthropic.TextBlockParam{Text: part.Text}
+ if part.CacheControl != nil && part.CacheControl.Type == "ephemeral" {
+ block.CacheControl = anthropic.NewCacheControlEphemeralParam()
+ }
+ system = append(system, block)
+ }
+ } else {
+ system = append(system, anthropic.TextBlockParam{Text: msg.Content})
+ }
case "user":
if msg.ToolCallID != "" {
anthropicMessages = append(anthropicMessages,
@@ -170,7 +197,7 @@ func translateTools(tools []ToolDefinition) []anthropic.ToolUnionParam {
if desc := t.Function.Description; desc != "" {
tool.Description = anthropic.String(desc)
}
- if req, ok := t.Function.Parameters["required"].([]interface{}); ok {
+ if req, ok := t.Function.Parameters["required"].([]any); ok {
required := make([]string, 0, len(req))
for _, r := range req {
if s, ok := r.(string); ok {
@@ -195,10 +222,10 @@ func parseResponse(resp *anthropic.Message) *LLMResponse {
content += tb.Text
case "tool_use":
tu := block.AsToolUse()
- var args map[string]interface{}
+ var args map[string]any
if err := json.Unmarshal(tu.Input, &args); err != nil {
log.Printf("anthropic: failed to decode tool call input for %q: %v", tu.Name, err)
- args = map[string]interface{}{"raw": string(tu.Input)}
+ args = map[string]any{"raw": string(tu.Input)}
}
toolCalls = append(toolCalls, ToolCall{
ID: tu.ID,
diff --git a/pkg/providers/anthropic/provider_test.go b/pkg/providers/anthropic/provider_test.go
index 08ac9c829..3d21c1d0b 100644
--- a/pkg/providers/anthropic/provider_test.go
+++ b/pkg/providers/anthropic/provider_test.go
@@ -15,7 +15,7 @@ func TestBuildParams_BasicMessage(t *testing.T) {
messages := []Message{
{Role: "user", Content: "Hello"},
}
- params, err := buildParams(messages, nil, "claude-sonnet-4.6", map[string]interface{}{
+ params, err := buildParams(messages, nil, "claude-sonnet-4.6", map[string]any{
"max_tokens": 1024,
})
if err != nil {
@@ -37,7 +37,7 @@ func TestBuildParams_SystemMessage(t *testing.T) {
{Role: "system", Content: "You are helpful"},
{Role: "user", Content: "Hi"},
}
- params, err := buildParams(messages, nil, "claude-sonnet-4.6", map[string]interface{}{})
+ params, err := buildParams(messages, nil, "claude-sonnet-4.6", map[string]any{})
if err != nil {
t.Fatalf("buildParams() error: %v", err)
}
@@ -62,13 +62,13 @@ func TestBuildParams_ToolCallMessage(t *testing.T) {
{
ID: "call_1",
Name: "get_weather",
- Arguments: map[string]interface{}{"city": "SF"},
+ Arguments: map[string]any{"city": "SF"},
},
},
},
{Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"},
}
- params, err := buildParams(messages, nil, "claude-sonnet-4.6", map[string]interface{}{})
+ params, err := buildParams(messages, nil, "claude-sonnet-4.6", map[string]any{})
if err != nil {
t.Fatalf("buildParams() error: %v", err)
}
@@ -84,17 +84,17 @@ func TestBuildParams_WithTools(t *testing.T) {
Function: ToolFunctionDefinition{
Name: "get_weather",
Description: "Get weather for a city",
- Parameters: map[string]interface{}{
+ Parameters: map[string]any{
"type": "object",
- "properties": map[string]interface{}{
- "city": map[string]interface{}{"type": "string"},
+ "properties": map[string]any{
+ "city": map[string]any{"type": "string"},
},
- "required": []interface{}{"city"},
+ "required": []any{"city"},
},
},
},
}
- params, err := buildParams([]Message{{Role: "user", Content: "Hi"}}, tools, "claude-sonnet-4.6", map[string]interface{}{})
+ params, err := buildParams([]Message{{Role: "user", Content: "Hi"}}, tools, "claude-sonnet-4.6", map[string]any{})
if err != nil {
t.Fatalf("buildParams() error: %v", err)
}
@@ -154,19 +154,19 @@ func TestProvider_ChatRoundTrip(t *testing.T) {
return
}
- var reqBody map[string]interface{}
+ var reqBody map[string]any
json.NewDecoder(r.Body).Decode(&reqBody)
- resp := map[string]interface{}{
+ resp := map[string]any{
"id": "msg_test",
"type": "message",
"role": "assistant",
"model": reqBody["model"],
"stop_reason": "end_turn",
- "content": []map[string]interface{}{
+ "content": []map[string]any{
{"type": "text", "text": "Hello! How can I help you?"},
},
- "usage": map[string]interface{}{
+ "usage": map[string]any{
"input_tokens": 15,
"output_tokens": 8,
},
@@ -178,7 +178,7 @@ func TestProvider_ChatRoundTrip(t *testing.T) {
provider := NewProviderWithClient(createAnthropicTestClient(server.URL, "test-token"))
messages := []Message{{Role: "user", Content: "Hello"}}
- resp, err := provider.Chat(t.Context(), messages, nil, "claude-sonnet-4.6", map[string]interface{}{"max_tokens": 1024})
+ resp, err := provider.Chat(t.Context(), messages, nil, "claude-sonnet-4.6", map[string]any{"max_tokens": 1024})
if err != nil {
t.Fatalf("Chat() error: %v", err)
}
@@ -221,19 +221,19 @@ func TestProvider_ChatUsesTokenSource(t *testing.T) {
return
}
- var reqBody map[string]interface{}
+ var reqBody map[string]any
json.NewDecoder(r.Body).Decode(&reqBody)
- resp := map[string]interface{}{
+ resp := map[string]any{
"id": "msg_test",
"type": "message",
"role": "assistant",
"model": reqBody["model"],
"stop_reason": "end_turn",
- "content": []map[string]interface{}{
+ "content": []map[string]any{
{"type": "text", "text": "ok"},
},
- "usage": map[string]interface{}{
+ "usage": map[string]any{
"input_tokens": 1,
"output_tokens": 1,
},
@@ -247,7 +247,13 @@ func TestProvider_ChatUsesTokenSource(t *testing.T) {
return "refreshed-token", nil
}, server.URL)
- _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hello"}}, nil, "claude-sonnet-4.6", map[string]interface{}{})
+ _, err := p.Chat(
+ t.Context(),
+ []Message{{Role: "user", Content: "hello"}},
+ nil,
+ "claude-sonnet-4.6",
+ map[string]any{},
+ )
if err != nil {
t.Fatalf("Chat() error: %v", err)
}
diff --git a/pkg/providers/antigravity_provider.go b/pkg/providers/antigravity_provider.go
index 6c6bf7830..d4ee528b7 100644
--- a/pkg/providers/antigravity_provider.go
+++ b/pkg/providers/antigravity_provider.go
@@ -45,7 +45,13 @@ func NewAntigravityProvider() *AntigravityProvider {
// Chat implements LLMProvider.Chat using the Cloud Code Assist v1internal API.
// The v1internal endpoint wraps the standard Gemini request in an envelope with
// project, model, request, requestType, userAgent, and requestId fields.
-func (p *AntigravityProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
+func (p *AntigravityProvider) Chat(
+ ctx context.Context,
+ messages []Message,
+ tools []ToolDefinition,
+ model string,
+ options map[string]any,
+) (*LLMResponse, error) {
accessToken, projectID, err := p.tokenSource()
if err != nil {
return nil, fmt.Errorf("antigravity auth: %w", err)
@@ -58,7 +64,7 @@ func (p *AntigravityProvider) Chat(ctx context.Context, messages []Message, tool
model = strings.TrimPrefix(model, "google-antigravity/")
model = strings.TrimPrefix(model, "antigravity/")
- logger.DebugCF("provider.antigravity", "Starting chat", map[string]interface{}{
+ logger.DebugCF("provider.antigravity", "Starting chat", map[string]any{
"model": model,
"project": projectID,
"requestId": fmt.Sprintf("agent-%d-%s", time.Now().UnixMilli(), randomString(9)),
@@ -68,7 +74,7 @@ func (p *AntigravityProvider) Chat(ctx context.Context, messages []Message, tool
innerRequest := p.buildRequest(messages, tools, model, options)
// Wrap in v1internal envelope (matches pi-ai SDK format)
- envelope := map[string]interface{}{
+ envelope := map[string]any{
"project": projectID,
"model": model,
"request": innerRequest,
@@ -115,7 +121,7 @@ func (p *AntigravityProvider) Chat(ctx context.Context, messages []Message, tool
}
if resp.StatusCode != http.StatusOK {
- logger.ErrorCF("provider.antigravity", "API call failed", map[string]interface{}{
+ logger.ErrorCF("provider.antigravity", "API call failed", map[string]any{
"status_code": resp.StatusCode,
"response": string(respBody),
"model": model,
@@ -133,7 +139,9 @@ func (p *AntigravityProvider) Chat(ctx context.Context, messages []Message, tool
// Check for empty response (some models might return valid success but empty text)
if llmResp.Content == "" && len(llmResp.ToolCalls) == 0 {
- return nil, fmt.Errorf("antigravity: model returned an empty response (this model might be invalid or restricted)")
+ return nil, fmt.Errorf(
+ "antigravity: model returned an empty response (this model might be invalid or restricted)",
+ )
}
return llmResp, nil
@@ -167,13 +175,13 @@ type antigravityPart struct {
}
type antigravityFunctionCall struct {
- Name string `json:"name"`
- Args map[string]interface{} `json:"args"`
+ Name string `json:"name"`
+ Args map[string]any `json:"args"`
}
type antigravityFunctionResponse struct {
- Name string `json:"name"`
- Response map[string]interface{} `json:"response"`
+ Name string `json:"name"`
+ Response map[string]any `json:"response"`
}
type antigravityTool struct {
@@ -181,9 +189,9 @@ type antigravityTool struct {
}
type antigravityFuncDecl struct {
- Name string `json:"name"`
- Description string `json:"description,omitempty"`
- Parameters interface{} `json:"parameters,omitempty"`
+ Name string `json:"name"`
+ Description string `json:"description,omitempty"`
+ Parameters any `json:"parameters,omitempty"`
}
type antigravitySystemPrompt struct {
@@ -195,7 +203,12 @@ type antigravityGenConfig struct {
Temperature float64 `json:"temperature,omitempty"`
}
-func (p *AntigravityProvider) buildRequest(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) antigravityRequest {
+func (p *AntigravityProvider) buildRequest(
+ messages []Message,
+ tools []ToolDefinition,
+ model string,
+ options map[string]any,
+) antigravityRequest {
req := antigravityRequest{}
toolCallNames := make(map[string]string)
@@ -215,7 +228,7 @@ func (p *AntigravityProvider) buildRequest(messages []Message, tools []ToolDefin
Parts: []antigravityPart{{
FunctionResponse: &antigravityFunctionResponse{
Name: toolName,
- Response: map[string]interface{}{
+ Response: map[string]any{
"result": msg.Content,
},
},
@@ -237,9 +250,13 @@ func (p *AntigravityProvider) buildRequest(messages []Message, tools []ToolDefin
for _, tc := range msg.ToolCalls {
toolName, toolArgs, thoughtSignature := normalizeStoredToolCall(tc)
if toolName == "" {
- logger.WarnCF("provider.antigravity", "Skipping tool call with empty name in history", map[string]interface{}{
- "tool_call_id": tc.ID,
- })
+ logger.WarnCF(
+ "provider.antigravity",
+ "Skipping tool call with empty name in history",
+ map[string]any{
+ "tool_call_id": tc.ID,
+ },
+ )
continue
}
if tc.ID != "" {
@@ -264,7 +281,7 @@ func (p *AntigravityProvider) buildRequest(messages []Message, tools []ToolDefin
Parts: []antigravityPart{{
FunctionResponse: &antigravityFunctionResponse{
Name: toolName,
- Response: map[string]interface{}{
+ Response: map[string]any{
"result": msg.Content,
},
},
@@ -311,7 +328,7 @@ func (p *AntigravityProvider) buildRequest(messages []Message, tools []ToolDefin
return req
}
-func normalizeStoredToolCall(tc ToolCall) (string, map[string]interface{}, string) {
+func normalizeStoredToolCall(tc ToolCall) (string, map[string]any, string) {
name := tc.Name
args := tc.Arguments
thoughtSignature := ""
@@ -324,11 +341,11 @@ func normalizeStoredToolCall(tc ToolCall) (string, map[string]interface{}, strin
}
if args == nil {
- args = map[string]interface{}{}
+ args = map[string]any{}
}
if len(args) == 0 && tc.Function != nil && tc.Function.Arguments != "" {
- var parsed map[string]interface{}
+ var parsed map[string]any
if err := json.Unmarshal([]byte(tc.Function.Arguments), &parsed); err == nil && parsed != nil {
args = parsed
}
@@ -387,64 +404,6 @@ type antigravityJSONResponse struct {
} `json:"usageMetadata"`
}
-func (p *AntigravityProvider) parseJSONResponse(body []byte) (*LLMResponse, error) {
- var resp antigravityJSONResponse
- if err := json.Unmarshal(body, &resp); err != nil {
- return nil, fmt.Errorf("parsing antigravity response: %w", err)
- }
-
- if len(resp.Candidates) == 0 {
- return nil, fmt.Errorf("antigravity: no candidates in response")
- }
-
- candidate := resp.Candidates[0]
- var contentParts []string
- var toolCalls []ToolCall
-
- for _, part := range candidate.Content.Parts {
- if part.Text != "" {
- contentParts = append(contentParts, part.Text)
- }
- if part.FunctionCall != nil {
- argumentsJSON, _ := json.Marshal(part.FunctionCall.Args)
- toolCalls = append(toolCalls, ToolCall{
- ID: fmt.Sprintf("call_%s_%d", part.FunctionCall.Name, time.Now().UnixNano()),
- Name: part.FunctionCall.Name,
- Arguments: part.FunctionCall.Args,
- Function: &FunctionCall{
- Name: part.FunctionCall.Name,
- Arguments: string(argumentsJSON),
- ThoughtSignature: extractPartThoughtSignature(part.ThoughtSignature, part.ThoughtSignatureSnake),
- },
- })
- }
- }
-
- finishReason := "stop"
- if len(toolCalls) > 0 {
- finishReason = "tool_calls"
- }
- if candidate.FinishReason == "MAX_TOKENS" {
- finishReason = "length"
- }
-
- var usage *UsageInfo
- if resp.UsageMetadata.TotalTokenCount > 0 {
- usage = &UsageInfo{
- PromptTokens: resp.UsageMetadata.PromptTokenCount,
- CompletionTokens: resp.UsageMetadata.CandidatesTokenCount,
- TotalTokens: resp.UsageMetadata.TotalTokenCount,
- }
- }
-
- return &LLMResponse{
- Content: strings.Join(contentParts, ""),
- ToolCalls: toolCalls,
- FinishReason: finishReason,
- Usage: usage,
- }, nil
-}
-
func (p *AntigravityProvider) parseSSEResponse(body string) (*LLMResponse, error) {
var contentParts []string
var toolCalls []ToolCall
@@ -483,9 +442,12 @@ func (p *AntigravityProvider) parseSSEResponse(body string) (*LLMResponse, error
Name: part.FunctionCall.Name,
Arguments: part.FunctionCall.Args,
Function: &FunctionCall{
- Name: part.FunctionCall.Name,
- Arguments: string(argumentsJSON),
- ThoughtSignature: extractPartThoughtSignature(part.ThoughtSignature, part.ThoughtSignatureSnake),
+ Name: part.FunctionCall.Name,
+ Arguments: string(argumentsJSON),
+ ThoughtSignature: extractPartThoughtSignature(
+ part.ThoughtSignature,
+ part.ThoughtSignatureSnake,
+ ),
},
})
}
@@ -556,24 +518,24 @@ var geminiUnsupportedKeywords = map[string]bool{
"maxProperties": true,
}
-func sanitizeSchemaForGemini(schema map[string]interface{}) map[string]interface{} {
+func sanitizeSchemaForGemini(schema map[string]any) map[string]any {
if schema == nil {
return nil
}
- result := make(map[string]interface{})
+ result := make(map[string]any)
for k, v := range schema {
if geminiUnsupportedKeywords[k] {
continue
}
// Recursively sanitize nested objects
switch val := v.(type) {
- case map[string]interface{}:
+ case map[string]any:
result[k] = sanitizeSchemaForGemini(val)
- case []interface{}:
- sanitized := make([]interface{}, len(val))
+ case []any:
+ sanitized := make([]any, len(val))
for i, item := range val {
- if m, ok := item.(map[string]interface{}); ok {
+ if m, ok := item.(map[string]any); ok {
sanitized[i] = sanitizeSchemaForGemini(m)
} else {
sanitized[i] = item
@@ -604,7 +566,9 @@ func createAntigravityTokenSource() func() (string, string, error) {
return "", "", fmt.Errorf("loading auth credentials: %w", err)
}
if cred == nil {
- return "", "", fmt.Errorf("no credentials for google-antigravity. Run: picoclaw auth login --provider google-antigravity")
+ return "", "", fmt.Errorf(
+ "no credentials for google-antigravity. Run: picoclaw auth login --provider google-antigravity",
+ )
}
// Refresh if needed
@@ -625,7 +589,9 @@ func createAntigravityTokenSource() func() (string, string, error) {
}
if cred.IsExpired() {
- return "", "", fmt.Errorf("antigravity credentials expired. Run: picoclaw auth login --provider google-antigravity")
+ return "", "", fmt.Errorf(
+ "antigravity credentials expired. Run: picoclaw auth login --provider google-antigravity",
+ )
}
projectID := cred.ProjectID
@@ -633,7 +599,7 @@ func createAntigravityTokenSource() func() (string, string, error) {
// Try to fetch project ID from API
fetchedID, err := FetchAntigravityProjectID(cred.AccessToken)
if err != nil {
- logger.WarnCF("provider.antigravity", "Could not fetch project ID, using fallback", map[string]interface{}{
+ logger.WarnCF("provider.antigravity", "Could not fetch project ID, using fallback", map[string]any{
"error": err.Error(),
})
projectID = "rising-fact-p41fc" // Default fallback (same as OpenCode)
@@ -650,8 +616,8 @@ func createAntigravityTokenSource() func() (string, string, error) {
// FetchAntigravityProjectID retrieves the Google Cloud project ID from the loadCodeAssist endpoint.
func FetchAntigravityProjectID(accessToken string) (string, error) {
- reqBody, _ := json.Marshal(map[string]interface{}{
- "metadata": map[string]interface{}{
+ reqBody, _ := json.Marshal(map[string]any{
+ "metadata": map[string]any{
"ideType": "IDE_UNSPECIFIED",
"platform": "PLATFORM_UNSPECIFIED",
"pluginType": "GEMINI",
@@ -695,7 +661,7 @@ func FetchAntigravityProjectID(accessToken string) (string, error) {
// FetchAntigravityModels fetches available models from the Cloud Code Assist API.
func FetchAntigravityModels(accessToken, projectID string) ([]AntigravityModelInfo, error) {
- reqBody, _ := json.Marshal(map[string]interface{}{
+ reqBody, _ := json.Marshal(map[string]any{
"project": projectID,
})
@@ -717,16 +683,20 @@ func FetchAntigravityModels(accessToken, projectID string) ([]AntigravityModelIn
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
- return nil, fmt.Errorf("fetchAvailableModels failed (HTTP %d): %s", resp.StatusCode, truncateString(string(body), 200))
+ return nil, fmt.Errorf(
+ "fetchAvailableModels failed (HTTP %d): %s",
+ resp.StatusCode,
+ truncateString(string(body), 200),
+ )
}
var result struct {
Models map[string]struct {
DisplayName string `json:"displayName"`
QuotaInfo struct {
- RemainingFraction interface{} `json:"remainingFraction"`
- ResetTime string `json:"resetTime"`
- IsExhausted bool `json:"isExhausted"`
+ RemainingFraction any `json:"remainingFraction"`
+ ResetTime string `json:"resetTime"`
+ IsExhausted bool `json:"isExhausted"`
} `json:"quotaInfo"`
} `json:"models"`
}
@@ -797,10 +767,10 @@ func randomString(n int) string {
func (p *AntigravityProvider) parseAntigravityError(statusCode int, body []byte) error {
var errResp struct {
Error struct {
- Code int `json:"code"`
- Message string `json:"message"`
- Status string `json:"status"`
- Details []map[string]interface{} `json:"details"`
+ Code int `json:"code"`
+ Message string `json:"message"`
+ Status string `json:"status"`
+ Details []map[string]any `json:"details"`
} `json:"error"`
}
@@ -813,7 +783,7 @@ func (p *AntigravityProvider) parseAntigravityError(statusCode int, body []byte)
// Try to extract quota reset info
for _, detail := range errResp.Error.Details {
if typeVal, ok := detail["@type"].(string); ok && strings.HasSuffix(typeVal, "ErrorInfo") {
- if metadata, ok := detail["metadata"].(map[string]interface{}); ok {
+ if metadata, ok := detail["metadata"].(map[string]any); ok {
if delay, ok := metadata["quotaResetDelay"].(string); ok {
return fmt.Errorf("antigravity rate limit exceeded: %s (reset in %s)", msg, delay)
}
diff --git a/pkg/providers/claude_cli_provider.go b/pkg/providers/claude_cli_provider.go
index 58ba3647d..74ec33b98 100644
--- a/pkg/providers/claude_cli_provider.go
+++ b/pkg/providers/claude_cli_provider.go
@@ -24,7 +24,9 @@ func NewClaudeCliProvider(workspace string) *ClaudeCliProvider {
}
// Chat implements LLMProvider.Chat by executing the claude CLI.
-func (p *ClaudeCliProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
+func (p *ClaudeCliProvider) Chat(
+ ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]any,
+) (*LLMResponse, error) {
systemPrompt := p.buildSystemPrompt(messages, tools)
prompt := p.messagesToPrompt(messages)
@@ -111,7 +113,9 @@ func (p *ClaudeCliProvider) buildToolsPrompt(tools []ToolDefinition) string {
sb.WriteString("## Available Tools\n\n")
sb.WriteString("When you need to use a tool, respond with ONLY a JSON object:\n\n")
sb.WriteString("```json\n")
- sb.WriteString(`{"tool_calls":[{"id":"call_xxx","type":"function","function":{"name":"tool_name","arguments":"{...}"}}]}`)
+ sb.WriteString(
+ `{"tool_calls":[{"id":"call_xxx","type":"function","function":{"name":"tool_name","arguments":"{...}"}}]}`,
+ )
sb.WriteString("\n```\n\n")
sb.WriteString("CRITICAL: The 'arguments' field MUST be a JSON-encoded STRING.\n\n")
sb.WriteString("### Tool Definitions:\n\n")
diff --git a/pkg/providers/claude_cli_provider_integration_test.go b/pkg/providers/claude_cli_provider_integration_test.go
index 9d1131ac4..f6e0d787a 100644
--- a/pkg/providers/claude_cli_provider_integration_test.go
+++ b/pkg/providers/claude_cli_provider_integration_test.go
@@ -28,7 +28,6 @@ func TestIntegration_RealClaudeCLI(t *testing.T) {
resp, err := p.Chat(ctx, []Message{
{Role: "user", Content: "Respond with only the word 'pong'. Nothing else."},
}, nil, "", nil)
-
if err != nil {
t.Fatalf("Chat() with real CLI error = %v", err)
}
@@ -75,7 +74,6 @@ func TestIntegration_RealClaudeCLI_WithSystemPrompt(t *testing.T) {
{Role: "system", Content: "You are a calculator. Only respond with numbers. No text."},
{Role: "user", Content: "What is 2+2?"},
}, nil, "", nil)
-
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
diff --git a/pkg/providers/claude_cli_provider_test.go b/pkg/providers/claude_cli_provider_test.go
index 945f5bd4f..3a3cafaca 100644
--- a/pkg/providers/claude_cli_provider_test.go
+++ b/pkg/providers/claude_cli_provider_test.go
@@ -30,12 +30,12 @@ func createMockCLI(t *testing.T, stdout, stderr string, exitCode int) string {
dir := t.TempDir()
if stdout != "" {
- if err := os.WriteFile(filepath.Join(dir, "stdout.txt"), []byte(stdout), 0644); err != nil {
+ if err := os.WriteFile(filepath.Join(dir, "stdout.txt"), []byte(stdout), 0o644); err != nil {
t.Fatal(err)
}
}
if stderr != "" {
- if err := os.WriteFile(filepath.Join(dir, "stderr.txt"), []byte(stderr), 0644); err != nil {
+ if err := os.WriteFile(filepath.Join(dir, "stderr.txt"), []byte(stderr), 0o644); err != nil {
t.Fatal(err)
}
}
@@ -51,7 +51,7 @@ func createMockCLI(t *testing.T, stdout, stderr string, exitCode int) string {
sb.WriteString(fmt.Sprintf("exit %d\n", exitCode))
script := filepath.Join(dir, "claude")
- if err := os.WriteFile(script, []byte(sb.String()), 0755); err != nil {
+ if err := os.WriteFile(script, []byte(sb.String()), 0o755); err != nil {
t.Fatal(err)
}
return script
@@ -67,7 +67,7 @@ func createSlowMockCLI(t *testing.T, sleepSeconds int) string {
dir := t.TempDir()
script := filepath.Join(dir, "claude")
content := fmt.Sprintf("#!/bin/sh\nsleep %d\necho '{\"type\":\"result\",\"result\":\"late\"}'\n", sleepSeconds)
- if err := os.WriteFile(script, []byte(content), 0755); err != nil {
+ if err := os.WriteFile(script, []byte(content), 0o755); err != nil {
t.Fatal(err)
}
return script
@@ -88,7 +88,7 @@ cat <<'EOFMOCK'
{"type":"result","result":"ok","session_id":"test"}
EOFMOCK
`, argsFile)
- if err := os.WriteFile(script, []byte(content), 0755); err != nil {
+ if err := os.WriteFile(script, []byte(content), 0o755); err != nil {
t.Fatal(err)
}
return script
@@ -137,7 +137,6 @@ func TestChat_Success(t *testing.T) {
resp, err := p.Chat(context.Background(), []Message{
{Role: "user", Content: "Hello"},
}, nil, "", nil)
-
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
@@ -193,7 +192,6 @@ func TestChat_WithToolCallsInResponse(t *testing.T) {
resp, err := p.Chat(context.Background(), []Message{
{Role: "user", Content: "What's the weather?"},
}, nil, "", nil)
-
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
@@ -403,7 +401,6 @@ func TestChat_EmptyWorkspaceDoesNotSetDir(t *testing.T) {
resp, err := p.Chat(context.Background(), []Message{
{Role: "user", Content: "Hello"},
}, nil, "", nil)
-
if err != nil {
t.Fatalf("Chat() with empty workspace error = %v", err)
}
@@ -622,10 +619,10 @@ func TestBuildSystemPrompt_WithTools(t *testing.T) {
Function: ToolFunctionDefinition{
Name: "get_weather",
Description: "Get weather for a location",
- Parameters: map[string]interface{}{
+ Parameters: map[string]any{
"type": "object",
- "properties": map[string]interface{}{
- "location": map[string]interface{}{"type": "string"},
+ "properties": map[string]any{
+ "location": map[string]any{"type": "string"},
},
},
},
diff --git a/pkg/providers/claude_provider.go b/pkg/providers/claude_provider.go
index 3ca54d5a3..60639ca18 100644
--- a/pkg/providers/claude_provider.go
+++ b/pkg/providers/claude_provider.go
@@ -29,7 +29,9 @@ func NewClaudeProviderWithTokenSource(token string, tokenSource func() (string,
}
}
-func NewClaudeProviderWithTokenSourceAndBaseURL(token string, tokenSource func() (string, error), apiBase string) *ClaudeProvider {
+func NewClaudeProviderWithTokenSourceAndBaseURL(
+ token string, tokenSource func() (string, error), apiBase string,
+) *ClaudeProvider {
return &ClaudeProvider{
delegate: anthropicprovider.NewProviderWithTokenSourceAndBaseURL(token, tokenSource, apiBase),
}
@@ -39,7 +41,9 @@ func newClaudeProviderWithDelegate(delegate *anthropicprovider.Provider) *Claude
return &ClaudeProvider{delegate: delegate}
}
-func (p *ClaudeProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
+func (p *ClaudeProvider) Chat(
+ ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]any,
+) (*LLMResponse, error) {
resp, err := p.delegate.Chat(ctx, messages, tools, model, options)
if err != nil {
return nil, err
diff --git a/pkg/providers/claude_provider_test.go b/pkg/providers/claude_provider_test.go
index b1bcd8b40..98e07bb80 100644
--- a/pkg/providers/claude_provider_test.go
+++ b/pkg/providers/claude_provider_test.go
@@ -8,6 +8,7 @@ import (
"github.com/anthropics/anthropic-sdk-go"
anthropicoption "github.com/anthropics/anthropic-sdk-go/option"
+
anthropicprovider "github.com/sipeed/picoclaw/pkg/providers/anthropic"
)
@@ -22,19 +23,19 @@ func TestClaudeProvider_ChatRoundTrip(t *testing.T) {
return
}
- var reqBody map[string]interface{}
+ var reqBody map[string]any
json.NewDecoder(r.Body).Decode(&reqBody)
- resp := map[string]interface{}{
+ resp := map[string]any{
"id": "msg_test",
"type": "message",
"role": "assistant",
"model": reqBody["model"],
"stop_reason": "end_turn",
- "content": []map[string]interface{}{
+ "content": []map[string]any{
{"type": "text", "text": "Hello! How can I help you?"},
},
- "usage": map[string]interface{}{
+ "usage": map[string]any{
"input_tokens": 15,
"output_tokens": 8,
},
@@ -48,7 +49,7 @@ func TestClaudeProvider_ChatRoundTrip(t *testing.T) {
provider := newClaudeProviderWithDelegate(delegate)
messages := []Message{{Role: "user", Content: "Hello"}}
- resp, err := provider.Chat(t.Context(), messages, nil, "claude-sonnet-4.6", map[string]interface{}{"max_tokens": 1024})
+ resp, err := provider.Chat(t.Context(), messages, nil, "claude-sonnet-4.6", map[string]any{"max_tokens": 1024})
if err != nil {
t.Fatalf("Chat() error: %v", err)
}
diff --git a/pkg/providers/codex_cli_credentials.go b/pkg/providers/codex_cli_credentials.go
index 7ad39ce8e..40f3ee2a1 100644
--- a/pkg/providers/codex_cli_credentials.go
+++ b/pkg/providers/codex_cli_credentials.go
@@ -31,7 +31,7 @@ func ReadCodexCliCredentials() (accessToken, accountID string, expiresAt time.Ti
}
var auth CodexCliAuth
- if err := json.Unmarshal(data, &auth); err != nil {
+ if err = json.Unmarshal(data, &auth); err != nil {
return "", "", time.Time{}, fmt.Errorf("parsing %s: %w", authPath, err)
}
@@ -59,7 +59,9 @@ func CreateCodexCliTokenSource() func() (string, string, error) {
}
if time.Now().After(expiresAt) {
- return "", "", fmt.Errorf("codex cli credentials expired (auth.json last modified > 1h ago). Run: codex login")
+ return "", "", fmt.Errorf(
+ "codex cli credentials expired (auth.json last modified > 1h ago). Run: codex login",
+ )
}
return token, accountID, nil
diff --git a/pkg/providers/codex_cli_credentials_test.go b/pkg/providers/codex_cli_credentials_test.go
index 3267f2d16..1e88c1120 100644
--- a/pkg/providers/codex_cli_credentials_test.go
+++ b/pkg/providers/codex_cli_credentials_test.go
@@ -18,7 +18,7 @@ func TestReadCodexCliCredentials_Valid(t *testing.T) {
"account_id": "org-test123"
}
}`
- if err := os.WriteFile(authPath, []byte(authJSON), 0600); err != nil {
+ if err := os.WriteFile(authPath, []byte(authJSON), 0o600); err != nil {
t.Fatal(err)
}
@@ -43,12 +43,18 @@ func TestReadCodexCliCredentials_Valid(t *testing.T) {
}
}
+// readCodexCliCredentialsErr calls ReadCodexCliCredentials and returns only the
+// error, for tests that only need to assert on failure.
+func readCodexCliCredentialsErr() error {
+ _, _, _, err := ReadCodexCliCredentials() //nolint:dogsled
+ return err
+}
+
func TestReadCodexCliCredentials_MissingFile(t *testing.T) {
tmpDir := t.TempDir()
t.Setenv("CODEX_HOME", tmpDir)
- _, _, _, err := ReadCodexCliCredentials()
- if err == nil {
+ if err := readCodexCliCredentialsErr(); err == nil {
t.Fatal("expected error for missing auth.json")
}
}
@@ -58,14 +64,13 @@ func TestReadCodexCliCredentials_EmptyToken(t *testing.T) {
authPath := filepath.Join(tmpDir, "auth.json")
authJSON := `{"tokens": {"access_token": "", "refresh_token": "r", "account_id": "a"}}`
- if err := os.WriteFile(authPath, []byte(authJSON), 0600); err != nil {
+ if err := os.WriteFile(authPath, []byte(authJSON), 0o600); err != nil {
t.Fatal(err)
}
t.Setenv("CODEX_HOME", tmpDir)
- _, _, _, err := ReadCodexCliCredentials()
- if err == nil {
+ if err := readCodexCliCredentialsErr(); err == nil {
t.Fatal("expected error for empty access_token")
}
}
@@ -74,14 +79,13 @@ func TestReadCodexCliCredentials_InvalidJSON(t *testing.T) {
tmpDir := t.TempDir()
authPath := filepath.Join(tmpDir, "auth.json")
- if err := os.WriteFile(authPath, []byte("not json"), 0600); err != nil {
+ if err := os.WriteFile(authPath, []byte("not json"), 0o600); err != nil {
t.Fatal(err)
}
t.Setenv("CODEX_HOME", tmpDir)
- _, _, _, err := ReadCodexCliCredentials()
- if err == nil {
+ if err := readCodexCliCredentialsErr(); err == nil {
t.Fatal("expected error for invalid JSON")
}
}
@@ -91,7 +95,7 @@ func TestReadCodexCliCredentials_NoAccountID(t *testing.T) {
authPath := filepath.Join(tmpDir, "auth.json")
authJSON := `{"tokens": {"access_token": "tok123", "refresh_token": "ref456"}}`
- if err := os.WriteFile(authPath, []byte(authJSON), 0600); err != nil {
+ if err := os.WriteFile(authPath, []byte(authJSON), 0o600); err != nil {
t.Fatal(err)
}
@@ -112,12 +116,12 @@ func TestReadCodexCliCredentials_NoAccountID(t *testing.T) {
func TestReadCodexCliCredentials_CodexHomeEnv(t *testing.T) {
tmpDir := t.TempDir()
customDir := filepath.Join(tmpDir, "custom-codex")
- if err := os.MkdirAll(customDir, 0755); err != nil {
+ if err := os.MkdirAll(customDir, 0o755); err != nil {
t.Fatal(err)
}
authJSON := `{"tokens": {"access_token": "custom-token", "refresh_token": "r"}}`
- if err := os.WriteFile(filepath.Join(customDir, "auth.json"), []byte(authJSON), 0600); err != nil {
+ if err := os.WriteFile(filepath.Join(customDir, "auth.json"), []byte(authJSON), 0o600); err != nil {
t.Fatal(err)
}
@@ -137,7 +141,7 @@ func TestCreateCodexCliTokenSource_Valid(t *testing.T) {
authPath := filepath.Join(tmpDir, "auth.json")
authJSON := `{"tokens": {"access_token": "fresh-token", "refresh_token": "r", "account_id": "acc"}}`
- if err := os.WriteFile(authPath, []byte(authJSON), 0600); err != nil {
+ if err := os.WriteFile(authPath, []byte(authJSON), 0o600); err != nil {
t.Fatal(err)
}
@@ -161,7 +165,7 @@ func TestCreateCodexCliTokenSource_Expired(t *testing.T) {
authPath := filepath.Join(tmpDir, "auth.json")
authJSON := `{"tokens": {"access_token": "old-token", "refresh_token": "r"}}`
- if err := os.WriteFile(authPath, []byte(authJSON), 0600); err != nil {
+ if err := os.WriteFile(authPath, []byte(authJSON), 0o600); err != nil {
t.Fatal(err)
}
diff --git a/pkg/providers/codex_cli_provider.go b/pkg/providers/codex_cli_provider.go
index 8886406b4..4c783ece5 100644
--- a/pkg/providers/codex_cli_provider.go
+++ b/pkg/providers/codex_cli_provider.go
@@ -25,7 +25,9 @@ func NewCodexCliProvider(workspace string) *CodexCliProvider {
}
// Chat implements LLMProvider.Chat by executing the codex CLI in non-interactive mode.
-func (p *CodexCliProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
+func (p *CodexCliProvider) Chat(
+ ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]any,
+) (*LLMResponse, error) {
if p.command == "" {
return nil, fmt.Errorf("codex command not configured")
}
@@ -133,7 +135,9 @@ func (p *CodexCliProvider) buildToolsPrompt(tools []ToolDefinition) string {
sb.WriteString("## Available Tools\n\n")
sb.WriteString("When you need to use a tool, respond with ONLY a JSON object:\n\n")
sb.WriteString("```json\n")
- sb.WriteString(`{"tool_calls":[{"id":"call_xxx","type":"function","function":{"name":"tool_name","arguments":"{...}"}}]}`)
+ sb.WriteString(
+ `{"tool_calls":[{"id":"call_xxx","type":"function","function":{"name":"tool_name","arguments":"{...}"}}]}`,
+ )
sb.WriteString("\n```\n\n")
sb.WriteString("CRITICAL: The 'arguments' field MUST be a JSON-encoded STRING.\n\n")
sb.WriteString("### Tool Definitions:\n\n")
diff --git a/pkg/providers/codex_cli_provider_integration_test.go b/pkg/providers/codex_cli_provider_integration_test.go
index 0267c730f..17a8305ad 100644
--- a/pkg/providers/codex_cli_provider_integration_test.go
+++ b/pkg/providers/codex_cli_provider_integration_test.go
@@ -27,7 +27,6 @@ func TestIntegration_RealCodexCLI(t *testing.T) {
resp, err := p.Chat(ctx, []Message{
{Role: "user", Content: "Respond with only the word 'pong'. Nothing else."},
}, nil, "", nil)
-
if err != nil {
t.Fatalf("Chat() with real CLI error = %v", err)
}
@@ -64,7 +63,6 @@ func TestIntegration_RealCodexCLI_WithSystemPrompt(t *testing.T) {
{Role: "system", Content: "You are a calculator. Only respond with numbers. No text."},
{Role: "user", Content: "What is 2+2?"},
}, nil, "", nil)
-
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
diff --git a/pkg/providers/codex_cli_provider_test.go b/pkg/providers/codex_cli_provider_test.go
index 7e4e1bc15..414e0844d 100644
--- a/pkg/providers/codex_cli_provider_test.go
+++ b/pkg/providers/codex_cli_provider_test.go
@@ -292,10 +292,10 @@ func TestBuildPrompt_WithTools(t *testing.T) {
Function: ToolFunctionDefinition{
Name: "get_weather",
Description: "Get current weather",
- Parameters: map[string]interface{}{
+ Parameters: map[string]any{
"type": "object",
- "properties": map[string]interface{}{
- "city": map[string]interface{}{"type": "string"},
+ "properties": map[string]any{
+ "city": map[string]any{"type": "string"},
},
},
},
@@ -409,7 +409,7 @@ func createMockCodexCLI(t *testing.T, events []string) string {
sb.WriteString(fmt.Sprintf("echo '%s'\n", event))
}
- if err := os.WriteFile(scriptPath, []byte(sb.String()), 0755); err != nil {
+ if err := os.WriteFile(scriptPath, []byte(sb.String()), 0o755); err != nil {
t.Fatal(err)
}
return scriptPath
@@ -480,7 +480,7 @@ echo "$@" > "` + filepath.Join(tmpDir, "args.txt") + `"
echo '{"type":"item.completed","item":{"id":"1","type":"agent_message","text":"ok"}}'
echo '{"type":"turn.completed"}'`
- if err := os.WriteFile(scriptPath, []byte(script), 0755); err != nil {
+ if err := os.WriteFile(scriptPath, []byte(script), 0o755); err != nil {
t.Fatal(err)
}
@@ -522,7 +522,7 @@ func TestCodexCliProvider_MockCLI_ContextCancel(t *testing.T) {
scriptPath := filepath.Join(tmpDir, "codex")
script := "#!/bin/bash\nsleep 60"
- if err := os.WriteFile(scriptPath, []byte(script), 0755); err != nil {
+ if err := os.WriteFile(scriptPath, []byte(script), 0o755); err != nil {
t.Fatal(err)
}
diff --git a/pkg/providers/codex_provider.go b/pkg/providers/codex_provider.go
index e3526cfb5..dcc740ba4 100644
--- a/pkg/providers/codex_provider.go
+++ b/pkg/providers/codex_provider.go
@@ -10,12 +10,15 @@ import (
"github.com/openai/openai-go/v3"
"github.com/openai/openai-go/v3/option"
"github.com/openai/openai-go/v3/responses"
+
"github.com/sipeed/picoclaw/pkg/auth"
"github.com/sipeed/picoclaw/pkg/logger"
)
-const codexDefaultModel = "gpt-5.2"
-const codexDefaultInstructions = "You are Codex, a coding assistant."
+const (
+ codexDefaultModel = "gpt-5.2"
+ codexDefaultInstructions = "You are Codex, a coding assistant."
+)
type CodexProvider struct {
client *openai.Client
@@ -44,22 +47,30 @@ func NewCodexProvider(token, accountID string) *CodexProvider {
}
}
-func NewCodexProviderWithTokenSource(token, accountID string, tokenSource func() (string, string, error)) *CodexProvider {
+func NewCodexProviderWithTokenSource(
+ token, accountID string, tokenSource func() (string, string, error),
+) *CodexProvider {
p := NewCodexProvider(token, accountID)
p.tokenSource = tokenSource
return p
}
-func (p *CodexProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
+func (p *CodexProvider) Chat(
+ ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]any,
+) (*LLMResponse, error) {
var opts []option.RequestOption
accountID := p.accountID
resolvedModel, fallbackReason := resolveCodexModel(model)
if fallbackReason != "" {
- logger.WarnCF("provider.codex", "Requested model is not compatible with Codex backend, using fallback", map[string]interface{}{
- "requested_model": model,
- "resolved_model": resolvedModel,
- "reason": fallbackReason,
- })
+ logger.WarnCF(
+ "provider.codex",
+ "Requested model is not compatible with Codex backend, using fallback",
+ map[string]any{
+ "requested_model": model,
+ "resolved_model": resolvedModel,
+ "reason": fallbackReason,
+ },
+ )
}
if p.tokenSource != nil {
tok, accID, err := p.tokenSource()
@@ -74,10 +85,14 @@ func (p *CodexProvider) Chat(ctx context.Context, messages []Message, tools []To
if accountID != "" {
opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accountID))
} else {
- logger.WarnCF("provider.codex", "No account id found for Codex request; backend may reject with 400", map[string]interface{}{
- "requested_model": model,
- "resolved_model": resolvedModel,
- })
+ logger.WarnCF(
+ "provider.codex",
+ "No account id found for Codex request; backend may reject with 400",
+ map[string]any{
+ "requested_model": model,
+ "resolved_model": resolvedModel,
+ },
+ )
}
params := buildCodexParams(messages, tools, resolvedModel, options, p.enableWebSearch)
@@ -91,14 +106,14 @@ func (p *CodexProvider) Chat(ctx context.Context, messages []Message, tools []To
if evt.Type == "response.completed" || evt.Type == "response.failed" || evt.Type == "response.incomplete" {
evtResp := evt.Response
if evtResp.ID != "" {
- copy := evtResp
- resp = ©
+ evtRespCopy := evtResp
+ resp = &evtRespCopy
}
}
}
err := stream.Err()
if err != nil {
- fields := map[string]interface{}{
+ fields := map[string]any{
"requested_model": model,
"resolved_model": resolvedModel,
"messages_count": len(messages),
@@ -124,7 +139,7 @@ func (p *CodexProvider) Chat(ctx context.Context, messages []Message, tools []To
return nil, fmt.Errorf("codex API call: %w", err)
}
if resp == nil {
- fields := map[string]interface{}{
+ fields := map[string]any{
"requested_model": model,
"resolved_model": resolvedModel,
"messages_count": len(messages),
@@ -184,20 +199,29 @@ func resolveCodexModel(model string) (string, string) {
return codexDefaultModel, "unsupported model family"
}
-func buildCodexParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, enableWebSearch bool) responses.ResponseNewParams {
+func buildCodexParams(
+ messages []Message, tools []ToolDefinition, model string, options map[string]any, enableWebSearch bool,
+) responses.ResponseNewParams {
var inputItems responses.ResponseInputParam
var instructions string
for _, msg := range messages {
switch msg.Role {
case "system":
+ // Use the full concatenated system prompt (static + dynamic + summary)
+ // as instructions. This keeps behavior consistent with Anthropic and
+ // OpenAI-compat adapters where the complete system context lives in
+ // one place. Prefix caching is handled by prompt_cache_key below,
+ // not by splitting content across instructions vs input messages.
instructions = msg.Content
case "user":
if msg.ToolCallID != "" {
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
OfFunctionCallOutput: &responses.ResponseInputItemFunctionCallOutputParam{
CallID: msg.ToolCallID,
- Output: responses.ResponseInputItemFunctionCallOutputOutputUnionParam{OfString: openai.Opt(msg.Content)},
+ Output: responses.ResponseInputItemFunctionCallOutputOutputUnionParam{
+ OfString: openai.Opt(msg.Content),
+ },
},
})
} else {
@@ -221,7 +245,7 @@ func buildCodexParams(messages []Message, tools []ToolDefinition, model string,
for _, tc := range msg.ToolCalls {
name, args, ok := resolveCodexToolCall(tc)
if !ok {
- logger.WarnCF("provider.codex", "Skipping invalid tool call in history", map[string]interface{}{
+ logger.WarnCF("provider.codex", "Skipping invalid tool call in history", map[string]any{
"call_id": tc.ID,
})
continue
@@ -246,7 +270,9 @@ func buildCodexParams(messages []Message, tools []ToolDefinition, model string,
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
OfFunctionCallOutput: &responses.ResponseInputItemFunctionCallOutputParam{
CallID: msg.ToolCallID,
- Output: responses.ResponseInputItemFunctionCallOutputOutputUnionParam{OfString: openai.Opt(msg.Content)},
+ Output: responses.ResponseInputItemFunctionCallOutputOutputUnionParam{
+ OfString: openai.Opt(msg.Content),
+ },
},
})
}
@@ -268,6 +294,13 @@ func buildCodexParams(messages []Message, tools []ToolDefinition, model string,
params.Instructions = openai.Opt(defaultCodexInstructions)
}
+ // Prompt caching: pass a stable cache key so OpenAI can bucket requests
+ // and reuse prefix KV cache across calls with the same key.
+ // See: https://platform.openai.com/docs/guides/prompt-caching
+ if cacheKey, ok := options["prompt_cache_key"].(string); ok && cacheKey != "" {
+ params.PromptCacheKey = openai.Opt(cacheKey)
+ }
+
if len(tools) > 0 || enableWebSearch {
params.Tools = translateToolsForCodex(tools, enableWebSearch)
}
@@ -341,9 +374,9 @@ func parseCodexResponse(resp *responses.Response) *LLMResponse {
}
}
case "function_call":
- var args map[string]interface{}
+ var args map[string]any
if err := json.Unmarshal([]byte(item.Arguments), &args); err != nil {
- args = map[string]interface{}{"raw": item.Arguments}
+ args = map[string]any{"raw": item.Arguments}
}
toolCalls = append(toolCalls, ToolCall{
ID: item.CallID,
diff --git a/pkg/providers/codex_provider_test.go b/pkg/providers/codex_provider_test.go
index 92e276165..4157e53e9 100644
--- a/pkg/providers/codex_provider_test.go
+++ b/pkg/providers/codex_provider_test.go
@@ -16,7 +16,7 @@ func TestBuildCodexParams_BasicMessage(t *testing.T) {
messages := []Message{
{Role: "user", Content: "Hello"},
}
- params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{
+ params := buildCodexParams(messages, nil, "gpt-4o", map[string]any{
"max_tokens": 2048,
"temperature": 0.7,
}, true)
@@ -39,7 +39,7 @@ func TestBuildCodexParams_SystemAsInstructions(t *testing.T) {
{Role: "system", Content: "You are helpful"},
{Role: "user", Content: "Hi"},
}
- params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{}, true)
+ params := buildCodexParams(messages, nil, "gpt-4o", map[string]any{}, true)
if !params.Instructions.Valid() {
t.Fatal("Instructions should be set")
}
@@ -54,12 +54,12 @@ func TestBuildCodexParams_ToolCallConversation(t *testing.T) {
{
Role: "assistant",
ToolCalls: []ToolCall{
- {ID: "call_1", Name: "get_weather", Arguments: map[string]interface{}{"city": "SF"}},
+ {ID: "call_1", Name: "get_weather", Arguments: map[string]any{"city": "SF"}},
},
},
{Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"},
}
- params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{}, false)
+ params := buildCodexParams(messages, nil, "gpt-4o", map[string]any{}, false)
if params.Input.OfInputItemList == nil {
t.Fatal("Input.OfInputItemList should not be nil")
}
@@ -87,7 +87,7 @@ func TestBuildCodexParams_ToolCallFunctionFallback(t *testing.T) {
{Role: "tool", Content: "ok", ToolCallID: "call_1"},
}
- params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{}, false)
+ params := buildCodexParams(messages, nil, "gpt-4o", map[string]any{}, false)
if params.Input.OfInputItemList == nil {
t.Fatal("Input.OfInputItemList should not be nil")
}
@@ -114,16 +114,16 @@ func TestBuildCodexParams_WithTools(t *testing.T) {
Function: ToolFunctionDefinition{
Name: "get_weather",
Description: "Get weather",
- Parameters: map[string]interface{}{
+ Parameters: map[string]any{
"type": "object",
- "properties": map[string]interface{}{
- "city": map[string]interface{}{"type": "string"},
+ "properties": map[string]any{
+ "city": map[string]any{"type": "string"},
},
},
},
},
}
- params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, tools, "gpt-4o", map[string]interface{}{}, false)
+ params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, tools, "gpt-4o", map[string]any{}, false)
if len(params.Tools) != 1 {
t.Fatalf("len(Tools) = %d, want 1", len(params.Tools))
}
@@ -136,14 +136,14 @@ func TestBuildCodexParams_WithTools(t *testing.T) {
}
func TestBuildCodexParams_StoreIsFalse(t *testing.T) {
- params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, nil, "gpt-4o", map[string]interface{}{}, false)
+ params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, nil, "gpt-4o", map[string]any{}, false)
if !params.Store.Valid() || params.Store.Or(true) != false {
t.Error("Store should be explicitly set to false")
}
}
func TestBuildCodexParams_DefaultWebSearchEnabled(t *testing.T) {
- params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, nil, "gpt-4o", map[string]interface{}{}, true)
+ params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, nil, "gpt-4o", map[string]any{}, true)
if len(params.Tools) != 1 {
t.Fatalf("len(Tools) = %d, want 1", len(params.Tools))
}
@@ -151,7 +151,11 @@ func TestBuildCodexParams_DefaultWebSearchEnabled(t *testing.T) {
t.Fatal("Tool should include built-in web_search")
}
if params.Tools[0].OfWebSearch.Type != responses.WebSearchToolTypeWebSearch {
- t.Errorf("Web search tool type = %q, want %q", params.Tools[0].OfWebSearch.Type, responses.WebSearchToolTypeWebSearch)
+ t.Errorf(
+ "Web search tool type = %q, want %q",
+ params.Tools[0].OfWebSearch.Type,
+ responses.WebSearchToolTypeWebSearch,
+ )
}
}
@@ -162,7 +166,7 @@ func TestBuildCodexParams_WebSearchFunctionReplacedWithBuiltin(t *testing.T) {
Function: ToolFunctionDefinition{
Name: "web_search",
Description: "local web search",
- Parameters: map[string]interface{}{
+ Parameters: map[string]any{
"type": "object",
},
},
@@ -172,14 +176,14 @@ func TestBuildCodexParams_WebSearchFunctionReplacedWithBuiltin(t *testing.T) {
Function: ToolFunctionDefinition{
Name: "read_file",
Description: "read file",
- Parameters: map[string]interface{}{
+ Parameters: map[string]any{
"type": "object",
},
},
},
}
- params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, tools, "gpt-4o", map[string]interface{}{}, true)
+ params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, tools, "gpt-4o", map[string]any{}, true)
if len(params.Tools) != 2 {
t.Fatalf("len(Tools) = %d, want 2", len(params.Tools))
}
@@ -296,7 +300,7 @@ func TestCodexProvider_ChatRoundTrip(t *testing.T) {
return
}
- var reqBody map[string]interface{}
+ var reqBody map[string]any
if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil {
http.Error(w, "invalid json", http.StatusBadRequest)
return
@@ -309,38 +313,38 @@ func TestCodexProvider_ChatRoundTrip(t *testing.T) {
http.Error(w, "max_output_tokens is not supported", http.StatusBadRequest)
return
}
- toolsAny, ok := reqBody["tools"].([]interface{})
+ toolsAny, ok := reqBody["tools"].([]any)
if !ok || len(toolsAny) != 1 {
http.Error(w, "missing default web search tool", http.StatusBadRequest)
return
}
- toolObj, ok := toolsAny[0].(map[string]interface{})
+ toolObj, ok := toolsAny[0].(map[string]any)
if !ok || toolObj["type"] != "web_search" {
http.Error(w, "expected web_search tool", http.StatusBadRequest)
return
}
- resp := map[string]interface{}{
+ resp := map[string]any{
"id": "resp_test",
"object": "response",
"status": "completed",
- "output": []map[string]interface{}{
+ "output": []map[string]any{
{
"id": "msg_1",
"type": "message",
"role": "assistant",
"status": "completed",
- "content": []map[string]interface{}{
+ "content": []map[string]any{
{"type": "output_text", "text": "Hi from Codex!"},
},
},
},
- "usage": map[string]interface{}{
+ "usage": map[string]any{
"input_tokens": 12,
"output_tokens": 6,
"total_tokens": 18,
- "input_tokens_details": map[string]interface{}{"cached_tokens": 0},
- "output_tokens_details": map[string]interface{}{"reasoning_tokens": 0},
+ "input_tokens_details": map[string]any{"cached_tokens": 0},
+ "output_tokens_details": map[string]any{"reasoning_tokens": 0},
},
}
writeCompletedSSE(w, resp)
@@ -351,7 +355,7 @@ func TestCodexProvider_ChatRoundTrip(t *testing.T) {
provider.client = createOpenAITestClient(server.URL, "test-token", "acc-123")
messages := []Message{{Role: "user", Content: "Hello"}}
- resp, err := provider.Chat(t.Context(), messages, nil, "gpt-4o", map[string]interface{}{"max_tokens": 1024})
+ resp, err := provider.Chat(t.Context(), messages, nil, "gpt-4o", map[string]any{"max_tokens": 1024})
if err != nil {
t.Fatalf("Chat() error: %v", err)
}
@@ -373,7 +377,7 @@ func TestCodexProvider_ChatRoundTrip_WebSearchDisabled(t *testing.T) {
return
}
- var reqBody map[string]interface{}
+ var reqBody map[string]any
if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil {
http.Error(w, "invalid json", http.StatusBadRequest)
return
@@ -383,27 +387,27 @@ func TestCodexProvider_ChatRoundTrip_WebSearchDisabled(t *testing.T) {
return
}
- resp := map[string]interface{}{
+ resp := map[string]any{
"id": "resp_test",
"object": "response",
"status": "completed",
- "output": []map[string]interface{}{
+ "output": []map[string]any{
{
"id": "msg_1",
"type": "message",
"role": "assistant",
"status": "completed",
- "content": []map[string]interface{}{
+ "content": []map[string]any{
{"type": "output_text", "text": "Hi from Codex!"},
},
},
},
- "usage": map[string]interface{}{
+ "usage": map[string]any{
"input_tokens": 4,
"output_tokens": 3,
"total_tokens": 7,
- "input_tokens_details": map[string]interface{}{"cached_tokens": 0},
- "output_tokens_details": map[string]interface{}{"reasoning_tokens": 0},
+ "input_tokens_details": map[string]any{"cached_tokens": 0},
+ "output_tokens_details": map[string]any{"reasoning_tokens": 0},
},
}
writeCompletedSSE(w, resp)
@@ -415,7 +419,7 @@ func TestCodexProvider_ChatRoundTrip_WebSearchDisabled(t *testing.T) {
provider.client = createOpenAITestClient(server.URL, "test-token", "acc-123")
messages := []Message{{Role: "user", Content: "Hello"}}
- resp, err := provider.Chat(t.Context(), messages, nil, "gpt-4o", map[string]interface{}{})
+ resp, err := provider.Chat(t.Context(), messages, nil, "gpt-4o", map[string]any{})
if err != nil {
t.Fatalf("Chat() error: %v", err)
}
@@ -439,7 +443,7 @@ func TestCodexProvider_ChatRoundTrip_TokenSourceFallbackAccountID(t *testing.T)
return
}
- var reqBody map[string]interface{}
+ var reqBody map[string]any
if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil {
http.Error(w, "invalid json", http.StatusBadRequest)
return
@@ -465,27 +469,27 @@ func TestCodexProvider_ChatRoundTrip_TokenSourceFallbackAccountID(t *testing.T)
return
}
- resp := map[string]interface{}{
+ resp := map[string]any{
"id": "resp_test",
"object": "response",
"status": "completed",
- "output": []map[string]interface{}{
+ "output": []map[string]any{
{
"id": "msg_1",
"type": "message",
"role": "assistant",
"status": "completed",
- "content": []map[string]interface{}{
+ "content": []map[string]any{
{"type": "output_text", "text": "Hi from Codex!"},
},
},
},
- "usage": map[string]interface{}{
+ "usage": map[string]any{
"input_tokens": 8,
"output_tokens": 4,
"total_tokens": 12,
- "input_tokens_details": map[string]interface{}{"cached_tokens": 0},
- "output_tokens_details": map[string]interface{}{"reasoning_tokens": 0},
+ "input_tokens_details": map[string]any{"cached_tokens": 0},
+ "output_tokens_details": map[string]any{"reasoning_tokens": 0},
},
}
writeCompletedSSE(w, resp)
@@ -499,7 +503,7 @@ func TestCodexProvider_ChatRoundTrip_TokenSourceFallbackAccountID(t *testing.T)
}
messages := []Message{{Role: "user", Content: "Hello"}}
- resp, err := provider.Chat(t.Context(), messages, nil, "gpt-4o", map[string]interface{}{"temperature": 0.7})
+ resp, err := provider.Chat(t.Context(), messages, nil, "gpt-4o", map[string]any{"temperature": 0.7})
if err != nil {
t.Fatalf("Chat() error: %v", err)
}
@@ -515,7 +519,7 @@ func TestCodexProvider_ChatRoundTrip_ModelFallbackFromUnsupported(t *testing.T)
return
}
- var reqBody map[string]interface{}
+ var reqBody map[string]any
if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil {
http.Error(w, "invalid json", http.StatusBadRequest)
return
@@ -533,27 +537,27 @@ func TestCodexProvider_ChatRoundTrip_ModelFallbackFromUnsupported(t *testing.T)
return
}
- resp := map[string]interface{}{
+ resp := map[string]any{
"id": "resp_test",
"object": "response",
"status": "completed",
- "output": []map[string]interface{}{
+ "output": []map[string]any{
{
"id": "msg_1",
"type": "message",
"role": "assistant",
"status": "completed",
- "content": []map[string]interface{}{
+ "content": []map[string]any{
{"type": "output_text", "text": "Hi from Codex!"},
},
},
},
- "usage": map[string]interface{}{
+ "usage": map[string]any{
"input_tokens": 8,
"output_tokens": 4,
"total_tokens": 12,
- "input_tokens_details": map[string]interface{}{"cached_tokens": 0},
- "output_tokens_details": map[string]interface{}{"reasoning_tokens": 0},
+ "input_tokens_details": map[string]any{"cached_tokens": 0},
+ "output_tokens_details": map[string]any{"reasoning_tokens": 0},
},
}
writeCompletedSSE(w, resp)
@@ -588,7 +592,12 @@ func TestResolveCodexModel(t *testing.T) {
wantFallback bool
}{
{name: "empty", input: "", wantModel: codexDefaultModel, wantFallback: true},
- {name: "unsupported namespace", input: "anthropic/claude-3.5", wantModel: codexDefaultModel, wantFallback: true},
+ {
+ name: "unsupported namespace",
+ input: "anthropic/claude-3.5",
+ wantModel: codexDefaultModel,
+ wantFallback: true,
+ },
{name: "non-openai prefixed", input: "glm-4.7", wantModel: codexDefaultModel, wantFallback: true},
{name: "openai prefix", input: "openai/gpt-5.2", wantModel: "gpt-5.2", wantFallback: false},
{name: "direct gpt", input: "gpt-4o", wantModel: "gpt-4o", wantFallback: false},
@@ -622,8 +631,8 @@ func createOpenAITestClient(baseURL, token, accountID string) *openai.Client {
return &c
}
-func writeCompletedSSE(w http.ResponseWriter, response map[string]interface{}) {
- event := map[string]interface{}{
+func writeCompletedSSE(w http.ResponseWriter, response map[string]any) {
+ event := map[string]any{
"type": "response.completed",
"sequence_number": 1,
"response": response,
diff --git a/pkg/providers/factory.go b/pkg/providers/factory.go
index b6f1b5e21..11af14da4 100644
--- a/pkg/providers/factory.go
+++ b/pkg/providers/factory.go
@@ -36,7 +36,7 @@ type providerSelection struct {
}
func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
- model := cfg.Agents.Defaults.Model
+ model := cfg.Agents.Defaults.GetModelName()
providerName := strings.ToLower(cfg.Agents.Defaults.Provider)
lowerModel := strings.ToLower(model)
@@ -172,6 +172,15 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
sel.model = "deepseek-chat"
}
}
+ case "mistral":
+ if cfg.Providers.Mistral.APIKey != "" {
+ sel.apiKey = cfg.Providers.Mistral.APIKey
+ sel.apiBase = cfg.Providers.Mistral.APIBase
+ sel.proxy = cfg.Providers.Mistral.Proxy
+ if sel.apiBase == "" {
+ sel.apiBase = "https://api.mistral.ai/v1"
+ }
+ }
case "github_copilot", "copilot":
sel.providerType = providerTypeGitHubCopilot
if cfg.Providers.GitHubCopilot.APIBase != "" {
@@ -275,6 +284,13 @@ func resolveProviderSelection(cfg *config.Config) (providerSelection, error) {
if sel.apiBase == "" {
sel.apiBase = "http://localhost:11434/v1"
}
+ case (strings.Contains(lowerModel, "mistral") || strings.HasPrefix(model, "mistral/")) && cfg.Providers.Mistral.APIKey != "":
+ sel.apiKey = cfg.Providers.Mistral.APIKey
+ sel.apiBase = cfg.Providers.Mistral.APIBase
+ sel.proxy = cfg.Providers.Mistral.Proxy
+ if sel.apiBase == "" {
+ sel.apiBase = "https://api.mistral.ai/v1"
+ }
case cfg.Providers.VLLM.APIBase != "":
sel.apiKey = cfg.Providers.VLLM.APIKey
sel.apiBase = cfg.Providers.VLLM.APIBase
diff --git a/pkg/providers/factory_provider.go b/pkg/providers/factory_provider.go
index 74fe8a36c..7d5566eef 100644
--- a/pkg/providers/factory_provider.go
+++ b/pkg/providers/factory_provider.go
@@ -88,7 +88,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
case "openrouter", "groq", "zhipu", "gemini", "nvidia",
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
- "volcengine", "vllm", "qwen":
+ "volcengine", "vllm", "qwen", "mistral":
// All other OpenAI-compatible HTTP providers
if cfg.APIKey == "" && cfg.APIBase == "" {
return nil, "", fmt.Errorf("api_key or api_base is required for HTTP-based protocol %q", protocol)
@@ -186,6 +186,8 @@ func getDefaultAPIBase(protocol string) string {
return "https://dashscope.aliyuncs.com/compatible-mode/v1"
case "vllm":
return "http://localhost:8000/v1"
+ case "mistral":
+ return "https://api.mistral.ai/v1"
default:
return ""
}
diff --git a/pkg/providers/fallback.go b/pkg/providers/fallback.go
index 9b07f9153..ecd451ec9 100644
--- a/pkg/providers/fallback.go
+++ b/pkg/providers/fallback.go
@@ -110,7 +110,11 @@ func (fc *FallbackChain) Execute(
Model: candidate.Model,
Skipped: true,
Reason: FailoverRateLimit,
- Error: fmt.Errorf("provider %s in cooldown (%s remaining)", candidate.Provider, remaining.Round(time.Second)),
+ Error: fmt.Errorf(
+ "provider %s in cooldown (%s remaining)",
+ candidate.Provider,
+ remaining.Round(time.Second),
+ ),
})
continue
}
diff --git a/pkg/providers/fallback_test.go b/pkg/providers/fallback_test.go
index ea81e0d48..ebba054ef 100644
--- a/pkg/providers/fallback_test.go
+++ b/pkg/providers/fallback_test.go
@@ -17,12 +17,6 @@ func successRun(content string) func(ctx context.Context, provider, model string
}
}
-func failRun(err error) func(ctx context.Context, provider, model string) (*LLMResponse, error) {
- return func(ctx context.Context, provider, model string) (*LLMResponse, error) {
- return nil, err
- }
-}
-
func TestFallback_SingleCandidate_Success(t *testing.T) {
ct := NewCooldownTracker()
fc := NewFallbackChain(ct)
@@ -462,7 +456,13 @@ func TestResolveCandidates_EmptyPrimary(t *testing.T) {
func TestFallbackExhaustedError_Message(t *testing.T) {
e := &FallbackExhaustedError{
Attempts: []FallbackAttempt{
- {Provider: "openai", Model: "gpt-4", Error: errors.New("rate limited"), Reason: FailoverRateLimit, Duration: 500 * time.Millisecond},
+ {
+ Provider: "openai",
+ Model: "gpt-4",
+ Error: errors.New("rate limited"),
+ Reason: FailoverRateLimit,
+ Duration: 500 * time.Millisecond,
+ },
{Provider: "anthropic", Model: "claude", Skipped: true},
},
}
diff --git a/pkg/providers/github_copilot_provider.go b/pkg/providers/github_copilot_provider.go
index 5058819f5..3fb15db2f 100644
--- a/pkg/providers/github_copilot_provider.go
+++ b/pkg/providers/github_copilot_provider.go
@@ -2,60 +2,85 @@ package providers
import (
"context"
+ "encoding/json"
"fmt"
-
- json "encoding/json"
+ "sync"
copilot "github.com/github/copilot-sdk/go"
)
type GitHubCopilotProvider struct {
uri string
- connectMode string // `stdio` or `grpc``
+ connectMode string // "stdio" or "grpc"
+ client *copilot.Client
session *copilot.Session
+
+ mu sync.Mutex
}
func NewGitHubCopilotProvider(uri string, connectMode string, model string) (*GitHubCopilotProvider, error) {
-
- var session *copilot.Session
if connectMode == "" {
connectMode = "grpc"
}
- switch connectMode {
+ switch connectMode {
case "stdio":
- //todo
+ // TODO:
+ return nil, fmt.Errorf("stdio mode not implemented")
case "grpc":
client := copilot.NewClient(&copilot.ClientOptions{
CLIUrl: uri,
})
if err := client.Start(context.Background()); err != nil {
- return nil, fmt.Errorf("Can't connect to Github Copilot, https://github.com/github/copilot-sdk/blob/main/docs/getting-started.md#connecting-to-an-external-cli-server for details")
+ return nil, fmt.Errorf(
+ "can't connect to Github Copilot: %w; `https://github.com/github/copilot-sdk/blob/main/docs/getting-started.md#connecting-to-an-external-cli-server` for details",
+ err,
+ )
}
- defer client.Stop()
- session, _ = client.CreateSession(context.Background(), &copilot.SessionConfig{
+
+ session, err := client.CreateSession(context.Background(), &copilot.SessionConfig{
Model: model,
Hooks: &copilot.SessionHooks{},
})
+ if err != nil {
+ client.Stop()
+ return nil, fmt.Errorf("create session failed: %w", err)
+ }
+ return &GitHubCopilotProvider{
+ uri: uri,
+ connectMode: connectMode,
+ client: client,
+ session: session,
+ }, nil
+ default:
+ return nil, fmt.Errorf("unknown connect mode: %s", connectMode)
}
-
- return &GitHubCopilotProvider{
- uri: uri,
- connectMode: connectMode,
- session: session,
- }, nil
}
-// Chat sends a chat request to GitHub Copilot
-func (p *GitHubCopilotProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
+func (p *GitHubCopilotProvider) Close() {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ if p.client != nil {
+ p.client.Stop()
+ p.client = nil
+ p.session = nil
+ }
+}
+
+func (p *GitHubCopilotProvider) Chat(
+ ctx context.Context,
+ messages []Message,
+ tools []ToolDefinition,
+ model string,
+ options map[string]any,
+) (*LLMResponse, error) {
type tempMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
out := make([]tempMessage, 0, len(messages))
-
for _, msg := range messages {
out = append(out, tempMessage{
Role: msg.Role,
@@ -63,20 +88,36 @@ func (p *GitHubCopilotProvider) Chat(ctx context.Context, messages []Message, to
})
}
- fullcontent, _ := json.Marshal(out)
+ fullcontent, err := json.Marshal(out)
+ if err != nil {
+ return nil, fmt.Errorf("marshal messages: %w", err)
+ }
+ p.mu.Lock()
+ session := p.session
+ p.mu.Unlock()
- content, _ := p.session.Send(ctx, copilot.MessageOptions{
+ if session == nil {
+ return nil, fmt.Errorf("provider closed")
+ }
+
+ resp, _ := session.SendAndWait(ctx, copilot.MessageOptions{
Prompt: string(fullcontent),
})
+ if resp == nil {
+ return nil, fmt.Errorf("empty response from copilot")
+ }
+ if resp.Data.Content == nil {
+ return nil, fmt.Errorf("no content in copilot response")
+ }
+ content := *resp.Data.Content
+
return &LLMResponse{
FinishReason: "stop",
Content: content,
}, nil
-
}
func (p *GitHubCopilotProvider) GetDefaultModel() string {
-
return "gpt-4.1"
}
diff --git a/pkg/providers/http_provider.go b/pkg/providers/http_provider.go
index eeaa9690a..d0c4344f3 100644
--- a/pkg/providers/http_provider.go
+++ b/pkg/providers/http_provider.go
@@ -28,7 +28,13 @@ func NewHTTPProviderWithMaxTokensField(apiKey, apiBase, proxy, maxTokensField st
}
}
-func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
+func (p *HTTPProvider) Chat(
+ ctx context.Context,
+ messages []Message,
+ tools []ToolDefinition,
+ model string,
+ options map[string]any,
+) (*LLMResponse, error) {
return p.delegate.Chat(ctx, messages, tools, model, options)
}
diff --git a/pkg/providers/legacy_provider.go b/pkg/providers/legacy_provider.go
index eb13cec65..23f137538 100644
--- a/pkg/providers/legacy_provider.go
+++ b/pkg/providers/legacy_provider.go
@@ -16,7 +16,7 @@ import (
// The old providers config is automatically converted to model_list during config loading.
// Returns the provider, the model ID to use, and any error.
func CreateProvider(cfg *config.Config) (LLMProvider, string, error) {
- model := cfg.Agents.Defaults.Model
+ model := cfg.Agents.Defaults.GetModelName()
// Ensure model_list is populated (should be done by LoadConfig, but handle edge cases)
if len(cfg.ModelList) == 0 && cfg.HasProvidersConfig() {
diff --git a/pkg/providers/openai_compat/provider.go b/pkg/providers/openai_compat/provider.go
index 6bc43a470..087d3506e 100644
--- a/pkg/providers/openai_compat/provider.go
+++ b/pkg/providers/openai_compat/provider.go
@@ -15,15 +15,17 @@ import (
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
-type ToolCall = protocoltypes.ToolCall
-type FunctionCall = protocoltypes.FunctionCall
-type LLMResponse = protocoltypes.LLMResponse
-type UsageInfo = protocoltypes.UsageInfo
-type Message = protocoltypes.Message
-type ToolDefinition = protocoltypes.ToolDefinition
-type ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition
-type ExtraContent = protocoltypes.ExtraContent
-type GoogleExtra = protocoltypes.GoogleExtra
+type (
+ ToolCall = protocoltypes.ToolCall
+ FunctionCall = protocoltypes.FunctionCall
+ LLMResponse = protocoltypes.LLMResponse
+ UsageInfo = protocoltypes.UsageInfo
+ Message = protocoltypes.Message
+ ToolDefinition = protocoltypes.ToolDefinition
+ ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition
+ ExtraContent = protocoltypes.ExtraContent
+ GoogleExtra = protocoltypes.GoogleExtra
+)
type Provider struct {
apiKey string
@@ -60,16 +62,22 @@ func NewProviderWithMaxTokensField(apiKey, apiBase, proxy, maxTokensField string
}
}
-func (p *Provider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
+func (p *Provider) Chat(
+ ctx context.Context,
+ messages []Message,
+ tools []ToolDefinition,
+ model string,
+ options map[string]any,
+) (*LLMResponse, error) {
if p.apiBase == "" {
return nil, fmt.Errorf("API base not configured")
}
model = normalizeModel(model, p.apiBase)
- requestBody := map[string]interface{}{
+ requestBody := map[string]any{
"model": model,
- "messages": messages,
+ "messages": stripSystemParts(messages),
}
if len(tools) > 0 {
@@ -83,7 +91,8 @@ func (p *Provider) Chat(ctx context.Context, messages []Message, tools []ToolDef
if fieldName == "" {
// Fallback: detect from model name for backward compatibility
lowerModel := strings.ToLower(model)
- if strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "o1") || strings.Contains(lowerModel, "gpt-5") {
+ if strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "o1") ||
+ strings.Contains(lowerModel, "gpt-5") {
fieldName = "max_completion_tokens"
} else {
fieldName = "max_tokens"
@@ -102,6 +111,18 @@ func (p *Provider) Chat(ctx context.Context, messages []Message, tools []ToolDef
}
}
+ // Prompt caching: pass a stable cache key so OpenAI can bucket requests
+ // with the same key and reuse prefix KV cache across calls.
+ // The key is typically the agent ID — stable per agent, shared across requests.
+ // See: https://platform.openai.com/docs/guides/prompt-caching
+ // Prompt caching is only supported by OpenAI-native endpoints.
+ // Gemini and other providers reject unknown fields, so skip for non-OpenAI APIs.
+ if cacheKey, ok := options["prompt_cache_key"].(string); ok && cacheKey != "" {
+ if !strings.Contains(p.apiBase, "generativelanguage.googleapis.com") {
+ requestBody["prompt_cache_key"] = cacheKey
+ }
+ }
+
jsonData, err := json.Marshal(requestBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
@@ -139,8 +160,9 @@ func parseResponse(body []byte) (*LLMResponse, error) {
var apiResponse struct {
Choices []struct {
Message struct {
- Content string `json:"content"`
- ToolCalls []struct {
+ Content string `json:"content"`
+ ReasoningContent string `json:"reasoning_content"`
+ ToolCalls []struct {
ID string `json:"id"`
Type string `json:"type"`
Function *struct {
@@ -173,7 +195,7 @@ func parseResponse(body []byte) (*LLMResponse, error) {
choice := apiResponse.Choices[0]
toolCalls := make([]ToolCall, 0, len(choice.Message.ToolCalls))
for _, tc := range choice.Message.ToolCalls {
- arguments := make(map[string]interface{})
+ arguments := make(map[string]any)
name := ""
// Extract thought_signature from Gemini/Google-specific extra content
@@ -212,13 +234,40 @@ func parseResponse(body []byte) (*LLMResponse, error) {
}
return &LLMResponse{
- Content: choice.Message.Content,
- ToolCalls: toolCalls,
- FinishReason: choice.FinishReason,
- Usage: apiResponse.Usage,
+ Content: choice.Message.Content,
+ ReasoningContent: choice.Message.ReasoningContent,
+ ToolCalls: toolCalls,
+ FinishReason: choice.FinishReason,
+ Usage: apiResponse.Usage,
}, nil
}
+// openaiMessage is the wire-format message for OpenAI-compatible APIs.
+// It mirrors protocoltypes.Message but omits SystemParts, which is an
+// internal field that would be unknown to third-party endpoints.
+type openaiMessage struct {
+ Role string `json:"role"`
+ Content string `json:"content"`
+ ToolCalls []ToolCall `json:"tool_calls,omitempty"`
+ ToolCallID string `json:"tool_call_id,omitempty"`
+}
+
+// stripSystemParts converts []Message to []openaiMessage, dropping the
+// SystemParts field so it doesn't leak into the JSON payload sent to
+// OpenAI-compatible APIs (some strict endpoints reject unknown fields).
+func stripSystemParts(messages []Message) []openaiMessage {
+ out := make([]openaiMessage, len(messages))
+ for i, m := range messages {
+ out[i] = openaiMessage{
+ Role: m.Role,
+ Content: m.Content,
+ ToolCalls: m.ToolCalls,
+ ToolCallID: m.ToolCallID,
+ }
+ }
+ return out
+}
+
func normalizeModel(model, apiBase string) string {
idx := strings.Index(model, "/")
if idx == -1 {
@@ -231,14 +280,14 @@ func normalizeModel(model, apiBase string) string {
prefix := strings.ToLower(model[:idx])
switch prefix {
- case "moonshot", "nvidia", "groq", "ollama", "deepseek", "google", "openrouter", "zhipu":
+ case "moonshot", "nvidia", "groq", "ollama", "deepseek", "google", "openrouter", "zhipu", "mistral":
return model[idx+1:]
default:
return model
}
}
-func asInt(v interface{}) (int, bool) {
+func asInt(v any) (int, bool) {
switch val := v.(type) {
case int:
return val, true
@@ -253,7 +302,7 @@ func asInt(v interface{}) (int, bool) {
}
}
-func asFloat(v interface{}) (float64, bool) {
+func asFloat(v any) (float64, bool) {
switch val := v.(type) {
case float64:
return val, true
diff --git a/pkg/providers/openai_compat/provider_test.go b/pkg/providers/openai_compat/provider_test.go
index 94779b39c..594a48213 100644
--- a/pkg/providers/openai_compat/provider_test.go
+++ b/pkg/providers/openai_compat/provider_test.go
@@ -9,7 +9,7 @@ import (
)
func TestProviderChat_UsesMaxCompletionTokensForGLM(t *testing.T) {
- var requestBody map[string]interface{}
+ var requestBody map[string]any
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/chat/completions" {
@@ -20,10 +20,10 @@ func TestProviderChat_UsesMaxCompletionTokensForGLM(t *testing.T) {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
- resp := map[string]interface{}{
- "choices": []map[string]interface{}{
+ resp := map[string]any{
+ "choices": []map[string]any{
{
- "message": map[string]interface{}{"content": "ok"},
+ "message": map[string]any{"content": "ok"},
"finish_reason": "stop",
},
},
@@ -34,7 +34,13 @@ func TestProviderChat_UsesMaxCompletionTokensForGLM(t *testing.T) {
defer server.Close()
p := NewProvider("key", server.URL, "")
- _, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "glm-4.7", map[string]interface{}{"max_tokens": 1234})
+ _, err := p.Chat(
+ t.Context(),
+ []Message{{Role: "user", Content: "hi"}},
+ nil,
+ "glm-4.7",
+ map[string]any{"max_tokens": 1234},
+ )
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
@@ -49,16 +55,16 @@ func TestProviderChat_UsesMaxCompletionTokensForGLM(t *testing.T) {
func TestProviderChat_ParsesToolCalls(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- resp := map[string]interface{}{
- "choices": []map[string]interface{}{
+ resp := map[string]any{
+ "choices": []map[string]any{
{
- "message": map[string]interface{}{
+ "message": map[string]any{
"content": "",
- "tool_calls": []map[string]interface{}{
+ "tool_calls": []map[string]any{
{
"id": "call_1",
"type": "function",
- "function": map[string]interface{}{
+ "function": map[string]any{
"name": "get_weather",
"arguments": "{\"city\":\"SF\"}",
},
@@ -68,7 +74,7 @@ func TestProviderChat_ParsesToolCalls(t *testing.T) {
"finish_reason": "tool_calls",
},
},
- "usage": map[string]interface{}{
+ "usage": map[string]any{
"prompt_tokens": 10,
"completion_tokens": 5,
"total_tokens": 15,
@@ -95,6 +101,50 @@ func TestProviderChat_ParsesToolCalls(t *testing.T) {
}
}
+func TestProviderChat_ParsesReasoningContent(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ resp := map[string]any{
+ "choices": []map[string]any{
+ {
+ "message": map[string]any{
+ "content": "The answer is 2",
+ "reasoning_content": "Let me think step by step... 1+1=2",
+ "tool_calls": []map[string]any{
+ {
+ "id": "call_1",
+ "type": "function",
+ "function": map[string]any{
+ "name": "calculator",
+ "arguments": "{\"expr\":\"1+1\"}",
+ },
+ },
+ },
+ },
+ "finish_reason": "tool_calls",
+ },
+ },
+ }
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(resp)
+ }))
+ defer server.Close()
+
+ p := NewProvider("key", server.URL, "")
+ out, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "1+1=?"}}, nil, "kimi-k2.5", nil)
+ if err != nil {
+ t.Fatalf("Chat() error = %v", err)
+ }
+ if out.ReasoningContent != "Let me think step by step... 1+1=2" {
+ t.Fatalf("ReasoningContent = %q, want %q", out.ReasoningContent, "Let me think step by step... 1+1=2")
+ }
+ if out.Content != "The answer is 2" {
+ t.Fatalf("Content = %q, want %q", out.Content, "The answer is 2")
+ }
+ if len(out.ToolCalls) != 1 {
+ t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls))
+ }
+}
+
func TestProviderChat_HTTPError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "bad request", http.StatusBadRequest)
@@ -109,17 +159,17 @@ func TestProviderChat_HTTPError(t *testing.T) {
}
func TestProviderChat_StripsMoonshotPrefixAndNormalizesKimiTemperature(t *testing.T) {
- var requestBody map[string]interface{}
+ var requestBody map[string]any
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
- resp := map[string]interface{}{
- "choices": []map[string]interface{}{
+ resp := map[string]any{
+ "choices": []map[string]any{
{
- "message": map[string]interface{}{"content": "ok"},
+ "message": map[string]any{"content": "ok"},
"finish_reason": "stop",
},
},
@@ -135,7 +185,7 @@ func TestProviderChat_StripsMoonshotPrefixAndNormalizesKimiTemperature(t *testin
[]Message{{Role: "user", Content: "hi"}},
nil,
"moonshot/kimi-k2.5",
- map[string]interface{}{"temperature": 0.3},
+ map[string]any{"temperature": 0.3},
)
if err != nil {
t.Fatalf("Chat() error = %v", err)
@@ -174,17 +224,17 @@ func TestProviderChat_StripsGroqAndOllamaPrefixes(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- var requestBody map[string]interface{}
+ var requestBody map[string]any
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
- resp := map[string]interface{}{
- "choices": []map[string]interface{}{
+ resp := map[string]any{
+ "choices": []map[string]any{
{
- "message": map[string]interface{}{"content": "ok"},
+ "message": map[string]any{"content": "ok"},
"finish_reason": "stop",
},
},
@@ -227,17 +277,17 @@ func TestProvider_ProxyConfigured(t *testing.T) {
}
func TestProviderChat_AcceptsNumericOptionTypes(t *testing.T) {
- var requestBody map[string]interface{}
+ var requestBody map[string]any
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
- resp := map[string]interface{}{
- "choices": []map[string]interface{}{
+ resp := map[string]any{
+ "choices": []map[string]any{
{
- "message": map[string]interface{}{"content": "ok"},
+ "message": map[string]any{"content": "ok"},
"finish_reason": "stop",
},
},
@@ -253,7 +303,7 @@ func TestProviderChat_AcceptsNumericOptionTypes(t *testing.T) {
[]Message{{Role: "user", Content: "hi"}},
nil,
"gpt-4o",
- map[string]interface{}{"max_tokens": float64(512), "temperature": 1},
+ map[string]any{"max_tokens": float64(512), "temperature": 1},
)
if err != nil {
t.Fatalf("Chat() error = %v", err)
diff --git a/pkg/providers/protocoltypes/types.go b/pkg/providers/protocoltypes/types.go
index b7e7062b9..33f052c5a 100644
--- a/pkg/providers/protocoltypes/types.go
+++ b/pkg/providers/protocoltypes/types.go
@@ -1,13 +1,13 @@
package protocoltypes
type ToolCall struct {
- ID string `json:"id"`
- Type string `json:"type,omitempty"`
- Function *FunctionCall `json:"function,omitempty"`
- Name string `json:"name,omitempty"`
- Arguments map[string]interface{} `json:"arguments,omitempty"`
- ThoughtSignature string `json:"-"` // Internal use only
- ExtraContent *ExtraContent `json:"extra_content,omitempty"`
+ ID string `json:"id"`
+ Type string `json:"type,omitempty"`
+ Function *FunctionCall `json:"function,omitempty"`
+ Name string `json:"-"`
+ Arguments map[string]any `json:"-"`
+ ThoughtSignature string `json:"-"` // Internal use only
+ ExtraContent *ExtraContent `json:"extra_content,omitempty"`
}
type ExtraContent struct {
@@ -25,10 +25,11 @@ type FunctionCall struct {
}
type LLMResponse struct {
- Content string `json:"content"`
- ToolCalls []ToolCall `json:"tool_calls,omitempty"`
- FinishReason string `json:"finish_reason"`
- Usage *UsageInfo `json:"usage,omitempty"`
+ Content string `json:"content"`
+ ReasoningContent string `json:"reasoning_content,omitempty"`
+ ToolCalls []ToolCall `json:"tool_calls,omitempty"`
+ FinishReason string `json:"finish_reason"`
+ Usage *UsageInfo `json:"usage,omitempty"`
}
type UsageInfo struct {
@@ -37,11 +38,28 @@ type UsageInfo struct {
TotalTokens int `json:"total_tokens"`
}
+// CacheControl marks a content block for LLM-side prefix caching.
+// Currently only "ephemeral" is supported (used by Anthropic).
+type CacheControl struct {
+ Type string `json:"type"` // "ephemeral"
+}
+
+// ContentBlock represents a structured segment of a system message.
+// Adapters that understand SystemParts can use these blocks to set
+// per-block cache control (e.g. Anthropic's cache_control: ephemeral).
+type ContentBlock struct {
+ Type string `json:"type"` // "text"
+ Text string `json:"text"`
+ CacheControl *CacheControl `json:"cache_control,omitempty"`
+}
+
type Message struct {
- Role string `json:"role"`
- Content string `json:"content"`
- ToolCalls []ToolCall `json:"tool_calls,omitempty"`
- ToolCallID string `json:"tool_call_id,omitempty"`
+ Role string `json:"role"`
+ Content string `json:"content"`
+ ReasoningContent string `json:"reasoning_content,omitempty"`
+ SystemParts []ContentBlock `json:"system_parts,omitempty"` // structured system blocks for cache-aware adapters
+ ToolCalls []ToolCall `json:"tool_calls,omitempty"`
+ ToolCallID string `json:"tool_call_id,omitempty"`
}
type ToolDefinition struct {
@@ -50,7 +68,7 @@ type ToolDefinition struct {
}
type ToolFunctionDefinition struct {
- Name string `json:"name"`
- Description string `json:"description"`
- Parameters map[string]interface{} `json:"parameters"`
+ Name string `json:"name"`
+ Description string `json:"description"`
+ Parameters map[string]any `json:"parameters"`
}
diff --git a/pkg/providers/tool_call_extract.go b/pkg/providers/tool_call_extract.go
index 97a219283..7ddea0e99 100644
--- a/pkg/providers/tool_call_extract.go
+++ b/pkg/providers/tool_call_extract.go
@@ -38,7 +38,7 @@ func extractToolCallsFromText(text string) []ToolCall {
var result []ToolCall
for _, tc := range wrapper.ToolCalls {
- var args map[string]interface{}
+ var args map[string]any
json.Unmarshal([]byte(tc.Function.Arguments), &args)
result = append(result, ToolCall{
diff --git a/pkg/providers/toolcall_utils.go b/pkg/providers/toolcall_utils.go
index c7c35ef42..49218b1b1 100644
--- a/pkg/providers/toolcall_utils.go
+++ b/pkg/providers/toolcall_utils.go
@@ -20,12 +20,12 @@ func NormalizeToolCall(tc ToolCall) ToolCall {
// Ensure Arguments is not nil
if normalized.Arguments == nil {
- normalized.Arguments = map[string]interface{}{}
+ normalized.Arguments = map[string]any{}
}
// Parse Arguments from Function.Arguments if not already set
if len(normalized.Arguments) == 0 && normalized.Function != nil && normalized.Function.Arguments != "" {
- var parsed map[string]interface{}
+ var parsed map[string]any
if err := json.Unmarshal([]byte(normalized.Function.Arguments), &parsed); err == nil && parsed != nil {
normalized.Arguments = parsed
}
diff --git a/pkg/providers/types.go b/pkg/providers/types.go
index e783e6348..f0c168bc6 100644
--- a/pkg/providers/types.go
+++ b/pkg/providers/types.go
@@ -7,21 +7,36 @@ import (
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
-type ToolCall = protocoltypes.ToolCall
-type FunctionCall = protocoltypes.FunctionCall
-type LLMResponse = protocoltypes.LLMResponse
-type UsageInfo = protocoltypes.UsageInfo
-type Message = protocoltypes.Message
-type ToolDefinition = protocoltypes.ToolDefinition
-type ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition
-type ExtraContent = protocoltypes.ExtraContent
-type GoogleExtra = protocoltypes.GoogleExtra
+type (
+ ToolCall = protocoltypes.ToolCall
+ FunctionCall = protocoltypes.FunctionCall
+ LLMResponse = protocoltypes.LLMResponse
+ UsageInfo = protocoltypes.UsageInfo
+ Message = protocoltypes.Message
+ ToolDefinition = protocoltypes.ToolDefinition
+ ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition
+ ExtraContent = protocoltypes.ExtraContent
+ GoogleExtra = protocoltypes.GoogleExtra
+ ContentBlock = protocoltypes.ContentBlock
+ CacheControl = protocoltypes.CacheControl
+)
type LLMProvider interface {
- Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error)
+ Chat(
+ ctx context.Context,
+ messages []Message,
+ tools []ToolDefinition,
+ model string,
+ options map[string]any,
+ ) (*LLMResponse, error)
GetDefaultModel() string
}
+type StatefulProvider interface {
+ LLMProvider
+ Close()
+}
+
// FailoverReason classifies why an LLM request failed for fallback decisions.
type FailoverReason string
diff --git a/pkg/session/manager.go b/pkg/session/manager.go
index 12bf33df0..08f0b0ad2 100644
--- a/pkg/session/manager.go
+++ b/pkg/session/manager.go
@@ -32,7 +32,7 @@ func NewSessionManager(storage string) *SessionManager {
}
if storage != "" {
- os.MkdirAll(storage, 0755)
+ os.MkdirAll(storage, 0o755)
sm.loadSessions()
}
@@ -214,7 +214,7 @@ func (sm *SessionManager) Save(key string) error {
_ = tmpFile.Close()
return err
}
- if err := tmpFile.Chmod(0644); err != nil {
+ if err := tmpFile.Chmod(0o644); err != nil {
_ = tmpFile.Close()
return err
}
diff --git a/pkg/skills/clawhub_registry.go b/pkg/skills/clawhub_registry.go
new file mode 100644
index 000000000..f78197bbe
--- /dev/null
+++ b/pkg/skills/clawhub_registry.go
@@ -0,0 +1,314 @@
+package skills
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "os"
+ "time"
+
+ "github.com/sipeed/picoclaw/pkg/utils"
+)
+
+const (
+ defaultClawHubTimeout = 30 * time.Second
+ defaultMaxZipSize = 50 * 1024 * 1024 // 50 MB
+ defaultMaxResponseSize = 2 * 1024 * 1024 // 2 MB
+)
+
+// ClawHubRegistry implements SkillRegistry for the ClawHub platform.
+type ClawHubRegistry struct {
+ baseURL string
+ authToken string // Optional - for elevated rate limits
+ searchPath string // Search API
+ skillsPath string // For retrieving skill metadata
+ downloadPath string // For fetching ZIP files for download
+ maxZipSize int
+ maxResponseSize int
+ client *http.Client
+}
+
+// NewClawHubRegistry creates a new ClawHub registry client from config.
+func NewClawHubRegistry(cfg ClawHubConfig) *ClawHubRegistry {
+ baseURL := cfg.BaseURL
+ if baseURL == "" {
+ baseURL = "https://clawhub.ai"
+ }
+ searchPath := cfg.SearchPath
+ if searchPath == "" {
+ searchPath = "/api/v1/search"
+ }
+ skillsPath := cfg.SkillsPath
+ if skillsPath == "" {
+ skillsPath = "/api/v1/skills"
+ }
+ downloadPath := cfg.DownloadPath
+ if downloadPath == "" {
+ downloadPath = "/api/v1/download"
+ }
+
+ timeout := defaultClawHubTimeout
+ if cfg.Timeout > 0 {
+ timeout = time.Duration(cfg.Timeout) * time.Second
+ }
+
+ maxZip := defaultMaxZipSize
+ if cfg.MaxZipSize > 0 {
+ maxZip = cfg.MaxZipSize
+ }
+
+ maxResp := defaultMaxResponseSize
+ if cfg.MaxResponseSize > 0 {
+ maxResp = cfg.MaxResponseSize
+ }
+
+ return &ClawHubRegistry{
+ baseURL: baseURL,
+ authToken: cfg.AuthToken,
+ searchPath: searchPath,
+ skillsPath: skillsPath,
+ downloadPath: downloadPath,
+ maxZipSize: maxZip,
+ maxResponseSize: maxResp,
+ client: &http.Client{
+ Timeout: timeout,
+ Transport: &http.Transport{
+ MaxIdleConns: 5,
+ IdleConnTimeout: 30 * time.Second,
+ TLSHandshakeTimeout: 10 * time.Second,
+ },
+ },
+ }
+}
+
+func (c *ClawHubRegistry) Name() string {
+ return "clawhub"
+}
+
+// --- Search ---
+
+type clawhubSearchResponse struct {
+ Results []clawhubSearchResult `json:"results"`
+}
+
+type clawhubSearchResult struct {
+ Score float64 `json:"score"`
+ Slug *string `json:"slug"`
+ DisplayName *string `json:"displayName"`
+ Summary *string `json:"summary"`
+ Version *string `json:"version"`
+}
+
+func (c *ClawHubRegistry) Search(ctx context.Context, query string, limit int) ([]SearchResult, error) {
+ u, err := url.Parse(c.baseURL + c.searchPath)
+ if err != nil {
+ return nil, fmt.Errorf("invalid base URL: %w", err)
+ }
+
+ q := u.Query()
+ q.Set("q", query)
+ if limit > 0 {
+ q.Set("limit", fmt.Sprintf("%d", limit))
+ }
+ u.RawQuery = q.Encode()
+
+ body, err := c.doGet(ctx, u.String())
+ if err != nil {
+ return nil, fmt.Errorf("search request failed: %w", err)
+ }
+
+ var resp clawhubSearchResponse
+ if err := json.Unmarshal(body, &resp); err != nil {
+ return nil, fmt.Errorf("failed to parse search response: %w", err)
+ }
+
+ results := make([]SearchResult, 0, len(resp.Results))
+ for _, r := range resp.Results {
+ slug := utils.DerefStr(r.Slug, "")
+ if slug == "" {
+ continue
+ }
+
+ summary := utils.DerefStr(r.Summary, "")
+ if summary == "" {
+ continue
+ }
+
+ displayName := utils.DerefStr(r.DisplayName, "")
+ if displayName == "" {
+ displayName = slug
+ }
+
+ results = append(results, SearchResult{
+ Score: r.Score,
+ Slug: slug,
+ DisplayName: displayName,
+ Summary: summary,
+ Version: utils.DerefStr(r.Version, ""),
+ RegistryName: c.Name(),
+ })
+ }
+
+ return results, nil
+}
+
+// --- GetSkillMeta ---
+
+type clawhubSkillResponse struct {
+ Slug string `json:"slug"`
+ DisplayName string `json:"displayName"`
+ Summary string `json:"summary"`
+ LatestVersion *clawhubVersionInfo `json:"latestVersion"`
+ Moderation *clawhubModerationInfo `json:"moderation"`
+}
+
+type clawhubVersionInfo struct {
+ Version string `json:"version"`
+}
+
+type clawhubModerationInfo struct {
+ IsMalwareBlocked bool `json:"isMalwareBlocked"`
+ IsSuspicious bool `json:"isSuspicious"`
+}
+
+func (c *ClawHubRegistry) GetSkillMeta(ctx context.Context, slug string) (*SkillMeta, error) {
+ if err := utils.ValidateSkillIdentifier(slug); err != nil {
+ return nil, fmt.Errorf("invalid slug %q: error: %s", slug, err.Error())
+ }
+
+ u := c.baseURL + c.skillsPath + "/" + url.PathEscape(slug)
+
+ body, err := c.doGet(ctx, u)
+ if err != nil {
+ return nil, fmt.Errorf("skill metadata request failed: %w", err)
+ }
+
+ var resp clawhubSkillResponse
+ if err := json.Unmarshal(body, &resp); err != nil {
+ return nil, fmt.Errorf("failed to parse skill metadata: %w", err)
+ }
+
+ meta := &SkillMeta{
+ Slug: resp.Slug,
+ DisplayName: resp.DisplayName,
+ Summary: resp.Summary,
+ RegistryName: c.Name(),
+ }
+
+ if resp.LatestVersion != nil {
+ meta.LatestVersion = resp.LatestVersion.Version
+ }
+ if resp.Moderation != nil {
+ meta.IsMalwareBlocked = resp.Moderation.IsMalwareBlocked
+ meta.IsSuspicious = resp.Moderation.IsSuspicious
+ }
+
+ return meta, nil
+}
+
+// --- DownloadAndInstall ---
+
+// DownloadAndInstall fetches metadata (with fallback), resolves version,
+// downloads the skill ZIP, and extracts it to targetDir.
+// Returns an InstallResult for the caller to use for moderation decisions.
+func (c *ClawHubRegistry) DownloadAndInstall(
+ ctx context.Context,
+ slug, version, targetDir string,
+) (*InstallResult, error) {
+ if err := utils.ValidateSkillIdentifier(slug); err != nil {
+ return nil, fmt.Errorf("invalid slug %q: error: %s", slug, err.Error())
+ }
+
+ // Step 1: Fetch metadata (with fallback).
+ result := &InstallResult{}
+ meta, err := c.GetSkillMeta(ctx, slug)
+ if err != nil {
+ // Fallback: proceed without metadata.
+ meta = nil
+ }
+
+ if meta != nil {
+ result.IsMalwareBlocked = meta.IsMalwareBlocked
+ result.IsSuspicious = meta.IsSuspicious
+ result.Summary = meta.Summary
+ }
+
+ // Step 2: Resolve version.
+ installVersion := version
+ if installVersion == "" && meta != nil {
+ installVersion = meta.LatestVersion
+ }
+ if installVersion == "" {
+ installVersion = "latest"
+ }
+ result.Version = installVersion
+
+ // Step 3: Download ZIP to temp file (streams in ~32KB chunks).
+ u, err := url.Parse(c.baseURL + c.downloadPath)
+ if err != nil {
+ return nil, fmt.Errorf("invalid base URL: %w", err)
+ }
+
+ q := u.Query()
+ q.Set("slug", slug)
+ if installVersion != "latest" {
+ q.Set("version", installVersion)
+ }
+ u.RawQuery = q.Encode()
+
+ req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create request: %w", err)
+ }
+ if c.authToken != "" {
+ req.Header.Set("Authorization", "Bearer "+c.authToken)
+ }
+
+ tmpPath, err := utils.DownloadToFile(ctx, c.client, req, int64(c.maxZipSize))
+ if err != nil {
+ return nil, fmt.Errorf("download failed: %w", err)
+ }
+ defer os.Remove(tmpPath)
+
+ // Step 4: Extract from file on disk.
+ if err := utils.ExtractZipFile(tmpPath, targetDir); err != nil {
+ return nil, err
+ }
+
+ return result, nil
+}
+
+// --- HTTP helper ---
+
+func (c *ClawHubRegistry) doGet(ctx context.Context, urlStr string) ([]byte, error) {
+ req, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ req.Header.Set("Accept", "application/json")
+ if c.authToken != "" {
+ req.Header.Set("Authorization", "Bearer "+c.authToken)
+ }
+
+ resp, err := c.client.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ defer resp.Body.Close()
+
+ // Limit response body read to prevent memory issues.
+ body, err := io.ReadAll(io.LimitReader(resp.Body, int64(c.maxResponseSize)))
+ if err != nil {
+ return nil, fmt.Errorf("failed to read response: %w", err)
+ }
+
+ if resp.StatusCode < 200 || resp.StatusCode >= 300 {
+ return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body))
+ }
+
+ return body, nil
+}
diff --git a/pkg/skills/clawhub_registry_test.go b/pkg/skills/clawhub_registry_test.go
new file mode 100644
index 000000000..65ee638da
--- /dev/null
+++ b/pkg/skills/clawhub_registry_test.go
@@ -0,0 +1,257 @@
+package skills
+
+import (
+ "archive/zip"
+ "bytes"
+ "context"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "path/filepath"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "github.com/sipeed/picoclaw/pkg/utils"
+)
+
+func newTestRegistry(serverURL, authToken string) *ClawHubRegistry {
+ return NewClawHubRegistry(ClawHubConfig{
+ Enabled: true,
+ BaseURL: serverURL,
+ AuthToken: authToken,
+ })
+}
+
+func TestClawHubRegistrySearch(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ assert.Equal(t, "/api/v1/search", r.URL.Path)
+ assert.Equal(t, "github", r.URL.Query().Get("q"))
+
+ slug := "github"
+ name := "GitHub Integration"
+ summary := "Interact with GitHub repos"
+ version := "1.0.0"
+
+ json.NewEncoder(w).Encode(clawhubSearchResponse{
+ Results: []clawhubSearchResult{
+ {Score: 0.95, Slug: &slug, DisplayName: &name, Summary: &summary, Version: &version},
+ },
+ })
+ }))
+ defer srv.Close()
+
+ reg := newTestRegistry(srv.URL, "")
+ results, err := reg.Search(context.Background(), "github", 5)
+
+ require.NoError(t, err)
+ require.Len(t, results, 1)
+ assert.Equal(t, "github", results[0].Slug)
+ assert.Equal(t, "GitHub Integration", results[0].DisplayName)
+ assert.InDelta(t, 0.95, results[0].Score, 0.001)
+ assert.Equal(t, "clawhub", results[0].RegistryName)
+}
+
+func TestClawHubRegistryGetSkillMeta(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ assert.Equal(t, "/api/v1/skills/github", r.URL.Path)
+
+ json.NewEncoder(w).Encode(clawhubSkillResponse{
+ Slug: "github",
+ DisplayName: "GitHub Integration",
+ Summary: "Full GitHub API integration",
+ LatestVersion: &clawhubVersionInfo{
+ Version: "2.1.0",
+ },
+ Moderation: &clawhubModerationInfo{
+ IsMalwareBlocked: false,
+ IsSuspicious: true,
+ },
+ })
+ }))
+ defer srv.Close()
+
+ reg := newTestRegistry(srv.URL, "")
+ meta, err := reg.GetSkillMeta(context.Background(), "github")
+
+ require.NoError(t, err)
+ assert.Equal(t, "github", meta.Slug)
+ assert.Equal(t, "2.1.0", meta.LatestVersion)
+ assert.False(t, meta.IsMalwareBlocked)
+ assert.True(t, meta.IsSuspicious)
+}
+
+func TestClawHubRegistryGetSkillMetaUnsafeSlug(t *testing.T) {
+ reg := newTestRegistry("https://example.com", "")
+ _, err := reg.GetSkillMeta(context.Background(), "../etc/passwd")
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "invalid slug")
+}
+
+func TestClawHubRegistryDownloadAndInstall(t *testing.T) {
+ // Create a valid ZIP in memory.
+ zipBuf := createTestZip(t, map[string]string{
+ "SKILL.md": "---\nname: test-skill\ndescription: A test\n---\nHello skill",
+ "README.md": "# Test Skill\n",
+ })
+
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/api/v1/skills/test-skill":
+ // Metadata endpoint.
+ json.NewEncoder(w).Encode(clawhubSkillResponse{
+ Slug: "test-skill",
+ DisplayName: "Test Skill",
+ Summary: "A test skill",
+ LatestVersion: &clawhubVersionInfo{Version: "1.0.0"},
+ })
+ case "/api/v1/download":
+ assert.Equal(t, "test-skill", r.URL.Query().Get("slug"))
+ w.Header().Set("Content-Type", "application/zip")
+ w.Write(zipBuf)
+ default:
+ w.WriteHeader(http.StatusNotFound)
+ }
+ }))
+ defer srv.Close()
+
+ tmpDir := t.TempDir()
+ targetDir := filepath.Join(tmpDir, "test-skill")
+
+ reg := newTestRegistry(srv.URL, "")
+ result, err := reg.DownloadAndInstall(context.Background(), "test-skill", "1.0.0", targetDir)
+
+ require.NoError(t, err)
+ assert.Equal(t, "1.0.0", result.Version)
+ assert.False(t, result.IsMalwareBlocked)
+
+ // Verify extracted files.
+ skillContent, err := os.ReadFile(filepath.Join(targetDir, "SKILL.md"))
+ require.NoError(t, err)
+ assert.Contains(t, string(skillContent), "Hello skill")
+
+ readmeContent, err := os.ReadFile(filepath.Join(targetDir, "README.md"))
+ require.NoError(t, err)
+ assert.Contains(t, string(readmeContent), "# Test Skill")
+}
+
+func TestClawHubRegistryAuthToken(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ authHeader := r.Header.Get("Authorization")
+ assert.Equal(t, "Bearer test-token-123", authHeader)
+ json.NewEncoder(w).Encode(clawhubSearchResponse{Results: nil})
+ }))
+ defer srv.Close()
+
+ reg := newTestRegistry(srv.URL, "test-token-123")
+ _, _ = reg.Search(context.Background(), "test", 5)
+}
+
+func TestExtractZipPathTraversal(t *testing.T) {
+ // Create a ZIP with a path traversal entry.
+ var buf bytes.Buffer
+ zw := zip.NewWriter(&buf)
+
+ // Malicious entry trying to escape directory.
+ w, err := zw.Create("../../etc/passwd")
+ require.NoError(t, err)
+ w.Write([]byte("malicious"))
+
+ zw.Close()
+
+ // Write to temp file for extractZipFile.
+ tmpZip := filepath.Join(t.TempDir(), "bad.zip")
+ require.NoError(t, os.WriteFile(tmpZip, buf.Bytes(), 0o644))
+
+ tmpDir := t.TempDir()
+ err = utils.ExtractZipFile(tmpZip, tmpDir)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "unsafe path")
+}
+
+func TestExtractZipWithSubdirectories(t *testing.T) {
+ zipBuf := createTestZip(t, map[string]string{
+ "SKILL.md": "root file",
+ "scripts/helper.sh": "#!/bin/bash\necho hello",
+ "examples/demo.yaml": "key: value",
+ })
+
+ // Write to temp file for extractZipFile.
+ tmpZip := filepath.Join(t.TempDir(), "test.zip")
+ require.NoError(t, os.WriteFile(tmpZip, zipBuf, 0o644))
+
+ tmpDir := t.TempDir()
+ targetDir := filepath.Join(tmpDir, "my-skill")
+
+ err := utils.ExtractZipFile(tmpZip, targetDir)
+ require.NoError(t, err)
+
+ // Verify nested file.
+ data, err := os.ReadFile(filepath.Join(targetDir, "scripts", "helper.sh"))
+ require.NoError(t, err)
+ assert.Contains(t, string(data), "#!/bin/bash")
+}
+
+func TestClawHubRegistrySearchHTTPError(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusInternalServerError)
+ w.Write([]byte("Internal Server Error"))
+ }))
+ defer srv.Close()
+
+ reg := newTestRegistry(srv.URL, "")
+ _, err := reg.Search(context.Background(), "test", 5)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "500")
+}
+
+func TestClawHubRegistrySearchNullableFields(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ validSlug := "valid-slug"
+ validSummary := "valid summary"
+
+ // Return results with various null/empty fields
+ json.NewEncoder(w).Encode(clawhubSearchResponse{
+ Results: []clawhubSearchResult{
+ // Case 1: Null Slug -> Skip
+ {Score: 0.1, Slug: nil, DisplayName: nil, Summary: nil, Version: nil},
+ // Case 2: Valid Slug, Null Summary -> Skip
+ {Score: 0.2, Slug: &validSlug, DisplayName: nil, Summary: nil, Version: nil},
+ // Case 3: Valid Slug, Valid Summary, Null Name -> Keep, Name=Slug
+ {Score: 0.8, Slug: &validSlug, DisplayName: nil, Summary: &validSummary, Version: nil},
+ },
+ })
+ }))
+ defer srv.Close()
+
+ reg := newTestRegistry(srv.URL, "")
+ results, err := reg.Search(context.Background(), "test", 5)
+
+ require.NoError(t, err)
+ require.Len(t, results, 1, "should only return 1 valid result")
+
+ r := results[0]
+ assert.Equal(t, "valid-slug", r.Slug)
+ assert.Equal(t, "valid-slug", r.DisplayName, "should fallback name to slug")
+ assert.Equal(t, "valid summary", r.Summary)
+}
+
+// --- helpers ---
+
+func createTestZip(t *testing.T, files map[string]string) []byte {
+ t.Helper()
+ var buf bytes.Buffer
+ zw := zip.NewWriter(&buf)
+
+ for name, content := range files {
+ w, err := zw.Create(name)
+ require.NoError(t, err)
+ _, err = w.Write([]byte(content))
+ require.NoError(t, err)
+ }
+
+ require.NoError(t, zw.Close())
+ return buf.Bytes()
+}
diff --git a/pkg/skills/installer.go b/pkg/skills/installer.go
index 0856254e8..3210509df 100644
--- a/pkg/skills/installer.go
+++ b/pkg/skills/installer.go
@@ -59,12 +59,12 @@ func (si *SkillInstaller) InstallFromGitHub(ctx context.Context, repo string) er
return fmt.Errorf("failed to read response: %w", err)
}
- if err := os.MkdirAll(skillDir, 0755); err != nil {
+ if err := os.MkdirAll(skillDir, 0o755); err != nil {
return fmt.Errorf("failed to create skill directory: %w", err)
}
skillPath := filepath.Join(skillDir, "SKILL.md")
- if err := os.WriteFile(skillPath, body, 0644); err != nil {
+ if err := os.WriteFile(skillPath, body, 0o644); err != nil {
return fmt.Errorf("failed to write skill file: %w", err)
}
diff --git a/pkg/skills/loader.go b/pkg/skills/loader.go
index bb0abbdcc..5749d8983 100644
--- a/pkg/skills/loader.go
+++ b/pkg/skills/loader.go
@@ -55,9 +55,9 @@ func (info SkillInfo) validate() error {
type SkillsLoader struct {
workspace string
- workspaceSkills string // workspace skills (项目级别)
- globalSkills string // 全局 skills (~/.picoclaw/skills)
- builtinSkills string // 内置 skills
+ workspaceSkills string // workspace skills (project-level)
+ globalSkills string // global skills (~/.picoclaw/skills)
+ builtinSkills string // builtin skills
}
func NewSkillsLoader(workspace string, globalSkills string, builtinSkills string) *SkillsLoader {
@@ -71,118 +71,56 @@ func NewSkillsLoader(workspace string, globalSkills string, builtinSkills string
func (sl *SkillsLoader) ListSkills() []SkillInfo {
skills := make([]SkillInfo, 0)
+ seen := make(map[string]bool)
- if sl.workspaceSkills != "" {
- if dirs, err := os.ReadDir(sl.workspaceSkills); err == nil {
- for _, dir := range dirs {
- if dir.IsDir() {
- skillFile := filepath.Join(sl.workspaceSkills, dir.Name(), "SKILL.md")
- if _, err := os.Stat(skillFile); err == nil {
- info := SkillInfo{
- Name: dir.Name(),
- Path: skillFile,
- Source: "workspace",
- }
- metadata := sl.getSkillMetadata(skillFile)
- if metadata != nil {
- info.Description = metadata.Description
- info.Name = metadata.Name
- }
- if err := info.validate(); err != nil {
- slog.Warn("invalid skill from workspace", "name", info.Name, "error", err)
- continue
- }
- skills = append(skills, info)
- }
- }
+ addSkills := func(dir, source string) {
+ if dir == "" {
+ return
+ }
+ dirs, err := os.ReadDir(dir)
+ if err != nil {
+ return
+ }
+ for _, d := range dirs {
+ if !d.IsDir() {
+ continue
}
+ skillFile := filepath.Join(dir, d.Name(), "SKILL.md")
+ if _, err := os.Stat(skillFile); err != nil {
+ continue
+ }
+ info := SkillInfo{
+ Name: d.Name(),
+ Path: skillFile,
+ Source: source,
+ }
+ metadata := sl.getSkillMetadata(skillFile)
+ if metadata != nil {
+ info.Description = metadata.Description
+ info.Name = metadata.Name
+ }
+ if err := info.validate(); err != nil {
+ slog.Warn("invalid skill from "+source, "name", info.Name, "error", err)
+ continue
+ }
+ if seen[info.Name] {
+ continue
+ }
+ seen[info.Name] = true
+ skills = append(skills, info)
}
}
- // 全局 skills (~/.picoclaw/skills) - 被 workspace skills 覆盖
- if sl.globalSkills != "" {
- if dirs, err := os.ReadDir(sl.globalSkills); err == nil {
- for _, dir := range dirs {
- if dir.IsDir() {
- skillFile := filepath.Join(sl.globalSkills, dir.Name(), "SKILL.md")
- if _, err := os.Stat(skillFile); err == nil {
- // 检查是否已被 workspace skills 覆盖
- exists := false
- for _, s := range skills {
- if s.Name == dir.Name() && s.Source == "workspace" {
- exists = true
- break
- }
- }
- if exists {
- continue
- }
-
- info := SkillInfo{
- Name: dir.Name(),
- Path: skillFile,
- Source: "global",
- }
- metadata := sl.getSkillMetadata(skillFile)
- if metadata != nil {
- info.Description = metadata.Description
- info.Name = metadata.Name
- }
- if err := info.validate(); err != nil {
- slog.Warn("invalid skill from global", "name", info.Name, "error", err)
- continue
- }
- skills = append(skills, info)
- }
- }
- }
- }
- }
-
- if sl.builtinSkills != "" {
- if dirs, err := os.ReadDir(sl.builtinSkills); err == nil {
- for _, dir := range dirs {
- if dir.IsDir() {
- skillFile := filepath.Join(sl.builtinSkills, dir.Name(), "SKILL.md")
- if _, err := os.Stat(skillFile); err == nil {
- // 检查是否已被 workspace 或 global skills 覆盖
- exists := false
- for _, s := range skills {
- if s.Name == dir.Name() && (s.Source == "workspace" || s.Source == "global") {
- exists = true
- break
- }
- }
- if exists {
- continue
- }
-
- info := SkillInfo{
- Name: dir.Name(),
- Path: skillFile,
- Source: "builtin",
- }
- metadata := sl.getSkillMetadata(skillFile)
- if metadata != nil {
- info.Description = metadata.Description
- info.Name = metadata.Name
- }
- if err := info.validate(); err != nil {
- slog.Warn("invalid skill from builtin", "name", info.Name, "error", err)
- continue
- }
- skills = append(skills, info)
- }
- }
- }
- }
- }
+ // Priority: workspace > global > builtin
+ addSkills(sl.workspaceSkills, "workspace")
+ addSkills(sl.globalSkills, "global")
+ addSkills(sl.builtinSkills, "builtin")
return skills
}
func (sl *SkillsLoader) LoadSkill(name string) (string, bool) {
- // 1. 优先从 workspace skills 加载(项目级别)
+ // 1. load from workspace skills first (project-level)
if sl.workspaceSkills != "" {
skillFile := filepath.Join(sl.workspaceSkills, name, "SKILL.md")
if content, err := os.ReadFile(skillFile); err == nil {
@@ -190,7 +128,7 @@ func (sl *SkillsLoader) LoadSkill(name string) (string, bool) {
}
}
- // 2. 其次从全局 skills 加载 (~/.picoclaw/skills)
+ // 2. then load from global skills (~/.picoclaw/skills)
if sl.globalSkills != "" {
skillFile := filepath.Join(sl.globalSkills, name, "SKILL.md")
if content, err := os.ReadFile(skillFile); err == nil {
@@ -198,7 +136,7 @@ func (sl *SkillsLoader) LoadSkill(name string) (string, bool) {
}
}
- // 3. 最后从内置 skills 加载
+ // 3. finally load from builtin skills
if sl.builtinSkills != "" {
skillFile := filepath.Join(sl.builtinSkills, name, "SKILL.md")
if content, err := os.ReadFile(skillFile); err == nil {
@@ -254,7 +192,7 @@ func (sl *SkillsLoader) getSkillMetadata(skillPath string) *SkillMetadata {
content, err := os.ReadFile(skillPath)
if err != nil {
logger.WarnCF("skills", "Failed to read skill metadata",
- map[string]interface{}{
+ map[string]any{
"skill_path": skillPath,
"error": err.Error(),
})
diff --git a/pkg/skills/loader_test.go b/pkg/skills/loader_test.go
index efadcdbf2..9428bea62 100644
--- a/pkg/skills/loader_test.go
+++ b/pkg/skills/loader_test.go
@@ -1,9 +1,12 @@
package skills
import (
+ "os"
+ "path/filepath"
"testing"
"github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
)
func TestSkillsInfoValidate(t *testing.T) {
@@ -117,12 +120,152 @@ func TestExtractFrontmatter(t *testing.T) {
// Parse YAML to get name and description (parseSimpleYAML now handles all line ending types)
yamlMeta := sl.parseSimpleYAML(frontmatter)
- assert.Equal(t, tc.expectedName, yamlMeta["name"], "Name should be correctly parsed from frontmatter with %s line endings", tc.lineEndingType)
- assert.Equal(t, tc.expectedDesc, yamlMeta["description"], "Description should be correctly parsed from frontmatter with %s line endings", tc.lineEndingType)
+ assert.Equal(
+ t,
+ tc.expectedName,
+ yamlMeta["name"],
+ "Name should be correctly parsed from frontmatter with %s line endings",
+ tc.lineEndingType,
+ )
+ assert.Equal(
+ t,
+ tc.expectedDesc,
+ yamlMeta["description"],
+ "Description should be correctly parsed from frontmatter with %s line endings",
+ tc.lineEndingType,
+ )
})
}
}
+// createSkillDir creates a skill directory with a SKILL.md file containing the given frontmatter.
+func createSkillDir(t *testing.T, base, dirName, name, description string) {
+ t.Helper()
+ dir := filepath.Join(base, dirName)
+ require.NoError(t, os.MkdirAll(dir, 0o755))
+ content := "---\nname: " + name + "\ndescription: " + description + "\n---\n\n# " + name
+ require.NoError(t, os.WriteFile(filepath.Join(dir, "SKILL.md"), []byte(content), 0o644))
+}
+
+func TestListSkillsWorkspaceOverridesGlobal(t *testing.T) {
+ tmp := t.TempDir()
+ ws := filepath.Join(tmp, "workspace")
+ global := filepath.Join(tmp, "global")
+
+ createSkillDir(t, filepath.Join(ws, "skills"), "my-skill", "my-skill", "workspace version")
+ createSkillDir(t, global, "my-skill", "my-skill", "global version")
+
+ sl := NewSkillsLoader(ws, global, "")
+ skills := sl.ListSkills()
+
+ assert.Len(t, skills, 1)
+ assert.Equal(t, "workspace", skills[0].Source)
+ assert.Equal(t, "workspace version", skills[0].Description)
+}
+
+func TestListSkillsGlobalOverridesBuiltin(t *testing.T) {
+ tmp := t.TempDir()
+ ws := filepath.Join(tmp, "workspace")
+ global := filepath.Join(tmp, "global")
+ builtin := filepath.Join(tmp, "builtin")
+
+ createSkillDir(t, global, "my-skill", "my-skill", "global version")
+ createSkillDir(t, builtin, "my-skill", "my-skill", "builtin version")
+
+ sl := NewSkillsLoader(ws, global, builtin)
+ skills := sl.ListSkills()
+
+ assert.Len(t, skills, 1)
+ assert.Equal(t, "global", skills[0].Source)
+ assert.Equal(t, "global version", skills[0].Description)
+}
+
+func TestListSkillsMetadataNameDedup(t *testing.T) {
+ tmp := t.TempDir()
+ ws := filepath.Join(tmp, "workspace")
+ global := filepath.Join(tmp, "global")
+
+ // Different directory names but same metadata name
+ createSkillDir(t, filepath.Join(ws, "skills"), "dir-a", "shared-name", "workspace version")
+ createSkillDir(t, global, "dir-b", "shared-name", "global version")
+
+ sl := NewSkillsLoader(ws, global, "")
+ skills := sl.ListSkills()
+
+ assert.Len(t, skills, 1)
+ assert.Equal(t, "shared-name", skills[0].Name)
+ assert.Equal(t, "workspace", skills[0].Source)
+}
+
+func TestListSkillsMultipleDistinctSkills(t *testing.T) {
+ tmp := t.TempDir()
+ ws := filepath.Join(tmp, "workspace")
+ global := filepath.Join(tmp, "global")
+ builtin := filepath.Join(tmp, "builtin")
+
+ createSkillDir(t, filepath.Join(ws, "skills"), "skill-a", "skill-a", "desc a")
+ createSkillDir(t, global, "skill-b", "skill-b", "desc b")
+ createSkillDir(t, builtin, "skill-c", "skill-c", "desc c")
+
+ sl := NewSkillsLoader(ws, global, builtin)
+ skills := sl.ListSkills()
+
+ assert.Len(t, skills, 3)
+ names := map[string]string{}
+ for _, s := range skills {
+ names[s.Name] = s.Source
+ }
+ assert.Equal(t, "workspace", names["skill-a"])
+ assert.Equal(t, "global", names["skill-b"])
+ assert.Equal(t, "builtin", names["skill-c"])
+}
+
+func TestListSkillsInvalidSkillSkipped(t *testing.T) {
+ tmp := t.TempDir()
+ ws := filepath.Join(tmp, "workspace")
+ global := filepath.Join(tmp, "global")
+
+ // Invalid name (underscore)
+ createSkillDir(t, filepath.Join(ws, "skills"), "bad_skill", "bad_skill", "desc")
+ // Valid skill
+ createSkillDir(t, global, "good-skill", "good-skill", "desc")
+
+ sl := NewSkillsLoader(ws, global, "")
+ skills := sl.ListSkills()
+
+ assert.Len(t, skills, 1)
+ assert.Equal(t, "good-skill", skills[0].Name)
+}
+
+func TestListSkillsEmptyAndNonexistentDirs(t *testing.T) {
+ tmp := t.TempDir()
+ ws := filepath.Join(tmp, "workspace")
+ emptyDir := filepath.Join(tmp, "empty")
+ require.NoError(t, os.MkdirAll(emptyDir, 0o755))
+
+ sl := NewSkillsLoader(ws, emptyDir, filepath.Join(tmp, "nonexistent"))
+ skills := sl.ListSkills()
+
+ assert.Empty(t, skills)
+}
+
+func TestListSkillsDirWithoutSkillMD(t *testing.T) {
+ tmp := t.TempDir()
+ ws := filepath.Join(tmp, "workspace")
+ global := filepath.Join(tmp, "global")
+
+ // Directory exists but has no SKILL.md
+ require.NoError(t, os.MkdirAll(filepath.Join(global, "no-skillmd"), 0o755))
+ // Valid skill alongside
+ createSkillDir(t, global, "real-skill", "real-skill", "desc")
+
+ sl := NewSkillsLoader(ws, global, "")
+ skills := sl.ListSkills()
+
+ assert.Len(t, skills, 1)
+ assert.Equal(t, "real-skill", skills[0].Name)
+}
+
func TestStripFrontmatter(t *testing.T) {
sl := &SkillsLoader{}
@@ -173,7 +316,13 @@ func TestStripFrontmatter(t *testing.T) {
for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
result := sl.stripFrontmatter(tc.content)
- assert.Equal(t, tc.expectedContent, result, "Frontmatter should be stripped correctly for %s", tc.lineEndingType)
+ assert.Equal(
+ t,
+ tc.expectedContent,
+ result,
+ "Frontmatter should be stripped correctly for %s",
+ tc.lineEndingType,
+ )
})
}
}
diff --git a/pkg/skills/registry.go b/pkg/skills/registry.go
new file mode 100644
index 000000000..45ae72253
--- /dev/null
+++ b/pkg/skills/registry.go
@@ -0,0 +1,223 @@
+package skills
+
+import (
+ "context"
+ "fmt"
+ "log/slog"
+ "sync"
+ "time"
+)
+
+const (
+ defaultMaxConcurrentSearches = 2
+)
+
+// SearchResult represents a single result from a skill registry search.
+type SearchResult struct {
+ Score float64 `json:"score"`
+ Slug string `json:"slug"`
+ DisplayName string `json:"display_name"`
+ Summary string `json:"summary"`
+ Version string `json:"version"`
+ RegistryName string `json:"registry_name"`
+}
+
+// SkillMeta holds metadata about a skill from a registry.
+type SkillMeta struct {
+ Slug string `json:"slug"`
+ DisplayName string `json:"display_name"`
+ Summary string `json:"summary"`
+ LatestVersion string `json:"latest_version"`
+ IsMalwareBlocked bool `json:"is_malware_blocked"`
+ IsSuspicious bool `json:"is_suspicious"`
+ RegistryName string `json:"registry_name"`
+}
+
+// InstallResult is returned by DownloadAndInstall to carry metadata
+// back to the caller for moderation and user messaging.
+type InstallResult struct {
+ Version string
+ IsMalwareBlocked bool
+ IsSuspicious bool
+ Summary string
+}
+
+// SkillRegistry is the interface that all skill registries must implement.
+// Each registry represents a different source of skills (e.g., clawhub.ai)
+type SkillRegistry interface {
+ // Name returns the unique name of this registry (e.g., "clawhub").
+ Name() string
+ // Search searches the registry for skills matching the query.
+ Search(ctx context.Context, query string, limit int) ([]SearchResult, error)
+ // GetSkillMeta retrieves metadata for a specific skill by slug.
+ GetSkillMeta(ctx context.Context, slug string) (*SkillMeta, error)
+ // DownloadAndInstall fetches metadata, resolves the version, downloads and
+ // installs the skill to targetDir. Returns an InstallResult with metadata
+ // for the caller to use for moderation and user messaging.
+ DownloadAndInstall(ctx context.Context, slug, version, targetDir string) (*InstallResult, error)
+}
+
+// RegistryConfig holds configuration for all skill registries.
+// This is the input to NewRegistryManagerFromConfig.
+type RegistryConfig struct {
+ ClawHub ClawHubConfig
+ MaxConcurrentSearches int
+}
+
+// ClawHubConfig configures the ClawHub registry.
+type ClawHubConfig struct {
+ Enabled bool
+ BaseURL string
+ AuthToken string
+ SearchPath string // e.g. "/api/v1/search"
+ SkillsPath string // e.g. "/api/v1/skills"
+ DownloadPath string // e.g. "/api/v1/download"
+ Timeout int // seconds, 0 = default (30s)
+ MaxZipSize int // bytes, 0 = default (50MB)
+ MaxResponseSize int // bytes, 0 = default (2MB)
+}
+
+// RegistryManager coordinates multiple skill registries.
+// It fans out search requests and routes installs to the correct registry.
+type RegistryManager struct {
+ registries []SkillRegistry
+ maxConcurrent int
+ mu sync.RWMutex
+}
+
+// NewRegistryManager creates an empty RegistryManager.
+func NewRegistryManager() *RegistryManager {
+ return &RegistryManager{
+ registries: make([]SkillRegistry, 0),
+ maxConcurrent: defaultMaxConcurrentSearches,
+ }
+}
+
+// NewRegistryManagerFromConfig builds a RegistryManager from config,
+// instantiating only the enabled registries.
+func NewRegistryManagerFromConfig(cfg RegistryConfig) *RegistryManager {
+ rm := NewRegistryManager()
+ if cfg.MaxConcurrentSearches > 0 {
+ rm.maxConcurrent = cfg.MaxConcurrentSearches
+ }
+ if cfg.ClawHub.Enabled {
+ rm.AddRegistry(NewClawHubRegistry(cfg.ClawHub))
+ }
+ return rm
+}
+
+// AddRegistry adds a registry to the manager.
+func (rm *RegistryManager) AddRegistry(r SkillRegistry) {
+ rm.mu.Lock()
+ defer rm.mu.Unlock()
+ rm.registries = append(rm.registries, r)
+}
+
+// GetRegistry returns a registry by name, or nil if not found.
+func (rm *RegistryManager) GetRegistry(name string) SkillRegistry {
+ rm.mu.RLock()
+ defer rm.mu.RUnlock()
+ for _, r := range rm.registries {
+ if r.Name() == name {
+ return r
+ }
+ }
+ return nil
+}
+
+// SearchAll fans out the query to all registries concurrently
+// and merges results sorted by score descending.
+func (rm *RegistryManager) SearchAll(ctx context.Context, query string, limit int) ([]SearchResult, error) {
+ rm.mu.RLock()
+ regs := make([]SkillRegistry, len(rm.registries))
+ copy(regs, rm.registries)
+ rm.mu.RUnlock()
+
+ if len(regs) == 0 {
+ return nil, fmt.Errorf("no registries configured")
+ }
+
+ type regResult struct {
+ results []SearchResult
+ err error
+ }
+
+ // Semaphore: limit concurrency.
+ sem := make(chan struct{}, rm.maxConcurrent)
+ resultsCh := make(chan regResult, len(regs))
+
+ var wg sync.WaitGroup
+ for _, reg := range regs {
+ wg.Add(1)
+ go func(r SkillRegistry) {
+ defer wg.Done()
+
+ // Acquire semaphore slot.
+ select {
+ case sem <- struct{}{}:
+ defer func() { <-sem }()
+ case <-ctx.Done():
+ resultsCh <- regResult{err: ctx.Err()}
+ return
+ }
+
+ searchCtx, cancel := context.WithTimeout(ctx, 1*time.Minute)
+ defer cancel()
+
+ results, err := r.Search(searchCtx, query, limit)
+ if err != nil {
+ slog.Warn("registry search failed", "registry", r.Name(), "error", err)
+ resultsCh <- regResult{err: err}
+ return
+ }
+ resultsCh <- regResult{results: results}
+ }(reg)
+ }
+
+ // Close results channel after all goroutines complete.
+ go func() {
+ wg.Wait()
+ close(resultsCh)
+ }()
+
+ var merged []SearchResult
+ var lastErr error
+
+ var anyRegistrySucceeded bool
+ for rr := range resultsCh {
+ if rr.err != nil {
+ lastErr = rr.err
+ continue
+ }
+ anyRegistrySucceeded = true
+ merged = append(merged, rr.results...)
+ }
+
+ // If all registries failed, return the last error.
+ if !anyRegistrySucceeded && lastErr != nil {
+ return nil, fmt.Errorf("all registries failed: %w", lastErr)
+ }
+
+ // Sort by score descending.
+ sortByScoreDesc(merged)
+
+ // Clamp to limit.
+ if limit > 0 && len(merged) > limit {
+ merged = merged[:limit]
+ }
+
+ return merged, nil
+}
+
+// sortByScoreDesc sorts SearchResults by Score in descending order (insertion sort — small slices).
+func sortByScoreDesc(results []SearchResult) {
+ for i := 1; i < len(results); i++ {
+ key := results[i]
+ j := i - 1
+ for j >= 0 && results[j].Score < key.Score {
+ results[j+1] = results[j]
+ j--
+ }
+ results[j+1] = key
+ }
+}
diff --git a/pkg/skills/registry_test.go b/pkg/skills/registry_test.go
new file mode 100644
index 000000000..a4694bd43
--- /dev/null
+++ b/pkg/skills/registry_test.go
@@ -0,0 +1,180 @@
+package skills
+
+import (
+ "context"
+ "fmt"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+
+ "github.com/sipeed/picoclaw/pkg/utils"
+)
+
+// mockRegistry is a test double implementing SkillRegistry.
+type mockRegistry struct {
+ name string
+ searchResults []SearchResult
+ searchErr error
+ meta *SkillMeta
+ metaErr error
+ installResult *InstallResult
+ installErr error
+}
+
+func (m *mockRegistry) Name() string { return m.name }
+
+func (m *mockRegistry) Search(_ context.Context, _ string, _ int) ([]SearchResult, error) {
+ return m.searchResults, m.searchErr
+}
+
+func (m *mockRegistry) GetSkillMeta(_ context.Context, _ string) (*SkillMeta, error) {
+ return m.meta, m.metaErr
+}
+
+func (m *mockRegistry) DownloadAndInstall(_ context.Context, _, _, _ string) (*InstallResult, error) {
+ return m.installResult, m.installErr
+}
+
+func TestRegistryManagerSearchAllSingle(t *testing.T) {
+ mgr := NewRegistryManager()
+ mgr.AddRegistry(&mockRegistry{
+ name: "test",
+ searchResults: []SearchResult{
+ {Slug: "skill-a", Score: 0.9, RegistryName: "test"},
+ {Slug: "skill-b", Score: 0.5, RegistryName: "test"},
+ },
+ })
+
+ results, err := mgr.SearchAll(context.Background(), "test query", 10)
+ assert.NoError(t, err)
+ assert.Len(t, results, 2)
+ assert.Equal(t, "skill-a", results[0].Slug)
+}
+
+func TestRegistryManagerSearchAllMultiple(t *testing.T) {
+ mgr := NewRegistryManager()
+ mgr.AddRegistry(&mockRegistry{
+ name: "alpha",
+ searchResults: []SearchResult{
+ {Slug: "skill-a", Score: 0.8, RegistryName: "alpha"},
+ },
+ })
+ mgr.AddRegistry(&mockRegistry{
+ name: "beta",
+ searchResults: []SearchResult{
+ {Slug: "skill-b", Score: 0.95, RegistryName: "beta"},
+ },
+ })
+
+ results, err := mgr.SearchAll(context.Background(), "test query", 10)
+ assert.NoError(t, err)
+ assert.Len(t, results, 2)
+ // Should be sorted by score descending
+ assert.Equal(t, "skill-b", results[0].Slug)
+ assert.Equal(t, "skill-a", results[1].Slug)
+}
+
+func TestRegistryManagerSearchAllOneFailsGracefully(t *testing.T) {
+ mgr := NewRegistryManager()
+ mgr.AddRegistry(&mockRegistry{
+ name: "failing",
+ searchErr: fmt.Errorf("network error"),
+ })
+ mgr.AddRegistry(&mockRegistry{
+ name: "working",
+ searchResults: []SearchResult{
+ {Slug: "skill-a", Score: 0.8, RegistryName: "working"},
+ },
+ })
+
+ results, err := mgr.SearchAll(context.Background(), "test query", 10)
+ assert.NoError(t, err)
+ assert.Len(t, results, 1)
+ assert.Equal(t, "skill-a", results[0].Slug)
+}
+
+func TestRegistryManagerSearchAllAllFail(t *testing.T) {
+ mgr := NewRegistryManager()
+ mgr.AddRegistry(&mockRegistry{
+ name: "fail-1",
+ searchErr: fmt.Errorf("error 1"),
+ })
+
+ _, err := mgr.SearchAll(context.Background(), "test query", 10)
+ assert.Error(t, err)
+}
+
+func TestRegistryManagerSearchAllNoRegistries(t *testing.T) {
+ mgr := NewRegistryManager()
+ _, err := mgr.SearchAll(context.Background(), "test query", 10)
+ assert.Error(t, err)
+}
+
+func TestRegistryManagerGetRegistry(t *testing.T) {
+ mgr := NewRegistryManager()
+ mock := &mockRegistry{name: "clawhub"}
+ mgr.AddRegistry(mock)
+
+ got := mgr.GetRegistry("clawhub")
+ assert.NotNil(t, got)
+ assert.Equal(t, "clawhub", got.Name())
+
+ got = mgr.GetRegistry("nonexistent")
+ assert.Nil(t, got)
+}
+
+func TestRegistryManagerSearchAllRespectLimit(t *testing.T) {
+ mgr := NewRegistryManager()
+ results := make([]SearchResult, 20)
+ for i := range results {
+ results[i] = SearchResult{Slug: fmt.Sprintf("skill-%d", i), Score: float64(20 - i)}
+ }
+ mgr.AddRegistry(&mockRegistry{
+ name: "test",
+ searchResults: results,
+ })
+
+ got, err := mgr.SearchAll(context.Background(), "test", 5)
+ assert.NoError(t, err)
+ assert.Len(t, got, 5)
+ // Top scores first
+ assert.Equal(t, "skill-0", got[0].Slug)
+}
+
+func TestRegistryManagerSearchAllTimeout(t *testing.T) {
+ ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
+ defer cancel()
+
+ time.Sleep(5 * time.Millisecond) // Let context expire.
+
+ mgr := NewRegistryManager()
+ mgr.AddRegistry(&mockRegistry{
+ name: "slow",
+ searchErr: fmt.Errorf("context deadline exceeded"),
+ })
+
+ _, err := mgr.SearchAll(ctx, "test", 5)
+ assert.Error(t, err)
+}
+
+func TestSortByScoreDesc(t *testing.T) {
+ results := []SearchResult{
+ {Slug: "c", Score: 0.3},
+ {Slug: "a", Score: 0.9},
+ {Slug: "b", Score: 0.5},
+ }
+ sortByScoreDesc(results)
+ assert.Equal(t, "a", results[0].Slug)
+ assert.Equal(t, "b", results[1].Slug)
+ assert.Equal(t, "c", results[2].Slug)
+}
+
+func TestIsSafeSlug(t *testing.T) {
+ assert.NoError(t, utils.ValidateSkillIdentifier("github"))
+ assert.NoError(t, utils.ValidateSkillIdentifier("docker-compose"))
+ assert.Error(t, utils.ValidateSkillIdentifier(""))
+ assert.Error(t, utils.ValidateSkillIdentifier("../etc/passwd"))
+ assert.Error(t, utils.ValidateSkillIdentifier("path/traversal"))
+ assert.Error(t, utils.ValidateSkillIdentifier("path\\traversal"))
+}
diff --git a/pkg/skills/search_cache.go b/pkg/skills/search_cache.go
new file mode 100644
index 000000000..5d7d2797e
--- /dev/null
+++ b/pkg/skills/search_cache.go
@@ -0,0 +1,229 @@
+package skills
+
+import (
+ "sort"
+ "strings"
+ "sync"
+ "time"
+)
+
+// SearchCache provides lightweight caching for search results.
+// It uses trigram-based similarity to match similar queries to cached results,
+// avoiding redundant API calls. Thread-safe for concurrent access.
+type SearchCache struct {
+ mu sync.RWMutex
+ entries map[string]*cacheEntry
+ order []string // LRU order: oldest first.
+ maxEntries int
+ ttl time.Duration
+}
+
+type cacheEntry struct {
+ query string
+ trigrams []uint32
+ results []SearchResult
+ createdAt time.Time
+}
+
+// similarityThreshold is the minimum trigram Jaccard similarity for a cache hit.
+const similarityThreshold = 0.7
+
+// NewSearchCache creates a new search cache.
+// maxEntries is the maximum number of cached queries (excess evicts LRU).
+// ttl is how long each entry lives before expiration.
+func NewSearchCache(maxEntries int, ttl time.Duration) *SearchCache {
+ if maxEntries <= 0 {
+ maxEntries = 50
+ }
+ if ttl <= 0 {
+ ttl = 5 * time.Minute
+ }
+ return &SearchCache{
+ entries: make(map[string]*cacheEntry),
+ order: make([]string, 0),
+ maxEntries: maxEntries,
+ ttl: ttl,
+ }
+}
+
+// Get looks up results for a query. Returns cached results and true if found
+// (either exact or similar match above threshold). Returns nil, false on miss.
+func (sc *SearchCache) Get(query string) ([]SearchResult, bool) {
+ normalized := normalizeQuery(query)
+ if normalized == "" {
+ return nil, false
+ }
+
+ sc.mu.Lock()
+ defer sc.mu.Unlock()
+
+ // Exact match first.
+ if entry, ok := sc.entries[normalized]; ok {
+ if time.Since(entry.createdAt) < sc.ttl {
+ sc.moveToEndLocked(normalized)
+ return copyResults(entry.results), true
+ }
+ }
+
+ // Similarity match.
+ queryTrigrams := buildTrigrams(normalized)
+ var bestEntry *cacheEntry
+ var bestSim float64
+
+ for _, entry := range sc.entries {
+ if time.Since(entry.createdAt) >= sc.ttl {
+ continue // Skip expired.
+ }
+ sim := jaccardSimilarity(queryTrigrams, entry.trigrams)
+ if sim > bestSim {
+ bestSim = sim
+ bestEntry = entry
+ }
+ }
+
+ if bestSim >= similarityThreshold && bestEntry != nil {
+ sc.moveToEndLocked(bestEntry.query)
+ return copyResults(bestEntry.results), true
+ }
+
+ return nil, false
+}
+
+// Put stores results for a query. Evicts the oldest entry if at capacity.
+func (sc *SearchCache) Put(query string, results []SearchResult) {
+ normalized := normalizeQuery(query)
+ if normalized == "" {
+ return
+ }
+
+ sc.mu.Lock()
+ defer sc.mu.Unlock()
+
+ // Evict expired entries first.
+ sc.evictExpiredLocked()
+
+ // If already exists, update.
+ if _, ok := sc.entries[normalized]; ok {
+ sc.entries[normalized] = &cacheEntry{
+ query: normalized,
+ trigrams: buildTrigrams(normalized),
+ results: copyResults(results),
+ createdAt: time.Now(),
+ }
+ // Move to end of LRU order.
+ sc.moveToEndLocked(normalized)
+ return
+ }
+
+ // Evict LRU if at capacity.
+ for len(sc.entries) >= sc.maxEntries && len(sc.order) > 0 {
+ oldest := sc.order[0]
+ sc.order = sc.order[1:]
+ delete(sc.entries, oldest)
+ }
+
+ // Insert new entry.
+ sc.entries[normalized] = &cacheEntry{
+ query: normalized,
+ trigrams: buildTrigrams(normalized),
+ results: copyResults(results),
+ createdAt: time.Now(),
+ }
+ sc.order = append(sc.order, normalized)
+}
+
+// Len returns the number of entries (for testing).
+func (sc *SearchCache) Len() int {
+ sc.mu.RLock()
+ defer sc.mu.RUnlock()
+ return len(sc.entries)
+}
+
+// --- internal ---
+
+func (sc *SearchCache) evictExpiredLocked() {
+ now := time.Now()
+ newOrder := make([]string, 0, len(sc.order))
+ for _, key := range sc.order {
+ entry, ok := sc.entries[key]
+ if !ok || now.Sub(entry.createdAt) >= sc.ttl {
+ delete(sc.entries, key)
+ continue
+ }
+ newOrder = append(newOrder, key)
+ }
+ sc.order = newOrder
+}
+
+func (sc *SearchCache) moveToEndLocked(key string) {
+ for i, k := range sc.order {
+ if k == key {
+ sc.order = append(sc.order[:i], sc.order[i+1:]...)
+ break
+ }
+ }
+ sc.order = append(sc.order, key)
+}
+
+func normalizeQuery(q string) string {
+ return strings.ToLower(strings.TrimSpace(q))
+}
+
+// buildTrigrams generates hash of trigrams from a string.
+// Example: "hello" → {"hel", "ell", "llo"}
+// "hel" -> 0x0068656c -> 4 bytes; compared to 16 bytes of a string
+func buildTrigrams(s string) []uint32 {
+ if len(s) < 3 {
+ return nil
+ }
+
+ trigrams := make([]uint32, 0, len(s)-2)
+ for i := 0; i <= len(s)-3; i++ {
+ trigrams = append(trigrams, uint32(s[i])<<16|uint32(s[i+1])<<8|uint32(s[i+2]))
+ }
+
+ // Sort and Deduplication
+ sort.Slice(trigrams, func(i, j int) bool { return trigrams[i] < trigrams[j] })
+ n := 1
+ for i := 1; i < len(trigrams); i++ {
+ if trigrams[i] != trigrams[i-1] {
+ trigrams[n] = trigrams[i]
+ n++
+ }
+ }
+
+ return trigrams[:n]
+}
+
+// jaccardSimilarity computes |A ∩ B| / |A ∪ B|.
+func jaccardSimilarity(a, b []uint32) float64 {
+ if len(a) == 0 && len(b) == 0 {
+ return 1
+ }
+ i, j := 0, 0
+ intersection := 0
+
+ for i < len(a) && j < len(b) {
+ if a[i] == b[j] {
+ intersection++
+ i++
+ j++
+ } else if a[i] < b[j] {
+ i++
+ } else {
+ j++
+ }
+ }
+
+ union := len(a) + len(b) - intersection
+ return float64(intersection) / float64(union)
+}
+
+func copyResults(results []SearchResult) []SearchResult {
+ if results == nil {
+ return nil
+ }
+ cp := make([]SearchResult, len(results))
+ copy(cp, results)
+ return cp
+}
diff --git a/pkg/skills/search_cache_test.go b/pkg/skills/search_cache_test.go
new file mode 100644
index 000000000..816bdfb93
--- /dev/null
+++ b/pkg/skills/search_cache_test.go
@@ -0,0 +1,200 @@
+package skills
+
+import (
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestSearchCacheExactHit(t *testing.T) {
+ cache := NewSearchCache(10, 5*time.Minute)
+
+ results := []SearchResult{
+ {Slug: "github", Score: 0.9, RegistryName: "clawhub"},
+ {Slug: "docker", Score: 0.7, RegistryName: "clawhub"},
+ }
+ cache.Put("github integration", results)
+
+ got, hit := cache.Get("github integration")
+ assert.True(t, hit)
+ assert.Len(t, got, 2)
+ assert.Equal(t, "github", got[0].Slug)
+}
+
+func TestSearchCacheExactHitCaseInsensitive(t *testing.T) {
+ cache := NewSearchCache(10, 5*time.Minute)
+
+ results := []SearchResult{{Slug: "github", Score: 0.9}}
+ cache.Put("GitHub Integration", results)
+
+ got, hit := cache.Get("github integration")
+ assert.True(t, hit)
+ assert.Len(t, got, 1)
+}
+
+func TestSearchCacheSimilarHit(t *testing.T) {
+ cache := NewSearchCache(10, 5*time.Minute)
+
+ results := []SearchResult{{Slug: "github", Score: 0.9}}
+ cache.Put("github integration tool", results)
+
+ // "github integration" is very similar to "github integration tool"
+ got, hit := cache.Get("github integration")
+ assert.True(t, hit)
+ assert.Len(t, got, 1)
+}
+
+func TestSearchCacheDissimilarMiss(t *testing.T) {
+ cache := NewSearchCache(10, 5*time.Minute)
+
+ results := []SearchResult{{Slug: "github", Score: 0.9}}
+ cache.Put("github integration", results)
+
+ // Completely unrelated query
+ _, hit := cache.Get("database management")
+ assert.False(t, hit)
+}
+
+func TestSearchCacheTTLExpiration(t *testing.T) {
+ cache := NewSearchCache(10, 50*time.Millisecond)
+
+ results := []SearchResult{{Slug: "github", Score: 0.9}}
+ cache.Put("github integration", results)
+
+ // Immediately should hit
+ _, hit := cache.Get("github integration")
+ assert.True(t, hit)
+
+ // Wait for expiration
+ time.Sleep(100 * time.Millisecond)
+
+ _, hit = cache.Get("github integration")
+ assert.False(t, hit)
+}
+
+func TestSearchCacheLRUEviction(t *testing.T) {
+ cache := NewSearchCache(3, 5*time.Minute)
+
+ cache.Put("query-1", []SearchResult{{Slug: "a"}})
+ cache.Put("query-2", []SearchResult{{Slug: "b"}})
+ cache.Put("query-3", []SearchResult{{Slug: "c"}})
+
+ assert.Equal(t, 3, cache.Len())
+
+ // Adding a 4th should evict query-1 (oldest)
+ cache.Put("query-4", []SearchResult{{Slug: "d"}})
+ assert.Equal(t, 3, cache.Len())
+
+ _, hit := cache.Get("query-1")
+ assert.False(t, hit, "oldest entry should be evicted")
+
+ got, hit := cache.Get("query-4")
+ assert.True(t, hit)
+ assert.Equal(t, "d", got[0].Slug)
+}
+
+func TestSearchCacheEmptyQuery(t *testing.T) {
+ cache := NewSearchCache(10, 5*time.Minute)
+
+ _, hit := cache.Get("")
+ assert.False(t, hit)
+
+ _, hit = cache.Get(" ")
+ assert.False(t, hit)
+}
+
+func TestSearchCacheResultsCopied(t *testing.T) {
+ cache := NewSearchCache(10, 5*time.Minute)
+
+ original := []SearchResult{{Slug: "github", Score: 0.9}}
+ cache.Put("test", original)
+
+ // Mutate original after putting
+ original[0].Slug = "mutated"
+
+ got, hit := cache.Get("test")
+ assert.True(t, hit)
+ assert.Equal(t, "github", got[0].Slug, "cache should hold a copy, not a reference")
+}
+
+func TestBuildTrigrams(t *testing.T) {
+ trigrams := buildTrigrams("hello")
+ assert.Contains(t, trigrams, uint32('h')<<16|uint32('e')<<8|uint32('l'))
+ assert.Contains(t, trigrams, uint32('e')<<16|uint32('l')<<8|uint32('l'))
+ assert.Contains(t, trigrams, uint32('l')<<16|uint32('l')<<8|uint32('o'))
+ assert.Len(t, trigrams, 3)
+}
+
+func TestJaccardSimilarity(t *testing.T) {
+ a := buildTrigrams("github integration")
+ b := buildTrigrams("github integration tool")
+
+ sim := jaccardSimilarity(a, b)
+ assert.Greater(t, sim, 0.5, "similar strings should have high sim")
+
+ c := buildTrigrams("completely different query about databases")
+ sim2 := jaccardSimilarity(a, c)
+ assert.Less(t, sim2, 0.3, "dissimilar strings should have low sim")
+}
+
+func TestJaccardSimilarityEdgeCases(t *testing.T) {
+ empty := buildTrigrams("")
+ nonempty := buildTrigrams("hello")
+
+ assert.Equal(t, 1.0, jaccardSimilarity(empty, empty))
+ assert.Equal(t, 0.0, jaccardSimilarity(empty, nonempty))
+ assert.Equal(t, 0.0, jaccardSimilarity(nonempty, empty))
+}
+
+func TestSearchCacheConcurrency(t *testing.T) {
+ cache := NewSearchCache(50, 5*time.Minute)
+ done := make(chan struct{})
+
+ // Concurrent writes
+ go func() {
+ for i := 0; i < 100; i++ {
+ cache.Put("query-write-"+string(rune('a'+i%26)), []SearchResult{{Slug: "x"}})
+ }
+ done <- struct{}{}
+ }()
+
+ // Concurrent reads
+ go func() {
+ for i := 0; i < 100; i++ {
+ cache.Get("query-write-a")
+ }
+ done <- struct{}{}
+ }()
+
+ <-done
+}
+
+func TestSearchCacheLRUUpdateOnGet(t *testing.T) {
+ // Capacity 3
+ cache := NewSearchCache(3, time.Hour)
+
+ // Fill cache: query-A, query-B, query-C
+ // Use longer strings to ensure trigrams are generated and avoid false positive similarity
+ cache.Put("query-A", []SearchResult{{Slug: "A"}})
+ cache.Put("query-B", []SearchResult{{Slug: "B"}})
+ cache.Put("query-C", []SearchResult{{Slug: "C"}})
+
+ // Access query-A (should make it most recently used)
+ if _, found := cache.Get("query-A"); !found {
+ t.Fatal("query-A should be in cache")
+ }
+
+ // Add query-D. Should evict query-B (LRU) instead of query-A (which was refreshed)
+ cache.Put("query-D", []SearchResult{{Slug: "D"}})
+
+ // Check if query-A is still there
+ if _, found := cache.Get("query-A"); !found {
+ t.Fatalf("query-A was evicted! valid LRU should have kept query-A and evicted query-B.")
+ }
+
+ // Check if query-B is evicted
+ if _, found := cache.Get("query-B"); found {
+ t.Fatal("query-B should have been evicted")
+ }
+}
diff --git a/pkg/state/state.go b/pkg/state/state.go
index 0bb9cd497..1a92f82ed 100644
--- a/pkg/state/state.go
+++ b/pkg/state/state.go
@@ -38,7 +38,7 @@ func NewManager(workspace string) *Manager {
oldStateFile := filepath.Join(workspace, "state.json")
// Create state directory if it doesn't exist
- os.MkdirAll(stateDir, 0755)
+ os.MkdirAll(stateDir, 0o755)
sm := &Manager{
workspace: workspace,
@@ -139,7 +139,7 @@ func (sm *Manager) saveAtomic() error {
}
// Write to temp file
- if err := os.WriteFile(tempFile, data, 0644); err != nil {
+ if err := os.WriteFile(tempFile, data, 0o644); err != nil {
return fmt.Errorf("failed to write temp file: %w", err)
}
diff --git a/pkg/state/state_test.go b/pkg/state/state_test.go
index ce3dd7215..f717a5bb4 100644
--- a/pkg/state/state_test.go
+++ b/pkg/state/state_test.go
@@ -98,7 +98,7 @@ func TestAtomicity_NoCorruptionOnInterrupt(t *testing.T) {
// Simulate a crash scenario by manually creating a corrupted temp file
tempFile := filepath.Join(tmpDir, "state", "state.json.tmp")
- err = os.WriteFile(tempFile, []byte("corrupted data"), 0644)
+ err = os.WriteFile(tempFile, []byte("corrupted data"), 0o644)
if err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
diff --git a/pkg/tools/base.go b/pkg/tools/base.go
index b13174633..770d8cb04 100644
--- a/pkg/tools/base.go
+++ b/pkg/tools/base.go
@@ -6,8 +6,8 @@ import "context"
type Tool interface {
Name() string
Description() string
- Parameters() map[string]interface{}
- Execute(ctx context.Context, args map[string]interface{}) *ToolResult
+ Parameters() map[string]any
+ Execute(ctx context.Context, args map[string]any) *ToolResult
}
// ContextualTool is an optional interface that tools can implement
@@ -69,10 +69,10 @@ type AsyncTool interface {
SetCallback(cb AsyncCallback)
}
-func ToolToSchema(tool Tool) map[string]interface{} {
- return map[string]interface{}{
+func ToolToSchema(tool Tool) map[string]any {
+ return map[string]any{
"type": "function",
- "function": map[string]interface{}{
+ "function": map[string]any{
"name": tool.Name(),
"description": tool.Description(),
"parameters": tool.Parameters(),
diff --git a/pkg/tools/cron.go b/pkg/tools/cron.go
index e2764d8ac..562fffc84 100644
--- a/pkg/tools/cron.go
+++ b/pkg/tools/cron.go
@@ -30,7 +30,10 @@ type CronTool struct {
// NewCronTool creates a new CronTool
// execTimeout: 0 means no timeout, >0 sets the timeout duration
-func NewCronTool(cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus, workspace string, restrict bool, execTimeout time.Duration, config *config.Config) *CronTool {
+func NewCronTool(
+ cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus, workspace string, restrict bool,
+ execTimeout time.Duration, config *config.Config,
+) *CronTool {
execTool := NewExecToolWithConfig(workspace, restrict, config)
execTool.SetTimeout(execTimeout)
return &CronTool{
@@ -52,40 +55,40 @@ func (t *CronTool) Description() string {
}
// Parameters returns the tool parameters schema
-func (t *CronTool) Parameters() map[string]interface{} {
- return map[string]interface{}{
+func (t *CronTool) Parameters() map[string]any {
+ return map[string]any{
"type": "object",
- "properties": map[string]interface{}{
- "action": map[string]interface{}{
+ "properties": map[string]any{
+ "action": map[string]any{
"type": "string",
"enum": []string{"add", "list", "remove", "enable", "disable"},
"description": "Action to perform. Use 'add' when user wants to schedule a reminder or task.",
},
- "message": map[string]interface{}{
+ "message": map[string]any{
"type": "string",
"description": "The reminder/task message to display when triggered. If 'command' is used, this describes what the command does.",
},
- "command": map[string]interface{}{
+ "command": map[string]any{
"type": "string",
"description": "Optional: Shell command to execute directly (e.g., 'df -h'). If set, the agent will run this command and report output instead of just showing the message. 'deliver' will be forced to false for commands.",
},
- "at_seconds": map[string]interface{}{
+ "at_seconds": map[string]any{
"type": "integer",
"description": "One-time reminder: seconds from now when to trigger (e.g., 600 for 10 minutes later). Use this for one-time reminders like 'remind me in 10 minutes'.",
},
- "every_seconds": map[string]interface{}{
+ "every_seconds": map[string]any{
"type": "integer",
"description": "Recurring interval in seconds (e.g., 3600 for every hour). Use this ONLY for recurring tasks like 'every 2 hours' or 'daily reminder'.",
},
- "cron_expr": map[string]interface{}{
+ "cron_expr": map[string]any{
"type": "string",
"description": "Cron expression for complex recurring schedules (e.g., '0 9 * * *' for daily at 9am). Use this for complex recurring schedules.",
},
- "job_id": map[string]interface{}{
+ "job_id": map[string]any{
"type": "string",
"description": "Job ID (for remove/enable/disable)",
},
- "deliver": map[string]interface{}{
+ "deliver": map[string]any{
"type": "boolean",
"description": "If true, send message directly to channel. If false, let agent process message (for complex tasks). Default: true",
},
@@ -103,7 +106,7 @@ func (t *CronTool) SetContext(channel, chatID string) {
}
// Execute runs the tool with the given arguments
-func (t *CronTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
+func (t *CronTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
action, ok := args["action"].(string)
if !ok {
return ErrorResult("action is required")
@@ -125,7 +128,7 @@ func (t *CronTool) Execute(ctx context.Context, args map[string]interface{}) *To
}
}
-func (t *CronTool) addJob(args map[string]interface{}) *ToolResult {
+func (t *CronTool) addJob(args map[string]any) *ToolResult {
t.mu.RLock()
channel := t.channel
chatID := t.chatID
@@ -233,7 +236,7 @@ func (t *CronTool) listJobs() *ToolResult {
return SilentResult(result)
}
-func (t *CronTool) removeJob(args map[string]interface{}) *ToolResult {
+func (t *CronTool) removeJob(args map[string]any) *ToolResult {
jobID, ok := args["job_id"].(string)
if !ok || jobID == "" {
return ErrorResult("job_id is required for remove")
@@ -245,7 +248,7 @@ func (t *CronTool) removeJob(args map[string]interface{}) *ToolResult {
return ErrorResult(fmt.Sprintf("Job %s not found", jobID))
}
-func (t *CronTool) enableJob(args map[string]interface{}, enable bool) *ToolResult {
+func (t *CronTool) enableJob(args map[string]any, enable bool) *ToolResult {
jobID, ok := args["job_id"].(string)
if !ok || jobID == "" {
return ErrorResult("job_id is required for enable/disable")
@@ -279,7 +282,7 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string {
// Execute command if present
if job.Payload.Command != "" {
- args := map[string]interface{}{
+ args := map[string]any{
"command": job.Payload.Command,
}
@@ -320,7 +323,6 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string {
channel,
chatID,
)
-
if err != nil {
return fmt.Sprintf("Error: %v", err)
}
diff --git a/pkg/tools/edit.go b/pkg/tools/edit.go
index 1e7c33b45..d3ab267bf 100644
--- a/pkg/tools/edit.go
+++ b/pkg/tools/edit.go
@@ -2,24 +2,27 @@ package tools
import (
"context"
+ "errors"
"fmt"
- "os"
+ "io/fs"
"strings"
)
// EditFileTool edits a file by replacing old_text with new_text.
// The old_text must exist exactly in the file.
type EditFileTool struct {
- allowedDir string
- restrict bool
+ fs fileSystem
}
// NewEditFileTool creates a new EditFileTool with optional directory restriction.
-func NewEditFileTool(allowedDir string, restrict bool) *EditFileTool {
- return &EditFileTool{
- allowedDir: allowedDir,
- restrict: restrict,
+func NewEditFileTool(workspace string, restrict bool) *EditFileTool {
+ var fs fileSystem
+ if restrict {
+ fs = &sandboxFs{workspace: workspace}
+ } else {
+ fs = &hostFs{}
}
+ return &EditFileTool{fs: fs}
}
func (t *EditFileTool) Name() string {
@@ -30,19 +33,19 @@ func (t *EditFileTool) Description() string {
return "Edit a file by replacing old_text with new_text. The old_text must exist exactly in the file."
}
-func (t *EditFileTool) Parameters() map[string]interface{} {
- return map[string]interface{}{
+func (t *EditFileTool) Parameters() map[string]any {
+ return map[string]any{
"type": "object",
- "properties": map[string]interface{}{
- "path": map[string]interface{}{
+ "properties": map[string]any{
+ "path": map[string]any{
"type": "string",
"description": "The file path to edit",
},
- "old_text": map[string]interface{}{
+ "old_text": map[string]any{
"type": "string",
"description": "The exact text to find and replace",
},
- "new_text": map[string]interface{}{
+ "new_text": map[string]any{
"type": "string",
"description": "The text to replace with",
},
@@ -51,7 +54,7 @@ func (t *EditFileTool) Parameters() map[string]interface{} {
}
}
-func (t *EditFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
+func (t *EditFileTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
path, ok := args["path"].(string)
if !ok {
return ErrorResult("path is required")
@@ -67,47 +70,24 @@ func (t *EditFileTool) Execute(ctx context.Context, args map[string]interface{})
return ErrorResult("new_text is required")
}
- resolvedPath, err := validatePath(path, t.allowedDir, t.restrict)
- if err != nil {
+ if err := editFile(t.fs, path, oldText, newText); err != nil {
return ErrorResult(err.Error())
}
-
- if _, err := os.Stat(resolvedPath); os.IsNotExist(err) {
- return ErrorResult(fmt.Sprintf("file not found: %s", path))
- }
-
- content, err := os.ReadFile(resolvedPath)
- if err != nil {
- return ErrorResult(fmt.Sprintf("failed to read file: %v", err))
- }
-
- contentStr := string(content)
-
- if !strings.Contains(contentStr, oldText) {
- return ErrorResult("old_text not found in file. Make sure it matches exactly")
- }
-
- count := strings.Count(contentStr, oldText)
- if count > 1 {
- return ErrorResult(fmt.Sprintf("old_text appears %d times. Please provide more context to make it unique", count))
- }
-
- newContent := strings.Replace(contentStr, oldText, newText, 1)
-
- if err := os.WriteFile(resolvedPath, []byte(newContent), 0644); err != nil {
- return ErrorResult(fmt.Sprintf("failed to write file: %v", err))
- }
-
return SilentResult(fmt.Sprintf("File edited: %s", path))
}
type AppendFileTool struct {
- workspace string
- restrict bool
+ fs fileSystem
}
func NewAppendFileTool(workspace string, restrict bool) *AppendFileTool {
- return &AppendFileTool{workspace: workspace, restrict: restrict}
+ var fs fileSystem
+ if restrict {
+ fs = &sandboxFs{workspace: workspace}
+ } else {
+ fs = &hostFs{}
+ }
+ return &AppendFileTool{fs: fs}
}
func (t *AppendFileTool) Name() string {
@@ -118,15 +98,15 @@ func (t *AppendFileTool) Description() string {
return "Append content to the end of a file"
}
-func (t *AppendFileTool) Parameters() map[string]interface{} {
- return map[string]interface{}{
+func (t *AppendFileTool) Parameters() map[string]any {
+ return map[string]any{
"type": "object",
- "properties": map[string]interface{}{
- "path": map[string]interface{}{
+ "properties": map[string]any{
+ "path": map[string]any{
"type": "string",
"description": "The file path to append to",
},
- "content": map[string]interface{}{
+ "content": map[string]any{
"type": "string",
"description": "The content to append",
},
@@ -135,7 +115,7 @@ func (t *AppendFileTool) Parameters() map[string]interface{} {
}
}
-func (t *AppendFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
+func (t *AppendFileTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
path, ok := args["path"].(string)
if !ok {
return ErrorResult("path is required")
@@ -146,20 +126,52 @@ func (t *AppendFileTool) Execute(ctx context.Context, args map[string]interface{
return ErrorResult("content is required")
}
- resolvedPath, err := validatePath(path, t.workspace, t.restrict)
- if err != nil {
+ if err := appendFile(t.fs, path, content); err != nil {
return ErrorResult(err.Error())
}
-
- f, err := os.OpenFile(resolvedPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
- if err != nil {
- return ErrorResult(fmt.Sprintf("failed to open file: %v", err))
- }
- defer f.Close()
-
- if _, err := f.WriteString(content); err != nil {
- return ErrorResult(fmt.Sprintf("failed to append to file: %v", err))
- }
-
return SilentResult(fmt.Sprintf("Appended to %s", path))
}
+
+// editFile reads the file via sysFs, performs the replacement, and writes back.
+// It uses a fileSystem interface, allowing the same logic for both restricted and unrestricted modes.
+func editFile(sysFs fileSystem, path, oldText, newText string) error {
+ content, err := sysFs.ReadFile(path)
+ if err != nil {
+ return err
+ }
+
+ newContent, err := replaceEditContent(content, oldText, newText)
+ if err != nil {
+ return err
+ }
+
+ return sysFs.WriteFile(path, newContent)
+}
+
+// appendFile reads the existing content (if any) via sysFs, appends new content, and writes back.
+func appendFile(sysFs fileSystem, path, appendContent string) error {
+ content, err := sysFs.ReadFile(path)
+ if err != nil && !errors.Is(err, fs.ErrNotExist) {
+ return err
+ }
+
+ newContent := append(content, []byte(appendContent)...)
+ return sysFs.WriteFile(path, newContent)
+}
+
+// replaceEditContent handles the core logic of finding and replacing a single occurrence of oldText.
+func replaceEditContent(content []byte, oldText, newText string) ([]byte, error) {
+ contentStr := string(content)
+
+ if !strings.Contains(contentStr, oldText) {
+ return nil, fmt.Errorf("old_text not found in file. Make sure it matches exactly")
+ }
+
+ count := strings.Count(contentStr, oldText)
+ if count > 1 {
+ return nil, fmt.Errorf("old_text appears %d times. Please provide more context to make it unique", count)
+ }
+
+ newContent := strings.Replace(contentStr, oldText, newText, 1)
+ return []byte(newContent), nil
+}
diff --git a/pkg/tools/edit_test.go b/pkg/tools/edit_test.go
index c4c02772d..83a7e778c 100644
--- a/pkg/tools/edit_test.go
+++ b/pkg/tools/edit_test.go
@@ -6,17 +6,19 @@ import (
"path/filepath"
"strings"
"testing"
+
+ "github.com/stretchr/testify/assert"
)
// TestEditTool_EditFile_Success verifies successful file editing
func TestEditTool_EditFile_Success(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.txt")
- os.WriteFile(testFile, []byte("Hello World\nThis is a test"), 0644)
+ os.WriteFile(testFile, []byte("Hello World\nThis is a test"), 0o644)
tool := NewEditFileTool(tmpDir, true)
ctx := context.Background()
- args := map[string]interface{}{
+ args := map[string]any{
"path": testFile,
"old_text": "World",
"new_text": "Universe",
@@ -60,7 +62,7 @@ func TestEditTool_EditFile_NotFound(t *testing.T) {
tool := NewEditFileTool(tmpDir, true)
ctx := context.Background()
- args := map[string]interface{}{
+ args := map[string]any{
"path": testFile,
"old_text": "old",
"new_text": "new",
@@ -83,11 +85,11 @@ func TestEditTool_EditFile_NotFound(t *testing.T) {
func TestEditTool_EditFile_OldTextNotFound(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.txt")
- os.WriteFile(testFile, []byte("Hello World"), 0644)
+ os.WriteFile(testFile, []byte("Hello World"), 0o644)
tool := NewEditFileTool(tmpDir, true)
ctx := context.Background()
- args := map[string]interface{}{
+ args := map[string]any{
"path": testFile,
"old_text": "Goodbye",
"new_text": "Hello",
@@ -110,11 +112,11 @@ func TestEditTool_EditFile_OldTextNotFound(t *testing.T) {
func TestEditTool_EditFile_MultipleMatches(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.txt")
- os.WriteFile(testFile, []byte("test test test"), 0644)
+ os.WriteFile(testFile, []byte("test test test"), 0o644)
tool := NewEditFileTool(tmpDir, true)
ctx := context.Background()
- args := map[string]interface{}{
+ args := map[string]any{
"path": testFile,
"old_text": "test",
"new_text": "done",
@@ -138,11 +140,11 @@ func TestEditTool_EditFile_OutsideAllowedDir(t *testing.T) {
tmpDir := t.TempDir()
otherDir := t.TempDir()
testFile := filepath.Join(otherDir, "test.txt")
- os.WriteFile(testFile, []byte("content"), 0644)
+ os.WriteFile(testFile, []byte("content"), 0o644)
tool := NewEditFileTool(tmpDir, true) // Restrict to tmpDir
ctx := context.Background()
- args := map[string]interface{}{
+ args := map[string]any{
"path": testFile,
"old_text": "content",
"new_text": "new",
@@ -151,21 +153,25 @@ func TestEditTool_EditFile_OutsideAllowedDir(t *testing.T) {
result := tool.Execute(ctx, args)
// Should return error result
- if !result.IsError {
- t.Errorf("Expected error when path is outside allowed directory")
- }
+ assert.True(t, result.IsError, "Expected error when path is outside allowed directory")
// Should mention outside allowed directory
- if !strings.Contains(result.ForLLM, "outside") && !strings.Contains(result.ForUser, "outside") {
- t.Errorf("Expected 'outside allowed' message, got ForLLM: %s", result.ForLLM)
- }
+ // Note: ErrorResult only sets ForLLM by default, so ForUser might be empty.
+ // We check ForLLM as it's the primary error channel.
+ assert.True(
+ t,
+ strings.Contains(result.ForLLM, "outside") || strings.Contains(result.ForLLM, "access denied") ||
+ strings.Contains(result.ForLLM, "escapes"),
+ "Expected 'outside allowed' or 'access denied' message, got ForLLM: %s",
+ result.ForLLM,
+ )
}
// TestEditTool_EditFile_MissingPath verifies error handling for missing path
func TestEditTool_EditFile_MissingPath(t *testing.T) {
tool := NewEditFileTool("", false)
ctx := context.Background()
- args := map[string]interface{}{
+ args := map[string]any{
"old_text": "old",
"new_text": "new",
}
@@ -182,7 +188,7 @@ func TestEditTool_EditFile_MissingPath(t *testing.T) {
func TestEditTool_EditFile_MissingOldText(t *testing.T) {
tool := NewEditFileTool("", false)
ctx := context.Background()
- args := map[string]interface{}{
+ args := map[string]any{
"path": "/tmp/test.txt",
"new_text": "new",
}
@@ -199,7 +205,7 @@ func TestEditTool_EditFile_MissingOldText(t *testing.T) {
func TestEditTool_EditFile_MissingNewText(t *testing.T) {
tool := NewEditFileTool("", false)
ctx := context.Background()
- args := map[string]interface{}{
+ args := map[string]any{
"path": "/tmp/test.txt",
"old_text": "old",
}
@@ -216,11 +222,11 @@ func TestEditTool_EditFile_MissingNewText(t *testing.T) {
func TestEditTool_AppendFile_Success(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.txt")
- os.WriteFile(testFile, []byte("Initial content"), 0644)
+ os.WriteFile(testFile, []byte("Initial content"), 0o644)
tool := NewAppendFileTool("", false)
ctx := context.Background()
- args := map[string]interface{}{
+ args := map[string]any{
"path": testFile,
"content": "\nAppended content",
}
@@ -260,7 +266,7 @@ func TestEditTool_AppendFile_Success(t *testing.T) {
func TestEditTool_AppendFile_MissingPath(t *testing.T) {
tool := NewAppendFileTool("", false)
ctx := context.Background()
- args := map[string]interface{}{
+ args := map[string]any{
"content": "test",
}
@@ -276,7 +282,7 @@ func TestEditTool_AppendFile_MissingPath(t *testing.T) {
func TestEditTool_AppendFile_MissingContent(t *testing.T) {
tool := NewAppendFileTool("", false)
ctx := context.Background()
- args := map[string]interface{}{
+ args := map[string]any{
"path": "/tmp/test.txt",
}
@@ -287,3 +293,145 @@ func TestEditTool_AppendFile_MissingContent(t *testing.T) {
t.Errorf("Expected error when content is missing")
}
}
+
+// TestReplaceEditContent verifies the helper function replaceEditContent
+func TestReplaceEditContent(t *testing.T) {
+ tests := []struct {
+ name string
+ content []byte
+ oldText string
+ newText string
+ expected []byte
+ expectError bool
+ }{
+ {
+ name: "successful replacement",
+ content: []byte("hello world"),
+ oldText: "world",
+ newText: "universe",
+ expected: []byte("hello universe"),
+ expectError: false,
+ },
+ {
+ name: "old text not found",
+ content: []byte("hello world"),
+ oldText: "golang",
+ newText: "rust",
+ expected: nil,
+ expectError: true,
+ },
+ {
+ name: "multiple matches found",
+ content: []byte("test text test"),
+ oldText: "test",
+ newText: "done",
+ expected: nil,
+ expectError: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result, err := replaceEditContent(tt.content, tt.oldText, tt.newText)
+ if tt.expectError {
+ assert.Error(t, err)
+ } else {
+ assert.NoError(t, err)
+ assert.Equal(t, tt.expected, result)
+ }
+ })
+ }
+}
+
+// TestAppendFileTool_AppendToNonExistent_Restricted verifies that AppendFileTool in restricted mode
+// can append to a file that does not yet exist — it should silently create the file.
+// This exercises the errors.Is(err, fs.ErrNotExist) path in appendFileWithRW + rootRW.
+func TestAppendFileTool_AppendToNonExistent_Restricted(t *testing.T) {
+ workspace := t.TempDir()
+ tool := NewAppendFileTool(workspace, true)
+ ctx := context.Background()
+
+ args := map[string]any{
+ "path": "brand_new_file.txt",
+ "content": "first content",
+ }
+
+ result := tool.Execute(ctx, args)
+ assert.False(
+ t,
+ result.IsError,
+ "Expected success when appending to non-existent file in restricted mode, got: %s",
+ result.ForLLM,
+ )
+
+ // Verify the file was created with correct content
+ data, err := os.ReadFile(filepath.Join(workspace, "brand_new_file.txt"))
+ assert.NoError(t, err)
+ assert.Equal(t, "first content", string(data))
+}
+
+// TestAppendFileTool_Restricted_Success verifies that AppendFileTool in restricted mode
+// correctly appends to an existing file within the sandbox.
+func TestAppendFileTool_Restricted_Success(t *testing.T) {
+ workspace := t.TempDir()
+ testFile := "existing.txt"
+ err := os.WriteFile(filepath.Join(workspace, testFile), []byte("initial"), 0o644)
+ assert.NoError(t, err)
+
+ tool := NewAppendFileTool(workspace, true)
+ ctx := context.Background()
+ args := map[string]any{
+ "path": testFile,
+ "content": " appended",
+ }
+
+ result := tool.Execute(ctx, args)
+ assert.False(t, result.IsError, "Expected success, got: %s", result.ForLLM)
+ assert.True(t, result.Silent)
+
+ data, err := os.ReadFile(filepath.Join(workspace, testFile))
+ assert.NoError(t, err)
+ assert.Equal(t, "initial appended", string(data))
+}
+
+// TestEditFileTool_Restricted_InPlaceEdit verifies that EditFileTool in restricted mode
+// correctly edits a file using the single-open editFileInRoot path.
+func TestEditFileTool_Restricted_InPlaceEdit(t *testing.T) {
+ workspace := t.TempDir()
+ testFile := "edit_target.txt"
+ err := os.WriteFile(filepath.Join(workspace, testFile), []byte("Hello World"), 0o644)
+ assert.NoError(t, err)
+
+ tool := NewEditFileTool(workspace, true)
+ ctx := context.Background()
+ args := map[string]any{
+ "path": testFile,
+ "old_text": "World",
+ "new_text": "Go",
+ }
+
+ result := tool.Execute(ctx, args)
+ assert.False(t, result.IsError, "Expected success, got: %s", result.ForLLM)
+ assert.True(t, result.Silent)
+
+ data, err := os.ReadFile(filepath.Join(workspace, testFile))
+ assert.NoError(t, err)
+ assert.Equal(t, "Hello Go", string(data))
+}
+
+// TestEditFileTool_Restricted_FileNotFound verifies that editFileInRoot returns a proper
+// error message when the target file does not exist.
+func TestEditFileTool_Restricted_FileNotFound(t *testing.T) {
+ workspace := t.TempDir()
+ tool := NewEditFileTool(workspace, true)
+ ctx := context.Background()
+ args := map[string]any{
+ "path": "no_such_file.txt",
+ "old_text": "old",
+ "new_text": "new",
+ }
+
+ result := tool.Execute(ctx, args)
+ assert.True(t, result.IsError)
+ assert.Contains(t, result.ForLLM, "not found")
+}
diff --git a/pkg/tools/filesystem.go b/pkg/tools/filesystem.go
index 09063ea0a..37db8b4ae 100644
--- a/pkg/tools/filesystem.go
+++ b/pkg/tools/filesystem.go
@@ -3,15 +3,17 @@ package tools
import (
"context"
"fmt"
+ "io/fs"
"os"
"path/filepath"
"strings"
+ "time"
)
// validatePath ensures the given path is within the workspace if restrict is true.
func validatePath(path, workspace string, restrict bool) (string, error) {
if workspace == "" {
- return path, nil
+ return path, fmt.Errorf("workspace is not defined")
}
absWorkspace, err := filepath.Abs(workspace)
@@ -34,17 +36,19 @@ func validatePath(path, workspace string, restrict bool) (string, error) {
return "", fmt.Errorf("access denied: path is outside the workspace")
}
+ var resolved string
workspaceReal := absWorkspace
- if resolved, err := filepath.EvalSymlinks(absWorkspace); err == nil {
+ if resolved, err = filepath.EvalSymlinks(absWorkspace); err == nil {
workspaceReal = resolved
}
- if resolved, err := filepath.EvalSymlinks(absPath); err == nil {
+ if resolved, err = filepath.EvalSymlinks(absPath); err == nil {
if !isWithinWorkspace(resolved, workspaceReal) {
return "", fmt.Errorf("access denied: symlink resolves outside workspace")
}
} else if os.IsNotExist(err) {
- if parentResolved, err := resolveExistingAncestor(filepath.Dir(absPath)); err == nil {
+ var parentResolved string
+ if parentResolved, err = resolveExistingAncestor(filepath.Dir(absPath)); err == nil {
if !isWithinWorkspace(parentResolved, workspaceReal) {
return "", fmt.Errorf("access denied: symlink resolves outside workspace")
}
@@ -74,16 +78,21 @@ func resolveExistingAncestor(path string) (string, error) {
func isWithinWorkspace(candidate, workspace string) bool {
rel, err := filepath.Rel(filepath.Clean(workspace), filepath.Clean(candidate))
- return err == nil && rel != ".." && !strings.HasPrefix(rel, ".."+string(os.PathSeparator))
+ return err == nil && filepath.IsLocal(rel)
}
type ReadFileTool struct {
- workspace string
- restrict bool
+ fs fileSystem
}
func NewReadFileTool(workspace string, restrict bool) *ReadFileTool {
- return &ReadFileTool{workspace: workspace, restrict: restrict}
+ var fs fileSystem
+ if restrict {
+ fs = &sandboxFs{workspace: workspace}
+ } else {
+ fs = &hostFs{}
+ }
+ return &ReadFileTool{fs: fs}
}
func (t *ReadFileTool) Name() string {
@@ -94,11 +103,11 @@ func (t *ReadFileTool) Description() string {
return "Read the contents of a file"
}
-func (t *ReadFileTool) Parameters() map[string]interface{} {
- return map[string]interface{}{
+func (t *ReadFileTool) Parameters() map[string]any {
+ return map[string]any{
"type": "object",
- "properties": map[string]interface{}{
- "path": map[string]interface{}{
+ "properties": map[string]any{
+ "path": map[string]any{
"type": "string",
"description": "Path to the file to read",
},
@@ -107,32 +116,31 @@ func (t *ReadFileTool) Parameters() map[string]interface{} {
}
}
-func (t *ReadFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
+func (t *ReadFileTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
path, ok := args["path"].(string)
if !ok {
return ErrorResult("path is required")
}
- resolvedPath, err := validatePath(path, t.workspace, t.restrict)
+ content, err := t.fs.ReadFile(path)
if err != nil {
return ErrorResult(err.Error())
}
-
- content, err := os.ReadFile(resolvedPath)
- if err != nil {
- return ErrorResult(fmt.Sprintf("failed to read file: %v", err))
- }
-
return NewToolResult(string(content))
}
type WriteFileTool struct {
- workspace string
- restrict bool
+ fs fileSystem
}
func NewWriteFileTool(workspace string, restrict bool) *WriteFileTool {
- return &WriteFileTool{workspace: workspace, restrict: restrict}
+ var fs fileSystem
+ if restrict {
+ fs = &sandboxFs{workspace: workspace}
+ } else {
+ fs = &hostFs{}
+ }
+ return &WriteFileTool{fs: fs}
}
func (t *WriteFileTool) Name() string {
@@ -143,15 +151,15 @@ func (t *WriteFileTool) Description() string {
return "Write content to a file"
}
-func (t *WriteFileTool) Parameters() map[string]interface{} {
- return map[string]interface{}{
+func (t *WriteFileTool) Parameters() map[string]any {
+ return map[string]any{
"type": "object",
- "properties": map[string]interface{}{
- "path": map[string]interface{}{
+ "properties": map[string]any{
+ "path": map[string]any{
"type": "string",
"description": "Path to the file to write",
},
- "content": map[string]interface{}{
+ "content": map[string]any{
"type": "string",
"description": "Content to write to the file",
},
@@ -160,7 +168,7 @@ func (t *WriteFileTool) Parameters() map[string]interface{} {
}
}
-func (t *WriteFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
+func (t *WriteFileTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
path, ok := args["path"].(string)
if !ok {
return ErrorResult("path is required")
@@ -171,30 +179,25 @@ func (t *WriteFileTool) Execute(ctx context.Context, args map[string]interface{}
return ErrorResult("content is required")
}
- resolvedPath, err := validatePath(path, t.workspace, t.restrict)
- if err != nil {
+ if err := t.fs.WriteFile(path, []byte(content)); err != nil {
return ErrorResult(err.Error())
}
- dir := filepath.Dir(resolvedPath)
- if err := os.MkdirAll(dir, 0755); err != nil {
- return ErrorResult(fmt.Sprintf("failed to create directory: %v", err))
- }
-
- if err := os.WriteFile(resolvedPath, []byte(content), 0644); err != nil {
- return ErrorResult(fmt.Sprintf("failed to write file: %v", err))
- }
-
return SilentResult(fmt.Sprintf("File written: %s", path))
}
type ListDirTool struct {
- workspace string
- restrict bool
+ fs fileSystem
}
func NewListDirTool(workspace string, restrict bool) *ListDirTool {
- return &ListDirTool{workspace: workspace, restrict: restrict}
+ var fs fileSystem
+ if restrict {
+ fs = &sandboxFs{workspace: workspace}
+ } else {
+ fs = &hostFs{}
+ }
+ return &ListDirTool{fs: fs}
}
func (t *ListDirTool) Name() string {
@@ -205,11 +208,11 @@ func (t *ListDirTool) Description() string {
return "List files and directories in a path"
}
-func (t *ListDirTool) Parameters() map[string]interface{} {
- return map[string]interface{}{
+func (t *ListDirTool) Parameters() map[string]any {
+ return map[string]any{
"type": "object",
- "properties": map[string]interface{}{
- "path": map[string]interface{}{
+ "properties": map[string]any{
+ "path": map[string]any{
"type": "string",
"description": "Path to list",
},
@@ -218,30 +221,185 @@ func (t *ListDirTool) Parameters() map[string]interface{} {
}
}
-func (t *ListDirTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
+func (t *ListDirTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
path, ok := args["path"].(string)
if !ok {
path = "."
}
- resolvedPath, err := validatePath(path, t.workspace, t.restrict)
- if err != nil {
- return ErrorResult(err.Error())
- }
-
- entries, err := os.ReadDir(resolvedPath)
+ entries, err := t.fs.ReadDir(path)
if err != nil {
return ErrorResult(fmt.Sprintf("failed to read directory: %v", err))
}
+ return formatDirEntries(entries)
+}
- result := ""
+func formatDirEntries(entries []os.DirEntry) *ToolResult {
+ var result strings.Builder
for _, entry := range entries {
if entry.IsDir() {
- result += "DIR: " + entry.Name() + "\n"
+ result.WriteString("DIR: " + entry.Name() + "\n")
} else {
- result += "FILE: " + entry.Name() + "\n"
+ result.WriteString("FILE: " + entry.Name() + "\n")
+ }
+ }
+ return NewToolResult(result.String())
+}
+
+// fileSystem abstracts reading, writing, and listing files, allowing both
+// unrestricted (host filesystem) and sandbox (os.Root) implementations to share the same polymorphic interface.
+type fileSystem interface {
+ ReadFile(path string) ([]byte, error)
+ WriteFile(path string, data []byte) error
+ ReadDir(path string) ([]os.DirEntry, error)
+}
+
+// hostFs is an unrestricted fileReadWriter that operates directly on the host filesystem.
+type hostFs struct{}
+
+func (h *hostFs) ReadFile(path string) ([]byte, error) {
+ content, err := os.ReadFile(path)
+ if err != nil {
+ if os.IsNotExist(err) {
+ return nil, fmt.Errorf("failed to read file: file not found: %w", err)
+ }
+ if os.IsPermission(err) {
+ return nil, fmt.Errorf("failed to read file: access denied: %w", err)
+ }
+ return nil, fmt.Errorf("failed to read file: %w", err)
+ }
+ return content, nil
+}
+
+func (h *hostFs) ReadDir(path string) ([]os.DirEntry, error) {
+ return os.ReadDir(path)
+}
+
+func (h *hostFs) WriteFile(path string, data []byte) error {
+ dir := filepath.Dir(path)
+ if err := os.MkdirAll(dir, 0o755); err != nil {
+ return fmt.Errorf("failed to create parent directories: %w", err)
+ }
+
+ // We use a "write-then-rename" pattern here to ensure an atomic write.
+ // This prevents the target file from being left in a truncated or partial state
+ // if the operation is interrupted, as the rename operation is atomic on Linux.
+ tmpPath := fmt.Sprintf("%s.%d.tmp", path, time.Now().UnixNano())
+ if err := os.WriteFile(tmpPath, data, 0o644); err != nil {
+ os.Remove(tmpPath) // Ensure cleanup of partial/empty temp file
+ return fmt.Errorf("failed to write temp file: %w", err)
+ }
+
+ if err := os.Rename(tmpPath, path); err != nil {
+ os.Remove(tmpPath)
+ return fmt.Errorf("failed to replace original file: %w", err)
+ }
+ return nil
+}
+
+// sandboxFs is a sandboxed fileSystem that operates within a strictly defined workspace using os.Root.
+type sandboxFs struct {
+ workspace string
+}
+
+func (r *sandboxFs) execute(path string, fn func(root *os.Root, relPath string) error) error {
+ if r.workspace == "" {
+ return fmt.Errorf("workspace is not defined")
+ }
+
+ root, err := os.OpenRoot(r.workspace)
+ if err != nil {
+ return fmt.Errorf("failed to open workspace: %w", err)
+ }
+ defer root.Close()
+
+ relPath, err := getSafeRelPath(r.workspace, path)
+ if err != nil {
+ return err
+ }
+
+ return fn(root, relPath)
+}
+
+func (r *sandboxFs) ReadFile(path string) ([]byte, error) {
+ var content []byte
+ err := r.execute(path, func(root *os.Root, relPath string) error {
+ fileContent, err := root.ReadFile(relPath)
+ if err != nil {
+ if os.IsNotExist(err) {
+ return fmt.Errorf("failed to read file: file not found: %w", err)
+ }
+ // os.Root returns "escapes from parent" for paths outside the root
+ if os.IsPermission(err) || strings.Contains(err.Error(), "escapes from parent") ||
+ strings.Contains(err.Error(), "permission denied") {
+ return fmt.Errorf("failed to read file: access denied: %w", err)
+ }
+ return fmt.Errorf("failed to read file: %w", err)
+ }
+ content = fileContent
+ return nil
+ })
+ return content, err
+}
+
+func (r *sandboxFs) WriteFile(path string, data []byte) error {
+ return r.execute(path, func(root *os.Root, relPath string) error {
+ dir := filepath.Dir(relPath)
+ if dir != "." && dir != "/" {
+ if err := root.MkdirAll(dir, 0o755); err != nil {
+ return fmt.Errorf("failed to create parent directories: %w", err)
+ }
+ }
+
+ // We use a "write-then-rename" pattern here to ensure an atomic write.
+ // This prevents the target file from being left in a truncated or partial state
+ // if the operation is interrupted, as the rename operation is atomic on Linux.
+ tmpRelPath := fmt.Sprintf("%s.%d.tmp", relPath, time.Now().UnixNano())
+
+ if err := root.WriteFile(tmpRelPath, data, 0o644); err != nil {
+ root.Remove(tmpRelPath) // Ensure cleanup of partial/empty temp file
+ return fmt.Errorf("failed to write to temp file: %w", err)
+ }
+
+ if err := root.Rename(tmpRelPath, relPath); err != nil {
+ root.Remove(tmpRelPath)
+ return fmt.Errorf("failed to rename temp file over target: %w", err)
+ }
+ return nil
+ })
+}
+
+func (r *sandboxFs) ReadDir(path string) ([]os.DirEntry, error) {
+ var entries []os.DirEntry
+ err := r.execute(path, func(root *os.Root, relPath string) error {
+ dirEntries, err := fs.ReadDir(root.FS(), relPath)
+ if err != nil {
+ return err
+ }
+ entries = dirEntries
+ return nil
+ })
+ return entries, err
+}
+
+// Helper to get a safe relative path for os.Root usage
+func getSafeRelPath(workspace, path string) (string, error) {
+ if workspace == "" {
+ return "", fmt.Errorf("workspace is not defined")
+ }
+
+ rel := filepath.Clean(path)
+ if filepath.IsAbs(rel) {
+ var err error
+ rel, err = filepath.Rel(workspace, rel)
+ if err != nil {
+ return "", fmt.Errorf("failed to calculate relative path: %w", err)
}
}
- return NewToolResult(result)
+ if !filepath.IsLocal(rel) {
+ return "", fmt.Errorf("path escapes workspace: %s", path)
+ }
+
+ return rel, nil
}
diff --git a/pkg/tools/filesystem_test.go b/pkg/tools/filesystem_test.go
index 958036419..6f896e22d 100644
--- a/pkg/tools/filesystem_test.go
+++ b/pkg/tools/filesystem_test.go
@@ -2,21 +2,24 @@ package tools
import (
"context"
+ "io"
"os"
"path/filepath"
"strings"
"testing"
+
+ "github.com/stretchr/testify/assert"
)
// TestFilesystemTool_ReadFile_Success verifies successful file reading
func TestFilesystemTool_ReadFile_Success(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.txt")
- os.WriteFile(testFile, []byte("test content"), 0644)
+ os.WriteFile(testFile, []byte("test content"), 0o644)
- tool := &ReadFileTool{}
+ tool := NewReadFileTool("", false)
ctx := context.Background()
- args := map[string]interface{}{
+ args := map[string]any{
"path": testFile,
}
@@ -41,9 +44,9 @@ func TestFilesystemTool_ReadFile_Success(t *testing.T) {
// TestFilesystemTool_ReadFile_NotFound verifies error handling for missing file
func TestFilesystemTool_ReadFile_NotFound(t *testing.T) {
- tool := &ReadFileTool{}
+ tool := NewReadFileTool("", false)
ctx := context.Background()
- args := map[string]interface{}{
+ args := map[string]any{
"path": "/nonexistent_file_12345.txt",
}
@@ -64,7 +67,7 @@ func TestFilesystemTool_ReadFile_NotFound(t *testing.T) {
func TestFilesystemTool_ReadFile_MissingPath(t *testing.T) {
tool := &ReadFileTool{}
ctx := context.Background()
- args := map[string]interface{}{}
+ args := map[string]any{}
result := tool.Execute(ctx, args)
@@ -84,9 +87,9 @@ func TestFilesystemTool_WriteFile_Success(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "newfile.txt")
- tool := &WriteFileTool{}
+ tool := NewWriteFileTool("", false)
ctx := context.Background()
- args := map[string]interface{}{
+ args := map[string]any{
"path": testFile,
"content": "hello world",
}
@@ -123,9 +126,9 @@ func TestFilesystemTool_WriteFile_CreateDir(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "subdir", "newfile.txt")
- tool := &WriteFileTool{}
+ tool := NewWriteFileTool("", false)
ctx := context.Background()
- args := map[string]interface{}{
+ args := map[string]any{
"path": testFile,
"content": "test",
}
@@ -149,9 +152,9 @@ func TestFilesystemTool_WriteFile_CreateDir(t *testing.T) {
// TestFilesystemTool_WriteFile_MissingPath verifies error handling for missing path
func TestFilesystemTool_WriteFile_MissingPath(t *testing.T) {
- tool := &WriteFileTool{}
+ tool := NewWriteFileTool("", false)
ctx := context.Background()
- args := map[string]interface{}{
+ args := map[string]any{
"content": "test",
}
@@ -165,9 +168,9 @@ func TestFilesystemTool_WriteFile_MissingPath(t *testing.T) {
// TestFilesystemTool_WriteFile_MissingContent verifies error handling for missing content
func TestFilesystemTool_WriteFile_MissingContent(t *testing.T) {
- tool := &WriteFileTool{}
+ tool := NewWriteFileTool("", false)
ctx := context.Background()
- args := map[string]interface{}{
+ args := map[string]any{
"path": "/tmp/test.txt",
}
@@ -179,7 +182,8 @@ func TestFilesystemTool_WriteFile_MissingContent(t *testing.T) {
}
// Should mention required parameter
- if !strings.Contains(result.ForLLM, "content is required") && !strings.Contains(result.ForUser, "content is required") {
+ if !strings.Contains(result.ForLLM, "content is required") &&
+ !strings.Contains(result.ForUser, "content is required") {
t.Errorf("Expected 'content is required' message, got ForLLM: %s", result.ForLLM)
}
}
@@ -187,13 +191,13 @@ func TestFilesystemTool_WriteFile_MissingContent(t *testing.T) {
// TestFilesystemTool_ListDir_Success verifies successful directory listing
func TestFilesystemTool_ListDir_Success(t *testing.T) {
tmpDir := t.TempDir()
- os.WriteFile(filepath.Join(tmpDir, "file1.txt"), []byte("content"), 0644)
- os.WriteFile(filepath.Join(tmpDir, "file2.txt"), []byte("content"), 0644)
- os.Mkdir(filepath.Join(tmpDir, "subdir"), 0755)
+ os.WriteFile(filepath.Join(tmpDir, "file1.txt"), []byte("content"), 0o644)
+ os.WriteFile(filepath.Join(tmpDir, "file2.txt"), []byte("content"), 0o644)
+ os.Mkdir(filepath.Join(tmpDir, "subdir"), 0o755)
- tool := &ListDirTool{}
+ tool := NewListDirTool("", false)
ctx := context.Background()
- args := map[string]interface{}{
+ args := map[string]any{
"path": tmpDir,
}
@@ -215,9 +219,9 @@ func TestFilesystemTool_ListDir_Success(t *testing.T) {
// TestFilesystemTool_ListDir_NotFound verifies error handling for non-existent directory
func TestFilesystemTool_ListDir_NotFound(t *testing.T) {
- tool := &ListDirTool{}
+ tool := NewListDirTool("", false)
ctx := context.Background()
- args := map[string]interface{}{
+ args := map[string]any{
"path": "/nonexistent_directory_12345",
}
@@ -236,9 +240,9 @@ func TestFilesystemTool_ListDir_NotFound(t *testing.T) {
// TestFilesystemTool_ListDir_DefaultPath verifies default to current directory
func TestFilesystemTool_ListDir_DefaultPath(t *testing.T) {
- tool := &ListDirTool{}
+ tool := NewListDirTool("", false)
ctx := context.Background()
- args := map[string]interface{}{}
+ args := map[string]any{}
result := tool.Execute(ctx, args)
@@ -250,15 +254,14 @@ func TestFilesystemTool_ListDir_DefaultPath(t *testing.T) {
// Block paths that look inside workspace but point outside via symlink.
func TestFilesystemTool_ReadFile_RejectsSymlinkEscape(t *testing.T) {
-
root := t.TempDir()
workspace := filepath.Join(root, "workspace")
- if err := os.MkdirAll(workspace, 0755); err != nil {
+ if err := os.MkdirAll(workspace, 0o755); err != nil {
t.Fatalf("failed to create workspace: %v", err)
}
secret := filepath.Join(root, "secret.txt")
- if err := os.WriteFile(secret, []byte("top secret"), 0644); err != nil {
+ if err := os.WriteFile(secret, []byte("top secret"), 0o644); err != nil {
t.Fatalf("failed to write secret file: %v", err)
}
@@ -268,14 +271,218 @@ func TestFilesystemTool_ReadFile_RejectsSymlinkEscape(t *testing.T) {
}
tool := NewReadFileTool(workspace, true)
- result := tool.Execute(context.Background(), map[string]interface{}{
+ result := tool.Execute(context.Background(), map[string]any{
"path": link,
})
if !result.IsError {
t.Fatalf("expected symlink escape to be blocked")
}
- if !strings.Contains(result.ForLLM, "symlink resolves outside workspace") {
+ // os.Root might return different errors depending on platform/implementation
+ // but it definitely should error.
+ // Our wrapper returns "access denied or file not found"
+ if !strings.Contains(result.ForLLM, "access denied") && !strings.Contains(result.ForLLM, "file not found") &&
+ !strings.Contains(result.ForLLM, "no such file") {
t.Fatalf("expected symlink escape error, got: %s", result.ForLLM)
}
}
+
+func TestFilesystemTool_EmptyWorkspace_AccessDenied(t *testing.T) {
+ tool := NewReadFileTool("", true) // restrict=true but workspace=""
+
+ // Try to read a sensitive file (simulated by a temp file outside workspace)
+ tmpDir := t.TempDir()
+ secretFile := filepath.Join(tmpDir, "shadow")
+ os.WriteFile(secretFile, []byte("secret data"), 0o600)
+
+ result := tool.Execute(context.Background(), map[string]any{
+ "path": secretFile,
+ })
+
+ // We EXPECT IsError=true (access blocked due to empty workspace)
+ assert.True(t, result.IsError, "Security Regression: Empty workspace allowed access! content: %s", result.ForLLM)
+
+ // Verify it failed for the right reason
+ assert.Contains(t, result.ForLLM, "workspace is not defined", "Expected 'workspace is not defined' error")
+}
+
+// TestRootMkdirAll verifies that root.MkdirAll (used by atomicWriteFileInRoot) handles all cases:
+// single dir, deeply nested dirs, already-existing dirs, and a file blocking a directory path.
+func TestRootMkdirAll(t *testing.T) {
+ workspace := t.TempDir()
+ root, err := os.OpenRoot(workspace)
+ if err != nil {
+ t.Fatalf("failed to open root: %v", err)
+ }
+ defer root.Close()
+
+ // Case 1: Single directory
+ err = root.MkdirAll("dir1", 0o755)
+ assert.NoError(t, err)
+ _, err = os.Stat(filepath.Join(workspace, "dir1"))
+ assert.NoError(t, err)
+
+ // Case 2: Deeply nested directory
+ err = root.MkdirAll("a/b/c/d", 0o755)
+ assert.NoError(t, err)
+ _, err = os.Stat(filepath.Join(workspace, "a/b/c/d"))
+ assert.NoError(t, err)
+
+ // Case 3: Already exists — must be idempotent
+ err = root.MkdirAll("a/b/c/d", 0o755)
+ assert.NoError(t, err)
+
+ // Case 4: A regular file blocks directory creation — must error
+ err = os.WriteFile(filepath.Join(workspace, "file_exists"), []byte("data"), 0o644)
+ assert.NoError(t, err)
+ err = root.MkdirAll("file_exists", 0o755)
+ assert.Error(t, err, "expected error when a file exists at the directory path")
+}
+
+func TestFilesystemTool_WriteFile_Restricted_CreateDir(t *testing.T) {
+ workspace := t.TempDir()
+ tool := NewWriteFileTool(workspace, true)
+ ctx := context.Background()
+
+ testFile := "deep/nested/path/to/file.txt"
+ content := "deep content"
+ args := map[string]any{
+ "path": testFile,
+ "content": content,
+ }
+
+ result := tool.Execute(ctx, args)
+ assert.False(t, result.IsError, "Expected success, got: %s", result.ForLLM)
+
+ // Verify file content
+ actualPath := filepath.Join(workspace, testFile)
+ data, err := os.ReadFile(actualPath)
+ assert.NoError(t, err)
+ assert.Equal(t, content, string(data))
+}
+
+// TestHostRW_Read_PermissionDenied verifies that hostRW.Read surfaces access denied errors.
+func TestHostRW_Read_PermissionDenied(t *testing.T) {
+ if os.Getuid() == 0 {
+ t.Skip("skipping permission test: running as root")
+ }
+ tmpDir := t.TempDir()
+ protected := filepath.Join(tmpDir, "protected.txt")
+ err := os.WriteFile(protected, []byte("secret"), 0o000)
+ assert.NoError(t, err)
+ defer os.Chmod(protected, 0o644) // ensure cleanup
+
+ _, err = (&hostFs{}).ReadFile(protected)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "access denied")
+}
+
+// TestHostRW_Read_Directory verifies that hostRW.Read returns an error when given a directory path.
+func TestHostRW_Read_Directory(t *testing.T) {
+ tmpDir := t.TempDir()
+
+ _, err := (&hostFs{}).ReadFile(tmpDir)
+ assert.Error(t, err, "expected error when reading a directory as a file")
+}
+
+// TestRootRW_Read_Directory verifies that rootRW.Read returns an error when given a directory.
+func TestRootRW_Read_Directory(t *testing.T) {
+ workspace := t.TempDir()
+ root, err := os.OpenRoot(workspace)
+ assert.NoError(t, err)
+ defer root.Close()
+
+ // Create a subdirectory
+ err = root.Mkdir("subdir", 0o755)
+ assert.NoError(t, err)
+
+ _, err = (&sandboxFs{workspace: workspace}).ReadFile("subdir")
+ assert.Error(t, err, "expected error when reading a directory as a file")
+}
+
+// TestHostRW_Write_ParentDirMissing verifies that hostRW.Write creates parent dirs automatically.
+func TestHostRW_Write_ParentDirMissing(t *testing.T) {
+ tmpDir := t.TempDir()
+ target := filepath.Join(tmpDir, "a", "b", "c", "file.txt")
+
+ err := (&hostFs{}).WriteFile(target, []byte("hello"))
+ assert.NoError(t, err)
+
+ data, err := os.ReadFile(target)
+ assert.NoError(t, err)
+ assert.Equal(t, "hello", string(data))
+}
+
+// TestRootRW_Write_ParentDirMissing verifies that rootRW.Write creates
+// nested parent directories automatically within the sandbox.
+func TestRootRW_Write_ParentDirMissing(t *testing.T) {
+ workspace := t.TempDir()
+
+ relPath := "x/y/z/file.txt"
+ err := (&sandboxFs{workspace: workspace}).WriteFile(relPath, []byte("nested"))
+ assert.NoError(t, err)
+
+ data, err := os.ReadFile(filepath.Join(workspace, relPath))
+ assert.NoError(t, err)
+ assert.Equal(t, "nested", string(data))
+}
+
+// TestHostRW_Write verifies the hostRW.Write helper function
+func TestHostRW_Write(t *testing.T) {
+ tmpDir := t.TempDir()
+ testFile := filepath.Join(tmpDir, "atomic_test.txt")
+ testData := []byte("atomic test content")
+
+ err := (&hostFs{}).WriteFile(testFile, testData)
+ assert.NoError(t, err)
+
+ content, err := os.ReadFile(testFile)
+ assert.NoError(t, err)
+ assert.Equal(t, testData, content)
+
+ // Verify it overwrites correctly
+ newData := []byte("new atomic content")
+ err = (&hostFs{}).WriteFile(testFile, newData)
+ assert.NoError(t, err)
+
+ content, err = os.ReadFile(testFile)
+ assert.NoError(t, err)
+ assert.Equal(t, newData, content)
+}
+
+// TestRootRW_Write verifies the rootRW.Write helper function
+func TestRootRW_Write(t *testing.T) {
+ tmpDir := t.TempDir()
+
+ relPath := "atomic_root_test.txt"
+ testData := []byte("atomic root test content")
+
+ erw := &sandboxFs{workspace: tmpDir}
+ err := erw.WriteFile(relPath, testData)
+ assert.NoError(t, err)
+
+ root, err := os.OpenRoot(tmpDir)
+ assert.NoError(t, err)
+ defer root.Close()
+
+ f, err := root.Open(relPath)
+ assert.NoError(t, err)
+ defer f.Close()
+
+ content, err := io.ReadAll(f)
+ assert.NoError(t, err)
+ assert.Equal(t, testData, content)
+
+ // Verify it overwrites correctly
+ newData := []byte("new root atomic content")
+ err = erw.WriteFile(relPath, newData)
+ assert.NoError(t, err)
+
+ f2, err := root.Open(relPath)
+ assert.NoError(t, err)
+ defer f2.Close()
+
+ content, err = io.ReadAll(f2)
+ assert.NoError(t, err)
+ assert.Equal(t, newData, content)
+}
diff --git a/pkg/tools/i2c.go b/pkg/tools/i2c.go
index abca5ec1e..779b1d5a7 100644
--- a/pkg/tools/i2c.go
+++ b/pkg/tools/i2c.go
@@ -24,37 +24,37 @@ func (t *I2CTool) Description() string {
return "Interact with I2C bus devices for reading sensors and controlling peripherals. Actions: detect (list buses), scan (find devices on a bus), read (read bytes from device), write (send bytes to device). Linux only."
}
-func (t *I2CTool) Parameters() map[string]interface{} {
- return map[string]interface{}{
+func (t *I2CTool) Parameters() map[string]any {
+ return map[string]any{
"type": "object",
- "properties": map[string]interface{}{
- "action": map[string]interface{}{
+ "properties": map[string]any{
+ "action": map[string]any{
"type": "string",
"enum": []string{"detect", "scan", "read", "write"},
"description": "Action to perform: detect (list available I2C buses), scan (find devices on a bus), read (read bytes from a device), write (send bytes to a device)",
},
- "bus": map[string]interface{}{
+ "bus": map[string]any{
"type": "string",
"description": "I2C bus number (e.g. \"1\" for /dev/i2c-1). Required for scan/read/write.",
},
- "address": map[string]interface{}{
+ "address": map[string]any{
"type": "integer",
"description": "7-bit I2C device address (0x03-0x77). Required for read/write.",
},
- "register": map[string]interface{}{
+ "register": map[string]any{
"type": "integer",
"description": "Register address to read from or write to. If set, sends register byte before read/write.",
},
- "data": map[string]interface{}{
+ "data": map[string]any{
"type": "array",
- "items": map[string]interface{}{"type": "integer"},
+ "items": map[string]any{"type": "integer"},
"description": "Bytes to write (0-255 each). Required for write action.",
},
- "length": map[string]interface{}{
+ "length": map[string]any{
"type": "integer",
"description": "Number of bytes to read (1-256). Default: 1. Used with read action.",
},
- "confirm": map[string]interface{}{
+ "confirm": map[string]any{
"type": "boolean",
"description": "Must be true for write operations. Safety guard to prevent accidental writes.",
},
@@ -63,7 +63,7 @@ func (t *I2CTool) Parameters() map[string]interface{} {
}
}
-func (t *I2CTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
+func (t *I2CTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
if runtime.GOOS != "linux" {
return ErrorResult("I2C is only supported on Linux. This tool requires /dev/i2c-* device files.")
}
@@ -95,7 +95,9 @@ func (t *I2CTool) detect() *ToolResult {
}
if len(matches) == 0 {
- return SilentResult("No I2C buses found. You may need to:\n1. Load the i2c-dev module: modprobe i2c-dev\n2. Check that I2C is enabled in device tree\n3. Configure pinmux for your board (see hardware skill)")
+ return SilentResult(
+ "No I2C buses found. You may need to:\n1. Load the i2c-dev module: modprobe i2c-dev\n2. Check that I2C is enabled in device tree\n3. Configure pinmux for your board (see hardware skill)",
+ )
}
type busInfo struct {
@@ -115,14 +117,20 @@ func (t *I2CTool) detect() *ToolResult {
return SilentResult(fmt.Sprintf("Found %d I2C bus(es):\n%s", len(buses), string(result)))
}
+// Helper functions for I2C operations (used by platform-specific implementations)
+
// isValidBusID checks that a bus identifier is a simple number (prevents path injection)
+//
+//nolint:unused // Used by i2c_linux.go
func isValidBusID(id string) bool {
matched, _ := regexp.MatchString(`^\d+$`, id)
return matched
}
// parseI2CAddress extracts and validates an I2C address from args
-func parseI2CAddress(args map[string]interface{}) (int, *ToolResult) {
+//
+//nolint:unused // Used by i2c_linux.go
+func parseI2CAddress(args map[string]any) (int, *ToolResult) {
addrFloat, ok := args["address"].(float64)
if !ok {
return 0, ErrorResult("address is required (e.g. 0x38 for AHT20)")
@@ -135,7 +143,9 @@ func parseI2CAddress(args map[string]interface{}) (int, *ToolResult) {
}
// parseI2CBus extracts and validates an I2C bus from args
-func parseI2CBus(args map[string]interface{}) (string, *ToolResult) {
+//
+//nolint:unused // Used by i2c_linux.go
+func parseI2CBus(args map[string]any) (string, *ToolResult) {
bus, ok := args["bus"].(string)
if !ok || bus == "" {
return "", ErrorResult("bus is required (e.g. \"1\" for /dev/i2c-1)")
diff --git a/pkg/tools/i2c_linux.go b/pkg/tools/i2c_linux.go
index 294f7ecbc..4eaaf8f09 100644
--- a/pkg/tools/i2c_linux.go
+++ b/pkg/tools/i2c_linux.go
@@ -74,7 +74,7 @@ func smbusProbe(fd int, addr int, hasQuick bool) bool {
// scan probes valid 7-bit addresses on a bus for connected devices.
// Uses the same hybrid probe strategy as i2cdetect's MODE_AUTO:
// SMBus Quick Write for most addresses, SMBus Read Byte for EEPROM ranges.
-func (t *I2CTool) scan(args map[string]interface{}) *ToolResult {
+func (t *I2CTool) scan(args map[string]any) *ToolResult {
bus, errResult := parseI2CBus(args)
if errResult != nil {
return errResult
@@ -99,7 +99,9 @@ func (t *I2CTool) scan(args map[string]interface{}) *ToolResult {
hasReadByte := funcs&i2cFuncSmbusReadByte != 0
if !hasQuick && !hasReadByte {
- return ErrorResult(fmt.Sprintf("I2C adapter %s supports neither SMBus Quick nor Read Byte — cannot probe safely", devPath))
+ return ErrorResult(
+ fmt.Sprintf("I2C adapter %s supports neither SMBus Quick nor Read Byte — cannot probe safely", devPath),
+ )
}
type deviceEntry struct {
@@ -133,7 +135,7 @@ func (t *I2CTool) scan(args map[string]interface{}) *ToolResult {
return SilentResult(fmt.Sprintf("No devices found on %s. Check wiring and pull-up resistors.", devPath))
}
- result, _ := json.MarshalIndent(map[string]interface{}{
+ result, _ := json.MarshalIndent(map[string]any{
"bus": devPath,
"devices": found,
"count": len(found),
@@ -142,7 +144,7 @@ func (t *I2CTool) scan(args map[string]interface{}) *ToolResult {
}
// readDevice reads bytes from an I2C device, optionally at a specific register
-func (t *I2CTool) readDevice(args map[string]interface{}) *ToolResult {
+func (t *I2CTool) readDevice(args map[string]any) *ToolResult {
bus, errResult := parseI2CBus(args)
if errResult != nil {
return errResult
@@ -180,7 +182,7 @@ func (t *I2CTool) readDevice(args map[string]interface{}) *ToolResult {
if reg < 0 || reg > 255 {
return ErrorResult("register must be between 0x00 and 0xFF")
}
- _, err := syscall.Write(fd, []byte{byte(reg)})
+ _, err = syscall.Write(fd, []byte{byte(reg)})
if err != nil {
return ErrorResult(fmt.Sprintf("failed to write register 0x%02x: %v", reg, err))
}
@@ -201,7 +203,7 @@ func (t *I2CTool) readDevice(args map[string]interface{}) *ToolResult {
intBytes[i] = int(buf[i])
}
- result, _ := json.MarshalIndent(map[string]interface{}{
+ result, _ := json.MarshalIndent(map[string]any{
"bus": devPath,
"address": fmt.Sprintf("0x%02x", addr),
"bytes": intBytes,
@@ -212,10 +214,12 @@ func (t *I2CTool) readDevice(args map[string]interface{}) *ToolResult {
}
// writeDevice writes bytes to an I2C device, optionally at a specific register
-func (t *I2CTool) writeDevice(args map[string]interface{}) *ToolResult {
+func (t *I2CTool) writeDevice(args map[string]any) *ToolResult {
confirm, _ := args["confirm"].(bool)
if !confirm {
- return ErrorResult("write operations require confirm: true. Please confirm with the user before writing to I2C devices, as incorrect writes can misconfigure hardware.")
+ return ErrorResult(
+ "write operations require confirm: true. Please confirm with the user before writing to I2C devices, as incorrect writes can misconfigure hardware.",
+ )
}
bus, errResult := parseI2CBus(args)
@@ -228,7 +232,7 @@ func (t *I2CTool) writeDevice(args map[string]interface{}) *ToolResult {
return errResult
}
- dataRaw, ok := args["data"].([]interface{})
+ dataRaw, ok := args["data"].([]any)
if !ok || len(dataRaw) == 0 {
return ErrorResult("data is required for write (array of byte values 0-255)")
}
diff --git a/pkg/tools/i2c_other.go b/pkg/tools/i2c_other.go
index d1d581348..7becf8339 100644
--- a/pkg/tools/i2c_other.go
+++ b/pkg/tools/i2c_other.go
@@ -3,16 +3,16 @@
package tools
// scan is a stub for non-Linux platforms.
-func (t *I2CTool) scan(args map[string]interface{}) *ToolResult {
+func (t *I2CTool) scan(args map[string]any) *ToolResult {
return ErrorResult("I2C is only supported on Linux")
}
// readDevice is a stub for non-Linux platforms.
-func (t *I2CTool) readDevice(args map[string]interface{}) *ToolResult {
+func (t *I2CTool) readDevice(args map[string]any) *ToolResult {
return ErrorResult("I2C is only supported on Linux")
}
// writeDevice is a stub for non-Linux platforms.
-func (t *I2CTool) writeDevice(args map[string]interface{}) *ToolResult {
+func (t *I2CTool) writeDevice(args map[string]any) *ToolResult {
return ErrorResult("I2C is only supported on Linux")
}
diff --git a/pkg/tools/message.go b/pkg/tools/message.go
index abedb1316..15ef4ff73 100644
--- a/pkg/tools/message.go
+++ b/pkg/tools/message.go
@@ -26,19 +26,19 @@ func (t *MessageTool) Description() string {
return "Send a message to user on a chat channel. Use this when you want to communicate something."
}
-func (t *MessageTool) Parameters() map[string]interface{} {
- return map[string]interface{}{
+func (t *MessageTool) Parameters() map[string]any {
+ return map[string]any{
"type": "object",
- "properties": map[string]interface{}{
- "content": map[string]interface{}{
+ "properties": map[string]any{
+ "content": map[string]any{
"type": "string",
"description": "The message content to send",
},
- "channel": map[string]interface{}{
+ "channel": map[string]any{
"type": "string",
"description": "Optional: target channel (telegram, whatsapp, etc.)",
},
- "chat_id": map[string]interface{}{
+ "chat_id": map[string]any{
"type": "string",
"description": "Optional: target chat/user ID",
},
@@ -62,7 +62,7 @@ func (t *MessageTool) SetSendCallback(callback SendCallback) {
t.sendCallback = callback
}
-func (t *MessageTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
+func (t *MessageTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
content, ok := args["content"].(string)
if !ok {
return &ToolResult{ForLLM: "content is required", IsError: true}
diff --git a/pkg/tools/message_test.go b/pkg/tools/message_test.go
index 4bedbe79b..717c1117b 100644
--- a/pkg/tools/message_test.go
+++ b/pkg/tools/message_test.go
@@ -19,7 +19,7 @@ func TestMessageTool_Execute_Success(t *testing.T) {
})
ctx := context.Background()
- args := map[string]interface{}{
+ args := map[string]any{
"content": "Hello, world!",
}
@@ -70,7 +70,7 @@ func TestMessageTool_Execute_WithCustomChannel(t *testing.T) {
})
ctx := context.Background()
- args := map[string]interface{}{
+ args := map[string]any{
"content": "Test message",
"channel": "custom-channel",
"chat_id": "custom-chat-id",
@@ -104,7 +104,7 @@ func TestMessageTool_Execute_SendFailure(t *testing.T) {
})
ctx := context.Background()
- args := map[string]interface{}{
+ args := map[string]any{
"content": "Test message",
}
@@ -136,7 +136,7 @@ func TestMessageTool_Execute_MissingContent(t *testing.T) {
tool.SetContext("test-channel", "test-chat-id")
ctx := context.Background()
- args := map[string]interface{}{} // content missing
+ args := map[string]any{} // content missing
result := tool.Execute(ctx, args)
@@ -158,7 +158,7 @@ func TestMessageTool_Execute_NoTargetChannel(t *testing.T) {
})
ctx := context.Background()
- args := map[string]interface{}{
+ args := map[string]any{
"content": "Test message",
}
@@ -179,7 +179,7 @@ func TestMessageTool_Execute_NotConfigured(t *testing.T) {
// No SetSendCallback called
ctx := context.Background()
- args := map[string]interface{}{
+ args := map[string]any{
"content": "Test message",
}
@@ -219,7 +219,7 @@ func TestMessageTool_Parameters(t *testing.T) {
t.Error("Expected type 'object'")
}
- props, ok := params["properties"].(map[string]interface{})
+ props, ok := params["properties"].(map[string]any)
if !ok {
t.Fatal("Expected properties to be a map")
}
@@ -231,7 +231,7 @@ func TestMessageTool_Parameters(t *testing.T) {
}
// Check content property
- contentProp, ok := props["content"].(map[string]interface{})
+ contentProp, ok := props["content"].(map[string]any)
if !ok {
t.Error("Expected 'content' property")
}
@@ -240,7 +240,7 @@ func TestMessageTool_Parameters(t *testing.T) {
}
// Check channel property (optional)
- channelProp, ok := props["channel"].(map[string]interface{})
+ channelProp, ok := props["channel"].(map[string]any)
if !ok {
t.Error("Expected 'channel' property")
}
@@ -249,7 +249,7 @@ func TestMessageTool_Parameters(t *testing.T) {
}
// Check chat_id property (optional)
- chatIDProp, ok := props["chat_id"].(map[string]interface{})
+ chatIDProp, ok := props["chat_id"].(map[string]any)
if !ok {
t.Error("Expected 'chat_id' property")
}
diff --git a/pkg/tools/registry.go b/pkg/tools/registry.go
index c8cf92863..d37a093a8 100644
--- a/pkg/tools/registry.go
+++ b/pkg/tools/registry.go
@@ -3,6 +3,7 @@ package tools
import (
"context"
"fmt"
+ "sort"
"sync"
"time"
@@ -34,16 +35,22 @@ func (r *ToolRegistry) Get(name string) (Tool, bool) {
return tool, ok
}
-func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string]interface{}) *ToolResult {
+func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string]any) *ToolResult {
return r.ExecuteWithContext(ctx, name, args, "", "", nil)
}
// ExecuteWithContext executes a tool with channel/chatID context and optional async callback.
// If the tool implements AsyncTool and a non-nil callback is provided,
// the callback will be set on the tool before execution.
-func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args map[string]interface{}, channel, chatID string, asyncCallback AsyncCallback) *ToolResult {
+func (r *ToolRegistry) ExecuteWithContext(
+ ctx context.Context,
+ name string,
+ args map[string]any,
+ channel, chatID string,
+ asyncCallback AsyncCallback,
+) *ToolResult {
logger.InfoCF("tool", "Tool execution started",
- map[string]interface{}{
+ map[string]any{
"tool": name,
"args": args,
})
@@ -51,7 +58,7 @@ func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args
tool, ok := r.Get(name)
if !ok {
logger.ErrorCF("tool", "Tool not found",
- map[string]interface{}{
+ map[string]any{
"tool": name,
})
return ErrorResult(fmt.Sprintf("tool %q not found", name)).WithError(fmt.Errorf("tool not found"))
@@ -66,7 +73,7 @@ func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args
if asyncTool, ok := tool.(AsyncTool); ok && asyncCallback != nil {
asyncTool.SetCallback(asyncCallback)
logger.DebugCF("tool", "Async callback injected",
- map[string]interface{}{
+ map[string]any{
"tool": name,
})
}
@@ -78,20 +85,20 @@ func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args
// Log based on result type
if result.IsError {
logger.ErrorCF("tool", "Tool execution failed",
- map[string]interface{}{
+ map[string]any{
"tool": name,
"duration": duration.Milliseconds(),
"error": result.ForLLM,
})
} else if result.Async {
logger.InfoCF("tool", "Tool started (async)",
- map[string]interface{}{
+ map[string]any{
"tool": name,
"duration": duration.Milliseconds(),
})
} else {
logger.InfoCF("tool", "Tool execution completed",
- map[string]interface{}{
+ map[string]any{
"tool": name,
"duration_ms": duration.Milliseconds(),
"result_length": len(result.ForLLM),
@@ -101,13 +108,27 @@ func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args
return result
}
-func (r *ToolRegistry) GetDefinitions() []map[string]interface{} {
+// sortedToolNames returns tool names in sorted order for deterministic iteration.
+// This is critical for KV cache stability: non-deterministic map iteration would
+// produce different system prompts and tool definitions on each call, invalidating
+// the LLM's prefix cache even when no tools have changed.
+func (r *ToolRegistry) sortedToolNames() []string {
+ names := make([]string, 0, len(r.tools))
+ for name := range r.tools {
+ names = append(names, name)
+ }
+ sort.Strings(names)
+ return names
+}
+
+func (r *ToolRegistry) GetDefinitions() []map[string]any {
r.mu.RLock()
defer r.mu.RUnlock()
- definitions := make([]map[string]interface{}, 0, len(r.tools))
- for _, tool := range r.tools {
- definitions = append(definitions, ToolToSchema(tool))
+ sorted := r.sortedToolNames()
+ definitions := make([]map[string]any, 0, len(sorted))
+ for _, name := range sorted {
+ definitions = append(definitions, ToolToSchema(r.tools[name]))
}
return definitions
}
@@ -118,19 +139,21 @@ func (r *ToolRegistry) ToProviderDefs() []providers.ToolDefinition {
r.mu.RLock()
defer r.mu.RUnlock()
- definitions := make([]providers.ToolDefinition, 0, len(r.tools))
- for _, tool := range r.tools {
+ sorted := r.sortedToolNames()
+ definitions := make([]providers.ToolDefinition, 0, len(sorted))
+ for _, name := range sorted {
+ tool := r.tools[name]
schema := ToolToSchema(tool)
// Safely extract nested values with type checks
- fn, ok := schema["function"].(map[string]interface{})
+ fn, ok := schema["function"].(map[string]any)
if !ok {
continue
}
name, _ := fn["name"].(string)
desc, _ := fn["description"].(string)
- params, _ := fn["parameters"].(map[string]interface{})
+ params, _ := fn["parameters"].(map[string]any)
definitions = append(definitions, providers.ToolDefinition{
Type: "function",
@@ -149,11 +172,7 @@ func (r *ToolRegistry) List() []string {
r.mu.RLock()
defer r.mu.RUnlock()
- names := make([]string, 0, len(r.tools))
- for name := range r.tools {
- names = append(names, name)
- }
- return names
+ return r.sortedToolNames()
}
// Count returns the number of registered tools.
@@ -169,8 +188,10 @@ func (r *ToolRegistry) GetSummaries() []string {
r.mu.RLock()
defer r.mu.RUnlock()
- summaries := make([]string, 0, len(r.tools))
- for _, tool := range r.tools {
+ sorted := r.sortedToolNames()
+ summaries := make([]string, 0, len(sorted))
+ for _, name := range sorted {
+ tool := r.tools[name]
summaries = append(summaries, fmt.Sprintf("- `%s` - %s", tool.Name(), tool.Description()))
}
return summaries
diff --git a/pkg/tools/registry_test.go b/pkg/tools/registry_test.go
new file mode 100644
index 000000000..8ae13b20c
--- /dev/null
+++ b/pkg/tools/registry_test.go
@@ -0,0 +1,350 @@
+package tools
+
+import (
+ "context"
+ "strings"
+ "sync"
+ "testing"
+
+ "github.com/sipeed/picoclaw/pkg/providers"
+)
+
+// --- mock types ---
+
+type mockRegistryTool struct {
+ name string
+ desc string
+ params map[string]any
+ result *ToolResult
+}
+
+func (m *mockRegistryTool) Name() string { return m.name }
+func (m *mockRegistryTool) Description() string { return m.desc }
+func (m *mockRegistryTool) Parameters() map[string]any { return m.params }
+func (m *mockRegistryTool) Execute(_ context.Context, _ map[string]any) *ToolResult {
+ return m.result
+}
+
+type mockCtxTool struct {
+ mockRegistryTool
+ channel string
+ chatID string
+}
+
+func (m *mockCtxTool) SetContext(channel, chatID string) {
+ m.channel = channel
+ m.chatID = chatID
+}
+
+type mockAsyncRegistryTool struct {
+ mockRegistryTool
+ cb AsyncCallback
+}
+
+func (m *mockAsyncRegistryTool) SetCallback(cb AsyncCallback) {
+ m.cb = cb
+}
+
+// --- helpers ---
+
+func newMockTool(name, desc string) *mockRegistryTool {
+ return &mockRegistryTool{
+ name: name,
+ desc: desc,
+ params: map[string]any{"type": "object"},
+ result: SilentResult("ok"),
+ }
+}
+
+// --- tests ---
+
+func TestNewToolRegistry(t *testing.T) {
+ r := NewToolRegistry()
+ if r.Count() != 0 {
+ t.Errorf("expected empty registry, got count %d", r.Count())
+ }
+ if len(r.List()) != 0 {
+ t.Errorf("expected empty list, got %v", r.List())
+ }
+}
+
+func TestToolRegistry_RegisterAndGet(t *testing.T) {
+ r := NewToolRegistry()
+ tool := newMockTool("echo", "echoes input")
+ r.Register(tool)
+
+ got, ok := r.Get("echo")
+ if !ok {
+ t.Fatal("expected to find registered tool")
+ }
+ if got.Name() != "echo" {
+ t.Errorf("expected name 'echo', got %q", got.Name())
+ }
+}
+
+func TestToolRegistry_Get_NotFound(t *testing.T) {
+ r := NewToolRegistry()
+ _, ok := r.Get("nonexistent")
+ if ok {
+ t.Error("expected ok=false for unregistered tool")
+ }
+}
+
+func TestToolRegistry_RegisterOverwrite(t *testing.T) {
+ r := NewToolRegistry()
+ r.Register(newMockTool("dup", "first"))
+ r.Register(newMockTool("dup", "second"))
+
+ if r.Count() != 1 {
+ t.Errorf("expected count 1 after overwrite, got %d", r.Count())
+ }
+ tool, _ := r.Get("dup")
+ if tool.Description() != "second" {
+ t.Errorf("expected overwritten description 'second', got %q", tool.Description())
+ }
+}
+
+func TestToolRegistry_Execute_Success(t *testing.T) {
+ r := NewToolRegistry()
+ r.Register(&mockRegistryTool{
+ name: "greet",
+ desc: "says hello",
+ params: map[string]any{},
+ result: SilentResult("hello"),
+ })
+
+ result := r.Execute(context.Background(), "greet", nil)
+ if result.IsError {
+ t.Errorf("expected success, got error: %s", result.ForLLM)
+ }
+ if result.ForLLM != "hello" {
+ t.Errorf("expected ForLLM 'hello', got %q", result.ForLLM)
+ }
+}
+
+func TestToolRegistry_Execute_NotFound(t *testing.T) {
+ r := NewToolRegistry()
+ result := r.Execute(context.Background(), "missing", nil)
+ if !result.IsError {
+ t.Error("expected error for missing tool")
+ }
+ if !strings.Contains(result.ForLLM, "not found") {
+ t.Errorf("expected 'not found' in error, got %q", result.ForLLM)
+ }
+ if result.Err == nil {
+ t.Error("expected Err to be set via WithError")
+ }
+}
+
+func TestToolRegistry_ExecuteWithContext_ContextualTool(t *testing.T) {
+ r := NewToolRegistry()
+ ct := &mockCtxTool{
+ mockRegistryTool: *newMockTool("ctx_tool", "needs context"),
+ }
+ r.Register(ct)
+
+ r.ExecuteWithContext(context.Background(), "ctx_tool", nil, "telegram", "chat-42", nil)
+
+ if ct.channel != "telegram" {
+ t.Errorf("expected channel 'telegram', got %q", ct.channel)
+ }
+ if ct.chatID != "chat-42" {
+ t.Errorf("expected chatID 'chat-42', got %q", ct.chatID)
+ }
+}
+
+func TestToolRegistry_ExecuteWithContext_SkipsEmptyContext(t *testing.T) {
+ r := NewToolRegistry()
+ ct := &mockCtxTool{
+ mockRegistryTool: *newMockTool("ctx_tool", "needs context"),
+ }
+ r.Register(ct)
+
+ r.ExecuteWithContext(context.Background(), "ctx_tool", nil, "", "", nil)
+
+ if ct.channel != "" || ct.chatID != "" {
+ t.Error("SetContext should not be called with empty channel/chatID")
+ }
+}
+
+func TestToolRegistry_ExecuteWithContext_AsyncCallback(t *testing.T) {
+ r := NewToolRegistry()
+ at := &mockAsyncRegistryTool{
+ mockRegistryTool: *newMockTool("async_tool", "async work"),
+ }
+ at.result = AsyncResult("started")
+ r.Register(at)
+
+ called := false
+ cb := func(_ context.Context, _ *ToolResult) { called = true }
+
+ result := r.ExecuteWithContext(context.Background(), "async_tool", nil, "", "", cb)
+ if at.cb == nil {
+ t.Error("expected SetCallback to have been called")
+ }
+ if !result.Async {
+ t.Error("expected async result")
+ }
+
+ at.cb(context.Background(), SilentResult("done"))
+ if !called {
+ t.Error("expected callback to be invoked")
+ }
+}
+
+func TestToolRegistry_GetDefinitions(t *testing.T) {
+ r := NewToolRegistry()
+ r.Register(newMockTool("alpha", "tool A"))
+
+ defs := r.GetDefinitions()
+ if len(defs) != 1 {
+ t.Fatalf("expected 1 definition, got %d", len(defs))
+ }
+ if defs[0]["type"] != "function" {
+ t.Errorf("expected type 'function', got %v", defs[0]["type"])
+ }
+ fn, ok := defs[0]["function"].(map[string]any)
+ if !ok {
+ t.Fatal("expected 'function' key to be a map")
+ }
+ if fn["name"] != "alpha" {
+ t.Errorf("expected name 'alpha', got %v", fn["name"])
+ }
+ if fn["description"] != "tool A" {
+ t.Errorf("expected description 'tool A', got %v", fn["description"])
+ }
+}
+
+func TestToolRegistry_ToProviderDefs(t *testing.T) {
+ r := NewToolRegistry()
+ params := map[string]any{"type": "object", "properties": map[string]any{}}
+ r.Register(&mockRegistryTool{
+ name: "beta",
+ desc: "tool B",
+ params: params,
+ result: SilentResult("ok"),
+ })
+
+ defs := r.ToProviderDefs()
+ if len(defs) != 1 {
+ t.Fatalf("expected 1 provider def, got %d", len(defs))
+ }
+
+ want := providers.ToolDefinition{
+ Type: "function",
+ Function: providers.ToolFunctionDefinition{
+ Name: "beta",
+ Description: "tool B",
+ Parameters: params,
+ },
+ }
+ got := defs[0]
+ if got.Type != want.Type {
+ t.Errorf("Type: want %q, got %q", want.Type, got.Type)
+ }
+ if got.Function.Name != want.Function.Name {
+ t.Errorf("Name: want %q, got %q", want.Function.Name, got.Function.Name)
+ }
+ if got.Function.Description != want.Function.Description {
+ t.Errorf("Description: want %q, got %q", want.Function.Description, got.Function.Description)
+ }
+}
+
+func TestToolRegistry_List(t *testing.T) {
+ r := NewToolRegistry()
+ r.Register(newMockTool("x", ""))
+ r.Register(newMockTool("y", ""))
+
+ names := r.List()
+ if len(names) != 2 {
+ t.Fatalf("expected 2 names, got %d", len(names))
+ }
+
+ nameSet := map[string]bool{}
+ for _, n := range names {
+ nameSet[n] = true
+ }
+ if !nameSet["x"] || !nameSet["y"] {
+ t.Errorf("expected names {x, y}, got %v", names)
+ }
+}
+
+func TestToolRegistry_Count(t *testing.T) {
+ r := NewToolRegistry()
+ if r.Count() != 0 {
+ t.Errorf("expected 0, got %d", r.Count())
+ }
+
+ r.Register(newMockTool("a", ""))
+ r.Register(newMockTool("b", ""))
+ if r.Count() != 2 {
+ t.Errorf("expected 2, got %d", r.Count())
+ }
+
+ r.Register(newMockTool("a", "replaced"))
+ if r.Count() != 2 {
+ t.Errorf("expected 2 after overwrite, got %d", r.Count())
+ }
+}
+
+func TestToolRegistry_GetSummaries(t *testing.T) {
+ r := NewToolRegistry()
+ r.Register(newMockTool("read_file", "Reads a file"))
+
+ summaries := r.GetSummaries()
+ if len(summaries) != 1 {
+ t.Fatalf("expected 1 summary, got %d", len(summaries))
+ }
+ if !strings.Contains(summaries[0], "`read_file`") {
+ t.Errorf("expected backtick-quoted name in summary, got %q", summaries[0])
+ }
+ if !strings.Contains(summaries[0], "Reads a file") {
+ t.Errorf("expected description in summary, got %q", summaries[0])
+ }
+}
+
+func TestToolToSchema(t *testing.T) {
+ tool := newMockTool("demo", "demo tool")
+ schema := ToolToSchema(tool)
+
+ if schema["type"] != "function" {
+ t.Errorf("expected type 'function', got %v", schema["type"])
+ }
+ fn, ok := schema["function"].(map[string]any)
+ if !ok {
+ t.Fatal("expected 'function' to be a map")
+ }
+ if fn["name"] != "demo" {
+ t.Errorf("expected name 'demo', got %v", fn["name"])
+ }
+ if fn["description"] != "demo tool" {
+ t.Errorf("expected description 'demo tool', got %v", fn["description"])
+ }
+ if fn["parameters"] == nil {
+ t.Error("expected parameters to be set")
+ }
+}
+
+func TestToolRegistry_ConcurrentAccess(t *testing.T) {
+ r := NewToolRegistry()
+ var wg sync.WaitGroup
+
+ for i := 0; i < 50; i++ {
+ wg.Add(1)
+ go func(n int) {
+ defer wg.Done()
+ name := string(rune('A' + n%26))
+ r.Register(newMockTool(name, "concurrent"))
+ r.Get(name)
+ r.Count()
+ r.List()
+ r.GetDefinitions()
+ }(i)
+ }
+
+ wg.Wait()
+
+ if r.Count() == 0 {
+ t.Error("expected tools to be registered after concurrent access")
+ }
+}
diff --git a/pkg/tools/result_test.go b/pkg/tools/result_test.go
index bc798cd70..a234e33f3 100644
--- a/pkg/tools/result_test.go
+++ b/pkg/tools/result_test.go
@@ -192,7 +192,7 @@ func TestToolResultJSONStructure(t *testing.T) {
}
// Verify JSON structure
- var parsed map[string]interface{}
+ var parsed map[string]any
if err := json.Unmarshal(data, &parsed); err != nil {
t.Fatalf("Failed to parse JSON: %v", err)
}
diff --git a/pkg/tools/shell.go b/pkg/tools/shell.go
index d9430672f..ad1664b5b 100644
--- a/pkg/tools/shell.go
+++ b/pkg/tools/shell.go
@@ -3,6 +3,7 @@ package tools
import (
"bytes"
"context"
+ "errors"
"fmt"
"os"
"os/exec"
@@ -75,11 +76,11 @@ func NewExecTool(workingDir string, restrict bool) *ExecTool {
func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Config) *ExecTool {
denyPatterns := make([]*regexp.Regexp, 0)
- enableDenyPatterns := true
if config != nil {
execConfig := config.Tools.Exec
- enableDenyPatterns = execConfig.EnableDenyPatterns
+ enableDenyPatterns := execConfig.EnableDenyPatterns
if enableDenyPatterns {
+ denyPatterns = append(denyPatterns, defaultDenyPatterns...)
if len(execConfig.CustomDenyPatterns) > 0 {
fmt.Printf("Using custom deny patterns: %v\n", execConfig.CustomDenyPatterns)
for _, pattern := range execConfig.CustomDenyPatterns {
@@ -90,8 +91,6 @@ func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Conf
}
denyPatterns = append(denyPatterns, re)
}
- } else {
- denyPatterns = append(denyPatterns, defaultDenyPatterns...)
}
} else {
// If deny patterns are disabled, we won't add any patterns, allowing all commands.
@@ -118,15 +117,15 @@ func (t *ExecTool) Description() string {
return "Execute a shell command and return its output. Use with caution."
}
-func (t *ExecTool) Parameters() map[string]interface{} {
- return map[string]interface{}{
+func (t *ExecTool) Parameters() map[string]any {
+ return map[string]any{
"type": "object",
- "properties": map[string]interface{}{
- "command": map[string]interface{}{
+ "properties": map[string]any{
+ "command": map[string]any{
"type": "string",
"description": "The shell command to execute",
},
- "working_dir": map[string]interface{}{
+ "working_dir": map[string]any{
"type": "string",
"description": "Optional working directory for the command",
},
@@ -135,7 +134,7 @@ func (t *ExecTool) Parameters() map[string]interface{} {
}
}
-func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
+func (t *ExecTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
command, ok := args["command"].(string)
if !ok {
return ErrorResult("command is required")
@@ -143,7 +142,15 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) *To
cwd := t.workingDir
if wd, ok := args["working_dir"].(string); ok && wd != "" {
- cwd = wd
+ if t.restrictToWorkspace && t.workingDir != "" {
+ resolvedWD, err := validatePath(wd, t.workingDir, true)
+ if err != nil {
+ return ErrorResult("Command blocked by safety guard (" + err.Error() + ")")
+ }
+ cwd = resolvedWD
+ } else {
+ cwd = wd
+ }
}
if cwd == "" {
@@ -177,18 +184,43 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) *To
cmd.Dir = cwd
}
+ prepareCommandForTermination(cmd)
+
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
- err := cmd.Run()
+ if err := cmd.Start(); err != nil {
+ return ErrorResult(fmt.Sprintf("failed to start command: %v", err))
+ }
+
+ done := make(chan error, 1)
+ go func() {
+ done <- cmd.Wait()
+ }()
+
+ var err error
+ select {
+ case err = <-done:
+ case <-cmdCtx.Done():
+ _ = terminateProcessTree(cmd)
+ select {
+ case err = <-done:
+ case <-time.After(2 * time.Second):
+ if cmd.Process != nil {
+ _ = cmd.Process.Kill()
+ }
+ err = <-done
+ }
+ }
+
output := stdout.String()
if stderr.Len() > 0 {
output += "\nSTDERR:\n" + stderr.String()
}
if err != nil {
- if cmdCtx.Err() == context.DeadlineExceeded {
+ if errors.Is(cmdCtx.Err(), context.DeadlineExceeded) {
msg := fmt.Sprintf("Command timed out after %v", t.timeout)
return &ToolResult{
ForLLM: msg,
diff --git a/pkg/tools/shell_process_unix.go b/pkg/tools/shell_process_unix.go
new file mode 100644
index 000000000..7b29a81bf
--- /dev/null
+++ b/pkg/tools/shell_process_unix.go
@@ -0,0 +1,32 @@
+//go:build !windows
+
+package tools
+
+import (
+ "os/exec"
+ "syscall"
+)
+
+func prepareCommandForTermination(cmd *exec.Cmd) {
+ if cmd == nil {
+ return
+ }
+ cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
+}
+
+func terminateProcessTree(cmd *exec.Cmd) error {
+ if cmd == nil || cmd.Process == nil {
+ return nil
+ }
+
+ pid := cmd.Process.Pid
+ if pid <= 0 {
+ return nil
+ }
+
+ // Kill the entire process group spawned by the shell command.
+ _ = syscall.Kill(-pid, syscall.SIGKILL)
+ // Fallback kill on the shell process itself.
+ _ = cmd.Process.Kill()
+ return nil
+}
diff --git a/pkg/tools/shell_process_windows.go b/pkg/tools/shell_process_windows.go
new file mode 100644
index 000000000..fe23b5c96
--- /dev/null
+++ b/pkg/tools/shell_process_windows.go
@@ -0,0 +1,27 @@
+//go:build windows
+
+package tools
+
+import (
+ "os/exec"
+ "strconv"
+)
+
+func prepareCommandForTermination(cmd *exec.Cmd) {
+ // no-op on Windows
+}
+
+func terminateProcessTree(cmd *exec.Cmd) error {
+ if cmd == nil || cmd.Process == nil {
+ return nil
+ }
+
+ pid := cmd.Process.Pid
+ if pid <= 0 {
+ return nil
+ }
+
+ _ = exec.Command("taskkill", "/T", "/F", "/PID", strconv.Itoa(pid)).Run()
+ _ = cmd.Process.Kill()
+ return nil
+}
diff --git a/pkg/tools/shell_test.go b/pkg/tools/shell_test.go
index c06468a39..6d35815e8 100644
--- a/pkg/tools/shell_test.go
+++ b/pkg/tools/shell_test.go
@@ -14,7 +14,7 @@ func TestShellTool_Success(t *testing.T) {
tool := NewExecTool("", false)
ctx := context.Background()
- args := map[string]interface{}{
+ args := map[string]any{
"command": "echo 'hello world'",
}
@@ -41,7 +41,7 @@ func TestShellTool_Failure(t *testing.T) {
tool := NewExecTool("", false)
ctx := context.Background()
- args := map[string]interface{}{
+ args := map[string]any{
"command": "ls /nonexistent_directory_12345",
}
@@ -69,7 +69,7 @@ func TestShellTool_Timeout(t *testing.T) {
tool.SetTimeout(100 * time.Millisecond)
ctx := context.Background()
- args := map[string]interface{}{
+ args := map[string]any{
"command": "sleep 10",
}
@@ -91,12 +91,12 @@ func TestShellTool_WorkingDir(t *testing.T) {
// Create temp directory
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.txt")
- os.WriteFile(testFile, []byte("test content"), 0644)
+ os.WriteFile(testFile, []byte("test content"), 0o644)
tool := NewExecTool("", false)
ctx := context.Background()
- args := map[string]interface{}{
+ args := map[string]any{
"command": "cat test.txt",
"working_dir": tmpDir,
}
@@ -117,7 +117,7 @@ func TestShellTool_DangerousCommand(t *testing.T) {
tool := NewExecTool("", false)
ctx := context.Background()
- args := map[string]interface{}{
+ args := map[string]any{
"command": "rm -rf /",
}
@@ -138,7 +138,7 @@ func TestShellTool_MissingCommand(t *testing.T) {
tool := NewExecTool("", false)
ctx := context.Background()
- args := map[string]interface{}{}
+ args := map[string]any{}
result := tool.Execute(ctx, args)
@@ -153,7 +153,7 @@ func TestShellTool_StderrCapture(t *testing.T) {
tool := NewExecTool("", false)
ctx := context.Background()
- args := map[string]interface{}{
+ args := map[string]any{
"command": "sh -c 'echo stdout; echo stderr >&2'",
}
@@ -174,7 +174,7 @@ func TestShellTool_OutputTruncation(t *testing.T) {
ctx := context.Background()
// Generate long output (>10000 chars)
- args := map[string]interface{}{
+ args := map[string]any{
"command": "python3 -c \"print('x' * 20000)\" || echo " + strings.Repeat("x", 20000),
}
@@ -186,6 +186,66 @@ func TestShellTool_OutputTruncation(t *testing.T) {
}
}
+// TestShellTool_WorkingDir_OutsideWorkspace verifies that working_dir cannot escape the workspace directly
+func TestShellTool_WorkingDir_OutsideWorkspace(t *testing.T) {
+ root := t.TempDir()
+ workspace := filepath.Join(root, "workspace")
+ outsideDir := filepath.Join(root, "outside")
+ if err := os.MkdirAll(workspace, 0o755); err != nil {
+ t.Fatalf("failed to create workspace: %v", err)
+ }
+ if err := os.MkdirAll(outsideDir, 0o755); err != nil {
+ t.Fatalf("failed to create outside dir: %v", err)
+ }
+
+ tool := NewExecTool(workspace, true)
+ result := tool.Execute(context.Background(), map[string]any{
+ "command": "pwd",
+ "working_dir": outsideDir,
+ })
+
+ if !result.IsError {
+ t.Fatalf("expected working_dir outside workspace to be blocked, got output: %s", result.ForLLM)
+ }
+ if !strings.Contains(result.ForLLM, "blocked") {
+ t.Errorf("expected 'blocked' in error, got: %s", result.ForLLM)
+ }
+}
+
+// TestShellTool_WorkingDir_SymlinkEscape verifies that a symlink inside the workspace
+// pointing outside cannot be used as working_dir to escape the sandbox.
+func TestShellTool_WorkingDir_SymlinkEscape(t *testing.T) {
+ root := t.TempDir()
+ workspace := filepath.Join(root, "workspace")
+ secretDir := filepath.Join(root, "secret")
+ if err := os.MkdirAll(workspace, 0o755); err != nil {
+ t.Fatalf("failed to create workspace: %v", err)
+ }
+ if err := os.MkdirAll(secretDir, 0o755); err != nil {
+ t.Fatalf("failed to create secret dir: %v", err)
+ }
+ os.WriteFile(filepath.Join(secretDir, "secret.txt"), []byte("top secret"), 0o644)
+
+ // symlink lives inside the workspace but resolves to secretDir outside it
+ link := filepath.Join(workspace, "escape")
+ if err := os.Symlink(secretDir, link); err != nil {
+ t.Skipf("symlinks not supported in this environment: %v", err)
+ }
+
+ tool := NewExecTool(workspace, true)
+ result := tool.Execute(context.Background(), map[string]any{
+ "command": "cat secret.txt",
+ "working_dir": link,
+ })
+
+ if !result.IsError {
+ t.Fatalf("expected symlink working_dir escape to be blocked, got output: %s", result.ForLLM)
+ }
+ if !strings.Contains(result.ForLLM, "blocked") {
+ t.Errorf("expected 'blocked' in error, got: %s", result.ForLLM)
+ }
+}
+
// TestShellTool_RestrictToWorkspace verifies workspace restriction
func TestShellTool_RestrictToWorkspace(t *testing.T) {
tmpDir := t.TempDir()
@@ -193,7 +253,7 @@ func TestShellTool_RestrictToWorkspace(t *testing.T) {
tool.SetRestrictToWorkspace(true)
ctx := context.Background()
- args := map[string]interface{}{
+ args := map[string]any{
"command": "cat ../../etc/passwd",
}
@@ -205,6 +265,10 @@ func TestShellTool_RestrictToWorkspace(t *testing.T) {
}
if !strings.Contains(result.ForLLM, "blocked") && !strings.Contains(result.ForUser, "blocked") {
- t.Errorf("Expected 'blocked' message for path traversal, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
+ t.Errorf(
+ "Expected 'blocked' message for path traversal, got ForLLM: %s, ForUser: %s",
+ result.ForLLM,
+ result.ForUser,
+ )
}
}
diff --git a/pkg/tools/shell_timeout_unix_test.go b/pkg/tools/shell_timeout_unix_test.go
new file mode 100644
index 000000000..04ef8e441
--- /dev/null
+++ b/pkg/tools/shell_timeout_unix_test.go
@@ -0,0 +1,61 @@
+//go:build !windows
+
+package tools
+
+import (
+ "context"
+ "os"
+ "path/filepath"
+ "strconv"
+ "strings"
+ "syscall"
+ "testing"
+ "time"
+)
+
+func processExists(pid int) bool {
+ if pid <= 0 {
+ return false
+ }
+ err := syscall.Kill(pid, 0)
+ return err == nil || err == syscall.EPERM
+}
+
+func TestShellTool_TimeoutKillsChildProcess(t *testing.T) {
+ tool := NewExecTool(t.TempDir(), false)
+ tool.SetTimeout(500 * time.Millisecond)
+
+ args := map[string]any{
+ // Spawn a child process that would outlive the shell unless process-group kill is used.
+ "command": "sleep 60 & echo $! > child.pid; wait",
+ }
+
+ result := tool.Execute(context.Background(), args)
+ if !result.IsError {
+ t.Fatalf("expected timeout error, got success: %s", result.ForLLM)
+ }
+ if !strings.Contains(result.ForLLM, "timed out") {
+ t.Fatalf("expected timeout message, got: %s", result.ForLLM)
+ }
+
+ childPIDPath := filepath.Join(tool.workingDir, "child.pid")
+ data, err := os.ReadFile(childPIDPath)
+ if err != nil {
+ t.Fatalf("failed to read child pid file: %v", err)
+ }
+
+ childPID, err := strconv.Atoi(strings.TrimSpace(string(data)))
+ if err != nil {
+ t.Fatalf("failed to parse child pid: %v", err)
+ }
+
+ deadline := time.Now().Add(2 * time.Second)
+ for time.Now().Before(deadline) {
+ if !processExists(childPID) {
+ return
+ }
+ time.Sleep(50 * time.Millisecond)
+ }
+
+ t.Fatalf("child process %d is still running after timeout", childPID)
+}
diff --git a/pkg/tools/skills_install.go b/pkg/tools/skills_install.go
new file mode 100644
index 000000000..55c0b678d
--- /dev/null
+++ b/pkg/tools/skills_install.go
@@ -0,0 +1,201 @@
+package tools
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "os"
+ "path/filepath"
+ "sync"
+ "time"
+
+ "github.com/sipeed/picoclaw/pkg/logger"
+ "github.com/sipeed/picoclaw/pkg/skills"
+ "github.com/sipeed/picoclaw/pkg/utils"
+)
+
+// InstallSkillTool allows the LLM agent to install skills from registries.
+// It shares the same RegistryManager that FindSkillsTool uses,
+// so all registries configured in config are available for installation.
+type InstallSkillTool struct {
+ registryMgr *skills.RegistryManager
+ workspace string
+ mu sync.Mutex
+}
+
+// NewInstallSkillTool creates a new InstallSkillTool.
+// registryMgr is the shared registry manager (same instance as FindSkillsTool).
+// workspace is the root workspace directory; skills install to {workspace}/skills/{slug}/.
+func NewInstallSkillTool(registryMgr *skills.RegistryManager, workspace string) *InstallSkillTool {
+ return &InstallSkillTool{
+ registryMgr: registryMgr,
+ workspace: workspace,
+ mu: sync.Mutex{},
+ }
+}
+
+func (t *InstallSkillTool) Name() string {
+ return "install_skill"
+}
+
+func (t *InstallSkillTool) Description() string {
+ return "Install a skill from a registry by slug. Downloads and extracts the skill into the workspace. Use find_skills first to discover available skills."
+}
+
+func (t *InstallSkillTool) Parameters() map[string]any {
+ return map[string]any{
+ "type": "object",
+ "properties": map[string]any{
+ "slug": map[string]any{
+ "type": "string",
+ "description": "The unique slug of the skill to install (e.g., 'github', 'docker-compose')",
+ },
+ "version": map[string]any{
+ "type": "string",
+ "description": "Specific version to install (optional, defaults to latest)",
+ },
+ "registry": map[string]any{
+ "type": "string",
+ "description": "Registry to install from (required, e.g., 'clawhub')",
+ },
+ "force": map[string]any{
+ "type": "boolean",
+ "description": "Force reinstall if skill already exists (default false)",
+ },
+ },
+ "required": []string{"slug", "registry"},
+ }
+}
+
+func (t *InstallSkillTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
+ // Install lock to prevent concurrent directory operations.
+ // Ideally this should be done at a `slug` level, currently, its at a `workspace` level.
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ // Validate slug
+ slug, _ := args["slug"].(string)
+ if err := utils.ValidateSkillIdentifier(slug); err != nil {
+ return ErrorResult(fmt.Sprintf("invalid slug %q: error: %s", slug, err.Error()))
+ }
+
+ // Validate registry
+ registryName, _ := args["registry"].(string)
+ if err := utils.ValidateSkillIdentifier(registryName); err != nil {
+ return ErrorResult(fmt.Sprintf("invalid registry %q: error: %s", registryName, err.Error()))
+ }
+
+ version, _ := args["version"].(string)
+ force, _ := args["force"].(bool)
+
+ // Check if already installed.
+ skillsDir := filepath.Join(t.workspace, "skills")
+ targetDir := filepath.Join(skillsDir, slug)
+
+ if !force {
+ if _, err := os.Stat(targetDir); err == nil {
+ return ErrorResult(
+ fmt.Sprintf("skill %q already installed at %s. Use force=true to reinstall.", slug, targetDir),
+ )
+ }
+ } else {
+ // Force: remove existing if present.
+ os.RemoveAll(targetDir)
+ }
+
+ // Resolve which registry to use.
+ registry := t.registryMgr.GetRegistry(registryName)
+ if registry == nil {
+ return ErrorResult(fmt.Sprintf("registry %q not found", registryName))
+ }
+
+ // Ensure skills directory exists.
+ if err := os.MkdirAll(skillsDir, 0o755); err != nil {
+ return ErrorResult(fmt.Sprintf("failed to create skills directory: %v", err))
+ }
+
+ // Download and install (handles metadata, version resolution, extraction).
+ result, err := registry.DownloadAndInstall(ctx, slug, version, targetDir)
+ if err != nil {
+ // Clean up partial install.
+ rmErr := os.RemoveAll(targetDir)
+ if rmErr != nil {
+ logger.ErrorCF("tool", "Failed to remove partial install",
+ map[string]any{
+ "tool": "install_skill",
+ "target_dir": targetDir,
+ "error": rmErr.Error(),
+ })
+ }
+ return ErrorResult(fmt.Sprintf("failed to install %q: %v", slug, err))
+ }
+
+ // Moderation: block malware.
+ if result.IsMalwareBlocked {
+ rmErr := os.RemoveAll(targetDir)
+ if rmErr != nil {
+ logger.ErrorCF("tool", "Failed to remove partial install",
+ map[string]any{
+ "tool": "install_skill",
+ "target_dir": targetDir,
+ "error": rmErr.Error(),
+ })
+ }
+ return ErrorResult(fmt.Sprintf("skill %q is flagged as malicious and cannot be installed", slug))
+ }
+
+ // Write origin metadata.
+ if err := writeOriginMeta(targetDir, registry.Name(), slug, result.Version); err != nil {
+ logger.ErrorCF("tool", "Failed to write origin metadata",
+ map[string]any{
+ "tool": "install_skill",
+ "error": err.Error(),
+ "target": targetDir,
+ "registry": registry.Name(),
+ "slug": slug,
+ "version": result.Version,
+ })
+ _ = err
+ }
+
+ // Build result with moderation warning if suspicious.
+ var output string
+ if result.IsSuspicious {
+ output = fmt.Sprintf("⚠️ Warning: skill %q is flagged as suspicious (may contain risky patterns).\n\n", slug)
+ }
+ output += fmt.Sprintf("Successfully installed skill %q v%s from %s registry.\nLocation: %s\n",
+ slug, result.Version, registry.Name(), targetDir)
+
+ if result.Summary != "" {
+ output += fmt.Sprintf("Description: %s\n", result.Summary)
+ }
+ output += "\nThe skill is now available and can be loaded in the current session."
+
+ return SilentResult(output)
+}
+
+// originMeta tracks which registry a skill was installed from.
+type originMeta struct {
+ Version int `json:"version"`
+ Registry string `json:"registry"`
+ Slug string `json:"slug"`
+ InstalledVersion string `json:"installed_version"`
+ InstalledAt int64 `json:"installed_at"`
+}
+
+func writeOriginMeta(targetDir, registryName, slug, version string) error {
+ meta := originMeta{
+ Version: 1,
+ Registry: registryName,
+ Slug: slug,
+ InstalledVersion: version,
+ InstalledAt: time.Now().UnixMilli(),
+ }
+
+ data, err := json.MarshalIndent(meta, "", " ")
+ if err != nil {
+ return err
+ }
+
+ return os.WriteFile(filepath.Join(targetDir, ".skill-origin.json"), data, 0o644)
+}
diff --git a/pkg/tools/skills_install_test.go b/pkg/tools/skills_install_test.go
new file mode 100644
index 000000000..676fcecc0
--- /dev/null
+++ b/pkg/tools/skills_install_test.go
@@ -0,0 +1,104 @@
+package tools
+
+import (
+ "context"
+ "os"
+ "path/filepath"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "github.com/sipeed/picoclaw/pkg/skills"
+)
+
+func TestInstallSkillToolName(t *testing.T) {
+ tool := NewInstallSkillTool(skills.NewRegistryManager(), t.TempDir())
+ assert.Equal(t, "install_skill", tool.Name())
+}
+
+func TestInstallSkillToolMissingSlug(t *testing.T) {
+ tool := NewInstallSkillTool(skills.NewRegistryManager(), t.TempDir())
+ result := tool.Execute(context.Background(), map[string]any{})
+ assert.True(t, result.IsError)
+ assert.Contains(t, result.ForLLM, "identifier is required and must be a non-empty string")
+}
+
+func TestInstallSkillToolEmptySlug(t *testing.T) {
+ tool := NewInstallSkillTool(skills.NewRegistryManager(), t.TempDir())
+ result := tool.Execute(context.Background(), map[string]any{
+ "slug": " ",
+ })
+ assert.True(t, result.IsError)
+ assert.Contains(t, result.ForLLM, "identifier is required and must be a non-empty string")
+}
+
+func TestInstallSkillToolUnsafeSlug(t *testing.T) {
+ tool := NewInstallSkillTool(skills.NewRegistryManager(), t.TempDir())
+
+ cases := []string{
+ "../etc/passwd",
+ "path/traversal",
+ "path\\traversal",
+ }
+
+ for _, slug := range cases {
+ result := tool.Execute(context.Background(), map[string]any{
+ "slug": slug,
+ })
+ assert.True(t, result.IsError, "slug %q should be rejected", slug)
+ assert.Contains(t, result.ForLLM, "invalid slug")
+ }
+}
+
+func TestInstallSkillToolAlreadyExists(t *testing.T) {
+ workspace := t.TempDir()
+ skillDir := filepath.Join(workspace, "skills", "existing-skill")
+ require.NoError(t, os.MkdirAll(skillDir, 0o755))
+
+ tool := NewInstallSkillTool(skills.NewRegistryManager(), workspace)
+ result := tool.Execute(context.Background(), map[string]any{
+ "slug": "existing-skill",
+ "registry": "clawhub",
+ })
+ assert.True(t, result.IsError)
+ assert.Contains(t, result.ForLLM, "already installed")
+}
+
+func TestInstallSkillToolRegistryNotFound(t *testing.T) {
+ workspace := t.TempDir()
+ tool := NewInstallSkillTool(skills.NewRegistryManager(), workspace)
+ result := tool.Execute(context.Background(), map[string]any{
+ "slug": "some-skill",
+ "registry": "nonexistent",
+ })
+ assert.True(t, result.IsError)
+ assert.Contains(t, result.ForLLM, "registry")
+ assert.Contains(t, result.ForLLM, "not found")
+}
+
+func TestInstallSkillToolParameters(t *testing.T) {
+ tool := NewInstallSkillTool(skills.NewRegistryManager(), t.TempDir())
+ params := tool.Parameters()
+
+ props, ok := params["properties"].(map[string]any)
+ assert.True(t, ok)
+ assert.Contains(t, props, "slug")
+ assert.Contains(t, props, "version")
+ assert.Contains(t, props, "registry")
+ assert.Contains(t, props, "force")
+
+ required, ok := params["required"].([]string)
+ assert.True(t, ok)
+ assert.Contains(t, required, "slug")
+ assert.Contains(t, required, "registry")
+}
+
+func TestInstallSkillToolMissingRegistry(t *testing.T) {
+ tool := NewInstallSkillTool(skills.NewRegistryManager(), t.TempDir())
+ result := tool.Execute(context.Background(), map[string]any{
+ "slug": "some-skill",
+ })
+ assert.True(t, result.IsError)
+ assert.Contains(t, result.ForLLM, "invalid registry")
+}
diff --git a/pkg/tools/skills_search.go b/pkg/tools/skills_search.go
new file mode 100644
index 000000000..2b6cffd38
--- /dev/null
+++ b/pkg/tools/skills_search.go
@@ -0,0 +1,119 @@
+package tools
+
+import (
+ "context"
+ "fmt"
+ "strings"
+
+ "github.com/sipeed/picoclaw/pkg/skills"
+)
+
+// FindSkillsTool allows the LLM agent to search for installable skills from registries.
+type FindSkillsTool struct {
+ registryMgr *skills.RegistryManager
+ cache *skills.SearchCache
+}
+
+// NewFindSkillsTool creates a new FindSkillsTool.
+// registryMgr is the shared registry manager (built from config in createToolRegistry).
+// cache is the search cache for deduplicating similar queries.
+func NewFindSkillsTool(registryMgr *skills.RegistryManager, cache *skills.SearchCache) *FindSkillsTool {
+ return &FindSkillsTool{
+ registryMgr: registryMgr,
+ cache: cache,
+ }
+}
+
+func (t *FindSkillsTool) Name() string {
+ return "find_skills"
+}
+
+func (t *FindSkillsTool) Description() string {
+ return "Search for installable skills from skill registries. Returns skill slugs, descriptions, versions, and relevance scores. Use this to discover skills before installing them with install_skill."
+}
+
+func (t *FindSkillsTool) Parameters() map[string]any {
+ return map[string]any{
+ "type": "object",
+ "properties": map[string]any{
+ "query": map[string]any{
+ "type": "string",
+ "description": "Search query describing the desired skill capability (e.g., 'github integration', 'database management')",
+ },
+ "limit": map[string]any{
+ "type": "integer",
+ "description": "Maximum number of results to return (1-20, default 5)",
+ "minimum": 1.0,
+ "maximum": 20.0,
+ },
+ },
+ "required": []string{"query"},
+ }
+}
+
+func (t *FindSkillsTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
+ query, ok := args["query"].(string)
+ query = strings.ToLower(strings.TrimSpace(query))
+ if !ok || query == "" {
+ return ErrorResult("query is required and must be a non-empty string")
+ }
+
+ limit := 5
+ if l, ok := args["limit"].(float64); ok {
+ li := int(l)
+ if li >= 1 && li <= 20 {
+ limit = li
+ }
+ }
+
+ // Check cache first.
+ if t.cache != nil {
+ if cached, hit := t.cache.Get(query); hit {
+ return SilentResult(formatSearchResults(query, cached, true))
+ }
+ }
+
+ // Search all registries.
+ results, err := t.registryMgr.SearchAll(ctx, query, limit)
+ if err != nil {
+ return ErrorResult(fmt.Sprintf("skill search failed: %v", err))
+ }
+
+ // Cache the results.
+ if t.cache != nil && len(results) > 0 {
+ t.cache.Put(query, results)
+ }
+
+ return SilentResult(formatSearchResults(query, results, false))
+}
+
+func formatSearchResults(query string, results []skills.SearchResult, cached bool) string {
+ if len(results) == 0 {
+ return fmt.Sprintf("No skills found for query: %q", query)
+ }
+
+ var sb strings.Builder
+ source := ""
+ if cached {
+ source = " (cached)"
+ }
+ sb.WriteString(fmt.Sprintf("Found %d skills for %q%s:\n\n", len(results), query, source))
+
+ for i, r := range results {
+ sb.WriteString(fmt.Sprintf("%d. **%s**", i+1, r.Slug))
+ if r.Version != "" {
+ sb.WriteString(fmt.Sprintf(" v%s", r.Version))
+ }
+ sb.WriteString(fmt.Sprintf(" (score: %.3f, registry: %s)\n", r.Score, r.RegistryName))
+ if r.DisplayName != "" && r.DisplayName != r.Slug {
+ sb.WriteString(fmt.Sprintf(" Name: %s\n", r.DisplayName))
+ }
+ if r.Summary != "" {
+ sb.WriteString(fmt.Sprintf(" %s\n", r.Summary))
+ }
+ sb.WriteString("\n")
+ }
+
+ sb.WriteString("Use install_skill with the slug to install a skill.")
+ return sb.String()
+}
diff --git a/pkg/tools/skills_search_test.go b/pkg/tools/skills_search_test.go
new file mode 100644
index 000000000..0e5387cf5
--- /dev/null
+++ b/pkg/tools/skills_search_test.go
@@ -0,0 +1,90 @@
+package tools
+
+import (
+ "context"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+
+ "github.com/sipeed/picoclaw/pkg/skills"
+)
+
+func TestFindSkillsToolName(t *testing.T) {
+ tool := NewFindSkillsTool(skills.NewRegistryManager(), nil)
+ assert.Equal(t, "find_skills", tool.Name())
+}
+
+func TestFindSkillsToolMissingQuery(t *testing.T) {
+ tool := NewFindSkillsTool(skills.NewRegistryManager(), nil)
+ result := tool.Execute(context.Background(), map[string]any{})
+ assert.True(t, result.IsError)
+ assert.Contains(t, result.ForLLM, "query is required")
+}
+
+func TestFindSkillsToolEmptyQuery(t *testing.T) {
+ tool := NewFindSkillsTool(skills.NewRegistryManager(), nil)
+ result := tool.Execute(context.Background(), map[string]any{
+ "query": " ",
+ })
+ assert.True(t, result.IsError)
+}
+
+func TestFindSkillsToolCacheHit(t *testing.T) {
+ cache := skills.NewSearchCache(10, 5*60*1000*1000*1000) // 5 min
+ cache.Put("github", []skills.SearchResult{
+ {Slug: "github", Score: 0.9, RegistryName: "clawhub"},
+ })
+
+ tool := NewFindSkillsTool(skills.NewRegistryManager(), cache)
+ result := tool.Execute(context.Background(), map[string]any{
+ "query": "github",
+ })
+
+ assert.False(t, result.IsError)
+ assert.Contains(t, result.ForLLM, "github")
+ assert.Contains(t, result.ForLLM, "cached")
+}
+
+func TestFindSkillsToolParameters(t *testing.T) {
+ tool := NewFindSkillsTool(skills.NewRegistryManager(), nil)
+ params := tool.Parameters()
+
+ props, ok := params["properties"].(map[string]any)
+ assert.True(t, ok)
+ assert.Contains(t, props, "query")
+ assert.Contains(t, props, "limit")
+
+ required, ok := params["required"].([]string)
+ assert.True(t, ok)
+ assert.Contains(t, required, "query")
+}
+
+func TestFindSkillsToolDescription(t *testing.T) {
+ tool := NewFindSkillsTool(skills.NewRegistryManager(), nil)
+ assert.NotEmpty(t, tool.Description())
+ assert.Contains(t, tool.Description(), "skill")
+}
+
+func TestFormatSearchResultsEmpty(t *testing.T) {
+ result := formatSearchResults("test query", nil, false)
+ assert.Contains(t, result, "No skills found")
+}
+
+func TestFormatSearchResultsWithData(t *testing.T) {
+ results := []skills.SearchResult{
+ {
+ Slug: "github",
+ Score: 0.95,
+ DisplayName: "GitHub",
+ Summary: "GitHub API integration",
+ Version: "1.0.0",
+ RegistryName: "clawhub",
+ },
+ }
+ output := formatSearchResults("github", results, false)
+ assert.Contains(t, output, "github")
+ assert.Contains(t, output, "v1.0.0")
+ assert.Contains(t, output, "0.950")
+ assert.Contains(t, output, "clawhub")
+ assert.Contains(t, output, "install_skill")
+}
diff --git a/pkg/tools/spawn.go b/pkg/tools/spawn.go
index f01372467..8b166b41f 100644
--- a/pkg/tools/spawn.go
+++ b/pkg/tools/spawn.go
@@ -3,6 +3,7 @@ package tools
import (
"context"
"fmt"
+ "strings"
)
type SpawnTool struct {
@@ -34,19 +35,19 @@ func (t *SpawnTool) Description() string {
return "Spawn a subagent to handle a task in the background. Use this for complex or time-consuming tasks that can run independently. The subagent will complete the task and report back when done."
}
-func (t *SpawnTool) Parameters() map[string]interface{} {
- return map[string]interface{}{
+func (t *SpawnTool) Parameters() map[string]any {
+ return map[string]any{
"type": "object",
- "properties": map[string]interface{}{
- "task": map[string]interface{}{
+ "properties": map[string]any{
+ "task": map[string]any{
"type": "string",
"description": "The task for subagent to complete",
},
- "label": map[string]interface{}{
+ "label": map[string]any{
"type": "string",
"description": "Optional short label for the task (for display)",
},
- "agent_id": map[string]interface{}{
+ "agent_id": map[string]any{
"type": "string",
"description": "Optional target agent ID to delegate the task to",
},
@@ -64,10 +65,10 @@ func (t *SpawnTool) SetAllowlistChecker(check func(targetAgentID string) bool) {
t.allowlistCheck = check
}
-func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
+func (t *SpawnTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
task, ok := args["task"].(string)
- if !ok {
- return ErrorResult("task is required")
+ if !ok || strings.TrimSpace(task) == "" {
+ return ErrorResult("task is required and must be a non-empty string")
}
label, _ := args["label"].(string)
diff --git a/pkg/tools/spawn_test.go b/pkg/tools/spawn_test.go
new file mode 100644
index 000000000..0646c82a9
--- /dev/null
+++ b/pkg/tools/spawn_test.go
@@ -0,0 +1,79 @@
+package tools
+
+import (
+ "context"
+ "strings"
+ "testing"
+)
+
+func TestSpawnTool_Execute_EmptyTask(t *testing.T) {
+ provider := &MockLLMProvider{}
+ manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
+ tool := NewSpawnTool(manager)
+
+ ctx := context.Background()
+
+ tests := []struct {
+ name string
+ args map[string]any
+ }{
+ {"empty string", map[string]any{"task": ""}},
+ {"whitespace only", map[string]any{"task": " "}},
+ {"tabs and newlines", map[string]any{"task": "\t\n "}},
+ {"missing task key", map[string]any{"label": "test"}},
+ {"wrong type", map[string]any{"task": 123}},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := tool.Execute(ctx, tt.args)
+ if result == nil {
+ t.Fatal("Result should not be nil")
+ }
+ if !result.IsError {
+ t.Error("Expected error for invalid task parameter")
+ }
+ if !strings.Contains(result.ForLLM, "task is required") {
+ t.Errorf("Error message should mention 'task is required', got: %s", result.ForLLM)
+ }
+ })
+ }
+}
+
+func TestSpawnTool_Execute_ValidTask(t *testing.T) {
+ provider := &MockLLMProvider{}
+ manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
+ tool := NewSpawnTool(manager)
+
+ ctx := context.Background()
+ args := map[string]any{
+ "task": "Write a haiku about coding",
+ "label": "haiku-task",
+ }
+
+ result := tool.Execute(ctx, args)
+ if result == nil {
+ t.Fatal("Result should not be nil")
+ }
+ if result.IsError {
+ t.Errorf("Expected success for valid task, got error: %s", result.ForLLM)
+ }
+ if !result.Async {
+ t.Error("SpawnTool should return async result")
+ }
+}
+
+func TestSpawnTool_Execute_NilManager(t *testing.T) {
+ tool := NewSpawnTool(nil)
+
+ ctx := context.Background()
+ args := map[string]any{"task": "test task"}
+
+ result := tool.Execute(ctx, args)
+ if !result.IsError {
+ t.Error("Expected error for nil manager")
+ }
+ if !strings.Contains(result.ForLLM, "Subagent manager not configured") {
+ t.Errorf("Error message should mention manager not configured, got: %s", result.ForLLM)
+ }
+}
diff --git a/pkg/tools/spi.go b/pkg/tools/spi.go
index 4805d6a35..0ca17e84f 100644
--- a/pkg/tools/spi.go
+++ b/pkg/tools/spi.go
@@ -24,41 +24,41 @@ func (t *SPITool) Description() string {
return "Interact with SPI bus devices for high-speed peripheral communication. Actions: list (find SPI devices), transfer (full-duplex send/receive), read (receive bytes). Linux only."
}
-func (t *SPITool) Parameters() map[string]interface{} {
- return map[string]interface{}{
+func (t *SPITool) Parameters() map[string]any {
+ return map[string]any{
"type": "object",
- "properties": map[string]interface{}{
- "action": map[string]interface{}{
+ "properties": map[string]any{
+ "action": map[string]any{
"type": "string",
"enum": []string{"list", "transfer", "read"},
"description": "Action to perform: list (find available SPI devices), transfer (full-duplex send/receive), read (receive bytes by sending zeros)",
},
- "device": map[string]interface{}{
+ "device": map[string]any{
"type": "string",
"description": "SPI device identifier (e.g. \"2.0\" for /dev/spidev2.0). Required for transfer/read.",
},
- "speed": map[string]interface{}{
+ "speed": map[string]any{
"type": "integer",
"description": "SPI clock speed in Hz. Default: 1000000 (1 MHz).",
},
- "mode": map[string]interface{}{
+ "mode": map[string]any{
"type": "integer",
"description": "SPI mode (0-3). Default: 0. Mode sets CPOL and CPHA: 0=0,0 1=0,1 2=1,0 3=1,1.",
},
- "bits": map[string]interface{}{
+ "bits": map[string]any{
"type": "integer",
"description": "Bits per word. Default: 8.",
},
- "data": map[string]interface{}{
+ "data": map[string]any{
"type": "array",
- "items": map[string]interface{}{"type": "integer"},
+ "items": map[string]any{"type": "integer"},
"description": "Bytes to send (0-255 each). Required for transfer action.",
},
- "length": map[string]interface{}{
+ "length": map[string]any{
"type": "integer",
"description": "Number of bytes to read (1-4096). Required for read action.",
},
- "confirm": map[string]interface{}{
+ "confirm": map[string]any{
"type": "boolean",
"description": "Must be true for transfer operations. Safety guard to prevent accidental writes.",
},
@@ -67,7 +67,7 @@ func (t *SPITool) Parameters() map[string]interface{} {
}
}
-func (t *SPITool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
+func (t *SPITool) Execute(ctx context.Context, args map[string]any) *ToolResult {
if runtime.GOOS != "linux" {
return ErrorResult("SPI is only supported on Linux. This tool requires /dev/spidev* device files.")
}
@@ -97,7 +97,9 @@ func (t *SPITool) list() *ToolResult {
}
if len(matches) == 0 {
- return SilentResult("No SPI devices found. You may need to:\n1. Enable SPI in device tree\n2. Configure pinmux for your board (see hardware skill)\n3. Check that spidev module is loaded")
+ return SilentResult(
+ "No SPI devices found. You may need to:\n1. Enable SPI in device tree\n2. Configure pinmux for your board (see hardware skill)\n3. Check that spidev module is loaded",
+ )
}
type devInfo struct {
@@ -117,8 +119,12 @@ func (t *SPITool) list() *ToolResult {
return SilentResult(fmt.Sprintf("Found %d SPI device(s):\n%s", len(devices), string(result)))
}
+// Helper function for SPI operations (used by platform-specific implementations)
+
// parseSPIArgs extracts and validates common SPI parameters
-func parseSPIArgs(args map[string]interface{}) (device string, speed uint32, mode uint8, bits uint8, errMsg string) {
+//
+//nolint:unused // Used by spi_linux.go
+func parseSPIArgs(args map[string]any) (device string, speed uint32, mode uint8, bits uint8, errMsg string) {
dev, ok := args["device"].(string)
if !ok || dev == "" {
return "", 0, 0, 0, "device is required (e.g. \"2.0\" for /dev/spidev2.0)"
diff --git a/pkg/tools/spi_linux.go b/pkg/tools/spi_linux.go
index 12b696007..9def73662 100644
--- a/pkg/tools/spi_linux.go
+++ b/pkg/tools/spi_linux.go
@@ -66,10 +66,12 @@ func configureSPI(devPath string, mode uint8, bits uint8, speed uint32) (int, *T
}
// transfer performs a full-duplex SPI transfer
-func (t *SPITool) transfer(args map[string]interface{}) *ToolResult {
+func (t *SPITool) transfer(args map[string]any) *ToolResult {
confirm, _ := args["confirm"].(bool)
if !confirm {
- return ErrorResult("transfer operations require confirm: true. Please confirm with the user before sending data to SPI devices.")
+ return ErrorResult(
+ "transfer operations require confirm: true. Please confirm with the user before sending data to SPI devices.",
+ )
}
dev, speed, mode, bits, errMsg := parseSPIArgs(args)
@@ -77,7 +79,7 @@ func (t *SPITool) transfer(args map[string]interface{}) *ToolResult {
return ErrorResult(errMsg)
}
- dataRaw, ok := args["data"].([]interface{})
+ dataRaw, ok := args["data"].([]any)
if !ok || len(dataRaw) == 0 {
return ErrorResult("data is required for transfer (array of byte values 0-255)")
}
@@ -130,7 +132,7 @@ func (t *SPITool) transfer(args map[string]interface{}) *ToolResult {
intBytes[i] = int(b)
}
- result, _ := json.MarshalIndent(map[string]interface{}{
+ result, _ := json.MarshalIndent(map[string]any{
"device": devPath,
"sent": len(txBuf),
"received": intBytes,
@@ -140,7 +142,7 @@ func (t *SPITool) transfer(args map[string]interface{}) *ToolResult {
}
// readDevice reads bytes from SPI by sending zeros (read-only, no confirm needed)
-func (t *SPITool) readDevice(args map[string]interface{}) *ToolResult {
+func (t *SPITool) readDevice(args map[string]any) *ToolResult {
dev, speed, mode, bits, errMsg := parseSPIArgs(args)
if errMsg != "" {
return ErrorResult(errMsg)
@@ -186,7 +188,7 @@ func (t *SPITool) readDevice(args map[string]interface{}) *ToolResult {
intBytes[i] = int(b)
}
- result, _ := json.MarshalIndent(map[string]interface{}{
+ result, _ := json.MarshalIndent(map[string]any{
"device": devPath,
"bytes": intBytes,
"hex": hexBytes,
diff --git a/pkg/tools/spi_other.go b/pkg/tools/spi_other.go
index 6dfc86fd1..5d078ac3f 100644
--- a/pkg/tools/spi_other.go
+++ b/pkg/tools/spi_other.go
@@ -3,11 +3,11 @@
package tools
// transfer is a stub for non-Linux platforms.
-func (t *SPITool) transfer(args map[string]interface{}) *ToolResult {
+func (t *SPITool) transfer(args map[string]any) *ToolResult {
return ErrorResult("SPI is only supported on Linux")
}
// readDevice is a stub for non-Linux platforms.
-func (t *SPITool) readDevice(args map[string]interface{}) *ToolResult {
+func (t *SPITool) readDevice(args map[string]any) *ToolResult {
return ErrorResult("SPI is only supported on Linux")
}
diff --git a/pkg/tools/subagent.go b/pkg/tools/subagent.go
index 294ba6ea8..ad371a649 100644
--- a/pkg/tools/subagent.go
+++ b/pkg/tools/subagent.go
@@ -38,7 +38,11 @@ type SubagentManager struct {
nextID int
}
-func NewSubagentManager(provider providers.LLMProvider, defaultModel, workspace string, bus *bus.MessageBus) *SubagentManager {
+func NewSubagentManager(
+ provider providers.LLMProvider,
+ defaultModel, workspace string,
+ bus *bus.MessageBus,
+) *SubagentManager {
return &SubagentManager{
tasks: make(map[string]*SubagentTask),
provider: provider,
@@ -76,7 +80,11 @@ func (sm *SubagentManager) RegisterTool(tool Tool) {
sm.tools.Register(tool)
}
-func (sm *SubagentManager) Spawn(ctx context.Context, task, label, agentID, originChannel, originChatID string, callback AsyncCallback) (string, error) {
+func (sm *SubagentManager) Spawn(
+ ctx context.Context,
+ task, label, agentID, originChannel, originChatID string,
+ callback AsyncCallback,
+) (string, error) {
sm.mu.Lock()
defer sm.mu.Unlock()
@@ -124,12 +132,12 @@ After completing the task, provide a clear summary of what was done.`
},
}
- // Check if context is already cancelled before starting
+ // Check if context is already canceled before starting
select {
case <-ctx.Done():
sm.mu.Lock()
- task.Status = "cancelled"
- task.Result = "Task cancelled before execution"
+ task.Status = "canceled"
+ task.Result = "Task canceled before execution"
sm.mu.Unlock()
return
default:
@@ -177,10 +185,10 @@ After completing the task, provide a clear summary of what was done.`
if err != nil {
task.Status = "failed"
task.Result = fmt.Sprintf("Error: %v", err)
- // Check if it was cancelled
+ // Check if it was canceled
if ctx.Err() != nil {
- task.Status = "cancelled"
- task.Result = "Task cancelled during execution"
+ task.Status = "canceled"
+ task.Result = "Task canceled during execution"
}
result = &ToolResult{
ForLLM: task.Result,
@@ -194,7 +202,12 @@ After completing the task, provide a clear summary of what was done.`
task.Status = "completed"
task.Result = loopResult.Content
result = &ToolResult{
- ForLLM: fmt.Sprintf("Subagent '%s' completed (iterations: %d): %s", task.Label, loopResult.Iterations, loopResult.Content),
+ ForLLM: fmt.Sprintf(
+ "Subagent '%s' completed (iterations: %d): %s",
+ task.Label,
+ loopResult.Iterations,
+ loopResult.Content,
+ ),
ForUser: loopResult.Content,
Silent: false,
IsError: false,
@@ -258,15 +271,15 @@ func (t *SubagentTool) Description() string {
return "Execute a subagent task synchronously and return the result. Use this for delegating specific tasks to an independent agent instance. Returns execution summary to user and full details to LLM."
}
-func (t *SubagentTool) Parameters() map[string]interface{} {
- return map[string]interface{}{
+func (t *SubagentTool) Parameters() map[string]any {
+ return map[string]any{
"type": "object",
- "properties": map[string]interface{}{
- "task": map[string]interface{}{
+ "properties": map[string]any{
+ "task": map[string]any{
"type": "string",
"description": "The task for subagent to complete",
},
- "label": map[string]interface{}{
+ "label": map[string]any{
"type": "string",
"description": "Optional short label for the task (for display)",
},
@@ -280,7 +293,7 @@ func (t *SubagentTool) SetContext(channel, chatID string) {
t.originChatID = chatID
}
-func (t *SubagentTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
+func (t *SubagentTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
task, ok := args["task"].(string)
if !ok {
return ErrorResult("task is required").WithError(fmt.Errorf("task parameter is required"))
diff --git a/pkg/tools/subagent_tool_test.go b/pkg/tools/subagent_tool_test.go
index f960a7fda..59bfdffae 100644
--- a/pkg/tools/subagent_tool_test.go
+++ b/pkg/tools/subagent_tool_test.go
@@ -11,10 +11,16 @@ import (
// MockLLMProvider is a test implementation of LLMProvider
type MockLLMProvider struct {
- lastOptions map[string]interface{}
+ lastOptions map[string]any
}
-func (m *MockLLMProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]interface{}) (*providers.LLMResponse, error) {
+func (m *MockLLMProvider) Chat(
+ ctx context.Context,
+ messages []providers.Message,
+ tools []providers.ToolDefinition,
+ model string,
+ options map[string]any,
+) (*providers.LLMResponse, error) {
m.lastOptions = options
// Find the last user message to generate a response
for i := len(messages) - 1; i >= 0; i-- {
@@ -47,7 +53,7 @@ func TestSubagentManager_SetLLMOptions_AppliesToRunToolLoop(t *testing.T) {
tool.SetContext("cli", "direct")
ctx := context.Background()
- args := map[string]interface{}{"task": "Do something"}
+ args := map[string]any{"task": "Do something"}
result := tool.Execute(ctx, args)
if result == nil || result.IsError {
@@ -108,13 +114,13 @@ func TestSubagentTool_Parameters(t *testing.T) {
}
// Check properties
- props, ok := params["properties"].(map[string]interface{})
+ props, ok := params["properties"].(map[string]any)
if !ok {
t.Fatal("Properties should be a map")
}
// Verify task parameter
- task, ok := props["task"].(map[string]interface{})
+ task, ok := props["task"].(map[string]any)
if !ok {
t.Fatal("Task parameter should exist")
}
@@ -123,7 +129,7 @@ func TestSubagentTool_Parameters(t *testing.T) {
}
// Verify label parameter
- label, ok := props["label"].(map[string]interface{})
+ label, ok := props["label"].(map[string]any)
if !ok {
t.Fatal("Label parameter should exist")
}
@@ -163,7 +169,7 @@ func TestSubagentTool_Execute_Success(t *testing.T) {
tool.SetContext("telegram", "chat-123")
ctx := context.Background()
- args := map[string]interface{}{
+ args := map[string]any{
"task": "Write a haiku about coding",
"label": "haiku-task",
}
@@ -218,7 +224,7 @@ func TestSubagentTool_Execute_NoLabel(t *testing.T) {
tool := NewSubagentTool(manager)
ctx := context.Background()
- args := map[string]interface{}{
+ args := map[string]any{
"task": "Test task without label",
}
@@ -241,7 +247,7 @@ func TestSubagentTool_Execute_MissingTask(t *testing.T) {
tool := NewSubagentTool(manager)
ctx := context.Background()
- args := map[string]interface{}{
+ args := map[string]any{
"label": "test",
}
@@ -268,7 +274,7 @@ func TestSubagentTool_Execute_NilManager(t *testing.T) {
tool := NewSubagentTool(nil)
ctx := context.Background()
- args := map[string]interface{}{
+ args := map[string]any{
"task": "test task",
}
@@ -297,7 +303,7 @@ func TestSubagentTool_Execute_ContextPassing(t *testing.T) {
tool.SetContext(channel, chatID)
ctx := context.Background()
- args := map[string]interface{}{
+ args := map[string]any{
"task": "Test context passing",
}
@@ -324,7 +330,7 @@ func TestSubagentTool_ForUserTruncation(t *testing.T) {
// Create a task that will generate long response
longTask := strings.Repeat("This is a very long task description. ", 100)
- args := map[string]interface{}{
+ args := map[string]any{
"task": longTask,
"label": "long-test",
}
diff --git a/pkg/tools/toolloop.go b/pkg/tools/toolloop.go
index 08f14cc92..cdfe0d6ce 100644
--- a/pkg/tools/toolloop.go
+++ b/pkg/tools/toolloop.go
@@ -33,7 +33,12 @@ type ToolLoopResult struct {
// RunToolLoop executes the LLM + tool call iteration loop.
// This is the core agent logic that can be reused by both main agent and subagents.
-func RunToolLoop(ctx context.Context, config ToolLoopConfig, messages []providers.Message, channel, chatID string) (*ToolLoopResult, error) {
+func RunToolLoop(
+ ctx context.Context,
+ config ToolLoopConfig,
+ messages []providers.Message,
+ channel, chatID string,
+) (*ToolLoopResult, error) {
iteration := 0
var finalContent string
diff --git a/pkg/tools/types.go b/pkg/tools/types.go
index f8205b8bd..a6015cde3 100644
--- a/pkg/tools/types.go
+++ b/pkg/tools/types.go
@@ -10,11 +10,11 @@ type Message struct {
}
type ToolCall struct {
- ID string `json:"id"`
- Type string `json:"type"`
- Function *FunctionCall `json:"function,omitempty"`
- Name string `json:"name,omitempty"`
- Arguments map[string]interface{} `json:"arguments,omitempty"`
+ ID string `json:"id"`
+ Type string `json:"type"`
+ Function *FunctionCall `json:"function,omitempty"`
+ Name string `json:"name,omitempty"`
+ Arguments map[string]any `json:"arguments,omitempty"`
}
type FunctionCall struct {
@@ -36,7 +36,13 @@ type UsageInfo struct {
}
type LLMProvider interface {
- Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error)
+ Chat(
+ ctx context.Context,
+ messages []Message,
+ tools []ToolDefinition,
+ model string,
+ options map[string]any,
+ ) (*LLMResponse, error)
GetDefaultModel() string
}
@@ -46,7 +52,7 @@ type ToolDefinition struct {
}
type ToolFunctionDefinition struct {
- Name string `json:"name"`
- Description string `json:"description"`
- Parameters map[string]interface{} `json:"parameters"`
+ Name string `json:"name"`
+ Description string `json:"description"`
+ Parameters map[string]any `json:"parameters"`
}
diff --git a/pkg/tools/web.go b/pkg/tools/web.go
index e1a640ff0..e95185599 100644
--- a/pkg/tools/web.go
+++ b/pkg/tools/web.go
@@ -1,6 +1,7 @@
package tools
import (
+ "bytes"
"context"
"encoding/json"
"fmt"
@@ -16,12 +17,50 @@ const (
userAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"
)
+// createHTTPClient creates an HTTP client with optional proxy support
+func createHTTPClient(proxyURL string, timeout time.Duration) (*http.Client, error) {
+ client := &http.Client{
+ Timeout: timeout,
+ Transport: &http.Transport{
+ MaxIdleConns: 10,
+ IdleConnTimeout: 30 * time.Second,
+ DisableCompression: false,
+ TLSHandshakeTimeout: 15 * time.Second,
+ },
+ }
+
+ if proxyURL != "" {
+ proxy, err := url.Parse(proxyURL)
+ if err != nil {
+ return nil, fmt.Errorf("invalid proxy URL: %w", err)
+ }
+ scheme := strings.ToLower(proxy.Scheme)
+ switch scheme {
+ case "http", "https", "socks5", "socks5h":
+ default:
+ return nil, fmt.Errorf(
+ "unsupported proxy scheme %q (supported: http, https, socks5, socks5h)",
+ proxy.Scheme,
+ )
+ }
+ if proxy.Host == "" {
+ return nil, fmt.Errorf("invalid proxy URL: missing host")
+ }
+ client.Transport.(*http.Transport).Proxy = http.ProxyURL(proxy)
+ } else {
+ client.Transport.(*http.Transport).Proxy = http.ProxyFromEnvironment
+ }
+
+ return client, nil
+}
+
type SearchProvider interface {
Search(ctx context.Context, query string, count int) (string, error)
}
type BraveSearchProvider struct {
apiKey string
+ proxy string
}
func (p *BraveSearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
@@ -36,7 +75,10 @@ func (p *BraveSearchProvider) Search(ctx context.Context, query string, count in
req.Header.Set("Accept", "application/json")
req.Header.Set("X-Subscription-Token", p.apiKey)
- client := &http.Client{Timeout: 10 * time.Second}
+ client, err := createHTTPClient(p.proxy, 10*time.Second)
+ if err != nil {
+ return "", fmt.Errorf("failed to create HTTP client: %w", err)
+ }
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("request failed: %w", err)
@@ -84,7 +126,95 @@ func (p *BraveSearchProvider) Search(ctx context.Context, query string, count in
return strings.Join(lines, "\n"), nil
}
-type DuckDuckGoSearchProvider struct{}
+type TavilySearchProvider struct {
+ apiKey string
+ baseURL string
+ proxy string
+}
+
+func (p *TavilySearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
+ searchURL := p.baseURL
+ if searchURL == "" {
+ searchURL = "https://api.tavily.com/search"
+ }
+
+ payload := map[string]any{
+ "api_key": p.apiKey,
+ "query": query,
+ "search_depth": "advanced",
+ "include_answer": false,
+ "include_images": false,
+ "include_raw_content": false,
+ "max_results": count,
+ }
+
+ bodyBytes, err := json.Marshal(payload)
+ if err != nil {
+ return "", fmt.Errorf("failed to marshal payload: %w", err)
+ }
+
+ req, err := http.NewRequestWithContext(ctx, "POST", searchURL, bytes.NewBuffer(bodyBytes))
+ if err != nil {
+ return "", fmt.Errorf("failed to create request: %w", err)
+ }
+
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("User-Agent", userAgent)
+
+ client, err := createHTTPClient(p.proxy, 10*time.Second)
+ if err != nil {
+ return "", fmt.Errorf("failed to create HTTP client: %w", err)
+ }
+ resp, err := client.Do(req)
+ if err != nil {
+ return "", fmt.Errorf("request failed: %w", err)
+ }
+ defer resp.Body.Close()
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return "", fmt.Errorf("failed to read response: %w", err)
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ return "", fmt.Errorf("tavily api error (status %d): %s", resp.StatusCode, string(body))
+ }
+
+ var searchResp struct {
+ Results []struct {
+ Title string `json:"title"`
+ URL string `json:"url"`
+ Content string `json:"content"`
+ } `json:"results"`
+ }
+
+ if err := json.Unmarshal(body, &searchResp); err != nil {
+ return "", fmt.Errorf("failed to parse response: %w", err)
+ }
+
+ results := searchResp.Results
+ if len(results) == 0 {
+ return fmt.Sprintf("No results for: %s", query), nil
+ }
+
+ var lines []string
+ lines = append(lines, fmt.Sprintf("Results for: %s (via Tavily)", query))
+ for i, item := range results {
+ if i >= count {
+ break
+ }
+ lines = append(lines, fmt.Sprintf("%d. %s\n %s", i+1, item.Title, item.URL))
+ if item.Content != "" {
+ lines = append(lines, fmt.Sprintf(" %s", item.Content))
+ }
+ }
+
+ return strings.Join(lines, "\n"), nil
+}
+
+type DuckDuckGoSearchProvider struct {
+ proxy string
+}
func (p *DuckDuckGoSearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
searchURL := fmt.Sprintf("https://html.duckduckgo.com/html/?q=%s", url.QueryEscape(query))
@@ -96,7 +226,10 @@ func (p *DuckDuckGoSearchProvider) Search(ctx context.Context, query string, cou
req.Header.Set("User-Agent", userAgent)
- client := &http.Client{Timeout: 10 * time.Second}
+ client, err := createHTTPClient(p.proxy, 10*time.Second)
+ if err != nil {
+ return "", fmt.Errorf("failed to create HTTP client: %w", err)
+ }
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("request failed: %w", err)
@@ -178,16 +311,23 @@ func stripTags(content string) string {
type PerplexitySearchProvider struct {
apiKey string
+ proxy string
}
func (p *PerplexitySearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
searchURL := "https://api.perplexity.ai/chat/completions"
- payload := map[string]interface{}{
+ payload := map[string]any{
"model": "sonar",
"messages": []map[string]string{
- {"role": "system", "content": "You are a search assistant. Provide concise search results with titles, URLs, and brief descriptions in the following format:\n1. Title\n URL\n Description\n\nDo not add extra commentary."},
- {"role": "user", "content": fmt.Sprintf("Search for: %s. Provide up to %d relevant results.", query, count)},
+ {
+ "role": "system",
+ "content": "You are a search assistant. Provide concise search results with titles, URLs, and brief descriptions in the following format:\n1. Title\n URL\n Description\n\nDo not add extra commentary.",
+ },
+ {
+ "role": "user",
+ "content": fmt.Sprintf("Search for: %s. Provide up to %d relevant results.", query, count),
+ },
},
"max_tokens": 1000,
}
@@ -206,7 +346,10 @@ func (p *PerplexitySearchProvider) Search(ctx context.Context, query string, cou
req.Header.Set("Authorization", "Bearer "+p.apiKey)
req.Header.Set("User-Agent", userAgent)
- client := &http.Client{Timeout: 30 * time.Second}
+ client, err := createHTTPClient(p.proxy, 30*time.Second)
+ if err != nil {
+ return "", fmt.Errorf("failed to create HTTP client: %w", err)
+ }
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("request failed: %w", err)
@@ -312,6 +455,10 @@ type WebSearchToolOptions struct {
BraveAPIKey string
BraveMaxResults int
BraveEnabled bool
+ TavilyAPIKey string
+ TavilyBaseURL string
+ TavilyMaxResults int
+ TavilyEnabled bool
DuckDuckGoMaxResults int
DuckDuckGoEnabled bool
PerplexityAPIKey string
@@ -320,20 +467,21 @@ type WebSearchToolOptions struct {
SearXNGBaseURL string
SearXNGMaxResults int
SearXNGEnabled bool
+ Proxy string
}
func NewWebSearchTool(opts WebSearchToolOptions) *WebSearchTool {
var provider SearchProvider
maxResults := 5
- // Priority: Perplexity > Brave > SearXNG > DuckDuckGo
+ // Priority: Perplexity > Brave > SearXNG > Tavily > DuckDuckGo
if opts.PerplexityEnabled && opts.PerplexityAPIKey != "" {
- provider = &PerplexitySearchProvider{apiKey: opts.PerplexityAPIKey}
+ provider = &PerplexitySearchProvider{apiKey: opts.PerplexityAPIKey, proxy: opts.Proxy}
if opts.PerplexityMaxResults > 0 {
maxResults = opts.PerplexityMaxResults
}
} else if opts.BraveEnabled && opts.BraveAPIKey != "" {
- provider = &BraveSearchProvider{apiKey: opts.BraveAPIKey}
+ provider = &BraveSearchProvider{apiKey: opts.BraveAPIKey, proxy: opts.Proxy}
if opts.BraveMaxResults > 0 {
maxResults = opts.BraveMaxResults
}
@@ -342,8 +490,17 @@ func NewWebSearchTool(opts WebSearchToolOptions) *WebSearchTool {
if opts.SearXNGMaxResults > 0 {
maxResults = opts.SearXNGMaxResults
}
+ } else if opts.TavilyEnabled && opts.TavilyAPIKey != "" {
+ provider = &TavilySearchProvider{
+ apiKey: opts.TavilyAPIKey,
+ baseURL: opts.TavilyBaseURL,
+ proxy: opts.Proxy,
+ }
+ if opts.TavilyMaxResults > 0 {
+ maxResults = opts.TavilyMaxResults
+ }
} else if opts.DuckDuckGoEnabled {
- provider = &DuckDuckGoSearchProvider{}
+ provider = &DuckDuckGoSearchProvider{proxy: opts.Proxy}
if opts.DuckDuckGoMaxResults > 0 {
maxResults = opts.DuckDuckGoMaxResults
}
@@ -365,15 +522,15 @@ func (t *WebSearchTool) Description() string {
return "Search the web for current information. Returns titles, URLs, and snippets from search results."
}
-func (t *WebSearchTool) Parameters() map[string]interface{} {
- return map[string]interface{}{
+func (t *WebSearchTool) Parameters() map[string]any {
+ return map[string]any{
"type": "object",
- "properties": map[string]interface{}{
- "query": map[string]interface{}{
+ "properties": map[string]any{
+ "query": map[string]any{
"type": "string",
"description": "Search query",
},
- "count": map[string]interface{}{
+ "count": map[string]any{
"type": "integer",
"description": "Number of results (1-10)",
"minimum": 1.0,
@@ -384,7 +541,7 @@ func (t *WebSearchTool) Parameters() map[string]interface{} {
}
}
-func (t *WebSearchTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
+func (t *WebSearchTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
query, ok := args["query"].(string)
if !ok {
return ErrorResult("query is required")
@@ -410,6 +567,7 @@ func (t *WebSearchTool) Execute(ctx context.Context, args map[string]interface{}
type WebFetchTool struct {
maxChars int
+ proxy string
}
func NewWebFetchTool(maxChars int) *WebFetchTool {
@@ -421,6 +579,16 @@ func NewWebFetchTool(maxChars int) *WebFetchTool {
}
}
+func NewWebFetchToolWithProxy(maxChars int, proxy string) *WebFetchTool {
+ if maxChars <= 0 {
+ maxChars = 50000
+ }
+ return &WebFetchTool{
+ maxChars: maxChars,
+ proxy: proxy,
+ }
+}
+
func (t *WebFetchTool) Name() string {
return "web_fetch"
}
@@ -429,15 +597,15 @@ func (t *WebFetchTool) Description() string {
return "Fetch a URL and extract readable content (HTML to text). Use this to get weather info, news, articles, or any web content."
}
-func (t *WebFetchTool) Parameters() map[string]interface{} {
- return map[string]interface{}{
+func (t *WebFetchTool) Parameters() map[string]any {
+ return map[string]any{
"type": "object",
- "properties": map[string]interface{}{
- "url": map[string]interface{}{
+ "properties": map[string]any{
+ "url": map[string]any{
"type": "string",
"description": "URL to fetch",
},
- "maxChars": map[string]interface{}{
+ "maxChars": map[string]any{
"type": "integer",
"description": "Maximum characters to extract",
"minimum": 100.0,
@@ -447,7 +615,7 @@ func (t *WebFetchTool) Parameters() map[string]interface{} {
}
}
-func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
+func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
urlStr, ok := args["url"].(string)
if !ok {
return ErrorResult("url is required")
@@ -480,20 +648,17 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{})
req.Header.Set("User-Agent", userAgent)
- client := &http.Client{
- Timeout: 60 * time.Second,
- Transport: &http.Transport{
- MaxIdleConns: 10,
- IdleConnTimeout: 30 * time.Second,
- DisableCompression: false,
- TLSHandshakeTimeout: 15 * time.Second,
- },
- CheckRedirect: func(req *http.Request, via []*http.Request) error {
- if len(via) >= 5 {
- return fmt.Errorf("stopped after 5 redirects")
- }
- return nil
- },
+ client, err := createHTTPClient(t.proxy, 60*time.Second)
+ if err != nil {
+ return ErrorResult(fmt.Sprintf("failed to create HTTP client: %v", err))
+ }
+
+ // Configure redirect handling
+ client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
+ if len(via) >= 5 {
+ return fmt.Errorf("stopped after 5 redirects")
+ }
+ return nil
}
resp, err := client.Do(req)
@@ -512,7 +677,7 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{})
var text, extractor string
if strings.Contains(contentType, "application/json") {
- var jsonData interface{}
+ var jsonData any
if err := json.Unmarshal(body, &jsonData); err == nil {
formatted, _ := json.MarshalIndent(jsonData, "", " ")
text = string(formatted)
@@ -535,7 +700,7 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{})
text = text[:maxChars]
}
- result := map[string]interface{}{
+ result := map[string]any{
"url": urlStr,
"status": resp.StatusCode,
"extractor": extractor,
@@ -547,7 +712,13 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{})
resultJSON, _ := json.MarshalIndent(result, "", " ")
return &ToolResult{
- ForLLM: fmt.Sprintf("Fetched %d bytes from %s (extractor: %s, truncated: %v)", len(text), urlStr, extractor, truncated),
+ ForLLM: fmt.Sprintf(
+ "Fetched %d bytes from %s (extractor: %s, truncated: %v)",
+ len(text),
+ urlStr,
+ extractor,
+ truncated,
+ ),
ForUser: string(resultJSON),
}
}
diff --git a/pkg/tools/web_test.go b/pkg/tools/web_test.go
index 7e6d62213..2cd79eb24 100644
--- a/pkg/tools/web_test.go
+++ b/pkg/tools/web_test.go
@@ -7,6 +7,7 @@ import (
"net/http/httptest"
"strings"
"testing"
+ "time"
)
// TestWebTool_WebFetch_Success verifies successful URL fetching
@@ -20,7 +21,7 @@ func TestWebTool_WebFetch_Success(t *testing.T) {
tool := NewWebFetchTool(50000)
ctx := context.Background()
- args := map[string]interface{}{
+ args := map[string]any{
"url": server.URL,
}
@@ -56,7 +57,7 @@ func TestWebTool_WebFetch_JSON(t *testing.T) {
tool := NewWebFetchTool(50000)
ctx := context.Background()
- args := map[string]interface{}{
+ args := map[string]any{
"url": server.URL,
}
@@ -77,7 +78,7 @@ func TestWebTool_WebFetch_JSON(t *testing.T) {
func TestWebTool_WebFetch_InvalidURL(t *testing.T) {
tool := NewWebFetchTool(50000)
ctx := context.Background()
- args := map[string]interface{}{
+ args := map[string]any{
"url": "not-a-valid-url",
}
@@ -98,7 +99,7 @@ func TestWebTool_WebFetch_InvalidURL(t *testing.T) {
func TestWebTool_WebFetch_UnsupportedScheme(t *testing.T) {
tool := NewWebFetchTool(50000)
ctx := context.Background()
- args := map[string]interface{}{
+ args := map[string]any{
"url": "ftp://example.com/file.txt",
}
@@ -119,7 +120,7 @@ func TestWebTool_WebFetch_UnsupportedScheme(t *testing.T) {
func TestWebTool_WebFetch_MissingURL(t *testing.T) {
tool := NewWebFetchTool(50000)
ctx := context.Background()
- args := map[string]interface{}{}
+ args := map[string]any{}
result := tool.Execute(ctx, args)
@@ -147,7 +148,7 @@ func TestWebTool_WebFetch_Truncation(t *testing.T) {
tool := NewWebFetchTool(1000) // Limit to 1000 chars
ctx := context.Background()
- args := map[string]interface{}{
+ args := map[string]any{
"url": server.URL,
}
@@ -159,7 +160,7 @@ func TestWebTool_WebFetch_Truncation(t *testing.T) {
}
// ForUser should contain truncated content (not the full 20000 chars)
- resultMap := make(map[string]interface{})
+ resultMap := make(map[string]any)
json.Unmarshal([]byte(result.ForUser), &resultMap)
if text, ok := resultMap["text"].(string); ok {
if len(text) > 1100 { // Allow some margin
@@ -191,7 +192,7 @@ func TestWebTool_WebSearch_NoApiKey(t *testing.T) {
func TestWebTool_WebSearch_MissingQuery(t *testing.T) {
tool := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: "test-key", BraveMaxResults: 5})
ctx := context.Background()
- args := map[string]interface{}{}
+ args := map[string]any{}
result := tool.Execute(ctx, args)
@@ -206,13 +207,17 @@ func TestWebTool_WebFetch_HTMLExtraction(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html")
w.WriteHeader(http.StatusOK)
- w.Write([]byte(`Title
Content
`))
+ w.Write(
+ []byte(
+ `Title
Content
`,
+ ),
+ )
}))
defer server.Close()
tool := NewWebFetchTool(50000)
ctx := context.Background()
- args := map[string]interface{}{
+ args := map[string]any{
"url": server.URL,
}
@@ -251,7 +256,8 @@ func TestWebFetchTool_extractText(t *testing.T) {
if len(lines) < 2 {
t.Errorf("Expected multiple lines, got %d: %q", len(lines), got)
}
- if !strings.Contains(got, "Title") || !strings.Contains(got, "Paragraph 1") || !strings.Contains(got, "Paragraph 2") {
+ if !strings.Contains(got, "Title") || !strings.Contains(got, "Paragraph 1") ||
+ !strings.Contains(got, "Paragraph 2") {
t.Errorf("Missing expected text: %q", got)
}
},
@@ -312,7 +318,7 @@ func TestWebFetchTool_extractText(t *testing.T) {
func TestWebTool_WebFetch_MissingDomain(t *testing.T) {
tool := NewWebFetchTool(50000)
ctx := context.Background()
- args := map[string]interface{}{
+ args := map[string]any{
"url": "https://",
}
@@ -328,3 +334,241 @@ func TestWebTool_WebFetch_MissingDomain(t *testing.T) {
t.Errorf("Expected domain error message, got ForLLM: %s", result.ForLLM)
}
}
+
+func TestCreateHTTPClient_ProxyConfigured(t *testing.T) {
+ client, err := createHTTPClient("http://127.0.0.1:7890", 12*time.Second)
+ if err != nil {
+ t.Fatalf("createHTTPClient() error: %v", err)
+ }
+ if client.Timeout != 12*time.Second {
+ t.Fatalf("client.Timeout = %v, want %v", client.Timeout, 12*time.Second)
+ }
+
+ tr, ok := client.Transport.(*http.Transport)
+ if !ok {
+ t.Fatalf("client.Transport type = %T, want *http.Transport", client.Transport)
+ }
+ if tr.Proxy == nil {
+ t.Fatal("transport.Proxy is nil, want non-nil")
+ }
+
+ req, err := http.NewRequest("GET", "https://example.com", nil)
+ if err != nil {
+ t.Fatalf("http.NewRequest() error: %v", err)
+ }
+ proxyURL, err := tr.Proxy(req)
+ if err != nil {
+ t.Fatalf("transport.Proxy(req) error: %v", err)
+ }
+ if proxyURL == nil || proxyURL.String() != "http://127.0.0.1:7890" {
+ t.Fatalf("proxy URL = %v, want %q", proxyURL, "http://127.0.0.1:7890")
+ }
+}
+
+func TestCreateHTTPClient_InvalidProxy(t *testing.T) {
+ _, err := createHTTPClient("://bad-proxy", 10*time.Second)
+ if err == nil {
+ t.Fatal("createHTTPClient() expected error for invalid proxy URL, got nil")
+ }
+}
+
+func TestCreateHTTPClient_Socks5ProxyConfigured(t *testing.T) {
+ client, err := createHTTPClient("socks5://127.0.0.1:1080", 8*time.Second)
+ if err != nil {
+ t.Fatalf("createHTTPClient() error: %v", err)
+ }
+
+ tr, ok := client.Transport.(*http.Transport)
+ if !ok {
+ t.Fatalf("client.Transport type = %T, want *http.Transport", client.Transport)
+ }
+ req, err := http.NewRequest("GET", "https://example.com", nil)
+ if err != nil {
+ t.Fatalf("http.NewRequest() error: %v", err)
+ }
+ proxyURL, err := tr.Proxy(req)
+ if err != nil {
+ t.Fatalf("transport.Proxy(req) error: %v", err)
+ }
+ if proxyURL == nil || proxyURL.String() != "socks5://127.0.0.1:1080" {
+ t.Fatalf("proxy URL = %v, want %q", proxyURL, "socks5://127.0.0.1:1080")
+ }
+}
+
+func TestCreateHTTPClient_UnsupportedProxyScheme(t *testing.T) {
+ _, err := createHTTPClient("ftp://127.0.0.1:21", 10*time.Second)
+ if err == nil {
+ t.Fatal("createHTTPClient() expected error for unsupported scheme, got nil")
+ }
+ if !strings.Contains(err.Error(), "unsupported proxy scheme") {
+ t.Fatalf("error = %q, want to contain %q", err.Error(), "unsupported proxy scheme")
+ }
+}
+
+func TestCreateHTTPClient_ProxyFromEnvironmentWhenConfigEmpty(t *testing.T) {
+ t.Setenv("HTTP_PROXY", "http://127.0.0.1:8888")
+ t.Setenv("http_proxy", "http://127.0.0.1:8888")
+ t.Setenv("HTTPS_PROXY", "http://127.0.0.1:8888")
+ t.Setenv("https_proxy", "http://127.0.0.1:8888")
+ t.Setenv("ALL_PROXY", "")
+ t.Setenv("all_proxy", "")
+ t.Setenv("NO_PROXY", "")
+ t.Setenv("no_proxy", "")
+
+ client, err := createHTTPClient("", 10*time.Second)
+ if err != nil {
+ t.Fatalf("createHTTPClient() error: %v", err)
+ }
+
+ tr, ok := client.Transport.(*http.Transport)
+ if !ok {
+ t.Fatalf("client.Transport type = %T, want *http.Transport", client.Transport)
+ }
+ if tr.Proxy == nil {
+ t.Fatal("transport.Proxy is nil, want proxy function from environment")
+ }
+
+ req, err := http.NewRequest("GET", "https://example.com", nil)
+ if err != nil {
+ t.Fatalf("http.NewRequest() error: %v", err)
+ }
+ if _, err := tr.Proxy(req); err != nil {
+ t.Fatalf("transport.Proxy(req) error: %v", err)
+ }
+}
+
+func TestNewWebFetchToolWithProxy(t *testing.T) {
+ tool := NewWebFetchToolWithProxy(1024, "http://127.0.0.1:7890")
+ if tool.maxChars != 1024 {
+ t.Fatalf("maxChars = %d, want %d", tool.maxChars, 1024)
+ }
+ if tool.proxy != "http://127.0.0.1:7890" {
+ t.Fatalf("proxy = %q, want %q", tool.proxy, "http://127.0.0.1:7890")
+ }
+
+ tool = NewWebFetchToolWithProxy(0, "http://127.0.0.1:7890")
+ if tool.maxChars != 50000 {
+ t.Fatalf("default maxChars = %d, want %d", tool.maxChars, 50000)
+ }
+}
+
+func TestNewWebSearchTool_PropagatesProxy(t *testing.T) {
+ t.Run("perplexity", func(t *testing.T) {
+ tool := NewWebSearchTool(WebSearchToolOptions{
+ PerplexityEnabled: true,
+ PerplexityAPIKey: "k",
+ PerplexityMaxResults: 3,
+ Proxy: "http://127.0.0.1:7890",
+ })
+ p, ok := tool.provider.(*PerplexitySearchProvider)
+ if !ok {
+ t.Fatalf("provider type = %T, want *PerplexitySearchProvider", tool.provider)
+ }
+ if p.proxy != "http://127.0.0.1:7890" {
+ t.Fatalf("provider proxy = %q, want %q", p.proxy, "http://127.0.0.1:7890")
+ }
+ })
+
+ t.Run("brave", func(t *testing.T) {
+ tool := NewWebSearchTool(WebSearchToolOptions{
+ BraveEnabled: true,
+ BraveAPIKey: "k",
+ BraveMaxResults: 3,
+ Proxy: "http://127.0.0.1:7890",
+ })
+ p, ok := tool.provider.(*BraveSearchProvider)
+ if !ok {
+ t.Fatalf("provider type = %T, want *BraveSearchProvider", tool.provider)
+ }
+ if p.proxy != "http://127.0.0.1:7890" {
+ t.Fatalf("provider proxy = %q, want %q", p.proxy, "http://127.0.0.1:7890")
+ }
+ })
+
+ t.Run("duckduckgo", func(t *testing.T) {
+ tool := NewWebSearchTool(WebSearchToolOptions{
+ DuckDuckGoEnabled: true,
+ DuckDuckGoMaxResults: 3,
+ Proxy: "http://127.0.0.1:7890",
+ })
+ p, ok := tool.provider.(*DuckDuckGoSearchProvider)
+ if !ok {
+ t.Fatalf("provider type = %T, want *DuckDuckGoSearchProvider", tool.provider)
+ }
+ if p.proxy != "http://127.0.0.1:7890" {
+ t.Fatalf("provider proxy = %q, want %q", p.proxy, "http://127.0.0.1:7890")
+ }
+ })
+}
+
+// TestWebTool_TavilySearch_Success verifies successful Tavily search
+func TestWebTool_TavilySearch_Success(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.Method != "POST" {
+ t.Errorf("Expected POST request, got %s", r.Method)
+ }
+ if r.Header.Get("Content-Type") != "application/json" {
+ t.Errorf("Expected Content-Type application/json, got %s", r.Header.Get("Content-Type"))
+ }
+
+ // Verify payload
+ var payload map[string]any
+ json.NewDecoder(r.Body).Decode(&payload)
+ if payload["api_key"] != "test-key" {
+ t.Errorf("Expected api_key test-key, got %v", payload["api_key"])
+ }
+ if payload["query"] != "test query" {
+ t.Errorf("Expected query 'test query', got %v", payload["query"])
+ }
+
+ // Return mock response
+ response := map[string]any{
+ "results": []map[string]any{
+ {
+ "title": "Test Result 1",
+ "url": "https://example.com/1",
+ "content": "Content for result 1",
+ },
+ {
+ "title": "Test Result 2",
+ "url": "https://example.com/2",
+ "content": "Content for result 2",
+ },
+ },
+ }
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusOK)
+ json.NewEncoder(w).Encode(response)
+ }))
+ defer server.Close()
+
+ tool := NewWebSearchTool(WebSearchToolOptions{
+ TavilyEnabled: true,
+ TavilyAPIKey: "test-key",
+ TavilyBaseURL: server.URL,
+ TavilyMaxResults: 5,
+ })
+
+ ctx := context.Background()
+ args := map[string]any{
+ "query": "test query",
+ }
+
+ result := tool.Execute(ctx, args)
+
+ // Success should not be an error
+ if result.IsError {
+ t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
+ }
+
+ // ForUser should contain result titles and URLs
+ if !strings.Contains(result.ForUser, "Test Result 1") ||
+ !strings.Contains(result.ForUser, "https://example.com/1") {
+ t.Errorf("Expected results in output, got: %s", result.ForUser)
+ }
+
+ // Should mention via Tavily
+ if !strings.Contains(result.ForUser, "via Tavily") {
+ t.Errorf("Expected 'via Tavily' in output, got: %s", result.ForUser)
+ }
+}
diff --git a/pkg/utils/download.go b/pkg/utils/download.go
new file mode 100644
index 000000000..5d9a13a30
--- /dev/null
+++ b/pkg/utils/download.go
@@ -0,0 +1,93 @@
+package utils
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "net/http"
+ "os"
+
+ "github.com/sipeed/picoclaw/pkg/logger"
+)
+
+// DownloadToFile streams an HTTP response body to a temporary file in small
+// chunks (~32KB), keeping peak memory usage constant regardless of file size.
+//
+// Parameters:
+// - ctx: context for cancellation/timeout
+// - client: HTTP client to use (caller controls timeouts, transport, etc.)
+// - req: fully prepared *http.Request (method, URL, headers, etc.)
+// - maxBytes: maximum bytes to download; 0 means no limit
+//
+// Returns the path to the temporary file. The caller is responsible for
+// removing it when done (defer os.Remove(path)).
+//
+// On any error the temp file is cleaned up automatically.
+func DownloadToFile(ctx context.Context, client *http.Client, req *http.Request, maxBytes int64) (string, error) {
+ // Attach context.
+ req = req.WithContext(ctx)
+
+ logger.DebugCF("download", "Starting download", map[string]any{
+ "url": req.URL.String(),
+ "max_bytes": maxBytes,
+ })
+
+ resp, err := client.Do(req)
+ if err != nil {
+ return "", fmt.Errorf("request failed: %w", err)
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode < 200 || resp.StatusCode >= 300 {
+ // Read a small amount for the error message.
+ errBody := make([]byte, 512)
+ n, _ := io.ReadFull(resp.Body, errBody)
+ return "", fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(errBody[:n]))
+ }
+
+ // Create temp file.
+ tmpFile, err := os.CreateTemp("", "picoclaw-dl-*")
+ if err != nil {
+ return "", fmt.Errorf("failed to create temp file: %w", err)
+ }
+ tmpPath := tmpFile.Name()
+
+ logger.DebugCF("download", "Streaming to temp file", map[string]any{
+ "path": tmpPath,
+ })
+
+ // Cleanup helper — removes the temp file on any error.
+ cleanup := func() {
+ _ = tmpFile.Close()
+ _ = os.Remove(tmpPath)
+ }
+
+ // Optionally limit the download size.
+ var src io.Reader = resp.Body
+ if maxBytes > 0 {
+ src = io.LimitReader(resp.Body, maxBytes+1) // +1 to detect overflow
+ }
+
+ written, err := io.Copy(tmpFile, src)
+ if err != nil {
+ cleanup()
+ return "", fmt.Errorf("download write failed: %w", err)
+ }
+
+ if maxBytes > 0 && written > maxBytes {
+ cleanup()
+ return "", fmt.Errorf("download too large: %d bytes (max %d)", written, maxBytes)
+ }
+
+ if err := tmpFile.Close(); err != nil {
+ _ = os.Remove(tmpPath)
+ return "", fmt.Errorf("failed to close temp file: %w", err)
+ }
+
+ logger.DebugCF("download", "Download complete", map[string]any{
+ "path": tmpPath,
+ "bytes_written": written,
+ })
+
+ return tmpPath, nil
+}
diff --git a/pkg/utils/media.go b/pkg/utils/media.go
index 2b184f2ec..a34889fb8 100644
--- a/pkg/utils/media.go
+++ b/pkg/utils/media.go
@@ -9,6 +9,7 @@ import (
"time"
"github.com/google/uuid"
+
"github.com/sipeed/picoclaw/pkg/logger"
)
@@ -65,8 +66,8 @@ func DownloadFile(url, filename string, opts DownloadOptions) string {
}
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
- if err := os.MkdirAll(mediaDir, 0700); err != nil {
- logger.ErrorCF(opts.LoggerPrefix, "Failed to create media directory", map[string]interface{}{
+ if err := os.MkdirAll(mediaDir, 0o700); err != nil {
+ logger.ErrorCF(opts.LoggerPrefix, "Failed to create media directory", map[string]any{
"error": err.Error(),
})
return ""
@@ -79,7 +80,7 @@ func DownloadFile(url, filename string, opts DownloadOptions) string {
// Create HTTP request
req, err := http.NewRequest("GET", url, nil)
if err != nil {
- logger.ErrorCF(opts.LoggerPrefix, "Failed to create download request", map[string]interface{}{
+ logger.ErrorCF(opts.LoggerPrefix, "Failed to create download request", map[string]any{
"error": err.Error(),
})
return ""
@@ -93,7 +94,7 @@ func DownloadFile(url, filename string, opts DownloadOptions) string {
client := &http.Client{Timeout: opts.Timeout}
resp, err := client.Do(req)
if err != nil {
- logger.ErrorCF(opts.LoggerPrefix, "Failed to download file", map[string]interface{}{
+ logger.ErrorCF(opts.LoggerPrefix, "Failed to download file", map[string]any{
"error": err.Error(),
"url": url,
})
@@ -102,7 +103,7 @@ func DownloadFile(url, filename string, opts DownloadOptions) string {
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
- logger.ErrorCF(opts.LoggerPrefix, "File download returned non-200 status", map[string]interface{}{
+ logger.ErrorCF(opts.LoggerPrefix, "File download returned non-200 status", map[string]any{
"status": resp.StatusCode,
"url": url,
})
@@ -111,7 +112,7 @@ func DownloadFile(url, filename string, opts DownloadOptions) string {
out, err := os.Create(localPath)
if err != nil {
- logger.ErrorCF(opts.LoggerPrefix, "Failed to create local file", map[string]interface{}{
+ logger.ErrorCF(opts.LoggerPrefix, "Failed to create local file", map[string]any{
"error": err.Error(),
})
return ""
@@ -121,13 +122,13 @@ func DownloadFile(url, filename string, opts DownloadOptions) string {
if _, err := io.Copy(out, resp.Body); err != nil {
out.Close()
os.Remove(localPath)
- logger.ErrorCF(opts.LoggerPrefix, "Failed to write file", map[string]interface{}{
+ logger.ErrorCF(opts.LoggerPrefix, "Failed to write file", map[string]any{
"error": err.Error(),
})
return ""
}
- logger.DebugCF(opts.LoggerPrefix, "File downloaded successfully", map[string]interface{}{
+ logger.DebugCF(opts.LoggerPrefix, "File downloaded successfully", map[string]any{
"path": localPath,
})
diff --git a/pkg/utils/skills.go b/pkg/utils/skills.go
new file mode 100644
index 000000000..1d2cfac7f
--- /dev/null
+++ b/pkg/utils/skills.go
@@ -0,0 +1,19 @@
+package utils
+
+import (
+ "fmt"
+ "strings"
+)
+
+// ValidateSkillIdentifier validates that the given skill identifier (slug or registry name) is non-empty
+// and does not contain path separators ("/", "\\") or ".." for security.
+func ValidateSkillIdentifier(identifier string) error {
+ trimmed := strings.TrimSpace(identifier)
+ if trimmed == "" {
+ return fmt.Errorf("identifier is required and must be a non-empty string")
+ }
+ if strings.ContainsAny(trimmed, "/\\") || strings.Contains(trimmed, "..") {
+ return fmt.Errorf("identifier must not contain path separators or '..' to prevent directory traversal")
+ }
+ return nil
+}
diff --git a/pkg/utils/string.go b/pkg/utils/string.go
index 0d9837cb9..62d9beee0 100644
--- a/pkg/utils/string.go
+++ b/pkg/utils/string.go
@@ -4,6 +4,9 @@ package utils
// Handles multi-byte Unicode characters properly.
// If the string is truncated, "..." is appended to indicate truncation.
func Truncate(s string, maxLen int) string {
+ if maxLen <= 0 {
+ return ""
+ }
runes := []rune(s)
if len(runes) <= maxLen {
return s
@@ -14,3 +17,12 @@ func Truncate(s string, maxLen int) string {
}
return string(runes[:maxLen-3]) + "..."
}
+
+// DerefStr dereferences a pointer to a string and
+// returns the value or a fallback if the pointer is nil.
+func DerefStr(s *string, fallback string) string {
+ if s == nil {
+ return fallback
+ }
+ return *s
+}
diff --git a/pkg/utils/string_test.go b/pkg/utils/string_test.go
new file mode 100644
index 000000000..a44ead228
--- /dev/null
+++ b/pkg/utils/string_test.go
@@ -0,0 +1,106 @@
+package utils
+
+import "testing"
+
+func TestTruncate(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ maxLen int
+ want string
+ }{
+ {
+ name: "short string unchanged",
+ input: "hi",
+ maxLen: 10,
+ want: "hi",
+ },
+ {
+ name: "exact length unchanged",
+ input: "hello",
+ maxLen: 5,
+ want: "hello",
+ },
+ {
+ name: "long string truncated with ellipsis",
+ input: "hello world",
+ maxLen: 8,
+ want: "hello...",
+ },
+ {
+ name: "maxLen equals 4 leaves 1 char plus ellipsis",
+ input: "abcdef",
+ maxLen: 4,
+ want: "a...",
+ },
+ {
+ name: "maxLen 3 returns first 3 chars without ellipsis",
+ input: "abcdef",
+ maxLen: 3,
+ want: "abc",
+ },
+ {
+ name: "maxLen 2 returns first 2 chars",
+ input: "abcdef",
+ maxLen: 2,
+ want: "ab",
+ },
+ {
+ name: "maxLen 1 returns first char",
+ input: "abcdef",
+ maxLen: 1,
+ want: "a",
+ },
+ {
+ name: "maxLen 0 returns empty",
+ input: "hello",
+ maxLen: 0,
+ want: "",
+ },
+ {
+ name: "negative maxLen returns empty",
+ input: "hello",
+ maxLen: -1,
+ want: "",
+ },
+ {
+ name: "empty string unchanged",
+ input: "",
+ maxLen: 5,
+ want: "",
+ },
+ {
+ name: "empty string with zero maxLen",
+ input: "",
+ maxLen: 0,
+ want: "",
+ },
+ {
+ name: "unicode truncated correctly",
+ input: "\U0001f600\U0001f601\U0001f602\U0001f603\U0001f604",
+ maxLen: 4,
+ want: "\U0001f600...",
+ },
+ {
+ name: "unicode short enough",
+ input: "\u00e9\u00e8",
+ maxLen: 5,
+ want: "\u00e9\u00e8",
+ },
+ {
+ name: "mixed ascii and unicode",
+ input: "Go\U0001f680\U0001f525\U0001f4a5\U0001f30d",
+ maxLen: 5,
+ want: "Go...",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := Truncate(tt.input, tt.maxLen)
+ if got != tt.want {
+ t.Errorf("Truncate(%q, %d) = %q, want %q", tt.input, tt.maxLen, got, tt.want)
+ }
+ })
+ }
+}
diff --git a/pkg/utils/zip.go b/pkg/utils/zip.go
new file mode 100644
index 000000000..919ce5a20
--- /dev/null
+++ b/pkg/utils/zip.go
@@ -0,0 +1,121 @@
+package utils
+
+import (
+ "archive/zip"
+ "fmt"
+ "io"
+ "os"
+ "path/filepath"
+ "strings"
+
+ "github.com/sipeed/picoclaw/pkg/logger"
+)
+
+// ExtractZipFile extracts a ZIP archive from disk to targetDir.
+// It reads entries one at a time from disk, keeping memory usage minimal.
+//
+// Security: rejects path traversal attempts and symlinks.
+func ExtractZipFile(zipPath string, targetDir string) error {
+ reader, err := zip.OpenReader(zipPath)
+ if err != nil {
+ return fmt.Errorf("invalid ZIP: %w", err)
+ }
+ defer reader.Close()
+
+ logger.DebugCF("zip", "Extracting ZIP", map[string]any{
+ "zip_path": zipPath,
+ "target_dir": targetDir,
+ "entries": len(reader.File),
+ })
+
+ if err := os.MkdirAll(targetDir, 0o755); err != nil {
+ return fmt.Errorf("failed to create target dir: %w", err)
+ }
+
+ for _, f := range reader.File {
+ // Path traversal protection.
+ cleanName := filepath.Clean(f.Name)
+ if strings.HasPrefix(cleanName, "..") || filepath.IsAbs(cleanName) {
+ return fmt.Errorf("zip entry has unsafe path: %q", f.Name)
+ }
+
+ destPath := filepath.Join(targetDir, cleanName)
+
+ // Double-check the resolved path is within target directory (defense-in-depth).
+ targetDirClean := filepath.Clean(targetDir)
+ if !strings.HasPrefix(filepath.Clean(destPath), targetDirClean+string(filepath.Separator)) &&
+ filepath.Clean(destPath) != targetDirClean {
+ return fmt.Errorf("zip entry escapes target dir: %q", f.Name)
+ }
+
+ mode := f.FileInfo().Mode()
+
+ // Reject any symlink.
+ if mode&os.ModeSymlink != 0 {
+ return fmt.Errorf("zip contains symlink %q; symlinks are not allowed", f.Name)
+ }
+
+ if f.FileInfo().IsDir() {
+ if err := os.MkdirAll(destPath, 0o755); err != nil {
+ return err
+ }
+ continue
+ }
+
+ // Ensure parent directory exists.
+ if err := os.MkdirAll(filepath.Dir(destPath), 0o755); err != nil {
+ return err
+ }
+
+ if err := extractSingleFile(f, destPath); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// extractSingleFile extracts one zip.File entry to destPath, with a size check.
+func extractSingleFile(f *zip.File, destPath string) error {
+ const maxFileSize = 5 * 1024 * 1024 // 5MB, adjust as appropriate
+
+ // Check the uncompressed size from the header, if available.
+ if f.UncompressedSize64 > maxFileSize {
+ return fmt.Errorf("zip entry %q is too large (%d bytes)", f.Name, f.UncompressedSize64)
+ }
+
+ rc, err := f.Open()
+ if err != nil {
+ return fmt.Errorf("failed to open zip entry %q: %w", f.Name, err)
+ }
+ defer rc.Close()
+
+ outFile, err := os.Create(destPath)
+ if err != nil {
+ return fmt.Errorf("failed to create file %q: %w", destPath, err)
+ }
+ // We don't return the close error via return, since it's not a named error return.
+ // Instead, we log to stderr and remove the partially written file as defensive cleanup.
+ defer func() {
+ if cerr := outFile.Close(); cerr != nil {
+ _ = os.Remove(destPath)
+ logger.ErrorCF("zip", "Failed to close file", map[string]any{
+ "dest_path": destPath,
+ "error": cerr.Error(),
+ })
+ }
+ }()
+
+ // Streamed size check: prevent overruns and malicious/corrupt headers.
+ written, err := io.CopyN(outFile, rc, maxFileSize+1)
+ if err != nil && err != io.EOF {
+ _ = os.Remove(destPath)
+ return fmt.Errorf("failed to extract %q: %w", f.Name, err)
+ }
+ if written > maxFileSize {
+ _ = os.Remove(destPath)
+ return fmt.Errorf("zip entry %q exceeds max size (%d bytes)", f.Name, written)
+ }
+
+ return nil
+}
diff --git a/pkg/voice/transcriber.go b/pkg/voice/transcriber.go
index 9af2ea6bb..f973e77fe 100644
--- a/pkg/voice/transcriber.go
+++ b/pkg/voice/transcriber.go
@@ -29,7 +29,7 @@ type TranscriptionResponse struct {
}
func NewGroqTranscriber(apiKey string) *GroqTranscriber {
- logger.DebugCF("voice", "Creating Groq transcriber", map[string]interface{}{"has_api_key": apiKey != ""})
+ logger.DebugCF("voice", "Creating Groq transcriber", map[string]any{"has_api_key": apiKey != ""})
apiBase := "https://api.groq.com/openai/v1"
return &GroqTranscriber{
@@ -42,22 +42,22 @@ func NewGroqTranscriber(apiKey string) *GroqTranscriber {
}
func (t *GroqTranscriber) Transcribe(ctx context.Context, audioFilePath string) (*TranscriptionResponse, error) {
- logger.InfoCF("voice", "Starting transcription", map[string]interface{}{"audio_file": audioFilePath})
+ logger.InfoCF("voice", "Starting transcription", map[string]any{"audio_file": audioFilePath})
audioFile, err := os.Open(audioFilePath)
if err != nil {
- logger.ErrorCF("voice", "Failed to open audio file", map[string]interface{}{"path": audioFilePath, "error": err})
+ logger.ErrorCF("voice", "Failed to open audio file", map[string]any{"path": audioFilePath, "error": err})
return nil, fmt.Errorf("failed to open audio file: %w", err)
}
defer audioFile.Close()
fileInfo, err := audioFile.Stat()
if err != nil {
- logger.ErrorCF("voice", "Failed to get file info", map[string]interface{}{"path": audioFilePath, "error": err})
+ logger.ErrorCF("voice", "Failed to get file info", map[string]any{"path": audioFilePath, "error": err})
return nil, fmt.Errorf("failed to get file info: %w", err)
}
- logger.DebugCF("voice", "Audio file details", map[string]interface{}{
+ logger.DebugCF("voice", "Audio file details", map[string]any{
"size_bytes": fileInfo.Size(),
"file_name": filepath.Base(audioFilePath),
})
@@ -67,44 +67,44 @@ func (t *GroqTranscriber) Transcribe(ctx context.Context, audioFilePath string)
part, err := writer.CreateFormFile("file", filepath.Base(audioFilePath))
if err != nil {
- logger.ErrorCF("voice", "Failed to create form file", map[string]interface{}{"error": err})
+ logger.ErrorCF("voice", "Failed to create form file", map[string]any{"error": err})
return nil, fmt.Errorf("failed to create form file: %w", err)
}
copied, err := io.Copy(part, audioFile)
if err != nil {
- logger.ErrorCF("voice", "Failed to copy file content", map[string]interface{}{"error": err})
+ logger.ErrorCF("voice", "Failed to copy file content", map[string]any{"error": err})
return nil, fmt.Errorf("failed to copy file content: %w", err)
}
- logger.DebugCF("voice", "File copied to request", map[string]interface{}{"bytes_copied": copied})
+ logger.DebugCF("voice", "File copied to request", map[string]any{"bytes_copied": copied})
- if err := writer.WriteField("model", "whisper-large-v3"); err != nil {
- logger.ErrorCF("voice", "Failed to write model field", map[string]interface{}{"error": err})
+ if err = writer.WriteField("model", "whisper-large-v3"); err != nil {
+ logger.ErrorCF("voice", "Failed to write model field", map[string]any{"error": err})
return nil, fmt.Errorf("failed to write model field: %w", err)
}
- if err := writer.WriteField("response_format", "json"); err != nil {
- logger.ErrorCF("voice", "Failed to write response_format field", map[string]interface{}{"error": err})
+ if err = writer.WriteField("response_format", "json"); err != nil {
+ logger.ErrorCF("voice", "Failed to write response_format field", map[string]any{"error": err})
return nil, fmt.Errorf("failed to write response_format field: %w", err)
}
- if err := writer.Close(); err != nil {
- logger.ErrorCF("voice", "Failed to close multipart writer", map[string]interface{}{"error": err})
+ if err = writer.Close(); err != nil {
+ logger.ErrorCF("voice", "Failed to close multipart writer", map[string]any{"error": err})
return nil, fmt.Errorf("failed to close multipart writer: %w", err)
}
url := t.apiBase + "/audio/transcriptions"
req, err := http.NewRequestWithContext(ctx, "POST", url, &requestBody)
if err != nil {
- logger.ErrorCF("voice", "Failed to create request", map[string]interface{}{"error": err})
+ logger.ErrorCF("voice", "Failed to create request", map[string]any{"error": err})
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", writer.FormDataContentType())
req.Header.Set("Authorization", "Bearer "+t.apiKey)
- logger.DebugCF("voice", "Sending transcription request to Groq API", map[string]interface{}{
+ logger.DebugCF("voice", "Sending transcription request to Groq API", map[string]any{
"url": url,
"request_size_bytes": requestBody.Len(),
"file_size_bytes": fileInfo.Size(),
@@ -112,37 +112,37 @@ func (t *GroqTranscriber) Transcribe(ctx context.Context, audioFilePath string)
resp, err := t.httpClient.Do(req)
if err != nil {
- logger.ErrorCF("voice", "Failed to send request", map[string]interface{}{"error": err})
+ logger.ErrorCF("voice", "Failed to send request", map[string]any{"error": err})
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
- logger.ErrorCF("voice", "Failed to read response", map[string]interface{}{"error": err})
+ logger.ErrorCF("voice", "Failed to read response", map[string]any{"error": err})
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
- logger.ErrorCF("voice", "API error", map[string]interface{}{
+ logger.ErrorCF("voice", "API error", map[string]any{
"status_code": resp.StatusCode,
"response": string(body),
})
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
}
- logger.DebugCF("voice", "Received response from Groq API", map[string]interface{}{
+ logger.DebugCF("voice", "Received response from Groq API", map[string]any{
"status_code": resp.StatusCode,
"response_size_bytes": len(body),
})
var result TranscriptionResponse
if err := json.Unmarshal(body, &result); err != nil {
- logger.ErrorCF("voice", "Failed to unmarshal response", map[string]interface{}{"error": err})
+ logger.ErrorCF("voice", "Failed to unmarshal response", map[string]any{"error": err})
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
}
- logger.InfoCF("voice", "Transcription completed successfully", map[string]interface{}{
+ logger.InfoCF("voice", "Transcription completed successfully", map[string]any{
"text_length": len(result.Text),
"language": result.Language,
"duration_seconds": result.Duration,
@@ -154,6 +154,6 @@ func (t *GroqTranscriber) Transcribe(ctx context.Context, audioFilePath string)
func (t *GroqTranscriber) IsAvailable() bool {
available := t.apiKey != ""
- logger.DebugCF("voice", "Checking transcriber availability", map[string]interface{}{"available": available})
+ logger.DebugCF("voice", "Checking transcriber availability", map[string]any{"available": available})
return available
}