mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
add pid file for gateway running and auth token for /reload and pico channel
This commit is contained in:
@@ -12,7 +12,6 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
@@ -59,14 +58,7 @@ func (cb *ContextBuilder) WithSplitOnMarker(enabled bool) *ContextBuilder {
|
||||
}
|
||||
|
||||
func getGlobalConfigDir() string {
|
||||
if home := os.Getenv(config.EnvHome); home != "" {
|
||||
return home
|
||||
}
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return filepath.Join(home, pkg.DefaultPicoClawHome)
|
||||
return config.GetHome()
|
||||
}
|
||||
|
||||
func NewContextBuilder(workspace string) *ContextBuilder {
|
||||
|
||||
+1
-6
@@ -6,7 +6,6 @@ import (
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/fileutil"
|
||||
)
|
||||
@@ -41,11 +40,7 @@ func (c *AuthCredential) NeedsRefresh() bool {
|
||||
}
|
||||
|
||||
func authFilePath() string {
|
||||
if home := os.Getenv(config.EnvHome); home != "" {
|
||||
return filepath.Join(home, "auth.json")
|
||||
}
|
||||
home, _ := os.UserHomeDir()
|
||||
return filepath.Join(home, pkg.DefaultPicoClawHome, "auth.json")
|
||||
return filepath.Join(config.GetHome(), "auth.json")
|
||||
}
|
||||
|
||||
func LoadStore() (*AuthStore, error) {
|
||||
|
||||
@@ -37,11 +37,7 @@ type syncCursorFile struct {
|
||||
}
|
||||
|
||||
func picoclawHomeDir() string {
|
||||
if home := os.Getenv(config.EnvHome); home != "" {
|
||||
return home
|
||||
}
|
||||
userHome, _ := os.UserHomeDir()
|
||||
return filepath.Join(userHome, ".picoclaw")
|
||||
return config.GetHome()
|
||||
}
|
||||
|
||||
func buildWeixinSyncBufPath(cfg config.WeixinConfig) string {
|
||||
|
||||
@@ -1081,12 +1081,7 @@ func LoadConfig(path string) (*Config, error) {
|
||||
|
||||
// Ensure Workspace has a default if not set
|
||||
if cfg.Agents.Defaults.Workspace == "" {
|
||||
homePath, _ := os.UserHomeDir()
|
||||
if picoclawHome := os.Getenv(EnvHome); picoclawHome != "" {
|
||||
homePath = picoclawHome
|
||||
} else if homePath != "" {
|
||||
homePath = filepath.Join(homePath, pkg.DefaultPicoClawHome)
|
||||
}
|
||||
homePath := GetHome()
|
||||
cfg.Agents.Defaults.Workspace = filepath.Join(homePath, pkg.WorkspaceName)
|
||||
}
|
||||
|
||||
|
||||
+1
-11
@@ -6,7 +6,6 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg"
|
||||
@@ -14,16 +13,7 @@ import (
|
||||
|
||||
// DefaultConfig returns the default configuration for PicoClaw.
|
||||
func DefaultConfig() *Config {
|
||||
// Determine the base path for the workspace.
|
||||
// Priority: $PICOCLAW_HOME > ~/.picoclaw
|
||||
var homePath string
|
||||
if picoclawHome := os.Getenv(EnvHome); picoclawHome != "" {
|
||||
homePath = picoclawHome
|
||||
} else {
|
||||
userHome, _ := os.UserHomeDir()
|
||||
homePath = filepath.Join(userHome, pkg.DefaultPicoClawHome)
|
||||
}
|
||||
workspacePath := filepath.Join(homePath, pkg.WorkspaceName)
|
||||
workspacePath := filepath.Join(GetHome(), pkg.WorkspaceName)
|
||||
|
||||
return &Config{
|
||||
Version: CurrentVersion,
|
||||
|
||||
@@ -5,6 +5,13 @@
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg"
|
||||
)
|
||||
|
||||
// Runtime environment variable keys for the picoclaw process.
|
||||
// These control the location of files and binaries at runtime and are read
|
||||
// directly via os.Getenv / os.LookupEnv. All picoclaw-specific keys use the
|
||||
@@ -35,3 +42,16 @@ const (
|
||||
// Default: "127.0.0.1"
|
||||
EnvGatewayHost = "PICOCLAW_GATEWAY_HOST"
|
||||
)
|
||||
|
||||
func GetHome() string {
|
||||
homePath, _ := os.UserHomeDir()
|
||||
if picoclawHome := os.Getenv(EnvHome); picoclawHome != "" {
|
||||
homePath = picoclawHome
|
||||
} else if homePath != "" {
|
||||
homePath = filepath.Join(homePath, pkg.DefaultPicoClawHome)
|
||||
}
|
||||
if homePath == "" {
|
||||
homePath = "."
|
||||
}
|
||||
return homePath
|
||||
}
|
||||
|
||||
+38
-2
@@ -6,6 +6,7 @@ import (
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
@@ -36,6 +37,7 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/heartbeat"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
"github.com/sipeed/picoclaw/pkg/pid"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/state"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
@@ -61,6 +63,7 @@ type services struct {
|
||||
HealthServer *health.Server
|
||||
manualReloadChan chan struct{}
|
||||
reloading atomic.Bool
|
||||
authToken string
|
||||
}
|
||||
|
||||
type startupBlockedProvider struct {
|
||||
@@ -107,6 +110,13 @@ func Run(debug bool, homePath, configPath string, allowEmptyStartup bool) error
|
||||
fmt.Println("🔍 Debug mode enabled")
|
||||
}
|
||||
|
||||
// Enforce singleton: write PID file with generated token.
|
||||
pidData, err := pid.WritePidFile(homePath, cfg.Gateway.Host, cfg.Gateway.Port)
|
||||
if err != nil {
|
||||
return fmt.Errorf("singleton check failed: %w", err)
|
||||
}
|
||||
defer pid.RemovePidFile(homePath)
|
||||
|
||||
provider, modelID, err := createStartupProvider(cfg, allowEmptyStartup)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating provider: %w", err)
|
||||
@@ -133,7 +143,7 @@ func Run(debug bool, homePath, configPath string, allowEmptyStartup bool) error
|
||||
"skills_available": skillsInfo["available"],
|
||||
})
|
||||
|
||||
runningServices, err := setupAndStartServices(cfg, agentLoop, msgBus)
|
||||
runningServices, err := setupAndStartServices(cfg, agentLoop, msgBus, pidData.Token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -224,6 +234,9 @@ func executeReload(
|
||||
allowEmptyStartup bool,
|
||||
) error {
|
||||
defer runningServices.reloading.Store(false)
|
||||
|
||||
overridePicoToken(newCfg, runningServices.authToken)
|
||||
|
||||
return handleConfigReload(ctx, agentLoop, newCfg, provider, runningServices, msgBus, allowEmptyStartup)
|
||||
}
|
||||
|
||||
@@ -248,6 +261,7 @@ func setupAndStartServices(
|
||||
cfg *config.Config,
|
||||
agentLoop *agent.AgentLoop,
|
||||
msgBus *bus.MessageBus,
|
||||
authToken string,
|
||||
) (*services, error) {
|
||||
runningServices := &services{}
|
||||
|
||||
@@ -290,6 +304,8 @@ func setupAndStartServices(
|
||||
fms.Start()
|
||||
}
|
||||
|
||||
overridePicoToken(cfg, authToken)
|
||||
|
||||
runningServices.ChannelManager, err = channels.NewManager(cfg, msgBus, runningServices.MediaStore)
|
||||
if err != nil {
|
||||
if fms, ok := runningServices.MediaStore.(*media.FileMediaStore); ok {
|
||||
@@ -314,7 +330,8 @@ func setupAndStartServices(
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf("%s:%d", cfg.Gateway.Host, cfg.Gateway.Port)
|
||||
runningServices.HealthServer = health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port)
|
||||
runningServices.authToken = authToken
|
||||
runningServices.HealthServer = health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port, authToken)
|
||||
runningServices.ChannelManager.SetupHTTPServer(addr, runningServices.HealthServer)
|
||||
|
||||
if err = runningServices.ChannelManager.StartAll(context.Background()); err != nil {
|
||||
@@ -524,6 +541,9 @@ func restartServices(
|
||||
logger.InfoCF("voice", "Transcription disabled", nil)
|
||||
}
|
||||
|
||||
// NOTE: PID file is written once at startup and not updated on reload.
|
||||
// Changing the gateway listen address requires a full restart.
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -642,6 +662,22 @@ func setupCronTool(
|
||||
return cronService, nil
|
||||
}
|
||||
|
||||
const picoTokenPrefix = "pico-"
|
||||
|
||||
// overridePicoToken replaces the pico channel token with the one from the PID file.
|
||||
// The PID file is the single source of truth for the pico auth token;
|
||||
// it is generated once at gateway startup and remains unchanged across reloads.
|
||||
func overridePicoToken(cfg *config.Config, token string) {
|
||||
if !cfg.Channels.Pico.Enabled {
|
||||
return
|
||||
}
|
||||
picoToken := cfg.Channels.Pico.Token.String()
|
||||
if picoToken == "" || strings.HasPrefix(picoToken, picoTokenPrefix) {
|
||||
return
|
||||
}
|
||||
cfg.Channels.Pico.SetToken(picoTokenPrefix + token + picoToken)
|
||||
}
|
||||
|
||||
func createHeartbeatHandler(agentLoop *agent.AgentLoop) func(prompt, channel, chatID string) *tools.ToolResult {
|
||||
return func(prompt, channel, chatID string) *tools.ToolResult {
|
||||
if channel == "" || chatID == "" {
|
||||
|
||||
+32
-4
@@ -2,11 +2,11 @@ package health
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/subtle"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"maps"
|
||||
"net/http"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
@@ -18,6 +18,7 @@ type Server struct {
|
||||
checks map[string]Check
|
||||
startTime time.Time
|
||||
reloadFunc func() error
|
||||
authToken string // optional bearer token for protected endpoints
|
||||
}
|
||||
|
||||
type Check struct {
|
||||
@@ -31,15 +32,15 @@ type StatusResponse struct {
|
||||
Status string `json:"status"`
|
||||
Uptime string `json:"uptime"`
|
||||
Checks map[string]Check `json:"checks,omitempty"`
|
||||
Pid int `json:"pid"`
|
||||
}
|
||||
|
||||
func NewServer(host string, port int) *Server {
|
||||
func NewServer(host string, port int, token string) *Server {
|
||||
mux := http.NewServeMux()
|
||||
s := &Server{
|
||||
ready: false,
|
||||
checks: make(map[string]Check),
|
||||
startTime: time.Now(),
|
||||
authToken: token,
|
||||
}
|
||||
|
||||
mux.HandleFunc("/health", s.healthHandler)
|
||||
@@ -123,6 +124,21 @@ func (s *Server) reloadHandler(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Token check
|
||||
s.mu.RLock()
|
||||
requiredToken := s.authToken
|
||||
s.mu.RUnlock()
|
||||
|
||||
if requiredToken != "" {
|
||||
given := extractBearerToken(r.Header.Get("Authorization"))
|
||||
if given == "" || subtle.ConstantTimeCompare([]byte(given), []byte(requiredToken)) != 1 {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": "unauthorized"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
reloadFunc := s.reloadFunc
|
||||
s.mu.Unlock()
|
||||
@@ -154,7 +170,6 @@ func (s *Server) healthHandler(w http.ResponseWriter, r *http.Request) {
|
||||
resp := StatusResponse{
|
||||
Status: "ok",
|
||||
Uptime: uptime.String(),
|
||||
Pid: os.Getpid(),
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
@@ -220,3 +235,16 @@ func statusString(ok bool) string {
|
||||
}
|
||||
return "fail"
|
||||
}
|
||||
|
||||
// extractBearerToken returns the token from an "Authorization: Bearer <t>" header,
|
||||
// or the empty string if the header is missing or malformed.
|
||||
func extractBearerToken(header string) string {
|
||||
const prefix = "Bearer "
|
||||
if len(header) < len(prefix) {
|
||||
return ""
|
||||
}
|
||||
if header[:len(prefix)] != prefix {
|
||||
return ""
|
||||
}
|
||||
return header[len(prefix):]
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ func newTestServer() *Server {
|
||||
ready: false,
|
||||
checks: make(map[string]Check),
|
||||
startTime: time.Now(),
|
||||
authToken: "test",
|
||||
}
|
||||
return s
|
||||
}
|
||||
@@ -37,9 +38,6 @@ func TestHealthHandler_ReturnsOK(t *testing.T) {
|
||||
if resp.Status != "ok" {
|
||||
t.Errorf("status = %q, want %q", resp.Status, "ok")
|
||||
}
|
||||
if resp.Pid == 0 {
|
||||
t.Error("pid should not be 0")
|
||||
}
|
||||
if resp.Uptime == "" {
|
||||
t.Error("uptime should not be empty")
|
||||
}
|
||||
@@ -168,6 +166,7 @@ func TestReloadHandler_NoReloadFunc(t *testing.T) {
|
||||
s := newTestServer()
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/reload", nil)
|
||||
req.Header.Set("Authorization", "Bearer test")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
s.reloadHandler(w, req)
|
||||
@@ -186,6 +185,7 @@ func TestReloadHandler_Success(t *testing.T) {
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/reload", nil)
|
||||
req.Header.Set("Authorization", "Bearer test")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
s.reloadHandler(w, req)
|
||||
@@ -205,6 +205,7 @@ func TestReloadHandler_Error(t *testing.T) {
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/reload", nil)
|
||||
req.Header.Set("Authorization", "Bearer test")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
s.reloadHandler(w, req)
|
||||
@@ -292,7 +293,7 @@ func TestRegisterOnMux(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNewServer(t *testing.T) {
|
||||
s := NewServer("127.0.0.1", 0)
|
||||
s := NewServer("127.0.0.1", 0, "")
|
||||
if s == nil {
|
||||
t.Fatal("NewServer returned nil")
|
||||
}
|
||||
@@ -305,7 +306,7 @@ func TestNewServer(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStartContext_Cancellation(t *testing.T) {
|
||||
s := NewServer("127.0.0.1", 0)
|
||||
s := NewServer("127.0.0.1", 0, "")
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
@@ -14,14 +12,7 @@ func ResolveTargetHome(override string) (string, error) {
|
||||
if override != "" {
|
||||
return ExpandHome(override), nil
|
||||
}
|
||||
if envHome := os.Getenv(config.EnvHome); envHome != "" {
|
||||
return ExpandHome(envHome), nil
|
||||
}
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("resolving home directory: %w", err)
|
||||
}
|
||||
return filepath.Join(home, pkg.DefaultPicoClawHome), nil
|
||||
return config.GetHome(), nil
|
||||
}
|
||||
|
||||
func ExpandHome(path string) string {
|
||||
|
||||
@@ -0,0 +1,159 @@
|
||||
package pid
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
const pidFileName = ".picoclaw.pid"
|
||||
|
||||
// PidFileData is the JSON structure stored in the PID file.
|
||||
type PidFileData struct {
|
||||
PID int `json:"pid"`
|
||||
Token string `json:"token"`
|
||||
Version string `json:"version"`
|
||||
Port int `json:"port"`
|
||||
Host string `json:"host"`
|
||||
}
|
||||
|
||||
var pidMu sync.Mutex
|
||||
|
||||
// pidFilePath returns the absolute path for the PID file given the home directory.
|
||||
func pidFilePath(homePath string) string {
|
||||
return filepath.Join(homePath, pidFileName)
|
||||
}
|
||||
|
||||
// generateToken creates a cryptographically random 32-character hex token.
|
||||
func generateToken() string {
|
||||
b := make([]byte, 16)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
// Fallback to something pseudo-random if crypto/rand fails
|
||||
return fmt.Sprintf("%032x", time.Now().UnixNano())
|
||||
}
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
// WritePidFile creates (or overwrites) the PID file atomically.
|
||||
// It returns an error if another gateway instance appears to be running
|
||||
// (a valid PID file exists with a live process).
|
||||
func WritePidFile(homePath, host string, port int) (*PidFileData, error) {
|
||||
pidMu.Lock()
|
||||
defer pidMu.Unlock()
|
||||
|
||||
pidPath := pidFilePath(homePath)
|
||||
|
||||
// Check for existing PID file → singleton enforcement.
|
||||
if data, err := readPidFileUnlocked(pidPath); err == nil {
|
||||
if os.Getpid() != data.PID {
|
||||
logger.Infof("found pid file (PID: %d, version: %s)", data.PID, data.Version)
|
||||
if isProcessRunning(data.PID) {
|
||||
return nil, fmt.Errorf("gateway is already running (PID: %d, version: %s)", data.PID, data.Version)
|
||||
}
|
||||
logger.Warnf("not running (PID: %d) so will remove the pid file: %s", data.PID, pidPath)
|
||||
}
|
||||
// Stale PID file; process no longer exists → clean up.
|
||||
os.Remove(pidPath)
|
||||
}
|
||||
|
||||
data := &PidFileData{
|
||||
PID: os.Getpid(),
|
||||
Version: config.GetVersion(),
|
||||
Port: port,
|
||||
Host: host,
|
||||
}
|
||||
|
||||
token := generateToken()
|
||||
data.Token = token
|
||||
|
||||
raw, err := json.MarshalIndent(data, "", " ")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal pid file: %w", err)
|
||||
}
|
||||
|
||||
// Ensure parent directory exists.
|
||||
dir := filepath.Dir(pidPath)
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create pid directory: %w", err)
|
||||
}
|
||||
|
||||
// Write atomically via temp file + rename.
|
||||
tmp := pidPath + ".tmp"
|
||||
if err := os.WriteFile(tmp, raw, 0o600); err != nil {
|
||||
return nil, fmt.Errorf("failed to write pid file: %w", err)
|
||||
}
|
||||
if err := os.Rename(tmp, pidPath); err != nil {
|
||||
os.Remove(tmp)
|
||||
return nil, fmt.Errorf("failed to rename pid file: %w", err)
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// ReadPidFileWithCheck reads the PID file and additionally checks if
|
||||
// the recorded process is still alive. Returns nil if the file is
|
||||
// missing, unreadable, or the process has exited.
|
||||
func ReadPidFileWithCheck(homePath string) *PidFileData {
|
||||
pidMu.Lock()
|
||||
defer pidMu.Unlock()
|
||||
|
||||
pidPath := pidFilePath(homePath)
|
||||
data, err := readPidFileUnlocked(pidPath)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !isProcessRunning(data.PID) {
|
||||
os.Remove(pidPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
return data
|
||||
}
|
||||
|
||||
// RemovePidFile deletes the PID file (e.g. on graceful shutdown).
|
||||
func RemovePidFile(homePath string) {
|
||||
pidMu.Lock()
|
||||
defer pidMu.Unlock()
|
||||
|
||||
pidPath := pidFilePath(homePath)
|
||||
// Only remove if the PID matches our own process (avoid deleting
|
||||
// a file that belongs to a newer gateway instance).
|
||||
if data, err := readPidFileUnlocked(pidPath); err == nil {
|
||||
if data.PID != os.Getpid() {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
logger.Infof("remove pid file: %s", pidPath)
|
||||
os.Remove(pidPath)
|
||||
}
|
||||
|
||||
// readPidFileUnlocked reads the PID file without acquiring the lock.
|
||||
// Caller must hold pidMu.
|
||||
func readPidFileUnlocked(pidPath string) (*PidFileData, error) {
|
||||
raw, err := os.ReadFile(pidPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var data PidFileData
|
||||
if err := json.Unmarshal(raw, &data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Validate PID is a positive integer.
|
||||
if data.PID <= 0 {
|
||||
return nil, fmt.Errorf("invalid pid in pid file: %d", data.PID)
|
||||
}
|
||||
|
||||
return &data, nil
|
||||
}
|
||||
@@ -0,0 +1,253 @@
|
||||
package pid
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// tmpDir returns a clean temporary directory for a test.
|
||||
func tmpDir(t *testing.T) string {
|
||||
t.Helper()
|
||||
dir, err := os.MkdirTemp("", "pidtest-*")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(func() { os.RemoveAll(dir) })
|
||||
return dir
|
||||
}
|
||||
|
||||
// TestGenerateToken verifies that generateToken produces a 32-character hex string.
|
||||
func TestGenerateToken(t *testing.T) {
|
||||
token := generateToken()
|
||||
if len(token) != 32 {
|
||||
t.Errorf("expected token length 32, got %d (token: %q)", len(token), token)
|
||||
}
|
||||
// Verify all characters are valid hex.
|
||||
for _, c := range token {
|
||||
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) {
|
||||
t.Errorf("token contains non-hex character: %c", c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestGenerateTokenUniqueness checks that two consecutive tokens differ.
|
||||
func TestGenerateTokenUniqueness(t *testing.T) {
|
||||
a := generateToken()
|
||||
b := generateToken()
|
||||
if a == b {
|
||||
t.Error("two consecutive tokens should not be equal")
|
||||
}
|
||||
}
|
||||
|
||||
// TestPidFilePath returns the expected path.
|
||||
func TestPidFilePath(t *testing.T) {
|
||||
dir := tmpDir(t)
|
||||
got := pidFilePath(dir)
|
||||
want := filepath.Join(dir, pidFileName)
|
||||
if got != want {
|
||||
t.Errorf("pidFilePath(%q) = %q, want %q", dir, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWritePidFile creates a PID file and verifies its contents.
|
||||
func TestWritePidFile(t *testing.T) {
|
||||
dir := tmpDir(t)
|
||||
data, err := WritePidFile(dir, "127.0.0.1", 18790)
|
||||
if err != nil {
|
||||
t.Fatalf("WritePidFile failed: %v", err)
|
||||
}
|
||||
|
||||
if data.PID != os.Getpid() {
|
||||
t.Errorf("PID = %d, want %d", data.PID, os.Getpid())
|
||||
}
|
||||
if data.Host != "127.0.0.1" {
|
||||
t.Errorf("Host = %q, want %q", data.Host, "127.0.0.1")
|
||||
}
|
||||
if data.Port != 18790 {
|
||||
t.Errorf("Port = %d, want %d", data.Port, 18790)
|
||||
}
|
||||
if len(data.Token) != 32 {
|
||||
t.Errorf("Token length = %d, want 32", len(data.Token))
|
||||
}
|
||||
|
||||
// Verify the file exists and can be unmarshalled.
|
||||
raw, err := os.ReadFile(filepath.Join(dir, pidFileName))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read pid file: %v", err)
|
||||
}
|
||||
|
||||
var fileData PidFileData
|
||||
if err = json.Unmarshal(raw, &fileData); err != nil {
|
||||
t.Fatalf("failed to unmarshal pid file: %v", err)
|
||||
}
|
||||
if fileData.PID != data.PID || fileData.Token != data.Token {
|
||||
t.Error("file data mismatch")
|
||||
}
|
||||
|
||||
// Verify file permissions (owner-only read/write).
|
||||
info, err := os.Stat(filepath.Join(dir, pidFileName))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to stat pid file: %v", err)
|
||||
}
|
||||
perm := info.Mode().Perm()
|
||||
if perm != 0o600 {
|
||||
t.Errorf("file permission = %o, want 0600", perm)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWritePidFileOverwrite writes twice and verifies the PID file is replaced.
|
||||
func TestWritePidFileOverwrite(t *testing.T) {
|
||||
dir := tmpDir(t)
|
||||
|
||||
data1, err := WritePidFile(dir, "0.0.0.0", 18790)
|
||||
if err != nil {
|
||||
t.Fatalf("first WritePidFile failed: %v", err)
|
||||
}
|
||||
|
||||
// Second write should succeed because the PID matches our process.
|
||||
data2, err := WritePidFile(dir, "0.0.0.0", 18800)
|
||||
if err != nil {
|
||||
t.Fatalf("second WritePidFile failed: %v", err)
|
||||
}
|
||||
|
||||
if data2.Token == data1.Token {
|
||||
t.Error("token should change on re-write")
|
||||
}
|
||||
if data2.Port != 18800 {
|
||||
t.Errorf("Port = %d, want 18800", data2.Port)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWritePidFileStalePID writes a PID file with a non-running PID, then
|
||||
// verifies WritePidFile cleans it up and writes a new one.
|
||||
func TestWritePidFileStalePID(t *testing.T) {
|
||||
dir := tmpDir(t)
|
||||
|
||||
// Write a PID file with a PID that almost certainly doesn't exist.
|
||||
stale := PidFileData{PID: 99999999, Token: "deadbeef12345678deadbeef12345678"}
|
||||
raw, _ := json.MarshalIndent(stale, "", " ")
|
||||
os.WriteFile(filepath.Join(dir, pidFileName), raw, 0o600)
|
||||
|
||||
data, err := WritePidFile(dir, "127.0.0.1", 18790)
|
||||
if err != nil {
|
||||
t.Fatalf("WritePidFile with stale PID failed: %v", err)
|
||||
}
|
||||
if data.PID != os.Getpid() {
|
||||
t.Errorf("PID = %d, want %d", data.PID, os.Getpid())
|
||||
}
|
||||
}
|
||||
|
||||
// TestReadPidFileWithCheck verifies reading a valid PID file for the current process.
|
||||
func TestReadPidFileWithCheck(t *testing.T) {
|
||||
dir := tmpDir(t)
|
||||
|
||||
// Some sandboxed environments (e.g. macOS test runner) may restrict
|
||||
// signal(0), causing isProcessRunning(getpid()) to return false.
|
||||
if !isProcessRunning(os.Getpid()) {
|
||||
t.Skip("skipping: isProcessRunning(getpid()) is false in this environment")
|
||||
}
|
||||
|
||||
written, err := WritePidFile(dir, "127.0.0.1", 18790)
|
||||
if err != nil {
|
||||
t.Fatalf("WritePidFile failed: %v", err)
|
||||
}
|
||||
|
||||
read := ReadPidFileWithCheck(dir)
|
||||
if read == nil {
|
||||
t.Fatal("ReadPidFileWithCheck returned nil for current process")
|
||||
}
|
||||
if read.PID != written.PID || read.Token != written.Token {
|
||||
t.Error("read data doesn't match written data")
|
||||
}
|
||||
}
|
||||
|
||||
// TestReadPidFileWithCheckNonexistent returns nil for missing file.
|
||||
func TestReadPidFileWithCheckNonexistent(t *testing.T) {
|
||||
dir := tmpDir(t)
|
||||
data := ReadPidFileWithCheck(dir)
|
||||
if data != nil {
|
||||
t.Error("expected nil for nonexistent PID file")
|
||||
}
|
||||
}
|
||||
|
||||
// TestReadPidFileWithCheckStalePID auto-cleans a PID file whose process is dead.
|
||||
func TestReadPidFileWithCheckStalePID(t *testing.T) {
|
||||
dir := tmpDir(t)
|
||||
|
||||
stale := PidFileData{PID: 99999999, Token: "deadbeef12345678deadbeef12345678"}
|
||||
raw, _ := json.MarshalIndent(stale, "", " ")
|
||||
os.WriteFile(filepath.Join(dir, pidFileName), raw, 0o600)
|
||||
|
||||
data := ReadPidFileWithCheck(dir)
|
||||
if data != nil {
|
||||
t.Error("expected nil for stale PID")
|
||||
}
|
||||
|
||||
// File should be cleaned up.
|
||||
if _, err := os.Stat(filepath.Join(dir, pidFileName)); !os.IsNotExist(err) {
|
||||
t.Error("stale PID file should be removed")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRemovePidFile removes the PID file for the current process.
|
||||
func TestRemovePidFile(t *testing.T) {
|
||||
dir := tmpDir(t)
|
||||
|
||||
if _, err := WritePidFile(dir, "127.0.0.1", 18790); err != nil {
|
||||
t.Fatalf("WritePidFile failed: %v", err)
|
||||
}
|
||||
|
||||
RemovePidFile(dir)
|
||||
|
||||
if _, err := os.Stat(filepath.Join(dir, pidFileName)); !os.IsNotExist(err) {
|
||||
t.Error("PID file should be removed")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRemovePidFileDifferentPID does not remove a PID file owned by another process.
|
||||
func TestRemovePidFileDifferentPID(t *testing.T) {
|
||||
dir := tmpDir(t)
|
||||
|
||||
other := PidFileData{PID: 99999999, Token: "deadbeef12345678deadbeef12345678"}
|
||||
raw, _ := json.MarshalIndent(other, "", " ")
|
||||
os.WriteFile(filepath.Join(dir, pidFileName), raw, 0o600)
|
||||
|
||||
RemovePidFile(dir)
|
||||
|
||||
if _, err := os.Stat(filepath.Join(dir, pidFileName)); os.IsNotExist(err) {
|
||||
t.Error("PID file should NOT be removed (different PID)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRemovePidFileNonexistent does not error on missing file.
|
||||
func TestRemovePidFileNonexistent(t *testing.T) {
|
||||
dir := tmpDir(t)
|
||||
// Should not panic or error.
|
||||
RemovePidFile(dir)
|
||||
}
|
||||
|
||||
// TestReadPidFileUnlockedInvalidJSON returns error for malformed content.
|
||||
func TestReadPidFileUnlockedInvalidJSON(t *testing.T) {
|
||||
dir := tmpDir(t)
|
||||
path := filepath.Join(dir, pidFileName)
|
||||
os.WriteFile(path, []byte("not json"), 0o600)
|
||||
|
||||
_, err := readPidFileUnlocked(path)
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid JSON")
|
||||
}
|
||||
}
|
||||
|
||||
// TestReadPidFileUnlockedInvalidPID returns error for non-positive PID.
|
||||
func TestReadPidFileUnlockedInvalidPID(t *testing.T) {
|
||||
dir := tmpDir(t)
|
||||
path := filepath.Join(dir, pidFileName)
|
||||
os.WriteFile(path, []byte(`{"pid": -1, "token": "a"}`), 0o600)
|
||||
|
||||
_, err := readPidFileUnlocked(path)
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid PID")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
//go:build !windows
|
||||
|
||||
package pid
|
||||
|
||||
import (
|
||||
"os"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// isProcessRunning checks whether a process with the given PID is alive
|
||||
// on Unix-like systems using signal(0).
|
||||
func isProcessRunning(pid int) bool {
|
||||
if pid <= 0 {
|
||||
return false
|
||||
}
|
||||
p, err := os.FindProcess(pid)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
// Signal(nil) does not kill the process but checks existence on Unix.
|
||||
return p.Signal(syscall.Signal(0)) == nil
|
||||
}
|
||||
@@ -0,0 +1,42 @@
|
||||
//go:build windows
|
||||
|
||||
package pid
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
var (
|
||||
kernel32 = syscall.NewLazyDLL("kernel32.dll")
|
||||
procOpenProcess = kernel32.NewProc("OpenProcess")
|
||||
procGetExitCodeProcess = kernel32.NewProc("GetExitCodeProcess")
|
||||
procCloseHandle = kernel32.NewProc("CloseHandle")
|
||||
processQueryLimitedInformation = uint32(0x1000)
|
||||
stillActive = uint32(259)
|
||||
)
|
||||
|
||||
// isProcessRunning checks whether a process with the given PID is alive
|
||||
// on Windows using OpenProcess + GetExitCodeProcess.
|
||||
func isProcessRunning(pid int) bool {
|
||||
if pid <= 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
handle, _, err := procOpenProcess.Call(
|
||||
uintptr(processQueryLimitedInformation),
|
||||
0,
|
||||
uintptr(pid),
|
||||
)
|
||||
if handle == 0 || err != nil {
|
||||
return false
|
||||
}
|
||||
defer procCloseHandle.Call(handle)
|
||||
|
||||
var exitCode uint32
|
||||
ret, _, err := procGetExitCodeProcess.Call(handle, uintptr(unsafe.Pointer(&exitCode)))
|
||||
if ret == 0 || err != nil {
|
||||
return false
|
||||
}
|
||||
return exitCode == stillActive
|
||||
}
|
||||
Reference in New Issue
Block a user