Merge upstream/main into feat/subturn-poc

Includes JSONL session persistence (#1170), spawn_status tool, Azure provider,
credential encryption, and various fixes. SubTurn features preserved and
integrated with new spawn_status functionality.
This commit is contained in:
Administrator
2026-03-17 21:55:20 +08:00
110 changed files with 7413 additions and 1547 deletions
+27
View File
@@ -0,0 +1,27 @@
version: 2
updates:
# Go dependencies (entire repo)
- package-ecosystem: "gomod"
directory: "/"
schedule:
interval: "weekly"
labels:
- "dependencies"
- "go"
# Frontend dependencies
- package-ecosystem: "npm"
directory: "/web/frontend"
schedule:
interval: "weekly"
labels:
- "dependencies"
- "frontend"
# GitHub Actions
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "weekly"
+4 -4
View File
@@ -31,11 +31,11 @@ jobs:
# ── Docker Buildx ─────────────────────────
- name: 🔧 Set up Docker Buildx
uses: docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@v4
# ── Login to GHCR ─────────────────────────
- name: 🔑 Login to GitHub Container Registry
uses: docker/login-action@v3
uses: docker/login-action@v4
with:
registry: ${{ env.GHCR_REGISTRY }}
username: ${{ github.actor }}
@@ -43,7 +43,7 @@ jobs:
# ── Login to Docker Hub ────────────────────
- name: 🔑 Login to Docker Hub
uses: docker/login-action@v3
uses: docker/login-action@v4
with:
registry: ${{ env.DOCKERHUB_REGISTRY }}
username: ${{ secrets.DOCKERHUB_USERNAME }}
@@ -62,7 +62,7 @@ jobs:
# ── Build & Push ──────────────────────────
- name: 🚀 Build and push Docker image
uses: docker/build-push-action@v6
uses: docker/build-push-action@v7
with:
context: .
push: true
+4 -4
View File
@@ -48,7 +48,7 @@ jobs:
go-version-file: go.mod
- name: Setup Node.js
uses: actions/setup-node@v4
uses: actions/setup-node@v6
with:
node-version: 22
@@ -59,17 +59,17 @@ jobs:
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@v4
- name: Login to GitHub Container Registry
uses: docker/login-action@v3
uses: docker/login-action@v4
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Login to Docker Hub
uses: docker/login-action@v3
uses: docker/login-action@v4
with:
registry: docker.io
username: ${{ secrets.DOCKERHUB_USERNAME }}
+1 -1
View File
@@ -34,7 +34,7 @@ jobs:
persist-credentials: false
- name: Setup Go
uses: actions/setup-go@v5
uses: actions/setup-go@v6
with:
go-version-file: go.mod
+4 -4
View File
@@ -66,7 +66,7 @@ jobs:
go-version-file: go.mod
- name: Setup Node.js
uses: actions/setup-node@v4
uses: actions/setup-node@v6
with:
node-version: 22
@@ -77,17 +77,17 @@ jobs:
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@v4
- name: Login to GitHub Container Registry
uses: docker/login-action@v3
uses: docker/login-action@v4
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Login to Docker Hub
uses: docker/login-action@v3
uses: docker/login-action@v4
with:
registry: docker.io
username: ${{ secrets.DOCKERHUB_USERNAME }}
+33 -28
View File
@@ -12,10 +12,11 @@ GIT_COMMIT=$(shell git rev-parse --short=8 HEAD 2>/dev/null || echo "dev")
BUILD_TIME=$(shell date +%FT%T%z)
GO_VERSION=$(shell $(GO) version | awk '{print $$3}')
CONFIG_PKG=github.com/sipeed/picoclaw/pkg/config
LDFLAGS=-ldflags "-X $(CONFIG_PKG).Version=$(VERSION) -X $(CONFIG_PKG).GitCommit=$(GIT_COMMIT) -X $(CONFIG_PKG).BuildTime=$(BUILD_TIME) -X $(CONFIG_PKG).GoVersion=$(GO_VERSION) -s -w"
LDFLAGS=-X $(CONFIG_PKG).Version=$(VERSION) -X $(CONFIG_PKG).GitCommit=$(GIT_COMMIT) -X $(CONFIG_PKG).BuildTime=$(BUILD_TIME) -X $(CONFIG_PKG).GoVersion=$(GO_VERSION) -s -w
# Go variables
GO?=CGO_ENABLED=0 go
WEB_GO?=$(GO)
GOFLAGS?=-v -tags stdjson
# Patch MIPS LE ELF e_flags (offset 36) for NaN2008-only kernels (e.g. Ingenic X2600).
@@ -79,6 +80,7 @@ ifeq ($(UNAME_S),Linux)
endif
else ifeq ($(UNAME_S),Darwin)
PLATFORM=darwin
WEB_GO=CGO_ENABLED=1 go
ifeq ($(UNAME_M),x86_64)
ARCH=amd64
else ifeq ($(UNAME_M),arm64)
@@ -107,7 +109,7 @@ generate:
build: generate
@echo "Building $(BINARY_NAME) for $(PLATFORM)/$(ARCH)..."
@mkdir -p $(BUILD_DIR)
@$(GO) build $(GOFLAGS) $(LDFLAGS) -o $(BINARY_PATH) ./$(CMD_DIR)
@$(GO) build $(GOFLAGS) -ldflags "$(LDFLAGS)" -o $(BINARY_PATH) ./$(CMD_DIR)
@echo "Build complete: $(BINARY_PATH)"
@ln -sf $(BINARY_NAME)-$(PLATFORM)-$(ARCH) $(BUILD_DIR)/$(BINARY_NAME)
@@ -119,7 +121,7 @@ build-launcher:
echo "Building frontend..."; \
cd web/frontend && pnpm install && pnpm build:backend; \
fi
@$(GO) build $(GOFLAGS) -o $(BUILD_DIR)/picoclaw-launcher-$(PLATFORM)-$(ARCH) ./web/backend
@$(WEB_GO) build $(GOFLAGS) -o $(BUILD_DIR)/picoclaw-launcher-$(PLATFORM)-$(ARCH) ./web/backend
@ln -sf picoclaw-launcher-$(PLATFORM)-$(ARCH) $(BUILD_DIR)/picoclaw-launcher
@echo "Build complete: $(BUILD_DIR)/picoclaw-launcher"
@@ -128,16 +130,16 @@ build-whatsapp-native: generate
## @echo "Building $(BINARY_NAME) with WhatsApp native for $(PLATFORM)/$(ARCH)..."
@echo "Building for multiple platforms..."
@mkdir -p $(BUILD_DIR)
GOOS=linux GOARCH=amd64 $(GO) build -tags whatsapp_native $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-amd64 ./$(CMD_DIR)
GOOS=linux GOARCH=arm GOARM=7 $(GO) build -tags whatsapp_native $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm ./$(CMD_DIR)
GOOS=linux GOARCH=arm64 $(GO) build -tags whatsapp_native $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm64 ./$(CMD_DIR)
GOOS=linux GOARCH=loong64 $(GO) build -tags whatsapp_native $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-loong64 ./$(CMD_DIR)
GOOS=linux GOARCH=riscv64 $(GO) build -tags whatsapp_native $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-riscv64 ./$(CMD_DIR)
GOOS=linux GOARCH=mipsle GOMIPS=softfloat $(GO) build -tags whatsapp_native $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-mipsle ./$(CMD_DIR)
GOOS=linux GOARCH=amd64 $(GO) build -tags whatsapp_native -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-amd64 ./$(CMD_DIR)
GOOS=linux GOARCH=arm GOARM=7 $(GO) build -tags whatsapp_native -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm ./$(CMD_DIR)
GOOS=linux GOARCH=arm64 $(GO) build -tags whatsapp_native -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm64 ./$(CMD_DIR)
GOOS=linux GOARCH=loong64 $(GO) build -tags whatsapp_native -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-loong64 ./$(CMD_DIR)
GOOS=linux GOARCH=riscv64 $(GO) build -tags whatsapp_native -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-riscv64 ./$(CMD_DIR)
GOOS=linux GOARCH=mipsle GOMIPS=softfloat $(GO) build -tags whatsapp_native -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-mipsle ./$(CMD_DIR)
$(call PATCH_MIPS_FLAGS,$(BUILD_DIR)/$(BINARY_NAME)-linux-mipsle)
GOOS=darwin GOARCH=arm64 $(GO) build -tags whatsapp_native $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-darwin-arm64 ./$(CMD_DIR)
GOOS=windows GOARCH=amd64 $(GO) build -tags whatsapp_native $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-windows-amd64.exe ./$(CMD_DIR)
## @$(GO) build $(GOFLAGS) -tags whatsapp_native $(LDFLAGS) -o $(BINARY_PATH) ./$(CMD_DIR)
GOOS=darwin GOARCH=arm64 $(GO) build -tags whatsapp_native -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-darwin-arm64 ./$(CMD_DIR)
GOOS=windows GOARCH=amd64 $(GO) build -tags whatsapp_native -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-windows-amd64.exe ./$(CMD_DIR)
## @$(GO) build $(GOFLAGS) -tags whatsapp_native -ldflags "$(LDFLAGS)" -o $(BINARY_PATH) ./$(CMD_DIR)
@echo "Build complete"
## @ln -sf $(BINARY_NAME)-$(PLATFORM)-$(ARCH) $(BUILD_DIR)/$(BINARY_NAME)
@@ -145,21 +147,21 @@ build-whatsapp-native: generate
build-linux-arm: generate
@echo "Building for linux/arm (GOARM=7)..."
@mkdir -p $(BUILD_DIR)
GOOS=linux GOARCH=arm GOARM=7 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm ./$(CMD_DIR)
GOOS=linux GOARCH=arm GOARM=7 $(GO) build -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm ./$(CMD_DIR)
@echo "Build complete: $(BUILD_DIR)/$(BINARY_NAME)-linux-arm"
## build-linux-arm64: Build for Linux ARM64 (e.g. Raspberry Pi Zero 2 W 64-bit)
build-linux-arm64: generate
@echo "Building for linux/arm64..."
@mkdir -p $(BUILD_DIR)
GOOS=linux GOARCH=arm64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm64 ./$(CMD_DIR)
GOOS=linux GOARCH=arm64 $(GO) build -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm64 ./$(CMD_DIR)
@echo "Build complete: $(BUILD_DIR)/$(BINARY_NAME)-linux-arm64"
## build-linux-mipsle: Build for Linux MIPS32 LE
build-linux-mipsle: generate
@echo "Building for linux/mipsle (softfloat)..."
@mkdir -p $(BUILD_DIR)
GOOS=linux GOARCH=mipsle GOMIPS=softfloat $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-mipsle ./$(CMD_DIR)
GOOS=linux GOARCH=mipsle GOMIPS=softfloat $(GO) build -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-mipsle ./$(CMD_DIR)
$(call PATCH_MIPS_FLAGS,$(BUILD_DIR)/$(BINARY_NAME)-linux-mipsle)
@echo "Build complete: $(BUILD_DIR)/$(BINARY_NAME)-linux-mipsle"
@@ -171,18 +173,18 @@ build-pi-zero: build-linux-arm build-linux-arm64
build-all: generate
@echo "Building for multiple platforms..."
@mkdir -p $(BUILD_DIR)
GOOS=linux GOARCH=amd64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-amd64 ./$(CMD_DIR)
GOOS=linux GOARCH=arm GOARM=7 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm ./$(CMD_DIR)
GOOS=linux GOARCH=arm64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm64 ./$(CMD_DIR)
GOOS=linux GOARCH=loong64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-loong64 ./$(CMD_DIR)
GOOS=linux GOARCH=riscv64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-riscv64 ./$(CMD_DIR)
GOOS=linux GOARCH=mipsle GOMIPS=softfloat $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-mipsle ./$(CMD_DIR)
GOOS=linux GOARCH=amd64 $(GO) build -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-amd64 ./$(CMD_DIR)
GOOS=linux GOARCH=arm GOARM=7 $(GO) build -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm ./$(CMD_DIR)
GOOS=linux GOARCH=arm64 $(GO) build -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm64 ./$(CMD_DIR)
GOOS=linux GOARCH=loong64 $(GO) build -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-loong64 ./$(CMD_DIR)
GOOS=linux GOARCH=riscv64 $(GO) build -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-riscv64 ./$(CMD_DIR)
GOOS=linux GOARCH=mipsle GOMIPS=softfloat $(GO) build -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-mipsle ./$(CMD_DIR)
$(call PATCH_MIPS_FLAGS,$(BUILD_DIR)/$(BINARY_NAME)-linux-mipsle)
GOOS=linux GOARCH=arm GOARM=7 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-armv7 ./$(CMD_DIR)
GOOS=darwin GOARCH=arm64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-darwin-arm64 ./$(CMD_DIR)
GOOS=windows GOARCH=amd64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-windows-amd64.exe ./$(CMD_DIR)
GOOS=netbsd GOARCH=amd64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-netbsd-amd64 ./$(CMD_DIR)
GOOS=netbsd GOARCH=arm64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-netbsd-arm64 ./$(CMD_DIR)
GOOS=linux GOARCH=arm GOARM=7 $(GO) build -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-armv7 ./$(CMD_DIR)
GOOS=darwin GOARCH=arm64 $(GO) build -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-darwin-arm64 ./$(CMD_DIR)
GOOS=windows GOARCH=amd64 $(GO) build -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-windows-amd64.exe ./$(CMD_DIR)
GOOS=netbsd GOARCH=amd64 $(GO) build -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-netbsd-amd64 ./$(CMD_DIR)
GOOS=netbsd GOARCH=arm64 $(GO) build -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-netbsd-arm64 ./$(CMD_DIR)
@echo "All builds complete"
## install: Install picoclaw to system and copy builtin skills
@@ -219,11 +221,14 @@ clean:
## vet: Run go vet for static analysis
vet: generate
@$(GO) vet ./...
@packages="$$(go list ./...)" && \
$(GO) vet $$(printf '%s\n' "$$packages" | grep -v '^github.com/sipeed/picoclaw/web/')
@cd web/backend && $(WEB_GO) vet ./...
## test: Test Go code
test: generate
@$(GO) test ./...
@$(GO) test $$(go list ./... | grep -v github.com/sipeed/picoclaw/web/)
@cd web && make test
## fmt: Format Go code
fmt:
+4 -1
View File
@@ -23,7 +23,9 @@
---
🦐 **PicoClaw** est un assistant personnel IA ultra-léger inspiré de [nanobot](https://github.com/HKUDS/nanobot), entièrement écrit en **Go** via un processus d'auto-amorçage (self-bootstrapping) — où l'agent IA lui-même a piloté l'intégralité de la migration architecturale et de l'optimisation du code.
> **PicoClaw** est un projet open-source indépendant initié par [Sipeed](https://sipeed.com). Il est entièrement écrit en **Go** — ce n'est pas un fork d'OpenClaw, de NanoBot ou de tout autre projet.
🦐 **PicoClaw** est un assistant personnel IA ultra-léger inspiré de [NanoBot](https://github.com/HKUDS/nanobot), entièrement réécrit en **Go** via un processus d'auto-amorçage (self-bootstrapping) — où l'agent IA lui-même a piloté l'intégralité de la migration architecturale et de l'optimisation du code.
⚡️ **Extrêmement léger :** Fonctionne sur du matériel à seulement **10$** avec **<10 Mo** de RAM. C'est 99% de mémoire en moins qu'OpenClaw et 98% moins cher qu'un Mac mini !
@@ -991,6 +993,7 @@ Cette conception permet également le **support multi-agent** avec une sélectio
| **BytePlus** | `byteplus/` | `https://ark.ap-southeast.bytepluses.com/api/v3` | OpenAI | [Obtenir Clé](https://www.byteplus.com/) |
| **LongCat** | `longcat/` | `https://api.longcat.chat/openai` | OpenAI | [Obtenir une clé](https://longcat.chat/platform) |
| **ModelScope (魔搭)**| `modelscope/` | `https://api-inference.modelscope.cn/v1` | OpenAI | [Obtenir un Token](https://modelscope.cn/my/tokens) |
| **Azure OpenAI** | `azure/` | `https://{resource}.openai.azure.com` | Azure | [Obtenir Clé](https://portal.azure.com) |
| **Antigravity** | `antigravity/` | Google Cloud | Custom | OAuth uniquement |
| **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - |
+4 -1
View File
@@ -26,7 +26,9 @@
---
🦐 PicoClaw は [nanobot](https://github.com/HKUDS/nanobot) にインスパイアされた超軽量パーソナル AI アシスタントです。Go でゼロからリファクタリングされ、AI エージェント自身がアーキテクチャの移行とコード最適化を推進するセルフブートストラッピングプロセスで構築されました
> **PicoClaw** は [Sipeed](https://sipeed.com) が立ち上げた独立したオープンソースプロジェクトです。完全に **Go 言語**で一から書かれており、OpenClaw、NanoBot、その他のプロジェクトのフォークではありません
🦐 PicoClaw は [NanoBot](https://github.com/HKUDS/nanobot) にインスパイアされた超軽量パーソナル AI アシスタントです。Go でゼロからリファクタリングされ、AI エージェント自身がアーキテクチャの移行とコード最適化を推進するセルフブートストラッピングプロセスで構築されました。
⚡️ $10 のハードウェアで 10MB 未満の RAM で動作:OpenClaw より 99% 少ないメモリ、Mac mini より 98% 安い!
@@ -935,6 +937,7 @@ HEARTBEAT_OK 応答 ユーザーが直接結果を受け取る
| **BytePlus** | `byteplus/` | `https://ark.ap-southeast.bytepluses.com/api/v3` | OpenAI | [キーを取得](https://www.byteplus.com) |
| **LongCat** | `longcat/` | `https://api.longcat.chat/openai` | OpenAI | [キーを取得](https://longcat.chat/platform) |
| **ModelScope (魔搭)**| `modelscope/` | `https://api-inference.modelscope.cn/v1` | OpenAI | [トークンを取得](https://modelscope.cn/my/tokens) |
| **Azure OpenAI** | `azure/` | `https://{resource}.openai.azure.com` | Azure | [キーを取得](https://portal.azure.com) |
| **Antigravity** | `antigravity/` | Google Cloud | カスタム | OAuthのみ |
| **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - |
+20 -1
View File
@@ -24,7 +24,9 @@
---
🦐 PicoClaw is an ultra-lightweight personal AI Assistant inspired by [nanobot](https://github.com/HKUDS/nanobot), refactored from the ground up in Go through a self-bootstrapping process, where the AI agent itself drove the entire architectural migration and code optimization.
> **PicoClaw** is an independent open-source project initiated by [Sipeed](https://sipeed.com). It is written entirely in **Go** — not a fork of OpenClaw, NanoBot, or any other project.
🦐 PicoClaw is an ultra-lightweight personal AI Assistant inspired by [NanoBot](https://github.com/HKUDS/nanobot), refactored from the ground up in Go through a self-bootstrapping process, where the AI agent itself drove the entire architectural migration and code optimization.
⚡️ Runs on $10 hardware with <10MB RAM: That's 99% less memory than OpenClaw and 98% cheaper than a Mac mini!
@@ -861,6 +863,21 @@ Even with `restrict_to_workspace: false`, the `exec` tool blocks these dangerous
* `shutdown`, `reboot`, `poweroff` — System shutdown
* Fork bomb `:(){ :|:& };:`
#### Known Limitation: Child Processes From Build Tools
The exec safety guard only inspects the command line PicoClaw launches directly. It does not recursively inspect child
processes spawned by allowed developer tools such as `make`, `go run`, `cargo`, `npm run`, or custom build scripts.
That means a top-level command can still compile or launch other binaries after it passes the initial guard check. In
practice, treat build scripts, Makefiles, package scripts, and generated binaries as executable code that needs the same
level of review as a direct shell command.
For higher-risk environments:
* Review build scripts before execution.
* Prefer approval/manual review for compile-and-run workflows.
* Run PicoClaw inside a container or VM if you need stronger isolation than the built-in guard provides.
#### Error Examples
```
@@ -1006,6 +1023,7 @@ The subagent has access to tools (message, web_search, etc.) and can communicate
| `groq` | LLM + **Voice transcription** (Whisper) | [console.groq.com](https://console.groq.com) |
| `cerebras` | LLM (Cerebras direct) | [cerebras.ai](https://cerebras.ai) |
| `vivgrid` | LLM (Vivgrid direct) | [vivgrid.com](https://vivgrid.com) |
| `azure` | LLM (Azure OpenAI) | [portal.azure.com](https://portal.azure.com) |
### Model Configuration (model_list)
@@ -1042,6 +1060,7 @@ This design also enables **multi-agent support** with flexible provider selectio
| **Vivgrid** | `vivgrid/` | `https://api.vivgrid.com/v1` | OpenAI | [Get Key](https://vivgrid.com) |
| **LongCat** | `longcat/` | `https://api.longcat.chat/openai` | OpenAI | [Get Key](https://longcat.chat/platform) |
| **ModelScope (魔搭)**| `modelscope/` | `https://api-inference.modelscope.cn/v1` | OpenAI | [Get Token](https://modelscope.cn/my/tokens) |
| **Azure OpenAI** | `azure/` | `https://{resource}.openai.azure.com` | Azure | [Get Key](https://portal.azure.com) |
| **Antigravity** | `antigravity/` | Google Cloud | Custom | OAuth only |
| **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - |
+4 -1
View File
@@ -23,7 +23,9 @@
---
🦐 **PicoClaw** é um assistente pessoal de IA ultra-leve inspirado no [nanobot](https://github.com/HKUDS/nanobot), reescrito do zero em **Go** por meio de um processo de "auto-inicialização" (self-bootstrapping) — onde o próprio agente de IA conduziu toda a migração de arquitetura e otimização de código.
> **PicoClaw** é um projeto open-source independente iniciado pela [Sipeed](https://sipeed.com). É escrito inteiramente em **Go** — não é um fork do OpenClaw, NanoBot ou qualquer outro projeto.
🦐 **PicoClaw** é um assistente pessoal de IA ultra-leve inspirado no [NanoBot](https://github.com/HKUDS/nanobot), reescrito do zero em **Go** por meio de um processo de "auto-inicialização" (self-bootstrapping) — onde o próprio agente de IA conduziu toda a migração de arquitetura e otimização de código.
⚡️ **Extremamente leve:** Roda em hardware de apenas **$10** com **<10MB** de RAM. Isso é 99% menos memória que o OpenClaw e 98% mais barato que um Mac mini!
@@ -987,6 +989,7 @@ Este design também possibilita o **suporte multi-agent** com seleção flexíve
| **BytePlus** | `byteplus/` | `https://ark.ap-southeast.bytepluses.com/api/v3` | OpenAI | [Obter Chave](https://www.byteplus.com) |
| **LongCat** | `longcat/` | `https://api.longcat.chat/openai` | OpenAI | [Obter Chave](https://longcat.chat/platform) |
| **ModelScope (魔搭)**| `modelscope/` | `https://api-inference.modelscope.cn/v1` | OpenAI | [Obter Token](https://modelscope.cn/my/tokens) |
| **Azure OpenAI** | `azure/` | `https://{resource}.openai.azure.com` | Azure | [Obter Chave](https://portal.azure.com) |
| **Antigravity** | `antigravity/` | Google Cloud | Custom | Apenas OAuth |
| **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - |
+4 -1
View File
@@ -23,7 +23,9 @@
---
🦐 **PicoClaw**trợ lý AI cá nhân siêu nhẹ, lấy cảm hứng từ [nanobot](https://github.com/HKUDS/nanobot), được viết lại hoàn toàn bằng **Go** thông qua quá trình "tự khởi tạo" (self-bootstrapping) — nơi chính AI Agent đã tự dẫn dắt toàn bộ quá trình chuyển đổi kiến trúc và tối ưu hóa mã nguồn.
> **PicoClaw**dự án mã nguồn mở độc lập được khởi xướng bởi [Sipeed](https://sipeed.com). Được viết hoàn toàn bằng **Go** — không phải là bản fork của OpenClaw, NanoBot hay bất kỳ dự án nào khác.
🦐 **PicoClaw** là trợ lý AI cá nhân siêu nhẹ, lấy cảm hứng từ [NanoBot](https://github.com/HKUDS/nanobot), được viết lại hoàn toàn bằng **Go** thông qua quá trình "tự khởi tạo" (self-bootstrapping) — nơi chính AI Agent đã tự dẫn dắt toàn bộ quá trình chuyển đổi kiến trúc và tối ưu hóa mã nguồn.
⚡️ **Cực kỳ nhẹ:** Chạy trên phần cứng chỉ **$10** với RAM **<10MB**. Tiết kiệm 99% bộ nhớ so với OpenClaw và rẻ hơn 98% so với Mac mini!
@@ -956,6 +958,7 @@ Thiết kế này cũng cho phép **hỗ trợ đa tác nhân** với lựa ch
| **BytePlus** | `byteplus/` | `https://ark.ap-southeast.bytepluses.com/api/v3` | OpenAI | [Lấy Khóa](https://www.byteplus.com) |
| **LongCat** | `longcat/` | `https://api.longcat.chat/openai` | OpenAI | [Lấy Key](https://longcat.chat/platform) |
| **ModelScope (魔搭)**| `modelscope/` | `https://api-inference.modelscope.cn/v1` | OpenAI | [Lấy Token](https://modelscope.cn/my/tokens) |
| **Azure OpenAI** | `azure/` | `https://{resource}.openai.azure.com` | Azure | [Lấy Khóa](https://portal.azure.com) |
| **Antigravity** | `antigravity/` | Google Cloud | Tùy chỉnh | Chỉ OAuth |
| **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - |
+4 -1
View File
@@ -24,7 +24,9 @@
---
🦐 **PicoClaw**一个受 [nanobot](https://github.com/HKUDS/nanobot) 启发的超轻量级个人 AI 助手。它采**Go 语言** 从零重构,经历了一个“自举”过程——即由 AI Agent 自身驱动了整个架构迁移和代码优化
> **PicoClaw**由 [矽速科技 (Sipeed)](https://sipeed.com) 发起的独立开源项目,完全使用 **Go 语言**从零编写——不是 OpenClaw、NanoBot 或其他项目的分支
🦐 **PicoClaw** 是一个受 [NanoBot](https://github.com/HKUDS/nanobot) 启发的超轻量级个人 AI 助手。它采用 **Go 语言** 从零重构,经历了一个“自举”过程——即由 AI Agent 自身驱动了整个架构迁移和代码优化。
⚡️ **极致轻量**:可在 **10 美元** 的硬件上运行,内存占用 **<10MB**。这意味着比 OpenClaw 节省 99% 的内存,比 Mac mini 便宜 98%
@@ -528,6 +530,7 @@ Agent 读取 HEARTBEAT.md
| **BytePlus** | `byteplus/` | `https://ark.ap-southeast.bytepluses.com/api/v3` | OpenAI | [获取密钥](https://www.byteplus.com) |
| **LongCat** | `longcat/` | `https://api.longcat.chat/openai` | OpenAI | [获取密钥](https://longcat.chat/platform) |
| **ModelScope (魔搭)**| `modelscope/` | `https://api-inference.modelscope.cn/v1` | OpenAI | [获取 Token](https://modelscope.cn/my/tokens) |
| **Azure OpenAI** | `azure/` | `https://{resource}.openai.azure.com` | Azure | [获取密钥](https://portal.azure.com) |
| **Antigravity** | `antigravity/` | Google Cloud | 自定义 | 仅 OAuth |
| **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - |
+11 -1
View File
@@ -5,6 +5,8 @@ import (
"github.com/spf13/cobra"
"github.com/sipeed/picoclaw/cmd/picoclaw/internal"
"github.com/sipeed/picoclaw/pkg/gateway"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/utils"
)
@@ -12,6 +14,7 @@ import (
func NewGatewayCommand() *cobra.Command {
var debug bool
var noTruncate bool
var allowEmpty bool
cmd := &cobra.Command{
Use: "gateway",
@@ -31,12 +34,19 @@ func NewGatewayCommand() *cobra.Command {
return nil
},
RunE: func(_ *cobra.Command, _ []string) error {
return gatewayCmd(debug)
return gateway.Run(debug, internal.GetConfigPath(), allowEmpty)
},
}
cmd.Flags().BoolVarP(&debug, "debug", "d", false, "Enable debug logging")
cmd.Flags().BoolVarP(&noTruncate, "no-truncate", "T", false, "Disable string truncation in debug logs")
cmd.Flags().BoolVarP(
&allowEmpty,
"allow-empty",
"E",
false,
"Continue starting even when no default model is configured",
)
return cmd
}
@@ -28,4 +28,5 @@ func TestNewGatewayCommand(t *testing.T) {
assert.True(t, cmd.HasFlags())
assert.NotNil(t, cmd.Flags().Lookup("debug"))
assert.NotNil(t, cmd.Flags().Lookup("allow-empty"))
}
+6 -1
View File
@@ -11,14 +11,19 @@ import (
var embeddedFiles embed.FS
func NewOnboardCommand() *cobra.Command {
var encrypt bool
cmd := &cobra.Command{
Use: "onboard",
Aliases: []string{"o"},
Short: "Initialize picoclaw configuration and workspace",
Run: func(cmd *cobra.Command, args []string) {
onboard()
onboard(encrypt)
},
}
cmd.Flags().BoolVar(&encrypt, "enc", false,
"Enable credential encryption (generates SSH key and prompts for passphrase)")
return cmd
}
@@ -24,6 +24,9 @@ func TestNewOnboardCommand(t *testing.T) {
assert.Nil(t, cmd.PersistentPreRun)
assert.Nil(t, cmd.PersistentPostRun)
assert.False(t, cmd.HasFlags())
assert.True(t, cmd.HasFlags())
encFlag := cmd.Flags().Lookup("enc")
require.NotNil(t, encFlag, "expected --enc flag to be registered")
assert.Equal(t, "false", encFlag.DefValue, "--enc should default to false")
assert.False(t, cmd.HasSubCommands())
}
+121 -12
View File
@@ -6,25 +6,71 @@ import (
"os"
"path/filepath"
"golang.org/x/term"
"github.com/sipeed/picoclaw/cmd/picoclaw/internal"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/credential"
)
func onboard() {
func onboard(encrypt bool) {
configPath := internal.GetConfigPath()
configExists := false
if _, err := os.Stat(configPath); err == nil {
fmt.Printf("Config already exists at %s\n", configPath)
fmt.Print("Overwrite? (y/n): ")
var response string
fmt.Scanln(&response)
if response != "y" {
fmt.Println("Aborted.")
return
configExists = true
if encrypt {
// Only ask for confirmation when *both* config and SSH key already exist,
// indicating a full re-onboard that would reset the config to defaults.
sshKeyPath, _ := credential.DefaultSSHKeyPath()
if _, err := os.Stat(sshKeyPath); err == nil {
// Both exist — confirm a full reset.
fmt.Printf("Config already exists at %s\n", configPath)
fmt.Print("Overwrite config with defaults? (y/n): ")
var response string
fmt.Scanln(&response)
if response != "y" {
fmt.Println("Aborted.")
return
}
configExists = false // user agreed to reset; treat as fresh
}
// Config exists but SSH key is missing — keep existing config, only add SSH key.
}
}
cfg := config.DefaultConfig()
var err error
if encrypt {
fmt.Println("\nSet up credential encryption")
fmt.Println("-----------------------------")
passphrase, pErr := promptPassphrase()
if pErr != nil {
fmt.Printf("Error: %v\n", pErr)
os.Exit(1)
}
// Expose the passphrase to credential.PassphraseProvider (which calls
// os.Getenv by default) so that SaveConfig can encrypt api_keys.
// This process is a one-shot CLI tool; the env var is never exposed outside
// the current process and disappears when it exits.
os.Setenv(credential.PassphraseEnvVar, passphrase)
if err = setupSSHKey(); err != nil {
fmt.Printf("Error generating SSH key: %v\n", err)
os.Exit(1)
}
}
var cfg *config.Config
if configExists {
// Preserve the existing config; SaveConfig will re-encrypt api_keys with the new passphrase.
cfg, err = config.LoadConfig(configPath)
if err != nil {
fmt.Printf("Error loading existing config: %v\n", err)
os.Exit(1)
}
} else {
cfg = config.DefaultConfig()
}
if err := config.SaveConfig(configPath, cfg); err != nil {
fmt.Printf("Error saving config: %v\n", err)
os.Exit(1)
@@ -33,9 +79,17 @@ func onboard() {
workspace := cfg.WorkspacePath()
createWorkspaceTemplates(workspace)
fmt.Printf("%s picoclaw is ready!\n", internal.Logo)
fmt.Printf("\n%s picoclaw is ready!\n", internal.Logo)
fmt.Println("\nNext steps:")
fmt.Println(" 1. Add your API key to", configPath)
if encrypt {
fmt.Println(" 1. Set your encryption passphrase before starting picoclaw:")
fmt.Println(" export PICOCLAW_KEY_PASSPHRASE=<your-passphrase> # Linux/macOS")
fmt.Println(" set PICOCLAW_KEY_PASSPHRASE=<your-passphrase> # Windows cmd")
fmt.Println("")
fmt.Println(" 2. Add your API key to", configPath)
} else {
fmt.Println(" 1. Add your API key to", configPath)
}
fmt.Println("")
fmt.Println(" Recommended:")
fmt.Println(" - OpenRouter: https://openrouter.ai/keys (access 100+ models)")
@@ -43,7 +97,62 @@ func onboard() {
fmt.Println("")
fmt.Println(" See README.md for 17+ supported providers.")
fmt.Println("")
fmt.Println(" 2. Chat: picoclaw agent -m \"Hello!\"")
fmt.Println(" 3. Chat: picoclaw agent -m \"Hello!\"")
}
// promptPassphrase reads the encryption passphrase twice from the terminal
// (with echo disabled) and returns it. Returns an error if the passphrase is
// empty or if the two inputs do not match.
func promptPassphrase() (string, error) {
fmt.Print("Enter passphrase for credential encryption: ")
p1, err := term.ReadPassword(int(os.Stdin.Fd()))
fmt.Println()
if err != nil {
return "", fmt.Errorf("reading passphrase: %w", err)
}
if len(p1) == 0 {
return "", fmt.Errorf("passphrase must not be empty")
}
fmt.Print("Confirm passphrase: ")
p2, err := term.ReadPassword(int(os.Stdin.Fd()))
fmt.Println()
if err != nil {
return "", fmt.Errorf("reading passphrase confirmation: %w", err)
}
if string(p1) != string(p2) {
return "", fmt.Errorf("passphrases do not match")
}
return string(p1), nil
}
// setupSSHKey generates the picoclaw-specific SSH key at ~/.ssh/picoclaw_ed25519.key.
// If the key already exists the user is warned and asked to confirm overwrite.
// Answering anything other than "y" keeps the existing key (not an error).
func setupSSHKey() error {
keyPath, err := credential.DefaultSSHKeyPath()
if err != nil {
return fmt.Errorf("cannot determine SSH key path: %w", err)
}
if _, err := os.Stat(keyPath); err == nil {
fmt.Printf("\n⚠️ WARNING: %s already exists.\n", keyPath)
fmt.Println(" Overwriting will invalidate any credentials previously encrypted with this key.")
fmt.Print(" Overwrite? (y/n): ")
var response string
fmt.Scanln(&response)
if response != "y" {
fmt.Println("Keeping existing SSH key.")
return nil
}
}
if err := credential.GenerateSSHKey(keyPath); err != nil {
return err
}
fmt.Printf("SSH key generated: %s\n", keyPath)
return nil
}
func createWorkspaceTemplates(workspace string) {
+8 -1
View File
@@ -53,6 +53,12 @@
"api_key": "your-modelscope-access-token",
"api_base": "https://api-inference.modelscope.cn/v1"
},
{
"model_name": "azure-gpt5",
"model": "azure/my-gpt5-deployment",
"api_key": "your-azure-api-key",
"api_base": "https://your-resource.openai.azure.com"
},
{
"model_name": "loadbalanced-gpt-5.4",
"model": "openai/gpt-5.4",
@@ -512,6 +518,7 @@
},
"gateway": {
"host": "127.0.0.1",
"port": 18790
"port": 18790,
"hot_reload": false
}
}
+168
View File
@@ -0,0 +1,168 @@
# Credential Encryption
PicoClaw supports encrypting `api_key` values in `model_list` configuration entries.
Encrypted keys are stored as `enc://<base64>` strings and decrypted automatically at startup.
---
## Quick Start
**1. Set your passphrase**
```bash
export PICOCLAW_KEY_PASSPHRASE="your-passphrase"
```
**2. Encrypt an API key**
Run `picoclaw onboard` — it prompts for your passphrase and generates the SSH key,
then automatically re-encrypts any plaintext `api_key` entries in your config on
the next `SaveConfig` call. The resulting `enc://` value will look like:
```
enc://AAAA...base64...
```
**3. Paste the output into your config**
```json
{
"model_list": [
{
"model_name": "gpt-4o",
"api_key": "enc://AAAA...base64...",
"base_url": "https://api.openai.com/v1"
}
]
}
```
---
## Supported `api_key` Formats
| Format | Example | Behaviour |
|--------|---------|-----------|
| Plaintext | `sk-abc123` | Used as-is |
| File reference | `file://openai.key` | Content read from the same directory as the config file |
| Encrypted | `enc://<base64>` | Decrypted at startup using `PICOCLAW_KEY_PASSPHRASE` |
| Empty | `""` | Passed through unchanged (used with `auth_method: oauth`) |
---
## Cryptographic Design
### Key Derivation
Encryption uses **HKDF-SHA256** with an optional SSH private key as a second factor.
```
Without SSH key (passphrase only):
ikm = SHA256(passphrase)
aes_key = HKDF-SHA256(ikm, salt, info="picoclaw-credential-v1", 32 bytes)
With SSH key (recommended):
sshHash = SHA256(ssh_private_key_file_bytes)
ikm = HMAC-SHA256(key=sshHash, message=passphrase)
aes_key = HKDF-SHA256(ikm, salt, info="picoclaw-credential-v1", 32 bytes)
```
### Encryption
```
AES-256-GCM(key=aes_key, nonce=random[12], plaintext=api_key)
```
### Wire Format
```
enc://<base64( salt[16] + nonce[12] + ciphertext )>
```
| Field | Size | Description |
|-------|------|-------------|
| `salt` | 16 bytes | Random per encryption; fed into HKDF |
| `nonce` | 12 bytes | Random per encryption; AES-GCM IV |
| `ciphertext` | variable | AES-256-GCM ciphertext + 16-byte authentication tag |
The GCM authentication tag is appended to the ciphertext automatically. Any tampering causes decryption to fail with an error rather than returning corrupt plaintext.
### Performance
| Operation | Time (ARM Cortex-A) |
|-----------|---------------------|
| Key derivation (HKDF) | < 1 ms |
| AES-256-GCM decrypt | < 1 ms |
| **Total startup overhead** | **< 2 ms per key** |
---
## Two-Factor Security with SSH Key
When a SSH private key is provided, breaking the encryption requires **both**:
1. The **passphrase** (`PICOCLAW_KEY_PASSPHRASE`)
2. The **SSH private key file**
This means a leaked config file alone is not sufficient to recover the API key, even if the passphrase is weak. The SSH key contributes 256 bits of entropy (Ed25519) regardless of passphrase strength.
### Threat Model
| Attacker Has | Can Decrypt? |
|---|---|
| Config file only | No — needs passphrase + SSH key |
| SSH key only | No — needs passphrase |
| Passphrase only | No — needs SSH key |
| Config file + SSH key + passphrase | Yes — full compromise |
---
## Environment Variables
| Variable | Required | Description |
|----------|----------|-------------|
| `PICOCLAW_KEY_PASSPHRASE` | Yes (for `enc://`) | Passphrase used for key derivation |
| `PICOCLAW_SSH_KEY_PATH` | No | Path to SSH private key. Set to `""` to disable auto-detection and use passphrase-only mode |
### SSH Key Auto-Detection
If `PICOCLAW_SSH_KEY_PATH` is not set, PicoClaw looks for the picoclaw-specific key:
```
~/.ssh/picoclaw_ed25519.key
```
This dedicated file avoids conflicts with the user's existing SSH keys.
Run `picoclaw onboard` to generate it automatically.
`os.UserHomeDir()` is used for cross-platform home directory resolution (reads `USERPROFILE` on Windows, `HOME` on Unix/macOS).
To explicitly disable SSH key usage and use passphrase-only mode:
```bash
export PICOCLAW_SSH_KEY_PATH=""
```
---
## Migration
Because the only secret material is `PICOCLAW_KEY_PASSPHRASE` and the SSH private key file, migration is straightforward:
1. Copy the config file to the new machine.
2. Set `PICOCLAW_KEY_PASSPHRASE` to the same value.
3. Copy the SSH private key file to the same path (or set `PICOCLAW_SSH_KEY_PATH` to its new location).
No re-encryption is needed.
---
## Security Considerations
- **Passphrase strength matters in passphrase-only mode.** Without an SSH key, a weak passphrase can be brute-forced offline. Use `PICOCLAW_SSH_KEY_PATH=""` only in environments where no SSH key is available and the passphrase is sufficiently strong (≥ 32 random characters).
- **The SSH key is read-only at runtime.** PicoClaw never writes to or modifies the SSH key file.
- **Plaintext keys remain supported.** Existing configs without `enc://` are unaffected.
- **The `enc://` format is versioned** via the HKDF `info` field (`picoclaw-credential-v1`), allowing future algorithm upgrades without breaking existing encrypted values.
+16
View File
@@ -84,6 +84,22 @@ By default, PicoClaw blocks the following dangerous commands:
- Git: `git push`, `git force`
- Other: `eval`, `source *.sh`
### Known Architectural Limitation
The exec guard only validates the top-level command sent to PicoClaw. It does **not** recursively inspect child
processes spawned by build tools or scripts after that command starts running.
Examples of workflows that can bypass the direct command guard once the initial command is allowed:
- `make run`
- `go run ./cmd/...`
- `cargo run`
- `npm run build`
This means the guard is useful for blocking obviously dangerous direct commands, but it is **not** a full sandbox for
unreviewed build pipelines. If your threat model includes untrusted code in the workspace, use stronger isolation such
as containers, VMs, or an approval flow around build-and-run commands.
### Configuration Example
```json
+10 -8
View File
@@ -3,10 +3,11 @@ module github.com/sipeed/picoclaw
go 1.25.7
require (
fyne.io/systray v1.12.0
github.com/adhocore/gronx v1.19.6
github.com/anthropics/anthropic-sdk-go v1.22.1
github.com/anthropics/anthropic-sdk-go v1.26.0
github.com/bwmarrin/discordgo v0.29.0
github.com/caarlos0/env/v11 v11.3.1
github.com/caarlos0/env/v11 v11.4.0
github.com/ergochat/irc-go v0.5.0
github.com/ergochat/readline v0.1.3
github.com/gdamore/tcell/v2 v2.13.8
@@ -17,7 +18,7 @@ require (
github.com/larksuite/oapi-sdk-go/v3 v3.5.3
github.com/mdp/qrterminal/v3 v3.2.1
github.com/modelcontextprotocol/go-sdk v1.3.1
github.com/mymmrac/telego v1.6.0
github.com/mymmrac/telego v1.7.0
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1
github.com/openai/openai-go/v3 v3.22.0
github.com/rivo/tview v0.42.0
@@ -27,7 +28,8 @@ require (
github.com/stretchr/testify v1.11.1
github.com/tencent-connect/botgo v0.2.1
go.mau.fi/whatsmeow v0.0.0-20260219150138-7ae702b1eed4
golang.org/x/oauth2 v0.35.0
golang.org/x/oauth2 v0.36.0
golang.org/x/term v0.40.0
golang.org/x/time v0.14.0
google.golang.org/protobuf v1.36.11
gopkg.in/yaml.v3 v3.0.1
@@ -43,6 +45,7 @@ require (
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/elliotchance/orderedmap/v3 v3.1.0 // indirect
github.com/gdamore/encoding v1.0.1 // indirect
github.com/godbus/dbus/v5 v5.1.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/lucasb-eyer/go-colorful v1.3.0 // indirect
github.com/mattn/go-colorable v0.1.14 // indirect
@@ -59,7 +62,6 @@ require (
go.mau.fi/libsignal v0.2.1 // indirect
go.mau.fi/util v0.9.6 // indirect
golang.org/x/exp v0.0.0-20260212183809-81e46e3db34a // indirect
golang.org/x/term v0.40.0 // indirect
golang.org/x/text v0.34.0 // indirect
modernc.org/libc v1.67.6 // indirect
modernc.org/mathutil v1.7.1 // indirect
@@ -73,7 +75,7 @@ require (
github.com/bytedance/sonic v1.15.0 // indirect
github.com/bytedance/sonic/loader v0.5.0 // indirect
github.com/cloudwego/base64x v0.1.6 // indirect
github.com/github/copilot-sdk/go v0.1.23
github.com/github/copilot-sdk/go v0.1.32
github.com/go-resty/resty/v2 v2.17.1 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/google/jsonschema-go v0.4.2 // indirect
@@ -87,10 +89,10 @@ require (
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/valyala/bytebufferpool v1.0.0 // indirect
github.com/valyala/fasthttp v1.69.0 // indirect
github.com/valyala/fastjson v1.6.7 // indirect
github.com/valyala/fastjson v1.6.10 // indirect
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
golang.org/x/arch v0.24.0 // indirect
golang.org/x/crypto v0.48.0 // indirect
golang.org/x/crypto v0.48.0
golang.org/x/net v0.51.0 // indirect
golang.org/x/sync v0.19.0 // indirect
golang.org/x/sys v0.41.0 // indirect
+19 -12
View File
@@ -1,6 +1,8 @@
cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
filippo.io/edwards25519 v1.1.1 h1:YpjwWWlNmGIDyXOn8zLzqiD+9TyIlPhGFG96P39uBpw=
filippo.io/edwards25519 v1.1.1/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
fyne.io/systray v1.12.0 h1:CA1Kk0e2zwFlxtc02L3QFSiIbxJ/P0n582YrZHT7aTM=
fyne.io/systray v1.12.0/go.mod h1:RVwqP9nYMo7h5zViCBHri2FgjXF7H2cub7MAq4NSoLs=
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
github.com/adhocore/gronx v1.19.6 h1:5KNVcoR9ACgL9HhEqCm5QXsab/gI4QDIybTAWcXDKDc=
@@ -11,8 +13,8 @@ github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883 h1:bvNMNQO63//z+xNg
github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883/go.mod h1:rCTlJbsFo29Kk6CurOXKm700vrz8f0KW0JNfpkRJY/8=
github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ=
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
github.com/anthropics/anthropic-sdk-go v1.22.1 h1:xbsc3vJKCX/ELDZSpTNfz9wCgrFsamwFewPb1iI0Xh0=
github.com/anthropics/anthropic-sdk-go v1.22.1/go.mod h1:WTz31rIUHUHqai2UslPpw5CwXrQP3geYBioRV4WOLvE=
github.com/anthropics/anthropic-sdk-go v1.26.0 h1:oUTzFaUpAevfuELAP1sjL6CQJ9HHAfT7CoSYSac11PY=
github.com/anthropics/anthropic-sdk-go v1.26.0/go.mod h1:qUKmaW+uuPB64iy1l+4kOSvaLqPXnHTTBKH6RVZ7q5Q=
github.com/beeper/argo-go v1.1.2 h1:UQI2G8F+NLfGTOmTUI0254pGKx/HUU/etbUGTJv91Fs=
github.com/beeper/argo-go v1.1.2/go.mod h1:M+LJAnyowKVQ6Rdj6XYGEn+qcVFkb3R/MUpqkGR0hM4=
github.com/bwmarrin/discordgo v0.29.0 h1:FmWeXFaKUwrcL3Cx65c20bTRW+vOb6k8AnaP+EgjDno=
@@ -23,8 +25,8 @@ github.com/bytedance/sonic v1.15.0 h1:/PXeWFaR5ElNcVE84U0dOHjiMHQOwNIx3K4ymzh/uS
github.com/bytedance/sonic v1.15.0/go.mod h1:tFkWrPz0/CUCLEF4ri4UkHekCIcdnkqXw9VduqpJh0k=
github.com/bytedance/sonic/loader v0.5.0 h1:gXH3KVnatgY7loH5/TkeVyXPfESoqSBSBEiDd5VjlgE=
github.com/bytedance/sonic/loader v0.5.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo=
github.com/caarlos0/env/v11 v11.3.1 h1:cArPWC15hWmEt+gWk7YBi7lEXTXCvpaSdCiZE2X5mCA=
github.com/caarlos0/env/v11 v11.3.1/go.mod h1:qupehSf/Y0TUTsxKywqRt/vJjN5nz6vauiYEUUr8P4U=
github.com/caarlos0/env/v11 v11.4.0 h1:Kcb6t5kIIr4XkoQC9AF2j+8E1Jsrl3Wz/hhm1LtoGAc=
github.com/caarlos0/env/v11 v11.4.0/go.mod h1:qupehSf/Y0TUTsxKywqRt/vJjN5nz6vauiYEUUr8P4U=
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M=
@@ -38,6 +40,8 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/dnaeon/go-vcr v1.2.0 h1:zHCHvJYTMh1N7xnV7zf1m1GPBF9Ad0Jk/whtQ1663qI=
github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/elliotchance/orderedmap/v3 v3.1.0 h1:j4DJ5ObEmMBt/lcwIecKcoRxIQUEnw0L804lXYDt/pg=
@@ -52,8 +56,8 @@ github.com/gdamore/encoding v1.0.1 h1:YzKZckdBL6jVt2Gc+5p82qhrGiqMdG/eNs6Wy0u3Uh
github.com/gdamore/encoding v1.0.1/go.mod h1:0Z0cMFinngz9kS1QfMjCP8TY7em3bZYeeklsSDPivEo=
github.com/gdamore/tcell/v2 v2.13.8 h1:Mys/Kl5wfC/GcC5Cx4C2BIQH9dbnhnkPgS9/wF3RlfU=
github.com/gdamore/tcell/v2 v2.13.8/go.mod h1:+Wfe208WDdB7INEtCsNrAN6O2m+wsTPk1RAovjaILlo=
github.com/github/copilot-sdk/go v0.1.23 h1:uExtO/inZQndCZMiSAA1hvXINiz9tqo/MZgQzFzurxw=
github.com/github/copilot-sdk/go v0.1.23/go.mod h1:GdwwBfMbm9AABLEM3x5IZKw4ZfwCYxZ1BgyytmZenQ0=
github.com/github/copilot-sdk/go v0.1.32 h1:wc9SFWwxXhJts6vyzzboPLJqcEJGnHE8rMCAY1RrUgo=
github.com/github/copilot-sdk/go v0.1.32/go.mod h1:qc2iEF7hdO8kzSvbyGvrcGhuk2fzdW4xTtT0+1EH2ts=
github.com/go-redis/redis/v8 v8.11.4/go.mod h1:2Z2wHZXdQpCDXEGzqMockDpNyYvi2l4Pxt6RJr792+w=
github.com/go-resty/resty/v2 v2.6.0/go.mod h1:PwvJS6hvaPkjtjNg9ph+VrSD92bi5Zq73w/BIH7cC3Q=
github.com/go-resty/resty/v2 v2.17.1 h1:x3aMpHK1YM9e4va/TMDRlusDDoZiQ+ViDu/WpA6xTM4=
@@ -62,6 +66,8 @@ github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg78
github.com/go-test/deep v1.1.1 h1:0r/53hagsehfO4bzD2Pgr/+RgHqhmf+k1Bpse2cTu1U=
github.com/go-test/deep v1.1.1/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk=
github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
@@ -136,8 +142,8 @@ github.com/mdp/qrterminal/v3 v3.2.1 h1:6+yQjiiOsSuXT5n9/m60E54vdgFsw0zhADHhHLrFe
github.com/mdp/qrterminal/v3 v3.2.1/go.mod h1:jOTmXvnBsMy5xqLniO0R++Jmjs2sTm9dFSuQ5kpz/SU=
github.com/modelcontextprotocol/go-sdk v1.3.1 h1:TfqtNKOIWN4Z1oqmPAiWDC2Jq7K9OdJaooe0teoXASI=
github.com/modelcontextprotocol/go-sdk v1.3.1/go.mod h1:DgVX498dMD8UJlseK1S5i1T4tFz2fkBk4xogC3D15nw=
github.com/mymmrac/telego v1.6.0 h1:Zc8rgyHozvd/7ZgyrigyHdAF9koHYMfilYfyB6wlFC0=
github.com/mymmrac/telego v1.6.0/go.mod h1:xt6ZWA8zi8KmuzryE1ImEdl9JSwjHNpM4yhC7D8hU4Y=
github.com/mymmrac/telego v1.7.0 h1:yRO/l00tFGG4nY66ufUKb4ARqv7qx9+LsjQv/b0NEyo=
github.com/mymmrac/telego v1.7.0/go.mod h1:pdLV346EgVuq7Xrh3kMggeBiazeHhsdEoK0RTEOPXRM=
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
@@ -216,8 +222,8 @@ github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6Kllzaw
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
github.com/valyala/fasthttp v1.69.0 h1:fNLLESD2SooWeh2cidsuFtOcrEi4uB4m1mPrkJMZyVI=
github.com/valyala/fasthttp v1.69.0/go.mod h1:4wA4PfAraPlAsJ5jMSqCE2ug5tqUPwKXxVj8oNECGcw=
github.com/valyala/fastjson v1.6.7 h1:ZE4tRy0CIkh+qDc5McjatheGX2czdn8slQjomexVpBM=
github.com/valyala/fastjson v1.6.7/go.mod h1:CLCAqky6SMuOcxStkYQvblddUtoRxhYMGLrsQns1aXY=
github.com/valyala/fastjson v1.6.10 h1:/yjJg8jaVQdYR3arGxPE2X5z89xrlhS0eGXdv+ADTh4=
github.com/valyala/fastjson v1.6.10/go.mod h1:e6FubmQouUNP73jtMLmcbxS6ydWIpOfhz34TSfO3JaE=
github.com/vektah/gqlparser/v2 v2.5.27 h1:RHPD3JOplpk5mP5JGX8RKZkt2/Vwj/PZv0HxTdwFp0s=
github.com/vektah/gqlparser/v2 v2.5.27/go.mod h1:D1/VCZtV3LPnQrcPBeR/q5jkSQIPti0uYCP/RI0gIeo=
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
@@ -270,8 +276,8 @@ golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U=
golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo=
golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y=
golang.org/x/oauth2 v0.23.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
golang.org/x/oauth2 v0.35.0 h1:Mv2mzuHuZuY2+bkyWXIHMfhNdJAdwW3FuWeCPYN5GVQ=
golang.org/x/oauth2 v0.35.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -354,6 +360,7 @@ gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWD
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
+25 -2
View File
@@ -10,6 +10,7 @@ import (
"strings"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/media"
"github.com/sipeed/picoclaw/pkg/memory"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/routing"
@@ -66,7 +67,7 @@ func NewAgentInstance(
readRestrict := restrict && !defaults.AllowReadOutsideWorkspace
// Compile path whitelist patterns from config.
allowReadPaths := compilePatterns(cfg.Tools.AllowReadPaths)
allowReadPaths := buildAllowReadPatterns(cfg)
allowWritePaths := compilePatterns(cfg.Tools.AllowWritePaths)
toolsRegistry := tools.NewToolRegistry()
@@ -82,7 +83,7 @@ func NewAgentInstance(
toolsRegistry.Register(tools.NewListDirTool(workspace, readRestrict, allowReadPaths))
}
if cfg.Tools.IsToolEnabled("exec") {
execTool, err := tools.NewExecToolWithConfig(workspace, restrict, cfg)
execTool, err := tools.NewExecToolWithConfig(workspace, restrict, cfg, allowReadPaths)
if err != nil {
log.Fatalf("Critical error: unable to initialize exec tool: %v", err)
}
@@ -282,6 +283,28 @@ func compilePatterns(patterns []string) []*regexp.Regexp {
return compiled
}
func buildAllowReadPatterns(cfg *config.Config) []*regexp.Regexp {
var configured []string
if cfg != nil {
configured = cfg.Tools.AllowReadPaths
}
compiled := compilePatterns(configured)
mediaDirPattern := regexp.MustCompile(mediaTempDirPattern())
for _, pattern := range compiled {
if pattern.String() == mediaDirPattern.String() {
return compiled
}
}
return append(compiled, mediaDirPattern)
}
func mediaTempDirPattern() string {
sep := regexp.QuoteMeta(string(os.PathSeparator))
return "^" + regexp.QuoteMeta(filepath.Clean(media.TempDir())) + "(?:" + sep + "|$)"
}
// Close releases resources held by the agent's session store.
func (a *AgentInstance) Close() error {
if a.Sessions != nil {
+86
View File
@@ -1,10 +1,14 @@
package agent
import (
"context"
"os"
"path/filepath"
"strings"
"testing"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/media"
)
func TestNewAgentInstance_UsesDefaultsTemperatureAndMaxTokens(t *testing.T) {
@@ -160,3 +164,85 @@ func TestNewAgentInstance_ResolveCandidatesFromModelListAlias(t *testing.T) {
})
}
}
func TestNewAgentInstance_AllowsMediaTempDirForReadListAndExec(t *testing.T) {
workspace := t.TempDir()
mediaDir := media.TempDir()
if err := os.MkdirAll(mediaDir, 0o700); err != nil {
t.Fatalf("MkdirAll(mediaDir) error = %v", err)
}
mediaFile, err := os.CreateTemp(mediaDir, "instance-tool-*.txt")
if err != nil {
t.Fatalf("CreateTemp(mediaDir) error = %v", err)
}
mediaPath := mediaFile.Name()
if _, err := mediaFile.WriteString("attachment content"); err != nil {
mediaFile.Close()
t.Fatalf("WriteString(mediaFile) error = %v", err)
}
if err := mediaFile.Close(); err != nil {
t.Fatalf("Close(mediaFile) error = %v", err)
}
t.Cleanup(func() { _ = os.Remove(mediaPath) })
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: workspace,
ModelName: "test-model",
RestrictToWorkspace: true,
},
},
Tools: config.ToolsConfig{
ReadFile: config.ReadFileToolConfig{Enabled: true},
ListDir: config.ToolConfig{Enabled: true},
Exec: config.ExecConfig{
ToolConfig: config.ToolConfig{Enabled: true},
EnableDenyPatterns: true,
AllowRemote: true,
},
},
}
agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, &mockProvider{})
readTool, ok := agent.Tools.Get("read_file")
if !ok {
t.Fatal("read_file tool not registered")
}
readResult := readTool.Execute(context.Background(), map[string]any{"path": mediaPath})
if readResult.IsError {
t.Fatalf("read_file should allow media temp dir, got: %s", readResult.ForLLM)
}
if !strings.Contains(readResult.ForLLM, "attachment content") {
t.Fatalf("read_file output missing media content: %s", readResult.ForLLM)
}
listTool, ok := agent.Tools.Get("list_dir")
if !ok {
t.Fatal("list_dir tool not registered")
}
listResult := listTool.Execute(context.Background(), map[string]any{"path": mediaDir})
if listResult.IsError {
t.Fatalf("list_dir should allow media temp dir, got: %s", listResult.ForLLM)
}
if !strings.Contains(listResult.ForLLM, filepath.Base(mediaPath)) {
t.Fatalf("list_dir output missing media file: %s", listResult.ForLLM)
}
execTool, ok := agent.Tools.Get("exec")
if !ok {
t.Fatal("exec tool not registered")
}
execResult := execTool.Execute(context.Background(), map[string]any{
"command": "cat " + filepath.Base(mediaPath),
"working_dir": mediaDir,
})
if execResult.IsError {
t.Fatalf("exec should allow media temp dir, got: %s", execResult.ForLLM)
}
if !strings.Contains(execResult.ForLLM, "attachment content") {
t.Fatalf("exec output missing media content: %s", execResult.ForLLM)
}
}
+70 -61
View File
@@ -124,6 +124,8 @@ func registerSharedTools(
registry *AgentRegistry,
provider providers.LLMProvider,
) {
allowReadPaths := buildAllowReadPatterns(cfg)
for _, agentID := range registry.ListAgentIDs() {
agent, ok := registry.GetAgent(agentID)
if !ok {
@@ -202,6 +204,7 @@ func registerSharedTools(
cfg.Agents.Defaults.RestrictToWorkspace,
cfg.Agents.Defaults.GetMaxMediaSize(),
nil,
allowReadPaths,
)
agent.Tools.Register(sendFileTool)
}
@@ -229,72 +232,75 @@ func registerSharedTools(
}
}
// Spawn tool with allowlist checker
if cfg.Tools.IsToolEnabled("spawn") {
if cfg.Tools.IsToolEnabled("subagent") {
subagentManager := tools.NewSubagentManager(provider, agent.Model, agent.Workspace)
subagentManager.SetLLMOptions(agent.MaxTokens, agent.Temperature)
// Set the spawner that links into AgentLoop's turnState
subagentManager.SetSpawner(func(
ctx context.Context,
task, label, targetAgentID string,
tls *tools.ToolRegistry,
maxTokens int,
temperature float64,
hasMaxTokens, hasTemperature bool,
) (*tools.ToolResult, error) {
// 1. Recover parent Turn State from Context
parentTS := turnStateFromContext(ctx)
if parentTS == nil {
// Fallback: If no turnState exists in context, create an isolated ad-hoc root turn state
// so that the tool can still function outside of an agent loop (e.g. tests, raw invocations).
parentTS = &turnState{
ctx: ctx,
turnID: "adhoc-root",
depth: 0,
session: newEphemeralSession(nil),
pendingResults: make(chan *tools.ToolResult, 16),
concurrencySem: make(chan struct{}, 5),
}
}
// 2. Build Tools slice from registry
var tlSlice []tools.Tool
for _, name := range tls.List() {
if t, ok := tls.Get(name); ok {
tlSlice = append(tlSlice, t)
}
}
// Spawn and spawn_status tools share a SubagentManager.
// Construct it when either tool is enabled (both require subagent).
spawnEnabled := cfg.Tools.IsToolEnabled("spawn")
spawnStatusEnabled := cfg.Tools.IsToolEnabled("spawn_status")
if (spawnEnabled || spawnStatusEnabled) && cfg.Tools.IsToolEnabled("subagent") {
subagentManager := tools.NewSubagentManager(provider, agent.Model, agent.Workspace)
subagentManager.SetLLMOptions(agent.MaxTokens, agent.Temperature)
// 3. System Prompt
systemPrompt := "You are a subagent. Complete the given task independently and report the result.\n" +
"You have access to tools - use them as needed to complete your task.\n" +
"After completing the task, provide a clear summary of what was done.\n\n" +
"Task: " + task
// 4. Resolve Model
modelToUse := agent.Model
if targetAgentID != "" {
if targetAgent, ok := al.GetRegistry().GetAgent(targetAgentID); ok {
modelToUse = targetAgent.Model
}
// Set the spawner that links into AgentLoop's turnState
subagentManager.SetSpawner(func(
ctx context.Context,
task, label, targetAgentID string,
tls *tools.ToolRegistry,
maxTokens int,
temperature float64,
hasMaxTokens, hasTemperature bool,
) (*tools.ToolResult, error) {
// 1. Recover parent Turn State from Context
parentTS := turnStateFromContext(ctx)
if parentTS == nil {
// Fallback: If no turnState exists in context, create an isolated ad-hoc root turn state
// so that the tool can still function outside of an agent loop (e.g. tests, raw invocations).
parentTS = &turnState{
ctx: ctx,
turnID: "adhoc-root",
depth: 0,
session: newEphemeralSession(nil),
pendingResults: make(chan *tools.ToolResult, 16),
concurrencySem: make(chan struct{}, 5),
}
}
// 5. Build SubTurnConfig
cfg := SubTurnConfig{
Model: modelToUse,
Tools: tlSlice,
SystemPrompt: systemPrompt,
// 2. Build Tools slice from registry
var tlSlice []tools.Tool
for _, name := range tls.List() {
if t, ok := tls.Get(name); ok {
tlSlice = append(tlSlice, t)
}
if hasMaxTokens {
cfg.MaxTokens = maxTokens
}
// 3. System Prompt
systemPrompt := "You are a subagent. Complete the given task independently and report the result.\n" +
"You have access to tools - use them as needed to complete your task.\n" +
"After completing the task, provide a clear summary of what was done.\n\n" +
"Task: " + task
// 4. Resolve Model
modelToUse := agent.Model
if targetAgentID != "" {
if targetAgent, ok := al.GetRegistry().GetAgent(targetAgentID); ok {
modelToUse = targetAgent.Model
}
}
// 6. Spawn SubTurn
return spawnSubTurn(ctx, al, parentTS, cfg)
})
// 5. Build SubTurnConfig
cfg := SubTurnConfig{
Model: modelToUse,
Tools: tlSlice,
SystemPrompt: systemPrompt,
}
if hasMaxTokens {
cfg.MaxTokens = maxTokens
}
// 6. Spawn SubTurn
return spawnSubTurn(ctx, al, parentTS, cfg)
})
if spawnEnabled {
spawnTool := tools.NewSpawnTool(subagentManager)
currentAgentID := agentID
spawnTool.SetAllowlistChecker(func(targetAgentID string) bool {
@@ -311,9 +317,12 @@ func registerSharedTools(
subagentTool := tools.NewSubagentTool(subagentManager)
subagentTool.SetSpawner(spawner)
agent.Tools.Register(subagentTool)
} else {
logger.WarnCF("agent", "spawn tool requires subagent to be enabled", nil)
}
if spawnStatusEnabled {
agent.Tools.Register(tools.NewSpawnStatusTool(subagentManager))
}
} else if (spawnEnabled || spawnStatusEnabled) && !cfg.Tools.IsToolEnabled("subagent") {
logger.WarnCF("agent", "spawn/spawn_status tools require subagent to be enabled", nil)
}
}
}
+1 -1
View File
@@ -618,7 +618,7 @@ func (c *FeishuChannel) downloadResource(
}
// Write to the shared picoclaw_media directory using a unique name to avoid collisions.
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
mediaDir := media.TempDir()
if mkdirErr := os.MkdirAll(mediaDir, 0o700); mkdirErr != nil {
logger.ErrorCF("feishu", "Failed to create media directory", map[string]any{
"error": mkdirErr.Error(),
+1 -2
View File
@@ -357,7 +357,6 @@ func (m *Manager) StartAll(ctx context.Context) error {
if len(m.channels) == 0 {
logger.WarnC("channels", "No channels enabled")
return errors.New("no channels enabled")
}
logger.InfoC("channels", "Starting all channels")
@@ -397,7 +396,7 @@ func (m *Manager) StartAll(ctx context.Context) error {
"addr": m.httpServer.Addr,
})
if err := m.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
logger.ErrorCF("channels", "Shared HTTP server error", map[string]any{
logger.FatalCF("channels", "Shared HTTP server error", map[string]any{
"error": err.Error(),
})
}
+1 -3
View File
@@ -35,8 +35,6 @@ const (
roomKindCacheTTL = 5 * time.Minute
roomKindCacheCleanupPeriod = 1 * time.Minute
roomKindCacheMaxEntries = 2048
matrixMediaTempDirName = "picoclaw_media"
)
var matrixMentionHrefRegexp = regexp.MustCompile(`(?i)<a[^>]+href=["']([^"']+)["']`)
@@ -1105,7 +1103,7 @@ func (c *MatrixChannel) stripSelfMention(text string) string {
}
func matrixMediaTempDir() (string, error) {
mediaDir := filepath.Join(os.TempDir(), matrixMediaTempDirName)
mediaDir := media.TempDir()
if err := os.MkdirAll(mediaDir, 0o700); err != nil {
return "", err
}
+2 -1
View File
@@ -15,6 +15,7 @@ import (
"maunium.net/go/mautrix/id"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/media"
)
func TestMatrixLocalpartMentionRegexp(t *testing.T) {
@@ -165,7 +166,7 @@ func TestMatrixMediaTempDir(t *testing.T) {
if err != nil {
t.Fatalf("matrixMediaTempDir failed: %v", err)
}
if filepath.Base(dir) != matrixMediaTempDirName {
if filepath.Base(dir) != media.TempDirName {
t.Fatalf("unexpected media dir base: %q", filepath.Base(dir))
}
+28 -3
View File
@@ -251,7 +251,13 @@ func (c *PicoChannel) handleWebSocket(w http.ResponseWriter, r *http.Request) {
return
}
conn, err := c.upgrader.Upgrade(w, r, nil)
// Echo the matched subprotocol back so the browser accepts the upgrade.
var responseHeader http.Header
if proto := c.matchedSubprotocol(r); proto != "" {
responseHeader = http.Header{"Sec-WebSocket-Protocol": {proto}}
}
conn, err := c.upgrader.Upgrade(w, r, responseHeader)
if err != nil {
logger.ErrorCF("pico", "WebSocket upgrade failed", map[string]any{
"error": err.Error(),
@@ -282,8 +288,10 @@ func (c *PicoChannel) handleWebSocket(w http.ResponseWriter, r *http.Request) {
go c.readLoop(pc)
}
// authenticate checks the Bearer token from the Authorization header.
// Query parameter authentication is only allowed when AllowTokenQuery is explicitly enabled.
// authenticate checks the request for a valid token:
// 1. Authorization: Bearer <token> header
// 2. Sec-WebSocket-Protocol "token.<value>" (for browsers that can't set headers)
// 3. Query parameter "token" (only when AllowTokenQuery is on)
func (c *PicoChannel) authenticate(r *http.Request) bool {
token := c.config.Token
if token == "" {
@@ -298,6 +306,11 @@ func (c *PicoChannel) authenticate(r *http.Request) bool {
}
}
// Check Sec-WebSocket-Protocol subprotocol ("token.<value>")
if c.matchedSubprotocol(r) != "" {
return true
}
// Check query parameter only when explicitly allowed
if c.config.AllowTokenQuery {
if r.URL.Query().Get("token") == token {
@@ -308,6 +321,18 @@ func (c *PicoChannel) authenticate(r *http.Request) bool {
return false
}
// matchedSubprotocol returns the "token.<value>" subprotocol that matches
// the configured token, or "" if none do.
func (c *PicoChannel) matchedSubprotocol(r *http.Request) string {
token := c.config.Token
for _, proto := range websocket.Subprotocols(r) {
if after, ok := strings.CutPrefix(proto, "token."); ok && after == token {
return proto
}
}
return ""
}
// readLoop reads messages from a WebSocket connection.
func (c *PicoChannel) readLoop(pc *picoConn) {
defer func() {
+79 -6
View File
@@ -4,11 +4,13 @@ import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"sync/atomic"
"github.com/caarlos0/env/v11"
"github.com/sipeed/picoclaw/pkg/credential"
"github.com/sipeed/picoclaw/pkg/fileutil"
)
@@ -624,8 +626,9 @@ func (c *ModelConfig) Validate() error {
}
type GatewayConfig struct {
Host string `json:"host" env:"PICOCLAW_GATEWAY_HOST"`
Port int `json:"port" env:"PICOCLAW_GATEWAY_PORT"`
Host string `json:"host" env:"PICOCLAW_GATEWAY_HOST"`
Port int `json:"port" env:"PICOCLAW_GATEWAY_PORT"`
HotReload bool `json:"hot_reload" env:"PICOCLAW_GATEWAY_HOT_RELOAD"`
}
type ToolDiscoveryConfig struct {
@@ -698,8 +701,9 @@ type WebToolsConfig struct {
}
type CronToolsConfig struct {
ToolConfig ` envPrefix:"PICOCLAW_TOOLS_CRON_"`
ExecTimeoutMinutes int ` env:"PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES" json:"exec_timeout_minutes"` // 0 means no timeout
ToolConfig ` envPrefix:"PICOCLAW_TOOLS_CRON_"`
ExecTimeoutMinutes int ` env:"PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES" json:"exec_timeout_minutes"` // 0 means no timeout
AllowCommand bool ` env:"PICOCLAW_TOOLS_CRON_ALLOW_COMMAND" json:"allow_command"`
}
type ExecConfig struct {
@@ -749,6 +753,7 @@ type ToolsConfig struct {
ReadFile ReadFileToolConfig `json:"read_file" envPrefix:"PICOCLAW_TOOLS_READ_FILE_"`
SendFile ToolConfig `json:"send_file" envPrefix:"PICOCLAW_TOOLS_SEND_FILE_"`
Spawn ToolConfig `json:"spawn" envPrefix:"PICOCLAW_TOOLS_SPAWN_"`
SpawnStatus ToolConfig `json:"spawn_status" envPrefix:"PICOCLAW_TOOLS_SPAWN_STATUS_"`
SPI ToolConfig `json:"spi" envPrefix:"PICOCLAW_TOOLS_SPI_"`
Subagent ToolConfig `json:"subagent" envPrefix:"PICOCLAW_TOOLS_SUBAGENT_"`
WebFetch ToolConfig `json:"web_fetch" envPrefix:"PICOCLAW_TOOLS_WEB_FETCH_"`
@@ -838,10 +843,24 @@ func LoadConfig(path string) (*Config, error) {
return nil, err
}
if passphrase := credential.PassphraseProvider(); passphrase != "" {
for _, m := range cfg.ModelList {
if m.APIKey != "" && !strings.HasPrefix(m.APIKey, "enc://") && !strings.HasPrefix(m.APIKey, "file://") {
fmt.Fprintf(os.Stderr,
"picoclaw: warning: model %q has a plaintext api_key; call SaveConfig to encrypt it\n",
m.ModelName)
}
}
}
if err := env.Parse(cfg); err != nil {
return nil, err
}
if err := resolveAPIKeys(cfg.ModelList, filepath.Dir(path)); err != nil {
return nil, err
}
// Migrate legacy channel config fields to new unified structures
cfg.migrateChannelConfigs()
@@ -858,6 +877,48 @@ func LoadConfig(path string) (*Config, error) {
return cfg, nil
}
// encryptPlaintextAPIKeys returns a copy of models with plaintext api_key values
// encrypted. Returns (nil, nil) when nothing changed (all keys already sealed or
// empty). Returns (nil, error) if any key fails to encrypt — callers must treat
// this as a hard failure to prevent a mixed plaintext/ciphertext state on disk.
// Symmetric counterpart of resolveAPIKeys: both operate purely on []ModelConfig
// and leave JSON marshaling to the caller.
func encryptPlaintextAPIKeys(models []ModelConfig, passphrase string) ([]ModelConfig, error) {
sealed := make([]ModelConfig, len(models))
copy(sealed, models)
changed := false
for i := range sealed {
m := &sealed[i]
if m.APIKey == "" || strings.HasPrefix(m.APIKey, "enc://") || strings.HasPrefix(m.APIKey, "file://") {
continue
}
encrypted, err := credential.Encrypt(passphrase, "", m.APIKey)
if err != nil {
return nil, fmt.Errorf("cannot seal api_key for model %q: %w", m.ModelName, err)
}
m.APIKey = encrypted
changed = true
}
if !changed {
return nil, nil
}
return sealed, nil
}
// resolveAPIKeys decrypts or dereferences each api_key in models in-place.
// Supports plaintext (no-op), file:// (read from configDir), and enc:// (AES-GCM decrypt).
func resolveAPIKeys(models []ModelConfig, configDir string) error {
cr := credential.NewResolver(configDir)
for i := range models {
resolved, err := cr.Resolve(models[i].APIKey)
if err != nil {
return fmt.Errorf("model_list[%d] (%s): %w", i, models[i].ModelName, err)
}
models[i].APIKey = resolved
}
return nil
}
func (c *Config) migrateChannelConfigs() {
// Discord: mention_only -> group_trigger.mention_only
if c.Channels.Discord.MentionOnly && !c.Channels.Discord.GroupTrigger.MentionOnly {
@@ -872,12 +933,22 @@ func (c *Config) migrateChannelConfigs() {
}
func SaveConfig(path string, cfg *Config) error {
if passphrase := credential.PassphraseProvider(); passphrase != "" {
sealed, err := encryptPlaintextAPIKeys(cfg.ModelList, passphrase)
if err != nil {
return err
}
if sealed != nil {
tmp := *cfg
tmp.ModelList = sealed
cfg = &tmp
}
}
data, err := json.MarshalIndent(cfg, "", " ")
if err != nil {
return err
}
// Use unified atomic write utility with explicit sync for flash storage reliability.
return fileutil.WriteFileAtomic(path, data, 0o600)
}
@@ -1044,6 +1115,8 @@ func (t *ToolsConfig) IsToolEnabled(name string) bool {
return t.ReadFile.Enabled
case "spawn":
return t.Spawn.Enabled
case "spawn_status":
return t.SpawnStatus.Enabled
case "spi":
return t.SPI.Enabled
case "subagent":
+386 -5
View File
@@ -7,8 +7,22 @@ import (
"runtime"
"strings"
"testing"
"github.com/sipeed/picoclaw/pkg/credential"
)
// mustSetupSSHKey generates a temporary Ed25519 SSH key in t.TempDir() and sets
// PICOCLAW_SSH_KEY_PATH to its path for the duration of the test. This is required
// whenever a test exercises encryption/decryption via credential.Encrypt or SaveConfig.
func mustSetupSSHKey(t *testing.T) {
t.Helper()
keyPath := filepath.Join(t.TempDir(), "picoclaw_ed25519.key")
if err := credential.GenerateSSHKey(keyPath); err != nil {
t.Fatalf("mustSetupSSHKey: %v", err)
}
t.Setenv("PICOCLAW_SSH_KEY_PATH", keyPath)
}
func TestAgentModelConfig_UnmarshalString(t *testing.T) {
var m AgentModelConfig
if err := json.Unmarshal([]byte(`"gpt-4"`), &m); err != nil {
@@ -253,6 +267,9 @@ func TestDefaultConfig_Gateway(t *testing.T) {
if cfg.Gateway.Port == 0 {
t.Error("Gateway port should have default value")
}
if cfg.Gateway.HotReload {
t.Error("Gateway hot reload should be disabled by default")
}
}
// TestDefaultConfig_Providers verifies provider structure
@@ -391,6 +408,13 @@ func TestDefaultConfig_ExecAllowRemoteEnabled(t *testing.T) {
}
}
func TestDefaultConfig_CronAllowCommandEnabled(t *testing.T) {
cfg := DefaultConfig()
if !cfg.Tools.Cron.AllowCommand {
t.Fatal("DefaultConfig().Tools.Cron.AllowCommand should be true")
}
}
func TestLoadConfig_OpenAIWebSearchDefaultsTrueWhenUnset(t *testing.T) {
dir := t.TempDir()
configPath := filepath.Join(dir, "config.json")
@@ -423,6 +447,22 @@ func TestLoadConfig_ExecAllowRemoteDefaultsTrueWhenUnset(t *testing.T) {
}
}
func TestLoadConfig_CronAllowCommandDefaultsTrueWhenUnset(t *testing.T) {
dir := t.TempDir()
configPath := filepath.Join(dir, "config.json")
if err := os.WriteFile(configPath, []byte(`{"tools":{"cron":{"exec_timeout_minutes":5}}}`), 0o600); err != nil {
t.Fatalf("WriteFile() error: %v", err)
}
cfg, err := LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error: %v", err)
}
if !cfg.Tools.Cron.AllowCommand {
t.Fatal("tools.cron.allow_command should remain true when unset in config file")
}
}
func TestLoadConfig_OpenAIWebSearchCanBeDisabled(t *testing.T) {
dir := t.TempDir()
configPath := filepath.Join(dir, "config.json")
@@ -482,13 +522,19 @@ func TestDefaultConfig_DMScope(t *testing.T) {
}
func TestDefaultConfig_WorkspacePath_Default(t *testing.T) {
// Unset to ensure we test the default
t.Setenv("PICOCLAW_HOME", "")
// Set a known home for consistent test results
t.Setenv("HOME", "/tmp/home")
var fakeHome string
if runtime.GOOS == "windows" {
fakeHome = `C:\tmp\home`
t.Setenv("USERPROFILE", fakeHome)
} else {
fakeHome = "/tmp/home"
t.Setenv("HOME", fakeHome)
}
cfg := DefaultConfig()
want := filepath.Join("/tmp/home", ".picoclaw", "workspace")
want := filepath.Join(fakeHome, ".picoclaw", "workspace")
if cfg.Agents.Defaults.Workspace != want {
t.Errorf("Default workspace path = %q, want %q", cfg.Agents.Defaults.Workspace, want)
@@ -499,7 +545,7 @@ func TestDefaultConfig_WorkspacePath_WithPicoclawHome(t *testing.T) {
t.Setenv("PICOCLAW_HOME", "/custom/picoclaw/home")
cfg := DefaultConfig()
want := "/custom/picoclaw/home/workspace"
want := filepath.Join("/custom/picoclaw/home", "workspace")
if cfg.Agents.Defaults.Workspace != want {
t.Errorf("Workspace path with PICOCLAW_HOME = %q, want %q", cfg.Agents.Defaults.Workspace, want)
@@ -621,3 +667,338 @@ func TestFlexibleStringSlice_UnmarshalText_EmptySliceConsistency(t *testing.T) {
}
})
}
// TestLoadConfig_WarnsForPlaintextAPIKey verifies that LoadConfig resolves a plaintext
// api_key into memory but does NOT rewrite the config file. File writes are the sole
// responsibility of SaveConfig.
func TestLoadConfig_WarnsForPlaintextAPIKey(t *testing.T) {
dir := t.TempDir()
cfgPath := filepath.Join(dir, "config.json")
const original = `{"model_list":[{"model_name":"test","model":"openai/gpt-4","api_key":"sk-plaintext"}]}`
if err := os.WriteFile(cfgPath, []byte(original), 0o600); err != nil {
t.Fatalf("setup: %v", err)
}
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "test-passphrase")
t.Setenv("PICOCLAW_SSH_KEY_PATH", "")
cfg, err := LoadConfig(cfgPath)
if err != nil {
t.Fatalf("LoadConfig: %v", err)
}
// In-memory value must be the resolved plaintext.
if cfg.ModelList[0].APIKey != "sk-plaintext" {
t.Errorf("in-memory api_key = %q, want %q", cfg.ModelList[0].APIKey, "sk-plaintext")
}
// The file on disk must remain unchanged — LoadConfig must not write anything.
raw, _ := os.ReadFile(cfgPath)
if string(raw) != original {
t.Errorf("LoadConfig must not modify the config file; got:\n%s", string(raw))
}
}
// TestSaveConfig_EncryptsPlaintextAPIKey verifies that SaveConfig writes enc:// ciphertext
// to disk and that a subsequent LoadConfig decrypts it back to the original plaintext.
func TestSaveConfig_EncryptsPlaintextAPIKey(t *testing.T) {
dir := t.TempDir()
cfgPath := filepath.Join(dir, "config.json")
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "test-passphrase")
mustSetupSSHKey(t)
cfg := DefaultConfig()
cfg.ModelList = []ModelConfig{
{ModelName: "test", Model: "openai/gpt-4", APIKey: "sk-plaintext"},
}
if err := SaveConfig(cfgPath, cfg); err != nil {
t.Fatalf("SaveConfig: %v", err)
}
// Disk must contain enc://, not the raw key.
raw, _ := os.ReadFile(cfgPath)
if !strings.Contains(string(raw), "enc://") {
t.Errorf("saved file should contain enc://, got:\n%s", string(raw))
}
if strings.Contains(string(raw), "sk-plaintext") {
t.Errorf("saved file must not contain the plaintext key")
}
// A fresh load must decrypt back to the original plaintext.
cfg2, err := LoadConfig(cfgPath)
if err != nil {
t.Fatalf("LoadConfig after SaveConfig: %v", err)
}
if cfg2.ModelList[0].APIKey != "sk-plaintext" {
t.Errorf("loaded api_key = %q, want %q", cfg2.ModelList[0].APIKey, "sk-plaintext")
}
}
// TestLoadConfig_NoSealWithoutPassphrase verifies that api_key values are left
// unchanged when PICOCLAW_KEY_PASSPHRASE is not set.
func TestLoadConfig_NoSealWithoutPassphrase(t *testing.T) {
dir := t.TempDir()
cfgPath := filepath.Join(dir, "config.json")
data := `{"model_list":[{"model_name":"test","model":"openai/gpt-4","api_key":"sk-plaintext"}]}`
if err := os.WriteFile(cfgPath, []byte(data), 0o600); err != nil {
t.Fatalf("setup: %v", err)
}
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "")
t.Setenv("PICOCLAW_SSH_KEY_PATH", "")
if _, err := LoadConfig(cfgPath); err != nil {
t.Fatalf("LoadConfig: %v", err)
}
raw, _ := os.ReadFile(cfgPath)
if strings.Contains(string(raw), "enc://") {
t.Error("config file must not be modified when no passphrase is set")
}
}
// TestLoadConfig_FileRefNotSealed verifies that file:// api_key references are not
// converted to enc:// values (they are resolved at runtime by the Resolver).
func TestLoadConfig_FileRefNotSealed(t *testing.T) {
dir := t.TempDir()
cfgPath := filepath.Join(dir, "config.json")
keyFile := filepath.Join(dir, "openai.key")
if err := os.WriteFile(keyFile, []byte("sk-from-file"), 0o600); err != nil {
t.Fatalf("setup: %v", err)
}
data := `{"model_list":[{"model_name":"test","model":"openai/gpt-4","api_key":"file://openai.key"}]}`
if err := os.WriteFile(cfgPath, []byte(data), 0o600); err != nil {
t.Fatalf("setup: %v", err)
}
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "test-passphrase")
t.Setenv("PICOCLAW_SSH_KEY_PATH", "")
if _, err := LoadConfig(cfgPath); err != nil {
t.Fatalf("LoadConfig: %v", err)
}
raw, _ := os.ReadFile(cfgPath)
if !strings.Contains(string(raw), "file://openai.key") {
t.Error("file:// reference should be preserved unchanged in the config file")
}
if strings.Contains(string(raw), "enc://") {
t.Error("file:// reference must not be converted to enc://")
}
}
// TestSaveConfig_MixedKeys verifies that SaveConfig encrypts only plaintext api_keys
// and leaves already-encrypted (enc://) and file:// entries unchanged.
func TestSaveConfig_MixedKeys(t *testing.T) {
dir := t.TempDir()
cfgPath := filepath.Join(dir, "config.json")
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "test-passphrase")
mustSetupSSHKey(t)
// Pre-encrypt one key so we have a genuine enc:// value to put in the config.
if err := SaveConfig(cfgPath, &Config{
ModelList: []ModelConfig{
{ModelName: "pre", Model: "openai/gpt-4", APIKey: "sk-already-plain"},
},
}); err != nil {
t.Fatalf("setup SaveConfig: %v", err)
}
raw, _ := os.ReadFile(cfgPath)
// Extract the enc:// value from the saved file.
var tmp struct {
ModelList []struct {
APIKey string `json:"api_key"`
} `json:"model_list"`
}
if err := json.Unmarshal(raw, &tmp); err != nil || len(tmp.ModelList) == 0 {
t.Fatalf("setup: could not parse saved config: %v", err)
}
alreadyEncrypted := tmp.ModelList[0].APIKey
if !strings.HasPrefix(alreadyEncrypted, "enc://") {
t.Fatalf("setup: expected enc:// key, got %q", alreadyEncrypted)
}
// Build a config with three models:
// 1. plaintext → must be encrypted by SaveConfig
// 2. enc:// → must be left unchanged (already encrypted)
// 3. file:// → must be left unchanged (file reference)
keyFile := filepath.Join(dir, "api.key")
if err := os.WriteFile(keyFile, []byte("sk-from-file"), 0o600); err != nil {
t.Fatalf("setup: %v", err)
}
cfg := &Config{
ModelList: []ModelConfig{
{ModelName: "plain", Model: "openai/gpt-4", APIKey: "sk-new-plaintext"},
{ModelName: "enc", Model: "openai/gpt-4", APIKey: alreadyEncrypted},
{ModelName: "file", Model: "openai/gpt-4", APIKey: "file://api.key"},
},
}
if err := SaveConfig(cfgPath, cfg); err != nil {
t.Fatalf("SaveConfig: %v", err)
}
raw, _ = os.ReadFile(cfgPath)
s := string(raw)
// 1. Plaintext must be encrypted.
if strings.Contains(s, "sk-new-plaintext") {
t.Error("plaintext key must not appear in saved file")
}
// 2. The pre-existing enc:// value must still be present (byte-for-byte unchanged).
if !strings.Contains(s, alreadyEncrypted) {
t.Error("pre-existing enc:// entry must be preserved unchanged")
}
// 3. file:// must be preserved.
if !strings.Contains(s, "file://api.key") {
t.Error("file:// reference must be preserved unchanged")
}
// Now load and verify all three decrypt/resolve correctly.
cfg2, err := LoadConfig(cfgPath)
if err != nil {
t.Fatalf("LoadConfig after SaveConfig: %v", err)
}
byName := make(map[string]string)
for _, m := range cfg2.ModelList {
byName[m.ModelName] = m.APIKey
}
if byName["plain"] != "sk-new-plaintext" {
t.Errorf("plain model api_key = %q, want %q", byName["plain"], "sk-new-plaintext")
}
if byName["enc"] != "sk-already-plain" {
t.Errorf("enc model api_key = %q, want %q", byName["enc"], "sk-already-plain")
}
if byName["file"] != "sk-from-file" {
t.Errorf("file model api_key = %q, want %q", byName["file"], "sk-from-file")
}
}
// TestLoadConfig_MixedKeys_NoPassphrase verifies that when PICOCLAW_KEY_PASSPHRASE
// is not set, enc:// entries cause LoadConfig to return an error, while plaintext
// and file:// entries in the same config are not affected.
func TestLoadConfig_MixedKeys_NoPassphrase(t *testing.T) {
dir := t.TempDir()
cfgPath := filepath.Join(dir, "config.json")
// First encrypt a key so we have a real enc:// value.
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "test-passphrase")
mustSetupSSHKey(t)
if err := SaveConfig(cfgPath, &Config{
ModelList: []ModelConfig{
{ModelName: "m", Model: "openai/gpt-4", APIKey: "sk-secret"},
},
}); err != nil {
t.Fatalf("setup SaveConfig: %v", err)
}
raw, _ := os.ReadFile(cfgPath)
var tmp struct {
ModelList []struct {
APIKey string `json:"api_key"`
} `json:"model_list"`
}
if err := json.Unmarshal(raw, &tmp); err != nil {
t.Fatalf("setup parse: %v", err)
}
encValue := tmp.ModelList[0].APIKey
// Write a mixed config: enc:// + plaintext + file://
keyFile := filepath.Join(dir, "api.key")
if err := os.WriteFile(keyFile, []byte("sk-from-file"), 0o600); err != nil {
t.Fatalf("setup: %v", err)
}
mixed, _ := json.Marshal(map[string]any{
"model_list": []map[string]any{
{"model_name": "enc", "model": "openai/gpt-4", "api_key": encValue},
{"model_name": "plain", "model": "openai/gpt-4", "api_key": "sk-plain"},
{"model_name": "file", "model": "openai/gpt-4", "api_key": "file://api.key"},
},
})
if err := os.WriteFile(cfgPath, mixed, 0o600); err != nil {
t.Fatalf("setup write: %v", err)
}
// Now clear the passphrase — LoadConfig must fail because enc:// cannot be decrypted.
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "")
_, err := LoadConfig(cfgPath)
if err == nil {
t.Fatal("LoadConfig should fail when enc:// key is present and no passphrase is set")
}
if !strings.Contains(err.Error(), "passphrase required") {
t.Errorf("error should mention passphrase required, got: %v", err)
}
}
// TestSaveConfig_UsesPassphraseProvider verifies that SaveConfig encrypts plaintext
// api_keys using credential.PassphraseProvider() rather than os.Getenv directly.
// This matters for the launcher, which clears the environment variable and redirects
// PassphraseProvider to an in-memory SecureStore.
func TestSaveConfig_UsesPassphraseProvider(t *testing.T) {
dir := t.TempDir()
cfgPath := filepath.Join(dir, "config.json")
// Ensure the env var is empty — passphrase must come from PassphraseProvider only.
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "")
mustSetupSSHKey(t)
// Replace PassphraseProvider with an in-memory function (simulating SecureStore).
const testPassphrase = "provider-passphrase"
orig := credential.PassphraseProvider
credential.PassphraseProvider = func() string { return testPassphrase }
t.Cleanup(func() { credential.PassphraseProvider = orig })
cfg := DefaultConfig()
cfg.ModelList = []ModelConfig{
{ModelName: "test", Model: "openai/gpt-4", APIKey: "sk-plaintext"},
}
if err := SaveConfig(cfgPath, cfg); err != nil {
t.Fatalf("SaveConfig: %v", err)
}
raw, _ := os.ReadFile(cfgPath)
if !strings.Contains(string(raw), "enc://") {
t.Errorf("SaveConfig should have encrypted plaintext key via PassphraseProvider; got:\n%s", raw)
}
}
// TestLoadConfig_UsesPassphraseProvider verifies that LoadConfig decrypts enc:// keys
// using credential.PassphraseProvider() rather than os.Getenv directly.
func TestLoadConfig_UsesPassphraseProvider(t *testing.T) {
dir := t.TempDir()
cfgPath := filepath.Join(dir, "config.json")
// Ensure the env var is empty throughout.
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "")
mustSetupSSHKey(t)
const testPassphrase = "provider-passphrase"
const plainKey = "sk-secret"
// First, encrypt the key using the same passphrase.
encrypted, err := credential.Encrypt(testPassphrase, "", plainKey)
if err != nil {
t.Fatalf("Encrypt: %v", err)
}
raw, _ := json.Marshal(map[string]any{
"model_list": []map[string]any{
{"model_name": "test", "model": "openai/gpt-4", "api_key": encrypted},
},
})
if err = os.WriteFile(cfgPath, raw, 0o600); err != nil {
t.Fatalf("setup: %v", err)
}
// Redirect PassphraseProvider — env var is empty, so without this the load would fail.
orig := credential.PassphraseProvider
credential.PassphraseProvider = func() string { return testPassphrase }
t.Cleanup(func() { credential.PassphraseProvider = orig })
cfg, err := LoadConfig(cfgPath)
if err != nil {
t.Fatalf("LoadConfig: %v", err)
}
if cfg.ModelList[0].APIKey != plainKey {
t.Errorf("api_key = %q, want %q", cfg.ModelList[0].APIKey, plainKey)
}
}
+16 -2
View File
@@ -385,10 +385,20 @@ func DefaultConfig() *Config {
APIBase: "http://localhost:8000/v1",
APIKey: "",
},
// Azure OpenAI - https://portal.azure.com
// model_name is a user-friendly alias; the model field's path after "azure/" is your deployment name
{
ModelName: "azure-gpt5",
Model: "azure/my-gpt5-deployment",
APIBase: "https://your-resource.openai.azure.com",
APIKey: "",
},
},
Gateway: GatewayConfig{
Host: "127.0.0.1",
Port: 18790,
Host: "127.0.0.1",
Port: 18790,
HotReload: false,
},
Tools: ToolsConfig{
MediaCleanup: MediaCleanupConfig{
@@ -444,6 +454,7 @@ func DefaultConfig() *Config {
Enabled: true,
},
ExecTimeoutMinutes: 5,
AllowCommand: true,
},
Exec: ExecConfig{
ToolConfig: ToolConfig{
@@ -513,6 +524,9 @@ func DefaultConfig() *Config {
Spawn: ToolConfig{
Enabled: true,
},
SpawnStatus: ToolConfig{
Enabled: false,
},
SPI: ToolConfig{
Enabled: false, // Hardware tool - Linux only
},
+335
View File
@@ -0,0 +1,335 @@
// Package credential resolves API credential values for model_list entries.
//
// An API key is a form of authorization credential. This package centralizes
// how raw credential strings—plaintext or file references—are resolved into
// their actual values, keeping that logic out of the config loader.
//
// Supported formats for the api_key field:
//
// - Plaintext: "sk-abc123" → returned as-is
// - File ref: "file://filename.key" → content read from configDir/filename.key
// - Encrypted: "enc://<base64>" → AES-256-GCM decrypt via PICOCLAW_KEY_PASSPHRASE
// - Empty: "" → returned as-is (auth_method=oauth etc.)
//
// Encryption uses AES-256-GCM with HKDF-SHA256 key derivation (< 1ms, safe for embedded Linux).
// An SSH private key is required for both encryption and decryption.
// Key derivation:
//
// HKDF-SHA256(ikm=HMAC-SHA256(SHA256(sshKeyBytes), passphrase), salt, info)
//
// SSH key path resolution priority:
//
// 1. sshKeyPath argument to Encrypt (explicit)
// 2. PICOCLAW_SSH_KEY_PATH env var
// 3. ~/.ssh/picoclaw_ed25519.key (os.UserHomeDir is cross-platform)
package credential
import (
"crypto/aes"
"crypto/cipher"
"crypto/hkdf"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
"io"
"os"
"path/filepath"
"strings"
)
// PassphraseEnvVar is the environment variable that holds the encryption passphrase.
// Other packages (e.g. config) reference this constant to avoid duplicating the string.
const PassphraseEnvVar = "PICOCLAW_KEY_PASSPHRASE"
// PassphraseProvider is the function used to retrieve the passphrase for enc://
// credential decryption. It defaults to reading PICOCLAW_KEY_PASSPHRASE from the
// process environment. Replace it at startup to use a different source, such as
// an in-memory SecureStore, so that all LoadConfig() calls everywhere share the
// same passphrase source without needing os.Environ.
//
// Example (launcher main.go):
//
// credential.PassphraseProvider = apiHandler.passphraseStore.Get
var PassphraseProvider func() string = func() string {
return os.Getenv(PassphraseEnvVar)
}
// ErrPassphraseRequired is returned when an enc:// credential is encountered but
// no passphrase is available from PassphraseProvider. Callers can detect this
// with errors.Is to distinguish a missing-passphrase condition from other errors.
var ErrPassphraseRequired = errors.New("credential: enc:// passphrase required")
// ErrDecryptionFailed is returned when an enc:// credential cannot be decrypted,
// indicating a wrong passphrase or SSH key. Callers can detect this with errors.Is.
var ErrDecryptionFailed = errors.New("credential: enc:// decryption failed (wrong passphrase or SSH key?)")
const (
fileScheme = "file://"
encScheme = "enc://"
hkdfInfo = "picoclaw-credential-v1"
saltLen = 16
nonceLen = 12
keyLen = 32
sshKeyEnv = "PICOCLAW_SSH_KEY_PATH"
)
// Resolver resolves raw credential strings for model_list api_key fields.
// File references are resolved relative to the directory of the config file.
type Resolver struct {
configDir string
resolvedConfigDir string // symlink-resolved form of configDir
}
// NewResolver returns a Resolver that resolves file:// references relative to
// configDir (typically filepath.Dir of the config file path).
func NewResolver(configDir string) *Resolver {
resolved := configDir
if configDir != "" {
if linkedPath, err := filepath.EvalSymlinks(configDir); err == nil {
resolved = linkedPath
}
}
return &Resolver{configDir: configDir, resolvedConfigDir: resolved}
}
// Resolve returns the actual credential value for raw:
//
// - "" → "" (no error; auth_method=oauth needs no key)
// - "file://name.key" → trimmed content of configDir/name.key
// - anything else → raw unchanged (plaintext credential)
func (r *Resolver) Resolve(raw string) (string, error) {
if raw == "" {
return "", nil
}
if strings.HasPrefix(raw, fileScheme) {
fileName := strings.TrimSpace(strings.TrimPrefix(raw, fileScheme))
if fileName == "" {
return "", fmt.Errorf("credential: file:// reference has no filename")
}
baseDir := r.resolvedConfigDir
if baseDir == "" {
baseDir = r.configDir
}
keyPath := filepath.Join(baseDir, fileName)
// Resolve symlinks before enforcing containment to prevent escaping via symlinks.
realKeyPath, err := filepath.EvalSymlinks(keyPath)
if err != nil {
return "", fmt.Errorf("credential: failed to resolve credential file path %q: %w", keyPath, err)
}
if !isWithinDir(realKeyPath, baseDir) {
return "", fmt.Errorf("credential: file:// path escapes config directory")
}
data, err := os.ReadFile(realKeyPath)
if err != nil {
return "", fmt.Errorf("credential: failed to read credential file %q: %w", realKeyPath, err)
}
value := strings.TrimSpace(string(data))
if value == "" {
return "", fmt.Errorf("credential: credential file %q is empty", realKeyPath)
}
return value, nil
}
if strings.HasPrefix(raw, encScheme) {
return resolveEncrypted(raw)
}
// Plaintext credential — return unchanged.
return raw, nil
}
// resolveEncrypted decrypts an enc:// credential using PassphraseProvider.
func resolveEncrypted(raw string) (string, error) {
passphrase := PassphraseProvider()
if passphrase == "" {
return "", ErrPassphraseRequired
}
sshKeyPath := pickSSHKeyPath("") // override="": consult env then auto-detect
b64 := strings.TrimPrefix(raw, encScheme)
blob, err := base64.StdEncoding.DecodeString(b64)
if err != nil {
return "", fmt.Errorf("credential: enc:// invalid base64: %w", err)
}
if len(blob) < saltLen+nonceLen+1 {
return "", fmt.Errorf("credential: enc:// payload too short")
}
salt := blob[:saltLen]
nonce := blob[saltLen : saltLen+nonceLen]
ciphertext := blob[saltLen+nonceLen:]
key, err := deriveKey(passphrase, sshKeyPath, salt)
if err != nil {
return "", err
}
block, err := aes.NewCipher(key)
if err != nil {
return "", fmt.Errorf("credential: enc:// cipher init: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", fmt.Errorf("credential: enc:// gcm init: %w", err)
}
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
return "", fmt.Errorf("%w: %w", ErrDecryptionFailed, err)
}
return string(plaintext), nil
}
// Encrypt encrypts plaintext and returns an enc:// credential string.
//
// passphrase is required (PICOCLAW_KEY_PASSPHRASE value).
// sshKeyPath is the SSH private key file to use; pass "" to auto-detect via
// PICOCLAW_SSH_KEY_PATH env var or ~/.ssh/picoclaw_ed25519.key.
// An SSH private key must be resolvable or Encrypt returns an error.
func Encrypt(passphrase, sshKeyPath, plaintext string) (string, error) {
if passphrase == "" {
return "", fmt.Errorf("credential: passphrase must not be empty")
}
sshKeyPath = pickSSHKeyPath(sshKeyPath)
salt := make([]byte, saltLen)
if _, err := io.ReadFull(rand.Reader, salt); err != nil {
return "", fmt.Errorf("credential: failed to generate salt: %w", err)
}
key, err := deriveKey(passphrase, sshKeyPath, salt)
if err != nil {
return "", err
}
block, err := aes.NewCipher(key)
if err != nil {
return "", fmt.Errorf("credential: cipher init: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", fmt.Errorf("credential: gcm init: %w", err)
}
nonce := make([]byte, nonceLen)
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return "", fmt.Errorf("credential: failed to generate nonce: %w", err)
}
ciphertext := gcm.Seal(nil, nonce, []byte(plaintext), nil)
blob := make([]byte, 0, saltLen+nonceLen+len(ciphertext))
blob = append(blob, salt...)
blob = append(blob, nonce...)
blob = append(blob, ciphertext...)
return encScheme + base64.StdEncoding.EncodeToString(blob), nil
}
// isWithinDir reports whether path is contained within (or equal to) dir.
// Uses filepath.IsLocal on the relative path for robust cross-platform traversal detection.
func isWithinDir(path, dir string) bool {
rel, err := filepath.Rel(filepath.Clean(dir), filepath.Clean(path))
return err == nil && filepath.IsLocal(rel)
}
// allowedSSHKeyPath reports whether path is in a permitted location for SSH key files:
// - exact match with PICOCLAW_SSH_KEY_PATH env var
// - within the PICOCLAW_HOME env var directory
// - within ~/.ssh/
func allowedSSHKeyPath(path string) bool {
if path == "" {
return true // passphrase-only mode; no file will be read
}
clean := filepath.Clean(path)
// Exact match with PICOCLAW_SSH_KEY_PATH.
if envPath, ok := os.LookupEnv(sshKeyEnv); ok && envPath != "" {
if clean == filepath.Clean(envPath) {
return true
}
}
// Within PICOCLAW_HOME.
if picoHome := os.Getenv("PICOCLAW_HOME"); picoHome != "" {
if isWithinDir(clean, picoHome) {
return true
}
}
// Within ~/.ssh/.
if userHome, err := os.UserHomeDir(); err == nil {
if isWithinDir(clean, filepath.Join(userHome, ".ssh")) {
return true
}
}
return false
}
// deriveKey derives a 32-byte AES-256 key from passphrase and SSH private key.
//
// ikm = HMAC-SHA256(key=SHA256(sshKeyBytes), msg=passphrase)
// Final key: HKDF-SHA256(ikm, salt, info="picoclaw-credential-v1", 32 bytes)
// sshKeyPath must be non-empty; returns an error otherwise.
func deriveKey(passphrase, sshKeyPath string, salt []byte) ([]byte, error) {
if sshKeyPath == "" {
return nil, fmt.Errorf(
"credential: SSH private key is required but not found" +
" (set PICOCLAW_SSH_KEY_PATH or place key at ~/.ssh/picoclaw_ed25519.key)")
}
if !allowedSSHKeyPath(sshKeyPath) {
return nil, fmt.Errorf(
"credential: SSH key path %q is not in an allowed location (PICOCLAW_SSH_KEY_PATH, PICOCLAW_HOME, or ~/.ssh/)",
sshKeyPath,
)
}
sshBytes, err := os.ReadFile(sshKeyPath)
if err != nil {
return nil, fmt.Errorf("credential: cannot read SSH key %q: %w", sshKeyPath, err)
}
sshHash := sha256.Sum256(sshBytes)
mac := hmac.New(sha256.New, sshHash[:])
mac.Write([]byte(passphrase))
ikm := mac.Sum(nil)
key, err := hkdf.Key(sha256.New, ikm, salt, hkdfInfo, keyLen)
if err != nil {
return nil, fmt.Errorf("credential: HKDF expand failed: %w", err)
}
return key, nil
}
// pickSSHKeyPath returns the SSH private key path to use for encryption/decryption.
//
// Priority:
// 1. override (non-empty explicit argument)
// 2. PICOCLAW_SSH_KEY_PATH env var
// 3. ~/.ssh/picoclaw_ed25519.key (auto-detection)
//
// Returns "" when no key is found; deriveKey will return an error in that case.
func pickSSHKeyPath(override string) string {
if override != "" {
return override
}
if p, ok := os.LookupEnv(sshKeyEnv); ok {
return p // respect explicit setting, even if ""
}
return findDefaultSSHKey()
}
// findDefaultSSHKey returns the picoclaw-specific SSH key path if it exists.
func findDefaultSSHKey() string {
p, err := DefaultSSHKeyPath()
if err != nil {
return ""
}
if _, err := os.Stat(p); err == nil {
return p
}
return ""
}
+283
View File
@@ -0,0 +1,283 @@
package credential_test
import (
"os"
"path/filepath"
"testing"
"github.com/sipeed/picoclaw/pkg/credential"
)
func TestResolve_PlainKey(t *testing.T) {
r := credential.NewResolver(t.TempDir())
got, err := r.Resolve("sk-plaintext-key")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != "sk-plaintext-key" {
t.Fatalf("got %q, want %q", got, "sk-plaintext-key")
}
}
func TestResolve_FileKey_Success(t *testing.T) {
dir := t.TempDir()
keyFile := "openai_plain.key"
if err := os.WriteFile(filepath.Join(dir, keyFile), []byte("sk-from-file\n"), 0o600); err != nil {
t.Fatalf("setup: %v", err)
}
r := credential.NewResolver(dir)
got, err := r.Resolve("file://" + keyFile)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != "sk-from-file" {
t.Fatalf("got %q, want %q", got, "sk-from-file")
}
}
func TestResolve_FileKey_NotFound(t *testing.T) {
r := credential.NewResolver(t.TempDir())
_, err := r.Resolve("file://missing.key")
if err == nil {
t.Fatal("expected error for missing file, got nil")
}
}
func TestResolve_FileKey_Empty(t *testing.T) {
dir := t.TempDir()
keyFile := "empty.key"
if err := os.WriteFile(filepath.Join(dir, keyFile), []byte(" \n"), 0o600); err != nil {
t.Fatalf("setup: %v", err)
}
r := credential.NewResolver(dir)
_, err := r.Resolve("file://" + keyFile)
if err == nil {
t.Fatal("expected error for empty credential file, got nil")
}
}
// TestResolve_EncKey_RoundTrip tests basic encryption/decryption round-trip with an SSH key.
func TestResolve_EncKey_RoundTrip(t *testing.T) {
dir := t.TempDir()
sshKeyPath := filepath.Join(dir, "picoclaw_ed25519.key")
if err := os.WriteFile(sshKeyPath, []byte("fake-ssh-key-material\n"), 0o600); err != nil {
t.Fatalf("setup: %v", err)
}
const passphrase = "test-passphrase-32bytes-long-ok!"
const plaintext = "sk-encrypted-secret"
t.Setenv("PICOCLAW_SSH_KEY_PATH", sshKeyPath)
enc, err := credential.Encrypt(passphrase, "", plaintext)
if err != nil {
t.Fatalf("Encrypt: %v", err)
}
t.Setenv("PICOCLAW_KEY_PASSPHRASE", passphrase)
r := credential.NewResolver(t.TempDir())
got, err := r.Resolve(enc)
if err != nil {
t.Fatalf("Resolve: %v", err)
}
if got != plaintext {
t.Fatalf("got %q, want %q", got, plaintext)
}
}
// TestResolve_EncKey_WithSSHKey tests that the SSH key file is incorporated into key derivation.
func TestResolve_EncKey_WithSSHKey(t *testing.T) {
dir := t.TempDir()
sshKeyPath := filepath.Join(dir, "picoclaw_ed25519.key")
if err := os.WriteFile(sshKeyPath, []byte("fake-ssh-private-key-material\n"), 0o600); err != nil {
t.Fatalf("setup: %v", err)
}
const passphrase = "test-passphrase"
const plaintext = "sk-ssh-protected-secret"
// Set PICOCLAW_SSH_KEY_PATH before Encrypt so the path passes allowedSSHKeyPath validation.
t.Setenv("PICOCLAW_KEY_PASSPHRASE", passphrase)
t.Setenv("PICOCLAW_SSH_KEY_PATH", sshKeyPath)
enc, err := credential.Encrypt(passphrase, sshKeyPath, plaintext)
if err != nil {
t.Fatalf("Encrypt: %v", err)
}
r := credential.NewResolver(t.TempDir())
got, err := r.Resolve(enc)
if err != nil {
t.Fatalf("Resolve: %v", err)
}
if got != plaintext {
t.Fatalf("got %q, want %q", got, plaintext)
}
}
func TestResolve_EncKey_NoPassphrase(t *testing.T) {
dir := t.TempDir()
sshKeyPath := filepath.Join(dir, "picoclaw_ed25519.key")
if err := os.WriteFile(sshKeyPath, []byte("fake-ssh-key\n"), 0o600); err != nil {
t.Fatalf("setup: %v", err)
}
t.Setenv("PICOCLAW_SSH_KEY_PATH", sshKeyPath)
enc, err := credential.Encrypt("some-passphrase", "", "sk-secret")
if err != nil {
t.Fatalf("Encrypt: %v", err)
}
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "")
r := credential.NewResolver(t.TempDir())
_, err = r.Resolve(enc)
if err == nil {
t.Fatal("expected error when PICOCLAW_KEY_PASSPHRASE is unset, got nil")
}
}
func TestResolve_EncKey_BadCiphertext(t *testing.T) {
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "some-passphrase")
t.Setenv("PICOCLAW_SSH_KEY_PATH", "")
r := credential.NewResolver(t.TempDir())
_, err := r.Resolve("enc://!!not-valid-base64!!")
if err == nil {
t.Fatal("expected error for invalid enc:// payload, got nil")
}
}
func TestResolve_EncKey_PayloadTooShort(t *testing.T) {
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "some-passphrase")
t.Setenv("PICOCLAW_SSH_KEY_PATH", "")
// Valid base64 but fewer bytes than salt(16)+nonce(12)+1 minimum.
import64 := "dG9vc2hvcnQ=" // "tooshort" = 8 bytes
r := credential.NewResolver(t.TempDir())
_, err := r.Resolve("enc://" + import64)
if err == nil {
t.Fatal("expected error for too-short enc:// payload, got nil")
}
}
func TestResolve_EncKey_WrongPassphrase(t *testing.T) {
dir := t.TempDir()
sshKeyPath := filepath.Join(dir, "picoclaw_ed25519.key")
if err := os.WriteFile(sshKeyPath, []byte("fake-ssh-key\n"), 0o600); err != nil {
t.Fatalf("setup: %v", err)
}
t.Setenv("PICOCLAW_SSH_KEY_PATH", sshKeyPath)
enc, err := credential.Encrypt("correct-passphrase", "", "sk-secret")
if err != nil {
t.Fatalf("Encrypt: %v", err)
}
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "wrong-passphrase")
r := credential.NewResolver(t.TempDir())
_, err = r.Resolve(enc)
if err == nil {
t.Fatal("expected decryption error for wrong passphrase, got nil")
}
}
func TestEncrypt_EmptyPassphrase(t *testing.T) {
_, err := credential.Encrypt("", "", "sk-secret")
if err == nil {
t.Fatal("expected error for empty passphrase, got nil")
}
}
func TestDeriveKey_SSHKeyNotFound(t *testing.T) {
// Encrypt with a real SSH key path, then try to decrypt with a missing path.
dir := t.TempDir()
sshKeyPath := filepath.Join(dir, "picoclaw_ed25519.key")
if err := os.WriteFile(sshKeyPath, []byte("fake-key\n"), 0o600); err != nil {
t.Fatalf("setup: %v", err)
}
// Register the real key path so allowedSSHKeyPath validation passes for Encrypt.
t.Setenv("PICOCLAW_SSH_KEY_PATH", sshKeyPath)
enc, err := credential.Encrypt("passphrase", sshKeyPath, "sk-secret")
if err != nil {
t.Fatalf("Encrypt: %v", err)
}
// Point to a non-existent SSH key so deriveKey's ReadFile fails.
// The path is still under the same dir, so allowedSSHKeyPath passes (exact env match).
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "passphrase")
t.Setenv("PICOCLAW_SSH_KEY_PATH", filepath.Join(dir, "nonexistent_key"))
r := credential.NewResolver(t.TempDir())
_, err = r.Resolve(enc)
if err == nil {
t.Fatal("expected error when SSH key file is missing, got nil")
}
}
// TestResolve_FileRef_PathTraversal verifies that file:// references cannot escape configDir
// via relative traversal ("../../etc/passwd") or absolute paths ("/abs/path").
func TestResolve_FileRef_PathTraversal(t *testing.T) {
dir := t.TempDir()
cfgPath := filepath.Join(dir, "config.json")
// Create a file outside configDir that the traversal would point to.
outsideFile := filepath.Join(t.TempDir(), "secret.key")
if err := os.WriteFile(outsideFile, []byte("stolen"), 0o600); err != nil {
t.Fatalf("setup: %v", err)
}
r := credential.NewResolver(filepath.Dir(cfgPath))
cases := []string{
"file://../../secret.key",
"file://../secret.key",
"file://" + outsideFile, // absolute path
}
for _, raw := range cases {
_, err := r.Resolve(raw)
if err == nil {
t.Errorf("Resolve(%q): expected path traversal error, got nil", raw)
}
}
}
// TestResolve_FileRef_withinConfigDir verifies that a legitimate relative file:// ref works.
func TestResolve_FileRef_withinConfigDir(t *testing.T) {
dir := t.TempDir()
if err := os.WriteFile(filepath.Join(dir, "my.key"), []byte("sk-valid\n"), 0o600); err != nil {
t.Fatalf("setup: %v", err)
}
r := credential.NewResolver(dir)
got, err := r.Resolve("file://my.key")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != "sk-valid" {
t.Fatalf("got %q, want %q", got, "sk-valid")
}
}
// TestEncrypt_SSHKeyOutsideAllowedDirs verifies that Encrypt rejects SSH key paths
// that are not under PICOCLAW_SSH_KEY_PATH, PICOCLAW_HOME, or ~/.ssh/.
func TestEncrypt_SSHKeyOutsideAllowedDirs(t *testing.T) {
dir := t.TempDir()
sshKeyPath := filepath.Join(dir, "picoclaw_ed25519.key")
if err := os.WriteFile(sshKeyPath, []byte("fake-key\n"), 0o600); err != nil {
t.Fatalf("setup: %v", err)
}
// Make sure none of the allowed env vars point here.
t.Setenv("PICOCLAW_SSH_KEY_PATH", "")
t.Setenv("PICOCLAW_HOME", "")
_, err := credential.Encrypt("passphrase", sshKeyPath, "sk-secret")
if err == nil {
t.Fatal("expected error for SSH key outside allowed directories, got nil")
}
}
+62
View File
@@ -0,0 +1,62 @@
package credential
import (
"crypto/ed25519"
"crypto/rand"
"encoding/pem"
"fmt"
"os"
"path/filepath"
"golang.org/x/crypto/ssh"
)
// DefaultSSHKeyPath returns the canonical path for the picoclaw-specific SSH key.
// The path is always ~/.ssh/picoclaw_ed25519.key (os.UserHomeDir is cross-platform).
func DefaultSSHKeyPath() (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", fmt.Errorf("credential: cannot determine home directory: %w", err)
}
return filepath.Join(home, ".ssh", "picoclaw_ed25519.key"), nil
}
// GenerateSSHKey generates an Ed25519 SSH key pair and writes the private key
// to path (permissions 0600) and the public key to path+".pub" (permissions 0644).
// The ~/.ssh/ directory is created with 0700 if it does not exist.
// If the files already exist they are overwritten.
func GenerateSSHKey(path string) error {
if err := os.MkdirAll(filepath.Dir(path), 0o700); err != nil {
return fmt.Errorf("credential: keygen: cannot create directory %q: %w", filepath.Dir(path), err)
}
pubRaw, privRaw, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
return fmt.Errorf("credential: keygen: ed25519 key generation failed: %w", err)
}
// Marshal private key as OpenSSH PEM.
block, err := ssh.MarshalPrivateKey(privRaw, "")
if err != nil {
return fmt.Errorf("credential: keygen: marshal private key: %w", err)
}
privPEM := pem.EncodeToMemory(block)
if err = os.WriteFile(path, privPEM, 0o600); err != nil {
return fmt.Errorf("credential: keygen: write private key %q: %w", path, err)
}
// Marshal public key as authorized_keys line.
sshPub, err := ssh.NewPublicKey(pubRaw)
if err != nil {
return fmt.Errorf("credential: keygen: marshal public key: %w", err)
}
pubLine := ssh.MarshalAuthorizedKey(sshPub)
pubPath := path + ".pub"
if err := os.WriteFile(pubPath, pubLine, 0o644); err != nil {
return fmt.Errorf("credential: keygen: write public key %q: %w", pubPath, err)
}
return nil
}
+115
View File
@@ -0,0 +1,115 @@
package credential
import (
"crypto/ed25519"
"os"
"path/filepath"
"runtime"
"testing"
"golang.org/x/crypto/ssh"
)
func TestGenerateSSHKey_CreatesFiles(t *testing.T) {
dir := t.TempDir()
keyPath := filepath.Join(dir, "test_ed25519.key")
if err := GenerateSSHKey(keyPath); err != nil {
t.Fatalf("GenerateSSHKey() error = %v", err)
}
// Private key must exist.
privInfo, err := os.Stat(keyPath)
if err != nil {
t.Fatalf("private key file missing: %v", err)
}
// Check permissions on non-Windows (Windows does not support Unix permission bits).
if runtime.GOOS != "windows" {
if got := privInfo.Mode().Perm(); got != 0o600 {
t.Errorf("private key permissions = %04o, want 0600", got)
}
}
// Public key must exist.
pubPath := keyPath + ".pub"
pubInfo, err := os.Stat(pubPath)
if err != nil {
t.Fatalf("public key file missing: %v", err)
}
if runtime.GOOS != "windows" {
if got := pubInfo.Mode().Perm(); got != 0o644 {
t.Errorf("public key permissions = %04o, want 0644", got)
}
}
// Private key must be parseable as an OpenSSH ed25519 key.
privPEM, err := os.ReadFile(keyPath)
if err != nil {
t.Fatalf("read private key: %v", err)
}
privKey, err := ssh.ParseRawPrivateKey(privPEM)
if err != nil {
t.Fatalf("parse private key: %v", err)
}
if _, ok := privKey.(*ed25519.PrivateKey); !ok {
t.Errorf("private key type = %T, want *ed25519.PrivateKey", privKey)
}
// Public key must be parseable as authorized_keys line.
pubBytes, err := os.ReadFile(pubPath)
if err != nil {
t.Fatalf("read public key: %v", err)
}
pubKey, _, _, rest, err := ssh.ParseAuthorizedKey(pubBytes)
if err != nil {
t.Fatalf("parse public key: %v", err)
}
if pubKey == nil {
t.Fatal("expected non-nil public key")
}
if len(rest) > 0 {
t.Errorf("unexpected trailing bytes after public key: %d bytes", len(rest))
}
}
func TestGenerateSSHKey_OverwritesExisting(t *testing.T) {
dir := t.TempDir()
keyPath := filepath.Join(dir, "test_ed25519.key")
// Generate twice; second call must not error and must produce a different key.
if err := GenerateSSHKey(keyPath); err != nil {
t.Fatalf("first GenerateSSHKey() error = %v", err)
}
first, err := os.ReadFile(keyPath)
if err != nil {
t.Fatalf("read first key: %v", err)
}
if err = GenerateSSHKey(keyPath); err != nil {
t.Fatalf("second GenerateSSHKey() error = %v", err)
}
second, err := os.ReadFile(keyPath)
if err != nil {
t.Fatalf("read second key: %v", err)
}
// Two independently generated Ed25519 keys must differ.
if string(first) == string(second) {
t.Error("expected overwritten key to differ from original")
}
}
func TestGenerateSSHKey_CreatesDirectory(t *testing.T) {
dir := t.TempDir()
// Nested directory that does not yet exist.
keyPath := filepath.Join(dir, "subdir", ".ssh", "picoclaw_ed25519.key")
if err := GenerateSSHKey(keyPath); err != nil {
t.Fatalf("GenerateSSHKey() error = %v", err)
}
if _, err := os.Stat(keyPath); err != nil {
t.Fatalf("private key not created: %v", err)
}
}
+44
View File
@@ -0,0 +1,44 @@
package credential
import "sync/atomic"
// SecureStore holds a passphrase in memory.
//
// Uses atomic.Pointer so reads and writes are lock-free.
// The passphrase is never written to disk; callers decide how to
// transport it outside this store (e.g., via cmd.Env or os.Environ).
type SecureStore struct {
val atomic.Pointer[string]
}
// NewSecureStore creates an empty SecureStore.
func NewSecureStore() *SecureStore {
return &SecureStore{}
}
// SetString stores the passphrase. An empty string clears the store.
func (s *SecureStore) SetString(passphrase string) {
if passphrase == "" {
s.val.Store(nil)
return
}
s.val.Store(&passphrase)
}
// Get returns the stored passphrase, or "" if not set.
func (s *SecureStore) Get() string {
if p := s.val.Load(); p != nil {
return *p
}
return ""
}
// IsSet reports whether a passphrase is currently stored.
func (s *SecureStore) IsSet() bool {
return s.val.Load() != nil
}
// Clear removes the stored passphrase.
func (s *SecureStore) Clear() {
s.val.Store(nil)
}
+81
View File
@@ -0,0 +1,81 @@
package credential
import (
"sync"
"testing"
)
func TestSecureStore_SetGet(t *testing.T) {
s := NewSecureStore()
if s.IsSet() {
t.Error("expected empty store")
}
s.SetString("hunter2")
if !s.IsSet() {
t.Error("expected store to be set")
}
if got := s.Get(); got != "hunter2" {
t.Errorf("Get() = %q, want %q", got, "hunter2")
}
}
func TestSecureStore_Clear(t *testing.T) {
s := NewSecureStore()
s.SetString("secret")
s.Clear()
if s.IsSet() {
t.Error("expected store to be empty after Clear()")
}
if got := s.Get(); got != "" {
t.Errorf("Get() after Clear() = %q, want empty", got)
}
}
func TestSecureStore_SetOverwrites(t *testing.T) {
s := NewSecureStore()
s.SetString("first")
s.SetString("second")
if got := s.Get(); got != "second" {
t.Errorf("Get() = %q, want %q", got, "second")
}
}
func TestSecureStore_EmptyPassphrase(t *testing.T) {
s := NewSecureStore()
s.SetString("") // empty → should not mark as set
if s.IsSet() {
t.Error("empty passphrase should not mark store as set")
}
}
func TestSecureStore_ConcurrentSetGet(t *testing.T) {
s := NewSecureStore()
const goroutines = 10
const iterations = 1000
var wg sync.WaitGroup
wg.Add(goroutines)
for i := 0; i < goroutines; i++ {
go func(id int) {
defer wg.Done()
for j := 0; j < iterations; j++ {
if id%2 == 0 {
s.SetString("even")
} else {
s.SetString("odd")
}
_ = s.Get()
}
}(i)
}
wg.Wait()
final := s.Get()
if final != "" && final != "even" && final != "odd" {
t.Errorf("Get() returned unexpected value %q after concurrent Set/Get", final)
}
}
@@ -7,9 +7,9 @@ import (
"os/signal"
"path/filepath"
"sync"
"syscall"
"time"
"github.com/sipeed/picoclaw/cmd/picoclaw/internal"
"github.com/sipeed/picoclaw/pkg/agent"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
@@ -41,16 +41,13 @@ import (
"github.com/sipeed/picoclaw/pkg/voice"
)
// Timeout constants for service operations
const (
serviceRestartTimeout = 30 * time.Second
serviceShutdownTimeout = 30 * time.Second
providerReloadTimeout = 30 * time.Second
gracefulShutdownTimeout = 15 * time.Second
)
// gatewayServices holds references to all running services
type gatewayServices struct {
type services struct {
CronService *cron.CronService
HeartbeatService *heartbeat.HeartbeatService
MediaStore media.MediaStore
@@ -59,24 +56,41 @@ type gatewayServices struct {
HealthServer *health.Server
}
func gatewayCmd(debug bool) error {
type startupBlockedProvider struct {
reason string
}
func (p *startupBlockedProvider) Chat(
_ context.Context,
_ []providers.Message,
_ []providers.ToolDefinition,
_ string,
_ map[string]any,
) (*providers.LLMResponse, error) {
return nil, fmt.Errorf("%s", p.reason)
}
func (p *startupBlockedProvider) GetDefaultModel() string {
return ""
}
// Run starts the gateway runtime using the configuration loaded from configPath.
func Run(debug bool, configPath string, allowEmptyStartup bool) error {
if debug {
logger.SetLevel(logger.DEBUG)
fmt.Println("🔍 Debug mode enabled")
}
configPath := internal.GetConfigPath()
cfg, err := internal.LoadConfig()
cfg, err := config.LoadConfig(configPath)
if err != nil {
return fmt.Errorf("error loading config: %w", err)
}
provider, modelID, err := providers.CreateProvider(cfg)
provider, modelID, err := createStartupProvider(cfg, allowEmptyStartup)
if err != nil {
return fmt.Errorf("error creating provider: %w", err)
}
// Use the resolved model ID from provider creation
if modelID != "" {
cfg.Agents.Defaults.ModelName = modelID
}
@@ -84,17 +98,13 @@ func gatewayCmd(debug bool) error {
msgBus := bus.NewMessageBus()
agentLoop := agent.NewAgentLoop(cfg, msgBus, provider)
// Print agent startup info
fmt.Println("\n📦 Agent Status:")
startupInfo := agentLoop.GetStartupInfo()
toolsInfo := startupInfo["tools"].(map[string]any)
skillsInfo := startupInfo["skills"].(map[string]any)
fmt.Printf(" • Tools: %d loaded\n", toolsInfo["count"])
fmt.Printf(" • Skills: %d/%d available\n",
skillsInfo["available"],
skillsInfo["total"])
fmt.Printf(" • Skills: %d/%d available\n", skillsInfo["available"], skillsInfo["total"])
// Log to file as well
logger.InfoCF("agent", "Agent initialized",
map[string]any{
"tools_count": toolsInfo["count"],
@@ -102,8 +112,7 @@ func gatewayCmd(debug bool) error {
"skills_available": skillsInfo["available"],
})
// Setup and start all services
services, err := setupAndStartServices(cfg, agentLoop, msgBus)
runningServices, err := setupAndStartServices(cfg, agentLoop, msgBus)
if err != nil {
return err
}
@@ -116,23 +125,25 @@ func gatewayCmd(debug bool) error {
go agentLoop.Run(ctx)
// Setup config file watcher for hot reload
configReloadChan, stopWatch := setupConfigWatcherPolling(configPath, debug)
var configReloadChan <-chan *config.Config
stopWatch := func() {}
if cfg.Gateway.HotReload {
configReloadChan, stopWatch = setupConfigWatcherPolling(configPath, debug)
logger.Info("Config hot reload enabled")
}
defer stopWatch()
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt)
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
// Main event loop - wait for signals or config changes
for {
select {
case <-sigChan:
logger.Info("Shutting down...")
shutdownGateway(services, agentLoop, provider, true)
shutdownGateway(runningServices, agentLoop, provider, true)
return nil
case newCfg := <-configReloadChan:
err := handleConfigReload(ctx, agentLoop, newCfg, &provider, services, msgBus)
err := handleConfigReload(ctx, agentLoop, newCfg, &provider, runningServices, msgBus, allowEmptyStartup)
if err != nil {
logger.Errorf("Config reload failed: %v", err)
}
@@ -140,17 +151,33 @@ func gatewayCmd(debug bool) error {
}
}
// setupAndStartServices initializes and starts all services
func createStartupProvider(
cfg *config.Config,
allowEmptyStartup bool,
) (providers.LLMProvider, string, error) {
modelName := cfg.Agents.Defaults.GetModelName()
if modelName == "" && allowEmptyStartup {
reason := "no default model configured; gateway started in limited mode"
fmt.Printf("⚠ Warning: %s\n", reason)
logger.WarnCF("gateway", "Gateway started without default model", map[string]any{
"limited_mode": true,
})
return &startupBlockedProvider{reason: reason}, "", nil
}
return providers.CreateProvider(cfg)
}
func setupAndStartServices(
cfg *config.Config,
agentLoop *agent.AgentLoop,
msgBus *bus.MessageBus,
) (*gatewayServices, error) {
services := &gatewayServices{}
) (*services, error) {
runningServices := &services{}
// Setup cron tool and service
execTimeout := time.Duration(cfg.Tools.Cron.ExecTimeoutMinutes) * time.Minute
services.CronService = setupCronTool(
var err error
runningServices.CronService, err = setupCronTool(
agentLoop,
msgBus,
cfg.WorkspacePath(),
@@ -158,139 +185,108 @@ func setupAndStartServices(
execTimeout,
cfg,
)
if err := services.CronService.Start(); err != nil {
if err != nil {
return nil, fmt.Errorf("error setting up cron service: %w", err)
}
if err = runningServices.CronService.Start(); err != nil {
return nil, fmt.Errorf("error starting cron service: %w", err)
}
fmt.Println("✓ Cron service started")
// Setup heartbeat service
services.HeartbeatService = heartbeat.NewHeartbeatService(
runningServices.HeartbeatService = heartbeat.NewHeartbeatService(
cfg.WorkspacePath(),
cfg.Heartbeat.Interval,
cfg.Heartbeat.Enabled,
)
services.HeartbeatService.SetBus(msgBus)
services.HeartbeatService.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
// Use cli:direct as fallback if no valid channel
if channel == "" || chatID == "" {
channel, chatID = "cli", "direct"
}
// Use ProcessHeartbeat - no session history, each heartbeat is independent
var response string
var err error
response, err = agentLoop.ProcessHeartbeat(context.Background(), prompt, channel, chatID)
if err != nil {
return tools.ErrorResult(fmt.Sprintf("Heartbeat error: %v", err))
}
if response == "HEARTBEAT_OK" {
return tools.SilentResult("Heartbeat OK")
}
// For heartbeat, always return silent - the subagent result will be
// sent to user via processSystemMessage when the async task completes
return tools.SilentResult(response)
})
if err := services.HeartbeatService.Start(); err != nil {
runningServices.HeartbeatService.SetBus(msgBus)
runningServices.HeartbeatService.SetHandler(createHeartbeatHandler(agentLoop))
if err = runningServices.HeartbeatService.Start(); err != nil {
return nil, fmt.Errorf("error starting heartbeat service: %w", err)
}
fmt.Println("✓ Heartbeat service started")
// Create media store for file lifecycle management with TTL cleanup
services.MediaStore = media.NewFileMediaStoreWithCleanup(media.MediaCleanerConfig{
runningServices.MediaStore = media.NewFileMediaStoreWithCleanup(media.MediaCleanerConfig{
Enabled: cfg.Tools.MediaCleanup.Enabled,
MaxAge: time.Duration(cfg.Tools.MediaCleanup.MaxAge) * time.Minute,
Interval: time.Duration(cfg.Tools.MediaCleanup.Interval) * time.Minute,
})
// Start the media store if it's a FileMediaStore with cleanup
if fms, ok := services.MediaStore.(*media.FileMediaStore); ok {
if fms, ok := runningServices.MediaStore.(*media.FileMediaStore); ok {
fms.Start()
}
// Create channel manager
var err error
services.ChannelManager, err = channels.NewManager(cfg, msgBus, services.MediaStore)
runningServices.ChannelManager, err = channels.NewManager(cfg, msgBus, runningServices.MediaStore)
if err != nil {
// Stop the media store if it's a FileMediaStore with cleanup
if fms, ok := services.MediaStore.(*media.FileMediaStore); ok {
if fms, ok := runningServices.MediaStore.(*media.FileMediaStore); ok {
fms.Stop()
}
return nil, fmt.Errorf("error creating channel manager: %w", err)
}
// Inject channel manager and media store into agent loop
agentLoop.SetChannelManager(services.ChannelManager)
agentLoop.SetMediaStore(services.MediaStore)
agentLoop.SetChannelManager(runningServices.ChannelManager)
agentLoop.SetMediaStore(runningServices.MediaStore)
// Wire up voice transcription if a supported provider is configured.
if transcriber := voice.DetectTranscriber(cfg); transcriber != nil {
agentLoop.SetTranscriber(transcriber)
logger.InfoCF("voice", "Transcription enabled (agent-level)", map[string]any{"provider": transcriber.Name()})
}
enabledChannels := services.ChannelManager.GetEnabledChannels()
enabledChannels := runningServices.ChannelManager.GetEnabledChannels()
if len(enabledChannels) > 0 {
fmt.Printf("✓ Channels enabled: %s\n", enabledChannels)
} else {
fmt.Println("⚠ Warning: No channels enabled")
}
// Setup shared HTTP server with health endpoints and webhook handlers
addr := fmt.Sprintf("%s:%d", cfg.Gateway.Host, cfg.Gateway.Port)
services.HealthServer = health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port)
services.ChannelManager.SetupHTTPServer(addr, services.HealthServer)
runningServices.HealthServer = health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port)
runningServices.ChannelManager.SetupHTTPServer(addr, runningServices.HealthServer)
if err := services.ChannelManager.StartAll(context.Background()); err != nil {
if err = runningServices.ChannelManager.StartAll(context.Background()); err != nil {
return nil, fmt.Errorf("error starting channels: %w", err)
}
fmt.Printf("✓ Health endpoints available at http://%s:%d/health and /ready\n", cfg.Gateway.Host, cfg.Gateway.Port)
// Setup state manager and device service
stateManager := state.NewManager(cfg.WorkspacePath())
services.DeviceService = devices.NewService(devices.Config{
runningServices.DeviceService = devices.NewService(devices.Config{
Enabled: cfg.Devices.Enabled,
MonitorUSB: cfg.Devices.MonitorUSB,
}, stateManager)
services.DeviceService.SetBus(msgBus)
if err := services.DeviceService.Start(context.Background()); err != nil {
runningServices.DeviceService.SetBus(msgBus)
if err = runningServices.DeviceService.Start(context.Background()); err != nil {
logger.ErrorCF("device", "Error starting device service", map[string]any{"error": err.Error()})
} else if cfg.Devices.Enabled {
fmt.Println("✓ Device event service started")
}
return services, nil
return runningServices, nil
}
// stopAndCleanupServices stops all services and cleans up resources
func stopAndCleanupServices(
services *gatewayServices,
shutdownTimeout time.Duration,
) {
func stopAndCleanupServices(runningServices *services, shutdownTimeout time.Duration) {
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), shutdownTimeout)
defer shutdownCancel()
if services.ChannelManager != nil {
services.ChannelManager.StopAll(shutdownCtx)
if runningServices.ChannelManager != nil {
runningServices.ChannelManager.StopAll(shutdownCtx)
}
if services.DeviceService != nil {
services.DeviceService.Stop()
if runningServices.DeviceService != nil {
runningServices.DeviceService.Stop()
}
if services.HeartbeatService != nil {
services.HeartbeatService.Stop()
if runningServices.HeartbeatService != nil {
runningServices.HeartbeatService.Stop()
}
if services.CronService != nil {
services.CronService.Stop()
if runningServices.CronService != nil {
runningServices.CronService.Stop()
}
if services.MediaStore != nil {
// Stop the media store if it's a FileMediaStore with cleanup
if fms, ok := services.MediaStore.(*media.FileMediaStore); ok {
if runningServices.MediaStore != nil {
if fms, ok := runningServices.MediaStore.(*media.FileMediaStore); ok {
fms.Stop()
}
}
}
// shutdownGateway performs a complete gateway shutdown
func shutdownGateway(
services *gatewayServices,
runningServices *services,
agentLoop *agent.AgentLoop,
provider providers.LLMProvider,
fullShutdown bool,
@@ -299,7 +295,7 @@ func shutdownGateway(
cp.Close()
}
stopAndCleanupServices(services, gracefulShutdownTimeout)
stopAndCleanupServices(runningServices, gracefulShutdownTimeout)
agentLoop.Stop()
agentLoop.Close()
@@ -307,15 +303,14 @@ func shutdownGateway(
logger.Info("✓ Gateway stopped")
}
// handleConfigReload handles config file reload by stopping all services,
// reloading the provider and config, and restarting services with the new config.
func handleConfigReload(
ctx context.Context,
al *agent.AgentLoop,
newCfg *config.Config,
providerRef *providers.LLMProvider,
services *gatewayServices,
runningServices *services,
msgBus *bus.MessageBus,
allowEmptyStartup bool,
) error {
logger.Info("🔄 Config file changed, reloading...")
@@ -326,18 +321,14 @@ func handleConfigReload(
logger.Infof(" New model is '%s', recreating provider...", newModel)
// Stop all services before reloading
logger.Info(" Stopping all services...")
stopAndCleanupServices(services, serviceShutdownTimeout)
stopAndCleanupServices(runningServices, serviceShutdownTimeout)
// Create new provider from updated config first to ensure validity
// This will use the correct API key and settings from newCfg.ModelList
newProvider, newModelID, err := providers.CreateProvider(newCfg)
newProvider, newModelID, err := createStartupProvider(newCfg, allowEmptyStartup)
if err != nil {
logger.Errorf(" ⚠ Error creating new provider: %v", err)
logger.Warn(" Attempting to restart services with old provider and config...")
// Try to restart services with old configuration
if restartErr := restartServices(al, services, msgBus); restartErr != nil {
if restartErr := restartServices(al, runningServices, msgBus); restartErr != nil {
logger.Errorf(" ⚠ Failed to restart services: %v", restartErr)
}
return fmt.Errorf("error creating new provider: %w", err)
@@ -347,31 +338,25 @@ func handleConfigReload(
newCfg.Agents.Defaults.ModelName = newModelID
}
// Use the atomic reload method on AgentLoop to safely swap provider and config.
// This handles locking internally to prevent races with in-flight LLM calls
// and concurrent reads of registry/config while the swap occurs.
reloadCtx, reloadCancel := context.WithTimeout(context.Background(), providerReloadTimeout)
defer reloadCancel()
if err := al.ReloadProviderAndConfig(reloadCtx, newProvider, newCfg); err != nil {
logger.Errorf(" ⚠ Error reloading agent loop: %v", err)
// Close the newly created provider since it wasn't adopted
if cp, ok := newProvider.(providers.StatefulProvider); ok {
cp.Close()
}
logger.Warn(" Attempting to restart services with old provider and config...")
if restartErr := restartServices(al, services, msgBus); restartErr != nil {
if restartErr := restartServices(al, runningServices, msgBus); restartErr != nil {
logger.Errorf(" ⚠ Failed to restart services: %v", restartErr)
}
return fmt.Errorf("error reloading agent loop: %w", err)
}
// Update local provider reference only after successful atomic reload
*providerRef = newProvider
// Restart all services with new config
logger.Info(" Restarting all services with new configuration...")
if err := restartServices(al, services, msgBus); err != nil {
if err := restartServices(al, runningServices, msgBus); err != nil {
logger.Errorf(" ⚠ Error restarting services: %v", err)
return fmt.Errorf("error restarting services: %w", err)
}
@@ -380,23 +365,16 @@ func handleConfigReload(
return nil
}
// restartServices restarts all services after a config reload
func restartServices(
al *agent.AgentLoop,
services *gatewayServices,
runningServices *services,
msgBus *bus.MessageBus,
) error {
// Create an independent context with timeout for service restart
// This prevents cancellation from the main loop context during reload
ctx, cancel := context.WithTimeout(context.Background(), serviceRestartTimeout)
defer cancel()
// Get current config from agent loop (which has been updated if this is a reload)
cfg := al.GetConfig()
// Re-create and start cron service with new config
execTimeout := time.Duration(cfg.Tools.Cron.ExecTimeoutMinutes) * time.Minute
services.CronService = setupCronTool(
var err error
runningServices.CronService, err = setupCronTool(
al,
msgBus,
cfg.WorkspacePath(),
@@ -404,80 +382,54 @@ func restartServices(
execTimeout,
cfg,
)
if err := services.CronService.Start(); err != nil {
if err != nil {
return fmt.Errorf("error restarting cron service: %w", err)
}
if err = runningServices.CronService.Start(); err != nil {
return fmt.Errorf("error restarting cron service: %w", err)
}
fmt.Println(" ✓ Cron service restarted")
// Re-create and start heartbeat service with new config
services.HeartbeatService = heartbeat.NewHeartbeatService(
runningServices.HeartbeatService = heartbeat.NewHeartbeatService(
cfg.WorkspacePath(),
cfg.Heartbeat.Interval,
cfg.Heartbeat.Enabled,
)
services.HeartbeatService.SetBus(msgBus)
services.HeartbeatService.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
if channel == "" || chatID == "" {
channel, chatID = "cli", "direct"
}
var response string
var err error
response, err = al.ProcessHeartbeat(context.Background(), prompt, channel, chatID)
if err != nil {
return tools.ErrorResult(fmt.Sprintf("Heartbeat error: %v", err))
}
if response == "HEARTBEAT_OK" {
return tools.SilentResult("Heartbeat OK")
}
return tools.SilentResult(response)
})
if err := services.HeartbeatService.Start(); err != nil {
runningServices.HeartbeatService.SetBus(msgBus)
runningServices.HeartbeatService.SetHandler(createHeartbeatHandler(al))
if err = runningServices.HeartbeatService.Start(); err != nil {
return fmt.Errorf("error restarting heartbeat service: %w", err)
}
fmt.Println(" ✓ Heartbeat service restarted")
// Stop the old media store before creating a new one
if fms, ok := services.MediaStore.(*media.FileMediaStore); ok {
fms.Stop()
}
// Re-create media store with new config
services.MediaStore = media.NewFileMediaStoreWithCleanup(media.MediaCleanerConfig{
runningServices.MediaStore = media.NewFileMediaStoreWithCleanup(media.MediaCleanerConfig{
Enabled: cfg.Tools.MediaCleanup.Enabled,
MaxAge: time.Duration(cfg.Tools.MediaCleanup.MaxAge) * time.Minute,
Interval: time.Duration(cfg.Tools.MediaCleanup.Interval) * time.Minute,
})
// Start the media store if it's a FileMediaStore with cleanup
if fms, ok := services.MediaStore.(*media.FileMediaStore); ok {
if fms, ok := runningServices.MediaStore.(*media.FileMediaStore); ok {
fms.Start()
}
al.SetMediaStore(services.MediaStore)
al.SetMediaStore(runningServices.MediaStore)
// Re-create channel manager with new config
var err error
services.ChannelManager, err = channels.NewManager(cfg, msgBus, services.MediaStore)
runningServices.ChannelManager, err = channels.NewManager(cfg, msgBus, runningServices.MediaStore)
if err != nil {
// Stop the media store if it's a FileMediaStore with cleanup
if fms, ok := services.MediaStore.(*media.FileMediaStore); ok {
fms.Stop()
}
return fmt.Errorf("error recreating channel manager: %w", err)
}
al.SetChannelManager(services.ChannelManager)
al.SetChannelManager(runningServices.ChannelManager)
enabledChannels := services.ChannelManager.GetEnabledChannels()
enabledChannels := runningServices.ChannelManager.GetEnabledChannels()
if len(enabledChannels) > 0 {
fmt.Printf(" ✓ Channels enabled: %s\n", enabledChannels)
} else {
fmt.Println(" ⚠ Warning: No channels enabled")
}
// Setup HTTP server with new config
addr := fmt.Sprintf("%s:%d", cfg.Gateway.Host, cfg.Gateway.Port)
services.HealthServer = health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port)
services.ChannelManager.SetupHTTPServer(addr, services.HealthServer)
runningServices.HealthServer = health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port)
runningServices.ChannelManager.SetupHTTPServer(addr, runningServices.HealthServer)
if err := services.ChannelManager.StartAll(ctx); err != nil {
if err = runningServices.ChannelManager.StartAll(context.Background()); err != nil {
return fmt.Errorf("error restarting channels: %w", err)
}
fmt.Printf(
@@ -486,22 +438,20 @@ func restartServices(
cfg.Gateway.Port,
)
// Re-create device service with new config
stateManager := state.NewManager(cfg.WorkspacePath())
services.DeviceService = devices.NewService(devices.Config{
runningServices.DeviceService = devices.NewService(devices.Config{
Enabled: cfg.Devices.Enabled,
MonitorUSB: cfg.Devices.MonitorUSB,
}, stateManager)
services.DeviceService.SetBus(msgBus)
if err := services.DeviceService.Start(ctx); err != nil {
runningServices.DeviceService.SetBus(msgBus)
if err := runningServices.DeviceService.Start(context.Background()); err != nil {
logger.WarnCF("device", "Failed to restart device service", map[string]any{"error": err.Error()})
} else if cfg.Devices.Enabled {
fmt.Println(" ✓ Device event service restarted")
}
// Wire up voice transcription with new config
transcriber := voice.DetectTranscriber(cfg)
al.SetTranscriber(transcriber) // This will set it to nil if disabled
al.SetTranscriber(transcriber)
if transcriber != nil {
logger.InfoCF("voice", "Transcription re-enabled (agent-level)", map[string]any{"provider": transcriber.Name()})
} else {
@@ -511,8 +461,6 @@ func restartServices(
return nil
}
// setupConfigWatcherPolling sets up a simple polling-based config file watcher
// Returns a channel for config updates and a stop function
func setupConfigWatcherPolling(configPath string, debug bool) (chan *config.Config, func()) {
configChan := make(chan *config.Config, 1)
stop := make(chan struct{})
@@ -522,11 +470,10 @@ func setupConfigWatcherPolling(configPath string, debug bool) (chan *config.Conf
go func() {
defer wg.Done()
// Get initial file info
lastModTime := getFileModTime(configPath)
lastSize := getFileSize(configPath)
ticker := time.NewTicker(2 * time.Second) // Check every 2 seconds
ticker := time.NewTicker(2 * time.Second)
defer ticker.Stop()
for {
@@ -535,16 +482,16 @@ func setupConfigWatcherPolling(configPath string, debug bool) (chan *config.Conf
currentModTime := getFileModTime(configPath)
currentSize := getFileSize(configPath)
// Check if file changed (modification time or size changed)
if currentModTime.After(lastModTime) || currentSize != lastSize {
if debug {
logger.Debugf("🔍 Config file change detected")
}
// Debounce - wait a bit to ensure file write is complete
time.Sleep(500 * time.Millisecond)
// Validate and load new config
lastModTime = currentModTime
lastSize = currentSize
newCfg, err := config.LoadConfig(configPath)
if err != nil {
logger.Errorf("⚠ Error loading new config: %v", err)
@@ -552,7 +499,6 @@ func setupConfigWatcherPolling(configPath string, debug bool) (chan *config.Conf
continue
}
// Validate the new config
if err := newCfg.ValidateModelList(); err != nil {
logger.Errorf(" ⚠ New config validation failed: %v", err)
logger.Warn(" Using previous valid config")
@@ -561,19 +507,12 @@ func setupConfigWatcherPolling(configPath string, debug bool) (chan *config.Conf
logger.Info("✓ Config file validated and loaded")
// Update last known state
lastModTime = currentModTime
lastSize = currentSize
// Send new config to main loop (non-blocking)
select {
case configChan <- newCfg:
default:
// Channel full, skip this update
logger.Warn("⚠ Previous config reload still in progress, skipping")
}
}
case <-stop:
return
}
@@ -588,7 +527,6 @@ func setupConfigWatcherPolling(configPath string, debug bool) (chan *config.Conf
return configChan, stopFunc
}
// getFileModTime returns the modification time of a file, or zero time if file doesn't exist
func getFileModTime(path string) time.Time {
info, err := os.Stat(path)
if err != nil {
@@ -597,7 +535,6 @@ func getFileModTime(path string) time.Time {
return info.ModTime()
}
// getFileSize returns the size of a file, or 0 if file doesn't exist
func getFileSize(path string) int64 {
info, err := os.Stat(path)
if err != nil {
@@ -613,25 +550,22 @@ func setupCronTool(
restrict bool,
execTimeout time.Duration,
cfg *config.Config,
) *cron.CronService {
) (*cron.CronService, error) {
cronStorePath := filepath.Join(workspace, "cron", "jobs.json")
// Create cron service
cronService := cron.NewCronService(cronStorePath, nil)
// Create and register CronTool if enabled
var cronTool *tools.CronTool
if cfg.Tools.IsToolEnabled("cron") {
var err error
cronTool, err = tools.NewCronTool(cronService, agentLoop, msgBus, workspace, restrict, execTimeout, cfg)
if err != nil {
logger.Fatalf("Critical error during CronTool initialization: %v", err)
return nil, fmt.Errorf("critical error during CronTool initialization: %w", err)
}
agentLoop.RegisterTool(cronTool)
}
// Set onJob handler
if cronTool != nil {
cronService.SetOnJob(func(job *cron.CronJob) (string, error) {
result := cronTool.ExecuteJob(context.Background(), job)
@@ -639,5 +573,22 @@ func setupCronTool(
})
}
return cronService
return cronService, nil
}
func createHeartbeatHandler(agentLoop *agent.AgentLoop) func(prompt, channel, chatID string) *tools.ToolResult {
return func(prompt, channel, chatID string) *tools.ToolResult {
if channel == "" || chatID == "" {
channel, chatID = "cli", "direct"
}
response, err := agentLoop.ProcessHeartbeat(context.Background(), prompt, channel, chatID)
if err != nil {
return tools.ErrorResult(fmt.Sprintf("Heartbeat error: %v", err))
}
if response == "HEARTBEAT_OK" {
return tools.SilentResult("Heartbeat OK")
}
return tools.SilentResult(response)
}
}
+3
View File
@@ -6,6 +6,7 @@ import (
"fmt"
"maps"
"net/http"
"os"
"sync"
"time"
)
@@ -29,6 +30,7 @@ type StatusResponse struct {
Status string `json:"status"`
Uptime string `json:"uptime"`
Checks map[string]Check `json:"checks,omitempty"`
Pid int `json:"pid"`
}
func NewServer(host string, port int) *Server {
@@ -112,6 +114,7 @@ func (s *Server) healthHandler(w http.ResponseWriter, r *http.Request) {
resp := StatusResponse{
Status: "ok",
Uptime: uptime.String(),
Pid: os.Getpid(),
}
json.NewEncoder(w).Encode(resp)
+58 -7
View File
@@ -5,6 +5,7 @@ import (
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"sync"
@@ -45,6 +46,9 @@ func init() {
consoleWriter := zerolog.ConsoleWriter{
Out: os.Stdout,
TimeFormat: "15:04:05", // TODO: make it configurable???
// Custom formatter to handle multiline strings and JSON objects
FormatFieldValue: formatFieldValue,
}
logger = zerolog.New(consoleWriter).With().Timestamp().Logger()
@@ -52,6 +56,37 @@ func init() {
})
}
func formatFieldValue(i any) string {
var s string
switch val := i.(type) {
case string:
s = val
case []byte:
s = string(val)
default:
return fmt.Sprintf("%v", i)
}
if unquoted, err := strconv.Unquote(s); err == nil {
s = unquoted
}
if strings.Contains(s, "\n") {
return fmt.Sprintf("\n%s", s)
}
if strings.Contains(s, " ") {
if (strings.HasPrefix(s, "{") && strings.HasSuffix(s, "}")) ||
(strings.HasPrefix(s, "[") && strings.HasSuffix(s, "]")) {
return s
}
return fmt.Sprintf("%q", s)
}
return s
}
func SetLevel(level LogLevel) {
mu.Lock()
defer mu.Unlock()
@@ -163,10 +198,7 @@ func logMessage(level LogLevel, component string, message string, fields map[str
event.Str("caller", fmt.Sprintf("<none> %s:%d (%s)", callerFile, callerLine, callerFunc))
}
for k, v := range fields {
event.Interface(k, v)
}
appendFields(event, fields)
event.Msg(message)
// Also log to file if enabled
@@ -176,9 +208,8 @@ func logMessage(level LogLevel, component string, message string, fields map[str
if component != "" {
fileEvent.Str("component", component)
}
for k, v := range fields {
fileEvent.Interface(k, v)
}
appendFields(event, fields)
fileEvent.Msg(message)
}
@@ -187,6 +218,26 @@ func logMessage(level LogLevel, component string, message string, fields map[str
}
}
func appendFields(event *zerolog.Event, fields map[string]any) {
for k, v := range fields {
// Type switch to avoid double JSON serialization of strings
switch val := v.(type) {
case string:
event.Str(k, val)
case int:
event.Int(k, val)
case int64:
event.Int64(k, val)
case float64:
event.Float64(k, val)
case bool:
event.Bool(k, val)
default:
event.Interface(k, v) // Fallback for struct, slice and maps
}
}
}
func Debug(message string) {
logMessage(DEBUG, "", message, nil)
}
+25 -12
View File
@@ -2,7 +2,20 @@
package logger
import "fmt"
import (
"fmt"
"regexp"
)
// botTokenRe matches the bot ID prefix and the secret part of a Telegram bot token.
// Groups: 1 = "bot<id>:", 2 = first 4 chars of secret, 3 = middle, 4 = last 4 chars.
var botTokenRe = regexp.MustCompile(`(bot\d+:)([A-Za-z0-9_-]{4})[A-Za-z0-9_-]{12,}([A-Za-z0-9_-]{4})`)
// maskSecrets replaces any embedded bot tokens in s with a redacted placeholder
// that keeps the first and last 4 characters of the secret for identification.
func maskSecrets(s string) string {
return botTokenRe.ReplaceAllString(s, "${1}${2}****${3}")
}
// Logger implements common Logger interface
type Logger struct {
@@ -12,52 +25,52 @@ type Logger struct {
// Debug logs debug messages
func (b *Logger) Debug(v ...any) {
logMessage(DEBUG, b.component, fmt.Sprint(v...), nil)
logMessage(DEBUG, b.component, maskSecrets(fmt.Sprint(v...)), nil)
}
// Info logs info messages
func (b *Logger) Info(v ...any) {
logMessage(INFO, b.component, fmt.Sprint(v...), nil)
logMessage(INFO, b.component, maskSecrets(fmt.Sprint(v...)), nil)
}
// Warn logs warning messages
func (b *Logger) Warn(v ...any) {
logMessage(WARN, b.component, fmt.Sprint(v...), nil)
logMessage(WARN, b.component, maskSecrets(fmt.Sprint(v...)), nil)
}
// Error logs error messages
func (b *Logger) Error(v ...any) {
logMessage(ERROR, b.component, fmt.Sprint(v...), nil)
logMessage(ERROR, b.component, maskSecrets(fmt.Sprint(v...)), nil)
}
// Debugf logs formatted debug messages
func (b *Logger) Debugf(format string, v ...any) {
logMessage(DEBUG, b.component, fmt.Sprintf(format, v...), nil)
logMessage(DEBUG, b.component, maskSecrets(fmt.Sprintf(format, v...)), nil)
}
// Infof logs formatted info messages
func (b *Logger) Infof(format string, v ...any) {
logMessage(INFO, b.component, fmt.Sprintf(format, v...), nil)
logMessage(INFO, b.component, maskSecrets(fmt.Sprintf(format, v...)), nil)
}
// Warnf logs formatted warning messages
func (b *Logger) Warnf(format string, v ...any) {
logMessage(WARN, b.component, fmt.Sprintf(format, v...), nil)
logMessage(WARN, b.component, maskSecrets(fmt.Sprintf(format, v...)), nil)
}
// Warningf logs formatted warning messages
func (b *Logger) Warningf(format string, v ...any) {
logMessage(WARN, b.component, fmt.Sprintf(format, v...), nil)
logMessage(WARN, b.component, maskSecrets(fmt.Sprintf(format, v...)), nil)
}
// Errorf logs formatted error messages
func (b *Logger) Errorf(format string, v ...any) {
logMessage(ERROR, b.component, fmt.Sprintf(format, v...), nil)
logMessage(ERROR, b.component, maskSecrets(fmt.Sprintf(format, v...)), nil)
}
// Fatalf logs formatted fatal messages and exits
func (b *Logger) Fatalf(format string, v ...any) {
logMessage(FATAL, b.component, fmt.Sprintf(format, v...), nil)
logMessage(FATAL, b.component, maskSecrets(fmt.Sprintf(format, v...)), nil)
}
// Log logs a message at a given level with caller information
@@ -75,7 +88,7 @@ func (b *Logger) Log(msgL, caller int, format string, a ...any) {
level = lvl
}
}
logMessage(level, b.component, fmt.Sprintf(format, a...), nil)
logMessage(level, b.component, maskSecrets(fmt.Sprintf(format, a...)), nil)
}
// Sync flushes log buffer (no-op for this implementation)
+111
View File
@@ -141,3 +141,114 @@ func TestLoggerHelperFunctions(t *testing.T) {
Debugf("test from %v", "Debugf")
WarnF("Warning with fields", map[string]any{"key": "value"})
}
func TestFormatFieldValue(t *testing.T) {
tests := []struct {
name string
input any
expected string
}{
// Basic types test (default case of the switch)
{
name: "Integer Type",
input: 42,
expected: "42",
},
{
name: "Boolean Type",
input: true,
expected: "true",
},
{
name: "Unsupported Struct Type",
input: struct{ A int }{A: 1},
expected: "{1}",
},
// Simple strings and byte slices test
{
name: "Simple string without spaces",
input: "simple_value",
expected: "simple_value",
},
{
name: "Simple byte slice",
input: []byte("byte_value"),
expected: "byte_value",
},
// Unquoting test (strconv.Unquote)
{
name: "Quoted string",
input: `"quoted_value"`,
expected: "quoted_value",
},
// Strings with newline (\n) test
{
name: "String with newline",
input: "line1\nline2",
expected: "\nline1\nline2",
},
{
name: "Quoted string with newline (Unquote -> newline)",
input: `"line1\nline2"`, // Escaped \n that Unquote will resolve
expected: "\nline1\nline2",
},
// Strings with spaces test (which should be quoted)
{
name: "String with spaces",
input: "hello world",
expected: `"hello world"`,
},
{
name: "Quoted string with spaces (Unquote -> has spaces -> Re-quote)",
input: `"hello world"`,
expected: `"hello world"`,
},
// JSON formats test (strings with spaces that start/end with brackets)
{
name: "Valid JSON object",
input: `{"key": "value"}`,
expected: `{"key": "value"}`,
},
{
name: "Valid JSON array",
input: `[1, 2, "three"]`,
expected: `[1, 2, "three"]`,
},
{
name: "Fake JSON (starts with { but doesn't end with })",
input: `{"key": "value"`, // Missing closing bracket, has spaces
expected: `"{\"key\": \"value\""`,
},
{
name: "Empty JSON (object)",
input: `{ }`,
expected: `{ }`,
},
// 7. Edge Cases
{
name: "Empty string",
input: "",
expected: "",
},
{
name: "Whitespace only string",
input: " ",
expected: `" "`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
actual := formatFieldValue(tt.input)
if actual != tt.expected {
t.Errorf("formatFieldValue() = %q, expected %q", actual, tt.expected)
}
})
}
}
+13
View File
@@ -0,0 +1,13 @@
package media
import (
"os"
"path/filepath"
)
const TempDirName = "picoclaw_media"
// TempDir returns the shared temporary directory used for downloaded media.
func TempDir() string {
return filepath.Join(os.TempDir(), TempDirName)
}
+7 -1
View File
@@ -221,11 +221,17 @@ func buildRequestBody(
// Add tool_use blocks
for _, tc := range msg.ToolCalls {
// Handle nil Arguments (GLM-4 may return null input)
input := tc.Arguments
if input == nil {
input = map[string]any{}
}
toolUse := map[string]any{
"type": "tool_use",
"id": tc.ID,
"name": tc.Name,
"input": tc.Arguments,
"input": input,
}
content = append(content, toolUse)
}
+150
View File
@@ -0,0 +1,150 @@
package azure
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"time"
"github.com/sipeed/picoclaw/pkg/providers/common"
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
type (
LLMResponse = protocoltypes.LLMResponse
Message = protocoltypes.Message
ToolDefinition = protocoltypes.ToolDefinition
)
const (
// azureAPIVersion is the Azure OpenAI API version used for all requests.
azureAPIVersion = "2024-10-21"
defaultRequestTimeout = common.DefaultRequestTimeout
)
// Provider implements the LLM provider interface for Azure OpenAI endpoints.
// It handles Azure-specific authentication (api-key header), URL construction
// (deployment-based), and request body formatting (max_completion_tokens, no model field).
type Provider struct {
apiKey string
apiBase string
httpClient *http.Client
}
// Option configures the Azure Provider.
type Option func(*Provider)
// WithRequestTimeout sets the HTTP request timeout.
func WithRequestTimeout(timeout time.Duration) Option {
return func(p *Provider) {
if timeout > 0 {
p.httpClient.Timeout = timeout
}
}
}
// NewProvider creates a new Azure OpenAI provider.
func NewProvider(apiKey, apiBase, proxy string, opts ...Option) *Provider {
p := &Provider{
apiKey: apiKey,
apiBase: strings.TrimRight(apiBase, "/"),
httpClient: common.NewHTTPClient(proxy),
}
for _, opt := range opts {
if opt != nil {
opt(p)
}
}
return p
}
// NewProviderWithTimeout creates a new Azure OpenAI provider with a custom request timeout in seconds.
func NewProviderWithTimeout(apiKey, apiBase, proxy string, requestTimeoutSeconds int) *Provider {
return NewProvider(
apiKey, apiBase, proxy,
WithRequestTimeout(time.Duration(requestTimeoutSeconds)*time.Second),
)
}
// Chat sends a chat completion request to the Azure OpenAI endpoint.
// The model parameter is used as the Azure deployment name in the URL.
func (p *Provider) Chat(
ctx context.Context,
messages []Message,
tools []ToolDefinition,
model string,
options map[string]any,
) (*LLMResponse, error) {
if p.apiBase == "" {
return nil, fmt.Errorf("Azure API base not configured")
}
// model is the deployment name for Azure OpenAI
deployment := model
// Build Azure-specific URL safely using url.JoinPath and query encoding
// to prevent path traversal or query injection via deployment names.
base, err := url.JoinPath(p.apiBase, "openai/deployments", deployment, "chat/completions")
if err != nil {
return nil, fmt.Errorf("failed to build Azure request URL: %w", err)
}
requestURL := base + "?api-version=" + azureAPIVersion
// Build request body — no "model" field (Azure infers from deployment URL)
requestBody := map[string]any{
"messages": common.SerializeMessages(messages),
}
if len(tools) > 0 {
requestBody["tools"] = tools
requestBody["tool_choice"] = "auto"
}
// Azure OpenAI always uses max_completion_tokens
if maxTokens, ok := common.AsInt(options["max_tokens"]); ok {
requestBody["max_completion_tokens"] = maxTokens
}
if temperature, ok := common.AsFloat(options["temperature"]); ok {
requestBody["temperature"] = temperature
}
jsonData, err := json.Marshal(requestBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", requestURL, bytes.NewReader(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
// Azure uses api-key header instead of Authorization: Bearer
req.Header.Set("Content-Type", "application/json")
if p.apiKey != "" {
req.Header.Set("Api-Key", p.apiKey)
}
resp, err := p.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, common.HandleErrorResponse(resp, p.apiBase)
}
return common.ReadAndParseResponse(resp, p.apiBase)
}
// GetDefaultModel returns an empty string as Azure deployments are user-configured.
func (p *Provider) GetDefaultModel() string {
return ""
}
+232
View File
@@ -0,0 +1,232 @@
package azure
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
)
// writeValidResponse writes a minimal valid Azure OpenAI chat completion response.
func writeValidResponse(w http.ResponseWriter) {
resp := map[string]any{
"choices": []map[string]any{
{
"message": map[string]any{"content": "ok"},
"finish_reason": "stop",
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}
func TestProviderChat_AzureURLConstruction(t *testing.T) {
var capturedPath string
var capturedAPIVersion string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedPath = r.URL.Path
capturedAPIVersion = r.URL.Query().Get("api-version")
writeValidResponse(w)
}))
defer server.Close()
p := NewProvider("test-key", server.URL, "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "my-gpt5-deployment", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
wantPath := "/openai/deployments/my-gpt5-deployment/chat/completions"
if capturedPath != wantPath {
t.Errorf("URL path = %q, want %q", capturedPath, wantPath)
}
if capturedAPIVersion != azureAPIVersion {
t.Errorf("api-version = %q, want %q", capturedAPIVersion, azureAPIVersion)
}
}
func TestProviderChat_AzureAuthHeader(t *testing.T) {
var capturedAPIKey string
var capturedAuth string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedAPIKey = r.Header.Get("Api-Key")
capturedAuth = r.Header.Get("Authorization")
writeValidResponse(w)
}))
defer server.Close()
p := NewProvider("test-azure-key", server.URL, "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if capturedAPIKey != "test-azure-key" {
t.Errorf("api-key header = %q, want %q", capturedAPIKey, "test-azure-key")
}
if capturedAuth != "" {
t.Errorf("Authorization header should be empty, got %q", capturedAuth)
}
}
func TestProviderChat_AzureOmitsModelFromBody(t *testing.T) {
var requestBody map[string]any
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewDecoder(r.Body).Decode(&requestBody)
writeValidResponse(w)
}))
defer server.Close()
p := NewProvider("test-key", server.URL, "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if _, exists := requestBody["model"]; exists {
t.Error("request body should not contain 'model' field for Azure OpenAI")
}
}
func TestProviderChat_AzureUsesMaxCompletionTokens(t *testing.T) {
var requestBody map[string]any
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewDecoder(r.Body).Decode(&requestBody)
writeValidResponse(w)
}))
defer server.Close()
p := NewProvider("test-key", server.URL, "")
_, err := p.Chat(
t.Context(),
[]Message{{Role: "user", Content: "hi"}},
nil,
"deployment",
map[string]any{"max_tokens": 2048},
)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if _, exists := requestBody["max_completion_tokens"]; !exists {
t.Error("request body should contain 'max_completion_tokens'")
}
if _, exists := requestBody["max_tokens"]; exists {
t.Error("request body should not contain 'max_tokens'")
}
}
func TestProviderChat_AzureHTTPError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, `{"error":"unauthorized"}`, http.StatusUnauthorized)
}))
defer server.Close()
p := NewProvider("bad-key", server.URL, "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
if err == nil {
t.Fatal("expected error, got nil")
}
}
func TestProviderChat_AzureParseToolCalls(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := map[string]any{
"choices": []map[string]any{
{
"message": map[string]any{
"content": "",
"tool_calls": []map[string]any{
{
"id": "call_1",
"type": "function",
"function": map[string]any{
"name": "get_weather",
"arguments": `{"city":"Seattle"}`,
},
},
},
},
"finish_reason": "tool_calls",
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
p := NewProvider("test-key", server.URL, "")
out, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "weather?"}}, nil, "deployment", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if len(out.ToolCalls) != 1 {
t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls))
}
if out.ToolCalls[0].Name != "get_weather" {
t.Errorf("ToolCalls[0].Name = %q, want %q", out.ToolCalls[0].Name, "get_weather")
}
}
func TestProvider_AzureEmptyAPIBase(t *testing.T) {
p := NewProvider("test-key", "", "")
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
if err == nil {
t.Fatal("expected error for empty API base")
}
}
func TestProvider_AzureRequestTimeoutDefault(t *testing.T) {
p := NewProvider("test-key", "https://example.com", "")
if p.httpClient.Timeout != defaultRequestTimeout {
t.Errorf("timeout = %v, want %v", p.httpClient.Timeout, defaultRequestTimeout)
}
}
func TestProvider_AzureRequestTimeoutOverride(t *testing.T) {
p := NewProvider("test-key", "https://example.com", "", WithRequestTimeout(300*time.Second))
if p.httpClient.Timeout != 300*time.Second {
t.Errorf("timeout = %v, want %v", p.httpClient.Timeout, 300*time.Second)
}
}
func TestProvider_AzureNewProviderWithTimeout(t *testing.T) {
p := NewProviderWithTimeout("test-key", "https://example.com", "", 180)
if p.httpClient.Timeout != 180*time.Second {
t.Errorf("timeout = %v, want %v", p.httpClient.Timeout, 180*time.Second)
}
}
func TestProviderChat_AzureDeploymentNameEscaped(t *testing.T) {
var capturedPath string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedPath = r.URL.RawPath // use RawPath to see percent-encoding
if capturedPath == "" {
capturedPath = r.URL.Path
}
writeValidResponse(w)
}))
defer server.Close()
p := NewProvider("test-key", server.URL, "")
// Deployment name with characters that could cause path injection
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "my deploy/../../admin", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
// The slash and special chars in the deployment name must be escaped, not treated as path separators
if capturedPath == "/openai/deployments/my deploy/../../admin/chat/completions" {
t.Fatal("deployment name was interpolated without escaping — path injection possible")
}
}
+380
View File
@@ -0,0 +1,380 @@
// PicoClaw - Ultra-lightweight personal AI agent
// License: MIT
//
// Copyright (c) 2026 PicoClaw contributors
// Package common provides shared utilities used by multiple LLM provider
// implementations (openai_compat, azure, etc.).
package common
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"net/url"
"strings"
"time"
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
// Re-export protocol types used across providers.
type (
ToolCall = protocoltypes.ToolCall
FunctionCall = protocoltypes.FunctionCall
LLMResponse = protocoltypes.LLMResponse
UsageInfo = protocoltypes.UsageInfo
Message = protocoltypes.Message
ToolDefinition = protocoltypes.ToolDefinition
ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition
ExtraContent = protocoltypes.ExtraContent
GoogleExtra = protocoltypes.GoogleExtra
ReasoningDetail = protocoltypes.ReasoningDetail
)
const DefaultRequestTimeout = 120 * time.Second
// NewHTTPClient creates an *http.Client with an optional proxy and the default timeout.
func NewHTTPClient(proxy string) *http.Client {
client := &http.Client{
Timeout: DefaultRequestTimeout,
}
if proxy != "" {
parsed, err := url.Parse(proxy)
if err == nil {
// Preserve http.DefaultTransport settings (TLS, HTTP/2, timeouts, etc.)
if base, ok := http.DefaultTransport.(*http.Transport); ok {
tr := base.Clone()
tr.Proxy = http.ProxyURL(parsed)
client.Transport = tr
} else {
// Fallback: minimal transport if DefaultTransport is not *http.Transport.
client.Transport = &http.Transport{
Proxy: http.ProxyURL(parsed),
}
}
} else {
log.Printf("common: invalid proxy URL %q: %v", proxy, err)
}
}
return client
}
// --- Message serialization ---
// openaiMessage is the wire-format message for OpenAI-compatible APIs.
// It mirrors protocoltypes.Message but omits SystemParts, which is an
// internal field that would be unknown to third-party endpoints.
type openaiMessage struct {
Role string `json:"role"`
Content string `json:"content"`
ReasoningContent string `json:"reasoning_content,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
}
// SerializeMessages converts internal Message structs to the OpenAI wire format.
// - Strips SystemParts (unknown to third-party endpoints)
// - Converts messages with Media to multipart content format (text + image_url parts)
// - Preserves ToolCallID, ToolCalls, and ReasoningContent for all messages
func SerializeMessages(messages []Message) []any {
out := make([]any, 0, len(messages))
for _, m := range messages {
if len(m.Media) == 0 {
out = append(out, openaiMessage{
Role: m.Role,
Content: m.Content,
ReasoningContent: m.ReasoningContent,
ToolCalls: m.ToolCalls,
ToolCallID: m.ToolCallID,
})
continue
}
// Multipart content format for messages with media
parts := make([]map[string]any, 0, 1+len(m.Media))
if m.Content != "" {
parts = append(parts, map[string]any{
"type": "text",
"text": m.Content,
})
}
for _, mediaURL := range m.Media {
if strings.HasPrefix(mediaURL, "data:image/") {
parts = append(parts, map[string]any{
"type": "image_url",
"image_url": map[string]any{
"url": mediaURL,
},
})
}
}
msg := map[string]any{
"role": m.Role,
"content": parts,
}
if m.ToolCallID != "" {
msg["tool_call_id"] = m.ToolCallID
}
if len(m.ToolCalls) > 0 {
msg["tool_calls"] = m.ToolCalls
}
if m.ReasoningContent != "" {
msg["reasoning_content"] = m.ReasoningContent
}
out = append(out, msg)
}
return out
}
// --- Response parsing ---
// ParseResponse parses a JSON chat completion response body into an LLMResponse.
func ParseResponse(body io.Reader) (*LLMResponse, error) {
var apiResponse struct {
Choices []struct {
Message struct {
Content string `json:"content"`
ReasoningContent string `json:"reasoning_content"`
Reasoning string `json:"reasoning"`
ReasoningDetails []ReasoningDetail `json:"reasoning_details"`
ToolCalls []struct {
ID string `json:"id"`
Type string `json:"type"`
Function *struct {
Name string `json:"name"`
Arguments json.RawMessage `json:"arguments"`
} `json:"function"`
ExtraContent *struct {
Google *struct {
ThoughtSignature string `json:"thought_signature"`
} `json:"google"`
} `json:"extra_content"`
} `json:"tool_calls"`
} `json:"message"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
Usage *UsageInfo `json:"usage"`
}
if err := json.NewDecoder(body).Decode(&apiResponse); err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}
if len(apiResponse.Choices) == 0 {
return &LLMResponse{
Content: "",
FinishReason: "stop",
}, nil
}
choice := apiResponse.Choices[0]
toolCalls := make([]ToolCall, 0, len(choice.Message.ToolCalls))
for _, tc := range choice.Message.ToolCalls {
arguments := make(map[string]any)
name := ""
// Extract thought_signature from Gemini/Google-specific extra content
thoughtSignature := ""
if tc.ExtraContent != nil && tc.ExtraContent.Google != nil {
thoughtSignature = tc.ExtraContent.Google.ThoughtSignature
}
if tc.Function != nil {
name = tc.Function.Name
arguments = DecodeToolCallArguments(tc.Function.Arguments, name)
}
toolCall := ToolCall{
ID: tc.ID,
Name: name,
Arguments: arguments,
ThoughtSignature: thoughtSignature,
}
if thoughtSignature != "" {
toolCall.ExtraContent = &ExtraContent{
Google: &GoogleExtra{
ThoughtSignature: thoughtSignature,
},
}
}
toolCalls = append(toolCalls, toolCall)
}
return &LLMResponse{
Content: choice.Message.Content,
ReasoningContent: choice.Message.ReasoningContent,
Reasoning: choice.Message.Reasoning,
ReasoningDetails: choice.Message.ReasoningDetails,
ToolCalls: toolCalls,
FinishReason: choice.FinishReason,
Usage: apiResponse.Usage,
}, nil
}
// DecodeToolCallArguments decodes a tool call's arguments from raw JSON.
func DecodeToolCallArguments(raw json.RawMessage, name string) map[string]any {
arguments := make(map[string]any)
raw = bytes.TrimSpace(raw)
if len(raw) == 0 || bytes.Equal(raw, []byte("null")) {
return arguments
}
var decoded any
if err := json.Unmarshal(raw, &decoded); err != nil {
log.Printf("common: failed to decode tool call arguments payload for %q: %v", name, err)
arguments["raw"] = string(raw)
return arguments
}
switch v := decoded.(type) {
case string:
if strings.TrimSpace(v) == "" {
return arguments
}
if err := json.Unmarshal([]byte(v), &arguments); err != nil {
log.Printf("common: failed to decode tool call arguments for %q: %v", name, err)
arguments["raw"] = v
}
return arguments
case map[string]any:
return v
default:
log.Printf("common: unsupported tool call arguments type for %q: %T", name, decoded)
arguments["raw"] = string(raw)
return arguments
}
}
// --- HTTP response helpers ---
// HandleErrorResponse reads a non-200 response body and returns an appropriate error.
func HandleErrorResponse(resp *http.Response, apiBase string) error {
contentType := resp.Header.Get("Content-Type")
body, readErr := io.ReadAll(io.LimitReader(resp.Body, 256))
if readErr != nil {
return fmt.Errorf("failed to read response: %w", readErr)
}
if LooksLikeHTML(body, contentType) {
return WrapHTMLResponseError(resp.StatusCode, body, contentType, apiBase)
}
return fmt.Errorf(
"API request failed:\n Status: %d\n Body: %s",
resp.StatusCode,
ResponsePreview(body, 128),
)
}
// ReadAndParseResponse peeks at the response body to detect HTML errors,
// then parses the JSON response into an LLMResponse.
func ReadAndParseResponse(resp *http.Response, apiBase string) (*LLMResponse, error) {
contentType := resp.Header.Get("Content-Type")
reader := bufio.NewReader(resp.Body)
prefix, err := reader.Peek(256)
if err != nil && err != io.EOF && err != bufio.ErrBufferFull {
return nil, fmt.Errorf("failed to inspect response: %w", err)
}
if LooksLikeHTML(prefix, contentType) {
return nil, WrapHTMLResponseError(resp.StatusCode, prefix, contentType, apiBase)
}
out, err := ParseResponse(reader)
if err != nil {
return nil, fmt.Errorf("failed to parse JSON response: %w", err)
}
return out, nil
}
// LooksLikeHTML checks if the response body appears to be HTML.
func LooksLikeHTML(body []byte, contentType string) bool {
contentType = strings.ToLower(strings.TrimSpace(contentType))
if strings.Contains(contentType, "text/html") || strings.Contains(contentType, "application/xhtml+xml") {
return true
}
prefix := bytes.ToLower(leadingTrimmedPrefix(body, 128))
return bytes.HasPrefix(prefix, []byte("<!doctype html")) ||
bytes.HasPrefix(prefix, []byte("<html")) ||
bytes.HasPrefix(prefix, []byte("<head")) ||
bytes.HasPrefix(prefix, []byte("<body"))
}
// WrapHTMLResponseError creates a descriptive error for HTML responses.
func WrapHTMLResponseError(statusCode int, body []byte, contentType, apiBase string) error {
respPreview := ResponsePreview(body, 128)
return fmt.Errorf(
"API request failed: %s returned HTML instead of JSON (content-type: %s); check api_base or proxy configuration.\n Status: %d\n Body: %s",
apiBase,
contentType,
statusCode,
respPreview,
)
}
// ResponsePreview returns a truncated preview of response body for error messages.
func ResponsePreview(body []byte, maxLen int) string {
trimmed := bytes.TrimSpace(body)
if len(trimmed) == 0 {
return "<empty>"
}
if len(trimmed) <= maxLen {
return string(trimmed)
}
return string(trimmed[:maxLen]) + "..."
}
func leadingTrimmedPrefix(body []byte, maxLen int) []byte {
i := 0
for i < len(body) {
switch body[i] {
case ' ', '\t', '\n', '\r', '\f', '\v':
i++
default:
end := i + maxLen
if end > len(body) {
end = len(body)
}
return body[i:end]
}
}
return nil
}
// --- Numeric helpers ---
// AsInt converts various numeric types to int.
func AsInt(v any) (int, bool) {
switch val := v.(type) {
case int:
return val, true
case int64:
return int(val), true
case float64:
return int(val), true
case float32:
return int(val), true
default:
return 0, false
}
}
// AsFloat converts various numeric types to float64.
func AsFloat(v any) (float64, bool) {
switch val := v.(type) {
case float64:
return val, true
case float32:
return float64(val), true
case int:
return float64(val), true
case int64:
return float64(val), true
default:
return 0, false
}
}
+558
View File
@@ -0,0 +1,558 @@
package common
import (
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
// --- NewHTTPClient tests ---
func TestNewHTTPClient_DefaultTimeout(t *testing.T) {
client := NewHTTPClient("")
if client.Timeout != DefaultRequestTimeout {
t.Errorf("timeout = %v, want %v", client.Timeout, DefaultRequestTimeout)
}
}
func TestNewHTTPClient_WithProxy(t *testing.T) {
client := NewHTTPClient("http://127.0.0.1:8080")
transport, ok := client.Transport.(*http.Transport)
if !ok || transport == nil {
t.Fatalf("expected http.Transport with proxy, got %T", client.Transport)
}
req := &http.Request{URL: &url.URL{Scheme: "https", Host: "api.example.com"}}
gotProxy, err := transport.Proxy(req)
if err != nil {
t.Fatalf("proxy function error: %v", err)
}
if gotProxy == nil || gotProxy.String() != "http://127.0.0.1:8080" {
t.Errorf("proxy = %v, want http://127.0.0.1:8080", gotProxy)
}
}
func TestNewHTTPClient_NoProxy(t *testing.T) {
client := NewHTTPClient("")
if client.Transport != nil {
t.Errorf("expected nil transport without proxy, got %T", client.Transport)
}
}
func TestNewHTTPClient_InvalidProxy(t *testing.T) {
// Should not panic, just log and return client without proxy
client := NewHTTPClient("://bad-url")
if client == nil {
t.Fatal("expected non-nil client even with invalid proxy")
}
}
// --- SerializeMessages tests ---
func TestSerializeMessages_PlainText(t *testing.T) {
messages := []Message{
{Role: "user", Content: "hello"},
{Role: "assistant", Content: "hi", ReasoningContent: "thinking..."},
}
result := SerializeMessages(messages)
data, _ := json.Marshal(result)
var msgs []map[string]any
json.Unmarshal(data, &msgs)
if msgs[0]["content"] != "hello" {
t.Errorf("expected plain string content, got %v", msgs[0]["content"])
}
if msgs[1]["reasoning_content"] != "thinking..." {
t.Errorf("reasoning_content not preserved, got %v", msgs[1]["reasoning_content"])
}
}
func TestSerializeMessages_WithMedia(t *testing.T) {
messages := []Message{
{Role: "user", Content: "describe this", Media: []string{"data:image/png;base64,abc123"}},
}
result := SerializeMessages(messages)
data, _ := json.Marshal(result)
var msgs []map[string]any
json.Unmarshal(data, &msgs)
content, ok := msgs[0]["content"].([]any)
if !ok {
t.Fatalf("expected array content for media message, got %T", msgs[0]["content"])
}
if len(content) != 2 {
t.Fatalf("expected 2 content parts, got %d", len(content))
}
}
func TestSerializeMessages_MediaWithToolCallID(t *testing.T) {
messages := []Message{
{Role: "tool", Content: "result", Media: []string{"data:image/png;base64,xyz"}, ToolCallID: "call_1"},
}
result := SerializeMessages(messages)
data, _ := json.Marshal(result)
var msgs []map[string]any
json.Unmarshal(data, &msgs)
if msgs[0]["tool_call_id"] != "call_1" {
t.Errorf("tool_call_id not preserved, got %v", msgs[0]["tool_call_id"])
}
}
func TestSerializeMessages_StripsSystemParts(t *testing.T) {
messages := []Message{
{
Role: "system",
Content: "you are helpful",
SystemParts: []protocoltypes.ContentBlock{
{Type: "text", Text: "you are helpful"},
},
},
}
result := SerializeMessages(messages)
data, _ := json.Marshal(result)
if strings.Contains(string(data), "system_parts") {
t.Error("system_parts should not appear in serialized output")
}
}
// --- ParseResponse tests ---
func TestParseResponse_BasicContent(t *testing.T) {
body := `{"choices":[{"message":{"content":"hello world"},"finish_reason":"stop"}]}`
out, err := ParseResponse(strings.NewReader(body))
if err != nil {
t.Fatalf("ParseResponse() error = %v", err)
}
if out.Content != "hello world" {
t.Errorf("Content = %q, want %q", out.Content, "hello world")
}
if out.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want %q", out.FinishReason, "stop")
}
}
func TestParseResponse_EmptyChoices(t *testing.T) {
body := `{"choices":[]}`
out, err := ParseResponse(strings.NewReader(body))
if err != nil {
t.Fatalf("ParseResponse() error = %v", err)
}
if out.Content != "" {
t.Errorf("Content = %q, want empty", out.Content)
}
if out.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want %q", out.FinishReason, "stop")
}
}
func TestParseResponse_WithToolCalls(t *testing.T) {
body := `{"choices":[{"message":{"content":"","tool_calls":[{"id":"call_1","type":"function","function":{"name":"get_weather","arguments":"{\"city\":\"SF\"}"}}]},"finish_reason":"tool_calls"}]}`
out, err := ParseResponse(strings.NewReader(body))
if err != nil {
t.Fatalf("ParseResponse() error = %v", err)
}
if len(out.ToolCalls) != 1 {
t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls))
}
if out.ToolCalls[0].Name != "get_weather" {
t.Errorf("ToolCalls[0].Name = %q, want %q", out.ToolCalls[0].Name, "get_weather")
}
if out.ToolCalls[0].Arguments["city"] != "SF" {
t.Errorf("ToolCalls[0].Arguments[city] = %v, want SF", out.ToolCalls[0].Arguments["city"])
}
}
func TestParseResponse_WithUsage(t *testing.T) {
body := `{"choices":[{"message":{"content":"ok"},"finish_reason":"stop"}],"usage":{"prompt_tokens":10,"completion_tokens":5,"total_tokens":15}}`
out, err := ParseResponse(strings.NewReader(body))
if err != nil {
t.Fatalf("ParseResponse() error = %v", err)
}
if out.Usage == nil {
t.Fatal("Usage is nil")
}
if out.Usage.PromptTokens != 10 {
t.Errorf("PromptTokens = %d, want 10", out.Usage.PromptTokens)
}
}
func TestParseResponse_WithReasoningContent(t *testing.T) {
body := `{"choices":[{"message":{"content":"2","reasoning_content":"Let me think... 1+1=2"},"finish_reason":"stop"}]}`
out, err := ParseResponse(strings.NewReader(body))
if err != nil {
t.Fatalf("ParseResponse() error = %v", err)
}
if out.ReasoningContent != "Let me think... 1+1=2" {
t.Errorf("ReasoningContent = %q, want %q", out.ReasoningContent, "Let me think... 1+1=2")
}
}
func TestParseResponse_InvalidJSON(t *testing.T) {
_, err := ParseResponse(strings.NewReader("not json"))
if err == nil {
t.Fatal("expected error for invalid JSON")
}
}
// --- DecodeToolCallArguments tests ---
func TestDecodeToolCallArguments_ObjectJSON(t *testing.T) {
raw := json.RawMessage(`{"city":"Seattle","units":"metric"}`)
args := DecodeToolCallArguments(raw, "test")
if args["city"] != "Seattle" {
t.Errorf("city = %v, want Seattle", args["city"])
}
if args["units"] != "metric" {
t.Errorf("units = %v, want metric", args["units"])
}
}
func TestDecodeToolCallArguments_StringJSON(t *testing.T) {
raw := json.RawMessage(`"{\"city\":\"SF\"}"`)
args := DecodeToolCallArguments(raw, "test")
if args["city"] != "SF" {
t.Errorf("city = %v, want SF", args["city"])
}
}
func TestDecodeToolCallArguments_EmptyInput(t *testing.T) {
args := DecodeToolCallArguments(nil, "test")
if len(args) != 0 {
t.Errorf("expected empty map, got %v", args)
}
}
func TestDecodeToolCallArguments_NullInput(t *testing.T) {
args := DecodeToolCallArguments(json.RawMessage(`null`), "test")
if len(args) != 0 {
t.Errorf("expected empty map, got %v", args)
}
}
func TestDecodeToolCallArguments_InvalidJSON(t *testing.T) {
args := DecodeToolCallArguments(json.RawMessage(`not-json`), "test")
if _, ok := args["raw"]; !ok {
t.Error("expected 'raw' fallback key for invalid JSON")
}
}
func TestDecodeToolCallArguments_EmptyStringJSON(t *testing.T) {
args := DecodeToolCallArguments(json.RawMessage(`" "`), "test")
if len(args) != 0 {
t.Errorf("expected empty map for whitespace string, got %v", args)
}
}
// --- HandleErrorResponse tests ---
func TestHandleErrorResponse_JSONError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(`{"error":"bad request"}`))
}))
defer server.Close()
resp, err := http.Get(server.URL)
if err != nil {
t.Fatalf("http.Get() error = %v", err)
}
defer resp.Body.Close()
err = HandleErrorResponse(resp, server.URL)
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), "400") {
t.Errorf("error should contain status code, got %v", err)
}
if strings.Contains(err.Error(), "HTML") {
t.Errorf("should not mention HTML for JSON error, got %v", err)
}
}
func TestHandleErrorResponse_HTMLError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html")
w.WriteHeader(http.StatusBadGateway)
w.Write([]byte("<!DOCTYPE html><html><body>bad gateway</body></html>"))
}))
defer server.Close()
resp, err := http.Get(server.URL)
if err != nil {
t.Fatalf("http.Get() error = %v", err)
}
defer resp.Body.Close()
err = HandleErrorResponse(resp, server.URL)
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), "HTML instead of JSON") {
t.Errorf("expected HTML error message, got %v", err)
}
}
// --- ReadAndParseResponse tests ---
func TestReadAndParseResponse_ValidJSON(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"choices":[{"message":{"content":"ok"},"finish_reason":"stop"}]}`))
}))
defer server.Close()
resp, err := http.Get(server.URL)
if err != nil {
t.Fatalf("http.Get() error = %v", err)
}
defer resp.Body.Close()
out, err := ReadAndParseResponse(resp, server.URL)
if err != nil {
t.Fatalf("ReadAndParseResponse() error = %v", err)
}
if out.Content != "ok" {
t.Errorf("Content = %q, want %q", out.Content, "ok")
}
}
func TestReadAndParseResponse_HTMLResponse(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html")
w.Write([]byte("<!DOCTYPE html><html><body>login page</body></html>"))
}))
defer server.Close()
resp, err := http.Get(server.URL)
if err != nil {
t.Fatalf("http.Get() error = %v", err)
}
defer resp.Body.Close()
_, err = ReadAndParseResponse(resp, server.URL)
if err == nil {
t.Fatal("expected error for HTML response")
}
if !strings.Contains(err.Error(), "HTML instead of JSON") {
t.Errorf("expected HTML error, got %v", err)
}
}
// --- LooksLikeHTML tests ---
func TestLooksLikeHTML_ContentTypeHTML(t *testing.T) {
if !LooksLikeHTML(nil, "text/html; charset=utf-8") {
t.Error("expected true for text/html content type")
}
}
func TestLooksLikeHTML_ContentTypeXHTML(t *testing.T) {
if !LooksLikeHTML(nil, "application/xhtml+xml") {
t.Error("expected true for xhtml content type")
}
}
func TestLooksLikeHTML_BodyPrefix(t *testing.T) {
tests := []struct {
name string
body string
}{
{"doctype", "<!DOCTYPE html><html>"},
{"html tag", "<html><body>"},
{"head tag", "<head><title>"},
{"body tag", "<body>content"},
{"whitespace before", " \n\t<!DOCTYPE html>"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if !LooksLikeHTML([]byte(tt.body), "application/json") {
t.Errorf("expected true for body %q", tt.body)
}
})
}
}
func TestLooksLikeHTML_NotHTML(t *testing.T) {
if LooksLikeHTML([]byte(`{"error":"bad"}`), "application/json") {
t.Error("expected false for JSON body")
}
}
// --- ResponsePreview tests ---
func TestResponsePreview_Short(t *testing.T) {
got := ResponsePreview([]byte("hello"), 128)
if got != "hello" {
t.Errorf("got %q, want %q", got, "hello")
}
}
func TestResponsePreview_Truncated(t *testing.T) {
body := strings.Repeat("a", 200)
got := ResponsePreview([]byte(body), 128)
if len(got) != 131 { // 128 + "..."
t.Errorf("len = %d, want 131", len(got))
}
if !strings.HasSuffix(got, "...") {
t.Error("expected ... suffix")
}
}
func TestResponsePreview_Empty(t *testing.T) {
got := ResponsePreview([]byte(""), 128)
if got != "<empty>" {
t.Errorf("got %q, want %q", got, "<empty>")
}
}
func TestResponsePreview_Whitespace(t *testing.T) {
got := ResponsePreview([]byte(" \n\t "), 128)
if got != "<empty>" {
t.Errorf("got %q, want %q for whitespace-only body", got, "<empty>")
}
}
// --- AsInt tests ---
func TestAsInt(t *testing.T) {
tests := []struct {
name string
val any
want int
ok bool
}{
{"int", 42, 42, true},
{"int64", int64(99), 99, true},
{"float64", float64(512), 512, true},
{"float32", float32(256), 256, true},
{"string", "nope", 0, false},
{"nil", nil, 0, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, ok := AsInt(tt.val)
if ok != tt.ok || got != tt.want {
t.Errorf("AsInt(%v) = (%d, %v), want (%d, %v)", tt.val, got, ok, tt.want, tt.ok)
}
})
}
}
// --- AsFloat tests ---
func TestAsFloat(t *testing.T) {
tests := []struct {
name string
val any
want float64
ok bool
}{
{"float64", float64(0.7), 0.7, true},
{"float32", float32(0.5), float64(float32(0.5)), true},
{"int", 1, 1.0, true},
{"int64", int64(100), 100.0, true},
{"string", "nope", 0, false},
{"nil", nil, 0, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, ok := AsFloat(tt.val)
if ok != tt.ok || got != tt.want {
t.Errorf("AsFloat(%v) = (%f, %v), want (%f, %v)", tt.val, got, ok, tt.want, tt.ok)
}
})
}
}
// --- WrapHTMLResponseError tests ---
func TestWrapHTMLResponseError(t *testing.T) {
err := WrapHTMLResponseError(502, []byte("<html>bad</html>"), "text/html", "https://api.example.com")
if err == nil {
t.Fatal("expected error")
}
msg := err.Error()
if !strings.Contains(msg, "502") {
t.Errorf("expected status code in error, got %v", msg)
}
if !strings.Contains(msg, "https://api.example.com") {
t.Errorf("expected api base in error, got %v", msg)
}
if !strings.Contains(msg, "HTML instead of JSON") {
t.Errorf("expected HTML mention in error, got %v", msg)
}
}
// --- HandleErrorResponse with read failure ---
func TestHandleErrorResponse_EmptyBody(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusInternalServerError)
// empty body
}))
defer server.Close()
resp, err := http.Get(server.URL)
if err != nil {
t.Fatalf("http.Get() error = %v", err)
}
defer resp.Body.Close()
err = HandleErrorResponse(resp, server.URL)
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), "500") {
t.Errorf("expected status code, got %v", err)
}
}
// --- ReadAndParseResponse with invalid JSON ---
func TestReadAndParseResponse_InvalidJSON(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte("not valid json"))
}))
defer server.Close()
resp, err := http.Get(server.URL)
if err != nil {
t.Fatalf("http.Get() error = %v", err)
}
defer resp.Body.Close()
_, err = ReadAndParseResponse(resp, server.URL)
if err == nil {
t.Fatal("expected error for invalid JSON")
}
}
// --- ParseResponse with thought_signature (Google/Gemini) ---
func TestParseResponse_WithThoughtSignature(t *testing.T) {
body := `{"choices":[{"message":{"content":"","tool_calls":[{"id":"call_1","type":"function","function":{"name":"test_tool","arguments":"{}"},"extra_content":{"google":{"thought_signature":"sig123"}}}]},"finish_reason":"tool_calls"}]}`
out, err := ParseResponse(strings.NewReader(body))
if err != nil {
t.Fatalf("ParseResponse() error = %v", err)
}
if len(out.ToolCalls) != 1 {
t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls))
}
if out.ToolCalls[0].ThoughtSignature != "sig123" {
t.Errorf("ThoughtSignature = %q, want %q", out.ToolCalls[0].ThoughtSignature, "sig123")
}
if out.ToolCalls[0].ExtraContent == nil || out.ToolCalls[0].ExtraContent.Google == nil {
t.Fatal("ExtraContent.Google is nil")
}
if out.ToolCalls[0].ExtraContent.Google.ThoughtSignature != "sig123" {
t.Errorf("ExtraContent.Google.ThoughtSignature = %q, want %q",
out.ToolCalls[0].ExtraContent.Google.ThoughtSignature, "sig123")
}
}
+19
View File
@@ -11,6 +11,7 @@ import (
"github.com/sipeed/picoclaw/pkg/config"
anthropicmessages "github.com/sipeed/picoclaw/pkg/providers/anthropic_messages"
"github.com/sipeed/picoclaw/pkg/providers/azure"
)
// createClaudeAuthProvider creates a Claude provider using OAuth credentials from auth store.
@@ -94,6 +95,24 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
cfg.RequestTimeout,
), modelID, nil
case "azure", "azure-openai":
// Azure OpenAI uses deployment-based URLs, api-key header auth,
// and always sends max_completion_tokens.
if cfg.APIKey == "" {
return nil, "", fmt.Errorf("api_key is required for azure protocol")
}
if cfg.APIBase == "" {
return nil, "", fmt.Errorf(
"api_base is required for azure protocol (e.g., https://your-resource.openai.azure.com)",
)
}
return azure.NewProviderWithTimeout(
cfg.APIKey,
cfg.APIBase,
cfg.Proxy,
cfg.RequestTimeout,
), modelID, nil
case "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia",
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
"vivgrid", "volcengine", "vllm", "qwen", "mistral", "avian",
+72
View File
@@ -64,6 +64,12 @@ func TestExtractProtocol(t *testing.T) {
wantProtocol: "nvidia",
wantModelID: "meta/llama-3.1-8b",
},
{
name: "azure with prefix",
model: "azure/my-gpt5-deployment",
wantProtocol: "azure",
wantModelID: "my-gpt5-deployment",
},
}
for _, tt := range tests {
@@ -371,3 +377,69 @@ func TestCreateProviderFromConfig_RequestTimeoutPropagation(t *testing.T) {
t.Fatalf("Chat() error = %q, want timeout-related error", errMsg)
}
}
func TestCreateProviderFromConfig_Azure(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "azure-gpt5",
Model: "azure/my-gpt5-deployment",
APIKey: "test-azure-key",
APIBase: "https://my-resource.openai.azure.com",
}
provider, modelID, err := CreateProviderFromConfig(cfg)
if err != nil {
t.Fatalf("CreateProviderFromConfig() error = %v", err)
}
if provider == nil {
t.Fatal("CreateProviderFromConfig() returned nil provider")
}
if modelID != "my-gpt5-deployment" {
t.Errorf("modelID = %q, want %q", modelID, "my-gpt5-deployment")
}
}
func TestCreateProviderFromConfig_AzureOpenAIAlias(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "azure-gpt4",
Model: "azure-openai/my-deployment",
APIKey: "test-azure-key",
APIBase: "https://my-resource.openai.azure.com",
}
provider, modelID, err := CreateProviderFromConfig(cfg)
if err != nil {
t.Fatalf("CreateProviderFromConfig() error = %v", err)
}
if provider == nil {
t.Fatal("CreateProviderFromConfig() returned nil provider")
}
if modelID != "my-deployment" {
t.Errorf("modelID = %q, want %q", modelID, "my-deployment")
}
}
func TestCreateProviderFromConfig_AzureMissingAPIKey(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "azure-gpt5",
Model: "azure/my-gpt5-deployment",
APIBase: "https://my-resource.openai.azure.com",
}
_, _, err := CreateProviderFromConfig(cfg)
if err == nil {
t.Fatal("CreateProviderFromConfig() expected error for missing API key")
}
}
func TestCreateProviderFromConfig_AzureMissingAPIBase(t *testing.T) {
cfg := &config.ModelConfig{
ModelName: "azure-gpt5",
Model: "azure/my-gpt5-deployment",
APIKey: "test-azure-key",
}
_, _, err := CreateProviderFromConfig(cfg)
if err == nil {
t.Fatal("CreateProviderFromConfig() expected error for missing API base")
}
}
+8 -319
View File
@@ -1,18 +1,16 @@
package openai_compat
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"net/url"
"strings"
"time"
"github.com/sipeed/picoclaw/pkg/providers/common"
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
@@ -38,7 +36,7 @@ type Provider struct {
type Option func(*Provider)
const defaultRequestTimeout = 120 * time.Second
const defaultRequestTimeout = common.DefaultRequestTimeout
func WithMaxTokensField(maxTokensField string) Option {
return func(p *Provider) {
@@ -55,25 +53,10 @@ func WithRequestTimeout(timeout time.Duration) Option {
}
func NewProvider(apiKey, apiBase, proxy string, opts ...Option) *Provider {
client := &http.Client{
Timeout: defaultRequestTimeout,
}
if proxy != "" {
parsed, err := url.Parse(proxy)
if err == nil {
client.Transport = &http.Transport{
Proxy: http.ProxyURL(parsed),
}
} else {
log.Printf("openai_compat: invalid proxy URL %q: %v", proxy, err)
}
}
p := &Provider{
apiKey: apiKey,
apiBase: strings.TrimRight(apiBase, "/"),
httpClient: client,
httpClient: common.NewHTTPClient(proxy),
}
for _, opt := range opts {
@@ -117,7 +100,7 @@ func (p *Provider) Chat(
requestBody := map[string]any{
"model": model,
"messages": serializeMessages(messages),
"messages": common.SerializeMessages(messages),
}
if len(tools) > 0 {
@@ -125,7 +108,7 @@ func (p *Provider) Chat(
requestBody["tool_choice"] = "auto"
}
if maxTokens, ok := asInt(options["max_tokens"]); ok {
if maxTokens, ok := common.AsInt(options["max_tokens"]); ok {
// Use configured maxTokensField if specified, otherwise fallback to model-based detection
fieldName := p.maxTokensField
if fieldName == "" {
@@ -141,7 +124,7 @@ func (p *Provider) Chat(
requestBody[fieldName] = maxTokens
}
if temperature, ok := asFloat(options["temperature"]); ok {
if temperature, ok := common.AsFloat(options["temperature"]); ok {
lowerModel := strings.ToLower(model)
// Kimi k2 models only support temperature=1.
if strings.Contains(lowerModel, "kimi") && strings.Contains(lowerModel, "k2") {
@@ -185,275 +168,11 @@ func (p *Provider) Chat(
}
defer resp.Body.Close()
contentType := resp.Header.Get("Content-Type")
// Non-200: read a prefix to tell HTML error page apart from JSON error body.
if resp.StatusCode != http.StatusOK {
body, readErr := io.ReadAll(io.LimitReader(resp.Body, 256))
if readErr != nil {
return nil, fmt.Errorf("failed to read response: %w", readErr)
}
if looksLikeHTML(body, contentType) {
return nil, wrapHTMLResponseError(resp.StatusCode, body, contentType, p.apiBase)
}
return nil, fmt.Errorf(
"API request failed:\n Status: %d\n Body: %s",
resp.StatusCode,
responsePreview(body, 128),
)
return nil, common.HandleErrorResponse(resp, p.apiBase)
}
// Peek without consuming so the full stream reaches the JSON decoder.
reader := bufio.NewReader(resp.Body)
prefix, err := reader.Peek(256) // io.EOF/ErrBufferFull are normal; only real errors abort
if err != nil && err != io.EOF && err != bufio.ErrBufferFull {
return nil, fmt.Errorf("failed to inspect response: %w", err)
}
if looksLikeHTML(prefix, contentType) {
return nil, wrapHTMLResponseError(resp.StatusCode, prefix, contentType, p.apiBase)
}
out, err := parseResponse(reader)
if err != nil {
return nil, fmt.Errorf("failed to parse JSON response: %w", err)
}
return out, nil
}
func wrapHTMLResponseError(statusCode int, body []byte, contentType, apiBase string) error {
respPreview := responsePreview(body, 128)
return fmt.Errorf(
"API request failed: %s returned HTML instead of JSON (content-type: %s); check api_base or proxy configuration.\n Status: %d\n Body: %s",
apiBase,
contentType,
statusCode,
respPreview,
)
}
func looksLikeHTML(body []byte, contentType string) bool {
contentType = strings.ToLower(strings.TrimSpace(contentType))
if strings.Contains(contentType, "text/html") || strings.Contains(contentType, "application/xhtml+xml") {
return true
}
prefix := bytes.ToLower(leadingTrimmedPrefix(body, 128))
return bytes.HasPrefix(prefix, []byte("<!doctype html")) ||
bytes.HasPrefix(prefix, []byte("<html")) ||
bytes.HasPrefix(prefix, []byte("<head")) ||
bytes.HasPrefix(prefix, []byte("<body"))
}
func leadingTrimmedPrefix(body []byte, maxLen int) []byte {
i := 0
for i < len(body) {
switch body[i] {
case ' ', '\t', '\n', '\r', '\f', '\v':
i++
default:
end := i + maxLen
if end > len(body) {
end = len(body)
}
return body[i:end]
}
}
return nil
}
func responsePreview(body []byte, maxLen int) string {
trimmed := bytes.TrimSpace(body)
if len(trimmed) == 0 {
return "<empty>"
}
if len(trimmed) <= maxLen {
return string(trimmed)
}
return string(trimmed[:maxLen]) + "..."
}
func parseResponse(body io.Reader) (*LLMResponse, error) {
var apiResponse struct {
Choices []struct {
Message struct {
Content string `json:"content"`
ReasoningContent string `json:"reasoning_content"`
Reasoning string `json:"reasoning"`
ReasoningDetails []ReasoningDetail `json:"reasoning_details"`
ToolCalls []struct {
ID string `json:"id"`
Type string `json:"type"`
Function *struct {
Name string `json:"name"`
Arguments json.RawMessage `json:"arguments"`
} `json:"function"`
ExtraContent *struct {
Google *struct {
ThoughtSignature string `json:"thought_signature"`
} `json:"google"`
} `json:"extra_content"`
} `json:"tool_calls"`
} `json:"message"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
Usage *UsageInfo `json:"usage"`
}
if err := json.NewDecoder(body).Decode(&apiResponse); err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}
if len(apiResponse.Choices) == 0 {
return &LLMResponse{
Content: "",
FinishReason: "stop",
}, nil
}
choice := apiResponse.Choices[0]
toolCalls := make([]ToolCall, 0, len(choice.Message.ToolCalls))
for _, tc := range choice.Message.ToolCalls {
arguments := make(map[string]any)
name := ""
// Extract thought_signature from Gemini/Google-specific extra content
thoughtSignature := ""
if tc.ExtraContent != nil && tc.ExtraContent.Google != nil {
thoughtSignature = tc.ExtraContent.Google.ThoughtSignature
}
if tc.Function != nil {
name = tc.Function.Name
arguments = decodeToolCallArguments(tc.Function.Arguments, name)
}
// Build ToolCall with ExtraContent for Gemini 3 thought_signature persistence
toolCall := ToolCall{
ID: tc.ID,
Name: name,
Arguments: arguments,
ThoughtSignature: thoughtSignature,
}
if thoughtSignature != "" {
toolCall.ExtraContent = &ExtraContent{
Google: &GoogleExtra{
ThoughtSignature: thoughtSignature,
},
}
}
toolCalls = append(toolCalls, toolCall)
}
return &LLMResponse{
Content: choice.Message.Content,
ReasoningContent: choice.Message.ReasoningContent,
Reasoning: choice.Message.Reasoning,
ReasoningDetails: choice.Message.ReasoningDetails,
ToolCalls: toolCalls,
FinishReason: choice.FinishReason,
Usage: apiResponse.Usage,
}, nil
}
func decodeToolCallArguments(raw json.RawMessage, name string) map[string]any {
arguments := make(map[string]any)
raw = bytes.TrimSpace(raw)
if len(raw) == 0 || bytes.Equal(raw, []byte("null")) {
return arguments
}
var decoded any
if err := json.Unmarshal(raw, &decoded); err != nil {
log.Printf("openai_compat: failed to decode tool call arguments payload for %q: %v", name, err)
arguments["raw"] = string(raw)
return arguments
}
switch v := decoded.(type) {
case string:
if strings.TrimSpace(v) == "" {
return arguments
}
if err := json.Unmarshal([]byte(v), &arguments); err != nil {
log.Printf("openai_compat: failed to decode tool call arguments for %q: %v", name, err)
arguments["raw"] = v
}
return arguments
case map[string]any:
return v
default:
log.Printf("openai_compat: unsupported tool call arguments type for %q: %T", name, decoded)
arguments["raw"] = string(raw)
return arguments
}
}
// openaiMessage is the wire-format message for OpenAI-compatible APIs.
// It mirrors protocoltypes.Message but omits SystemParts, which is an
// internal field that would be unknown to third-party endpoints.
type openaiMessage struct {
Role string `json:"role"`
Content string `json:"content"`
ReasoningContent string `json:"reasoning_content,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
}
// serializeMessages converts internal Message structs to the OpenAI wire format.
// - Strips SystemParts (unknown to third-party endpoints)
// - Converts messages with Media to multipart content format (text + image_url parts)
// - Preserves ToolCallID, ToolCalls, and ReasoningContent for all messages
func serializeMessages(messages []Message) []any {
out := make([]any, 0, len(messages))
for _, m := range messages {
if len(m.Media) == 0 {
out = append(out, openaiMessage{
Role: m.Role,
Content: m.Content,
ReasoningContent: m.ReasoningContent,
ToolCalls: m.ToolCalls,
ToolCallID: m.ToolCallID,
})
continue
}
// Multipart content format for messages with media
parts := make([]map[string]any, 0, 1+len(m.Media))
if m.Content != "" {
parts = append(parts, map[string]any{
"type": "text",
"text": m.Content,
})
}
for _, mediaURL := range m.Media {
if strings.HasPrefix(mediaURL, "data:image/") {
parts = append(parts, map[string]any{
"type": "image_url",
"image_url": map[string]any{
"url": mediaURL,
},
})
}
}
msg := map[string]any{
"role": m.Role,
"content": parts,
}
if m.ToolCallID != "" {
msg["tool_call_id"] = m.ToolCallID
}
if len(m.ToolCalls) > 0 {
msg["tool_calls"] = m.ToolCalls
}
if m.ReasoningContent != "" {
msg["reasoning_content"] = m.ReasoningContent
}
out = append(out, msg)
}
return out
return common.ReadAndParseResponse(resp, p.apiBase)
}
func normalizeModel(model, apiBase string) string {
@@ -476,36 +195,6 @@ func normalizeModel(model, apiBase string) string {
}
}
func asInt(v any) (int, bool) {
switch val := v.(type) {
case int:
return val, true
case int64:
return int(val), true
case float64:
return int(val), true
case float32:
return int(val), true
default:
return 0, false
}
}
func asFloat(v any) (float64, bool) {
switch val := v.(type) {
case float64:
return val, true
case float32:
return float64(val), true
case int:
return float64(val), true
case int64:
return float64(val), true
default:
return 0, false
}
}
// supportsPromptCacheKey reports whether the given API base is known to
// support the prompt_cache_key request field. Currently only OpenAI's own
// API and Azure OpenAI support this. All other OpenAI-compatible providers
+5 -4
View File
@@ -12,6 +12,7 @@ import (
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/providers/common"
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
)
@@ -648,7 +649,7 @@ func TestSerializeMessages_PlainText(t *testing.T) {
{Role: "user", Content: "hello"},
{Role: "assistant", Content: "hi", ReasoningContent: "thinking..."},
}
result := serializeMessages(messages)
result := common.SerializeMessages(messages)
data, err := json.Marshal(result)
if err != nil {
@@ -670,7 +671,7 @@ func TestSerializeMessages_WithMedia(t *testing.T) {
messages := []protocoltypes.Message{
{Role: "user", Content: "describe this", Media: []string{"data:image/png;base64,abc123"}},
}
result := serializeMessages(messages)
result := common.SerializeMessages(messages)
data, _ := json.Marshal(result)
var msgs []map[string]any
@@ -703,7 +704,7 @@ func TestSerializeMessages_MediaWithToolCallID(t *testing.T) {
messages := []protocoltypes.Message{
{Role: "tool", Content: "image result", Media: []string{"data:image/png;base64,xyz"}, ToolCallID: "call_1"},
}
result := serializeMessages(messages)
result := common.SerializeMessages(messages)
data, _ := json.Marshal(result)
var msgs []map[string]any
@@ -833,7 +834,7 @@ func TestSerializeMessages_StripsSystemParts(t *testing.T) {
},
},
}
result := serializeMessages(messages)
result := common.SerializeMessages(messages)
data, _ := json.Marshal(result)
raw := string(data)
+53 -20
View File
@@ -20,10 +20,12 @@ type JobExecutor interface {
// CronTool provides scheduling capabilities for the agent
type CronTool struct {
cronService *cron.CronService
executor JobExecutor
msgBus *bus.MessageBus
execTool *ExecTool
cronService *cron.CronService
executor JobExecutor
msgBus *bus.MessageBus
execTool *ExecTool
allowCommand bool
execEnabled bool
}
// NewCronTool creates a new CronTool
@@ -32,17 +34,32 @@ func NewCronTool(
cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus, workspace string, restrict bool,
execTimeout time.Duration, config *config.Config,
) (*CronTool, error) {
execTool, err := NewExecToolWithConfig(workspace, restrict, config)
if err != nil {
return nil, fmt.Errorf("unable to configure exec tool: %w", err)
allowCommand := true
execEnabled := true
if config != nil {
allowCommand = config.Tools.Cron.AllowCommand
execEnabled = config.Tools.Exec.Enabled
}
execTool.SetTimeout(execTimeout)
var execTool *ExecTool
if execEnabled {
var err error
execTool, err = NewExecToolWithConfig(workspace, restrict, config)
if err != nil {
return nil, fmt.Errorf("unable to configure exec tool: %w", err)
}
}
if execTool != nil {
execTool.SetTimeout(execTimeout)
}
return &CronTool{
cronService: cronService,
executor: executor,
msgBus: msgBus,
execTool: execTool,
cronService: cronService,
executor: executor,
msgBus: msgBus,
execTool: execTool,
allowCommand: allowCommand,
execEnabled: execEnabled,
}, nil
}
@@ -76,7 +93,7 @@ func (t *CronTool) Parameters() map[string]any {
},
"command_confirm": map[string]any{
"type": "boolean",
"description": "Required when using command=true. Must be true to explicitly confirm scheduling a shell command.",
"description": "Optional explicit confirmation flag for scheduling a shell command. Command execution must also be enabled via tools.cron.allow_command.",
},
"at_seconds": map[string]any{
"type": "integer",
@@ -96,7 +113,7 @@ func (t *CronTool) Parameters() map[string]any {
},
"deliver": map[string]any{
"type": "boolean",
"description": "If true, send message directly to channel. If false, let agent process message (for complex tasks). Default: true",
"description": "If true, send message directly to channel. If false, let agent process message (for complex tasks). Default: false",
},
},
"required": []string{"action"},
@@ -174,22 +191,26 @@ func (t *CronTool) addJob(ctx context.Context, args map[string]any) *ToolResult
return ErrorResult("one of at_seconds, every_seconds, or cron_expr is required")
}
// Read deliver parameter, default to true
deliver := true
// Read deliver parameter, default to false so scheduled tasks execute through the agent
deliver := false
if d, ok := args["deliver"].(bool); ok {
deliver = d
}
// GHSA-pv8c-p6jf-3fpp: command scheduling requires internal channel + explicit confirm.
// Non-command reminders (plain messages) remain open to all channels.
// GHSA-pv8c-p6jf-3fpp: command scheduling requires internal channel. When
// allow_command is disabled, explicit confirmation is required as an override.
// Non-command reminders remain open to all channels.
command, _ := args["command"].(string)
commandConfirm, _ := args["command_confirm"].(bool)
if command != "" {
if !t.execEnabled {
return ErrorResult("command execution is disabled")
}
if !constants.IsInternalChannel(channel) {
return ErrorResult("scheduling command execution is restricted to internal channels")
}
if !commandConfirm {
return ErrorResult("command_confirm=true is required to schedule command execution")
if !t.allowCommand && !commandConfirm {
return ErrorResult("command_confirm=true is required when allow_command is disabled")
}
deliver = false
}
@@ -290,6 +311,18 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string {
// Execute command if present
if job.Payload.Command != "" {
if !t.execEnabled || t.execTool == nil {
output := "Error executing scheduled command: command execution is disabled"
pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer pubCancel()
t.msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{
Channel: channel,
ChatID: chatID,
Content: output,
})
return "ok"
}
args := map[string]any{
"command": job.Payload.Command,
"__channel": channel,
+126 -6
View File
@@ -5,18 +5,18 @@ import (
"path/filepath"
"strings"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/cron"
)
func newTestCronTool(t *testing.T) *CronTool {
func newTestCronToolWithConfig(t *testing.T, cfg *config.Config) *CronTool {
t.Helper()
storePath := filepath.Join(t.TempDir(), "cron.json")
cronService := cron.NewCronService(storePath, nil)
msgBus := bus.NewMessageBus()
cfg := config.DefaultConfig()
tool, err := NewCronTool(cronService, nil, msgBus, t.TempDir(), true, 0, cfg)
if err != nil {
t.Fatalf("NewCronTool() error: %v", err)
@@ -24,6 +24,11 @@ func newTestCronTool(t *testing.T) *CronTool {
return tool
}
func newTestCronTool(t *testing.T) *CronTool {
t.Helper()
return newTestCronToolWithConfig(t, config.DefaultConfig())
}
// TestCronTool_CommandBlockedFromRemoteChannel verifies command scheduling is restricted to internal channels
func TestCronTool_CommandBlockedFromRemoteChannel(t *testing.T) {
tool := newTestCronTool(t)
@@ -44,8 +49,7 @@ func TestCronTool_CommandBlockedFromRemoteChannel(t *testing.T) {
}
}
// TestCronTool_CommandRequiresConfirm verifies command_confirm=true is required
func TestCronTool_CommandRequiresConfirm(t *testing.T) {
func TestCronTool_CommandDoesNotRequireConfirmByDefault(t *testing.T) {
tool := newTestCronTool(t)
ctx := WithToolContext(context.Background(), "cli", "direct")
result := tool.Execute(ctx, map[string]any{
@@ -55,11 +59,79 @@ func TestCronTool_CommandRequiresConfirm(t *testing.T) {
"at_seconds": float64(60),
})
if result.IsError {
t.Fatalf("expected command scheduling without confirm to succeed by default, got: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "Cron job added") {
t.Errorf("expected 'Cron job added', got: %s", result.ForLLM)
}
}
func TestCronTool_CommandRequiresConfirmWhenAllowCommandDisabled(t *testing.T) {
cfg := config.DefaultConfig()
cfg.Tools.Cron.AllowCommand = false
tool := newTestCronToolWithConfig(t, cfg)
ctx := WithToolContext(context.Background(), "cli", "direct")
result := tool.Execute(ctx, map[string]any{
"action": "add",
"message": "check disk",
"command": "df -h",
"at_seconds": float64(60),
})
if !result.IsError {
t.Fatal("expected error when command_confirm is missing")
t.Fatal("expected command scheduling to require confirm when allow_command is disabled")
}
if !strings.Contains(result.ForLLM, "command_confirm=true") {
t.Errorf("expected 'command_confirm=true' message, got: %s", result.ForLLM)
t.Errorf("expected command_confirm requirement message, got: %s", result.ForLLM)
}
}
func TestCronTool_CommandAllowedWithConfirmWhenAllowCommandDisabled(t *testing.T) {
cfg := config.DefaultConfig()
cfg.Tools.Cron.AllowCommand = false
tool := newTestCronToolWithConfig(t, cfg)
ctx := WithToolContext(context.Background(), "cli", "direct")
result := tool.Execute(ctx, map[string]any{
"action": "add",
"message": "check disk",
"command": "df -h",
"command_confirm": true,
"at_seconds": float64(60),
})
if result.IsError {
t.Fatalf(
"expected command scheduling with confirm to succeed when allow_command is disabled, got: %s",
result.ForLLM,
)
}
if !strings.Contains(result.ForLLM, "Cron job added") {
t.Errorf("expected 'Cron job added', got: %s", result.ForLLM)
}
}
func TestCronTool_CommandBlockedWhenExecDisabled(t *testing.T) {
cfg := config.DefaultConfig()
cfg.Tools.Exec.Enabled = false
tool := newTestCronToolWithConfig(t, cfg)
ctx := WithToolContext(context.Background(), "cli", "direct")
result := tool.Execute(ctx, map[string]any{
"action": "add",
"message": "check disk",
"command": "df -h",
"command_confirm": true,
"at_seconds": float64(60),
})
if !result.IsError {
t.Fatal("expected command scheduling to be blocked when exec is disabled")
}
if !strings.Contains(result.ForLLM, "command execution is disabled") {
t.Errorf("expected exec disabled message, got: %s", result.ForLLM)
}
}
@@ -114,3 +186,51 @@ func TestCronTool_NonCommandJobAllowedFromRemoteChannel(t *testing.T) {
t.Fatalf("expected non-command reminder to succeed from remote channel, got: %s", result.ForLLM)
}
}
func TestCronTool_NonCommandJobDefaultsDeliverToFalse(t *testing.T) {
tool := newTestCronTool(t)
ctx := WithToolContext(context.Background(), "telegram", "chat-1")
result := tool.Execute(ctx, map[string]any{
"action": "add",
"message": "send me a poem",
"at_seconds": float64(600),
})
if result.IsError {
t.Fatalf("expected non-command reminder to succeed, got: %s", result.ForLLM)
}
jobs := tool.cronService.ListJobs(false)
if len(jobs) != 1 {
t.Fatalf("expected 1 job, got %d", len(jobs))
}
if jobs[0].Payload.Deliver {
t.Fatal("expected deliver=false by default for non-command jobs")
}
}
func TestCronTool_ExecuteJobPublishesErrorWhenExecDisabled(t *testing.T) {
cfg := config.DefaultConfig()
cfg.Tools.Exec.Enabled = false
tool := newTestCronToolWithConfig(t, cfg)
job := &cron.CronJob{}
job.Payload.Channel = "cli"
job.Payload.To = "direct"
job.Payload.Command = "df -h"
if got := tool.ExecuteJob(context.Background(), job); got != "ok" {
t.Fatalf("ExecuteJob() = %q, want ok", got)
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
msg, ok := tool.msgBus.SubscribeOutbound(ctx)
if !ok {
t.Fatal("expected outbound message")
}
if !strings.Contains(msg.Content, "command execution is disabled") {
t.Fatalf("expected exec disabled message, got: %s", msg.Content)
}
}
+161 -9
View File
@@ -20,8 +20,7 @@ import (
const MaxReadFileSize = 64 * 1024 // 64KB limit to avoid context overflow
// validatePath ensures the given path is within the workspace if restrict is true.
func validatePath(path, workspace string, restrict bool) (string, error) {
func validatePathWithAllowPaths(path, workspace string, restrict bool, patterns []*regexp.Regexp) (string, error) {
if workspace == "" {
return path, fmt.Errorf("workspace is not defined")
}
@@ -42,6 +41,10 @@ func validatePath(path, workspace string, restrict bool) (string, error) {
}
if restrict {
if isAllowedPath(absPath, patterns) {
return absPath, nil
}
if !isWithinWorkspace(absPath, absWorkspace) {
return "", fmt.Errorf("access denied: path is outside the workspace")
}
@@ -73,6 +76,137 @@ func validatePath(path, workspace string, restrict bool) (string, error) {
return absPath, nil
}
func isAllowedPath(path string, patterns []*regexp.Regexp) bool {
if len(patterns) == 0 {
return false
}
cleaned := filepath.Clean(path)
if !filepath.IsAbs(cleaned) {
return false
}
if !matchesAllowedPath(cleaned, patterns) {
return false
}
resolved, err := resolvePathAgainstExistingAncestor(cleaned)
if err != nil {
return false
}
return matchesAllowedPath(resolved, patterns)
}
func matchesAllowedPath(path string, patterns []*regexp.Regexp) bool {
cleaned := filepath.Clean(path)
for _, pattern := range patterns {
if pattern.MatchString(cleaned) {
return true
}
if root, ok := extractAllowedPathRoot(pattern); ok && isWithinAllowedRoot(cleaned, root) {
return true
}
}
return false
}
func extractAllowedPathRoot(pattern *regexp.Regexp) (string, bool) {
raw := pattern.String()
if !strings.HasPrefix(raw, "^") {
return "", false
}
literal := strings.TrimPrefix(raw, "^")
// Recognize the common "directory prefix" form: ^<literal>(?:/|$)
literal = strings.TrimSuffix(literal, "(?:/|$)")
literal = strings.TrimSuffix(literal, `(?:\\|$)`)
// Reject patterns that still contain regex operators after removing the
// optional anchored-directory suffix. That keeps arbitrary regex behavior
// unchanged and only enables normalized prefix matching for literal paths.
if containsUnescapedRegexMeta(literal) {
return "", false
}
unescaped, ok := unescapeRegexLiteral(literal)
if !ok || unescaped == "" {
return "", false
}
return filepath.Clean(unescaped), filepath.IsAbs(unescaped)
}
func appendUniquePath(paths []string, path string) []string {
for _, existing := range paths {
if existing == path {
return paths
}
}
return append(paths, path)
}
func containsUnescapedRegexMeta(s string) bool {
escaped := false
for _, r := range s {
if escaped {
escaped = false
continue
}
if r == '\\' {
escaped = true
continue
}
switch r {
case '.', '+', '*', '?', '(', ')', '[', ']', '{', '}', '|':
return true
}
}
return escaped
}
func unescapeRegexLiteral(s string) (string, bool) {
var b strings.Builder
b.Grow(len(s))
escaped := false
for _, r := range s {
if escaped {
b.WriteRune(r)
escaped = false
continue
}
if r == '\\' {
escaped = true
continue
}
b.WriteRune(r)
}
if escaped {
return "", false
}
return b.String(), true
}
func isWithinAllowedRoot(path, root string) bool {
candidate := filepath.Clean(path)
allowedVariants := []string{filepath.Clean(root)}
if resolvedRoot, err := resolvePathAgainstExistingAncestor(root); err == nil {
allowedVariants = appendUniquePath(allowedVariants, filepath.Clean(resolvedRoot))
}
for _, allowedRoot := range allowedVariants {
if isWithinWorkspace(candidate, allowedRoot) {
return true
}
}
return false
}
func resolveExistingAncestor(path string) (string, error) {
for current := filepath.Clean(path); ; current = filepath.Dir(current) {
if resolved, err := filepath.EvalSymlinks(current); err == nil {
@@ -86,9 +220,32 @@ func resolveExistingAncestor(path string) (string, error) {
}
}
func resolvePathAgainstExistingAncestor(path string) (string, error) {
cleaned := filepath.Clean(path)
for current := cleaned; ; current = filepath.Dir(current) {
resolved, err := filepath.EvalSymlinks(current)
if err == nil {
suffix, relErr := filepath.Rel(current, cleaned)
if relErr != nil {
return "", relErr
}
if suffix == "." {
return filepath.Clean(resolved), nil
}
return filepath.Clean(filepath.Join(resolved, suffix)), nil
}
if !os.IsNotExist(err) {
return "", err
}
if filepath.Dir(current) == current {
return "", os.ErrNotExist
}
}
}
func isWithinWorkspace(candidate, workspace string) bool {
rel, err := filepath.Rel(filepath.Clean(workspace), filepath.Clean(candidate))
return err == nil && filepath.IsLocal(rel)
return err == nil && (rel == "." || filepath.IsLocal(rel))
}
type ReadFileTool struct {
@@ -625,12 +782,7 @@ type whitelistFs struct {
}
func (w *whitelistFs) matches(path string) bool {
for _, p := range w.patterns {
if p.MatchString(path) {
return true
}
}
return false
return isAllowedPath(path, w.patterns)
}
func (w *whitelistFs) ReadFile(path string) ([]byte, error) {
+84
View File
@@ -521,6 +521,90 @@ func TestWhitelistFs_AllowsMatchingPaths(t *testing.T) {
}
}
func TestWhitelistFs_BlocksSymlinkEscapeInAllowedDir(t *testing.T) {
workspace := t.TempDir()
allowedDir := t.TempDir()
secretDir := t.TempDir()
secretFile := filepath.Join(secretDir, "secret.txt")
if err := os.WriteFile(secretFile, []byte("top secret"), 0o644); err != nil {
t.Fatalf("WriteFile(secretFile) error = %v", err)
}
linkPath := filepath.Join(allowedDir, "link_out")
if err := os.Symlink(secretDir, linkPath); err != nil {
t.Skipf("symlink not supported in this environment: %v", err)
}
patterns := []*regexp.Regexp{regexp.MustCompile(`^` + regexp.QuoteMeta(allowedDir))}
tool := NewReadFileTool(workspace, true, MaxReadFileSize, patterns)
result := tool.Execute(context.Background(), map[string]any{"path": filepath.Join(linkPath, "secret.txt")})
if !result.IsError {
t.Fatalf("expected symlink escape from allowed dir to be blocked, got: %s", result.ForLLM)
}
}
func TestWhitelistFs_WriteAllowsNewFileUnderAllowedDir(t *testing.T) {
workspace := t.TempDir()
rootDir := t.TempDir()
allowedDir := filepath.Join(rootDir, "allowed")
targetFile := filepath.Join(allowedDir, "nested", "file.txt")
patterns := []*regexp.Regexp{regexp.MustCompile(`^` + regexp.QuoteMeta(allowedDir))}
tool := NewWriteFileTool(workspace, true, patterns)
result := tool.Execute(context.Background(), map[string]any{
"path": targetFile,
"content": "outside write",
})
if result.IsError {
t.Fatalf("expected whitelisted write to succeed, got: %s", result.ForLLM)
}
data, err := os.ReadFile(targetFile)
if err != nil {
t.Fatalf("ReadFile(targetFile) error = %v", err)
}
if string(data) != "outside write" {
t.Fatalf("target file content = %q, want %q", string(data), "outside write")
}
}
func TestWhitelistFs_AllowsResolvedAllowedRootAlias(t *testing.T) {
workspace := t.TempDir()
realDir := t.TempDir()
linkParent := t.TempDir()
allowedAlias := filepath.Join(linkParent, "allowed-link")
if err := os.Symlink(realDir, allowedAlias); err != nil {
t.Skipf("symlink not supported in this environment: %v", err)
}
targetFile := filepath.Join(allowedAlias, "nested", "alias.txt")
if err := os.MkdirAll(filepath.Dir(targetFile), 0o755); err != nil {
t.Fatalf("MkdirAll(targetFile dir) error = %v", err)
}
if err := os.WriteFile(targetFile, []byte("through alias"), 0o644); err != nil {
t.Fatalf("WriteFile(targetFile) error = %v", err)
}
patterns := []*regexp.Regexp{
regexp.MustCompile(
"^" + regexp.QuoteMeta(filepath.Clean(allowedAlias)) +
"(?:" + regexp.QuoteMeta(string(os.PathSeparator)) + "|$)",
),
}
tool := NewReadFileTool(workspace, true, MaxReadFileSize, patterns)
result := tool.Execute(context.Background(), map[string]any{"path": targetFile})
if result.IsError {
t.Fatalf("expected symlink-backed allowed root to be readable, got: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "through alias") {
t.Fatalf("expected file content, got: %s", result.ForLLM)
}
}
// TestReadFileTool_ChunkedReading verifies the pagination logic of the tool
// by reading a file in multiple chunks using 'offset' and 'length'.
func TestReadFileTool_ChunkedReading(t *testing.T) {
+15 -2
View File
@@ -6,6 +6,7 @@ import (
"mime"
"os"
"path/filepath"
"regexp"
"strings"
"github.com/h2non/filetype"
@@ -21,20 +22,32 @@ type SendFileTool struct {
restrict bool
maxFileSize int
mediaStore media.MediaStore
allowPaths []*regexp.Regexp
defaultChannel string
defaultChatID string
}
func NewSendFileTool(workspace string, restrict bool, maxFileSize int, store media.MediaStore) *SendFileTool {
func NewSendFileTool(
workspace string,
restrict bool,
maxFileSize int,
store media.MediaStore,
allowPaths ...[]*regexp.Regexp,
) *SendFileTool {
if maxFileSize <= 0 {
maxFileSize = config.DefaultMaxMediaSize
}
var patterns []*regexp.Regexp
if len(allowPaths) > 0 {
patterns = allowPaths[0]
}
return &SendFileTool{
workspace: workspace,
restrict: restrict,
maxFileSize: maxFileSize,
mediaStore: store,
allowPaths: patterns,
}
}
@@ -92,7 +105,7 @@ func (t *SendFileTool) Execute(ctx context.Context, args map[string]any) *ToolRe
return ErrorResult("media store not configured")
}
resolved, err := validatePath(path, t.workspace, t.restrict)
resolved, err := validatePathWithAllowPaths(path, t.workspace, t.restrict, t.allowPaths)
if err != nil {
return ErrorResult(fmt.Sprintf("invalid path: %v", err))
}
+39
View File
@@ -4,6 +4,7 @@ import (
"context"
"os"
"path/filepath"
"regexp"
"strings"
"testing"
@@ -128,6 +129,44 @@ func TestSendFileTool_CustomFilename(t *testing.T) {
}
}
func TestSendFileTool_AllowsWhitelistedMediaTempPath(t *testing.T) {
workspace := t.TempDir()
mediaDir := media.TempDir()
if err := os.MkdirAll(mediaDir, 0o700); err != nil {
t.Fatalf("MkdirAll(mediaDir) error = %v", err)
}
testFile, err := os.CreateTemp(mediaDir, "send-file-*.txt")
if err != nil {
t.Fatalf("CreateTemp(mediaDir) error = %v", err)
}
testPath := testFile.Name()
if _, err := testFile.WriteString("forward me"); err != nil {
testFile.Close()
t.Fatalf("WriteString(testFile) error = %v", err)
}
if err := testFile.Close(); err != nil {
t.Fatalf("Close(testFile) error = %v", err)
}
t.Cleanup(func() { _ = os.Remove(testPath) })
pattern := regexp.MustCompile(
"^" + regexp.QuoteMeta(filepath.Clean(mediaDir)) + "(?:" + regexp.QuoteMeta(string(os.PathSeparator)) + "|$)",
)
store := media.NewFileMediaStore()
tool := NewSendFileTool(workspace, true, 0, store, []*regexp.Regexp{pattern})
tool.SetContext("feishu", "chat123")
result := tool.Execute(context.Background(), map[string]any{"path": testPath})
if result.IsError {
t.Fatalf("expected whitelisted temp media file to be sendable, got: %s", result.ForLLM)
}
if len(result.Media) != 1 {
t.Fatalf("expected 1 media ref, got %d", len(result.Media))
}
}
func TestDetectMediaType_MagicBytes(t *testing.T) {
dir := t.TempDir()
+31 -13
View File
@@ -23,6 +23,7 @@ type ExecTool struct {
denyPatterns []*regexp.Regexp
allowPatterns []*regexp.Regexp
customAllowPatterns []*regexp.Regexp
allowedPathPatterns []*regexp.Regexp
restrictToWorkspace bool
allowRemote bool
}
@@ -95,14 +96,23 @@ var (
}
)
func NewExecTool(workingDir string, restrict bool) (*ExecTool, error) {
return NewExecToolWithConfig(workingDir, restrict, nil)
func NewExecTool(workingDir string, restrict bool, allowPaths ...[]*regexp.Regexp) (*ExecTool, error) {
return NewExecToolWithConfig(workingDir, restrict, nil, allowPaths...)
}
func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Config) (*ExecTool, error) {
func NewExecToolWithConfig(
workingDir string,
restrict bool,
config *config.Config,
allowPaths ...[]*regexp.Regexp,
) (*ExecTool, error) {
denyPatterns := make([]*regexp.Regexp, 0)
customAllowPatterns := make([]*regexp.Regexp, 0)
var allowedPathPatterns []*regexp.Regexp
allowRemote := true
if len(allowPaths) > 0 {
allowedPathPatterns = allowPaths[0]
}
if config != nil {
execConfig := config.Tools.Exec
@@ -146,6 +156,7 @@ func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Conf
denyPatterns: denyPatterns,
allowPatterns: nil,
customAllowPatterns: customAllowPatterns,
allowedPathPatterns: allowedPathPatterns,
restrictToWorkspace: restrict,
allowRemote: allowRemote,
}, nil
@@ -198,7 +209,7 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]any) *ToolResult
cwd := t.workingDir
if wd, ok := args["working_dir"].(string); ok && wd != "" {
if t.restrictToWorkspace && t.workingDir != "" {
resolvedWD, err := validatePath(wd, t.workingDir, true)
resolvedWD, err := validatePathWithAllowPaths(wd, t.workingDir, true, t.allowedPathPatterns)
if err != nil {
return ErrorResult("Command blocked by safety guard (" + err.Error() + ")")
}
@@ -226,16 +237,20 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]any) *ToolResult
if err != nil {
return ErrorResult(fmt.Sprintf("Command blocked by safety guard (path resolution failed: %v)", err))
}
absWorkspace, _ := filepath.Abs(t.workingDir)
wsResolved, _ := filepath.EvalSymlinks(absWorkspace)
if wsResolved == "" {
wsResolved = absWorkspace
if isAllowedPath(resolved, t.allowedPathPatterns) {
cwd = resolved
} else {
absWorkspace, _ := filepath.Abs(t.workingDir)
wsResolved, _ := filepath.EvalSymlinks(absWorkspace)
if wsResolved == "" {
wsResolved = absWorkspace
}
rel, err := filepath.Rel(wsResolved, resolved)
if err != nil || !filepath.IsLocal(rel) {
return ErrorResult("Command blocked by safety guard (working directory escaped workspace)")
}
cwd = resolved
}
rel, err := filepath.Rel(wsResolved, resolved)
if err != nil || !filepath.IsLocal(rel) {
return ErrorResult("Command blocked by safety guard (working directory escaped workspace)")
}
cwd = resolved
}
// timeout == 0 means no timeout
@@ -412,6 +427,9 @@ func (t *ExecTool) guardCommand(command, cwd string) string {
if safePaths[p] {
continue
}
if isAllowedPath(p, t.allowedPathPatterns) {
continue
}
rel, err := filepath.Rel(cwdPath, p)
if err != nil {
+178
View File
@@ -0,0 +1,178 @@
package tools
import (
"context"
"fmt"
"sort"
"strings"
"time"
)
// SpawnStatusTool reports the status of subagents that were spawned via the
// spawn tool. It can query a specific task by ID, or list every known task with
// a summary count broken-down by status.
type SpawnStatusTool struct {
manager *SubagentManager
}
// NewSpawnStatusTool creates a SpawnStatusTool backed by the given manager.
func NewSpawnStatusTool(manager *SubagentManager) *SpawnStatusTool {
return &SpawnStatusTool{manager: manager}
}
func (t *SpawnStatusTool) Name() string {
return "spawn_status"
}
func (t *SpawnStatusTool) Description() string {
return "Get the status of spawned subagents. " +
"Returns a list of all subagents and their current state " +
"(running, completed, failed, or canceled), or retrieves details " +
"for a specific subagent task when task_id is provided. " +
"Results are scoped to the current conversation's channel and chat ID; " +
"all tasks are listed only when no channel/chat context is injected " +
"(e.g. direct programmatic calls via Execute)."
}
func (t *SpawnStatusTool) Parameters() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]any{
"task_id": map[string]any{
"type": "string",
"description": "Optional task ID (e.g. \"subagent-1\") to inspect a specific " +
"subagent. When omitted, all visible subagents are listed.",
},
},
"required": []string{},
}
}
func (t *SpawnStatusTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
if t.manager == nil {
return ErrorResult("Subagent manager not configured")
}
// Derive the calling conversation's identity so we can scope results to the
// current chat only — preventing cross-conversation task leakage in
// multi-user deployments.
callerChannel := ToolChannel(ctx)
callerChatID := ToolChatID(ctx)
var taskID string
if rawTaskID, ok := args["task_id"]; ok && rawTaskID != nil {
taskIDStr, ok := rawTaskID.(string)
if !ok {
return ErrorResult("task_id must be a string")
}
taskID = strings.TrimSpace(taskIDStr)
}
if taskID != "" {
// GetTaskCopy returns a consistent snapshot under the manager lock,
// eliminating any data race with the concurrent subagent goroutine.
taskCopy, ok := t.manager.GetTaskCopy(taskID)
if !ok {
return ErrorResult(fmt.Sprintf("No subagent found with task ID: %s", taskID))
}
// Restrict lookup to tasks that belong to this conversation.
if callerChannel != "" && taskCopy.OriginChannel != "" && taskCopy.OriginChannel != callerChannel {
return ErrorResult(fmt.Sprintf("No subagent found with task ID: %s", taskID))
}
if callerChatID != "" && taskCopy.OriginChatID != "" && taskCopy.OriginChatID != callerChatID {
return ErrorResult(fmt.Sprintf("No subagent found with task ID: %s", taskID))
}
return NewToolResult(spawnStatusFormatTask(&taskCopy))
}
// ListTaskCopies returns consistent snapshots under the manager lock.
origTasks := t.manager.ListTaskCopies()
if len(origTasks) == 0 {
return NewToolResult("No subagents have been spawned yet.")
}
tasks := make([]*SubagentTask, 0, len(origTasks))
for i := range origTasks {
cpy := &origTasks[i]
// Filter to tasks that originate from the current conversation only.
if callerChannel != "" && cpy.OriginChannel != "" && cpy.OriginChannel != callerChannel {
continue
}
if callerChatID != "" && cpy.OriginChatID != "" && cpy.OriginChatID != callerChatID {
continue
}
tasks = append(tasks, cpy)
}
if len(tasks) == 0 {
return NewToolResult("No subagents found for this conversation.")
}
// Order by creation time (ascending) so spawning order is preserved.
// Fall back to ID string for tasks created in the same millisecond.
sort.Slice(tasks, func(i, j int) bool {
if tasks[i].Created != tasks[j].Created {
return tasks[i].Created < tasks[j].Created
}
return tasks[i].ID < tasks[j].ID
})
counts := map[string]int{}
for _, task := range tasks {
counts[task.Status]++
}
var sb strings.Builder
sb.WriteString(fmt.Sprintf("Subagent status report (%d total):\n", len(tasks)))
for _, status := range []string{"running", "completed", "failed", "canceled"} {
if n := counts[status]; n > 0 {
label := strings.ToUpper(status[:1]) + status[1:] + ":"
sb.WriteString(fmt.Sprintf(" %-10s %d\n", label, n))
}
}
sb.WriteString("\n")
for _, task := range tasks {
sb.WriteString(spawnStatusFormatTask(task))
sb.WriteString("\n\n")
}
return NewToolResult(strings.TrimRight(sb.String(), "\n"))
}
// spawnStatusFormatTask renders a single SubagentTask as a human-readable block.
func spawnStatusFormatTask(task *SubagentTask) string {
var sb strings.Builder
header := fmt.Sprintf("[%s] status=%s", task.ID, task.Status)
if task.Label != "" {
header += fmt.Sprintf(" label=%q", task.Label)
}
if task.AgentID != "" {
header += fmt.Sprintf(" agent=%s", task.AgentID)
}
if task.Created > 0 {
created := time.UnixMilli(task.Created).UTC().Format("2006-01-02 15:04:05 UTC")
header += fmt.Sprintf(" created=%s", created)
}
sb.WriteString(header)
if task.Task != "" {
sb.WriteString(fmt.Sprintf("\n task: %s", task.Task))
}
if task.Result != "" {
result := task.Result
const maxResultLen = 300
runes := []rune(result)
if len(runes) > maxResultLen {
result = string(runes[:maxResultLen]) + "…"
}
sb.WriteString(fmt.Sprintf("\n result: %s", result))
}
return sb.String()
}
+406
View File
@@ -0,0 +1,406 @@
package tools
import (
"context"
"fmt"
"strings"
"testing"
"time"
)
func TestSpawnStatusTool_Name(t *testing.T) {
provider := &MockLLMProvider{}
workspace := t.TempDir()
manager := NewSubagentManager(provider, "test-model", workspace)
tool := NewSpawnStatusTool(manager)
if tool.Name() != "spawn_status" {
t.Errorf("Expected name 'spawn_status', got '%s'", tool.Name())
}
}
func TestSpawnStatusTool_Description(t *testing.T) {
provider := &MockLLMProvider{}
workspace := t.TempDir()
manager := NewSubagentManager(provider, "test-model", workspace)
tool := NewSpawnStatusTool(manager)
desc := tool.Description()
if desc == "" {
t.Error("Description should not be empty")
}
if !strings.Contains(strings.ToLower(desc), "subagent") {
t.Errorf("Description should mention 'subagent', got: %s", desc)
}
}
func TestSpawnStatusTool_Parameters(t *testing.T) {
provider := &MockLLMProvider{}
workspace := t.TempDir()
manager := NewSubagentManager(provider, "test-model", workspace)
tool := NewSpawnStatusTool(manager)
params := tool.Parameters()
if params["type"] != "object" {
t.Errorf("Expected type 'object', got: %v", params["type"])
}
props, ok := params["properties"].(map[string]any)
if !ok {
t.Fatal("Expected 'properties' to be a map")
}
if _, hasTaskID := props["task_id"]; !hasTaskID {
t.Error("Expected 'task_id' parameter in properties")
}
}
func TestSpawnStatusTool_NilManager(t *testing.T) {
tool := &SpawnStatusTool{manager: nil}
result := tool.Execute(context.Background(), map[string]any{})
if !result.IsError {
t.Error("Expected error result when manager is nil")
}
}
func TestSpawnStatusTool_Empty(t *testing.T) {
provider := &MockLLMProvider{}
workspace := t.TempDir()
manager := NewSubagentManager(provider, "test-model", workspace)
tool := NewSpawnStatusTool(manager)
result := tool.Execute(context.Background(), map[string]any{})
if result.IsError {
t.Fatalf("Expected success, got error: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "No subagents") {
t.Errorf("Expected 'No subagents' message, got: %s", result.ForLLM)
}
}
func TestSpawnStatusTool_ListAll(t *testing.T) {
provider := &MockLLMProvider{}
workspace := t.TempDir()
manager := NewSubagentManager(provider, "test-model", workspace)
now := time.Now().UnixMilli()
manager.mu.Lock()
manager.tasks["subagent-1"] = &SubagentTask{
ID: "subagent-1",
Task: "Do task A",
Label: "task-a",
Status: "running",
Created: now,
}
manager.tasks["subagent-2"] = &SubagentTask{
ID: "subagent-2",
Task: "Do task B",
Label: "task-b",
Status: "completed",
Result: "Done successfully",
Created: now,
}
manager.tasks["subagent-3"] = &SubagentTask{
ID: "subagent-3",
Task: "Do task C",
Status: "failed",
Result: "Error: something went wrong",
}
manager.mu.Unlock()
tool := NewSpawnStatusTool(manager)
result := tool.Execute(context.Background(), map[string]any{})
if result.IsError {
t.Fatalf("Expected success, got error: %s", result.ForLLM)
}
// Summary header
if !strings.Contains(result.ForLLM, "3 total") {
t.Errorf("Expected total count in header, got: %s", result.ForLLM)
}
// Individual task IDs
for _, id := range []string{"subagent-1", "subagent-2", "subagent-3"} {
if !strings.Contains(result.ForLLM, id) {
t.Errorf("Expected task %s in output, got:\n%s", id, result.ForLLM)
}
}
// Status values
for _, status := range []string{"running", "completed", "failed"} {
if !strings.Contains(result.ForLLM, status) {
t.Errorf("Expected status '%s' in output, got:\n%s", status, result.ForLLM)
}
}
// Result content
if !strings.Contains(result.ForLLM, "Done successfully") {
t.Errorf("Expected result text in output, got:\n%s", result.ForLLM)
}
}
func TestSpawnStatusTool_GetByID(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
manager.mu.Lock()
manager.tasks["subagent-42"] = &SubagentTask{
ID: "subagent-42",
Task: "Specific task",
Label: "my-task",
Status: "failed",
Result: "Something went wrong",
Created: time.Now().UnixMilli(),
}
manager.mu.Unlock()
tool := NewSpawnStatusTool(manager)
result := tool.Execute(context.Background(), map[string]any{"task_id": "subagent-42"})
if result.IsError {
t.Fatalf("Expected success, got error: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "subagent-42") {
t.Errorf("Expected task ID in output, got: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "failed") {
t.Errorf("Expected status 'failed' in output, got: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "Something went wrong") {
t.Errorf("Expected result text in output, got: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "my-task") {
t.Errorf("Expected label in output, got: %s", result.ForLLM)
}
}
func TestSpawnStatusTool_GetByID_NotFound(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
tool := NewSpawnStatusTool(manager)
result := tool.Execute(context.Background(), map[string]any{"task_id": "nonexistent-999"})
if !result.IsError {
t.Errorf("Expected error for nonexistent task, got: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "nonexistent-999") {
t.Errorf("Expected task ID in error message, got: %s", result.ForLLM)
}
}
func TestSpawnStatusTool_TaskID_NonString(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
tool := NewSpawnStatusTool(manager)
for _, badVal := range []any{42, 3.14, true, map[string]any{"x": 1}, []string{"a"}} {
result := tool.Execute(context.Background(), map[string]any{"task_id": badVal})
if !result.IsError {
t.Errorf("Expected error for task_id=%T(%v), got success: %s", badVal, badVal, result.ForLLM)
}
if !strings.Contains(result.ForLLM, "task_id must be a string") {
t.Errorf("Expected type-error message, got: %s", result.ForLLM)
}
}
}
func TestSpawnStatusTool_ResultTruncation(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
longResult := strings.Repeat("X", 500)
manager.mu.Lock()
manager.tasks["subagent-1"] = &SubagentTask{
ID: "subagent-1",
Task: "Long task",
Status: "completed",
Result: longResult,
}
manager.mu.Unlock()
tool := NewSpawnStatusTool(manager)
result := tool.Execute(context.Background(), map[string]any{"task_id": "subagent-1"})
if result.IsError {
t.Fatalf("Unexpected error: %s", result.ForLLM)
}
// Output should be shorter than the raw result due to truncation
if len(result.ForLLM) >= len(longResult) {
t.Errorf("Expected result to be truncated, but ForLLM is %d chars", len(result.ForLLM))
}
if !strings.Contains(result.ForLLM, "…") {
t.Errorf("Expected truncation indicator '…' in output, got: %s", result.ForLLM)
}
}
func TestSpawnStatusTool_ResultTruncation_Unicode(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
// Each CJK rune is 3 bytes; 400 runes = 1200 bytes — well over the 300-rune limit.
cjkChar := string(rune(0x5b57))
longResult := strings.Repeat(cjkChar, 400)
manager.mu.Lock()
manager.tasks["subagent-1"] = &SubagentTask{
ID: "subagent-1",
Task: "Unicode task",
Status: "completed",
Result: longResult,
}
manager.mu.Unlock()
tool := NewSpawnStatusTool(manager)
result := tool.Execute(context.Background(), map[string]any{"task_id": "subagent-1"})
if result.IsError {
t.Fatalf("Unexpected error: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "…") {
t.Errorf("Expected truncation indicator in output")
}
// The truncated result must be valid UTF-8 (no split rune boundaries).
if !strings.Contains(result.ForLLM, cjkChar) {
t.Errorf("Expected CJK runes to appear intact in output")
}
}
func TestSpawnStatusTool_StatusCounts(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
manager.mu.Lock()
for i, status := range []string{"running", "running", "completed", "failed", "canceled"} {
id := fmt.Sprintf("subagent-%d", i+1)
manager.tasks[id] = &SubagentTask{ID: id, Task: "t", Status: status}
}
manager.mu.Unlock()
tool := NewSpawnStatusTool(manager)
result := tool.Execute(context.Background(), map[string]any{})
if result.IsError {
t.Fatalf("Unexpected error: %s", result.ForLLM)
}
// The summary line should mention all statuses that have counts
for _, want := range []string{"Running:", "Completed:", "Failed:", "Canceled:"} {
if !strings.Contains(result.ForLLM, want) {
t.Errorf("Expected %q in summary, got:\n%s", want, result.ForLLM)
}
}
}
func TestSpawnStatusTool_SortByCreatedTimestamp(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
now := time.Now().UnixMilli()
manager.mu.Lock()
// Intentionally insert with out-of-order IDs and timestamps that reflect
// true spawn order: subagent-2 was spawned first, subagent-10 second.
manager.tasks["subagent-10"] = &SubagentTask{
ID: "subagent-10", Task: "second", Status: "running",
Created: now + 1,
}
manager.tasks["subagent-2"] = &SubagentTask{
ID: "subagent-2", Task: "first", Status: "running",
Created: now,
}
manager.mu.Unlock()
tool := NewSpawnStatusTool(manager)
result := tool.Execute(context.Background(), map[string]any{})
if result.IsError {
t.Fatalf("Unexpected error: %s", result.ForLLM)
}
pos2 := strings.Index(result.ForLLM, "subagent-2")
pos10 := strings.Index(result.ForLLM, "subagent-10")
if pos2 < 0 || pos10 < 0 {
t.Fatalf("Both task IDs should appear in output:\n%s", result.ForLLM)
}
if pos2 > pos10 {
t.Errorf("Expected subagent-2 (created first) to appear before subagent-10, but got:\n%s", result.ForLLM)
}
}
func TestSpawnStatusTool_ChannelFiltering_ListAll(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
manager.mu.Lock()
manager.tasks["subagent-1"] = &SubagentTask{
ID: "subagent-1", Task: "mine", Status: "running",
OriginChannel: "telegram", OriginChatID: "chat-A",
}
manager.tasks["subagent-2"] = &SubagentTask{
ID: "subagent-2", Task: "other user", Status: "running",
OriginChannel: "telegram", OriginChatID: "chat-B",
}
manager.mu.Unlock()
tool := NewSpawnStatusTool(manager)
// Caller is chat-A — should only see subagent-1.
ctx := WithToolContext(context.Background(), "telegram", "chat-A")
result := tool.Execute(ctx, map[string]any{})
if result.IsError {
t.Fatalf("Unexpected error: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "subagent-1") {
t.Errorf("Expected own task in output, got:\n%s", result.ForLLM)
}
if strings.Contains(result.ForLLM, "subagent-2") {
t.Errorf("Should NOT see other chat's task, got:\n%s", result.ForLLM)
}
}
func TestSpawnStatusTool_ChannelFiltering_GetByID(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
manager.mu.Lock()
manager.tasks["subagent-99"] = &SubagentTask{
ID: "subagent-99", Task: "secret", Status: "completed", Result: "private data",
OriginChannel: "slack", OriginChatID: "room-Z",
}
manager.mu.Unlock()
tool := NewSpawnStatusTool(manager)
// Different chat trying to look up subagent-99 by ID.
ctx := WithToolContext(context.Background(), "slack", "room-OTHER")
result := tool.Execute(ctx, map[string]any{"task_id": "subagent-99"})
if !result.IsError {
t.Errorf("Expected error (cross-chat lookup blocked), got: %s", result.ForLLM)
}
}
func TestSpawnStatusTool_ChannelFiltering_NoContext(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
manager.mu.Lock()
manager.tasks["subagent-1"] = &SubagentTask{
ID: "subagent-1", Task: "t", Status: "completed",
OriginChannel: "telegram", OriginChatID: "chat-A",
}
manager.mu.Unlock()
tool := NewSpawnStatusTool(manager)
// No ToolContext injected (e.g. a direct programmatic call that bypasses
// WithToolContext entirely) — callerChannel and callerChatID are both "".
// Note: the normal CLI path uses ProcessDirectWithChannel("cli", "direct"),
// which *does* inject a non-empty context; this test covers the case where
// no context injection happens at all.
// The filter conditions require a non-empty caller value, so all tasks pass through.
result := tool.Execute(context.Background(), map[string]any{})
if result.IsError {
t.Fatalf("Unexpected error: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "subagent-1") {
t.Errorf("Expected task visible from no-context caller, got:\n%s", result.ForLLM)
}
}
+25
View File
@@ -255,6 +255,18 @@ func (sm *SubagentManager) GetTask(taskID string) (*SubagentTask, bool) {
return task, ok
}
// GetTaskCopy returns a copy of the task with the given ID, taken under the
// read lock, so the caller receives a consistent snapshot with no data race.
func (sm *SubagentManager) GetTaskCopy(taskID string) (SubagentTask, bool) {
sm.mu.RLock()
defer sm.mu.RUnlock()
task, ok := sm.tasks[taskID]
if !ok {
return SubagentTask{}, false
}
return *task, true
}
func (sm *SubagentManager) ListTasks() []*SubagentTask {
sm.mu.RLock()
defer sm.mu.RUnlock()
@@ -266,6 +278,19 @@ func (sm *SubagentManager) ListTasks() []*SubagentTask {
return tasks
}
// ListTaskCopies returns value copies of all tasks, taken under the read lock,
// so callers receive consistent snapshots with no data race.
func (sm *SubagentManager) ListTaskCopies() []SubagentTask {
sm.mu.RLock()
defer sm.mu.RUnlock()
copies := make([]SubagentTask, 0, len(sm.tasks))
for _, task := range sm.tasks {
copies = append(copies, *task)
}
return copies
}
// SubagentTool executes a subagent task synchronously and returns the result.
// It directly calls SubTurnSpawner with Async=false for synchronous execution.
type SubagentTool struct {
+2 -1
View File
@@ -12,6 +12,7 @@ import (
"github.com/google/uuid"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/media"
)
// IsAudioFile checks if a file is an audio file based on its filename extension and content type.
@@ -67,7 +68,7 @@ func DownloadFile(urlStr, filename string, opts DownloadOptions) string {
opts.LoggerPrefix = "utils"
}
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
mediaDir := media.TempDir()
if err := os.MkdirAll(mediaDir, 0o700); err != nil {
logger.ErrorCF(opts.LoggerPrefix, "Failed to create media directory", map[string]any{
"error": err.Error(),
+59 -4
View File
@@ -1,5 +1,60 @@
.PHONY: dev dev-frontend dev-backend build test lint clean
# Go variables
GO?=CGO_ENABLED=0 go
WEB_GO?=$(GO)
GOFLAGS?=-v -tags stdjson
# Version
VERSION?=$(shell git describe --tags --always --dirty 2>/dev/null || echo "dev")
GIT_COMMIT=$(shell git rev-parse --short=8 HEAD 2>/dev/null || echo "dev")
BUILD_TIME=$(shell date +%FT%T%z)
GO_VERSION=$(shell $(WEB_GO) version | awk '{print $$3}')
CONFIG_PKG=github.com/sipeed/picoclaw/pkg/config
LDFLAGS=-X $(CONFIG_PKG).Version=$(VERSION) -X $(CONFIG_PKG).GitCommit=$(GIT_COMMIT) -X $(CONFIG_PKG).BuildTime=$(BUILD_TIME) -X $(CONFIG_PKG).GoVersion=$(GO_VERSION) -s -w
# OS detection
UNAME_S:=$(shell uname -s)
UNAME_M:=$(shell uname -m)
# Platform-specific settings
ifeq ($(UNAME_S),Linux)
PLATFORM=linux
ifeq ($(UNAME_M),x86_64)
ARCH=amd64
else ifeq ($(UNAME_M),aarch64)
ARCH=arm64
else ifeq ($(UNAME_M),armv81)
ARCH=arm64
else ifeq ($(UNAME_M),loongarch64)
ARCH=loong64
else ifeq ($(UNAME_M),riscv64)
ARCH=riscv64
else ifeq ($(UNAME_M),mipsel)
ARCH=mipsle
else
ARCH=$(UNAME_M)
endif
else ifeq ($(UNAME_S),Darwin)
PLATFORM=darwin
WEB_GO=CGO_ENABLED=1 go
ifeq ($(UNAME_M),x86_64)
ARCH=amd64
else ifeq ($(UNAME_M),arm64)
ARCH=arm64
else
ARCH=$(UNAME_M)
endif
else ifeq ($(UNAME_S),Windows)
PLATFORM=windows
ARCH=$(UNAME_M)
LDFLAGS=-H=windowsgui $(LDFLAGS)
else
PLATFORM=$(UNAME_S)
ARCH=$(UNAME_M)
endif
# Run both frontend and backend dev servers
dev:
@if [ ! -f backend/picoclaw-web ] || [ ! -d backend/dist ]; then \
@@ -15,21 +70,21 @@ dev-frontend:
# Start backend dev server
dev-backend:
cd backend && go run .
cd backend && ${WEB_GO} run -ldflags "$(LDFLAGS)" .
# Build frontend and embed into Go binary
build:
cd frontend && pnpm build:backend
cd backend && go build -o picoclaw-web .
cd backend && ${WEB_GO} build $(GOFLAGS) -ldflags "$(LDFLAGS)" -o picoclaw-web .
# Run all tests
test:
cd backend && go test ./...
cd backend && ${WEB_GO} test ./...
cd frontend && pnpm lint
# Lint and format
lint:
cd backend && go vet ./...
cd backend && ${WEB_GO} vet ./...
cd frontend && pnpm check
# Clean build artifacts
+22
View File
@@ -5,6 +5,7 @@ import (
"fmt"
"io"
"net/http"
"regexp"
"github.com/sipeed/picoclaw/pkg/config"
)
@@ -188,6 +189,27 @@ func validateConfig(cfg *config.Config) []string {
errs = append(errs, "channels.discord.token is required when discord channel is enabled")
}
if cfg.Tools.Exec.Enabled {
if cfg.Tools.Exec.EnableDenyPatterns {
errs = append(
errs,
validateRegexPatterns("tools.exec.custom_deny_patterns", cfg.Tools.Exec.CustomDenyPatterns)...)
}
errs = append(
errs,
validateRegexPatterns("tools.exec.custom_allow_patterns", cfg.Tools.Exec.CustomAllowPatterns)...)
}
return errs
}
func validateRegexPatterns(field string, patterns []string) []string {
var errs []string
for index, pattern := range patterns {
if _, err := regexp.Compile(pattern); err != nil {
errs = append(errs, fmt.Sprintf("%s[%d] is not a valid regular expression: %v", field, index, err))
}
}
return errs
}
+79
View File
@@ -86,3 +86,82 @@ func TestHandleUpdateConfig_DoesNotInheritDefaultModelFields(t *testing.T) {
t.Fatalf("model_list[0].api_base = %q, want empty string", got)
}
}
func TestHandlePatchConfig_RejectsInvalidExecRegexPatterns(t *testing.T) {
configPath, cleanup := setupOAuthTestEnv(t)
defer cleanup()
h := NewHandler(configPath)
mux := http.NewServeMux()
h.RegisterRoutes(mux)
req := httptest.NewRequest(http.MethodPatch, "/api/config", bytes.NewBufferString(`{
"tools": {
"exec": {
"custom_deny_patterns": ["("]
}
}
}`))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
mux.ServeHTTP(rec, req)
if rec.Code != http.StatusBadRequest {
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusBadRequest, rec.Body.String())
}
if !bytes.Contains(rec.Body.Bytes(), []byte("custom_deny_patterns")) {
t.Fatalf("expected validation error mentioning custom_deny_patterns, body=%s", rec.Body.String())
}
}
func TestHandlePatchConfig_AllowsInvalidExecRegexPatternsWhenExecDisabled(t *testing.T) {
configPath, cleanup := setupOAuthTestEnv(t)
defer cleanup()
h := NewHandler(configPath)
mux := http.NewServeMux()
h.RegisterRoutes(mux)
req := httptest.NewRequest(http.MethodPatch, "/api/config", bytes.NewBufferString(`{
"tools": {
"exec": {
"enabled": false,
"custom_deny_patterns": ["("],
"custom_allow_patterns": ["("]
}
}
}`))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
mux.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
}
}
func TestHandlePatchConfig_AllowsInvalidDenyRegexPatternsWhenDenyPatternsDisabled(t *testing.T) {
configPath, cleanup := setupOAuthTestEnv(t)
defer cleanup()
h := NewHandler(configPath)
mux := http.NewServeMux()
h.RegisterRoutes(mux)
req := httptest.NewRequest(http.MethodPatch, "/api/config", bytes.NewBufferString(`{
"tools": {
"exec": {
"enabled": true,
"enable_deny_patterns": false,
"custom_deny_patterns": ["("]
}
}
}`))
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
mux.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
}
}
-65
View File
@@ -1,65 +0,0 @@
package api
import (
"encoding/json"
"sync"
)
// GatewayEvent represents a state change event for the gateway process.
type GatewayEvent struct {
Status string `json:"gateway_status"` // "running", "starting", "restarting", "stopped", "error"
PID int `json:"pid,omitempty"`
BootDefaultModel string `json:"boot_default_model,omitempty"`
ConfigDefaultModel string `json:"config_default_model,omitempty"`
RestartRequired bool `json:"gateway_restart_required,omitempty"`
}
// EventBroadcaster manages SSE client subscriptions and broadcasts events.
type EventBroadcaster struct {
mu sync.RWMutex
clients map[chan string]struct{}
}
// NewEventBroadcaster creates a new broadcaster.
func NewEventBroadcaster() *EventBroadcaster {
return &EventBroadcaster{
clients: make(map[chan string]struct{}),
}
}
// Subscribe adds a new listener channel and returns it.
// The caller must call Unsubscribe when done.
func (b *EventBroadcaster) Subscribe() chan string {
ch := make(chan string, 8)
b.mu.Lock()
b.clients[ch] = struct{}{}
b.mu.Unlock()
return ch
}
// Unsubscribe removes a listener channel and closes it.
func (b *EventBroadcaster) Unsubscribe(ch chan string) {
b.mu.Lock()
delete(b.clients, ch)
b.mu.Unlock()
close(ch)
}
// Broadcast sends a GatewayEvent to all connected SSE clients.
func (b *EventBroadcaster) Broadcast(event GatewayEvent) {
data, err := json.Marshal(event)
if err != nil {
return
}
b.mu.RLock()
defer b.mu.RUnlock()
for ch := range b.clients {
// Non-blocking send; drop event if client is slow
select {
case ch <- string(data):
default:
}
}
}
+250 -216
View File
@@ -3,6 +3,7 @@ package api
import (
"bufio"
"encoding/json"
"errors"
"fmt"
"io"
"log"
@@ -18,6 +19,7 @@ import (
"time"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/health"
"github.com/sipeed/picoclaw/web/backend/utils"
)
@@ -29,11 +31,9 @@ var gateway = struct {
runtimeStatus string
startupDeadline time.Time
logs *LogBuffer
events *EventBroadcaster
}{
runtimeStatus: "stopped",
logs: NewLogBuffer(200),
events: NewEventBroadcaster(),
}
var (
@@ -48,10 +48,38 @@ var gatewayHealthGet = func(url string, timeout time.Duration) (*http.Response,
return client.Get(url)
}
// getGatewayHealth checks the gateway health endpoint and returns the status response
// Returns (*health.StatusResponse, statusCode, error). If error is not nil, the other values are not valid.
func (h *Handler) getGatewayHealth(cfg *config.Config, timeout time.Duration) (*health.StatusResponse, int, error) {
port := 18790
if cfg != nil && cfg.Gateway.Port != 0 {
port = cfg.Gateway.Port
}
probeHost := gatewayProbeHost(h.effectiveGatewayBindHost(cfg))
url := "http://" + net.JoinHostPort(probeHost, strconv.Itoa(port)) + "/health"
return getGatewayHealthByURL(url, timeout)
}
func getGatewayHealthByURL(url string, timeout time.Duration) (*health.StatusResponse, int, error) {
resp, err := gatewayHealthGet(url, timeout)
if err != nil {
return nil, 0, err
}
defer resp.Body.Close()
var healthResponse health.StatusResponse
if decErr := json.NewDecoder(resp.Body).Decode(&healthResponse); decErr != nil {
return nil, resp.StatusCode, decErr
}
return &healthResponse, resp.StatusCode, nil
}
// registerGatewayRoutes binds gateway lifecycle endpoints to the ServeMux.
func (h *Handler) registerGatewayRoutes(mux *http.ServeMux) {
mux.HandleFunc("GET /api/gateway/status", h.handleGatewayStatus)
mux.HandleFunc("GET /api/gateway/events", h.handleGatewayEvents)
mux.HandleFunc("GET /api/gateway/logs", h.handleGatewayLogs)
mux.HandleFunc("POST /api/gateway/logs/clear", h.handleGatewayClearLogs)
mux.HandleFunc("POST /api/gateway/start", h.handleGatewayStart)
@@ -62,12 +90,35 @@ func (h *Handler) registerGatewayRoutes(mux *http.ServeMux) {
// TryAutoStartGateway checks whether gateway start preconditions are met and
// starts it when possible. Intended to be called by the backend at startup.
func (h *Handler) TryAutoStartGateway() {
// Check if gateway is already running via health endpoint
cfg, cfgErr := config.LoadConfig(h.configPath)
if cfgErr == nil && cfg != nil {
healthResp, statusCode, err := h.getGatewayHealth(cfg, 2*time.Second)
if err == nil && statusCode == http.StatusOK {
// Gateway is already running, attach to the existing process
pid := healthResp.Pid
gateway.mu.Lock()
defer gateway.mu.Unlock()
ready, reason, err := h.gatewayStartReady()
if err != nil {
log.Printf("Skip auto-starting gateway: %v", err)
return
}
if !ready {
log.Printf("Skip auto-starting gateway: %s", reason)
return
}
_, err = h.startGatewayLocked("starting", pid)
if err != nil {
log.Printf("Failed to attach to running gateway (PID: %d): %v", pid, err)
}
return
}
}
gateway.mu.Lock()
defer gateway.mu.Unlock()
if isGatewayProcessAliveLocked() {
return
}
if gateway.cmd != nil && gateway.cmd.Process != nil {
gateway.cmd = nil
}
@@ -82,7 +133,7 @@ func (h *Handler) TryAutoStartGateway() {
return
}
pid, err := h.startGatewayLocked("starting")
pid, err := h.startGatewayLocked("starting", 0)
if err != nil {
log.Printf("Failed to auto-start gateway: %v", err)
return
@@ -125,8 +176,14 @@ func lookupModelConfig(cfg *config.Config, modelName string) *config.ModelConfig
return modelCfg
}
func isGatewayProcessAliveLocked() bool {
return isCmdProcessAliveLocked(gateway.cmd)
func gatewayRestartRequired(configDefaultModel, bootDefaultModel, gatewayStatus string) bool {
if gatewayStatus != "running" {
return false
}
if strings.TrimSpace(configDefaultModel) == "" || strings.TrimSpace(bootDefaultModel) == "" {
return false
}
return configDefaultModel != bootDefaultModel
}
func isCmdProcessAliveLocked(cmd *exec.Cmd) bool {
@@ -157,7 +214,29 @@ func setGatewayRuntimeStatusLocked(status string) {
gateway.startupDeadline = time.Time{}
}
func gatewayStatusOnHealthFailureLocked() string {
// attachToGatewayProcess attaches to an existing gateway process by PID
// and updates the gateway state accordingly.
// Assumes gateway.mu is held by the caller.
func attachToGatewayProcessLocked(pid int, cfg *config.Config) error {
process, err := os.FindProcess(pid)
if err != nil {
return fmt.Errorf("failed to find process for PID %d: %w", pid, err)
}
gateway.cmd = &exec.Cmd{Process: process}
setGatewayRuntimeStatusLocked("running")
// Update bootDefaultModel from config
if cfg != nil {
defaultModelName := strings.TrimSpace(cfg.Agents.Defaults.GetModelName())
gateway.bootDefaultModel = defaultModelName
}
log.Printf("Attached to gateway process (PID: %d)", pid)
return nil
}
func gatewayStatusWithoutHealthLocked() string {
if gateway.runtimeStatus == "starting" || gateway.runtimeStatus == "restarting" {
if gateway.startupDeadline.IsZero() || time.Now().Before(gateway.startupDeadline) {
return gateway.runtimeStatus
@@ -170,23 +249,7 @@ func gatewayStatusOnHealthFailureLocked() string {
if gateway.runtimeStatus == "error" {
return "error"
}
return "error"
}
func currentGatewayStatusLocked(processAlive bool) string {
if !processAlive {
if gateway.runtimeStatus == "restarting" {
if gateway.startupDeadline.IsZero() || time.Now().Before(gateway.startupDeadline) {
return "restarting"
}
return "error"
}
if gateway.runtimeStatus == "error" {
return "error"
}
return "stopped"
}
return gatewayStatusOnHealthFailureLocked()
return "stopped"
}
func waitForGatewayProcessExit(cmd *exec.Cmd, timeout time.Duration) bool {
@@ -238,24 +301,32 @@ func stopGatewayProcessForRestart(cmd *exec.Cmd) error {
return fmt.Errorf("existing gateway did not exit before restart")
}
func gatewayRestartRequired(status, bootDefaultModel, configDefaultModel string) bool {
return status == "running" &&
bootDefaultModel != "" &&
configDefaultModel != "" &&
bootDefaultModel != configDefaultModel
}
func (h *Handler) startGatewayLocked(initialStatus string) (int, error) {
func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int, error) {
cfg, err := config.LoadConfig(h.configPath)
if err != nil {
return 0, fmt.Errorf("failed to load config: %w", err)
}
defaultModelName := strings.TrimSpace(cfg.Agents.Defaults.GetModelName())
var cmd *exec.Cmd
var pid int
if existingPid > 0 {
// Attach to existing process
pid = existingPid
gateway.cmd = nil // Clear first to ensure clean state
if err = attachToGatewayProcessLocked(pid, cfg); err != nil {
return 0, err
}
return pid, nil
}
// Start new process
// Locate the picoclaw executable
execPath := utils.FindPicoclawBinary()
cmd := exec.Command(execPath, "gateway")
cmd = exec.Command(execPath, "gateway", "-E")
cmd.Env = os.Environ()
// Forward the launcher's config path via the environment variable that
// GetConfigPath() already reads, so the gateway sub-process uses the same
@@ -281,7 +352,7 @@ func (h *Handler) startGatewayLocked(initialStatus string) (int, error) {
gateway.logs.Reset()
// Ensure Pico Channel is configured before starting gateway
if _, err := h.ensurePicoChannel(); err != nil {
if _, err := h.ensurePicoChannel(""); err != nil {
log.Printf("Warning: failed to ensure pico channel: %v", err)
// Non-fatal: gateway can still start without pico channel
}
@@ -293,18 +364,9 @@ func (h *Handler) startGatewayLocked(initialStatus string) (int, error) {
gateway.cmd = cmd
gateway.bootDefaultModel = defaultModelName
setGatewayRuntimeStatusLocked(initialStatus)
pid := cmd.Process.Pid
pid = cmd.Process.Pid
log.Printf("Started picoclaw gateway (PID: %d) from %s", pid, execPath)
// Broadcast the launch state immediately so clients can reflect it without polling.
gateway.events.Broadcast(GatewayEvent{
Status: initialStatus,
PID: pid,
BootDefaultModel: defaultModelName,
ConfigDefaultModel: defaultModelName,
RestartRequired: false,
})
// Capture stdout/stderr in background
go scanPipe(stdoutPipe, gateway.logs)
go scanPipe(stderrPipe, gateway.logs)
@@ -318,26 +380,17 @@ func (h *Handler) startGatewayLocked(initialStatus string) (int, error) {
}
gateway.mu.Lock()
shouldBroadcastStopped := false
if gateway.cmd == cmd {
gateway.cmd = nil
gateway.bootDefaultModel = ""
if gateway.runtimeStatus != "restarting" {
setGatewayRuntimeStatusLocked("stopped")
shouldBroadcastStopped = true
}
}
gateway.mu.Unlock()
if shouldBroadcastStopped {
gateway.events.Broadcast(GatewayEvent{
Status: "stopped",
RestartRequired: false,
})
}
}()
// Start a goroutine to probe health and broadcast "running" once ready
// Start a goroutine to probe health and update the runtime state once ready.
go func() {
for i := 0; i < 30; i++ { // try for up to 15 seconds
time.Sleep(500 * time.Millisecond)
@@ -351,30 +404,15 @@ func (h *Handler) startGatewayLocked(initialStatus string) (int, error) {
if err != nil {
continue
}
healthHost := gatewayProbeHost(h.effectiveGatewayBindHost(cfg))
healthPort := cfg.Gateway.Port
if healthPort == 0 {
healthPort = 18790
}
healthURL := fmt.Sprintf("http://%s/health", net.JoinHostPort(healthHost, strconv.Itoa(healthPort)))
resp, err := gatewayHealthGet(healthURL, 1*time.Second)
if err == nil {
resp.Body.Close()
if resp.StatusCode == http.StatusOK {
gateway.mu.Lock()
if gateway.cmd == cmd {
setGatewayRuntimeStatusLocked("running")
}
gateway.mu.Unlock()
gateway.events.Broadcast(GatewayEvent{
Status: "running",
PID: pid,
BootDefaultModel: defaultModelName,
ConfigDefaultModel: defaultModelName,
RestartRequired: false,
})
return
healthResp, statusCode, err := h.getGatewayHealth(cfg, 1*time.Second)
if err == nil && statusCode == http.StatusOK && healthResp.Pid == pid {
// Verify the health endpoint returns the expected pid
gateway.mu.Lock()
if gateway.cmd == cmd {
setGatewayRuntimeStatusLocked("running")
}
gateway.mu.Unlock()
return
}
}
}()
@@ -386,19 +424,54 @@ func (h *Handler) startGatewayLocked(initialStatus string) (int, error) {
//
// POST /api/gateway/start
func (h *Handler) handleGatewayStart(w http.ResponseWriter, r *http.Request) {
// Prevent duplicate starts by checking health endpoint
cfg, cfgErr := config.LoadConfig(h.configPath)
if cfgErr == nil && cfg != nil {
healthResp, statusCode, err := h.getGatewayHealth(cfg, 2*time.Second)
if err == nil && statusCode == http.StatusOK {
// Gateway is already running, attach to the existing process
pid := healthResp.Pid
gateway.mu.Lock()
ready, reason, err := h.gatewayStartReady()
if err != nil {
gateway.mu.Unlock()
http.Error(
w,
fmt.Sprintf("Failed to validate gateway start conditions: %v", err),
http.StatusInternalServerError,
)
return
}
if !ready {
gateway.mu.Unlock()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]any{
"status": "precondition_failed",
"message": reason,
})
return
}
_, err = h.startGatewayLocked("starting", pid)
gateway.mu.Unlock()
if err != nil {
log.Printf("Failed to attach to running gateway (PID: %d): %v", pid, err)
http.Error(w, fmt.Sprintf("Failed to attach to gateway: %v", err), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]any{
"status": "ok",
"pid": pid,
})
return
}
}
gateway.mu.Lock()
defer gateway.mu.Unlock()
// Prevent duplicate starts
if isGatewayProcessAliveLocked() {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusConflict)
json.NewEncoder(w).Encode(map[string]any{
"status": "already_running",
"pid": gateway.cmd.Process.Pid,
})
return
}
if gateway.cmd != nil && gateway.cmd.Process != nil {
gateway.cmd = nil
setGatewayRuntimeStatusLocked("stopped")
@@ -423,7 +496,7 @@ func (h *Handler) handleGatewayStart(w http.ResponseWriter, r *http.Request) {
return
}
pid, err := h.startGatewayLocked("starting")
pid, err := h.startGatewayLocked("starting", 0)
if err != nil {
http.Error(w, fmt.Sprintf("Failed to start gateway: %v", err), http.StatusInternalServerError)
return
@@ -475,36 +548,21 @@ func (h *Handler) handleGatewayStop(w http.ResponseWriter, r *http.Request) {
})
}
// handleGatewayRestart stops the gateway (if running) and starts a new instance.
//
// POST /api/gateway/restart
func (h *Handler) handleGatewayRestart(w http.ResponseWriter, r *http.Request) {
// RestartGateway restarts the gateway process. This is a non-blocking operation
// that stops the current gateway (if running) and starts a new one.
// Returns the PID of the new gateway process or an error.
func (h *Handler) RestartGateway() (int, error) {
ready, reason, err := h.gatewayStartReady()
if err != nil {
http.Error(
w,
fmt.Sprintf("Failed to validate gateway start conditions: %v", err),
http.StatusInternalServerError,
)
return
return 0, fmt.Errorf("failed to validate gateway start conditions: %w", err)
}
if !ready {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]any{
"status": "precondition_failed",
"message": reason,
})
return
return 0, &preconditionFailedError{reason: reason}
}
gateway.mu.Lock()
previousCmd := gateway.cmd
setGatewayRuntimeStatusLocked("restarting")
gateway.events.Broadcast(GatewayEvent{
Status: "restarting",
RestartRequired: false,
})
gateway.mu.Unlock()
if err = stopGatewayProcessForRestart(previousCmd); err != nil {
@@ -519,8 +577,7 @@ func (h *Handler) handleGatewayRestart(w http.ResponseWriter, r *http.Request) {
}
}
gateway.mu.Unlock()
http.Error(w, fmt.Sprintf("Failed to restart gateway: %v", err), http.StatusInternalServerError)
return
return 0, fmt.Errorf("failed to stop gateway: %w", err)
}
gateway.mu.Lock()
@@ -528,7 +585,7 @@ func (h *Handler) handleGatewayRestart(w http.ResponseWriter, r *http.Request) {
gateway.cmd = nil
gateway.bootDefaultModel = ""
}
pid, err := h.startGatewayLocked("restarting")
pid, err := h.startGatewayLocked("restarting", 0)
if err != nil {
gateway.cmd = nil
gateway.bootDefaultModel = ""
@@ -536,6 +593,43 @@ func (h *Handler) handleGatewayRestart(w http.ResponseWriter, r *http.Request) {
}
gateway.mu.Unlock()
if err != nil {
return 0, fmt.Errorf("failed to start gateway: %w", err)
}
return pid, nil
}
// preconditionFailedError is returned when gateway restart preconditions are not met
type preconditionFailedError struct {
reason string
}
func (e *preconditionFailedError) Error() string {
return e.reason
}
// IsBadRequest returns true if the error should result in a 400 Bad Request status
func (e *preconditionFailedError) IsBadRequest() bool {
return true
}
// handleGatewayRestart stops the gateway (if running) and starts a new instance.
//
// POST /api/gateway/restart
func (h *Handler) handleGatewayRestart(w http.ResponseWriter, r *http.Request) {
pid, err := h.RestartGateway()
if err != nil {
// Check if it's a precondition failed error
var precondErr *preconditionFailedError
if errors.As(err, &precondErr) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]any{
"status": "precondition_failed",
"message": precondErr.reason,
})
return
}
http.Error(w, fmt.Sprintf("Failed to restart gateway: %v", err), http.StatusInternalServerError)
return
}
@@ -572,8 +666,8 @@ func (h *Handler) handleGatewayStatus(w http.ResponseWriter, r *http.Request) {
func (h *Handler) gatewayStatusData() map[string]any {
data := map[string]any{}
cfg, cfgErr := config.LoadConfig(h.configPath)
configDefaultModel := ""
cfg, cfgErr := config.LoadConfig(h.configPath)
if cfgErr == nil && cfg != nil {
configDefaultModel = strings.TrimSpace(cfg.Agents.Defaults.GetModelName())
if configDefaultModel != "" {
@@ -581,74 +675,59 @@ func (h *Handler) gatewayStatusData() map[string]any {
}
}
// Check process state
gateway.mu.Lock()
processAlive := isGatewayProcessAliveLocked()
bootDefaultModel := ""
if processAlive {
data["pid"] = gateway.cmd.Process.Pid
if gateway.bootDefaultModel != "" {
data["boot_default_model"] = gateway.bootDefaultModel
bootDefaultModel = gateway.bootDefaultModel
}
}
gateway.mu.Unlock()
if !processAlive {
// Probe health endpoint to get pid and status
healthResp, statusCode, err := h.getGatewayHealth(cfg, 2*time.Second)
if err != nil {
gateway.mu.Lock()
data["gateway_status"] = currentGatewayStatusLocked(false)
data["gateway_status"] = gatewayStatusWithoutHealthLocked()
gateway.mu.Unlock()
log.Printf("Gateway health check failed: %v", err)
} else {
// Process is alive — probe its health endpoint
host := "127.0.0.1"
port := 18790
if cfgErr == nil && cfg != nil {
host = gatewayProbeHost(h.effectiveGatewayBindHost(cfg))
if cfg.Gateway.Port != 0 {
port = cfg.Gateway.Port
}
}
url := fmt.Sprintf("http://%s/health", net.JoinHostPort(host, strconv.Itoa(port)))
resp, err := gatewayHealthGet(url, 2*time.Second)
if err != nil {
log.Printf("Gateway health status: %d", statusCode)
if statusCode != http.StatusOK {
gateway.mu.Lock()
data["gateway_status"] = currentGatewayStatusLocked(true)
setGatewayRuntimeStatusLocked("error")
gateway.mu.Unlock()
data["gateway_status"] = "error"
data["status_code"] = statusCode
} else {
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
gateway.mu.Lock()
setGatewayRuntimeStatusLocked("error")
gateway.mu.Unlock()
data["gateway_status"] = "error"
data["status_code"] = resp.StatusCode
} else {
var healthData map[string]any
if decErr := json.NewDecoder(resp.Body).Decode(&healthData); decErr != nil {
gateway.mu.Lock()
setGatewayRuntimeStatusLocked("error")
gateway.mu.Unlock()
data["gateway_status"] = "error"
} else {
gateway.mu.Lock()
setGatewayRuntimeStatusLocked("running")
gateway.mu.Unlock()
for k, v := range healthData {
data[k] = v
}
data["gateway_status"] = "running"
gateway.mu.Lock()
setGatewayRuntimeStatusLocked("running")
if gateway.cmd == nil || gateway.cmd.Process == nil || gateway.cmd.Process.Pid != healthResp.Pid {
oldPid := "none"
if gateway.cmd != nil && gateway.cmd.Process != nil {
oldPid = fmt.Sprintf("%d", gateway.cmd.Process.Pid)
}
log.Printf(
"Detected gateway PID from health (old: %s, new: %d), attempting to attach",
oldPid,
healthResp.Pid,
)
if err := attachToGatewayProcessLocked(healthResp.Pid, cfg); err != nil {
log.Printf(
"Failed to attach to gateway process reported by health (PID: %d): %v",
healthResp.Pid,
err,
)
}
}
bootDefaultModel := gateway.bootDefaultModel
if bootDefaultModel != "" {
data["boot_default_model"] = bootDefaultModel
}
data["gateway_status"] = "running"
data["pid"] = healthResp.Pid
gateway.mu.Unlock()
}
}
status, _ := data["gateway_status"].(string)
bootDefaultModel, _ := data["boot_default_model"].(string)
gatewayStatus, _ := data["gateway_status"].(string)
data["gateway_restart_required"] = gatewayRestartRequired(
status,
bootDefaultModel,
configDefaultModel,
bootDefaultModel,
gatewayStatus,
)
ready, reason, readyErr := h.gatewayStartReady()
@@ -719,51 +798,6 @@ func gatewayLogsData(r *http.Request) map[string]any {
return data
}
// handleGatewayEvents serves an SSE stream of gateway state change events.
//
// GET /api/gateway/events
func (h *Handler) handleGatewayEvents(w http.ResponseWriter, r *http.Request) {
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "SSE not supported", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("Access-Control-Allow-Origin", "*")
// Subscribe to gateway events
ch := gateway.events.Subscribe()
defer gateway.events.Unsubscribe(ch)
// Send initial status so the client doesn't start blank
initial := h.currentGatewayStatus()
fmt.Fprintf(w, "data: %s\n\n", initial)
flusher.Flush()
for {
select {
case <-r.Context().Done():
return
case data, ok := <-ch:
if !ok {
return
}
fmt.Fprintf(w, "data: %s\n\n", data)
flusher.Flush()
}
}
}
// currentGatewayStatus returns the current gateway status as a JSON string.
func (h *Handler) currentGatewayStatus() string {
data := h.gatewayStatusData()
encoded, _ := json.Marshal(data)
return string(encoded)
}
// scanPipe reads lines from r and appends them to buf. Returns when r reaches EOF.
func scanPipe(r io.Reader, buf *LogBuffer) {
scanner := bufio.NewScanner(r)
+43 -1
View File
@@ -3,6 +3,7 @@ package api
import (
"net"
"net/http"
"net/url"
"strconv"
"strings"
@@ -46,6 +47,23 @@ func gatewayProbeHost(bindHost string) string {
return bindHost
}
func (h *Handler) gatewayProxyURL() *url.URL {
cfg, err := config.LoadConfig(h.configPath)
port := 18790
bindHost := ""
if err == nil && cfg != nil {
if cfg.Gateway.Port != 0 {
port = cfg.Gateway.Port
}
bindHost = h.effectiveGatewayBindHost(cfg)
}
return &url.URL{
Scheme: "http",
Host: net.JoinHostPort(gatewayProbeHost(bindHost), strconv.Itoa(port)),
}
}
func requestHostName(r *http.Request) string {
reqHost, _, err := net.SplitHostPort(r.Host)
if err == nil {
@@ -57,10 +75,34 @@ func requestHostName(r *http.Request) string {
return "127.0.0.1"
}
func requestWSScheme(r *http.Request) string {
if forwarded := strings.TrimSpace(r.Header.Get("X-Forwarded-Proto")); forwarded != "" {
proto := strings.ToLower(strings.TrimSpace(strings.Split(forwarded, ",")[0]))
if proto == "https" || proto == "wss" {
return "wss"
}
if proto == "http" || proto == "ws" {
return "ws"
}
}
if r.TLS != nil {
return "wss"
}
return "ws"
}
func (h *Handler) buildWsURL(r *http.Request, cfg *config.Config) string {
host := h.effectiveGatewayBindHost(cfg)
if host == "" || host == "0.0.0.0" {
host = requestHostName(r)
}
return "ws://" + net.JoinHostPort(host, strconv.Itoa(cfg.Gateway.Port)) + "/pico/ws"
// Use web server port instead of gateway port to avoid exposing extra ports
// The WebSocket connection will be proxied by the backend to the gateway
wsPort := h.serverPort
if wsPort == 0 {
wsPort = 18800 // default web server port
}
return requestWSScheme(r) + "://" + net.JoinHostPort(host, strconv.Itoa(wsPort)) + "/pico/ws"
}
+131 -2
View File
@@ -1,9 +1,13 @@
package api
import (
"crypto/tls"
"errors"
"net/http"
"net/http/httptest"
"path/filepath"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/web/backend/launcherconfig"
@@ -47,8 +51,8 @@ func TestBuildWsURLUsesRequestHostWhenLauncherPublicSaved(t *testing.T) {
req := httptest.NewRequest("GET", "http://launcher.local/api/pico/token", nil)
req.Host = "192.168.1.9:18800"
if got := h.buildWsURL(req, cfg); got != "ws://192.168.1.9:18790/pico/ws" {
t.Fatalf("buildWsURL() = %q, want %q", got, "ws://192.168.1.9:18790/pico/ws")
if got := h.buildWsURL(req, cfg); got != "ws://192.168.1.9:18800/pico/ws" {
t.Fatalf("buildWsURL() = %q, want %q", got, "ws://192.168.1.9:18800/pico/ws")
}
}
@@ -57,3 +61,128 @@ func TestGatewayProbeHostUsesLoopbackForWildcardBind(t *testing.T) {
t.Fatalf("gatewayProbeHost() = %q, want %q", got, "127.0.0.1")
}
}
func TestGatewayProxyURLUsesConfiguredHost(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
cfg := config.DefaultConfig()
cfg.Gateway.Host = "192.168.1.10"
cfg.Gateway.Port = 18791
if err := config.SaveConfig(configPath, cfg); err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}
if got := h.gatewayProxyURL().String(); got != "http://192.168.1.10:18791" {
t.Fatalf("gatewayProxyURL() = %q, want %q", got, "http://192.168.1.10:18791")
}
}
func TestGetGatewayHealthUsesConfiguredHost(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
cfg := config.DefaultConfig()
cfg.Gateway.Host = "192.168.1.10"
cfg.Gateway.Port = 18791
originalHealthGet := gatewayHealthGet
t.Cleanup(func() {
gatewayHealthGet = originalHealthGet
})
var requestedURL string
gatewayHealthGet = func(url string, timeout time.Duration) (*http.Response, error) {
requestedURL = url
return nil, errors.New("probe failed")
}
_, statusCode, err := h.getGatewayHealth(cfg, time.Second)
_ = statusCode
_ = err
if requestedURL != "http://192.168.1.10:18791/health" {
t.Fatalf("health url = %q, want %q", requestedURL, "http://192.168.1.10:18791/health")
}
}
func TestGetGatewayHealthUsesProbeHostForPublicLauncher(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
h.SetServerOptions(18800, true, true, nil)
cfg := config.DefaultConfig()
cfg.Gateway.Host = "127.0.0.1"
cfg.Gateway.Port = 18791
originalHealthGet := gatewayHealthGet
t.Cleanup(func() {
gatewayHealthGet = originalHealthGet
})
var requestedURL string
gatewayHealthGet = func(url string, timeout time.Duration) (*http.Response, error) {
requestedURL = url
return nil, errors.New("probe failed")
}
_, statusCode, err := h.getGatewayHealth(cfg, time.Second)
_ = statusCode
_ = err
if requestedURL != "http://127.0.0.1:18791/health" {
t.Fatalf("health url = %q, want %q", requestedURL, "http://127.0.0.1:18791/health")
}
}
func TestBuildWsURLUsesWSSWhenForwardedProtoIsHTTPS(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
cfg := config.DefaultConfig()
cfg.Gateway.Host = "0.0.0.0"
cfg.Gateway.Port = 18790
req := httptest.NewRequest("GET", "http://launcher.local/api/pico/token", nil)
req.Host = "chat.example.com"
req.Header.Set("X-Forwarded-Proto", "https")
if got := h.buildWsURL(req, cfg); got != "wss://chat.example.com:18800/pico/ws" {
t.Fatalf("buildWsURL() = %q, want %q", got, "wss://chat.example.com:18800/pico/ws")
}
}
func TestBuildWsURLUsesWSSWhenRequestIsTLS(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
cfg := config.DefaultConfig()
cfg.Gateway.Host = "0.0.0.0"
cfg.Gateway.Port = 18790
req := httptest.NewRequest("GET", "https://launcher.local/api/pico/token", nil)
req.Host = "secure.example.com"
req.TLS = &tls.ConnectionState{}
if got := h.buildWsURL(req, cfg); got != "wss://secure.example.com:18800/pico/ws" {
t.Fatalf("buildWsURL() = %q, want %q", got, "wss://secure.example.com:18800/pico/ws")
}
}
func TestBuildWsURLPrefersForwardedHTTPOverTLS(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
cfg := config.DefaultConfig()
cfg.Gateway.Host = "0.0.0.0"
cfg.Gateway.Port = 18790
req := httptest.NewRequest("GET", "https://launcher.local/api/pico/token", nil)
req.Host = "chat.example.com"
req.TLS = &tls.ConnectionState{}
req.Header.Set("X-Forwarded-Proto", "http")
if got := h.buildWsURL(req, cfg); got != "ws://chat.example.com:18800/pico/ws" {
t.Fatalf("buildWsURL() = %q, want %q", got, "ws://chat.example.com:18800/pico/ws")
}
}
+129 -54
View File
@@ -3,6 +3,7 @@ package api
import (
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
"os"
@@ -36,6 +37,15 @@ func startLongRunningProcess(t *testing.T) *exec.Cmd {
return cmd
}
func mockGatewayHealthResponse(statusCode, pid int) *http.Response {
return &http.Response{
StatusCode: statusCode,
Body: io.NopCloser(strings.NewReader(
`{"status":"ok","uptime":"1s","pid":` + strconv.Itoa(pid) + `}`,
)),
}
}
func startIgnoringTermProcess(t *testing.T) *exec.Cmd {
t.Helper()
@@ -419,6 +429,125 @@ func TestGatewayStatusKeepsRunningWhenHealthProbeFailsAfterRunning(t *testing.T)
}
}
func TestGatewayStatusReportsRunningFromHealthProbe(t *testing.T) {
resetGatewayTestState(t)
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
mux := http.NewServeMux()
h.RegisterRoutes(mux)
cmd := startLongRunningProcess(t)
t.Cleanup(func() {
if cmd.Process != nil {
_ = cmd.Process.Kill()
}
_ = cmd.Wait()
})
gateway.mu.Lock()
setGatewayRuntimeStatusLocked("stopped")
gateway.mu.Unlock()
gatewayHealthGet = func(string, time.Duration) (*http.Response, error) {
return mockGatewayHealthResponse(http.StatusOK, cmd.Process.Pid), nil
}
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil)
mux.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}
var body map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
t.Fatalf("unmarshal response: %v", err)
}
if got := body["gateway_status"]; got != "running" {
t.Fatalf("gateway_status = %#v, want %q", got, "running")
}
if got := body["pid"]; got != float64(cmd.Process.Pid) {
t.Fatalf("pid = %#v, want %d", got, cmd.Process.Pid)
}
if got := body["gateway_restart_required"]; got != false {
t.Fatalf("gateway_restart_required = %#v, want false", got)
}
}
func TestGatewayStatusRequiresRestartAfterDefaultModelChange(t *testing.T) {
resetGatewayTestState(t)
configPath := filepath.Join(t.TempDir(), "config.json")
cfg := config.DefaultConfig()
cfg.Agents.Defaults.ModelName = cfg.ModelList[0].ModelName
cfg.ModelList[0].APIKey = "test-key"
cfg.ModelList = append(cfg.ModelList, config.ModelConfig{
ModelName: "second-model",
Model: "openai/gpt-4.1",
APIKey: "second-key",
})
if err := config.SaveConfig(configPath, cfg); err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}
h := NewHandler(configPath)
mux := http.NewServeMux()
h.RegisterRoutes(mux)
process, err := os.FindProcess(os.Getpid())
if err != nil {
t.Fatalf("FindProcess() error = %v", err)
}
gateway.mu.Lock()
gateway.cmd = &exec.Cmd{Process: process}
gateway.bootDefaultModel = cfg.ModelList[0].ModelName
setGatewayRuntimeStatusLocked("running")
gateway.mu.Unlock()
updatedCfg, err := config.LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error = %v", err)
}
updatedCfg.Agents.Defaults.ModelName = "second-model"
if err := config.SaveConfig(configPath, updatedCfg); err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}
gatewayHealthGet = func(string, time.Duration) (*http.Response, error) {
return mockGatewayHealthResponse(http.StatusOK, os.Getpid()), nil
}
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil)
mux.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}
var body map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
t.Fatalf("unmarshal response: %v", err)
}
if got := body["gateway_status"]; got != "running" {
t.Fatalf("gateway_status = %#v, want %q", got, "running")
}
if got := body["boot_default_model"]; got != cfg.ModelList[0].ModelName {
t.Fatalf("boot_default_model = %#v, want %q", got, cfg.ModelList[0].ModelName)
}
if got := body["config_default_model"]; got != "second-model" {
t.Fatalf("config_default_model = %#v, want %q", got, "second-model")
}
if got := body["gateway_restart_required"]; got != true {
t.Fatalf("gateway_restart_required = %#v, want true", got)
}
}
func TestGatewayStatusReturnsErrorAfterStartupWindowExpires(t *testing.T) {
resetGatewayTestState(t)
@@ -494,60 +623,6 @@ func TestGatewayStatusReturnsRestartingDuringRestartGap(t *testing.T) {
}
}
func TestGatewayStatusIncludesRestartRequiredWhenModelsDiffer(t *testing.T) {
resetGatewayTestState(t)
configPath := filepath.Join(t.TempDir(), "config.json")
cfg := config.DefaultConfig()
cfg.Agents.Defaults.ModelName = cfg.ModelList[0].ModelName
cfg.ModelList[0].APIKey = "test-key"
if err := config.SaveConfig(configPath, cfg); err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}
h := NewHandler(configPath)
mux := http.NewServeMux()
h.RegisterRoutes(mux)
cmd := startLongRunningProcess(t)
t.Cleanup(func() {
if cmd.Process != nil {
_ = cmd.Process.Kill()
}
_ = cmd.Wait()
})
gateway.mu.Lock()
gateway.cmd = cmd
gateway.bootDefaultModel = "previous-model"
setGatewayRuntimeStatusLocked("running")
gateway.mu.Unlock()
gatewayHealthGet = func(string, time.Duration) (*http.Response, error) {
rec := httptest.NewRecorder()
rec.WriteHeader(http.StatusOK)
_, _ = rec.WriteString(`{"ok":true}`)
return rec.Result(), nil
}
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil)
mux.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}
var body map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
t.Fatalf("unmarshal response: %v", err)
}
if got := body["gateway_restart_required"]; got != true {
t.Fatalf("gateway_restart_required = %#v, want true", got)
}
}
func TestGatewayRestartKeepsRunningProcessWhenPreconditionsFail(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
cfg := config.DefaultConfig()
+37 -12
View File
@@ -6,6 +6,7 @@ import (
"encoding/json"
"fmt"
"net/http"
"net/http/httputil"
"time"
"github.com/sipeed/picoclaw/pkg/config"
@@ -16,6 +17,30 @@ func (h *Handler) registerPicoRoutes(mux *http.ServeMux) {
mux.HandleFunc("GET /api/pico/token", h.handleGetPicoToken)
mux.HandleFunc("POST /api/pico/token", h.handleRegenPicoToken)
mux.HandleFunc("POST /api/pico/setup", h.handlePicoSetup)
// WebSocket proxy: forward /pico/ws to gateway
// This allows the frontend to connect via the same port as the web UI,
// avoiding the need to expose extra ports for WebSocket communication.
mux.HandleFunc("GET /pico/ws", h.handleWebSocketProxy())
}
// createWsProxy creates a reverse proxy to the current gateway WebSocket endpoint.
// The gateway bind host and port are resolved from the latest configuration.
func (h *Handler) createWsProxy() *httputil.ReverseProxy {
wsProxy := httputil.NewSingleHostReverseProxy(h.gatewayProxyURL())
wsProxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
http.Error(w, "Gateway unavailable: "+err.Error(), http.StatusBadGateway)
}
return wsProxy
}
// handleWebSocketProxy wraps a reverse proxy to handle WebSocket connections.
// The reverse proxy forwards the incoming upgrade handshake as-is.
func (h *Handler) handleWebSocketProxy() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
proxy := h.createWsProxy()
proxy.ServeHTTP(w, r)
}
}
// handleGetPicoToken returns the current WS token and URL for the frontend.
@@ -65,9 +90,14 @@ func (h *Handler) handleRegenPicoToken(w http.ResponseWriter, r *http.Request) {
})
}
// ensurePicoChannel checks if the Pico Channel is properly configured and
// enables it with sensible defaults if not. Returns true if config was changed.
func (h *Handler) ensurePicoChannel() (bool, error) {
// ensurePicoChannel enables the Pico channel with sane defaults if it isn't
// already configured. Returns true when the config was modified.
//
// callerOrigin is the Origin header from the setup request. If non-empty and
// no origins are configured yet, it's written as the allowed origin so the
// WebSocket handshake works for whatever host the caller is on (LAN, custom
// port, etc.). Pass "" when there's no request context.
func (h *Handler) ensurePicoChannel(callerOrigin string) (bool, error) {
cfg, err := config.LoadConfig(h.configPath)
if err != nil {
return false, fmt.Errorf("failed to load config: %w", err)
@@ -85,14 +115,9 @@ func (h *Handler) ensurePicoChannel() (bool, error) {
changed = true
}
if !cfg.Channels.Pico.AllowTokenQuery {
cfg.Channels.Pico.AllowTokenQuery = true
changed = true
}
// Make sure origins are allowed (frontend might be running on a different port like 5173 during dev)
if len(cfg.Channels.Pico.AllowOrigins) == 0 {
cfg.Channels.Pico.AllowOrigins = []string{"*"}
// Seed origins from the request instead of hardcoding ports.
if len(cfg.Channels.Pico.AllowOrigins) == 0 && callerOrigin != "" {
cfg.Channels.Pico.AllowOrigins = []string{callerOrigin}
changed = true
}
@@ -109,7 +134,7 @@ func (h *Handler) ensurePicoChannel() (bool, error) {
//
// POST /api/pico/setup
func (h *Handler) handlePicoSetup(w http.ResponseWriter, r *http.Request) {
changed, err := h.ensurePicoChannel()
changed, err := h.ensurePicoChannel(r.Header.Get("Origin"))
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
+314
View File
@@ -0,0 +1,314 @@
package api
import (
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"net/url"
"path/filepath"
"strconv"
"testing"
"github.com/sipeed/picoclaw/pkg/config"
)
func TestEnsurePicoChannel_FreshConfig(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
changed, err := h.ensurePicoChannel("")
if err != nil {
t.Fatalf("ensurePicoChannel() error = %v", err)
}
if !changed {
t.Fatal("ensurePicoChannel() should report changed on a fresh config")
}
cfg, err := config.LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error = %v", err)
}
if !cfg.Channels.Pico.Enabled {
t.Error("expected Pico to be enabled after setup")
}
if cfg.Channels.Pico.Token == "" {
t.Error("expected a non-empty token after setup")
}
}
func TestEnsurePicoChannel_DoesNotEnableTokenQuery(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
if _, err := h.ensurePicoChannel(""); err != nil {
t.Fatalf("ensurePicoChannel() error = %v", err)
}
cfg, err := config.LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error = %v", err)
}
if cfg.Channels.Pico.AllowTokenQuery {
t.Error("setup must not enable allow_token_query by default")
}
}
func TestEnsurePicoChannel_DoesNotSetWildcardOrigins(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
if _, err := h.ensurePicoChannel("http://localhost:18800"); err != nil {
t.Fatalf("ensurePicoChannel() error = %v", err)
}
cfg, err := config.LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error = %v", err)
}
for _, origin := range cfg.Channels.Pico.AllowOrigins {
if origin == "*" {
t.Error("setup must not set wildcard origin '*'")
}
}
}
func TestEnsurePicoChannel_NoOriginWithoutCaller(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
if _, err := h.ensurePicoChannel(""); err != nil {
t.Fatalf("ensurePicoChannel() error = %v", err)
}
cfg, err := config.LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error = %v", err)
}
// Without a caller origin, allow_origins stays empty (CheckOrigin
// allows all when the list is empty, so the channel still works).
if len(cfg.Channels.Pico.AllowOrigins) != 0 {
t.Errorf("allow_origins = %v, want empty when no caller origin", cfg.Channels.Pico.AllowOrigins)
}
}
func TestEnsurePicoChannel_SetsCallerOrigin(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
lanOrigin := "http://192.168.1.9:18800"
if _, err := h.ensurePicoChannel(lanOrigin); err != nil {
t.Fatalf("ensurePicoChannel() error = %v", err)
}
cfg, err := config.LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error = %v", err)
}
if len(cfg.Channels.Pico.AllowOrigins) != 1 || cfg.Channels.Pico.AllowOrigins[0] != lanOrigin {
t.Errorf("allow_origins = %v, want [%s]", cfg.Channels.Pico.AllowOrigins, lanOrigin)
}
}
func TestEnsurePicoChannel_PreservesUserSettings(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
// Pre-configure with custom user settings
cfg := config.DefaultConfig()
cfg.Channels.Pico.Enabled = true
cfg.Channels.Pico.Token = "user-custom-token"
cfg.Channels.Pico.AllowTokenQuery = true
cfg.Channels.Pico.AllowOrigins = []string{"https://myapp.example.com"}
if err := config.SaveConfig(configPath, cfg); err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}
h := NewHandler(configPath)
changed, err := h.ensurePicoChannel("")
if err != nil {
t.Fatalf("ensurePicoChannel() error = %v", err)
}
if changed {
t.Error("ensurePicoChannel() should not change a fully configured config")
}
cfg, err = config.LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error = %v", err)
}
if cfg.Channels.Pico.Token != "user-custom-token" {
t.Errorf("token = %q, want %q", cfg.Channels.Pico.Token, "user-custom-token")
}
if !cfg.Channels.Pico.AllowTokenQuery {
t.Error("user's allow_token_query=true must be preserved")
}
if len(cfg.Channels.Pico.AllowOrigins) != 1 || cfg.Channels.Pico.AllowOrigins[0] != "https://myapp.example.com" {
t.Errorf("allow_origins = %v, want [https://myapp.example.com]", cfg.Channels.Pico.AllowOrigins)
}
}
func TestEnsurePicoChannel_Idempotent(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
origin := "http://localhost:18800"
// First call sets things up
if _, err := h.ensurePicoChannel(origin); err != nil {
t.Fatalf("first ensurePicoChannel() error = %v", err)
}
cfg1, _ := config.LoadConfig(configPath)
token1 := cfg1.Channels.Pico.Token
// Second call should be a no-op
changed, err := h.ensurePicoChannel(origin)
if err != nil {
t.Fatalf("second ensurePicoChannel() error = %v", err)
}
if changed {
t.Error("second ensurePicoChannel() should not report changed")
}
cfg2, _ := config.LoadConfig(configPath)
if cfg2.Channels.Pico.Token != token1 {
t.Error("token should not change on subsequent calls")
}
}
func TestHandlePicoSetup_IncludesRequestOrigin(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
req := httptest.NewRequest("POST", "/api/pico/setup", nil)
req.Header.Set("Origin", "http://10.0.0.5:3000")
rec := httptest.NewRecorder()
h.handlePicoSetup(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}
cfg, err := config.LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error = %v", err)
}
if len(cfg.Channels.Pico.AllowOrigins) != 1 || cfg.Channels.Pico.AllowOrigins[0] != "http://10.0.0.5:3000" {
t.Errorf("allow_origins = %v, want [http://10.0.0.5:3000]", cfg.Channels.Pico.AllowOrigins)
}
}
func TestHandlePicoSetup_Response(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
req := httptest.NewRequest("POST", "/api/pico/setup", nil)
rec := httptest.NewRecorder()
h.handlePicoSetup(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}
var resp map[string]any
if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if resp["token"] == nil || resp["token"] == "" {
t.Error("response should contain a non-empty token")
}
if resp["ws_url"] == nil || resp["ws_url"] == "" {
t.Error("response should contain ws_url")
}
if resp["enabled"] != true {
t.Error("response should have enabled=true")
}
if resp["changed"] != true {
t.Error("response should have changed=true on first setup")
}
}
func TestHandleWebSocketProxyReloadsGatewayTargetFromConfig(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
handler := h.handleWebSocketProxy()
server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/pico/ws" {
t.Fatalf("server1 path = %q, want %q", r.URL.Path, "/pico/ws")
}
w.WriteHeader(http.StatusOK)
_, _ = io.WriteString(w, "server1")
}))
defer server1.Close()
server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/pico/ws" {
t.Fatalf("server2 path = %q, want %q", r.URL.Path, "/pico/ws")
}
w.WriteHeader(http.StatusOK)
_, _ = io.WriteString(w, "server2")
}))
defer server2.Close()
cfg := config.DefaultConfig()
cfg.Gateway.Host = "127.0.0.1"
cfg.Gateway.Port = mustGatewayTestPort(t, server1.URL)
if err := config.SaveConfig(configPath, cfg); err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}
req1 := httptest.NewRequest(http.MethodGet, "/pico/ws", nil)
rec1 := httptest.NewRecorder()
handler(rec1, req1)
if rec1.Code != http.StatusOK {
t.Fatalf("first status = %d, want %d", rec1.Code, http.StatusOK)
}
if body := rec1.Body.String(); body != "server1" {
t.Fatalf("first body = %q, want %q", body, "server1")
}
cfg.Gateway.Port = mustGatewayTestPort(t, server2.URL)
if err := config.SaveConfig(configPath, cfg); err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}
req2 := httptest.NewRequest(http.MethodGet, "/pico/ws", nil)
rec2 := httptest.NewRecorder()
handler(rec2, req2)
if rec2.Code != http.StatusOK {
t.Fatalf("second status = %d, want %d", rec2.Code, http.StatusOK)
}
if body := rec2.Body.String(); body != "server2" {
t.Fatalf("second body = %q, want %q", body, "server2")
}
}
func mustGatewayTestPort(t *testing.T, rawURL string) int {
t.Helper()
parsed, err := url.Parse(rawURL)
if err != nil {
t.Fatalf("url.Parse() error = %v", err)
}
port, err := strconv.Atoi(parsed.Port())
if err != nil {
t.Fatalf("Atoi(%q) error = %v", parsed.Port(), err)
}
return port
}
+2
View File
@@ -70,3 +70,5 @@ func (h *Handler) RegisterRoutes(mux *http.ServeMux) {
// Launcher service parameters (port/public)
h.registerLauncherConfigRoutes(mux)
}
func (h *Handler) Shutdown() {}
+13 -1
View File
@@ -118,6 +118,12 @@ var toolCatalog = []toolCatalogEntry{
Category: "agents",
ConfigKey: "spawn",
},
{
Name: "spawn_status",
Description: "Query the status of spawned subagents.",
Category: "agents",
ConfigKey: "spawn_status",
},
{
Name: "i2c",
Description: "Interact with I2C hardware devices exposed on the host.",
@@ -205,7 +211,7 @@ func buildToolSupport(cfg *config.Config) []toolSupportItem {
reasonCode = "requires_skills"
}
}
case "spawn":
case "spawn", "spawn_status":
if cfg.Tools.IsToolEnabled(entry.ConfigKey) {
if cfg.Tools.IsToolEnabled("subagent") {
status = "enabled"
@@ -300,6 +306,12 @@ func applyToolState(cfg *config.Config, toolName string, enabled bool) error {
if enabled {
cfg.Tools.Subagent.Enabled = true
}
case "spawn_status":
cfg.Tools.SpawnStatus.Enabled = enabled
if enabled {
cfg.Tools.Spawn.Enabled = true
cfg.Tools.Subagent.Enabled = true
}
case "i2c":
cfg.Tools.I2C.Enabled = enabled
case "spi":
+46
View File
@@ -0,0 +1,46 @@
package main
import (
"context"
"fmt"
"time"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/web/backend/utils"
)
const (
browserDelay = 500 * time.Millisecond
shutdownTimeout = 15 * time.Second
)
func shutdownApp() {
fmt.Println(T(Exiting))
if apiHandler != nil {
apiHandler.Shutdown()
}
if server != nil {
server.SetKeepAlivesEnabled(false)
ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout)
defer cancel()
if err := server.Shutdown(ctx); err != nil {
if err == context.DeadlineExceeded {
logger.Infof("Server shutdown timeout after %v, forcing close", shutdownTimeout)
} else {
logger.Errorf("Server shutdown error: %v", err)
}
} else {
logger.Infof("Server shutdown completed successfully")
}
}
}
func openBrowser() error {
if serverAddr == "" {
return fmt.Errorf("server address not set")
}
return utils.OpenBrowser(serverAddr)
}
+120
View File
@@ -0,0 +1,120 @@
package main
import (
"fmt"
"os"
"strings"
)
// Language represents the supported languages
type Language string
const (
LanguageEnglish Language = "en"
LanguageChinese Language = "zh"
)
// current language (default: English)
var currentLang Language = LanguageEnglish
// TranslationKey represents a translation key used for i18n
type TranslationKey string
const (
AppTooltip TranslationKey = "AppTooltip"
MenuOpen TranslationKey = "MenuOpen"
MenuOpenTooltip TranslationKey = "MenuOpenTooltip"
MenuAbout TranslationKey = "MenuAbout"
MenuAboutTooltip TranslationKey = "MenuAboutTooltip"
MenuVersion TranslationKey = "MenuVersion"
MenuVersionTooltip TranslationKey = "MenuVersionTooltip"
MenuGitHub TranslationKey = "MenuGitHub"
MenuDocs TranslationKey = "MenuDocs"
MenuRestart TranslationKey = "MenuRestart"
MenuRestartTooltip TranslationKey = "MenuRestartTooltip"
MenuQuit TranslationKey = "MenuQuit"
MenuQuitTooltip TranslationKey = "MenuQuitTooltip"
Exiting TranslationKey = "Exiting"
DocUrl TranslationKey = "DocUrl"
)
// Translation tables
// Chinese translations intentionally contain Han script
//
//nolint:gosmopolitan
var translations = map[Language]map[TranslationKey]string{
LanguageEnglish: {
AppTooltip: "%s - Web Console",
MenuOpen: "Open Console",
MenuOpenTooltip: "Open PicoClaw console in browser",
MenuAbout: "About",
MenuAboutTooltip: "About PicoClaw",
MenuVersion: "Version: %s",
MenuVersionTooltip: "Current version number",
MenuGitHub: "GitHub",
MenuDocs: "Documentation",
MenuRestart: "Restart Service",
MenuRestartTooltip: "Restart Gateway service",
MenuQuit: "Quit",
MenuQuitTooltip: "Exit PicoClaw",
Exiting: "Exiting PicoClaw...",
DocUrl: "https://docs.picoclaw.io/docs/",
},
LanguageChinese: {
AppTooltip: "%s - Web Console",
MenuOpen: "打开控制台",
MenuOpenTooltip: "在浏览器中打开 PicoClaw 控制台",
MenuAbout: "关于",
MenuAboutTooltip: "关于 PicoClaw",
MenuVersion: "版本: %s",
MenuVersionTooltip: "当前版本号",
MenuGitHub: "GitHub",
MenuDocs: "文档",
MenuRestart: "重启服务",
MenuRestartTooltip: "重启核心服务",
MenuQuit: "退出",
MenuQuitTooltip: "退出 PicoClaw",
Exiting: "正在退出 PicoClaw...",
DocUrl: "https://docs.picoclaw.io/zh-Hans/docs/",
},
}
// SetLanguage sets the current language
func SetLanguage(lang string) {
lang = strings.ToLower(strings.TrimSpace(lang))
// Extract language code before first underscore or dot
// e.g., "en_US.UTF-8" -> "en", "zh_CN" -> "zh"
if idx := strings.IndexAny(lang, "_."); idx > 0 {
lang = lang[:idx]
}
if lang == "zh" || lang == "zh-cn" || lang == "chinese" {
currentLang = LanguageChinese
} else {
currentLang = LanguageEnglish
}
}
// GetLanguage returns the current language
func GetLanguage() Language {
return currentLang
}
// T translates a key to the current language
func T(key TranslationKey, args ...any) string {
if trans, ok := translations[currentLang][key]; ok {
if len(args) > 0 {
return fmt.Sprintf(trans, args...)
}
return trans
}
return string(key)
}
// Initialize i18n from environment variable
func init() {
if lang := os.Getenv("LANG"); lang != "" {
SetLanguage(lang)
}
}
Binary file not shown.

After

Width:  |  Height:  |  Size: 102 KiB

+37 -16
View File
@@ -22,16 +22,32 @@ import (
"strconv"
"time"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/web/backend/api"
"github.com/sipeed/picoclaw/web/backend/launcherconfig"
"github.com/sipeed/picoclaw/web/backend/middleware"
"github.com/sipeed/picoclaw/web/backend/utils"
)
const (
appName = "PicoClaw"
)
var (
appVersion = config.Version
server *http.Server
serverAddr string
apiHandler *api.Handler
noBrowser *bool
)
func main() {
port := flag.String("port", "18800", "Port to listen on")
public := flag.Bool("public", false, "Listen on all interfaces (0.0.0.0) instead of localhost only")
noBrowser := flag.Bool("no-browser", false, "Do not auto-open browser on startup")
noBrowser = flag.Bool("no-browser", false, "Do not auto-open browser on startup")
lang := flag.String("lang", "", "Language: en (English) or zh (Chinese). Default: auto-detect from system locale")
flag.Usage = func() {
fmt.Fprintf(os.Stderr, "PicoClaw Launcher - A web-based configuration editor\n\n")
@@ -51,6 +67,11 @@ func main() {
}
flag.Parse()
// Set language from command line or auto-detect
if *lang != "" {
SetLanguage(*lang)
}
// Resolve config path
configPath := utils.GetDefaultConfigPath()
if flag.NArg() > 0 {
@@ -113,7 +134,7 @@ func main() {
mux := http.NewServeMux()
// API Routes (e.g. /api/status)
apiHandler := api.NewHandler(absPath)
apiHandler = api.NewHandler(absPath)
apiHandler.SetServerOptions(portNum, effectivePublic, explicitPublic, launcherCfg.AllowedCIDRs)
apiHandler.RegisterRoutes(mux)
@@ -145,16 +166,10 @@ func main() {
}
fmt.Println()
// Auto-open browser
if !*noBrowser {
go func() {
time.Sleep(500 * time.Millisecond)
url := "http://localhost:" + effectivePort
if err := utils.OpenBrowser(url); err != nil {
log.Printf("Warning: Failed to auto-open browser: %v", err)
}
}()
}
// Share the local URL with the launcher runtime.
serverAddr = fmt.Sprintf("http://localhost:%s", effectivePort)
// Auto-open browser will be handled by the launcher runtime.
// Auto-start gateway after backend starts listening.
go func() {
@@ -162,8 +177,14 @@ func main() {
apiHandler.TryAutoStartGateway()
}()
// Start the Server
if err := http.ListenAndServe(addr, handler); err != nil {
log.Fatalf("Server failed to start: %v", err)
}
// Start the Server in a goroutine
server = &http.Server{Addr: addr, Handler: handler}
go func() {
log.Printf("Server listening on %s", addr)
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Fatalf("Server failed to start: %v", err)
}
}()
runTray()
}
+1 -4
View File
@@ -4,16 +4,14 @@ import (
"log"
"net/http"
"runtime/debug"
"strings"
"time"
)
// JSONContentType sets the Content-Type header to application/json for
// API requests handled by the wrapped handler.
// SSE endpoints (text/event-stream) are excluded.
func JSONContentType(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasPrefix(r.URL.Path, "/api/") && !strings.HasSuffix(r.URL.Path, "/events") {
if len(r.URL.Path) >= 5 && r.URL.Path[:5] == "/api/" {
w.Header().Set("Content-Type", "application/json")
}
next.ServeHTTP(w, r)
@@ -32,7 +30,6 @@ func (rr *responseRecorder) WriteHeader(code int) {
}
// Flush delegates to the underlying ResponseWriter if it implements http.Flusher.
// This is required for SSE (Server-Sent Events) to work through the middleware.
func (rr *responseRecorder) Flush() {
if f, ok := rr.ResponseWriter.(http.Flusher); ok {
f.Flush()
+95
View File
@@ -0,0 +1,95 @@
//go:build (!darwin && !freebsd) || cgo
package main
import (
_ "embed"
"fmt"
"fyne.io/systray"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/web/backend/utils"
)
func runTray() {
systray.Run(onReady, shutdownApp)
}
// onReady is called when the system tray is ready
func onReady() {
// Set icon and tooltip
systray.SetIcon(getIcon())
systray.SetTooltip(fmt.Sprintf(T(AppTooltip), appName))
// Create menu items
mOpen := systray.AddMenuItem(T(MenuOpen), T(MenuOpenTooltip))
mAbout := systray.AddMenuItem(T(MenuAbout), T(MenuAboutTooltip))
// Add version info under About menu
mVersion := mAbout.AddSubMenuItem(fmt.Sprintf(T(MenuVersion), appVersion), T(MenuVersionTooltip))
mVersion.Disable()
mRepo := mAbout.AddSubMenuItem(T(MenuGitHub), "")
mDocs := mAbout.AddSubMenuItem(T(MenuDocs), "")
systray.AddSeparator()
// Add restart option
mRestart := systray.AddMenuItem(T(MenuRestart), T(MenuRestartTooltip))
systray.AddSeparator()
// Quit option
mQuit := systray.AddMenuItem(T(MenuQuit), T(MenuQuitTooltip))
// Handle menu clicks
go func() {
for {
select {
case <-mOpen.ClickedCh:
if err := openBrowser(); err != nil {
logger.Errorf("Failed to open browser: %v", err)
}
case <-mVersion.ClickedCh:
// Version info - do nothing, just shows current version
case <-mRepo.ClickedCh:
if err := utils.OpenBrowser("https://github.com/sipeed/picoclaw"); err != nil {
logger.Errorf("Failed to open GitHub: %v", err)
}
case <-mDocs.ClickedCh:
if err := utils.OpenBrowser(T(DocUrl)); err != nil {
logger.Errorf("Failed to open docs: %v", err)
}
case <-mRestart.ClickedCh:
fmt.Println("Restart request received...")
if apiHandler != nil {
if pid, err := apiHandler.RestartGateway(); err != nil {
logger.Errorf("Failed to restart gateway: %v", err)
} else {
logger.Infof("Gateway restarted (PID: %d)", pid)
}
}
case <-mQuit.ClickedCh:
systray.Quit()
}
}
}()
if !*noBrowser {
// Auto-open browser after systray is ready (if not disabled)
// Check no-browser flag via environment or pass as parameter if needed
if err := openBrowser(); err != nil {
logger.Errorf("Warning: Failed to auto-open browser: %v", err)
}
}
}
// getIcon returns the system tray icon
func getIcon() []byte {
return iconData
}
+8
View File
@@ -0,0 +1,8 @@
//go:build !windows
package main
import _ "embed"
//go:embed icon.png
var iconData []byte
+8
View File
@@ -0,0 +1,8 @@
//go:build windows
package main
import _ "embed"
//go:embed icon.ico
var iconData []byte
+33
View File
@@ -0,0 +1,33 @@
//go:build (darwin || freebsd) && !cgo
package main
import (
"context"
"os"
"os/signal"
"runtime"
"syscall"
"time"
"github.com/sipeed/picoclaw/pkg/logger"
)
func runTray() {
logger.Infof("System tray is unavailable in %s builds without cgo; running without tray", runtime.GOOS)
if !*noBrowser {
go func() {
time.Sleep(browserDelay)
if err := openBrowser(); err != nil {
logger.Errorf("Warning: Failed to auto-open browser: %v", err)
}
}()
}
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop()
<-ctx.Done()
shutdownApp()
}
+8 -8
View File
@@ -6,7 +6,7 @@
"scripts": {
"dev": "vite",
"build": "tsc -b && vite build",
"build:backend": "tsc -b && vite build --outDir ../backend/dist --emptyOutDir",
"build:backend": "tsc -b && vite build --outDir ../backend/dist --emptyOutDir && node ./scripts/ensure-backend-gitkeep.cjs",
"lint": "eslint .",
"preview": "vite preview",
"format": "prettier --check .",
@@ -17,18 +17,18 @@
"@tabler/icons-react": "^3.38.0",
"@tailwindcss/vite": "^4.2.1",
"@tanstack/react-query": "^5.90.21",
"@tanstack/react-router": "^1.163.3",
"@tanstack/react-router": "^1.167.0",
"@tanstack/react-router-devtools": "^1.163.3",
"class-variance-authority": "^0.7.1",
"clsx": "^2.1.1",
"dayjs": "^1.11.19",
"dayjs": "^1.11.20",
"i18next": "^25.8.14",
"i18next-browser-languagedetector": "^8.2.1",
"jotai": "^2.18.0",
"jotai": "^2.18.1",
"radix-ui": "^1.4.3",
"react": "^19.2.0",
"react-dom": "^19.2.0",
"react-i18next": "^16.5.4",
"react-i18next": "^16.5.8",
"react-markdown": "^10.1.0",
"react-textarea-autosize": "^8.5.9",
"remark-gfm": "^4.0.1",
@@ -40,7 +40,7 @@
"wrap-ansi": "^10.0.0"
},
"devDependencies": {
"@eslint/js": "^9.39.1",
"@eslint/js": "^9.39.3",
"@tailwindcss/typography": "^0.5.19",
"@tanstack/router-plugin": "^1.164.0",
"@trivago/prettier-plugin-sort-imports": "^6.0.2",
@@ -48,8 +48,8 @@
"@types/react": "^19.2.7",
"@types/react-dom": "^19.2.3",
"@typescript-eslint/eslint-plugin": "^8.56.1",
"@vitejs/plugin-react": "^5.1.1",
"eslint": "^9.39.1",
"@vitejs/plugin-react": "^5.2.0",
"eslint": "^9.39.3",
"eslint-config-prettier": "^10.1.8",
"eslint-plugin-react-hooks": "^7.0.1",
"eslint-plugin-react-refresh": "^0.4.24",
+95 -60
View File
@@ -21,11 +21,11 @@ importers:
specifier: ^5.90.21
version: 5.90.21(react@19.2.4)
'@tanstack/react-router':
specifier: ^1.163.3
version: 1.163.3(react-dom@19.2.4(react@19.2.4))(react@19.2.4)
specifier: ^1.167.0
version: 1.167.0(react-dom@19.2.4(react@19.2.4))(react@19.2.4)
'@tanstack/react-router-devtools':
specifier: ^1.163.3
version: 1.163.3(@tanstack/react-router@1.163.3(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(@tanstack/router-core@1.163.3)(csstype@3.2.3)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)
version: 1.163.3(@tanstack/react-router@1.167.0(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(@tanstack/router-core@1.167.0)(csstype@3.2.3)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)
class-variance-authority:
specifier: ^0.7.1
version: 0.7.1
@@ -33,8 +33,8 @@ importers:
specifier: ^2.1.1
version: 2.1.1
dayjs:
specifier: ^1.11.19
version: 1.11.19
specifier: ^1.11.20
version: 1.11.20
i18next:
specifier: ^25.8.14
version: 25.8.14(typescript@5.9.3)
@@ -42,8 +42,8 @@ importers:
specifier: ^8.2.1
version: 8.2.1
jotai:
specifier: ^2.18.0
version: 2.18.0(@babel/core@7.29.0)(@babel/template@7.28.6)(@types/react@19.2.14)(react@19.2.4)
specifier: ^2.18.1
version: 2.18.1(@babel/core@7.29.0)(@babel/template@7.28.6)(@types/react@19.2.14)(react@19.2.4)
radix-ui:
specifier: ^1.4.3
version: 1.4.3(@types/react-dom@19.2.3(@types/react@19.2.14))(@types/react@19.2.14)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)
@@ -54,8 +54,8 @@ importers:
specifier: ^19.2.0
version: 19.2.4(react@19.2.4)
react-i18next:
specifier: ^16.5.4
version: 16.5.4(i18next@25.8.14(typescript@5.9.3))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(typescript@5.9.3)
specifier: ^16.5.8
version: 16.5.8(i18next@25.8.14(typescript@5.9.3))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(typescript@5.9.3)
react-markdown:
specifier: ^10.1.0
version: 10.1.0(@types/react@19.2.14)(react@19.2.4)
@@ -85,14 +85,14 @@ importers:
version: 10.0.0
devDependencies:
'@eslint/js':
specifier: ^9.39.1
specifier: ^9.39.3
version: 9.39.3
'@tailwindcss/typography':
specifier: ^0.5.19
version: 0.5.19(tailwindcss@4.2.1)
'@tanstack/router-plugin':
specifier: ^1.164.0
version: 1.164.0(@tanstack/react-router@1.163.3(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(vite@7.3.1(@types/node@24.11.0)(jiti@2.6.1)(lightningcss@1.31.1)(tsx@4.21.0))
version: 1.164.0(@tanstack/react-router@1.167.0(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(vite@7.3.1(@types/node@24.11.0)(jiti@2.6.1)(lightningcss@1.31.1)(tsx@4.21.0))
'@trivago/prettier-plugin-sort-imports':
specifier: ^6.0.2
version: 6.0.2(prettier@3.8.1)
@@ -109,10 +109,10 @@ importers:
specifier: ^8.56.1
version: 8.56.1(@typescript-eslint/parser@8.56.1(eslint@9.39.3(jiti@2.6.1))(typescript@5.9.3))(eslint@9.39.3(jiti@2.6.1))(typescript@5.9.3)
'@vitejs/plugin-react':
specifier: ^5.1.1
version: 5.1.4(vite@7.3.1(@types/node@24.11.0)(jiti@2.6.1)(lightningcss@1.31.1)(tsx@4.21.0))
specifier: ^5.2.0
version: 5.2.0(vite@7.3.1(@types/node@24.11.0)(jiti@2.6.1)(lightningcss@1.31.1)(tsx@4.21.0))
eslint:
specifier: ^9.39.1
specifier: ^9.39.3
version: 9.39.3(jiti@2.6.1)
eslint-config-prettier:
specifier: ^10.1.8
@@ -469,8 +469,8 @@ packages:
resolution: {integrity: sha512-EriSTlt5OC9/7SXkRSCAhfSxxoSUgBm33OH+IkwbdpgoqsSsUg7y3uh+IICI/Qg4BBWr3U2i39RpmycbxMq4ew==}
engines: {node: ^12.0.0 || ^14.0.0 || >=16.0.0}
'@eslint/config-array@0.21.1':
resolution: {integrity: sha512-aw1gNayWpdI/jSYVgzN5pL0cfzU02GT3NBpeT/DXbx1/1x7ZKxFPd9bwrzygx/qiwIQiJ1sw/zD8qY/kRvlGHA==}
'@eslint/config-array@0.21.2':
resolution: {integrity: sha512-nJl2KGTlrf9GjLimgIru+V/mzgSK0ABCDQRvxw5BjURL7WfH5uoWmizbH7QB6MmnMBd8cIC9uceWnezL1VZWWw==}
engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0}
'@eslint/config-helpers@0.4.2':
@@ -481,8 +481,8 @@ packages:
resolution: {integrity: sha512-yL/sLrpmtDaFEiUj1osRP4TI2MDz1AddJL+jZ7KSqvBuliN4xqYY54IfdN8qD8Toa6g1iloph1fxQNkjOxrrpQ==}
engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0}
'@eslint/eslintrc@3.3.4':
resolution: {integrity: sha512-4h4MVF8pmBsncB60r0wSJiIeUKTSD4m7FmTFThG8RHlsg9ajqckLm9OraguFGZE4vVdpiI1Q4+hFnisopmG6gQ==}
'@eslint/eslintrc@3.3.5':
resolution: {integrity: sha512-4IlJx0X0qftVsN5E+/vGujTRIFtwuLbNsVUe7TO6zYPDR1O6nFwvwhIKEKSrl6dZchmYBITazxKoUYOjdtjlRg==}
engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0}
'@eslint/js@9.39.3':
@@ -1587,15 +1587,15 @@ packages:
'@tanstack/router-core':
optional: true
'@tanstack/react-router@1.163.3':
resolution: {integrity: sha512-hheBbFVb+PbxtrWp8iy6+TTRTbhx3Pn6hKo8Tv/sWlG89ZMcD1xpQWzx8ukHN9K8YWbh5rdzt4kv6u8X4kB28Q==}
'@tanstack/react-router@1.167.0':
resolution: {integrity: sha512-U7CamtXjuC8ixg1c32Rj/4A2OFBnjtMLdbgbyOGHrFHE7ULWS/yhnZLVXff0QSyn6qF92Oecek9mDMHCaTnB2Q==}
engines: {node: '>=20.19'}
peerDependencies:
react: '>=18.0.0 || >=19.0.0'
react-dom: '>=18.0.0 || >=19.0.0'
'@tanstack/react-store@0.9.1':
resolution: {integrity: sha512-YzJLnRvy5lIEFTLWBAZmcOjK3+2AepnBv/sr6NZmiqJvq7zTQggyK99Gw8fqYdMdHPQWXjz0epFKJXC+9V2xDA==}
'@tanstack/react-store@0.9.2':
resolution: {integrity: sha512-Vt5usJE5sHG/cMechQfmwvwne6ktGCELe89Lmvoxe3LKRoFrhPa8OCKWs0NliG8HTJElEIj7PLtaBQIcux5pAQ==}
peerDependencies:
react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0
react-dom: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0
@@ -1604,6 +1604,10 @@ packages:
resolution: {integrity: sha512-jPptiGq/w3nuPzcMC7RNa79aU+b6OjaDzWJnBcV2UAwL4ThJamRS4h42TdhJE+oF5yH9IEnCOGQdfnbw45LbfA==}
engines: {node: '>=20.19'}
'@tanstack/router-core@1.167.0':
resolution: {integrity: sha512-pnaaUP+vMQEyL2XjZGe2PXmtzulxvXfGyvEMUs+AEBaNEk77xWA88bl3ujiBRbUxzpK0rxfJf+eSKPdZmBMFdQ==}
engines: {node: '>=20.19'}
'@tanstack/router-devtools-core@1.163.3':
resolution: {integrity: sha512-FPi64IP0PT1IkoeyGmsD6JoOVOYAb85VCH0mUbSdD90yV0+1UB6oT+D7K27GXkp7SXMJN3mBEjU5rKnNnmSCIw==}
engines: {node: '>=20.19'}
@@ -1646,6 +1650,9 @@ packages:
'@tanstack/store@0.9.1':
resolution: {integrity: sha512-+qcNkOy0N1qSGsP7omVCW0SDrXtaDcycPqBDE726yryiA5eTDFpjBReaYjghVJwNf1pcPMyzIwTGlYjCSQR0Fg==}
'@tanstack/store@0.9.2':
resolution: {integrity: sha512-K013lUJEFJK2ofFQ/hZKJUmCnpcV00ebLyOyFOWQvyQHUOZp/iYO84BM6aOGiV81JzwbX0APTVmW8YI7yiG5oA==}
'@tanstack/virtual-file-routes@1.161.4':
resolution: {integrity: sha512-42WoRePf8v690qG8yGRe/YOh+oHni9vUaUUfoqlS91U2scd3a5rkLtVsc6b7z60w3RogH0I00vdrC5AaeiZ18w==}
engines: {node: '>=20.19'}
@@ -1790,11 +1797,11 @@ packages:
'@ungap/structured-clone@1.3.0':
resolution: {integrity: sha512-WmoN8qaIAo7WTYWbAZuG8PYEhn5fkz7dZrqTBZ7dtt//lL2Gwms1IcnQ5yHqjDfX8Ft5j4YzDM23f87zBfDe9g==}
'@vitejs/plugin-react@5.1.4':
resolution: {integrity: sha512-VIcFLdRi/VYRU8OL/puL7QXMYafHmqOnwTZY50U1JPlCNj30PxCMx65c494b1K9be9hX83KVt0+gTEwTWLqToA==}
'@vitejs/plugin-react@5.2.0':
resolution: {integrity: sha512-YmKkfhOAi3wsB1PhJq5Scj3GXMn3WvtQ/JC0xoopuHoXSdmtdStOpFrYaT1kie2YgFBcIe64ROzMYRjCrYOdYw==}
engines: {node: ^20.19.0 || >=22.12.0}
peerDependencies:
vite: ^4.2.0 || ^5.0.0 || ^6.0.0 || ^7.0.0
vite: ^4.2.0 || ^5.0.0 || ^6.0.0 || ^7.0.0 || ^8.0.0
accepts@2.0.0:
resolution: {integrity: sha512-5cvg6CtKwfgdmVqY1WIiXKc3Q1bkRqGLi+2W/6ao+6Y7gu/RCwRuAhGEzh5B4KlszSuTLgZYuqFqo5bImjNKng==}
@@ -2060,8 +2067,8 @@ packages:
resolution: {integrity: sha512-0R9ikRb668HB7QDxT1vkpuUBtqc53YyAwMwGeUFKRojY/NWKvdZ+9UYtRfGmhqNbRkTSVpMbmyhXipFFv2cb/A==}
engines: {node: '>= 12'}
dayjs@1.11.19:
resolution: {integrity: sha512-t5EcLVS6QPBNqM2z8fakk/NKel+Xzshgt8FFKAn+qwlD1pzZWxh0nVCrvFK7ZDb6XucZeF9z8C7CBWTRIVApAw==}
dayjs@1.11.20:
resolution: {integrity: sha512-YbwwqR/uYpeoP4pu043q+LTDLFBLApUP6VxRihdfNTqu4ubqMlGDLd6ErXhEgsyvY0K6nCs7nggYumAN+9uEuQ==}
debug@4.4.3:
resolution: {integrity: sha512-RGwwWnwQvkVfavKVt22FGLw+xYSdzARwm0ru6DhTVA3umU5hZc28V3kO4stgYryrTlLpuvgI9GiijltAjNbcqA==}
@@ -2355,8 +2362,8 @@ packages:
resolution: {integrity: sha512-f7ccFPK3SXFHpx15UIGyRJ/FJQctuKZ0zVuN3frBo4HnK3cay9VEW0R6yPYFHC0AgqhukPzKjq22t5DmAyqGyw==}
engines: {node: '>=16'}
flatted@3.3.3:
resolution: {integrity: sha512-GX+ysw4PBCz0PzosHDepZGANEuFCMLrnRTiEy9McGjmkCQYwRq4A/X786G/fjM/+OjsWSU1ZrY5qyARZmO/uwg==}
flatted@3.4.1:
resolution: {integrity: sha512-IxfVbRFVlV8V/yRaGzk0UVIcsKKHMSfYw66T/u4nTwlWteQePsxe//LjudR1AMX4tZW3WFCh3Zqa/sjlqpbURQ==}
formdata-polyfill@4.0.10:
resolution: {integrity: sha512-buewHzMvYL29jdeQTVILecSaZKnt/RJWjoZCF5OW60Z67/GmSLBkOFM7qh1PI3zFNtJbaZL5eQu1vLfazOwj4g==}
@@ -2648,8 +2655,8 @@ packages:
resolution: {integrity: sha512-e6rvdUCiQCAuumZslxRJWR/Doq4VpPR82kqclvcS0efgt430SlGIk05vdCN58+VrzgtIcfNODjozVielycD4Sw==}
engines: {node: '>=16'}
isbot@5.1.35:
resolution: {integrity: sha512-waFfC72ZNfwLLuJ2iLaoVaqcNo+CAaLR7xCpAn0Y5WfGzkNHv7ZN39Vbi1y+kb+Zs46XHOX3tZNExroFUPX+Kg==}
isbot@5.1.36:
resolution: {integrity: sha512-C/ZtXyJqDPZ7G7JPr06ApWyYoHjYexQbS6hPYD4WYCzpv2Qes6Z+CCEfTX4Owzf+1EJ933PoI2p+B9v7wpGZBQ==}
engines: {node: '>=18'}
isexe@2.0.0:
@@ -2669,8 +2676,8 @@ packages:
jose@6.1.3:
resolution: {integrity: sha512-0TpaTfihd4QMNwrz/ob2Bp7X04yuxJkjRGi4aKmOqwhov54i6u79oCv7T+C7lo70MKH6BesI3vscD1yb/yzKXQ==}
jotai@2.18.0:
resolution: {integrity: sha512-XI38kGWAvtxAZ+cwHcTgJsd+kJOJGf3OfL4XYaXWZMZ7IIY8e53abpIHvtVn1eAgJ5dlgwlGFnP4psrZ/vZbtA==}
jotai@2.18.1:
resolution: {integrity: sha512-e0NOzK+yRFwHo7DOp0DS0Ycq74KMEAObDWFGmfEL28PD9nLqBTt3/Ug7jf9ca72x0gC9LQZG9zH+0ISICmy3iA==}
engines: {node: '>=12.20.0'}
peerDependencies:
'@babel/core': '>=7.0.0'
@@ -3323,8 +3330,8 @@ packages:
peerDependencies:
react: ^19.2.4
react-i18next@16.5.4:
resolution: {integrity: sha512-6yj+dcfMncEC21QPhOTsW8mOSO+pzFmT6uvU7XXdvM/Cp38zJkmTeMeKmTrmCMD5ToT79FmiE/mRWiYWcJYW4g==}
react-i18next@16.5.8:
resolution: {integrity: sha512-2ABeHHlakxVY+LSirD+OiERxFL6+zip0PaHo979bgwzeHg27Sqc82xxXWIrSFmfWX0ZkrvXMHwhsi/NGUf5VQg==}
peerDependencies:
i18next: '>= 25.6.2'
react: '>= 16.8.0'
@@ -3476,10 +3483,20 @@ packages:
peerDependencies:
seroval: ^1.0
seroval-plugins@1.5.1:
resolution: {integrity: sha512-4FbuZ/TMl02sqv0RTFexu0SP6V+ywaIe5bAWCCEik0fk17BhALgwvUDVF7e3Uvf9pxmwCEJsRPmlkUE6HdzLAw==}
engines: {node: '>=10'}
peerDependencies:
seroval: ^1.0
seroval@1.5.0:
resolution: {integrity: sha512-OE4cvmJ1uSPrKorFIH9/w/Qwuvi/IMcGbv5RKgcJ/zjA/IohDLU6SVaxFN9FwajbP7nsX0dQqMDes1whk3y+yw==}
engines: {node: '>=10'}
seroval@1.5.1:
resolution: {integrity: sha512-OwrZRZAfhHww0WEnKHDY8OM0U/Qs8OTfIDWhUD4BLpNJUfXK4cGmjiagGze086m+mhI+V2nD0gfbHEnJjb9STA==}
engines: {node: '>=10'}
serve-static@2.2.1:
resolution: {integrity: sha512-xRXBn0pPqQTVQiC8wyQrKs2MOlX24zQ0POGaj0kultvoOCstBQM5yvOhAVSUwOMjQtTvsPWoNCHfPGwaaQJhTw==}
engines: {node: '>= 18'}
@@ -4268,7 +4285,7 @@ snapshots:
'@eslint-community/regexpp@4.12.2': {}
'@eslint/config-array@0.21.1':
'@eslint/config-array@0.21.2':
dependencies:
'@eslint/object-schema': 2.1.7
debug: 4.4.3
@@ -4284,7 +4301,7 @@ snapshots:
dependencies:
'@types/json-schema': 7.0.15
'@eslint/eslintrc@3.3.4':
'@eslint/eslintrc@3.3.5':
dependencies:
ajv: 6.14.0
debug: 4.4.3
@@ -5365,31 +5382,31 @@ snapshots:
'@tanstack/query-core': 5.90.20
react: 19.2.4
'@tanstack/react-router-devtools@1.163.3(@tanstack/react-router@1.163.3(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(@tanstack/router-core@1.163.3)(csstype@3.2.3)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)':
'@tanstack/react-router-devtools@1.163.3(@tanstack/react-router@1.167.0(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(@tanstack/router-core@1.167.0)(csstype@3.2.3)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)':
dependencies:
'@tanstack/react-router': 1.163.3(react-dom@19.2.4(react@19.2.4))(react@19.2.4)
'@tanstack/router-devtools-core': 1.163.3(@tanstack/router-core@1.163.3)(csstype@3.2.3)
'@tanstack/react-router': 1.167.0(react-dom@19.2.4(react@19.2.4))(react@19.2.4)
'@tanstack/router-devtools-core': 1.163.3(@tanstack/router-core@1.167.0)(csstype@3.2.3)
react: 19.2.4
react-dom: 19.2.4(react@19.2.4)
optionalDependencies:
'@tanstack/router-core': 1.163.3
'@tanstack/router-core': 1.167.0
transitivePeerDependencies:
- csstype
'@tanstack/react-router@1.163.3(react-dom@19.2.4(react@19.2.4))(react@19.2.4)':
'@tanstack/react-router@1.167.0(react-dom@19.2.4(react@19.2.4))(react@19.2.4)':
dependencies:
'@tanstack/history': 1.161.4
'@tanstack/react-store': 0.9.1(react-dom@19.2.4(react@19.2.4))(react@19.2.4)
'@tanstack/router-core': 1.163.3
isbot: 5.1.35
'@tanstack/react-store': 0.9.2(react-dom@19.2.4(react@19.2.4))(react@19.2.4)
'@tanstack/router-core': 1.167.0
isbot: 5.1.36
react: 19.2.4
react-dom: 19.2.4(react@19.2.4)
tiny-invariant: 1.3.3
tiny-warning: 1.0.3
'@tanstack/react-store@0.9.1(react-dom@19.2.4(react@19.2.4))(react@19.2.4)':
'@tanstack/react-store@0.9.2(react-dom@19.2.4(react@19.2.4))(react@19.2.4)':
dependencies:
'@tanstack/store': 0.9.1
'@tanstack/store': 0.9.2
react: 19.2.4
react-dom: 19.2.4(react@19.2.4)
use-sync-external-store: 1.6.0(react@19.2.4)
@@ -5404,9 +5421,19 @@ snapshots:
tiny-invariant: 1.3.3
tiny-warning: 1.0.3
'@tanstack/router-devtools-core@1.163.3(@tanstack/router-core@1.163.3)(csstype@3.2.3)':
'@tanstack/router-core@1.167.0':
dependencies:
'@tanstack/router-core': 1.163.3
'@tanstack/history': 1.161.4
'@tanstack/store': 0.9.2
cookie-es: 2.0.0
seroval: 1.5.1
seroval-plugins: 1.5.1(seroval@1.5.1)
tiny-invariant: 1.3.3
tiny-warning: 1.0.3
'@tanstack/router-devtools-core@1.163.3(@tanstack/router-core@1.167.0)(csstype@3.2.3)':
dependencies:
'@tanstack/router-core': 1.167.0
clsx: 2.1.1
goober: 2.1.18(csstype@3.2.3)
tiny-invariant: 1.3.3
@@ -5426,7 +5453,7 @@ snapshots:
transitivePeerDependencies:
- supports-color
'@tanstack/router-plugin@1.164.0(@tanstack/react-router@1.163.3(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(vite@7.3.1(@types/node@24.11.0)(jiti@2.6.1)(lightningcss@1.31.1)(tsx@4.21.0))':
'@tanstack/router-plugin@1.164.0(@tanstack/react-router@1.167.0(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(vite@7.3.1(@types/node@24.11.0)(jiti@2.6.1)(lightningcss@1.31.1)(tsx@4.21.0))':
dependencies:
'@babel/core': 7.29.0
'@babel/plugin-syntax-jsx': 7.28.6(@babel/core@7.29.0)
@@ -5442,7 +5469,7 @@ snapshots:
unplugin: 2.3.11
zod: 3.25.76
optionalDependencies:
'@tanstack/react-router': 1.163.3(react-dom@19.2.4(react@19.2.4))(react@19.2.4)
'@tanstack/react-router': 1.167.0(react-dom@19.2.4(react@19.2.4))(react@19.2.4)
vite: 7.3.1(@types/node@24.11.0)(jiti@2.6.1)(lightningcss@1.31.1)(tsx@4.21.0)
transitivePeerDependencies:
- supports-color
@@ -5463,6 +5490,8 @@ snapshots:
'@tanstack/store@0.9.1': {}
'@tanstack/store@0.9.2': {}
'@tanstack/virtual-file-routes@1.161.4': {}
'@trivago/prettier-plugin-sort-imports@6.0.2(prettier@3.8.1)':
@@ -5641,7 +5670,7 @@ snapshots:
'@ungap/structured-clone@1.3.0': {}
'@vitejs/plugin-react@5.1.4(vite@7.3.1(@types/node@24.11.0)(jiti@2.6.1)(lightningcss@1.31.1)(tsx@4.21.0))':
'@vitejs/plugin-react@5.2.0(vite@7.3.1(@types/node@24.11.0)(jiti@2.6.1)(lightningcss@1.31.1)(tsx@4.21.0))':
dependencies:
'@babel/core': 7.29.0
'@babel/plugin-transform-react-jsx-self': 7.27.1(@babel/core@7.29.0)
@@ -5894,7 +5923,7 @@ snapshots:
data-uri-to-buffer@4.0.1: {}
dayjs@1.11.19: {}
dayjs@1.11.20: {}
debug@4.4.3:
dependencies:
@@ -6048,10 +6077,10 @@ snapshots:
dependencies:
'@eslint-community/eslint-utils': 4.9.1(eslint@9.39.3(jiti@2.6.1))
'@eslint-community/regexpp': 4.12.2
'@eslint/config-array': 0.21.1
'@eslint/config-array': 0.21.2
'@eslint/config-helpers': 0.4.2
'@eslint/core': 0.17.0
'@eslint/eslintrc': 3.3.4
'@eslint/eslintrc': 3.3.5
'@eslint/js': 9.39.3
'@eslint/plugin-kit': 0.4.1
'@humanfs/node': 0.16.7
@@ -6241,10 +6270,10 @@ snapshots:
flat-cache@4.0.1:
dependencies:
flatted: 3.3.3
flatted: 3.4.1
keyv: 4.5.4
flatted@3.3.3: {}
flatted@3.4.1: {}
formdata-polyfill@4.0.10:
dependencies:
@@ -6489,7 +6518,7 @@ snapshots:
dependencies:
is-inside-container: 1.0.0
isbot@5.1.35: {}
isbot@5.1.36: {}
isexe@2.0.0: {}
@@ -6501,7 +6530,7 @@ snapshots:
jose@6.1.3: {}
jotai@2.18.0(@babel/core@7.29.0)(@babel/template@7.28.6)(@types/react@19.2.14)(react@19.2.4):
jotai@2.18.1(@babel/core@7.29.0)(@babel/template@7.28.6)(@types/react@19.2.14)(react@19.2.4):
optionalDependencies:
'@babel/core': 7.29.0
'@babel/template': 7.28.6
@@ -7310,7 +7339,7 @@ snapshots:
react: 19.2.4
scheduler: 0.27.0
react-i18next@16.5.4(i18next@25.8.14(typescript@5.9.3))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(typescript@5.9.3):
react-i18next@16.5.8(i18next@25.8.14(typescript@5.9.3))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(typescript@5.9.3):
dependencies:
'@babel/runtime': 7.28.6
html-parse-stringify: 3.0.1
@@ -7517,8 +7546,14 @@ snapshots:
dependencies:
seroval: 1.5.0
seroval-plugins@1.5.1(seroval@1.5.1):
dependencies:
seroval: 1.5.1
seroval@1.5.0: {}
seroval@1.5.1: {}
serve-static@2.2.1:
dependencies:
encodeurl: 2.0.0
@@ -0,0 +1,9 @@
const fs = require("node:fs")
const path = require("node:path")
const gitkeepPath = path.resolve(__dirname, "../../backend/dist/.gitkeep")
const gitkeepContents =
"# Keep the embedded web backend dist directory in version control.\n"
fs.mkdirSync(path.dirname(gitkeepPath), { recursive: true })
fs.writeFileSync(gitkeepPath, gitkeepContents)
+23 -11
View File
@@ -56,14 +56,20 @@ export function AppHeader() {
const isRunning = gwState === "running"
const isStarting = gwState === "starting"
const isRestarting = gwState === "restarting"
const isStopping = gwState === "stopping"
const isStopped = gwState === "stopped" || gwState === "unknown"
const showNotConnectedHint =
!isRestarting && canStart && (gwState === "stopped" || gwState === "error")
!isRestarting &&
!isStopping &&
canStart &&
(gwState === "stopped" || gwState === "error")
const [showStopDialog, setShowStopDialog] = React.useState(false)
const handleGatewayToggle = () => {
if (gwLoading || isRestarting || (!isRunning && !canStart)) return
if (gwLoading || isRestarting || isStopping || (!isRunning && !canStart)) {
return
}
if (isRunning) {
setShowStopDialog(true)
} else {
@@ -137,7 +143,7 @@ export function AppHeader() {
size="icon-sm"
className="bg-amber-500/15 text-amber-700 hover:bg-amber-500/25 hover:text-amber-800 dark:text-amber-300 dark:hover:bg-amber-500/25"
onClick={handleGatewayRestart}
disabled={gwLoading || isRestarting || !canStart}
disabled={gwLoading || isRestarting || isStopping || !canStart}
aria-label={t("header.gateway.action.restart")}
>
<IconRefresh className="size-4" />
@@ -168,25 +174,31 @@ export function AppHeader() {
</Tooltip>
) : (
<Button
variant={isStarting || isRestarting ? "secondary" : "default"}
variant={
isStarting || isRestarting || isStopping ? "secondary" : "default"
}
size="sm"
className={`h-8 gap-2 px-3 ${
isStopped ? "bg-green-500 text-white hover:bg-green-600" : ""
}`}
onClick={handleGatewayToggle}
disabled={gwLoading || isStarting || isRestarting || !canStart}
disabled={
gwLoading || isStarting || isRestarting || isStopping || !canStart
}
>
{gwLoading || isStarting || isRestarting ? (
{gwLoading || isStarting || isRestarting || isStopping ? (
<IconLoader2 className="h-4 w-4 animate-spin opacity-70" />
) : (
<IconPlayerPlay className="h-4 w-4 opacity-80" />
)}
<span className="text-xs font-semibold">
{isRestarting
? t("header.gateway.status.restarting")
: isStarting
? t("header.gateway.status.starting")
: t("header.gateway.action.start")}
{isStopping
? t("header.gateway.status.stopping")
: isRestarting
? t("header.gateway.status.restarting")
: isStarting
? t("header.gateway.status.starting")
: t("header.gateway.action.start")}
</span>
</Button>
)}
@@ -42,7 +42,7 @@ export function ChatComposer({
placeholder={t("chat.placeholder")}
disabled={!canInput}
className={cn(
"max-h-[200px] min-h-[60px] resize-none border-0 bg-transparent px-2 py-1 text-[15px] shadow-none transition-colors focus-visible:ring-0 focus-visible:outline-none dark:bg-transparent",
"placeholder:text-muted-foreground max-h-[200px] min-h-[60px] resize-none border-0 bg-transparent px-2 py-1 text-[15px] shadow-none transition-colors focus-visible:ring-0 focus-visible:outline-none dark:bg-transparent",
!canInput && "cursor-not-allowed",
)}
minRows={1}
@@ -56,7 +56,7 @@ export function ChatComposer({
size="icon"
className="size-8 rounded-full bg-violet-500 text-white transition-transform hover:bg-violet-600 active:scale-95"
onClick={onSend}
disabled={!input.trim() || !isConnected}
disabled={!input.trim() || !canInput}
>
<IconArrowUp className="size-4" />
</Button>
@@ -34,7 +34,7 @@ export function ChatEmptyState({
<p className="text-muted-foreground mb-4 text-center text-sm">
{t("chat.empty.noConfiguredModelDescription")}
</p>
<Button asChild variant="secondary" size="sm" className="px-4">
<Button asChild variant="outline" size="sm" className="px-4">
<Link to="/models">{t("chat.empty.goToModels")}</Link>
</Button>
</div>
+11 -12
View File
@@ -15,7 +15,6 @@ import { useChatModels } from "@/hooks/use-chat-models"
import { useGateway } from "@/hooks/use-gateway"
import { usePicoChat } from "@/hooks/use-pico-chat"
import { useSessionHistory } from "@/hooks/use-session-history"
import { hydrateActiveSession } from "@/lib/pico-chat-controller"
export function ChatPage() {
const { t } = useTranslation()
@@ -26,6 +25,7 @@ export function ChatPage() {
const {
messages,
connectionState,
isTyping,
activeSessionId,
sendMessage,
@@ -34,7 +34,8 @@ export function ChatPage() {
} = usePicoChat()
const { state: gwState } = useGateway()
const isConnected = gwState === "running"
const isGatewayRunning = gwState === "running"
const isChatConnected = connectionState === "connected"
const {
defaultModelName,
@@ -43,7 +44,8 @@ export function ChatPage() {
oauthModels,
localModels,
handleSetDefault,
} = useChatModels({ isConnected })
} = useChatModels({ isConnected: isGatewayRunning })
const canSend = isChatConnected && Boolean(defaultModelName)
const {
sessions,
@@ -68,10 +70,6 @@ export function ChatPage() {
syncScrollState(e.currentTarget)
}
useEffect(() => {
void hydrateActiveSession()
}, [])
useEffect(() => {
if (scrollRef.current) {
if (isAtBottom) {
@@ -82,9 +80,10 @@ export function ChatPage() {
}, [messages, isTyping, isAtBottom])
const handleSend = () => {
if (!input.trim() || !isConnected) return
sendMessage(input.trim())
setInput("")
if (!input.trim() || !canSend) return
if (sendMessage(input.trim())) {
setInput("")
}
}
return (
@@ -143,7 +142,7 @@ export function ChatPage() {
<ChatEmptyState
hasConfiguredModels={hasConfiguredModels}
defaultModelName={defaultModelName}
isConnected={isConnected}
isConnected={isGatewayRunning}
/>
)}
@@ -168,7 +167,7 @@ export function ChatPage() {
input={input}
onInputChange={setInput}
onSend={handleSend}
isConnected={isConnected}
isConnected={isChatConnected}
hasDefaultModel={Boolean(defaultModelName)}
/>
</div>
@@ -14,7 +14,9 @@ import {
} from "@/api/system"
import {
AgentDefaultsSection,
CronSection,
DevicesSection,
ExecSection,
LauncherSection,
RuntimeSection,
} from "@/components/config/config-sections"
@@ -26,6 +28,7 @@ import {
buildFormFromConfig,
parseCIDRText,
parseIntField,
parseMultilineList,
} from "@/components/config/form-model"
import { PageHeader } from "@/components/page-header"
import { Button } from "@/components/ui/button"
@@ -164,6 +167,33 @@ export function ConfigPage() {
"Heartbeat interval",
{ min: 1 },
)
const cronExecTimeoutMinutes = parseIntField(
form.cronExecTimeoutMinutes,
"Cron exec timeout",
{ min: 0 },
)
const execConfigPatch: Record<string, unknown> = {
enabled: form.execEnabled,
}
if (form.execEnabled) {
execConfigPatch.allow_remote = form.allowRemote
execConfigPatch.enable_deny_patterns = form.enableDenyPatterns
execConfigPatch.custom_allow_patterns = parseMultilineList(
form.customAllowPatternsText,
)
execConfigPatch.timeout_seconds = parseIntField(
form.execTimeoutSeconds,
"Exec timeout",
{ min: 0 },
)
if (form.enableDenyPatterns) {
execConfigPatch.custom_deny_patterns = parseMultilineList(
form.customDenyPatternsText,
)
}
}
await patchAppConfig({
agents: {
@@ -180,9 +210,11 @@ export function ConfigPage() {
dm_scope: dmScope,
},
tools: {
exec: {
allow_remote: form.allowRemote,
cron: {
allow_command: form.allowCommand,
exec_timeout_minutes: cronExecTimeoutMinutes,
},
exec: execConfigPatch,
},
heartbeat: {
enabled: form.heartbeatEnabled,
@@ -279,6 +311,10 @@ export function ConfigPage() {
<RuntimeSection form={form} onFieldChange={updateField} />
<ExecSection form={form} onFieldChange={updateField} />
<CronSection form={form} onFieldChange={updateField} />
<LauncherSection
launcherForm={launcherForm}
onFieldChange={updateLauncherField}
@@ -93,14 +93,6 @@ export function AgentDefaultsSection({
}
/>
<SwitchCardField
label={t("pages.config.allow_remote")}
hint={t("pages.config.allow_remote_hint")}
layout="setting-row"
checked={form.allowRemote}
onCheckedChange={(checked) => onFieldChange("allowRemote", checked)}
/>
<Field
label={t("pages.config.max_tokens")}
hint={t("pages.config.max_tokens_hint")}
@@ -161,6 +153,98 @@ export function AgentDefaultsSection({
)
}
interface ExecSectionProps {
form: CoreConfigForm
onFieldChange: UpdateCoreField
}
export function ExecSection({ form, onFieldChange }: ExecSectionProps) {
const { t } = useTranslation()
return (
<ConfigSectionCard title={t("pages.config.sections.exec")}>
<SwitchCardField
label={t("pages.config.exec_enabled")}
hint={t("pages.config.exec_enabled_hint")}
layout="setting-row"
checked={form.execEnabled}
onCheckedChange={(checked) => onFieldChange("execEnabled", checked)}
/>
{form.execEnabled && (
<>
<SwitchCardField
label={t("pages.config.allow_remote")}
hint={t("pages.config.allow_remote_hint")}
layout="setting-row"
checked={form.allowRemote}
onCheckedChange={(checked) => onFieldChange("allowRemote", checked)}
/>
<SwitchCardField
label={t("pages.config.enable_deny_patterns")}
hint={t("pages.config.enable_deny_patterns_hint")}
layout="setting-row"
checked={form.enableDenyPatterns}
onCheckedChange={(checked) =>
onFieldChange("enableDenyPatterns", checked)
}
/>
{form.enableDenyPatterns && (
<Field
label={t("pages.config.custom_deny_patterns")}
hint={t("pages.config.custom_deny_patterns_hint")}
layout="setting-row"
controlClassName="md:max-w-md"
>
<Textarea
value={form.customDenyPatternsText}
placeholder={t("pages.config.custom_patterns_placeholder")}
className="min-h-[88px]"
onChange={(e) =>
onFieldChange("customDenyPatternsText", e.target.value)
}
/>
</Field>
)}
<Field
label={t("pages.config.custom_allow_patterns")}
hint={t("pages.config.custom_allow_patterns_hint")}
layout="setting-row"
controlClassName="md:max-w-md"
>
<Textarea
value={form.customAllowPatternsText}
placeholder={t("pages.config.custom_patterns_placeholder")}
className="min-h-[88px]"
onChange={(e) =>
onFieldChange("customAllowPatternsText", e.target.value)
}
/>
</Field>
<Field
label={t("pages.config.exec_timeout_seconds")}
hint={t("pages.config.exec_timeout_seconds_hint")}
layout="setting-row"
>
<Input
type="number"
min={0}
value={form.execTimeoutSeconds}
onChange={(e) =>
onFieldChange("execTimeoutSeconds", e.target.value)
}
/>
</Field>
</>
)}
</ConfigSectionCard>
)
}
interface RuntimeSectionProps {
form: CoreConfigForm
onFieldChange: UpdateCoreField
@@ -236,6 +320,44 @@ export function RuntimeSection({ form, onFieldChange }: RuntimeSectionProps) {
)
}
interface CronSectionProps {
form: CoreConfigForm
onFieldChange: UpdateCoreField
}
export function CronSection({ form, onFieldChange }: CronSectionProps) {
const { t } = useTranslation()
return (
<ConfigSectionCard title={t("pages.config.sections.cron")}>
<SwitchCardField
label={t("pages.config.allow_shell_execution")}
hint={t("pages.config.allow_shell_execution_hint")}
layout="setting-row"
checked={form.allowCommand}
disabled={!form.execEnabled}
onCheckedChange={(checked) => onFieldChange("allowCommand", checked)}
/>
<Field
label={t("pages.config.cron_exec_timeout")}
hint={t("pages.config.cron_exec_timeout_hint")}
layout="setting-row"
>
<Input
type="number"
min={0}
disabled={!form.execEnabled}
value={form.cronExecTimeoutMinutes}
onChange={(e) =>
onFieldChange("cronExecTimeoutMinutes", e.target.value)
}
/>
</Field>
</ConfigSectionCard>
)
}
interface LauncherSectionProps {
launcherForm: LauncherForm
onFieldChange: UpdateLauncherField
@@ -3,7 +3,14 @@ export type JsonRecord = Record<string, unknown>
export interface CoreConfigForm {
workspace: string
restrictToWorkspace: boolean
execEnabled: boolean
allowRemote: boolean
enableDenyPatterns: boolean
customDenyPatternsText: string
customAllowPatternsText: string
execTimeoutSeconds: string
allowCommand: boolean
cronExecTimeoutMinutes: string
maxTokens: string
maxToolIterations: string
summarizeMessageThreshold: string
@@ -55,7 +62,14 @@ export const DM_SCOPE_OPTIONS = [
export const EMPTY_FORM: CoreConfigForm = {
workspace: "",
restrictToWorkspace: true,
execEnabled: true,
allowRemote: true,
enableDenyPatterns: true,
customDenyPatternsText: "",
customAllowPatternsText: "",
execTimeoutSeconds: "0",
allowCommand: true,
cronExecTimeoutMinutes: "5",
maxTokens: "32768",
maxToolIterations: "50",
summarizeMessageThreshold: "20",
@@ -106,6 +120,7 @@ export function buildFormFromConfig(config: unknown): CoreConfigForm {
const heartbeat = asRecord(root.heartbeat)
const devices = asRecord(root.devices)
const tools = asRecord(root.tools)
const cron = asRecord(tools.cron)
const exec = asRecord(tools.exec)
return {
@@ -114,10 +129,40 @@ export function buildFormFromConfig(config: unknown): CoreConfigForm {
defaults.restrict_to_workspace === undefined
? EMPTY_FORM.restrictToWorkspace
: asBool(defaults.restrict_to_workspace),
execEnabled:
exec.enabled === undefined
? EMPTY_FORM.execEnabled
: asBool(exec.enabled),
allowRemote:
exec.allow_remote === undefined
? EMPTY_FORM.allowRemote
: asBool(exec.allow_remote),
enableDenyPatterns:
exec.enable_deny_patterns === undefined
? EMPTY_FORM.enableDenyPatterns
: asBool(exec.enable_deny_patterns),
customDenyPatternsText: Array.isArray(exec.custom_deny_patterns)
? exec.custom_deny_patterns
.filter((value): value is string => typeof value === "string")
.join("\n")
: EMPTY_FORM.customDenyPatternsText,
customAllowPatternsText: Array.isArray(exec.custom_allow_patterns)
? exec.custom_allow_patterns
.filter((value): value is string => typeof value === "string")
.join("\n")
: EMPTY_FORM.customAllowPatternsText,
execTimeoutSeconds: asNumberString(
exec.timeout_seconds,
EMPTY_FORM.execTimeoutSeconds,
),
allowCommand:
cron.allow_command === undefined
? EMPTY_FORM.allowCommand
: asBool(cron.allow_command),
cronExecTimeoutMinutes: asNumberString(
cron.exec_timeout_minutes,
EMPTY_FORM.cronExecTimeoutMinutes,
),
maxTokens: asNumberString(defaults.max_tokens, EMPTY_FORM.maxTokens),
maxToolIterations: asNumberString(
defaults.max_tool_iterations,
@@ -178,3 +223,13 @@ export function parseCIDRText(raw: string): string[] {
.map((v) => v.trim())
.filter((v) => v.length > 0)
}
export function parseMultilineList(raw: string): string[] {
if (!raw.trim()) {
return []
}
return raw
.split("\n")
.map((value) => value.trim())
.filter((value) => value.length > 0)
}
@@ -2,24 +2,24 @@ import { getDefaultStore } from "jotai"
import { toast } from "sonner"
import { getPicoToken } from "@/api/pico"
import { getSessionHistory } from "@/api/sessions"
import i18n from "@/i18n"
import {
loadSessionMessages,
mergeHistoryMessages,
} from "@/features/chat/history"
import { type PicoMessage, handlePicoMessage } from "@/features/chat/protocol"
import {
clearStoredSessionId,
generateSessionId,
normalizeUnixTimestamp,
readStoredSessionId,
} from "@/lib/pico-chat-state"
import { type ChatMessage, getChatState, updateChatStore } from "@/store/chat"
import { gatewayAtom } from "@/store/gateway"
interface PicoMessage {
type: string
id?: string
session_id?: string
timestamp?: number | string
payload?: Record<string, unknown>
}
} from "@/features/chat/state"
import {
invalidateSocket,
isCurrentSocket,
normalizeWsUrlForBrowser,
} from "@/features/chat/websocket"
import i18n from "@/i18n"
import { getChatState, updateChatStore } from "@/store/chat"
import { type GatewayState, gatewayAtom } from "@/store/gateway"
const store = getDefaultStore()
@@ -31,81 +31,51 @@ let initialized = false
let unsubscribeGateway: (() => void) | null = null
let hydratePromise: Promise<void> | null = null
let connectionGeneration = 0
let reconnectTimer: number | null = null
let reconnectAttempts = 0
let shouldMaintainConnection = false
async function loadSessionMessages(sessionId: string): Promise<ChatMessage[]> {
const detail = await getSessionHistory(sessionId)
const fallbackTime = detail.updated
return detail.messages.map((message, index) => ({
id: `hist-${index}-${Date.now()}`,
role: message.role,
content: message.content,
timestamp: fallbackTime,
}))
function clearReconnectTimer() {
if (reconnectTimer !== null) {
window.clearTimeout(reconnectTimer)
reconnectTimer = null
}
}
function handlePicoMessage(message: PicoMessage) {
const payload = message.payload || {}
function shouldReconnectFor(generation: number, sessionId: string): boolean {
return (
shouldMaintainConnection &&
generation === connectionGeneration &&
sessionId === activeSessionIdRef &&
store.get(gatewayAtom).status === "running"
)
}
switch (message.type) {
case "message.create": {
const content = (payload.content as string) || ""
const messageId = (payload.message_id as string) || `pico-${Date.now()}`
const timestamp =
message.timestamp !== undefined &&
Number.isFinite(Number(message.timestamp))
? normalizeUnixTimestamp(Number(message.timestamp))
: Date.now()
updateChatStore((prev) => ({
messages: [
...prev.messages,
{
id: messageId,
role: "assistant",
content,
timestamp,
},
],
isTyping: false,
}))
break
}
case "message.update": {
const content = (payload.content as string) || ""
const messageId = payload.message_id as string
if (!messageId) {
break
}
updateChatStore((prev) => ({
messages: prev.messages.map((msg) =>
msg.id === messageId ? { ...msg, content } : msg,
),
}))
break
}
case "typing.start":
updateChatStore({ isTyping: true })
break
case "typing.stop":
updateChatStore({ isTyping: false })
break
case "error":
console.error("Pico error:", payload)
updateChatStore({ isTyping: false })
break
case "pong":
break
default:
console.log("Unknown pico message type:", message.type)
function scheduleReconnect(generation: number, sessionId: string) {
if (!shouldReconnectFor(generation, sessionId) || reconnectTimer !== null) {
return
}
const delay = Math.min(1000 * 2 ** reconnectAttempts, 5000)
reconnectAttempts += 1
reconnectTimer = window.setTimeout(() => {
reconnectTimer = null
if (!shouldReconnectFor(generation, sessionId)) {
return
}
void connectChat()
}, delay)
}
function needsActiveSessionHydration(): boolean {
const state = getChatState()
const storedSessionId = readStoredSessionId()
return Boolean(
storedSessionId &&
storedSessionId === state.activeSessionId &&
!state.hasHydratedActiveSession,
)
}
function setActiveSessionId(sessionId: string) {
@@ -113,8 +83,35 @@ function setActiveSessionId(sessionId: string) {
updateChatStore({ activeSessionId: sessionId })
}
function disconnectChatInternal({
clearDesiredConnection,
}: {
clearDesiredConnection: boolean
}) {
connectionGeneration += 1
clearReconnectTimer()
if (clearDesiredConnection) {
shouldMaintainConnection = false
}
const socket = wsRef
wsRef = null
isConnecting = false
invalidateSocket(socket)
updateChatStore({
connectionState: "disconnected",
isTyping: false,
})
}
export async function connectChat() {
if (store.get(gatewayAtom).status !== "running") {
if (
store.get(gatewayAtom).status !== "running" ||
needsActiveSessionHydration()
) {
return
}
@@ -130,12 +127,15 @@ export async function connectChat() {
const generation = connectionGeneration + 1
connectionGeneration = generation
isConnecting = true
clearReconnectTimer()
updateChatStore({ connectionState: "connecting" })
try {
const { token, ws_url } = await getPicoToken()
const sessionId = activeSessionIdRef
if (generation !== connectionGeneration) {
isConnecting = false
return
}
@@ -143,55 +143,71 @@ export async function connectChat() {
console.error("No pico token available")
updateChatStore({ connectionState: "error" })
isConnecting = false
scheduleReconnect(generation, sessionId)
return
}
let finalWsUrl = ws_url
try {
const parsedUrl = new URL(ws_url)
const isLocalHost =
parsedUrl.hostname === "localhost" ||
parsedUrl.hostname === "127.0.0.1" ||
parsedUrl.hostname === "0.0.0.0"
const isBrowserLocal =
window.location.hostname === "localhost" ||
window.location.hostname === "127.0.0.1"
if (isLocalHost && !isBrowserLocal) {
parsedUrl.hostname = window.location.hostname
finalWsUrl = parsedUrl.toString()
}
} catch (error) {
console.warn("Could not parse ws_url:", error)
}
const url = `${finalWsUrl}?token=${encodeURIComponent(token)}&session_id=${encodeURIComponent(activeSessionIdRef)}`
const socket = new WebSocket(url)
const finalWsUrl = normalizeWsUrlForBrowser(ws_url)
const url = `${finalWsUrl}?session_id=${encodeURIComponent(sessionId)}`
const socket = new WebSocket(url, [`token.${token}`])
if (generation !== connectionGeneration) {
socket.close()
isConnecting = false
invalidateSocket(socket)
return
}
socket.onopen = () => {
if (wsRef !== socket) {
if (
!isCurrentSocket({
socket,
currentSocket: wsRef,
generation,
currentGeneration: connectionGeneration,
sessionId,
currentSessionId: activeSessionIdRef,
})
) {
return
}
updateChatStore({ connectionState: "connected" })
isConnecting = false
reconnectAttempts = 0
}
socket.onmessage = (event) => {
if (
!isCurrentSocket({
socket,
currentSocket: wsRef,
generation,
currentGeneration: connectionGeneration,
sessionId,
currentSessionId: activeSessionIdRef,
})
) {
return
}
try {
const message: PicoMessage = JSON.parse(event.data)
handlePicoMessage(message)
const message = JSON.parse(event.data) as PicoMessage
handlePicoMessage(message, sessionId)
} catch {
console.warn("Non-JSON message from pico:", event.data)
}
}
socket.onclose = () => {
if (wsRef !== socket) {
if (
!isCurrentSocket({
socket,
currentSocket: wsRef,
generation,
currentGeneration: connectionGeneration,
sessionId,
currentSessionId: activeSessionIdRef,
})
) {
return
}
wsRef = null
@@ -200,42 +216,42 @@ export async function connectChat() {
connectionState: "disconnected",
isTyping: false,
})
scheduleReconnect(generation, sessionId)
}
socket.onerror = () => {
if (wsRef !== socket) {
if (
!isCurrentSocket({
socket,
currentSocket: wsRef,
generation,
currentGeneration: connectionGeneration,
sessionId,
currentSessionId: activeSessionIdRef,
})
) {
return
}
isConnecting = false
updateChatStore({ connectionState: "error" })
scheduleReconnect(generation, sessionId)
}
wsRef = socket
} catch (error) {
if (generation !== connectionGeneration) {
isConnecting = false
return
}
console.error("Failed to connect to pico:", error)
updateChatStore({ connectionState: "error" })
isConnecting = false
scheduleReconnect(generation, activeSessionIdRef)
}
}
export function disconnectChat() {
connectionGeneration += 1
const socket = wsRef
wsRef = null
isConnecting = false
if (socket) {
socket.close()
}
updateChatStore({
connectionState: "disconnected",
isTyping: false,
})
disconnectChatInternal({ clearDesiredConnection: true })
}
export async function hydrateActiveSession() {
@@ -249,7 +265,6 @@ export async function hydrateActiveSession() {
if (
!storedSessionId ||
state.hasHydratedActiveSession ||
state.messages.length > 0 ||
storedSessionId !== state.activeSessionId
) {
if (!state.hasHydratedActiveSession) {
@@ -266,7 +281,13 @@ export async function hydrateActiveSession() {
}
if (currentState.messages.length > 0) {
updateChatStore({ hasHydratedActiveSession: true })
updateChatStore({
messages: mergeHistoryMessages(
historyMessages,
currentState.messages,
),
hasHydratedActiveSession: true,
})
return
}
@@ -306,9 +327,10 @@ export async function hydrateActiveSession() {
export function sendChatMessage(content: string) {
if (!wsRef || wsRef.readyState !== WebSocket.OPEN) {
console.warn("WebSocket not connected")
return
return false
}
const socket = wsRef
const id = `msg-${++msgIdCounter}-${Date.now()}`
updateChatStore((prev) => ({
@@ -319,13 +341,23 @@ export function sendChatMessage(content: string) {
isTyping: true,
}))
wsRef.send(
JSON.stringify({
type: "message.send",
id,
payload: { content },
}),
)
try {
socket.send(
JSON.stringify({
type: "message.send",
id,
payload: { content },
}),
)
return true
} catch (error) {
console.error("Failed to send pico message:", error)
updateChatStore((prev) => ({
messages: prev.messages.filter((message) => message.id !== id),
isTyping: false,
}))
return false
}
}
export async function switchChatSession(sessionId: string) {
@@ -336,7 +368,7 @@ export async function switchChatSession(sessionId: string) {
try {
const historyMessages = await loadSessionMessages(sessionId)
disconnectChat()
disconnectChatInternal({ clearDesiredConnection: false })
setActiveSessionId(sessionId)
updateChatStore({
messages: historyMessages,
@@ -345,6 +377,7 @@ export async function switchChatSession(sessionId: string) {
})
if (store.get(gatewayAtom).status === "running") {
shouldMaintainConnection = true
await connectChat()
}
} catch (error) {
@@ -358,7 +391,7 @@ export async function newChatSession() {
return
}
disconnectChat()
disconnectChatInternal({ clearDesiredConnection: false })
setActiveSessionId(generateSessionId())
updateChatStore({
messages: [],
@@ -367,6 +400,7 @@ export async function newChatSession() {
})
if (store.get(gatewayAtom).status === "running") {
shouldMaintainConnection = true
await connectChat()
}
}
@@ -378,23 +412,43 @@ export function initializeChatStore() {
initialized = true
activeSessionIdRef = getChatState().activeSessionId
let lastGatewayStatus: GatewayState | null = null
const syncConnectionWithGateway = () => {
if (store.get(gatewayAtom).status === "running") {
const syncConnectionWithGateway = (force: boolean = false) => {
const gatewayStatus = store.get(gatewayAtom).status
if (!force && gatewayStatus === lastGatewayStatus) {
return
}
lastGatewayStatus = gatewayStatus
if (gatewayStatus === "running") {
shouldMaintainConnection = true
if (needsActiveSessionHydration()) {
return
}
void connectChat()
return
}
disconnectChat()
if (gatewayStatus === "stopped" || gatewayStatus === "error") {
disconnectChatInternal({ clearDesiredConnection: true })
}
}
unsubscribeGateway = store.sub(gatewayAtom, syncConnectionWithGateway)
if (!readStoredSessionId()) {
updateChatStore({ hasHydratedActiveSession: true })
syncConnectionWithGateway(true)
return
}
syncConnectionWithGateway()
void hydrateActiveSession().finally(() => {
if (!initialized) {
return
}
syncConnectionWithGateway(true)
})
}
export function teardownChatStore() {
+68
View File
@@ -0,0 +1,68 @@
import { getSessionHistory } from "@/api/sessions"
import { normalizeUnixTimestamp } from "@/features/chat/state"
import type { ChatMessage } from "@/store/chat"
export async function loadSessionMessages(
sessionId: string,
): Promise<ChatMessage[]> {
const detail = await getSessionHistory(sessionId)
const fallbackTime = detail.updated
return detail.messages.map((message, index) => ({
id: `hist-${index}-${Date.now()}`,
role: message.role,
content: message.content,
timestamp: fallbackTime,
}))
}
function normalizeMessageTimestamp(timestamp: number | string): string {
if (typeof timestamp === "number") {
return String(normalizeUnixTimestamp(timestamp))
}
const trimmed = timestamp.trim()
if (/^-?\d+(\.\d+)?$/.test(trimmed)) {
return String(normalizeUnixTimestamp(Number(trimmed)))
}
const parsed = Date.parse(trimmed)
return Number.isNaN(parsed) ? trimmed : String(parsed)
}
function messageSignature(message: ChatMessage): string {
return `${message.role}\u0000${message.content}\u0000${normalizeMessageTimestamp(
message.timestamp,
)}`
}
function comparableTimestamp(timestamp: number | string): number {
const normalized = normalizeMessageTimestamp(timestamp)
const numeric = Number(normalized)
return Number.isFinite(numeric) ? numeric : 0
}
export function mergeHistoryMessages(
historyMessages: ChatMessage[],
currentMessages: ChatMessage[],
): ChatMessage[] {
const currentIds = new Set(currentMessages.map((message) => message.id))
const currentSignatures = new Set(
currentMessages.map((message) => messageSignature(message)),
)
const merged = [
...historyMessages.filter(
(message) =>
!currentIds.has(message.id) &&
!currentSignatures.has(messageSignature(message)),
),
...currentMessages,
]
return merged.sort(
(left, right) =>
comparableTimestamp(left.timestamp) -
comparableTimestamp(right.timestamp),
)
}
@@ -0,0 +1,81 @@
import { normalizeUnixTimestamp } from "@/features/chat/state"
import { updateChatStore } from "@/store/chat"
export interface PicoMessage {
type: string
id?: string
session_id?: string
timestamp?: number | string
payload?: Record<string, unknown>
}
export function handlePicoMessage(
message: PicoMessage,
expectedSessionId: string,
) {
if (message.session_id && message.session_id !== expectedSessionId) {
return
}
const payload = message.payload || {}
switch (message.type) {
case "message.create": {
const content = (payload.content as string) || ""
const messageId = (payload.message_id as string) || `pico-${Date.now()}`
const timestamp =
message.timestamp !== undefined &&
Number.isFinite(Number(message.timestamp))
? normalizeUnixTimestamp(Number(message.timestamp))
: Date.now()
updateChatStore((prev) => ({
messages: [
...prev.messages,
{
id: messageId,
role: "assistant",
content,
timestamp,
},
],
isTyping: false,
}))
break
}
case "message.update": {
const content = (payload.content as string) || ""
const messageId = payload.message_id as string
if (!messageId) {
break
}
updateChatStore((prev) => ({
messages: prev.messages.map((msg) =>
msg.id === messageId ? { ...msg, content } : msg,
),
}))
break
}
case "typing.start":
updateChatStore({ isTyping: true })
break
case "typing.stop":
updateChatStore({ isTyping: false })
break
case "error":
console.error("Pico error:", payload)
updateChatStore({ isTyping: false })
break
case "pong":
break
default:
console.log("Unknown pico message type:", message.type)
}
}

Some files were not shown because too many files have changed in this diff Show More