Merge pull request #2134 from cytown/t3

add pid file for gateway running and auth token for /reload and pico channel
This commit is contained in:
daming大铭
2026-03-30 18:55:01 +08:00
committed by GitHub
25 changed files with 876 additions and 263 deletions
+20 -54
View File
@@ -7,9 +7,7 @@ package ui
import (
"fmt"
"os"
"os/exec"
"path/filepath"
"runtime"
"strconv"
"strings"
@@ -17,63 +15,30 @@ import (
"github.com/gdamore/tcell/v2"
"github.com/rivo/tview"
)
const pidFileName = "gateway.pid"
"github.com/sipeed/picoclaw/pkg/config"
ppid "github.com/sipeed/picoclaw/pkg/pid"
)
type gatewayStatus struct {
running bool
pid int
version string
}
func getPidPath() string {
home, err := os.UserHomeDir()
if err != nil {
home = "."
}
return filepath.Join(home, ".picoclaw", pidFileName)
}
func isProcessRunning(pid int) bool {
switch runtime.GOOS {
case "windows":
cmd := exec.Command("tasklist", "/FI", fmt.Sprintf("PID eq %d", pid))
output, err := cmd.Output()
if err != nil {
return false
}
return strings.Contains(string(output), strconv.Itoa(pid))
case "darwin":
cmd := exec.Command("ps", "aux")
output, err := cmd.Output()
if err != nil {
return false
}
return strings.Contains(string(output), fmt.Sprintf(" %d ", pid))
default:
// Linux and other unix-like systems.
_, err := os.Stat(fmt.Sprintf("/proc/%d", pid))
return err == nil
}
func picoHome() string {
return config.GetHome()
}
func getGatewayStatus() gatewayStatus {
pidPath := getPidPath()
data, err := os.ReadFile(pidPath)
if err != nil {
return gatewayStatus{running: false}
}
pid, err := strconv.Atoi(strings.TrimSpace(string(data)))
if err != nil {
return gatewayStatus{running: false}
}
if !isProcessRunning(pid) {
os.Remove(pidPath)
data := ppid.ReadPidFileWithCheck(picoHome())
if data == nil {
return gatewayStatus{running: false}
}
return gatewayStatus{
running: true,
pid: pid,
pid: data.PID,
version: data.Version,
}
}
@@ -83,13 +48,12 @@ func startGateway() error {
return fmt.Errorf("gateway is already running (PID: %d)", status.pid)
}
pidPath := getPidPath()
var cmd *exec.Cmd
if runtime.GOOS == "windows" {
cmd = exec.Command("cmd", "/C", "start /B picoclaw gateway > NUL 2>&1")
} else {
cmd = exec.Command("sh", "-c", "nohup picoclaw gateway > /dev/null 2>&1 & echo $! > "+pidPath)
cmd = exec.Command("sh", "-c", "nohup picoclaw gateway > /dev/null 2>&1 &")
}
err := cmd.Start()
@@ -118,9 +82,8 @@ func startGateway() error {
if line == "" {
continue
}
pid, err := strconv.Atoi(line)
_, err := strconv.Atoi(line)
if err == nil {
os.WriteFile(pidPath, []byte(strconv.Itoa(pid)), 0o600)
break
}
}
@@ -143,21 +106,20 @@ func stopGateway() error {
if runtime.GOOS == "windows" {
err = exec.Command("taskkill", "/F", "/PID", strconv.Itoa(status.pid)).Run()
} else {
err = exec.Command("kill", "-9", strconv.Itoa(status.pid)).Run()
err = exec.Command("kill", strconv.Itoa(status.pid)).Run()
}
if err != nil {
return err
}
// 多次尝试确认进程已停止
// Wait for process to stop (ReadPidFileWithCheck cleans up stale pid file)
for i := 0; i < 5; i++ {
if !isProcessRunning(status.pid) {
if !getGatewayStatus().running {
break
}
time.Sleep(200 * time.Millisecond)
}
os.Remove(getPidPath())
return nil
}
@@ -219,7 +181,11 @@ func (a *App) newGatewayPage() tview.Primitive {
updateStatus = func() {
status := getGatewayStatus()
if status.running {
statusTV.SetText(fmt.Sprintf("[#39ff14::b]GATEWAY RUNNING[-]\n\nPID: %d", status.pid))
versionInfo := ""
if status.version != "" {
versionInfo = fmt.Sprintf("\nVersion: %s", status.version)
}
statusTV.SetText(fmt.Sprintf("[#39ff14::b]GATEWAY RUNNING[-]\n\nPID: %d%s", status.pid, versionInfo))
buttons.SetItemText(0, " [gray]START[white] ", "")
buttons.SetItemText(1, " [red]STOP[white] ", "")
} else {
+1 -5
View File
@@ -14,11 +14,7 @@ const Logo = pkg.Logo
// GetPicoclawHome returns the picoclaw home directory.
// Priority: $PICOCLAW_HOME > ~/.picoclaw
func GetPicoclawHome() string {
if home := os.Getenv(config.EnvHome); home != "" {
return home
}
home, _ := os.UserHomeDir()
return filepath.Join(home, pkg.DefaultPicoClawHome)
return config.GetHome()
}
func GetConfigPath() string {
+1 -9
View File
@@ -12,7 +12,6 @@ import (
"sync"
"time"
"github.com/sipeed/picoclaw/pkg"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
@@ -59,14 +58,7 @@ func (cb *ContextBuilder) WithSplitOnMarker(enabled bool) *ContextBuilder {
}
func getGlobalConfigDir() string {
if home := os.Getenv(config.EnvHome); home != "" {
return home
}
home, err := os.UserHomeDir()
if err != nil {
return ""
}
return filepath.Join(home, pkg.DefaultPicoClawHome)
return config.GetHome()
}
func NewContextBuilder(workspace string) *ContextBuilder {
+1 -6
View File
@@ -6,7 +6,6 @@ import (
"path/filepath"
"time"
"github.com/sipeed/picoclaw/pkg"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/fileutil"
)
@@ -41,11 +40,7 @@ func (c *AuthCredential) NeedsRefresh() bool {
}
func authFilePath() string {
if home := os.Getenv(config.EnvHome); home != "" {
return filepath.Join(home, "auth.json")
}
home, _ := os.UserHomeDir()
return filepath.Join(home, pkg.DefaultPicoClawHome, "auth.json")
return filepath.Join(config.GetHome(), "auth.json")
}
func LoadStore() (*AuthStore, error) {
+2
View File
@@ -17,6 +17,8 @@ const (
TypeTypingStop = "typing.stop"
TypeError = "error"
TypePong = "pong"
PicoTokenPrefix = "pico-"
)
// PicoMessage is the wire format for all Pico Protocol messages.
+1 -5
View File
@@ -41,11 +41,7 @@ type contextTokensFile struct {
}
func picoclawHomeDir() string {
if home := os.Getenv(config.EnvHome); home != "" {
return home
}
userHome, _ := os.UserHomeDir()
return filepath.Join(userHome, ".picoclaw")
return config.GetHome()
}
func genWeixinAccountKey(cfg config.WeixinConfig) string {
+1 -6
View File
@@ -1031,12 +1031,7 @@ func LoadConfig(path string) (*Config, error) {
// Ensure Workspace has a default if not set
if cfg.Agents.Defaults.Workspace == "" {
homePath, _ := os.UserHomeDir()
if picoclawHome := os.Getenv(EnvHome); picoclawHome != "" {
homePath = picoclawHome
} else if homePath != "" {
homePath = filepath.Join(homePath, pkg.DefaultPicoClawHome)
}
homePath := GetHome()
cfg.Agents.Defaults.Workspace = filepath.Join(homePath, pkg.WorkspaceName)
}
+1 -11
View File
@@ -6,7 +6,6 @@
package config
import (
"os"
"path/filepath"
"github.com/sipeed/picoclaw/pkg"
@@ -14,16 +13,7 @@ import (
// DefaultConfig returns the default configuration for PicoClaw.
func DefaultConfig() *Config {
// Determine the base path for the workspace.
// Priority: $PICOCLAW_HOME > ~/.picoclaw
var homePath string
if picoclawHome := os.Getenv(EnvHome); picoclawHome != "" {
homePath = picoclawHome
} else {
userHome, _ := os.UserHomeDir()
homePath = filepath.Join(userHome, pkg.DefaultPicoClawHome)
}
workspacePath := filepath.Join(homePath, pkg.WorkspaceName)
workspacePath := filepath.Join(GetHome(), pkg.WorkspaceName)
return &Config{
Version: CurrentVersion,
+20
View File
@@ -5,6 +5,13 @@
package config
import (
"os"
"path/filepath"
"github.com/sipeed/picoclaw/pkg"
)
// Runtime environment variable keys for the picoclaw process.
// These control the location of files and binaries at runtime and are read
// directly via os.Getenv / os.LookupEnv. All picoclaw-specific keys use the
@@ -35,3 +42,16 @@ const (
// Default: "127.0.0.1"
EnvGatewayHost = "PICOCLAW_GATEWAY_HOST"
)
func GetHome() string {
homePath, _ := os.UserHomeDir()
if picoclawHome := os.Getenv(EnvHome); picoclawHome != "" {
homePath = picoclawHome
} else if homePath != "" {
homePath = filepath.Join(homePath, pkg.DefaultPicoClawHome)
}
if homePath == "" {
homePath = "."
}
return homePath
}
+37 -3
View File
@@ -6,6 +6,7 @@ import (
"os"
"os/signal"
"path/filepath"
"strings"
"sync"
"sync/atomic"
"syscall"
@@ -21,7 +22,7 @@ import (
_ "github.com/sipeed/picoclaw/pkg/channels/line"
_ "github.com/sipeed/picoclaw/pkg/channels/maixcam"
_ "github.com/sipeed/picoclaw/pkg/channels/onebot"
_ "github.com/sipeed/picoclaw/pkg/channels/pico"
"github.com/sipeed/picoclaw/pkg/channels/pico"
_ "github.com/sipeed/picoclaw/pkg/channels/qq"
_ "github.com/sipeed/picoclaw/pkg/channels/slack"
_ "github.com/sipeed/picoclaw/pkg/channels/telegram"
@@ -36,6 +37,7 @@ import (
"github.com/sipeed/picoclaw/pkg/heartbeat"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/media"
"github.com/sipeed/picoclaw/pkg/pid"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/state"
"github.com/sipeed/picoclaw/pkg/tools"
@@ -61,6 +63,7 @@ type services struct {
HealthServer *health.Server
manualReloadChan chan struct{}
reloading atomic.Bool
authToken string
}
type startupBlockedProvider struct {
@@ -113,6 +116,13 @@ func Run(debug bool, homePath, configPath string, allowEmptyStartup bool) error
logger.Infof("Log level set to %q", cfg.Gateway.LogLevel)
}
// Enforce singleton: write PID file with generated token.
pidData, err := pid.WritePidFile(homePath, cfg.Gateway.Host, cfg.Gateway.Port)
if err != nil {
return fmt.Errorf("singleton check failed: %w", err)
}
defer pid.RemovePidFile(homePath)
provider, modelID, err := createStartupProvider(cfg, allowEmptyStartup)
if err != nil {
return fmt.Errorf("error creating provider: %w", err)
@@ -139,7 +149,7 @@ func Run(debug bool, homePath, configPath string, allowEmptyStartup bool) error
"skills_available": skillsInfo["available"],
})
runningServices, err := setupAndStartServices(cfg, agentLoop, msgBus)
runningServices, err := setupAndStartServices(cfg, agentLoop, msgBus, pidData.Token)
if err != nil {
return err
}
@@ -238,6 +248,9 @@ func executeReload(
debug bool,
) error {
defer runningServices.reloading.Store(false)
overridePicoToken(newCfg, runningServices.authToken)
return handleConfigReload(ctx, agentLoop, newCfg, provider, runningServices, msgBus, allowEmptyStartup, debug)
}
@@ -262,6 +275,7 @@ func setupAndStartServices(
cfg *config.Config,
agentLoop *agent.AgentLoop,
msgBus *bus.MessageBus,
authToken string,
) (*services, error) {
runningServices := &services{}
@@ -304,6 +318,8 @@ func setupAndStartServices(
fms.Start()
}
overridePicoToken(cfg, authToken)
runningServices.ChannelManager, err = channels.NewManager(cfg, msgBus, runningServices.MediaStore)
if err != nil {
if fms, ok := runningServices.MediaStore.(*media.FileMediaStore); ok {
@@ -328,7 +344,8 @@ func setupAndStartServices(
}
addr := fmt.Sprintf("%s:%d", cfg.Gateway.Host, cfg.Gateway.Port)
runningServices.HealthServer = health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port)
runningServices.authToken = authToken
runningServices.HealthServer = health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port, authToken)
runningServices.ChannelManager.SetupHTTPServer(addr, runningServices.HealthServer)
if err = runningServices.ChannelManager.StartAll(context.Background()); err != nil {
@@ -547,6 +564,9 @@ func restartServices(
logger.InfoCF("voice", "Transcription disabled", nil)
}
// NOTE: PID file is written once at startup and not updated on reload.
// Changing the gateway listen address requires a full restart.
return nil
}
@@ -665,6 +685,20 @@ func setupCronTool(
return cronService, nil
}
// overridePicoToken replaces the pico channel token with the one from the PID file.
// The PID file is the single source of truth for the pico auth token;
// it is generated once at gateway startup and remains unchanged across reloads.
func overridePicoToken(cfg *config.Config, token string) {
if !cfg.Channels.Pico.Enabled {
return
}
picoToken := cfg.Channels.Pico.Token.String()
if picoToken == "" || strings.HasPrefix(picoToken, pico.PicoTokenPrefix) {
return
}
cfg.Channels.Pico.SetToken(pico.PicoTokenPrefix + token + picoToken)
}
func createHeartbeatHandler(agentLoop *agent.AgentLoop) func(prompt, channel, chatID string) *tools.ToolResult {
return func(prompt, channel, chatID string) *tools.ToolResult {
if channel == "" || chatID == "" {
+32 -4
View File
@@ -2,11 +2,11 @@ package health
import (
"context"
"crypto/subtle"
"encoding/json"
"fmt"
"maps"
"net/http"
"os"
"sync"
"time"
)
@@ -18,6 +18,7 @@ type Server struct {
checks map[string]Check
startTime time.Time
reloadFunc func() error
authToken string // optional bearer token for protected endpoints
}
type Check struct {
@@ -31,15 +32,15 @@ type StatusResponse struct {
Status string `json:"status"`
Uptime string `json:"uptime"`
Checks map[string]Check `json:"checks,omitempty"`
Pid int `json:"pid"`
}
func NewServer(host string, port int) *Server {
func NewServer(host string, port int, token string) *Server {
mux := http.NewServeMux()
s := &Server{
ready: false,
checks: make(map[string]Check),
startTime: time.Now(),
authToken: token,
}
mux.HandleFunc("/health", s.healthHandler)
@@ -123,6 +124,21 @@ func (s *Server) reloadHandler(w http.ResponseWriter, r *http.Request) {
return
}
// Token check
s.mu.RLock()
requiredToken := s.authToken
s.mu.RUnlock()
if requiredToken != "" {
given := extractBearerToken(r.Header.Get("Authorization"))
if given == "" || subtle.ConstantTimeCompare([]byte(given), []byte(requiredToken)) != 1 {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
json.NewEncoder(w).Encode(map[string]string{"error": "unauthorized"})
return
}
}
s.mu.Lock()
reloadFunc := s.reloadFunc
s.mu.Unlock()
@@ -154,7 +170,6 @@ func (s *Server) healthHandler(w http.ResponseWriter, r *http.Request) {
resp := StatusResponse{
Status: "ok",
Uptime: uptime.String(),
Pid: os.Getpid(),
}
json.NewEncoder(w).Encode(resp)
@@ -220,3 +235,16 @@ func statusString(ok bool) string {
}
return "fail"
}
// extractBearerToken returns the token from an "Authorization: Bearer <t>" header,
// or the empty string if the header is missing or malformed.
func extractBearerToken(header string) string {
const prefix = "Bearer "
if len(header) < len(prefix) {
return ""
}
if header[:len(prefix)] != prefix {
return ""
}
return header[len(prefix):]
}
+6 -5
View File
@@ -15,6 +15,7 @@ func newTestServer() *Server {
ready: false,
checks: make(map[string]Check),
startTime: time.Now(),
authToken: "test",
}
return s
}
@@ -37,9 +38,6 @@ func TestHealthHandler_ReturnsOK(t *testing.T) {
if resp.Status != "ok" {
t.Errorf("status = %q, want %q", resp.Status, "ok")
}
if resp.Pid == 0 {
t.Error("pid should not be 0")
}
if resp.Uptime == "" {
t.Error("uptime should not be empty")
}
@@ -168,6 +166,7 @@ func TestReloadHandler_NoReloadFunc(t *testing.T) {
s := newTestServer()
req := httptest.NewRequest(http.MethodPost, "/reload", nil)
req.Header.Set("Authorization", "Bearer test")
w := httptest.NewRecorder()
s.reloadHandler(w, req)
@@ -186,6 +185,7 @@ func TestReloadHandler_Success(t *testing.T) {
})
req := httptest.NewRequest(http.MethodPost, "/reload", nil)
req.Header.Set("Authorization", "Bearer test")
w := httptest.NewRecorder()
s.reloadHandler(w, req)
@@ -205,6 +205,7 @@ func TestReloadHandler_Error(t *testing.T) {
})
req := httptest.NewRequest(http.MethodPost, "/reload", nil)
req.Header.Set("Authorization", "Bearer test")
w := httptest.NewRecorder()
s.reloadHandler(w, req)
@@ -292,7 +293,7 @@ func TestRegisterOnMux(t *testing.T) {
}
func TestNewServer(t *testing.T) {
s := NewServer("127.0.0.1", 0)
s := NewServer("127.0.0.1", 0, "")
if s == nil {
t.Fatal("NewServer returned nil")
}
@@ -305,7 +306,7 @@ func TestNewServer(t *testing.T) {
}
func TestStartContext_Cancellation(t *testing.T) {
s := NewServer("127.0.0.1", 0)
s := NewServer("127.0.0.1", 0, "")
ctx, cancel := context.WithCancel(context.Background())
+1 -10
View File
@@ -1,12 +1,10 @@
package internal
import (
"fmt"
"io"
"os"
"path/filepath"
"github.com/sipeed/picoclaw/pkg"
"github.com/sipeed/picoclaw/pkg/config"
)
@@ -14,14 +12,7 @@ func ResolveTargetHome(override string) (string, error) {
if override != "" {
return ExpandHome(override), nil
}
if envHome := os.Getenv(config.EnvHome); envHome != "" {
return ExpandHome(envHome), nil
}
home, err := os.UserHomeDir()
if err != nil {
return "", fmt.Errorf("resolving home directory: %w", err)
}
return filepath.Join(home, pkg.DefaultPicoClawHome), nil
return config.GetHome(), nil
}
func ExpandHome(path string) string {
+159
View File
@@ -0,0 +1,159 @@
package pid
import (
"crypto/rand"
"encoding/hex"
"encoding/json"
"fmt"
"os"
"path/filepath"
"sync"
"time"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
)
const pidFileName = ".picoclaw.pid"
// PidFileData is the JSON structure stored in the PID file.
type PidFileData struct {
PID int `json:"pid"`
Token string `json:"token"`
Version string `json:"version"`
Port int `json:"port"`
Host string `json:"host"`
}
var pidMu sync.Mutex
// pidFilePath returns the absolute path for the PID file given the home directory.
func pidFilePath(homePath string) string {
return filepath.Join(homePath, pidFileName)
}
// generateToken creates a cryptographically random 32-character hex token.
func generateToken() string {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
// Fallback to something pseudo-random if crypto/rand fails
return fmt.Sprintf("%032x", time.Now().UnixNano())
}
return hex.EncodeToString(b)
}
// WritePidFile creates (or overwrites) the PID file atomically.
// It returns an error if another gateway instance appears to be running
// (a valid PID file exists with a live process).
func WritePidFile(homePath, host string, port int) (*PidFileData, error) {
pidMu.Lock()
defer pidMu.Unlock()
pidPath := pidFilePath(homePath)
// Check for existing PID file → singleton enforcement.
if data, err := readPidFileUnlocked(pidPath); err == nil {
if os.Getpid() != data.PID {
logger.Infof("found pid file (PID: %d, version: %s)", data.PID, data.Version)
if isProcessRunning(data.PID) {
return nil, fmt.Errorf("gateway is already running (PID: %d, version: %s)", data.PID, data.Version)
}
logger.Warnf("not running (PID: %d) so will remove the pid file: %s", data.PID, pidPath)
}
// Stale PID file; process no longer exists → clean up.
os.Remove(pidPath)
}
data := &PidFileData{
PID: os.Getpid(),
Version: config.GetVersion(),
Port: port,
Host: host,
}
token := generateToken()
data.Token = token
raw, err := json.MarshalIndent(data, "", " ")
if err != nil {
return nil, fmt.Errorf("failed to marshal pid file: %w", err)
}
// Ensure parent directory exists.
dir := filepath.Dir(pidPath)
if err := os.MkdirAll(dir, 0o755); err != nil {
return nil, fmt.Errorf("failed to create pid directory: %w", err)
}
// Write atomically via temp file + rename.
tmp := pidPath + ".tmp"
if err := os.WriteFile(tmp, raw, 0o600); err != nil {
return nil, fmt.Errorf("failed to write pid file: %w", err)
}
if err := os.Rename(tmp, pidPath); err != nil {
os.Remove(tmp)
return nil, fmt.Errorf("failed to rename pid file: %w", err)
}
return data, nil
}
// ReadPidFileWithCheck reads the PID file and additionally checks if
// the recorded process is still alive. Returns nil if the file is
// missing, unreadable, or the process has exited.
func ReadPidFileWithCheck(homePath string) *PidFileData {
pidMu.Lock()
defer pidMu.Unlock()
pidPath := pidFilePath(homePath)
data, err := readPidFileUnlocked(pidPath)
if err != nil {
return nil
}
if !isProcessRunning(data.PID) {
os.Remove(pidPath)
return nil
}
return data
}
// RemovePidFile deletes the PID file (e.g. on graceful shutdown).
func RemovePidFile(homePath string) {
pidMu.Lock()
defer pidMu.Unlock()
pidPath := pidFilePath(homePath)
// Only remove if the PID matches our own process (avoid deleting
// a file that belongs to a newer gateway instance).
if data, err := readPidFileUnlocked(pidPath); err == nil {
if data.PID != os.Getpid() {
return
}
}
logger.Infof("remove pid file: %s", pidPath)
os.Remove(pidPath)
}
// readPidFileUnlocked reads the PID file without acquiring the lock.
// Caller must hold pidMu.
func readPidFileUnlocked(pidPath string) (*PidFileData, error) {
raw, err := os.ReadFile(pidPath)
if err != nil {
return nil, err
}
var data PidFileData
if err := json.Unmarshal(raw, &data); err != nil {
return nil, err
}
// Validate PID is a positive integer.
if data.PID <= 0 {
return nil, fmt.Errorf("invalid pid in pid file: %d", data.PID)
}
return &data, nil
}
+253
View File
@@ -0,0 +1,253 @@
package pid
import (
"encoding/json"
"os"
"path/filepath"
"testing"
)
// tmpDir returns a clean temporary directory for a test.
func tmpDir(t *testing.T) string {
t.Helper()
dir, err := os.MkdirTemp("", "pidtest-*")
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { os.RemoveAll(dir) })
return dir
}
// TestGenerateToken verifies that generateToken produces a 32-character hex string.
func TestGenerateToken(t *testing.T) {
token := generateToken()
if len(token) != 32 {
t.Errorf("expected token length 32, got %d (token: %q)", len(token), token)
}
// Verify all characters are valid hex.
for _, c := range token {
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) {
t.Errorf("token contains non-hex character: %c", c)
}
}
}
// TestGenerateTokenUniqueness checks that two consecutive tokens differ.
func TestGenerateTokenUniqueness(t *testing.T) {
a := generateToken()
b := generateToken()
if a == b {
t.Error("two consecutive tokens should not be equal")
}
}
// TestPidFilePath returns the expected path.
func TestPidFilePath(t *testing.T) {
dir := tmpDir(t)
got := pidFilePath(dir)
want := filepath.Join(dir, pidFileName)
if got != want {
t.Errorf("pidFilePath(%q) = %q, want %q", dir, got, want)
}
}
// TestWritePidFile creates a PID file and verifies its contents.
func TestWritePidFile(t *testing.T) {
dir := tmpDir(t)
data, err := WritePidFile(dir, "127.0.0.1", 18790)
if err != nil {
t.Fatalf("WritePidFile failed: %v", err)
}
if data.PID != os.Getpid() {
t.Errorf("PID = %d, want %d", data.PID, os.Getpid())
}
if data.Host != "127.0.0.1" {
t.Errorf("Host = %q, want %q", data.Host, "127.0.0.1")
}
if data.Port != 18790 {
t.Errorf("Port = %d, want %d", data.Port, 18790)
}
if len(data.Token) != 32 {
t.Errorf("Token length = %d, want 32", len(data.Token))
}
// Verify the file exists and can be unmarshalled.
raw, err := os.ReadFile(filepath.Join(dir, pidFileName))
if err != nil {
t.Fatalf("failed to read pid file: %v", err)
}
var fileData PidFileData
if err = json.Unmarshal(raw, &fileData); err != nil {
t.Fatalf("failed to unmarshal pid file: %v", err)
}
if fileData.PID != data.PID || fileData.Token != data.Token {
t.Error("file data mismatch")
}
// Verify file permissions (owner-only read/write).
info, err := os.Stat(filepath.Join(dir, pidFileName))
if err != nil {
t.Fatalf("failed to stat pid file: %v", err)
}
perm := info.Mode().Perm()
if perm != 0o600 {
t.Errorf("file permission = %o, want 0600", perm)
}
}
// TestWritePidFileOverwrite writes twice and verifies the PID file is replaced.
func TestWritePidFileOverwrite(t *testing.T) {
dir := tmpDir(t)
data1, err := WritePidFile(dir, "0.0.0.0", 18790)
if err != nil {
t.Fatalf("first WritePidFile failed: %v", err)
}
// Second write should succeed because the PID matches our process.
data2, err := WritePidFile(dir, "0.0.0.0", 18800)
if err != nil {
t.Fatalf("second WritePidFile failed: %v", err)
}
if data2.Token == data1.Token {
t.Error("token should change on re-write")
}
if data2.Port != 18800 {
t.Errorf("Port = %d, want 18800", data2.Port)
}
}
// TestWritePidFileStalePID writes a PID file with a non-running PID, then
// verifies WritePidFile cleans it up and writes a new one.
func TestWritePidFileStalePID(t *testing.T) {
dir := tmpDir(t)
// Write a PID file with a PID that almost certainly doesn't exist.
stale := PidFileData{PID: 99999999, Token: "deadbeef12345678deadbeef12345678"}
raw, _ := json.MarshalIndent(stale, "", " ")
os.WriteFile(filepath.Join(dir, pidFileName), raw, 0o600)
data, err := WritePidFile(dir, "127.0.0.1", 18790)
if err != nil {
t.Fatalf("WritePidFile with stale PID failed: %v", err)
}
if data.PID != os.Getpid() {
t.Errorf("PID = %d, want %d", data.PID, os.Getpid())
}
}
// TestReadPidFileWithCheck verifies reading a valid PID file for the current process.
func TestReadPidFileWithCheck(t *testing.T) {
dir := tmpDir(t)
// Some sandboxed environments (e.g. macOS test runner) may restrict
// signal(0), causing isProcessRunning(getpid()) to return false.
if !isProcessRunning(os.Getpid()) {
t.Skip("skipping: isProcessRunning(getpid()) is false in this environment")
}
written, err := WritePidFile(dir, "127.0.0.1", 18790)
if err != nil {
t.Fatalf("WritePidFile failed: %v", err)
}
read := ReadPidFileWithCheck(dir)
if read == nil {
t.Fatal("ReadPidFileWithCheck returned nil for current process")
}
if read.PID != written.PID || read.Token != written.Token {
t.Error("read data doesn't match written data")
}
}
// TestReadPidFileWithCheckNonexistent returns nil for missing file.
func TestReadPidFileWithCheckNonexistent(t *testing.T) {
dir := tmpDir(t)
data := ReadPidFileWithCheck(dir)
if data != nil {
t.Error("expected nil for nonexistent PID file")
}
}
// TestReadPidFileWithCheckStalePID auto-cleans a PID file whose process is dead.
func TestReadPidFileWithCheckStalePID(t *testing.T) {
dir := tmpDir(t)
stale := PidFileData{PID: 99999999, Token: "deadbeef12345678deadbeef12345678"}
raw, _ := json.MarshalIndent(stale, "", " ")
os.WriteFile(filepath.Join(dir, pidFileName), raw, 0o600)
data := ReadPidFileWithCheck(dir)
if data != nil {
t.Error("expected nil for stale PID")
}
// File should be cleaned up.
if _, err := os.Stat(filepath.Join(dir, pidFileName)); !os.IsNotExist(err) {
t.Error("stale PID file should be removed")
}
}
// TestRemovePidFile removes the PID file for the current process.
func TestRemovePidFile(t *testing.T) {
dir := tmpDir(t)
if _, err := WritePidFile(dir, "127.0.0.1", 18790); err != nil {
t.Fatalf("WritePidFile failed: %v", err)
}
RemovePidFile(dir)
if _, err := os.Stat(filepath.Join(dir, pidFileName)); !os.IsNotExist(err) {
t.Error("PID file should be removed")
}
}
// TestRemovePidFileDifferentPID does not remove a PID file owned by another process.
func TestRemovePidFileDifferentPID(t *testing.T) {
dir := tmpDir(t)
other := PidFileData{PID: 99999999, Token: "deadbeef12345678deadbeef12345678"}
raw, _ := json.MarshalIndent(other, "", " ")
os.WriteFile(filepath.Join(dir, pidFileName), raw, 0o600)
RemovePidFile(dir)
if _, err := os.Stat(filepath.Join(dir, pidFileName)); os.IsNotExist(err) {
t.Error("PID file should NOT be removed (different PID)")
}
}
// TestRemovePidFileNonexistent does not error on missing file.
func TestRemovePidFileNonexistent(t *testing.T) {
dir := tmpDir(t)
// Should not panic or error.
RemovePidFile(dir)
}
// TestReadPidFileUnlockedInvalidJSON returns error for malformed content.
func TestReadPidFileUnlockedInvalidJSON(t *testing.T) {
dir := tmpDir(t)
path := filepath.Join(dir, pidFileName)
os.WriteFile(path, []byte("not json"), 0o600)
_, err := readPidFileUnlocked(path)
if err == nil {
t.Error("expected error for invalid JSON")
}
}
// TestReadPidFileUnlockedInvalidPID returns error for non-positive PID.
func TestReadPidFileUnlockedInvalidPID(t *testing.T) {
dir := tmpDir(t)
path := filepath.Join(dir, pidFileName)
os.WriteFile(path, []byte(`{"pid": -1, "token": "a"}`), 0o600)
_, err := readPidFileUnlocked(path)
if err == nil {
t.Error("expected error for invalid PID")
}
}
+22
View File
@@ -0,0 +1,22 @@
//go:build !windows
package pid
import (
"os"
"syscall"
)
// isProcessRunning checks whether a process with the given PID is alive
// on Unix-like systems using signal(0).
func isProcessRunning(pid int) bool {
if pid <= 0 {
return false
}
p, err := os.FindProcess(pid)
if err != nil {
return false
}
// Signal(nil) does not kill the process but checks existence on Unix.
return p.Signal(syscall.Signal(0)) == nil
}
+42
View File
@@ -0,0 +1,42 @@
//go:build windows
package pid
import (
"syscall"
"unsafe"
)
var (
kernel32 = syscall.NewLazyDLL("kernel32.dll")
procOpenProcess = kernel32.NewProc("OpenProcess")
procGetExitCodeProcess = kernel32.NewProc("GetExitCodeProcess")
procCloseHandle = kernel32.NewProc("CloseHandle")
processQueryLimitedInformation = uint32(0x1000)
stillActive = uint32(259)
)
// isProcessRunning checks whether a process with the given PID is alive
// on Windows using OpenProcess + GetExitCodeProcess.
func isProcessRunning(pid int) bool {
if pid <= 0 {
return false
}
handle, _, err := procOpenProcess.Call(
uintptr(processQueryLimitedInformation),
0,
uintptr(pid),
)
if handle == 0 || err != nil {
return false
}
defer procCloseHandle.Call(handle)
var exitCode uint32
ret, _, err := procGetExitCodeProcess.Call(handle, uintptr(unsafe.Pointer(&exitCode)))
if ret == 0 || err != nil {
return false
}
return exitCode == stillActive
}
+6
View File
@@ -87,6 +87,9 @@ func (h *Handler) handleUpdateConfig(w http.ResponseWriter, r *http.Request) {
return
}
// Refresh cached pico token in case user changed it.
refreshPicoToken(&cfg)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
}
@@ -182,6 +185,9 @@ func (h *Handler) handlePatchConfig(w http.ResponseWriter, r *http.Request) {
return
}
// Refresh cached pico token in case user changed it.
refreshPicoToken(&newCfg)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
}
+188 -121
View File
@@ -17,9 +17,11 @@ import (
"syscall"
"time"
"github.com/sipeed/picoclaw/pkg/channels/pico"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/health"
"github.com/sipeed/picoclaw/pkg/logger"
ppid "github.com/sipeed/picoclaw/pkg/pid"
"github.com/sipeed/picoclaw/web/backend/utils"
)
@@ -33,11 +35,49 @@ var gateway = struct {
runtimeStatus string
startupDeadline time.Time
logs *LogBuffer
pidData *ppid.PidFileData // pid file data read from picoclaw.pid.json
picoToken string // cached pico token from config (for proxy auth validation)
}{
runtimeStatus: "stopped",
logs: NewLogBuffer(200),
}
// refreshPicoToken updates gateway.picoToken from cfg
func refreshPicoToken(cfg *config.Config) {
gateway.mu.Lock()
defer gateway.mu.Unlock()
gateway.picoToken = cfg.Channels.Pico.Token.String()
}
// refreshPicoTokensLocked reads the pico token from config and caches it.
// Caller must hold gateway.mu (or be sole writer).
func refreshPicoTokensLocked(configPath string) {
cfg, err := config.LoadConfig(configPath)
if err != nil {
return
}
gateway.picoToken = cfg.Channels.Pico.Token.String()
}
const (
protocolKey = "Sec-Websocket-Protocol"
tokenPrefix = "token."
)
// picoComposedToken returns "pico-"+pidToken+picoToken for gateway auth.
func picoComposedToken(token string) string {
gateway.mu.Lock()
defer gateway.mu.Unlock()
// if not initial pico token, don't allow gateway auth
if gateway.picoToken == "" || gateway.pidData == nil {
return ""
}
if tokenPrefix+gateway.picoToken != token {
return ""
}
return pico.PicoTokenPrefix + gateway.pidData.Token + gateway.picoToken
}
var (
gatewayStartupWindow = 15 * time.Second
gatewayRestartGracePeriod = 5 * time.Second
@@ -50,16 +90,29 @@ var gatewayHealthGet = func(url string, timeout time.Duration) (*http.Response,
return client.Get(url)
}
// getGatewayHealth checks the gateway health endpoint and returns the status response
// getGatewayHealth checks the gateway health endpoint and returns the status response.
// Returns (*health.StatusResponse, statusCode, error). If error is not nil, the other values are not valid.
func (h *Handler) getGatewayHealth(cfg *config.Config, timeout time.Duration) (*health.StatusResponse, int, error) {
port := 18790
if cfg != nil && cfg.Gateway.Port != 0 {
port = cfg.Gateway.Port
// Prefer port/host from pidData when available.
var port int
var host string
gateway.mu.Lock()
if d := gateway.pidData; d != nil && d.Port > 0 {
port = d.Port
host = d.Host
}
gateway.mu.Unlock()
if port == 0 {
port = 18790
if cfg != nil && cfg.Gateway.Port != 0 {
port = cfg.Gateway.Port
}
}
if host == "" {
host = gatewayProbeHost(h.effectiveGatewayBindHost(cfg))
}
probeHost := gatewayProbeHost(h.effectiveGatewayBindHost(cfg))
url := "http://" + net.JoinHostPort(probeHost, strconv.Itoa(port)) + "/health"
url := "http://" + net.JoinHostPort(host, strconv.Itoa(port)) + "/health"
return getGatewayHealthByURL(url, timeout)
}
@@ -92,30 +145,33 @@ func (h *Handler) registerGatewayRoutes(mux *http.ServeMux) {
// TryAutoStartGateway checks whether gateway start preconditions are met and
// starts it when possible. Intended to be called by the backend at startup.
func (h *Handler) TryAutoStartGateway() {
// Check if gateway is already running via health endpoint
cfg, cfgErr := config.LoadConfig(h.configPath)
if cfgErr == nil && cfg != nil {
healthResp, statusCode, err := h.getGatewayHealth(cfg, 2*time.Second)
if err == nil && statusCode == http.StatusOK {
// Gateway is already running, attach to the existing process
pid := healthResp.Pid
gateway.mu.Lock()
defer gateway.mu.Unlock()
ready, reason, err := h.gatewayStartReady()
if err != nil {
logger.ErrorC("gateway", fmt.Sprintf("Skip auto-starting gateway: %v", err))
return
}
if !ready {
logger.InfoC("gateway", fmt.Sprintf("Skip auto-starting gateway: %s", reason))
return
}
_, err = h.startGatewayLocked("starting", pid)
if err != nil {
logger.ErrorC("gateway", fmt.Sprintf("Failed to attach to running gateway (PID: %d): %v", pid, err))
}
// Check PID file first to detect an already-running gateway.
pidData := ppid.ReadPidFileWithCheck(globalConfigDir())
if pidData != nil {
gateway.mu.Lock()
ready, reason, err := h.gatewayStartReady()
if err != nil {
logger.ErrorC("gateway", fmt.Sprintf("Skip auto-starting gateway: %v", err))
gateway.mu.Unlock()
return
}
logger.Infof("ready: %v, reason: %s", ready, reason)
if !ready {
logger.InfoC("gateway", fmt.Sprintf("Skip auto-starting gateway: %s", reason))
gateway.mu.Unlock()
return
}
pid := pidData.PID
_, err = h.startGatewayLocked("starting", pid)
if err != nil {
logger.ErrorC("gateway", fmt.Sprintf("Failed to attach to running gateway (PID: %d): %v", pid, err))
} else {
gateway.pidData = pidData
refreshPicoTokensLocked(h.configPath)
logger.InfoC("gateway", fmt.Sprintf("Attached to running gateway via PID file (PID: %d)", pid))
}
gateway.mu.Unlock()
return
}
gateway.mu.Lock()
@@ -400,6 +456,7 @@ func stopGatewayLocked() (int, error) {
gateway.cmd = nil
gateway.owned = false
gateway.bootDefaultModel = ""
gateway.pidData = nil
setGatewayRuntimeStatusLocked("stopped")
return pid, nil
@@ -452,6 +509,7 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
pid = existingPid
gateway.cmd = nil // Clear first to ensure clean state
if err = attachToGatewayProcessLocked(pid, cfg); err != nil {
logger.ErrorC("gateway", fmt.Sprintf("Failed to attach to existing gateway (PID %d): %v", pid, err))
return 0, err
}
@@ -461,6 +519,7 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
// Start new process
// Locate the picoclaw executable
execPath := utils.FindPicoclawBinary()
logger.InfoC("gateway", fmt.Sprintf("Starting gateway process (%s)", execPath))
cmd = exec.Command(execPath, "gateway", "-E")
cmd.Env = os.Environ()
@@ -488,10 +547,16 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
gateway.logs.Reset()
// Ensure Pico Channel is configured before starting gateway
if _, err := h.EnsurePicoChannel(""); err != nil {
changed, err := h.EnsurePicoChannel("")
if err != nil {
logger.ErrorC("gateway", fmt.Sprintf("Warning: failed to ensure pico channel: %v", err))
// Non-fatal: gateway can still start without pico channel
}
// Refresh cached pico token in case EnsurePicoChannel generated a new one.
// Already holding gateway.mu from caller.
if changed {
refreshPicoTokensLocked(h.configPath)
}
if err := cmd.Start(); err != nil {
return 0, fmt.Errorf("failed to start gateway: %w", err)
@@ -529,7 +594,7 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
gateway.mu.Unlock()
}()
// Start a goroutine to probe health and update the runtime state once ready.
// Start a goroutine to probe pidFile and health, update runtime state once ready.
go func() {
for i := 0; i < 30; i++ { // try for up to 15 seconds
time.Sleep(500 * time.Millisecond)
@@ -539,13 +604,26 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
if !stillOurs {
return
}
// Poll for pidFile first — once available we have port/host/token.
if pd := ppid.ReadPidFileWithCheck(globalConfigDir()); pd != nil && pd.PID == pid {
gateway.mu.Lock()
if gateway.cmd == cmd {
gateway.pidData = pd
setGatewayRuntimeStatusLocked("running")
}
gateway.mu.Unlock()
logger.InfoC("gateway", fmt.Sprintf("Gateway pidFile detected (PID: %d, port: %d)", pd.PID, pd.Port))
return
}
// Fallback: probe health endpoint to confirm liveness.
cfg, err := config.LoadConfig(h.configPath)
if err != nil {
continue
}
healthResp, statusCode, err := h.getGatewayHealth(cfg, 1*time.Second)
if err == nil && statusCode == http.StatusOK && healthResp.Pid == pid {
// Verify the health endpoint returns the expected pid
_, statusCode, err := h.getGatewayHealth(cfg, 1*time.Second)
if err == nil && statusCode == http.StatusOK {
gateway.mu.Lock()
if gateway.cmd == cmd {
setGatewayRuntimeStatusLocked("running")
@@ -563,49 +641,47 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
//
// POST /api/gateway/start
func (h *Handler) handleGatewayStart(w http.ResponseWriter, r *http.Request) {
// Prevent duplicate starts by checking health endpoint
cfg, cfgErr := config.LoadConfig(h.configPath)
if cfgErr == nil && cfg != nil {
healthResp, statusCode, err := h.getGatewayHealth(cfg, 2*time.Second)
if err == nil && statusCode == http.StatusOK {
// Gateway is already running, attach to the existing process
pid := healthResp.Pid
gateway.mu.Lock()
ready, reason, err := h.gatewayStartReady()
if err != nil {
gateway.mu.Unlock()
http.Error(
w,
fmt.Sprintf("Failed to validate gateway start conditions: %v", err),
http.StatusInternalServerError,
)
return
}
if !ready {
gateway.mu.Unlock()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]any{
"status": "precondition_failed",
"message": reason,
})
return
}
_, err = h.startGatewayLocked("starting", pid)
// Check PID file first to detect an already-running gateway.
pidData := ppid.ReadPidFileWithCheck(globalConfigDir())
if pidData != nil {
pid := pidData.PID
gateway.mu.Lock()
ready, reason, err := h.gatewayStartReady()
if err != nil {
gateway.mu.Unlock()
http.Error(
w,
fmt.Sprintf("Failed to validate gateway start conditions: %v", err),
http.StatusInternalServerError,
)
return
}
if !ready {
gateway.mu.Unlock()
if err != nil {
logger.ErrorC("gateway", fmt.Sprintf("Failed to attach to running gateway (PID: %d): %v", pid, err))
http.Error(w, fmt.Sprintf("Failed to attach to gateway: %v", err), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]any{
"status": "ok",
"pid": pid,
"status": "precondition_failed",
"message": reason,
})
return
}
_, err = h.startGatewayLocked("starting", pid)
if err != nil {
gateway.mu.Unlock()
logger.ErrorC("gateway", fmt.Sprintf("Failed to attach to running gateway (PID: %d): %v", pid, err))
http.Error(w, fmt.Sprintf("Failed to attach to gateway: %v", err), http.StatusInternalServerError)
return
}
gateway.pidData = pidData
gateway.mu.Unlock()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]any{
"status": "ok",
"pid": pid,
})
return
}
gateway.mu.Lock()
@@ -805,65 +881,56 @@ func (h *Handler) gatewayStatusData() map[string]any {
}
}
// Probe health endpoint to get pid and status
healthResp, statusCode, err := h.getGatewayHealth(cfg, 2*time.Second)
if err != nil {
// Primary detection: read PID file and check if process is alive.
pidData := ppid.ReadPidFileWithCheck(globalConfigDir())
if pidData != nil {
gateway.mu.Lock()
data["gateway_status"] = gatewayStatusWithoutHealthLocked()
gateway.pidData = pidData
if pidData.Version != "" {
data["gateway_version"] = pidData.Version
}
setGatewayRuntimeStatusLocked("running")
// Attach if we don't already track this PID.
if gateway.cmd == nil || gateway.cmd.Process == nil || gateway.cmd.Process.Pid != pidData.PID {
_ = attachToGatewayProcessLocked(pidData.PID, cfg)
}
bootDefaultModel := gateway.bootDefaultModel
if bootDefaultModel != "" {
data["boot_default_model"] = bootDefaultModel
}
data["gateway_status"] = "running"
data["pid"] = pidData.PID
gateway.mu.Unlock()
logger.ErrorC("gateway", fmt.Sprintf("Gateway health check failed: %v", err))
} else {
if statusCode != http.StatusOK {
logger.WarnC("gateway", fmt.Sprintf("Gateway health status: %d", statusCode))
// Fallback: probe health endpoint to get pid and status
_, statusCode, err := h.getGatewayHealth(cfg, 2*time.Second)
if err != nil {
gateway.mu.Lock()
setGatewayRuntimeStatusLocked("error")
data["gateway_status"] = gatewayStatusWithoutHealthLocked()
gateway.pidData = nil
gateway.mu.Unlock()
data["gateway_status"] = "error"
data["status_code"] = statusCode
logger.ErrorC("gateway", fmt.Sprintf("Gateway health check failed: %v", err))
} else {
gateway.mu.Lock()
setGatewayRuntimeStatusLocked("running")
if gateway.cmd == nil || gateway.cmd.Process == nil || gateway.cmd.Process.Pid != healthResp.Pid {
oldPid := "none"
if gateway.cmd != nil && gateway.cmd.Process != nil {
oldPid = fmt.Sprintf("%d", gateway.cmd.Process.Pid)
}
logger.InfoC(
"gateway",
fmt.Sprintf(
"Detected new gateway PID (old: %s, new: %d), attempting to attach",
oldPid,
healthResp.Pid,
),
)
if err := attachToGatewayProcessLocked(healthResp.Pid, cfg); err != nil {
// Failed to find the process, treat as error
setGatewayRuntimeStatusLocked("error")
data["gateway_status"] = "error"
data["pid"] = healthResp.Pid
logger.ErrorC(
"gateway",
fmt.Sprintf("Failed to attach to new gateway process (PID: %d): %v", healthResp.Pid, err),
)
} else {
// Successfully attached, update response data
bootDefaultModel := gateway.bootDefaultModel
if bootDefaultModel != "" {
data["boot_default_model"] = bootDefaultModel
}
data["gateway_status"] = "running"
data["pid"] = healthResp.Pid
logger.InfoC("gateway", fmt.Sprintf("Gateway health status: %d", statusCode))
if statusCode != http.StatusOK {
gateway.mu.Lock()
setGatewayRuntimeStatusLocked("error")
gateway.pidData = nil
gateway.mu.Unlock()
data["gateway_status"] = "error"
data["status_code"] = statusCode
} else {
gateway.mu.Lock()
setGatewayRuntimeStatusLocked("running")
bootDefaultModel := gateway.bootDefaultModel
if bootDefaultModel != "" {
data["boot_default_model"] = bootDefaultModel
}
data["gateway_status"] = "running"
gateway.mu.Unlock()
}
bootDefaultModel := gateway.bootDefaultModel
if bootDefaultModel != "" {
data["boot_default_model"] = bootDefaultModel
}
data["gateway_status"] = "running"
data["pid"] = healthResp.Pid
gateway.mu.Unlock()
}
}
-3
View File
@@ -470,9 +470,6 @@ func TestGatewayStatusReportsRunningFromHealthProbe(t *testing.T) {
if got := body["gateway_status"]; got != "running" {
t.Fatalf("gateway_status = %#v, want %q", got, "running")
}
if got := body["pid"]; got != float64(cmd.Process.Pid) {
t.Fatalf("pid = %#v, want %d", got, cmd.Process.Pid)
}
if got := body["gateway_restart_required"]; got != false {
t.Fatalf("gateway_restart_required = %#v, want false", got)
}
+53 -8
View File
@@ -10,6 +10,7 @@ import (
"time"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
)
// registerPicoRoutes binds Pico Channel management endpoints to the ServeMux.
@@ -26,20 +27,55 @@ func (h *Handler) registerPicoRoutes(mux *http.ServeMux) {
// createWsProxy creates a reverse proxy to the current gateway WebSocket endpoint.
// The gateway bind host and port are resolved from the latest configuration.
func (h *Handler) createWsProxy() *httputil.ReverseProxy {
wsProxy := httputil.NewSingleHostReverseProxy(h.gatewayProxyURL())
wsProxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
http.Error(w, "Gateway unavailable: "+err.Error(), http.StatusBadGateway)
func (h *Handler) createWsProxy(origProtocol string, token string) *httputil.ReverseProxy {
wsProxy := &httputil.ReverseProxy{
Rewrite: func(r *httputil.ProxyRequest) {
target := h.gatewayProxyURL()
r.SetURL(target)
r.Out.Header.Set(protocolKey, tokenPrefix+token)
},
ModifyResponse: func(r *http.Response) error {
if prot := r.Header.Values(protocolKey); len(prot) > 0 {
r.Header.Del(protocolKey)
if origProtocol != "" {
r.Header.Set(protocolKey, origProtocol)
}
}
return nil
},
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
logger.Errorf("Failed to proxy WebSocket: %v", err)
http.Error(w, "Gateway unavailable: "+err.Error(), http.StatusBadGateway)
},
}
return wsProxy
}
// handleWebSocketProxy wraps a reverse proxy to handle WebSocket connections.
// The reverse proxy forwards the incoming upgrade handshake as-is.
// It validates the client token before forwarding; rejects immediately on failure.
func (h *Handler) handleWebSocketProxy() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
proxy := h.createWsProxy()
proxy.ServeHTTP(w, r)
gateway.mu.Lock()
gatewayAvailable := gateway.pidData != nil
gateway.mu.Unlock()
if !gatewayAvailable {
logger.Warnf("Gateway not available for WebSocket proxy")
http.Error(w, "Gateway not available", http.StatusServiceUnavailable)
return
}
prot := r.Header.Values(protocolKey)
if len(prot) > 0 {
origProtocol := prot[0]
newToken := picoComposedToken(prot[0])
if newToken != "" {
h.createWsProxy(origProtocol, newToken).ServeHTTP(w, r)
return
}
}
logger.Warnf("Invalid Pico token: %v", prot)
http.Error(w, "Invalid Pico token", http.StatusForbidden)
}
}
@@ -81,6 +117,11 @@ func (h *Handler) handleRegenPicoToken(w http.ResponseWriter, r *http.Request) {
return
}
// Refresh cached pico token.
gateway.mu.Lock()
gateway.picoToken = token
gateway.mu.Unlock()
wsURL := h.buildWsURL(r)
w.Header().Set("Content-Type", "application/json")
@@ -140,11 +181,15 @@ func (h *Handler) handlePicoSetup(w http.ResponseWriter, r *http.Request) {
return
}
// Reload config (EnsurePicoChannel may have modified it) and refresh cache.
cfg, err := config.LoadConfig(h.configPath)
if err != nil {
http.Error(w, fmt.Sprintf("Failed to load config: %v", err), http.StatusInternalServerError)
return
}
if changed {
refreshPicoToken(cfg)
}
wsURL := h.buildWsURL(r)
@@ -162,7 +207,7 @@ func generateSecureToken() string {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
// Fallback to something pseudo-random if crypto/rand fails
return fmt.Sprintf("pico_%x", time.Now().UnixNano())
return fmt.Sprintf("%032x", time.Now().UnixNano())
}
return hex.EncodeToString(b)
}
+14
View File
@@ -12,6 +12,7 @@ import (
"testing"
"github.com/sipeed/picoclaw/pkg/config"
ppid "github.com/sipeed/picoclaw/pkg/pid"
)
func TestEnsurePicoChannel_FreshConfig(t *testing.T) {
@@ -335,10 +336,22 @@ func TestHandleWebSocketProxyReloadsGatewayTargetFromConfig(t *testing.T) {
t.Fatalf("SaveConfig() error = %v", err)
}
gateway.pidData = &ppid.PidFileData{}
gateway.picoToken = "pico"
req1 := httptest.NewRequest(http.MethodGet, "/pico/ws", nil)
req1.Header.Set(protocolKey, tokenPrefix+"wrong_token")
rec1 := httptest.NewRecorder()
handler(rec1, req1)
if rec1.Code != http.StatusForbidden {
t.Fatalf("first status = %d, want %d", rec1.Code, http.StatusForbidden)
}
req1 = httptest.NewRequest(http.MethodGet, "/pico/ws", nil)
req1.Header.Set(protocolKey, tokenPrefix+"pico")
rec1 = httptest.NewRecorder()
handler(rec1, req1)
if rec1.Code != http.StatusOK {
t.Fatalf("first status = %d, want %d", rec1.Code, http.StatusOK)
}
@@ -352,6 +365,7 @@ func TestHandleWebSocketProxyReloadsGatewayTargetFromConfig(t *testing.T) {
}
req2 := httptest.NewRequest(http.MethodGet, "/pico/ws", nil)
req2.Header.Set(protocolKey, tokenPrefix+"pico")
rec2 := httptest.NewRecorder()
handler(rec2, req2)
+1 -8
View File
@@ -309,14 +309,7 @@ func loadSkillContent(path string) (string, error) {
}
func globalConfigDir() string {
if home := os.Getenv(config.EnvHome); home != "" {
return home
}
home, err := os.UserHomeDir()
if err != nil {
return ""
}
return filepath.Join(home, ".picoclaw")
return config.GetHome()
}
func builtinSkillsDir() string {
+11
View File
@@ -1,7 +1,9 @@
package middleware
import (
"bufio"
"fmt"
"net"
"net/http"
"runtime/debug"
"time"
@@ -44,6 +46,15 @@ func (rr *responseRecorder) Unwrap() http.ResponseWriter {
return rr.ResponseWriter
}
// Hijack implements http.Hijacker so that WebSocket upgrades work through
// the middleware layer.
func (rr *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hj, ok := rr.ResponseWriter.(http.Hijacker); ok {
return hj.Hijack()
}
return nil, nil, http.ErrNotSupported
}
// Logger logs each HTTP request with method, path, status code, and duration.
func Logger(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+3 -5
View File
@@ -9,16 +9,13 @@ import (
"runtime"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
)
// GetPicoclawHome returns the picoclaw home directory.
// Priority: $PICOCLAW_HOME > ~/.picoclaw
func GetPicoclawHome() string {
if home := os.Getenv(config.EnvHome); home != "" {
return home
}
home, _ := os.UserHomeDir()
return filepath.Join(home, ".picoclaw")
return config.GetHome()
}
// GetDefaultConfigPath returns the default path to the picoclaw config file.
@@ -47,6 +44,7 @@ func FindPicoclawBinary() string {
}
if exe, err := os.Executable(); err == nil {
logger.Debugf("Trying to find picoclaw binary in %s", exe)
candidate := filepath.Join(filepath.Dir(exe), binaryName)
if info, err := os.Stat(candidate); err == nil && !info.IsDir() {
return candidate