mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
merge: integrate main into refactor-inbound-context-routing-session
This commit is contained in:
+187
-28
@@ -1,8 +1,10 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/subtle"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -10,34 +12,47 @@ import (
|
||||
"github.com/sipeed/picoclaw/web/backend/middleware"
|
||||
)
|
||||
|
||||
// LauncherAuthRouteOpts configures dashboard token login handlers.
|
||||
// PasswordStore is the interface for bcrypt-backed dashboard password persistence.
|
||||
// Implemented by dashboardauth.Store; a nil value falls back to the legacy
|
||||
// static-token comparison.
|
||||
type PasswordStore interface {
|
||||
IsInitialized(ctx context.Context) (bool, error)
|
||||
SetPassword(ctx context.Context, plain string) error
|
||||
VerifyPassword(ctx context.Context, plain string) (bool, error)
|
||||
}
|
||||
|
||||
// LauncherAuthRouteOpts configures dashboard auth handlers.
|
||||
type LauncherAuthRouteOpts struct {
|
||||
// DashboardToken is the fallback plaintext token used when PasswordStore is
|
||||
// nil or not yet initialized (env-var / config-file source, and ?token= auto-login).
|
||||
DashboardToken string
|
||||
SessionCookie string
|
||||
SecureCookie func(*http.Request) bool
|
||||
// TokenHelp is returned on unauthenticated /api/auth/status responses (no secrets).
|
||||
TokenHelp LauncherAuthTokenHelp
|
||||
}
|
||||
|
||||
// LauncherAuthTokenHelp tells the login UI where users can find the dashboard token.
|
||||
type LauncherAuthTokenHelp struct {
|
||||
EnvVarName string `json:"env_var_name"`
|
||||
LogFileAbs string `json:"log_file,omitempty"`
|
||||
ConfigFileAbs string `json:"config_file,omitempty"`
|
||||
TrayCopyMenu bool `json:"tray_copy_menu"`
|
||||
ConsoleStdout bool `json:"console_stdout"`
|
||||
// PasswordStore enables bcrypt-backed password persistence. When non-nil and
|
||||
// initialized, web-form login verifies against the stored hash instead of
|
||||
// the plaintext DashboardToken.
|
||||
PasswordStore PasswordStore
|
||||
// StoreError holds the error returned when opening the password store. When
|
||||
// non-nil and PasswordStore is nil, the auth endpoints surface a recovery
|
||||
// message instead of an opaque 501/503.
|
||||
StoreError error
|
||||
}
|
||||
|
||||
type launcherAuthLoginBody struct {
|
||||
Token string `json:"token"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
type launcherAuthSetupBody struct {
|
||||
Password string `json:"password"`
|
||||
Confirm string `json:"confirm"`
|
||||
}
|
||||
|
||||
type launcherAuthStatusResponse struct {
|
||||
Authenticated bool `json:"authenticated"`
|
||||
TokenHelp *LauncherAuthTokenHelp `json:"token_help,omitempty"`
|
||||
Authenticated bool `json:"authenticated"`
|
||||
Initialized bool `json:"initialized"`
|
||||
}
|
||||
|
||||
// RegisterLauncherAuthRoutes registers /api/auth/login|logout|status.
|
||||
// RegisterLauncherAuthRoutes registers /api/auth/login|logout|status|setup.
|
||||
func RegisterLauncherAuthRoutes(mux *http.ServeMux, opts LauncherAuthRouteOpts) {
|
||||
secure := opts.SecureCookie
|
||||
if secure == nil {
|
||||
@@ -47,22 +62,52 @@ func RegisterLauncherAuthRoutes(mux *http.ServeMux, opts LauncherAuthRouteOpts)
|
||||
token: opts.DashboardToken,
|
||||
sessionCookie: opts.SessionCookie,
|
||||
secureCookie: secure,
|
||||
tokenHelp: opts.TokenHelp,
|
||||
store: opts.PasswordStore,
|
||||
storeErr: opts.StoreError,
|
||||
loginLimit: newLoginRateLimiter(),
|
||||
}
|
||||
mux.HandleFunc("POST /api/auth/login", h.handleLogin)
|
||||
mux.HandleFunc("POST /api/auth/logout", h.handleLogout)
|
||||
mux.HandleFunc("GET /api/auth/status", h.handleStatus)
|
||||
mux.HandleFunc("POST /api/auth/setup", h.handleSetup)
|
||||
}
|
||||
|
||||
type launcherAuthHandlers struct {
|
||||
token string
|
||||
sessionCookie string
|
||||
secureCookie func(*http.Request) bool
|
||||
tokenHelp LauncherAuthTokenHelp
|
||||
store PasswordStore
|
||||
storeErr error // set when the store failed to open; drives recovery messages
|
||||
loginLimit *loginRateLimiter
|
||||
}
|
||||
|
||||
func (h *launcherAuthHandlers) usesLegacyTokenAuth() bool {
|
||||
return h.store == nil && h.storeErr == nil && h.token != ""
|
||||
}
|
||||
|
||||
// isStoreInitialized safely queries the store.
|
||||
// Returns (true, nil) when legacy token auth is active without a password store.
|
||||
// Returns (false, nil) when no store/token fallback is configured.
|
||||
// Returns (false, err) on store errors — callers must treat this as a 5xx, not as
|
||||
// "uninitialized", to keep auth fail-closed.
|
||||
// Exception: handleLogin swallows storeErr and falls back to token auth so
|
||||
// that a corrupt DB does not lock out all access.
|
||||
func (h *launcherAuthHandlers) isStoreInitialized(ctx context.Context) (bool, error) {
|
||||
if h.store == nil {
|
||||
if h.storeErr != nil {
|
||||
return false, fmt.Errorf(
|
||||
"password store unavailable (%w); "+
|
||||
"to recover, stop the application, delete the database file and restart ",
|
||||
h.storeErr)
|
||||
}
|
||||
if h.usesLegacyTokenAuth() {
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
return h.store.IsInitialized(ctx)
|
||||
}
|
||||
|
||||
func (h *launcherAuthHandlers) handleLogin(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
var body launcherAuthLoginBody
|
||||
@@ -77,10 +122,39 @@ func (h *launcherAuthHandlers) handleLogin(w http.ResponseWriter, r *http.Reques
|
||||
_, _ = w.Write([]byte(`{"error":"too many login attempts"}`))
|
||||
return
|
||||
}
|
||||
in := strings.TrimSpace(body.Token)
|
||||
if len(in) != len(h.token) || subtle.ConstantTimeCompare([]byte(in), []byte(h.token)) != 1 {
|
||||
in := strings.TrimSpace(body.Password)
|
||||
var ok bool
|
||||
|
||||
initialized, initErr := h.isStoreInitialized(r.Context())
|
||||
if initErr != nil {
|
||||
if h.storeErr != nil {
|
||||
// Store failed to open at startup — token login remains available.
|
||||
initialized = false
|
||||
} else {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
writeErrorf(w, "%v", initErr)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if initialized && h.store != nil {
|
||||
// Bcrypt path: verify against the stored hash.
|
||||
var err error
|
||||
ok, err = h.store.VerifyPassword(r.Context(), in)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
writeErrorf(w, "password verification failed: %v", err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// Fallback: constant-time compare against the plaintext token.
|
||||
ok = len(in) == len(h.token) &&
|
||||
subtle.ConstantTimeCompare([]byte(in), []byte(h.token)) == 1
|
||||
}
|
||||
|
||||
if !ok {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = w.Write([]byte(`{"error":"invalid token"}`))
|
||||
_, _ = w.Write([]byte(`{"error":"invalid password"}`))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -121,23 +195,108 @@ func (h *launcherAuthHandlers) handleLogout(w http.ResponseWriter, r *http.Reque
|
||||
|
||||
func (h *launcherAuthHandlers) handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
ok := false
|
||||
authed := false
|
||||
if c, err := r.Cookie(middleware.LauncherDashboardCookieName); err == nil {
|
||||
ok = subtle.ConstantTimeCompare([]byte(c.Value), []byte(h.sessionCookie)) == 1
|
||||
authed = subtle.ConstantTimeCompare([]byte(c.Value), []byte(h.sessionCookie)) == 1
|
||||
}
|
||||
if ok {
|
||||
_, _ = w.Write([]byte(`{"authenticated":true}`))
|
||||
initialized, initErr := h.isStoreInitialized(r.Context())
|
||||
if initErr != nil {
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
writeErrorf(w, "%v", initErr)
|
||||
return
|
||||
}
|
||||
resp := launcherAuthStatusResponse{
|
||||
Authenticated: false,
|
||||
TokenHelp: &h.tokenHelp,
|
||||
Authenticated: authed,
|
||||
Initialized: initialized,
|
||||
}
|
||||
enc, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
_, _ = w.Write([]byte(`{"error":"internal error"}`))
|
||||
writeErrorf(w, "marshal response failed: %v", err)
|
||||
return
|
||||
}
|
||||
_, _ = w.Write(enc)
|
||||
}
|
||||
|
||||
// handleSetup sets or changes the dashboard password.
|
||||
//
|
||||
// Rules:
|
||||
// - If the store has no password yet, the endpoint is open (no session required).
|
||||
// - If a password is already set, the caller must hold a valid session cookie.
|
||||
func (h *launcherAuthHandlers) handleSetup(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
if h.usesLegacyTokenAuth() {
|
||||
w.WriteHeader(http.StatusNotImplemented)
|
||||
_, _ = w.Write(
|
||||
[]byte(`{"error":"password setup is unavailable on this platform; use the dashboard token instead"}`),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if h.store == nil {
|
||||
w.WriteHeader(http.StatusNotImplemented)
|
||||
_, _ = w.Write([]byte(`{"error":"password store not configured"}`))
|
||||
return
|
||||
}
|
||||
|
||||
initialized, initErr := h.isStoreInitialized(r.Context())
|
||||
if initErr != nil {
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
writeErrorf(w, "%v", initErr)
|
||||
return
|
||||
}
|
||||
|
||||
// If already initialized, require an active session (change-password flow).
|
||||
if initialized {
|
||||
authed := false
|
||||
if c, err := r.Cookie(middleware.LauncherDashboardCookieName); err == nil {
|
||||
authed = subtle.ConstantTimeCompare([]byte(c.Value), []byte(h.sessionCookie)) == 1
|
||||
}
|
||||
if !authed {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = w.Write([]byte(`{"error":"must be authenticated to change password"}`))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
var body launcherAuthSetupBody
|
||||
if err := json.NewDecoder(http.MaxBytesReader(w, r.Body, 1<<20)).Decode(&body); err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = w.Write([]byte(`{"error":"invalid JSON"}`))
|
||||
return
|
||||
}
|
||||
|
||||
pw := strings.TrimSpace(body.Password)
|
||||
if pw == "" {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = w.Write([]byte(`{"error":"password must not be empty"}`))
|
||||
return
|
||||
}
|
||||
if pw != strings.TrimSpace(body.Confirm) {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = w.Write([]byte(`{"error":"passwords do not match"}`))
|
||||
return
|
||||
}
|
||||
if len([]rune(pw)) < 8 {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = w.Write([]byte(`{"error":"password must be at least 8 characters"}`))
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.store.SetPassword(r.Context(), pw); err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
writeErrorf(w, "failed to save password: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(`{"status":"ok"}`))
|
||||
}
|
||||
|
||||
// writeErrorf writes a JSON error response with a formatted message.
|
||||
// json.Marshal is used to safely escape the message string.
|
||||
func writeErrorf(w http.ResponseWriter, format string, args ...any) {
|
||||
msg, _ := json.Marshal(fmt.Sprintf(format, args...))
|
||||
_, _ = w.Write([]byte(`{"error":` + string(msg) + `}`))
|
||||
}
|
||||
|
||||
@@ -23,12 +23,6 @@ func TestLauncherAuthLoginAndStatus(t *testing.T) {
|
||||
RegisterLauncherAuthRoutes(mux, LauncherAuthRouteOpts{
|
||||
DashboardToken: tok,
|
||||
SessionCookie: sess,
|
||||
TokenHelp: LauncherAuthTokenHelp{
|
||||
EnvVarName: "PICOCLAW_LAUNCHER_TOKEN",
|
||||
LogFileAbs: "/tmp/launcher.log",
|
||||
TrayCopyMenu: true,
|
||||
ConsoleStdout: false,
|
||||
},
|
||||
})
|
||||
|
||||
t.Run("status_unauthenticated", func(t *testing.T) {
|
||||
@@ -38,23 +32,20 @@ func TestLauncherAuthLoginAndStatus(t *testing.T) {
|
||||
t.Fatalf("status code = %d", rec.Code)
|
||||
}
|
||||
var body struct {
|
||||
Authenticated bool `json:"authenticated"`
|
||||
TokenHelp *LauncherAuthTokenHelp `json:"token_help"`
|
||||
Authenticated bool `json:"authenticated"`
|
||||
Initialized bool `json:"initialized"`
|
||||
}
|
||||
if err := json.NewDecoder(rec.Body).Decode(&body); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if body.Authenticated || body.TokenHelp == nil {
|
||||
t.Fatalf("unexpected body: %+v", body)
|
||||
}
|
||||
if body.TokenHelp.EnvVarName != "PICOCLAW_LAUNCHER_TOKEN" || body.TokenHelp.LogFileAbs != "/tmp/launcher.log" {
|
||||
t.Fatalf("token_help = %+v", body.TokenHelp)
|
||||
if body.Authenticated {
|
||||
t.Fatalf("unexpected authenticated=true: %+v", body)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("login_ok", func(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/auth/login", strings.NewReader(`{"token":"`+tok+`"}`))
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/auth/login", strings.NewReader(`{"password":"`+tok+`"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.RemoteAddr = "127.0.0.1:12345"
|
||||
mux.ServeHTTP(rec, req)
|
||||
@@ -84,6 +75,67 @@ func TestLauncherAuthLoginAndStatus(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestLauncherAuthLegacyTokenFallbackReportsInitialized(t *testing.T) {
|
||||
key := make([]byte, 32)
|
||||
const tok = "legacy-fallback-token"
|
||||
sess := middleware.SessionCookieValue(key, tok)
|
||||
mux := http.NewServeMux()
|
||||
RegisterLauncherAuthRoutes(mux, LauncherAuthRouteOpts{
|
||||
DashboardToken: tok,
|
||||
SessionCookie: sess,
|
||||
})
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
mux.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/api/auth/status", nil))
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status code = %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
|
||||
var body struct {
|
||||
Authenticated bool `json:"authenticated"`
|
||||
Initialized bool `json:"initialized"`
|
||||
}
|
||||
if err := json.NewDecoder(rec.Body).Decode(&body); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !body.Initialized {
|
||||
t.Fatalf("initialized = false, want true in legacy token fallback mode")
|
||||
}
|
||||
if body.Authenticated {
|
||||
t.Fatalf("unexpected authenticated=true: %+v", body)
|
||||
}
|
||||
|
||||
rec = httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/auth/login", strings.NewReader(`{"password":"`+tok+`"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
mux.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("login code = %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestLauncherAuthSetupRejectedInLegacyTokenFallback(t *testing.T) {
|
||||
key := make([]byte, 32)
|
||||
sess := middleware.SessionCookieValue(key, "legacy-token")
|
||||
mux := http.NewServeMux()
|
||||
RegisterLauncherAuthRoutes(mux, LauncherAuthRouteOpts{
|
||||
DashboardToken: "legacy-token",
|
||||
SessionCookie: sess,
|
||||
})
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(
|
||||
http.MethodPost,
|
||||
"/api/auth/setup",
|
||||
strings.NewReader(`{"password":"12345678","confirm":"12345678"}`),
|
||||
)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
mux.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusNotImplemented {
|
||||
t.Fatalf("setup code = %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestLauncherAuthLogoutRequiresPostAndJSON(t *testing.T) {
|
||||
key := make([]byte, 32)
|
||||
sess := middleware.SessionCookieValue(key, "tok")
|
||||
@@ -91,7 +143,6 @@ func TestLauncherAuthLogoutRequiresPostAndJSON(t *testing.T) {
|
||||
RegisterLauncherAuthRoutes(mux, LauncherAuthRouteOpts{
|
||||
DashboardToken: "tok",
|
||||
SessionCookie: sess,
|
||||
TokenHelp: LauncherAuthTokenHelp{EnvVarName: "PICOCLAW_LAUNCHER_TOKEN"},
|
||||
})
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
@@ -125,11 +176,10 @@ func TestLauncherAuthLoginRateLimit(t *testing.T) {
|
||||
RegisterLauncherAuthRoutes(mux, LauncherAuthRouteOpts{
|
||||
DashboardToken: tok,
|
||||
SessionCookie: sess,
|
||||
TokenHelp: LauncherAuthTokenHelp{EnvVarName: "X"},
|
||||
})
|
||||
|
||||
// 11 failing logins by wrong token; each consumes allow() slot after valid JSON.
|
||||
wrongBody := `{"token":"wrong"}`
|
||||
wrongBody := `{"password":"wrong"}`
|
||||
for i := 0; i < loginAttemptsPerIP; i++ {
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/auth/login", strings.NewReader(wrongBody))
|
||||
@@ -187,7 +237,6 @@ func TestLauncherAuthLogoutEmptyBody(t *testing.T) {
|
||||
RegisterLauncherAuthRoutes(mux, LauncherAuthRouteOpts{
|
||||
DashboardToken: "tok",
|
||||
SessionCookie: sess,
|
||||
TokenHelp: LauncherAuthTokenHelp{EnvVarName: "X"},
|
||||
})
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/auth/logout", nil)
|
||||
@@ -206,7 +255,6 @@ func TestLauncherAuthLogoutRejectsTrailingJSON(t *testing.T) {
|
||||
RegisterLauncherAuthRoutes(mux, LauncherAuthRouteOpts{
|
||||
DashboardToken: "tok",
|
||||
SessionCookie: sess,
|
||||
TokenHelp: LauncherAuthTokenHelp{EnvVarName: "X"},
|
||||
})
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/auth/logout", strings.NewReader(`{}{}`))
|
||||
|
||||
+168
-4
@@ -108,6 +108,8 @@ var gatewayHealthGet = func(url string, timeout time.Duration) (*http.Response,
|
||||
return client.Get(url)
|
||||
}
|
||||
|
||||
var gatewayProcessMatcher = isLikelyGatewayProcess
|
||||
|
||||
// getGatewayHealth checks the gateway health endpoint and returns the status response.
|
||||
// Returns (*health.StatusResponse, statusCode, error). If error is not nil, the other values are not valid.
|
||||
func (h *Handler) getGatewayHealth(cfg *config.Config, timeout time.Duration) (*health.StatusResponse, int, error) {
|
||||
@@ -117,7 +119,7 @@ func (h *Handler) getGatewayHealth(cfg *config.Config, timeout time.Duration) (*
|
||||
gateway.mu.Lock()
|
||||
if d := gateway.pidData; d != nil && d.Port > 0 {
|
||||
port = d.Port
|
||||
host = d.Host
|
||||
host = gatewayProbeHost(d.Host)
|
||||
}
|
||||
gateway.mu.Unlock()
|
||||
if port == 0 {
|
||||
@@ -150,6 +152,150 @@ func getGatewayHealthByURL(url string, timeout time.Duration) (*health.StatusRes
|
||||
return &healthResponse, resp.StatusCode, nil
|
||||
}
|
||||
|
||||
// isLikelyGatewayProcess returns whether PID appears to be a picoclaw gateway
|
||||
// process plus whether inspection was conclusive on this platform/environment.
|
||||
func isLikelyGatewayProcess(pid int) (bool, bool) {
|
||||
if pid <= 0 {
|
||||
return false, true
|
||||
}
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
psCmd := fmt.Sprintf(
|
||||
`$p=Get-CimInstance Win32_Process -Filter "ProcessId = %d"; if ($null -eq $p) { "" } else { $p.CommandLine }`,
|
||||
pid,
|
||||
)
|
||||
out, err := exec.Command("powershell", "-NoProfile", "-NonInteractive", "-Command", psCmd).Output()
|
||||
if err == nil {
|
||||
cmdline := strings.TrimSpace(string(out))
|
||||
if cmdline != "" {
|
||||
return looksLikeGatewayCommandLine(cmdline), true
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: determine only whether the process still exists.
|
||||
out, err = exec.Command("tasklist", "/FI", "PID eq "+strconv.Itoa(pid), "/FO", "CSV", "/NH").Output()
|
||||
if err != nil {
|
||||
return false, false
|
||||
}
|
||||
line := strings.ToLower(strings.TrimSpace(string(out)))
|
||||
if line == "" {
|
||||
return false, true
|
||||
}
|
||||
// A CSV row means the process exists, but may have a custom executable
|
||||
// name we cannot classify here.
|
||||
if strings.HasPrefix(line, "\"") {
|
||||
if strings.Contains(line, "\"picoclaw.exe\"") {
|
||||
return true, true
|
||||
}
|
||||
return false, false
|
||||
}
|
||||
if strings.Contains(line, "no tasks are running") {
|
||||
return false, true
|
||||
}
|
||||
return false, true
|
||||
}
|
||||
|
||||
out, err := exec.Command("ps", "-o", "command=", "-p", strconv.Itoa(pid)).Output()
|
||||
if err != nil {
|
||||
return false, false
|
||||
}
|
||||
cmdline := strings.ToLower(strings.TrimSpace(string(out)))
|
||||
if cmdline == "" {
|
||||
return false, true
|
||||
}
|
||||
return looksLikeGatewayCommandLine(cmdline), true
|
||||
}
|
||||
|
||||
// looksLikeGatewayCommandLine checks whether a process command line likely
|
||||
// represents "picoclaw gateway ..." regardless of executable filename.
|
||||
func looksLikeGatewayCommandLine(cmdline string) bool {
|
||||
fields := strings.Fields(strings.ToLower(strings.TrimSpace(cmdline)))
|
||||
if len(fields) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, f := range fields {
|
||||
token := strings.Trim(f, `"'`)
|
||||
if token == "gateway" || strings.HasSuffix(token, "/gateway") || strings.HasSuffix(token, `\gateway`) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *Handler) getGatewayHealthForPidData(
|
||||
pidData *ppid.PidFileData,
|
||||
cfg *config.Config,
|
||||
timeout time.Duration,
|
||||
) (*health.StatusResponse, int, error) {
|
||||
if pidData == nil {
|
||||
return nil, 0, errors.New("nil pid data")
|
||||
}
|
||||
|
||||
port := pidData.Port
|
||||
if port == 0 {
|
||||
port = 18790
|
||||
if cfg != nil && cfg.Gateway.Port != 0 {
|
||||
port = cfg.Gateway.Port
|
||||
}
|
||||
}
|
||||
|
||||
host := gatewayProbeHost(strings.TrimSpace(pidData.Host))
|
||||
if host == "" {
|
||||
host = gatewayProbeHost(h.effectiveGatewayBindHost(cfg))
|
||||
}
|
||||
if host == "" {
|
||||
host = "127.0.0.1"
|
||||
}
|
||||
|
||||
url := "http://" + net.JoinHostPort(host, strconv.Itoa(port)) + "/health"
|
||||
return getGatewayHealthByURL(url, timeout)
|
||||
}
|
||||
|
||||
func (h *Handler) validateGatewayPidData(
|
||||
pidData *ppid.PidFileData,
|
||||
cfg *config.Config,
|
||||
) (ok bool, decisive bool, reason string) {
|
||||
if pidData == nil || pidData.PID <= 0 {
|
||||
return false, true, "invalid pid data"
|
||||
}
|
||||
|
||||
if gatewayProcess, inspected := gatewayProcessMatcher(pidData.PID); inspected {
|
||||
if !gatewayProcess {
|
||||
return false, true, "pid process command is not picoclaw gateway"
|
||||
}
|
||||
return true, true, ""
|
||||
}
|
||||
|
||||
healthResp, statusCode, err := h.getGatewayHealthForPidData(pidData, cfg, 800*time.Millisecond)
|
||||
if err != nil {
|
||||
return false, false, fmt.Sprintf("health probe failed: %v", err)
|
||||
}
|
||||
if statusCode != http.StatusOK {
|
||||
return false, false, fmt.Sprintf("health endpoint returned status %d", statusCode)
|
||||
}
|
||||
if healthResp.PID > 0 && healthResp.PID != pidData.PID {
|
||||
return false, true, fmt.Sprintf("health pid mismatch: pidFile=%d, health=%d", pidData.PID, healthResp.PID)
|
||||
}
|
||||
return true, true, ""
|
||||
}
|
||||
|
||||
func (h *Handler) sanitizeGatewayPidData(pidData *ppid.PidFileData, cfg *config.Config) *ppid.PidFileData {
|
||||
if pidData == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
ok, decisive, reason := h.validateGatewayPidData(pidData, cfg)
|
||||
if ok {
|
||||
return pidData
|
||||
}
|
||||
|
||||
logger.Warnf("ignore pid file for PID %d: %s", pidData.PID, reason)
|
||||
if decisive && ppid.RemovePidFileIfPID(globalConfigDir(), pidData.PID) {
|
||||
logger.Warnf("removed stale pid file for PID %d", pidData.PID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// registerGatewayRoutes binds gateway lifecycle endpoints to the ServeMux.
|
||||
func (h *Handler) registerGatewayRoutes(mux *http.ServeMux) {
|
||||
mux.HandleFunc("GET /api/gateway/status", h.handleGatewayStatus)
|
||||
@@ -164,7 +310,7 @@ func (h *Handler) registerGatewayRoutes(mux *http.ServeMux) {
|
||||
// starts it when possible. Intended to be called by the backend at startup.
|
||||
func (h *Handler) TryAutoStartGateway() {
|
||||
// Check PID file first to detect an already-running gateway.
|
||||
pidData := ppid.ReadPidFileWithCheck(globalConfigDir())
|
||||
pidData := h.sanitizeGatewayPidData(ppid.ReadPidFileWithCheck(globalConfigDir()), nil)
|
||||
if pidData != nil {
|
||||
gateway.mu.Lock()
|
||||
ready, reason, err := h.gatewayStartReady()
|
||||
@@ -472,6 +618,11 @@ func stopGatewayLocked() (int, error) {
|
||||
}
|
||||
|
||||
pid := gateway.cmd.Process.Pid
|
||||
if !gateway.owned {
|
||||
if isGateway, inspected := gatewayProcessMatcher(pid); inspected && !isGateway {
|
||||
return pid, fmt.Errorf("refuse to stop non-gateway process (PID %d)", pid)
|
||||
}
|
||||
}
|
||||
|
||||
// Send SIGTERM for graceful shutdown (SIGKILL on Windows)
|
||||
var sigErr error
|
||||
@@ -681,7 +832,7 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
|
||||
// POST /api/gateway/start
|
||||
func (h *Handler) handleGatewayStart(w http.ResponseWriter, r *http.Request) {
|
||||
// Check PID file first to detect an already-running gateway.
|
||||
pidData := ppid.ReadPidFileWithCheck(globalConfigDir())
|
||||
pidData := h.sanitizeGatewayPidData(ppid.ReadPidFileWithCheck(globalConfigDir()), nil)
|
||||
if pidData != nil {
|
||||
pid := pidData.PID
|
||||
gateway.mu.Lock()
|
||||
@@ -807,9 +958,22 @@ func (h *Handler) RestartGateway() (int, error) {
|
||||
|
||||
gateway.mu.Lock()
|
||||
previousCmd := gateway.cmd
|
||||
previousOwned := gateway.owned
|
||||
setGatewayRuntimeStatusLocked("restarting")
|
||||
gateway.mu.Unlock()
|
||||
|
||||
if previousCmd != nil && previousCmd.Process != nil && !previousOwned {
|
||||
if isGateway, inspected := gatewayProcessMatcher(previousCmd.Process.Pid); inspected && !isGateway {
|
||||
logger.Warnf("refuse restarting non-gateway process (PID: %d)", previousCmd.Process.Pid)
|
||||
gateway.mu.Lock()
|
||||
if gateway.cmd == previousCmd {
|
||||
setGatewayRuntimeStatusLocked("running")
|
||||
}
|
||||
gateway.mu.Unlock()
|
||||
return 0, fmt.Errorf("refuse to restart non-gateway process (PID %d)", previousCmd.Process.Pid)
|
||||
}
|
||||
}
|
||||
|
||||
if err = stopGatewayProcessForRestart(previousCmd); err != nil {
|
||||
gateway.mu.Lock()
|
||||
if gateway.cmd == previousCmd {
|
||||
@@ -921,7 +1085,7 @@ func (h *Handler) gatewayStatusData() map[string]any {
|
||||
}
|
||||
|
||||
// Primary detection: read PID file and check if process is alive.
|
||||
pidData := ppid.ReadPidFileWithCheck(globalConfigDir())
|
||||
pidData := h.sanitizeGatewayPidData(ppid.ReadPidFileWithCheck(globalConfigDir()), cfg)
|
||||
if pidData != nil {
|
||||
gateway.mu.Lock()
|
||||
gateway.pidData = pidData
|
||||
|
||||
+234
-12
@@ -15,8 +15,6 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/auth"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
ppid "github.com/sipeed/picoclaw/pkg/pid"
|
||||
@@ -40,6 +38,36 @@ func startLongRunningProcess(t *testing.T) *exec.Cmd {
|
||||
return cmd
|
||||
}
|
||||
|
||||
func startGatewayLikeProcess(t *testing.T) *exec.Cmd {
|
||||
t.Helper()
|
||||
|
||||
var cmd *exec.Cmd
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("gateway-like process commandline check is not deterministic on Windows tests")
|
||||
}
|
||||
cmd = exec.Command("sh", "-c", "sleep 30 # picoclaw gateway")
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
t.Fatalf("Start() error = %v", err)
|
||||
}
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func writeTestPidFile(t *testing.T, data ppid.PidFileData) string {
|
||||
t.Helper()
|
||||
|
||||
path := filepath.Join(globalConfigDir(), ".picoclaw.pid")
|
||||
raw, err := json.MarshalIndent(data, "", " ")
|
||||
if err != nil {
|
||||
t.Fatalf("marshal pid file: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(path, raw, 0o600); err != nil {
|
||||
t.Fatalf("write pid file: %v", err)
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
func mockGatewayHealthResponse(statusCode, pid int) *http.Response {
|
||||
return &http.Response{
|
||||
StatusCode: statusCode,
|
||||
@@ -68,12 +96,14 @@ func resetGatewayTestState(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
originalHealthGet := gatewayHealthGet
|
||||
originalProcessMatcher := gatewayProcessMatcher
|
||||
originalRestartGracePeriod := gatewayRestartGracePeriod
|
||||
originalRestartForceKillWindow := gatewayRestartForceKillWindow
|
||||
originalRestartPollInterval := gatewayRestartPollInterval
|
||||
t.Setenv("PICOCLAW_HOME", t.TempDir())
|
||||
t.Cleanup(func() {
|
||||
gatewayHealthGet = originalHealthGet
|
||||
gatewayProcessMatcher = originalProcessMatcher
|
||||
gatewayRestartGracePeriod = originalRestartGracePeriod
|
||||
gatewayRestartForceKillWindow = originalRestartForceKillWindow
|
||||
gatewayRestartPollInterval = originalRestartPollInterval
|
||||
@@ -105,6 +135,105 @@ func TestGatewayStartReady_NoDefaultModel(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestLooksLikeGatewayCommandLine(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
cmdline string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "default picoclaw gateway",
|
||||
cmdline: "/usr/local/bin/picoclaw gateway -E",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "renamed binary with gateway subcommand",
|
||||
cmdline: "/opt/bin/custom-claw gateway -E -d",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "standalone gateway binary path",
|
||||
cmdline: "/opt/bin/gateway -E",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "non gateway process",
|
||||
cmdline: "/bin/sleep 30",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "gateway substring only",
|
||||
cmdline: "/opt/bin/gatewayd --serve",
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := looksLikeGatewayCommandLine(tc.cmdline)
|
||||
if got != tc.want {
|
||||
t.Fatalf("looksLikeGatewayCommandLine(%q) = %v, want %v", tc.cmdline, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateGatewayPidDataAcceptsHealthWhenMatcherInconclusive(t *testing.T) {
|
||||
resetGatewayTestState(t)
|
||||
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
const testPID = 34567
|
||||
pidData := &ppid.PidFileData{
|
||||
PID: testPID,
|
||||
Host: "127.0.0.1",
|
||||
Port: 18790,
|
||||
}
|
||||
|
||||
gatewayProcessMatcher = func(int) (bool, bool) { return false, false }
|
||||
gatewayHealthGet = func(string, time.Duration) (*http.Response, error) {
|
||||
return mockGatewayHealthResponse(http.StatusOK, testPID), nil
|
||||
}
|
||||
|
||||
ok, decisive, reason := h.validateGatewayPidData(pidData, nil)
|
||||
if !ok {
|
||||
t.Fatalf("validateGatewayPidData() ok = false, want true (reason=%q)", reason)
|
||||
}
|
||||
if !decisive {
|
||||
t.Fatalf("validateGatewayPidData() decisive = false, want true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateGatewayPidDataRejectsHealthPidMismatchWhenMatcherInconclusive(t *testing.T) {
|
||||
resetGatewayTestState(t)
|
||||
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
pidData := &ppid.PidFileData{
|
||||
PID: 34567,
|
||||
Host: "127.0.0.1",
|
||||
Port: 18790,
|
||||
}
|
||||
|
||||
gatewayProcessMatcher = func(int) (bool, bool) { return false, false }
|
||||
gatewayHealthGet = func(string, time.Duration) (*http.Response, error) {
|
||||
return mockGatewayHealthResponse(http.StatusOK, 99999), nil
|
||||
}
|
||||
|
||||
ok, decisive, reason := h.validateGatewayPidData(pidData, nil)
|
||||
if ok {
|
||||
t.Fatalf("validateGatewayPidData() ok = true, want false")
|
||||
}
|
||||
if !decisive {
|
||||
t.Fatalf("validateGatewayPidData() decisive = false, want true")
|
||||
}
|
||||
if !strings.Contains(reason, "health pid mismatch") {
|
||||
t.Fatalf("validateGatewayPidData() reason = %q, want contains %q", reason, "health pid mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayStartReady_InvalidDefaultModel(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
cfg := config.DefaultConfig()
|
||||
@@ -533,7 +662,7 @@ func TestGatewayStatusDowngradesRunningWhenTrackedProcessExitedAndPidFileMissing
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayStatusReportsRunningFromPidProbe(t *testing.T) {
|
||||
func TestGatewayStatusIgnoresAndRemovesPidFileForNonGatewayProcess(t *testing.T) {
|
||||
resetGatewayTestState(t)
|
||||
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
@@ -549,6 +678,87 @@ func TestGatewayStatusReportsRunningFromPidProbe(t *testing.T) {
|
||||
_ = cmd.Wait()
|
||||
})
|
||||
|
||||
pidPath := writeTestPidFile(t, ppid.PidFileData{
|
||||
PID: cmd.Process.Pid,
|
||||
Token: "stale-token",
|
||||
Host: "127.0.0.1",
|
||||
Port: 18790,
|
||||
})
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var body map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("unmarshal response: %v", err)
|
||||
}
|
||||
if got := body["gateway_status"]; got != "stopped" {
|
||||
t.Fatalf("gateway_status = %#v, want %q", got, "stopped")
|
||||
}
|
||||
if _, err := os.Stat(pidPath); !os.IsNotExist(err) {
|
||||
t.Fatal("stale pid file should be removed for non-gateway process")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayStopRefusesNonGatewayAttachedProcess(t *testing.T) {
|
||||
resetGatewayTestState(t)
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("commandline-based process type check is best-effort on Windows")
|
||||
}
|
||||
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
cmd := startLongRunningProcess(t)
|
||||
t.Cleanup(func() {
|
||||
if cmd.Process != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
}
|
||||
_ = cmd.Wait()
|
||||
})
|
||||
|
||||
gateway.mu.Lock()
|
||||
gateway.cmd = cmd
|
||||
gateway.owned = false
|
||||
setGatewayRuntimeStatusLocked("running")
|
||||
gateway.mu.Unlock()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/gateway/stop", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusInternalServerError {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusInternalServerError)
|
||||
}
|
||||
if !isCmdProcessAliveLocked(cmd) {
|
||||
t.Fatal("non-gateway process should not be terminated by /api/gateway/stop")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayStatusReportsRunningFromPidProbe(t *testing.T) {
|
||||
resetGatewayTestState(t)
|
||||
gatewayProcessMatcher = func(int) (bool, bool) { return true, true }
|
||||
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
cmd := startGatewayLikeProcess(t)
|
||||
t.Cleanup(func() {
|
||||
if cmd.Process != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
}
|
||||
_ = cmd.Wait()
|
||||
})
|
||||
|
||||
gateway.mu.Lock()
|
||||
setGatewayRuntimeStatusLocked("stopped")
|
||||
gateway.mu.Unlock()
|
||||
@@ -557,8 +767,12 @@ func TestGatewayStatusReportsRunningFromPidProbe(t *testing.T) {
|
||||
return mockGatewayHealthResponse(http.StatusOK, cmd.Process.Pid), nil
|
||||
}
|
||||
|
||||
_, err := ppid.WritePidFile(globalConfigDir(), "localhost", 0)
|
||||
require.NoError(t, err)
|
||||
writeTestPidFile(t, ppid.PidFileData{
|
||||
PID: cmd.Process.Pid,
|
||||
Token: "test-token",
|
||||
Host: "127.0.0.1",
|
||||
Port: 18790,
|
||||
})
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil)
|
||||
@@ -583,6 +797,7 @@ func TestGatewayStatusReportsRunningFromPidProbe(t *testing.T) {
|
||||
|
||||
func TestGatewayStatusRequiresRestartAfterDefaultModelChange(t *testing.T) {
|
||||
resetGatewayTestState(t)
|
||||
gatewayProcessMatcher = func(int) (bool, bool) { return true, true }
|
||||
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
cfg := config.DefaultConfig()
|
||||
@@ -601,16 +816,23 @@ func TestGatewayStatusRequiresRestartAfterDefaultModelChange(t *testing.T) {
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
process, err := os.FindProcess(os.Getpid())
|
||||
if err != nil {
|
||||
t.Fatalf("FindProcess() error = %v", err)
|
||||
}
|
||||
_, err = ppid.WritePidFile(globalConfigDir(), "localhost", 0)
|
||||
require.NoError(t, err)
|
||||
cmd := startGatewayLikeProcess(t)
|
||||
t.Cleanup(func() {
|
||||
if cmd.Process != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
}
|
||||
_ = cmd.Wait()
|
||||
})
|
||||
writeTestPidFile(t, ppid.PidFileData{
|
||||
PID: cmd.Process.Pid,
|
||||
Token: "test-token",
|
||||
Host: "127.0.0.1",
|
||||
Port: 18790,
|
||||
})
|
||||
|
||||
bootSignature := computeConfigSignature(cfg)
|
||||
gateway.mu.Lock()
|
||||
gateway.cmd = &exec.Cmd{Process: process}
|
||||
gateway.cmd = cmd
|
||||
gateway.bootDefaultModel = cfg.ModelList[0].ModelName
|
||||
gateway.bootConfigSignature = bootSignature
|
||||
setGatewayRuntimeStatusLocked("running")
|
||||
|
||||
@@ -64,7 +64,7 @@ func (h *Handler) handleWebSocketProxy() http.HandlerFunc {
|
||||
|
||||
gatewayAvailable := false
|
||||
// Prefer fresh PID file data when available.
|
||||
if pidData := ppid.ReadPidFileWithCheck(globalConfigDir()); pidData != nil {
|
||||
if pidData := h.sanitizeGatewayPidData(ppid.ReadPidFileWithCheck(globalConfigDir()), nil); pidData != nil {
|
||||
gateway.mu.Lock()
|
||||
gateway.pidData = pidData
|
||||
setGatewayRuntimeStatusLocked("running")
|
||||
|
||||
@@ -308,6 +308,10 @@ func TestHandlePicoSetup_Response(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestHandleWebSocketProxyReloadsGatewayTargetFromConfig(t *testing.T) {
|
||||
origMatcher := gatewayProcessMatcher
|
||||
gatewayProcessMatcher = func(int) (bool, bool) { return true, true }
|
||||
t.Cleanup(func() { gatewayProcessMatcher = origMatcher })
|
||||
|
||||
home := t.TempDir()
|
||||
t.Setenv("PICOCLAW_HOME", home)
|
||||
|
||||
@@ -339,9 +343,19 @@ func TestHandleWebSocketProxyReloadsGatewayTargetFromConfig(t *testing.T) {
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
if _, err := ppid.WritePidFile(globalConfigDir(), cfg.Gateway.Host, cfg.Gateway.Port); err != nil {
|
||||
t.Fatalf("WritePidFile() error = %v", err)
|
||||
}
|
||||
cmd := startGatewayLikeProcess(t)
|
||||
t.Cleanup(func() {
|
||||
if cmd.Process != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
}
|
||||
_ = cmd.Wait()
|
||||
})
|
||||
writeTestPidFile(t, ppid.PidFileData{
|
||||
PID: cmd.Process.Pid,
|
||||
Token: "test-token",
|
||||
Host: cfg.Gateway.Host,
|
||||
Port: cfg.Gateway.Port,
|
||||
})
|
||||
origPidData := gateway.pidData
|
||||
origPicoToken := gateway.picoToken
|
||||
t.Cleanup(func() {
|
||||
@@ -392,6 +406,10 @@ func TestHandleWebSocketProxyReloadsGatewayTargetFromConfig(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestHandleWebSocketProxyLoadsCachedPicoTokenWhenMissing(t *testing.T) {
|
||||
origMatcher := gatewayProcessMatcher
|
||||
gatewayProcessMatcher = func(int) (bool, bool) { return true, true }
|
||||
t.Cleanup(func() { gatewayProcessMatcher = origMatcher })
|
||||
|
||||
home := t.TempDir()
|
||||
t.Setenv("PICOCLAW_HOME", home)
|
||||
|
||||
@@ -416,9 +434,19 @@ func TestHandleWebSocketProxyLoadsCachedPicoTokenWhenMissing(t *testing.T) {
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
if _, err := ppid.WritePidFile(globalConfigDir(), cfg.Gateway.Host, cfg.Gateway.Port); err != nil {
|
||||
t.Fatalf("WritePidFile() error = %v", err)
|
||||
}
|
||||
cmd := startGatewayLikeProcess(t)
|
||||
t.Cleanup(func() {
|
||||
if cmd.Process != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
}
|
||||
_ = cmd.Wait()
|
||||
})
|
||||
writeTestPidFile(t, ppid.PidFileData{
|
||||
PID: cmd.Process.Pid,
|
||||
Token: "test-token",
|
||||
Host: cfg.Gateway.Host,
|
||||
Port: cfg.Gateway.Port,
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
ppid.RemovePidFile(globalConfigDir())
|
||||
})
|
||||
@@ -450,6 +478,10 @@ func TestHandleWebSocketProxyLoadsCachedPicoTokenWhenMissing(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestHandleWebSocketProxyLoadsPidDataOnDemand(t *testing.T) {
|
||||
origMatcher := gatewayProcessMatcher
|
||||
gatewayProcessMatcher = func(int) (bool, bool) { return true, true }
|
||||
t.Cleanup(func() { gatewayProcessMatcher = origMatcher })
|
||||
|
||||
home := t.TempDir()
|
||||
t.Setenv("PICOCLAW_HOME", home)
|
||||
|
||||
@@ -475,10 +507,20 @@ func TestHandleWebSocketProxyLoadsPidDataOnDemand(t *testing.T) {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
pidData, err := ppid.WritePidFile(globalConfigDir(), cfg.Gateway.Host, cfg.Gateway.Port)
|
||||
if err != nil {
|
||||
t.Fatalf("WritePidFile() error = %v", err)
|
||||
cmd := startGatewayLikeProcess(t)
|
||||
t.Cleanup(func() {
|
||||
if cmd.Process != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
}
|
||||
_ = cmd.Wait()
|
||||
})
|
||||
pidData := ppid.PidFileData{
|
||||
PID: cmd.Process.Pid,
|
||||
Token: "test-token",
|
||||
Host: cfg.Gateway.Host,
|
||||
Port: cfg.Gateway.Port,
|
||||
}
|
||||
writeTestPidFile(t, pidData)
|
||||
t.Cleanup(func() {
|
||||
ppid.RemovePidFile(globalConfigDir())
|
||||
})
|
||||
|
||||
+93
-11
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/memory"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
// registerSessionRoutes binds session list and detail endpoints to the ServeMux.
|
||||
@@ -64,6 +65,11 @@ const (
|
||||
handledToolResponseSummaryText = "Requested output delivered via tool attachment."
|
||||
)
|
||||
|
||||
func defaultToolFeedbackMaxArgsLength() int {
|
||||
defaults := config.AgentDefaults{}
|
||||
return defaults.GetToolFeedbackMaxArgsLength()
|
||||
}
|
||||
|
||||
// extractLegacyPicoSessionID extracts the session UUID from an old Pico key.
|
||||
// Returns the UUID and true if the key matches the Pico session pattern.
|
||||
func extractLegacyPicoSessionID(key string) (string, bool) {
|
||||
@@ -391,7 +397,7 @@ func (h *Handler) findLegacyPicoSession(dir, sessionID string) (picoLegacySessio
|
||||
return picoLegacySessionRef{}, os.ErrNotExist
|
||||
}
|
||||
|
||||
func buildSessionListItem(sessionID string, sess sessionFile) sessionListItem {
|
||||
func buildSessionListItem(sessionID string, sess sessionFile, toolFeedbackMaxArgsLength int) sessionListItem {
|
||||
preview := ""
|
||||
for _, msg := range sess.Messages {
|
||||
if msg.Role == "user" {
|
||||
@@ -408,7 +414,7 @@ func buildSessionListItem(sessionID string, sess sessionFile) sessionListItem {
|
||||
}
|
||||
title := preview
|
||||
|
||||
validMessageCount := len(visibleSessionMessages(sess.Messages))
|
||||
validMessageCount := len(visibleSessionMessages(sess.Messages, toolFeedbackMaxArgsLength))
|
||||
|
||||
return sessionListItem{
|
||||
ID: sessionID,
|
||||
@@ -449,7 +455,7 @@ func sessionMessagePreview(msg providers.Message) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func visibleSessionMessages(messages []providers.Message) []sessionChatMessage {
|
||||
func visibleSessionMessages(messages []providers.Message, toolFeedbackMaxArgsLength int) []sessionChatMessage {
|
||||
transcript := make([]sessionChatMessage, 0, len(messages))
|
||||
|
||||
for _, msg := range messages {
|
||||
@@ -464,6 +470,17 @@ func visibleSessionMessages(messages []providers.Message) []sessionChatMessage {
|
||||
}
|
||||
|
||||
case "assistant":
|
||||
// Reasoning-only assistant messages are transient display artifacts and
|
||||
// should not be restored from session history.
|
||||
if assistantMessageTransientThought(msg) {
|
||||
continue
|
||||
}
|
||||
|
||||
toolSummaryMessages := visibleAssistantToolSummaryMessages(msg.ToolCalls, toolFeedbackMaxArgsLength)
|
||||
if len(toolSummaryMessages) > 0 {
|
||||
transcript = append(transcript, toolSummaryMessages...)
|
||||
}
|
||||
|
||||
visibleToolMessages := visibleAssistantToolMessages(msg.ToolCalls)
|
||||
if len(visibleToolMessages) > 0 {
|
||||
transcript = append(transcript, visibleToolMessages...)
|
||||
@@ -472,7 +489,7 @@ func visibleSessionMessages(messages []providers.Message) []sessionChatMessage {
|
||||
// Pico web chat can persist both visible `message` tool output and a
|
||||
// later plain assistant reply in the same turn. Hide only the fixed
|
||||
// internal summary that marks handled tool delivery.
|
||||
if len(visibleToolMessages) > 0 || !sessionMessageVisible(msg) || assistantMessageInternalOnly(msg) {
|
||||
if !sessionMessageVisible(msg) || assistantMessageInternalOnly(msg) {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -487,10 +504,63 @@ func visibleSessionMessages(messages []providers.Message) []sessionChatMessage {
|
||||
return transcript
|
||||
}
|
||||
|
||||
func assistantMessageTransientThought(msg providers.Message) bool {
|
||||
return strings.TrimSpace(msg.Content) == "" &&
|
||||
strings.TrimSpace(msg.ReasoningContent) != "" &&
|
||||
len(msg.ToolCalls) == 0 &&
|
||||
len(msg.Media) == 0
|
||||
}
|
||||
|
||||
func assistantMessageInternalOnly(msg providers.Message) bool {
|
||||
return strings.TrimSpace(msg.Content) == handledToolResponseSummaryText
|
||||
}
|
||||
|
||||
func visibleAssistantToolSummaryMessages(
|
||||
toolCalls []providers.ToolCall,
|
||||
toolFeedbackMaxArgsLength int,
|
||||
) []sessionChatMessage {
|
||||
if len(toolCalls) == 0 {
|
||||
return nil
|
||||
}
|
||||
if toolFeedbackMaxArgsLength <= 0 {
|
||||
toolFeedbackMaxArgsLength = defaultToolFeedbackMaxArgsLength()
|
||||
}
|
||||
|
||||
messages := make([]sessionChatMessage, 0, len(toolCalls))
|
||||
for _, tc := range toolCalls {
|
||||
name := tc.Name
|
||||
argsJSON := ""
|
||||
if tc.Function != nil {
|
||||
if name == "" {
|
||||
name = tc.Function.Name
|
||||
}
|
||||
argsJSON = tc.Function.Arguments
|
||||
}
|
||||
|
||||
if strings.TrimSpace(name) == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.TrimSpace(argsJSON) == "" && len(tc.Arguments) > 0 {
|
||||
if encodedArgs, err := json.Marshal(tc.Arguments); err == nil {
|
||||
argsJSON = string(encodedArgs)
|
||||
}
|
||||
}
|
||||
|
||||
argsPreview := strings.TrimSpace(argsJSON)
|
||||
if argsPreview == "" {
|
||||
argsPreview = "{}"
|
||||
}
|
||||
|
||||
messages = append(messages, sessionChatMessage{
|
||||
Role: "assistant",
|
||||
Content: utils.FormatToolFeedbackMessage(name, utils.Truncate(argsPreview, toolFeedbackMaxArgsLength)),
|
||||
})
|
||||
}
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
func visibleAssistantToolMessages(toolCalls []providers.ToolCall) []sessionChatMessage {
|
||||
if len(toolCalls) == 0 {
|
||||
return nil
|
||||
@@ -536,7 +606,19 @@ func (h *Handler) sessionsDir() (string, error) {
|
||||
return "", err
|
||||
}
|
||||
|
||||
workspace := cfg.Agents.Defaults.Workspace
|
||||
return resolveSessionsDir(cfg.Agents.Defaults.Workspace), nil
|
||||
}
|
||||
|
||||
func (h *Handler) sessionRuntimeSettings() (string, int, error) {
|
||||
cfg, err := config.LoadConfig(h.configPath)
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
|
||||
return resolveSessionsDir(cfg.Agents.Defaults.Workspace), cfg.Agents.Defaults.GetToolFeedbackMaxArgsLength(), nil
|
||||
}
|
||||
|
||||
func resolveSessionsDir(workspace string) string {
|
||||
if workspace == "" {
|
||||
home, _ := os.UserHomeDir()
|
||||
workspace = filepath.Join(home, ".picoclaw", "workspace")
|
||||
@@ -552,14 +634,14 @@ func (h *Handler) sessionsDir() (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
return filepath.Join(workspace, "sessions"), nil
|
||||
return filepath.Join(workspace, "sessions")
|
||||
}
|
||||
|
||||
// handleListSessions returns a list of Pico session summaries.
|
||||
//
|
||||
// GET /api/sessions
|
||||
func (h *Handler) handleListSessions(w http.ResponseWriter, r *http.Request) {
|
||||
dir, err := h.sessionsDir()
|
||||
dir, toolFeedbackMaxArgsLength, err := h.sessionRuntimeSettings()
|
||||
if err != nil {
|
||||
http.Error(w, "failed to resolve sessions directory", http.StatusInternalServerError)
|
||||
return
|
||||
@@ -582,7 +664,7 @@ func (h *Handler) handleListSessions(w http.ResponseWriter, r *http.Request) {
|
||||
continue
|
||||
}
|
||||
seen[ref.ID] = struct{}{}
|
||||
items = append(items, buildSessionListItem(ref.ID, sess))
|
||||
items = append(items, buildSessionListItem(ref.ID, sess, toolFeedbackMaxArgsLength))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -596,7 +678,7 @@ func (h *Handler) handleListSessions(w http.ResponseWriter, r *http.Request) {
|
||||
continue
|
||||
}
|
||||
seen[ref.ID] = struct{}{}
|
||||
items = append(items, buildSessionListItem(ref.ID, sess))
|
||||
items = append(items, buildSessionListItem(ref.ID, sess, toolFeedbackMaxArgsLength))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -645,7 +727,7 @@ func (h *Handler) handleGetSession(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
dir, err := h.sessionsDir()
|
||||
dir, toolFeedbackMaxArgsLength, err := h.sessionRuntimeSettings()
|
||||
if err != nil {
|
||||
http.Error(w, "failed to resolve sessions directory", http.StatusInternalServerError)
|
||||
return
|
||||
@@ -679,7 +761,7 @@ func (h *Handler) handleGetSession(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
messages := visibleSessionMessages(sess.Messages)
|
||||
messages := visibleSessionMessages(sess.Messages, toolFeedbackMaxArgsLength)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
|
||||
+216
-12
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/memory"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
func sessionsTestDir(t *testing.T, configPath string) string {
|
||||
@@ -292,6 +293,59 @@ func TestHandleSessions_JSONLScopeDiscovery(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleGetSession_OmitsTransientThoughtMessages(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
|
||||
dir := sessionsTestDir(t, configPath)
|
||||
store, err := memory.NewJSONLStore(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("NewJSONLStore() error = %v", err)
|
||||
}
|
||||
|
||||
sessionKey := picoSessionPrefix + "detail-transient-thought"
|
||||
for _, msg := range []providers.Message{
|
||||
{Role: "user", Content: "hello"},
|
||||
{Role: "assistant", ReasoningContent: "internal chain of thought"},
|
||||
{Role: "assistant", Content: "final visible answer"},
|
||||
} {
|
||||
if err := store.AddFullMessage(nil, sessionKey, msg); err != nil {
|
||||
t.Fatalf("AddFullMessage() error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/sessions/detail-transient-thought", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
|
||||
var resp struct {
|
||||
Messages []struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
} `json:"messages"`
|
||||
}
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("Unmarshal() error = %v", err)
|
||||
}
|
||||
if len(resp.Messages) != 2 {
|
||||
t.Fatalf("len(resp.Messages) = %d, want 2", len(resp.Messages))
|
||||
}
|
||||
if resp.Messages[0].Role != "user" || resp.Messages[0].Content != "hello" {
|
||||
t.Fatalf("first message = %#v, want user/hello", resp.Messages[0])
|
||||
}
|
||||
if resp.Messages[1].Role != "assistant" || resp.Messages[1].Content != "final visible answer" {
|
||||
t.Fatalf("second message = %#v, want assistant/final visible answer", resp.Messages[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleGetSession_ReconstructsVisibleMessageToolOutput(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
@@ -348,11 +402,14 @@ func TestHandleGetSession_ReconstructsVisibleMessageToolOutput(t *testing.T) {
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("Unmarshal() error = %v", err)
|
||||
}
|
||||
if len(resp.Messages) != 2 {
|
||||
t.Fatalf("len(resp.Messages) = %d, want 2", len(resp.Messages))
|
||||
if len(resp.Messages) != 3 {
|
||||
t.Fatalf("len(resp.Messages) = %d, want 3", len(resp.Messages))
|
||||
}
|
||||
if resp.Messages[1].Role != "assistant" || resp.Messages[1].Content != "visible tool output" {
|
||||
t.Fatalf("assistant message = %#v, want visible tool output", resp.Messages[1])
|
||||
if !strings.Contains(resp.Messages[1].Content, "`message`") {
|
||||
t.Fatalf("tool summary message = %#v, want message tool summary", resp.Messages[1])
|
||||
}
|
||||
if resp.Messages[2].Role != "assistant" || resp.Messages[2].Content != "visible tool output" {
|
||||
t.Fatalf("assistant message = %#v, want visible tool output", resp.Messages[2])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -411,14 +468,17 @@ func TestHandleGetSession_PreservesFinalAssistantReplyAfterMessageToolOutput(t *
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("Unmarshal() error = %v", err)
|
||||
}
|
||||
if len(resp.Messages) != 3 {
|
||||
t.Fatalf("len(resp.Messages) = %d, want 3", len(resp.Messages))
|
||||
if len(resp.Messages) != 4 {
|
||||
t.Fatalf("len(resp.Messages) = %d, want 4", len(resp.Messages))
|
||||
}
|
||||
if resp.Messages[1].Role != "assistant" || resp.Messages[1].Content != "visible tool output" {
|
||||
t.Fatalf("interim assistant message = %#v, want visible tool output", resp.Messages[1])
|
||||
if !strings.Contains(resp.Messages[1].Content, "`message`") {
|
||||
t.Fatalf("tool summary message = %#v, want message tool summary", resp.Messages[1])
|
||||
}
|
||||
if resp.Messages[2].Role != "assistant" || resp.Messages[2].Content != "final assistant reply" {
|
||||
t.Fatalf("final assistant message = %#v, want final assistant reply", resp.Messages[2])
|
||||
if resp.Messages[2].Role != "assistant" || resp.Messages[2].Content != "visible tool output" {
|
||||
t.Fatalf("interim assistant message = %#v, want visible tool output", resp.Messages[2])
|
||||
}
|
||||
if resp.Messages[3].Role != "assistant" || resp.Messages[3].Content != "final assistant reply" {
|
||||
t.Fatalf("final assistant message = %#v, want final assistant reply", resp.Messages[3])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -475,8 +535,152 @@ func TestHandleListSessions_MessageCountUsesVisibleTranscript(t *testing.T) {
|
||||
if len(items) != 1 {
|
||||
t.Fatalf("len(items) = %d, want 1", len(items))
|
||||
}
|
||||
if items[0].MessageCount != 2 {
|
||||
t.Fatalf("items[0].MessageCount = %d, want 2", items[0].MessageCount)
|
||||
if items[0].MessageCount != 3 {
|
||||
t.Fatalf("items[0].MessageCount = %d, want 3", items[0].MessageCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleGetSession_PreservesToolSummaryAndAssistantContent(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
|
||||
dir := sessionsTestDir(t, configPath)
|
||||
store, err := memory.NewJSONLStore(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("NewJSONLStore() error = %v", err)
|
||||
}
|
||||
|
||||
sessionKey := picoSessionPrefix + "detail-tool-summary-and-content"
|
||||
for _, msg := range []providers.Message{
|
||||
{Role: "user", Content: "check file"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "model final reply",
|
||||
ToolCalls: []providers.ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Type: "function",
|
||||
Function: &providers.FunctionCall{
|
||||
Name: "read_file",
|
||||
Arguments: `{"path":"README.md","start_line":1,"end_line":10}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
} {
|
||||
if err := store.AddFullMessage(nil, sessionKey, msg); err != nil {
|
||||
t.Fatalf("AddFullMessage() error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/sessions/detail-tool-summary-and-content", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
|
||||
var resp struct {
|
||||
Messages []struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
} `json:"messages"`
|
||||
}
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("Unmarshal() error = %v", err)
|
||||
}
|
||||
if len(resp.Messages) != 3 {
|
||||
t.Fatalf("len(resp.Messages) = %d, want 3", len(resp.Messages))
|
||||
}
|
||||
if resp.Messages[0].Role != "user" || resp.Messages[0].Content != "check file" {
|
||||
t.Fatalf("first message = %#v, want user/check file", resp.Messages[0])
|
||||
}
|
||||
if !strings.Contains(resp.Messages[1].Content, "`read_file`") {
|
||||
t.Fatalf("tool summary message = %#v, want read_file summary", resp.Messages[1])
|
||||
}
|
||||
if resp.Messages[2].Role != "assistant" || resp.Messages[2].Content != "model final reply" {
|
||||
t.Fatalf("assistant message = %#v, want model final reply", resp.Messages[2])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleGetSession_UsesConfiguredToolFeedbackMaxArgsLength(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
|
||||
cfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
cfg.Agents.Defaults.ToolFeedback.MaxArgsLength = 20
|
||||
err = config.SaveConfig(configPath, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
dir := sessionsTestDir(t, configPath)
|
||||
store, err := memory.NewJSONLStore(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("NewJSONLStore() error = %v", err)
|
||||
}
|
||||
|
||||
argsJSON := `{"path":"README.md","start_line":1,"end_line":10,"extra":"abcdefghijklmnopqrstuvwxyz"}`
|
||||
sessionKey := picoSessionPrefix + "detail-tool-summary-max-args"
|
||||
err = store.AddFullMessage(nil, sessionKey, providers.Message{Role: "user", Content: "check file"})
|
||||
if err != nil {
|
||||
t.Fatalf("AddFullMessage(user) error = %v", err)
|
||||
}
|
||||
err = store.AddFullMessage(nil, sessionKey, providers.Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: []providers.ToolCall{{
|
||||
ID: "call_1",
|
||||
Type: "function",
|
||||
Function: &providers.FunctionCall{
|
||||
Name: "read_file",
|
||||
Arguments: argsJSON,
|
||||
},
|
||||
}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("AddFullMessage(assistant) error = %v", err)
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/sessions/detail-tool-summary-max-args", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
|
||||
var resp struct {
|
||||
Messages []struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
} `json:"messages"`
|
||||
}
|
||||
err = json.Unmarshal(rec.Body.Bytes(), &resp)
|
||||
if err != nil {
|
||||
t.Fatalf("Unmarshal() error = %v", err)
|
||||
}
|
||||
if len(resp.Messages) < 2 {
|
||||
t.Fatalf("len(resp.Messages) = %d, want at least 2", len(resp.Messages))
|
||||
}
|
||||
|
||||
wantPreview := utils.Truncate(argsJSON, 20)
|
||||
if !strings.Contains(resp.Messages[1].Content, wantPreview) {
|
||||
t.Fatalf("tool summary = %q, want preview %q", resp.Messages[1].Content, wantPreview)
|
||||
}
|
||||
if strings.Contains(resp.Messages[1].Content, argsJSON) {
|
||||
t.Fatalf("tool summary = %q, expected configured truncation", resp.Messages[1].Content)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user