mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
fix(gateway): validate PID ownership and clean stale pid files (#2422)
* fix(gateway): validate PID ownership and clean stale pid files - include `pid` in health responses for runtime PID verification - add `RemovePidFileIfPID` to safely delete PID files only on PID match - sanitize gateway PID data via process-command checks with health fallback - ignore and remove stale/non-gateway PID files before gateway operations - refuse stop/restart actions when the attached process is not a gateway - update gateway and websocket tests to cover PID validation and safety paths * test(seahorse): use shared in-memory SQLite DB in tests to fix async compaction failures * test: remove unused sendMediaErr field from hook test mock
This commit is contained in:
@@ -515,7 +515,6 @@ type respondWithMediaHook struct {
|
||||
media []string
|
||||
responseHandled bool
|
||||
forLLM string
|
||||
sendMediaErr error
|
||||
}
|
||||
|
||||
func (h *respondWithMediaHook) BeforeTool(
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"maps"
|
||||
"net/http"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
@@ -31,6 +32,7 @@ type Check struct {
|
||||
type StatusResponse struct {
|
||||
Status string `json:"status"`
|
||||
Uptime string `json:"uptime"`
|
||||
PID int `json:"pid,omitempty"`
|
||||
Checks map[string]Check `json:"checks,omitempty"`
|
||||
}
|
||||
|
||||
@@ -170,6 +172,7 @@ func (s *Server) healthHandler(w http.ResponseWriter, r *http.Request) {
|
||||
resp := StatusResponse{
|
||||
Status: "ok",
|
||||
Uptime: uptime.String(),
|
||||
PID: os.Getpid(),
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
|
||||
@@ -151,6 +151,30 @@ func RemovePidFile(homePath string) {
|
||||
os.Remove(pidPath)
|
||||
}
|
||||
|
||||
// RemovePidFileIfPID deletes the PID file only when the recorded PID matches
|
||||
// expectedPID. It returns true when the file is removed successfully.
|
||||
func RemovePidFileIfPID(homePath string, expectedPID int) bool {
|
||||
if expectedPID <= 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
pidMu.Lock()
|
||||
defer pidMu.Unlock()
|
||||
|
||||
pidPath := pidFilePath(homePath)
|
||||
data, err := readPidFileUnlocked(pidPath)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if data.PID != expectedPID {
|
||||
return false
|
||||
}
|
||||
if err := os.Remove(pidPath); err != nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// readPidFileUnlocked reads the PID file without acquiring the lock.
|
||||
// Caller must hold pidMu.
|
||||
func readPidFileUnlocked(pidPath string) (*PidFileData, error) {
|
||||
|
||||
@@ -244,6 +244,40 @@ func TestRemovePidFileNonexistent(t *testing.T) {
|
||||
RemovePidFile(dir)
|
||||
}
|
||||
|
||||
func TestRemovePidFileIfPID(t *testing.T) {
|
||||
dir := tmpDir(t)
|
||||
|
||||
other := PidFileData{PID: 99999999, Token: "deadbeef12345678deadbeef12345678"}
|
||||
raw, _ := json.MarshalIndent(other, "", " ")
|
||||
path := filepath.Join(dir, pidFileName)
|
||||
os.WriteFile(path, raw, 0o600)
|
||||
|
||||
removed := RemovePidFileIfPID(dir, 99999999)
|
||||
if !removed {
|
||||
t.Fatal("expected RemovePidFileIfPID to remove matching pid file")
|
||||
}
|
||||
if _, err := os.Stat(path); !os.IsNotExist(err) {
|
||||
t.Error("PID file should be removed for matching expected PID")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemovePidFileIfPIDMismatch(t *testing.T) {
|
||||
dir := tmpDir(t)
|
||||
|
||||
other := PidFileData{PID: 99999999, Token: "deadbeef12345678deadbeef12345678"}
|
||||
raw, _ := json.MarshalIndent(other, "", " ")
|
||||
path := filepath.Join(dir, pidFileName)
|
||||
os.WriteFile(path, raw, 0o600)
|
||||
|
||||
removed := RemovePidFileIfPID(dir, 88888888)
|
||||
if removed {
|
||||
t.Fatal("expected RemovePidFileIfPID to keep non-matching pid file")
|
||||
}
|
||||
if _, err := os.Stat(path); os.IsNotExist(err) {
|
||||
t.Error("PID file should NOT be removed for mismatching expected PID")
|
||||
}
|
||||
}
|
||||
|
||||
// TestReadPidFileUnlockedInvalidJSON returns error for malformed content.
|
||||
func TestReadPidFileUnlockedInvalidJSON(t *testing.T) {
|
||||
dir := tmpDir(t)
|
||||
|
||||
@@ -2,14 +2,26 @@ package seahorse
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
var testDBCounter uint64
|
||||
|
||||
func openTestDB(t *testing.T) *sql.DB {
|
||||
t.Helper()
|
||||
db, err := sql.Open("sqlite", ":memory:")
|
||||
|
||||
n := atomic.AddUint64(&testDBCounter, 1)
|
||||
testName := strings.NewReplacer("/", "_", " ", "_").Replace(t.Name())
|
||||
// Use a shared in-memory database so concurrent goroutines/connections in tests
|
||||
// observe the same schema/data.
|
||||
dsn := fmt.Sprintf("file:seahorse_test_%s_%d?mode=memory&cache=shared", testName, n)
|
||||
|
||||
db, err := sql.Open("sqlite", dsn)
|
||||
if err != nil {
|
||||
t.Fatalf("open test db: %v", err)
|
||||
}
|
||||
|
||||
+168
-4
@@ -108,6 +108,8 @@ var gatewayHealthGet = func(url string, timeout time.Duration) (*http.Response,
|
||||
return client.Get(url)
|
||||
}
|
||||
|
||||
var gatewayProcessMatcher = isLikelyGatewayProcess
|
||||
|
||||
// getGatewayHealth checks the gateway health endpoint and returns the status response.
|
||||
// Returns (*health.StatusResponse, statusCode, error). If error is not nil, the other values are not valid.
|
||||
func (h *Handler) getGatewayHealth(cfg *config.Config, timeout time.Duration) (*health.StatusResponse, int, error) {
|
||||
@@ -117,7 +119,7 @@ func (h *Handler) getGatewayHealth(cfg *config.Config, timeout time.Duration) (*
|
||||
gateway.mu.Lock()
|
||||
if d := gateway.pidData; d != nil && d.Port > 0 {
|
||||
port = d.Port
|
||||
host = d.Host
|
||||
host = gatewayProbeHost(d.Host)
|
||||
}
|
||||
gateway.mu.Unlock()
|
||||
if port == 0 {
|
||||
@@ -150,6 +152,150 @@ func getGatewayHealthByURL(url string, timeout time.Duration) (*health.StatusRes
|
||||
return &healthResponse, resp.StatusCode, nil
|
||||
}
|
||||
|
||||
// isLikelyGatewayProcess returns whether PID appears to be a picoclaw gateway
|
||||
// process plus whether inspection was conclusive on this platform/environment.
|
||||
func isLikelyGatewayProcess(pid int) (bool, bool) {
|
||||
if pid <= 0 {
|
||||
return false, true
|
||||
}
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
psCmd := fmt.Sprintf(
|
||||
`$p=Get-CimInstance Win32_Process -Filter "ProcessId = %d"; if ($null -eq $p) { "" } else { $p.CommandLine }`,
|
||||
pid,
|
||||
)
|
||||
out, err := exec.Command("powershell", "-NoProfile", "-NonInteractive", "-Command", psCmd).Output()
|
||||
if err == nil {
|
||||
cmdline := strings.TrimSpace(string(out))
|
||||
if cmdline != "" {
|
||||
return looksLikeGatewayCommandLine(cmdline), true
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: determine only whether the process still exists.
|
||||
out, err = exec.Command("tasklist", "/FI", "PID eq "+strconv.Itoa(pid), "/FO", "CSV", "/NH").Output()
|
||||
if err != nil {
|
||||
return false, false
|
||||
}
|
||||
line := strings.ToLower(strings.TrimSpace(string(out)))
|
||||
if line == "" {
|
||||
return false, true
|
||||
}
|
||||
// A CSV row means the process exists, but may have a custom executable
|
||||
// name we cannot classify here.
|
||||
if strings.HasPrefix(line, "\"") {
|
||||
if strings.Contains(line, "\"picoclaw.exe\"") {
|
||||
return true, true
|
||||
}
|
||||
return false, false
|
||||
}
|
||||
if strings.Contains(line, "no tasks are running") {
|
||||
return false, true
|
||||
}
|
||||
return false, true
|
||||
}
|
||||
|
||||
out, err := exec.Command("ps", "-o", "command=", "-p", strconv.Itoa(pid)).Output()
|
||||
if err != nil {
|
||||
return false, false
|
||||
}
|
||||
cmdline := strings.ToLower(strings.TrimSpace(string(out)))
|
||||
if cmdline == "" {
|
||||
return false, true
|
||||
}
|
||||
return looksLikeGatewayCommandLine(cmdline), true
|
||||
}
|
||||
|
||||
// looksLikeGatewayCommandLine checks whether a process command line likely
|
||||
// represents "picoclaw gateway ..." regardless of executable filename.
|
||||
func looksLikeGatewayCommandLine(cmdline string) bool {
|
||||
fields := strings.Fields(strings.ToLower(strings.TrimSpace(cmdline)))
|
||||
if len(fields) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, f := range fields {
|
||||
token := strings.Trim(f, `"'`)
|
||||
if token == "gateway" || strings.HasSuffix(token, "/gateway") || strings.HasSuffix(token, `\gateway`) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *Handler) getGatewayHealthForPidData(
|
||||
pidData *ppid.PidFileData,
|
||||
cfg *config.Config,
|
||||
timeout time.Duration,
|
||||
) (*health.StatusResponse, int, error) {
|
||||
if pidData == nil {
|
||||
return nil, 0, errors.New("nil pid data")
|
||||
}
|
||||
|
||||
port := pidData.Port
|
||||
if port == 0 {
|
||||
port = 18790
|
||||
if cfg != nil && cfg.Gateway.Port != 0 {
|
||||
port = cfg.Gateway.Port
|
||||
}
|
||||
}
|
||||
|
||||
host := gatewayProbeHost(strings.TrimSpace(pidData.Host))
|
||||
if host == "" {
|
||||
host = gatewayProbeHost(h.effectiveGatewayBindHost(cfg))
|
||||
}
|
||||
if host == "" {
|
||||
host = "127.0.0.1"
|
||||
}
|
||||
|
||||
url := "http://" + net.JoinHostPort(host, strconv.Itoa(port)) + "/health"
|
||||
return getGatewayHealthByURL(url, timeout)
|
||||
}
|
||||
|
||||
func (h *Handler) validateGatewayPidData(
|
||||
pidData *ppid.PidFileData,
|
||||
cfg *config.Config,
|
||||
) (ok bool, decisive bool, reason string) {
|
||||
if pidData == nil || pidData.PID <= 0 {
|
||||
return false, true, "invalid pid data"
|
||||
}
|
||||
|
||||
if gatewayProcess, inspected := gatewayProcessMatcher(pidData.PID); inspected {
|
||||
if !gatewayProcess {
|
||||
return false, true, "pid process command is not picoclaw gateway"
|
||||
}
|
||||
return true, true, ""
|
||||
}
|
||||
|
||||
healthResp, statusCode, err := h.getGatewayHealthForPidData(pidData, cfg, 800*time.Millisecond)
|
||||
if err != nil {
|
||||
return false, false, fmt.Sprintf("health probe failed: %v", err)
|
||||
}
|
||||
if statusCode != http.StatusOK {
|
||||
return false, false, fmt.Sprintf("health endpoint returned status %d", statusCode)
|
||||
}
|
||||
if healthResp.PID > 0 && healthResp.PID != pidData.PID {
|
||||
return false, true, fmt.Sprintf("health pid mismatch: pidFile=%d, health=%d", pidData.PID, healthResp.PID)
|
||||
}
|
||||
return true, true, ""
|
||||
}
|
||||
|
||||
func (h *Handler) sanitizeGatewayPidData(pidData *ppid.PidFileData, cfg *config.Config) *ppid.PidFileData {
|
||||
if pidData == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
ok, decisive, reason := h.validateGatewayPidData(pidData, cfg)
|
||||
if ok {
|
||||
return pidData
|
||||
}
|
||||
|
||||
logger.Warnf("ignore pid file for PID %d: %s", pidData.PID, reason)
|
||||
if decisive && ppid.RemovePidFileIfPID(globalConfigDir(), pidData.PID) {
|
||||
logger.Warnf("removed stale pid file for PID %d", pidData.PID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// registerGatewayRoutes binds gateway lifecycle endpoints to the ServeMux.
|
||||
func (h *Handler) registerGatewayRoutes(mux *http.ServeMux) {
|
||||
mux.HandleFunc("GET /api/gateway/status", h.handleGatewayStatus)
|
||||
@@ -164,7 +310,7 @@ func (h *Handler) registerGatewayRoutes(mux *http.ServeMux) {
|
||||
// starts it when possible. Intended to be called by the backend at startup.
|
||||
func (h *Handler) TryAutoStartGateway() {
|
||||
// Check PID file first to detect an already-running gateway.
|
||||
pidData := ppid.ReadPidFileWithCheck(globalConfigDir())
|
||||
pidData := h.sanitizeGatewayPidData(ppid.ReadPidFileWithCheck(globalConfigDir()), nil)
|
||||
if pidData != nil {
|
||||
gateway.mu.Lock()
|
||||
ready, reason, err := h.gatewayStartReady()
|
||||
@@ -472,6 +618,11 @@ func stopGatewayLocked() (int, error) {
|
||||
}
|
||||
|
||||
pid := gateway.cmd.Process.Pid
|
||||
if !gateway.owned {
|
||||
if isGateway, inspected := gatewayProcessMatcher(pid); inspected && !isGateway {
|
||||
return pid, fmt.Errorf("refuse to stop non-gateway process (PID %d)", pid)
|
||||
}
|
||||
}
|
||||
|
||||
// Send SIGTERM for graceful shutdown (SIGKILL on Windows)
|
||||
var sigErr error
|
||||
@@ -681,7 +832,7 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
|
||||
// POST /api/gateway/start
|
||||
func (h *Handler) handleGatewayStart(w http.ResponseWriter, r *http.Request) {
|
||||
// Check PID file first to detect an already-running gateway.
|
||||
pidData := ppid.ReadPidFileWithCheck(globalConfigDir())
|
||||
pidData := h.sanitizeGatewayPidData(ppid.ReadPidFileWithCheck(globalConfigDir()), nil)
|
||||
if pidData != nil {
|
||||
pid := pidData.PID
|
||||
gateway.mu.Lock()
|
||||
@@ -807,9 +958,22 @@ func (h *Handler) RestartGateway() (int, error) {
|
||||
|
||||
gateway.mu.Lock()
|
||||
previousCmd := gateway.cmd
|
||||
previousOwned := gateway.owned
|
||||
setGatewayRuntimeStatusLocked("restarting")
|
||||
gateway.mu.Unlock()
|
||||
|
||||
if previousCmd != nil && previousCmd.Process != nil && !previousOwned {
|
||||
if isGateway, inspected := gatewayProcessMatcher(previousCmd.Process.Pid); inspected && !isGateway {
|
||||
logger.Warnf("refuse restarting non-gateway process (PID: %d)", previousCmd.Process.Pid)
|
||||
gateway.mu.Lock()
|
||||
if gateway.cmd == previousCmd {
|
||||
setGatewayRuntimeStatusLocked("running")
|
||||
}
|
||||
gateway.mu.Unlock()
|
||||
return 0, fmt.Errorf("refuse to restart non-gateway process (PID %d)", previousCmd.Process.Pid)
|
||||
}
|
||||
}
|
||||
|
||||
if err = stopGatewayProcessForRestart(previousCmd); err != nil {
|
||||
gateway.mu.Lock()
|
||||
if gateway.cmd == previousCmd {
|
||||
@@ -921,7 +1085,7 @@ func (h *Handler) gatewayStatusData() map[string]any {
|
||||
}
|
||||
|
||||
// Primary detection: read PID file and check if process is alive.
|
||||
pidData := ppid.ReadPidFileWithCheck(globalConfigDir())
|
||||
pidData := h.sanitizeGatewayPidData(ppid.ReadPidFileWithCheck(globalConfigDir()), cfg)
|
||||
if pidData != nil {
|
||||
gateway.mu.Lock()
|
||||
gateway.pidData = pidData
|
||||
|
||||
+234
-12
@@ -15,8 +15,6 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/auth"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
ppid "github.com/sipeed/picoclaw/pkg/pid"
|
||||
@@ -40,6 +38,36 @@ func startLongRunningProcess(t *testing.T) *exec.Cmd {
|
||||
return cmd
|
||||
}
|
||||
|
||||
func startGatewayLikeProcess(t *testing.T) *exec.Cmd {
|
||||
t.Helper()
|
||||
|
||||
var cmd *exec.Cmd
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("gateway-like process commandline check is not deterministic on Windows tests")
|
||||
}
|
||||
cmd = exec.Command("sh", "-c", "sleep 30 # picoclaw gateway")
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
t.Fatalf("Start() error = %v", err)
|
||||
}
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func writeTestPidFile(t *testing.T, data ppid.PidFileData) string {
|
||||
t.Helper()
|
||||
|
||||
path := filepath.Join(globalConfigDir(), ".picoclaw.pid")
|
||||
raw, err := json.MarshalIndent(data, "", " ")
|
||||
if err != nil {
|
||||
t.Fatalf("marshal pid file: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(path, raw, 0o600); err != nil {
|
||||
t.Fatalf("write pid file: %v", err)
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
func mockGatewayHealthResponse(statusCode, pid int) *http.Response {
|
||||
return &http.Response{
|
||||
StatusCode: statusCode,
|
||||
@@ -68,12 +96,14 @@ func resetGatewayTestState(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
originalHealthGet := gatewayHealthGet
|
||||
originalProcessMatcher := gatewayProcessMatcher
|
||||
originalRestartGracePeriod := gatewayRestartGracePeriod
|
||||
originalRestartForceKillWindow := gatewayRestartForceKillWindow
|
||||
originalRestartPollInterval := gatewayRestartPollInterval
|
||||
t.Setenv("PICOCLAW_HOME", t.TempDir())
|
||||
t.Cleanup(func() {
|
||||
gatewayHealthGet = originalHealthGet
|
||||
gatewayProcessMatcher = originalProcessMatcher
|
||||
gatewayRestartGracePeriod = originalRestartGracePeriod
|
||||
gatewayRestartForceKillWindow = originalRestartForceKillWindow
|
||||
gatewayRestartPollInterval = originalRestartPollInterval
|
||||
@@ -105,6 +135,105 @@ func TestGatewayStartReady_NoDefaultModel(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestLooksLikeGatewayCommandLine(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
cmdline string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "default picoclaw gateway",
|
||||
cmdline: "/usr/local/bin/picoclaw gateway -E",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "renamed binary with gateway subcommand",
|
||||
cmdline: "/opt/bin/custom-claw gateway -E -d",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "standalone gateway binary path",
|
||||
cmdline: "/opt/bin/gateway -E",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "non gateway process",
|
||||
cmdline: "/bin/sleep 30",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "gateway substring only",
|
||||
cmdline: "/opt/bin/gatewayd --serve",
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := looksLikeGatewayCommandLine(tc.cmdline)
|
||||
if got != tc.want {
|
||||
t.Fatalf("looksLikeGatewayCommandLine(%q) = %v, want %v", tc.cmdline, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateGatewayPidDataAcceptsHealthWhenMatcherInconclusive(t *testing.T) {
|
||||
resetGatewayTestState(t)
|
||||
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
const testPID = 34567
|
||||
pidData := &ppid.PidFileData{
|
||||
PID: testPID,
|
||||
Host: "127.0.0.1",
|
||||
Port: 18790,
|
||||
}
|
||||
|
||||
gatewayProcessMatcher = func(int) (bool, bool) { return false, false }
|
||||
gatewayHealthGet = func(string, time.Duration) (*http.Response, error) {
|
||||
return mockGatewayHealthResponse(http.StatusOK, testPID), nil
|
||||
}
|
||||
|
||||
ok, decisive, reason := h.validateGatewayPidData(pidData, nil)
|
||||
if !ok {
|
||||
t.Fatalf("validateGatewayPidData() ok = false, want true (reason=%q)", reason)
|
||||
}
|
||||
if !decisive {
|
||||
t.Fatalf("validateGatewayPidData() decisive = false, want true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateGatewayPidDataRejectsHealthPidMismatchWhenMatcherInconclusive(t *testing.T) {
|
||||
resetGatewayTestState(t)
|
||||
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
pidData := &ppid.PidFileData{
|
||||
PID: 34567,
|
||||
Host: "127.0.0.1",
|
||||
Port: 18790,
|
||||
}
|
||||
|
||||
gatewayProcessMatcher = func(int) (bool, bool) { return false, false }
|
||||
gatewayHealthGet = func(string, time.Duration) (*http.Response, error) {
|
||||
return mockGatewayHealthResponse(http.StatusOK, 99999), nil
|
||||
}
|
||||
|
||||
ok, decisive, reason := h.validateGatewayPidData(pidData, nil)
|
||||
if ok {
|
||||
t.Fatalf("validateGatewayPidData() ok = true, want false")
|
||||
}
|
||||
if !decisive {
|
||||
t.Fatalf("validateGatewayPidData() decisive = false, want true")
|
||||
}
|
||||
if !strings.Contains(reason, "health pid mismatch") {
|
||||
t.Fatalf("validateGatewayPidData() reason = %q, want contains %q", reason, "health pid mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayStartReady_InvalidDefaultModel(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
cfg := config.DefaultConfig()
|
||||
@@ -533,7 +662,7 @@ func TestGatewayStatusDowngradesRunningWhenTrackedProcessExitedAndPidFileMissing
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayStatusReportsRunningFromPidProbe(t *testing.T) {
|
||||
func TestGatewayStatusIgnoresAndRemovesPidFileForNonGatewayProcess(t *testing.T) {
|
||||
resetGatewayTestState(t)
|
||||
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
@@ -549,6 +678,87 @@ func TestGatewayStatusReportsRunningFromPidProbe(t *testing.T) {
|
||||
_ = cmd.Wait()
|
||||
})
|
||||
|
||||
pidPath := writeTestPidFile(t, ppid.PidFileData{
|
||||
PID: cmd.Process.Pid,
|
||||
Token: "stale-token",
|
||||
Host: "127.0.0.1",
|
||||
Port: 18790,
|
||||
})
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var body map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("unmarshal response: %v", err)
|
||||
}
|
||||
if got := body["gateway_status"]; got != "stopped" {
|
||||
t.Fatalf("gateway_status = %#v, want %q", got, "stopped")
|
||||
}
|
||||
if _, err := os.Stat(pidPath); !os.IsNotExist(err) {
|
||||
t.Fatal("stale pid file should be removed for non-gateway process")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayStopRefusesNonGatewayAttachedProcess(t *testing.T) {
|
||||
resetGatewayTestState(t)
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("commandline-based process type check is best-effort on Windows")
|
||||
}
|
||||
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
cmd := startLongRunningProcess(t)
|
||||
t.Cleanup(func() {
|
||||
if cmd.Process != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
}
|
||||
_ = cmd.Wait()
|
||||
})
|
||||
|
||||
gateway.mu.Lock()
|
||||
gateway.cmd = cmd
|
||||
gateway.owned = false
|
||||
setGatewayRuntimeStatusLocked("running")
|
||||
gateway.mu.Unlock()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/gateway/stop", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusInternalServerError {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusInternalServerError)
|
||||
}
|
||||
if !isCmdProcessAliveLocked(cmd) {
|
||||
t.Fatal("non-gateway process should not be terminated by /api/gateway/stop")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayStatusReportsRunningFromPidProbe(t *testing.T) {
|
||||
resetGatewayTestState(t)
|
||||
gatewayProcessMatcher = func(int) (bool, bool) { return true, true }
|
||||
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
cmd := startGatewayLikeProcess(t)
|
||||
t.Cleanup(func() {
|
||||
if cmd.Process != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
}
|
||||
_ = cmd.Wait()
|
||||
})
|
||||
|
||||
gateway.mu.Lock()
|
||||
setGatewayRuntimeStatusLocked("stopped")
|
||||
gateway.mu.Unlock()
|
||||
@@ -557,8 +767,12 @@ func TestGatewayStatusReportsRunningFromPidProbe(t *testing.T) {
|
||||
return mockGatewayHealthResponse(http.StatusOK, cmd.Process.Pid), nil
|
||||
}
|
||||
|
||||
_, err := ppid.WritePidFile(globalConfigDir(), "localhost", 0)
|
||||
require.NoError(t, err)
|
||||
writeTestPidFile(t, ppid.PidFileData{
|
||||
PID: cmd.Process.Pid,
|
||||
Token: "test-token",
|
||||
Host: "127.0.0.1",
|
||||
Port: 18790,
|
||||
})
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil)
|
||||
@@ -583,6 +797,7 @@ func TestGatewayStatusReportsRunningFromPidProbe(t *testing.T) {
|
||||
|
||||
func TestGatewayStatusRequiresRestartAfterDefaultModelChange(t *testing.T) {
|
||||
resetGatewayTestState(t)
|
||||
gatewayProcessMatcher = func(int) (bool, bool) { return true, true }
|
||||
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
cfg := config.DefaultConfig()
|
||||
@@ -601,16 +816,23 @@ func TestGatewayStatusRequiresRestartAfterDefaultModelChange(t *testing.T) {
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
process, err := os.FindProcess(os.Getpid())
|
||||
if err != nil {
|
||||
t.Fatalf("FindProcess() error = %v", err)
|
||||
}
|
||||
_, err = ppid.WritePidFile(globalConfigDir(), "localhost", 0)
|
||||
require.NoError(t, err)
|
||||
cmd := startGatewayLikeProcess(t)
|
||||
t.Cleanup(func() {
|
||||
if cmd.Process != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
}
|
||||
_ = cmd.Wait()
|
||||
})
|
||||
writeTestPidFile(t, ppid.PidFileData{
|
||||
PID: cmd.Process.Pid,
|
||||
Token: "test-token",
|
||||
Host: "127.0.0.1",
|
||||
Port: 18790,
|
||||
})
|
||||
|
||||
bootSignature := computeConfigSignature(cfg)
|
||||
gateway.mu.Lock()
|
||||
gateway.cmd = &exec.Cmd{Process: process}
|
||||
gateway.cmd = cmd
|
||||
gateway.bootDefaultModel = cfg.ModelList[0].ModelName
|
||||
gateway.bootConfigSignature = bootSignature
|
||||
setGatewayRuntimeStatusLocked("running")
|
||||
|
||||
@@ -64,7 +64,7 @@ func (h *Handler) handleWebSocketProxy() http.HandlerFunc {
|
||||
|
||||
gatewayAvailable := false
|
||||
// Prefer fresh PID file data when available.
|
||||
if pidData := ppid.ReadPidFileWithCheck(globalConfigDir()); pidData != nil {
|
||||
if pidData := h.sanitizeGatewayPidData(ppid.ReadPidFileWithCheck(globalConfigDir()), nil); pidData != nil {
|
||||
gateway.mu.Lock()
|
||||
gateway.pidData = pidData
|
||||
setGatewayRuntimeStatusLocked("running")
|
||||
|
||||
@@ -308,6 +308,10 @@ func TestHandlePicoSetup_Response(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestHandleWebSocketProxyReloadsGatewayTargetFromConfig(t *testing.T) {
|
||||
origMatcher := gatewayProcessMatcher
|
||||
gatewayProcessMatcher = func(int) (bool, bool) { return true, true }
|
||||
t.Cleanup(func() { gatewayProcessMatcher = origMatcher })
|
||||
|
||||
home := t.TempDir()
|
||||
t.Setenv("PICOCLAW_HOME", home)
|
||||
|
||||
@@ -339,9 +343,19 @@ func TestHandleWebSocketProxyReloadsGatewayTargetFromConfig(t *testing.T) {
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
if _, err := ppid.WritePidFile(globalConfigDir(), cfg.Gateway.Host, cfg.Gateway.Port); err != nil {
|
||||
t.Fatalf("WritePidFile() error = %v", err)
|
||||
}
|
||||
cmd := startGatewayLikeProcess(t)
|
||||
t.Cleanup(func() {
|
||||
if cmd.Process != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
}
|
||||
_ = cmd.Wait()
|
||||
})
|
||||
writeTestPidFile(t, ppid.PidFileData{
|
||||
PID: cmd.Process.Pid,
|
||||
Token: "test-token",
|
||||
Host: cfg.Gateway.Host,
|
||||
Port: cfg.Gateway.Port,
|
||||
})
|
||||
origPidData := gateway.pidData
|
||||
origPicoToken := gateway.picoToken
|
||||
t.Cleanup(func() {
|
||||
@@ -392,6 +406,10 @@ func TestHandleWebSocketProxyReloadsGatewayTargetFromConfig(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestHandleWebSocketProxyLoadsCachedPicoTokenWhenMissing(t *testing.T) {
|
||||
origMatcher := gatewayProcessMatcher
|
||||
gatewayProcessMatcher = func(int) (bool, bool) { return true, true }
|
||||
t.Cleanup(func() { gatewayProcessMatcher = origMatcher })
|
||||
|
||||
home := t.TempDir()
|
||||
t.Setenv("PICOCLAW_HOME", home)
|
||||
|
||||
@@ -416,9 +434,19 @@ func TestHandleWebSocketProxyLoadsCachedPicoTokenWhenMissing(t *testing.T) {
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
if _, err := ppid.WritePidFile(globalConfigDir(), cfg.Gateway.Host, cfg.Gateway.Port); err != nil {
|
||||
t.Fatalf("WritePidFile() error = %v", err)
|
||||
}
|
||||
cmd := startGatewayLikeProcess(t)
|
||||
t.Cleanup(func() {
|
||||
if cmd.Process != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
}
|
||||
_ = cmd.Wait()
|
||||
})
|
||||
writeTestPidFile(t, ppid.PidFileData{
|
||||
PID: cmd.Process.Pid,
|
||||
Token: "test-token",
|
||||
Host: cfg.Gateway.Host,
|
||||
Port: cfg.Gateway.Port,
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
ppid.RemovePidFile(globalConfigDir())
|
||||
})
|
||||
@@ -450,6 +478,10 @@ func TestHandleWebSocketProxyLoadsCachedPicoTokenWhenMissing(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestHandleWebSocketProxyLoadsPidDataOnDemand(t *testing.T) {
|
||||
origMatcher := gatewayProcessMatcher
|
||||
gatewayProcessMatcher = func(int) (bool, bool) { return true, true }
|
||||
t.Cleanup(func() { gatewayProcessMatcher = origMatcher })
|
||||
|
||||
home := t.TempDir()
|
||||
t.Setenv("PICOCLAW_HOME", home)
|
||||
|
||||
@@ -475,10 +507,20 @@ func TestHandleWebSocketProxyLoadsPidDataOnDemand(t *testing.T) {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
pidData, err := ppid.WritePidFile(globalConfigDir(), cfg.Gateway.Host, cfg.Gateway.Port)
|
||||
if err != nil {
|
||||
t.Fatalf("WritePidFile() error = %v", err)
|
||||
cmd := startGatewayLikeProcess(t)
|
||||
t.Cleanup(func() {
|
||||
if cmd.Process != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
}
|
||||
_ = cmd.Wait()
|
||||
})
|
||||
pidData := ppid.PidFileData{
|
||||
PID: cmd.Process.Pid,
|
||||
Token: "test-token",
|
||||
Host: cfg.Gateway.Host,
|
||||
Port: cfg.Gateway.Port,
|
||||
}
|
||||
writeTestPidFile(t, pidData)
|
||||
t.Cleanup(func() {
|
||||
ppid.RemovePidFile(globalConfigDir())
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user