add pid file for gateway running and auth token for /reload and pico channel

This commit is contained in:
Cytown
2026-03-29 01:14:39 +08:00
parent f1cb7cc8f5
commit 0bb561548f
24 changed files with 876 additions and 260 deletions
+1 -9
View File
@@ -12,7 +12,6 @@ import (
"sync"
"time"
"github.com/sipeed/picoclaw/pkg"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
@@ -59,14 +58,7 @@ func (cb *ContextBuilder) WithSplitOnMarker(enabled bool) *ContextBuilder {
}
func getGlobalConfigDir() string {
if home := os.Getenv(config.EnvHome); home != "" {
return home
}
home, err := os.UserHomeDir()
if err != nil {
return ""
}
return filepath.Join(home, pkg.DefaultPicoClawHome)
return config.GetHome()
}
func NewContextBuilder(workspace string) *ContextBuilder {
+1 -6
View File
@@ -6,7 +6,6 @@ import (
"path/filepath"
"time"
"github.com/sipeed/picoclaw/pkg"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/fileutil"
)
@@ -41,11 +40,7 @@ func (c *AuthCredential) NeedsRefresh() bool {
}
func authFilePath() string {
if home := os.Getenv(config.EnvHome); home != "" {
return filepath.Join(home, "auth.json")
}
home, _ := os.UserHomeDir()
return filepath.Join(home, pkg.DefaultPicoClawHome, "auth.json")
return filepath.Join(config.GetHome(), "auth.json")
}
func LoadStore() (*AuthStore, error) {
+1 -5
View File
@@ -37,11 +37,7 @@ type syncCursorFile struct {
}
func picoclawHomeDir() string {
if home := os.Getenv(config.EnvHome); home != "" {
return home
}
userHome, _ := os.UserHomeDir()
return filepath.Join(userHome, ".picoclaw")
return config.GetHome()
}
func buildWeixinSyncBufPath(cfg config.WeixinConfig) string {
+1 -6
View File
@@ -1081,12 +1081,7 @@ func LoadConfig(path string) (*Config, error) {
// Ensure Workspace has a default if not set
if cfg.Agents.Defaults.Workspace == "" {
homePath, _ := os.UserHomeDir()
if picoclawHome := os.Getenv(EnvHome); picoclawHome != "" {
homePath = picoclawHome
} else if homePath != "" {
homePath = filepath.Join(homePath, pkg.DefaultPicoClawHome)
}
homePath := GetHome()
cfg.Agents.Defaults.Workspace = filepath.Join(homePath, pkg.WorkspaceName)
}
+1 -11
View File
@@ -6,7 +6,6 @@
package config
import (
"os"
"path/filepath"
"github.com/sipeed/picoclaw/pkg"
@@ -14,16 +13,7 @@ import (
// DefaultConfig returns the default configuration for PicoClaw.
func DefaultConfig() *Config {
// Determine the base path for the workspace.
// Priority: $PICOCLAW_HOME > ~/.picoclaw
var homePath string
if picoclawHome := os.Getenv(EnvHome); picoclawHome != "" {
homePath = picoclawHome
} else {
userHome, _ := os.UserHomeDir()
homePath = filepath.Join(userHome, pkg.DefaultPicoClawHome)
}
workspacePath := filepath.Join(homePath, pkg.WorkspaceName)
workspacePath := filepath.Join(GetHome(), pkg.WorkspaceName)
return &Config{
Version: CurrentVersion,
+20
View File
@@ -5,6 +5,13 @@
package config
import (
"os"
"path/filepath"
"github.com/sipeed/picoclaw/pkg"
)
// Runtime environment variable keys for the picoclaw process.
// These control the location of files and binaries at runtime and are read
// directly via os.Getenv / os.LookupEnv. All picoclaw-specific keys use the
@@ -35,3 +42,16 @@ const (
// Default: "127.0.0.1"
EnvGatewayHost = "PICOCLAW_GATEWAY_HOST"
)
func GetHome() string {
homePath, _ := os.UserHomeDir()
if picoclawHome := os.Getenv(EnvHome); picoclawHome != "" {
homePath = picoclawHome
} else if homePath != "" {
homePath = filepath.Join(homePath, pkg.DefaultPicoClawHome)
}
if homePath == "" {
homePath = "."
}
return homePath
}
+38 -2
View File
@@ -6,6 +6,7 @@ import (
"os"
"os/signal"
"path/filepath"
"strings"
"sync"
"sync/atomic"
"syscall"
@@ -36,6 +37,7 @@ import (
"github.com/sipeed/picoclaw/pkg/heartbeat"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/media"
"github.com/sipeed/picoclaw/pkg/pid"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/state"
"github.com/sipeed/picoclaw/pkg/tools"
@@ -61,6 +63,7 @@ type services struct {
HealthServer *health.Server
manualReloadChan chan struct{}
reloading atomic.Bool
authToken string
}
type startupBlockedProvider struct {
@@ -107,6 +110,13 @@ func Run(debug bool, homePath, configPath string, allowEmptyStartup bool) error
fmt.Println("🔍 Debug mode enabled")
}
// Enforce singleton: write PID file with generated token.
pidData, err := pid.WritePidFile(homePath, cfg.Gateway.Host, cfg.Gateway.Port)
if err != nil {
return fmt.Errorf("singleton check failed: %w", err)
}
defer pid.RemovePidFile(homePath)
provider, modelID, err := createStartupProvider(cfg, allowEmptyStartup)
if err != nil {
return fmt.Errorf("error creating provider: %w", err)
@@ -133,7 +143,7 @@ func Run(debug bool, homePath, configPath string, allowEmptyStartup bool) error
"skills_available": skillsInfo["available"],
})
runningServices, err := setupAndStartServices(cfg, agentLoop, msgBus)
runningServices, err := setupAndStartServices(cfg, agentLoop, msgBus, pidData.Token)
if err != nil {
return err
}
@@ -224,6 +234,9 @@ func executeReload(
allowEmptyStartup bool,
) error {
defer runningServices.reloading.Store(false)
overridePicoToken(newCfg, runningServices.authToken)
return handleConfigReload(ctx, agentLoop, newCfg, provider, runningServices, msgBus, allowEmptyStartup)
}
@@ -248,6 +261,7 @@ func setupAndStartServices(
cfg *config.Config,
agentLoop *agent.AgentLoop,
msgBus *bus.MessageBus,
authToken string,
) (*services, error) {
runningServices := &services{}
@@ -290,6 +304,8 @@ func setupAndStartServices(
fms.Start()
}
overridePicoToken(cfg, authToken)
runningServices.ChannelManager, err = channels.NewManager(cfg, msgBus, runningServices.MediaStore)
if err != nil {
if fms, ok := runningServices.MediaStore.(*media.FileMediaStore); ok {
@@ -314,7 +330,8 @@ func setupAndStartServices(
}
addr := fmt.Sprintf("%s:%d", cfg.Gateway.Host, cfg.Gateway.Port)
runningServices.HealthServer = health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port)
runningServices.authToken = authToken
runningServices.HealthServer = health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port, authToken)
runningServices.ChannelManager.SetupHTTPServer(addr, runningServices.HealthServer)
if err = runningServices.ChannelManager.StartAll(context.Background()); err != nil {
@@ -524,6 +541,9 @@ func restartServices(
logger.InfoCF("voice", "Transcription disabled", nil)
}
// NOTE: PID file is written once at startup and not updated on reload.
// Changing the gateway listen address requires a full restart.
return nil
}
@@ -642,6 +662,22 @@ func setupCronTool(
return cronService, nil
}
const picoTokenPrefix = "pico-"
// overridePicoToken replaces the pico channel token with the one from the PID file.
// The PID file is the single source of truth for the pico auth token;
// it is generated once at gateway startup and remains unchanged across reloads.
func overridePicoToken(cfg *config.Config, token string) {
if !cfg.Channels.Pico.Enabled {
return
}
picoToken := cfg.Channels.Pico.Token.String()
if picoToken == "" || strings.HasPrefix(picoToken, picoTokenPrefix) {
return
}
cfg.Channels.Pico.SetToken(picoTokenPrefix + token + picoToken)
}
func createHeartbeatHandler(agentLoop *agent.AgentLoop) func(prompt, channel, chatID string) *tools.ToolResult {
return func(prompt, channel, chatID string) *tools.ToolResult {
if channel == "" || chatID == "" {
+32 -4
View File
@@ -2,11 +2,11 @@ package health
import (
"context"
"crypto/subtle"
"encoding/json"
"fmt"
"maps"
"net/http"
"os"
"sync"
"time"
)
@@ -18,6 +18,7 @@ type Server struct {
checks map[string]Check
startTime time.Time
reloadFunc func() error
authToken string // optional bearer token for protected endpoints
}
type Check struct {
@@ -31,15 +32,15 @@ type StatusResponse struct {
Status string `json:"status"`
Uptime string `json:"uptime"`
Checks map[string]Check `json:"checks,omitempty"`
Pid int `json:"pid"`
}
func NewServer(host string, port int) *Server {
func NewServer(host string, port int, token string) *Server {
mux := http.NewServeMux()
s := &Server{
ready: false,
checks: make(map[string]Check),
startTime: time.Now(),
authToken: token,
}
mux.HandleFunc("/health", s.healthHandler)
@@ -123,6 +124,21 @@ func (s *Server) reloadHandler(w http.ResponseWriter, r *http.Request) {
return
}
// Token check
s.mu.RLock()
requiredToken := s.authToken
s.mu.RUnlock()
if requiredToken != "" {
given := extractBearerToken(r.Header.Get("Authorization"))
if given == "" || subtle.ConstantTimeCompare([]byte(given), []byte(requiredToken)) != 1 {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
json.NewEncoder(w).Encode(map[string]string{"error": "unauthorized"})
return
}
}
s.mu.Lock()
reloadFunc := s.reloadFunc
s.mu.Unlock()
@@ -154,7 +170,6 @@ func (s *Server) healthHandler(w http.ResponseWriter, r *http.Request) {
resp := StatusResponse{
Status: "ok",
Uptime: uptime.String(),
Pid: os.Getpid(),
}
json.NewEncoder(w).Encode(resp)
@@ -220,3 +235,16 @@ func statusString(ok bool) string {
}
return "fail"
}
// extractBearerToken returns the token from an "Authorization: Bearer <t>" header,
// or the empty string if the header is missing or malformed.
func extractBearerToken(header string) string {
const prefix = "Bearer "
if len(header) < len(prefix) {
return ""
}
if header[:len(prefix)] != prefix {
return ""
}
return header[len(prefix):]
}
+6 -5
View File
@@ -15,6 +15,7 @@ func newTestServer() *Server {
ready: false,
checks: make(map[string]Check),
startTime: time.Now(),
authToken: "test",
}
return s
}
@@ -37,9 +38,6 @@ func TestHealthHandler_ReturnsOK(t *testing.T) {
if resp.Status != "ok" {
t.Errorf("status = %q, want %q", resp.Status, "ok")
}
if resp.Pid == 0 {
t.Error("pid should not be 0")
}
if resp.Uptime == "" {
t.Error("uptime should not be empty")
}
@@ -168,6 +166,7 @@ func TestReloadHandler_NoReloadFunc(t *testing.T) {
s := newTestServer()
req := httptest.NewRequest(http.MethodPost, "/reload", nil)
req.Header.Set("Authorization", "Bearer test")
w := httptest.NewRecorder()
s.reloadHandler(w, req)
@@ -186,6 +185,7 @@ func TestReloadHandler_Success(t *testing.T) {
})
req := httptest.NewRequest(http.MethodPost, "/reload", nil)
req.Header.Set("Authorization", "Bearer test")
w := httptest.NewRecorder()
s.reloadHandler(w, req)
@@ -205,6 +205,7 @@ func TestReloadHandler_Error(t *testing.T) {
})
req := httptest.NewRequest(http.MethodPost, "/reload", nil)
req.Header.Set("Authorization", "Bearer test")
w := httptest.NewRecorder()
s.reloadHandler(w, req)
@@ -292,7 +293,7 @@ func TestRegisterOnMux(t *testing.T) {
}
func TestNewServer(t *testing.T) {
s := NewServer("127.0.0.1", 0)
s := NewServer("127.0.0.1", 0, "")
if s == nil {
t.Fatal("NewServer returned nil")
}
@@ -305,7 +306,7 @@ func TestNewServer(t *testing.T) {
}
func TestStartContext_Cancellation(t *testing.T) {
s := NewServer("127.0.0.1", 0)
s := NewServer("127.0.0.1", 0, "")
ctx, cancel := context.WithCancel(context.Background())
+1 -10
View File
@@ -1,12 +1,10 @@
package internal
import (
"fmt"
"io"
"os"
"path/filepath"
"github.com/sipeed/picoclaw/pkg"
"github.com/sipeed/picoclaw/pkg/config"
)
@@ -14,14 +12,7 @@ func ResolveTargetHome(override string) (string, error) {
if override != "" {
return ExpandHome(override), nil
}
if envHome := os.Getenv(config.EnvHome); envHome != "" {
return ExpandHome(envHome), nil
}
home, err := os.UserHomeDir()
if err != nil {
return "", fmt.Errorf("resolving home directory: %w", err)
}
return filepath.Join(home, pkg.DefaultPicoClawHome), nil
return config.GetHome(), nil
}
func ExpandHome(path string) string {
+159
View File
@@ -0,0 +1,159 @@
package pid
import (
"crypto/rand"
"encoding/hex"
"encoding/json"
"fmt"
"os"
"path/filepath"
"sync"
"time"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
)
const pidFileName = ".picoclaw.pid"
// PidFileData is the JSON structure stored in the PID file.
type PidFileData struct {
PID int `json:"pid"`
Token string `json:"token"`
Version string `json:"version"`
Port int `json:"port"`
Host string `json:"host"`
}
var pidMu sync.Mutex
// pidFilePath returns the absolute path for the PID file given the home directory.
func pidFilePath(homePath string) string {
return filepath.Join(homePath, pidFileName)
}
// generateToken creates a cryptographically random 32-character hex token.
func generateToken() string {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
// Fallback to something pseudo-random if crypto/rand fails
return fmt.Sprintf("%032x", time.Now().UnixNano())
}
return hex.EncodeToString(b)
}
// WritePidFile creates (or overwrites) the PID file atomically.
// It returns an error if another gateway instance appears to be running
// (a valid PID file exists with a live process).
func WritePidFile(homePath, host string, port int) (*PidFileData, error) {
pidMu.Lock()
defer pidMu.Unlock()
pidPath := pidFilePath(homePath)
// Check for existing PID file → singleton enforcement.
if data, err := readPidFileUnlocked(pidPath); err == nil {
if os.Getpid() != data.PID {
logger.Infof("found pid file (PID: %d, version: %s)", data.PID, data.Version)
if isProcessRunning(data.PID) {
return nil, fmt.Errorf("gateway is already running (PID: %d, version: %s)", data.PID, data.Version)
}
logger.Warnf("not running (PID: %d) so will remove the pid file: %s", data.PID, pidPath)
}
// Stale PID file; process no longer exists → clean up.
os.Remove(pidPath)
}
data := &PidFileData{
PID: os.Getpid(),
Version: config.GetVersion(),
Port: port,
Host: host,
}
token := generateToken()
data.Token = token
raw, err := json.MarshalIndent(data, "", " ")
if err != nil {
return nil, fmt.Errorf("failed to marshal pid file: %w", err)
}
// Ensure parent directory exists.
dir := filepath.Dir(pidPath)
if err := os.MkdirAll(dir, 0o755); err != nil {
return nil, fmt.Errorf("failed to create pid directory: %w", err)
}
// Write atomically via temp file + rename.
tmp := pidPath + ".tmp"
if err := os.WriteFile(tmp, raw, 0o600); err != nil {
return nil, fmt.Errorf("failed to write pid file: %w", err)
}
if err := os.Rename(tmp, pidPath); err != nil {
os.Remove(tmp)
return nil, fmt.Errorf("failed to rename pid file: %w", err)
}
return data, nil
}
// ReadPidFileWithCheck reads the PID file and additionally checks if
// the recorded process is still alive. Returns nil if the file is
// missing, unreadable, or the process has exited.
func ReadPidFileWithCheck(homePath string) *PidFileData {
pidMu.Lock()
defer pidMu.Unlock()
pidPath := pidFilePath(homePath)
data, err := readPidFileUnlocked(pidPath)
if err != nil {
return nil
}
if !isProcessRunning(data.PID) {
os.Remove(pidPath)
return nil
}
return data
}
// RemovePidFile deletes the PID file (e.g. on graceful shutdown).
func RemovePidFile(homePath string) {
pidMu.Lock()
defer pidMu.Unlock()
pidPath := pidFilePath(homePath)
// Only remove if the PID matches our own process (avoid deleting
// a file that belongs to a newer gateway instance).
if data, err := readPidFileUnlocked(pidPath); err == nil {
if data.PID != os.Getpid() {
return
}
}
logger.Infof("remove pid file: %s", pidPath)
os.Remove(pidPath)
}
// readPidFileUnlocked reads the PID file without acquiring the lock.
// Caller must hold pidMu.
func readPidFileUnlocked(pidPath string) (*PidFileData, error) {
raw, err := os.ReadFile(pidPath)
if err != nil {
return nil, err
}
var data PidFileData
if err := json.Unmarshal(raw, &data); err != nil {
return nil, err
}
// Validate PID is a positive integer.
if data.PID <= 0 {
return nil, fmt.Errorf("invalid pid in pid file: %d", data.PID)
}
return &data, nil
}
+253
View File
@@ -0,0 +1,253 @@
package pid
import (
"encoding/json"
"os"
"path/filepath"
"testing"
)
// tmpDir returns a clean temporary directory for a test.
func tmpDir(t *testing.T) string {
t.Helper()
dir, err := os.MkdirTemp("", "pidtest-*")
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { os.RemoveAll(dir) })
return dir
}
// TestGenerateToken verifies that generateToken produces a 32-character hex string.
func TestGenerateToken(t *testing.T) {
token := generateToken()
if len(token) != 32 {
t.Errorf("expected token length 32, got %d (token: %q)", len(token), token)
}
// Verify all characters are valid hex.
for _, c := range token {
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) {
t.Errorf("token contains non-hex character: %c", c)
}
}
}
// TestGenerateTokenUniqueness checks that two consecutive tokens differ.
func TestGenerateTokenUniqueness(t *testing.T) {
a := generateToken()
b := generateToken()
if a == b {
t.Error("two consecutive tokens should not be equal")
}
}
// TestPidFilePath returns the expected path.
func TestPidFilePath(t *testing.T) {
dir := tmpDir(t)
got := pidFilePath(dir)
want := filepath.Join(dir, pidFileName)
if got != want {
t.Errorf("pidFilePath(%q) = %q, want %q", dir, got, want)
}
}
// TestWritePidFile creates a PID file and verifies its contents.
func TestWritePidFile(t *testing.T) {
dir := tmpDir(t)
data, err := WritePidFile(dir, "127.0.0.1", 18790)
if err != nil {
t.Fatalf("WritePidFile failed: %v", err)
}
if data.PID != os.Getpid() {
t.Errorf("PID = %d, want %d", data.PID, os.Getpid())
}
if data.Host != "127.0.0.1" {
t.Errorf("Host = %q, want %q", data.Host, "127.0.0.1")
}
if data.Port != 18790 {
t.Errorf("Port = %d, want %d", data.Port, 18790)
}
if len(data.Token) != 32 {
t.Errorf("Token length = %d, want 32", len(data.Token))
}
// Verify the file exists and can be unmarshalled.
raw, err := os.ReadFile(filepath.Join(dir, pidFileName))
if err != nil {
t.Fatalf("failed to read pid file: %v", err)
}
var fileData PidFileData
if err = json.Unmarshal(raw, &fileData); err != nil {
t.Fatalf("failed to unmarshal pid file: %v", err)
}
if fileData.PID != data.PID || fileData.Token != data.Token {
t.Error("file data mismatch")
}
// Verify file permissions (owner-only read/write).
info, err := os.Stat(filepath.Join(dir, pidFileName))
if err != nil {
t.Fatalf("failed to stat pid file: %v", err)
}
perm := info.Mode().Perm()
if perm != 0o600 {
t.Errorf("file permission = %o, want 0600", perm)
}
}
// TestWritePidFileOverwrite writes twice and verifies the PID file is replaced.
func TestWritePidFileOverwrite(t *testing.T) {
dir := tmpDir(t)
data1, err := WritePidFile(dir, "0.0.0.0", 18790)
if err != nil {
t.Fatalf("first WritePidFile failed: %v", err)
}
// Second write should succeed because the PID matches our process.
data2, err := WritePidFile(dir, "0.0.0.0", 18800)
if err != nil {
t.Fatalf("second WritePidFile failed: %v", err)
}
if data2.Token == data1.Token {
t.Error("token should change on re-write")
}
if data2.Port != 18800 {
t.Errorf("Port = %d, want 18800", data2.Port)
}
}
// TestWritePidFileStalePID writes a PID file with a non-running PID, then
// verifies WritePidFile cleans it up and writes a new one.
func TestWritePidFileStalePID(t *testing.T) {
dir := tmpDir(t)
// Write a PID file with a PID that almost certainly doesn't exist.
stale := PidFileData{PID: 99999999, Token: "deadbeef12345678deadbeef12345678"}
raw, _ := json.MarshalIndent(stale, "", " ")
os.WriteFile(filepath.Join(dir, pidFileName), raw, 0o600)
data, err := WritePidFile(dir, "127.0.0.1", 18790)
if err != nil {
t.Fatalf("WritePidFile with stale PID failed: %v", err)
}
if data.PID != os.Getpid() {
t.Errorf("PID = %d, want %d", data.PID, os.Getpid())
}
}
// TestReadPidFileWithCheck verifies reading a valid PID file for the current process.
func TestReadPidFileWithCheck(t *testing.T) {
dir := tmpDir(t)
// Some sandboxed environments (e.g. macOS test runner) may restrict
// signal(0), causing isProcessRunning(getpid()) to return false.
if !isProcessRunning(os.Getpid()) {
t.Skip("skipping: isProcessRunning(getpid()) is false in this environment")
}
written, err := WritePidFile(dir, "127.0.0.1", 18790)
if err != nil {
t.Fatalf("WritePidFile failed: %v", err)
}
read := ReadPidFileWithCheck(dir)
if read == nil {
t.Fatal("ReadPidFileWithCheck returned nil for current process")
}
if read.PID != written.PID || read.Token != written.Token {
t.Error("read data doesn't match written data")
}
}
// TestReadPidFileWithCheckNonexistent returns nil for missing file.
func TestReadPidFileWithCheckNonexistent(t *testing.T) {
dir := tmpDir(t)
data := ReadPidFileWithCheck(dir)
if data != nil {
t.Error("expected nil for nonexistent PID file")
}
}
// TestReadPidFileWithCheckStalePID auto-cleans a PID file whose process is dead.
func TestReadPidFileWithCheckStalePID(t *testing.T) {
dir := tmpDir(t)
stale := PidFileData{PID: 99999999, Token: "deadbeef12345678deadbeef12345678"}
raw, _ := json.MarshalIndent(stale, "", " ")
os.WriteFile(filepath.Join(dir, pidFileName), raw, 0o600)
data := ReadPidFileWithCheck(dir)
if data != nil {
t.Error("expected nil for stale PID")
}
// File should be cleaned up.
if _, err := os.Stat(filepath.Join(dir, pidFileName)); !os.IsNotExist(err) {
t.Error("stale PID file should be removed")
}
}
// TestRemovePidFile removes the PID file for the current process.
func TestRemovePidFile(t *testing.T) {
dir := tmpDir(t)
if _, err := WritePidFile(dir, "127.0.0.1", 18790); err != nil {
t.Fatalf("WritePidFile failed: %v", err)
}
RemovePidFile(dir)
if _, err := os.Stat(filepath.Join(dir, pidFileName)); !os.IsNotExist(err) {
t.Error("PID file should be removed")
}
}
// TestRemovePidFileDifferentPID does not remove a PID file owned by another process.
func TestRemovePidFileDifferentPID(t *testing.T) {
dir := tmpDir(t)
other := PidFileData{PID: 99999999, Token: "deadbeef12345678deadbeef12345678"}
raw, _ := json.MarshalIndent(other, "", " ")
os.WriteFile(filepath.Join(dir, pidFileName), raw, 0o600)
RemovePidFile(dir)
if _, err := os.Stat(filepath.Join(dir, pidFileName)); os.IsNotExist(err) {
t.Error("PID file should NOT be removed (different PID)")
}
}
// TestRemovePidFileNonexistent does not error on missing file.
func TestRemovePidFileNonexistent(t *testing.T) {
dir := tmpDir(t)
// Should not panic or error.
RemovePidFile(dir)
}
// TestReadPidFileUnlockedInvalidJSON returns error for malformed content.
func TestReadPidFileUnlockedInvalidJSON(t *testing.T) {
dir := tmpDir(t)
path := filepath.Join(dir, pidFileName)
os.WriteFile(path, []byte("not json"), 0o600)
_, err := readPidFileUnlocked(path)
if err == nil {
t.Error("expected error for invalid JSON")
}
}
// TestReadPidFileUnlockedInvalidPID returns error for non-positive PID.
func TestReadPidFileUnlockedInvalidPID(t *testing.T) {
dir := tmpDir(t)
path := filepath.Join(dir, pidFileName)
os.WriteFile(path, []byte(`{"pid": -1, "token": "a"}`), 0o600)
_, err := readPidFileUnlocked(path)
if err == nil {
t.Error("expected error for invalid PID")
}
}
+22
View File
@@ -0,0 +1,22 @@
//go:build !windows
package pid
import (
"os"
"syscall"
)
// isProcessRunning checks whether a process with the given PID is alive
// on Unix-like systems using signal(0).
func isProcessRunning(pid int) bool {
if pid <= 0 {
return false
}
p, err := os.FindProcess(pid)
if err != nil {
return false
}
// Signal(nil) does not kill the process but checks existence on Unix.
return p.Signal(syscall.Signal(0)) == nil
}
+42
View File
@@ -0,0 +1,42 @@
//go:build windows
package pid
import (
"syscall"
"unsafe"
)
var (
kernel32 = syscall.NewLazyDLL("kernel32.dll")
procOpenProcess = kernel32.NewProc("OpenProcess")
procGetExitCodeProcess = kernel32.NewProc("GetExitCodeProcess")
procCloseHandle = kernel32.NewProc("CloseHandle")
processQueryLimitedInformation = uint32(0x1000)
stillActive = uint32(259)
)
// isProcessRunning checks whether a process with the given PID is alive
// on Windows using OpenProcess + GetExitCodeProcess.
func isProcessRunning(pid int) bool {
if pid <= 0 {
return false
}
handle, _, err := procOpenProcess.Call(
uintptr(processQueryLimitedInformation),
0,
uintptr(pid),
)
if handle == 0 || err != nil {
return false
}
defer procCloseHandle.Call(handle)
var exitCode uint32
ret, _, err := procGetExitCodeProcess.Call(handle, uintptr(unsafe.Pointer(&exitCode)))
if ret == 0 || err != nil {
return false
}
return exitCode == stillActive
}