diff --git a/.github/workflows/docker-build.yml b/.github/workflows/docker-build.yml index 90ff635da..2d1aa9ffc 100644 --- a/.github/workflows/docker-build.yml +++ b/.github/workflows/docker-build.yml @@ -1,12 +1,18 @@ name: ๐Ÿณ Build & Push Docker Image on: - release: - types: [published] + workflow_call: + inputs: + tag: + description: "Release tag" + required: true + type: string env: - REGISTRY: ghcr.io - IMAGE_NAME: ${{ github.repository_owner }}/picoclaw + GHCR_REGISTRY: ghcr.io + GHCR_IMAGE_NAME: ${{ github.repository_owner }}/picoclaw + DOCKERHUB_REGISTRY: docker.io + DOCKERHUB_IMAGE_NAME: ${{ vars.DOCKERHUB_REPOSITORY }} jobs: build: @@ -20,6 +26,8 @@ jobs: # โ”€โ”€ Checkout โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ - name: ๐Ÿ“ฅ Checkout repository uses: actions/checkout@v4 + with: + ref: ${{ inputs.tag }} # โ”€โ”€ Docker Buildx โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ - name: ๐Ÿ”ง Set up Docker Buildx @@ -27,36 +35,42 @@ jobs: # โ”€โ”€ Login to GHCR โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ - name: ๐Ÿ”‘ Login to GitHub Container Registry - if: github.event_name != 'pull_request' uses: docker/login-action@v3 with: - registry: ${{ env.REGISTRY }} + registry: ${{ env.GHCR_REGISTRY }} username: ${{ github.actor }} password: ${{ secrets.GITHUB_TOKEN }} - # โ”€โ”€ Metadata (tags & labels) โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ - - name: ๐Ÿท๏ธ Extract Docker metadata - id: meta - uses: docker/metadata-action@v5 + # โ”€โ”€ Login to Docker Hub โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + - name: ๐Ÿ”‘ Login to Docker Hub + uses: docker/login-action@v3 with: - images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} - tags: | - type=ref,event=branch - type=ref,event=pr - type=semver,pattern={{version}} - type=semver,pattern={{major}}.{{minor}} - type=sha,prefix= - type=raw,value=latest,enable={{is_default_branch}} - type=raw,value={{date 'YYYYMMDD-HHmmss'}},enable={{is_default_branch}} + registry: ${{ env.DOCKERHUB_REGISTRY }} + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + + # โ”€โ”€ Metadata (tags & labels) โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + - name: ๐Ÿท๏ธ Prepare image tags + id: tags + shell: bash + run: | + tag="${{ inputs.tag }}" + echo "ghcr_tag=${{ env.GHCR_REGISTRY }}/${{ env.GHCR_IMAGE_NAME }}:${tag}" >> "$GITHUB_OUTPUT" + echo "ghcr_latest=${{ env.GHCR_REGISTRY }}/${{ env.GHCR_IMAGE_NAME }}:latest" >> "$GITHUB_OUTPUT" + echo "dockerhub_tag=${{ env.DOCKERHUB_REGISTRY }}/${{ env.DOCKERHUB_IMAGE_NAME }}:${tag}" >> "$GITHUB_OUTPUT" + echo "dockerhub_latest=${{ env.DOCKERHUB_REGISTRY }}/${{ env.DOCKERHUB_IMAGE_NAME }}:latest" >> "$GITHUB_OUTPUT" # โ”€โ”€ Build & Push โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ - name: ๐Ÿš€ Build and push Docker image uses: docker/build-push-action@v6 with: context: . - push: ${{ github.event_name != 'pull_request' }} - tags: ${{ steps.meta.outputs.tags }} - labels: ${{ steps.meta.outputs.labels }} + push: true + tags: | + ${{ steps.tags.outputs.ghcr_tag }} + ${{ steps.tags.outputs.ghcr_latest }} + ${{ steps.tags.outputs.dockerhub_tag }} + ${{ steps.tags.outputs.dockerhub_latest }} cache-from: type=gha cache-to: type=gha,mode=max - platforms: linux/amd64,linux/arm64 + platforms: linux/amd64,linux/arm64,linux/riscv64 diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 59cc6caeb..f9987b35f 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -38,14 +38,18 @@ jobs: git tag -a "${{ inputs.tag }}" -m "Release ${{ inputs.tag }}" git push origin "${{ inputs.tag }}" - build-binaries: - name: Build Release Binaries + release: + name: GoReleaser Release needs: create-tag runs-on: ubuntu-latest + permissions: + contents: write + packages: write steps: - name: Checkout tag uses: actions/checkout@v4 with: + fetch-depth: 0 ref: ${{ inputs.tag }} - name: Setup Go from go.mod @@ -53,47 +57,42 @@ jobs: with: go-version-file: go.mod - - name: Build all binaries - run: make build-all + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 - - name: Generate checksums - shell: bash - run: | - shasum -a 256 build/picoclaw-* > build/sha256sums.txt + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 - - name: Upload release binaries artifact - uses: actions/upload-artifact@v4 + - name: Login to GitHub Container Registry + uses: docker/login-action@v3 with: - name: picoclaw-binaries - path: | - build/picoclaw-* - build/sha256sums.txt - if-no-files-found: error + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} - create-release: - name: Create GitHub Release - needs: [create-tag, build-binaries] - runs-on: ubuntu-latest - permissions: - contents: write - steps: - - name: Download all artifacts - uses: actions/download-artifact@v4 + - name: Login to Docker Hub + uses: docker/login-action@v3 with: - path: release-artifacts + registry: docker.io + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} - - name: Show downloaded files - run: ls -R release-artifacts - - - name: Create release - uses: softprops/action-gh-release@v2 + - name: Run GoReleaser + uses: goreleaser/goreleaser-action@v6 with: - tag_name: ${{ inputs.tag }} - name: ${{ inputs.tag }} - draft: ${{ inputs.draft }} - prerelease: ${{ inputs.prerelease }} - files: | - release-artifacts/**/* - generate_release_notes: true + distribution: goreleaser + version: ~> v2 + args: release --clean env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + GITHUB_REPOSITORY_OWNER: ${{ github.repository_owner }} + DOCKERHUB_IMAGE_NAME: ${{ vars.DOCKERHUB_REPOSITORY }} + + - name: Apply release flags + shell: bash + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + gh release edit "${{ inputs.tag }}" \ + --draft=${{ inputs.draft }} \ + --prerelease=${{ inputs.prerelease }} diff --git a/.goreleaser.yaml b/.goreleaser.yaml index a2c158331..368a0f06b 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -5,9 +5,11 @@ version: 2 before: hooks: - go mod tidy + - go generate ./cmd/picoclaw builds: - - env: + - id: picoclaw + env: - CGO_ENABLED=0 goos: - linux @@ -26,6 +28,22 @@ builds: - goos: windows goarch: arm +dockers_v2: + - id: picoclaw + dockerfile: Dockerfile.goreleaser + ids: + - picoclaw + images: + - "ghcr.io/{{ .Env.GITHUB_REPOSITORY_OWNER }}/picoclaw" + - "docker.io/{{ .Env.DOCKERHUB_IMAGE_NAME }}" + tags: + - "{{ .Tag }}" + - "latest" + platforms: + - linux/amd64 + - linux/arm64 + - linux/riscv64 + archives: - formats: [tar.gz] # this name template makes the OS and Arch compatible with the results of `uname`. @@ -48,10 +66,10 @@ changelog: - "^docs:" - "^test:" -upx: - - enabled: true - compress: best - lzma: true +# upx: +# - enabled: true +# compress: best +# lzma: true release: footer: >- diff --git a/Dockerfile b/Dockerfile index 433d962f2..dd98ec0bd 100644 --- a/Dockerfile +++ b/Dockerfile @@ -22,6 +22,10 @@ FROM alpine:3.23 RUN apk add --no-cache ca-certificates tzdata curl +# Health check +HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \ + CMD wget -q --spider http://localhost:18790/health || exit 1 + # Copy binary COPY --from=builder /src/build/picoclaw /usr/local/bin/picoclaw diff --git a/Dockerfile.goreleaser b/Dockerfile.goreleaser new file mode 100644 index 000000000..0cdc8c6bd --- /dev/null +++ b/Dockerfile.goreleaser @@ -0,0 +1,10 @@ +FROM alpine:3.21 + +ARG TARGETPLATFORM + +RUN apk add --no-cache ca-certificates tzdata + +COPY $TARGETPLATFORM/picoclaw /usr/local/bin/picoclaw + +ENTRYPOINT ["picoclaw"] +CMD ["gateway"] diff --git a/Makefile b/Makefile index a97f17799..05551bedc 100644 --- a/Makefile +++ b/Makefile @@ -119,7 +119,7 @@ clean: @rm -rf $(BUILD_DIR) @echo "Clean complete" -## fmt: Format Go code +## vet: Run go vet for static analysis vet: @$(GO) vet ./... @@ -131,11 +131,19 @@ test: fmt: @$(GO) fmt ./... -## deps: Update dependencies +## deps: Download dependencies deps: + @$(GO) mod download + @$(GO) mod verify + +## update-deps: Update dependencies +update-deps: @$(GO) get -u ./... @$(GO) mod tidy +## check: Run vet, fmt, and verify dependencies +check: deps fmt vet test + ## run: Build and run picoclaw run: build @$(BUILD_DIR)/$(BINARY_NAME) $(ARGS) diff --git a/README.md b/README.md index a61a1abf4..091af2811 100644 --- a/README.md +++ b/README.md @@ -44,7 +44,7 @@ > * **NO CRYPTO:** PicoClaw has **NO** official token/coin. All claims on `pump.fun` or other trading platforms are **SCAMS**. > * **OFFICIAL DOMAIN:** The **ONLY** official website is **[picoclaw.io](https://picoclaw.io)**, and company website is **[sipeed.com](https://sipeed.com)** > * **Warning:** Many `.ai/.org/.com/.net/...` domains are registered by third parties. -> +> * **Warning:** picoclaw is in early development now and may have unresolved network security issues. Do not deploy to production environments before the v1.0 release. ## ๐Ÿ“ข News diff --git a/README.zh.md b/README.zh.md index f94abce88..5a1c3c50b 100644 --- a/README.zh.md +++ b/README.zh.md @@ -45,8 +45,8 @@ > * **ๆ— ๅŠ ๅฏ†่ดงๅธ (NO CRYPTO):** PicoClaw **ๆฒกๆœ‰** ๅ‘่กŒไปปไฝ•ๅฎ˜ๆ–นไปฃๅธใ€Token ๆˆ–่™šๆ‹Ÿ่ดงๅธใ€‚ๆ‰€ๆœ‰ๅœจ `pump.fun` ๆˆ–ๅ…ถไป–ไบคๆ˜“ๅนณๅฐไธŠ็š„็›ธๅ…ณๅฃฐ็งฐๅ‡ไธบ **่ฏˆ้ช—**ใ€‚ > * **ๅฎ˜ๆ–นๅŸŸๅ:** ๅ”ฏไธ€็š„ๅฎ˜ๆ–น็ฝ‘็ซ™ๆ˜ฏ **[picoclaw.io](https://picoclaw.io)**๏ผŒๅ…ฌๅธๅฎ˜็ฝ‘ๆ˜ฏ **[sipeed.com](https://sipeed.com)**ใ€‚ > * **่ญฆๆƒ•:** ่ฎธๅคš `.ai/.org/.com/.net/...` ๅŽ็ผ€็š„ๅŸŸๅ่ขซ็ฌฌไธ‰ๆ–นๆŠขๆณจ๏ผŒ่ฏทๅ‹ฟ่ฝปไฟกใ€‚ -> -> +> * **ๆณจๆ„:** picoclawๆญฃๅœจๅˆๆœŸ็š„ๅฟซ้€ŸๅŠŸ่ƒฝๅผ€ๅ‘้˜ถๆฎต๏ผŒๅฏ่ƒฝๆœ‰ๅฐšๆœชไฟฎๅค็š„็ฝ‘็ปœๅฎ‰ๅ…จ้—ฎ้ข˜๏ผŒๅœจ1.0ๆญฃๅผ็‰ˆๅ‘ๅธƒๅ‰๏ผŒ่ฏทไธ่ฆๅฐ†ๅ…ถ้ƒจ็ฝฒๅˆฐ็”Ÿไบง็Žฏๅขƒไธญ + ## ๐Ÿ“ข ๆ–ฐ้—ป (News) diff --git a/assets/wechat.png b/assets/wechat.png index 0f97fa3ee..d62c8d09d 100644 Binary files a/assets/wechat.png and b/assets/wechat.png differ diff --git a/cmd/picoclaw/main.go b/cmd/picoclaw/main.go index 2129662d7..10b53948b 100644 --- a/cmd/picoclaw/main.go +++ b/cmd/picoclaw/main.go @@ -13,6 +13,7 @@ import ( "fmt" "io" "io/fs" + "net/http" "os" "os/signal" "path/filepath" @@ -28,6 +29,7 @@ import ( "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/cron" "github.com/sipeed/picoclaw/pkg/devices" + "github.com/sipeed/picoclaw/pkg/health" "github.com/sipeed/picoclaw/pkg/heartbeat" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/migrate" @@ -560,7 +562,7 @@ func gatewayCmd() { }) // Setup cron tool and service - cronService := setupCronTool(agentLoop, msgBus, cfg.WorkspacePath()) + cronService := setupCronTool(agentLoop, msgBus, cfg.WorkspacePath(), cfg.Agents.Defaults.RestrictToWorkspace) heartbeatService := heartbeat.NewHeartbeatService( cfg.WorkspacePath(), @@ -592,6 +594,9 @@ func gatewayCmd() { os.Exit(1) } + // Inject channel manager into agent loop for command handling + agentLoop.SetChannelManager(channelManager) + var transcriber *voice.GroqTranscriber if cfg.Providers.Groq.APIKey != "" { transcriber = voice.NewGroqTranscriber(cfg.Providers.Groq.APIKey) @@ -658,6 +663,14 @@ func gatewayCmd() { fmt.Printf("Error starting channels: %v\n", err) } + healthServer := health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port) + go func() { + if err := healthServer.Start(); err != nil && err != http.ErrServerClosed { + logger.ErrorCF("health", "Health server error", map[string]interface{}{"error": err.Error()}) + } + }() + fmt.Printf("โœ“ Health endpoints available at http://%s:%d/health and /ready\n", cfg.Gateway.Host, cfg.Gateway.Port) + go agentLoop.Run(ctx) sigChan := make(chan os.Signal, 1) @@ -666,6 +679,7 @@ func gatewayCmd() { fmt.Println("\nShutting down...") cancel() + healthServer.Stop(context.Background()) deviceService.Stop() heartbeatService.Stop() cronService.Stop() @@ -973,14 +987,14 @@ func getConfigPath() string { return filepath.Join(home, ".picoclaw", "config.json") } -func setupCronTool(agentLoop *agent.AgentLoop, msgBus *bus.MessageBus, workspace string) *cron.CronService { +func setupCronTool(agentLoop *agent.AgentLoop, msgBus *bus.MessageBus, workspace string, restrict bool) *cron.CronService { cronStorePath := filepath.Join(workspace, "cron", "jobs.json") // Create cron service cronService := cron.NewCronService(cronStorePath, nil) // Create and register CronTool - cronTool := tools.NewCronTool(cronService, agentLoop, msgBus, workspace) + cronTool := tools.NewCronTool(cronService, agentLoop, msgBus, workspace, restrict) agentLoop.RegisterTool(cronTool) // Set the onJob handler diff --git a/config/config.example.json b/config/config.example.json index 75478af4d..fe1e88279 100644 --- a/config/config.example.json +++ b/config/config.example.json @@ -109,6 +109,10 @@ "moonshot": { "api_key": "sk-xxx", "api_base": "" + }, + "ollama": { + "api_key": "", + "api_base": "http://localhost:11434/v1" } }, "tools": { diff --git a/go.mod b/go.mod index 528bc40b9..833093f7c 100644 --- a/go.mod +++ b/go.mod @@ -16,10 +16,16 @@ require ( github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1 github.com/openai/openai-go/v3 v3.22.0 github.com/slack-go/slack v0.17.3 + github.com/stretchr/testify v1.11.1 github.com/tencent-connect/botgo v0.2.1 golang.org/x/oauth2 v0.35.0 ) +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) require github.com/yosida95/uritemplate/v3 v3.0.2 // indirect require ( diff --git a/go.sum b/go.sum index 5469dd7dd..450d449bf 100644 --- a/go.sum +++ b/go.sum @@ -82,9 +82,11 @@ github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzh github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/larksuite/oapi-sdk-go/v3 v3.5.3 h1:xvf8Dv29kBXC5/DNDCLhHkAFW8l/0LlQJimO5Zn+JUk= github.com/larksuite/oapi-sdk-go/v3 v3.5.3/go.mod h1:ZEplY+kwuIrj/nqw5uSCINNATcH3KdxSN7y+UxYY5fI= @@ -108,6 +110,7 @@ github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsK github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/slack-go/slack v0.17.3 h1:zV5qO3Q+WJAQ/XwbGfNFrRMaJ5T/naqaonyPV/1TP4g= github.com/slack-go/slack v0.17.3/go.mod h1:X+UqOufi3LYQHDnMG1vxf0J8asC6+WllXrVrhl8/Prk= @@ -252,6 +255,7 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0 google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 742ea5496..9349f86f3 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -19,6 +19,7 @@ import ( "unicode/utf8" "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/constants" "github.com/sipeed/picoclaw/pkg/logger" @@ -44,6 +45,7 @@ type AgentLoop struct { mcpManager *mcp.Manager // MCP server manager for resource cleanup running atomic.Bool summarizing sync.Map // Tracks which sessions are currently being summarized + channelManager *channels.Manager } // processOptions configures how a message is processed @@ -239,6 +241,10 @@ func (al *AgentLoop) RegisterTool(tool tools.Tool) { al.tools.Register(tool) } +func (al *AgentLoop) SetChannelManager(cm *channels.Manager) { + al.channelManager = cm +} + // RecordLastChannel records the last active channel for this workspace. // This uses the atomic state save mechanism to prevent data loss on crash. func (al *AgentLoop) RecordLastChannel(channel string) error { @@ -303,6 +309,11 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) return al.processSystemMessage(ctx, msg) } + // Check for commands + if response, handled := al.handleCommand(ctx, msg); handled { + return response, nil + } + // Process as user message return al.runAgentLoop(ctx, processOptions{ SessionKey: msg.SessionKey, @@ -423,7 +434,7 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str // 7. Optional: summarization if opts.EnableSummary { - al.maybeSummarize(opts.SessionKey) + al.maybeSummarize(opts.SessionKey, opts.Channel, opts.ChatID) } // 8. Optional: send response via bus @@ -485,11 +496,131 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M "tools_json": formatToolsForLog(providerToolDefs), }) - // Call LLM - response, err := al.provider.Chat(ctx, messages, providerToolDefs, al.model, map[string]interface{}{ - "max_tokens": 8192, - "temperature": 0.7, - }) + var response *providers.LLMResponse + var err error + + // Retry loop for context/token errors + maxRetries := 2 + for retry := 0; retry <= maxRetries; retry++ { + response, err = al.provider.Chat(ctx, messages, providerToolDefs, al.model, map[string]interface{}{ + "max_tokens": 8192, + "temperature": 0.7, + }) + + if err == nil { + break // Success + } + + errMsg := strings.ToLower(err.Error()) + // Check for context window errors (provider specific, but usually contain "token" or "invalid") + isContextError := strings.Contains(errMsg, "token") || + strings.Contains(errMsg, "context") || + strings.Contains(errMsg, "invalidparameter") || + strings.Contains(errMsg, "length") + + if isContextError && retry < maxRetries { + logger.WarnCF("agent", "Context window error detected, attempting compression", map[string]interface{}{ + "error": err.Error(), + "retry": retry, + }) + + // Notify user on first retry only + if retry == 0 && !constants.IsInternalChannel(opts.Channel) && opts.SendResponse { + al.bus.PublishOutbound(bus.OutboundMessage{ + Channel: opts.Channel, + ChatID: opts.ChatID, + Content: "โš ๏ธ Context window exceeded. Compressing history and retrying...", + }) + } + + // Force compression + al.forceCompression(opts.SessionKey) + + // Rebuild messages with compressed history + // Note: We need to reload history from session manager because forceCompression changed it + newHistory := al.sessions.GetHistory(opts.SessionKey) + newSummary := al.sessions.GetSummary(opts.SessionKey) + + // Re-create messages for the next attempt + // We keep the current user message (opts.UserMessage) effectively + messages = al.contextBuilder.BuildMessages( + newHistory, + newSummary, + opts.UserMessage, + nil, + opts.Channel, + opts.ChatID, + ) + + // Important: If we are in the middle of a tool loop (iteration > 1), + // rebuilding messages from session history might duplicate the flow or miss context + // if intermediate steps weren't saved correctly. + // However, al.sessions.AddFullMessage is called after every tool execution, + // so GetHistory should reflect the current state including partial tool execution. + // But we need to ensure we don't duplicate the user message which is appended in BuildMessages. + // BuildMessages(history...) takes the stored history and appends the *current* user message. + // If iteration > 1, the "current user message" was already added to history in step 3 of runAgentLoop. + // So if we pass opts.UserMessage again, we might duplicate it? + // Actually, step 3 is: al.sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage) + // So GetHistory ALREADY contains the user message! + + // CORRECTION: + // BuildMessages combines: [System] + [History] + [CurrentMessage] + // But Step 3 added CurrentMessage to History. + // So if we use GetHistory now, it has the user message. + // If we pass opts.UserMessage to BuildMessages, it adds it AGAIN. + + // For retry in the middle of a loop, we should rely on what's in the session. + // BUT checking BuildMessages implementation: + // It appends history... then appends currentMessage. + + // Logic fix for retry: + // If iteration == 1, opts.UserMessage corresponds to the user input. + // If iteration > 1, we are processing tool results. The "messages" passed to Chat + // already accumulated tool outputs. + // Rebuilding from session history is safest because it persists state. + // Start fresh with rebuilt history. + + // Special case: standard BuildMessages appends "currentMessage". + // If we are strictly retrying the *LLM call*, we want the exact same state as before but compressed. + // However, the "messages" argument passed to runLLMIteration is constructed by the caller. + // If we rebuild from Session, we need to know if "currentMessage" should be appended or is already in history. + + // In runAgentLoop: + // 3. sessions.AddMessage(userMsg) + // 4. runLLMIteration(..., UserMessage) + + // So History contains the user message. + // BuildMessages typically appends the user message as a *new* pending message. + // Wait, standard BuildMessages usage in runAgentLoop: + // messages := BuildMessages(history (has old), UserMessage) + // THEN AddMessage(UserMessage). + // So "history" passed to BuildMessages does NOT contain the current UserMessage yet. + + // But here, inside the loop, we have already saved it. + // So GetHistory() includes the current user message. + // If we call BuildMessages(GetHistory(), UserMessage), we get duplicates. + + // Hack/Fix: + // If we are retrying, we rebuild from Session History ONLY. + // We pass empty string as "currentMessage" to BuildMessages + // because the "current message" is already saved in history (step 3). + + messages = al.contextBuilder.BuildMessages( + newHistory, + newSummary, + "", // Empty because history already contains the relevant messages + nil, + opts.Channel, + opts.ChatID, + ) + + continue + } + + // Real error or success, break loop + break + } if err != nil { logger.ErrorCF("agent", "LLM call failed", @@ -497,7 +628,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M "iteration": iteration, "error": err.Error(), }) - return "", iteration, fmt.Errorf("LLM call failed: %w", err) + return "", iteration, fmt.Errorf("LLM call failed after retries: %w", err) } // Check if no tool calls - we're done @@ -629,7 +760,7 @@ func (al *AgentLoop) updateToolContexts(channel, chatID string) { } // maybeSummarize triggers summarization if the session history exceeds thresholds. -func (al *AgentLoop) maybeSummarize(sessionKey string) { +func (al *AgentLoop) maybeSummarize(sessionKey, channel, chatID string) { newHistory := al.sessions.GetHistory(sessionKey) tokenEstimate := al.estimateTokens(newHistory) threshold := al.contextWindow * 75 / 100 @@ -638,12 +769,80 @@ func (al *AgentLoop) maybeSummarize(sessionKey string) { if _, loading := al.summarizing.LoadOrStore(sessionKey, true); !loading { go func() { defer al.summarizing.Delete(sessionKey) + // Notify user about optimization if not an internal channel + if !constants.IsInternalChannel(channel) { + al.bus.PublishOutbound(bus.OutboundMessage{ + Channel: channel, + ChatID: chatID, + Content: "โš ๏ธ Memory threshold reached. Optimizing conversation history...", + }) + } al.summarizeSession(sessionKey) }() } } } +// forceCompression aggressively reduces context when the limit is hit. +// It drops the oldest 50% of messages (keeping system prompt and last user message). +func (al *AgentLoop) forceCompression(sessionKey string) { + history := al.sessions.GetHistory(sessionKey) + if len(history) <= 4 { + return + } + + // Keep system prompt (usually [0]) and the very last message (user's trigger) + // We want to drop the oldest half of the *conversation* + // Assuming [0] is system, [1:] is conversation + conversation := history[1 : len(history)-1] + if len(conversation) == 0 { + return + } + + // Helper to find the mid-point of the conversation + mid := len(conversation) / 2 + + // New history structure: + // 1. System Prompt + // 2. [Summary of dropped part] - synthesized + // 3. Second half of conversation + // 4. Last message + + // Simplified approach for emergency: Drop first half of conversation + // and rely on existing summary if present, or create a placeholder. + + droppedCount := mid + keptConversation := conversation[mid:] + + newHistory := make([]providers.Message, 0) + newHistory = append(newHistory, history[0]) // System prompt + + // Add a note about compression + compressionNote := fmt.Sprintf("[System: Emergency compression dropped %d oldest messages due to context limit]", droppedCount) + // If there was an existing summary, we might lose it if it was in the dropped part (which is just messages). + // The summary is stored separately in session.Summary, so it persists! + // We just need to ensure the user knows there's a gap. + + // We only modify the messages list here + newHistory = append(newHistory, providers.Message{ + Role: "system", + Content: compressionNote, + }) + + newHistory = append(newHistory, keptConversation...) + newHistory = append(newHistory, history[len(history)-1]) // Last message + + // Update session + al.sessions.SetHistory(sessionKey, newHistory) + al.sessions.Save(sessionKey) + + logger.WarnCF("agent", "Forced compression executed", map[string]interface{}{ + "session_key": sessionKey, + "dropped_msgs": droppedCount, + "new_count": len(newHistory), + }) +} + // GetStartupInfo returns information about loaded tools and skills for logging. func (al *AgentLoop) GetStartupInfo() map[string]interface{} { info := make(map[string]interface{}) @@ -671,7 +870,7 @@ func formatMessagesForLog(messages []providers.Message) string { result += "[\n" for i, msg := range messages { result += fmt.Sprintf(" [%d] Role: %s\n", i, msg.Role) - if msg.ToolCalls != nil && len(msg.ToolCalls) > 0 { + if len(msg.ToolCalls) > 0 { result += " ToolCalls:\n" for _, tc := range msg.ToolCalls { result += fmt.Sprintf(" - ID: %s, Type: %s, Name: %s\n", tc.ID, tc.Type, tc.Name) @@ -738,7 +937,7 @@ func (al *AgentLoop) summarizeSession(sessionKey string) { continue } // Estimate tokens for this message - msgTokens := len(m.Content) / 4 + msgTokens := len(m.Content) / 2 // Use safer estimate here too (2.5 -> 2 for integer division safety) if msgTokens > maxMessageTokens { omitted = true continue @@ -809,13 +1008,96 @@ func (al *AgentLoop) summarizeBatch(ctx context.Context, batch []providers.Messa } // estimateTokens estimates the number of tokens in a message list. -// Uses rune count instead of byte length so that CJK and other multi-byte -// characters are not over-counted (a Chinese character is 3 bytes but roughly -// one token). +// Uses a safe heuristic of 2.5 characters per token to account for CJK and other +// overheads better than the previous 3 chars/token. func (al *AgentLoop) estimateTokens(messages []providers.Message) int { - total := 0 + totalChars := 0 for _, m := range messages { - total += utf8.RuneCountInString(m.Content) / 3 + totalChars += utf8.RuneCountInString(m.Content) } - return total + // 2.5 chars per token = totalChars * 2 / 5 + return totalChars * 2 / 5 +} + +func (al *AgentLoop) handleCommand(ctx context.Context, msg bus.InboundMessage) (string, bool) { + content := strings.TrimSpace(msg.Content) + if !strings.HasPrefix(content, "/") { + return "", false + } + + parts := strings.Fields(content) + if len(parts) == 0 { + return "", false + } + + cmd := parts[0] + args := parts[1:] + + switch cmd { + case "/show": + if len(args) < 1 { + return "Usage: /show [model|channel]", true + } + switch args[0] { + case "model": + return fmt.Sprintf("Current model: %s", al.model), true + case "channel": + return fmt.Sprintf("Current channel: %s", msg.Channel), true + default: + return fmt.Sprintf("Unknown show target: %s", args[0]), true + } + + case "/list": + if len(args) < 1 { + return "Usage: /list [models|channels]", true + } + switch args[0] { + case "models": + // TODO: Fetch available models dynamically if possible + return "Available models: glm-4.7, claude-3-5-sonnet, gpt-4o (configured in config.json/env)", true + case "channels": + if al.channelManager == nil { + return "Channel manager not initialized", true + } + channels := al.channelManager.GetEnabledChannels() + if len(channels) == 0 { + return "No channels enabled", true + } + return fmt.Sprintf("Enabled channels: %s", strings.Join(channels, ", ")), true + default: + return fmt.Sprintf("Unknown list target: %s", args[0]), true + } + + case "/switch": + if len(args) < 3 || args[1] != "to" { + return "Usage: /switch [model|channel] to ", true + } + target := args[0] + value := args[2] + + switch target { + case "model": + oldModel := al.model + al.model = value + return fmt.Sprintf("Switched model from %s to %s", oldModel, value), true + case "channel": + // This changes the 'default' channel for some operations, or effectively redirects output? + // For now, let's just validate if the channel exists + if al.channelManager == nil { + return "Channel manager not initialized", true + } + if _, exists := al.channelManager.GetChannel(value); !exists && value != "cli" { + return fmt.Sprintf("Channel '%s' not found or not enabled", value), true + } + + // If message came from CLI, maybe we want to redirect CLI output to this channel? + // That would require state persistence about "redirected channel" + // For now, just acknowledged. + return fmt.Sprintf("Switched target channel to %s (Note: this currently only validates existence)", value), true + default: + return fmt.Sprintf("Unknown switch target: %s", target), true + } + } + + return "", false } diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go index c18220258..0bd38abf4 100644 --- a/pkg/agent/loop_test.go +++ b/pkg/agent/loop_test.go @@ -2,6 +2,7 @@ package agent import ( "context" + "fmt" "os" "path/filepath" "testing" @@ -527,3 +528,99 @@ func TestToolResult_UserFacingToolDoesSendMessage(t *testing.T) { t.Errorf("Expected 'Command output: hello world', got: %s", response) } } + +// failFirstMockProvider fails on the first N calls with a specific error +type failFirstMockProvider struct { + failures int + currentCall int + failError error + successResp string +} + +func (m *failFirstMockProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, opts map[string]interface{}) (*providers.LLMResponse, error) { + m.currentCall++ + if m.currentCall <= m.failures { + return nil, m.failError + } + return &providers.LLMResponse{ + Content: m.successResp, + ToolCalls: []providers.ToolCall{}, + }, nil +} + +func (m *failFirstMockProvider) GetDefaultModel() string { + return "mock-fail-model" +} + +// TestAgentLoop_ContextExhaustionRetry verify that the agent retries on context errors +func TestAgentLoop_ContextExhaustionRetry(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + + // Create a provider that fails once with a context error + contextErr := fmt.Errorf("InvalidParameter: Total tokens of image and text exceed max message tokens") + provider := &failFirstMockProvider{ + failures: 1, + failError: contextErr, + successResp: "Recovered from context error", + } + + al := NewAgentLoop(cfg, msgBus, provider) + + // Inject some history to simulate a full context + sessionKey := "test-session-context" + // Create dummy history + history := []providers.Message{ + {Role: "system", Content: "System prompt"}, + {Role: "user", Content: "Old message 1"}, + {Role: "assistant", Content: "Old response 1"}, + {Role: "user", Content: "Old message 2"}, + {Role: "assistant", Content: "Old response 2"}, + {Role: "user", Content: "Trigger message"}, + } + al.sessions.SetHistory(sessionKey, history) + + // Call ProcessDirectWithChannel + // Note: ProcessDirectWithChannel calls processMessage which will execute runLLMIteration + response, err := al.ProcessDirectWithChannel(context.Background(), "Trigger message", sessionKey, "test", "test-chat") + + if err != nil { + t.Fatalf("Expected success after retry, got error: %v", err) + } + + if response != "Recovered from context error" { + t.Errorf("Expected 'Recovered from context error', got '%s'", response) + } + + // We expect 2 calls: 1st failed, 2nd succeeded + if provider.currentCall != 2 { + t.Errorf("Expected 2 calls (1 fail + 1 success), got %d", provider.currentCall) + } + + // Check final history length + finalHistory := al.sessions.GetHistory(sessionKey) + // We verify that the history has been modified (compressed) + // Original length: 6 + // Expected behavior: compression drops ~50% of history (mid slice) + // We can assert that the length is NOT what it would be without compression. + // Without compression: 6 + 1 (new user msg) + 1 (assistant msg) = 8 + if len(finalHistory) >= 8 { + t.Errorf("Expected history to be compressed (len < 8), got %d", len(finalHistory)) + } +} diff --git a/pkg/auth/oauth.go b/pkg/auth/oauth.go index 1a6589641..dcd91bebd 100644 --- a/pkg/auth/oauth.go +++ b/pkg/auth/oauth.go @@ -281,7 +281,17 @@ func RefreshAccessToken(cred *AuthCredential, cfg OAuthProviderConfig) (*AuthCre return nil, fmt.Errorf("token refresh failed: %s", string(body)) } - return parseTokenResponse(body, cred.Provider) + refreshed, err := parseTokenResponse(body, cred.Provider) + if err != nil { + return nil, err + } + if refreshed.RefreshToken == "" { + refreshed.RefreshToken = cred.RefreshToken + } + if refreshed.AccountID == "" { + refreshed.AccountID = cred.AccountID + } + return refreshed, nil } func BuildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectURI string) string { @@ -300,6 +310,9 @@ func buildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectU "codex_cli_simplified_flow": {"true"}, "state": {state}, } + if strings.Contains(strings.ToLower(cfg.Issuer), "auth.openai.com") { + params.Set("originator", "picoclaw") + } if cfg.Originator != "" { params.Set("originator", cfg.Originator) } @@ -357,7 +370,9 @@ func parseTokenResponse(body []byte, provider string) (*AuthCredential, error) { AuthMethod: "oauth", } - if accountID := extractAccountID(tokenResp.AccessToken); accountID != "" { + if accountID := extractAccountID(tokenResp.IDToken); accountID != "" { + cred.AccountID = accountID + } else if accountID := extractAccountID(tokenResp.AccessToken); accountID != "" { cred.AccountID = accountID } else if accountID := extractAccountID(tokenResp.IDToken); accountID != "" { // Recent OpenAI OAuth responses may only include chatgpt_account_id in id_token claims. @@ -367,12 +382,45 @@ func parseTokenResponse(body []byte, provider string) (*AuthCredential, error) { return cred, nil } -func extractAccountID(accessToken string) string { - parts := strings.Split(accessToken, ".") - if len(parts) < 2 { +func extractAccountID(token string) string { + claims, err := parseJWTClaims(token) + if err != nil { return "" } + if accountID, ok := claims["chatgpt_account_id"].(string); ok && accountID != "" { + return accountID + } + + if accountID, ok := claims["https://api.openai.com/auth.chatgpt_account_id"].(string); ok && accountID != "" { + return accountID + } + + if authClaim, ok := claims["https://api.openai.com/auth"].(map[string]interface{}); ok { + if accountID, ok := authClaim["chatgpt_account_id"].(string); ok && accountID != "" { + return accountID + } + } + + if orgs, ok := claims["organizations"].([]interface{}); ok { + for _, org := range orgs { + if orgMap, ok := org.(map[string]interface{}); ok { + if accountID, ok := orgMap["id"].(string); ok && accountID != "" { + return accountID + } + } + } + } + + return "" +} + +func parseJWTClaims(token string) (map[string]interface{}, error) { + parts := strings.Split(token, ".") + if len(parts) < 2 { + return nil, fmt.Errorf("token is not a JWT") + } + payload := parts[1] switch len(payload) % 4 { case 2: @@ -383,21 +431,15 @@ func extractAccountID(accessToken string) string { decoded, err := base64URLDecode(payload) if err != nil { - return "" + return nil, err } var claims map[string]interface{} if err := json.Unmarshal(decoded, &claims); err != nil { - return "" + return nil, err } - if authClaim, ok := claims["https://api.openai.com/auth"].(map[string]interface{}); ok { - if accountID, ok := authClaim["chatgpt_account_id"].(string); ok { - return accountID - } - } - - return "" + return claims, nil } func base64URLDecode(s string) ([]byte, error) { diff --git a/pkg/auth/oauth_test.go b/pkg/auth/oauth_test.go index 0d2ccc9a5..5deb17805 100644 --- a/pkg/auth/oauth_test.go +++ b/pkg/auth/oauth_test.go @@ -5,10 +5,23 @@ import ( "encoding/json" "net/http" "net/http/httptest" + "net/url" "strings" "testing" ) +func makeJWTForClaims(t *testing.T, claims map[string]interface{}) string { + t.Helper() + + header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"none","typ":"JWT"}`)) + payloadJSON, err := json.Marshal(claims) + if err != nil { + t.Fatalf("marshal claims: %v", err) + } + payload := base64.RawURLEncoding.EncodeToString(payloadJSON) + return header + "." + payload + ".sig" +} + func TestBuildAuthorizeURL(t *testing.T) { cfg := OAuthProviderConfig{ Issuer: "https://auth.example.com", @@ -53,6 +66,28 @@ func TestBuildAuthorizeURL(t *testing.T) { } } +func TestBuildAuthorizeURLOpenAIExtras(t *testing.T) { + cfg := OpenAIOAuthConfig() + pkce := PKCECodes{CodeVerifier: "test-verifier", CodeChallenge: "test-challenge"} + + u := BuildAuthorizeURL(cfg, pkce, "test-state", "http://localhost:1455/auth/callback") + parsed, err := url.Parse(u) + if err != nil { + t.Fatalf("url.Parse() error: %v", err) + } + q := parsed.Query() + + if q.Get("id_token_add_organizations") != "true" { + t.Errorf("id_token_add_organizations = %q, want true", q.Get("id_token_add_organizations")) + } + if q.Get("codex_cli_simplified_flow") != "true" { + t.Errorf("codex_cli_simplified_flow = %q, want true", q.Get("codex_cli_simplified_flow")) + } + if q.Get("originator") != "codex_cli_rs" { + t.Errorf("originator = %q, want codex_cli_rs", q.Get("originator")) + } +} + func TestParseTokenResponse(t *testing.T) { resp := map[string]interface{}{ "access_token": "test-access-token", @@ -84,6 +119,37 @@ func TestParseTokenResponse(t *testing.T) { } } +func TestParseTokenResponseExtractsAccountIDFromIDToken(t *testing.T) { + idToken := makeJWTForClaims(t, map[string]interface{}{"chatgpt_account_id": "acc-id-from-id-token"}) + resp := map[string]interface{}{ + "access_token": "opaque-access-token", + "refresh_token": "test-refresh-token", + "expires_in": 3600, + "id_token": idToken, + } + body, _ := json.Marshal(resp) + + cred, err := parseTokenResponse(body, "openai") + if err != nil { + t.Fatalf("parseTokenResponse() error: %v", err) + } + if cred.AccountID != "acc-id-from-id-token" { + t.Errorf("AccountID = %q, want %q", cred.AccountID, "acc-id-from-id-token") + } +} + +func TestExtractAccountIDFromOrganizationsFallback(t *testing.T) { + token := makeJWTForClaims(t, map[string]interface{}{ + "organizations": []interface{}{ + map[string]interface{}{"id": "org_from_orgs"}, + }, + }) + + if got := extractAccountID(token); got != "org_from_orgs" { + t.Errorf("extractAccountID() = %q, want %q", got, "org_from_orgs") + } +} + func TestParseTokenResponseNoAccessToken(t *testing.T) { body := []byte(`{"refresh_token": "test"}`) _, err := parseTokenResponse(body, "openai") @@ -222,6 +288,37 @@ func TestRefreshAccessTokenNoRefreshToken(t *testing.T) { } } +func TestRefreshAccessTokenPreservesRefreshAndAccountID(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := map[string]interface{}{ + "access_token": "new-access-token-only", + "expires_in": 3600, + } + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + cfg := OAuthProviderConfig{Issuer: server.URL, ClientID: "test-client"} + cred := &AuthCredential{ + AccessToken: "old-access", + RefreshToken: "existing-refresh", + AccountID: "acc_existing", + Provider: "openai", + AuthMethod: "oauth", + } + + refreshed, err := RefreshAccessToken(cred, cfg) + if err != nil { + t.Fatalf("RefreshAccessToken() error: %v", err) + } + if refreshed.RefreshToken != "existing-refresh" { + t.Errorf("RefreshToken = %q, want %q", refreshed.RefreshToken, "existing-refresh") + } + if refreshed.AccountID != "acc_existing" { + t.Errorf("AccountID = %q, want %q", refreshed.AccountID, "acc_existing") + } +} + func TestOpenAIOAuthConfig(t *testing.T) { cfg := OpenAIOAuthConfig() if cfg.Issuer != "https://auth.openai.com" { diff --git a/pkg/bus/bus.go b/pkg/bus/bus.go index 6283251a4..58c0a25d5 100644 --- a/pkg/bus/bus.go +++ b/pkg/bus/bus.go @@ -9,6 +9,7 @@ type MessageBus struct { inbound chan InboundMessage outbound chan OutboundMessage handlers map[string]MessageHandler + closed bool mu sync.RWMutex } @@ -21,6 +22,11 @@ func NewMessageBus() *MessageBus { } func (mb *MessageBus) PublishInbound(msg InboundMessage) { + mb.mu.RLock() + defer mb.mu.RUnlock() + if mb.closed { + return + } mb.inbound <- msg } @@ -34,6 +40,11 @@ func (mb *MessageBus) ConsumeInbound(ctx context.Context) (InboundMessage, bool) } func (mb *MessageBus) PublishOutbound(msg OutboundMessage) { + mb.mu.RLock() + defer mb.mu.RUnlock() + if mb.closed { + return + } mb.outbound <- msg } @@ -60,6 +71,12 @@ func (mb *MessageBus) GetHandler(channel string) (MessageHandler, bool) { } func (mb *MessageBus) Close() { + mb.mu.Lock() + defer mb.mu.Unlock() + if mb.closed { + return + } + mb.closed = true close(mb.inbound) close(mb.outbound) } diff --git a/pkg/channels/discord.go b/pkg/channels/discord.go index e65c99eec..00aa8ab4d 100644 --- a/pkg/channels/discord.go +++ b/pkg/channels/discord.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "os" + "strings" "time" "github.com/bwmarrin/discordgo" @@ -100,15 +101,156 @@ func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro return fmt.Errorf("channel ID is empty") } - message := msg.Content + runes := []rune(msg.Content) + if len(runes) == 0 { + return nil + } + chunks := splitMessage(msg.Content, 1500) // Discord has a limit of 2000 characters per message, leave 500 for natural split e.g. code blocks + + for _, chunk := range chunks { + if err := c.sendChunk(ctx, channelID, chunk); err != nil { + return err + } + } + + return nil +} + +// splitMessage splits long messages into chunks, preserving code block integrity +// Uses natural boundaries (newlines, spaces) and extends messages slightly to avoid breaking code blocks +func splitMessage(content string, limit int) []string { + var messages []string + + for len(content) > 0 { + if len(content) <= limit { + messages = append(messages, content) + break + } + + msgEnd := limit + + // Find natural split point within the limit + msgEnd = findLastNewline(content[:limit], 200) + if msgEnd <= 0 { + msgEnd = findLastSpace(content[:limit], 100) + } + if msgEnd <= 0 { + msgEnd = limit + } + + // Check if this would end with an incomplete code block + candidate := content[:msgEnd] + unclosedIdx := findLastUnclosedCodeBlock(candidate) + + if unclosedIdx >= 0 { + // Message would end with incomplete code block + // Try to extend to include the closing ``` (with some buffer) + extendedLimit := limit + 500 // Allow 500 char buffer for code blocks + if len(content) > extendedLimit { + closingIdx := findNextClosingCodeBlock(content, msgEnd) + if closingIdx > 0 && closingIdx <= extendedLimit { + // Extend to include the closing ``` + msgEnd = closingIdx + } else { + // Can't find closing, split before the code block + msgEnd = findLastNewline(content[:unclosedIdx], 200) + if msgEnd <= 0 { + msgEnd = findLastSpace(content[:unclosedIdx], 100) + } + if msgEnd <= 0 { + msgEnd = unclosedIdx + } + } + } else { + // Remaining content fits within extended limit + msgEnd = len(content) + } + } + + if msgEnd <= 0 { + msgEnd = limit + } + + messages = append(messages, content[:msgEnd]) + content = strings.TrimSpace(content[msgEnd:]) + } + + return messages +} + +// findLastUnclosedCodeBlock finds the last opening ``` that doesn't have a closing ``` +// Returns the position of the opening ``` or -1 if all code blocks are complete +func findLastUnclosedCodeBlock(text string) int { + count := 0 + lastOpenIdx := -1 + + for i := 0; i < len(text); i++ { + if i+2 < len(text) && text[i] == '`' && text[i+1] == '`' && text[i+2] == '`' { + if count == 0 { + lastOpenIdx = i + } + count++ + i += 2 + } + } + + // If odd number of ``` markers, last one is unclosed + if count%2 == 1 { + return lastOpenIdx + } + return -1 +} + +// findNextClosingCodeBlock finds the next closing ``` starting from a position +// Returns the position after the closing ``` or -1 if not found +func findNextClosingCodeBlock(text string, startIdx int) int { + for i := startIdx; i < len(text); i++ { + if i+2 < len(text) && text[i] == '`' && text[i+1] == '`' && text[i+2] == '`' { + return i + 3 + } + } + return -1 +} + +// findLastNewline finds the last newline character within the last N characters +// Returns the position of the newline or -1 if not found +func findLastNewline(s string, searchWindow int) int { + searchStart := len(s) - searchWindow + if searchStart < 0 { + searchStart = 0 + } + for i := len(s) - 1; i >= searchStart; i-- { + if s[i] == '\n' { + return i + } + } + return -1 +} + +// findLastSpace finds the last space character within the last N characters +// Returns the position of the space or -1 if not found +func findLastSpace(s string, searchWindow int) int { + searchStart := len(s) - searchWindow + if searchStart < 0 { + searchStart = 0 + } + for i := len(s) - 1; i >= searchStart; i-- { + if s[i] == ' ' || s[i] == '\t' { + return i + } + } + return -1 +} + +func (c *DiscordChannel) sendChunk(ctx context.Context, channelID, content string) error { // ไฝฟ็”จไผ ๅ…ฅ็š„ ctx ่ฟ›่กŒ่ถ…ๆ—ถๆŽงๅˆถ sendCtx, cancel := context.WithTimeout(ctx, sendTimeout) defer cancel() done := make(chan error, 1) go func() { - _, err := c.session.ChannelMessageSend(channelID, message) + _, err := c.session.ChannelMessageSend(channelID, content) done <- err }() @@ -140,6 +282,12 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag return } + if err := c.session.ChannelTyping(m.ChannelID); err != nil { + logger.ErrorCF("discord", "Failed to send typing indicator", map[string]any{ + "error": err.Error(), + }) + } + // ๆฃ€ๆŸฅ็™ฝๅๅ•๏ผŒ้ฟๅ…ไธบ่ขซๆ‹’็ป็š„็”จๆˆทไธ‹่ฝฝ้™„ไปถๅ’Œ่ฝฌๅฝ• if !c.IsAllowed(m.Author.ID) { logger.DebugCF("discord", "Message rejected by allowlist", map[string]any{ diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go index 15f8c6037..7f6abc4cb 100644 --- a/pkg/channels/manager.go +++ b/pkg/channels/manager.go @@ -48,7 +48,7 @@ func (m *Manager) initChannels() error { if m.config.Channels.Telegram.Enabled && m.config.Channels.Telegram.Token != "" { logger.DebugC("channels", "Attempting to initialize Telegram channel") - telegram, err := NewTelegramChannel(m.config.Channels.Telegram, m.bus) + telegram, err := NewTelegramChannel(m.config, m.bus) if err != nil { logger.ErrorCF("channels", "Failed to initialize Telegram channel", map[string]interface{}{ "error": err.Error(), diff --git a/pkg/channels/slack.go b/pkg/channels/slack.go index d86d08a9d..5387e9213 100644 --- a/pkg/channels/slack.go +++ b/pkg/channels/slack.go @@ -296,6 +296,13 @@ func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) { return } + if !c.IsAllowed(ev.User) { + logger.DebugCF("slack", "Mention rejected by allowlist", map[string]interface{}{ + "user_id": ev.User, + }) + return + } + senderID := ev.User channelID := ev.Channel threadTS := ev.ThreadTimeStamp @@ -345,6 +352,13 @@ func (c *SlackChannel) handleSlashCommand(event socketmode.Event) { c.socketClient.Ack(*event.Request) } + if !c.IsAllowed(cmd.UserID) { + logger.DebugCF("slack", "Slash command rejected by allowlist", map[string]interface{}{ + "user_id": cmd.UserID, + }) + return + } + senderID := cmd.UserID channelID := cmd.ChannelID chatID := channelID diff --git a/pkg/channels/telegram.go b/pkg/channels/telegram.go index b14b1632e..5601d508c 100644 --- a/pkg/channels/telegram.go +++ b/pkg/channels/telegram.go @@ -11,7 +11,10 @@ import ( "sync" "time" + th "github.com/mymmrac/telego/telegohandler" + "github.com/mymmrac/telego" + "github.com/mymmrac/telego/telegohandler" tu "github.com/mymmrac/telego/telegoutil" "github.com/sipeed/picoclaw/pkg/bus" @@ -24,7 +27,8 @@ import ( type TelegramChannel struct { *BaseChannel bot *telego.Bot - config config.TelegramConfig + commands TelegramCommander + config *config.Config chatIDs map[string]int64 transcriber *voice.GroqTranscriber placeholders sync.Map // chatID -> messageID @@ -41,13 +45,14 @@ func (c *thinkingCancel) Cancel() { } } -func NewTelegramChannel(cfg config.TelegramConfig, bus *bus.MessageBus) (*TelegramChannel, error) { +func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChannel, error) { var opts []telego.BotOption + telegramCfg := cfg.Channels.Telegram - if cfg.Proxy != "" { - proxyURL, parseErr := url.Parse(cfg.Proxy) + if telegramCfg.Proxy != "" { + proxyURL, parseErr := url.Parse(telegramCfg.Proxy) if parseErr != nil { - return nil, fmt.Errorf("invalid proxy URL %q: %w", cfg.Proxy, parseErr) + return nil, fmt.Errorf("invalid proxy URL %q: %w", telegramCfg.Proxy, parseErr) } opts = append(opts, telego.WithHTTPClient(&http.Client{ Transport: &http.Transport{ @@ -56,15 +61,16 @@ func NewTelegramChannel(cfg config.TelegramConfig, bus *bus.MessageBus) (*Telegr })) } - bot, err := telego.NewBot(cfg.Token, opts...) + bot, err := telego.NewBot(telegramCfg.Token, opts...) if err != nil { return nil, fmt.Errorf("failed to create telegram bot: %w", err) } - base := NewBaseChannel("telegram", cfg, bus, cfg.AllowFrom) + base := NewBaseChannel("telegram", telegramCfg, bus, telegramCfg.AllowFrom) return &TelegramChannel{ BaseChannel: base, + commands: NewTelegramCommands(bot, cfg), bot: bot, config: cfg, chatIDs: make(map[string]int64), @@ -88,31 +94,45 @@ func (c *TelegramChannel) Start(ctx context.Context) error { return fmt.Errorf("failed to start long polling: %w", err) } + bh, err := telegohandler.NewBotHandler(c.bot, updates) + if err != nil { + return fmt.Errorf("failed to create bot handler: %w", err) + } + + bh.HandleMessage(func(ctx *th.Context, message telego.Message) error { + c.commands.Help(ctx, message) + return nil + }, th.CommandEqual("help")) + bh.HandleMessage(func(ctx *th.Context, message telego.Message) error { + return c.commands.Start(ctx, message) + }, th.CommandEqual("start")) + + bh.HandleMessage(func(ctx *th.Context, message telego.Message) error { + return c.commands.Show(ctx, message) + }, th.CommandEqual("show")) + + bh.HandleMessage(func(ctx *th.Context, message telego.Message) error { + return c.commands.List(ctx, message) + }, th.CommandEqual("list")) + + bh.HandleMessage(func(ctx *th.Context, message telego.Message) error { + return c.handleMessage(ctx, &message) + }, th.AnyMessage()) + c.setRunning(true) logger.InfoCF("telegram", "Telegram bot connected", map[string]interface{}{ "username": c.bot.Username(), }) + go bh.Start() + go func() { - for { - select { - case <-ctx.Done(): - return - case update, ok := <-updates: - if !ok { - logger.InfoC("telegram", "Updates channel closed, reconnecting...") - return - } - if update.Message != nil { - c.handleMessage(ctx, update) - } - } - } + <-ctx.Done() + bh.Stop() }() return nil } - func (c *TelegramChannel) Stop(ctx context.Context) error { logger.InfoC("telegram", "Stopping Telegram bot...") c.setRunning(false) @@ -166,30 +186,27 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err return nil } -func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Update) { - message := update.Message +func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Message) error { if message == nil { - return + return fmt.Errorf("message is nil") } user := message.From if user == nil { - return + return fmt.Errorf("message sender (user) is nil") } - userID := fmt.Sprintf("%d", user.ID) - senderID := userID + senderID := fmt.Sprintf("%d", user.ID) if user.Username != "" { - senderID = fmt.Sprintf("%s|%s", userID, user.Username) + senderID = fmt.Sprintf("%d|%s", user.ID, user.Username) } // ๆฃ€ๆŸฅ็™ฝๅๅ•๏ผŒ้ฟๅ…ไธบ่ขซๆ‹’็ป็š„็”จๆˆทไธ‹่ฝฝ้™„ไปถ - if !c.IsAllowed(userID) && !c.IsAllowed(senderID) { + if !c.IsAllowed(senderID) { logger.DebugCF("telegram", "Message rejected by allowlist", map[string]interface{}{ - "user_id": userID, - "username": user.Username, + "user_id": senderID, }) - return + return nil } chatID := message.Chat.ID @@ -222,7 +239,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat content += message.Caption } - if message.Photo != nil && len(message.Photo) > 0 { + if len(message.Photo) > 0 { photo := message.Photo[len(message.Photo)-1] photoPath := c.downloadPhoto(ctx, photo.FileID) if photoPath != "" { @@ -231,7 +248,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat if content != "" { content += "\n" } - content += fmt.Sprintf("[image: photo]") + content += "[image: photo]" } } @@ -252,7 +269,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat "error": err.Error(), "path": voicePath, }) - transcribedText = fmt.Sprintf("[voice (transcription failed)]") + transcribedText = "[voice (transcription failed)]" } else { transcribedText = fmt.Sprintf("[voice transcription: %s]", result.Text) logger.InfoCF("telegram", "Voice transcribed successfully", map[string]interface{}{ @@ -260,7 +277,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat }) } } else { - transcribedText = fmt.Sprintf("[voice]") + transcribedText = "[voice]" } if content != "" { @@ -278,7 +295,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat if content != "" { content += "\n" } - content += fmt.Sprintf("[audio]") + content += "[audio]" } } @@ -290,7 +307,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat if content != "" { content += "\n" } - content += fmt.Sprintf("[file]") + content += "[file]" } } @@ -338,7 +355,8 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat "is_group": fmt.Sprintf("%t", message.Chat.Type != "private"), } - c.HandleMessage(senderID, fmt.Sprintf("%d", chatID), content, mediaPaths, metadata) + c.HandleMessage(fmt.Sprintf("%d", user.ID), fmt.Sprintf("%d", chatID), content, mediaPaths, metadata) + return nil } func (c *TelegramChannel) downloadPhoto(ctx context.Context, fileID string) string { diff --git a/pkg/channels/telegram_commands.go b/pkg/channels/telegram_commands.go new file mode 100644 index 000000000..df245e156 --- /dev/null +++ b/pkg/channels/telegram_commands.go @@ -0,0 +1,153 @@ +package channels + +import ( + "context" + "fmt" + "strings" + + "github.com/mymmrac/telego" + "github.com/sipeed/picoclaw/pkg/config" +) + +type TelegramCommander interface { + Help(ctx context.Context, message telego.Message) error + Start(ctx context.Context, message telego.Message) error + Show(ctx context.Context, message telego.Message) error + List(ctx context.Context, message telego.Message) error +} + +type cmd struct { + bot *telego.Bot + config *config.Config +} + +func NewTelegramCommands(bot *telego.Bot, cfg *config.Config) TelegramCommander { + return &cmd{ + bot: bot, + config: cfg, + } +} + +func commandArgs(text string) string { + parts := strings.SplitN(text, " ", 2) + if len(parts) < 2 { + return "" + } + return strings.TrimSpace(parts[1]) +} +func (c *cmd) Help(ctx context.Context, message telego.Message) error { + msg := `/start - Start the bot +/help - Show this help message +/show [model|channel] - Show current configuration +/list [models|channels] - List available options + ` + _, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{ + ChatID: telego.ChatID{ID: message.Chat.ID}, + Text: msg, + ReplyParameters: &telego.ReplyParameters{ + MessageID: message.MessageID, + }, + }) + return err +} + +func (c *cmd) Start(ctx context.Context, message telego.Message) error { + _, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{ + ChatID: telego.ChatID{ID: message.Chat.ID}, + Text: "Hello! I am PicoClaw ๐Ÿฆž", + ReplyParameters: &telego.ReplyParameters{ + MessageID: message.MessageID, + }, + }) + return err +} + +func (c *cmd) Show(ctx context.Context, message telego.Message) error { + args := commandArgs(message.Text) + if args == "" { + _, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{ + ChatID: telego.ChatID{ID: message.Chat.ID}, + Text: "Usage: /show [model|channel]", + ReplyParameters: &telego.ReplyParameters{ + MessageID: message.MessageID, + }, + }) + return err + } + + var response string + switch args { + case "model": + response = fmt.Sprintf("Current Model: %s (Provider: %s)", + c.config.Agents.Defaults.Model, + c.config.Agents.Defaults.Provider) + case "channel": + response = "Current Channel: telegram" + default: + response = fmt.Sprintf("Unknown parameter: %s. Try 'model' or 'channel'.", args) + } + + _, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{ + ChatID: telego.ChatID{ID: message.Chat.ID}, + Text: response, + ReplyParameters: &telego.ReplyParameters{ + MessageID: message.MessageID, + }, + }) + return err +} +func (c *cmd) List(ctx context.Context, message telego.Message) error { + args := commandArgs(message.Text) + if args == "" { + _, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{ + ChatID: telego.ChatID{ID: message.Chat.ID}, + Text: "Usage: /list [models|channels]", + ReplyParameters: &telego.ReplyParameters{ + MessageID: message.MessageID, + }, + }) + return err + } + + var response string + switch args { + case "models": + provider := c.config.Agents.Defaults.Provider + if provider == "" { + provider = "configured default" + } + response = fmt.Sprintf("Configured Model: %s\nProvider: %s\n\nTo change models, update config.yaml", + c.config.Agents.Defaults.Model, provider) + + case "channels": + var enabled []string + if c.config.Channels.Telegram.Enabled { + enabled = append(enabled, "telegram") + } + if c.config.Channels.WhatsApp.Enabled { + enabled = append(enabled, "whatsapp") + } + if c.config.Channels.Feishu.Enabled { + enabled = append(enabled, "feishu") + } + if c.config.Channels.Discord.Enabled { + enabled = append(enabled, "discord") + } + if c.config.Channels.Slack.Enabled { + enabled = append(enabled, "slack") + } + response = fmt.Sprintf("Enabled Channels:\n- %s", strings.Join(enabled, "\n- ")) + + default: + response = fmt.Sprintf("Unknown parameter: %s. Try 'models' or 'channels'.", args) + } + + _, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{ + ChatID: telego.ChatID{ID: message.Chat.ID}, + Text: response, + ReplyParameters: &telego.ReplyParameters{ + MessageID: message.MessageID, + }, + }) + return err +} diff --git a/pkg/config/config.go b/pkg/config/config.go index 237eade65..ddc376645 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -175,6 +175,7 @@ type ProvidersConfig struct { VLLM ProviderConfig `json:"vllm"` Gemini ProviderConfig `json:"gemini"` Nvidia ProviderConfig `json:"nvidia"` + Ollama ProviderConfig `json:"ollama"` Moonshot ProviderConfig `json:"moonshot"` ShengSuanYun ProviderConfig `json:"shengsuanyun"` DeepSeek ProviderConfig `json:"deepseek"` @@ -430,7 +431,7 @@ func SaveConfig(path string, cfg *Config) error { return err } - return os.WriteFile(path, data, 0644) + return os.WriteFile(path, data, 0600) } func (c *Config) WorkspacePath() string { diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index 14618b109..febfd0456 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -1,6 +1,9 @@ package config import ( + "os" + "path/filepath" + "runtime" "testing" ) @@ -147,6 +150,30 @@ func TestDefaultConfig_WebTools(t *testing.T) { } } +func TestSaveConfig_FilePermissions(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("file permission bits are not enforced on Windows") + } + + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "config.json") + + cfg := DefaultConfig() + if err := SaveConfig(path, cfg); err != nil { + t.Fatalf("SaveConfig failed: %v", err) + } + + info, err := os.Stat(path) + if err != nil { + t.Fatalf("Stat failed: %v", err) + } + + perm := info.Mode().Perm() + if perm != 0600 { + t.Errorf("config file has permission %04o, want 0600", perm) + } +} + // TestConfig_Complete verifies all config fields are set func TestConfig_Complete(t *testing.T) { cfg := DefaultConfig() diff --git a/pkg/cron/service.go b/pkg/cron/service.go index ddd680e74..9f62c743b 100644 --- a/pkg/cron/service.go +++ b/pkg/cron/service.go @@ -340,7 +340,7 @@ func (cs *CronService) saveStoreUnsafe() error { return err } - return os.WriteFile(cs.storePath, data, 0644) + return os.WriteFile(cs.storePath, data, 0600) } func (cs *CronService) AddJob(name string, schedule CronSchedule, message string, deliver bool, channel, to string) (*CronJob, error) { diff --git a/pkg/cron/service_test.go b/pkg/cron/service_test.go new file mode 100644 index 000000000..53d69f6a9 --- /dev/null +++ b/pkg/cron/service_test.go @@ -0,0 +1,38 @@ +package cron + +import ( + "os" + "path/filepath" + "runtime" + "testing" +) + +func TestSaveStore_FilePermissions(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("file permission bits are not enforced on Windows") + } + + tmpDir := t.TempDir() + storePath := filepath.Join(tmpDir, "cron", "jobs.json") + + cs := NewCronService(storePath, nil) + + _, err := cs.AddJob("test", CronSchedule{Kind: "every", EveryMS: int64Ptr(60000)}, "hello", false, "cli", "direct") + if err != nil { + t.Fatalf("AddJob failed: %v", err) + } + + info, err := os.Stat(storePath) + if err != nil { + t.Fatalf("Stat failed: %v", err) + } + + perm := info.Mode().Perm() + if perm != 0600 { + t.Errorf("cron store has permission %04o, want 0600", perm) + } +} + +func int64Ptr(v int64) *int64 { + return &v +} diff --git a/pkg/health/server.go b/pkg/health/server.go new file mode 100644 index 000000000..77b36034d --- /dev/null +++ b/pkg/health/server.go @@ -0,0 +1,164 @@ +package health + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "sync" + "time" +) + +type Server struct { + server *http.Server + mu sync.RWMutex + ready bool + checks map[string]Check + startTime time.Time +} + +type Check struct { + Name string `json:"name"` + Status string `json:"status"` + Message string `json:"message,omitempty"` + Timestamp time.Time `json:"timestamp"` +} + +type StatusResponse struct { + Status string `json:"status"` + Uptime string `json:"uptime"` + Checks map[string]Check `json:"checks,omitempty"` +} + +func NewServer(host string, port int) *Server { + mux := http.NewServeMux() + s := &Server{ + ready: false, + checks: make(map[string]Check), + startTime: time.Now(), + } + + mux.HandleFunc("/health", s.healthHandler) + mux.HandleFunc("/ready", s.readyHandler) + + addr := fmt.Sprintf("%s:%d", host, port) + s.server = &http.Server{ + Addr: addr, + Handler: mux, + ReadTimeout: 5 * time.Second, + WriteTimeout: 5 * time.Second, + } + + return s +} + +func (s *Server) Start() error { + s.mu.Lock() + s.ready = true + s.mu.Unlock() + return s.server.ListenAndServe() +} + +func (s *Server) StartContext(ctx context.Context) error { + s.mu.Lock() + s.ready = true + s.mu.Unlock() + + errCh := make(chan error, 1) + go func() { + errCh <- s.server.ListenAndServe() + }() + + select { + case err := <-errCh: + return err + case <-ctx.Done(): + return s.server.Shutdown(context.Background()) + } +} + +func (s *Server) Stop(ctx context.Context) error { + s.mu.Lock() + s.ready = false + s.mu.Unlock() + return s.server.Shutdown(ctx) +} + +func (s *Server) SetReady(ready bool) { + s.mu.Lock() + s.ready = ready + s.mu.Unlock() +} + +func (s *Server) RegisterCheck(name string, checkFn func() (bool, string)) { + s.mu.Lock() + defer s.mu.Unlock() + + status, msg := checkFn() + s.checks[name] = Check{ + Name: name, + Status: statusString(status), + Message: msg, + Timestamp: time.Now(), + } +} + +func (s *Server) healthHandler(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + + uptime := time.Since(s.startTime) + resp := StatusResponse{ + Status: "ok", + Uptime: uptime.String(), + } + + json.NewEncoder(w).Encode(resp) +} + +func (s *Server) readyHandler(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + + s.mu.RLock() + ready := s.ready + checks := make(map[string]Check) + for k, v := range s.checks { + checks[k] = v + } + s.mu.RUnlock() + + if !ready { + w.WriteHeader(http.StatusServiceUnavailable) + json.NewEncoder(w).Encode(StatusResponse{ + Status: "not ready", + Checks: checks, + }) + return + } + + for _, check := range checks { + if check.Status == "fail" { + w.WriteHeader(http.StatusServiceUnavailable) + json.NewEncoder(w).Encode(StatusResponse{ + Status: "not ready", + Checks: checks, + }) + return + } + } + + w.WriteHeader(http.StatusOK) + uptime := time.Since(s.startTime) + json.NewEncoder(w).Encode(StatusResponse{ + Status: "ready", + Uptime: uptime.String(), + Checks: checks, + }) +} + +func statusString(ok bool) string { + if ok { + return "ok" + } + return "fail" +} diff --git a/pkg/providers/claude_cli_provider.go b/pkg/providers/claude_cli_provider.go index a91795715..58ba3647d 100644 --- a/pkg/providers/claude_cli_provider.go +++ b/pkg/providers/claude_cli_provider.go @@ -171,68 +171,14 @@ func (p *ClaudeCliProvider) parseClaudeCliResponse(output string) (*LLMResponse, }, nil } -// extractToolCalls parses tool call JSON from the response text. +// extractToolCalls delegates to the shared extractToolCallsFromText function. func (p *ClaudeCliProvider) extractToolCalls(text string) []ToolCall { - start := strings.Index(text, `{"tool_calls"`) - if start == -1 { - return nil - } - - end := findMatchingBrace(text, start) - if end == start { - return nil - } - - jsonStr := text[start:end] - - var wrapper struct { - ToolCalls []struct { - ID string `json:"id"` - Type string `json:"type"` - Function struct { - Name string `json:"name"` - Arguments string `json:"arguments"` - } `json:"function"` - } `json:"tool_calls"` - } - - if err := json.Unmarshal([]byte(jsonStr), &wrapper); err != nil { - return nil - } - - var result []ToolCall - for _, tc := range wrapper.ToolCalls { - var args map[string]interface{} - json.Unmarshal([]byte(tc.Function.Arguments), &args) - - result = append(result, ToolCall{ - ID: tc.ID, - Type: tc.Type, - Name: tc.Function.Name, - Arguments: args, - Function: &FunctionCall{ - Name: tc.Function.Name, - Arguments: tc.Function.Arguments, - }, - }) - } - - return result + return extractToolCallsFromText(text) } -// stripToolCallsJSON removes tool call JSON from response text. +// stripToolCallsJSON delegates to the shared stripToolCallsFromText function. func (p *ClaudeCliProvider) stripToolCallsJSON(text string) string { - start := strings.Index(text, `{"tool_calls"`) - if start == -1 { - return text - } - - end := findMatchingBrace(text, start) - if end == start { - return text - } - - return strings.TrimSpace(text[:start] + text[end:]) + return stripToolCallsFromText(text) } // findMatchingBrace finds the index after the closing brace matching the opening brace at pos. diff --git a/pkg/providers/codex_cli_credentials.go b/pkg/providers/codex_cli_credentials.go new file mode 100644 index 000000000..7ad39ce8e --- /dev/null +++ b/pkg/providers/codex_cli_credentials.go @@ -0,0 +1,79 @@ +package providers + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "time" +) + +// CodexCliAuth represents the ~/.codex/auth.json file structure. +type CodexCliAuth struct { + Tokens struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + AccountID string `json:"account_id"` + } `json:"tokens"` +} + +// ReadCodexCliCredentials reads OAuth tokens from the Codex CLI's auth.json file. +// Expiry is estimated as file modification time + 1 hour (same approach as moltbot). +func ReadCodexCliCredentials() (accessToken, accountID string, expiresAt time.Time, err error) { + authPath, err := resolveCodexAuthPath() + if err != nil { + return "", "", time.Time{}, err + } + + data, err := os.ReadFile(authPath) + if err != nil { + return "", "", time.Time{}, fmt.Errorf("reading %s: %w", authPath, err) + } + + var auth CodexCliAuth + if err := json.Unmarshal(data, &auth); err != nil { + return "", "", time.Time{}, fmt.Errorf("parsing %s: %w", authPath, err) + } + + if auth.Tokens.AccessToken == "" { + return "", "", time.Time{}, fmt.Errorf("no access_token in %s", authPath) + } + + stat, err := os.Stat(authPath) + if err != nil { + expiresAt = time.Now().Add(time.Hour) + } else { + expiresAt = stat.ModTime().Add(time.Hour) + } + + return auth.Tokens.AccessToken, auth.Tokens.AccountID, expiresAt, nil +} + +// CreateCodexCliTokenSource creates a token source that reads from ~/.codex/auth.json. +// This allows the existing CodexProvider to reuse Codex CLI credentials. +func CreateCodexCliTokenSource() func() (string, string, error) { + return func() (string, string, error) { + token, accountID, expiresAt, err := ReadCodexCliCredentials() + if err != nil { + return "", "", fmt.Errorf("reading codex cli credentials: %w", err) + } + + if time.Now().After(expiresAt) { + return "", "", fmt.Errorf("codex cli credentials expired (auth.json last modified > 1h ago). Run: codex login") + } + + return token, accountID, nil + } +} + +func resolveCodexAuthPath() (string, error) { + codexHome := os.Getenv("CODEX_HOME") + if codexHome == "" { + home, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("getting home dir: %w", err) + } + codexHome = filepath.Join(home, ".codex") + } + return filepath.Join(codexHome, "auth.json"), nil +} diff --git a/pkg/providers/codex_cli_credentials_test.go b/pkg/providers/codex_cli_credentials_test.go new file mode 100644 index 000000000..3267f2d16 --- /dev/null +++ b/pkg/providers/codex_cli_credentials_test.go @@ -0,0 +1,181 @@ +package providers + +import ( + "os" + "path/filepath" + "testing" + "time" +) + +func TestReadCodexCliCredentials_Valid(t *testing.T) { + tmpDir := t.TempDir() + authPath := filepath.Join(tmpDir, "auth.json") + + authJSON := `{ + "tokens": { + "access_token": "test-access-token", + "refresh_token": "test-refresh-token", + "account_id": "org-test123" + } + }` + if err := os.WriteFile(authPath, []byte(authJSON), 0600); err != nil { + t.Fatal(err) + } + + t.Setenv("CODEX_HOME", tmpDir) + + token, accountID, expiresAt, err := ReadCodexCliCredentials() + if err != nil { + t.Fatalf("ReadCodexCliCredentials() error: %v", err) + } + if token != "test-access-token" { + t.Errorf("token = %q, want %q", token, "test-access-token") + } + if accountID != "org-test123" { + t.Errorf("accountID = %q, want %q", accountID, "org-test123") + } + // Expiry should be within ~1 hour from now (file was just written) + if expiresAt.Before(time.Now()) { + t.Errorf("expiresAt = %v, should be in the future", expiresAt) + } + if expiresAt.After(time.Now().Add(2 * time.Hour)) { + t.Errorf("expiresAt = %v, should be within ~1 hour", expiresAt) + } +} + +func TestReadCodexCliCredentials_MissingFile(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("CODEX_HOME", tmpDir) + + _, _, _, err := ReadCodexCliCredentials() + if err == nil { + t.Fatal("expected error for missing auth.json") + } +} + +func TestReadCodexCliCredentials_EmptyToken(t *testing.T) { + tmpDir := t.TempDir() + authPath := filepath.Join(tmpDir, "auth.json") + + authJSON := `{"tokens": {"access_token": "", "refresh_token": "r", "account_id": "a"}}` + if err := os.WriteFile(authPath, []byte(authJSON), 0600); err != nil { + t.Fatal(err) + } + + t.Setenv("CODEX_HOME", tmpDir) + + _, _, _, err := ReadCodexCliCredentials() + if err == nil { + t.Fatal("expected error for empty access_token") + } +} + +func TestReadCodexCliCredentials_InvalidJSON(t *testing.T) { + tmpDir := t.TempDir() + authPath := filepath.Join(tmpDir, "auth.json") + + if err := os.WriteFile(authPath, []byte("not json"), 0600); err != nil { + t.Fatal(err) + } + + t.Setenv("CODEX_HOME", tmpDir) + + _, _, _, err := ReadCodexCliCredentials() + if err == nil { + t.Fatal("expected error for invalid JSON") + } +} + +func TestReadCodexCliCredentials_NoAccountID(t *testing.T) { + tmpDir := t.TempDir() + authPath := filepath.Join(tmpDir, "auth.json") + + authJSON := `{"tokens": {"access_token": "tok123", "refresh_token": "ref456"}}` + if err := os.WriteFile(authPath, []byte(authJSON), 0600); err != nil { + t.Fatal(err) + } + + t.Setenv("CODEX_HOME", tmpDir) + + token, accountID, _, err := ReadCodexCliCredentials() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if token != "tok123" { + t.Errorf("token = %q, want %q", token, "tok123") + } + if accountID != "" { + t.Errorf("accountID = %q, want empty", accountID) + } +} + +func TestReadCodexCliCredentials_CodexHomeEnv(t *testing.T) { + tmpDir := t.TempDir() + customDir := filepath.Join(tmpDir, "custom-codex") + if err := os.MkdirAll(customDir, 0755); err != nil { + t.Fatal(err) + } + + authJSON := `{"tokens": {"access_token": "custom-token", "refresh_token": "r"}}` + if err := os.WriteFile(filepath.Join(customDir, "auth.json"), []byte(authJSON), 0600); err != nil { + t.Fatal(err) + } + + t.Setenv("CODEX_HOME", customDir) + + token, _, _, err := ReadCodexCliCredentials() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if token != "custom-token" { + t.Errorf("token = %q, want %q", token, "custom-token") + } +} + +func TestCreateCodexCliTokenSource_Valid(t *testing.T) { + tmpDir := t.TempDir() + authPath := filepath.Join(tmpDir, "auth.json") + + authJSON := `{"tokens": {"access_token": "fresh-token", "refresh_token": "r", "account_id": "acc"}}` + if err := os.WriteFile(authPath, []byte(authJSON), 0600); err != nil { + t.Fatal(err) + } + + t.Setenv("CODEX_HOME", tmpDir) + + source := CreateCodexCliTokenSource() + token, accountID, err := source() + if err != nil { + t.Fatalf("token source error: %v", err) + } + if token != "fresh-token" { + t.Errorf("token = %q, want %q", token, "fresh-token") + } + if accountID != "acc" { + t.Errorf("accountID = %q, want %q", accountID, "acc") + } +} + +func TestCreateCodexCliTokenSource_Expired(t *testing.T) { + tmpDir := t.TempDir() + authPath := filepath.Join(tmpDir, "auth.json") + + authJSON := `{"tokens": {"access_token": "old-token", "refresh_token": "r"}}` + if err := os.WriteFile(authPath, []byte(authJSON), 0600); err != nil { + t.Fatal(err) + } + + // Set file modification time to 2 hours ago + oldTime := time.Now().Add(-2 * time.Hour) + if err := os.Chtimes(authPath, oldTime, oldTime); err != nil { + t.Fatal(err) + } + + t.Setenv("CODEX_HOME", tmpDir) + + source := CreateCodexCliTokenSource() + _, _, err := source() + if err == nil { + t.Fatal("expected error for expired credentials") + } +} diff --git a/pkg/providers/codex_cli_provider.go b/pkg/providers/codex_cli_provider.go new file mode 100644 index 000000000..8886406b4 --- /dev/null +++ b/pkg/providers/codex_cli_provider.go @@ -0,0 +1,251 @@ +package providers + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "os/exec" + "strings" +) + +// CodexCliProvider implements LLMProvider by wrapping the codex CLI as a subprocess. +type CodexCliProvider struct { + command string + workspace string +} + +// NewCodexCliProvider creates a new Codex CLI provider. +func NewCodexCliProvider(workspace string) *CodexCliProvider { + return &CodexCliProvider{ + command: "codex", + workspace: workspace, + } +} + +// Chat implements LLMProvider.Chat by executing the codex CLI in non-interactive mode. +func (p *CodexCliProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { + if p.command == "" { + return nil, fmt.Errorf("codex command not configured") + } + + prompt := p.buildPrompt(messages, tools) + + args := []string{ + "exec", + "--json", + "--dangerously-bypass-approvals-and-sandbox", + "--skip-git-repo-check", + "--color", "never", + } + if model != "" && model != "codex-cli" { + args = append(args, "-m", model) + } + if p.workspace != "" { + args = append(args, "-C", p.workspace) + } + args = append(args, "-") // read prompt from stdin + + cmd := exec.CommandContext(ctx, p.command, args...) + cmd.Stdin = bytes.NewReader([]byte(prompt)) + + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + err := cmd.Run() + + // Parse JSONL from stdout even if exit code is non-zero, + // because codex writes diagnostic noise to stderr (e.g. rollout errors) + // but still produces valid JSONL output. + if stdoutStr := stdout.String(); stdoutStr != "" { + resp, parseErr := p.parseJSONLEvents(stdoutStr) + if parseErr == nil && resp != nil && (resp.Content != "" || len(resp.ToolCalls) > 0) { + return resp, nil + } + } + + if err != nil { + if ctx.Err() == context.Canceled { + return nil, ctx.Err() + } + if stderrStr := stderr.String(); stderrStr != "" { + return nil, fmt.Errorf("codex cli error: %s", stderrStr) + } + return nil, fmt.Errorf("codex cli error: %w", err) + } + + return p.parseJSONLEvents(stdout.String()) +} + +// GetDefaultModel returns the default model identifier. +func (p *CodexCliProvider) GetDefaultModel() string { + return "codex-cli" +} + +// buildPrompt converts messages to a prompt string for the Codex CLI. +// System messages are prepended as instructions since Codex CLI has no --system-prompt flag. +func (p *CodexCliProvider) buildPrompt(messages []Message, tools []ToolDefinition) string { + var systemParts []string + var conversationParts []string + + for _, msg := range messages { + switch msg.Role { + case "system": + systemParts = append(systemParts, msg.Content) + case "user": + conversationParts = append(conversationParts, msg.Content) + case "assistant": + conversationParts = append(conversationParts, "Assistant: "+msg.Content) + case "tool": + conversationParts = append(conversationParts, + fmt.Sprintf("[Tool Result for %s]: %s", msg.ToolCallID, msg.Content)) + } + } + + var sb strings.Builder + + if len(systemParts) > 0 { + sb.WriteString("## System Instructions\n\n") + sb.WriteString(strings.Join(systemParts, "\n\n")) + sb.WriteString("\n\n## Task\n\n") + } + + if len(tools) > 0 { + sb.WriteString(p.buildToolsPrompt(tools)) + sb.WriteString("\n\n") + } + + // Simplify single user message (no prefix) + if len(conversationParts) == 1 && len(systemParts) == 0 && len(tools) == 0 { + return conversationParts[0] + } + + sb.WriteString(strings.Join(conversationParts, "\n")) + return sb.String() +} + +// buildToolsPrompt creates a tool definitions section for the prompt. +func (p *CodexCliProvider) buildToolsPrompt(tools []ToolDefinition) string { + var sb strings.Builder + + sb.WriteString("## Available Tools\n\n") + sb.WriteString("When you need to use a tool, respond with ONLY a JSON object:\n\n") + sb.WriteString("```json\n") + sb.WriteString(`{"tool_calls":[{"id":"call_xxx","type":"function","function":{"name":"tool_name","arguments":"{...}"}}]}`) + sb.WriteString("\n```\n\n") + sb.WriteString("CRITICAL: The 'arguments' field MUST be a JSON-encoded STRING.\n\n") + sb.WriteString("### Tool Definitions:\n\n") + + for _, tool := range tools { + if tool.Type != "function" { + continue + } + sb.WriteString(fmt.Sprintf("#### %s\n", tool.Function.Name)) + if tool.Function.Description != "" { + sb.WriteString(fmt.Sprintf("Description: %s\n", tool.Function.Description)) + } + if len(tool.Function.Parameters) > 0 { + paramsJSON, _ := json.Marshal(tool.Function.Parameters) + sb.WriteString(fmt.Sprintf("Parameters:\n```json\n%s\n```\n", string(paramsJSON))) + } + sb.WriteString("\n") + } + + return sb.String() +} + +// codexEvent represents a single JSONL event from `codex exec --json`. +type codexEvent struct { + Type string `json:"type"` + ThreadID string `json:"thread_id,omitempty"` + Message string `json:"message,omitempty"` + Item *codexEventItem `json:"item,omitempty"` + Usage *codexUsage `json:"usage,omitempty"` + Error *codexEventErr `json:"error,omitempty"` +} + +type codexEventItem struct { + ID string `json:"id"` + Type string `json:"type"` + Text string `json:"text,omitempty"` + Command string `json:"command,omitempty"` + Status string `json:"status,omitempty"` + ExitCode *int `json:"exit_code,omitempty"` + Output string `json:"output,omitempty"` +} + +type codexUsage struct { + InputTokens int `json:"input_tokens"` + CachedInputTokens int `json:"cached_input_tokens"` + OutputTokens int `json:"output_tokens"` +} + +type codexEventErr struct { + Message string `json:"message"` +} + +// parseJSONLEvents processes the JSONL output from codex exec --json. +func (p *CodexCliProvider) parseJSONLEvents(output string) (*LLMResponse, error) { + var contentParts []string + var usage *UsageInfo + var lastError string + + scanner := bufio.NewScanner(strings.NewReader(output)) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" { + continue + } + + var event codexEvent + if err := json.Unmarshal([]byte(line), &event); err != nil { + continue // skip malformed lines + } + + switch event.Type { + case "item.completed": + if event.Item != nil && event.Item.Type == "agent_message" && event.Item.Text != "" { + contentParts = append(contentParts, event.Item.Text) + } + case "turn.completed": + if event.Usage != nil { + promptTokens := event.Usage.InputTokens + event.Usage.CachedInputTokens + usage = &UsageInfo{ + PromptTokens: promptTokens, + CompletionTokens: event.Usage.OutputTokens, + TotalTokens: promptTokens + event.Usage.OutputTokens, + } + } + case "error": + lastError = event.Message + case "turn.failed": + if event.Error != nil { + lastError = event.Error.Message + } + } + } + + if lastError != "" && len(contentParts) == 0 { + return nil, fmt.Errorf("codex cli: %s", lastError) + } + + content := strings.Join(contentParts, "\n") + + // Extract tool calls from response text (same pattern as ClaudeCliProvider) + toolCalls := extractToolCallsFromText(content) + + finishReason := "stop" + if len(toolCalls) > 0 { + finishReason = "tool_calls" + content = stripToolCallsFromText(content) + } + + return &LLMResponse{ + Content: strings.TrimSpace(content), + ToolCalls: toolCalls, + FinishReason: finishReason, + Usage: usage, + }, nil +} diff --git a/pkg/providers/codex_cli_provider_test.go b/pkg/providers/codex_cli_provider_test.go new file mode 100644 index 000000000..7e4e1bc15 --- /dev/null +++ b/pkg/providers/codex_cli_provider_test.go @@ -0,0 +1,585 @@ +package providers + +import ( + "context" + "encoding/json" + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" + "testing" +) + +// --- JSONL Event Parsing Tests --- + +func TestParseJSONLEvents_AgentMessage(t *testing.T) { + p := &CodexCliProvider{} + events := `{"type":"thread.started","thread_id":"abc-123"} +{"type":"turn.started"} +{"type":"item.completed","item":{"id":"item_1","type":"agent_message","text":"Hello from Codex!"}} +{"type":"turn.completed","usage":{"input_tokens":100,"cached_input_tokens":50,"output_tokens":20}}` + + resp, err := p.parseJSONLEvents(events) + if err != nil { + t.Fatalf("parseJSONLEvents() error: %v", err) + } + if resp.Content != "Hello from Codex!" { + t.Errorf("Content = %q, want %q", resp.Content, "Hello from Codex!") + } + if resp.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop") + } + if resp.Usage == nil { + t.Fatal("Usage should not be nil") + } + if resp.Usage.PromptTokens != 150 { + t.Errorf("PromptTokens = %d, want 150", resp.Usage.PromptTokens) + } + if resp.Usage.CompletionTokens != 20 { + t.Errorf("CompletionTokens = %d, want 20", resp.Usage.CompletionTokens) + } + if resp.Usage.TotalTokens != 170 { + t.Errorf("TotalTokens = %d, want 170", resp.Usage.TotalTokens) + } + if len(resp.ToolCalls) != 0 { + t.Errorf("ToolCalls should be empty, got %d", len(resp.ToolCalls)) + } +} + +func TestParseJSONLEvents_ToolCallExtraction(t *testing.T) { + p := &CodexCliProvider{} + toolCallText := `Let me read that file. +{"tool_calls":[{"id":"call_1","type":"function","function":{"name":"read_file","arguments":"{\"path\":\"/tmp/test.txt\"}"}}]}` + // Build valid JSONL by marshaling the event + item := codexEvent{ + Type: "item.completed", + Item: &codexEventItem{ID: "item_1", Type: "agent_message", Text: toolCallText}, + } + itemJSON, _ := json.Marshal(item) + usageEvt := `{"type":"turn.completed","usage":{"input_tokens":50,"cached_input_tokens":0,"output_tokens":20}}` + events := `{"type":"turn.started"}` + "\n" + string(itemJSON) + "\n" + usageEvt + + resp, err := p.parseJSONLEvents(events) + if err != nil { + t.Fatalf("parseJSONLEvents() error: %v", err) + } + if resp.FinishReason != "tool_calls" { + t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "tool_calls") + } + if len(resp.ToolCalls) != 1 { + t.Fatalf("ToolCalls count = %d, want 1", len(resp.ToolCalls)) + } + if resp.ToolCalls[0].Name != "read_file" { + t.Errorf("ToolCalls[0].Name = %q, want %q", resp.ToolCalls[0].Name, "read_file") + } + if resp.ToolCalls[0].ID != "call_1" { + t.Errorf("ToolCalls[0].ID = %q, want %q", resp.ToolCalls[0].ID, "call_1") + } + if resp.ToolCalls[0].Function.Arguments != `{"path":"/tmp/test.txt"}` { + t.Errorf("ToolCalls[0].Function.Arguments = %q", resp.ToolCalls[0].Function.Arguments) + } + // Content should have the tool call JSON stripped + if strings.Contains(resp.Content, "tool_calls") { + t.Errorf("Content should not contain tool_calls JSON, got: %q", resp.Content) + } +} + +func TestParseJSONLEvents_MultipleToolCalls(t *testing.T) { + p := &CodexCliProvider{} + toolCallText := `{"tool_calls":[{"id":"call_1","type":"function","function":{"name":"read_file","arguments":"{\"path\":\"a.txt\"}"}},{"id":"call_2","type":"function","function":{"name":"write_file","arguments":"{\"path\":\"b.txt\",\"content\":\"hello\"}"}}]}` + item := codexEvent{ + Type: "item.completed", + Item: &codexEventItem{ID: "item_1", Type: "agent_message", Text: toolCallText}, + } + itemJSON, _ := json.Marshal(item) + events := `{"type":"turn.started"}` + "\n" + string(itemJSON) + "\n" + `{"type":"turn.completed"}` + + resp, err := p.parseJSONLEvents(events) + if err != nil { + t.Fatalf("parseJSONLEvents() error: %v", err) + } + if len(resp.ToolCalls) != 2 { + t.Fatalf("ToolCalls count = %d, want 2", len(resp.ToolCalls)) + } + if resp.ToolCalls[0].Name != "read_file" { + t.Errorf("ToolCalls[0].Name = %q, want %q", resp.ToolCalls[0].Name, "read_file") + } + if resp.ToolCalls[1].Name != "write_file" { + t.Errorf("ToolCalls[1].Name = %q, want %q", resp.ToolCalls[1].Name, "write_file") + } + if resp.FinishReason != "tool_calls" { + t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "tool_calls") + } +} + +func TestParseJSONLEvents_MultipleMessages(t *testing.T) { + p := &CodexCliProvider{} + events := `{"type":"turn.started"} +{"type":"item.completed","item":{"id":"item_1","type":"agent_message","text":"First part."}} +{"type":"item.completed","item":{"id":"item_2","type":"command_execution","command":"ls","status":"completed"}} +{"type":"item.completed","item":{"id":"item_3","type":"agent_message","text":"Second part."}} +{"type":"turn.completed"}` + + resp, err := p.parseJSONLEvents(events) + if err != nil { + t.Fatalf("parseJSONLEvents() error: %v", err) + } + if resp.Content != "First part.\nSecond part." { + t.Errorf("Content = %q, want %q", resp.Content, "First part.\nSecond part.") + } +} + +func TestParseJSONLEvents_ErrorEvent(t *testing.T) { + p := &CodexCliProvider{} + events := `{"type":"thread.started","thread_id":"abc"} +{"type":"turn.started"} +{"type":"error","message":"token expired"} +{"type":"turn.failed","error":{"message":"token expired"}}` + + _, err := p.parseJSONLEvents(events) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "token expired") { + t.Errorf("error = %q, want to contain 'token expired'", err.Error()) + } +} + +func TestParseJSONLEvents_TurnFailed(t *testing.T) { + p := &CodexCliProvider{} + events := `{"type":"turn.started"} +{"type":"turn.failed","error":{"message":"rate limit exceeded"}}` + + _, err := p.parseJSONLEvents(events) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "rate limit exceeded") { + t.Errorf("error = %q, want to contain 'rate limit exceeded'", err.Error()) + } +} + +func TestParseJSONLEvents_ErrorWithContent(t *testing.T) { + p := &CodexCliProvider{} + // If there's an error but also content, return the content (partial success) + events := `{"type":"turn.started"} +{"type":"item.completed","item":{"id":"item_1","type":"agent_message","text":"Partial result."}} +{"type":"error","message":"connection reset"} +{"type":"turn.failed","error":{"message":"connection reset"}}` + + resp, err := p.parseJSONLEvents(events) + if err != nil { + t.Fatalf("should not error when content exists: %v", err) + } + if resp.Content != "Partial result." { + t.Errorf("Content = %q, want %q", resp.Content, "Partial result.") + } +} + +func TestParseJSONLEvents_EmptyOutput(t *testing.T) { + p := &CodexCliProvider{} + resp, err := p.parseJSONLEvents("") + if err != nil { + t.Fatalf("empty output should not error: %v", err) + } + if resp.Content != "" { + t.Errorf("Content = %q, want empty", resp.Content) + } +} + +func TestParseJSONLEvents_MalformedLines(t *testing.T) { + p := &CodexCliProvider{} + events := `not json at all +{"type":"item.completed","item":{"id":"item_1","type":"agent_message","text":"Good line."}} +another bad line +{"type":"turn.completed","usage":{"input_tokens":10,"output_tokens":5}}` + + resp, err := p.parseJSONLEvents(events) + if err != nil { + t.Fatalf("should skip malformed lines: %v", err) + } + if resp.Content != "Good line." { + t.Errorf("Content = %q, want %q", resp.Content, "Good line.") + } + if resp.Usage == nil || resp.Usage.TotalTokens != 15 { + t.Errorf("Usage.TotalTokens = %v, want 15", resp.Usage) + } +} + +func TestParseJSONLEvents_CommandExecution(t *testing.T) { + p := &CodexCliProvider{} + events := `{"type":"turn.started"} +{"type":"item.started","item":{"id":"item_1","type":"command_execution","command":"bash -lc ls","status":"in_progress"}} +{"type":"item.completed","item":{"id":"item_1","type":"command_execution","command":"bash -lc ls","status":"completed","exit_code":0,"output":"file1.go\nfile2.go"}} +{"type":"item.completed","item":{"id":"item_2","type":"agent_message","text":"Found 2 files."}} +{"type":"turn.completed"}` + + resp, err := p.parseJSONLEvents(events) + if err != nil { + t.Fatalf("parseJSONLEvents() error: %v", err) + } + // command_execution items should be skipped; only agent_message text is returned + if resp.Content != "Found 2 files." { + t.Errorf("Content = %q, want %q", resp.Content, "Found 2 files.") + } +} + +func TestParseJSONLEvents_NoUsage(t *testing.T) { + p := &CodexCliProvider{} + events := `{"type":"turn.started"} +{"type":"item.completed","item":{"id":"item_1","type":"agent_message","text":"No usage info."}} +{"type":"turn.completed"}` + + resp, err := p.parseJSONLEvents(events) + if err != nil { + t.Fatalf("parseJSONLEvents() error: %v", err) + } + if resp.Usage != nil { + t.Errorf("Usage should be nil when turn.completed has no usage, got %+v", resp.Usage) + } +} + +// --- Prompt Building Tests --- + +func TestBuildPrompt_SystemAsInstructions(t *testing.T) { + p := &CodexCliProvider{} + messages := []Message{ + {Role: "system", Content: "You are helpful."}, + {Role: "user", Content: "Hi there"}, + } + + prompt := p.buildPrompt(messages, nil) + + if !strings.Contains(prompt, "## System Instructions") { + t.Error("prompt should contain '## System Instructions'") + } + if !strings.Contains(prompt, "You are helpful.") { + t.Error("prompt should contain system content") + } + if !strings.Contains(prompt, "## Task") { + t.Error("prompt should contain '## Task'") + } + if !strings.Contains(prompt, "Hi there") { + t.Error("prompt should contain user message") + } +} + +func TestBuildPrompt_NoSystem(t *testing.T) { + p := &CodexCliProvider{} + messages := []Message{ + {Role: "user", Content: "Just a question"}, + } + + prompt := p.buildPrompt(messages, nil) + + if strings.Contains(prompt, "## System Instructions") { + t.Error("prompt should not contain system instructions header") + } + if prompt != "Just a question" { + t.Errorf("prompt = %q, want %q", prompt, "Just a question") + } +} + +func TestBuildPrompt_WithTools(t *testing.T) { + p := &CodexCliProvider{} + messages := []Message{ + {Role: "user", Content: "Get weather"}, + } + tools := []ToolDefinition{ + { + Type: "function", + Function: ToolFunctionDefinition{ + Name: "get_weather", + Description: "Get current weather", + Parameters: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "city": map[string]interface{}{"type": "string"}, + }, + }, + }, + }, + } + + prompt := p.buildPrompt(messages, tools) + + if !strings.Contains(prompt, "## Available Tools") { + t.Error("prompt should contain tools section") + } + if !strings.Contains(prompt, "get_weather") { + t.Error("prompt should contain tool name") + } + if !strings.Contains(prompt, "Get current weather") { + t.Error("prompt should contain tool description") + } +} + +func TestBuildPrompt_MultipleMessages(t *testing.T) { + p := &CodexCliProvider{} + messages := []Message{ + {Role: "user", Content: "Hello"}, + {Role: "assistant", Content: "Hi! How can I help?"}, + {Role: "user", Content: "Tell me about Go"}, + } + + prompt := p.buildPrompt(messages, nil) + + if !strings.Contains(prompt, "Hello") { + t.Error("prompt should contain first user message") + } + if !strings.Contains(prompt, "Assistant: Hi! How can I help?") { + t.Error("prompt should contain assistant message with prefix") + } + if !strings.Contains(prompt, "Tell me about Go") { + t.Error("prompt should contain second user message") + } +} + +func TestBuildPrompt_ToolResults(t *testing.T) { + p := &CodexCliProvider{} + messages := []Message{ + {Role: "user", Content: "Weather?"}, + {Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"}, + } + + prompt := p.buildPrompt(messages, nil) + + if !strings.Contains(prompt, "[Tool Result for call_1]") { + t.Error("prompt should contain tool result") + } + if !strings.Contains(prompt, `{"temp": 72}`) { + t.Error("prompt should contain tool result content") + } +} + +func TestBuildPrompt_SystemAndTools(t *testing.T) { + p := &CodexCliProvider{} + messages := []Message{ + {Role: "system", Content: "Be concise."}, + {Role: "user", Content: "Do something"}, + } + tools := []ToolDefinition{ + { + Type: "function", + Function: ToolFunctionDefinition{ + Name: "my_tool", + Description: "A tool", + }, + }, + } + + prompt := p.buildPrompt(messages, tools) + + // System instructions should come first + sysIdx := strings.Index(prompt, "## System Instructions") + toolIdx := strings.Index(prompt, "## Available Tools") + taskIdx := strings.Index(prompt, "## Task") + + if sysIdx == -1 || toolIdx == -1 || taskIdx == -1 { + t.Fatal("prompt should contain all sections") + } + if sysIdx >= taskIdx { + t.Error("system instructions should come before task") + } + if taskIdx >= toolIdx { + t.Error("task section should come before tools in the output") + } +} + +// --- CLI Argument Tests --- + +func TestCodexCliProvider_GetDefaultModel(t *testing.T) { + p := NewCodexCliProvider("") + if got := p.GetDefaultModel(); got != "codex-cli" { + t.Errorf("GetDefaultModel() = %q, want %q", got, "codex-cli") + } +} + +// --- Mock CLI Integration Test --- + +func createMockCodexCLI(t *testing.T, events []string) string { + t.Helper() + tmpDir := t.TempDir() + scriptPath := filepath.Join(tmpDir, "codex") + + var sb strings.Builder + sb.WriteString("#!/bin/bash\n") + for _, event := range events { + sb.WriteString(fmt.Sprintf("echo '%s'\n", event)) + } + + if err := os.WriteFile(scriptPath, []byte(sb.String()), 0755); err != nil { + t.Fatal(err) + } + return scriptPath +} + +func TestCodexCliProvider_MockCLI_Success(t *testing.T) { + scriptPath := createMockCodexCLI(t, []string{ + `{"type":"thread.started","thread_id":"test-123"}`, + `{"type":"turn.started"}`, + `{"type":"item.completed","item":{"id":"item_1","type":"agent_message","text":"Mock response from Codex CLI"}}`, + `{"type":"turn.completed","usage":{"input_tokens":50,"cached_input_tokens":10,"output_tokens":15}}`, + }) + + p := &CodexCliProvider{ + command: scriptPath, + workspace: "", + } + + messages := []Message{{Role: "user", Content: "Hello"}} + resp, err := p.Chat(context.Background(), messages, nil, "", nil) + if err != nil { + t.Fatalf("Chat() error: %v", err) + } + if resp.Content != "Mock response from Codex CLI" { + t.Errorf("Content = %q, want %q", resp.Content, "Mock response from Codex CLI") + } + if resp.Usage == nil { + t.Fatal("Usage should not be nil") + } + if resp.Usage.PromptTokens != 60 { + t.Errorf("PromptTokens = %d, want 60", resp.Usage.PromptTokens) + } + if resp.Usage.CompletionTokens != 15 { + t.Errorf("CompletionTokens = %d, want 15", resp.Usage.CompletionTokens) + } +} + +func TestCodexCliProvider_MockCLI_Error(t *testing.T) { + scriptPath := createMockCodexCLI(t, []string{ + `{"type":"thread.started","thread_id":"test-err"}`, + `{"type":"turn.started"}`, + `{"type":"error","message":"auth token expired"}`, + `{"type":"turn.failed","error":{"message":"auth token expired"}}`, + }) + + p := &CodexCliProvider{ + command: scriptPath, + workspace: "", + } + + messages := []Message{{Role: "user", Content: "Hello"}} + _, err := p.Chat(context.Background(), messages, nil, "", nil) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "auth token expired") { + t.Errorf("error = %q, want to contain 'auth token expired'", err.Error()) + } +} + +func TestCodexCliProvider_MockCLI_WithModel(t *testing.T) { + // Mock script that captures args to verify model flag is passed + tmpDir := t.TempDir() + scriptPath := filepath.Join(tmpDir, "codex") + script := `#!/bin/bash +# Write args to a file for verification +echo "$@" > "` + filepath.Join(tmpDir, "args.txt") + `" +echo '{"type":"item.completed","item":{"id":"1","type":"agent_message","text":"ok"}}' +echo '{"type":"turn.completed"}'` + + if err := os.WriteFile(scriptPath, []byte(script), 0755); err != nil { + t.Fatal(err) + } + + p := &CodexCliProvider{ + command: scriptPath, + workspace: "/tmp/test-workspace", + } + + messages := []Message{{Role: "user", Content: "test"}} + _, err := p.Chat(context.Background(), messages, nil, "gpt-5.2-codex", nil) + if err != nil { + t.Fatalf("Chat() error: %v", err) + } + + // Verify the args + argsData, err := os.ReadFile(filepath.Join(tmpDir, "args.txt")) + if err != nil { + t.Fatalf("reading args: %v", err) + } + args := string(argsData) + + if !strings.Contains(args, "-m gpt-5.2-codex") { + t.Errorf("args should contain model flag, got: %s", args) + } + if !strings.Contains(args, "-C /tmp/test-workspace") { + t.Errorf("args should contain workspace flag, got: %s", args) + } + if !strings.Contains(args, "--json") { + t.Errorf("args should contain --json, got: %s", args) + } + if !strings.Contains(args, "--dangerously-bypass-approvals-and-sandbox") { + t.Errorf("args should contain bypass flag, got: %s", args) + } +} + +func TestCodexCliProvider_MockCLI_ContextCancel(t *testing.T) { + // Script that sleeps forever + tmpDir := t.TempDir() + scriptPath := filepath.Join(tmpDir, "codex") + script := "#!/bin/bash\nsleep 60" + + if err := os.WriteFile(scriptPath, []byte(script), 0755); err != nil { + t.Fatal(err) + } + + p := &CodexCliProvider{ + command: scriptPath, + workspace: "", + } + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // cancel immediately + + messages := []Message{{Role: "user", Content: "test"}} + _, err := p.Chat(ctx, messages, nil, "", nil) + if err == nil { + t.Fatal("expected error on canceled context") + } +} + +func TestCodexCliProvider_EmptyCommand(t *testing.T) { + p := &CodexCliProvider{command: ""} + + messages := []Message{{Role: "user", Content: "test"}} + _, err := p.Chat(context.Background(), messages, nil, "", nil) + if err == nil { + t.Fatal("expected error for empty command") + } +} + +// --- Integration Test (requires real codex CLI with valid auth) --- + +func TestCodexCliProvider_Integration(t *testing.T) { + if os.Getenv("PICOCLAW_INTEGRATION_TESTS") == "" { + t.Skip("skipping integration test (set PICOCLAW_INTEGRATION_TESTS=1 to enable)") + } + + // Verify codex is available + codexPath, err := exec.LookPath("codex") + if err != nil { + t.Skip("codex CLI not found in PATH") + } + + p := &CodexCliProvider{ + command: codexPath, + workspace: "", + } + + messages := []Message{ + {Role: "user", Content: "Respond with just the word 'hello' and nothing else."}, + } + + resp, err := p.Chat(context.Background(), messages, nil, "", nil) + if err != nil { + t.Fatalf("Chat() error: %v", err) + } + + lower := strings.ToLower(strings.TrimSpace(resp.Content)) + if !strings.Contains(lower, "hello") { + t.Errorf("Content = %q, expected to contain 'hello'", resp.Content) + } +} diff --git a/pkg/providers/codex_provider.go b/pkg/providers/codex_provider.go index c0b10bd5b..6dff3a52e 100644 --- a/pkg/providers/codex_provider.go +++ b/pkg/providers/codex_provider.go @@ -3,6 +3,7 @@ package providers import ( "context" "encoding/json" + "errors" "fmt" "strings" @@ -10,8 +11,12 @@ import ( "github.com/openai/openai-go/v3/option" "github.com/openai/openai-go/v3/responses" "github.com/sipeed/picoclaw/pkg/auth" + "github.com/sipeed/picoclaw/pkg/logger" ) +const codexDefaultModel = "gpt-5.2" +const codexDefaultInstructions = "You are Codex, a coding assistant." + type CodexProvider struct { client *openai.Client accountID string @@ -24,6 +29,8 @@ func NewCodexProvider(token, accountID string) *CodexProvider { opts := []option.RequestOption{ option.WithBaseURL("https://chatgpt.com/backend-api/codex"), option.WithAPIKey(token), + option.WithHeader("originator", "codex_cli_rs"), + option.WithHeader("OpenAI-Beta", "responses=experimental"), } if accountID != "" { opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accountID)) @@ -43,6 +50,15 @@ func NewCodexProviderWithTokenSource(token, accountID string, tokenSource func() func (p *CodexProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { var opts []option.RequestOption + accountID := p.accountID + resolvedModel, fallbackReason := resolveCodexModel(model) + if fallbackReason != "" { + logger.WarnCF("provider.codex", "Requested model is not compatible with Codex backend, using fallback", map[string]interface{}{ + "requested_model": model, + "resolved_model": resolvedModel, + "reason": fallbackReason, + }) + } if p.tokenSource != nil { tok, accID, err := p.tokenSource() if err != nil { @@ -50,22 +66,120 @@ func (p *CodexProvider) Chat(ctx context.Context, messages []Message, tools []To } opts = append(opts, option.WithAPIKey(tok)) if accID != "" { - opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accID)) + accountID = accID } } + if accountID != "" { + opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accountID)) + } else { + logger.WarnCF("provider.codex", "No account id found for Codex request; backend may reject with 400", map[string]interface{}{ + "requested_model": model, + "resolved_model": resolvedModel, + }) + } - params := buildCodexParams(messages, tools, model, options) + params := buildCodexParams(messages, tools, resolvedModel, options) - resp, err := p.client.Responses.New(ctx, params, opts...) + stream := p.client.Responses.NewStreaming(ctx, params, opts...) + defer stream.Close() + + var resp *responses.Response + for stream.Next() { + evt := stream.Current() + if evt.Type == "response.completed" || evt.Type == "response.failed" || evt.Type == "response.incomplete" { + evtResp := evt.Response + if evtResp.ID != "" { + copy := evtResp + resp = © + } + } + } + err := stream.Err() if err != nil { + fields := map[string]interface{}{ + "requested_model": model, + "resolved_model": resolvedModel, + "messages_count": len(messages), + "tools_count": len(tools), + "account_id_present": accountID != "", + "error": err.Error(), + } + var apiErr *openai.Error + if errors.As(err, &apiErr) { + fields["status_code"] = apiErr.StatusCode + fields["api_type"] = apiErr.Type + fields["api_code"] = apiErr.Code + fields["api_param"] = apiErr.Param + fields["api_message"] = apiErr.Message + if apiErr.StatusCode == 400 { + fields["hint"] = "verify account id header and model compatibility for codex backend" + } + if apiErr.Response != nil { + fields["request_id"] = apiErr.Response.Header.Get("x-request-id") + } + } + logger.ErrorCF("provider.codex", "Codex API call failed", fields) return nil, fmt.Errorf("codex API call: %w", err) } + if resp == nil { + fields := map[string]interface{}{ + "requested_model": model, + "resolved_model": resolvedModel, + "messages_count": len(messages), + "tools_count": len(tools), + "account_id_present": accountID != "", + } + logger.ErrorCF("provider.codex", "Codex stream ended without completed response event", fields) + return nil, fmt.Errorf("codex API call: stream ended without completed response") + } return parseCodexResponse(resp), nil } func (p *CodexProvider) GetDefaultModel() string { - return "gpt-4o" + return codexDefaultModel +} + +func resolveCodexModel(model string) (string, string) { + m := strings.ToLower(strings.TrimSpace(model)) + if m == "" { + return codexDefaultModel, "empty model" + } + + if strings.HasPrefix(m, "openai/") { + m = strings.TrimPrefix(m, "openai/") + } else if strings.Contains(m, "/") { + return codexDefaultModel, "non-openai model namespace" + } + + unsupportedPrefixes := []string{ + "glm", + "claude", + "anthropic", + "gemini", + "google", + "moonshot", + "kimi", + "qwen", + "deepseek", + "llama", + "meta-llama", + "mistral", + "grok", + "xai", + "zhipu", + } + for _, prefix := range unsupportedPrefixes { + if strings.HasPrefix(m, prefix) { + return codexDefaultModel, "unsupported model prefix" + } + } + + if strings.HasPrefix(m, "gpt-") || strings.HasPrefix(m, "o3") || strings.HasPrefix(m, "o4") { + return m, "" + } + + return codexDefaultModel, "unsupported model family" } func buildCodexParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) responses.ResponseNewParams { @@ -135,7 +249,8 @@ func buildCodexParams(messages []Message, tools []ToolDefinition, model string, Input: responses.ResponseNewParamsInputUnion{ OfInputItemList: inputItems, }, - Store: openai.Opt(false), + Instructions: openai.Opt(instructions), + Store: openai.Opt(false), } if instructions != "" { @@ -149,10 +264,6 @@ func buildCodexParams(messages []Message, tools []ToolDefinition, model string, params.MaxOutputTokens = openai.Opt(int64(maxTokens)) } - if temp, ok := options["temperature"].(float64); ok { - params.Temperature = openai.Opt(temp) - } - if len(tools) > 0 { params.Tools = translateToolsForCodex(tools) } @@ -242,6 +353,9 @@ func createCodexTokenSource() func() (string, string, error) { if err != nil { return "", "", fmt.Errorf("refreshing token: %w", err) } + if refreshed.AccountID == "" { + refreshed.AccountID = cred.AccountID + } if err := auth.SetCredential("openai", refreshed); err != nil { return "", "", fmt.Errorf("saving refreshed token: %w", err) } diff --git a/pkg/providers/codex_provider_test.go b/pkg/providers/codex_provider_test.go index 1a5a8cafa..317b1a5de 100644 --- a/pkg/providers/codex_provider_test.go +++ b/pkg/providers/codex_provider_test.go @@ -2,6 +2,7 @@ package providers import ( "encoding/json" + "fmt" "net/http" "net/http/httptest" "testing" @@ -16,7 +17,8 @@ func TestBuildCodexParams_BasicMessage(t *testing.T) { {Role: "user", Content: "Hello"}, } params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{ - "max_tokens": 2048, + "max_tokens": 2048, + "temperature": 0.7, }) if params.Model != "gpt-4o" { t.Errorf("Model = %q, want %q", params.Model, "gpt-4o") @@ -203,6 +205,16 @@ func TestCodexProvider_ChatRoundTrip(t *testing.T) { return } + var reqBody map[string]interface{} + if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil { + http.Error(w, "invalid json", http.StatusBadRequest) + return + } + if reqBody["stream"] != true { + http.Error(w, "stream must be true", http.StatusBadRequest) + return + } + resp := map[string]interface{}{ "id": "resp_test", "object": "response", @@ -226,8 +238,7 @@ func TestCodexProvider_ChatRoundTrip(t *testing.T) { "output_tokens_details": map[string]interface{}{"reasoning_tokens": 0}, }, } - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(resp) + writeCompletedSSE(w, resp) })) defer server.Close() @@ -250,10 +261,185 @@ func TestCodexProvider_ChatRoundTrip(t *testing.T) { } } +func TestCodexProvider_ChatRoundTrip_TokenSourceFallbackAccountID(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/responses" { + http.Error(w, "not found: "+r.URL.Path, http.StatusNotFound) + return + } + if r.Header.Get("Authorization") != "Bearer refreshed-token" { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + if r.Header.Get("Chatgpt-Account-Id") != "acc-123" { + http.Error(w, "missing account id", http.StatusBadRequest) + return + } + + var reqBody map[string]interface{} + if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil { + http.Error(w, "invalid json", http.StatusBadRequest) + return + } + if _, ok := reqBody["instructions"]; !ok { + http.Error(w, "missing instructions", http.StatusBadRequest) + return + } + if reqBody["instructions"] == "" { + http.Error(w, "instructions must not be empty", http.StatusBadRequest) + return + } + if _, ok := reqBody["temperature"]; ok { + http.Error(w, "temperature is not supported", http.StatusBadRequest) + return + } + if reqBody["stream"] != true { + http.Error(w, "stream must be true", http.StatusBadRequest) + return + } + + resp := map[string]interface{}{ + "id": "resp_test", + "object": "response", + "status": "completed", + "output": []map[string]interface{}{ + { + "id": "msg_1", + "type": "message", + "role": "assistant", + "status": "completed", + "content": []map[string]interface{}{ + {"type": "output_text", "text": "Hi from Codex!"}, + }, + }, + }, + "usage": map[string]interface{}{ + "input_tokens": 8, + "output_tokens": 4, + "total_tokens": 12, + "input_tokens_details": map[string]interface{}{"cached_tokens": 0}, + "output_tokens_details": map[string]interface{}{"reasoning_tokens": 0}, + }, + } + writeCompletedSSE(w, resp) + })) + defer server.Close() + + provider := NewCodexProvider("stale-token", "acc-123") + provider.client = createOpenAITestClient(server.URL, "stale-token", "") + provider.tokenSource = func() (string, string, error) { + return "refreshed-token", "", nil + } + + messages := []Message{{Role: "user", Content: "Hello"}} + resp, err := provider.Chat(t.Context(), messages, nil, "gpt-4o", map[string]interface{}{"temperature": 0.7}) + if err != nil { + t.Fatalf("Chat() error: %v", err) + } + if resp.Content != "Hi from Codex!" { + t.Errorf("Content = %q, want %q", resp.Content, "Hi from Codex!") + } +} + +func TestCodexProvider_ChatRoundTrip_ModelFallbackFromUnsupported(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/responses" { + http.Error(w, "not found: "+r.URL.Path, http.StatusNotFound) + return + } + + var reqBody map[string]interface{} + if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil { + http.Error(w, "invalid json", http.StatusBadRequest) + return + } + if reqBody["model"] != codexDefaultModel { + http.Error(w, "unsupported model", http.StatusBadRequest) + return + } + if reqBody["stream"] != true { + http.Error(w, "stream must be true", http.StatusBadRequest) + return + } + if reqBody["instructions"] != codexDefaultInstructions { + http.Error(w, "missing default instructions", http.StatusBadRequest) + return + } + + resp := map[string]interface{}{ + "id": "resp_test", + "object": "response", + "status": "completed", + "output": []map[string]interface{}{ + { + "id": "msg_1", + "type": "message", + "role": "assistant", + "status": "completed", + "content": []map[string]interface{}{ + {"type": "output_text", "text": "Hi from Codex!"}, + }, + }, + }, + "usage": map[string]interface{}{ + "input_tokens": 8, + "output_tokens": 4, + "total_tokens": 12, + "input_tokens_details": map[string]interface{}{"cached_tokens": 0}, + "output_tokens_details": map[string]interface{}{"reasoning_tokens": 0}, + }, + } + writeCompletedSSE(w, resp) + })) + defer server.Close() + + provider := NewCodexProvider("test-token", "acc-123") + provider.client = createOpenAITestClient(server.URL, "test-token", "acc-123") + + messages := []Message{{Role: "user", Content: "Hello"}} + resp, err := provider.Chat(t.Context(), messages, nil, "gpt-5.2", nil) + if err != nil { + t.Fatalf("Chat() error: %v", err) + } + if resp.Content != "Hi from Codex!" { + t.Errorf("Content = %q, want %q", resp.Content, "Hi from Codex!") + } +} + func TestCodexProvider_GetDefaultModel(t *testing.T) { p := NewCodexProvider("test-token", "") - if got := p.GetDefaultModel(); got != "gpt-4o" { - t.Errorf("GetDefaultModel() = %q, want %q", got, "gpt-4o") + if got := p.GetDefaultModel(); got != codexDefaultModel { + t.Errorf("GetDefaultModel() = %q, want %q", got, codexDefaultModel) + } +} + +func TestResolveCodexModel(t *testing.T) { + tests := []struct { + name string + input string + wantModel string + wantFallback bool + }{ + {name: "empty", input: "", wantModel: codexDefaultModel, wantFallback: true}, + {name: "unsupported namespace", input: "anthropic/claude-3.5", wantModel: codexDefaultModel, wantFallback: true}, + {name: "non-openai prefixed", input: "glm-4.7", wantModel: codexDefaultModel, wantFallback: true}, + {name: "openai prefix", input: "openai/gpt-5.2", wantModel: "gpt-5.2", wantFallback: false}, + {name: "direct gpt", input: "gpt-4o", wantModel: "gpt-4o", wantFallback: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotModel, reason := resolveCodexModel(tt.input) + if gotModel != tt.wantModel { + t.Fatalf("resolveCodexModel(%q) model = %q, want %q", tt.input, gotModel, tt.wantModel) + } + if tt.wantFallback && reason == "" { + t.Fatalf("resolveCodexModel(%q) expected fallback reason", tt.input) + } + if !tt.wantFallback && reason != "" { + t.Fatalf("resolveCodexModel(%q) unexpected fallback reason: %q", tt.input, reason) + } + }) } } @@ -268,3 +454,16 @@ func createOpenAITestClient(baseURL, token, accountID string) *openai.Client { c := openai.NewClient(opts...) return &c } + +func writeCompletedSSE(w http.ResponseWriter, response map[string]interface{}) { + event := map[string]interface{}{ + "type": "response.completed", + "sequence_number": 1, + "response": response, + } + b, _ := json.Marshal(event) + w.Header().Set("Content-Type", "text/event-stream") + fmt.Fprintf(w, "event: response.completed\n") + fmt.Fprintf(w, "data: %s\n\n", string(b)) + fmt.Fprintf(w, "data: [DONE]\n\n") +} diff --git a/pkg/providers/http_provider.go b/pkg/providers/http_provider.go index 17eb6214c..4cf2c6db2 100644 --- a/pkg/providers/http_provider.go +++ b/pkg/providers/http_provider.go @@ -53,10 +53,10 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too return nil, fmt.Errorf("API base not configured") } - // Strip provider prefix from model name (e.g., moonshot/kimi-k2.5 -> kimi-k2.5) + // Strip provider prefix from model name (e.g., moonshot/kimi-k2.5 -> kimi-k2.5, groq/openai/gpt-oss-120b -> openai/gpt-oss-120b, ollama/qwen2.5:14b -> qwen2.5:14b) if idx := strings.Index(model, "/"); idx != -1 { prefix := model[:idx] - if prefix == "moonshot" || prefix == "nvidia" { + if prefix == "moonshot" || prefix == "nvidia" || prefix == "groq" || prefix == "ollama" { model = model[idx+1:] } } @@ -240,6 +240,9 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) { } case "openai", "gpt": if cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != "" { + if cfg.Providers.OpenAI.AuthMethod == "codex-cli" { + return NewCodexProviderWithTokenSource("", "", CreateCodexCliTokenSource()), nil + } if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" { return createCodexAuthProvider() } @@ -299,11 +302,17 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) { } } case "claude-cli", "claudecode", "claude-code": - workspace := cfg.Agents.Defaults.Workspace + workspace := cfg.WorkspacePath() if workspace == "" { workspace = "." } return NewClaudeCliProvider(workspace), nil + case "codex-cli", "codex-code": + workspace := cfg.WorkspacePath() + if workspace == "" { + workspace = "." + } + return NewCodexCliProvider(workspace), nil case "deepseek": if cfg.Providers.DeepSeek.APIKey != "" { apiKey = cfg.Providers.DeepSeek.APIKey @@ -400,7 +409,15 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) { if apiBase == "" { apiBase = "https://integrate.api.nvidia.com/v1" } - + case (strings.Contains(lowerModel, "ollama") || strings.HasPrefix(model, "ollama/")) && cfg.Providers.Ollama.APIKey != "": + fmt.Println("Ollama provider selected based on model name prefix") + apiKey = cfg.Providers.Ollama.APIKey + apiBase = cfg.Providers.Ollama.APIBase + proxy = cfg.Providers.Ollama.Proxy + if apiBase == "" { + apiBase = "http://localhost:11434/v1" + } + fmt.Println("Ollama apiBase:", apiBase) case cfg.Providers.VLLM.APIBase != "": apiKey = cfg.Providers.VLLM.APIKey apiBase = cfg.Providers.VLLM.APIBase diff --git a/pkg/providers/tool_call_extract.go b/pkg/providers/tool_call_extract.go new file mode 100644 index 000000000..97a219283 --- /dev/null +++ b/pkg/providers/tool_call_extract.go @@ -0,0 +1,72 @@ +package providers + +import ( + "encoding/json" + "strings" +) + +// extractToolCallsFromText parses tool call JSON from response text. +// Both ClaudeCliProvider and CodexCliProvider use this to extract +// tool calls that the model outputs in its response text. +func extractToolCallsFromText(text string) []ToolCall { + start := strings.Index(text, `{"tool_calls"`) + if start == -1 { + return nil + } + + end := findMatchingBrace(text, start) + if end == start { + return nil + } + + jsonStr := text[start:end] + + var wrapper struct { + ToolCalls []struct { + ID string `json:"id"` + Type string `json:"type"` + Function struct { + Name string `json:"name"` + Arguments string `json:"arguments"` + } `json:"function"` + } `json:"tool_calls"` + } + + if err := json.Unmarshal([]byte(jsonStr), &wrapper); err != nil { + return nil + } + + var result []ToolCall + for _, tc := range wrapper.ToolCalls { + var args map[string]interface{} + json.Unmarshal([]byte(tc.Function.Arguments), &args) + + result = append(result, ToolCall{ + ID: tc.ID, + Type: tc.Type, + Name: tc.Function.Name, + Arguments: args, + Function: &FunctionCall{ + Name: tc.Function.Name, + Arguments: tc.Function.Arguments, + }, + }) + } + + return result +} + +// stripToolCallsFromText removes tool call JSON from response text. +func stripToolCallsFromText(text string) string { + start := strings.Index(text, `{"tool_calls"`) + if start == -1 { + return text + } + + end := findMatchingBrace(text, start) + if end == start { + return text + } + + return strings.TrimSpace(text[:start] + text[end:]) +} diff --git a/pkg/session/manager.go b/pkg/session/manager.go index 9981d4901..12bf33df0 100644 --- a/pkg/session/manager.go +++ b/pkg/session/manager.go @@ -264,3 +264,19 @@ func (sm *SessionManager) loadSessions() error { return nil } + +// SetHistory updates the messages of a session. +func (sm *SessionManager) SetHistory(key string, history []providers.Message) { + sm.mu.Lock() + defer sm.mu.Unlock() + + session, ok := sm.sessions[key] + if ok { + // Create a deep copy to strictly isolate internal state + // from the caller's slice. + msgs := make([]providers.Message, len(history)) + copy(msgs, history) + session.Messages = msgs + session.Updated = time.Now() + } +} diff --git a/pkg/skills/loader.go b/pkg/skills/loader.go index 1f952c1f5..0c63ae067 100644 --- a/pkg/skills/loader.go +++ b/pkg/skills/loader.go @@ -2,13 +2,22 @@ package skills import ( "encoding/json" + "errors" "fmt" + "log/slog" "os" "path/filepath" "regexp" "strings" ) +var namePattern = regexp.MustCompile(`^[a-zA-Z0-9]+(-[a-zA-Z0-9]+)*$`) + +const ( + MaxNameLength = 64 + MaxDescriptionLength = 1024 +) + type SkillMetadata struct { Name string `json:"name"` Description string `json:"description"` @@ -21,6 +30,27 @@ type SkillInfo struct { Description string `json:"description"` } +func (info SkillInfo) validate() error { + var errs error + if info.Name == "" { + errs = errors.Join(errs, errors.New("name is required")) + } else { + if len(info.Name) > MaxNameLength { + errs = errors.Join(errs, fmt.Errorf("name exceeds %d characters", MaxNameLength)) + } + if !namePattern.MatchString(info.Name) { + errs = errors.Join(errs, errors.New("name must be alphanumeric with hyphens")) + } + } + + if info.Description == "" { + errs = errors.Join(errs, errors.New("description is required")) + } else if len(info.Description) > MaxDescriptionLength { + errs = errors.Join(errs, fmt.Errorf("description exceeds %d character", MaxDescriptionLength)) + } + return errs +} + type SkillsLoader struct { workspace string workspaceSkills string // workspace skills (้กน็›ฎ็บงๅˆซ) @@ -54,6 +84,11 @@ func (sl *SkillsLoader) ListSkills() []SkillInfo { metadata := sl.getSkillMetadata(skillFile) if metadata != nil { info.Description = metadata.Description + info.Name = metadata.Name + } + if err := info.validate(); err != nil { + slog.Warn("invalid skill from workspace", "name", info.Name, "error", err) + continue } skills = append(skills, info) } @@ -89,6 +124,11 @@ func (sl *SkillsLoader) ListSkills() []SkillInfo { metadata := sl.getSkillMetadata(skillFile) if metadata != nil { info.Description = metadata.Description + info.Name = metadata.Name + } + if err := info.validate(); err != nil { + slog.Warn("invalid skill from global", "name", info.Name, "error", err) + continue } skills = append(skills, info) } @@ -123,6 +163,11 @@ func (sl *SkillsLoader) ListSkills() []SkillInfo { metadata := sl.getSkillMetadata(skillFile) if metadata != nil { info.Description = metadata.Description + info.Name = metadata.Name + } + if err := info.validate(); err != nil { + slog.Warn("invalid skill from builtin", "name", info.Name, "error", err) + continue } skills = append(skills, info) } diff --git a/pkg/skills/loader_test.go b/pkg/skills/loader_test.go new file mode 100644 index 000000000..e0e7109cf --- /dev/null +++ b/pkg/skills/loader_test.go @@ -0,0 +1,77 @@ +package skills + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSkillsInfoValidate(t *testing.T) { + testcases := []struct { + name string + skillName string + description string + wantErr bool + errContains []string + }{ + { + name: "valid-skill", + skillName: "valid-skill", + description: "a valid skill description", + wantErr: false, + }, + { + name: "empty-name", + skillName: "", + description: "description without name", + wantErr: true, + errContains: []string{"name is required"}, + }, + { + name: "empty-description", + skillName: "skill-without-description", + description: "", + wantErr: true, + errContains: []string{"description is required"}, + }, + { + name: "empty-both", + skillName: "", + description: "", + wantErr: true, + errContains: []string{"name is required", "description is required"}, + }, + { + name: "name-with-spaces", + skillName: "skill with spaces", + description: "invalid name with spaces", + wantErr: true, + errContains: []string{"name must be alphanumeric with hyphens"}, + }, + { + name: "name-with-underscore", + skillName: "skill_underscore", + description: "invalid name with underscore", + wantErr: true, + errContains: []string{"name must be alphanumeric with hyphens"}, + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + info := SkillInfo{ + Name: tc.skillName, + Description: tc.description, + } + err := info.validate() + if tc.wantErr { + assert.Error(t, err) + for _, msg := range tc.errContains { + assert.ErrorContains(t, err, msg) + } + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/pkg/tools/cron.go b/pkg/tools/cron.go index 0ef745e2b..4b6f973d8 100644 --- a/pkg/tools/cron.go +++ b/pkg/tools/cron.go @@ -28,12 +28,12 @@ type CronTool struct { } // NewCronTool creates a new CronTool -func NewCronTool(cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus, workspace string) *CronTool { +func NewCronTool(cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus, workspace string, restrict bool) *CronTool { return &CronTool{ cronService: cronService, executor: executor, msgBus: msgBus, - execTool: NewExecTool(workspace, false), + execTool: NewExecTool(workspace, restrict), } } diff --git a/pkg/tools/filesystem.go b/pkg/tools/filesystem.go index 237687734..09063ea0a 100644 --- a/pkg/tools/filesystem.go +++ b/pkg/tools/filesystem.go @@ -29,13 +29,54 @@ func validatePath(path, workspace string, restrict bool) (string, error) { } } - if restrict && !strings.HasPrefix(absPath, absWorkspace) { - return "", fmt.Errorf("access denied: path is outside the workspace") + if restrict { + if !isWithinWorkspace(absPath, absWorkspace) { + return "", fmt.Errorf("access denied: path is outside the workspace") + } + + workspaceReal := absWorkspace + if resolved, err := filepath.EvalSymlinks(absWorkspace); err == nil { + workspaceReal = resolved + } + + if resolved, err := filepath.EvalSymlinks(absPath); err == nil { + if !isWithinWorkspace(resolved, workspaceReal) { + return "", fmt.Errorf("access denied: symlink resolves outside workspace") + } + } else if os.IsNotExist(err) { + if parentResolved, err := resolveExistingAncestor(filepath.Dir(absPath)); err == nil { + if !isWithinWorkspace(parentResolved, workspaceReal) { + return "", fmt.Errorf("access denied: symlink resolves outside workspace") + } + } else if !os.IsNotExist(err) { + return "", fmt.Errorf("failed to resolve path: %w", err) + } + } else { + return "", fmt.Errorf("failed to resolve path: %w", err) + } } return absPath, nil } +func resolveExistingAncestor(path string) (string, error) { + for current := filepath.Clean(path); ; current = filepath.Dir(current) { + if resolved, err := filepath.EvalSymlinks(current); err == nil { + return resolved, nil + } else if !os.IsNotExist(err) { + return "", err + } + if filepath.Dir(current) == current { + return "", os.ErrNotExist + } + } +} + +func isWithinWorkspace(candidate, workspace string) bool { + rel, err := filepath.Rel(filepath.Clean(workspace), filepath.Clean(candidate)) + return err == nil && rel != ".." && !strings.HasPrefix(rel, ".."+string(os.PathSeparator)) +} + type ReadFileTool struct { workspace string restrict bool diff --git a/pkg/tools/filesystem_test.go b/pkg/tools/filesystem_test.go index 2707f29b5..958036419 100644 --- a/pkg/tools/filesystem_test.go +++ b/pkg/tools/filesystem_test.go @@ -247,3 +247,35 @@ func TestFilesystemTool_ListDir_DefaultPath(t *testing.T) { t.Errorf("Expected success with default path '.', got IsError=true: %s", result.ForLLM) } } + +// Block paths that look inside workspace but point outside via symlink. +func TestFilesystemTool_ReadFile_RejectsSymlinkEscape(t *testing.T) { + + root := t.TempDir() + workspace := filepath.Join(root, "workspace") + if err := os.MkdirAll(workspace, 0755); err != nil { + t.Fatalf("failed to create workspace: %v", err) + } + + secret := filepath.Join(root, "secret.txt") + if err := os.WriteFile(secret, []byte("top secret"), 0644); err != nil { + t.Fatalf("failed to write secret file: %v", err) + } + + link := filepath.Join(workspace, "leak.txt") + if err := os.Symlink(secret, link); err != nil { + t.Skipf("symlink not supported in this environment: %v", err) + } + + tool := NewReadFileTool(workspace, true) + result := tool.Execute(context.Background(), map[string]interface{}{ + "path": link, + }) + + if !result.IsError { + t.Fatalf("expected symlink escape to be blocked") + } + if !strings.Contains(result.ForLLM, "symlink resolves outside workspace") { + t.Fatalf("expected symlink escape error, got: %s", result.ForLLM) + } +} diff --git a/pkg/tools/web_test.go b/pkg/tools/web_test.go index 988eada16..a526ea34a 100644 --- a/pkg/tools/web_test.go +++ b/pkg/tools/web_test.go @@ -173,19 +173,23 @@ func TestWebTool_WebFetch_Truncation(t *testing.T) { } } -// TestWebTool_WebSearch_NoApiKey verifies that nil is returned when no provider is configured +// TestWebTool_WebSearch_NoApiKey verifies that no tool is created when API key is missing func TestWebTool_WebSearch_NoApiKey(t *testing.T) { - tool := NewWebSearchTool(WebSearchToolOptions{BraveAPIKey: "", BraveMaxResults: 5}) - - // Should return nil when no provider is enabled + tool := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: ""}) if tool != nil { - t.Errorf("Expected nil when no search provider is configured") + t.Errorf("Expected nil tool when Brave API key is empty") + } + + // Also nil when nothing is enabled + tool = NewWebSearchTool(WebSearchToolOptions{}) + if tool != nil { + t.Errorf("Expected nil tool when no provider is enabled") } } // TestWebTool_WebSearch_MissingQuery verifies error handling for missing query func TestWebTool_WebSearch_MissingQuery(t *testing.T) { - tool := NewWebSearchTool(WebSearchToolOptions{BraveAPIKey: "test-key", BraveMaxResults: 5, BraveEnabled: true}) + tool := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: "test-key", BraveMaxResults: 5}) ctx := context.Background() args := map[string]interface{}{} diff --git a/pkg/utils/media.go b/pkg/utils/media.go index 6345da8fc..2b184f2ec 100644 --- a/pkg/utils/media.go +++ b/pkg/utils/media.go @@ -73,9 +73,8 @@ func DownloadFile(url, filename string, opts DownloadOptions) string { } // Generate unique filename with UUID prefix to prevent conflicts - ext := filepath.Ext(filename) safeName := SanitizeFilename(filename) - localPath := filepath.Join(mediaDir, uuid.New().String()[:8]+"_"+safeName+ext) + localPath := filepath.Join(mediaDir, uuid.New().String()[:8]+"_"+safeName) // Create HTTP request req, err := http.NewRequest("GET", url, nil)