Merge remote-tracking branch 'origin/main' into feat/cron-exec-timeout-config

This commit is contained in:
yinwm
2026-02-17 20:53:31 +08:00
52 changed files with 3459 additions and 241 deletions
+28
View File
@@ -0,0 +1,28 @@
---
name: Bug report
about: Report a bug or unexpected behavior
title: "[BUG]"
labels: bug
assignees: ''
---
## Quick Summary
## Environment & Tools
- **PicoClaw Version:** (e.g., v0.1.2 or commit hash)
- **Go Version:** (e.g., go 1.22)
- **AI Model & Provider:** (e.g., GPT-4o via OpenAI / DeepSeek via SiliconFlow)
- **Operating System:** (e.g., Ubuntu 22.04 / macOS / Android Termux)
- **Channels:** (e.g., Discord, Telegram, Feishu, ...)
## 📸 Steps to Reproduce
1.
2.
3.
## ❌ Actual Behavior
## ✅ Expected Behavior
## 💬 Additional Context
+23
View File
@@ -0,0 +1,23 @@
---
name: Feature request
about: Suggest a new idea or improvement
title: "[Feature]"
labels: enhancement
assignees: ''
---
## 🎯 The Goal / Use Case
## 💡 Proposed Solution
## 🛠 Potential Implementation (Optional)
## 🚦 Impact & Roadmap Alignment
- [ ] This is a Core Feature
- [ ] This is a Nice-to-Have / Enhancement
- [ ] This aligns with the current Roadmap
## 🔄 Alternatives Considered
## 💬 Additional Context
@@ -0,0 +1,26 @@
---
name: General Task / Todo
about: A specific piece of work like doc, refactoring, or maintenance.
title: "[Task]"
labels: ''
assignees: ''
---
## 📝 Objective
## 📋 To-Do List
- [ ] Step 1
- [ ] Step 2
- [ ] Step 3
## 🎯 Definition of Done (Acceptance Criteria)
- [ ] Documentation is updated in the README/docs folder.
- [ ] Code follows project linting standards.
- [ ] (If applicable) Basic tests pass.
## 💡 Context / Motivation
## 🔗 Related Issues / PRs
- Fixes #
- Relates to #
+37
View File
@@ -0,0 +1,37 @@
## 📝 Description
## 🗣️ Type of Change
- [ ] 🐞 Bug fix (non-breaking change which fixes an issue)
- [ ] ✨ New feature (non-breaking change which adds functionality)
- [ ] 📖 Documentation update
- [ ] ⚡ Code refactoring (no functional changes, no api changes)
## 🤖 AI Code Generation
- [ ] 🤖 Fully AI-generated (100% AI, 0% Human)
- [ ] 🛠️ Mostly AI-generated (AI draft, Human verified/modified)
- [ ] 👨‍💻 Mostly Human-written (Human lead, AI assisted or none)
## 🔗 Linked Issue
## 📚 Technical Context (Skip for Docs)
* **Reference:** [URL]
* **Reasoning:** ...
## 🧪 Test Environment & Hardware
- **Hardware:** [e.g. Raspberry Pi 5, Orange Pi, PC]
- **OS:** [e.g. Debian 12, Ubuntu 22.04]
- **Model/Provider:** [e.g. OpenAI GPT-4o, Kimi k2, DeepSeek-V3]
- **Channels:** [e.g. Discord, Telegram, Feishu, ...]
## 📸 Proof of Work (Optional for Docs)
<details>
<summary>Click to view Logs/Screenshots</summary>
</details>
## ☑️ Checklist
- [ ] My code/docs follow the style of this project.
- [ ] I have performed a self-review of my own changes.
- [ ] I have updated the documentation accordingly.
+37 -23
View File
@@ -1,12 +1,18 @@
name: 🐳 Build & Push Docker Image
on:
release:
types: [published]
workflow_call:
inputs:
tag:
description: "Release tag"
required: true
type: string
env:
REGISTRY: ghcr.io
IMAGE_NAME: ${{ github.repository_owner }}/picoclaw
GHCR_REGISTRY: ghcr.io
GHCR_IMAGE_NAME: ${{ github.repository_owner }}/picoclaw
DOCKERHUB_REGISTRY: docker.io
DOCKERHUB_IMAGE_NAME: ${{ vars.DOCKERHUB_REPOSITORY }}
jobs:
build:
@@ -20,6 +26,8 @@ jobs:
# ── Checkout ──────────────────────────────
- name: 📥 Checkout repository
uses: actions/checkout@v4
with:
ref: ${{ inputs.tag }}
# ── Docker Buildx ─────────────────────────
- name: 🔧 Set up Docker Buildx
@@ -27,36 +35,42 @@ jobs:
# ── Login to GHCR ─────────────────────────
- name: 🔑 Login to GitHub Container Registry
if: github.event_name != 'pull_request'
uses: docker/login-action@v3
with:
registry: ${{ env.REGISTRY }}
registry: ${{ env.GHCR_REGISTRY }}
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
# ── Metadata (tags & labels) ──────────────
- name: 🏷️ Extract Docker metadata
id: meta
uses: docker/metadata-action@v5
# ── Login to Docker Hub ────────────────────
- name: 🔑 Login to Docker Hub
uses: docker/login-action@v3
with:
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
tags: |
type=ref,event=branch
type=ref,event=pr
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
type=sha,prefix=
type=raw,value=latest,enable={{is_default_branch}}
type=raw,value={{date 'YYYYMMDD-HHmmss'}},enable={{is_default_branch}}
registry: ${{ env.DOCKERHUB_REGISTRY }}
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
# ── Metadata (tags & labels) ──────────────
- name: 🏷️ Prepare image tags
id: tags
shell: bash
run: |
tag="${{ inputs.tag }}"
echo "ghcr_tag=${{ env.GHCR_REGISTRY }}/${{ env.GHCR_IMAGE_NAME }}:${tag}" >> "$GITHUB_OUTPUT"
echo "ghcr_latest=${{ env.GHCR_REGISTRY }}/${{ env.GHCR_IMAGE_NAME }}:latest" >> "$GITHUB_OUTPUT"
echo "dockerhub_tag=${{ env.DOCKERHUB_REGISTRY }}/${{ env.DOCKERHUB_IMAGE_NAME }}:${tag}" >> "$GITHUB_OUTPUT"
echo "dockerhub_latest=${{ env.DOCKERHUB_REGISTRY }}/${{ env.DOCKERHUB_IMAGE_NAME }}:latest" >> "$GITHUB_OUTPUT"
# ── Build & Push ──────────────────────────
- name: 🚀 Build and push Docker image
uses: docker/build-push-action@v6
with:
context: .
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
push: true
tags: |
${{ steps.tags.outputs.ghcr_tag }}
${{ steps.tags.outputs.ghcr_latest }}
${{ steps.tags.outputs.dockerhub_tag }}
${{ steps.tags.outputs.dockerhub_latest }}
cache-from: type=gha
cache-to: type=gha,mode=max
platforms: linux/amd64,linux/arm64
platforms: linux/amd64,linux/arm64,linux/riscv64
+40 -39
View File
@@ -32,20 +32,26 @@ jobs:
- name: Create and push tag
shell: bash
env:
RELEASE_TAG: ${{ inputs.tag }}
run: |
git config user.name "github-actions[bot]"
git config user.email "github-actions[bot]@users.noreply.github.com"
git tag -a "${{ inputs.tag }}" -m "Release ${{ inputs.tag }}"
git push origin "${{ inputs.tag }}"
git tag -a "$RELEASE_TAG" -m "Release $RELEASE_TAG"
git push origin "$RELEASE_TAG"
build-binaries:
name: Build Release Binaries
release:
name: GoReleaser Release
needs: create-tag
runs-on: ubuntu-latest
permissions:
contents: write
packages: write
steps:
- name: Checkout tag
uses: actions/checkout@v4
with:
fetch-depth: 0
ref: ${{ inputs.tag }}
- name: Setup Go from go.mod
@@ -53,47 +59,42 @@ jobs:
with:
go-version-file: go.mod
- name: Build all binaries
run: make build-all
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Generate checksums
shell: bash
run: |
shasum -a 256 build/picoclaw-* > build/sha256sums.txt
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Upload release binaries artifact
uses: actions/upload-artifact@v4
- name: Login to GitHub Container Registry
uses: docker/login-action@v3
with:
name: picoclaw-binaries
path: |
build/picoclaw-*
build/sha256sums.txt
if-no-files-found: error
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
create-release:
name: Create GitHub Release
needs: [create-tag, build-binaries]
runs-on: ubuntu-latest
permissions:
contents: write
steps:
- name: Download all artifacts
uses: actions/download-artifact@v4
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
path: release-artifacts
registry: docker.io
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Show downloaded files
run: ls -R release-artifacts
- name: Create release
uses: softprops/action-gh-release@v2
- name: Run GoReleaser
uses: goreleaser/goreleaser-action@v6
with:
tag_name: ${{ inputs.tag }}
name: ${{ inputs.tag }}
draft: ${{ inputs.draft }}
prerelease: ${{ inputs.prerelease }}
files: |
release-artifacts/**/*
generate_release_notes: true
distribution: goreleaser
version: ~> v2
args: release --clean
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
GITHUB_REPOSITORY_OWNER: ${{ github.repository_owner }}
DOCKERHUB_IMAGE_NAME: ${{ vars.DOCKERHUB_REPOSITORY }}
- name: Apply release flags
shell: bash
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
gh release edit "${{ inputs.tag }}" \
--draft=${{ inputs.draft }} \
--prerelease=${{ inputs.prerelease }}
+23 -5
View File
@@ -5,9 +5,11 @@ version: 2
before:
hooks:
- go mod tidy
- go generate ./cmd/picoclaw
builds:
- env:
- id: picoclaw
env:
- CGO_ENABLED=0
goos:
- linux
@@ -26,6 +28,22 @@ builds:
- goos: windows
goarch: arm
dockers_v2:
- id: picoclaw
dockerfile: Dockerfile.goreleaser
ids:
- picoclaw
images:
- "ghcr.io/{{ .Env.GITHUB_REPOSITORY_OWNER }}/picoclaw"
- "docker.io/{{ .Env.DOCKERHUB_IMAGE_NAME }}"
tags:
- "{{ .Tag }}"
- "latest"
platforms:
- linux/amd64
- linux/arm64
- linux/riscv64
archives:
- formats: [tar.gz]
# this name template makes the OS and Arch compatible with the results of `uname`.
@@ -48,10 +66,10 @@ changelog:
- "^docs:"
- "^test:"
upx:
- enabled: true
compress: best
lzma: true
# upx:
# - enabled: true
# compress: best
# lzma: true
release:
footer: >-
+4
View File
@@ -22,6 +22,10 @@ FROM alpine:3.23
RUN apk add --no-cache ca-certificates tzdata curl
# Health check
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
CMD wget -q --spider http://localhost:18790/health || exit 1
# Copy binary
COPY --from=builder /src/build/picoclaw /usr/local/bin/picoclaw
+10
View File
@@ -0,0 +1,10 @@
FROM alpine:3.21
ARG TARGETPLATFORM
RUN apk add --no-cache ca-certificates tzdata
COPY $TARGETPLATFORM/picoclaw /usr/local/bin/picoclaw
ENTRYPOINT ["picoclaw"]
CMD ["gateway"]
+10 -2
View File
@@ -119,7 +119,7 @@ clean:
@rm -rf $(BUILD_DIR)
@echo "Clean complete"
## fmt: Format Go code
## vet: Run go vet for static analysis
vet:
@$(GO) vet ./...
@@ -131,11 +131,19 @@ test:
fmt:
@$(GO) fmt ./...
## deps: Update dependencies
## deps: Download dependencies
deps:
@$(GO) mod download
@$(GO) mod verify
## update-deps: Update dependencies
update-deps:
@$(GO) get -u ./...
@$(GO) mod tidy
## check: Run vet, fmt, and verify dependencies
check: deps fmt vet test
## run: Build and run picoclaw
run: build
@$(BUILD_DIR)/$(BINARY_NAME) $(ARGS)
+18 -1
View File
@@ -44,9 +44,12 @@
> * **NO CRYPTO:** PicoClaw has **NO** official token/coin. All claims on `pump.fun` or other trading platforms are **SCAMS**.
> * **OFFICIAL DOMAIN:** The **ONLY** official website is **[picoclaw.io](https://picoclaw.io)**, and company website is **[sipeed.com](https://sipeed.com)**
> * **Warning:** Many `.ai/.org/.com/.net/...` domains are registered by third parties.
>
> * **Warning:** picoclaw is in early development now and may have unresolved network security issues. Do not deploy to production environments before the v1.0 release.
> * **Note:** picoclaw has recently merged a lot of PRs, which may result in a larger memory footprint (1020MB) in the latest versions. We plan to prioritize resource optimization as soon as the current feature set reaches a stable state.
## 📢 News
2026-02-16 🎉 PicoClaw hit 12K stars in one week! Thank you all for your support! PicoClaw is growing faster than we ever imagined. Given the high volume of PRs, we urgently need community maintainers. Our volunteer roles and roadmap are officially posted [here](docs/picoclaw_community_roadmap_260216.md) —we cant wait to have you on board!
2026-02-13 🎉 PicoClaw hit 5000 stars in 4days! Thank you for the community! There are so many PRs&issues come in (during Chinese New Year holidays), we are finalizing the Project Roadmap and setting up the Developer Group to accelerate PicoClaw's development.
🚀 Call to Action: Please submit your feature requests in GitHub Discussions. We will review and prioritize them during our upcoming weekly meeting.
@@ -96,6 +99,20 @@
</tr>
</table>
### 📱 Run on old Android Phones
Give your decade-old phone a second life! Turn it into a smart AI Assistant with PicoClaw. Quick Start:
1. **Install Termux** (Available on F-Droid or Google Play).
2. **Execute cmds**
```bash
# Note: Replace v0.1.1 with the latest version from the Releases page
wget https://github.com/sipeed/picoclaw/releases/download/v0.1.1/picoclaw-linux-arm64
chmod +x picoclaw-linux-arm64
pkg install proot
termux-chroot ./picoclaw-linux-arm64 onboard
```
And then follow the instructions in the "Quick Start" section to complete the configuration!
<img src="assets/termux.jpg" alt="PicoClaw" width="512">
### 🐜 Innovative Low-Footprint Deploy
PicoClaw can be deployed on almost any Linux device!
+21 -2
View File
@@ -45,10 +45,12 @@
> * **无加密货币 (NO CRYPTO):** PicoClaw **没有** 发行任何官方代币、Token 或虚拟货币。所有在 `pump.fun` 或其他交易平台上的相关声称均为 **诈骗**。
> * **官方域名:** 唯一的官方网站是 **[picoclaw.io](https://picoclaw.io)**,公司官网是 **[sipeed.com](https://sipeed.com)**。
> * **警惕:** 许多 `.ai/.org/.com/.net/...` 后缀的域名被第三方抢注,请勿轻信。
>
>
> * **注意:** picoclaw正在初期的快速功能开发阶段,可能有尚未修复的网络安全问题,在1.0正式版发布前,请不要将其部署到生产环境中
> * **注意:** picoclaw最近合并了大量PRs,近期版本可能内存占用较大(10~20MB),我们将在功能较为收敛后进行资源占用优化.
## 📢 新闻 (News)
2026-02-16 🎉 PicoClaw 在一周内突破了12K star! 感谢大家的关注!PicoClaw 的成长速度超乎我们预期. 由于PR数量的快速膨胀,我们亟需社区开发者参与维护. 我们需要的志愿者角色和roadmap已经发布到了[这里](docs/picoclaw_community_roadmap_260216.md), 期待你的参与!
2026-02-13 🎉 **PicoClaw 在 4 天内突破 5000 Stars** 感谢社区的支持!由于正值中国春节假期,PR 和 Issue 涌入较多,我们正在利用这段时间敲定 **项目路线图 (Roadmap)** 并组建 **开发者群组**,以便加速 PicoClaw 的开发。
🚀 **行动号召:** 请在 GitHub Discussions 中提交您的功能请求 (Feature Requests)。我们将在接下来的周会上进行审查和优先级排序。
@@ -98,6 +100,23 @@
</tr>
</table>
### 📱 在手机上轻松运行
picoclaw 可以将你10年前的老旧手机废物利用,变身成为你的AI助理!快速指南:
1. 先去应用商店下载安装Termux
2. 打开后执行指令
```bash
# 注意: 下面的v0.1.1 可以换为你实际看到的最新版本
wget https://github.com/sipeed/picoclaw/releases/download/v0.1.1/picoclaw-linux-arm64
chmod +x picoclaw-linux-arm64
pkg install proot
termux-chroot ./picoclaw-linux-arm64 onboard
```
然后跟随下面的“快速开始”章节继续配置picoclaw即可使用!
<img src="assets/termux.jpg" alt="PicoClaw" width="512">
### 🐜 创新的低占用部署
PicoClaw 几乎可以部署在任何 Linux 设备上!
+116
View File
@@ -0,0 +1,116 @@
# 🦐 PicoClaw Roadmap
> **Vision**: To build the ultimate lightweight, secure, and fully autonomous AI Agent infrastructure.automate the mundane, unleash your creativity
---
## 🚀 1. Core Optimization: Extreme Lightweight
*Our defining characteristic. We fight software bloat to ensure PicoClaw runs smoothly on the smallest embedded devices.*
* [**Memory Footprint Reduction**](https://github.com/sipeed/picoclaw/issues/346)
* **Goal**: Run smoothly on 64MB RAM embedded boards (e.g., low-end RISC-V SBCs) with the core process consuming < 20MB.
* **Context**: RAM is expensive and scarce on edge devices. Memory optimization takes precedence over storage size.
* **Action**: Analyze memory growth between releases, remove redundant dependencies, and optimize data structures.
## 🛡️ 2. Security Hardening: Defense in Depth
*Paying off early technical debt. We invite security experts to help build a "Secure-by-Default" agent.*
* **Input Defense & Permission Control**
* **Prompt Injection Defense**: Harden JSON extraction logic to prevent LLM manipulation.
* **Tool Abuse Prevention**: Strict parameter validation to ensure generated commands stay within safe boundaries.
* **SSRF Protection**: Built-in blocklists for network tools to prevent accessing internal IPs (LAN/Metadata services).
* **Sandboxing & Isolation**
* **Filesystem Sandbox**: Restrict file R/W operations to specific directories only.
* **Context Isolation**: Prevent data leakage between different user sessions or channels.
* **Privacy Redaction**: Auto-redact sensitive info (API Keys, PII) from logs and standard outputs.
* **Authentication & Secrets**
* **Crypto Upgrade**: Adopt modern algorithms like `ChaCha20-Poly1305` for secret storage.
* **OAuth 2.0 Flow**: Deprecate hardcoded API keys in the CLI; move to secure OAuth flows.
## 🔌 3. Connectivity: Protocol-First Architecture
*Connect every model, reach every platform.*
* **Provider**
* [**Architecture Upgrade**](https://github.com/sipeed/picoclaw/issues/283): Refactor from "Vendor-based" to "Protocol-based" classification (e.g., OpenAI-compatible, Ollama-compatible). *(Status: In progress by @Daming, ETA 5 days)*
* **Local Models**: Deep integration with **Ollama**, **vLLM**, **LM Studio**, and **Mistral** (local inference).
* **Online Models**: Continued support for frontier closed-source models.
* **Channel**
* **IM Matrix**: QQ, WeChat (Work), DingTalk, Feishu (Lark), Telegram, Discord, WhatsApp, LINE, Slack, Email, KOOK, Signal, ...
* **Standards**: Support for the **OneBot** protocol.
* [**attachment**](https://github.com/sipeed/picoclaw/issues/348): Native handling of images, audio, and video attachments.
* **Skill Marketplace**
* [**Discovery skills**](https://github.com/sipeed/picoclaw/issues/287): Implement `find_skill` to automatically discover and install skills from the [GitHub Skills Repo] or other registries.
## 🧠 4. Advanced Capabilities: From Chatbot to Agentic AI
*Beyond conversation—focusing on action and collaboration.*
* **Operations**
* [**MCP Support**](https://github.com/sipeed/picoclaw/issues/290): Native support for the **Model Context Protocol (MCP)**.
* [**Browser Automation**](https://github.com/sipeed/picoclaw/issues/293): Headless browser control via CDP (Chrome DevTools Protocol) or ActionBook.
* [**Mobile Operation**](https://github.com/sipeed/picoclaw/issues/292): Android device control (similar to BotDrop).
* **Multi-Agent Collaboration**
* [**Basic Multi-Agent**](https://github.com/sipeed/picoclaw/issues/294) implement
* [**Model Routing**](https://github.com/sipeed/picoclaw/issues/295): "Smart Routing" — dispatch simple tasks to small/local models (fast/cheap) and complex tasks to SOTA models (smart).
* [**Swarm Mode**](https://github.com/sipeed/picoclaw/issues/284): Collaboration between multiple PicoClaw instances on the same network.
* [**AIEOS**](https://github.com/sipeed/picoclaw/issues/296): Exploring AI-Native Operating System interaction paradigms.
## 📚 5. Developer Experience (DevEx) & Documentation
*Lowering the barrier to entry so anyone can deploy in minutes.*
* [**QuickGuide (Zero-Config Start)**](https://github.com/sipeed/picoclaw/issues/350)
* Interactive CLI Wizard: If launched without config, automatically detect the environment and guide the user through Token/Network setup step-by-step.
* **Comprehensive Documentation**
* **Platform Guides**: Dedicated guides for Windows, macOS, Linux, and Android.
* **Step-by-Step Tutorials**: "Babysitter-level" guides for configuring Providers and Channels.
* **AI-Assisted Docs**: Using AI to auto-generate API references and code comments (with human verification to prevent hallucinations).
## 🤖 6. Engineering: AI-Powered Open Source
*Born from Vibe Coding, we continue to use AI to accelerate development.*
* **AI-Enhanced CI/CD**
* Integrate AI for automated Code Review, Linting, and PR Labeling.
* **Bot Noise Reduction**: Optimize bot interactions to keep PR timelines clean.
* **Issue Triage**: AI agents to analyze incoming issues and suggest preliminary fixes.
## 🎨 7. Brand & Community
* [**Logo Design**](https://github.com/sipeed/picoclaw/issues/297): We are looking for a **Mantis Shrimp (Stomatopoda)** logo design!
* *Concept*: Needs to reflect "Small but Mighty" and "Lightning Fast Strikes."
---
### 🤝 Call for Contributions
We welcome community contributions to any item on this roadmap! Please comment on the relevant Issue or submit a PR. Let's build the best Edge AI Agent together!
BIN
View File
Binary file not shown.

After

Width:  |  Height:  |  Size: 97 KiB

BIN
View File
Binary file not shown.

Before

Width:  |  Height:  |  Size: 145 KiB

After

Width:  |  Height:  |  Size: 142 KiB

+18 -3
View File
@@ -13,6 +13,7 @@ import (
"fmt"
"io"
"io/fs"
"net/http"
"os"
"os/signal"
"path/filepath"
@@ -28,6 +29,7 @@ import (
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/cron"
"github.com/sipeed/picoclaw/pkg/devices"
"github.com/sipeed/picoclaw/pkg/health"
"github.com/sipeed/picoclaw/pkg/heartbeat"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/migrate"
@@ -560,7 +562,8 @@ func gatewayCmd() {
})
// Setup cron tool and service
cronService := setupCronTool(agentLoop, msgBus, cfg.WorkspacePath(), time.Duration(cfg.Tools.Cron.ExecTimeoutMinutes)*time.Minute)
execTimeout := time.Duration(cfg.Tools.Cron.ExecTimeoutMinutes) * time.Minute
cronService := setupCronTool(agentLoop, msgBus, cfg.WorkspacePath(), cfg.Agents.Defaults.RestrictToWorkspace, execTimeout)
heartbeatService := heartbeat.NewHeartbeatService(
cfg.WorkspacePath(),
@@ -592,6 +595,9 @@ func gatewayCmd() {
os.Exit(1)
}
// Inject channel manager into agent loop for command handling
agentLoop.SetChannelManager(channelManager)
var transcriber *voice.GroqTranscriber
if cfg.Providers.Groq.APIKey != "" {
transcriber = voice.NewGroqTranscriber(cfg.Providers.Groq.APIKey)
@@ -658,6 +664,14 @@ func gatewayCmd() {
fmt.Printf("Error starting channels: %v\n", err)
}
healthServer := health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port)
go func() {
if err := healthServer.Start(); err != nil && err != http.ErrServerClosed {
logger.ErrorCF("health", "Health server error", map[string]interface{}{"error": err.Error()})
}
}()
fmt.Printf("✓ Health endpoints available at http://%s:%d/health and /ready\n", cfg.Gateway.Host, cfg.Gateway.Port)
go agentLoop.Run(ctx)
sigChan := make(chan os.Signal, 1)
@@ -666,6 +680,7 @@ func gatewayCmd() {
fmt.Println("\nShutting down...")
cancel()
healthServer.Stop(context.Background())
deviceService.Stop()
heartbeatService.Stop()
cronService.Stop()
@@ -973,14 +988,14 @@ func getConfigPath() string {
return filepath.Join(home, ".picoclaw", "config.json")
}
func setupCronTool(agentLoop *agent.AgentLoop, msgBus *bus.MessageBus, workspace string, execTimeout time.Duration) *cron.CronService {
func setupCronTool(agentLoop *agent.AgentLoop, msgBus *bus.MessageBus, workspace string, restrict bool, execTimeout time.Duration) *cron.CronService {
cronStorePath := filepath.Join(workspace, "cron", "jobs.json")
// Create cron service
cronService := cron.NewCronService(cronStorePath, nil)
// Create and register CronTool
cronTool := tools.NewCronTool(cronService, agentLoop, msgBus, workspace, execTimeout)
cronTool := tools.NewCronTool(cronService, agentLoop, msgBus, workspace, restrict, execTimeout)
agentLoop.RegisterTool(cronTool)
// Set the onJob handler
+4
View File
@@ -107,6 +107,10 @@
"moonshot": {
"api_key": "sk-xxx",
"api_base": ""
},
"ollama": {
"api_key": "",
"api_base": "http://localhost:11434/v1"
}
},
"tools": {
+112
View File
@@ -0,0 +1,112 @@
## 🚀 Join the PicoClaw Journey: Call for Community Volunteers & Roadmap Reveal
**Hello, PicoClaw Community!**
First, a massive thank you to everyone for your enthusiasm and PR contributions. It is because of you that PicoClaw continues to iterate and evolve so rapidly. Thanks to the simplicity and accessibility of the **Go language**, weve seen a non-stop stream of high-quality PRs!
PicoClaw is growing much faster than we anticipated. As we are currently in the midst of the **Chinese New Year holiday**, we are looking to recruit community volunteers to help us maintain this incredible momentum.
This document outlines the specific volunteer roles we need right now and provides a look at our upcoming **Roadmap**.
### 🎁 Community Perks
To show our appreciation, developers who officially join our community operations will receive:
* **Exclusive AI Hardware:** Our upcoming, unreleased AI device.
* **Token Discounts:** Potential discounts on LLM tokens (currently in negotiations with major providers).
### 🎥 Calling All Content Creators!
Not a developer? You can still help! We welcome users to post **PicoClaw reviews or tutorials**.
* **Twitter:** Use the tag **#picoclaw** and mention **@SipeedIO**.
* **Bilibili:** Mention **@Sipeed矽速科技** or send us a DM.
We will be rewarding high-quality content creators with the same perks as our community developers!
---
## 🛠️ Urgent Volunteer Roles
We are looking for experts in the following areas:
1. **Issue/PR Reviewers**
* **The Mission:** With PRs and Issues exploding in volume, we need help with initial triage, evaluation, and merging.
* **Focus:** Preliminary merging and community health. Efficiency optimization and security audits will be handled by specialized roles.
2. **Resource Optimization Experts**
* **The Mission:** Rapid growth has introduced dependencies that are making PicoClaw a bit "heavy." We want to keep it lean.
* **Focus:** Analyzing resource growth between releases and trimming redundancy.
* **Priority:** **RAM usage optimization** > Binary size reduction.
3. **Security Audit & Bug Fixes**
* **The Mission:** Due to the "vibe coding" nature of our early stages, we need a thorough review of network security and AI permission management.
* **Focus:** Auditing the codebase for vulnerabilities and implementing robust fixes.
4. **Documentation & DX (Developer Experience)**
* **The Mission:** Our current README is a bit outdated. We need "step-by-step" guides that even beginners can follow.
* **Focus:** Creating clear, user-friendly documentation for both setup and development.
5. **AI-Powered CI/CD Optimization**
* **The Mission:** PicoClaw started as a "vibe coding" experiment; now we want to use AI to manage it.
* **Focus:** Automating builds with AI and exploring AI-driven issue resolution.
**How to Apply:** > If you are interested in any of the roles above, please send an email to support@sipeed.com with the subject line: [Apply: PicoClaw Expert Volunteer] + Your Desired Role.
Please include a brief introduction and any relevant experience or portfolio links. We will review all applications and grant project permissions to selected contributors!
---
## 📍 The Roadmap
Interested in a specific feature? You can "claim" these tasks and start building:
###
* **Provider:**
* **Provider Refactor:** Currently being handled by **@Daming** (ETA: 5 days)
* You can still submit code; Daming will merge it into the new implementation.
* **Channels:**
* Support for OneBot, additional platforms
* attachments (images, audio, video, files).
* **Skills:**
* Implementing `find_skill` to discover tools via [openclaw/skills](https://github.com/openclaw/skills) and other platforms.
* **Operations:** * MCP Support.
* Android operations (e.g., botdrop).
* Browser automation via CDP or ActionBook.
* **Multi-Agent Ecosystem:**
* **Basic Model-Agnet** S
* **Model Routing:** Small models for easy tasks, large models for hard ones (to save tokens).
* **Swarm Mode.**
* **AIEOS Integration.**
* **Branding:**
* **Logo**: We need a cute logo! Were leaning toward a **Mantis Shrimp**—small, but packs a legendary punch!
We have officially created these tasks as GitHub Issues, all marked with the roadmap tag.
This list will be updated continuously as we progress.
If you would like to claim a task, please feel free to start a conversation by commenting directly on the corresponding issue!
---
## 🤝 How to Join
**Everything is open to your creativity!** If you have a wild idea, just PR it.
1. **The Fast Track:** Once you have at least **one merged PR**, you are eligible to join our **Developer Discord** to help plan the future of PicoClaw.
2. **The Application Track:** If you havent submitted a PR yet but want to dive in, email **support@sipeed.com** with the subject:
> `[Apply Join PicoClaw Dev Group] + Your GitHub Account`
> Include the role you're interested in and any evidence of your development experience.
### Looking Ahead
Powered by PicoClaw, we are crafting a Swarm AI Assistant to transform your environment into a seamless network of personal stewards. By automating the friction of daily life, we empower you to transcend the ordinary and freely explore your creative potential.
**Finally, Happy Chinese New Year to everyone!** May PicoClaw gallop forward in this **Year of the Horse!** 🐎
+7 -3
View File
@@ -15,11 +15,16 @@ require (
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1
github.com/openai/openai-go/v3 v3.22.0
github.com/slack-go/slack v0.17.3
github.com/stretchr/testify v1.11.1
github.com/tencent-connect/botgo v0.2.1
golang.org/x/oauth2 v0.35.0
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
require (
github.com/andybalholm/brotli v1.2.0 // indirect
@@ -28,9 +33,9 @@ require (
github.com/bytedance/sonic/loader v0.5.0 // indirect
github.com/cloudwego/base64x v0.1.6 // indirect
github.com/github/copilot-sdk/go v0.1.23
github.com/google/jsonschema-go v0.4.2 // indirect
github.com/go-resty/resty/v2 v2.17.1 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/google/jsonschema-go v0.4.2 // indirect
github.com/grbit/go-json v0.11.0 // indirect
github.com/klauspost/compress v1.18.4 // indirect
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
@@ -47,5 +52,4 @@ require (
golang.org/x/net v0.50.0 // indirect
golang.org/x/sync v0.19.0 // indirect
golang.org/x/sys v0.41.0 // indirect
)
+6
View File
@@ -58,6 +58,8 @@ github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/jsonschema-go v0.4.2 h1:tmrUohrwoLZZS/P3x7ex0WAVknEkBZM46iALbcqoRA8=
github.com/google/jsonschema-go v0.4.2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
@@ -78,9 +80,11 @@ github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzh
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/larksuite/oapi-sdk-go/v3 v3.5.3 h1:xvf8Dv29kBXC5/DNDCLhHkAFW8l/0LlQJimO5Zn+JUk=
github.com/larksuite/oapi-sdk-go/v3 v3.5.3/go.mod h1:ZEplY+kwuIrj/nqw5uSCINNATcH3KdxSN7y+UxYY5fI=
@@ -102,6 +106,7 @@ github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsK
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/slack-go/slack v0.17.3 h1:zV5qO3Q+WJAQ/XwbGfNFrRMaJ5T/naqaonyPV/1TP4g=
github.com/slack-go/slack v0.17.3/go.mod h1:X+UqOufi3LYQHDnMG1vxf0J8asC6+WllXrVrhl8/Prk=
@@ -242,6 +247,7 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys=
+298 -16
View File
@@ -19,6 +19,7 @@ import (
"unicode/utf8"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/constants"
"github.com/sipeed/picoclaw/pkg/logger"
@@ -42,6 +43,7 @@ type AgentLoop struct {
tools *tools.ToolRegistry
running atomic.Bool
summarizing sync.Map // Tracks which sessions are currently being summarized
channelManager *channels.Manager
}
// processOptions configures how a message is processed
@@ -199,6 +201,10 @@ func (al *AgentLoop) RegisterTool(tool tools.Tool) {
al.tools.Register(tool)
}
func (al *AgentLoop) SetChannelManager(cm *channels.Manager) {
al.channelManager = cm
}
// RecordLastChannel records the last active channel for this workspace.
// This uses the atomic state save mechanism to prevent data loss on crash.
func (al *AgentLoop) RecordLastChannel(channel string) error {
@@ -263,6 +269,11 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
return al.processSystemMessage(ctx, msg)
}
// Check for commands
if response, handled := al.handleCommand(ctx, msg); handled {
return response, nil
}
// Process as user message
return al.runAgentLoop(ctx, processOptions{
SessionKey: msg.SessionKey,
@@ -383,7 +394,7 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str
// 7. Optional: summarization
if opts.EnableSummary {
al.maybeSummarize(opts.SessionKey)
al.maybeSummarize(opts.SessionKey, opts.Channel, opts.ChatID)
}
// 8. Optional: send response via bus
@@ -445,11 +456,131 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
"tools_json": formatToolsForLog(providerToolDefs),
})
// Call LLM
response, err := al.provider.Chat(ctx, messages, providerToolDefs, al.model, map[string]interface{}{
"max_tokens": 8192,
"temperature": 0.7,
})
var response *providers.LLMResponse
var err error
// Retry loop for context/token errors
maxRetries := 2
for retry := 0; retry <= maxRetries; retry++ {
response, err = al.provider.Chat(ctx, messages, providerToolDefs, al.model, map[string]interface{}{
"max_tokens": 8192,
"temperature": 0.7,
})
if err == nil {
break // Success
}
errMsg := strings.ToLower(err.Error())
// Check for context window errors (provider specific, but usually contain "token" or "invalid")
isContextError := strings.Contains(errMsg, "token") ||
strings.Contains(errMsg, "context") ||
strings.Contains(errMsg, "invalidparameter") ||
strings.Contains(errMsg, "length")
if isContextError && retry < maxRetries {
logger.WarnCF("agent", "Context window error detected, attempting compression", map[string]interface{}{
"error": err.Error(),
"retry": retry,
})
// Notify user on first retry only
if retry == 0 && !constants.IsInternalChannel(opts.Channel) && opts.SendResponse {
al.bus.PublishOutbound(bus.OutboundMessage{
Channel: opts.Channel,
ChatID: opts.ChatID,
Content: "⚠️ Context window exceeded. Compressing history and retrying...",
})
}
// Force compression
al.forceCompression(opts.SessionKey)
// Rebuild messages with compressed history
// Note: We need to reload history from session manager because forceCompression changed it
newHistory := al.sessions.GetHistory(opts.SessionKey)
newSummary := al.sessions.GetSummary(opts.SessionKey)
// Re-create messages for the next attempt
// We keep the current user message (opts.UserMessage) effectively
messages = al.contextBuilder.BuildMessages(
newHistory,
newSummary,
opts.UserMessage,
nil,
opts.Channel,
opts.ChatID,
)
// Important: If we are in the middle of a tool loop (iteration > 1),
// rebuilding messages from session history might duplicate the flow or miss context
// if intermediate steps weren't saved correctly.
// However, al.sessions.AddFullMessage is called after every tool execution,
// so GetHistory should reflect the current state including partial tool execution.
// But we need to ensure we don't duplicate the user message which is appended in BuildMessages.
// BuildMessages(history...) takes the stored history and appends the *current* user message.
// If iteration > 1, the "current user message" was already added to history in step 3 of runAgentLoop.
// So if we pass opts.UserMessage again, we might duplicate it?
// Actually, step 3 is: al.sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage)
// So GetHistory ALREADY contains the user message!
// CORRECTION:
// BuildMessages combines: [System] + [History] + [CurrentMessage]
// But Step 3 added CurrentMessage to History.
// So if we use GetHistory now, it has the user message.
// If we pass opts.UserMessage to BuildMessages, it adds it AGAIN.
// For retry in the middle of a loop, we should rely on what's in the session.
// BUT checking BuildMessages implementation:
// It appends history... then appends currentMessage.
// Logic fix for retry:
// If iteration == 1, opts.UserMessage corresponds to the user input.
// If iteration > 1, we are processing tool results. The "messages" passed to Chat
// already accumulated tool outputs.
// Rebuilding from session history is safest because it persists state.
// Start fresh with rebuilt history.
// Special case: standard BuildMessages appends "currentMessage".
// If we are strictly retrying the *LLM call*, we want the exact same state as before but compressed.
// However, the "messages" argument passed to runLLMIteration is constructed by the caller.
// If we rebuild from Session, we need to know if "currentMessage" should be appended or is already in history.
// In runAgentLoop:
// 3. sessions.AddMessage(userMsg)
// 4. runLLMIteration(..., UserMessage)
// So History contains the user message.
// BuildMessages typically appends the user message as a *new* pending message.
// Wait, standard BuildMessages usage in runAgentLoop:
// messages := BuildMessages(history (has old), UserMessage)
// THEN AddMessage(UserMessage).
// So "history" passed to BuildMessages does NOT contain the current UserMessage yet.
// But here, inside the loop, we have already saved it.
// So GetHistory() includes the current user message.
// If we call BuildMessages(GetHistory(), UserMessage), we get duplicates.
// Hack/Fix:
// If we are retrying, we rebuild from Session History ONLY.
// We pass empty string as "currentMessage" to BuildMessages
// because the "current message" is already saved in history (step 3).
messages = al.contextBuilder.BuildMessages(
newHistory,
newSummary,
"", // Empty because history already contains the relevant messages
nil,
opts.Channel,
opts.ChatID,
)
continue
}
// Real error or success, break loop
break
}
if err != nil {
logger.ErrorCF("agent", "LLM call failed",
@@ -457,7 +588,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
"iteration": iteration,
"error": err.Error(),
})
return "", iteration, fmt.Errorf("LLM call failed: %w", err)
return "", iteration, fmt.Errorf("LLM call failed after retries: %w", err)
}
// Check if no tool calls - we're done
@@ -589,7 +720,7 @@ func (al *AgentLoop) updateToolContexts(channel, chatID string) {
}
// maybeSummarize triggers summarization if the session history exceeds thresholds.
func (al *AgentLoop) maybeSummarize(sessionKey string) {
func (al *AgentLoop) maybeSummarize(sessionKey, channel, chatID string) {
newHistory := al.sessions.GetHistory(sessionKey)
tokenEstimate := al.estimateTokens(newHistory)
threshold := al.contextWindow * 75 / 100
@@ -598,12 +729,80 @@ func (al *AgentLoop) maybeSummarize(sessionKey string) {
if _, loading := al.summarizing.LoadOrStore(sessionKey, true); !loading {
go func() {
defer al.summarizing.Delete(sessionKey)
// Notify user about optimization if not an internal channel
if !constants.IsInternalChannel(channel) {
al.bus.PublishOutbound(bus.OutboundMessage{
Channel: channel,
ChatID: chatID,
Content: "⚠️ Memory threshold reached. Optimizing conversation history...",
})
}
al.summarizeSession(sessionKey)
}()
}
}
}
// forceCompression aggressively reduces context when the limit is hit.
// It drops the oldest 50% of messages (keeping system prompt and last user message).
func (al *AgentLoop) forceCompression(sessionKey string) {
history := al.sessions.GetHistory(sessionKey)
if len(history) <= 4 {
return
}
// Keep system prompt (usually [0]) and the very last message (user's trigger)
// We want to drop the oldest half of the *conversation*
// Assuming [0] is system, [1:] is conversation
conversation := history[1 : len(history)-1]
if len(conversation) == 0 {
return
}
// Helper to find the mid-point of the conversation
mid := len(conversation) / 2
// New history structure:
// 1. System Prompt
// 2. [Summary of dropped part] - synthesized
// 3. Second half of conversation
// 4. Last message
// Simplified approach for emergency: Drop first half of conversation
// and rely on existing summary if present, or create a placeholder.
droppedCount := mid
keptConversation := conversation[mid:]
newHistory := make([]providers.Message, 0)
newHistory = append(newHistory, history[0]) // System prompt
// Add a note about compression
compressionNote := fmt.Sprintf("[System: Emergency compression dropped %d oldest messages due to context limit]", droppedCount)
// If there was an existing summary, we might lose it if it was in the dropped part (which is just messages).
// The summary is stored separately in session.Summary, so it persists!
// We just need to ensure the user knows there's a gap.
// We only modify the messages list here
newHistory = append(newHistory, providers.Message{
Role: "system",
Content: compressionNote,
})
newHistory = append(newHistory, keptConversation...)
newHistory = append(newHistory, history[len(history)-1]) // Last message
// Update session
al.sessions.SetHistory(sessionKey, newHistory)
al.sessions.Save(sessionKey)
logger.WarnCF("agent", "Forced compression executed", map[string]interface{}{
"session_key": sessionKey,
"dropped_msgs": droppedCount,
"new_count": len(newHistory),
})
}
// GetStartupInfo returns information about loaded tools and skills for logging.
func (al *AgentLoop) GetStartupInfo() map[string]interface{} {
info := make(map[string]interface{})
@@ -631,7 +830,7 @@ func formatMessagesForLog(messages []providers.Message) string {
result += "[\n"
for i, msg := range messages {
result += fmt.Sprintf(" [%d] Role: %s\n", i, msg.Role)
if msg.ToolCalls != nil && len(msg.ToolCalls) > 0 {
if len(msg.ToolCalls) > 0 {
result += " ToolCalls:\n"
for _, tc := range msg.ToolCalls {
result += fmt.Sprintf(" - ID: %s, Type: %s, Name: %s\n", tc.ID, tc.Type, tc.Name)
@@ -698,7 +897,7 @@ func (al *AgentLoop) summarizeSession(sessionKey string) {
continue
}
// Estimate tokens for this message
msgTokens := len(m.Content) / 4
msgTokens := len(m.Content) / 2 // Use safer estimate here too (2.5 -> 2 for integer division safety)
if msgTokens > maxMessageTokens {
omitted = true
continue
@@ -769,13 +968,96 @@ func (al *AgentLoop) summarizeBatch(ctx context.Context, batch []providers.Messa
}
// estimateTokens estimates the number of tokens in a message list.
// Uses rune count instead of byte length so that CJK and other multi-byte
// characters are not over-counted (a Chinese character is 3 bytes but roughly
// one token).
// Uses a safe heuristic of 2.5 characters per token to account for CJK and other
// overheads better than the previous 3 chars/token.
func (al *AgentLoop) estimateTokens(messages []providers.Message) int {
total := 0
totalChars := 0
for _, m := range messages {
total += utf8.RuneCountInString(m.Content) / 3
totalChars += utf8.RuneCountInString(m.Content)
}
return total
// 2.5 chars per token = totalChars * 2 / 5
return totalChars * 2 / 5
}
func (al *AgentLoop) handleCommand(ctx context.Context, msg bus.InboundMessage) (string, bool) {
content := strings.TrimSpace(msg.Content)
if !strings.HasPrefix(content, "/") {
return "", false
}
parts := strings.Fields(content)
if len(parts) == 0 {
return "", false
}
cmd := parts[0]
args := parts[1:]
switch cmd {
case "/show":
if len(args) < 1 {
return "Usage: /show [model|channel]", true
}
switch args[0] {
case "model":
return fmt.Sprintf("Current model: %s", al.model), true
case "channel":
return fmt.Sprintf("Current channel: %s", msg.Channel), true
default:
return fmt.Sprintf("Unknown show target: %s", args[0]), true
}
case "/list":
if len(args) < 1 {
return "Usage: /list [models|channels]", true
}
switch args[0] {
case "models":
// TODO: Fetch available models dynamically if possible
return "Available models: glm-4.7, claude-3-5-sonnet, gpt-4o (configured in config.json/env)", true
case "channels":
if al.channelManager == nil {
return "Channel manager not initialized", true
}
channels := al.channelManager.GetEnabledChannels()
if len(channels) == 0 {
return "No channels enabled", true
}
return fmt.Sprintf("Enabled channels: %s", strings.Join(channels, ", ")), true
default:
return fmt.Sprintf("Unknown list target: %s", args[0]), true
}
case "/switch":
if len(args) < 3 || args[1] != "to" {
return "Usage: /switch [model|channel] to <name>", true
}
target := args[0]
value := args[2]
switch target {
case "model":
oldModel := al.model
al.model = value
return fmt.Sprintf("Switched model from %s to %s", oldModel, value), true
case "channel":
// This changes the 'default' channel for some operations, or effectively redirects output?
// For now, let's just validate if the channel exists
if al.channelManager == nil {
return "Channel manager not initialized", true
}
if _, exists := al.channelManager.GetChannel(value); !exists && value != "cli" {
return fmt.Sprintf("Channel '%s' not found or not enabled", value), true
}
// If message came from CLI, maybe we want to redirect CLI output to this channel?
// That would require state persistence about "redirected channel"
// For now, just acknowledged.
return fmt.Sprintf("Switched target channel to %s (Note: this currently only validates existence)", value), true
default:
return fmt.Sprintf("Unknown switch target: %s", target), true
}
}
return "", false
}
+97
View File
@@ -2,6 +2,7 @@ package agent
import (
"context"
"fmt"
"os"
"path/filepath"
"testing"
@@ -527,3 +528,99 @@ func TestToolResult_UserFacingToolDoesSendMessage(t *testing.T) {
t.Errorf("Expected 'Command output: hello world', got: %s", response)
}
}
// failFirstMockProvider fails on the first N calls with a specific error
type failFirstMockProvider struct {
failures int
currentCall int
failError error
successResp string
}
func (m *failFirstMockProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, opts map[string]interface{}) (*providers.LLMResponse, error) {
m.currentCall++
if m.currentCall <= m.failures {
return nil, m.failError
}
return &providers.LLMResponse{
Content: m.successResp,
ToolCalls: []providers.ToolCall{},
}, nil
}
func (m *failFirstMockProvider) GetDefaultModel() string {
return "mock-fail-model"
}
// TestAgentLoop_ContextExhaustionRetry verify that the agent retries on context errors
func TestAgentLoop_ContextExhaustionRetry(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
// Create a provider that fails once with a context error
contextErr := fmt.Errorf("InvalidParameter: Total tokens of image and text exceed max message tokens")
provider := &failFirstMockProvider{
failures: 1,
failError: contextErr,
successResp: "Recovered from context error",
}
al := NewAgentLoop(cfg, msgBus, provider)
// Inject some history to simulate a full context
sessionKey := "test-session-context"
// Create dummy history
history := []providers.Message{
{Role: "system", Content: "System prompt"},
{Role: "user", Content: "Old message 1"},
{Role: "assistant", Content: "Old response 1"},
{Role: "user", Content: "Old message 2"},
{Role: "assistant", Content: "Old response 2"},
{Role: "user", Content: "Trigger message"},
}
al.sessions.SetHistory(sessionKey, history)
// Call ProcessDirectWithChannel
// Note: ProcessDirectWithChannel calls processMessage which will execute runLLMIteration
response, err := al.ProcessDirectWithChannel(context.Background(), "Trigger message", sessionKey, "test", "test-chat")
if err != nil {
t.Fatalf("Expected success after retry, got error: %v", err)
}
if response != "Recovered from context error" {
t.Errorf("Expected 'Recovered from context error', got '%s'", response)
}
// We expect 2 calls: 1st failed, 2nd succeeded
if provider.currentCall != 2 {
t.Errorf("Expected 2 calls (1 fail + 1 success), got %d", provider.currentCall)
}
// Check final history length
finalHistory := al.sessions.GetHistory(sessionKey)
// We verify that the history has been modified (compressed)
// Original length: 6
// Expected behavior: compression drops ~50% of history (mid slice)
// We can assert that the length is NOT what it would be without compression.
// Without compression: 6 + 1 (new user msg) + 1 (assistant msg) = 8
if len(finalHistory) >= 8 {
t.Errorf("Expected history to be compressed (len < 8), got %d", len(finalHistory))
}
}
+56 -14
View File
@@ -281,7 +281,17 @@ func RefreshAccessToken(cred *AuthCredential, cfg OAuthProviderConfig) (*AuthCre
return nil, fmt.Errorf("token refresh failed: %s", string(body))
}
return parseTokenResponse(body, cred.Provider)
refreshed, err := parseTokenResponse(body, cred.Provider)
if err != nil {
return nil, err
}
if refreshed.RefreshToken == "" {
refreshed.RefreshToken = cred.RefreshToken
}
if refreshed.AccountID == "" {
refreshed.AccountID = cred.AccountID
}
return refreshed, nil
}
func BuildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectURI string) string {
@@ -300,6 +310,9 @@ func buildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectU
"codex_cli_simplified_flow": {"true"},
"state": {state},
}
if strings.Contains(strings.ToLower(cfg.Issuer), "auth.openai.com") {
params.Set("originator", "picoclaw")
}
if cfg.Originator != "" {
params.Set("originator", cfg.Originator)
}
@@ -357,7 +370,9 @@ func parseTokenResponse(body []byte, provider string) (*AuthCredential, error) {
AuthMethod: "oauth",
}
if accountID := extractAccountID(tokenResp.AccessToken); accountID != "" {
if accountID := extractAccountID(tokenResp.IDToken); accountID != "" {
cred.AccountID = accountID
} else if accountID := extractAccountID(tokenResp.AccessToken); accountID != "" {
cred.AccountID = accountID
} else if accountID := extractAccountID(tokenResp.IDToken); accountID != "" {
// Recent OpenAI OAuth responses may only include chatgpt_account_id in id_token claims.
@@ -367,12 +382,45 @@ func parseTokenResponse(body []byte, provider string) (*AuthCredential, error) {
return cred, nil
}
func extractAccountID(accessToken string) string {
parts := strings.Split(accessToken, ".")
if len(parts) < 2 {
func extractAccountID(token string) string {
claims, err := parseJWTClaims(token)
if err != nil {
return ""
}
if accountID, ok := claims["chatgpt_account_id"].(string); ok && accountID != "" {
return accountID
}
if accountID, ok := claims["https://api.openai.com/auth.chatgpt_account_id"].(string); ok && accountID != "" {
return accountID
}
if authClaim, ok := claims["https://api.openai.com/auth"].(map[string]interface{}); ok {
if accountID, ok := authClaim["chatgpt_account_id"].(string); ok && accountID != "" {
return accountID
}
}
if orgs, ok := claims["organizations"].([]interface{}); ok {
for _, org := range orgs {
if orgMap, ok := org.(map[string]interface{}); ok {
if accountID, ok := orgMap["id"].(string); ok && accountID != "" {
return accountID
}
}
}
}
return ""
}
func parseJWTClaims(token string) (map[string]interface{}, error) {
parts := strings.Split(token, ".")
if len(parts) < 2 {
return nil, fmt.Errorf("token is not a JWT")
}
payload := parts[1]
switch len(payload) % 4 {
case 2:
@@ -383,21 +431,15 @@ func extractAccountID(accessToken string) string {
decoded, err := base64URLDecode(payload)
if err != nil {
return ""
return nil, err
}
var claims map[string]interface{}
if err := json.Unmarshal(decoded, &claims); err != nil {
return ""
return nil, err
}
if authClaim, ok := claims["https://api.openai.com/auth"].(map[string]interface{}); ok {
if accountID, ok := authClaim["chatgpt_account_id"].(string); ok {
return accountID
}
}
return ""
return claims, nil
}
func base64URLDecode(s string) ([]byte, error) {
+97
View File
@@ -5,10 +5,23 @@ import (
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
)
func makeJWTForClaims(t *testing.T, claims map[string]interface{}) string {
t.Helper()
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"none","typ":"JWT"}`))
payloadJSON, err := json.Marshal(claims)
if err != nil {
t.Fatalf("marshal claims: %v", err)
}
payload := base64.RawURLEncoding.EncodeToString(payloadJSON)
return header + "." + payload + ".sig"
}
func TestBuildAuthorizeURL(t *testing.T) {
cfg := OAuthProviderConfig{
Issuer: "https://auth.example.com",
@@ -53,6 +66,28 @@ func TestBuildAuthorizeURL(t *testing.T) {
}
}
func TestBuildAuthorizeURLOpenAIExtras(t *testing.T) {
cfg := OpenAIOAuthConfig()
pkce := PKCECodes{CodeVerifier: "test-verifier", CodeChallenge: "test-challenge"}
u := BuildAuthorizeURL(cfg, pkce, "test-state", "http://localhost:1455/auth/callback")
parsed, err := url.Parse(u)
if err != nil {
t.Fatalf("url.Parse() error: %v", err)
}
q := parsed.Query()
if q.Get("id_token_add_organizations") != "true" {
t.Errorf("id_token_add_organizations = %q, want true", q.Get("id_token_add_organizations"))
}
if q.Get("codex_cli_simplified_flow") != "true" {
t.Errorf("codex_cli_simplified_flow = %q, want true", q.Get("codex_cli_simplified_flow"))
}
if q.Get("originator") != "codex_cli_rs" {
t.Errorf("originator = %q, want codex_cli_rs", q.Get("originator"))
}
}
func TestParseTokenResponse(t *testing.T) {
resp := map[string]interface{}{
"access_token": "test-access-token",
@@ -84,6 +119,37 @@ func TestParseTokenResponse(t *testing.T) {
}
}
func TestParseTokenResponseExtractsAccountIDFromIDToken(t *testing.T) {
idToken := makeJWTForClaims(t, map[string]interface{}{"chatgpt_account_id": "acc-id-from-id-token"})
resp := map[string]interface{}{
"access_token": "opaque-access-token",
"refresh_token": "test-refresh-token",
"expires_in": 3600,
"id_token": idToken,
}
body, _ := json.Marshal(resp)
cred, err := parseTokenResponse(body, "openai")
if err != nil {
t.Fatalf("parseTokenResponse() error: %v", err)
}
if cred.AccountID != "acc-id-from-id-token" {
t.Errorf("AccountID = %q, want %q", cred.AccountID, "acc-id-from-id-token")
}
}
func TestExtractAccountIDFromOrganizationsFallback(t *testing.T) {
token := makeJWTForClaims(t, map[string]interface{}{
"organizations": []interface{}{
map[string]interface{}{"id": "org_from_orgs"},
},
})
if got := extractAccountID(token); got != "org_from_orgs" {
t.Errorf("extractAccountID() = %q, want %q", got, "org_from_orgs")
}
}
func TestParseTokenResponseNoAccessToken(t *testing.T) {
body := []byte(`{"refresh_token": "test"}`)
_, err := parseTokenResponse(body, "openai")
@@ -222,6 +288,37 @@ func TestRefreshAccessTokenNoRefreshToken(t *testing.T) {
}
}
func TestRefreshAccessTokenPreservesRefreshAndAccountID(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := map[string]interface{}{
"access_token": "new-access-token-only",
"expires_in": 3600,
}
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
cfg := OAuthProviderConfig{Issuer: server.URL, ClientID: "test-client"}
cred := &AuthCredential{
AccessToken: "old-access",
RefreshToken: "existing-refresh",
AccountID: "acc_existing",
Provider: "openai",
AuthMethod: "oauth",
}
refreshed, err := RefreshAccessToken(cred, cfg)
if err != nil {
t.Fatalf("RefreshAccessToken() error: %v", err)
}
if refreshed.RefreshToken != "existing-refresh" {
t.Errorf("RefreshToken = %q, want %q", refreshed.RefreshToken, "existing-refresh")
}
if refreshed.AccountID != "acc_existing" {
t.Errorf("AccountID = %q, want %q", refreshed.AccountID, "acc_existing")
}
}
func TestOpenAIOAuthConfig(t *testing.T) {
cfg := OpenAIOAuthConfig()
if cfg.Issuer != "https://auth.openai.com" {
+17
View File
@@ -9,6 +9,7 @@ type MessageBus struct {
inbound chan InboundMessage
outbound chan OutboundMessage
handlers map[string]MessageHandler
closed bool
mu sync.RWMutex
}
@@ -21,6 +22,11 @@ func NewMessageBus() *MessageBus {
}
func (mb *MessageBus) PublishInbound(msg InboundMessage) {
mb.mu.RLock()
defer mb.mu.RUnlock()
if mb.closed {
return
}
mb.inbound <- msg
}
@@ -34,6 +40,11 @@ func (mb *MessageBus) ConsumeInbound(ctx context.Context) (InboundMessage, bool)
}
func (mb *MessageBus) PublishOutbound(msg OutboundMessage) {
mb.mu.RLock()
defer mb.mu.RUnlock()
if mb.closed {
return
}
mb.outbound <- msg
}
@@ -60,6 +71,12 @@ func (mb *MessageBus) GetHandler(channel string) (MessageHandler, bool) {
}
func (mb *MessageBus) Close() {
mb.mu.Lock()
defer mb.mu.Unlock()
if mb.closed {
return
}
mb.closed = true
close(mb.inbound)
close(mb.outbound)
}
+150 -2
View File
@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"os"
"strings"
"time"
"github.com/bwmarrin/discordgo"
@@ -100,15 +101,156 @@ func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro
return fmt.Errorf("channel ID is empty")
}
message := msg.Content
runes := []rune(msg.Content)
if len(runes) == 0 {
return nil
}
chunks := splitMessage(msg.Content, 1500) // Discord has a limit of 2000 characters per message, leave 500 for natural split e.g. code blocks
for _, chunk := range chunks {
if err := c.sendChunk(ctx, channelID, chunk); err != nil {
return err
}
}
return nil
}
// splitMessage splits long messages into chunks, preserving code block integrity
// Uses natural boundaries (newlines, spaces) and extends messages slightly to avoid breaking code blocks
func splitMessage(content string, limit int) []string {
var messages []string
for len(content) > 0 {
if len(content) <= limit {
messages = append(messages, content)
break
}
msgEnd := limit
// Find natural split point within the limit
msgEnd = findLastNewline(content[:limit], 200)
if msgEnd <= 0 {
msgEnd = findLastSpace(content[:limit], 100)
}
if msgEnd <= 0 {
msgEnd = limit
}
// Check if this would end with an incomplete code block
candidate := content[:msgEnd]
unclosedIdx := findLastUnclosedCodeBlock(candidate)
if unclosedIdx >= 0 {
// Message would end with incomplete code block
// Try to extend to include the closing ``` (with some buffer)
extendedLimit := limit + 500 // Allow 500 char buffer for code blocks
if len(content) > extendedLimit {
closingIdx := findNextClosingCodeBlock(content, msgEnd)
if closingIdx > 0 && closingIdx <= extendedLimit {
// Extend to include the closing ```
msgEnd = closingIdx
} else {
// Can't find closing, split before the code block
msgEnd = findLastNewline(content[:unclosedIdx], 200)
if msgEnd <= 0 {
msgEnd = findLastSpace(content[:unclosedIdx], 100)
}
if msgEnd <= 0 {
msgEnd = unclosedIdx
}
}
} else {
// Remaining content fits within extended limit
msgEnd = len(content)
}
}
if msgEnd <= 0 {
msgEnd = limit
}
messages = append(messages, content[:msgEnd])
content = strings.TrimSpace(content[msgEnd:])
}
return messages
}
// findLastUnclosedCodeBlock finds the last opening ``` that doesn't have a closing ```
// Returns the position of the opening ``` or -1 if all code blocks are complete
func findLastUnclosedCodeBlock(text string) int {
count := 0
lastOpenIdx := -1
for i := 0; i < len(text); i++ {
if i+2 < len(text) && text[i] == '`' && text[i+1] == '`' && text[i+2] == '`' {
if count == 0 {
lastOpenIdx = i
}
count++
i += 2
}
}
// If odd number of ``` markers, last one is unclosed
if count%2 == 1 {
return lastOpenIdx
}
return -1
}
// findNextClosingCodeBlock finds the next closing ``` starting from a position
// Returns the position after the closing ``` or -1 if not found
func findNextClosingCodeBlock(text string, startIdx int) int {
for i := startIdx; i < len(text); i++ {
if i+2 < len(text) && text[i] == '`' && text[i+1] == '`' && text[i+2] == '`' {
return i + 3
}
}
return -1
}
// findLastNewline finds the last newline character within the last N characters
// Returns the position of the newline or -1 if not found
func findLastNewline(s string, searchWindow int) int {
searchStart := len(s) - searchWindow
if searchStart < 0 {
searchStart = 0
}
for i := len(s) - 1; i >= searchStart; i-- {
if s[i] == '\n' {
return i
}
}
return -1
}
// findLastSpace finds the last space character within the last N characters
// Returns the position of the space or -1 if not found
func findLastSpace(s string, searchWindow int) int {
searchStart := len(s) - searchWindow
if searchStart < 0 {
searchStart = 0
}
for i := len(s) - 1; i >= searchStart; i-- {
if s[i] == ' ' || s[i] == '\t' {
return i
}
}
return -1
}
func (c *DiscordChannel) sendChunk(ctx context.Context, channelID, content string) error {
// 使用传入的 ctx 进行超时控制
sendCtx, cancel := context.WithTimeout(ctx, sendTimeout)
defer cancel()
done := make(chan error, 1)
go func() {
_, err := c.session.ChannelMessageSend(channelID, message)
_, err := c.session.ChannelMessageSend(channelID, content)
done <- err
}()
@@ -140,6 +282,12 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
return
}
if err := c.session.ChannelTyping(m.ChannelID); err != nil {
logger.ErrorCF("discord", "Failed to send typing indicator", map[string]any{
"error": err.Error(),
})
}
// 检查白名单,避免为被拒绝的用户下载附件和转录
if !c.IsAllowed(m.Author.ID) {
logger.DebugCF("discord", "Message rejected by allowlist", map[string]any{
+1 -1
View File
@@ -48,7 +48,7 @@ func (m *Manager) initChannels() error {
if m.config.Channels.Telegram.Enabled && m.config.Channels.Telegram.Token != "" {
logger.DebugC("channels", "Attempting to initialize Telegram channel")
telegram, err := NewTelegramChannel(m.config.Channels.Telegram, m.bus)
telegram, err := NewTelegramChannel(m.config, m.bus)
if err != nil {
logger.ErrorCF("channels", "Failed to initialize Telegram channel", map[string]interface{}{
"error": err.Error(),
+14
View File
@@ -296,6 +296,13 @@ func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) {
return
}
if !c.IsAllowed(ev.User) {
logger.DebugCF("slack", "Mention rejected by allowlist", map[string]interface{}{
"user_id": ev.User,
})
return
}
senderID := ev.User
channelID := ev.Channel
threadTS := ev.ThreadTimeStamp
@@ -345,6 +352,13 @@ func (c *SlackChannel) handleSlashCommand(event socketmode.Event) {
c.socketClient.Ack(*event.Request)
}
if !c.IsAllowed(cmd.UserID) {
logger.DebugCF("slack", "Slash command rejected by allowlist", map[string]interface{}{
"user_id": cmd.UserID,
})
return
}
senderID := cmd.UserID
channelID := cmd.ChannelID
chatID := channelID
+58 -40
View File
@@ -11,7 +11,10 @@ import (
"sync"
"time"
th "github.com/mymmrac/telego/telegohandler"
"github.com/mymmrac/telego"
"github.com/mymmrac/telego/telegohandler"
tu "github.com/mymmrac/telego/telegoutil"
"github.com/sipeed/picoclaw/pkg/bus"
@@ -24,7 +27,8 @@ import (
type TelegramChannel struct {
*BaseChannel
bot *telego.Bot
config config.TelegramConfig
commands TelegramCommander
config *config.Config
chatIDs map[string]int64
transcriber *voice.GroqTranscriber
placeholders sync.Map // chatID -> messageID
@@ -41,13 +45,14 @@ func (c *thinkingCancel) Cancel() {
}
}
func NewTelegramChannel(cfg config.TelegramConfig, bus *bus.MessageBus) (*TelegramChannel, error) {
func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChannel, error) {
var opts []telego.BotOption
telegramCfg := cfg.Channels.Telegram
if cfg.Proxy != "" {
proxyURL, parseErr := url.Parse(cfg.Proxy)
if telegramCfg.Proxy != "" {
proxyURL, parseErr := url.Parse(telegramCfg.Proxy)
if parseErr != nil {
return nil, fmt.Errorf("invalid proxy URL %q: %w", cfg.Proxy, parseErr)
return nil, fmt.Errorf("invalid proxy URL %q: %w", telegramCfg.Proxy, parseErr)
}
opts = append(opts, telego.WithHTTPClient(&http.Client{
Transport: &http.Transport{
@@ -56,15 +61,16 @@ func NewTelegramChannel(cfg config.TelegramConfig, bus *bus.MessageBus) (*Telegr
}))
}
bot, err := telego.NewBot(cfg.Token, opts...)
bot, err := telego.NewBot(telegramCfg.Token, opts...)
if err != nil {
return nil, fmt.Errorf("failed to create telegram bot: %w", err)
}
base := NewBaseChannel("telegram", cfg, bus, cfg.AllowFrom)
base := NewBaseChannel("telegram", telegramCfg, bus, telegramCfg.AllowFrom)
return &TelegramChannel{
BaseChannel: base,
commands: NewTelegramCommands(bot, cfg),
bot: bot,
config: cfg,
chatIDs: make(map[string]int64),
@@ -88,31 +94,45 @@ func (c *TelegramChannel) Start(ctx context.Context) error {
return fmt.Errorf("failed to start long polling: %w", err)
}
bh, err := telegohandler.NewBotHandler(c.bot, updates)
if err != nil {
return fmt.Errorf("failed to create bot handler: %w", err)
}
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
c.commands.Help(ctx, message)
return nil
}, th.CommandEqual("help"))
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
return c.commands.Start(ctx, message)
}, th.CommandEqual("start"))
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
return c.commands.Show(ctx, message)
}, th.CommandEqual("show"))
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
return c.commands.List(ctx, message)
}, th.CommandEqual("list"))
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
return c.handleMessage(ctx, &message)
}, th.AnyMessage())
c.setRunning(true)
logger.InfoCF("telegram", "Telegram bot connected", map[string]interface{}{
"username": c.bot.Username(),
})
go bh.Start()
go func() {
for {
select {
case <-ctx.Done():
return
case update, ok := <-updates:
if !ok {
logger.InfoC("telegram", "Updates channel closed, reconnecting...")
return
}
if update.Message != nil {
c.handleMessage(ctx, update)
}
}
}
<-ctx.Done()
bh.Stop()
}()
return nil
}
func (c *TelegramChannel) Stop(ctx context.Context) error {
logger.InfoC("telegram", "Stopping Telegram bot...")
c.setRunning(false)
@@ -166,30 +186,27 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
return nil
}
func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Update) {
message := update.Message
func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Message) error {
if message == nil {
return
return fmt.Errorf("message is nil")
}
user := message.From
if user == nil {
return
return fmt.Errorf("message sender (user) is nil")
}
userID := fmt.Sprintf("%d", user.ID)
senderID := userID
senderID := fmt.Sprintf("%d", user.ID)
if user.Username != "" {
senderID = fmt.Sprintf("%s|%s", userID, user.Username)
senderID = fmt.Sprintf("%d|%s", user.ID, user.Username)
}
// 检查白名单,避免为被拒绝的用户下载附件
if !c.IsAllowed(userID) && !c.IsAllowed(senderID) {
if !c.IsAllowed(senderID) {
logger.DebugCF("telegram", "Message rejected by allowlist", map[string]interface{}{
"user_id": userID,
"username": user.Username,
"user_id": senderID,
})
return
return nil
}
chatID := message.Chat.ID
@@ -222,7 +239,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
content += message.Caption
}
if message.Photo != nil && len(message.Photo) > 0 {
if len(message.Photo) > 0 {
photo := message.Photo[len(message.Photo)-1]
photoPath := c.downloadPhoto(ctx, photo.FileID)
if photoPath != "" {
@@ -231,7 +248,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
if content != "" {
content += "\n"
}
content += fmt.Sprintf("[image: photo]")
content += "[image: photo]"
}
}
@@ -252,7 +269,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
"error": err.Error(),
"path": voicePath,
})
transcribedText = fmt.Sprintf("[voice (transcription failed)]")
transcribedText = "[voice (transcription failed)]"
} else {
transcribedText = fmt.Sprintf("[voice transcription: %s]", result.Text)
logger.InfoCF("telegram", "Voice transcribed successfully", map[string]interface{}{
@@ -260,7 +277,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
})
}
} else {
transcribedText = fmt.Sprintf("[voice]")
transcribedText = "[voice]"
}
if content != "" {
@@ -278,7 +295,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
if content != "" {
content += "\n"
}
content += fmt.Sprintf("[audio]")
content += "[audio]"
}
}
@@ -290,7 +307,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
if content != "" {
content += "\n"
}
content += fmt.Sprintf("[file]")
content += "[file]"
}
}
@@ -338,7 +355,8 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
"is_group": fmt.Sprintf("%t", message.Chat.Type != "private"),
}
c.HandleMessage(senderID, fmt.Sprintf("%d", chatID), content, mediaPaths, metadata)
c.HandleMessage(fmt.Sprintf("%d", user.ID), fmt.Sprintf("%d", chatID), content, mediaPaths, metadata)
return nil
}
func (c *TelegramChannel) downloadPhoto(ctx context.Context, fileID string) string {
+153
View File
@@ -0,0 +1,153 @@
package channels
import (
"context"
"fmt"
"strings"
"github.com/mymmrac/telego"
"github.com/sipeed/picoclaw/pkg/config"
)
type TelegramCommander interface {
Help(ctx context.Context, message telego.Message) error
Start(ctx context.Context, message telego.Message) error
Show(ctx context.Context, message telego.Message) error
List(ctx context.Context, message telego.Message) error
}
type cmd struct {
bot *telego.Bot
config *config.Config
}
func NewTelegramCommands(bot *telego.Bot, cfg *config.Config) TelegramCommander {
return &cmd{
bot: bot,
config: cfg,
}
}
func commandArgs(text string) string {
parts := strings.SplitN(text, " ", 2)
if len(parts) < 2 {
return ""
}
return strings.TrimSpace(parts[1])
}
func (c *cmd) Help(ctx context.Context, message telego.Message) error {
msg := `/start - Start the bot
/help - Show this help message
/show [model|channel] - Show current configuration
/list [models|channels] - List available options
`
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
ChatID: telego.ChatID{ID: message.Chat.ID},
Text: msg,
ReplyParameters: &telego.ReplyParameters{
MessageID: message.MessageID,
},
})
return err
}
func (c *cmd) Start(ctx context.Context, message telego.Message) error {
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
ChatID: telego.ChatID{ID: message.Chat.ID},
Text: "Hello! I am PicoClaw 🦞",
ReplyParameters: &telego.ReplyParameters{
MessageID: message.MessageID,
},
})
return err
}
func (c *cmd) Show(ctx context.Context, message telego.Message) error {
args := commandArgs(message.Text)
if args == "" {
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
ChatID: telego.ChatID{ID: message.Chat.ID},
Text: "Usage: /show [model|channel]",
ReplyParameters: &telego.ReplyParameters{
MessageID: message.MessageID,
},
})
return err
}
var response string
switch args {
case "model":
response = fmt.Sprintf("Current Model: %s (Provider: %s)",
c.config.Agents.Defaults.Model,
c.config.Agents.Defaults.Provider)
case "channel":
response = "Current Channel: telegram"
default:
response = fmt.Sprintf("Unknown parameter: %s. Try 'model' or 'channel'.", args)
}
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
ChatID: telego.ChatID{ID: message.Chat.ID},
Text: response,
ReplyParameters: &telego.ReplyParameters{
MessageID: message.MessageID,
},
})
return err
}
func (c *cmd) List(ctx context.Context, message telego.Message) error {
args := commandArgs(message.Text)
if args == "" {
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
ChatID: telego.ChatID{ID: message.Chat.ID},
Text: "Usage: /list [models|channels]",
ReplyParameters: &telego.ReplyParameters{
MessageID: message.MessageID,
},
})
return err
}
var response string
switch args {
case "models":
provider := c.config.Agents.Defaults.Provider
if provider == "" {
provider = "configured default"
}
response = fmt.Sprintf("Configured Model: %s\nProvider: %s\n\nTo change models, update config.yaml",
c.config.Agents.Defaults.Model, provider)
case "channels":
var enabled []string
if c.config.Channels.Telegram.Enabled {
enabled = append(enabled, "telegram")
}
if c.config.Channels.WhatsApp.Enabled {
enabled = append(enabled, "whatsapp")
}
if c.config.Channels.Feishu.Enabled {
enabled = append(enabled, "feishu")
}
if c.config.Channels.Discord.Enabled {
enabled = append(enabled, "discord")
}
if c.config.Channels.Slack.Enabled {
enabled = append(enabled, "slack")
}
response = fmt.Sprintf("Enabled Channels:\n- %s", strings.Join(enabled, "\n- "))
default:
response = fmt.Sprintf("Unknown parameter: %s. Try 'models' or 'channels'.", args)
}
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
ChatID: telego.ChatID{ID: message.Chat.ID},
Text: response,
ReplyParameters: &telego.ReplyParameters{
MessageID: message.MessageID,
},
})
return err
}
+2 -1
View File
@@ -175,6 +175,7 @@ type ProvidersConfig struct {
VLLM ProviderConfig `json:"vllm"`
Gemini ProviderConfig `json:"gemini"`
Nvidia ProviderConfig `json:"nvidia"`
Ollama ProviderConfig `json:"ollama"`
Moonshot ProviderConfig `json:"moonshot"`
ShengSuanYun ProviderConfig `json:"shengsuanyun"`
DeepSeek ProviderConfig `json:"deepseek"`
@@ -377,7 +378,7 @@ func SaveConfig(path string, cfg *Config) error {
return err
}
return os.WriteFile(path, data, 0644)
return os.WriteFile(path, data, 0600)
}
func (c *Config) WorkspacePath() string {
+27
View File
@@ -1,6 +1,9 @@
package config
import (
"os"
"path/filepath"
"runtime"
"testing"
)
@@ -147,6 +150,30 @@ func TestDefaultConfig_WebTools(t *testing.T) {
}
}
func TestSaveConfig_FilePermissions(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("file permission bits are not enforced on Windows")
}
tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "config.json")
cfg := DefaultConfig()
if err := SaveConfig(path, cfg); err != nil {
t.Fatalf("SaveConfig failed: %v", err)
}
info, err := os.Stat(path)
if err != nil {
t.Fatalf("Stat failed: %v", err)
}
perm := info.Mode().Perm()
if perm != 0600 {
t.Errorf("config file has permission %04o, want 0600", perm)
}
}
// TestConfig_Complete verifies all config fields are set
func TestConfig_Complete(t *testing.T) {
cfg := DefaultConfig()
+1 -1
View File
@@ -340,7 +340,7 @@ func (cs *CronService) saveStoreUnsafe() error {
return err
}
return os.WriteFile(cs.storePath, data, 0644)
return os.WriteFile(cs.storePath, data, 0600)
}
func (cs *CronService) AddJob(name string, schedule CronSchedule, message string, deliver bool, channel, to string) (*CronJob, error) {
+38
View File
@@ -0,0 +1,38 @@
package cron
import (
"os"
"path/filepath"
"runtime"
"testing"
)
func TestSaveStore_FilePermissions(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("file permission bits are not enforced on Windows")
}
tmpDir := t.TempDir()
storePath := filepath.Join(tmpDir, "cron", "jobs.json")
cs := NewCronService(storePath, nil)
_, err := cs.AddJob("test", CronSchedule{Kind: "every", EveryMS: int64Ptr(60000)}, "hello", false, "cli", "direct")
if err != nil {
t.Fatalf("AddJob failed: %v", err)
}
info, err := os.Stat(storePath)
if err != nil {
t.Fatalf("Stat failed: %v", err)
}
perm := info.Mode().Perm()
if perm != 0600 {
t.Errorf("cron store has permission %04o, want 0600", perm)
}
}
func int64Ptr(v int64) *int64 {
return &v
}
+164
View File
@@ -0,0 +1,164 @@
package health
import (
"context"
"encoding/json"
"fmt"
"net/http"
"sync"
"time"
)
type Server struct {
server *http.Server
mu sync.RWMutex
ready bool
checks map[string]Check
startTime time.Time
}
type Check struct {
Name string `json:"name"`
Status string `json:"status"`
Message string `json:"message,omitempty"`
Timestamp time.Time `json:"timestamp"`
}
type StatusResponse struct {
Status string `json:"status"`
Uptime string `json:"uptime"`
Checks map[string]Check `json:"checks,omitempty"`
}
func NewServer(host string, port int) *Server {
mux := http.NewServeMux()
s := &Server{
ready: false,
checks: make(map[string]Check),
startTime: time.Now(),
}
mux.HandleFunc("/health", s.healthHandler)
mux.HandleFunc("/ready", s.readyHandler)
addr := fmt.Sprintf("%s:%d", host, port)
s.server = &http.Server{
Addr: addr,
Handler: mux,
ReadTimeout: 5 * time.Second,
WriteTimeout: 5 * time.Second,
}
return s
}
func (s *Server) Start() error {
s.mu.Lock()
s.ready = true
s.mu.Unlock()
return s.server.ListenAndServe()
}
func (s *Server) StartContext(ctx context.Context) error {
s.mu.Lock()
s.ready = true
s.mu.Unlock()
errCh := make(chan error, 1)
go func() {
errCh <- s.server.ListenAndServe()
}()
select {
case err := <-errCh:
return err
case <-ctx.Done():
return s.server.Shutdown(context.Background())
}
}
func (s *Server) Stop(ctx context.Context) error {
s.mu.Lock()
s.ready = false
s.mu.Unlock()
return s.server.Shutdown(ctx)
}
func (s *Server) SetReady(ready bool) {
s.mu.Lock()
s.ready = ready
s.mu.Unlock()
}
func (s *Server) RegisterCheck(name string, checkFn func() (bool, string)) {
s.mu.Lock()
defer s.mu.Unlock()
status, msg := checkFn()
s.checks[name] = Check{
Name: name,
Status: statusString(status),
Message: msg,
Timestamp: time.Now(),
}
}
func (s *Server) healthHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
uptime := time.Since(s.startTime)
resp := StatusResponse{
Status: "ok",
Uptime: uptime.String(),
}
json.NewEncoder(w).Encode(resp)
}
func (s *Server) readyHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
s.mu.RLock()
ready := s.ready
checks := make(map[string]Check)
for k, v := range s.checks {
checks[k] = v
}
s.mu.RUnlock()
if !ready {
w.WriteHeader(http.StatusServiceUnavailable)
json.NewEncoder(w).Encode(StatusResponse{
Status: "not ready",
Checks: checks,
})
return
}
for _, check := range checks {
if check.Status == "fail" {
w.WriteHeader(http.StatusServiceUnavailable)
json.NewEncoder(w).Encode(StatusResponse{
Status: "not ready",
Checks: checks,
})
return
}
}
w.WriteHeader(http.StatusOK)
uptime := time.Since(s.startTime)
json.NewEncoder(w).Encode(StatusResponse{
Status: "ready",
Uptime: uptime.String(),
Checks: checks,
})
}
func statusString(ok bool) string {
if ok {
return "ok"
}
return "fail"
}
+4 -58
View File
@@ -171,68 +171,14 @@ func (p *ClaudeCliProvider) parseClaudeCliResponse(output string) (*LLMResponse,
}, nil
}
// extractToolCalls parses tool call JSON from the response text.
// extractToolCalls delegates to the shared extractToolCallsFromText function.
func (p *ClaudeCliProvider) extractToolCalls(text string) []ToolCall {
start := strings.Index(text, `{"tool_calls"`)
if start == -1 {
return nil
}
end := findMatchingBrace(text, start)
if end == start {
return nil
}
jsonStr := text[start:end]
var wrapper struct {
ToolCalls []struct {
ID string `json:"id"`
Type string `json:"type"`
Function struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
} `json:"function"`
} `json:"tool_calls"`
}
if err := json.Unmarshal([]byte(jsonStr), &wrapper); err != nil {
return nil
}
var result []ToolCall
for _, tc := range wrapper.ToolCalls {
var args map[string]interface{}
json.Unmarshal([]byte(tc.Function.Arguments), &args)
result = append(result, ToolCall{
ID: tc.ID,
Type: tc.Type,
Name: tc.Function.Name,
Arguments: args,
Function: &FunctionCall{
Name: tc.Function.Name,
Arguments: tc.Function.Arguments,
},
})
}
return result
return extractToolCallsFromText(text)
}
// stripToolCallsJSON removes tool call JSON from response text.
// stripToolCallsJSON delegates to the shared stripToolCallsFromText function.
func (p *ClaudeCliProvider) stripToolCallsJSON(text string) string {
start := strings.Index(text, `{"tool_calls"`)
if start == -1 {
return text
}
end := findMatchingBrace(text, start)
if end == start {
return text
}
return strings.TrimSpace(text[:start] + text[end:])
return stripToolCallsFromText(text)
}
// findMatchingBrace finds the index after the closing brace matching the opening brace at pos.
+79
View File
@@ -0,0 +1,79 @@
package providers
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"time"
)
// CodexCliAuth represents the ~/.codex/auth.json file structure.
type CodexCliAuth struct {
Tokens struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
AccountID string `json:"account_id"`
} `json:"tokens"`
}
// ReadCodexCliCredentials reads OAuth tokens from the Codex CLI's auth.json file.
// Expiry is estimated as file modification time + 1 hour (same approach as moltbot).
func ReadCodexCliCredentials() (accessToken, accountID string, expiresAt time.Time, err error) {
authPath, err := resolveCodexAuthPath()
if err != nil {
return "", "", time.Time{}, err
}
data, err := os.ReadFile(authPath)
if err != nil {
return "", "", time.Time{}, fmt.Errorf("reading %s: %w", authPath, err)
}
var auth CodexCliAuth
if err := json.Unmarshal(data, &auth); err != nil {
return "", "", time.Time{}, fmt.Errorf("parsing %s: %w", authPath, err)
}
if auth.Tokens.AccessToken == "" {
return "", "", time.Time{}, fmt.Errorf("no access_token in %s", authPath)
}
stat, err := os.Stat(authPath)
if err != nil {
expiresAt = time.Now().Add(time.Hour)
} else {
expiresAt = stat.ModTime().Add(time.Hour)
}
return auth.Tokens.AccessToken, auth.Tokens.AccountID, expiresAt, nil
}
// CreateCodexCliTokenSource creates a token source that reads from ~/.codex/auth.json.
// This allows the existing CodexProvider to reuse Codex CLI credentials.
func CreateCodexCliTokenSource() func() (string, string, error) {
return func() (string, string, error) {
token, accountID, expiresAt, err := ReadCodexCliCredentials()
if err != nil {
return "", "", fmt.Errorf("reading codex cli credentials: %w", err)
}
if time.Now().After(expiresAt) {
return "", "", fmt.Errorf("codex cli credentials expired (auth.json last modified > 1h ago). Run: codex login")
}
return token, accountID, nil
}
}
func resolveCodexAuthPath() (string, error) {
codexHome := os.Getenv("CODEX_HOME")
if codexHome == "" {
home, err := os.UserHomeDir()
if err != nil {
return "", fmt.Errorf("getting home dir: %w", err)
}
codexHome = filepath.Join(home, ".codex")
}
return filepath.Join(codexHome, "auth.json"), nil
}
+181
View File
@@ -0,0 +1,181 @@
package providers
import (
"os"
"path/filepath"
"testing"
"time"
)
func TestReadCodexCliCredentials_Valid(t *testing.T) {
tmpDir := t.TempDir()
authPath := filepath.Join(tmpDir, "auth.json")
authJSON := `{
"tokens": {
"access_token": "test-access-token",
"refresh_token": "test-refresh-token",
"account_id": "org-test123"
}
}`
if err := os.WriteFile(authPath, []byte(authJSON), 0600); err != nil {
t.Fatal(err)
}
t.Setenv("CODEX_HOME", tmpDir)
token, accountID, expiresAt, err := ReadCodexCliCredentials()
if err != nil {
t.Fatalf("ReadCodexCliCredentials() error: %v", err)
}
if token != "test-access-token" {
t.Errorf("token = %q, want %q", token, "test-access-token")
}
if accountID != "org-test123" {
t.Errorf("accountID = %q, want %q", accountID, "org-test123")
}
// Expiry should be within ~1 hour from now (file was just written)
if expiresAt.Before(time.Now()) {
t.Errorf("expiresAt = %v, should be in the future", expiresAt)
}
if expiresAt.After(time.Now().Add(2 * time.Hour)) {
t.Errorf("expiresAt = %v, should be within ~1 hour", expiresAt)
}
}
func TestReadCodexCliCredentials_MissingFile(t *testing.T) {
tmpDir := t.TempDir()
t.Setenv("CODEX_HOME", tmpDir)
_, _, _, err := ReadCodexCliCredentials()
if err == nil {
t.Fatal("expected error for missing auth.json")
}
}
func TestReadCodexCliCredentials_EmptyToken(t *testing.T) {
tmpDir := t.TempDir()
authPath := filepath.Join(tmpDir, "auth.json")
authJSON := `{"tokens": {"access_token": "", "refresh_token": "r", "account_id": "a"}}`
if err := os.WriteFile(authPath, []byte(authJSON), 0600); err != nil {
t.Fatal(err)
}
t.Setenv("CODEX_HOME", tmpDir)
_, _, _, err := ReadCodexCliCredentials()
if err == nil {
t.Fatal("expected error for empty access_token")
}
}
func TestReadCodexCliCredentials_InvalidJSON(t *testing.T) {
tmpDir := t.TempDir()
authPath := filepath.Join(tmpDir, "auth.json")
if err := os.WriteFile(authPath, []byte("not json"), 0600); err != nil {
t.Fatal(err)
}
t.Setenv("CODEX_HOME", tmpDir)
_, _, _, err := ReadCodexCliCredentials()
if err == nil {
t.Fatal("expected error for invalid JSON")
}
}
func TestReadCodexCliCredentials_NoAccountID(t *testing.T) {
tmpDir := t.TempDir()
authPath := filepath.Join(tmpDir, "auth.json")
authJSON := `{"tokens": {"access_token": "tok123", "refresh_token": "ref456"}}`
if err := os.WriteFile(authPath, []byte(authJSON), 0600); err != nil {
t.Fatal(err)
}
t.Setenv("CODEX_HOME", tmpDir)
token, accountID, _, err := ReadCodexCliCredentials()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if token != "tok123" {
t.Errorf("token = %q, want %q", token, "tok123")
}
if accountID != "" {
t.Errorf("accountID = %q, want empty", accountID)
}
}
func TestReadCodexCliCredentials_CodexHomeEnv(t *testing.T) {
tmpDir := t.TempDir()
customDir := filepath.Join(tmpDir, "custom-codex")
if err := os.MkdirAll(customDir, 0755); err != nil {
t.Fatal(err)
}
authJSON := `{"tokens": {"access_token": "custom-token", "refresh_token": "r"}}`
if err := os.WriteFile(filepath.Join(customDir, "auth.json"), []byte(authJSON), 0600); err != nil {
t.Fatal(err)
}
t.Setenv("CODEX_HOME", customDir)
token, _, _, err := ReadCodexCliCredentials()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if token != "custom-token" {
t.Errorf("token = %q, want %q", token, "custom-token")
}
}
func TestCreateCodexCliTokenSource_Valid(t *testing.T) {
tmpDir := t.TempDir()
authPath := filepath.Join(tmpDir, "auth.json")
authJSON := `{"tokens": {"access_token": "fresh-token", "refresh_token": "r", "account_id": "acc"}}`
if err := os.WriteFile(authPath, []byte(authJSON), 0600); err != nil {
t.Fatal(err)
}
t.Setenv("CODEX_HOME", tmpDir)
source := CreateCodexCliTokenSource()
token, accountID, err := source()
if err != nil {
t.Fatalf("token source error: %v", err)
}
if token != "fresh-token" {
t.Errorf("token = %q, want %q", token, "fresh-token")
}
if accountID != "acc" {
t.Errorf("accountID = %q, want %q", accountID, "acc")
}
}
func TestCreateCodexCliTokenSource_Expired(t *testing.T) {
tmpDir := t.TempDir()
authPath := filepath.Join(tmpDir, "auth.json")
authJSON := `{"tokens": {"access_token": "old-token", "refresh_token": "r"}}`
if err := os.WriteFile(authPath, []byte(authJSON), 0600); err != nil {
t.Fatal(err)
}
// Set file modification time to 2 hours ago
oldTime := time.Now().Add(-2 * time.Hour)
if err := os.Chtimes(authPath, oldTime, oldTime); err != nil {
t.Fatal(err)
}
t.Setenv("CODEX_HOME", tmpDir)
source := CreateCodexCliTokenSource()
_, _, err := source()
if err == nil {
t.Fatal("expected error for expired credentials")
}
}
+251
View File
@@ -0,0 +1,251 @@
package providers
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"os/exec"
"strings"
)
// CodexCliProvider implements LLMProvider by wrapping the codex CLI as a subprocess.
type CodexCliProvider struct {
command string
workspace string
}
// NewCodexCliProvider creates a new Codex CLI provider.
func NewCodexCliProvider(workspace string) *CodexCliProvider {
return &CodexCliProvider{
command: "codex",
workspace: workspace,
}
}
// Chat implements LLMProvider.Chat by executing the codex CLI in non-interactive mode.
func (p *CodexCliProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
if p.command == "" {
return nil, fmt.Errorf("codex command not configured")
}
prompt := p.buildPrompt(messages, tools)
args := []string{
"exec",
"--json",
"--dangerously-bypass-approvals-and-sandbox",
"--skip-git-repo-check",
"--color", "never",
}
if model != "" && model != "codex-cli" {
args = append(args, "-m", model)
}
if p.workspace != "" {
args = append(args, "-C", p.workspace)
}
args = append(args, "-") // read prompt from stdin
cmd := exec.CommandContext(ctx, p.command, args...)
cmd.Stdin = bytes.NewReader([]byte(prompt))
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
err := cmd.Run()
// Parse JSONL from stdout even if exit code is non-zero,
// because codex writes diagnostic noise to stderr (e.g. rollout errors)
// but still produces valid JSONL output.
if stdoutStr := stdout.String(); stdoutStr != "" {
resp, parseErr := p.parseJSONLEvents(stdoutStr)
if parseErr == nil && resp != nil && (resp.Content != "" || len(resp.ToolCalls) > 0) {
return resp, nil
}
}
if err != nil {
if ctx.Err() == context.Canceled {
return nil, ctx.Err()
}
if stderrStr := stderr.String(); stderrStr != "" {
return nil, fmt.Errorf("codex cli error: %s", stderrStr)
}
return nil, fmt.Errorf("codex cli error: %w", err)
}
return p.parseJSONLEvents(stdout.String())
}
// GetDefaultModel returns the default model identifier.
func (p *CodexCliProvider) GetDefaultModel() string {
return "codex-cli"
}
// buildPrompt converts messages to a prompt string for the Codex CLI.
// System messages are prepended as instructions since Codex CLI has no --system-prompt flag.
func (p *CodexCliProvider) buildPrompt(messages []Message, tools []ToolDefinition) string {
var systemParts []string
var conversationParts []string
for _, msg := range messages {
switch msg.Role {
case "system":
systemParts = append(systemParts, msg.Content)
case "user":
conversationParts = append(conversationParts, msg.Content)
case "assistant":
conversationParts = append(conversationParts, "Assistant: "+msg.Content)
case "tool":
conversationParts = append(conversationParts,
fmt.Sprintf("[Tool Result for %s]: %s", msg.ToolCallID, msg.Content))
}
}
var sb strings.Builder
if len(systemParts) > 0 {
sb.WriteString("## System Instructions\n\n")
sb.WriteString(strings.Join(systemParts, "\n\n"))
sb.WriteString("\n\n## Task\n\n")
}
if len(tools) > 0 {
sb.WriteString(p.buildToolsPrompt(tools))
sb.WriteString("\n\n")
}
// Simplify single user message (no prefix)
if len(conversationParts) == 1 && len(systemParts) == 0 && len(tools) == 0 {
return conversationParts[0]
}
sb.WriteString(strings.Join(conversationParts, "\n"))
return sb.String()
}
// buildToolsPrompt creates a tool definitions section for the prompt.
func (p *CodexCliProvider) buildToolsPrompt(tools []ToolDefinition) string {
var sb strings.Builder
sb.WriteString("## Available Tools\n\n")
sb.WriteString("When you need to use a tool, respond with ONLY a JSON object:\n\n")
sb.WriteString("```json\n")
sb.WriteString(`{"tool_calls":[{"id":"call_xxx","type":"function","function":{"name":"tool_name","arguments":"{...}"}}]}`)
sb.WriteString("\n```\n\n")
sb.WriteString("CRITICAL: The 'arguments' field MUST be a JSON-encoded STRING.\n\n")
sb.WriteString("### Tool Definitions:\n\n")
for _, tool := range tools {
if tool.Type != "function" {
continue
}
sb.WriteString(fmt.Sprintf("#### %s\n", tool.Function.Name))
if tool.Function.Description != "" {
sb.WriteString(fmt.Sprintf("Description: %s\n", tool.Function.Description))
}
if len(tool.Function.Parameters) > 0 {
paramsJSON, _ := json.Marshal(tool.Function.Parameters)
sb.WriteString(fmt.Sprintf("Parameters:\n```json\n%s\n```\n", string(paramsJSON)))
}
sb.WriteString("\n")
}
return sb.String()
}
// codexEvent represents a single JSONL event from `codex exec --json`.
type codexEvent struct {
Type string `json:"type"`
ThreadID string `json:"thread_id,omitempty"`
Message string `json:"message,omitempty"`
Item *codexEventItem `json:"item,omitempty"`
Usage *codexUsage `json:"usage,omitempty"`
Error *codexEventErr `json:"error,omitempty"`
}
type codexEventItem struct {
ID string `json:"id"`
Type string `json:"type"`
Text string `json:"text,omitempty"`
Command string `json:"command,omitempty"`
Status string `json:"status,omitempty"`
ExitCode *int `json:"exit_code,omitempty"`
Output string `json:"output,omitempty"`
}
type codexUsage struct {
InputTokens int `json:"input_tokens"`
CachedInputTokens int `json:"cached_input_tokens"`
OutputTokens int `json:"output_tokens"`
}
type codexEventErr struct {
Message string `json:"message"`
}
// parseJSONLEvents processes the JSONL output from codex exec --json.
func (p *CodexCliProvider) parseJSONLEvents(output string) (*LLMResponse, error) {
var contentParts []string
var usage *UsageInfo
var lastError string
scanner := bufio.NewScanner(strings.NewReader(output))
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" {
continue
}
var event codexEvent
if err := json.Unmarshal([]byte(line), &event); err != nil {
continue // skip malformed lines
}
switch event.Type {
case "item.completed":
if event.Item != nil && event.Item.Type == "agent_message" && event.Item.Text != "" {
contentParts = append(contentParts, event.Item.Text)
}
case "turn.completed":
if event.Usage != nil {
promptTokens := event.Usage.InputTokens + event.Usage.CachedInputTokens
usage = &UsageInfo{
PromptTokens: promptTokens,
CompletionTokens: event.Usage.OutputTokens,
TotalTokens: promptTokens + event.Usage.OutputTokens,
}
}
case "error":
lastError = event.Message
case "turn.failed":
if event.Error != nil {
lastError = event.Error.Message
}
}
}
if lastError != "" && len(contentParts) == 0 {
return nil, fmt.Errorf("codex cli: %s", lastError)
}
content := strings.Join(contentParts, "\n")
// Extract tool calls from response text (same pattern as ClaudeCliProvider)
toolCalls := extractToolCallsFromText(content)
finishReason := "stop"
if len(toolCalls) > 0 {
finishReason = "tool_calls"
content = stripToolCallsFromText(content)
}
return &LLMResponse{
Content: strings.TrimSpace(content),
ToolCalls: toolCalls,
FinishReason: finishReason,
Usage: usage,
}, nil
}
+585
View File
@@ -0,0 +1,585 @@
package providers
import (
"context"
"encoding/json"
"fmt"
"os"
"os/exec"
"path/filepath"
"strings"
"testing"
)
// --- JSONL Event Parsing Tests ---
func TestParseJSONLEvents_AgentMessage(t *testing.T) {
p := &CodexCliProvider{}
events := `{"type":"thread.started","thread_id":"abc-123"}
{"type":"turn.started"}
{"type":"item.completed","item":{"id":"item_1","type":"agent_message","text":"Hello from Codex!"}}
{"type":"turn.completed","usage":{"input_tokens":100,"cached_input_tokens":50,"output_tokens":20}}`
resp, err := p.parseJSONLEvents(events)
if err != nil {
t.Fatalf("parseJSONLEvents() error: %v", err)
}
if resp.Content != "Hello from Codex!" {
t.Errorf("Content = %q, want %q", resp.Content, "Hello from Codex!")
}
if resp.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
}
if resp.Usage == nil {
t.Fatal("Usage should not be nil")
}
if resp.Usage.PromptTokens != 150 {
t.Errorf("PromptTokens = %d, want 150", resp.Usage.PromptTokens)
}
if resp.Usage.CompletionTokens != 20 {
t.Errorf("CompletionTokens = %d, want 20", resp.Usage.CompletionTokens)
}
if resp.Usage.TotalTokens != 170 {
t.Errorf("TotalTokens = %d, want 170", resp.Usage.TotalTokens)
}
if len(resp.ToolCalls) != 0 {
t.Errorf("ToolCalls should be empty, got %d", len(resp.ToolCalls))
}
}
func TestParseJSONLEvents_ToolCallExtraction(t *testing.T) {
p := &CodexCliProvider{}
toolCallText := `Let me read that file.
{"tool_calls":[{"id":"call_1","type":"function","function":{"name":"read_file","arguments":"{\"path\":\"/tmp/test.txt\"}"}}]}`
// Build valid JSONL by marshaling the event
item := codexEvent{
Type: "item.completed",
Item: &codexEventItem{ID: "item_1", Type: "agent_message", Text: toolCallText},
}
itemJSON, _ := json.Marshal(item)
usageEvt := `{"type":"turn.completed","usage":{"input_tokens":50,"cached_input_tokens":0,"output_tokens":20}}`
events := `{"type":"turn.started"}` + "\n" + string(itemJSON) + "\n" + usageEvt
resp, err := p.parseJSONLEvents(events)
if err != nil {
t.Fatalf("parseJSONLEvents() error: %v", err)
}
if resp.FinishReason != "tool_calls" {
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "tool_calls")
}
if len(resp.ToolCalls) != 1 {
t.Fatalf("ToolCalls count = %d, want 1", len(resp.ToolCalls))
}
if resp.ToolCalls[0].Name != "read_file" {
t.Errorf("ToolCalls[0].Name = %q, want %q", resp.ToolCalls[0].Name, "read_file")
}
if resp.ToolCalls[0].ID != "call_1" {
t.Errorf("ToolCalls[0].ID = %q, want %q", resp.ToolCalls[0].ID, "call_1")
}
if resp.ToolCalls[0].Function.Arguments != `{"path":"/tmp/test.txt"}` {
t.Errorf("ToolCalls[0].Function.Arguments = %q", resp.ToolCalls[0].Function.Arguments)
}
// Content should have the tool call JSON stripped
if strings.Contains(resp.Content, "tool_calls") {
t.Errorf("Content should not contain tool_calls JSON, got: %q", resp.Content)
}
}
func TestParseJSONLEvents_MultipleToolCalls(t *testing.T) {
p := &CodexCliProvider{}
toolCallText := `{"tool_calls":[{"id":"call_1","type":"function","function":{"name":"read_file","arguments":"{\"path\":\"a.txt\"}"}},{"id":"call_2","type":"function","function":{"name":"write_file","arguments":"{\"path\":\"b.txt\",\"content\":\"hello\"}"}}]}`
item := codexEvent{
Type: "item.completed",
Item: &codexEventItem{ID: "item_1", Type: "agent_message", Text: toolCallText},
}
itemJSON, _ := json.Marshal(item)
events := `{"type":"turn.started"}` + "\n" + string(itemJSON) + "\n" + `{"type":"turn.completed"}`
resp, err := p.parseJSONLEvents(events)
if err != nil {
t.Fatalf("parseJSONLEvents() error: %v", err)
}
if len(resp.ToolCalls) != 2 {
t.Fatalf("ToolCalls count = %d, want 2", len(resp.ToolCalls))
}
if resp.ToolCalls[0].Name != "read_file" {
t.Errorf("ToolCalls[0].Name = %q, want %q", resp.ToolCalls[0].Name, "read_file")
}
if resp.ToolCalls[1].Name != "write_file" {
t.Errorf("ToolCalls[1].Name = %q, want %q", resp.ToolCalls[1].Name, "write_file")
}
if resp.FinishReason != "tool_calls" {
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "tool_calls")
}
}
func TestParseJSONLEvents_MultipleMessages(t *testing.T) {
p := &CodexCliProvider{}
events := `{"type":"turn.started"}
{"type":"item.completed","item":{"id":"item_1","type":"agent_message","text":"First part."}}
{"type":"item.completed","item":{"id":"item_2","type":"command_execution","command":"ls","status":"completed"}}
{"type":"item.completed","item":{"id":"item_3","type":"agent_message","text":"Second part."}}
{"type":"turn.completed"}`
resp, err := p.parseJSONLEvents(events)
if err != nil {
t.Fatalf("parseJSONLEvents() error: %v", err)
}
if resp.Content != "First part.\nSecond part." {
t.Errorf("Content = %q, want %q", resp.Content, "First part.\nSecond part.")
}
}
func TestParseJSONLEvents_ErrorEvent(t *testing.T) {
p := &CodexCliProvider{}
events := `{"type":"thread.started","thread_id":"abc"}
{"type":"turn.started"}
{"type":"error","message":"token expired"}
{"type":"turn.failed","error":{"message":"token expired"}}`
_, err := p.parseJSONLEvents(events)
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), "token expired") {
t.Errorf("error = %q, want to contain 'token expired'", err.Error())
}
}
func TestParseJSONLEvents_TurnFailed(t *testing.T) {
p := &CodexCliProvider{}
events := `{"type":"turn.started"}
{"type":"turn.failed","error":{"message":"rate limit exceeded"}}`
_, err := p.parseJSONLEvents(events)
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), "rate limit exceeded") {
t.Errorf("error = %q, want to contain 'rate limit exceeded'", err.Error())
}
}
func TestParseJSONLEvents_ErrorWithContent(t *testing.T) {
p := &CodexCliProvider{}
// If there's an error but also content, return the content (partial success)
events := `{"type":"turn.started"}
{"type":"item.completed","item":{"id":"item_1","type":"agent_message","text":"Partial result."}}
{"type":"error","message":"connection reset"}
{"type":"turn.failed","error":{"message":"connection reset"}}`
resp, err := p.parseJSONLEvents(events)
if err != nil {
t.Fatalf("should not error when content exists: %v", err)
}
if resp.Content != "Partial result." {
t.Errorf("Content = %q, want %q", resp.Content, "Partial result.")
}
}
func TestParseJSONLEvents_EmptyOutput(t *testing.T) {
p := &CodexCliProvider{}
resp, err := p.parseJSONLEvents("")
if err != nil {
t.Fatalf("empty output should not error: %v", err)
}
if resp.Content != "" {
t.Errorf("Content = %q, want empty", resp.Content)
}
}
func TestParseJSONLEvents_MalformedLines(t *testing.T) {
p := &CodexCliProvider{}
events := `not json at all
{"type":"item.completed","item":{"id":"item_1","type":"agent_message","text":"Good line."}}
another bad line
{"type":"turn.completed","usage":{"input_tokens":10,"output_tokens":5}}`
resp, err := p.parseJSONLEvents(events)
if err != nil {
t.Fatalf("should skip malformed lines: %v", err)
}
if resp.Content != "Good line." {
t.Errorf("Content = %q, want %q", resp.Content, "Good line.")
}
if resp.Usage == nil || resp.Usage.TotalTokens != 15 {
t.Errorf("Usage.TotalTokens = %v, want 15", resp.Usage)
}
}
func TestParseJSONLEvents_CommandExecution(t *testing.T) {
p := &CodexCliProvider{}
events := `{"type":"turn.started"}
{"type":"item.started","item":{"id":"item_1","type":"command_execution","command":"bash -lc ls","status":"in_progress"}}
{"type":"item.completed","item":{"id":"item_1","type":"command_execution","command":"bash -lc ls","status":"completed","exit_code":0,"output":"file1.go\nfile2.go"}}
{"type":"item.completed","item":{"id":"item_2","type":"agent_message","text":"Found 2 files."}}
{"type":"turn.completed"}`
resp, err := p.parseJSONLEvents(events)
if err != nil {
t.Fatalf("parseJSONLEvents() error: %v", err)
}
// command_execution items should be skipped; only agent_message text is returned
if resp.Content != "Found 2 files." {
t.Errorf("Content = %q, want %q", resp.Content, "Found 2 files.")
}
}
func TestParseJSONLEvents_NoUsage(t *testing.T) {
p := &CodexCliProvider{}
events := `{"type":"turn.started"}
{"type":"item.completed","item":{"id":"item_1","type":"agent_message","text":"No usage info."}}
{"type":"turn.completed"}`
resp, err := p.parseJSONLEvents(events)
if err != nil {
t.Fatalf("parseJSONLEvents() error: %v", err)
}
if resp.Usage != nil {
t.Errorf("Usage should be nil when turn.completed has no usage, got %+v", resp.Usage)
}
}
// --- Prompt Building Tests ---
func TestBuildPrompt_SystemAsInstructions(t *testing.T) {
p := &CodexCliProvider{}
messages := []Message{
{Role: "system", Content: "You are helpful."},
{Role: "user", Content: "Hi there"},
}
prompt := p.buildPrompt(messages, nil)
if !strings.Contains(prompt, "## System Instructions") {
t.Error("prompt should contain '## System Instructions'")
}
if !strings.Contains(prompt, "You are helpful.") {
t.Error("prompt should contain system content")
}
if !strings.Contains(prompt, "## Task") {
t.Error("prompt should contain '## Task'")
}
if !strings.Contains(prompt, "Hi there") {
t.Error("prompt should contain user message")
}
}
func TestBuildPrompt_NoSystem(t *testing.T) {
p := &CodexCliProvider{}
messages := []Message{
{Role: "user", Content: "Just a question"},
}
prompt := p.buildPrompt(messages, nil)
if strings.Contains(prompt, "## System Instructions") {
t.Error("prompt should not contain system instructions header")
}
if prompt != "Just a question" {
t.Errorf("prompt = %q, want %q", prompt, "Just a question")
}
}
func TestBuildPrompt_WithTools(t *testing.T) {
p := &CodexCliProvider{}
messages := []Message{
{Role: "user", Content: "Get weather"},
}
tools := []ToolDefinition{
{
Type: "function",
Function: ToolFunctionDefinition{
Name: "get_weather",
Description: "Get current weather",
Parameters: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"city": map[string]interface{}{"type": "string"},
},
},
},
},
}
prompt := p.buildPrompt(messages, tools)
if !strings.Contains(prompt, "## Available Tools") {
t.Error("prompt should contain tools section")
}
if !strings.Contains(prompt, "get_weather") {
t.Error("prompt should contain tool name")
}
if !strings.Contains(prompt, "Get current weather") {
t.Error("prompt should contain tool description")
}
}
func TestBuildPrompt_MultipleMessages(t *testing.T) {
p := &CodexCliProvider{}
messages := []Message{
{Role: "user", Content: "Hello"},
{Role: "assistant", Content: "Hi! How can I help?"},
{Role: "user", Content: "Tell me about Go"},
}
prompt := p.buildPrompt(messages, nil)
if !strings.Contains(prompt, "Hello") {
t.Error("prompt should contain first user message")
}
if !strings.Contains(prompt, "Assistant: Hi! How can I help?") {
t.Error("prompt should contain assistant message with prefix")
}
if !strings.Contains(prompt, "Tell me about Go") {
t.Error("prompt should contain second user message")
}
}
func TestBuildPrompt_ToolResults(t *testing.T) {
p := &CodexCliProvider{}
messages := []Message{
{Role: "user", Content: "Weather?"},
{Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"},
}
prompt := p.buildPrompt(messages, nil)
if !strings.Contains(prompt, "[Tool Result for call_1]") {
t.Error("prompt should contain tool result")
}
if !strings.Contains(prompt, `{"temp": 72}`) {
t.Error("prompt should contain tool result content")
}
}
func TestBuildPrompt_SystemAndTools(t *testing.T) {
p := &CodexCliProvider{}
messages := []Message{
{Role: "system", Content: "Be concise."},
{Role: "user", Content: "Do something"},
}
tools := []ToolDefinition{
{
Type: "function",
Function: ToolFunctionDefinition{
Name: "my_tool",
Description: "A tool",
},
},
}
prompt := p.buildPrompt(messages, tools)
// System instructions should come first
sysIdx := strings.Index(prompt, "## System Instructions")
toolIdx := strings.Index(prompt, "## Available Tools")
taskIdx := strings.Index(prompt, "## Task")
if sysIdx == -1 || toolIdx == -1 || taskIdx == -1 {
t.Fatal("prompt should contain all sections")
}
if sysIdx >= taskIdx {
t.Error("system instructions should come before task")
}
if taskIdx >= toolIdx {
t.Error("task section should come before tools in the output")
}
}
// --- CLI Argument Tests ---
func TestCodexCliProvider_GetDefaultModel(t *testing.T) {
p := NewCodexCliProvider("")
if got := p.GetDefaultModel(); got != "codex-cli" {
t.Errorf("GetDefaultModel() = %q, want %q", got, "codex-cli")
}
}
// --- Mock CLI Integration Test ---
func createMockCodexCLI(t *testing.T, events []string) string {
t.Helper()
tmpDir := t.TempDir()
scriptPath := filepath.Join(tmpDir, "codex")
var sb strings.Builder
sb.WriteString("#!/bin/bash\n")
for _, event := range events {
sb.WriteString(fmt.Sprintf("echo '%s'\n", event))
}
if err := os.WriteFile(scriptPath, []byte(sb.String()), 0755); err != nil {
t.Fatal(err)
}
return scriptPath
}
func TestCodexCliProvider_MockCLI_Success(t *testing.T) {
scriptPath := createMockCodexCLI(t, []string{
`{"type":"thread.started","thread_id":"test-123"}`,
`{"type":"turn.started"}`,
`{"type":"item.completed","item":{"id":"item_1","type":"agent_message","text":"Mock response from Codex CLI"}}`,
`{"type":"turn.completed","usage":{"input_tokens":50,"cached_input_tokens":10,"output_tokens":15}}`,
})
p := &CodexCliProvider{
command: scriptPath,
workspace: "",
}
messages := []Message{{Role: "user", Content: "Hello"}}
resp, err := p.Chat(context.Background(), messages, nil, "", nil)
if err != nil {
t.Fatalf("Chat() error: %v", err)
}
if resp.Content != "Mock response from Codex CLI" {
t.Errorf("Content = %q, want %q", resp.Content, "Mock response from Codex CLI")
}
if resp.Usage == nil {
t.Fatal("Usage should not be nil")
}
if resp.Usage.PromptTokens != 60 {
t.Errorf("PromptTokens = %d, want 60", resp.Usage.PromptTokens)
}
if resp.Usage.CompletionTokens != 15 {
t.Errorf("CompletionTokens = %d, want 15", resp.Usage.CompletionTokens)
}
}
func TestCodexCliProvider_MockCLI_Error(t *testing.T) {
scriptPath := createMockCodexCLI(t, []string{
`{"type":"thread.started","thread_id":"test-err"}`,
`{"type":"turn.started"}`,
`{"type":"error","message":"auth token expired"}`,
`{"type":"turn.failed","error":{"message":"auth token expired"}}`,
})
p := &CodexCliProvider{
command: scriptPath,
workspace: "",
}
messages := []Message{{Role: "user", Content: "Hello"}}
_, err := p.Chat(context.Background(), messages, nil, "", nil)
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), "auth token expired") {
t.Errorf("error = %q, want to contain 'auth token expired'", err.Error())
}
}
func TestCodexCliProvider_MockCLI_WithModel(t *testing.T) {
// Mock script that captures args to verify model flag is passed
tmpDir := t.TempDir()
scriptPath := filepath.Join(tmpDir, "codex")
script := `#!/bin/bash
# Write args to a file for verification
echo "$@" > "` + filepath.Join(tmpDir, "args.txt") + `"
echo '{"type":"item.completed","item":{"id":"1","type":"agent_message","text":"ok"}}'
echo '{"type":"turn.completed"}'`
if err := os.WriteFile(scriptPath, []byte(script), 0755); err != nil {
t.Fatal(err)
}
p := &CodexCliProvider{
command: scriptPath,
workspace: "/tmp/test-workspace",
}
messages := []Message{{Role: "user", Content: "test"}}
_, err := p.Chat(context.Background(), messages, nil, "gpt-5.2-codex", nil)
if err != nil {
t.Fatalf("Chat() error: %v", err)
}
// Verify the args
argsData, err := os.ReadFile(filepath.Join(tmpDir, "args.txt"))
if err != nil {
t.Fatalf("reading args: %v", err)
}
args := string(argsData)
if !strings.Contains(args, "-m gpt-5.2-codex") {
t.Errorf("args should contain model flag, got: %s", args)
}
if !strings.Contains(args, "-C /tmp/test-workspace") {
t.Errorf("args should contain workspace flag, got: %s", args)
}
if !strings.Contains(args, "--json") {
t.Errorf("args should contain --json, got: %s", args)
}
if !strings.Contains(args, "--dangerously-bypass-approvals-and-sandbox") {
t.Errorf("args should contain bypass flag, got: %s", args)
}
}
func TestCodexCliProvider_MockCLI_ContextCancel(t *testing.T) {
// Script that sleeps forever
tmpDir := t.TempDir()
scriptPath := filepath.Join(tmpDir, "codex")
script := "#!/bin/bash\nsleep 60"
if err := os.WriteFile(scriptPath, []byte(script), 0755); err != nil {
t.Fatal(err)
}
p := &CodexCliProvider{
command: scriptPath,
workspace: "",
}
ctx, cancel := context.WithCancel(context.Background())
cancel() // cancel immediately
messages := []Message{{Role: "user", Content: "test"}}
_, err := p.Chat(ctx, messages, nil, "", nil)
if err == nil {
t.Fatal("expected error on canceled context")
}
}
func TestCodexCliProvider_EmptyCommand(t *testing.T) {
p := &CodexCliProvider{command: ""}
messages := []Message{{Role: "user", Content: "test"}}
_, err := p.Chat(context.Background(), messages, nil, "", nil)
if err == nil {
t.Fatal("expected error for empty command")
}
}
// --- Integration Test (requires real codex CLI with valid auth) ---
func TestCodexCliProvider_Integration(t *testing.T) {
if os.Getenv("PICOCLAW_INTEGRATION_TESTS") == "" {
t.Skip("skipping integration test (set PICOCLAW_INTEGRATION_TESTS=1 to enable)")
}
// Verify codex is available
codexPath, err := exec.LookPath("codex")
if err != nil {
t.Skip("codex CLI not found in PATH")
}
p := &CodexCliProvider{
command: codexPath,
workspace: "",
}
messages := []Message{
{Role: "user", Content: "Respond with just the word 'hello' and nothing else."},
}
resp, err := p.Chat(context.Background(), messages, nil, "", nil)
if err != nil {
t.Fatalf("Chat() error: %v", err)
}
lower := strings.ToLower(strings.TrimSpace(resp.Content))
if !strings.Contains(lower, "hello") {
t.Errorf("Content = %q, expected to contain 'hello'", resp.Content)
}
}
+123 -9
View File
@@ -3,6 +3,7 @@ package providers
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
@@ -10,8 +11,12 @@ import (
"github.com/openai/openai-go/v3/option"
"github.com/openai/openai-go/v3/responses"
"github.com/sipeed/picoclaw/pkg/auth"
"github.com/sipeed/picoclaw/pkg/logger"
)
const codexDefaultModel = "gpt-5.2"
const codexDefaultInstructions = "You are Codex, a coding assistant."
type CodexProvider struct {
client *openai.Client
accountID string
@@ -24,6 +29,8 @@ func NewCodexProvider(token, accountID string) *CodexProvider {
opts := []option.RequestOption{
option.WithBaseURL("https://chatgpt.com/backend-api/codex"),
option.WithAPIKey(token),
option.WithHeader("originator", "codex_cli_rs"),
option.WithHeader("OpenAI-Beta", "responses=experimental"),
}
if accountID != "" {
opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accountID))
@@ -43,6 +50,15 @@ func NewCodexProviderWithTokenSource(token, accountID string, tokenSource func()
func (p *CodexProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
var opts []option.RequestOption
accountID := p.accountID
resolvedModel, fallbackReason := resolveCodexModel(model)
if fallbackReason != "" {
logger.WarnCF("provider.codex", "Requested model is not compatible with Codex backend, using fallback", map[string]interface{}{
"requested_model": model,
"resolved_model": resolvedModel,
"reason": fallbackReason,
})
}
if p.tokenSource != nil {
tok, accID, err := p.tokenSource()
if err != nil {
@@ -50,22 +66,120 @@ func (p *CodexProvider) Chat(ctx context.Context, messages []Message, tools []To
}
opts = append(opts, option.WithAPIKey(tok))
if accID != "" {
opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accID))
accountID = accID
}
}
if accountID != "" {
opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accountID))
} else {
logger.WarnCF("provider.codex", "No account id found for Codex request; backend may reject with 400", map[string]interface{}{
"requested_model": model,
"resolved_model": resolvedModel,
})
}
params := buildCodexParams(messages, tools, model, options)
params := buildCodexParams(messages, tools, resolvedModel, options)
resp, err := p.client.Responses.New(ctx, params, opts...)
stream := p.client.Responses.NewStreaming(ctx, params, opts...)
defer stream.Close()
var resp *responses.Response
for stream.Next() {
evt := stream.Current()
if evt.Type == "response.completed" || evt.Type == "response.failed" || evt.Type == "response.incomplete" {
evtResp := evt.Response
if evtResp.ID != "" {
copy := evtResp
resp = &copy
}
}
}
err := stream.Err()
if err != nil {
fields := map[string]interface{}{
"requested_model": model,
"resolved_model": resolvedModel,
"messages_count": len(messages),
"tools_count": len(tools),
"account_id_present": accountID != "",
"error": err.Error(),
}
var apiErr *openai.Error
if errors.As(err, &apiErr) {
fields["status_code"] = apiErr.StatusCode
fields["api_type"] = apiErr.Type
fields["api_code"] = apiErr.Code
fields["api_param"] = apiErr.Param
fields["api_message"] = apiErr.Message
if apiErr.StatusCode == 400 {
fields["hint"] = "verify account id header and model compatibility for codex backend"
}
if apiErr.Response != nil {
fields["request_id"] = apiErr.Response.Header.Get("x-request-id")
}
}
logger.ErrorCF("provider.codex", "Codex API call failed", fields)
return nil, fmt.Errorf("codex API call: %w", err)
}
if resp == nil {
fields := map[string]interface{}{
"requested_model": model,
"resolved_model": resolvedModel,
"messages_count": len(messages),
"tools_count": len(tools),
"account_id_present": accountID != "",
}
logger.ErrorCF("provider.codex", "Codex stream ended without completed response event", fields)
return nil, fmt.Errorf("codex API call: stream ended without completed response")
}
return parseCodexResponse(resp), nil
}
func (p *CodexProvider) GetDefaultModel() string {
return "gpt-4o"
return codexDefaultModel
}
func resolveCodexModel(model string) (string, string) {
m := strings.ToLower(strings.TrimSpace(model))
if m == "" {
return codexDefaultModel, "empty model"
}
if strings.HasPrefix(m, "openai/") {
m = strings.TrimPrefix(m, "openai/")
} else if strings.Contains(m, "/") {
return codexDefaultModel, "non-openai model namespace"
}
unsupportedPrefixes := []string{
"glm",
"claude",
"anthropic",
"gemini",
"google",
"moonshot",
"kimi",
"qwen",
"deepseek",
"llama",
"meta-llama",
"mistral",
"grok",
"xai",
"zhipu",
}
for _, prefix := range unsupportedPrefixes {
if strings.HasPrefix(m, prefix) {
return codexDefaultModel, "unsupported model prefix"
}
}
if strings.HasPrefix(m, "gpt-") || strings.HasPrefix(m, "o3") || strings.HasPrefix(m, "o4") {
return m, ""
}
return codexDefaultModel, "unsupported model family"
}
func buildCodexParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) responses.ResponseNewParams {
@@ -135,7 +249,8 @@ func buildCodexParams(messages []Message, tools []ToolDefinition, model string,
Input: responses.ResponseNewParamsInputUnion{
OfInputItemList: inputItems,
},
Store: openai.Opt(false),
Instructions: openai.Opt(instructions),
Store: openai.Opt(false),
}
if instructions != "" {
@@ -149,10 +264,6 @@ func buildCodexParams(messages []Message, tools []ToolDefinition, model string,
params.MaxOutputTokens = openai.Opt(int64(maxTokens))
}
if temp, ok := options["temperature"].(float64); ok {
params.Temperature = openai.Opt(temp)
}
if len(tools) > 0 {
params.Tools = translateToolsForCodex(tools)
}
@@ -242,6 +353,9 @@ func createCodexTokenSource() func() (string, string, error) {
if err != nil {
return "", "", fmt.Errorf("refreshing token: %w", err)
}
if refreshed.AccountID == "" {
refreshed.AccountID = cred.AccountID
}
if err := auth.SetCredential("openai", refreshed); err != nil {
return "", "", fmt.Errorf("saving refreshed token: %w", err)
}
+204 -5
View File
@@ -2,6 +2,7 @@ package providers
import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
@@ -16,7 +17,8 @@ func TestBuildCodexParams_BasicMessage(t *testing.T) {
{Role: "user", Content: "Hello"},
}
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{
"max_tokens": 2048,
"max_tokens": 2048,
"temperature": 0.7,
})
if params.Model != "gpt-4o" {
t.Errorf("Model = %q, want %q", params.Model, "gpt-4o")
@@ -203,6 +205,16 @@ func TestCodexProvider_ChatRoundTrip(t *testing.T) {
return
}
var reqBody map[string]interface{}
if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil {
http.Error(w, "invalid json", http.StatusBadRequest)
return
}
if reqBody["stream"] != true {
http.Error(w, "stream must be true", http.StatusBadRequest)
return
}
resp := map[string]interface{}{
"id": "resp_test",
"object": "response",
@@ -226,8 +238,7 @@ func TestCodexProvider_ChatRoundTrip(t *testing.T) {
"output_tokens_details": map[string]interface{}{"reasoning_tokens": 0},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
writeCompletedSSE(w, resp)
}))
defer server.Close()
@@ -250,10 +261,185 @@ func TestCodexProvider_ChatRoundTrip(t *testing.T) {
}
}
func TestCodexProvider_ChatRoundTrip_TokenSourceFallbackAccountID(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/responses" {
http.Error(w, "not found: "+r.URL.Path, http.StatusNotFound)
return
}
if r.Header.Get("Authorization") != "Bearer refreshed-token" {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
if r.Header.Get("Chatgpt-Account-Id") != "acc-123" {
http.Error(w, "missing account id", http.StatusBadRequest)
return
}
var reqBody map[string]interface{}
if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil {
http.Error(w, "invalid json", http.StatusBadRequest)
return
}
if _, ok := reqBody["instructions"]; !ok {
http.Error(w, "missing instructions", http.StatusBadRequest)
return
}
if reqBody["instructions"] == "" {
http.Error(w, "instructions must not be empty", http.StatusBadRequest)
return
}
if _, ok := reqBody["temperature"]; ok {
http.Error(w, "temperature is not supported", http.StatusBadRequest)
return
}
if reqBody["stream"] != true {
http.Error(w, "stream must be true", http.StatusBadRequest)
return
}
resp := map[string]interface{}{
"id": "resp_test",
"object": "response",
"status": "completed",
"output": []map[string]interface{}{
{
"id": "msg_1",
"type": "message",
"role": "assistant",
"status": "completed",
"content": []map[string]interface{}{
{"type": "output_text", "text": "Hi from Codex!"},
},
},
},
"usage": map[string]interface{}{
"input_tokens": 8,
"output_tokens": 4,
"total_tokens": 12,
"input_tokens_details": map[string]interface{}{"cached_tokens": 0},
"output_tokens_details": map[string]interface{}{"reasoning_tokens": 0},
},
}
writeCompletedSSE(w, resp)
}))
defer server.Close()
provider := NewCodexProvider("stale-token", "acc-123")
provider.client = createOpenAITestClient(server.URL, "stale-token", "")
provider.tokenSource = func() (string, string, error) {
return "refreshed-token", "", nil
}
messages := []Message{{Role: "user", Content: "Hello"}}
resp, err := provider.Chat(t.Context(), messages, nil, "gpt-4o", map[string]interface{}{"temperature": 0.7})
if err != nil {
t.Fatalf("Chat() error: %v", err)
}
if resp.Content != "Hi from Codex!" {
t.Errorf("Content = %q, want %q", resp.Content, "Hi from Codex!")
}
}
func TestCodexProvider_ChatRoundTrip_ModelFallbackFromUnsupported(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/responses" {
http.Error(w, "not found: "+r.URL.Path, http.StatusNotFound)
return
}
var reqBody map[string]interface{}
if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil {
http.Error(w, "invalid json", http.StatusBadRequest)
return
}
if reqBody["model"] != codexDefaultModel {
http.Error(w, "unsupported model", http.StatusBadRequest)
return
}
if reqBody["stream"] != true {
http.Error(w, "stream must be true", http.StatusBadRequest)
return
}
if reqBody["instructions"] != codexDefaultInstructions {
http.Error(w, "missing default instructions", http.StatusBadRequest)
return
}
resp := map[string]interface{}{
"id": "resp_test",
"object": "response",
"status": "completed",
"output": []map[string]interface{}{
{
"id": "msg_1",
"type": "message",
"role": "assistant",
"status": "completed",
"content": []map[string]interface{}{
{"type": "output_text", "text": "Hi from Codex!"},
},
},
},
"usage": map[string]interface{}{
"input_tokens": 8,
"output_tokens": 4,
"total_tokens": 12,
"input_tokens_details": map[string]interface{}{"cached_tokens": 0},
"output_tokens_details": map[string]interface{}{"reasoning_tokens": 0},
},
}
writeCompletedSSE(w, resp)
}))
defer server.Close()
provider := NewCodexProvider("test-token", "acc-123")
provider.client = createOpenAITestClient(server.URL, "test-token", "acc-123")
messages := []Message{{Role: "user", Content: "Hello"}}
resp, err := provider.Chat(t.Context(), messages, nil, "gpt-5.2", nil)
if err != nil {
t.Fatalf("Chat() error: %v", err)
}
if resp.Content != "Hi from Codex!" {
t.Errorf("Content = %q, want %q", resp.Content, "Hi from Codex!")
}
}
func TestCodexProvider_GetDefaultModel(t *testing.T) {
p := NewCodexProvider("test-token", "")
if got := p.GetDefaultModel(); got != "gpt-4o" {
t.Errorf("GetDefaultModel() = %q, want %q", got, "gpt-4o")
if got := p.GetDefaultModel(); got != codexDefaultModel {
t.Errorf("GetDefaultModel() = %q, want %q", got, codexDefaultModel)
}
}
func TestResolveCodexModel(t *testing.T) {
tests := []struct {
name string
input string
wantModel string
wantFallback bool
}{
{name: "empty", input: "", wantModel: codexDefaultModel, wantFallback: true},
{name: "unsupported namespace", input: "anthropic/claude-3.5", wantModel: codexDefaultModel, wantFallback: true},
{name: "non-openai prefixed", input: "glm-4.7", wantModel: codexDefaultModel, wantFallback: true},
{name: "openai prefix", input: "openai/gpt-5.2", wantModel: "gpt-5.2", wantFallback: false},
{name: "direct gpt", input: "gpt-4o", wantModel: "gpt-4o", wantFallback: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotModel, reason := resolveCodexModel(tt.input)
if gotModel != tt.wantModel {
t.Fatalf("resolveCodexModel(%q) model = %q, want %q", tt.input, gotModel, tt.wantModel)
}
if tt.wantFallback && reason == "" {
t.Fatalf("resolveCodexModel(%q) expected fallback reason", tt.input)
}
if !tt.wantFallback && reason != "" {
t.Fatalf("resolveCodexModel(%q) unexpected fallback reason: %q", tt.input, reason)
}
})
}
}
@@ -268,3 +454,16 @@ func createOpenAITestClient(baseURL, token, accountID string) *openai.Client {
c := openai.NewClient(opts...)
return &c
}
func writeCompletedSSE(w http.ResponseWriter, response map[string]interface{}) {
event := map[string]interface{}{
"type": "response.completed",
"sequence_number": 1,
"response": response,
}
b, _ := json.Marshal(event)
w.Header().Set("Content-Type", "text/event-stream")
fmt.Fprintf(w, "event: response.completed\n")
fmt.Fprintf(w, "data: %s\n\n", string(b))
fmt.Fprintf(w, "data: [DONE]\n\n")
}
+21 -4
View File
@@ -53,10 +53,10 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too
return nil, fmt.Errorf("API base not configured")
}
// Strip provider prefix from model name (e.g., moonshot/kimi-k2.5 -> kimi-k2.5)
// Strip provider prefix from model name (e.g., moonshot/kimi-k2.5 -> kimi-k2.5, groq/openai/gpt-oss-120b -> openai/gpt-oss-120b, ollama/qwen2.5:14b -> qwen2.5:14b)
if idx := strings.Index(model, "/"); idx != -1 {
prefix := model[:idx]
if prefix == "moonshot" || prefix == "nvidia" {
if prefix == "moonshot" || prefix == "nvidia" || prefix == "groq" || prefix == "ollama" {
model = model[idx+1:]
}
}
@@ -240,6 +240,9 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) {
}
case "openai", "gpt":
if cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != "" {
if cfg.Providers.OpenAI.AuthMethod == "codex-cli" {
return NewCodexProviderWithTokenSource("", "", CreateCodexCliTokenSource()), nil
}
if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" {
return createCodexAuthProvider()
}
@@ -299,11 +302,17 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) {
}
}
case "claude-cli", "claudecode", "claude-code":
workspace := cfg.Agents.Defaults.Workspace
workspace := cfg.WorkspacePath()
if workspace == "" {
workspace = "."
}
return NewClaudeCliProvider(workspace), nil
case "codex-cli", "codex-code":
workspace := cfg.WorkspacePath()
if workspace == "" {
workspace = "."
}
return NewCodexCliProvider(workspace), nil
case "deepseek":
if cfg.Providers.DeepSeek.APIKey != "" {
apiKey = cfg.Providers.DeepSeek.APIKey
@@ -400,7 +409,15 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) {
if apiBase == "" {
apiBase = "https://integrate.api.nvidia.com/v1"
}
case (strings.Contains(lowerModel, "ollama") || strings.HasPrefix(model, "ollama/")) && cfg.Providers.Ollama.APIKey != "":
fmt.Println("Ollama provider selected based on model name prefix")
apiKey = cfg.Providers.Ollama.APIKey
apiBase = cfg.Providers.Ollama.APIBase
proxy = cfg.Providers.Ollama.Proxy
if apiBase == "" {
apiBase = "http://localhost:11434/v1"
}
fmt.Println("Ollama apiBase:", apiBase)
case cfg.Providers.VLLM.APIBase != "":
apiKey = cfg.Providers.VLLM.APIKey
apiBase = cfg.Providers.VLLM.APIBase
+72
View File
@@ -0,0 +1,72 @@
package providers
import (
"encoding/json"
"strings"
)
// extractToolCallsFromText parses tool call JSON from response text.
// Both ClaudeCliProvider and CodexCliProvider use this to extract
// tool calls that the model outputs in its response text.
func extractToolCallsFromText(text string) []ToolCall {
start := strings.Index(text, `{"tool_calls"`)
if start == -1 {
return nil
}
end := findMatchingBrace(text, start)
if end == start {
return nil
}
jsonStr := text[start:end]
var wrapper struct {
ToolCalls []struct {
ID string `json:"id"`
Type string `json:"type"`
Function struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
} `json:"function"`
} `json:"tool_calls"`
}
if err := json.Unmarshal([]byte(jsonStr), &wrapper); err != nil {
return nil
}
var result []ToolCall
for _, tc := range wrapper.ToolCalls {
var args map[string]interface{}
json.Unmarshal([]byte(tc.Function.Arguments), &args)
result = append(result, ToolCall{
ID: tc.ID,
Type: tc.Type,
Name: tc.Function.Name,
Arguments: args,
Function: &FunctionCall{
Name: tc.Function.Name,
Arguments: tc.Function.Arguments,
},
})
}
return result
}
// stripToolCallsFromText removes tool call JSON from response text.
func stripToolCallsFromText(text string) string {
start := strings.Index(text, `{"tool_calls"`)
if start == -1 {
return text
}
end := findMatchingBrace(text, start)
if end == start {
return text
}
return strings.TrimSpace(text[:start] + text[end:])
}
+16
View File
@@ -264,3 +264,19 @@ func (sm *SessionManager) loadSessions() error {
return nil
}
// SetHistory updates the messages of a session.
func (sm *SessionManager) SetHistory(key string, history []providers.Message) {
sm.mu.Lock()
defer sm.mu.Unlock()
session, ok := sm.sessions[key]
if ok {
// Create a deep copy to strictly isolate internal state
// from the caller's slice.
msgs := make([]providers.Message, len(history))
copy(msgs, history)
session.Messages = msgs
session.Updated = time.Now()
}
}
+45
View File
@@ -2,13 +2,22 @@ package skills
import (
"encoding/json"
"errors"
"fmt"
"log/slog"
"os"
"path/filepath"
"regexp"
"strings"
)
var namePattern = regexp.MustCompile(`^[a-zA-Z0-9]+(-[a-zA-Z0-9]+)*$`)
const (
MaxNameLength = 64
MaxDescriptionLength = 1024
)
type SkillMetadata struct {
Name string `json:"name"`
Description string `json:"description"`
@@ -21,6 +30,27 @@ type SkillInfo struct {
Description string `json:"description"`
}
func (info SkillInfo) validate() error {
var errs error
if info.Name == "" {
errs = errors.Join(errs, errors.New("name is required"))
} else {
if len(info.Name) > MaxNameLength {
errs = errors.Join(errs, fmt.Errorf("name exceeds %d characters", MaxNameLength))
}
if !namePattern.MatchString(info.Name) {
errs = errors.Join(errs, errors.New("name must be alphanumeric with hyphens"))
}
}
if info.Description == "" {
errs = errors.Join(errs, errors.New("description is required"))
} else if len(info.Description) > MaxDescriptionLength {
errs = errors.Join(errs, fmt.Errorf("description exceeds %d character", MaxDescriptionLength))
}
return errs
}
type SkillsLoader struct {
workspace string
workspaceSkills string // workspace skills (项目级别)
@@ -54,6 +84,11 @@ func (sl *SkillsLoader) ListSkills() []SkillInfo {
metadata := sl.getSkillMetadata(skillFile)
if metadata != nil {
info.Description = metadata.Description
info.Name = metadata.Name
}
if err := info.validate(); err != nil {
slog.Warn("invalid skill from workspace", "name", info.Name, "error", err)
continue
}
skills = append(skills, info)
}
@@ -89,6 +124,11 @@ func (sl *SkillsLoader) ListSkills() []SkillInfo {
metadata := sl.getSkillMetadata(skillFile)
if metadata != nil {
info.Description = metadata.Description
info.Name = metadata.Name
}
if err := info.validate(); err != nil {
slog.Warn("invalid skill from global", "name", info.Name, "error", err)
continue
}
skills = append(skills, info)
}
@@ -123,6 +163,11 @@ func (sl *SkillsLoader) ListSkills() []SkillInfo {
metadata := sl.getSkillMetadata(skillFile)
if metadata != nil {
info.Description = metadata.Description
info.Name = metadata.Name
}
if err := info.validate(); err != nil {
slog.Warn("invalid skill from builtin", "name", info.Name, "error", err)
continue
}
skills = append(skills, info)
}
+77
View File
@@ -0,0 +1,77 @@
package skills
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestSkillsInfoValidate(t *testing.T) {
testcases := []struct {
name string
skillName string
description string
wantErr bool
errContains []string
}{
{
name: "valid-skill",
skillName: "valid-skill",
description: "a valid skill description",
wantErr: false,
},
{
name: "empty-name",
skillName: "",
description: "description without name",
wantErr: true,
errContains: []string{"name is required"},
},
{
name: "empty-description",
skillName: "skill-without-description",
description: "",
wantErr: true,
errContains: []string{"description is required"},
},
{
name: "empty-both",
skillName: "",
description: "",
wantErr: true,
errContains: []string{"name is required", "description is required"},
},
{
name: "name-with-spaces",
skillName: "skill with spaces",
description: "invalid name with spaces",
wantErr: true,
errContains: []string{"name must be alphanumeric with hyphens"},
},
{
name: "name-with-underscore",
skillName: "skill_underscore",
description: "invalid name with underscore",
wantErr: true,
errContains: []string{"name must be alphanumeric with hyphens"},
},
}
for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
info := SkillInfo{
Name: tc.skillName,
Description: tc.description,
}
err := info.validate()
if tc.wantErr {
assert.Error(t, err)
for _, msg := range tc.errContains {
assert.ErrorContains(t, err, msg)
}
} else {
assert.NoError(t, err)
}
})
}
}
+2 -2
View File
@@ -28,8 +28,8 @@ type CronTool struct {
}
// NewCronTool creates a new CronTool
func NewCronTool(cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus, workspace string, execTimeout time.Duration) *CronTool {
execTool := NewExecTool(workspace, false)
func NewCronTool(cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus, workspace string, restrict bool, execTimeout time.Duration) *CronTool {
execTool := NewExecTool(workspace, restrict)
if execTimeout > 0 {
execTool.SetTimeout(execTimeout)
}
+43 -2
View File
@@ -29,13 +29,54 @@ func validatePath(path, workspace string, restrict bool) (string, error) {
}
}
if restrict && !strings.HasPrefix(absPath, absWorkspace) {
return "", fmt.Errorf("access denied: path is outside the workspace")
if restrict {
if !isWithinWorkspace(absPath, absWorkspace) {
return "", fmt.Errorf("access denied: path is outside the workspace")
}
workspaceReal := absWorkspace
if resolved, err := filepath.EvalSymlinks(absWorkspace); err == nil {
workspaceReal = resolved
}
if resolved, err := filepath.EvalSymlinks(absPath); err == nil {
if !isWithinWorkspace(resolved, workspaceReal) {
return "", fmt.Errorf("access denied: symlink resolves outside workspace")
}
} else if os.IsNotExist(err) {
if parentResolved, err := resolveExistingAncestor(filepath.Dir(absPath)); err == nil {
if !isWithinWorkspace(parentResolved, workspaceReal) {
return "", fmt.Errorf("access denied: symlink resolves outside workspace")
}
} else if !os.IsNotExist(err) {
return "", fmt.Errorf("failed to resolve path: %w", err)
}
} else {
return "", fmt.Errorf("failed to resolve path: %w", err)
}
}
return absPath, nil
}
func resolveExistingAncestor(path string) (string, error) {
for current := filepath.Clean(path); ; current = filepath.Dir(current) {
if resolved, err := filepath.EvalSymlinks(current); err == nil {
return resolved, nil
} else if !os.IsNotExist(err) {
return "", err
}
if filepath.Dir(current) == current {
return "", os.ErrNotExist
}
}
}
func isWithinWorkspace(candidate, workspace string) bool {
rel, err := filepath.Rel(filepath.Clean(workspace), filepath.Clean(candidate))
return err == nil && rel != ".." && !strings.HasPrefix(rel, ".."+string(os.PathSeparator))
}
type ReadFileTool struct {
workspace string
restrict bool
+32
View File
@@ -247,3 +247,35 @@ func TestFilesystemTool_ListDir_DefaultPath(t *testing.T) {
t.Errorf("Expected success with default path '.', got IsError=true: %s", result.ForLLM)
}
}
// Block paths that look inside workspace but point outside via symlink.
func TestFilesystemTool_ReadFile_RejectsSymlinkEscape(t *testing.T) {
root := t.TempDir()
workspace := filepath.Join(root, "workspace")
if err := os.MkdirAll(workspace, 0755); err != nil {
t.Fatalf("failed to create workspace: %v", err)
}
secret := filepath.Join(root, "secret.txt")
if err := os.WriteFile(secret, []byte("top secret"), 0644); err != nil {
t.Fatalf("failed to write secret file: %v", err)
}
link := filepath.Join(workspace, "leak.txt")
if err := os.Symlink(secret, link); err != nil {
t.Skipf("symlink not supported in this environment: %v", err)
}
tool := NewReadFileTool(workspace, true)
result := tool.Execute(context.Background(), map[string]interface{}{
"path": link,
})
if !result.IsError {
t.Fatalf("expected symlink escape to be blocked")
}
if !strings.Contains(result.ForLLM, "symlink resolves outside workspace") {
t.Fatalf("expected symlink escape error, got: %s", result.ForLLM)
}
}
+10 -6
View File
@@ -173,19 +173,23 @@ func TestWebTool_WebFetch_Truncation(t *testing.T) {
}
}
// TestWebTool_WebSearch_NoApiKey verifies that nil is returned when no provider is configured
// TestWebTool_WebSearch_NoApiKey verifies that no tool is created when API key is missing
func TestWebTool_WebSearch_NoApiKey(t *testing.T) {
tool := NewWebSearchTool(WebSearchToolOptions{BraveAPIKey: "", BraveMaxResults: 5})
// Should return nil when no provider is enabled
tool := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: ""})
if tool != nil {
t.Errorf("Expected nil when no search provider is configured")
t.Errorf("Expected nil tool when Brave API key is empty")
}
// Also nil when nothing is enabled
tool = NewWebSearchTool(WebSearchToolOptions{})
if tool != nil {
t.Errorf("Expected nil tool when no provider is enabled")
}
}
// TestWebTool_WebSearch_MissingQuery verifies error handling for missing query
func TestWebTool_WebSearch_MissingQuery(t *testing.T) {
tool := NewWebSearchTool(WebSearchToolOptions{BraveAPIKey: "test-key", BraveMaxResults: 5, BraveEnabled: true})
tool := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: "test-key", BraveMaxResults: 5})
ctx := context.Background()
args := map[string]interface{}{}
+1 -2
View File
@@ -73,9 +73,8 @@ func DownloadFile(url, filename string, opts DownloadOptions) string {
}
// Generate unique filename with UUID prefix to prevent conflicts
ext := filepath.Ext(filename)
safeName := SanitizeFilename(filename)
localPath := filepath.Join(mediaDir, uuid.New().String()[:8]+"_"+safeName+ext)
localPath := filepath.Join(mediaDir, uuid.New().String()[:8]+"_"+safeName)
// Create HTTP request
req, err := http.NewRequest("GET", url, nil)