mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
feat(web): add agent management UI and improve launcher integration (#1358)
* Improve the web launcher and gateway integration across backend and frontend. - add runtime model availability checks for local and OAuth-backed models - support launcher-driven gateway host overrides and websocket URL resolution - add gateway log clearing and keep incremental log sync consistent after resets - migrate session history APIs to JSONL metadata-backed storage with legacy fallback - expose session titles and improve chat history loading and error handling - move shared backend runtime helpers into the web utils package - avoid blocking web startup when automatic onboard initialization fails - add backend tests covering gateway readiness, host resolution, models, logs, and sessions * feat(agent): add skills and tools management APIs and UI - add backend APIs to list, view, import, and delete skills - add tool status and toggle endpoints with dependency-aware config updates - add agent skills/tools pages, routes, sidebar entries, and i18n strings - add backend tests for the new skills and tools flows * chore(frontend): upgrade shadcn to 4.0.5 and refresh lockfile * chore(web): keep backend dist placeholder tracked
This commit is contained in:
@@ -5,7 +5,6 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
@@ -17,36 +16,11 @@ func (h *Handler) registerConfigRoutes(mux *http.ServeMux) {
|
||||
mux.HandleFunc("PATCH /api/config", h.handlePatchConfig)
|
||||
}
|
||||
|
||||
// loadFilteredConfig loads the configuration and filters out default placeholder credentials
|
||||
// (like API limits/keys) if the configuration file has not been created yet by the user.
|
||||
func (h *Handler) loadFilteredConfig() (*config.Config, error) {
|
||||
cfg, err := config.LoadConfig(h.configPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
configExists := false
|
||||
if h.configPath != "" {
|
||||
if _, err := os.Stat(h.configPath); err == nil {
|
||||
configExists = true
|
||||
}
|
||||
}
|
||||
|
||||
if !configExists {
|
||||
for i := range cfg.ModelList {
|
||||
cfg.ModelList[i].APIKey = ""
|
||||
cfg.ModelList[i].AuthMethod = ""
|
||||
}
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// handleGetConfig returns the complete system configuration.
|
||||
//
|
||||
// GET /api/config
|
||||
func (h *Handler) handleGetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
cfg, err := h.loadFilteredConfig()
|
||||
cfg, err := config.LoadConfig(h.configPath)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to load config: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
|
||||
+28
-43
@@ -10,7 +10,6 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -19,6 +18,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/web/backend/utils"
|
||||
)
|
||||
|
||||
// gateway holds the state for the managed gateway process.
|
||||
@@ -36,6 +36,7 @@ var gateway = struct {
|
||||
func (h *Handler) registerGatewayRoutes(mux *http.ServeMux) {
|
||||
mux.HandleFunc("GET /api/gateway/status", h.handleGatewayStatus)
|
||||
mux.HandleFunc("GET /api/gateway/events", h.handleGatewayEvents)
|
||||
mux.HandleFunc("POST /api/gateway/logs/clear", h.handleGatewayClearLogs)
|
||||
mux.HandleFunc("POST /api/gateway/start", h.handleGatewayStart)
|
||||
mux.HandleFunc("POST /api/gateway/stop", h.handleGatewayStop)
|
||||
mux.HandleFunc("POST /api/gateway/restart", h.handleGatewayRestart)
|
||||
@@ -89,11 +90,12 @@ func (h *Handler) gatewayStartReady() (bool, string, error) {
|
||||
return false, fmt.Sprintf("default model %q is invalid", modelName), nil
|
||||
}
|
||||
|
||||
hasCredential := strings.TrimSpace(modelCfg.APIKey) != "" ||
|
||||
strings.TrimSpace(modelCfg.AuthMethod) != ""
|
||||
if !hasCredential {
|
||||
if !hasModelConfiguration(*modelCfg) {
|
||||
return false, fmt.Sprintf("default model %q has no credentials configured", modelName), nil
|
||||
}
|
||||
if requiresRuntimeProbe(*modelCfg) && !probeLocalModelAvailability(*modelCfg) {
|
||||
return false, fmt.Sprintf("default model %q is not reachable", modelName), nil
|
||||
}
|
||||
|
||||
return true, "", nil
|
||||
}
|
||||
@@ -131,14 +133,18 @@ func isCmdProcessAliveLocked(cmd *exec.Cmd) bool {
|
||||
|
||||
func (h *Handler) startGatewayLocked() (int, error) {
|
||||
// Locate the picoclaw executable
|
||||
execPath := findPicoclawBinary()
|
||||
execPath := utils.FindPicoclawBinary()
|
||||
|
||||
cmd := exec.Command(execPath, "gateway")
|
||||
cmd.Env = os.Environ()
|
||||
// Forward the launcher's config path via the environment variable that
|
||||
// GetConfigPath() already reads, so the gateway sub-process uses the same
|
||||
// config file without requiring a --config flag on the gateway subcommand.
|
||||
if h.configPath != "" {
|
||||
cmd.Env = append(os.Environ(), "PICOCLAW_CONFIG="+h.configPath)
|
||||
cmd.Env = append(cmd.Env, "PICOCLAW_CONFIG="+h.configPath)
|
||||
}
|
||||
if host := h.gatewayHostOverride(); host != "" {
|
||||
cmd.Env = append(cmd.Env, "PICOCLAW_GATEWAY_HOST="+host)
|
||||
}
|
||||
|
||||
stdoutPipe, err := cmd.StdoutPipe()
|
||||
@@ -207,10 +213,7 @@ func (h *Handler) startGatewayLocked() (int, error) {
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
healthHost := "127.0.0.1"
|
||||
if cfg.Gateway.Host != "" && cfg.Gateway.Host != "0.0.0.0" {
|
||||
healthHost = cfg.Gateway.Host
|
||||
}
|
||||
healthHost := gatewayProbeHost(h.effectiveGatewayBindHost(cfg))
|
||||
healthPort := cfg.Gateway.Port
|
||||
if healthPort == 0 {
|
||||
healthPort = 18790
|
||||
@@ -353,6 +356,20 @@ func (h *Handler) handleGatewayRestart(w http.ResponseWriter, r *http.Request) {
|
||||
h.handleGatewayStart(w, r)
|
||||
}
|
||||
|
||||
// handleGatewayClearLogs clears the in-memory gateway log buffer.
|
||||
//
|
||||
// POST /api/gateway/logs/clear
|
||||
func (h *Handler) handleGatewayClearLogs(w http.ResponseWriter, r *http.Request) {
|
||||
gateway.logs.Clear()
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"status": "cleared",
|
||||
"log_total": 0,
|
||||
"log_run_id": gateway.logs.RunID(),
|
||||
})
|
||||
}
|
||||
|
||||
// handleGatewayStatus returns the gateway run status, health info, and logs.
|
||||
//
|
||||
// GET /api/gateway/status
|
||||
@@ -375,9 +392,7 @@ func (h *Handler) handleGatewayStatus(w http.ResponseWriter, r *http.Request) {
|
||||
host := "127.0.0.1"
|
||||
port := 18790
|
||||
if err == nil && cfg != nil {
|
||||
if cfg.Gateway.Host != "" && cfg.Gateway.Host != "0.0.0.0" {
|
||||
host = cfg.Gateway.Host
|
||||
}
|
||||
host = gatewayProbeHost(h.effectiveGatewayBindHost(cfg))
|
||||
if cfg.Gateway.Port != 0 {
|
||||
port = cfg.Gateway.Port
|
||||
}
|
||||
@@ -535,36 +550,6 @@ func (h *Handler) currentGatewayStatus() string {
|
||||
return string(encoded)
|
||||
}
|
||||
|
||||
// findPicoclawBinary locates the picoclaw executable.
|
||||
// Search order:
|
||||
// 1. PICOCLAW_BINARY environment variable (explicit override)
|
||||
// 2. Same directory as the current executable
|
||||
// 3. Falls back to "picoclaw" and relies on $PATH
|
||||
func findPicoclawBinary() string {
|
||||
binaryName := "picoclaw"
|
||||
if runtime.GOOS == "windows" {
|
||||
binaryName = "picoclaw.exe"
|
||||
}
|
||||
|
||||
// 1. Explicit override via environment variable
|
||||
if p := os.Getenv("PICOCLAW_BINARY"); p != "" {
|
||||
if info, _ := os.Stat(p); info != nil && !info.IsDir() {
|
||||
return p
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Same directory as the launcher executable
|
||||
if exe, err := os.Executable(); err == nil {
|
||||
candidate := filepath.Join(filepath.Dir(exe), binaryName)
|
||||
if info, err := os.Stat(candidate); err == nil && !info.IsDir() {
|
||||
return candidate
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Fall back to PATH lookup
|
||||
return "picoclaw"
|
||||
}
|
||||
|
||||
// scanPipe reads lines from r and appends them to buf. Returns when r reaches EOF.
|
||||
func scanPipe(r io.Reader, buf *LogBuffer) {
|
||||
scanner := bufio.NewScanner(r)
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
func (h *Handler) effectiveLauncherPublic() bool {
|
||||
if h.serverPublicExplicit {
|
||||
return h.serverPublic
|
||||
}
|
||||
|
||||
cfg, err := h.loadLauncherConfig()
|
||||
if err == nil {
|
||||
return cfg.Public
|
||||
}
|
||||
|
||||
return h.serverPublic
|
||||
}
|
||||
|
||||
func (h *Handler) gatewayHostOverride() string {
|
||||
if h.effectiveLauncherPublic() {
|
||||
return "0.0.0.0"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (h *Handler) effectiveGatewayBindHost(cfg *config.Config) string {
|
||||
if override := h.gatewayHostOverride(); override != "" {
|
||||
return override
|
||||
}
|
||||
if cfg == nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(cfg.Gateway.Host)
|
||||
}
|
||||
|
||||
func gatewayProbeHost(bindHost string) string {
|
||||
if bindHost == "" || bindHost == "0.0.0.0" {
|
||||
return "127.0.0.1"
|
||||
}
|
||||
return bindHost
|
||||
}
|
||||
|
||||
func requestHostName(r *http.Request) string {
|
||||
reqHost, _, err := net.SplitHostPort(r.Host)
|
||||
if err == nil {
|
||||
return reqHost
|
||||
}
|
||||
if strings.TrimSpace(r.Host) != "" {
|
||||
return r.Host
|
||||
}
|
||||
return "127.0.0.1"
|
||||
}
|
||||
|
||||
func (h *Handler) buildWsURL(r *http.Request, cfg *config.Config) string {
|
||||
host := h.effectiveGatewayBindHost(cfg)
|
||||
if host == "" || host == "0.0.0.0" {
|
||||
host = requestHostName(r)
|
||||
}
|
||||
return "ws://" + net.JoinHostPort(host, strconv.Itoa(cfg.Gateway.Port)) + "/pico/ws"
|
||||
}
|
||||
@@ -0,0 +1,59 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/web/backend/launcherconfig"
|
||||
)
|
||||
|
||||
func TestGatewayHostOverrideUsesExplicitRuntimePublic(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
launcherPath := launcherconfig.PathForAppConfig(configPath)
|
||||
if err := launcherconfig.Save(launcherPath, launcherconfig.Config{
|
||||
Port: 18800,
|
||||
Public: false,
|
||||
}); err != nil {
|
||||
t.Fatalf("launcherconfig.Save() error = %v", err)
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
h.SetServerOptions(18800, true, true, nil)
|
||||
|
||||
if got := h.gatewayHostOverride(); got != "0.0.0.0" {
|
||||
t.Fatalf("gatewayHostOverride() = %q, want %q", got, "0.0.0.0")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildWsURLUsesRequestHostWhenLauncherPublicSaved(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
launcherPath := launcherconfig.PathForAppConfig(configPath)
|
||||
if err := launcherconfig.Save(launcherPath, launcherconfig.Config{
|
||||
Port: 18800,
|
||||
Public: true,
|
||||
}); err != nil {
|
||||
t.Fatalf("launcherconfig.Save() error = %v", err)
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
h.SetServerOptions(18800, false, false, nil)
|
||||
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Gateway.Host = "127.0.0.1"
|
||||
cfg.Gateway.Port = 18790
|
||||
|
||||
req := httptest.NewRequest("GET", "http://launcher.local/api/pico/token", nil)
|
||||
req.Host = "192.168.1.9:18800"
|
||||
|
||||
if got := h.buildWsURL(req, cfg); got != "ws://192.168.1.9:18790/pico/ws" {
|
||||
t.Fatalf("buildWsURL() = %q, want %q", got, "ws://192.168.1.9:18790/pico/ws")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayProbeHostUsesLoopbackForWildcardBind(t *testing.T) {
|
||||
if got := gatewayProbeHost("0.0.0.0"); got != "127.0.0.1" {
|
||||
t.Fatalf("gatewayProbeHost() = %q, want %q", got, "127.0.0.1")
|
||||
}
|
||||
}
|
||||
@@ -6,10 +6,13 @@ import (
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/auth"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/web/backend/utils"
|
||||
)
|
||||
|
||||
func TestGatewayStartReady_NoDefaultModel(t *testing.T) {
|
||||
@@ -32,7 +35,8 @@ func TestGatewayStartReady_InvalidDefaultModel(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Agents.Defaults.Model = "missing-model"
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
err := config.SaveConfig(configPath, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
@@ -54,7 +58,8 @@ func TestGatewayStartReady_ValidDefaultModel(t *testing.T) {
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Agents.Defaults.ModelName = cfg.ModelList[0].ModelName
|
||||
cfg.ModelList[0].APIKey = "test-key"
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
err := config.SaveConfig(configPath, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
@@ -74,7 +79,8 @@ func TestGatewayStartReady_DefaultModelWithoutCredential(t *testing.T) {
|
||||
cfg.Agents.Defaults.ModelName = cfg.ModelList[0].ModelName
|
||||
cfg.ModelList[0].APIKey = ""
|
||||
cfg.ModelList[0].AuthMethod = ""
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
err := config.SaveConfig(configPath, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
@@ -91,6 +97,195 @@ func TestGatewayStartReady_DefaultModelWithoutCredential(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayStartReady_LocalModelWithoutAPIKey(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
resetModelProbeHooks(t)
|
||||
|
||||
probeOpenAICompatibleModelFunc = func(apiBase, modelID string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
cfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
cfg.ModelList = []config.ModelConfig{{
|
||||
ModelName: "local-vllm",
|
||||
Model: "vllm/custom-model",
|
||||
APIBase: "http://localhost:8000/v1",
|
||||
}}
|
||||
cfg.Agents.Defaults.ModelName = "local-vllm"
|
||||
err = config.SaveConfig(configPath, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
ready, reason, err := h.gatewayStartReady()
|
||||
if err != nil {
|
||||
t.Fatalf("gatewayStartReady() error = %v", err)
|
||||
}
|
||||
if ready {
|
||||
t.Fatalf("gatewayStartReady() ready = true, want false without a running local service")
|
||||
}
|
||||
if !strings.Contains(reason, "not reachable") {
|
||||
t.Fatalf("gatewayStartReady() reason = %q, want contains %q", reason, "not reachable")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayStartReady_LocalModelWithRunningService(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
resetModelProbeHooks(t)
|
||||
|
||||
probeOpenAICompatibleModelFunc = func(apiBase, modelID string) bool {
|
||||
return apiBase == "http://127.0.0.1:8000/v1" && modelID == "custom-model"
|
||||
}
|
||||
|
||||
cfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
cfg.ModelList = []config.ModelConfig{{
|
||||
ModelName: "local-vllm",
|
||||
Model: "vllm/custom-model",
|
||||
APIBase: "http://127.0.0.1:8000/v1",
|
||||
}}
|
||||
cfg.Agents.Defaults.ModelName = "local-vllm"
|
||||
err = config.SaveConfig(configPath, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
ready, reason, err := h.gatewayStartReady()
|
||||
if err != nil {
|
||||
t.Fatalf("gatewayStartReady() error = %v", err)
|
||||
}
|
||||
if !ready {
|
||||
t.Fatalf("gatewayStartReady() ready = false, want true with a running local service (reason=%q)", reason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayStartReady_RemoteVLLMWithAPIKeyDoesNotProbe(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
resetModelProbeHooks(t)
|
||||
|
||||
probeOpenAICompatibleModelFunc = func(apiBase, modelID string) bool {
|
||||
t.Fatalf("unexpected OpenAI-compatible probe for %q (%q)", apiBase, modelID)
|
||||
return false
|
||||
}
|
||||
|
||||
cfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
cfg.ModelList = []config.ModelConfig{{
|
||||
ModelName: "remote-vllm",
|
||||
Model: "vllm/custom-model",
|
||||
APIBase: "https://models.example.com/v1",
|
||||
APIKey: "remote-key",
|
||||
}}
|
||||
cfg.Agents.Defaults.ModelName = "remote-vllm"
|
||||
err = config.SaveConfig(configPath, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
ready, reason, err := h.gatewayStartReady()
|
||||
if err != nil {
|
||||
t.Fatalf("gatewayStartReady() error = %v", err)
|
||||
}
|
||||
if !ready {
|
||||
t.Fatalf("gatewayStartReady() ready = false, want true for remote vllm with api key (reason=%q)", reason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayStartReady_LocalOllamaUsesDefaultProbeBase(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
resetModelProbeHooks(t)
|
||||
|
||||
probeOllamaModelFunc = func(apiBase, modelID string) bool {
|
||||
return apiBase == "http://localhost:11434/v1" && modelID == "llama3"
|
||||
}
|
||||
|
||||
cfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
cfg.ModelList = []config.ModelConfig{{
|
||||
ModelName: "local-ollama",
|
||||
Model: "ollama/llama3",
|
||||
}}
|
||||
cfg.Agents.Defaults.ModelName = "local-ollama"
|
||||
err = config.SaveConfig(configPath, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
ready, reason, err := h.gatewayStartReady()
|
||||
if err != nil {
|
||||
t.Fatalf("gatewayStartReady() error = %v", err)
|
||||
}
|
||||
if !ready {
|
||||
t.Fatalf("gatewayStartReady() ready = false, want true with default Ollama probe base (reason=%q)", reason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayStartReady_OAuthModelRequiresStoredCredential(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
|
||||
cfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
cfg.ModelList = []config.ModelConfig{{
|
||||
ModelName: "openai-oauth",
|
||||
Model: "openai/gpt-5.2",
|
||||
AuthMethod: "oauth",
|
||||
}}
|
||||
cfg.Agents.Defaults.ModelName = "openai-oauth"
|
||||
err = config.SaveConfig(configPath, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
ready, reason, err := h.gatewayStartReady()
|
||||
if err != nil {
|
||||
t.Fatalf("gatewayStartReady() error = %v", err)
|
||||
}
|
||||
if ready {
|
||||
t.Fatalf("gatewayStartReady() ready = true, want false without stored credential")
|
||||
}
|
||||
if !strings.Contains(reason, "no credentials configured") {
|
||||
t.Fatalf("gatewayStartReady() reason = %q, want contains %q", reason, "no credentials configured")
|
||||
}
|
||||
|
||||
err = auth.SetCredential(oauthProviderOpenAI, &auth.AuthCredential{
|
||||
AccessToken: "openai-token",
|
||||
Provider: oauthProviderOpenAI,
|
||||
AuthMethod: "oauth",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SetCredential() error = %v", err)
|
||||
}
|
||||
|
||||
ready, reason, err = h.gatewayStartReady()
|
||||
if err != nil {
|
||||
t.Fatalf("gatewayStartReady() error = %v", err)
|
||||
}
|
||||
if !ready {
|
||||
t.Fatalf("gatewayStartReady() ready = false, want true with stored credential (reason=%q)", reason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayStatusIncludesStartConditionWhenNotReady(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
@@ -122,6 +317,71 @@ func TestGatewayStatusIncludesStartConditionWhenNotReady(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayClearLogsResetsBufferedHistory(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
gateway.logs.Clear()
|
||||
gateway.logs.Append("first line")
|
||||
gateway.logs.Append("second line")
|
||||
previousRunID := gateway.logs.RunID()
|
||||
|
||||
clearRec := httptest.NewRecorder()
|
||||
clearReq := httptest.NewRequest(http.MethodPost, "/api/gateway/logs/clear", nil)
|
||||
mux.ServeHTTP(clearRec, clearReq)
|
||||
|
||||
if clearRec.Code != http.StatusOK {
|
||||
t.Fatalf("clear status = %d, want %d", clearRec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var clearBody map[string]any
|
||||
if err := json.Unmarshal(clearRec.Body.Bytes(), &clearBody); err != nil {
|
||||
t.Fatalf("unmarshal clear response: %v", err)
|
||||
}
|
||||
|
||||
if got := clearBody["status"]; got != "cleared" {
|
||||
t.Fatalf("clear status body = %#v, want %q", got, "cleared")
|
||||
}
|
||||
|
||||
clearRunID, ok := clearBody["log_run_id"].(float64)
|
||||
if !ok {
|
||||
t.Fatalf("log_run_id missing or not number: %#v", clearBody["log_run_id"])
|
||||
}
|
||||
if int(clearRunID) <= previousRunID {
|
||||
t.Fatalf("log_run_id = %d, want > %d", int(clearRunID), previousRunID)
|
||||
}
|
||||
|
||||
statusRec := httptest.NewRecorder()
|
||||
statusReq := httptest.NewRequest(
|
||||
http.MethodGet,
|
||||
"/api/gateway/status?log_offset=0&log_run_id="+strconv.Itoa(previousRunID),
|
||||
nil,
|
||||
)
|
||||
mux.ServeHTTP(statusRec, statusReq)
|
||||
|
||||
if statusRec.Code != http.StatusOK {
|
||||
t.Fatalf("status code = %d, want %d", statusRec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var statusBody map[string]any
|
||||
if err := json.Unmarshal(statusRec.Body.Bytes(), &statusBody); err != nil {
|
||||
t.Fatalf("unmarshal status response: %v", err)
|
||||
}
|
||||
|
||||
logs, ok := statusBody["logs"].([]any)
|
||||
if !ok {
|
||||
t.Fatalf("logs missing or not array: %#v", statusBody["logs"])
|
||||
}
|
||||
if len(logs) != 0 {
|
||||
t.Fatalf("logs len = %d, want 0", len(logs))
|
||||
}
|
||||
if got := statusBody["log_total"]; got != float64(0) {
|
||||
t.Fatalf("log_total = %#v, want 0", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindPicoclawBinary_EnvOverride(t *testing.T) {
|
||||
// Create a temporary file to act as the mock binary
|
||||
tmpDir := t.TempDir()
|
||||
@@ -132,9 +392,9 @@ func TestFindPicoclawBinary_EnvOverride(t *testing.T) {
|
||||
|
||||
t.Setenv("PICOCLAW_BINARY", mockBinary)
|
||||
|
||||
got := findPicoclawBinary()
|
||||
got := utils.FindPicoclawBinary()
|
||||
if got != mockBinary {
|
||||
t.Errorf("findPicoclawBinary() = %q, want %q", got, mockBinary)
|
||||
t.Errorf("FindPicoclawBinary() = %q, want %q", got, mockBinary)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -142,9 +402,9 @@ func TestFindPicoclawBinary_EnvOverride_InvalidPath(t *testing.T) {
|
||||
// When PICOCLAW_BINARY points to a non-existent path, fall through to next strategy
|
||||
t.Setenv("PICOCLAW_BINARY", "/nonexistent/picoclaw-binary")
|
||||
|
||||
got := findPicoclawBinary()
|
||||
got := utils.FindPicoclawBinary()
|
||||
// Should not return the invalid path; falls back to "picoclaw" or another found path
|
||||
if got == "/nonexistent/picoclaw-binary" {
|
||||
t.Errorf("findPicoclawBinary() returned invalid env path %q, expected fallback", got)
|
||||
t.Errorf("FindPicoclawBinary() returned invalid env path %q, expected fallback", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
func TestGetLauncherConfigUsesRuntimeFallback(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
h.SetServerOptions(19999, true, []string{"192.168.1.0/24"})
|
||||
h.SetServerOptions(19999, true, false, []string{"192.168.1.0/24"})
|
||||
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
@@ -4,7 +4,7 @@ import "sync"
|
||||
|
||||
// LogBuffer is a thread-safe ring buffer that stores the most recent N log lines.
|
||||
// It supports incremental reads via LinesSince and tracks a runID that increments
|
||||
// on each Reset (used to detect gateway restarts).
|
||||
// whenever the buffer is reset or cleared so clients can detect log history resets.
|
||||
type LogBuffer struct {
|
||||
mu sync.RWMutex
|
||||
lines []string
|
||||
@@ -45,6 +45,12 @@ func (b *LogBuffer) Reset() {
|
||||
b.runID++
|
||||
}
|
||||
|
||||
// Clear removes all buffered lines and increments the runID so clients treat
|
||||
// subsequent reads as a new log stream.
|
||||
func (b *LogBuffer) Clear() {
|
||||
b.Reset()
|
||||
}
|
||||
|
||||
// LinesSince returns lines appended after the given offset, the current total count, and the runID.
|
||||
// If offset >= total, no lines are returned. If offset is too old (evicted), all buffered lines are returned.
|
||||
func (b *LogBuffer) LinesSince(offset int) (lines []string, total int, runID int) {
|
||||
|
||||
@@ -0,0 +1,324 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
const modelProbeTimeout = 800 * time.Millisecond
|
||||
|
||||
var (
|
||||
probeTCPServiceFunc = probeTCPService
|
||||
probeOllamaModelFunc = probeOllamaModel
|
||||
probeOpenAICompatibleModelFunc = probeOpenAICompatibleModel
|
||||
)
|
||||
|
||||
func hasModelConfiguration(m config.ModelConfig) bool {
|
||||
authMethod := strings.ToLower(strings.TrimSpace(m.AuthMethod))
|
||||
apiKey := strings.TrimSpace(m.APIKey)
|
||||
|
||||
if authMethod == "oauth" || authMethod == "token" {
|
||||
if provider, ok := oauthProviderForModel(m.Model); ok {
|
||||
cred, err := oauthGetCredential(provider)
|
||||
if err != nil || cred == nil {
|
||||
return false
|
||||
}
|
||||
return strings.TrimSpace(cred.AccessToken) != "" || strings.TrimSpace(cred.RefreshToken) != ""
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
if requiresRuntimeProbe(m) {
|
||||
return true
|
||||
}
|
||||
|
||||
return apiKey != ""
|
||||
}
|
||||
|
||||
// isModelConfigured reports whether a model is currently available to use.
|
||||
// Local models must be reachable; remote/API-key models only need saved config.
|
||||
func isModelConfigured(m config.ModelConfig) bool {
|
||||
if !hasModelConfiguration(m) {
|
||||
return false
|
||||
}
|
||||
if requiresRuntimeProbe(m) {
|
||||
return probeLocalModelAvailability(m)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func requiresRuntimeProbe(m config.ModelConfig) bool {
|
||||
authMethod := strings.ToLower(strings.TrimSpace(m.AuthMethod))
|
||||
if authMethod == "local" {
|
||||
return true
|
||||
}
|
||||
|
||||
switch modelProtocol(m.Model) {
|
||||
case "claude-cli", "claudecli", "codex-cli", "codexcli", "github-copilot", "copilot":
|
||||
return true
|
||||
case "ollama", "vllm":
|
||||
apiBase := strings.TrimSpace(m.APIBase)
|
||||
return apiBase == "" || hasLocalAPIBase(apiBase)
|
||||
}
|
||||
|
||||
if hasLocalAPIBase(m.APIBase) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func probeLocalModelAvailability(m config.ModelConfig) bool {
|
||||
apiBase := modelProbeAPIBase(m)
|
||||
protocol, modelID := splitModel(m.Model)
|
||||
switch protocol {
|
||||
case "ollama":
|
||||
return probeOllamaModelFunc(apiBase, modelID)
|
||||
case "vllm":
|
||||
return probeOpenAICompatibleModelFunc(apiBase, modelID)
|
||||
case "github-copilot", "copilot":
|
||||
return probeTCPServiceFunc(apiBase)
|
||||
case "claude-cli", "claudecli", "codex-cli", "codexcli":
|
||||
return true
|
||||
default:
|
||||
if hasLocalAPIBase(apiBase) {
|
||||
return probeOpenAICompatibleModelFunc(apiBase, modelID)
|
||||
}
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func modelProbeAPIBase(m config.ModelConfig) string {
|
||||
if apiBase := strings.TrimSpace(m.APIBase); apiBase != "" {
|
||||
return normalizeModelProbeAPIBase(apiBase)
|
||||
}
|
||||
|
||||
switch modelProtocol(m.Model) {
|
||||
case "ollama":
|
||||
return "http://localhost:11434/v1"
|
||||
case "vllm":
|
||||
return "http://localhost:8000/v1"
|
||||
case "github-copilot", "copilot":
|
||||
return "localhost:4321"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeModelProbeAPIBase(raw string) string {
|
||||
u, err := parseAPIBase(raw)
|
||||
if err != nil {
|
||||
return strings.TrimSpace(raw)
|
||||
}
|
||||
|
||||
switch strings.ToLower(u.Hostname()) {
|
||||
case "0.0.0.0":
|
||||
u.Host = net.JoinHostPort("127.0.0.1", u.Port())
|
||||
case "::":
|
||||
u.Host = net.JoinHostPort("::1", u.Port())
|
||||
default:
|
||||
return strings.TrimSpace(raw)
|
||||
}
|
||||
|
||||
if u.Port() == "" {
|
||||
u.Host = u.Hostname()
|
||||
}
|
||||
|
||||
return u.String()
|
||||
}
|
||||
|
||||
func oauthProviderForModel(model string) (string, bool) {
|
||||
switch modelProtocol(model) {
|
||||
case "openai":
|
||||
return oauthProviderOpenAI, true
|
||||
case "anthropic":
|
||||
return oauthProviderAnthropic, true
|
||||
case "antigravity", "google-antigravity":
|
||||
return oauthProviderGoogleAntigravity, true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func modelProtocol(model string) string {
|
||||
protocol, _ := splitModel(model)
|
||||
return protocol
|
||||
}
|
||||
|
||||
func splitModel(model string) (protocol, modelID string) {
|
||||
model = strings.ToLower(strings.TrimSpace(model))
|
||||
protocol, _, found := strings.Cut(model, "/")
|
||||
if !found {
|
||||
return "openai", model
|
||||
}
|
||||
return protocol, strings.TrimSpace(model[strings.Index(model, "/")+1:])
|
||||
}
|
||||
|
||||
func hasLocalAPIBase(raw string) bool {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
u, err := url.Parse(raw)
|
||||
if err != nil || u.Hostname() == "" {
|
||||
u, err = url.Parse("//" + raw)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
switch strings.ToLower(u.Hostname()) {
|
||||
case "localhost", "127.0.0.1", "::1", "0.0.0.0":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func probeTCPService(raw string) bool {
|
||||
hostPort, err := hostPortFromAPIBase(raw)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
conn, err := net.DialTimeout("tcp", hostPort, modelProbeTimeout)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
_ = conn.Close()
|
||||
return true
|
||||
}
|
||||
|
||||
func probeOllamaModel(apiBase, modelID string) bool {
|
||||
root, err := apiRootFromAPIBase(apiBase)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
var resp struct {
|
||||
Models []struct {
|
||||
Name string `json:"name"`
|
||||
Model string `json:"model"`
|
||||
} `json:"models"`
|
||||
}
|
||||
if err := getJSON(root+"/api/tags", &resp); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, model := range resp.Models {
|
||||
if ollamaModelMatches(model.Name, modelID) || ollamaModelMatches(model.Model, modelID) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func probeOpenAICompatibleModel(apiBase, modelID string) bool {
|
||||
if strings.TrimSpace(apiBase) == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
var resp struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
if err := getJSON(strings.TrimRight(strings.TrimSpace(apiBase), "/")+"/models", &resp); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, model := range resp.Data {
|
||||
if strings.EqualFold(strings.TrimSpace(model.ID), modelID) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func getJSON(rawURL string, out any) error {
|
||||
req, err := http.NewRequest(http.MethodGet, rawURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: modelProbeTimeout}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("unexpected status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return json.NewDecoder(resp.Body).Decode(out)
|
||||
}
|
||||
|
||||
func apiRootFromAPIBase(raw string) (string, error) {
|
||||
u, err := parseAPIBase(raw)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return (&url.URL{Scheme: u.Scheme, Host: u.Host}).String(), nil
|
||||
}
|
||||
|
||||
func hostPortFromAPIBase(raw string) (string, error) {
|
||||
u, err := parseAPIBase(raw)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if port := u.Port(); port != "" {
|
||||
return u.Host, nil
|
||||
}
|
||||
switch strings.ToLower(u.Scheme) {
|
||||
case "https":
|
||||
return net.JoinHostPort(u.Hostname(), "443"), nil
|
||||
default:
|
||||
return net.JoinHostPort(u.Hostname(), "80"), nil
|
||||
}
|
||||
}
|
||||
|
||||
func parseAPIBase(raw string) (*url.URL, error) {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return nil, fmt.Errorf("empty api base")
|
||||
}
|
||||
|
||||
u, err := url.Parse(raw)
|
||||
if err == nil && u.Hostname() != "" {
|
||||
return u, nil
|
||||
}
|
||||
|
||||
u, err = url.Parse("//" + raw)
|
||||
if err != nil || u.Hostname() == "" {
|
||||
return nil, fmt.Errorf("invalid api base %q", raw)
|
||||
}
|
||||
if u.Scheme == "" {
|
||||
u.Scheme = "http"
|
||||
}
|
||||
return u, nil
|
||||
}
|
||||
|
||||
func ollamaModelMatches(candidate, want string) bool {
|
||||
candidate = strings.TrimSpace(candidate)
|
||||
want = strings.TrimSpace(want)
|
||||
if candidate == "" || want == "" {
|
||||
return false
|
||||
}
|
||||
if strings.EqualFold(candidate, want) {
|
||||
return true
|
||||
}
|
||||
|
||||
base, _, _ := strings.Cut(candidate, ":")
|
||||
return strings.EqualFold(base, want)
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
@@ -45,13 +46,24 @@ type modelResponse struct {
|
||||
//
|
||||
// GET /api/models
|
||||
func (h *Handler) handleListModels(w http.ResponseWriter, r *http.Request) {
|
||||
cfg, err := h.loadFilteredConfig()
|
||||
cfg, err := config.LoadConfig(h.configPath)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to load config: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
defaultModel := cfg.Agents.Defaults.GetModelName()
|
||||
configured := make([]bool, len(cfg.ModelList))
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(cfg.ModelList))
|
||||
for i, m := range cfg.ModelList {
|
||||
go func(i int, m config.ModelConfig) {
|
||||
defer wg.Done()
|
||||
configured[i] = isModelConfigured(m)
|
||||
}(i, m)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
models := make([]modelResponse, 0, len(cfg.ModelList))
|
||||
for i, m := range cfg.ModelList {
|
||||
@@ -69,7 +81,7 @@ func (h *Handler) handleListModels(w http.ResponseWriter, r *http.Request) {
|
||||
MaxTokensField: m.MaxTokensField,
|
||||
RequestTimeout: m.RequestTimeout,
|
||||
ThinkingLevel: m.ThinkingLevel,
|
||||
Configured: m.APIKey != "" || m.AuthMethod != "",
|
||||
Configured: configured[i],
|
||||
IsDefault: m.ModelName == defaultModel,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -0,0 +1,313 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/auth"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
func resetModelProbeHooks(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
origTCPProbe := probeTCPServiceFunc
|
||||
origOllamaProbe := probeOllamaModelFunc
|
||||
origOpenAIProbe := probeOpenAICompatibleModelFunc
|
||||
t.Cleanup(func() {
|
||||
probeTCPServiceFunc = origTCPProbe
|
||||
probeOllamaModelFunc = origOllamaProbe
|
||||
probeOpenAICompatibleModelFunc = origOpenAIProbe
|
||||
})
|
||||
}
|
||||
|
||||
func TestHandleListModels_ConfiguredStatusUsesRuntimeProbesForLocalModels(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
resetOAuthHooks(t)
|
||||
resetModelProbeHooks(t)
|
||||
|
||||
var mu sync.Mutex
|
||||
var openAIProbes []string
|
||||
var ollamaProbes []string
|
||||
var tcpProbes []string
|
||||
|
||||
probeOpenAICompatibleModelFunc = func(apiBase, modelID string) bool {
|
||||
mu.Lock()
|
||||
openAIProbes = append(openAIProbes, apiBase+"|"+modelID)
|
||||
mu.Unlock()
|
||||
return apiBase == "http://127.0.0.1:8000/v1" && modelID == "custom-model"
|
||||
}
|
||||
probeOllamaModelFunc = func(apiBase, modelID string) bool {
|
||||
mu.Lock()
|
||||
ollamaProbes = append(ollamaProbes, apiBase+"|"+modelID)
|
||||
mu.Unlock()
|
||||
return apiBase == "http://localhost:11434/v1" && modelID == "llama3"
|
||||
}
|
||||
probeTCPServiceFunc = func(apiBase string) bool {
|
||||
mu.Lock()
|
||||
tcpProbes = append(tcpProbes, apiBase)
|
||||
mu.Unlock()
|
||||
return apiBase == "http://127.0.0.1:4321"
|
||||
}
|
||||
|
||||
cfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
cfg.ModelList = []config.ModelConfig{
|
||||
{
|
||||
ModelName: "openai-oauth",
|
||||
Model: "openai/gpt-5.2",
|
||||
AuthMethod: "oauth",
|
||||
},
|
||||
{
|
||||
ModelName: "vllm-local",
|
||||
Model: "vllm/custom-model",
|
||||
APIBase: "http://127.0.0.1:8000/v1",
|
||||
},
|
||||
{
|
||||
ModelName: "ollama-default",
|
||||
Model: "ollama/llama3",
|
||||
},
|
||||
{
|
||||
ModelName: "vllm-remote",
|
||||
Model: "vllm/custom-model",
|
||||
APIBase: "https://models.example.com/v1",
|
||||
APIKey: "remote-key",
|
||||
},
|
||||
{
|
||||
ModelName: "copilot-gpt-5.2",
|
||||
Model: "github-copilot/gpt-5.2",
|
||||
APIBase: "http://127.0.0.1:4321",
|
||||
AuthMethod: "oauth",
|
||||
},
|
||||
}
|
||||
cfg.Agents.Defaults.ModelName = "openai-oauth"
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/models", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
|
||||
var resp struct {
|
||||
Models []modelResponse `json:"models"`
|
||||
}
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("Unmarshal() error = %v", err)
|
||||
}
|
||||
|
||||
got := make(map[string]bool, len(resp.Models))
|
||||
for _, model := range resp.Models {
|
||||
got[model.ModelName] = model.Configured
|
||||
}
|
||||
|
||||
if got["openai-oauth"] {
|
||||
t.Fatalf("openai oauth model configured = true, want false without stored credential")
|
||||
}
|
||||
if !got["vllm-local"] {
|
||||
t.Fatalf("vllm local model configured = false, want true when local probe succeeds")
|
||||
}
|
||||
if !got["ollama-default"] {
|
||||
t.Fatalf("ollama default model configured = false, want true when default local probe succeeds")
|
||||
}
|
||||
if !got["vllm-remote"] {
|
||||
t.Fatalf("remote vllm model configured = false, want true with api_key")
|
||||
}
|
||||
if !got["copilot-gpt-5.2"] {
|
||||
t.Fatalf("copilot model configured = false, want true when local bridge probe succeeds")
|
||||
}
|
||||
if len(openAIProbes) != 1 || openAIProbes[0] != "http://127.0.0.1:8000/v1|custom-model" {
|
||||
t.Fatalf("openAI probes = %#v, want only local vllm probe", openAIProbes)
|
||||
}
|
||||
if len(ollamaProbes) != 1 || ollamaProbes[0] != "http://localhost:11434/v1|llama3" {
|
||||
t.Fatalf("ollama probes = %#v, want default local probe", ollamaProbes)
|
||||
}
|
||||
if len(tcpProbes) != 1 || tcpProbes[0] != "http://127.0.0.1:4321" {
|
||||
t.Fatalf("tcp probes = %#v, want only local copilot probe", tcpProbes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleListModels_ConfiguredStatusForOAuthModelWithCredential(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
resetOAuthHooks(t)
|
||||
resetModelProbeHooks(t)
|
||||
|
||||
cfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
cfg.ModelList = []config.ModelConfig{{
|
||||
ModelName: "claude-oauth",
|
||||
Model: "anthropic/claude-sonnet-4.6",
|
||||
AuthMethod: "oauth",
|
||||
}}
|
||||
cfg.Agents.Defaults.ModelName = "claude-oauth"
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
if err := auth.SetCredential(oauthProviderAnthropic, &auth.AuthCredential{
|
||||
AccessToken: "anthropic-token",
|
||||
Provider: oauthProviderAnthropic,
|
||||
AuthMethod: "oauth",
|
||||
}); err != nil {
|
||||
t.Fatalf("SetCredential() error = %v", err)
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/models", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
|
||||
var resp struct {
|
||||
Models []modelResponse `json:"models"`
|
||||
}
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("Unmarshal() error = %v", err)
|
||||
}
|
||||
if len(resp.Models) != 1 {
|
||||
t.Fatalf("len(models) = %d, want 1", len(resp.Models))
|
||||
}
|
||||
if !resp.Models[0].Configured {
|
||||
t.Fatalf("oauth model configured = false, want true with stored credential")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleListModels_ProbesLocalModelsConcurrently(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
resetOAuthHooks(t)
|
||||
resetModelProbeHooks(t)
|
||||
|
||||
started := make(chan string, 2)
|
||||
release := make(chan struct{})
|
||||
|
||||
probeOpenAICompatibleModelFunc = func(apiBase, modelID string) bool {
|
||||
started <- apiBase + "|" + modelID
|
||||
<-release
|
||||
return true
|
||||
}
|
||||
|
||||
cfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
cfg.ModelList = []config.ModelConfig{
|
||||
{
|
||||
ModelName: "local-vllm-a",
|
||||
Model: "vllm/custom-a",
|
||||
APIBase: "http://127.0.0.1:8000/v1",
|
||||
},
|
||||
{
|
||||
ModelName: "local-vllm-b",
|
||||
Model: "vllm/custom-b",
|
||||
APIBase: "http://127.0.0.1:8001/v1",
|
||||
},
|
||||
}
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
recCh := make(chan *httptest.ResponseRecorder, 1)
|
||||
go func() {
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/models", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
recCh <- rec
|
||||
}()
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
select {
|
||||
case <-started:
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Fatal("expected both local probes to start before the first one completed")
|
||||
}
|
||||
}
|
||||
close(release)
|
||||
|
||||
rec := <-recCh
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleListModels_NormalizesWildcardLocalAPIBaseForProbe(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
resetOAuthHooks(t)
|
||||
resetModelProbeHooks(t)
|
||||
|
||||
var gotProbe string
|
||||
probeOpenAICompatibleModelFunc = func(apiBase, modelID string) bool {
|
||||
gotProbe = apiBase + "|" + modelID
|
||||
return apiBase == "http://127.0.0.1:8000/v1" && modelID == "custom-model"
|
||||
}
|
||||
|
||||
cfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
cfg.ModelList = []config.ModelConfig{{
|
||||
ModelName: "vllm-local",
|
||||
Model: "vllm/custom-model",
|
||||
APIBase: "http://0.0.0.0:8000/v1",
|
||||
}}
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/models", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
|
||||
var resp struct {
|
||||
Models []modelResponse `json:"models"`
|
||||
}
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("Unmarshal() error = %v", err)
|
||||
}
|
||||
if len(resp.Models) != 1 {
|
||||
t.Fatalf("len(models) = %d, want 1", len(resp.Models))
|
||||
}
|
||||
if !resp.Models[0].Configured {
|
||||
t.Fatal("wildcard-bound local model configured = false, want true after probe host normalization")
|
||||
}
|
||||
if gotProbe != "http://127.0.0.1:8000/v1|custom-model" {
|
||||
t.Fatalf("probe api base = %q, want %q", gotProbe, "http://127.0.0.1:8000/v1|custom-model")
|
||||
}
|
||||
}
|
||||
+3
-21
@@ -5,9 +5,7 @@ import (
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
@@ -30,7 +28,7 @@ func (h *Handler) handleGetPicoToken(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
wsURL := buildWsURL(r, cfg)
|
||||
wsURL := h.buildWsURL(r, cfg)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
@@ -58,7 +56,7 @@ func (h *Handler) handleRegenPicoToken(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
wsURL := fmt.Sprintf("ws://%s/pico/ws", net.JoinHostPort(cfg.Gateway.Host, strconv.Itoa(cfg.Gateway.Port)))
|
||||
wsURL := h.buildWsURL(r, cfg)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
@@ -123,7 +121,7 @@ func (h *Handler) handlePicoSetup(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
wsURL := buildWsURL(r, cfg)
|
||||
wsURL := h.buildWsURL(r, cfg)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
@@ -134,22 +132,6 @@ func (h *Handler) handlePicoSetup(w http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
}
|
||||
|
||||
// buildWsURL creates a WebSocket URL for the Pico Channel.
|
||||
// When the gateway host is "0.0.0.0" or empty, it uses the hostname from the
|
||||
// incoming HTTP request so the browser gets a connectable address.
|
||||
func buildWsURL(r *http.Request, cfg *config.Config) string {
|
||||
host := cfg.Gateway.Host
|
||||
if host == "" || host == "0.0.0.0" {
|
||||
// Use the hostname the browser used to reach this backend
|
||||
reqHost, _, err := net.SplitHostPort(r.Host)
|
||||
if err != nil {
|
||||
reqHost = r.Host // r.Host might not have a port
|
||||
}
|
||||
host = reqHost
|
||||
}
|
||||
return "ws://" + net.JoinHostPort(host, strconv.Itoa(cfg.Gateway.Port)) + "/pico/ws"
|
||||
}
|
||||
|
||||
// generateSecureToken creates a random 32-character hex string.
|
||||
func generateSecureToken() string {
|
||||
b := make([]byte, 16)
|
||||
|
||||
@@ -9,13 +9,14 @@ import (
|
||||
|
||||
// Handler serves HTTP API requests.
|
||||
type Handler struct {
|
||||
configPath string
|
||||
serverPort int
|
||||
serverPublic bool
|
||||
serverCIDRs []string
|
||||
oauthMu sync.Mutex
|
||||
oauthFlows map[string]*oauthFlow
|
||||
oauthState map[string]string
|
||||
configPath string
|
||||
serverPort int
|
||||
serverPublic bool
|
||||
serverPublicExplicit bool
|
||||
serverCIDRs []string
|
||||
oauthMu sync.Mutex
|
||||
oauthFlows map[string]*oauthFlow
|
||||
oauthState map[string]string
|
||||
}
|
||||
|
||||
// NewHandler creates an instance of the API handler.
|
||||
@@ -29,9 +30,10 @@ func NewHandler(configPath string) *Handler {
|
||||
}
|
||||
|
||||
// SetServerOptions stores current backend listen options for fallback behavior.
|
||||
func (h *Handler) SetServerOptions(port int, public bool, allowedCIDRs []string) {
|
||||
func (h *Handler) SetServerOptions(port int, public bool, publicExplicit bool, allowedCIDRs []string) {
|
||||
h.serverPort = port
|
||||
h.serverPublic = public
|
||||
h.serverPublicExplicit = publicExplicit
|
||||
h.serverCIDRs = append([]string(nil), allowedCIDRs...)
|
||||
}
|
||||
|
||||
@@ -58,6 +60,10 @@ func (h *Handler) RegisterRoutes(mux *http.ServeMux) {
|
||||
// Channel catalog (for frontend navigation/config pages)
|
||||
h.registerChannelRoutes(mux)
|
||||
|
||||
// Skills and tools support/actions
|
||||
h.registerSkillRoutes(mux)
|
||||
h.registerToolRoutes(mux)
|
||||
|
||||
// OS startup / launch-at-login
|
||||
h.registerStartupRoutes(mux)
|
||||
|
||||
|
||||
+284
-64
@@ -1,7 +1,9 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -33,12 +35,22 @@ type sessionFile struct {
|
||||
// sessionListItem is a lightweight summary returned by GET /api/sessions.
|
||||
type sessionListItem struct {
|
||||
ID string `json:"id"`
|
||||
Title string `json:"title"`
|
||||
Preview string `json:"preview"`
|
||||
MessageCount int `json:"message_count"`
|
||||
Created string `json:"created"`
|
||||
Updated string `json:"updated"`
|
||||
}
|
||||
|
||||
type sessionMetaFile struct {
|
||||
Key string `json:"key"`
|
||||
Summary string `json:"summary"`
|
||||
Skip int `json:"skip"`
|
||||
Count int `json:"count"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// picoSessionPrefix is the key prefix used by the gateway's routing for Pico
|
||||
// channel sessions. The full key format is:
|
||||
//
|
||||
@@ -47,7 +59,12 @@ type sessionListItem struct {
|
||||
// The sanitized filename replaces ':' with '_', so on disk it becomes:
|
||||
//
|
||||
// agent_main_pico_direct_pico_<session-uuid>.json
|
||||
const picoSessionPrefix = "agent:main:pico:direct:pico:"
|
||||
const (
|
||||
picoSessionPrefix = "agent:main:pico:direct:pico:"
|
||||
sanitizedPicoSessionPrefix = "agent_main_pico_direct_pico_"
|
||||
maxSessionJSONLLineSize = 10 * 1024 * 1024 // 10 MB
|
||||
maxSessionTitleRunes = 60
|
||||
)
|
||||
|
||||
// extractPicoSessionID extracts the session UUID from a full session key.
|
||||
// Returns the UUID and true if the key matches the Pico session pattern.
|
||||
@@ -58,6 +75,178 @@ func extractPicoSessionID(key string) (string, bool) {
|
||||
return "", false
|
||||
}
|
||||
|
||||
func extractPicoSessionIDFromSanitizedKey(key string) (string, bool) {
|
||||
if strings.HasPrefix(key, sanitizedPicoSessionPrefix) {
|
||||
return strings.TrimPrefix(key, sanitizedPicoSessionPrefix), true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func sanitizeSessionKey(key string) string {
|
||||
return strings.ReplaceAll(key, ":", "_")
|
||||
}
|
||||
|
||||
func (h *Handler) readLegacySession(dir, sessionID string) (sessionFile, error) {
|
||||
path := filepath.Join(dir, sanitizeSessionKey(picoSessionPrefix+sessionID)+".json")
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return sessionFile{}, err
|
||||
}
|
||||
|
||||
var sess sessionFile
|
||||
if err := json.Unmarshal(data, &sess); err != nil {
|
||||
return sessionFile{}, err
|
||||
}
|
||||
return sess, nil
|
||||
}
|
||||
|
||||
func (h *Handler) readSessionMeta(path, sessionKey string) (sessionMetaFile, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if os.IsNotExist(err) {
|
||||
return sessionMetaFile{Key: sessionKey}, nil
|
||||
}
|
||||
if err != nil {
|
||||
return sessionMetaFile{}, err
|
||||
}
|
||||
|
||||
var meta sessionMetaFile
|
||||
if err := json.Unmarshal(data, &meta); err != nil {
|
||||
return sessionMetaFile{}, err
|
||||
}
|
||||
if meta.Key == "" {
|
||||
meta.Key = sessionKey
|
||||
}
|
||||
return meta, nil
|
||||
}
|
||||
|
||||
func (h *Handler) readSessionMessages(path string, skip int) ([]providers.Message, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
msgs := make([]providers.Message, 0)
|
||||
scanner := bufio.NewScanner(f)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), maxSessionJSONLLineSize)
|
||||
|
||||
seen := 0
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
seen++
|
||||
if seen <= skip {
|
||||
continue
|
||||
}
|
||||
|
||||
var msg providers.Message
|
||||
if err := json.Unmarshal(line, &msg); err != nil {
|
||||
continue
|
||||
}
|
||||
msgs = append(msgs, msg)
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return msgs, nil
|
||||
}
|
||||
|
||||
func (h *Handler) readJSONLSession(dir, sessionID string) (sessionFile, error) {
|
||||
sessionKey := picoSessionPrefix + sessionID
|
||||
base := filepath.Join(dir, sanitizeSessionKey(sessionKey))
|
||||
jsonlPath := base + ".jsonl"
|
||||
metaPath := base + ".meta.json"
|
||||
|
||||
meta, err := h.readSessionMeta(metaPath, sessionKey)
|
||||
if err != nil {
|
||||
return sessionFile{}, err
|
||||
}
|
||||
|
||||
messages, err := h.readSessionMessages(jsonlPath, meta.Skip)
|
||||
if err != nil {
|
||||
return sessionFile{}, err
|
||||
}
|
||||
|
||||
updated := meta.UpdatedAt
|
||||
created := meta.CreatedAt
|
||||
if created.IsZero() || updated.IsZero() {
|
||||
if info, statErr := os.Stat(jsonlPath); statErr == nil {
|
||||
if created.IsZero() {
|
||||
created = info.ModTime()
|
||||
}
|
||||
if updated.IsZero() {
|
||||
updated = info.ModTime()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return sessionFile{
|
||||
Key: meta.Key,
|
||||
Messages: messages,
|
||||
Summary: meta.Summary,
|
||||
Created: created,
|
||||
Updated: updated,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func buildSessionListItem(sessionID string, sess sessionFile) sessionListItem {
|
||||
preview := ""
|
||||
for _, msg := range sess.Messages {
|
||||
if msg.Role == "user" && strings.TrimSpace(msg.Content) != "" {
|
||||
preview = msg.Content
|
||||
break
|
||||
}
|
||||
}
|
||||
title := strings.TrimSpace(sess.Summary)
|
||||
if title == "" {
|
||||
title = preview
|
||||
}
|
||||
|
||||
title = truncateRunes(title, maxSessionTitleRunes)
|
||||
preview = truncateRunes(preview, maxSessionTitleRunes)
|
||||
|
||||
if preview == "" {
|
||||
preview = "(empty)"
|
||||
}
|
||||
if title == "" {
|
||||
title = preview
|
||||
}
|
||||
|
||||
validMessageCount := 0
|
||||
for _, msg := range sess.Messages {
|
||||
if (msg.Role == "user" || msg.Role == "assistant") && strings.TrimSpace(msg.Content) != "" {
|
||||
validMessageCount++
|
||||
}
|
||||
}
|
||||
|
||||
return sessionListItem{
|
||||
ID: sessionID,
|
||||
Title: title,
|
||||
Preview: preview,
|
||||
MessageCount: validMessageCount,
|
||||
Created: sess.Created.Format(time.RFC3339),
|
||||
Updated: sess.Updated.Format(time.RFC3339),
|
||||
}
|
||||
}
|
||||
|
||||
func isEmptySession(sess sessionFile) bool {
|
||||
return len(sess.Messages) == 0 && strings.TrimSpace(sess.Summary) == ""
|
||||
}
|
||||
|
||||
func truncateRunes(s string, maxLen int) string {
|
||||
if maxLen <= 0 {
|
||||
return ""
|
||||
}
|
||||
runes := []rune(strings.TrimSpace(s))
|
||||
if len(runes) <= maxLen {
|
||||
return string(runes)
|
||||
}
|
||||
return string(runes[:maxLen]) + "..."
|
||||
}
|
||||
|
||||
// sessionsDir resolves the path to the gateway's session storage directory.
|
||||
// It reads the workspace from config, falling back to ~/.picoclaw/workspace.
|
||||
func (h *Handler) sessionsDir() (string, error) {
|
||||
@@ -104,58 +293,76 @@ func (h *Handler) handleListSessions(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
items := []sessionListItem{}
|
||||
seen := make(map[string]struct{})
|
||||
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || filepath.Ext(entry.Name()) != ".json" {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(filepath.Join(dir, entry.Name()))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
name := entry.Name()
|
||||
var (
|
||||
sessionID string
|
||||
sess sessionFile
|
||||
loadErr error
|
||||
ok bool
|
||||
)
|
||||
|
||||
var sess sessionFile
|
||||
if err := json.Unmarshal(data, &sess); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Only include Pico channel sessions
|
||||
sessionID, ok := extractPicoSessionID(sess.Key)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// Build a preview from the first user message
|
||||
preview := ""
|
||||
for _, msg := range sess.Messages {
|
||||
if msg.Role == "user" && strings.TrimSpace(msg.Content) != "" {
|
||||
preview = msg.Content
|
||||
break
|
||||
switch {
|
||||
case strings.HasSuffix(name, ".jsonl"):
|
||||
sessionID, ok = extractPicoSessionIDFromSanitizedKey(strings.TrimSuffix(name, ".jsonl"))
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if len([]rune(preview)) > 60 {
|
||||
preview = string([]rune(preview)[:60]) + "..."
|
||||
}
|
||||
if preview == "" {
|
||||
preview = "(empty)"
|
||||
}
|
||||
|
||||
// Only count non-empty user and assistant messages
|
||||
validMessageCount := 0
|
||||
for _, msg := range sess.Messages {
|
||||
if (msg.Role == "user" || msg.Role == "assistant") && strings.TrimSpace(msg.Content) != "" {
|
||||
validMessageCount++
|
||||
sess, loadErr = h.readJSONLSession(dir, sessionID)
|
||||
if loadErr == nil && isEmptySession(sess) {
|
||||
continue
|
||||
}
|
||||
case strings.HasSuffix(name, ".meta.json"):
|
||||
continue
|
||||
case filepath.Ext(name) == ".json":
|
||||
base := strings.TrimSuffix(name, ".json")
|
||||
if _, statErr := os.Stat(filepath.Join(dir, base+".jsonl")); statErr == nil {
|
||||
if jsonlSessionID, found := extractPicoSessionIDFromSanitizedKey(base); found {
|
||||
if jsonlSess, jsonlErr := h.readJSONLSession(
|
||||
dir,
|
||||
jsonlSessionID,
|
||||
); jsonlErr == nil &&
|
||||
!isEmptySession(jsonlSess) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
data, err := os.ReadFile(filepath.Join(dir, name))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if err := json.Unmarshal(data, &sess); err != nil {
|
||||
continue
|
||||
}
|
||||
if isEmptySession(sess) {
|
||||
continue
|
||||
}
|
||||
sessionID, ok = extractPicoSessionID(sess.Key)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if _, exists := seen[sessionID]; exists {
|
||||
continue
|
||||
}
|
||||
default:
|
||||
continue
|
||||
}
|
||||
|
||||
items = append(items, sessionListItem{
|
||||
ID: sessionID,
|
||||
Preview: preview,
|
||||
MessageCount: validMessageCount,
|
||||
Created: sess.Created.Format(time.RFC3339),
|
||||
Updated: sess.Updated.Format(time.RFC3339),
|
||||
})
|
||||
if loadErr != nil {
|
||||
continue
|
||||
}
|
||||
if _, exists := seen[sessionID]; exists {
|
||||
continue
|
||||
}
|
||||
|
||||
seen[sessionID] = struct{}{}
|
||||
items = append(items, buildSessionListItem(sessionID, sess))
|
||||
}
|
||||
|
||||
// Sort by updated descending (most recent first)
|
||||
@@ -209,20 +416,25 @@ func (h *Handler) handleGetSession(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// The sanitized filename replaces ':' with '_':
|
||||
// agent:main:pico:direct:pico:<uuid> -> agent_main_pico_direct_pico_<uuid>.json
|
||||
filename := strings.ReplaceAll(picoSessionPrefix+sessionID, ":", "_") + ".json"
|
||||
|
||||
data, err := os.ReadFile(filepath.Join(dir, filename))
|
||||
if err != nil {
|
||||
http.Error(w, "session not found", http.StatusNotFound)
|
||||
return
|
||||
sess, err := h.readJSONLSession(dir, sessionID)
|
||||
if err == nil && isEmptySession(sess) {
|
||||
err = os.ErrNotExist
|
||||
}
|
||||
|
||||
var sess sessionFile
|
||||
if err := json.Unmarshal(data, &sess); err != nil {
|
||||
http.Error(w, "failed to parse session", http.StatusInternalServerError)
|
||||
return
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
sess, err = h.readLegacySession(dir, sessionID)
|
||||
if err == nil && isEmptySession(sess) {
|
||||
err = os.ErrNotExist
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
http.Error(w, "session not found", http.StatusNotFound)
|
||||
} else {
|
||||
http.Error(w, "failed to parse session", http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Convert to a simpler format for the frontend
|
||||
@@ -268,17 +480,25 @@ func (h *Handler) handleDeleteSession(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// The sanitized filename replaces ':' with '_':
|
||||
// agent:main:pico:direct:pico:<uuid> -> agent_main_pico_direct_pico_<uuid>.json
|
||||
filename := strings.ReplaceAll(picoSessionPrefix+sessionID, ":", "_") + ".json"
|
||||
filePath := filepath.Join(dir, filename)
|
||||
base := filepath.Join(dir, sanitizeSessionKey(picoSessionPrefix+sessionID))
|
||||
jsonlPath := base + ".jsonl"
|
||||
metaPath := base + ".meta.json"
|
||||
legacyPath := base + ".json"
|
||||
|
||||
if err := os.Remove(filePath); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
http.Error(w, "session not found", http.StatusNotFound)
|
||||
} else {
|
||||
removed := false
|
||||
for _, path := range []string{jsonlPath, metaPath, legacyPath} {
|
||||
if err := os.Remove(path); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
continue
|
||||
}
|
||||
http.Error(w, "failed to delete session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
removed = true
|
||||
}
|
||||
|
||||
if !removed {
|
||||
http.Error(w, "session not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,322 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/memory"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
)
|
||||
|
||||
func sessionsTestDir(t *testing.T, configPath string) string {
|
||||
t.Helper()
|
||||
|
||||
cfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
|
||||
dir := filepath.Join(cfg.Agents.Defaults.Workspace, "sessions")
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
t.Fatalf("MkdirAll() error = %v", err)
|
||||
}
|
||||
return dir
|
||||
}
|
||||
|
||||
func TestHandleListSessions_JSONLStorage(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
|
||||
dir := sessionsTestDir(t, configPath)
|
||||
store, err := memory.NewJSONLStore(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("NewJSONLStore() error = %v", err)
|
||||
}
|
||||
|
||||
sessionKey := picoSessionPrefix + "history-jsonl"
|
||||
if err := store.AddFullMessage(nil, sessionKey, providers.Message{
|
||||
Role: "user",
|
||||
Content: "Explain why the history API is empty after migration.",
|
||||
}); err != nil {
|
||||
t.Fatalf("AddFullMessage(user) error = %v", err)
|
||||
}
|
||||
if err := store.AddFullMessage(nil, sessionKey, providers.Message{
|
||||
Role: "assistant",
|
||||
Content: "Because the API still reads only legacy JSON session files.",
|
||||
}); err != nil {
|
||||
t.Fatalf("AddFullMessage(assistant) error = %v", err)
|
||||
}
|
||||
if err := store.AddFullMessage(nil, sessionKey, providers.Message{
|
||||
Role: "tool",
|
||||
Content: "ignored",
|
||||
}); err != nil {
|
||||
t.Fatalf("AddFullMessage(tool) error = %v", err)
|
||||
}
|
||||
if err := store.SetSummary(nil, sessionKey, "JSONL-backed session"); err != nil {
|
||||
t.Fatalf("SetSummary() error = %v", err)
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/sessions", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
|
||||
var items []sessionListItem
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &items); err != nil {
|
||||
t.Fatalf("Unmarshal() error = %v", err)
|
||||
}
|
||||
if len(items) != 1 {
|
||||
t.Fatalf("len(items) = %d, want 1", len(items))
|
||||
}
|
||||
if items[0].ID != "history-jsonl" {
|
||||
t.Fatalf("items[0].ID = %q, want %q", items[0].ID, "history-jsonl")
|
||||
}
|
||||
if items[0].MessageCount != 2 {
|
||||
t.Fatalf("items[0].MessageCount = %d, want 2", items[0].MessageCount)
|
||||
}
|
||||
if items[0].Title != "JSONL-backed session" {
|
||||
t.Fatalf("items[0].Title = %q, want %q", items[0].Title, "JSONL-backed session")
|
||||
}
|
||||
if items[0].Preview != "Explain why the history API is empty after migration." {
|
||||
t.Fatalf("items[0].Preview = %q", items[0].Preview)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleListSessions_TitleUsesTrimmedSummary(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
|
||||
dir := sessionsTestDir(t, configPath)
|
||||
store, err := memory.NewJSONLStore(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("NewJSONLStore() error = %v", err)
|
||||
}
|
||||
|
||||
sessionKey := picoSessionPrefix + "summary-title"
|
||||
if err := store.AddFullMessage(nil, sessionKey, providers.Message{
|
||||
Role: "user",
|
||||
Content: "fallback preview",
|
||||
}); err != nil {
|
||||
t.Fatalf("AddFullMessage() error = %v", err)
|
||||
}
|
||||
if err := store.SetSummary(
|
||||
nil,
|
||||
sessionKey,
|
||||
" This summary is intentionally longer than sixty characters so it must be truncated in the history menu. ",
|
||||
); err != nil {
|
||||
t.Fatalf("SetSummary() error = %v", err)
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/sessions", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
|
||||
var items []sessionListItem
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &items); err != nil {
|
||||
t.Fatalf("Unmarshal() error = %v", err)
|
||||
}
|
||||
if len(items) != 1 {
|
||||
t.Fatalf("len(items) = %d, want 1", len(items))
|
||||
}
|
||||
expectedTitle := truncateRunes(
|
||||
"This summary is intentionally longer than sixty characters so it must be truncated in the history menu.",
|
||||
maxSessionTitleRunes,
|
||||
)
|
||||
if items[0].Title != expectedTitle {
|
||||
t.Fatalf("items[0].Title = %q", items[0].Title)
|
||||
}
|
||||
if items[0].Preview != "fallback preview" {
|
||||
t.Fatalf("items[0].Preview = %q, want %q", items[0].Preview, "fallback preview")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleGetSession_JSONLStorage(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
|
||||
dir := sessionsTestDir(t, configPath)
|
||||
store, err := memory.NewJSONLStore(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("NewJSONLStore() error = %v", err)
|
||||
}
|
||||
|
||||
sessionKey := picoSessionPrefix + "detail-jsonl"
|
||||
for _, msg := range []providers.Message{
|
||||
{Role: "user", Content: "first"},
|
||||
{Role: "assistant", Content: "second"},
|
||||
{Role: "tool", Content: "ignored"},
|
||||
} {
|
||||
if err := store.AddFullMessage(nil, sessionKey, msg); err != nil {
|
||||
t.Fatalf("AddFullMessage() error = %v", err)
|
||||
}
|
||||
}
|
||||
if err := store.SetSummary(nil, sessionKey, "detail summary"); err != nil {
|
||||
t.Fatalf("SetSummary() error = %v", err)
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/sessions/detail-jsonl", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
|
||||
var resp struct {
|
||||
ID string `json:"id"`
|
||||
Summary string `json:"summary"`
|
||||
Messages []struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
} `json:"messages"`
|
||||
}
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("Unmarshal() error = %v", err)
|
||||
}
|
||||
if resp.ID != "detail-jsonl" {
|
||||
t.Fatalf("resp.ID = %q, want %q", resp.ID, "detail-jsonl")
|
||||
}
|
||||
if resp.Summary != "detail summary" {
|
||||
t.Fatalf("resp.Summary = %q, want %q", resp.Summary, "detail summary")
|
||||
}
|
||||
if len(resp.Messages) != 2 {
|
||||
t.Fatalf("len(resp.Messages) = %d, want 2", len(resp.Messages))
|
||||
}
|
||||
if resp.Messages[0].Role != "user" || resp.Messages[0].Content != "first" {
|
||||
t.Fatalf("first message = %#v, want user/first", resp.Messages[0])
|
||||
}
|
||||
if resp.Messages[1].Role != "assistant" || resp.Messages[1].Content != "second" {
|
||||
t.Fatalf("second message = %#v, want assistant/second", resp.Messages[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleDeleteSession_JSONLStorage(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
|
||||
dir := sessionsTestDir(t, configPath)
|
||||
store, err := memory.NewJSONLStore(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("NewJSONLStore() error = %v", err)
|
||||
}
|
||||
|
||||
sessionKey := picoSessionPrefix + "delete-jsonl"
|
||||
if err := store.AddFullMessage(nil, sessionKey, providers.Message{
|
||||
Role: "user",
|
||||
Content: "delete me",
|
||||
}); err != nil {
|
||||
t.Fatalf("AddFullMessage() error = %v", err)
|
||||
}
|
||||
if err := store.SetSummary(nil, sessionKey, "delete summary"); err != nil {
|
||||
t.Fatalf("SetSummary() error = %v", err)
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodDelete, "/api/sessions/delete-jsonl", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusNoContent {
|
||||
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusNoContent, rec.Body.String())
|
||||
}
|
||||
|
||||
base := filepath.Join(dir, sanitizeSessionKey(sessionKey))
|
||||
for _, path := range []string{base + ".jsonl", base + ".meta.json"} {
|
||||
if _, err := os.Stat(path); !os.IsNotExist(err) {
|
||||
t.Fatalf("expected %s to be removed, stat err = %v", path, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleGetSession_LegacyJSONFallback(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
|
||||
dir := sessionsTestDir(t, configPath)
|
||||
manager := session.NewSessionManager(dir)
|
||||
sessionKey := picoSessionPrefix + "legacy-json"
|
||||
manager.AddMessage(sessionKey, "user", "legacy user")
|
||||
manager.AddMessage(sessionKey, "assistant", "legacy assistant")
|
||||
if err := manager.Save(sessionKey); err != nil {
|
||||
t.Fatalf("Save() error = %v", err)
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/sessions/legacy-json", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleSessions_FiltersEmptyJSONLFiles(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
|
||||
dir := sessionsTestDir(t, configPath)
|
||||
base := filepath.Join(dir, sanitizeSessionKey(picoSessionPrefix+"empty-jsonl"))
|
||||
if err := os.WriteFile(base+".jsonl", []byte{}, 0o644); err != nil {
|
||||
t.Fatalf("WriteFile(jsonl) error = %v", err)
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
listRec := httptest.NewRecorder()
|
||||
listReq := httptest.NewRequest(http.MethodGet, "/api/sessions", nil)
|
||||
mux.ServeHTTP(listRec, listReq)
|
||||
|
||||
if listRec.Code != http.StatusOK {
|
||||
t.Fatalf("list status = %d, want %d, body=%s", listRec.Code, http.StatusOK, listRec.Body.String())
|
||||
}
|
||||
|
||||
var items []sessionListItem
|
||||
if err := json.Unmarshal(listRec.Body.Bytes(), &items); err != nil {
|
||||
t.Fatalf("Unmarshal(list) error = %v", err)
|
||||
}
|
||||
if len(items) != 0 {
|
||||
t.Fatalf("len(items) = %d, want 0", len(items))
|
||||
}
|
||||
|
||||
detailRec := httptest.NewRecorder()
|
||||
detailReq := httptest.NewRequest(http.MethodGet, "/api/sessions/empty-jsonl", nil)
|
||||
mux.ServeHTTP(detailRec, detailReq)
|
||||
|
||||
if detailRec.Code != http.StatusNotFound {
|
||||
t.Fatalf("detail status = %d, want %d, body=%s", detailRec.Code, http.StatusNotFound, detailRec.Body.String())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,331 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/skills"
|
||||
)
|
||||
|
||||
type skillSupportResponse struct {
|
||||
Skills []skills.SkillInfo `json:"skills"`
|
||||
}
|
||||
|
||||
type skillDetailResponse struct {
|
||||
Name string `json:"name"`
|
||||
Path string `json:"path"`
|
||||
Source string `json:"source"`
|
||||
Description string `json:"description"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
var (
|
||||
skillNameSanitizer = regexp.MustCompile(`[^a-z0-9-]+`)
|
||||
importedSkillFrontmatter = regexp.MustCompile(`(?s)^---(?:\r\n|\n|\r)(.*?)(?:\r\n|\n|\r)---(?:\r\n|\n|\r)*`)
|
||||
skillFrontmatterStripper = regexp.MustCompile(`(?s)^---(?:\r\n|\n|\r)(.*?)(?:\r\n|\n|\r)---(?:\r\n|\n|\r)*`)
|
||||
)
|
||||
|
||||
func (h *Handler) registerSkillRoutes(mux *http.ServeMux) {
|
||||
mux.HandleFunc("GET /api/skills", h.handleListSkills)
|
||||
mux.HandleFunc("GET /api/skills/{name}", h.handleGetSkill)
|
||||
mux.HandleFunc("POST /api/skills/import", h.handleImportSkill)
|
||||
mux.HandleFunc("DELETE /api/skills/{name}", h.handleDeleteSkill)
|
||||
}
|
||||
|
||||
func (h *Handler) handleListSkills(w http.ResponseWriter, r *http.Request) {
|
||||
cfg, err := config.LoadConfig(h.configPath)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to load config: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
loader := newSkillsLoader(cfg.WorkspacePath())
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(skillSupportResponse{
|
||||
Skills: loader.ListSkills(),
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) handleGetSkill(w http.ResponseWriter, r *http.Request) {
|
||||
cfg, err := config.LoadConfig(h.configPath)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to load config: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
loader := newSkillsLoader(cfg.WorkspacePath())
|
||||
name := r.PathValue("name")
|
||||
allSkills := loader.ListSkills()
|
||||
|
||||
for _, skill := range allSkills {
|
||||
if skill.Name != name {
|
||||
continue
|
||||
}
|
||||
|
||||
content, err := loadSkillContent(skill.Path)
|
||||
if err != nil {
|
||||
http.Error(w, "Skill content not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(skillDetailResponse{
|
||||
Name: skill.Name,
|
||||
Path: skill.Path,
|
||||
Source: skill.Source,
|
||||
Description: skill.Description,
|
||||
Content: content,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
http.Error(w, "Skill not found", http.StatusNotFound)
|
||||
}
|
||||
|
||||
func (h *Handler) handleImportSkill(w http.ResponseWriter, r *http.Request) {
|
||||
cfg, err := config.LoadConfig(h.configPath)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to load config: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
err = r.ParseMultipartForm(2 << 20)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Invalid multipart form: %v", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
uploadedFile, fileHeader, err := r.FormFile("file")
|
||||
if err != nil {
|
||||
http.Error(w, "file is required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
defer uploadedFile.Close()
|
||||
|
||||
content, err := io.ReadAll(io.LimitReader(uploadedFile, (1<<20)+1))
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to read file: %v", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if len(content) > 1<<20 {
|
||||
http.Error(w, "file exceeds 1MB limit", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
skillName, err := normalizeImportedSkillName(fileHeader.Filename, content)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
content = normalizeImportedSkillContent(content, skillName)
|
||||
|
||||
workspace := cfg.WorkspacePath()
|
||||
skillDir := filepath.Join(workspace, "skills", skillName)
|
||||
skillFile := filepath.Join(skillDir, "SKILL.md")
|
||||
if _, err := os.Stat(skillDir); err == nil {
|
||||
http.Error(w, "skill already exists", http.StatusConflict)
|
||||
return
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(skillDir, 0o755); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to create skill directory: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if err := os.WriteFile(skillFile, content, 0o644); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to save skill: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
loader := newSkillsLoader(workspace)
|
||||
for _, skill := range loader.ListSkills() {
|
||||
if skill.Path == skillFile || (skill.Name == skillName && skill.Source == "workspace") {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(skill)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"name": skillName,
|
||||
"path": skillFile,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) handleDeleteSkill(w http.ResponseWriter, r *http.Request) {
|
||||
cfg, err := config.LoadConfig(h.configPath)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to load config: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
loader := newSkillsLoader(cfg.WorkspacePath())
|
||||
name := r.PathValue("name")
|
||||
for _, skill := range loader.ListSkills() {
|
||||
if skill.Name != name {
|
||||
continue
|
||||
}
|
||||
if skill.Source != "workspace" {
|
||||
http.Error(w, "only workspace skills can be deleted", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if err := os.RemoveAll(filepath.Dir(skill.Path)); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to delete skill: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
|
||||
return
|
||||
}
|
||||
|
||||
http.Error(w, "Skill not found", http.StatusNotFound)
|
||||
}
|
||||
|
||||
func newSkillsLoader(workspace string) *skills.SkillsLoader {
|
||||
return skills.NewSkillsLoader(
|
||||
workspace,
|
||||
filepath.Join(globalConfigDir(), "skills"),
|
||||
builtinSkillsDir(),
|
||||
)
|
||||
}
|
||||
|
||||
func normalizeImportedSkillName(filename string, content []byte) (string, error) {
|
||||
rawContent := strings.ReplaceAll(string(content), "\r\n", "\n")
|
||||
rawContent = strings.ReplaceAll(rawContent, "\r", "\n")
|
||||
metadata, _ := extractImportedSkillMetadata(rawContent)
|
||||
|
||||
raw := strings.TrimSpace(metadata["name"])
|
||||
if raw == "" {
|
||||
raw = strings.TrimSpace(strings.TrimSuffix(filepath.Base(filename), filepath.Ext(filename)))
|
||||
}
|
||||
raw = strings.ToLower(raw)
|
||||
raw = strings.ReplaceAll(raw, "_", "-")
|
||||
raw = strings.ReplaceAll(raw, " ", "-")
|
||||
raw = skillNameSanitizer.ReplaceAllString(raw, "-")
|
||||
raw = strings.Trim(raw, "-")
|
||||
raw = strings.Join(strings.FieldsFunc(raw, func(r rune) bool { return r == '-' }), "-")
|
||||
|
||||
if raw == "" {
|
||||
return "", fmt.Errorf("skill name is required in frontmatter or filename")
|
||||
}
|
||||
if len(raw) > 64 {
|
||||
return "", fmt.Errorf("skill name exceeds 64 characters")
|
||||
}
|
||||
matched, err := regexp.MatchString(`^[a-z0-9]+(-[a-z0-9]+)*$`, raw)
|
||||
if err != nil || !matched {
|
||||
return "", fmt.Errorf("skill name must be alphanumeric with hyphens")
|
||||
}
|
||||
return raw, nil
|
||||
}
|
||||
|
||||
func normalizeImportedSkillContent(content []byte, skillName string) []byte {
|
||||
raw := strings.ReplaceAll(string(content), "\r\n", "\n")
|
||||
raw = strings.ReplaceAll(raw, "\r", "\n")
|
||||
|
||||
metadata, body := extractImportedSkillMetadata(raw)
|
||||
description := strings.TrimSpace(metadata["description"])
|
||||
if description == "" {
|
||||
description = inferImportedSkillDescription(body)
|
||||
}
|
||||
if description == "" {
|
||||
description = "Imported skill"
|
||||
}
|
||||
if len(description) > 1024 {
|
||||
description = strings.TrimSpace(description[:1024])
|
||||
}
|
||||
|
||||
body = strings.TrimLeft(body, "\n")
|
||||
var builder strings.Builder
|
||||
builder.WriteString("---\n")
|
||||
builder.WriteString("name: ")
|
||||
builder.WriteString(skillName)
|
||||
builder.WriteString("\n")
|
||||
builder.WriteString("description: ")
|
||||
builder.WriteString(description)
|
||||
builder.WriteString("\n")
|
||||
builder.WriteString("---\n\n")
|
||||
builder.WriteString(body)
|
||||
if !strings.HasSuffix(builder.String(), "\n") {
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
return []byte(builder.String())
|
||||
}
|
||||
|
||||
func extractImportedSkillMetadata(raw string) (map[string]string, string) {
|
||||
matches := importedSkillFrontmatter.FindStringSubmatch(raw)
|
||||
if len(matches) != 2 {
|
||||
return map[string]string{}, raw
|
||||
}
|
||||
meta := parseImportedSkillYAML(matches[1])
|
||||
body := importedSkillFrontmatter.ReplaceAllString(raw, "")
|
||||
return meta, body
|
||||
}
|
||||
|
||||
func parseImportedSkillYAML(frontmatter string) map[string]string {
|
||||
result := make(map[string]string)
|
||||
for _, line := range strings.Split(frontmatter, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
key, value, ok := strings.Cut(line, ":")
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
result[strings.TrimSpace(key)] = strings.Trim(strings.TrimSpace(value), `"'`)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func inferImportedSkillDescription(body string) string {
|
||||
for _, line := range strings.Split(body, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
line = strings.TrimLeft(line, "#-*0123456789. ")
|
||||
line = strings.TrimSpace(line)
|
||||
if line != "" {
|
||||
return line
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func loadSkillContent(path string) (string, error) {
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return skillFrontmatterStripper.ReplaceAllString(string(content), ""), nil
|
||||
}
|
||||
|
||||
func globalConfigDir() string {
|
||||
if home := os.Getenv("PICOCLAW_HOME"); home != "" {
|
||||
return home
|
||||
}
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return filepath.Join(home, ".picoclaw")
|
||||
}
|
||||
|
||||
func builtinSkillsDir() string {
|
||||
if path := os.Getenv("PICOCLAW_BUILTIN_SKILLS"); path != "" {
|
||||
return path
|
||||
}
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return filepath.Join(wd, "skills")
|
||||
}
|
||||
@@ -0,0 +1,336 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
func TestHandleListSkills(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
|
||||
cfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
|
||||
workspace := filepath.Join(t.TempDir(), "workspace")
|
||||
cfg.Agents.Defaults.Workspace = workspace
|
||||
err = config.SaveConfig(configPath, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Join(workspace, "skills", "workspace-skill"), 0o755); err != nil {
|
||||
t.Fatalf("MkdirAll(workspace skill) error = %v", err)
|
||||
}
|
||||
if err := os.WriteFile(
|
||||
filepath.Join(workspace, "skills", "workspace-skill", "SKILL.md"),
|
||||
[]byte("---\nname: workspace-skill\ndescription: Workspace skill\n---\n"),
|
||||
0o644,
|
||||
); err != nil {
|
||||
t.Fatalf("WriteFile(workspace skill) error = %v", err)
|
||||
}
|
||||
|
||||
globalSkillDir := filepath.Join(globalConfigDir(), "skills", "global-skill")
|
||||
if err := os.MkdirAll(globalSkillDir, 0o755); err != nil {
|
||||
t.Fatalf("MkdirAll(global skill) error = %v", err)
|
||||
}
|
||||
if err := os.WriteFile(
|
||||
filepath.Join(globalSkillDir, "SKILL.md"),
|
||||
[]byte("---\nname: global-skill\ndescription: Global skill\n---\n"),
|
||||
0o644,
|
||||
); err != nil {
|
||||
t.Fatalf("WriteFile(global skill) error = %v", err)
|
||||
}
|
||||
|
||||
builtinRoot := filepath.Join(t.TempDir(), "builtin-skills")
|
||||
oldBuiltin := os.Getenv("PICOCLAW_BUILTIN_SKILLS")
|
||||
if err := os.Setenv("PICOCLAW_BUILTIN_SKILLS", builtinRoot); err != nil {
|
||||
t.Fatalf("Setenv(PICOCLAW_BUILTIN_SKILLS) error = %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if oldBuiltin == "" {
|
||||
_ = os.Unsetenv("PICOCLAW_BUILTIN_SKILLS")
|
||||
} else {
|
||||
_ = os.Setenv("PICOCLAW_BUILTIN_SKILLS", oldBuiltin)
|
||||
}
|
||||
}()
|
||||
|
||||
builtinSkillDir := filepath.Join(builtinRoot, "builtin-skill")
|
||||
if err := os.MkdirAll(builtinSkillDir, 0o755); err != nil {
|
||||
t.Fatalf("MkdirAll(builtin skill) error = %v", err)
|
||||
}
|
||||
if err := os.WriteFile(
|
||||
filepath.Join(builtinSkillDir, "SKILL.md"),
|
||||
[]byte("---\nname: builtin-skill\ndescription: Builtin skill\n---\n"),
|
||||
0o644,
|
||||
); err != nil {
|
||||
t.Fatalf("WriteFile(builtin skill) error = %v", err)
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/skills", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
|
||||
var resp skillSupportResponse
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("Unmarshal() error = %v", err)
|
||||
}
|
||||
if len(resp.Skills) != 3 {
|
||||
t.Fatalf("skills count = %d, want 3", len(resp.Skills))
|
||||
}
|
||||
|
||||
gotSkills := make(map[string]string, len(resp.Skills))
|
||||
for _, skill := range resp.Skills {
|
||||
gotSkills[skill.Name] = skill.Source
|
||||
}
|
||||
if gotSkills["workspace-skill"] != "workspace" {
|
||||
t.Fatalf("workspace-skill source = %q, want workspace", gotSkills["workspace-skill"])
|
||||
}
|
||||
if gotSkills["global-skill"] != "global" {
|
||||
t.Fatalf("global-skill source = %q, want global", gotSkills["global-skill"])
|
||||
}
|
||||
if gotSkills["builtin-skill"] != "builtin" {
|
||||
t.Fatalf("builtin-skill source = %q, want builtin", gotSkills["builtin-skill"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleGetSkill(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
|
||||
cfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
|
||||
workspace := filepath.Join(t.TempDir(), "workspace")
|
||||
cfg.Agents.Defaults.Workspace = workspace
|
||||
err = config.SaveConfig(configPath, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
skillDir := filepath.Join(workspace, "skills", "viewer-skill")
|
||||
if err := os.MkdirAll(skillDir, 0o755); err != nil {
|
||||
t.Fatalf("MkdirAll() error = %v", err)
|
||||
}
|
||||
if err := os.WriteFile(
|
||||
filepath.Join(skillDir, "SKILL.md"),
|
||||
[]byte(
|
||||
"---\nname: viewer-skill\ndescription: Viewable skill\n---\n# Viewer Skill\n\nThis is visible content.\n",
|
||||
),
|
||||
0o644,
|
||||
); err != nil {
|
||||
t.Fatalf("WriteFile() error = %v", err)
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/skills/viewer-skill", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
|
||||
var resp skillDetailResponse
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("Unmarshal() error = %v", err)
|
||||
}
|
||||
if resp.Name != "viewer-skill" || resp.Source != "workspace" || resp.Description != "Viewable skill" {
|
||||
t.Fatalf("unexpected response: %#v", resp)
|
||||
}
|
||||
if resp.Content != "# Viewer Skill\n\nThis is visible content.\n" {
|
||||
t.Fatalf("content = %q", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleGetSkillUsesResolvedPath(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
|
||||
cfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
|
||||
workspace := filepath.Join(t.TempDir(), "workspace")
|
||||
cfg.Agents.Defaults.Workspace = workspace
|
||||
err = config.SaveConfig(configPath, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
skillDir := filepath.Join(workspace, "skills", "folder-name")
|
||||
if err := os.MkdirAll(skillDir, 0o755); err != nil {
|
||||
t.Fatalf("MkdirAll() error = %v", err)
|
||||
}
|
||||
if err := os.WriteFile(
|
||||
filepath.Join(skillDir, "SKILL.md"),
|
||||
[]byte("---\nname: display-name\ndescription: Mismatched path skill\n---\n# Display Name\n"),
|
||||
0o644,
|
||||
); err != nil {
|
||||
t.Fatalf("WriteFile() error = %v", err)
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/skills/display-name", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
|
||||
var resp skillDetailResponse
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("Unmarshal() error = %v", err)
|
||||
}
|
||||
if resp.Name != "display-name" {
|
||||
t.Fatalf("resp.Name = %q, want display-name", resp.Name)
|
||||
}
|
||||
if resp.Content != "# Display Name\n" {
|
||||
t.Fatalf("content = %q", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleImportSkill(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
|
||||
cfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
workspace := filepath.Join(t.TempDir(), "workspace")
|
||||
cfg.Agents.Defaults.Workspace = workspace
|
||||
err = config.SaveConfig(configPath, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
var body bytes.Buffer
|
||||
writer := multipart.NewWriter(&body)
|
||||
part, err := writer.CreateFormFile("file", "Plain Skill.md")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateFormFile() error = %v", err)
|
||||
}
|
||||
_, err = io.WriteString(part, "# Plain Skill\n\nUse this skill to test imports.\n")
|
||||
if err != nil {
|
||||
t.Fatalf("WriteString() error = %v", err)
|
||||
}
|
||||
err = writer.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("Close() error = %v", err)
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/skills/import", &body)
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
|
||||
skillFile := filepath.Join(workspace, "skills", "plain-skill", "SKILL.md")
|
||||
content, err := os.ReadFile(skillFile)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFile() error = %v", err)
|
||||
}
|
||||
expected := "---\nname: plain-skill\ndescription: Plain Skill\n---\n\n# Plain Skill\n\nUse this skill to test imports.\n"
|
||||
if string(content) != expected {
|
||||
t.Fatalf("saved skill content mismatch:\n%s", string(content))
|
||||
}
|
||||
|
||||
rec2 := httptest.NewRecorder()
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/api/skills", nil)
|
||||
mux.ServeHTTP(rec2, req2)
|
||||
if rec2.Code != http.StatusOK {
|
||||
t.Fatalf("list status = %d, want %d, body=%s", rec2.Code, http.StatusOK, rec2.Body.String())
|
||||
}
|
||||
var listResp skillSupportResponse
|
||||
if err := json.Unmarshal(rec2.Body.Bytes(), &listResp); err != nil {
|
||||
t.Fatalf("Unmarshal list response error = %v", err)
|
||||
}
|
||||
found := false
|
||||
for _, skill := range listResp.Skills {
|
||||
if skill.Name == "plain-skill" && skill.Source == "workspace" && skill.Description == "Plain Skill" {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Fatalf("plain-skill should be listed after import, got %#v", listResp.Skills)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleDeleteSkill(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
|
||||
cfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
workspace := filepath.Join(t.TempDir(), "workspace")
|
||||
cfg.Agents.Defaults.Workspace = workspace
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
skillDir := filepath.Join(workspace, "skills", "delete-me")
|
||||
if err := os.MkdirAll(skillDir, 0o755); err != nil {
|
||||
t.Fatalf("MkdirAll() error = %v", err)
|
||||
}
|
||||
if err := os.WriteFile(
|
||||
filepath.Join(skillDir, "SKILL.md"),
|
||||
[]byte("---\nname: delete-me\ndescription: delete me\n---\n"),
|
||||
0o644,
|
||||
); err != nil {
|
||||
t.Fatalf("WriteFile() error = %v", err)
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodDelete, "/api/skills/delete-me", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
if _, err := os.Stat(skillDir); !os.IsNotExist(err) {
|
||||
t.Fatalf("skill directory should be removed, stat err=%v", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,323 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"runtime"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
type toolCatalogEntry struct {
|
||||
Name string
|
||||
Description string
|
||||
Category string
|
||||
ConfigKey string
|
||||
}
|
||||
|
||||
type toolSupportItem struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Category string `json:"category"`
|
||||
ConfigKey string `json:"config_key"`
|
||||
Status string `json:"status"`
|
||||
ReasonCode string `json:"reason_code,omitempty"`
|
||||
}
|
||||
|
||||
type toolSupportResponse struct {
|
||||
Tools []toolSupportItem `json:"tools"`
|
||||
}
|
||||
|
||||
type toolStateRequest struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
var toolCatalog = []toolCatalogEntry{
|
||||
{
|
||||
Name: "read_file",
|
||||
Description: "Read file content from the workspace or explicitly allowed paths.",
|
||||
Category: "filesystem",
|
||||
ConfigKey: "read_file",
|
||||
},
|
||||
{
|
||||
Name: "write_file",
|
||||
Description: "Create or overwrite files within the writable workspace scope.",
|
||||
Category: "filesystem",
|
||||
ConfigKey: "write_file",
|
||||
},
|
||||
{
|
||||
Name: "list_dir",
|
||||
Description: "Inspect directories and enumerate files available to the agent.",
|
||||
Category: "filesystem",
|
||||
ConfigKey: "list_dir",
|
||||
},
|
||||
{
|
||||
Name: "edit_file",
|
||||
Description: "Apply targeted edits to existing files without rewriting everything.",
|
||||
Category: "filesystem",
|
||||
ConfigKey: "edit_file",
|
||||
},
|
||||
{
|
||||
Name: "append_file",
|
||||
Description: "Append content to the end of an existing file.",
|
||||
Category: "filesystem",
|
||||
ConfigKey: "append_file",
|
||||
},
|
||||
{
|
||||
Name: "exec",
|
||||
Description: "Run shell commands inside the configured workspace sandbox.",
|
||||
Category: "filesystem",
|
||||
ConfigKey: "exec",
|
||||
},
|
||||
{
|
||||
Name: "cron",
|
||||
Description: "Schedule one-time or recurring reminders, jobs, and shell commands.",
|
||||
Category: "automation",
|
||||
ConfigKey: "cron",
|
||||
},
|
||||
{
|
||||
Name: "web_search",
|
||||
Description: "Search the web using the configured providers.",
|
||||
Category: "web",
|
||||
ConfigKey: "web",
|
||||
},
|
||||
{
|
||||
Name: "web_fetch",
|
||||
Description: "Fetch and summarize the contents of a webpage.",
|
||||
Category: "web",
|
||||
ConfigKey: "web_fetch",
|
||||
},
|
||||
{
|
||||
Name: "message",
|
||||
Description: "Send a follow-up message back to the active user or chat.",
|
||||
Category: "communication",
|
||||
ConfigKey: "message",
|
||||
},
|
||||
{
|
||||
Name: "send_file",
|
||||
Description: "Send an outbound file or media attachment to the active chat.",
|
||||
Category: "communication",
|
||||
ConfigKey: "send_file",
|
||||
},
|
||||
{
|
||||
Name: "find_skills",
|
||||
Description: "Search external skill registries for installable skills.",
|
||||
Category: "skills",
|
||||
ConfigKey: "find_skills",
|
||||
},
|
||||
{
|
||||
Name: "install_skill",
|
||||
Description: "Install a skill into the current workspace from a registry.",
|
||||
Category: "skills",
|
||||
ConfigKey: "install_skill",
|
||||
},
|
||||
{
|
||||
Name: "spawn",
|
||||
Description: "Launch a background subagent for long-running or delegated work.",
|
||||
Category: "agents",
|
||||
ConfigKey: "spawn",
|
||||
},
|
||||
{
|
||||
Name: "i2c",
|
||||
Description: "Interact with I2C hardware devices exposed on the host.",
|
||||
Category: "hardware",
|
||||
ConfigKey: "i2c",
|
||||
},
|
||||
{
|
||||
Name: "spi",
|
||||
Description: "Interact with SPI hardware devices exposed on the host.",
|
||||
Category: "hardware",
|
||||
ConfigKey: "spi",
|
||||
},
|
||||
{
|
||||
Name: "tool_search_tool_regex",
|
||||
Description: "Discover hidden MCP tools by regex search when tool discovery is enabled.",
|
||||
Category: "discovery",
|
||||
ConfigKey: "mcp.discovery.use_regex",
|
||||
},
|
||||
{
|
||||
Name: "tool_search_tool_bm25",
|
||||
Description: "Discover hidden MCP tools by semantic ranking when tool discovery is enabled.",
|
||||
Category: "discovery",
|
||||
ConfigKey: "mcp.discovery.use_bm25",
|
||||
},
|
||||
}
|
||||
|
||||
func (h *Handler) registerToolRoutes(mux *http.ServeMux) {
|
||||
mux.HandleFunc("GET /api/tools", h.handleListTools)
|
||||
mux.HandleFunc("PUT /api/tools/{name}/state", h.handleUpdateToolState)
|
||||
}
|
||||
|
||||
func (h *Handler) handleListTools(w http.ResponseWriter, r *http.Request) {
|
||||
cfg, err := config.LoadConfig(h.configPath)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to load config: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(toolSupportResponse{
|
||||
Tools: buildToolSupport(cfg),
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) handleUpdateToolState(w http.ResponseWriter, r *http.Request) {
|
||||
cfg, err := config.LoadConfig(h.configPath)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to load config: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
var req toolStateRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Invalid JSON: %v", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err := applyToolState(cfg, r.PathValue("name"), req.Enabled); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err := config.SaveConfig(h.configPath, cfg); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to save config: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
|
||||
}
|
||||
|
||||
func buildToolSupport(cfg *config.Config) []toolSupportItem {
|
||||
items := make([]toolSupportItem, 0, len(toolCatalog))
|
||||
for _, entry := range toolCatalog {
|
||||
status := "disabled"
|
||||
reasonCode := ""
|
||||
|
||||
switch entry.Name {
|
||||
case "find_skills", "install_skill":
|
||||
if cfg.Tools.IsToolEnabled(entry.ConfigKey) {
|
||||
if cfg.Tools.IsToolEnabled("skills") {
|
||||
status = "enabled"
|
||||
} else {
|
||||
status = "blocked"
|
||||
reasonCode = "requires_skills"
|
||||
}
|
||||
}
|
||||
case "spawn":
|
||||
if cfg.Tools.IsToolEnabled(entry.ConfigKey) {
|
||||
if cfg.Tools.IsToolEnabled("subagent") {
|
||||
status = "enabled"
|
||||
} else {
|
||||
status = "blocked"
|
||||
reasonCode = "requires_subagent"
|
||||
}
|
||||
}
|
||||
case "tool_search_tool_regex":
|
||||
status, reasonCode = resolveDiscoveryToolSupport(cfg, cfg.Tools.MCP.Discovery.UseRegex)
|
||||
case "tool_search_tool_bm25":
|
||||
status, reasonCode = resolveDiscoveryToolSupport(cfg, cfg.Tools.MCP.Discovery.UseBM25)
|
||||
case "i2c", "spi":
|
||||
status, reasonCode = resolveHardwareToolSupport(cfg.Tools.IsToolEnabled(entry.ConfigKey))
|
||||
default:
|
||||
if cfg.Tools.IsToolEnabled(entry.ConfigKey) {
|
||||
status = "enabled"
|
||||
}
|
||||
}
|
||||
|
||||
items = append(items, toolSupportItem{
|
||||
Name: entry.Name,
|
||||
Description: entry.Description,
|
||||
Category: entry.Category,
|
||||
ConfigKey: entry.ConfigKey,
|
||||
Status: status,
|
||||
ReasonCode: reasonCode,
|
||||
})
|
||||
}
|
||||
return items
|
||||
}
|
||||
|
||||
func resolveHardwareToolSupport(enabled bool) (string, string) {
|
||||
if !enabled {
|
||||
return "disabled", ""
|
||||
}
|
||||
if runtime.GOOS != "linux" {
|
||||
return "blocked", "requires_linux"
|
||||
}
|
||||
return "enabled", ""
|
||||
}
|
||||
|
||||
func resolveDiscoveryToolSupport(cfg *config.Config, methodEnabled bool) (string, string) {
|
||||
if !cfg.Tools.IsToolEnabled("mcp") {
|
||||
return "disabled", ""
|
||||
}
|
||||
if !cfg.Tools.MCP.Discovery.Enabled {
|
||||
return "blocked", "requires_mcp_discovery"
|
||||
}
|
||||
if !methodEnabled {
|
||||
return "disabled", ""
|
||||
}
|
||||
return "enabled", ""
|
||||
}
|
||||
|
||||
func applyToolState(cfg *config.Config, toolName string, enabled bool) error {
|
||||
switch toolName {
|
||||
case "read_file":
|
||||
cfg.Tools.ReadFile.Enabled = enabled
|
||||
case "write_file":
|
||||
cfg.Tools.WriteFile.Enabled = enabled
|
||||
case "list_dir":
|
||||
cfg.Tools.ListDir.Enabled = enabled
|
||||
case "edit_file":
|
||||
cfg.Tools.EditFile.Enabled = enabled
|
||||
case "append_file":
|
||||
cfg.Tools.AppendFile.Enabled = enabled
|
||||
case "exec":
|
||||
cfg.Tools.Exec.Enabled = enabled
|
||||
case "cron":
|
||||
cfg.Tools.Cron.Enabled = enabled
|
||||
case "web_search":
|
||||
cfg.Tools.Web.Enabled = enabled
|
||||
case "web_fetch":
|
||||
cfg.Tools.WebFetch.Enabled = enabled
|
||||
case "message":
|
||||
cfg.Tools.Message.Enabled = enabled
|
||||
case "send_file":
|
||||
cfg.Tools.SendFile.Enabled = enabled
|
||||
case "find_skills":
|
||||
cfg.Tools.FindSkills.Enabled = enabled
|
||||
if enabled {
|
||||
cfg.Tools.Skills.Enabled = true
|
||||
}
|
||||
case "install_skill":
|
||||
cfg.Tools.InstallSkill.Enabled = enabled
|
||||
if enabled {
|
||||
cfg.Tools.Skills.Enabled = true
|
||||
}
|
||||
case "spawn":
|
||||
cfg.Tools.Spawn.Enabled = enabled
|
||||
if enabled {
|
||||
cfg.Tools.Subagent.Enabled = true
|
||||
}
|
||||
case "i2c":
|
||||
cfg.Tools.I2C.Enabled = enabled
|
||||
case "spi":
|
||||
cfg.Tools.SPI.Enabled = enabled
|
||||
case "tool_search_tool_regex":
|
||||
cfg.Tools.MCP.Discovery.UseRegex = enabled
|
||||
if enabled {
|
||||
cfg.Tools.MCP.Enabled = true
|
||||
cfg.Tools.MCP.Discovery.Enabled = true
|
||||
}
|
||||
case "tool_search_tool_bm25":
|
||||
cfg.Tools.MCP.Discovery.UseBM25 = enabled
|
||||
if enabled {
|
||||
cfg.Tools.MCP.Enabled = true
|
||||
cfg.Tools.MCP.Discovery.Enabled = true
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("tool %q cannot be updated", toolName)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,198 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
func TestHandleListTools(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
|
||||
cfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
cfg.Tools.ReadFile.Enabled = true
|
||||
cfg.Tools.WriteFile.Enabled = false
|
||||
cfg.Tools.Cron.Enabled = true
|
||||
cfg.Tools.FindSkills.Enabled = true
|
||||
cfg.Tools.Skills.Enabled = true
|
||||
cfg.Tools.Spawn.Enabled = true
|
||||
cfg.Tools.Subagent.Enabled = false
|
||||
cfg.Tools.MCP.Enabled = true
|
||||
cfg.Tools.MCP.Discovery.Enabled = true
|
||||
cfg.Tools.MCP.Discovery.UseRegex = true
|
||||
cfg.Tools.MCP.Discovery.UseBM25 = false
|
||||
err = config.SaveConfig(configPath, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/tools", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
|
||||
var resp toolSupportResponse
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("Unmarshal() error = %v", err)
|
||||
}
|
||||
gotTools := make(map[string]toolSupportItem, len(resp.Tools))
|
||||
for _, tool := range resp.Tools {
|
||||
gotTools[tool.Name] = tool
|
||||
}
|
||||
if gotTools["read_file"].Status != "enabled" {
|
||||
t.Fatalf("read_file status = %q, want enabled", gotTools["read_file"].Status)
|
||||
}
|
||||
if gotTools["write_file"].Status != "disabled" {
|
||||
t.Fatalf("write_file status = %q, want disabled", gotTools["write_file"].Status)
|
||||
}
|
||||
if gotTools["cron"].Status != "enabled" {
|
||||
t.Fatalf("cron status = %q, want enabled", gotTools["cron"].Status)
|
||||
}
|
||||
if gotTools["spawn"].Status != "blocked" || gotTools["spawn"].ReasonCode != "requires_subagent" {
|
||||
t.Fatalf("spawn = %#v, want blocked/requires_subagent", gotTools["spawn"])
|
||||
}
|
||||
if gotTools["find_skills"].Status != "enabled" {
|
||||
t.Fatalf("find_skills status = %q, want enabled", gotTools["find_skills"].Status)
|
||||
}
|
||||
if gotTools["tool_search_tool_regex"].Status != "enabled" {
|
||||
t.Fatalf("tool_search_tool_regex status = %q, want enabled", gotTools["tool_search_tool_regex"].Status)
|
||||
}
|
||||
if gotTools["tool_search_tool_regex"].ConfigKey != "mcp.discovery.use_regex" {
|
||||
t.Fatalf(
|
||||
"tool_search_tool_regex config_key = %q, want mcp.discovery.use_regex",
|
||||
gotTools["tool_search_tool_regex"].ConfigKey,
|
||||
)
|
||||
}
|
||||
if gotTools["tool_search_tool_bm25"].Status != "disabled" {
|
||||
t.Fatalf("tool_search_tool_bm25 status = %q, want disabled", gotTools["tool_search_tool_bm25"].Status)
|
||||
}
|
||||
if gotTools["tool_search_tool_bm25"].ConfigKey != "mcp.discovery.use_bm25" {
|
||||
t.Fatalf(
|
||||
"tool_search_tool_bm25 config_key = %q, want mcp.discovery.use_bm25",
|
||||
gotTools["tool_search_tool_bm25"].ConfigKey,
|
||||
)
|
||||
}
|
||||
if runtime.GOOS == "linux" {
|
||||
if gotTools["i2c"].Status != "disabled" {
|
||||
t.Fatalf("i2c status = %q, want disabled on linux when config is off", gotTools["i2c"].Status)
|
||||
}
|
||||
} else {
|
||||
cfg.Tools.I2C.Enabled = true
|
||||
cfg.Tools.SPI.Enabled = true
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req = httptest.NewRequest(http.MethodGet, "/api/tools", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("Unmarshal() error = %v", err)
|
||||
}
|
||||
gotTools = make(map[string]toolSupportItem, len(resp.Tools))
|
||||
for _, tool := range resp.Tools {
|
||||
gotTools[tool.Name] = tool
|
||||
}
|
||||
|
||||
if gotTools["i2c"].Status != "blocked" || gotTools["i2c"].ReasonCode != "requires_linux" {
|
||||
t.Fatalf("i2c = %#v, want blocked/requires_linux", gotTools["i2c"])
|
||||
}
|
||||
if gotTools["spi"].Status != "blocked" || gotTools["spi"].ReasonCode != "requires_linux" {
|
||||
t.Fatalf("spi = %#v, want blocked/requires_linux", gotTools["spi"])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleUpdateToolState(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
|
||||
cfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
cfg.Tools.Spawn.Enabled = false
|
||||
cfg.Tools.Subagent.Enabled = false
|
||||
cfg.Tools.Cron.Enabled = false
|
||||
cfg.Tools.MCP.Enabled = false
|
||||
cfg.Tools.MCP.Discovery.Enabled = false
|
||||
cfg.Tools.MCP.Discovery.UseRegex = false
|
||||
err = config.SaveConfig(configPath, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(
|
||||
http.MethodPut,
|
||||
"/api/tools/spawn/state",
|
||||
bytes.NewBufferString(`{"enabled":true}`),
|
||||
)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
mux.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("spawn status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
|
||||
rec2 := httptest.NewRecorder()
|
||||
req2 := httptest.NewRequest(
|
||||
http.MethodPut,
|
||||
"/api/tools/tool_search_tool_regex/state",
|
||||
bytes.NewBufferString(`{"enabled":true}`),
|
||||
)
|
||||
req2.Header.Set("Content-Type", "application/json")
|
||||
mux.ServeHTTP(rec2, req2)
|
||||
if rec2.Code != http.StatusOK {
|
||||
t.Fatalf("regex status = %d, want %d, body=%s", rec2.Code, http.StatusOK, rec2.Body.String())
|
||||
}
|
||||
|
||||
rec3 := httptest.NewRecorder()
|
||||
req3 := httptest.NewRequest(
|
||||
http.MethodPut,
|
||||
"/api/tools/cron/state",
|
||||
bytes.NewBufferString(`{"enabled":true}`),
|
||||
)
|
||||
req3.Header.Set("Content-Type", "application/json")
|
||||
mux.ServeHTTP(rec3, req3)
|
||||
if rec3.Code != http.StatusOK {
|
||||
t.Fatalf("cron status = %d, want %d, body=%s", rec3.Code, http.StatusOK, rec3.Body.String())
|
||||
}
|
||||
|
||||
updated, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig(updated) error = %v", err)
|
||||
}
|
||||
if !updated.Tools.Spawn.Enabled || !updated.Tools.Subagent.Enabled {
|
||||
t.Fatalf("spawn/subagent should both be enabled: %#v", updated.Tools)
|
||||
}
|
||||
if !updated.Tools.MCP.Enabled || !updated.Tools.MCP.Discovery.Enabled || !updated.Tools.MCP.Discovery.UseRegex {
|
||||
t.Fatalf("mcp regex discovery should be enabled: %#v", updated.Tools.MCP)
|
||||
}
|
||||
if !updated.Tools.Cron.Enabled {
|
||||
t.Fatalf("cron should be enabled: %#v", updated.Tools.Cron)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user