mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge upstream/main into feat/subturn-poc
Includes JSONL session persistence (#1170), spawn_status tool, Azure provider, credential encryption, and various fixes. SubTurn features preserved and integrated with new spawn_status functionality.
This commit is contained in:
@@ -0,0 +1,27 @@
|
||||
version: 2
|
||||
|
||||
updates:
|
||||
|
||||
# Go dependencies (entire repo)
|
||||
- package-ecosystem: "gomod"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
labels:
|
||||
- "dependencies"
|
||||
- "go"
|
||||
|
||||
# Frontend dependencies
|
||||
- package-ecosystem: "npm"
|
||||
directory: "/web/frontend"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
labels:
|
||||
- "dependencies"
|
||||
- "frontend"
|
||||
|
||||
# GitHub Actions
|
||||
- package-ecosystem: "github-actions"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
@@ -31,11 +31,11 @@ jobs:
|
||||
|
||||
# ── Docker Buildx ─────────────────────────
|
||||
- name: 🔧 Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@v4
|
||||
|
||||
# ── Login to GHCR ─────────────────────────
|
||||
- name: 🔑 Login to GitHub Container Registry
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@v4
|
||||
with:
|
||||
registry: ${{ env.GHCR_REGISTRY }}
|
||||
username: ${{ github.actor }}
|
||||
@@ -43,7 +43,7 @@ jobs:
|
||||
|
||||
# ── Login to Docker Hub ────────────────────
|
||||
- name: 🔑 Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@v4
|
||||
with:
|
||||
registry: ${{ env.DOCKERHUB_REGISTRY }}
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
@@ -62,7 +62,7 @@ jobs:
|
||||
|
||||
# ── Build & Push ──────────────────────────
|
||||
- name: 🚀 Build and push Docker image
|
||||
uses: docker/build-push-action@v6
|
||||
uses: docker/build-push-action@v7
|
||||
with:
|
||||
context: .
|
||||
push: true
|
||||
|
||||
@@ -48,7 +48,7 @@ jobs:
|
||||
go-version-file: go.mod
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: 22
|
||||
|
||||
@@ -59,17 +59,17 @@ jobs:
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@v4
|
||||
|
||||
- name: Login to GitHub Container Registry
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@v4
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@v4
|
||||
with:
|
||||
registry: docker.io
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
|
||||
@@ -34,7 +34,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
|
||||
|
||||
@@ -66,7 +66,7 @@ jobs:
|
||||
go-version-file: go.mod
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: 22
|
||||
|
||||
@@ -77,17 +77,17 @@ jobs:
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@v4
|
||||
|
||||
- name: Login to GitHub Container Registry
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@v4
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@v4
|
||||
with:
|
||||
registry: docker.io
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
|
||||
@@ -12,10 +12,11 @@ GIT_COMMIT=$(shell git rev-parse --short=8 HEAD 2>/dev/null || echo "dev")
|
||||
BUILD_TIME=$(shell date +%FT%T%z)
|
||||
GO_VERSION=$(shell $(GO) version | awk '{print $$3}')
|
||||
CONFIG_PKG=github.com/sipeed/picoclaw/pkg/config
|
||||
LDFLAGS=-ldflags "-X $(CONFIG_PKG).Version=$(VERSION) -X $(CONFIG_PKG).GitCommit=$(GIT_COMMIT) -X $(CONFIG_PKG).BuildTime=$(BUILD_TIME) -X $(CONFIG_PKG).GoVersion=$(GO_VERSION) -s -w"
|
||||
LDFLAGS=-X $(CONFIG_PKG).Version=$(VERSION) -X $(CONFIG_PKG).GitCommit=$(GIT_COMMIT) -X $(CONFIG_PKG).BuildTime=$(BUILD_TIME) -X $(CONFIG_PKG).GoVersion=$(GO_VERSION) -s -w
|
||||
|
||||
# Go variables
|
||||
GO?=CGO_ENABLED=0 go
|
||||
WEB_GO?=$(GO)
|
||||
GOFLAGS?=-v -tags stdjson
|
||||
|
||||
# Patch MIPS LE ELF e_flags (offset 36) for NaN2008-only kernels (e.g. Ingenic X2600).
|
||||
@@ -79,6 +80,7 @@ ifeq ($(UNAME_S),Linux)
|
||||
endif
|
||||
else ifeq ($(UNAME_S),Darwin)
|
||||
PLATFORM=darwin
|
||||
WEB_GO=CGO_ENABLED=1 go
|
||||
ifeq ($(UNAME_M),x86_64)
|
||||
ARCH=amd64
|
||||
else ifeq ($(UNAME_M),arm64)
|
||||
@@ -107,7 +109,7 @@ generate:
|
||||
build: generate
|
||||
@echo "Building $(BINARY_NAME) for $(PLATFORM)/$(ARCH)..."
|
||||
@mkdir -p $(BUILD_DIR)
|
||||
@$(GO) build $(GOFLAGS) $(LDFLAGS) -o $(BINARY_PATH) ./$(CMD_DIR)
|
||||
@$(GO) build $(GOFLAGS) -ldflags "$(LDFLAGS)" -o $(BINARY_PATH) ./$(CMD_DIR)
|
||||
@echo "Build complete: $(BINARY_PATH)"
|
||||
@ln -sf $(BINARY_NAME)-$(PLATFORM)-$(ARCH) $(BUILD_DIR)/$(BINARY_NAME)
|
||||
|
||||
@@ -119,7 +121,7 @@ build-launcher:
|
||||
echo "Building frontend..."; \
|
||||
cd web/frontend && pnpm install && pnpm build:backend; \
|
||||
fi
|
||||
@$(GO) build $(GOFLAGS) -o $(BUILD_DIR)/picoclaw-launcher-$(PLATFORM)-$(ARCH) ./web/backend
|
||||
@$(WEB_GO) build $(GOFLAGS) -o $(BUILD_DIR)/picoclaw-launcher-$(PLATFORM)-$(ARCH) ./web/backend
|
||||
@ln -sf picoclaw-launcher-$(PLATFORM)-$(ARCH) $(BUILD_DIR)/picoclaw-launcher
|
||||
@echo "Build complete: $(BUILD_DIR)/picoclaw-launcher"
|
||||
|
||||
@@ -128,16 +130,16 @@ build-whatsapp-native: generate
|
||||
## @echo "Building $(BINARY_NAME) with WhatsApp native for $(PLATFORM)/$(ARCH)..."
|
||||
@echo "Building for multiple platforms..."
|
||||
@mkdir -p $(BUILD_DIR)
|
||||
GOOS=linux GOARCH=amd64 $(GO) build -tags whatsapp_native $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-amd64 ./$(CMD_DIR)
|
||||
GOOS=linux GOARCH=arm GOARM=7 $(GO) build -tags whatsapp_native $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm ./$(CMD_DIR)
|
||||
GOOS=linux GOARCH=arm64 $(GO) build -tags whatsapp_native $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm64 ./$(CMD_DIR)
|
||||
GOOS=linux GOARCH=loong64 $(GO) build -tags whatsapp_native $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-loong64 ./$(CMD_DIR)
|
||||
GOOS=linux GOARCH=riscv64 $(GO) build -tags whatsapp_native $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-riscv64 ./$(CMD_DIR)
|
||||
GOOS=linux GOARCH=mipsle GOMIPS=softfloat $(GO) build -tags whatsapp_native $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-mipsle ./$(CMD_DIR)
|
||||
GOOS=linux GOARCH=amd64 $(GO) build -tags whatsapp_native -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-amd64 ./$(CMD_DIR)
|
||||
GOOS=linux GOARCH=arm GOARM=7 $(GO) build -tags whatsapp_native -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm ./$(CMD_DIR)
|
||||
GOOS=linux GOARCH=arm64 $(GO) build -tags whatsapp_native -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm64 ./$(CMD_DIR)
|
||||
GOOS=linux GOARCH=loong64 $(GO) build -tags whatsapp_native -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-loong64 ./$(CMD_DIR)
|
||||
GOOS=linux GOARCH=riscv64 $(GO) build -tags whatsapp_native -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-riscv64 ./$(CMD_DIR)
|
||||
GOOS=linux GOARCH=mipsle GOMIPS=softfloat $(GO) build -tags whatsapp_native -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-mipsle ./$(CMD_DIR)
|
||||
$(call PATCH_MIPS_FLAGS,$(BUILD_DIR)/$(BINARY_NAME)-linux-mipsle)
|
||||
GOOS=darwin GOARCH=arm64 $(GO) build -tags whatsapp_native $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-darwin-arm64 ./$(CMD_DIR)
|
||||
GOOS=windows GOARCH=amd64 $(GO) build -tags whatsapp_native $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-windows-amd64.exe ./$(CMD_DIR)
|
||||
## @$(GO) build $(GOFLAGS) -tags whatsapp_native $(LDFLAGS) -o $(BINARY_PATH) ./$(CMD_DIR)
|
||||
GOOS=darwin GOARCH=arm64 $(GO) build -tags whatsapp_native -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-darwin-arm64 ./$(CMD_DIR)
|
||||
GOOS=windows GOARCH=amd64 $(GO) build -tags whatsapp_native -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-windows-amd64.exe ./$(CMD_DIR)
|
||||
## @$(GO) build $(GOFLAGS) -tags whatsapp_native -ldflags "$(LDFLAGS)" -o $(BINARY_PATH) ./$(CMD_DIR)
|
||||
@echo "Build complete"
|
||||
## @ln -sf $(BINARY_NAME)-$(PLATFORM)-$(ARCH) $(BUILD_DIR)/$(BINARY_NAME)
|
||||
|
||||
@@ -145,21 +147,21 @@ build-whatsapp-native: generate
|
||||
build-linux-arm: generate
|
||||
@echo "Building for linux/arm (GOARM=7)..."
|
||||
@mkdir -p $(BUILD_DIR)
|
||||
GOOS=linux GOARCH=arm GOARM=7 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm ./$(CMD_DIR)
|
||||
GOOS=linux GOARCH=arm GOARM=7 $(GO) build -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm ./$(CMD_DIR)
|
||||
@echo "Build complete: $(BUILD_DIR)/$(BINARY_NAME)-linux-arm"
|
||||
|
||||
## build-linux-arm64: Build for Linux ARM64 (e.g. Raspberry Pi Zero 2 W 64-bit)
|
||||
build-linux-arm64: generate
|
||||
@echo "Building for linux/arm64..."
|
||||
@mkdir -p $(BUILD_DIR)
|
||||
GOOS=linux GOARCH=arm64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm64 ./$(CMD_DIR)
|
||||
GOOS=linux GOARCH=arm64 $(GO) build -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm64 ./$(CMD_DIR)
|
||||
@echo "Build complete: $(BUILD_DIR)/$(BINARY_NAME)-linux-arm64"
|
||||
|
||||
## build-linux-mipsle: Build for Linux MIPS32 LE
|
||||
build-linux-mipsle: generate
|
||||
@echo "Building for linux/mipsle (softfloat)..."
|
||||
@mkdir -p $(BUILD_DIR)
|
||||
GOOS=linux GOARCH=mipsle GOMIPS=softfloat $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-mipsle ./$(CMD_DIR)
|
||||
GOOS=linux GOARCH=mipsle GOMIPS=softfloat $(GO) build -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-mipsle ./$(CMD_DIR)
|
||||
$(call PATCH_MIPS_FLAGS,$(BUILD_DIR)/$(BINARY_NAME)-linux-mipsle)
|
||||
@echo "Build complete: $(BUILD_DIR)/$(BINARY_NAME)-linux-mipsle"
|
||||
|
||||
@@ -171,18 +173,18 @@ build-pi-zero: build-linux-arm build-linux-arm64
|
||||
build-all: generate
|
||||
@echo "Building for multiple platforms..."
|
||||
@mkdir -p $(BUILD_DIR)
|
||||
GOOS=linux GOARCH=amd64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-amd64 ./$(CMD_DIR)
|
||||
GOOS=linux GOARCH=arm GOARM=7 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm ./$(CMD_DIR)
|
||||
GOOS=linux GOARCH=arm64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm64 ./$(CMD_DIR)
|
||||
GOOS=linux GOARCH=loong64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-loong64 ./$(CMD_DIR)
|
||||
GOOS=linux GOARCH=riscv64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-riscv64 ./$(CMD_DIR)
|
||||
GOOS=linux GOARCH=mipsle GOMIPS=softfloat $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-mipsle ./$(CMD_DIR)
|
||||
GOOS=linux GOARCH=amd64 $(GO) build -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-amd64 ./$(CMD_DIR)
|
||||
GOOS=linux GOARCH=arm GOARM=7 $(GO) build -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm ./$(CMD_DIR)
|
||||
GOOS=linux GOARCH=arm64 $(GO) build -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm64 ./$(CMD_DIR)
|
||||
GOOS=linux GOARCH=loong64 $(GO) build -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-loong64 ./$(CMD_DIR)
|
||||
GOOS=linux GOARCH=riscv64 $(GO) build -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-riscv64 ./$(CMD_DIR)
|
||||
GOOS=linux GOARCH=mipsle GOMIPS=softfloat $(GO) build -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-mipsle ./$(CMD_DIR)
|
||||
$(call PATCH_MIPS_FLAGS,$(BUILD_DIR)/$(BINARY_NAME)-linux-mipsle)
|
||||
GOOS=linux GOARCH=arm GOARM=7 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-armv7 ./$(CMD_DIR)
|
||||
GOOS=darwin GOARCH=arm64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-darwin-arm64 ./$(CMD_DIR)
|
||||
GOOS=windows GOARCH=amd64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-windows-amd64.exe ./$(CMD_DIR)
|
||||
GOOS=netbsd GOARCH=amd64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-netbsd-amd64 ./$(CMD_DIR)
|
||||
GOOS=netbsd GOARCH=arm64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-netbsd-arm64 ./$(CMD_DIR)
|
||||
GOOS=linux GOARCH=arm GOARM=7 $(GO) build -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-linux-armv7 ./$(CMD_DIR)
|
||||
GOOS=darwin GOARCH=arm64 $(GO) build -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-darwin-arm64 ./$(CMD_DIR)
|
||||
GOOS=windows GOARCH=amd64 $(GO) build -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-windows-amd64.exe ./$(CMD_DIR)
|
||||
GOOS=netbsd GOARCH=amd64 $(GO) build -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-netbsd-amd64 ./$(CMD_DIR)
|
||||
GOOS=netbsd GOARCH=arm64 $(GO) build -ldflags "$(LDFLAGS)" -o $(BUILD_DIR)/$(BINARY_NAME)-netbsd-arm64 ./$(CMD_DIR)
|
||||
@echo "All builds complete"
|
||||
|
||||
## install: Install picoclaw to system and copy builtin skills
|
||||
@@ -219,11 +221,14 @@ clean:
|
||||
|
||||
## vet: Run go vet for static analysis
|
||||
vet: generate
|
||||
@$(GO) vet ./...
|
||||
@packages="$$(go list ./...)" && \
|
||||
$(GO) vet $$(printf '%s\n' "$$packages" | grep -v '^github.com/sipeed/picoclaw/web/')
|
||||
@cd web/backend && $(WEB_GO) vet ./...
|
||||
|
||||
## test: Test Go code
|
||||
test: generate
|
||||
@$(GO) test ./...
|
||||
@$(GO) test $$(go list ./... | grep -v github.com/sipeed/picoclaw/web/)
|
||||
@cd web && make test
|
||||
|
||||
## fmt: Format Go code
|
||||
fmt:
|
||||
|
||||
+4
-1
@@ -23,7 +23,9 @@
|
||||
|
||||
---
|
||||
|
||||
🦐 **PicoClaw** est un assistant personnel IA ultra-léger inspiré de [nanobot](https://github.com/HKUDS/nanobot), entièrement réécrit en **Go** via un processus d'auto-amorçage (self-bootstrapping) — où l'agent IA lui-même a piloté l'intégralité de la migration architecturale et de l'optimisation du code.
|
||||
> **PicoClaw** est un projet open-source indépendant initié par [Sipeed](https://sipeed.com). Il est entièrement écrit en **Go** — ce n'est pas un fork d'OpenClaw, de NanoBot ou de tout autre projet.
|
||||
|
||||
🦐 **PicoClaw** est un assistant personnel IA ultra-léger inspiré de [NanoBot](https://github.com/HKUDS/nanobot), entièrement réécrit en **Go** via un processus d'auto-amorçage (self-bootstrapping) — où l'agent IA lui-même a piloté l'intégralité de la migration architecturale et de l'optimisation du code.
|
||||
|
||||
⚡️ **Extrêmement léger :** Fonctionne sur du matériel à seulement **10$** avec **<10 Mo** de RAM. C'est 99% de mémoire en moins qu'OpenClaw et 98% moins cher qu'un Mac mini !
|
||||
|
||||
@@ -991,6 +993,7 @@ Cette conception permet également le **support multi-agent** avec une sélectio
|
||||
| **BytePlus** | `byteplus/` | `https://ark.ap-southeast.bytepluses.com/api/v3` | OpenAI | [Obtenir Clé](https://www.byteplus.com/) |
|
||||
| **LongCat** | `longcat/` | `https://api.longcat.chat/openai` | OpenAI | [Obtenir une clé](https://longcat.chat/platform) |
|
||||
| **ModelScope (魔搭)**| `modelscope/` | `https://api-inference.modelscope.cn/v1` | OpenAI | [Obtenir un Token](https://modelscope.cn/my/tokens) |
|
||||
| **Azure OpenAI** | `azure/` | `https://{resource}.openai.azure.com` | Azure | [Obtenir Clé](https://portal.azure.com) |
|
||||
| **Antigravity** | `antigravity/` | Google Cloud | Custom | OAuth uniquement |
|
||||
| **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - |
|
||||
|
||||
|
||||
+4
-1
@@ -26,7 +26,9 @@
|
||||
|
||||
---
|
||||
|
||||
🦐 PicoClaw は [nanobot](https://github.com/HKUDS/nanobot) にインスパイアされた超軽量パーソナル AI アシスタントです。Go でゼロからリファクタリングされ、AI エージェント自身がアーキテクチャの移行とコード最適化を推進するセルフブートストラッピングプロセスで構築されました。
|
||||
> **PicoClaw** は [Sipeed](https://sipeed.com) が立ち上げた独立したオープンソースプロジェクトです。完全に **Go 言語**で一から書かれており、OpenClaw、NanoBot、その他のプロジェクトのフォークではありません。
|
||||
|
||||
🦐 PicoClaw は [NanoBot](https://github.com/HKUDS/nanobot) にインスパイアされた超軽量パーソナル AI アシスタントです。Go でゼロからリファクタリングされ、AI エージェント自身がアーキテクチャの移行とコード最適化を推進するセルフブートストラッピングプロセスで構築されました。
|
||||
|
||||
⚡️ $10 のハードウェアで 10MB 未満の RAM で動作:OpenClaw より 99% 少ないメモリ、Mac mini より 98% 安い!
|
||||
|
||||
@@ -935,6 +937,7 @@ HEARTBEAT_OK 応答 ユーザーが直接結果を受け取る
|
||||
| **BytePlus** | `byteplus/` | `https://ark.ap-southeast.bytepluses.com/api/v3` | OpenAI | [キーを取得](https://www.byteplus.com) |
|
||||
| **LongCat** | `longcat/` | `https://api.longcat.chat/openai` | OpenAI | [キーを取得](https://longcat.chat/platform) |
|
||||
| **ModelScope (魔搭)**| `modelscope/` | `https://api-inference.modelscope.cn/v1` | OpenAI | [トークンを取得](https://modelscope.cn/my/tokens) |
|
||||
| **Azure OpenAI** | `azure/` | `https://{resource}.openai.azure.com` | Azure | [キーを取得](https://portal.azure.com) |
|
||||
| **Antigravity** | `antigravity/` | Google Cloud | カスタム | OAuthのみ |
|
||||
| **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - |
|
||||
|
||||
|
||||
@@ -24,7 +24,9 @@
|
||||
|
||||
---
|
||||
|
||||
🦐 PicoClaw is an ultra-lightweight personal AI Assistant inspired by [nanobot](https://github.com/HKUDS/nanobot), refactored from the ground up in Go through a self-bootstrapping process, where the AI agent itself drove the entire architectural migration and code optimization.
|
||||
> **PicoClaw** is an independent open-source project initiated by [Sipeed](https://sipeed.com). It is written entirely in **Go** — not a fork of OpenClaw, NanoBot, or any other project.
|
||||
|
||||
🦐 PicoClaw is an ultra-lightweight personal AI Assistant inspired by [NanoBot](https://github.com/HKUDS/nanobot), refactored from the ground up in Go through a self-bootstrapping process, where the AI agent itself drove the entire architectural migration and code optimization.
|
||||
|
||||
⚡️ Runs on $10 hardware with <10MB RAM: That's 99% less memory than OpenClaw and 98% cheaper than a Mac mini!
|
||||
|
||||
@@ -861,6 +863,21 @@ Even with `restrict_to_workspace: false`, the `exec` tool blocks these dangerous
|
||||
* `shutdown`, `reboot`, `poweroff` — System shutdown
|
||||
* Fork bomb `:(){ :|:& };:`
|
||||
|
||||
#### Known Limitation: Child Processes From Build Tools
|
||||
|
||||
The exec safety guard only inspects the command line PicoClaw launches directly. It does not recursively inspect child
|
||||
processes spawned by allowed developer tools such as `make`, `go run`, `cargo`, `npm run`, or custom build scripts.
|
||||
|
||||
That means a top-level command can still compile or launch other binaries after it passes the initial guard check. In
|
||||
practice, treat build scripts, Makefiles, package scripts, and generated binaries as executable code that needs the same
|
||||
level of review as a direct shell command.
|
||||
|
||||
For higher-risk environments:
|
||||
|
||||
* Review build scripts before execution.
|
||||
* Prefer approval/manual review for compile-and-run workflows.
|
||||
* Run PicoClaw inside a container or VM if you need stronger isolation than the built-in guard provides.
|
||||
|
||||
#### Error Examples
|
||||
|
||||
```
|
||||
@@ -1006,6 +1023,7 @@ The subagent has access to tools (message, web_search, etc.) and can communicate
|
||||
| `groq` | LLM + **Voice transcription** (Whisper) | [console.groq.com](https://console.groq.com) |
|
||||
| `cerebras` | LLM (Cerebras direct) | [cerebras.ai](https://cerebras.ai) |
|
||||
| `vivgrid` | LLM (Vivgrid direct) | [vivgrid.com](https://vivgrid.com) |
|
||||
| `azure` | LLM (Azure OpenAI) | [portal.azure.com](https://portal.azure.com) |
|
||||
|
||||
### Model Configuration (model_list)
|
||||
|
||||
@@ -1042,6 +1060,7 @@ This design also enables **multi-agent support** with flexible provider selectio
|
||||
| **Vivgrid** | `vivgrid/` | `https://api.vivgrid.com/v1` | OpenAI | [Get Key](https://vivgrid.com) |
|
||||
| **LongCat** | `longcat/` | `https://api.longcat.chat/openai` | OpenAI | [Get Key](https://longcat.chat/platform) |
|
||||
| **ModelScope (魔搭)**| `modelscope/` | `https://api-inference.modelscope.cn/v1` | OpenAI | [Get Token](https://modelscope.cn/my/tokens) |
|
||||
| **Azure OpenAI** | `azure/` | `https://{resource}.openai.azure.com` | Azure | [Get Key](https://portal.azure.com) |
|
||||
| **Antigravity** | `antigravity/` | Google Cloud | Custom | OAuth only |
|
||||
| **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - |
|
||||
|
||||
|
||||
+4
-1
@@ -23,7 +23,9 @@
|
||||
|
||||
---
|
||||
|
||||
🦐 **PicoClaw** é um assistente pessoal de IA ultra-leve inspirado no [nanobot](https://github.com/HKUDS/nanobot), reescrito do zero em **Go** por meio de um processo de "auto-inicialização" (self-bootstrapping) — onde o próprio agente de IA conduziu toda a migração de arquitetura e otimização de código.
|
||||
> **PicoClaw** é um projeto open-source independente iniciado pela [Sipeed](https://sipeed.com). É escrito inteiramente em **Go** — não é um fork do OpenClaw, NanoBot ou qualquer outro projeto.
|
||||
|
||||
🦐 **PicoClaw** é um assistente pessoal de IA ultra-leve inspirado no [NanoBot](https://github.com/HKUDS/nanobot), reescrito do zero em **Go** por meio de um processo de "auto-inicialização" (self-bootstrapping) — onde o próprio agente de IA conduziu toda a migração de arquitetura e otimização de código.
|
||||
|
||||
⚡️ **Extremamente leve:** Roda em hardware de apenas **$10** com **<10MB** de RAM. Isso é 99% menos memória que o OpenClaw e 98% mais barato que um Mac mini!
|
||||
|
||||
@@ -987,6 +989,7 @@ Este design também possibilita o **suporte multi-agent** com seleção flexíve
|
||||
| **BytePlus** | `byteplus/` | `https://ark.ap-southeast.bytepluses.com/api/v3` | OpenAI | [Obter Chave](https://www.byteplus.com) |
|
||||
| **LongCat** | `longcat/` | `https://api.longcat.chat/openai` | OpenAI | [Obter Chave](https://longcat.chat/platform) |
|
||||
| **ModelScope (魔搭)**| `modelscope/` | `https://api-inference.modelscope.cn/v1` | OpenAI | [Obter Token](https://modelscope.cn/my/tokens) |
|
||||
| **Azure OpenAI** | `azure/` | `https://{resource}.openai.azure.com` | Azure | [Obter Chave](https://portal.azure.com) |
|
||||
| **Antigravity** | `antigravity/` | Google Cloud | Custom | Apenas OAuth |
|
||||
| **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - |
|
||||
|
||||
|
||||
+4
-1
@@ -23,7 +23,9 @@
|
||||
|
||||
---
|
||||
|
||||
🦐 **PicoClaw** là trợ lý AI cá nhân siêu nhẹ, lấy cảm hứng từ [nanobot](https://github.com/HKUDS/nanobot), được viết lại hoàn toàn bằng **Go** thông qua quá trình "tự khởi tạo" (self-bootstrapping) — nơi chính AI Agent đã tự dẫn dắt toàn bộ quá trình chuyển đổi kiến trúc và tối ưu hóa mã nguồn.
|
||||
> **PicoClaw** là dự án mã nguồn mở độc lập được khởi xướng bởi [Sipeed](https://sipeed.com). Được viết hoàn toàn bằng **Go** — không phải là bản fork của OpenClaw, NanoBot hay bất kỳ dự án nào khác.
|
||||
|
||||
🦐 **PicoClaw** là trợ lý AI cá nhân siêu nhẹ, lấy cảm hứng từ [NanoBot](https://github.com/HKUDS/nanobot), được viết lại hoàn toàn bằng **Go** thông qua quá trình "tự khởi tạo" (self-bootstrapping) — nơi chính AI Agent đã tự dẫn dắt toàn bộ quá trình chuyển đổi kiến trúc và tối ưu hóa mã nguồn.
|
||||
|
||||
⚡️ **Cực kỳ nhẹ:** Chạy trên phần cứng chỉ **$10** với RAM **<10MB**. Tiết kiệm 99% bộ nhớ so với OpenClaw và rẻ hơn 98% so với Mac mini!
|
||||
|
||||
@@ -956,6 +958,7 @@ Thiết kế này cũng cho phép **hỗ trợ đa tác nhân** với lựa ch
|
||||
| **BytePlus** | `byteplus/` | `https://ark.ap-southeast.bytepluses.com/api/v3` | OpenAI | [Lấy Khóa](https://www.byteplus.com) |
|
||||
| **LongCat** | `longcat/` | `https://api.longcat.chat/openai` | OpenAI | [Lấy Key](https://longcat.chat/platform) |
|
||||
| **ModelScope (魔搭)**| `modelscope/` | `https://api-inference.modelscope.cn/v1` | OpenAI | [Lấy Token](https://modelscope.cn/my/tokens) |
|
||||
| **Azure OpenAI** | `azure/` | `https://{resource}.openai.azure.com` | Azure | [Lấy Khóa](https://portal.azure.com) |
|
||||
| **Antigravity** | `antigravity/` | Google Cloud | Tùy chỉnh | Chỉ OAuth |
|
||||
| **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - |
|
||||
|
||||
|
||||
+4
-1
@@ -24,7 +24,9 @@
|
||||
|
||||
---
|
||||
|
||||
🦐 **PicoClaw** 是一个受 [nanobot](https://github.com/HKUDS/nanobot) 启发的超轻量级个人 AI 助手。它采用 **Go 语言** 从零重构,经历了一个“自举”过程——即由 AI Agent 自身驱动了整个架构迁移和代码优化。
|
||||
> **PicoClaw** 是由 [矽速科技 (Sipeed)](https://sipeed.com) 发起的独立开源项目,完全使用 **Go 语言**从零编写——不是 OpenClaw、NanoBot 或其他项目的分支。
|
||||
|
||||
🦐 **PicoClaw** 是一个受 [NanoBot](https://github.com/HKUDS/nanobot) 启发的超轻量级个人 AI 助手。它采用 **Go 语言** 从零重构,经历了一个“自举”过程——即由 AI Agent 自身驱动了整个架构迁移和代码优化。
|
||||
|
||||
⚡️ **极致轻量**:可在 **10 美元** 的硬件上运行,内存占用 **<10MB**。这意味着比 OpenClaw 节省 99% 的内存,比 Mac mini 便宜 98%!
|
||||
|
||||
@@ -528,6 +530,7 @@ Agent 读取 HEARTBEAT.md
|
||||
| **BytePlus** | `byteplus/` | `https://ark.ap-southeast.bytepluses.com/api/v3` | OpenAI | [获取密钥](https://www.byteplus.com) |
|
||||
| **LongCat** | `longcat/` | `https://api.longcat.chat/openai` | OpenAI | [获取密钥](https://longcat.chat/platform) |
|
||||
| **ModelScope (魔搭)**| `modelscope/` | `https://api-inference.modelscope.cn/v1` | OpenAI | [获取 Token](https://modelscope.cn/my/tokens) |
|
||||
| **Azure OpenAI** | `azure/` | `https://{resource}.openai.azure.com` | Azure | [获取密钥](https://portal.azure.com) |
|
||||
| **Antigravity** | `antigravity/` | Google Cloud | 自定义 | 仅 OAuth |
|
||||
| **GitHub Copilot** | `github-copilot/` | `localhost:4321` | gRPC | - |
|
||||
|
||||
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/sipeed/picoclaw/cmd/picoclaw/internal"
|
||||
"github.com/sipeed/picoclaw/pkg/gateway"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
@@ -12,6 +14,7 @@ import (
|
||||
func NewGatewayCommand() *cobra.Command {
|
||||
var debug bool
|
||||
var noTruncate bool
|
||||
var allowEmpty bool
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "gateway",
|
||||
@@ -31,12 +34,19 @@ func NewGatewayCommand() *cobra.Command {
|
||||
return nil
|
||||
},
|
||||
RunE: func(_ *cobra.Command, _ []string) error {
|
||||
return gatewayCmd(debug)
|
||||
return gateway.Run(debug, internal.GetConfigPath(), allowEmpty)
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().BoolVarP(&debug, "debug", "d", false, "Enable debug logging")
|
||||
cmd.Flags().BoolVarP(&noTruncate, "no-truncate", "T", false, "Disable string truncation in debug logs")
|
||||
cmd.Flags().BoolVarP(
|
||||
&allowEmpty,
|
||||
"allow-empty",
|
||||
"E",
|
||||
false,
|
||||
"Continue starting even when no default model is configured",
|
||||
)
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
@@ -28,4 +28,5 @@ func TestNewGatewayCommand(t *testing.T) {
|
||||
|
||||
assert.True(t, cmd.HasFlags())
|
||||
assert.NotNil(t, cmd.Flags().Lookup("debug"))
|
||||
assert.NotNil(t, cmd.Flags().Lookup("allow-empty"))
|
||||
}
|
||||
|
||||
@@ -11,14 +11,19 @@ import (
|
||||
var embeddedFiles embed.FS
|
||||
|
||||
func NewOnboardCommand() *cobra.Command {
|
||||
var encrypt bool
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "onboard",
|
||||
Aliases: []string{"o"},
|
||||
Short: "Initialize picoclaw configuration and workspace",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
onboard()
|
||||
onboard(encrypt)
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().BoolVar(&encrypt, "enc", false,
|
||||
"Enable credential encryption (generates SSH key and prompts for passphrase)")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
@@ -24,6 +24,9 @@ func TestNewOnboardCommand(t *testing.T) {
|
||||
assert.Nil(t, cmd.PersistentPreRun)
|
||||
assert.Nil(t, cmd.PersistentPostRun)
|
||||
|
||||
assert.False(t, cmd.HasFlags())
|
||||
assert.True(t, cmd.HasFlags())
|
||||
encFlag := cmd.Flags().Lookup("enc")
|
||||
require.NotNil(t, encFlag, "expected --enc flag to be registered")
|
||||
assert.Equal(t, "false", encFlag.DefValue, "--enc should default to false")
|
||||
assert.False(t, cmd.HasSubCommands())
|
||||
}
|
||||
|
||||
@@ -6,25 +6,71 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"golang.org/x/term"
|
||||
|
||||
"github.com/sipeed/picoclaw/cmd/picoclaw/internal"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/credential"
|
||||
)
|
||||
|
||||
func onboard() {
|
||||
func onboard(encrypt bool) {
|
||||
configPath := internal.GetConfigPath()
|
||||
|
||||
configExists := false
|
||||
if _, err := os.Stat(configPath); err == nil {
|
||||
fmt.Printf("Config already exists at %s\n", configPath)
|
||||
fmt.Print("Overwrite? (y/n): ")
|
||||
var response string
|
||||
fmt.Scanln(&response)
|
||||
if response != "y" {
|
||||
fmt.Println("Aborted.")
|
||||
return
|
||||
configExists = true
|
||||
if encrypt {
|
||||
// Only ask for confirmation when *both* config and SSH key already exist,
|
||||
// indicating a full re-onboard that would reset the config to defaults.
|
||||
sshKeyPath, _ := credential.DefaultSSHKeyPath()
|
||||
if _, err := os.Stat(sshKeyPath); err == nil {
|
||||
// Both exist — confirm a full reset.
|
||||
fmt.Printf("Config already exists at %s\n", configPath)
|
||||
fmt.Print("Overwrite config with defaults? (y/n): ")
|
||||
var response string
|
||||
fmt.Scanln(&response)
|
||||
if response != "y" {
|
||||
fmt.Println("Aborted.")
|
||||
return
|
||||
}
|
||||
configExists = false // user agreed to reset; treat as fresh
|
||||
}
|
||||
// Config exists but SSH key is missing — keep existing config, only add SSH key.
|
||||
}
|
||||
}
|
||||
|
||||
cfg := config.DefaultConfig()
|
||||
var err error
|
||||
if encrypt {
|
||||
fmt.Println("\nSet up credential encryption")
|
||||
fmt.Println("-----------------------------")
|
||||
passphrase, pErr := promptPassphrase()
|
||||
if pErr != nil {
|
||||
fmt.Printf("Error: %v\n", pErr)
|
||||
os.Exit(1)
|
||||
}
|
||||
// Expose the passphrase to credential.PassphraseProvider (which calls
|
||||
// os.Getenv by default) so that SaveConfig can encrypt api_keys.
|
||||
// This process is a one-shot CLI tool; the env var is never exposed outside
|
||||
// the current process and disappears when it exits.
|
||||
os.Setenv(credential.PassphraseEnvVar, passphrase)
|
||||
|
||||
if err = setupSSHKey(); err != nil {
|
||||
fmt.Printf("Error generating SSH key: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
var cfg *config.Config
|
||||
if configExists {
|
||||
// Preserve the existing config; SaveConfig will re-encrypt api_keys with the new passphrase.
|
||||
cfg, err = config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
fmt.Printf("Error loading existing config: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
} else {
|
||||
cfg = config.DefaultConfig()
|
||||
}
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
fmt.Printf("Error saving config: %v\n", err)
|
||||
os.Exit(1)
|
||||
@@ -33,9 +79,17 @@ func onboard() {
|
||||
workspace := cfg.WorkspacePath()
|
||||
createWorkspaceTemplates(workspace)
|
||||
|
||||
fmt.Printf("%s picoclaw is ready!\n", internal.Logo)
|
||||
fmt.Printf("\n%s picoclaw is ready!\n", internal.Logo)
|
||||
fmt.Println("\nNext steps:")
|
||||
fmt.Println(" 1. Add your API key to", configPath)
|
||||
if encrypt {
|
||||
fmt.Println(" 1. Set your encryption passphrase before starting picoclaw:")
|
||||
fmt.Println(" export PICOCLAW_KEY_PASSPHRASE=<your-passphrase> # Linux/macOS")
|
||||
fmt.Println(" set PICOCLAW_KEY_PASSPHRASE=<your-passphrase> # Windows cmd")
|
||||
fmt.Println("")
|
||||
fmt.Println(" 2. Add your API key to", configPath)
|
||||
} else {
|
||||
fmt.Println(" 1. Add your API key to", configPath)
|
||||
}
|
||||
fmt.Println("")
|
||||
fmt.Println(" Recommended:")
|
||||
fmt.Println(" - OpenRouter: https://openrouter.ai/keys (access 100+ models)")
|
||||
@@ -43,7 +97,62 @@ func onboard() {
|
||||
fmt.Println("")
|
||||
fmt.Println(" See README.md for 17+ supported providers.")
|
||||
fmt.Println("")
|
||||
fmt.Println(" 2. Chat: picoclaw agent -m \"Hello!\"")
|
||||
fmt.Println(" 3. Chat: picoclaw agent -m \"Hello!\"")
|
||||
}
|
||||
|
||||
// promptPassphrase reads the encryption passphrase twice from the terminal
|
||||
// (with echo disabled) and returns it. Returns an error if the passphrase is
|
||||
// empty or if the two inputs do not match.
|
||||
func promptPassphrase() (string, error) {
|
||||
fmt.Print("Enter passphrase for credential encryption: ")
|
||||
p1, err := term.ReadPassword(int(os.Stdin.Fd()))
|
||||
fmt.Println()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("reading passphrase: %w", err)
|
||||
}
|
||||
if len(p1) == 0 {
|
||||
return "", fmt.Errorf("passphrase must not be empty")
|
||||
}
|
||||
|
||||
fmt.Print("Confirm passphrase: ")
|
||||
p2, err := term.ReadPassword(int(os.Stdin.Fd()))
|
||||
fmt.Println()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("reading passphrase confirmation: %w", err)
|
||||
}
|
||||
|
||||
if string(p1) != string(p2) {
|
||||
return "", fmt.Errorf("passphrases do not match")
|
||||
}
|
||||
return string(p1), nil
|
||||
}
|
||||
|
||||
// setupSSHKey generates the picoclaw-specific SSH key at ~/.ssh/picoclaw_ed25519.key.
|
||||
// If the key already exists the user is warned and asked to confirm overwrite.
|
||||
// Answering anything other than "y" keeps the existing key (not an error).
|
||||
func setupSSHKey() error {
|
||||
keyPath, err := credential.DefaultSSHKeyPath()
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot determine SSH key path: %w", err)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(keyPath); err == nil {
|
||||
fmt.Printf("\n⚠️ WARNING: %s already exists.\n", keyPath)
|
||||
fmt.Println(" Overwriting will invalidate any credentials previously encrypted with this key.")
|
||||
fmt.Print(" Overwrite? (y/n): ")
|
||||
var response string
|
||||
fmt.Scanln(&response)
|
||||
if response != "y" {
|
||||
fmt.Println("Keeping existing SSH key.")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if err := credential.GenerateSSHKey(keyPath); err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Printf("SSH key generated: %s\n", keyPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
func createWorkspaceTemplates(workspace string) {
|
||||
|
||||
@@ -53,6 +53,12 @@
|
||||
"api_key": "your-modelscope-access-token",
|
||||
"api_base": "https://api-inference.modelscope.cn/v1"
|
||||
},
|
||||
{
|
||||
"model_name": "azure-gpt5",
|
||||
"model": "azure/my-gpt5-deployment",
|
||||
"api_key": "your-azure-api-key",
|
||||
"api_base": "https://your-resource.openai.azure.com"
|
||||
},
|
||||
{
|
||||
"model_name": "loadbalanced-gpt-5.4",
|
||||
"model": "openai/gpt-5.4",
|
||||
@@ -512,6 +518,7 @@
|
||||
},
|
||||
"gateway": {
|
||||
"host": "127.0.0.1",
|
||||
"port": 18790
|
||||
"port": 18790,
|
||||
"hot_reload": false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,168 @@
|
||||
# Credential Encryption
|
||||
|
||||
PicoClaw supports encrypting `api_key` values in `model_list` configuration entries.
|
||||
Encrypted keys are stored as `enc://<base64>` strings and decrypted automatically at startup.
|
||||
|
||||
---
|
||||
|
||||
## Quick Start
|
||||
|
||||
**1. Set your passphrase**
|
||||
|
||||
```bash
|
||||
export PICOCLAW_KEY_PASSPHRASE="your-passphrase"
|
||||
```
|
||||
|
||||
**2. Encrypt an API key**
|
||||
|
||||
Run `picoclaw onboard` — it prompts for your passphrase and generates the SSH key,
|
||||
then automatically re-encrypts any plaintext `api_key` entries in your config on
|
||||
the next `SaveConfig` call. The resulting `enc://` value will look like:
|
||||
|
||||
```
|
||||
enc://AAAA...base64...
|
||||
```
|
||||
|
||||
**3. Paste the output into your config**
|
||||
|
||||
```json
|
||||
{
|
||||
"model_list": [
|
||||
{
|
||||
"model_name": "gpt-4o",
|
||||
"api_key": "enc://AAAA...base64...",
|
||||
"base_url": "https://api.openai.com/v1"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Supported `api_key` Formats
|
||||
|
||||
| Format | Example | Behaviour |
|
||||
|--------|---------|-----------|
|
||||
| Plaintext | `sk-abc123` | Used as-is |
|
||||
| File reference | `file://openai.key` | Content read from the same directory as the config file |
|
||||
| Encrypted | `enc://<base64>` | Decrypted at startup using `PICOCLAW_KEY_PASSPHRASE` |
|
||||
| Empty | `""` | Passed through unchanged (used with `auth_method: oauth`) |
|
||||
|
||||
---
|
||||
|
||||
## Cryptographic Design
|
||||
|
||||
### Key Derivation
|
||||
|
||||
Encryption uses **HKDF-SHA256** with an optional SSH private key as a second factor.
|
||||
|
||||
```
|
||||
Without SSH key (passphrase only):
|
||||
|
||||
ikm = SHA256(passphrase)
|
||||
aes_key = HKDF-SHA256(ikm, salt, info="picoclaw-credential-v1", 32 bytes)
|
||||
|
||||
|
||||
With SSH key (recommended):
|
||||
|
||||
sshHash = SHA256(ssh_private_key_file_bytes)
|
||||
ikm = HMAC-SHA256(key=sshHash, message=passphrase)
|
||||
aes_key = HKDF-SHA256(ikm, salt, info="picoclaw-credential-v1", 32 bytes)
|
||||
```
|
||||
|
||||
### Encryption
|
||||
|
||||
```
|
||||
AES-256-GCM(key=aes_key, nonce=random[12], plaintext=api_key)
|
||||
```
|
||||
|
||||
### Wire Format
|
||||
|
||||
```
|
||||
enc://<base64( salt[16] + nonce[12] + ciphertext )>
|
||||
```
|
||||
|
||||
| Field | Size | Description |
|
||||
|-------|------|-------------|
|
||||
| `salt` | 16 bytes | Random per encryption; fed into HKDF |
|
||||
| `nonce` | 12 bytes | Random per encryption; AES-GCM IV |
|
||||
| `ciphertext` | variable | AES-256-GCM ciphertext + 16-byte authentication tag |
|
||||
|
||||
The GCM authentication tag is appended to the ciphertext automatically. Any tampering causes decryption to fail with an error rather than returning corrupt plaintext.
|
||||
|
||||
### Performance
|
||||
|
||||
| Operation | Time (ARM Cortex-A) |
|
||||
|-----------|---------------------|
|
||||
| Key derivation (HKDF) | < 1 ms |
|
||||
| AES-256-GCM decrypt | < 1 ms |
|
||||
| **Total startup overhead** | **< 2 ms per key** |
|
||||
|
||||
---
|
||||
|
||||
## Two-Factor Security with SSH Key
|
||||
|
||||
When a SSH private key is provided, breaking the encryption requires **both**:
|
||||
|
||||
1. The **passphrase** (`PICOCLAW_KEY_PASSPHRASE`)
|
||||
2. The **SSH private key file**
|
||||
|
||||
This means a leaked config file alone is not sufficient to recover the API key, even if the passphrase is weak. The SSH key contributes 256 bits of entropy (Ed25519) regardless of passphrase strength.
|
||||
|
||||
### Threat Model
|
||||
|
||||
| Attacker Has | Can Decrypt? |
|
||||
|---|---|
|
||||
| Config file only | No — needs passphrase + SSH key |
|
||||
| SSH key only | No — needs passphrase |
|
||||
| Passphrase only | No — needs SSH key |
|
||||
| Config file + SSH key + passphrase | Yes — full compromise |
|
||||
|
||||
---
|
||||
|
||||
## Environment Variables
|
||||
|
||||
| Variable | Required | Description |
|
||||
|----------|----------|-------------|
|
||||
| `PICOCLAW_KEY_PASSPHRASE` | Yes (for `enc://`) | Passphrase used for key derivation |
|
||||
| `PICOCLAW_SSH_KEY_PATH` | No | Path to SSH private key. Set to `""` to disable auto-detection and use passphrase-only mode |
|
||||
|
||||
### SSH Key Auto-Detection
|
||||
|
||||
If `PICOCLAW_SSH_KEY_PATH` is not set, PicoClaw looks for the picoclaw-specific key:
|
||||
|
||||
```
|
||||
~/.ssh/picoclaw_ed25519.key
|
||||
```
|
||||
|
||||
This dedicated file avoids conflicts with the user's existing SSH keys.
|
||||
Run `picoclaw onboard` to generate it automatically.
|
||||
|
||||
`os.UserHomeDir()` is used for cross-platform home directory resolution (reads `USERPROFILE` on Windows, `HOME` on Unix/macOS).
|
||||
|
||||
To explicitly disable SSH key usage and use passphrase-only mode:
|
||||
|
||||
```bash
|
||||
export PICOCLAW_SSH_KEY_PATH=""
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Migration
|
||||
|
||||
Because the only secret material is `PICOCLAW_KEY_PASSPHRASE` and the SSH private key file, migration is straightforward:
|
||||
|
||||
1. Copy the config file to the new machine.
|
||||
2. Set `PICOCLAW_KEY_PASSPHRASE` to the same value.
|
||||
3. Copy the SSH private key file to the same path (or set `PICOCLAW_SSH_KEY_PATH` to its new location).
|
||||
|
||||
No re-encryption is needed.
|
||||
|
||||
---
|
||||
|
||||
## Security Considerations
|
||||
|
||||
- **Passphrase strength matters in passphrase-only mode.** Without an SSH key, a weak passphrase can be brute-forced offline. Use `PICOCLAW_SSH_KEY_PATH=""` only in environments where no SSH key is available and the passphrase is sufficiently strong (≥ 32 random characters).
|
||||
- **The SSH key is read-only at runtime.** PicoClaw never writes to or modifies the SSH key file.
|
||||
- **Plaintext keys remain supported.** Existing configs without `enc://` are unaffected.
|
||||
- **The `enc://` format is versioned** via the HKDF `info` field (`picoclaw-credential-v1`), allowing future algorithm upgrades without breaking existing encrypted values.
|
||||
@@ -84,6 +84,22 @@ By default, PicoClaw blocks the following dangerous commands:
|
||||
- Git: `git push`, `git force`
|
||||
- Other: `eval`, `source *.sh`
|
||||
|
||||
### Known Architectural Limitation
|
||||
|
||||
The exec guard only validates the top-level command sent to PicoClaw. It does **not** recursively inspect child
|
||||
processes spawned by build tools or scripts after that command starts running.
|
||||
|
||||
Examples of workflows that can bypass the direct command guard once the initial command is allowed:
|
||||
|
||||
- `make run`
|
||||
- `go run ./cmd/...`
|
||||
- `cargo run`
|
||||
- `npm run build`
|
||||
|
||||
This means the guard is useful for blocking obviously dangerous direct commands, but it is **not** a full sandbox for
|
||||
unreviewed build pipelines. If your threat model includes untrusted code in the workspace, use stronger isolation such
|
||||
as containers, VMs, or an approval flow around build-and-run commands.
|
||||
|
||||
### Configuration Example
|
||||
|
||||
```json
|
||||
|
||||
@@ -3,10 +3,11 @@ module github.com/sipeed/picoclaw
|
||||
go 1.25.7
|
||||
|
||||
require (
|
||||
fyne.io/systray v1.12.0
|
||||
github.com/adhocore/gronx v1.19.6
|
||||
github.com/anthropics/anthropic-sdk-go v1.22.1
|
||||
github.com/anthropics/anthropic-sdk-go v1.26.0
|
||||
github.com/bwmarrin/discordgo v0.29.0
|
||||
github.com/caarlos0/env/v11 v11.3.1
|
||||
github.com/caarlos0/env/v11 v11.4.0
|
||||
github.com/ergochat/irc-go v0.5.0
|
||||
github.com/ergochat/readline v0.1.3
|
||||
github.com/gdamore/tcell/v2 v2.13.8
|
||||
@@ -17,7 +18,7 @@ require (
|
||||
github.com/larksuite/oapi-sdk-go/v3 v3.5.3
|
||||
github.com/mdp/qrterminal/v3 v3.2.1
|
||||
github.com/modelcontextprotocol/go-sdk v1.3.1
|
||||
github.com/mymmrac/telego v1.6.0
|
||||
github.com/mymmrac/telego v1.7.0
|
||||
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1
|
||||
github.com/openai/openai-go/v3 v3.22.0
|
||||
github.com/rivo/tview v0.42.0
|
||||
@@ -27,7 +28,8 @@ require (
|
||||
github.com/stretchr/testify v1.11.1
|
||||
github.com/tencent-connect/botgo v0.2.1
|
||||
go.mau.fi/whatsmeow v0.0.0-20260219150138-7ae702b1eed4
|
||||
golang.org/x/oauth2 v0.35.0
|
||||
golang.org/x/oauth2 v0.36.0
|
||||
golang.org/x/term v0.40.0
|
||||
golang.org/x/time v0.14.0
|
||||
google.golang.org/protobuf v1.36.11
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
@@ -43,6 +45,7 @@ require (
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/elliotchance/orderedmap/v3 v3.1.0 // indirect
|
||||
github.com/gdamore/encoding v1.0.1 // indirect
|
||||
github.com/godbus/dbus/v5 v5.1.0 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0 // indirect
|
||||
github.com/mattn/go-colorable v0.1.14 // indirect
|
||||
@@ -59,7 +62,6 @@ require (
|
||||
go.mau.fi/libsignal v0.2.1 // indirect
|
||||
go.mau.fi/util v0.9.6 // indirect
|
||||
golang.org/x/exp v0.0.0-20260212183809-81e46e3db34a // indirect
|
||||
golang.org/x/term v0.40.0 // indirect
|
||||
golang.org/x/text v0.34.0 // indirect
|
||||
modernc.org/libc v1.67.6 // indirect
|
||||
modernc.org/mathutil v1.7.1 // indirect
|
||||
@@ -73,7 +75,7 @@ require (
|
||||
github.com/bytedance/sonic v1.15.0 // indirect
|
||||
github.com/bytedance/sonic/loader v0.5.0 // indirect
|
||||
github.com/cloudwego/base64x v0.1.6 // indirect
|
||||
github.com/github/copilot-sdk/go v0.1.23
|
||||
github.com/github/copilot-sdk/go v0.1.32
|
||||
github.com/go-resty/resty/v2 v2.17.1 // indirect
|
||||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/google/jsonschema-go v0.4.2 // indirect
|
||||
@@ -87,10 +89,10 @@ require (
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||
github.com/valyala/fasthttp v1.69.0 // indirect
|
||||
github.com/valyala/fastjson v1.6.7 // indirect
|
||||
github.com/valyala/fastjson v1.6.10 // indirect
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
||||
golang.org/x/arch v0.24.0 // indirect
|
||||
golang.org/x/crypto v0.48.0 // indirect
|
||||
golang.org/x/crypto v0.48.0
|
||||
golang.org/x/net v0.51.0 // indirect
|
||||
golang.org/x/sync v0.19.0 // indirect
|
||||
golang.org/x/sys v0.41.0 // indirect
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
|
||||
filippo.io/edwards25519 v1.1.1 h1:YpjwWWlNmGIDyXOn8zLzqiD+9TyIlPhGFG96P39uBpw=
|
||||
filippo.io/edwards25519 v1.1.1/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
|
||||
fyne.io/systray v1.12.0 h1:CA1Kk0e2zwFlxtc02L3QFSiIbxJ/P0n582YrZHT7aTM=
|
||||
fyne.io/systray v1.12.0/go.mod h1:RVwqP9nYMo7h5zViCBHri2FgjXF7H2cub7MAq4NSoLs=
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
|
||||
github.com/adhocore/gronx v1.19.6 h1:5KNVcoR9ACgL9HhEqCm5QXsab/gI4QDIybTAWcXDKDc=
|
||||
@@ -11,8 +13,8 @@ github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883 h1:bvNMNQO63//z+xNg
|
||||
github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883/go.mod h1:rCTlJbsFo29Kk6CurOXKm700vrz8f0KW0JNfpkRJY/8=
|
||||
github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ=
|
||||
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
|
||||
github.com/anthropics/anthropic-sdk-go v1.22.1 h1:xbsc3vJKCX/ELDZSpTNfz9wCgrFsamwFewPb1iI0Xh0=
|
||||
github.com/anthropics/anthropic-sdk-go v1.22.1/go.mod h1:WTz31rIUHUHqai2UslPpw5CwXrQP3geYBioRV4WOLvE=
|
||||
github.com/anthropics/anthropic-sdk-go v1.26.0 h1:oUTzFaUpAevfuELAP1sjL6CQJ9HHAfT7CoSYSac11PY=
|
||||
github.com/anthropics/anthropic-sdk-go v1.26.0/go.mod h1:qUKmaW+uuPB64iy1l+4kOSvaLqPXnHTTBKH6RVZ7q5Q=
|
||||
github.com/beeper/argo-go v1.1.2 h1:UQI2G8F+NLfGTOmTUI0254pGKx/HUU/etbUGTJv91Fs=
|
||||
github.com/beeper/argo-go v1.1.2/go.mod h1:M+LJAnyowKVQ6Rdj6XYGEn+qcVFkb3R/MUpqkGR0hM4=
|
||||
github.com/bwmarrin/discordgo v0.29.0 h1:FmWeXFaKUwrcL3Cx65c20bTRW+vOb6k8AnaP+EgjDno=
|
||||
@@ -23,8 +25,8 @@ github.com/bytedance/sonic v1.15.0 h1:/PXeWFaR5ElNcVE84U0dOHjiMHQOwNIx3K4ymzh/uS
|
||||
github.com/bytedance/sonic v1.15.0/go.mod h1:tFkWrPz0/CUCLEF4ri4UkHekCIcdnkqXw9VduqpJh0k=
|
||||
github.com/bytedance/sonic/loader v0.5.0 h1:gXH3KVnatgY7loH5/TkeVyXPfESoqSBSBEiDd5VjlgE=
|
||||
github.com/bytedance/sonic/loader v0.5.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo=
|
||||
github.com/caarlos0/env/v11 v11.3.1 h1:cArPWC15hWmEt+gWk7YBi7lEXTXCvpaSdCiZE2X5mCA=
|
||||
github.com/caarlos0/env/v11 v11.3.1/go.mod h1:qupehSf/Y0TUTsxKywqRt/vJjN5nz6vauiYEUUr8P4U=
|
||||
github.com/caarlos0/env/v11 v11.4.0 h1:Kcb6t5kIIr4XkoQC9AF2j+8E1Jsrl3Wz/hhm1LtoGAc=
|
||||
github.com/caarlos0/env/v11 v11.4.0/go.mod h1:qupehSf/Y0TUTsxKywqRt/vJjN5nz6vauiYEUUr8P4U=
|
||||
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M=
|
||||
@@ -38,6 +40,8 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/dnaeon/go-vcr v1.2.0 h1:zHCHvJYTMh1N7xnV7zf1m1GPBF9Ad0Jk/whtQ1663qI=
|
||||
github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/elliotchance/orderedmap/v3 v3.1.0 h1:j4DJ5ObEmMBt/lcwIecKcoRxIQUEnw0L804lXYDt/pg=
|
||||
@@ -52,8 +56,8 @@ github.com/gdamore/encoding v1.0.1 h1:YzKZckdBL6jVt2Gc+5p82qhrGiqMdG/eNs6Wy0u3Uh
|
||||
github.com/gdamore/encoding v1.0.1/go.mod h1:0Z0cMFinngz9kS1QfMjCP8TY7em3bZYeeklsSDPivEo=
|
||||
github.com/gdamore/tcell/v2 v2.13.8 h1:Mys/Kl5wfC/GcC5Cx4C2BIQH9dbnhnkPgS9/wF3RlfU=
|
||||
github.com/gdamore/tcell/v2 v2.13.8/go.mod h1:+Wfe208WDdB7INEtCsNrAN6O2m+wsTPk1RAovjaILlo=
|
||||
github.com/github/copilot-sdk/go v0.1.23 h1:uExtO/inZQndCZMiSAA1hvXINiz9tqo/MZgQzFzurxw=
|
||||
github.com/github/copilot-sdk/go v0.1.23/go.mod h1:GdwwBfMbm9AABLEM3x5IZKw4ZfwCYxZ1BgyytmZenQ0=
|
||||
github.com/github/copilot-sdk/go v0.1.32 h1:wc9SFWwxXhJts6vyzzboPLJqcEJGnHE8rMCAY1RrUgo=
|
||||
github.com/github/copilot-sdk/go v0.1.32/go.mod h1:qc2iEF7hdO8kzSvbyGvrcGhuk2fzdW4xTtT0+1EH2ts=
|
||||
github.com/go-redis/redis/v8 v8.11.4/go.mod h1:2Z2wHZXdQpCDXEGzqMockDpNyYvi2l4Pxt6RJr792+w=
|
||||
github.com/go-resty/resty/v2 v2.6.0/go.mod h1:PwvJS6hvaPkjtjNg9ph+VrSD92bi5Zq73w/BIH7cC3Q=
|
||||
github.com/go-resty/resty/v2 v2.17.1 h1:x3aMpHK1YM9e4va/TMDRlusDDoZiQ+ViDu/WpA6xTM4=
|
||||
@@ -62,6 +66,8 @@ github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg78
|
||||
github.com/go-test/deep v1.1.1 h1:0r/53hagsehfO4bzD2Pgr/+RgHqhmf+k1Bpse2cTu1U=
|
||||
github.com/go-test/deep v1.1.1/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE=
|
||||
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||
github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk=
|
||||
github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
|
||||
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
|
||||
@@ -136,8 +142,8 @@ github.com/mdp/qrterminal/v3 v3.2.1 h1:6+yQjiiOsSuXT5n9/m60E54vdgFsw0zhADHhHLrFe
|
||||
github.com/mdp/qrterminal/v3 v3.2.1/go.mod h1:jOTmXvnBsMy5xqLniO0R++Jmjs2sTm9dFSuQ5kpz/SU=
|
||||
github.com/modelcontextprotocol/go-sdk v1.3.1 h1:TfqtNKOIWN4Z1oqmPAiWDC2Jq7K9OdJaooe0teoXASI=
|
||||
github.com/modelcontextprotocol/go-sdk v1.3.1/go.mod h1:DgVX498dMD8UJlseK1S5i1T4tFz2fkBk4xogC3D15nw=
|
||||
github.com/mymmrac/telego v1.6.0 h1:Zc8rgyHozvd/7ZgyrigyHdAF9koHYMfilYfyB6wlFC0=
|
||||
github.com/mymmrac/telego v1.6.0/go.mod h1:xt6ZWA8zi8KmuzryE1ImEdl9JSwjHNpM4yhC7D8hU4Y=
|
||||
github.com/mymmrac/telego v1.7.0 h1:yRO/l00tFGG4nY66ufUKb4ARqv7qx9+LsjQv/b0NEyo=
|
||||
github.com/mymmrac/telego v1.7.0/go.mod h1:pdLV346EgVuq7Xrh3kMggeBiazeHhsdEoK0RTEOPXRM=
|
||||
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
||||
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
|
||||
@@ -216,8 +222,8 @@ github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6Kllzaw
|
||||
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
||||
github.com/valyala/fasthttp v1.69.0 h1:fNLLESD2SooWeh2cidsuFtOcrEi4uB4m1mPrkJMZyVI=
|
||||
github.com/valyala/fasthttp v1.69.0/go.mod h1:4wA4PfAraPlAsJ5jMSqCE2ug5tqUPwKXxVj8oNECGcw=
|
||||
github.com/valyala/fastjson v1.6.7 h1:ZE4tRy0CIkh+qDc5McjatheGX2czdn8slQjomexVpBM=
|
||||
github.com/valyala/fastjson v1.6.7/go.mod h1:CLCAqky6SMuOcxStkYQvblddUtoRxhYMGLrsQns1aXY=
|
||||
github.com/valyala/fastjson v1.6.10 h1:/yjJg8jaVQdYR3arGxPE2X5z89xrlhS0eGXdv+ADTh4=
|
||||
github.com/valyala/fastjson v1.6.10/go.mod h1:e6FubmQouUNP73jtMLmcbxS6ydWIpOfhz34TSfO3JaE=
|
||||
github.com/vektah/gqlparser/v2 v2.5.27 h1:RHPD3JOplpk5mP5JGX8RKZkt2/Vwj/PZv0HxTdwFp0s=
|
||||
github.com/vektah/gqlparser/v2 v2.5.27/go.mod h1:D1/VCZtV3LPnQrcPBeR/q5jkSQIPti0uYCP/RI0gIeo=
|
||||
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
|
||||
@@ -270,8 +276,8 @@ golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U=
|
||||
golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo=
|
||||
golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y=
|
||||
golang.org/x/oauth2 v0.23.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
|
||||
golang.org/x/oauth2 v0.35.0 h1:Mv2mzuHuZuY2+bkyWXIHMfhNdJAdwW3FuWeCPYN5GVQ=
|
||||
golang.org/x/oauth2 v0.35.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
|
||||
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
|
||||
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
@@ -354,6 +360,7 @@ gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWD
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
|
||||
+25
-2
@@ -10,6 +10,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
"github.com/sipeed/picoclaw/pkg/memory"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
@@ -66,7 +67,7 @@ func NewAgentInstance(
|
||||
readRestrict := restrict && !defaults.AllowReadOutsideWorkspace
|
||||
|
||||
// Compile path whitelist patterns from config.
|
||||
allowReadPaths := compilePatterns(cfg.Tools.AllowReadPaths)
|
||||
allowReadPaths := buildAllowReadPatterns(cfg)
|
||||
allowWritePaths := compilePatterns(cfg.Tools.AllowWritePaths)
|
||||
|
||||
toolsRegistry := tools.NewToolRegistry()
|
||||
@@ -82,7 +83,7 @@ func NewAgentInstance(
|
||||
toolsRegistry.Register(tools.NewListDirTool(workspace, readRestrict, allowReadPaths))
|
||||
}
|
||||
if cfg.Tools.IsToolEnabled("exec") {
|
||||
execTool, err := tools.NewExecToolWithConfig(workspace, restrict, cfg)
|
||||
execTool, err := tools.NewExecToolWithConfig(workspace, restrict, cfg, allowReadPaths)
|
||||
if err != nil {
|
||||
log.Fatalf("Critical error: unable to initialize exec tool: %v", err)
|
||||
}
|
||||
@@ -282,6 +283,28 @@ func compilePatterns(patterns []string) []*regexp.Regexp {
|
||||
return compiled
|
||||
}
|
||||
|
||||
func buildAllowReadPatterns(cfg *config.Config) []*regexp.Regexp {
|
||||
var configured []string
|
||||
if cfg != nil {
|
||||
configured = cfg.Tools.AllowReadPaths
|
||||
}
|
||||
|
||||
compiled := compilePatterns(configured)
|
||||
mediaDirPattern := regexp.MustCompile(mediaTempDirPattern())
|
||||
for _, pattern := range compiled {
|
||||
if pattern.String() == mediaDirPattern.String() {
|
||||
return compiled
|
||||
}
|
||||
}
|
||||
|
||||
return append(compiled, mediaDirPattern)
|
||||
}
|
||||
|
||||
func mediaTempDirPattern() string {
|
||||
sep := regexp.QuoteMeta(string(os.PathSeparator))
|
||||
return "^" + regexp.QuoteMeta(filepath.Clean(media.TempDir())) + "(?:" + sep + "|$)"
|
||||
}
|
||||
|
||||
// Close releases resources held by the agent's session store.
|
||||
func (a *AgentInstance) Close() error {
|
||||
if a.Sessions != nil {
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
)
|
||||
|
||||
func TestNewAgentInstance_UsesDefaultsTemperatureAndMaxTokens(t *testing.T) {
|
||||
@@ -160,3 +164,85 @@ func TestNewAgentInstance_ResolveCandidatesFromModelListAlias(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAgentInstance_AllowsMediaTempDirForReadListAndExec(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
mediaDir := media.TempDir()
|
||||
if err := os.MkdirAll(mediaDir, 0o700); err != nil {
|
||||
t.Fatalf("MkdirAll(mediaDir) error = %v", err)
|
||||
}
|
||||
|
||||
mediaFile, err := os.CreateTemp(mediaDir, "instance-tool-*.txt")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateTemp(mediaDir) error = %v", err)
|
||||
}
|
||||
mediaPath := mediaFile.Name()
|
||||
if _, err := mediaFile.WriteString("attachment content"); err != nil {
|
||||
mediaFile.Close()
|
||||
t.Fatalf("WriteString(mediaFile) error = %v", err)
|
||||
}
|
||||
if err := mediaFile.Close(); err != nil {
|
||||
t.Fatalf("Close(mediaFile) error = %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = os.Remove(mediaPath) })
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: workspace,
|
||||
ModelName: "test-model",
|
||||
RestrictToWorkspace: true,
|
||||
},
|
||||
},
|
||||
Tools: config.ToolsConfig{
|
||||
ReadFile: config.ReadFileToolConfig{Enabled: true},
|
||||
ListDir: config.ToolConfig{Enabled: true},
|
||||
Exec: config.ExecConfig{
|
||||
ToolConfig: config.ToolConfig{Enabled: true},
|
||||
EnableDenyPatterns: true,
|
||||
AllowRemote: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, &mockProvider{})
|
||||
|
||||
readTool, ok := agent.Tools.Get("read_file")
|
||||
if !ok {
|
||||
t.Fatal("read_file tool not registered")
|
||||
}
|
||||
readResult := readTool.Execute(context.Background(), map[string]any{"path": mediaPath})
|
||||
if readResult.IsError {
|
||||
t.Fatalf("read_file should allow media temp dir, got: %s", readResult.ForLLM)
|
||||
}
|
||||
if !strings.Contains(readResult.ForLLM, "attachment content") {
|
||||
t.Fatalf("read_file output missing media content: %s", readResult.ForLLM)
|
||||
}
|
||||
|
||||
listTool, ok := agent.Tools.Get("list_dir")
|
||||
if !ok {
|
||||
t.Fatal("list_dir tool not registered")
|
||||
}
|
||||
listResult := listTool.Execute(context.Background(), map[string]any{"path": mediaDir})
|
||||
if listResult.IsError {
|
||||
t.Fatalf("list_dir should allow media temp dir, got: %s", listResult.ForLLM)
|
||||
}
|
||||
if !strings.Contains(listResult.ForLLM, filepath.Base(mediaPath)) {
|
||||
t.Fatalf("list_dir output missing media file: %s", listResult.ForLLM)
|
||||
}
|
||||
|
||||
execTool, ok := agent.Tools.Get("exec")
|
||||
if !ok {
|
||||
t.Fatal("exec tool not registered")
|
||||
}
|
||||
execResult := execTool.Execute(context.Background(), map[string]any{
|
||||
"command": "cat " + filepath.Base(mediaPath),
|
||||
"working_dir": mediaDir,
|
||||
})
|
||||
if execResult.IsError {
|
||||
t.Fatalf("exec should allow media temp dir, got: %s", execResult.ForLLM)
|
||||
}
|
||||
if !strings.Contains(execResult.ForLLM, "attachment content") {
|
||||
t.Fatalf("exec output missing media content: %s", execResult.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
+70
-61
@@ -124,6 +124,8 @@ func registerSharedTools(
|
||||
registry *AgentRegistry,
|
||||
provider providers.LLMProvider,
|
||||
) {
|
||||
allowReadPaths := buildAllowReadPatterns(cfg)
|
||||
|
||||
for _, agentID := range registry.ListAgentIDs() {
|
||||
agent, ok := registry.GetAgent(agentID)
|
||||
if !ok {
|
||||
@@ -202,6 +204,7 @@ func registerSharedTools(
|
||||
cfg.Agents.Defaults.RestrictToWorkspace,
|
||||
cfg.Agents.Defaults.GetMaxMediaSize(),
|
||||
nil,
|
||||
allowReadPaths,
|
||||
)
|
||||
agent.Tools.Register(sendFileTool)
|
||||
}
|
||||
@@ -229,72 +232,75 @@ func registerSharedTools(
|
||||
}
|
||||
}
|
||||
|
||||
// Spawn tool with allowlist checker
|
||||
if cfg.Tools.IsToolEnabled("spawn") {
|
||||
if cfg.Tools.IsToolEnabled("subagent") {
|
||||
subagentManager := tools.NewSubagentManager(provider, agent.Model, agent.Workspace)
|
||||
subagentManager.SetLLMOptions(agent.MaxTokens, agent.Temperature)
|
||||
|
||||
// Set the spawner that links into AgentLoop's turnState
|
||||
subagentManager.SetSpawner(func(
|
||||
ctx context.Context,
|
||||
task, label, targetAgentID string,
|
||||
tls *tools.ToolRegistry,
|
||||
maxTokens int,
|
||||
temperature float64,
|
||||
hasMaxTokens, hasTemperature bool,
|
||||
) (*tools.ToolResult, error) {
|
||||
// 1. Recover parent Turn State from Context
|
||||
parentTS := turnStateFromContext(ctx)
|
||||
if parentTS == nil {
|
||||
// Fallback: If no turnState exists in context, create an isolated ad-hoc root turn state
|
||||
// so that the tool can still function outside of an agent loop (e.g. tests, raw invocations).
|
||||
parentTS = &turnState{
|
||||
ctx: ctx,
|
||||
turnID: "adhoc-root",
|
||||
depth: 0,
|
||||
session: newEphemeralSession(nil),
|
||||
pendingResults: make(chan *tools.ToolResult, 16),
|
||||
concurrencySem: make(chan struct{}, 5),
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Build Tools slice from registry
|
||||
var tlSlice []tools.Tool
|
||||
for _, name := range tls.List() {
|
||||
if t, ok := tls.Get(name); ok {
|
||||
tlSlice = append(tlSlice, t)
|
||||
}
|
||||
}
|
||||
// Spawn and spawn_status tools share a SubagentManager.
|
||||
// Construct it when either tool is enabled (both require subagent).
|
||||
spawnEnabled := cfg.Tools.IsToolEnabled("spawn")
|
||||
spawnStatusEnabled := cfg.Tools.IsToolEnabled("spawn_status")
|
||||
if (spawnEnabled || spawnStatusEnabled) && cfg.Tools.IsToolEnabled("subagent") {
|
||||
subagentManager := tools.NewSubagentManager(provider, agent.Model, agent.Workspace)
|
||||
subagentManager.SetLLMOptions(agent.MaxTokens, agent.Temperature)
|
||||
|
||||
// 3. System Prompt
|
||||
systemPrompt := "You are a subagent. Complete the given task independently and report the result.\n" +
|
||||
"You have access to tools - use them as needed to complete your task.\n" +
|
||||
"After completing the task, provide a clear summary of what was done.\n\n" +
|
||||
"Task: " + task
|
||||
|
||||
// 4. Resolve Model
|
||||
modelToUse := agent.Model
|
||||
if targetAgentID != "" {
|
||||
if targetAgent, ok := al.GetRegistry().GetAgent(targetAgentID); ok {
|
||||
modelToUse = targetAgent.Model
|
||||
}
|
||||
// Set the spawner that links into AgentLoop's turnState
|
||||
subagentManager.SetSpawner(func(
|
||||
ctx context.Context,
|
||||
task, label, targetAgentID string,
|
||||
tls *tools.ToolRegistry,
|
||||
maxTokens int,
|
||||
temperature float64,
|
||||
hasMaxTokens, hasTemperature bool,
|
||||
) (*tools.ToolResult, error) {
|
||||
// 1. Recover parent Turn State from Context
|
||||
parentTS := turnStateFromContext(ctx)
|
||||
if parentTS == nil {
|
||||
// Fallback: If no turnState exists in context, create an isolated ad-hoc root turn state
|
||||
// so that the tool can still function outside of an agent loop (e.g. tests, raw invocations).
|
||||
parentTS = &turnState{
|
||||
ctx: ctx,
|
||||
turnID: "adhoc-root",
|
||||
depth: 0,
|
||||
session: newEphemeralSession(nil),
|
||||
pendingResults: make(chan *tools.ToolResult, 16),
|
||||
concurrencySem: make(chan struct{}, 5),
|
||||
}
|
||||
}
|
||||
|
||||
// 5. Build SubTurnConfig
|
||||
cfg := SubTurnConfig{
|
||||
Model: modelToUse,
|
||||
Tools: tlSlice,
|
||||
SystemPrompt: systemPrompt,
|
||||
// 2. Build Tools slice from registry
|
||||
var tlSlice []tools.Tool
|
||||
for _, name := range tls.List() {
|
||||
if t, ok := tls.Get(name); ok {
|
||||
tlSlice = append(tlSlice, t)
|
||||
}
|
||||
if hasMaxTokens {
|
||||
cfg.MaxTokens = maxTokens
|
||||
}
|
||||
|
||||
// 3. System Prompt
|
||||
systemPrompt := "You are a subagent. Complete the given task independently and report the result.\n" +
|
||||
"You have access to tools - use them as needed to complete your task.\n" +
|
||||
"After completing the task, provide a clear summary of what was done.\n\n" +
|
||||
"Task: " + task
|
||||
|
||||
// 4. Resolve Model
|
||||
modelToUse := agent.Model
|
||||
if targetAgentID != "" {
|
||||
if targetAgent, ok := al.GetRegistry().GetAgent(targetAgentID); ok {
|
||||
modelToUse = targetAgent.Model
|
||||
}
|
||||
}
|
||||
|
||||
// 6. Spawn SubTurn
|
||||
return spawnSubTurn(ctx, al, parentTS, cfg)
|
||||
})
|
||||
// 5. Build SubTurnConfig
|
||||
cfg := SubTurnConfig{
|
||||
Model: modelToUse,
|
||||
Tools: tlSlice,
|
||||
SystemPrompt: systemPrompt,
|
||||
}
|
||||
if hasMaxTokens {
|
||||
cfg.MaxTokens = maxTokens
|
||||
}
|
||||
|
||||
// 6. Spawn SubTurn
|
||||
return spawnSubTurn(ctx, al, parentTS, cfg)
|
||||
})
|
||||
|
||||
if spawnEnabled {
|
||||
spawnTool := tools.NewSpawnTool(subagentManager)
|
||||
currentAgentID := agentID
|
||||
spawnTool.SetAllowlistChecker(func(targetAgentID string) bool {
|
||||
@@ -311,9 +317,12 @@ func registerSharedTools(
|
||||
subagentTool := tools.NewSubagentTool(subagentManager)
|
||||
subagentTool.SetSpawner(spawner)
|
||||
agent.Tools.Register(subagentTool)
|
||||
} else {
|
||||
logger.WarnCF("agent", "spawn tool requires subagent to be enabled", nil)
|
||||
}
|
||||
if spawnStatusEnabled {
|
||||
agent.Tools.Register(tools.NewSpawnStatusTool(subagentManager))
|
||||
}
|
||||
} else if (spawnEnabled || spawnStatusEnabled) && !cfg.Tools.IsToolEnabled("subagent") {
|
||||
logger.WarnCF("agent", "spawn/spawn_status tools require subagent to be enabled", nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -618,7 +618,7 @@ func (c *FeishuChannel) downloadResource(
|
||||
}
|
||||
|
||||
// Write to the shared picoclaw_media directory using a unique name to avoid collisions.
|
||||
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
|
||||
mediaDir := media.TempDir()
|
||||
if mkdirErr := os.MkdirAll(mediaDir, 0o700); mkdirErr != nil {
|
||||
logger.ErrorCF("feishu", "Failed to create media directory", map[string]any{
|
||||
"error": mkdirErr.Error(),
|
||||
|
||||
@@ -357,7 +357,6 @@ func (m *Manager) StartAll(ctx context.Context) error {
|
||||
|
||||
if len(m.channels) == 0 {
|
||||
logger.WarnC("channels", "No channels enabled")
|
||||
return errors.New("no channels enabled")
|
||||
}
|
||||
|
||||
logger.InfoC("channels", "Starting all channels")
|
||||
@@ -397,7 +396,7 @@ func (m *Manager) StartAll(ctx context.Context) error {
|
||||
"addr": m.httpServer.Addr,
|
||||
})
|
||||
if err := m.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
logger.ErrorCF("channels", "Shared HTTP server error", map[string]any{
|
||||
logger.FatalCF("channels", "Shared HTTP server error", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -35,8 +35,6 @@ const (
|
||||
roomKindCacheTTL = 5 * time.Minute
|
||||
roomKindCacheCleanupPeriod = 1 * time.Minute
|
||||
roomKindCacheMaxEntries = 2048
|
||||
|
||||
matrixMediaTempDirName = "picoclaw_media"
|
||||
)
|
||||
|
||||
var matrixMentionHrefRegexp = regexp.MustCompile(`(?i)<a[^>]+href=["']([^"']+)["']`)
|
||||
@@ -1105,7 +1103,7 @@ func (c *MatrixChannel) stripSelfMention(text string) string {
|
||||
}
|
||||
|
||||
func matrixMediaTempDir() (string, error) {
|
||||
mediaDir := filepath.Join(os.TempDir(), matrixMediaTempDirName)
|
||||
mediaDir := media.TempDir()
|
||||
if err := os.MkdirAll(mediaDir, 0o700); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"maunium.net/go/mautrix/id"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
)
|
||||
|
||||
func TestMatrixLocalpartMentionRegexp(t *testing.T) {
|
||||
@@ -165,7 +166,7 @@ func TestMatrixMediaTempDir(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("matrixMediaTempDir failed: %v", err)
|
||||
}
|
||||
if filepath.Base(dir) != matrixMediaTempDirName {
|
||||
if filepath.Base(dir) != media.TempDirName {
|
||||
t.Fatalf("unexpected media dir base: %q", filepath.Base(dir))
|
||||
}
|
||||
|
||||
|
||||
@@ -251,7 +251,13 @@ func (c *PicoChannel) handleWebSocket(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := c.upgrader.Upgrade(w, r, nil)
|
||||
// Echo the matched subprotocol back so the browser accepts the upgrade.
|
||||
var responseHeader http.Header
|
||||
if proto := c.matchedSubprotocol(r); proto != "" {
|
||||
responseHeader = http.Header{"Sec-WebSocket-Protocol": {proto}}
|
||||
}
|
||||
|
||||
conn, err := c.upgrader.Upgrade(w, r, responseHeader)
|
||||
if err != nil {
|
||||
logger.ErrorCF("pico", "WebSocket upgrade failed", map[string]any{
|
||||
"error": err.Error(),
|
||||
@@ -282,8 +288,10 @@ func (c *PicoChannel) handleWebSocket(w http.ResponseWriter, r *http.Request) {
|
||||
go c.readLoop(pc)
|
||||
}
|
||||
|
||||
// authenticate checks the Bearer token from the Authorization header.
|
||||
// Query parameter authentication is only allowed when AllowTokenQuery is explicitly enabled.
|
||||
// authenticate checks the request for a valid token:
|
||||
// 1. Authorization: Bearer <token> header
|
||||
// 2. Sec-WebSocket-Protocol "token.<value>" (for browsers that can't set headers)
|
||||
// 3. Query parameter "token" (only when AllowTokenQuery is on)
|
||||
func (c *PicoChannel) authenticate(r *http.Request) bool {
|
||||
token := c.config.Token
|
||||
if token == "" {
|
||||
@@ -298,6 +306,11 @@ func (c *PicoChannel) authenticate(r *http.Request) bool {
|
||||
}
|
||||
}
|
||||
|
||||
// Check Sec-WebSocket-Protocol subprotocol ("token.<value>")
|
||||
if c.matchedSubprotocol(r) != "" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check query parameter only when explicitly allowed
|
||||
if c.config.AllowTokenQuery {
|
||||
if r.URL.Query().Get("token") == token {
|
||||
@@ -308,6 +321,18 @@ func (c *PicoChannel) authenticate(r *http.Request) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// matchedSubprotocol returns the "token.<value>" subprotocol that matches
|
||||
// the configured token, or "" if none do.
|
||||
func (c *PicoChannel) matchedSubprotocol(r *http.Request) string {
|
||||
token := c.config.Token
|
||||
for _, proto := range websocket.Subprotocols(r) {
|
||||
if after, ok := strings.CutPrefix(proto, "token."); ok && after == token {
|
||||
return proto
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// readLoop reads messages from a WebSocket connection.
|
||||
func (c *PicoChannel) readLoop(pc *picoConn) {
|
||||
defer func() {
|
||||
|
||||
+79
-6
@@ -4,11 +4,13 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/caarlos0/env/v11"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/credential"
|
||||
"github.com/sipeed/picoclaw/pkg/fileutil"
|
||||
)
|
||||
|
||||
@@ -624,8 +626,9 @@ func (c *ModelConfig) Validate() error {
|
||||
}
|
||||
|
||||
type GatewayConfig struct {
|
||||
Host string `json:"host" env:"PICOCLAW_GATEWAY_HOST"`
|
||||
Port int `json:"port" env:"PICOCLAW_GATEWAY_PORT"`
|
||||
Host string `json:"host" env:"PICOCLAW_GATEWAY_HOST"`
|
||||
Port int `json:"port" env:"PICOCLAW_GATEWAY_PORT"`
|
||||
HotReload bool `json:"hot_reload" env:"PICOCLAW_GATEWAY_HOT_RELOAD"`
|
||||
}
|
||||
|
||||
type ToolDiscoveryConfig struct {
|
||||
@@ -698,8 +701,9 @@ type WebToolsConfig struct {
|
||||
}
|
||||
|
||||
type CronToolsConfig struct {
|
||||
ToolConfig ` envPrefix:"PICOCLAW_TOOLS_CRON_"`
|
||||
ExecTimeoutMinutes int ` env:"PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES" json:"exec_timeout_minutes"` // 0 means no timeout
|
||||
ToolConfig ` envPrefix:"PICOCLAW_TOOLS_CRON_"`
|
||||
ExecTimeoutMinutes int ` env:"PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES" json:"exec_timeout_minutes"` // 0 means no timeout
|
||||
AllowCommand bool ` env:"PICOCLAW_TOOLS_CRON_ALLOW_COMMAND" json:"allow_command"`
|
||||
}
|
||||
|
||||
type ExecConfig struct {
|
||||
@@ -749,6 +753,7 @@ type ToolsConfig struct {
|
||||
ReadFile ReadFileToolConfig `json:"read_file" envPrefix:"PICOCLAW_TOOLS_READ_FILE_"`
|
||||
SendFile ToolConfig `json:"send_file" envPrefix:"PICOCLAW_TOOLS_SEND_FILE_"`
|
||||
Spawn ToolConfig `json:"spawn" envPrefix:"PICOCLAW_TOOLS_SPAWN_"`
|
||||
SpawnStatus ToolConfig `json:"spawn_status" envPrefix:"PICOCLAW_TOOLS_SPAWN_STATUS_"`
|
||||
SPI ToolConfig `json:"spi" envPrefix:"PICOCLAW_TOOLS_SPI_"`
|
||||
Subagent ToolConfig `json:"subagent" envPrefix:"PICOCLAW_TOOLS_SUBAGENT_"`
|
||||
WebFetch ToolConfig `json:"web_fetch" envPrefix:"PICOCLAW_TOOLS_WEB_FETCH_"`
|
||||
@@ -838,10 +843,24 @@ func LoadConfig(path string) (*Config, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if passphrase := credential.PassphraseProvider(); passphrase != "" {
|
||||
for _, m := range cfg.ModelList {
|
||||
if m.APIKey != "" && !strings.HasPrefix(m.APIKey, "enc://") && !strings.HasPrefix(m.APIKey, "file://") {
|
||||
fmt.Fprintf(os.Stderr,
|
||||
"picoclaw: warning: model %q has a plaintext api_key; call SaveConfig to encrypt it\n",
|
||||
m.ModelName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := env.Parse(cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := resolveAPIKeys(cfg.ModelList, filepath.Dir(path)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Migrate legacy channel config fields to new unified structures
|
||||
cfg.migrateChannelConfigs()
|
||||
|
||||
@@ -858,6 +877,48 @@ func LoadConfig(path string) (*Config, error) {
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// encryptPlaintextAPIKeys returns a copy of models with plaintext api_key values
|
||||
// encrypted. Returns (nil, nil) when nothing changed (all keys already sealed or
|
||||
// empty). Returns (nil, error) if any key fails to encrypt — callers must treat
|
||||
// this as a hard failure to prevent a mixed plaintext/ciphertext state on disk.
|
||||
// Symmetric counterpart of resolveAPIKeys: both operate purely on []ModelConfig
|
||||
// and leave JSON marshaling to the caller.
|
||||
func encryptPlaintextAPIKeys(models []ModelConfig, passphrase string) ([]ModelConfig, error) {
|
||||
sealed := make([]ModelConfig, len(models))
|
||||
copy(sealed, models)
|
||||
changed := false
|
||||
for i := range sealed {
|
||||
m := &sealed[i]
|
||||
if m.APIKey == "" || strings.HasPrefix(m.APIKey, "enc://") || strings.HasPrefix(m.APIKey, "file://") {
|
||||
continue
|
||||
}
|
||||
encrypted, err := credential.Encrypt(passphrase, "", m.APIKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot seal api_key for model %q: %w", m.ModelName, err)
|
||||
}
|
||||
m.APIKey = encrypted
|
||||
changed = true
|
||||
}
|
||||
if !changed {
|
||||
return nil, nil
|
||||
}
|
||||
return sealed, nil
|
||||
}
|
||||
|
||||
// resolveAPIKeys decrypts or dereferences each api_key in models in-place.
|
||||
// Supports plaintext (no-op), file:// (read from configDir), and enc:// (AES-GCM decrypt).
|
||||
func resolveAPIKeys(models []ModelConfig, configDir string) error {
|
||||
cr := credential.NewResolver(configDir)
|
||||
for i := range models {
|
||||
resolved, err := cr.Resolve(models[i].APIKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("model_list[%d] (%s): %w", i, models[i].ModelName, err)
|
||||
}
|
||||
models[i].APIKey = resolved
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Config) migrateChannelConfigs() {
|
||||
// Discord: mention_only -> group_trigger.mention_only
|
||||
if c.Channels.Discord.MentionOnly && !c.Channels.Discord.GroupTrigger.MentionOnly {
|
||||
@@ -872,12 +933,22 @@ func (c *Config) migrateChannelConfigs() {
|
||||
}
|
||||
|
||||
func SaveConfig(path string, cfg *Config) error {
|
||||
if passphrase := credential.PassphraseProvider(); passphrase != "" {
|
||||
sealed, err := encryptPlaintextAPIKeys(cfg.ModelList, passphrase)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if sealed != nil {
|
||||
tmp := *cfg
|
||||
tmp.ModelList = sealed
|
||||
cfg = &tmp
|
||||
}
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(cfg, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Use unified atomic write utility with explicit sync for flash storage reliability.
|
||||
return fileutil.WriteFileAtomic(path, data, 0o600)
|
||||
}
|
||||
|
||||
@@ -1044,6 +1115,8 @@ func (t *ToolsConfig) IsToolEnabled(name string) bool {
|
||||
return t.ReadFile.Enabled
|
||||
case "spawn":
|
||||
return t.Spawn.Enabled
|
||||
case "spawn_status":
|
||||
return t.SpawnStatus.Enabled
|
||||
case "spi":
|
||||
return t.SPI.Enabled
|
||||
case "subagent":
|
||||
|
||||
+386
-5
@@ -7,8 +7,22 @@ import (
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/credential"
|
||||
)
|
||||
|
||||
// mustSetupSSHKey generates a temporary Ed25519 SSH key in t.TempDir() and sets
|
||||
// PICOCLAW_SSH_KEY_PATH to its path for the duration of the test. This is required
|
||||
// whenever a test exercises encryption/decryption via credential.Encrypt or SaveConfig.
|
||||
func mustSetupSSHKey(t *testing.T) {
|
||||
t.Helper()
|
||||
keyPath := filepath.Join(t.TempDir(), "picoclaw_ed25519.key")
|
||||
if err := credential.GenerateSSHKey(keyPath); err != nil {
|
||||
t.Fatalf("mustSetupSSHKey: %v", err)
|
||||
}
|
||||
t.Setenv("PICOCLAW_SSH_KEY_PATH", keyPath)
|
||||
}
|
||||
|
||||
func TestAgentModelConfig_UnmarshalString(t *testing.T) {
|
||||
var m AgentModelConfig
|
||||
if err := json.Unmarshal([]byte(`"gpt-4"`), &m); err != nil {
|
||||
@@ -253,6 +267,9 @@ func TestDefaultConfig_Gateway(t *testing.T) {
|
||||
if cfg.Gateway.Port == 0 {
|
||||
t.Error("Gateway port should have default value")
|
||||
}
|
||||
if cfg.Gateway.HotReload {
|
||||
t.Error("Gateway hot reload should be disabled by default")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDefaultConfig_Providers verifies provider structure
|
||||
@@ -391,6 +408,13 @@ func TestDefaultConfig_ExecAllowRemoteEnabled(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultConfig_CronAllowCommandEnabled(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
if !cfg.Tools.Cron.AllowCommand {
|
||||
t.Fatal("DefaultConfig().Tools.Cron.AllowCommand should be true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_OpenAIWebSearchDefaultsTrueWhenUnset(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
configPath := filepath.Join(dir, "config.json")
|
||||
@@ -423,6 +447,22 @@ func TestLoadConfig_ExecAllowRemoteDefaultsTrueWhenUnset(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_CronAllowCommandDefaultsTrueWhenUnset(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
configPath := filepath.Join(dir, "config.json")
|
||||
if err := os.WriteFile(configPath, []byte(`{"tools":{"cron":{"exec_timeout_minutes":5}}}`), 0o600); err != nil {
|
||||
t.Fatalf("WriteFile() error: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error: %v", err)
|
||||
}
|
||||
if !cfg.Tools.Cron.AllowCommand {
|
||||
t.Fatal("tools.cron.allow_command should remain true when unset in config file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_OpenAIWebSearchCanBeDisabled(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
configPath := filepath.Join(dir, "config.json")
|
||||
@@ -482,13 +522,19 @@ func TestDefaultConfig_DMScope(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDefaultConfig_WorkspacePath_Default(t *testing.T) {
|
||||
// Unset to ensure we test the default
|
||||
t.Setenv("PICOCLAW_HOME", "")
|
||||
// Set a known home for consistent test results
|
||||
t.Setenv("HOME", "/tmp/home")
|
||||
|
||||
var fakeHome string
|
||||
if runtime.GOOS == "windows" {
|
||||
fakeHome = `C:\tmp\home`
|
||||
t.Setenv("USERPROFILE", fakeHome)
|
||||
} else {
|
||||
fakeHome = "/tmp/home"
|
||||
t.Setenv("HOME", fakeHome)
|
||||
}
|
||||
|
||||
cfg := DefaultConfig()
|
||||
want := filepath.Join("/tmp/home", ".picoclaw", "workspace")
|
||||
want := filepath.Join(fakeHome, ".picoclaw", "workspace")
|
||||
|
||||
if cfg.Agents.Defaults.Workspace != want {
|
||||
t.Errorf("Default workspace path = %q, want %q", cfg.Agents.Defaults.Workspace, want)
|
||||
@@ -499,7 +545,7 @@ func TestDefaultConfig_WorkspacePath_WithPicoclawHome(t *testing.T) {
|
||||
t.Setenv("PICOCLAW_HOME", "/custom/picoclaw/home")
|
||||
|
||||
cfg := DefaultConfig()
|
||||
want := "/custom/picoclaw/home/workspace"
|
||||
want := filepath.Join("/custom/picoclaw/home", "workspace")
|
||||
|
||||
if cfg.Agents.Defaults.Workspace != want {
|
||||
t.Errorf("Workspace path with PICOCLAW_HOME = %q, want %q", cfg.Agents.Defaults.Workspace, want)
|
||||
@@ -621,3 +667,338 @@ func TestFlexibleStringSlice_UnmarshalText_EmptySliceConsistency(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestLoadConfig_WarnsForPlaintextAPIKey verifies that LoadConfig resolves a plaintext
|
||||
// api_key into memory but does NOT rewrite the config file. File writes are the sole
|
||||
// responsibility of SaveConfig.
|
||||
func TestLoadConfig_WarnsForPlaintextAPIKey(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgPath := filepath.Join(dir, "config.json")
|
||||
const original = `{"model_list":[{"model_name":"test","model":"openai/gpt-4","api_key":"sk-plaintext"}]}`
|
||||
if err := os.WriteFile(cfgPath, []byte(original), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "test-passphrase")
|
||||
t.Setenv("PICOCLAW_SSH_KEY_PATH", "")
|
||||
|
||||
cfg, err := LoadConfig(cfgPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig: %v", err)
|
||||
}
|
||||
// In-memory value must be the resolved plaintext.
|
||||
if cfg.ModelList[0].APIKey != "sk-plaintext" {
|
||||
t.Errorf("in-memory api_key = %q, want %q", cfg.ModelList[0].APIKey, "sk-plaintext")
|
||||
}
|
||||
// The file on disk must remain unchanged — LoadConfig must not write anything.
|
||||
raw, _ := os.ReadFile(cfgPath)
|
||||
if string(raw) != original {
|
||||
t.Errorf("LoadConfig must not modify the config file; got:\n%s", string(raw))
|
||||
}
|
||||
}
|
||||
|
||||
// TestSaveConfig_EncryptsPlaintextAPIKey verifies that SaveConfig writes enc:// ciphertext
|
||||
// to disk and that a subsequent LoadConfig decrypts it back to the original plaintext.
|
||||
func TestSaveConfig_EncryptsPlaintextAPIKey(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgPath := filepath.Join(dir, "config.json")
|
||||
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "test-passphrase")
|
||||
mustSetupSSHKey(t)
|
||||
|
||||
cfg := DefaultConfig()
|
||||
cfg.ModelList = []ModelConfig{
|
||||
{ModelName: "test", Model: "openai/gpt-4", APIKey: "sk-plaintext"},
|
||||
}
|
||||
if err := SaveConfig(cfgPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig: %v", err)
|
||||
}
|
||||
|
||||
// Disk must contain enc://, not the raw key.
|
||||
raw, _ := os.ReadFile(cfgPath)
|
||||
if !strings.Contains(string(raw), "enc://") {
|
||||
t.Errorf("saved file should contain enc://, got:\n%s", string(raw))
|
||||
}
|
||||
if strings.Contains(string(raw), "sk-plaintext") {
|
||||
t.Errorf("saved file must not contain the plaintext key")
|
||||
}
|
||||
|
||||
// A fresh load must decrypt back to the original plaintext.
|
||||
cfg2, err := LoadConfig(cfgPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig after SaveConfig: %v", err)
|
||||
}
|
||||
if cfg2.ModelList[0].APIKey != "sk-plaintext" {
|
||||
t.Errorf("loaded api_key = %q, want %q", cfg2.ModelList[0].APIKey, "sk-plaintext")
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadConfig_NoSealWithoutPassphrase verifies that api_key values are left
|
||||
// unchanged when PICOCLAW_KEY_PASSPHRASE is not set.
|
||||
func TestLoadConfig_NoSealWithoutPassphrase(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgPath := filepath.Join(dir, "config.json")
|
||||
data := `{"model_list":[{"model_name":"test","model":"openai/gpt-4","api_key":"sk-plaintext"}]}`
|
||||
if err := os.WriteFile(cfgPath, []byte(data), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "")
|
||||
t.Setenv("PICOCLAW_SSH_KEY_PATH", "")
|
||||
|
||||
if _, err := LoadConfig(cfgPath); err != nil {
|
||||
t.Fatalf("LoadConfig: %v", err)
|
||||
}
|
||||
|
||||
raw, _ := os.ReadFile(cfgPath)
|
||||
if strings.Contains(string(raw), "enc://") {
|
||||
t.Error("config file must not be modified when no passphrase is set")
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadConfig_FileRefNotSealed verifies that file:// api_key references are not
|
||||
// converted to enc:// values (they are resolved at runtime by the Resolver).
|
||||
func TestLoadConfig_FileRefNotSealed(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgPath := filepath.Join(dir, "config.json")
|
||||
keyFile := filepath.Join(dir, "openai.key")
|
||||
if err := os.WriteFile(keyFile, []byte("sk-from-file"), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
data := `{"model_list":[{"model_name":"test","model":"openai/gpt-4","api_key":"file://openai.key"}]}`
|
||||
if err := os.WriteFile(cfgPath, []byte(data), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "test-passphrase")
|
||||
t.Setenv("PICOCLAW_SSH_KEY_PATH", "")
|
||||
|
||||
if _, err := LoadConfig(cfgPath); err != nil {
|
||||
t.Fatalf("LoadConfig: %v", err)
|
||||
}
|
||||
|
||||
raw, _ := os.ReadFile(cfgPath)
|
||||
if !strings.Contains(string(raw), "file://openai.key") {
|
||||
t.Error("file:// reference should be preserved unchanged in the config file")
|
||||
}
|
||||
if strings.Contains(string(raw), "enc://") {
|
||||
t.Error("file:// reference must not be converted to enc://")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSaveConfig_MixedKeys verifies that SaveConfig encrypts only plaintext api_keys
|
||||
// and leaves already-encrypted (enc://) and file:// entries unchanged.
|
||||
func TestSaveConfig_MixedKeys(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgPath := filepath.Join(dir, "config.json")
|
||||
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "test-passphrase")
|
||||
mustSetupSSHKey(t)
|
||||
|
||||
// Pre-encrypt one key so we have a genuine enc:// value to put in the config.
|
||||
if err := SaveConfig(cfgPath, &Config{
|
||||
ModelList: []ModelConfig{
|
||||
{ModelName: "pre", Model: "openai/gpt-4", APIKey: "sk-already-plain"},
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("setup SaveConfig: %v", err)
|
||||
}
|
||||
raw, _ := os.ReadFile(cfgPath)
|
||||
// Extract the enc:// value from the saved file.
|
||||
var tmp struct {
|
||||
ModelList []struct {
|
||||
APIKey string `json:"api_key"`
|
||||
} `json:"model_list"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &tmp); err != nil || len(tmp.ModelList) == 0 {
|
||||
t.Fatalf("setup: could not parse saved config: %v", err)
|
||||
}
|
||||
alreadyEncrypted := tmp.ModelList[0].APIKey
|
||||
if !strings.HasPrefix(alreadyEncrypted, "enc://") {
|
||||
t.Fatalf("setup: expected enc:// key, got %q", alreadyEncrypted)
|
||||
}
|
||||
|
||||
// Build a config with three models:
|
||||
// 1. plaintext → must be encrypted by SaveConfig
|
||||
// 2. enc:// → must be left unchanged (already encrypted)
|
||||
// 3. file:// → must be left unchanged (file reference)
|
||||
keyFile := filepath.Join(dir, "api.key")
|
||||
if err := os.WriteFile(keyFile, []byte("sk-from-file"), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
cfg := &Config{
|
||||
ModelList: []ModelConfig{
|
||||
{ModelName: "plain", Model: "openai/gpt-4", APIKey: "sk-new-plaintext"},
|
||||
{ModelName: "enc", Model: "openai/gpt-4", APIKey: alreadyEncrypted},
|
||||
{ModelName: "file", Model: "openai/gpt-4", APIKey: "file://api.key"},
|
||||
},
|
||||
}
|
||||
if err := SaveConfig(cfgPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig: %v", err)
|
||||
}
|
||||
|
||||
raw, _ = os.ReadFile(cfgPath)
|
||||
s := string(raw)
|
||||
|
||||
// 1. Plaintext must be encrypted.
|
||||
if strings.Contains(s, "sk-new-plaintext") {
|
||||
t.Error("plaintext key must not appear in saved file")
|
||||
}
|
||||
// 2. The pre-existing enc:// value must still be present (byte-for-byte unchanged).
|
||||
if !strings.Contains(s, alreadyEncrypted) {
|
||||
t.Error("pre-existing enc:// entry must be preserved unchanged")
|
||||
}
|
||||
// 3. file:// must be preserved.
|
||||
if !strings.Contains(s, "file://api.key") {
|
||||
t.Error("file:// reference must be preserved unchanged")
|
||||
}
|
||||
|
||||
// Now load and verify all three decrypt/resolve correctly.
|
||||
cfg2, err := LoadConfig(cfgPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig after SaveConfig: %v", err)
|
||||
}
|
||||
byName := make(map[string]string)
|
||||
for _, m := range cfg2.ModelList {
|
||||
byName[m.ModelName] = m.APIKey
|
||||
}
|
||||
if byName["plain"] != "sk-new-plaintext" {
|
||||
t.Errorf("plain model api_key = %q, want %q", byName["plain"], "sk-new-plaintext")
|
||||
}
|
||||
if byName["enc"] != "sk-already-plain" {
|
||||
t.Errorf("enc model api_key = %q, want %q", byName["enc"], "sk-already-plain")
|
||||
}
|
||||
if byName["file"] != "sk-from-file" {
|
||||
t.Errorf("file model api_key = %q, want %q", byName["file"], "sk-from-file")
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadConfig_MixedKeys_NoPassphrase verifies that when PICOCLAW_KEY_PASSPHRASE
|
||||
// is not set, enc:// entries cause LoadConfig to return an error, while plaintext
|
||||
// and file:// entries in the same config are not affected.
|
||||
func TestLoadConfig_MixedKeys_NoPassphrase(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgPath := filepath.Join(dir, "config.json")
|
||||
|
||||
// First encrypt a key so we have a real enc:// value.
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "test-passphrase")
|
||||
mustSetupSSHKey(t)
|
||||
if err := SaveConfig(cfgPath, &Config{
|
||||
ModelList: []ModelConfig{
|
||||
{ModelName: "m", Model: "openai/gpt-4", APIKey: "sk-secret"},
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("setup SaveConfig: %v", err)
|
||||
}
|
||||
raw, _ := os.ReadFile(cfgPath)
|
||||
var tmp struct {
|
||||
ModelList []struct {
|
||||
APIKey string `json:"api_key"`
|
||||
} `json:"model_list"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &tmp); err != nil {
|
||||
t.Fatalf("setup parse: %v", err)
|
||||
}
|
||||
encValue := tmp.ModelList[0].APIKey
|
||||
|
||||
// Write a mixed config: enc:// + plaintext + file://
|
||||
keyFile := filepath.Join(dir, "api.key")
|
||||
if err := os.WriteFile(keyFile, []byte("sk-from-file"), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
mixed, _ := json.Marshal(map[string]any{
|
||||
"model_list": []map[string]any{
|
||||
{"model_name": "enc", "model": "openai/gpt-4", "api_key": encValue},
|
||||
{"model_name": "plain", "model": "openai/gpt-4", "api_key": "sk-plain"},
|
||||
{"model_name": "file", "model": "openai/gpt-4", "api_key": "file://api.key"},
|
||||
},
|
||||
})
|
||||
if err := os.WriteFile(cfgPath, mixed, 0o600); err != nil {
|
||||
t.Fatalf("setup write: %v", err)
|
||||
}
|
||||
|
||||
// Now clear the passphrase — LoadConfig must fail because enc:// cannot be decrypted.
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "")
|
||||
|
||||
_, err := LoadConfig(cfgPath)
|
||||
if err == nil {
|
||||
t.Fatal("LoadConfig should fail when enc:// key is present and no passphrase is set")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "passphrase required") {
|
||||
t.Errorf("error should mention passphrase required, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSaveConfig_UsesPassphraseProvider verifies that SaveConfig encrypts plaintext
|
||||
// api_keys using credential.PassphraseProvider() rather than os.Getenv directly.
|
||||
// This matters for the launcher, which clears the environment variable and redirects
|
||||
// PassphraseProvider to an in-memory SecureStore.
|
||||
func TestSaveConfig_UsesPassphraseProvider(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgPath := filepath.Join(dir, "config.json")
|
||||
|
||||
// Ensure the env var is empty — passphrase must come from PassphraseProvider only.
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "")
|
||||
mustSetupSSHKey(t)
|
||||
|
||||
// Replace PassphraseProvider with an in-memory function (simulating SecureStore).
|
||||
const testPassphrase = "provider-passphrase"
|
||||
orig := credential.PassphraseProvider
|
||||
credential.PassphraseProvider = func() string { return testPassphrase }
|
||||
t.Cleanup(func() { credential.PassphraseProvider = orig })
|
||||
|
||||
cfg := DefaultConfig()
|
||||
cfg.ModelList = []ModelConfig{
|
||||
{ModelName: "test", Model: "openai/gpt-4", APIKey: "sk-plaintext"},
|
||||
}
|
||||
if err := SaveConfig(cfgPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig: %v", err)
|
||||
}
|
||||
|
||||
raw, _ := os.ReadFile(cfgPath)
|
||||
if !strings.Contains(string(raw), "enc://") {
|
||||
t.Errorf("SaveConfig should have encrypted plaintext key via PassphraseProvider; got:\n%s", raw)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadConfig_UsesPassphraseProvider verifies that LoadConfig decrypts enc:// keys
|
||||
// using credential.PassphraseProvider() rather than os.Getenv directly.
|
||||
func TestLoadConfig_UsesPassphraseProvider(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgPath := filepath.Join(dir, "config.json")
|
||||
|
||||
// Ensure the env var is empty throughout.
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "")
|
||||
mustSetupSSHKey(t)
|
||||
|
||||
const testPassphrase = "provider-passphrase"
|
||||
const plainKey = "sk-secret"
|
||||
|
||||
// First, encrypt the key using the same passphrase.
|
||||
encrypted, err := credential.Encrypt(testPassphrase, "", plainKey)
|
||||
if err != nil {
|
||||
t.Fatalf("Encrypt: %v", err)
|
||||
}
|
||||
|
||||
raw, _ := json.Marshal(map[string]any{
|
||||
"model_list": []map[string]any{
|
||||
{"model_name": "test", "model": "openai/gpt-4", "api_key": encrypted},
|
||||
},
|
||||
})
|
||||
if err = os.WriteFile(cfgPath, raw, 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
|
||||
// Redirect PassphraseProvider — env var is empty, so without this the load would fail.
|
||||
orig := credential.PassphraseProvider
|
||||
credential.PassphraseProvider = func() string { return testPassphrase }
|
||||
t.Cleanup(func() { credential.PassphraseProvider = orig })
|
||||
|
||||
cfg, err := LoadConfig(cfgPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig: %v", err)
|
||||
}
|
||||
if cfg.ModelList[0].APIKey != plainKey {
|
||||
t.Errorf("api_key = %q, want %q", cfg.ModelList[0].APIKey, plainKey)
|
||||
}
|
||||
}
|
||||
|
||||
+16
-2
@@ -385,10 +385,20 @@ func DefaultConfig() *Config {
|
||||
APIBase: "http://localhost:8000/v1",
|
||||
APIKey: "",
|
||||
},
|
||||
|
||||
// Azure OpenAI - https://portal.azure.com
|
||||
// model_name is a user-friendly alias; the model field's path after "azure/" is your deployment name
|
||||
{
|
||||
ModelName: "azure-gpt5",
|
||||
Model: "azure/my-gpt5-deployment",
|
||||
APIBase: "https://your-resource.openai.azure.com",
|
||||
APIKey: "",
|
||||
},
|
||||
},
|
||||
Gateway: GatewayConfig{
|
||||
Host: "127.0.0.1",
|
||||
Port: 18790,
|
||||
Host: "127.0.0.1",
|
||||
Port: 18790,
|
||||
HotReload: false,
|
||||
},
|
||||
Tools: ToolsConfig{
|
||||
MediaCleanup: MediaCleanupConfig{
|
||||
@@ -444,6 +454,7 @@ func DefaultConfig() *Config {
|
||||
Enabled: true,
|
||||
},
|
||||
ExecTimeoutMinutes: 5,
|
||||
AllowCommand: true,
|
||||
},
|
||||
Exec: ExecConfig{
|
||||
ToolConfig: ToolConfig{
|
||||
@@ -513,6 +524,9 @@ func DefaultConfig() *Config {
|
||||
Spawn: ToolConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
SpawnStatus: ToolConfig{
|
||||
Enabled: false,
|
||||
},
|
||||
SPI: ToolConfig{
|
||||
Enabled: false, // Hardware tool - Linux only
|
||||
},
|
||||
|
||||
@@ -0,0 +1,335 @@
|
||||
// Package credential resolves API credential values for model_list entries.
|
||||
//
|
||||
// An API key is a form of authorization credential. This package centralizes
|
||||
// how raw credential strings—plaintext or file references—are resolved into
|
||||
// their actual values, keeping that logic out of the config loader.
|
||||
//
|
||||
// Supported formats for the api_key field:
|
||||
//
|
||||
// - Plaintext: "sk-abc123" → returned as-is
|
||||
// - File ref: "file://filename.key" → content read from configDir/filename.key
|
||||
// - Encrypted: "enc://<base64>" → AES-256-GCM decrypt via PICOCLAW_KEY_PASSPHRASE
|
||||
// - Empty: "" → returned as-is (auth_method=oauth etc.)
|
||||
//
|
||||
// Encryption uses AES-256-GCM with HKDF-SHA256 key derivation (< 1ms, safe for embedded Linux).
|
||||
// An SSH private key is required for both encryption and decryption.
|
||||
// Key derivation:
|
||||
//
|
||||
// HKDF-SHA256(ikm=HMAC-SHA256(SHA256(sshKeyBytes), passphrase), salt, info)
|
||||
//
|
||||
// SSH key path resolution priority:
|
||||
//
|
||||
// 1. sshKeyPath argument to Encrypt (explicit)
|
||||
// 2. PICOCLAW_SSH_KEY_PATH env var
|
||||
// 3. ~/.ssh/picoclaw_ed25519.key (os.UserHomeDir is cross-platform)
|
||||
package credential
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/hkdf"
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// PassphraseEnvVar is the environment variable that holds the encryption passphrase.
|
||||
// Other packages (e.g. config) reference this constant to avoid duplicating the string.
|
||||
const PassphraseEnvVar = "PICOCLAW_KEY_PASSPHRASE"
|
||||
|
||||
// PassphraseProvider is the function used to retrieve the passphrase for enc://
|
||||
// credential decryption. It defaults to reading PICOCLAW_KEY_PASSPHRASE from the
|
||||
// process environment. Replace it at startup to use a different source, such as
|
||||
// an in-memory SecureStore, so that all LoadConfig() calls everywhere share the
|
||||
// same passphrase source without needing os.Environ.
|
||||
//
|
||||
// Example (launcher main.go):
|
||||
//
|
||||
// credential.PassphraseProvider = apiHandler.passphraseStore.Get
|
||||
var PassphraseProvider func() string = func() string {
|
||||
return os.Getenv(PassphraseEnvVar)
|
||||
}
|
||||
|
||||
// ErrPassphraseRequired is returned when an enc:// credential is encountered but
|
||||
// no passphrase is available from PassphraseProvider. Callers can detect this
|
||||
// with errors.Is to distinguish a missing-passphrase condition from other errors.
|
||||
var ErrPassphraseRequired = errors.New("credential: enc:// passphrase required")
|
||||
|
||||
// ErrDecryptionFailed is returned when an enc:// credential cannot be decrypted,
|
||||
// indicating a wrong passphrase or SSH key. Callers can detect this with errors.Is.
|
||||
var ErrDecryptionFailed = errors.New("credential: enc:// decryption failed (wrong passphrase or SSH key?)")
|
||||
|
||||
const (
|
||||
fileScheme = "file://"
|
||||
encScheme = "enc://"
|
||||
hkdfInfo = "picoclaw-credential-v1"
|
||||
saltLen = 16
|
||||
nonceLen = 12
|
||||
keyLen = 32
|
||||
sshKeyEnv = "PICOCLAW_SSH_KEY_PATH"
|
||||
)
|
||||
|
||||
// Resolver resolves raw credential strings for model_list api_key fields.
|
||||
// File references are resolved relative to the directory of the config file.
|
||||
type Resolver struct {
|
||||
configDir string
|
||||
resolvedConfigDir string // symlink-resolved form of configDir
|
||||
}
|
||||
|
||||
// NewResolver returns a Resolver that resolves file:// references relative to
|
||||
// configDir (typically filepath.Dir of the config file path).
|
||||
func NewResolver(configDir string) *Resolver {
|
||||
resolved := configDir
|
||||
if configDir != "" {
|
||||
if linkedPath, err := filepath.EvalSymlinks(configDir); err == nil {
|
||||
resolved = linkedPath
|
||||
}
|
||||
}
|
||||
return &Resolver{configDir: configDir, resolvedConfigDir: resolved}
|
||||
}
|
||||
|
||||
// Resolve returns the actual credential value for raw:
|
||||
//
|
||||
// - "" → "" (no error; auth_method=oauth needs no key)
|
||||
// - "file://name.key" → trimmed content of configDir/name.key
|
||||
// - anything else → raw unchanged (plaintext credential)
|
||||
func (r *Resolver) Resolve(raw string) (string, error) {
|
||||
if raw == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
if strings.HasPrefix(raw, fileScheme) {
|
||||
fileName := strings.TrimSpace(strings.TrimPrefix(raw, fileScheme))
|
||||
if fileName == "" {
|
||||
return "", fmt.Errorf("credential: file:// reference has no filename")
|
||||
}
|
||||
|
||||
baseDir := r.resolvedConfigDir
|
||||
if baseDir == "" {
|
||||
baseDir = r.configDir
|
||||
}
|
||||
keyPath := filepath.Join(baseDir, fileName)
|
||||
// Resolve symlinks before enforcing containment to prevent escaping via symlinks.
|
||||
realKeyPath, err := filepath.EvalSymlinks(keyPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("credential: failed to resolve credential file path %q: %w", keyPath, err)
|
||||
}
|
||||
if !isWithinDir(realKeyPath, baseDir) {
|
||||
return "", fmt.Errorf("credential: file:// path escapes config directory")
|
||||
}
|
||||
data, err := os.ReadFile(realKeyPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("credential: failed to read credential file %q: %w", realKeyPath, err)
|
||||
}
|
||||
|
||||
value := strings.TrimSpace(string(data))
|
||||
if value == "" {
|
||||
return "", fmt.Errorf("credential: credential file %q is empty", realKeyPath)
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
if strings.HasPrefix(raw, encScheme) {
|
||||
return resolveEncrypted(raw)
|
||||
}
|
||||
|
||||
// Plaintext credential — return unchanged.
|
||||
return raw, nil
|
||||
}
|
||||
|
||||
// resolveEncrypted decrypts an enc:// credential using PassphraseProvider.
|
||||
func resolveEncrypted(raw string) (string, error) {
|
||||
passphrase := PassphraseProvider()
|
||||
if passphrase == "" {
|
||||
return "", ErrPassphraseRequired
|
||||
}
|
||||
|
||||
sshKeyPath := pickSSHKeyPath("") // override="": consult env then auto-detect
|
||||
|
||||
b64 := strings.TrimPrefix(raw, encScheme)
|
||||
blob, err := base64.StdEncoding.DecodeString(b64)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("credential: enc:// invalid base64: %w", err)
|
||||
}
|
||||
if len(blob) < saltLen+nonceLen+1 {
|
||||
return "", fmt.Errorf("credential: enc:// payload too short")
|
||||
}
|
||||
|
||||
salt := blob[:saltLen]
|
||||
nonce := blob[saltLen : saltLen+nonceLen]
|
||||
ciphertext := blob[saltLen+nonceLen:]
|
||||
|
||||
key, err := deriveKey(passphrase, sshKeyPath, salt)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("credential: enc:// cipher init: %w", err)
|
||||
}
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("credential: enc:// gcm init: %w", err)
|
||||
}
|
||||
|
||||
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("%w: %w", ErrDecryptionFailed, err)
|
||||
}
|
||||
return string(plaintext), nil
|
||||
}
|
||||
|
||||
// Encrypt encrypts plaintext and returns an enc:// credential string.
|
||||
//
|
||||
// passphrase is required (PICOCLAW_KEY_PASSPHRASE value).
|
||||
// sshKeyPath is the SSH private key file to use; pass "" to auto-detect via
|
||||
// PICOCLAW_SSH_KEY_PATH env var or ~/.ssh/picoclaw_ed25519.key.
|
||||
// An SSH private key must be resolvable or Encrypt returns an error.
|
||||
func Encrypt(passphrase, sshKeyPath, plaintext string) (string, error) {
|
||||
if passphrase == "" {
|
||||
return "", fmt.Errorf("credential: passphrase must not be empty")
|
||||
}
|
||||
sshKeyPath = pickSSHKeyPath(sshKeyPath)
|
||||
|
||||
salt := make([]byte, saltLen)
|
||||
if _, err := io.ReadFull(rand.Reader, salt); err != nil {
|
||||
return "", fmt.Errorf("credential: failed to generate salt: %w", err)
|
||||
}
|
||||
|
||||
key, err := deriveKey(passphrase, sshKeyPath, salt)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("credential: cipher init: %w", err)
|
||||
}
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("credential: gcm init: %w", err)
|
||||
}
|
||||
|
||||
nonce := make([]byte, nonceLen)
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return "", fmt.Errorf("credential: failed to generate nonce: %w", err)
|
||||
}
|
||||
|
||||
ciphertext := gcm.Seal(nil, nonce, []byte(plaintext), nil)
|
||||
blob := make([]byte, 0, saltLen+nonceLen+len(ciphertext))
|
||||
blob = append(blob, salt...)
|
||||
blob = append(blob, nonce...)
|
||||
blob = append(blob, ciphertext...)
|
||||
return encScheme + base64.StdEncoding.EncodeToString(blob), nil
|
||||
}
|
||||
|
||||
// isWithinDir reports whether path is contained within (or equal to) dir.
|
||||
// Uses filepath.IsLocal on the relative path for robust cross-platform traversal detection.
|
||||
func isWithinDir(path, dir string) bool {
|
||||
rel, err := filepath.Rel(filepath.Clean(dir), filepath.Clean(path))
|
||||
return err == nil && filepath.IsLocal(rel)
|
||||
}
|
||||
|
||||
// allowedSSHKeyPath reports whether path is in a permitted location for SSH key files:
|
||||
// - exact match with PICOCLAW_SSH_KEY_PATH env var
|
||||
// - within the PICOCLAW_HOME env var directory
|
||||
// - within ~/.ssh/
|
||||
func allowedSSHKeyPath(path string) bool {
|
||||
if path == "" {
|
||||
return true // passphrase-only mode; no file will be read
|
||||
}
|
||||
clean := filepath.Clean(path)
|
||||
|
||||
// Exact match with PICOCLAW_SSH_KEY_PATH.
|
||||
if envPath, ok := os.LookupEnv(sshKeyEnv); ok && envPath != "" {
|
||||
if clean == filepath.Clean(envPath) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Within PICOCLAW_HOME.
|
||||
if picoHome := os.Getenv("PICOCLAW_HOME"); picoHome != "" {
|
||||
if isWithinDir(clean, picoHome) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Within ~/.ssh/.
|
||||
if userHome, err := os.UserHomeDir(); err == nil {
|
||||
if isWithinDir(clean, filepath.Join(userHome, ".ssh")) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// deriveKey derives a 32-byte AES-256 key from passphrase and SSH private key.
|
||||
//
|
||||
// ikm = HMAC-SHA256(key=SHA256(sshKeyBytes), msg=passphrase)
|
||||
// Final key: HKDF-SHA256(ikm, salt, info="picoclaw-credential-v1", 32 bytes)
|
||||
// sshKeyPath must be non-empty; returns an error otherwise.
|
||||
func deriveKey(passphrase, sshKeyPath string, salt []byte) ([]byte, error) {
|
||||
if sshKeyPath == "" {
|
||||
return nil, fmt.Errorf(
|
||||
"credential: SSH private key is required but not found" +
|
||||
" (set PICOCLAW_SSH_KEY_PATH or place key at ~/.ssh/picoclaw_ed25519.key)")
|
||||
}
|
||||
if !allowedSSHKeyPath(sshKeyPath) {
|
||||
return nil, fmt.Errorf(
|
||||
"credential: SSH key path %q is not in an allowed location (PICOCLAW_SSH_KEY_PATH, PICOCLAW_HOME, or ~/.ssh/)",
|
||||
sshKeyPath,
|
||||
)
|
||||
}
|
||||
sshBytes, err := os.ReadFile(sshKeyPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("credential: cannot read SSH key %q: %w", sshKeyPath, err)
|
||||
}
|
||||
sshHash := sha256.Sum256(sshBytes)
|
||||
mac := hmac.New(sha256.New, sshHash[:])
|
||||
mac.Write([]byte(passphrase))
|
||||
ikm := mac.Sum(nil)
|
||||
|
||||
key, err := hkdf.Key(sha256.New, ikm, salt, hkdfInfo, keyLen)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("credential: HKDF expand failed: %w", err)
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// pickSSHKeyPath returns the SSH private key path to use for encryption/decryption.
|
||||
//
|
||||
// Priority:
|
||||
// 1. override (non-empty explicit argument)
|
||||
// 2. PICOCLAW_SSH_KEY_PATH env var
|
||||
// 3. ~/.ssh/picoclaw_ed25519.key (auto-detection)
|
||||
//
|
||||
// Returns "" when no key is found; deriveKey will return an error in that case.
|
||||
func pickSSHKeyPath(override string) string {
|
||||
if override != "" {
|
||||
return override
|
||||
}
|
||||
if p, ok := os.LookupEnv(sshKeyEnv); ok {
|
||||
return p // respect explicit setting, even if ""
|
||||
}
|
||||
return findDefaultSSHKey()
|
||||
}
|
||||
|
||||
// findDefaultSSHKey returns the picoclaw-specific SSH key path if it exists.
|
||||
func findDefaultSSHKey() string {
|
||||
p, err := DefaultSSHKeyPath()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
if _, err := os.Stat(p); err == nil {
|
||||
return p
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -0,0 +1,283 @@
|
||||
package credential_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/credential"
|
||||
)
|
||||
|
||||
func TestResolve_PlainKey(t *testing.T) {
|
||||
r := credential.NewResolver(t.TempDir())
|
||||
got, err := r.Resolve("sk-plaintext-key")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != "sk-plaintext-key" {
|
||||
t.Fatalf("got %q, want %q", got, "sk-plaintext-key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolve_FileKey_Success(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
keyFile := "openai_plain.key"
|
||||
if err := os.WriteFile(filepath.Join(dir, keyFile), []byte("sk-from-file\n"), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
|
||||
r := credential.NewResolver(dir)
|
||||
got, err := r.Resolve("file://" + keyFile)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != "sk-from-file" {
|
||||
t.Fatalf("got %q, want %q", got, "sk-from-file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolve_FileKey_NotFound(t *testing.T) {
|
||||
r := credential.NewResolver(t.TempDir())
|
||||
_, err := r.Resolve("file://missing.key")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing file, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolve_FileKey_Empty(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
keyFile := "empty.key"
|
||||
if err := os.WriteFile(filepath.Join(dir, keyFile), []byte(" \n"), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
|
||||
r := credential.NewResolver(dir)
|
||||
_, err := r.Resolve("file://" + keyFile)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty credential file, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolve_EncKey_RoundTrip tests basic encryption/decryption round-trip with an SSH key.
|
||||
func TestResolve_EncKey_RoundTrip(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
sshKeyPath := filepath.Join(dir, "picoclaw_ed25519.key")
|
||||
if err := os.WriteFile(sshKeyPath, []byte("fake-ssh-key-material\n"), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
|
||||
const passphrase = "test-passphrase-32bytes-long-ok!"
|
||||
const plaintext = "sk-encrypted-secret"
|
||||
|
||||
t.Setenv("PICOCLAW_SSH_KEY_PATH", sshKeyPath)
|
||||
|
||||
enc, err := credential.Encrypt(passphrase, "", plaintext)
|
||||
if err != nil {
|
||||
t.Fatalf("Encrypt: %v", err)
|
||||
}
|
||||
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", passphrase)
|
||||
|
||||
r := credential.NewResolver(t.TempDir())
|
||||
got, err := r.Resolve(enc)
|
||||
if err != nil {
|
||||
t.Fatalf("Resolve: %v", err)
|
||||
}
|
||||
if got != plaintext {
|
||||
t.Fatalf("got %q, want %q", got, plaintext)
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolve_EncKey_WithSSHKey tests that the SSH key file is incorporated into key derivation.
|
||||
func TestResolve_EncKey_WithSSHKey(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
sshKeyPath := filepath.Join(dir, "picoclaw_ed25519.key")
|
||||
if err := os.WriteFile(sshKeyPath, []byte("fake-ssh-private-key-material\n"), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
|
||||
const passphrase = "test-passphrase"
|
||||
const plaintext = "sk-ssh-protected-secret"
|
||||
|
||||
// Set PICOCLAW_SSH_KEY_PATH before Encrypt so the path passes allowedSSHKeyPath validation.
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", passphrase)
|
||||
t.Setenv("PICOCLAW_SSH_KEY_PATH", sshKeyPath)
|
||||
|
||||
enc, err := credential.Encrypt(passphrase, sshKeyPath, plaintext)
|
||||
if err != nil {
|
||||
t.Fatalf("Encrypt: %v", err)
|
||||
}
|
||||
|
||||
r := credential.NewResolver(t.TempDir())
|
||||
got, err := r.Resolve(enc)
|
||||
if err != nil {
|
||||
t.Fatalf("Resolve: %v", err)
|
||||
}
|
||||
if got != plaintext {
|
||||
t.Fatalf("got %q, want %q", got, plaintext)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolve_EncKey_NoPassphrase(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
sshKeyPath := filepath.Join(dir, "picoclaw_ed25519.key")
|
||||
if err := os.WriteFile(sshKeyPath, []byte("fake-ssh-key\n"), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
t.Setenv("PICOCLAW_SSH_KEY_PATH", sshKeyPath)
|
||||
|
||||
enc, err := credential.Encrypt("some-passphrase", "", "sk-secret")
|
||||
if err != nil {
|
||||
t.Fatalf("Encrypt: %v", err)
|
||||
}
|
||||
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "")
|
||||
|
||||
r := credential.NewResolver(t.TempDir())
|
||||
_, err = r.Resolve(enc)
|
||||
if err == nil {
|
||||
t.Fatal("expected error when PICOCLAW_KEY_PASSPHRASE is unset, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolve_EncKey_BadCiphertext(t *testing.T) {
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "some-passphrase")
|
||||
t.Setenv("PICOCLAW_SSH_KEY_PATH", "")
|
||||
|
||||
r := credential.NewResolver(t.TempDir())
|
||||
_, err := r.Resolve("enc://!!not-valid-base64!!")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid enc:// payload, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolve_EncKey_PayloadTooShort(t *testing.T) {
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "some-passphrase")
|
||||
t.Setenv("PICOCLAW_SSH_KEY_PATH", "")
|
||||
|
||||
// Valid base64 but fewer bytes than salt(16)+nonce(12)+1 minimum.
|
||||
import64 := "dG9vc2hvcnQ=" // "tooshort" = 8 bytes
|
||||
r := credential.NewResolver(t.TempDir())
|
||||
_, err := r.Resolve("enc://" + import64)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for too-short enc:// payload, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolve_EncKey_WrongPassphrase(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
sshKeyPath := filepath.Join(dir, "picoclaw_ed25519.key")
|
||||
if err := os.WriteFile(sshKeyPath, []byte("fake-ssh-key\n"), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
t.Setenv("PICOCLAW_SSH_KEY_PATH", sshKeyPath)
|
||||
|
||||
enc, err := credential.Encrypt("correct-passphrase", "", "sk-secret")
|
||||
if err != nil {
|
||||
t.Fatalf("Encrypt: %v", err)
|
||||
}
|
||||
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "wrong-passphrase")
|
||||
|
||||
r := credential.NewResolver(t.TempDir())
|
||||
_, err = r.Resolve(enc)
|
||||
if err == nil {
|
||||
t.Fatal("expected decryption error for wrong passphrase, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncrypt_EmptyPassphrase(t *testing.T) {
|
||||
_, err := credential.Encrypt("", "", "sk-secret")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty passphrase, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeriveKey_SSHKeyNotFound(t *testing.T) {
|
||||
// Encrypt with a real SSH key path, then try to decrypt with a missing path.
|
||||
dir := t.TempDir()
|
||||
sshKeyPath := filepath.Join(dir, "picoclaw_ed25519.key")
|
||||
if err := os.WriteFile(sshKeyPath, []byte("fake-key\n"), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
|
||||
// Register the real key path so allowedSSHKeyPath validation passes for Encrypt.
|
||||
t.Setenv("PICOCLAW_SSH_KEY_PATH", sshKeyPath)
|
||||
|
||||
enc, err := credential.Encrypt("passphrase", sshKeyPath, "sk-secret")
|
||||
if err != nil {
|
||||
t.Fatalf("Encrypt: %v", err)
|
||||
}
|
||||
|
||||
// Point to a non-existent SSH key so deriveKey's ReadFile fails.
|
||||
// The path is still under the same dir, so allowedSSHKeyPath passes (exact env match).
|
||||
t.Setenv("PICOCLAW_KEY_PASSPHRASE", "passphrase")
|
||||
t.Setenv("PICOCLAW_SSH_KEY_PATH", filepath.Join(dir, "nonexistent_key"))
|
||||
|
||||
r := credential.NewResolver(t.TempDir())
|
||||
_, err = r.Resolve(enc)
|
||||
if err == nil {
|
||||
t.Fatal("expected error when SSH key file is missing, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolve_FileRef_PathTraversal verifies that file:// references cannot escape configDir
|
||||
// via relative traversal ("../../etc/passwd") or absolute paths ("/abs/path").
|
||||
func TestResolve_FileRef_PathTraversal(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgPath := filepath.Join(dir, "config.json")
|
||||
// Create a file outside configDir that the traversal would point to.
|
||||
outsideFile := filepath.Join(t.TempDir(), "secret.key")
|
||||
if err := os.WriteFile(outsideFile, []byte("stolen"), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
|
||||
r := credential.NewResolver(filepath.Dir(cfgPath))
|
||||
|
||||
cases := []string{
|
||||
"file://../../secret.key",
|
||||
"file://../secret.key",
|
||||
"file://" + outsideFile, // absolute path
|
||||
}
|
||||
for _, raw := range cases {
|
||||
_, err := r.Resolve(raw)
|
||||
if err == nil {
|
||||
t.Errorf("Resolve(%q): expected path traversal error, got nil", raw)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolve_FileRef_withinConfigDir verifies that a legitimate relative file:// ref works.
|
||||
func TestResolve_FileRef_withinConfigDir(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
if err := os.WriteFile(filepath.Join(dir, "my.key"), []byte("sk-valid\n"), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
r := credential.NewResolver(dir)
|
||||
got, err := r.Resolve("file://my.key")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != "sk-valid" {
|
||||
t.Fatalf("got %q, want %q", got, "sk-valid")
|
||||
}
|
||||
}
|
||||
|
||||
// TestEncrypt_SSHKeyOutsideAllowedDirs verifies that Encrypt rejects SSH key paths
|
||||
// that are not under PICOCLAW_SSH_KEY_PATH, PICOCLAW_HOME, or ~/.ssh/.
|
||||
func TestEncrypt_SSHKeyOutsideAllowedDirs(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
sshKeyPath := filepath.Join(dir, "picoclaw_ed25519.key")
|
||||
if err := os.WriteFile(sshKeyPath, []byte("fake-key\n"), 0o600); err != nil {
|
||||
t.Fatalf("setup: %v", err)
|
||||
}
|
||||
|
||||
// Make sure none of the allowed env vars point here.
|
||||
t.Setenv("PICOCLAW_SSH_KEY_PATH", "")
|
||||
t.Setenv("PICOCLAW_HOME", "")
|
||||
|
||||
_, err := credential.Encrypt("passphrase", sshKeyPath, "sk-secret")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for SSH key outside allowed directories, got nil")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
package credential
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// DefaultSSHKeyPath returns the canonical path for the picoclaw-specific SSH key.
|
||||
// The path is always ~/.ssh/picoclaw_ed25519.key (os.UserHomeDir is cross-platform).
|
||||
func DefaultSSHKeyPath() (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("credential: cannot determine home directory: %w", err)
|
||||
}
|
||||
return filepath.Join(home, ".ssh", "picoclaw_ed25519.key"), nil
|
||||
}
|
||||
|
||||
// GenerateSSHKey generates an Ed25519 SSH key pair and writes the private key
|
||||
// to path (permissions 0600) and the public key to path+".pub" (permissions 0644).
|
||||
// The ~/.ssh/ directory is created with 0700 if it does not exist.
|
||||
// If the files already exist they are overwritten.
|
||||
func GenerateSSHKey(path string) error {
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0o700); err != nil {
|
||||
return fmt.Errorf("credential: keygen: cannot create directory %q: %w", filepath.Dir(path), err)
|
||||
}
|
||||
|
||||
pubRaw, privRaw, err := ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("credential: keygen: ed25519 key generation failed: %w", err)
|
||||
}
|
||||
|
||||
// Marshal private key as OpenSSH PEM.
|
||||
block, err := ssh.MarshalPrivateKey(privRaw, "")
|
||||
if err != nil {
|
||||
return fmt.Errorf("credential: keygen: marshal private key: %w", err)
|
||||
}
|
||||
privPEM := pem.EncodeToMemory(block)
|
||||
|
||||
if err = os.WriteFile(path, privPEM, 0o600); err != nil {
|
||||
return fmt.Errorf("credential: keygen: write private key %q: %w", path, err)
|
||||
}
|
||||
|
||||
// Marshal public key as authorized_keys line.
|
||||
sshPub, err := ssh.NewPublicKey(pubRaw)
|
||||
if err != nil {
|
||||
return fmt.Errorf("credential: keygen: marshal public key: %w", err)
|
||||
}
|
||||
pubLine := ssh.MarshalAuthorizedKey(sshPub)
|
||||
|
||||
pubPath := path + ".pub"
|
||||
if err := os.WriteFile(pubPath, pubLine, 0o644); err != nil {
|
||||
return fmt.Errorf("credential: keygen: write public key %q: %w", pubPath, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,115 @@
|
||||
package credential
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func TestGenerateSSHKey_CreatesFiles(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
keyPath := filepath.Join(dir, "test_ed25519.key")
|
||||
|
||||
if err := GenerateSSHKey(keyPath); err != nil {
|
||||
t.Fatalf("GenerateSSHKey() error = %v", err)
|
||||
}
|
||||
|
||||
// Private key must exist.
|
||||
privInfo, err := os.Stat(keyPath)
|
||||
if err != nil {
|
||||
t.Fatalf("private key file missing: %v", err)
|
||||
}
|
||||
|
||||
// Check permissions on non-Windows (Windows does not support Unix permission bits).
|
||||
if runtime.GOOS != "windows" {
|
||||
if got := privInfo.Mode().Perm(); got != 0o600 {
|
||||
t.Errorf("private key permissions = %04o, want 0600", got)
|
||||
}
|
||||
}
|
||||
|
||||
// Public key must exist.
|
||||
pubPath := keyPath + ".pub"
|
||||
pubInfo, err := os.Stat(pubPath)
|
||||
if err != nil {
|
||||
t.Fatalf("public key file missing: %v", err)
|
||||
}
|
||||
if runtime.GOOS != "windows" {
|
||||
if got := pubInfo.Mode().Perm(); got != 0o644 {
|
||||
t.Errorf("public key permissions = %04o, want 0644", got)
|
||||
}
|
||||
}
|
||||
|
||||
// Private key must be parseable as an OpenSSH ed25519 key.
|
||||
privPEM, err := os.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
t.Fatalf("read private key: %v", err)
|
||||
}
|
||||
privKey, err := ssh.ParseRawPrivateKey(privPEM)
|
||||
if err != nil {
|
||||
t.Fatalf("parse private key: %v", err)
|
||||
}
|
||||
if _, ok := privKey.(*ed25519.PrivateKey); !ok {
|
||||
t.Errorf("private key type = %T, want *ed25519.PrivateKey", privKey)
|
||||
}
|
||||
|
||||
// Public key must be parseable as authorized_keys line.
|
||||
pubBytes, err := os.ReadFile(pubPath)
|
||||
if err != nil {
|
||||
t.Fatalf("read public key: %v", err)
|
||||
}
|
||||
pubKey, _, _, rest, err := ssh.ParseAuthorizedKey(pubBytes)
|
||||
if err != nil {
|
||||
t.Fatalf("parse public key: %v", err)
|
||||
}
|
||||
if pubKey == nil {
|
||||
t.Fatal("expected non-nil public key")
|
||||
}
|
||||
if len(rest) > 0 {
|
||||
t.Errorf("unexpected trailing bytes after public key: %d bytes", len(rest))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateSSHKey_OverwritesExisting(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
keyPath := filepath.Join(dir, "test_ed25519.key")
|
||||
|
||||
// Generate twice; second call must not error and must produce a different key.
|
||||
if err := GenerateSSHKey(keyPath); err != nil {
|
||||
t.Fatalf("first GenerateSSHKey() error = %v", err)
|
||||
}
|
||||
first, err := os.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
t.Fatalf("read first key: %v", err)
|
||||
}
|
||||
|
||||
if err = GenerateSSHKey(keyPath); err != nil {
|
||||
t.Fatalf("second GenerateSSHKey() error = %v", err)
|
||||
}
|
||||
second, err := os.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
t.Fatalf("read second key: %v", err)
|
||||
}
|
||||
|
||||
// Two independently generated Ed25519 keys must differ.
|
||||
if string(first) == string(second) {
|
||||
t.Error("expected overwritten key to differ from original")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateSSHKey_CreatesDirectory(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
// Nested directory that does not yet exist.
|
||||
keyPath := filepath.Join(dir, "subdir", ".ssh", "picoclaw_ed25519.key")
|
||||
|
||||
if err := GenerateSSHKey(keyPath); err != nil {
|
||||
t.Fatalf("GenerateSSHKey() error = %v", err)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(keyPath); err != nil {
|
||||
t.Fatalf("private key not created: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
package credential
|
||||
|
||||
import "sync/atomic"
|
||||
|
||||
// SecureStore holds a passphrase in memory.
|
||||
//
|
||||
// Uses atomic.Pointer so reads and writes are lock-free.
|
||||
// The passphrase is never written to disk; callers decide how to
|
||||
// transport it outside this store (e.g., via cmd.Env or os.Environ).
|
||||
type SecureStore struct {
|
||||
val atomic.Pointer[string]
|
||||
}
|
||||
|
||||
// NewSecureStore creates an empty SecureStore.
|
||||
func NewSecureStore() *SecureStore {
|
||||
return &SecureStore{}
|
||||
}
|
||||
|
||||
// SetString stores the passphrase. An empty string clears the store.
|
||||
func (s *SecureStore) SetString(passphrase string) {
|
||||
if passphrase == "" {
|
||||
s.val.Store(nil)
|
||||
return
|
||||
}
|
||||
s.val.Store(&passphrase)
|
||||
}
|
||||
|
||||
// Get returns the stored passphrase, or "" if not set.
|
||||
func (s *SecureStore) Get() string {
|
||||
if p := s.val.Load(); p != nil {
|
||||
return *p
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// IsSet reports whether a passphrase is currently stored.
|
||||
func (s *SecureStore) IsSet() bool {
|
||||
return s.val.Load() != nil
|
||||
}
|
||||
|
||||
// Clear removes the stored passphrase.
|
||||
func (s *SecureStore) Clear() {
|
||||
s.val.Store(nil)
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
package credential
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSecureStore_SetGet(t *testing.T) {
|
||||
s := NewSecureStore()
|
||||
if s.IsSet() {
|
||||
t.Error("expected empty store")
|
||||
}
|
||||
|
||||
s.SetString("hunter2")
|
||||
if !s.IsSet() {
|
||||
t.Error("expected store to be set")
|
||||
}
|
||||
if got := s.Get(); got != "hunter2" {
|
||||
t.Errorf("Get() = %q, want %q", got, "hunter2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureStore_Clear(t *testing.T) {
|
||||
s := NewSecureStore()
|
||||
s.SetString("secret")
|
||||
s.Clear()
|
||||
|
||||
if s.IsSet() {
|
||||
t.Error("expected store to be empty after Clear()")
|
||||
}
|
||||
if got := s.Get(); got != "" {
|
||||
t.Errorf("Get() after Clear() = %q, want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureStore_SetOverwrites(t *testing.T) {
|
||||
s := NewSecureStore()
|
||||
s.SetString("first")
|
||||
s.SetString("second")
|
||||
|
||||
if got := s.Get(); got != "second" {
|
||||
t.Errorf("Get() = %q, want %q", got, "second")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureStore_EmptyPassphrase(t *testing.T) {
|
||||
s := NewSecureStore()
|
||||
s.SetString("") // empty → should not mark as set
|
||||
|
||||
if s.IsSet() {
|
||||
t.Error("empty passphrase should not mark store as set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecureStore_ConcurrentSetGet(t *testing.T) {
|
||||
s := NewSecureStore()
|
||||
const goroutines = 10
|
||||
const iterations = 1000
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(goroutines)
|
||||
for i := 0; i < goroutines; i++ {
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < iterations; j++ {
|
||||
if id%2 == 0 {
|
||||
s.SetString("even")
|
||||
} else {
|
||||
s.SetString("odd")
|
||||
}
|
||||
_ = s.Get()
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
final := s.Get()
|
||||
if final != "" && final != "even" && final != "odd" {
|
||||
t.Errorf("Get() returned unexpected value %q after concurrent Set/Get", final)
|
||||
}
|
||||
}
|
||||
@@ -7,9 +7,9 @@ import (
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/cmd/picoclaw/internal"
|
||||
"github.com/sipeed/picoclaw/pkg/agent"
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
@@ -41,16 +41,13 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/voice"
|
||||
)
|
||||
|
||||
// Timeout constants for service operations
|
||||
const (
|
||||
serviceRestartTimeout = 30 * time.Second
|
||||
serviceShutdownTimeout = 30 * time.Second
|
||||
providerReloadTimeout = 30 * time.Second
|
||||
gracefulShutdownTimeout = 15 * time.Second
|
||||
)
|
||||
|
||||
// gatewayServices holds references to all running services
|
||||
type gatewayServices struct {
|
||||
type services struct {
|
||||
CronService *cron.CronService
|
||||
HeartbeatService *heartbeat.HeartbeatService
|
||||
MediaStore media.MediaStore
|
||||
@@ -59,24 +56,41 @@ type gatewayServices struct {
|
||||
HealthServer *health.Server
|
||||
}
|
||||
|
||||
func gatewayCmd(debug bool) error {
|
||||
type startupBlockedProvider struct {
|
||||
reason string
|
||||
}
|
||||
|
||||
func (p *startupBlockedProvider) Chat(
|
||||
_ context.Context,
|
||||
_ []providers.Message,
|
||||
_ []providers.ToolDefinition,
|
||||
_ string,
|
||||
_ map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
return nil, fmt.Errorf("%s", p.reason)
|
||||
}
|
||||
|
||||
func (p *startupBlockedProvider) GetDefaultModel() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Run starts the gateway runtime using the configuration loaded from configPath.
|
||||
func Run(debug bool, configPath string, allowEmptyStartup bool) error {
|
||||
if debug {
|
||||
logger.SetLevel(logger.DEBUG)
|
||||
fmt.Println("🔍 Debug mode enabled")
|
||||
}
|
||||
|
||||
configPath := internal.GetConfigPath()
|
||||
cfg, err := internal.LoadConfig()
|
||||
cfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error loading config: %w", err)
|
||||
}
|
||||
|
||||
provider, modelID, err := providers.CreateProvider(cfg)
|
||||
provider, modelID, err := createStartupProvider(cfg, allowEmptyStartup)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating provider: %w", err)
|
||||
}
|
||||
|
||||
// Use the resolved model ID from provider creation
|
||||
if modelID != "" {
|
||||
cfg.Agents.Defaults.ModelName = modelID
|
||||
}
|
||||
@@ -84,17 +98,13 @@ func gatewayCmd(debug bool) error {
|
||||
msgBus := bus.NewMessageBus()
|
||||
agentLoop := agent.NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
// Print agent startup info
|
||||
fmt.Println("\n📦 Agent Status:")
|
||||
startupInfo := agentLoop.GetStartupInfo()
|
||||
toolsInfo := startupInfo["tools"].(map[string]any)
|
||||
skillsInfo := startupInfo["skills"].(map[string]any)
|
||||
fmt.Printf(" • Tools: %d loaded\n", toolsInfo["count"])
|
||||
fmt.Printf(" • Skills: %d/%d available\n",
|
||||
skillsInfo["available"],
|
||||
skillsInfo["total"])
|
||||
fmt.Printf(" • Skills: %d/%d available\n", skillsInfo["available"], skillsInfo["total"])
|
||||
|
||||
// Log to file as well
|
||||
logger.InfoCF("agent", "Agent initialized",
|
||||
map[string]any{
|
||||
"tools_count": toolsInfo["count"],
|
||||
@@ -102,8 +112,7 @@ func gatewayCmd(debug bool) error {
|
||||
"skills_available": skillsInfo["available"],
|
||||
})
|
||||
|
||||
// Setup and start all services
|
||||
services, err := setupAndStartServices(cfg, agentLoop, msgBus)
|
||||
runningServices, err := setupAndStartServices(cfg, agentLoop, msgBus)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -116,23 +125,25 @@ func gatewayCmd(debug bool) error {
|
||||
|
||||
go agentLoop.Run(ctx)
|
||||
|
||||
// Setup config file watcher for hot reload
|
||||
configReloadChan, stopWatch := setupConfigWatcherPolling(configPath, debug)
|
||||
var configReloadChan <-chan *config.Config
|
||||
stopWatch := func() {}
|
||||
if cfg.Gateway.HotReload {
|
||||
configReloadChan, stopWatch = setupConfigWatcherPolling(configPath, debug)
|
||||
logger.Info("Config hot reload enabled")
|
||||
}
|
||||
defer stopWatch()
|
||||
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, os.Interrupt)
|
||||
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
|
||||
|
||||
// Main event loop - wait for signals or config changes
|
||||
for {
|
||||
select {
|
||||
case <-sigChan:
|
||||
logger.Info("Shutting down...")
|
||||
shutdownGateway(services, agentLoop, provider, true)
|
||||
shutdownGateway(runningServices, agentLoop, provider, true)
|
||||
return nil
|
||||
|
||||
case newCfg := <-configReloadChan:
|
||||
err := handleConfigReload(ctx, agentLoop, newCfg, &provider, services, msgBus)
|
||||
err := handleConfigReload(ctx, agentLoop, newCfg, &provider, runningServices, msgBus, allowEmptyStartup)
|
||||
if err != nil {
|
||||
logger.Errorf("Config reload failed: %v", err)
|
||||
}
|
||||
@@ -140,17 +151,33 @@ func gatewayCmd(debug bool) error {
|
||||
}
|
||||
}
|
||||
|
||||
// setupAndStartServices initializes and starts all services
|
||||
func createStartupProvider(
|
||||
cfg *config.Config,
|
||||
allowEmptyStartup bool,
|
||||
) (providers.LLMProvider, string, error) {
|
||||
modelName := cfg.Agents.Defaults.GetModelName()
|
||||
if modelName == "" && allowEmptyStartup {
|
||||
reason := "no default model configured; gateway started in limited mode"
|
||||
fmt.Printf("⚠ Warning: %s\n", reason)
|
||||
logger.WarnCF("gateway", "Gateway started without default model", map[string]any{
|
||||
"limited_mode": true,
|
||||
})
|
||||
return &startupBlockedProvider{reason: reason}, "", nil
|
||||
}
|
||||
|
||||
return providers.CreateProvider(cfg)
|
||||
}
|
||||
|
||||
func setupAndStartServices(
|
||||
cfg *config.Config,
|
||||
agentLoop *agent.AgentLoop,
|
||||
msgBus *bus.MessageBus,
|
||||
) (*gatewayServices, error) {
|
||||
services := &gatewayServices{}
|
||||
) (*services, error) {
|
||||
runningServices := &services{}
|
||||
|
||||
// Setup cron tool and service
|
||||
execTimeout := time.Duration(cfg.Tools.Cron.ExecTimeoutMinutes) * time.Minute
|
||||
services.CronService = setupCronTool(
|
||||
var err error
|
||||
runningServices.CronService, err = setupCronTool(
|
||||
agentLoop,
|
||||
msgBus,
|
||||
cfg.WorkspacePath(),
|
||||
@@ -158,139 +185,108 @@ func setupAndStartServices(
|
||||
execTimeout,
|
||||
cfg,
|
||||
)
|
||||
if err := services.CronService.Start(); err != nil {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error setting up cron service: %w", err)
|
||||
}
|
||||
if err = runningServices.CronService.Start(); err != nil {
|
||||
return nil, fmt.Errorf("error starting cron service: %w", err)
|
||||
}
|
||||
fmt.Println("✓ Cron service started")
|
||||
|
||||
// Setup heartbeat service
|
||||
services.HeartbeatService = heartbeat.NewHeartbeatService(
|
||||
runningServices.HeartbeatService = heartbeat.NewHeartbeatService(
|
||||
cfg.WorkspacePath(),
|
||||
cfg.Heartbeat.Interval,
|
||||
cfg.Heartbeat.Enabled,
|
||||
)
|
||||
services.HeartbeatService.SetBus(msgBus)
|
||||
services.HeartbeatService.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
|
||||
// Use cli:direct as fallback if no valid channel
|
||||
if channel == "" || chatID == "" {
|
||||
channel, chatID = "cli", "direct"
|
||||
}
|
||||
// Use ProcessHeartbeat - no session history, each heartbeat is independent
|
||||
var response string
|
||||
var err error
|
||||
response, err = agentLoop.ProcessHeartbeat(context.Background(), prompt, channel, chatID)
|
||||
if err != nil {
|
||||
return tools.ErrorResult(fmt.Sprintf("Heartbeat error: %v", err))
|
||||
}
|
||||
if response == "HEARTBEAT_OK" {
|
||||
return tools.SilentResult("Heartbeat OK")
|
||||
}
|
||||
// For heartbeat, always return silent - the subagent result will be
|
||||
// sent to user via processSystemMessage when the async task completes
|
||||
return tools.SilentResult(response)
|
||||
})
|
||||
if err := services.HeartbeatService.Start(); err != nil {
|
||||
runningServices.HeartbeatService.SetBus(msgBus)
|
||||
runningServices.HeartbeatService.SetHandler(createHeartbeatHandler(agentLoop))
|
||||
if err = runningServices.HeartbeatService.Start(); err != nil {
|
||||
return nil, fmt.Errorf("error starting heartbeat service: %w", err)
|
||||
}
|
||||
fmt.Println("✓ Heartbeat service started")
|
||||
|
||||
// Create media store for file lifecycle management with TTL cleanup
|
||||
services.MediaStore = media.NewFileMediaStoreWithCleanup(media.MediaCleanerConfig{
|
||||
runningServices.MediaStore = media.NewFileMediaStoreWithCleanup(media.MediaCleanerConfig{
|
||||
Enabled: cfg.Tools.MediaCleanup.Enabled,
|
||||
MaxAge: time.Duration(cfg.Tools.MediaCleanup.MaxAge) * time.Minute,
|
||||
Interval: time.Duration(cfg.Tools.MediaCleanup.Interval) * time.Minute,
|
||||
})
|
||||
// Start the media store if it's a FileMediaStore with cleanup
|
||||
if fms, ok := services.MediaStore.(*media.FileMediaStore); ok {
|
||||
if fms, ok := runningServices.MediaStore.(*media.FileMediaStore); ok {
|
||||
fms.Start()
|
||||
}
|
||||
|
||||
// Create channel manager
|
||||
var err error
|
||||
services.ChannelManager, err = channels.NewManager(cfg, msgBus, services.MediaStore)
|
||||
runningServices.ChannelManager, err = channels.NewManager(cfg, msgBus, runningServices.MediaStore)
|
||||
if err != nil {
|
||||
// Stop the media store if it's a FileMediaStore with cleanup
|
||||
if fms, ok := services.MediaStore.(*media.FileMediaStore); ok {
|
||||
if fms, ok := runningServices.MediaStore.(*media.FileMediaStore); ok {
|
||||
fms.Stop()
|
||||
}
|
||||
return nil, fmt.Errorf("error creating channel manager: %w", err)
|
||||
}
|
||||
|
||||
// Inject channel manager and media store into agent loop
|
||||
agentLoop.SetChannelManager(services.ChannelManager)
|
||||
agentLoop.SetMediaStore(services.MediaStore)
|
||||
agentLoop.SetChannelManager(runningServices.ChannelManager)
|
||||
agentLoop.SetMediaStore(runningServices.MediaStore)
|
||||
|
||||
// Wire up voice transcription if a supported provider is configured.
|
||||
if transcriber := voice.DetectTranscriber(cfg); transcriber != nil {
|
||||
agentLoop.SetTranscriber(transcriber)
|
||||
logger.InfoCF("voice", "Transcription enabled (agent-level)", map[string]any{"provider": transcriber.Name()})
|
||||
}
|
||||
|
||||
enabledChannels := services.ChannelManager.GetEnabledChannels()
|
||||
enabledChannels := runningServices.ChannelManager.GetEnabledChannels()
|
||||
if len(enabledChannels) > 0 {
|
||||
fmt.Printf("✓ Channels enabled: %s\n", enabledChannels)
|
||||
} else {
|
||||
fmt.Println("⚠ Warning: No channels enabled")
|
||||
}
|
||||
|
||||
// Setup shared HTTP server with health endpoints and webhook handlers
|
||||
addr := fmt.Sprintf("%s:%d", cfg.Gateway.Host, cfg.Gateway.Port)
|
||||
services.HealthServer = health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port)
|
||||
services.ChannelManager.SetupHTTPServer(addr, services.HealthServer)
|
||||
runningServices.HealthServer = health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port)
|
||||
runningServices.ChannelManager.SetupHTTPServer(addr, runningServices.HealthServer)
|
||||
|
||||
if err := services.ChannelManager.StartAll(context.Background()); err != nil {
|
||||
if err = runningServices.ChannelManager.StartAll(context.Background()); err != nil {
|
||||
return nil, fmt.Errorf("error starting channels: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("✓ Health endpoints available at http://%s:%d/health and /ready\n", cfg.Gateway.Host, cfg.Gateway.Port)
|
||||
|
||||
// Setup state manager and device service
|
||||
stateManager := state.NewManager(cfg.WorkspacePath())
|
||||
services.DeviceService = devices.NewService(devices.Config{
|
||||
runningServices.DeviceService = devices.NewService(devices.Config{
|
||||
Enabled: cfg.Devices.Enabled,
|
||||
MonitorUSB: cfg.Devices.MonitorUSB,
|
||||
}, stateManager)
|
||||
services.DeviceService.SetBus(msgBus)
|
||||
if err := services.DeviceService.Start(context.Background()); err != nil {
|
||||
runningServices.DeviceService.SetBus(msgBus)
|
||||
if err = runningServices.DeviceService.Start(context.Background()); err != nil {
|
||||
logger.ErrorCF("device", "Error starting device service", map[string]any{"error": err.Error()})
|
||||
} else if cfg.Devices.Enabled {
|
||||
fmt.Println("✓ Device event service started")
|
||||
}
|
||||
|
||||
return services, nil
|
||||
return runningServices, nil
|
||||
}
|
||||
|
||||
// stopAndCleanupServices stops all services and cleans up resources
|
||||
func stopAndCleanupServices(
|
||||
services *gatewayServices,
|
||||
shutdownTimeout time.Duration,
|
||||
) {
|
||||
func stopAndCleanupServices(runningServices *services, shutdownTimeout time.Duration) {
|
||||
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), shutdownTimeout)
|
||||
defer shutdownCancel()
|
||||
|
||||
if services.ChannelManager != nil {
|
||||
services.ChannelManager.StopAll(shutdownCtx)
|
||||
if runningServices.ChannelManager != nil {
|
||||
runningServices.ChannelManager.StopAll(shutdownCtx)
|
||||
}
|
||||
if services.DeviceService != nil {
|
||||
services.DeviceService.Stop()
|
||||
if runningServices.DeviceService != nil {
|
||||
runningServices.DeviceService.Stop()
|
||||
}
|
||||
if services.HeartbeatService != nil {
|
||||
services.HeartbeatService.Stop()
|
||||
if runningServices.HeartbeatService != nil {
|
||||
runningServices.HeartbeatService.Stop()
|
||||
}
|
||||
if services.CronService != nil {
|
||||
services.CronService.Stop()
|
||||
if runningServices.CronService != nil {
|
||||
runningServices.CronService.Stop()
|
||||
}
|
||||
if services.MediaStore != nil {
|
||||
// Stop the media store if it's a FileMediaStore with cleanup
|
||||
if fms, ok := services.MediaStore.(*media.FileMediaStore); ok {
|
||||
if runningServices.MediaStore != nil {
|
||||
if fms, ok := runningServices.MediaStore.(*media.FileMediaStore); ok {
|
||||
fms.Stop()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// shutdownGateway performs a complete gateway shutdown
|
||||
func shutdownGateway(
|
||||
services *gatewayServices,
|
||||
runningServices *services,
|
||||
agentLoop *agent.AgentLoop,
|
||||
provider providers.LLMProvider,
|
||||
fullShutdown bool,
|
||||
@@ -299,7 +295,7 @@ func shutdownGateway(
|
||||
cp.Close()
|
||||
}
|
||||
|
||||
stopAndCleanupServices(services, gracefulShutdownTimeout)
|
||||
stopAndCleanupServices(runningServices, gracefulShutdownTimeout)
|
||||
|
||||
agentLoop.Stop()
|
||||
agentLoop.Close()
|
||||
@@ -307,15 +303,14 @@ func shutdownGateway(
|
||||
logger.Info("✓ Gateway stopped")
|
||||
}
|
||||
|
||||
// handleConfigReload handles config file reload by stopping all services,
|
||||
// reloading the provider and config, and restarting services with the new config.
|
||||
func handleConfigReload(
|
||||
ctx context.Context,
|
||||
al *agent.AgentLoop,
|
||||
newCfg *config.Config,
|
||||
providerRef *providers.LLMProvider,
|
||||
services *gatewayServices,
|
||||
runningServices *services,
|
||||
msgBus *bus.MessageBus,
|
||||
allowEmptyStartup bool,
|
||||
) error {
|
||||
logger.Info("🔄 Config file changed, reloading...")
|
||||
|
||||
@@ -326,18 +321,14 @@ func handleConfigReload(
|
||||
|
||||
logger.Infof(" New model is '%s', recreating provider...", newModel)
|
||||
|
||||
// Stop all services before reloading
|
||||
logger.Info(" Stopping all services...")
|
||||
stopAndCleanupServices(services, serviceShutdownTimeout)
|
||||
stopAndCleanupServices(runningServices, serviceShutdownTimeout)
|
||||
|
||||
// Create new provider from updated config first to ensure validity
|
||||
// This will use the correct API key and settings from newCfg.ModelList
|
||||
newProvider, newModelID, err := providers.CreateProvider(newCfg)
|
||||
newProvider, newModelID, err := createStartupProvider(newCfg, allowEmptyStartup)
|
||||
if err != nil {
|
||||
logger.Errorf(" ⚠ Error creating new provider: %v", err)
|
||||
logger.Warn(" Attempting to restart services with old provider and config...")
|
||||
// Try to restart services with old configuration
|
||||
if restartErr := restartServices(al, services, msgBus); restartErr != nil {
|
||||
if restartErr := restartServices(al, runningServices, msgBus); restartErr != nil {
|
||||
logger.Errorf(" ⚠ Failed to restart services: %v", restartErr)
|
||||
}
|
||||
return fmt.Errorf("error creating new provider: %w", err)
|
||||
@@ -347,31 +338,25 @@ func handleConfigReload(
|
||||
newCfg.Agents.Defaults.ModelName = newModelID
|
||||
}
|
||||
|
||||
// Use the atomic reload method on AgentLoop to safely swap provider and config.
|
||||
// This handles locking internally to prevent races with in-flight LLM calls
|
||||
// and concurrent reads of registry/config while the swap occurs.
|
||||
reloadCtx, reloadCancel := context.WithTimeout(context.Background(), providerReloadTimeout)
|
||||
defer reloadCancel()
|
||||
|
||||
if err := al.ReloadProviderAndConfig(reloadCtx, newProvider, newCfg); err != nil {
|
||||
logger.Errorf(" ⚠ Error reloading agent loop: %v", err)
|
||||
// Close the newly created provider since it wasn't adopted
|
||||
if cp, ok := newProvider.(providers.StatefulProvider); ok {
|
||||
cp.Close()
|
||||
}
|
||||
logger.Warn(" Attempting to restart services with old provider and config...")
|
||||
if restartErr := restartServices(al, services, msgBus); restartErr != nil {
|
||||
if restartErr := restartServices(al, runningServices, msgBus); restartErr != nil {
|
||||
logger.Errorf(" ⚠ Failed to restart services: %v", restartErr)
|
||||
}
|
||||
return fmt.Errorf("error reloading agent loop: %w", err)
|
||||
}
|
||||
|
||||
// Update local provider reference only after successful atomic reload
|
||||
*providerRef = newProvider
|
||||
|
||||
// Restart all services with new config
|
||||
logger.Info(" Restarting all services with new configuration...")
|
||||
if err := restartServices(al, services, msgBus); err != nil {
|
||||
if err := restartServices(al, runningServices, msgBus); err != nil {
|
||||
logger.Errorf(" ⚠ Error restarting services: %v", err)
|
||||
return fmt.Errorf("error restarting services: %w", err)
|
||||
}
|
||||
@@ -380,23 +365,16 @@ func handleConfigReload(
|
||||
return nil
|
||||
}
|
||||
|
||||
// restartServices restarts all services after a config reload
|
||||
func restartServices(
|
||||
al *agent.AgentLoop,
|
||||
services *gatewayServices,
|
||||
runningServices *services,
|
||||
msgBus *bus.MessageBus,
|
||||
) error {
|
||||
// Create an independent context with timeout for service restart
|
||||
// This prevents cancellation from the main loop context during reload
|
||||
ctx, cancel := context.WithTimeout(context.Background(), serviceRestartTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Get current config from agent loop (which has been updated if this is a reload)
|
||||
cfg := al.GetConfig()
|
||||
|
||||
// Re-create and start cron service with new config
|
||||
execTimeout := time.Duration(cfg.Tools.Cron.ExecTimeoutMinutes) * time.Minute
|
||||
services.CronService = setupCronTool(
|
||||
var err error
|
||||
runningServices.CronService, err = setupCronTool(
|
||||
al,
|
||||
msgBus,
|
||||
cfg.WorkspacePath(),
|
||||
@@ -404,80 +382,54 @@ func restartServices(
|
||||
execTimeout,
|
||||
cfg,
|
||||
)
|
||||
if err := services.CronService.Start(); err != nil {
|
||||
if err != nil {
|
||||
return fmt.Errorf("error restarting cron service: %w", err)
|
||||
}
|
||||
if err = runningServices.CronService.Start(); err != nil {
|
||||
return fmt.Errorf("error restarting cron service: %w", err)
|
||||
}
|
||||
fmt.Println(" ✓ Cron service restarted")
|
||||
|
||||
// Re-create and start heartbeat service with new config
|
||||
services.HeartbeatService = heartbeat.NewHeartbeatService(
|
||||
runningServices.HeartbeatService = heartbeat.NewHeartbeatService(
|
||||
cfg.WorkspacePath(),
|
||||
cfg.Heartbeat.Interval,
|
||||
cfg.Heartbeat.Enabled,
|
||||
)
|
||||
services.HeartbeatService.SetBus(msgBus)
|
||||
services.HeartbeatService.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
|
||||
if channel == "" || chatID == "" {
|
||||
channel, chatID = "cli", "direct"
|
||||
}
|
||||
var response string
|
||||
var err error
|
||||
response, err = al.ProcessHeartbeat(context.Background(), prompt, channel, chatID)
|
||||
if err != nil {
|
||||
return tools.ErrorResult(fmt.Sprintf("Heartbeat error: %v", err))
|
||||
}
|
||||
if response == "HEARTBEAT_OK" {
|
||||
return tools.SilentResult("Heartbeat OK")
|
||||
}
|
||||
return tools.SilentResult(response)
|
||||
})
|
||||
if err := services.HeartbeatService.Start(); err != nil {
|
||||
runningServices.HeartbeatService.SetBus(msgBus)
|
||||
runningServices.HeartbeatService.SetHandler(createHeartbeatHandler(al))
|
||||
if err = runningServices.HeartbeatService.Start(); err != nil {
|
||||
return fmt.Errorf("error restarting heartbeat service: %w", err)
|
||||
}
|
||||
fmt.Println(" ✓ Heartbeat service restarted")
|
||||
|
||||
// Stop the old media store before creating a new one
|
||||
if fms, ok := services.MediaStore.(*media.FileMediaStore); ok {
|
||||
fms.Stop()
|
||||
}
|
||||
|
||||
// Re-create media store with new config
|
||||
services.MediaStore = media.NewFileMediaStoreWithCleanup(media.MediaCleanerConfig{
|
||||
runningServices.MediaStore = media.NewFileMediaStoreWithCleanup(media.MediaCleanerConfig{
|
||||
Enabled: cfg.Tools.MediaCleanup.Enabled,
|
||||
MaxAge: time.Duration(cfg.Tools.MediaCleanup.MaxAge) * time.Minute,
|
||||
Interval: time.Duration(cfg.Tools.MediaCleanup.Interval) * time.Minute,
|
||||
})
|
||||
// Start the media store if it's a FileMediaStore with cleanup
|
||||
if fms, ok := services.MediaStore.(*media.FileMediaStore); ok {
|
||||
if fms, ok := runningServices.MediaStore.(*media.FileMediaStore); ok {
|
||||
fms.Start()
|
||||
}
|
||||
al.SetMediaStore(services.MediaStore)
|
||||
al.SetMediaStore(runningServices.MediaStore)
|
||||
|
||||
// Re-create channel manager with new config
|
||||
var err error
|
||||
services.ChannelManager, err = channels.NewManager(cfg, msgBus, services.MediaStore)
|
||||
runningServices.ChannelManager, err = channels.NewManager(cfg, msgBus, runningServices.MediaStore)
|
||||
if err != nil {
|
||||
// Stop the media store if it's a FileMediaStore with cleanup
|
||||
if fms, ok := services.MediaStore.(*media.FileMediaStore); ok {
|
||||
fms.Stop()
|
||||
}
|
||||
return fmt.Errorf("error recreating channel manager: %w", err)
|
||||
}
|
||||
al.SetChannelManager(services.ChannelManager)
|
||||
al.SetChannelManager(runningServices.ChannelManager)
|
||||
|
||||
enabledChannels := services.ChannelManager.GetEnabledChannels()
|
||||
enabledChannels := runningServices.ChannelManager.GetEnabledChannels()
|
||||
if len(enabledChannels) > 0 {
|
||||
fmt.Printf(" ✓ Channels enabled: %s\n", enabledChannels)
|
||||
} else {
|
||||
fmt.Println(" ⚠ Warning: No channels enabled")
|
||||
}
|
||||
|
||||
// Setup HTTP server with new config
|
||||
addr := fmt.Sprintf("%s:%d", cfg.Gateway.Host, cfg.Gateway.Port)
|
||||
services.HealthServer = health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port)
|
||||
services.ChannelManager.SetupHTTPServer(addr, services.HealthServer)
|
||||
runningServices.HealthServer = health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port)
|
||||
runningServices.ChannelManager.SetupHTTPServer(addr, runningServices.HealthServer)
|
||||
|
||||
if err := services.ChannelManager.StartAll(ctx); err != nil {
|
||||
if err = runningServices.ChannelManager.StartAll(context.Background()); err != nil {
|
||||
return fmt.Errorf("error restarting channels: %w", err)
|
||||
}
|
||||
fmt.Printf(
|
||||
@@ -486,22 +438,20 @@ func restartServices(
|
||||
cfg.Gateway.Port,
|
||||
)
|
||||
|
||||
// Re-create device service with new config
|
||||
stateManager := state.NewManager(cfg.WorkspacePath())
|
||||
services.DeviceService = devices.NewService(devices.Config{
|
||||
runningServices.DeviceService = devices.NewService(devices.Config{
|
||||
Enabled: cfg.Devices.Enabled,
|
||||
MonitorUSB: cfg.Devices.MonitorUSB,
|
||||
}, stateManager)
|
||||
services.DeviceService.SetBus(msgBus)
|
||||
if err := services.DeviceService.Start(ctx); err != nil {
|
||||
runningServices.DeviceService.SetBus(msgBus)
|
||||
if err := runningServices.DeviceService.Start(context.Background()); err != nil {
|
||||
logger.WarnCF("device", "Failed to restart device service", map[string]any{"error": err.Error()})
|
||||
} else if cfg.Devices.Enabled {
|
||||
fmt.Println(" ✓ Device event service restarted")
|
||||
}
|
||||
|
||||
// Wire up voice transcription with new config
|
||||
transcriber := voice.DetectTranscriber(cfg)
|
||||
al.SetTranscriber(transcriber) // This will set it to nil if disabled
|
||||
al.SetTranscriber(transcriber)
|
||||
if transcriber != nil {
|
||||
logger.InfoCF("voice", "Transcription re-enabled (agent-level)", map[string]any{"provider": transcriber.Name()})
|
||||
} else {
|
||||
@@ -511,8 +461,6 @@ func restartServices(
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupConfigWatcherPolling sets up a simple polling-based config file watcher
|
||||
// Returns a channel for config updates and a stop function
|
||||
func setupConfigWatcherPolling(configPath string, debug bool) (chan *config.Config, func()) {
|
||||
configChan := make(chan *config.Config, 1)
|
||||
stop := make(chan struct{})
|
||||
@@ -522,11 +470,10 @@ func setupConfigWatcherPolling(configPath string, debug bool) (chan *config.Conf
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
// Get initial file info
|
||||
lastModTime := getFileModTime(configPath)
|
||||
lastSize := getFileSize(configPath)
|
||||
|
||||
ticker := time.NewTicker(2 * time.Second) // Check every 2 seconds
|
||||
ticker := time.NewTicker(2 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
@@ -535,16 +482,16 @@ func setupConfigWatcherPolling(configPath string, debug bool) (chan *config.Conf
|
||||
currentModTime := getFileModTime(configPath)
|
||||
currentSize := getFileSize(configPath)
|
||||
|
||||
// Check if file changed (modification time or size changed)
|
||||
if currentModTime.After(lastModTime) || currentSize != lastSize {
|
||||
if debug {
|
||||
logger.Debugf("🔍 Config file change detected")
|
||||
}
|
||||
|
||||
// Debounce - wait a bit to ensure file write is complete
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// Validate and load new config
|
||||
lastModTime = currentModTime
|
||||
lastSize = currentSize
|
||||
|
||||
newCfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
logger.Errorf("⚠ Error loading new config: %v", err)
|
||||
@@ -552,7 +499,6 @@ func setupConfigWatcherPolling(configPath string, debug bool) (chan *config.Conf
|
||||
continue
|
||||
}
|
||||
|
||||
// Validate the new config
|
||||
if err := newCfg.ValidateModelList(); err != nil {
|
||||
logger.Errorf(" ⚠ New config validation failed: %v", err)
|
||||
logger.Warn(" Using previous valid config")
|
||||
@@ -561,19 +507,12 @@ func setupConfigWatcherPolling(configPath string, debug bool) (chan *config.Conf
|
||||
|
||||
logger.Info("✓ Config file validated and loaded")
|
||||
|
||||
// Update last known state
|
||||
lastModTime = currentModTime
|
||||
lastSize = currentSize
|
||||
|
||||
// Send new config to main loop (non-blocking)
|
||||
select {
|
||||
case configChan <- newCfg:
|
||||
default:
|
||||
// Channel full, skip this update
|
||||
logger.Warn("⚠ Previous config reload still in progress, skipping")
|
||||
}
|
||||
}
|
||||
|
||||
case <-stop:
|
||||
return
|
||||
}
|
||||
@@ -588,7 +527,6 @@ func setupConfigWatcherPolling(configPath string, debug bool) (chan *config.Conf
|
||||
return configChan, stopFunc
|
||||
}
|
||||
|
||||
// getFileModTime returns the modification time of a file, or zero time if file doesn't exist
|
||||
func getFileModTime(path string) time.Time {
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
@@ -597,7 +535,6 @@ func getFileModTime(path string) time.Time {
|
||||
return info.ModTime()
|
||||
}
|
||||
|
||||
// getFileSize returns the size of a file, or 0 if file doesn't exist
|
||||
func getFileSize(path string) int64 {
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
@@ -613,25 +550,22 @@ func setupCronTool(
|
||||
restrict bool,
|
||||
execTimeout time.Duration,
|
||||
cfg *config.Config,
|
||||
) *cron.CronService {
|
||||
) (*cron.CronService, error) {
|
||||
cronStorePath := filepath.Join(workspace, "cron", "jobs.json")
|
||||
|
||||
// Create cron service
|
||||
cronService := cron.NewCronService(cronStorePath, nil)
|
||||
|
||||
// Create and register CronTool if enabled
|
||||
var cronTool *tools.CronTool
|
||||
if cfg.Tools.IsToolEnabled("cron") {
|
||||
var err error
|
||||
cronTool, err = tools.NewCronTool(cronService, agentLoop, msgBus, workspace, restrict, execTimeout, cfg)
|
||||
if err != nil {
|
||||
logger.Fatalf("Critical error during CronTool initialization: %v", err)
|
||||
return nil, fmt.Errorf("critical error during CronTool initialization: %w", err)
|
||||
}
|
||||
|
||||
agentLoop.RegisterTool(cronTool)
|
||||
}
|
||||
|
||||
// Set onJob handler
|
||||
if cronTool != nil {
|
||||
cronService.SetOnJob(func(job *cron.CronJob) (string, error) {
|
||||
result := cronTool.ExecuteJob(context.Background(), job)
|
||||
@@ -639,5 +573,22 @@ func setupCronTool(
|
||||
})
|
||||
}
|
||||
|
||||
return cronService
|
||||
return cronService, nil
|
||||
}
|
||||
|
||||
func createHeartbeatHandler(agentLoop *agent.AgentLoop) func(prompt, channel, chatID string) *tools.ToolResult {
|
||||
return func(prompt, channel, chatID string) *tools.ToolResult {
|
||||
if channel == "" || chatID == "" {
|
||||
channel, chatID = "cli", "direct"
|
||||
}
|
||||
|
||||
response, err := agentLoop.ProcessHeartbeat(context.Background(), prompt, channel, chatID)
|
||||
if err != nil {
|
||||
return tools.ErrorResult(fmt.Sprintf("Heartbeat error: %v", err))
|
||||
}
|
||||
if response == "HEARTBEAT_OK" {
|
||||
return tools.SilentResult("Heartbeat OK")
|
||||
}
|
||||
return tools.SilentResult(response)
|
||||
}
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"maps"
|
||||
"net/http"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
@@ -29,6 +30,7 @@ type StatusResponse struct {
|
||||
Status string `json:"status"`
|
||||
Uptime string `json:"uptime"`
|
||||
Checks map[string]Check `json:"checks,omitempty"`
|
||||
Pid int `json:"pid"`
|
||||
}
|
||||
|
||||
func NewServer(host string, port int) *Server {
|
||||
@@ -112,6 +114,7 @@ func (s *Server) healthHandler(w http.ResponseWriter, r *http.Request) {
|
||||
resp := StatusResponse{
|
||||
Status: "ok",
|
||||
Uptime: uptime.String(),
|
||||
Pid: os.Getpid(),
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
|
||||
+58
-7
@@ -5,6 +5,7 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
@@ -45,6 +46,9 @@ func init() {
|
||||
consoleWriter := zerolog.ConsoleWriter{
|
||||
Out: os.Stdout,
|
||||
TimeFormat: "15:04:05", // TODO: make it configurable???
|
||||
|
||||
// Custom formatter to handle multiline strings and JSON objects
|
||||
FormatFieldValue: formatFieldValue,
|
||||
}
|
||||
|
||||
logger = zerolog.New(consoleWriter).With().Timestamp().Logger()
|
||||
@@ -52,6 +56,37 @@ func init() {
|
||||
})
|
||||
}
|
||||
|
||||
func formatFieldValue(i any) string {
|
||||
var s string
|
||||
|
||||
switch val := i.(type) {
|
||||
case string:
|
||||
s = val
|
||||
case []byte:
|
||||
s = string(val)
|
||||
default:
|
||||
return fmt.Sprintf("%v", i)
|
||||
}
|
||||
|
||||
if unquoted, err := strconv.Unquote(s); err == nil {
|
||||
s = unquoted
|
||||
}
|
||||
|
||||
if strings.Contains(s, "\n") {
|
||||
return fmt.Sprintf("\n%s", s)
|
||||
}
|
||||
|
||||
if strings.Contains(s, " ") {
|
||||
if (strings.HasPrefix(s, "{") && strings.HasSuffix(s, "}")) ||
|
||||
(strings.HasPrefix(s, "[") && strings.HasSuffix(s, "]")) {
|
||||
return s
|
||||
}
|
||||
return fmt.Sprintf("%q", s)
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func SetLevel(level LogLevel) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
@@ -163,10 +198,7 @@ func logMessage(level LogLevel, component string, message string, fields map[str
|
||||
event.Str("caller", fmt.Sprintf("<none> %s:%d (%s)", callerFile, callerLine, callerFunc))
|
||||
}
|
||||
|
||||
for k, v := range fields {
|
||||
event.Interface(k, v)
|
||||
}
|
||||
|
||||
appendFields(event, fields)
|
||||
event.Msg(message)
|
||||
|
||||
// Also log to file if enabled
|
||||
@@ -176,9 +208,8 @@ func logMessage(level LogLevel, component string, message string, fields map[str
|
||||
if component != "" {
|
||||
fileEvent.Str("component", component)
|
||||
}
|
||||
for k, v := range fields {
|
||||
fileEvent.Interface(k, v)
|
||||
}
|
||||
|
||||
appendFields(event, fields)
|
||||
fileEvent.Msg(message)
|
||||
}
|
||||
|
||||
@@ -187,6 +218,26 @@ func logMessage(level LogLevel, component string, message string, fields map[str
|
||||
}
|
||||
}
|
||||
|
||||
func appendFields(event *zerolog.Event, fields map[string]any) {
|
||||
for k, v := range fields {
|
||||
// Type switch to avoid double JSON serialization of strings
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
event.Str(k, val)
|
||||
case int:
|
||||
event.Int(k, val)
|
||||
case int64:
|
||||
event.Int64(k, val)
|
||||
case float64:
|
||||
event.Float64(k, val)
|
||||
case bool:
|
||||
event.Bool(k, val)
|
||||
default:
|
||||
event.Interface(k, v) // Fallback for struct, slice and maps
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Debug(message string) {
|
||||
logMessage(DEBUG, "", message, nil)
|
||||
}
|
||||
|
||||
@@ -2,7 +2,20 @@
|
||||
|
||||
package logger
|
||||
|
||||
import "fmt"
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
// botTokenRe matches the bot ID prefix and the secret part of a Telegram bot token.
|
||||
// Groups: 1 = "bot<id>:", 2 = first 4 chars of secret, 3 = middle, 4 = last 4 chars.
|
||||
var botTokenRe = regexp.MustCompile(`(bot\d+:)([A-Za-z0-9_-]{4})[A-Za-z0-9_-]{12,}([A-Za-z0-9_-]{4})`)
|
||||
|
||||
// maskSecrets replaces any embedded bot tokens in s with a redacted placeholder
|
||||
// that keeps the first and last 4 characters of the secret for identification.
|
||||
func maskSecrets(s string) string {
|
||||
return botTokenRe.ReplaceAllString(s, "${1}${2}****${3}")
|
||||
}
|
||||
|
||||
// Logger implements common Logger interface
|
||||
type Logger struct {
|
||||
@@ -12,52 +25,52 @@ type Logger struct {
|
||||
|
||||
// Debug logs debug messages
|
||||
func (b *Logger) Debug(v ...any) {
|
||||
logMessage(DEBUG, b.component, fmt.Sprint(v...), nil)
|
||||
logMessage(DEBUG, b.component, maskSecrets(fmt.Sprint(v...)), nil)
|
||||
}
|
||||
|
||||
// Info logs info messages
|
||||
func (b *Logger) Info(v ...any) {
|
||||
logMessage(INFO, b.component, fmt.Sprint(v...), nil)
|
||||
logMessage(INFO, b.component, maskSecrets(fmt.Sprint(v...)), nil)
|
||||
}
|
||||
|
||||
// Warn logs warning messages
|
||||
func (b *Logger) Warn(v ...any) {
|
||||
logMessage(WARN, b.component, fmt.Sprint(v...), nil)
|
||||
logMessage(WARN, b.component, maskSecrets(fmt.Sprint(v...)), nil)
|
||||
}
|
||||
|
||||
// Error logs error messages
|
||||
func (b *Logger) Error(v ...any) {
|
||||
logMessage(ERROR, b.component, fmt.Sprint(v...), nil)
|
||||
logMessage(ERROR, b.component, maskSecrets(fmt.Sprint(v...)), nil)
|
||||
}
|
||||
|
||||
// Debugf logs formatted debug messages
|
||||
func (b *Logger) Debugf(format string, v ...any) {
|
||||
logMessage(DEBUG, b.component, fmt.Sprintf(format, v...), nil)
|
||||
logMessage(DEBUG, b.component, maskSecrets(fmt.Sprintf(format, v...)), nil)
|
||||
}
|
||||
|
||||
// Infof logs formatted info messages
|
||||
func (b *Logger) Infof(format string, v ...any) {
|
||||
logMessage(INFO, b.component, fmt.Sprintf(format, v...), nil)
|
||||
logMessage(INFO, b.component, maskSecrets(fmt.Sprintf(format, v...)), nil)
|
||||
}
|
||||
|
||||
// Warnf logs formatted warning messages
|
||||
func (b *Logger) Warnf(format string, v ...any) {
|
||||
logMessage(WARN, b.component, fmt.Sprintf(format, v...), nil)
|
||||
logMessage(WARN, b.component, maskSecrets(fmt.Sprintf(format, v...)), nil)
|
||||
}
|
||||
|
||||
// Warningf logs formatted warning messages
|
||||
func (b *Logger) Warningf(format string, v ...any) {
|
||||
logMessage(WARN, b.component, fmt.Sprintf(format, v...), nil)
|
||||
logMessage(WARN, b.component, maskSecrets(fmt.Sprintf(format, v...)), nil)
|
||||
}
|
||||
|
||||
// Errorf logs formatted error messages
|
||||
func (b *Logger) Errorf(format string, v ...any) {
|
||||
logMessage(ERROR, b.component, fmt.Sprintf(format, v...), nil)
|
||||
logMessage(ERROR, b.component, maskSecrets(fmt.Sprintf(format, v...)), nil)
|
||||
}
|
||||
|
||||
// Fatalf logs formatted fatal messages and exits
|
||||
func (b *Logger) Fatalf(format string, v ...any) {
|
||||
logMessage(FATAL, b.component, fmt.Sprintf(format, v...), nil)
|
||||
logMessage(FATAL, b.component, maskSecrets(fmt.Sprintf(format, v...)), nil)
|
||||
}
|
||||
|
||||
// Log logs a message at a given level with caller information
|
||||
@@ -75,7 +88,7 @@ func (b *Logger) Log(msgL, caller int, format string, a ...any) {
|
||||
level = lvl
|
||||
}
|
||||
}
|
||||
logMessage(level, b.component, fmt.Sprintf(format, a...), nil)
|
||||
logMessage(level, b.component, maskSecrets(fmt.Sprintf(format, a...)), nil)
|
||||
}
|
||||
|
||||
// Sync flushes log buffer (no-op for this implementation)
|
||||
|
||||
@@ -141,3 +141,114 @@ func TestLoggerHelperFunctions(t *testing.T) {
|
||||
Debugf("test from %v", "Debugf")
|
||||
WarnF("Warning with fields", map[string]any{"key": "value"})
|
||||
}
|
||||
|
||||
func TestFormatFieldValue(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input any
|
||||
expected string
|
||||
}{
|
||||
// Basic types test (default case of the switch)
|
||||
{
|
||||
name: "Integer Type",
|
||||
input: 42,
|
||||
expected: "42",
|
||||
},
|
||||
{
|
||||
name: "Boolean Type",
|
||||
input: true,
|
||||
expected: "true",
|
||||
},
|
||||
{
|
||||
name: "Unsupported Struct Type",
|
||||
input: struct{ A int }{A: 1},
|
||||
expected: "{1}",
|
||||
},
|
||||
|
||||
// Simple strings and byte slices test
|
||||
{
|
||||
name: "Simple string without spaces",
|
||||
input: "simple_value",
|
||||
expected: "simple_value",
|
||||
},
|
||||
{
|
||||
name: "Simple byte slice",
|
||||
input: []byte("byte_value"),
|
||||
expected: "byte_value",
|
||||
},
|
||||
|
||||
// Unquoting test (strconv.Unquote)
|
||||
{
|
||||
name: "Quoted string",
|
||||
input: `"quoted_value"`,
|
||||
expected: "quoted_value",
|
||||
},
|
||||
|
||||
// Strings with newline (\n) test
|
||||
{
|
||||
name: "String with newline",
|
||||
input: "line1\nline2",
|
||||
expected: "\nline1\nline2",
|
||||
},
|
||||
{
|
||||
name: "Quoted string with newline (Unquote -> newline)",
|
||||
input: `"line1\nline2"`, // Escaped \n that Unquote will resolve
|
||||
expected: "\nline1\nline2",
|
||||
},
|
||||
|
||||
// Strings with spaces test (which should be quoted)
|
||||
{
|
||||
name: "String with spaces",
|
||||
input: "hello world",
|
||||
expected: `"hello world"`,
|
||||
},
|
||||
{
|
||||
name: "Quoted string with spaces (Unquote -> has spaces -> Re-quote)",
|
||||
input: `"hello world"`,
|
||||
expected: `"hello world"`,
|
||||
},
|
||||
|
||||
// JSON formats test (strings with spaces that start/end with brackets)
|
||||
{
|
||||
name: "Valid JSON object",
|
||||
input: `{"key": "value"}`,
|
||||
expected: `{"key": "value"}`,
|
||||
},
|
||||
{
|
||||
name: "Valid JSON array",
|
||||
input: `[1, 2, "three"]`,
|
||||
expected: `[1, 2, "three"]`,
|
||||
},
|
||||
{
|
||||
name: "Fake JSON (starts with { but doesn't end with })",
|
||||
input: `{"key": "value"`, // Missing closing bracket, has spaces
|
||||
expected: `"{\"key\": \"value\""`,
|
||||
},
|
||||
{
|
||||
name: "Empty JSON (object)",
|
||||
input: `{ }`,
|
||||
expected: `{ }`,
|
||||
},
|
||||
|
||||
// 7. Edge Cases
|
||||
{
|
||||
name: "Empty string",
|
||||
input: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Whitespace only string",
|
||||
input: " ",
|
||||
expected: `" "`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
actual := formatFieldValue(tt.input)
|
||||
if actual != tt.expected {
|
||||
t.Errorf("formatFieldValue() = %q, expected %q", actual, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
package media
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
const TempDirName = "picoclaw_media"
|
||||
|
||||
// TempDir returns the shared temporary directory used for downloaded media.
|
||||
func TempDir() string {
|
||||
return filepath.Join(os.TempDir(), TempDirName)
|
||||
}
|
||||
@@ -221,11 +221,17 @@ func buildRequestBody(
|
||||
|
||||
// Add tool_use blocks
|
||||
for _, tc := range msg.ToolCalls {
|
||||
// Handle nil Arguments (GLM-4 may return null input)
|
||||
input := tc.Arguments
|
||||
if input == nil {
|
||||
input = map[string]any{}
|
||||
}
|
||||
|
||||
toolUse := map[string]any{
|
||||
"type": "tool_use",
|
||||
"id": tc.ID,
|
||||
"name": tc.Name,
|
||||
"input": tc.Arguments,
|
||||
"input": input,
|
||||
}
|
||||
content = append(content, toolUse)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,150 @@
|
||||
package azure
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers/common"
|
||||
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
|
||||
)
|
||||
|
||||
type (
|
||||
LLMResponse = protocoltypes.LLMResponse
|
||||
Message = protocoltypes.Message
|
||||
ToolDefinition = protocoltypes.ToolDefinition
|
||||
)
|
||||
|
||||
const (
|
||||
// azureAPIVersion is the Azure OpenAI API version used for all requests.
|
||||
azureAPIVersion = "2024-10-21"
|
||||
defaultRequestTimeout = common.DefaultRequestTimeout
|
||||
)
|
||||
|
||||
// Provider implements the LLM provider interface for Azure OpenAI endpoints.
|
||||
// It handles Azure-specific authentication (api-key header), URL construction
|
||||
// (deployment-based), and request body formatting (max_completion_tokens, no model field).
|
||||
type Provider struct {
|
||||
apiKey string
|
||||
apiBase string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// Option configures the Azure Provider.
|
||||
type Option func(*Provider)
|
||||
|
||||
// WithRequestTimeout sets the HTTP request timeout.
|
||||
func WithRequestTimeout(timeout time.Duration) Option {
|
||||
return func(p *Provider) {
|
||||
if timeout > 0 {
|
||||
p.httpClient.Timeout = timeout
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// NewProvider creates a new Azure OpenAI provider.
|
||||
func NewProvider(apiKey, apiBase, proxy string, opts ...Option) *Provider {
|
||||
p := &Provider{
|
||||
apiKey: apiKey,
|
||||
apiBase: strings.TrimRight(apiBase, "/"),
|
||||
httpClient: common.NewHTTPClient(proxy),
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
if opt != nil {
|
||||
opt(p)
|
||||
}
|
||||
}
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
// NewProviderWithTimeout creates a new Azure OpenAI provider with a custom request timeout in seconds.
|
||||
func NewProviderWithTimeout(apiKey, apiBase, proxy string, requestTimeoutSeconds int) *Provider {
|
||||
return NewProvider(
|
||||
apiKey, apiBase, proxy,
|
||||
WithRequestTimeout(time.Duration(requestTimeoutSeconds)*time.Second),
|
||||
)
|
||||
}
|
||||
|
||||
// Chat sends a chat completion request to the Azure OpenAI endpoint.
|
||||
// The model parameter is used as the Azure deployment name in the URL.
|
||||
func (p *Provider) Chat(
|
||||
ctx context.Context,
|
||||
messages []Message,
|
||||
tools []ToolDefinition,
|
||||
model string,
|
||||
options map[string]any,
|
||||
) (*LLMResponse, error) {
|
||||
if p.apiBase == "" {
|
||||
return nil, fmt.Errorf("Azure API base not configured")
|
||||
}
|
||||
|
||||
// model is the deployment name for Azure OpenAI
|
||||
deployment := model
|
||||
|
||||
// Build Azure-specific URL safely using url.JoinPath and query encoding
|
||||
// to prevent path traversal or query injection via deployment names.
|
||||
base, err := url.JoinPath(p.apiBase, "openai/deployments", deployment, "chat/completions")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build Azure request URL: %w", err)
|
||||
}
|
||||
requestURL := base + "?api-version=" + azureAPIVersion
|
||||
|
||||
// Build request body — no "model" field (Azure infers from deployment URL)
|
||||
requestBody := map[string]any{
|
||||
"messages": common.SerializeMessages(messages),
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
requestBody["tools"] = tools
|
||||
requestBody["tool_choice"] = "auto"
|
||||
}
|
||||
|
||||
// Azure OpenAI always uses max_completion_tokens
|
||||
if maxTokens, ok := common.AsInt(options["max_tokens"]); ok {
|
||||
requestBody["max_completion_tokens"] = maxTokens
|
||||
}
|
||||
|
||||
if temperature, ok := common.AsFloat(options["temperature"]); ok {
|
||||
requestBody["temperature"] = temperature
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(requestBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", requestURL, bytes.NewReader(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
// Azure uses api-key header instead of Authorization: Bearer
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if p.apiKey != "" {
|
||||
req.Header.Set("Api-Key", p.apiKey)
|
||||
}
|
||||
|
||||
resp, err := p.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, common.HandleErrorResponse(resp, p.apiBase)
|
||||
}
|
||||
|
||||
return common.ReadAndParseResponse(resp, p.apiBase)
|
||||
}
|
||||
|
||||
// GetDefaultModel returns an empty string as Azure deployments are user-configured.
|
||||
func (p *Provider) GetDefaultModel() string {
|
||||
return ""
|
||||
}
|
||||
@@ -0,0 +1,232 @@
|
||||
package azure
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// writeValidResponse writes a minimal valid Azure OpenAI chat completion response.
|
||||
func writeValidResponse(w http.ResponseWriter) {
|
||||
resp := map[string]any{
|
||||
"choices": []map[string]any{
|
||||
{
|
||||
"message": map[string]any{"content": "ok"},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
|
||||
func TestProviderChat_AzureURLConstruction(t *testing.T) {
|
||||
var capturedPath string
|
||||
var capturedAPIVersion string
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedPath = r.URL.Path
|
||||
capturedAPIVersion = r.URL.Query().Get("api-version")
|
||||
writeValidResponse(w)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("test-key", server.URL, "")
|
||||
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "my-gpt5-deployment", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
wantPath := "/openai/deployments/my-gpt5-deployment/chat/completions"
|
||||
if capturedPath != wantPath {
|
||||
t.Errorf("URL path = %q, want %q", capturedPath, wantPath)
|
||||
}
|
||||
if capturedAPIVersion != azureAPIVersion {
|
||||
t.Errorf("api-version = %q, want %q", capturedAPIVersion, azureAPIVersion)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_AzureAuthHeader(t *testing.T) {
|
||||
var capturedAPIKey string
|
||||
var capturedAuth string
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedAPIKey = r.Header.Get("Api-Key")
|
||||
capturedAuth = r.Header.Get("Authorization")
|
||||
writeValidResponse(w)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("test-azure-key", server.URL, "")
|
||||
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
if capturedAPIKey != "test-azure-key" {
|
||||
t.Errorf("api-key header = %q, want %q", capturedAPIKey, "test-azure-key")
|
||||
}
|
||||
if capturedAuth != "" {
|
||||
t.Errorf("Authorization header should be empty, got %q", capturedAuth)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_AzureOmitsModelFromBody(t *testing.T) {
|
||||
var requestBody map[string]any
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewDecoder(r.Body).Decode(&requestBody)
|
||||
writeValidResponse(w)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("test-key", server.URL, "")
|
||||
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
if _, exists := requestBody["model"]; exists {
|
||||
t.Error("request body should not contain 'model' field for Azure OpenAI")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_AzureUsesMaxCompletionTokens(t *testing.T) {
|
||||
var requestBody map[string]any
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewDecoder(r.Body).Decode(&requestBody)
|
||||
writeValidResponse(w)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("test-key", server.URL, "")
|
||||
_, err := p.Chat(
|
||||
t.Context(),
|
||||
[]Message{{Role: "user", Content: "hi"}},
|
||||
nil,
|
||||
"deployment",
|
||||
map[string]any{"max_tokens": 2048},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
if _, exists := requestBody["max_completion_tokens"]; !exists {
|
||||
t.Error("request body should contain 'max_completion_tokens'")
|
||||
}
|
||||
if _, exists := requestBody["max_tokens"]; exists {
|
||||
t.Error("request body should not contain 'max_tokens'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_AzureHTTPError(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, `{"error":"unauthorized"}`, http.StatusUnauthorized)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("bad-key", server.URL, "")
|
||||
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_AzureParseToolCalls(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
resp := map[string]any{
|
||||
"choices": []map[string]any{
|
||||
{
|
||||
"message": map[string]any{
|
||||
"content": "",
|
||||
"tool_calls": []map[string]any{
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": map[string]any{
|
||||
"name": "get_weather",
|
||||
"arguments": `{"city":"Seattle"}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"finish_reason": "tool_calls",
|
||||
},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("test-key", server.URL, "")
|
||||
out, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "weather?"}}, nil, "deployment", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
if len(out.ToolCalls) != 1 {
|
||||
t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls))
|
||||
}
|
||||
if out.ToolCalls[0].Name != "get_weather" {
|
||||
t.Errorf("ToolCalls[0].Name = %q, want %q", out.ToolCalls[0].Name, "get_weather")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_AzureEmptyAPIBase(t *testing.T) {
|
||||
p := NewProvider("test-key", "", "")
|
||||
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "deployment", nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty API base")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_AzureRequestTimeoutDefault(t *testing.T) {
|
||||
p := NewProvider("test-key", "https://example.com", "")
|
||||
if p.httpClient.Timeout != defaultRequestTimeout {
|
||||
t.Errorf("timeout = %v, want %v", p.httpClient.Timeout, defaultRequestTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_AzureRequestTimeoutOverride(t *testing.T) {
|
||||
p := NewProvider("test-key", "https://example.com", "", WithRequestTimeout(300*time.Second))
|
||||
if p.httpClient.Timeout != 300*time.Second {
|
||||
t.Errorf("timeout = %v, want %v", p.httpClient.Timeout, 300*time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProvider_AzureNewProviderWithTimeout(t *testing.T) {
|
||||
p := NewProviderWithTimeout("test-key", "https://example.com", "", 180)
|
||||
if p.httpClient.Timeout != 180*time.Second {
|
||||
t.Errorf("timeout = %v, want %v", p.httpClient.Timeout, 180*time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_AzureDeploymentNameEscaped(t *testing.T) {
|
||||
var capturedPath string
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedPath = r.URL.RawPath // use RawPath to see percent-encoding
|
||||
if capturedPath == "" {
|
||||
capturedPath = r.URL.Path
|
||||
}
|
||||
writeValidResponse(w)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider("test-key", server.URL, "")
|
||||
|
||||
// Deployment name with characters that could cause path injection
|
||||
_, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "my deploy/../../admin", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
// The slash and special chars in the deployment name must be escaped, not treated as path separators
|
||||
if capturedPath == "/openai/deployments/my deploy/../../admin/chat/completions" {
|
||||
t.Fatal("deployment name was interpolated without escaping — path injection possible")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,380 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// License: MIT
|
||||
//
|
||||
// Copyright (c) 2026 PicoClaw contributors
|
||||
|
||||
// Package common provides shared utilities used by multiple LLM provider
|
||||
// implementations (openai_compat, azure, etc.).
|
||||
package common
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
|
||||
)
|
||||
|
||||
// Re-export protocol types used across providers.
|
||||
type (
|
||||
ToolCall = protocoltypes.ToolCall
|
||||
FunctionCall = protocoltypes.FunctionCall
|
||||
LLMResponse = protocoltypes.LLMResponse
|
||||
UsageInfo = protocoltypes.UsageInfo
|
||||
Message = protocoltypes.Message
|
||||
ToolDefinition = protocoltypes.ToolDefinition
|
||||
ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition
|
||||
ExtraContent = protocoltypes.ExtraContent
|
||||
GoogleExtra = protocoltypes.GoogleExtra
|
||||
ReasoningDetail = protocoltypes.ReasoningDetail
|
||||
)
|
||||
|
||||
const DefaultRequestTimeout = 120 * time.Second
|
||||
|
||||
// NewHTTPClient creates an *http.Client with an optional proxy and the default timeout.
|
||||
func NewHTTPClient(proxy string) *http.Client {
|
||||
client := &http.Client{
|
||||
Timeout: DefaultRequestTimeout,
|
||||
}
|
||||
if proxy != "" {
|
||||
parsed, err := url.Parse(proxy)
|
||||
if err == nil {
|
||||
// Preserve http.DefaultTransport settings (TLS, HTTP/2, timeouts, etc.)
|
||||
if base, ok := http.DefaultTransport.(*http.Transport); ok {
|
||||
tr := base.Clone()
|
||||
tr.Proxy = http.ProxyURL(parsed)
|
||||
client.Transport = tr
|
||||
} else {
|
||||
// Fallback: minimal transport if DefaultTransport is not *http.Transport.
|
||||
client.Transport = &http.Transport{
|
||||
Proxy: http.ProxyURL(parsed),
|
||||
}
|
||||
}
|
||||
} else {
|
||||
log.Printf("common: invalid proxy URL %q: %v", proxy, err)
|
||||
}
|
||||
}
|
||||
return client
|
||||
}
|
||||
|
||||
// --- Message serialization ---
|
||||
|
||||
// openaiMessage is the wire-format message for OpenAI-compatible APIs.
|
||||
// It mirrors protocoltypes.Message but omits SystemParts, which is an
|
||||
// internal field that would be unknown to third-party endpoints.
|
||||
type openaiMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
}
|
||||
|
||||
// SerializeMessages converts internal Message structs to the OpenAI wire format.
|
||||
// - Strips SystemParts (unknown to third-party endpoints)
|
||||
// - Converts messages with Media to multipart content format (text + image_url parts)
|
||||
// - Preserves ToolCallID, ToolCalls, and ReasoningContent for all messages
|
||||
func SerializeMessages(messages []Message) []any {
|
||||
out := make([]any, 0, len(messages))
|
||||
for _, m := range messages {
|
||||
if len(m.Media) == 0 {
|
||||
out = append(out, openaiMessage{
|
||||
Role: m.Role,
|
||||
Content: m.Content,
|
||||
ReasoningContent: m.ReasoningContent,
|
||||
ToolCalls: m.ToolCalls,
|
||||
ToolCallID: m.ToolCallID,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// Multipart content format for messages with media
|
||||
parts := make([]map[string]any, 0, 1+len(m.Media))
|
||||
if m.Content != "" {
|
||||
parts = append(parts, map[string]any{
|
||||
"type": "text",
|
||||
"text": m.Content,
|
||||
})
|
||||
}
|
||||
for _, mediaURL := range m.Media {
|
||||
if strings.HasPrefix(mediaURL, "data:image/") {
|
||||
parts = append(parts, map[string]any{
|
||||
"type": "image_url",
|
||||
"image_url": map[string]any{
|
||||
"url": mediaURL,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
msg := map[string]any{
|
||||
"role": m.Role,
|
||||
"content": parts,
|
||||
}
|
||||
if m.ToolCallID != "" {
|
||||
msg["tool_call_id"] = m.ToolCallID
|
||||
}
|
||||
if len(m.ToolCalls) > 0 {
|
||||
msg["tool_calls"] = m.ToolCalls
|
||||
}
|
||||
if m.ReasoningContent != "" {
|
||||
msg["reasoning_content"] = m.ReasoningContent
|
||||
}
|
||||
out = append(out, msg)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// --- Response parsing ---
|
||||
|
||||
// ParseResponse parses a JSON chat completion response body into an LLMResponse.
|
||||
func ParseResponse(body io.Reader) (*LLMResponse, error) {
|
||||
var apiResponse struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
ReasoningContent string `json:"reasoning_content"`
|
||||
Reasoning string `json:"reasoning"`
|
||||
ReasoningDetails []ReasoningDetail `json:"reasoning_details"`
|
||||
ToolCalls []struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Function *struct {
|
||||
Name string `json:"name"`
|
||||
Arguments json.RawMessage `json:"arguments"`
|
||||
} `json:"function"`
|
||||
ExtraContent *struct {
|
||||
Google *struct {
|
||||
ThoughtSignature string `json:"thought_signature"`
|
||||
} `json:"google"`
|
||||
} `json:"extra_content"`
|
||||
} `json:"tool_calls"`
|
||||
} `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
} `json:"choices"`
|
||||
Usage *UsageInfo `json:"usage"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(body).Decode(&apiResponse); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
if len(apiResponse.Choices) == 0 {
|
||||
return &LLMResponse{
|
||||
Content: "",
|
||||
FinishReason: "stop",
|
||||
}, nil
|
||||
}
|
||||
|
||||
choice := apiResponse.Choices[0]
|
||||
toolCalls := make([]ToolCall, 0, len(choice.Message.ToolCalls))
|
||||
for _, tc := range choice.Message.ToolCalls {
|
||||
arguments := make(map[string]any)
|
||||
name := ""
|
||||
|
||||
// Extract thought_signature from Gemini/Google-specific extra content
|
||||
thoughtSignature := ""
|
||||
if tc.ExtraContent != nil && tc.ExtraContent.Google != nil {
|
||||
thoughtSignature = tc.ExtraContent.Google.ThoughtSignature
|
||||
}
|
||||
|
||||
if tc.Function != nil {
|
||||
name = tc.Function.Name
|
||||
arguments = DecodeToolCallArguments(tc.Function.Arguments, name)
|
||||
}
|
||||
|
||||
toolCall := ToolCall{
|
||||
ID: tc.ID,
|
||||
Name: name,
|
||||
Arguments: arguments,
|
||||
ThoughtSignature: thoughtSignature,
|
||||
}
|
||||
|
||||
if thoughtSignature != "" {
|
||||
toolCall.ExtraContent = &ExtraContent{
|
||||
Google: &GoogleExtra{
|
||||
ThoughtSignature: thoughtSignature,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
}
|
||||
|
||||
return &LLMResponse{
|
||||
Content: choice.Message.Content,
|
||||
ReasoningContent: choice.Message.ReasoningContent,
|
||||
Reasoning: choice.Message.Reasoning,
|
||||
ReasoningDetails: choice.Message.ReasoningDetails,
|
||||
ToolCalls: toolCalls,
|
||||
FinishReason: choice.FinishReason,
|
||||
Usage: apiResponse.Usage,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// DecodeToolCallArguments decodes a tool call's arguments from raw JSON.
|
||||
func DecodeToolCallArguments(raw json.RawMessage, name string) map[string]any {
|
||||
arguments := make(map[string]any)
|
||||
raw = bytes.TrimSpace(raw)
|
||||
if len(raw) == 0 || bytes.Equal(raw, []byte("null")) {
|
||||
return arguments
|
||||
}
|
||||
|
||||
var decoded any
|
||||
if err := json.Unmarshal(raw, &decoded); err != nil {
|
||||
log.Printf("common: failed to decode tool call arguments payload for %q: %v", name, err)
|
||||
arguments["raw"] = string(raw)
|
||||
return arguments
|
||||
}
|
||||
|
||||
switch v := decoded.(type) {
|
||||
case string:
|
||||
if strings.TrimSpace(v) == "" {
|
||||
return arguments
|
||||
}
|
||||
if err := json.Unmarshal([]byte(v), &arguments); err != nil {
|
||||
log.Printf("common: failed to decode tool call arguments for %q: %v", name, err)
|
||||
arguments["raw"] = v
|
||||
}
|
||||
return arguments
|
||||
case map[string]any:
|
||||
return v
|
||||
default:
|
||||
log.Printf("common: unsupported tool call arguments type for %q: %T", name, decoded)
|
||||
arguments["raw"] = string(raw)
|
||||
return arguments
|
||||
}
|
||||
}
|
||||
|
||||
// --- HTTP response helpers ---
|
||||
|
||||
// HandleErrorResponse reads a non-200 response body and returns an appropriate error.
|
||||
func HandleErrorResponse(resp *http.Response, apiBase string) error {
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
body, readErr := io.ReadAll(io.LimitReader(resp.Body, 256))
|
||||
if readErr != nil {
|
||||
return fmt.Errorf("failed to read response: %w", readErr)
|
||||
}
|
||||
if LooksLikeHTML(body, contentType) {
|
||||
return WrapHTMLResponseError(resp.StatusCode, body, contentType, apiBase)
|
||||
}
|
||||
return fmt.Errorf(
|
||||
"API request failed:\n Status: %d\n Body: %s",
|
||||
resp.StatusCode,
|
||||
ResponsePreview(body, 128),
|
||||
)
|
||||
}
|
||||
|
||||
// ReadAndParseResponse peeks at the response body to detect HTML errors,
|
||||
// then parses the JSON response into an LLMResponse.
|
||||
func ReadAndParseResponse(resp *http.Response, apiBase string) (*LLMResponse, error) {
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
prefix, err := reader.Peek(256)
|
||||
if err != nil && err != io.EOF && err != bufio.ErrBufferFull {
|
||||
return nil, fmt.Errorf("failed to inspect response: %w", err)
|
||||
}
|
||||
if LooksLikeHTML(prefix, contentType) {
|
||||
return nil, WrapHTMLResponseError(resp.StatusCode, prefix, contentType, apiBase)
|
||||
}
|
||||
out, err := ParseResponse(reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse JSON response: %w", err)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// LooksLikeHTML checks if the response body appears to be HTML.
|
||||
func LooksLikeHTML(body []byte, contentType string) bool {
|
||||
contentType = strings.ToLower(strings.TrimSpace(contentType))
|
||||
if strings.Contains(contentType, "text/html") || strings.Contains(contentType, "application/xhtml+xml") {
|
||||
return true
|
||||
}
|
||||
prefix := bytes.ToLower(leadingTrimmedPrefix(body, 128))
|
||||
return bytes.HasPrefix(prefix, []byte("<!doctype html")) ||
|
||||
bytes.HasPrefix(prefix, []byte("<html")) ||
|
||||
bytes.HasPrefix(prefix, []byte("<head")) ||
|
||||
bytes.HasPrefix(prefix, []byte("<body"))
|
||||
}
|
||||
|
||||
// WrapHTMLResponseError creates a descriptive error for HTML responses.
|
||||
func WrapHTMLResponseError(statusCode int, body []byte, contentType, apiBase string) error {
|
||||
respPreview := ResponsePreview(body, 128)
|
||||
return fmt.Errorf(
|
||||
"API request failed: %s returned HTML instead of JSON (content-type: %s); check api_base or proxy configuration.\n Status: %d\n Body: %s",
|
||||
apiBase,
|
||||
contentType,
|
||||
statusCode,
|
||||
respPreview,
|
||||
)
|
||||
}
|
||||
|
||||
// ResponsePreview returns a truncated preview of response body for error messages.
|
||||
func ResponsePreview(body []byte, maxLen int) string {
|
||||
trimmed := bytes.TrimSpace(body)
|
||||
if len(trimmed) == 0 {
|
||||
return "<empty>"
|
||||
}
|
||||
if len(trimmed) <= maxLen {
|
||||
return string(trimmed)
|
||||
}
|
||||
return string(trimmed[:maxLen]) + "..."
|
||||
}
|
||||
|
||||
func leadingTrimmedPrefix(body []byte, maxLen int) []byte {
|
||||
i := 0
|
||||
for i < len(body) {
|
||||
switch body[i] {
|
||||
case ' ', '\t', '\n', '\r', '\f', '\v':
|
||||
i++
|
||||
default:
|
||||
end := i + maxLen
|
||||
if end > len(body) {
|
||||
end = len(body)
|
||||
}
|
||||
return body[i:end]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- Numeric helpers ---
|
||||
|
||||
// AsInt converts various numeric types to int.
|
||||
func AsInt(v any) (int, bool) {
|
||||
switch val := v.(type) {
|
||||
case int:
|
||||
return val, true
|
||||
case int64:
|
||||
return int(val), true
|
||||
case float64:
|
||||
return int(val), true
|
||||
case float32:
|
||||
return int(val), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
// AsFloat converts various numeric types to float64.
|
||||
func AsFloat(v any) (float64, bool) {
|
||||
switch val := v.(type) {
|
||||
case float64:
|
||||
return val, true
|
||||
case float32:
|
||||
return float64(val), true
|
||||
case int:
|
||||
return float64(val), true
|
||||
case int64:
|
||||
return float64(val), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,558 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
|
||||
)
|
||||
|
||||
// --- NewHTTPClient tests ---
|
||||
|
||||
func TestNewHTTPClient_DefaultTimeout(t *testing.T) {
|
||||
client := NewHTTPClient("")
|
||||
if client.Timeout != DefaultRequestTimeout {
|
||||
t.Errorf("timeout = %v, want %v", client.Timeout, DefaultRequestTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewHTTPClient_WithProxy(t *testing.T) {
|
||||
client := NewHTTPClient("http://127.0.0.1:8080")
|
||||
transport, ok := client.Transport.(*http.Transport)
|
||||
if !ok || transport == nil {
|
||||
t.Fatalf("expected http.Transport with proxy, got %T", client.Transport)
|
||||
}
|
||||
req := &http.Request{URL: &url.URL{Scheme: "https", Host: "api.example.com"}}
|
||||
gotProxy, err := transport.Proxy(req)
|
||||
if err != nil {
|
||||
t.Fatalf("proxy function error: %v", err)
|
||||
}
|
||||
if gotProxy == nil || gotProxy.String() != "http://127.0.0.1:8080" {
|
||||
t.Errorf("proxy = %v, want http://127.0.0.1:8080", gotProxy)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewHTTPClient_NoProxy(t *testing.T) {
|
||||
client := NewHTTPClient("")
|
||||
if client.Transport != nil {
|
||||
t.Errorf("expected nil transport without proxy, got %T", client.Transport)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewHTTPClient_InvalidProxy(t *testing.T) {
|
||||
// Should not panic, just log and return client without proxy
|
||||
client := NewHTTPClient("://bad-url")
|
||||
if client == nil {
|
||||
t.Fatal("expected non-nil client even with invalid proxy")
|
||||
}
|
||||
}
|
||||
|
||||
// --- SerializeMessages tests ---
|
||||
|
||||
func TestSerializeMessages_PlainText(t *testing.T) {
|
||||
messages := []Message{
|
||||
{Role: "user", Content: "hello"},
|
||||
{Role: "assistant", Content: "hi", ReasoningContent: "thinking..."},
|
||||
}
|
||||
result := SerializeMessages(messages)
|
||||
|
||||
data, _ := json.Marshal(result)
|
||||
var msgs []map[string]any
|
||||
json.Unmarshal(data, &msgs)
|
||||
|
||||
if msgs[0]["content"] != "hello" {
|
||||
t.Errorf("expected plain string content, got %v", msgs[0]["content"])
|
||||
}
|
||||
if msgs[1]["reasoning_content"] != "thinking..." {
|
||||
t.Errorf("reasoning_content not preserved, got %v", msgs[1]["reasoning_content"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerializeMessages_WithMedia(t *testing.T) {
|
||||
messages := []Message{
|
||||
{Role: "user", Content: "describe this", Media: []string{"data:image/png;base64,abc123"}},
|
||||
}
|
||||
result := SerializeMessages(messages)
|
||||
|
||||
data, _ := json.Marshal(result)
|
||||
var msgs []map[string]any
|
||||
json.Unmarshal(data, &msgs)
|
||||
|
||||
content, ok := msgs[0]["content"].([]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected array content for media message, got %T", msgs[0]["content"])
|
||||
}
|
||||
if len(content) != 2 {
|
||||
t.Fatalf("expected 2 content parts, got %d", len(content))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerializeMessages_MediaWithToolCallID(t *testing.T) {
|
||||
messages := []Message{
|
||||
{Role: "tool", Content: "result", Media: []string{"data:image/png;base64,xyz"}, ToolCallID: "call_1"},
|
||||
}
|
||||
result := SerializeMessages(messages)
|
||||
|
||||
data, _ := json.Marshal(result)
|
||||
var msgs []map[string]any
|
||||
json.Unmarshal(data, &msgs)
|
||||
|
||||
if msgs[0]["tool_call_id"] != "call_1" {
|
||||
t.Errorf("tool_call_id not preserved, got %v", msgs[0]["tool_call_id"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSerializeMessages_StripsSystemParts(t *testing.T) {
|
||||
messages := []Message{
|
||||
{
|
||||
Role: "system",
|
||||
Content: "you are helpful",
|
||||
SystemParts: []protocoltypes.ContentBlock{
|
||||
{Type: "text", Text: "you are helpful"},
|
||||
},
|
||||
},
|
||||
}
|
||||
result := SerializeMessages(messages)
|
||||
|
||||
data, _ := json.Marshal(result)
|
||||
if strings.Contains(string(data), "system_parts") {
|
||||
t.Error("system_parts should not appear in serialized output")
|
||||
}
|
||||
}
|
||||
|
||||
// --- ParseResponse tests ---
|
||||
|
||||
func TestParseResponse_BasicContent(t *testing.T) {
|
||||
body := `{"choices":[{"message":{"content":"hello world"},"finish_reason":"stop"}]}`
|
||||
out, err := ParseResponse(strings.NewReader(body))
|
||||
if err != nil {
|
||||
t.Fatalf("ParseResponse() error = %v", err)
|
||||
}
|
||||
if out.Content != "hello world" {
|
||||
t.Errorf("Content = %q, want %q", out.Content, "hello world")
|
||||
}
|
||||
if out.FinishReason != "stop" {
|
||||
t.Errorf("FinishReason = %q, want %q", out.FinishReason, "stop")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseResponse_EmptyChoices(t *testing.T) {
|
||||
body := `{"choices":[]}`
|
||||
out, err := ParseResponse(strings.NewReader(body))
|
||||
if err != nil {
|
||||
t.Fatalf("ParseResponse() error = %v", err)
|
||||
}
|
||||
if out.Content != "" {
|
||||
t.Errorf("Content = %q, want empty", out.Content)
|
||||
}
|
||||
if out.FinishReason != "stop" {
|
||||
t.Errorf("FinishReason = %q, want %q", out.FinishReason, "stop")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseResponse_WithToolCalls(t *testing.T) {
|
||||
body := `{"choices":[{"message":{"content":"","tool_calls":[{"id":"call_1","type":"function","function":{"name":"get_weather","arguments":"{\"city\":\"SF\"}"}}]},"finish_reason":"tool_calls"}]}`
|
||||
out, err := ParseResponse(strings.NewReader(body))
|
||||
if err != nil {
|
||||
t.Fatalf("ParseResponse() error = %v", err)
|
||||
}
|
||||
if len(out.ToolCalls) != 1 {
|
||||
t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls))
|
||||
}
|
||||
if out.ToolCalls[0].Name != "get_weather" {
|
||||
t.Errorf("ToolCalls[0].Name = %q, want %q", out.ToolCalls[0].Name, "get_weather")
|
||||
}
|
||||
if out.ToolCalls[0].Arguments["city"] != "SF" {
|
||||
t.Errorf("ToolCalls[0].Arguments[city] = %v, want SF", out.ToolCalls[0].Arguments["city"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseResponse_WithUsage(t *testing.T) {
|
||||
body := `{"choices":[{"message":{"content":"ok"},"finish_reason":"stop"}],"usage":{"prompt_tokens":10,"completion_tokens":5,"total_tokens":15}}`
|
||||
out, err := ParseResponse(strings.NewReader(body))
|
||||
if err != nil {
|
||||
t.Fatalf("ParseResponse() error = %v", err)
|
||||
}
|
||||
if out.Usage == nil {
|
||||
t.Fatal("Usage is nil")
|
||||
}
|
||||
if out.Usage.PromptTokens != 10 {
|
||||
t.Errorf("PromptTokens = %d, want 10", out.Usage.PromptTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseResponse_WithReasoningContent(t *testing.T) {
|
||||
body := `{"choices":[{"message":{"content":"2","reasoning_content":"Let me think... 1+1=2"},"finish_reason":"stop"}]}`
|
||||
out, err := ParseResponse(strings.NewReader(body))
|
||||
if err != nil {
|
||||
t.Fatalf("ParseResponse() error = %v", err)
|
||||
}
|
||||
if out.ReasoningContent != "Let me think... 1+1=2" {
|
||||
t.Errorf("ReasoningContent = %q, want %q", out.ReasoningContent, "Let me think... 1+1=2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseResponse_InvalidJSON(t *testing.T) {
|
||||
_, err := ParseResponse(strings.NewReader("not json"))
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid JSON")
|
||||
}
|
||||
}
|
||||
|
||||
// --- DecodeToolCallArguments tests ---
|
||||
|
||||
func TestDecodeToolCallArguments_ObjectJSON(t *testing.T) {
|
||||
raw := json.RawMessage(`{"city":"Seattle","units":"metric"}`)
|
||||
args := DecodeToolCallArguments(raw, "test")
|
||||
if args["city"] != "Seattle" {
|
||||
t.Errorf("city = %v, want Seattle", args["city"])
|
||||
}
|
||||
if args["units"] != "metric" {
|
||||
t.Errorf("units = %v, want metric", args["units"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeToolCallArguments_StringJSON(t *testing.T) {
|
||||
raw := json.RawMessage(`"{\"city\":\"SF\"}"`)
|
||||
args := DecodeToolCallArguments(raw, "test")
|
||||
if args["city"] != "SF" {
|
||||
t.Errorf("city = %v, want SF", args["city"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeToolCallArguments_EmptyInput(t *testing.T) {
|
||||
args := DecodeToolCallArguments(nil, "test")
|
||||
if len(args) != 0 {
|
||||
t.Errorf("expected empty map, got %v", args)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeToolCallArguments_NullInput(t *testing.T) {
|
||||
args := DecodeToolCallArguments(json.RawMessage(`null`), "test")
|
||||
if len(args) != 0 {
|
||||
t.Errorf("expected empty map, got %v", args)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeToolCallArguments_InvalidJSON(t *testing.T) {
|
||||
args := DecodeToolCallArguments(json.RawMessage(`not-json`), "test")
|
||||
if _, ok := args["raw"]; !ok {
|
||||
t.Error("expected 'raw' fallback key for invalid JSON")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeToolCallArguments_EmptyStringJSON(t *testing.T) {
|
||||
args := DecodeToolCallArguments(json.RawMessage(`" "`), "test")
|
||||
if len(args) != 0 {
|
||||
t.Errorf("expected empty map for whitespace string, got %v", args)
|
||||
}
|
||||
}
|
||||
|
||||
// --- HandleErrorResponse tests ---
|
||||
|
||||
func TestHandleErrorResponse_JSONError(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte(`{"error":"bad request"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
resp, err := http.Get(server.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("http.Get() error = %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
err = HandleErrorResponse(resp, server.URL)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "400") {
|
||||
t.Errorf("error should contain status code, got %v", err)
|
||||
}
|
||||
if strings.Contains(err.Error(), "HTML") {
|
||||
t.Errorf("should not mention HTML for JSON error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleErrorResponse_HTMLError(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
w.WriteHeader(http.StatusBadGateway)
|
||||
w.Write([]byte("<!DOCTYPE html><html><body>bad gateway</body></html>"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
resp, err := http.Get(server.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("http.Get() error = %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
err = HandleErrorResponse(resp, server.URL)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "HTML instead of JSON") {
|
||||
t.Errorf("expected HTML error message, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// --- ReadAndParseResponse tests ---
|
||||
|
||||
func TestReadAndParseResponse_ValidJSON(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(`{"choices":[{"message":{"content":"ok"},"finish_reason":"stop"}]}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
resp, err := http.Get(server.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("http.Get() error = %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
out, err := ReadAndParseResponse(resp, server.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadAndParseResponse() error = %v", err)
|
||||
}
|
||||
if out.Content != "ok" {
|
||||
t.Errorf("Content = %q, want %q", out.Content, "ok")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadAndParseResponse_HTMLResponse(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
w.Write([]byte("<!DOCTYPE html><html><body>login page</body></html>"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
resp, err := http.Get(server.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("http.Get() error = %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
_, err = ReadAndParseResponse(resp, server.URL)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for HTML response")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "HTML instead of JSON") {
|
||||
t.Errorf("expected HTML error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// --- LooksLikeHTML tests ---
|
||||
|
||||
func TestLooksLikeHTML_ContentTypeHTML(t *testing.T) {
|
||||
if !LooksLikeHTML(nil, "text/html; charset=utf-8") {
|
||||
t.Error("expected true for text/html content type")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLooksLikeHTML_ContentTypeXHTML(t *testing.T) {
|
||||
if !LooksLikeHTML(nil, "application/xhtml+xml") {
|
||||
t.Error("expected true for xhtml content type")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLooksLikeHTML_BodyPrefix(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
}{
|
||||
{"doctype", "<!DOCTYPE html><html>"},
|
||||
{"html tag", "<html><body>"},
|
||||
{"head tag", "<head><title>"},
|
||||
{"body tag", "<body>content"},
|
||||
{"whitespace before", " \n\t<!DOCTYPE html>"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if !LooksLikeHTML([]byte(tt.body), "application/json") {
|
||||
t.Errorf("expected true for body %q", tt.body)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLooksLikeHTML_NotHTML(t *testing.T) {
|
||||
if LooksLikeHTML([]byte(`{"error":"bad"}`), "application/json") {
|
||||
t.Error("expected false for JSON body")
|
||||
}
|
||||
}
|
||||
|
||||
// --- ResponsePreview tests ---
|
||||
|
||||
func TestResponsePreview_Short(t *testing.T) {
|
||||
got := ResponsePreview([]byte("hello"), 128)
|
||||
if got != "hello" {
|
||||
t.Errorf("got %q, want %q", got, "hello")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponsePreview_Truncated(t *testing.T) {
|
||||
body := strings.Repeat("a", 200)
|
||||
got := ResponsePreview([]byte(body), 128)
|
||||
if len(got) != 131 { // 128 + "..."
|
||||
t.Errorf("len = %d, want 131", len(got))
|
||||
}
|
||||
if !strings.HasSuffix(got, "...") {
|
||||
t.Error("expected ... suffix")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponsePreview_Empty(t *testing.T) {
|
||||
got := ResponsePreview([]byte(""), 128)
|
||||
if got != "<empty>" {
|
||||
t.Errorf("got %q, want %q", got, "<empty>")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponsePreview_Whitespace(t *testing.T) {
|
||||
got := ResponsePreview([]byte(" \n\t "), 128)
|
||||
if got != "<empty>" {
|
||||
t.Errorf("got %q, want %q for whitespace-only body", got, "<empty>")
|
||||
}
|
||||
}
|
||||
|
||||
// --- AsInt tests ---
|
||||
|
||||
func TestAsInt(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
val any
|
||||
want int
|
||||
ok bool
|
||||
}{
|
||||
{"int", 42, 42, true},
|
||||
{"int64", int64(99), 99, true},
|
||||
{"float64", float64(512), 512, true},
|
||||
{"float32", float32(256), 256, true},
|
||||
{"string", "nope", 0, false},
|
||||
{"nil", nil, 0, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, ok := AsInt(tt.val)
|
||||
if ok != tt.ok || got != tt.want {
|
||||
t.Errorf("AsInt(%v) = (%d, %v), want (%d, %v)", tt.val, got, ok, tt.want, tt.ok)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- AsFloat tests ---
|
||||
|
||||
func TestAsFloat(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
val any
|
||||
want float64
|
||||
ok bool
|
||||
}{
|
||||
{"float64", float64(0.7), 0.7, true},
|
||||
{"float32", float32(0.5), float64(float32(0.5)), true},
|
||||
{"int", 1, 1.0, true},
|
||||
{"int64", int64(100), 100.0, true},
|
||||
{"string", "nope", 0, false},
|
||||
{"nil", nil, 0, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, ok := AsFloat(tt.val)
|
||||
if ok != tt.ok || got != tt.want {
|
||||
t.Errorf("AsFloat(%v) = (%f, %v), want (%f, %v)", tt.val, got, ok, tt.want, tt.ok)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- WrapHTMLResponseError tests ---
|
||||
|
||||
func TestWrapHTMLResponseError(t *testing.T) {
|
||||
err := WrapHTMLResponseError(502, []byte("<html>bad</html>"), "text/html", "https://api.example.com")
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
msg := err.Error()
|
||||
if !strings.Contains(msg, "502") {
|
||||
t.Errorf("expected status code in error, got %v", msg)
|
||||
}
|
||||
if !strings.Contains(msg, "https://api.example.com") {
|
||||
t.Errorf("expected api base in error, got %v", msg)
|
||||
}
|
||||
if !strings.Contains(msg, "HTML instead of JSON") {
|
||||
t.Errorf("expected HTML mention in error, got %v", msg)
|
||||
}
|
||||
}
|
||||
|
||||
// --- HandleErrorResponse with read failure ---
|
||||
|
||||
func TestHandleErrorResponse_EmptyBody(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
// empty body
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
resp, err := http.Get(server.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("http.Get() error = %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
err = HandleErrorResponse(resp, server.URL)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "500") {
|
||||
t.Errorf("expected status code, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// --- ReadAndParseResponse with invalid JSON ---
|
||||
|
||||
func TestReadAndParseResponse_InvalidJSON(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte("not valid json"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
resp, err := http.Get(server.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("http.Get() error = %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
_, err = ReadAndParseResponse(resp, server.URL)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid JSON")
|
||||
}
|
||||
}
|
||||
|
||||
// --- ParseResponse with thought_signature (Google/Gemini) ---
|
||||
|
||||
func TestParseResponse_WithThoughtSignature(t *testing.T) {
|
||||
body := `{"choices":[{"message":{"content":"","tool_calls":[{"id":"call_1","type":"function","function":{"name":"test_tool","arguments":"{}"},"extra_content":{"google":{"thought_signature":"sig123"}}}]},"finish_reason":"tool_calls"}]}`
|
||||
out, err := ParseResponse(strings.NewReader(body))
|
||||
if err != nil {
|
||||
t.Fatalf("ParseResponse() error = %v", err)
|
||||
}
|
||||
if len(out.ToolCalls) != 1 {
|
||||
t.Fatalf("len(ToolCalls) = %d, want 1", len(out.ToolCalls))
|
||||
}
|
||||
if out.ToolCalls[0].ThoughtSignature != "sig123" {
|
||||
t.Errorf("ThoughtSignature = %q, want %q", out.ToolCalls[0].ThoughtSignature, "sig123")
|
||||
}
|
||||
if out.ToolCalls[0].ExtraContent == nil || out.ToolCalls[0].ExtraContent.Google == nil {
|
||||
t.Fatal("ExtraContent.Google is nil")
|
||||
}
|
||||
if out.ToolCalls[0].ExtraContent.Google.ThoughtSignature != "sig123" {
|
||||
t.Errorf("ExtraContent.Google.ThoughtSignature = %q, want %q",
|
||||
out.ToolCalls[0].ExtraContent.Google.ThoughtSignature, "sig123")
|
||||
}
|
||||
}
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
anthropicmessages "github.com/sipeed/picoclaw/pkg/providers/anthropic_messages"
|
||||
"github.com/sipeed/picoclaw/pkg/providers/azure"
|
||||
)
|
||||
|
||||
// createClaudeAuthProvider creates a Claude provider using OAuth credentials from auth store.
|
||||
@@ -94,6 +95,24 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
cfg.RequestTimeout,
|
||||
), modelID, nil
|
||||
|
||||
case "azure", "azure-openai":
|
||||
// Azure OpenAI uses deployment-based URLs, api-key header auth,
|
||||
// and always sends max_completion_tokens.
|
||||
if cfg.APIKey == "" {
|
||||
return nil, "", fmt.Errorf("api_key is required for azure protocol")
|
||||
}
|
||||
if cfg.APIBase == "" {
|
||||
return nil, "", fmt.Errorf(
|
||||
"api_base is required for azure protocol (e.g., https://your-resource.openai.azure.com)",
|
||||
)
|
||||
}
|
||||
return azure.NewProviderWithTimeout(
|
||||
cfg.APIKey,
|
||||
cfg.APIBase,
|
||||
cfg.Proxy,
|
||||
cfg.RequestTimeout,
|
||||
), modelID, nil
|
||||
|
||||
case "litellm", "openrouter", "groq", "zhipu", "gemini", "nvidia",
|
||||
"ollama", "moonshot", "shengsuanyun", "deepseek", "cerebras",
|
||||
"vivgrid", "volcengine", "vllm", "qwen", "mistral", "avian",
|
||||
|
||||
@@ -64,6 +64,12 @@ func TestExtractProtocol(t *testing.T) {
|
||||
wantProtocol: "nvidia",
|
||||
wantModelID: "meta/llama-3.1-8b",
|
||||
},
|
||||
{
|
||||
name: "azure with prefix",
|
||||
model: "azure/my-gpt5-deployment",
|
||||
wantProtocol: "azure",
|
||||
wantModelID: "my-gpt5-deployment",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -371,3 +377,69 @@ func TestCreateProviderFromConfig_RequestTimeoutPropagation(t *testing.T) {
|
||||
t.Fatalf("Chat() error = %q, want timeout-related error", errMsg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_Azure(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "azure-gpt5",
|
||||
Model: "azure/my-gpt5-deployment",
|
||||
APIKey: "test-azure-key",
|
||||
APIBase: "https://my-resource.openai.azure.com",
|
||||
}
|
||||
|
||||
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProviderFromConfig() error = %v", err)
|
||||
}
|
||||
if provider == nil {
|
||||
t.Fatal("CreateProviderFromConfig() returned nil provider")
|
||||
}
|
||||
if modelID != "my-gpt5-deployment" {
|
||||
t.Errorf("modelID = %q, want %q", modelID, "my-gpt5-deployment")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_AzureOpenAIAlias(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "azure-gpt4",
|
||||
Model: "azure-openai/my-deployment",
|
||||
APIKey: "test-azure-key",
|
||||
APIBase: "https://my-resource.openai.azure.com",
|
||||
}
|
||||
|
||||
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProviderFromConfig() error = %v", err)
|
||||
}
|
||||
if provider == nil {
|
||||
t.Fatal("CreateProviderFromConfig() returned nil provider")
|
||||
}
|
||||
if modelID != "my-deployment" {
|
||||
t.Errorf("modelID = %q, want %q", modelID, "my-deployment")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_AzureMissingAPIKey(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "azure-gpt5",
|
||||
Model: "azure/my-gpt5-deployment",
|
||||
APIBase: "https://my-resource.openai.azure.com",
|
||||
}
|
||||
|
||||
_, _, err := CreateProviderFromConfig(cfg)
|
||||
if err == nil {
|
||||
t.Fatal("CreateProviderFromConfig() expected error for missing API key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_AzureMissingAPIBase(t *testing.T) {
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "azure-gpt5",
|
||||
Model: "azure/my-gpt5-deployment",
|
||||
APIKey: "test-azure-key",
|
||||
}
|
||||
|
||||
_, _, err := CreateProviderFromConfig(cfg)
|
||||
if err == nil {
|
||||
t.Fatal("CreateProviderFromConfig() expected error for missing API base")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,18 +1,16 @@
|
||||
package openai_compat
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers/common"
|
||||
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
|
||||
)
|
||||
|
||||
@@ -38,7 +36,7 @@ type Provider struct {
|
||||
|
||||
type Option func(*Provider)
|
||||
|
||||
const defaultRequestTimeout = 120 * time.Second
|
||||
const defaultRequestTimeout = common.DefaultRequestTimeout
|
||||
|
||||
func WithMaxTokensField(maxTokensField string) Option {
|
||||
return func(p *Provider) {
|
||||
@@ -55,25 +53,10 @@ func WithRequestTimeout(timeout time.Duration) Option {
|
||||
}
|
||||
|
||||
func NewProvider(apiKey, apiBase, proxy string, opts ...Option) *Provider {
|
||||
client := &http.Client{
|
||||
Timeout: defaultRequestTimeout,
|
||||
}
|
||||
|
||||
if proxy != "" {
|
||||
parsed, err := url.Parse(proxy)
|
||||
if err == nil {
|
||||
client.Transport = &http.Transport{
|
||||
Proxy: http.ProxyURL(parsed),
|
||||
}
|
||||
} else {
|
||||
log.Printf("openai_compat: invalid proxy URL %q: %v", proxy, err)
|
||||
}
|
||||
}
|
||||
|
||||
p := &Provider{
|
||||
apiKey: apiKey,
|
||||
apiBase: strings.TrimRight(apiBase, "/"),
|
||||
httpClient: client,
|
||||
httpClient: common.NewHTTPClient(proxy),
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
@@ -117,7 +100,7 @@ func (p *Provider) Chat(
|
||||
|
||||
requestBody := map[string]any{
|
||||
"model": model,
|
||||
"messages": serializeMessages(messages),
|
||||
"messages": common.SerializeMessages(messages),
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
@@ -125,7 +108,7 @@ func (p *Provider) Chat(
|
||||
requestBody["tool_choice"] = "auto"
|
||||
}
|
||||
|
||||
if maxTokens, ok := asInt(options["max_tokens"]); ok {
|
||||
if maxTokens, ok := common.AsInt(options["max_tokens"]); ok {
|
||||
// Use configured maxTokensField if specified, otherwise fallback to model-based detection
|
||||
fieldName := p.maxTokensField
|
||||
if fieldName == "" {
|
||||
@@ -141,7 +124,7 @@ func (p *Provider) Chat(
|
||||
requestBody[fieldName] = maxTokens
|
||||
}
|
||||
|
||||
if temperature, ok := asFloat(options["temperature"]); ok {
|
||||
if temperature, ok := common.AsFloat(options["temperature"]); ok {
|
||||
lowerModel := strings.ToLower(model)
|
||||
// Kimi k2 models only support temperature=1.
|
||||
if strings.Contains(lowerModel, "kimi") && strings.Contains(lowerModel, "k2") {
|
||||
@@ -185,275 +168,11 @@ func (p *Provider) Chat(
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
|
||||
// Non-200: read a prefix to tell HTML error page apart from JSON error body.
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, readErr := io.ReadAll(io.LimitReader(resp.Body, 256))
|
||||
if readErr != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", readErr)
|
||||
}
|
||||
if looksLikeHTML(body, contentType) {
|
||||
return nil, wrapHTMLResponseError(resp.StatusCode, body, contentType, p.apiBase)
|
||||
}
|
||||
return nil, fmt.Errorf(
|
||||
"API request failed:\n Status: %d\n Body: %s",
|
||||
resp.StatusCode,
|
||||
responsePreview(body, 128),
|
||||
)
|
||||
return nil, common.HandleErrorResponse(resp, p.apiBase)
|
||||
}
|
||||
|
||||
// Peek without consuming so the full stream reaches the JSON decoder.
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
prefix, err := reader.Peek(256) // io.EOF/ErrBufferFull are normal; only real errors abort
|
||||
if err != nil && err != io.EOF && err != bufio.ErrBufferFull {
|
||||
return nil, fmt.Errorf("failed to inspect response: %w", err)
|
||||
}
|
||||
if looksLikeHTML(prefix, contentType) {
|
||||
return nil, wrapHTMLResponseError(resp.StatusCode, prefix, contentType, p.apiBase)
|
||||
}
|
||||
|
||||
out, err := parseResponse(reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse JSON response: %w", err)
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func wrapHTMLResponseError(statusCode int, body []byte, contentType, apiBase string) error {
|
||||
respPreview := responsePreview(body, 128)
|
||||
return fmt.Errorf(
|
||||
"API request failed: %s returned HTML instead of JSON (content-type: %s); check api_base or proxy configuration.\n Status: %d\n Body: %s",
|
||||
apiBase,
|
||||
contentType,
|
||||
statusCode,
|
||||
respPreview,
|
||||
)
|
||||
}
|
||||
|
||||
func looksLikeHTML(body []byte, contentType string) bool {
|
||||
contentType = strings.ToLower(strings.TrimSpace(contentType))
|
||||
if strings.Contains(contentType, "text/html") || strings.Contains(contentType, "application/xhtml+xml") {
|
||||
return true
|
||||
}
|
||||
prefix := bytes.ToLower(leadingTrimmedPrefix(body, 128))
|
||||
return bytes.HasPrefix(prefix, []byte("<!doctype html")) ||
|
||||
bytes.HasPrefix(prefix, []byte("<html")) ||
|
||||
bytes.HasPrefix(prefix, []byte("<head")) ||
|
||||
bytes.HasPrefix(prefix, []byte("<body"))
|
||||
}
|
||||
|
||||
func leadingTrimmedPrefix(body []byte, maxLen int) []byte {
|
||||
i := 0
|
||||
for i < len(body) {
|
||||
switch body[i] {
|
||||
case ' ', '\t', '\n', '\r', '\f', '\v':
|
||||
i++
|
||||
default:
|
||||
end := i + maxLen
|
||||
if end > len(body) {
|
||||
end = len(body)
|
||||
}
|
||||
return body[i:end]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func responsePreview(body []byte, maxLen int) string {
|
||||
trimmed := bytes.TrimSpace(body)
|
||||
if len(trimmed) == 0 {
|
||||
return "<empty>"
|
||||
}
|
||||
if len(trimmed) <= maxLen {
|
||||
return string(trimmed)
|
||||
}
|
||||
return string(trimmed[:maxLen]) + "..."
|
||||
}
|
||||
|
||||
func parseResponse(body io.Reader) (*LLMResponse, error) {
|
||||
var apiResponse struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
ReasoningContent string `json:"reasoning_content"`
|
||||
Reasoning string `json:"reasoning"`
|
||||
ReasoningDetails []ReasoningDetail `json:"reasoning_details"`
|
||||
ToolCalls []struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Function *struct {
|
||||
Name string `json:"name"`
|
||||
Arguments json.RawMessage `json:"arguments"`
|
||||
} `json:"function"`
|
||||
ExtraContent *struct {
|
||||
Google *struct {
|
||||
ThoughtSignature string `json:"thought_signature"`
|
||||
} `json:"google"`
|
||||
} `json:"extra_content"`
|
||||
} `json:"tool_calls"`
|
||||
} `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
} `json:"choices"`
|
||||
Usage *UsageInfo `json:"usage"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(body).Decode(&apiResponse); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
if len(apiResponse.Choices) == 0 {
|
||||
return &LLMResponse{
|
||||
Content: "",
|
||||
FinishReason: "stop",
|
||||
}, nil
|
||||
}
|
||||
|
||||
choice := apiResponse.Choices[0]
|
||||
toolCalls := make([]ToolCall, 0, len(choice.Message.ToolCalls))
|
||||
for _, tc := range choice.Message.ToolCalls {
|
||||
arguments := make(map[string]any)
|
||||
name := ""
|
||||
|
||||
// Extract thought_signature from Gemini/Google-specific extra content
|
||||
thoughtSignature := ""
|
||||
if tc.ExtraContent != nil && tc.ExtraContent.Google != nil {
|
||||
thoughtSignature = tc.ExtraContent.Google.ThoughtSignature
|
||||
}
|
||||
|
||||
if tc.Function != nil {
|
||||
name = tc.Function.Name
|
||||
arguments = decodeToolCallArguments(tc.Function.Arguments, name)
|
||||
}
|
||||
|
||||
// Build ToolCall with ExtraContent for Gemini 3 thought_signature persistence
|
||||
toolCall := ToolCall{
|
||||
ID: tc.ID,
|
||||
Name: name,
|
||||
Arguments: arguments,
|
||||
ThoughtSignature: thoughtSignature,
|
||||
}
|
||||
|
||||
if thoughtSignature != "" {
|
||||
toolCall.ExtraContent = &ExtraContent{
|
||||
Google: &GoogleExtra{
|
||||
ThoughtSignature: thoughtSignature,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
}
|
||||
|
||||
return &LLMResponse{
|
||||
Content: choice.Message.Content,
|
||||
ReasoningContent: choice.Message.ReasoningContent,
|
||||
Reasoning: choice.Message.Reasoning,
|
||||
ReasoningDetails: choice.Message.ReasoningDetails,
|
||||
ToolCalls: toolCalls,
|
||||
FinishReason: choice.FinishReason,
|
||||
Usage: apiResponse.Usage,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func decodeToolCallArguments(raw json.RawMessage, name string) map[string]any {
|
||||
arguments := make(map[string]any)
|
||||
raw = bytes.TrimSpace(raw)
|
||||
if len(raw) == 0 || bytes.Equal(raw, []byte("null")) {
|
||||
return arguments
|
||||
}
|
||||
|
||||
var decoded any
|
||||
if err := json.Unmarshal(raw, &decoded); err != nil {
|
||||
log.Printf("openai_compat: failed to decode tool call arguments payload for %q: %v", name, err)
|
||||
arguments["raw"] = string(raw)
|
||||
return arguments
|
||||
}
|
||||
|
||||
switch v := decoded.(type) {
|
||||
case string:
|
||||
if strings.TrimSpace(v) == "" {
|
||||
return arguments
|
||||
}
|
||||
if err := json.Unmarshal([]byte(v), &arguments); err != nil {
|
||||
log.Printf("openai_compat: failed to decode tool call arguments for %q: %v", name, err)
|
||||
arguments["raw"] = v
|
||||
}
|
||||
return arguments
|
||||
case map[string]any:
|
||||
return v
|
||||
default:
|
||||
log.Printf("openai_compat: unsupported tool call arguments type for %q: %T", name, decoded)
|
||||
arguments["raw"] = string(raw)
|
||||
return arguments
|
||||
}
|
||||
}
|
||||
|
||||
// openaiMessage is the wire-format message for OpenAI-compatible APIs.
|
||||
// It mirrors protocoltypes.Message but omits SystemParts, which is an
|
||||
// internal field that would be unknown to third-party endpoints.
|
||||
type openaiMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
}
|
||||
|
||||
// serializeMessages converts internal Message structs to the OpenAI wire format.
|
||||
// - Strips SystemParts (unknown to third-party endpoints)
|
||||
// - Converts messages with Media to multipart content format (text + image_url parts)
|
||||
// - Preserves ToolCallID, ToolCalls, and ReasoningContent for all messages
|
||||
func serializeMessages(messages []Message) []any {
|
||||
out := make([]any, 0, len(messages))
|
||||
for _, m := range messages {
|
||||
if len(m.Media) == 0 {
|
||||
out = append(out, openaiMessage{
|
||||
Role: m.Role,
|
||||
Content: m.Content,
|
||||
ReasoningContent: m.ReasoningContent,
|
||||
ToolCalls: m.ToolCalls,
|
||||
ToolCallID: m.ToolCallID,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// Multipart content format for messages with media
|
||||
parts := make([]map[string]any, 0, 1+len(m.Media))
|
||||
if m.Content != "" {
|
||||
parts = append(parts, map[string]any{
|
||||
"type": "text",
|
||||
"text": m.Content,
|
||||
})
|
||||
}
|
||||
for _, mediaURL := range m.Media {
|
||||
if strings.HasPrefix(mediaURL, "data:image/") {
|
||||
parts = append(parts, map[string]any{
|
||||
"type": "image_url",
|
||||
"image_url": map[string]any{
|
||||
"url": mediaURL,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
msg := map[string]any{
|
||||
"role": m.Role,
|
||||
"content": parts,
|
||||
}
|
||||
if m.ToolCallID != "" {
|
||||
msg["tool_call_id"] = m.ToolCallID
|
||||
}
|
||||
if len(m.ToolCalls) > 0 {
|
||||
msg["tool_calls"] = m.ToolCalls
|
||||
}
|
||||
if m.ReasoningContent != "" {
|
||||
msg["reasoning_content"] = m.ReasoningContent
|
||||
}
|
||||
out = append(out, msg)
|
||||
}
|
||||
return out
|
||||
return common.ReadAndParseResponse(resp, p.apiBase)
|
||||
}
|
||||
|
||||
func normalizeModel(model, apiBase string) string {
|
||||
@@ -476,36 +195,6 @@ func normalizeModel(model, apiBase string) string {
|
||||
}
|
||||
}
|
||||
|
||||
func asInt(v any) (int, bool) {
|
||||
switch val := v.(type) {
|
||||
case int:
|
||||
return val, true
|
||||
case int64:
|
||||
return int(val), true
|
||||
case float64:
|
||||
return int(val), true
|
||||
case float32:
|
||||
return int(val), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
func asFloat(v any) (float64, bool) {
|
||||
switch val := v.(type) {
|
||||
case float64:
|
||||
return val, true
|
||||
case float32:
|
||||
return float64(val), true
|
||||
case int:
|
||||
return float64(val), true
|
||||
case int64:
|
||||
return float64(val), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
// supportsPromptCacheKey reports whether the given API base is known to
|
||||
// support the prompt_cache_key request field. Currently only OpenAI's own
|
||||
// API and Azure OpenAI support this. All other OpenAI-compatible providers
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers/common"
|
||||
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
|
||||
)
|
||||
|
||||
@@ -648,7 +649,7 @@ func TestSerializeMessages_PlainText(t *testing.T) {
|
||||
{Role: "user", Content: "hello"},
|
||||
{Role: "assistant", Content: "hi", ReasoningContent: "thinking..."},
|
||||
}
|
||||
result := serializeMessages(messages)
|
||||
result := common.SerializeMessages(messages)
|
||||
|
||||
data, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
@@ -670,7 +671,7 @@ func TestSerializeMessages_WithMedia(t *testing.T) {
|
||||
messages := []protocoltypes.Message{
|
||||
{Role: "user", Content: "describe this", Media: []string{"data:image/png;base64,abc123"}},
|
||||
}
|
||||
result := serializeMessages(messages)
|
||||
result := common.SerializeMessages(messages)
|
||||
|
||||
data, _ := json.Marshal(result)
|
||||
var msgs []map[string]any
|
||||
@@ -703,7 +704,7 @@ func TestSerializeMessages_MediaWithToolCallID(t *testing.T) {
|
||||
messages := []protocoltypes.Message{
|
||||
{Role: "tool", Content: "image result", Media: []string{"data:image/png;base64,xyz"}, ToolCallID: "call_1"},
|
||||
}
|
||||
result := serializeMessages(messages)
|
||||
result := common.SerializeMessages(messages)
|
||||
|
||||
data, _ := json.Marshal(result)
|
||||
var msgs []map[string]any
|
||||
@@ -833,7 +834,7 @@ func TestSerializeMessages_StripsSystemParts(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
result := serializeMessages(messages)
|
||||
result := common.SerializeMessages(messages)
|
||||
|
||||
data, _ := json.Marshal(result)
|
||||
raw := string(data)
|
||||
|
||||
+53
-20
@@ -20,10 +20,12 @@ type JobExecutor interface {
|
||||
|
||||
// CronTool provides scheduling capabilities for the agent
|
||||
type CronTool struct {
|
||||
cronService *cron.CronService
|
||||
executor JobExecutor
|
||||
msgBus *bus.MessageBus
|
||||
execTool *ExecTool
|
||||
cronService *cron.CronService
|
||||
executor JobExecutor
|
||||
msgBus *bus.MessageBus
|
||||
execTool *ExecTool
|
||||
allowCommand bool
|
||||
execEnabled bool
|
||||
}
|
||||
|
||||
// NewCronTool creates a new CronTool
|
||||
@@ -32,17 +34,32 @@ func NewCronTool(
|
||||
cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus, workspace string, restrict bool,
|
||||
execTimeout time.Duration, config *config.Config,
|
||||
) (*CronTool, error) {
|
||||
execTool, err := NewExecToolWithConfig(workspace, restrict, config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to configure exec tool: %w", err)
|
||||
allowCommand := true
|
||||
execEnabled := true
|
||||
if config != nil {
|
||||
allowCommand = config.Tools.Cron.AllowCommand
|
||||
execEnabled = config.Tools.Exec.Enabled
|
||||
}
|
||||
|
||||
execTool.SetTimeout(execTimeout)
|
||||
var execTool *ExecTool
|
||||
if execEnabled {
|
||||
var err error
|
||||
execTool, err = NewExecToolWithConfig(workspace, restrict, config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to configure exec tool: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if execTool != nil {
|
||||
execTool.SetTimeout(execTimeout)
|
||||
}
|
||||
return &CronTool{
|
||||
cronService: cronService,
|
||||
executor: executor,
|
||||
msgBus: msgBus,
|
||||
execTool: execTool,
|
||||
cronService: cronService,
|
||||
executor: executor,
|
||||
msgBus: msgBus,
|
||||
execTool: execTool,
|
||||
allowCommand: allowCommand,
|
||||
execEnabled: execEnabled,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -76,7 +93,7 @@ func (t *CronTool) Parameters() map[string]any {
|
||||
},
|
||||
"command_confirm": map[string]any{
|
||||
"type": "boolean",
|
||||
"description": "Required when using command=true. Must be true to explicitly confirm scheduling a shell command.",
|
||||
"description": "Optional explicit confirmation flag for scheduling a shell command. Command execution must also be enabled via tools.cron.allow_command.",
|
||||
},
|
||||
"at_seconds": map[string]any{
|
||||
"type": "integer",
|
||||
@@ -96,7 +113,7 @@ func (t *CronTool) Parameters() map[string]any {
|
||||
},
|
||||
"deliver": map[string]any{
|
||||
"type": "boolean",
|
||||
"description": "If true, send message directly to channel. If false, let agent process message (for complex tasks). Default: true",
|
||||
"description": "If true, send message directly to channel. If false, let agent process message (for complex tasks). Default: false",
|
||||
},
|
||||
},
|
||||
"required": []string{"action"},
|
||||
@@ -174,22 +191,26 @@ func (t *CronTool) addJob(ctx context.Context, args map[string]any) *ToolResult
|
||||
return ErrorResult("one of at_seconds, every_seconds, or cron_expr is required")
|
||||
}
|
||||
|
||||
// Read deliver parameter, default to true
|
||||
deliver := true
|
||||
// Read deliver parameter, default to false so scheduled tasks execute through the agent
|
||||
deliver := false
|
||||
if d, ok := args["deliver"].(bool); ok {
|
||||
deliver = d
|
||||
}
|
||||
|
||||
// GHSA-pv8c-p6jf-3fpp: command scheduling requires internal channel + explicit confirm.
|
||||
// Non-command reminders (plain messages) remain open to all channels.
|
||||
// GHSA-pv8c-p6jf-3fpp: command scheduling requires internal channel. When
|
||||
// allow_command is disabled, explicit confirmation is required as an override.
|
||||
// Non-command reminders remain open to all channels.
|
||||
command, _ := args["command"].(string)
|
||||
commandConfirm, _ := args["command_confirm"].(bool)
|
||||
if command != "" {
|
||||
if !t.execEnabled {
|
||||
return ErrorResult("command execution is disabled")
|
||||
}
|
||||
if !constants.IsInternalChannel(channel) {
|
||||
return ErrorResult("scheduling command execution is restricted to internal channels")
|
||||
}
|
||||
if !commandConfirm {
|
||||
return ErrorResult("command_confirm=true is required to schedule command execution")
|
||||
if !t.allowCommand && !commandConfirm {
|
||||
return ErrorResult("command_confirm=true is required when allow_command is disabled")
|
||||
}
|
||||
deliver = false
|
||||
}
|
||||
@@ -290,6 +311,18 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string {
|
||||
|
||||
// Execute command if present
|
||||
if job.Payload.Command != "" {
|
||||
if !t.execEnabled || t.execTool == nil {
|
||||
output := "Error executing scheduled command: command execution is disabled"
|
||||
pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer pubCancel()
|
||||
t.msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{
|
||||
Channel: channel,
|
||||
ChatID: chatID,
|
||||
Content: output,
|
||||
})
|
||||
return "ok"
|
||||
}
|
||||
|
||||
args := map[string]any{
|
||||
"command": job.Payload.Command,
|
||||
"__channel": channel,
|
||||
|
||||
+126
-6
@@ -5,18 +5,18 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/cron"
|
||||
)
|
||||
|
||||
func newTestCronTool(t *testing.T) *CronTool {
|
||||
func newTestCronToolWithConfig(t *testing.T, cfg *config.Config) *CronTool {
|
||||
t.Helper()
|
||||
storePath := filepath.Join(t.TempDir(), "cron.json")
|
||||
cronService := cron.NewCronService(storePath, nil)
|
||||
msgBus := bus.NewMessageBus()
|
||||
cfg := config.DefaultConfig()
|
||||
tool, err := NewCronTool(cronService, nil, msgBus, t.TempDir(), true, 0, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("NewCronTool() error: %v", err)
|
||||
@@ -24,6 +24,11 @@ func newTestCronTool(t *testing.T) *CronTool {
|
||||
return tool
|
||||
}
|
||||
|
||||
func newTestCronTool(t *testing.T) *CronTool {
|
||||
t.Helper()
|
||||
return newTestCronToolWithConfig(t, config.DefaultConfig())
|
||||
}
|
||||
|
||||
// TestCronTool_CommandBlockedFromRemoteChannel verifies command scheduling is restricted to internal channels
|
||||
func TestCronTool_CommandBlockedFromRemoteChannel(t *testing.T) {
|
||||
tool := newTestCronTool(t)
|
||||
@@ -44,8 +49,7 @@ func TestCronTool_CommandBlockedFromRemoteChannel(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestCronTool_CommandRequiresConfirm verifies command_confirm=true is required
|
||||
func TestCronTool_CommandRequiresConfirm(t *testing.T) {
|
||||
func TestCronTool_CommandDoesNotRequireConfirmByDefault(t *testing.T) {
|
||||
tool := newTestCronTool(t)
|
||||
ctx := WithToolContext(context.Background(), "cli", "direct")
|
||||
result := tool.Execute(ctx, map[string]any{
|
||||
@@ -55,11 +59,79 @@ func TestCronTool_CommandRequiresConfirm(t *testing.T) {
|
||||
"at_seconds": float64(60),
|
||||
})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf("expected command scheduling without confirm to succeed by default, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "Cron job added") {
|
||||
t.Errorf("expected 'Cron job added', got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCronTool_CommandRequiresConfirmWhenAllowCommandDisabled(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Tools.Cron.AllowCommand = false
|
||||
|
||||
tool := newTestCronToolWithConfig(t, cfg)
|
||||
ctx := WithToolContext(context.Background(), "cli", "direct")
|
||||
result := tool.Execute(ctx, map[string]any{
|
||||
"action": "add",
|
||||
"message": "check disk",
|
||||
"command": "df -h",
|
||||
"at_seconds": float64(60),
|
||||
})
|
||||
|
||||
if !result.IsError {
|
||||
t.Fatal("expected error when command_confirm is missing")
|
||||
t.Fatal("expected command scheduling to require confirm when allow_command is disabled")
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "command_confirm=true") {
|
||||
t.Errorf("expected 'command_confirm=true' message, got: %s", result.ForLLM)
|
||||
t.Errorf("expected command_confirm requirement message, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCronTool_CommandAllowedWithConfirmWhenAllowCommandDisabled(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Tools.Cron.AllowCommand = false
|
||||
|
||||
tool := newTestCronToolWithConfig(t, cfg)
|
||||
ctx := WithToolContext(context.Background(), "cli", "direct")
|
||||
result := tool.Execute(ctx, map[string]any{
|
||||
"action": "add",
|
||||
"message": "check disk",
|
||||
"command": "df -h",
|
||||
"command_confirm": true,
|
||||
"at_seconds": float64(60),
|
||||
})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf(
|
||||
"expected command scheduling with confirm to succeed when allow_command is disabled, got: %s",
|
||||
result.ForLLM,
|
||||
)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "Cron job added") {
|
||||
t.Errorf("expected 'Cron job added', got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCronTool_CommandBlockedWhenExecDisabled(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Tools.Exec.Enabled = false
|
||||
|
||||
tool := newTestCronToolWithConfig(t, cfg)
|
||||
ctx := WithToolContext(context.Background(), "cli", "direct")
|
||||
result := tool.Execute(ctx, map[string]any{
|
||||
"action": "add",
|
||||
"message": "check disk",
|
||||
"command": "df -h",
|
||||
"command_confirm": true,
|
||||
"at_seconds": float64(60),
|
||||
})
|
||||
|
||||
if !result.IsError {
|
||||
t.Fatal("expected command scheduling to be blocked when exec is disabled")
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "command execution is disabled") {
|
||||
t.Errorf("expected exec disabled message, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -114,3 +186,51 @@ func TestCronTool_NonCommandJobAllowedFromRemoteChannel(t *testing.T) {
|
||||
t.Fatalf("expected non-command reminder to succeed from remote channel, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCronTool_NonCommandJobDefaultsDeliverToFalse(t *testing.T) {
|
||||
tool := newTestCronTool(t)
|
||||
ctx := WithToolContext(context.Background(), "telegram", "chat-1")
|
||||
result := tool.Execute(ctx, map[string]any{
|
||||
"action": "add",
|
||||
"message": "send me a poem",
|
||||
"at_seconds": float64(600),
|
||||
})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf("expected non-command reminder to succeed, got: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
jobs := tool.cronService.ListJobs(false)
|
||||
if len(jobs) != 1 {
|
||||
t.Fatalf("expected 1 job, got %d", len(jobs))
|
||||
}
|
||||
if jobs[0].Payload.Deliver {
|
||||
t.Fatal("expected deliver=false by default for non-command jobs")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCronTool_ExecuteJobPublishesErrorWhenExecDisabled(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Tools.Exec.Enabled = false
|
||||
|
||||
tool := newTestCronToolWithConfig(t, cfg)
|
||||
job := &cron.CronJob{}
|
||||
job.Payload.Channel = "cli"
|
||||
job.Payload.To = "direct"
|
||||
job.Payload.Command = "df -h"
|
||||
|
||||
if got := tool.ExecuteJob(context.Background(), job); got != "ok" {
|
||||
t.Fatalf("ExecuteJob() = %q, want ok", got)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
msg, ok := tool.msgBus.SubscribeOutbound(ctx)
|
||||
if !ok {
|
||||
t.Fatal("expected outbound message")
|
||||
}
|
||||
if !strings.Contains(msg.Content, "command execution is disabled") {
|
||||
t.Fatalf("expected exec disabled message, got: %s", msg.Content)
|
||||
}
|
||||
}
|
||||
|
||||
+161
-9
@@ -20,8 +20,7 @@ import (
|
||||
|
||||
const MaxReadFileSize = 64 * 1024 // 64KB limit to avoid context overflow
|
||||
|
||||
// validatePath ensures the given path is within the workspace if restrict is true.
|
||||
func validatePath(path, workspace string, restrict bool) (string, error) {
|
||||
func validatePathWithAllowPaths(path, workspace string, restrict bool, patterns []*regexp.Regexp) (string, error) {
|
||||
if workspace == "" {
|
||||
return path, fmt.Errorf("workspace is not defined")
|
||||
}
|
||||
@@ -42,6 +41,10 @@ func validatePath(path, workspace string, restrict bool) (string, error) {
|
||||
}
|
||||
|
||||
if restrict {
|
||||
if isAllowedPath(absPath, patterns) {
|
||||
return absPath, nil
|
||||
}
|
||||
|
||||
if !isWithinWorkspace(absPath, absWorkspace) {
|
||||
return "", fmt.Errorf("access denied: path is outside the workspace")
|
||||
}
|
||||
@@ -73,6 +76,137 @@ func validatePath(path, workspace string, restrict bool) (string, error) {
|
||||
return absPath, nil
|
||||
}
|
||||
|
||||
func isAllowedPath(path string, patterns []*regexp.Regexp) bool {
|
||||
if len(patterns) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
cleaned := filepath.Clean(path)
|
||||
if !filepath.IsAbs(cleaned) {
|
||||
return false
|
||||
}
|
||||
if !matchesAllowedPath(cleaned, patterns) {
|
||||
return false
|
||||
}
|
||||
|
||||
resolved, err := resolvePathAgainstExistingAncestor(cleaned)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return matchesAllowedPath(resolved, patterns)
|
||||
}
|
||||
|
||||
func matchesAllowedPath(path string, patterns []*regexp.Regexp) bool {
|
||||
cleaned := filepath.Clean(path)
|
||||
for _, pattern := range patterns {
|
||||
if pattern.MatchString(cleaned) {
|
||||
return true
|
||||
}
|
||||
if root, ok := extractAllowedPathRoot(pattern); ok && isWithinAllowedRoot(cleaned, root) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func extractAllowedPathRoot(pattern *regexp.Regexp) (string, bool) {
|
||||
raw := pattern.String()
|
||||
if !strings.HasPrefix(raw, "^") {
|
||||
return "", false
|
||||
}
|
||||
|
||||
literal := strings.TrimPrefix(raw, "^")
|
||||
|
||||
// Recognize the common "directory prefix" form: ^<literal>(?:/|$)
|
||||
literal = strings.TrimSuffix(literal, "(?:/|$)")
|
||||
literal = strings.TrimSuffix(literal, `(?:\\|$)`)
|
||||
|
||||
// Reject patterns that still contain regex operators after removing the
|
||||
// optional anchored-directory suffix. That keeps arbitrary regex behavior
|
||||
// unchanged and only enables normalized prefix matching for literal paths.
|
||||
if containsUnescapedRegexMeta(literal) {
|
||||
return "", false
|
||||
}
|
||||
|
||||
unescaped, ok := unescapeRegexLiteral(literal)
|
||||
if !ok || unescaped == "" {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return filepath.Clean(unescaped), filepath.IsAbs(unescaped)
|
||||
}
|
||||
|
||||
func appendUniquePath(paths []string, path string) []string {
|
||||
for _, existing := range paths {
|
||||
if existing == path {
|
||||
return paths
|
||||
}
|
||||
}
|
||||
return append(paths, path)
|
||||
}
|
||||
|
||||
func containsUnescapedRegexMeta(s string) bool {
|
||||
escaped := false
|
||||
for _, r := range s {
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if r == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
switch r {
|
||||
case '.', '+', '*', '?', '(', ')', '[', ']', '{', '}', '|':
|
||||
return true
|
||||
}
|
||||
}
|
||||
return escaped
|
||||
}
|
||||
|
||||
func unescapeRegexLiteral(s string) (string, bool) {
|
||||
var b strings.Builder
|
||||
b.Grow(len(s))
|
||||
|
||||
escaped := false
|
||||
for _, r := range s {
|
||||
if escaped {
|
||||
b.WriteRune(r)
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if r == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
b.WriteRune(r)
|
||||
}
|
||||
|
||||
if escaped {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return b.String(), true
|
||||
}
|
||||
|
||||
func isWithinAllowedRoot(path, root string) bool {
|
||||
candidate := filepath.Clean(path)
|
||||
allowedVariants := []string{filepath.Clean(root)}
|
||||
|
||||
if resolvedRoot, err := resolvePathAgainstExistingAncestor(root); err == nil {
|
||||
allowedVariants = appendUniquePath(allowedVariants, filepath.Clean(resolvedRoot))
|
||||
}
|
||||
|
||||
for _, allowedRoot := range allowedVariants {
|
||||
if isWithinWorkspace(candidate, allowedRoot) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func resolveExistingAncestor(path string) (string, error) {
|
||||
for current := filepath.Clean(path); ; current = filepath.Dir(current) {
|
||||
if resolved, err := filepath.EvalSymlinks(current); err == nil {
|
||||
@@ -86,9 +220,32 @@ func resolveExistingAncestor(path string) (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func resolvePathAgainstExistingAncestor(path string) (string, error) {
|
||||
cleaned := filepath.Clean(path)
|
||||
for current := cleaned; ; current = filepath.Dir(current) {
|
||||
resolved, err := filepath.EvalSymlinks(current)
|
||||
if err == nil {
|
||||
suffix, relErr := filepath.Rel(current, cleaned)
|
||||
if relErr != nil {
|
||||
return "", relErr
|
||||
}
|
||||
if suffix == "." {
|
||||
return filepath.Clean(resolved), nil
|
||||
}
|
||||
return filepath.Clean(filepath.Join(resolved, suffix)), nil
|
||||
}
|
||||
if !os.IsNotExist(err) {
|
||||
return "", err
|
||||
}
|
||||
if filepath.Dir(current) == current {
|
||||
return "", os.ErrNotExist
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isWithinWorkspace(candidate, workspace string) bool {
|
||||
rel, err := filepath.Rel(filepath.Clean(workspace), filepath.Clean(candidate))
|
||||
return err == nil && filepath.IsLocal(rel)
|
||||
return err == nil && (rel == "." || filepath.IsLocal(rel))
|
||||
}
|
||||
|
||||
type ReadFileTool struct {
|
||||
@@ -625,12 +782,7 @@ type whitelistFs struct {
|
||||
}
|
||||
|
||||
func (w *whitelistFs) matches(path string) bool {
|
||||
for _, p := range w.patterns {
|
||||
if p.MatchString(path) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
return isAllowedPath(path, w.patterns)
|
||||
}
|
||||
|
||||
func (w *whitelistFs) ReadFile(path string) ([]byte, error) {
|
||||
|
||||
@@ -521,6 +521,90 @@ func TestWhitelistFs_AllowsMatchingPaths(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestWhitelistFs_BlocksSymlinkEscapeInAllowedDir(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
allowedDir := t.TempDir()
|
||||
secretDir := t.TempDir()
|
||||
secretFile := filepath.Join(secretDir, "secret.txt")
|
||||
if err := os.WriteFile(secretFile, []byte("top secret"), 0o644); err != nil {
|
||||
t.Fatalf("WriteFile(secretFile) error = %v", err)
|
||||
}
|
||||
|
||||
linkPath := filepath.Join(allowedDir, "link_out")
|
||||
if err := os.Symlink(secretDir, linkPath); err != nil {
|
||||
t.Skipf("symlink not supported in this environment: %v", err)
|
||||
}
|
||||
|
||||
patterns := []*regexp.Regexp{regexp.MustCompile(`^` + regexp.QuoteMeta(allowedDir))}
|
||||
tool := NewReadFileTool(workspace, true, MaxReadFileSize, patterns)
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{"path": filepath.Join(linkPath, "secret.txt")})
|
||||
if !result.IsError {
|
||||
t.Fatalf("expected symlink escape from allowed dir to be blocked, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWhitelistFs_WriteAllowsNewFileUnderAllowedDir(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
rootDir := t.TempDir()
|
||||
allowedDir := filepath.Join(rootDir, "allowed")
|
||||
targetFile := filepath.Join(allowedDir, "nested", "file.txt")
|
||||
|
||||
patterns := []*regexp.Regexp{regexp.MustCompile(`^` + regexp.QuoteMeta(allowedDir))}
|
||||
tool := NewWriteFileTool(workspace, true, patterns)
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{
|
||||
"path": targetFile,
|
||||
"content": "outside write",
|
||||
})
|
||||
if result.IsError {
|
||||
t.Fatalf("expected whitelisted write to succeed, got: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(targetFile)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFile(targetFile) error = %v", err)
|
||||
}
|
||||
if string(data) != "outside write" {
|
||||
t.Fatalf("target file content = %q, want %q", string(data), "outside write")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWhitelistFs_AllowsResolvedAllowedRootAlias(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
realDir := t.TempDir()
|
||||
linkParent := t.TempDir()
|
||||
allowedAlias := filepath.Join(linkParent, "allowed-link")
|
||||
|
||||
if err := os.Symlink(realDir, allowedAlias); err != nil {
|
||||
t.Skipf("symlink not supported in this environment: %v", err)
|
||||
}
|
||||
|
||||
targetFile := filepath.Join(allowedAlias, "nested", "alias.txt")
|
||||
if err := os.MkdirAll(filepath.Dir(targetFile), 0o755); err != nil {
|
||||
t.Fatalf("MkdirAll(targetFile dir) error = %v", err)
|
||||
}
|
||||
if err := os.WriteFile(targetFile, []byte("through alias"), 0o644); err != nil {
|
||||
t.Fatalf("WriteFile(targetFile) error = %v", err)
|
||||
}
|
||||
|
||||
patterns := []*regexp.Regexp{
|
||||
regexp.MustCompile(
|
||||
"^" + regexp.QuoteMeta(filepath.Clean(allowedAlias)) +
|
||||
"(?:" + regexp.QuoteMeta(string(os.PathSeparator)) + "|$)",
|
||||
),
|
||||
}
|
||||
tool := NewReadFileTool(workspace, true, MaxReadFileSize, patterns)
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{"path": targetFile})
|
||||
if result.IsError {
|
||||
t.Fatalf("expected symlink-backed allowed root to be readable, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "through alias") {
|
||||
t.Fatalf("expected file content, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestReadFileTool_ChunkedReading verifies the pagination logic of the tool
|
||||
// by reading a file in multiple chunks using 'offset' and 'length'.
|
||||
func TestReadFileTool_ChunkedReading(t *testing.T) {
|
||||
|
||||
+15
-2
@@ -6,6 +6,7 @@ import (
|
||||
"mime"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/h2non/filetype"
|
||||
@@ -21,20 +22,32 @@ type SendFileTool struct {
|
||||
restrict bool
|
||||
maxFileSize int
|
||||
mediaStore media.MediaStore
|
||||
allowPaths []*regexp.Regexp
|
||||
|
||||
defaultChannel string
|
||||
defaultChatID string
|
||||
}
|
||||
|
||||
func NewSendFileTool(workspace string, restrict bool, maxFileSize int, store media.MediaStore) *SendFileTool {
|
||||
func NewSendFileTool(
|
||||
workspace string,
|
||||
restrict bool,
|
||||
maxFileSize int,
|
||||
store media.MediaStore,
|
||||
allowPaths ...[]*regexp.Regexp,
|
||||
) *SendFileTool {
|
||||
if maxFileSize <= 0 {
|
||||
maxFileSize = config.DefaultMaxMediaSize
|
||||
}
|
||||
var patterns []*regexp.Regexp
|
||||
if len(allowPaths) > 0 {
|
||||
patterns = allowPaths[0]
|
||||
}
|
||||
return &SendFileTool{
|
||||
workspace: workspace,
|
||||
restrict: restrict,
|
||||
maxFileSize: maxFileSize,
|
||||
mediaStore: store,
|
||||
allowPaths: patterns,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -92,7 +105,7 @@ func (t *SendFileTool) Execute(ctx context.Context, args map[string]any) *ToolRe
|
||||
return ErrorResult("media store not configured")
|
||||
}
|
||||
|
||||
resolved, err := validatePath(path, t.workspace, t.restrict)
|
||||
resolved, err := validatePathWithAllowPaths(path, t.workspace, t.restrict, t.allowPaths)
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("invalid path: %v", err))
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -128,6 +129,44 @@ func TestSendFileTool_CustomFilename(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendFileTool_AllowsWhitelistedMediaTempPath(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
mediaDir := media.TempDir()
|
||||
if err := os.MkdirAll(mediaDir, 0o700); err != nil {
|
||||
t.Fatalf("MkdirAll(mediaDir) error = %v", err)
|
||||
}
|
||||
|
||||
testFile, err := os.CreateTemp(mediaDir, "send-file-*.txt")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateTemp(mediaDir) error = %v", err)
|
||||
}
|
||||
testPath := testFile.Name()
|
||||
if _, err := testFile.WriteString("forward me"); err != nil {
|
||||
testFile.Close()
|
||||
t.Fatalf("WriteString(testFile) error = %v", err)
|
||||
}
|
||||
if err := testFile.Close(); err != nil {
|
||||
t.Fatalf("Close(testFile) error = %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = os.Remove(testPath) })
|
||||
|
||||
pattern := regexp.MustCompile(
|
||||
"^" + regexp.QuoteMeta(filepath.Clean(mediaDir)) + "(?:" + regexp.QuoteMeta(string(os.PathSeparator)) + "|$)",
|
||||
)
|
||||
|
||||
store := media.NewFileMediaStore()
|
||||
tool := NewSendFileTool(workspace, true, 0, store, []*regexp.Regexp{pattern})
|
||||
tool.SetContext("feishu", "chat123")
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{"path": testPath})
|
||||
if result.IsError {
|
||||
t.Fatalf("expected whitelisted temp media file to be sendable, got: %s", result.ForLLM)
|
||||
}
|
||||
if len(result.Media) != 1 {
|
||||
t.Fatalf("expected 1 media ref, got %d", len(result.Media))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectMediaType_MagicBytes(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
|
||||
+31
-13
@@ -23,6 +23,7 @@ type ExecTool struct {
|
||||
denyPatterns []*regexp.Regexp
|
||||
allowPatterns []*regexp.Regexp
|
||||
customAllowPatterns []*regexp.Regexp
|
||||
allowedPathPatterns []*regexp.Regexp
|
||||
restrictToWorkspace bool
|
||||
allowRemote bool
|
||||
}
|
||||
@@ -95,14 +96,23 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
func NewExecTool(workingDir string, restrict bool) (*ExecTool, error) {
|
||||
return NewExecToolWithConfig(workingDir, restrict, nil)
|
||||
func NewExecTool(workingDir string, restrict bool, allowPaths ...[]*regexp.Regexp) (*ExecTool, error) {
|
||||
return NewExecToolWithConfig(workingDir, restrict, nil, allowPaths...)
|
||||
}
|
||||
|
||||
func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Config) (*ExecTool, error) {
|
||||
func NewExecToolWithConfig(
|
||||
workingDir string,
|
||||
restrict bool,
|
||||
config *config.Config,
|
||||
allowPaths ...[]*regexp.Regexp,
|
||||
) (*ExecTool, error) {
|
||||
denyPatterns := make([]*regexp.Regexp, 0)
|
||||
customAllowPatterns := make([]*regexp.Regexp, 0)
|
||||
var allowedPathPatterns []*regexp.Regexp
|
||||
allowRemote := true
|
||||
if len(allowPaths) > 0 {
|
||||
allowedPathPatterns = allowPaths[0]
|
||||
}
|
||||
|
||||
if config != nil {
|
||||
execConfig := config.Tools.Exec
|
||||
@@ -146,6 +156,7 @@ func NewExecToolWithConfig(workingDir string, restrict bool, config *config.Conf
|
||||
denyPatterns: denyPatterns,
|
||||
allowPatterns: nil,
|
||||
customAllowPatterns: customAllowPatterns,
|
||||
allowedPathPatterns: allowedPathPatterns,
|
||||
restrictToWorkspace: restrict,
|
||||
allowRemote: allowRemote,
|
||||
}, nil
|
||||
@@ -198,7 +209,7 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]any) *ToolResult
|
||||
cwd := t.workingDir
|
||||
if wd, ok := args["working_dir"].(string); ok && wd != "" {
|
||||
if t.restrictToWorkspace && t.workingDir != "" {
|
||||
resolvedWD, err := validatePath(wd, t.workingDir, true)
|
||||
resolvedWD, err := validatePathWithAllowPaths(wd, t.workingDir, true, t.allowedPathPatterns)
|
||||
if err != nil {
|
||||
return ErrorResult("Command blocked by safety guard (" + err.Error() + ")")
|
||||
}
|
||||
@@ -226,16 +237,20 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]any) *ToolResult
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("Command blocked by safety guard (path resolution failed: %v)", err))
|
||||
}
|
||||
absWorkspace, _ := filepath.Abs(t.workingDir)
|
||||
wsResolved, _ := filepath.EvalSymlinks(absWorkspace)
|
||||
if wsResolved == "" {
|
||||
wsResolved = absWorkspace
|
||||
if isAllowedPath(resolved, t.allowedPathPatterns) {
|
||||
cwd = resolved
|
||||
} else {
|
||||
absWorkspace, _ := filepath.Abs(t.workingDir)
|
||||
wsResolved, _ := filepath.EvalSymlinks(absWorkspace)
|
||||
if wsResolved == "" {
|
||||
wsResolved = absWorkspace
|
||||
}
|
||||
rel, err := filepath.Rel(wsResolved, resolved)
|
||||
if err != nil || !filepath.IsLocal(rel) {
|
||||
return ErrorResult("Command blocked by safety guard (working directory escaped workspace)")
|
||||
}
|
||||
cwd = resolved
|
||||
}
|
||||
rel, err := filepath.Rel(wsResolved, resolved)
|
||||
if err != nil || !filepath.IsLocal(rel) {
|
||||
return ErrorResult("Command blocked by safety guard (working directory escaped workspace)")
|
||||
}
|
||||
cwd = resolved
|
||||
}
|
||||
|
||||
// timeout == 0 means no timeout
|
||||
@@ -412,6 +427,9 @@ func (t *ExecTool) guardCommand(command, cwd string) string {
|
||||
if safePaths[p] {
|
||||
continue
|
||||
}
|
||||
if isAllowedPath(p, t.allowedPathPatterns) {
|
||||
continue
|
||||
}
|
||||
|
||||
rel, err := filepath.Rel(cwdPath, p)
|
||||
if err != nil {
|
||||
|
||||
@@ -0,0 +1,178 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SpawnStatusTool reports the status of subagents that were spawned via the
|
||||
// spawn tool. It can query a specific task by ID, or list every known task with
|
||||
// a summary count broken-down by status.
|
||||
type SpawnStatusTool struct {
|
||||
manager *SubagentManager
|
||||
}
|
||||
|
||||
// NewSpawnStatusTool creates a SpawnStatusTool backed by the given manager.
|
||||
func NewSpawnStatusTool(manager *SubagentManager) *SpawnStatusTool {
|
||||
return &SpawnStatusTool{manager: manager}
|
||||
}
|
||||
|
||||
func (t *SpawnStatusTool) Name() string {
|
||||
return "spawn_status"
|
||||
}
|
||||
|
||||
func (t *SpawnStatusTool) Description() string {
|
||||
return "Get the status of spawned subagents. " +
|
||||
"Returns a list of all subagents and their current state " +
|
||||
"(running, completed, failed, or canceled), or retrieves details " +
|
||||
"for a specific subagent task when task_id is provided. " +
|
||||
"Results are scoped to the current conversation's channel and chat ID; " +
|
||||
"all tasks are listed only when no channel/chat context is injected " +
|
||||
"(e.g. direct programmatic calls via Execute)."
|
||||
}
|
||||
|
||||
func (t *SpawnStatusTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"task_id": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Optional task ID (e.g. \"subagent-1\") to inspect a specific " +
|
||||
"subagent. When omitted, all visible subagents are listed.",
|
||||
},
|
||||
},
|
||||
"required": []string{},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *SpawnStatusTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
if t.manager == nil {
|
||||
return ErrorResult("Subagent manager not configured")
|
||||
}
|
||||
|
||||
// Derive the calling conversation's identity so we can scope results to the
|
||||
// current chat only — preventing cross-conversation task leakage in
|
||||
// multi-user deployments.
|
||||
callerChannel := ToolChannel(ctx)
|
||||
callerChatID := ToolChatID(ctx)
|
||||
|
||||
var taskID string
|
||||
if rawTaskID, ok := args["task_id"]; ok && rawTaskID != nil {
|
||||
taskIDStr, ok := rawTaskID.(string)
|
||||
if !ok {
|
||||
return ErrorResult("task_id must be a string")
|
||||
}
|
||||
taskID = strings.TrimSpace(taskIDStr)
|
||||
}
|
||||
|
||||
if taskID != "" {
|
||||
// GetTaskCopy returns a consistent snapshot under the manager lock,
|
||||
// eliminating any data race with the concurrent subagent goroutine.
|
||||
taskCopy, ok := t.manager.GetTaskCopy(taskID)
|
||||
if !ok {
|
||||
return ErrorResult(fmt.Sprintf("No subagent found with task ID: %s", taskID))
|
||||
}
|
||||
|
||||
// Restrict lookup to tasks that belong to this conversation.
|
||||
if callerChannel != "" && taskCopy.OriginChannel != "" && taskCopy.OriginChannel != callerChannel {
|
||||
return ErrorResult(fmt.Sprintf("No subagent found with task ID: %s", taskID))
|
||||
}
|
||||
if callerChatID != "" && taskCopy.OriginChatID != "" && taskCopy.OriginChatID != callerChatID {
|
||||
return ErrorResult(fmt.Sprintf("No subagent found with task ID: %s", taskID))
|
||||
}
|
||||
|
||||
return NewToolResult(spawnStatusFormatTask(&taskCopy))
|
||||
}
|
||||
|
||||
// ListTaskCopies returns consistent snapshots under the manager lock.
|
||||
origTasks := t.manager.ListTaskCopies()
|
||||
if len(origTasks) == 0 {
|
||||
return NewToolResult("No subagents have been spawned yet.")
|
||||
}
|
||||
|
||||
tasks := make([]*SubagentTask, 0, len(origTasks))
|
||||
for i := range origTasks {
|
||||
cpy := &origTasks[i]
|
||||
|
||||
// Filter to tasks that originate from the current conversation only.
|
||||
if callerChannel != "" && cpy.OriginChannel != "" && cpy.OriginChannel != callerChannel {
|
||||
continue
|
||||
}
|
||||
if callerChatID != "" && cpy.OriginChatID != "" && cpy.OriginChatID != callerChatID {
|
||||
continue
|
||||
}
|
||||
|
||||
tasks = append(tasks, cpy)
|
||||
}
|
||||
|
||||
if len(tasks) == 0 {
|
||||
return NewToolResult("No subagents found for this conversation.")
|
||||
}
|
||||
|
||||
// Order by creation time (ascending) so spawning order is preserved.
|
||||
// Fall back to ID string for tasks created in the same millisecond.
|
||||
sort.Slice(tasks, func(i, j int) bool {
|
||||
if tasks[i].Created != tasks[j].Created {
|
||||
return tasks[i].Created < tasks[j].Created
|
||||
}
|
||||
return tasks[i].ID < tasks[j].ID
|
||||
})
|
||||
|
||||
counts := map[string]int{}
|
||||
for _, task := range tasks {
|
||||
counts[task.Status]++
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
sb.WriteString(fmt.Sprintf("Subagent status report (%d total):\n", len(tasks)))
|
||||
for _, status := range []string{"running", "completed", "failed", "canceled"} {
|
||||
if n := counts[status]; n > 0 {
|
||||
label := strings.ToUpper(status[:1]) + status[1:] + ":"
|
||||
sb.WriteString(fmt.Sprintf(" %-10s %d\n", label, n))
|
||||
}
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
|
||||
for _, task := range tasks {
|
||||
sb.WriteString(spawnStatusFormatTask(task))
|
||||
sb.WriteString("\n\n")
|
||||
}
|
||||
|
||||
return NewToolResult(strings.TrimRight(sb.String(), "\n"))
|
||||
}
|
||||
|
||||
// spawnStatusFormatTask renders a single SubagentTask as a human-readable block.
|
||||
func spawnStatusFormatTask(task *SubagentTask) string {
|
||||
var sb strings.Builder
|
||||
|
||||
header := fmt.Sprintf("[%s] status=%s", task.ID, task.Status)
|
||||
if task.Label != "" {
|
||||
header += fmt.Sprintf(" label=%q", task.Label)
|
||||
}
|
||||
if task.AgentID != "" {
|
||||
header += fmt.Sprintf(" agent=%s", task.AgentID)
|
||||
}
|
||||
if task.Created > 0 {
|
||||
created := time.UnixMilli(task.Created).UTC().Format("2006-01-02 15:04:05 UTC")
|
||||
header += fmt.Sprintf(" created=%s", created)
|
||||
}
|
||||
sb.WriteString(header)
|
||||
|
||||
if task.Task != "" {
|
||||
sb.WriteString(fmt.Sprintf("\n task: %s", task.Task))
|
||||
}
|
||||
if task.Result != "" {
|
||||
result := task.Result
|
||||
const maxResultLen = 300
|
||||
runes := []rune(result)
|
||||
if len(runes) > maxResultLen {
|
||||
result = string(runes[:maxResultLen]) + "…"
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("\n result: %s", result))
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
@@ -0,0 +1,406 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSpawnStatusTool_Name(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
workspace := t.TempDir()
|
||||
manager := NewSubagentManager(provider, "test-model", workspace)
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
|
||||
if tool.Name() != "spawn_status" {
|
||||
t.Errorf("Expected name 'spawn_status', got '%s'", tool.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_Description(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
workspace := t.TempDir()
|
||||
manager := NewSubagentManager(provider, "test-model", workspace)
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
|
||||
desc := tool.Description()
|
||||
if desc == "" {
|
||||
t.Error("Description should not be empty")
|
||||
}
|
||||
if !strings.Contains(strings.ToLower(desc), "subagent") {
|
||||
t.Errorf("Description should mention 'subagent', got: %s", desc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_Parameters(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
workspace := t.TempDir()
|
||||
manager := NewSubagentManager(provider, "test-model", workspace)
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
|
||||
params := tool.Parameters()
|
||||
if params["type"] != "object" {
|
||||
t.Errorf("Expected type 'object', got: %v", params["type"])
|
||||
}
|
||||
props, ok := params["properties"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("Expected 'properties' to be a map")
|
||||
}
|
||||
if _, hasTaskID := props["task_id"]; !hasTaskID {
|
||||
t.Error("Expected 'task_id' parameter in properties")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_NilManager(t *testing.T) {
|
||||
tool := &SpawnStatusTool{manager: nil}
|
||||
result := tool.Execute(context.Background(), map[string]any{})
|
||||
if !result.IsError {
|
||||
t.Error("Expected error result when manager is nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_Empty(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
workspace := t.TempDir()
|
||||
manager := NewSubagentManager(provider, "test-model", workspace)
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{})
|
||||
if result.IsError {
|
||||
t.Fatalf("Expected success, got error: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "No subagents") {
|
||||
t.Errorf("Expected 'No subagents' message, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_ListAll(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
workspace := t.TempDir()
|
||||
manager := NewSubagentManager(provider, "test-model", workspace)
|
||||
|
||||
now := time.Now().UnixMilli()
|
||||
manager.mu.Lock()
|
||||
manager.tasks["subagent-1"] = &SubagentTask{
|
||||
ID: "subagent-1",
|
||||
Task: "Do task A",
|
||||
Label: "task-a",
|
||||
Status: "running",
|
||||
Created: now,
|
||||
}
|
||||
manager.tasks["subagent-2"] = &SubagentTask{
|
||||
ID: "subagent-2",
|
||||
Task: "Do task B",
|
||||
Label: "task-b",
|
||||
Status: "completed",
|
||||
Result: "Done successfully",
|
||||
Created: now,
|
||||
}
|
||||
manager.tasks["subagent-3"] = &SubagentTask{
|
||||
ID: "subagent-3",
|
||||
Task: "Do task C",
|
||||
Status: "failed",
|
||||
Result: "Error: something went wrong",
|
||||
}
|
||||
manager.mu.Unlock()
|
||||
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
result := tool.Execute(context.Background(), map[string]any{})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf("Expected success, got error: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// Summary header
|
||||
if !strings.Contains(result.ForLLM, "3 total") {
|
||||
t.Errorf("Expected total count in header, got: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// Individual task IDs
|
||||
for _, id := range []string{"subagent-1", "subagent-2", "subagent-3"} {
|
||||
if !strings.Contains(result.ForLLM, id) {
|
||||
t.Errorf("Expected task %s in output, got:\n%s", id, result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// Status values
|
||||
for _, status := range []string{"running", "completed", "failed"} {
|
||||
if !strings.Contains(result.ForLLM, status) {
|
||||
t.Errorf("Expected status '%s' in output, got:\n%s", status, result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// Result content
|
||||
if !strings.Contains(result.ForLLM, "Done successfully") {
|
||||
t.Errorf("Expected result text in output, got:\n%s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_GetByID(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
|
||||
manager.mu.Lock()
|
||||
manager.tasks["subagent-42"] = &SubagentTask{
|
||||
ID: "subagent-42",
|
||||
Task: "Specific task",
|
||||
Label: "my-task",
|
||||
Status: "failed",
|
||||
Result: "Something went wrong",
|
||||
Created: time.Now().UnixMilli(),
|
||||
}
|
||||
manager.mu.Unlock()
|
||||
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
result := tool.Execute(context.Background(), map[string]any{"task_id": "subagent-42"})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf("Expected success, got error: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "subagent-42") {
|
||||
t.Errorf("Expected task ID in output, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "failed") {
|
||||
t.Errorf("Expected status 'failed' in output, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "Something went wrong") {
|
||||
t.Errorf("Expected result text in output, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "my-task") {
|
||||
t.Errorf("Expected label in output, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_GetByID_NotFound(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{"task_id": "nonexistent-999"})
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error for nonexistent task, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "nonexistent-999") {
|
||||
t.Errorf("Expected task ID in error message, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_TaskID_NonString(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
|
||||
for _, badVal := range []any{42, 3.14, true, map[string]any{"x": 1}, []string{"a"}} {
|
||||
result := tool.Execute(context.Background(), map[string]any{"task_id": badVal})
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error for task_id=%T(%v), got success: %s", badVal, badVal, result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "task_id must be a string") {
|
||||
t.Errorf("Expected type-error message, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_ResultTruncation(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
|
||||
longResult := strings.Repeat("X", 500)
|
||||
manager.mu.Lock()
|
||||
manager.tasks["subagent-1"] = &SubagentTask{
|
||||
ID: "subagent-1",
|
||||
Task: "Long task",
|
||||
Status: "completed",
|
||||
Result: longResult,
|
||||
}
|
||||
manager.mu.Unlock()
|
||||
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
result := tool.Execute(context.Background(), map[string]any{"task_id": "subagent-1"})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf("Unexpected error: %s", result.ForLLM)
|
||||
}
|
||||
// Output should be shorter than the raw result due to truncation
|
||||
if len(result.ForLLM) >= len(longResult) {
|
||||
t.Errorf("Expected result to be truncated, but ForLLM is %d chars", len(result.ForLLM))
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "…") {
|
||||
t.Errorf("Expected truncation indicator '…' in output, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_ResultTruncation_Unicode(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
|
||||
// Each CJK rune is 3 bytes; 400 runes = 1200 bytes — well over the 300-rune limit.
|
||||
cjkChar := string(rune(0x5b57))
|
||||
longResult := strings.Repeat(cjkChar, 400)
|
||||
manager.mu.Lock()
|
||||
manager.tasks["subagent-1"] = &SubagentTask{
|
||||
ID: "subagent-1",
|
||||
Task: "Unicode task",
|
||||
Status: "completed",
|
||||
Result: longResult,
|
||||
}
|
||||
manager.mu.Unlock()
|
||||
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
result := tool.Execute(context.Background(), map[string]any{"task_id": "subagent-1"})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf("Unexpected error: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "…") {
|
||||
t.Errorf("Expected truncation indicator in output")
|
||||
}
|
||||
// The truncated result must be valid UTF-8 (no split rune boundaries).
|
||||
if !strings.Contains(result.ForLLM, cjkChar) {
|
||||
t.Errorf("Expected CJK runes to appear intact in output")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_StatusCounts(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
|
||||
manager.mu.Lock()
|
||||
for i, status := range []string{"running", "running", "completed", "failed", "canceled"} {
|
||||
id := fmt.Sprintf("subagent-%d", i+1)
|
||||
manager.tasks[id] = &SubagentTask{ID: id, Task: "t", Status: status}
|
||||
}
|
||||
manager.mu.Unlock()
|
||||
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
result := tool.Execute(context.Background(), map[string]any{})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf("Unexpected error: %s", result.ForLLM)
|
||||
}
|
||||
// The summary line should mention all statuses that have counts
|
||||
for _, want := range []string{"Running:", "Completed:", "Failed:", "Canceled:"} {
|
||||
if !strings.Contains(result.ForLLM, want) {
|
||||
t.Errorf("Expected %q in summary, got:\n%s", want, result.ForLLM)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_SortByCreatedTimestamp(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
|
||||
now := time.Now().UnixMilli()
|
||||
manager.mu.Lock()
|
||||
// Intentionally insert with out-of-order IDs and timestamps that reflect
|
||||
// true spawn order: subagent-2 was spawned first, subagent-10 second.
|
||||
manager.tasks["subagent-10"] = &SubagentTask{
|
||||
ID: "subagent-10", Task: "second", Status: "running",
|
||||
Created: now + 1,
|
||||
}
|
||||
manager.tasks["subagent-2"] = &SubagentTask{
|
||||
ID: "subagent-2", Task: "first", Status: "running",
|
||||
Created: now,
|
||||
}
|
||||
manager.mu.Unlock()
|
||||
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
result := tool.Execute(context.Background(), map[string]any{})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf("Unexpected error: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
pos2 := strings.Index(result.ForLLM, "subagent-2")
|
||||
pos10 := strings.Index(result.ForLLM, "subagent-10")
|
||||
if pos2 < 0 || pos10 < 0 {
|
||||
t.Fatalf("Both task IDs should appear in output:\n%s", result.ForLLM)
|
||||
}
|
||||
if pos2 > pos10 {
|
||||
t.Errorf("Expected subagent-2 (created first) to appear before subagent-10, but got:\n%s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_ChannelFiltering_ListAll(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
|
||||
manager.mu.Lock()
|
||||
manager.tasks["subagent-1"] = &SubagentTask{
|
||||
ID: "subagent-1", Task: "mine", Status: "running",
|
||||
OriginChannel: "telegram", OriginChatID: "chat-A",
|
||||
}
|
||||
manager.tasks["subagent-2"] = &SubagentTask{
|
||||
ID: "subagent-2", Task: "other user", Status: "running",
|
||||
OriginChannel: "telegram", OriginChatID: "chat-B",
|
||||
}
|
||||
manager.mu.Unlock()
|
||||
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
|
||||
// Caller is chat-A — should only see subagent-1.
|
||||
ctx := WithToolContext(context.Background(), "telegram", "chat-A")
|
||||
result := tool.Execute(ctx, map[string]any{})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf("Unexpected error: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "subagent-1") {
|
||||
t.Errorf("Expected own task in output, got:\n%s", result.ForLLM)
|
||||
}
|
||||
if strings.Contains(result.ForLLM, "subagent-2") {
|
||||
t.Errorf("Should NOT see other chat's task, got:\n%s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_ChannelFiltering_GetByID(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
|
||||
manager.mu.Lock()
|
||||
manager.tasks["subagent-99"] = &SubagentTask{
|
||||
ID: "subagent-99", Task: "secret", Status: "completed", Result: "private data",
|
||||
OriginChannel: "slack", OriginChatID: "room-Z",
|
||||
}
|
||||
manager.mu.Unlock()
|
||||
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
|
||||
// Different chat trying to look up subagent-99 by ID.
|
||||
ctx := WithToolContext(context.Background(), "slack", "room-OTHER")
|
||||
result := tool.Execute(ctx, map[string]any{"task_id": "subagent-99"})
|
||||
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error (cross-chat lookup blocked), got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnStatusTool_ChannelFiltering_NoContext(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
|
||||
|
||||
manager.mu.Lock()
|
||||
manager.tasks["subagent-1"] = &SubagentTask{
|
||||
ID: "subagent-1", Task: "t", Status: "completed",
|
||||
OriginChannel: "telegram", OriginChatID: "chat-A",
|
||||
}
|
||||
manager.mu.Unlock()
|
||||
|
||||
tool := NewSpawnStatusTool(manager)
|
||||
|
||||
// No ToolContext injected (e.g. a direct programmatic call that bypasses
|
||||
// WithToolContext entirely) — callerChannel and callerChatID are both "".
|
||||
// Note: the normal CLI path uses ProcessDirectWithChannel("cli", "direct"),
|
||||
// which *does* inject a non-empty context; this test covers the case where
|
||||
// no context injection happens at all.
|
||||
// The filter conditions require a non-empty caller value, so all tasks pass through.
|
||||
result := tool.Execute(context.Background(), map[string]any{})
|
||||
if result.IsError {
|
||||
t.Fatalf("Unexpected error: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "subagent-1") {
|
||||
t.Errorf("Expected task visible from no-context caller, got:\n%s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
@@ -255,6 +255,18 @@ func (sm *SubagentManager) GetTask(taskID string) (*SubagentTask, bool) {
|
||||
return task, ok
|
||||
}
|
||||
|
||||
// GetTaskCopy returns a copy of the task with the given ID, taken under the
|
||||
// read lock, so the caller receives a consistent snapshot with no data race.
|
||||
func (sm *SubagentManager) GetTaskCopy(taskID string) (SubagentTask, bool) {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
task, ok := sm.tasks[taskID]
|
||||
if !ok {
|
||||
return SubagentTask{}, false
|
||||
}
|
||||
return *task, true
|
||||
}
|
||||
|
||||
func (sm *SubagentManager) ListTasks() []*SubagentTask {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
@@ -266,6 +278,19 @@ func (sm *SubagentManager) ListTasks() []*SubagentTask {
|
||||
return tasks
|
||||
}
|
||||
|
||||
// ListTaskCopies returns value copies of all tasks, taken under the read lock,
|
||||
// so callers receive consistent snapshots with no data race.
|
||||
func (sm *SubagentManager) ListTaskCopies() []SubagentTask {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
|
||||
copies := make([]SubagentTask, 0, len(sm.tasks))
|
||||
for _, task := range sm.tasks {
|
||||
copies = append(copies, *task)
|
||||
}
|
||||
return copies
|
||||
}
|
||||
|
||||
// SubagentTool executes a subagent task synchronously and returns the result.
|
||||
// It directly calls SubTurnSpawner with Async=false for synchronous execution.
|
||||
type SubagentTool struct {
|
||||
|
||||
+2
-1
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
)
|
||||
|
||||
// IsAudioFile checks if a file is an audio file based on its filename extension and content type.
|
||||
@@ -67,7 +68,7 @@ func DownloadFile(urlStr, filename string, opts DownloadOptions) string {
|
||||
opts.LoggerPrefix = "utils"
|
||||
}
|
||||
|
||||
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
|
||||
mediaDir := media.TempDir()
|
||||
if err := os.MkdirAll(mediaDir, 0o700); err != nil {
|
||||
logger.ErrorCF(opts.LoggerPrefix, "Failed to create media directory", map[string]any{
|
||||
"error": err.Error(),
|
||||
|
||||
+59
-4
@@ -1,5 +1,60 @@
|
||||
.PHONY: dev dev-frontend dev-backend build test lint clean
|
||||
|
||||
# Go variables
|
||||
GO?=CGO_ENABLED=0 go
|
||||
WEB_GO?=$(GO)
|
||||
GOFLAGS?=-v -tags stdjson
|
||||
|
||||
# Version
|
||||
VERSION?=$(shell git describe --tags --always --dirty 2>/dev/null || echo "dev")
|
||||
GIT_COMMIT=$(shell git rev-parse --short=8 HEAD 2>/dev/null || echo "dev")
|
||||
BUILD_TIME=$(shell date +%FT%T%z)
|
||||
GO_VERSION=$(shell $(WEB_GO) version | awk '{print $$3}')
|
||||
CONFIG_PKG=github.com/sipeed/picoclaw/pkg/config
|
||||
LDFLAGS=-X $(CONFIG_PKG).Version=$(VERSION) -X $(CONFIG_PKG).GitCommit=$(GIT_COMMIT) -X $(CONFIG_PKG).BuildTime=$(BUILD_TIME) -X $(CONFIG_PKG).GoVersion=$(GO_VERSION) -s -w
|
||||
|
||||
|
||||
# OS detection
|
||||
UNAME_S:=$(shell uname -s)
|
||||
UNAME_M:=$(shell uname -m)
|
||||
|
||||
# Platform-specific settings
|
||||
ifeq ($(UNAME_S),Linux)
|
||||
PLATFORM=linux
|
||||
ifeq ($(UNAME_M),x86_64)
|
||||
ARCH=amd64
|
||||
else ifeq ($(UNAME_M),aarch64)
|
||||
ARCH=arm64
|
||||
else ifeq ($(UNAME_M),armv81)
|
||||
ARCH=arm64
|
||||
else ifeq ($(UNAME_M),loongarch64)
|
||||
ARCH=loong64
|
||||
else ifeq ($(UNAME_M),riscv64)
|
||||
ARCH=riscv64
|
||||
else ifeq ($(UNAME_M),mipsel)
|
||||
ARCH=mipsle
|
||||
else
|
||||
ARCH=$(UNAME_M)
|
||||
endif
|
||||
else ifeq ($(UNAME_S),Darwin)
|
||||
PLATFORM=darwin
|
||||
WEB_GO=CGO_ENABLED=1 go
|
||||
ifeq ($(UNAME_M),x86_64)
|
||||
ARCH=amd64
|
||||
else ifeq ($(UNAME_M),arm64)
|
||||
ARCH=arm64
|
||||
else
|
||||
ARCH=$(UNAME_M)
|
||||
endif
|
||||
else ifeq ($(UNAME_S),Windows)
|
||||
PLATFORM=windows
|
||||
ARCH=$(UNAME_M)
|
||||
LDFLAGS=-H=windowsgui $(LDFLAGS)
|
||||
else
|
||||
PLATFORM=$(UNAME_S)
|
||||
ARCH=$(UNAME_M)
|
||||
endif
|
||||
|
||||
# Run both frontend and backend dev servers
|
||||
dev:
|
||||
@if [ ! -f backend/picoclaw-web ] || [ ! -d backend/dist ]; then \
|
||||
@@ -15,21 +70,21 @@ dev-frontend:
|
||||
|
||||
# Start backend dev server
|
||||
dev-backend:
|
||||
cd backend && go run .
|
||||
cd backend && ${WEB_GO} run -ldflags "$(LDFLAGS)" .
|
||||
|
||||
# Build frontend and embed into Go binary
|
||||
build:
|
||||
cd frontend && pnpm build:backend
|
||||
cd backend && go build -o picoclaw-web .
|
||||
cd backend && ${WEB_GO} build $(GOFLAGS) -ldflags "$(LDFLAGS)" -o picoclaw-web .
|
||||
|
||||
# Run all tests
|
||||
test:
|
||||
cd backend && go test ./...
|
||||
cd backend && ${WEB_GO} test ./...
|
||||
cd frontend && pnpm lint
|
||||
|
||||
# Lint and format
|
||||
lint:
|
||||
cd backend && go vet ./...
|
||||
cd backend && ${WEB_GO} vet ./...
|
||||
cd frontend && pnpm check
|
||||
|
||||
# Clean build artifacts
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"regexp"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
@@ -188,6 +189,27 @@ func validateConfig(cfg *config.Config) []string {
|
||||
errs = append(errs, "channels.discord.token is required when discord channel is enabled")
|
||||
}
|
||||
|
||||
if cfg.Tools.Exec.Enabled {
|
||||
if cfg.Tools.Exec.EnableDenyPatterns {
|
||||
errs = append(
|
||||
errs,
|
||||
validateRegexPatterns("tools.exec.custom_deny_patterns", cfg.Tools.Exec.CustomDenyPatterns)...)
|
||||
}
|
||||
errs = append(
|
||||
errs,
|
||||
validateRegexPatterns("tools.exec.custom_allow_patterns", cfg.Tools.Exec.CustomAllowPatterns)...)
|
||||
}
|
||||
|
||||
return errs
|
||||
}
|
||||
|
||||
func validateRegexPatterns(field string, patterns []string) []string {
|
||||
var errs []string
|
||||
for index, pattern := range patterns {
|
||||
if _, err := regexp.Compile(pattern); err != nil {
|
||||
errs = append(errs, fmt.Sprintf("%s[%d] is not a valid regular expression: %v", field, index, err))
|
||||
}
|
||||
}
|
||||
return errs
|
||||
}
|
||||
|
||||
|
||||
@@ -86,3 +86,82 @@ func TestHandleUpdateConfig_DoesNotInheritDefaultModelFields(t *testing.T) {
|
||||
t.Fatalf("model_list[0].api_base = %q, want empty string", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlePatchConfig_RejectsInvalidExecRegexPatterns(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPatch, "/api/config", bytes.NewBufferString(`{
|
||||
"tools": {
|
||||
"exec": {
|
||||
"custom_deny_patterns": ["("]
|
||||
}
|
||||
}
|
||||
}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
mux.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusBadRequest, rec.Body.String())
|
||||
}
|
||||
if !bytes.Contains(rec.Body.Bytes(), []byte("custom_deny_patterns")) {
|
||||
t.Fatalf("expected validation error mentioning custom_deny_patterns, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlePatchConfig_AllowsInvalidExecRegexPatternsWhenExecDisabled(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPatch, "/api/config", bytes.NewBufferString(`{
|
||||
"tools": {
|
||||
"exec": {
|
||||
"enabled": false,
|
||||
"custom_deny_patterns": ["("],
|
||||
"custom_allow_patterns": ["("]
|
||||
}
|
||||
}
|
||||
}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
mux.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlePatchConfig_AllowsInvalidDenyRegexPatternsWhenDenyPatternsDisabled(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPatch, "/api/config", bytes.NewBufferString(`{
|
||||
"tools": {
|
||||
"exec": {
|
||||
"enabled": true,
|
||||
"enable_deny_patterns": false,
|
||||
"custom_deny_patterns": ["("]
|
||||
}
|
||||
}
|
||||
}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
mux.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,65 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// GatewayEvent represents a state change event for the gateway process.
|
||||
type GatewayEvent struct {
|
||||
Status string `json:"gateway_status"` // "running", "starting", "restarting", "stopped", "error"
|
||||
PID int `json:"pid,omitempty"`
|
||||
BootDefaultModel string `json:"boot_default_model,omitempty"`
|
||||
ConfigDefaultModel string `json:"config_default_model,omitempty"`
|
||||
RestartRequired bool `json:"gateway_restart_required,omitempty"`
|
||||
}
|
||||
|
||||
// EventBroadcaster manages SSE client subscriptions and broadcasts events.
|
||||
type EventBroadcaster struct {
|
||||
mu sync.RWMutex
|
||||
clients map[chan string]struct{}
|
||||
}
|
||||
|
||||
// NewEventBroadcaster creates a new broadcaster.
|
||||
func NewEventBroadcaster() *EventBroadcaster {
|
||||
return &EventBroadcaster{
|
||||
clients: make(map[chan string]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Subscribe adds a new listener channel and returns it.
|
||||
// The caller must call Unsubscribe when done.
|
||||
func (b *EventBroadcaster) Subscribe() chan string {
|
||||
ch := make(chan string, 8)
|
||||
b.mu.Lock()
|
||||
b.clients[ch] = struct{}{}
|
||||
b.mu.Unlock()
|
||||
return ch
|
||||
}
|
||||
|
||||
// Unsubscribe removes a listener channel and closes it.
|
||||
func (b *EventBroadcaster) Unsubscribe(ch chan string) {
|
||||
b.mu.Lock()
|
||||
delete(b.clients, ch)
|
||||
b.mu.Unlock()
|
||||
close(ch)
|
||||
}
|
||||
|
||||
// Broadcast sends a GatewayEvent to all connected SSE clients.
|
||||
func (b *EventBroadcaster) Broadcast(event GatewayEvent) {
|
||||
data, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
b.mu.RLock()
|
||||
defer b.mu.RUnlock()
|
||||
|
||||
for ch := range b.clients {
|
||||
// Non-blocking send; drop event if client is slow
|
||||
select {
|
||||
case ch <- string(data):
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
+250
-216
@@ -3,6 +3,7 @@ package api
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
@@ -18,6 +19,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/health"
|
||||
"github.com/sipeed/picoclaw/web/backend/utils"
|
||||
)
|
||||
|
||||
@@ -29,11 +31,9 @@ var gateway = struct {
|
||||
runtimeStatus string
|
||||
startupDeadline time.Time
|
||||
logs *LogBuffer
|
||||
events *EventBroadcaster
|
||||
}{
|
||||
runtimeStatus: "stopped",
|
||||
logs: NewLogBuffer(200),
|
||||
events: NewEventBroadcaster(),
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -48,10 +48,38 @@ var gatewayHealthGet = func(url string, timeout time.Duration) (*http.Response,
|
||||
return client.Get(url)
|
||||
}
|
||||
|
||||
// getGatewayHealth checks the gateway health endpoint and returns the status response
|
||||
// Returns (*health.StatusResponse, statusCode, error). If error is not nil, the other values are not valid.
|
||||
func (h *Handler) getGatewayHealth(cfg *config.Config, timeout time.Duration) (*health.StatusResponse, int, error) {
|
||||
port := 18790
|
||||
if cfg != nil && cfg.Gateway.Port != 0 {
|
||||
port = cfg.Gateway.Port
|
||||
}
|
||||
|
||||
probeHost := gatewayProbeHost(h.effectiveGatewayBindHost(cfg))
|
||||
url := "http://" + net.JoinHostPort(probeHost, strconv.Itoa(port)) + "/health"
|
||||
|
||||
return getGatewayHealthByURL(url, timeout)
|
||||
}
|
||||
|
||||
func getGatewayHealthByURL(url string, timeout time.Duration) (*health.StatusResponse, int, error) {
|
||||
resp, err := gatewayHealthGet(url, timeout)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var healthResponse health.StatusResponse
|
||||
if decErr := json.NewDecoder(resp.Body).Decode(&healthResponse); decErr != nil {
|
||||
return nil, resp.StatusCode, decErr
|
||||
}
|
||||
|
||||
return &healthResponse, resp.StatusCode, nil
|
||||
}
|
||||
|
||||
// registerGatewayRoutes binds gateway lifecycle endpoints to the ServeMux.
|
||||
func (h *Handler) registerGatewayRoutes(mux *http.ServeMux) {
|
||||
mux.HandleFunc("GET /api/gateway/status", h.handleGatewayStatus)
|
||||
mux.HandleFunc("GET /api/gateway/events", h.handleGatewayEvents)
|
||||
mux.HandleFunc("GET /api/gateway/logs", h.handleGatewayLogs)
|
||||
mux.HandleFunc("POST /api/gateway/logs/clear", h.handleGatewayClearLogs)
|
||||
mux.HandleFunc("POST /api/gateway/start", h.handleGatewayStart)
|
||||
@@ -62,12 +90,35 @@ func (h *Handler) registerGatewayRoutes(mux *http.ServeMux) {
|
||||
// TryAutoStartGateway checks whether gateway start preconditions are met and
|
||||
// starts it when possible. Intended to be called by the backend at startup.
|
||||
func (h *Handler) TryAutoStartGateway() {
|
||||
// Check if gateway is already running via health endpoint
|
||||
cfg, cfgErr := config.LoadConfig(h.configPath)
|
||||
if cfgErr == nil && cfg != nil {
|
||||
healthResp, statusCode, err := h.getGatewayHealth(cfg, 2*time.Second)
|
||||
if err == nil && statusCode == http.StatusOK {
|
||||
// Gateway is already running, attach to the existing process
|
||||
pid := healthResp.Pid
|
||||
gateway.mu.Lock()
|
||||
defer gateway.mu.Unlock()
|
||||
ready, reason, err := h.gatewayStartReady()
|
||||
if err != nil {
|
||||
log.Printf("Skip auto-starting gateway: %v", err)
|
||||
return
|
||||
}
|
||||
if !ready {
|
||||
log.Printf("Skip auto-starting gateway: %s", reason)
|
||||
return
|
||||
}
|
||||
_, err = h.startGatewayLocked("starting", pid)
|
||||
if err != nil {
|
||||
log.Printf("Failed to attach to running gateway (PID: %d): %v", pid, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
gateway.mu.Lock()
|
||||
defer gateway.mu.Unlock()
|
||||
|
||||
if isGatewayProcessAliveLocked() {
|
||||
return
|
||||
}
|
||||
if gateway.cmd != nil && gateway.cmd.Process != nil {
|
||||
gateway.cmd = nil
|
||||
}
|
||||
@@ -82,7 +133,7 @@ func (h *Handler) TryAutoStartGateway() {
|
||||
return
|
||||
}
|
||||
|
||||
pid, err := h.startGatewayLocked("starting")
|
||||
pid, err := h.startGatewayLocked("starting", 0)
|
||||
if err != nil {
|
||||
log.Printf("Failed to auto-start gateway: %v", err)
|
||||
return
|
||||
@@ -125,8 +176,14 @@ func lookupModelConfig(cfg *config.Config, modelName string) *config.ModelConfig
|
||||
return modelCfg
|
||||
}
|
||||
|
||||
func isGatewayProcessAliveLocked() bool {
|
||||
return isCmdProcessAliveLocked(gateway.cmd)
|
||||
func gatewayRestartRequired(configDefaultModel, bootDefaultModel, gatewayStatus string) bool {
|
||||
if gatewayStatus != "running" {
|
||||
return false
|
||||
}
|
||||
if strings.TrimSpace(configDefaultModel) == "" || strings.TrimSpace(bootDefaultModel) == "" {
|
||||
return false
|
||||
}
|
||||
return configDefaultModel != bootDefaultModel
|
||||
}
|
||||
|
||||
func isCmdProcessAliveLocked(cmd *exec.Cmd) bool {
|
||||
@@ -157,7 +214,29 @@ func setGatewayRuntimeStatusLocked(status string) {
|
||||
gateway.startupDeadline = time.Time{}
|
||||
}
|
||||
|
||||
func gatewayStatusOnHealthFailureLocked() string {
|
||||
// attachToGatewayProcess attaches to an existing gateway process by PID
|
||||
// and updates the gateway state accordingly.
|
||||
// Assumes gateway.mu is held by the caller.
|
||||
func attachToGatewayProcessLocked(pid int, cfg *config.Config) error {
|
||||
process, err := os.FindProcess(pid)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to find process for PID %d: %w", pid, err)
|
||||
}
|
||||
|
||||
gateway.cmd = &exec.Cmd{Process: process}
|
||||
setGatewayRuntimeStatusLocked("running")
|
||||
|
||||
// Update bootDefaultModel from config
|
||||
if cfg != nil {
|
||||
defaultModelName := strings.TrimSpace(cfg.Agents.Defaults.GetModelName())
|
||||
gateway.bootDefaultModel = defaultModelName
|
||||
}
|
||||
|
||||
log.Printf("Attached to gateway process (PID: %d)", pid)
|
||||
return nil
|
||||
}
|
||||
|
||||
func gatewayStatusWithoutHealthLocked() string {
|
||||
if gateway.runtimeStatus == "starting" || gateway.runtimeStatus == "restarting" {
|
||||
if gateway.startupDeadline.IsZero() || time.Now().Before(gateway.startupDeadline) {
|
||||
return gateway.runtimeStatus
|
||||
@@ -170,23 +249,7 @@ func gatewayStatusOnHealthFailureLocked() string {
|
||||
if gateway.runtimeStatus == "error" {
|
||||
return "error"
|
||||
}
|
||||
return "error"
|
||||
}
|
||||
|
||||
func currentGatewayStatusLocked(processAlive bool) string {
|
||||
if !processAlive {
|
||||
if gateway.runtimeStatus == "restarting" {
|
||||
if gateway.startupDeadline.IsZero() || time.Now().Before(gateway.startupDeadline) {
|
||||
return "restarting"
|
||||
}
|
||||
return "error"
|
||||
}
|
||||
if gateway.runtimeStatus == "error" {
|
||||
return "error"
|
||||
}
|
||||
return "stopped"
|
||||
}
|
||||
return gatewayStatusOnHealthFailureLocked()
|
||||
return "stopped"
|
||||
}
|
||||
|
||||
func waitForGatewayProcessExit(cmd *exec.Cmd, timeout time.Duration) bool {
|
||||
@@ -238,24 +301,32 @@ func stopGatewayProcessForRestart(cmd *exec.Cmd) error {
|
||||
return fmt.Errorf("existing gateway did not exit before restart")
|
||||
}
|
||||
|
||||
func gatewayRestartRequired(status, bootDefaultModel, configDefaultModel string) bool {
|
||||
return status == "running" &&
|
||||
bootDefaultModel != "" &&
|
||||
configDefaultModel != "" &&
|
||||
bootDefaultModel != configDefaultModel
|
||||
}
|
||||
|
||||
func (h *Handler) startGatewayLocked(initialStatus string) (int, error) {
|
||||
func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int, error) {
|
||||
cfg, err := config.LoadConfig(h.configPath)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to load config: %w", err)
|
||||
}
|
||||
defaultModelName := strings.TrimSpace(cfg.Agents.Defaults.GetModelName())
|
||||
|
||||
var cmd *exec.Cmd
|
||||
var pid int
|
||||
|
||||
if existingPid > 0 {
|
||||
// Attach to existing process
|
||||
pid = existingPid
|
||||
gateway.cmd = nil // Clear first to ensure clean state
|
||||
if err = attachToGatewayProcessLocked(pid, cfg); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return pid, nil
|
||||
}
|
||||
|
||||
// Start new process
|
||||
// Locate the picoclaw executable
|
||||
execPath := utils.FindPicoclawBinary()
|
||||
|
||||
cmd := exec.Command(execPath, "gateway")
|
||||
cmd = exec.Command(execPath, "gateway", "-E")
|
||||
cmd.Env = os.Environ()
|
||||
// Forward the launcher's config path via the environment variable that
|
||||
// GetConfigPath() already reads, so the gateway sub-process uses the same
|
||||
@@ -281,7 +352,7 @@ func (h *Handler) startGatewayLocked(initialStatus string) (int, error) {
|
||||
gateway.logs.Reset()
|
||||
|
||||
// Ensure Pico Channel is configured before starting gateway
|
||||
if _, err := h.ensurePicoChannel(); err != nil {
|
||||
if _, err := h.ensurePicoChannel(""); err != nil {
|
||||
log.Printf("Warning: failed to ensure pico channel: %v", err)
|
||||
// Non-fatal: gateway can still start without pico channel
|
||||
}
|
||||
@@ -293,18 +364,9 @@ func (h *Handler) startGatewayLocked(initialStatus string) (int, error) {
|
||||
gateway.cmd = cmd
|
||||
gateway.bootDefaultModel = defaultModelName
|
||||
setGatewayRuntimeStatusLocked(initialStatus)
|
||||
pid := cmd.Process.Pid
|
||||
pid = cmd.Process.Pid
|
||||
log.Printf("Started picoclaw gateway (PID: %d) from %s", pid, execPath)
|
||||
|
||||
// Broadcast the launch state immediately so clients can reflect it without polling.
|
||||
gateway.events.Broadcast(GatewayEvent{
|
||||
Status: initialStatus,
|
||||
PID: pid,
|
||||
BootDefaultModel: defaultModelName,
|
||||
ConfigDefaultModel: defaultModelName,
|
||||
RestartRequired: false,
|
||||
})
|
||||
|
||||
// Capture stdout/stderr in background
|
||||
go scanPipe(stdoutPipe, gateway.logs)
|
||||
go scanPipe(stderrPipe, gateway.logs)
|
||||
@@ -318,26 +380,17 @@ func (h *Handler) startGatewayLocked(initialStatus string) (int, error) {
|
||||
}
|
||||
|
||||
gateway.mu.Lock()
|
||||
shouldBroadcastStopped := false
|
||||
if gateway.cmd == cmd {
|
||||
gateway.cmd = nil
|
||||
gateway.bootDefaultModel = ""
|
||||
if gateway.runtimeStatus != "restarting" {
|
||||
setGatewayRuntimeStatusLocked("stopped")
|
||||
shouldBroadcastStopped = true
|
||||
}
|
||||
}
|
||||
gateway.mu.Unlock()
|
||||
|
||||
if shouldBroadcastStopped {
|
||||
gateway.events.Broadcast(GatewayEvent{
|
||||
Status: "stopped",
|
||||
RestartRequired: false,
|
||||
})
|
||||
}
|
||||
}()
|
||||
|
||||
// Start a goroutine to probe health and broadcast "running" once ready
|
||||
// Start a goroutine to probe health and update the runtime state once ready.
|
||||
go func() {
|
||||
for i := 0; i < 30; i++ { // try for up to 15 seconds
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
@@ -351,30 +404,15 @@ func (h *Handler) startGatewayLocked(initialStatus string) (int, error) {
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
healthHost := gatewayProbeHost(h.effectiveGatewayBindHost(cfg))
|
||||
healthPort := cfg.Gateway.Port
|
||||
if healthPort == 0 {
|
||||
healthPort = 18790
|
||||
}
|
||||
healthURL := fmt.Sprintf("http://%s/health", net.JoinHostPort(healthHost, strconv.Itoa(healthPort)))
|
||||
resp, err := gatewayHealthGet(healthURL, 1*time.Second)
|
||||
if err == nil {
|
||||
resp.Body.Close()
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
gateway.mu.Lock()
|
||||
if gateway.cmd == cmd {
|
||||
setGatewayRuntimeStatusLocked("running")
|
||||
}
|
||||
gateway.mu.Unlock()
|
||||
gateway.events.Broadcast(GatewayEvent{
|
||||
Status: "running",
|
||||
PID: pid,
|
||||
BootDefaultModel: defaultModelName,
|
||||
ConfigDefaultModel: defaultModelName,
|
||||
RestartRequired: false,
|
||||
})
|
||||
return
|
||||
healthResp, statusCode, err := h.getGatewayHealth(cfg, 1*time.Second)
|
||||
if err == nil && statusCode == http.StatusOK && healthResp.Pid == pid {
|
||||
// Verify the health endpoint returns the expected pid
|
||||
gateway.mu.Lock()
|
||||
if gateway.cmd == cmd {
|
||||
setGatewayRuntimeStatusLocked("running")
|
||||
}
|
||||
gateway.mu.Unlock()
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
@@ -386,19 +424,54 @@ func (h *Handler) startGatewayLocked(initialStatus string) (int, error) {
|
||||
//
|
||||
// POST /api/gateway/start
|
||||
func (h *Handler) handleGatewayStart(w http.ResponseWriter, r *http.Request) {
|
||||
// Prevent duplicate starts by checking health endpoint
|
||||
cfg, cfgErr := config.LoadConfig(h.configPath)
|
||||
if cfgErr == nil && cfg != nil {
|
||||
healthResp, statusCode, err := h.getGatewayHealth(cfg, 2*time.Second)
|
||||
if err == nil && statusCode == http.StatusOK {
|
||||
// Gateway is already running, attach to the existing process
|
||||
pid := healthResp.Pid
|
||||
gateway.mu.Lock()
|
||||
ready, reason, err := h.gatewayStartReady()
|
||||
if err != nil {
|
||||
gateway.mu.Unlock()
|
||||
http.Error(
|
||||
w,
|
||||
fmt.Sprintf("Failed to validate gateway start conditions: %v", err),
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
return
|
||||
}
|
||||
if !ready {
|
||||
gateway.mu.Unlock()
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"status": "precondition_failed",
|
||||
"message": reason,
|
||||
})
|
||||
return
|
||||
}
|
||||
_, err = h.startGatewayLocked("starting", pid)
|
||||
gateway.mu.Unlock()
|
||||
if err != nil {
|
||||
log.Printf("Failed to attach to running gateway (PID: %d): %v", pid, err)
|
||||
http.Error(w, fmt.Sprintf("Failed to attach to gateway: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"status": "ok",
|
||||
"pid": pid,
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
gateway.mu.Lock()
|
||||
defer gateway.mu.Unlock()
|
||||
|
||||
// Prevent duplicate starts
|
||||
if isGatewayProcessAliveLocked() {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusConflict)
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"status": "already_running",
|
||||
"pid": gateway.cmd.Process.Pid,
|
||||
})
|
||||
return
|
||||
}
|
||||
if gateway.cmd != nil && gateway.cmd.Process != nil {
|
||||
gateway.cmd = nil
|
||||
setGatewayRuntimeStatusLocked("stopped")
|
||||
@@ -423,7 +496,7 @@ func (h *Handler) handleGatewayStart(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
pid, err := h.startGatewayLocked("starting")
|
||||
pid, err := h.startGatewayLocked("starting", 0)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to start gateway: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
@@ -475,36 +548,21 @@ func (h *Handler) handleGatewayStop(w http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
}
|
||||
|
||||
// handleGatewayRestart stops the gateway (if running) and starts a new instance.
|
||||
//
|
||||
// POST /api/gateway/restart
|
||||
func (h *Handler) handleGatewayRestart(w http.ResponseWriter, r *http.Request) {
|
||||
// RestartGateway restarts the gateway process. This is a non-blocking operation
|
||||
// that stops the current gateway (if running) and starts a new one.
|
||||
// Returns the PID of the new gateway process or an error.
|
||||
func (h *Handler) RestartGateway() (int, error) {
|
||||
ready, reason, err := h.gatewayStartReady()
|
||||
if err != nil {
|
||||
http.Error(
|
||||
w,
|
||||
fmt.Sprintf("Failed to validate gateway start conditions: %v", err),
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
return
|
||||
return 0, fmt.Errorf("failed to validate gateway start conditions: %w", err)
|
||||
}
|
||||
if !ready {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"status": "precondition_failed",
|
||||
"message": reason,
|
||||
})
|
||||
return
|
||||
return 0, &preconditionFailedError{reason: reason}
|
||||
}
|
||||
|
||||
gateway.mu.Lock()
|
||||
previousCmd := gateway.cmd
|
||||
setGatewayRuntimeStatusLocked("restarting")
|
||||
gateway.events.Broadcast(GatewayEvent{
|
||||
Status: "restarting",
|
||||
RestartRequired: false,
|
||||
})
|
||||
gateway.mu.Unlock()
|
||||
|
||||
if err = stopGatewayProcessForRestart(previousCmd); err != nil {
|
||||
@@ -519,8 +577,7 @@ func (h *Handler) handleGatewayRestart(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
gateway.mu.Unlock()
|
||||
http.Error(w, fmt.Sprintf("Failed to restart gateway: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
return 0, fmt.Errorf("failed to stop gateway: %w", err)
|
||||
}
|
||||
|
||||
gateway.mu.Lock()
|
||||
@@ -528,7 +585,7 @@ func (h *Handler) handleGatewayRestart(w http.ResponseWriter, r *http.Request) {
|
||||
gateway.cmd = nil
|
||||
gateway.bootDefaultModel = ""
|
||||
}
|
||||
pid, err := h.startGatewayLocked("restarting")
|
||||
pid, err := h.startGatewayLocked("restarting", 0)
|
||||
if err != nil {
|
||||
gateway.cmd = nil
|
||||
gateway.bootDefaultModel = ""
|
||||
@@ -536,6 +593,43 @@ func (h *Handler) handleGatewayRestart(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
gateway.mu.Unlock()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to start gateway: %w", err)
|
||||
}
|
||||
|
||||
return pid, nil
|
||||
}
|
||||
|
||||
// preconditionFailedError is returned when gateway restart preconditions are not met
|
||||
type preconditionFailedError struct {
|
||||
reason string
|
||||
}
|
||||
|
||||
func (e *preconditionFailedError) Error() string {
|
||||
return e.reason
|
||||
}
|
||||
|
||||
// IsBadRequest returns true if the error should result in a 400 Bad Request status
|
||||
func (e *preconditionFailedError) IsBadRequest() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// handleGatewayRestart stops the gateway (if running) and starts a new instance.
|
||||
//
|
||||
// POST /api/gateway/restart
|
||||
func (h *Handler) handleGatewayRestart(w http.ResponseWriter, r *http.Request) {
|
||||
pid, err := h.RestartGateway()
|
||||
if err != nil {
|
||||
// Check if it's a precondition failed error
|
||||
var precondErr *preconditionFailedError
|
||||
if errors.As(err, &precondErr) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"status": "precondition_failed",
|
||||
"message": precondErr.reason,
|
||||
})
|
||||
return
|
||||
}
|
||||
http.Error(w, fmt.Sprintf("Failed to restart gateway: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@@ -572,8 +666,8 @@ func (h *Handler) handleGatewayStatus(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
func (h *Handler) gatewayStatusData() map[string]any {
|
||||
data := map[string]any{}
|
||||
cfg, cfgErr := config.LoadConfig(h.configPath)
|
||||
configDefaultModel := ""
|
||||
cfg, cfgErr := config.LoadConfig(h.configPath)
|
||||
if cfgErr == nil && cfg != nil {
|
||||
configDefaultModel = strings.TrimSpace(cfg.Agents.Defaults.GetModelName())
|
||||
if configDefaultModel != "" {
|
||||
@@ -581,74 +675,59 @@ func (h *Handler) gatewayStatusData() map[string]any {
|
||||
}
|
||||
}
|
||||
|
||||
// Check process state
|
||||
gateway.mu.Lock()
|
||||
processAlive := isGatewayProcessAliveLocked()
|
||||
bootDefaultModel := ""
|
||||
if processAlive {
|
||||
data["pid"] = gateway.cmd.Process.Pid
|
||||
if gateway.bootDefaultModel != "" {
|
||||
data["boot_default_model"] = gateway.bootDefaultModel
|
||||
bootDefaultModel = gateway.bootDefaultModel
|
||||
}
|
||||
}
|
||||
gateway.mu.Unlock()
|
||||
|
||||
if !processAlive {
|
||||
// Probe health endpoint to get pid and status
|
||||
healthResp, statusCode, err := h.getGatewayHealth(cfg, 2*time.Second)
|
||||
if err != nil {
|
||||
gateway.mu.Lock()
|
||||
data["gateway_status"] = currentGatewayStatusLocked(false)
|
||||
data["gateway_status"] = gatewayStatusWithoutHealthLocked()
|
||||
gateway.mu.Unlock()
|
||||
log.Printf("Gateway health check failed: %v", err)
|
||||
} else {
|
||||
// Process is alive — probe its health endpoint
|
||||
host := "127.0.0.1"
|
||||
port := 18790
|
||||
if cfgErr == nil && cfg != nil {
|
||||
host = gatewayProbeHost(h.effectiveGatewayBindHost(cfg))
|
||||
if cfg.Gateway.Port != 0 {
|
||||
port = cfg.Gateway.Port
|
||||
}
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("http://%s/health", net.JoinHostPort(host, strconv.Itoa(port)))
|
||||
resp, err := gatewayHealthGet(url, 2*time.Second)
|
||||
|
||||
if err != nil {
|
||||
log.Printf("Gateway health status: %d", statusCode)
|
||||
if statusCode != http.StatusOK {
|
||||
gateway.mu.Lock()
|
||||
data["gateway_status"] = currentGatewayStatusLocked(true)
|
||||
setGatewayRuntimeStatusLocked("error")
|
||||
gateway.mu.Unlock()
|
||||
data["gateway_status"] = "error"
|
||||
data["status_code"] = statusCode
|
||||
} else {
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
gateway.mu.Lock()
|
||||
setGatewayRuntimeStatusLocked("error")
|
||||
gateway.mu.Unlock()
|
||||
data["gateway_status"] = "error"
|
||||
data["status_code"] = resp.StatusCode
|
||||
} else {
|
||||
var healthData map[string]any
|
||||
if decErr := json.NewDecoder(resp.Body).Decode(&healthData); decErr != nil {
|
||||
gateway.mu.Lock()
|
||||
setGatewayRuntimeStatusLocked("error")
|
||||
gateway.mu.Unlock()
|
||||
data["gateway_status"] = "error"
|
||||
} else {
|
||||
gateway.mu.Lock()
|
||||
setGatewayRuntimeStatusLocked("running")
|
||||
gateway.mu.Unlock()
|
||||
for k, v := range healthData {
|
||||
data[k] = v
|
||||
}
|
||||
data["gateway_status"] = "running"
|
||||
gateway.mu.Lock()
|
||||
setGatewayRuntimeStatusLocked("running")
|
||||
if gateway.cmd == nil || gateway.cmd.Process == nil || gateway.cmd.Process.Pid != healthResp.Pid {
|
||||
oldPid := "none"
|
||||
if gateway.cmd != nil && gateway.cmd.Process != nil {
|
||||
oldPid = fmt.Sprintf("%d", gateway.cmd.Process.Pid)
|
||||
}
|
||||
log.Printf(
|
||||
"Detected gateway PID from health (old: %s, new: %d), attempting to attach",
|
||||
oldPid,
|
||||
healthResp.Pid,
|
||||
)
|
||||
if err := attachToGatewayProcessLocked(healthResp.Pid, cfg); err != nil {
|
||||
log.Printf(
|
||||
"Failed to attach to gateway process reported by health (PID: %d): %v",
|
||||
healthResp.Pid,
|
||||
err,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
bootDefaultModel := gateway.bootDefaultModel
|
||||
if bootDefaultModel != "" {
|
||||
data["boot_default_model"] = bootDefaultModel
|
||||
}
|
||||
data["gateway_status"] = "running"
|
||||
data["pid"] = healthResp.Pid
|
||||
gateway.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
status, _ := data["gateway_status"].(string)
|
||||
bootDefaultModel, _ := data["boot_default_model"].(string)
|
||||
gatewayStatus, _ := data["gateway_status"].(string)
|
||||
data["gateway_restart_required"] = gatewayRestartRequired(
|
||||
status,
|
||||
bootDefaultModel,
|
||||
configDefaultModel,
|
||||
bootDefaultModel,
|
||||
gatewayStatus,
|
||||
)
|
||||
|
||||
ready, reason, readyErr := h.gatewayStartReady()
|
||||
@@ -719,51 +798,6 @@ func gatewayLogsData(r *http.Request) map[string]any {
|
||||
return data
|
||||
}
|
||||
|
||||
// handleGatewayEvents serves an SSE stream of gateway state change events.
|
||||
//
|
||||
// GET /api/gateway/events
|
||||
func (h *Handler) handleGatewayEvents(w http.ResponseWriter, r *http.Request) {
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
http.Error(w, "SSE not supported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
|
||||
// Subscribe to gateway events
|
||||
ch := gateway.events.Subscribe()
|
||||
defer gateway.events.Unsubscribe(ch)
|
||||
|
||||
// Send initial status so the client doesn't start blank
|
||||
initial := h.currentGatewayStatus()
|
||||
fmt.Fprintf(w, "data: %s\n\n", initial)
|
||||
flusher.Flush()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-r.Context().Done():
|
||||
return
|
||||
case data, ok := <-ch:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// currentGatewayStatus returns the current gateway status as a JSON string.
|
||||
func (h *Handler) currentGatewayStatus() string {
|
||||
data := h.gatewayStatusData()
|
||||
encoded, _ := json.Marshal(data)
|
||||
return string(encoded)
|
||||
}
|
||||
|
||||
// scanPipe reads lines from r and appends them to buf. Returns when r reaches EOF.
|
||||
func scanPipe(r io.Reader, buf *LogBuffer) {
|
||||
scanner := bufio.NewScanner(r)
|
||||
|
||||
@@ -3,6 +3,7 @@ package api
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@@ -46,6 +47,23 @@ func gatewayProbeHost(bindHost string) string {
|
||||
return bindHost
|
||||
}
|
||||
|
||||
func (h *Handler) gatewayProxyURL() *url.URL {
|
||||
cfg, err := config.LoadConfig(h.configPath)
|
||||
port := 18790
|
||||
bindHost := ""
|
||||
if err == nil && cfg != nil {
|
||||
if cfg.Gateway.Port != 0 {
|
||||
port = cfg.Gateway.Port
|
||||
}
|
||||
bindHost = h.effectiveGatewayBindHost(cfg)
|
||||
}
|
||||
|
||||
return &url.URL{
|
||||
Scheme: "http",
|
||||
Host: net.JoinHostPort(gatewayProbeHost(bindHost), strconv.Itoa(port)),
|
||||
}
|
||||
}
|
||||
|
||||
func requestHostName(r *http.Request) string {
|
||||
reqHost, _, err := net.SplitHostPort(r.Host)
|
||||
if err == nil {
|
||||
@@ -57,10 +75,34 @@ func requestHostName(r *http.Request) string {
|
||||
return "127.0.0.1"
|
||||
}
|
||||
|
||||
func requestWSScheme(r *http.Request) string {
|
||||
if forwarded := strings.TrimSpace(r.Header.Get("X-Forwarded-Proto")); forwarded != "" {
|
||||
proto := strings.ToLower(strings.TrimSpace(strings.Split(forwarded, ",")[0]))
|
||||
if proto == "https" || proto == "wss" {
|
||||
return "wss"
|
||||
}
|
||||
if proto == "http" || proto == "ws" {
|
||||
return "ws"
|
||||
}
|
||||
}
|
||||
|
||||
if r.TLS != nil {
|
||||
return "wss"
|
||||
}
|
||||
|
||||
return "ws"
|
||||
}
|
||||
|
||||
func (h *Handler) buildWsURL(r *http.Request, cfg *config.Config) string {
|
||||
host := h.effectiveGatewayBindHost(cfg)
|
||||
if host == "" || host == "0.0.0.0" {
|
||||
host = requestHostName(r)
|
||||
}
|
||||
return "ws://" + net.JoinHostPort(host, strconv.Itoa(cfg.Gateway.Port)) + "/pico/ws"
|
||||
// Use web server port instead of gateway port to avoid exposing extra ports
|
||||
// The WebSocket connection will be proxied by the backend to the gateway
|
||||
wsPort := h.serverPort
|
||||
if wsPort == 0 {
|
||||
wsPort = 18800 // default web server port
|
||||
}
|
||||
return requestWSScheme(r) + "://" + net.JoinHostPort(host, strconv.Itoa(wsPort)) + "/pico/ws"
|
||||
}
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/web/backend/launcherconfig"
|
||||
@@ -47,8 +51,8 @@ func TestBuildWsURLUsesRequestHostWhenLauncherPublicSaved(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "http://launcher.local/api/pico/token", nil)
|
||||
req.Host = "192.168.1.9:18800"
|
||||
|
||||
if got := h.buildWsURL(req, cfg); got != "ws://192.168.1.9:18790/pico/ws" {
|
||||
t.Fatalf("buildWsURL() = %q, want %q", got, "ws://192.168.1.9:18790/pico/ws")
|
||||
if got := h.buildWsURL(req, cfg); got != "ws://192.168.1.9:18800/pico/ws" {
|
||||
t.Fatalf("buildWsURL() = %q, want %q", got, "ws://192.168.1.9:18800/pico/ws")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -57,3 +61,128 @@ func TestGatewayProbeHostUsesLoopbackForWildcardBind(t *testing.T) {
|
||||
t.Fatalf("gatewayProbeHost() = %q, want %q", got, "127.0.0.1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayProxyURLUsesConfiguredHost(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Gateway.Host = "192.168.1.10"
|
||||
cfg.Gateway.Port = 18791
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
if got := h.gatewayProxyURL().String(); got != "http://192.168.1.10:18791" {
|
||||
t.Fatalf("gatewayProxyURL() = %q, want %q", got, "http://192.168.1.10:18791")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetGatewayHealthUsesConfiguredHost(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Gateway.Host = "192.168.1.10"
|
||||
cfg.Gateway.Port = 18791
|
||||
|
||||
originalHealthGet := gatewayHealthGet
|
||||
t.Cleanup(func() {
|
||||
gatewayHealthGet = originalHealthGet
|
||||
})
|
||||
|
||||
var requestedURL string
|
||||
gatewayHealthGet = func(url string, timeout time.Duration) (*http.Response, error) {
|
||||
requestedURL = url
|
||||
return nil, errors.New("probe failed")
|
||||
}
|
||||
|
||||
_, statusCode, err := h.getGatewayHealth(cfg, time.Second)
|
||||
_ = statusCode
|
||||
_ = err
|
||||
|
||||
if requestedURL != "http://192.168.1.10:18791/health" {
|
||||
t.Fatalf("health url = %q, want %q", requestedURL, "http://192.168.1.10:18791/health")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetGatewayHealthUsesProbeHostForPublicLauncher(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
h.SetServerOptions(18800, true, true, nil)
|
||||
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Gateway.Host = "127.0.0.1"
|
||||
cfg.Gateway.Port = 18791
|
||||
|
||||
originalHealthGet := gatewayHealthGet
|
||||
t.Cleanup(func() {
|
||||
gatewayHealthGet = originalHealthGet
|
||||
})
|
||||
|
||||
var requestedURL string
|
||||
gatewayHealthGet = func(url string, timeout time.Duration) (*http.Response, error) {
|
||||
requestedURL = url
|
||||
return nil, errors.New("probe failed")
|
||||
}
|
||||
|
||||
_, statusCode, err := h.getGatewayHealth(cfg, time.Second)
|
||||
_ = statusCode
|
||||
_ = err
|
||||
|
||||
if requestedURL != "http://127.0.0.1:18791/health" {
|
||||
t.Fatalf("health url = %q, want %q", requestedURL, "http://127.0.0.1:18791/health")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildWsURLUsesWSSWhenForwardedProtoIsHTTPS(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Gateway.Host = "0.0.0.0"
|
||||
cfg.Gateway.Port = 18790
|
||||
|
||||
req := httptest.NewRequest("GET", "http://launcher.local/api/pico/token", nil)
|
||||
req.Host = "chat.example.com"
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
|
||||
if got := h.buildWsURL(req, cfg); got != "wss://chat.example.com:18800/pico/ws" {
|
||||
t.Fatalf("buildWsURL() = %q, want %q", got, "wss://chat.example.com:18800/pico/ws")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildWsURLUsesWSSWhenRequestIsTLS(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Gateway.Host = "0.0.0.0"
|
||||
cfg.Gateway.Port = 18790
|
||||
|
||||
req := httptest.NewRequest("GET", "https://launcher.local/api/pico/token", nil)
|
||||
req.Host = "secure.example.com"
|
||||
req.TLS = &tls.ConnectionState{}
|
||||
|
||||
if got := h.buildWsURL(req, cfg); got != "wss://secure.example.com:18800/pico/ws" {
|
||||
t.Fatalf("buildWsURL() = %q, want %q", got, "wss://secure.example.com:18800/pico/ws")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildWsURLPrefersForwardedHTTPOverTLS(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Gateway.Host = "0.0.0.0"
|
||||
cfg.Gateway.Port = 18790
|
||||
|
||||
req := httptest.NewRequest("GET", "https://launcher.local/api/pico/token", nil)
|
||||
req.Host = "chat.example.com"
|
||||
req.TLS = &tls.ConnectionState{}
|
||||
req.Header.Set("X-Forwarded-Proto", "http")
|
||||
|
||||
if got := h.buildWsURL(req, cfg); got != "ws://chat.example.com:18800/pico/ws" {
|
||||
t.Fatalf("buildWsURL() = %q, want %q", got, "ws://chat.example.com:18800/pico/ws")
|
||||
}
|
||||
}
|
||||
|
||||
+129
-54
@@ -3,6 +3,7 @@ package api
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
@@ -36,6 +37,15 @@ func startLongRunningProcess(t *testing.T) *exec.Cmd {
|
||||
return cmd
|
||||
}
|
||||
|
||||
func mockGatewayHealthResponse(statusCode, pid int) *http.Response {
|
||||
return &http.Response{
|
||||
StatusCode: statusCode,
|
||||
Body: io.NopCloser(strings.NewReader(
|
||||
`{"status":"ok","uptime":"1s","pid":` + strconv.Itoa(pid) + `}`,
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
func startIgnoringTermProcess(t *testing.T) *exec.Cmd {
|
||||
t.Helper()
|
||||
|
||||
@@ -419,6 +429,125 @@ func TestGatewayStatusKeepsRunningWhenHealthProbeFailsAfterRunning(t *testing.T)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayStatusReportsRunningFromHealthProbe(t *testing.T) {
|
||||
resetGatewayTestState(t)
|
||||
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
cmd := startLongRunningProcess(t)
|
||||
t.Cleanup(func() {
|
||||
if cmd.Process != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
}
|
||||
_ = cmd.Wait()
|
||||
})
|
||||
|
||||
gateway.mu.Lock()
|
||||
setGatewayRuntimeStatusLocked("stopped")
|
||||
gateway.mu.Unlock()
|
||||
|
||||
gatewayHealthGet = func(string, time.Duration) (*http.Response, error) {
|
||||
return mockGatewayHealthResponse(http.StatusOK, cmd.Process.Pid), nil
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var body map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if got := body["gateway_status"]; got != "running" {
|
||||
t.Fatalf("gateway_status = %#v, want %q", got, "running")
|
||||
}
|
||||
if got := body["pid"]; got != float64(cmd.Process.Pid) {
|
||||
t.Fatalf("pid = %#v, want %d", got, cmd.Process.Pid)
|
||||
}
|
||||
if got := body["gateway_restart_required"]; got != false {
|
||||
t.Fatalf("gateway_restart_required = %#v, want false", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayStatusRequiresRestartAfterDefaultModelChange(t *testing.T) {
|
||||
resetGatewayTestState(t)
|
||||
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Agents.Defaults.ModelName = cfg.ModelList[0].ModelName
|
||||
cfg.ModelList[0].APIKey = "test-key"
|
||||
cfg.ModelList = append(cfg.ModelList, config.ModelConfig{
|
||||
ModelName: "second-model",
|
||||
Model: "openai/gpt-4.1",
|
||||
APIKey: "second-key",
|
||||
})
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
process, err := os.FindProcess(os.Getpid())
|
||||
if err != nil {
|
||||
t.Fatalf("FindProcess() error = %v", err)
|
||||
}
|
||||
|
||||
gateway.mu.Lock()
|
||||
gateway.cmd = &exec.Cmd{Process: process}
|
||||
gateway.bootDefaultModel = cfg.ModelList[0].ModelName
|
||||
setGatewayRuntimeStatusLocked("running")
|
||||
gateway.mu.Unlock()
|
||||
|
||||
updatedCfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
updatedCfg.Agents.Defaults.ModelName = "second-model"
|
||||
if err := config.SaveConfig(configPath, updatedCfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
gatewayHealthGet = func(string, time.Duration) (*http.Response, error) {
|
||||
return mockGatewayHealthResponse(http.StatusOK, os.Getpid()), nil
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var body map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if got := body["gateway_status"]; got != "running" {
|
||||
t.Fatalf("gateway_status = %#v, want %q", got, "running")
|
||||
}
|
||||
if got := body["boot_default_model"]; got != cfg.ModelList[0].ModelName {
|
||||
t.Fatalf("boot_default_model = %#v, want %q", got, cfg.ModelList[0].ModelName)
|
||||
}
|
||||
if got := body["config_default_model"]; got != "second-model" {
|
||||
t.Fatalf("config_default_model = %#v, want %q", got, "second-model")
|
||||
}
|
||||
if got := body["gateway_restart_required"]; got != true {
|
||||
t.Fatalf("gateway_restart_required = %#v, want true", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayStatusReturnsErrorAfterStartupWindowExpires(t *testing.T) {
|
||||
resetGatewayTestState(t)
|
||||
|
||||
@@ -494,60 +623,6 @@ func TestGatewayStatusReturnsRestartingDuringRestartGap(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayStatusIncludesRestartRequiredWhenModelsDiffer(t *testing.T) {
|
||||
resetGatewayTestState(t)
|
||||
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Agents.Defaults.ModelName = cfg.ModelList[0].ModelName
|
||||
cfg.ModelList[0].APIKey = "test-key"
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
cmd := startLongRunningProcess(t)
|
||||
t.Cleanup(func() {
|
||||
if cmd.Process != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
}
|
||||
_ = cmd.Wait()
|
||||
})
|
||||
|
||||
gateway.mu.Lock()
|
||||
gateway.cmd = cmd
|
||||
gateway.bootDefaultModel = "previous-model"
|
||||
setGatewayRuntimeStatusLocked("running")
|
||||
gateway.mu.Unlock()
|
||||
|
||||
gatewayHealthGet = func(string, time.Duration) (*http.Response, error) {
|
||||
rec := httptest.NewRecorder()
|
||||
rec.WriteHeader(http.StatusOK)
|
||||
_, _ = rec.WriteString(`{"ok":true}`)
|
||||
return rec.Result(), nil
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var body map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if got := body["gateway_restart_required"]; got != true {
|
||||
t.Fatalf("gateway_restart_required = %#v, want true", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayRestartKeepsRunningProcessWhenPreconditionsFail(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
cfg := config.DefaultConfig()
|
||||
|
||||
+37
-12
@@ -6,6 +6,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
@@ -16,6 +17,30 @@ func (h *Handler) registerPicoRoutes(mux *http.ServeMux) {
|
||||
mux.HandleFunc("GET /api/pico/token", h.handleGetPicoToken)
|
||||
mux.HandleFunc("POST /api/pico/token", h.handleRegenPicoToken)
|
||||
mux.HandleFunc("POST /api/pico/setup", h.handlePicoSetup)
|
||||
|
||||
// WebSocket proxy: forward /pico/ws to gateway
|
||||
// This allows the frontend to connect via the same port as the web UI,
|
||||
// avoiding the need to expose extra ports for WebSocket communication.
|
||||
mux.HandleFunc("GET /pico/ws", h.handleWebSocketProxy())
|
||||
}
|
||||
|
||||
// createWsProxy creates a reverse proxy to the current gateway WebSocket endpoint.
|
||||
// The gateway bind host and port are resolved from the latest configuration.
|
||||
func (h *Handler) createWsProxy() *httputil.ReverseProxy {
|
||||
wsProxy := httputil.NewSingleHostReverseProxy(h.gatewayProxyURL())
|
||||
wsProxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
http.Error(w, "Gateway unavailable: "+err.Error(), http.StatusBadGateway)
|
||||
}
|
||||
return wsProxy
|
||||
}
|
||||
|
||||
// handleWebSocketProxy wraps a reverse proxy to handle WebSocket connections.
|
||||
// The reverse proxy forwards the incoming upgrade handshake as-is.
|
||||
func (h *Handler) handleWebSocketProxy() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
proxy := h.createWsProxy()
|
||||
proxy.ServeHTTP(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// handleGetPicoToken returns the current WS token and URL for the frontend.
|
||||
@@ -65,9 +90,14 @@ func (h *Handler) handleRegenPicoToken(w http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
}
|
||||
|
||||
// ensurePicoChannel checks if the Pico Channel is properly configured and
|
||||
// enables it with sensible defaults if not. Returns true if config was changed.
|
||||
func (h *Handler) ensurePicoChannel() (bool, error) {
|
||||
// ensurePicoChannel enables the Pico channel with sane defaults if it isn't
|
||||
// already configured. Returns true when the config was modified.
|
||||
//
|
||||
// callerOrigin is the Origin header from the setup request. If non-empty and
|
||||
// no origins are configured yet, it's written as the allowed origin so the
|
||||
// WebSocket handshake works for whatever host the caller is on (LAN, custom
|
||||
// port, etc.). Pass "" when there's no request context.
|
||||
func (h *Handler) ensurePicoChannel(callerOrigin string) (bool, error) {
|
||||
cfg, err := config.LoadConfig(h.configPath)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to load config: %w", err)
|
||||
@@ -85,14 +115,9 @@ func (h *Handler) ensurePicoChannel() (bool, error) {
|
||||
changed = true
|
||||
}
|
||||
|
||||
if !cfg.Channels.Pico.AllowTokenQuery {
|
||||
cfg.Channels.Pico.AllowTokenQuery = true
|
||||
changed = true
|
||||
}
|
||||
|
||||
// Make sure origins are allowed (frontend might be running on a different port like 5173 during dev)
|
||||
if len(cfg.Channels.Pico.AllowOrigins) == 0 {
|
||||
cfg.Channels.Pico.AllowOrigins = []string{"*"}
|
||||
// Seed origins from the request instead of hardcoding ports.
|
||||
if len(cfg.Channels.Pico.AllowOrigins) == 0 && callerOrigin != "" {
|
||||
cfg.Channels.Pico.AllowOrigins = []string{callerOrigin}
|
||||
changed = true
|
||||
}
|
||||
|
||||
@@ -109,7 +134,7 @@ func (h *Handler) ensurePicoChannel() (bool, error) {
|
||||
//
|
||||
// POST /api/pico/setup
|
||||
func (h *Handler) handlePicoSetup(w http.ResponseWriter, r *http.Request) {
|
||||
changed, err := h.ensurePicoChannel()
|
||||
changed, err := h.ensurePicoChannel(r.Header.Get("Origin"))
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
|
||||
@@ -0,0 +1,314 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
func TestEnsurePicoChannel_FreshConfig(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
changed, err := h.ensurePicoChannel("")
|
||||
if err != nil {
|
||||
t.Fatalf("ensurePicoChannel() error = %v", err)
|
||||
}
|
||||
if !changed {
|
||||
t.Fatal("ensurePicoChannel() should report changed on a fresh config")
|
||||
}
|
||||
|
||||
cfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
|
||||
if !cfg.Channels.Pico.Enabled {
|
||||
t.Error("expected Pico to be enabled after setup")
|
||||
}
|
||||
if cfg.Channels.Pico.Token == "" {
|
||||
t.Error("expected a non-empty token after setup")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsurePicoChannel_DoesNotEnableTokenQuery(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
if _, err := h.ensurePicoChannel(""); err != nil {
|
||||
t.Fatalf("ensurePicoChannel() error = %v", err)
|
||||
}
|
||||
|
||||
cfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
|
||||
if cfg.Channels.Pico.AllowTokenQuery {
|
||||
t.Error("setup must not enable allow_token_query by default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsurePicoChannel_DoesNotSetWildcardOrigins(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
if _, err := h.ensurePicoChannel("http://localhost:18800"); err != nil {
|
||||
t.Fatalf("ensurePicoChannel() error = %v", err)
|
||||
}
|
||||
|
||||
cfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
|
||||
for _, origin := range cfg.Channels.Pico.AllowOrigins {
|
||||
if origin == "*" {
|
||||
t.Error("setup must not set wildcard origin '*'")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsurePicoChannel_NoOriginWithoutCaller(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
if _, err := h.ensurePicoChannel(""); err != nil {
|
||||
t.Fatalf("ensurePicoChannel() error = %v", err)
|
||||
}
|
||||
|
||||
cfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
|
||||
// Without a caller origin, allow_origins stays empty (CheckOrigin
|
||||
// allows all when the list is empty, so the channel still works).
|
||||
if len(cfg.Channels.Pico.AllowOrigins) != 0 {
|
||||
t.Errorf("allow_origins = %v, want empty when no caller origin", cfg.Channels.Pico.AllowOrigins)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsurePicoChannel_SetsCallerOrigin(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
lanOrigin := "http://192.168.1.9:18800"
|
||||
if _, err := h.ensurePicoChannel(lanOrigin); err != nil {
|
||||
t.Fatalf("ensurePicoChannel() error = %v", err)
|
||||
}
|
||||
|
||||
cfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
|
||||
if len(cfg.Channels.Pico.AllowOrigins) != 1 || cfg.Channels.Pico.AllowOrigins[0] != lanOrigin {
|
||||
t.Errorf("allow_origins = %v, want [%s]", cfg.Channels.Pico.AllowOrigins, lanOrigin)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsurePicoChannel_PreservesUserSettings(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
|
||||
// Pre-configure with custom user settings
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Channels.Pico.Enabled = true
|
||||
cfg.Channels.Pico.Token = "user-custom-token"
|
||||
cfg.Channels.Pico.AllowTokenQuery = true
|
||||
cfg.Channels.Pico.AllowOrigins = []string{"https://myapp.example.com"}
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
|
||||
changed, err := h.ensurePicoChannel("")
|
||||
if err != nil {
|
||||
t.Fatalf("ensurePicoChannel() error = %v", err)
|
||||
}
|
||||
if changed {
|
||||
t.Error("ensurePicoChannel() should not change a fully configured config")
|
||||
}
|
||||
|
||||
cfg, err = config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
|
||||
if cfg.Channels.Pico.Token != "user-custom-token" {
|
||||
t.Errorf("token = %q, want %q", cfg.Channels.Pico.Token, "user-custom-token")
|
||||
}
|
||||
if !cfg.Channels.Pico.AllowTokenQuery {
|
||||
t.Error("user's allow_token_query=true must be preserved")
|
||||
}
|
||||
if len(cfg.Channels.Pico.AllowOrigins) != 1 || cfg.Channels.Pico.AllowOrigins[0] != "https://myapp.example.com" {
|
||||
t.Errorf("allow_origins = %v, want [https://myapp.example.com]", cfg.Channels.Pico.AllowOrigins)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsurePicoChannel_Idempotent(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
origin := "http://localhost:18800"
|
||||
|
||||
// First call sets things up
|
||||
if _, err := h.ensurePicoChannel(origin); err != nil {
|
||||
t.Fatalf("first ensurePicoChannel() error = %v", err)
|
||||
}
|
||||
|
||||
cfg1, _ := config.LoadConfig(configPath)
|
||||
token1 := cfg1.Channels.Pico.Token
|
||||
|
||||
// Second call should be a no-op
|
||||
changed, err := h.ensurePicoChannel(origin)
|
||||
if err != nil {
|
||||
t.Fatalf("second ensurePicoChannel() error = %v", err)
|
||||
}
|
||||
if changed {
|
||||
t.Error("second ensurePicoChannel() should not report changed")
|
||||
}
|
||||
|
||||
cfg2, _ := config.LoadConfig(configPath)
|
||||
if cfg2.Channels.Pico.Token != token1 {
|
||||
t.Error("token should not change on subsequent calls")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlePicoSetup_IncludesRequestOrigin(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/pico/setup", nil)
|
||||
req.Header.Set("Origin", "http://10.0.0.5:3000")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.handlePicoSetup(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
cfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
|
||||
if len(cfg.Channels.Pico.AllowOrigins) != 1 || cfg.Channels.Pico.AllowOrigins[0] != "http://10.0.0.5:3000" {
|
||||
t.Errorf("allow_origins = %v, want [http://10.0.0.5:3000]", cfg.Channels.Pico.AllowOrigins)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlePicoSetup_Response(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/pico/setup", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.handlePicoSetup(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var resp map[string]any
|
||||
if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
if resp["token"] == nil || resp["token"] == "" {
|
||||
t.Error("response should contain a non-empty token")
|
||||
}
|
||||
if resp["ws_url"] == nil || resp["ws_url"] == "" {
|
||||
t.Error("response should contain ws_url")
|
||||
}
|
||||
if resp["enabled"] != true {
|
||||
t.Error("response should have enabled=true")
|
||||
}
|
||||
if resp["changed"] != true {
|
||||
t.Error("response should have changed=true on first setup")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleWebSocketProxyReloadsGatewayTargetFromConfig(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
handler := h.handleWebSocketProxy()
|
||||
|
||||
server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/pico/ws" {
|
||||
t.Fatalf("server1 path = %q, want %q", r.URL.Path, "/pico/ws")
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = io.WriteString(w, "server1")
|
||||
}))
|
||||
defer server1.Close()
|
||||
|
||||
server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/pico/ws" {
|
||||
t.Fatalf("server2 path = %q, want %q", r.URL.Path, "/pico/ws")
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = io.WriteString(w, "server2")
|
||||
}))
|
||||
defer server2.Close()
|
||||
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Gateway.Host = "127.0.0.1"
|
||||
cfg.Gateway.Port = mustGatewayTestPort(t, server1.URL)
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
req1 := httptest.NewRequest(http.MethodGet, "/pico/ws", nil)
|
||||
rec1 := httptest.NewRecorder()
|
||||
handler(rec1, req1)
|
||||
|
||||
if rec1.Code != http.StatusOK {
|
||||
t.Fatalf("first status = %d, want %d", rec1.Code, http.StatusOK)
|
||||
}
|
||||
if body := rec1.Body.String(); body != "server1" {
|
||||
t.Fatalf("first body = %q, want %q", body, "server1")
|
||||
}
|
||||
|
||||
cfg.Gateway.Port = mustGatewayTestPort(t, server2.URL)
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/pico/ws", nil)
|
||||
rec2 := httptest.NewRecorder()
|
||||
handler(rec2, req2)
|
||||
|
||||
if rec2.Code != http.StatusOK {
|
||||
t.Fatalf("second status = %d, want %d", rec2.Code, http.StatusOK)
|
||||
}
|
||||
if body := rec2.Body.String(); body != "server2" {
|
||||
t.Fatalf("second body = %q, want %q", body, "server2")
|
||||
}
|
||||
}
|
||||
|
||||
func mustGatewayTestPort(t *testing.T, rawURL string) int {
|
||||
t.Helper()
|
||||
|
||||
parsed, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
t.Fatalf("url.Parse() error = %v", err)
|
||||
}
|
||||
|
||||
port, err := strconv.Atoi(parsed.Port())
|
||||
if err != nil {
|
||||
t.Fatalf("Atoi(%q) error = %v", parsed.Port(), err)
|
||||
}
|
||||
|
||||
return port
|
||||
}
|
||||
@@ -70,3 +70,5 @@ func (h *Handler) RegisterRoutes(mux *http.ServeMux) {
|
||||
// Launcher service parameters (port/public)
|
||||
h.registerLauncherConfigRoutes(mux)
|
||||
}
|
||||
|
||||
func (h *Handler) Shutdown() {}
|
||||
|
||||
@@ -118,6 +118,12 @@ var toolCatalog = []toolCatalogEntry{
|
||||
Category: "agents",
|
||||
ConfigKey: "spawn",
|
||||
},
|
||||
{
|
||||
Name: "spawn_status",
|
||||
Description: "Query the status of spawned subagents.",
|
||||
Category: "agents",
|
||||
ConfigKey: "spawn_status",
|
||||
},
|
||||
{
|
||||
Name: "i2c",
|
||||
Description: "Interact with I2C hardware devices exposed on the host.",
|
||||
@@ -205,7 +211,7 @@ func buildToolSupport(cfg *config.Config) []toolSupportItem {
|
||||
reasonCode = "requires_skills"
|
||||
}
|
||||
}
|
||||
case "spawn":
|
||||
case "spawn", "spawn_status":
|
||||
if cfg.Tools.IsToolEnabled(entry.ConfigKey) {
|
||||
if cfg.Tools.IsToolEnabled("subagent") {
|
||||
status = "enabled"
|
||||
@@ -300,6 +306,12 @@ func applyToolState(cfg *config.Config, toolName string, enabled bool) error {
|
||||
if enabled {
|
||||
cfg.Tools.Subagent.Enabled = true
|
||||
}
|
||||
case "spawn_status":
|
||||
cfg.Tools.SpawnStatus.Enabled = enabled
|
||||
if enabled {
|
||||
cfg.Tools.Spawn.Enabled = true
|
||||
cfg.Tools.Subagent.Enabled = true
|
||||
}
|
||||
case "i2c":
|
||||
cfg.Tools.I2C.Enabled = enabled
|
||||
case "spi":
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/web/backend/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
browserDelay = 500 * time.Millisecond
|
||||
shutdownTimeout = 15 * time.Second
|
||||
)
|
||||
|
||||
func shutdownApp() {
|
||||
fmt.Println(T(Exiting))
|
||||
|
||||
if apiHandler != nil {
|
||||
apiHandler.Shutdown()
|
||||
}
|
||||
|
||||
if server != nil {
|
||||
server.SetKeepAlivesEnabled(false)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout)
|
||||
defer cancel()
|
||||
if err := server.Shutdown(ctx); err != nil {
|
||||
if err == context.DeadlineExceeded {
|
||||
logger.Infof("Server shutdown timeout after %v, forcing close", shutdownTimeout)
|
||||
} else {
|
||||
logger.Errorf("Server shutdown error: %v", err)
|
||||
}
|
||||
} else {
|
||||
logger.Infof("Server shutdown completed successfully")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func openBrowser() error {
|
||||
if serverAddr == "" {
|
||||
return fmt.Errorf("server address not set")
|
||||
}
|
||||
return utils.OpenBrowser(serverAddr)
|
||||
}
|
||||
@@ -0,0 +1,120 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Language represents the supported languages
|
||||
type Language string
|
||||
|
||||
const (
|
||||
LanguageEnglish Language = "en"
|
||||
LanguageChinese Language = "zh"
|
||||
)
|
||||
|
||||
// current language (default: English)
|
||||
var currentLang Language = LanguageEnglish
|
||||
|
||||
// TranslationKey represents a translation key used for i18n
|
||||
type TranslationKey string
|
||||
|
||||
const (
|
||||
AppTooltip TranslationKey = "AppTooltip"
|
||||
MenuOpen TranslationKey = "MenuOpen"
|
||||
MenuOpenTooltip TranslationKey = "MenuOpenTooltip"
|
||||
MenuAbout TranslationKey = "MenuAbout"
|
||||
MenuAboutTooltip TranslationKey = "MenuAboutTooltip"
|
||||
MenuVersion TranslationKey = "MenuVersion"
|
||||
MenuVersionTooltip TranslationKey = "MenuVersionTooltip"
|
||||
MenuGitHub TranslationKey = "MenuGitHub"
|
||||
MenuDocs TranslationKey = "MenuDocs"
|
||||
MenuRestart TranslationKey = "MenuRestart"
|
||||
MenuRestartTooltip TranslationKey = "MenuRestartTooltip"
|
||||
MenuQuit TranslationKey = "MenuQuit"
|
||||
MenuQuitTooltip TranslationKey = "MenuQuitTooltip"
|
||||
Exiting TranslationKey = "Exiting"
|
||||
DocUrl TranslationKey = "DocUrl"
|
||||
)
|
||||
|
||||
// Translation tables
|
||||
// Chinese translations intentionally contain Han script
|
||||
//
|
||||
//nolint:gosmopolitan
|
||||
var translations = map[Language]map[TranslationKey]string{
|
||||
LanguageEnglish: {
|
||||
AppTooltip: "%s - Web Console",
|
||||
MenuOpen: "Open Console",
|
||||
MenuOpenTooltip: "Open PicoClaw console in browser",
|
||||
MenuAbout: "About",
|
||||
MenuAboutTooltip: "About PicoClaw",
|
||||
MenuVersion: "Version: %s",
|
||||
MenuVersionTooltip: "Current version number",
|
||||
MenuGitHub: "GitHub",
|
||||
MenuDocs: "Documentation",
|
||||
MenuRestart: "Restart Service",
|
||||
MenuRestartTooltip: "Restart Gateway service",
|
||||
MenuQuit: "Quit",
|
||||
MenuQuitTooltip: "Exit PicoClaw",
|
||||
Exiting: "Exiting PicoClaw...",
|
||||
DocUrl: "https://docs.picoclaw.io/docs/",
|
||||
},
|
||||
LanguageChinese: {
|
||||
AppTooltip: "%s - Web Console",
|
||||
MenuOpen: "打开控制台",
|
||||
MenuOpenTooltip: "在浏览器中打开 PicoClaw 控制台",
|
||||
MenuAbout: "关于",
|
||||
MenuAboutTooltip: "关于 PicoClaw",
|
||||
MenuVersion: "版本: %s",
|
||||
MenuVersionTooltip: "当前版本号",
|
||||
MenuGitHub: "GitHub",
|
||||
MenuDocs: "文档",
|
||||
MenuRestart: "重启服务",
|
||||
MenuRestartTooltip: "重启核心服务",
|
||||
MenuQuit: "退出",
|
||||
MenuQuitTooltip: "退出 PicoClaw",
|
||||
Exiting: "正在退出 PicoClaw...",
|
||||
DocUrl: "https://docs.picoclaw.io/zh-Hans/docs/",
|
||||
},
|
||||
}
|
||||
|
||||
// SetLanguage sets the current language
|
||||
func SetLanguage(lang string) {
|
||||
lang = strings.ToLower(strings.TrimSpace(lang))
|
||||
|
||||
// Extract language code before first underscore or dot
|
||||
// e.g., "en_US.UTF-8" -> "en", "zh_CN" -> "zh"
|
||||
if idx := strings.IndexAny(lang, "_."); idx > 0 {
|
||||
lang = lang[:idx]
|
||||
}
|
||||
|
||||
if lang == "zh" || lang == "zh-cn" || lang == "chinese" {
|
||||
currentLang = LanguageChinese
|
||||
} else {
|
||||
currentLang = LanguageEnglish
|
||||
}
|
||||
}
|
||||
|
||||
// GetLanguage returns the current language
|
||||
func GetLanguage() Language {
|
||||
return currentLang
|
||||
}
|
||||
|
||||
// T translates a key to the current language
|
||||
func T(key TranslationKey, args ...any) string {
|
||||
if trans, ok := translations[currentLang][key]; ok {
|
||||
if len(args) > 0 {
|
||||
return fmt.Sprintf(trans, args...)
|
||||
}
|
||||
return trans
|
||||
}
|
||||
return string(key)
|
||||
}
|
||||
|
||||
// Initialize i18n from environment variable
|
||||
func init() {
|
||||
if lang := os.Getenv("LANG"); lang != "" {
|
||||
SetLanguage(lang)
|
||||
}
|
||||
}
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 102 KiB |
+37
-16
@@ -22,16 +22,32 @@ import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/web/backend/api"
|
||||
"github.com/sipeed/picoclaw/web/backend/launcherconfig"
|
||||
"github.com/sipeed/picoclaw/web/backend/middleware"
|
||||
"github.com/sipeed/picoclaw/web/backend/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
appName = "PicoClaw"
|
||||
)
|
||||
|
||||
var (
|
||||
appVersion = config.Version
|
||||
|
||||
server *http.Server
|
||||
serverAddr string
|
||||
apiHandler *api.Handler
|
||||
|
||||
noBrowser *bool
|
||||
)
|
||||
|
||||
func main() {
|
||||
port := flag.String("port", "18800", "Port to listen on")
|
||||
public := flag.Bool("public", false, "Listen on all interfaces (0.0.0.0) instead of localhost only")
|
||||
noBrowser := flag.Bool("no-browser", false, "Do not auto-open browser on startup")
|
||||
noBrowser = flag.Bool("no-browser", false, "Do not auto-open browser on startup")
|
||||
lang := flag.String("lang", "", "Language: en (English) or zh (Chinese). Default: auto-detect from system locale")
|
||||
|
||||
flag.Usage = func() {
|
||||
fmt.Fprintf(os.Stderr, "PicoClaw Launcher - A web-based configuration editor\n\n")
|
||||
@@ -51,6 +67,11 @@ func main() {
|
||||
}
|
||||
flag.Parse()
|
||||
|
||||
// Set language from command line or auto-detect
|
||||
if *lang != "" {
|
||||
SetLanguage(*lang)
|
||||
}
|
||||
|
||||
// Resolve config path
|
||||
configPath := utils.GetDefaultConfigPath()
|
||||
if flag.NArg() > 0 {
|
||||
@@ -113,7 +134,7 @@ func main() {
|
||||
mux := http.NewServeMux()
|
||||
|
||||
// API Routes (e.g. /api/status)
|
||||
apiHandler := api.NewHandler(absPath)
|
||||
apiHandler = api.NewHandler(absPath)
|
||||
apiHandler.SetServerOptions(portNum, effectivePublic, explicitPublic, launcherCfg.AllowedCIDRs)
|
||||
apiHandler.RegisterRoutes(mux)
|
||||
|
||||
@@ -145,16 +166,10 @@ func main() {
|
||||
}
|
||||
fmt.Println()
|
||||
|
||||
// Auto-open browser
|
||||
if !*noBrowser {
|
||||
go func() {
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
url := "http://localhost:" + effectivePort
|
||||
if err := utils.OpenBrowser(url); err != nil {
|
||||
log.Printf("Warning: Failed to auto-open browser: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
// Share the local URL with the launcher runtime.
|
||||
serverAddr = fmt.Sprintf("http://localhost:%s", effectivePort)
|
||||
|
||||
// Auto-open browser will be handled by the launcher runtime.
|
||||
|
||||
// Auto-start gateway after backend starts listening.
|
||||
go func() {
|
||||
@@ -162,8 +177,14 @@ func main() {
|
||||
apiHandler.TryAutoStartGateway()
|
||||
}()
|
||||
|
||||
// Start the Server
|
||||
if err := http.ListenAndServe(addr, handler); err != nil {
|
||||
log.Fatalf("Server failed to start: %v", err)
|
||||
}
|
||||
// Start the Server in a goroutine
|
||||
server = &http.Server{Addr: addr, Handler: handler}
|
||||
go func() {
|
||||
log.Printf("Server listening on %s", addr)
|
||||
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
log.Fatalf("Server failed to start: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
runTray()
|
||||
}
|
||||
|
||||
@@ -4,16 +4,14 @@ import (
|
||||
"log"
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// JSONContentType sets the Content-Type header to application/json for
|
||||
// API requests handled by the wrapped handler.
|
||||
// SSE endpoints (text/event-stream) are excluded.
|
||||
func JSONContentType(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.HasPrefix(r.URL.Path, "/api/") && !strings.HasSuffix(r.URL.Path, "/events") {
|
||||
if len(r.URL.Path) >= 5 && r.URL.Path[:5] == "/api/" {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
@@ -32,7 +30,6 @@ func (rr *responseRecorder) WriteHeader(code int) {
|
||||
}
|
||||
|
||||
// Flush delegates to the underlying ResponseWriter if it implements http.Flusher.
|
||||
// This is required for SSE (Server-Sent Events) to work through the middleware.
|
||||
func (rr *responseRecorder) Flush() {
|
||||
if f, ok := rr.ResponseWriter.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
|
||||
@@ -0,0 +1,95 @@
|
||||
//go:build (!darwin && !freebsd) || cgo
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"fmt"
|
||||
|
||||
"fyne.io/systray"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/web/backend/utils"
|
||||
)
|
||||
|
||||
func runTray() {
|
||||
systray.Run(onReady, shutdownApp)
|
||||
}
|
||||
|
||||
// onReady is called when the system tray is ready
|
||||
func onReady() {
|
||||
// Set icon and tooltip
|
||||
systray.SetIcon(getIcon())
|
||||
systray.SetTooltip(fmt.Sprintf(T(AppTooltip), appName))
|
||||
|
||||
// Create menu items
|
||||
mOpen := systray.AddMenuItem(T(MenuOpen), T(MenuOpenTooltip))
|
||||
mAbout := systray.AddMenuItem(T(MenuAbout), T(MenuAboutTooltip))
|
||||
|
||||
// Add version info under About menu
|
||||
mVersion := mAbout.AddSubMenuItem(fmt.Sprintf(T(MenuVersion), appVersion), T(MenuVersionTooltip))
|
||||
mVersion.Disable()
|
||||
mRepo := mAbout.AddSubMenuItem(T(MenuGitHub), "")
|
||||
mDocs := mAbout.AddSubMenuItem(T(MenuDocs), "")
|
||||
|
||||
systray.AddSeparator()
|
||||
|
||||
// Add restart option
|
||||
mRestart := systray.AddMenuItem(T(MenuRestart), T(MenuRestartTooltip))
|
||||
|
||||
systray.AddSeparator()
|
||||
|
||||
// Quit option
|
||||
mQuit := systray.AddMenuItem(T(MenuQuit), T(MenuQuitTooltip))
|
||||
|
||||
// Handle menu clicks
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-mOpen.ClickedCh:
|
||||
if err := openBrowser(); err != nil {
|
||||
logger.Errorf("Failed to open browser: %v", err)
|
||||
}
|
||||
|
||||
case <-mVersion.ClickedCh:
|
||||
// Version info - do nothing, just shows current version
|
||||
|
||||
case <-mRepo.ClickedCh:
|
||||
if err := utils.OpenBrowser("https://github.com/sipeed/picoclaw"); err != nil {
|
||||
logger.Errorf("Failed to open GitHub: %v", err)
|
||||
}
|
||||
|
||||
case <-mDocs.ClickedCh:
|
||||
if err := utils.OpenBrowser(T(DocUrl)); err != nil {
|
||||
logger.Errorf("Failed to open docs: %v", err)
|
||||
}
|
||||
|
||||
case <-mRestart.ClickedCh:
|
||||
fmt.Println("Restart request received...")
|
||||
if apiHandler != nil {
|
||||
if pid, err := apiHandler.RestartGateway(); err != nil {
|
||||
logger.Errorf("Failed to restart gateway: %v", err)
|
||||
} else {
|
||||
logger.Infof("Gateway restarted (PID: %d)", pid)
|
||||
}
|
||||
}
|
||||
|
||||
case <-mQuit.ClickedCh:
|
||||
systray.Quit()
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
if !*noBrowser {
|
||||
// Auto-open browser after systray is ready (if not disabled)
|
||||
// Check no-browser flag via environment or pass as parameter if needed
|
||||
if err := openBrowser(); err != nil {
|
||||
logger.Errorf("Warning: Failed to auto-open browser: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// getIcon returns the system tray icon
|
||||
func getIcon() []byte {
|
||||
return iconData
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
//go:build !windows
|
||||
|
||||
package main
|
||||
|
||||
import _ "embed"
|
||||
|
||||
//go:embed icon.png
|
||||
var iconData []byte
|
||||
@@ -0,0 +1,8 @@
|
||||
//go:build windows
|
||||
|
||||
package main
|
||||
|
||||
import _ "embed"
|
||||
|
||||
//go:embed icon.ico
|
||||
var iconData []byte
|
||||
@@ -0,0 +1,33 @@
|
||||
//go:build (darwin || freebsd) && !cgo
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"os/signal"
|
||||
"runtime"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
func runTray() {
|
||||
logger.Infof("System tray is unavailable in %s builds without cgo; running without tray", runtime.GOOS)
|
||||
|
||||
if !*noBrowser {
|
||||
go func() {
|
||||
time.Sleep(browserDelay)
|
||||
if err := openBrowser(); err != nil {
|
||||
logger.Errorf("Warning: Failed to auto-open browser: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
||||
defer stop()
|
||||
|
||||
<-ctx.Done()
|
||||
shutdownApp()
|
||||
}
|
||||
@@ -6,7 +6,7 @@
|
||||
"scripts": {
|
||||
"dev": "vite",
|
||||
"build": "tsc -b && vite build",
|
||||
"build:backend": "tsc -b && vite build --outDir ../backend/dist --emptyOutDir",
|
||||
"build:backend": "tsc -b && vite build --outDir ../backend/dist --emptyOutDir && node ./scripts/ensure-backend-gitkeep.cjs",
|
||||
"lint": "eslint .",
|
||||
"preview": "vite preview",
|
||||
"format": "prettier --check .",
|
||||
@@ -17,18 +17,18 @@
|
||||
"@tabler/icons-react": "^3.38.0",
|
||||
"@tailwindcss/vite": "^4.2.1",
|
||||
"@tanstack/react-query": "^5.90.21",
|
||||
"@tanstack/react-router": "^1.163.3",
|
||||
"@tanstack/react-router": "^1.167.0",
|
||||
"@tanstack/react-router-devtools": "^1.163.3",
|
||||
"class-variance-authority": "^0.7.1",
|
||||
"clsx": "^2.1.1",
|
||||
"dayjs": "^1.11.19",
|
||||
"dayjs": "^1.11.20",
|
||||
"i18next": "^25.8.14",
|
||||
"i18next-browser-languagedetector": "^8.2.1",
|
||||
"jotai": "^2.18.0",
|
||||
"jotai": "^2.18.1",
|
||||
"radix-ui": "^1.4.3",
|
||||
"react": "^19.2.0",
|
||||
"react-dom": "^19.2.0",
|
||||
"react-i18next": "^16.5.4",
|
||||
"react-i18next": "^16.5.8",
|
||||
"react-markdown": "^10.1.0",
|
||||
"react-textarea-autosize": "^8.5.9",
|
||||
"remark-gfm": "^4.0.1",
|
||||
@@ -40,7 +40,7 @@
|
||||
"wrap-ansi": "^10.0.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@eslint/js": "^9.39.1",
|
||||
"@eslint/js": "^9.39.3",
|
||||
"@tailwindcss/typography": "^0.5.19",
|
||||
"@tanstack/router-plugin": "^1.164.0",
|
||||
"@trivago/prettier-plugin-sort-imports": "^6.0.2",
|
||||
@@ -48,8 +48,8 @@
|
||||
"@types/react": "^19.2.7",
|
||||
"@types/react-dom": "^19.2.3",
|
||||
"@typescript-eslint/eslint-plugin": "^8.56.1",
|
||||
"@vitejs/plugin-react": "^5.1.1",
|
||||
"eslint": "^9.39.1",
|
||||
"@vitejs/plugin-react": "^5.2.0",
|
||||
"eslint": "^9.39.3",
|
||||
"eslint-config-prettier": "^10.1.8",
|
||||
"eslint-plugin-react-hooks": "^7.0.1",
|
||||
"eslint-plugin-react-refresh": "^0.4.24",
|
||||
|
||||
Generated
+95
-60
@@ -21,11 +21,11 @@ importers:
|
||||
specifier: ^5.90.21
|
||||
version: 5.90.21(react@19.2.4)
|
||||
'@tanstack/react-router':
|
||||
specifier: ^1.163.3
|
||||
version: 1.163.3(react-dom@19.2.4(react@19.2.4))(react@19.2.4)
|
||||
specifier: ^1.167.0
|
||||
version: 1.167.0(react-dom@19.2.4(react@19.2.4))(react@19.2.4)
|
||||
'@tanstack/react-router-devtools':
|
||||
specifier: ^1.163.3
|
||||
version: 1.163.3(@tanstack/react-router@1.163.3(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(@tanstack/router-core@1.163.3)(csstype@3.2.3)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)
|
||||
version: 1.163.3(@tanstack/react-router@1.167.0(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(@tanstack/router-core@1.167.0)(csstype@3.2.3)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)
|
||||
class-variance-authority:
|
||||
specifier: ^0.7.1
|
||||
version: 0.7.1
|
||||
@@ -33,8 +33,8 @@ importers:
|
||||
specifier: ^2.1.1
|
||||
version: 2.1.1
|
||||
dayjs:
|
||||
specifier: ^1.11.19
|
||||
version: 1.11.19
|
||||
specifier: ^1.11.20
|
||||
version: 1.11.20
|
||||
i18next:
|
||||
specifier: ^25.8.14
|
||||
version: 25.8.14(typescript@5.9.3)
|
||||
@@ -42,8 +42,8 @@ importers:
|
||||
specifier: ^8.2.1
|
||||
version: 8.2.1
|
||||
jotai:
|
||||
specifier: ^2.18.0
|
||||
version: 2.18.0(@babel/core@7.29.0)(@babel/template@7.28.6)(@types/react@19.2.14)(react@19.2.4)
|
||||
specifier: ^2.18.1
|
||||
version: 2.18.1(@babel/core@7.29.0)(@babel/template@7.28.6)(@types/react@19.2.14)(react@19.2.4)
|
||||
radix-ui:
|
||||
specifier: ^1.4.3
|
||||
version: 1.4.3(@types/react-dom@19.2.3(@types/react@19.2.14))(@types/react@19.2.14)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)
|
||||
@@ -54,8 +54,8 @@ importers:
|
||||
specifier: ^19.2.0
|
||||
version: 19.2.4(react@19.2.4)
|
||||
react-i18next:
|
||||
specifier: ^16.5.4
|
||||
version: 16.5.4(i18next@25.8.14(typescript@5.9.3))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(typescript@5.9.3)
|
||||
specifier: ^16.5.8
|
||||
version: 16.5.8(i18next@25.8.14(typescript@5.9.3))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(typescript@5.9.3)
|
||||
react-markdown:
|
||||
specifier: ^10.1.0
|
||||
version: 10.1.0(@types/react@19.2.14)(react@19.2.4)
|
||||
@@ -85,14 +85,14 @@ importers:
|
||||
version: 10.0.0
|
||||
devDependencies:
|
||||
'@eslint/js':
|
||||
specifier: ^9.39.1
|
||||
specifier: ^9.39.3
|
||||
version: 9.39.3
|
||||
'@tailwindcss/typography':
|
||||
specifier: ^0.5.19
|
||||
version: 0.5.19(tailwindcss@4.2.1)
|
||||
'@tanstack/router-plugin':
|
||||
specifier: ^1.164.0
|
||||
version: 1.164.0(@tanstack/react-router@1.163.3(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(vite@7.3.1(@types/node@24.11.0)(jiti@2.6.1)(lightningcss@1.31.1)(tsx@4.21.0))
|
||||
version: 1.164.0(@tanstack/react-router@1.167.0(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(vite@7.3.1(@types/node@24.11.0)(jiti@2.6.1)(lightningcss@1.31.1)(tsx@4.21.0))
|
||||
'@trivago/prettier-plugin-sort-imports':
|
||||
specifier: ^6.0.2
|
||||
version: 6.0.2(prettier@3.8.1)
|
||||
@@ -109,10 +109,10 @@ importers:
|
||||
specifier: ^8.56.1
|
||||
version: 8.56.1(@typescript-eslint/parser@8.56.1(eslint@9.39.3(jiti@2.6.1))(typescript@5.9.3))(eslint@9.39.3(jiti@2.6.1))(typescript@5.9.3)
|
||||
'@vitejs/plugin-react':
|
||||
specifier: ^5.1.1
|
||||
version: 5.1.4(vite@7.3.1(@types/node@24.11.0)(jiti@2.6.1)(lightningcss@1.31.1)(tsx@4.21.0))
|
||||
specifier: ^5.2.0
|
||||
version: 5.2.0(vite@7.3.1(@types/node@24.11.0)(jiti@2.6.1)(lightningcss@1.31.1)(tsx@4.21.0))
|
||||
eslint:
|
||||
specifier: ^9.39.1
|
||||
specifier: ^9.39.3
|
||||
version: 9.39.3(jiti@2.6.1)
|
||||
eslint-config-prettier:
|
||||
specifier: ^10.1.8
|
||||
@@ -469,8 +469,8 @@ packages:
|
||||
resolution: {integrity: sha512-EriSTlt5OC9/7SXkRSCAhfSxxoSUgBm33OH+IkwbdpgoqsSsUg7y3uh+IICI/Qg4BBWr3U2i39RpmycbxMq4ew==}
|
||||
engines: {node: ^12.0.0 || ^14.0.0 || >=16.0.0}
|
||||
|
||||
'@eslint/config-array@0.21.1':
|
||||
resolution: {integrity: sha512-aw1gNayWpdI/jSYVgzN5pL0cfzU02GT3NBpeT/DXbx1/1x7ZKxFPd9bwrzygx/qiwIQiJ1sw/zD8qY/kRvlGHA==}
|
||||
'@eslint/config-array@0.21.2':
|
||||
resolution: {integrity: sha512-nJl2KGTlrf9GjLimgIru+V/mzgSK0ABCDQRvxw5BjURL7WfH5uoWmizbH7QB6MmnMBd8cIC9uceWnezL1VZWWw==}
|
||||
engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0}
|
||||
|
||||
'@eslint/config-helpers@0.4.2':
|
||||
@@ -481,8 +481,8 @@ packages:
|
||||
resolution: {integrity: sha512-yL/sLrpmtDaFEiUj1osRP4TI2MDz1AddJL+jZ7KSqvBuliN4xqYY54IfdN8qD8Toa6g1iloph1fxQNkjOxrrpQ==}
|
||||
engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0}
|
||||
|
||||
'@eslint/eslintrc@3.3.4':
|
||||
resolution: {integrity: sha512-4h4MVF8pmBsncB60r0wSJiIeUKTSD4m7FmTFThG8RHlsg9ajqckLm9OraguFGZE4vVdpiI1Q4+hFnisopmG6gQ==}
|
||||
'@eslint/eslintrc@3.3.5':
|
||||
resolution: {integrity: sha512-4IlJx0X0qftVsN5E+/vGujTRIFtwuLbNsVUe7TO6zYPDR1O6nFwvwhIKEKSrl6dZchmYBITazxKoUYOjdtjlRg==}
|
||||
engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0}
|
||||
|
||||
'@eslint/js@9.39.3':
|
||||
@@ -1587,15 +1587,15 @@ packages:
|
||||
'@tanstack/router-core':
|
||||
optional: true
|
||||
|
||||
'@tanstack/react-router@1.163.3':
|
||||
resolution: {integrity: sha512-hheBbFVb+PbxtrWp8iy6+TTRTbhx3Pn6hKo8Tv/sWlG89ZMcD1xpQWzx8ukHN9K8YWbh5rdzt4kv6u8X4kB28Q==}
|
||||
'@tanstack/react-router@1.167.0':
|
||||
resolution: {integrity: sha512-U7CamtXjuC8ixg1c32Rj/4A2OFBnjtMLdbgbyOGHrFHE7ULWS/yhnZLVXff0QSyn6qF92Oecek9mDMHCaTnB2Q==}
|
||||
engines: {node: '>=20.19'}
|
||||
peerDependencies:
|
||||
react: '>=18.0.0 || >=19.0.0'
|
||||
react-dom: '>=18.0.0 || >=19.0.0'
|
||||
|
||||
'@tanstack/react-store@0.9.1':
|
||||
resolution: {integrity: sha512-YzJLnRvy5lIEFTLWBAZmcOjK3+2AepnBv/sr6NZmiqJvq7zTQggyK99Gw8fqYdMdHPQWXjz0epFKJXC+9V2xDA==}
|
||||
'@tanstack/react-store@0.9.2':
|
||||
resolution: {integrity: sha512-Vt5usJE5sHG/cMechQfmwvwne6ktGCELe89Lmvoxe3LKRoFrhPa8OCKWs0NliG8HTJElEIj7PLtaBQIcux5pAQ==}
|
||||
peerDependencies:
|
||||
react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0
|
||||
react-dom: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0
|
||||
@@ -1604,6 +1604,10 @@ packages:
|
||||
resolution: {integrity: sha512-jPptiGq/w3nuPzcMC7RNa79aU+b6OjaDzWJnBcV2UAwL4ThJamRS4h42TdhJE+oF5yH9IEnCOGQdfnbw45LbfA==}
|
||||
engines: {node: '>=20.19'}
|
||||
|
||||
'@tanstack/router-core@1.167.0':
|
||||
resolution: {integrity: sha512-pnaaUP+vMQEyL2XjZGe2PXmtzulxvXfGyvEMUs+AEBaNEk77xWA88bl3ujiBRbUxzpK0rxfJf+eSKPdZmBMFdQ==}
|
||||
engines: {node: '>=20.19'}
|
||||
|
||||
'@tanstack/router-devtools-core@1.163.3':
|
||||
resolution: {integrity: sha512-FPi64IP0PT1IkoeyGmsD6JoOVOYAb85VCH0mUbSdD90yV0+1UB6oT+D7K27GXkp7SXMJN3mBEjU5rKnNnmSCIw==}
|
||||
engines: {node: '>=20.19'}
|
||||
@@ -1646,6 +1650,9 @@ packages:
|
||||
'@tanstack/store@0.9.1':
|
||||
resolution: {integrity: sha512-+qcNkOy0N1qSGsP7omVCW0SDrXtaDcycPqBDE726yryiA5eTDFpjBReaYjghVJwNf1pcPMyzIwTGlYjCSQR0Fg==}
|
||||
|
||||
'@tanstack/store@0.9.2':
|
||||
resolution: {integrity: sha512-K013lUJEFJK2ofFQ/hZKJUmCnpcV00ebLyOyFOWQvyQHUOZp/iYO84BM6aOGiV81JzwbX0APTVmW8YI7yiG5oA==}
|
||||
|
||||
'@tanstack/virtual-file-routes@1.161.4':
|
||||
resolution: {integrity: sha512-42WoRePf8v690qG8yGRe/YOh+oHni9vUaUUfoqlS91U2scd3a5rkLtVsc6b7z60w3RogH0I00vdrC5AaeiZ18w==}
|
||||
engines: {node: '>=20.19'}
|
||||
@@ -1790,11 +1797,11 @@ packages:
|
||||
'@ungap/structured-clone@1.3.0':
|
||||
resolution: {integrity: sha512-WmoN8qaIAo7WTYWbAZuG8PYEhn5fkz7dZrqTBZ7dtt//lL2Gwms1IcnQ5yHqjDfX8Ft5j4YzDM23f87zBfDe9g==}
|
||||
|
||||
'@vitejs/plugin-react@5.1.4':
|
||||
resolution: {integrity: sha512-VIcFLdRi/VYRU8OL/puL7QXMYafHmqOnwTZY50U1JPlCNj30PxCMx65c494b1K9be9hX83KVt0+gTEwTWLqToA==}
|
||||
'@vitejs/plugin-react@5.2.0':
|
||||
resolution: {integrity: sha512-YmKkfhOAi3wsB1PhJq5Scj3GXMn3WvtQ/JC0xoopuHoXSdmtdStOpFrYaT1kie2YgFBcIe64ROzMYRjCrYOdYw==}
|
||||
engines: {node: ^20.19.0 || >=22.12.0}
|
||||
peerDependencies:
|
||||
vite: ^4.2.0 || ^5.0.0 || ^6.0.0 || ^7.0.0
|
||||
vite: ^4.2.0 || ^5.0.0 || ^6.0.0 || ^7.0.0 || ^8.0.0
|
||||
|
||||
accepts@2.0.0:
|
||||
resolution: {integrity: sha512-5cvg6CtKwfgdmVqY1WIiXKc3Q1bkRqGLi+2W/6ao+6Y7gu/RCwRuAhGEzh5B4KlszSuTLgZYuqFqo5bImjNKng==}
|
||||
@@ -2060,8 +2067,8 @@ packages:
|
||||
resolution: {integrity: sha512-0R9ikRb668HB7QDxT1vkpuUBtqc53YyAwMwGeUFKRojY/NWKvdZ+9UYtRfGmhqNbRkTSVpMbmyhXipFFv2cb/A==}
|
||||
engines: {node: '>= 12'}
|
||||
|
||||
dayjs@1.11.19:
|
||||
resolution: {integrity: sha512-t5EcLVS6QPBNqM2z8fakk/NKel+Xzshgt8FFKAn+qwlD1pzZWxh0nVCrvFK7ZDb6XucZeF9z8C7CBWTRIVApAw==}
|
||||
dayjs@1.11.20:
|
||||
resolution: {integrity: sha512-YbwwqR/uYpeoP4pu043q+LTDLFBLApUP6VxRihdfNTqu4ubqMlGDLd6ErXhEgsyvY0K6nCs7nggYumAN+9uEuQ==}
|
||||
|
||||
debug@4.4.3:
|
||||
resolution: {integrity: sha512-RGwwWnwQvkVfavKVt22FGLw+xYSdzARwm0ru6DhTVA3umU5hZc28V3kO4stgYryrTlLpuvgI9GiijltAjNbcqA==}
|
||||
@@ -2355,8 +2362,8 @@ packages:
|
||||
resolution: {integrity: sha512-f7ccFPK3SXFHpx15UIGyRJ/FJQctuKZ0zVuN3frBo4HnK3cay9VEW0R6yPYFHC0AgqhukPzKjq22t5DmAyqGyw==}
|
||||
engines: {node: '>=16'}
|
||||
|
||||
flatted@3.3.3:
|
||||
resolution: {integrity: sha512-GX+ysw4PBCz0PzosHDepZGANEuFCMLrnRTiEy9McGjmkCQYwRq4A/X786G/fjM/+OjsWSU1ZrY5qyARZmO/uwg==}
|
||||
flatted@3.4.1:
|
||||
resolution: {integrity: sha512-IxfVbRFVlV8V/yRaGzk0UVIcsKKHMSfYw66T/u4nTwlWteQePsxe//LjudR1AMX4tZW3WFCh3Zqa/sjlqpbURQ==}
|
||||
|
||||
formdata-polyfill@4.0.10:
|
||||
resolution: {integrity: sha512-buewHzMvYL29jdeQTVILecSaZKnt/RJWjoZCF5OW60Z67/GmSLBkOFM7qh1PI3zFNtJbaZL5eQu1vLfazOwj4g==}
|
||||
@@ -2648,8 +2655,8 @@ packages:
|
||||
resolution: {integrity: sha512-e6rvdUCiQCAuumZslxRJWR/Doq4VpPR82kqclvcS0efgt430SlGIk05vdCN58+VrzgtIcfNODjozVielycD4Sw==}
|
||||
engines: {node: '>=16'}
|
||||
|
||||
isbot@5.1.35:
|
||||
resolution: {integrity: sha512-waFfC72ZNfwLLuJ2iLaoVaqcNo+CAaLR7xCpAn0Y5WfGzkNHv7ZN39Vbi1y+kb+Zs46XHOX3tZNExroFUPX+Kg==}
|
||||
isbot@5.1.36:
|
||||
resolution: {integrity: sha512-C/ZtXyJqDPZ7G7JPr06ApWyYoHjYexQbS6hPYD4WYCzpv2Qes6Z+CCEfTX4Owzf+1EJ933PoI2p+B9v7wpGZBQ==}
|
||||
engines: {node: '>=18'}
|
||||
|
||||
isexe@2.0.0:
|
||||
@@ -2669,8 +2676,8 @@ packages:
|
||||
jose@6.1.3:
|
||||
resolution: {integrity: sha512-0TpaTfihd4QMNwrz/ob2Bp7X04yuxJkjRGi4aKmOqwhov54i6u79oCv7T+C7lo70MKH6BesI3vscD1yb/yzKXQ==}
|
||||
|
||||
jotai@2.18.0:
|
||||
resolution: {integrity: sha512-XI38kGWAvtxAZ+cwHcTgJsd+kJOJGf3OfL4XYaXWZMZ7IIY8e53abpIHvtVn1eAgJ5dlgwlGFnP4psrZ/vZbtA==}
|
||||
jotai@2.18.1:
|
||||
resolution: {integrity: sha512-e0NOzK+yRFwHo7DOp0DS0Ycq74KMEAObDWFGmfEL28PD9nLqBTt3/Ug7jf9ca72x0gC9LQZG9zH+0ISICmy3iA==}
|
||||
engines: {node: '>=12.20.0'}
|
||||
peerDependencies:
|
||||
'@babel/core': '>=7.0.0'
|
||||
@@ -3323,8 +3330,8 @@ packages:
|
||||
peerDependencies:
|
||||
react: ^19.2.4
|
||||
|
||||
react-i18next@16.5.4:
|
||||
resolution: {integrity: sha512-6yj+dcfMncEC21QPhOTsW8mOSO+pzFmT6uvU7XXdvM/Cp38zJkmTeMeKmTrmCMD5ToT79FmiE/mRWiYWcJYW4g==}
|
||||
react-i18next@16.5.8:
|
||||
resolution: {integrity: sha512-2ABeHHlakxVY+LSirD+OiERxFL6+zip0PaHo979bgwzeHg27Sqc82xxXWIrSFmfWX0ZkrvXMHwhsi/NGUf5VQg==}
|
||||
peerDependencies:
|
||||
i18next: '>= 25.6.2'
|
||||
react: '>= 16.8.0'
|
||||
@@ -3476,10 +3483,20 @@ packages:
|
||||
peerDependencies:
|
||||
seroval: ^1.0
|
||||
|
||||
seroval-plugins@1.5.1:
|
||||
resolution: {integrity: sha512-4FbuZ/TMl02sqv0RTFexu0SP6V+ywaIe5bAWCCEik0fk17BhALgwvUDVF7e3Uvf9pxmwCEJsRPmlkUE6HdzLAw==}
|
||||
engines: {node: '>=10'}
|
||||
peerDependencies:
|
||||
seroval: ^1.0
|
||||
|
||||
seroval@1.5.0:
|
||||
resolution: {integrity: sha512-OE4cvmJ1uSPrKorFIH9/w/Qwuvi/IMcGbv5RKgcJ/zjA/IohDLU6SVaxFN9FwajbP7nsX0dQqMDes1whk3y+yw==}
|
||||
engines: {node: '>=10'}
|
||||
|
||||
seroval@1.5.1:
|
||||
resolution: {integrity: sha512-OwrZRZAfhHww0WEnKHDY8OM0U/Qs8OTfIDWhUD4BLpNJUfXK4cGmjiagGze086m+mhI+V2nD0gfbHEnJjb9STA==}
|
||||
engines: {node: '>=10'}
|
||||
|
||||
serve-static@2.2.1:
|
||||
resolution: {integrity: sha512-xRXBn0pPqQTVQiC8wyQrKs2MOlX24zQ0POGaj0kultvoOCstBQM5yvOhAVSUwOMjQtTvsPWoNCHfPGwaaQJhTw==}
|
||||
engines: {node: '>= 18'}
|
||||
@@ -4268,7 +4285,7 @@ snapshots:
|
||||
|
||||
'@eslint-community/regexpp@4.12.2': {}
|
||||
|
||||
'@eslint/config-array@0.21.1':
|
||||
'@eslint/config-array@0.21.2':
|
||||
dependencies:
|
||||
'@eslint/object-schema': 2.1.7
|
||||
debug: 4.4.3
|
||||
@@ -4284,7 +4301,7 @@ snapshots:
|
||||
dependencies:
|
||||
'@types/json-schema': 7.0.15
|
||||
|
||||
'@eslint/eslintrc@3.3.4':
|
||||
'@eslint/eslintrc@3.3.5':
|
||||
dependencies:
|
||||
ajv: 6.14.0
|
||||
debug: 4.4.3
|
||||
@@ -5365,31 +5382,31 @@ snapshots:
|
||||
'@tanstack/query-core': 5.90.20
|
||||
react: 19.2.4
|
||||
|
||||
'@tanstack/react-router-devtools@1.163.3(@tanstack/react-router@1.163.3(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(@tanstack/router-core@1.163.3)(csstype@3.2.3)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)':
|
||||
'@tanstack/react-router-devtools@1.163.3(@tanstack/react-router@1.167.0(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(@tanstack/router-core@1.167.0)(csstype@3.2.3)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)':
|
||||
dependencies:
|
||||
'@tanstack/react-router': 1.163.3(react-dom@19.2.4(react@19.2.4))(react@19.2.4)
|
||||
'@tanstack/router-devtools-core': 1.163.3(@tanstack/router-core@1.163.3)(csstype@3.2.3)
|
||||
'@tanstack/react-router': 1.167.0(react-dom@19.2.4(react@19.2.4))(react@19.2.4)
|
||||
'@tanstack/router-devtools-core': 1.163.3(@tanstack/router-core@1.167.0)(csstype@3.2.3)
|
||||
react: 19.2.4
|
||||
react-dom: 19.2.4(react@19.2.4)
|
||||
optionalDependencies:
|
||||
'@tanstack/router-core': 1.163.3
|
||||
'@tanstack/router-core': 1.167.0
|
||||
transitivePeerDependencies:
|
||||
- csstype
|
||||
|
||||
'@tanstack/react-router@1.163.3(react-dom@19.2.4(react@19.2.4))(react@19.2.4)':
|
||||
'@tanstack/react-router@1.167.0(react-dom@19.2.4(react@19.2.4))(react@19.2.4)':
|
||||
dependencies:
|
||||
'@tanstack/history': 1.161.4
|
||||
'@tanstack/react-store': 0.9.1(react-dom@19.2.4(react@19.2.4))(react@19.2.4)
|
||||
'@tanstack/router-core': 1.163.3
|
||||
isbot: 5.1.35
|
||||
'@tanstack/react-store': 0.9.2(react-dom@19.2.4(react@19.2.4))(react@19.2.4)
|
||||
'@tanstack/router-core': 1.167.0
|
||||
isbot: 5.1.36
|
||||
react: 19.2.4
|
||||
react-dom: 19.2.4(react@19.2.4)
|
||||
tiny-invariant: 1.3.3
|
||||
tiny-warning: 1.0.3
|
||||
|
||||
'@tanstack/react-store@0.9.1(react-dom@19.2.4(react@19.2.4))(react@19.2.4)':
|
||||
'@tanstack/react-store@0.9.2(react-dom@19.2.4(react@19.2.4))(react@19.2.4)':
|
||||
dependencies:
|
||||
'@tanstack/store': 0.9.1
|
||||
'@tanstack/store': 0.9.2
|
||||
react: 19.2.4
|
||||
react-dom: 19.2.4(react@19.2.4)
|
||||
use-sync-external-store: 1.6.0(react@19.2.4)
|
||||
@@ -5404,9 +5421,19 @@ snapshots:
|
||||
tiny-invariant: 1.3.3
|
||||
tiny-warning: 1.0.3
|
||||
|
||||
'@tanstack/router-devtools-core@1.163.3(@tanstack/router-core@1.163.3)(csstype@3.2.3)':
|
||||
'@tanstack/router-core@1.167.0':
|
||||
dependencies:
|
||||
'@tanstack/router-core': 1.163.3
|
||||
'@tanstack/history': 1.161.4
|
||||
'@tanstack/store': 0.9.2
|
||||
cookie-es: 2.0.0
|
||||
seroval: 1.5.1
|
||||
seroval-plugins: 1.5.1(seroval@1.5.1)
|
||||
tiny-invariant: 1.3.3
|
||||
tiny-warning: 1.0.3
|
||||
|
||||
'@tanstack/router-devtools-core@1.163.3(@tanstack/router-core@1.167.0)(csstype@3.2.3)':
|
||||
dependencies:
|
||||
'@tanstack/router-core': 1.167.0
|
||||
clsx: 2.1.1
|
||||
goober: 2.1.18(csstype@3.2.3)
|
||||
tiny-invariant: 1.3.3
|
||||
@@ -5426,7 +5453,7 @@ snapshots:
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
|
||||
'@tanstack/router-plugin@1.164.0(@tanstack/react-router@1.163.3(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(vite@7.3.1(@types/node@24.11.0)(jiti@2.6.1)(lightningcss@1.31.1)(tsx@4.21.0))':
|
||||
'@tanstack/router-plugin@1.164.0(@tanstack/react-router@1.167.0(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(vite@7.3.1(@types/node@24.11.0)(jiti@2.6.1)(lightningcss@1.31.1)(tsx@4.21.0))':
|
||||
dependencies:
|
||||
'@babel/core': 7.29.0
|
||||
'@babel/plugin-syntax-jsx': 7.28.6(@babel/core@7.29.0)
|
||||
@@ -5442,7 +5469,7 @@ snapshots:
|
||||
unplugin: 2.3.11
|
||||
zod: 3.25.76
|
||||
optionalDependencies:
|
||||
'@tanstack/react-router': 1.163.3(react-dom@19.2.4(react@19.2.4))(react@19.2.4)
|
||||
'@tanstack/react-router': 1.167.0(react-dom@19.2.4(react@19.2.4))(react@19.2.4)
|
||||
vite: 7.3.1(@types/node@24.11.0)(jiti@2.6.1)(lightningcss@1.31.1)(tsx@4.21.0)
|
||||
transitivePeerDependencies:
|
||||
- supports-color
|
||||
@@ -5463,6 +5490,8 @@ snapshots:
|
||||
|
||||
'@tanstack/store@0.9.1': {}
|
||||
|
||||
'@tanstack/store@0.9.2': {}
|
||||
|
||||
'@tanstack/virtual-file-routes@1.161.4': {}
|
||||
|
||||
'@trivago/prettier-plugin-sort-imports@6.0.2(prettier@3.8.1)':
|
||||
@@ -5641,7 +5670,7 @@ snapshots:
|
||||
|
||||
'@ungap/structured-clone@1.3.0': {}
|
||||
|
||||
'@vitejs/plugin-react@5.1.4(vite@7.3.1(@types/node@24.11.0)(jiti@2.6.1)(lightningcss@1.31.1)(tsx@4.21.0))':
|
||||
'@vitejs/plugin-react@5.2.0(vite@7.3.1(@types/node@24.11.0)(jiti@2.6.1)(lightningcss@1.31.1)(tsx@4.21.0))':
|
||||
dependencies:
|
||||
'@babel/core': 7.29.0
|
||||
'@babel/plugin-transform-react-jsx-self': 7.27.1(@babel/core@7.29.0)
|
||||
@@ -5894,7 +5923,7 @@ snapshots:
|
||||
|
||||
data-uri-to-buffer@4.0.1: {}
|
||||
|
||||
dayjs@1.11.19: {}
|
||||
dayjs@1.11.20: {}
|
||||
|
||||
debug@4.4.3:
|
||||
dependencies:
|
||||
@@ -6048,10 +6077,10 @@ snapshots:
|
||||
dependencies:
|
||||
'@eslint-community/eslint-utils': 4.9.1(eslint@9.39.3(jiti@2.6.1))
|
||||
'@eslint-community/regexpp': 4.12.2
|
||||
'@eslint/config-array': 0.21.1
|
||||
'@eslint/config-array': 0.21.2
|
||||
'@eslint/config-helpers': 0.4.2
|
||||
'@eslint/core': 0.17.0
|
||||
'@eslint/eslintrc': 3.3.4
|
||||
'@eslint/eslintrc': 3.3.5
|
||||
'@eslint/js': 9.39.3
|
||||
'@eslint/plugin-kit': 0.4.1
|
||||
'@humanfs/node': 0.16.7
|
||||
@@ -6241,10 +6270,10 @@ snapshots:
|
||||
|
||||
flat-cache@4.0.1:
|
||||
dependencies:
|
||||
flatted: 3.3.3
|
||||
flatted: 3.4.1
|
||||
keyv: 4.5.4
|
||||
|
||||
flatted@3.3.3: {}
|
||||
flatted@3.4.1: {}
|
||||
|
||||
formdata-polyfill@4.0.10:
|
||||
dependencies:
|
||||
@@ -6489,7 +6518,7 @@ snapshots:
|
||||
dependencies:
|
||||
is-inside-container: 1.0.0
|
||||
|
||||
isbot@5.1.35: {}
|
||||
isbot@5.1.36: {}
|
||||
|
||||
isexe@2.0.0: {}
|
||||
|
||||
@@ -6501,7 +6530,7 @@ snapshots:
|
||||
|
||||
jose@6.1.3: {}
|
||||
|
||||
jotai@2.18.0(@babel/core@7.29.0)(@babel/template@7.28.6)(@types/react@19.2.14)(react@19.2.4):
|
||||
jotai@2.18.1(@babel/core@7.29.0)(@babel/template@7.28.6)(@types/react@19.2.14)(react@19.2.4):
|
||||
optionalDependencies:
|
||||
'@babel/core': 7.29.0
|
||||
'@babel/template': 7.28.6
|
||||
@@ -7310,7 +7339,7 @@ snapshots:
|
||||
react: 19.2.4
|
||||
scheduler: 0.27.0
|
||||
|
||||
react-i18next@16.5.4(i18next@25.8.14(typescript@5.9.3))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(typescript@5.9.3):
|
||||
react-i18next@16.5.8(i18next@25.8.14(typescript@5.9.3))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(typescript@5.9.3):
|
||||
dependencies:
|
||||
'@babel/runtime': 7.28.6
|
||||
html-parse-stringify: 3.0.1
|
||||
@@ -7517,8 +7546,14 @@ snapshots:
|
||||
dependencies:
|
||||
seroval: 1.5.0
|
||||
|
||||
seroval-plugins@1.5.1(seroval@1.5.1):
|
||||
dependencies:
|
||||
seroval: 1.5.1
|
||||
|
||||
seroval@1.5.0: {}
|
||||
|
||||
seroval@1.5.1: {}
|
||||
|
||||
serve-static@2.2.1:
|
||||
dependencies:
|
||||
encodeurl: 2.0.0
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
const fs = require("node:fs")
|
||||
const path = require("node:path")
|
||||
|
||||
const gitkeepPath = path.resolve(__dirname, "../../backend/dist/.gitkeep")
|
||||
const gitkeepContents =
|
||||
"# Keep the embedded web backend dist directory in version control.\n"
|
||||
|
||||
fs.mkdirSync(path.dirname(gitkeepPath), { recursive: true })
|
||||
fs.writeFileSync(gitkeepPath, gitkeepContents)
|
||||
@@ -56,14 +56,20 @@ export function AppHeader() {
|
||||
const isRunning = gwState === "running"
|
||||
const isStarting = gwState === "starting"
|
||||
const isRestarting = gwState === "restarting"
|
||||
const isStopping = gwState === "stopping"
|
||||
const isStopped = gwState === "stopped" || gwState === "unknown"
|
||||
const showNotConnectedHint =
|
||||
!isRestarting && canStart && (gwState === "stopped" || gwState === "error")
|
||||
!isRestarting &&
|
||||
!isStopping &&
|
||||
canStart &&
|
||||
(gwState === "stopped" || gwState === "error")
|
||||
|
||||
const [showStopDialog, setShowStopDialog] = React.useState(false)
|
||||
|
||||
const handleGatewayToggle = () => {
|
||||
if (gwLoading || isRestarting || (!isRunning && !canStart)) return
|
||||
if (gwLoading || isRestarting || isStopping || (!isRunning && !canStart)) {
|
||||
return
|
||||
}
|
||||
if (isRunning) {
|
||||
setShowStopDialog(true)
|
||||
} else {
|
||||
@@ -137,7 +143,7 @@ export function AppHeader() {
|
||||
size="icon-sm"
|
||||
className="bg-amber-500/15 text-amber-700 hover:bg-amber-500/25 hover:text-amber-800 dark:text-amber-300 dark:hover:bg-amber-500/25"
|
||||
onClick={handleGatewayRestart}
|
||||
disabled={gwLoading || isRestarting || !canStart}
|
||||
disabled={gwLoading || isRestarting || isStopping || !canStart}
|
||||
aria-label={t("header.gateway.action.restart")}
|
||||
>
|
||||
<IconRefresh className="size-4" />
|
||||
@@ -168,25 +174,31 @@ export function AppHeader() {
|
||||
</Tooltip>
|
||||
) : (
|
||||
<Button
|
||||
variant={isStarting || isRestarting ? "secondary" : "default"}
|
||||
variant={
|
||||
isStarting || isRestarting || isStopping ? "secondary" : "default"
|
||||
}
|
||||
size="sm"
|
||||
className={`h-8 gap-2 px-3 ${
|
||||
isStopped ? "bg-green-500 text-white hover:bg-green-600" : ""
|
||||
}`}
|
||||
onClick={handleGatewayToggle}
|
||||
disabled={gwLoading || isStarting || isRestarting || !canStart}
|
||||
disabled={
|
||||
gwLoading || isStarting || isRestarting || isStopping || !canStart
|
||||
}
|
||||
>
|
||||
{gwLoading || isStarting || isRestarting ? (
|
||||
{gwLoading || isStarting || isRestarting || isStopping ? (
|
||||
<IconLoader2 className="h-4 w-4 animate-spin opacity-70" />
|
||||
) : (
|
||||
<IconPlayerPlay className="h-4 w-4 opacity-80" />
|
||||
)}
|
||||
<span className="text-xs font-semibold">
|
||||
{isRestarting
|
||||
? t("header.gateway.status.restarting")
|
||||
: isStarting
|
||||
? t("header.gateway.status.starting")
|
||||
: t("header.gateway.action.start")}
|
||||
{isStopping
|
||||
? t("header.gateway.status.stopping")
|
||||
: isRestarting
|
||||
? t("header.gateway.status.restarting")
|
||||
: isStarting
|
||||
? t("header.gateway.status.starting")
|
||||
: t("header.gateway.action.start")}
|
||||
</span>
|
||||
</Button>
|
||||
)}
|
||||
|
||||
@@ -42,7 +42,7 @@ export function ChatComposer({
|
||||
placeholder={t("chat.placeholder")}
|
||||
disabled={!canInput}
|
||||
className={cn(
|
||||
"max-h-[200px] min-h-[60px] resize-none border-0 bg-transparent px-2 py-1 text-[15px] shadow-none transition-colors focus-visible:ring-0 focus-visible:outline-none dark:bg-transparent",
|
||||
"placeholder:text-muted-foreground max-h-[200px] min-h-[60px] resize-none border-0 bg-transparent px-2 py-1 text-[15px] shadow-none transition-colors focus-visible:ring-0 focus-visible:outline-none dark:bg-transparent",
|
||||
!canInput && "cursor-not-allowed",
|
||||
)}
|
||||
minRows={1}
|
||||
@@ -56,7 +56,7 @@ export function ChatComposer({
|
||||
size="icon"
|
||||
className="size-8 rounded-full bg-violet-500 text-white transition-transform hover:bg-violet-600 active:scale-95"
|
||||
onClick={onSend}
|
||||
disabled={!input.trim() || !isConnected}
|
||||
disabled={!input.trim() || !canInput}
|
||||
>
|
||||
<IconArrowUp className="size-4" />
|
||||
</Button>
|
||||
|
||||
@@ -34,7 +34,7 @@ export function ChatEmptyState({
|
||||
<p className="text-muted-foreground mb-4 text-center text-sm">
|
||||
{t("chat.empty.noConfiguredModelDescription")}
|
||||
</p>
|
||||
<Button asChild variant="secondary" size="sm" className="px-4">
|
||||
<Button asChild variant="outline" size="sm" className="px-4">
|
||||
<Link to="/models">{t("chat.empty.goToModels")}</Link>
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
@@ -15,7 +15,6 @@ import { useChatModels } from "@/hooks/use-chat-models"
|
||||
import { useGateway } from "@/hooks/use-gateway"
|
||||
import { usePicoChat } from "@/hooks/use-pico-chat"
|
||||
import { useSessionHistory } from "@/hooks/use-session-history"
|
||||
import { hydrateActiveSession } from "@/lib/pico-chat-controller"
|
||||
|
||||
export function ChatPage() {
|
||||
const { t } = useTranslation()
|
||||
@@ -26,6 +25,7 @@ export function ChatPage() {
|
||||
|
||||
const {
|
||||
messages,
|
||||
connectionState,
|
||||
isTyping,
|
||||
activeSessionId,
|
||||
sendMessage,
|
||||
@@ -34,7 +34,8 @@ export function ChatPage() {
|
||||
} = usePicoChat()
|
||||
|
||||
const { state: gwState } = useGateway()
|
||||
const isConnected = gwState === "running"
|
||||
const isGatewayRunning = gwState === "running"
|
||||
const isChatConnected = connectionState === "connected"
|
||||
|
||||
const {
|
||||
defaultModelName,
|
||||
@@ -43,7 +44,8 @@ export function ChatPage() {
|
||||
oauthModels,
|
||||
localModels,
|
||||
handleSetDefault,
|
||||
} = useChatModels({ isConnected })
|
||||
} = useChatModels({ isConnected: isGatewayRunning })
|
||||
const canSend = isChatConnected && Boolean(defaultModelName)
|
||||
|
||||
const {
|
||||
sessions,
|
||||
@@ -68,10 +70,6 @@ export function ChatPage() {
|
||||
syncScrollState(e.currentTarget)
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
void hydrateActiveSession()
|
||||
}, [])
|
||||
|
||||
useEffect(() => {
|
||||
if (scrollRef.current) {
|
||||
if (isAtBottom) {
|
||||
@@ -82,9 +80,10 @@ export function ChatPage() {
|
||||
}, [messages, isTyping, isAtBottom])
|
||||
|
||||
const handleSend = () => {
|
||||
if (!input.trim() || !isConnected) return
|
||||
sendMessage(input.trim())
|
||||
setInput("")
|
||||
if (!input.trim() || !canSend) return
|
||||
if (sendMessage(input.trim())) {
|
||||
setInput("")
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
@@ -143,7 +142,7 @@ export function ChatPage() {
|
||||
<ChatEmptyState
|
||||
hasConfiguredModels={hasConfiguredModels}
|
||||
defaultModelName={defaultModelName}
|
||||
isConnected={isConnected}
|
||||
isConnected={isGatewayRunning}
|
||||
/>
|
||||
)}
|
||||
|
||||
@@ -168,7 +167,7 @@ export function ChatPage() {
|
||||
input={input}
|
||||
onInputChange={setInput}
|
||||
onSend={handleSend}
|
||||
isConnected={isConnected}
|
||||
isConnected={isChatConnected}
|
||||
hasDefaultModel={Boolean(defaultModelName)}
|
||||
/>
|
||||
</div>
|
||||
|
||||
@@ -14,7 +14,9 @@ import {
|
||||
} from "@/api/system"
|
||||
import {
|
||||
AgentDefaultsSection,
|
||||
CronSection,
|
||||
DevicesSection,
|
||||
ExecSection,
|
||||
LauncherSection,
|
||||
RuntimeSection,
|
||||
} from "@/components/config/config-sections"
|
||||
@@ -26,6 +28,7 @@ import {
|
||||
buildFormFromConfig,
|
||||
parseCIDRText,
|
||||
parseIntField,
|
||||
parseMultilineList,
|
||||
} from "@/components/config/form-model"
|
||||
import { PageHeader } from "@/components/page-header"
|
||||
import { Button } from "@/components/ui/button"
|
||||
@@ -164,6 +167,33 @@ export function ConfigPage() {
|
||||
"Heartbeat interval",
|
||||
{ min: 1 },
|
||||
)
|
||||
const cronExecTimeoutMinutes = parseIntField(
|
||||
form.cronExecTimeoutMinutes,
|
||||
"Cron exec timeout",
|
||||
{ min: 0 },
|
||||
)
|
||||
const execConfigPatch: Record<string, unknown> = {
|
||||
enabled: form.execEnabled,
|
||||
}
|
||||
|
||||
if (form.execEnabled) {
|
||||
execConfigPatch.allow_remote = form.allowRemote
|
||||
execConfigPatch.enable_deny_patterns = form.enableDenyPatterns
|
||||
execConfigPatch.custom_allow_patterns = parseMultilineList(
|
||||
form.customAllowPatternsText,
|
||||
)
|
||||
execConfigPatch.timeout_seconds = parseIntField(
|
||||
form.execTimeoutSeconds,
|
||||
"Exec timeout",
|
||||
{ min: 0 },
|
||||
)
|
||||
|
||||
if (form.enableDenyPatterns) {
|
||||
execConfigPatch.custom_deny_patterns = parseMultilineList(
|
||||
form.customDenyPatternsText,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
await patchAppConfig({
|
||||
agents: {
|
||||
@@ -180,9 +210,11 @@ export function ConfigPage() {
|
||||
dm_scope: dmScope,
|
||||
},
|
||||
tools: {
|
||||
exec: {
|
||||
allow_remote: form.allowRemote,
|
||||
cron: {
|
||||
allow_command: form.allowCommand,
|
||||
exec_timeout_minutes: cronExecTimeoutMinutes,
|
||||
},
|
||||
exec: execConfigPatch,
|
||||
},
|
||||
heartbeat: {
|
||||
enabled: form.heartbeatEnabled,
|
||||
@@ -279,6 +311,10 @@ export function ConfigPage() {
|
||||
|
||||
<RuntimeSection form={form} onFieldChange={updateField} />
|
||||
|
||||
<ExecSection form={form} onFieldChange={updateField} />
|
||||
|
||||
<CronSection form={form} onFieldChange={updateField} />
|
||||
|
||||
<LauncherSection
|
||||
launcherForm={launcherForm}
|
||||
onFieldChange={updateLauncherField}
|
||||
|
||||
@@ -93,14 +93,6 @@ export function AgentDefaultsSection({
|
||||
}
|
||||
/>
|
||||
|
||||
<SwitchCardField
|
||||
label={t("pages.config.allow_remote")}
|
||||
hint={t("pages.config.allow_remote_hint")}
|
||||
layout="setting-row"
|
||||
checked={form.allowRemote}
|
||||
onCheckedChange={(checked) => onFieldChange("allowRemote", checked)}
|
||||
/>
|
||||
|
||||
<Field
|
||||
label={t("pages.config.max_tokens")}
|
||||
hint={t("pages.config.max_tokens_hint")}
|
||||
@@ -161,6 +153,98 @@ export function AgentDefaultsSection({
|
||||
)
|
||||
}
|
||||
|
||||
interface ExecSectionProps {
|
||||
form: CoreConfigForm
|
||||
onFieldChange: UpdateCoreField
|
||||
}
|
||||
|
||||
export function ExecSection({ form, onFieldChange }: ExecSectionProps) {
|
||||
const { t } = useTranslation()
|
||||
|
||||
return (
|
||||
<ConfigSectionCard title={t("pages.config.sections.exec")}>
|
||||
<SwitchCardField
|
||||
label={t("pages.config.exec_enabled")}
|
||||
hint={t("pages.config.exec_enabled_hint")}
|
||||
layout="setting-row"
|
||||
checked={form.execEnabled}
|
||||
onCheckedChange={(checked) => onFieldChange("execEnabled", checked)}
|
||||
/>
|
||||
|
||||
{form.execEnabled && (
|
||||
<>
|
||||
<SwitchCardField
|
||||
label={t("pages.config.allow_remote")}
|
||||
hint={t("pages.config.allow_remote_hint")}
|
||||
layout="setting-row"
|
||||
checked={form.allowRemote}
|
||||
onCheckedChange={(checked) => onFieldChange("allowRemote", checked)}
|
||||
/>
|
||||
|
||||
<SwitchCardField
|
||||
label={t("pages.config.enable_deny_patterns")}
|
||||
hint={t("pages.config.enable_deny_patterns_hint")}
|
||||
layout="setting-row"
|
||||
checked={form.enableDenyPatterns}
|
||||
onCheckedChange={(checked) =>
|
||||
onFieldChange("enableDenyPatterns", checked)
|
||||
}
|
||||
/>
|
||||
|
||||
{form.enableDenyPatterns && (
|
||||
<Field
|
||||
label={t("pages.config.custom_deny_patterns")}
|
||||
hint={t("pages.config.custom_deny_patterns_hint")}
|
||||
layout="setting-row"
|
||||
controlClassName="md:max-w-md"
|
||||
>
|
||||
<Textarea
|
||||
value={form.customDenyPatternsText}
|
||||
placeholder={t("pages.config.custom_patterns_placeholder")}
|
||||
className="min-h-[88px]"
|
||||
onChange={(e) =>
|
||||
onFieldChange("customDenyPatternsText", e.target.value)
|
||||
}
|
||||
/>
|
||||
</Field>
|
||||
)}
|
||||
|
||||
<Field
|
||||
label={t("pages.config.custom_allow_patterns")}
|
||||
hint={t("pages.config.custom_allow_patterns_hint")}
|
||||
layout="setting-row"
|
||||
controlClassName="md:max-w-md"
|
||||
>
|
||||
<Textarea
|
||||
value={form.customAllowPatternsText}
|
||||
placeholder={t("pages.config.custom_patterns_placeholder")}
|
||||
className="min-h-[88px]"
|
||||
onChange={(e) =>
|
||||
onFieldChange("customAllowPatternsText", e.target.value)
|
||||
}
|
||||
/>
|
||||
</Field>
|
||||
|
||||
<Field
|
||||
label={t("pages.config.exec_timeout_seconds")}
|
||||
hint={t("pages.config.exec_timeout_seconds_hint")}
|
||||
layout="setting-row"
|
||||
>
|
||||
<Input
|
||||
type="number"
|
||||
min={0}
|
||||
value={form.execTimeoutSeconds}
|
||||
onChange={(e) =>
|
||||
onFieldChange("execTimeoutSeconds", e.target.value)
|
||||
}
|
||||
/>
|
||||
</Field>
|
||||
</>
|
||||
)}
|
||||
</ConfigSectionCard>
|
||||
)
|
||||
}
|
||||
|
||||
interface RuntimeSectionProps {
|
||||
form: CoreConfigForm
|
||||
onFieldChange: UpdateCoreField
|
||||
@@ -236,6 +320,44 @@ export function RuntimeSection({ form, onFieldChange }: RuntimeSectionProps) {
|
||||
)
|
||||
}
|
||||
|
||||
interface CronSectionProps {
|
||||
form: CoreConfigForm
|
||||
onFieldChange: UpdateCoreField
|
||||
}
|
||||
|
||||
export function CronSection({ form, onFieldChange }: CronSectionProps) {
|
||||
const { t } = useTranslation()
|
||||
|
||||
return (
|
||||
<ConfigSectionCard title={t("pages.config.sections.cron")}>
|
||||
<SwitchCardField
|
||||
label={t("pages.config.allow_shell_execution")}
|
||||
hint={t("pages.config.allow_shell_execution_hint")}
|
||||
layout="setting-row"
|
||||
checked={form.allowCommand}
|
||||
disabled={!form.execEnabled}
|
||||
onCheckedChange={(checked) => onFieldChange("allowCommand", checked)}
|
||||
/>
|
||||
|
||||
<Field
|
||||
label={t("pages.config.cron_exec_timeout")}
|
||||
hint={t("pages.config.cron_exec_timeout_hint")}
|
||||
layout="setting-row"
|
||||
>
|
||||
<Input
|
||||
type="number"
|
||||
min={0}
|
||||
disabled={!form.execEnabled}
|
||||
value={form.cronExecTimeoutMinutes}
|
||||
onChange={(e) =>
|
||||
onFieldChange("cronExecTimeoutMinutes", e.target.value)
|
||||
}
|
||||
/>
|
||||
</Field>
|
||||
</ConfigSectionCard>
|
||||
)
|
||||
}
|
||||
|
||||
interface LauncherSectionProps {
|
||||
launcherForm: LauncherForm
|
||||
onFieldChange: UpdateLauncherField
|
||||
|
||||
@@ -3,7 +3,14 @@ export type JsonRecord = Record<string, unknown>
|
||||
export interface CoreConfigForm {
|
||||
workspace: string
|
||||
restrictToWorkspace: boolean
|
||||
execEnabled: boolean
|
||||
allowRemote: boolean
|
||||
enableDenyPatterns: boolean
|
||||
customDenyPatternsText: string
|
||||
customAllowPatternsText: string
|
||||
execTimeoutSeconds: string
|
||||
allowCommand: boolean
|
||||
cronExecTimeoutMinutes: string
|
||||
maxTokens: string
|
||||
maxToolIterations: string
|
||||
summarizeMessageThreshold: string
|
||||
@@ -55,7 +62,14 @@ export const DM_SCOPE_OPTIONS = [
|
||||
export const EMPTY_FORM: CoreConfigForm = {
|
||||
workspace: "",
|
||||
restrictToWorkspace: true,
|
||||
execEnabled: true,
|
||||
allowRemote: true,
|
||||
enableDenyPatterns: true,
|
||||
customDenyPatternsText: "",
|
||||
customAllowPatternsText: "",
|
||||
execTimeoutSeconds: "0",
|
||||
allowCommand: true,
|
||||
cronExecTimeoutMinutes: "5",
|
||||
maxTokens: "32768",
|
||||
maxToolIterations: "50",
|
||||
summarizeMessageThreshold: "20",
|
||||
@@ -106,6 +120,7 @@ export function buildFormFromConfig(config: unknown): CoreConfigForm {
|
||||
const heartbeat = asRecord(root.heartbeat)
|
||||
const devices = asRecord(root.devices)
|
||||
const tools = asRecord(root.tools)
|
||||
const cron = asRecord(tools.cron)
|
||||
const exec = asRecord(tools.exec)
|
||||
|
||||
return {
|
||||
@@ -114,10 +129,40 @@ export function buildFormFromConfig(config: unknown): CoreConfigForm {
|
||||
defaults.restrict_to_workspace === undefined
|
||||
? EMPTY_FORM.restrictToWorkspace
|
||||
: asBool(defaults.restrict_to_workspace),
|
||||
execEnabled:
|
||||
exec.enabled === undefined
|
||||
? EMPTY_FORM.execEnabled
|
||||
: asBool(exec.enabled),
|
||||
allowRemote:
|
||||
exec.allow_remote === undefined
|
||||
? EMPTY_FORM.allowRemote
|
||||
: asBool(exec.allow_remote),
|
||||
enableDenyPatterns:
|
||||
exec.enable_deny_patterns === undefined
|
||||
? EMPTY_FORM.enableDenyPatterns
|
||||
: asBool(exec.enable_deny_patterns),
|
||||
customDenyPatternsText: Array.isArray(exec.custom_deny_patterns)
|
||||
? exec.custom_deny_patterns
|
||||
.filter((value): value is string => typeof value === "string")
|
||||
.join("\n")
|
||||
: EMPTY_FORM.customDenyPatternsText,
|
||||
customAllowPatternsText: Array.isArray(exec.custom_allow_patterns)
|
||||
? exec.custom_allow_patterns
|
||||
.filter((value): value is string => typeof value === "string")
|
||||
.join("\n")
|
||||
: EMPTY_FORM.customAllowPatternsText,
|
||||
execTimeoutSeconds: asNumberString(
|
||||
exec.timeout_seconds,
|
||||
EMPTY_FORM.execTimeoutSeconds,
|
||||
),
|
||||
allowCommand:
|
||||
cron.allow_command === undefined
|
||||
? EMPTY_FORM.allowCommand
|
||||
: asBool(cron.allow_command),
|
||||
cronExecTimeoutMinutes: asNumberString(
|
||||
cron.exec_timeout_minutes,
|
||||
EMPTY_FORM.cronExecTimeoutMinutes,
|
||||
),
|
||||
maxTokens: asNumberString(defaults.max_tokens, EMPTY_FORM.maxTokens),
|
||||
maxToolIterations: asNumberString(
|
||||
defaults.max_tool_iterations,
|
||||
@@ -178,3 +223,13 @@ export function parseCIDRText(raw: string): string[] {
|
||||
.map((v) => v.trim())
|
||||
.filter((v) => v.length > 0)
|
||||
}
|
||||
|
||||
export function parseMultilineList(raw: string): string[] {
|
||||
if (!raw.trim()) {
|
||||
return []
|
||||
}
|
||||
return raw
|
||||
.split("\n")
|
||||
.map((value) => value.trim())
|
||||
.filter((value) => value.length > 0)
|
||||
}
|
||||
|
||||
+196
-142
@@ -2,24 +2,24 @@ import { getDefaultStore } from "jotai"
|
||||
import { toast } from "sonner"
|
||||
|
||||
import { getPicoToken } from "@/api/pico"
|
||||
import { getSessionHistory } from "@/api/sessions"
|
||||
import i18n from "@/i18n"
|
||||
import {
|
||||
loadSessionMessages,
|
||||
mergeHistoryMessages,
|
||||
} from "@/features/chat/history"
|
||||
import { type PicoMessage, handlePicoMessage } from "@/features/chat/protocol"
|
||||
import {
|
||||
clearStoredSessionId,
|
||||
generateSessionId,
|
||||
normalizeUnixTimestamp,
|
||||
readStoredSessionId,
|
||||
} from "@/lib/pico-chat-state"
|
||||
import { type ChatMessage, getChatState, updateChatStore } from "@/store/chat"
|
||||
import { gatewayAtom } from "@/store/gateway"
|
||||
|
||||
interface PicoMessage {
|
||||
type: string
|
||||
id?: string
|
||||
session_id?: string
|
||||
timestamp?: number | string
|
||||
payload?: Record<string, unknown>
|
||||
}
|
||||
} from "@/features/chat/state"
|
||||
import {
|
||||
invalidateSocket,
|
||||
isCurrentSocket,
|
||||
normalizeWsUrlForBrowser,
|
||||
} from "@/features/chat/websocket"
|
||||
import i18n from "@/i18n"
|
||||
import { getChatState, updateChatStore } from "@/store/chat"
|
||||
import { type GatewayState, gatewayAtom } from "@/store/gateway"
|
||||
|
||||
const store = getDefaultStore()
|
||||
|
||||
@@ -31,81 +31,51 @@ let initialized = false
|
||||
let unsubscribeGateway: (() => void) | null = null
|
||||
let hydratePromise: Promise<void> | null = null
|
||||
let connectionGeneration = 0
|
||||
let reconnectTimer: number | null = null
|
||||
let reconnectAttempts = 0
|
||||
let shouldMaintainConnection = false
|
||||
|
||||
async function loadSessionMessages(sessionId: string): Promise<ChatMessage[]> {
|
||||
const detail = await getSessionHistory(sessionId)
|
||||
const fallbackTime = detail.updated
|
||||
|
||||
return detail.messages.map((message, index) => ({
|
||||
id: `hist-${index}-${Date.now()}`,
|
||||
role: message.role,
|
||||
content: message.content,
|
||||
timestamp: fallbackTime,
|
||||
}))
|
||||
function clearReconnectTimer() {
|
||||
if (reconnectTimer !== null) {
|
||||
window.clearTimeout(reconnectTimer)
|
||||
reconnectTimer = null
|
||||
}
|
||||
}
|
||||
|
||||
function handlePicoMessage(message: PicoMessage) {
|
||||
const payload = message.payload || {}
|
||||
function shouldReconnectFor(generation: number, sessionId: string): boolean {
|
||||
return (
|
||||
shouldMaintainConnection &&
|
||||
generation === connectionGeneration &&
|
||||
sessionId === activeSessionIdRef &&
|
||||
store.get(gatewayAtom).status === "running"
|
||||
)
|
||||
}
|
||||
|
||||
switch (message.type) {
|
||||
case "message.create": {
|
||||
const content = (payload.content as string) || ""
|
||||
const messageId = (payload.message_id as string) || `pico-${Date.now()}`
|
||||
const timestamp =
|
||||
message.timestamp !== undefined &&
|
||||
Number.isFinite(Number(message.timestamp))
|
||||
? normalizeUnixTimestamp(Number(message.timestamp))
|
||||
: Date.now()
|
||||
|
||||
updateChatStore((prev) => ({
|
||||
messages: [
|
||||
...prev.messages,
|
||||
{
|
||||
id: messageId,
|
||||
role: "assistant",
|
||||
content,
|
||||
timestamp,
|
||||
},
|
||||
],
|
||||
isTyping: false,
|
||||
}))
|
||||
break
|
||||
}
|
||||
|
||||
case "message.update": {
|
||||
const content = (payload.content as string) || ""
|
||||
const messageId = payload.message_id as string
|
||||
if (!messageId) {
|
||||
break
|
||||
}
|
||||
|
||||
updateChatStore((prev) => ({
|
||||
messages: prev.messages.map((msg) =>
|
||||
msg.id === messageId ? { ...msg, content } : msg,
|
||||
),
|
||||
}))
|
||||
break
|
||||
}
|
||||
|
||||
case "typing.start":
|
||||
updateChatStore({ isTyping: true })
|
||||
break
|
||||
|
||||
case "typing.stop":
|
||||
updateChatStore({ isTyping: false })
|
||||
break
|
||||
|
||||
case "error":
|
||||
console.error("Pico error:", payload)
|
||||
updateChatStore({ isTyping: false })
|
||||
break
|
||||
|
||||
case "pong":
|
||||
break
|
||||
|
||||
default:
|
||||
console.log("Unknown pico message type:", message.type)
|
||||
function scheduleReconnect(generation: number, sessionId: string) {
|
||||
if (!shouldReconnectFor(generation, sessionId) || reconnectTimer !== null) {
|
||||
return
|
||||
}
|
||||
|
||||
const delay = Math.min(1000 * 2 ** reconnectAttempts, 5000)
|
||||
reconnectAttempts += 1
|
||||
reconnectTimer = window.setTimeout(() => {
|
||||
reconnectTimer = null
|
||||
if (!shouldReconnectFor(generation, sessionId)) {
|
||||
return
|
||||
}
|
||||
void connectChat()
|
||||
}, delay)
|
||||
}
|
||||
|
||||
function needsActiveSessionHydration(): boolean {
|
||||
const state = getChatState()
|
||||
const storedSessionId = readStoredSessionId()
|
||||
|
||||
return Boolean(
|
||||
storedSessionId &&
|
||||
storedSessionId === state.activeSessionId &&
|
||||
!state.hasHydratedActiveSession,
|
||||
)
|
||||
}
|
||||
|
||||
function setActiveSessionId(sessionId: string) {
|
||||
@@ -113,8 +83,35 @@ function setActiveSessionId(sessionId: string) {
|
||||
updateChatStore({ activeSessionId: sessionId })
|
||||
}
|
||||
|
||||
function disconnectChatInternal({
|
||||
clearDesiredConnection,
|
||||
}: {
|
||||
clearDesiredConnection: boolean
|
||||
}) {
|
||||
connectionGeneration += 1
|
||||
clearReconnectTimer()
|
||||
|
||||
if (clearDesiredConnection) {
|
||||
shouldMaintainConnection = false
|
||||
}
|
||||
|
||||
const socket = wsRef
|
||||
wsRef = null
|
||||
isConnecting = false
|
||||
|
||||
invalidateSocket(socket)
|
||||
|
||||
updateChatStore({
|
||||
connectionState: "disconnected",
|
||||
isTyping: false,
|
||||
})
|
||||
}
|
||||
|
||||
export async function connectChat() {
|
||||
if (store.get(gatewayAtom).status !== "running") {
|
||||
if (
|
||||
store.get(gatewayAtom).status !== "running" ||
|
||||
needsActiveSessionHydration()
|
||||
) {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -130,12 +127,15 @@ export async function connectChat() {
|
||||
const generation = connectionGeneration + 1
|
||||
connectionGeneration = generation
|
||||
isConnecting = true
|
||||
clearReconnectTimer()
|
||||
updateChatStore({ connectionState: "connecting" })
|
||||
|
||||
try {
|
||||
const { token, ws_url } = await getPicoToken()
|
||||
const sessionId = activeSessionIdRef
|
||||
|
||||
if (generation !== connectionGeneration) {
|
||||
isConnecting = false
|
||||
return
|
||||
}
|
||||
|
||||
@@ -143,55 +143,71 @@ export async function connectChat() {
|
||||
console.error("No pico token available")
|
||||
updateChatStore({ connectionState: "error" })
|
||||
isConnecting = false
|
||||
scheduleReconnect(generation, sessionId)
|
||||
return
|
||||
}
|
||||
|
||||
let finalWsUrl = ws_url
|
||||
try {
|
||||
const parsedUrl = new URL(ws_url)
|
||||
const isLocalHost =
|
||||
parsedUrl.hostname === "localhost" ||
|
||||
parsedUrl.hostname === "127.0.0.1" ||
|
||||
parsedUrl.hostname === "0.0.0.0"
|
||||
const isBrowserLocal =
|
||||
window.location.hostname === "localhost" ||
|
||||
window.location.hostname === "127.0.0.1"
|
||||
|
||||
if (isLocalHost && !isBrowserLocal) {
|
||||
parsedUrl.hostname = window.location.hostname
|
||||
finalWsUrl = parsedUrl.toString()
|
||||
}
|
||||
} catch (error) {
|
||||
console.warn("Could not parse ws_url:", error)
|
||||
}
|
||||
|
||||
const url = `${finalWsUrl}?token=${encodeURIComponent(token)}&session_id=${encodeURIComponent(activeSessionIdRef)}`
|
||||
const socket = new WebSocket(url)
|
||||
const finalWsUrl = normalizeWsUrlForBrowser(ws_url)
|
||||
const url = `${finalWsUrl}?session_id=${encodeURIComponent(sessionId)}`
|
||||
const socket = new WebSocket(url, [`token.${token}`])
|
||||
|
||||
if (generation !== connectionGeneration) {
|
||||
socket.close()
|
||||
isConnecting = false
|
||||
invalidateSocket(socket)
|
||||
return
|
||||
}
|
||||
|
||||
socket.onopen = () => {
|
||||
if (wsRef !== socket) {
|
||||
if (
|
||||
!isCurrentSocket({
|
||||
socket,
|
||||
currentSocket: wsRef,
|
||||
generation,
|
||||
currentGeneration: connectionGeneration,
|
||||
sessionId,
|
||||
currentSessionId: activeSessionIdRef,
|
||||
})
|
||||
) {
|
||||
return
|
||||
}
|
||||
updateChatStore({ connectionState: "connected" })
|
||||
isConnecting = false
|
||||
reconnectAttempts = 0
|
||||
}
|
||||
|
||||
socket.onmessage = (event) => {
|
||||
if (
|
||||
!isCurrentSocket({
|
||||
socket,
|
||||
currentSocket: wsRef,
|
||||
generation,
|
||||
currentGeneration: connectionGeneration,
|
||||
sessionId,
|
||||
currentSessionId: activeSessionIdRef,
|
||||
})
|
||||
) {
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
const message: PicoMessage = JSON.parse(event.data)
|
||||
handlePicoMessage(message)
|
||||
const message = JSON.parse(event.data) as PicoMessage
|
||||
handlePicoMessage(message, sessionId)
|
||||
} catch {
|
||||
console.warn("Non-JSON message from pico:", event.data)
|
||||
}
|
||||
}
|
||||
|
||||
socket.onclose = () => {
|
||||
if (wsRef !== socket) {
|
||||
if (
|
||||
!isCurrentSocket({
|
||||
socket,
|
||||
currentSocket: wsRef,
|
||||
generation,
|
||||
currentGeneration: connectionGeneration,
|
||||
sessionId,
|
||||
currentSessionId: activeSessionIdRef,
|
||||
})
|
||||
) {
|
||||
return
|
||||
}
|
||||
wsRef = null
|
||||
@@ -200,42 +216,42 @@ export async function connectChat() {
|
||||
connectionState: "disconnected",
|
||||
isTyping: false,
|
||||
})
|
||||
scheduleReconnect(generation, sessionId)
|
||||
}
|
||||
|
||||
socket.onerror = () => {
|
||||
if (wsRef !== socket) {
|
||||
if (
|
||||
!isCurrentSocket({
|
||||
socket,
|
||||
currentSocket: wsRef,
|
||||
generation,
|
||||
currentGeneration: connectionGeneration,
|
||||
sessionId,
|
||||
currentSessionId: activeSessionIdRef,
|
||||
})
|
||||
) {
|
||||
return
|
||||
}
|
||||
isConnecting = false
|
||||
updateChatStore({ connectionState: "error" })
|
||||
scheduleReconnect(generation, sessionId)
|
||||
}
|
||||
|
||||
wsRef = socket
|
||||
} catch (error) {
|
||||
if (generation !== connectionGeneration) {
|
||||
isConnecting = false
|
||||
return
|
||||
}
|
||||
console.error("Failed to connect to pico:", error)
|
||||
updateChatStore({ connectionState: "error" })
|
||||
isConnecting = false
|
||||
scheduleReconnect(generation, activeSessionIdRef)
|
||||
}
|
||||
}
|
||||
|
||||
export function disconnectChat() {
|
||||
connectionGeneration += 1
|
||||
|
||||
const socket = wsRef
|
||||
wsRef = null
|
||||
isConnecting = false
|
||||
|
||||
if (socket) {
|
||||
socket.close()
|
||||
}
|
||||
|
||||
updateChatStore({
|
||||
connectionState: "disconnected",
|
||||
isTyping: false,
|
||||
})
|
||||
disconnectChatInternal({ clearDesiredConnection: true })
|
||||
}
|
||||
|
||||
export async function hydrateActiveSession() {
|
||||
@@ -249,7 +265,6 @@ export async function hydrateActiveSession() {
|
||||
if (
|
||||
!storedSessionId ||
|
||||
state.hasHydratedActiveSession ||
|
||||
state.messages.length > 0 ||
|
||||
storedSessionId !== state.activeSessionId
|
||||
) {
|
||||
if (!state.hasHydratedActiveSession) {
|
||||
@@ -266,7 +281,13 @@ export async function hydrateActiveSession() {
|
||||
}
|
||||
|
||||
if (currentState.messages.length > 0) {
|
||||
updateChatStore({ hasHydratedActiveSession: true })
|
||||
updateChatStore({
|
||||
messages: mergeHistoryMessages(
|
||||
historyMessages,
|
||||
currentState.messages,
|
||||
),
|
||||
hasHydratedActiveSession: true,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -306,9 +327,10 @@ export async function hydrateActiveSession() {
|
||||
export function sendChatMessage(content: string) {
|
||||
if (!wsRef || wsRef.readyState !== WebSocket.OPEN) {
|
||||
console.warn("WebSocket not connected")
|
||||
return
|
||||
return false
|
||||
}
|
||||
|
||||
const socket = wsRef
|
||||
const id = `msg-${++msgIdCounter}-${Date.now()}`
|
||||
|
||||
updateChatStore((prev) => ({
|
||||
@@ -319,13 +341,23 @@ export function sendChatMessage(content: string) {
|
||||
isTyping: true,
|
||||
}))
|
||||
|
||||
wsRef.send(
|
||||
JSON.stringify({
|
||||
type: "message.send",
|
||||
id,
|
||||
payload: { content },
|
||||
}),
|
||||
)
|
||||
try {
|
||||
socket.send(
|
||||
JSON.stringify({
|
||||
type: "message.send",
|
||||
id,
|
||||
payload: { content },
|
||||
}),
|
||||
)
|
||||
return true
|
||||
} catch (error) {
|
||||
console.error("Failed to send pico message:", error)
|
||||
updateChatStore((prev) => ({
|
||||
messages: prev.messages.filter((message) => message.id !== id),
|
||||
isTyping: false,
|
||||
}))
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
export async function switchChatSession(sessionId: string) {
|
||||
@@ -336,7 +368,7 @@ export async function switchChatSession(sessionId: string) {
|
||||
try {
|
||||
const historyMessages = await loadSessionMessages(sessionId)
|
||||
|
||||
disconnectChat()
|
||||
disconnectChatInternal({ clearDesiredConnection: false })
|
||||
setActiveSessionId(sessionId)
|
||||
updateChatStore({
|
||||
messages: historyMessages,
|
||||
@@ -345,6 +377,7 @@ export async function switchChatSession(sessionId: string) {
|
||||
})
|
||||
|
||||
if (store.get(gatewayAtom).status === "running") {
|
||||
shouldMaintainConnection = true
|
||||
await connectChat()
|
||||
}
|
||||
} catch (error) {
|
||||
@@ -358,7 +391,7 @@ export async function newChatSession() {
|
||||
return
|
||||
}
|
||||
|
||||
disconnectChat()
|
||||
disconnectChatInternal({ clearDesiredConnection: false })
|
||||
setActiveSessionId(generateSessionId())
|
||||
updateChatStore({
|
||||
messages: [],
|
||||
@@ -367,6 +400,7 @@ export async function newChatSession() {
|
||||
})
|
||||
|
||||
if (store.get(gatewayAtom).status === "running") {
|
||||
shouldMaintainConnection = true
|
||||
await connectChat()
|
||||
}
|
||||
}
|
||||
@@ -378,23 +412,43 @@ export function initializeChatStore() {
|
||||
|
||||
initialized = true
|
||||
activeSessionIdRef = getChatState().activeSessionId
|
||||
let lastGatewayStatus: GatewayState | null = null
|
||||
|
||||
const syncConnectionWithGateway = () => {
|
||||
if (store.get(gatewayAtom).status === "running") {
|
||||
const syncConnectionWithGateway = (force: boolean = false) => {
|
||||
const gatewayStatus = store.get(gatewayAtom).status
|
||||
if (!force && gatewayStatus === lastGatewayStatus) {
|
||||
return
|
||||
}
|
||||
lastGatewayStatus = gatewayStatus
|
||||
|
||||
if (gatewayStatus === "running") {
|
||||
shouldMaintainConnection = true
|
||||
if (needsActiveSessionHydration()) {
|
||||
return
|
||||
}
|
||||
void connectChat()
|
||||
return
|
||||
}
|
||||
|
||||
disconnectChat()
|
||||
if (gatewayStatus === "stopped" || gatewayStatus === "error") {
|
||||
disconnectChatInternal({ clearDesiredConnection: true })
|
||||
}
|
||||
}
|
||||
|
||||
unsubscribeGateway = store.sub(gatewayAtom, syncConnectionWithGateway)
|
||||
|
||||
if (!readStoredSessionId()) {
|
||||
updateChatStore({ hasHydratedActiveSession: true })
|
||||
syncConnectionWithGateway(true)
|
||||
return
|
||||
}
|
||||
|
||||
syncConnectionWithGateway()
|
||||
void hydrateActiveSession().finally(() => {
|
||||
if (!initialized) {
|
||||
return
|
||||
}
|
||||
syncConnectionWithGateway(true)
|
||||
})
|
||||
}
|
||||
|
||||
export function teardownChatStore() {
|
||||
@@ -0,0 +1,68 @@
|
||||
import { getSessionHistory } from "@/api/sessions"
|
||||
import { normalizeUnixTimestamp } from "@/features/chat/state"
|
||||
import type { ChatMessage } from "@/store/chat"
|
||||
|
||||
export async function loadSessionMessages(
|
||||
sessionId: string,
|
||||
): Promise<ChatMessage[]> {
|
||||
const detail = await getSessionHistory(sessionId)
|
||||
const fallbackTime = detail.updated
|
||||
|
||||
return detail.messages.map((message, index) => ({
|
||||
id: `hist-${index}-${Date.now()}`,
|
||||
role: message.role,
|
||||
content: message.content,
|
||||
timestamp: fallbackTime,
|
||||
}))
|
||||
}
|
||||
|
||||
function normalizeMessageTimestamp(timestamp: number | string): string {
|
||||
if (typeof timestamp === "number") {
|
||||
return String(normalizeUnixTimestamp(timestamp))
|
||||
}
|
||||
|
||||
const trimmed = timestamp.trim()
|
||||
if (/^-?\d+(\.\d+)?$/.test(trimmed)) {
|
||||
return String(normalizeUnixTimestamp(Number(trimmed)))
|
||||
}
|
||||
|
||||
const parsed = Date.parse(trimmed)
|
||||
return Number.isNaN(parsed) ? trimmed : String(parsed)
|
||||
}
|
||||
|
||||
function messageSignature(message: ChatMessage): string {
|
||||
return `${message.role}\u0000${message.content}\u0000${normalizeMessageTimestamp(
|
||||
message.timestamp,
|
||||
)}`
|
||||
}
|
||||
|
||||
function comparableTimestamp(timestamp: number | string): number {
|
||||
const normalized = normalizeMessageTimestamp(timestamp)
|
||||
const numeric = Number(normalized)
|
||||
return Number.isFinite(numeric) ? numeric : 0
|
||||
}
|
||||
|
||||
export function mergeHistoryMessages(
|
||||
historyMessages: ChatMessage[],
|
||||
currentMessages: ChatMessage[],
|
||||
): ChatMessage[] {
|
||||
const currentIds = new Set(currentMessages.map((message) => message.id))
|
||||
const currentSignatures = new Set(
|
||||
currentMessages.map((message) => messageSignature(message)),
|
||||
)
|
||||
|
||||
const merged = [
|
||||
...historyMessages.filter(
|
||||
(message) =>
|
||||
!currentIds.has(message.id) &&
|
||||
!currentSignatures.has(messageSignature(message)),
|
||||
),
|
||||
...currentMessages,
|
||||
]
|
||||
|
||||
return merged.sort(
|
||||
(left, right) =>
|
||||
comparableTimestamp(left.timestamp) -
|
||||
comparableTimestamp(right.timestamp),
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
import { normalizeUnixTimestamp } from "@/features/chat/state"
|
||||
import { updateChatStore } from "@/store/chat"
|
||||
|
||||
export interface PicoMessage {
|
||||
type: string
|
||||
id?: string
|
||||
session_id?: string
|
||||
timestamp?: number | string
|
||||
payload?: Record<string, unknown>
|
||||
}
|
||||
|
||||
export function handlePicoMessage(
|
||||
message: PicoMessage,
|
||||
expectedSessionId: string,
|
||||
) {
|
||||
if (message.session_id && message.session_id !== expectedSessionId) {
|
||||
return
|
||||
}
|
||||
|
||||
const payload = message.payload || {}
|
||||
|
||||
switch (message.type) {
|
||||
case "message.create": {
|
||||
const content = (payload.content as string) || ""
|
||||
const messageId = (payload.message_id as string) || `pico-${Date.now()}`
|
||||
const timestamp =
|
||||
message.timestamp !== undefined &&
|
||||
Number.isFinite(Number(message.timestamp))
|
||||
? normalizeUnixTimestamp(Number(message.timestamp))
|
||||
: Date.now()
|
||||
|
||||
updateChatStore((prev) => ({
|
||||
messages: [
|
||||
...prev.messages,
|
||||
{
|
||||
id: messageId,
|
||||
role: "assistant",
|
||||
content,
|
||||
timestamp,
|
||||
},
|
||||
],
|
||||
isTyping: false,
|
||||
}))
|
||||
break
|
||||
}
|
||||
|
||||
case "message.update": {
|
||||
const content = (payload.content as string) || ""
|
||||
const messageId = payload.message_id as string
|
||||
if (!messageId) {
|
||||
break
|
||||
}
|
||||
|
||||
updateChatStore((prev) => ({
|
||||
messages: prev.messages.map((msg) =>
|
||||
msg.id === messageId ? { ...msg, content } : msg,
|
||||
),
|
||||
}))
|
||||
break
|
||||
}
|
||||
|
||||
case "typing.start":
|
||||
updateChatStore({ isTyping: true })
|
||||
break
|
||||
|
||||
case "typing.stop":
|
||||
updateChatStore({ isTyping: false })
|
||||
break
|
||||
|
||||
case "error":
|
||||
console.error("Pico error:", payload)
|
||||
updateChatStore({ isTyping: false })
|
||||
break
|
||||
|
||||
case "pong":
|
||||
break
|
||||
|
||||
default:
|
||||
console.log("Unknown pico message type:", message.type)
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user