From 0bb561548f099a60b588f0aaef30615af45aad05 Mon Sep 17 00:00:00 2001 From: Cytown Date: Sun, 29 Mar 2026 01:14:39 +0800 Subject: [PATCH 1/2] add pid file for gateway running and auth token for /reload and pico channel --- cmd/picoclaw-launcher-tui/ui/gateway.go | 72 ++---- cmd/picoclaw/internal/helpers.go | 6 +- pkg/agent/context.go | 10 +- pkg/auth/store.go | 7 +- pkg/channels/weixin/state.go | 6 +- pkg/config/config.go | 7 +- pkg/config/defaults.go | 12 +- pkg/config/envkeys.go | 20 ++ pkg/gateway/gateway.go | 40 ++- pkg/health/server.go | 36 ++- pkg/health/server_test.go | 11 +- pkg/migrate/internal/common.go | 11 +- pkg/pid/pidfile.go | 159 ++++++++++++ pkg/pid/pidfile_test.go | 253 +++++++++++++++++++ pkg/pid/pidfile_unix.go | 22 ++ pkg/pid/pidfile_windows.go | 42 ++++ web/backend/api/config.go | 6 + web/backend/api/gateway.go | 310 +++++++++++++++--------- web/backend/api/gateway_test.go | 3 - web/backend/api/pico.go | 61 ++++- web/backend/api/pico_test.go | 14 ++ web/backend/api/skills.go | 9 +- web/backend/middleware/middleware.go | 11 + web/backend/utils/runtime.go | 8 +- 24 files changed, 876 insertions(+), 260 deletions(-) create mode 100644 pkg/pid/pidfile.go create mode 100644 pkg/pid/pidfile_test.go create mode 100644 pkg/pid/pidfile_unix.go create mode 100644 pkg/pid/pidfile_windows.go diff --git a/cmd/picoclaw-launcher-tui/ui/gateway.go b/cmd/picoclaw-launcher-tui/ui/gateway.go index 1138c12db..781204bf2 100644 --- a/cmd/picoclaw-launcher-tui/ui/gateway.go +++ b/cmd/picoclaw-launcher-tui/ui/gateway.go @@ -7,9 +7,7 @@ package ui import ( "fmt" - "os" "os/exec" - "path/filepath" "runtime" "strconv" "strings" @@ -17,61 +15,30 @@ import ( "github.com/gdamore/tcell/v2" "github.com/rivo/tview" -) -const pidFileName = "gateway.pid" + "github.com/sipeed/picoclaw/pkg/config" + ppid "github.com/sipeed/picoclaw/pkg/pid" +) type gatewayStatus struct { running bool pid int + version string } -func getPidPath() string { - home, err := os.UserHomeDir() - if err != nil { - home = "." - } - return filepath.Join(home, ".picoclaw", pidFileName) -} - -func isProcessRunning(pid int) bool { - if runtime.GOOS == "windows" { - cmd := exec.Command("tasklist", "/FI", fmt.Sprintf("PID eq %d", pid)) - output, err := cmd.Output() - if err != nil { - return false - } - return strings.Contains(string(output), strconv.Itoa(pid)) - } else if runtime.GOOS == "darwin" { - cmd := exec.Command("ps", "aux") - output, err := cmd.Output() - if err != nil { - return false - } - return strings.Contains(string(output), fmt.Sprintf(" %d ", pid)) - } - // Linux - _, err := os.Stat(fmt.Sprintf("/proc/%d", pid)) - return err == nil +func picoHome() string { + return config.GetHome() } func getGatewayStatus() gatewayStatus { - pidPath := getPidPath() - data, err := os.ReadFile(pidPath) - if err != nil { - return gatewayStatus{running: false} - } - pid, err := strconv.Atoi(strings.TrimSpace(string(data))) - if err != nil { - return gatewayStatus{running: false} - } - if !isProcessRunning(pid) { - os.Remove(pidPath) + data := ppid.ReadPidFileWithCheck(picoHome()) + if data == nil { return gatewayStatus{running: false} } return gatewayStatus{ running: true, - pid: pid, + pid: data.PID, + version: data.Version, } } @@ -81,13 +48,12 @@ func startGateway() error { return fmt.Errorf("gateway is already running (PID: %d)", status.pid) } - pidPath := getPidPath() var cmd *exec.Cmd if runtime.GOOS == "windows" { cmd = exec.Command("cmd", "/C", "start /B picoclaw gateway > NUL 2>&1") } else { - cmd = exec.Command("sh", "-c", "nohup picoclaw gateway > /dev/null 2>&1 & echo $! > "+pidPath) + cmd = exec.Command("sh", "-c", "nohup picoclaw gateway > /dev/null 2>&1 &") } err := cmd.Start() @@ -116,9 +82,8 @@ func startGateway() error { if line == "" { continue } - pid, err := strconv.Atoi(line) + _, err := strconv.Atoi(line) if err == nil { - os.WriteFile(pidPath, []byte(strconv.Itoa(pid)), 0o600) break } } @@ -141,21 +106,20 @@ func stopGateway() error { if runtime.GOOS == "windows" { err = exec.Command("taskkill", "/F", "/PID", strconv.Itoa(status.pid)).Run() } else { - err = exec.Command("kill", "-9", strconv.Itoa(status.pid)).Run() + err = exec.Command("kill", strconv.Itoa(status.pid)).Run() } if err != nil { return err } - // 多次尝试确认进程已停止 + // Wait for process to stop (ReadPidFileWithCheck cleans up stale pid file) for i := 0; i < 5; i++ { - if !isProcessRunning(status.pid) { + if !getGatewayStatus().running { break } time.Sleep(200 * time.Millisecond) } - os.Remove(getPidPath()) return nil } @@ -217,7 +181,11 @@ func (a *App) newGatewayPage() tview.Primitive { updateStatus = func() { status := getGatewayStatus() if status.running { - statusTV.SetText(fmt.Sprintf("[#39ff14::b]GATEWAY RUNNING[-]\n\nPID: %d", status.pid)) + versionInfo := "" + if status.version != "" { + versionInfo = fmt.Sprintf("\nVersion: %s", status.version) + } + statusTV.SetText(fmt.Sprintf("[#39ff14::b]GATEWAY RUNNING[-]\n\nPID: %d%s", status.pid, versionInfo)) buttons.SetItemText(0, " [gray]START[white] ", "") buttons.SetItemText(1, " [red]STOP[white] ", "") } else { diff --git a/cmd/picoclaw/internal/helpers.go b/cmd/picoclaw/internal/helpers.go index 17de88ccb..afe5074a7 100644 --- a/cmd/picoclaw/internal/helpers.go +++ b/cmd/picoclaw/internal/helpers.go @@ -14,11 +14,7 @@ const Logo = pkg.Logo // GetPicoclawHome returns the picoclaw home directory. // Priority: $PICOCLAW_HOME > ~/.picoclaw func GetPicoclawHome() string { - if home := os.Getenv(config.EnvHome); home != "" { - return home - } - home, _ := os.UserHomeDir() - return filepath.Join(home, pkg.DefaultPicoClawHome) + return config.GetHome() } func GetConfigPath() string { diff --git a/pkg/agent/context.go b/pkg/agent/context.go index c3fcc9fff..b5c68650a 100644 --- a/pkg/agent/context.go +++ b/pkg/agent/context.go @@ -12,7 +12,6 @@ import ( "sync" "time" - "github.com/sipeed/picoclaw/pkg" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/providers" @@ -59,14 +58,7 @@ func (cb *ContextBuilder) WithSplitOnMarker(enabled bool) *ContextBuilder { } func getGlobalConfigDir() string { - if home := os.Getenv(config.EnvHome); home != "" { - return home - } - home, err := os.UserHomeDir() - if err != nil { - return "" - } - return filepath.Join(home, pkg.DefaultPicoClawHome) + return config.GetHome() } func NewContextBuilder(workspace string) *ContextBuilder { diff --git a/pkg/auth/store.go b/pkg/auth/store.go index 8a878d553..dfea11df4 100644 --- a/pkg/auth/store.go +++ b/pkg/auth/store.go @@ -6,7 +6,6 @@ import ( "path/filepath" "time" - "github.com/sipeed/picoclaw/pkg" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/fileutil" ) @@ -41,11 +40,7 @@ func (c *AuthCredential) NeedsRefresh() bool { } func authFilePath() string { - if home := os.Getenv(config.EnvHome); home != "" { - return filepath.Join(home, "auth.json") - } - home, _ := os.UserHomeDir() - return filepath.Join(home, pkg.DefaultPicoClawHome, "auth.json") + return filepath.Join(config.GetHome(), "auth.json") } func LoadStore() (*AuthStore, error) { diff --git a/pkg/channels/weixin/state.go b/pkg/channels/weixin/state.go index 2d1b9f4a6..854a4ab53 100644 --- a/pkg/channels/weixin/state.go +++ b/pkg/channels/weixin/state.go @@ -37,11 +37,7 @@ type syncCursorFile struct { } func picoclawHomeDir() string { - if home := os.Getenv(config.EnvHome); home != "" { - return home - } - userHome, _ := os.UserHomeDir() - return filepath.Join(userHome, ".picoclaw") + return config.GetHome() } func buildWeixinSyncBufPath(cfg config.WeixinConfig) string { diff --git a/pkg/config/config.go b/pkg/config/config.go index 533f45a44..e7326625c 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -1081,12 +1081,7 @@ func LoadConfig(path string) (*Config, error) { // Ensure Workspace has a default if not set if cfg.Agents.Defaults.Workspace == "" { - homePath, _ := os.UserHomeDir() - if picoclawHome := os.Getenv(EnvHome); picoclawHome != "" { - homePath = picoclawHome - } else if homePath != "" { - homePath = filepath.Join(homePath, pkg.DefaultPicoClawHome) - } + homePath := GetHome() cfg.Agents.Defaults.Workspace = filepath.Join(homePath, pkg.WorkspaceName) } diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go index bc4ab0649..bded97fcd 100644 --- a/pkg/config/defaults.go +++ b/pkg/config/defaults.go @@ -6,7 +6,6 @@ package config import ( - "os" "path/filepath" "github.com/sipeed/picoclaw/pkg" @@ -14,16 +13,7 @@ import ( // DefaultConfig returns the default configuration for PicoClaw. func DefaultConfig() *Config { - // Determine the base path for the workspace. - // Priority: $PICOCLAW_HOME > ~/.picoclaw - var homePath string - if picoclawHome := os.Getenv(EnvHome); picoclawHome != "" { - homePath = picoclawHome - } else { - userHome, _ := os.UserHomeDir() - homePath = filepath.Join(userHome, pkg.DefaultPicoClawHome) - } - workspacePath := filepath.Join(homePath, pkg.WorkspaceName) + workspacePath := filepath.Join(GetHome(), pkg.WorkspaceName) return &Config{ Version: CurrentVersion, diff --git a/pkg/config/envkeys.go b/pkg/config/envkeys.go index b04ff19f5..615769d3c 100644 --- a/pkg/config/envkeys.go +++ b/pkg/config/envkeys.go @@ -5,6 +5,13 @@ package config +import ( + "os" + "path/filepath" + + "github.com/sipeed/picoclaw/pkg" +) + // Runtime environment variable keys for the picoclaw process. // These control the location of files and binaries at runtime and are read // directly via os.Getenv / os.LookupEnv. All picoclaw-specific keys use the @@ -35,3 +42,16 @@ const ( // Default: "127.0.0.1" EnvGatewayHost = "PICOCLAW_GATEWAY_HOST" ) + +func GetHome() string { + homePath, _ := os.UserHomeDir() + if picoclawHome := os.Getenv(EnvHome); picoclawHome != "" { + homePath = picoclawHome + } else if homePath != "" { + homePath = filepath.Join(homePath, pkg.DefaultPicoClawHome) + } + if homePath == "" { + homePath = "." + } + return homePath +} diff --git a/pkg/gateway/gateway.go b/pkg/gateway/gateway.go index c35b3e744..a63530806 100644 --- a/pkg/gateway/gateway.go +++ b/pkg/gateway/gateway.go @@ -6,6 +6,7 @@ import ( "os" "os/signal" "path/filepath" + "strings" "sync" "sync/atomic" "syscall" @@ -36,6 +37,7 @@ import ( "github.com/sipeed/picoclaw/pkg/heartbeat" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/media" + "github.com/sipeed/picoclaw/pkg/pid" "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/state" "github.com/sipeed/picoclaw/pkg/tools" @@ -61,6 +63,7 @@ type services struct { HealthServer *health.Server manualReloadChan chan struct{} reloading atomic.Bool + authToken string } type startupBlockedProvider struct { @@ -107,6 +110,13 @@ func Run(debug bool, homePath, configPath string, allowEmptyStartup bool) error fmt.Println("🔍 Debug mode enabled") } + // Enforce singleton: write PID file with generated token. + pidData, err := pid.WritePidFile(homePath, cfg.Gateway.Host, cfg.Gateway.Port) + if err != nil { + return fmt.Errorf("singleton check failed: %w", err) + } + defer pid.RemovePidFile(homePath) + provider, modelID, err := createStartupProvider(cfg, allowEmptyStartup) if err != nil { return fmt.Errorf("error creating provider: %w", err) @@ -133,7 +143,7 @@ func Run(debug bool, homePath, configPath string, allowEmptyStartup bool) error "skills_available": skillsInfo["available"], }) - runningServices, err := setupAndStartServices(cfg, agentLoop, msgBus) + runningServices, err := setupAndStartServices(cfg, agentLoop, msgBus, pidData.Token) if err != nil { return err } @@ -224,6 +234,9 @@ func executeReload( allowEmptyStartup bool, ) error { defer runningServices.reloading.Store(false) + + overridePicoToken(newCfg, runningServices.authToken) + return handleConfigReload(ctx, agentLoop, newCfg, provider, runningServices, msgBus, allowEmptyStartup) } @@ -248,6 +261,7 @@ func setupAndStartServices( cfg *config.Config, agentLoop *agent.AgentLoop, msgBus *bus.MessageBus, + authToken string, ) (*services, error) { runningServices := &services{} @@ -290,6 +304,8 @@ func setupAndStartServices( fms.Start() } + overridePicoToken(cfg, authToken) + runningServices.ChannelManager, err = channels.NewManager(cfg, msgBus, runningServices.MediaStore) if err != nil { if fms, ok := runningServices.MediaStore.(*media.FileMediaStore); ok { @@ -314,7 +330,8 @@ func setupAndStartServices( } addr := fmt.Sprintf("%s:%d", cfg.Gateway.Host, cfg.Gateway.Port) - runningServices.HealthServer = health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port) + runningServices.authToken = authToken + runningServices.HealthServer = health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port, authToken) runningServices.ChannelManager.SetupHTTPServer(addr, runningServices.HealthServer) if err = runningServices.ChannelManager.StartAll(context.Background()); err != nil { @@ -524,6 +541,9 @@ func restartServices( logger.InfoCF("voice", "Transcription disabled", nil) } + // NOTE: PID file is written once at startup and not updated on reload. + // Changing the gateway listen address requires a full restart. + return nil } @@ -642,6 +662,22 @@ func setupCronTool( return cronService, nil } +const picoTokenPrefix = "pico-" + +// overridePicoToken replaces the pico channel token with the one from the PID file. +// The PID file is the single source of truth for the pico auth token; +// it is generated once at gateway startup and remains unchanged across reloads. +func overridePicoToken(cfg *config.Config, token string) { + if !cfg.Channels.Pico.Enabled { + return + } + picoToken := cfg.Channels.Pico.Token.String() + if picoToken == "" || strings.HasPrefix(picoToken, picoTokenPrefix) { + return + } + cfg.Channels.Pico.SetToken(picoTokenPrefix + token + picoToken) +} + func createHeartbeatHandler(agentLoop *agent.AgentLoop) func(prompt, channel, chatID string) *tools.ToolResult { return func(prompt, channel, chatID string) *tools.ToolResult { if channel == "" || chatID == "" { diff --git a/pkg/health/server.go b/pkg/health/server.go index 387cb0756..2602cb965 100644 --- a/pkg/health/server.go +++ b/pkg/health/server.go @@ -2,11 +2,11 @@ package health import ( "context" + "crypto/subtle" "encoding/json" "fmt" "maps" "net/http" - "os" "sync" "time" ) @@ -18,6 +18,7 @@ type Server struct { checks map[string]Check startTime time.Time reloadFunc func() error + authToken string // optional bearer token for protected endpoints } type Check struct { @@ -31,15 +32,15 @@ type StatusResponse struct { Status string `json:"status"` Uptime string `json:"uptime"` Checks map[string]Check `json:"checks,omitempty"` - Pid int `json:"pid"` } -func NewServer(host string, port int) *Server { +func NewServer(host string, port int, token string) *Server { mux := http.NewServeMux() s := &Server{ ready: false, checks: make(map[string]Check), startTime: time.Now(), + authToken: token, } mux.HandleFunc("/health", s.healthHandler) @@ -123,6 +124,21 @@ func (s *Server) reloadHandler(w http.ResponseWriter, r *http.Request) { return } + // Token check + s.mu.RLock() + requiredToken := s.authToken + s.mu.RUnlock() + + if requiredToken != "" { + given := extractBearerToken(r.Header.Get("Authorization")) + if given == "" || subtle.ConstantTimeCompare([]byte(given), []byte(requiredToken)) != 1 { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusUnauthorized) + json.NewEncoder(w).Encode(map[string]string{"error": "unauthorized"}) + return + } + } + s.mu.Lock() reloadFunc := s.reloadFunc s.mu.Unlock() @@ -154,7 +170,6 @@ func (s *Server) healthHandler(w http.ResponseWriter, r *http.Request) { resp := StatusResponse{ Status: "ok", Uptime: uptime.String(), - Pid: os.Getpid(), } json.NewEncoder(w).Encode(resp) @@ -220,3 +235,16 @@ func statusString(ok bool) string { } return "fail" } + +// extractBearerToken returns the token from an "Authorization: Bearer " header, +// or the empty string if the header is missing or malformed. +func extractBearerToken(header string) string { + const prefix = "Bearer " + if len(header) < len(prefix) { + return "" + } + if header[:len(prefix)] != prefix { + return "" + } + return header[len(prefix):] +} diff --git a/pkg/health/server_test.go b/pkg/health/server_test.go index 6e0b5e66b..c4982fff9 100644 --- a/pkg/health/server_test.go +++ b/pkg/health/server_test.go @@ -15,6 +15,7 @@ func newTestServer() *Server { ready: false, checks: make(map[string]Check), startTime: time.Now(), + authToken: "test", } return s } @@ -37,9 +38,6 @@ func TestHealthHandler_ReturnsOK(t *testing.T) { if resp.Status != "ok" { t.Errorf("status = %q, want %q", resp.Status, "ok") } - if resp.Pid == 0 { - t.Error("pid should not be 0") - } if resp.Uptime == "" { t.Error("uptime should not be empty") } @@ -168,6 +166,7 @@ func TestReloadHandler_NoReloadFunc(t *testing.T) { s := newTestServer() req := httptest.NewRequest(http.MethodPost, "/reload", nil) + req.Header.Set("Authorization", "Bearer test") w := httptest.NewRecorder() s.reloadHandler(w, req) @@ -186,6 +185,7 @@ func TestReloadHandler_Success(t *testing.T) { }) req := httptest.NewRequest(http.MethodPost, "/reload", nil) + req.Header.Set("Authorization", "Bearer test") w := httptest.NewRecorder() s.reloadHandler(w, req) @@ -205,6 +205,7 @@ func TestReloadHandler_Error(t *testing.T) { }) req := httptest.NewRequest(http.MethodPost, "/reload", nil) + req.Header.Set("Authorization", "Bearer test") w := httptest.NewRecorder() s.reloadHandler(w, req) @@ -292,7 +293,7 @@ func TestRegisterOnMux(t *testing.T) { } func TestNewServer(t *testing.T) { - s := NewServer("127.0.0.1", 0) + s := NewServer("127.0.0.1", 0, "") if s == nil { t.Fatal("NewServer returned nil") } @@ -305,7 +306,7 @@ func TestNewServer(t *testing.T) { } func TestStartContext_Cancellation(t *testing.T) { - s := NewServer("127.0.0.1", 0) + s := NewServer("127.0.0.1", 0, "") ctx, cancel := context.WithCancel(context.Background()) diff --git a/pkg/migrate/internal/common.go b/pkg/migrate/internal/common.go index 65a87adc4..f1179c3a9 100644 --- a/pkg/migrate/internal/common.go +++ b/pkg/migrate/internal/common.go @@ -1,12 +1,10 @@ package internal import ( - "fmt" "io" "os" "path/filepath" - "github.com/sipeed/picoclaw/pkg" "github.com/sipeed/picoclaw/pkg/config" ) @@ -14,14 +12,7 @@ func ResolveTargetHome(override string) (string, error) { if override != "" { return ExpandHome(override), nil } - if envHome := os.Getenv(config.EnvHome); envHome != "" { - return ExpandHome(envHome), nil - } - home, err := os.UserHomeDir() - if err != nil { - return "", fmt.Errorf("resolving home directory: %w", err) - } - return filepath.Join(home, pkg.DefaultPicoClawHome), nil + return config.GetHome(), nil } func ExpandHome(path string) string { diff --git a/pkg/pid/pidfile.go b/pkg/pid/pidfile.go new file mode 100644 index 000000000..584b9b2b5 --- /dev/null +++ b/pkg/pid/pidfile.go @@ -0,0 +1,159 @@ +package pid + +import ( + "crypto/rand" + "encoding/hex" + "encoding/json" + "fmt" + "os" + "path/filepath" + "sync" + "time" + + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" +) + +const pidFileName = ".picoclaw.pid" + +// PidFileData is the JSON structure stored in the PID file. +type PidFileData struct { + PID int `json:"pid"` + Token string `json:"token"` + Version string `json:"version"` + Port int `json:"port"` + Host string `json:"host"` +} + +var pidMu sync.Mutex + +// pidFilePath returns the absolute path for the PID file given the home directory. +func pidFilePath(homePath string) string { + return filepath.Join(homePath, pidFileName) +} + +// generateToken creates a cryptographically random 32-character hex token. +func generateToken() string { + b := make([]byte, 16) + if _, err := rand.Read(b); err != nil { + // Fallback to something pseudo-random if crypto/rand fails + return fmt.Sprintf("%032x", time.Now().UnixNano()) + } + return hex.EncodeToString(b) +} + +// WritePidFile creates (or overwrites) the PID file atomically. +// It returns an error if another gateway instance appears to be running +// (a valid PID file exists with a live process). +func WritePidFile(homePath, host string, port int) (*PidFileData, error) { + pidMu.Lock() + defer pidMu.Unlock() + + pidPath := pidFilePath(homePath) + + // Check for existing PID file → singleton enforcement. + if data, err := readPidFileUnlocked(pidPath); err == nil { + if os.Getpid() != data.PID { + logger.Infof("found pid file (PID: %d, version: %s)", data.PID, data.Version) + if isProcessRunning(data.PID) { + return nil, fmt.Errorf("gateway is already running (PID: %d, version: %s)", data.PID, data.Version) + } + logger.Warnf("not running (PID: %d) so will remove the pid file: %s", data.PID, pidPath) + } + // Stale PID file; process no longer exists → clean up. + os.Remove(pidPath) + } + + data := &PidFileData{ + PID: os.Getpid(), + Version: config.GetVersion(), + Port: port, + Host: host, + } + + token := generateToken() + data.Token = token + + raw, err := json.MarshalIndent(data, "", " ") + if err != nil { + return nil, fmt.Errorf("failed to marshal pid file: %w", err) + } + + // Ensure parent directory exists. + dir := filepath.Dir(pidPath) + if err := os.MkdirAll(dir, 0o755); err != nil { + return nil, fmt.Errorf("failed to create pid directory: %w", err) + } + + // Write atomically via temp file + rename. + tmp := pidPath + ".tmp" + if err := os.WriteFile(tmp, raw, 0o600); err != nil { + return nil, fmt.Errorf("failed to write pid file: %w", err) + } + if err := os.Rename(tmp, pidPath); err != nil { + os.Remove(tmp) + return nil, fmt.Errorf("failed to rename pid file: %w", err) + } + + return data, nil +} + +// ReadPidFileWithCheck reads the PID file and additionally checks if +// the recorded process is still alive. Returns nil if the file is +// missing, unreadable, or the process has exited. +func ReadPidFileWithCheck(homePath string) *PidFileData { + pidMu.Lock() + defer pidMu.Unlock() + + pidPath := pidFilePath(homePath) + data, err := readPidFileUnlocked(pidPath) + if err != nil { + return nil + } + + if !isProcessRunning(data.PID) { + os.Remove(pidPath) + return nil + } + + return data +} + +// RemovePidFile deletes the PID file (e.g. on graceful shutdown). +func RemovePidFile(homePath string) { + pidMu.Lock() + defer pidMu.Unlock() + + pidPath := pidFilePath(homePath) + // Only remove if the PID matches our own process (avoid deleting + // a file that belongs to a newer gateway instance). + if data, err := readPidFileUnlocked(pidPath); err == nil { + if data.PID != os.Getpid() { + return + } + } + + logger.Infof("remove pid file: %s", pidPath) + os.Remove(pidPath) +} + +// readPidFileUnlocked reads the PID file without acquiring the lock. +// Caller must hold pidMu. +func readPidFileUnlocked(pidPath string) (*PidFileData, error) { + raw, err := os.ReadFile(pidPath) + if err != nil { + return nil, err + } + + var data PidFileData + if err := json.Unmarshal(raw, &data); err != nil { + return nil, err + } + + // Validate PID is a positive integer. + if data.PID <= 0 { + return nil, fmt.Errorf("invalid pid in pid file: %d", data.PID) + } + + return &data, nil +} diff --git a/pkg/pid/pidfile_test.go b/pkg/pid/pidfile_test.go new file mode 100644 index 000000000..921f590ad --- /dev/null +++ b/pkg/pid/pidfile_test.go @@ -0,0 +1,253 @@ +package pid + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" +) + +// tmpDir returns a clean temporary directory for a test. +func tmpDir(t *testing.T) string { + t.Helper() + dir, err := os.MkdirTemp("", "pidtest-*") + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { os.RemoveAll(dir) }) + return dir +} + +// TestGenerateToken verifies that generateToken produces a 32-character hex string. +func TestGenerateToken(t *testing.T) { + token := generateToken() + if len(token) != 32 { + t.Errorf("expected token length 32, got %d (token: %q)", len(token), token) + } + // Verify all characters are valid hex. + for _, c := range token { + if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) { + t.Errorf("token contains non-hex character: %c", c) + } + } +} + +// TestGenerateTokenUniqueness checks that two consecutive tokens differ. +func TestGenerateTokenUniqueness(t *testing.T) { + a := generateToken() + b := generateToken() + if a == b { + t.Error("two consecutive tokens should not be equal") + } +} + +// TestPidFilePath returns the expected path. +func TestPidFilePath(t *testing.T) { + dir := tmpDir(t) + got := pidFilePath(dir) + want := filepath.Join(dir, pidFileName) + if got != want { + t.Errorf("pidFilePath(%q) = %q, want %q", dir, got, want) + } +} + +// TestWritePidFile creates a PID file and verifies its contents. +func TestWritePidFile(t *testing.T) { + dir := tmpDir(t) + data, err := WritePidFile(dir, "127.0.0.1", 18790) + if err != nil { + t.Fatalf("WritePidFile failed: %v", err) + } + + if data.PID != os.Getpid() { + t.Errorf("PID = %d, want %d", data.PID, os.Getpid()) + } + if data.Host != "127.0.0.1" { + t.Errorf("Host = %q, want %q", data.Host, "127.0.0.1") + } + if data.Port != 18790 { + t.Errorf("Port = %d, want %d", data.Port, 18790) + } + if len(data.Token) != 32 { + t.Errorf("Token length = %d, want 32", len(data.Token)) + } + + // Verify the file exists and can be unmarshalled. + raw, err := os.ReadFile(filepath.Join(dir, pidFileName)) + if err != nil { + t.Fatalf("failed to read pid file: %v", err) + } + + var fileData PidFileData + if err = json.Unmarshal(raw, &fileData); err != nil { + t.Fatalf("failed to unmarshal pid file: %v", err) + } + if fileData.PID != data.PID || fileData.Token != data.Token { + t.Error("file data mismatch") + } + + // Verify file permissions (owner-only read/write). + info, err := os.Stat(filepath.Join(dir, pidFileName)) + if err != nil { + t.Fatalf("failed to stat pid file: %v", err) + } + perm := info.Mode().Perm() + if perm != 0o600 { + t.Errorf("file permission = %o, want 0600", perm) + } +} + +// TestWritePidFileOverwrite writes twice and verifies the PID file is replaced. +func TestWritePidFileOverwrite(t *testing.T) { + dir := tmpDir(t) + + data1, err := WritePidFile(dir, "0.0.0.0", 18790) + if err != nil { + t.Fatalf("first WritePidFile failed: %v", err) + } + + // Second write should succeed because the PID matches our process. + data2, err := WritePidFile(dir, "0.0.0.0", 18800) + if err != nil { + t.Fatalf("second WritePidFile failed: %v", err) + } + + if data2.Token == data1.Token { + t.Error("token should change on re-write") + } + if data2.Port != 18800 { + t.Errorf("Port = %d, want 18800", data2.Port) + } +} + +// TestWritePidFileStalePID writes a PID file with a non-running PID, then +// verifies WritePidFile cleans it up and writes a new one. +func TestWritePidFileStalePID(t *testing.T) { + dir := tmpDir(t) + + // Write a PID file with a PID that almost certainly doesn't exist. + stale := PidFileData{PID: 99999999, Token: "deadbeef12345678deadbeef12345678"} + raw, _ := json.MarshalIndent(stale, "", " ") + os.WriteFile(filepath.Join(dir, pidFileName), raw, 0o600) + + data, err := WritePidFile(dir, "127.0.0.1", 18790) + if err != nil { + t.Fatalf("WritePidFile with stale PID failed: %v", err) + } + if data.PID != os.Getpid() { + t.Errorf("PID = %d, want %d", data.PID, os.Getpid()) + } +} + +// TestReadPidFileWithCheck verifies reading a valid PID file for the current process. +func TestReadPidFileWithCheck(t *testing.T) { + dir := tmpDir(t) + + // Some sandboxed environments (e.g. macOS test runner) may restrict + // signal(0), causing isProcessRunning(getpid()) to return false. + if !isProcessRunning(os.Getpid()) { + t.Skip("skipping: isProcessRunning(getpid()) is false in this environment") + } + + written, err := WritePidFile(dir, "127.0.0.1", 18790) + if err != nil { + t.Fatalf("WritePidFile failed: %v", err) + } + + read := ReadPidFileWithCheck(dir) + if read == nil { + t.Fatal("ReadPidFileWithCheck returned nil for current process") + } + if read.PID != written.PID || read.Token != written.Token { + t.Error("read data doesn't match written data") + } +} + +// TestReadPidFileWithCheckNonexistent returns nil for missing file. +func TestReadPidFileWithCheckNonexistent(t *testing.T) { + dir := tmpDir(t) + data := ReadPidFileWithCheck(dir) + if data != nil { + t.Error("expected nil for nonexistent PID file") + } +} + +// TestReadPidFileWithCheckStalePID auto-cleans a PID file whose process is dead. +func TestReadPidFileWithCheckStalePID(t *testing.T) { + dir := tmpDir(t) + + stale := PidFileData{PID: 99999999, Token: "deadbeef12345678deadbeef12345678"} + raw, _ := json.MarshalIndent(stale, "", " ") + os.WriteFile(filepath.Join(dir, pidFileName), raw, 0o600) + + data := ReadPidFileWithCheck(dir) + if data != nil { + t.Error("expected nil for stale PID") + } + + // File should be cleaned up. + if _, err := os.Stat(filepath.Join(dir, pidFileName)); !os.IsNotExist(err) { + t.Error("stale PID file should be removed") + } +} + +// TestRemovePidFile removes the PID file for the current process. +func TestRemovePidFile(t *testing.T) { + dir := tmpDir(t) + + if _, err := WritePidFile(dir, "127.0.0.1", 18790); err != nil { + t.Fatalf("WritePidFile failed: %v", err) + } + + RemovePidFile(dir) + + if _, err := os.Stat(filepath.Join(dir, pidFileName)); !os.IsNotExist(err) { + t.Error("PID file should be removed") + } +} + +// TestRemovePidFileDifferentPID does not remove a PID file owned by another process. +func TestRemovePidFileDifferentPID(t *testing.T) { + dir := tmpDir(t) + + other := PidFileData{PID: 99999999, Token: "deadbeef12345678deadbeef12345678"} + raw, _ := json.MarshalIndent(other, "", " ") + os.WriteFile(filepath.Join(dir, pidFileName), raw, 0o600) + + RemovePidFile(dir) + + if _, err := os.Stat(filepath.Join(dir, pidFileName)); os.IsNotExist(err) { + t.Error("PID file should NOT be removed (different PID)") + } +} + +// TestRemovePidFileNonexistent does not error on missing file. +func TestRemovePidFileNonexistent(t *testing.T) { + dir := tmpDir(t) + // Should not panic or error. + RemovePidFile(dir) +} + +// TestReadPidFileUnlockedInvalidJSON returns error for malformed content. +func TestReadPidFileUnlockedInvalidJSON(t *testing.T) { + dir := tmpDir(t) + path := filepath.Join(dir, pidFileName) + os.WriteFile(path, []byte("not json"), 0o600) + + _, err := readPidFileUnlocked(path) + if err == nil { + t.Error("expected error for invalid JSON") + } +} + +// TestReadPidFileUnlockedInvalidPID returns error for non-positive PID. +func TestReadPidFileUnlockedInvalidPID(t *testing.T) { + dir := tmpDir(t) + path := filepath.Join(dir, pidFileName) + os.WriteFile(path, []byte(`{"pid": -1, "token": "a"}`), 0o600) + + _, err := readPidFileUnlocked(path) + if err == nil { + t.Error("expected error for invalid PID") + } +} diff --git a/pkg/pid/pidfile_unix.go b/pkg/pid/pidfile_unix.go new file mode 100644 index 000000000..5459d8370 --- /dev/null +++ b/pkg/pid/pidfile_unix.go @@ -0,0 +1,22 @@ +//go:build !windows + +package pid + +import ( + "os" + "syscall" +) + +// isProcessRunning checks whether a process with the given PID is alive +// on Unix-like systems using signal(0). +func isProcessRunning(pid int) bool { + if pid <= 0 { + return false + } + p, err := os.FindProcess(pid) + if err != nil { + return false + } + // Signal(nil) does not kill the process but checks existence on Unix. + return p.Signal(syscall.Signal(0)) == nil +} diff --git a/pkg/pid/pidfile_windows.go b/pkg/pid/pidfile_windows.go new file mode 100644 index 000000000..6a2cce793 --- /dev/null +++ b/pkg/pid/pidfile_windows.go @@ -0,0 +1,42 @@ +//go:build windows + +package pid + +import ( + "syscall" + "unsafe" +) + +var ( + kernel32 = syscall.NewLazyDLL("kernel32.dll") + procOpenProcess = kernel32.NewProc("OpenProcess") + procGetExitCodeProcess = kernel32.NewProc("GetExitCodeProcess") + procCloseHandle = kernel32.NewProc("CloseHandle") + processQueryLimitedInformation = uint32(0x1000) + stillActive = uint32(259) +) + +// isProcessRunning checks whether a process with the given PID is alive +// on Windows using OpenProcess + GetExitCodeProcess. +func isProcessRunning(pid int) bool { + if pid <= 0 { + return false + } + + handle, _, err := procOpenProcess.Call( + uintptr(processQueryLimitedInformation), + 0, + uintptr(pid), + ) + if handle == 0 || err != nil { + return false + } + defer procCloseHandle.Call(handle) + + var exitCode uint32 + ret, _, err := procGetExitCodeProcess.Call(handle, uintptr(unsafe.Pointer(&exitCode))) + if ret == 0 || err != nil { + return false + } + return exitCode == stillActive +} diff --git a/web/backend/api/config.go b/web/backend/api/config.go index 0add7594d..8da2f53e0 100644 --- a/web/backend/api/config.go +++ b/web/backend/api/config.go @@ -81,6 +81,9 @@ func (h *Handler) handleUpdateConfig(w http.ResponseWriter, r *http.Request) { return } + // Refresh cached pico token in case user changed it. + refreshPicoToken(&cfg) + w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]string{"status": "ok"}) } @@ -175,6 +178,9 @@ func (h *Handler) handlePatchConfig(w http.ResponseWriter, r *http.Request) { return } + // Refresh cached pico token in case user changed it. + refreshPicoToken(&newCfg) + w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]string{"status": "ok"}) } diff --git a/web/backend/api/gateway.go b/web/backend/api/gateway.go index 808475cb6..4aebd2eaa 100644 --- a/web/backend/api/gateway.go +++ b/web/backend/api/gateway.go @@ -20,6 +20,7 @@ import ( "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/health" "github.com/sipeed/picoclaw/pkg/logger" + ppid "github.com/sipeed/picoclaw/pkg/pid" "github.com/sipeed/picoclaw/web/backend/utils" ) @@ -32,11 +33,50 @@ var gateway = struct { runtimeStatus string startupDeadline time.Time logs *LogBuffer + pidData *ppid.PidFileData // pid file data read from picoclaw.pid.json + picoToken string // cached pico token from config (for proxy auth validation) }{ runtimeStatus: "stopped", logs: NewLogBuffer(200), } +// refreshPicoToken updates gateway.picoToken from cfg +func refreshPicoToken(cfg *config.Config) { + gateway.mu.Lock() + defer gateway.mu.Unlock() + gateway.picoToken = cfg.Channels.Pico.Token.String() +} + +// refreshPicoTokensLocked reads the pico token from config and caches it. +// Caller must hold gateway.mu (or be sole writer). +func refreshPicoTokensLocked(configPath string) { + cfg, err := config.LoadConfig(configPath) + if err != nil { + return + } + gateway.picoToken = cfg.Channels.Pico.Token.String() +} + +const ( + protocolKey = "Sec-Websocket-Protocol" + picoTokenPrefix = "pico-" + tokenPrefix = "token." +) + +// picoComposedToken returns "pico-"+pidToken+picoToken for gateway auth. +func picoComposedToken(token string) string { + gateway.mu.Lock() + defer gateway.mu.Unlock() + // if not initial pico token, don't allow gateway auth + if gateway.picoToken == "" || gateway.pidData == nil { + return "" + } + if tokenPrefix+gateway.picoToken != token { + return "" + } + return picoTokenPrefix + gateway.pidData.Token + gateway.picoToken +} + var ( gatewayStartupWindow = 15 * time.Second gatewayRestartGracePeriod = 5 * time.Second @@ -49,16 +89,29 @@ var gatewayHealthGet = func(url string, timeout time.Duration) (*http.Response, return client.Get(url) } -// getGatewayHealth checks the gateway health endpoint and returns the status response +// getGatewayHealth checks the gateway health endpoint and returns the status response. // Returns (*health.StatusResponse, statusCode, error). If error is not nil, the other values are not valid. func (h *Handler) getGatewayHealth(cfg *config.Config, timeout time.Duration) (*health.StatusResponse, int, error) { - port := 18790 - if cfg != nil && cfg.Gateway.Port != 0 { - port = cfg.Gateway.Port + // Prefer port/host from pidData when available. + var port int + var host string + gateway.mu.Lock() + if d := gateway.pidData; d != nil && d.Port > 0 { + port = d.Port + host = d.Host + } + gateway.mu.Unlock() + if port == 0 { + port = 18790 + if cfg != nil && cfg.Gateway.Port != 0 { + port = cfg.Gateway.Port + } + } + if host == "" { + host = gatewayProbeHost(h.effectiveGatewayBindHost(cfg)) } - probeHost := gatewayProbeHost(h.effectiveGatewayBindHost(cfg)) - url := "http://" + net.JoinHostPort(probeHost, strconv.Itoa(port)) + "/health" + url := "http://" + net.JoinHostPort(host, strconv.Itoa(port)) + "/health" return getGatewayHealthByURL(url, timeout) } @@ -91,30 +144,34 @@ func (h *Handler) registerGatewayRoutes(mux *http.ServeMux) { // TryAutoStartGateway checks whether gateway start preconditions are met and // starts it when possible. Intended to be called by the backend at startup. func (h *Handler) TryAutoStartGateway() { - // Check if gateway is already running via health endpoint - cfg, cfgErr := config.LoadConfig(h.configPath) - if cfgErr == nil && cfg != nil { - healthResp, statusCode, err := h.getGatewayHealth(cfg, 2*time.Second) - if err == nil && statusCode == http.StatusOK { - // Gateway is already running, attach to the existing process - pid := healthResp.Pid - gateway.mu.Lock() - defer gateway.mu.Unlock() - ready, reason, err := h.gatewayStartReady() - if err != nil { - logger.ErrorC("gateway", fmt.Sprintf("Skip auto-starting gateway: %v", err)) - return - } - if !ready { - logger.InfoC("gateway", fmt.Sprintf("Skip auto-starting gateway: %s", reason)) - return - } - _, err = h.startGatewayLocked("starting", pid) - if err != nil { - logger.ErrorC("gateway", fmt.Sprintf("Failed to attach to running gateway (PID: %d): %v", pid, err)) - } + // Check PID file first to detect an already-running gateway. + pidData := ppid.ReadPidFileWithCheck(globalConfigDir()) + logger.Infof("pidData: %v", pidData) + if pidData != nil { + gateway.mu.Lock() + ready, reason, err := h.gatewayStartReady() + if err != nil { + logger.ErrorC("gateway", fmt.Sprintf("Skip auto-starting gateway: %v", err)) + gateway.mu.Unlock() return } + logger.Infof("ready: %v, reason: %s", ready, reason) + if !ready { + logger.InfoC("gateway", fmt.Sprintf("Skip auto-starting gateway: %s", reason)) + gateway.mu.Unlock() + return + } + pid := pidData.PID + _, err = h.startGatewayLocked("starting", pid) + if err != nil { + logger.ErrorC("gateway", fmt.Sprintf("Failed to attach to running gateway (PID: %d): %v", pid, err)) + } else { + gateway.pidData = pidData + refreshPicoTokensLocked(h.configPath) + logger.InfoC("gateway", fmt.Sprintf("Attached to running gateway via PID file (PID: %d)", pid)) + } + gateway.mu.Unlock() + return } gateway.mu.Lock() @@ -319,6 +376,7 @@ func stopGatewayLocked() (int, error) { gateway.cmd = nil gateway.owned = false gateway.bootDefaultModel = "" + gateway.pidData = nil setGatewayRuntimeStatusLocked("stopped") return pid, nil @@ -371,6 +429,7 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int pid = existingPid gateway.cmd = nil // Clear first to ensure clean state if err = attachToGatewayProcessLocked(pid, cfg); err != nil { + logger.ErrorC("gateway", fmt.Sprintf("Failed to attach to existing gateway (PID %d): %v", pid, err)) return 0, err } @@ -380,6 +439,7 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int // Start new process // Locate the picoclaw executable execPath := utils.FindPicoclawBinary() + logger.InfoC("gateway", fmt.Sprintf("Starting gateway process (%s)", execPath)) cmd = exec.Command(execPath, "gateway", "-E") cmd.Env = os.Environ() @@ -407,10 +467,16 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int gateway.logs.Reset() // Ensure Pico Channel is configured before starting gateway - if _, err := h.EnsurePicoChannel(""); err != nil { + changed, err := h.EnsurePicoChannel("") + if err != nil { logger.ErrorC("gateway", fmt.Sprintf("Warning: failed to ensure pico channel: %v", err)) // Non-fatal: gateway can still start without pico channel } + // Refresh cached pico token in case EnsurePicoChannel generated a new one. + // Already holding gateway.mu from caller. + if changed { + refreshPicoTokensLocked(h.configPath) + } if err := cmd.Start(); err != nil { return 0, fmt.Errorf("failed to start gateway: %w", err) @@ -446,7 +512,7 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int gateway.mu.Unlock() }() - // Start a goroutine to probe health and update the runtime state once ready. + // Start a goroutine to probe pidFile and health, update runtime state once ready. go func() { for i := 0; i < 30; i++ { // try for up to 15 seconds time.Sleep(500 * time.Millisecond) @@ -456,13 +522,26 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int if !stillOurs { return } + + // Poll for pidFile first — once available we have port/host/token. + if pd := ppid.ReadPidFileWithCheck(globalConfigDir()); pd != nil && pd.PID == pid { + gateway.mu.Lock() + if gateway.cmd == cmd { + gateway.pidData = pd + setGatewayRuntimeStatusLocked("running") + } + gateway.mu.Unlock() + logger.InfoC("gateway", fmt.Sprintf("Gateway pidFile detected (PID: %d, port: %d)", pd.PID, pd.Port)) + return + } + + // Fallback: probe health endpoint to confirm liveness. cfg, err := config.LoadConfig(h.configPath) if err != nil { continue } - healthResp, statusCode, err := h.getGatewayHealth(cfg, 1*time.Second) - if err == nil && statusCode == http.StatusOK && healthResp.Pid == pid { - // Verify the health endpoint returns the expected pid + _, statusCode, err := h.getGatewayHealth(cfg, 1*time.Second) + if err == nil && statusCode == http.StatusOK { gateway.mu.Lock() if gateway.cmd == cmd { setGatewayRuntimeStatusLocked("running") @@ -480,49 +559,47 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int // // POST /api/gateway/start func (h *Handler) handleGatewayStart(w http.ResponseWriter, r *http.Request) { - // Prevent duplicate starts by checking health endpoint - cfg, cfgErr := config.LoadConfig(h.configPath) - if cfgErr == nil && cfg != nil { - healthResp, statusCode, err := h.getGatewayHealth(cfg, 2*time.Second) - if err == nil && statusCode == http.StatusOK { - // Gateway is already running, attach to the existing process - pid := healthResp.Pid - gateway.mu.Lock() - ready, reason, err := h.gatewayStartReady() - if err != nil { - gateway.mu.Unlock() - http.Error( - w, - fmt.Sprintf("Failed to validate gateway start conditions: %v", err), - http.StatusInternalServerError, - ) - return - } - if !ready { - gateway.mu.Unlock() - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusBadRequest) - json.NewEncoder(w).Encode(map[string]any{ - "status": "precondition_failed", - "message": reason, - }) - return - } - _, err = h.startGatewayLocked("starting", pid) + // Check PID file first to detect an already-running gateway. + pidData := ppid.ReadPidFileWithCheck(globalConfigDir()) + if pidData != nil { + pid := pidData.PID + gateway.mu.Lock() + ready, reason, err := h.gatewayStartReady() + if err != nil { + gateway.mu.Unlock() + http.Error( + w, + fmt.Sprintf("Failed to validate gateway start conditions: %v", err), + http.StatusInternalServerError, + ) + return + } + if !ready { gateway.mu.Unlock() - if err != nil { - logger.ErrorC("gateway", fmt.Sprintf("Failed to attach to running gateway (PID: %d): %v", pid, err)) - http.Error(w, fmt.Sprintf("Failed to attach to gateway: %v", err), http.StatusInternalServerError) - return - } w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) + w.WriteHeader(http.StatusBadRequest) json.NewEncoder(w).Encode(map[string]any{ - "status": "ok", - "pid": pid, + "status": "precondition_failed", + "message": reason, }) return } + _, err = h.startGatewayLocked("starting", pid) + if err != nil { + gateway.mu.Unlock() + logger.ErrorC("gateway", fmt.Sprintf("Failed to attach to running gateway (PID: %d): %v", pid, err)) + http.Error(w, fmt.Sprintf("Failed to attach to gateway: %v", err), http.StatusInternalServerError) + return + } + gateway.pidData = pidData + gateway.mu.Unlock() + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]any{ + "status": "ok", + "pid": pid, + }) + return } gateway.mu.Lock() @@ -722,65 +799,56 @@ func (h *Handler) gatewayStatusData() map[string]any { } } - // Probe health endpoint to get pid and status - healthResp, statusCode, err := h.getGatewayHealth(cfg, 2*time.Second) - if err != nil { + // Primary detection: read PID file and check if process is alive. + pidData := ppid.ReadPidFileWithCheck(globalConfigDir()) + if pidData != nil { gateway.mu.Lock() - data["gateway_status"] = gatewayStatusWithoutHealthLocked() + gateway.pidData = pidData + if pidData.Version != "" { + data["gateway_version"] = pidData.Version + } + setGatewayRuntimeStatusLocked("running") + + // Attach if we don't already track this PID. + if gateway.cmd == nil || gateway.cmd.Process == nil || gateway.cmd.Process.Pid != pidData.PID { + _ = attachToGatewayProcessLocked(pidData.PID, cfg) + } + + bootDefaultModel := gateway.bootDefaultModel + if bootDefaultModel != "" { + data["boot_default_model"] = bootDefaultModel + } + data["gateway_status"] = "running" + data["pid"] = pidData.PID gateway.mu.Unlock() - logger.ErrorC("gateway", fmt.Sprintf("Gateway health check failed: %v", err)) } else { - if statusCode != http.StatusOK { - logger.WarnC("gateway", fmt.Sprintf("Gateway health status: %d", statusCode)) + // Fallback: probe health endpoint to get pid and status + _, statusCode, err := h.getGatewayHealth(cfg, 2*time.Second) + if err != nil { gateway.mu.Lock() - setGatewayRuntimeStatusLocked("error") + data["gateway_status"] = gatewayStatusWithoutHealthLocked() + gateway.pidData = nil gateway.mu.Unlock() - data["gateway_status"] = "error" - data["status_code"] = statusCode + logger.ErrorC("gateway", fmt.Sprintf("Gateway health check failed: %v", err)) } else { - gateway.mu.Lock() - setGatewayRuntimeStatusLocked("running") - if gateway.cmd == nil || gateway.cmd.Process == nil || gateway.cmd.Process.Pid != healthResp.Pid { - oldPid := "none" - if gateway.cmd != nil && gateway.cmd.Process != nil { - oldPid = fmt.Sprintf("%d", gateway.cmd.Process.Pid) - } - logger.InfoC( - "gateway", - fmt.Sprintf( - "Detected new gateway PID (old: %s, new: %d), attempting to attach", - oldPid, - healthResp.Pid, - ), - ) - - if err := attachToGatewayProcessLocked(healthResp.Pid, cfg); err != nil { - // Failed to find the process, treat as error - setGatewayRuntimeStatusLocked("error") - data["gateway_status"] = "error" - data["pid"] = healthResp.Pid - logger.ErrorC( - "gateway", - fmt.Sprintf("Failed to attach to new gateway process (PID: %d): %v", healthResp.Pid, err), - ) - } else { - // Successfully attached, update response data - bootDefaultModel := gateway.bootDefaultModel - if bootDefaultModel != "" { - data["boot_default_model"] = bootDefaultModel - } - data["gateway_status"] = "running" - data["pid"] = healthResp.Pid + logger.InfoC("gateway", fmt.Sprintf("Gateway health status: %d", statusCode)) + if statusCode != http.StatusOK { + gateway.mu.Lock() + setGatewayRuntimeStatusLocked("error") + gateway.pidData = nil + gateway.mu.Unlock() + data["gateway_status"] = "error" + data["status_code"] = statusCode + } else { + gateway.mu.Lock() + setGatewayRuntimeStatusLocked("running") + bootDefaultModel := gateway.bootDefaultModel + if bootDefaultModel != "" { + data["boot_default_model"] = bootDefaultModel } + data["gateway_status"] = "running" + gateway.mu.Unlock() } - - bootDefaultModel := gateway.bootDefaultModel - if bootDefaultModel != "" { - data["boot_default_model"] = bootDefaultModel - } - data["gateway_status"] = "running" - data["pid"] = healthResp.Pid - gateway.mu.Unlock() } } diff --git a/web/backend/api/gateway_test.go b/web/backend/api/gateway_test.go index a5ba2bad2..a891be3c1 100644 --- a/web/backend/api/gateway_test.go +++ b/web/backend/api/gateway_test.go @@ -469,9 +469,6 @@ func TestGatewayStatusReportsRunningFromHealthProbe(t *testing.T) { if got := body["gateway_status"]; got != "running" { t.Fatalf("gateway_status = %#v, want %q", got, "running") } - if got := body["pid"]; got != float64(cmd.Process.Pid) { - t.Fatalf("pid = %#v, want %d", got, cmd.Process.Pid) - } if got := body["gateway_restart_required"]; got != false { t.Fatalf("gateway_restart_required = %#v, want false", got) } diff --git a/web/backend/api/pico.go b/web/backend/api/pico.go index d345d980c..8bef33ac8 100644 --- a/web/backend/api/pico.go +++ b/web/backend/api/pico.go @@ -10,6 +10,7 @@ import ( "time" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" ) // registerPicoRoutes binds Pico Channel management endpoints to the ServeMux. @@ -26,20 +27,55 @@ func (h *Handler) registerPicoRoutes(mux *http.ServeMux) { // createWsProxy creates a reverse proxy to the current gateway WebSocket endpoint. // The gateway bind host and port are resolved from the latest configuration. -func (h *Handler) createWsProxy() *httputil.ReverseProxy { - wsProxy := httputil.NewSingleHostReverseProxy(h.gatewayProxyURL()) - wsProxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { - http.Error(w, "Gateway unavailable: "+err.Error(), http.StatusBadGateway) +func (h *Handler) createWsProxy(origProtocol string, token string) *httputil.ReverseProxy { + wsProxy := &httputil.ReverseProxy{ + Rewrite: func(r *httputil.ProxyRequest) { + target := h.gatewayProxyURL() + r.SetURL(target) + r.Out.Header.Set(protocolKey, tokenPrefix+token) + }, + ModifyResponse: func(r *http.Response) error { + if prot := r.Header.Values(protocolKey); len(prot) > 0 { + r.Header.Del(protocolKey) + if origProtocol != "" { + r.Header.Set(protocolKey, origProtocol) + } + } + return nil + }, + ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) { + logger.Errorf("Failed to proxy WebSocket: %v", err) + http.Error(w, "Gateway unavailable: "+err.Error(), http.StatusBadGateway) + }, } return wsProxy } // handleWebSocketProxy wraps a reverse proxy to handle WebSocket connections. -// The reverse proxy forwards the incoming upgrade handshake as-is. +// It validates the client token before forwarding; rejects immediately on failure. func (h *Handler) handleWebSocketProxy() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - proxy := h.createWsProxy() - proxy.ServeHTTP(w, r) + gateway.mu.Lock() + gatewayAvailable := gateway.pidData != nil + gateway.mu.Unlock() + + if !gatewayAvailable { + logger.Warnf("Gateway not available for WebSocket proxy") + http.Error(w, "Gateway not available", http.StatusServiceUnavailable) + return + } + prot := r.Header.Values(protocolKey) + if len(prot) > 0 { + origProtocol := prot[0] + newToken := picoComposedToken(prot[0]) + if newToken != "" { + h.createWsProxy(origProtocol, newToken).ServeHTTP(w, r) + return + } + } + + logger.Warnf("Invalid Pico token: %v", prot) + http.Error(w, "Invalid Pico token", http.StatusForbidden) } } @@ -81,6 +117,11 @@ func (h *Handler) handleRegenPicoToken(w http.ResponseWriter, r *http.Request) { return } + // Refresh cached pico token. + gateway.mu.Lock() + gateway.picoToken = token + gateway.mu.Unlock() + wsURL := h.buildWsURL(r, cfg) w.Header().Set("Content-Type", "application/json") @@ -140,11 +181,15 @@ func (h *Handler) handlePicoSetup(w http.ResponseWriter, r *http.Request) { return } + // Reload config (EnsurePicoChannel may have modified it) and refresh cache. cfg, err := config.LoadConfig(h.configPath) if err != nil { http.Error(w, fmt.Sprintf("Failed to load config: %v", err), http.StatusInternalServerError) return } + if changed { + refreshPicoToken(cfg) + } wsURL := h.buildWsURL(r, cfg) @@ -162,7 +207,7 @@ func generateSecureToken() string { b := make([]byte, 16) if _, err := rand.Read(b); err != nil { // Fallback to something pseudo-random if crypto/rand fails - return fmt.Sprintf("pico_%x", time.Now().UnixNano()) + return fmt.Sprintf("%032x", time.Now().UnixNano()) } return hex.EncodeToString(b) } diff --git a/web/backend/api/pico_test.go b/web/backend/api/pico_test.go index aa377975d..beff4d77f 100644 --- a/web/backend/api/pico_test.go +++ b/web/backend/api/pico_test.go @@ -12,6 +12,7 @@ import ( "testing" "github.com/sipeed/picoclaw/pkg/config" + ppid "github.com/sipeed/picoclaw/pkg/pid" ) func TestEnsurePicoChannel_FreshConfig(t *testing.T) { @@ -335,10 +336,22 @@ func TestHandleWebSocketProxyReloadsGatewayTargetFromConfig(t *testing.T) { t.Fatalf("SaveConfig() error = %v", err) } + gateway.pidData = &ppid.PidFileData{} + gateway.picoToken = "pico" req1 := httptest.NewRequest(http.MethodGet, "/pico/ws", nil) + req1.Header.Set(protocolKey, tokenPrefix+"wrong_token") rec1 := httptest.NewRecorder() handler(rec1, req1) + if rec1.Code != http.StatusForbidden { + t.Fatalf("first status = %d, want %d", rec1.Code, http.StatusForbidden) + } + + req1 = httptest.NewRequest(http.MethodGet, "/pico/ws", nil) + req1.Header.Set(protocolKey, tokenPrefix+"pico") + rec1 = httptest.NewRecorder() + handler(rec1, req1) + if rec1.Code != http.StatusOK { t.Fatalf("first status = %d, want %d", rec1.Code, http.StatusOK) } @@ -352,6 +365,7 @@ func TestHandleWebSocketProxyReloadsGatewayTargetFromConfig(t *testing.T) { } req2 := httptest.NewRequest(http.MethodGet, "/pico/ws", nil) + req2.Header.Set(protocolKey, tokenPrefix+"pico") rec2 := httptest.NewRecorder() handler(rec2, req2) diff --git a/web/backend/api/skills.go b/web/backend/api/skills.go index 3c2fb57dd..b2036f66c 100644 --- a/web/backend/api/skills.go +++ b/web/backend/api/skills.go @@ -309,14 +309,7 @@ func loadSkillContent(path string) (string, error) { } func globalConfigDir() string { - if home := os.Getenv(config.EnvHome); home != "" { - return home - } - home, err := os.UserHomeDir() - if err != nil { - return "" - } - return filepath.Join(home, ".picoclaw") + return config.GetHome() } func builtinSkillsDir() string { diff --git a/web/backend/middleware/middleware.go b/web/backend/middleware/middleware.go index 5e0dfeb90..a0b7eb998 100644 --- a/web/backend/middleware/middleware.go +++ b/web/backend/middleware/middleware.go @@ -1,7 +1,9 @@ package middleware import ( + "bufio" "fmt" + "net" "net/http" "runtime/debug" "time" @@ -44,6 +46,15 @@ func (rr *responseRecorder) Unwrap() http.ResponseWriter { return rr.ResponseWriter } +// Hijack implements http.Hijacker so that WebSocket upgrades work through +// the middleware layer. +func (rr *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if hj, ok := rr.ResponseWriter.(http.Hijacker); ok { + return hj.Hijack() + } + return nil, nil, http.ErrNotSupported +} + // Logger logs each HTTP request with method, path, status code, and duration. func Logger(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/web/backend/utils/runtime.go b/web/backend/utils/runtime.go index 772cd7ec0..0b9e30979 100644 --- a/web/backend/utils/runtime.go +++ b/web/backend/utils/runtime.go @@ -9,16 +9,13 @@ import ( "runtime" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" ) // GetPicoclawHome returns the picoclaw home directory. // Priority: $PICOCLAW_HOME > ~/.picoclaw func GetPicoclawHome() string { - if home := os.Getenv(config.EnvHome); home != "" { - return home - } - home, _ := os.UserHomeDir() - return filepath.Join(home, ".picoclaw") + return config.GetHome() } // GetDefaultConfigPath returns the default path to the picoclaw config file. @@ -47,6 +44,7 @@ func FindPicoclawBinary() string { } if exe, err := os.Executable(); err == nil { + logger.Debugf("Trying to find picoclaw binary in %s", exe) candidate := filepath.Join(filepath.Dir(exe), binaryName) if info, err := os.Stat(candidate); err == nil && !info.IsDir() { return candidate From f0c0219c4cd6cc86b12b8635c185dc0a4327a225 Mon Sep 17 00:00:00 2001 From: Cytown Date: Sun, 29 Mar 2026 16:58:48 +0800 Subject: [PATCH 2/2] fix for review --- pkg/channels/pico/protocol.go | 2 ++ pkg/gateway/gateway.go | 8 +++----- web/backend/api/gateway.go | 8 ++++---- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/pkg/channels/pico/protocol.go b/pkg/channels/pico/protocol.go index 0a630e193..192c96164 100644 --- a/pkg/channels/pico/protocol.go +++ b/pkg/channels/pico/protocol.go @@ -17,6 +17,8 @@ const ( TypeTypingStop = "typing.stop" TypeError = "error" TypePong = "pong" + + PicoTokenPrefix = "pico-" ) // PicoMessage is the wire format for all Pico Protocol messages. diff --git a/pkg/gateway/gateway.go b/pkg/gateway/gateway.go index a63530806..1a9ab4461 100644 --- a/pkg/gateway/gateway.go +++ b/pkg/gateway/gateway.go @@ -22,7 +22,7 @@ import ( _ "github.com/sipeed/picoclaw/pkg/channels/line" _ "github.com/sipeed/picoclaw/pkg/channels/maixcam" _ "github.com/sipeed/picoclaw/pkg/channels/onebot" - _ "github.com/sipeed/picoclaw/pkg/channels/pico" + "github.com/sipeed/picoclaw/pkg/channels/pico" _ "github.com/sipeed/picoclaw/pkg/channels/qq" _ "github.com/sipeed/picoclaw/pkg/channels/slack" _ "github.com/sipeed/picoclaw/pkg/channels/telegram" @@ -662,8 +662,6 @@ func setupCronTool( return cronService, nil } -const picoTokenPrefix = "pico-" - // overridePicoToken replaces the pico channel token with the one from the PID file. // The PID file is the single source of truth for the pico auth token; // it is generated once at gateway startup and remains unchanged across reloads. @@ -672,10 +670,10 @@ func overridePicoToken(cfg *config.Config, token string) { return } picoToken := cfg.Channels.Pico.Token.String() - if picoToken == "" || strings.HasPrefix(picoToken, picoTokenPrefix) { + if picoToken == "" || strings.HasPrefix(picoToken, pico.PicoTokenPrefix) { return } - cfg.Channels.Pico.SetToken(picoTokenPrefix + token + picoToken) + cfg.Channels.Pico.SetToken(pico.PicoTokenPrefix + token + picoToken) } func createHeartbeatHandler(agentLoop *agent.AgentLoop) func(prompt, channel, chatID string) *tools.ToolResult { diff --git a/web/backend/api/gateway.go b/web/backend/api/gateway.go index 98fb77a04..ce3a9ca1e 100644 --- a/web/backend/api/gateway.go +++ b/web/backend/api/gateway.go @@ -17,6 +17,7 @@ import ( "syscall" "time" + "github.com/sipeed/picoclaw/pkg/channels/pico" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/health" "github.com/sipeed/picoclaw/pkg/logger" @@ -59,9 +60,8 @@ func refreshPicoTokensLocked(configPath string) { } const ( - protocolKey = "Sec-Websocket-Protocol" - picoTokenPrefix = "pico-" - tokenPrefix = "token." + protocolKey = "Sec-Websocket-Protocol" + tokenPrefix = "token." ) // picoComposedToken returns "pico-"+pidToken+picoToken for gateway auth. @@ -75,7 +75,7 @@ func picoComposedToken(token string) string { if tokenPrefix+gateway.picoToken != token { return "" } - return picoTokenPrefix + gateway.pidData.Token + gateway.picoToken + return pico.PicoTokenPrefix + gateway.pidData.Token + gateway.picoToken } var (