diff --git a/pkg/agent/hooks_test.go b/pkg/agent/hooks_test.go index 92e9caae9..9049a5c72 100644 --- a/pkg/agent/hooks_test.go +++ b/pkg/agent/hooks_test.go @@ -515,7 +515,6 @@ type respondWithMediaHook struct { media []string responseHandled bool forLLM string - sendMediaErr error } func (h *respondWithMediaHook) BeforeTool( diff --git a/pkg/health/server.go b/pkg/health/server.go index 2602cb965..a152d8ab1 100644 --- a/pkg/health/server.go +++ b/pkg/health/server.go @@ -7,6 +7,7 @@ import ( "fmt" "maps" "net/http" + "os" "sync" "time" ) @@ -31,6 +32,7 @@ type Check struct { type StatusResponse struct { Status string `json:"status"` Uptime string `json:"uptime"` + PID int `json:"pid,omitempty"` Checks map[string]Check `json:"checks,omitempty"` } @@ -170,6 +172,7 @@ func (s *Server) healthHandler(w http.ResponseWriter, r *http.Request) { resp := StatusResponse{ Status: "ok", Uptime: uptime.String(), + PID: os.Getpid(), } json.NewEncoder(w).Encode(resp) diff --git a/pkg/pid/pidfile.go b/pkg/pid/pidfile.go index 0b6d461c2..f7c1f42b2 100644 --- a/pkg/pid/pidfile.go +++ b/pkg/pid/pidfile.go @@ -151,6 +151,30 @@ func RemovePidFile(homePath string) { os.Remove(pidPath) } +// RemovePidFileIfPID deletes the PID file only when the recorded PID matches +// expectedPID. It returns true when the file is removed successfully. +func RemovePidFileIfPID(homePath string, expectedPID int) bool { + if expectedPID <= 0 { + return false + } + + pidMu.Lock() + defer pidMu.Unlock() + + pidPath := pidFilePath(homePath) + data, err := readPidFileUnlocked(pidPath) + if err != nil { + return false + } + if data.PID != expectedPID { + return false + } + if err := os.Remove(pidPath); err != nil { + return false + } + return true +} + // readPidFileUnlocked reads the PID file without acquiring the lock. // Caller must hold pidMu. func readPidFileUnlocked(pidPath string) (*PidFileData, error) { diff --git a/pkg/pid/pidfile_test.go b/pkg/pid/pidfile_test.go index e54b93f4f..2da44bbbc 100644 --- a/pkg/pid/pidfile_test.go +++ b/pkg/pid/pidfile_test.go @@ -244,6 +244,40 @@ func TestRemovePidFileNonexistent(t *testing.T) { RemovePidFile(dir) } +func TestRemovePidFileIfPID(t *testing.T) { + dir := tmpDir(t) + + other := PidFileData{PID: 99999999, Token: "deadbeef12345678deadbeef12345678"} + raw, _ := json.MarshalIndent(other, "", " ") + path := filepath.Join(dir, pidFileName) + os.WriteFile(path, raw, 0o600) + + removed := RemovePidFileIfPID(dir, 99999999) + if !removed { + t.Fatal("expected RemovePidFileIfPID to remove matching pid file") + } + if _, err := os.Stat(path); !os.IsNotExist(err) { + t.Error("PID file should be removed for matching expected PID") + } +} + +func TestRemovePidFileIfPIDMismatch(t *testing.T) { + dir := tmpDir(t) + + other := PidFileData{PID: 99999999, Token: "deadbeef12345678deadbeef12345678"} + raw, _ := json.MarshalIndent(other, "", " ") + path := filepath.Join(dir, pidFileName) + os.WriteFile(path, raw, 0o600) + + removed := RemovePidFileIfPID(dir, 88888888) + if removed { + t.Fatal("expected RemovePidFileIfPID to keep non-matching pid file") + } + if _, err := os.Stat(path); os.IsNotExist(err) { + t.Error("PID file should NOT be removed for mismatching expected PID") + } +} + // TestReadPidFileUnlockedInvalidJSON returns error for malformed content. func TestReadPidFileUnlockedInvalidJSON(t *testing.T) { dir := tmpDir(t) diff --git a/pkg/seahorse/schema_test.go b/pkg/seahorse/schema_test.go index 17879f66c..e11e6e96e 100644 --- a/pkg/seahorse/schema_test.go +++ b/pkg/seahorse/schema_test.go @@ -2,14 +2,26 @@ package seahorse import ( "database/sql" + "fmt" + "strings" + "sync/atomic" "testing" _ "modernc.org/sqlite" ) +var testDBCounter uint64 + func openTestDB(t *testing.T) *sql.DB { t.Helper() - db, err := sql.Open("sqlite", ":memory:") + + n := atomic.AddUint64(&testDBCounter, 1) + testName := strings.NewReplacer("/", "_", " ", "_").Replace(t.Name()) + // Use a shared in-memory database so concurrent goroutines/connections in tests + // observe the same schema/data. + dsn := fmt.Sprintf("file:seahorse_test_%s_%d?mode=memory&cache=shared", testName, n) + + db, err := sql.Open("sqlite", dsn) if err != nil { t.Fatalf("open test db: %v", err) } diff --git a/web/backend/api/gateway.go b/web/backend/api/gateway.go index 139f2c8c8..8994e9c60 100644 --- a/web/backend/api/gateway.go +++ b/web/backend/api/gateway.go @@ -108,6 +108,8 @@ var gatewayHealthGet = func(url string, timeout time.Duration) (*http.Response, return client.Get(url) } +var gatewayProcessMatcher = isLikelyGatewayProcess + // getGatewayHealth checks the gateway health endpoint and returns the status response. // Returns (*health.StatusResponse, statusCode, error). If error is not nil, the other values are not valid. func (h *Handler) getGatewayHealth(cfg *config.Config, timeout time.Duration) (*health.StatusResponse, int, error) { @@ -117,7 +119,7 @@ func (h *Handler) getGatewayHealth(cfg *config.Config, timeout time.Duration) (* gateway.mu.Lock() if d := gateway.pidData; d != nil && d.Port > 0 { port = d.Port - host = d.Host + host = gatewayProbeHost(d.Host) } gateway.mu.Unlock() if port == 0 { @@ -150,6 +152,150 @@ func getGatewayHealthByURL(url string, timeout time.Duration) (*health.StatusRes return &healthResponse, resp.StatusCode, nil } +// isLikelyGatewayProcess returns whether PID appears to be a picoclaw gateway +// process plus whether inspection was conclusive on this platform/environment. +func isLikelyGatewayProcess(pid int) (bool, bool) { + if pid <= 0 { + return false, true + } + + if runtime.GOOS == "windows" { + psCmd := fmt.Sprintf( + `$p=Get-CimInstance Win32_Process -Filter "ProcessId = %d"; if ($null -eq $p) { "" } else { $p.CommandLine }`, + pid, + ) + out, err := exec.Command("powershell", "-NoProfile", "-NonInteractive", "-Command", psCmd).Output() + if err == nil { + cmdline := strings.TrimSpace(string(out)) + if cmdline != "" { + return looksLikeGatewayCommandLine(cmdline), true + } + } + + // Fallback: determine only whether the process still exists. + out, err = exec.Command("tasklist", "/FI", "PID eq "+strconv.Itoa(pid), "/FO", "CSV", "/NH").Output() + if err != nil { + return false, false + } + line := strings.ToLower(strings.TrimSpace(string(out))) + if line == "" { + return false, true + } + // A CSV row means the process exists, but may have a custom executable + // name we cannot classify here. + if strings.HasPrefix(line, "\"") { + if strings.Contains(line, "\"picoclaw.exe\"") { + return true, true + } + return false, false + } + if strings.Contains(line, "no tasks are running") { + return false, true + } + return false, true + } + + out, err := exec.Command("ps", "-o", "command=", "-p", strconv.Itoa(pid)).Output() + if err != nil { + return false, false + } + cmdline := strings.ToLower(strings.TrimSpace(string(out))) + if cmdline == "" { + return false, true + } + return looksLikeGatewayCommandLine(cmdline), true +} + +// looksLikeGatewayCommandLine checks whether a process command line likely +// represents "picoclaw gateway ..." regardless of executable filename. +func looksLikeGatewayCommandLine(cmdline string) bool { + fields := strings.Fields(strings.ToLower(strings.TrimSpace(cmdline))) + if len(fields) == 0 { + return false + } + for _, f := range fields { + token := strings.Trim(f, `"'`) + if token == "gateway" || strings.HasSuffix(token, "/gateway") || strings.HasSuffix(token, `\gateway`) { + return true + } + } + return false +} + +func (h *Handler) getGatewayHealthForPidData( + pidData *ppid.PidFileData, + cfg *config.Config, + timeout time.Duration, +) (*health.StatusResponse, int, error) { + if pidData == nil { + return nil, 0, errors.New("nil pid data") + } + + port := pidData.Port + if port == 0 { + port = 18790 + if cfg != nil && cfg.Gateway.Port != 0 { + port = cfg.Gateway.Port + } + } + + host := gatewayProbeHost(strings.TrimSpace(pidData.Host)) + if host == "" { + host = gatewayProbeHost(h.effectiveGatewayBindHost(cfg)) + } + if host == "" { + host = "127.0.0.1" + } + + url := "http://" + net.JoinHostPort(host, strconv.Itoa(port)) + "/health" + return getGatewayHealthByURL(url, timeout) +} + +func (h *Handler) validateGatewayPidData( + pidData *ppid.PidFileData, + cfg *config.Config, +) (ok bool, decisive bool, reason string) { + if pidData == nil || pidData.PID <= 0 { + return false, true, "invalid pid data" + } + + if gatewayProcess, inspected := gatewayProcessMatcher(pidData.PID); inspected { + if !gatewayProcess { + return false, true, "pid process command is not picoclaw gateway" + } + return true, true, "" + } + + healthResp, statusCode, err := h.getGatewayHealthForPidData(pidData, cfg, 800*time.Millisecond) + if err != nil { + return false, false, fmt.Sprintf("health probe failed: %v", err) + } + if statusCode != http.StatusOK { + return false, false, fmt.Sprintf("health endpoint returned status %d", statusCode) + } + if healthResp.PID > 0 && healthResp.PID != pidData.PID { + return false, true, fmt.Sprintf("health pid mismatch: pidFile=%d, health=%d", pidData.PID, healthResp.PID) + } + return true, true, "" +} + +func (h *Handler) sanitizeGatewayPidData(pidData *ppid.PidFileData, cfg *config.Config) *ppid.PidFileData { + if pidData == nil { + return nil + } + + ok, decisive, reason := h.validateGatewayPidData(pidData, cfg) + if ok { + return pidData + } + + logger.Warnf("ignore pid file for PID %d: %s", pidData.PID, reason) + if decisive && ppid.RemovePidFileIfPID(globalConfigDir(), pidData.PID) { + logger.Warnf("removed stale pid file for PID %d", pidData.PID) + } + return nil +} + // registerGatewayRoutes binds gateway lifecycle endpoints to the ServeMux. func (h *Handler) registerGatewayRoutes(mux *http.ServeMux) { mux.HandleFunc("GET /api/gateway/status", h.handleGatewayStatus) @@ -164,7 +310,7 @@ func (h *Handler) registerGatewayRoutes(mux *http.ServeMux) { // starts it when possible. Intended to be called by the backend at startup. func (h *Handler) TryAutoStartGateway() { // Check PID file first to detect an already-running gateway. - pidData := ppid.ReadPidFileWithCheck(globalConfigDir()) + pidData := h.sanitizeGatewayPidData(ppid.ReadPidFileWithCheck(globalConfigDir()), nil) if pidData != nil { gateway.mu.Lock() ready, reason, err := h.gatewayStartReady() @@ -472,6 +618,11 @@ func stopGatewayLocked() (int, error) { } pid := gateway.cmd.Process.Pid + if !gateway.owned { + if isGateway, inspected := gatewayProcessMatcher(pid); inspected && !isGateway { + return pid, fmt.Errorf("refuse to stop non-gateway process (PID %d)", pid) + } + } // Send SIGTERM for graceful shutdown (SIGKILL on Windows) var sigErr error @@ -681,7 +832,7 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int // POST /api/gateway/start func (h *Handler) handleGatewayStart(w http.ResponseWriter, r *http.Request) { // Check PID file first to detect an already-running gateway. - pidData := ppid.ReadPidFileWithCheck(globalConfigDir()) + pidData := h.sanitizeGatewayPidData(ppid.ReadPidFileWithCheck(globalConfigDir()), nil) if pidData != nil { pid := pidData.PID gateway.mu.Lock() @@ -807,9 +958,22 @@ func (h *Handler) RestartGateway() (int, error) { gateway.mu.Lock() previousCmd := gateway.cmd + previousOwned := gateway.owned setGatewayRuntimeStatusLocked("restarting") gateway.mu.Unlock() + if previousCmd != nil && previousCmd.Process != nil && !previousOwned { + if isGateway, inspected := gatewayProcessMatcher(previousCmd.Process.Pid); inspected && !isGateway { + logger.Warnf("refuse restarting non-gateway process (PID: %d)", previousCmd.Process.Pid) + gateway.mu.Lock() + if gateway.cmd == previousCmd { + setGatewayRuntimeStatusLocked("running") + } + gateway.mu.Unlock() + return 0, fmt.Errorf("refuse to restart non-gateway process (PID %d)", previousCmd.Process.Pid) + } + } + if err = stopGatewayProcessForRestart(previousCmd); err != nil { gateway.mu.Lock() if gateway.cmd == previousCmd { @@ -921,7 +1085,7 @@ func (h *Handler) gatewayStatusData() map[string]any { } // Primary detection: read PID file and check if process is alive. - pidData := ppid.ReadPidFileWithCheck(globalConfigDir()) + pidData := h.sanitizeGatewayPidData(ppid.ReadPidFileWithCheck(globalConfigDir()), cfg) if pidData != nil { gateway.mu.Lock() gateway.pidData = pidData diff --git a/web/backend/api/gateway_test.go b/web/backend/api/gateway_test.go index 1f5f13e27..d300b657c 100644 --- a/web/backend/api/gateway_test.go +++ b/web/backend/api/gateway_test.go @@ -15,8 +15,6 @@ import ( "testing" "time" - "github.com/stretchr/testify/require" - "github.com/sipeed/picoclaw/pkg/auth" "github.com/sipeed/picoclaw/pkg/config" ppid "github.com/sipeed/picoclaw/pkg/pid" @@ -40,6 +38,36 @@ func startLongRunningProcess(t *testing.T) *exec.Cmd { return cmd } +func startGatewayLikeProcess(t *testing.T) *exec.Cmd { + t.Helper() + + var cmd *exec.Cmd + if runtime.GOOS == "windows" { + t.Skip("gateway-like process commandline check is not deterministic on Windows tests") + } + cmd = exec.Command("sh", "-c", "sleep 30 # picoclaw gateway") + + if err := cmd.Start(); err != nil { + t.Fatalf("Start() error = %v", err) + } + + return cmd +} + +func writeTestPidFile(t *testing.T, data ppid.PidFileData) string { + t.Helper() + + path := filepath.Join(globalConfigDir(), ".picoclaw.pid") + raw, err := json.MarshalIndent(data, "", " ") + if err != nil { + t.Fatalf("marshal pid file: %v", err) + } + if err := os.WriteFile(path, raw, 0o600); err != nil { + t.Fatalf("write pid file: %v", err) + } + return path +} + func mockGatewayHealthResponse(statusCode, pid int) *http.Response { return &http.Response{ StatusCode: statusCode, @@ -68,12 +96,14 @@ func resetGatewayTestState(t *testing.T) { t.Helper() originalHealthGet := gatewayHealthGet + originalProcessMatcher := gatewayProcessMatcher originalRestartGracePeriod := gatewayRestartGracePeriod originalRestartForceKillWindow := gatewayRestartForceKillWindow originalRestartPollInterval := gatewayRestartPollInterval t.Setenv("PICOCLAW_HOME", t.TempDir()) t.Cleanup(func() { gatewayHealthGet = originalHealthGet + gatewayProcessMatcher = originalProcessMatcher gatewayRestartGracePeriod = originalRestartGracePeriod gatewayRestartForceKillWindow = originalRestartForceKillWindow gatewayRestartPollInterval = originalRestartPollInterval @@ -105,6 +135,105 @@ func TestGatewayStartReady_NoDefaultModel(t *testing.T) { } } +func TestLooksLikeGatewayCommandLine(t *testing.T) { + cases := []struct { + name string + cmdline string + want bool + }{ + { + name: "default picoclaw gateway", + cmdline: "/usr/local/bin/picoclaw gateway -E", + want: true, + }, + { + name: "renamed binary with gateway subcommand", + cmdline: "/opt/bin/custom-claw gateway -E -d", + want: true, + }, + { + name: "standalone gateway binary path", + cmdline: "/opt/bin/gateway -E", + want: true, + }, + { + name: "non gateway process", + cmdline: "/bin/sleep 30", + want: false, + }, + { + name: "gateway substring only", + cmdline: "/opt/bin/gatewayd --serve", + want: false, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := looksLikeGatewayCommandLine(tc.cmdline) + if got != tc.want { + t.Fatalf("looksLikeGatewayCommandLine(%q) = %v, want %v", tc.cmdline, got, tc.want) + } + }) + } +} + +func TestValidateGatewayPidDataAcceptsHealthWhenMatcherInconclusive(t *testing.T) { + resetGatewayTestState(t) + + configPath := filepath.Join(t.TempDir(), "config.json") + h := NewHandler(configPath) + + const testPID = 34567 + pidData := &ppid.PidFileData{ + PID: testPID, + Host: "127.0.0.1", + Port: 18790, + } + + gatewayProcessMatcher = func(int) (bool, bool) { return false, false } + gatewayHealthGet = func(string, time.Duration) (*http.Response, error) { + return mockGatewayHealthResponse(http.StatusOK, testPID), nil + } + + ok, decisive, reason := h.validateGatewayPidData(pidData, nil) + if !ok { + t.Fatalf("validateGatewayPidData() ok = false, want true (reason=%q)", reason) + } + if !decisive { + t.Fatalf("validateGatewayPidData() decisive = false, want true") + } +} + +func TestValidateGatewayPidDataRejectsHealthPidMismatchWhenMatcherInconclusive(t *testing.T) { + resetGatewayTestState(t) + + configPath := filepath.Join(t.TempDir(), "config.json") + h := NewHandler(configPath) + + pidData := &ppid.PidFileData{ + PID: 34567, + Host: "127.0.0.1", + Port: 18790, + } + + gatewayProcessMatcher = func(int) (bool, bool) { return false, false } + gatewayHealthGet = func(string, time.Duration) (*http.Response, error) { + return mockGatewayHealthResponse(http.StatusOK, 99999), nil + } + + ok, decisive, reason := h.validateGatewayPidData(pidData, nil) + if ok { + t.Fatalf("validateGatewayPidData() ok = true, want false") + } + if !decisive { + t.Fatalf("validateGatewayPidData() decisive = false, want true") + } + if !strings.Contains(reason, "health pid mismatch") { + t.Fatalf("validateGatewayPidData() reason = %q, want contains %q", reason, "health pid mismatch") + } +} + func TestGatewayStartReady_InvalidDefaultModel(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") cfg := config.DefaultConfig() @@ -533,7 +662,7 @@ func TestGatewayStatusDowngradesRunningWhenTrackedProcessExitedAndPidFileMissing } } -func TestGatewayStatusReportsRunningFromPidProbe(t *testing.T) { +func TestGatewayStatusIgnoresAndRemovesPidFileForNonGatewayProcess(t *testing.T) { resetGatewayTestState(t) configPath := filepath.Join(t.TempDir(), "config.json") @@ -549,6 +678,87 @@ func TestGatewayStatusReportsRunningFromPidProbe(t *testing.T) { _ = cmd.Wait() }) + pidPath := writeTestPidFile(t, ppid.PidFileData{ + PID: cmd.Process.Pid, + Token: "stale-token", + Host: "127.0.0.1", + Port: 18790, + }) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil) + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } + + var body map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil { + t.Fatalf("unmarshal response: %v", err) + } + if got := body["gateway_status"]; got != "stopped" { + t.Fatalf("gateway_status = %#v, want %q", got, "stopped") + } + if _, err := os.Stat(pidPath); !os.IsNotExist(err) { + t.Fatal("stale pid file should be removed for non-gateway process") + } +} + +func TestGatewayStopRefusesNonGatewayAttachedProcess(t *testing.T) { + resetGatewayTestState(t) + if runtime.GOOS == "windows" { + t.Skip("commandline-based process type check is best-effort on Windows") + } + + configPath := filepath.Join(t.TempDir(), "config.json") + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + cmd := startLongRunningProcess(t) + t.Cleanup(func() { + if cmd.Process != nil { + _ = cmd.Process.Kill() + } + _ = cmd.Wait() + }) + + gateway.mu.Lock() + gateway.cmd = cmd + gateway.owned = false + setGatewayRuntimeStatusLocked("running") + gateway.mu.Unlock() + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/gateway/stop", nil) + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusInternalServerError { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusInternalServerError) + } + if !isCmdProcessAliveLocked(cmd) { + t.Fatal("non-gateway process should not be terminated by /api/gateway/stop") + } +} + +func TestGatewayStatusReportsRunningFromPidProbe(t *testing.T) { + resetGatewayTestState(t) + gatewayProcessMatcher = func(int) (bool, bool) { return true, true } + + configPath := filepath.Join(t.TempDir(), "config.json") + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + cmd := startGatewayLikeProcess(t) + t.Cleanup(func() { + if cmd.Process != nil { + _ = cmd.Process.Kill() + } + _ = cmd.Wait() + }) + gateway.mu.Lock() setGatewayRuntimeStatusLocked("stopped") gateway.mu.Unlock() @@ -557,8 +767,12 @@ func TestGatewayStatusReportsRunningFromPidProbe(t *testing.T) { return mockGatewayHealthResponse(http.StatusOK, cmd.Process.Pid), nil } - _, err := ppid.WritePidFile(globalConfigDir(), "localhost", 0) - require.NoError(t, err) + writeTestPidFile(t, ppid.PidFileData{ + PID: cmd.Process.Pid, + Token: "test-token", + Host: "127.0.0.1", + Port: 18790, + }) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil) @@ -583,6 +797,7 @@ func TestGatewayStatusReportsRunningFromPidProbe(t *testing.T) { func TestGatewayStatusRequiresRestartAfterDefaultModelChange(t *testing.T) { resetGatewayTestState(t) + gatewayProcessMatcher = func(int) (bool, bool) { return true, true } configPath := filepath.Join(t.TempDir(), "config.json") cfg := config.DefaultConfig() @@ -601,16 +816,23 @@ func TestGatewayStatusRequiresRestartAfterDefaultModelChange(t *testing.T) { mux := http.NewServeMux() h.RegisterRoutes(mux) - process, err := os.FindProcess(os.Getpid()) - if err != nil { - t.Fatalf("FindProcess() error = %v", err) - } - _, err = ppid.WritePidFile(globalConfigDir(), "localhost", 0) - require.NoError(t, err) + cmd := startGatewayLikeProcess(t) + t.Cleanup(func() { + if cmd.Process != nil { + _ = cmd.Process.Kill() + } + _ = cmd.Wait() + }) + writeTestPidFile(t, ppid.PidFileData{ + PID: cmd.Process.Pid, + Token: "test-token", + Host: "127.0.0.1", + Port: 18790, + }) bootSignature := computeConfigSignature(cfg) gateway.mu.Lock() - gateway.cmd = &exec.Cmd{Process: process} + gateway.cmd = cmd gateway.bootDefaultModel = cfg.ModelList[0].ModelName gateway.bootConfigSignature = bootSignature setGatewayRuntimeStatusLocked("running") diff --git a/web/backend/api/pico.go b/web/backend/api/pico.go index 95bbfd2c1..1d6b46d32 100644 --- a/web/backend/api/pico.go +++ b/web/backend/api/pico.go @@ -64,7 +64,7 @@ func (h *Handler) handleWebSocketProxy() http.HandlerFunc { gatewayAvailable := false // Prefer fresh PID file data when available. - if pidData := ppid.ReadPidFileWithCheck(globalConfigDir()); pidData != nil { + if pidData := h.sanitizeGatewayPidData(ppid.ReadPidFileWithCheck(globalConfigDir()), nil); pidData != nil { gateway.mu.Lock() gateway.pidData = pidData setGatewayRuntimeStatusLocked("running") diff --git a/web/backend/api/pico_test.go b/web/backend/api/pico_test.go index 04888fde7..af5ba205f 100644 --- a/web/backend/api/pico_test.go +++ b/web/backend/api/pico_test.go @@ -308,6 +308,10 @@ func TestHandlePicoSetup_Response(t *testing.T) { } func TestHandleWebSocketProxyReloadsGatewayTargetFromConfig(t *testing.T) { + origMatcher := gatewayProcessMatcher + gatewayProcessMatcher = func(int) (bool, bool) { return true, true } + t.Cleanup(func() { gatewayProcessMatcher = origMatcher }) + home := t.TempDir() t.Setenv("PICOCLAW_HOME", home) @@ -339,9 +343,19 @@ func TestHandleWebSocketProxyReloadsGatewayTargetFromConfig(t *testing.T) { if err := config.SaveConfig(configPath, cfg); err != nil { t.Fatalf("SaveConfig() error = %v", err) } - if _, err := ppid.WritePidFile(globalConfigDir(), cfg.Gateway.Host, cfg.Gateway.Port); err != nil { - t.Fatalf("WritePidFile() error = %v", err) - } + cmd := startGatewayLikeProcess(t) + t.Cleanup(func() { + if cmd.Process != nil { + _ = cmd.Process.Kill() + } + _ = cmd.Wait() + }) + writeTestPidFile(t, ppid.PidFileData{ + PID: cmd.Process.Pid, + Token: "test-token", + Host: cfg.Gateway.Host, + Port: cfg.Gateway.Port, + }) origPidData := gateway.pidData origPicoToken := gateway.picoToken t.Cleanup(func() { @@ -392,6 +406,10 @@ func TestHandleWebSocketProxyReloadsGatewayTargetFromConfig(t *testing.T) { } func TestHandleWebSocketProxyLoadsCachedPicoTokenWhenMissing(t *testing.T) { + origMatcher := gatewayProcessMatcher + gatewayProcessMatcher = func(int) (bool, bool) { return true, true } + t.Cleanup(func() { gatewayProcessMatcher = origMatcher }) + home := t.TempDir() t.Setenv("PICOCLAW_HOME", home) @@ -416,9 +434,19 @@ func TestHandleWebSocketProxyLoadsCachedPicoTokenWhenMissing(t *testing.T) { if err := config.SaveConfig(configPath, cfg); err != nil { t.Fatalf("SaveConfig() error = %v", err) } - if _, err := ppid.WritePidFile(globalConfigDir(), cfg.Gateway.Host, cfg.Gateway.Port); err != nil { - t.Fatalf("WritePidFile() error = %v", err) - } + cmd := startGatewayLikeProcess(t) + t.Cleanup(func() { + if cmd.Process != nil { + _ = cmd.Process.Kill() + } + _ = cmd.Wait() + }) + writeTestPidFile(t, ppid.PidFileData{ + PID: cmd.Process.Pid, + Token: "test-token", + Host: cfg.Gateway.Host, + Port: cfg.Gateway.Port, + }) t.Cleanup(func() { ppid.RemovePidFile(globalConfigDir()) }) @@ -450,6 +478,10 @@ func TestHandleWebSocketProxyLoadsCachedPicoTokenWhenMissing(t *testing.T) { } func TestHandleWebSocketProxyLoadsPidDataOnDemand(t *testing.T) { + origMatcher := gatewayProcessMatcher + gatewayProcessMatcher = func(int) (bool, bool) { return true, true } + t.Cleanup(func() { gatewayProcessMatcher = origMatcher }) + home := t.TempDir() t.Setenv("PICOCLAW_HOME", home) @@ -475,10 +507,20 @@ func TestHandleWebSocketProxyLoadsPidDataOnDemand(t *testing.T) { t.Fatalf("SaveConfig() error = %v", err) } - pidData, err := ppid.WritePidFile(globalConfigDir(), cfg.Gateway.Host, cfg.Gateway.Port) - if err != nil { - t.Fatalf("WritePidFile() error = %v", err) + cmd := startGatewayLikeProcess(t) + t.Cleanup(func() { + if cmd.Process != nil { + _ = cmd.Process.Kill() + } + _ = cmd.Wait() + }) + pidData := ppid.PidFileData{ + PID: cmd.Process.Pid, + Token: "test-token", + Host: cfg.Gateway.Host, + Port: cfg.Gateway.Port, } + writeTestPidFile(t, pidData) t.Cleanup(func() { ppid.RemovePidFile(globalConfigDir()) })