Merge branch 'main' into refactor-inbound-context-routing-session

# Conflicts:
#	pkg/agent/eventbus_test.go
#	pkg/agent/loop.go
#	pkg/bus/bus.go
#	pkg/bus/types.go
#	pkg/channels/pico/pico.go
#	pkg/channels/telegram/telegram.go
#	pkg/config/config.go
#	web/backend/api/session.go
#	web/backend/api/session_test.go
This commit is contained in:
Hoshina
2026-04-07 21:41:02 +08:00
282 changed files with 33064 additions and 3251 deletions
+1
View File
@@ -23,6 +23,7 @@ type LauncherAuthRouteOpts struct {
type LauncherAuthTokenHelp struct {
EnvVarName string `json:"env_var_name"`
LogFileAbs string `json:"log_file,omitempty"`
ConfigFileAbs string `json:"config_file,omitempty"`
TrayCopyMenu bool `json:"tray_copy_menu"`
ConsoleStdout bool `json:"console_stdout"`
}
+184
View File
@@ -3,6 +3,8 @@ package api
import (
"encoding/json"
"net/http"
"github.com/sipeed/picoclaw/pkg/config"
)
type channelCatalogItem struct {
@@ -30,9 +32,22 @@ var channelCatalog = []channelCatalogItem{
{Name: "irc", ConfigKey: "irc"},
}
type channelConfigResponse struct {
Config any `json:"config"`
ConfiguredSecrets []string `json:"configured_secrets"`
ConfigKey string `json:"config_key"`
Variant string `json:"variant,omitempty"`
}
type channelSecretPresence struct {
key string
configured bool
}
// registerChannelRoutes binds read-only channel catalog endpoints to the ServeMux.
func (h *Handler) registerChannelRoutes(mux *http.ServeMux) {
mux.HandleFunc("GET /api/channels/catalog", h.handleListChannelCatalog)
mux.HandleFunc("GET /api/channels/{name}/config", h.handleGetChannelConfig)
}
// handleListChannelCatalog returns the channels supported by backend.
@@ -44,3 +59,172 @@ func (h *Handler) handleListChannelCatalog(w http.ResponseWriter, r *http.Reques
"channels": channelCatalog,
})
}
// handleGetChannelConfig returns safe channel config plus secret presence metadata.
//
// GET /api/channels/{name}/config
func (h *Handler) handleGetChannelConfig(w http.ResponseWriter, r *http.Request) {
channelName := r.PathValue("name")
item, ok := findChannelCatalogItem(channelName)
if !ok {
http.Error(w, "Channel not found", http.StatusNotFound)
return
}
cfg, err := config.LoadConfig(h.configPath)
if err != nil {
http.Error(w, "Failed to load config", http.StatusInternalServerError)
return
}
resp := buildChannelConfigResponse(cfg, item)
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(resp); err != nil {
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
}
}
func findChannelCatalogItem(name string) (channelCatalogItem, bool) {
for _, item := range channelCatalog {
if item.Name == name {
return item, true
}
}
return channelCatalogItem{}, false
}
func buildChannelConfigResponse(cfg *config.Config, item channelCatalogItem) channelConfigResponse {
resp := channelConfigResponse{
ConfiguredSecrets: []string{},
ConfigKey: item.ConfigKey,
Variant: item.Variant,
}
switch item.Name {
case "weixin":
channelCfg := cfg.Channels.Weixin
resp.ConfiguredSecrets = collectConfiguredSecrets(
channelSecretPresence{key: "token", configured: channelCfg.Token.String() != ""},
)
channelCfg.Token = config.SecureString{}
resp.Config = channelCfg
case "telegram":
channelCfg := cfg.Channels.Telegram
resp.ConfiguredSecrets = collectConfiguredSecrets(
channelSecretPresence{key: "token", configured: channelCfg.Token.String() != ""},
)
channelCfg.Token = config.SecureString{}
resp.Config = channelCfg
case "discord":
channelCfg := cfg.Channels.Discord
resp.ConfiguredSecrets = collectConfiguredSecrets(
channelSecretPresence{key: "token", configured: channelCfg.Token.String() != ""},
)
channelCfg.Token = config.SecureString{}
resp.Config = channelCfg
case "slack":
channelCfg := cfg.Channels.Slack
resp.ConfiguredSecrets = collectConfiguredSecrets(
channelSecretPresence{key: "bot_token", configured: channelCfg.BotToken.String() != ""},
channelSecretPresence{key: "app_token", configured: channelCfg.AppToken.String() != ""},
)
channelCfg.BotToken = config.SecureString{}
channelCfg.AppToken = config.SecureString{}
resp.Config = channelCfg
case "feishu":
channelCfg := cfg.Channels.Feishu
resp.ConfiguredSecrets = collectConfiguredSecrets(
channelSecretPresence{key: "app_secret", configured: channelCfg.AppSecret.String() != ""},
channelSecretPresence{key: "encrypt_key", configured: channelCfg.EncryptKey.String() != ""},
channelSecretPresence{key: "verification_token", configured: channelCfg.VerificationToken.String() != ""},
)
channelCfg.AppSecret = config.SecureString{}
channelCfg.EncryptKey = config.SecureString{}
channelCfg.VerificationToken = config.SecureString{}
resp.Config = channelCfg
case "dingtalk":
channelCfg := cfg.Channels.DingTalk
resp.ConfiguredSecrets = collectConfiguredSecrets(
channelSecretPresence{key: "client_secret", configured: channelCfg.ClientSecret.String() != ""},
)
channelCfg.ClientSecret = config.SecureString{}
resp.Config = channelCfg
case "line":
channelCfg := cfg.Channels.LINE
resp.ConfiguredSecrets = collectConfiguredSecrets(
channelSecretPresence{key: "channel_secret", configured: channelCfg.ChannelSecret.String() != ""},
channelSecretPresence{
key: "channel_access_token",
configured: channelCfg.ChannelAccessToken.String() != "",
},
)
channelCfg.ChannelSecret = config.SecureString{}
channelCfg.ChannelAccessToken = config.SecureString{}
resp.Config = channelCfg
case "qq":
channelCfg := cfg.Channels.QQ
resp.ConfiguredSecrets = collectConfiguredSecrets(
channelSecretPresence{key: "app_secret", configured: channelCfg.AppSecret.String() != ""},
)
channelCfg.AppSecret = config.SecureString{}
resp.Config = channelCfg
case "onebot":
channelCfg := cfg.Channels.OneBot
resp.ConfiguredSecrets = collectConfiguredSecrets(
channelSecretPresence{key: "access_token", configured: channelCfg.AccessToken.String() != ""},
)
channelCfg.AccessToken = config.SecureString{}
resp.Config = channelCfg
case "wecom":
channelCfg := cfg.Channels.WeCom
resp.ConfiguredSecrets = collectConfiguredSecrets(
channelSecretPresence{key: "secret", configured: channelCfg.Secret.String() != ""},
)
channelCfg.Secret = config.SecureString{}
resp.Config = channelCfg
case "whatsapp", "whatsapp_native":
resp.Config = cfg.Channels.WhatsApp
case "pico":
channelCfg := cfg.Channels.Pico
resp.ConfiguredSecrets = collectConfiguredSecrets(
channelSecretPresence{key: "token", configured: channelCfg.Token.String() != ""},
)
channelCfg.Token = config.SecureString{}
resp.Config = channelCfg
case "maixcam":
resp.Config = cfg.Channels.MaixCam
case "matrix":
channelCfg := cfg.Channels.Matrix
resp.ConfiguredSecrets = collectConfiguredSecrets(
channelSecretPresence{key: "access_token", configured: channelCfg.AccessToken.String() != ""},
)
channelCfg.AccessToken = config.SecureString{}
resp.Config = channelCfg
case "irc":
channelCfg := cfg.Channels.IRC
resp.ConfiguredSecrets = collectConfiguredSecrets(
channelSecretPresence{key: "password", configured: channelCfg.Password.String() != ""},
channelSecretPresence{key: "nickserv_password", configured: channelCfg.NickServPassword.String() != ""},
channelSecretPresence{key: "sasl_password", configured: channelCfg.SASLPassword.String() != ""},
)
channelCfg.Password = config.SecureString{}
channelCfg.NickServPassword = config.SecureString{}
channelCfg.SASLPassword = config.SecureString{}
resp.Config = channelCfg
default:
resp.Config = map[string]any{}
}
return resp
}
func collectConfiguredSecrets(secrets ...channelSecretPresence) []string {
configured := make([]string, 0, len(secrets))
for _, secret := range secrets {
if secret.configured {
configured = append(configured, secret.key)
}
}
return configured
}
+87
View File
@@ -0,0 +1,87 @@
package api
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/sipeed/picoclaw/pkg/config"
)
func TestHandleGetChannelConfig_ReturnsSecretPresenceWithoutLeakingSecrets(t *testing.T) {
configPath, cleanup := setupOAuthTestEnv(t)
defer cleanup()
cfg, err := config.LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error = %v", err)
}
cfg.Channels.Feishu.Enabled = true
cfg.Channels.Feishu.AppID = "cli_test_app"
cfg.Channels.Feishu.AppSecret = *config.NewSecureString("feishu-secret-from-security")
if err := config.SaveConfig(configPath, cfg); err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}
h := NewHandler(configPath)
mux := http.NewServeMux()
h.RegisterRoutes(mux)
req := httptest.NewRequest(http.MethodGet, "/api/channels/feishu/config", nil)
rec := httptest.NewRecorder()
mux.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf(
"GET /api/channels/feishu/config status = %d, want %d, body=%s",
rec.Code,
http.StatusOK,
rec.Body.String(),
)
}
if strings.Contains(rec.Body.String(), "feishu-secret-from-security") {
t.Fatalf("response leaked secret value: %s", rec.Body.String())
}
var resp struct {
Config map[string]any `json:"config"`
ConfiguredSecrets []string `json:"configured_secrets"`
ConfigKey string `json:"config_key"`
Variant string `json:"variant"`
}
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
t.Fatalf("json.Unmarshal() error = %v", err)
}
if got := resp.ConfigKey; got != "feishu" {
t.Fatalf("config_key = %q, want %q", got, "feishu")
}
if got := resp.Config["app_id"]; got != "cli_test_app" {
t.Fatalf("config.app_id = %#v, want %q", got, "cli_test_app")
}
if _, exists := resp.Config["app_secret"]; exists {
t.Fatalf("config should omit app_secret, got %#v", resp.Config["app_secret"])
}
if len(resp.ConfiguredSecrets) != 1 || resp.ConfiguredSecrets[0] != "app_secret" {
t.Fatalf("configured_secrets = %#v, want [\"app_secret\"]", resp.ConfiguredSecrets)
}
}
func TestHandleGetChannelConfig_ReturnsNotFoundForUnknownChannel(t *testing.T) {
configPath, cleanup := setupOAuthTestEnv(t)
defer cleanup()
h := NewHandler(configPath)
mux := http.NewServeMux()
h.RegisterRoutes(mux)
req := httptest.NewRequest(http.MethodGet, "/api/channels/not-a-channel/config", nil)
rec := httptest.NewRecorder()
mux.ServeHTTP(rec, req)
if rec.Code != http.StatusNotFound {
t.Fatalf("GET /api/channels/not-a-channel/config status = %d, want %d", rec.Code, http.StatusNotFound)
}
}
+34 -28
View File
@@ -357,7 +357,13 @@ func isCmdProcessAliveLocked(cmd *exec.Cmd) bool {
return true
}
return cmd.Process.Signal(syscall.Signal(0)) == nil
err := cmd.Process.Signal(syscall.Signal(0))
if err == nil {
return true
}
var errno syscall.Errno
// EPERM means the process exists but cannot be signaled by this user.
return errors.As(err, &errno) && errno == syscall.EPERM
}
func setGatewayRuntimeStatusLocked(status string) {
@@ -401,6 +407,15 @@ func gatewayStatusWithoutHealthLocked() string {
return "error"
}
if gateway.runtimeStatus == "running" {
// For attached processes there is no waiter goroutine; degrade stale
// running state once the tracked process exits.
if !isCmdProcessAliveLocked(gateway.cmd) {
gateway.cmd = nil
gateway.owned = false
gateway.bootDefaultModel = ""
gateway.bootConfigSignature = ""
return "stopped"
}
return "running"
}
if gateway.runtimeStatus == "error" {
@@ -614,6 +629,7 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
// Start a goroutine to probe pidFile and health, update runtime state once ready.
go func() {
healthConfirmed := false
for i := 0; i < 30; i++ { // try for up to 15 seconds
time.Sleep(500 * time.Millisecond)
gateway.mu.Lock()
@@ -628,6 +644,7 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
gateway.mu.Lock()
if gateway.cmd == cmd {
gateway.pidData = pd
gateway.picoToken = cfg.Channels.Pico.Token.String()
setGatewayRuntimeStatusLocked("running")
}
gateway.mu.Unlock()
@@ -647,7 +664,11 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
setGatewayRuntimeStatusLocked("running")
}
gateway.mu.Unlock()
return
if !healthConfirmed {
healthConfirmed = true
logger.InfoC("gateway", "Gateway health endpoint reachable; waiting for pid file")
}
continue
}
}
}()
@@ -922,34 +943,19 @@ func (h *Handler) gatewayStatusData() map[string]any {
data["pid"] = pidData.PID
gateway.mu.Unlock()
} else {
// Fallback: probe health endpoint to get pid and status
_, statusCode, err := h.getGatewayHealth(cfg, 2*time.Second)
if err != nil {
gateway.mu.Lock()
data["gateway_status"] = gatewayStatusWithoutHealthLocked()
// Intentionally skip health probe here; the startup goroutine
// (startGatewayLocked) already handles liveness detection via
// pidFile polling and health fallback.
gateway.mu.Lock()
status := gatewayStatusWithoutHealthLocked()
data["gateway_status"] = status
// Keep last known pidData while gateway is still in a transient
// running state; otherwise websocket proxy may lose auth token
// during short pid-file races.
if status == "stopped" || status == "error" {
gateway.pidData = nil
gateway.mu.Unlock()
logger.ErrorC("gateway", fmt.Sprintf("Gateway health check failed: %v", err))
} else {
logger.InfoC("gateway", fmt.Sprintf("Gateway health status: %d", statusCode))
if statusCode != http.StatusOK {
gateway.mu.Lock()
setGatewayRuntimeStatusLocked("error")
gateway.pidData = nil
gateway.mu.Unlock()
data["gateway_status"] = "error"
data["status_code"] = statusCode
} else {
gateway.mu.Lock()
setGatewayRuntimeStatusLocked("running")
bootDefaultModel := gateway.bootDefaultModel
if bootDefaultModel != "" {
data["boot_default_model"] = bootDefaultModel
}
data["gateway_status"] = "running"
gateway.mu.Unlock()
}
}
gateway.mu.Unlock()
}
gatewayStatus, _ := data["gateway_status"].(string)
+95 -1
View File
@@ -15,8 +15,11 @@ import (
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/sipeed/picoclaw/pkg/auth"
"github.com/sipeed/picoclaw/pkg/config"
ppid "github.com/sipeed/picoclaw/pkg/pid"
"github.com/sipeed/picoclaw/web/backend/utils"
)
@@ -444,7 +447,93 @@ func TestGatewayStatusKeepsRunningWhenHealthProbeFailsAfterRunning(t *testing.T)
}
}
func TestGatewayStatusReportsRunningFromHealthProbe(t *testing.T) {
func TestGatewayStatusKeepsPidDataWhileTrackedProcessAliveWhenPidFileUnavailable(t *testing.T) {
resetGatewayTestState(t)
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
mux := http.NewServeMux()
h.RegisterRoutes(mux)
cmd := startLongRunningProcess(t)
t.Cleanup(func() {
if cmd.Process != nil {
_ = cmd.Process.Kill()
}
_ = cmd.Wait()
})
gateway.mu.Lock()
gateway.cmd = cmd
gateway.pidData = &ppid.PidFileData{
PID: cmd.Process.Pid,
Token: "existing-token",
}
setGatewayRuntimeStatusLocked("running")
gateway.mu.Unlock()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil)
mux.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}
gateway.mu.Lock()
defer gateway.mu.Unlock()
if gateway.pidData == nil {
t.Fatal("gateway.pidData was cleared while runtime status remained running")
}
}
func TestGatewayStatusDowngradesRunningWhenTrackedProcessExitedAndPidFileMissing(t *testing.T) {
resetGatewayTestState(t)
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
mux := http.NewServeMux()
h.RegisterRoutes(mux)
cmd := startLongRunningProcess(t)
if cmd.Process != nil {
_ = cmd.Process.Kill()
}
_ = cmd.Wait()
gateway.mu.Lock()
gateway.cmd = cmd
gateway.pidData = &ppid.PidFileData{
PID: cmd.Process.Pid,
Token: "stale-token",
}
setGatewayRuntimeStatusLocked("running")
gateway.mu.Unlock()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil)
mux.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}
var body map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
t.Fatalf("unmarshal response: %v", err)
}
if got := body["gateway_status"]; got != "stopped" {
t.Fatalf("gateway_status = %#v, want %q", got, "stopped")
}
gateway.mu.Lock()
defer gateway.mu.Unlock()
if gateway.pidData != nil {
t.Fatal("gateway.pidData should be cleared when tracked process has exited")
}
}
func TestGatewayStatusReportsRunningFromPidProbe(t *testing.T) {
resetGatewayTestState(t)
configPath := filepath.Join(t.TempDir(), "config.json")
@@ -468,6 +557,9 @@ func TestGatewayStatusReportsRunningFromHealthProbe(t *testing.T) {
return mockGatewayHealthResponse(http.StatusOK, cmd.Process.Pid), nil
}
_, err := ppid.WritePidFile(globalConfigDir(), "localhost", 0)
require.NoError(t, err)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil)
mux.ServeHTTP(rec, req)
@@ -513,6 +605,8 @@ func TestGatewayStatusRequiresRestartAfterDefaultModelChange(t *testing.T) {
if err != nil {
t.Fatalf("FindProcess() error = %v", err)
}
_, err = ppid.WritePidFile(globalConfigDir(), "localhost", 0)
require.NoError(t, err)
bootSignature := computeConfigSignature(cfg)
gateway.mu.Lock()
+17 -12
View File
@@ -4,14 +4,16 @@ import (
"encoding/json"
"fmt"
"net/http"
"strings"
"github.com/sipeed/picoclaw/web/backend/launcherconfig"
)
type launcherConfigPayload struct {
Port int `json:"port"`
Public bool `json:"public"`
AllowedCIDRs []string `json:"allowed_cidrs"`
Port int `json:"port"`
Public bool `json:"public"`
AllowedCIDRs []string `json:"allowed_cidrs"`
LauncherToken string `json:"launcher_token"`
}
func (h *Handler) registerLauncherConfigRoutes(mux *http.ServeMux) {
@@ -48,9 +50,10 @@ func (h *Handler) handleGetLauncherConfig(w http.ResponseWriter, r *http.Request
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(launcherConfigPayload{
Port: cfg.Port,
Public: cfg.Public,
AllowedCIDRs: append([]string(nil), cfg.AllowedCIDRs...),
Port: cfg.Port,
Public: cfg.Public,
AllowedCIDRs: append([]string(nil), cfg.AllowedCIDRs...),
LauncherToken: cfg.LauncherToken,
})
}
@@ -62,9 +65,10 @@ func (h *Handler) handleUpdateLauncherConfig(w http.ResponseWriter, r *http.Requ
}
cfg := launcherconfig.Config{
Port: payload.Port,
Public: payload.Public,
AllowedCIDRs: append([]string(nil), payload.AllowedCIDRs...),
Port: payload.Port,
Public: payload.Public,
AllowedCIDRs: append([]string(nil), payload.AllowedCIDRs...),
LauncherToken: strings.TrimSpace(payload.LauncherToken),
}
if err := launcherconfig.Validate(cfg); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
@@ -78,8 +82,9 @@ func (h *Handler) handleUpdateLauncherConfig(w http.ResponseWriter, r *http.Requ
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(launcherConfigPayload{
Port: cfg.Port,
Public: cfg.Public,
AllowedCIDRs: append([]string(nil), cfg.AllowedCIDRs...),
Port: cfg.Port,
Public: cfg.Public,
AllowedCIDRs: append([]string(nil), cfg.AllowedCIDRs...),
LauncherToken: cfg.LauncherToken,
})
}
+9 -1
View File
@@ -34,6 +34,9 @@ func TestGetLauncherConfigUsesRuntimeFallback(t *testing.T) {
if got.Port != 19999 || !got.Public {
t.Fatalf("response = %+v, want port=19999 public=true", got)
}
if got.LauncherToken != "" {
t.Fatalf("response launcher_token = %q, want empty", got.LauncherToken)
}
if len(got.AllowedCIDRs) != 1 || got.AllowedCIDRs[0] != "192.168.1.0/24" {
t.Fatalf("response allowed_cidrs = %v, want [192.168.1.0/24]", got.AllowedCIDRs)
}
@@ -50,7 +53,9 @@ func TestPutLauncherConfigPersists(t *testing.T) {
req := httptest.NewRequest(
http.MethodPut,
"/api/system/launcher-config",
strings.NewReader(`{"port":18080,"public":true,"allowed_cidrs":["192.168.1.0/24"]}`),
strings.NewReader(
`{"port":18080,"public":true,"allowed_cidrs":["192.168.1.0/24"],"launcher_token":"saved-token"}`,
),
)
req.Header.Set("Content-Type", "application/json")
mux.ServeHTTP(rec, req)
@@ -67,6 +72,9 @@ func TestPutLauncherConfigPersists(t *testing.T) {
if cfg.Port != 18080 || !cfg.Public {
t.Fatalf("saved config = %+v, want port=18080 public=true", cfg)
}
if cfg.LauncherToken != "saved-token" {
t.Fatalf("saved launcher_token = %q, want %q", cfg.LauncherToken, "saved-token")
}
if len(cfg.AllowedCIDRs) != 1 || cfg.AllowedCIDRs[0] != "192.168.1.0/24" {
t.Fatalf("saved config allowed_cidrs = %v, want [192.168.1.0/24]", cfg.AllowedCIDRs)
}
+301 -8
View File
@@ -1,19 +1,36 @@
package api
import (
"context"
"encoding/json"
"fmt"
"hash/fnv"
"net"
"net/http"
"net/url"
"strconv"
"strings"
"sync"
"time"
"golang.org/x/sync/singleflight"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/providers"
)
const modelProbeTimeout = 800 * time.Millisecond
const (
modelProbeTimeout = 800 * time.Millisecond
modelProbeSuccessBaseInterval = 2 * time.Second
modelProbeSuccessMaxInterval = 60 * time.Second
modelProbeFailureBaseInterval = 1 * time.Second
modelProbeFailureMaxInterval = 30 * time.Second
modelProbeBackoffMaxShift = 8
modelProbeCacheMaxEntries = 1024
modelProbeCacheEntryTTL = 30 * time.Minute
modelProbeCacheTrimToEntries = modelProbeCacheMaxEntries * 8 / 10
modelProbeTTLGCInterval = 1 * time.Minute
)
const (
modelStatusAvailable = "available"
@@ -30,8 +47,41 @@ var (
probeTCPServiceFunc = probeTCPService
probeOllamaModelFunc = probeOllamaModel
probeOpenAICompatibleModelFunc = probeOpenAICompatibleModel
modelProbeNowFunc = time.Now
modelProbeState = newModelProbeCacheState()
)
type modelProbeCacheState struct {
mu sync.RWMutex
cache map[string]*modelProbeCacheEntry
group singleflight.Group
nextTTLGCAt time.Time
}
type modelProbeCacheEntry struct {
lastResult bool
hasResult bool
successStreak int
failureStreak int
nextProbeAt time.Time
updatedAt time.Time
}
func newModelProbeCacheState() *modelProbeCacheState {
return &modelProbeCacheState{cache: map[string]*modelProbeCacheEntry{}}
}
func resetModelProbeCache() {
modelProbeState.resetForTest()
}
func (s *modelProbeCacheState) resetForTest() {
s.mu.Lock()
defer s.mu.Unlock()
s.cache = map[string]*modelProbeCacheEntry{}
s.nextTTLGCAt = time.Time{}
}
func hasModelConfiguration(m *config.ModelConfig) bool {
authMethod := strings.ToLower(strings.TrimSpace(m.AuthMethod))
apiKey := strings.TrimSpace(m.APIKey())
@@ -93,6 +143,34 @@ func requiresRuntimeProbe(m *config.ModelConfig) bool {
}
func probeLocalModelAvailability(m *config.ModelConfig) bool {
cacheKey := modelProbeCacheKey(m)
return modelProbeState.probe(cacheKey, func() bool {
return runLocalModelProbe(m)
})
}
func (s *modelProbeCacheState) probe(cacheKey string, probeFunc func() bool) bool {
now := modelProbeNowFunc()
if cachedResult, ok := s.getCachedResult(cacheKey, now); ok {
return cachedResult
}
v, _, _ := s.group.Do(cacheKey, func() (any, error) {
now = modelProbeNowFunc()
if cachedResult, ok := s.getCachedResult(cacheKey, now); ok {
return cachedResult, nil
}
result := probeFunc()
s.setCachedResult(cacheKey, result, now)
return result, nil
})
result, _ := v.(bool)
return result
}
func runLocalModelProbe(m *config.ModelConfig) bool {
apiBase := modelProbeAPIBase(m)
protocol, modelID := splitModel(m.Model)
switch protocol {
@@ -112,6 +190,195 @@ func probeLocalModelAvailability(m *config.ModelConfig) bool {
}
}
func modelProbeCacheKey(m *config.ModelConfig) string {
protocol, modelID := splitModel(m.Model)
apiBaseRaw := modelProbeAPIBase(m)
apiBase := strings.ToLower(strings.TrimRight(strings.TrimSpace(apiBaseRaw), "/"))
apiKeyFingerprint := modelProbeAPIKeyFingerprint(m.APIKey())
var b strings.Builder
b.Grow(len(protocol) + len(modelID) + len(apiBase) + len(apiKeyFingerprint) + 8)
b.WriteString(protocol)
b.WriteByte('|')
b.WriteString(modelID)
b.WriteByte('|')
b.WriteString(apiBase)
b.WriteByte('|')
b.WriteString(apiKeyFingerprint)
return b.String()
}
func modelProbeAPIKeyFingerprint(raw string) string {
apiKey := strings.TrimSpace(raw)
if apiKey == "" {
return "none"
}
h := fnv.New64a()
_, _ = h.Write([]byte(apiKey))
return strconv.FormatUint(h.Sum64(), 36)
}
func (s *modelProbeCacheState) getCachedResult(cacheKey string, now time.Time) (bool, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
entry, ok := s.cache[cacheKey]
if !ok || !entry.hasResult {
return false, false
}
if now.Before(entry.nextProbeAt) {
return entry.lastResult, true
}
return false, false
}
func (s *modelProbeCacheState) setCachedResult(cacheKey string, result bool, now time.Time) {
s.mu.Lock()
entry, ok := s.cache[cacheKey]
if !ok {
entry = &modelProbeCacheEntry{}
s.cache[cacheKey] = entry
}
entry.lastResult = result
entry.hasResult = true
entry.updatedAt = now
var delay time.Duration
if result {
entry.successStreak++
entry.failureStreak = 0
delay = modelProbeBackoffDelay(
modelProbeSuccessBaseInterval,
modelProbeSuccessMaxInterval,
entry.successStreak,
)
} else {
entry.failureStreak++
entry.successStreak = 0
delay = modelProbeBackoffDelay(
modelProbeFailureBaseInterval,
modelProbeFailureMaxInterval,
entry.failureStreak,
)
}
entry.nextProbeAt = now.Add(delay)
shouldRunTTLGC := modelProbeCacheEntryTTL > 0 && (s.nextTTLGCAt.IsZero() || !now.Before(s.nextTTLGCAt))
if shouldRunTTLGC {
s.nextTTLGCAt = now.Add(modelProbeTTLGCInterval)
}
shouldRunSizeGC := len(s.cache) > modelProbeCacheMaxEntries
s.mu.Unlock()
if shouldRunTTLGC || shouldRunSizeGC {
s.gc(now, shouldRunTTLGC)
}
}
func (s *modelProbeCacheState) gc(now time.Time, runTTL bool) {
type evictionCandidate struct {
key string
updatedAt time.Time
}
var expireBefore time.Time
if runTTL && modelProbeCacheEntryTTL > 0 {
expireBefore = now.Add(-modelProbeCacheEntryTTL)
}
s.mu.RLock()
cacheLen := len(s.cache)
if cacheLen == 0 {
s.mu.RUnlock()
return
}
expiredKeys := make([]string, 0)
if !expireBefore.IsZero() {
expiredKeys = make([]string, 0, min(cacheLen/8+1, 64))
for key, entry := range s.cache {
if entry.updatedAt.Before(expireBefore) {
expiredKeys = append(expiredKeys, key)
}
}
}
effectiveLen := cacheLen - len(expiredKeys)
removeCount := max(effectiveLen-modelProbeCacheTrimToEntries, 0)
candidates := make([]evictionCandidate, 0)
if removeCount > 0 {
candidates = make([]evictionCandidate, 0, effectiveLen)
for key, entry := range s.cache {
if !expireBefore.IsZero() && entry.updatedAt.Before(expireBefore) {
continue
}
candidates = append(candidates, evictionCandidate{key: key, updatedAt: entry.updatedAt})
}
}
s.mu.RUnlock()
if len(expiredKeys) == 0 && len(candidates) == 0 {
return
}
toEvict := map[string]time.Time{}
for i := 0; i < removeCount && len(candidates) > 0; i++ {
oldest := 0
for j := 1; j < len(candidates); j++ {
if candidates[j].updatedAt.Before(candidates[oldest].updatedAt) {
oldest = j
}
}
victim := candidates[oldest]
toEvict[victim.key] = victim.updatedAt
candidates[oldest] = candidates[len(candidates)-1]
candidates = candidates[:len(candidates)-1]
}
s.mu.Lock()
defer s.mu.Unlock()
if !expireBefore.IsZero() {
for _, key := range expiredKeys {
entry, ok := s.cache[key]
if ok && entry.updatedAt.Before(expireBefore) {
delete(s.cache, key)
}
}
}
for key, victimUpdatedAt := range toEvict {
entry, ok := s.cache[key]
if ok && !entry.updatedAt.After(victimUpdatedAt) {
delete(s.cache, key)
}
}
}
func modelProbeBackoffDelay(base, maxDelay time.Duration, streak int) time.Duration {
if streak <= 0 {
streak = 1
}
shift := min(streak-1, modelProbeBackoffMaxShift)
delay := base * time.Duration(1<<shift)
if maxDelay > 0 && (delay > maxDelay || delay < 0) {
return maxDelay
}
if delay <= 0 {
return base
}
return delay
}
func modelProbeAPIBase(m *config.ModelConfig) string {
if apiBase := strings.TrimSpace(m.APIBase); apiBase != "" {
return normalizeModelProbeAPIBase(apiBase)
@@ -207,7 +474,11 @@ func probeTCPService(raw string) bool {
return false
}
conn, err := net.DialTimeout("tcp", hostPort, modelProbeTimeout)
ctx, cancel := context.WithTimeout(context.Background(), modelProbeTimeout)
defer cancel()
dialer := &net.Dialer{}
conn, err := dialer.DialContext(ctx, "tcp", hostPort)
if err != nil {
return false
}
@@ -262,7 +533,10 @@ func probeOpenAICompatibleModel(apiBase, modelID, apiKey string) bool {
}
func getJSON(rawURL string, out any, apiKey string) error {
req, err := http.NewRequest(http.MethodGet, rawURL, nil)
ctx, cancel := context.WithTimeout(context.Background(), modelProbeTimeout)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil)
if err != nil {
return err
}
@@ -270,7 +544,7 @@ func getJSON(rawURL string, out any, apiKey string) error {
req.Header.Set("Authorization", "Bearer "+apiKey)
}
client := &http.Client{Timeout: modelProbeTimeout}
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return err
@@ -336,10 +610,29 @@ func ollamaModelMatches(candidate, want string) bool {
if candidate == "" || want == "" {
return false
}
if strings.EqualFold(candidate, want) {
return true
candidateBase, candidateTag := splitOllamaModel(candidate)
wantBase, wantTag := splitOllamaModel(want)
if candidateBase == "" || wantBase == "" {
return false
}
base, _, _ := strings.Cut(candidate, ":")
return strings.EqualFold(base, want)
if candidateTag == "" {
candidateTag = "latest"
}
if wantTag == "" {
wantTag = "latest"
}
return strings.EqualFold(candidateBase, wantBase) && strings.EqualFold(candidateTag, wantTag)
}
func splitOllamaModel(raw string) (base, tag string) {
raw = strings.TrimSpace(raw)
if raw == "" {
return "", ""
}
base, tag, _ = strings.Cut(raw, ":")
return strings.TrimSpace(base), strings.TrimSpace(tag)
}
+307
View File
@@ -3,7 +3,10 @@ package api
import (
"net/http"
"net/http/httptest"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/config"
)
@@ -85,3 +88,307 @@ func TestProbeLocalModelAvailability_LMStudioUsesOpenAICompatibleProbe(t *testin
t.Fatal("probeOpenAICompatibleModelFunc was not called for lmstudio")
}
}
func TestModelProbeCacheKey_DifferentAPIKeysProduceDifferentKeys(t *testing.T) {
base := &config.ModelConfig{
ModelName: "local-vllm",
Model: "vllm/custom-model",
APIBase: "http://127.0.0.1:8000/v1",
AuthMethod: "local",
ConnectMode: "",
}
m1 := *base
m1.SetAPIKey("key-a")
m2 := *base
m2.SetAPIKey("key-b")
k1 := modelProbeCacheKey(&m1)
k2 := modelProbeCacheKey(&m2)
if k1 == k2 {
t.Fatal("modelProbeCacheKey() should differ when api key changes")
}
}
func TestModelProbeCacheKey_NormalizesTrailingSlashInAPIBase(t *testing.T) {
m1 := &config.ModelConfig{
ModelName: "local-vllm",
Model: "vllm/custom-model",
APIBase: "http://127.0.0.1:8000/v1",
}
m2 := &config.ModelConfig{
ModelName: "local-vllm",
Model: "vllm/custom-model",
APIBase: "http://127.0.0.1:8000/v1/",
}
k1 := modelProbeCacheKey(m1)
k2 := modelProbeCacheKey(m2)
if k1 != k2 {
t.Fatalf("modelProbeCacheKey() mismatch for equivalent api_base values: %q vs %q", k1, k2)
}
}
func TestModelProbeCacheKey_IgnoresDisplayAndConnectionFields(t *testing.T) {
base := &config.ModelConfig{
ModelName: "vllm-one",
Model: "vllm/custom-model",
APIBase: "http://127.0.0.1:8000/v1",
AuthMethod: "none",
ConnectMode: "http",
}
changed := &config.ModelConfig{
ModelName: "vllm-two",
Model: "vllm/custom-model",
APIBase: "http://127.0.0.1:8000/v1",
AuthMethod: "token",
ConnectMode: "ws",
}
k1 := modelProbeCacheKey(base)
k2 := modelProbeCacheKey(changed)
if k1 != k2 {
t.Fatalf("modelProbeCacheKey() should ignore non-probe fields, got %q vs %q", k1, k2)
}
}
func TestProbeLocalModelAvailability_SuccessBackoff(t *testing.T) {
resetModelProbeHooks(t)
now := time.Unix(1700000000, 0)
modelProbeNowFunc = func() time.Time { return now }
calls := 0
probeOpenAICompatibleModelFunc = func(apiBase, modelID, apiKey string) bool {
calls++
return true
}
model := &config.ModelConfig{
ModelName: "local-vllm",
Model: "vllm/custom-model",
APIBase: "http://127.0.0.1:8000/v1",
}
if !probeLocalModelAvailability(model) {
t.Fatal("first probe result = false, want true")
}
if calls != 1 {
t.Fatalf("probe calls after first probe = %d, want 1", calls)
}
if !probeLocalModelAvailability(model) {
t.Fatal("cached probe result = false, want true")
}
if calls != 1 {
t.Fatalf("probe calls after immediate re-check = %d, want 1", calls)
}
now = now.Add(modelProbeSuccessBaseInterval)
if !probeLocalModelAvailability(model) {
t.Fatal("second probe result = false, want true")
}
if calls != 2 {
t.Fatalf("probe calls after success backoff window = %d, want 2", calls)
}
now = now.Add(modelProbeSuccessBaseInterval)
if !probeLocalModelAvailability(model) {
t.Fatal("cached result after doubled backoff = false, want true")
}
if calls != 2 {
t.Fatalf("probe calls before doubled backoff expires = %d, want 2", calls)
}
now = now.Add(modelProbeSuccessBaseInterval)
if !probeLocalModelAvailability(model) {
t.Fatal("third probe result = false, want true")
}
if calls != 3 {
t.Fatalf("probe calls after doubled backoff expires = %d, want 3", calls)
}
}
func TestProbeLocalModelAvailability_FailureBackoff(t *testing.T) {
resetModelProbeHooks(t)
now := time.Unix(1700000100, 0)
modelProbeNowFunc = func() time.Time { return now }
calls := 0
probeOpenAICompatibleModelFunc = func(apiBase, modelID, apiKey string) bool {
calls++
return false
}
model := &config.ModelConfig{
ModelName: "local-vllm",
Model: "vllm/custom-model",
APIBase: "http://127.0.0.1:8000/v1",
}
if probeLocalModelAvailability(model) {
t.Fatal("first probe result = true, want false")
}
if calls != 1 {
t.Fatalf("probe calls after first failure = %d, want 1", calls)
}
if probeLocalModelAvailability(model) {
t.Fatal("cached failed probe result = true, want false")
}
if calls != 1 {
t.Fatalf("probe calls after immediate failed re-check = %d, want 1", calls)
}
now = now.Add(modelProbeFailureBaseInterval)
if probeLocalModelAvailability(model) {
t.Fatal("second failed probe result = true, want false")
}
if calls != 2 {
t.Fatalf("probe calls after failure backoff window = %d, want 2", calls)
}
now = now.Add(modelProbeFailureBaseInterval)
if probeLocalModelAvailability(model) {
t.Fatal("cached failure after doubled backoff = true, want false")
}
if calls != 2 {
t.Fatalf("probe calls before doubled failure backoff expires = %d, want 2", calls)
}
now = now.Add(modelProbeFailureBaseInterval)
if probeLocalModelAvailability(model) {
t.Fatal("third failed probe result = true, want false")
}
if calls != 3 {
t.Fatalf("probe calls after doubled failure backoff expires = %d, want 3", calls)
}
}
func TestProbeLocalModelAvailability_ResultFlipResetsBackoff(t *testing.T) {
resetModelProbeHooks(t)
now := time.Unix(1700000200, 0)
modelProbeNowFunc = func() time.Time { return now }
results := []bool{true, false, false}
index := 0
probeOpenAICompatibleModelFunc = func(apiBase, modelID, apiKey string) bool {
if index >= len(results) {
return false
}
result := results[index]
index++
return result
}
model := &config.ModelConfig{
ModelName: "local-vllm",
Model: "vllm/custom-model",
APIBase: "http://127.0.0.1:8000/v1",
}
if !probeLocalModelAvailability(model) {
t.Fatal("first probe result = false, want true")
}
now = now.Add(modelProbeSuccessBaseInterval)
if probeLocalModelAvailability(model) {
t.Fatal("second probe result = true, want false")
}
now = now.Add(modelProbeFailureBaseInterval)
if probeLocalModelAvailability(model) {
t.Fatal("third probe result = true, want false")
}
if index != 3 {
t.Fatalf("probe invocations = %d, want 3", index)
}
}
func TestProbeLocalModelAvailability_DeduplicatesInflightProbe(t *testing.T) {
resetModelProbeHooks(t)
now := time.Unix(1700000300, 0)
modelProbeNowFunc = func() time.Time { return now }
var calls int32
probeStarted := make(chan struct{})
releaseProbe := make(chan struct{})
probeOpenAICompatibleModelFunc = func(apiBase, modelID, apiKey string) bool {
if atomic.AddInt32(&calls, 1) == 1 {
close(probeStarted)
}
<-releaseProbe
return true
}
model := &config.ModelConfig{
ModelName: "local-vllm",
Model: "vllm/custom-model",
APIBase: "http://127.0.0.1:8000/v1",
}
const workers = 8
var wg sync.WaitGroup
results := make(chan bool, workers)
workerStarted := make(chan struct{}, workers)
for range workers {
wg.Add(1)
go func() {
defer wg.Done()
workerStarted <- struct{}{}
results <- probeLocalModelAvailability(model)
}()
}
for range workers {
<-workerStarted
}
select {
case <-probeStarted:
case <-time.After(200 * time.Millisecond):
t.Fatal("probe did not start in time")
}
if got := atomic.LoadInt32(&calls); got != 1 {
t.Fatalf("concurrent probe calls = %d, want 1", got)
}
close(releaseProbe)
wg.Wait()
close(results)
for result := range results {
if !result {
t.Fatal("deduplicated probe result = false, want true")
}
}
if got := atomic.LoadInt32(&calls); got != 1 {
t.Fatalf("final probe calls = %d, want 1", got)
}
}
func TestOllamaModelMatches_WithTagRequiresExactTag(t *testing.T) {
if ollamaModelMatches("llama3:8b", "llama3:7b") {
t.Fatal("ollamaModelMatches() = true, want false for mismatched tags")
}
if !ollamaModelMatches("llama3:7b", "llama3:7b") {
t.Fatal("ollamaModelMatches() = false, want true for exact tagged match")
}
if ollamaModelMatches("llama3:8b", "llama3") {
t.Fatal("ollamaModelMatches() = true, want false when request omits tag (defaults to latest)")
}
if !ollamaModelMatches("llama3:latest", "llama3") {
t.Fatal("ollamaModelMatches() = false, want true when request omits tag and candidate is latest")
}
if !ollamaModelMatches("llama3", "llama3") {
t.Fatal("ollamaModelMatches() = false, want true when both candidate and request omit tag (latest)")
}
}
+17 -7
View File
@@ -32,13 +32,14 @@ type modelResponse struct {
Proxy string `json:"proxy,omitempty"`
AuthMethod string `json:"auth_method,omitempty"`
// Advanced fields
ConnectMode string `json:"connect_mode,omitempty"`
Workspace string `json:"workspace,omitempty"`
RPM int `json:"rpm,omitempty"`
MaxTokensField string `json:"max_tokens_field,omitempty"`
RequestTimeout int `json:"request_timeout,omitempty"`
ThinkingLevel string `json:"thinking_level,omitempty"`
ExtraBody map[string]any `json:"extra_body,omitempty"`
ConnectMode string `json:"connect_mode,omitempty"`
Workspace string `json:"workspace,omitempty"`
RPM int `json:"rpm,omitempty"`
MaxTokensField string `json:"max_tokens_field,omitempty"`
RequestTimeout int `json:"request_timeout,omitempty"`
ThinkingLevel string `json:"thinking_level,omitempty"`
ExtraBody map[string]any `json:"extra_body,omitempty"`
CustomHeaders map[string]string `json:"custom_headers,omitempty"`
// Meta
Enabled bool `json:"enabled"`
Available bool `json:"available"`
@@ -87,6 +88,7 @@ func (h *Handler) handleListModels(w http.ResponseWriter, r *http.Request) {
RequestTimeout: m.RequestTimeout,
ThinkingLevel: m.ThinkingLevel,
ExtraBody: m.ExtraBody,
CustomHeaders: m.CustomHeaders,
Enabled: m.Enabled,
Available: modelStatuses[i].Available,
Status: modelStatuses[i].Status,
@@ -216,6 +218,14 @@ func (h *Handler) handleUpdateModel(w http.ResponseWriter, r *http.Request) {
} else if len(mc.ExtraBody) == 0 {
mc.ExtraBody = nil
}
// Preserve existing CustomHeaders when omitted (nil), but clear it when
// the frontend sends an empty object {} to indicate the field should
// be removed.
if mc.CustomHeaders == nil {
mc.CustomHeaders = cfg.ModelList[idx].CustomHeaders
} else if len(mc.CustomHeaders) == 0 {
mc.CustomHeaders = nil
}
cfg.ModelList[idx] = &mc.ModelConfig
+110
View File
@@ -20,10 +20,14 @@ func resetModelProbeHooks(t *testing.T) {
origTCPProbe := probeTCPServiceFunc
origOllamaProbe := probeOllamaModelFunc
origOpenAIProbe := probeOpenAICompatibleModelFunc
origNow := modelProbeNowFunc
resetModelProbeCache()
t.Cleanup(func() {
probeTCPServiceFunc = origTCPProbe
probeOllamaModelFunc = origOllamaProbe
probeOpenAICompatibleModelFunc = origOpenAIProbe
modelProbeNowFunc = origNow
resetModelProbeCache()
})
}
@@ -426,6 +430,112 @@ func TestHandleAddModel_PersistsAPIKey(t *testing.T) {
}
}
func TestHandleAddModel_PersistsCustomHeaders(t *testing.T) {
configPath, cleanup := setupOAuthTestEnv(t)
defer cleanup()
h := NewHandler(configPath)
mux := http.NewServeMux()
h.RegisterRoutes(mux)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/models", bytes.NewBufferString(`{
"model_name":"new-model-headers",
"model":"openai/gpt-4o-mini",
"custom_headers":{"X-Source":"coding-plan","X-Agent":"openclaw"}
}`))
req.Header.Set("Content-Type", "application/json")
mux.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
}
cfg, err := config.LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error = %v", err)
}
if len(cfg.ModelList) != 2 {
t.Fatalf("len(model_list) = %d, want 2", len(cfg.ModelList))
}
added := cfg.ModelList[1]
if added.CustomHeaders == nil {
t.Fatal("custom_headers should not be nil")
}
if got := added.CustomHeaders["X-Source"]; got != "coding-plan" {
t.Fatalf("custom_headers[X-Source] = %q, want %q", got, "coding-plan")
}
if got := added.CustomHeaders["X-Agent"]; got != "openclaw" {
t.Fatalf("custom_headers[X-Agent] = %q, want %q", got, "openclaw")
}
}
func TestHandleUpdateModel_CustomHeadersPreserveAndClear(t *testing.T) {
configPath, cleanup := setupOAuthTestEnv(t)
defer cleanup()
cfg, err := config.LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error = %v", err)
}
cfg.ModelList = []*config.ModelConfig{{
ModelName: "editable",
Model: "openai/gpt-4o-mini",
APIKeys: config.SimpleSecureStrings("sk-existing"),
CustomHeaders: map[string]string{"X-Source": "coding-plan"},
}}
err = config.SaveConfig(configPath, cfg)
if err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}
h := NewHandler(configPath)
mux := http.NewServeMux()
h.RegisterRoutes(mux)
// Omitted custom_headers should preserve existing value.
recPreserve := httptest.NewRecorder()
reqPreserve := httptest.NewRequest(http.MethodPut, "/api/models/0", bytes.NewBufferString(`{
"model_name":"editable",
"model":"openai/gpt-4o-mini"
}`))
reqPreserve.Header.Set("Content-Type", "application/json")
mux.ServeHTTP(recPreserve, reqPreserve)
if recPreserve.Code != http.StatusOK {
t.Fatalf("preserve status = %d, want %d, body=%s", recPreserve.Code, http.StatusOK, recPreserve.Body.String())
}
afterPreserve, err := config.LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() after preserve error = %v", err)
}
if got := afterPreserve.ModelList[0].CustomHeaders["X-Source"]; got != "coding-plan" {
t.Fatalf("preserved custom_headers[X-Source] = %q, want %q", got, "coding-plan")
}
// Empty object should clear custom_headers.
recClear := httptest.NewRecorder()
reqClear := httptest.NewRequest(http.MethodPut, "/api/models/0", bytes.NewBufferString(`{
"model_name":"editable",
"model":"openai/gpt-4o-mini",
"custom_headers":{}
}`))
reqClear.Header.Set("Content-Type", "application/json")
mux.ServeHTTP(recClear, reqClear)
if recClear.Code != http.StatusOK {
t.Fatalf("clear status = %d, want %d, body=%s", recClear.Code, http.StatusOK, recClear.Body.String())
}
afterClear, err := config.LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() after clear error = %v", err)
}
if afterClear.ModelList[0].CustomHeaders != nil {
t.Fatalf("custom_headers = %#v, want nil", afterClear.ModelList[0].CustomHeaders)
}
}
// TestHandleSetDefaultModel_RejectsNonexistentModel tests that setting a non-existent
// model as default returns 404. This covers the case where virtual models (which are
// filtered by SaveConfig) cannot be set as default.
+27 -1
View File
@@ -11,6 +11,7 @@ import (
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
ppid "github.com/sipeed/picoclaw/pkg/pid"
)
// registerPicoRoutes binds Pico Channel management endpoints to the ServeMux.
@@ -57,9 +58,34 @@ func (h *Handler) handleWebSocketProxy() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
gateway.mu.Lock()
ensurePicoTokenCachedLocked(h.configPath)
gatewayAvailable := gateway.pidData != nil
cachedPID := gateway.pidData
trackedCmd := gateway.cmd
gateway.mu.Unlock()
gatewayAvailable := false
// Prefer fresh PID file data when available.
if pidData := ppid.ReadPidFileWithCheck(globalConfigDir()); pidData != nil {
gateway.mu.Lock()
gateway.pidData = pidData
setGatewayRuntimeStatusLocked("running")
gatewayAvailable = true
gateway.mu.Unlock()
} else if cachedPID != nil {
// No PID file now: keep availability only while tracked process is
// still alive (covers short PID-file races at startup/restart).
if isCmdProcessAliveLocked(trackedCmd) {
gatewayAvailable = true
} else {
gateway.mu.Lock()
if gateway.cmd == trackedCmd {
gateway.pidData = nil
setGatewayRuntimeStatusLocked("stopped")
}
gatewayAvailable = gateway.pidData != nil
gateway.mu.Unlock()
}
}
if !gatewayAvailable {
logger.Warnf("Gateway not available for WebSocket proxy")
http.Error(w, "Gateway not available", http.StatusServiceUnavailable)
+151
View File
@@ -11,6 +11,7 @@ import (
"strconv"
"testing"
"github.com/sipeed/picoclaw/pkg/channels/pico"
"github.com/sipeed/picoclaw/pkg/config"
ppid "github.com/sipeed/picoclaw/pkg/pid"
)
@@ -307,6 +308,9 @@ func TestHandlePicoSetup_Response(t *testing.T) {
}
func TestHandleWebSocketProxyReloadsGatewayTargetFromConfig(t *testing.T) {
home := t.TempDir()
t.Setenv("PICOCLAW_HOME", home)
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
handler := h.handleWebSocketProxy()
@@ -335,6 +339,16 @@ func TestHandleWebSocketProxyReloadsGatewayTargetFromConfig(t *testing.T) {
if err := config.SaveConfig(configPath, cfg); err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}
if _, err := ppid.WritePidFile(globalConfigDir(), cfg.Gateway.Host, cfg.Gateway.Port); err != nil {
t.Fatalf("WritePidFile() error = %v", err)
}
origPidData := gateway.pidData
origPicoToken := gateway.picoToken
t.Cleanup(func() {
ppid.RemovePidFile(globalConfigDir())
gateway.pidData = origPidData
gateway.picoToken = origPicoToken
})
gateway.pidData = &ppid.PidFileData{}
gateway.picoToken = "pico"
@@ -378,6 +392,9 @@ func TestHandleWebSocketProxyReloadsGatewayTargetFromConfig(t *testing.T) {
}
func TestHandleWebSocketProxyLoadsCachedPicoTokenWhenMissing(t *testing.T) {
home := t.TempDir()
t.Setenv("PICOCLAW_HOME", home)
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
handler := h.handleWebSocketProxy()
@@ -399,6 +416,12 @@ func TestHandleWebSocketProxyLoadsCachedPicoTokenWhenMissing(t *testing.T) {
if err := config.SaveConfig(configPath, cfg); err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}
if _, err := ppid.WritePidFile(globalConfigDir(), cfg.Gateway.Host, cfg.Gateway.Port); err != nil {
t.Fatalf("WritePidFile() error = %v", err)
}
t.Cleanup(func() {
ppid.RemovePidFile(globalConfigDir())
})
origPidData := gateway.pidData
origPicoToken := gateway.picoToken
@@ -426,6 +449,134 @@ func TestHandleWebSocketProxyLoadsCachedPicoTokenWhenMissing(t *testing.T) {
}
}
func TestHandleWebSocketProxyLoadsPidDataOnDemand(t *testing.T) {
home := t.TempDir()
t.Setenv("PICOCLAW_HOME", home)
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
handler := h.handleWebSocketProxy()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/pico/ws" {
t.Fatalf("path = %q, want %q", r.URL.Path, "/pico/ws")
}
w.WriteHeader(http.StatusOK)
_, _ = io.WriteString(w, r.Header.Get(protocolKey))
}))
defer server.Close()
cfg := config.DefaultConfig()
cfg.Gateway.Host = "127.0.0.1"
cfg.Gateway.Port = mustGatewayTestPort(t, server.URL)
cfg.Channels.Pico.Enabled = true
cfg.Channels.Pico.SetToken("ui-token")
if err := config.SaveConfig(configPath, cfg); err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}
pidData, err := ppid.WritePidFile(globalConfigDir(), cfg.Gateway.Host, cfg.Gateway.Port)
if err != nil {
t.Fatalf("WritePidFile() error = %v", err)
}
t.Cleanup(func() {
ppid.RemovePidFile(globalConfigDir())
})
origPidData := gateway.pidData
origPicoToken := gateway.picoToken
origStatus := gateway.runtimeStatus
t.Cleanup(func() {
gateway.mu.Lock()
gateway.pidData = origPidData
gateway.picoToken = origPicoToken
gateway.runtimeStatus = origStatus
gateway.mu.Unlock()
})
gateway.mu.Lock()
gateway.pidData = nil
gateway.picoToken = ""
setGatewayRuntimeStatusLocked("stopped")
gateway.mu.Unlock()
req := httptest.NewRequest(http.MethodGet, "/pico/ws?session_id=test-session", nil)
req.Header.Set(protocolKey, tokenPrefix+"ui-token")
rec := httptest.NewRecorder()
handler(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}
expected := tokenPrefix + pico.PicoTokenPrefix + pidData.Token + "ui-token"
if got := rec.Body.String(); got != expected {
t.Fatalf("forwarded protocol = %q, want %q", got, expected)
}
gateway.mu.Lock()
defer gateway.mu.Unlock()
if gateway.pidData == nil {
t.Fatal("gateway.pidData should be loaded from pid file")
}
if gateway.runtimeStatus != "running" {
t.Fatalf("runtimeStatus = %q, want %q", gateway.runtimeStatus, "running")
}
}
func TestHandleWebSocketProxyRejectsStalePidDataAfterProcessExit(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
handler := h.handleWebSocketProxy()
cfg := config.DefaultConfig()
cfg.Channels.Pico.Enabled = true
cfg.Channels.Pico.SetToken("ui-token")
if err := config.SaveConfig(configPath, cfg); err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}
cmd := startLongRunningProcess(t)
if cmd.Process != nil {
_ = cmd.Process.Kill()
}
_ = cmd.Wait()
origPidData := gateway.pidData
origPicoToken := gateway.picoToken
origCmd := gateway.cmd
origStatus := gateway.runtimeStatus
t.Cleanup(func() {
gateway.mu.Lock()
gateway.pidData = origPidData
gateway.picoToken = origPicoToken
gateway.cmd = origCmd
gateway.runtimeStatus = origStatus
gateway.mu.Unlock()
})
gateway.mu.Lock()
gateway.pidData = &ppid.PidFileData{PID: cmd.Process.Pid, Token: "stale-token"}
gateway.picoToken = "ui-token"
gateway.cmd = cmd
setGatewayRuntimeStatusLocked("running")
gateway.mu.Unlock()
req := httptest.NewRequest(http.MethodGet, "/pico/ws?session_id=test-session", nil)
req.Header.Set(protocolKey, tokenPrefix+"ui-token")
rec := httptest.NewRecorder()
handler(rec, req)
if rec.Code != http.StatusServiceUnavailable {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusServiceUnavailable)
}
gateway.mu.Lock()
defer gateway.mu.Unlock()
if gateway.pidData != nil {
t.Fatal("gateway.pidData should be cleared after stale process exit is detected")
}
}
func mustGatewayTestPort(t *testing.T, rawURL string) int {
t.Helper()
+3
View File
@@ -81,6 +81,9 @@ func (h *Handler) RegisterRoutes(mux *http.ServeMux) {
// Launcher service parameters (port/public)
h.registerLauncherConfigRoutes(mux)
// Self-update endpoint (requires dashboard auth)
h.registerUpdateRoutes(mux)
// Runtime build/version metadata
h.registerVersionRoutes(mux)
+113 -34
View File
@@ -44,12 +44,24 @@ type sessionListItem struct {
Updated string `json:"updated"`
}
type sessionChatMessage struct {
Role string `json:"role"`
Content string `json:"content"`
Media []string `json:"media,omitempty"`
}
// legacyPicoSessionPrefix is the legacy key prefix used by older Pico JSON/JSONL
// sessions before structured scope metadata existed.
const (
legacyPicoSessionPrefix = "agent:main:pico:direct:pico:"
maxSessionJSONLLineSize = 10 * 1024 * 1024 // 10 MB
picoSessionPrefix = legacyPicoSessionPrefix
// Keep the session API aligned with the shared JSONL store reader limit in
// pkg/memory/jsonl.go so oversized lines fail consistently everywhere.
maxSessionJSONLLineSize = 10 * 1024 * 1024
maxSessionTitleRunes = 60
handledToolResponseSummaryText = "Requested output delivered via tool attachment."
)
// extractLegacyPicoSessionID extracts the session UUID from an old Pico key.
@@ -327,32 +339,21 @@ func (h *Handler) findLegacyPicoSession(dir, sessionID string) (picoLegacySessio
func buildSessionListItem(sessionID string, sess sessionFile) sessionListItem {
preview := ""
for _, msg := range sess.Messages {
if msg.Role == "user" && strings.TrimSpace(msg.Content) != "" {
preview = msg.Content
if msg.Role == "user" {
preview = sessionMessagePreview(msg)
}
if preview != "" {
break
}
}
title := strings.TrimSpace(sess.Summary)
if title == "" {
title = preview
}
title = truncateRunes(title, maxSessionTitleRunes)
preview = truncateRunes(preview, maxSessionTitleRunes)
if preview == "" {
preview = "(empty)"
}
if title == "" {
title = preview
}
title := preview
validMessageCount := 0
for _, msg := range sess.Messages {
if (msg.Role == "user" || msg.Role == "assistant") && strings.TrimSpace(msg.Content) != "" {
validMessageCount++
}
}
validMessageCount := len(visibleSessionMessages(sess.Messages))
return sessionListItem{
ID: sessionID,
@@ -379,6 +380,99 @@ func truncateRunes(s string, maxLen int) string {
return string(runes[:maxLen]) + "..."
}
func sessionMessageVisible(msg providers.Message) bool {
return strings.TrimSpace(msg.Content) != "" || len(msg.Media) > 0
}
func sessionMessagePreview(msg providers.Message) string {
if content := strings.TrimSpace(msg.Content); content != "" {
return content
}
if len(msg.Media) > 0 {
return "[image]"
}
return ""
}
func visibleSessionMessages(messages []providers.Message) []sessionChatMessage {
transcript := make([]sessionChatMessage, 0, len(messages))
for _, msg := range messages {
switch msg.Role {
case "user":
if sessionMessageVisible(msg) {
transcript = append(transcript, sessionChatMessage{
Role: "user",
Content: msg.Content,
Media: append([]string(nil), msg.Media...),
})
}
case "assistant":
visibleToolMessages := visibleAssistantToolMessages(msg.ToolCalls)
if len(visibleToolMessages) > 0 {
transcript = append(transcript, visibleToolMessages...)
}
// Pico web chat can persist both visible `message` tool output and a
// later plain assistant reply in the same turn. Hide only the fixed
// internal summary that marks handled tool delivery.
if len(visibleToolMessages) > 0 || !sessionMessageVisible(msg) || assistantMessageInternalOnly(msg) {
continue
}
transcript = append(transcript, sessionChatMessage{
Role: "assistant",
Content: msg.Content,
Media: append([]string(nil), msg.Media...),
})
}
}
return transcript
}
func assistantMessageInternalOnly(msg providers.Message) bool {
return strings.TrimSpace(msg.Content) == handledToolResponseSummaryText
}
func visibleAssistantToolMessages(toolCalls []providers.ToolCall) []sessionChatMessage {
if len(toolCalls) == 0 {
return nil
}
messages := make([]sessionChatMessage, 0, len(toolCalls))
for _, tc := range toolCalls {
name := tc.Name
argsJSON := ""
if tc.Function != nil {
if name == "" {
name = tc.Function.Name
}
argsJSON = tc.Function.Arguments
}
switch name {
case "message":
var args struct {
Content string `json:"content"`
}
if err := json.Unmarshal([]byte(argsJSON), &args); err != nil {
continue
}
if strings.TrimSpace(args.Content) == "" {
continue
}
messages = append(messages, sessionChatMessage{
Role: "assistant",
Content: args.Content,
})
}
}
return messages
}
// sessionsDir resolves the path to the gateway's session storage directory.
// It reads the workspace from config, falling back to ~/.picoclaw/workspace.
func (h *Handler) sessionsDir() (string, error) {
@@ -530,22 +624,7 @@ func (h *Handler) handleGetSession(w http.ResponseWriter, r *http.Request) {
}
}
// Convert to a simpler format for the frontend
type chatMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
messages := make([]chatMessage, 0, len(sess.Messages))
for _, msg := range sess.Messages {
// Only include user and assistant messages that have actual content
if (msg.Role == "user" || msg.Role == "assistant") && strings.TrimSpace(msg.Content) != "" {
messages = append(messages, chatMessage{
Role: msg.Role,
Content: msg.Content,
})
}
}
messages := visibleSessionMessages(sess.Messages)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
+376 -23
View File
@@ -6,6 +6,7 @@ import (
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"github.com/sipeed/picoclaw/pkg/config"
@@ -34,9 +35,9 @@ func TestHandleListSessions_JSONLStorage(t *testing.T) {
defer cleanup()
dir := sessionsTestDir(t, configPath)
store, err := memory.NewJSONLStore(dir)
if err != nil {
t.Fatalf("NewJSONLStore() error = %v", err)
store, storeErr := memory.NewJSONLStore(dir)
if storeErr != nil {
t.Fatalf("NewJSONLStore() error = %v", storeErr)
}
sessionKey := legacyPicoSessionPrefix + "history-jsonl"
@@ -87,22 +88,26 @@ func TestHandleListSessions_JSONLStorage(t *testing.T) {
if items[0].MessageCount != 2 {
t.Fatalf("items[0].MessageCount = %d, want 2", items[0].MessageCount)
}
if items[0].Title != "JSONL-backed session" {
t.Fatalf("items[0].Title = %q, want %q", items[0].Title, "JSONL-backed session")
if items[0].Title != "Explain why the history API is empty after migration." {
t.Fatalf(
"items[0].Title = %q, want %q",
items[0].Title,
"Explain why the history API is empty after migration.",
)
}
if items[0].Preview != "Explain why the history API is empty after migration." {
t.Fatalf("items[0].Preview = %q", items[0].Preview)
}
}
func TestHandleListSessions_TitleUsesTrimmedSummary(t *testing.T) {
func TestHandleListSessions_TitleUsesFirstUserMessage(t *testing.T) {
configPath, cleanup := setupOAuthTestEnv(t)
defer cleanup()
dir := sessionsTestDir(t, configPath)
store, err := memory.NewJSONLStore(dir)
if err != nil {
t.Fatalf("NewJSONLStore() error = %v", err)
store, storeErr := memory.NewJSONLStore(dir)
if storeErr != nil {
t.Fatalf("NewJSONLStore() error = %v", storeErr)
}
sessionKey := legacyPicoSessionPrefix + "summary-title"
@@ -139,10 +144,7 @@ func TestHandleListSessions_TitleUsesTrimmedSummary(t *testing.T) {
if len(items) != 1 {
t.Fatalf("len(items) = %d, want 1", len(items))
}
expectedTitle := truncateRunes(
"This summary is intentionally longer than sixty characters so it must be truncated in the history menu.",
maxSessionTitleRunes,
)
expectedTitle := truncateRunes("fallback preview", maxSessionTitleRunes)
if items[0].Title != expectedTitle {
t.Fatalf("items[0].Title = %q", items[0].Title)
}
@@ -220,22 +222,20 @@ func TestHandleSessions_JSONLScopeDiscovery(t *testing.T) {
defer cleanup()
dir := sessionsTestDir(t, configPath)
store, err := memory.NewJSONLStore(dir)
if err != nil {
t.Fatalf("NewJSONLStore() error = %v", err)
store, storeErr := memory.NewJSONLStore(dir)
if storeErr != nil {
t.Fatalf("NewJSONLStore() error = %v", storeErr)
}
sessionKey := "sk_v1_scope_discovery"
addErr := store.AddFullMessage(nil, sessionKey, providers.Message{
if err := store.AddFullMessage(nil, sessionKey, providers.Message{
Role: "user",
Content: "scope discovered session",
})
if addErr != nil {
t.Fatalf("AddFullMessage() error = %v", addErr)
}); err != nil {
t.Fatalf("AddFullMessage() error = %v", err)
}
summaryErr := store.SetSummary(nil, sessionKey, "scope summary")
if summaryErr != nil {
t.Fatalf("SetSummary() error = %v", summaryErr)
if err := store.SetSummary(nil, sessionKey, "scope summary"); err != nil {
t.Fatalf("SetSummary() error = %v", err)
}
scopeData, err := json.Marshal(session.SessionScope{
@@ -292,6 +292,359 @@ func TestHandleSessions_JSONLScopeDiscovery(t *testing.T) {
}
}
func TestHandleGetSession_ReconstructsVisibleMessageToolOutput(t *testing.T) {
configPath, cleanup := setupOAuthTestEnv(t)
defer cleanup()
dir := sessionsTestDir(t, configPath)
store, err := memory.NewJSONLStore(dir)
if err != nil {
t.Fatalf("NewJSONLStore() error = %v", err)
}
sessionKey := picoSessionPrefix + "detail-message-tool"
for _, msg := range []providers.Message{
{Role: "user", Content: "test"},
{
Role: "assistant",
Content: "",
ToolCalls: []providers.ToolCall{
{
ID: "call_1",
Type: "function",
Function: &providers.FunctionCall{
Name: "message",
Arguments: `{"content":"visible tool output"}`,
},
},
},
},
{Role: "tool", Content: "Message sent to pico:pico:detail-message-tool", ToolCallID: "call_1"},
{Role: "assistant", Content: handledToolResponseSummaryText},
} {
if err := store.AddFullMessage(nil, sessionKey, msg); err != nil {
t.Fatalf("AddFullMessage() error = %v", err)
}
}
h := NewHandler(configPath)
mux := http.NewServeMux()
h.RegisterRoutes(mux)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/sessions/detail-message-tool", nil)
mux.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
}
var resp struct {
Messages []struct {
Role string `json:"role"`
Content string `json:"content"`
} `json:"messages"`
}
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
t.Fatalf("Unmarshal() error = %v", err)
}
if len(resp.Messages) != 2 {
t.Fatalf("len(resp.Messages) = %d, want 2", len(resp.Messages))
}
if resp.Messages[1].Role != "assistant" || resp.Messages[1].Content != "visible tool output" {
t.Fatalf("assistant message = %#v, want visible tool output", resp.Messages[1])
}
}
func TestHandleGetSession_PreservesFinalAssistantReplyAfterMessageToolOutput(t *testing.T) {
configPath, cleanup := setupOAuthTestEnv(t)
defer cleanup()
dir := sessionsTestDir(t, configPath)
store, err := memory.NewJSONLStore(dir)
if err != nil {
t.Fatalf("NewJSONLStore() error = %v", err)
}
sessionKey := picoSessionPrefix + "detail-message-tool-final-reply"
for _, msg := range []providers.Message{
{Role: "user", Content: "test"},
{
Role: "assistant",
ToolCalls: []providers.ToolCall{
{
ID: "call_1",
Type: "function",
Function: &providers.FunctionCall{
Name: "message",
Arguments: `{"content":"visible tool output"}`,
},
},
},
},
{Role: "tool", Content: "Message sent to pico:pico:detail-message-tool-final-reply", ToolCallID: "call_1"},
{Role: "assistant", Content: "final assistant reply"},
} {
if err := store.AddFullMessage(nil, sessionKey, msg); err != nil {
t.Fatalf("AddFullMessage() error = %v", err)
}
}
h := NewHandler(configPath)
mux := http.NewServeMux()
h.RegisterRoutes(mux)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/sessions/detail-message-tool-final-reply", nil)
mux.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
}
var resp struct {
Messages []struct {
Role string `json:"role"`
Content string `json:"content"`
} `json:"messages"`
}
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
t.Fatalf("Unmarshal() error = %v", err)
}
if len(resp.Messages) != 3 {
t.Fatalf("len(resp.Messages) = %d, want 3", len(resp.Messages))
}
if resp.Messages[1].Role != "assistant" || resp.Messages[1].Content != "visible tool output" {
t.Fatalf("interim assistant message = %#v, want visible tool output", resp.Messages[1])
}
if resp.Messages[2].Role != "assistant" || resp.Messages[2].Content != "final assistant reply" {
t.Fatalf("final assistant message = %#v, want final assistant reply", resp.Messages[2])
}
}
func TestHandleListSessions_MessageCountUsesVisibleTranscript(t *testing.T) {
configPath, cleanup := setupOAuthTestEnv(t)
defer cleanup()
dir := sessionsTestDir(t, configPath)
store, err := memory.NewJSONLStore(dir)
if err != nil {
t.Fatalf("NewJSONLStore() error = %v", err)
}
sessionKey := picoSessionPrefix + "list-visible-count"
for _, msg := range []providers.Message{
{Role: "user", Content: "test"},
{
Role: "assistant",
ToolCalls: []providers.ToolCall{
{
ID: "call_1",
Type: "function",
Function: &providers.FunctionCall{
Name: "message",
Arguments: `{"content":"visible tool output"}`,
},
},
},
},
{Role: "tool", Content: "Message sent to pico:pico:list-visible-count", ToolCallID: "call_1"},
{Role: "assistant", Content: handledToolResponseSummaryText},
} {
if err := store.AddFullMessage(nil, sessionKey, msg); err != nil {
t.Fatalf("AddFullMessage() error = %v", err)
}
}
h := NewHandler(configPath)
mux := http.NewServeMux()
h.RegisterRoutes(mux)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/sessions", nil)
mux.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
}
var items []sessionListItem
if err := json.Unmarshal(rec.Body.Bytes(), &items); err != nil {
t.Fatalf("Unmarshal() error = %v", err)
}
if len(items) != 1 {
t.Fatalf("len(items) = %d, want 1", len(items))
}
if items[0].MessageCount != 2 {
t.Fatalf("items[0].MessageCount = %d, want 2", items[0].MessageCount)
}
}
func TestHandleGetSession_IncludesMediaOnlyMessages(t *testing.T) {
configPath, cleanup := setupOAuthTestEnv(t)
defer cleanup()
dir := sessionsTestDir(t, configPath)
store, err := memory.NewJSONLStore(dir)
if err != nil {
t.Fatalf("NewJSONLStore() error = %v", err)
}
sessionKey := picoSessionPrefix + "detail-media-only"
if err := store.AddFullMessage(nil, sessionKey, providers.Message{
Role: "user",
Media: []string{"data:image/png;base64,abc123"},
}); err != nil {
t.Fatalf("AddFullMessage(user) error = %v", err)
}
h := NewHandler(configPath)
mux := http.NewServeMux()
h.RegisterRoutes(mux)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/sessions/detail-media-only", nil)
mux.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
}
var resp struct {
Messages []struct {
Role string `json:"role"`
Content string `json:"content"`
Media []string `json:"media"`
} `json:"messages"`
}
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
t.Fatalf("Unmarshal() error = %v", err)
}
if len(resp.Messages) != 1 {
t.Fatalf("len(resp.Messages) = %d, want 1", len(resp.Messages))
}
if resp.Messages[0].Role != "user" || len(resp.Messages[0].Media) != 1 {
t.Fatalf("message = %#v, want user message with media", resp.Messages[0])
}
}
func TestHandleSessions_SupportsJSONLMessagesUpToStoreCap(t *testing.T) {
configPath, cleanup := setupOAuthTestEnv(t)
defer cleanup()
dir := sessionsTestDir(t, configPath)
store, err := memory.NewJSONLStore(dir)
if err != nil {
t.Fatalf("NewJSONLStore() error = %v", err)
}
sessionKey := picoSessionPrefix + "detail-large-jsonl"
largeContent := strings.Repeat("x", 9*1024*1024)
if err := store.AddFullMessage(nil, sessionKey, providers.Message{
Role: "user",
Content: largeContent,
}); err != nil {
t.Fatalf("AddFullMessage() error = %v", err)
}
h := NewHandler(configPath)
mux := http.NewServeMux()
h.RegisterRoutes(mux)
listRec := httptest.NewRecorder()
listReq := httptest.NewRequest(http.MethodGet, "/api/sessions", nil)
mux.ServeHTTP(listRec, listReq)
if listRec.Code != http.StatusOK {
t.Fatalf("list status = %d, want %d, body=%s", listRec.Code, http.StatusOK, listRec.Body.String())
}
var items []sessionListItem
if err := json.Unmarshal(listRec.Body.Bytes(), &items); err != nil {
t.Fatalf("list Unmarshal() error = %v", err)
}
if len(items) != 1 {
t.Fatalf("len(items) = %d, want 1", len(items))
}
detailRec := httptest.NewRecorder()
detailReq := httptest.NewRequest(http.MethodGet, "/api/sessions/detail-large-jsonl", nil)
mux.ServeHTTP(detailRec, detailReq)
if detailRec.Code != http.StatusOK {
t.Fatalf(
"detail status = %d, want %d, body=%s",
detailRec.Code,
http.StatusOK,
detailRec.Body.String(),
)
}
var resp struct {
Messages []struct {
Role string `json:"role"`
Content string `json:"content"`
} `json:"messages"`
}
if err := json.Unmarshal(detailRec.Body.Bytes(), &resp); err != nil {
t.Fatalf("detail Unmarshal() error = %v", err)
}
if len(resp.Messages) != 1 {
t.Fatalf("len(resp.Messages) = %d, want 1", len(resp.Messages))
}
if resp.Messages[0].Role != "user" {
t.Fatalf("resp.Messages[0].Role = %q, want %q", resp.Messages[0].Role, "user")
}
if got := len(resp.Messages[0].Content); got != len(largeContent) {
t.Fatalf("len(resp.Messages[0].Content) = %d, want %d", got, len(largeContent))
}
}
func TestHandleListSessions_UsesImagePreviewForMediaOnlyMessage(t *testing.T) {
configPath, cleanup := setupOAuthTestEnv(t)
defer cleanup()
dir := sessionsTestDir(t, configPath)
store, err := memory.NewJSONLStore(dir)
if err != nil {
t.Fatalf("NewJSONLStore() error = %v", err)
}
sessionKey := picoSessionPrefix + "preview-media-only"
if err := store.AddFullMessage(nil, sessionKey, providers.Message{
Role: "user",
Media: []string{"data:image/png;base64,abc123"},
}); err != nil {
t.Fatalf("AddFullMessage() error = %v", err)
}
h := NewHandler(configPath)
mux := http.NewServeMux()
h.RegisterRoutes(mux)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/sessions", nil)
mux.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
}
var items []sessionListItem
if err := json.Unmarshal(rec.Body.Bytes(), &items); err != nil {
t.Fatalf("Unmarshal() error = %v", err)
}
if len(items) != 1 {
t.Fatalf("len(items) = %d, want 1", len(items))
}
if items[0].Preview != "[image]" {
t.Fatalf("items[0].Preview = %q, want %q", items[0].Preview, "[image]")
}
if items[0].MessageCount != 1 {
t.Fatalf("items[0].MessageCount = %d, want 1", items[0].MessageCount)
}
}
func TestHandleDeleteSession_JSONLStorage(t *testing.T) {
configPath, cleanup := setupOAuthTestEnv(t)
defer cleanup()
+833 -54
View File
@@ -1,40 +1,115 @@
package api
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"io/fs"
"net/http"
"net/url"
"os"
"path/filepath"
"regexp"
"strconv"
"strings"
"sync"
"time"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/fileutil"
"github.com/sipeed/picoclaw/pkg/skills"
"github.com/sipeed/picoclaw/pkg/utils"
)
type skillSupportResponse struct {
Skills []skills.SkillInfo `json:"skills"`
Skills []skillSupportItem `json:"skills"`
}
type skillSupportItem struct {
Name string `json:"name"`
Path string `json:"path"`
Source string `json:"source"`
Description string `json:"description"`
OriginKind string `json:"origin_kind"`
RegistryName string `json:"registry_name,omitempty"`
RegistryURL string `json:"registry_url,omitempty"`
InstalledVersion string `json:"installed_version,omitempty"`
InstalledAt int64 `json:"installed_at,omitempty"`
}
type skillDetailResponse struct {
Name string `json:"name"`
Path string `json:"path"`
Source string `json:"source"`
Description string `json:"description"`
Content string `json:"content"`
skillSupportItem
Content string `json:"content"`
}
type skillSearchResultItem struct {
Score float64 `json:"score"`
Slug string `json:"slug"`
DisplayName string `json:"display_name"`
Summary string `json:"summary"`
Version string `json:"version"`
RegistryName string `json:"registry_name"`
URL string `json:"url,omitempty"`
Installed bool `json:"installed"`
InstalledName string `json:"installed_name,omitempty"`
}
type skillSearchResponse struct {
Results []skillSearchResultItem `json:"results"`
Limit int `json:"limit"`
Offset int `json:"offset"`
NextOffset int `json:"next_offset,omitempty"`
HasMore bool `json:"has_more"`
}
type installSkillRequest struct {
Slug string `json:"slug"`
Registry string `json:"registry"`
Version string `json:"version,omitempty"`
Force bool `json:"force,omitempty"`
}
type installSkillResponse struct {
Status string `json:"status"`
Slug string `json:"slug"`
Registry string `json:"registry"`
Version string `json:"version"`
Summary string `json:"summary,omitempty"`
IsSuspicious bool `json:"is_suspicious,omitempty"`
InstalledSkill *skillSupportItem `json:"skill,omitempty"`
}
type installedSkillOriginMeta struct {
Version int `json:"version"`
OriginKind string `json:"origin_kind,omitempty"`
Registry string `json:"registry,omitempty"`
Slug string `json:"slug,omitempty"`
RegistryURL string `json:"registry_url,omitempty"`
InstalledVersion string `json:"installed_version,omitempty"`
InstalledAt int64 `json:"installed_at"`
}
var (
skillNameSanitizer = regexp.MustCompile(`[^a-z0-9-]+`)
importedSkillFrontmatter = regexp.MustCompile(`(?s)^---(?:\r\n|\n|\r)(.*?)(?:\r\n|\n|\r)---(?:\r\n|\n|\r)*`)
skillFrontmatterStripper = regexp.MustCompile(`(?s)^---(?:\r\n|\n|\r)(.*?)(?:\r\n|\n|\r)---(?:\r\n|\n|\r)*`)
persistSkillOriginMeta = writeSkillOriginMeta
workspaceSkillWriteMu sync.Mutex
errImportedSkillExists = errors.New("skill already exists")
)
const (
maxImportedSkillSize = 1 << 20
maxRegistrySearchFanout = 1000
)
func (h *Handler) registerSkillRoutes(mux *http.ServeMux) {
mux.HandleFunc("GET /api/skills", h.handleListSkills)
mux.HandleFunc("GET /api/skills/{name}", h.handleGetSkill)
mux.HandleFunc("GET /api/skills/search", h.handleSearchSkills)
mux.HandleFunc("POST /api/skills/install", h.handleInstallSkill)
mux.HandleFunc("POST /api/skills/import", h.handleImportSkill)
mux.HandleFunc("DELETE /api/skills/{name}", h.handleDeleteSkill)
}
@@ -46,11 +121,15 @@ func (h *Handler) handleListSkills(w http.ResponseWriter, r *http.Request) {
return
}
loader := newSkillsLoader(cfg.WorkspacePath())
items, err := buildSkillSupportItems(cfg)
if err != nil {
http.Error(w, fmt.Sprintf("Failed to build skill list: %v", err), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(skillSupportResponse{
Skills: loader.ListSkills(),
Skills: items,
})
}
@@ -61,16 +140,18 @@ func (h *Handler) handleGetSkill(w http.ResponseWriter, r *http.Request) {
return
}
loader := newSkillsLoader(cfg.WorkspacePath())
skillItems, err := buildSkillSupportItems(cfg)
if err != nil {
http.Error(w, fmt.Sprintf("Failed to build skill list: %v", err), http.StatusInternalServerError)
return
}
name := r.PathValue("name")
allSkills := loader.ListSkills()
for _, skill := range allSkills {
if skill.Name != name {
for _, skillItem := range skillItems {
if skillItem.Name != name {
continue
}
content, err := loadSkillContent(skill.Path)
content, err := loadSkillContent(skillItem.Path)
if err != nil {
http.Error(w, "Skill content not found", http.StatusNotFound)
return
@@ -78,11 +159,8 @@ func (h *Handler) handleGetSkill(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(skillDetailResponse{
Name: skill.Name,
Path: skill.Path,
Source: skill.Source,
Description: skill.Description,
Content: content,
skillSupportItem: skillItem,
Content: content,
})
return
}
@@ -90,6 +168,266 @@ func (h *Handler) handleGetSkill(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Skill not found", http.StatusNotFound)
}
func (h *Handler) handleSearchSkills(w http.ResponseWriter, r *http.Request) {
cfg, loadErr := config.LoadConfig(h.configPath)
if loadErr != nil {
http.Error(w, fmt.Sprintf("Failed to load config: %v", loadErr), http.StatusInternalServerError)
return
}
if registryErr := ensureSkillRegistryToolEnabled(cfg, "find_skills"); registryErr != nil {
http.Error(w, registryErr.Error(), http.StatusBadRequest)
return
}
query := strings.TrimSpace(r.URL.Query().Get("q"))
limit := 20
if rawLimit := strings.TrimSpace(r.URL.Query().Get("limit")); rawLimit != "" {
parsedLimit, parseErr := strconv.Atoi(rawLimit)
if parseErr != nil || parsedLimit < 1 || parsedLimit > 50 {
http.Error(w, "limit must be between 1 and 50", http.StatusBadRequest)
return
}
limit = parsedLimit
}
offset := 0
if rawOffset := strings.TrimSpace(r.URL.Query().Get("offset")); rawOffset != "" {
parsedOffset, parseErr := strconv.Atoi(rawOffset)
if parseErr != nil || parsedOffset < 0 {
http.Error(w, "offset must be 0 or greater", http.StatusBadRequest)
return
}
offset = parsedOffset
}
installedSkills, err := buildOccupiedWorkspaceSkillsByDirectory(cfg)
if err != nil {
http.Error(w, fmt.Sprintf("Failed to inspect installed skills: %v", err), http.StatusInternalServerError)
return
}
if query == "" {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(skillSearchResponse{
Results: []skillSearchResultItem{},
Limit: limit,
Offset: offset,
HasMore: false,
})
return
}
registryMgr := newSkillsRegistryManager(cfg)
searchLimit := offset + limit + 1
if searchLimit > maxRegistrySearchFanout {
searchLimit = maxRegistrySearchFanout
}
results, err := registryMgr.SearchAll(r.Context(), query, searchLimit)
if err != nil {
http.Error(w, fmt.Sprintf("Failed to search skills: %v", err), http.StatusBadGateway)
return
}
if offset > len(results) {
offset = len(results)
}
end := offset + limit
if end > len(results) {
end = len(results)
}
pageResults := results[offset:end]
response := make([]skillSearchResultItem, 0, len(pageResults))
for _, result := range pageResults {
installedSkill, installed := installedSkills[result.Slug]
item := skillSearchResultItem{
Score: result.Score,
Slug: result.Slug,
DisplayName: result.DisplayName,
Summary: result.Summary,
Version: result.Version,
RegistryName: result.RegistryName,
URL: registrySkillURL(cfg, result.RegistryName, result.Slug),
Installed: installed,
}
if installed {
item.InstalledName = installedSkill.Name
}
response = append(response, item)
}
w.Header().Set("Content-Type", "application/json")
nextOffset := 0
hasMore := len(results) > end
if hasMore {
nextOffset = end
}
json.NewEncoder(w).Encode(skillSearchResponse{
Results: response,
Limit: limit,
Offset: offset,
NextOffset: nextOffset,
HasMore: hasMore,
})
}
func (h *Handler) handleInstallSkill(w http.ResponseWriter, r *http.Request) {
cfg, loadErr := config.LoadConfig(h.configPath)
if loadErr != nil {
http.Error(w, fmt.Sprintf("Failed to load config: %v", loadErr), http.StatusInternalServerError)
return
}
if registryErr := ensureSkillRegistryToolEnabled(cfg, "install_skill"); registryErr != nil {
http.Error(w, registryErr.Error(), http.StatusBadRequest)
return
}
var req installSkillRequest
if decodeErr := json.NewDecoder(r.Body).Decode(&req); decodeErr != nil {
http.Error(w, fmt.Sprintf("Invalid JSON: %v", decodeErr), http.StatusBadRequest)
return
}
req.Slug = strings.TrimSpace(req.Slug)
req.Registry = strings.TrimSpace(req.Registry)
req.Version = strings.TrimSpace(req.Version)
if validateErr := utils.ValidateSkillIdentifier(req.Slug); validateErr != nil {
http.Error(
w,
fmt.Sprintf("invalid slug %q: error: %s", req.Slug, validateErr.Error()),
http.StatusBadRequest,
)
return
}
if validateErr := utils.ValidateSkillIdentifier(req.Registry); validateErr != nil {
http.Error(
w,
fmt.Sprintf("invalid registry %q: error: %s", req.Registry, validateErr.Error()),
http.StatusBadRequest,
)
return
}
registryMgr := newSkillsRegistryManager(cfg)
registry := registryMgr.GetRegistry(req.Registry)
if registry == nil {
http.Error(w, fmt.Sprintf("registry %q not found", req.Registry), http.StatusBadRequest)
return
}
workspace := cfg.WorkspacePath()
skillsRoot := filepath.Join(workspace, "skills")
targetDir := filepath.Join(workspace, "skills", req.Slug)
workspaceSkillWriteMu.Lock()
defer workspaceSkillWriteMu.Unlock()
targetExists := false
if _, statErr := os.Stat(targetDir); statErr == nil {
targetExists = true
} else if !os.IsNotExist(statErr) {
http.Error(w, fmt.Sprintf("Failed to inspect install target: %v", statErr), http.StatusInternalServerError)
return
}
if !req.Force && targetExists {
http.Error(w, fmt.Sprintf("skill %q already installed at %s", req.Slug, targetDir), http.StatusConflict)
return
}
if err := os.MkdirAll(skillsRoot, 0o755); err != nil {
http.Error(w, fmt.Sprintf("Failed to create skills directory: %v", err), http.StatusInternalServerError)
return
}
stagedWorkspaceRoot, stagedTargetDir, err := createStagedSkillInstall(skillsRoot, req.Slug)
if err != nil {
http.Error(w, fmt.Sprintf("Failed to prepare staged install: %v", err), http.StatusInternalServerError)
return
}
defer os.RemoveAll(stagedWorkspaceRoot)
result, err := registry.DownloadAndInstall(r.Context(), req.Slug, req.Version, stagedTargetDir)
if err != nil {
http.Error(w, fmt.Sprintf("Failed to install skill: %v", err), http.StatusBadGateway)
return
}
if result.IsMalwareBlocked {
http.Error(
w,
fmt.Sprintf("skill %q is flagged as malicious and cannot be installed", req.Slug),
http.StatusForbidden,
)
return
}
if findWorkspaceSkillInfoByDirectory(stagedWorkspaceRoot, req.Slug) == nil {
http.Error(
w,
fmt.Sprintf("Failed to install skill: registry archive for %q is not a valid skill", req.Slug),
http.StatusBadGateway,
)
return
}
installedAt := time.Now().UnixMilli()
if err := persistSkillOriginMeta(stagedTargetDir, installedSkillOriginMeta{
Version: 1,
OriginKind: "third_party",
Registry: registry.Name(),
Slug: req.Slug,
RegistryURL: registrySkillURL(cfg, registry.Name(), req.Slug),
InstalledVersion: result.Version,
InstalledAt: installedAt,
}); err != nil {
http.Error(w, fmt.Sprintf("Failed to persist skill metadata: %v", err), http.StatusInternalServerError)
return
}
if err := commitStagedSkillInstall(
stagedWorkspaceRoot,
stagedTargetDir,
targetDir,
req.Force && targetExists,
); err != nil {
http.Error(w, fmt.Sprintf("Failed to activate installed skill: %v", err), http.StatusInternalServerError)
return
}
validatedSkill := findWorkspaceSkillByDirectory(cfg, req.Slug)
if validatedSkill == nil {
http.Error(
w,
fmt.Sprintf("Failed to install skill: activated archive for %q is not a valid skill", req.Slug),
http.StatusBadGateway,
)
return
}
installedSkill := &skillSupportItem{
Name: validatedSkill.Name,
Path: validatedSkill.Path,
Source: validatedSkill.Source,
Description: validatedSkill.Description,
OriginKind: "third_party",
RegistryName: registry.Name(),
RegistryURL: registrySkillURL(cfg, registry.Name(), req.Slug),
InstalledVersion: result.Version,
InstalledAt: installedAt,
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(installSkillResponse{
Status: "ok",
Slug: req.Slug,
Registry: registry.Name(),
Version: result.Version,
Summary: result.Summary,
IsSuspicious: result.IsSuspicious,
InstalledSkill: installedSkill,
})
}
func (h *Handler) handleImportSkill(w http.ResponseWriter, r *http.Request) {
cfg, err := config.LoadConfig(h.configPath)
if err != nil {
@@ -110,54 +448,26 @@ func (h *Handler) handleImportSkill(w http.ResponseWriter, r *http.Request) {
}
defer uploadedFile.Close()
content, err := io.ReadAll(io.LimitReader(uploadedFile, (1<<20)+1))
content, err := io.ReadAll(io.LimitReader(uploadedFile, maxImportedSkillSize+1))
if err != nil {
http.Error(w, fmt.Sprintf("Failed to read file: %v", err), http.StatusBadRequest)
return
}
if len(content) > 1<<20 {
if len(content) > maxImportedSkillSize {
http.Error(w, "file exceeds 1MB limit", http.StatusBadRequest)
return
}
workspaceSkillWriteMu.Lock()
defer workspaceSkillWriteMu.Unlock()
skillName, err := normalizeImportedSkillName(fileHeader.Filename, content)
importedSkill, statusCode, err := importUploadedSkill(cfg, fileHeader.Filename, content)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
http.Error(w, err.Error(), statusCode)
return
}
content = normalizeImportedSkillContent(content, skillName)
workspace := cfg.WorkspacePath()
skillDir := filepath.Join(workspace, "skills", skillName)
skillFile := filepath.Join(skillDir, "SKILL.md")
if _, err := os.Stat(skillDir); err == nil {
http.Error(w, "skill already exists", http.StatusConflict)
return
}
if err := os.MkdirAll(skillDir, 0o755); err != nil {
http.Error(w, fmt.Sprintf("Failed to create skill directory: %v", err), http.StatusInternalServerError)
return
}
if err := os.WriteFile(skillFile, content, 0o644); err != nil {
http.Error(w, fmt.Sprintf("Failed to save skill: %v", err), http.StatusInternalServerError)
return
}
loader := newSkillsLoader(workspace)
for _, skill := range loader.ListSkills() {
if skill.Path == skillFile || (skill.Name == skillName && skill.Source == "workspace") {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(skill)
return
}
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]string{
"name": skillName,
"path": skillFile,
})
json.NewEncoder(w).Encode(importedSkill)
}
func (h *Handler) handleDeleteSkill(w http.ResponseWriter, r *http.Request) {
@@ -169,6 +479,9 @@ func (h *Handler) handleDeleteSkill(w http.ResponseWriter, r *http.Request) {
loader := newSkillsLoader(cfg.WorkspacePath())
name := r.PathValue("name")
workspaceSkillWriteMu.Lock()
defer workspaceSkillWriteMu.Unlock()
for _, skill := range loader.ListSkills() {
if skill.Name != name {
continue
@@ -197,12 +510,274 @@ func newSkillsLoader(workspace string) *skills.SkillsLoader {
)
}
func newSkillsRegistryManager(cfg *config.Config) *skills.RegistryManager {
clawHubConfig := cfg.Tools.Skills.Registries.ClawHub
return skills.NewRegistryManagerFromConfig(skills.RegistryConfig{
MaxConcurrentSearches: cfg.Tools.Skills.MaxConcurrentSearches,
ClawHub: skills.ClawHubConfig{
Enabled: clawHubConfig.Enabled,
BaseURL: clawHubConfig.BaseURL,
AuthToken: clawHubConfig.AuthToken.String(),
SearchPath: clawHubConfig.SearchPath,
SkillsPath: clawHubConfig.SkillsPath,
DownloadPath: clawHubConfig.DownloadPath,
Timeout: clawHubConfig.Timeout,
MaxZipSize: clawHubConfig.MaxZipSize,
MaxResponseSize: clawHubConfig.MaxResponseSize,
},
})
}
func ensureSkillRegistryToolEnabled(cfg *config.Config, toolName string) error {
if !cfg.Tools.IsToolEnabled("skills") {
return fmt.Errorf("tools.skills is disabled")
}
if !cfg.Tools.IsToolEnabled(toolName) {
return fmt.Errorf("%s is disabled", toolName)
}
return nil
}
func buildSkillSupportItems(cfg *config.Config) ([]skillSupportItem, error) {
rawSkills := newSkillsLoader(cfg.WorkspacePath()).ListSkills()
items := make([]skillSupportItem, 0, len(rawSkills))
for _, skill := range rawSkills {
item, err := enrichSkillInfo(cfg, skill)
if err != nil {
return nil, err
}
items = append(items, item)
}
return items, nil
}
func buildWorkspaceSkillItemsByDirectory(cfg *config.Config) (map[string]skillSupportItem, error) {
result := make(map[string]skillSupportItem)
items, err := buildSkillSupportItems(cfg)
if err != nil {
return nil, err
}
for _, skill := range items {
if skill.Source != "workspace" {
continue
}
dir := filepath.Base(filepath.Dir(skill.Path))
if dir == "" {
continue
}
result[dir] = skill
}
return result, nil
}
func buildOccupiedWorkspaceSkillsByDirectory(cfg *config.Config) (map[string]skillSupportItem, error) {
result := make(map[string]skillSupportItem)
items, err := buildSkillSupportItems(cfg)
if err != nil {
return nil, err
}
for _, skill := range items {
if skill.Source != "workspace" {
continue
}
key := filepath.Base(filepath.Dir(skill.Path))
if meta, err := readInstalledSkillOriginMeta(skill.Path); err == nil && meta != nil && meta.Slug != "" {
key = meta.Slug
}
if key == "" {
continue
}
result[key] = skill
}
return result, nil
}
func findWorkspaceSkillByDirectory(cfg *config.Config, directory string) *skillSupportItem {
items, err := buildWorkspaceSkillItemsByDirectory(cfg)
if err != nil {
return nil
}
skill, ok := items[directory]
if !ok {
return nil
}
return &skill
}
func findWorkspaceSkillInfoByDirectory(workspace, directory string) *skills.SkillInfo {
loader := skills.NewSkillsLoader(workspace, "", "")
for _, skill := range loader.ListSkills() {
if skill.Source != "workspace" {
continue
}
if filepath.Base(filepath.Dir(skill.Path)) != directory {
continue
}
skillCopy := skill
return &skillCopy
}
return nil
}
func createStagedSkillInstall(skillsRoot, slug string) (string, string, error) {
stagedWorkspaceRoot, err := os.MkdirTemp(skillsRoot, "."+slug+"-install-*")
if err != nil {
return "", "", err
}
stagedTargetDir := filepath.Join(stagedWorkspaceRoot, "skills", slug)
return stagedWorkspaceRoot, stagedTargetDir, nil
}
func commitStagedSkillInstall(stagedWorkspaceRoot, stagedTargetDir, targetDir string, replaceExisting bool) error {
if !replaceExisting {
return os.Rename(stagedTargetDir, targetDir)
}
backupDir, err := reserveTempDirPath(filepath.Dir(targetDir), "."+filepath.Base(targetDir)+"-backup-*")
if err != nil {
return err
}
if err := os.Rename(targetDir, backupDir); err != nil {
return fmt.Errorf("failed to move existing skill aside: %w", err)
}
if err := os.Rename(stagedTargetDir, targetDir); err != nil {
if rollbackErr := os.Rename(backupDir, targetDir); rollbackErr != nil {
return fmt.Errorf("failed to activate replacement: %w (rollback failed: %v)", err, rollbackErr)
}
return fmt.Errorf("failed to activate replacement: %w", err)
}
_ = os.RemoveAll(backupDir)
_ = os.RemoveAll(stagedWorkspaceRoot)
return nil
}
func reserveTempDirPath(parent, pattern string) (string, error) {
tempDir, err := os.MkdirTemp(parent, pattern)
if err != nil {
return "", err
}
if err := os.Remove(tempDir); err != nil {
return "", err
}
return tempDir, nil
}
func enrichSkillInfo(cfg *config.Config, skill skills.SkillInfo) (skillSupportItem, error) {
item := skillSupportItem{
Name: skill.Name,
Path: skill.Path,
Source: skill.Source,
Description: skill.Description,
OriginKind: "builtin",
}
switch skill.Source {
case "builtin":
item.OriginKind = "builtin"
case "global":
item.OriginKind = "builtin"
case "workspace":
meta, err := readInstalledSkillOriginMeta(skill.Path)
if err == nil && meta != nil {
switch meta.OriginKind {
case "manual":
item.OriginKind = "manual"
item.InstalledAt = meta.InstalledAt
case "third_party":
item.OriginKind = "third_party"
item.RegistryName = meta.Registry
item.RegistryURL = registrySkillURLFromMeta(cfg, meta)
item.InstalledVersion = meta.InstalledVersion
item.InstalledAt = meta.InstalledAt
default:
if meta.Registry != "" || meta.Slug != "" || meta.InstalledVersion != "" {
item.OriginKind = "third_party"
item.RegistryName = meta.Registry
item.RegistryURL = registrySkillURLFromMeta(cfg, meta)
item.InstalledVersion = meta.InstalledVersion
item.InstalledAt = meta.InstalledAt
} else {
item.OriginKind = "builtin"
item.InstalledAt = meta.InstalledAt
}
}
} else {
item.OriginKind = "builtin"
}
default:
item.OriginKind = "builtin"
}
return item, nil
}
func readInstalledSkillOriginMeta(skillPath string) (*installedSkillOriginMeta, error) {
metaPath := filepath.Join(filepath.Dir(skillPath), ".skill-origin.json")
data, err := os.ReadFile(metaPath)
if err != nil {
if os.IsNotExist(err) {
return nil, nil
}
return nil, err
}
var meta installedSkillOriginMeta
if err := json.Unmarshal(data, &meta); err != nil {
return nil, err
}
return &meta, nil
}
func writeSkillOriginMeta(targetDir string, meta installedSkillOriginMeta) error {
data, err := json.MarshalIndent(meta, "", " ")
if err != nil {
return err
}
return fileutil.WriteFileAtomic(filepath.Join(targetDir, ".skill-origin.json"), data, 0o600)
}
func registrySkillURL(cfg *config.Config, registryName, slug string) string {
switch registryName {
case "clawhub":
baseURL := strings.TrimRight(cfg.Tools.Skills.Registries.ClawHub.BaseURL, "/")
if baseURL == "" {
baseURL = "https://clawhub.ai"
}
return baseURL + "/skills/" + url.PathEscape(slug)
default:
return ""
}
}
func registrySkillURLFromMeta(cfg *config.Config, meta *installedSkillOriginMeta) string {
if meta == nil || meta.Slug == "" {
return ""
}
if meta.RegistryURL != "" {
return meta.RegistryURL
}
if cfg == nil || meta.Registry == "" {
return ""
}
return registrySkillURL(cfg, meta.Registry, meta.Slug)
}
func normalizeImportedSkillName(filename string, content []byte) (string, error) {
return normalizeImportedSkillNameWithHint(filename, "", content)
}
func normalizeImportedSkillNameWithHint(filename, directoryHint string, content []byte) (string, error) {
rawContent := strings.ReplaceAll(string(content), "\r\n", "\n")
rawContent = strings.ReplaceAll(rawContent, "\r", "\n")
metadata, _ := extractImportedSkillMetadata(rawContent)
raw := strings.TrimSpace(metadata["name"])
if raw == "" {
raw = strings.TrimSpace(directoryHint)
}
if raw == "" {
raw = strings.TrimSpace(strings.TrimSuffix(filepath.Base(filename), filepath.Ext(filename)))
}
@@ -259,6 +834,210 @@ func normalizeImportedSkillContent(content []byte, skillName string) []byte {
return []byte(builder.String())
}
func importUploadedSkill(cfg *config.Config, filename string, content []byte) (*skillSupportItem, int, error) {
if isImportedSkillArchive(filename, content) {
return importUploadedSkillArchive(cfg, filename, content)
}
return importUploadedMarkdownSkill(cfg, filename, content)
}
func importUploadedMarkdownSkill(cfg *config.Config, filename string, content []byte) (*skillSupportItem, int, error) {
skillName, err := normalizeImportedSkillName(filename, content)
if err != nil {
return nil, http.StatusBadRequest, err
}
normalizedContent := normalizeImportedSkillContent(content, skillName)
workspace := cfg.WorkspacePath()
skillDir := filepath.Join(workspace, "skills", skillName)
skillFile := filepath.Join(skillDir, "SKILL.md")
if err := ensureWorkspaceSkillDoesNotExist(skillDir); err != nil {
return nil, statusCodeForImportedSkillWriteError(err), err
}
if err := os.MkdirAll(skillDir, 0o755); err != nil {
return nil, http.StatusInternalServerError, fmt.Errorf("Failed to create skill directory: %v", err)
}
if err := fileutil.WriteFileAtomic(skillFile, normalizedContent, 0o644); err != nil {
_ = os.RemoveAll(skillDir)
return nil, http.StatusInternalServerError, fmt.Errorf("Failed to save skill: %v", err)
}
return finalizeImportedSkill(cfg, skillDir, skillName, false)
}
func importUploadedSkillArchive(cfg *config.Config, filename string, content []byte) (*skillSupportItem, int, error) {
tmpDir, tempDirErr := os.MkdirTemp("", "picoclaw-skill-import-*")
if tempDirErr != nil {
return nil, http.StatusInternalServerError, fmt.Errorf("Failed to create temp directory: %v", tempDirErr)
}
defer os.RemoveAll(tmpDir)
archivePath := filepath.Join(tmpDir, "import.zip")
if writeErr := fileutil.WriteFileAtomic(archivePath, content, 0o600); writeErr != nil {
return nil, http.StatusInternalServerError, fmt.Errorf("Failed to stage uploaded archive: %v", writeErr)
}
extractDir := filepath.Join(tmpDir, "extract")
if extractErr := utils.ExtractZipFile(archivePath, extractDir); extractErr != nil {
return nil, http.StatusBadRequest, fmt.Errorf("invalid ZIP archive: %w", extractErr)
}
skillRoot, err := findImportedSkillRoot(extractDir)
if err != nil {
return nil, http.StatusBadRequest, err
}
skillFile := filepath.Join(skillRoot, "SKILL.md")
skillContent, err := os.ReadFile(skillFile)
if err != nil {
return nil, http.StatusBadRequest, fmt.Errorf("failed to read SKILL.md from archive: %w", err)
}
directoryHint := ""
if filepath.Clean(skillRoot) != filepath.Clean(extractDir) {
directoryHint = filepath.Base(skillRoot)
}
skillName, err := normalizeImportedSkillNameWithHint(filename, directoryHint, skillContent)
if err != nil {
return nil, http.StatusBadRequest, err
}
workspace := cfg.WorkspacePath()
skillDir := filepath.Join(workspace, "skills", skillName)
if err := ensureWorkspaceSkillDoesNotExist(skillDir); err != nil {
return nil, statusCodeForImportedSkillWriteError(err), err
}
if err := copyImportedSkillTree(skillRoot, skillDir); err != nil {
_ = os.RemoveAll(skillDir)
return nil, http.StatusInternalServerError, fmt.Errorf("Failed to save skill: %v", err)
}
normalizedContent := normalizeImportedSkillContent(skillContent, skillName)
if err := fileutil.WriteFileAtomic(filepath.Join(skillDir, "SKILL.md"), normalizedContent, 0o644); err != nil {
_ = os.RemoveAll(skillDir)
return nil, http.StatusInternalServerError, fmt.Errorf("Failed to normalize skill: %v", err)
}
return finalizeImportedSkill(cfg, skillDir, skillName, true)
}
func isImportedSkillArchive(filename string, content []byte) bool {
if strings.EqualFold(filepath.Ext(filename), ".zip") {
return true
}
return len(content) >= 4 && bytes.HasPrefix(content, []byte("PK\x03\x04"))
}
func ensureWorkspaceSkillDoesNotExist(skillDir string) error {
if _, err := os.Stat(skillDir); err == nil {
return errImportedSkillExists
} else if !os.IsNotExist(err) {
return fmt.Errorf("failed to inspect skill directory: %w", err)
}
return nil
}
func statusCodeForImportedSkillWriteError(err error) int {
if err == nil {
return http.StatusOK
}
if errors.Is(err, errImportedSkillExists) {
return http.StatusConflict
}
return http.StatusInternalServerError
}
func finalizeImportedSkill(
cfg *config.Config,
skillDir string,
skillName string,
requireValidatedSkill bool,
) (*skillSupportItem, int, error) {
if err := persistSkillOriginMeta(skillDir, installedSkillOriginMeta{
Version: 1,
OriginKind: "manual",
InstalledAt: time.Now().UnixMilli(),
}); err != nil {
_ = os.RemoveAll(skillDir)
return nil, http.StatusInternalServerError, fmt.Errorf("Failed to persist skill metadata: %v", err)
}
if importedSkill := findWorkspaceSkillByDirectory(cfg, skillName); importedSkill != nil {
return importedSkill, http.StatusOK, nil
}
if requireValidatedSkill {
_ = os.RemoveAll(skillDir)
return nil, http.StatusBadRequest, fmt.Errorf("imported archive is not a valid skill")
}
return &skillSupportItem{
Name: skillName,
Path: filepath.Join(skillDir, "SKILL.md"),
Source: "workspace",
Description: "Imported skill",
OriginKind: "manual",
}, http.StatusOK, nil
}
func findImportedSkillRoot(extractDir string) (string, error) {
skillFiles := make([]string, 0, 1)
err := filepath.WalkDir(extractDir, func(path string, d fs.DirEntry, walkErr error) error {
if walkErr != nil {
return walkErr
}
if d.IsDir() {
return nil
}
if d.Name() == "SKILL.md" {
skillFiles = append(skillFiles, path)
}
return nil
})
if err != nil {
return "", fmt.Errorf("failed to inspect ZIP archive: %w", err)
}
switch len(skillFiles) {
case 0:
return "", fmt.Errorf("ZIP archive must contain a SKILL.md file")
case 1:
return filepath.Dir(skillFiles[0]), nil
default:
return "", fmt.Errorf("ZIP archive must contain exactly one SKILL.md file")
}
}
func copyImportedSkillTree(srcDir, destDir string) error {
return filepath.WalkDir(srcDir, func(path string, d fs.DirEntry, walkErr error) error {
if walkErr != nil {
return walkErr
}
relPath, err := filepath.Rel(srcDir, path)
if err != nil {
return err
}
if relPath == "." {
return os.MkdirAll(destDir, 0o755)
}
destPath := filepath.Join(destDir, relPath)
info, err := d.Info()
if err != nil {
return err
}
if d.IsDir() {
return os.MkdirAll(destPath, 0o755)
}
if !info.Mode().IsRegular() {
return fmt.Errorf("archive contains unsupported file %q", relPath)
}
return fileutil.CopyFile(path, destPath, info.Mode().Perm())
})
}
func extractImportedSkillMetadata(raw string) (map[string]string, string) {
matches := importedSkillFrontmatter.FindStringSubmatch(raw)
if len(matches) != 2 {
File diff suppressed because it is too large Load Diff
+52
View File
@@ -0,0 +1,52 @@
package api
import (
"encoding/json"
"net/http"
"github.com/sipeed/picoclaw/pkg/updater"
)
// registerUpdateRoutes registers the self-update endpoint.
func (h *Handler) registerUpdateRoutes(mux *http.ServeMux) {
mux.HandleFunc("/api/update", h.handleUpdate)
}
type updateRequest struct {
URL string `json:"url,omitempty"`
Binary string `json:"binary,omitempty"`
}
type updateResponse struct {
Status string `json:"status"`
Message string `json:"message,omitempty"`
}
func (h *Handler) handleUpdate(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
w.WriteHeader(http.StatusMethodNotAllowed)
_ = json.NewEncoder(w).Encode(updateResponse{Status: "error", Message: "method not allowed"})
return
}
dec := json.NewDecoder(http.MaxBytesReader(w, r.Body, 1<<20))
var req updateRequest
if err := dec.Decode(&req); err != nil {
w.WriteHeader(http.StatusBadRequest)
_ = json.NewEncoder(w).Encode(updateResponse{Status: "error", Message: "invalid request body"})
return
}
binary := req.Binary
if binary == "" {
binary = "picoclaw-launcher"
}
if err := updater.UpdateSelfFromRelease(req.URL, "", "", binary); err != nil {
w.WriteHeader(http.StatusInternalServerError)
_ = json.NewEncoder(w).Encode(updateResponse{Status: "error", Message: err.Error()})
return
}
_ = json.NewEncoder(w).Encode(updateResponse{Status: "ok", Message: "update applied; restart to use new version"})
}