Merge branch 'main' into mcp-tools-support

This commit is contained in:
yuchou87
2026-02-16 16:37:24 +08:00
45 changed files with 3073 additions and 236 deletions
+37 -23
View File
@@ -1,12 +1,18 @@
name: 🐳 Build & Push Docker Image
on:
release:
types: [published]
workflow_call:
inputs:
tag:
description: "Release tag"
required: true
type: string
env:
REGISTRY: ghcr.io
IMAGE_NAME: ${{ github.repository_owner }}/picoclaw
GHCR_REGISTRY: ghcr.io
GHCR_IMAGE_NAME: ${{ github.repository_owner }}/picoclaw
DOCKERHUB_REGISTRY: docker.io
DOCKERHUB_IMAGE_NAME: ${{ vars.DOCKERHUB_REPOSITORY }}
jobs:
build:
@@ -20,6 +26,8 @@ jobs:
# ── Checkout ──────────────────────────────
- name: 📥 Checkout repository
uses: actions/checkout@v4
with:
ref: ${{ inputs.tag }}
# ── Docker Buildx ─────────────────────────
- name: 🔧 Set up Docker Buildx
@@ -27,36 +35,42 @@ jobs:
# ── Login to GHCR ─────────────────────────
- name: 🔑 Login to GitHub Container Registry
if: github.event_name != 'pull_request'
uses: docker/login-action@v3
with:
registry: ${{ env.REGISTRY }}
registry: ${{ env.GHCR_REGISTRY }}
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
# ── Metadata (tags & labels) ──────────────
- name: 🏷️ Extract Docker metadata
id: meta
uses: docker/metadata-action@v5
# ── Login to Docker Hub ────────────────────
- name: 🔑 Login to Docker Hub
uses: docker/login-action@v3
with:
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
tags: |
type=ref,event=branch
type=ref,event=pr
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
type=sha,prefix=
type=raw,value=latest,enable={{is_default_branch}}
type=raw,value={{date 'YYYYMMDD-HHmmss'}},enable={{is_default_branch}}
registry: ${{ env.DOCKERHUB_REGISTRY }}
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
# ── Metadata (tags & labels) ──────────────
- name: 🏷️ Prepare image tags
id: tags
shell: bash
run: |
tag="${{ inputs.tag }}"
echo "ghcr_tag=${{ env.GHCR_REGISTRY }}/${{ env.GHCR_IMAGE_NAME }}:${tag}" >> "$GITHUB_OUTPUT"
echo "ghcr_latest=${{ env.GHCR_REGISTRY }}/${{ env.GHCR_IMAGE_NAME }}:latest" >> "$GITHUB_OUTPUT"
echo "dockerhub_tag=${{ env.DOCKERHUB_REGISTRY }}/${{ env.DOCKERHUB_IMAGE_NAME }}:${tag}" >> "$GITHUB_OUTPUT"
echo "dockerhub_latest=${{ env.DOCKERHUB_REGISTRY }}/${{ env.DOCKERHUB_IMAGE_NAME }}:latest" >> "$GITHUB_OUTPUT"
# ── Build & Push ──────────────────────────
- name: 🚀 Build and push Docker image
uses: docker/build-push-action@v6
with:
context: .
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
push: true
tags: |
${{ steps.tags.outputs.ghcr_tag }}
${{ steps.tags.outputs.ghcr_latest }}
${{ steps.tags.outputs.dockerhub_tag }}
${{ steps.tags.outputs.dockerhub_latest }}
cache-from: type=gha
cache-to: type=gha,mode=max
platforms: linux/amd64,linux/arm64
platforms: linux/amd64,linux/arm64,linux/riscv64
+36 -37
View File
@@ -38,14 +38,18 @@ jobs:
git tag -a "${{ inputs.tag }}" -m "Release ${{ inputs.tag }}"
git push origin "${{ inputs.tag }}"
build-binaries:
name: Build Release Binaries
release:
name: GoReleaser Release
needs: create-tag
runs-on: ubuntu-latest
permissions:
contents: write
packages: write
steps:
- name: Checkout tag
uses: actions/checkout@v4
with:
fetch-depth: 0
ref: ${{ inputs.tag }}
- name: Setup Go from go.mod
@@ -53,47 +57,42 @@ jobs:
with:
go-version-file: go.mod
- name: Build all binaries
run: make build-all
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Generate checksums
shell: bash
run: |
shasum -a 256 build/picoclaw-* > build/sha256sums.txt
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Upload release binaries artifact
uses: actions/upload-artifact@v4
- name: Login to GitHub Container Registry
uses: docker/login-action@v3
with:
name: picoclaw-binaries
path: |
build/picoclaw-*
build/sha256sums.txt
if-no-files-found: error
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
create-release:
name: Create GitHub Release
needs: [create-tag, build-binaries]
runs-on: ubuntu-latest
permissions:
contents: write
steps:
- name: Download all artifacts
uses: actions/download-artifact@v4
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
path: release-artifacts
registry: docker.io
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Show downloaded files
run: ls -R release-artifacts
- name: Create release
uses: softprops/action-gh-release@v2
- name: Run GoReleaser
uses: goreleaser/goreleaser-action@v6
with:
tag_name: ${{ inputs.tag }}
name: ${{ inputs.tag }}
draft: ${{ inputs.draft }}
prerelease: ${{ inputs.prerelease }}
files: |
release-artifacts/**/*
generate_release_notes: true
distribution: goreleaser
version: ~> v2
args: release --clean
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
GITHUB_REPOSITORY_OWNER: ${{ github.repository_owner }}
DOCKERHUB_IMAGE_NAME: ${{ vars.DOCKERHUB_REPOSITORY }}
- name: Apply release flags
shell: bash
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
gh release edit "${{ inputs.tag }}" \
--draft=${{ inputs.draft }} \
--prerelease=${{ inputs.prerelease }}
+23 -5
View File
@@ -5,9 +5,11 @@ version: 2
before:
hooks:
- go mod tidy
- go generate ./cmd/picoclaw
builds:
- env:
- id: picoclaw
env:
- CGO_ENABLED=0
goos:
- linux
@@ -26,6 +28,22 @@ builds:
- goos: windows
goarch: arm
dockers_v2:
- id: picoclaw
dockerfile: Dockerfile.goreleaser
ids:
- picoclaw
images:
- "ghcr.io/{{ .Env.GITHUB_REPOSITORY_OWNER }}/picoclaw"
- "docker.io/{{ .Env.DOCKERHUB_IMAGE_NAME }}"
tags:
- "{{ .Tag }}"
- "latest"
platforms:
- linux/amd64
- linux/arm64
- linux/riscv64
archives:
- formats: [tar.gz]
# this name template makes the OS and Arch compatible with the results of `uname`.
@@ -48,10 +66,10 @@ changelog:
- "^docs:"
- "^test:"
upx:
- enabled: true
compress: best
lzma: true
# upx:
# - enabled: true
# compress: best
# lzma: true
release:
footer: >-
+4
View File
@@ -22,6 +22,10 @@ FROM alpine:3.23
RUN apk add --no-cache ca-certificates tzdata curl
# Health check
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
CMD wget -q --spider http://localhost:18790/health || exit 1
# Copy binary
COPY --from=builder /src/build/picoclaw /usr/local/bin/picoclaw
+10
View File
@@ -0,0 +1,10 @@
FROM alpine:3.21
ARG TARGETPLATFORM
RUN apk add --no-cache ca-certificates tzdata
COPY $TARGETPLATFORM/picoclaw /usr/local/bin/picoclaw
ENTRYPOINT ["picoclaw"]
CMD ["gateway"]
+10 -2
View File
@@ -119,7 +119,7 @@ clean:
@rm -rf $(BUILD_DIR)
@echo "Clean complete"
## fmt: Format Go code
## vet: Run go vet for static analysis
vet:
@$(GO) vet ./...
@@ -131,11 +131,19 @@ test:
fmt:
@$(GO) fmt ./...
## deps: Update dependencies
## deps: Download dependencies
deps:
@$(GO) mod download
@$(GO) mod verify
## update-deps: Update dependencies
update-deps:
@$(GO) get -u ./...
@$(GO) mod tidy
## check: Run vet, fmt, and verify dependencies
check: deps fmt vet test
## run: Build and run picoclaw
run: build
@$(BUILD_DIR)/$(BINARY_NAME) $(ARGS)
+1 -1
View File
@@ -44,7 +44,7 @@
> * **NO CRYPTO:** PicoClaw has **NO** official token/coin. All claims on `pump.fun` or other trading platforms are **SCAMS**.
> * **OFFICIAL DOMAIN:** The **ONLY** official website is **[picoclaw.io](https://picoclaw.io)**, and company website is **[sipeed.com](https://sipeed.com)**
> * **Warning:** Many `.ai/.org/.com/.net/...` domains are registered by third parties.
>
> * **Warning:** picoclaw is in early development now and may have unresolved network security issues. Do not deploy to production environments before the v1.0 release.
## 📢 News
+2 -2
View File
@@ -45,8 +45,8 @@
> * **无加密货币 (NO CRYPTO):** PicoClaw **没有** 发行任何官方代币、Token 或虚拟货币。所有在 `pump.fun` 或其他交易平台上的相关声称均为 **诈骗**。
> * **官方域名:** 唯一的官方网站是 **[picoclaw.io](https://picoclaw.io)**,公司官网是 **[sipeed.com](https://sipeed.com)**。
> * **警惕:** 许多 `.ai/.org/.com/.net/...` 后缀的域名被第三方抢注,请勿轻信。
>
>
> * **注意:** picoclaw正在初期的快速功能开发阶段,可能有尚未修复的网络安全问题,在1.0正式版发布前,请不要将其部署到生产环境中
## 📢 新闻 (News)
BIN
View File
Binary file not shown.

Before

Width:  |  Height:  |  Size: 145 KiB

After

Width:  |  Height:  |  Size: 140 KiB

+17 -3
View File
@@ -13,6 +13,7 @@ import (
"fmt"
"io"
"io/fs"
"net/http"
"os"
"os/signal"
"path/filepath"
@@ -28,6 +29,7 @@ import (
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/cron"
"github.com/sipeed/picoclaw/pkg/devices"
"github.com/sipeed/picoclaw/pkg/health"
"github.com/sipeed/picoclaw/pkg/heartbeat"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/migrate"
@@ -560,7 +562,7 @@ func gatewayCmd() {
})
// Setup cron tool and service
cronService := setupCronTool(agentLoop, msgBus, cfg.WorkspacePath())
cronService := setupCronTool(agentLoop, msgBus, cfg.WorkspacePath(), cfg.Agents.Defaults.RestrictToWorkspace)
heartbeatService := heartbeat.NewHeartbeatService(
cfg.WorkspacePath(),
@@ -592,6 +594,9 @@ func gatewayCmd() {
os.Exit(1)
}
// Inject channel manager into agent loop for command handling
agentLoop.SetChannelManager(channelManager)
var transcriber *voice.GroqTranscriber
if cfg.Providers.Groq.APIKey != "" {
transcriber = voice.NewGroqTranscriber(cfg.Providers.Groq.APIKey)
@@ -658,6 +663,14 @@ func gatewayCmd() {
fmt.Printf("Error starting channels: %v\n", err)
}
healthServer := health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port)
go func() {
if err := healthServer.Start(); err != nil && err != http.ErrServerClosed {
logger.ErrorCF("health", "Health server error", map[string]interface{}{"error": err.Error()})
}
}()
fmt.Printf("✓ Health endpoints available at http://%s:%d/health and /ready\n", cfg.Gateway.Host, cfg.Gateway.Port)
go agentLoop.Run(ctx)
sigChan := make(chan os.Signal, 1)
@@ -666,6 +679,7 @@ func gatewayCmd() {
fmt.Println("\nShutting down...")
cancel()
healthServer.Stop(context.Background())
deviceService.Stop()
heartbeatService.Stop()
cronService.Stop()
@@ -973,14 +987,14 @@ func getConfigPath() string {
return filepath.Join(home, ".picoclaw", "config.json")
}
func setupCronTool(agentLoop *agent.AgentLoop, msgBus *bus.MessageBus, workspace string) *cron.CronService {
func setupCronTool(agentLoop *agent.AgentLoop, msgBus *bus.MessageBus, workspace string, restrict bool) *cron.CronService {
cronStorePath := filepath.Join(workspace, "cron", "jobs.json")
// Create cron service
cronService := cron.NewCronService(cronStorePath, nil)
// Create and register CronTool
cronTool := tools.NewCronTool(cronService, agentLoop, msgBus, workspace)
cronTool := tools.NewCronTool(cronService, agentLoop, msgBus, workspace, restrict)
agentLoop.RegisterTool(cronTool)
// Set the onJob handler
+4
View File
@@ -109,6 +109,10 @@
"moonshot": {
"api_key": "sk-xxx",
"api_base": ""
},
"ollama": {
"api_key": "",
"api_base": "http://localhost:11434/v1"
}
},
"tools": {
+6
View File
@@ -16,10 +16,16 @@ require (
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1
github.com/openai/openai-go/v3 v3.22.0
github.com/slack-go/slack v0.17.3
github.com/stretchr/testify v1.11.1
github.com/tencent-connect/botgo v0.2.1
golang.org/x/oauth2 v0.35.0
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
require github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
require (
+4
View File
@@ -82,9 +82,11 @@ github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzh
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/larksuite/oapi-sdk-go/v3 v3.5.3 h1:xvf8Dv29kBXC5/DNDCLhHkAFW8l/0LlQJimO5Zn+JUk=
github.com/larksuite/oapi-sdk-go/v3 v3.5.3/go.mod h1:ZEplY+kwuIrj/nqw5uSCINNATcH3KdxSN7y+UxYY5fI=
@@ -108,6 +110,7 @@ github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsK
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/slack-go/slack v0.17.3 h1:zV5qO3Q+WJAQ/XwbGfNFrRMaJ5T/naqaonyPV/1TP4g=
github.com/slack-go/slack v0.17.3/go.mod h1:X+UqOufi3LYQHDnMG1vxf0J8asC6+WllXrVrhl8/Prk=
@@ -252,6 +255,7 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys=
+298 -16
View File
@@ -19,6 +19,7 @@ import (
"unicode/utf8"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/constants"
"github.com/sipeed/picoclaw/pkg/logger"
@@ -44,6 +45,7 @@ type AgentLoop struct {
mcpManager *mcp.Manager // MCP server manager for resource cleanup
running atomic.Bool
summarizing sync.Map // Tracks which sessions are currently being summarized
channelManager *channels.Manager
}
// processOptions configures how a message is processed
@@ -239,6 +241,10 @@ func (al *AgentLoop) RegisterTool(tool tools.Tool) {
al.tools.Register(tool)
}
func (al *AgentLoop) SetChannelManager(cm *channels.Manager) {
al.channelManager = cm
}
// RecordLastChannel records the last active channel for this workspace.
// This uses the atomic state save mechanism to prevent data loss on crash.
func (al *AgentLoop) RecordLastChannel(channel string) error {
@@ -303,6 +309,11 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
return al.processSystemMessage(ctx, msg)
}
// Check for commands
if response, handled := al.handleCommand(ctx, msg); handled {
return response, nil
}
// Process as user message
return al.runAgentLoop(ctx, processOptions{
SessionKey: msg.SessionKey,
@@ -423,7 +434,7 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str
// 7. Optional: summarization
if opts.EnableSummary {
al.maybeSummarize(opts.SessionKey)
al.maybeSummarize(opts.SessionKey, opts.Channel, opts.ChatID)
}
// 8. Optional: send response via bus
@@ -485,11 +496,131 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
"tools_json": formatToolsForLog(providerToolDefs),
})
// Call LLM
response, err := al.provider.Chat(ctx, messages, providerToolDefs, al.model, map[string]interface{}{
"max_tokens": 8192,
"temperature": 0.7,
})
var response *providers.LLMResponse
var err error
// Retry loop for context/token errors
maxRetries := 2
for retry := 0; retry <= maxRetries; retry++ {
response, err = al.provider.Chat(ctx, messages, providerToolDefs, al.model, map[string]interface{}{
"max_tokens": 8192,
"temperature": 0.7,
})
if err == nil {
break // Success
}
errMsg := strings.ToLower(err.Error())
// Check for context window errors (provider specific, but usually contain "token" or "invalid")
isContextError := strings.Contains(errMsg, "token") ||
strings.Contains(errMsg, "context") ||
strings.Contains(errMsg, "invalidparameter") ||
strings.Contains(errMsg, "length")
if isContextError && retry < maxRetries {
logger.WarnCF("agent", "Context window error detected, attempting compression", map[string]interface{}{
"error": err.Error(),
"retry": retry,
})
// Notify user on first retry only
if retry == 0 && !constants.IsInternalChannel(opts.Channel) && opts.SendResponse {
al.bus.PublishOutbound(bus.OutboundMessage{
Channel: opts.Channel,
ChatID: opts.ChatID,
Content: "⚠️ Context window exceeded. Compressing history and retrying...",
})
}
// Force compression
al.forceCompression(opts.SessionKey)
// Rebuild messages with compressed history
// Note: We need to reload history from session manager because forceCompression changed it
newHistory := al.sessions.GetHistory(opts.SessionKey)
newSummary := al.sessions.GetSummary(opts.SessionKey)
// Re-create messages for the next attempt
// We keep the current user message (opts.UserMessage) effectively
messages = al.contextBuilder.BuildMessages(
newHistory,
newSummary,
opts.UserMessage,
nil,
opts.Channel,
opts.ChatID,
)
// Important: If we are in the middle of a tool loop (iteration > 1),
// rebuilding messages from session history might duplicate the flow or miss context
// if intermediate steps weren't saved correctly.
// However, al.sessions.AddFullMessage is called after every tool execution,
// so GetHistory should reflect the current state including partial tool execution.
// But we need to ensure we don't duplicate the user message which is appended in BuildMessages.
// BuildMessages(history...) takes the stored history and appends the *current* user message.
// If iteration > 1, the "current user message" was already added to history in step 3 of runAgentLoop.
// So if we pass opts.UserMessage again, we might duplicate it?
// Actually, step 3 is: al.sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage)
// So GetHistory ALREADY contains the user message!
// CORRECTION:
// BuildMessages combines: [System] + [History] + [CurrentMessage]
// But Step 3 added CurrentMessage to History.
// So if we use GetHistory now, it has the user message.
// If we pass opts.UserMessage to BuildMessages, it adds it AGAIN.
// For retry in the middle of a loop, we should rely on what's in the session.
// BUT checking BuildMessages implementation:
// It appends history... then appends currentMessage.
// Logic fix for retry:
// If iteration == 1, opts.UserMessage corresponds to the user input.
// If iteration > 1, we are processing tool results. The "messages" passed to Chat
// already accumulated tool outputs.
// Rebuilding from session history is safest because it persists state.
// Start fresh with rebuilt history.
// Special case: standard BuildMessages appends "currentMessage".
// If we are strictly retrying the *LLM call*, we want the exact same state as before but compressed.
// However, the "messages" argument passed to runLLMIteration is constructed by the caller.
// If we rebuild from Session, we need to know if "currentMessage" should be appended or is already in history.
// In runAgentLoop:
// 3. sessions.AddMessage(userMsg)
// 4. runLLMIteration(..., UserMessage)
// So History contains the user message.
// BuildMessages typically appends the user message as a *new* pending message.
// Wait, standard BuildMessages usage in runAgentLoop:
// messages := BuildMessages(history (has old), UserMessage)
// THEN AddMessage(UserMessage).
// So "history" passed to BuildMessages does NOT contain the current UserMessage yet.
// But here, inside the loop, we have already saved it.
// So GetHistory() includes the current user message.
// If we call BuildMessages(GetHistory(), UserMessage), we get duplicates.
// Hack/Fix:
// If we are retrying, we rebuild from Session History ONLY.
// We pass empty string as "currentMessage" to BuildMessages
// because the "current message" is already saved in history (step 3).
messages = al.contextBuilder.BuildMessages(
newHistory,
newSummary,
"", // Empty because history already contains the relevant messages
nil,
opts.Channel,
opts.ChatID,
)
continue
}
// Real error or success, break loop
break
}
if err != nil {
logger.ErrorCF("agent", "LLM call failed",
@@ -497,7 +628,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
"iteration": iteration,
"error": err.Error(),
})
return "", iteration, fmt.Errorf("LLM call failed: %w", err)
return "", iteration, fmt.Errorf("LLM call failed after retries: %w", err)
}
// Check if no tool calls - we're done
@@ -629,7 +760,7 @@ func (al *AgentLoop) updateToolContexts(channel, chatID string) {
}
// maybeSummarize triggers summarization if the session history exceeds thresholds.
func (al *AgentLoop) maybeSummarize(sessionKey string) {
func (al *AgentLoop) maybeSummarize(sessionKey, channel, chatID string) {
newHistory := al.sessions.GetHistory(sessionKey)
tokenEstimate := al.estimateTokens(newHistory)
threshold := al.contextWindow * 75 / 100
@@ -638,12 +769,80 @@ func (al *AgentLoop) maybeSummarize(sessionKey string) {
if _, loading := al.summarizing.LoadOrStore(sessionKey, true); !loading {
go func() {
defer al.summarizing.Delete(sessionKey)
// Notify user about optimization if not an internal channel
if !constants.IsInternalChannel(channel) {
al.bus.PublishOutbound(bus.OutboundMessage{
Channel: channel,
ChatID: chatID,
Content: "⚠️ Memory threshold reached. Optimizing conversation history...",
})
}
al.summarizeSession(sessionKey)
}()
}
}
}
// forceCompression aggressively reduces context when the limit is hit.
// It drops the oldest 50% of messages (keeping system prompt and last user message).
func (al *AgentLoop) forceCompression(sessionKey string) {
history := al.sessions.GetHistory(sessionKey)
if len(history) <= 4 {
return
}
// Keep system prompt (usually [0]) and the very last message (user's trigger)
// We want to drop the oldest half of the *conversation*
// Assuming [0] is system, [1:] is conversation
conversation := history[1 : len(history)-1]
if len(conversation) == 0 {
return
}
// Helper to find the mid-point of the conversation
mid := len(conversation) / 2
// New history structure:
// 1. System Prompt
// 2. [Summary of dropped part] - synthesized
// 3. Second half of conversation
// 4. Last message
// Simplified approach for emergency: Drop first half of conversation
// and rely on existing summary if present, or create a placeholder.
droppedCount := mid
keptConversation := conversation[mid:]
newHistory := make([]providers.Message, 0)
newHistory = append(newHistory, history[0]) // System prompt
// Add a note about compression
compressionNote := fmt.Sprintf("[System: Emergency compression dropped %d oldest messages due to context limit]", droppedCount)
// If there was an existing summary, we might lose it if it was in the dropped part (which is just messages).
// The summary is stored separately in session.Summary, so it persists!
// We just need to ensure the user knows there's a gap.
// We only modify the messages list here
newHistory = append(newHistory, providers.Message{
Role: "system",
Content: compressionNote,
})
newHistory = append(newHistory, keptConversation...)
newHistory = append(newHistory, history[len(history)-1]) // Last message
// Update session
al.sessions.SetHistory(sessionKey, newHistory)
al.sessions.Save(sessionKey)
logger.WarnCF("agent", "Forced compression executed", map[string]interface{}{
"session_key": sessionKey,
"dropped_msgs": droppedCount,
"new_count": len(newHistory),
})
}
// GetStartupInfo returns information about loaded tools and skills for logging.
func (al *AgentLoop) GetStartupInfo() map[string]interface{} {
info := make(map[string]interface{})
@@ -671,7 +870,7 @@ func formatMessagesForLog(messages []providers.Message) string {
result += "[\n"
for i, msg := range messages {
result += fmt.Sprintf(" [%d] Role: %s\n", i, msg.Role)
if msg.ToolCalls != nil && len(msg.ToolCalls) > 0 {
if len(msg.ToolCalls) > 0 {
result += " ToolCalls:\n"
for _, tc := range msg.ToolCalls {
result += fmt.Sprintf(" - ID: %s, Type: %s, Name: %s\n", tc.ID, tc.Type, tc.Name)
@@ -738,7 +937,7 @@ func (al *AgentLoop) summarizeSession(sessionKey string) {
continue
}
// Estimate tokens for this message
msgTokens := len(m.Content) / 4
msgTokens := len(m.Content) / 2 // Use safer estimate here too (2.5 -> 2 for integer division safety)
if msgTokens > maxMessageTokens {
omitted = true
continue
@@ -809,13 +1008,96 @@ func (al *AgentLoop) summarizeBatch(ctx context.Context, batch []providers.Messa
}
// estimateTokens estimates the number of tokens in a message list.
// Uses rune count instead of byte length so that CJK and other multi-byte
// characters are not over-counted (a Chinese character is 3 bytes but roughly
// one token).
// Uses a safe heuristic of 2.5 characters per token to account for CJK and other
// overheads better than the previous 3 chars/token.
func (al *AgentLoop) estimateTokens(messages []providers.Message) int {
total := 0
totalChars := 0
for _, m := range messages {
total += utf8.RuneCountInString(m.Content) / 3
totalChars += utf8.RuneCountInString(m.Content)
}
return total
// 2.5 chars per token = totalChars * 2 / 5
return totalChars * 2 / 5
}
func (al *AgentLoop) handleCommand(ctx context.Context, msg bus.InboundMessage) (string, bool) {
content := strings.TrimSpace(msg.Content)
if !strings.HasPrefix(content, "/") {
return "", false
}
parts := strings.Fields(content)
if len(parts) == 0 {
return "", false
}
cmd := parts[0]
args := parts[1:]
switch cmd {
case "/show":
if len(args) < 1 {
return "Usage: /show [model|channel]", true
}
switch args[0] {
case "model":
return fmt.Sprintf("Current model: %s", al.model), true
case "channel":
return fmt.Sprintf("Current channel: %s", msg.Channel), true
default:
return fmt.Sprintf("Unknown show target: %s", args[0]), true
}
case "/list":
if len(args) < 1 {
return "Usage: /list [models|channels]", true
}
switch args[0] {
case "models":
// TODO: Fetch available models dynamically if possible
return "Available models: glm-4.7, claude-3-5-sonnet, gpt-4o (configured in config.json/env)", true
case "channels":
if al.channelManager == nil {
return "Channel manager not initialized", true
}
channels := al.channelManager.GetEnabledChannels()
if len(channels) == 0 {
return "No channels enabled", true
}
return fmt.Sprintf("Enabled channels: %s", strings.Join(channels, ", ")), true
default:
return fmt.Sprintf("Unknown list target: %s", args[0]), true
}
case "/switch":
if len(args) < 3 || args[1] != "to" {
return "Usage: /switch [model|channel] to <name>", true
}
target := args[0]
value := args[2]
switch target {
case "model":
oldModel := al.model
al.model = value
return fmt.Sprintf("Switched model from %s to %s", oldModel, value), true
case "channel":
// This changes the 'default' channel for some operations, or effectively redirects output?
// For now, let's just validate if the channel exists
if al.channelManager == nil {
return "Channel manager not initialized", true
}
if _, exists := al.channelManager.GetChannel(value); !exists && value != "cli" {
return fmt.Sprintf("Channel '%s' not found or not enabled", value), true
}
// If message came from CLI, maybe we want to redirect CLI output to this channel?
// That would require state persistence about "redirected channel"
// For now, just acknowledged.
return fmt.Sprintf("Switched target channel to %s (Note: this currently only validates existence)", value), true
default:
return fmt.Sprintf("Unknown switch target: %s", target), true
}
}
return "", false
}
+97
View File
@@ -2,6 +2,7 @@ package agent
import (
"context"
"fmt"
"os"
"path/filepath"
"testing"
@@ -527,3 +528,99 @@ func TestToolResult_UserFacingToolDoesSendMessage(t *testing.T) {
t.Errorf("Expected 'Command output: hello world', got: %s", response)
}
}
// failFirstMockProvider fails on the first N calls with a specific error
type failFirstMockProvider struct {
failures int
currentCall int
failError error
successResp string
}
func (m *failFirstMockProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, opts map[string]interface{}) (*providers.LLMResponse, error) {
m.currentCall++
if m.currentCall <= m.failures {
return nil, m.failError
}
return &providers.LLMResponse{
Content: m.successResp,
ToolCalls: []providers.ToolCall{},
}, nil
}
func (m *failFirstMockProvider) GetDefaultModel() string {
return "mock-fail-model"
}
// TestAgentLoop_ContextExhaustionRetry verify that the agent retries on context errors
func TestAgentLoop_ContextExhaustionRetry(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
// Create a provider that fails once with a context error
contextErr := fmt.Errorf("InvalidParameter: Total tokens of image and text exceed max message tokens")
provider := &failFirstMockProvider{
failures: 1,
failError: contextErr,
successResp: "Recovered from context error",
}
al := NewAgentLoop(cfg, msgBus, provider)
// Inject some history to simulate a full context
sessionKey := "test-session-context"
// Create dummy history
history := []providers.Message{
{Role: "system", Content: "System prompt"},
{Role: "user", Content: "Old message 1"},
{Role: "assistant", Content: "Old response 1"},
{Role: "user", Content: "Old message 2"},
{Role: "assistant", Content: "Old response 2"},
{Role: "user", Content: "Trigger message"},
}
al.sessions.SetHistory(sessionKey, history)
// Call ProcessDirectWithChannel
// Note: ProcessDirectWithChannel calls processMessage which will execute runLLMIteration
response, err := al.ProcessDirectWithChannel(context.Background(), "Trigger message", sessionKey, "test", "test-chat")
if err != nil {
t.Fatalf("Expected success after retry, got error: %v", err)
}
if response != "Recovered from context error" {
t.Errorf("Expected 'Recovered from context error', got '%s'", response)
}
// We expect 2 calls: 1st failed, 2nd succeeded
if provider.currentCall != 2 {
t.Errorf("Expected 2 calls (1 fail + 1 success), got %d", provider.currentCall)
}
// Check final history length
finalHistory := al.sessions.GetHistory(sessionKey)
// We verify that the history has been modified (compressed)
// Original length: 6
// Expected behavior: compression drops ~50% of history (mid slice)
// We can assert that the length is NOT what it would be without compression.
// Without compression: 6 + 1 (new user msg) + 1 (assistant msg) = 8
if len(finalHistory) >= 8 {
t.Errorf("Expected history to be compressed (len < 8), got %d", len(finalHistory))
}
}
+56 -14
View File
@@ -281,7 +281,17 @@ func RefreshAccessToken(cred *AuthCredential, cfg OAuthProviderConfig) (*AuthCre
return nil, fmt.Errorf("token refresh failed: %s", string(body))
}
return parseTokenResponse(body, cred.Provider)
refreshed, err := parseTokenResponse(body, cred.Provider)
if err != nil {
return nil, err
}
if refreshed.RefreshToken == "" {
refreshed.RefreshToken = cred.RefreshToken
}
if refreshed.AccountID == "" {
refreshed.AccountID = cred.AccountID
}
return refreshed, nil
}
func BuildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectURI string) string {
@@ -300,6 +310,9 @@ func buildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectU
"codex_cli_simplified_flow": {"true"},
"state": {state},
}
if strings.Contains(strings.ToLower(cfg.Issuer), "auth.openai.com") {
params.Set("originator", "picoclaw")
}
if cfg.Originator != "" {
params.Set("originator", cfg.Originator)
}
@@ -357,7 +370,9 @@ func parseTokenResponse(body []byte, provider string) (*AuthCredential, error) {
AuthMethod: "oauth",
}
if accountID := extractAccountID(tokenResp.AccessToken); accountID != "" {
if accountID := extractAccountID(tokenResp.IDToken); accountID != "" {
cred.AccountID = accountID
} else if accountID := extractAccountID(tokenResp.AccessToken); accountID != "" {
cred.AccountID = accountID
} else if accountID := extractAccountID(tokenResp.IDToken); accountID != "" {
// Recent OpenAI OAuth responses may only include chatgpt_account_id in id_token claims.
@@ -367,12 +382,45 @@ func parseTokenResponse(body []byte, provider string) (*AuthCredential, error) {
return cred, nil
}
func extractAccountID(accessToken string) string {
parts := strings.Split(accessToken, ".")
if len(parts) < 2 {
func extractAccountID(token string) string {
claims, err := parseJWTClaims(token)
if err != nil {
return ""
}
if accountID, ok := claims["chatgpt_account_id"].(string); ok && accountID != "" {
return accountID
}
if accountID, ok := claims["https://api.openai.com/auth.chatgpt_account_id"].(string); ok && accountID != "" {
return accountID
}
if authClaim, ok := claims["https://api.openai.com/auth"].(map[string]interface{}); ok {
if accountID, ok := authClaim["chatgpt_account_id"].(string); ok && accountID != "" {
return accountID
}
}
if orgs, ok := claims["organizations"].([]interface{}); ok {
for _, org := range orgs {
if orgMap, ok := org.(map[string]interface{}); ok {
if accountID, ok := orgMap["id"].(string); ok && accountID != "" {
return accountID
}
}
}
}
return ""
}
func parseJWTClaims(token string) (map[string]interface{}, error) {
parts := strings.Split(token, ".")
if len(parts) < 2 {
return nil, fmt.Errorf("token is not a JWT")
}
payload := parts[1]
switch len(payload) % 4 {
case 2:
@@ -383,21 +431,15 @@ func extractAccountID(accessToken string) string {
decoded, err := base64URLDecode(payload)
if err != nil {
return ""
return nil, err
}
var claims map[string]interface{}
if err := json.Unmarshal(decoded, &claims); err != nil {
return ""
return nil, err
}
if authClaim, ok := claims["https://api.openai.com/auth"].(map[string]interface{}); ok {
if accountID, ok := authClaim["chatgpt_account_id"].(string); ok {
return accountID
}
}
return ""
return claims, nil
}
func base64URLDecode(s string) ([]byte, error) {
+97
View File
@@ -5,10 +5,23 @@ import (
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
)
func makeJWTForClaims(t *testing.T, claims map[string]interface{}) string {
t.Helper()
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"none","typ":"JWT"}`))
payloadJSON, err := json.Marshal(claims)
if err != nil {
t.Fatalf("marshal claims: %v", err)
}
payload := base64.RawURLEncoding.EncodeToString(payloadJSON)
return header + "." + payload + ".sig"
}
func TestBuildAuthorizeURL(t *testing.T) {
cfg := OAuthProviderConfig{
Issuer: "https://auth.example.com",
@@ -53,6 +66,28 @@ func TestBuildAuthorizeURL(t *testing.T) {
}
}
func TestBuildAuthorizeURLOpenAIExtras(t *testing.T) {
cfg := OpenAIOAuthConfig()
pkce := PKCECodes{CodeVerifier: "test-verifier", CodeChallenge: "test-challenge"}
u := BuildAuthorizeURL(cfg, pkce, "test-state", "http://localhost:1455/auth/callback")
parsed, err := url.Parse(u)
if err != nil {
t.Fatalf("url.Parse() error: %v", err)
}
q := parsed.Query()
if q.Get("id_token_add_organizations") != "true" {
t.Errorf("id_token_add_organizations = %q, want true", q.Get("id_token_add_organizations"))
}
if q.Get("codex_cli_simplified_flow") != "true" {
t.Errorf("codex_cli_simplified_flow = %q, want true", q.Get("codex_cli_simplified_flow"))
}
if q.Get("originator") != "codex_cli_rs" {
t.Errorf("originator = %q, want codex_cli_rs", q.Get("originator"))
}
}
func TestParseTokenResponse(t *testing.T) {
resp := map[string]interface{}{
"access_token": "test-access-token",
@@ -84,6 +119,37 @@ func TestParseTokenResponse(t *testing.T) {
}
}
func TestParseTokenResponseExtractsAccountIDFromIDToken(t *testing.T) {
idToken := makeJWTForClaims(t, map[string]interface{}{"chatgpt_account_id": "acc-id-from-id-token"})
resp := map[string]interface{}{
"access_token": "opaque-access-token",
"refresh_token": "test-refresh-token",
"expires_in": 3600,
"id_token": idToken,
}
body, _ := json.Marshal(resp)
cred, err := parseTokenResponse(body, "openai")
if err != nil {
t.Fatalf("parseTokenResponse() error: %v", err)
}
if cred.AccountID != "acc-id-from-id-token" {
t.Errorf("AccountID = %q, want %q", cred.AccountID, "acc-id-from-id-token")
}
}
func TestExtractAccountIDFromOrganizationsFallback(t *testing.T) {
token := makeJWTForClaims(t, map[string]interface{}{
"organizations": []interface{}{
map[string]interface{}{"id": "org_from_orgs"},
},
})
if got := extractAccountID(token); got != "org_from_orgs" {
t.Errorf("extractAccountID() = %q, want %q", got, "org_from_orgs")
}
}
func TestParseTokenResponseNoAccessToken(t *testing.T) {
body := []byte(`{"refresh_token": "test"}`)
_, err := parseTokenResponse(body, "openai")
@@ -222,6 +288,37 @@ func TestRefreshAccessTokenNoRefreshToken(t *testing.T) {
}
}
func TestRefreshAccessTokenPreservesRefreshAndAccountID(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := map[string]interface{}{
"access_token": "new-access-token-only",
"expires_in": 3600,
}
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
cfg := OAuthProviderConfig{Issuer: server.URL, ClientID: "test-client"}
cred := &AuthCredential{
AccessToken: "old-access",
RefreshToken: "existing-refresh",
AccountID: "acc_existing",
Provider: "openai",
AuthMethod: "oauth",
}
refreshed, err := RefreshAccessToken(cred, cfg)
if err != nil {
t.Fatalf("RefreshAccessToken() error: %v", err)
}
if refreshed.RefreshToken != "existing-refresh" {
t.Errorf("RefreshToken = %q, want %q", refreshed.RefreshToken, "existing-refresh")
}
if refreshed.AccountID != "acc_existing" {
t.Errorf("AccountID = %q, want %q", refreshed.AccountID, "acc_existing")
}
}
func TestOpenAIOAuthConfig(t *testing.T) {
cfg := OpenAIOAuthConfig()
if cfg.Issuer != "https://auth.openai.com" {
+17
View File
@@ -9,6 +9,7 @@ type MessageBus struct {
inbound chan InboundMessage
outbound chan OutboundMessage
handlers map[string]MessageHandler
closed bool
mu sync.RWMutex
}
@@ -21,6 +22,11 @@ func NewMessageBus() *MessageBus {
}
func (mb *MessageBus) PublishInbound(msg InboundMessage) {
mb.mu.RLock()
defer mb.mu.RUnlock()
if mb.closed {
return
}
mb.inbound <- msg
}
@@ -34,6 +40,11 @@ func (mb *MessageBus) ConsumeInbound(ctx context.Context) (InboundMessage, bool)
}
func (mb *MessageBus) PublishOutbound(msg OutboundMessage) {
mb.mu.RLock()
defer mb.mu.RUnlock()
if mb.closed {
return
}
mb.outbound <- msg
}
@@ -60,6 +71,12 @@ func (mb *MessageBus) GetHandler(channel string) (MessageHandler, bool) {
}
func (mb *MessageBus) Close() {
mb.mu.Lock()
defer mb.mu.Unlock()
if mb.closed {
return
}
mb.closed = true
close(mb.inbound)
close(mb.outbound)
}
+150 -2
View File
@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"os"
"strings"
"time"
"github.com/bwmarrin/discordgo"
@@ -100,15 +101,156 @@ func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro
return fmt.Errorf("channel ID is empty")
}
message := msg.Content
runes := []rune(msg.Content)
if len(runes) == 0 {
return nil
}
chunks := splitMessage(msg.Content, 1500) // Discord has a limit of 2000 characters per message, leave 500 for natural split e.g. code blocks
for _, chunk := range chunks {
if err := c.sendChunk(ctx, channelID, chunk); err != nil {
return err
}
}
return nil
}
// splitMessage splits long messages into chunks, preserving code block integrity
// Uses natural boundaries (newlines, spaces) and extends messages slightly to avoid breaking code blocks
func splitMessage(content string, limit int) []string {
var messages []string
for len(content) > 0 {
if len(content) <= limit {
messages = append(messages, content)
break
}
msgEnd := limit
// Find natural split point within the limit
msgEnd = findLastNewline(content[:limit], 200)
if msgEnd <= 0 {
msgEnd = findLastSpace(content[:limit], 100)
}
if msgEnd <= 0 {
msgEnd = limit
}
// Check if this would end with an incomplete code block
candidate := content[:msgEnd]
unclosedIdx := findLastUnclosedCodeBlock(candidate)
if unclosedIdx >= 0 {
// Message would end with incomplete code block
// Try to extend to include the closing ``` (with some buffer)
extendedLimit := limit + 500 // Allow 500 char buffer for code blocks
if len(content) > extendedLimit {
closingIdx := findNextClosingCodeBlock(content, msgEnd)
if closingIdx > 0 && closingIdx <= extendedLimit {
// Extend to include the closing ```
msgEnd = closingIdx
} else {
// Can't find closing, split before the code block
msgEnd = findLastNewline(content[:unclosedIdx], 200)
if msgEnd <= 0 {
msgEnd = findLastSpace(content[:unclosedIdx], 100)
}
if msgEnd <= 0 {
msgEnd = unclosedIdx
}
}
} else {
// Remaining content fits within extended limit
msgEnd = len(content)
}
}
if msgEnd <= 0 {
msgEnd = limit
}
messages = append(messages, content[:msgEnd])
content = strings.TrimSpace(content[msgEnd:])
}
return messages
}
// findLastUnclosedCodeBlock finds the last opening ``` that doesn't have a closing ```
// Returns the position of the opening ``` or -1 if all code blocks are complete
func findLastUnclosedCodeBlock(text string) int {
count := 0
lastOpenIdx := -1
for i := 0; i < len(text); i++ {
if i+2 < len(text) && text[i] == '`' && text[i+1] == '`' && text[i+2] == '`' {
if count == 0 {
lastOpenIdx = i
}
count++
i += 2
}
}
// If odd number of ``` markers, last one is unclosed
if count%2 == 1 {
return lastOpenIdx
}
return -1
}
// findNextClosingCodeBlock finds the next closing ``` starting from a position
// Returns the position after the closing ``` or -1 if not found
func findNextClosingCodeBlock(text string, startIdx int) int {
for i := startIdx; i < len(text); i++ {
if i+2 < len(text) && text[i] == '`' && text[i+1] == '`' && text[i+2] == '`' {
return i + 3
}
}
return -1
}
// findLastNewline finds the last newline character within the last N characters
// Returns the position of the newline or -1 if not found
func findLastNewline(s string, searchWindow int) int {
searchStart := len(s) - searchWindow
if searchStart < 0 {
searchStart = 0
}
for i := len(s) - 1; i >= searchStart; i-- {
if s[i] == '\n' {
return i
}
}
return -1
}
// findLastSpace finds the last space character within the last N characters
// Returns the position of the space or -1 if not found
func findLastSpace(s string, searchWindow int) int {
searchStart := len(s) - searchWindow
if searchStart < 0 {
searchStart = 0
}
for i := len(s) - 1; i >= searchStart; i-- {
if s[i] == ' ' || s[i] == '\t' {
return i
}
}
return -1
}
func (c *DiscordChannel) sendChunk(ctx context.Context, channelID, content string) error {
// 使用传入的 ctx 进行超时控制
sendCtx, cancel := context.WithTimeout(ctx, sendTimeout)
defer cancel()
done := make(chan error, 1)
go func() {
_, err := c.session.ChannelMessageSend(channelID, message)
_, err := c.session.ChannelMessageSend(channelID, content)
done <- err
}()
@@ -140,6 +282,12 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
return
}
if err := c.session.ChannelTyping(m.ChannelID); err != nil {
logger.ErrorCF("discord", "Failed to send typing indicator", map[string]any{
"error": err.Error(),
})
}
// 检查白名单,避免为被拒绝的用户下载附件和转录
if !c.IsAllowed(m.Author.ID) {
logger.DebugCF("discord", "Message rejected by allowlist", map[string]any{
+1 -1
View File
@@ -48,7 +48,7 @@ func (m *Manager) initChannels() error {
if m.config.Channels.Telegram.Enabled && m.config.Channels.Telegram.Token != "" {
logger.DebugC("channels", "Attempting to initialize Telegram channel")
telegram, err := NewTelegramChannel(m.config.Channels.Telegram, m.bus)
telegram, err := NewTelegramChannel(m.config, m.bus)
if err != nil {
logger.ErrorCF("channels", "Failed to initialize Telegram channel", map[string]interface{}{
"error": err.Error(),
+14
View File
@@ -296,6 +296,13 @@ func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) {
return
}
if !c.IsAllowed(ev.User) {
logger.DebugCF("slack", "Mention rejected by allowlist", map[string]interface{}{
"user_id": ev.User,
})
return
}
senderID := ev.User
channelID := ev.Channel
threadTS := ev.ThreadTimeStamp
@@ -345,6 +352,13 @@ func (c *SlackChannel) handleSlashCommand(event socketmode.Event) {
c.socketClient.Ack(*event.Request)
}
if !c.IsAllowed(cmd.UserID) {
logger.DebugCF("slack", "Slash command rejected by allowlist", map[string]interface{}{
"user_id": cmd.UserID,
})
return
}
senderID := cmd.UserID
channelID := cmd.ChannelID
chatID := channelID
+58 -40
View File
@@ -11,7 +11,10 @@ import (
"sync"
"time"
th "github.com/mymmrac/telego/telegohandler"
"github.com/mymmrac/telego"
"github.com/mymmrac/telego/telegohandler"
tu "github.com/mymmrac/telego/telegoutil"
"github.com/sipeed/picoclaw/pkg/bus"
@@ -24,7 +27,8 @@ import (
type TelegramChannel struct {
*BaseChannel
bot *telego.Bot
config config.TelegramConfig
commands TelegramCommander
config *config.Config
chatIDs map[string]int64
transcriber *voice.GroqTranscriber
placeholders sync.Map // chatID -> messageID
@@ -41,13 +45,14 @@ func (c *thinkingCancel) Cancel() {
}
}
func NewTelegramChannel(cfg config.TelegramConfig, bus *bus.MessageBus) (*TelegramChannel, error) {
func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChannel, error) {
var opts []telego.BotOption
telegramCfg := cfg.Channels.Telegram
if cfg.Proxy != "" {
proxyURL, parseErr := url.Parse(cfg.Proxy)
if telegramCfg.Proxy != "" {
proxyURL, parseErr := url.Parse(telegramCfg.Proxy)
if parseErr != nil {
return nil, fmt.Errorf("invalid proxy URL %q: %w", cfg.Proxy, parseErr)
return nil, fmt.Errorf("invalid proxy URL %q: %w", telegramCfg.Proxy, parseErr)
}
opts = append(opts, telego.WithHTTPClient(&http.Client{
Transport: &http.Transport{
@@ -56,15 +61,16 @@ func NewTelegramChannel(cfg config.TelegramConfig, bus *bus.MessageBus) (*Telegr
}))
}
bot, err := telego.NewBot(cfg.Token, opts...)
bot, err := telego.NewBot(telegramCfg.Token, opts...)
if err != nil {
return nil, fmt.Errorf("failed to create telegram bot: %w", err)
}
base := NewBaseChannel("telegram", cfg, bus, cfg.AllowFrom)
base := NewBaseChannel("telegram", telegramCfg, bus, telegramCfg.AllowFrom)
return &TelegramChannel{
BaseChannel: base,
commands: NewTelegramCommands(bot, cfg),
bot: bot,
config: cfg,
chatIDs: make(map[string]int64),
@@ -88,31 +94,45 @@ func (c *TelegramChannel) Start(ctx context.Context) error {
return fmt.Errorf("failed to start long polling: %w", err)
}
bh, err := telegohandler.NewBotHandler(c.bot, updates)
if err != nil {
return fmt.Errorf("failed to create bot handler: %w", err)
}
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
c.commands.Help(ctx, message)
return nil
}, th.CommandEqual("help"))
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
return c.commands.Start(ctx, message)
}, th.CommandEqual("start"))
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
return c.commands.Show(ctx, message)
}, th.CommandEqual("show"))
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
return c.commands.List(ctx, message)
}, th.CommandEqual("list"))
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
return c.handleMessage(ctx, &message)
}, th.AnyMessage())
c.setRunning(true)
logger.InfoCF("telegram", "Telegram bot connected", map[string]interface{}{
"username": c.bot.Username(),
})
go bh.Start()
go func() {
for {
select {
case <-ctx.Done():
return
case update, ok := <-updates:
if !ok {
logger.InfoC("telegram", "Updates channel closed, reconnecting...")
return
}
if update.Message != nil {
c.handleMessage(ctx, update)
}
}
}
<-ctx.Done()
bh.Stop()
}()
return nil
}
func (c *TelegramChannel) Stop(ctx context.Context) error {
logger.InfoC("telegram", "Stopping Telegram bot...")
c.setRunning(false)
@@ -166,30 +186,27 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
return nil
}
func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Update) {
message := update.Message
func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Message) error {
if message == nil {
return
return fmt.Errorf("message is nil")
}
user := message.From
if user == nil {
return
return fmt.Errorf("message sender (user) is nil")
}
userID := fmt.Sprintf("%d", user.ID)
senderID := userID
senderID := fmt.Sprintf("%d", user.ID)
if user.Username != "" {
senderID = fmt.Sprintf("%s|%s", userID, user.Username)
senderID = fmt.Sprintf("%d|%s", user.ID, user.Username)
}
// 检查白名单,避免为被拒绝的用户下载附件
if !c.IsAllowed(userID) && !c.IsAllowed(senderID) {
if !c.IsAllowed(senderID) {
logger.DebugCF("telegram", "Message rejected by allowlist", map[string]interface{}{
"user_id": userID,
"username": user.Username,
"user_id": senderID,
})
return
return nil
}
chatID := message.Chat.ID
@@ -222,7 +239,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
content += message.Caption
}
if message.Photo != nil && len(message.Photo) > 0 {
if len(message.Photo) > 0 {
photo := message.Photo[len(message.Photo)-1]
photoPath := c.downloadPhoto(ctx, photo.FileID)
if photoPath != "" {
@@ -231,7 +248,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
if content != "" {
content += "\n"
}
content += fmt.Sprintf("[image: photo]")
content += "[image: photo]"
}
}
@@ -252,7 +269,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
"error": err.Error(),
"path": voicePath,
})
transcribedText = fmt.Sprintf("[voice (transcription failed)]")
transcribedText = "[voice (transcription failed)]"
} else {
transcribedText = fmt.Sprintf("[voice transcription: %s]", result.Text)
logger.InfoCF("telegram", "Voice transcribed successfully", map[string]interface{}{
@@ -260,7 +277,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
})
}
} else {
transcribedText = fmt.Sprintf("[voice]")
transcribedText = "[voice]"
}
if content != "" {
@@ -278,7 +295,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
if content != "" {
content += "\n"
}
content += fmt.Sprintf("[audio]")
content += "[audio]"
}
}
@@ -290,7 +307,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
if content != "" {
content += "\n"
}
content += fmt.Sprintf("[file]")
content += "[file]"
}
}
@@ -338,7 +355,8 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
"is_group": fmt.Sprintf("%t", message.Chat.Type != "private"),
}
c.HandleMessage(senderID, fmt.Sprintf("%d", chatID), content, mediaPaths, metadata)
c.HandleMessage(fmt.Sprintf("%d", user.ID), fmt.Sprintf("%d", chatID), content, mediaPaths, metadata)
return nil
}
func (c *TelegramChannel) downloadPhoto(ctx context.Context, fileID string) string {
+153
View File
@@ -0,0 +1,153 @@
package channels
import (
"context"
"fmt"
"strings"
"github.com/mymmrac/telego"
"github.com/sipeed/picoclaw/pkg/config"
)
type TelegramCommander interface {
Help(ctx context.Context, message telego.Message) error
Start(ctx context.Context, message telego.Message) error
Show(ctx context.Context, message telego.Message) error
List(ctx context.Context, message telego.Message) error
}
type cmd struct {
bot *telego.Bot
config *config.Config
}
func NewTelegramCommands(bot *telego.Bot, cfg *config.Config) TelegramCommander {
return &cmd{
bot: bot,
config: cfg,
}
}
func commandArgs(text string) string {
parts := strings.SplitN(text, " ", 2)
if len(parts) < 2 {
return ""
}
return strings.TrimSpace(parts[1])
}
func (c *cmd) Help(ctx context.Context, message telego.Message) error {
msg := `/start - Start the bot
/help - Show this help message
/show [model|channel] - Show current configuration
/list [models|channels] - List available options
`
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
ChatID: telego.ChatID{ID: message.Chat.ID},
Text: msg,
ReplyParameters: &telego.ReplyParameters{
MessageID: message.MessageID,
},
})
return err
}
func (c *cmd) Start(ctx context.Context, message telego.Message) error {
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
ChatID: telego.ChatID{ID: message.Chat.ID},
Text: "Hello! I am PicoClaw 🦞",
ReplyParameters: &telego.ReplyParameters{
MessageID: message.MessageID,
},
})
return err
}
func (c *cmd) Show(ctx context.Context, message telego.Message) error {
args := commandArgs(message.Text)
if args == "" {
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
ChatID: telego.ChatID{ID: message.Chat.ID},
Text: "Usage: /show [model|channel]",
ReplyParameters: &telego.ReplyParameters{
MessageID: message.MessageID,
},
})
return err
}
var response string
switch args {
case "model":
response = fmt.Sprintf("Current Model: %s (Provider: %s)",
c.config.Agents.Defaults.Model,
c.config.Agents.Defaults.Provider)
case "channel":
response = "Current Channel: telegram"
default:
response = fmt.Sprintf("Unknown parameter: %s. Try 'model' or 'channel'.", args)
}
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
ChatID: telego.ChatID{ID: message.Chat.ID},
Text: response,
ReplyParameters: &telego.ReplyParameters{
MessageID: message.MessageID,
},
})
return err
}
func (c *cmd) List(ctx context.Context, message telego.Message) error {
args := commandArgs(message.Text)
if args == "" {
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
ChatID: telego.ChatID{ID: message.Chat.ID},
Text: "Usage: /list [models|channels]",
ReplyParameters: &telego.ReplyParameters{
MessageID: message.MessageID,
},
})
return err
}
var response string
switch args {
case "models":
provider := c.config.Agents.Defaults.Provider
if provider == "" {
provider = "configured default"
}
response = fmt.Sprintf("Configured Model: %s\nProvider: %s\n\nTo change models, update config.yaml",
c.config.Agents.Defaults.Model, provider)
case "channels":
var enabled []string
if c.config.Channels.Telegram.Enabled {
enabled = append(enabled, "telegram")
}
if c.config.Channels.WhatsApp.Enabled {
enabled = append(enabled, "whatsapp")
}
if c.config.Channels.Feishu.Enabled {
enabled = append(enabled, "feishu")
}
if c.config.Channels.Discord.Enabled {
enabled = append(enabled, "discord")
}
if c.config.Channels.Slack.Enabled {
enabled = append(enabled, "slack")
}
response = fmt.Sprintf("Enabled Channels:\n- %s", strings.Join(enabled, "\n- "))
default:
response = fmt.Sprintf("Unknown parameter: %s. Try 'models' or 'channels'.", args)
}
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
ChatID: telego.ChatID{ID: message.Chat.ID},
Text: response,
ReplyParameters: &telego.ReplyParameters{
MessageID: message.MessageID,
},
})
return err
}
+2 -1
View File
@@ -175,6 +175,7 @@ type ProvidersConfig struct {
VLLM ProviderConfig `json:"vllm"`
Gemini ProviderConfig `json:"gemini"`
Nvidia ProviderConfig `json:"nvidia"`
Ollama ProviderConfig `json:"ollama"`
Moonshot ProviderConfig `json:"moonshot"`
ShengSuanYun ProviderConfig `json:"shengsuanyun"`
DeepSeek ProviderConfig `json:"deepseek"`
@@ -430,7 +431,7 @@ func SaveConfig(path string, cfg *Config) error {
return err
}
return os.WriteFile(path, data, 0644)
return os.WriteFile(path, data, 0600)
}
func (c *Config) WorkspacePath() string {
+27
View File
@@ -1,6 +1,9 @@
package config
import (
"os"
"path/filepath"
"runtime"
"testing"
)
@@ -147,6 +150,30 @@ func TestDefaultConfig_WebTools(t *testing.T) {
}
}
func TestSaveConfig_FilePermissions(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("file permission bits are not enforced on Windows")
}
tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "config.json")
cfg := DefaultConfig()
if err := SaveConfig(path, cfg); err != nil {
t.Fatalf("SaveConfig failed: %v", err)
}
info, err := os.Stat(path)
if err != nil {
t.Fatalf("Stat failed: %v", err)
}
perm := info.Mode().Perm()
if perm != 0600 {
t.Errorf("config file has permission %04o, want 0600", perm)
}
}
// TestConfig_Complete verifies all config fields are set
func TestConfig_Complete(t *testing.T) {
cfg := DefaultConfig()
+1 -1
View File
@@ -340,7 +340,7 @@ func (cs *CronService) saveStoreUnsafe() error {
return err
}
return os.WriteFile(cs.storePath, data, 0644)
return os.WriteFile(cs.storePath, data, 0600)
}
func (cs *CronService) AddJob(name string, schedule CronSchedule, message string, deliver bool, channel, to string) (*CronJob, error) {
+38
View File
@@ -0,0 +1,38 @@
package cron
import (
"os"
"path/filepath"
"runtime"
"testing"
)
func TestSaveStore_FilePermissions(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("file permission bits are not enforced on Windows")
}
tmpDir := t.TempDir()
storePath := filepath.Join(tmpDir, "cron", "jobs.json")
cs := NewCronService(storePath, nil)
_, err := cs.AddJob("test", CronSchedule{Kind: "every", EveryMS: int64Ptr(60000)}, "hello", false, "cli", "direct")
if err != nil {
t.Fatalf("AddJob failed: %v", err)
}
info, err := os.Stat(storePath)
if err != nil {
t.Fatalf("Stat failed: %v", err)
}
perm := info.Mode().Perm()
if perm != 0600 {
t.Errorf("cron store has permission %04o, want 0600", perm)
}
}
func int64Ptr(v int64) *int64 {
return &v
}
+164
View File
@@ -0,0 +1,164 @@
package health
import (
"context"
"encoding/json"
"fmt"
"net/http"
"sync"
"time"
)
type Server struct {
server *http.Server
mu sync.RWMutex
ready bool
checks map[string]Check
startTime time.Time
}
type Check struct {
Name string `json:"name"`
Status string `json:"status"`
Message string `json:"message,omitempty"`
Timestamp time.Time `json:"timestamp"`
}
type StatusResponse struct {
Status string `json:"status"`
Uptime string `json:"uptime"`
Checks map[string]Check `json:"checks,omitempty"`
}
func NewServer(host string, port int) *Server {
mux := http.NewServeMux()
s := &Server{
ready: false,
checks: make(map[string]Check),
startTime: time.Now(),
}
mux.HandleFunc("/health", s.healthHandler)
mux.HandleFunc("/ready", s.readyHandler)
addr := fmt.Sprintf("%s:%d", host, port)
s.server = &http.Server{
Addr: addr,
Handler: mux,
ReadTimeout: 5 * time.Second,
WriteTimeout: 5 * time.Second,
}
return s
}
func (s *Server) Start() error {
s.mu.Lock()
s.ready = true
s.mu.Unlock()
return s.server.ListenAndServe()
}
func (s *Server) StartContext(ctx context.Context) error {
s.mu.Lock()
s.ready = true
s.mu.Unlock()
errCh := make(chan error, 1)
go func() {
errCh <- s.server.ListenAndServe()
}()
select {
case err := <-errCh:
return err
case <-ctx.Done():
return s.server.Shutdown(context.Background())
}
}
func (s *Server) Stop(ctx context.Context) error {
s.mu.Lock()
s.ready = false
s.mu.Unlock()
return s.server.Shutdown(ctx)
}
func (s *Server) SetReady(ready bool) {
s.mu.Lock()
s.ready = ready
s.mu.Unlock()
}
func (s *Server) RegisterCheck(name string, checkFn func() (bool, string)) {
s.mu.Lock()
defer s.mu.Unlock()
status, msg := checkFn()
s.checks[name] = Check{
Name: name,
Status: statusString(status),
Message: msg,
Timestamp: time.Now(),
}
}
func (s *Server) healthHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
uptime := time.Since(s.startTime)
resp := StatusResponse{
Status: "ok",
Uptime: uptime.String(),
}
json.NewEncoder(w).Encode(resp)
}
func (s *Server) readyHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
s.mu.RLock()
ready := s.ready
checks := make(map[string]Check)
for k, v := range s.checks {
checks[k] = v
}
s.mu.RUnlock()
if !ready {
w.WriteHeader(http.StatusServiceUnavailable)
json.NewEncoder(w).Encode(StatusResponse{
Status: "not ready",
Checks: checks,
})
return
}
for _, check := range checks {
if check.Status == "fail" {
w.WriteHeader(http.StatusServiceUnavailable)
json.NewEncoder(w).Encode(StatusResponse{
Status: "not ready",
Checks: checks,
})
return
}
}
w.WriteHeader(http.StatusOK)
uptime := time.Since(s.startTime)
json.NewEncoder(w).Encode(StatusResponse{
Status: "ready",
Uptime: uptime.String(),
Checks: checks,
})
}
func statusString(ok bool) string {
if ok {
return "ok"
}
return "fail"
}
+4 -58
View File
@@ -171,68 +171,14 @@ func (p *ClaudeCliProvider) parseClaudeCliResponse(output string) (*LLMResponse,
}, nil
}
// extractToolCalls parses tool call JSON from the response text.
// extractToolCalls delegates to the shared extractToolCallsFromText function.
func (p *ClaudeCliProvider) extractToolCalls(text string) []ToolCall {
start := strings.Index(text, `{"tool_calls"`)
if start == -1 {
return nil
}
end := findMatchingBrace(text, start)
if end == start {
return nil
}
jsonStr := text[start:end]
var wrapper struct {
ToolCalls []struct {
ID string `json:"id"`
Type string `json:"type"`
Function struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
} `json:"function"`
} `json:"tool_calls"`
}
if err := json.Unmarshal([]byte(jsonStr), &wrapper); err != nil {
return nil
}
var result []ToolCall
for _, tc := range wrapper.ToolCalls {
var args map[string]interface{}
json.Unmarshal([]byte(tc.Function.Arguments), &args)
result = append(result, ToolCall{
ID: tc.ID,
Type: tc.Type,
Name: tc.Function.Name,
Arguments: args,
Function: &FunctionCall{
Name: tc.Function.Name,
Arguments: tc.Function.Arguments,
},
})
}
return result
return extractToolCallsFromText(text)
}
// stripToolCallsJSON removes tool call JSON from response text.
// stripToolCallsJSON delegates to the shared stripToolCallsFromText function.
func (p *ClaudeCliProvider) stripToolCallsJSON(text string) string {
start := strings.Index(text, `{"tool_calls"`)
if start == -1 {
return text
}
end := findMatchingBrace(text, start)
if end == start {
return text
}
return strings.TrimSpace(text[:start] + text[end:])
return stripToolCallsFromText(text)
}
// findMatchingBrace finds the index after the closing brace matching the opening brace at pos.
+79
View File
@@ -0,0 +1,79 @@
package providers
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"time"
)
// CodexCliAuth represents the ~/.codex/auth.json file structure.
type CodexCliAuth struct {
Tokens struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
AccountID string `json:"account_id"`
} `json:"tokens"`
}
// ReadCodexCliCredentials reads OAuth tokens from the Codex CLI's auth.json file.
// Expiry is estimated as file modification time + 1 hour (same approach as moltbot).
func ReadCodexCliCredentials() (accessToken, accountID string, expiresAt time.Time, err error) {
authPath, err := resolveCodexAuthPath()
if err != nil {
return "", "", time.Time{}, err
}
data, err := os.ReadFile(authPath)
if err != nil {
return "", "", time.Time{}, fmt.Errorf("reading %s: %w", authPath, err)
}
var auth CodexCliAuth
if err := json.Unmarshal(data, &auth); err != nil {
return "", "", time.Time{}, fmt.Errorf("parsing %s: %w", authPath, err)
}
if auth.Tokens.AccessToken == "" {
return "", "", time.Time{}, fmt.Errorf("no access_token in %s", authPath)
}
stat, err := os.Stat(authPath)
if err != nil {
expiresAt = time.Now().Add(time.Hour)
} else {
expiresAt = stat.ModTime().Add(time.Hour)
}
return auth.Tokens.AccessToken, auth.Tokens.AccountID, expiresAt, nil
}
// CreateCodexCliTokenSource creates a token source that reads from ~/.codex/auth.json.
// This allows the existing CodexProvider to reuse Codex CLI credentials.
func CreateCodexCliTokenSource() func() (string, string, error) {
return func() (string, string, error) {
token, accountID, expiresAt, err := ReadCodexCliCredentials()
if err != nil {
return "", "", fmt.Errorf("reading codex cli credentials: %w", err)
}
if time.Now().After(expiresAt) {
return "", "", fmt.Errorf("codex cli credentials expired (auth.json last modified > 1h ago). Run: codex login")
}
return token, accountID, nil
}
}
func resolveCodexAuthPath() (string, error) {
codexHome := os.Getenv("CODEX_HOME")
if codexHome == "" {
home, err := os.UserHomeDir()
if err != nil {
return "", fmt.Errorf("getting home dir: %w", err)
}
codexHome = filepath.Join(home, ".codex")
}
return filepath.Join(codexHome, "auth.json"), nil
}
+181
View File
@@ -0,0 +1,181 @@
package providers
import (
"os"
"path/filepath"
"testing"
"time"
)
func TestReadCodexCliCredentials_Valid(t *testing.T) {
tmpDir := t.TempDir()
authPath := filepath.Join(tmpDir, "auth.json")
authJSON := `{
"tokens": {
"access_token": "test-access-token",
"refresh_token": "test-refresh-token",
"account_id": "org-test123"
}
}`
if err := os.WriteFile(authPath, []byte(authJSON), 0600); err != nil {
t.Fatal(err)
}
t.Setenv("CODEX_HOME", tmpDir)
token, accountID, expiresAt, err := ReadCodexCliCredentials()
if err != nil {
t.Fatalf("ReadCodexCliCredentials() error: %v", err)
}
if token != "test-access-token" {
t.Errorf("token = %q, want %q", token, "test-access-token")
}
if accountID != "org-test123" {
t.Errorf("accountID = %q, want %q", accountID, "org-test123")
}
// Expiry should be within ~1 hour from now (file was just written)
if expiresAt.Before(time.Now()) {
t.Errorf("expiresAt = %v, should be in the future", expiresAt)
}
if expiresAt.After(time.Now().Add(2 * time.Hour)) {
t.Errorf("expiresAt = %v, should be within ~1 hour", expiresAt)
}
}
func TestReadCodexCliCredentials_MissingFile(t *testing.T) {
tmpDir := t.TempDir()
t.Setenv("CODEX_HOME", tmpDir)
_, _, _, err := ReadCodexCliCredentials()
if err == nil {
t.Fatal("expected error for missing auth.json")
}
}
func TestReadCodexCliCredentials_EmptyToken(t *testing.T) {
tmpDir := t.TempDir()
authPath := filepath.Join(tmpDir, "auth.json")
authJSON := `{"tokens": {"access_token": "", "refresh_token": "r", "account_id": "a"}}`
if err := os.WriteFile(authPath, []byte(authJSON), 0600); err != nil {
t.Fatal(err)
}
t.Setenv("CODEX_HOME", tmpDir)
_, _, _, err := ReadCodexCliCredentials()
if err == nil {
t.Fatal("expected error for empty access_token")
}
}
func TestReadCodexCliCredentials_InvalidJSON(t *testing.T) {
tmpDir := t.TempDir()
authPath := filepath.Join(tmpDir, "auth.json")
if err := os.WriteFile(authPath, []byte("not json"), 0600); err != nil {
t.Fatal(err)
}
t.Setenv("CODEX_HOME", tmpDir)
_, _, _, err := ReadCodexCliCredentials()
if err == nil {
t.Fatal("expected error for invalid JSON")
}
}
func TestReadCodexCliCredentials_NoAccountID(t *testing.T) {
tmpDir := t.TempDir()
authPath := filepath.Join(tmpDir, "auth.json")
authJSON := `{"tokens": {"access_token": "tok123", "refresh_token": "ref456"}}`
if err := os.WriteFile(authPath, []byte(authJSON), 0600); err != nil {
t.Fatal(err)
}
t.Setenv("CODEX_HOME", tmpDir)
token, accountID, _, err := ReadCodexCliCredentials()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if token != "tok123" {
t.Errorf("token = %q, want %q", token, "tok123")
}
if accountID != "" {
t.Errorf("accountID = %q, want empty", accountID)
}
}
func TestReadCodexCliCredentials_CodexHomeEnv(t *testing.T) {
tmpDir := t.TempDir()
customDir := filepath.Join(tmpDir, "custom-codex")
if err := os.MkdirAll(customDir, 0755); err != nil {
t.Fatal(err)
}
authJSON := `{"tokens": {"access_token": "custom-token", "refresh_token": "r"}}`
if err := os.WriteFile(filepath.Join(customDir, "auth.json"), []byte(authJSON), 0600); err != nil {
t.Fatal(err)
}
t.Setenv("CODEX_HOME", customDir)
token, _, _, err := ReadCodexCliCredentials()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if token != "custom-token" {
t.Errorf("token = %q, want %q", token, "custom-token")
}
}
func TestCreateCodexCliTokenSource_Valid(t *testing.T) {
tmpDir := t.TempDir()
authPath := filepath.Join(tmpDir, "auth.json")
authJSON := `{"tokens": {"access_token": "fresh-token", "refresh_token": "r", "account_id": "acc"}}`
if err := os.WriteFile(authPath, []byte(authJSON), 0600); err != nil {
t.Fatal(err)
}
t.Setenv("CODEX_HOME", tmpDir)
source := CreateCodexCliTokenSource()
token, accountID, err := source()
if err != nil {
t.Fatalf("token source error: %v", err)
}
if token != "fresh-token" {
t.Errorf("token = %q, want %q", token, "fresh-token")
}
if accountID != "acc" {
t.Errorf("accountID = %q, want %q", accountID, "acc")
}
}
func TestCreateCodexCliTokenSource_Expired(t *testing.T) {
tmpDir := t.TempDir()
authPath := filepath.Join(tmpDir, "auth.json")
authJSON := `{"tokens": {"access_token": "old-token", "refresh_token": "r"}}`
if err := os.WriteFile(authPath, []byte(authJSON), 0600); err != nil {
t.Fatal(err)
}
// Set file modification time to 2 hours ago
oldTime := time.Now().Add(-2 * time.Hour)
if err := os.Chtimes(authPath, oldTime, oldTime); err != nil {
t.Fatal(err)
}
t.Setenv("CODEX_HOME", tmpDir)
source := CreateCodexCliTokenSource()
_, _, err := source()
if err == nil {
t.Fatal("expected error for expired credentials")
}
}
+251
View File
@@ -0,0 +1,251 @@
package providers
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"os/exec"
"strings"
)
// CodexCliProvider implements LLMProvider by wrapping the codex CLI as a subprocess.
type CodexCliProvider struct {
command string
workspace string
}
// NewCodexCliProvider creates a new Codex CLI provider.
func NewCodexCliProvider(workspace string) *CodexCliProvider {
return &CodexCliProvider{
command: "codex",
workspace: workspace,
}
}
// Chat implements LLMProvider.Chat by executing the codex CLI in non-interactive mode.
func (p *CodexCliProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
if p.command == "" {
return nil, fmt.Errorf("codex command not configured")
}
prompt := p.buildPrompt(messages, tools)
args := []string{
"exec",
"--json",
"--dangerously-bypass-approvals-and-sandbox",
"--skip-git-repo-check",
"--color", "never",
}
if model != "" && model != "codex-cli" {
args = append(args, "-m", model)
}
if p.workspace != "" {
args = append(args, "-C", p.workspace)
}
args = append(args, "-") // read prompt from stdin
cmd := exec.CommandContext(ctx, p.command, args...)
cmd.Stdin = bytes.NewReader([]byte(prompt))
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
err := cmd.Run()
// Parse JSONL from stdout even if exit code is non-zero,
// because codex writes diagnostic noise to stderr (e.g. rollout errors)
// but still produces valid JSONL output.
if stdoutStr := stdout.String(); stdoutStr != "" {
resp, parseErr := p.parseJSONLEvents(stdoutStr)
if parseErr == nil && resp != nil && (resp.Content != "" || len(resp.ToolCalls) > 0) {
return resp, nil
}
}
if err != nil {
if ctx.Err() == context.Canceled {
return nil, ctx.Err()
}
if stderrStr := stderr.String(); stderrStr != "" {
return nil, fmt.Errorf("codex cli error: %s", stderrStr)
}
return nil, fmt.Errorf("codex cli error: %w", err)
}
return p.parseJSONLEvents(stdout.String())
}
// GetDefaultModel returns the default model identifier.
func (p *CodexCliProvider) GetDefaultModel() string {
return "codex-cli"
}
// buildPrompt converts messages to a prompt string for the Codex CLI.
// System messages are prepended as instructions since Codex CLI has no --system-prompt flag.
func (p *CodexCliProvider) buildPrompt(messages []Message, tools []ToolDefinition) string {
var systemParts []string
var conversationParts []string
for _, msg := range messages {
switch msg.Role {
case "system":
systemParts = append(systemParts, msg.Content)
case "user":
conversationParts = append(conversationParts, msg.Content)
case "assistant":
conversationParts = append(conversationParts, "Assistant: "+msg.Content)
case "tool":
conversationParts = append(conversationParts,
fmt.Sprintf("[Tool Result for %s]: %s", msg.ToolCallID, msg.Content))
}
}
var sb strings.Builder
if len(systemParts) > 0 {
sb.WriteString("## System Instructions\n\n")
sb.WriteString(strings.Join(systemParts, "\n\n"))
sb.WriteString("\n\n## Task\n\n")
}
if len(tools) > 0 {
sb.WriteString(p.buildToolsPrompt(tools))
sb.WriteString("\n\n")
}
// Simplify single user message (no prefix)
if len(conversationParts) == 1 && len(systemParts) == 0 && len(tools) == 0 {
return conversationParts[0]
}
sb.WriteString(strings.Join(conversationParts, "\n"))
return sb.String()
}
// buildToolsPrompt creates a tool definitions section for the prompt.
func (p *CodexCliProvider) buildToolsPrompt(tools []ToolDefinition) string {
var sb strings.Builder
sb.WriteString("## Available Tools\n\n")
sb.WriteString("When you need to use a tool, respond with ONLY a JSON object:\n\n")
sb.WriteString("```json\n")
sb.WriteString(`{"tool_calls":[{"id":"call_xxx","type":"function","function":{"name":"tool_name","arguments":"{...}"}}]}`)
sb.WriteString("\n```\n\n")
sb.WriteString("CRITICAL: The 'arguments' field MUST be a JSON-encoded STRING.\n\n")
sb.WriteString("### Tool Definitions:\n\n")
for _, tool := range tools {
if tool.Type != "function" {
continue
}
sb.WriteString(fmt.Sprintf("#### %s\n", tool.Function.Name))
if tool.Function.Description != "" {
sb.WriteString(fmt.Sprintf("Description: %s\n", tool.Function.Description))
}
if len(tool.Function.Parameters) > 0 {
paramsJSON, _ := json.Marshal(tool.Function.Parameters)
sb.WriteString(fmt.Sprintf("Parameters:\n```json\n%s\n```\n", string(paramsJSON)))
}
sb.WriteString("\n")
}
return sb.String()
}
// codexEvent represents a single JSONL event from `codex exec --json`.
type codexEvent struct {
Type string `json:"type"`
ThreadID string `json:"thread_id,omitempty"`
Message string `json:"message,omitempty"`
Item *codexEventItem `json:"item,omitempty"`
Usage *codexUsage `json:"usage,omitempty"`
Error *codexEventErr `json:"error,omitempty"`
}
type codexEventItem struct {
ID string `json:"id"`
Type string `json:"type"`
Text string `json:"text,omitempty"`
Command string `json:"command,omitempty"`
Status string `json:"status,omitempty"`
ExitCode *int `json:"exit_code,omitempty"`
Output string `json:"output,omitempty"`
}
type codexUsage struct {
InputTokens int `json:"input_tokens"`
CachedInputTokens int `json:"cached_input_tokens"`
OutputTokens int `json:"output_tokens"`
}
type codexEventErr struct {
Message string `json:"message"`
}
// parseJSONLEvents processes the JSONL output from codex exec --json.
func (p *CodexCliProvider) parseJSONLEvents(output string) (*LLMResponse, error) {
var contentParts []string
var usage *UsageInfo
var lastError string
scanner := bufio.NewScanner(strings.NewReader(output))
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" {
continue
}
var event codexEvent
if err := json.Unmarshal([]byte(line), &event); err != nil {
continue // skip malformed lines
}
switch event.Type {
case "item.completed":
if event.Item != nil && event.Item.Type == "agent_message" && event.Item.Text != "" {
contentParts = append(contentParts, event.Item.Text)
}
case "turn.completed":
if event.Usage != nil {
promptTokens := event.Usage.InputTokens + event.Usage.CachedInputTokens
usage = &UsageInfo{
PromptTokens: promptTokens,
CompletionTokens: event.Usage.OutputTokens,
TotalTokens: promptTokens + event.Usage.OutputTokens,
}
}
case "error":
lastError = event.Message
case "turn.failed":
if event.Error != nil {
lastError = event.Error.Message
}
}
}
if lastError != "" && len(contentParts) == 0 {
return nil, fmt.Errorf("codex cli: %s", lastError)
}
content := strings.Join(contentParts, "\n")
// Extract tool calls from response text (same pattern as ClaudeCliProvider)
toolCalls := extractToolCallsFromText(content)
finishReason := "stop"
if len(toolCalls) > 0 {
finishReason = "tool_calls"
content = stripToolCallsFromText(content)
}
return &LLMResponse{
Content: strings.TrimSpace(content),
ToolCalls: toolCalls,
FinishReason: finishReason,
Usage: usage,
}, nil
}
+585
View File
@@ -0,0 +1,585 @@
package providers
import (
"context"
"encoding/json"
"fmt"
"os"
"os/exec"
"path/filepath"
"strings"
"testing"
)
// --- JSONL Event Parsing Tests ---
func TestParseJSONLEvents_AgentMessage(t *testing.T) {
p := &CodexCliProvider{}
events := `{"type":"thread.started","thread_id":"abc-123"}
{"type":"turn.started"}
{"type":"item.completed","item":{"id":"item_1","type":"agent_message","text":"Hello from Codex!"}}
{"type":"turn.completed","usage":{"input_tokens":100,"cached_input_tokens":50,"output_tokens":20}}`
resp, err := p.parseJSONLEvents(events)
if err != nil {
t.Fatalf("parseJSONLEvents() error: %v", err)
}
if resp.Content != "Hello from Codex!" {
t.Errorf("Content = %q, want %q", resp.Content, "Hello from Codex!")
}
if resp.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
}
if resp.Usage == nil {
t.Fatal("Usage should not be nil")
}
if resp.Usage.PromptTokens != 150 {
t.Errorf("PromptTokens = %d, want 150", resp.Usage.PromptTokens)
}
if resp.Usage.CompletionTokens != 20 {
t.Errorf("CompletionTokens = %d, want 20", resp.Usage.CompletionTokens)
}
if resp.Usage.TotalTokens != 170 {
t.Errorf("TotalTokens = %d, want 170", resp.Usage.TotalTokens)
}
if len(resp.ToolCalls) != 0 {
t.Errorf("ToolCalls should be empty, got %d", len(resp.ToolCalls))
}
}
func TestParseJSONLEvents_ToolCallExtraction(t *testing.T) {
p := &CodexCliProvider{}
toolCallText := `Let me read that file.
{"tool_calls":[{"id":"call_1","type":"function","function":{"name":"read_file","arguments":"{\"path\":\"/tmp/test.txt\"}"}}]}`
// Build valid JSONL by marshaling the event
item := codexEvent{
Type: "item.completed",
Item: &codexEventItem{ID: "item_1", Type: "agent_message", Text: toolCallText},
}
itemJSON, _ := json.Marshal(item)
usageEvt := `{"type":"turn.completed","usage":{"input_tokens":50,"cached_input_tokens":0,"output_tokens":20}}`
events := `{"type":"turn.started"}` + "\n" + string(itemJSON) + "\n" + usageEvt
resp, err := p.parseJSONLEvents(events)
if err != nil {
t.Fatalf("parseJSONLEvents() error: %v", err)
}
if resp.FinishReason != "tool_calls" {
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "tool_calls")
}
if len(resp.ToolCalls) != 1 {
t.Fatalf("ToolCalls count = %d, want 1", len(resp.ToolCalls))
}
if resp.ToolCalls[0].Name != "read_file" {
t.Errorf("ToolCalls[0].Name = %q, want %q", resp.ToolCalls[0].Name, "read_file")
}
if resp.ToolCalls[0].ID != "call_1" {
t.Errorf("ToolCalls[0].ID = %q, want %q", resp.ToolCalls[0].ID, "call_1")
}
if resp.ToolCalls[0].Function.Arguments != `{"path":"/tmp/test.txt"}` {
t.Errorf("ToolCalls[0].Function.Arguments = %q", resp.ToolCalls[0].Function.Arguments)
}
// Content should have the tool call JSON stripped
if strings.Contains(resp.Content, "tool_calls") {
t.Errorf("Content should not contain tool_calls JSON, got: %q", resp.Content)
}
}
func TestParseJSONLEvents_MultipleToolCalls(t *testing.T) {
p := &CodexCliProvider{}
toolCallText := `{"tool_calls":[{"id":"call_1","type":"function","function":{"name":"read_file","arguments":"{\"path\":\"a.txt\"}"}},{"id":"call_2","type":"function","function":{"name":"write_file","arguments":"{\"path\":\"b.txt\",\"content\":\"hello\"}"}}]}`
item := codexEvent{
Type: "item.completed",
Item: &codexEventItem{ID: "item_1", Type: "agent_message", Text: toolCallText},
}
itemJSON, _ := json.Marshal(item)
events := `{"type":"turn.started"}` + "\n" + string(itemJSON) + "\n" + `{"type":"turn.completed"}`
resp, err := p.parseJSONLEvents(events)
if err != nil {
t.Fatalf("parseJSONLEvents() error: %v", err)
}
if len(resp.ToolCalls) != 2 {
t.Fatalf("ToolCalls count = %d, want 2", len(resp.ToolCalls))
}
if resp.ToolCalls[0].Name != "read_file" {
t.Errorf("ToolCalls[0].Name = %q, want %q", resp.ToolCalls[0].Name, "read_file")
}
if resp.ToolCalls[1].Name != "write_file" {
t.Errorf("ToolCalls[1].Name = %q, want %q", resp.ToolCalls[1].Name, "write_file")
}
if resp.FinishReason != "tool_calls" {
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "tool_calls")
}
}
func TestParseJSONLEvents_MultipleMessages(t *testing.T) {
p := &CodexCliProvider{}
events := `{"type":"turn.started"}
{"type":"item.completed","item":{"id":"item_1","type":"agent_message","text":"First part."}}
{"type":"item.completed","item":{"id":"item_2","type":"command_execution","command":"ls","status":"completed"}}
{"type":"item.completed","item":{"id":"item_3","type":"agent_message","text":"Second part."}}
{"type":"turn.completed"}`
resp, err := p.parseJSONLEvents(events)
if err != nil {
t.Fatalf("parseJSONLEvents() error: %v", err)
}
if resp.Content != "First part.\nSecond part." {
t.Errorf("Content = %q, want %q", resp.Content, "First part.\nSecond part.")
}
}
func TestParseJSONLEvents_ErrorEvent(t *testing.T) {
p := &CodexCliProvider{}
events := `{"type":"thread.started","thread_id":"abc"}
{"type":"turn.started"}
{"type":"error","message":"token expired"}
{"type":"turn.failed","error":{"message":"token expired"}}`
_, err := p.parseJSONLEvents(events)
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), "token expired") {
t.Errorf("error = %q, want to contain 'token expired'", err.Error())
}
}
func TestParseJSONLEvents_TurnFailed(t *testing.T) {
p := &CodexCliProvider{}
events := `{"type":"turn.started"}
{"type":"turn.failed","error":{"message":"rate limit exceeded"}}`
_, err := p.parseJSONLEvents(events)
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), "rate limit exceeded") {
t.Errorf("error = %q, want to contain 'rate limit exceeded'", err.Error())
}
}
func TestParseJSONLEvents_ErrorWithContent(t *testing.T) {
p := &CodexCliProvider{}
// If there's an error but also content, return the content (partial success)
events := `{"type":"turn.started"}
{"type":"item.completed","item":{"id":"item_1","type":"agent_message","text":"Partial result."}}
{"type":"error","message":"connection reset"}
{"type":"turn.failed","error":{"message":"connection reset"}}`
resp, err := p.parseJSONLEvents(events)
if err != nil {
t.Fatalf("should not error when content exists: %v", err)
}
if resp.Content != "Partial result." {
t.Errorf("Content = %q, want %q", resp.Content, "Partial result.")
}
}
func TestParseJSONLEvents_EmptyOutput(t *testing.T) {
p := &CodexCliProvider{}
resp, err := p.parseJSONLEvents("")
if err != nil {
t.Fatalf("empty output should not error: %v", err)
}
if resp.Content != "" {
t.Errorf("Content = %q, want empty", resp.Content)
}
}
func TestParseJSONLEvents_MalformedLines(t *testing.T) {
p := &CodexCliProvider{}
events := `not json at all
{"type":"item.completed","item":{"id":"item_1","type":"agent_message","text":"Good line."}}
another bad line
{"type":"turn.completed","usage":{"input_tokens":10,"output_tokens":5}}`
resp, err := p.parseJSONLEvents(events)
if err != nil {
t.Fatalf("should skip malformed lines: %v", err)
}
if resp.Content != "Good line." {
t.Errorf("Content = %q, want %q", resp.Content, "Good line.")
}
if resp.Usage == nil || resp.Usage.TotalTokens != 15 {
t.Errorf("Usage.TotalTokens = %v, want 15", resp.Usage)
}
}
func TestParseJSONLEvents_CommandExecution(t *testing.T) {
p := &CodexCliProvider{}
events := `{"type":"turn.started"}
{"type":"item.started","item":{"id":"item_1","type":"command_execution","command":"bash -lc ls","status":"in_progress"}}
{"type":"item.completed","item":{"id":"item_1","type":"command_execution","command":"bash -lc ls","status":"completed","exit_code":0,"output":"file1.go\nfile2.go"}}
{"type":"item.completed","item":{"id":"item_2","type":"agent_message","text":"Found 2 files."}}
{"type":"turn.completed"}`
resp, err := p.parseJSONLEvents(events)
if err != nil {
t.Fatalf("parseJSONLEvents() error: %v", err)
}
// command_execution items should be skipped; only agent_message text is returned
if resp.Content != "Found 2 files." {
t.Errorf("Content = %q, want %q", resp.Content, "Found 2 files.")
}
}
func TestParseJSONLEvents_NoUsage(t *testing.T) {
p := &CodexCliProvider{}
events := `{"type":"turn.started"}
{"type":"item.completed","item":{"id":"item_1","type":"agent_message","text":"No usage info."}}
{"type":"turn.completed"}`
resp, err := p.parseJSONLEvents(events)
if err != nil {
t.Fatalf("parseJSONLEvents() error: %v", err)
}
if resp.Usage != nil {
t.Errorf("Usage should be nil when turn.completed has no usage, got %+v", resp.Usage)
}
}
// --- Prompt Building Tests ---
func TestBuildPrompt_SystemAsInstructions(t *testing.T) {
p := &CodexCliProvider{}
messages := []Message{
{Role: "system", Content: "You are helpful."},
{Role: "user", Content: "Hi there"},
}
prompt := p.buildPrompt(messages, nil)
if !strings.Contains(prompt, "## System Instructions") {
t.Error("prompt should contain '## System Instructions'")
}
if !strings.Contains(prompt, "You are helpful.") {
t.Error("prompt should contain system content")
}
if !strings.Contains(prompt, "## Task") {
t.Error("prompt should contain '## Task'")
}
if !strings.Contains(prompt, "Hi there") {
t.Error("prompt should contain user message")
}
}
func TestBuildPrompt_NoSystem(t *testing.T) {
p := &CodexCliProvider{}
messages := []Message{
{Role: "user", Content: "Just a question"},
}
prompt := p.buildPrompt(messages, nil)
if strings.Contains(prompt, "## System Instructions") {
t.Error("prompt should not contain system instructions header")
}
if prompt != "Just a question" {
t.Errorf("prompt = %q, want %q", prompt, "Just a question")
}
}
func TestBuildPrompt_WithTools(t *testing.T) {
p := &CodexCliProvider{}
messages := []Message{
{Role: "user", Content: "Get weather"},
}
tools := []ToolDefinition{
{
Type: "function",
Function: ToolFunctionDefinition{
Name: "get_weather",
Description: "Get current weather",
Parameters: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"city": map[string]interface{}{"type": "string"},
},
},
},
},
}
prompt := p.buildPrompt(messages, tools)
if !strings.Contains(prompt, "## Available Tools") {
t.Error("prompt should contain tools section")
}
if !strings.Contains(prompt, "get_weather") {
t.Error("prompt should contain tool name")
}
if !strings.Contains(prompt, "Get current weather") {
t.Error("prompt should contain tool description")
}
}
func TestBuildPrompt_MultipleMessages(t *testing.T) {
p := &CodexCliProvider{}
messages := []Message{
{Role: "user", Content: "Hello"},
{Role: "assistant", Content: "Hi! How can I help?"},
{Role: "user", Content: "Tell me about Go"},
}
prompt := p.buildPrompt(messages, nil)
if !strings.Contains(prompt, "Hello") {
t.Error("prompt should contain first user message")
}
if !strings.Contains(prompt, "Assistant: Hi! How can I help?") {
t.Error("prompt should contain assistant message with prefix")
}
if !strings.Contains(prompt, "Tell me about Go") {
t.Error("prompt should contain second user message")
}
}
func TestBuildPrompt_ToolResults(t *testing.T) {
p := &CodexCliProvider{}
messages := []Message{
{Role: "user", Content: "Weather?"},
{Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"},
}
prompt := p.buildPrompt(messages, nil)
if !strings.Contains(prompt, "[Tool Result for call_1]") {
t.Error("prompt should contain tool result")
}
if !strings.Contains(prompt, `{"temp": 72}`) {
t.Error("prompt should contain tool result content")
}
}
func TestBuildPrompt_SystemAndTools(t *testing.T) {
p := &CodexCliProvider{}
messages := []Message{
{Role: "system", Content: "Be concise."},
{Role: "user", Content: "Do something"},
}
tools := []ToolDefinition{
{
Type: "function",
Function: ToolFunctionDefinition{
Name: "my_tool",
Description: "A tool",
},
},
}
prompt := p.buildPrompt(messages, tools)
// System instructions should come first
sysIdx := strings.Index(prompt, "## System Instructions")
toolIdx := strings.Index(prompt, "## Available Tools")
taskIdx := strings.Index(prompt, "## Task")
if sysIdx == -1 || toolIdx == -1 || taskIdx == -1 {
t.Fatal("prompt should contain all sections")
}
if sysIdx >= taskIdx {
t.Error("system instructions should come before task")
}
if taskIdx >= toolIdx {
t.Error("task section should come before tools in the output")
}
}
// --- CLI Argument Tests ---
func TestCodexCliProvider_GetDefaultModel(t *testing.T) {
p := NewCodexCliProvider("")
if got := p.GetDefaultModel(); got != "codex-cli" {
t.Errorf("GetDefaultModel() = %q, want %q", got, "codex-cli")
}
}
// --- Mock CLI Integration Test ---
func createMockCodexCLI(t *testing.T, events []string) string {
t.Helper()
tmpDir := t.TempDir()
scriptPath := filepath.Join(tmpDir, "codex")
var sb strings.Builder
sb.WriteString("#!/bin/bash\n")
for _, event := range events {
sb.WriteString(fmt.Sprintf("echo '%s'\n", event))
}
if err := os.WriteFile(scriptPath, []byte(sb.String()), 0755); err != nil {
t.Fatal(err)
}
return scriptPath
}
func TestCodexCliProvider_MockCLI_Success(t *testing.T) {
scriptPath := createMockCodexCLI(t, []string{
`{"type":"thread.started","thread_id":"test-123"}`,
`{"type":"turn.started"}`,
`{"type":"item.completed","item":{"id":"item_1","type":"agent_message","text":"Mock response from Codex CLI"}}`,
`{"type":"turn.completed","usage":{"input_tokens":50,"cached_input_tokens":10,"output_tokens":15}}`,
})
p := &CodexCliProvider{
command: scriptPath,
workspace: "",
}
messages := []Message{{Role: "user", Content: "Hello"}}
resp, err := p.Chat(context.Background(), messages, nil, "", nil)
if err != nil {
t.Fatalf("Chat() error: %v", err)
}
if resp.Content != "Mock response from Codex CLI" {
t.Errorf("Content = %q, want %q", resp.Content, "Mock response from Codex CLI")
}
if resp.Usage == nil {
t.Fatal("Usage should not be nil")
}
if resp.Usage.PromptTokens != 60 {
t.Errorf("PromptTokens = %d, want 60", resp.Usage.PromptTokens)
}
if resp.Usage.CompletionTokens != 15 {
t.Errorf("CompletionTokens = %d, want 15", resp.Usage.CompletionTokens)
}
}
func TestCodexCliProvider_MockCLI_Error(t *testing.T) {
scriptPath := createMockCodexCLI(t, []string{
`{"type":"thread.started","thread_id":"test-err"}`,
`{"type":"turn.started"}`,
`{"type":"error","message":"auth token expired"}`,
`{"type":"turn.failed","error":{"message":"auth token expired"}}`,
})
p := &CodexCliProvider{
command: scriptPath,
workspace: "",
}
messages := []Message{{Role: "user", Content: "Hello"}}
_, err := p.Chat(context.Background(), messages, nil, "", nil)
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), "auth token expired") {
t.Errorf("error = %q, want to contain 'auth token expired'", err.Error())
}
}
func TestCodexCliProvider_MockCLI_WithModel(t *testing.T) {
// Mock script that captures args to verify model flag is passed
tmpDir := t.TempDir()
scriptPath := filepath.Join(tmpDir, "codex")
script := `#!/bin/bash
# Write args to a file for verification
echo "$@" > "` + filepath.Join(tmpDir, "args.txt") + `"
echo '{"type":"item.completed","item":{"id":"1","type":"agent_message","text":"ok"}}'
echo '{"type":"turn.completed"}'`
if err := os.WriteFile(scriptPath, []byte(script), 0755); err != nil {
t.Fatal(err)
}
p := &CodexCliProvider{
command: scriptPath,
workspace: "/tmp/test-workspace",
}
messages := []Message{{Role: "user", Content: "test"}}
_, err := p.Chat(context.Background(), messages, nil, "gpt-5.2-codex", nil)
if err != nil {
t.Fatalf("Chat() error: %v", err)
}
// Verify the args
argsData, err := os.ReadFile(filepath.Join(tmpDir, "args.txt"))
if err != nil {
t.Fatalf("reading args: %v", err)
}
args := string(argsData)
if !strings.Contains(args, "-m gpt-5.2-codex") {
t.Errorf("args should contain model flag, got: %s", args)
}
if !strings.Contains(args, "-C /tmp/test-workspace") {
t.Errorf("args should contain workspace flag, got: %s", args)
}
if !strings.Contains(args, "--json") {
t.Errorf("args should contain --json, got: %s", args)
}
if !strings.Contains(args, "--dangerously-bypass-approvals-and-sandbox") {
t.Errorf("args should contain bypass flag, got: %s", args)
}
}
func TestCodexCliProvider_MockCLI_ContextCancel(t *testing.T) {
// Script that sleeps forever
tmpDir := t.TempDir()
scriptPath := filepath.Join(tmpDir, "codex")
script := "#!/bin/bash\nsleep 60"
if err := os.WriteFile(scriptPath, []byte(script), 0755); err != nil {
t.Fatal(err)
}
p := &CodexCliProvider{
command: scriptPath,
workspace: "",
}
ctx, cancel := context.WithCancel(context.Background())
cancel() // cancel immediately
messages := []Message{{Role: "user", Content: "test"}}
_, err := p.Chat(ctx, messages, nil, "", nil)
if err == nil {
t.Fatal("expected error on canceled context")
}
}
func TestCodexCliProvider_EmptyCommand(t *testing.T) {
p := &CodexCliProvider{command: ""}
messages := []Message{{Role: "user", Content: "test"}}
_, err := p.Chat(context.Background(), messages, nil, "", nil)
if err == nil {
t.Fatal("expected error for empty command")
}
}
// --- Integration Test (requires real codex CLI with valid auth) ---
func TestCodexCliProvider_Integration(t *testing.T) {
if os.Getenv("PICOCLAW_INTEGRATION_TESTS") == "" {
t.Skip("skipping integration test (set PICOCLAW_INTEGRATION_TESTS=1 to enable)")
}
// Verify codex is available
codexPath, err := exec.LookPath("codex")
if err != nil {
t.Skip("codex CLI not found in PATH")
}
p := &CodexCliProvider{
command: codexPath,
workspace: "",
}
messages := []Message{
{Role: "user", Content: "Respond with just the word 'hello' and nothing else."},
}
resp, err := p.Chat(context.Background(), messages, nil, "", nil)
if err != nil {
t.Fatalf("Chat() error: %v", err)
}
lower := strings.ToLower(strings.TrimSpace(resp.Content))
if !strings.Contains(lower, "hello") {
t.Errorf("Content = %q, expected to contain 'hello'", resp.Content)
}
}
+123 -9
View File
@@ -3,6 +3,7 @@ package providers
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
@@ -10,8 +11,12 @@ import (
"github.com/openai/openai-go/v3/option"
"github.com/openai/openai-go/v3/responses"
"github.com/sipeed/picoclaw/pkg/auth"
"github.com/sipeed/picoclaw/pkg/logger"
)
const codexDefaultModel = "gpt-5.2"
const codexDefaultInstructions = "You are Codex, a coding assistant."
type CodexProvider struct {
client *openai.Client
accountID string
@@ -24,6 +29,8 @@ func NewCodexProvider(token, accountID string) *CodexProvider {
opts := []option.RequestOption{
option.WithBaseURL("https://chatgpt.com/backend-api/codex"),
option.WithAPIKey(token),
option.WithHeader("originator", "codex_cli_rs"),
option.WithHeader("OpenAI-Beta", "responses=experimental"),
}
if accountID != "" {
opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accountID))
@@ -43,6 +50,15 @@ func NewCodexProviderWithTokenSource(token, accountID string, tokenSource func()
func (p *CodexProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
var opts []option.RequestOption
accountID := p.accountID
resolvedModel, fallbackReason := resolveCodexModel(model)
if fallbackReason != "" {
logger.WarnCF("provider.codex", "Requested model is not compatible with Codex backend, using fallback", map[string]interface{}{
"requested_model": model,
"resolved_model": resolvedModel,
"reason": fallbackReason,
})
}
if p.tokenSource != nil {
tok, accID, err := p.tokenSource()
if err != nil {
@@ -50,22 +66,120 @@ func (p *CodexProvider) Chat(ctx context.Context, messages []Message, tools []To
}
opts = append(opts, option.WithAPIKey(tok))
if accID != "" {
opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accID))
accountID = accID
}
}
if accountID != "" {
opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accountID))
} else {
logger.WarnCF("provider.codex", "No account id found for Codex request; backend may reject with 400", map[string]interface{}{
"requested_model": model,
"resolved_model": resolvedModel,
})
}
params := buildCodexParams(messages, tools, model, options)
params := buildCodexParams(messages, tools, resolvedModel, options)
resp, err := p.client.Responses.New(ctx, params, opts...)
stream := p.client.Responses.NewStreaming(ctx, params, opts...)
defer stream.Close()
var resp *responses.Response
for stream.Next() {
evt := stream.Current()
if evt.Type == "response.completed" || evt.Type == "response.failed" || evt.Type == "response.incomplete" {
evtResp := evt.Response
if evtResp.ID != "" {
copy := evtResp
resp = &copy
}
}
}
err := stream.Err()
if err != nil {
fields := map[string]interface{}{
"requested_model": model,
"resolved_model": resolvedModel,
"messages_count": len(messages),
"tools_count": len(tools),
"account_id_present": accountID != "",
"error": err.Error(),
}
var apiErr *openai.Error
if errors.As(err, &apiErr) {
fields["status_code"] = apiErr.StatusCode
fields["api_type"] = apiErr.Type
fields["api_code"] = apiErr.Code
fields["api_param"] = apiErr.Param
fields["api_message"] = apiErr.Message
if apiErr.StatusCode == 400 {
fields["hint"] = "verify account id header and model compatibility for codex backend"
}
if apiErr.Response != nil {
fields["request_id"] = apiErr.Response.Header.Get("x-request-id")
}
}
logger.ErrorCF("provider.codex", "Codex API call failed", fields)
return nil, fmt.Errorf("codex API call: %w", err)
}
if resp == nil {
fields := map[string]interface{}{
"requested_model": model,
"resolved_model": resolvedModel,
"messages_count": len(messages),
"tools_count": len(tools),
"account_id_present": accountID != "",
}
logger.ErrorCF("provider.codex", "Codex stream ended without completed response event", fields)
return nil, fmt.Errorf("codex API call: stream ended without completed response")
}
return parseCodexResponse(resp), nil
}
func (p *CodexProvider) GetDefaultModel() string {
return "gpt-4o"
return codexDefaultModel
}
func resolveCodexModel(model string) (string, string) {
m := strings.ToLower(strings.TrimSpace(model))
if m == "" {
return codexDefaultModel, "empty model"
}
if strings.HasPrefix(m, "openai/") {
m = strings.TrimPrefix(m, "openai/")
} else if strings.Contains(m, "/") {
return codexDefaultModel, "non-openai model namespace"
}
unsupportedPrefixes := []string{
"glm",
"claude",
"anthropic",
"gemini",
"google",
"moonshot",
"kimi",
"qwen",
"deepseek",
"llama",
"meta-llama",
"mistral",
"grok",
"xai",
"zhipu",
}
for _, prefix := range unsupportedPrefixes {
if strings.HasPrefix(m, prefix) {
return codexDefaultModel, "unsupported model prefix"
}
}
if strings.HasPrefix(m, "gpt-") || strings.HasPrefix(m, "o3") || strings.HasPrefix(m, "o4") {
return m, ""
}
return codexDefaultModel, "unsupported model family"
}
func buildCodexParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) responses.ResponseNewParams {
@@ -135,7 +249,8 @@ func buildCodexParams(messages []Message, tools []ToolDefinition, model string,
Input: responses.ResponseNewParamsInputUnion{
OfInputItemList: inputItems,
},
Store: openai.Opt(false),
Instructions: openai.Opt(instructions),
Store: openai.Opt(false),
}
if instructions != "" {
@@ -149,10 +264,6 @@ func buildCodexParams(messages []Message, tools []ToolDefinition, model string,
params.MaxOutputTokens = openai.Opt(int64(maxTokens))
}
if temp, ok := options["temperature"].(float64); ok {
params.Temperature = openai.Opt(temp)
}
if len(tools) > 0 {
params.Tools = translateToolsForCodex(tools)
}
@@ -242,6 +353,9 @@ func createCodexTokenSource() func() (string, string, error) {
if err != nil {
return "", "", fmt.Errorf("refreshing token: %w", err)
}
if refreshed.AccountID == "" {
refreshed.AccountID = cred.AccountID
}
if err := auth.SetCredential("openai", refreshed); err != nil {
return "", "", fmt.Errorf("saving refreshed token: %w", err)
}
+204 -5
View File
@@ -2,6 +2,7 @@ package providers
import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
@@ -16,7 +17,8 @@ func TestBuildCodexParams_BasicMessage(t *testing.T) {
{Role: "user", Content: "Hello"},
}
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{
"max_tokens": 2048,
"max_tokens": 2048,
"temperature": 0.7,
})
if params.Model != "gpt-4o" {
t.Errorf("Model = %q, want %q", params.Model, "gpt-4o")
@@ -203,6 +205,16 @@ func TestCodexProvider_ChatRoundTrip(t *testing.T) {
return
}
var reqBody map[string]interface{}
if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil {
http.Error(w, "invalid json", http.StatusBadRequest)
return
}
if reqBody["stream"] != true {
http.Error(w, "stream must be true", http.StatusBadRequest)
return
}
resp := map[string]interface{}{
"id": "resp_test",
"object": "response",
@@ -226,8 +238,7 @@ func TestCodexProvider_ChatRoundTrip(t *testing.T) {
"output_tokens_details": map[string]interface{}{"reasoning_tokens": 0},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
writeCompletedSSE(w, resp)
}))
defer server.Close()
@@ -250,10 +261,185 @@ func TestCodexProvider_ChatRoundTrip(t *testing.T) {
}
}
func TestCodexProvider_ChatRoundTrip_TokenSourceFallbackAccountID(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/responses" {
http.Error(w, "not found: "+r.URL.Path, http.StatusNotFound)
return
}
if r.Header.Get("Authorization") != "Bearer refreshed-token" {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
if r.Header.Get("Chatgpt-Account-Id") != "acc-123" {
http.Error(w, "missing account id", http.StatusBadRequest)
return
}
var reqBody map[string]interface{}
if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil {
http.Error(w, "invalid json", http.StatusBadRequest)
return
}
if _, ok := reqBody["instructions"]; !ok {
http.Error(w, "missing instructions", http.StatusBadRequest)
return
}
if reqBody["instructions"] == "" {
http.Error(w, "instructions must not be empty", http.StatusBadRequest)
return
}
if _, ok := reqBody["temperature"]; ok {
http.Error(w, "temperature is not supported", http.StatusBadRequest)
return
}
if reqBody["stream"] != true {
http.Error(w, "stream must be true", http.StatusBadRequest)
return
}
resp := map[string]interface{}{
"id": "resp_test",
"object": "response",
"status": "completed",
"output": []map[string]interface{}{
{
"id": "msg_1",
"type": "message",
"role": "assistant",
"status": "completed",
"content": []map[string]interface{}{
{"type": "output_text", "text": "Hi from Codex!"},
},
},
},
"usage": map[string]interface{}{
"input_tokens": 8,
"output_tokens": 4,
"total_tokens": 12,
"input_tokens_details": map[string]interface{}{"cached_tokens": 0},
"output_tokens_details": map[string]interface{}{"reasoning_tokens": 0},
},
}
writeCompletedSSE(w, resp)
}))
defer server.Close()
provider := NewCodexProvider("stale-token", "acc-123")
provider.client = createOpenAITestClient(server.URL, "stale-token", "")
provider.tokenSource = func() (string, string, error) {
return "refreshed-token", "", nil
}
messages := []Message{{Role: "user", Content: "Hello"}}
resp, err := provider.Chat(t.Context(), messages, nil, "gpt-4o", map[string]interface{}{"temperature": 0.7})
if err != nil {
t.Fatalf("Chat() error: %v", err)
}
if resp.Content != "Hi from Codex!" {
t.Errorf("Content = %q, want %q", resp.Content, "Hi from Codex!")
}
}
func TestCodexProvider_ChatRoundTrip_ModelFallbackFromUnsupported(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/responses" {
http.Error(w, "not found: "+r.URL.Path, http.StatusNotFound)
return
}
var reqBody map[string]interface{}
if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil {
http.Error(w, "invalid json", http.StatusBadRequest)
return
}
if reqBody["model"] != codexDefaultModel {
http.Error(w, "unsupported model", http.StatusBadRequest)
return
}
if reqBody["stream"] != true {
http.Error(w, "stream must be true", http.StatusBadRequest)
return
}
if reqBody["instructions"] != codexDefaultInstructions {
http.Error(w, "missing default instructions", http.StatusBadRequest)
return
}
resp := map[string]interface{}{
"id": "resp_test",
"object": "response",
"status": "completed",
"output": []map[string]interface{}{
{
"id": "msg_1",
"type": "message",
"role": "assistant",
"status": "completed",
"content": []map[string]interface{}{
{"type": "output_text", "text": "Hi from Codex!"},
},
},
},
"usage": map[string]interface{}{
"input_tokens": 8,
"output_tokens": 4,
"total_tokens": 12,
"input_tokens_details": map[string]interface{}{"cached_tokens": 0},
"output_tokens_details": map[string]interface{}{"reasoning_tokens": 0},
},
}
writeCompletedSSE(w, resp)
}))
defer server.Close()
provider := NewCodexProvider("test-token", "acc-123")
provider.client = createOpenAITestClient(server.URL, "test-token", "acc-123")
messages := []Message{{Role: "user", Content: "Hello"}}
resp, err := provider.Chat(t.Context(), messages, nil, "gpt-5.2", nil)
if err != nil {
t.Fatalf("Chat() error: %v", err)
}
if resp.Content != "Hi from Codex!" {
t.Errorf("Content = %q, want %q", resp.Content, "Hi from Codex!")
}
}
func TestCodexProvider_GetDefaultModel(t *testing.T) {
p := NewCodexProvider("test-token", "")
if got := p.GetDefaultModel(); got != "gpt-4o" {
t.Errorf("GetDefaultModel() = %q, want %q", got, "gpt-4o")
if got := p.GetDefaultModel(); got != codexDefaultModel {
t.Errorf("GetDefaultModel() = %q, want %q", got, codexDefaultModel)
}
}
func TestResolveCodexModel(t *testing.T) {
tests := []struct {
name string
input string
wantModel string
wantFallback bool
}{
{name: "empty", input: "", wantModel: codexDefaultModel, wantFallback: true},
{name: "unsupported namespace", input: "anthropic/claude-3.5", wantModel: codexDefaultModel, wantFallback: true},
{name: "non-openai prefixed", input: "glm-4.7", wantModel: codexDefaultModel, wantFallback: true},
{name: "openai prefix", input: "openai/gpt-5.2", wantModel: "gpt-5.2", wantFallback: false},
{name: "direct gpt", input: "gpt-4o", wantModel: "gpt-4o", wantFallback: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotModel, reason := resolveCodexModel(tt.input)
if gotModel != tt.wantModel {
t.Fatalf("resolveCodexModel(%q) model = %q, want %q", tt.input, gotModel, tt.wantModel)
}
if tt.wantFallback && reason == "" {
t.Fatalf("resolveCodexModel(%q) expected fallback reason", tt.input)
}
if !tt.wantFallback && reason != "" {
t.Fatalf("resolveCodexModel(%q) unexpected fallback reason: %q", tt.input, reason)
}
})
}
}
@@ -268,3 +454,16 @@ func createOpenAITestClient(baseURL, token, accountID string) *openai.Client {
c := openai.NewClient(opts...)
return &c
}
func writeCompletedSSE(w http.ResponseWriter, response map[string]interface{}) {
event := map[string]interface{}{
"type": "response.completed",
"sequence_number": 1,
"response": response,
}
b, _ := json.Marshal(event)
w.Header().Set("Content-Type", "text/event-stream")
fmt.Fprintf(w, "event: response.completed\n")
fmt.Fprintf(w, "data: %s\n\n", string(b))
fmt.Fprintf(w, "data: [DONE]\n\n")
}
+21 -4
View File
@@ -53,10 +53,10 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too
return nil, fmt.Errorf("API base not configured")
}
// Strip provider prefix from model name (e.g., moonshot/kimi-k2.5 -> kimi-k2.5)
// Strip provider prefix from model name (e.g., moonshot/kimi-k2.5 -> kimi-k2.5, groq/openai/gpt-oss-120b -> openai/gpt-oss-120b, ollama/qwen2.5:14b -> qwen2.5:14b)
if idx := strings.Index(model, "/"); idx != -1 {
prefix := model[:idx]
if prefix == "moonshot" || prefix == "nvidia" {
if prefix == "moonshot" || prefix == "nvidia" || prefix == "groq" || prefix == "ollama" {
model = model[idx+1:]
}
}
@@ -240,6 +240,9 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) {
}
case "openai", "gpt":
if cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != "" {
if cfg.Providers.OpenAI.AuthMethod == "codex-cli" {
return NewCodexProviderWithTokenSource("", "", CreateCodexCliTokenSource()), nil
}
if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" {
return createCodexAuthProvider()
}
@@ -299,11 +302,17 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) {
}
}
case "claude-cli", "claudecode", "claude-code":
workspace := cfg.Agents.Defaults.Workspace
workspace := cfg.WorkspacePath()
if workspace == "" {
workspace = "."
}
return NewClaudeCliProvider(workspace), nil
case "codex-cli", "codex-code":
workspace := cfg.WorkspacePath()
if workspace == "" {
workspace = "."
}
return NewCodexCliProvider(workspace), nil
case "deepseek":
if cfg.Providers.DeepSeek.APIKey != "" {
apiKey = cfg.Providers.DeepSeek.APIKey
@@ -400,7 +409,15 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) {
if apiBase == "" {
apiBase = "https://integrate.api.nvidia.com/v1"
}
case (strings.Contains(lowerModel, "ollama") || strings.HasPrefix(model, "ollama/")) && cfg.Providers.Ollama.APIKey != "":
fmt.Println("Ollama provider selected based on model name prefix")
apiKey = cfg.Providers.Ollama.APIKey
apiBase = cfg.Providers.Ollama.APIBase
proxy = cfg.Providers.Ollama.Proxy
if apiBase == "" {
apiBase = "http://localhost:11434/v1"
}
fmt.Println("Ollama apiBase:", apiBase)
case cfg.Providers.VLLM.APIBase != "":
apiKey = cfg.Providers.VLLM.APIKey
apiBase = cfg.Providers.VLLM.APIBase
+72
View File
@@ -0,0 +1,72 @@
package providers
import (
"encoding/json"
"strings"
)
// extractToolCallsFromText parses tool call JSON from response text.
// Both ClaudeCliProvider and CodexCliProvider use this to extract
// tool calls that the model outputs in its response text.
func extractToolCallsFromText(text string) []ToolCall {
start := strings.Index(text, `{"tool_calls"`)
if start == -1 {
return nil
}
end := findMatchingBrace(text, start)
if end == start {
return nil
}
jsonStr := text[start:end]
var wrapper struct {
ToolCalls []struct {
ID string `json:"id"`
Type string `json:"type"`
Function struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
} `json:"function"`
} `json:"tool_calls"`
}
if err := json.Unmarshal([]byte(jsonStr), &wrapper); err != nil {
return nil
}
var result []ToolCall
for _, tc := range wrapper.ToolCalls {
var args map[string]interface{}
json.Unmarshal([]byte(tc.Function.Arguments), &args)
result = append(result, ToolCall{
ID: tc.ID,
Type: tc.Type,
Name: tc.Function.Name,
Arguments: args,
Function: &FunctionCall{
Name: tc.Function.Name,
Arguments: tc.Function.Arguments,
},
})
}
return result
}
// stripToolCallsFromText removes tool call JSON from response text.
func stripToolCallsFromText(text string) string {
start := strings.Index(text, `{"tool_calls"`)
if start == -1 {
return text
}
end := findMatchingBrace(text, start)
if end == start {
return text
}
return strings.TrimSpace(text[:start] + text[end:])
}
+16
View File
@@ -264,3 +264,19 @@ func (sm *SessionManager) loadSessions() error {
return nil
}
// SetHistory updates the messages of a session.
func (sm *SessionManager) SetHistory(key string, history []providers.Message) {
sm.mu.Lock()
defer sm.mu.Unlock()
session, ok := sm.sessions[key]
if ok {
// Create a deep copy to strictly isolate internal state
// from the caller's slice.
msgs := make([]providers.Message, len(history))
copy(msgs, history)
session.Messages = msgs
session.Updated = time.Now()
}
}
+45
View File
@@ -2,13 +2,22 @@ package skills
import (
"encoding/json"
"errors"
"fmt"
"log/slog"
"os"
"path/filepath"
"regexp"
"strings"
)
var namePattern = regexp.MustCompile(`^[a-zA-Z0-9]+(-[a-zA-Z0-9]+)*$`)
const (
MaxNameLength = 64
MaxDescriptionLength = 1024
)
type SkillMetadata struct {
Name string `json:"name"`
Description string `json:"description"`
@@ -21,6 +30,27 @@ type SkillInfo struct {
Description string `json:"description"`
}
func (info SkillInfo) validate() error {
var errs error
if info.Name == "" {
errs = errors.Join(errs, errors.New("name is required"))
} else {
if len(info.Name) > MaxNameLength {
errs = errors.Join(errs, fmt.Errorf("name exceeds %d characters", MaxNameLength))
}
if !namePattern.MatchString(info.Name) {
errs = errors.Join(errs, errors.New("name must be alphanumeric with hyphens"))
}
}
if info.Description == "" {
errs = errors.Join(errs, errors.New("description is required"))
} else if len(info.Description) > MaxDescriptionLength {
errs = errors.Join(errs, fmt.Errorf("description exceeds %d character", MaxDescriptionLength))
}
return errs
}
type SkillsLoader struct {
workspace string
workspaceSkills string // workspace skills (项目级别)
@@ -54,6 +84,11 @@ func (sl *SkillsLoader) ListSkills() []SkillInfo {
metadata := sl.getSkillMetadata(skillFile)
if metadata != nil {
info.Description = metadata.Description
info.Name = metadata.Name
}
if err := info.validate(); err != nil {
slog.Warn("invalid skill from workspace", "name", info.Name, "error", err)
continue
}
skills = append(skills, info)
}
@@ -89,6 +124,11 @@ func (sl *SkillsLoader) ListSkills() []SkillInfo {
metadata := sl.getSkillMetadata(skillFile)
if metadata != nil {
info.Description = metadata.Description
info.Name = metadata.Name
}
if err := info.validate(); err != nil {
slog.Warn("invalid skill from global", "name", info.Name, "error", err)
continue
}
skills = append(skills, info)
}
@@ -123,6 +163,11 @@ func (sl *SkillsLoader) ListSkills() []SkillInfo {
metadata := sl.getSkillMetadata(skillFile)
if metadata != nil {
info.Description = metadata.Description
info.Name = metadata.Name
}
if err := info.validate(); err != nil {
slog.Warn("invalid skill from builtin", "name", info.Name, "error", err)
continue
}
skills = append(skills, info)
}
+77
View File
@@ -0,0 +1,77 @@
package skills
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestSkillsInfoValidate(t *testing.T) {
testcases := []struct {
name string
skillName string
description string
wantErr bool
errContains []string
}{
{
name: "valid-skill",
skillName: "valid-skill",
description: "a valid skill description",
wantErr: false,
},
{
name: "empty-name",
skillName: "",
description: "description without name",
wantErr: true,
errContains: []string{"name is required"},
},
{
name: "empty-description",
skillName: "skill-without-description",
description: "",
wantErr: true,
errContains: []string{"description is required"},
},
{
name: "empty-both",
skillName: "",
description: "",
wantErr: true,
errContains: []string{"name is required", "description is required"},
},
{
name: "name-with-spaces",
skillName: "skill with spaces",
description: "invalid name with spaces",
wantErr: true,
errContains: []string{"name must be alphanumeric with hyphens"},
},
{
name: "name-with-underscore",
skillName: "skill_underscore",
description: "invalid name with underscore",
wantErr: true,
errContains: []string{"name must be alphanumeric with hyphens"},
},
}
for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
info := SkillInfo{
Name: tc.skillName,
Description: tc.description,
}
err := info.validate()
if tc.wantErr {
assert.Error(t, err)
for _, msg := range tc.errContains {
assert.ErrorContains(t, err, msg)
}
} else {
assert.NoError(t, err)
}
})
}
}
+2 -2
View File
@@ -28,12 +28,12 @@ type CronTool struct {
}
// NewCronTool creates a new CronTool
func NewCronTool(cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus, workspace string) *CronTool {
func NewCronTool(cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus, workspace string, restrict bool) *CronTool {
return &CronTool{
cronService: cronService,
executor: executor,
msgBus: msgBus,
execTool: NewExecTool(workspace, false),
execTool: NewExecTool(workspace, restrict),
}
}
+43 -2
View File
@@ -29,13 +29,54 @@ func validatePath(path, workspace string, restrict bool) (string, error) {
}
}
if restrict && !strings.HasPrefix(absPath, absWorkspace) {
return "", fmt.Errorf("access denied: path is outside the workspace")
if restrict {
if !isWithinWorkspace(absPath, absWorkspace) {
return "", fmt.Errorf("access denied: path is outside the workspace")
}
workspaceReal := absWorkspace
if resolved, err := filepath.EvalSymlinks(absWorkspace); err == nil {
workspaceReal = resolved
}
if resolved, err := filepath.EvalSymlinks(absPath); err == nil {
if !isWithinWorkspace(resolved, workspaceReal) {
return "", fmt.Errorf("access denied: symlink resolves outside workspace")
}
} else if os.IsNotExist(err) {
if parentResolved, err := resolveExistingAncestor(filepath.Dir(absPath)); err == nil {
if !isWithinWorkspace(parentResolved, workspaceReal) {
return "", fmt.Errorf("access denied: symlink resolves outside workspace")
}
} else if !os.IsNotExist(err) {
return "", fmt.Errorf("failed to resolve path: %w", err)
}
} else {
return "", fmt.Errorf("failed to resolve path: %w", err)
}
}
return absPath, nil
}
func resolveExistingAncestor(path string) (string, error) {
for current := filepath.Clean(path); ; current = filepath.Dir(current) {
if resolved, err := filepath.EvalSymlinks(current); err == nil {
return resolved, nil
} else if !os.IsNotExist(err) {
return "", err
}
if filepath.Dir(current) == current {
return "", os.ErrNotExist
}
}
}
func isWithinWorkspace(candidate, workspace string) bool {
rel, err := filepath.Rel(filepath.Clean(workspace), filepath.Clean(candidate))
return err == nil && rel != ".." && !strings.HasPrefix(rel, ".."+string(os.PathSeparator))
}
type ReadFileTool struct {
workspace string
restrict bool
+32
View File
@@ -247,3 +247,35 @@ func TestFilesystemTool_ListDir_DefaultPath(t *testing.T) {
t.Errorf("Expected success with default path '.', got IsError=true: %s", result.ForLLM)
}
}
// Block paths that look inside workspace but point outside via symlink.
func TestFilesystemTool_ReadFile_RejectsSymlinkEscape(t *testing.T) {
root := t.TempDir()
workspace := filepath.Join(root, "workspace")
if err := os.MkdirAll(workspace, 0755); err != nil {
t.Fatalf("failed to create workspace: %v", err)
}
secret := filepath.Join(root, "secret.txt")
if err := os.WriteFile(secret, []byte("top secret"), 0644); err != nil {
t.Fatalf("failed to write secret file: %v", err)
}
link := filepath.Join(workspace, "leak.txt")
if err := os.Symlink(secret, link); err != nil {
t.Skipf("symlink not supported in this environment: %v", err)
}
tool := NewReadFileTool(workspace, true)
result := tool.Execute(context.Background(), map[string]interface{}{
"path": link,
})
if !result.IsError {
t.Fatalf("expected symlink escape to be blocked")
}
if !strings.Contains(result.ForLLM, "symlink resolves outside workspace") {
t.Fatalf("expected symlink escape error, got: %s", result.ForLLM)
}
}
+10 -6
View File
@@ -173,19 +173,23 @@ func TestWebTool_WebFetch_Truncation(t *testing.T) {
}
}
// TestWebTool_WebSearch_NoApiKey verifies that nil is returned when no provider is configured
// TestWebTool_WebSearch_NoApiKey verifies that no tool is created when API key is missing
func TestWebTool_WebSearch_NoApiKey(t *testing.T) {
tool := NewWebSearchTool(WebSearchToolOptions{BraveAPIKey: "", BraveMaxResults: 5})
// Should return nil when no provider is enabled
tool := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: ""})
if tool != nil {
t.Errorf("Expected nil when no search provider is configured")
t.Errorf("Expected nil tool when Brave API key is empty")
}
// Also nil when nothing is enabled
tool = NewWebSearchTool(WebSearchToolOptions{})
if tool != nil {
t.Errorf("Expected nil tool when no provider is enabled")
}
}
// TestWebTool_WebSearch_MissingQuery verifies error handling for missing query
func TestWebTool_WebSearch_MissingQuery(t *testing.T) {
tool := NewWebSearchTool(WebSearchToolOptions{BraveAPIKey: "test-key", BraveMaxResults: 5, BraveEnabled: true})
tool := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: "test-key", BraveMaxResults: 5})
ctx := context.Background()
args := map[string]interface{}{}
+1 -2
View File
@@ -73,9 +73,8 @@ func DownloadFile(url, filename string, opts DownloadOptions) string {
}
// Generate unique filename with UUID prefix to prevent conflicts
ext := filepath.Ext(filename)
safeName := SanitizeFilename(filename)
localPath := filepath.Join(mediaDir, uuid.New().String()[:8]+"_"+safeName+ext)
localPath := filepath.Join(mediaDir, uuid.New().String()[:8]+"_"+safeName)
// Create HTTP request
req, err := http.NewRequest("GET", url, nil)