mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Compare commits
175 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 50c58a3462 | |||
| 403e048821 | |||
| 4fde0175cf | |||
| 0d16525fab | |||
| 4cd3f99dd6 | |||
| 920e30a241 | |||
| 7b9b8104c8 | |||
| 881999aceb | |||
| f929268ab2 | |||
| 684e7413e1 | |||
| da79c201c7 | |||
| 5fb2721d22 | |||
| 951b05d255 | |||
| ac4b16dfb4 | |||
| 0fadbcd340 | |||
| a961a2df87 | |||
| 57dac394c5 | |||
| 13e4028d42 | |||
| e7f15afdd4 | |||
| 8d757fbb6f | |||
| e3f65fc3d6 | |||
| 5c321a90de | |||
| 17685da584 | |||
| 159a954122 | |||
| a371d53438 | |||
| 9d5728ec5b | |||
| 32cb8fdc12 | |||
| 341dbd3007 | |||
| 0cb9387cf8 | |||
| e77b0a6755 | |||
| a5503aea36 | |||
| cd638fff6c | |||
| 0d18210803 | |||
| 1d748fb742 | |||
| c6c82b3c44 | |||
| 811e4f8728 | |||
| 7ce5b75178 | |||
| 214b201bfa | |||
| 40f90281e5 | |||
| 82856bc57a | |||
| 9a3f3611c3 | |||
| 8eb9dcd99a | |||
| 0a88ff0817 | |||
| ddd73cad48 | |||
| 3334595859 | |||
| 25a47b50ef | |||
| 0f506d4202 | |||
| 6ce7659090 | |||
| 6f2e730eba | |||
| 16e5a02953 | |||
| 4b886b6573 | |||
| 0a4bf32e81 | |||
| 9eb1a53fb8 | |||
| 5a6ad37dab | |||
| f6d6221c09 | |||
| 7a9659971d | |||
| ecbe31599e | |||
| 6c392c3387 | |||
| 0aab8d8afc | |||
| 5faa67b77d | |||
| 7140e73d46 | |||
| dbf2739783 | |||
| a9557aa073 | |||
| a24cbd4385 | |||
| 8cff6cf312 | |||
| 60c3d96b5e | |||
| 773ce9bcb6 | |||
| 1cff7d4e37 | |||
| f7421128a0 | |||
| 1cb690df32 | |||
| 59c7aa1628 | |||
| 1fcf3fde0f | |||
| b805ec8bde | |||
| 378045510d | |||
| 2720fa71c7 | |||
| 7fa641a2ed | |||
| 18d3634f1b | |||
| afc3a2cda3 | |||
| ba0ef4b62c | |||
| 7d45101fcd | |||
| 1299b20465 | |||
| da804a0748 | |||
| 11a6f5eb71 | |||
| 7304ab7d33 | |||
| 5872e0f55e | |||
| 6463796fa1 | |||
| ac79a23e0a | |||
| ff92973361 | |||
| d7822e5d52 | |||
| aa7a8b89c3 | |||
| 6083168ab4 | |||
| f294a71bc5 | |||
| 2cb90f2fe6 | |||
| a286100db5 | |||
| 45351a6a79 | |||
| 1aea912fcd | |||
| 55d5e89246 | |||
| bd9c9d7efc | |||
| 15e3c7d08a | |||
| 82a9a80d94 | |||
| 878650c459 | |||
| 7f60392d88 | |||
| bc27707671 | |||
| 1e17bac9f0 | |||
| 14de80d35f | |||
| b5a8effcd6 | |||
| 82fea61551 | |||
| 92490feff0 | |||
| 3c2e467324 | |||
| 9036a511fb | |||
| e0f702fe03 | |||
| ee3e8ccebb | |||
| b484d3fcf3 | |||
| 000e64c55a | |||
| c86e121688 | |||
| e0a766243e | |||
| 1516cb57b4 | |||
| 584a3dcc87 | |||
| c58f8b740f | |||
| 5339389ef7 | |||
| f739c459bf | |||
| 2f5849b39d | |||
| cf5a84aac1 | |||
| 5aa4dd2975 | |||
| 896eae4c56 | |||
| 0d339d9e5a | |||
| c0d1346b5c | |||
| 7fa70b8cdf | |||
| 9ccfea4ed4 | |||
| cef1e39734 | |||
| a6aa833237 | |||
| 8851152cbd | |||
| 132fe7db51 | |||
| 0cce9fc905 | |||
| fc40f291d1 | |||
| 0c4b8b00f4 | |||
| 42e0e588dd | |||
| 68abf6b2ee | |||
| 4dfa133cb8 | |||
| 8fbbb67f70 | |||
| 7fa341c449 | |||
| 5893245b45 | |||
| b59464230a | |||
| 875a16d2d6 | |||
| e353844dfb | |||
| 474f3dbf90 | |||
| e7e086155e | |||
| 53b5be862f | |||
| ab20314882 | |||
| fbe1152e2d | |||
| 04924ed640 | |||
| 53df8d1f3d | |||
| b36c87bd60 | |||
| 3eb9d6a409 | |||
| 5582b6d910 | |||
| be81ba1f30 | |||
| e63f96794f | |||
| 03b02cc7d7 | |||
| 28734c3a2e | |||
| 061b07192d | |||
| a14181543e | |||
| 35fa64cde8 | |||
| 0ac93d4429 | |||
| 88014ecaff | |||
| e7e3f95ebe | |||
| 2989c391e3 | |||
| feba44ecf0 | |||
| b94941da4a | |||
| 4c4c10c915 | |||
| 7bcd8b284f | |||
| 56ac18ab70 | |||
| b573d61a58 | |||
| c6c61b4e9d | |||
| ca781d4b37 | |||
| cddafb403a |
@@ -9,6 +9,8 @@
|
||||
# ── Chat Channel ──────────────────────────
|
||||
# TELEGRAM_BOT_TOKEN=123456:ABC...
|
||||
# DISCORD_BOT_TOKEN=xxx
|
||||
# LINE_CHANNEL_SECRET=xxx
|
||||
# LINE_CHANNEL_ACCESS_TOKEN=xxx
|
||||
|
||||
# ── Web Search (optional) ────────────────
|
||||
# BRAVE_SEARCH_API_KEY=BSA...
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
---
|
||||
name: Bug report
|
||||
about: Report a bug or unexpected behavior
|
||||
title: "[BUG]"
|
||||
labels: bug
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
## Quick Summary
|
||||
|
||||
## Environment & Tools
|
||||
- **PicoClaw Version:** (e.g., v0.1.2 or commit hash)
|
||||
- **Go Version:** (e.g., go 1.22)
|
||||
- **AI Model & Provider:** (e.g., GPT-4o via OpenAI / DeepSeek via SiliconFlow)
|
||||
- **Operating System:** (e.g., Ubuntu 22.04 / macOS / Android Termux)
|
||||
- **Channels:** (e.g., Discord, Telegram, Feishu, ...)
|
||||
|
||||
## 📸 Steps to Reproduce
|
||||
1.
|
||||
2.
|
||||
3.
|
||||
|
||||
## ❌ Actual Behavior
|
||||
|
||||
## ✅ Expected Behavior
|
||||
|
||||
## 💬 Additional Context
|
||||
@@ -0,0 +1,23 @@
|
||||
---
|
||||
name: Feature request
|
||||
about: Suggest a new idea or improvement
|
||||
title: "[Feature]"
|
||||
labels: enhancement
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
## 🎯 The Goal / Use Case
|
||||
|
||||
## 💡 Proposed Solution
|
||||
|
||||
## 🛠 Potential Implementation (Optional)
|
||||
|
||||
## 🚦 Impact & Roadmap Alignment
|
||||
- [ ] This is a Core Feature
|
||||
- [ ] This is a Nice-to-Have / Enhancement
|
||||
- [ ] This aligns with the current Roadmap
|
||||
|
||||
## 🔄 Alternatives Considered
|
||||
|
||||
## 💬 Additional Context
|
||||
@@ -0,0 +1,26 @@
|
||||
---
|
||||
name: General Task / Todo
|
||||
about: A specific piece of work like doc, refactoring, or maintenance.
|
||||
title: "[Task]"
|
||||
labels: ''
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
## 📝 Objective
|
||||
|
||||
## 📋 To-Do List
|
||||
- [ ] Step 1
|
||||
- [ ] Step 2
|
||||
- [ ] Step 3
|
||||
|
||||
## 🎯 Definition of Done (Acceptance Criteria)
|
||||
- [ ] Documentation is updated in the README/docs folder.
|
||||
- [ ] Code follows project linting standards.
|
||||
- [ ] (If applicable) Basic tests pass.
|
||||
|
||||
## 💡 Context / Motivation
|
||||
|
||||
## 🔗 Related Issues / PRs
|
||||
- Fixes #
|
||||
- Relates to #
|
||||
@@ -0,0 +1,37 @@
|
||||
## 📝 Description
|
||||
## 🗣️ Type of Change
|
||||
- [ ] 🐞 Bug fix (non-breaking change which fixes an issue)
|
||||
- [ ] ✨ New feature (non-breaking change which adds functionality)
|
||||
- [ ] 📖 Documentation update
|
||||
- [ ] ⚡ Code refactoring (no functional changes, no api changes)
|
||||
|
||||
## 🤖 AI Code Generation
|
||||
- [ ] 🤖 Fully AI-generated (100% AI, 0% Human)
|
||||
- [ ] 🛠️ Mostly AI-generated (AI draft, Human verified/modified)
|
||||
- [ ] 👨💻 Mostly Human-written (Human lead, AI assisted or none)
|
||||
|
||||
|
||||
## 🔗 Linked Issue
|
||||
## 📚 Technical Context (Skip for Docs)
|
||||
* **Reference:** [URL]
|
||||
* **Reasoning:** ...
|
||||
|
||||
|
||||
## 🧪 Test Environment & Hardware
|
||||
- **Hardware:** [e.g. Raspberry Pi 5, Orange Pi, PC]
|
||||
- **OS:** [e.g. Debian 12, Ubuntu 22.04]
|
||||
- **Model/Provider:** [e.g. OpenAI GPT-4o, Kimi k2, DeepSeek-V3]
|
||||
- **Channels:** [e.g. Discord, Telegram, Feishu, ...]
|
||||
|
||||
|
||||
## 📸 Proof of Work (Optional for Docs)
|
||||
<details>
|
||||
<summary>Click to view Logs/Screenshots</summary>
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
## ☑️ Checklist
|
||||
- [ ] My code/docs follow the style of this project.
|
||||
- [ ] I have performed a self-review of my own changes.
|
||||
- [ ] I have updated the documentation accordingly.
|
||||
@@ -3,7 +3,6 @@ name: build
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
|
||||
jobs:
|
||||
build:
|
||||
@@ -17,5 +16,10 @@ jobs:
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
|
||||
- name: fmt
|
||||
run: |
|
||||
make fmt
|
||||
git diff --exit-code || (echo "::error::Code is not formatted. Run 'make fmt' and commit the changes." && exit 1)
|
||||
|
||||
- name: Build
|
||||
run: make build-all
|
||||
|
||||
@@ -1,15 +1,18 @@
|
||||
name: 🐳 Build & Push Docker Image
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
tags: ["v*"]
|
||||
pull_request:
|
||||
branches: [main]
|
||||
workflow_call:
|
||||
inputs:
|
||||
tag:
|
||||
description: "Release tag"
|
||||
required: true
|
||||
type: string
|
||||
|
||||
env:
|
||||
REGISTRY: ghcr.io
|
||||
IMAGE_NAME: ${{ github.repository_owner }}/picoclaw
|
||||
GHCR_REGISTRY: ghcr.io
|
||||
GHCR_IMAGE_NAME: ${{ github.repository_owner }}/picoclaw
|
||||
DOCKERHUB_REGISTRY: docker.io
|
||||
DOCKERHUB_IMAGE_NAME: ${{ vars.DOCKERHUB_REPOSITORY }}
|
||||
|
||||
jobs:
|
||||
build:
|
||||
@@ -23,6 +26,8 @@ jobs:
|
||||
# ── Checkout ──────────────────────────────
|
||||
- name: 📥 Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ inputs.tag }}
|
||||
|
||||
# ── Docker Buildx ─────────────────────────
|
||||
- name: 🔧 Set up Docker Buildx
|
||||
@@ -30,36 +35,42 @@ jobs:
|
||||
|
||||
# ── Login to GHCR ─────────────────────────
|
||||
- name: 🔑 Login to GitHub Container Registry
|
||||
if: github.event_name != 'pull_request'
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ${{ env.REGISTRY }}
|
||||
registry: ${{ env.GHCR_REGISTRY }}
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
# ── Metadata (tags & labels) ──────────────
|
||||
- name: 🏷️ Extract Docker metadata
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
# ── Login to Docker Hub ────────────────────
|
||||
- name: 🔑 Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
type=ref,event=pr
|
||||
type=semver,pattern={{version}}
|
||||
type=semver,pattern={{major}}.{{minor}}
|
||||
type=sha,prefix=
|
||||
type=raw,value=latest,enable={{is_default_branch}}
|
||||
type=raw,value={{date 'YYYYMMDD-HHmmss'}},enable={{is_default_branch}}
|
||||
registry: ${{ env.DOCKERHUB_REGISTRY }}
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
# ── Metadata (tags & labels) ──────────────
|
||||
- name: 🏷️ Prepare image tags
|
||||
id: tags
|
||||
shell: bash
|
||||
run: |
|
||||
tag="${{ inputs.tag }}"
|
||||
echo "ghcr_tag=${{ env.GHCR_REGISTRY }}/${{ env.GHCR_IMAGE_NAME }}:${tag}" >> "$GITHUB_OUTPUT"
|
||||
echo "ghcr_latest=${{ env.GHCR_REGISTRY }}/${{ env.GHCR_IMAGE_NAME }}:latest" >> "$GITHUB_OUTPUT"
|
||||
echo "dockerhub_tag=${{ env.DOCKERHUB_REGISTRY }}/${{ env.DOCKERHUB_IMAGE_NAME }}:${tag}" >> "$GITHUB_OUTPUT"
|
||||
echo "dockerhub_latest=${{ env.DOCKERHUB_REGISTRY }}/${{ env.DOCKERHUB_IMAGE_NAME }}:latest" >> "$GITHUB_OUTPUT"
|
||||
|
||||
# ── Build & Push ──────────────────────────
|
||||
- name: 🚀 Build and push Docker image
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
push: true
|
||||
tags: |
|
||||
${{ steps.tags.outputs.ghcr_tag }}
|
||||
${{ steps.tags.outputs.ghcr_latest }}
|
||||
${{ steps.tags.outputs.dockerhub_tag }}
|
||||
${{ steps.tags.outputs.dockerhub_latest }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
platforms: linux/amd64,linux/arm64
|
||||
platforms: linux/amd64,linux/arm64,linux/riscv64
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
name: pr-check
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
|
||||
jobs:
|
||||
fmt-check:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
|
||||
- name: Check formatting
|
||||
run: |
|
||||
make fmt
|
||||
git diff --exit-code || (echo "::error::Code is not formatted. Run 'make fmt' and commit the changes." && exit 1)
|
||||
|
||||
vet:
|
||||
runs-on: ubuntu-latest
|
||||
needs: fmt-check
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
|
||||
- name: Run go generate
|
||||
run: go generate ./...
|
||||
|
||||
- name: Run go vet
|
||||
run: go vet ./...
|
||||
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
needs: fmt-check
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
|
||||
- name: Run go generate
|
||||
run: go generate ./...
|
||||
|
||||
- name: Run go test
|
||||
run: go test ./...
|
||||
|
||||
@@ -32,20 +32,26 @@ jobs:
|
||||
|
||||
- name: Create and push tag
|
||||
shell: bash
|
||||
env:
|
||||
RELEASE_TAG: ${{ inputs.tag }}
|
||||
run: |
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||
git tag -a "${{ inputs.tag }}" -m "Release ${{ inputs.tag }}"
|
||||
git push origin "${{ inputs.tag }}"
|
||||
git tag -a "$RELEASE_TAG" -m "Release $RELEASE_TAG"
|
||||
git push origin "$RELEASE_TAG"
|
||||
|
||||
build-binaries:
|
||||
name: Build Release Binaries
|
||||
release:
|
||||
name: GoReleaser Release
|
||||
needs: create-tag
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write
|
||||
packages: write
|
||||
steps:
|
||||
- name: Checkout tag
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
ref: ${{ inputs.tag }}
|
||||
|
||||
- name: Setup Go from go.mod
|
||||
@@ -53,47 +59,42 @@ jobs:
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
|
||||
- name: Build all binaries
|
||||
run: make build-all
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Generate checksums
|
||||
shell: bash
|
||||
run: |
|
||||
shasum -a 256 build/picoclaw-* > build/sha256sums.txt
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Upload release binaries artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
- name: Login to GitHub Container Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
name: picoclaw-binaries
|
||||
path: |
|
||||
build/picoclaw-*
|
||||
build/sha256sums.txt
|
||||
if-no-files-found: error
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
create-release:
|
||||
name: Create GitHub Release
|
||||
needs: [create-tag, build-binaries]
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write
|
||||
steps:
|
||||
- name: Download all artifacts
|
||||
uses: actions/download-artifact@v4
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
path: release-artifacts
|
||||
registry: docker.io
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
|
||||
- name: Show downloaded files
|
||||
run: ls -R release-artifacts
|
||||
|
||||
- name: Create release
|
||||
uses: softprops/action-gh-release@v2
|
||||
- name: Run GoReleaser
|
||||
uses: goreleaser/goreleaser-action@v6
|
||||
with:
|
||||
tag_name: ${{ inputs.tag }}
|
||||
name: ${{ inputs.tag }}
|
||||
draft: ${{ inputs.draft }}
|
||||
prerelease: ${{ inputs.prerelease }}
|
||||
files: |
|
||||
release-artifacts/**/*
|
||||
generate_release_notes: true
|
||||
distribution: goreleaser
|
||||
version: ~> v2
|
||||
args: release --clean
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
GITHUB_REPOSITORY_OWNER: ${{ github.repository_owner }}
|
||||
DOCKERHUB_IMAGE_NAME: ${{ vars.DOCKERHUB_REPOSITORY }}
|
||||
|
||||
- name: Apply release flags
|
||||
shell: bash
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
gh release edit "${{ inputs.tag }}" \
|
||||
--draft=${{ inputs.draft }} \
|
||||
--prerelease=${{ inputs.prerelease }}
|
||||
|
||||
+10
@@ -10,6 +10,7 @@ build/
|
||||
*.out
|
||||
/picoclaw
|
||||
/picoclaw-test
|
||||
cmd/picoclaw/workspace
|
||||
|
||||
# Picoclaw specific
|
||||
|
||||
@@ -34,3 +35,12 @@ coverage.html
|
||||
|
||||
# Ralph workspace
|
||||
ralph/
|
||||
.ralph/
|
||||
tasks/
|
||||
|
||||
# Editors
|
||||
.vscode/
|
||||
.idea/
|
||||
|
||||
# Added by goreleaser init:
|
||||
dist/
|
||||
|
||||
@@ -0,0 +1,79 @@
|
||||
# yaml-language-server: $schema=https://goreleaser.com/static/schema.json
|
||||
# vim: set ts=2 sw=2 tw=0 fo=cnqoj
|
||||
version: 2
|
||||
|
||||
before:
|
||||
hooks:
|
||||
- go mod tidy
|
||||
- go generate ./cmd/picoclaw
|
||||
|
||||
builds:
|
||||
- id: picoclaw
|
||||
env:
|
||||
- CGO_ENABLED=0
|
||||
goos:
|
||||
- linux
|
||||
- windows
|
||||
- darwin
|
||||
- freebsd
|
||||
goarch:
|
||||
- amd64
|
||||
- arm64
|
||||
- riscv64
|
||||
- s390x
|
||||
- mips64
|
||||
- arm
|
||||
main: ./cmd/picoclaw
|
||||
ignore:
|
||||
- goos: windows
|
||||
goarch: arm
|
||||
|
||||
dockers_v2:
|
||||
- id: picoclaw
|
||||
dockerfile: Dockerfile.goreleaser
|
||||
ids:
|
||||
- picoclaw
|
||||
images:
|
||||
- "ghcr.io/{{ .Env.GITHUB_REPOSITORY_OWNER }}/picoclaw"
|
||||
- "docker.io/{{ .Env.DOCKERHUB_IMAGE_NAME }}"
|
||||
tags:
|
||||
- "{{ .Tag }}"
|
||||
- "latest"
|
||||
platforms:
|
||||
- linux/amd64
|
||||
- linux/arm64
|
||||
- linux/riscv64
|
||||
|
||||
archives:
|
||||
- formats: [tar.gz]
|
||||
# this name template makes the OS and Arch compatible with the results of `uname`.
|
||||
name_template: >-
|
||||
{{ .ProjectName }}_
|
||||
{{- title .Os }}_
|
||||
{{- if eq .Arch "amd64" }}x86_64
|
||||
{{- else if eq .Arch "386" }}i386
|
||||
{{- else }}{{ .Arch }}{{ end }}
|
||||
{{- if .Arm }}v{{ .Arm }}{{ end }}
|
||||
# use zip for windows archives
|
||||
format_overrides:
|
||||
- goos: windows
|
||||
formats: [zip]
|
||||
|
||||
changelog:
|
||||
sort: asc
|
||||
filters:
|
||||
exclude:
|
||||
- "^docs:"
|
||||
- "^test:"
|
||||
|
||||
# upx:
|
||||
# - enabled: true
|
||||
# compress: best
|
||||
# lzma: true
|
||||
|
||||
release:
|
||||
footer: >-
|
||||
|
||||
---
|
||||
|
||||
Released by [GoReleaser](https://github.com/goreleaser/goreleaser).
|
||||
+8
-8
@@ -1,7 +1,7 @@
|
||||
# ============================================================
|
||||
# Stage 1: Build the picoclaw binary
|
||||
# ============================================================
|
||||
FROM golang:1.24-alpine AS builder
|
||||
FROM golang:1.26.0-alpine AS builder
|
||||
|
||||
RUN apk add --no-cache git make
|
||||
|
||||
@@ -18,19 +18,19 @@ RUN make build
|
||||
# ============================================================
|
||||
# Stage 2: Minimal runtime image
|
||||
# ============================================================
|
||||
FROM alpine:3.21
|
||||
FROM alpine:3.23
|
||||
|
||||
RUN apk add --no-cache ca-certificates tzdata
|
||||
RUN apk add --no-cache ca-certificates tzdata curl
|
||||
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
||||
CMD wget -q --spider http://localhost:18790/health || exit 1
|
||||
|
||||
# Copy binary
|
||||
COPY --from=builder /src/build/picoclaw /usr/local/bin/picoclaw
|
||||
|
||||
# Copy builtin skills
|
||||
COPY --from=builder /src/skills /opt/picoclaw/skills
|
||||
|
||||
# Create picoclaw home directory
|
||||
RUN mkdir -p /root/.picoclaw/workspace/skills && \
|
||||
cp -r /opt/picoclaw/skills/* /root/.picoclaw/workspace/skills/ 2>/dev/null || true
|
||||
RUN /usr/local/bin/picoclaw onboard
|
||||
|
||||
ENTRYPOINT ["picoclaw"]
|
||||
CMD ["gateway"]
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
FROM alpine:3.21
|
||||
|
||||
ARG TARGETPLATFORM
|
||||
|
||||
RUN apk add --no-cache ca-certificates tzdata
|
||||
|
||||
COPY $TARGETPLATFORM/picoclaw /usr/local/bin/picoclaw
|
||||
|
||||
ENTRYPOINT ["picoclaw"]
|
||||
CMD ["gateway"]
|
||||
@@ -8,9 +8,10 @@ MAIN_GO=$(CMD_DIR)/main.go
|
||||
|
||||
# Version
|
||||
VERSION?=$(shell git describe --tags --always --dirty 2>/dev/null || echo "dev")
|
||||
GIT_COMMIT=$(shell git rev-parse --short=8 HEAD 2>/dev/null || echo "dev")
|
||||
BUILD_TIME=$(shell date +%FT%T%z)
|
||||
GO_VERSION=$(shell $(GO) version | awk '{print $$3}')
|
||||
LDFLAGS=-ldflags "-X main.version=$(VERSION) -X main.buildTime=$(BUILD_TIME) -X main.goVersion=$(GO_VERSION)"
|
||||
LDFLAGS=-ldflags "-X main.version=$(VERSION) -X main.gitCommit=$(GIT_COMMIT) -X main.buildTime=$(BUILD_TIME) -X main.goVersion=$(GO_VERSION)"
|
||||
|
||||
# Go variables
|
||||
GO?=go
|
||||
@@ -38,6 +39,8 @@ ifeq ($(UNAME_S),Linux)
|
||||
ARCH=amd64
|
||||
else ifeq ($(UNAME_M),aarch64)
|
||||
ARCH=arm64
|
||||
else ifeq ($(UNAME_M),loongarch64)
|
||||
ARCH=loong64
|
||||
else ifeq ($(UNAME_M),riscv64)
|
||||
ARCH=riscv64
|
||||
else
|
||||
@@ -62,20 +65,28 @@ BINARY_PATH=$(BUILD_DIR)/$(BINARY_NAME)-$(PLATFORM)-$(ARCH)
|
||||
# Default target
|
||||
all: build
|
||||
|
||||
## generate: Run generate
|
||||
generate:
|
||||
@echo "Run generate..."
|
||||
@rm -r ./$(CMD_DIR)/workspace 2>/dev/null || true
|
||||
@$(GO) generate ./...
|
||||
@echo "Run generate complete"
|
||||
|
||||
## build: Build the picoclaw binary for current platform
|
||||
build:
|
||||
build: generate
|
||||
@echo "Building $(BINARY_NAME) for $(PLATFORM)/$(ARCH)..."
|
||||
@mkdir -p $(BUILD_DIR)
|
||||
$(GO) build $(GOFLAGS) $(LDFLAGS) -o $(BINARY_PATH) ./$(CMD_DIR)
|
||||
@$(GO) build $(GOFLAGS) $(LDFLAGS) -o $(BINARY_PATH) ./$(CMD_DIR)
|
||||
@echo "Build complete: $(BINARY_PATH)"
|
||||
@ln -sf $(BINARY_NAME)-$(PLATFORM)-$(ARCH) $(BUILD_DIR)/$(BINARY_NAME)
|
||||
|
||||
## build-all: Build picoclaw for all platforms
|
||||
build-all:
|
||||
build-all: generate
|
||||
@echo "Building for multiple platforms..."
|
||||
@mkdir -p $(BUILD_DIR)
|
||||
GOOS=linux GOARCH=amd64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-amd64 ./$(CMD_DIR)
|
||||
GOOS=linux GOARCH=arm64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm64 ./$(CMD_DIR)
|
||||
GOOS=linux GOARCH=loong64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-loong64 ./$(CMD_DIR)
|
||||
GOOS=linux GOARCH=riscv64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-riscv64 ./$(CMD_DIR)
|
||||
GOOS=darwin GOARCH=arm64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-darwin-arm64 ./$(CMD_DIR)
|
||||
GOOS=windows GOARCH=amd64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-windows-amd64.exe ./$(CMD_DIR)
|
||||
@@ -88,35 +99,8 @@ install: build
|
||||
@cp $(BUILD_DIR)/$(BINARY_NAME) $(INSTALL_BIN_DIR)/$(BINARY_NAME)
|
||||
@chmod +x $(INSTALL_BIN_DIR)/$(BINARY_NAME)
|
||||
@echo "Installed binary to $(INSTALL_BIN_DIR)/$(BINARY_NAME)"
|
||||
@echo "Installing builtin skills to $(WORKSPACE_SKILLS_DIR)..."
|
||||
@mkdir -p $(WORKSPACE_SKILLS_DIR)
|
||||
@for skill in $(BUILTIN_SKILLS_DIR)/*/; do \
|
||||
if [ -d "$$skill" ]; then \
|
||||
skill_name=$$(basename "$$skill"); \
|
||||
if [ -f "$$skill/SKILL.md" ]; then \
|
||||
cp -r "$$skill" $(WORKSPACE_SKILLS_DIR); \
|
||||
echo " ✓ Installed skill: $$skill_name"; \
|
||||
fi; \
|
||||
fi; \
|
||||
done
|
||||
@echo "Installation complete!"
|
||||
|
||||
## install-skills: Install builtin skills to workspace
|
||||
install-skills:
|
||||
@echo "Installing builtin skills to $(WORKSPACE_SKILLS_DIR)..."
|
||||
@mkdir -p $(WORKSPACE_SKILLS_DIR)
|
||||
@for skill in $(BUILTIN_SKILLS_DIR)/*/; do \
|
||||
if [ -d "$$skill" ]; then \
|
||||
skill_name=$$(basename "$$skill"); \
|
||||
if [ -f "$$skill/SKILL.md" ]; then \
|
||||
mkdir -p $(WORKSPACE_SKILLS_DIR)/$$skill_name; \
|
||||
cp -r "$$skill" $(WORKSPACE_SKILLS_DIR); \
|
||||
echo " ✓ Installed skill: $$skill_name"; \
|
||||
fi; \
|
||||
fi; \
|
||||
done
|
||||
@echo "Skills installation complete!"
|
||||
|
||||
## uninstall: Remove picoclaw from system
|
||||
uninstall:
|
||||
@echo "Uninstalling $(BINARY_NAME)..."
|
||||
@@ -138,15 +122,31 @@ clean:
|
||||
@rm -rf $(BUILD_DIR)
|
||||
@echo "Clean complete"
|
||||
|
||||
## vet: Run go vet for static analysis
|
||||
vet:
|
||||
@$(GO) vet ./...
|
||||
|
||||
## fmt: Format Go code
|
||||
test:
|
||||
@$(GO) test ./...
|
||||
|
||||
## fmt: Format Go code
|
||||
fmt:
|
||||
@$(GO) fmt ./...
|
||||
|
||||
## deps: Update dependencies
|
||||
## deps: Download dependencies
|
||||
deps:
|
||||
@$(GO) mod download
|
||||
@$(GO) mod verify
|
||||
|
||||
## update-deps: Update dependencies
|
||||
update-deps:
|
||||
@$(GO) get -u ./...
|
||||
@$(GO) mod tidy
|
||||
|
||||
## check: Run vet, fmt, and verify dependencies
|
||||
check: deps fmt vet test
|
||||
|
||||
## run: Build and run picoclaw
|
||||
run: build
|
||||
@$(BUILD_DIR)/$(BINARY_NAME) $(ARGS)
|
||||
|
||||
+331
-12
@@ -186,7 +186,7 @@ picoclaw onboard
|
||||
"providers": {
|
||||
"openrouter": {
|
||||
"api_key": "xxx",
|
||||
"api_base": "https://open.bigmodel.cn/api/paas/v4"
|
||||
"api_base": "https://openrouter.ai/api/v1"
|
||||
}
|
||||
},
|
||||
"tools": {
|
||||
@@ -195,7 +195,14 @@ picoclaw onboard
|
||||
"api_key": "YOUR_BRAVE_API_KEY",
|
||||
"max_results": 5
|
||||
}
|
||||
},
|
||||
"cron": {
|
||||
"exec_timeout_minutes": 5
|
||||
}
|
||||
},
|
||||
"heartbeat": {
|
||||
"enabled": true,
|
||||
"interval": 30
|
||||
}
|
||||
}
|
||||
```
|
||||
@@ -219,12 +226,15 @@ picoclaw agent -m "What is 2+2?"
|
||||
|
||||
## 💬 チャットアプリ
|
||||
|
||||
Telegram で PicoClaw と会話できます
|
||||
Telegram、Discord、QQ、DingTalk、LINE で PicoClaw と会話できます
|
||||
|
||||
| チャネル | セットアップ |
|
||||
|---------|------------|
|
||||
| **Telegram** | 簡単(トークンのみ) |
|
||||
| **Discord** | 簡単(Bot トークン + Intents) |
|
||||
| **QQ** | 簡単(AppID + AppSecret) |
|
||||
| **DingTalk** | 普通(アプリ認証情報) |
|
||||
| **LINE** | 普通(認証情報 + Webhook URL) |
|
||||
|
||||
<details>
|
||||
<summary><b>Telegram</b>(推奨)</summary>
|
||||
@@ -303,22 +313,324 @@ picoclaw gateway
|
||||
|
||||
</details>
|
||||
|
||||
## 設定 (Configuration)
|
||||
<details>
|
||||
<summary><b>QQ</b></summary>
|
||||
|
||||
PicoClaw は設定に `config.json` を使用します。
|
||||
**1. Bot を作成**
|
||||
|
||||
- [QQ オープンプラットフォーム](https://q.qq.com/#) にアクセス
|
||||
- アプリケーションを作成 → **AppID** と **AppSecret** を取得
|
||||
|
||||
**2. 設定**
|
||||
|
||||
```json
|
||||
{
|
||||
"channels": {
|
||||
"qq": {
|
||||
"enabled": true,
|
||||
"app_id": "YOUR_APP_ID",
|
||||
"app_secret": "YOUR_APP_SECRET",
|
||||
"allow_from": []
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
> `allow_from` を空にすると全ユーザーを許可、QQ番号を指定してアクセス制限可能。
|
||||
|
||||
**3. 起動**
|
||||
|
||||
```bash
|
||||
picoclaw gateway
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><b>DingTalk</b></summary>
|
||||
|
||||
**1. Bot を作成**
|
||||
|
||||
- [オープンプラットフォーム](https://open.dingtalk.com/) にアクセス
|
||||
- 内部アプリを作成
|
||||
- Client ID と Client Secret をコピー
|
||||
|
||||
**2. 設定**
|
||||
|
||||
```json
|
||||
{
|
||||
"channels": {
|
||||
"dingtalk": {
|
||||
"enabled": true,
|
||||
"client_id": "YOUR_CLIENT_ID",
|
||||
"client_secret": "YOUR_CLIENT_SECRET",
|
||||
"allow_from": []
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
> `allow_from` を空にすると全ユーザーを許可、ユーザーIDを指定してアクセス制限可能。
|
||||
|
||||
**3. 起動**
|
||||
|
||||
```bash
|
||||
picoclaw gateway
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><b>LINE</b></summary>
|
||||
|
||||
**1. LINE 公式アカウントを作成**
|
||||
|
||||
- [LINE Developers Console](https://developers.line.biz/) にアクセス
|
||||
- プロバイダーを作成 → Messaging API チャネルを作成
|
||||
- **チャネルシークレット** と **チャネルアクセストークン** をコピー
|
||||
|
||||
**2. 設定**
|
||||
|
||||
```json
|
||||
{
|
||||
"channels": {
|
||||
"line": {
|
||||
"enabled": true,
|
||||
"channel_secret": "YOUR_CHANNEL_SECRET",
|
||||
"channel_access_token": "YOUR_CHANNEL_ACCESS_TOKEN",
|
||||
"webhook_host": "0.0.0.0",
|
||||
"webhook_port": 18791,
|
||||
"webhook_path": "/webhook/line",
|
||||
"allow_from": []
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**3. Webhook URL を設定**
|
||||
|
||||
LINE の Webhook には HTTPS が必要です。リバースプロキシまたはトンネルを使用してください:
|
||||
|
||||
```bash
|
||||
# ngrok の例
|
||||
ngrok http 18791
|
||||
```
|
||||
|
||||
LINE Developers Console で Webhook URL を `https://あなたのドメイン/webhook/line` に設定し、**Webhook の利用** を有効にしてください。
|
||||
|
||||
**4. 起動**
|
||||
|
||||
```bash
|
||||
picoclaw gateway
|
||||
```
|
||||
|
||||
> グループチャットでは @メンション時のみ応答します。返信は元メッセージを引用する形式です。
|
||||
|
||||
> **Docker Compose**: `picoclaw-gateway` サービスに `ports: ["18791:18791"]` を追加して Webhook ポートを公開してください。
|
||||
|
||||
</details>
|
||||
|
||||
## ⚙️ 設定
|
||||
|
||||
設定ファイル: `~/.picoclaw/config.json`
|
||||
|
||||
### ワークスペース構成
|
||||
|
||||
PicoClaw は設定されたワークスペース(デフォルト: `~/.picoclaw/workspace`)にデータを保存します:
|
||||
|
||||
```
|
||||
~/.picoclaw/workspace/
|
||||
├── sessions/ # 会話セッションと履歴
|
||||
├── memory/ # 長期メモリ(MEMORY.md)
|
||||
├── state/ # 永続状態(最後のチャネルなど)
|
||||
├── cron/ # スケジュールジョブデータベース
|
||||
├── skills/ # カスタムスキル
|
||||
├── AGENTS.md # エージェントの行動ガイド
|
||||
├── HEARTBEAT.md # 定期タスクプロンプト(30分ごとに確認)
|
||||
├── IDENTITY.md # エージェントのアイデンティティ
|
||||
├── SOUL.md # エージェントのソウル
|
||||
├── TOOLS.md # ツールの説明
|
||||
└── USER.md # ユーザー設定
|
||||
```
|
||||
|
||||
### 🔒 セキュリティサンドボックス
|
||||
|
||||
PicoClaw はデフォルトでサンドボックス環境で実行されます。エージェントは設定されたワークスペース内のファイルにのみアクセスし、コマンドを実行できます。
|
||||
|
||||
#### デフォルト設定
|
||||
|
||||
```json
|
||||
{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"workspace": "~/.picoclaw/workspace",
|
||||
"restrict_to_workspace": true
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
| オプション | デフォルト | 説明 |
|
||||
|-----------|-----------|------|
|
||||
| `workspace` | `~/.picoclaw/workspace` | エージェントの作業ディレクトリ |
|
||||
| `restrict_to_workspace` | `true` | ファイル/コマンドアクセスをワークスペースに制限 |
|
||||
|
||||
#### 保護対象ツール
|
||||
|
||||
`restrict_to_workspace: true` の場合、以下のツールがサンドボックス化されます:
|
||||
|
||||
| ツール | 機能 | 制限 |
|
||||
|-------|------|------|
|
||||
| `read_file` | ファイル読み込み | ワークスペース内のファイルのみ |
|
||||
| `write_file` | ファイル書き込み | ワークスペース内のファイルのみ |
|
||||
| `list_dir` | ディレクトリ一覧 | ワークスペース内のディレクトリのみ |
|
||||
| `edit_file` | ファイル編集 | ワークスペース内のファイルのみ |
|
||||
| `append_file` | ファイル追記 | ワークスペース内のファイルのみ |
|
||||
| `exec` | コマンド実行 | コマンドパスはワークスペース内である必要あり |
|
||||
|
||||
#### exec ツールの追加保護
|
||||
|
||||
`restrict_to_workspace: false` でも、`exec` ツールは以下の危険なコマンドをブロックします:
|
||||
|
||||
- `rm -rf`, `del /f`, `rmdir /s` — 一括削除
|
||||
- `format`, `mkfs`, `diskpart` — ディスクフォーマット
|
||||
- `dd if=` — ディスクイメージング
|
||||
- `/dev/sd[a-z]` への書き込み — 直接ディスク書き込み
|
||||
- `shutdown`, `reboot`, `poweroff` — システムシャットダウン
|
||||
- フォークボム `:(){ :|:& };:`
|
||||
|
||||
#### エラー例
|
||||
|
||||
```
|
||||
[ERROR] tool: Tool execution failed
|
||||
{tool=exec, error=Command blocked by safety guard (path outside working dir)}
|
||||
```
|
||||
|
||||
```
|
||||
[ERROR] tool: Tool execution failed
|
||||
{tool=exec, error=Command blocked by safety guard (dangerous pattern detected)}
|
||||
```
|
||||
|
||||
#### 制限の無効化(セキュリティリスク)
|
||||
|
||||
エージェントにワークスペース外のパスへのアクセスが必要な場合:
|
||||
|
||||
**方法1: 設定ファイル**
|
||||
```json
|
||||
{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"restrict_to_workspace": false
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**方法2: 環境変数**
|
||||
```bash
|
||||
export PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE=false
|
||||
```
|
||||
|
||||
> ⚠️ **警告**: この制限を無効にすると、エージェントはシステム上の任意のパスにアクセスできるようになります。制御された環境でのみ慎重に使用してください。
|
||||
|
||||
#### セキュリティ境界の一貫性
|
||||
|
||||
`restrict_to_workspace` 設定は、すべての実行パスで一貫して適用されます:
|
||||
|
||||
| 実行パス | セキュリティ境界 |
|
||||
|---------|-----------------|
|
||||
| メインエージェント | `restrict_to_workspace` ✅ |
|
||||
| サブエージェント / Spawn | 同じ制限を継承 ✅ |
|
||||
| ハートビートタスク | 同じ制限を継承 ✅ |
|
||||
|
||||
すべてのパスで同じワークスペース制限が適用されます — サブエージェントやスケジュールタスクを通じてセキュリティ境界をバイパスする方法はありません。
|
||||
|
||||
### ハートビート(定期タスク)
|
||||
|
||||
PicoClaw は自動的に定期タスクを実行できます。ワークスペースに `HEARTBEAT.md` ファイルを作成します:
|
||||
|
||||
```markdown
|
||||
# 定期タスク
|
||||
|
||||
- 重要なメールをチェック
|
||||
- 今後の予定を確認
|
||||
- 天気予報をチェック
|
||||
```
|
||||
|
||||
エージェントは30分ごと(設定可能)にこのファイルを読み込み、利用可能なツールを使ってタスクを実行します。
|
||||
|
||||
#### spawn で非同期タスク実行
|
||||
|
||||
時間のかかるタスク(Web検索、API呼び出し)には `spawn` ツールを使って**サブエージェント**を作成します:
|
||||
|
||||
```markdown
|
||||
# 定期タスク
|
||||
|
||||
## クイックタスク(直接応答)
|
||||
- 現在時刻を報告
|
||||
|
||||
## 長時間タスク(spawn で非同期)
|
||||
- AIニュースを検索して要約
|
||||
- メールをチェックして重要なメッセージを報告
|
||||
```
|
||||
|
||||
**主な特徴:**
|
||||
|
||||
| 機能 | 説明 |
|
||||
|------|------|
|
||||
| **spawn** | 非同期サブエージェントを作成、ハートビートをブロックしない |
|
||||
| **独立コンテキスト** | サブエージェントは独自のコンテキストを持ち、セッション履歴なし |
|
||||
| **message ツール** | サブエージェントは message ツールで直接ユーザーと通信 |
|
||||
| **非ブロッキング** | spawn 後、ハートビートは次のタスクへ継続 |
|
||||
|
||||
#### サブエージェントの通信方法
|
||||
|
||||
```
|
||||
ハートビート発動
|
||||
↓
|
||||
エージェントが HEARTBEAT.md を読む
|
||||
↓
|
||||
長いタスク: spawn サブエージェント
|
||||
↓ ↓
|
||||
次のタスクへ継続 サブエージェントが独立して動作
|
||||
↓ ↓
|
||||
全タスク完了 message ツールを使用
|
||||
↓ ↓
|
||||
HEARTBEAT_OK 応答 ユーザーが直接結果を受け取る
|
||||
```
|
||||
|
||||
サブエージェントはツール(message、web_search など)にアクセスでき、メインエージェントを経由せずにユーザーと通信できます。
|
||||
|
||||
**設定:**
|
||||
|
||||
```json
|
||||
{
|
||||
"heartbeat": {
|
||||
"enabled": true,
|
||||
"interval": 30
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
| オプション | デフォルト | 説明 |
|
||||
|-----------|-----------|------|
|
||||
| `enabled` | `true` | ハートビートの有効/無効 |
|
||||
| `interval` | `30` | チェック間隔(分)、最小5分 |
|
||||
|
||||
**環境変数:**
|
||||
- `PICOCLAW_HEARTBEAT_ENABLED=false` で無効化
|
||||
- `PICOCLAW_HEARTBEAT_INTERVAL=60` で間隔変更
|
||||
|
||||
### 基本設定
|
||||
|
||||
1. **設定ファイルの作成:**
|
||||
|
||||
サンプル設定ファイルをコピーします:
|
||||
|
||||
```bash
|
||||
cp config.example.json config/config.json
|
||||
```
|
||||
|
||||
2. **設定の編集:**
|
||||
|
||||
`config/config.json` を開き、APIキーや設定を記述します。
|
||||
|
||||
```json
|
||||
{
|
||||
"providers": {
|
||||
@@ -335,11 +647,11 @@ PicoClaw は設定に `config.json` を使用します。
|
||||
}
|
||||
```
|
||||
|
||||
**3. 実行**
|
||||
3. **実行**
|
||||
|
||||
```bash
|
||||
picoclaw agent -m "Hello"
|
||||
```
|
||||
```bash
|
||||
picoclaw agent -m "Hello"
|
||||
```
|
||||
</details>
|
||||
|
||||
<details>
|
||||
@@ -388,7 +700,14 @@ picoclaw agent -m "Hello"
|
||||
"search": {
|
||||
"apiKey": "BSA..."
|
||||
}
|
||||
},
|
||||
"cron": {
|
||||
"exec_timeout_minutes": 5
|
||||
}
|
||||
},
|
||||
"heartbeat": {
|
||||
"enabled": true,
|
||||
"interval": 30
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
@@ -1,19 +1,20 @@
|
||||
<div align="center">
|
||||
<img src="assets/logo.jpg" alt="PicoClaw" width="512">
|
||||
<img src="assets/logo.jpg" alt="PicoClaw" width="512">
|
||||
|
||||
<h1>PicoClaw: Ultra-Efficient AI Assistant in Go</h1>
|
||||
<h1>PicoClaw: Ultra-Efficient AI Assistant in Go</h1>
|
||||
|
||||
<h3>$10 Hardware · 10MB RAM · 1s Boot · 皮皮虾,我们走!</h3>
|
||||
<h3></h3>
|
||||
<h3>$10 Hardware · 10MB RAM · 1s Boot · 皮皮虾,我们走!</h3>
|
||||
|
||||
<p>
|
||||
<img src="https://img.shields.io/badge/Go-1.21+-00ADD8?style=flat&logo=go&logoColor=white" alt="Go">
|
||||
<img src="https://img.shields.io/badge/Arch-x86__64%2C%20ARM64%2C%20RISC--V-blue" alt="Hardware">
|
||||
<img src="https://img.shields.io/badge/license-MIT-green" alt="License">
|
||||
</p>
|
||||
|
||||
[日本語](README.ja.md) | **English**
|
||||
<p>
|
||||
<img src="https://img.shields.io/badge/Go-1.21+-00ADD8?style=flat&logo=go&logoColor=white" alt="Go">
|
||||
<img src="https://img.shields.io/badge/Arch-x86__64%2C%20ARM64%2C%20RISC--V-blue" alt="Hardware">
|
||||
<img src="https://img.shields.io/badge/license-MIT-green" alt="License">
|
||||
<br>
|
||||
<a href="https://picoclaw.io"><img src="https://img.shields.io/badge/Website-picoclaw.io-blue?style=flat&logo=google-chrome&logoColor=white" alt="Website"></a>
|
||||
<a href="https://x.com/SipeedIO"><img src="https://img.shields.io/badge/X_(Twitter)-SipeedIO-black?style=flat&logo=x&logoColor=white" alt="Twitter"></a>
|
||||
</p>
|
||||
|
||||
[中文](README.zh.md) | [日本語](README.ja.md) | **English**
|
||||
</div>
|
||||
|
||||
---
|
||||
@@ -37,9 +38,23 @@
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## 📢 News
|
||||
> [!CAUTION]
|
||||
> **🚨 SECURITY & OFFICIAL CHANNELS / 安全声明**
|
||||
>
|
||||
> * **NO CRYPTO:** PicoClaw has **NO** official token/coin. All claims on `pump.fun` or other trading platforms are **SCAMS**.
|
||||
> * **OFFICIAL DOMAIN:** The **ONLY** official website is **[picoclaw.io](https://picoclaw.io)**, and company website is **[sipeed.com](https://sipeed.com)**
|
||||
> * **Warning:** Many `.ai/.org/.com/.net/...` domains are registered by third parties.
|
||||
> * **Warning:** picoclaw is in early development now and may have unresolved network security issues. Do not deploy to production environments before the v1.0 release.
|
||||
> * **Note:** picoclaw has recently merged a lot of PRs, which may result in a larger memory footprint (10–20MB) in the latest versions. We plan to prioritize resource optimization as soon as the current feature set reaches a stable state.
|
||||
|
||||
2026-02-09 🎉 PicoClaw Launched! Built in 1 day to bring AI Agents to $10 hardware with <10MB RAM. 🦐 皮皮虾,我们走!
|
||||
|
||||
## 📢 News
|
||||
2026-02-16 🎉 PicoClaw hit 12K stars in one week! Thank you all for your support! PicoClaw is growing faster than we ever imagined. Given the high volume of PRs, we urgently need community maintainers. Our volunteer roles and roadmap are officially posted [here](docs/picoclaw_community_roadmap_260216.md) —we can’t wait to have you on board!
|
||||
|
||||
2026-02-13 🎉 PicoClaw hit 5000 stars in 4days! Thank you for the community! There are so many PRs&issues come in (during Chinese New Year holidays), we are finalizing the Project Roadmap and setting up the Developer Group to accelerate PicoClaw's development.
|
||||
🚀 Call to Action: Please submit your feature requests in GitHub Discussions. We will review and prioritize them during our upcoming weekly meeting.
|
||||
|
||||
2026-02-09 🎉 PicoClaw Launched! Built in 1 day to bring AI Agents to $10 hardware with <10MB RAM. 🦐 PicoClaw,Let's Go!
|
||||
|
||||
## ✨ Features
|
||||
|
||||
@@ -53,12 +68,12 @@
|
||||
|
||||
🤖 **AI-Bootstrapped**: Autonomous Go-native implementation — 95% Agent-generated core with human-in-the-loop refinement.
|
||||
|
||||
| | OpenClaw | NanoBot | **PicoClaw** |
|
||||
| --- | --- | --- |--- |
|
||||
| **Language** | TypeScript | Python | **Go** |
|
||||
| **RAM** | >1GB |>100MB| **< 10MB** |
|
||||
| **Startup**</br>(0.8GHz core) | >500s | >30s | **<1s** |
|
||||
| **Cost** | Mac Mini 599$ | Most Linux SBC </br>~50$ |**Any Linux Board**</br>**As low as 10$** |
|
||||
| | OpenClaw | NanoBot | **PicoClaw** |
|
||||
| ----------------------------- | ------------- | ------------------------ | ----------------------------------------- |
|
||||
| **Language** | TypeScript | Python | **Go** |
|
||||
| **RAM** | >1GB | >100MB | **< 10MB** |
|
||||
| **Startup**</br>(0.8GHz core) | >500s | >30s | **<1s** |
|
||||
| **Cost** | Mac Mini 599$ | Most Linux SBC </br>~50$ | **Any Linux Board**</br>**As low as 10$** |
|
||||
|
||||
<img src="assets/compare.jpg" alt="PicoClaw" width="512">
|
||||
|
||||
@@ -84,11 +99,25 @@
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
### 📱 Run on old Android Phones
|
||||
Give your decade-old phone a second life! Turn it into a smart AI Assistant with PicoClaw. Quick Start:
|
||||
1. **Install Termux** (Available on F-Droid or Google Play).
|
||||
2. **Execute cmds**
|
||||
```bash
|
||||
# Note: Replace v0.1.1 with the latest version from the Releases page
|
||||
wget https://github.com/sipeed/picoclaw/releases/download/v0.1.1/picoclaw-linux-arm64
|
||||
chmod +x picoclaw-linux-arm64
|
||||
pkg install proot
|
||||
termux-chroot ./picoclaw-linux-arm64 onboard
|
||||
```
|
||||
And then follow the instructions in the "Quick Start" section to complete the configuration!
|
||||
<img src="assets/termux.jpg" alt="PicoClaw" width="512">
|
||||
|
||||
### 🐜 Innovative Low-Footprint Deploy
|
||||
|
||||
PicoClaw can be deployed on almost any Linux device!
|
||||
|
||||
- $9.9 [LicheeRV-Nano](https://www.aliexpress.com/item/1005006519668532.html) E(Ethernet) or W(WiFi6) version, for Minimal Home Assistant
|
||||
- $9.9 [LicheeRV-Nano](https://www.aliexpress.com/item/1005006519668532.html) E(Ethernet) or W(WiFi6) version, for Minimal Home Assistant
|
||||
- $30~50 [NanoKVM](https://www.aliexpress.com/item/1005007369816019.html), or $100 [NanoKVM-Pro](https://www.aliexpress.com/item/1005010048471263.html) for Automated Server Maintenance
|
||||
- $50 [MaixCAM](https://www.aliexpress.com/item/1005008053333693.html) or $100 [MaixCAM2](https://www.kickstarter.com/projects/zepan/maixcam2-build-your-next-gen-4k-ai-camera) for Smart Monitoring
|
||||
|
||||
@@ -165,7 +194,7 @@ docker compose --profile gateway up -d
|
||||
> [!TIP]
|
||||
> Set your API key in `~/.picoclaw/config.json`.
|
||||
> Get API keys: [OpenRouter](https://openrouter.ai/keys) (LLM) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) (LLM)
|
||||
> Web search is **optional** - get free [Brave Search API](https://brave.com/search/api) (2000 free queries/month)
|
||||
> Web search is **optional** - get free [Brave Search API](https://brave.com/search/api) (2000 free queries/month) or use built-in auto fallback.
|
||||
|
||||
**1. Initialize**
|
||||
|
||||
@@ -194,9 +223,14 @@ picoclaw onboard
|
||||
},
|
||||
"tools": {
|
||||
"web": {
|
||||
"search": {
|
||||
"brave": {
|
||||
"enabled": false,
|
||||
"api_key": "YOUR_BRAVE_API_KEY",
|
||||
"max_results": 5
|
||||
},
|
||||
"duckduckgo": {
|
||||
"enabled": true,
|
||||
"max_results": 5
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -205,8 +239,8 @@ picoclaw onboard
|
||||
|
||||
**3. Get API Keys**
|
||||
|
||||
- **LLM Provider**: [OpenRouter](https://openrouter.ai/keys) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) · [Anthropic](https://console.anthropic.com) · [OpenAI](https://platform.openai.com) · [Gemini](https://aistudio.google.com/api-keys)
|
||||
- **Web Search** (optional): [Brave Search](https://brave.com/search/api) - Free tier available (2000 requests/month)
|
||||
* **LLM Provider**: [OpenRouter](https://openrouter.ai/keys) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) · [Anthropic](https://console.anthropic.com) · [OpenAI](https://platform.openai.com) · [Gemini](https://aistudio.google.com/api-keys)
|
||||
* **Web Search** (optional): [Brave Search](https://brave.com/search/api) - Free tier available (2000 requests/month)
|
||||
|
||||
> **Note**: See `config.example.json` for a complete configuration template.
|
||||
|
||||
@@ -222,23 +256,24 @@ That's it! You have a working AI assistant in 2 minutes.
|
||||
|
||||
## 💬 Chat Apps
|
||||
|
||||
Talk to your picoclaw through Telegram, Discord, or DingTalk
|
||||
Talk to your picoclaw through Telegram, Discord, DingTalk, or LINE
|
||||
|
||||
| Channel | Setup |
|
||||
|---------|-------|
|
||||
| **Telegram** | Easy (just a token) |
|
||||
| **Discord** | Easy (bot token + intents) |
|
||||
| **QQ** | Easy (AppID + AppSecret) |
|
||||
| **DingTalk** | Medium (app credentials) |
|
||||
| Channel | Setup |
|
||||
| ------------ | ---------------------------------- |
|
||||
| **Telegram** | Easy (just a token) |
|
||||
| **Discord** | Easy (bot token + intents) |
|
||||
| **QQ** | Easy (AppID + AppSecret) |
|
||||
| **DingTalk** | Medium (app credentials) |
|
||||
| **LINE** | Medium (credentials + webhook URL) |
|
||||
|
||||
<details>
|
||||
<summary><b>Telegram</b> (Recommended)</summary>
|
||||
|
||||
**1. Create a bot**
|
||||
|
||||
- Open Telegram, search `@BotFather`
|
||||
- Send `/newbot`, follow prompts
|
||||
- Copy the token
|
||||
* Open Telegram, search `@BotFather`
|
||||
* Send `/newbot`, follow prompts
|
||||
* Copy the token
|
||||
|
||||
**2. Configure**
|
||||
|
||||
@@ -269,19 +304,19 @@ picoclaw gateway
|
||||
|
||||
**1. Create a bot**
|
||||
|
||||
- Go to <https://discord.com/developers/applications>
|
||||
- Create an application → Bot → Add Bot
|
||||
- Copy the bot token
|
||||
* Go to <https://discord.com/developers/applications>
|
||||
* Create an application → Bot → Add Bot
|
||||
* Copy the bot token
|
||||
|
||||
**2. Enable intents**
|
||||
|
||||
- In the Bot settings, enable **MESSAGE CONTENT INTENT**
|
||||
- (Optional) Enable **SERVER MEMBERS INTENT** if you plan to use allow lists based on member data
|
||||
* In the Bot settings, enable **MESSAGE CONTENT INTENT**
|
||||
* (Optional) Enable **SERVER MEMBERS INTENT** if you plan to use allow lists based on member data
|
||||
|
||||
**3. Get your User ID**
|
||||
|
||||
- Discord Settings → Advanced → enable **Developer Mode**
|
||||
- Right-click your avatar → **Copy User ID**
|
||||
* Discord Settings → Advanced → enable **Developer Mode**
|
||||
* Right-click your avatar → **Copy User ID**
|
||||
|
||||
**4. Configure**
|
||||
|
||||
@@ -299,10 +334,10 @@ picoclaw gateway
|
||||
|
||||
**5. Invite the bot**
|
||||
|
||||
- OAuth2 → URL Generator
|
||||
- Scopes: `bot`
|
||||
- Bot Permissions: `Send Messages`, `Read Message History`
|
||||
- Open the generated invite URL and add the bot to your server
|
||||
* OAuth2 → URL Generator
|
||||
* Scopes: `bot`
|
||||
* Bot Permissions: `Send Messages`, `Read Message History`
|
||||
* Open the generated invite URL and add the bot to your server
|
||||
|
||||
**6. Run**
|
||||
|
||||
@@ -317,7 +352,7 @@ picoclaw gateway
|
||||
|
||||
**1. Create a bot**
|
||||
|
||||
- Go to [QQ Open Platform](https://connect.qq.com/)
|
||||
- Go to [QQ Open Platform](https://q.qq.com/#)
|
||||
- Create an application → Get **AppID** and **AppSecret**
|
||||
|
||||
**2. Configure**
|
||||
@@ -350,9 +385,9 @@ picoclaw gateway
|
||||
|
||||
**1. Create a bot**
|
||||
|
||||
- Go to [Open Platform](https://open.dingtalk.com/)
|
||||
- Create an internal app
|
||||
- Copy Client ID and Client Secret
|
||||
* Go to [Open Platform](https://open.dingtalk.com/)
|
||||
* Create an internal app
|
||||
* Copy Client ID and Client Secret
|
||||
|
||||
**2. Configure**
|
||||
|
||||
@@ -379,14 +414,62 @@ picoclaw gateway
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><b>LINE</b></summary>
|
||||
|
||||
**1. Create a LINE Official Account**
|
||||
|
||||
- Go to [LINE Developers Console](https://developers.line.biz/)
|
||||
- Create a provider → Create a Messaging API channel
|
||||
- Copy **Channel Secret** and **Channel Access Token**
|
||||
|
||||
**2. Configure**
|
||||
|
||||
```json
|
||||
{
|
||||
"channels": {
|
||||
"line": {
|
||||
"enabled": true,
|
||||
"channel_secret": "YOUR_CHANNEL_SECRET",
|
||||
"channel_access_token": "YOUR_CHANNEL_ACCESS_TOKEN",
|
||||
"webhook_host": "0.0.0.0",
|
||||
"webhook_port": 18791,
|
||||
"webhook_path": "/webhook/line",
|
||||
"allow_from": []
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**3. Set up Webhook URL**
|
||||
|
||||
LINE requires HTTPS for webhooks. Use a reverse proxy or tunnel:
|
||||
|
||||
```bash
|
||||
# Example with ngrok
|
||||
ngrok http 18791
|
||||
```
|
||||
|
||||
Then set the Webhook URL in LINE Developers Console to `https://your-domain/webhook/line` and enable **Use webhook**.
|
||||
|
||||
**4. Run**
|
||||
|
||||
```bash
|
||||
picoclaw gateway
|
||||
```
|
||||
|
||||
> In group chats, the bot responds only when @mentioned. Replies quote the original message.
|
||||
|
||||
> **Docker Compose**: Add `ports: ["18791:18791"]` to the `picoclaw-gateway` service to expose the webhook port.
|
||||
|
||||
</details>
|
||||
|
||||
## <img src="assets/clawdchat-icon.png" width="24" height="24" alt="ClawdChat"> Join the Agent Social Network
|
||||
|
||||
Connect Picoclaw to the Agent Social Network simply by sending a single message via the CLI or any integrated Chat App.
|
||||
|
||||
**Read `https://clawdchat.ai/skill.md` and follow the instructions to join [ClawdChat.ai](https://clawdchat.ai)**
|
||||
|
||||
|
||||
|
||||
## ⚙️ Configuration
|
||||
|
||||
Config file: `~/.picoclaw/config.json`
|
||||
@@ -399,36 +482,209 @@ PicoClaw stores data in your configured workspace (default: `~/.picoclaw/workspa
|
||||
~/.picoclaw/workspace/
|
||||
├── sessions/ # Conversation sessions and history
|
||||
├── memory/ # Long-term memory (MEMORY.md)
|
||||
├── state/ # Persistent state (last channel, etc.)
|
||||
├── cron/ # Scheduled jobs database
|
||||
├── skills/ # Custom skills
|
||||
├── AGENTS.md # Agent behavior guide
|
||||
├── HEARTBEAT.md # Periodic task prompts (checked every 30 min)
|
||||
├── IDENTITY.md # Agent identity
|
||||
├── SOUL.md # Agent soul
|
||||
├── TOOLS.md # Tool descriptions
|
||||
└── USER.md # User preferences
|
||||
```
|
||||
|
||||
### 🔒 Security Sandbox
|
||||
|
||||
PicoClaw runs in a sandboxed environment by default. The agent can only access files and execute commands within the configured workspace.
|
||||
|
||||
#### Default Configuration
|
||||
|
||||
```json
|
||||
{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"workspace": "~/.picoclaw/workspace",
|
||||
"restrict_to_workspace": true
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
| Option | Default | Description |
|
||||
|--------|---------|-------------|
|
||||
| `workspace` | `~/.picoclaw/workspace` | Working directory for the agent |
|
||||
| `restrict_to_workspace` | `true` | Restrict file/command access to workspace |
|
||||
|
||||
#### Protected Tools
|
||||
|
||||
When `restrict_to_workspace: true`, the following tools are sandboxed:
|
||||
|
||||
| Tool | Function | Restriction |
|
||||
|------|----------|-------------|
|
||||
| `read_file` | Read files | Only files within workspace |
|
||||
| `write_file` | Write files | Only files within workspace |
|
||||
| `list_dir` | List directories | Only directories within workspace |
|
||||
| `edit_file` | Edit files | Only files within workspace |
|
||||
| `append_file` | Append to files | Only files within workspace |
|
||||
| `exec` | Execute commands | Command paths must be within workspace |
|
||||
|
||||
#### Additional Exec Protection
|
||||
|
||||
Even with `restrict_to_workspace: false`, the `exec` tool blocks these dangerous commands:
|
||||
|
||||
* `rm -rf`, `del /f`, `rmdir /s` — Bulk deletion
|
||||
* `format`, `mkfs`, `diskpart` — Disk formatting
|
||||
* `dd if=` — Disk imaging
|
||||
* Writing to `/dev/sd[a-z]` — Direct disk writes
|
||||
* `shutdown`, `reboot`, `poweroff` — System shutdown
|
||||
* Fork bomb `:(){ :|:& };:`
|
||||
|
||||
#### Error Examples
|
||||
|
||||
```
|
||||
[ERROR] tool: Tool execution failed
|
||||
{tool=exec, error=Command blocked by safety guard (path outside working dir)}
|
||||
```
|
||||
|
||||
```
|
||||
[ERROR] tool: Tool execution failed
|
||||
{tool=exec, error=Command blocked by safety guard (dangerous pattern detected)}
|
||||
```
|
||||
|
||||
#### Disabling Restrictions (Security Risk)
|
||||
|
||||
If you need the agent to access paths outside the workspace:
|
||||
|
||||
**Method 1: Config file**
|
||||
|
||||
```json
|
||||
{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"restrict_to_workspace": false
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Method 2: Environment variable**
|
||||
|
||||
```bash
|
||||
export PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE=false
|
||||
```
|
||||
|
||||
> ⚠️ **Warning**: Disabling this restriction allows the agent to access any path on your system. Use with caution in controlled environments only.
|
||||
|
||||
#### Security Boundary Consistency
|
||||
|
||||
The `restrict_to_workspace` setting applies consistently across all execution paths:
|
||||
|
||||
| Execution Path | Security Boundary |
|
||||
|----------------|-------------------|
|
||||
| Main Agent | `restrict_to_workspace` ✅ |
|
||||
| Subagent / Spawn | Inherits same restriction ✅ |
|
||||
| Heartbeat tasks | Inherits same restriction ✅ |
|
||||
|
||||
All paths share the same workspace restriction — there's no way to bypass the security boundary through subagents or scheduled tasks.
|
||||
|
||||
### Heartbeat (Periodic Tasks)
|
||||
|
||||
PicoClaw can perform periodic tasks automatically. Create a `HEARTBEAT.md` file in your workspace:
|
||||
|
||||
```markdown
|
||||
# Periodic Tasks
|
||||
|
||||
- Check my email for important messages
|
||||
- Review my calendar for upcoming events
|
||||
- Check the weather forecast
|
||||
```
|
||||
|
||||
The agent will read this file every 30 minutes (configurable) and execute any tasks using available tools.
|
||||
|
||||
#### Async Tasks with Spawn
|
||||
|
||||
For long-running tasks (web search, API calls), use the `spawn` tool to create a **subagent**:
|
||||
|
||||
```markdown
|
||||
# Periodic Tasks
|
||||
|
||||
## Quick Tasks (respond directly)
|
||||
- Report current time
|
||||
|
||||
## Long Tasks (use spawn for async)
|
||||
- Search the web for AI news and summarize
|
||||
- Check email and report important messages
|
||||
```
|
||||
|
||||
**Key behaviors:**
|
||||
|
||||
| Feature | Description |
|
||||
|---------|-------------|
|
||||
| **spawn** | Creates async subagent, doesn't block heartbeat |
|
||||
| **Independent context** | Subagent has its own context, no session history |
|
||||
| **message tool** | Subagent communicates with user directly via message tool |
|
||||
| **Non-blocking** | After spawning, heartbeat continues to next task |
|
||||
|
||||
#### How Subagent Communication Works
|
||||
|
||||
```
|
||||
Heartbeat triggers
|
||||
↓
|
||||
Agent reads HEARTBEAT.md
|
||||
↓
|
||||
For long task: spawn subagent
|
||||
↓ ↓
|
||||
Continue to next task Subagent works independently
|
||||
↓ ↓
|
||||
All tasks done Subagent uses "message" tool
|
||||
↓ ↓
|
||||
Respond HEARTBEAT_OK User receives result directly
|
||||
```
|
||||
|
||||
The subagent has access to tools (message, web_search, etc.) and can communicate with the user independently without going through the main agent.
|
||||
|
||||
**Configuration:**
|
||||
|
||||
```json
|
||||
{
|
||||
"heartbeat": {
|
||||
"enabled": true,
|
||||
"interval": 30
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
| Option | Default | Description |
|
||||
|--------|---------|-------------|
|
||||
| `enabled` | `true` | Enable/disable heartbeat |
|
||||
| `interval` | `30` | Check interval in minutes (min: 5) |
|
||||
|
||||
**Environment variables:**
|
||||
|
||||
* `PICOCLAW_HEARTBEAT_ENABLED=false` to disable
|
||||
* `PICOCLAW_HEARTBEAT_INTERVAL=60` to change interval
|
||||
|
||||
### Providers
|
||||
|
||||
> [!NOTE]
|
||||
> Groq provides free voice transcription via Whisper. If configured, Telegram voice messages will be automatically transcribed.
|
||||
|
||||
| Provider | Purpose | Get API Key |
|
||||
|----------|---------|-------------|
|
||||
| `gemini` | LLM (Gemini direct) | [aistudio.google.com](https://aistudio.google.com) |
|
||||
| `zhipu` | LLM (Zhipu direct) | [bigmodel.cn](bigmodel.cn) |
|
||||
| `openrouter(To be tested)` | LLM (recommended, access to all models) | [openrouter.ai](https://openrouter.ai) |
|
||||
| `anthropic(To be tested)` | LLM (Claude direct) | [console.anthropic.com](https://console.anthropic.com) |
|
||||
| `openai(To be tested)` | LLM (GPT direct) | [platform.openai.com](https://platform.openai.com) |
|
||||
| `deepseek(To be tested)` | LLM (DeepSeek direct) | [platform.deepseek.com](https://platform.deepseek.com) |
|
||||
| `groq` | LLM + **Voice transcription** (Whisper) | [console.groq.com](https://console.groq.com) |
|
||||
| Provider | Purpose | Get API Key |
|
||||
| -------------------------- | --------------------------------------- | ------------------------------------------------------ |
|
||||
| `gemini` | LLM (Gemini direct) | [aistudio.google.com](https://aistudio.google.com) |
|
||||
| `zhipu` | LLM (Zhipu direct) | [bigmodel.cn](bigmodel.cn) |
|
||||
| `openrouter(To be tested)` | LLM (recommended, access to all models) | [openrouter.ai](https://openrouter.ai) |
|
||||
| `anthropic(To be tested)` | LLM (Claude direct) | [console.anthropic.com](https://console.anthropic.com) |
|
||||
| `openai(To be tested)` | LLM (GPT direct) | [platform.openai.com](https://platform.openai.com) |
|
||||
| `deepseek(To be tested)` | LLM (DeepSeek direct) | [platform.deepseek.com](https://platform.deepseek.com) |
|
||||
| `groq` | LLM + **Voice transcription** (Whisper) | [console.groq.com](https://console.groq.com) |
|
||||
|
||||
<details>
|
||||
<summary><b>Zhipu</b></summary>
|
||||
|
||||
**1. Get API key and base URL**
|
||||
|
||||
- Get [API key](https://bigmodel.cn/usercenter/proj-mgmt/apikeys)
|
||||
* Get [API key](https://bigmodel.cn/usercenter/proj-mgmt/apikeys)
|
||||
|
||||
**2. Configure**
|
||||
|
||||
@@ -447,8 +703,8 @@ PicoClaw stores data in your configured workspace (default: `~/.picoclaw/workspa
|
||||
"zhipu": {
|
||||
"api_key": "Your API Key",
|
||||
"api_base": "https://open.bigmodel.cn/api/paas/v4"
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
@@ -509,10 +765,23 @@ picoclaw agent -m "Hello"
|
||||
},
|
||||
"tools": {
|
||||
"web": {
|
||||
"search": {
|
||||
"api_key": "BSA..."
|
||||
"brave": {
|
||||
"enabled": false,
|
||||
"api_key": "BSA...",
|
||||
"max_results": 5
|
||||
},
|
||||
"duckduckgo": {
|
||||
"enabled": true,
|
||||
"max_results": 5
|
||||
}
|
||||
},
|
||||
"cron": {
|
||||
"exec_timeout_minutes": 5
|
||||
}
|
||||
},
|
||||
"heartbeat": {
|
||||
"enabled": true,
|
||||
"interval": 30
|
||||
}
|
||||
}
|
||||
```
|
||||
@@ -521,23 +790,23 @@ picoclaw agent -m "Hello"
|
||||
|
||||
## CLI Reference
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `picoclaw onboard` | Initialize config & workspace |
|
||||
| `picoclaw agent -m "..."` | Chat with the agent |
|
||||
| `picoclaw agent` | Interactive chat mode |
|
||||
| `picoclaw gateway` | Start the gateway |
|
||||
| `picoclaw status` | Show status |
|
||||
| `picoclaw cron list` | List all scheduled jobs |
|
||||
| `picoclaw cron add ...` | Add a scheduled job |
|
||||
| Command | Description |
|
||||
| ------------------------- | ----------------------------- |
|
||||
| `picoclaw onboard` | Initialize config & workspace |
|
||||
| `picoclaw agent -m "..."` | Chat with the agent |
|
||||
| `picoclaw agent` | Interactive chat mode |
|
||||
| `picoclaw gateway` | Start the gateway |
|
||||
| `picoclaw status` | Show status |
|
||||
| `picoclaw cron list` | List all scheduled jobs |
|
||||
| `picoclaw cron add ...` | Add a scheduled job |
|
||||
|
||||
### Scheduled Tasks / Reminders
|
||||
|
||||
PicoClaw supports scheduled reminders and recurring tasks through the `cron` tool:
|
||||
|
||||
- **One-time reminders**: "Remind me in 10 minutes" → triggers once after 10min
|
||||
- **Recurring tasks**: "Remind me every 2 hours" → triggers every 2 hours
|
||||
- **Cron expressions**: "Remind me at 9am daily" → uses cron expression
|
||||
* **One-time reminders**: "Remind me in 10 minutes" → triggers once after 10min
|
||||
* **Recurring tasks**: "Remind me every 2 hours" → triggers every 2 hours
|
||||
* **Cron expressions**: "Remind me at 9am daily" → uses cron expression
|
||||
|
||||
Jobs are stored in `~/.picoclaw/workspace/cron/` and processed automatically.
|
||||
|
||||
@@ -545,6 +814,12 @@ Jobs are stored in `~/.picoclaw/workspace/cron/` and processed automatically.
|
||||
|
||||
PRs welcome! The codebase is intentionally small and readable. 🤗
|
||||
|
||||
Roadmap coming soon...
|
||||
|
||||
Developer group building, Entry Requirement: At least 1 Merged PR.
|
||||
|
||||
User Groups:
|
||||
|
||||
discord: <https://discord.gg/V4sAZ9XWpN>
|
||||
|
||||
<img src="assets/wechat.png" alt="PicoClaw" width="512">
|
||||
@@ -557,21 +832,28 @@ This is normal if you haven't configured a search API key yet. PicoClaw will pro
|
||||
|
||||
To enable web search:
|
||||
|
||||
1. Get a free API key at [https://brave.com/search/api](https://brave.com/search/api) (2000 free queries/month)
|
||||
2. Add to `~/.picoclaw/config.json`:
|
||||
1. **Option 1 (Recommended)**: Get a free API key at [https://brave.com/search/api](https://brave.com/search/api) (2000 free queries/month) for the best results.
|
||||
2. **Option 2 (No Credit Card)**: If you don't have a key, we automatically fall back to **DuckDuckGo** (no key required).
|
||||
|
||||
```json
|
||||
{
|
||||
"tools": {
|
||||
"web": {
|
||||
"search": {
|
||||
"api_key": "YOUR_BRAVE_API_KEY",
|
||||
"max_results": 5
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
Add the key to `~/.picoclaw/config.json` if using Brave:
|
||||
|
||||
```json
|
||||
{
|
||||
"tools": {
|
||||
"web": {
|
||||
"brave": {
|
||||
"enabled": false,
|
||||
"api_key": "YOUR_BRAVE_API_KEY",
|
||||
"max_results": 5
|
||||
},
|
||||
"duckduckgo": {
|
||||
"enabled": true,
|
||||
"max_results": 5
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Getting content filtering errors
|
||||
|
||||
@@ -585,9 +867,9 @@ This happens when another instance of the bot is running. Make sure only one `pi
|
||||
|
||||
## 📝 API Key Comparison
|
||||
|
||||
| Service | Free Tier | Use Case |
|
||||
|---------|-----------|-----------|
|
||||
| **OpenRouter** | 200K tokens/month | Multiple models (Claude, GPT-4, etc.) |
|
||||
| **Zhipu** | 200K tokens/month | Best for Chinese users |
|
||||
| **Brave Search** | 2000 queries/month | Web search functionality |
|
||||
| **Groq** | Free tier available | Fast inference (Llama, Mixtral) |
|
||||
| Service | Free Tier | Use Case |
|
||||
| ---------------- | ------------------- | ------------------------------------- |
|
||||
| **OpenRouter** | 200K tokens/month | Multiple models (Claude, GPT-4, etc.) |
|
||||
| **Zhipu** | 200K tokens/month | Best for Chinese users |
|
||||
| **Brave Search** | 2000 queries/month | Web search functionality |
|
||||
| **Groq** | Free tier available | Fast inference (Llama, Mixtral) |
|
||||
|
||||
+744
@@ -0,0 +1,744 @@
|
||||
<div align="center">
|
||||
<img src="assets/logo.jpg" alt="PicoClaw" width="512">
|
||||
|
||||
<h1>PicoClaw: 基于Go语言的超高效 AI 助手</h1>
|
||||
|
||||
<h3>10$硬件 · 10MB内存 · 1秒启动 · 皮皮虾,我们走!</h3>
|
||||
|
||||
<p>
|
||||
<img src="https://img.shields.io/badge/Go-1.21+-00ADD8?style=flat&logo=go&logoColor=white" alt="Go">
|
||||
<img src="https://img.shields.io/badge/Arch-x86__64%2C%20ARM64%2C%20RISC--V-blue" alt="Hardware">
|
||||
<img src="https://img.shields.io/badge/license-MIT-green" alt="License">
|
||||
<br>
|
||||
<a href="https://picoclaw.io"><img src="https://img.shields.io/badge/Website-picoclaw.io-blue?style=flat&logo=google-chrome&logoColor=white" alt="Website"></a>
|
||||
<a href="https://x.com/SipeedIO"><img src="https://img.shields.io/badge/X_(Twitter)-SipeedIO-black?style=flat&logo=x&logoColor=white" alt="Twitter"></a>
|
||||
</p>
|
||||
|
||||
**中文** | [日本語](README.ja.md) | [English](README.md)
|
||||
</div>
|
||||
|
||||
---
|
||||
|
||||
🦐 **PicoClaw** 是一个受 [nanobot](https://github.com/HKUDS/nanobot) 启发的超轻量级个人 AI 助手。它采用 **Go 语言** 从零重构,经历了一个“自举”过程——即由 AI Agent 自身驱动了整个架构迁移和代码优化。
|
||||
|
||||
⚡️ **极致轻量**:可在 **10 美元** 的硬件上运行,内存占用 **<10MB**。这意味着比 OpenClaw 节省 99% 的内存,比 Mac mini 便宜 98%!
|
||||
|
||||
<table align="center">
|
||||
<tr align="center">
|
||||
<td align="center" valign="top">
|
||||
<p align="center">
|
||||
<img src="assets/picoclaw_mem.gif" width="360" height="240">
|
||||
</p>
|
||||
</td>
|
||||
<td align="center" valign="top">
|
||||
<p align="center">
|
||||
<img src="assets/licheervnano.png" width="400" height="240">
|
||||
</p>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
注意:人手有限,中文文档可能略有滞后,请优先查看英文文档。
|
||||
|
||||
> [!CAUTION]
|
||||
> **🚨 SECURITY & OFFICIAL CHANNELS / 安全声明**
|
||||
> * **无加密货币 (NO CRYPTO):** PicoClaw **没有** 发行任何官方代币、Token 或虚拟货币。所有在 `pump.fun` 或其他交易平台上的相关声称均为 **诈骗**。
|
||||
> * **官方域名:** 唯一的官方网站是 **[picoclaw.io](https://picoclaw.io)**,公司官网是 **[sipeed.com](https://sipeed.com)**。
|
||||
> * **警惕:** 许多 `.ai/.org/.com/.net/...` 后缀的域名被第三方抢注,请勿轻信。
|
||||
> * **注意:** picoclaw正在初期的快速功能开发阶段,可能有尚未修复的网络安全问题,在1.0正式版发布前,请不要将其部署到生产环境中
|
||||
> * **注意:** picoclaw最近合并了大量PRs,近期版本可能内存占用较大(10~20MB),我们将在功能较为收敛后进行资源占用优化.
|
||||
|
||||
|
||||
## 📢 新闻 (News)
|
||||
2026-02-16 🎉 PicoClaw 在一周内突破了12K star! 感谢大家的关注!PicoClaw 的成长速度超乎我们预期. 由于PR数量的快速膨胀,我们亟需社区开发者参与维护. 我们需要的志愿者角色和roadmap已经发布到了[这里](docs/picoclaw_community_roadmap_260216.md), 期待你的参与!
|
||||
|
||||
2026-02-13 🎉 **PicoClaw 在 4 天内突破 5000 Stars!** 感谢社区的支持!由于正值中国春节假期,PR 和 Issue 涌入较多,我们正在利用这段时间敲定 **项目路线图 (Roadmap)** 并组建 **开发者群组**,以便加速 PicoClaw 的开发。
|
||||
🚀 **行动号召:** 请在 GitHub Discussions 中提交您的功能请求 (Feature Requests)。我们将在接下来的周会上进行审查和优先级排序。
|
||||
|
||||
2026-02-09 🎉 **PicoClaw 正式发布!** 仅用 1 天构建,旨在将 AI Agent 带入 10 美元硬件与 <10MB 内存的世界。🦐 PicoClaw(皮皮虾),我们走!
|
||||
|
||||
## ✨ 特性
|
||||
|
||||
🪶 **超轻量级**: 核心功能内存占用 <10MB — 比 Clawdbot 小 99%。
|
||||
|
||||
💰 **极低成本**: 高效到足以在 10 美元的硬件上运行 — 比 Mac mini 便宜 98%。
|
||||
|
||||
⚡️ **闪电启动**: 启动速度快 400 倍,即使在 0.6GHz 单核处理器上也能在 1 秒内启动。
|
||||
|
||||
🌍 **真正可移植**: 跨 RISC-V、ARM 和 x86 架构的单二进制文件,一键运行!
|
||||
|
||||
🤖 **AI 自举**: 纯 Go 语言原生实现 — 95% 的核心代码由 Agent 生成,并经由“人机回环 (Human-in-the-loop)”微调。
|
||||
|
||||
| | OpenClaw | NanoBot | **PicoClaw** |
|
||||
| --- | --- | --- | --- |
|
||||
| **语言** | TypeScript | Python | **Go** |
|
||||
| **RAM** | >1GB | >100MB | **< 10MB** |
|
||||
| **启动时间**</br>(0.8GHz core) | >500s | >30s | **<1s** |
|
||||
| **成本** | Mac Mini $599 | 大多数 Linux 开发板 ~$50 | **任意 Linux 开发板**</br>**低至 $10** |
|
||||
|
||||
<img src="assets/compare.jpg" alt="PicoClaw" width="512">
|
||||
|
||||
## 🦾 演示
|
||||
|
||||
### 🛠️ 标准助手工作流
|
||||
|
||||
<table align="center">
|
||||
<tr align="center">
|
||||
<th><p align="center">🧩 全栈工程师模式</p></th>
|
||||
<th><p align="center">🗂️ 日志与规划管理</p></th>
|
||||
<th><p align="center">🔎 网络搜索与学习</p></th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center"><p align="center"><img src="assets/picoclaw_code.gif" width="240" height="180"></p></td>
|
||||
<td align="center"><p align="center"><img src="assets/picoclaw_memory.gif" width="240" height="180"></p></td>
|
||||
<td align="center"><p align="center"><img src="assets/picoclaw_search.gif" width="240" height="180"></p></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">开发 • 部署 • 扩展</td>
|
||||
<td align="center">日程 • 自动化 • 记忆</td>
|
||||
<td align="center">发现 • 洞察 • 趋势</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
### 📱 在手机上轻松运行
|
||||
picoclaw 可以将你10年前的老旧手机废物利用,变身成为你的AI助理!快速指南:
|
||||
1. 先去应用商店下载安装Termux
|
||||
2. 打开后执行指令
|
||||
```bash
|
||||
# 注意: 下面的v0.1.1 可以换为你实际看到的最新版本
|
||||
wget https://github.com/sipeed/picoclaw/releases/download/v0.1.1/picoclaw-linux-arm64
|
||||
chmod +x picoclaw-linux-arm64
|
||||
pkg install proot
|
||||
termux-chroot ./picoclaw-linux-arm64 onboard
|
||||
```
|
||||
然后跟随下面的“快速开始”章节继续配置picoclaw即可使用!
|
||||
<img src="assets/termux.jpg" alt="PicoClaw" width="512">
|
||||
|
||||
|
||||
|
||||
|
||||
### 🐜 创新的低占用部署
|
||||
|
||||
PicoClaw 几乎可以部署在任何 Linux 设备上!
|
||||
|
||||
* $9.9 [LicheeRV-Nano](https://www.aliexpress.com/item/1005006519668532.html) E(网口) 或 W(WiFi6) 版本,用于极简家庭助手。
|
||||
* $30~50 [NanoKVM](https://www.aliexpress.com/item/1005007369816019.html),或 $100 [NanoKVM-Pro](https://www.aliexpress.com/item/1005010048471263.html),用于自动化服务器运维。
|
||||
* $50 [MaixCAM](https://www.aliexpress.com/item/1005008053333693.html) 或 $100 [MaixCAM2](https://www.kickstarter.com/projects/zepan/maixcam2-build-your-next-gen-4k-ai-camera),用于智能监控。
|
||||
|
||||
[https://private-user-images.githubusercontent.com/83055338/547056448-e7b031ff-d6f5-4468-bcca-5726b6fecb5c.mp4](https://private-user-images.githubusercontent.com/83055338/547056448-e7b031ff-d6f5-4468-bcca-5726b6fecb5c.mp4)
|
||||
|
||||
🌟 更多部署案例敬请期待!
|
||||
|
||||
## 📦 安装
|
||||
|
||||
### 使用预编译二进制文件安装
|
||||
|
||||
从 [Release 页面](https://github.com/sipeed/picoclaw/releases) 下载适用于您平台的固件。
|
||||
|
||||
### 从源码安装(获取最新特性,开发推荐)
|
||||
|
||||
```bash
|
||||
git clone https://github.com/sipeed/picoclaw.git
|
||||
|
||||
cd picoclaw
|
||||
make deps
|
||||
|
||||
# 构建(无需安装)
|
||||
make build
|
||||
|
||||
# 为多平台构建
|
||||
make build-all
|
||||
|
||||
# 构建并安装
|
||||
make install
|
||||
|
||||
```
|
||||
|
||||
## 🐳 Docker Compose
|
||||
|
||||
您也可以使用 Docker Compose 运行 PicoClaw,无需在本地安装任何环境。
|
||||
|
||||
```bash
|
||||
# 1. 克隆仓库
|
||||
git clone https://github.com/sipeed/picoclaw.git
|
||||
cd picoclaw
|
||||
|
||||
# 2. 设置 API Key
|
||||
cp config/config.example.json config/config.json
|
||||
vim config/config.json # 设置 DISCORD_BOT_TOKEN, API keys 等
|
||||
|
||||
# 3. 构建并启动
|
||||
docker compose --profile gateway up -d
|
||||
|
||||
# 4. 查看日志
|
||||
docker compose logs -f picoclaw-gateway
|
||||
|
||||
# 5. 停止
|
||||
docker compose --profile gateway down
|
||||
|
||||
```
|
||||
|
||||
### Agent 模式 (一次性运行)
|
||||
|
||||
```bash
|
||||
# 提问
|
||||
docker compose run --rm picoclaw-agent -m "2+2 等于几?"
|
||||
|
||||
# 交互模式
|
||||
docker compose run --rm picoclaw-agent
|
||||
|
||||
```
|
||||
|
||||
### 重新构建
|
||||
|
||||
```bash
|
||||
docker compose --profile gateway build --no-cache
|
||||
docker compose --profile gateway up -d
|
||||
|
||||
```
|
||||
|
||||
### 🚀 快速开始
|
||||
|
||||
> [!TIP]
|
||||
> 在 `~/.picoclaw/config.json` 中设置您的 API Key。
|
||||
> 获取 API Key: [OpenRouter](https://openrouter.ai/keys) (LLM) · [Zhipu (智谱)](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) (LLM)
|
||||
> 网络搜索是 **可选的** - 获取免费的 [Brave Search API](https://brave.com/search/api) (每月 2000 次免费查询)
|
||||
|
||||
**1. 初始化 (Initialize)**
|
||||
|
||||
```bash
|
||||
picoclaw onboard
|
||||
|
||||
```
|
||||
|
||||
**2. 配置 (Configure)** (`~/.picoclaw/config.json`)
|
||||
|
||||
```json
|
||||
{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"workspace": "~/.picoclaw/workspace",
|
||||
"model": "glm-4.7",
|
||||
"max_tokens": 8192,
|
||||
"temperature": 0.7,
|
||||
"max_tool_iterations": 20
|
||||
}
|
||||
},
|
||||
"providers": {
|
||||
"openrouter": {
|
||||
"api_key": "xxx",
|
||||
"api_base": "https://openrouter.ai/api/v1"
|
||||
}
|
||||
},
|
||||
"tools": {
|
||||
"web": {
|
||||
"search": {
|
||||
"api_key": "YOUR_BRAVE_API_KEY",
|
||||
"max_results": 5
|
||||
}
|
||||
},
|
||||
"cron": {
|
||||
"exec_timeout_minutes": 5
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
**3. 获取 API Key**
|
||||
|
||||
* **LLM 提供商**: [OpenRouter](https://openrouter.ai/keys) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) · [Anthropic](https://console.anthropic.com) · [OpenAI](https://platform.openai.com) · [Gemini](https://aistudio.google.com/api-keys)
|
||||
* **网络搜索** (可选): [Brave Search](https://brave.com/search/api) - 提供免费层级 (2000 请求/月)
|
||||
|
||||
> **注意**: 完整的配置模板请参考 `config.example.json`。
|
||||
|
||||
**4. 对话 (Chat)**
|
||||
|
||||
```bash
|
||||
picoclaw agent -m "2+2 等于几?"
|
||||
|
||||
```
|
||||
|
||||
就是这样!您在 2 分钟内就拥有了一个可工作的 AI 助手。
|
||||
|
||||
---
|
||||
|
||||
## 💬 聊天应用集成 (Chat Apps)
|
||||
|
||||
通过 Telegram, Discord 或钉钉与您的 PicoClaw 对话。
|
||||
|
||||
| 渠道 | 设置难度 |
|
||||
| --- | --- |
|
||||
| **Telegram** | 简单 (仅需 token) |
|
||||
| **Discord** | 简单 (bot token + intents) |
|
||||
| **QQ** | 简单 (AppID + AppSecret) |
|
||||
| **钉钉 (DingTalk)** | 中等 (app credentials) |
|
||||
|
||||
<details>
|
||||
<summary><b>Telegram</b> (推荐)</summary>
|
||||
|
||||
**1. 创建机器人**
|
||||
|
||||
* 打开 Telegram,搜索 `@BotFather`
|
||||
* 发送 `/newbot`,按照提示操作
|
||||
* 复制 token
|
||||
|
||||
**2. 配置**
|
||||
|
||||
```json
|
||||
{
|
||||
"channels": {
|
||||
"telegram": {
|
||||
"enabled": true,
|
||||
"token": "YOUR_BOT_TOKEN",
|
||||
"allowFrom": ["YOUR_USER_ID"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
> 从 Telegram 上的 `@userinfobot` 获取您的用户 ID。
|
||||
|
||||
**3. 运行**
|
||||
|
||||
```bash
|
||||
picoclaw gateway
|
||||
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><b>Discord</b></summary>
|
||||
|
||||
**1. 创建机器人**
|
||||
|
||||
* 前往 [https://discord.com/developers/applications](https://discord.com/developers/applications)
|
||||
* Create an application → Bot → Add Bot
|
||||
* 复制 bot token
|
||||
|
||||
**2. 开启 Intents**
|
||||
|
||||
* 在 Bot 设置中,开启 **MESSAGE CONTENT INTENT**
|
||||
* (可选) 如果计划基于成员数据使用白名单,开启 **SERVER MEMBERS INTENT**
|
||||
|
||||
**3. 获取您的 User ID**
|
||||
|
||||
* Discord 设置 → Advanced → 开启 **Developer Mode**
|
||||
* 右键点击您的头像 → **Copy User ID**
|
||||
|
||||
**4. 配置**
|
||||
|
||||
```json
|
||||
{
|
||||
"channels": {
|
||||
"discord": {
|
||||
"enabled": true,
|
||||
"token": "YOUR_BOT_TOKEN",
|
||||
"allowFrom": ["YOUR_USER_ID"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
**5. 邀请机器人**
|
||||
|
||||
* OAuth2 → URL Generator
|
||||
* Scopes: `bot`
|
||||
* Bot Permissions: `Send Messages`, `Read Message History`
|
||||
* 打开生成的邀请 URL,将机器人添加到您的服务器
|
||||
|
||||
**6. 运行**
|
||||
|
||||
```bash
|
||||
picoclaw gateway
|
||||
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><b>QQ</b></summary>
|
||||
|
||||
**1. 创建机器人**
|
||||
|
||||
* 前往 [QQ 开放平台](https://q.qq.com/#)
|
||||
* 创建应用 → 获取 **AppID** 和 **AppSecret**
|
||||
|
||||
**2. 配置**
|
||||
|
||||
```json
|
||||
{
|
||||
"channels": {
|
||||
"qq": {
|
||||
"enabled": true,
|
||||
"app_id": "YOUR_APP_ID",
|
||||
"app_secret": "YOUR_APP_SECRET",
|
||||
"allow_from": []
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
> 将 `allow_from` 设为空以允许所有用户,或指定 QQ 号以限制访问。
|
||||
|
||||
**3. 运行**
|
||||
|
||||
```bash
|
||||
picoclaw gateway
|
||||
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><b>钉钉 (DingTalk)</b></summary>
|
||||
|
||||
**1. 创建机器人**
|
||||
|
||||
* 前往 [开放平台](https://open.dingtalk.com/)
|
||||
* 创建内部应用
|
||||
* 复制 Client ID 和 Client Secret
|
||||
|
||||
**2. 配置**
|
||||
|
||||
```json
|
||||
{
|
||||
"channels": {
|
||||
"dingtalk": {
|
||||
"enabled": true,
|
||||
"client_id": "YOUR_CLIENT_ID",
|
||||
"client_secret": "YOUR_CLIENT_SECRET",
|
||||
"allow_from": []
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
> 将 `allow_from` 设为空以允许所有用户,或指定 ID 以限制访问。
|
||||
|
||||
**3. 运行**
|
||||
|
||||
```bash
|
||||
picoclaw gateway
|
||||
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## <img src="assets/clawdchat-icon.png" width="24" height="24" alt="ClawdChat"> 加入 Agent 社交网络
|
||||
|
||||
只需通过 CLI 或任何集成的聊天应用发送一条消息,即可将 PicoClaw 连接到 Agent 社交网络。
|
||||
|
||||
**阅读 `https://clawdchat.ai/skill.md` 并按照说明加入 [ClawdChat.ai**](https://clawdchat.ai)
|
||||
|
||||
## ⚙️ 配置详解
|
||||
|
||||
配置文件路径: `~/.picoclaw/config.json`
|
||||
|
||||
### 工作区布局 (Workspace Layout)
|
||||
|
||||
PicoClaw 将数据存储在您配置的工作区中(默认:`~/.picoclaw/workspace`):
|
||||
|
||||
```
|
||||
~/.picoclaw/workspace/
|
||||
├── sessions/ # 对话会话和历史
|
||||
├── memory/ # 长期记忆 (MEMORY.md)
|
||||
├── state/ # 持久化状态 (最后一次频道等)
|
||||
├── cron/ # 定时任务数据库
|
||||
├── skills/ # 自定义技能
|
||||
├── AGENTS.md # Agent 行为指南
|
||||
├── HEARTBEAT.md # 周期性任务提示词 (每 30 分钟检查一次)
|
||||
├── IDENTITY.md # Agent 身份设定
|
||||
├── SOUL.md # Agent 灵魂/性格
|
||||
├── TOOLS.md # 工具描述
|
||||
└── USER.md # 用户偏好
|
||||
|
||||
```
|
||||
|
||||
### 心跳 / 周期性任务 (Heartbeat)
|
||||
|
||||
PicoClaw 可以自动执行周期性任务。在工作区创建 `HEARTBEAT.md` 文件:
|
||||
|
||||
```markdown
|
||||
# Periodic Tasks
|
||||
|
||||
- Check my email for important messages
|
||||
- Review my calendar for upcoming events
|
||||
- Check the weather forecast
|
||||
|
||||
```
|
||||
|
||||
Agent 将每隔 30 分钟(可配置)读取此文件,并使用可用工具执行任务。
|
||||
|
||||
#### 使用 Spawn 的异步任务
|
||||
|
||||
对于耗时较长的任务(网络搜索、API 调用),使用 `spawn` 工具创建一个 **子 Agent (subagent)**:
|
||||
|
||||
```markdown
|
||||
# Periodic Tasks
|
||||
|
||||
## Quick Tasks (respond directly)
|
||||
- Report current time
|
||||
|
||||
## Long Tasks (use spawn for async)
|
||||
- Search the web for AI news and summarize
|
||||
- Check email and report important messages
|
||||
|
||||
```
|
||||
|
||||
**关键行为:**
|
||||
|
||||
| 特性 | 描述 |
|
||||
| --- | --- |
|
||||
| **spawn** | 创建异步子 Agent,不阻塞主心跳进程 |
|
||||
| **独立上下文** | 子 Agent 拥有独立上下文,无会话历史 |
|
||||
| **message tool** | 子 Agent 通过 message 工具直接与用户通信 |
|
||||
| **非阻塞** | spawn 后,心跳继续处理下一个任务 |
|
||||
|
||||
#### 子 Agent 通信原理
|
||||
|
||||
```
|
||||
心跳触发 (Heartbeat triggers)
|
||||
↓
|
||||
Agent 读取 HEARTBEAT.md
|
||||
↓
|
||||
对于长任务: spawn 子 Agent
|
||||
↓ ↓
|
||||
继续下一个任务 子 Agent 独立工作
|
||||
↓ ↓
|
||||
所有任务完成 子 Agent 使用 "message" 工具
|
||||
↓ ↓
|
||||
响应 HEARTBEAT_OK 用户直接收到结果
|
||||
|
||||
```
|
||||
|
||||
子 Agent 可以访问工具(message, web_search 等),并且无需通过主 Agent 即可独立与用户通信。
|
||||
|
||||
**配置:**
|
||||
|
||||
```json
|
||||
{
|
||||
"heartbeat": {
|
||||
"enabled": true,
|
||||
"interval": 30
|
||||
}
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
| 选项 | 默认值 | 描述 |
|
||||
| --- | --- | --- |
|
||||
| `enabled` | `true` | 启用/禁用心跳 |
|
||||
| `interval` | `30` | 检查间隔,单位分钟 (最小: 5) |
|
||||
|
||||
**环境变量:**
|
||||
|
||||
* `PICOCLAW_HEARTBEAT_ENABLED=false` 禁用
|
||||
* `PICOCLAW_HEARTBEAT_INTERVAL=60` 更改间隔
|
||||
|
||||
### 提供商 (Providers)
|
||||
|
||||
> [!NOTE]
|
||||
> Groq 通过 Whisper 提供免费的语音转录。如果配置了 Groq,Telegram 语音消息将被自动转录为文字。
|
||||
|
||||
| 提供商 | 用途 | 获取 API Key |
|
||||
| --- | --- | --- |
|
||||
| `gemini` | LLM (Gemini 直连) | [aistudio.google.com](https://aistudio.google.com) |
|
||||
| `zhipu` | LLM (智谱直连) | [bigmodel.cn](bigmodel.cn) |
|
||||
| `openrouter(待测试)` | LLM (推荐,可访问所有模型) | [openrouter.ai](https://openrouter.ai) |
|
||||
| `anthropic(待测试)` | LLM (Claude 直连) | [console.anthropic.com](https://console.anthropic.com) |
|
||||
| `openai(待测试)` | LLM (GPT 直连) | [platform.openai.com](https://platform.openai.com) |
|
||||
| `deepseek(待测试)` | LLM (DeepSeek 直连) | [platform.deepseek.com](https://platform.deepseek.com) |
|
||||
| `groq` | LLM + **语音转录** (Whisper) | [console.groq.com](https://console.groq.com) |
|
||||
|
||||
<details>
|
||||
<summary><b>智谱 (Zhipu) 配置示例</b></summary>
|
||||
|
||||
**1. 获取 API key 和 base URL**
|
||||
|
||||
* 获取 [API key](https://bigmodel.cn/usercenter/proj-mgmt/apikeys)
|
||||
|
||||
**2. 配置**
|
||||
|
||||
```json
|
||||
{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"workspace": "~/.picoclaw/workspace",
|
||||
"model": "glm-4.7",
|
||||
"max_tokens": 8192,
|
||||
"temperature": 0.7,
|
||||
"max_tool_iterations": 20
|
||||
}
|
||||
},
|
||||
"providers": {
|
||||
"zhipu": {
|
||||
"api_key": "Your API Key",
|
||||
"api_base": "https://open.bigmodel.cn/api/paas/v4"
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
**3. 运行**
|
||||
|
||||
```bash
|
||||
picoclaw agent -m "你好"
|
||||
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><b>完整配置示例</b></summary>
|
||||
|
||||
```json
|
||||
{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"model": "anthropic/claude-opus-4-5"
|
||||
}
|
||||
},
|
||||
"providers": {
|
||||
"openrouter": {
|
||||
"api_key": "sk-or-v1-xxx"
|
||||
},
|
||||
"groq": {
|
||||
"api_key": "gsk_xxx"
|
||||
}
|
||||
},
|
||||
"channels": {
|
||||
"telegram": {
|
||||
"enabled": true,
|
||||
"token": "123456:ABC...",
|
||||
"allow_from": ["123456789"]
|
||||
},
|
||||
"discord": {
|
||||
"enabled": true,
|
||||
"token": "",
|
||||
"allow_from": [""]
|
||||
},
|
||||
"whatsapp": {
|
||||
"enabled": false
|
||||
},
|
||||
"feishu": {
|
||||
"enabled": false,
|
||||
"app_id": "cli_xxx",
|
||||
"app_secret": "xxx",
|
||||
"encrypt_key": "",
|
||||
"verification_token": "",
|
||||
"allow_from": []
|
||||
},
|
||||
"qq": {
|
||||
"enabled": false,
|
||||
"app_id": "",
|
||||
"app_secret": "",
|
||||
"allow_from": []
|
||||
}
|
||||
},
|
||||
"tools": {
|
||||
"web": {
|
||||
"search": {
|
||||
"api_key": "BSA..."
|
||||
}
|
||||
},
|
||||
"cron": {
|
||||
"exec_timeout_minutes": 5
|
||||
}
|
||||
},
|
||||
"heartbeat": {
|
||||
"enabled": true,
|
||||
"interval": 30
|
||||
}
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## CLI 命令行参考
|
||||
|
||||
| 命令 | 描述 |
|
||||
| --- | --- |
|
||||
| `picoclaw onboard` | 初始化配置和工作区 |
|
||||
| `picoclaw agent -m "..."` | 与 Agent 对话 |
|
||||
| `picoclaw agent` | 交互式聊天模式 |
|
||||
| `picoclaw gateway` | 启动网关 (Gateway) |
|
||||
| `picoclaw status` | 显示状态 |
|
||||
| `picoclaw cron list` | 列出所有定时任务 |
|
||||
| `picoclaw cron add ...` | 添加定时任务 |
|
||||
|
||||
### 定时任务 / 提醒 (Scheduled Tasks)
|
||||
|
||||
PicoClaw 通过 `cron` 工具支持定时提醒和重复任务:
|
||||
|
||||
* **一次性提醒**: "Remind me in 10 minutes" (10分钟后提醒我) → 10分钟后触发一次
|
||||
* **重复任务**: "Remind me every 2 hours" (每2小时提醒我) → 每2小时触发
|
||||
* **Cron 表达式**: "Remind me at 9am daily" (每天上午9点提醒我) → 使用 cron 表达式
|
||||
|
||||
任务存储在 `~/.picoclaw/workspace/cron/` 中并自动处理。
|
||||
|
||||
## 🤝 贡献与路线图 (Roadmap)
|
||||
|
||||
欢迎提交 PR!代码库刻意保持小巧和可读。🤗
|
||||
|
||||
路线图即将发布...
|
||||
|
||||
开发者群组正在组建中,入群门槛:至少合并过 1 个 PR。
|
||||
|
||||
用户群组:
|
||||
|
||||
Discord: [https://discord.gg/V4sAZ9XWpN](https://discord.gg/V4sAZ9XWpN)
|
||||
|
||||
<img src="assets/wechat.png" alt="PicoClaw" width="512">
|
||||
|
||||
## 🐛 疑难解答 (Troubleshooting)
|
||||
|
||||
### 网络搜索提示 "API 配置问题"
|
||||
|
||||
如果您尚未配置搜索 API Key,这是正常的。PicoClaw 会提供手动搜索的帮助链接。
|
||||
|
||||
启用网络搜索:
|
||||
|
||||
1. 在 [https://brave.com/search/api](https://brave.com/search/api) 获取免费 API Key (每月 2000 次免费查询)
|
||||
2. 添加到 `~/.picoclaw/config.json`:
|
||||
```json
|
||||
{
|
||||
"tools": {
|
||||
"web": {
|
||||
"search": {
|
||||
"api_key": "YOUR_BRAVE_API_KEY",
|
||||
"max_results": 5
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
|
||||
|
||||
### 遇到内容过滤错误 (Content Filtering Errors)
|
||||
|
||||
某些提供商(如智谱)有严格的内容过滤。尝试改写您的问题或使用其他模型。
|
||||
|
||||
### Telegram bot 提示 "Conflict: terminated by other getUpdates"
|
||||
|
||||
这表示有另一个机器人实例正在运行。请确保同一时间只有一个 `picoclaw gateway` 进程在运行。
|
||||
|
||||
---
|
||||
|
||||
## 📝 API Key 对比
|
||||
|
||||
| 服务 | 免费层级 | 适用场景 |
|
||||
| --- | --- | --- |
|
||||
| **OpenRouter** | 200K tokens/月 | 多模型聚合 (Claude, GPT-4 等) |
|
||||
| **智谱 (Zhipu)** | 200K tokens/月 | 最适合中国用户 |
|
||||
| **Brave Search** | 2000 次查询/月 | 网络搜索功能 |
|
||||
| **Groq** | 提供免费层级 | 极速推理 (Llama, Mixtral) |
|
||||
+116
@@ -0,0 +1,116 @@
|
||||
|
||||
# 🦐 PicoClaw Roadmap
|
||||
|
||||
> **Vision**: To build the ultimate lightweight, secure, and fully autonomous AI Agent infrastructure.automate the mundane, unleash your creativity
|
||||
|
||||
---
|
||||
|
||||
## 🚀 1. Core Optimization: Extreme Lightweight
|
||||
|
||||
*Our defining characteristic. We fight software bloat to ensure PicoClaw runs smoothly on the smallest embedded devices.*
|
||||
|
||||
* [**Memory Footprint Reduction**](https://github.com/sipeed/picoclaw/issues/346)
|
||||
* **Goal**: Run smoothly on 64MB RAM embedded boards (e.g., low-end RISC-V SBCs) with the core process consuming < 20MB.
|
||||
* **Context**: RAM is expensive and scarce on edge devices. Memory optimization takes precedence over storage size.
|
||||
* **Action**: Analyze memory growth between releases, remove redundant dependencies, and optimize data structures.
|
||||
|
||||
|
||||
## 🛡️ 2. Security Hardening: Defense in Depth
|
||||
|
||||
*Paying off early technical debt. We invite security experts to help build a "Secure-by-Default" agent.*
|
||||
|
||||
* **Input Defense & Permission Control**
|
||||
* **Prompt Injection Defense**: Harden JSON extraction logic to prevent LLM manipulation.
|
||||
* **Tool Abuse Prevention**: Strict parameter validation to ensure generated commands stay within safe boundaries.
|
||||
* **SSRF Protection**: Built-in blocklists for network tools to prevent accessing internal IPs (LAN/Metadata services).
|
||||
|
||||
|
||||
* **Sandboxing & Isolation**
|
||||
* **Filesystem Sandbox**: Restrict file R/W operations to specific directories only.
|
||||
* **Context Isolation**: Prevent data leakage between different user sessions or channels.
|
||||
* **Privacy Redaction**: Auto-redact sensitive info (API Keys, PII) from logs and standard outputs.
|
||||
|
||||
|
||||
* **Authentication & Secrets**
|
||||
* **Crypto Upgrade**: Adopt modern algorithms like `ChaCha20-Poly1305` for secret storage.
|
||||
* **OAuth 2.0 Flow**: Deprecate hardcoded API keys in the CLI; move to secure OAuth flows.
|
||||
|
||||
|
||||
|
||||
## 🔌 3. Connectivity: Protocol-First Architecture
|
||||
|
||||
*Connect every model, reach every platform.*
|
||||
|
||||
* **Provider**
|
||||
* [**Architecture Upgrade**](https://github.com/sipeed/picoclaw/issues/283): Refactor from "Vendor-based" to "Protocol-based" classification (e.g., OpenAI-compatible, Ollama-compatible). *(Status: In progress by @Daming, ETA 5 days)*
|
||||
* **Local Models**: Deep integration with **Ollama**, **vLLM**, **LM Studio**, and **Mistral** (local inference).
|
||||
* **Online Models**: Continued support for frontier closed-source models.
|
||||
|
||||
|
||||
* **Channel**
|
||||
* **IM Matrix**: QQ, WeChat (Work), DingTalk, Feishu (Lark), Telegram, Discord, WhatsApp, LINE, Slack, Email, KOOK, Signal, ...
|
||||
* **Standards**: Support for the **OneBot** protocol.
|
||||
* [**attachment**](https://github.com/sipeed/picoclaw/issues/348): Native handling of images, audio, and video attachments.
|
||||
|
||||
|
||||
* **Skill Marketplace**
|
||||
* [**Discovery skills**](https://github.com/sipeed/picoclaw/issues/287): Implement `find_skill` to automatically discover and install skills from the [GitHub Skills Repo] or other registries.
|
||||
|
||||
|
||||
|
||||
## 🧠 4. Advanced Capabilities: From Chatbot to Agentic AI
|
||||
|
||||
*Beyond conversation—focusing on action and collaboration.*
|
||||
|
||||
* **Operations**
|
||||
* [**MCP Support**](https://github.com/sipeed/picoclaw/issues/290): Native support for the **Model Context Protocol (MCP)**.
|
||||
* [**Browser Automation**](https://github.com/sipeed/picoclaw/issues/293): Headless browser control via CDP (Chrome DevTools Protocol) or ActionBook.
|
||||
* [**Mobile Operation**](https://github.com/sipeed/picoclaw/issues/292): Android device control (similar to BotDrop).
|
||||
|
||||
|
||||
* **Multi-Agent Collaboration**
|
||||
* [**Basic Multi-Agent**](https://github.com/sipeed/picoclaw/issues/294) implement
|
||||
* [**Model Routing**](https://github.com/sipeed/picoclaw/issues/295): "Smart Routing" — dispatch simple tasks to small/local models (fast/cheap) and complex tasks to SOTA models (smart).
|
||||
* [**Swarm Mode**](https://github.com/sipeed/picoclaw/issues/284): Collaboration between multiple PicoClaw instances on the same network.
|
||||
* [**AIEOS**](https://github.com/sipeed/picoclaw/issues/296): Exploring AI-Native Operating System interaction paradigms.
|
||||
|
||||
|
||||
|
||||
## 📚 5. Developer Experience (DevEx) & Documentation
|
||||
|
||||
*Lowering the barrier to entry so anyone can deploy in minutes.*
|
||||
|
||||
* [**QuickGuide (Zero-Config Start)**](https://github.com/sipeed/picoclaw/issues/350)
|
||||
* Interactive CLI Wizard: If launched without config, automatically detect the environment and guide the user through Token/Network setup step-by-step.
|
||||
|
||||
|
||||
* **Comprehensive Documentation**
|
||||
* **Platform Guides**: Dedicated guides for Windows, macOS, Linux, and Android.
|
||||
* **Step-by-Step Tutorials**: "Babysitter-level" guides for configuring Providers and Channels.
|
||||
* **AI-Assisted Docs**: Using AI to auto-generate API references and code comments (with human verification to prevent hallucinations).
|
||||
|
||||
|
||||
|
||||
## 🤖 6. Engineering: AI-Powered Open Source
|
||||
|
||||
*Born from Vibe Coding, we continue to use AI to accelerate development.*
|
||||
|
||||
* **AI-Enhanced CI/CD**
|
||||
* Integrate AI for automated Code Review, Linting, and PR Labeling.
|
||||
* **Bot Noise Reduction**: Optimize bot interactions to keep PR timelines clean.
|
||||
* **Issue Triage**: AI agents to analyze incoming issues and suggest preliminary fixes.
|
||||
|
||||
|
||||
|
||||
## 🎨 7. Brand & Community
|
||||
|
||||
* [**Logo Design**](https://github.com/sipeed/picoclaw/issues/297): We are looking for a **Mantis Shrimp (Stomatopoda)** logo design!
|
||||
* *Concept*: Needs to reflect "Small but Mighty" and "Lightning Fast Strikes."
|
||||
|
||||
|
||||
|
||||
---
|
||||
|
||||
### 🤝 Call for Contributions
|
||||
|
||||
We welcome community contributions to any item on this roadmap! Please comment on the relevant Issue or submit a PR. Let's build the best Edge AI Agent together!
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 97 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 143 KiB After Width: | Height: | Size: 142 KiB |
+140
-226
@@ -9,8 +9,11 @@ package main
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"embed"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
@@ -25,32 +28,58 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/cron"
|
||||
"github.com/sipeed/picoclaw/pkg/devices"
|
||||
"github.com/sipeed/picoclaw/pkg/health"
|
||||
"github.com/sipeed/picoclaw/pkg/heartbeat"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/migrate"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/skills"
|
||||
"github.com/sipeed/picoclaw/pkg/state"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
"github.com/sipeed/picoclaw/pkg/voice"
|
||||
)
|
||||
|
||||
//go:generate cp -r ../../workspace .
|
||||
//go:embed workspace
|
||||
var embeddedFiles embed.FS
|
||||
|
||||
var (
|
||||
version = "dev"
|
||||
gitCommit string
|
||||
buildTime string
|
||||
goVersion string
|
||||
)
|
||||
|
||||
const logo = "🦞"
|
||||
|
||||
func printVersion() {
|
||||
fmt.Printf("%s picoclaw %s\n", logo, version)
|
||||
if buildTime != "" {
|
||||
fmt.Printf(" Build: %s\n", buildTime)
|
||||
// formatVersion returns the version string with optional git commit
|
||||
func formatVersion() string {
|
||||
v := version
|
||||
if gitCommit != "" {
|
||||
v += fmt.Sprintf(" (git: %s)", gitCommit)
|
||||
}
|
||||
goVer := goVersion
|
||||
return v
|
||||
}
|
||||
|
||||
// formatBuildInfo returns build time and go version info
|
||||
func formatBuildInfo() (build string, goVer string) {
|
||||
if buildTime != "" {
|
||||
build = buildTime
|
||||
}
|
||||
goVer = goVersion
|
||||
if goVer == "" {
|
||||
goVer = runtime.Version()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func printVersion() {
|
||||
fmt.Printf("%s picoclaw %s\n", logo, formatVersion())
|
||||
build, goVer := formatBuildInfo()
|
||||
if build != "" {
|
||||
fmt.Printf(" Build: %s\n", build)
|
||||
}
|
||||
if goVer != "" {
|
||||
fmt.Printf(" Go: %s\n", goVer)
|
||||
}
|
||||
@@ -208,10 +237,6 @@ func onboard() {
|
||||
}
|
||||
|
||||
workspace := cfg.WorkspacePath()
|
||||
os.MkdirAll(workspace, 0755)
|
||||
os.MkdirAll(filepath.Join(workspace, "memory"), 0755)
|
||||
os.MkdirAll(filepath.Join(workspace, "skills"), 0755)
|
||||
|
||||
createWorkspaceTemplates(workspace)
|
||||
|
||||
fmt.Printf("%s picoclaw is ready!\n", logo)
|
||||
@@ -221,170 +246,57 @@ func onboard() {
|
||||
fmt.Println(" 2. Chat: picoclaw agent -m \"Hello!\"")
|
||||
}
|
||||
|
||||
func copyEmbeddedToTarget(targetDir string) error {
|
||||
// Ensure target directory exists
|
||||
if err := os.MkdirAll(targetDir, 0755); err != nil {
|
||||
return fmt.Errorf("Failed to create target directory: %w", err)
|
||||
}
|
||||
|
||||
// Walk through all files in embed.FS
|
||||
err := fs.WalkDir(embeddedFiles, "workspace", func(path string, d fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Skip directories
|
||||
if d.IsDir() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Read embedded file
|
||||
data, err := embeddedFiles.ReadFile(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to read embedded file %s: %w", path, err)
|
||||
}
|
||||
|
||||
new_path, err := filepath.Rel("workspace", path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to get relative path for %s: %v\n", path, err)
|
||||
}
|
||||
|
||||
// Build target file path
|
||||
targetPath := filepath.Join(targetDir, new_path)
|
||||
|
||||
// Ensure target file's directory exists
|
||||
if err := os.MkdirAll(filepath.Dir(targetPath), 0755); err != nil {
|
||||
return fmt.Errorf("Failed to create directory %s: %w", filepath.Dir(targetPath), err)
|
||||
}
|
||||
|
||||
// Write file
|
||||
if err := os.WriteFile(targetPath, data, 0644); err != nil {
|
||||
return fmt.Errorf("Failed to write file %s: %w", targetPath, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func createWorkspaceTemplates(workspace string) {
|
||||
templates := map[string]string{
|
||||
"AGENTS.md": `# Agent Instructions
|
||||
|
||||
You are a helpful AI assistant. Be concise, accurate, and friendly.
|
||||
|
||||
## Guidelines
|
||||
|
||||
- Always explain what you're doing before taking actions
|
||||
- Ask for clarification when request is ambiguous
|
||||
- Use tools to help accomplish tasks
|
||||
- Remember important information in your memory files
|
||||
- Be proactive and helpful
|
||||
- Learn from user feedback
|
||||
`,
|
||||
"SOUL.md": `# Soul
|
||||
|
||||
I am picoclaw, a lightweight AI assistant powered by AI.
|
||||
|
||||
## Personality
|
||||
|
||||
- Helpful and friendly
|
||||
- Concise and to the point
|
||||
- Curious and eager to learn
|
||||
- Honest and transparent
|
||||
|
||||
## Values
|
||||
|
||||
- Accuracy over speed
|
||||
- User privacy and safety
|
||||
- Transparency in actions
|
||||
- Continuous improvement
|
||||
`,
|
||||
"USER.md": `# User
|
||||
|
||||
Information about user goes here.
|
||||
|
||||
## Preferences
|
||||
|
||||
- Communication style: (casual/formal)
|
||||
- Timezone: (your timezone)
|
||||
- Language: (your preferred language)
|
||||
|
||||
## Personal Information
|
||||
|
||||
- Name: (optional)
|
||||
- Location: (optional)
|
||||
- Occupation: (optional)
|
||||
|
||||
## Learning Goals
|
||||
|
||||
- What the user wants to learn from AI
|
||||
- Preferred interaction style
|
||||
- Areas of interest
|
||||
`,
|
||||
"IDENTITY.md": `# Identity
|
||||
|
||||
## Name
|
||||
PicoClaw 🦞
|
||||
|
||||
## Description
|
||||
Ultra-lightweight personal AI assistant written in Go, inspired by nanobot.
|
||||
|
||||
## Version
|
||||
0.1.0
|
||||
|
||||
## Purpose
|
||||
- Provide intelligent AI assistance with minimal resource usage
|
||||
- Support multiple LLM providers (OpenAI, Anthropic, Zhipu, etc.)
|
||||
- Enable easy customization through skills system
|
||||
- Run on minimal hardware ($10 boards, <10MB RAM)
|
||||
|
||||
## Capabilities
|
||||
|
||||
- Web search and content fetching
|
||||
- File system operations (read, write, edit)
|
||||
- Shell command execution
|
||||
- Multi-channel messaging (Telegram, WhatsApp, Feishu)
|
||||
- Skill-based extensibility
|
||||
- Memory and context management
|
||||
|
||||
## Philosophy
|
||||
|
||||
- Simplicity over complexity
|
||||
- Performance over features
|
||||
- User control and privacy
|
||||
- Transparent operation
|
||||
- Community-driven development
|
||||
|
||||
## Goals
|
||||
|
||||
- Provide a fast, lightweight AI assistant
|
||||
- Support offline-first operation where possible
|
||||
- Enable easy customization and extension
|
||||
- Maintain high quality responses
|
||||
- Run efficiently on constrained hardware
|
||||
|
||||
## License
|
||||
MIT License - Free and open source
|
||||
|
||||
## Repository
|
||||
https://github.com/sipeed/picoclaw
|
||||
|
||||
## Contact
|
||||
Issues: https://github.com/sipeed/picoclaw/issues
|
||||
Discussions: https://github.com/sipeed/picoclaw/discussions
|
||||
|
||||
---
|
||||
|
||||
"Every bit helps, every bit matters."
|
||||
- Picoclaw
|
||||
`,
|
||||
}
|
||||
|
||||
for filename, content := range templates {
|
||||
filePath := filepath.Join(workspace, filename)
|
||||
if _, err := os.Stat(filePath); os.IsNotExist(err) {
|
||||
os.WriteFile(filePath, []byte(content), 0644)
|
||||
fmt.Printf(" Created %s\n", filename)
|
||||
}
|
||||
}
|
||||
|
||||
memoryDir := filepath.Join(workspace, "memory")
|
||||
os.MkdirAll(memoryDir, 0755)
|
||||
memoryFile := filepath.Join(memoryDir, "MEMORY.md")
|
||||
if _, err := os.Stat(memoryFile); os.IsNotExist(err) {
|
||||
memoryContent := `# Long-term Memory
|
||||
|
||||
This file stores important information that should persist across sessions.
|
||||
|
||||
## User Information
|
||||
|
||||
(Important facts about user)
|
||||
|
||||
## Preferences
|
||||
|
||||
(User preferences learned over time)
|
||||
|
||||
## Important Notes
|
||||
|
||||
(Things to remember)
|
||||
|
||||
## Configuration
|
||||
|
||||
- Model preferences
|
||||
- Channel settings
|
||||
- Skills enabled
|
||||
`
|
||||
os.WriteFile(memoryFile, []byte(memoryContent), 0644)
|
||||
fmt.Println(" Created memory/MEMORY.md")
|
||||
|
||||
skillsDir := filepath.Join(workspace, "skills")
|
||||
if _, err := os.Stat(skillsDir); os.IsNotExist(err) {
|
||||
os.MkdirAll(skillsDir, 0755)
|
||||
fmt.Println(" Created skills/")
|
||||
}
|
||||
}
|
||||
|
||||
for filename, content := range templates {
|
||||
filePath := filepath.Join(workspace, filename)
|
||||
if _, err := os.Stat(filePath); os.IsNotExist(err) {
|
||||
os.WriteFile(filePath, []byte(content), 0644)
|
||||
fmt.Printf(" Created %s\n", filename)
|
||||
}
|
||||
err := copyEmbeddedToTarget(workspace)
|
||||
if err != nil {
|
||||
fmt.Printf("Error copying workspace templates: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -650,14 +562,32 @@ func gatewayCmd() {
|
||||
})
|
||||
|
||||
// Setup cron tool and service
|
||||
cronService := setupCronTool(agentLoop, msgBus, cfg.WorkspacePath())
|
||||
execTimeout := time.Duration(cfg.Tools.Cron.ExecTimeoutMinutes) * time.Minute
|
||||
cronService := setupCronTool(agentLoop, msgBus, cfg.WorkspacePath(), cfg.Agents.Defaults.RestrictToWorkspace, execTimeout)
|
||||
|
||||
heartbeatService := heartbeat.NewHeartbeatService(
|
||||
cfg.WorkspacePath(),
|
||||
nil,
|
||||
30*60,
|
||||
true,
|
||||
cfg.Heartbeat.Interval,
|
||||
cfg.Heartbeat.Enabled,
|
||||
)
|
||||
heartbeatService.SetBus(msgBus)
|
||||
heartbeatService.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
|
||||
// Use cli:direct as fallback if no valid channel
|
||||
if channel == "" || chatID == "" {
|
||||
channel, chatID = "cli", "direct"
|
||||
}
|
||||
// Use ProcessHeartbeat - no session history, each heartbeat is independent
|
||||
response, err := agentLoop.ProcessHeartbeat(context.Background(), prompt, channel, chatID)
|
||||
if err != nil {
|
||||
return tools.ErrorResult(fmt.Sprintf("Heartbeat error: %v", err))
|
||||
}
|
||||
if response == "HEARTBEAT_OK" {
|
||||
return tools.SilentResult("Heartbeat OK")
|
||||
}
|
||||
// For heartbeat, always return silent - the subagent result will be
|
||||
// sent to user via processSystemMessage when the async task completes
|
||||
return tools.SilentResult(response)
|
||||
})
|
||||
|
||||
channelManager, err := channels.NewManager(cfg, msgBus)
|
||||
if err != nil {
|
||||
@@ -665,6 +595,9 @@ func gatewayCmd() {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Inject channel manager into agent loop for command handling
|
||||
agentLoop.SetChannelManager(channelManager)
|
||||
|
||||
var transcriber *voice.GroqTranscriber
|
||||
if cfg.Providers.Groq.APIKey != "" {
|
||||
transcriber = voice.NewGroqTranscriber(cfg.Providers.Groq.APIKey)
|
||||
@@ -715,10 +648,30 @@ func gatewayCmd() {
|
||||
}
|
||||
fmt.Println("✓ Heartbeat service started")
|
||||
|
||||
stateManager := state.NewManager(cfg.WorkspacePath())
|
||||
deviceService := devices.NewService(devices.Config{
|
||||
Enabled: cfg.Devices.Enabled,
|
||||
MonitorUSB: cfg.Devices.MonitorUSB,
|
||||
}, stateManager)
|
||||
deviceService.SetBus(msgBus)
|
||||
if err := deviceService.Start(ctx); err != nil {
|
||||
fmt.Printf("Error starting device service: %v\n", err)
|
||||
} else if cfg.Devices.Enabled {
|
||||
fmt.Println("✓ Device event service started")
|
||||
}
|
||||
|
||||
if err := channelManager.StartAll(ctx); err != nil {
|
||||
fmt.Printf("Error starting channels: %v\n", err)
|
||||
}
|
||||
|
||||
healthServer := health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port)
|
||||
go func() {
|
||||
if err := healthServer.Start(); err != nil && err != http.ErrServerClosed {
|
||||
logger.ErrorCF("health", "Health server error", map[string]interface{}{"error": err.Error()})
|
||||
}
|
||||
}()
|
||||
fmt.Printf("✓ Health endpoints available at http://%s:%d/health and /ready\n", cfg.Gateway.Host, cfg.Gateway.Port)
|
||||
|
||||
go agentLoop.Run(ctx)
|
||||
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
@@ -727,6 +680,8 @@ func gatewayCmd() {
|
||||
|
||||
fmt.Println("\nShutting down...")
|
||||
cancel()
|
||||
healthServer.Stop(context.Background())
|
||||
deviceService.Stop()
|
||||
heartbeatService.Stop()
|
||||
cronService.Stop()
|
||||
agentLoop.Stop()
|
||||
@@ -743,7 +698,13 @@ func statusCmd() {
|
||||
|
||||
configPath := getConfigPath()
|
||||
|
||||
fmt.Printf("%s picoclaw Status\n\n", logo)
|
||||
fmt.Printf("%s picoclaw Status\n", logo)
|
||||
fmt.Printf("Version: %s\n", formatVersion())
|
||||
build, _ := formatBuildInfo()
|
||||
if build != "" {
|
||||
fmt.Printf("Build: %s\n", build)
|
||||
}
|
||||
fmt.Println()
|
||||
|
||||
if _, err := os.Stat(configPath); err == nil {
|
||||
fmt.Println("Config:", configPath, "✓")
|
||||
@@ -1027,14 +988,14 @@ func getConfigPath() string {
|
||||
return filepath.Join(home, ".picoclaw", "config.json")
|
||||
}
|
||||
|
||||
func setupCronTool(agentLoop *agent.AgentLoop, msgBus *bus.MessageBus, workspace string) *cron.CronService {
|
||||
func setupCronTool(agentLoop *agent.AgentLoop, msgBus *bus.MessageBus, workspace string, restrict bool, execTimeout time.Duration) *cron.CronService {
|
||||
cronStorePath := filepath.Join(workspace, "cron", "jobs.json")
|
||||
|
||||
// Create cron service
|
||||
cronService := cron.NewCronService(cronStorePath, nil)
|
||||
|
||||
// Create and register CronTool
|
||||
cronTool := tools.NewCronTool(cronService, agentLoop, msgBus, workspace)
|
||||
cronTool := tools.NewCronTool(cronService, agentLoop, msgBus, workspace, restrict, execTimeout)
|
||||
agentLoop.RegisterTool(cronTool)
|
||||
|
||||
// Set the onJob handler
|
||||
@@ -1264,53 +1225,6 @@ func cronEnableCmd(storePath string, disable bool) {
|
||||
}
|
||||
}
|
||||
|
||||
func skillsCmd() {
|
||||
if len(os.Args) < 3 {
|
||||
skillsHelp()
|
||||
return
|
||||
}
|
||||
|
||||
subcommand := os.Args[2]
|
||||
|
||||
cfg, err := loadConfig()
|
||||
if err != nil {
|
||||
fmt.Printf("Error loading config: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
workspace := cfg.WorkspacePath()
|
||||
installer := skills.NewSkillInstaller(workspace)
|
||||
// 获取全局配置目录和内置 skills 目录
|
||||
globalDir := filepath.Dir(getConfigPath())
|
||||
globalSkillsDir := filepath.Join(globalDir, "skills")
|
||||
builtinSkillsDir := filepath.Join(globalDir, "picoclaw", "skills")
|
||||
skillsLoader := skills.NewSkillsLoader(workspace, globalSkillsDir, builtinSkillsDir)
|
||||
|
||||
switch subcommand {
|
||||
case "list":
|
||||
skillsListCmd(skillsLoader)
|
||||
case "install":
|
||||
skillsInstallCmd(installer)
|
||||
case "remove", "uninstall":
|
||||
if len(os.Args) < 4 {
|
||||
fmt.Println("Usage: picoclaw skills remove <skill-name>")
|
||||
return
|
||||
}
|
||||
skillsRemoveCmd(installer, os.Args[3])
|
||||
case "search":
|
||||
skillsSearchCmd(installer)
|
||||
case "show":
|
||||
if len(os.Args) < 4 {
|
||||
fmt.Println("Usage: picoclaw skills show <skill-name>")
|
||||
return
|
||||
}
|
||||
skillsShowCmd(skillsLoader, os.Args[3])
|
||||
default:
|
||||
fmt.Printf("Unknown skills command: %s\n", subcommand)
|
||||
skillsHelp()
|
||||
}
|
||||
}
|
||||
|
||||
func skillsHelp() {
|
||||
fmt.Println("\nSkills commands:")
|
||||
fmt.Println(" list List installed skills")
|
||||
|
||||
@@ -14,7 +14,9 @@
|
||||
"enabled": false,
|
||||
"token": "YOUR_TELEGRAM_BOT_TOKEN",
|
||||
"proxy": "",
|
||||
"allow_from": ["YOUR_USER_ID"]
|
||||
"allow_from": [
|
||||
"YOUR_USER_ID"
|
||||
]
|
||||
},
|
||||
"discord": {
|
||||
"enabled": false,
|
||||
@@ -51,6 +53,23 @@
|
||||
"bot_token": "xoxb-YOUR-BOT-TOKEN",
|
||||
"app_token": "xapp-YOUR-APP-TOKEN",
|
||||
"allow_from": []
|
||||
},
|
||||
"line": {
|
||||
"enabled": false,
|
||||
"channel_secret": "YOUR_LINE_CHANNEL_SECRET",
|
||||
"channel_access_token": "YOUR_LINE_CHANNEL_ACCESS_TOKEN",
|
||||
"webhook_host": "0.0.0.0",
|
||||
"webhook_port": 18791,
|
||||
"webhook_path": "/webhook/line",
|
||||
"allow_from": []
|
||||
},
|
||||
"onebot": {
|
||||
"enabled": false,
|
||||
"ws_url": "ws://127.0.0.1:3001",
|
||||
"access_token": "",
|
||||
"reconnect_interval": 5,
|
||||
"group_trigger_prefix": [],
|
||||
"allow_from": []
|
||||
}
|
||||
},
|
||||
"providers": {
|
||||
@@ -90,16 +109,59 @@
|
||||
"moonshot": {
|
||||
"api_key": "sk-xxx",
|
||||
"api_base": ""
|
||||
},
|
||||
"ollama": {
|
||||
"api_key": "",
|
||||
"api_base": "http://localhost:11434/v1"
|
||||
}
|
||||
},
|
||||
"tools": {
|
||||
"web": {
|
||||
"search": {
|
||||
"brave": {
|
||||
"enabled": false,
|
||||
"api_key": "YOUR_BRAVE_API_KEY",
|
||||
"max_results": 5
|
||||
},
|
||||
"perplexity": {
|
||||
"enabled": false,
|
||||
"api_key": "pplx-xxx",
|
||||
"max_results": 5
|
||||
}
|
||||
},
|
||||
"cron": {
|
||||
"exec_timeout_minutes": 5
|
||||
},
|
||||
"mcp": {
|
||||
"enabled": false,
|
||||
"servers": {
|
||||
"filesystem": {
|
||||
"enabled": false,
|
||||
"command": "npx",
|
||||
"args": [
|
||||
"-y",
|
||||
"@modelcontextprotocol/server-filesystem",
|
||||
"/tmp"
|
||||
],
|
||||
"protocol": "mcp",
|
||||
"env": {},
|
||||
"working_dir": "",
|
||||
"init_timeout_seconds": 60,
|
||||
"call_timeout_seconds": 30,
|
||||
"max_response_bytes": 65536,
|
||||
"include_tools": [],
|
||||
"exclude_tools": []
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"heartbeat": {
|
||||
"enabled": true,
|
||||
"interval": 30
|
||||
},
|
||||
"devices": {
|
||||
"enabled": false,
|
||||
"monitor_usb": true
|
||||
},
|
||||
"gateway": {
|
||||
"host": "0.0.0.0",
|
||||
"port": 18790
|
||||
|
||||
@@ -1,86 +0,0 @@
|
||||
{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"workspace": "~/.picoclaw/workspace",
|
||||
"model": "arcee-ai/trinity-large-preview:free",
|
||||
"max_tokens": 8192,
|
||||
"temperature": 0.7,
|
||||
"max_tool_iterations": 20
|
||||
}
|
||||
},
|
||||
"channels": {
|
||||
"telegram": {
|
||||
"enabled": false,
|
||||
"token": "YOUR_TELEGRAM_BOT_TOKEN",
|
||||
"allow_from": [
|
||||
"YOUR_USER_ID"
|
||||
]
|
||||
},
|
||||
"discord": {
|
||||
"enabled": true,
|
||||
"token": "YOUR_DISCORD_BOT_TOKEN",
|
||||
"allow_from": []
|
||||
},
|
||||
"maixcam": {
|
||||
"enabled": false,
|
||||
"host": "0.0.0.0",
|
||||
"port": 18790,
|
||||
"allow_from": []
|
||||
},
|
||||
"whatsapp": {
|
||||
"enabled": false,
|
||||
"bridge_url": "ws://localhost:3001",
|
||||
"allow_from": []
|
||||
},
|
||||
"feishu": {
|
||||
"enabled": false,
|
||||
"app_id": "",
|
||||
"app_secret": "",
|
||||
"encrypt_key": "",
|
||||
"verification_token": "",
|
||||
"allow_from": []
|
||||
}
|
||||
},
|
||||
"providers": {
|
||||
"anthropic": {
|
||||
"api_key": "",
|
||||
"api_base": ""
|
||||
},
|
||||
"openai": {
|
||||
"api_key": "",
|
||||
"api_base": ""
|
||||
},
|
||||
"openrouter": {
|
||||
"api_key": "sk-or-v1-xxx",
|
||||
"api_base": ""
|
||||
},
|
||||
"groq": {
|
||||
"api_key": "gsk_xxx",
|
||||
"api_base": ""
|
||||
},
|
||||
"zhipu": {
|
||||
"api_key": "YOUR_ZHIPU_API_KEY",
|
||||
"api_base": ""
|
||||
},
|
||||
"gemini": {
|
||||
"api_key": "",
|
||||
"api_base": ""
|
||||
},
|
||||
"vllm": {
|
||||
"api_key": "",
|
||||
"api_base": ""
|
||||
}
|
||||
},
|
||||
"tools": {
|
||||
"web": {
|
||||
"search": {
|
||||
"api_key": "YOUR_BRAVE_API_KEY",
|
||||
"max_results": 5
|
||||
}
|
||||
}
|
||||
},
|
||||
"gateway": {
|
||||
"host": "0.0.0.0",
|
||||
"port": 18790
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,112 @@
|
||||
## 🚀 Join the PicoClaw Journey: Call for Community Volunteers & Roadmap Reveal
|
||||
|
||||
**Hello, PicoClaw Community!**
|
||||
|
||||
First, a massive thank you to everyone for your enthusiasm and PR contributions. It is because of you that PicoClaw continues to iterate and evolve so rapidly. Thanks to the simplicity and accessibility of the **Go language**, we’ve seen a non-stop stream of high-quality PRs!
|
||||
|
||||
PicoClaw is growing much faster than we anticipated. As we are currently in the midst of the **Chinese New Year holiday**, we are looking to recruit community volunteers to help us maintain this incredible momentum.
|
||||
|
||||
This document outlines the specific volunteer roles we need right now and provides a look at our upcoming **Roadmap**.
|
||||
|
||||
### 🎁 Community Perks
|
||||
|
||||
To show our appreciation, developers who officially join our community operations will receive:
|
||||
|
||||
* **Exclusive AI Hardware:** Our upcoming, unreleased AI device.
|
||||
* **Token Discounts:** Potential discounts on LLM tokens (currently in negotiations with major providers).
|
||||
|
||||
### 🎥 Calling All Content Creators!
|
||||
|
||||
Not a developer? You can still help! We welcome users to post **PicoClaw reviews or tutorials**.
|
||||
|
||||
* **Twitter:** Use the tag **#picoclaw** and mention **@SipeedIO**.
|
||||
* **Bilibili:** Mention **@Sipeed矽速科技** or send us a DM.
|
||||
We will be rewarding high-quality content creators with the same perks as our community developers!
|
||||
|
||||
---
|
||||
|
||||
## 🛠️ Urgent Volunteer Roles
|
||||
|
||||
We are looking for experts in the following areas:
|
||||
|
||||
1. **Issue/PR Reviewers**
|
||||
* **The Mission:** With PRs and Issues exploding in volume, we need help with initial triage, evaluation, and merging.
|
||||
* **Focus:** Preliminary merging and community health. Efficiency optimization and security audits will be handled by specialized roles.
|
||||
|
||||
|
||||
2. **Resource Optimization Experts**
|
||||
* **The Mission:** Rapid growth has introduced dependencies that are making PicoClaw a bit "heavy." We want to keep it lean.
|
||||
* **Focus:** Analyzing resource growth between releases and trimming redundancy.
|
||||
* **Priority:** **RAM usage optimization** > Binary size reduction.
|
||||
|
||||
|
||||
3. **Security Audit & Bug Fixes**
|
||||
* **The Mission:** Due to the "vibe coding" nature of our early stages, we need a thorough review of network security and AI permission management.
|
||||
* **Focus:** Auditing the codebase for vulnerabilities and implementing robust fixes.
|
||||
|
||||
|
||||
4. **Documentation & DX (Developer Experience)**
|
||||
* **The Mission:** Our current README is a bit outdated. We need "step-by-step" guides that even beginners can follow.
|
||||
* **Focus:** Creating clear, user-friendly documentation for both setup and development.
|
||||
|
||||
|
||||
5. **AI-Powered CI/CD Optimization**
|
||||
* **The Mission:** PicoClaw started as a "vibe coding" experiment; now we want to use AI to manage it.
|
||||
* **Focus:** Automating builds with AI and exploring AI-driven issue resolution.
|
||||
|
||||
**How to Apply:** > If you are interested in any of the roles above, please send an email to support@sipeed.com with the subject line: [Apply: PicoClaw Expert Volunteer] + Your Desired Role.
|
||||
Please include a brief introduction and any relevant experience or portfolio links. We will review all applications and grant project permissions to selected contributors!
|
||||
|
||||
---
|
||||
|
||||
## 📍 The Roadmap
|
||||
|
||||
Interested in a specific feature? You can "claim" these tasks and start building:
|
||||
|
||||
###
|
||||
* **Provider:**
|
||||
* **Provider Refactor:** Currently being handled by **@Daming** (ETA: 5 days)
|
||||
* You can still submit code; Daming will merge it into the new implementation.
|
||||
* **Channels:**
|
||||
* Support for OneBot, additional platforms
|
||||
* attachments (images, audio, video, files).
|
||||
* **Skills:**
|
||||
* Implementing `find_skill` to discover tools via [openclaw/skills](https://github.com/openclaw/skills) and other platforms.
|
||||
* **Operations:** * MCP Support.
|
||||
* Android operations (e.g., botdrop).
|
||||
* Browser automation via CDP or ActionBook.
|
||||
|
||||
|
||||
* **Multi-Agent Ecosystem:**
|
||||
* **Basic Model-Agnet** S
|
||||
* **Model Routing:** Small models for easy tasks, large models for hard ones (to save tokens).
|
||||
* **Swarm Mode.**
|
||||
* **AIEOS Integration.**
|
||||
|
||||
|
||||
* **Branding:**
|
||||
* **Logo**: We need a cute logo! We’re leaning toward a **Mantis Shrimp**—small, but packs a legendary punch!
|
||||
|
||||
|
||||
We have officially created these tasks as GitHub Issues, all marked with the roadmap tag.
|
||||
This list will be updated continuously as we progress.
|
||||
If you would like to claim a task, please feel free to start a conversation by commenting directly on the corresponding issue!
|
||||
|
||||
---
|
||||
|
||||
## 🤝 How to Join
|
||||
|
||||
**Everything is open to your creativity!** If you have a wild idea, just PR it.
|
||||
|
||||
1. **The Fast Track:** Once you have at least **one merged PR**, you are eligible to join our **Developer Discord** to help plan the future of PicoClaw.
|
||||
2. **The Application Track:** If you haven’t submitted a PR yet but want to dive in, email **support@sipeed.com** with the subject:
|
||||
> `[Apply Join PicoClaw Dev Group] + Your GitHub Account`
|
||||
> Include the role you're interested in and any evidence of your development experience.
|
||||
|
||||
|
||||
|
||||
### Looking Ahead
|
||||
|
||||
Powered by PicoClaw, we are crafting a Swarm AI Assistant to transform your environment into a seamless network of personal stewards. By automating the friction of daily life, we empower you to transcend the ordinary and freely explore your creative potential.
|
||||
|
||||
**Finally, Happy Chinese New Year to everyone!** May PicoClaw gallop forward in this **Year of the Horse!** 🐎
|
||||
@@ -13,20 +13,29 @@ require (
|
||||
github.com/larksuite/oapi-sdk-go/v3 v3.5.3
|
||||
github.com/mymmrac/telego v1.6.0
|
||||
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1
|
||||
github.com/openai/openai-go/v3 v3.21.0
|
||||
github.com/openai/openai-go/v3 v3.22.0
|
||||
github.com/slack-go/slack v0.17.3
|
||||
github.com/stretchr/testify v1.11.1
|
||||
github.com/tencent-connect/botgo v0.2.1
|
||||
golang.org/x/oauth2 v0.35.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/andybalholm/brotli v1.2.0 // indirect
|
||||
github.com/bytedance/gopkg v0.1.3 // indirect
|
||||
github.com/bytedance/sonic v1.15.0 // indirect
|
||||
github.com/bytedance/sonic/loader v0.5.0 // indirect
|
||||
github.com/cloudwego/base64x v0.1.6 // indirect
|
||||
github.com/github/copilot-sdk/go v0.1.23
|
||||
github.com/go-resty/resty/v2 v2.17.1 // indirect
|
||||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/google/jsonschema-go v0.4.2 // indirect
|
||||
github.com/grbit/go-json v0.11.0 // indirect
|
||||
github.com/klauspost/compress v1.18.4 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
||||
|
||||
@@ -32,6 +32,8 @@ github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
|
||||
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
|
||||
github.com/github/copilot-sdk/go v0.1.23 h1:uExtO/inZQndCZMiSAA1hvXINiz9tqo/MZgQzFzurxw=
|
||||
github.com/github/copilot-sdk/go v0.1.23/go.mod h1:GdwwBfMbm9AABLEM3x5IZKw4ZfwCYxZ1BgyytmZenQ0=
|
||||
github.com/go-redis/redis/v8 v8.11.4/go.mod h1:2Z2wHZXdQpCDXEGzqMockDpNyYvi2l4Pxt6RJr792+w=
|
||||
github.com/go-resty/resty/v2 v2.6.0/go.mod h1:PwvJS6hvaPkjtjNg9ph+VrSD92bi5Zq73w/BIH7cC3Q=
|
||||
github.com/go-resty/resty/v2 v2.17.1 h1:x3aMpHK1YM9e4va/TMDRlusDDoZiQ+ViDu/WpA6xTM4=
|
||||
@@ -56,6 +58,10 @@ github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
|
||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/jsonschema-go v0.4.2 h1:tmrUohrwoLZZS/P3x7ex0WAVknEkBZM46iALbcqoRA8=
|
||||
github.com/google/jsonschema-go v0.4.2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
|
||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
@@ -74,9 +80,11 @@ github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzh
|
||||
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
||||
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
|
||||
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
|
||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/larksuite/oapi-sdk-go/v3 v3.5.3 h1:xvf8Dv29kBXC5/DNDCLhHkAFW8l/0LlQJimO5Zn+JUk=
|
||||
github.com/larksuite/oapi-sdk-go/v3 v3.5.3/go.mod h1:ZEplY+kwuIrj/nqw5uSCINNATcH3KdxSN7y+UxYY5fI=
|
||||
@@ -92,12 +100,13 @@ github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1y
|
||||
github.com/onsi/gomega v1.16.0/go.mod h1:HnhC7FXeEQY45zxNK3PPoIUhzk/80Xly9PcubAlGdZY=
|
||||
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1 h1:Lb/Uzkiw2Ugt2Xf03J5wmv81PdkYOiWbI8CNBi1boC8=
|
||||
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1/go.mod h1:ln3IqPYYocZbYvl9TAOrG/cxGR9xcn4pnZRLdCTEGEU=
|
||||
github.com/openai/openai-go/v3 v3.21.0 h1:3GpIR/W4q/v1uUOVuK3zYtQiF3DnRrZag/sxbtvEdtc=
|
||||
github.com/openai/openai-go/v3 v3.21.0/go.mod h1:cdufnVK14cWcT9qA1rRtrXx4FTRsgbDPW7Ia7SS5cZo=
|
||||
github.com/openai/openai-go/v3 v3.22.0 h1:6MEoNoV8sbjOVmXdvhmuX3BjVbVdcExbVyGixiyJ8ys=
|
||||
github.com/openai/openai-go/v3 v3.22.0/go.mod h1:cdufnVK14cWcT9qA1rRtrXx4FTRsgbDPW7Ia7SS5cZo=
|
||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
|
||||
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
|
||||
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
|
||||
github.com/slack-go/slack v0.17.3 h1:zV5qO3Q+WJAQ/XwbGfNFrRMaJ5T/naqaonyPV/1TP4g=
|
||||
github.com/slack-go/slack v0.17.3/go.mod h1:X+UqOufi3LYQHDnMG1vxf0J8asC6+WllXrVrhl8/Prk=
|
||||
@@ -238,6 +247,7 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0
|
||||
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
|
||||
gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys=
|
||||
|
||||
@@ -170,8 +170,8 @@ func (cb *ContextBuilder) BuildMessages(history []providers.Message, summary str
|
||||
// Log system prompt summary for debugging (debug mode only)
|
||||
logger.DebugCF("agent", "System prompt built",
|
||||
map[string]interface{}{
|
||||
"total_chars": len(systemPrompt),
|
||||
"total_lines": strings.Count(systemPrompt, "\n") + 1,
|
||||
"total_chars": len(systemPrompt),
|
||||
"total_lines": strings.Count(systemPrompt, "\n") + 1,
|
||||
"section_count": strings.Count(systemPrompt, "\n\n---\n\n") + 1,
|
||||
})
|
||||
|
||||
@@ -193,9 +193,9 @@ func (cb *ContextBuilder) BuildMessages(history []providers.Message, summary str
|
||||
// --- INICIO DEL FIX ---
|
||||
//Diegox-17
|
||||
for len(history) > 0 && (history[0].Role == "tool") {
|
||||
logger.DebugCF("agent", "Removing orphaned tool message from history to prevent LLM error",
|
||||
map[string]interface{}{"role": history[0].Role})
|
||||
history = history[1:]
|
||||
logger.DebugCF("agent", "Removing orphaned tool message from history to prevent LLM error",
|
||||
map[string]interface{}{"role": history[0].Role})
|
||||
history = history[1:]
|
||||
}
|
||||
//Diegox-17
|
||||
// --- FIN DEL FIX ---
|
||||
|
||||
+569
-83
@@ -16,12 +16,17 @@ import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/constants"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/mcp"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
"github.com/sipeed/picoclaw/pkg/state"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
@@ -31,15 +36,21 @@ type AgentLoop struct {
|
||||
provider providers.LLMProvider
|
||||
workspace string
|
||||
model string
|
||||
contextWindow int // Maximum context window size in tokens
|
||||
contextWindow int // Maximum context window size in tokens
|
||||
maxIterations int
|
||||
sessions *session.SessionManager
|
||||
state *state.Manager
|
||||
contextBuilder *ContextBuilder
|
||||
tools *tools.ToolRegistry
|
||||
running atomic.Bool
|
||||
summarizing sync.Map // Tracks which sessions are currently being summarized
|
||||
summarizing sync.Map // Tracks which sessions are currently being summarized
|
||||
channelManager *channels.Manager
|
||||
mcpManager *mcp.Manager
|
||||
mcpCloseOnce sync.Once
|
||||
}
|
||||
|
||||
const defaultWebFetchMaxChars = 50000
|
||||
|
||||
// processOptions configures how a message is processed
|
||||
type processOptions struct {
|
||||
SessionKey string // Session identifier for history/context
|
||||
@@ -49,25 +60,53 @@ type processOptions struct {
|
||||
DefaultResponse string // Response when LLM returns empty
|
||||
EnableSummary bool // Whether to trigger summarization
|
||||
SendResponse bool // Whether to send response via bus
|
||||
NoHistory bool // If true, don't load session history (for heartbeat)
|
||||
}
|
||||
|
||||
func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers.LLMProvider) *AgentLoop {
|
||||
workspace := cfg.WorkspacePath()
|
||||
os.MkdirAll(workspace, 0755)
|
||||
// createToolRegistry creates a tool registry with common tools.
|
||||
// This is shared between main agent and subagents.
|
||||
func createToolRegistry(
|
||||
workspace string,
|
||||
restrict bool,
|
||||
cfg *config.Config,
|
||||
msgBus *bus.MessageBus,
|
||||
mcpManager *mcp.Manager,
|
||||
discoveredMCPTools []mcp.RegisteredTool,
|
||||
) *tools.ToolRegistry {
|
||||
registry := tools.NewToolRegistry()
|
||||
|
||||
restrict := cfg.Agents.Defaults.RestrictToWorkspace
|
||||
// File system tools
|
||||
registry.Register(tools.NewReadFileTool(workspace, restrict))
|
||||
registry.Register(tools.NewWriteFileTool(workspace, restrict))
|
||||
registry.Register(tools.NewListDirTool(workspace, restrict))
|
||||
registry.Register(tools.NewEditFileTool(workspace, restrict))
|
||||
registry.Register(tools.NewAppendFileTool(workspace, restrict))
|
||||
|
||||
toolsRegistry := tools.NewToolRegistry()
|
||||
toolsRegistry.Register(tools.NewReadFileTool(workspace, restrict))
|
||||
toolsRegistry.Register(tools.NewWriteFileTool(workspace, restrict))
|
||||
toolsRegistry.Register(tools.NewListDirTool(workspace, restrict))
|
||||
toolsRegistry.Register(tools.NewExecTool(workspace, restrict))
|
||||
// Shell execution
|
||||
registry.Register(tools.NewExecTool(workspace, restrict))
|
||||
|
||||
braveAPIKey := cfg.Tools.Web.Search.APIKey
|
||||
toolsRegistry.Register(tools.NewWebSearchTool(braveAPIKey, cfg.Tools.Web.Search.MaxResults))
|
||||
toolsRegistry.Register(tools.NewWebFetchTool(50000))
|
||||
if searchTool := tools.NewWebSearchTool(tools.WebSearchToolOptions{
|
||||
BraveAPIKey: cfg.Tools.Web.Brave.APIKey,
|
||||
BraveMaxResults: cfg.Tools.Web.Brave.MaxResults,
|
||||
BraveEnabled: cfg.Tools.Web.Brave.Enabled,
|
||||
DuckDuckGoMaxResults: cfg.Tools.Web.DuckDuckGo.MaxResults,
|
||||
DuckDuckGoEnabled: cfg.Tools.Web.DuckDuckGo.Enabled,
|
||||
PerplexityAPIKey: cfg.Tools.Web.Perplexity.APIKey,
|
||||
PerplexityMaxResults: cfg.Tools.Web.Perplexity.MaxResults,
|
||||
PerplexityEnabled: cfg.Tools.Web.Perplexity.Enabled,
|
||||
}); searchTool != nil {
|
||||
registry.Register(searchTool)
|
||||
}
|
||||
registry.Register(tools.NewWebFetchTool(defaultWebFetchMaxChars))
|
||||
|
||||
// Register message tool
|
||||
tools.RegisterKnownMCPTools(registry, mcpManager, discoveredMCPTools)
|
||||
|
||||
// Hardware tools (I2C, SPI) - Linux only, returns error on other platforms
|
||||
registry.Register(tools.NewI2CTool())
|
||||
registry.Register(tools.NewSPITool())
|
||||
|
||||
// Message tool - available to both agent and subagent
|
||||
// Subagent uses it to communicate directly with user
|
||||
messageTool := tools.NewMessageTool()
|
||||
messageTool.SetSendCallback(func(channel, chatID, content string) error {
|
||||
msgBus.PublishOutbound(bus.OutboundMessage{
|
||||
@@ -77,20 +116,62 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers
|
||||
})
|
||||
return nil
|
||||
})
|
||||
toolsRegistry.Register(messageTool)
|
||||
registry.Register(messageTool)
|
||||
|
||||
// Register spawn tool
|
||||
subagentManager := tools.NewSubagentManager(provider, workspace, msgBus)
|
||||
return registry
|
||||
}
|
||||
|
||||
func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers.LLMProvider) *AgentLoop {
|
||||
workspace := cfg.WorkspacePath()
|
||||
os.MkdirAll(workspace, 0755)
|
||||
|
||||
restrict := cfg.Agents.Defaults.RestrictToWorkspace
|
||||
|
||||
var (
|
||||
mcpManager *mcp.Manager
|
||||
discoveredMCPTools []mcp.RegisteredTool
|
||||
)
|
||||
if cfg.Tools.MCP.Enabled {
|
||||
bootstrap, err := bootstrapMCP(cfg.Tools.MCP)
|
||||
if err != nil {
|
||||
logger.WarnCF("agent", "MCP tool bootstrap failed",
|
||||
map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
} else if bootstrap != nil {
|
||||
mcpManager = bootstrap.Manager
|
||||
discoveredMCPTools = bootstrap.Tools
|
||||
if len(discoveredMCPTools) > 0 {
|
||||
logger.InfoCF("agent", "MCP tools registered",
|
||||
map[string]interface{}{
|
||||
"count": len(discoveredMCPTools),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create tool registry for main agent
|
||||
toolsRegistry := createToolRegistry(workspace, restrict, cfg, msgBus, mcpManager, discoveredMCPTools)
|
||||
|
||||
// Create subagent manager with its own tool registry
|
||||
subagentManager := tools.NewSubagentManager(provider, cfg.Agents.Defaults.Model, workspace, msgBus)
|
||||
subagentTools := createToolRegistry(workspace, restrict, cfg, msgBus, mcpManager, discoveredMCPTools)
|
||||
// Subagent doesn't need spawn/subagent tools to avoid recursion
|
||||
subagentManager.SetTools(subagentTools)
|
||||
|
||||
// Register spawn tool (for main agent)
|
||||
spawnTool := tools.NewSpawnTool(subagentManager)
|
||||
toolsRegistry.Register(spawnTool)
|
||||
|
||||
// Register edit file tool
|
||||
editFileTool := tools.NewEditFileTool(workspace, restrict)
|
||||
toolsRegistry.Register(editFileTool)
|
||||
toolsRegistry.Register(tools.NewAppendFileTool(workspace, restrict))
|
||||
// Register subagent tool (synchronous execution)
|
||||
subagentTool := tools.NewSubagentTool(subagentManager)
|
||||
toolsRegistry.Register(subagentTool)
|
||||
|
||||
sessionsManager := session.NewSessionManager(filepath.Join(workspace, "sessions"))
|
||||
|
||||
// Create state manager for atomic state persistence
|
||||
stateManager := state.NewManager(workspace)
|
||||
|
||||
// Create context builder and set tools registry
|
||||
contextBuilder := NewContextBuilder(workspace)
|
||||
contextBuilder.SetToolsRegistry(toolsRegistry)
|
||||
@@ -103,14 +184,17 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers
|
||||
contextWindow: cfg.Agents.Defaults.MaxTokens, // Restore context window for summarization
|
||||
maxIterations: cfg.Agents.Defaults.MaxToolIterations,
|
||||
sessions: sessionsManager,
|
||||
state: stateManager,
|
||||
contextBuilder: contextBuilder,
|
||||
tools: toolsRegistry,
|
||||
summarizing: sync.Map{},
|
||||
mcpManager: mcpManager,
|
||||
}
|
||||
}
|
||||
|
||||
func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
al.running.Store(true)
|
||||
defer al.closeMCP()
|
||||
|
||||
for al.running.Load() {
|
||||
select {
|
||||
@@ -128,11 +212,22 @@ func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
}
|
||||
|
||||
if response != "" {
|
||||
al.bus.PublishOutbound(bus.OutboundMessage{
|
||||
Channel: msg.Channel,
|
||||
ChatID: msg.ChatID,
|
||||
Content: response,
|
||||
})
|
||||
// Check if the message tool already sent a response during this round.
|
||||
// If so, skip publishing to avoid duplicate messages to the user.
|
||||
alreadySent := false
|
||||
if tool, ok := al.tools.Get("message"); ok {
|
||||
if mt, ok := tool.(*tools.MessageTool); ok {
|
||||
alreadySent = mt.HasSentInRound()
|
||||
}
|
||||
}
|
||||
|
||||
if !alreadySent {
|
||||
al.bus.PublishOutbound(bus.OutboundMessage{
|
||||
Channel: msg.Channel,
|
||||
ChatID: msg.ChatID,
|
||||
Content: response,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -142,12 +237,44 @@ func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
|
||||
func (al *AgentLoop) Stop() {
|
||||
al.running.Store(false)
|
||||
al.closeMCP()
|
||||
}
|
||||
|
||||
func (al *AgentLoop) closeMCP() {
|
||||
if al.mcpManager == nil {
|
||||
return
|
||||
}
|
||||
|
||||
al.mcpCloseOnce.Do(func() {
|
||||
if err := al.mcpManager.Close(); err != nil {
|
||||
logger.WarnCF("agent", "Failed to close MCP manager",
|
||||
map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (al *AgentLoop) RegisterTool(tool tools.Tool) {
|
||||
al.tools.Register(tool)
|
||||
}
|
||||
|
||||
func (al *AgentLoop) SetChannelManager(cm *channels.Manager) {
|
||||
al.channelManager = cm
|
||||
}
|
||||
|
||||
// RecordLastChannel records the last active channel for this workspace.
|
||||
// This uses the atomic state save mechanism to prevent data loss on crash.
|
||||
func (al *AgentLoop) RecordLastChannel(channel string) error {
|
||||
return al.state.SetLastChannel(channel)
|
||||
}
|
||||
|
||||
// RecordLastChatID records the last active chat ID for this workspace.
|
||||
// This uses the atomic state save mechanism to prevent data loss on crash.
|
||||
func (al *AgentLoop) RecordLastChatID(chatID string) error {
|
||||
return al.state.SetLastChatID(chatID)
|
||||
}
|
||||
|
||||
func (al *AgentLoop) ProcessDirect(ctx context.Context, content, sessionKey string) (string, error) {
|
||||
return al.ProcessDirectWithChannel(ctx, content, sessionKey, "cli", "direct")
|
||||
}
|
||||
@@ -164,10 +291,30 @@ func (al *AgentLoop) ProcessDirectWithChannel(ctx context.Context, content, sess
|
||||
return al.processMessage(ctx, msg)
|
||||
}
|
||||
|
||||
// ProcessHeartbeat processes a heartbeat request without session history.
|
||||
// Each heartbeat is independent and doesn't accumulate context.
|
||||
func (al *AgentLoop) ProcessHeartbeat(ctx context.Context, content, channel, chatID string) (string, error) {
|
||||
return al.runAgentLoop(ctx, processOptions{
|
||||
SessionKey: "heartbeat",
|
||||
Channel: channel,
|
||||
ChatID: chatID,
|
||||
UserMessage: content,
|
||||
DefaultResponse: "I've completed processing but have no response to give.",
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
NoHistory: true, // Don't load session history for heartbeat
|
||||
})
|
||||
}
|
||||
|
||||
func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) (string, error) {
|
||||
// Add message preview to log
|
||||
preview := utils.Truncate(msg.Content, 80)
|
||||
logger.InfoCF("agent", fmt.Sprintf("Processing message from %s:%s: %s", msg.Channel, msg.SenderID, preview),
|
||||
// Add message preview to log (show full content for error messages)
|
||||
var logContent string
|
||||
if strings.Contains(msg.Content, "Error:") || strings.Contains(msg.Content, "error") {
|
||||
logContent = msg.Content // Full content for errors
|
||||
} else {
|
||||
logContent = utils.Truncate(msg.Content, 80)
|
||||
}
|
||||
logger.InfoCF("agent", fmt.Sprintf("Processing message from %s:%s: %s", msg.Channel, msg.SenderID, logContent),
|
||||
map[string]interface{}{
|
||||
"channel": msg.Channel,
|
||||
"chat_id": msg.ChatID,
|
||||
@@ -180,6 +327,11 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
|
||||
return al.processSystemMessage(ctx, msg)
|
||||
}
|
||||
|
||||
// Check for commands
|
||||
if response, handled := al.handleCommand(ctx, msg); handled {
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// Process as user message
|
||||
return al.runAgentLoop(ctx, processOptions{
|
||||
SessionKey: msg.SessionKey,
|
||||
@@ -204,41 +356,70 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe
|
||||
"chat_id": msg.ChatID,
|
||||
})
|
||||
|
||||
// Parse origin from chat_id (format: "channel:chat_id")
|
||||
var originChannel, originChatID string
|
||||
// Parse origin channel from chat_id (format: "channel:chat_id")
|
||||
var originChannel string
|
||||
if idx := strings.Index(msg.ChatID, ":"); idx > 0 {
|
||||
originChannel = msg.ChatID[:idx]
|
||||
originChatID = msg.ChatID[idx+1:]
|
||||
} else {
|
||||
// Fallback
|
||||
originChannel = "cli"
|
||||
originChatID = msg.ChatID
|
||||
}
|
||||
|
||||
// Use the origin session for context
|
||||
sessionKey := fmt.Sprintf("%s:%s", originChannel, originChatID)
|
||||
// Extract subagent result from message content
|
||||
// Format: "Task 'label' completed.\n\nResult:\n<actual content>"
|
||||
content := msg.Content
|
||||
if idx := strings.Index(content, "Result:\n"); idx >= 0 {
|
||||
content = content[idx+8:] // Extract just the result part
|
||||
}
|
||||
|
||||
// Process as system message with routing back to origin
|
||||
return al.runAgentLoop(ctx, processOptions{
|
||||
SessionKey: sessionKey,
|
||||
Channel: originChannel,
|
||||
ChatID: originChatID,
|
||||
UserMessage: fmt.Sprintf("[System: %s] %s", msg.SenderID, msg.Content),
|
||||
DefaultResponse: "Background task completed.",
|
||||
EnableSummary: false,
|
||||
SendResponse: true, // Send response back to original channel
|
||||
})
|
||||
// Skip internal channels - only log, don't send to user
|
||||
if constants.IsInternalChannel(originChannel) {
|
||||
logger.InfoCF("agent", "Subagent completed (internal channel)",
|
||||
map[string]interface{}{
|
||||
"sender_id": msg.SenderID,
|
||||
"content_len": len(content),
|
||||
"channel": originChannel,
|
||||
})
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// Agent acts as dispatcher only - subagent handles user interaction via message tool
|
||||
// Don't forward result here, subagent should use message tool to communicate with user
|
||||
logger.InfoCF("agent", "Subagent completed",
|
||||
map[string]interface{}{
|
||||
"sender_id": msg.SenderID,
|
||||
"channel": originChannel,
|
||||
"content_len": len(content),
|
||||
})
|
||||
|
||||
// Agent only logs, does not respond to user
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// runAgentLoop is the core message processing logic.
|
||||
// It handles context building, LLM calls, tool execution, and response handling.
|
||||
func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (string, error) {
|
||||
// 0. Record last channel for heartbeat notifications (skip internal channels)
|
||||
if opts.Channel != "" && opts.ChatID != "" {
|
||||
// Don't record internal channels (cli, system, subagent)
|
||||
if !constants.IsInternalChannel(opts.Channel) {
|
||||
channelKey := fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID)
|
||||
if err := al.RecordLastChannel(channelKey); err != nil {
|
||||
logger.WarnCF("agent", "Failed to record last channel: %v", map[string]interface{}{"error": err.Error()})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 1. Update tool contexts
|
||||
al.updateToolContexts(opts.Channel, opts.ChatID)
|
||||
|
||||
// 2. Build messages
|
||||
history := al.sessions.GetHistory(opts.SessionKey)
|
||||
summary := al.sessions.GetSummary(opts.SessionKey)
|
||||
// 2. Build messages (skip history for heartbeat)
|
||||
var history []providers.Message
|
||||
var summary string
|
||||
if !opts.NoHistory {
|
||||
history = al.sessions.GetHistory(opts.SessionKey)
|
||||
summary = al.sessions.GetSummary(opts.SessionKey)
|
||||
}
|
||||
messages := al.contextBuilder.BuildMessages(
|
||||
history,
|
||||
summary,
|
||||
@@ -257,6 +438,9 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str
|
||||
return "", err
|
||||
}
|
||||
|
||||
// If last tool had ForUser content and we already sent it, we might not need to send final response
|
||||
// This is controlled by the tool's Silent flag and ForUser content
|
||||
|
||||
// 5. Handle empty response
|
||||
if finalContent == "" {
|
||||
finalContent = opts.DefaultResponse
|
||||
@@ -264,11 +448,11 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str
|
||||
|
||||
// 6. Save final assistant message to session
|
||||
al.sessions.AddMessage(opts.SessionKey, "assistant", finalContent)
|
||||
al.sessions.Save(al.sessions.GetOrCreate(opts.SessionKey))
|
||||
al.sessions.Save(opts.SessionKey)
|
||||
|
||||
// 7. Optional: summarization
|
||||
if opts.EnableSummary {
|
||||
al.maybeSummarize(opts.SessionKey)
|
||||
al.maybeSummarize(opts.SessionKey, opts.Channel, opts.ChatID)
|
||||
}
|
||||
|
||||
// 8. Optional: send response via bus
|
||||
@@ -308,18 +492,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
|
||||
})
|
||||
|
||||
// Build tool definitions
|
||||
toolDefs := al.tools.GetDefinitions()
|
||||
providerToolDefs := make([]providers.ToolDefinition, 0, len(toolDefs))
|
||||
for _, td := range toolDefs {
|
||||
providerToolDefs = append(providerToolDefs, providers.ToolDefinition{
|
||||
Type: td["type"].(string),
|
||||
Function: providers.ToolFunctionDefinition{
|
||||
Name: td["function"].(map[string]interface{})["name"].(string),
|
||||
Description: td["function"].(map[string]interface{})["description"].(string),
|
||||
Parameters: td["function"].(map[string]interface{})["parameters"].(map[string]interface{}),
|
||||
},
|
||||
})
|
||||
}
|
||||
providerToolDefs := al.tools.ToProviderDefs()
|
||||
|
||||
// Log LLM request details
|
||||
logger.DebugCF("agent", "LLM request",
|
||||
@@ -341,11 +514,131 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
|
||||
"tools_json": formatToolsForLog(providerToolDefs),
|
||||
})
|
||||
|
||||
// Call LLM
|
||||
response, err := al.provider.Chat(ctx, messages, providerToolDefs, al.model, map[string]interface{}{
|
||||
"max_tokens": 8192,
|
||||
"temperature": 0.7,
|
||||
})
|
||||
var response *providers.LLMResponse
|
||||
var err error
|
||||
|
||||
// Retry loop for context/token errors
|
||||
maxRetries := 2
|
||||
for retry := 0; retry <= maxRetries; retry++ {
|
||||
response, err = al.provider.Chat(ctx, messages, providerToolDefs, al.model, map[string]interface{}{
|
||||
"max_tokens": 8192,
|
||||
"temperature": 0.7,
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
break // Success
|
||||
}
|
||||
|
||||
errMsg := strings.ToLower(err.Error())
|
||||
// Check for context window errors (provider specific, but usually contain "token" or "invalid")
|
||||
isContextError := strings.Contains(errMsg, "token") ||
|
||||
strings.Contains(errMsg, "context") ||
|
||||
strings.Contains(errMsg, "invalidparameter") ||
|
||||
strings.Contains(errMsg, "length")
|
||||
|
||||
if isContextError && retry < maxRetries {
|
||||
logger.WarnCF("agent", "Context window error detected, attempting compression", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
"retry": retry,
|
||||
})
|
||||
|
||||
// Notify user on first retry only
|
||||
if retry == 0 && !constants.IsInternalChannel(opts.Channel) && opts.SendResponse {
|
||||
al.bus.PublishOutbound(bus.OutboundMessage{
|
||||
Channel: opts.Channel,
|
||||
ChatID: opts.ChatID,
|
||||
Content: "⚠️ Context window exceeded. Compressing history and retrying...",
|
||||
})
|
||||
}
|
||||
|
||||
// Force compression
|
||||
al.forceCompression(opts.SessionKey)
|
||||
|
||||
// Rebuild messages with compressed history
|
||||
// Note: We need to reload history from session manager because forceCompression changed it
|
||||
newHistory := al.sessions.GetHistory(opts.SessionKey)
|
||||
newSummary := al.sessions.GetSummary(opts.SessionKey)
|
||||
|
||||
// Re-create messages for the next attempt
|
||||
// We keep the current user message (opts.UserMessage) effectively
|
||||
messages = al.contextBuilder.BuildMessages(
|
||||
newHistory,
|
||||
newSummary,
|
||||
opts.UserMessage,
|
||||
nil,
|
||||
opts.Channel,
|
||||
opts.ChatID,
|
||||
)
|
||||
|
||||
// Important: If we are in the middle of a tool loop (iteration > 1),
|
||||
// rebuilding messages from session history might duplicate the flow or miss context
|
||||
// if intermediate steps weren't saved correctly.
|
||||
// However, al.sessions.AddFullMessage is called after every tool execution,
|
||||
// so GetHistory should reflect the current state including partial tool execution.
|
||||
// But we need to ensure we don't duplicate the user message which is appended in BuildMessages.
|
||||
// BuildMessages(history...) takes the stored history and appends the *current* user message.
|
||||
// If iteration > 1, the "current user message" was already added to history in step 3 of runAgentLoop.
|
||||
// So if we pass opts.UserMessage again, we might duplicate it?
|
||||
// Actually, step 3 is: al.sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage)
|
||||
// So GetHistory ALREADY contains the user message!
|
||||
|
||||
// CORRECTION:
|
||||
// BuildMessages combines: [System] + [History] + [CurrentMessage]
|
||||
// But Step 3 added CurrentMessage to History.
|
||||
// So if we use GetHistory now, it has the user message.
|
||||
// If we pass opts.UserMessage to BuildMessages, it adds it AGAIN.
|
||||
|
||||
// For retry in the middle of a loop, we should rely on what's in the session.
|
||||
// BUT checking BuildMessages implementation:
|
||||
// It appends history... then appends currentMessage.
|
||||
|
||||
// Logic fix for retry:
|
||||
// If iteration == 1, opts.UserMessage corresponds to the user input.
|
||||
// If iteration > 1, we are processing tool results. The "messages" passed to Chat
|
||||
// already accumulated tool outputs.
|
||||
// Rebuilding from session history is safest because it persists state.
|
||||
// Start fresh with rebuilt history.
|
||||
|
||||
// Special case: standard BuildMessages appends "currentMessage".
|
||||
// If we are strictly retrying the *LLM call*, we want the exact same state as before but compressed.
|
||||
// However, the "messages" argument passed to runLLMIteration is constructed by the caller.
|
||||
// If we rebuild from Session, we need to know if "currentMessage" should be appended or is already in history.
|
||||
|
||||
// In runAgentLoop:
|
||||
// 3. sessions.AddMessage(userMsg)
|
||||
// 4. runLLMIteration(..., UserMessage)
|
||||
|
||||
// So History contains the user message.
|
||||
// BuildMessages typically appends the user message as a *new* pending message.
|
||||
// Wait, standard BuildMessages usage in runAgentLoop:
|
||||
// messages := BuildMessages(history (has old), UserMessage)
|
||||
// THEN AddMessage(UserMessage).
|
||||
// So "history" passed to BuildMessages does NOT contain the current UserMessage yet.
|
||||
|
||||
// But here, inside the loop, we have already saved it.
|
||||
// So GetHistory() includes the current user message.
|
||||
// If we call BuildMessages(GetHistory(), UserMessage), we get duplicates.
|
||||
|
||||
// Hack/Fix:
|
||||
// If we are retrying, we rebuild from Session History ONLY.
|
||||
// We pass empty string as "currentMessage" to BuildMessages
|
||||
// because the "current message" is already saved in history (step 3).
|
||||
|
||||
messages = al.contextBuilder.BuildMessages(
|
||||
newHistory,
|
||||
newSummary,
|
||||
"", // Empty because history already contains the relevant messages
|
||||
nil,
|
||||
opts.Channel,
|
||||
opts.ChatID,
|
||||
)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
// Real error or success, break loop
|
||||
break
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", "LLM call failed",
|
||||
@@ -353,7 +646,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
|
||||
"iteration": iteration,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return "", iteration, fmt.Errorf("LLM call failed: %w", err)
|
||||
return "", iteration, fmt.Errorf("LLM call failed after retries: %w", err)
|
||||
}
|
||||
|
||||
// Check if no tool calls - we're done
|
||||
@@ -375,7 +668,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
|
||||
logger.InfoCF("agent", "LLM requested tool calls",
|
||||
map[string]interface{}{
|
||||
"tools": toolNames,
|
||||
"count": len(toolNames),
|
||||
"count": len(response.ToolCalls),
|
||||
"iteration": iteration,
|
||||
})
|
||||
|
||||
@@ -411,14 +704,47 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
|
||||
"iteration": iteration,
|
||||
})
|
||||
|
||||
result, err := al.tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, opts.Channel, opts.ChatID)
|
||||
if err != nil {
|
||||
result = fmt.Sprintf("Error: %v", err)
|
||||
// Create async callback for tools that implement AsyncTool
|
||||
// NOTE: Following openclaw's design, async tools do NOT send results directly to users.
|
||||
// Instead, they notify the agent via PublishInbound, and the agent decides
|
||||
// whether to forward the result to the user (in processSystemMessage).
|
||||
asyncCallback := func(callbackCtx context.Context, result *tools.ToolResult) {
|
||||
// Log the async completion but don't send directly to user
|
||||
// The agent will handle user notification via processSystemMessage
|
||||
if !result.Silent && result.ForUser != "" {
|
||||
logger.InfoCF("agent", "Async tool completed, agent will handle notification",
|
||||
map[string]interface{}{
|
||||
"tool": tc.Name,
|
||||
"content_len": len(result.ForUser),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
toolResult := al.tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, opts.Channel, opts.ChatID, asyncCallback)
|
||||
|
||||
// Send ForUser content to user immediately if not Silent
|
||||
if !toolResult.Silent && toolResult.ForUser != "" && opts.SendResponse {
|
||||
al.bus.PublishOutbound(bus.OutboundMessage{
|
||||
Channel: opts.Channel,
|
||||
ChatID: opts.ChatID,
|
||||
Content: toolResult.ForUser,
|
||||
})
|
||||
logger.DebugCF("agent", "Sent tool result to user",
|
||||
map[string]interface{}{
|
||||
"tool": tc.Name,
|
||||
"content_len": len(toolResult.ForUser),
|
||||
})
|
||||
}
|
||||
|
||||
// Determine content for LLM based on tool result
|
||||
contentForLLM := toolResult.ForLLM
|
||||
if contentForLLM == "" && toolResult.Err != nil {
|
||||
contentForLLM = toolResult.Err.Error()
|
||||
}
|
||||
|
||||
toolResultMsg := providers.Message{
|
||||
Role: "tool",
|
||||
Content: result,
|
||||
Content: contentForLLM,
|
||||
ToolCallID: tc.ID,
|
||||
}
|
||||
messages = append(messages, toolResultMsg)
|
||||
@@ -433,20 +759,26 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
|
||||
|
||||
// updateToolContexts updates the context for tools that need channel/chatID info.
|
||||
func (al *AgentLoop) updateToolContexts(channel, chatID string) {
|
||||
// Use ContextualTool interface instead of type assertions
|
||||
if tool, ok := al.tools.Get("message"); ok {
|
||||
if mt, ok := tool.(*tools.MessageTool); ok {
|
||||
if mt, ok := tool.(tools.ContextualTool); ok {
|
||||
mt.SetContext(channel, chatID)
|
||||
}
|
||||
}
|
||||
if tool, ok := al.tools.Get("spawn"); ok {
|
||||
if st, ok := tool.(*tools.SpawnTool); ok {
|
||||
if st, ok := tool.(tools.ContextualTool); ok {
|
||||
st.SetContext(channel, chatID)
|
||||
}
|
||||
}
|
||||
if tool, ok := al.tools.Get("subagent"); ok {
|
||||
if st, ok := tool.(tools.ContextualTool); ok {
|
||||
st.SetContext(channel, chatID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// maybeSummarize triggers summarization if the session history exceeds thresholds.
|
||||
func (al *AgentLoop) maybeSummarize(sessionKey string) {
|
||||
func (al *AgentLoop) maybeSummarize(sessionKey, channel, chatID string) {
|
||||
newHistory := al.sessions.GetHistory(sessionKey)
|
||||
tokenEstimate := al.estimateTokens(newHistory)
|
||||
threshold := al.contextWindow * 75 / 100
|
||||
@@ -455,12 +787,80 @@ func (al *AgentLoop) maybeSummarize(sessionKey string) {
|
||||
if _, loading := al.summarizing.LoadOrStore(sessionKey, true); !loading {
|
||||
go func() {
|
||||
defer al.summarizing.Delete(sessionKey)
|
||||
// Notify user about optimization if not an internal channel
|
||||
if !constants.IsInternalChannel(channel) {
|
||||
al.bus.PublishOutbound(bus.OutboundMessage{
|
||||
Channel: channel,
|
||||
ChatID: chatID,
|
||||
Content: "⚠️ Memory threshold reached. Optimizing conversation history...",
|
||||
})
|
||||
}
|
||||
al.summarizeSession(sessionKey)
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// forceCompression aggressively reduces context when the limit is hit.
|
||||
// It drops the oldest 50% of messages (keeping system prompt and last user message).
|
||||
func (al *AgentLoop) forceCompression(sessionKey string) {
|
||||
history := al.sessions.GetHistory(sessionKey)
|
||||
if len(history) <= 4 {
|
||||
return
|
||||
}
|
||||
|
||||
// Keep system prompt (usually [0]) and the very last message (user's trigger)
|
||||
// We want to drop the oldest half of the *conversation*
|
||||
// Assuming [0] is system, [1:] is conversation
|
||||
conversation := history[1 : len(history)-1]
|
||||
if len(conversation) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Helper to find the mid-point of the conversation
|
||||
mid := len(conversation) / 2
|
||||
|
||||
// New history structure:
|
||||
// 1. System Prompt
|
||||
// 2. [Summary of dropped part] - synthesized
|
||||
// 3. Second half of conversation
|
||||
// 4. Last message
|
||||
|
||||
// Simplified approach for emergency: Drop first half of conversation
|
||||
// and rely on existing summary if present, or create a placeholder.
|
||||
|
||||
droppedCount := mid
|
||||
keptConversation := conversation[mid:]
|
||||
|
||||
newHistory := make([]providers.Message, 0)
|
||||
newHistory = append(newHistory, history[0]) // System prompt
|
||||
|
||||
// Add a note about compression
|
||||
compressionNote := fmt.Sprintf("[System: Emergency compression dropped %d oldest messages due to context limit]", droppedCount)
|
||||
// If there was an existing summary, we might lose it if it was in the dropped part (which is just messages).
|
||||
// The summary is stored separately in session.Summary, so it persists!
|
||||
// We just need to ensure the user knows there's a gap.
|
||||
|
||||
// We only modify the messages list here
|
||||
newHistory = append(newHistory, providers.Message{
|
||||
Role: "system",
|
||||
Content: compressionNote,
|
||||
})
|
||||
|
||||
newHistory = append(newHistory, keptConversation...)
|
||||
newHistory = append(newHistory, history[len(history)-1]) // Last message
|
||||
|
||||
// Update session
|
||||
al.sessions.SetHistory(sessionKey, newHistory)
|
||||
al.sessions.Save(sessionKey)
|
||||
|
||||
logger.WarnCF("agent", "Forced compression executed", map[string]interface{}{
|
||||
"session_key": sessionKey,
|
||||
"dropped_msgs": droppedCount,
|
||||
"new_count": len(newHistory),
|
||||
})
|
||||
}
|
||||
|
||||
// GetStartupInfo returns information about loaded tools and skills for logging.
|
||||
func (al *AgentLoop) GetStartupInfo() map[string]interface{} {
|
||||
info := make(map[string]interface{})
|
||||
@@ -488,7 +888,7 @@ func formatMessagesForLog(messages []providers.Message) string {
|
||||
result += "[\n"
|
||||
for i, msg := range messages {
|
||||
result += fmt.Sprintf(" [%d] Role: %s\n", i, msg.Role)
|
||||
if msg.ToolCalls != nil && len(msg.ToolCalls) > 0 {
|
||||
if len(msg.ToolCalls) > 0 {
|
||||
result += " ToolCalls:\n"
|
||||
for _, tc := range msg.ToolCalls {
|
||||
result += fmt.Sprintf(" - ID: %s, Type: %s, Name: %s\n", tc.ID, tc.Type, tc.Name)
|
||||
@@ -555,7 +955,7 @@ func (al *AgentLoop) summarizeSession(sessionKey string) {
|
||||
continue
|
||||
}
|
||||
// Estimate tokens for this message
|
||||
msgTokens := len(m.Content) / 4
|
||||
msgTokens := len(m.Content) / 2 // Use safer estimate here too (2.5 -> 2 for integer division safety)
|
||||
if msgTokens > maxMessageTokens {
|
||||
omitted = true
|
||||
continue
|
||||
@@ -600,7 +1000,7 @@ func (al *AgentLoop) summarizeSession(sessionKey string) {
|
||||
if finalSummary != "" {
|
||||
al.sessions.SetSummary(sessionKey, finalSummary)
|
||||
al.sessions.TruncateHistory(sessionKey, 4)
|
||||
al.sessions.Save(al.sessions.GetOrCreate(sessionKey))
|
||||
al.sessions.Save(sessionKey)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -626,10 +1026,96 @@ func (al *AgentLoop) summarizeBatch(ctx context.Context, batch []providers.Messa
|
||||
}
|
||||
|
||||
// estimateTokens estimates the number of tokens in a message list.
|
||||
// Uses a safe heuristic of 2.5 characters per token to account for CJK and other
|
||||
// overheads better than the previous 3 chars/token.
|
||||
func (al *AgentLoop) estimateTokens(messages []providers.Message) int {
|
||||
total := 0
|
||||
totalChars := 0
|
||||
for _, m := range messages {
|
||||
total += len(m.Content) / 4 // Simple heuristic: 4 chars per token
|
||||
totalChars += utf8.RuneCountInString(m.Content)
|
||||
}
|
||||
return total
|
||||
// 2.5 chars per token = totalChars * 2 / 5
|
||||
return totalChars * 2 / 5
|
||||
}
|
||||
|
||||
func (al *AgentLoop) handleCommand(ctx context.Context, msg bus.InboundMessage) (string, bool) {
|
||||
content := strings.TrimSpace(msg.Content)
|
||||
if !strings.HasPrefix(content, "/") {
|
||||
return "", false
|
||||
}
|
||||
|
||||
parts := strings.Fields(content)
|
||||
if len(parts) == 0 {
|
||||
return "", false
|
||||
}
|
||||
|
||||
cmd := parts[0]
|
||||
args := parts[1:]
|
||||
|
||||
switch cmd {
|
||||
case "/show":
|
||||
if len(args) < 1 {
|
||||
return "Usage: /show [model|channel]", true
|
||||
}
|
||||
switch args[0] {
|
||||
case "model":
|
||||
return fmt.Sprintf("Current model: %s", al.model), true
|
||||
case "channel":
|
||||
return fmt.Sprintf("Current channel: %s", msg.Channel), true
|
||||
default:
|
||||
return fmt.Sprintf("Unknown show target: %s", args[0]), true
|
||||
}
|
||||
|
||||
case "/list":
|
||||
if len(args) < 1 {
|
||||
return "Usage: /list [models|channels]", true
|
||||
}
|
||||
switch args[0] {
|
||||
case "models":
|
||||
// TODO: Fetch available models dynamically if possible
|
||||
return "Available models: glm-4.7, claude-3-5-sonnet, gpt-4o (configured in config.json/env)", true
|
||||
case "channels":
|
||||
if al.channelManager == nil {
|
||||
return "Channel manager not initialized", true
|
||||
}
|
||||
channels := al.channelManager.GetEnabledChannels()
|
||||
if len(channels) == 0 {
|
||||
return "No channels enabled", true
|
||||
}
|
||||
return fmt.Sprintf("Enabled channels: %s", strings.Join(channels, ", ")), true
|
||||
default:
|
||||
return fmt.Sprintf("Unknown list target: %s", args[0]), true
|
||||
}
|
||||
|
||||
case "/switch":
|
||||
if len(args) < 3 || args[1] != "to" {
|
||||
return "Usage: /switch [model|channel] to <name>", true
|
||||
}
|
||||
target := args[0]
|
||||
value := args[2]
|
||||
|
||||
switch target {
|
||||
case "model":
|
||||
oldModel := al.model
|
||||
al.model = value
|
||||
return fmt.Sprintf("Switched model from %s to %s", oldModel, value), true
|
||||
case "channel":
|
||||
// This changes the 'default' channel for some operations, or effectively redirects output?
|
||||
// For now, let's just validate if the channel exists
|
||||
if al.channelManager == nil {
|
||||
return "Channel manager not initialized", true
|
||||
}
|
||||
if _, exists := al.channelManager.GetChannel(value); !exists && value != "cli" {
|
||||
return fmt.Sprintf("Channel '%s' not found or not enabled", value), true
|
||||
}
|
||||
|
||||
// If message came from CLI, maybe we want to redirect CLI output to this channel?
|
||||
// That would require state persistence about "redirected channel"
|
||||
// For now, just acknowledged.
|
||||
return fmt.Sprintf("Switched target channel to %s (Note: this currently only validates existence)", value), true
|
||||
default:
|
||||
return fmt.Sprintf("Unknown switch target: %s", target), true
|
||||
}
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
@@ -0,0 +1,626 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
|
||||
// mockProvider is a simple mock LLM provider for testing
|
||||
type mockProvider struct{}
|
||||
|
||||
func (m *mockProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, opts map[string]interface{}) (*providers.LLMResponse, error) {
|
||||
return &providers.LLMResponse{
|
||||
Content: "Mock response",
|
||||
ToolCalls: []providers.ToolCall{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *mockProvider) GetDefaultModel() string {
|
||||
return "mock-model"
|
||||
}
|
||||
|
||||
func TestRecordLastChannel(t *testing.T) {
|
||||
// Create temp workspace
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
// Create test config
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Create agent loop
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &mockProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
// Test RecordLastChannel
|
||||
testChannel := "test-channel"
|
||||
err = al.RecordLastChannel(testChannel)
|
||||
if err != nil {
|
||||
t.Fatalf("RecordLastChannel failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify channel was saved
|
||||
lastChannel := al.state.GetLastChannel()
|
||||
if lastChannel != testChannel {
|
||||
t.Errorf("Expected channel '%s', got '%s'", testChannel, lastChannel)
|
||||
}
|
||||
|
||||
// Verify persistence by creating a new agent loop
|
||||
al2 := NewAgentLoop(cfg, msgBus, provider)
|
||||
if al2.state.GetLastChannel() != testChannel {
|
||||
t.Errorf("Expected persistent channel '%s', got '%s'", testChannel, al2.state.GetLastChannel())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordLastChatID(t *testing.T) {
|
||||
// Create temp workspace
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
// Create test config
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Create agent loop
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &mockProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
// Test RecordLastChatID
|
||||
testChatID := "test-chat-id-123"
|
||||
err = al.RecordLastChatID(testChatID)
|
||||
if err != nil {
|
||||
t.Fatalf("RecordLastChatID failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify chat ID was saved
|
||||
lastChatID := al.state.GetLastChatID()
|
||||
if lastChatID != testChatID {
|
||||
t.Errorf("Expected chat ID '%s', got '%s'", testChatID, lastChatID)
|
||||
}
|
||||
|
||||
// Verify persistence by creating a new agent loop
|
||||
al2 := NewAgentLoop(cfg, msgBus, provider)
|
||||
if al2.state.GetLastChatID() != testChatID {
|
||||
t.Errorf("Expected persistent chat ID '%s', got '%s'", testChatID, al2.state.GetLastChatID())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAgentLoop_StateInitialized(t *testing.T) {
|
||||
// Create temp workspace
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
// Create test config
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Create agent loop
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &mockProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
// Verify state manager is initialized
|
||||
if al.state == nil {
|
||||
t.Error("Expected state manager to be initialized")
|
||||
}
|
||||
|
||||
// Verify state directory was created
|
||||
stateDir := filepath.Join(tmpDir, "state")
|
||||
if _, err := os.Stat(stateDir); os.IsNotExist(err) {
|
||||
t.Error("Expected state directory to exist")
|
||||
}
|
||||
}
|
||||
|
||||
// TestToolRegistry_ToolRegistration verifies tools can be registered and retrieved
|
||||
func TestToolRegistry_ToolRegistration(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &mockProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
// Register a custom tool
|
||||
customTool := &mockCustomTool{}
|
||||
al.RegisterTool(customTool)
|
||||
|
||||
// Verify tool is registered by checking it doesn't panic on GetStartupInfo
|
||||
// (actual tool retrieval is tested in tools package tests)
|
||||
info := al.GetStartupInfo()
|
||||
toolsInfo := info["tools"].(map[string]interface{})
|
||||
toolsList := toolsInfo["names"].([]string)
|
||||
|
||||
// Check that our custom tool name is in the list
|
||||
found := false
|
||||
for _, name := range toolsList {
|
||||
if name == "mock_custom" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("Expected custom tool to be registered")
|
||||
}
|
||||
}
|
||||
|
||||
// TestToolContext_Updates verifies tool context is updated with channel/chatID
|
||||
func TestToolContext_Updates(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &simpleMockProvider{response: "OK"}
|
||||
_ = NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
// Verify that ContextualTool interface is defined and can be implemented
|
||||
// This test validates the interface contract exists
|
||||
ctxTool := &mockContextualTool{}
|
||||
|
||||
// Verify the tool implements the interface correctly
|
||||
var _ tools.ContextualTool = ctxTool
|
||||
}
|
||||
|
||||
// TestToolRegistry_GetDefinitions verifies tool definitions can be retrieved
|
||||
func TestToolRegistry_GetDefinitions(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &mockProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
// Register a test tool and verify it shows up in startup info
|
||||
testTool := &mockCustomTool{}
|
||||
al.RegisterTool(testTool)
|
||||
|
||||
info := al.GetStartupInfo()
|
||||
toolsInfo := info["tools"].(map[string]interface{})
|
||||
toolsList := toolsInfo["names"].([]string)
|
||||
|
||||
// Check that our custom tool name is in the list
|
||||
found := false
|
||||
for _, name := range toolsList {
|
||||
if name == "mock_custom" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("Expected custom tool to be registered")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAgentLoop_GetStartupInfo verifies startup info contains tools
|
||||
func TestAgentLoop_GetStartupInfo(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &mockProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
info := al.GetStartupInfo()
|
||||
|
||||
// Verify tools info exists
|
||||
toolsInfo, ok := info["tools"]
|
||||
if !ok {
|
||||
t.Fatal("Expected 'tools' key in startup info")
|
||||
}
|
||||
|
||||
toolsMap, ok := toolsInfo.(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatal("Expected 'tools' to be a map")
|
||||
}
|
||||
|
||||
count, ok := toolsMap["count"]
|
||||
if !ok {
|
||||
t.Fatal("Expected 'count' in tools info")
|
||||
}
|
||||
|
||||
// Should have default tools registered
|
||||
if count.(int) == 0 {
|
||||
t.Error("Expected at least some tools to be registered")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAgentLoop_Stop verifies Stop() sets running to false
|
||||
func TestAgentLoop_Stop(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &mockProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
// Note: running is only set to true when Run() is called
|
||||
// We can't test that without starting the event loop
|
||||
// Instead, verify the Stop method can be called safely
|
||||
al.Stop()
|
||||
|
||||
// Verify running is false (initial state or after Stop)
|
||||
if al.running.Load() {
|
||||
t.Error("Expected agent to be stopped (or never started)")
|
||||
}
|
||||
}
|
||||
|
||||
// Mock implementations for testing
|
||||
|
||||
type simpleMockProvider struct {
|
||||
response string
|
||||
}
|
||||
|
||||
func (m *simpleMockProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, opts map[string]interface{}) (*providers.LLMResponse, error) {
|
||||
return &providers.LLMResponse{
|
||||
Content: m.response,
|
||||
ToolCalls: []providers.ToolCall{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *simpleMockProvider) GetDefaultModel() string {
|
||||
return "mock-model"
|
||||
}
|
||||
|
||||
// mockCustomTool is a simple mock tool for registration testing
|
||||
type mockCustomTool struct{}
|
||||
|
||||
func (m *mockCustomTool) Name() string {
|
||||
return "mock_custom"
|
||||
}
|
||||
|
||||
func (m *mockCustomTool) Description() string {
|
||||
return "Mock custom tool for testing"
|
||||
}
|
||||
|
||||
func (m *mockCustomTool) Parameters() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockCustomTool) Execute(ctx context.Context, args map[string]interface{}) *tools.ToolResult {
|
||||
return tools.SilentResult("Custom tool executed")
|
||||
}
|
||||
|
||||
// mockContextualTool tracks context updates
|
||||
type mockContextualTool struct {
|
||||
lastChannel string
|
||||
lastChatID string
|
||||
}
|
||||
|
||||
func (m *mockContextualTool) Name() string {
|
||||
return "mock_contextual"
|
||||
}
|
||||
|
||||
func (m *mockContextualTool) Description() string {
|
||||
return "Mock contextual tool"
|
||||
}
|
||||
|
||||
func (m *mockContextualTool) Parameters() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockContextualTool) Execute(ctx context.Context, args map[string]interface{}) *tools.ToolResult {
|
||||
return tools.SilentResult("Contextual tool executed")
|
||||
}
|
||||
|
||||
func (m *mockContextualTool) SetContext(channel, chatID string) {
|
||||
m.lastChannel = channel
|
||||
m.lastChatID = chatID
|
||||
}
|
||||
|
||||
// testHelper executes a message and returns the response
|
||||
type testHelper struct {
|
||||
al *AgentLoop
|
||||
}
|
||||
|
||||
func (h testHelper) executeAndGetResponse(tb testing.TB, ctx context.Context, msg bus.InboundMessage) string {
|
||||
// Use a short timeout to avoid hanging
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, responseTimeout)
|
||||
defer cancel()
|
||||
|
||||
response, err := h.al.processMessage(timeoutCtx, msg)
|
||||
if err != nil {
|
||||
tb.Fatalf("processMessage failed: %v", err)
|
||||
}
|
||||
return response
|
||||
}
|
||||
|
||||
const responseTimeout = 3 * time.Second
|
||||
|
||||
// TestToolResult_SilentToolDoesNotSendUserMessage verifies silent tools don't trigger outbound
|
||||
func TestToolResult_SilentToolDoesNotSendUserMessage(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &simpleMockProvider{response: "File operation complete"}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
helper := testHelper{al: al}
|
||||
|
||||
// ReadFileTool returns SilentResult, which should not send user message
|
||||
ctx := context.Background()
|
||||
msg := bus.InboundMessage{
|
||||
Channel: "test",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "read test.txt",
|
||||
SessionKey: "test-session",
|
||||
}
|
||||
|
||||
response := helper.executeAndGetResponse(t, ctx, msg)
|
||||
|
||||
// Silent tool should return the LLM's response directly
|
||||
if response != "File operation complete" {
|
||||
t.Errorf("Expected 'File operation complete', got: %s", response)
|
||||
}
|
||||
}
|
||||
|
||||
// TestToolResult_UserFacingToolDoesSendMessage verifies user-facing tools trigger outbound
|
||||
func TestToolResult_UserFacingToolDoesSendMessage(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &simpleMockProvider{response: "Command output: hello world"}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
helper := testHelper{al: al}
|
||||
|
||||
// ExecTool returns UserResult, which should send user message
|
||||
ctx := context.Background()
|
||||
msg := bus.InboundMessage{
|
||||
Channel: "test",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "run hello",
|
||||
SessionKey: "test-session",
|
||||
}
|
||||
|
||||
response := helper.executeAndGetResponse(t, ctx, msg)
|
||||
|
||||
// User-facing tool should include the output in final response
|
||||
if response != "Command output: hello world" {
|
||||
t.Errorf("Expected 'Command output: hello world', got: %s", response)
|
||||
}
|
||||
}
|
||||
|
||||
// failFirstMockProvider fails on the first N calls with a specific error
|
||||
type failFirstMockProvider struct {
|
||||
failures int
|
||||
currentCall int
|
||||
failError error
|
||||
successResp string
|
||||
}
|
||||
|
||||
func (m *failFirstMockProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, opts map[string]interface{}) (*providers.LLMResponse, error) {
|
||||
m.currentCall++
|
||||
if m.currentCall <= m.failures {
|
||||
return nil, m.failError
|
||||
}
|
||||
return &providers.LLMResponse{
|
||||
Content: m.successResp,
|
||||
ToolCalls: []providers.ToolCall{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *failFirstMockProvider) GetDefaultModel() string {
|
||||
return "mock-fail-model"
|
||||
}
|
||||
|
||||
// TestAgentLoop_ContextExhaustionRetry verify that the agent retries on context errors
|
||||
func TestAgentLoop_ContextExhaustionRetry(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
|
||||
// Create a provider that fails once with a context error
|
||||
contextErr := fmt.Errorf("InvalidParameter: Total tokens of image and text exceed max message tokens")
|
||||
provider := &failFirstMockProvider{
|
||||
failures: 1,
|
||||
failError: contextErr,
|
||||
successResp: "Recovered from context error",
|
||||
}
|
||||
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
// Inject some history to simulate a full context
|
||||
sessionKey := "test-session-context"
|
||||
// Create dummy history
|
||||
history := []providers.Message{
|
||||
{Role: "system", Content: "System prompt"},
|
||||
{Role: "user", Content: "Old message 1"},
|
||||
{Role: "assistant", Content: "Old response 1"},
|
||||
{Role: "user", Content: "Old message 2"},
|
||||
{Role: "assistant", Content: "Old response 2"},
|
||||
{Role: "user", Content: "Trigger message"},
|
||||
}
|
||||
al.sessions.SetHistory(sessionKey, history)
|
||||
|
||||
// Call ProcessDirectWithChannel
|
||||
// Note: ProcessDirectWithChannel calls processMessage which will execute runLLMIteration
|
||||
response, err := al.ProcessDirectWithChannel(context.Background(), "Trigger message", sessionKey, "test", "test-chat")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected success after retry, got error: %v", err)
|
||||
}
|
||||
|
||||
if response != "Recovered from context error" {
|
||||
t.Errorf("Expected 'Recovered from context error', got '%s'", response)
|
||||
}
|
||||
|
||||
// We expect 2 calls: 1st failed, 2nd succeeded
|
||||
if provider.currentCall != 2 {
|
||||
t.Errorf("Expected 2 calls (1 fail + 1 success), got %d", provider.currentCall)
|
||||
}
|
||||
|
||||
// Check final history length
|
||||
finalHistory := al.sessions.GetHistory(sessionKey)
|
||||
// We verify that the history has been modified (compressed)
|
||||
// Original length: 6
|
||||
// Expected behavior: compression drops ~50% of history (mid slice)
|
||||
// We can assert that the length is NOT what it would be without compression.
|
||||
// Without compression: 6 + 1 (new user msg) + 1 (assistant msg) = 8
|
||||
if len(finalHistory) >= 8 {
|
||||
t.Errorf("Expected history to be compressed (len < 8), got %d", len(finalHistory))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,110 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/mcp"
|
||||
)
|
||||
|
||||
const (
|
||||
mcpBootstrapMinTimeout = 10 * time.Second
|
||||
mcpBootstrapMaxTimeout = 5 * time.Minute
|
||||
mcpBootstrapGraceTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
type mcpBootstrapResult struct {
|
||||
Manager *mcp.Manager
|
||||
Tools []mcp.RegisteredTool
|
||||
}
|
||||
|
||||
func bootstrapMCP(cfg config.MCPToolsConfig) (*mcpBootstrapResult, error) {
|
||||
serverConfigs := buildMCPServerConfigs(cfg)
|
||||
if len(serverConfigs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
manager := mcp.NewManager(serverConfigs)
|
||||
|
||||
discoveryTimeout := calculateMCPDiscoveryTimeout(serverConfigs)
|
||||
discoveryCtx, cancel := context.WithTimeout(context.Background(), discoveryTimeout)
|
||||
defer cancel()
|
||||
|
||||
discoveredTools, err := manager.DiscoverTools(discoveryCtx)
|
||||
if err != nil {
|
||||
_ = manager.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &mcpBootstrapResult{
|
||||
Manager: manager,
|
||||
Tools: discoveredTools,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func calculateMCPDiscoveryTimeout(serverConfigs map[string]mcp.ServerConfig) time.Duration {
|
||||
maxInitTimeout := mcpBootstrapMinTimeout
|
||||
|
||||
for _, serverConfig := range serverConfigs {
|
||||
initTimeout := serverConfig.InitTimeout()
|
||||
if initTimeout > maxInitTimeout {
|
||||
maxInitTimeout = initTimeout
|
||||
}
|
||||
}
|
||||
|
||||
timeout := maxInitTimeout + mcpBootstrapGraceTimeout
|
||||
if timeout < mcpBootstrapMinTimeout {
|
||||
return mcpBootstrapMinTimeout
|
||||
}
|
||||
if timeout > mcpBootstrapMaxTimeout {
|
||||
return mcpBootstrapMaxTimeout
|
||||
}
|
||||
return timeout
|
||||
}
|
||||
|
||||
func buildMCPServerConfigs(cfg config.MCPToolsConfig) map[string]mcp.ServerConfig {
|
||||
servers := make(map[string]mcp.ServerConfig, len(cfg.Servers))
|
||||
|
||||
for serverName, serverCfg := range cfg.Servers {
|
||||
if !serverCfg.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
envCopy := make(map[string]string, len(serverCfg.Env))
|
||||
for key, value := range serverCfg.Env {
|
||||
envCopy[key] = value
|
||||
}
|
||||
|
||||
servers[serverName] = mcp.ServerConfig{
|
||||
Name: serverName,
|
||||
Command: serverCfg.Command,
|
||||
Args: append([]string{}, serverCfg.Args...),
|
||||
Env: envCopy,
|
||||
WorkingDir: serverCfg.WorkingDir,
|
||||
Protocol: inferMCPProtocol(serverCfg.Protocol, serverCfg.Command),
|
||||
InitTimeoutSeconds: serverCfg.InitTimeoutSeconds,
|
||||
CallTimeoutSeconds: serverCfg.CallTimeoutSeconds,
|
||||
MaxResponseBytes: serverCfg.MaxResponseBytes,
|
||||
IncludeTools: append([]string{}, serverCfg.IncludeTools...),
|
||||
ExcludeTools: append([]string{}, serverCfg.ExcludeTools...),
|
||||
}
|
||||
}
|
||||
|
||||
return servers
|
||||
}
|
||||
|
||||
func inferMCPProtocol(configuredProtocol, command string) string {
|
||||
if protocol := strings.TrimSpace(configuredProtocol); protocol != "" {
|
||||
return protocol
|
||||
}
|
||||
|
||||
// Context7 currently emits JSON-RPC messages as JSONL on stdio,
|
||||
// so defaulting avoids long startup waits when protocol is omitted.
|
||||
if strings.Contains(strings.ToLower(command), "context7-mcp") {
|
||||
return mcp.ProtocolJSONLines
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
+4
-4
@@ -40,8 +40,8 @@ func NewMemoryStore(workspace string) *MemoryStore {
|
||||
|
||||
// getTodayFile returns the path to today's daily note file (memory/YYYYMM/YYYYMMDD.md).
|
||||
func (ms *MemoryStore) getTodayFile() string {
|
||||
today := time.Now().Format("20060102") // YYYYMMDD
|
||||
monthDir := today[:6] // YYYYMM
|
||||
today := time.Now().Format("20060102") // YYYYMMDD
|
||||
monthDir := today[:6] // YYYYMM
|
||||
filePath := filepath.Join(ms.memoryDir, monthDir, today+".md")
|
||||
return filePath
|
||||
}
|
||||
@@ -104,8 +104,8 @@ func (ms *MemoryStore) GetRecentDailyNotes(days int) string {
|
||||
|
||||
for i := 0; i < days; i++ {
|
||||
date := time.Now().AddDate(0, 0, -i)
|
||||
dateStr := date.Format("20060102") // YYYYMMDD
|
||||
monthDir := dateStr[:6] // YYYYMM
|
||||
dateStr := date.Format("20060102") // YYYYMMDD
|
||||
monthDir := dateStr[:6] // YYYYMM
|
||||
filePath := filepath.Join(ms.memoryDir, monthDir, dateStr+".md")
|
||||
|
||||
if data, err := os.ReadFile(filePath); err == nil {
|
||||
|
||||
+82
-30
@@ -19,18 +19,20 @@ import (
|
||||
)
|
||||
|
||||
type OAuthProviderConfig struct {
|
||||
Issuer string
|
||||
ClientID string
|
||||
Scopes string
|
||||
Port int
|
||||
Issuer string
|
||||
ClientID string
|
||||
Scopes string
|
||||
Originator string
|
||||
Port int
|
||||
}
|
||||
|
||||
func OpenAIOAuthConfig() OAuthProviderConfig {
|
||||
return OAuthProviderConfig{
|
||||
Issuer: "https://auth.openai.com",
|
||||
ClientID: "app_EMoamEEZ73f0CkXaXp7hrann",
|
||||
Scopes: "openid profile email offline_access",
|
||||
Port: 1455,
|
||||
Issuer: "https://auth.openai.com",
|
||||
ClientID: "app_EMoamEEZ73f0CkXaXp7hrann",
|
||||
Scopes: "openid profile email offline_access",
|
||||
Originator: "codex_cli_rs",
|
||||
Port: 1455,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -279,7 +281,17 @@ func RefreshAccessToken(cred *AuthCredential, cfg OAuthProviderConfig) (*AuthCre
|
||||
return nil, fmt.Errorf("token refresh failed: %s", string(body))
|
||||
}
|
||||
|
||||
return parseTokenResponse(body, cred.Provider)
|
||||
refreshed, err := parseTokenResponse(body, cred.Provider)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if refreshed.RefreshToken == "" {
|
||||
refreshed.RefreshToken = cred.RefreshToken
|
||||
}
|
||||
if refreshed.AccountID == "" {
|
||||
refreshed.AccountID = cred.AccountID
|
||||
}
|
||||
return refreshed, nil
|
||||
}
|
||||
|
||||
func BuildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectURI string) string {
|
||||
@@ -288,15 +300,23 @@ func BuildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectU
|
||||
|
||||
func buildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectURI string) string {
|
||||
params := url.Values{
|
||||
"response_type": {"code"},
|
||||
"client_id": {cfg.ClientID},
|
||||
"redirect_uri": {redirectURI},
|
||||
"scope": {cfg.Scopes},
|
||||
"code_challenge": {pkce.CodeChallenge},
|
||||
"code_challenge_method": {"S256"},
|
||||
"state": {state},
|
||||
"response_type": {"code"},
|
||||
"client_id": {cfg.ClientID},
|
||||
"redirect_uri": {redirectURI},
|
||||
"scope": {cfg.Scopes},
|
||||
"code_challenge": {pkce.CodeChallenge},
|
||||
"code_challenge_method": {"S256"},
|
||||
"id_token_add_organizations": {"true"},
|
||||
"codex_cli_simplified_flow": {"true"},
|
||||
"state": {state},
|
||||
}
|
||||
return cfg.Issuer + "/authorize?" + params.Encode()
|
||||
if strings.Contains(strings.ToLower(cfg.Issuer), "auth.openai.com") {
|
||||
params.Set("originator", "picoclaw")
|
||||
}
|
||||
if cfg.Originator != "" {
|
||||
params.Set("originator", cfg.Originator)
|
||||
}
|
||||
return cfg.Issuer + "/oauth/authorize?" + params.Encode()
|
||||
}
|
||||
|
||||
func exchangeCodeForTokens(cfg OAuthProviderConfig, code, codeVerifier, redirectURI string) (*AuthCredential, error) {
|
||||
@@ -350,19 +370,57 @@ func parseTokenResponse(body []byte, provider string) (*AuthCredential, error) {
|
||||
AuthMethod: "oauth",
|
||||
}
|
||||
|
||||
if accountID := extractAccountID(tokenResp.AccessToken); accountID != "" {
|
||||
if accountID := extractAccountID(tokenResp.IDToken); accountID != "" {
|
||||
cred.AccountID = accountID
|
||||
} else if accountID := extractAccountID(tokenResp.AccessToken); accountID != "" {
|
||||
cred.AccountID = accountID
|
||||
} else if accountID := extractAccountID(tokenResp.IDToken); accountID != "" {
|
||||
// Recent OpenAI OAuth responses may only include chatgpt_account_id in id_token claims.
|
||||
cred.AccountID = accountID
|
||||
}
|
||||
|
||||
return cred, nil
|
||||
}
|
||||
|
||||
func extractAccountID(accessToken string) string {
|
||||
parts := strings.Split(accessToken, ".")
|
||||
if len(parts) < 2 {
|
||||
func extractAccountID(token string) string {
|
||||
claims, err := parseJWTClaims(token)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
if accountID, ok := claims["chatgpt_account_id"].(string); ok && accountID != "" {
|
||||
return accountID
|
||||
}
|
||||
|
||||
if accountID, ok := claims["https://api.openai.com/auth.chatgpt_account_id"].(string); ok && accountID != "" {
|
||||
return accountID
|
||||
}
|
||||
|
||||
if authClaim, ok := claims["https://api.openai.com/auth"].(map[string]interface{}); ok {
|
||||
if accountID, ok := authClaim["chatgpt_account_id"].(string); ok && accountID != "" {
|
||||
return accountID
|
||||
}
|
||||
}
|
||||
|
||||
if orgs, ok := claims["organizations"].([]interface{}); ok {
|
||||
for _, org := range orgs {
|
||||
if orgMap, ok := org.(map[string]interface{}); ok {
|
||||
if accountID, ok := orgMap["id"].(string); ok && accountID != "" {
|
||||
return accountID
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func parseJWTClaims(token string) (map[string]interface{}, error) {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) < 2 {
|
||||
return nil, fmt.Errorf("token is not a JWT")
|
||||
}
|
||||
|
||||
payload := parts[1]
|
||||
switch len(payload) % 4 {
|
||||
case 2:
|
||||
@@ -373,21 +431,15 @@ func extractAccountID(accessToken string) string {
|
||||
|
||||
decoded, err := base64URLDecode(payload)
|
||||
if err != nil {
|
||||
return ""
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var claims map[string]interface{}
|
||||
if err := json.Unmarshal(decoded, &claims); err != nil {
|
||||
return ""
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if authClaim, ok := claims["https://api.openai.com/auth"].(map[string]interface{}); ok {
|
||||
if accountID, ok := authClaim["chatgpt_account_id"].(string); ok {
|
||||
return accountID
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func base64URLDecode(s string) ([]byte, error) {
|
||||
|
||||
+139
-5
@@ -1,19 +1,34 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func makeJWTForClaims(t *testing.T, claims map[string]interface{}) string {
|
||||
t.Helper()
|
||||
|
||||
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"none","typ":"JWT"}`))
|
||||
payloadJSON, err := json.Marshal(claims)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal claims: %v", err)
|
||||
}
|
||||
payload := base64.RawURLEncoding.EncodeToString(payloadJSON)
|
||||
return header + "." + payload + ".sig"
|
||||
}
|
||||
|
||||
func TestBuildAuthorizeURL(t *testing.T) {
|
||||
cfg := OAuthProviderConfig{
|
||||
Issuer: "https://auth.example.com",
|
||||
ClientID: "test-client-id",
|
||||
Scopes: "openid profile",
|
||||
Port: 1455,
|
||||
Issuer: "https://auth.example.com",
|
||||
ClientID: "test-client-id",
|
||||
Scopes: "openid profile",
|
||||
Originator: "codex_cli_rs",
|
||||
Port: 1455,
|
||||
}
|
||||
pkce := PKCECodes{
|
||||
CodeVerifier: "test-verifier",
|
||||
@@ -22,7 +37,7 @@ func TestBuildAuthorizeURL(t *testing.T) {
|
||||
|
||||
u := BuildAuthorizeURL(cfg, pkce, "test-state", "http://localhost:1455/auth/callback")
|
||||
|
||||
if !strings.HasPrefix(u, "https://auth.example.com/authorize?") {
|
||||
if !strings.HasPrefix(u, "https://auth.example.com/oauth/authorize?") {
|
||||
t.Errorf("URL does not start with expected prefix: %s", u)
|
||||
}
|
||||
if !strings.Contains(u, "client_id=test-client-id") {
|
||||
@@ -40,6 +55,37 @@ func TestBuildAuthorizeURL(t *testing.T) {
|
||||
if !strings.Contains(u, "response_type=code") {
|
||||
t.Error("URL missing response_type")
|
||||
}
|
||||
if !strings.Contains(u, "id_token_add_organizations=true") {
|
||||
t.Error("URL missing id_token_add_organizations")
|
||||
}
|
||||
if !strings.Contains(u, "codex_cli_simplified_flow=true") {
|
||||
t.Error("URL missing codex_cli_simplified_flow")
|
||||
}
|
||||
if !strings.Contains(u, "originator=codex_cli_rs") {
|
||||
t.Error("URL missing originator")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAuthorizeURLOpenAIExtras(t *testing.T) {
|
||||
cfg := OpenAIOAuthConfig()
|
||||
pkce := PKCECodes{CodeVerifier: "test-verifier", CodeChallenge: "test-challenge"}
|
||||
|
||||
u := BuildAuthorizeURL(cfg, pkce, "test-state", "http://localhost:1455/auth/callback")
|
||||
parsed, err := url.Parse(u)
|
||||
if err != nil {
|
||||
t.Fatalf("url.Parse() error: %v", err)
|
||||
}
|
||||
q := parsed.Query()
|
||||
|
||||
if q.Get("id_token_add_organizations") != "true" {
|
||||
t.Errorf("id_token_add_organizations = %q, want true", q.Get("id_token_add_organizations"))
|
||||
}
|
||||
if q.Get("codex_cli_simplified_flow") != "true" {
|
||||
t.Errorf("codex_cli_simplified_flow = %q, want true", q.Get("codex_cli_simplified_flow"))
|
||||
}
|
||||
if q.Get("originator") != "codex_cli_rs" {
|
||||
t.Errorf("originator = %q, want codex_cli_rs", q.Get("originator"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTokenResponse(t *testing.T) {
|
||||
@@ -73,6 +119,37 @@ func TestParseTokenResponse(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTokenResponseExtractsAccountIDFromIDToken(t *testing.T) {
|
||||
idToken := makeJWTForClaims(t, map[string]interface{}{"chatgpt_account_id": "acc-id-from-id-token"})
|
||||
resp := map[string]interface{}{
|
||||
"access_token": "opaque-access-token",
|
||||
"refresh_token": "test-refresh-token",
|
||||
"expires_in": 3600,
|
||||
"id_token": idToken,
|
||||
}
|
||||
body, _ := json.Marshal(resp)
|
||||
|
||||
cred, err := parseTokenResponse(body, "openai")
|
||||
if err != nil {
|
||||
t.Fatalf("parseTokenResponse() error: %v", err)
|
||||
}
|
||||
if cred.AccountID != "acc-id-from-id-token" {
|
||||
t.Errorf("AccountID = %q, want %q", cred.AccountID, "acc-id-from-id-token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractAccountIDFromOrganizationsFallback(t *testing.T) {
|
||||
token := makeJWTForClaims(t, map[string]interface{}{
|
||||
"organizations": []interface{}{
|
||||
map[string]interface{}{"id": "org_from_orgs"},
|
||||
},
|
||||
})
|
||||
|
||||
if got := extractAccountID(token); got != "org_from_orgs" {
|
||||
t.Errorf("extractAccountID() = %q, want %q", got, "org_from_orgs")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTokenResponseNoAccessToken(t *testing.T) {
|
||||
body := []byte(`{"refresh_token": "test"}`)
|
||||
_, err := parseTokenResponse(body, "openai")
|
||||
@@ -81,6 +158,32 @@ func TestParseTokenResponseNoAccessToken(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTokenResponseAccountIDFromIDToken(t *testing.T) {
|
||||
idToken := makeJWTWithAccountID("acc-from-id")
|
||||
resp := map[string]interface{}{
|
||||
"access_token": "not-a-jwt",
|
||||
"refresh_token": "test-refresh-token",
|
||||
"expires_in": 3600,
|
||||
"id_token": idToken,
|
||||
}
|
||||
body, _ := json.Marshal(resp)
|
||||
|
||||
cred, err := parseTokenResponse(body, "openai")
|
||||
if err != nil {
|
||||
t.Fatalf("parseTokenResponse() error: %v", err)
|
||||
}
|
||||
|
||||
if cred.AccountID != "acc-from-id" {
|
||||
t.Errorf("AccountID = %q, want %q", cred.AccountID, "acc-from-id")
|
||||
}
|
||||
}
|
||||
|
||||
func makeJWTWithAccountID(accountID string) string {
|
||||
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"none","typ":"JWT"}`))
|
||||
payload := base64.RawURLEncoding.EncodeToString([]byte(`{"https://api.openai.com/auth":{"chatgpt_account_id":"` + accountID + `"}}`))
|
||||
return header + "." + payload + ".sig"
|
||||
}
|
||||
|
||||
func TestExchangeCodeForTokens(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/oauth/token" {
|
||||
@@ -185,6 +288,37 @@ func TestRefreshAccessTokenNoRefreshToken(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefreshAccessTokenPreservesRefreshAndAccountID(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
resp := map[string]interface{}{
|
||||
"access_token": "new-access-token-only",
|
||||
"expires_in": 3600,
|
||||
}
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cfg := OAuthProviderConfig{Issuer: server.URL, ClientID: "test-client"}
|
||||
cred := &AuthCredential{
|
||||
AccessToken: "old-access",
|
||||
RefreshToken: "existing-refresh",
|
||||
AccountID: "acc_existing",
|
||||
Provider: "openai",
|
||||
AuthMethod: "oauth",
|
||||
}
|
||||
|
||||
refreshed, err := RefreshAccessToken(cred, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("RefreshAccessToken() error: %v", err)
|
||||
}
|
||||
if refreshed.RefreshToken != "existing-refresh" {
|
||||
t.Errorf("RefreshToken = %q, want %q", refreshed.RefreshToken, "existing-refresh")
|
||||
}
|
||||
if refreshed.AccountID != "acc_existing" {
|
||||
t.Errorf("AccountID = %q, want %q", refreshed.AccountID, "acc_existing")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIOAuthConfig(t *testing.T) {
|
||||
cfg := OpenAIOAuthConfig()
|
||||
if cfg.Issuer != "https://auth.openai.com" {
|
||||
|
||||
@@ -9,6 +9,7 @@ type MessageBus struct {
|
||||
inbound chan InboundMessage
|
||||
outbound chan OutboundMessage
|
||||
handlers map[string]MessageHandler
|
||||
closed bool
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
@@ -21,6 +22,11 @@ func NewMessageBus() *MessageBus {
|
||||
}
|
||||
|
||||
func (mb *MessageBus) PublishInbound(msg InboundMessage) {
|
||||
mb.mu.RLock()
|
||||
defer mb.mu.RUnlock()
|
||||
if mb.closed {
|
||||
return
|
||||
}
|
||||
mb.inbound <- msg
|
||||
}
|
||||
|
||||
@@ -34,6 +40,11 @@ func (mb *MessageBus) ConsumeInbound(ctx context.Context) (InboundMessage, bool)
|
||||
}
|
||||
|
||||
func (mb *MessageBus) PublishOutbound(msg OutboundMessage) {
|
||||
mb.mu.RLock()
|
||||
defer mb.mu.RUnlock()
|
||||
if mb.closed {
|
||||
return
|
||||
}
|
||||
mb.outbound <- msg
|
||||
}
|
||||
|
||||
@@ -60,6 +71,12 @@ func (mb *MessageBus) GetHandler(channel string) (MessageHandler, bool) {
|
||||
}
|
||||
|
||||
func (mb *MessageBus) Close() {
|
||||
mb.mu.Lock()
|
||||
defer mb.mu.Unlock()
|
||||
if mb.closed {
|
||||
return
|
||||
}
|
||||
mb.closed = true
|
||||
close(mb.inbound)
|
||||
close(mb.outbound)
|
||||
}
|
||||
|
||||
+16
-1
@@ -59,7 +59,22 @@ func (c *BaseChannel) IsAllowed(senderID string) bool {
|
||||
for _, allowed := range c.allowList {
|
||||
// Strip leading "@" from allowed value for username matching
|
||||
trimmed := strings.TrimPrefix(allowed, "@")
|
||||
if senderID == allowed || idPart == allowed || senderID == trimmed || idPart == trimmed || (userPart != "" && (userPart == allowed || userPart == trimmed)) {
|
||||
allowedID := trimmed
|
||||
allowedUser := ""
|
||||
if idx := strings.Index(trimmed, "|"); idx > 0 {
|
||||
allowedID = trimmed[:idx]
|
||||
allowedUser = trimmed[idx+1:]
|
||||
}
|
||||
|
||||
// Support either side using "id|username" compound form.
|
||||
// This keeps backward compatibility with legacy Telegram allowlist entries.
|
||||
if senderID == allowed ||
|
||||
idPart == allowed ||
|
||||
senderID == trimmed ||
|
||||
idPart == trimmed ||
|
||||
idPart == allowedID ||
|
||||
(allowedUser != "" && senderID == allowedUser) ||
|
||||
(userPart != "" && (userPart == allowed || userPart == trimmed || userPart == allowedUser)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
package channels
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestBaseChannelIsAllowed(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
allowList []string
|
||||
senderID string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "empty allowlist allows all",
|
||||
allowList: nil,
|
||||
senderID: "anyone",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "compound sender matches numeric allowlist",
|
||||
allowList: []string{"123456"},
|
||||
senderID: "123456|alice",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "compound sender matches username allowlist",
|
||||
allowList: []string{"@alice"},
|
||||
senderID: "123456|alice",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "numeric sender matches legacy compound allowlist",
|
||||
allowList: []string{"123456|alice"},
|
||||
senderID: "123456",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "non matching sender is denied",
|
||||
allowList: []string{"123456"},
|
||||
senderID: "654321|bob",
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ch := NewBaseChannel("test", nil, nil, tt.allowList)
|
||||
if got := ch.IsAllowed(tt.senderID); got != tt.want {
|
||||
t.Fatalf("IsAllowed(%q) = %v, want %v", tt.senderID, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -20,12 +20,12 @@ import (
|
||||
// It uses WebSocket for receiving messages via stream mode and API for sending
|
||||
type DingTalkChannel struct {
|
||||
*BaseChannel
|
||||
config config.DingTalkConfig
|
||||
clientID string
|
||||
clientSecret string
|
||||
streamClient *client.StreamClient
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
config config.DingTalkConfig
|
||||
clientID string
|
||||
clientSecret string
|
||||
streamClient *client.StreamClient
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
// Map to store session webhooks for each chat
|
||||
sessionWebhooks sync.Map // chatID -> sessionWebhook
|
||||
}
|
||||
@@ -109,8 +109,8 @@ func (c *DingTalkChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
|
||||
}
|
||||
|
||||
logger.DebugCF("dingtalk", "Sending message", map[string]interface{}{
|
||||
"chat_id": msg.ChatID,
|
||||
"preview": utils.Truncate(msg.Content, 100),
|
||||
"chat_id": msg.ChatID,
|
||||
"preview": utils.Truncate(msg.Content, 100),
|
||||
})
|
||||
|
||||
// Use the session webhook to send the reply
|
||||
|
||||
+150
-2
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/bwmarrin/discordgo"
|
||||
@@ -100,15 +101,156 @@ func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro
|
||||
return fmt.Errorf("channel ID is empty")
|
||||
}
|
||||
|
||||
message := msg.Content
|
||||
runes := []rune(msg.Content)
|
||||
if len(runes) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
chunks := splitMessage(msg.Content, 1500) // Discord has a limit of 2000 characters per message, leave 500 for natural split e.g. code blocks
|
||||
|
||||
for _, chunk := range chunks {
|
||||
if err := c.sendChunk(ctx, channelID, chunk); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// splitMessage splits long messages into chunks, preserving code block integrity
|
||||
// Uses natural boundaries (newlines, spaces) and extends messages slightly to avoid breaking code blocks
|
||||
func splitMessage(content string, limit int) []string {
|
||||
var messages []string
|
||||
|
||||
for len(content) > 0 {
|
||||
if len(content) <= limit {
|
||||
messages = append(messages, content)
|
||||
break
|
||||
}
|
||||
|
||||
msgEnd := limit
|
||||
|
||||
// Find natural split point within the limit
|
||||
msgEnd = findLastNewline(content[:limit], 200)
|
||||
if msgEnd <= 0 {
|
||||
msgEnd = findLastSpace(content[:limit], 100)
|
||||
}
|
||||
if msgEnd <= 0 {
|
||||
msgEnd = limit
|
||||
}
|
||||
|
||||
// Check if this would end with an incomplete code block
|
||||
candidate := content[:msgEnd]
|
||||
unclosedIdx := findLastUnclosedCodeBlock(candidate)
|
||||
|
||||
if unclosedIdx >= 0 {
|
||||
// Message would end with incomplete code block
|
||||
// Try to extend to include the closing ``` (with some buffer)
|
||||
extendedLimit := limit + 500 // Allow 500 char buffer for code blocks
|
||||
if len(content) > extendedLimit {
|
||||
closingIdx := findNextClosingCodeBlock(content, msgEnd)
|
||||
if closingIdx > 0 && closingIdx <= extendedLimit {
|
||||
// Extend to include the closing ```
|
||||
msgEnd = closingIdx
|
||||
} else {
|
||||
// Can't find closing, split before the code block
|
||||
msgEnd = findLastNewline(content[:unclosedIdx], 200)
|
||||
if msgEnd <= 0 {
|
||||
msgEnd = findLastSpace(content[:unclosedIdx], 100)
|
||||
}
|
||||
if msgEnd <= 0 {
|
||||
msgEnd = unclosedIdx
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Remaining content fits within extended limit
|
||||
msgEnd = len(content)
|
||||
}
|
||||
}
|
||||
|
||||
if msgEnd <= 0 {
|
||||
msgEnd = limit
|
||||
}
|
||||
|
||||
messages = append(messages, content[:msgEnd])
|
||||
content = strings.TrimSpace(content[msgEnd:])
|
||||
}
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
// findLastUnclosedCodeBlock finds the last opening ``` that doesn't have a closing ```
|
||||
// Returns the position of the opening ``` or -1 if all code blocks are complete
|
||||
func findLastUnclosedCodeBlock(text string) int {
|
||||
count := 0
|
||||
lastOpenIdx := -1
|
||||
|
||||
for i := 0; i < len(text); i++ {
|
||||
if i+2 < len(text) && text[i] == '`' && text[i+1] == '`' && text[i+2] == '`' {
|
||||
if count == 0 {
|
||||
lastOpenIdx = i
|
||||
}
|
||||
count++
|
||||
i += 2
|
||||
}
|
||||
}
|
||||
|
||||
// If odd number of ``` markers, last one is unclosed
|
||||
if count%2 == 1 {
|
||||
return lastOpenIdx
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// findNextClosingCodeBlock finds the next closing ``` starting from a position
|
||||
// Returns the position after the closing ``` or -1 if not found
|
||||
func findNextClosingCodeBlock(text string, startIdx int) int {
|
||||
for i := startIdx; i < len(text); i++ {
|
||||
if i+2 < len(text) && text[i] == '`' && text[i+1] == '`' && text[i+2] == '`' {
|
||||
return i + 3
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// findLastNewline finds the last newline character within the last N characters
|
||||
// Returns the position of the newline or -1 if not found
|
||||
func findLastNewline(s string, searchWindow int) int {
|
||||
searchStart := len(s) - searchWindow
|
||||
if searchStart < 0 {
|
||||
searchStart = 0
|
||||
}
|
||||
for i := len(s) - 1; i >= searchStart; i-- {
|
||||
if s[i] == '\n' {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// findLastSpace finds the last space character within the last N characters
|
||||
// Returns the position of the space or -1 if not found
|
||||
func findLastSpace(s string, searchWindow int) int {
|
||||
searchStart := len(s) - searchWindow
|
||||
if searchStart < 0 {
|
||||
searchStart = 0
|
||||
}
|
||||
for i := len(s) - 1; i >= searchStart; i-- {
|
||||
if s[i] == ' ' || s[i] == '\t' {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func (c *DiscordChannel) sendChunk(ctx context.Context, channelID, content string) error {
|
||||
// 使用传入的 ctx 进行超时控制
|
||||
sendCtx, cancel := context.WithTimeout(ctx, sendTimeout)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := c.session.ChannelMessageSend(channelID, message)
|
||||
_, err := c.session.ChannelMessageSend(channelID, content)
|
||||
done <- err
|
||||
}()
|
||||
|
||||
@@ -140,6 +282,12 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.session.ChannelTyping(m.ChannelID); err != nil {
|
||||
logger.ErrorCF("discord", "Failed to send typing indicator", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
// 检查白名单,避免为被拒绝的用户下载附件和转录
|
||||
if !c.IsAllowed(m.Author.ID) {
|
||||
logger.DebugCF("discord", "Message rejected by allowlist", map[string]any{
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
//go:build !amd64 && !arm64 && !riscv64 && !mips64 && !ppc64
|
||||
|
||||
package channels
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
// FeishuChannel is a stub implementation for 32-bit architectures
|
||||
type FeishuChannel struct {
|
||||
*BaseChannel
|
||||
}
|
||||
|
||||
// NewFeishuChannel returns an error on 32-bit architectures where the Feishu SDK is not supported
|
||||
func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChannel, error) {
|
||||
return nil, errors.New("feishu channel is not supported on 32-bit architectures (armv7l, 386, etc.). Please use a 64-bit system or disable feishu in your config")
|
||||
}
|
||||
|
||||
// Start is a stub method to satisfy the Channel interface
|
||||
func (c *FeishuChannel) Start(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop is a stub method to satisfy the Channel interface
|
||||
func (c *FeishuChannel) Stop(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send is a stub method to satisfy the Channel interface
|
||||
func (c *FeishuChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
|
||||
return errors.New("feishu channel is not supported on 32-bit architectures")
|
||||
}
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build amd64 || arm64 || riscv64 || mips64 || ppc64
|
||||
|
||||
package channels
|
||||
|
||||
import (
|
||||
@@ -0,0 +1,598 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
lineAPIBase = "https://api.line.me/v2/bot"
|
||||
lineDataAPIBase = "https://api-data.line.me/v2/bot"
|
||||
lineReplyEndpoint = lineAPIBase + "/message/reply"
|
||||
linePushEndpoint = lineAPIBase + "/message/push"
|
||||
lineContentEndpoint = lineDataAPIBase + "/message/%s/content"
|
||||
lineBotInfoEndpoint = lineAPIBase + "/info"
|
||||
lineLoadingEndpoint = lineAPIBase + "/chat/loading/start"
|
||||
lineReplyTokenMaxAge = 25 * time.Second
|
||||
)
|
||||
|
||||
type replyTokenEntry struct {
|
||||
token string
|
||||
timestamp time.Time
|
||||
}
|
||||
|
||||
// LINEChannel implements the Channel interface for LINE Official Account
|
||||
// using the LINE Messaging API with HTTP webhook for receiving messages
|
||||
// and REST API for sending messages.
|
||||
type LINEChannel struct {
|
||||
*BaseChannel
|
||||
config config.LINEConfig
|
||||
httpServer *http.Server
|
||||
botUserID string // Bot's user ID
|
||||
botBasicID string // Bot's basic ID (e.g. @216ru...)
|
||||
botDisplayName string // Bot's display name for text-based mention detection
|
||||
replyTokens sync.Map // chatID -> replyTokenEntry
|
||||
quoteTokens sync.Map // chatID -> quoteToken (string)
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewLINEChannel creates a new LINE channel instance.
|
||||
func NewLINEChannel(cfg config.LINEConfig, messageBus *bus.MessageBus) (*LINEChannel, error) {
|
||||
if cfg.ChannelSecret == "" || cfg.ChannelAccessToken == "" {
|
||||
return nil, fmt.Errorf("line channel_secret and channel_access_token are required")
|
||||
}
|
||||
|
||||
base := NewBaseChannel("line", cfg, messageBus, cfg.AllowFrom)
|
||||
|
||||
return &LINEChannel{
|
||||
BaseChannel: base,
|
||||
config: cfg,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Start launches the HTTP webhook server.
|
||||
func (c *LINEChannel) Start(ctx context.Context) error {
|
||||
logger.InfoC("line", "Starting LINE channel (Webhook Mode)")
|
||||
|
||||
c.ctx, c.cancel = context.WithCancel(ctx)
|
||||
|
||||
// Fetch bot profile to get bot's userId for mention detection
|
||||
if err := c.fetchBotInfo(); err != nil {
|
||||
logger.WarnCF("line", "Failed to fetch bot info (mention detection disabled)", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
} else {
|
||||
logger.InfoCF("line", "Bot info fetched", map[string]interface{}{
|
||||
"bot_user_id": c.botUserID,
|
||||
"basic_id": c.botBasicID,
|
||||
"display_name": c.botDisplayName,
|
||||
})
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
path := c.config.WebhookPath
|
||||
if path == "" {
|
||||
path = "/webhook/line"
|
||||
}
|
||||
mux.HandleFunc(path, c.webhookHandler)
|
||||
|
||||
addr := fmt.Sprintf("%s:%d", c.config.WebhookHost, c.config.WebhookPort)
|
||||
c.httpServer = &http.Server{
|
||||
Addr: addr,
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
go func() {
|
||||
logger.InfoCF("line", "LINE webhook server listening", map[string]interface{}{
|
||||
"addr": addr,
|
||||
"path": path,
|
||||
})
|
||||
if err := c.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
logger.ErrorCF("line", "Webhook server error", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}()
|
||||
|
||||
c.setRunning(true)
|
||||
logger.InfoC("line", "LINE channel started (Webhook Mode)")
|
||||
return nil
|
||||
}
|
||||
|
||||
// fetchBotInfo retrieves the bot's userId, basicId, and displayName from the LINE API.
|
||||
func (c *LINEChannel) fetchBotInfo() error {
|
||||
req, err := http.NewRequest(http.MethodGet, lineBotInfoEndpoint, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+c.config.ChannelAccessToken)
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("bot info API returned status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var info struct {
|
||||
UserID string `json:"userId"`
|
||||
BasicID string `json:"basicId"`
|
||||
DisplayName string `json:"displayName"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&info); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.botUserID = info.UserID
|
||||
c.botBasicID = info.BasicID
|
||||
c.botDisplayName = info.DisplayName
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down the HTTP server.
|
||||
func (c *LINEChannel) Stop(ctx context.Context) error {
|
||||
logger.InfoC("line", "Stopping LINE channel")
|
||||
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
}
|
||||
|
||||
if c.httpServer != nil {
|
||||
shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
if err := c.httpServer.Shutdown(shutdownCtx); err != nil {
|
||||
logger.ErrorCF("line", "Webhook server shutdown error", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
c.setRunning(false)
|
||||
logger.InfoC("line", "LINE channel stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
// webhookHandler handles incoming LINE webhook requests.
|
||||
func (c *LINEChannel) webhookHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
logger.ErrorCF("line", "Failed to read request body", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
http.Error(w, "Bad request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
signature := r.Header.Get("X-Line-Signature")
|
||||
if !c.verifySignature(body, signature) {
|
||||
logger.WarnC("line", "Invalid webhook signature")
|
||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
var payload struct {
|
||||
Events []lineEvent `json:"events"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
logger.ErrorCF("line", "Failed to parse webhook payload", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
http.Error(w, "Bad request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Return 200 immediately, process events asynchronously
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
for _, event := range payload.Events {
|
||||
go c.processEvent(event)
|
||||
}
|
||||
}
|
||||
|
||||
// verifySignature validates the X-Line-Signature using HMAC-SHA256.
|
||||
func (c *LINEChannel) verifySignature(body []byte, signature string) bool {
|
||||
if signature == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
mac := hmac.New(sha256.New, []byte(c.config.ChannelSecret))
|
||||
mac.Write(body)
|
||||
expected := base64.StdEncoding.EncodeToString(mac.Sum(nil))
|
||||
|
||||
return hmac.Equal([]byte(expected), []byte(signature))
|
||||
}
|
||||
|
||||
// LINE webhook event types
|
||||
type lineEvent struct {
|
||||
Type string `json:"type"`
|
||||
ReplyToken string `json:"replyToken"`
|
||||
Source lineSource `json:"source"`
|
||||
Message json.RawMessage `json:"message"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
}
|
||||
|
||||
type lineSource struct {
|
||||
Type string `json:"type"` // "user", "group", "room"
|
||||
UserID string `json:"userId"`
|
||||
GroupID string `json:"groupId"`
|
||||
RoomID string `json:"roomId"`
|
||||
}
|
||||
|
||||
type lineMessage struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"` // "text", "image", "video", "audio", "file", "sticker"
|
||||
Text string `json:"text"`
|
||||
QuoteToken string `json:"quoteToken"`
|
||||
Mention *struct {
|
||||
Mentionees []lineMentionee `json:"mentionees"`
|
||||
} `json:"mention"`
|
||||
ContentProvider struct {
|
||||
Type string `json:"type"`
|
||||
} `json:"contentProvider"`
|
||||
}
|
||||
|
||||
type lineMentionee struct {
|
||||
Index int `json:"index"`
|
||||
Length int `json:"length"`
|
||||
Type string `json:"type"` // "user", "all"
|
||||
UserID string `json:"userId"`
|
||||
}
|
||||
|
||||
func (c *LINEChannel) processEvent(event lineEvent) {
|
||||
if event.Type != "message" {
|
||||
logger.DebugCF("line", "Ignoring non-message event", map[string]interface{}{
|
||||
"type": event.Type,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
senderID := event.Source.UserID
|
||||
chatID := c.resolveChatID(event.Source)
|
||||
isGroup := event.Source.Type == "group" || event.Source.Type == "room"
|
||||
|
||||
var msg lineMessage
|
||||
if err := json.Unmarshal(event.Message, &msg); err != nil {
|
||||
logger.ErrorCF("line", "Failed to parse message", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// In group chats, only respond when the bot is mentioned
|
||||
if isGroup && !c.isBotMentioned(msg) {
|
||||
logger.DebugCF("line", "Ignoring group message without mention", map[string]interface{}{
|
||||
"chat_id": chatID,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Store reply token for later use
|
||||
if event.ReplyToken != "" {
|
||||
c.replyTokens.Store(chatID, replyTokenEntry{
|
||||
token: event.ReplyToken,
|
||||
timestamp: time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
// Store quote token for quoting the original message in reply
|
||||
if msg.QuoteToken != "" {
|
||||
c.quoteTokens.Store(chatID, msg.QuoteToken)
|
||||
}
|
||||
|
||||
var content string
|
||||
var mediaPaths []string
|
||||
localFiles := []string{}
|
||||
|
||||
defer func() {
|
||||
for _, file := range localFiles {
|
||||
if err := os.Remove(file); err != nil {
|
||||
logger.DebugCF("line", "Failed to cleanup temp file", map[string]interface{}{
|
||||
"file": file,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
switch msg.Type {
|
||||
case "text":
|
||||
content = msg.Text
|
||||
// Strip bot mention from text in group chats
|
||||
if isGroup {
|
||||
content = c.stripBotMention(content, msg)
|
||||
}
|
||||
case "image":
|
||||
localPath := c.downloadContent(msg.ID, "image.jpg")
|
||||
if localPath != "" {
|
||||
localFiles = append(localFiles, localPath)
|
||||
mediaPaths = append(mediaPaths, localPath)
|
||||
content = "[image]"
|
||||
}
|
||||
case "audio":
|
||||
localPath := c.downloadContent(msg.ID, "audio.m4a")
|
||||
if localPath != "" {
|
||||
localFiles = append(localFiles, localPath)
|
||||
mediaPaths = append(mediaPaths, localPath)
|
||||
content = "[audio]"
|
||||
}
|
||||
case "video":
|
||||
localPath := c.downloadContent(msg.ID, "video.mp4")
|
||||
if localPath != "" {
|
||||
localFiles = append(localFiles, localPath)
|
||||
mediaPaths = append(mediaPaths, localPath)
|
||||
content = "[video]"
|
||||
}
|
||||
case "file":
|
||||
content = "[file]"
|
||||
case "sticker":
|
||||
content = "[sticker]"
|
||||
default:
|
||||
content = fmt.Sprintf("[%s]", msg.Type)
|
||||
}
|
||||
|
||||
if strings.TrimSpace(content) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
metadata := map[string]string{
|
||||
"platform": "line",
|
||||
"source_type": event.Source.Type,
|
||||
"message_id": msg.ID,
|
||||
}
|
||||
|
||||
logger.DebugCF("line", "Received message", map[string]interface{}{
|
||||
"sender_id": senderID,
|
||||
"chat_id": chatID,
|
||||
"message_type": msg.Type,
|
||||
"is_group": isGroup,
|
||||
"preview": utils.Truncate(content, 50),
|
||||
})
|
||||
|
||||
// Show typing/loading indicator (requires user ID, not group ID)
|
||||
c.sendLoading(senderID)
|
||||
|
||||
c.HandleMessage(senderID, chatID, content, mediaPaths, metadata)
|
||||
}
|
||||
|
||||
// isBotMentioned checks if the bot is mentioned in the message.
|
||||
// It first checks the mention metadata (userId match), then falls back
|
||||
// to text-based detection using the bot's display name, since LINE may
|
||||
// not include userId in mentionees for Official Accounts.
|
||||
func (c *LINEChannel) isBotMentioned(msg lineMessage) bool {
|
||||
// Check mention metadata
|
||||
if msg.Mention != nil {
|
||||
for _, m := range msg.Mention.Mentionees {
|
||||
if m.Type == "all" {
|
||||
return true
|
||||
}
|
||||
if c.botUserID != "" && m.UserID == c.botUserID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
// Mention metadata exists with mentionees but bot not matched by userId.
|
||||
// The bot IS likely mentioned (LINE includes mention struct when bot is @-ed),
|
||||
// so check if any mentionee overlaps with bot display name in text.
|
||||
if c.botDisplayName != "" {
|
||||
for _, m := range msg.Mention.Mentionees {
|
||||
if m.Index >= 0 && m.Length > 0 {
|
||||
runes := []rune(msg.Text)
|
||||
end := m.Index + m.Length
|
||||
if end <= len(runes) {
|
||||
mentionText := string(runes[m.Index:end])
|
||||
if strings.Contains(mentionText, c.botDisplayName) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: text-based detection with display name
|
||||
if c.botDisplayName != "" && strings.Contains(msg.Text, "@"+c.botDisplayName) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// stripBotMention removes the @BotName mention text from the message.
|
||||
func (c *LINEChannel) stripBotMention(text string, msg lineMessage) string {
|
||||
stripped := false
|
||||
|
||||
// Try to strip using mention metadata indices
|
||||
if msg.Mention != nil {
|
||||
runes := []rune(text)
|
||||
for i := len(msg.Mention.Mentionees) - 1; i >= 0; i-- {
|
||||
m := msg.Mention.Mentionees[i]
|
||||
// Strip if userId matches OR if the mention text contains the bot display name
|
||||
shouldStrip := false
|
||||
if c.botUserID != "" && m.UserID == c.botUserID {
|
||||
shouldStrip = true
|
||||
} else if c.botDisplayName != "" && m.Index >= 0 && m.Length > 0 {
|
||||
end := m.Index + m.Length
|
||||
if end <= len(runes) {
|
||||
mentionText := string(runes[m.Index:end])
|
||||
if strings.Contains(mentionText, c.botDisplayName) {
|
||||
shouldStrip = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if shouldStrip {
|
||||
start := m.Index
|
||||
end := m.Index + m.Length
|
||||
if start >= 0 && end <= len(runes) {
|
||||
runes = append(runes[:start], runes[end:]...)
|
||||
stripped = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if stripped {
|
||||
return strings.TrimSpace(string(runes))
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: strip @DisplayName from text
|
||||
if c.botDisplayName != "" {
|
||||
text = strings.ReplaceAll(text, "@"+c.botDisplayName, "")
|
||||
}
|
||||
|
||||
return strings.TrimSpace(text)
|
||||
}
|
||||
|
||||
// resolveChatID determines the chat ID from the event source.
|
||||
// For group/room messages, use the group/room ID; for 1:1, use the user ID.
|
||||
func (c *LINEChannel) resolveChatID(source lineSource) string {
|
||||
switch source.Type {
|
||||
case "group":
|
||||
return source.GroupID
|
||||
case "room":
|
||||
return source.RoomID
|
||||
default:
|
||||
return source.UserID
|
||||
}
|
||||
}
|
||||
|
||||
// Send sends a message to LINE. It first tries the Reply API (free)
|
||||
// using a cached reply token, then falls back to the Push API.
|
||||
func (c *LINEChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
|
||||
if !c.IsRunning() {
|
||||
return fmt.Errorf("line channel not running")
|
||||
}
|
||||
|
||||
// Load and consume quote token for this chat
|
||||
var quoteToken string
|
||||
if qt, ok := c.quoteTokens.LoadAndDelete(msg.ChatID); ok {
|
||||
quoteToken = qt.(string)
|
||||
}
|
||||
|
||||
// Try reply token first (free, valid for ~25 seconds)
|
||||
if entry, ok := c.replyTokens.LoadAndDelete(msg.ChatID); ok {
|
||||
tokenEntry := entry.(replyTokenEntry)
|
||||
if time.Since(tokenEntry.timestamp) < lineReplyTokenMaxAge {
|
||||
if err := c.sendReply(ctx, tokenEntry.token, msg.Content, quoteToken); err == nil {
|
||||
logger.DebugCF("line", "Message sent via Reply API", map[string]interface{}{
|
||||
"chat_id": msg.ChatID,
|
||||
"quoted": quoteToken != "",
|
||||
})
|
||||
return nil
|
||||
}
|
||||
logger.DebugC("line", "Reply API failed, falling back to Push API")
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to Push API
|
||||
return c.sendPush(ctx, msg.ChatID, msg.Content, quoteToken)
|
||||
}
|
||||
|
||||
// buildTextMessage creates a text message object, optionally with quoteToken.
|
||||
func buildTextMessage(content, quoteToken string) map[string]string {
|
||||
msg := map[string]string{
|
||||
"type": "text",
|
||||
"text": content,
|
||||
}
|
||||
if quoteToken != "" {
|
||||
msg["quoteToken"] = quoteToken
|
||||
}
|
||||
return msg
|
||||
}
|
||||
|
||||
// sendReply sends a message using the LINE Reply API.
|
||||
func (c *LINEChannel) sendReply(ctx context.Context, replyToken, content, quoteToken string) error {
|
||||
payload := map[string]interface{}{
|
||||
"replyToken": replyToken,
|
||||
"messages": []map[string]string{buildTextMessage(content, quoteToken)},
|
||||
}
|
||||
|
||||
return c.callAPI(ctx, lineReplyEndpoint, payload)
|
||||
}
|
||||
|
||||
// sendPush sends a message using the LINE Push API.
|
||||
func (c *LINEChannel) sendPush(ctx context.Context, to, content, quoteToken string) error {
|
||||
payload := map[string]interface{}{
|
||||
"to": to,
|
||||
"messages": []map[string]string{buildTextMessage(content, quoteToken)},
|
||||
}
|
||||
|
||||
return c.callAPI(ctx, linePushEndpoint, payload)
|
||||
}
|
||||
|
||||
// sendLoading sends a loading animation indicator to the chat.
|
||||
func (c *LINEChannel) sendLoading(chatID string) {
|
||||
payload := map[string]interface{}{
|
||||
"chatId": chatID,
|
||||
"loadingSeconds": 60,
|
||||
}
|
||||
if err := c.callAPI(c.ctx, lineLoadingEndpoint, payload); err != nil {
|
||||
logger.DebugCF("line", "Failed to send loading indicator", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// callAPI makes an authenticated POST request to the LINE API.
|
||||
func (c *LINEChannel) callAPI(ctx context.Context, endpoint string, payload interface{}) error {
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal payload: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+c.config.ChannelAccessToken)
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("API request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("LINE API error (status %d): %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// downloadContent downloads media content from the LINE API.
|
||||
func (c *LINEChannel) downloadContent(messageID, filename string) string {
|
||||
url := fmt.Sprintf(lineContentEndpoint, messageID)
|
||||
return utils.DownloadFile(url, filename, utils.DownloadOptions{
|
||||
LoggerPrefix: "line",
|
||||
ExtraHeaders: map[string]string{
|
||||
"Authorization": "Bearer " + c.config.ChannelAccessToken,
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -18,7 +18,6 @@ type MaixCamChannel struct {
|
||||
listener net.Listener
|
||||
clients map[net.Conn]bool
|
||||
clientsMux sync.RWMutex
|
||||
running bool
|
||||
}
|
||||
|
||||
type MaixCamMessage struct {
|
||||
@@ -35,7 +34,6 @@ func NewMaixCamChannel(cfg config.MaixCamConfig, bus *bus.MessageBus) (*MaixCamC
|
||||
BaseChannel: base,
|
||||
config: cfg,
|
||||
clients: make(map[net.Conn]bool),
|
||||
running: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
+33
-1
@@ -13,6 +13,7 @@ import (
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/constants"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
@@ -47,7 +48,7 @@ func (m *Manager) initChannels() error {
|
||||
|
||||
if m.config.Channels.Telegram.Enabled && m.config.Channels.Telegram.Token != "" {
|
||||
logger.DebugC("channels", "Attempting to initialize Telegram channel")
|
||||
telegram, err := NewTelegramChannel(m.config.Channels.Telegram, m.bus)
|
||||
telegram, err := NewTelegramChannel(m.config, m.bus)
|
||||
if err != nil {
|
||||
logger.ErrorCF("channels", "Failed to initialize Telegram channel", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
@@ -149,6 +150,32 @@ func (m *Manager) initChannels() error {
|
||||
}
|
||||
}
|
||||
|
||||
if m.config.Channels.LINE.Enabled && m.config.Channels.LINE.ChannelAccessToken != "" {
|
||||
logger.DebugC("channels", "Attempting to initialize LINE channel")
|
||||
line, err := NewLINEChannel(m.config.Channels.LINE, m.bus)
|
||||
if err != nil {
|
||||
logger.ErrorCF("channels", "Failed to initialize LINE channel", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
} else {
|
||||
m.channels["line"] = line
|
||||
logger.InfoC("channels", "LINE channel enabled successfully")
|
||||
}
|
||||
}
|
||||
|
||||
if m.config.Channels.OneBot.Enabled && m.config.Channels.OneBot.WSUrl != "" {
|
||||
logger.DebugC("channels", "Attempting to initialize OneBot channel")
|
||||
onebot, err := NewOneBotChannel(m.config.Channels.OneBot, m.bus)
|
||||
if err != nil {
|
||||
logger.ErrorCF("channels", "Failed to initialize OneBot channel", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
} else {
|
||||
m.channels["onebot"] = onebot
|
||||
logger.InfoC("channels", "OneBot channel enabled successfully")
|
||||
}
|
||||
}
|
||||
|
||||
logger.InfoCF("channels", "Channel initialization completed", map[string]interface{}{
|
||||
"enabled_channels": len(m.channels),
|
||||
})
|
||||
@@ -229,6 +256,11 @@ func (m *Manager) dispatchOutbound(ctx context.Context) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Silently skip internal channels
|
||||
if constants.IsInternalChannel(msg.Channel) {
|
||||
continue
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
channel, exists := m.channels[msg.Channel]
|
||||
m.mu.RUnlock()
|
||||
|
||||
@@ -0,0 +1,686 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
type OneBotChannel struct {
|
||||
*BaseChannel
|
||||
config config.OneBotConfig
|
||||
conn *websocket.Conn
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
dedup map[string]struct{}
|
||||
dedupRing []string
|
||||
dedupIdx int
|
||||
mu sync.Mutex
|
||||
writeMu sync.Mutex
|
||||
echoCounter int64
|
||||
}
|
||||
|
||||
type oneBotRawEvent struct {
|
||||
PostType string `json:"post_type"`
|
||||
MessageType string `json:"message_type"`
|
||||
SubType string `json:"sub_type"`
|
||||
MessageID json.RawMessage `json:"message_id"`
|
||||
UserID json.RawMessage `json:"user_id"`
|
||||
GroupID json.RawMessage `json:"group_id"`
|
||||
RawMessage string `json:"raw_message"`
|
||||
Message json.RawMessage `json:"message"`
|
||||
Sender json.RawMessage `json:"sender"`
|
||||
SelfID json.RawMessage `json:"self_id"`
|
||||
Time json.RawMessage `json:"time"`
|
||||
MetaEventType string `json:"meta_event_type"`
|
||||
Echo string `json:"echo"`
|
||||
RetCode json.RawMessage `json:"retcode"`
|
||||
Status BotStatus `json:"status"`
|
||||
}
|
||||
|
||||
type BotStatus struct {
|
||||
Online bool `json:"online"`
|
||||
Good bool `json:"good"`
|
||||
}
|
||||
|
||||
type oneBotSender struct {
|
||||
UserID json.RawMessage `json:"user_id"`
|
||||
Nickname string `json:"nickname"`
|
||||
Card string `json:"card"`
|
||||
}
|
||||
|
||||
type oneBotEvent struct {
|
||||
PostType string
|
||||
MessageType string
|
||||
SubType string
|
||||
MessageID string
|
||||
UserID int64
|
||||
GroupID int64
|
||||
Content string
|
||||
RawContent string
|
||||
IsBotMentioned bool
|
||||
Sender oneBotSender
|
||||
SelfID int64
|
||||
Time int64
|
||||
MetaEventType string
|
||||
}
|
||||
|
||||
type oneBotAPIRequest struct {
|
||||
Action string `json:"action"`
|
||||
Params interface{} `json:"params"`
|
||||
Echo string `json:"echo,omitempty"`
|
||||
}
|
||||
|
||||
type oneBotSendPrivateMsgParams struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type oneBotSendGroupMsgParams struct {
|
||||
GroupID int64 `json:"group_id"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
func NewOneBotChannel(cfg config.OneBotConfig, messageBus *bus.MessageBus) (*OneBotChannel, error) {
|
||||
base := NewBaseChannel("onebot", cfg, messageBus, cfg.AllowFrom)
|
||||
|
||||
const dedupSize = 1024
|
||||
return &OneBotChannel{
|
||||
BaseChannel: base,
|
||||
config: cfg,
|
||||
dedup: make(map[string]struct{}, dedupSize),
|
||||
dedupRing: make([]string, dedupSize),
|
||||
dedupIdx: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) Start(ctx context.Context) error {
|
||||
if c.config.WSUrl == "" {
|
||||
return fmt.Errorf("OneBot ws_url not configured")
|
||||
}
|
||||
|
||||
logger.InfoCF("onebot", "Starting OneBot channel", map[string]interface{}{
|
||||
"ws_url": c.config.WSUrl,
|
||||
})
|
||||
|
||||
c.ctx, c.cancel = context.WithCancel(ctx)
|
||||
|
||||
if err := c.connect(); err != nil {
|
||||
logger.WarnCF("onebot", "Initial connection failed, will retry in background", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
} else {
|
||||
go c.listen()
|
||||
}
|
||||
|
||||
if c.config.ReconnectInterval > 0 {
|
||||
go c.reconnectLoop()
|
||||
} else {
|
||||
// If reconnect is disabled but initial connection failed, we cannot recover
|
||||
if c.conn == nil {
|
||||
return fmt.Errorf("failed to connect to OneBot and reconnect is disabled")
|
||||
}
|
||||
}
|
||||
|
||||
c.setRunning(true)
|
||||
logger.InfoC("onebot", "OneBot channel started successfully")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) connect() error {
|
||||
dialer := websocket.DefaultDialer
|
||||
dialer.HandshakeTimeout = 10 * time.Second
|
||||
|
||||
header := make(map[string][]string)
|
||||
if c.config.AccessToken != "" {
|
||||
header["Authorization"] = []string{"Bearer " + c.config.AccessToken}
|
||||
}
|
||||
|
||||
conn, _, err := dialer.Dial(c.config.WSUrl, header)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
c.conn = conn
|
||||
c.mu.Unlock()
|
||||
|
||||
logger.InfoC("onebot", "WebSocket connected")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) reconnectLoop() {
|
||||
interval := time.Duration(c.config.ReconnectInterval) * time.Second
|
||||
if interval < 5*time.Second {
|
||||
interval = 5 * time.Second
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
case <-time.After(interval):
|
||||
c.mu.Lock()
|
||||
conn := c.conn
|
||||
c.mu.Unlock()
|
||||
|
||||
if conn == nil {
|
||||
logger.InfoC("onebot", "Attempting to reconnect...")
|
||||
if err := c.connect(); err != nil {
|
||||
logger.ErrorCF("onebot", "Reconnect failed", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
} else {
|
||||
go c.listen()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) Stop(ctx context.Context) error {
|
||||
logger.InfoC("onebot", "Stopping OneBot channel")
|
||||
c.setRunning(false)
|
||||
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
if c.conn != nil {
|
||||
c.conn.Close()
|
||||
c.conn = nil
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
|
||||
if !c.IsRunning() {
|
||||
return fmt.Errorf("OneBot channel not running")
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
conn := c.conn
|
||||
c.mu.Unlock()
|
||||
|
||||
if conn == nil {
|
||||
return fmt.Errorf("OneBot WebSocket not connected")
|
||||
}
|
||||
|
||||
action, params, err := c.buildSendRequest(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.writeMu.Lock()
|
||||
c.echoCounter++
|
||||
echo := fmt.Sprintf("send_%d", c.echoCounter)
|
||||
c.writeMu.Unlock()
|
||||
|
||||
req := oneBotAPIRequest{
|
||||
Action: action,
|
||||
Params: params,
|
||||
Echo: echo,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal OneBot request: %w", err)
|
||||
}
|
||||
|
||||
c.writeMu.Lock()
|
||||
err = conn.WriteMessage(websocket.TextMessage, data)
|
||||
c.writeMu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
logger.ErrorCF("onebot", "Failed to send message", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) buildSendRequest(msg bus.OutboundMessage) (string, interface{}, error) {
|
||||
chatID := msg.ChatID
|
||||
|
||||
if len(chatID) > 6 && chatID[:6] == "group:" {
|
||||
groupID, err := strconv.ParseInt(chatID[6:], 10, 64)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("invalid group ID in chatID: %s", chatID)
|
||||
}
|
||||
return "send_group_msg", oneBotSendGroupMsgParams{
|
||||
GroupID: groupID,
|
||||
Message: msg.Content,
|
||||
}, nil
|
||||
}
|
||||
|
||||
if len(chatID) > 8 && chatID[:8] == "private:" {
|
||||
userID, err := strconv.ParseInt(chatID[8:], 10, 64)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("invalid user ID in chatID: %s", chatID)
|
||||
}
|
||||
return "send_private_msg", oneBotSendPrivateMsgParams{
|
||||
UserID: userID,
|
||||
Message: msg.Content,
|
||||
}, nil
|
||||
}
|
||||
|
||||
userID, err := strconv.ParseInt(chatID, 10, 64)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("invalid chatID for OneBot: %s", chatID)
|
||||
}
|
||||
|
||||
return "send_private_msg", oneBotSendPrivateMsgParams{
|
||||
UserID: userID,
|
||||
Message: msg.Content,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) listen() {
|
||||
for {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
default:
|
||||
c.mu.Lock()
|
||||
conn := c.conn
|
||||
c.mu.Unlock()
|
||||
|
||||
if conn == nil {
|
||||
logger.WarnC("onebot", "WebSocket connection is nil, listener exiting")
|
||||
return
|
||||
}
|
||||
|
||||
_, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
logger.ErrorCF("onebot", "WebSocket read error", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
c.mu.Lock()
|
||||
if c.conn != nil {
|
||||
c.conn.Close()
|
||||
c.conn = nil
|
||||
}
|
||||
c.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
logger.DebugCF("onebot", "Raw WebSocket message received", map[string]interface{}{
|
||||
"length": len(message),
|
||||
"payload": string(message),
|
||||
})
|
||||
|
||||
var raw oneBotRawEvent
|
||||
if err := json.Unmarshal(message, &raw); err != nil {
|
||||
logger.WarnCF("onebot", "Failed to unmarshal raw event", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
"payload": string(message),
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
if raw.Echo != "" || raw.Status.Online || raw.Status.Good {
|
||||
logger.DebugCF("onebot", "Received API response, skipping", map[string]interface{}{
|
||||
"echo": raw.Echo,
|
||||
"status": raw.Status,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
logger.DebugCF("onebot", "Parsed raw event", map[string]interface{}{
|
||||
"post_type": raw.PostType,
|
||||
"message_type": raw.MessageType,
|
||||
"sub_type": raw.SubType,
|
||||
"meta_event_type": raw.MetaEventType,
|
||||
})
|
||||
|
||||
c.handleRawEvent(&raw)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func parseJSONInt64(raw json.RawMessage) (int64, error) {
|
||||
if len(raw) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
var n int64
|
||||
if err := json.Unmarshal(raw, &n); err == nil {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
var s string
|
||||
if err := json.Unmarshal(raw, &s); err == nil {
|
||||
return strconv.ParseInt(s, 10, 64)
|
||||
}
|
||||
return 0, fmt.Errorf("cannot parse as int64: %s", string(raw))
|
||||
}
|
||||
|
||||
func parseJSONString(raw json.RawMessage) string {
|
||||
if len(raw) == 0 {
|
||||
return ""
|
||||
}
|
||||
var s string
|
||||
if err := json.Unmarshal(raw, &s); err == nil {
|
||||
return s
|
||||
}
|
||||
|
||||
return string(raw)
|
||||
}
|
||||
|
||||
type parseMessageResult struct {
|
||||
Text string
|
||||
IsBotMentioned bool
|
||||
}
|
||||
|
||||
func parseMessageContentEx(raw json.RawMessage, selfID int64) parseMessageResult {
|
||||
if len(raw) == 0 {
|
||||
return parseMessageResult{}
|
||||
}
|
||||
|
||||
var s string
|
||||
if err := json.Unmarshal(raw, &s); err == nil {
|
||||
mentioned := false
|
||||
if selfID > 0 {
|
||||
cqAt := fmt.Sprintf("[CQ:at,qq=%d]", selfID)
|
||||
if strings.Contains(s, cqAt) {
|
||||
mentioned = true
|
||||
s = strings.ReplaceAll(s, cqAt, "")
|
||||
s = strings.TrimSpace(s)
|
||||
}
|
||||
}
|
||||
return parseMessageResult{Text: s, IsBotMentioned: mentioned}
|
||||
}
|
||||
|
||||
var segments []map[string]interface{}
|
||||
if err := json.Unmarshal(raw, &segments); err == nil {
|
||||
var text string
|
||||
mentioned := false
|
||||
selfIDStr := strconv.FormatInt(selfID, 10)
|
||||
for _, seg := range segments {
|
||||
segType, _ := seg["type"].(string)
|
||||
data, _ := seg["data"].(map[string]interface{})
|
||||
switch segType {
|
||||
case "text":
|
||||
if data != nil {
|
||||
if t, ok := data["text"].(string); ok {
|
||||
text += t
|
||||
}
|
||||
}
|
||||
case "at":
|
||||
if data != nil && selfID > 0 {
|
||||
qqVal := fmt.Sprintf("%v", data["qq"])
|
||||
if qqVal == selfIDStr || qqVal == "all" {
|
||||
mentioned = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return parseMessageResult{Text: strings.TrimSpace(text), IsBotMentioned: mentioned}
|
||||
}
|
||||
return parseMessageResult{}
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) handleRawEvent(raw *oneBotRawEvent) {
|
||||
switch raw.PostType {
|
||||
case "message":
|
||||
evt, err := c.normalizeMessageEvent(raw)
|
||||
if err != nil {
|
||||
logger.WarnCF("onebot", "Failed to normalize message event", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.handleMessage(evt)
|
||||
case "meta_event":
|
||||
c.handleMetaEvent(raw)
|
||||
case "notice":
|
||||
logger.DebugCF("onebot", "Notice event received", map[string]interface{}{
|
||||
"sub_type": raw.SubType,
|
||||
})
|
||||
case "request":
|
||||
logger.DebugCF("onebot", "Request event received", map[string]interface{}{
|
||||
"sub_type": raw.SubType,
|
||||
})
|
||||
case "":
|
||||
logger.DebugCF("onebot", "Event with empty post_type (possibly API response)", map[string]interface{}{
|
||||
"echo": raw.Echo,
|
||||
"status": raw.Status,
|
||||
})
|
||||
default:
|
||||
logger.DebugCF("onebot", "Unknown post_type", map[string]interface{}{
|
||||
"post_type": raw.PostType,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) normalizeMessageEvent(raw *oneBotRawEvent) (*oneBotEvent, error) {
|
||||
userID, err := parseJSONInt64(raw.UserID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse user_id: %w (raw: %s)", err, string(raw.UserID))
|
||||
}
|
||||
|
||||
groupID, _ := parseJSONInt64(raw.GroupID)
|
||||
selfID, _ := parseJSONInt64(raw.SelfID)
|
||||
ts, _ := parseJSONInt64(raw.Time)
|
||||
messageID := parseJSONString(raw.MessageID)
|
||||
|
||||
parsed := parseMessageContentEx(raw.Message, selfID)
|
||||
isBotMentioned := parsed.IsBotMentioned
|
||||
|
||||
content := raw.RawMessage
|
||||
if content == "" {
|
||||
content = parsed.Text
|
||||
} else if selfID > 0 {
|
||||
cqAt := fmt.Sprintf("[CQ:at,qq=%d]", selfID)
|
||||
if strings.Contains(content, cqAt) {
|
||||
isBotMentioned = true
|
||||
content = strings.ReplaceAll(content, cqAt, "")
|
||||
content = strings.TrimSpace(content)
|
||||
}
|
||||
}
|
||||
|
||||
var sender oneBotSender
|
||||
if len(raw.Sender) > 0 {
|
||||
if err := json.Unmarshal(raw.Sender, &sender); err != nil {
|
||||
logger.WarnCF("onebot", "Failed to parse sender", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
"sender": string(raw.Sender),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
logger.DebugCF("onebot", "Normalized message event", map[string]interface{}{
|
||||
"message_type": raw.MessageType,
|
||||
"user_id": userID,
|
||||
"group_id": groupID,
|
||||
"message_id": messageID,
|
||||
"content_len": len(content),
|
||||
"nickname": sender.Nickname,
|
||||
})
|
||||
|
||||
return &oneBotEvent{
|
||||
PostType: raw.PostType,
|
||||
MessageType: raw.MessageType,
|
||||
SubType: raw.SubType,
|
||||
MessageID: messageID,
|
||||
UserID: userID,
|
||||
GroupID: groupID,
|
||||
Content: content,
|
||||
RawContent: raw.RawMessage,
|
||||
IsBotMentioned: isBotMentioned,
|
||||
Sender: sender,
|
||||
SelfID: selfID,
|
||||
Time: ts,
|
||||
MetaEventType: raw.MetaEventType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) handleMetaEvent(raw *oneBotRawEvent) {
|
||||
switch raw.MetaEventType {
|
||||
case "lifecycle":
|
||||
logger.InfoCF("onebot", "Lifecycle event", map[string]interface{}{
|
||||
"sub_type": raw.SubType,
|
||||
})
|
||||
case "heartbeat":
|
||||
logger.DebugC("onebot", "Heartbeat received")
|
||||
default:
|
||||
logger.DebugCF("onebot", "Unknown meta_event_type", map[string]interface{}{
|
||||
"meta_event_type": raw.MetaEventType,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) handleMessage(evt *oneBotEvent) {
|
||||
if c.isDuplicate(evt.MessageID) {
|
||||
logger.DebugCF("onebot", "Duplicate message, skipping", map[string]interface{}{
|
||||
"message_id": evt.MessageID,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
content := evt.Content
|
||||
if content == "" {
|
||||
logger.DebugCF("onebot", "Received empty message, ignoring", map[string]interface{}{
|
||||
"message_id": evt.MessageID,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
senderID := strconv.FormatInt(evt.UserID, 10)
|
||||
var chatID string
|
||||
|
||||
metadata := map[string]string{
|
||||
"message_id": evt.MessageID,
|
||||
}
|
||||
|
||||
switch evt.MessageType {
|
||||
case "private":
|
||||
chatID = "private:" + senderID
|
||||
logger.InfoCF("onebot", "Received private message", map[string]interface{}{
|
||||
"sender": senderID,
|
||||
"message_id": evt.MessageID,
|
||||
"length": len(content),
|
||||
"content": truncate(content, 100),
|
||||
})
|
||||
|
||||
case "group":
|
||||
groupIDStr := strconv.FormatInt(evt.GroupID, 10)
|
||||
chatID = "group:" + groupIDStr
|
||||
metadata["group_id"] = groupIDStr
|
||||
|
||||
senderUserID, _ := parseJSONInt64(evt.Sender.UserID)
|
||||
if senderUserID > 0 {
|
||||
metadata["sender_user_id"] = strconv.FormatInt(senderUserID, 10)
|
||||
}
|
||||
|
||||
if evt.Sender.Card != "" {
|
||||
metadata["sender_name"] = evt.Sender.Card
|
||||
} else if evt.Sender.Nickname != "" {
|
||||
metadata["sender_name"] = evt.Sender.Nickname
|
||||
}
|
||||
|
||||
triggered, strippedContent := c.checkGroupTrigger(content, evt.IsBotMentioned)
|
||||
if !triggered {
|
||||
logger.DebugCF("onebot", "Group message ignored (no trigger)", map[string]interface{}{
|
||||
"sender": senderID,
|
||||
"group": groupIDStr,
|
||||
"is_mentioned": evt.IsBotMentioned,
|
||||
"content": truncate(content, 100),
|
||||
})
|
||||
return
|
||||
}
|
||||
content = strippedContent
|
||||
|
||||
logger.InfoCF("onebot", "Received group message", map[string]interface{}{
|
||||
"sender": senderID,
|
||||
"group": groupIDStr,
|
||||
"message_id": evt.MessageID,
|
||||
"is_mentioned": evt.IsBotMentioned,
|
||||
"length": len(content),
|
||||
"content": truncate(content, 100),
|
||||
})
|
||||
|
||||
default:
|
||||
logger.WarnCF("onebot", "Unknown message type, cannot route", map[string]interface{}{
|
||||
"type": evt.MessageType,
|
||||
"message_id": evt.MessageID,
|
||||
"user_id": evt.UserID,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if evt.Sender.Nickname != "" {
|
||||
metadata["nickname"] = evt.Sender.Nickname
|
||||
}
|
||||
|
||||
logger.DebugCF("onebot", "Forwarding message to bus", map[string]interface{}{
|
||||
"sender_id": senderID,
|
||||
"chat_id": chatID,
|
||||
"content": truncate(content, 100),
|
||||
})
|
||||
|
||||
c.HandleMessage(senderID, chatID, content, []string{}, metadata)
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) isDuplicate(messageID string) bool {
|
||||
if messageID == "" || messageID == "0" {
|
||||
return false
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if _, exists := c.dedup[messageID]; exists {
|
||||
return true
|
||||
}
|
||||
|
||||
if old := c.dedupRing[c.dedupIdx]; old != "" {
|
||||
delete(c.dedup, old)
|
||||
}
|
||||
c.dedupRing[c.dedupIdx] = messageID
|
||||
c.dedup[messageID] = struct{}{}
|
||||
c.dedupIdx = (c.dedupIdx + 1) % len(c.dedupRing)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func truncate(s string, n int) string {
|
||||
runes := []rune(s)
|
||||
if len(runes) <= n {
|
||||
return s
|
||||
}
|
||||
return string(runes[:n]) + "..."
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) checkGroupTrigger(content string, isBotMentioned bool) (triggered bool, strippedContent string) {
|
||||
if isBotMentioned {
|
||||
return true, strings.TrimSpace(content)
|
||||
}
|
||||
|
||||
for _, prefix := range c.config.GroupTriggerPrefix {
|
||||
if prefix == "" {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(content, prefix) {
|
||||
return true, strings.TrimSpace(strings.TrimPrefix(content, prefix))
|
||||
}
|
||||
}
|
||||
|
||||
return false, content
|
||||
}
|
||||
+17
-3
@@ -282,9 +282,9 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) {
|
||||
}
|
||||
|
||||
logger.DebugCF("slack", "Received message", map[string]interface{}{
|
||||
"sender_id": senderID,
|
||||
"chat_id": chatID,
|
||||
"preview": utils.Truncate(content, 50),
|
||||
"sender_id": senderID,
|
||||
"chat_id": chatID,
|
||||
"preview": utils.Truncate(content, 50),
|
||||
"has_thread": threadTS != "",
|
||||
})
|
||||
|
||||
@@ -296,6 +296,13 @@ func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) {
|
||||
return
|
||||
}
|
||||
|
||||
if !c.IsAllowed(ev.User) {
|
||||
logger.DebugCF("slack", "Mention rejected by allowlist", map[string]interface{}{
|
||||
"user_id": ev.User,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
senderID := ev.User
|
||||
channelID := ev.Channel
|
||||
threadTS := ev.ThreadTimeStamp
|
||||
@@ -345,6 +352,13 @@ func (c *SlackChannel) handleSlashCommand(event socketmode.Event) {
|
||||
c.socketClient.Ack(*event.Request)
|
||||
}
|
||||
|
||||
if !c.IsAllowed(cmd.UserID) {
|
||||
logger.DebugCF("slack", "Slash command rejected by allowlist", map[string]interface{}{
|
||||
"user_id": cmd.UserID,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
senderID := cmd.UserID
|
||||
channelID := cmd.ChannelID
|
||||
chatID := channelID
|
||||
|
||||
+63
-60
@@ -11,7 +11,10 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
th "github.com/mymmrac/telego/telegohandler"
|
||||
|
||||
"github.com/mymmrac/telego"
|
||||
"github.com/mymmrac/telego/telegohandler"
|
||||
tu "github.com/mymmrac/telego/telegoutil"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
@@ -24,7 +27,8 @@ import (
|
||||
type TelegramChannel struct {
|
||||
*BaseChannel
|
||||
bot *telego.Bot
|
||||
config config.TelegramConfig
|
||||
commands TelegramCommander
|
||||
config *config.Config
|
||||
chatIDs map[string]int64
|
||||
transcriber *voice.GroqTranscriber
|
||||
placeholders sync.Map // chatID -> messageID
|
||||
@@ -41,13 +45,14 @@ func (c *thinkingCancel) Cancel() {
|
||||
}
|
||||
}
|
||||
|
||||
func NewTelegramChannel(cfg config.TelegramConfig, bus *bus.MessageBus) (*TelegramChannel, error) {
|
||||
func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChannel, error) {
|
||||
var opts []telego.BotOption
|
||||
telegramCfg := cfg.Channels.Telegram
|
||||
|
||||
if cfg.Proxy != "" {
|
||||
proxyURL, parseErr := url.Parse(cfg.Proxy)
|
||||
if telegramCfg.Proxy != "" {
|
||||
proxyURL, parseErr := url.Parse(telegramCfg.Proxy)
|
||||
if parseErr != nil {
|
||||
return nil, fmt.Errorf("invalid proxy URL %q: %w", cfg.Proxy, parseErr)
|
||||
return nil, fmt.Errorf("invalid proxy URL %q: %w", telegramCfg.Proxy, parseErr)
|
||||
}
|
||||
opts = append(opts, telego.WithHTTPClient(&http.Client{
|
||||
Transport: &http.Transport{
|
||||
@@ -56,15 +61,16 @@ func NewTelegramChannel(cfg config.TelegramConfig, bus *bus.MessageBus) (*Telegr
|
||||
}))
|
||||
}
|
||||
|
||||
bot, err := telego.NewBot(cfg.Token, opts...)
|
||||
bot, err := telego.NewBot(telegramCfg.Token, opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create telegram bot: %w", err)
|
||||
}
|
||||
|
||||
base := NewBaseChannel("telegram", cfg, bus, cfg.AllowFrom)
|
||||
base := NewBaseChannel("telegram", telegramCfg, bus, telegramCfg.AllowFrom)
|
||||
|
||||
return &TelegramChannel{
|
||||
BaseChannel: base,
|
||||
commands: NewTelegramCommands(bot, cfg),
|
||||
bot: bot,
|
||||
config: cfg,
|
||||
chatIDs: make(map[string]int64),
|
||||
@@ -88,31 +94,45 @@ func (c *TelegramChannel) Start(ctx context.Context) error {
|
||||
return fmt.Errorf("failed to start long polling: %w", err)
|
||||
}
|
||||
|
||||
bh, err := telegohandler.NewBotHandler(c.bot, updates)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create bot handler: %w", err)
|
||||
}
|
||||
|
||||
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
|
||||
c.commands.Help(ctx, message)
|
||||
return nil
|
||||
}, th.CommandEqual("help"))
|
||||
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
|
||||
return c.commands.Start(ctx, message)
|
||||
}, th.CommandEqual("start"))
|
||||
|
||||
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
|
||||
return c.commands.Show(ctx, message)
|
||||
}, th.CommandEqual("show"))
|
||||
|
||||
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
|
||||
return c.commands.List(ctx, message)
|
||||
}, th.CommandEqual("list"))
|
||||
|
||||
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
|
||||
return c.handleMessage(ctx, &message)
|
||||
}, th.AnyMessage())
|
||||
|
||||
c.setRunning(true)
|
||||
logger.InfoCF("telegram", "Telegram bot connected", map[string]interface{}{
|
||||
"username": c.bot.Username(),
|
||||
})
|
||||
|
||||
go bh.Start()
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case update, ok := <-updates:
|
||||
if !ok {
|
||||
logger.InfoC("telegram", "Updates channel closed, reconnecting...")
|
||||
return
|
||||
}
|
||||
if update.Message != nil {
|
||||
c.handleMessage(ctx, update)
|
||||
}
|
||||
}
|
||||
}
|
||||
<-ctx.Done()
|
||||
bh.Stop()
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *TelegramChannel) Stop(ctx context.Context) error {
|
||||
logger.InfoC("telegram", "Stopping Telegram bot...")
|
||||
c.setRunning(false)
|
||||
@@ -166,15 +186,14 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Update) {
|
||||
message := update.Message
|
||||
func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Message) error {
|
||||
if message == nil {
|
||||
return
|
||||
return fmt.Errorf("message is nil")
|
||||
}
|
||||
|
||||
user := message.From
|
||||
if user == nil {
|
||||
return
|
||||
return fmt.Errorf("message sender (user) is nil")
|
||||
}
|
||||
|
||||
senderID := fmt.Sprintf("%d", user.ID)
|
||||
@@ -187,7 +206,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
|
||||
logger.DebugCF("telegram", "Message rejected by allowlist", map[string]interface{}{
|
||||
"user_id": senderID,
|
||||
})
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
chatID := message.Chat.ID
|
||||
@@ -220,7 +239,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
|
||||
content += message.Caption
|
||||
}
|
||||
|
||||
if message.Photo != nil && len(message.Photo) > 0 {
|
||||
if len(message.Photo) > 0 {
|
||||
photo := message.Photo[len(message.Photo)-1]
|
||||
photoPath := c.downloadPhoto(ctx, photo.FileID)
|
||||
if photoPath != "" {
|
||||
@@ -229,7 +248,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
|
||||
if content != "" {
|
||||
content += "\n"
|
||||
}
|
||||
content += fmt.Sprintf("[image: photo]")
|
||||
content += "[image: photo]"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -250,7 +269,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
|
||||
"error": err.Error(),
|
||||
"path": voicePath,
|
||||
})
|
||||
transcribedText = fmt.Sprintf("[voice (transcription failed)]")
|
||||
transcribedText = "[voice (transcription failed)]"
|
||||
} else {
|
||||
transcribedText = fmt.Sprintf("[voice transcription: %s]", result.Text)
|
||||
logger.InfoCF("telegram", "Voice transcribed successfully", map[string]interface{}{
|
||||
@@ -258,7 +277,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
|
||||
})
|
||||
}
|
||||
} else {
|
||||
transcribedText = fmt.Sprintf("[voice]")
|
||||
transcribedText = "[voice]"
|
||||
}
|
||||
|
||||
if content != "" {
|
||||
@@ -276,7 +295,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
|
||||
if content != "" {
|
||||
content += "\n"
|
||||
}
|
||||
content += fmt.Sprintf("[audio]")
|
||||
content += "[audio]"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -288,7 +307,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
|
||||
if content != "" {
|
||||
content += "\n"
|
||||
}
|
||||
content += fmt.Sprintf("[file]")
|
||||
content += "[file]"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -318,37 +337,14 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
|
||||
}
|
||||
}
|
||||
|
||||
// Create new context for thinking animation with timeout
|
||||
thinkCtx, thinkCancel := context.WithTimeout(ctx, 5*time.Minute)
|
||||
// Create cancel function for thinking state
|
||||
_, thinkCancel := context.WithTimeout(ctx, 5*time.Minute)
|
||||
c.stopThinking.Store(chatIDStr, &thinkingCancel{fn: thinkCancel})
|
||||
|
||||
pMsg, err := c.bot.SendMessage(ctx, tu.Message(tu.ID(chatID), "Thinking... 💭"))
|
||||
if err == nil {
|
||||
pID := pMsg.MessageID
|
||||
c.placeholders.Store(chatIDStr, pID)
|
||||
|
||||
go func(cid int64, mid int) {
|
||||
dots := []string{".", "..", "..."}
|
||||
emotes := []string{"💭", "🤔", "☁️"}
|
||||
i := 0
|
||||
ticker := time.NewTicker(2000 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-thinkCtx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
i++
|
||||
text := fmt.Sprintf("Thinking%s %s", dots[i%len(dots)], emotes[i%len(emotes)])
|
||||
_, editErr := c.bot.EditMessageText(thinkCtx, tu.EditMessageText(tu.ID(chatID), mid, text))
|
||||
if editErr != nil {
|
||||
logger.DebugCF("telegram", "Failed to edit thinking message", map[string]interface{}{
|
||||
"error": editErr.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}(chatID, pID)
|
||||
}
|
||||
|
||||
metadata := map[string]string{
|
||||
@@ -360,6 +356,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
|
||||
}
|
||||
|
||||
c.HandleMessage(fmt.Sprintf("%d", user.ID), fmt.Sprintf("%d", chatID), content, mediaPaths, metadata)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *TelegramChannel) downloadPhoto(ctx context.Context, fileID string) string {
|
||||
@@ -470,8 +467,11 @@ func extractCodeBlocks(text string) codeBlockMatch {
|
||||
codes = append(codes, match[1])
|
||||
}
|
||||
|
||||
i := 0
|
||||
text = re.ReplaceAllStringFunc(text, func(m string) string {
|
||||
return fmt.Sprintf("\x00CB%d\x00", len(codes)-1)
|
||||
placeholder := fmt.Sprintf("\x00CB%d\x00", i)
|
||||
i++
|
||||
return placeholder
|
||||
})
|
||||
|
||||
return codeBlockMatch{text: text, codes: codes}
|
||||
@@ -491,8 +491,11 @@ func extractInlineCodes(text string) inlineCodeMatch {
|
||||
codes = append(codes, match[1])
|
||||
}
|
||||
|
||||
i := 0
|
||||
text = re.ReplaceAllStringFunc(text, func(m string) string {
|
||||
return fmt.Sprintf("\x00IC%d\x00", len(codes)-1)
|
||||
placeholder := fmt.Sprintf("\x00IC%d\x00", i)
|
||||
i++
|
||||
return placeholder
|
||||
})
|
||||
|
||||
return inlineCodeMatch{text: text, codes: codes}
|
||||
|
||||
@@ -0,0 +1,153 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/mymmrac/telego"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
type TelegramCommander interface {
|
||||
Help(ctx context.Context, message telego.Message) error
|
||||
Start(ctx context.Context, message telego.Message) error
|
||||
Show(ctx context.Context, message telego.Message) error
|
||||
List(ctx context.Context, message telego.Message) error
|
||||
}
|
||||
|
||||
type cmd struct {
|
||||
bot *telego.Bot
|
||||
config *config.Config
|
||||
}
|
||||
|
||||
func NewTelegramCommands(bot *telego.Bot, cfg *config.Config) TelegramCommander {
|
||||
return &cmd{
|
||||
bot: bot,
|
||||
config: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
func commandArgs(text string) string {
|
||||
parts := strings.SplitN(text, " ", 2)
|
||||
if len(parts) < 2 {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(parts[1])
|
||||
}
|
||||
func (c *cmd) Help(ctx context.Context, message telego.Message) error {
|
||||
msg := `/start - Start the bot
|
||||
/help - Show this help message
|
||||
/show [model|channel] - Show current configuration
|
||||
/list [models|channels] - List available options
|
||||
`
|
||||
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
|
||||
ChatID: telego.ChatID{ID: message.Chat.ID},
|
||||
Text: msg,
|
||||
ReplyParameters: &telego.ReplyParameters{
|
||||
MessageID: message.MessageID,
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *cmd) Start(ctx context.Context, message telego.Message) error {
|
||||
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
|
||||
ChatID: telego.ChatID{ID: message.Chat.ID},
|
||||
Text: "Hello! I am PicoClaw 🦞",
|
||||
ReplyParameters: &telego.ReplyParameters{
|
||||
MessageID: message.MessageID,
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *cmd) Show(ctx context.Context, message telego.Message) error {
|
||||
args := commandArgs(message.Text)
|
||||
if args == "" {
|
||||
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
|
||||
ChatID: telego.ChatID{ID: message.Chat.ID},
|
||||
Text: "Usage: /show [model|channel]",
|
||||
ReplyParameters: &telego.ReplyParameters{
|
||||
MessageID: message.MessageID,
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
var response string
|
||||
switch args {
|
||||
case "model":
|
||||
response = fmt.Sprintf("Current Model: %s (Provider: %s)",
|
||||
c.config.Agents.Defaults.Model,
|
||||
c.config.Agents.Defaults.Provider)
|
||||
case "channel":
|
||||
response = "Current Channel: telegram"
|
||||
default:
|
||||
response = fmt.Sprintf("Unknown parameter: %s. Try 'model' or 'channel'.", args)
|
||||
}
|
||||
|
||||
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
|
||||
ChatID: telego.ChatID{ID: message.Chat.ID},
|
||||
Text: response,
|
||||
ReplyParameters: &telego.ReplyParameters{
|
||||
MessageID: message.MessageID,
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
func (c *cmd) List(ctx context.Context, message telego.Message) error {
|
||||
args := commandArgs(message.Text)
|
||||
if args == "" {
|
||||
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
|
||||
ChatID: telego.ChatID{ID: message.Chat.ID},
|
||||
Text: "Usage: /list [models|channels]",
|
||||
ReplyParameters: &telego.ReplyParameters{
|
||||
MessageID: message.MessageID,
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
var response string
|
||||
switch args {
|
||||
case "models":
|
||||
provider := c.config.Agents.Defaults.Provider
|
||||
if provider == "" {
|
||||
provider = "configured default"
|
||||
}
|
||||
response = fmt.Sprintf("Configured Model: %s\nProvider: %s\n\nTo change models, update config.yaml",
|
||||
c.config.Agents.Defaults.Model, provider)
|
||||
|
||||
case "channels":
|
||||
var enabled []string
|
||||
if c.config.Channels.Telegram.Enabled {
|
||||
enabled = append(enabled, "telegram")
|
||||
}
|
||||
if c.config.Channels.WhatsApp.Enabled {
|
||||
enabled = append(enabled, "whatsapp")
|
||||
}
|
||||
if c.config.Channels.Feishu.Enabled {
|
||||
enabled = append(enabled, "feishu")
|
||||
}
|
||||
if c.config.Channels.Discord.Enabled {
|
||||
enabled = append(enabled, "discord")
|
||||
}
|
||||
if c.config.Channels.Slack.Enabled {
|
||||
enabled = append(enabled, "slack")
|
||||
}
|
||||
response = fmt.Sprintf("Enabled Channels:\n- %s", strings.Join(enabled, "\n- "))
|
||||
|
||||
default:
|
||||
response = fmt.Sprintf("Unknown parameter: %s. Try 'models' or 'channels'.", args)
|
||||
}
|
||||
|
||||
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
|
||||
ChatID: telego.ChatID{ID: message.Chat.ID},
|
||||
Text: response,
|
||||
ReplyParameters: &telego.ReplyParameters{
|
||||
MessageID: message.MessageID,
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
+228
-48
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/caarlos0/env/v11"
|
||||
@@ -49,7 +50,12 @@ type Config struct {
|
||||
Providers ProvidersConfig `json:"providers"`
|
||||
Gateway GatewayConfig `json:"gateway"`
|
||||
Tools ToolsConfig `json:"tools"`
|
||||
mu sync.RWMutex
|
||||
Heartbeat HeartbeatConfig `json:"heartbeat"`
|
||||
Devices DevicesConfig `json:"devices"`
|
||||
// MCPServers is a compatibility alias for configs using top-level "mcpServers".
|
||||
// Canonical config remains tools.mcp.servers.
|
||||
MCPServers map[string]LegacyMCPServerConfig `json:"mcpServers,omitempty"`
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
type AgentsConfig struct {
|
||||
@@ -57,13 +63,13 @@ type AgentsConfig struct {
|
||||
}
|
||||
|
||||
type AgentDefaults struct {
|
||||
Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"`
|
||||
RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"`
|
||||
Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"`
|
||||
Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"`
|
||||
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
|
||||
Temperature float64 `json:"temperature" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
|
||||
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
|
||||
Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"`
|
||||
RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"`
|
||||
Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"`
|
||||
Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"`
|
||||
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
|
||||
Temperature float64 `json:"temperature" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
|
||||
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
|
||||
}
|
||||
|
||||
type ChannelsConfig struct {
|
||||
@@ -75,6 +81,8 @@ type ChannelsConfig struct {
|
||||
QQ QQConfig `json:"qq"`
|
||||
DingTalk DingTalkConfig `json:"dingtalk"`
|
||||
Slack SlackConfig `json:"slack"`
|
||||
LINE LINEConfig `json:"line"`
|
||||
OneBot OneBotConfig `json:"onebot"`
|
||||
}
|
||||
|
||||
type WhatsAppConfig struct {
|
||||
@@ -127,29 +135,63 @@ type DingTalkConfig struct {
|
||||
}
|
||||
|
||||
type SlackConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_SLACK_ENABLED"`
|
||||
BotToken string `json:"bot_token" env:"PICOCLAW_CHANNELS_SLACK_BOT_TOKEN"`
|
||||
AppToken string `json:"app_token" env:"PICOCLAW_CHANNELS_SLACK_APP_TOKEN"`
|
||||
AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_SLACK_ALLOW_FROM"`
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_SLACK_ENABLED"`
|
||||
BotToken string `json:"bot_token" env:"PICOCLAW_CHANNELS_SLACK_BOT_TOKEN"`
|
||||
AppToken string `json:"app_token" env:"PICOCLAW_CHANNELS_SLACK_APP_TOKEN"`
|
||||
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_SLACK_ALLOW_FROM"`
|
||||
}
|
||||
|
||||
type LINEConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_LINE_ENABLED"`
|
||||
ChannelSecret string `json:"channel_secret" env:"PICOCLAW_CHANNELS_LINE_CHANNEL_SECRET"`
|
||||
ChannelAccessToken string `json:"channel_access_token" env:"PICOCLAW_CHANNELS_LINE_CHANNEL_ACCESS_TOKEN"`
|
||||
WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_LINE_WEBHOOK_HOST"`
|
||||
WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_LINE_WEBHOOK_PORT"`
|
||||
WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_LINE_WEBHOOK_PATH"`
|
||||
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_LINE_ALLOW_FROM"`
|
||||
}
|
||||
|
||||
type OneBotConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_ONEBOT_ENABLED"`
|
||||
WSUrl string `json:"ws_url" env:"PICOCLAW_CHANNELS_ONEBOT_WS_URL"`
|
||||
AccessToken string `json:"access_token" env:"PICOCLAW_CHANNELS_ONEBOT_ACCESS_TOKEN"`
|
||||
ReconnectInterval int `json:"reconnect_interval" env:"PICOCLAW_CHANNELS_ONEBOT_RECONNECT_INTERVAL"`
|
||||
GroupTriggerPrefix []string `json:"group_trigger_prefix" env:"PICOCLAW_CHANNELS_ONEBOT_GROUP_TRIGGER_PREFIX"`
|
||||
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_ONEBOT_ALLOW_FROM"`
|
||||
}
|
||||
|
||||
type HeartbeatConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_HEARTBEAT_ENABLED"`
|
||||
Interval int `json:"interval" env:"PICOCLAW_HEARTBEAT_INTERVAL"` // minutes, min 5
|
||||
}
|
||||
|
||||
type DevicesConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_DEVICES_ENABLED"`
|
||||
MonitorUSB bool `json:"monitor_usb" env:"PICOCLAW_DEVICES_MONITOR_USB"`
|
||||
}
|
||||
|
||||
type ProvidersConfig struct {
|
||||
Anthropic ProviderConfig `json:"anthropic"`
|
||||
OpenAI ProviderConfig `json:"openai"`
|
||||
OpenRouter ProviderConfig `json:"openrouter"`
|
||||
Groq ProviderConfig `json:"groq"`
|
||||
Zhipu ProviderConfig `json:"zhipu"`
|
||||
VLLM ProviderConfig `json:"vllm"`
|
||||
Gemini ProviderConfig `json:"gemini"`
|
||||
Nvidia ProviderConfig `json:"nvidia"`
|
||||
Moonshot ProviderConfig `json:"moonshot"`
|
||||
Anthropic ProviderConfig `json:"anthropic"`
|
||||
OpenAI ProviderConfig `json:"openai"`
|
||||
OpenRouter ProviderConfig `json:"openrouter"`
|
||||
Groq ProviderConfig `json:"groq"`
|
||||
Zhipu ProviderConfig `json:"zhipu"`
|
||||
VLLM ProviderConfig `json:"vllm"`
|
||||
Gemini ProviderConfig `json:"gemini"`
|
||||
Nvidia ProviderConfig `json:"nvidia"`
|
||||
Ollama ProviderConfig `json:"ollama"`
|
||||
Moonshot ProviderConfig `json:"moonshot"`
|
||||
ShengSuanYun ProviderConfig `json:"shengsuanyun"`
|
||||
DeepSeek ProviderConfig `json:"deepseek"`
|
||||
GitHubCopilot ProviderConfig `json:"github_copilot"`
|
||||
}
|
||||
|
||||
type ProviderConfig struct {
|
||||
APIKey string `json:"api_key" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_KEY"`
|
||||
APIBase string `json:"api_base" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_BASE"`
|
||||
Proxy string `json:"proxy,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_PROXY"`
|
||||
AuthMethod string `json:"auth_method,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_AUTH_METHOD"`
|
||||
APIKey string `json:"api_key" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_KEY"`
|
||||
APIBase string `json:"api_base" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_BASE"`
|
||||
Proxy string `json:"proxy,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_PROXY"`
|
||||
AuthMethod string `json:"auth_method,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_AUTH_METHOD"`
|
||||
ConnectMode string `json:"connect_mode,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_CONNECT_MODE"` //only for Github Copilot, `stdio` or `grpc`
|
||||
}
|
||||
|
||||
type GatewayConfig struct {
|
||||
@@ -157,30 +199,78 @@ type GatewayConfig struct {
|
||||
Port int `json:"port" env:"PICOCLAW_GATEWAY_PORT"`
|
||||
}
|
||||
|
||||
type WebSearchConfig struct {
|
||||
APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_SEARCH_API_KEY"`
|
||||
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_SEARCH_MAX_RESULTS"`
|
||||
type BraveConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_BRAVE_ENABLED"`
|
||||
APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_BRAVE_API_KEY"`
|
||||
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_BRAVE_MAX_RESULTS"`
|
||||
}
|
||||
|
||||
type DuckDuckGoConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_DUCKDUCKGO_ENABLED"`
|
||||
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_DUCKDUCKGO_MAX_RESULTS"`
|
||||
}
|
||||
|
||||
type PerplexityConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_ENABLED"`
|
||||
APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_API_KEY"`
|
||||
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_PERPLEXITY_MAX_RESULTS"`
|
||||
}
|
||||
|
||||
type WebToolsConfig struct {
|
||||
Search WebSearchConfig `json:"search"`
|
||||
Brave BraveConfig `json:"brave"`
|
||||
DuckDuckGo DuckDuckGoConfig `json:"duckduckgo"`
|
||||
Perplexity PerplexityConfig `json:"perplexity"`
|
||||
}
|
||||
|
||||
type CronToolsConfig struct {
|
||||
ExecTimeoutMinutes int `json:"exec_timeout_minutes" env:"PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES"` // 0 means no timeout
|
||||
}
|
||||
|
||||
type MCPServerConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Command string `json:"command"`
|
||||
Args []string `json:"args"`
|
||||
Env map[string]string `json:"env"`
|
||||
WorkingDir string `json:"working_dir"`
|
||||
Protocol string `json:"protocol"`
|
||||
InitTimeoutSeconds int `json:"init_timeout_seconds"`
|
||||
CallTimeoutSeconds int `json:"call_timeout_seconds"`
|
||||
MaxResponseBytes int `json:"max_response_bytes"`
|
||||
IncludeTools []string `json:"include_tools"`
|
||||
ExcludeTools []string `json:"exclude_tools"`
|
||||
}
|
||||
|
||||
type MCPToolsConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_MCP_ENABLED"`
|
||||
Servers map[string]MCPServerConfig `json:"servers"`
|
||||
}
|
||||
|
||||
// LegacyMCPServerConfig supports compatibility with "mcpServers" style config.
|
||||
type LegacyMCPServerConfig struct {
|
||||
Type string `json:"type"`
|
||||
Command string `json:"command"`
|
||||
Args []string `json:"args"`
|
||||
Env map[string]string `json:"env"`
|
||||
Protocol string `json:"protocol"`
|
||||
}
|
||||
|
||||
type ToolsConfig struct {
|
||||
Web WebToolsConfig `json:"web"`
|
||||
Web WebToolsConfig `json:"web"`
|
||||
Cron CronToolsConfig `json:"cron"`
|
||||
MCP MCPToolsConfig `json:"mcp"`
|
||||
}
|
||||
|
||||
func DefaultConfig() *Config {
|
||||
return &Config{
|
||||
Agents: AgentsConfig{
|
||||
Defaults: AgentDefaults{
|
||||
Workspace: "~/.picoclaw/workspace",
|
||||
Workspace: "~/.picoclaw/workspace",
|
||||
RestrictToWorkspace: true,
|
||||
Provider: "",
|
||||
Model: "glm-4.7",
|
||||
MaxTokens: 8192,
|
||||
Temperature: 0.7,
|
||||
MaxToolIterations: 20,
|
||||
Provider: "",
|
||||
Model: "glm-4.7",
|
||||
MaxTokens: 8192,
|
||||
Temperature: 0.7,
|
||||
MaxToolIterations: 20,
|
||||
},
|
||||
},
|
||||
Channels: ChannelsConfig{
|
||||
@@ -229,19 +319,37 @@ func DefaultConfig() *Config {
|
||||
Enabled: false,
|
||||
BotToken: "",
|
||||
AppToken: "",
|
||||
AllowFrom: []string{},
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
LINE: LINEConfig{
|
||||
Enabled: false,
|
||||
ChannelSecret: "",
|
||||
ChannelAccessToken: "",
|
||||
WebhookHost: "0.0.0.0",
|
||||
WebhookPort: 18791,
|
||||
WebhookPath: "/webhook/line",
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
OneBot: OneBotConfig{
|
||||
Enabled: false,
|
||||
WSUrl: "ws://127.0.0.1:3001",
|
||||
AccessToken: "",
|
||||
ReconnectInterval: 5,
|
||||
GroupTriggerPrefix: []string{},
|
||||
AllowFrom: FlexibleStringSlice{},
|
||||
},
|
||||
},
|
||||
Providers: ProvidersConfig{
|
||||
Anthropic: ProviderConfig{},
|
||||
OpenAI: ProviderConfig{},
|
||||
OpenRouter: ProviderConfig{},
|
||||
Groq: ProviderConfig{},
|
||||
Zhipu: ProviderConfig{},
|
||||
VLLM: ProviderConfig{},
|
||||
Gemini: ProviderConfig{},
|
||||
Nvidia: ProviderConfig{},
|
||||
Moonshot: ProviderConfig{},
|
||||
Anthropic: ProviderConfig{},
|
||||
OpenAI: ProviderConfig{},
|
||||
OpenRouter: ProviderConfig{},
|
||||
Groq: ProviderConfig{},
|
||||
Zhipu: ProviderConfig{},
|
||||
VLLM: ProviderConfig{},
|
||||
Gemini: ProviderConfig{},
|
||||
Nvidia: ProviderConfig{},
|
||||
Moonshot: ProviderConfig{},
|
||||
ShengSuanYun: ProviderConfig{},
|
||||
},
|
||||
Gateway: GatewayConfig{
|
||||
Host: "0.0.0.0",
|
||||
@@ -249,11 +357,36 @@ func DefaultConfig() *Config {
|
||||
},
|
||||
Tools: ToolsConfig{
|
||||
Web: WebToolsConfig{
|
||||
Search: WebSearchConfig{
|
||||
Brave: BraveConfig{
|
||||
Enabled: false,
|
||||
APIKey: "",
|
||||
MaxResults: 5,
|
||||
},
|
||||
DuckDuckGo: DuckDuckGoConfig{
|
||||
Enabled: true,
|
||||
MaxResults: 5,
|
||||
},
|
||||
Perplexity: PerplexityConfig{
|
||||
Enabled: false,
|
||||
APIKey: "",
|
||||
MaxResults: 5,
|
||||
},
|
||||
},
|
||||
Cron: CronToolsConfig{
|
||||
ExecTimeoutMinutes: 5, // default 5 minutes for LLM operations
|
||||
},
|
||||
MCP: MCPToolsConfig{
|
||||
Enabled: false,
|
||||
Servers: map[string]MCPServerConfig{},
|
||||
},
|
||||
},
|
||||
Heartbeat: HeartbeatConfig{
|
||||
Enabled: true,
|
||||
Interval: 30, // default 30 minutes
|
||||
},
|
||||
Devices: DevicesConfig{
|
||||
Enabled: false,
|
||||
MonitorUSB: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -277,9 +410,53 @@ func LoadConfig(path string) (*Config, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cfg.applyLegacyMCPServers()
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (c *Config) applyLegacyMCPServers() {
|
||||
// If canonical MCP config already exists, keep it as source of truth.
|
||||
if len(c.Tools.MCP.Servers) > 0 {
|
||||
return
|
||||
}
|
||||
if len(c.MCPServers) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
if c.Tools.MCP.Servers == nil {
|
||||
c.Tools.MCP.Servers = map[string]MCPServerConfig{}
|
||||
}
|
||||
|
||||
for name, legacy := range c.MCPServers {
|
||||
if strings.TrimSpace(legacy.Command) == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
enabled := true
|
||||
if legacy.Type != "" && legacy.Type != "stdio" {
|
||||
enabled = false
|
||||
}
|
||||
|
||||
envCopy := make(map[string]string, len(legacy.Env))
|
||||
for key, value := range legacy.Env {
|
||||
envCopy[key] = value
|
||||
}
|
||||
|
||||
c.Tools.MCP.Servers[name] = MCPServerConfig{
|
||||
Enabled: enabled,
|
||||
Command: legacy.Command,
|
||||
Args: append([]string{}, legacy.Args...),
|
||||
Env: envCopy,
|
||||
Protocol: legacy.Protocol,
|
||||
}
|
||||
}
|
||||
|
||||
if len(c.Tools.MCP.Servers) > 0 {
|
||||
c.Tools.MCP.Enabled = true
|
||||
}
|
||||
}
|
||||
|
||||
func SaveConfig(path string, cfg *Config) error {
|
||||
cfg.mu.RLock()
|
||||
defer cfg.mu.RUnlock()
|
||||
@@ -294,7 +471,7 @@ func SaveConfig(path string, cfg *Config) error {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(path, data, 0644)
|
||||
return os.WriteFile(path, data, 0600)
|
||||
}
|
||||
|
||||
func (c *Config) WorkspacePath() string {
|
||||
@@ -327,6 +504,9 @@ func (c *Config) GetAPIKey() string {
|
||||
if c.Providers.VLLM.APIKey != "" {
|
||||
return c.Providers.VLLM.APIKey
|
||||
}
|
||||
if c.Providers.ShengSuanYun.APIKey != "" {
|
||||
return c.Providers.ShengSuanYun.APIKey
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,206 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestDefaultConfig_HeartbeatEnabled verifies heartbeat is enabled by default
|
||||
func TestDefaultConfig_HeartbeatEnabled(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
if !cfg.Heartbeat.Enabled {
|
||||
t.Error("Heartbeat should be enabled by default")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDefaultConfig_WorkspacePath verifies workspace path is correctly set
|
||||
func TestDefaultConfig_WorkspacePath(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
// Just verify the workspace is set, don't compare exact paths
|
||||
// since expandHome behavior may differ based on environment
|
||||
if cfg.Agents.Defaults.Workspace == "" {
|
||||
t.Error("Workspace should not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDefaultConfig_Model verifies model is set
|
||||
func TestDefaultConfig_Model(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
if cfg.Agents.Defaults.Model == "" {
|
||||
t.Error("Model should not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDefaultConfig_MaxTokens verifies max tokens has default value
|
||||
func TestDefaultConfig_MaxTokens(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
if cfg.Agents.Defaults.MaxTokens == 0 {
|
||||
t.Error("MaxTokens should not be zero")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDefaultConfig_MaxToolIterations verifies max tool iterations has default value
|
||||
func TestDefaultConfig_MaxToolIterations(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
if cfg.Agents.Defaults.MaxToolIterations == 0 {
|
||||
t.Error("MaxToolIterations should not be zero")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDefaultConfig_Temperature verifies temperature has default value
|
||||
func TestDefaultConfig_Temperature(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
if cfg.Agents.Defaults.Temperature == 0 {
|
||||
t.Error("Temperature should not be zero")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDefaultConfig_Gateway verifies gateway defaults
|
||||
func TestDefaultConfig_Gateway(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
if cfg.Gateway.Host != "0.0.0.0" {
|
||||
t.Error("Gateway host should have default value")
|
||||
}
|
||||
if cfg.Gateway.Port == 0 {
|
||||
t.Error("Gateway port should have default value")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDefaultConfig_Providers verifies provider structure
|
||||
func TestDefaultConfig_Providers(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
// Verify all providers are empty by default
|
||||
if cfg.Providers.Anthropic.APIKey != "" {
|
||||
t.Error("Anthropic API key should be empty by default")
|
||||
}
|
||||
if cfg.Providers.OpenAI.APIKey != "" {
|
||||
t.Error("OpenAI API key should be empty by default")
|
||||
}
|
||||
if cfg.Providers.OpenRouter.APIKey != "" {
|
||||
t.Error("OpenRouter API key should be empty by default")
|
||||
}
|
||||
if cfg.Providers.Groq.APIKey != "" {
|
||||
t.Error("Groq API key should be empty by default")
|
||||
}
|
||||
if cfg.Providers.Zhipu.APIKey != "" {
|
||||
t.Error("Zhipu API key should be empty by default")
|
||||
}
|
||||
if cfg.Providers.VLLM.APIKey != "" {
|
||||
t.Error("VLLM API key should be empty by default")
|
||||
}
|
||||
if cfg.Providers.Gemini.APIKey != "" {
|
||||
t.Error("Gemini API key should be empty by default")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDefaultConfig_Channels verifies channels are disabled by default
|
||||
func TestDefaultConfig_Channels(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
// Verify all channels are disabled by default
|
||||
if cfg.Channels.WhatsApp.Enabled {
|
||||
t.Error("WhatsApp should be disabled by default")
|
||||
}
|
||||
if cfg.Channels.Telegram.Enabled {
|
||||
t.Error("Telegram should be disabled by default")
|
||||
}
|
||||
if cfg.Channels.Feishu.Enabled {
|
||||
t.Error("Feishu should be disabled by default")
|
||||
}
|
||||
if cfg.Channels.Discord.Enabled {
|
||||
t.Error("Discord should be disabled by default")
|
||||
}
|
||||
if cfg.Channels.MaixCam.Enabled {
|
||||
t.Error("MaixCam should be disabled by default")
|
||||
}
|
||||
if cfg.Channels.QQ.Enabled {
|
||||
t.Error("QQ should be disabled by default")
|
||||
}
|
||||
if cfg.Channels.DingTalk.Enabled {
|
||||
t.Error("DingTalk should be disabled by default")
|
||||
}
|
||||
if cfg.Channels.Slack.Enabled {
|
||||
t.Error("Slack should be disabled by default")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDefaultConfig_WebTools verifies web tools config
|
||||
func TestDefaultConfig_WebTools(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
// Verify web tools defaults
|
||||
if cfg.Tools.Web.Brave.MaxResults != 5 {
|
||||
t.Error("Expected Brave MaxResults 5, got ", cfg.Tools.Web.Brave.MaxResults)
|
||||
}
|
||||
if cfg.Tools.Web.Brave.APIKey != "" {
|
||||
t.Error("Brave API key should be empty by default")
|
||||
}
|
||||
if cfg.Tools.Web.DuckDuckGo.MaxResults != 5 {
|
||||
t.Error("Expected DuckDuckGo MaxResults 5, got ", cfg.Tools.Web.DuckDuckGo.MaxResults)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveConfig_FilePermissions(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("file permission bits are not enforced on Windows")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "config.json")
|
||||
|
||||
cfg := DefaultConfig()
|
||||
if err := SaveConfig(path, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig failed: %v", err)
|
||||
}
|
||||
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
t.Fatalf("Stat failed: %v", err)
|
||||
}
|
||||
|
||||
perm := info.Mode().Perm()
|
||||
if perm != 0600 {
|
||||
t.Errorf("config file has permission %04o, want 0600", perm)
|
||||
}
|
||||
}
|
||||
|
||||
// TestConfig_Complete verifies all config fields are set
|
||||
func TestConfig_Complete(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
// Verify complete config structure
|
||||
if cfg.Agents.Defaults.Workspace == "" {
|
||||
t.Error("Workspace should not be empty")
|
||||
}
|
||||
if cfg.Agents.Defaults.Model == "" {
|
||||
t.Error("Model should not be empty")
|
||||
}
|
||||
if cfg.Agents.Defaults.Temperature == 0 {
|
||||
t.Error("Temperature should have default value")
|
||||
}
|
||||
if cfg.Agents.Defaults.MaxTokens == 0 {
|
||||
t.Error("MaxTokens should not be zero")
|
||||
}
|
||||
if cfg.Agents.Defaults.MaxToolIterations == 0 {
|
||||
t.Error("MaxToolIterations should not be zero")
|
||||
}
|
||||
if cfg.Gateway.Host != "0.0.0.0" {
|
||||
t.Error("Gateway host should have default value")
|
||||
}
|
||||
if cfg.Gateway.Port == 0 {
|
||||
t.Error("Gateway port should have default value")
|
||||
}
|
||||
if !cfg.Heartbeat.Enabled {
|
||||
t.Error("Heartbeat should be enabled by default")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
// Package constants provides shared constants across the codebase.
|
||||
package constants
|
||||
|
||||
// InternalChannels defines channels that are used for internal communication
|
||||
// and should not be exposed to external users or recorded as last active channel.
|
||||
var InternalChannels = map[string]bool{
|
||||
"cli": true,
|
||||
"system": true,
|
||||
"subagent": true,
|
||||
}
|
||||
|
||||
// IsInternalChannel returns true if the channel is an internal channel.
|
||||
func IsInternalChannel(channel string) bool {
|
||||
return InternalChannels[channel]
|
||||
}
|
||||
+67
-46
@@ -71,7 +71,6 @@ func NewCronService(storePath string, onJob JobHandler) *CronService {
|
||||
cs := &CronService{
|
||||
storePath: storePath,
|
||||
onJob: onJob,
|
||||
stopChan: make(chan struct{}),
|
||||
gronx: gronx.New(),
|
||||
}
|
||||
// Initialize and load store on creation
|
||||
@@ -96,8 +95,9 @@ func (cs *CronService) Start() error {
|
||||
return fmt.Errorf("failed to save store: %w", err)
|
||||
}
|
||||
|
||||
cs.stopChan = make(chan struct{})
|
||||
cs.running = true
|
||||
go cs.runLoop()
|
||||
go cs.runLoop(cs.stopChan)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -111,16 +111,19 @@ func (cs *CronService) Stop() {
|
||||
}
|
||||
|
||||
cs.running = false
|
||||
close(cs.stopChan)
|
||||
if cs.stopChan != nil {
|
||||
close(cs.stopChan)
|
||||
cs.stopChan = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (cs *CronService) runLoop() {
|
||||
func (cs *CronService) runLoop(stopChan chan struct{}) {
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-cs.stopChan:
|
||||
case <-stopChan:
|
||||
return
|
||||
case <-ticker.C:
|
||||
cs.checkJobs()
|
||||
@@ -137,27 +140,23 @@ func (cs *CronService) checkJobs() {
|
||||
}
|
||||
|
||||
now := time.Now().UnixMilli()
|
||||
var dueJobs []*CronJob
|
||||
var dueJobIDs []string
|
||||
|
||||
// Collect jobs that are due (we need to copy them to execute outside lock)
|
||||
for i := range cs.store.Jobs {
|
||||
job := &cs.store.Jobs[i]
|
||||
if job.Enabled && job.State.NextRunAtMS != nil && *job.State.NextRunAtMS <= now {
|
||||
// Create a shallow copy of the job for execution
|
||||
jobCopy := *job
|
||||
dueJobs = append(dueJobs, &jobCopy)
|
||||
dueJobIDs = append(dueJobIDs, job.ID)
|
||||
}
|
||||
}
|
||||
|
||||
// Update next run times for due jobs immediately (before executing)
|
||||
// Use map for O(n) lookup instead of O(n²) nested loop
|
||||
dueMap := make(map[string]bool, len(dueJobs))
|
||||
for _, job := range dueJobs {
|
||||
dueMap[job.ID] = true
|
||||
// Reset next run for due jobs before unlocking to avoid duplicate execution.
|
||||
dueMap := make(map[string]bool, len(dueJobIDs))
|
||||
for _, jobID := range dueJobIDs {
|
||||
dueMap[jobID] = true
|
||||
}
|
||||
for i := range cs.store.Jobs {
|
||||
if dueMap[cs.store.Jobs[i].ID] {
|
||||
// Reset NextRunAtMS temporarily so we don't re-execute
|
||||
cs.store.Jobs[i].State.NextRunAtMS = nil
|
||||
}
|
||||
}
|
||||
@@ -168,53 +167,75 @@ func (cs *CronService) checkJobs() {
|
||||
|
||||
cs.mu.Unlock()
|
||||
|
||||
// Execute jobs outside the lock
|
||||
for _, job := range dueJobs {
|
||||
cs.executeJob(job)
|
||||
// Execute jobs outside lock.
|
||||
for _, jobID := range dueJobIDs {
|
||||
cs.executeJobByID(jobID)
|
||||
}
|
||||
}
|
||||
|
||||
func (cs *CronService) executeJob(job *CronJob) {
|
||||
func (cs *CronService) executeJobByID(jobID string) {
|
||||
startTime := time.Now().UnixMilli()
|
||||
|
||||
cs.mu.RLock()
|
||||
var callbackJob *CronJob
|
||||
for i := range cs.store.Jobs {
|
||||
job := &cs.store.Jobs[i]
|
||||
if job.ID == jobID {
|
||||
jobCopy := *job
|
||||
callbackJob = &jobCopy
|
||||
break
|
||||
}
|
||||
}
|
||||
cs.mu.RUnlock()
|
||||
|
||||
if callbackJob == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var err error
|
||||
if cs.onJob != nil {
|
||||
_, err = cs.onJob(job)
|
||||
_, err = cs.onJob(callbackJob)
|
||||
}
|
||||
|
||||
// Now acquire lock to update state
|
||||
cs.mu.Lock()
|
||||
defer cs.mu.Unlock()
|
||||
|
||||
// Find the job in store and update it
|
||||
var job *CronJob
|
||||
for i := range cs.store.Jobs {
|
||||
if cs.store.Jobs[i].ID == job.ID {
|
||||
cs.store.Jobs[i].State.LastRunAtMS = &startTime
|
||||
cs.store.Jobs[i].UpdatedAtMS = time.Now().UnixMilli()
|
||||
|
||||
if err != nil {
|
||||
cs.store.Jobs[i].State.LastStatus = "error"
|
||||
cs.store.Jobs[i].State.LastError = err.Error()
|
||||
} else {
|
||||
cs.store.Jobs[i].State.LastStatus = "ok"
|
||||
cs.store.Jobs[i].State.LastError = ""
|
||||
}
|
||||
|
||||
// Compute next run time
|
||||
if cs.store.Jobs[i].Schedule.Kind == "at" {
|
||||
if cs.store.Jobs[i].DeleteAfterRun {
|
||||
cs.removeJobUnsafe(job.ID)
|
||||
} else {
|
||||
cs.store.Jobs[i].Enabled = false
|
||||
cs.store.Jobs[i].State.NextRunAtMS = nil
|
||||
}
|
||||
} else {
|
||||
nextRun := cs.computeNextRun(&cs.store.Jobs[i].Schedule, time.Now().UnixMilli())
|
||||
cs.store.Jobs[i].State.NextRunAtMS = nextRun
|
||||
}
|
||||
if cs.store.Jobs[i].ID == jobID {
|
||||
job = &cs.store.Jobs[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
if job == nil {
|
||||
log.Printf("[cron] job %s disappeared before state update", jobID)
|
||||
return
|
||||
}
|
||||
|
||||
job.State.LastRunAtMS = &startTime
|
||||
job.UpdatedAtMS = time.Now().UnixMilli()
|
||||
|
||||
if err != nil {
|
||||
job.State.LastStatus = "error"
|
||||
job.State.LastError = err.Error()
|
||||
} else {
|
||||
job.State.LastStatus = "ok"
|
||||
job.State.LastError = ""
|
||||
}
|
||||
|
||||
// Compute next run time
|
||||
if job.Schedule.Kind == "at" {
|
||||
if job.DeleteAfterRun {
|
||||
cs.removeJobUnsafe(job.ID)
|
||||
} else {
|
||||
job.Enabled = false
|
||||
job.State.NextRunAtMS = nil
|
||||
}
|
||||
} else {
|
||||
nextRun := cs.computeNextRun(&job.Schedule, time.Now().UnixMilli())
|
||||
job.State.NextRunAtMS = nextRun
|
||||
}
|
||||
|
||||
if err := cs.saveStoreUnsafe(); err != nil {
|
||||
log.Printf("[cron] failed to save store: %v", err)
|
||||
@@ -319,7 +340,7 @@ func (cs *CronService) saveStoreUnsafe() error {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(cs.storePath, data, 0644)
|
||||
return os.WriteFile(cs.storePath, data, 0600)
|
||||
}
|
||||
|
||||
func (cs *CronService) AddJob(name string, schedule CronSchedule, message string, deliver bool, channel, to string) (*CronJob, error) {
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
package cron
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSaveStore_FilePermissions(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("file permission bits are not enforced on Windows")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
storePath := filepath.Join(tmpDir, "cron", "jobs.json")
|
||||
|
||||
cs := NewCronService(storePath, nil)
|
||||
|
||||
_, err := cs.AddJob("test", CronSchedule{Kind: "every", EveryMS: int64Ptr(60000)}, "hello", false, "cli", "direct")
|
||||
if err != nil {
|
||||
t.Fatalf("AddJob failed: %v", err)
|
||||
}
|
||||
|
||||
info, err := os.Stat(storePath)
|
||||
if err != nil {
|
||||
t.Fatalf("Stat failed: %v", err)
|
||||
}
|
||||
|
||||
perm := info.Mode().Perm()
|
||||
if perm != 0600 {
|
||||
t.Errorf("cron store has permission %04o, want 0600", perm)
|
||||
}
|
||||
}
|
||||
|
||||
func int64Ptr(v int64) *int64 {
|
||||
return &v
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
package events
|
||||
|
||||
import "context"
|
||||
|
||||
type EventSource interface {
|
||||
Kind() Kind
|
||||
Start(ctx context.Context) (<-chan *DeviceEvent, error)
|
||||
Stop() error
|
||||
}
|
||||
|
||||
type Action string
|
||||
|
||||
const (
|
||||
ActionAdd Action = "add"
|
||||
ActionRemove Action = "remove"
|
||||
ActionChange Action = "change"
|
||||
)
|
||||
|
||||
type Kind string
|
||||
|
||||
const (
|
||||
KindUSB Kind = "usb"
|
||||
KindBluetooth Kind = "bluetooth"
|
||||
KindPCI Kind = "pci"
|
||||
KindGeneric Kind = "generic"
|
||||
)
|
||||
|
||||
type DeviceEvent struct {
|
||||
Action Action
|
||||
Kind Kind
|
||||
DeviceID string // e.g. "1-2" for USB bus 1 dev 2
|
||||
Vendor string // Vendor name or ID
|
||||
Product string // Product name or ID
|
||||
Serial string // Serial number if available
|
||||
Capabilities string // Human-readable capability description
|
||||
Raw map[string]string // Raw properties for extensibility
|
||||
}
|
||||
|
||||
func (e *DeviceEvent) FormatMessage() string {
|
||||
actionEmoji := "🔌"
|
||||
actionText := "Connected"
|
||||
if e.Action == ActionRemove {
|
||||
actionEmoji = "🔌"
|
||||
actionText = "Disconnected"
|
||||
}
|
||||
|
||||
msg := actionEmoji + " Device " + actionText + "\n\n"
|
||||
msg += "Type: " + string(e.Kind) + "\n"
|
||||
msg += "Device: " + e.Vendor + " " + e.Product + "\n"
|
||||
if e.Capabilities != "" {
|
||||
msg += "Capabilities: " + e.Capabilities + "\n"
|
||||
}
|
||||
if e.Serial != "" {
|
||||
msg += "Serial: " + e.Serial + "\n"
|
||||
}
|
||||
return msg
|
||||
}
|
||||
@@ -0,0 +1,152 @@
|
||||
package devices
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/constants"
|
||||
"github.com/sipeed/picoclaw/pkg/devices/events"
|
||||
"github.com/sipeed/picoclaw/pkg/devices/sources"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/state"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
bus *bus.MessageBus
|
||||
state *state.Manager
|
||||
sources []events.EventSource
|
||||
enabled bool
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Enabled bool
|
||||
MonitorUSB bool // When true, monitor USB hotplug (Linux only)
|
||||
// Future: MonitorBluetooth, MonitorPCI, etc.
|
||||
}
|
||||
|
||||
func NewService(cfg Config, stateMgr *state.Manager) *Service {
|
||||
s := &Service{
|
||||
state: stateMgr,
|
||||
enabled: cfg.Enabled,
|
||||
sources: make([]EventSource, 0),
|
||||
}
|
||||
|
||||
if cfg.Enabled && cfg.MonitorUSB {
|
||||
s.sources = append(s.sources, sources.NewUSBMonitor())
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Service) SetBus(msgBus *bus.MessageBus) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.bus = msgBus
|
||||
}
|
||||
|
||||
func (s *Service) Start(ctx context.Context) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if !s.enabled || len(s.sources) == 0 {
|
||||
logger.InfoC("devices", "Device event service disabled or no sources")
|
||||
return nil
|
||||
}
|
||||
|
||||
s.ctx, s.cancel = context.WithCancel(ctx)
|
||||
|
||||
for _, src := range s.sources {
|
||||
eventCh, err := src.Start(s.ctx)
|
||||
if err != nil {
|
||||
logger.ErrorCF("devices", "Failed to start source", map[string]interface{}{
|
||||
"kind": src.Kind(),
|
||||
"error": err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
go s.handleEvents(src.Kind(), eventCh)
|
||||
logger.InfoCF("devices", "Device source started", map[string]interface{}{
|
||||
"kind": src.Kind(),
|
||||
})
|
||||
}
|
||||
|
||||
logger.InfoC("devices", "Device event service started")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) Stop() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.cancel != nil {
|
||||
s.cancel()
|
||||
s.cancel = nil
|
||||
}
|
||||
|
||||
for _, src := range s.sources {
|
||||
src.Stop()
|
||||
}
|
||||
|
||||
logger.InfoC("devices", "Device event service stopped")
|
||||
}
|
||||
|
||||
func (s *Service) handleEvents(kind events.Kind, eventCh <-chan *events.DeviceEvent) {
|
||||
for ev := range eventCh {
|
||||
if ev == nil {
|
||||
continue
|
||||
}
|
||||
s.sendNotification(ev)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) sendNotification(ev *events.DeviceEvent) {
|
||||
s.mu.RLock()
|
||||
msgBus := s.bus
|
||||
s.mu.RUnlock()
|
||||
|
||||
if msgBus == nil {
|
||||
return
|
||||
}
|
||||
|
||||
lastChannel := s.state.GetLastChannel()
|
||||
if lastChannel == "" {
|
||||
logger.DebugCF("devices", "No last channel, skipping notification", map[string]interface{}{
|
||||
"event": ev.FormatMessage(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
platform, userID := parseLastChannel(lastChannel)
|
||||
if platform == "" || userID == "" || constants.IsInternalChannel(platform) {
|
||||
return
|
||||
}
|
||||
|
||||
msg := ev.FormatMessage()
|
||||
msgBus.PublishOutbound(bus.OutboundMessage{
|
||||
Channel: platform,
|
||||
ChatID: userID,
|
||||
Content: msg,
|
||||
})
|
||||
|
||||
logger.InfoCF("devices", "Device notification sent", map[string]interface{}{
|
||||
"kind": ev.Kind,
|
||||
"action": ev.Action,
|
||||
"to": platform,
|
||||
})
|
||||
}
|
||||
|
||||
func parseLastChannel(lastChannel string) (platform, userID string) {
|
||||
if lastChannel == "" {
|
||||
return "", ""
|
||||
}
|
||||
parts := strings.SplitN(lastChannel, ":", 2)
|
||||
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
|
||||
return "", ""
|
||||
}
|
||||
return parts[0], parts[1]
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
package devices
|
||||
|
||||
import "github.com/sipeed/picoclaw/pkg/devices/events"
|
||||
|
||||
type EventSource = events.EventSource
|
||||
@@ -0,0 +1,198 @@
|
||||
//go:build linux
|
||||
|
||||
package sources
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/devices/events"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
var usbClassToCapability = map[string]string{
|
||||
"00": "Interface Definition (by interface)",
|
||||
"01": "Audio",
|
||||
"02": "CDC Communication (Network Card/Modem)",
|
||||
"03": "HID (Keyboard/Mouse/Gamepad)",
|
||||
"05": "Physical Interface",
|
||||
"06": "Image (Scanner/Camera)",
|
||||
"07": "Printer",
|
||||
"08": "Mass Storage (USB Flash Drive/Hard Disk)",
|
||||
"09": "USB Hub",
|
||||
"0a": "CDC Data",
|
||||
"0b": "Smart Card",
|
||||
"0e": "Video (Camera)",
|
||||
"dc": "Diagnostic Device",
|
||||
"e0": "Wireless Controller (Bluetooth)",
|
||||
"ef": "Miscellaneous",
|
||||
"fe": "Application Specific",
|
||||
"ff": "Vendor Specific",
|
||||
}
|
||||
|
||||
type USBMonitor struct {
|
||||
cmd *exec.Cmd
|
||||
cancel context.CancelFunc
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewUSBMonitor() *USBMonitor {
|
||||
return &USBMonitor{}
|
||||
}
|
||||
|
||||
func (m *USBMonitor) Kind() events.Kind {
|
||||
return events.KindUSB
|
||||
}
|
||||
|
||||
func (m *USBMonitor) Start(ctx context.Context) (<-chan *events.DeviceEvent, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// udevadm monitor outputs: UDEV/KERNEL [timestamp] action devpath (subsystem)
|
||||
// Followed by KEY=value lines, empty line separates events
|
||||
// Use -s/--subsystem-match (eudev) or --udev-subsystem-match (systemd udev)
|
||||
cmd := exec.CommandContext(ctx, "udevadm", "monitor", "--property", "--subsystem-match=usb")
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("udevadm stdout pipe: %w", err)
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
return nil, fmt.Errorf("udevadm start: %w (is udevadm installed?)", err)
|
||||
}
|
||||
|
||||
m.cmd = cmd
|
||||
eventCh := make(chan *events.DeviceEvent, 16)
|
||||
|
||||
go func() {
|
||||
defer close(eventCh)
|
||||
scanner := bufio.NewScanner(stdout)
|
||||
var props map[string]string
|
||||
var action string
|
||||
isUdev := false // Only UDEV events have complete info (ID_VENDOR, ID_MODEL); KERNEL events come first with less info
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if line == "" {
|
||||
// End of event block - only process UDEV events (skip KERNEL to avoid duplicate/incomplete notifications)
|
||||
if isUdev && props != nil && (action == "add" || action == "remove") {
|
||||
if ev := parseUSBEvent(action, props); ev != nil {
|
||||
select {
|
||||
case eventCh <- ev:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
props = nil
|
||||
action = ""
|
||||
isUdev = false
|
||||
continue
|
||||
}
|
||||
|
||||
idx := strings.Index(line, "=")
|
||||
// First line of block: "UDEV [ts] action devpath" or "KERNEL[ts] action devpath" - no KEY=value
|
||||
if idx <= 0 {
|
||||
isUdev = strings.HasPrefix(strings.TrimSpace(line), "UDEV")
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse KEY=value
|
||||
key := line[:idx]
|
||||
val := line[idx+1:]
|
||||
if props == nil {
|
||||
props = make(map[string]string)
|
||||
}
|
||||
props[key] = val
|
||||
|
||||
if key == "ACTION" {
|
||||
action = val
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
logger.ErrorCF("devices", "udevadm scan error", map[string]interface{}{"error": err.Error()})
|
||||
}
|
||||
cmd.Wait()
|
||||
}()
|
||||
|
||||
return eventCh, nil
|
||||
}
|
||||
|
||||
func (m *USBMonitor) Stop() error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if m.cmd != nil && m.cmd.Process != nil {
|
||||
m.cmd.Process.Kill()
|
||||
m.cmd = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseUSBEvent(action string, props map[string]string) *events.DeviceEvent {
|
||||
// Only care about add/remove for physical devices (not interfaces)
|
||||
subsystem := props["SUBSYSTEM"]
|
||||
if subsystem != "usb" {
|
||||
return nil
|
||||
}
|
||||
// Skip interface events - we want device-level only to avoid duplicates
|
||||
devType := props["DEVTYPE"]
|
||||
if devType == "usb_interface" {
|
||||
return nil
|
||||
}
|
||||
// Prefer usb_device, but accept if DEVTYPE not set (varies by udev version)
|
||||
if devType != "" && devType != "usb_device" {
|
||||
return nil
|
||||
}
|
||||
|
||||
ev := &events.DeviceEvent{
|
||||
Raw: props,
|
||||
}
|
||||
switch action {
|
||||
case "add":
|
||||
ev.Action = events.ActionAdd
|
||||
case "remove":
|
||||
ev.Action = events.ActionRemove
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
ev.Kind = events.KindUSB
|
||||
|
||||
ev.Vendor = props["ID_VENDOR"]
|
||||
if ev.Vendor == "" {
|
||||
ev.Vendor = props["ID_VENDOR_ID"]
|
||||
}
|
||||
if ev.Vendor == "" {
|
||||
ev.Vendor = "Unknown Vendor"
|
||||
}
|
||||
|
||||
ev.Product = props["ID_MODEL"]
|
||||
if ev.Product == "" {
|
||||
ev.Product = props["ID_MODEL_ID"]
|
||||
}
|
||||
if ev.Product == "" {
|
||||
ev.Product = "Unknown Device"
|
||||
}
|
||||
|
||||
ev.Serial = props["ID_SERIAL_SHORT"]
|
||||
ev.DeviceID = props["DEVPATH"]
|
||||
if bus := props["BUSNUM"]; bus != "" {
|
||||
if dev := props["DEVNUM"]; dev != "" {
|
||||
ev.DeviceID = bus + ":" + dev
|
||||
}
|
||||
}
|
||||
|
||||
// Map USB class to capability
|
||||
if class := props["ID_USB_CLASS"]; class != "" {
|
||||
ev.Capabilities = usbClassToCapability[strings.ToLower(class)]
|
||||
}
|
||||
if ev.Capabilities == "" {
|
||||
ev.Capabilities = "USB Device"
|
||||
}
|
||||
|
||||
return ev
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
//go:build !linux
|
||||
|
||||
package sources
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/devices/events"
|
||||
)
|
||||
|
||||
type USBMonitor struct{}
|
||||
|
||||
func NewUSBMonitor() *USBMonitor {
|
||||
return &USBMonitor{}
|
||||
}
|
||||
|
||||
func (m *USBMonitor) Kind() events.Kind {
|
||||
return events.KindUSB
|
||||
}
|
||||
|
||||
func (m *USBMonitor) Start(ctx context.Context) (<-chan *events.DeviceEvent, error) {
|
||||
ch := make(chan *events.DeviceEvent)
|
||||
close(ch) // Immediately close, no events
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func (m *USBMonitor) Stop() error {
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,164 @@
|
||||
package health
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
server *http.Server
|
||||
mu sync.RWMutex
|
||||
ready bool
|
||||
checks map[string]Check
|
||||
startTime time.Time
|
||||
}
|
||||
|
||||
type Check struct {
|
||||
Name string `json:"name"`
|
||||
Status string `json:"status"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
}
|
||||
|
||||
type StatusResponse struct {
|
||||
Status string `json:"status"`
|
||||
Uptime string `json:"uptime"`
|
||||
Checks map[string]Check `json:"checks,omitempty"`
|
||||
}
|
||||
|
||||
func NewServer(host string, port int) *Server {
|
||||
mux := http.NewServeMux()
|
||||
s := &Server{
|
||||
ready: false,
|
||||
checks: make(map[string]Check),
|
||||
startTime: time.Now(),
|
||||
}
|
||||
|
||||
mux.HandleFunc("/health", s.healthHandler)
|
||||
mux.HandleFunc("/ready", s.readyHandler)
|
||||
|
||||
addr := fmt.Sprintf("%s:%d", host, port)
|
||||
s.server = &http.Server{
|
||||
Addr: addr,
|
||||
Handler: mux,
|
||||
ReadTimeout: 5 * time.Second,
|
||||
WriteTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Server) Start() error {
|
||||
s.mu.Lock()
|
||||
s.ready = true
|
||||
s.mu.Unlock()
|
||||
return s.server.ListenAndServe()
|
||||
}
|
||||
|
||||
func (s *Server) StartContext(ctx context.Context) error {
|
||||
s.mu.Lock()
|
||||
s.ready = true
|
||||
s.mu.Unlock()
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- s.server.ListenAndServe()
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
return s.server.Shutdown(context.Background())
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) Stop(ctx context.Context) error {
|
||||
s.mu.Lock()
|
||||
s.ready = false
|
||||
s.mu.Unlock()
|
||||
return s.server.Shutdown(ctx)
|
||||
}
|
||||
|
||||
func (s *Server) SetReady(ready bool) {
|
||||
s.mu.Lock()
|
||||
s.ready = ready
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
func (s *Server) RegisterCheck(name string, checkFn func() (bool, string)) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
status, msg := checkFn()
|
||||
s.checks[name] = Check{
|
||||
Name: name,
|
||||
Status: statusString(status),
|
||||
Message: msg,
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) healthHandler(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
uptime := time.Since(s.startTime)
|
||||
resp := StatusResponse{
|
||||
Status: "ok",
|
||||
Uptime: uptime.String(),
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
|
||||
func (s *Server) readyHandler(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
s.mu.RLock()
|
||||
ready := s.ready
|
||||
checks := make(map[string]Check)
|
||||
for k, v := range s.checks {
|
||||
checks[k] = v
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !ready {
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
json.NewEncoder(w).Encode(StatusResponse{
|
||||
Status: "not ready",
|
||||
Checks: checks,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
for _, check := range checks {
|
||||
if check.Status == "fail" {
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
json.NewEncoder(w).Encode(StatusResponse{
|
||||
Status: "not ready",
|
||||
Checks: checks,
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
uptime := time.Since(s.startTime)
|
||||
json.NewEncoder(w).Encode(StatusResponse{
|
||||
Status: "ready",
|
||||
Uptime: uptime.String(),
|
||||
Checks: checks,
|
||||
})
|
||||
}
|
||||
|
||||
func statusString(ok bool) string {
|
||||
if ok {
|
||||
return "ok"
|
||||
}
|
||||
return "fail"
|
||||
}
|
||||
+282
-54
@@ -1,131 +1,359 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// Inspired by and based on nanobot: https://github.com/HKUDS/nanobot
|
||||
// License: MIT
|
||||
//
|
||||
// Copyright (c) 2026 PicoClaw contributors
|
||||
|
||||
package heartbeat
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/constants"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/state"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
|
||||
const (
|
||||
minIntervalMinutes = 5
|
||||
defaultIntervalMinutes = 30
|
||||
)
|
||||
|
||||
// HeartbeatHandler is the function type for handling heartbeat.
|
||||
// It returns a ToolResult that can indicate async operations.
|
||||
// channel and chatID are derived from the last active user channel.
|
||||
type HeartbeatHandler func(prompt, channel, chatID string) *tools.ToolResult
|
||||
|
||||
// HeartbeatService manages periodic heartbeat checks
|
||||
type HeartbeatService struct {
|
||||
workspace string
|
||||
onHeartbeat func(string) (string, error)
|
||||
interval time.Duration
|
||||
enabled bool
|
||||
mu sync.RWMutex
|
||||
started bool
|
||||
stopChan chan struct{}
|
||||
workspace string
|
||||
bus *bus.MessageBus
|
||||
state *state.Manager
|
||||
handler HeartbeatHandler
|
||||
interval time.Duration
|
||||
enabled bool
|
||||
mu sync.RWMutex
|
||||
stopChan chan struct{}
|
||||
}
|
||||
|
||||
func NewHeartbeatService(workspace string, onHeartbeat func(string) (string, error), intervalS int, enabled bool) *HeartbeatService {
|
||||
// NewHeartbeatService creates a new heartbeat service
|
||||
func NewHeartbeatService(workspace string, intervalMinutes int, enabled bool) *HeartbeatService {
|
||||
// Apply minimum interval
|
||||
if intervalMinutes < minIntervalMinutes && intervalMinutes != 0 {
|
||||
intervalMinutes = minIntervalMinutes
|
||||
}
|
||||
|
||||
if intervalMinutes == 0 {
|
||||
intervalMinutes = defaultIntervalMinutes
|
||||
}
|
||||
|
||||
return &HeartbeatService{
|
||||
workspace: workspace,
|
||||
onHeartbeat: onHeartbeat,
|
||||
interval: time.Duration(intervalS) * time.Second,
|
||||
enabled: enabled,
|
||||
stopChan: make(chan struct{}),
|
||||
workspace: workspace,
|
||||
interval: time.Duration(intervalMinutes) * time.Minute,
|
||||
enabled: enabled,
|
||||
state: state.NewManager(workspace),
|
||||
}
|
||||
}
|
||||
|
||||
// SetBus sets the message bus for delivering heartbeat results.
|
||||
func (hs *HeartbeatService) SetBus(msgBus *bus.MessageBus) {
|
||||
hs.mu.Lock()
|
||||
defer hs.mu.Unlock()
|
||||
hs.bus = msgBus
|
||||
}
|
||||
|
||||
// SetHandler sets the heartbeat handler.
|
||||
func (hs *HeartbeatService) SetHandler(handler HeartbeatHandler) {
|
||||
hs.mu.Lock()
|
||||
defer hs.mu.Unlock()
|
||||
hs.handler = handler
|
||||
}
|
||||
|
||||
// Start begins the heartbeat service
|
||||
func (hs *HeartbeatService) Start() error {
|
||||
hs.mu.Lock()
|
||||
defer hs.mu.Unlock()
|
||||
|
||||
if hs.started {
|
||||
if hs.stopChan != nil {
|
||||
logger.InfoC("heartbeat", "Heartbeat service already running")
|
||||
return nil
|
||||
}
|
||||
|
||||
if !hs.enabled {
|
||||
return fmt.Errorf("heartbeat service is disabled")
|
||||
logger.InfoC("heartbeat", "Heartbeat service disabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
hs.started = true
|
||||
go hs.runLoop()
|
||||
hs.stopChan = make(chan struct{})
|
||||
go hs.runLoop(hs.stopChan)
|
||||
|
||||
logger.InfoCF("heartbeat", "Heartbeat service started", map[string]any{
|
||||
"interval_minutes": hs.interval.Minutes(),
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop gracefully stops the heartbeat service
|
||||
func (hs *HeartbeatService) Stop() {
|
||||
hs.mu.Lock()
|
||||
defer hs.mu.Unlock()
|
||||
|
||||
if !hs.started {
|
||||
if hs.stopChan == nil {
|
||||
return
|
||||
}
|
||||
|
||||
hs.started = false
|
||||
logger.InfoC("heartbeat", "Stopping heartbeat service")
|
||||
close(hs.stopChan)
|
||||
hs.stopChan = nil
|
||||
}
|
||||
|
||||
func (hs *HeartbeatService) running() bool {
|
||||
select {
|
||||
case <-hs.stopChan:
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
// IsRunning returns whether the service is running
|
||||
func (hs *HeartbeatService) IsRunning() bool {
|
||||
hs.mu.RLock()
|
||||
defer hs.mu.RUnlock()
|
||||
return hs.stopChan != nil
|
||||
}
|
||||
|
||||
func (hs *HeartbeatService) runLoop() {
|
||||
// runLoop runs the heartbeat ticker
|
||||
func (hs *HeartbeatService) runLoop(stopChan chan struct{}) {
|
||||
ticker := time.NewTicker(hs.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Run first heartbeat after initial delay
|
||||
time.AfterFunc(time.Second, func() {
|
||||
hs.executeHeartbeat()
|
||||
})
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-hs.stopChan:
|
||||
case <-stopChan:
|
||||
return
|
||||
case <-ticker.C:
|
||||
hs.checkHeartbeat()
|
||||
hs.executeHeartbeat()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (hs *HeartbeatService) checkHeartbeat() {
|
||||
// executeHeartbeat performs a single heartbeat check
|
||||
func (hs *HeartbeatService) executeHeartbeat() {
|
||||
hs.mu.RLock()
|
||||
if !hs.enabled || !hs.running() {
|
||||
enabled := hs.enabled
|
||||
handler := hs.handler
|
||||
if !hs.enabled || hs.stopChan == nil {
|
||||
hs.mu.RUnlock()
|
||||
return
|
||||
}
|
||||
hs.mu.RUnlock()
|
||||
|
||||
prompt := hs.buildPrompt()
|
||||
|
||||
if hs.onHeartbeat != nil {
|
||||
_, err := hs.onHeartbeat(prompt)
|
||||
if err != nil {
|
||||
hs.log(fmt.Sprintf("Heartbeat error: %v", err))
|
||||
}
|
||||
if !enabled {
|
||||
return
|
||||
}
|
||||
|
||||
logger.DebugC("heartbeat", "Executing heartbeat")
|
||||
|
||||
prompt := hs.buildPrompt()
|
||||
if prompt == "" {
|
||||
logger.InfoC("heartbeat", "No heartbeat prompt (HEARTBEAT.md empty or missing)")
|
||||
return
|
||||
}
|
||||
|
||||
if handler == nil {
|
||||
hs.logError("Heartbeat handler not configured")
|
||||
return
|
||||
}
|
||||
|
||||
// Get last channel info for context
|
||||
lastChannel := hs.state.GetLastChannel()
|
||||
channel, chatID := hs.parseLastChannel(lastChannel)
|
||||
|
||||
// Debug log for channel resolution
|
||||
hs.logInfo("Resolved channel: %s, chatID: %s (from lastChannel: %s)", channel, chatID, lastChannel)
|
||||
|
||||
result := handler(prompt, channel, chatID)
|
||||
|
||||
if result == nil {
|
||||
hs.logInfo("Heartbeat handler returned nil result")
|
||||
return
|
||||
}
|
||||
|
||||
// Handle different result types
|
||||
if result.IsError {
|
||||
hs.logError("Heartbeat error: %s", result.ForLLM)
|
||||
return
|
||||
}
|
||||
|
||||
if result.Async {
|
||||
hs.logInfo("Async task started: %s", result.ForLLM)
|
||||
logger.InfoCF("heartbeat", "Async heartbeat task started",
|
||||
map[string]interface{}{
|
||||
"message": result.ForLLM,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Check if silent
|
||||
if result.Silent {
|
||||
hs.logInfo("Heartbeat OK - silent")
|
||||
return
|
||||
}
|
||||
|
||||
// Send result to user
|
||||
if result.ForUser != "" {
|
||||
hs.sendResponse(result.ForUser)
|
||||
} else if result.ForLLM != "" {
|
||||
hs.sendResponse(result.ForLLM)
|
||||
}
|
||||
|
||||
hs.logInfo("Heartbeat completed: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// buildPrompt builds the heartbeat prompt from HEARTBEAT.md
|
||||
func (hs *HeartbeatService) buildPrompt() string {
|
||||
notesDir := filepath.Join(hs.workspace, "memory")
|
||||
notesFile := filepath.Join(notesDir, "HEARTBEAT.md")
|
||||
heartbeatPath := filepath.Join(hs.workspace, "HEARTBEAT.md")
|
||||
|
||||
var notes string
|
||||
if data, err := os.ReadFile(notesFile); err == nil {
|
||||
notes = string(data)
|
||||
data, err := os.ReadFile(heartbeatPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
hs.createDefaultHeartbeatTemplate()
|
||||
return ""
|
||||
}
|
||||
hs.logError("Error reading HEARTBEAT.md: %v", err)
|
||||
return ""
|
||||
}
|
||||
|
||||
now := time.Now().Format("2006-01-02 15:04")
|
||||
content := string(data)
|
||||
if len(content) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
prompt := fmt.Sprintf(`# Heartbeat Check
|
||||
now := time.Now().Format("2006-01-02 15:04:05")
|
||||
return fmt.Sprintf(`# Heartbeat Check
|
||||
|
||||
Current time: %s
|
||||
|
||||
Check if there are any tasks I should be aware of or actions I should take.
|
||||
Review the memory file for any important updates or changes.
|
||||
Be proactive in identifying potential issues or improvements.
|
||||
You are a proactive AI assistant. This is a scheduled heartbeat check.
|
||||
Review the following tasks and execute any necessary actions using available skills.
|
||||
If there is nothing that requires attention, respond ONLY with: HEARTBEAT_OK
|
||||
|
||||
%s
|
||||
`, now, notes)
|
||||
|
||||
return prompt
|
||||
`, now, content)
|
||||
}
|
||||
|
||||
func (hs *HeartbeatService) log(message string) {
|
||||
logFile := filepath.Join(hs.workspace, "memory", "heartbeat.log")
|
||||
// createDefaultHeartbeatTemplate creates the default HEARTBEAT.md file
|
||||
func (hs *HeartbeatService) createDefaultHeartbeatTemplate() {
|
||||
heartbeatPath := filepath.Join(hs.workspace, "HEARTBEAT.md")
|
||||
|
||||
defaultContent := `# Heartbeat Check List
|
||||
|
||||
This file contains tasks for the heartbeat service to check periodically.
|
||||
|
||||
## Examples
|
||||
|
||||
- Check for unread messages
|
||||
- Review upcoming calendar events
|
||||
- Check device status (e.g., MaixCam)
|
||||
|
||||
## Instructions
|
||||
|
||||
- Execute ALL tasks listed below. Do NOT skip any task.
|
||||
- For simple tasks (e.g., report current time), respond directly.
|
||||
- For complex tasks that may take time, use the spawn tool to create a subagent.
|
||||
- The spawn tool is async - subagent results will be sent to the user automatically.
|
||||
- After spawning a subagent, CONTINUE to process remaining tasks.
|
||||
- Only respond with HEARTBEAT_OK when ALL tasks are done AND nothing needs attention.
|
||||
|
||||
---
|
||||
|
||||
Add your heartbeat tasks below this line:
|
||||
`
|
||||
|
||||
if err := os.WriteFile(heartbeatPath, []byte(defaultContent), 0644); err != nil {
|
||||
hs.logError("Failed to create default HEARTBEAT.md: %v", err)
|
||||
} else {
|
||||
hs.logInfo("Created default HEARTBEAT.md template")
|
||||
}
|
||||
}
|
||||
|
||||
// sendResponse sends the heartbeat response to the last channel
|
||||
func (hs *HeartbeatService) sendResponse(response string) {
|
||||
hs.mu.RLock()
|
||||
msgBus := hs.bus
|
||||
hs.mu.RUnlock()
|
||||
|
||||
if msgBus == nil {
|
||||
hs.logInfo("No message bus configured, heartbeat result not sent")
|
||||
return
|
||||
}
|
||||
|
||||
// Get last channel from state
|
||||
lastChannel := hs.state.GetLastChannel()
|
||||
if lastChannel == "" {
|
||||
hs.logInfo("No last channel recorded, heartbeat result not sent")
|
||||
return
|
||||
}
|
||||
|
||||
platform, userID := hs.parseLastChannel(lastChannel)
|
||||
|
||||
// Skip internal channels that can't receive messages
|
||||
if platform == "" || userID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
msgBus.PublishOutbound(bus.OutboundMessage{
|
||||
Channel: platform,
|
||||
ChatID: userID,
|
||||
Content: response,
|
||||
})
|
||||
|
||||
hs.logInfo("Heartbeat result sent to %s", platform)
|
||||
}
|
||||
|
||||
// parseLastChannel parses the last channel string into platform and userID.
|
||||
// Returns empty strings for invalid or internal channels.
|
||||
func (hs *HeartbeatService) parseLastChannel(lastChannel string) (platform, userID string) {
|
||||
if lastChannel == "" {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
// Parse channel format: "platform:user_id" (e.g., "telegram:123456")
|
||||
parts := strings.SplitN(lastChannel, ":", 2)
|
||||
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
|
||||
hs.logError("Invalid last channel format: %s", lastChannel)
|
||||
return "", ""
|
||||
}
|
||||
|
||||
platform, userID = parts[0], parts[1]
|
||||
|
||||
// Skip internal channels
|
||||
if constants.IsInternalChannel(platform) {
|
||||
hs.logInfo("Skipping internal channel: %s", platform)
|
||||
return "", ""
|
||||
}
|
||||
|
||||
return platform, userID
|
||||
}
|
||||
|
||||
// logInfo logs an informational message to the heartbeat log
|
||||
func (hs *HeartbeatService) logInfo(format string, args ...any) {
|
||||
hs.log("INFO", format, args...)
|
||||
}
|
||||
|
||||
// logError logs an error message to the heartbeat log
|
||||
func (hs *HeartbeatService) logError(format string, args ...any) {
|
||||
hs.log("ERROR", format, args...)
|
||||
}
|
||||
|
||||
// log writes a message to the heartbeat log file
|
||||
func (hs *HeartbeatService) log(level, format string, args ...any) {
|
||||
logFile := filepath.Join(hs.workspace, "heartbeat.log")
|
||||
f, err := os.OpenFile(logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
return
|
||||
@@ -133,5 +361,5 @@ func (hs *HeartbeatService) log(message string) {
|
||||
defer f.Close()
|
||||
|
||||
timestamp := time.Now().Format("2006-01-02 15:04:05")
|
||||
f.WriteString(fmt.Sprintf("[%s] %s\n", timestamp, message))
|
||||
fmt.Fprintf(f, "[%s] [%s] %s\n", timestamp, level, fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
@@ -0,0 +1,221 @@
|
||||
package heartbeat
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
|
||||
func TestExecuteHeartbeat_Async(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
hs := NewHeartbeatService(tmpDir, 30, true)
|
||||
hs.stopChan = make(chan struct{}) // Enable for testing
|
||||
|
||||
asyncCalled := false
|
||||
asyncResult := &tools.ToolResult{
|
||||
ForLLM: "Background task started",
|
||||
ForUser: "Task started in background",
|
||||
Silent: false,
|
||||
IsError: false,
|
||||
Async: true,
|
||||
}
|
||||
|
||||
hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
|
||||
asyncCalled = true
|
||||
if prompt == "" {
|
||||
t.Error("Expected non-empty prompt")
|
||||
}
|
||||
return asyncResult
|
||||
})
|
||||
|
||||
// Create HEARTBEAT.md
|
||||
os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0644)
|
||||
|
||||
// Execute heartbeat directly (internal method for testing)
|
||||
hs.executeHeartbeat()
|
||||
|
||||
if !asyncCalled {
|
||||
t.Error("Expected handler to be called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteHeartbeat_Error(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
hs := NewHeartbeatService(tmpDir, 30, true)
|
||||
hs.stopChan = make(chan struct{}) // Enable for testing
|
||||
|
||||
hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
|
||||
return &tools.ToolResult{
|
||||
ForLLM: "Heartbeat failed: connection error",
|
||||
ForUser: "",
|
||||
Silent: false,
|
||||
IsError: true,
|
||||
Async: false,
|
||||
}
|
||||
})
|
||||
|
||||
// Create HEARTBEAT.md
|
||||
os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0644)
|
||||
|
||||
hs.executeHeartbeat()
|
||||
|
||||
// Check log file for error message
|
||||
logFile := filepath.Join(tmpDir, "heartbeat.log")
|
||||
data, err := os.ReadFile(logFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read log file: %v", err)
|
||||
}
|
||||
|
||||
logContent := string(data)
|
||||
if logContent == "" {
|
||||
t.Error("Expected log file to contain error message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteHeartbeat_Silent(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
hs := NewHeartbeatService(tmpDir, 30, true)
|
||||
hs.stopChan = make(chan struct{}) // Enable for testing
|
||||
|
||||
hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
|
||||
return &tools.ToolResult{
|
||||
ForLLM: "Heartbeat completed successfully",
|
||||
ForUser: "",
|
||||
Silent: true,
|
||||
IsError: false,
|
||||
Async: false,
|
||||
}
|
||||
})
|
||||
|
||||
// Create HEARTBEAT.md
|
||||
os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0644)
|
||||
|
||||
hs.executeHeartbeat()
|
||||
|
||||
// Check log file for completion message
|
||||
logFile := filepath.Join(tmpDir, "heartbeat.log")
|
||||
data, err := os.ReadFile(logFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read log file: %v", err)
|
||||
}
|
||||
|
||||
logContent := string(data)
|
||||
if logContent == "" {
|
||||
t.Error("Expected log file to contain completion message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeartbeatService_StartStop(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
hs := NewHeartbeatService(tmpDir, 1, true)
|
||||
|
||||
err = hs.Start()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to start heartbeat service: %v", err)
|
||||
}
|
||||
|
||||
hs.Stop()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
func TestHeartbeatService_Disabled(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
hs := NewHeartbeatService(tmpDir, 1, false)
|
||||
|
||||
if hs.enabled != false {
|
||||
t.Error("Expected service to be disabled")
|
||||
}
|
||||
|
||||
err = hs.Start()
|
||||
_ = err // Disabled service returns nil
|
||||
}
|
||||
|
||||
func TestExecuteHeartbeat_NilResult(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
hs := NewHeartbeatService(tmpDir, 30, true)
|
||||
hs.stopChan = make(chan struct{}) // Enable for testing
|
||||
|
||||
hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
|
||||
return nil
|
||||
})
|
||||
|
||||
// Create HEARTBEAT.md
|
||||
os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0644)
|
||||
|
||||
// Should not panic with nil result
|
||||
hs.executeHeartbeat()
|
||||
}
|
||||
|
||||
// TestLogPath verifies heartbeat log is written to workspace directory
|
||||
func TestLogPath(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
hs := NewHeartbeatService(tmpDir, 30, true)
|
||||
|
||||
// Write a log entry
|
||||
hs.log("INFO", "Test log entry")
|
||||
|
||||
// Verify log file exists at workspace root
|
||||
expectedLogPath := filepath.Join(tmpDir, "heartbeat.log")
|
||||
if _, err := os.Stat(expectedLogPath); os.IsNotExist(err) {
|
||||
t.Errorf("Expected log file at %s, but it doesn't exist", expectedLogPath)
|
||||
}
|
||||
}
|
||||
|
||||
// TestHeartbeatFilePath verifies HEARTBEAT.md is at workspace root
|
||||
func TestHeartbeatFilePath(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
hs := NewHeartbeatService(tmpDir, 30, true)
|
||||
|
||||
// Trigger default template creation
|
||||
hs.buildPrompt()
|
||||
|
||||
// Verify HEARTBEAT.md exists at workspace root
|
||||
expectedPath := filepath.Join(tmpDir, "HEARTBEAT.md")
|
||||
if _, err := os.Stat(expectedPath); os.IsNotExist(err) {
|
||||
t.Errorf("Expected HEARTBEAT.md at %s, but it doesn't exist", expectedPath)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,603 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
// Client is the transport-agnostic MCP client contract.
|
||||
type Client interface {
|
||||
Start(ctx context.Context) error
|
||||
ListTools(ctx context.Context) ([]RemoteTool, error)
|
||||
CallTool(ctx context.Context, toolName string, arguments map[string]any) (CallResult, error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
// StdioClient speaks MCP over stdio (JSON-RPC framed with Content-Length headers).
|
||||
type StdioClient struct {
|
||||
config ServerConfig
|
||||
mode string
|
||||
|
||||
mu sync.Mutex
|
||||
writeMu sync.Mutex
|
||||
|
||||
started bool
|
||||
closed bool
|
||||
|
||||
cmd *exec.Cmd
|
||||
stdin io.WriteCloser
|
||||
stdout io.ReadCloser
|
||||
stderr io.ReadCloser
|
||||
waitCh chan struct{}
|
||||
pending map[string]chan rpcResponse
|
||||
|
||||
nextID uint64
|
||||
}
|
||||
|
||||
type rpcRequest struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Method string `json:"method"`
|
||||
Params any `json:"params,omitempty"`
|
||||
}
|
||||
|
||||
type rpcResponseEnvelope struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
ID json.RawMessage `json:"id,omitempty"`
|
||||
Result json.RawMessage `json:"result,omitempty"`
|
||||
Error *rpcError `json:"error,omitempty"`
|
||||
Method string `json:"method,omitempty"`
|
||||
}
|
||||
|
||||
type rpcError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type rpcResponse struct {
|
||||
result json.RawMessage
|
||||
rpcErr *rpcError
|
||||
err error
|
||||
}
|
||||
|
||||
type initializeParams struct {
|
||||
ProtocolVersion string `json:"protocolVersion"`
|
||||
Capabilities map[string]any `json:"capabilities"`
|
||||
ClientInfo map[string]interface{} `json:"clientInfo"`
|
||||
}
|
||||
|
||||
func NewStdioClient(config ServerConfig) *StdioClient {
|
||||
return &StdioClient{
|
||||
config: config,
|
||||
mode: normalizeProtocol(config.Protocol),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *StdioClient) Start(ctx context.Context) error {
|
||||
c.mu.Lock()
|
||||
if c.started {
|
||||
c.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
if strings.TrimSpace(c.config.Command) == "" {
|
||||
c.mu.Unlock()
|
||||
return fmt.Errorf("mcp server %q command is empty", c.config.Name)
|
||||
}
|
||||
|
||||
cmd := exec.Command(c.config.Command, c.config.Args...)
|
||||
if c.config.WorkingDir != "" {
|
||||
cmd.Dir = c.config.WorkingDir
|
||||
}
|
||||
cmd.Env = buildProcessEnv(c.config.Env)
|
||||
|
||||
stdin, err := cmd.StdinPipe()
|
||||
if err != nil {
|
||||
c.mu.Unlock()
|
||||
return fmt.Errorf("create stdin pipe: %w", err)
|
||||
}
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
c.mu.Unlock()
|
||||
return fmt.Errorf("create stdout pipe: %w", err)
|
||||
}
|
||||
stderr, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
c.mu.Unlock()
|
||||
return fmt.Errorf("create stderr pipe: %w", err)
|
||||
}
|
||||
if err := cmd.Start(); err != nil {
|
||||
c.mu.Unlock()
|
||||
return fmt.Errorf("start process: %w", err)
|
||||
}
|
||||
|
||||
c.started = true
|
||||
c.closed = false
|
||||
c.cmd = cmd
|
||||
c.stdin = stdin
|
||||
c.stdout = stdout
|
||||
c.stderr = stderr
|
||||
c.waitCh = make(chan struct{})
|
||||
c.pending = make(map[string]chan rpcResponse)
|
||||
c.mu.Unlock()
|
||||
|
||||
go c.readLoop()
|
||||
go c.waitLoop()
|
||||
go c.drainStderr()
|
||||
|
||||
initCtx, cancel := withTimeoutIfMissing(ctx, c.config.InitTimeout())
|
||||
defer cancel()
|
||||
|
||||
_, err = c.request(initCtx, "initialize", initializeParams{
|
||||
ProtocolVersion: "2024-11-05",
|
||||
Capabilities: map[string]any{
|
||||
"tools": map[string]any{},
|
||||
},
|
||||
ClientInfo: map[string]any{
|
||||
"name": "picoclaw",
|
||||
"version": "dev",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
_ = c.Close()
|
||||
return fmt.Errorf("initialize failed: %w", err)
|
||||
}
|
||||
|
||||
if err := c.notify("notifications/initialized", map[string]any{}); err != nil {
|
||||
_ = c.Close()
|
||||
return fmt.Errorf("initialized notification failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *StdioClient) ListTools(ctx context.Context) ([]RemoteTool, error) {
|
||||
if err := c.Start(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
type listToolsResponse struct {
|
||||
Tools []struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
InputSchema map[string]any `json:"inputSchema"`
|
||||
} `json:"tools"`
|
||||
NextCursor string `json:"nextCursor,omitempty"`
|
||||
}
|
||||
|
||||
allTools := make([]RemoteTool, 0, 8)
|
||||
cursor := ""
|
||||
|
||||
for page := 0; page < maxToolListPages; page++ {
|
||||
params := map[string]any{}
|
||||
if cursor != "" {
|
||||
params["cursor"] = cursor
|
||||
}
|
||||
|
||||
callCtx, cancel := withTimeoutIfMissing(ctx, c.config.CallTimeout())
|
||||
raw, err := c.request(callCtx, "tools/list", params)
|
||||
cancel()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var response listToolsResponse
|
||||
if err := json.Unmarshal(raw, &response); err != nil {
|
||||
return nil, fmt.Errorf("decode tools/list response: %w", err)
|
||||
}
|
||||
|
||||
for _, tool := range response.Tools {
|
||||
allTools = append(allTools, RemoteTool{
|
||||
Name: tool.Name,
|
||||
Description: tool.Description,
|
||||
InputSchema: tool.InputSchema,
|
||||
})
|
||||
}
|
||||
|
||||
if response.NextCursor == "" {
|
||||
return allTools, nil
|
||||
}
|
||||
cursor = response.NextCursor
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("tools/list exceeded %d pages", maxToolListPages)
|
||||
}
|
||||
|
||||
func (c *StdioClient) CallTool(ctx context.Context, toolName string, arguments map[string]any) (CallResult, error) {
|
||||
if err := c.Start(ctx); err != nil {
|
||||
return CallResult{}, err
|
||||
}
|
||||
|
||||
callCtx, cancel := withTimeoutIfMissing(ctx, c.config.CallTimeout())
|
||||
defer cancel()
|
||||
|
||||
raw, err := c.request(callCtx, "tools/call", map[string]any{
|
||||
"name": toolName,
|
||||
"arguments": arguments,
|
||||
})
|
||||
if err != nil {
|
||||
return CallResult{}, err
|
||||
}
|
||||
|
||||
return formatCallPayload(raw, c.config.ResponseLimit())
|
||||
}
|
||||
|
||||
func (c *StdioClient) Close() error {
|
||||
c.mu.Lock()
|
||||
if !c.started || c.closed {
|
||||
c.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
c.closed = true
|
||||
cmd := c.cmd
|
||||
stdin := c.stdin
|
||||
waitCh := c.waitCh
|
||||
c.mu.Unlock()
|
||||
|
||||
c.failPending(errors.New("mcp client closed"))
|
||||
|
||||
if stdin != nil {
|
||||
_ = stdin.Close()
|
||||
}
|
||||
if cmd != nil && cmd.Process != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
}
|
||||
|
||||
if waitCh != nil {
|
||||
select {
|
||||
case <-waitCh:
|
||||
case <-time.After(2 * time.Second):
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *StdioClient) request(ctx context.Context, method string, params any) (json.RawMessage, error) {
|
||||
id := strconv.FormatUint(atomic.AddUint64(&c.nextID, 1), 10)
|
||||
responseCh := make(chan rpcResponse, 1)
|
||||
|
||||
c.mu.Lock()
|
||||
if c.closed {
|
||||
c.mu.Unlock()
|
||||
return nil, fmt.Errorf("mcp server %q is closed", c.config.Name)
|
||||
}
|
||||
c.pending[id] = responseCh
|
||||
c.mu.Unlock()
|
||||
|
||||
req := rpcRequest{
|
||||
JSONRPC: "2.0",
|
||||
ID: id,
|
||||
Method: method,
|
||||
Params: params,
|
||||
}
|
||||
if err := c.writeMessage(req); err != nil {
|
||||
c.removePending(id)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
c.removePending(id)
|
||||
return nil, ctx.Err()
|
||||
case response := <-responseCh:
|
||||
if response.err != nil {
|
||||
return nil, response.err
|
||||
}
|
||||
if response.rpcErr != nil {
|
||||
return nil, fmt.Errorf("mcp error %d: %s", response.rpcErr.Code, response.rpcErr.Message)
|
||||
}
|
||||
return response.result, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *StdioClient) notify(method string, params any) error {
|
||||
req := rpcRequest{
|
||||
JSONRPC: "2.0",
|
||||
Method: method,
|
||||
Params: params,
|
||||
}
|
||||
return c.writeMessage(req)
|
||||
}
|
||||
|
||||
func (c *StdioClient) writeMessage(payload any) error {
|
||||
c.mu.Lock()
|
||||
if c.closed || c.stdin == nil {
|
||||
c.mu.Unlock()
|
||||
return fmt.Errorf("mcp server %q is not writable", c.config.Name)
|
||||
}
|
||||
stdin := c.stdin
|
||||
c.mu.Unlock()
|
||||
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal json-rpc payload: %w", err)
|
||||
}
|
||||
|
||||
if c.mode == ProtocolJSONLines {
|
||||
c.writeMu.Lock()
|
||||
defer c.writeMu.Unlock()
|
||||
|
||||
if _, err := stdin.Write(append(data, '\n')); err != nil {
|
||||
return fmt.Errorf("write jsonl body: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
frameHeader := fmt.Sprintf("Content-Length: %d\r\n\r\n", len(data))
|
||||
|
||||
c.writeMu.Lock()
|
||||
defer c.writeMu.Unlock()
|
||||
|
||||
if _, err := io.WriteString(stdin, frameHeader); err != nil {
|
||||
return fmt.Errorf("write frame header: %w", err)
|
||||
}
|
||||
if _, err := stdin.Write(data); err != nil {
|
||||
return fmt.Errorf("write frame body: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *StdioClient) readLoop() {
|
||||
if c.mode == ProtocolJSONLines {
|
||||
c.readJSONLLoop()
|
||||
return
|
||||
}
|
||||
|
||||
c.readMCPFrameLoop()
|
||||
}
|
||||
|
||||
func (c *StdioClient) readMCPFrameLoop() {
|
||||
reader := bufio.NewReader(c.stdout)
|
||||
|
||||
for {
|
||||
payload, err := readFramePayload(reader)
|
||||
if err != nil {
|
||||
c.failPending(err)
|
||||
return
|
||||
}
|
||||
|
||||
var envelope rpcResponseEnvelope
|
||||
if err := json.Unmarshal(payload, &envelope); err != nil {
|
||||
continue
|
||||
}
|
||||
c.dispatchResponse(envelope)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *StdioClient) readJSONLLoop() {
|
||||
scanner := bufio.NewScanner(c.stdout)
|
||||
scanner.Buffer(make([]byte, 0, defaultScannerBufferBytes), maxFrameBytes)
|
||||
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
var envelope rpcResponseEnvelope
|
||||
if err := json.Unmarshal([]byte(line), &envelope); err != nil {
|
||||
continue
|
||||
}
|
||||
c.dispatchResponse(envelope)
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
c.failPending(err)
|
||||
return
|
||||
}
|
||||
c.failPending(io.EOF)
|
||||
}
|
||||
|
||||
func (c *StdioClient) dispatchResponse(envelope rpcResponseEnvelope) {
|
||||
if len(envelope.ID) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
id, ok := parseRPCID(envelope.ID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
responseCh := c.pending[id]
|
||||
if responseCh != nil {
|
||||
delete(c.pending, id)
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
if responseCh == nil {
|
||||
return
|
||||
}
|
||||
|
||||
response := rpcResponse{
|
||||
result: envelope.Result,
|
||||
rpcErr: envelope.Error,
|
||||
}
|
||||
select {
|
||||
case responseCh <- response:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (c *StdioClient) waitLoop() {
|
||||
c.mu.Lock()
|
||||
cmd := c.cmd
|
||||
waitCh := c.waitCh
|
||||
serverName := c.config.Name
|
||||
c.mu.Unlock()
|
||||
|
||||
if cmd == nil {
|
||||
if waitCh != nil {
|
||||
close(waitCh)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
err := cmd.Wait()
|
||||
if waitCh != nil {
|
||||
close(waitCh)
|
||||
}
|
||||
if err != nil {
|
||||
logger.WarnCF("mcp", "MCP process exited with error",
|
||||
map[string]any{
|
||||
"server": serverName,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (c *StdioClient) drainStderr() {
|
||||
c.mu.Lock()
|
||||
stderr := c.stderr
|
||||
serverName := c.config.Name
|
||||
c.mu.Unlock()
|
||||
|
||||
if stderr == nil {
|
||||
return
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(stderr)
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
logger.DebugCF("mcp", "MCP server stderr",
|
||||
map[string]any{
|
||||
"server": serverName,
|
||||
"line": line,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (c *StdioClient) failPending(err error) {
|
||||
c.mu.Lock()
|
||||
pending := c.pending
|
||||
c.pending = make(map[string]chan rpcResponse)
|
||||
c.mu.Unlock()
|
||||
|
||||
if len(pending) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for _, ch := range pending {
|
||||
select {
|
||||
case ch <- rpcResponse{err: err}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *StdioClient) removePending(id string) {
|
||||
c.mu.Lock()
|
||||
delete(c.pending, id)
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
func readFramePayload(reader *bufio.Reader) ([]byte, error) {
|
||||
contentLength := -1
|
||||
|
||||
for {
|
||||
line, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
trimmed := strings.TrimRight(line, "\r\n")
|
||||
if trimmed == "" {
|
||||
break
|
||||
}
|
||||
|
||||
parts := strings.SplitN(trimmed, ":", 2)
|
||||
if len(parts) != 2 {
|
||||
continue
|
||||
}
|
||||
headerName := strings.TrimSpace(strings.ToLower(parts[0]))
|
||||
if headerName != "content-length" {
|
||||
continue
|
||||
}
|
||||
value := strings.TrimSpace(parts[1])
|
||||
length, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid content-length %q: %w", value, err)
|
||||
}
|
||||
contentLength = length
|
||||
}
|
||||
|
||||
if contentLength <= 0 {
|
||||
return nil, fmt.Errorf("missing content-length")
|
||||
}
|
||||
if contentLength > maxFrameBytes {
|
||||
return nil, fmt.Errorf("frame too large (%d bytes)", contentLength)
|
||||
}
|
||||
|
||||
payload := make([]byte, contentLength)
|
||||
if _, err := io.ReadFull(reader, payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
func parseRPCID(raw json.RawMessage) (string, bool) {
|
||||
var stringID string
|
||||
if err := json.Unmarshal(raw, &stringID); err == nil {
|
||||
return stringID, true
|
||||
}
|
||||
|
||||
var numberID float64
|
||||
if err := json.Unmarshal(raw, &numberID); err == nil {
|
||||
return strconv.FormatInt(int64(numberID), 10), true
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
func withTimeoutIfMissing(parent context.Context, timeout time.Duration) (context.Context, context.CancelFunc) {
|
||||
if _, hasDeadline := parent.Deadline(); hasDeadline {
|
||||
return context.WithCancel(parent)
|
||||
}
|
||||
return context.WithTimeout(parent, timeout)
|
||||
}
|
||||
|
||||
func buildProcessEnv(custom map[string]string) []string {
|
||||
base := os.Environ()
|
||||
if len(custom) == 0 {
|
||||
return base
|
||||
}
|
||||
|
||||
keys := make([]string, 0, len(custom))
|
||||
for key := range custom {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
env := make([]string, 0, len(base)+len(keys))
|
||||
env = append(env, base...)
|
||||
for _, key := range keys {
|
||||
env = append(env, key+"="+custom[key])
|
||||
}
|
||||
return env
|
||||
}
|
||||
|
||||
func normalizeProtocol(protocol string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(protocol)) {
|
||||
case "", ProtocolMCPFrames:
|
||||
return ProtocolMCPFrames
|
||||
case ProtocolJSONLines:
|
||||
return ProtocolJSONLines
|
||||
default:
|
||||
return ProtocolMCPFrames
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,61 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type callResponse struct {
|
||||
Content []contentBlock `json:"content"`
|
||||
StructuredContent any `json:"structuredContent,omitempty"`
|
||||
IsError bool `json:"isError,omitempty"`
|
||||
}
|
||||
|
||||
type contentBlock struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
}
|
||||
|
||||
func formatCallPayload(raw json.RawMessage, responseLimit int) (CallResult, error) {
|
||||
var payload callResponse
|
||||
if err := json.Unmarshal(raw, &payload); err != nil {
|
||||
// Fallback for servers that return non-standard payloads.
|
||||
return CallResult{
|
||||
Content: truncateString(strings.TrimSpace(string(raw)), responseLimit),
|
||||
IsError: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
parts := make([]string, 0, len(payload.Content)+1)
|
||||
for _, block := range payload.Content {
|
||||
if block.Type == "text" && strings.TrimSpace(block.Text) != "" {
|
||||
parts = append(parts, block.Text)
|
||||
}
|
||||
}
|
||||
|
||||
if payload.StructuredContent != nil {
|
||||
if encoded, err := json.Marshal(payload.StructuredContent); err == nil {
|
||||
parts = append(parts, string(encoded))
|
||||
}
|
||||
}
|
||||
|
||||
content := strings.TrimSpace(strings.Join(parts, "\n"))
|
||||
if content == "" {
|
||||
content = "{}"
|
||||
}
|
||||
|
||||
return CallResult{
|
||||
Content: truncateString(content, responseLimit),
|
||||
IsError: payload.IsError,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func truncateString(value string, maxBytes int) string {
|
||||
if maxBytes <= 0 || len(value) <= maxBytes {
|
||||
return value
|
||||
}
|
||||
if maxBytes <= 12 {
|
||||
return value[:maxBytes]
|
||||
}
|
||||
return value[:maxBytes-12] + "\n...[truncated]"
|
||||
}
|
||||
@@ -0,0 +1,190 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type clientFactory func(config ServerConfig) Client
|
||||
|
||||
type managedServer struct {
|
||||
config ServerConfig
|
||||
client Client
|
||||
}
|
||||
|
||||
// Manager owns MCP servers and maps discovered MCP tools to PicoClaw tools.
|
||||
type Manager struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
servers map[string]*managedServer
|
||||
tools map[string]RegisteredTool
|
||||
|
||||
discovered bool
|
||||
newClient clientFactory
|
||||
}
|
||||
|
||||
func NewManager(configs map[string]ServerConfig) *Manager {
|
||||
servers := make(map[string]*managedServer, len(configs))
|
||||
for name, cfg := range configs {
|
||||
copied := cfg
|
||||
copied.Name = name
|
||||
servers[name] = &managedServer{config: copied}
|
||||
}
|
||||
return &Manager{
|
||||
servers: servers,
|
||||
tools: make(map[string]RegisteredTool),
|
||||
discovered: false,
|
||||
newClient: func(config ServerConfig) Client {
|
||||
return NewStdioClient(config)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// DiscoverTools starts configured MCP servers and returns discovered tool metadata.
|
||||
func (m *Manager) DiscoverTools(ctx context.Context) ([]RegisteredTool, error) {
|
||||
m.mu.Lock()
|
||||
if m.discovered {
|
||||
tools := toolsFromMap(m.tools)
|
||||
m.mu.Unlock()
|
||||
return tools, nil
|
||||
}
|
||||
|
||||
discoveryErrors := make([]string, 0)
|
||||
|
||||
for serverName, server := range m.servers {
|
||||
client := m.newClient(server.config)
|
||||
if err := client.Start(ctx); err != nil {
|
||||
discoveryErrors = append(discoveryErrors, fmt.Sprintf("%s: %v", serverName, err))
|
||||
continue
|
||||
}
|
||||
|
||||
remoteTools, err := client.ListTools(ctx)
|
||||
if err != nil {
|
||||
_ = client.Close()
|
||||
discoveryErrors = append(discoveryErrors, fmt.Sprintf("%s: %v", serverName, err))
|
||||
continue
|
||||
}
|
||||
|
||||
server.client = client
|
||||
for _, remoteTool := range remoteTools {
|
||||
if !isToolAllowed(remoteTool.Name, server.config.IncludeTools, server.config.ExcludeTools) {
|
||||
continue
|
||||
}
|
||||
|
||||
qualifiedName := m.makeUniqueToolName(serverName, remoteTool.Name)
|
||||
parameters := normalizeSchema(remoteTool.InputSchema)
|
||||
m.tools[qualifiedName] = RegisteredTool{
|
||||
QualifiedName: qualifiedName,
|
||||
ServerName: serverName,
|
||||
ToolName: remoteTool.Name,
|
||||
Description: remoteTool.Description,
|
||||
Parameters: parameters,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
m.discovered = true
|
||||
tools := toolsFromMap(m.tools)
|
||||
m.mu.Unlock()
|
||||
|
||||
if len(tools) == 0 && len(discoveryErrors) > 0 {
|
||||
return nil, fmt.Errorf("mcp tool discovery failed: %s", strings.Join(discoveryErrors, "; "))
|
||||
}
|
||||
return tools, nil
|
||||
}
|
||||
|
||||
func (m *Manager) CallTool(ctx context.Context, qualifiedName string, args map[string]any) (CallResult, error) {
|
||||
m.mu.RLock()
|
||||
tool, ok := m.tools[qualifiedName]
|
||||
if !ok {
|
||||
m.mu.RUnlock()
|
||||
return CallResult{}, fmt.Errorf("mcp tool %q not found", qualifiedName)
|
||||
}
|
||||
|
||||
server := m.servers[tool.ServerName]
|
||||
if server == nil || server.client == nil {
|
||||
m.mu.RUnlock()
|
||||
return CallResult{}, fmt.Errorf("mcp server %q is not active", tool.ServerName)
|
||||
}
|
||||
client := server.client
|
||||
toolName := tool.ToolName
|
||||
m.mu.RUnlock()
|
||||
|
||||
if args == nil {
|
||||
args = map[string]any{}
|
||||
}
|
||||
return client.CallTool(ctx, toolName, args)
|
||||
}
|
||||
|
||||
func (m *Manager) Close() error {
|
||||
m.mu.Lock()
|
||||
servers := make([]*managedServer, 0, len(m.servers))
|
||||
for _, server := range m.servers {
|
||||
servers = append(servers, server)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
var firstErr error
|
||||
for _, server := range servers {
|
||||
if server.client == nil {
|
||||
continue
|
||||
}
|
||||
if err := server.client.Close(); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
return firstErr
|
||||
}
|
||||
|
||||
func (m *Manager) makeUniqueToolName(serverName, toolName string) string {
|
||||
base := QualifiedToolName(serverName, toolName)
|
||||
if _, exists := m.tools[base]; !exists {
|
||||
return base
|
||||
}
|
||||
|
||||
for index := 2; ; index++ {
|
||||
candidate := fmt.Sprintf("%s_%d", base, index)
|
||||
if len(candidate) > qualifiedNameMaxLen {
|
||||
overflow := len(candidate) - qualifiedNameMaxLen
|
||||
if overflow < len(base) {
|
||||
candidate = base[:len(base)-overflow] + fmt.Sprintf("_%d", index)
|
||||
} else {
|
||||
candidate = candidate[:qualifiedNameMaxLen]
|
||||
}
|
||||
}
|
||||
if _, exists := m.tools[candidate]; !exists {
|
||||
return candidate
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeSchema(schema map[string]any) map[string]any {
|
||||
if len(schema) == 0 {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{},
|
||||
}
|
||||
}
|
||||
return schema
|
||||
}
|
||||
|
||||
func isToolAllowed(name string, include, exclude []string) bool {
|
||||
if len(include) > 0 && !slices.Contains(include, name) {
|
||||
return false
|
||||
}
|
||||
if slices.Contains(exclude, name) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func toolsFromMap(tools map[string]RegisteredTool) []RegisteredTool {
|
||||
out := make([]RegisteredTool, 0, len(tools))
|
||||
for _, tool := range tools {
|
||||
out = append(out, tool)
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
package mcp
|
||||
|
||||
import "strings"
|
||||
|
||||
const qualifiedNameMaxLen = 64
|
||||
|
||||
// QualifiedToolName creates a stable, provider-safe function name.
|
||||
func QualifiedToolName(serverName, toolName string) string {
|
||||
prefix := "mcp_" + sanitizeName(serverName) + "__"
|
||||
tool := sanitizeName(toolName)
|
||||
maxToolLen := qualifiedNameMaxLen - len(prefix)
|
||||
if maxToolLen <= 0 {
|
||||
return prefix[:qualifiedNameMaxLen]
|
||||
}
|
||||
if len(tool) > maxToolLen {
|
||||
tool = tool[:maxToolLen]
|
||||
}
|
||||
return prefix + tool
|
||||
}
|
||||
|
||||
func sanitizeName(value string) string {
|
||||
trimmed := strings.TrimSpace(strings.ToLower(value))
|
||||
if trimmed == "" {
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
b.Grow(len(trimmed))
|
||||
|
||||
lastUnderscore := false
|
||||
for i := 0; i < len(trimmed); i++ {
|
||||
ch := trimmed[i]
|
||||
isAlphaNum := (ch >= 'a' && ch <= 'z') || (ch >= '0' && ch <= '9')
|
||||
if isAlphaNum {
|
||||
b.WriteByte(ch)
|
||||
lastUnderscore = false
|
||||
continue
|
||||
}
|
||||
if !lastUnderscore {
|
||||
b.WriteByte('_')
|
||||
lastUnderscore = true
|
||||
}
|
||||
}
|
||||
|
||||
s := strings.Trim(b.String(), "_")
|
||||
if s == "" {
|
||||
s = "unknown"
|
||||
}
|
||||
if s[0] >= '0' && s[0] <= '9' {
|
||||
return "t_" + s
|
||||
}
|
||||
return s
|
||||
}
|
||||
@@ -0,0 +1,77 @@
|
||||
package mcp
|
||||
|
||||
import "time"
|
||||
|
||||
const (
|
||||
defaultInitTimeoutSeconds = 60
|
||||
defaultCallTimeoutSeconds = 30
|
||||
defaultMaxResponseBytes = 64 * 1024
|
||||
defaultScannerBufferBytes = 64 * 1024
|
||||
maxFrameBytes = 2 * 1024 * 1024
|
||||
maxToolListPages = 50
|
||||
)
|
||||
|
||||
const (
|
||||
ProtocolMCPFrames = "mcp"
|
||||
ProtocolJSONLines = "jsonl"
|
||||
)
|
||||
|
||||
// ServerConfig defines one MCP server connection.
|
||||
type ServerConfig struct {
|
||||
Name string
|
||||
Command string
|
||||
Args []string
|
||||
Env map[string]string
|
||||
WorkingDir string
|
||||
Protocol string
|
||||
InitTimeoutSeconds int
|
||||
CallTimeoutSeconds int
|
||||
MaxResponseBytes int
|
||||
IncludeTools []string
|
||||
ExcludeTools []string
|
||||
}
|
||||
|
||||
func (c ServerConfig) InitTimeout() time.Duration {
|
||||
seconds := c.InitTimeoutSeconds
|
||||
if seconds <= 0 {
|
||||
seconds = defaultInitTimeoutSeconds
|
||||
}
|
||||
return time.Duration(seconds) * time.Second
|
||||
}
|
||||
|
||||
func (c ServerConfig) CallTimeout() time.Duration {
|
||||
seconds := c.CallTimeoutSeconds
|
||||
if seconds <= 0 {
|
||||
seconds = defaultCallTimeoutSeconds
|
||||
}
|
||||
return time.Duration(seconds) * time.Second
|
||||
}
|
||||
|
||||
func (c ServerConfig) ResponseLimit() int {
|
||||
if c.MaxResponseBytes <= 0 {
|
||||
return defaultMaxResponseBytes
|
||||
}
|
||||
return c.MaxResponseBytes
|
||||
}
|
||||
|
||||
// RemoteTool is an MCP tool discovered from a server.
|
||||
type RemoteTool struct {
|
||||
Name string
|
||||
Description string
|
||||
InputSchema map[string]any
|
||||
}
|
||||
|
||||
// RegisteredTool is a discovered tool with a PicoClaw-facing qualified name.
|
||||
type RegisteredTool struct {
|
||||
QualifiedName string
|
||||
ServerName string
|
||||
ToolName string
|
||||
Description string
|
||||
Parameters map[string]any
|
||||
}
|
||||
|
||||
// CallResult is a normalized MCP tool call result.
|
||||
type CallResult struct {
|
||||
Content string
|
||||
IsError bool
|
||||
}
|
||||
+10
-5
@@ -27,7 +27,7 @@ var supportedChannels = map[string]bool{
|
||||
"whatsapp": true,
|
||||
"feishu": true,
|
||||
"qq": true,
|
||||
"dingtalk": true,
|
||||
"dingtalk": true,
|
||||
"maixcam": true,
|
||||
}
|
||||
|
||||
@@ -212,12 +212,17 @@ func ConvertConfig(data map[string]interface{}) (*config.Config, []string, error
|
||||
|
||||
if tools, ok := getMap(data, "tools"); ok {
|
||||
if web, ok := getMap(tools, "web"); ok {
|
||||
// Migrate old "search" config to "brave" if api_key is present
|
||||
if search, ok := getMap(web, "search"); ok {
|
||||
if v, ok := getString(search, "api_key"); ok {
|
||||
cfg.Tools.Web.Search.APIKey = v
|
||||
cfg.Tools.Web.Brave.APIKey = v
|
||||
if v != "" {
|
||||
cfg.Tools.Web.Brave.Enabled = true
|
||||
}
|
||||
}
|
||||
if v, ok := getFloat(search, "max_results"); ok {
|
||||
cfg.Tools.Web.Search.MaxResults = int(v)
|
||||
cfg.Tools.Web.Brave.MaxResults = int(v)
|
||||
cfg.Tools.Web.DuckDuckGo.MaxResults = int(v)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -271,8 +276,8 @@ func MergeConfig(existing, incoming *config.Config) *config.Config {
|
||||
existing.Channels.MaixCam = incoming.Channels.MaixCam
|
||||
}
|
||||
|
||||
if existing.Tools.Web.Search.APIKey == "" {
|
||||
existing.Tools.Web.Search = incoming.Tools.Web.Search
|
||||
if existing.Tools.Web.Brave.APIKey == "" {
|
||||
existing.Tools.Web.Brave = incoming.Tools.Web.Brave
|
||||
}
|
||||
|
||||
return existing
|
||||
|
||||
@@ -44,8 +44,8 @@ func TestConvertKeysToSnake(t *testing.T) {
|
||||
"apiKey": "test-key",
|
||||
"apiBase": "https://example.com",
|
||||
"nested": map[string]interface{}{
|
||||
"maxTokens": float64(8192),
|
||||
"allowFrom": []interface{}{"user1", "user2"},
|
||||
"maxTokens": float64(8192),
|
||||
"allowFrom": []interface{}{"user1", "user2"},
|
||||
"deeperLevel": map[string]interface{}{
|
||||
"clientId": "abc",
|
||||
},
|
||||
@@ -256,11 +256,11 @@ func TestConvertConfig(t *testing.T) {
|
||||
data := map[string]interface{}{
|
||||
"agents": map[string]interface{}{
|
||||
"defaults": map[string]interface{}{
|
||||
"model": "claude-3-opus",
|
||||
"max_tokens": float64(4096),
|
||||
"temperature": 0.5,
|
||||
"max_tool_iterations": float64(10),
|
||||
"workspace": "~/.openclaw/workspace",
|
||||
"model": "claude-3-opus",
|
||||
"max_tokens": float64(4096),
|
||||
"temperature": 0.5,
|
||||
"max_tool_iterations": float64(10),
|
||||
"workspace": "~/.openclaw/workspace",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -171,68 +171,14 @@ func (p *ClaudeCliProvider) parseClaudeCliResponse(output string) (*LLMResponse,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// extractToolCalls parses tool call JSON from the response text.
|
||||
// extractToolCalls delegates to the shared extractToolCallsFromText function.
|
||||
func (p *ClaudeCliProvider) extractToolCalls(text string) []ToolCall {
|
||||
start := strings.Index(text, `{"tool_calls"`)
|
||||
if start == -1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
end := findMatchingBrace(text, start)
|
||||
if end == start {
|
||||
return nil
|
||||
}
|
||||
|
||||
jsonStr := text[start:end]
|
||||
|
||||
var wrapper struct {
|
||||
ToolCalls []struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Function struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
} `json:"function"`
|
||||
} `json:"tool_calls"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(jsonStr), &wrapper); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var result []ToolCall
|
||||
for _, tc := range wrapper.ToolCalls {
|
||||
var args map[string]interface{}
|
||||
json.Unmarshal([]byte(tc.Function.Arguments), &args)
|
||||
|
||||
result = append(result, ToolCall{
|
||||
ID: tc.ID,
|
||||
Type: tc.Type,
|
||||
Name: tc.Function.Name,
|
||||
Arguments: args,
|
||||
Function: &FunctionCall{
|
||||
Name: tc.Function.Name,
|
||||
Arguments: tc.Function.Arguments,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return result
|
||||
return extractToolCallsFromText(text)
|
||||
}
|
||||
|
||||
// stripToolCallsJSON removes tool call JSON from response text.
|
||||
// stripToolCallsJSON delegates to the shared stripToolCallsFromText function.
|
||||
func (p *ClaudeCliProvider) stripToolCallsJSON(text string) string {
|
||||
start := strings.Index(text, `{"tool_calls"`)
|
||||
if start == -1 {
|
||||
return text
|
||||
}
|
||||
|
||||
end := findMatchingBrace(text, start)
|
||||
if end == start {
|
||||
return text
|
||||
}
|
||||
|
||||
return strings.TrimSpace(text[:start] + text[end:])
|
||||
return stripToolCallsFromText(text)
|
||||
}
|
||||
|
||||
// findMatchingBrace finds the index after the closing brace matching the opening brace at pos.
|
||||
@@ -254,22 +200,22 @@ func findMatchingBrace(text string, pos int) int {
|
||||
// claudeCliJSONResponse represents the JSON output from the claude CLI.
|
||||
// Matches the real claude CLI v2.x output format.
|
||||
type claudeCliJSONResponse struct {
|
||||
Type string `json:"type"`
|
||||
Subtype string `json:"subtype"`
|
||||
IsError bool `json:"is_error"`
|
||||
Result string `json:"result"`
|
||||
SessionID string `json:"session_id"`
|
||||
TotalCostUSD float64 `json:"total_cost_usd"`
|
||||
DurationMS int `json:"duration_ms"`
|
||||
DurationAPI int `json:"duration_api_ms"`
|
||||
NumTurns int `json:"num_turns"`
|
||||
Usage claudeCliUsageInfo `json:"usage"`
|
||||
Type string `json:"type"`
|
||||
Subtype string `json:"subtype"`
|
||||
IsError bool `json:"is_error"`
|
||||
Result string `json:"result"`
|
||||
SessionID string `json:"session_id"`
|
||||
TotalCostUSD float64 `json:"total_cost_usd"`
|
||||
DurationMS int `json:"duration_ms"`
|
||||
DurationAPI int `json:"duration_api_ms"`
|
||||
NumTurns int `json:"num_turns"`
|
||||
Usage claudeCliUsageInfo `json:"usage"`
|
||||
}
|
||||
|
||||
// claudeCliUsageInfo represents token usage from the claude CLI response.
|
||||
type claudeCliUsageInfo struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
|
||||
CacheReadInputTokens int `json:"cache_read_input_tokens"`
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
|
||||
CacheReadInputTokens int `json:"cache_read_input_tokens"`
|
||||
}
|
||||
|
||||
@@ -0,0 +1,126 @@
|
||||
//go:build integration
|
||||
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
exec "os/exec"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestIntegration_RealClaudeCLI tests the ClaudeCliProvider with a real claude CLI.
|
||||
// Run with: go test -tags=integration ./pkg/providers/...
|
||||
func TestIntegration_RealClaudeCLI(t *testing.T) {
|
||||
// Check if claude CLI is available
|
||||
path, err := exec.LookPath("claude")
|
||||
if err != nil {
|
||||
t.Skip("claude CLI not found in PATH, skipping integration test")
|
||||
}
|
||||
t.Logf("Using claude CLI at: %s", path)
|
||||
|
||||
p := NewClaudeCliProvider(t.TempDir())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancel()
|
||||
|
||||
resp, err := p.Chat(ctx, []Message{
|
||||
{Role: "user", Content: "Respond with only the word 'pong'. Nothing else."},
|
||||
}, nil, "", nil)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() with real CLI error = %v", err)
|
||||
}
|
||||
|
||||
// Verify response structure
|
||||
if resp.Content == "" {
|
||||
t.Error("Content is empty")
|
||||
}
|
||||
if resp.FinishReason != "stop" {
|
||||
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
|
||||
}
|
||||
if resp.Usage == nil {
|
||||
t.Error("Usage should not be nil from real CLI")
|
||||
} else {
|
||||
if resp.Usage.PromptTokens == 0 {
|
||||
t.Error("PromptTokens should be > 0")
|
||||
}
|
||||
if resp.Usage.CompletionTokens == 0 {
|
||||
t.Error("CompletionTokens should be > 0")
|
||||
}
|
||||
t.Logf("Usage: prompt=%d, completion=%d, total=%d",
|
||||
resp.Usage.PromptTokens, resp.Usage.CompletionTokens, resp.Usage.TotalTokens)
|
||||
}
|
||||
|
||||
t.Logf("Response content: %q", resp.Content)
|
||||
|
||||
// Loose check - should contain "pong" somewhere (model might capitalize or add punctuation)
|
||||
if !strings.Contains(strings.ToLower(resp.Content), "pong") {
|
||||
t.Errorf("Content = %q, expected to contain 'pong'", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_RealClaudeCLI_WithSystemPrompt(t *testing.T) {
|
||||
if _, err := exec.LookPath("claude"); err != nil {
|
||||
t.Skip("claude CLI not found in PATH")
|
||||
}
|
||||
|
||||
p := NewClaudeCliProvider(t.TempDir())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancel()
|
||||
|
||||
resp, err := p.Chat(ctx, []Message{
|
||||
{Role: "system", Content: "You are a calculator. Only respond with numbers. No text."},
|
||||
{Role: "user", Content: "What is 2+2?"},
|
||||
}, nil, "", nil)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
t.Logf("Response: %q", resp.Content)
|
||||
|
||||
if !strings.Contains(resp.Content, "4") {
|
||||
t.Errorf("Content = %q, expected to contain '4'", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_RealClaudeCLI_ParsesRealJSON(t *testing.T) {
|
||||
if _, err := exec.LookPath("claude"); err != nil {
|
||||
t.Skip("claude CLI not found in PATH")
|
||||
}
|
||||
|
||||
// Run claude directly and verify our parser handles real output
|
||||
cmd := exec.Command("claude", "-p", "--output-format", "json",
|
||||
"--dangerously-skip-permissions", "--no-chrome", "--no-session-persistence", "-")
|
||||
cmd.Stdin = strings.NewReader("Say hi")
|
||||
cmd.Dir = t.TempDir()
|
||||
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
t.Fatalf("claude CLI failed: %v", err)
|
||||
}
|
||||
|
||||
t.Logf("Raw CLI output: %s", string(output))
|
||||
|
||||
// Verify our parser can handle real output
|
||||
p := NewClaudeCliProvider("")
|
||||
resp, err := p.parseClaudeCliResponse(string(output))
|
||||
if err != nil {
|
||||
t.Fatalf("parseClaudeCliResponse() failed on real CLI output: %v", err)
|
||||
}
|
||||
|
||||
if resp.Content == "" {
|
||||
t.Error("parsed Content is empty")
|
||||
}
|
||||
if resp.FinishReason != "stop" {
|
||||
t.Errorf("FinishReason = %q, want stop", resp.FinishReason)
|
||||
}
|
||||
if resp.Usage == nil {
|
||||
t.Error("Usage should not be nil")
|
||||
}
|
||||
|
||||
t.Logf("Parsed: content=%q, finish=%s, usage=%+v", resp.Content, resp.FinishReason, resp.Usage)
|
||||
}
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
@@ -968,9 +967,9 @@ func TestFindMatchingBrace(t *testing.T) {
|
||||
{`{"a":1}`, 0, 7},
|
||||
{`{"a":{"b":2}}`, 0, 13},
|
||||
{`text {"a":1} more`, 5, 12},
|
||||
{`{unclosed`, 0, 0}, // no match returns pos
|
||||
{`{}`, 0, 2}, // empty object
|
||||
{`{{{}}}`, 0, 6}, // deeply nested
|
||||
{`{unclosed`, 0, 0}, // no match returns pos
|
||||
{`{}`, 0, 2}, // empty object
|
||||
{`{{{}}}`, 0, 6}, // deeply nested
|
||||
{`{"a":"b{c}d"}`, 0, 13}, // braces in strings (simplified matcher)
|
||||
}
|
||||
for _, tt := range tests {
|
||||
@@ -980,130 +979,3 @@ func TestFindMatchingBrace(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- Integration test: real claude CLI ---
|
||||
|
||||
func TestIntegration_RealClaudeCLI(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
// Check if claude CLI is available
|
||||
path, err := exec.LookPath("claude")
|
||||
if err != nil {
|
||||
t.Skip("claude CLI not found in PATH, skipping integration test")
|
||||
}
|
||||
t.Logf("Using claude CLI at: %s", path)
|
||||
|
||||
p := NewClaudeCliProvider(t.TempDir())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancel()
|
||||
|
||||
resp, err := p.Chat(ctx, []Message{
|
||||
{Role: "user", Content: "Respond with only the word 'pong'. Nothing else."},
|
||||
}, nil, "", nil)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() with real CLI error = %v", err)
|
||||
}
|
||||
|
||||
// Verify response structure
|
||||
if resp.Content == "" {
|
||||
t.Error("Content is empty")
|
||||
}
|
||||
if resp.FinishReason != "stop" {
|
||||
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
|
||||
}
|
||||
if resp.Usage == nil {
|
||||
t.Error("Usage should not be nil from real CLI")
|
||||
} else {
|
||||
if resp.Usage.PromptTokens == 0 {
|
||||
t.Error("PromptTokens should be > 0")
|
||||
}
|
||||
if resp.Usage.CompletionTokens == 0 {
|
||||
t.Error("CompletionTokens should be > 0")
|
||||
}
|
||||
t.Logf("Usage: prompt=%d, completion=%d, total=%d",
|
||||
resp.Usage.PromptTokens, resp.Usage.CompletionTokens, resp.Usage.TotalTokens)
|
||||
}
|
||||
|
||||
t.Logf("Response content: %q", resp.Content)
|
||||
|
||||
// Loose check - should contain "pong" somewhere (model might capitalize or add punctuation)
|
||||
if !strings.Contains(strings.ToLower(resp.Content), "pong") {
|
||||
t.Errorf("Content = %q, expected to contain 'pong'", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_RealClaudeCLI_WithSystemPrompt(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
if _, err := exec.LookPath("claude"); err != nil {
|
||||
t.Skip("claude CLI not found in PATH")
|
||||
}
|
||||
|
||||
p := NewClaudeCliProvider(t.TempDir())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancel()
|
||||
|
||||
resp, err := p.Chat(ctx, []Message{
|
||||
{Role: "system", Content: "You are a calculator. Only respond with numbers. No text."},
|
||||
{Role: "user", Content: "What is 2+2?"},
|
||||
}, nil, "", nil)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
t.Logf("Response: %q", resp.Content)
|
||||
|
||||
if !strings.Contains(resp.Content, "4") {
|
||||
t.Errorf("Content = %q, expected to contain '4'", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_RealClaudeCLI_ParsesRealJSON(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
if _, err := exec.LookPath("claude"); err != nil {
|
||||
t.Skip("claude CLI not found in PATH")
|
||||
}
|
||||
|
||||
// Run claude directly and verify our parser handles real output
|
||||
cmd := exec.Command("claude", "-p", "--output-format", "json",
|
||||
"--dangerously-skip-permissions", "--no-chrome", "--no-session-persistence", "-")
|
||||
cmd.Stdin = strings.NewReader("Say hi")
|
||||
cmd.Dir = t.TempDir()
|
||||
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
t.Fatalf("claude CLI failed: %v", err)
|
||||
}
|
||||
|
||||
t.Logf("Raw CLI output: %s", string(output))
|
||||
|
||||
// Verify our parser can handle real output
|
||||
p := NewClaudeCliProvider("")
|
||||
resp, err := p.parseClaudeCliResponse(string(output))
|
||||
if err != nil {
|
||||
t.Fatalf("parseClaudeCliResponse() failed on real CLI output: %v", err)
|
||||
}
|
||||
|
||||
if resp.Content == "" {
|
||||
t.Error("parsed Content is empty")
|
||||
}
|
||||
if resp.FinishReason != "stop" {
|
||||
t.Errorf("FinishReason = %q, want stop", resp.FinishReason)
|
||||
}
|
||||
if resp.Usage == nil {
|
||||
t.Error("Usage should not be nil")
|
||||
}
|
||||
|
||||
t.Logf("Parsed: content=%q, finish=%s, usage=%+v", resp.Content, resp.FinishReason, resp.Usage)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,79 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CodexCliAuth represents the ~/.codex/auth.json file structure.
|
||||
type CodexCliAuth struct {
|
||||
Tokens struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
AccountID string `json:"account_id"`
|
||||
} `json:"tokens"`
|
||||
}
|
||||
|
||||
// ReadCodexCliCredentials reads OAuth tokens from the Codex CLI's auth.json file.
|
||||
// Expiry is estimated as file modification time + 1 hour (same approach as moltbot).
|
||||
func ReadCodexCliCredentials() (accessToken, accountID string, expiresAt time.Time, err error) {
|
||||
authPath, err := resolveCodexAuthPath()
|
||||
if err != nil {
|
||||
return "", "", time.Time{}, err
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(authPath)
|
||||
if err != nil {
|
||||
return "", "", time.Time{}, fmt.Errorf("reading %s: %w", authPath, err)
|
||||
}
|
||||
|
||||
var auth CodexCliAuth
|
||||
if err := json.Unmarshal(data, &auth); err != nil {
|
||||
return "", "", time.Time{}, fmt.Errorf("parsing %s: %w", authPath, err)
|
||||
}
|
||||
|
||||
if auth.Tokens.AccessToken == "" {
|
||||
return "", "", time.Time{}, fmt.Errorf("no access_token in %s", authPath)
|
||||
}
|
||||
|
||||
stat, err := os.Stat(authPath)
|
||||
if err != nil {
|
||||
expiresAt = time.Now().Add(time.Hour)
|
||||
} else {
|
||||
expiresAt = stat.ModTime().Add(time.Hour)
|
||||
}
|
||||
|
||||
return auth.Tokens.AccessToken, auth.Tokens.AccountID, expiresAt, nil
|
||||
}
|
||||
|
||||
// CreateCodexCliTokenSource creates a token source that reads from ~/.codex/auth.json.
|
||||
// This allows the existing CodexProvider to reuse Codex CLI credentials.
|
||||
func CreateCodexCliTokenSource() func() (string, string, error) {
|
||||
return func() (string, string, error) {
|
||||
token, accountID, expiresAt, err := ReadCodexCliCredentials()
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("reading codex cli credentials: %w", err)
|
||||
}
|
||||
|
||||
if time.Now().After(expiresAt) {
|
||||
return "", "", fmt.Errorf("codex cli credentials expired (auth.json last modified > 1h ago). Run: codex login")
|
||||
}
|
||||
|
||||
return token, accountID, nil
|
||||
}
|
||||
}
|
||||
|
||||
func resolveCodexAuthPath() (string, error) {
|
||||
codexHome := os.Getenv("CODEX_HOME")
|
||||
if codexHome == "" {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("getting home dir: %w", err)
|
||||
}
|
||||
codexHome = filepath.Join(home, ".codex")
|
||||
}
|
||||
return filepath.Join(codexHome, "auth.json"), nil
|
||||
}
|
||||
@@ -0,0 +1,181 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestReadCodexCliCredentials_Valid(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
authPath := filepath.Join(tmpDir, "auth.json")
|
||||
|
||||
authJSON := `{
|
||||
"tokens": {
|
||||
"access_token": "test-access-token",
|
||||
"refresh_token": "test-refresh-token",
|
||||
"account_id": "org-test123"
|
||||
}
|
||||
}`
|
||||
if err := os.WriteFile(authPath, []byte(authJSON), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Setenv("CODEX_HOME", tmpDir)
|
||||
|
||||
token, accountID, expiresAt, err := ReadCodexCliCredentials()
|
||||
if err != nil {
|
||||
t.Fatalf("ReadCodexCliCredentials() error: %v", err)
|
||||
}
|
||||
if token != "test-access-token" {
|
||||
t.Errorf("token = %q, want %q", token, "test-access-token")
|
||||
}
|
||||
if accountID != "org-test123" {
|
||||
t.Errorf("accountID = %q, want %q", accountID, "org-test123")
|
||||
}
|
||||
// Expiry should be within ~1 hour from now (file was just written)
|
||||
if expiresAt.Before(time.Now()) {
|
||||
t.Errorf("expiresAt = %v, should be in the future", expiresAt)
|
||||
}
|
||||
if expiresAt.After(time.Now().Add(2 * time.Hour)) {
|
||||
t.Errorf("expiresAt = %v, should be within ~1 hour", expiresAt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadCodexCliCredentials_MissingFile(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
t.Setenv("CODEX_HOME", tmpDir)
|
||||
|
||||
_, _, _, err := ReadCodexCliCredentials()
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing auth.json")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadCodexCliCredentials_EmptyToken(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
authPath := filepath.Join(tmpDir, "auth.json")
|
||||
|
||||
authJSON := `{"tokens": {"access_token": "", "refresh_token": "r", "account_id": "a"}}`
|
||||
if err := os.WriteFile(authPath, []byte(authJSON), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Setenv("CODEX_HOME", tmpDir)
|
||||
|
||||
_, _, _, err := ReadCodexCliCredentials()
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty access_token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadCodexCliCredentials_InvalidJSON(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
authPath := filepath.Join(tmpDir, "auth.json")
|
||||
|
||||
if err := os.WriteFile(authPath, []byte("not json"), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Setenv("CODEX_HOME", tmpDir)
|
||||
|
||||
_, _, _, err := ReadCodexCliCredentials()
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid JSON")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadCodexCliCredentials_NoAccountID(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
authPath := filepath.Join(tmpDir, "auth.json")
|
||||
|
||||
authJSON := `{"tokens": {"access_token": "tok123", "refresh_token": "ref456"}}`
|
||||
if err := os.WriteFile(authPath, []byte(authJSON), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Setenv("CODEX_HOME", tmpDir)
|
||||
|
||||
token, accountID, _, err := ReadCodexCliCredentials()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if token != "tok123" {
|
||||
t.Errorf("token = %q, want %q", token, "tok123")
|
||||
}
|
||||
if accountID != "" {
|
||||
t.Errorf("accountID = %q, want empty", accountID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadCodexCliCredentials_CodexHomeEnv(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
customDir := filepath.Join(tmpDir, "custom-codex")
|
||||
if err := os.MkdirAll(customDir, 0755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
authJSON := `{"tokens": {"access_token": "custom-token", "refresh_token": "r"}}`
|
||||
if err := os.WriteFile(filepath.Join(customDir, "auth.json"), []byte(authJSON), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Setenv("CODEX_HOME", customDir)
|
||||
|
||||
token, _, _, err := ReadCodexCliCredentials()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if token != "custom-token" {
|
||||
t.Errorf("token = %q, want %q", token, "custom-token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateCodexCliTokenSource_Valid(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
authPath := filepath.Join(tmpDir, "auth.json")
|
||||
|
||||
authJSON := `{"tokens": {"access_token": "fresh-token", "refresh_token": "r", "account_id": "acc"}}`
|
||||
if err := os.WriteFile(authPath, []byte(authJSON), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Setenv("CODEX_HOME", tmpDir)
|
||||
|
||||
source := CreateCodexCliTokenSource()
|
||||
token, accountID, err := source()
|
||||
if err != nil {
|
||||
t.Fatalf("token source error: %v", err)
|
||||
}
|
||||
if token != "fresh-token" {
|
||||
t.Errorf("token = %q, want %q", token, "fresh-token")
|
||||
}
|
||||
if accountID != "acc" {
|
||||
t.Errorf("accountID = %q, want %q", accountID, "acc")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateCodexCliTokenSource_Expired(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
authPath := filepath.Join(tmpDir, "auth.json")
|
||||
|
||||
authJSON := `{"tokens": {"access_token": "old-token", "refresh_token": "r"}}`
|
||||
if err := os.WriteFile(authPath, []byte(authJSON), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Set file modification time to 2 hours ago
|
||||
oldTime := time.Now().Add(-2 * time.Hour)
|
||||
if err := os.Chtimes(authPath, oldTime, oldTime); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Setenv("CODEX_HOME", tmpDir)
|
||||
|
||||
source := CreateCodexCliTokenSource()
|
||||
_, _, err := source()
|
||||
if err == nil {
|
||||
t.Fatal("expected error for expired credentials")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,251 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// CodexCliProvider implements LLMProvider by wrapping the codex CLI as a subprocess.
|
||||
type CodexCliProvider struct {
|
||||
command string
|
||||
workspace string
|
||||
}
|
||||
|
||||
// NewCodexCliProvider creates a new Codex CLI provider.
|
||||
func NewCodexCliProvider(workspace string) *CodexCliProvider {
|
||||
return &CodexCliProvider{
|
||||
command: "codex",
|
||||
workspace: workspace,
|
||||
}
|
||||
}
|
||||
|
||||
// Chat implements LLMProvider.Chat by executing the codex CLI in non-interactive mode.
|
||||
func (p *CodexCliProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
|
||||
if p.command == "" {
|
||||
return nil, fmt.Errorf("codex command not configured")
|
||||
}
|
||||
|
||||
prompt := p.buildPrompt(messages, tools)
|
||||
|
||||
args := []string{
|
||||
"exec",
|
||||
"--json",
|
||||
"--dangerously-bypass-approvals-and-sandbox",
|
||||
"--skip-git-repo-check",
|
||||
"--color", "never",
|
||||
}
|
||||
if model != "" && model != "codex-cli" {
|
||||
args = append(args, "-m", model)
|
||||
}
|
||||
if p.workspace != "" {
|
||||
args = append(args, "-C", p.workspace)
|
||||
}
|
||||
args = append(args, "-") // read prompt from stdin
|
||||
|
||||
cmd := exec.CommandContext(ctx, p.command, args...)
|
||||
cmd.Stdin = bytes.NewReader([]byte(prompt))
|
||||
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
err := cmd.Run()
|
||||
|
||||
// Parse JSONL from stdout even if exit code is non-zero,
|
||||
// because codex writes diagnostic noise to stderr (e.g. rollout errors)
|
||||
// but still produces valid JSONL output.
|
||||
if stdoutStr := stdout.String(); stdoutStr != "" {
|
||||
resp, parseErr := p.parseJSONLEvents(stdoutStr)
|
||||
if parseErr == nil && resp != nil && (resp.Content != "" || len(resp.ToolCalls) > 0) {
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if ctx.Err() == context.Canceled {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
if stderrStr := stderr.String(); stderrStr != "" {
|
||||
return nil, fmt.Errorf("codex cli error: %s", stderrStr)
|
||||
}
|
||||
return nil, fmt.Errorf("codex cli error: %w", err)
|
||||
}
|
||||
|
||||
return p.parseJSONLEvents(stdout.String())
|
||||
}
|
||||
|
||||
// GetDefaultModel returns the default model identifier.
|
||||
func (p *CodexCliProvider) GetDefaultModel() string {
|
||||
return "codex-cli"
|
||||
}
|
||||
|
||||
// buildPrompt converts messages to a prompt string for the Codex CLI.
|
||||
// System messages are prepended as instructions since Codex CLI has no --system-prompt flag.
|
||||
func (p *CodexCliProvider) buildPrompt(messages []Message, tools []ToolDefinition) string {
|
||||
var systemParts []string
|
||||
var conversationParts []string
|
||||
|
||||
for _, msg := range messages {
|
||||
switch msg.Role {
|
||||
case "system":
|
||||
systemParts = append(systemParts, msg.Content)
|
||||
case "user":
|
||||
conversationParts = append(conversationParts, msg.Content)
|
||||
case "assistant":
|
||||
conversationParts = append(conversationParts, "Assistant: "+msg.Content)
|
||||
case "tool":
|
||||
conversationParts = append(conversationParts,
|
||||
fmt.Sprintf("[Tool Result for %s]: %s", msg.ToolCallID, msg.Content))
|
||||
}
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
|
||||
if len(systemParts) > 0 {
|
||||
sb.WriteString("## System Instructions\n\n")
|
||||
sb.WriteString(strings.Join(systemParts, "\n\n"))
|
||||
sb.WriteString("\n\n## Task\n\n")
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
sb.WriteString(p.buildToolsPrompt(tools))
|
||||
sb.WriteString("\n\n")
|
||||
}
|
||||
|
||||
// Simplify single user message (no prefix)
|
||||
if len(conversationParts) == 1 && len(systemParts) == 0 && len(tools) == 0 {
|
||||
return conversationParts[0]
|
||||
}
|
||||
|
||||
sb.WriteString(strings.Join(conversationParts, "\n"))
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// buildToolsPrompt creates a tool definitions section for the prompt.
|
||||
func (p *CodexCliProvider) buildToolsPrompt(tools []ToolDefinition) string {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString("## Available Tools\n\n")
|
||||
sb.WriteString("When you need to use a tool, respond with ONLY a JSON object:\n\n")
|
||||
sb.WriteString("```json\n")
|
||||
sb.WriteString(`{"tool_calls":[{"id":"call_xxx","type":"function","function":{"name":"tool_name","arguments":"{...}"}}]}`)
|
||||
sb.WriteString("\n```\n\n")
|
||||
sb.WriteString("CRITICAL: The 'arguments' field MUST be a JSON-encoded STRING.\n\n")
|
||||
sb.WriteString("### Tool Definitions:\n\n")
|
||||
|
||||
for _, tool := range tools {
|
||||
if tool.Type != "function" {
|
||||
continue
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("#### %s\n", tool.Function.Name))
|
||||
if tool.Function.Description != "" {
|
||||
sb.WriteString(fmt.Sprintf("Description: %s\n", tool.Function.Description))
|
||||
}
|
||||
if len(tool.Function.Parameters) > 0 {
|
||||
paramsJSON, _ := json.Marshal(tool.Function.Parameters)
|
||||
sb.WriteString(fmt.Sprintf("Parameters:\n```json\n%s\n```\n", string(paramsJSON)))
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// codexEvent represents a single JSONL event from `codex exec --json`.
|
||||
type codexEvent struct {
|
||||
Type string `json:"type"`
|
||||
ThreadID string `json:"thread_id,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Item *codexEventItem `json:"item,omitempty"`
|
||||
Usage *codexUsage `json:"usage,omitempty"`
|
||||
Error *codexEventErr `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type codexEventItem struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Command string `json:"command,omitempty"`
|
||||
Status string `json:"status,omitempty"`
|
||||
ExitCode *int `json:"exit_code,omitempty"`
|
||||
Output string `json:"output,omitempty"`
|
||||
}
|
||||
|
||||
type codexUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
CachedInputTokens int `json:"cached_input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
|
||||
type codexEventErr struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// parseJSONLEvents processes the JSONL output from codex exec --json.
|
||||
func (p *CodexCliProvider) parseJSONLEvents(output string) (*LLMResponse, error) {
|
||||
var contentParts []string
|
||||
var usage *UsageInfo
|
||||
var lastError string
|
||||
|
||||
scanner := bufio.NewScanner(strings.NewReader(output))
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
var event codexEvent
|
||||
if err := json.Unmarshal([]byte(line), &event); err != nil {
|
||||
continue // skip malformed lines
|
||||
}
|
||||
|
||||
switch event.Type {
|
||||
case "item.completed":
|
||||
if event.Item != nil && event.Item.Type == "agent_message" && event.Item.Text != "" {
|
||||
contentParts = append(contentParts, event.Item.Text)
|
||||
}
|
||||
case "turn.completed":
|
||||
if event.Usage != nil {
|
||||
promptTokens := event.Usage.InputTokens + event.Usage.CachedInputTokens
|
||||
usage = &UsageInfo{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: event.Usage.OutputTokens,
|
||||
TotalTokens: promptTokens + event.Usage.OutputTokens,
|
||||
}
|
||||
}
|
||||
case "error":
|
||||
lastError = event.Message
|
||||
case "turn.failed":
|
||||
if event.Error != nil {
|
||||
lastError = event.Error.Message
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if lastError != "" && len(contentParts) == 0 {
|
||||
return nil, fmt.Errorf("codex cli: %s", lastError)
|
||||
}
|
||||
|
||||
content := strings.Join(contentParts, "\n")
|
||||
|
||||
// Extract tool calls from response text (same pattern as ClaudeCliProvider)
|
||||
toolCalls := extractToolCallsFromText(content)
|
||||
|
||||
finishReason := "stop"
|
||||
if len(toolCalls) > 0 {
|
||||
finishReason = "tool_calls"
|
||||
content = stripToolCallsFromText(content)
|
||||
}
|
||||
|
||||
return &LLMResponse{
|
||||
Content: strings.TrimSpace(content),
|
||||
ToolCalls: toolCalls,
|
||||
FinishReason: finishReason,
|
||||
Usage: usage,
|
||||
}, nil
|
||||
}
|
||||
@@ -0,0 +1,585 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// --- JSONL Event Parsing Tests ---
|
||||
|
||||
func TestParseJSONLEvents_AgentMessage(t *testing.T) {
|
||||
p := &CodexCliProvider{}
|
||||
events := `{"type":"thread.started","thread_id":"abc-123"}
|
||||
{"type":"turn.started"}
|
||||
{"type":"item.completed","item":{"id":"item_1","type":"agent_message","text":"Hello from Codex!"}}
|
||||
{"type":"turn.completed","usage":{"input_tokens":100,"cached_input_tokens":50,"output_tokens":20}}`
|
||||
|
||||
resp, err := p.parseJSONLEvents(events)
|
||||
if err != nil {
|
||||
t.Fatalf("parseJSONLEvents() error: %v", err)
|
||||
}
|
||||
if resp.Content != "Hello from Codex!" {
|
||||
t.Errorf("Content = %q, want %q", resp.Content, "Hello from Codex!")
|
||||
}
|
||||
if resp.FinishReason != "stop" {
|
||||
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
|
||||
}
|
||||
if resp.Usage == nil {
|
||||
t.Fatal("Usage should not be nil")
|
||||
}
|
||||
if resp.Usage.PromptTokens != 150 {
|
||||
t.Errorf("PromptTokens = %d, want 150", resp.Usage.PromptTokens)
|
||||
}
|
||||
if resp.Usage.CompletionTokens != 20 {
|
||||
t.Errorf("CompletionTokens = %d, want 20", resp.Usage.CompletionTokens)
|
||||
}
|
||||
if resp.Usage.TotalTokens != 170 {
|
||||
t.Errorf("TotalTokens = %d, want 170", resp.Usage.TotalTokens)
|
||||
}
|
||||
if len(resp.ToolCalls) != 0 {
|
||||
t.Errorf("ToolCalls should be empty, got %d", len(resp.ToolCalls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSONLEvents_ToolCallExtraction(t *testing.T) {
|
||||
p := &CodexCliProvider{}
|
||||
toolCallText := `Let me read that file.
|
||||
{"tool_calls":[{"id":"call_1","type":"function","function":{"name":"read_file","arguments":"{\"path\":\"/tmp/test.txt\"}"}}]}`
|
||||
// Build valid JSONL by marshaling the event
|
||||
item := codexEvent{
|
||||
Type: "item.completed",
|
||||
Item: &codexEventItem{ID: "item_1", Type: "agent_message", Text: toolCallText},
|
||||
}
|
||||
itemJSON, _ := json.Marshal(item)
|
||||
usageEvt := `{"type":"turn.completed","usage":{"input_tokens":50,"cached_input_tokens":0,"output_tokens":20}}`
|
||||
events := `{"type":"turn.started"}` + "\n" + string(itemJSON) + "\n" + usageEvt
|
||||
|
||||
resp, err := p.parseJSONLEvents(events)
|
||||
if err != nil {
|
||||
t.Fatalf("parseJSONLEvents() error: %v", err)
|
||||
}
|
||||
if resp.FinishReason != "tool_calls" {
|
||||
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "tool_calls")
|
||||
}
|
||||
if len(resp.ToolCalls) != 1 {
|
||||
t.Fatalf("ToolCalls count = %d, want 1", len(resp.ToolCalls))
|
||||
}
|
||||
if resp.ToolCalls[0].Name != "read_file" {
|
||||
t.Errorf("ToolCalls[0].Name = %q, want %q", resp.ToolCalls[0].Name, "read_file")
|
||||
}
|
||||
if resp.ToolCalls[0].ID != "call_1" {
|
||||
t.Errorf("ToolCalls[0].ID = %q, want %q", resp.ToolCalls[0].ID, "call_1")
|
||||
}
|
||||
if resp.ToolCalls[0].Function.Arguments != `{"path":"/tmp/test.txt"}` {
|
||||
t.Errorf("ToolCalls[0].Function.Arguments = %q", resp.ToolCalls[0].Function.Arguments)
|
||||
}
|
||||
// Content should have the tool call JSON stripped
|
||||
if strings.Contains(resp.Content, "tool_calls") {
|
||||
t.Errorf("Content should not contain tool_calls JSON, got: %q", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSONLEvents_MultipleToolCalls(t *testing.T) {
|
||||
p := &CodexCliProvider{}
|
||||
toolCallText := `{"tool_calls":[{"id":"call_1","type":"function","function":{"name":"read_file","arguments":"{\"path\":\"a.txt\"}"}},{"id":"call_2","type":"function","function":{"name":"write_file","arguments":"{\"path\":\"b.txt\",\"content\":\"hello\"}"}}]}`
|
||||
item := codexEvent{
|
||||
Type: "item.completed",
|
||||
Item: &codexEventItem{ID: "item_1", Type: "agent_message", Text: toolCallText},
|
||||
}
|
||||
itemJSON, _ := json.Marshal(item)
|
||||
events := `{"type":"turn.started"}` + "\n" + string(itemJSON) + "\n" + `{"type":"turn.completed"}`
|
||||
|
||||
resp, err := p.parseJSONLEvents(events)
|
||||
if err != nil {
|
||||
t.Fatalf("parseJSONLEvents() error: %v", err)
|
||||
}
|
||||
if len(resp.ToolCalls) != 2 {
|
||||
t.Fatalf("ToolCalls count = %d, want 2", len(resp.ToolCalls))
|
||||
}
|
||||
if resp.ToolCalls[0].Name != "read_file" {
|
||||
t.Errorf("ToolCalls[0].Name = %q, want %q", resp.ToolCalls[0].Name, "read_file")
|
||||
}
|
||||
if resp.ToolCalls[1].Name != "write_file" {
|
||||
t.Errorf("ToolCalls[1].Name = %q, want %q", resp.ToolCalls[1].Name, "write_file")
|
||||
}
|
||||
if resp.FinishReason != "tool_calls" {
|
||||
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "tool_calls")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSONLEvents_MultipleMessages(t *testing.T) {
|
||||
p := &CodexCliProvider{}
|
||||
events := `{"type":"turn.started"}
|
||||
{"type":"item.completed","item":{"id":"item_1","type":"agent_message","text":"First part."}}
|
||||
{"type":"item.completed","item":{"id":"item_2","type":"command_execution","command":"ls","status":"completed"}}
|
||||
{"type":"item.completed","item":{"id":"item_3","type":"agent_message","text":"Second part."}}
|
||||
{"type":"turn.completed"}`
|
||||
|
||||
resp, err := p.parseJSONLEvents(events)
|
||||
if err != nil {
|
||||
t.Fatalf("parseJSONLEvents() error: %v", err)
|
||||
}
|
||||
if resp.Content != "First part.\nSecond part." {
|
||||
t.Errorf("Content = %q, want %q", resp.Content, "First part.\nSecond part.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSONLEvents_ErrorEvent(t *testing.T) {
|
||||
p := &CodexCliProvider{}
|
||||
events := `{"type":"thread.started","thread_id":"abc"}
|
||||
{"type":"turn.started"}
|
||||
{"type":"error","message":"token expired"}
|
||||
{"type":"turn.failed","error":{"message":"token expired"}}`
|
||||
|
||||
_, err := p.parseJSONLEvents(events)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "token expired") {
|
||||
t.Errorf("error = %q, want to contain 'token expired'", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSONLEvents_TurnFailed(t *testing.T) {
|
||||
p := &CodexCliProvider{}
|
||||
events := `{"type":"turn.started"}
|
||||
{"type":"turn.failed","error":{"message":"rate limit exceeded"}}`
|
||||
|
||||
_, err := p.parseJSONLEvents(events)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "rate limit exceeded") {
|
||||
t.Errorf("error = %q, want to contain 'rate limit exceeded'", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSONLEvents_ErrorWithContent(t *testing.T) {
|
||||
p := &CodexCliProvider{}
|
||||
// If there's an error but also content, return the content (partial success)
|
||||
events := `{"type":"turn.started"}
|
||||
{"type":"item.completed","item":{"id":"item_1","type":"agent_message","text":"Partial result."}}
|
||||
{"type":"error","message":"connection reset"}
|
||||
{"type":"turn.failed","error":{"message":"connection reset"}}`
|
||||
|
||||
resp, err := p.parseJSONLEvents(events)
|
||||
if err != nil {
|
||||
t.Fatalf("should not error when content exists: %v", err)
|
||||
}
|
||||
if resp.Content != "Partial result." {
|
||||
t.Errorf("Content = %q, want %q", resp.Content, "Partial result.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSONLEvents_EmptyOutput(t *testing.T) {
|
||||
p := &CodexCliProvider{}
|
||||
resp, err := p.parseJSONLEvents("")
|
||||
if err != nil {
|
||||
t.Fatalf("empty output should not error: %v", err)
|
||||
}
|
||||
if resp.Content != "" {
|
||||
t.Errorf("Content = %q, want empty", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSONLEvents_MalformedLines(t *testing.T) {
|
||||
p := &CodexCliProvider{}
|
||||
events := `not json at all
|
||||
{"type":"item.completed","item":{"id":"item_1","type":"agent_message","text":"Good line."}}
|
||||
another bad line
|
||||
{"type":"turn.completed","usage":{"input_tokens":10,"output_tokens":5}}`
|
||||
|
||||
resp, err := p.parseJSONLEvents(events)
|
||||
if err != nil {
|
||||
t.Fatalf("should skip malformed lines: %v", err)
|
||||
}
|
||||
if resp.Content != "Good line." {
|
||||
t.Errorf("Content = %q, want %q", resp.Content, "Good line.")
|
||||
}
|
||||
if resp.Usage == nil || resp.Usage.TotalTokens != 15 {
|
||||
t.Errorf("Usage.TotalTokens = %v, want 15", resp.Usage)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSONLEvents_CommandExecution(t *testing.T) {
|
||||
p := &CodexCliProvider{}
|
||||
events := `{"type":"turn.started"}
|
||||
{"type":"item.started","item":{"id":"item_1","type":"command_execution","command":"bash -lc ls","status":"in_progress"}}
|
||||
{"type":"item.completed","item":{"id":"item_1","type":"command_execution","command":"bash -lc ls","status":"completed","exit_code":0,"output":"file1.go\nfile2.go"}}
|
||||
{"type":"item.completed","item":{"id":"item_2","type":"agent_message","text":"Found 2 files."}}
|
||||
{"type":"turn.completed"}`
|
||||
|
||||
resp, err := p.parseJSONLEvents(events)
|
||||
if err != nil {
|
||||
t.Fatalf("parseJSONLEvents() error: %v", err)
|
||||
}
|
||||
// command_execution items should be skipped; only agent_message text is returned
|
||||
if resp.Content != "Found 2 files." {
|
||||
t.Errorf("Content = %q, want %q", resp.Content, "Found 2 files.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJSONLEvents_NoUsage(t *testing.T) {
|
||||
p := &CodexCliProvider{}
|
||||
events := `{"type":"turn.started"}
|
||||
{"type":"item.completed","item":{"id":"item_1","type":"agent_message","text":"No usage info."}}
|
||||
{"type":"turn.completed"}`
|
||||
|
||||
resp, err := p.parseJSONLEvents(events)
|
||||
if err != nil {
|
||||
t.Fatalf("parseJSONLEvents() error: %v", err)
|
||||
}
|
||||
if resp.Usage != nil {
|
||||
t.Errorf("Usage should be nil when turn.completed has no usage, got %+v", resp.Usage)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Prompt Building Tests ---
|
||||
|
||||
func TestBuildPrompt_SystemAsInstructions(t *testing.T) {
|
||||
p := &CodexCliProvider{}
|
||||
messages := []Message{
|
||||
{Role: "system", Content: "You are helpful."},
|
||||
{Role: "user", Content: "Hi there"},
|
||||
}
|
||||
|
||||
prompt := p.buildPrompt(messages, nil)
|
||||
|
||||
if !strings.Contains(prompt, "## System Instructions") {
|
||||
t.Error("prompt should contain '## System Instructions'")
|
||||
}
|
||||
if !strings.Contains(prompt, "You are helpful.") {
|
||||
t.Error("prompt should contain system content")
|
||||
}
|
||||
if !strings.Contains(prompt, "## Task") {
|
||||
t.Error("prompt should contain '## Task'")
|
||||
}
|
||||
if !strings.Contains(prompt, "Hi there") {
|
||||
t.Error("prompt should contain user message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPrompt_NoSystem(t *testing.T) {
|
||||
p := &CodexCliProvider{}
|
||||
messages := []Message{
|
||||
{Role: "user", Content: "Just a question"},
|
||||
}
|
||||
|
||||
prompt := p.buildPrompt(messages, nil)
|
||||
|
||||
if strings.Contains(prompt, "## System Instructions") {
|
||||
t.Error("prompt should not contain system instructions header")
|
||||
}
|
||||
if prompt != "Just a question" {
|
||||
t.Errorf("prompt = %q, want %q", prompt, "Just a question")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPrompt_WithTools(t *testing.T) {
|
||||
p := &CodexCliProvider{}
|
||||
messages := []Message{
|
||||
{Role: "user", Content: "Get weather"},
|
||||
}
|
||||
tools := []ToolDefinition{
|
||||
{
|
||||
Type: "function",
|
||||
Function: ToolFunctionDefinition{
|
||||
Name: "get_weather",
|
||||
Description: "Get current weather",
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"city": map[string]interface{}{"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
prompt := p.buildPrompt(messages, tools)
|
||||
|
||||
if !strings.Contains(prompt, "## Available Tools") {
|
||||
t.Error("prompt should contain tools section")
|
||||
}
|
||||
if !strings.Contains(prompt, "get_weather") {
|
||||
t.Error("prompt should contain tool name")
|
||||
}
|
||||
if !strings.Contains(prompt, "Get current weather") {
|
||||
t.Error("prompt should contain tool description")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPrompt_MultipleMessages(t *testing.T) {
|
||||
p := &CodexCliProvider{}
|
||||
messages := []Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
{Role: "assistant", Content: "Hi! How can I help?"},
|
||||
{Role: "user", Content: "Tell me about Go"},
|
||||
}
|
||||
|
||||
prompt := p.buildPrompt(messages, nil)
|
||||
|
||||
if !strings.Contains(prompt, "Hello") {
|
||||
t.Error("prompt should contain first user message")
|
||||
}
|
||||
if !strings.Contains(prompt, "Assistant: Hi! How can I help?") {
|
||||
t.Error("prompt should contain assistant message with prefix")
|
||||
}
|
||||
if !strings.Contains(prompt, "Tell me about Go") {
|
||||
t.Error("prompt should contain second user message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPrompt_ToolResults(t *testing.T) {
|
||||
p := &CodexCliProvider{}
|
||||
messages := []Message{
|
||||
{Role: "user", Content: "Weather?"},
|
||||
{Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"},
|
||||
}
|
||||
|
||||
prompt := p.buildPrompt(messages, nil)
|
||||
|
||||
if !strings.Contains(prompt, "[Tool Result for call_1]") {
|
||||
t.Error("prompt should contain tool result")
|
||||
}
|
||||
if !strings.Contains(prompt, `{"temp": 72}`) {
|
||||
t.Error("prompt should contain tool result content")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPrompt_SystemAndTools(t *testing.T) {
|
||||
p := &CodexCliProvider{}
|
||||
messages := []Message{
|
||||
{Role: "system", Content: "Be concise."},
|
||||
{Role: "user", Content: "Do something"},
|
||||
}
|
||||
tools := []ToolDefinition{
|
||||
{
|
||||
Type: "function",
|
||||
Function: ToolFunctionDefinition{
|
||||
Name: "my_tool",
|
||||
Description: "A tool",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
prompt := p.buildPrompt(messages, tools)
|
||||
|
||||
// System instructions should come first
|
||||
sysIdx := strings.Index(prompt, "## System Instructions")
|
||||
toolIdx := strings.Index(prompt, "## Available Tools")
|
||||
taskIdx := strings.Index(prompt, "## Task")
|
||||
|
||||
if sysIdx == -1 || toolIdx == -1 || taskIdx == -1 {
|
||||
t.Fatal("prompt should contain all sections")
|
||||
}
|
||||
if sysIdx >= taskIdx {
|
||||
t.Error("system instructions should come before task")
|
||||
}
|
||||
if taskIdx >= toolIdx {
|
||||
t.Error("task section should come before tools in the output")
|
||||
}
|
||||
}
|
||||
|
||||
// --- CLI Argument Tests ---
|
||||
|
||||
func TestCodexCliProvider_GetDefaultModel(t *testing.T) {
|
||||
p := NewCodexCliProvider("")
|
||||
if got := p.GetDefaultModel(); got != "codex-cli" {
|
||||
t.Errorf("GetDefaultModel() = %q, want %q", got, "codex-cli")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Mock CLI Integration Test ---
|
||||
|
||||
func createMockCodexCLI(t *testing.T, events []string) string {
|
||||
t.Helper()
|
||||
tmpDir := t.TempDir()
|
||||
scriptPath := filepath.Join(tmpDir, "codex")
|
||||
|
||||
var sb strings.Builder
|
||||
sb.WriteString("#!/bin/bash\n")
|
||||
for _, event := range events {
|
||||
sb.WriteString(fmt.Sprintf("echo '%s'\n", event))
|
||||
}
|
||||
|
||||
if err := os.WriteFile(scriptPath, []byte(sb.String()), 0755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return scriptPath
|
||||
}
|
||||
|
||||
func TestCodexCliProvider_MockCLI_Success(t *testing.T) {
|
||||
scriptPath := createMockCodexCLI(t, []string{
|
||||
`{"type":"thread.started","thread_id":"test-123"}`,
|
||||
`{"type":"turn.started"}`,
|
||||
`{"type":"item.completed","item":{"id":"item_1","type":"agent_message","text":"Mock response from Codex CLI"}}`,
|
||||
`{"type":"turn.completed","usage":{"input_tokens":50,"cached_input_tokens":10,"output_tokens":15}}`,
|
||||
})
|
||||
|
||||
p := &CodexCliProvider{
|
||||
command: scriptPath,
|
||||
workspace: "",
|
||||
}
|
||||
|
||||
messages := []Message{{Role: "user", Content: "Hello"}}
|
||||
resp, err := p.Chat(context.Background(), messages, nil, "", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error: %v", err)
|
||||
}
|
||||
if resp.Content != "Mock response from Codex CLI" {
|
||||
t.Errorf("Content = %q, want %q", resp.Content, "Mock response from Codex CLI")
|
||||
}
|
||||
if resp.Usage == nil {
|
||||
t.Fatal("Usage should not be nil")
|
||||
}
|
||||
if resp.Usage.PromptTokens != 60 {
|
||||
t.Errorf("PromptTokens = %d, want 60", resp.Usage.PromptTokens)
|
||||
}
|
||||
if resp.Usage.CompletionTokens != 15 {
|
||||
t.Errorf("CompletionTokens = %d, want 15", resp.Usage.CompletionTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexCliProvider_MockCLI_Error(t *testing.T) {
|
||||
scriptPath := createMockCodexCLI(t, []string{
|
||||
`{"type":"thread.started","thread_id":"test-err"}`,
|
||||
`{"type":"turn.started"}`,
|
||||
`{"type":"error","message":"auth token expired"}`,
|
||||
`{"type":"turn.failed","error":{"message":"auth token expired"}}`,
|
||||
})
|
||||
|
||||
p := &CodexCliProvider{
|
||||
command: scriptPath,
|
||||
workspace: "",
|
||||
}
|
||||
|
||||
messages := []Message{{Role: "user", Content: "Hello"}}
|
||||
_, err := p.Chat(context.Background(), messages, nil, "", nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "auth token expired") {
|
||||
t.Errorf("error = %q, want to contain 'auth token expired'", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexCliProvider_MockCLI_WithModel(t *testing.T) {
|
||||
// Mock script that captures args to verify model flag is passed
|
||||
tmpDir := t.TempDir()
|
||||
scriptPath := filepath.Join(tmpDir, "codex")
|
||||
script := `#!/bin/bash
|
||||
# Write args to a file for verification
|
||||
echo "$@" > "` + filepath.Join(tmpDir, "args.txt") + `"
|
||||
echo '{"type":"item.completed","item":{"id":"1","type":"agent_message","text":"ok"}}'
|
||||
echo '{"type":"turn.completed"}'`
|
||||
|
||||
if err := os.WriteFile(scriptPath, []byte(script), 0755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
p := &CodexCliProvider{
|
||||
command: scriptPath,
|
||||
workspace: "/tmp/test-workspace",
|
||||
}
|
||||
|
||||
messages := []Message{{Role: "user", Content: "test"}}
|
||||
_, err := p.Chat(context.Background(), messages, nil, "gpt-5.2-codex", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error: %v", err)
|
||||
}
|
||||
|
||||
// Verify the args
|
||||
argsData, err := os.ReadFile(filepath.Join(tmpDir, "args.txt"))
|
||||
if err != nil {
|
||||
t.Fatalf("reading args: %v", err)
|
||||
}
|
||||
args := string(argsData)
|
||||
|
||||
if !strings.Contains(args, "-m gpt-5.2-codex") {
|
||||
t.Errorf("args should contain model flag, got: %s", args)
|
||||
}
|
||||
if !strings.Contains(args, "-C /tmp/test-workspace") {
|
||||
t.Errorf("args should contain workspace flag, got: %s", args)
|
||||
}
|
||||
if !strings.Contains(args, "--json") {
|
||||
t.Errorf("args should contain --json, got: %s", args)
|
||||
}
|
||||
if !strings.Contains(args, "--dangerously-bypass-approvals-and-sandbox") {
|
||||
t.Errorf("args should contain bypass flag, got: %s", args)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexCliProvider_MockCLI_ContextCancel(t *testing.T) {
|
||||
// Script that sleeps forever
|
||||
tmpDir := t.TempDir()
|
||||
scriptPath := filepath.Join(tmpDir, "codex")
|
||||
script := "#!/bin/bash\nsleep 60"
|
||||
|
||||
if err := os.WriteFile(scriptPath, []byte(script), 0755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
p := &CodexCliProvider{
|
||||
command: scriptPath,
|
||||
workspace: "",
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // cancel immediately
|
||||
|
||||
messages := []Message{{Role: "user", Content: "test"}}
|
||||
_, err := p.Chat(ctx, messages, nil, "", nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error on canceled context")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexCliProvider_EmptyCommand(t *testing.T) {
|
||||
p := &CodexCliProvider{command: ""}
|
||||
|
||||
messages := []Message{{Role: "user", Content: "test"}}
|
||||
_, err := p.Chat(context.Background(), messages, nil, "", nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty command")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Integration Test (requires real codex CLI with valid auth) ---
|
||||
|
||||
func TestCodexCliProvider_Integration(t *testing.T) {
|
||||
if os.Getenv("PICOCLAW_INTEGRATION_TESTS") == "" {
|
||||
t.Skip("skipping integration test (set PICOCLAW_INTEGRATION_TESTS=1 to enable)")
|
||||
}
|
||||
|
||||
// Verify codex is available
|
||||
codexPath, err := exec.LookPath("codex")
|
||||
if err != nil {
|
||||
t.Skip("codex CLI not found in PATH")
|
||||
}
|
||||
|
||||
p := &CodexCliProvider{
|
||||
command: codexPath,
|
||||
workspace: "",
|
||||
}
|
||||
|
||||
messages := []Message{
|
||||
{Role: "user", Content: "Respond with just the word 'hello' and nothing else."},
|
||||
}
|
||||
|
||||
resp, err := p.Chat(context.Background(), messages, nil, "", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error: %v", err)
|
||||
}
|
||||
|
||||
lower := strings.ToLower(strings.TrimSpace(resp.Content))
|
||||
if !strings.Contains(lower, "hello") {
|
||||
t.Errorf("Content = %q, expected to contain 'hello'", resp.Content)
|
||||
}
|
||||
}
|
||||
+161
-16
@@ -3,6 +3,7 @@ package providers
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
@@ -10,18 +11,26 @@ import (
|
||||
"github.com/openai/openai-go/v3/option"
|
||||
"github.com/openai/openai-go/v3/responses"
|
||||
"github.com/sipeed/picoclaw/pkg/auth"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
const codexDefaultModel = "gpt-5.2"
|
||||
const codexDefaultInstructions = "You are Codex, a coding assistant."
|
||||
|
||||
type CodexProvider struct {
|
||||
client *openai.Client
|
||||
accountID string
|
||||
tokenSource func() (string, string, error)
|
||||
}
|
||||
|
||||
const defaultCodexInstructions = "You are Codex, a coding assistant."
|
||||
|
||||
func NewCodexProvider(token, accountID string) *CodexProvider {
|
||||
opts := []option.RequestOption{
|
||||
option.WithBaseURL("https://chatgpt.com/backend-api/codex"),
|
||||
option.WithAPIKey(token),
|
||||
option.WithHeader("originator", "codex_cli_rs"),
|
||||
option.WithHeader("OpenAI-Beta", "responses=experimental"),
|
||||
}
|
||||
if accountID != "" {
|
||||
opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accountID))
|
||||
@@ -41,6 +50,15 @@ func NewCodexProviderWithTokenSource(token, accountID string, tokenSource func()
|
||||
|
||||
func (p *CodexProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
|
||||
var opts []option.RequestOption
|
||||
accountID := p.accountID
|
||||
resolvedModel, fallbackReason := resolveCodexModel(model)
|
||||
if fallbackReason != "" {
|
||||
logger.WarnCF("provider.codex", "Requested model is not compatible with Codex backend, using fallback", map[string]interface{}{
|
||||
"requested_model": model,
|
||||
"resolved_model": resolvedModel,
|
||||
"reason": fallbackReason,
|
||||
})
|
||||
}
|
||||
if p.tokenSource != nil {
|
||||
tok, accID, err := p.tokenSource()
|
||||
if err != nil {
|
||||
@@ -48,22 +66,120 @@ func (p *CodexProvider) Chat(ctx context.Context, messages []Message, tools []To
|
||||
}
|
||||
opts = append(opts, option.WithAPIKey(tok))
|
||||
if accID != "" {
|
||||
opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accID))
|
||||
accountID = accID
|
||||
}
|
||||
}
|
||||
if accountID != "" {
|
||||
opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accountID))
|
||||
} else {
|
||||
logger.WarnCF("provider.codex", "No account id found for Codex request; backend may reject with 400", map[string]interface{}{
|
||||
"requested_model": model,
|
||||
"resolved_model": resolvedModel,
|
||||
})
|
||||
}
|
||||
|
||||
params := buildCodexParams(messages, tools, model, options)
|
||||
params := buildCodexParams(messages, tools, resolvedModel, options)
|
||||
|
||||
resp, err := p.client.Responses.New(ctx, params, opts...)
|
||||
stream := p.client.Responses.NewStreaming(ctx, params, opts...)
|
||||
defer stream.Close()
|
||||
|
||||
var resp *responses.Response
|
||||
for stream.Next() {
|
||||
evt := stream.Current()
|
||||
if evt.Type == "response.completed" || evt.Type == "response.failed" || evt.Type == "response.incomplete" {
|
||||
evtResp := evt.Response
|
||||
if evtResp.ID != "" {
|
||||
copy := evtResp
|
||||
resp = ©
|
||||
}
|
||||
}
|
||||
}
|
||||
err := stream.Err()
|
||||
if err != nil {
|
||||
fields := map[string]interface{}{
|
||||
"requested_model": model,
|
||||
"resolved_model": resolvedModel,
|
||||
"messages_count": len(messages),
|
||||
"tools_count": len(tools),
|
||||
"account_id_present": accountID != "",
|
||||
"error": err.Error(),
|
||||
}
|
||||
var apiErr *openai.Error
|
||||
if errors.As(err, &apiErr) {
|
||||
fields["status_code"] = apiErr.StatusCode
|
||||
fields["api_type"] = apiErr.Type
|
||||
fields["api_code"] = apiErr.Code
|
||||
fields["api_param"] = apiErr.Param
|
||||
fields["api_message"] = apiErr.Message
|
||||
if apiErr.StatusCode == 400 {
|
||||
fields["hint"] = "verify account id header and model compatibility for codex backend"
|
||||
}
|
||||
if apiErr.Response != nil {
|
||||
fields["request_id"] = apiErr.Response.Header.Get("x-request-id")
|
||||
}
|
||||
}
|
||||
logger.ErrorCF("provider.codex", "Codex API call failed", fields)
|
||||
return nil, fmt.Errorf("codex API call: %w", err)
|
||||
}
|
||||
if resp == nil {
|
||||
fields := map[string]interface{}{
|
||||
"requested_model": model,
|
||||
"resolved_model": resolvedModel,
|
||||
"messages_count": len(messages),
|
||||
"tools_count": len(tools),
|
||||
"account_id_present": accountID != "",
|
||||
}
|
||||
logger.ErrorCF("provider.codex", "Codex stream ended without completed response event", fields)
|
||||
return nil, fmt.Errorf("codex API call: stream ended without completed response")
|
||||
}
|
||||
|
||||
return parseCodexResponse(resp), nil
|
||||
}
|
||||
|
||||
func (p *CodexProvider) GetDefaultModel() string {
|
||||
return "gpt-4o"
|
||||
return codexDefaultModel
|
||||
}
|
||||
|
||||
func resolveCodexModel(model string) (string, string) {
|
||||
m := strings.ToLower(strings.TrimSpace(model))
|
||||
if m == "" {
|
||||
return codexDefaultModel, "empty model"
|
||||
}
|
||||
|
||||
if strings.HasPrefix(m, "openai/") {
|
||||
m = strings.TrimPrefix(m, "openai/")
|
||||
} else if strings.Contains(m, "/") {
|
||||
return codexDefaultModel, "non-openai model namespace"
|
||||
}
|
||||
|
||||
unsupportedPrefixes := []string{
|
||||
"glm",
|
||||
"claude",
|
||||
"anthropic",
|
||||
"gemini",
|
||||
"google",
|
||||
"moonshot",
|
||||
"kimi",
|
||||
"qwen",
|
||||
"deepseek",
|
||||
"llama",
|
||||
"meta-llama",
|
||||
"mistral",
|
||||
"grok",
|
||||
"xai",
|
||||
"zhipu",
|
||||
}
|
||||
for _, prefix := range unsupportedPrefixes {
|
||||
if strings.HasPrefix(m, prefix) {
|
||||
return codexDefaultModel, "unsupported model prefix"
|
||||
}
|
||||
}
|
||||
|
||||
if strings.HasPrefix(m, "gpt-") || strings.HasPrefix(m, "o3") || strings.HasPrefix(m, "o4") {
|
||||
return m, ""
|
||||
}
|
||||
|
||||
return codexDefaultModel, "unsupported model family"
|
||||
}
|
||||
|
||||
func buildCodexParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) responses.ResponseNewParams {
|
||||
@@ -101,12 +217,18 @@ func buildCodexParams(messages []Message, tools []ToolDefinition, model string,
|
||||
})
|
||||
}
|
||||
for _, tc := range msg.ToolCalls {
|
||||
argsJSON, _ := json.Marshal(tc.Arguments)
|
||||
name, args, ok := resolveCodexToolCall(tc)
|
||||
if !ok {
|
||||
logger.WarnCF("provider.codex", "Skipping invalid tool call in history", map[string]interface{}{
|
||||
"call_id": tc.ID,
|
||||
})
|
||||
continue
|
||||
}
|
||||
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
|
||||
OfFunctionCall: &responses.ResponseFunctionToolCallParam{
|
||||
CallID: tc.ID,
|
||||
Name: tc.Name,
|
||||
Arguments: string(argsJSON),
|
||||
Name: name,
|
||||
Arguments: args,
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -133,19 +255,15 @@ func buildCodexParams(messages []Message, tools []ToolDefinition, model string,
|
||||
Input: responses.ResponseNewParamsInputUnion{
|
||||
OfInputItemList: inputItems,
|
||||
},
|
||||
Store: openai.Opt(false),
|
||||
Instructions: openai.Opt(instructions),
|
||||
Store: openai.Opt(false),
|
||||
}
|
||||
|
||||
if instructions != "" {
|
||||
params.Instructions = openai.Opt(instructions)
|
||||
}
|
||||
|
||||
if maxTokens, ok := options["max_tokens"].(int); ok {
|
||||
params.MaxOutputTokens = openai.Opt(int64(maxTokens))
|
||||
}
|
||||
|
||||
if temp, ok := options["temperature"].(float64); ok {
|
||||
params.Temperature = openai.Opt(temp)
|
||||
} else {
|
||||
// ChatGPT Codex backend requires instructions to be present.
|
||||
params.Instructions = openai.Opt(defaultCodexInstructions)
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
@@ -155,6 +273,30 @@ func buildCodexParams(messages []Message, tools []ToolDefinition, model string,
|
||||
return params
|
||||
}
|
||||
|
||||
func resolveCodexToolCall(tc ToolCall) (name string, arguments string, ok bool) {
|
||||
name = tc.Name
|
||||
if name == "" && tc.Function != nil {
|
||||
name = tc.Function.Name
|
||||
}
|
||||
if name == "" {
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
if len(tc.Arguments) > 0 {
|
||||
argsJSON, err := json.Marshal(tc.Arguments)
|
||||
if err != nil {
|
||||
return "", "", false
|
||||
}
|
||||
return name, string(argsJSON), true
|
||||
}
|
||||
|
||||
if tc.Function != nil && tc.Function.Arguments != "" {
|
||||
return name, tc.Function.Arguments, true
|
||||
}
|
||||
|
||||
return name, "{}", true
|
||||
}
|
||||
|
||||
func translateToolsForCodex(tools []ToolDefinition) []responses.ToolUnionParam {
|
||||
result := make([]responses.ToolUnionParam, 0, len(tools))
|
||||
for _, t := range tools {
|
||||
@@ -237,6 +379,9 @@ func createCodexTokenSource() func() (string, string, error) {
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("refreshing token: %w", err)
|
||||
}
|
||||
if refreshed.AccountID == "" {
|
||||
refreshed.AccountID = cred.AccountID
|
||||
}
|
||||
if err := auth.SetCredential("openai", refreshed); err != nil {
|
||||
return "", "", fmt.Errorf("saving refreshed token: %w", err)
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package providers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
@@ -16,11 +17,21 @@ func TestBuildCodexParams_BasicMessage(t *testing.T) {
|
||||
{Role: "user", Content: "Hello"},
|
||||
}
|
||||
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{
|
||||
"max_tokens": 2048,
|
||||
"max_tokens": 2048,
|
||||
"temperature": 0.7,
|
||||
})
|
||||
if params.Model != "gpt-4o" {
|
||||
t.Errorf("Model = %q, want %q", params.Model, "gpt-4o")
|
||||
}
|
||||
if !params.Instructions.Valid() {
|
||||
t.Fatal("Instructions should be set")
|
||||
}
|
||||
if params.Instructions.Or("") != defaultCodexInstructions {
|
||||
t.Errorf("Instructions = %q, want %q", params.Instructions.Or(""), defaultCodexInstructions)
|
||||
}
|
||||
if params.MaxOutputTokens.Valid() {
|
||||
t.Fatalf("MaxOutputTokens should not be set for Codex backend")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCodexParams_SystemAsInstructions(t *testing.T) {
|
||||
@@ -57,6 +68,45 @@ func TestBuildCodexParams_ToolCallConversation(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCodexParams_ToolCallFunctionFallback(t *testing.T) {
|
||||
messages := []Message{
|
||||
{Role: "user", Content: "Read a file"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Type: "function",
|
||||
Function: &FunctionCall{
|
||||
Name: "read_file",
|
||||
Arguments: `{"path":"README.md"}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "ok", ToolCallID: "call_1"},
|
||||
}
|
||||
|
||||
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{})
|
||||
if params.Input.OfInputItemList == nil {
|
||||
t.Fatal("Input.OfInputItemList should not be nil")
|
||||
}
|
||||
if len(params.Input.OfInputItemList) != 3 {
|
||||
t.Fatalf("len(Input items) = %d, want 3", len(params.Input.OfInputItemList))
|
||||
}
|
||||
|
||||
fc := params.Input.OfInputItemList[1].OfFunctionCall
|
||||
if fc == nil {
|
||||
t.Fatal("assistant tool call should be converted to function_call input item")
|
||||
}
|
||||
if fc.Name != "read_file" {
|
||||
t.Errorf("Function call name = %q, want %q", fc.Name, "read_file")
|
||||
}
|
||||
if fc.Arguments != `{"path":"README.md"}` {
|
||||
t.Errorf("Function call arguments = %q, want %q", fc.Arguments, `{"path":"README.md"}`)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCodexParams_WithTools(t *testing.T) {
|
||||
tools := []ToolDefinition{
|
||||
{
|
||||
@@ -197,6 +247,20 @@ func TestCodexProvider_ChatRoundTrip(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
var reqBody map[string]interface{}
|
||||
if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil {
|
||||
http.Error(w, "invalid json", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if reqBody["stream"] != true {
|
||||
http.Error(w, "stream must be true", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if _, ok := reqBody["max_output_tokens"]; ok {
|
||||
http.Error(w, "max_output_tokens is not supported", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
resp := map[string]interface{}{
|
||||
"id": "resp_test",
|
||||
"object": "response",
|
||||
@@ -220,8 +284,7 @@ func TestCodexProvider_ChatRoundTrip(t *testing.T) {
|
||||
"output_tokens_details": map[string]interface{}{"reasoning_tokens": 0},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
writeCompletedSSE(w, resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
@@ -244,10 +307,189 @@ func TestCodexProvider_ChatRoundTrip(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexProvider_ChatRoundTrip_TokenSourceFallbackAccountID(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/responses" {
|
||||
http.Error(w, "not found: "+r.URL.Path, http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
if r.Header.Get("Authorization") != "Bearer refreshed-token" {
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
if r.Header.Get("Chatgpt-Account-Id") != "acc-123" {
|
||||
http.Error(w, "missing account id", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var reqBody map[string]interface{}
|
||||
if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil {
|
||||
http.Error(w, "invalid json", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if _, ok := reqBody["instructions"]; !ok {
|
||||
http.Error(w, "missing instructions", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if reqBody["instructions"] == "" {
|
||||
http.Error(w, "instructions must not be empty", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if _, ok := reqBody["temperature"]; ok {
|
||||
http.Error(w, "temperature is not supported", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if _, ok := reqBody["max_output_tokens"]; ok {
|
||||
http.Error(w, "max_output_tokens is not supported", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if reqBody["stream"] != true {
|
||||
http.Error(w, "stream must be true", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
resp := map[string]interface{}{
|
||||
"id": "resp_test",
|
||||
"object": "response",
|
||||
"status": "completed",
|
||||
"output": []map[string]interface{}{
|
||||
{
|
||||
"id": "msg_1",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"status": "completed",
|
||||
"content": []map[string]interface{}{
|
||||
{"type": "output_text", "text": "Hi from Codex!"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"usage": map[string]interface{}{
|
||||
"input_tokens": 8,
|
||||
"output_tokens": 4,
|
||||
"total_tokens": 12,
|
||||
"input_tokens_details": map[string]interface{}{"cached_tokens": 0},
|
||||
"output_tokens_details": map[string]interface{}{"reasoning_tokens": 0},
|
||||
},
|
||||
}
|
||||
writeCompletedSSE(w, resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := NewCodexProvider("stale-token", "acc-123")
|
||||
provider.client = createOpenAITestClient(server.URL, "stale-token", "")
|
||||
provider.tokenSource = func() (string, string, error) {
|
||||
return "refreshed-token", "", nil
|
||||
}
|
||||
|
||||
messages := []Message{{Role: "user", Content: "Hello"}}
|
||||
resp, err := provider.Chat(t.Context(), messages, nil, "gpt-4o", map[string]interface{}{"temperature": 0.7})
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error: %v", err)
|
||||
}
|
||||
if resp.Content != "Hi from Codex!" {
|
||||
t.Errorf("Content = %q, want %q", resp.Content, "Hi from Codex!")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexProvider_ChatRoundTrip_ModelFallbackFromUnsupported(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/responses" {
|
||||
http.Error(w, "not found: "+r.URL.Path, http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
var reqBody map[string]interface{}
|
||||
if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil {
|
||||
http.Error(w, "invalid json", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if reqBody["model"] != codexDefaultModel {
|
||||
http.Error(w, "unsupported model", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if reqBody["stream"] != true {
|
||||
http.Error(w, "stream must be true", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if reqBody["instructions"] != codexDefaultInstructions {
|
||||
http.Error(w, "missing default instructions", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
resp := map[string]interface{}{
|
||||
"id": "resp_test",
|
||||
"object": "response",
|
||||
"status": "completed",
|
||||
"output": []map[string]interface{}{
|
||||
{
|
||||
"id": "msg_1",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"status": "completed",
|
||||
"content": []map[string]interface{}{
|
||||
{"type": "output_text", "text": "Hi from Codex!"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"usage": map[string]interface{}{
|
||||
"input_tokens": 8,
|
||||
"output_tokens": 4,
|
||||
"total_tokens": 12,
|
||||
"input_tokens_details": map[string]interface{}{"cached_tokens": 0},
|
||||
"output_tokens_details": map[string]interface{}{"reasoning_tokens": 0},
|
||||
},
|
||||
}
|
||||
writeCompletedSSE(w, resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := NewCodexProvider("test-token", "acc-123")
|
||||
provider.client = createOpenAITestClient(server.URL, "test-token", "acc-123")
|
||||
|
||||
messages := []Message{{Role: "user", Content: "Hello"}}
|
||||
resp, err := provider.Chat(t.Context(), messages, nil, "gpt-5.2", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error: %v", err)
|
||||
}
|
||||
if resp.Content != "Hi from Codex!" {
|
||||
t.Errorf("Content = %q, want %q", resp.Content, "Hi from Codex!")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexProvider_GetDefaultModel(t *testing.T) {
|
||||
p := NewCodexProvider("test-token", "")
|
||||
if got := p.GetDefaultModel(); got != "gpt-4o" {
|
||||
t.Errorf("GetDefaultModel() = %q, want %q", got, "gpt-4o")
|
||||
if got := p.GetDefaultModel(); got != codexDefaultModel {
|
||||
t.Errorf("GetDefaultModel() = %q, want %q", got, codexDefaultModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveCodexModel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantModel string
|
||||
wantFallback bool
|
||||
}{
|
||||
{name: "empty", input: "", wantModel: codexDefaultModel, wantFallback: true},
|
||||
{name: "unsupported namespace", input: "anthropic/claude-3.5", wantModel: codexDefaultModel, wantFallback: true},
|
||||
{name: "non-openai prefixed", input: "glm-4.7", wantModel: codexDefaultModel, wantFallback: true},
|
||||
{name: "openai prefix", input: "openai/gpt-5.2", wantModel: "gpt-5.2", wantFallback: false},
|
||||
{name: "direct gpt", input: "gpt-4o", wantModel: "gpt-4o", wantFallback: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotModel, reason := resolveCodexModel(tt.input)
|
||||
if gotModel != tt.wantModel {
|
||||
t.Fatalf("resolveCodexModel(%q) model = %q, want %q", tt.input, gotModel, tt.wantModel)
|
||||
}
|
||||
if tt.wantFallback && reason == "" {
|
||||
t.Fatalf("resolveCodexModel(%q) expected fallback reason", tt.input)
|
||||
}
|
||||
if !tt.wantFallback && reason != "" {
|
||||
t.Fatalf("resolveCodexModel(%q) unexpected fallback reason: %q", tt.input, reason)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -262,3 +504,16 @@ func createOpenAITestClient(baseURL, token, accountID string) *openai.Client {
|
||||
c := openai.NewClient(opts...)
|
||||
return &c
|
||||
}
|
||||
|
||||
func writeCompletedSSE(w http.ResponseWriter, response map[string]interface{}) {
|
||||
event := map[string]interface{}{
|
||||
"type": "response.completed",
|
||||
"sequence_number": 1,
|
||||
"response": response,
|
||||
}
|
||||
b, _ := json.Marshal(event)
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
fmt.Fprintf(w, "event: response.completed\n")
|
||||
fmt.Fprintf(w, "data: %s\n\n", string(b))
|
||||
fmt.Fprintf(w, "data: [DONE]\n\n")
|
||||
}
|
||||
|
||||
@@ -0,0 +1,82 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
json "encoding/json"
|
||||
|
||||
copilot "github.com/github/copilot-sdk/go"
|
||||
)
|
||||
|
||||
type GitHubCopilotProvider struct {
|
||||
uri string
|
||||
connectMode string // `stdio` or `grpc``
|
||||
|
||||
session *copilot.Session
|
||||
}
|
||||
|
||||
func NewGitHubCopilotProvider(uri string, connectMode string, model string) (*GitHubCopilotProvider, error) {
|
||||
|
||||
var session *copilot.Session
|
||||
if connectMode == "" {
|
||||
connectMode = "grpc"
|
||||
}
|
||||
switch connectMode {
|
||||
|
||||
case "stdio":
|
||||
//todo
|
||||
case "grpc":
|
||||
client := copilot.NewClient(&copilot.ClientOptions{
|
||||
CLIUrl: uri,
|
||||
})
|
||||
if err := client.Start(context.Background()); err != nil {
|
||||
return nil, fmt.Errorf("Can't connect to Github Copilot, https://github.com/github/copilot-sdk/blob/main/docs/getting-started.md#connecting-to-an-external-cli-server for details")
|
||||
}
|
||||
defer client.Stop()
|
||||
session, _ = client.CreateSession(context.Background(), &copilot.SessionConfig{
|
||||
Model: model,
|
||||
Hooks: &copilot.SessionHooks{},
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
return &GitHubCopilotProvider{
|
||||
uri: uri,
|
||||
connectMode: connectMode,
|
||||
session: session,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Chat sends a chat request to GitHub Copilot
|
||||
func (p *GitHubCopilotProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
|
||||
type tempMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
out := make([]tempMessage, 0, len(messages))
|
||||
|
||||
for _, msg := range messages {
|
||||
out = append(out, tempMessage{
|
||||
Role: msg.Role,
|
||||
Content: msg.Content,
|
||||
})
|
||||
}
|
||||
|
||||
fullcontent, _ := json.Marshal(out)
|
||||
|
||||
content, _ := p.session.Send(ctx, copilot.MessageOptions{
|
||||
Prompt: string(fullcontent),
|
||||
})
|
||||
|
||||
return &LLMResponse{
|
||||
FinishReason: "stop",
|
||||
Content: content,
|
||||
}, nil
|
||||
|
||||
}
|
||||
|
||||
func (p *GitHubCopilotProvider) GetDefaultModel() string {
|
||||
|
||||
return "gpt-4.1"
|
||||
}
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/auth"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
@@ -28,7 +29,7 @@ type HTTPProvider struct {
|
||||
|
||||
func NewHTTPProvider(apiKey, apiBase, proxy string) *HTTPProvider {
|
||||
client := &http.Client{
|
||||
Timeout: 0,
|
||||
Timeout: 120 * time.Second,
|
||||
}
|
||||
|
||||
if proxy != "" {
|
||||
@@ -42,7 +43,7 @@ func NewHTTPProvider(apiKey, apiBase, proxy string) *HTTPProvider {
|
||||
|
||||
return &HTTPProvider{
|
||||
apiKey: apiKey,
|
||||
apiBase: apiBase,
|
||||
apiBase: strings.TrimRight(apiBase, "/"),
|
||||
httpClient: client,
|
||||
}
|
||||
}
|
||||
@@ -52,10 +53,10 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too
|
||||
return nil, fmt.Errorf("API base not configured")
|
||||
}
|
||||
|
||||
// Strip provider prefix from model name (e.g., moonshot/kimi-k2.5 -> kimi-k2.5)
|
||||
// Strip provider prefix from model name (e.g., moonshot/kimi-k2.5 -> kimi-k2.5, groq/openai/gpt-oss-120b -> openai/gpt-oss-120b, ollama/qwen2.5:14b -> qwen2.5:14b)
|
||||
if idx := strings.Index(model, "/"); idx != -1 {
|
||||
prefix := model[:idx]
|
||||
if prefix == "moonshot" || prefix == "nvidia" {
|
||||
if prefix == "moonshot" || prefix == "nvidia" || prefix == "groq" || prefix == "ollama" {
|
||||
model = model[idx+1:]
|
||||
}
|
||||
}
|
||||
@@ -116,7 +117,7 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("API error: %s", string(body))
|
||||
return nil, fmt.Errorf("API request failed:\n Status: %d\n Body: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
return p.parseResponse(body)
|
||||
@@ -239,6 +240,9 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) {
|
||||
}
|
||||
case "openai", "gpt":
|
||||
if cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != "" {
|
||||
if cfg.Providers.OpenAI.AuthMethod == "codex-cli" {
|
||||
return NewCodexProviderWithTokenSource("", "", CreateCodexCliTokenSource()), nil
|
||||
}
|
||||
if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" {
|
||||
return createCodexAuthProvider()
|
||||
}
|
||||
@@ -289,13 +293,47 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) {
|
||||
apiKey = cfg.Providers.VLLM.APIKey
|
||||
apiBase = cfg.Providers.VLLM.APIBase
|
||||
}
|
||||
case "shengsuanyun":
|
||||
if cfg.Providers.ShengSuanYun.APIKey != "" {
|
||||
apiKey = cfg.Providers.ShengSuanYun.APIKey
|
||||
apiBase = cfg.Providers.ShengSuanYun.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://router.shengsuanyun.com/api/v1"
|
||||
}
|
||||
}
|
||||
case "claude-cli", "claudecode", "claude-code":
|
||||
workspace := cfg.Agents.Defaults.Workspace
|
||||
workspace := cfg.WorkspacePath()
|
||||
if workspace == "" {
|
||||
workspace = "."
|
||||
}
|
||||
return NewClaudeCliProvider(workspace), nil
|
||||
case "codex-cli", "codex-code":
|
||||
workspace := cfg.WorkspacePath()
|
||||
if workspace == "" {
|
||||
workspace = "."
|
||||
}
|
||||
return NewCodexCliProvider(workspace), nil
|
||||
case "deepseek":
|
||||
if cfg.Providers.DeepSeek.APIKey != "" {
|
||||
apiKey = cfg.Providers.DeepSeek.APIKey
|
||||
apiBase = cfg.Providers.DeepSeek.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://api.deepseek.com/v1"
|
||||
}
|
||||
if model != "deepseek-chat" && model != "deepseek-reasoner" {
|
||||
model = "deepseek-chat"
|
||||
}
|
||||
}
|
||||
case "github_copilot", "copilot":
|
||||
if cfg.Providers.GitHubCopilot.APIBase != "" {
|
||||
apiBase = cfg.Providers.GitHubCopilot.APIBase
|
||||
} else {
|
||||
apiBase = "localhost:4321"
|
||||
}
|
||||
return NewGitHubCopilotProvider(apiBase, cfg.Providers.GitHubCopilot.ConnectMode, model)
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Fallback: detect provider from model name
|
||||
@@ -371,7 +409,15 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) {
|
||||
if apiBase == "" {
|
||||
apiBase = "https://integrate.api.nvidia.com/v1"
|
||||
}
|
||||
|
||||
case (strings.Contains(lowerModel, "ollama") || strings.HasPrefix(model, "ollama/")) && cfg.Providers.Ollama.APIKey != "":
|
||||
fmt.Println("Ollama provider selected based on model name prefix")
|
||||
apiKey = cfg.Providers.Ollama.APIKey
|
||||
apiBase = cfg.Providers.Ollama.APIBase
|
||||
proxy = cfg.Providers.Ollama.Proxy
|
||||
if apiBase == "" {
|
||||
apiBase = "http://localhost:11434/v1"
|
||||
}
|
||||
fmt.Println("Ollama apiBase:", apiBase)
|
||||
case cfg.Providers.VLLM.APIBase != "":
|
||||
apiKey = cfg.Providers.VLLM.APIKey
|
||||
apiBase = cfg.Providers.VLLM.APIBase
|
||||
|
||||
@@ -0,0 +1,72 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// extractToolCallsFromText parses tool call JSON from response text.
|
||||
// Both ClaudeCliProvider and CodexCliProvider use this to extract
|
||||
// tool calls that the model outputs in its response text.
|
||||
func extractToolCallsFromText(text string) []ToolCall {
|
||||
start := strings.Index(text, `{"tool_calls"`)
|
||||
if start == -1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
end := findMatchingBrace(text, start)
|
||||
if end == start {
|
||||
return nil
|
||||
}
|
||||
|
||||
jsonStr := text[start:end]
|
||||
|
||||
var wrapper struct {
|
||||
ToolCalls []struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Function struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
} `json:"function"`
|
||||
} `json:"tool_calls"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(jsonStr), &wrapper); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var result []ToolCall
|
||||
for _, tc := range wrapper.ToolCalls {
|
||||
var args map[string]interface{}
|
||||
json.Unmarshal([]byte(tc.Function.Arguments), &args)
|
||||
|
||||
result = append(result, ToolCall{
|
||||
ID: tc.ID,
|
||||
Type: tc.Type,
|
||||
Name: tc.Function.Name,
|
||||
Arguments: args,
|
||||
Function: &FunctionCall{
|
||||
Name: tc.Function.Name,
|
||||
Arguments: tc.Function.Arguments,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// stripToolCallsFromText removes tool call JSON from response text.
|
||||
func stripToolCallsFromText(text string) string {
|
||||
start := strings.Index(text, `{"tool_calls"`)
|
||||
if start == -1 {
|
||||
return text
|
||||
}
|
||||
|
||||
end := findMatchingBrace(text, start)
|
||||
if end == start {
|
||||
return text
|
||||
}
|
||||
|
||||
return strings.TrimSpace(text[:start] + text[end:])
|
||||
}
|
||||
+112
-19
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -39,22 +40,22 @@ func NewSessionManager(storage string) *SessionManager {
|
||||
}
|
||||
|
||||
func (sm *SessionManager) GetOrCreate(key string) *Session {
|
||||
sm.mu.RLock()
|
||||
session, ok := sm.sessions[key]
|
||||
sm.mu.RUnlock()
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
if !ok {
|
||||
sm.mu.Lock()
|
||||
session = &Session{
|
||||
Key: key,
|
||||
Messages: []providers.Message{},
|
||||
Created: time.Now(),
|
||||
Updated: time.Now(),
|
||||
}
|
||||
sm.sessions[key] = session
|
||||
sm.mu.Unlock()
|
||||
session, ok := sm.sessions[key]
|
||||
if ok {
|
||||
return session
|
||||
}
|
||||
|
||||
session = &Session{
|
||||
Key: key,
|
||||
Messages: []providers.Message{},
|
||||
Created: time.Now(),
|
||||
Updated: time.Now(),
|
||||
}
|
||||
sm.sessions[key] = session
|
||||
|
||||
return session
|
||||
}
|
||||
|
||||
@@ -130,6 +131,12 @@ func (sm *SessionManager) TruncateHistory(key string, keepLast int) {
|
||||
return
|
||||
}
|
||||
|
||||
if keepLast <= 0 {
|
||||
session.Messages = []providers.Message{}
|
||||
session.Updated = time.Now()
|
||||
return
|
||||
}
|
||||
|
||||
if len(session.Messages) <= keepLast {
|
||||
return
|
||||
}
|
||||
@@ -138,22 +145,92 @@ func (sm *SessionManager) TruncateHistory(key string, keepLast int) {
|
||||
session.Updated = time.Now()
|
||||
}
|
||||
|
||||
func (sm *SessionManager) Save(session *Session) error {
|
||||
// sanitizeFilename converts a session key into a cross-platform safe filename.
|
||||
// Session keys use "channel:chatID" (e.g. "telegram:123456") but ':' is the
|
||||
// volume separator on Windows, so filepath.Base would misinterpret the key.
|
||||
// We replace it with '_'. The original key is preserved inside the JSON file,
|
||||
// so loadSessions still maps back to the right in-memory key.
|
||||
func sanitizeFilename(key string) string {
|
||||
return strings.ReplaceAll(key, ":", "_")
|
||||
}
|
||||
|
||||
func (sm *SessionManager) Save(key string) error {
|
||||
if sm.storage == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
filename := sanitizeFilename(key)
|
||||
|
||||
sessionPath := filepath.Join(sm.storage, session.Key+".json")
|
||||
// filepath.IsLocal rejects empty names, "..", absolute paths, and
|
||||
// OS-reserved device names (NUL, COM1 … on Windows).
|
||||
// The extra checks reject "." and any directory separators so that
|
||||
// the session file is always written directly inside sm.storage.
|
||||
if filename == "." || !filepath.IsLocal(filename) || strings.ContainsAny(filename, `/\`) {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(session, "", " ")
|
||||
// Snapshot under read lock, then perform slow file I/O after unlock.
|
||||
sm.mu.RLock()
|
||||
stored, ok := sm.sessions[key]
|
||||
if !ok {
|
||||
sm.mu.RUnlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
snapshot := Session{
|
||||
Key: stored.Key,
|
||||
Summary: stored.Summary,
|
||||
Created: stored.Created,
|
||||
Updated: stored.Updated,
|
||||
}
|
||||
if len(stored.Messages) > 0 {
|
||||
snapshot.Messages = make([]providers.Message, len(stored.Messages))
|
||||
copy(snapshot.Messages, stored.Messages)
|
||||
} else {
|
||||
snapshot.Messages = []providers.Message{}
|
||||
}
|
||||
sm.mu.RUnlock()
|
||||
|
||||
data, err := json.MarshalIndent(snapshot, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(sessionPath, data, 0644)
|
||||
sessionPath := filepath.Join(sm.storage, filename+".json")
|
||||
tmpFile, err := os.CreateTemp(sm.storage, "session-*.tmp")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tmpPath := tmpFile.Name()
|
||||
cleanup := true
|
||||
defer func() {
|
||||
if cleanup {
|
||||
_ = os.Remove(tmpPath)
|
||||
}
|
||||
}()
|
||||
|
||||
if _, err := tmpFile.Write(data); err != nil {
|
||||
_ = tmpFile.Close()
|
||||
return err
|
||||
}
|
||||
if err := tmpFile.Chmod(0644); err != nil {
|
||||
_ = tmpFile.Close()
|
||||
return err
|
||||
}
|
||||
if err := tmpFile.Sync(); err != nil {
|
||||
_ = tmpFile.Close()
|
||||
return err
|
||||
}
|
||||
if err := tmpFile.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := os.Rename(tmpPath, sessionPath); err != nil {
|
||||
return err
|
||||
}
|
||||
cleanup = false
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sm *SessionManager) loadSessions() error {
|
||||
@@ -187,3 +264,19 @@ func (sm *SessionManager) loadSessions() error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetHistory updates the messages of a session.
|
||||
func (sm *SessionManager) SetHistory(key string, history []providers.Message) {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
session, ok := sm.sessions[key]
|
||||
if ok {
|
||||
// Create a deep copy to strictly isolate internal state
|
||||
// from the caller's slice.
|
||||
msgs := make([]providers.Message, len(history))
|
||||
copy(msgs, history)
|
||||
session.Messages = msgs
|
||||
session.Updated = time.Now()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,74 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSanitizeFilename(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"simple", "simple"},
|
||||
{"telegram:123456", "telegram_123456"},
|
||||
{"discord:987654321", "discord_987654321"},
|
||||
{"slack:C01234", "slack_C01234"},
|
||||
{"no-colons-here", "no-colons-here"},
|
||||
{"multiple:colons:here", "multiple_colons_here"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
got := sanitizeFilename(tt.input)
|
||||
if got != tt.expected {
|
||||
t.Errorf("sanitizeFilename(%q) = %q, want %q", tt.input, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSave_WithColonInKey(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
sm := NewSessionManager(tmpDir)
|
||||
|
||||
// Create a session with a key containing colon (typical channel session key).
|
||||
key := "telegram:123456"
|
||||
sm.GetOrCreate(key)
|
||||
sm.AddMessage(key, "user", "hello")
|
||||
|
||||
// Save should succeed even though the key contains ':'
|
||||
if err := sm.Save(key); err != nil {
|
||||
t.Fatalf("Save(%q) failed: %v", key, err)
|
||||
}
|
||||
|
||||
// The file on disk should use sanitized name.
|
||||
expectedFile := filepath.Join(tmpDir, "telegram_123456.json")
|
||||
if _, err := os.Stat(expectedFile); os.IsNotExist(err) {
|
||||
t.Fatalf("expected session file %s to exist", expectedFile)
|
||||
}
|
||||
|
||||
// Load into a fresh manager and verify the session round-trips.
|
||||
sm2 := NewSessionManager(tmpDir)
|
||||
history := sm2.GetHistory(key)
|
||||
if len(history) != 1 {
|
||||
t.Fatalf("expected 1 message after reload, got %d", len(history))
|
||||
}
|
||||
if history[0].Content != "hello" {
|
||||
t.Errorf("expected message content %q, got %q", "hello", history[0].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSave_RejectsPathTraversal(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
sm := NewSessionManager(tmpDir)
|
||||
|
||||
badKeys := []string{"", ".", "..", "foo/bar", "foo\\bar"}
|
||||
for _, key := range badKeys {
|
||||
sm.GetOrCreate(key)
|
||||
if err := sm.Save(key); err == nil {
|
||||
t.Errorf("Save(%q) should have failed but didn't", key)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2,13 +2,22 @@ package skills
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var namePattern = regexp.MustCompile(`^[a-zA-Z0-9]+(-[a-zA-Z0-9]+)*$`)
|
||||
|
||||
const (
|
||||
MaxNameLength = 64
|
||||
MaxDescriptionLength = 1024
|
||||
)
|
||||
|
||||
type SkillMetadata struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
@@ -21,6 +30,27 @@ type SkillInfo struct {
|
||||
Description string `json:"description"`
|
||||
}
|
||||
|
||||
func (info SkillInfo) validate() error {
|
||||
var errs error
|
||||
if info.Name == "" {
|
||||
errs = errors.Join(errs, errors.New("name is required"))
|
||||
} else {
|
||||
if len(info.Name) > MaxNameLength {
|
||||
errs = errors.Join(errs, fmt.Errorf("name exceeds %d characters", MaxNameLength))
|
||||
}
|
||||
if !namePattern.MatchString(info.Name) {
|
||||
errs = errors.Join(errs, errors.New("name must be alphanumeric with hyphens"))
|
||||
}
|
||||
}
|
||||
|
||||
if info.Description == "" {
|
||||
errs = errors.Join(errs, errors.New("description is required"))
|
||||
} else if len(info.Description) > MaxDescriptionLength {
|
||||
errs = errors.Join(errs, fmt.Errorf("description exceeds %d character", MaxDescriptionLength))
|
||||
}
|
||||
return errs
|
||||
}
|
||||
|
||||
type SkillsLoader struct {
|
||||
workspace string
|
||||
workspaceSkills string // workspace skills (项目级别)
|
||||
@@ -54,6 +84,11 @@ func (sl *SkillsLoader) ListSkills() []SkillInfo {
|
||||
metadata := sl.getSkillMetadata(skillFile)
|
||||
if metadata != nil {
|
||||
info.Description = metadata.Description
|
||||
info.Name = metadata.Name
|
||||
}
|
||||
if err := info.validate(); err != nil {
|
||||
slog.Warn("invalid skill from workspace", "name", info.Name, "error", err)
|
||||
continue
|
||||
}
|
||||
skills = append(skills, info)
|
||||
}
|
||||
@@ -89,6 +124,11 @@ func (sl *SkillsLoader) ListSkills() []SkillInfo {
|
||||
metadata := sl.getSkillMetadata(skillFile)
|
||||
if metadata != nil {
|
||||
info.Description = metadata.Description
|
||||
info.Name = metadata.Name
|
||||
}
|
||||
if err := info.validate(); err != nil {
|
||||
slog.Warn("invalid skill from global", "name", info.Name, "error", err)
|
||||
continue
|
||||
}
|
||||
skills = append(skills, info)
|
||||
}
|
||||
@@ -123,6 +163,11 @@ func (sl *SkillsLoader) ListSkills() []SkillInfo {
|
||||
metadata := sl.getSkillMetadata(skillFile)
|
||||
if metadata != nil {
|
||||
info.Description = metadata.Description
|
||||
info.Name = metadata.Name
|
||||
}
|
||||
if err := info.validate(); err != nil {
|
||||
slog.Warn("invalid skill from builtin", "name", info.Name, "error", err)
|
||||
continue
|
||||
}
|
||||
skills = append(skills, info)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,77 @@
|
||||
package skills
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestSkillsInfoValidate(t *testing.T) {
|
||||
testcases := []struct {
|
||||
name string
|
||||
skillName string
|
||||
description string
|
||||
wantErr bool
|
||||
errContains []string
|
||||
}{
|
||||
{
|
||||
name: "valid-skill",
|
||||
skillName: "valid-skill",
|
||||
description: "a valid skill description",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty-name",
|
||||
skillName: "",
|
||||
description: "description without name",
|
||||
wantErr: true,
|
||||
errContains: []string{"name is required"},
|
||||
},
|
||||
{
|
||||
name: "empty-description",
|
||||
skillName: "skill-without-description",
|
||||
description: "",
|
||||
wantErr: true,
|
||||
errContains: []string{"description is required"},
|
||||
},
|
||||
{
|
||||
name: "empty-both",
|
||||
skillName: "",
|
||||
description: "",
|
||||
wantErr: true,
|
||||
errContains: []string{"name is required", "description is required"},
|
||||
},
|
||||
{
|
||||
name: "name-with-spaces",
|
||||
skillName: "skill with spaces",
|
||||
description: "invalid name with spaces",
|
||||
wantErr: true,
|
||||
errContains: []string{"name must be alphanumeric with hyphens"},
|
||||
},
|
||||
{
|
||||
name: "name-with-underscore",
|
||||
skillName: "skill_underscore",
|
||||
description: "invalid name with underscore",
|
||||
wantErr: true,
|
||||
errContains: []string{"name must be alphanumeric with hyphens"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testcases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
info := SkillInfo{
|
||||
Name: tc.skillName,
|
||||
Description: tc.description,
|
||||
}
|
||||
err := info.validate()
|
||||
if tc.wantErr {
|
||||
assert.Error(t, err)
|
||||
for _, msg := range tc.errContains {
|
||||
assert.ErrorContains(t, err, msg)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,172 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// State represents the persistent state for a workspace.
|
||||
// It includes information about the last active channel/chat.
|
||||
type State struct {
|
||||
// LastChannel is the last channel used for communication
|
||||
LastChannel string `json:"last_channel,omitempty"`
|
||||
|
||||
// LastChatID is the last chat ID used for communication
|
||||
LastChatID string `json:"last_chat_id,omitempty"`
|
||||
|
||||
// Timestamp is the last time this state was updated
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
}
|
||||
|
||||
// Manager manages persistent state with atomic saves.
|
||||
type Manager struct {
|
||||
workspace string
|
||||
state *State
|
||||
mu sync.RWMutex
|
||||
stateFile string
|
||||
}
|
||||
|
||||
// NewManager creates a new state manager for the given workspace.
|
||||
func NewManager(workspace string) *Manager {
|
||||
stateDir := filepath.Join(workspace, "state")
|
||||
stateFile := filepath.Join(stateDir, "state.json")
|
||||
oldStateFile := filepath.Join(workspace, "state.json")
|
||||
|
||||
// Create state directory if it doesn't exist
|
||||
os.MkdirAll(stateDir, 0755)
|
||||
|
||||
sm := &Manager{
|
||||
workspace: workspace,
|
||||
stateFile: stateFile,
|
||||
state: &State{},
|
||||
}
|
||||
|
||||
// Try to load from new location first
|
||||
if _, err := os.Stat(stateFile); os.IsNotExist(err) {
|
||||
// New file doesn't exist, try migrating from old location
|
||||
if data, err := os.ReadFile(oldStateFile); err == nil {
|
||||
if err := json.Unmarshal(data, sm.state); err == nil {
|
||||
// Migrate to new location
|
||||
sm.saveAtomic()
|
||||
log.Printf("[INFO] state: migrated state from %s to %s", oldStateFile, stateFile)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Load from new location
|
||||
sm.load()
|
||||
}
|
||||
|
||||
return sm
|
||||
}
|
||||
|
||||
// SetLastChannel atomically updates the last channel and saves the state.
|
||||
// This method uses a temp file + rename pattern for atomic writes,
|
||||
// ensuring that the state file is never corrupted even if the process crashes.
|
||||
func (sm *Manager) SetLastChannel(channel string) error {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
// Update state
|
||||
sm.state.LastChannel = channel
|
||||
sm.state.Timestamp = time.Now()
|
||||
|
||||
// Atomic save using temp file + rename
|
||||
if err := sm.saveAtomic(); err != nil {
|
||||
return fmt.Errorf("failed to save state atomically: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetLastChatID atomically updates the last chat ID and saves the state.
|
||||
func (sm *Manager) SetLastChatID(chatID string) error {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
// Update state
|
||||
sm.state.LastChatID = chatID
|
||||
sm.state.Timestamp = time.Now()
|
||||
|
||||
// Atomic save using temp file + rename
|
||||
if err := sm.saveAtomic(); err != nil {
|
||||
return fmt.Errorf("failed to save state atomically: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetLastChannel returns the last channel from the state.
|
||||
func (sm *Manager) GetLastChannel() string {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
return sm.state.LastChannel
|
||||
}
|
||||
|
||||
// GetLastChatID returns the last chat ID from the state.
|
||||
func (sm *Manager) GetLastChatID() string {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
return sm.state.LastChatID
|
||||
}
|
||||
|
||||
// GetTimestamp returns the timestamp of the last state update.
|
||||
func (sm *Manager) GetTimestamp() time.Time {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
return sm.state.Timestamp
|
||||
}
|
||||
|
||||
// saveAtomic performs an atomic save using temp file + rename.
|
||||
// This ensures that the state file is never corrupted:
|
||||
// 1. Write to a temp file
|
||||
// 2. Rename temp file to target (atomic on POSIX systems)
|
||||
// 3. If rename fails, cleanup the temp file
|
||||
//
|
||||
// Must be called with the lock held.
|
||||
func (sm *Manager) saveAtomic() error {
|
||||
// Create temp file in the same directory as the target
|
||||
tempFile := sm.stateFile + ".tmp"
|
||||
|
||||
// Marshal state to JSON
|
||||
data, err := json.MarshalIndent(sm.state, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal state: %w", err)
|
||||
}
|
||||
|
||||
// Write to temp file
|
||||
if err := os.WriteFile(tempFile, data, 0644); err != nil {
|
||||
return fmt.Errorf("failed to write temp file: %w", err)
|
||||
}
|
||||
|
||||
// Atomic rename from temp to target
|
||||
if err := os.Rename(tempFile, sm.stateFile); err != nil {
|
||||
// Cleanup temp file if rename fails
|
||||
os.Remove(tempFile)
|
||||
return fmt.Errorf("failed to rename temp file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// load loads the state from disk.
|
||||
func (sm *Manager) load() error {
|
||||
data, err := os.ReadFile(sm.stateFile)
|
||||
if err != nil {
|
||||
// File doesn't exist yet, that's OK
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("failed to read state file: %w", err)
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(data, sm.state); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal state: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,216 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAtomicSave(t *testing.T) {
|
||||
// Create temp workspace
|
||||
tmpDir, err := os.MkdirTemp("", "state-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
sm := NewManager(tmpDir)
|
||||
|
||||
// Test SetLastChannel
|
||||
err = sm.SetLastChannel("test-channel")
|
||||
if err != nil {
|
||||
t.Fatalf("SetLastChannel failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify the channel was saved
|
||||
lastChannel := sm.GetLastChannel()
|
||||
if lastChannel != "test-channel" {
|
||||
t.Errorf("Expected channel 'test-channel', got '%s'", lastChannel)
|
||||
}
|
||||
|
||||
// Verify timestamp was updated
|
||||
if sm.GetTimestamp().IsZero() {
|
||||
t.Error("Expected timestamp to be updated")
|
||||
}
|
||||
|
||||
// Verify state file exists
|
||||
stateFile := filepath.Join(tmpDir, "state", "state.json")
|
||||
if _, err := os.Stat(stateFile); os.IsNotExist(err) {
|
||||
t.Error("Expected state file to exist")
|
||||
}
|
||||
|
||||
// Create a new manager to verify persistence
|
||||
sm2 := NewManager(tmpDir)
|
||||
if sm2.GetLastChannel() != "test-channel" {
|
||||
t.Errorf("Expected persistent channel 'test-channel', got '%s'", sm2.GetLastChannel())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetLastChatID(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "state-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
sm := NewManager(tmpDir)
|
||||
|
||||
// Test SetLastChatID
|
||||
err = sm.SetLastChatID("test-chat-id")
|
||||
if err != nil {
|
||||
t.Fatalf("SetLastChatID failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify the chat ID was saved
|
||||
lastChatID := sm.GetLastChatID()
|
||||
if lastChatID != "test-chat-id" {
|
||||
t.Errorf("Expected chat ID 'test-chat-id', got '%s'", lastChatID)
|
||||
}
|
||||
|
||||
// Verify timestamp was updated
|
||||
if sm.GetTimestamp().IsZero() {
|
||||
t.Error("Expected timestamp to be updated")
|
||||
}
|
||||
|
||||
// Create a new manager to verify persistence
|
||||
sm2 := NewManager(tmpDir)
|
||||
if sm2.GetLastChatID() != "test-chat-id" {
|
||||
t.Errorf("Expected persistent chat ID 'test-chat-id', got '%s'", sm2.GetLastChatID())
|
||||
}
|
||||
}
|
||||
|
||||
func TestAtomicity_NoCorruptionOnInterrupt(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "state-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
sm := NewManager(tmpDir)
|
||||
|
||||
// Write initial state
|
||||
err = sm.SetLastChannel("initial-channel")
|
||||
if err != nil {
|
||||
t.Fatalf("SetLastChannel failed: %v", err)
|
||||
}
|
||||
|
||||
// Simulate a crash scenario by manually creating a corrupted temp file
|
||||
tempFile := filepath.Join(tmpDir, "state", "state.json.tmp")
|
||||
err = os.WriteFile(tempFile, []byte("corrupted data"), 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp file: %v", err)
|
||||
}
|
||||
|
||||
// Verify that the original state is still intact
|
||||
lastChannel := sm.GetLastChannel()
|
||||
if lastChannel != "initial-channel" {
|
||||
t.Errorf("Expected channel 'initial-channel' after corrupted temp file, got '%s'", lastChannel)
|
||||
}
|
||||
|
||||
// Clean up the temp file manually
|
||||
os.Remove(tempFile)
|
||||
|
||||
// Now do a proper save
|
||||
err = sm.SetLastChannel("new-channel")
|
||||
if err != nil {
|
||||
t.Fatalf("SetLastChannel failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify the new state was saved
|
||||
if sm.GetLastChannel() != "new-channel" {
|
||||
t.Errorf("Expected channel 'new-channel', got '%s'", sm.GetLastChannel())
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrentAccess(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "state-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
sm := NewManager(tmpDir)
|
||||
|
||||
// Test concurrent writes
|
||||
done := make(chan bool, 10)
|
||||
for i := 0; i < 10; i++ {
|
||||
go func(idx int) {
|
||||
channel := fmt.Sprintf("channel-%d", idx)
|
||||
sm.SetLastChannel(channel)
|
||||
done <- true
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all goroutines to complete
|
||||
for i := 0; i < 10; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
// Verify the final state is consistent
|
||||
lastChannel := sm.GetLastChannel()
|
||||
if lastChannel == "" {
|
||||
t.Error("Expected non-empty channel after concurrent writes")
|
||||
}
|
||||
|
||||
// Verify state file is valid JSON
|
||||
stateFile := filepath.Join(tmpDir, "state", "state.json")
|
||||
data, err := os.ReadFile(stateFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read state file: %v", err)
|
||||
}
|
||||
|
||||
var state State
|
||||
if err := json.Unmarshal(data, &state); err != nil {
|
||||
t.Errorf("State file contains invalid JSON: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewManager_ExistingState(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "state-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
// Create initial state
|
||||
sm1 := NewManager(tmpDir)
|
||||
sm1.SetLastChannel("existing-channel")
|
||||
sm1.SetLastChatID("existing-chat-id")
|
||||
|
||||
// Create new manager with same workspace
|
||||
sm2 := NewManager(tmpDir)
|
||||
|
||||
// Verify state was loaded
|
||||
if sm2.GetLastChannel() != "existing-channel" {
|
||||
t.Errorf("Expected channel 'existing-channel', got '%s'", sm2.GetLastChannel())
|
||||
}
|
||||
|
||||
if sm2.GetLastChatID() != "existing-chat-id" {
|
||||
t.Errorf("Expected chat ID 'existing-chat-id', got '%s'", sm2.GetLastChatID())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewManager_EmptyWorkspace(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "state-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
sm := NewManager(tmpDir)
|
||||
|
||||
// Verify default state
|
||||
if sm.GetLastChannel() != "" {
|
||||
t.Errorf("Expected empty channel, got '%s'", sm.GetLastChannel())
|
||||
}
|
||||
|
||||
if sm.GetLastChatID() != "" {
|
||||
t.Errorf("Expected empty chat ID, got '%s'", sm.GetLastChatID())
|
||||
}
|
||||
|
||||
if !sm.GetTimestamp().IsZero() {
|
||||
t.Error("Expected zero timestamp for new state")
|
||||
}
|
||||
}
|
||||
+54
-1
@@ -2,11 +2,12 @@ package tools
|
||||
|
||||
import "context"
|
||||
|
||||
// Tool is the interface that all tools must implement.
|
||||
type Tool interface {
|
||||
Name() string
|
||||
Description() string
|
||||
Parameters() map[string]interface{}
|
||||
Execute(ctx context.Context, args map[string]interface{}) (string, error)
|
||||
Execute(ctx context.Context, args map[string]interface{}) *ToolResult
|
||||
}
|
||||
|
||||
// ContextualTool is an optional interface that tools can implement
|
||||
@@ -16,6 +17,58 @@ type ContextualTool interface {
|
||||
SetContext(channel, chatID string)
|
||||
}
|
||||
|
||||
// AsyncCallback is a function type that async tools use to notify completion.
|
||||
// When an async tool finishes its work, it calls this callback with the result.
|
||||
//
|
||||
// The ctx parameter allows the callback to be canceled if the agent is shutting down.
|
||||
// The result parameter contains the tool's execution result.
|
||||
//
|
||||
// Example usage in an async tool:
|
||||
//
|
||||
// func (t *MyAsyncTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
// // Start async work in background
|
||||
// go func() {
|
||||
// result := doAsyncWork()
|
||||
// if t.callback != nil {
|
||||
// t.callback(ctx, result)
|
||||
// }
|
||||
// }()
|
||||
// return AsyncResult("Async task started")
|
||||
// }
|
||||
type AsyncCallback func(ctx context.Context, result *ToolResult)
|
||||
|
||||
// AsyncTool is an optional interface that tools can implement to support
|
||||
// asynchronous execution with completion callbacks.
|
||||
//
|
||||
// Async tools return immediately with an AsyncResult, then notify completion
|
||||
// via the callback set by SetCallback.
|
||||
//
|
||||
// This is useful for:
|
||||
// - Long-running operations that shouldn't block the agent loop
|
||||
// - Subagent spawns that complete independently
|
||||
// - Background tasks that need to report results later
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// type SpawnTool struct {
|
||||
// callback AsyncCallback
|
||||
// }
|
||||
//
|
||||
// func (t *SpawnTool) SetCallback(cb AsyncCallback) {
|
||||
// t.callback = cb
|
||||
// }
|
||||
//
|
||||
// func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
// go t.runSubagent(ctx, args)
|
||||
// return AsyncResult("Subagent spawned, will report back")
|
||||
// }
|
||||
type AsyncTool interface {
|
||||
Tool
|
||||
// SetCallback registers a callback function to be invoked when the async operation completes.
|
||||
// The callback will be called from a goroutine and should handle thread-safety if needed.
|
||||
SetCallback(cb AsyncCallback)
|
||||
}
|
||||
|
||||
func ToolToSchema(tool Tool) map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"type": "function",
|
||||
|
||||
+35
-31
@@ -1,4 +1,4 @@
|
||||
package tools
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -28,12 +28,15 @@ type CronTool struct {
|
||||
}
|
||||
|
||||
// NewCronTool creates a new CronTool
|
||||
func NewCronTool(cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus, workspace string) *CronTool {
|
||||
// execTimeout: 0 means no timeout, >0 sets the timeout duration
|
||||
func NewCronTool(cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus, workspace string, restrict bool, execTimeout time.Duration) *CronTool {
|
||||
execTool := NewExecTool(workspace, restrict)
|
||||
execTool.SetTimeout(execTimeout) // 0 means no timeout
|
||||
return &CronTool{
|
||||
cronService: cronService,
|
||||
executor: executor,
|
||||
msgBus: msgBus,
|
||||
execTool: NewExecTool(workspace, false),
|
||||
execTool: execTool,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -83,7 +86,7 @@ func (t *CronTool) Parameters() map[string]interface{} {
|
||||
},
|
||||
"deliver": map[string]interface{}{
|
||||
"type": "boolean",
|
||||
"description": "If true, send message directly to channel. If false, let agent process the message (for complex tasks). Default: true",
|
||||
"description": "If true, send message directly to channel. If false, let agent process message (for complex tasks). Default: true",
|
||||
},
|
||||
},
|
||||
"required": []string{"action"},
|
||||
@@ -98,11 +101,11 @@ func (t *CronTool) SetContext(channel, chatID string) {
|
||||
t.chatID = chatID
|
||||
}
|
||||
|
||||
// Execute runs the tool with given arguments
|
||||
func (t *CronTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
|
||||
// Execute runs the tool with the given arguments
|
||||
func (t *CronTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
action, ok := args["action"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("action is required")
|
||||
return ErrorResult("action is required")
|
||||
}
|
||||
|
||||
switch action {
|
||||
@@ -117,23 +120,23 @@ func (t *CronTool) Execute(ctx context.Context, args map[string]interface{}) (st
|
||||
case "disable":
|
||||
return t.enableJob(args, false)
|
||||
default:
|
||||
return "", fmt.Errorf("unknown action: %s", action)
|
||||
return ErrorResult(fmt.Sprintf("unknown action: %s", action))
|
||||
}
|
||||
}
|
||||
|
||||
func (t *CronTool) addJob(args map[string]interface{}) (string, error) {
|
||||
func (t *CronTool) addJob(args map[string]interface{}) *ToolResult {
|
||||
t.mu.RLock()
|
||||
channel := t.channel
|
||||
chatID := t.chatID
|
||||
t.mu.RUnlock()
|
||||
|
||||
if channel == "" || chatID == "" {
|
||||
return "Error: no session context (channel/chat_id not set). Use this tool in an active conversation.", nil
|
||||
return ErrorResult("no session context (channel/chat_id not set). Use this tool in an active conversation.")
|
||||
}
|
||||
|
||||
message, ok := args["message"].(string)
|
||||
if !ok || message == "" {
|
||||
return "Error: message is required for add", nil
|
||||
return ErrorResult("message is required for add")
|
||||
}
|
||||
|
||||
var schedule cron.CronSchedule
|
||||
@@ -162,7 +165,7 @@ func (t *CronTool) addJob(args map[string]interface{}) (string, error) {
|
||||
Expr: cronExpr,
|
||||
}
|
||||
} else {
|
||||
return "Error: one of at_seconds, every_seconds, or cron_expr is required", nil
|
||||
return ErrorResult("one of at_seconds, every_seconds, or cron_expr is required")
|
||||
}
|
||||
|
||||
// Read deliver parameter, default to true
|
||||
@@ -192,23 +195,23 @@ func (t *CronTool) addJob(args map[string]interface{}) (string, error) {
|
||||
chatID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Error adding job: %v", err), nil
|
||||
return ErrorResult(fmt.Sprintf("Error adding job: %v", err))
|
||||
}
|
||||
|
||||
|
||||
if command != "" {
|
||||
job.Payload.Command = command
|
||||
// Need to save the updated payload
|
||||
t.cronService.UpdateJob(job)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("Created job '%s' (id: %s)", job.Name, job.ID), nil
|
||||
return SilentResult(fmt.Sprintf("Cron job added: %s (id: %s)", job.Name, job.ID))
|
||||
}
|
||||
|
||||
func (t *CronTool) listJobs() (string, error) {
|
||||
func (t *CronTool) listJobs() *ToolResult {
|
||||
jobs := t.cronService.ListJobs(false)
|
||||
|
||||
if len(jobs) == 0 {
|
||||
return "No scheduled jobs.", nil
|
||||
return SilentResult("No scheduled jobs")
|
||||
}
|
||||
|
||||
result := "Scheduled jobs:\n"
|
||||
@@ -226,37 +229,37 @@ func (t *CronTool) listJobs() (string, error) {
|
||||
result += fmt.Sprintf("- %s (id: %s, %s)\n", j.Name, j.ID, scheduleInfo)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
return SilentResult(result)
|
||||
}
|
||||
|
||||
func (t *CronTool) removeJob(args map[string]interface{}) (string, error) {
|
||||
func (t *CronTool) removeJob(args map[string]interface{}) *ToolResult {
|
||||
jobID, ok := args["job_id"].(string)
|
||||
if !ok || jobID == "" {
|
||||
return "Error: job_id is required for remove", nil
|
||||
return ErrorResult("job_id is required for remove")
|
||||
}
|
||||
|
||||
if t.cronService.RemoveJob(jobID) {
|
||||
return fmt.Sprintf("Removed job %s", jobID), nil
|
||||
return SilentResult(fmt.Sprintf("Cron job removed: %s", jobID))
|
||||
}
|
||||
return fmt.Sprintf("Job %s not found", jobID), nil
|
||||
return ErrorResult(fmt.Sprintf("Job %s not found", jobID))
|
||||
}
|
||||
|
||||
func (t *CronTool) enableJob(args map[string]interface{}, enable bool) (string, error) {
|
||||
func (t *CronTool) enableJob(args map[string]interface{}, enable bool) *ToolResult {
|
||||
jobID, ok := args["job_id"].(string)
|
||||
if !ok || jobID == "" {
|
||||
return "Error: job_id is required for enable/disable", nil
|
||||
return ErrorResult("job_id is required for enable/disable")
|
||||
}
|
||||
|
||||
job := t.cronService.EnableJob(jobID, enable)
|
||||
if job == nil {
|
||||
return fmt.Sprintf("Job %s not found", jobID), nil
|
||||
return ErrorResult(fmt.Sprintf("Job %s not found", jobID))
|
||||
}
|
||||
|
||||
status := "enabled"
|
||||
if !enable {
|
||||
status = "disabled"
|
||||
}
|
||||
return fmt.Sprintf("Job '%s' %s", job.Name, status), nil
|
||||
return SilentResult(fmt.Sprintf("Cron job '%s' %s", job.Name, status))
|
||||
}
|
||||
|
||||
// ExecuteJob executes a cron job through the agent
|
||||
@@ -279,11 +282,12 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string {
|
||||
"command": job.Payload.Command,
|
||||
}
|
||||
|
||||
output, err := t.execTool.Execute(ctx, args)
|
||||
if err != nil {
|
||||
output = fmt.Sprintf("Error executing scheduled command: %v", err)
|
||||
result := t.execTool.Execute(ctx, args)
|
||||
var output string
|
||||
if result.IsError {
|
||||
output = fmt.Sprintf("Error executing scheduled command: %s", result.ForLLM)
|
||||
} else {
|
||||
output = fmt.Sprintf("Scheduled command '%s' executed:\n%s", job.Payload.Command, output)
|
||||
output = fmt.Sprintf("Scheduled command '%s' executed:\n%s", job.Payload.Command, result.ForLLM)
|
||||
}
|
||||
|
||||
t.msgBus.PublishOutbound(bus.OutboundMessage{
|
||||
@@ -307,7 +311,7 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string {
|
||||
// For deliver=false, process through agent (for complex tasks)
|
||||
sessionKey := fmt.Sprintf("cron-%s", job.ID)
|
||||
|
||||
// Call agent with the job's message
|
||||
// Call agent with job's message
|
||||
response, err := t.executor.ProcessDirectWithChannel(
|
||||
ctx,
|
||||
job.Payload.Message,
|
||||
|
||||
+18
-18
@@ -51,54 +51,54 @@ func (t *EditFileTool) Parameters() map[string]interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *EditFileTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
|
||||
func (t *EditFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
path, ok := args["path"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("path is required")
|
||||
return ErrorResult("path is required")
|
||||
}
|
||||
|
||||
oldText, ok := args["old_text"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("old_text is required")
|
||||
return ErrorResult("old_text is required")
|
||||
}
|
||||
|
||||
newText, ok := args["new_text"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("new_text is required")
|
||||
return ErrorResult("new_text is required")
|
||||
}
|
||||
|
||||
resolvedPath, err := validatePath(path, t.allowedDir, t.restrict)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return ErrorResult(err.Error())
|
||||
}
|
||||
|
||||
if _, err := os.Stat(resolvedPath); os.IsNotExist(err) {
|
||||
return "", fmt.Errorf("file not found: %s", path)
|
||||
return ErrorResult(fmt.Sprintf("file not found: %s", path))
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(resolvedPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read file: %w", err)
|
||||
return ErrorResult(fmt.Sprintf("failed to read file: %v", err))
|
||||
}
|
||||
|
||||
contentStr := string(content)
|
||||
|
||||
if !strings.Contains(contentStr, oldText) {
|
||||
return "", fmt.Errorf("old_text not found in file. Make sure it matches exactly")
|
||||
return ErrorResult("old_text not found in file. Make sure it matches exactly")
|
||||
}
|
||||
|
||||
count := strings.Count(contentStr, oldText)
|
||||
if count > 1 {
|
||||
return "", fmt.Errorf("old_text appears %d times. Please provide more context to make it unique", count)
|
||||
return ErrorResult(fmt.Sprintf("old_text appears %d times. Please provide more context to make it unique", count))
|
||||
}
|
||||
|
||||
newContent := strings.Replace(contentStr, oldText, newText, 1)
|
||||
|
||||
if err := os.WriteFile(resolvedPath, []byte(newContent), 0644); err != nil {
|
||||
return "", fmt.Errorf("failed to write file: %w", err)
|
||||
return ErrorResult(fmt.Sprintf("failed to write file: %v", err))
|
||||
}
|
||||
|
||||
return fmt.Sprintf("Successfully edited %s", path), nil
|
||||
return SilentResult(fmt.Sprintf("File edited: %s", path))
|
||||
}
|
||||
|
||||
type AppendFileTool struct {
|
||||
@@ -135,31 +135,31 @@ func (t *AppendFileTool) Parameters() map[string]interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *AppendFileTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
|
||||
func (t *AppendFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
path, ok := args["path"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("path is required")
|
||||
return ErrorResult("path is required")
|
||||
}
|
||||
|
||||
content, ok := args["content"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("content is required")
|
||||
return ErrorResult("content is required")
|
||||
}
|
||||
|
||||
resolvedPath, err := validatePath(path, t.workspace, t.restrict)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return ErrorResult(err.Error())
|
||||
}
|
||||
|
||||
f, err := os.OpenFile(resolvedPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to open file: %w", err)
|
||||
return ErrorResult(fmt.Sprintf("failed to open file: %v", err))
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if _, err := f.WriteString(content); err != nil {
|
||||
return "", fmt.Errorf("failed to append to file: %w", err)
|
||||
return ErrorResult(fmt.Sprintf("failed to append to file: %v", err))
|
||||
}
|
||||
|
||||
return fmt.Sprintf("Successfully appended to %s", path), nil
|
||||
return SilentResult(fmt.Sprintf("Appended to %s", path))
|
||||
}
|
||||
|
||||
@@ -0,0 +1,289 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestEditTool_EditFile_Success verifies successful file editing
|
||||
func TestEditTool_EditFile_Success(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "test.txt")
|
||||
os.WriteFile(testFile, []byte("Hello World\nThis is a test"), 0644)
|
||||
|
||||
tool := NewEditFileTool(tmpDir, true)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"path": testFile,
|
||||
"old_text": "World",
|
||||
"new_text": "Universe",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Success should not be an error
|
||||
if result.IsError {
|
||||
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// Should return SilentResult
|
||||
if !result.Silent {
|
||||
t.Errorf("Expected Silent=true for EditFile, got false")
|
||||
}
|
||||
|
||||
// ForUser should be empty (silent result)
|
||||
if result.ForUser != "" {
|
||||
t.Errorf("Expected ForUser to be empty for SilentResult, got: %s", result.ForUser)
|
||||
}
|
||||
|
||||
// Verify file was actually edited
|
||||
content, err := os.ReadFile(testFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read edited file: %v", err)
|
||||
}
|
||||
contentStr := string(content)
|
||||
if !strings.Contains(contentStr, "Hello Universe") {
|
||||
t.Errorf("Expected file to contain 'Hello Universe', got: %s", contentStr)
|
||||
}
|
||||
if strings.Contains(contentStr, "Hello World") {
|
||||
t.Errorf("Expected 'Hello World' to be replaced, got: %s", contentStr)
|
||||
}
|
||||
}
|
||||
|
||||
// TestEditTool_EditFile_NotFound verifies error handling for non-existent file
|
||||
func TestEditTool_EditFile_NotFound(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "nonexistent.txt")
|
||||
|
||||
tool := NewEditFileTool(tmpDir, true)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"path": testFile,
|
||||
"old_text": "old",
|
||||
"new_text": "new",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should return error result
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error for non-existent file")
|
||||
}
|
||||
|
||||
// Should mention file not found
|
||||
if !strings.Contains(result.ForLLM, "not found") && !strings.Contains(result.ForUser, "not found") {
|
||||
t.Errorf("Expected 'file not found' message, got ForLLM: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestEditTool_EditFile_OldTextNotFound verifies error when old_text doesn't exist
|
||||
func TestEditTool_EditFile_OldTextNotFound(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "test.txt")
|
||||
os.WriteFile(testFile, []byte("Hello World"), 0644)
|
||||
|
||||
tool := NewEditFileTool(tmpDir, true)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"path": testFile,
|
||||
"old_text": "Goodbye",
|
||||
"new_text": "Hello",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should return error result
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error when old_text not found")
|
||||
}
|
||||
|
||||
// Should mention old_text not found
|
||||
if !strings.Contains(result.ForLLM, "not found") && !strings.Contains(result.ForUser, "not found") {
|
||||
t.Errorf("Expected 'not found' message, got ForLLM: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestEditTool_EditFile_MultipleMatches verifies error when old_text appears multiple times
|
||||
func TestEditTool_EditFile_MultipleMatches(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "test.txt")
|
||||
os.WriteFile(testFile, []byte("test test test"), 0644)
|
||||
|
||||
tool := NewEditFileTool(tmpDir, true)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"path": testFile,
|
||||
"old_text": "test",
|
||||
"new_text": "done",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should return error result
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error when old_text appears multiple times")
|
||||
}
|
||||
|
||||
// Should mention multiple occurrences
|
||||
if !strings.Contains(result.ForLLM, "times") && !strings.Contains(result.ForUser, "times") {
|
||||
t.Errorf("Expected 'multiple times' message, got ForLLM: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestEditTool_EditFile_OutsideAllowedDir verifies error when path is outside allowed directory
|
||||
func TestEditTool_EditFile_OutsideAllowedDir(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
otherDir := t.TempDir()
|
||||
testFile := filepath.Join(otherDir, "test.txt")
|
||||
os.WriteFile(testFile, []byte("content"), 0644)
|
||||
|
||||
tool := NewEditFileTool(tmpDir, true) // Restrict to tmpDir
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"path": testFile,
|
||||
"old_text": "content",
|
||||
"new_text": "new",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should return error result
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error when path is outside allowed directory")
|
||||
}
|
||||
|
||||
// Should mention outside allowed directory
|
||||
if !strings.Contains(result.ForLLM, "outside") && !strings.Contains(result.ForUser, "outside") {
|
||||
t.Errorf("Expected 'outside allowed' message, got ForLLM: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestEditTool_EditFile_MissingPath verifies error handling for missing path
|
||||
func TestEditTool_EditFile_MissingPath(t *testing.T) {
|
||||
tool := NewEditFileTool("", false)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"old_text": "old",
|
||||
"new_text": "new",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should return error result
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error when path is missing")
|
||||
}
|
||||
}
|
||||
|
||||
// TestEditTool_EditFile_MissingOldText verifies error handling for missing old_text
|
||||
func TestEditTool_EditFile_MissingOldText(t *testing.T) {
|
||||
tool := NewEditFileTool("", false)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"path": "/tmp/test.txt",
|
||||
"new_text": "new",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should return error result
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error when old_text is missing")
|
||||
}
|
||||
}
|
||||
|
||||
// TestEditTool_EditFile_MissingNewText verifies error handling for missing new_text
|
||||
func TestEditTool_EditFile_MissingNewText(t *testing.T) {
|
||||
tool := NewEditFileTool("", false)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"path": "/tmp/test.txt",
|
||||
"old_text": "old",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should return error result
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error when new_text is missing")
|
||||
}
|
||||
}
|
||||
|
||||
// TestEditTool_AppendFile_Success verifies successful file appending
|
||||
func TestEditTool_AppendFile_Success(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "test.txt")
|
||||
os.WriteFile(testFile, []byte("Initial content"), 0644)
|
||||
|
||||
tool := NewAppendFileTool("", false)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"path": testFile,
|
||||
"content": "\nAppended content",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Success should not be an error
|
||||
if result.IsError {
|
||||
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// Should return SilentResult
|
||||
if !result.Silent {
|
||||
t.Errorf("Expected Silent=true for AppendFile, got false")
|
||||
}
|
||||
|
||||
// ForUser should be empty (silent result)
|
||||
if result.ForUser != "" {
|
||||
t.Errorf("Expected ForUser to be empty for SilentResult, got: %s", result.ForUser)
|
||||
}
|
||||
|
||||
// Verify content was actually appended
|
||||
content, err := os.ReadFile(testFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read file: %v", err)
|
||||
}
|
||||
contentStr := string(content)
|
||||
if !strings.Contains(contentStr, "Initial content") {
|
||||
t.Errorf("Expected original content to remain, got: %s", contentStr)
|
||||
}
|
||||
if !strings.Contains(contentStr, "Appended content") {
|
||||
t.Errorf("Expected appended content, got: %s", contentStr)
|
||||
}
|
||||
}
|
||||
|
||||
// TestEditTool_AppendFile_MissingPath verifies error handling for missing path
|
||||
func TestEditTool_AppendFile_MissingPath(t *testing.T) {
|
||||
tool := NewAppendFileTool("", false)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"content": "test",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should return error result
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error when path is missing")
|
||||
}
|
||||
}
|
||||
|
||||
// TestEditTool_AppendFile_MissingContent verifies error handling for missing content
|
||||
func TestEditTool_AppendFile_MissingContent(t *testing.T) {
|
||||
tool := NewAppendFileTool("", false)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"path": "/tmp/test.txt",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should return error result
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error when content is missing")
|
||||
}
|
||||
}
|
||||
+59
-18
@@ -29,13 +29,54 @@ func validatePath(path, workspace string, restrict bool) (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
if restrict && !strings.HasPrefix(absPath, absWorkspace) {
|
||||
return "", fmt.Errorf("access denied: path is outside the workspace")
|
||||
if restrict {
|
||||
if !isWithinWorkspace(absPath, absWorkspace) {
|
||||
return "", fmt.Errorf("access denied: path is outside the workspace")
|
||||
}
|
||||
|
||||
workspaceReal := absWorkspace
|
||||
if resolved, err := filepath.EvalSymlinks(absWorkspace); err == nil {
|
||||
workspaceReal = resolved
|
||||
}
|
||||
|
||||
if resolved, err := filepath.EvalSymlinks(absPath); err == nil {
|
||||
if !isWithinWorkspace(resolved, workspaceReal) {
|
||||
return "", fmt.Errorf("access denied: symlink resolves outside workspace")
|
||||
}
|
||||
} else if os.IsNotExist(err) {
|
||||
if parentResolved, err := resolveExistingAncestor(filepath.Dir(absPath)); err == nil {
|
||||
if !isWithinWorkspace(parentResolved, workspaceReal) {
|
||||
return "", fmt.Errorf("access denied: symlink resolves outside workspace")
|
||||
}
|
||||
} else if !os.IsNotExist(err) {
|
||||
return "", fmt.Errorf("failed to resolve path: %w", err)
|
||||
}
|
||||
} else {
|
||||
return "", fmt.Errorf("failed to resolve path: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return absPath, nil
|
||||
}
|
||||
|
||||
func resolveExistingAncestor(path string) (string, error) {
|
||||
for current := filepath.Clean(path); ; current = filepath.Dir(current) {
|
||||
if resolved, err := filepath.EvalSymlinks(current); err == nil {
|
||||
return resolved, nil
|
||||
} else if !os.IsNotExist(err) {
|
||||
return "", err
|
||||
}
|
||||
if filepath.Dir(current) == current {
|
||||
return "", os.ErrNotExist
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isWithinWorkspace(candidate, workspace string) bool {
|
||||
rel, err := filepath.Rel(filepath.Clean(workspace), filepath.Clean(candidate))
|
||||
return err == nil && rel != ".." && !strings.HasPrefix(rel, ".."+string(os.PathSeparator))
|
||||
}
|
||||
|
||||
type ReadFileTool struct {
|
||||
workspace string
|
||||
restrict bool
|
||||
@@ -66,23 +107,23 @@ func (t *ReadFileTool) Parameters() map[string]interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ReadFileTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
|
||||
func (t *ReadFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
path, ok := args["path"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("path is required")
|
||||
return ErrorResult("path is required")
|
||||
}
|
||||
|
||||
resolvedPath, err := validatePath(path, t.workspace, t.restrict)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return ErrorResult(err.Error())
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(resolvedPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read file: %w", err)
|
||||
return ErrorResult(fmt.Sprintf("failed to read file: %v", err))
|
||||
}
|
||||
|
||||
return string(content), nil
|
||||
return NewToolResult(string(content))
|
||||
}
|
||||
|
||||
type WriteFileTool struct {
|
||||
@@ -119,32 +160,32 @@ func (t *WriteFileTool) Parameters() map[string]interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *WriteFileTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
|
||||
func (t *WriteFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
path, ok := args["path"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("path is required")
|
||||
return ErrorResult("path is required")
|
||||
}
|
||||
|
||||
content, ok := args["content"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("content is required")
|
||||
return ErrorResult("content is required")
|
||||
}
|
||||
|
||||
resolvedPath, err := validatePath(path, t.workspace, t.restrict)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return ErrorResult(err.Error())
|
||||
}
|
||||
|
||||
dir := filepath.Dir(resolvedPath)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return "", fmt.Errorf("failed to create directory: %w", err)
|
||||
return ErrorResult(fmt.Sprintf("failed to create directory: %v", err))
|
||||
}
|
||||
|
||||
if err := os.WriteFile(resolvedPath, []byte(content), 0644); err != nil {
|
||||
return "", fmt.Errorf("failed to write file: %w", err)
|
||||
return ErrorResult(fmt.Sprintf("failed to write file: %v", err))
|
||||
}
|
||||
|
||||
return "File written successfully", nil
|
||||
return SilentResult(fmt.Sprintf("File written: %s", path))
|
||||
}
|
||||
|
||||
type ListDirTool struct {
|
||||
@@ -177,7 +218,7 @@ func (t *ListDirTool) Parameters() map[string]interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ListDirTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
|
||||
func (t *ListDirTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
path, ok := args["path"].(string)
|
||||
if !ok {
|
||||
path = "."
|
||||
@@ -185,12 +226,12 @@ func (t *ListDirTool) Execute(ctx context.Context, args map[string]interface{})
|
||||
|
||||
resolvedPath, err := validatePath(path, t.workspace, t.restrict)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return ErrorResult(err.Error())
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(resolvedPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read directory: %w", err)
|
||||
return ErrorResult(fmt.Sprintf("failed to read directory: %v", err))
|
||||
}
|
||||
|
||||
result := ""
|
||||
@@ -202,5 +243,5 @@ func (t *ListDirTool) Execute(ctx context.Context, args map[string]interface{})
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
return NewToolResult(result)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,281 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestFilesystemTool_ReadFile_Success verifies successful file reading
|
||||
func TestFilesystemTool_ReadFile_Success(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "test.txt")
|
||||
os.WriteFile(testFile, []byte("test content"), 0644)
|
||||
|
||||
tool := &ReadFileTool{}
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"path": testFile,
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Success should not be an error
|
||||
if result.IsError {
|
||||
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// ForLLM should contain file content
|
||||
if !strings.Contains(result.ForLLM, "test content") {
|
||||
t.Errorf("Expected ForLLM to contain 'test content', got: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// ReadFile returns NewToolResult which only sets ForLLM, not ForUser
|
||||
// This is the expected behavior - file content goes to LLM, not directly to user
|
||||
if result.ForUser != "" {
|
||||
t.Errorf("Expected ForUser to be empty for NewToolResult, got: %s", result.ForUser)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilesystemTool_ReadFile_NotFound verifies error handling for missing file
|
||||
func TestFilesystemTool_ReadFile_NotFound(t *testing.T) {
|
||||
tool := &ReadFileTool{}
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"path": "/nonexistent_file_12345.txt",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Failure should be marked as error
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error for missing file, got IsError=false")
|
||||
}
|
||||
|
||||
// Should contain error message
|
||||
if !strings.Contains(result.ForLLM, "failed to read") && !strings.Contains(result.ForUser, "failed to read") {
|
||||
t.Errorf("Expected error message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilesystemTool_ReadFile_MissingPath verifies error handling for missing path
|
||||
func TestFilesystemTool_ReadFile_MissingPath(t *testing.T) {
|
||||
tool := &ReadFileTool{}
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should return error result
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error when path is missing")
|
||||
}
|
||||
|
||||
// Should mention required parameter
|
||||
if !strings.Contains(result.ForLLM, "path is required") && !strings.Contains(result.ForUser, "path is required") {
|
||||
t.Errorf("Expected 'path is required' message, got ForLLM: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilesystemTool_WriteFile_Success verifies successful file writing
|
||||
func TestFilesystemTool_WriteFile_Success(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "newfile.txt")
|
||||
|
||||
tool := &WriteFileTool{}
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"path": testFile,
|
||||
"content": "hello world",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Success should not be an error
|
||||
if result.IsError {
|
||||
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// WriteFile returns SilentResult
|
||||
if !result.Silent {
|
||||
t.Errorf("Expected Silent=true for WriteFile, got false")
|
||||
}
|
||||
|
||||
// ForUser should be empty (silent result)
|
||||
if result.ForUser != "" {
|
||||
t.Errorf("Expected ForUser to be empty for SilentResult, got: %s", result.ForUser)
|
||||
}
|
||||
|
||||
// Verify file was actually written
|
||||
content, err := os.ReadFile(testFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read written file: %v", err)
|
||||
}
|
||||
if string(content) != "hello world" {
|
||||
t.Errorf("Expected file content 'hello world', got: %s", string(content))
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilesystemTool_WriteFile_CreateDir verifies directory creation
|
||||
func TestFilesystemTool_WriteFile_CreateDir(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "subdir", "newfile.txt")
|
||||
|
||||
tool := &WriteFileTool{}
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"path": testFile,
|
||||
"content": "test",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Success should not be an error
|
||||
if result.IsError {
|
||||
t.Errorf("Expected success with directory creation, got IsError=true: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// Verify directory was created and file written
|
||||
content, err := os.ReadFile(testFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read written file: %v", err)
|
||||
}
|
||||
if string(content) != "test" {
|
||||
t.Errorf("Expected file content 'test', got: %s", string(content))
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilesystemTool_WriteFile_MissingPath verifies error handling for missing path
|
||||
func TestFilesystemTool_WriteFile_MissingPath(t *testing.T) {
|
||||
tool := &WriteFileTool{}
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"content": "test",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should return error result
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error when path is missing")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilesystemTool_WriteFile_MissingContent verifies error handling for missing content
|
||||
func TestFilesystemTool_WriteFile_MissingContent(t *testing.T) {
|
||||
tool := &WriteFileTool{}
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"path": "/tmp/test.txt",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should return error result
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error when content is missing")
|
||||
}
|
||||
|
||||
// Should mention required parameter
|
||||
if !strings.Contains(result.ForLLM, "content is required") && !strings.Contains(result.ForUser, "content is required") {
|
||||
t.Errorf("Expected 'content is required' message, got ForLLM: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilesystemTool_ListDir_Success verifies successful directory listing
|
||||
func TestFilesystemTool_ListDir_Success(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
os.WriteFile(filepath.Join(tmpDir, "file1.txt"), []byte("content"), 0644)
|
||||
os.WriteFile(filepath.Join(tmpDir, "file2.txt"), []byte("content"), 0644)
|
||||
os.Mkdir(filepath.Join(tmpDir, "subdir"), 0755)
|
||||
|
||||
tool := &ListDirTool{}
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"path": tmpDir,
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Success should not be an error
|
||||
if result.IsError {
|
||||
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// Should list files and directories
|
||||
if !strings.Contains(result.ForLLM, "file1.txt") || !strings.Contains(result.ForLLM, "file2.txt") {
|
||||
t.Errorf("Expected files in listing, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "subdir") {
|
||||
t.Errorf("Expected subdir in listing, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilesystemTool_ListDir_NotFound verifies error handling for non-existent directory
|
||||
func TestFilesystemTool_ListDir_NotFound(t *testing.T) {
|
||||
tool := &ListDirTool{}
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"path": "/nonexistent_directory_12345",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Failure should be marked as error
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error for non-existent directory, got IsError=false")
|
||||
}
|
||||
|
||||
// Should contain error message
|
||||
if !strings.Contains(result.ForLLM, "failed to read") && !strings.Contains(result.ForUser, "failed to read") {
|
||||
t.Errorf("Expected error message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilesystemTool_ListDir_DefaultPath verifies default to current directory
|
||||
func TestFilesystemTool_ListDir_DefaultPath(t *testing.T) {
|
||||
tool := &ListDirTool{}
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should use "." as default path
|
||||
if result.IsError {
|
||||
t.Errorf("Expected success with default path '.', got IsError=true: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// Block paths that look inside workspace but point outside via symlink.
|
||||
func TestFilesystemTool_ReadFile_RejectsSymlinkEscape(t *testing.T) {
|
||||
|
||||
root := t.TempDir()
|
||||
workspace := filepath.Join(root, "workspace")
|
||||
if err := os.MkdirAll(workspace, 0755); err != nil {
|
||||
t.Fatalf("failed to create workspace: %v", err)
|
||||
}
|
||||
|
||||
secret := filepath.Join(root, "secret.txt")
|
||||
if err := os.WriteFile(secret, []byte("top secret"), 0644); err != nil {
|
||||
t.Fatalf("failed to write secret file: %v", err)
|
||||
}
|
||||
|
||||
link := filepath.Join(workspace, "leak.txt")
|
||||
if err := os.Symlink(secret, link); err != nil {
|
||||
t.Skipf("symlink not supported in this environment: %v", err)
|
||||
}
|
||||
|
||||
tool := NewReadFileTool(workspace, true)
|
||||
result := tool.Execute(context.Background(), map[string]interface{}{
|
||||
"path": link,
|
||||
})
|
||||
|
||||
if !result.IsError {
|
||||
t.Fatalf("expected symlink escape to be blocked")
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "symlink resolves outside workspace") {
|
||||
t.Fatalf("expected symlink escape error, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,147 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
// I2CTool provides I2C bus interaction for reading sensors and controlling peripherals.
|
||||
type I2CTool struct{}
|
||||
|
||||
func NewI2CTool() *I2CTool {
|
||||
return &I2CTool{}
|
||||
}
|
||||
|
||||
func (t *I2CTool) Name() string {
|
||||
return "i2c"
|
||||
}
|
||||
|
||||
func (t *I2CTool) Description() string {
|
||||
return "Interact with I2C bus devices for reading sensors and controlling peripherals. Actions: detect (list buses), scan (find devices on a bus), read (read bytes from device), write (send bytes to device). Linux only."
|
||||
}
|
||||
|
||||
func (t *I2CTool) Parameters() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"action": map[string]interface{}{
|
||||
"type": "string",
|
||||
"enum": []string{"detect", "scan", "read", "write"},
|
||||
"description": "Action to perform: detect (list available I2C buses), scan (find devices on a bus), read (read bytes from a device), write (send bytes to a device)",
|
||||
},
|
||||
"bus": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "I2C bus number (e.g. \"1\" for /dev/i2c-1). Required for scan/read/write.",
|
||||
},
|
||||
"address": map[string]interface{}{
|
||||
"type": "integer",
|
||||
"description": "7-bit I2C device address (0x03-0x77). Required for read/write.",
|
||||
},
|
||||
"register": map[string]interface{}{
|
||||
"type": "integer",
|
||||
"description": "Register address to read from or write to. If set, sends register byte before read/write.",
|
||||
},
|
||||
"data": map[string]interface{}{
|
||||
"type": "array",
|
||||
"items": map[string]interface{}{"type": "integer"},
|
||||
"description": "Bytes to write (0-255 each). Required for write action.",
|
||||
},
|
||||
"length": map[string]interface{}{
|
||||
"type": "integer",
|
||||
"description": "Number of bytes to read (1-256). Default: 1. Used with read action.",
|
||||
},
|
||||
"confirm": map[string]interface{}{
|
||||
"type": "boolean",
|
||||
"description": "Must be true for write operations. Safety guard to prevent accidental writes.",
|
||||
},
|
||||
},
|
||||
"required": []string{"action"},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *I2CTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
if runtime.GOOS != "linux" {
|
||||
return ErrorResult("I2C is only supported on Linux. This tool requires /dev/i2c-* device files.")
|
||||
}
|
||||
|
||||
action, ok := args["action"].(string)
|
||||
if !ok {
|
||||
return ErrorResult("action is required")
|
||||
}
|
||||
|
||||
switch action {
|
||||
case "detect":
|
||||
return t.detect()
|
||||
case "scan":
|
||||
return t.scan(args)
|
||||
case "read":
|
||||
return t.readDevice(args)
|
||||
case "write":
|
||||
return t.writeDevice(args)
|
||||
default:
|
||||
return ErrorResult(fmt.Sprintf("unknown action: %s (valid: detect, scan, read, write)", action))
|
||||
}
|
||||
}
|
||||
|
||||
// detect lists available I2C buses by globbing /dev/i2c-*
|
||||
func (t *I2CTool) detect() *ToolResult {
|
||||
matches, err := filepath.Glob("/dev/i2c-*")
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to scan for I2C buses: %v", err))
|
||||
}
|
||||
|
||||
if len(matches) == 0 {
|
||||
return SilentResult("No I2C buses found. You may need to:\n1. Load the i2c-dev module: modprobe i2c-dev\n2. Check that I2C is enabled in device tree\n3. Configure pinmux for your board (see hardware skill)")
|
||||
}
|
||||
|
||||
type busInfo struct {
|
||||
Path string `json:"path"`
|
||||
Bus string `json:"bus"`
|
||||
}
|
||||
|
||||
buses := make([]busInfo, 0, len(matches))
|
||||
re := regexp.MustCompile(`/dev/i2c-(\d+)`)
|
||||
for _, m := range matches {
|
||||
if sub := re.FindStringSubmatch(m); sub != nil {
|
||||
buses = append(buses, busInfo{Path: m, Bus: sub[1]})
|
||||
}
|
||||
}
|
||||
|
||||
result, _ := json.MarshalIndent(buses, "", " ")
|
||||
return SilentResult(fmt.Sprintf("Found %d I2C bus(es):\n%s", len(buses), string(result)))
|
||||
}
|
||||
|
||||
// isValidBusID checks that a bus identifier is a simple number (prevents path injection)
|
||||
func isValidBusID(id string) bool {
|
||||
matched, _ := regexp.MatchString(`^\d+$`, id)
|
||||
return matched
|
||||
}
|
||||
|
||||
// parseI2CAddress extracts and validates an I2C address from args
|
||||
func parseI2CAddress(args map[string]interface{}) (int, *ToolResult) {
|
||||
addrFloat, ok := args["address"].(float64)
|
||||
if !ok {
|
||||
return 0, ErrorResult("address is required (e.g. 0x38 for AHT20)")
|
||||
}
|
||||
addr := int(addrFloat)
|
||||
if addr < 0x03 || addr > 0x77 {
|
||||
return 0, ErrorResult("address must be in valid 7-bit range (0x03-0x77)")
|
||||
}
|
||||
return addr, nil
|
||||
}
|
||||
|
||||
// parseI2CBus extracts and validates an I2C bus from args
|
||||
func parseI2CBus(args map[string]interface{}) (string, *ToolResult) {
|
||||
bus, ok := args["bus"].(string)
|
||||
if !ok || bus == "" {
|
||||
return "", ErrorResult("bus is required (e.g. \"1\" for /dev/i2c-1)")
|
||||
}
|
||||
if !isValidBusID(bus) {
|
||||
return "", ErrorResult("invalid bus identifier: must be a number (e.g. \"1\")")
|
||||
}
|
||||
return bus, nil
|
||||
}
|
||||
@@ -0,0 +1,282 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// I2C ioctl constants from Linux kernel headers (<linux/i2c-dev.h>, <linux/i2c.h>)
|
||||
const (
|
||||
i2cSlave = 0x0703 // Set slave address (fails if in use by driver)
|
||||
i2cFuncs = 0x0705 // Query adapter functionality bitmask
|
||||
i2cSmbus = 0x0720 // Perform SMBus transaction
|
||||
|
||||
// I2C_FUNC capability bits
|
||||
i2cFuncSmbusQuick = 0x00010000
|
||||
i2cFuncSmbusReadByte = 0x00020000
|
||||
|
||||
// SMBus transaction types
|
||||
i2cSmbusRead = 0
|
||||
i2cSmbusWrite = 1
|
||||
|
||||
// SMBus protocol sizes
|
||||
i2cSmbusQuick = 0
|
||||
i2cSmbusByte = 1
|
||||
)
|
||||
|
||||
// i2cSmbusData matches the kernel union i2c_smbus_data (34 bytes max).
|
||||
// For quick and byte transactions only the first byte is used (if at all).
|
||||
type i2cSmbusData [34]byte
|
||||
|
||||
// i2cSmbusArgs matches the kernel struct i2c_smbus_ioctl_data.
|
||||
type i2cSmbusArgs struct {
|
||||
readWrite uint8
|
||||
command uint8
|
||||
size uint32
|
||||
data *i2cSmbusData
|
||||
}
|
||||
|
||||
// smbusProbe performs a single SMBus probe at the given address.
|
||||
// Uses SMBus Quick Write (safest) or falls back to SMBus Read Byte for
|
||||
// EEPROM address ranges where quick write can corrupt AT24RF08 chips.
|
||||
// This matches i2cdetect's MODE_AUTO behavior.
|
||||
func smbusProbe(fd int, addr int, hasQuick bool) bool {
|
||||
// EEPROM ranges: use read byte (quick write can corrupt AT24RF08)
|
||||
useReadByte := (addr >= 0x30 && addr <= 0x37) || (addr >= 0x50 && addr <= 0x5F)
|
||||
|
||||
if !useReadByte && hasQuick {
|
||||
// SMBus Quick Write: [START] [ADDR|W] [ACK/NACK] [STOP]
|
||||
// Safest probe — no data transferred
|
||||
args := i2cSmbusArgs{
|
||||
readWrite: i2cSmbusWrite,
|
||||
command: 0,
|
||||
size: i2cSmbusQuick,
|
||||
data: nil,
|
||||
}
|
||||
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), i2cSmbus, uintptr(unsafe.Pointer(&args)))
|
||||
return errno == 0
|
||||
}
|
||||
|
||||
// SMBus Read Byte: [START] [ADDR|R] [ACK/NACK] [DATA] [STOP]
|
||||
var data i2cSmbusData
|
||||
args := i2cSmbusArgs{
|
||||
readWrite: i2cSmbusRead,
|
||||
command: 0,
|
||||
size: i2cSmbusByte,
|
||||
data: &data,
|
||||
}
|
||||
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), i2cSmbus, uintptr(unsafe.Pointer(&args)))
|
||||
return errno == 0
|
||||
}
|
||||
|
||||
// scan probes valid 7-bit addresses on a bus for connected devices.
|
||||
// Uses the same hybrid probe strategy as i2cdetect's MODE_AUTO:
|
||||
// SMBus Quick Write for most addresses, SMBus Read Byte for EEPROM ranges.
|
||||
func (t *I2CTool) scan(args map[string]interface{}) *ToolResult {
|
||||
bus, errResult := parseI2CBus(args)
|
||||
if errResult != nil {
|
||||
return errResult
|
||||
}
|
||||
|
||||
devPath := fmt.Sprintf("/dev/i2c-%s", bus)
|
||||
fd, err := syscall.Open(devPath, syscall.O_RDWR, 0)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to open %s: %v (check permissions and i2c-dev module)", devPath, err))
|
||||
}
|
||||
defer syscall.Close(fd)
|
||||
|
||||
// Query adapter capabilities to determine available probe methods.
|
||||
// I2C_FUNCS writes an unsigned long, which is word-sized on Linux.
|
||||
var funcs uintptr
|
||||
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), i2cFuncs, uintptr(unsafe.Pointer(&funcs)))
|
||||
if errno != 0 {
|
||||
return ErrorResult(fmt.Sprintf("failed to query I2C adapter capabilities on %s: %v", devPath, errno))
|
||||
}
|
||||
|
||||
hasQuick := funcs&i2cFuncSmbusQuick != 0
|
||||
hasReadByte := funcs&i2cFuncSmbusReadByte != 0
|
||||
|
||||
if !hasQuick && !hasReadByte {
|
||||
return ErrorResult(fmt.Sprintf("I2C adapter %s supports neither SMBus Quick nor Read Byte — cannot probe safely", devPath))
|
||||
}
|
||||
|
||||
type deviceEntry struct {
|
||||
Address string `json:"address"`
|
||||
Status string `json:"status,omitempty"`
|
||||
}
|
||||
|
||||
var found []deviceEntry
|
||||
// Scan 0x08-0x77, skipping I2C reserved addresses 0x00-0x07
|
||||
for addr := 0x08; addr <= 0x77; addr++ {
|
||||
// Set slave address — EBUSY means a kernel driver owns this address
|
||||
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), i2cSlave, uintptr(addr))
|
||||
if errno != 0 {
|
||||
if errno == syscall.EBUSY {
|
||||
found = append(found, deviceEntry{
|
||||
Address: fmt.Sprintf("0x%02x", addr),
|
||||
Status: "busy (in use by kernel driver)",
|
||||
})
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if smbusProbe(fd, addr, hasQuick) {
|
||||
found = append(found, deviceEntry{
|
||||
Address: fmt.Sprintf("0x%02x", addr),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if len(found) == 0 {
|
||||
return SilentResult(fmt.Sprintf("No devices found on %s. Check wiring and pull-up resistors.", devPath))
|
||||
}
|
||||
|
||||
result, _ := json.MarshalIndent(map[string]interface{}{
|
||||
"bus": devPath,
|
||||
"devices": found,
|
||||
"count": len(found),
|
||||
}, "", " ")
|
||||
return SilentResult(fmt.Sprintf("Scan of %s:\n%s", devPath, string(result)))
|
||||
}
|
||||
|
||||
// readDevice reads bytes from an I2C device, optionally at a specific register
|
||||
func (t *I2CTool) readDevice(args map[string]interface{}) *ToolResult {
|
||||
bus, errResult := parseI2CBus(args)
|
||||
if errResult != nil {
|
||||
return errResult
|
||||
}
|
||||
|
||||
addr, errResult := parseI2CAddress(args)
|
||||
if errResult != nil {
|
||||
return errResult
|
||||
}
|
||||
|
||||
length := 1
|
||||
if l, ok := args["length"].(float64); ok {
|
||||
length = int(l)
|
||||
}
|
||||
if length < 1 || length > 256 {
|
||||
return ErrorResult("length must be between 1 and 256")
|
||||
}
|
||||
|
||||
devPath := fmt.Sprintf("/dev/i2c-%s", bus)
|
||||
fd, err := syscall.Open(devPath, syscall.O_RDWR, 0)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to open %s: %v", devPath, err))
|
||||
}
|
||||
defer syscall.Close(fd)
|
||||
|
||||
// Set slave address
|
||||
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), i2cSlave, uintptr(addr))
|
||||
if errno != 0 {
|
||||
return ErrorResult(fmt.Sprintf("failed to set I2C address 0x%02x: %v", addr, errno))
|
||||
}
|
||||
|
||||
// If register is specified, write it first
|
||||
if regFloat, ok := args["register"].(float64); ok {
|
||||
reg := int(regFloat)
|
||||
if reg < 0 || reg > 255 {
|
||||
return ErrorResult("register must be between 0x00 and 0xFF")
|
||||
}
|
||||
_, err := syscall.Write(fd, []byte{byte(reg)})
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to write register 0x%02x: %v", reg, err))
|
||||
}
|
||||
}
|
||||
|
||||
// Read data
|
||||
buf := make([]byte, length)
|
||||
n, err := syscall.Read(fd, buf)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to read from device 0x%02x: %v", addr, err))
|
||||
}
|
||||
|
||||
// Format as hex bytes
|
||||
hexBytes := make([]string, n)
|
||||
intBytes := make([]int, n)
|
||||
for i := 0; i < n; i++ {
|
||||
hexBytes[i] = fmt.Sprintf("0x%02x", buf[i])
|
||||
intBytes[i] = int(buf[i])
|
||||
}
|
||||
|
||||
result, _ := json.MarshalIndent(map[string]interface{}{
|
||||
"bus": devPath,
|
||||
"address": fmt.Sprintf("0x%02x", addr),
|
||||
"bytes": intBytes,
|
||||
"hex": hexBytes,
|
||||
"length": n,
|
||||
}, "", " ")
|
||||
return SilentResult(string(result))
|
||||
}
|
||||
|
||||
// writeDevice writes bytes to an I2C device, optionally at a specific register
|
||||
func (t *I2CTool) writeDevice(args map[string]interface{}) *ToolResult {
|
||||
confirm, _ := args["confirm"].(bool)
|
||||
if !confirm {
|
||||
return ErrorResult("write operations require confirm: true. Please confirm with the user before writing to I2C devices, as incorrect writes can misconfigure hardware.")
|
||||
}
|
||||
|
||||
bus, errResult := parseI2CBus(args)
|
||||
if errResult != nil {
|
||||
return errResult
|
||||
}
|
||||
|
||||
addr, errResult := parseI2CAddress(args)
|
||||
if errResult != nil {
|
||||
return errResult
|
||||
}
|
||||
|
||||
dataRaw, ok := args["data"].([]interface{})
|
||||
if !ok || len(dataRaw) == 0 {
|
||||
return ErrorResult("data is required for write (array of byte values 0-255)")
|
||||
}
|
||||
if len(dataRaw) > 256 {
|
||||
return ErrorResult("data too long: maximum 256 bytes per I2C transaction")
|
||||
}
|
||||
|
||||
data := make([]byte, 0, len(dataRaw)+1)
|
||||
|
||||
// If register is specified, prepend it to the data
|
||||
if regFloat, ok := args["register"].(float64); ok {
|
||||
reg := int(regFloat)
|
||||
if reg < 0 || reg > 255 {
|
||||
return ErrorResult("register must be between 0x00 and 0xFF")
|
||||
}
|
||||
data = append(data, byte(reg))
|
||||
}
|
||||
|
||||
for i, v := range dataRaw {
|
||||
f, ok := v.(float64)
|
||||
if !ok {
|
||||
return ErrorResult(fmt.Sprintf("data[%d] is not a valid byte value", i))
|
||||
}
|
||||
b := int(f)
|
||||
if b < 0 || b > 255 {
|
||||
return ErrorResult(fmt.Sprintf("data[%d] = %d is out of byte range (0-255)", i, b))
|
||||
}
|
||||
data = append(data, byte(b))
|
||||
}
|
||||
|
||||
devPath := fmt.Sprintf("/dev/i2c-%s", bus)
|
||||
fd, err := syscall.Open(devPath, syscall.O_RDWR, 0)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to open %s: %v", devPath, err))
|
||||
}
|
||||
defer syscall.Close(fd)
|
||||
|
||||
// Set slave address
|
||||
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), i2cSlave, uintptr(addr))
|
||||
if errno != 0 {
|
||||
return ErrorResult(fmt.Sprintf("failed to set I2C address 0x%02x: %v", addr, errno))
|
||||
}
|
||||
|
||||
// Write data
|
||||
n, err := syscall.Write(fd, data)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("failed to write to device 0x%02x: %v", addr, err))
|
||||
}
|
||||
|
||||
return SilentResult(fmt.Sprintf("Wrote %d byte(s) to device 0x%02x on %s", n, addr, devPath))
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
//go:build !linux
|
||||
|
||||
package tools
|
||||
|
||||
// scan is a stub for non-Linux platforms.
|
||||
func (t *I2CTool) scan(args map[string]interface{}) *ToolResult {
|
||||
return ErrorResult("I2C is only supported on Linux")
|
||||
}
|
||||
|
||||
// readDevice is a stub for non-Linux platforms.
|
||||
func (t *I2CTool) readDevice(args map[string]interface{}) *ToolResult {
|
||||
return ErrorResult("I2C is only supported on Linux")
|
||||
}
|
||||
|
||||
// writeDevice is a stub for non-Linux platforms.
|
||||
func (t *I2CTool) writeDevice(args map[string]interface{}) *ToolResult {
|
||||
return ErrorResult("I2C is only supported on Linux")
|
||||
}
|
||||
@@ -0,0 +1,85 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/mcp"
|
||||
)
|
||||
|
||||
type MCPTool struct {
|
||||
manager *mcp.Manager
|
||||
name string
|
||||
description string
|
||||
parameters map[string]any
|
||||
}
|
||||
|
||||
func NewMCPTool(manager *mcp.Manager, tool mcp.RegisteredTool) *MCPTool {
|
||||
description := tool.Description
|
||||
if description == "" {
|
||||
description = fmt.Sprintf("MCP tool %s from server %s", tool.ToolName, tool.ServerName)
|
||||
}
|
||||
|
||||
return &MCPTool{
|
||||
manager: manager,
|
||||
name: tool.QualifiedName,
|
||||
description: description,
|
||||
parameters: tool.Parameters,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *MCPTool) Name() string {
|
||||
return t.name
|
||||
}
|
||||
|
||||
func (t *MCPTool) Description() string {
|
||||
return t.description
|
||||
}
|
||||
|
||||
func (t *MCPTool) Parameters() map[string]interface{} {
|
||||
return t.parameters
|
||||
}
|
||||
|
||||
func (t *MCPTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
if t.manager == nil {
|
||||
return ErrorResult("MCP manager is not configured")
|
||||
}
|
||||
|
||||
result, err := t.manager.CallTool(ctx, t.name, args)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("MCP tool %s failed: %v", t.name, err)).WithError(err)
|
||||
}
|
||||
if result.IsError {
|
||||
err := errors.New(result.Content)
|
||||
return ErrorResult(result.Content).WithError(err)
|
||||
}
|
||||
return SilentResult(result.Content)
|
||||
}
|
||||
|
||||
// RegisterMCPTools discovers tools from MCP servers and registers them into the registry.
|
||||
func RegisterMCPTools(ctx context.Context, registry *ToolRegistry, manager *mcp.Manager) (int, error) {
|
||||
if registry == nil || manager == nil {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
discoveredTools, err := manager.DiscoverTools(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return RegisterKnownMCPTools(registry, manager, discoveredTools), nil
|
||||
}
|
||||
|
||||
// RegisterKnownMCPTools registers already-discovered MCP tools.
|
||||
// This avoids repeated discovery work when multiple registries share one manager.
|
||||
func RegisterKnownMCPTools(registry *ToolRegistry, manager *mcp.Manager, discoveredTools []mcp.RegisteredTool) int {
|
||||
if registry == nil || manager == nil || len(discoveredTools) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
for _, tool := range discoveredTools {
|
||||
registry.Register(NewMCPTool(manager, tool))
|
||||
}
|
||||
return len(discoveredTools)
|
||||
}
|
||||
+22
-6
@@ -11,6 +11,7 @@ type MessageTool struct {
|
||||
sendCallback SendCallback
|
||||
defaultChannel string
|
||||
defaultChatID string
|
||||
sentInRound bool // Tracks whether a message was sent in the current processing round
|
||||
}
|
||||
|
||||
func NewMessageTool() *MessageTool {
|
||||
@@ -49,16 +50,22 @@ func (t *MessageTool) Parameters() map[string]interface{} {
|
||||
func (t *MessageTool) SetContext(channel, chatID string) {
|
||||
t.defaultChannel = channel
|
||||
t.defaultChatID = chatID
|
||||
t.sentInRound = false // Reset send tracking for new processing round
|
||||
}
|
||||
|
||||
// HasSentInRound returns true if the message tool sent a message during the current round.
|
||||
func (t *MessageTool) HasSentInRound() bool {
|
||||
return t.sentInRound
|
||||
}
|
||||
|
||||
func (t *MessageTool) SetSendCallback(callback SendCallback) {
|
||||
t.sendCallback = callback
|
||||
}
|
||||
|
||||
func (t *MessageTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
|
||||
func (t *MessageTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
content, ok := args["content"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("content is required")
|
||||
return &ToolResult{ForLLM: "content is required", IsError: true}
|
||||
}
|
||||
|
||||
channel, _ := args["channel"].(string)
|
||||
@@ -72,16 +79,25 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]interface{})
|
||||
}
|
||||
|
||||
if channel == "" || chatID == "" {
|
||||
return "Error: No target channel/chat specified", nil
|
||||
return &ToolResult{ForLLM: "No target channel/chat specified", IsError: true}
|
||||
}
|
||||
|
||||
if t.sendCallback == nil {
|
||||
return "Error: Message sending not configured", nil
|
||||
return &ToolResult{ForLLM: "Message sending not configured", IsError: true}
|
||||
}
|
||||
|
||||
if err := t.sendCallback(channel, chatID, content); err != nil {
|
||||
return fmt.Sprintf("Error sending message: %v", err), nil
|
||||
return &ToolResult{
|
||||
ForLLM: fmt.Sprintf("sending message: %v", err),
|
||||
IsError: true,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf("Message sent to %s:%s", channel, chatID), nil
|
||||
t.sentInRound = true
|
||||
// Silent: user already received the message directly
|
||||
return &ToolResult{
|
||||
ForLLM: fmt.Sprintf("Message sent to %s:%s", channel, chatID),
|
||||
Silent: true,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,259 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMessageTool_Execute_Success(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
tool.SetContext("test-channel", "test-chat-id")
|
||||
|
||||
var sentChannel, sentChatID, sentContent string
|
||||
tool.SetSendCallback(func(channel, chatID, content string) error {
|
||||
sentChannel = channel
|
||||
sentChatID = chatID
|
||||
sentContent = content
|
||||
return nil
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"content": "Hello, world!",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Verify message was sent with correct parameters
|
||||
if sentChannel != "test-channel" {
|
||||
t.Errorf("Expected channel 'test-channel', got '%s'", sentChannel)
|
||||
}
|
||||
if sentChatID != "test-chat-id" {
|
||||
t.Errorf("Expected chatID 'test-chat-id', got '%s'", sentChatID)
|
||||
}
|
||||
if sentContent != "Hello, world!" {
|
||||
t.Errorf("Expected content 'Hello, world!', got '%s'", sentContent)
|
||||
}
|
||||
|
||||
// Verify ToolResult meets US-011 criteria:
|
||||
// - Send success returns SilentResult (Silent=true)
|
||||
if !result.Silent {
|
||||
t.Error("Expected Silent=true for successful send")
|
||||
}
|
||||
|
||||
// - ForLLM contains send status description
|
||||
if result.ForLLM != "Message sent to test-channel:test-chat-id" {
|
||||
t.Errorf("Expected ForLLM 'Message sent to test-channel:test-chat-id', got '%s'", result.ForLLM)
|
||||
}
|
||||
|
||||
// - ForUser is empty (user already received message directly)
|
||||
if result.ForUser != "" {
|
||||
t.Errorf("Expected ForUser to be empty, got '%s'", result.ForUser)
|
||||
}
|
||||
|
||||
// - IsError should be false
|
||||
if result.IsError {
|
||||
t.Error("Expected IsError=false for successful send")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageTool_Execute_WithCustomChannel(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
tool.SetContext("default-channel", "default-chat-id")
|
||||
|
||||
var sentChannel, sentChatID string
|
||||
tool.SetSendCallback(func(channel, chatID, content string) error {
|
||||
sentChannel = channel
|
||||
sentChatID = chatID
|
||||
return nil
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"content": "Test message",
|
||||
"channel": "custom-channel",
|
||||
"chat_id": "custom-chat-id",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Verify custom channel/chatID were used instead of defaults
|
||||
if sentChannel != "custom-channel" {
|
||||
t.Errorf("Expected channel 'custom-channel', got '%s'", sentChannel)
|
||||
}
|
||||
if sentChatID != "custom-chat-id" {
|
||||
t.Errorf("Expected chatID 'custom-chat-id', got '%s'", sentChatID)
|
||||
}
|
||||
|
||||
if !result.Silent {
|
||||
t.Error("Expected Silent=true")
|
||||
}
|
||||
if result.ForLLM != "Message sent to custom-channel:custom-chat-id" {
|
||||
t.Errorf("Expected ForLLM 'Message sent to custom-channel:custom-chat-id', got '%s'", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageTool_Execute_SendFailure(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
tool.SetContext("test-channel", "test-chat-id")
|
||||
|
||||
sendErr := errors.New("network error")
|
||||
tool.SetSendCallback(func(channel, chatID, content string) error {
|
||||
return sendErr
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"content": "Test message",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Verify ToolResult for send failure:
|
||||
// - Send failure returns ErrorResult (IsError=true)
|
||||
if !result.IsError {
|
||||
t.Error("Expected IsError=true for failed send")
|
||||
}
|
||||
|
||||
// - ForLLM contains error description
|
||||
expectedErrMsg := "sending message: network error"
|
||||
if result.ForLLM != expectedErrMsg {
|
||||
t.Errorf("Expected ForLLM '%s', got '%s'", expectedErrMsg, result.ForLLM)
|
||||
}
|
||||
|
||||
// - Err field should contain original error
|
||||
if result.Err == nil {
|
||||
t.Error("Expected Err to be set")
|
||||
}
|
||||
if result.Err != sendErr {
|
||||
t.Errorf("Expected Err to be sendErr, got %v", result.Err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageTool_Execute_MissingContent(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
tool.SetContext("test-channel", "test-chat-id")
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{} // content missing
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Verify error result for missing content
|
||||
if !result.IsError {
|
||||
t.Error("Expected IsError=true for missing content")
|
||||
}
|
||||
if result.ForLLM != "content is required" {
|
||||
t.Errorf("Expected ForLLM 'content is required', got '%s'", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageTool_Execute_NoTargetChannel(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
// No SetContext called, so defaultChannel and defaultChatID are empty
|
||||
|
||||
tool.SetSendCallback(func(channel, chatID, content string) error {
|
||||
return nil
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"content": "Test message",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Verify error when no target channel specified
|
||||
if !result.IsError {
|
||||
t.Error("Expected IsError=true when no target channel")
|
||||
}
|
||||
if result.ForLLM != "No target channel/chat specified" {
|
||||
t.Errorf("Expected ForLLM 'No target channel/chat specified', got '%s'", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageTool_Execute_NotConfigured(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
tool.SetContext("test-channel", "test-chat-id")
|
||||
// No SetSendCallback called
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"content": "Test message",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Verify error when send callback not configured
|
||||
if !result.IsError {
|
||||
t.Error("Expected IsError=true when send callback not configured")
|
||||
}
|
||||
if result.ForLLM != "Message sending not configured" {
|
||||
t.Errorf("Expected ForLLM 'Message sending not configured', got '%s'", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageTool_Name(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
if tool.Name() != "message" {
|
||||
t.Errorf("Expected name 'message', got '%s'", tool.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageTool_Description(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
desc := tool.Description()
|
||||
if desc == "" {
|
||||
t.Error("Description should not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageTool_Parameters(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
params := tool.Parameters()
|
||||
|
||||
// Verify parameters structure
|
||||
typ, ok := params["type"].(string)
|
||||
if !ok || typ != "object" {
|
||||
t.Error("Expected type 'object'")
|
||||
}
|
||||
|
||||
props, ok := params["properties"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatal("Expected properties to be a map")
|
||||
}
|
||||
|
||||
// Check required properties
|
||||
required, ok := params["required"].([]string)
|
||||
if !ok || len(required) != 1 || required[0] != "content" {
|
||||
t.Error("Expected 'content' to be required")
|
||||
}
|
||||
|
||||
// Check content property
|
||||
contentProp, ok := props["content"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Error("Expected 'content' property")
|
||||
}
|
||||
if contentProp["type"] != "string" {
|
||||
t.Error("Expected content type to be 'string'")
|
||||
}
|
||||
|
||||
// Check channel property (optional)
|
||||
channelProp, ok := props["channel"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Error("Expected 'channel' property")
|
||||
}
|
||||
if channelProp["type"] != "string" {
|
||||
t.Error("Expected channel type to be 'string'")
|
||||
}
|
||||
|
||||
// Check chat_id property (optional)
|
||||
chatIDProp, ok := props["chat_id"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Error("Expected 'chat_id' property")
|
||||
}
|
||||
if chatIDProp["type"] != "string" {
|
||||
t.Error("Expected chat_id type to be 'string'")
|
||||
}
|
||||
}
|
||||
+61
-9
@@ -7,6 +7,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
type ToolRegistry struct {
|
||||
@@ -33,11 +34,14 @@ func (r *ToolRegistry) Get(name string) (Tool, bool) {
|
||||
return tool, ok
|
||||
}
|
||||
|
||||
func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string]interface{}) (string, error) {
|
||||
return r.ExecuteWithContext(ctx, name, args, "", "")
|
||||
func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string]interface{}) *ToolResult {
|
||||
return r.ExecuteWithContext(ctx, name, args, "", "", nil)
|
||||
}
|
||||
|
||||
func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args map[string]interface{}, channel, chatID string) (string, error) {
|
||||
// ExecuteWithContext executes a tool with channel/chatID context and optional async callback.
|
||||
// If the tool implements AsyncTool and a non-nil callback is provided,
|
||||
// the callback will be set on the tool before execution.
|
||||
func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args map[string]interface{}, channel, chatID string, asyncCallback AsyncCallback) *ToolResult {
|
||||
logger.InfoCF("tool", "Tool execution started",
|
||||
map[string]interface{}{
|
||||
"tool": name,
|
||||
@@ -50,7 +54,7 @@ func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args
|
||||
map[string]interface{}{
|
||||
"tool": name,
|
||||
})
|
||||
return "", fmt.Errorf("tool '%s' not found", name)
|
||||
return ErrorResult(fmt.Sprintf("tool %q not found", name)).WithError(fmt.Errorf("tool not found"))
|
||||
}
|
||||
|
||||
// If tool implements ContextualTool, set context
|
||||
@@ -58,27 +62,43 @@ func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args
|
||||
contextualTool.SetContext(channel, chatID)
|
||||
}
|
||||
|
||||
// If tool implements AsyncTool and callback is provided, set callback
|
||||
if asyncTool, ok := tool.(AsyncTool); ok && asyncCallback != nil {
|
||||
asyncTool.SetCallback(asyncCallback)
|
||||
logger.DebugCF("tool", "Async callback injected",
|
||||
map[string]interface{}{
|
||||
"tool": name,
|
||||
})
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
result, err := tool.Execute(ctx, args)
|
||||
result := tool.Execute(ctx, args)
|
||||
duration := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
// Log based on result type
|
||||
if result.IsError {
|
||||
logger.ErrorCF("tool", "Tool execution failed",
|
||||
map[string]interface{}{
|
||||
"tool": name,
|
||||
"duration": duration.Milliseconds(),
|
||||
"error": err.Error(),
|
||||
"error": result.ForLLM,
|
||||
})
|
||||
} else if result.Async {
|
||||
logger.InfoCF("tool", "Tool started (async)",
|
||||
map[string]interface{}{
|
||||
"tool": name,
|
||||
"duration": duration.Milliseconds(),
|
||||
})
|
||||
} else {
|
||||
logger.InfoCF("tool", "Tool execution completed",
|
||||
map[string]interface{}{
|
||||
"tool": name,
|
||||
"duration_ms": duration.Milliseconds(),
|
||||
"result_length": len(result),
|
||||
"result_length": len(result.ForLLM),
|
||||
})
|
||||
}
|
||||
|
||||
return result, err
|
||||
return result
|
||||
}
|
||||
|
||||
func (r *ToolRegistry) GetDefinitions() []map[string]interface{} {
|
||||
@@ -92,6 +112,38 @@ func (r *ToolRegistry) GetDefinitions() []map[string]interface{} {
|
||||
return definitions
|
||||
}
|
||||
|
||||
// ToProviderDefs converts tool definitions to provider-compatible format.
|
||||
// This is the format expected by LLM provider APIs.
|
||||
func (r *ToolRegistry) ToProviderDefs() []providers.ToolDefinition {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
definitions := make([]providers.ToolDefinition, 0, len(r.tools))
|
||||
for _, tool := range r.tools {
|
||||
schema := ToolToSchema(tool)
|
||||
|
||||
// Safely extract nested values with type checks
|
||||
fn, ok := schema["function"].(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
name, _ := fn["name"].(string)
|
||||
desc, _ := fn["description"].(string)
|
||||
params, _ := fn["parameters"].(map[string]interface{})
|
||||
|
||||
definitions = append(definitions, providers.ToolDefinition{
|
||||
Type: "function",
|
||||
Function: providers.ToolFunctionDefinition{
|
||||
Name: name,
|
||||
Description: desc,
|
||||
Parameters: params,
|
||||
},
|
||||
})
|
||||
}
|
||||
return definitions
|
||||
}
|
||||
|
||||
// List returns a list of all registered tool names.
|
||||
func (r *ToolRegistry) List() []string {
|
||||
r.mu.RLock()
|
||||
|
||||
@@ -0,0 +1,143 @@
|
||||
package tools
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
// ToolResult represents the structured return value from tool execution.
|
||||
// It provides clear semantics for different types of results and supports
|
||||
// async operations, user-facing messages, and error handling.
|
||||
type ToolResult struct {
|
||||
// ForLLM is the content sent to the LLM for context.
|
||||
// Required for all results.
|
||||
ForLLM string `json:"for_llm"`
|
||||
|
||||
// ForUser is the content sent directly to the user.
|
||||
// If empty, no user message is sent.
|
||||
// Silent=true overrides this field.
|
||||
ForUser string `json:"for_user,omitempty"`
|
||||
|
||||
// Silent suppresses sending any message to the user.
|
||||
// When true, ForUser is ignored even if set.
|
||||
Silent bool `json:"silent"`
|
||||
|
||||
// IsError indicates whether the tool execution failed.
|
||||
// When true, the result should be treated as an error.
|
||||
IsError bool `json:"is_error"`
|
||||
|
||||
// Async indicates whether the tool is running asynchronously.
|
||||
// When true, the tool will complete later and notify via callback.
|
||||
Async bool `json:"async"`
|
||||
|
||||
// Err is the underlying error (not JSON serialized).
|
||||
// Used for internal error handling and logging.
|
||||
Err error `json:"-"`
|
||||
}
|
||||
|
||||
// NewToolResult creates a basic ToolResult with content for the LLM.
|
||||
// Use this when you need a simple result with default behavior.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// result := NewToolResult("File updated successfully")
|
||||
func NewToolResult(forLLM string) *ToolResult {
|
||||
return &ToolResult{
|
||||
ForLLM: forLLM,
|
||||
}
|
||||
}
|
||||
|
||||
// SilentResult creates a ToolResult that is silent (no user message).
|
||||
// The content is only sent to the LLM for context.
|
||||
//
|
||||
// Use this for operations that should not spam the user, such as:
|
||||
// - File reads/writes
|
||||
// - Status updates
|
||||
// - Background operations
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// result := SilentResult("Config file saved")
|
||||
func SilentResult(forLLM string) *ToolResult {
|
||||
return &ToolResult{
|
||||
ForLLM: forLLM,
|
||||
Silent: true,
|
||||
IsError: false,
|
||||
Async: false,
|
||||
}
|
||||
}
|
||||
|
||||
// AsyncResult creates a ToolResult for async operations.
|
||||
// The task will run in the background and complete later.
|
||||
//
|
||||
// Use this for long-running operations like:
|
||||
// - Subagent spawns
|
||||
// - Background processing
|
||||
// - External API calls with callbacks
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// result := AsyncResult("Subagent spawned, will report back")
|
||||
func AsyncResult(forLLM string) *ToolResult {
|
||||
return &ToolResult{
|
||||
ForLLM: forLLM,
|
||||
Silent: false,
|
||||
IsError: false,
|
||||
Async: true,
|
||||
}
|
||||
}
|
||||
|
||||
// ErrorResult creates a ToolResult representing an error.
|
||||
// Sets IsError=true and includes the error message.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// result := ErrorResult("Failed to connect to database: connection refused")
|
||||
func ErrorResult(message string) *ToolResult {
|
||||
return &ToolResult{
|
||||
ForLLM: message,
|
||||
Silent: false,
|
||||
IsError: true,
|
||||
Async: false,
|
||||
}
|
||||
}
|
||||
|
||||
// UserResult creates a ToolResult with content for both LLM and user.
|
||||
// Both ForLLM and ForUser are set to the same content.
|
||||
//
|
||||
// Use this when the user needs to see the result directly:
|
||||
// - Command execution output
|
||||
// - Fetched web content
|
||||
// - Query results
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// result := UserResult("Total files found: 42")
|
||||
func UserResult(content string) *ToolResult {
|
||||
return &ToolResult{
|
||||
ForLLM: content,
|
||||
ForUser: content,
|
||||
Silent: false,
|
||||
IsError: false,
|
||||
Async: false,
|
||||
}
|
||||
}
|
||||
|
||||
// MarshalJSON implements custom JSON serialization.
|
||||
// The Err field is excluded from JSON output via the json:"-" tag.
|
||||
func (tr *ToolResult) MarshalJSON() ([]byte, error) {
|
||||
type Alias ToolResult
|
||||
return json.Marshal(&struct {
|
||||
*Alias
|
||||
}{
|
||||
Alias: (*Alias)(tr),
|
||||
})
|
||||
}
|
||||
|
||||
// WithError sets the Err field and returns the result for chaining.
|
||||
// This preserves the error for logging while keeping it out of JSON.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// result := ErrorResult("Operation failed").WithError(err)
|
||||
func (tr *ToolResult) WithError(err error) *ToolResult {
|
||||
tr.Err = err
|
||||
return tr
|
||||
}
|
||||
@@ -0,0 +1,229 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewToolResult(t *testing.T) {
|
||||
result := NewToolResult("test content")
|
||||
|
||||
if result.ForLLM != "test content" {
|
||||
t.Errorf("Expected ForLLM 'test content', got '%s'", result.ForLLM)
|
||||
}
|
||||
if result.Silent {
|
||||
t.Error("Expected Silent to be false")
|
||||
}
|
||||
if result.IsError {
|
||||
t.Error("Expected IsError to be false")
|
||||
}
|
||||
if result.Async {
|
||||
t.Error("Expected Async to be false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSilentResult(t *testing.T) {
|
||||
result := SilentResult("silent operation")
|
||||
|
||||
if result.ForLLM != "silent operation" {
|
||||
t.Errorf("Expected ForLLM 'silent operation', got '%s'", result.ForLLM)
|
||||
}
|
||||
if !result.Silent {
|
||||
t.Error("Expected Silent to be true")
|
||||
}
|
||||
if result.IsError {
|
||||
t.Error("Expected IsError to be false")
|
||||
}
|
||||
if result.Async {
|
||||
t.Error("Expected Async to be false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAsyncResult(t *testing.T) {
|
||||
result := AsyncResult("async task started")
|
||||
|
||||
if result.ForLLM != "async task started" {
|
||||
t.Errorf("Expected ForLLM 'async task started', got '%s'", result.ForLLM)
|
||||
}
|
||||
if result.Silent {
|
||||
t.Error("Expected Silent to be false")
|
||||
}
|
||||
if result.IsError {
|
||||
t.Error("Expected IsError to be false")
|
||||
}
|
||||
if !result.Async {
|
||||
t.Error("Expected Async to be true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorResult(t *testing.T) {
|
||||
result := ErrorResult("operation failed")
|
||||
|
||||
if result.ForLLM != "operation failed" {
|
||||
t.Errorf("Expected ForLLM 'operation failed', got '%s'", result.ForLLM)
|
||||
}
|
||||
if result.Silent {
|
||||
t.Error("Expected Silent to be false")
|
||||
}
|
||||
if !result.IsError {
|
||||
t.Error("Expected IsError to be true")
|
||||
}
|
||||
if result.Async {
|
||||
t.Error("Expected Async to be false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserResult(t *testing.T) {
|
||||
content := "user visible message"
|
||||
result := UserResult(content)
|
||||
|
||||
if result.ForLLM != content {
|
||||
t.Errorf("Expected ForLLM '%s', got '%s'", content, result.ForLLM)
|
||||
}
|
||||
if result.ForUser != content {
|
||||
t.Errorf("Expected ForUser '%s', got '%s'", content, result.ForUser)
|
||||
}
|
||||
if result.Silent {
|
||||
t.Error("Expected Silent to be false")
|
||||
}
|
||||
if result.IsError {
|
||||
t.Error("Expected IsError to be false")
|
||||
}
|
||||
if result.Async {
|
||||
t.Error("Expected Async to be false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolResultJSONSerialization(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
result *ToolResult
|
||||
}{
|
||||
{
|
||||
name: "basic result",
|
||||
result: NewToolResult("basic content"),
|
||||
},
|
||||
{
|
||||
name: "silent result",
|
||||
result: SilentResult("silent content"),
|
||||
},
|
||||
{
|
||||
name: "async result",
|
||||
result: AsyncResult("async content"),
|
||||
},
|
||||
{
|
||||
name: "error result",
|
||||
result: ErrorResult("error content"),
|
||||
},
|
||||
{
|
||||
name: "user result",
|
||||
result: UserResult("user content"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Marshal to JSON
|
||||
data, err := json.Marshal(tt.result)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal: %v", err)
|
||||
}
|
||||
|
||||
// Unmarshal back
|
||||
var decoded ToolResult
|
||||
if err := json.Unmarshal(data, &decoded); err != nil {
|
||||
t.Fatalf("Failed to unmarshal: %v", err)
|
||||
}
|
||||
|
||||
// Verify fields match (Err should be excluded)
|
||||
if decoded.ForLLM != tt.result.ForLLM {
|
||||
t.Errorf("ForLLM mismatch: got '%s', want '%s'", decoded.ForLLM, tt.result.ForLLM)
|
||||
}
|
||||
if decoded.ForUser != tt.result.ForUser {
|
||||
t.Errorf("ForUser mismatch: got '%s', want '%s'", decoded.ForUser, tt.result.ForUser)
|
||||
}
|
||||
if decoded.Silent != tt.result.Silent {
|
||||
t.Errorf("Silent mismatch: got %v, want %v", decoded.Silent, tt.result.Silent)
|
||||
}
|
||||
if decoded.IsError != tt.result.IsError {
|
||||
t.Errorf("IsError mismatch: got %v, want %v", decoded.IsError, tt.result.IsError)
|
||||
}
|
||||
if decoded.Async != tt.result.Async {
|
||||
t.Errorf("Async mismatch: got %v, want %v", decoded.Async, tt.result.Async)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolResultWithErrors(t *testing.T) {
|
||||
err := errors.New("underlying error")
|
||||
result := ErrorResult("error message").WithError(err)
|
||||
|
||||
if result.Err == nil {
|
||||
t.Error("Expected Err to be set")
|
||||
}
|
||||
if result.Err.Error() != "underlying error" {
|
||||
t.Errorf("Expected Err message 'underlying error', got '%s'", result.Err.Error())
|
||||
}
|
||||
|
||||
// Verify Err is not serialized
|
||||
data, marshalErr := json.Marshal(result)
|
||||
if marshalErr != nil {
|
||||
t.Fatalf("Failed to marshal: %v", marshalErr)
|
||||
}
|
||||
|
||||
var decoded ToolResult
|
||||
if unmarshalErr := json.Unmarshal(data, &decoded); unmarshalErr != nil {
|
||||
t.Fatalf("Failed to unmarshal: %v", unmarshalErr)
|
||||
}
|
||||
|
||||
if decoded.Err != nil {
|
||||
t.Error("Expected Err to be nil after JSON round-trip (should not be serialized)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolResultJSONStructure(t *testing.T) {
|
||||
result := UserResult("test content")
|
||||
|
||||
data, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal: %v", err)
|
||||
}
|
||||
|
||||
// Verify JSON structure
|
||||
var parsed map[string]interface{}
|
||||
if err := json.Unmarshal(data, &parsed); err != nil {
|
||||
t.Fatalf("Failed to parse JSON: %v", err)
|
||||
}
|
||||
|
||||
// Check expected keys exist
|
||||
if _, ok := parsed["for_llm"]; !ok {
|
||||
t.Error("Expected 'for_llm' key in JSON")
|
||||
}
|
||||
if _, ok := parsed["for_user"]; !ok {
|
||||
t.Error("Expected 'for_user' key in JSON")
|
||||
}
|
||||
if _, ok := parsed["silent"]; !ok {
|
||||
t.Error("Expected 'silent' key in JSON")
|
||||
}
|
||||
if _, ok := parsed["is_error"]; !ok {
|
||||
t.Error("Expected 'is_error' key in JSON")
|
||||
}
|
||||
if _, ok := parsed["async"]; !ok {
|
||||
t.Error("Expected 'async' key in JSON")
|
||||
}
|
||||
|
||||
// Check that 'err' is NOT present (it should have json:"-" tag)
|
||||
if _, ok := parsed["err"]; ok {
|
||||
t.Error("Expected 'err' key to be excluded from JSON")
|
||||
}
|
||||
|
||||
// Verify values
|
||||
if parsed["for_llm"] != "test content" {
|
||||
t.Errorf("Expected for_llm 'test content', got %v", parsed["for_llm"])
|
||||
}
|
||||
if parsed["silent"] != false {
|
||||
t.Errorf("Expected silent false, got %v", parsed["silent"])
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user