mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge branch 'main' into refactor-inbound-context-routing-session
# Conflicts: # pkg/agent/eventbus_test.go # pkg/agent/loop.go # pkg/bus/bus.go # pkg/bus/types.go # pkg/channels/pico/pico.go # pkg/channels/telegram/telegram.go # pkg/config/config.go # web/backend/api/session.go # web/backend/api/session_test.go
This commit is contained in:
@@ -23,6 +23,7 @@ type LauncherAuthRouteOpts struct {
|
||||
type LauncherAuthTokenHelp struct {
|
||||
EnvVarName string `json:"env_var_name"`
|
||||
LogFileAbs string `json:"log_file,omitempty"`
|
||||
ConfigFileAbs string `json:"config_file,omitempty"`
|
||||
TrayCopyMenu bool `json:"tray_copy_menu"`
|
||||
ConsoleStdout bool `json:"console_stdout"`
|
||||
}
|
||||
|
||||
@@ -3,6 +3,8 @@ package api
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
type channelCatalogItem struct {
|
||||
@@ -30,9 +32,22 @@ var channelCatalog = []channelCatalogItem{
|
||||
{Name: "irc", ConfigKey: "irc"},
|
||||
}
|
||||
|
||||
type channelConfigResponse struct {
|
||||
Config any `json:"config"`
|
||||
ConfiguredSecrets []string `json:"configured_secrets"`
|
||||
ConfigKey string `json:"config_key"`
|
||||
Variant string `json:"variant,omitempty"`
|
||||
}
|
||||
|
||||
type channelSecretPresence struct {
|
||||
key string
|
||||
configured bool
|
||||
}
|
||||
|
||||
// registerChannelRoutes binds read-only channel catalog endpoints to the ServeMux.
|
||||
func (h *Handler) registerChannelRoutes(mux *http.ServeMux) {
|
||||
mux.HandleFunc("GET /api/channels/catalog", h.handleListChannelCatalog)
|
||||
mux.HandleFunc("GET /api/channels/{name}/config", h.handleGetChannelConfig)
|
||||
}
|
||||
|
||||
// handleListChannelCatalog returns the channels supported by backend.
|
||||
@@ -44,3 +59,172 @@ func (h *Handler) handleListChannelCatalog(w http.ResponseWriter, r *http.Reques
|
||||
"channels": channelCatalog,
|
||||
})
|
||||
}
|
||||
|
||||
// handleGetChannelConfig returns safe channel config plus secret presence metadata.
|
||||
//
|
||||
// GET /api/channels/{name}/config
|
||||
func (h *Handler) handleGetChannelConfig(w http.ResponseWriter, r *http.Request) {
|
||||
channelName := r.PathValue("name")
|
||||
item, ok := findChannelCatalogItem(channelName)
|
||||
if !ok {
|
||||
http.Error(w, "Channel not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
cfg, err := config.LoadConfig(h.configPath)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to load config", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
resp := buildChannelConfigResponse(cfg, item)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
func findChannelCatalogItem(name string) (channelCatalogItem, bool) {
|
||||
for _, item := range channelCatalog {
|
||||
if item.Name == name {
|
||||
return item, true
|
||||
}
|
||||
}
|
||||
return channelCatalogItem{}, false
|
||||
}
|
||||
|
||||
func buildChannelConfigResponse(cfg *config.Config, item channelCatalogItem) channelConfigResponse {
|
||||
resp := channelConfigResponse{
|
||||
ConfiguredSecrets: []string{},
|
||||
ConfigKey: item.ConfigKey,
|
||||
Variant: item.Variant,
|
||||
}
|
||||
|
||||
switch item.Name {
|
||||
case "weixin":
|
||||
channelCfg := cfg.Channels.Weixin
|
||||
resp.ConfiguredSecrets = collectConfiguredSecrets(
|
||||
channelSecretPresence{key: "token", configured: channelCfg.Token.String() != ""},
|
||||
)
|
||||
channelCfg.Token = config.SecureString{}
|
||||
resp.Config = channelCfg
|
||||
case "telegram":
|
||||
channelCfg := cfg.Channels.Telegram
|
||||
resp.ConfiguredSecrets = collectConfiguredSecrets(
|
||||
channelSecretPresence{key: "token", configured: channelCfg.Token.String() != ""},
|
||||
)
|
||||
channelCfg.Token = config.SecureString{}
|
||||
resp.Config = channelCfg
|
||||
case "discord":
|
||||
channelCfg := cfg.Channels.Discord
|
||||
resp.ConfiguredSecrets = collectConfiguredSecrets(
|
||||
channelSecretPresence{key: "token", configured: channelCfg.Token.String() != ""},
|
||||
)
|
||||
channelCfg.Token = config.SecureString{}
|
||||
resp.Config = channelCfg
|
||||
case "slack":
|
||||
channelCfg := cfg.Channels.Slack
|
||||
resp.ConfiguredSecrets = collectConfiguredSecrets(
|
||||
channelSecretPresence{key: "bot_token", configured: channelCfg.BotToken.String() != ""},
|
||||
channelSecretPresence{key: "app_token", configured: channelCfg.AppToken.String() != ""},
|
||||
)
|
||||
channelCfg.BotToken = config.SecureString{}
|
||||
channelCfg.AppToken = config.SecureString{}
|
||||
resp.Config = channelCfg
|
||||
case "feishu":
|
||||
channelCfg := cfg.Channels.Feishu
|
||||
resp.ConfiguredSecrets = collectConfiguredSecrets(
|
||||
channelSecretPresence{key: "app_secret", configured: channelCfg.AppSecret.String() != ""},
|
||||
channelSecretPresence{key: "encrypt_key", configured: channelCfg.EncryptKey.String() != ""},
|
||||
channelSecretPresence{key: "verification_token", configured: channelCfg.VerificationToken.String() != ""},
|
||||
)
|
||||
channelCfg.AppSecret = config.SecureString{}
|
||||
channelCfg.EncryptKey = config.SecureString{}
|
||||
channelCfg.VerificationToken = config.SecureString{}
|
||||
resp.Config = channelCfg
|
||||
case "dingtalk":
|
||||
channelCfg := cfg.Channels.DingTalk
|
||||
resp.ConfiguredSecrets = collectConfiguredSecrets(
|
||||
channelSecretPresence{key: "client_secret", configured: channelCfg.ClientSecret.String() != ""},
|
||||
)
|
||||
channelCfg.ClientSecret = config.SecureString{}
|
||||
resp.Config = channelCfg
|
||||
case "line":
|
||||
channelCfg := cfg.Channels.LINE
|
||||
resp.ConfiguredSecrets = collectConfiguredSecrets(
|
||||
channelSecretPresence{key: "channel_secret", configured: channelCfg.ChannelSecret.String() != ""},
|
||||
channelSecretPresence{
|
||||
key: "channel_access_token",
|
||||
configured: channelCfg.ChannelAccessToken.String() != "",
|
||||
},
|
||||
)
|
||||
channelCfg.ChannelSecret = config.SecureString{}
|
||||
channelCfg.ChannelAccessToken = config.SecureString{}
|
||||
resp.Config = channelCfg
|
||||
case "qq":
|
||||
channelCfg := cfg.Channels.QQ
|
||||
resp.ConfiguredSecrets = collectConfiguredSecrets(
|
||||
channelSecretPresence{key: "app_secret", configured: channelCfg.AppSecret.String() != ""},
|
||||
)
|
||||
channelCfg.AppSecret = config.SecureString{}
|
||||
resp.Config = channelCfg
|
||||
case "onebot":
|
||||
channelCfg := cfg.Channels.OneBot
|
||||
resp.ConfiguredSecrets = collectConfiguredSecrets(
|
||||
channelSecretPresence{key: "access_token", configured: channelCfg.AccessToken.String() != ""},
|
||||
)
|
||||
channelCfg.AccessToken = config.SecureString{}
|
||||
resp.Config = channelCfg
|
||||
case "wecom":
|
||||
channelCfg := cfg.Channels.WeCom
|
||||
resp.ConfiguredSecrets = collectConfiguredSecrets(
|
||||
channelSecretPresence{key: "secret", configured: channelCfg.Secret.String() != ""},
|
||||
)
|
||||
channelCfg.Secret = config.SecureString{}
|
||||
resp.Config = channelCfg
|
||||
case "whatsapp", "whatsapp_native":
|
||||
resp.Config = cfg.Channels.WhatsApp
|
||||
case "pico":
|
||||
channelCfg := cfg.Channels.Pico
|
||||
resp.ConfiguredSecrets = collectConfiguredSecrets(
|
||||
channelSecretPresence{key: "token", configured: channelCfg.Token.String() != ""},
|
||||
)
|
||||
channelCfg.Token = config.SecureString{}
|
||||
resp.Config = channelCfg
|
||||
case "maixcam":
|
||||
resp.Config = cfg.Channels.MaixCam
|
||||
case "matrix":
|
||||
channelCfg := cfg.Channels.Matrix
|
||||
resp.ConfiguredSecrets = collectConfiguredSecrets(
|
||||
channelSecretPresence{key: "access_token", configured: channelCfg.AccessToken.String() != ""},
|
||||
)
|
||||
channelCfg.AccessToken = config.SecureString{}
|
||||
resp.Config = channelCfg
|
||||
case "irc":
|
||||
channelCfg := cfg.Channels.IRC
|
||||
resp.ConfiguredSecrets = collectConfiguredSecrets(
|
||||
channelSecretPresence{key: "password", configured: channelCfg.Password.String() != ""},
|
||||
channelSecretPresence{key: "nickserv_password", configured: channelCfg.NickServPassword.String() != ""},
|
||||
channelSecretPresence{key: "sasl_password", configured: channelCfg.SASLPassword.String() != ""},
|
||||
)
|
||||
channelCfg.Password = config.SecureString{}
|
||||
channelCfg.NickServPassword = config.SecureString{}
|
||||
channelCfg.SASLPassword = config.SecureString{}
|
||||
resp.Config = channelCfg
|
||||
default:
|
||||
resp.Config = map[string]any{}
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
func collectConfiguredSecrets(secrets ...channelSecretPresence) []string {
|
||||
configured := make([]string, 0, len(secrets))
|
||||
for _, secret := range secrets {
|
||||
if secret.configured {
|
||||
configured = append(configured, secret.key)
|
||||
}
|
||||
}
|
||||
return configured
|
||||
}
|
||||
|
||||
@@ -0,0 +1,87 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
func TestHandleGetChannelConfig_ReturnsSecretPresenceWithoutLeakingSecrets(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
|
||||
cfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
cfg.Channels.Feishu.Enabled = true
|
||||
cfg.Channels.Feishu.AppID = "cli_test_app"
|
||||
cfg.Channels.Feishu.AppSecret = *config.NewSecureString("feishu-secret-from-security")
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/channels/feishu/config", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf(
|
||||
"GET /api/channels/feishu/config status = %d, want %d, body=%s",
|
||||
rec.Code,
|
||||
http.StatusOK,
|
||||
rec.Body.String(),
|
||||
)
|
||||
}
|
||||
if strings.Contains(rec.Body.String(), "feishu-secret-from-security") {
|
||||
t.Fatalf("response leaked secret value: %s", rec.Body.String())
|
||||
}
|
||||
|
||||
var resp struct {
|
||||
Config map[string]any `json:"config"`
|
||||
ConfiguredSecrets []string `json:"configured_secrets"`
|
||||
ConfigKey string `json:"config_key"`
|
||||
Variant string `json:"variant"`
|
||||
}
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("json.Unmarshal() error = %v", err)
|
||||
}
|
||||
|
||||
if got := resp.ConfigKey; got != "feishu" {
|
||||
t.Fatalf("config_key = %q, want %q", got, "feishu")
|
||||
}
|
||||
if got := resp.Config["app_id"]; got != "cli_test_app" {
|
||||
t.Fatalf("config.app_id = %#v, want %q", got, "cli_test_app")
|
||||
}
|
||||
if _, exists := resp.Config["app_secret"]; exists {
|
||||
t.Fatalf("config should omit app_secret, got %#v", resp.Config["app_secret"])
|
||||
}
|
||||
if len(resp.ConfiguredSecrets) != 1 || resp.ConfiguredSecrets[0] != "app_secret" {
|
||||
t.Fatalf("configured_secrets = %#v, want [\"app_secret\"]", resp.ConfiguredSecrets)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleGetChannelConfig_ReturnsNotFoundForUnknownChannel(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/channels/not-a-channel/config", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusNotFound {
|
||||
t.Fatalf("GET /api/channels/not-a-channel/config status = %d, want %d", rec.Code, http.StatusNotFound)
|
||||
}
|
||||
}
|
||||
+34
-28
@@ -357,7 +357,13 @@ func isCmdProcessAliveLocked(cmd *exec.Cmd) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
return cmd.Process.Signal(syscall.Signal(0)) == nil
|
||||
err := cmd.Process.Signal(syscall.Signal(0))
|
||||
if err == nil {
|
||||
return true
|
||||
}
|
||||
var errno syscall.Errno
|
||||
// EPERM means the process exists but cannot be signaled by this user.
|
||||
return errors.As(err, &errno) && errno == syscall.EPERM
|
||||
}
|
||||
|
||||
func setGatewayRuntimeStatusLocked(status string) {
|
||||
@@ -401,6 +407,15 @@ func gatewayStatusWithoutHealthLocked() string {
|
||||
return "error"
|
||||
}
|
||||
if gateway.runtimeStatus == "running" {
|
||||
// For attached processes there is no waiter goroutine; degrade stale
|
||||
// running state once the tracked process exits.
|
||||
if !isCmdProcessAliveLocked(gateway.cmd) {
|
||||
gateway.cmd = nil
|
||||
gateway.owned = false
|
||||
gateway.bootDefaultModel = ""
|
||||
gateway.bootConfigSignature = ""
|
||||
return "stopped"
|
||||
}
|
||||
return "running"
|
||||
}
|
||||
if gateway.runtimeStatus == "error" {
|
||||
@@ -614,6 +629,7 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
|
||||
|
||||
// Start a goroutine to probe pidFile and health, update runtime state once ready.
|
||||
go func() {
|
||||
healthConfirmed := false
|
||||
for i := 0; i < 30; i++ { // try for up to 15 seconds
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
gateway.mu.Lock()
|
||||
@@ -628,6 +644,7 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
|
||||
gateway.mu.Lock()
|
||||
if gateway.cmd == cmd {
|
||||
gateway.pidData = pd
|
||||
gateway.picoToken = cfg.Channels.Pico.Token.String()
|
||||
setGatewayRuntimeStatusLocked("running")
|
||||
}
|
||||
gateway.mu.Unlock()
|
||||
@@ -647,7 +664,11 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
|
||||
setGatewayRuntimeStatusLocked("running")
|
||||
}
|
||||
gateway.mu.Unlock()
|
||||
return
|
||||
if !healthConfirmed {
|
||||
healthConfirmed = true
|
||||
logger.InfoC("gateway", "Gateway health endpoint reachable; waiting for pid file")
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
}()
|
||||
@@ -922,34 +943,19 @@ func (h *Handler) gatewayStatusData() map[string]any {
|
||||
data["pid"] = pidData.PID
|
||||
gateway.mu.Unlock()
|
||||
} else {
|
||||
// Fallback: probe health endpoint to get pid and status
|
||||
_, statusCode, err := h.getGatewayHealth(cfg, 2*time.Second)
|
||||
if err != nil {
|
||||
gateway.mu.Lock()
|
||||
data["gateway_status"] = gatewayStatusWithoutHealthLocked()
|
||||
// Intentionally skip health probe here; the startup goroutine
|
||||
// (startGatewayLocked) already handles liveness detection via
|
||||
// pidFile polling and health fallback.
|
||||
gateway.mu.Lock()
|
||||
status := gatewayStatusWithoutHealthLocked()
|
||||
data["gateway_status"] = status
|
||||
// Keep last known pidData while gateway is still in a transient
|
||||
// running state; otherwise websocket proxy may lose auth token
|
||||
// during short pid-file races.
|
||||
if status == "stopped" || status == "error" {
|
||||
gateway.pidData = nil
|
||||
gateway.mu.Unlock()
|
||||
logger.ErrorC("gateway", fmt.Sprintf("Gateway health check failed: %v", err))
|
||||
} else {
|
||||
logger.InfoC("gateway", fmt.Sprintf("Gateway health status: %d", statusCode))
|
||||
if statusCode != http.StatusOK {
|
||||
gateway.mu.Lock()
|
||||
setGatewayRuntimeStatusLocked("error")
|
||||
gateway.pidData = nil
|
||||
gateway.mu.Unlock()
|
||||
data["gateway_status"] = "error"
|
||||
data["status_code"] = statusCode
|
||||
} else {
|
||||
gateway.mu.Lock()
|
||||
setGatewayRuntimeStatusLocked("running")
|
||||
bootDefaultModel := gateway.bootDefaultModel
|
||||
if bootDefaultModel != "" {
|
||||
data["boot_default_model"] = bootDefaultModel
|
||||
}
|
||||
data["gateway_status"] = "running"
|
||||
gateway.mu.Unlock()
|
||||
}
|
||||
}
|
||||
gateway.mu.Unlock()
|
||||
}
|
||||
|
||||
gatewayStatus, _ := data["gateway_status"].(string)
|
||||
|
||||
@@ -15,8 +15,11 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/auth"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
ppid "github.com/sipeed/picoclaw/pkg/pid"
|
||||
"github.com/sipeed/picoclaw/web/backend/utils"
|
||||
)
|
||||
|
||||
@@ -444,7 +447,93 @@ func TestGatewayStatusKeepsRunningWhenHealthProbeFailsAfterRunning(t *testing.T)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayStatusReportsRunningFromHealthProbe(t *testing.T) {
|
||||
func TestGatewayStatusKeepsPidDataWhileTrackedProcessAliveWhenPidFileUnavailable(t *testing.T) {
|
||||
resetGatewayTestState(t)
|
||||
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
cmd := startLongRunningProcess(t)
|
||||
t.Cleanup(func() {
|
||||
if cmd.Process != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
}
|
||||
_ = cmd.Wait()
|
||||
})
|
||||
|
||||
gateway.mu.Lock()
|
||||
gateway.cmd = cmd
|
||||
gateway.pidData = &ppid.PidFileData{
|
||||
PID: cmd.Process.Pid,
|
||||
Token: "existing-token",
|
||||
}
|
||||
setGatewayRuntimeStatusLocked("running")
|
||||
gateway.mu.Unlock()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
gateway.mu.Lock()
|
||||
defer gateway.mu.Unlock()
|
||||
if gateway.pidData == nil {
|
||||
t.Fatal("gateway.pidData was cleared while runtime status remained running")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayStatusDowngradesRunningWhenTrackedProcessExitedAndPidFileMissing(t *testing.T) {
|
||||
resetGatewayTestState(t)
|
||||
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
cmd := startLongRunningProcess(t)
|
||||
if cmd.Process != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
}
|
||||
_ = cmd.Wait()
|
||||
|
||||
gateway.mu.Lock()
|
||||
gateway.cmd = cmd
|
||||
gateway.pidData = &ppid.PidFileData{
|
||||
PID: cmd.Process.Pid,
|
||||
Token: "stale-token",
|
||||
}
|
||||
setGatewayRuntimeStatusLocked("running")
|
||||
gateway.mu.Unlock()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var body map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("unmarshal response: %v", err)
|
||||
}
|
||||
if got := body["gateway_status"]; got != "stopped" {
|
||||
t.Fatalf("gateway_status = %#v, want %q", got, "stopped")
|
||||
}
|
||||
|
||||
gateway.mu.Lock()
|
||||
defer gateway.mu.Unlock()
|
||||
if gateway.pidData != nil {
|
||||
t.Fatal("gateway.pidData should be cleared when tracked process has exited")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayStatusReportsRunningFromPidProbe(t *testing.T) {
|
||||
resetGatewayTestState(t)
|
||||
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
@@ -468,6 +557,9 @@ func TestGatewayStatusReportsRunningFromHealthProbe(t *testing.T) {
|
||||
return mockGatewayHealthResponse(http.StatusOK, cmd.Process.Pid), nil
|
||||
}
|
||||
|
||||
_, err := ppid.WritePidFile(globalConfigDir(), "localhost", 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
@@ -513,6 +605,8 @@ func TestGatewayStatusRequiresRestartAfterDefaultModelChange(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("FindProcess() error = %v", err)
|
||||
}
|
||||
_, err = ppid.WritePidFile(globalConfigDir(), "localhost", 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
bootSignature := computeConfigSignature(cfg)
|
||||
gateway.mu.Lock()
|
||||
|
||||
@@ -4,14 +4,16 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/web/backend/launcherconfig"
|
||||
)
|
||||
|
||||
type launcherConfigPayload struct {
|
||||
Port int `json:"port"`
|
||||
Public bool `json:"public"`
|
||||
AllowedCIDRs []string `json:"allowed_cidrs"`
|
||||
Port int `json:"port"`
|
||||
Public bool `json:"public"`
|
||||
AllowedCIDRs []string `json:"allowed_cidrs"`
|
||||
LauncherToken string `json:"launcher_token"`
|
||||
}
|
||||
|
||||
func (h *Handler) registerLauncherConfigRoutes(mux *http.ServeMux) {
|
||||
@@ -48,9 +50,10 @@ func (h *Handler) handleGetLauncherConfig(w http.ResponseWriter, r *http.Request
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(launcherConfigPayload{
|
||||
Port: cfg.Port,
|
||||
Public: cfg.Public,
|
||||
AllowedCIDRs: append([]string(nil), cfg.AllowedCIDRs...),
|
||||
Port: cfg.Port,
|
||||
Public: cfg.Public,
|
||||
AllowedCIDRs: append([]string(nil), cfg.AllowedCIDRs...),
|
||||
LauncherToken: cfg.LauncherToken,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -62,9 +65,10 @@ func (h *Handler) handleUpdateLauncherConfig(w http.ResponseWriter, r *http.Requ
|
||||
}
|
||||
|
||||
cfg := launcherconfig.Config{
|
||||
Port: payload.Port,
|
||||
Public: payload.Public,
|
||||
AllowedCIDRs: append([]string(nil), payload.AllowedCIDRs...),
|
||||
Port: payload.Port,
|
||||
Public: payload.Public,
|
||||
AllowedCIDRs: append([]string(nil), payload.AllowedCIDRs...),
|
||||
LauncherToken: strings.TrimSpace(payload.LauncherToken),
|
||||
}
|
||||
if err := launcherconfig.Validate(cfg); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
@@ -78,8 +82,9 @@ func (h *Handler) handleUpdateLauncherConfig(w http.ResponseWriter, r *http.Requ
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(launcherConfigPayload{
|
||||
Port: cfg.Port,
|
||||
Public: cfg.Public,
|
||||
AllowedCIDRs: append([]string(nil), cfg.AllowedCIDRs...),
|
||||
Port: cfg.Port,
|
||||
Public: cfg.Public,
|
||||
AllowedCIDRs: append([]string(nil), cfg.AllowedCIDRs...),
|
||||
LauncherToken: cfg.LauncherToken,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -34,6 +34,9 @@ func TestGetLauncherConfigUsesRuntimeFallback(t *testing.T) {
|
||||
if got.Port != 19999 || !got.Public {
|
||||
t.Fatalf("response = %+v, want port=19999 public=true", got)
|
||||
}
|
||||
if got.LauncherToken != "" {
|
||||
t.Fatalf("response launcher_token = %q, want empty", got.LauncherToken)
|
||||
}
|
||||
if len(got.AllowedCIDRs) != 1 || got.AllowedCIDRs[0] != "192.168.1.0/24" {
|
||||
t.Fatalf("response allowed_cidrs = %v, want [192.168.1.0/24]", got.AllowedCIDRs)
|
||||
}
|
||||
@@ -50,7 +53,9 @@ func TestPutLauncherConfigPersists(t *testing.T) {
|
||||
req := httptest.NewRequest(
|
||||
http.MethodPut,
|
||||
"/api/system/launcher-config",
|
||||
strings.NewReader(`{"port":18080,"public":true,"allowed_cidrs":["192.168.1.0/24"]}`),
|
||||
strings.NewReader(
|
||||
`{"port":18080,"public":true,"allowed_cidrs":["192.168.1.0/24"],"launcher_token":"saved-token"}`,
|
||||
),
|
||||
)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
mux.ServeHTTP(rec, req)
|
||||
@@ -67,6 +72,9 @@ func TestPutLauncherConfigPersists(t *testing.T) {
|
||||
if cfg.Port != 18080 || !cfg.Public {
|
||||
t.Fatalf("saved config = %+v, want port=18080 public=true", cfg)
|
||||
}
|
||||
if cfg.LauncherToken != "saved-token" {
|
||||
t.Fatalf("saved launcher_token = %q, want %q", cfg.LauncherToken, "saved-token")
|
||||
}
|
||||
if len(cfg.AllowedCIDRs) != 1 || cfg.AllowedCIDRs[0] != "192.168.1.0/24" {
|
||||
t.Fatalf("saved config allowed_cidrs = %v, want [192.168.1.0/24]", cfg.AllowedCIDRs)
|
||||
}
|
||||
|
||||
@@ -1,19 +1,36 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/singleflight"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
const modelProbeTimeout = 800 * time.Millisecond
|
||||
const (
|
||||
modelProbeTimeout = 800 * time.Millisecond
|
||||
modelProbeSuccessBaseInterval = 2 * time.Second
|
||||
modelProbeSuccessMaxInterval = 60 * time.Second
|
||||
modelProbeFailureBaseInterval = 1 * time.Second
|
||||
modelProbeFailureMaxInterval = 30 * time.Second
|
||||
modelProbeBackoffMaxShift = 8
|
||||
modelProbeCacheMaxEntries = 1024
|
||||
modelProbeCacheEntryTTL = 30 * time.Minute
|
||||
modelProbeCacheTrimToEntries = modelProbeCacheMaxEntries * 8 / 10
|
||||
modelProbeTTLGCInterval = 1 * time.Minute
|
||||
)
|
||||
|
||||
const (
|
||||
modelStatusAvailable = "available"
|
||||
@@ -30,8 +47,41 @@ var (
|
||||
probeTCPServiceFunc = probeTCPService
|
||||
probeOllamaModelFunc = probeOllamaModel
|
||||
probeOpenAICompatibleModelFunc = probeOpenAICompatibleModel
|
||||
modelProbeNowFunc = time.Now
|
||||
modelProbeState = newModelProbeCacheState()
|
||||
)
|
||||
|
||||
type modelProbeCacheState struct {
|
||||
mu sync.RWMutex
|
||||
cache map[string]*modelProbeCacheEntry
|
||||
group singleflight.Group
|
||||
nextTTLGCAt time.Time
|
||||
}
|
||||
|
||||
type modelProbeCacheEntry struct {
|
||||
lastResult bool
|
||||
hasResult bool
|
||||
successStreak int
|
||||
failureStreak int
|
||||
nextProbeAt time.Time
|
||||
updatedAt time.Time
|
||||
}
|
||||
|
||||
func newModelProbeCacheState() *modelProbeCacheState {
|
||||
return &modelProbeCacheState{cache: map[string]*modelProbeCacheEntry{}}
|
||||
}
|
||||
|
||||
func resetModelProbeCache() {
|
||||
modelProbeState.resetForTest()
|
||||
}
|
||||
|
||||
func (s *modelProbeCacheState) resetForTest() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.cache = map[string]*modelProbeCacheEntry{}
|
||||
s.nextTTLGCAt = time.Time{}
|
||||
}
|
||||
|
||||
func hasModelConfiguration(m *config.ModelConfig) bool {
|
||||
authMethod := strings.ToLower(strings.TrimSpace(m.AuthMethod))
|
||||
apiKey := strings.TrimSpace(m.APIKey())
|
||||
@@ -93,6 +143,34 @@ func requiresRuntimeProbe(m *config.ModelConfig) bool {
|
||||
}
|
||||
|
||||
func probeLocalModelAvailability(m *config.ModelConfig) bool {
|
||||
cacheKey := modelProbeCacheKey(m)
|
||||
return modelProbeState.probe(cacheKey, func() bool {
|
||||
return runLocalModelProbe(m)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *modelProbeCacheState) probe(cacheKey string, probeFunc func() bool) bool {
|
||||
now := modelProbeNowFunc()
|
||||
if cachedResult, ok := s.getCachedResult(cacheKey, now); ok {
|
||||
return cachedResult
|
||||
}
|
||||
|
||||
v, _, _ := s.group.Do(cacheKey, func() (any, error) {
|
||||
now = modelProbeNowFunc()
|
||||
if cachedResult, ok := s.getCachedResult(cacheKey, now); ok {
|
||||
return cachedResult, nil
|
||||
}
|
||||
|
||||
result := probeFunc()
|
||||
s.setCachedResult(cacheKey, result, now)
|
||||
return result, nil
|
||||
})
|
||||
|
||||
result, _ := v.(bool)
|
||||
return result
|
||||
}
|
||||
|
||||
func runLocalModelProbe(m *config.ModelConfig) bool {
|
||||
apiBase := modelProbeAPIBase(m)
|
||||
protocol, modelID := splitModel(m.Model)
|
||||
switch protocol {
|
||||
@@ -112,6 +190,195 @@ func probeLocalModelAvailability(m *config.ModelConfig) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func modelProbeCacheKey(m *config.ModelConfig) string {
|
||||
protocol, modelID := splitModel(m.Model)
|
||||
|
||||
apiBaseRaw := modelProbeAPIBase(m)
|
||||
apiBase := strings.ToLower(strings.TrimRight(strings.TrimSpace(apiBaseRaw), "/"))
|
||||
apiKeyFingerprint := modelProbeAPIKeyFingerprint(m.APIKey())
|
||||
|
||||
var b strings.Builder
|
||||
b.Grow(len(protocol) + len(modelID) + len(apiBase) + len(apiKeyFingerprint) + 8)
|
||||
b.WriteString(protocol)
|
||||
b.WriteByte('|')
|
||||
b.WriteString(modelID)
|
||||
b.WriteByte('|')
|
||||
b.WriteString(apiBase)
|
||||
b.WriteByte('|')
|
||||
b.WriteString(apiKeyFingerprint)
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func modelProbeAPIKeyFingerprint(raw string) string {
|
||||
apiKey := strings.TrimSpace(raw)
|
||||
if apiKey == "" {
|
||||
return "none"
|
||||
}
|
||||
|
||||
h := fnv.New64a()
|
||||
_, _ = h.Write([]byte(apiKey))
|
||||
return strconv.FormatUint(h.Sum64(), 36)
|
||||
}
|
||||
|
||||
func (s *modelProbeCacheState) getCachedResult(cacheKey string, now time.Time) (bool, bool) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
entry, ok := s.cache[cacheKey]
|
||||
if !ok || !entry.hasResult {
|
||||
return false, false
|
||||
}
|
||||
if now.Before(entry.nextProbeAt) {
|
||||
return entry.lastResult, true
|
||||
}
|
||||
return false, false
|
||||
}
|
||||
|
||||
func (s *modelProbeCacheState) setCachedResult(cacheKey string, result bool, now time.Time) {
|
||||
s.mu.Lock()
|
||||
|
||||
entry, ok := s.cache[cacheKey]
|
||||
if !ok {
|
||||
entry = &modelProbeCacheEntry{}
|
||||
s.cache[cacheKey] = entry
|
||||
}
|
||||
|
||||
entry.lastResult = result
|
||||
entry.hasResult = true
|
||||
entry.updatedAt = now
|
||||
|
||||
var delay time.Duration
|
||||
if result {
|
||||
entry.successStreak++
|
||||
entry.failureStreak = 0
|
||||
delay = modelProbeBackoffDelay(
|
||||
modelProbeSuccessBaseInterval,
|
||||
modelProbeSuccessMaxInterval,
|
||||
entry.successStreak,
|
||||
)
|
||||
} else {
|
||||
entry.failureStreak++
|
||||
entry.successStreak = 0
|
||||
delay = modelProbeBackoffDelay(
|
||||
modelProbeFailureBaseInterval,
|
||||
modelProbeFailureMaxInterval,
|
||||
entry.failureStreak,
|
||||
)
|
||||
}
|
||||
|
||||
entry.nextProbeAt = now.Add(delay)
|
||||
|
||||
shouldRunTTLGC := modelProbeCacheEntryTTL > 0 && (s.nextTTLGCAt.IsZero() || !now.Before(s.nextTTLGCAt))
|
||||
if shouldRunTTLGC {
|
||||
s.nextTTLGCAt = now.Add(modelProbeTTLGCInterval)
|
||||
}
|
||||
shouldRunSizeGC := len(s.cache) > modelProbeCacheMaxEntries
|
||||
s.mu.Unlock()
|
||||
|
||||
if shouldRunTTLGC || shouldRunSizeGC {
|
||||
s.gc(now, shouldRunTTLGC)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *modelProbeCacheState) gc(now time.Time, runTTL bool) {
|
||||
type evictionCandidate struct {
|
||||
key string
|
||||
updatedAt time.Time
|
||||
}
|
||||
|
||||
var expireBefore time.Time
|
||||
if runTTL && modelProbeCacheEntryTTL > 0 {
|
||||
expireBefore = now.Add(-modelProbeCacheEntryTTL)
|
||||
}
|
||||
|
||||
s.mu.RLock()
|
||||
cacheLen := len(s.cache)
|
||||
if cacheLen == 0 {
|
||||
s.mu.RUnlock()
|
||||
return
|
||||
}
|
||||
|
||||
expiredKeys := make([]string, 0)
|
||||
if !expireBefore.IsZero() {
|
||||
expiredKeys = make([]string, 0, min(cacheLen/8+1, 64))
|
||||
for key, entry := range s.cache {
|
||||
if entry.updatedAt.Before(expireBefore) {
|
||||
expiredKeys = append(expiredKeys, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
effectiveLen := cacheLen - len(expiredKeys)
|
||||
removeCount := max(effectiveLen-modelProbeCacheTrimToEntries, 0)
|
||||
|
||||
candidates := make([]evictionCandidate, 0)
|
||||
if removeCount > 0 {
|
||||
candidates = make([]evictionCandidate, 0, effectiveLen)
|
||||
for key, entry := range s.cache {
|
||||
if !expireBefore.IsZero() && entry.updatedAt.Before(expireBefore) {
|
||||
continue
|
||||
}
|
||||
candidates = append(candidates, evictionCandidate{key: key, updatedAt: entry.updatedAt})
|
||||
}
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
|
||||
if len(expiredKeys) == 0 && len(candidates) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
toEvict := map[string]time.Time{}
|
||||
for i := 0; i < removeCount && len(candidates) > 0; i++ {
|
||||
oldest := 0
|
||||
for j := 1; j < len(candidates); j++ {
|
||||
if candidates[j].updatedAt.Before(candidates[oldest].updatedAt) {
|
||||
oldest = j
|
||||
}
|
||||
}
|
||||
victim := candidates[oldest]
|
||||
toEvict[victim.key] = victim.updatedAt
|
||||
candidates[oldest] = candidates[len(candidates)-1]
|
||||
candidates = candidates[:len(candidates)-1]
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if !expireBefore.IsZero() {
|
||||
for _, key := range expiredKeys {
|
||||
entry, ok := s.cache[key]
|
||||
if ok && entry.updatedAt.Before(expireBefore) {
|
||||
delete(s.cache, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for key, victimUpdatedAt := range toEvict {
|
||||
entry, ok := s.cache[key]
|
||||
if ok && !entry.updatedAt.After(victimUpdatedAt) {
|
||||
delete(s.cache, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func modelProbeBackoffDelay(base, maxDelay time.Duration, streak int) time.Duration {
|
||||
if streak <= 0 {
|
||||
streak = 1
|
||||
}
|
||||
|
||||
shift := min(streak-1, modelProbeBackoffMaxShift)
|
||||
|
||||
delay := base * time.Duration(1<<shift)
|
||||
if maxDelay > 0 && (delay > maxDelay || delay < 0) {
|
||||
return maxDelay
|
||||
}
|
||||
if delay <= 0 {
|
||||
return base
|
||||
}
|
||||
return delay
|
||||
}
|
||||
|
||||
func modelProbeAPIBase(m *config.ModelConfig) string {
|
||||
if apiBase := strings.TrimSpace(m.APIBase); apiBase != "" {
|
||||
return normalizeModelProbeAPIBase(apiBase)
|
||||
@@ -207,7 +474,11 @@ func probeTCPService(raw string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
conn, err := net.DialTimeout("tcp", hostPort, modelProbeTimeout)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), modelProbeTimeout)
|
||||
defer cancel()
|
||||
|
||||
dialer := &net.Dialer{}
|
||||
conn, err := dialer.DialContext(ctx, "tcp", hostPort)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
@@ -262,7 +533,10 @@ func probeOpenAICompatibleModel(apiBase, modelID, apiKey string) bool {
|
||||
}
|
||||
|
||||
func getJSON(rawURL string, out any, apiKey string) error {
|
||||
req, err := http.NewRequest(http.MethodGet, rawURL, nil)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), modelProbeTimeout)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -270,7 +544,7 @@ func getJSON(rawURL string, out any, apiKey string) error {
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: modelProbeTimeout}
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -336,10 +610,29 @@ func ollamaModelMatches(candidate, want string) bool {
|
||||
if candidate == "" || want == "" {
|
||||
return false
|
||||
}
|
||||
if strings.EqualFold(candidate, want) {
|
||||
return true
|
||||
|
||||
candidateBase, candidateTag := splitOllamaModel(candidate)
|
||||
wantBase, wantTag := splitOllamaModel(want)
|
||||
if candidateBase == "" || wantBase == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
base, _, _ := strings.Cut(candidate, ":")
|
||||
return strings.EqualFold(base, want)
|
||||
if candidateTag == "" {
|
||||
candidateTag = "latest"
|
||||
}
|
||||
if wantTag == "" {
|
||||
wantTag = "latest"
|
||||
}
|
||||
|
||||
return strings.EqualFold(candidateBase, wantBase) && strings.EqualFold(candidateTag, wantTag)
|
||||
}
|
||||
|
||||
func splitOllamaModel(raw string) (base, tag string) {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
base, tag, _ = strings.Cut(raw, ":")
|
||||
return strings.TrimSpace(base), strings.TrimSpace(tag)
|
||||
}
|
||||
|
||||
@@ -3,7 +3,10 @@ package api
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
@@ -85,3 +88,307 @@ func TestProbeLocalModelAvailability_LMStudioUsesOpenAICompatibleProbe(t *testin
|
||||
t.Fatal("probeOpenAICompatibleModelFunc was not called for lmstudio")
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelProbeCacheKey_DifferentAPIKeysProduceDifferentKeys(t *testing.T) {
|
||||
base := &config.ModelConfig{
|
||||
ModelName: "local-vllm",
|
||||
Model: "vllm/custom-model",
|
||||
APIBase: "http://127.0.0.1:8000/v1",
|
||||
AuthMethod: "local",
|
||||
ConnectMode: "",
|
||||
}
|
||||
|
||||
m1 := *base
|
||||
m1.SetAPIKey("key-a")
|
||||
m2 := *base
|
||||
m2.SetAPIKey("key-b")
|
||||
|
||||
k1 := modelProbeCacheKey(&m1)
|
||||
k2 := modelProbeCacheKey(&m2)
|
||||
if k1 == k2 {
|
||||
t.Fatal("modelProbeCacheKey() should differ when api key changes")
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelProbeCacheKey_NormalizesTrailingSlashInAPIBase(t *testing.T) {
|
||||
m1 := &config.ModelConfig{
|
||||
ModelName: "local-vllm",
|
||||
Model: "vllm/custom-model",
|
||||
APIBase: "http://127.0.0.1:8000/v1",
|
||||
}
|
||||
m2 := &config.ModelConfig{
|
||||
ModelName: "local-vllm",
|
||||
Model: "vllm/custom-model",
|
||||
APIBase: "http://127.0.0.1:8000/v1/",
|
||||
}
|
||||
|
||||
k1 := modelProbeCacheKey(m1)
|
||||
k2 := modelProbeCacheKey(m2)
|
||||
if k1 != k2 {
|
||||
t.Fatalf("modelProbeCacheKey() mismatch for equivalent api_base values: %q vs %q", k1, k2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelProbeCacheKey_IgnoresDisplayAndConnectionFields(t *testing.T) {
|
||||
base := &config.ModelConfig{
|
||||
ModelName: "vllm-one",
|
||||
Model: "vllm/custom-model",
|
||||
APIBase: "http://127.0.0.1:8000/v1",
|
||||
AuthMethod: "none",
|
||||
ConnectMode: "http",
|
||||
}
|
||||
changed := &config.ModelConfig{
|
||||
ModelName: "vllm-two",
|
||||
Model: "vllm/custom-model",
|
||||
APIBase: "http://127.0.0.1:8000/v1",
|
||||
AuthMethod: "token",
|
||||
ConnectMode: "ws",
|
||||
}
|
||||
|
||||
k1 := modelProbeCacheKey(base)
|
||||
k2 := modelProbeCacheKey(changed)
|
||||
if k1 != k2 {
|
||||
t.Fatalf("modelProbeCacheKey() should ignore non-probe fields, got %q vs %q", k1, k2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProbeLocalModelAvailability_SuccessBackoff(t *testing.T) {
|
||||
resetModelProbeHooks(t)
|
||||
|
||||
now := time.Unix(1700000000, 0)
|
||||
modelProbeNowFunc = func() time.Time { return now }
|
||||
|
||||
calls := 0
|
||||
probeOpenAICompatibleModelFunc = func(apiBase, modelID, apiKey string) bool {
|
||||
calls++
|
||||
return true
|
||||
}
|
||||
|
||||
model := &config.ModelConfig{
|
||||
ModelName: "local-vllm",
|
||||
Model: "vllm/custom-model",
|
||||
APIBase: "http://127.0.0.1:8000/v1",
|
||||
}
|
||||
|
||||
if !probeLocalModelAvailability(model) {
|
||||
t.Fatal("first probe result = false, want true")
|
||||
}
|
||||
if calls != 1 {
|
||||
t.Fatalf("probe calls after first probe = %d, want 1", calls)
|
||||
}
|
||||
|
||||
if !probeLocalModelAvailability(model) {
|
||||
t.Fatal("cached probe result = false, want true")
|
||||
}
|
||||
if calls != 1 {
|
||||
t.Fatalf("probe calls after immediate re-check = %d, want 1", calls)
|
||||
}
|
||||
|
||||
now = now.Add(modelProbeSuccessBaseInterval)
|
||||
if !probeLocalModelAvailability(model) {
|
||||
t.Fatal("second probe result = false, want true")
|
||||
}
|
||||
if calls != 2 {
|
||||
t.Fatalf("probe calls after success backoff window = %d, want 2", calls)
|
||||
}
|
||||
|
||||
now = now.Add(modelProbeSuccessBaseInterval)
|
||||
if !probeLocalModelAvailability(model) {
|
||||
t.Fatal("cached result after doubled backoff = false, want true")
|
||||
}
|
||||
if calls != 2 {
|
||||
t.Fatalf("probe calls before doubled backoff expires = %d, want 2", calls)
|
||||
}
|
||||
|
||||
now = now.Add(modelProbeSuccessBaseInterval)
|
||||
if !probeLocalModelAvailability(model) {
|
||||
t.Fatal("third probe result = false, want true")
|
||||
}
|
||||
if calls != 3 {
|
||||
t.Fatalf("probe calls after doubled backoff expires = %d, want 3", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProbeLocalModelAvailability_FailureBackoff(t *testing.T) {
|
||||
resetModelProbeHooks(t)
|
||||
|
||||
now := time.Unix(1700000100, 0)
|
||||
modelProbeNowFunc = func() time.Time { return now }
|
||||
|
||||
calls := 0
|
||||
probeOpenAICompatibleModelFunc = func(apiBase, modelID, apiKey string) bool {
|
||||
calls++
|
||||
return false
|
||||
}
|
||||
|
||||
model := &config.ModelConfig{
|
||||
ModelName: "local-vllm",
|
||||
Model: "vllm/custom-model",
|
||||
APIBase: "http://127.0.0.1:8000/v1",
|
||||
}
|
||||
|
||||
if probeLocalModelAvailability(model) {
|
||||
t.Fatal("first probe result = true, want false")
|
||||
}
|
||||
if calls != 1 {
|
||||
t.Fatalf("probe calls after first failure = %d, want 1", calls)
|
||||
}
|
||||
|
||||
if probeLocalModelAvailability(model) {
|
||||
t.Fatal("cached failed probe result = true, want false")
|
||||
}
|
||||
if calls != 1 {
|
||||
t.Fatalf("probe calls after immediate failed re-check = %d, want 1", calls)
|
||||
}
|
||||
|
||||
now = now.Add(modelProbeFailureBaseInterval)
|
||||
if probeLocalModelAvailability(model) {
|
||||
t.Fatal("second failed probe result = true, want false")
|
||||
}
|
||||
if calls != 2 {
|
||||
t.Fatalf("probe calls after failure backoff window = %d, want 2", calls)
|
||||
}
|
||||
|
||||
now = now.Add(modelProbeFailureBaseInterval)
|
||||
if probeLocalModelAvailability(model) {
|
||||
t.Fatal("cached failure after doubled backoff = true, want false")
|
||||
}
|
||||
if calls != 2 {
|
||||
t.Fatalf("probe calls before doubled failure backoff expires = %d, want 2", calls)
|
||||
}
|
||||
|
||||
now = now.Add(modelProbeFailureBaseInterval)
|
||||
if probeLocalModelAvailability(model) {
|
||||
t.Fatal("third failed probe result = true, want false")
|
||||
}
|
||||
if calls != 3 {
|
||||
t.Fatalf("probe calls after doubled failure backoff expires = %d, want 3", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProbeLocalModelAvailability_ResultFlipResetsBackoff(t *testing.T) {
|
||||
resetModelProbeHooks(t)
|
||||
|
||||
now := time.Unix(1700000200, 0)
|
||||
modelProbeNowFunc = func() time.Time { return now }
|
||||
|
||||
results := []bool{true, false, false}
|
||||
index := 0
|
||||
probeOpenAICompatibleModelFunc = func(apiBase, modelID, apiKey string) bool {
|
||||
if index >= len(results) {
|
||||
return false
|
||||
}
|
||||
result := results[index]
|
||||
index++
|
||||
return result
|
||||
}
|
||||
|
||||
model := &config.ModelConfig{
|
||||
ModelName: "local-vllm",
|
||||
Model: "vllm/custom-model",
|
||||
APIBase: "http://127.0.0.1:8000/v1",
|
||||
}
|
||||
|
||||
if !probeLocalModelAvailability(model) {
|
||||
t.Fatal("first probe result = false, want true")
|
||||
}
|
||||
|
||||
now = now.Add(modelProbeSuccessBaseInterval)
|
||||
if probeLocalModelAvailability(model) {
|
||||
t.Fatal("second probe result = true, want false")
|
||||
}
|
||||
|
||||
now = now.Add(modelProbeFailureBaseInterval)
|
||||
if probeLocalModelAvailability(model) {
|
||||
t.Fatal("third probe result = true, want false")
|
||||
}
|
||||
|
||||
if index != 3 {
|
||||
t.Fatalf("probe invocations = %d, want 3", index)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProbeLocalModelAvailability_DeduplicatesInflightProbe(t *testing.T) {
|
||||
resetModelProbeHooks(t)
|
||||
|
||||
now := time.Unix(1700000300, 0)
|
||||
modelProbeNowFunc = func() time.Time { return now }
|
||||
|
||||
var calls int32
|
||||
probeStarted := make(chan struct{})
|
||||
releaseProbe := make(chan struct{})
|
||||
|
||||
probeOpenAICompatibleModelFunc = func(apiBase, modelID, apiKey string) bool {
|
||||
if atomic.AddInt32(&calls, 1) == 1 {
|
||||
close(probeStarted)
|
||||
}
|
||||
<-releaseProbe
|
||||
return true
|
||||
}
|
||||
|
||||
model := &config.ModelConfig{
|
||||
ModelName: "local-vllm",
|
||||
Model: "vllm/custom-model",
|
||||
APIBase: "http://127.0.0.1:8000/v1",
|
||||
}
|
||||
|
||||
const workers = 8
|
||||
var wg sync.WaitGroup
|
||||
results := make(chan bool, workers)
|
||||
workerStarted := make(chan struct{}, workers)
|
||||
|
||||
for range workers {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
workerStarted <- struct{}{}
|
||||
results <- probeLocalModelAvailability(model)
|
||||
}()
|
||||
}
|
||||
|
||||
for range workers {
|
||||
<-workerStarted
|
||||
}
|
||||
|
||||
select {
|
||||
case <-probeStarted:
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Fatal("probe did not start in time")
|
||||
}
|
||||
|
||||
if got := atomic.LoadInt32(&calls); got != 1 {
|
||||
t.Fatalf("concurrent probe calls = %d, want 1", got)
|
||||
}
|
||||
|
||||
close(releaseProbe)
|
||||
wg.Wait()
|
||||
close(results)
|
||||
|
||||
for result := range results {
|
||||
if !result {
|
||||
t.Fatal("deduplicated probe result = false, want true")
|
||||
}
|
||||
}
|
||||
|
||||
if got := atomic.LoadInt32(&calls); got != 1 {
|
||||
t.Fatalf("final probe calls = %d, want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOllamaModelMatches_WithTagRequiresExactTag(t *testing.T) {
|
||||
if ollamaModelMatches("llama3:8b", "llama3:7b") {
|
||||
t.Fatal("ollamaModelMatches() = true, want false for mismatched tags")
|
||||
}
|
||||
if !ollamaModelMatches("llama3:7b", "llama3:7b") {
|
||||
t.Fatal("ollamaModelMatches() = false, want true for exact tagged match")
|
||||
}
|
||||
if ollamaModelMatches("llama3:8b", "llama3") {
|
||||
t.Fatal("ollamaModelMatches() = true, want false when request omits tag (defaults to latest)")
|
||||
}
|
||||
if !ollamaModelMatches("llama3:latest", "llama3") {
|
||||
t.Fatal("ollamaModelMatches() = false, want true when request omits tag and candidate is latest")
|
||||
}
|
||||
if !ollamaModelMatches("llama3", "llama3") {
|
||||
t.Fatal("ollamaModelMatches() = false, want true when both candidate and request omit tag (latest)")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,13 +32,14 @@ type modelResponse struct {
|
||||
Proxy string `json:"proxy,omitempty"`
|
||||
AuthMethod string `json:"auth_method,omitempty"`
|
||||
// Advanced fields
|
||||
ConnectMode string `json:"connect_mode,omitempty"`
|
||||
Workspace string `json:"workspace,omitempty"`
|
||||
RPM int `json:"rpm,omitempty"`
|
||||
MaxTokensField string `json:"max_tokens_field,omitempty"`
|
||||
RequestTimeout int `json:"request_timeout,omitempty"`
|
||||
ThinkingLevel string `json:"thinking_level,omitempty"`
|
||||
ExtraBody map[string]any `json:"extra_body,omitempty"`
|
||||
ConnectMode string `json:"connect_mode,omitempty"`
|
||||
Workspace string `json:"workspace,omitempty"`
|
||||
RPM int `json:"rpm,omitempty"`
|
||||
MaxTokensField string `json:"max_tokens_field,omitempty"`
|
||||
RequestTimeout int `json:"request_timeout,omitempty"`
|
||||
ThinkingLevel string `json:"thinking_level,omitempty"`
|
||||
ExtraBody map[string]any `json:"extra_body,omitempty"`
|
||||
CustomHeaders map[string]string `json:"custom_headers,omitempty"`
|
||||
// Meta
|
||||
Enabled bool `json:"enabled"`
|
||||
Available bool `json:"available"`
|
||||
@@ -87,6 +88,7 @@ func (h *Handler) handleListModels(w http.ResponseWriter, r *http.Request) {
|
||||
RequestTimeout: m.RequestTimeout,
|
||||
ThinkingLevel: m.ThinkingLevel,
|
||||
ExtraBody: m.ExtraBody,
|
||||
CustomHeaders: m.CustomHeaders,
|
||||
Enabled: m.Enabled,
|
||||
Available: modelStatuses[i].Available,
|
||||
Status: modelStatuses[i].Status,
|
||||
@@ -216,6 +218,14 @@ func (h *Handler) handleUpdateModel(w http.ResponseWriter, r *http.Request) {
|
||||
} else if len(mc.ExtraBody) == 0 {
|
||||
mc.ExtraBody = nil
|
||||
}
|
||||
// Preserve existing CustomHeaders when omitted (nil), but clear it when
|
||||
// the frontend sends an empty object {} to indicate the field should
|
||||
// be removed.
|
||||
if mc.CustomHeaders == nil {
|
||||
mc.CustomHeaders = cfg.ModelList[idx].CustomHeaders
|
||||
} else if len(mc.CustomHeaders) == 0 {
|
||||
mc.CustomHeaders = nil
|
||||
}
|
||||
|
||||
cfg.ModelList[idx] = &mc.ModelConfig
|
||||
|
||||
|
||||
@@ -20,10 +20,14 @@ func resetModelProbeHooks(t *testing.T) {
|
||||
origTCPProbe := probeTCPServiceFunc
|
||||
origOllamaProbe := probeOllamaModelFunc
|
||||
origOpenAIProbe := probeOpenAICompatibleModelFunc
|
||||
origNow := modelProbeNowFunc
|
||||
resetModelProbeCache()
|
||||
t.Cleanup(func() {
|
||||
probeTCPServiceFunc = origTCPProbe
|
||||
probeOllamaModelFunc = origOllamaProbe
|
||||
probeOpenAICompatibleModelFunc = origOpenAIProbe
|
||||
modelProbeNowFunc = origNow
|
||||
resetModelProbeCache()
|
||||
})
|
||||
}
|
||||
|
||||
@@ -426,6 +430,112 @@ func TestHandleAddModel_PersistsAPIKey(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleAddModel_PersistsCustomHeaders(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/models", bytes.NewBufferString(`{
|
||||
"model_name":"new-model-headers",
|
||||
"model":"openai/gpt-4o-mini",
|
||||
"custom_headers":{"X-Source":"coding-plan","X-Agent":"openclaw"}
|
||||
}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
|
||||
cfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
if len(cfg.ModelList) != 2 {
|
||||
t.Fatalf("len(model_list) = %d, want 2", len(cfg.ModelList))
|
||||
}
|
||||
|
||||
added := cfg.ModelList[1]
|
||||
if added.CustomHeaders == nil {
|
||||
t.Fatal("custom_headers should not be nil")
|
||||
}
|
||||
if got := added.CustomHeaders["X-Source"]; got != "coding-plan" {
|
||||
t.Fatalf("custom_headers[X-Source] = %q, want %q", got, "coding-plan")
|
||||
}
|
||||
if got := added.CustomHeaders["X-Agent"]; got != "openclaw" {
|
||||
t.Fatalf("custom_headers[X-Agent] = %q, want %q", got, "openclaw")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleUpdateModel_CustomHeadersPreserveAndClear(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
|
||||
cfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
cfg.ModelList = []*config.ModelConfig{{
|
||||
ModelName: "editable",
|
||||
Model: "openai/gpt-4o-mini",
|
||||
APIKeys: config.SimpleSecureStrings("sk-existing"),
|
||||
CustomHeaders: map[string]string{"X-Source": "coding-plan"},
|
||||
}}
|
||||
err = config.SaveConfig(configPath, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
// Omitted custom_headers should preserve existing value.
|
||||
recPreserve := httptest.NewRecorder()
|
||||
reqPreserve := httptest.NewRequest(http.MethodPut, "/api/models/0", bytes.NewBufferString(`{
|
||||
"model_name":"editable",
|
||||
"model":"openai/gpt-4o-mini"
|
||||
}`))
|
||||
reqPreserve.Header.Set("Content-Type", "application/json")
|
||||
mux.ServeHTTP(recPreserve, reqPreserve)
|
||||
if recPreserve.Code != http.StatusOK {
|
||||
t.Fatalf("preserve status = %d, want %d, body=%s", recPreserve.Code, http.StatusOK, recPreserve.Body.String())
|
||||
}
|
||||
|
||||
afterPreserve, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() after preserve error = %v", err)
|
||||
}
|
||||
if got := afterPreserve.ModelList[0].CustomHeaders["X-Source"]; got != "coding-plan" {
|
||||
t.Fatalf("preserved custom_headers[X-Source] = %q, want %q", got, "coding-plan")
|
||||
}
|
||||
|
||||
// Empty object should clear custom_headers.
|
||||
recClear := httptest.NewRecorder()
|
||||
reqClear := httptest.NewRequest(http.MethodPut, "/api/models/0", bytes.NewBufferString(`{
|
||||
"model_name":"editable",
|
||||
"model":"openai/gpt-4o-mini",
|
||||
"custom_headers":{}
|
||||
}`))
|
||||
reqClear.Header.Set("Content-Type", "application/json")
|
||||
mux.ServeHTTP(recClear, reqClear)
|
||||
if recClear.Code != http.StatusOK {
|
||||
t.Fatalf("clear status = %d, want %d, body=%s", recClear.Code, http.StatusOK, recClear.Body.String())
|
||||
}
|
||||
|
||||
afterClear, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() after clear error = %v", err)
|
||||
}
|
||||
if afterClear.ModelList[0].CustomHeaders != nil {
|
||||
t.Fatalf("custom_headers = %#v, want nil", afterClear.ModelList[0].CustomHeaders)
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleSetDefaultModel_RejectsNonexistentModel tests that setting a non-existent
|
||||
// model as default returns 404. This covers the case where virtual models (which are
|
||||
// filtered by SaveConfig) cannot be set as default.
|
||||
|
||||
+27
-1
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
ppid "github.com/sipeed/picoclaw/pkg/pid"
|
||||
)
|
||||
|
||||
// registerPicoRoutes binds Pico Channel management endpoints to the ServeMux.
|
||||
@@ -57,9 +58,34 @@ func (h *Handler) handleWebSocketProxy() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
gateway.mu.Lock()
|
||||
ensurePicoTokenCachedLocked(h.configPath)
|
||||
gatewayAvailable := gateway.pidData != nil
|
||||
cachedPID := gateway.pidData
|
||||
trackedCmd := gateway.cmd
|
||||
gateway.mu.Unlock()
|
||||
|
||||
gatewayAvailable := false
|
||||
// Prefer fresh PID file data when available.
|
||||
if pidData := ppid.ReadPidFileWithCheck(globalConfigDir()); pidData != nil {
|
||||
gateway.mu.Lock()
|
||||
gateway.pidData = pidData
|
||||
setGatewayRuntimeStatusLocked("running")
|
||||
gatewayAvailable = true
|
||||
gateway.mu.Unlock()
|
||||
} else if cachedPID != nil {
|
||||
// No PID file now: keep availability only while tracked process is
|
||||
// still alive (covers short PID-file races at startup/restart).
|
||||
if isCmdProcessAliveLocked(trackedCmd) {
|
||||
gatewayAvailable = true
|
||||
} else {
|
||||
gateway.mu.Lock()
|
||||
if gateway.cmd == trackedCmd {
|
||||
gateway.pidData = nil
|
||||
setGatewayRuntimeStatusLocked("stopped")
|
||||
}
|
||||
gatewayAvailable = gateway.pidData != nil
|
||||
gateway.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
if !gatewayAvailable {
|
||||
logger.Warnf("Gateway not available for WebSocket proxy")
|
||||
http.Error(w, "Gateway not available", http.StatusServiceUnavailable)
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/channels/pico"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
ppid "github.com/sipeed/picoclaw/pkg/pid"
|
||||
)
|
||||
@@ -307,6 +308,9 @@ func TestHandlePicoSetup_Response(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestHandleWebSocketProxyReloadsGatewayTargetFromConfig(t *testing.T) {
|
||||
home := t.TempDir()
|
||||
t.Setenv("PICOCLAW_HOME", home)
|
||||
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
handler := h.handleWebSocketProxy()
|
||||
@@ -335,6 +339,16 @@ func TestHandleWebSocketProxyReloadsGatewayTargetFromConfig(t *testing.T) {
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
if _, err := ppid.WritePidFile(globalConfigDir(), cfg.Gateway.Host, cfg.Gateway.Port); err != nil {
|
||||
t.Fatalf("WritePidFile() error = %v", err)
|
||||
}
|
||||
origPidData := gateway.pidData
|
||||
origPicoToken := gateway.picoToken
|
||||
t.Cleanup(func() {
|
||||
ppid.RemovePidFile(globalConfigDir())
|
||||
gateway.pidData = origPidData
|
||||
gateway.picoToken = origPicoToken
|
||||
})
|
||||
|
||||
gateway.pidData = &ppid.PidFileData{}
|
||||
gateway.picoToken = "pico"
|
||||
@@ -378,6 +392,9 @@ func TestHandleWebSocketProxyReloadsGatewayTargetFromConfig(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestHandleWebSocketProxyLoadsCachedPicoTokenWhenMissing(t *testing.T) {
|
||||
home := t.TempDir()
|
||||
t.Setenv("PICOCLAW_HOME", home)
|
||||
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
handler := h.handleWebSocketProxy()
|
||||
@@ -399,6 +416,12 @@ func TestHandleWebSocketProxyLoadsCachedPicoTokenWhenMissing(t *testing.T) {
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
if _, err := ppid.WritePidFile(globalConfigDir(), cfg.Gateway.Host, cfg.Gateway.Port); err != nil {
|
||||
t.Fatalf("WritePidFile() error = %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
ppid.RemovePidFile(globalConfigDir())
|
||||
})
|
||||
|
||||
origPidData := gateway.pidData
|
||||
origPicoToken := gateway.picoToken
|
||||
@@ -426,6 +449,134 @@ func TestHandleWebSocketProxyLoadsCachedPicoTokenWhenMissing(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleWebSocketProxyLoadsPidDataOnDemand(t *testing.T) {
|
||||
home := t.TempDir()
|
||||
t.Setenv("PICOCLAW_HOME", home)
|
||||
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
handler := h.handleWebSocketProxy()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/pico/ws" {
|
||||
t.Fatalf("path = %q, want %q", r.URL.Path, "/pico/ws")
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = io.WriteString(w, r.Header.Get(protocolKey))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Gateway.Host = "127.0.0.1"
|
||||
cfg.Gateway.Port = mustGatewayTestPort(t, server.URL)
|
||||
cfg.Channels.Pico.Enabled = true
|
||||
cfg.Channels.Pico.SetToken("ui-token")
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
pidData, err := ppid.WritePidFile(globalConfigDir(), cfg.Gateway.Host, cfg.Gateway.Port)
|
||||
if err != nil {
|
||||
t.Fatalf("WritePidFile() error = %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
ppid.RemovePidFile(globalConfigDir())
|
||||
})
|
||||
|
||||
origPidData := gateway.pidData
|
||||
origPicoToken := gateway.picoToken
|
||||
origStatus := gateway.runtimeStatus
|
||||
t.Cleanup(func() {
|
||||
gateway.mu.Lock()
|
||||
gateway.pidData = origPidData
|
||||
gateway.picoToken = origPicoToken
|
||||
gateway.runtimeStatus = origStatus
|
||||
gateway.mu.Unlock()
|
||||
})
|
||||
|
||||
gateway.mu.Lock()
|
||||
gateway.pidData = nil
|
||||
gateway.picoToken = ""
|
||||
setGatewayRuntimeStatusLocked("stopped")
|
||||
gateway.mu.Unlock()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/pico/ws?session_id=test-session", nil)
|
||||
req.Header.Set(protocolKey, tokenPrefix+"ui-token")
|
||||
rec := httptest.NewRecorder()
|
||||
handler(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
expected := tokenPrefix + pico.PicoTokenPrefix + pidData.Token + "ui-token"
|
||||
if got := rec.Body.String(); got != expected {
|
||||
t.Fatalf("forwarded protocol = %q, want %q", got, expected)
|
||||
}
|
||||
|
||||
gateway.mu.Lock()
|
||||
defer gateway.mu.Unlock()
|
||||
if gateway.pidData == nil {
|
||||
t.Fatal("gateway.pidData should be loaded from pid file")
|
||||
}
|
||||
if gateway.runtimeStatus != "running" {
|
||||
t.Fatalf("runtimeStatus = %q, want %q", gateway.runtimeStatus, "running")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleWebSocketProxyRejectsStalePidDataAfterProcessExit(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
handler := h.handleWebSocketProxy()
|
||||
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Channels.Pico.Enabled = true
|
||||
cfg.Channels.Pico.SetToken("ui-token")
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
cmd := startLongRunningProcess(t)
|
||||
if cmd.Process != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
}
|
||||
_ = cmd.Wait()
|
||||
|
||||
origPidData := gateway.pidData
|
||||
origPicoToken := gateway.picoToken
|
||||
origCmd := gateway.cmd
|
||||
origStatus := gateway.runtimeStatus
|
||||
t.Cleanup(func() {
|
||||
gateway.mu.Lock()
|
||||
gateway.pidData = origPidData
|
||||
gateway.picoToken = origPicoToken
|
||||
gateway.cmd = origCmd
|
||||
gateway.runtimeStatus = origStatus
|
||||
gateway.mu.Unlock()
|
||||
})
|
||||
|
||||
gateway.mu.Lock()
|
||||
gateway.pidData = &ppid.PidFileData{PID: cmd.Process.Pid, Token: "stale-token"}
|
||||
gateway.picoToken = "ui-token"
|
||||
gateway.cmd = cmd
|
||||
setGatewayRuntimeStatusLocked("running")
|
||||
gateway.mu.Unlock()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/pico/ws?session_id=test-session", nil)
|
||||
req.Header.Set(protocolKey, tokenPrefix+"ui-token")
|
||||
rec := httptest.NewRecorder()
|
||||
handler(rec, req)
|
||||
|
||||
if rec.Code != http.StatusServiceUnavailable {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusServiceUnavailable)
|
||||
}
|
||||
gateway.mu.Lock()
|
||||
defer gateway.mu.Unlock()
|
||||
if gateway.pidData != nil {
|
||||
t.Fatal("gateway.pidData should be cleared after stale process exit is detected")
|
||||
}
|
||||
}
|
||||
|
||||
func mustGatewayTestPort(t *testing.T, rawURL string) int {
|
||||
t.Helper()
|
||||
|
||||
|
||||
@@ -81,6 +81,9 @@ func (h *Handler) RegisterRoutes(mux *http.ServeMux) {
|
||||
// Launcher service parameters (port/public)
|
||||
h.registerLauncherConfigRoutes(mux)
|
||||
|
||||
// Self-update endpoint (requires dashboard auth)
|
||||
h.registerUpdateRoutes(mux)
|
||||
|
||||
// Runtime build/version metadata
|
||||
h.registerVersionRoutes(mux)
|
||||
|
||||
|
||||
+113
-34
@@ -44,12 +44,24 @@ type sessionListItem struct {
|
||||
Updated string `json:"updated"`
|
||||
}
|
||||
|
||||
type sessionChatMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
Media []string `json:"media,omitempty"`
|
||||
}
|
||||
|
||||
// legacyPicoSessionPrefix is the legacy key prefix used by older Pico JSON/JSONL
|
||||
// sessions before structured scope metadata existed.
|
||||
const (
|
||||
legacyPicoSessionPrefix = "agent:main:pico:direct:pico:"
|
||||
maxSessionJSONLLineSize = 10 * 1024 * 1024 // 10 MB
|
||||
picoSessionPrefix = legacyPicoSessionPrefix
|
||||
|
||||
// Keep the session API aligned with the shared JSONL store reader limit in
|
||||
// pkg/memory/jsonl.go so oversized lines fail consistently everywhere.
|
||||
maxSessionJSONLLineSize = 10 * 1024 * 1024
|
||||
maxSessionTitleRunes = 60
|
||||
|
||||
handledToolResponseSummaryText = "Requested output delivered via tool attachment."
|
||||
)
|
||||
|
||||
// extractLegacyPicoSessionID extracts the session UUID from an old Pico key.
|
||||
@@ -327,32 +339,21 @@ func (h *Handler) findLegacyPicoSession(dir, sessionID string) (picoLegacySessio
|
||||
func buildSessionListItem(sessionID string, sess sessionFile) sessionListItem {
|
||||
preview := ""
|
||||
for _, msg := range sess.Messages {
|
||||
if msg.Role == "user" && strings.TrimSpace(msg.Content) != "" {
|
||||
preview = msg.Content
|
||||
if msg.Role == "user" {
|
||||
preview = sessionMessagePreview(msg)
|
||||
}
|
||||
if preview != "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
title := strings.TrimSpace(sess.Summary)
|
||||
if title == "" {
|
||||
title = preview
|
||||
}
|
||||
|
||||
title = truncateRunes(title, maxSessionTitleRunes)
|
||||
preview = truncateRunes(preview, maxSessionTitleRunes)
|
||||
|
||||
if preview == "" {
|
||||
preview = "(empty)"
|
||||
}
|
||||
if title == "" {
|
||||
title = preview
|
||||
}
|
||||
title := preview
|
||||
|
||||
validMessageCount := 0
|
||||
for _, msg := range sess.Messages {
|
||||
if (msg.Role == "user" || msg.Role == "assistant") && strings.TrimSpace(msg.Content) != "" {
|
||||
validMessageCount++
|
||||
}
|
||||
}
|
||||
validMessageCount := len(visibleSessionMessages(sess.Messages))
|
||||
|
||||
return sessionListItem{
|
||||
ID: sessionID,
|
||||
@@ -379,6 +380,99 @@ func truncateRunes(s string, maxLen int) string {
|
||||
return string(runes[:maxLen]) + "..."
|
||||
}
|
||||
|
||||
func sessionMessageVisible(msg providers.Message) bool {
|
||||
return strings.TrimSpace(msg.Content) != "" || len(msg.Media) > 0
|
||||
}
|
||||
|
||||
func sessionMessagePreview(msg providers.Message) string {
|
||||
if content := strings.TrimSpace(msg.Content); content != "" {
|
||||
return content
|
||||
}
|
||||
if len(msg.Media) > 0 {
|
||||
return "[image]"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func visibleSessionMessages(messages []providers.Message) []sessionChatMessage {
|
||||
transcript := make([]sessionChatMessage, 0, len(messages))
|
||||
|
||||
for _, msg := range messages {
|
||||
switch msg.Role {
|
||||
case "user":
|
||||
if sessionMessageVisible(msg) {
|
||||
transcript = append(transcript, sessionChatMessage{
|
||||
Role: "user",
|
||||
Content: msg.Content,
|
||||
Media: append([]string(nil), msg.Media...),
|
||||
})
|
||||
}
|
||||
|
||||
case "assistant":
|
||||
visibleToolMessages := visibleAssistantToolMessages(msg.ToolCalls)
|
||||
if len(visibleToolMessages) > 0 {
|
||||
transcript = append(transcript, visibleToolMessages...)
|
||||
}
|
||||
|
||||
// Pico web chat can persist both visible `message` tool output and a
|
||||
// later plain assistant reply in the same turn. Hide only the fixed
|
||||
// internal summary that marks handled tool delivery.
|
||||
if len(visibleToolMessages) > 0 || !sessionMessageVisible(msg) || assistantMessageInternalOnly(msg) {
|
||||
continue
|
||||
}
|
||||
|
||||
transcript = append(transcript, sessionChatMessage{
|
||||
Role: "assistant",
|
||||
Content: msg.Content,
|
||||
Media: append([]string(nil), msg.Media...),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return transcript
|
||||
}
|
||||
|
||||
func assistantMessageInternalOnly(msg providers.Message) bool {
|
||||
return strings.TrimSpace(msg.Content) == handledToolResponseSummaryText
|
||||
}
|
||||
|
||||
func visibleAssistantToolMessages(toolCalls []providers.ToolCall) []sessionChatMessage {
|
||||
if len(toolCalls) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
messages := make([]sessionChatMessage, 0, len(toolCalls))
|
||||
for _, tc := range toolCalls {
|
||||
name := tc.Name
|
||||
argsJSON := ""
|
||||
if tc.Function != nil {
|
||||
if name == "" {
|
||||
name = tc.Function.Name
|
||||
}
|
||||
argsJSON = tc.Function.Arguments
|
||||
}
|
||||
|
||||
switch name {
|
||||
case "message":
|
||||
var args struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(argsJSON), &args); err != nil {
|
||||
continue
|
||||
}
|
||||
if strings.TrimSpace(args.Content) == "" {
|
||||
continue
|
||||
}
|
||||
messages = append(messages, sessionChatMessage{
|
||||
Role: "assistant",
|
||||
Content: args.Content,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
// sessionsDir resolves the path to the gateway's session storage directory.
|
||||
// It reads the workspace from config, falling back to ~/.picoclaw/workspace.
|
||||
func (h *Handler) sessionsDir() (string, error) {
|
||||
@@ -530,22 +624,7 @@ func (h *Handler) handleGetSession(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
// Convert to a simpler format for the frontend
|
||||
type chatMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
messages := make([]chatMessage, 0, len(sess.Messages))
|
||||
for _, msg := range sess.Messages {
|
||||
// Only include user and assistant messages that have actual content
|
||||
if (msg.Role == "user" || msg.Role == "assistant") && strings.TrimSpace(msg.Content) != "" {
|
||||
messages = append(messages, chatMessage{
|
||||
Role: msg.Role,
|
||||
Content: msg.Content,
|
||||
})
|
||||
}
|
||||
}
|
||||
messages := visibleSessionMessages(sess.Messages)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
|
||||
+376
-23
@@ -6,6 +6,7 @@ import (
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
@@ -34,9 +35,9 @@ func TestHandleListSessions_JSONLStorage(t *testing.T) {
|
||||
defer cleanup()
|
||||
|
||||
dir := sessionsTestDir(t, configPath)
|
||||
store, err := memory.NewJSONLStore(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("NewJSONLStore() error = %v", err)
|
||||
store, storeErr := memory.NewJSONLStore(dir)
|
||||
if storeErr != nil {
|
||||
t.Fatalf("NewJSONLStore() error = %v", storeErr)
|
||||
}
|
||||
|
||||
sessionKey := legacyPicoSessionPrefix + "history-jsonl"
|
||||
@@ -87,22 +88,26 @@ func TestHandleListSessions_JSONLStorage(t *testing.T) {
|
||||
if items[0].MessageCount != 2 {
|
||||
t.Fatalf("items[0].MessageCount = %d, want 2", items[0].MessageCount)
|
||||
}
|
||||
if items[0].Title != "JSONL-backed session" {
|
||||
t.Fatalf("items[0].Title = %q, want %q", items[0].Title, "JSONL-backed session")
|
||||
if items[0].Title != "Explain why the history API is empty after migration." {
|
||||
t.Fatalf(
|
||||
"items[0].Title = %q, want %q",
|
||||
items[0].Title,
|
||||
"Explain why the history API is empty after migration.",
|
||||
)
|
||||
}
|
||||
if items[0].Preview != "Explain why the history API is empty after migration." {
|
||||
t.Fatalf("items[0].Preview = %q", items[0].Preview)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleListSessions_TitleUsesTrimmedSummary(t *testing.T) {
|
||||
func TestHandleListSessions_TitleUsesFirstUserMessage(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
|
||||
dir := sessionsTestDir(t, configPath)
|
||||
store, err := memory.NewJSONLStore(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("NewJSONLStore() error = %v", err)
|
||||
store, storeErr := memory.NewJSONLStore(dir)
|
||||
if storeErr != nil {
|
||||
t.Fatalf("NewJSONLStore() error = %v", storeErr)
|
||||
}
|
||||
|
||||
sessionKey := legacyPicoSessionPrefix + "summary-title"
|
||||
@@ -139,10 +144,7 @@ func TestHandleListSessions_TitleUsesTrimmedSummary(t *testing.T) {
|
||||
if len(items) != 1 {
|
||||
t.Fatalf("len(items) = %d, want 1", len(items))
|
||||
}
|
||||
expectedTitle := truncateRunes(
|
||||
"This summary is intentionally longer than sixty characters so it must be truncated in the history menu.",
|
||||
maxSessionTitleRunes,
|
||||
)
|
||||
expectedTitle := truncateRunes("fallback preview", maxSessionTitleRunes)
|
||||
if items[0].Title != expectedTitle {
|
||||
t.Fatalf("items[0].Title = %q", items[0].Title)
|
||||
}
|
||||
@@ -220,22 +222,20 @@ func TestHandleSessions_JSONLScopeDiscovery(t *testing.T) {
|
||||
defer cleanup()
|
||||
|
||||
dir := sessionsTestDir(t, configPath)
|
||||
store, err := memory.NewJSONLStore(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("NewJSONLStore() error = %v", err)
|
||||
store, storeErr := memory.NewJSONLStore(dir)
|
||||
if storeErr != nil {
|
||||
t.Fatalf("NewJSONLStore() error = %v", storeErr)
|
||||
}
|
||||
|
||||
sessionKey := "sk_v1_scope_discovery"
|
||||
addErr := store.AddFullMessage(nil, sessionKey, providers.Message{
|
||||
if err := store.AddFullMessage(nil, sessionKey, providers.Message{
|
||||
Role: "user",
|
||||
Content: "scope discovered session",
|
||||
})
|
||||
if addErr != nil {
|
||||
t.Fatalf("AddFullMessage() error = %v", addErr)
|
||||
}); err != nil {
|
||||
t.Fatalf("AddFullMessage() error = %v", err)
|
||||
}
|
||||
summaryErr := store.SetSummary(nil, sessionKey, "scope summary")
|
||||
if summaryErr != nil {
|
||||
t.Fatalf("SetSummary() error = %v", summaryErr)
|
||||
if err := store.SetSummary(nil, sessionKey, "scope summary"); err != nil {
|
||||
t.Fatalf("SetSummary() error = %v", err)
|
||||
}
|
||||
|
||||
scopeData, err := json.Marshal(session.SessionScope{
|
||||
@@ -292,6 +292,359 @@ func TestHandleSessions_JSONLScopeDiscovery(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleGetSession_ReconstructsVisibleMessageToolOutput(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
|
||||
dir := sessionsTestDir(t, configPath)
|
||||
store, err := memory.NewJSONLStore(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("NewJSONLStore() error = %v", err)
|
||||
}
|
||||
|
||||
sessionKey := picoSessionPrefix + "detail-message-tool"
|
||||
for _, msg := range []providers.Message{
|
||||
{Role: "user", Content: "test"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "",
|
||||
ToolCalls: []providers.ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Type: "function",
|
||||
Function: &providers.FunctionCall{
|
||||
Name: "message",
|
||||
Arguments: `{"content":"visible tool output"}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "Message sent to pico:pico:detail-message-tool", ToolCallID: "call_1"},
|
||||
{Role: "assistant", Content: handledToolResponseSummaryText},
|
||||
} {
|
||||
if err := store.AddFullMessage(nil, sessionKey, msg); err != nil {
|
||||
t.Fatalf("AddFullMessage() error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/sessions/detail-message-tool", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
|
||||
var resp struct {
|
||||
Messages []struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
} `json:"messages"`
|
||||
}
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("Unmarshal() error = %v", err)
|
||||
}
|
||||
if len(resp.Messages) != 2 {
|
||||
t.Fatalf("len(resp.Messages) = %d, want 2", len(resp.Messages))
|
||||
}
|
||||
if resp.Messages[1].Role != "assistant" || resp.Messages[1].Content != "visible tool output" {
|
||||
t.Fatalf("assistant message = %#v, want visible tool output", resp.Messages[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleGetSession_PreservesFinalAssistantReplyAfterMessageToolOutput(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
|
||||
dir := sessionsTestDir(t, configPath)
|
||||
store, err := memory.NewJSONLStore(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("NewJSONLStore() error = %v", err)
|
||||
}
|
||||
|
||||
sessionKey := picoSessionPrefix + "detail-message-tool-final-reply"
|
||||
for _, msg := range []providers.Message{
|
||||
{Role: "user", Content: "test"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []providers.ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Type: "function",
|
||||
Function: &providers.FunctionCall{
|
||||
Name: "message",
|
||||
Arguments: `{"content":"visible tool output"}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "Message sent to pico:pico:detail-message-tool-final-reply", ToolCallID: "call_1"},
|
||||
{Role: "assistant", Content: "final assistant reply"},
|
||||
} {
|
||||
if err := store.AddFullMessage(nil, sessionKey, msg); err != nil {
|
||||
t.Fatalf("AddFullMessage() error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/sessions/detail-message-tool-final-reply", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
|
||||
var resp struct {
|
||||
Messages []struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
} `json:"messages"`
|
||||
}
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("Unmarshal() error = %v", err)
|
||||
}
|
||||
if len(resp.Messages) != 3 {
|
||||
t.Fatalf("len(resp.Messages) = %d, want 3", len(resp.Messages))
|
||||
}
|
||||
if resp.Messages[1].Role != "assistant" || resp.Messages[1].Content != "visible tool output" {
|
||||
t.Fatalf("interim assistant message = %#v, want visible tool output", resp.Messages[1])
|
||||
}
|
||||
if resp.Messages[2].Role != "assistant" || resp.Messages[2].Content != "final assistant reply" {
|
||||
t.Fatalf("final assistant message = %#v, want final assistant reply", resp.Messages[2])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleListSessions_MessageCountUsesVisibleTranscript(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
|
||||
dir := sessionsTestDir(t, configPath)
|
||||
store, err := memory.NewJSONLStore(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("NewJSONLStore() error = %v", err)
|
||||
}
|
||||
|
||||
sessionKey := picoSessionPrefix + "list-visible-count"
|
||||
for _, msg := range []providers.Message{
|
||||
{Role: "user", Content: "test"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []providers.ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Type: "function",
|
||||
Function: &providers.FunctionCall{
|
||||
Name: "message",
|
||||
Arguments: `{"content":"visible tool output"}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "Message sent to pico:pico:list-visible-count", ToolCallID: "call_1"},
|
||||
{Role: "assistant", Content: handledToolResponseSummaryText},
|
||||
} {
|
||||
if err := store.AddFullMessage(nil, sessionKey, msg); err != nil {
|
||||
t.Fatalf("AddFullMessage() error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/sessions", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
|
||||
var items []sessionListItem
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &items); err != nil {
|
||||
t.Fatalf("Unmarshal() error = %v", err)
|
||||
}
|
||||
if len(items) != 1 {
|
||||
t.Fatalf("len(items) = %d, want 1", len(items))
|
||||
}
|
||||
if items[0].MessageCount != 2 {
|
||||
t.Fatalf("items[0].MessageCount = %d, want 2", items[0].MessageCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleGetSession_IncludesMediaOnlyMessages(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
|
||||
dir := sessionsTestDir(t, configPath)
|
||||
store, err := memory.NewJSONLStore(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("NewJSONLStore() error = %v", err)
|
||||
}
|
||||
|
||||
sessionKey := picoSessionPrefix + "detail-media-only"
|
||||
if err := store.AddFullMessage(nil, sessionKey, providers.Message{
|
||||
Role: "user",
|
||||
Media: []string{"data:image/png;base64,abc123"},
|
||||
}); err != nil {
|
||||
t.Fatalf("AddFullMessage(user) error = %v", err)
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/sessions/detail-media-only", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
|
||||
var resp struct {
|
||||
Messages []struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
Media []string `json:"media"`
|
||||
} `json:"messages"`
|
||||
}
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("Unmarshal() error = %v", err)
|
||||
}
|
||||
if len(resp.Messages) != 1 {
|
||||
t.Fatalf("len(resp.Messages) = %d, want 1", len(resp.Messages))
|
||||
}
|
||||
if resp.Messages[0].Role != "user" || len(resp.Messages[0].Media) != 1 {
|
||||
t.Fatalf("message = %#v, want user message with media", resp.Messages[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleSessions_SupportsJSONLMessagesUpToStoreCap(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
|
||||
dir := sessionsTestDir(t, configPath)
|
||||
store, err := memory.NewJSONLStore(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("NewJSONLStore() error = %v", err)
|
||||
}
|
||||
|
||||
sessionKey := picoSessionPrefix + "detail-large-jsonl"
|
||||
largeContent := strings.Repeat("x", 9*1024*1024)
|
||||
if err := store.AddFullMessage(nil, sessionKey, providers.Message{
|
||||
Role: "user",
|
||||
Content: largeContent,
|
||||
}); err != nil {
|
||||
t.Fatalf("AddFullMessage() error = %v", err)
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
listRec := httptest.NewRecorder()
|
||||
listReq := httptest.NewRequest(http.MethodGet, "/api/sessions", nil)
|
||||
mux.ServeHTTP(listRec, listReq)
|
||||
|
||||
if listRec.Code != http.StatusOK {
|
||||
t.Fatalf("list status = %d, want %d, body=%s", listRec.Code, http.StatusOK, listRec.Body.String())
|
||||
}
|
||||
|
||||
var items []sessionListItem
|
||||
if err := json.Unmarshal(listRec.Body.Bytes(), &items); err != nil {
|
||||
t.Fatalf("list Unmarshal() error = %v", err)
|
||||
}
|
||||
if len(items) != 1 {
|
||||
t.Fatalf("len(items) = %d, want 1", len(items))
|
||||
}
|
||||
|
||||
detailRec := httptest.NewRecorder()
|
||||
detailReq := httptest.NewRequest(http.MethodGet, "/api/sessions/detail-large-jsonl", nil)
|
||||
mux.ServeHTTP(detailRec, detailReq)
|
||||
|
||||
if detailRec.Code != http.StatusOK {
|
||||
t.Fatalf(
|
||||
"detail status = %d, want %d, body=%s",
|
||||
detailRec.Code,
|
||||
http.StatusOK,
|
||||
detailRec.Body.String(),
|
||||
)
|
||||
}
|
||||
|
||||
var resp struct {
|
||||
Messages []struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
} `json:"messages"`
|
||||
}
|
||||
if err := json.Unmarshal(detailRec.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("detail Unmarshal() error = %v", err)
|
||||
}
|
||||
if len(resp.Messages) != 1 {
|
||||
t.Fatalf("len(resp.Messages) = %d, want 1", len(resp.Messages))
|
||||
}
|
||||
if resp.Messages[0].Role != "user" {
|
||||
t.Fatalf("resp.Messages[0].Role = %q, want %q", resp.Messages[0].Role, "user")
|
||||
}
|
||||
if got := len(resp.Messages[0].Content); got != len(largeContent) {
|
||||
t.Fatalf("len(resp.Messages[0].Content) = %d, want %d", got, len(largeContent))
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleListSessions_UsesImagePreviewForMediaOnlyMessage(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
|
||||
dir := sessionsTestDir(t, configPath)
|
||||
store, err := memory.NewJSONLStore(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("NewJSONLStore() error = %v", err)
|
||||
}
|
||||
|
||||
sessionKey := picoSessionPrefix + "preview-media-only"
|
||||
if err := store.AddFullMessage(nil, sessionKey, providers.Message{
|
||||
Role: "user",
|
||||
Media: []string{"data:image/png;base64,abc123"},
|
||||
}); err != nil {
|
||||
t.Fatalf("AddFullMessage() error = %v", err)
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/sessions", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
|
||||
var items []sessionListItem
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &items); err != nil {
|
||||
t.Fatalf("Unmarshal() error = %v", err)
|
||||
}
|
||||
if len(items) != 1 {
|
||||
t.Fatalf("len(items) = %d, want 1", len(items))
|
||||
}
|
||||
if items[0].Preview != "[image]" {
|
||||
t.Fatalf("items[0].Preview = %q, want %q", items[0].Preview, "[image]")
|
||||
}
|
||||
if items[0].MessageCount != 1 {
|
||||
t.Fatalf("items[0].MessageCount = %d, want 1", items[0].MessageCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleDeleteSession_JSONLStorage(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
|
||||
+833
-54
@@ -1,40 +1,115 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/fileutil"
|
||||
"github.com/sipeed/picoclaw/pkg/skills"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
type skillSupportResponse struct {
|
||||
Skills []skills.SkillInfo `json:"skills"`
|
||||
Skills []skillSupportItem `json:"skills"`
|
||||
}
|
||||
|
||||
type skillSupportItem struct {
|
||||
Name string `json:"name"`
|
||||
Path string `json:"path"`
|
||||
Source string `json:"source"`
|
||||
Description string `json:"description"`
|
||||
OriginKind string `json:"origin_kind"`
|
||||
RegistryName string `json:"registry_name,omitempty"`
|
||||
RegistryURL string `json:"registry_url,omitempty"`
|
||||
InstalledVersion string `json:"installed_version,omitempty"`
|
||||
InstalledAt int64 `json:"installed_at,omitempty"`
|
||||
}
|
||||
|
||||
type skillDetailResponse struct {
|
||||
Name string `json:"name"`
|
||||
Path string `json:"path"`
|
||||
Source string `json:"source"`
|
||||
Description string `json:"description"`
|
||||
Content string `json:"content"`
|
||||
skillSupportItem
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type skillSearchResultItem struct {
|
||||
Score float64 `json:"score"`
|
||||
Slug string `json:"slug"`
|
||||
DisplayName string `json:"display_name"`
|
||||
Summary string `json:"summary"`
|
||||
Version string `json:"version"`
|
||||
RegistryName string `json:"registry_name"`
|
||||
URL string `json:"url,omitempty"`
|
||||
Installed bool `json:"installed"`
|
||||
InstalledName string `json:"installed_name,omitempty"`
|
||||
}
|
||||
|
||||
type skillSearchResponse struct {
|
||||
Results []skillSearchResultItem `json:"results"`
|
||||
Limit int `json:"limit"`
|
||||
Offset int `json:"offset"`
|
||||
NextOffset int `json:"next_offset,omitempty"`
|
||||
HasMore bool `json:"has_more"`
|
||||
}
|
||||
|
||||
type installSkillRequest struct {
|
||||
Slug string `json:"slug"`
|
||||
Registry string `json:"registry"`
|
||||
Version string `json:"version,omitempty"`
|
||||
Force bool `json:"force,omitempty"`
|
||||
}
|
||||
|
||||
type installSkillResponse struct {
|
||||
Status string `json:"status"`
|
||||
Slug string `json:"slug"`
|
||||
Registry string `json:"registry"`
|
||||
Version string `json:"version"`
|
||||
Summary string `json:"summary,omitempty"`
|
||||
IsSuspicious bool `json:"is_suspicious,omitempty"`
|
||||
InstalledSkill *skillSupportItem `json:"skill,omitempty"`
|
||||
}
|
||||
|
||||
type installedSkillOriginMeta struct {
|
||||
Version int `json:"version"`
|
||||
OriginKind string `json:"origin_kind,omitempty"`
|
||||
Registry string `json:"registry,omitempty"`
|
||||
Slug string `json:"slug,omitempty"`
|
||||
RegistryURL string `json:"registry_url,omitempty"`
|
||||
InstalledVersion string `json:"installed_version,omitempty"`
|
||||
InstalledAt int64 `json:"installed_at"`
|
||||
}
|
||||
|
||||
var (
|
||||
skillNameSanitizer = regexp.MustCompile(`[^a-z0-9-]+`)
|
||||
importedSkillFrontmatter = regexp.MustCompile(`(?s)^---(?:\r\n|\n|\r)(.*?)(?:\r\n|\n|\r)---(?:\r\n|\n|\r)*`)
|
||||
skillFrontmatterStripper = regexp.MustCompile(`(?s)^---(?:\r\n|\n|\r)(.*?)(?:\r\n|\n|\r)---(?:\r\n|\n|\r)*`)
|
||||
persistSkillOriginMeta = writeSkillOriginMeta
|
||||
workspaceSkillWriteMu sync.Mutex
|
||||
errImportedSkillExists = errors.New("skill already exists")
|
||||
)
|
||||
|
||||
const (
|
||||
maxImportedSkillSize = 1 << 20
|
||||
maxRegistrySearchFanout = 1000
|
||||
)
|
||||
|
||||
func (h *Handler) registerSkillRoutes(mux *http.ServeMux) {
|
||||
mux.HandleFunc("GET /api/skills", h.handleListSkills)
|
||||
mux.HandleFunc("GET /api/skills/{name}", h.handleGetSkill)
|
||||
mux.HandleFunc("GET /api/skills/search", h.handleSearchSkills)
|
||||
mux.HandleFunc("POST /api/skills/install", h.handleInstallSkill)
|
||||
mux.HandleFunc("POST /api/skills/import", h.handleImportSkill)
|
||||
mux.HandleFunc("DELETE /api/skills/{name}", h.handleDeleteSkill)
|
||||
}
|
||||
@@ -46,11 +121,15 @@ func (h *Handler) handleListSkills(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
loader := newSkillsLoader(cfg.WorkspacePath())
|
||||
items, err := buildSkillSupportItems(cfg)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to build skill list: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(skillSupportResponse{
|
||||
Skills: loader.ListSkills(),
|
||||
Skills: items,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -61,16 +140,18 @@ func (h *Handler) handleGetSkill(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
loader := newSkillsLoader(cfg.WorkspacePath())
|
||||
skillItems, err := buildSkillSupportItems(cfg)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to build skill list: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
name := r.PathValue("name")
|
||||
allSkills := loader.ListSkills()
|
||||
|
||||
for _, skill := range allSkills {
|
||||
if skill.Name != name {
|
||||
for _, skillItem := range skillItems {
|
||||
if skillItem.Name != name {
|
||||
continue
|
||||
}
|
||||
|
||||
content, err := loadSkillContent(skill.Path)
|
||||
content, err := loadSkillContent(skillItem.Path)
|
||||
if err != nil {
|
||||
http.Error(w, "Skill content not found", http.StatusNotFound)
|
||||
return
|
||||
@@ -78,11 +159,8 @@ func (h *Handler) handleGetSkill(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(skillDetailResponse{
|
||||
Name: skill.Name,
|
||||
Path: skill.Path,
|
||||
Source: skill.Source,
|
||||
Description: skill.Description,
|
||||
Content: content,
|
||||
skillSupportItem: skillItem,
|
||||
Content: content,
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -90,6 +168,266 @@ func (h *Handler) handleGetSkill(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "Skill not found", http.StatusNotFound)
|
||||
}
|
||||
|
||||
func (h *Handler) handleSearchSkills(w http.ResponseWriter, r *http.Request) {
|
||||
cfg, loadErr := config.LoadConfig(h.configPath)
|
||||
if loadErr != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to load config: %v", loadErr), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if registryErr := ensureSkillRegistryToolEnabled(cfg, "find_skills"); registryErr != nil {
|
||||
http.Error(w, registryErr.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
query := strings.TrimSpace(r.URL.Query().Get("q"))
|
||||
|
||||
limit := 20
|
||||
if rawLimit := strings.TrimSpace(r.URL.Query().Get("limit")); rawLimit != "" {
|
||||
parsedLimit, parseErr := strconv.Atoi(rawLimit)
|
||||
if parseErr != nil || parsedLimit < 1 || parsedLimit > 50 {
|
||||
http.Error(w, "limit must be between 1 and 50", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
limit = parsedLimit
|
||||
}
|
||||
offset := 0
|
||||
if rawOffset := strings.TrimSpace(r.URL.Query().Get("offset")); rawOffset != "" {
|
||||
parsedOffset, parseErr := strconv.Atoi(rawOffset)
|
||||
if parseErr != nil || parsedOffset < 0 {
|
||||
http.Error(w, "offset must be 0 or greater", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
offset = parsedOffset
|
||||
}
|
||||
|
||||
installedSkills, err := buildOccupiedWorkspaceSkillsByDirectory(cfg)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to inspect installed skills: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if query == "" {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(skillSearchResponse{
|
||||
Results: []skillSearchResultItem{},
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
HasMore: false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
registryMgr := newSkillsRegistryManager(cfg)
|
||||
searchLimit := offset + limit + 1
|
||||
if searchLimit > maxRegistrySearchFanout {
|
||||
searchLimit = maxRegistrySearchFanout
|
||||
}
|
||||
results, err := registryMgr.SearchAll(r.Context(), query, searchLimit)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to search skills: %v", err), http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
||||
if offset > len(results) {
|
||||
offset = len(results)
|
||||
}
|
||||
|
||||
end := offset + limit
|
||||
if end > len(results) {
|
||||
end = len(results)
|
||||
}
|
||||
|
||||
pageResults := results[offset:end]
|
||||
response := make([]skillSearchResultItem, 0, len(pageResults))
|
||||
for _, result := range pageResults {
|
||||
installedSkill, installed := installedSkills[result.Slug]
|
||||
item := skillSearchResultItem{
|
||||
Score: result.Score,
|
||||
Slug: result.Slug,
|
||||
DisplayName: result.DisplayName,
|
||||
Summary: result.Summary,
|
||||
Version: result.Version,
|
||||
RegistryName: result.RegistryName,
|
||||
URL: registrySkillURL(cfg, result.RegistryName, result.Slug),
|
||||
Installed: installed,
|
||||
}
|
||||
if installed {
|
||||
item.InstalledName = installedSkill.Name
|
||||
}
|
||||
response = append(response, item)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
nextOffset := 0
|
||||
hasMore := len(results) > end
|
||||
if hasMore {
|
||||
nextOffset = end
|
||||
}
|
||||
json.NewEncoder(w).Encode(skillSearchResponse{
|
||||
Results: response,
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
NextOffset: nextOffset,
|
||||
HasMore: hasMore,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) handleInstallSkill(w http.ResponseWriter, r *http.Request) {
|
||||
cfg, loadErr := config.LoadConfig(h.configPath)
|
||||
if loadErr != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to load config: %v", loadErr), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if registryErr := ensureSkillRegistryToolEnabled(cfg, "install_skill"); registryErr != nil {
|
||||
http.Error(w, registryErr.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var req installSkillRequest
|
||||
if decodeErr := json.NewDecoder(r.Body).Decode(&req); decodeErr != nil {
|
||||
http.Error(w, fmt.Sprintf("Invalid JSON: %v", decodeErr), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
req.Slug = strings.TrimSpace(req.Slug)
|
||||
req.Registry = strings.TrimSpace(req.Registry)
|
||||
req.Version = strings.TrimSpace(req.Version)
|
||||
|
||||
if validateErr := utils.ValidateSkillIdentifier(req.Slug); validateErr != nil {
|
||||
http.Error(
|
||||
w,
|
||||
fmt.Sprintf("invalid slug %q: error: %s", req.Slug, validateErr.Error()),
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
return
|
||||
}
|
||||
if validateErr := utils.ValidateSkillIdentifier(req.Registry); validateErr != nil {
|
||||
http.Error(
|
||||
w,
|
||||
fmt.Sprintf("invalid registry %q: error: %s", req.Registry, validateErr.Error()),
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
registryMgr := newSkillsRegistryManager(cfg)
|
||||
registry := registryMgr.GetRegistry(req.Registry)
|
||||
if registry == nil {
|
||||
http.Error(w, fmt.Sprintf("registry %q not found", req.Registry), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
workspace := cfg.WorkspacePath()
|
||||
skillsRoot := filepath.Join(workspace, "skills")
|
||||
targetDir := filepath.Join(workspace, "skills", req.Slug)
|
||||
workspaceSkillWriteMu.Lock()
|
||||
defer workspaceSkillWriteMu.Unlock()
|
||||
|
||||
targetExists := false
|
||||
if _, statErr := os.Stat(targetDir); statErr == nil {
|
||||
targetExists = true
|
||||
} else if !os.IsNotExist(statErr) {
|
||||
http.Error(w, fmt.Sprintf("Failed to inspect install target: %v", statErr), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if !req.Force && targetExists {
|
||||
http.Error(w, fmt.Sprintf("skill %q already installed at %s", req.Slug, targetDir), http.StatusConflict)
|
||||
return
|
||||
}
|
||||
if err := os.MkdirAll(skillsRoot, 0o755); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to create skills directory: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
stagedWorkspaceRoot, stagedTargetDir, err := createStagedSkillInstall(skillsRoot, req.Slug)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to prepare staged install: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer os.RemoveAll(stagedWorkspaceRoot)
|
||||
|
||||
result, err := registry.DownloadAndInstall(r.Context(), req.Slug, req.Version, stagedTargetDir)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to install skill: %v", err), http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
if result.IsMalwareBlocked {
|
||||
http.Error(
|
||||
w,
|
||||
fmt.Sprintf("skill %q is flagged as malicious and cannot be installed", req.Slug),
|
||||
http.StatusForbidden,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if findWorkspaceSkillInfoByDirectory(stagedWorkspaceRoot, req.Slug) == nil {
|
||||
http.Error(
|
||||
w,
|
||||
fmt.Sprintf("Failed to install skill: registry archive for %q is not a valid skill", req.Slug),
|
||||
http.StatusBadGateway,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
installedAt := time.Now().UnixMilli()
|
||||
if err := persistSkillOriginMeta(stagedTargetDir, installedSkillOriginMeta{
|
||||
Version: 1,
|
||||
OriginKind: "third_party",
|
||||
Registry: registry.Name(),
|
||||
Slug: req.Slug,
|
||||
RegistryURL: registrySkillURL(cfg, registry.Name(), req.Slug),
|
||||
InstalledVersion: result.Version,
|
||||
InstalledAt: installedAt,
|
||||
}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to persist skill metadata: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if err := commitStagedSkillInstall(
|
||||
stagedWorkspaceRoot,
|
||||
stagedTargetDir,
|
||||
targetDir,
|
||||
req.Force && targetExists,
|
||||
); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to activate installed skill: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
validatedSkill := findWorkspaceSkillByDirectory(cfg, req.Slug)
|
||||
if validatedSkill == nil {
|
||||
http.Error(
|
||||
w,
|
||||
fmt.Sprintf("Failed to install skill: activated archive for %q is not a valid skill", req.Slug),
|
||||
http.StatusBadGateway,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
installedSkill := &skillSupportItem{
|
||||
Name: validatedSkill.Name,
|
||||
Path: validatedSkill.Path,
|
||||
Source: validatedSkill.Source,
|
||||
Description: validatedSkill.Description,
|
||||
OriginKind: "third_party",
|
||||
RegistryName: registry.Name(),
|
||||
RegistryURL: registrySkillURL(cfg, registry.Name(), req.Slug),
|
||||
InstalledVersion: result.Version,
|
||||
InstalledAt: installedAt,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(installSkillResponse{
|
||||
Status: "ok",
|
||||
Slug: req.Slug,
|
||||
Registry: registry.Name(),
|
||||
Version: result.Version,
|
||||
Summary: result.Summary,
|
||||
IsSuspicious: result.IsSuspicious,
|
||||
InstalledSkill: installedSkill,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) handleImportSkill(w http.ResponseWriter, r *http.Request) {
|
||||
cfg, err := config.LoadConfig(h.configPath)
|
||||
if err != nil {
|
||||
@@ -110,54 +448,26 @@ func (h *Handler) handleImportSkill(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
defer uploadedFile.Close()
|
||||
|
||||
content, err := io.ReadAll(io.LimitReader(uploadedFile, (1<<20)+1))
|
||||
content, err := io.ReadAll(io.LimitReader(uploadedFile, maxImportedSkillSize+1))
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to read file: %v", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if len(content) > 1<<20 {
|
||||
if len(content) > maxImportedSkillSize {
|
||||
http.Error(w, "file exceeds 1MB limit", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
workspaceSkillWriteMu.Lock()
|
||||
defer workspaceSkillWriteMu.Unlock()
|
||||
|
||||
skillName, err := normalizeImportedSkillName(fileHeader.Filename, content)
|
||||
importedSkill, statusCode, err := importUploadedSkill(cfg, fileHeader.Filename, content)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
http.Error(w, err.Error(), statusCode)
|
||||
return
|
||||
}
|
||||
content = normalizeImportedSkillContent(content, skillName)
|
||||
|
||||
workspace := cfg.WorkspacePath()
|
||||
skillDir := filepath.Join(workspace, "skills", skillName)
|
||||
skillFile := filepath.Join(skillDir, "SKILL.md")
|
||||
if _, err := os.Stat(skillDir); err == nil {
|
||||
http.Error(w, "skill already exists", http.StatusConflict)
|
||||
return
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(skillDir, 0o755); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to create skill directory: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if err := os.WriteFile(skillFile, content, 0o644); err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to save skill: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
loader := newSkillsLoader(workspace)
|
||||
for _, skill := range loader.ListSkills() {
|
||||
if skill.Path == skillFile || (skill.Name == skillName && skill.Source == "workspace") {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(skill)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"name": skillName,
|
||||
"path": skillFile,
|
||||
})
|
||||
json.NewEncoder(w).Encode(importedSkill)
|
||||
}
|
||||
|
||||
func (h *Handler) handleDeleteSkill(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -169,6 +479,9 @@ func (h *Handler) handleDeleteSkill(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
loader := newSkillsLoader(cfg.WorkspacePath())
|
||||
name := r.PathValue("name")
|
||||
workspaceSkillWriteMu.Lock()
|
||||
defer workspaceSkillWriteMu.Unlock()
|
||||
|
||||
for _, skill := range loader.ListSkills() {
|
||||
if skill.Name != name {
|
||||
continue
|
||||
@@ -197,12 +510,274 @@ func newSkillsLoader(workspace string) *skills.SkillsLoader {
|
||||
)
|
||||
}
|
||||
|
||||
func newSkillsRegistryManager(cfg *config.Config) *skills.RegistryManager {
|
||||
clawHubConfig := cfg.Tools.Skills.Registries.ClawHub
|
||||
return skills.NewRegistryManagerFromConfig(skills.RegistryConfig{
|
||||
MaxConcurrentSearches: cfg.Tools.Skills.MaxConcurrentSearches,
|
||||
ClawHub: skills.ClawHubConfig{
|
||||
Enabled: clawHubConfig.Enabled,
|
||||
BaseURL: clawHubConfig.BaseURL,
|
||||
AuthToken: clawHubConfig.AuthToken.String(),
|
||||
SearchPath: clawHubConfig.SearchPath,
|
||||
SkillsPath: clawHubConfig.SkillsPath,
|
||||
DownloadPath: clawHubConfig.DownloadPath,
|
||||
Timeout: clawHubConfig.Timeout,
|
||||
MaxZipSize: clawHubConfig.MaxZipSize,
|
||||
MaxResponseSize: clawHubConfig.MaxResponseSize,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func ensureSkillRegistryToolEnabled(cfg *config.Config, toolName string) error {
|
||||
if !cfg.Tools.IsToolEnabled("skills") {
|
||||
return fmt.Errorf("tools.skills is disabled")
|
||||
}
|
||||
if !cfg.Tools.IsToolEnabled(toolName) {
|
||||
return fmt.Errorf("%s is disabled", toolName)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildSkillSupportItems(cfg *config.Config) ([]skillSupportItem, error) {
|
||||
rawSkills := newSkillsLoader(cfg.WorkspacePath()).ListSkills()
|
||||
items := make([]skillSupportItem, 0, len(rawSkills))
|
||||
for _, skill := range rawSkills {
|
||||
item, err := enrichSkillInfo(cfg, skill)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, item)
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func buildWorkspaceSkillItemsByDirectory(cfg *config.Config) (map[string]skillSupportItem, error) {
|
||||
result := make(map[string]skillSupportItem)
|
||||
items, err := buildSkillSupportItems(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, skill := range items {
|
||||
if skill.Source != "workspace" {
|
||||
continue
|
||||
}
|
||||
dir := filepath.Base(filepath.Dir(skill.Path))
|
||||
if dir == "" {
|
||||
continue
|
||||
}
|
||||
result[dir] = skill
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func buildOccupiedWorkspaceSkillsByDirectory(cfg *config.Config) (map[string]skillSupportItem, error) {
|
||||
result := make(map[string]skillSupportItem)
|
||||
items, err := buildSkillSupportItems(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, skill := range items {
|
||||
if skill.Source != "workspace" {
|
||||
continue
|
||||
}
|
||||
|
||||
key := filepath.Base(filepath.Dir(skill.Path))
|
||||
if meta, err := readInstalledSkillOriginMeta(skill.Path); err == nil && meta != nil && meta.Slug != "" {
|
||||
key = meta.Slug
|
||||
}
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
result[key] = skill
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func findWorkspaceSkillByDirectory(cfg *config.Config, directory string) *skillSupportItem {
|
||||
items, err := buildWorkspaceSkillItemsByDirectory(cfg)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
skill, ok := items[directory]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return &skill
|
||||
}
|
||||
|
||||
func findWorkspaceSkillInfoByDirectory(workspace, directory string) *skills.SkillInfo {
|
||||
loader := skills.NewSkillsLoader(workspace, "", "")
|
||||
for _, skill := range loader.ListSkills() {
|
||||
if skill.Source != "workspace" {
|
||||
continue
|
||||
}
|
||||
if filepath.Base(filepath.Dir(skill.Path)) != directory {
|
||||
continue
|
||||
}
|
||||
skillCopy := skill
|
||||
return &skillCopy
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func createStagedSkillInstall(skillsRoot, slug string) (string, string, error) {
|
||||
stagedWorkspaceRoot, err := os.MkdirTemp(skillsRoot, "."+slug+"-install-*")
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
stagedTargetDir := filepath.Join(stagedWorkspaceRoot, "skills", slug)
|
||||
return stagedWorkspaceRoot, stagedTargetDir, nil
|
||||
}
|
||||
|
||||
func commitStagedSkillInstall(stagedWorkspaceRoot, stagedTargetDir, targetDir string, replaceExisting bool) error {
|
||||
if !replaceExisting {
|
||||
return os.Rename(stagedTargetDir, targetDir)
|
||||
}
|
||||
|
||||
backupDir, err := reserveTempDirPath(filepath.Dir(targetDir), "."+filepath.Base(targetDir)+"-backup-*")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := os.Rename(targetDir, backupDir); err != nil {
|
||||
return fmt.Errorf("failed to move existing skill aside: %w", err)
|
||||
}
|
||||
|
||||
if err := os.Rename(stagedTargetDir, targetDir); err != nil {
|
||||
if rollbackErr := os.Rename(backupDir, targetDir); rollbackErr != nil {
|
||||
return fmt.Errorf("failed to activate replacement: %w (rollback failed: %v)", err, rollbackErr)
|
||||
}
|
||||
return fmt.Errorf("failed to activate replacement: %w", err)
|
||||
}
|
||||
|
||||
_ = os.RemoveAll(backupDir)
|
||||
_ = os.RemoveAll(stagedWorkspaceRoot)
|
||||
return nil
|
||||
}
|
||||
|
||||
func reserveTempDirPath(parent, pattern string) (string, error) {
|
||||
tempDir, err := os.MkdirTemp(parent, pattern)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := os.Remove(tempDir); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return tempDir, nil
|
||||
}
|
||||
|
||||
func enrichSkillInfo(cfg *config.Config, skill skills.SkillInfo) (skillSupportItem, error) {
|
||||
item := skillSupportItem{
|
||||
Name: skill.Name,
|
||||
Path: skill.Path,
|
||||
Source: skill.Source,
|
||||
Description: skill.Description,
|
||||
OriginKind: "builtin",
|
||||
}
|
||||
|
||||
switch skill.Source {
|
||||
case "builtin":
|
||||
item.OriginKind = "builtin"
|
||||
case "global":
|
||||
item.OriginKind = "builtin"
|
||||
case "workspace":
|
||||
meta, err := readInstalledSkillOriginMeta(skill.Path)
|
||||
if err == nil && meta != nil {
|
||||
switch meta.OriginKind {
|
||||
case "manual":
|
||||
item.OriginKind = "manual"
|
||||
item.InstalledAt = meta.InstalledAt
|
||||
case "third_party":
|
||||
item.OriginKind = "third_party"
|
||||
item.RegistryName = meta.Registry
|
||||
item.RegistryURL = registrySkillURLFromMeta(cfg, meta)
|
||||
item.InstalledVersion = meta.InstalledVersion
|
||||
item.InstalledAt = meta.InstalledAt
|
||||
default:
|
||||
if meta.Registry != "" || meta.Slug != "" || meta.InstalledVersion != "" {
|
||||
item.OriginKind = "third_party"
|
||||
item.RegistryName = meta.Registry
|
||||
item.RegistryURL = registrySkillURLFromMeta(cfg, meta)
|
||||
item.InstalledVersion = meta.InstalledVersion
|
||||
item.InstalledAt = meta.InstalledAt
|
||||
} else {
|
||||
item.OriginKind = "builtin"
|
||||
item.InstalledAt = meta.InstalledAt
|
||||
}
|
||||
}
|
||||
} else {
|
||||
item.OriginKind = "builtin"
|
||||
}
|
||||
default:
|
||||
item.OriginKind = "builtin"
|
||||
}
|
||||
|
||||
return item, nil
|
||||
}
|
||||
|
||||
func readInstalledSkillOriginMeta(skillPath string) (*installedSkillOriginMeta, error) {
|
||||
metaPath := filepath.Join(filepath.Dir(skillPath), ".skill-origin.json")
|
||||
data, err := os.ReadFile(metaPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
var meta installedSkillOriginMeta
|
||||
if err := json.Unmarshal(data, &meta); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &meta, nil
|
||||
}
|
||||
|
||||
func writeSkillOriginMeta(targetDir string, meta installedSkillOriginMeta) error {
|
||||
data, err := json.MarshalIndent(meta, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return fileutil.WriteFileAtomic(filepath.Join(targetDir, ".skill-origin.json"), data, 0o600)
|
||||
}
|
||||
|
||||
func registrySkillURL(cfg *config.Config, registryName, slug string) string {
|
||||
switch registryName {
|
||||
case "clawhub":
|
||||
baseURL := strings.TrimRight(cfg.Tools.Skills.Registries.ClawHub.BaseURL, "/")
|
||||
if baseURL == "" {
|
||||
baseURL = "https://clawhub.ai"
|
||||
}
|
||||
return baseURL + "/skills/" + url.PathEscape(slug)
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func registrySkillURLFromMeta(cfg *config.Config, meta *installedSkillOriginMeta) string {
|
||||
if meta == nil || meta.Slug == "" {
|
||||
return ""
|
||||
}
|
||||
if meta.RegistryURL != "" {
|
||||
return meta.RegistryURL
|
||||
}
|
||||
if cfg == nil || meta.Registry == "" {
|
||||
return ""
|
||||
}
|
||||
return registrySkillURL(cfg, meta.Registry, meta.Slug)
|
||||
}
|
||||
|
||||
func normalizeImportedSkillName(filename string, content []byte) (string, error) {
|
||||
return normalizeImportedSkillNameWithHint(filename, "", content)
|
||||
}
|
||||
|
||||
func normalizeImportedSkillNameWithHint(filename, directoryHint string, content []byte) (string, error) {
|
||||
rawContent := strings.ReplaceAll(string(content), "\r\n", "\n")
|
||||
rawContent = strings.ReplaceAll(rawContent, "\r", "\n")
|
||||
metadata, _ := extractImportedSkillMetadata(rawContent)
|
||||
|
||||
raw := strings.TrimSpace(metadata["name"])
|
||||
if raw == "" {
|
||||
raw = strings.TrimSpace(directoryHint)
|
||||
}
|
||||
if raw == "" {
|
||||
raw = strings.TrimSpace(strings.TrimSuffix(filepath.Base(filename), filepath.Ext(filename)))
|
||||
}
|
||||
@@ -259,6 +834,210 @@ func normalizeImportedSkillContent(content []byte, skillName string) []byte {
|
||||
return []byte(builder.String())
|
||||
}
|
||||
|
||||
func importUploadedSkill(cfg *config.Config, filename string, content []byte) (*skillSupportItem, int, error) {
|
||||
if isImportedSkillArchive(filename, content) {
|
||||
return importUploadedSkillArchive(cfg, filename, content)
|
||||
}
|
||||
return importUploadedMarkdownSkill(cfg, filename, content)
|
||||
}
|
||||
|
||||
func importUploadedMarkdownSkill(cfg *config.Config, filename string, content []byte) (*skillSupportItem, int, error) {
|
||||
skillName, err := normalizeImportedSkillName(filename, content)
|
||||
if err != nil {
|
||||
return nil, http.StatusBadRequest, err
|
||||
}
|
||||
|
||||
normalizedContent := normalizeImportedSkillContent(content, skillName)
|
||||
workspace := cfg.WorkspacePath()
|
||||
skillDir := filepath.Join(workspace, "skills", skillName)
|
||||
skillFile := filepath.Join(skillDir, "SKILL.md")
|
||||
|
||||
if err := ensureWorkspaceSkillDoesNotExist(skillDir); err != nil {
|
||||
return nil, statusCodeForImportedSkillWriteError(err), err
|
||||
}
|
||||
if err := os.MkdirAll(skillDir, 0o755); err != nil {
|
||||
return nil, http.StatusInternalServerError, fmt.Errorf("Failed to create skill directory: %v", err)
|
||||
}
|
||||
if err := fileutil.WriteFileAtomic(skillFile, normalizedContent, 0o644); err != nil {
|
||||
_ = os.RemoveAll(skillDir)
|
||||
return nil, http.StatusInternalServerError, fmt.Errorf("Failed to save skill: %v", err)
|
||||
}
|
||||
|
||||
return finalizeImportedSkill(cfg, skillDir, skillName, false)
|
||||
}
|
||||
|
||||
func importUploadedSkillArchive(cfg *config.Config, filename string, content []byte) (*skillSupportItem, int, error) {
|
||||
tmpDir, tempDirErr := os.MkdirTemp("", "picoclaw-skill-import-*")
|
||||
if tempDirErr != nil {
|
||||
return nil, http.StatusInternalServerError, fmt.Errorf("Failed to create temp directory: %v", tempDirErr)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
archivePath := filepath.Join(tmpDir, "import.zip")
|
||||
if writeErr := fileutil.WriteFileAtomic(archivePath, content, 0o600); writeErr != nil {
|
||||
return nil, http.StatusInternalServerError, fmt.Errorf("Failed to stage uploaded archive: %v", writeErr)
|
||||
}
|
||||
|
||||
extractDir := filepath.Join(tmpDir, "extract")
|
||||
if extractErr := utils.ExtractZipFile(archivePath, extractDir); extractErr != nil {
|
||||
return nil, http.StatusBadRequest, fmt.Errorf("invalid ZIP archive: %w", extractErr)
|
||||
}
|
||||
|
||||
skillRoot, err := findImportedSkillRoot(extractDir)
|
||||
if err != nil {
|
||||
return nil, http.StatusBadRequest, err
|
||||
}
|
||||
|
||||
skillFile := filepath.Join(skillRoot, "SKILL.md")
|
||||
skillContent, err := os.ReadFile(skillFile)
|
||||
if err != nil {
|
||||
return nil, http.StatusBadRequest, fmt.Errorf("failed to read SKILL.md from archive: %w", err)
|
||||
}
|
||||
|
||||
directoryHint := ""
|
||||
if filepath.Clean(skillRoot) != filepath.Clean(extractDir) {
|
||||
directoryHint = filepath.Base(skillRoot)
|
||||
}
|
||||
skillName, err := normalizeImportedSkillNameWithHint(filename, directoryHint, skillContent)
|
||||
if err != nil {
|
||||
return nil, http.StatusBadRequest, err
|
||||
}
|
||||
|
||||
workspace := cfg.WorkspacePath()
|
||||
skillDir := filepath.Join(workspace, "skills", skillName)
|
||||
if err := ensureWorkspaceSkillDoesNotExist(skillDir); err != nil {
|
||||
return nil, statusCodeForImportedSkillWriteError(err), err
|
||||
}
|
||||
if err := copyImportedSkillTree(skillRoot, skillDir); err != nil {
|
||||
_ = os.RemoveAll(skillDir)
|
||||
return nil, http.StatusInternalServerError, fmt.Errorf("Failed to save skill: %v", err)
|
||||
}
|
||||
|
||||
normalizedContent := normalizeImportedSkillContent(skillContent, skillName)
|
||||
if err := fileutil.WriteFileAtomic(filepath.Join(skillDir, "SKILL.md"), normalizedContent, 0o644); err != nil {
|
||||
_ = os.RemoveAll(skillDir)
|
||||
return nil, http.StatusInternalServerError, fmt.Errorf("Failed to normalize skill: %v", err)
|
||||
}
|
||||
|
||||
return finalizeImportedSkill(cfg, skillDir, skillName, true)
|
||||
}
|
||||
|
||||
func isImportedSkillArchive(filename string, content []byte) bool {
|
||||
if strings.EqualFold(filepath.Ext(filename), ".zip") {
|
||||
return true
|
||||
}
|
||||
return len(content) >= 4 && bytes.HasPrefix(content, []byte("PK\x03\x04"))
|
||||
}
|
||||
|
||||
func ensureWorkspaceSkillDoesNotExist(skillDir string) error {
|
||||
if _, err := os.Stat(skillDir); err == nil {
|
||||
return errImportedSkillExists
|
||||
} else if !os.IsNotExist(err) {
|
||||
return fmt.Errorf("failed to inspect skill directory: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func statusCodeForImportedSkillWriteError(err error) int {
|
||||
if err == nil {
|
||||
return http.StatusOK
|
||||
}
|
||||
if errors.Is(err, errImportedSkillExists) {
|
||||
return http.StatusConflict
|
||||
}
|
||||
return http.StatusInternalServerError
|
||||
}
|
||||
|
||||
func finalizeImportedSkill(
|
||||
cfg *config.Config,
|
||||
skillDir string,
|
||||
skillName string,
|
||||
requireValidatedSkill bool,
|
||||
) (*skillSupportItem, int, error) {
|
||||
if err := persistSkillOriginMeta(skillDir, installedSkillOriginMeta{
|
||||
Version: 1,
|
||||
OriginKind: "manual",
|
||||
InstalledAt: time.Now().UnixMilli(),
|
||||
}); err != nil {
|
||||
_ = os.RemoveAll(skillDir)
|
||||
return nil, http.StatusInternalServerError, fmt.Errorf("Failed to persist skill metadata: %v", err)
|
||||
}
|
||||
|
||||
if importedSkill := findWorkspaceSkillByDirectory(cfg, skillName); importedSkill != nil {
|
||||
return importedSkill, http.StatusOK, nil
|
||||
}
|
||||
|
||||
if requireValidatedSkill {
|
||||
_ = os.RemoveAll(skillDir)
|
||||
return nil, http.StatusBadRequest, fmt.Errorf("imported archive is not a valid skill")
|
||||
}
|
||||
|
||||
return &skillSupportItem{
|
||||
Name: skillName,
|
||||
Path: filepath.Join(skillDir, "SKILL.md"),
|
||||
Source: "workspace",
|
||||
Description: "Imported skill",
|
||||
OriginKind: "manual",
|
||||
}, http.StatusOK, nil
|
||||
}
|
||||
|
||||
func findImportedSkillRoot(extractDir string) (string, error) {
|
||||
skillFiles := make([]string, 0, 1)
|
||||
err := filepath.WalkDir(extractDir, func(path string, d fs.DirEntry, walkErr error) error {
|
||||
if walkErr != nil {
|
||||
return walkErr
|
||||
}
|
||||
if d.IsDir() {
|
||||
return nil
|
||||
}
|
||||
if d.Name() == "SKILL.md" {
|
||||
skillFiles = append(skillFiles, path)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to inspect ZIP archive: %w", err)
|
||||
}
|
||||
|
||||
switch len(skillFiles) {
|
||||
case 0:
|
||||
return "", fmt.Errorf("ZIP archive must contain a SKILL.md file")
|
||||
case 1:
|
||||
return filepath.Dir(skillFiles[0]), nil
|
||||
default:
|
||||
return "", fmt.Errorf("ZIP archive must contain exactly one SKILL.md file")
|
||||
}
|
||||
}
|
||||
|
||||
func copyImportedSkillTree(srcDir, destDir string) error {
|
||||
return filepath.WalkDir(srcDir, func(path string, d fs.DirEntry, walkErr error) error {
|
||||
if walkErr != nil {
|
||||
return walkErr
|
||||
}
|
||||
|
||||
relPath, err := filepath.Rel(srcDir, path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if relPath == "." {
|
||||
return os.MkdirAll(destDir, 0o755)
|
||||
}
|
||||
|
||||
destPath := filepath.Join(destDir, relPath)
|
||||
info, err := d.Info()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if d.IsDir() {
|
||||
return os.MkdirAll(destPath, 0o755)
|
||||
}
|
||||
if !info.Mode().IsRegular() {
|
||||
return fmt.Errorf("archive contains unsupported file %q", relPath)
|
||||
}
|
||||
return fileutil.CopyFile(path, destPath, info.Mode().Perm())
|
||||
})
|
||||
}
|
||||
|
||||
func extractImportedSkillMetadata(raw string) (map[string]string, string) {
|
||||
matches := importedSkillFrontmatter.FindStringSubmatch(raw)
|
||||
if len(matches) != 2 {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,52 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/updater"
|
||||
)
|
||||
|
||||
// registerUpdateRoutes registers the self-update endpoint.
|
||||
func (h *Handler) registerUpdateRoutes(mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/update", h.handleUpdate)
|
||||
}
|
||||
|
||||
type updateRequest struct {
|
||||
URL string `json:"url,omitempty"`
|
||||
Binary string `json:"binary,omitempty"`
|
||||
}
|
||||
|
||||
type updateResponse struct {
|
||||
Status string `json:"status"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
func (h *Handler) handleUpdate(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
_ = json.NewEncoder(w).Encode(updateResponse{Status: "error", Message: "method not allowed"})
|
||||
return
|
||||
}
|
||||
|
||||
dec := json.NewDecoder(http.MaxBytesReader(w, r.Body, 1<<20))
|
||||
var req updateRequest
|
||||
if err := dec.Decode(&req); err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_ = json.NewEncoder(w).Encode(updateResponse{Status: "error", Message: "invalid request body"})
|
||||
return
|
||||
}
|
||||
|
||||
binary := req.Binary
|
||||
if binary == "" {
|
||||
binary = "picoclaw-launcher"
|
||||
}
|
||||
|
||||
if err := updater.UpdateSelfFromRelease(req.URL, "", "", binary); err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
_ = json.NewEncoder(w).Encode(updateResponse{Status: "error", Message: err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
_ = json.NewEncoder(w).Encode(updateResponse{Status: "ok", Message: "update applied; restart to use new version"})
|
||||
}
|
||||
Reference in New Issue
Block a user