mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge pull request #2134 from cytown/t3
add pid file for gateway running and auth token for /reload and pico channel
This commit is contained in:
@@ -7,9 +7,7 @@ package ui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -17,63 +15,30 @@ import (
|
||||
|
||||
"github.com/gdamore/tcell/v2"
|
||||
"github.com/rivo/tview"
|
||||
)
|
||||
|
||||
const pidFileName = "gateway.pid"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
ppid "github.com/sipeed/picoclaw/pkg/pid"
|
||||
)
|
||||
|
||||
type gatewayStatus struct {
|
||||
running bool
|
||||
pid int
|
||||
version string
|
||||
}
|
||||
|
||||
func getPidPath() string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
home = "."
|
||||
}
|
||||
return filepath.Join(home, ".picoclaw", pidFileName)
|
||||
}
|
||||
|
||||
func isProcessRunning(pid int) bool {
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
cmd := exec.Command("tasklist", "/FI", fmt.Sprintf("PID eq %d", pid))
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(string(output), strconv.Itoa(pid))
|
||||
case "darwin":
|
||||
cmd := exec.Command("ps", "aux")
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(string(output), fmt.Sprintf(" %d ", pid))
|
||||
default:
|
||||
// Linux and other unix-like systems.
|
||||
_, err := os.Stat(fmt.Sprintf("/proc/%d", pid))
|
||||
return err == nil
|
||||
}
|
||||
func picoHome() string {
|
||||
return config.GetHome()
|
||||
}
|
||||
|
||||
func getGatewayStatus() gatewayStatus {
|
||||
pidPath := getPidPath()
|
||||
data, err := os.ReadFile(pidPath)
|
||||
if err != nil {
|
||||
return gatewayStatus{running: false}
|
||||
}
|
||||
pid, err := strconv.Atoi(strings.TrimSpace(string(data)))
|
||||
if err != nil {
|
||||
return gatewayStatus{running: false}
|
||||
}
|
||||
if !isProcessRunning(pid) {
|
||||
os.Remove(pidPath)
|
||||
data := ppid.ReadPidFileWithCheck(picoHome())
|
||||
if data == nil {
|
||||
return gatewayStatus{running: false}
|
||||
}
|
||||
return gatewayStatus{
|
||||
running: true,
|
||||
pid: pid,
|
||||
pid: data.PID,
|
||||
version: data.Version,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -83,13 +48,12 @@ func startGateway() error {
|
||||
return fmt.Errorf("gateway is already running (PID: %d)", status.pid)
|
||||
}
|
||||
|
||||
pidPath := getPidPath()
|
||||
var cmd *exec.Cmd
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
cmd = exec.Command("cmd", "/C", "start /B picoclaw gateway > NUL 2>&1")
|
||||
} else {
|
||||
cmd = exec.Command("sh", "-c", "nohup picoclaw gateway > /dev/null 2>&1 & echo $! > "+pidPath)
|
||||
cmd = exec.Command("sh", "-c", "nohup picoclaw gateway > /dev/null 2>&1 &")
|
||||
}
|
||||
|
||||
err := cmd.Start()
|
||||
@@ -118,9 +82,8 @@ func startGateway() error {
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
pid, err := strconv.Atoi(line)
|
||||
_, err := strconv.Atoi(line)
|
||||
if err == nil {
|
||||
os.WriteFile(pidPath, []byte(strconv.Itoa(pid)), 0o600)
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -143,21 +106,20 @@ func stopGateway() error {
|
||||
if runtime.GOOS == "windows" {
|
||||
err = exec.Command("taskkill", "/F", "/PID", strconv.Itoa(status.pid)).Run()
|
||||
} else {
|
||||
err = exec.Command("kill", "-9", strconv.Itoa(status.pid)).Run()
|
||||
err = exec.Command("kill", strconv.Itoa(status.pid)).Run()
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 多次尝试确认进程已停止
|
||||
// Wait for process to stop (ReadPidFileWithCheck cleans up stale pid file)
|
||||
for i := 0; i < 5; i++ {
|
||||
if !isProcessRunning(status.pid) {
|
||||
if !getGatewayStatus().running {
|
||||
break
|
||||
}
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
os.Remove(getPidPath())
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -219,7 +181,11 @@ func (a *App) newGatewayPage() tview.Primitive {
|
||||
updateStatus = func() {
|
||||
status := getGatewayStatus()
|
||||
if status.running {
|
||||
statusTV.SetText(fmt.Sprintf("[#39ff14::b]GATEWAY RUNNING[-]\n\nPID: %d", status.pid))
|
||||
versionInfo := ""
|
||||
if status.version != "" {
|
||||
versionInfo = fmt.Sprintf("\nVersion: %s", status.version)
|
||||
}
|
||||
statusTV.SetText(fmt.Sprintf("[#39ff14::b]GATEWAY RUNNING[-]\n\nPID: %d%s", status.pid, versionInfo))
|
||||
buttons.SetItemText(0, " [gray]START[white] ", "")
|
||||
buttons.SetItemText(1, " [red]STOP[white] ", "")
|
||||
} else {
|
||||
|
||||
@@ -14,11 +14,7 @@ const Logo = pkg.Logo
|
||||
// GetPicoclawHome returns the picoclaw home directory.
|
||||
// Priority: $PICOCLAW_HOME > ~/.picoclaw
|
||||
func GetPicoclawHome() string {
|
||||
if home := os.Getenv(config.EnvHome); home != "" {
|
||||
return home
|
||||
}
|
||||
home, _ := os.UserHomeDir()
|
||||
return filepath.Join(home, pkg.DefaultPicoClawHome)
|
||||
return config.GetHome()
|
||||
}
|
||||
|
||||
func GetConfigPath() string {
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
@@ -59,14 +58,7 @@ func (cb *ContextBuilder) WithSplitOnMarker(enabled bool) *ContextBuilder {
|
||||
}
|
||||
|
||||
func getGlobalConfigDir() string {
|
||||
if home := os.Getenv(config.EnvHome); home != "" {
|
||||
return home
|
||||
}
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return filepath.Join(home, pkg.DefaultPicoClawHome)
|
||||
return config.GetHome()
|
||||
}
|
||||
|
||||
func NewContextBuilder(workspace string) *ContextBuilder {
|
||||
|
||||
+1
-6
@@ -6,7 +6,6 @@ import (
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/fileutil"
|
||||
)
|
||||
@@ -41,11 +40,7 @@ func (c *AuthCredential) NeedsRefresh() bool {
|
||||
}
|
||||
|
||||
func authFilePath() string {
|
||||
if home := os.Getenv(config.EnvHome); home != "" {
|
||||
return filepath.Join(home, "auth.json")
|
||||
}
|
||||
home, _ := os.UserHomeDir()
|
||||
return filepath.Join(home, pkg.DefaultPicoClawHome, "auth.json")
|
||||
return filepath.Join(config.GetHome(), "auth.json")
|
||||
}
|
||||
|
||||
func LoadStore() (*AuthStore, error) {
|
||||
|
||||
@@ -17,6 +17,8 @@ const (
|
||||
TypeTypingStop = "typing.stop"
|
||||
TypeError = "error"
|
||||
TypePong = "pong"
|
||||
|
||||
PicoTokenPrefix = "pico-"
|
||||
)
|
||||
|
||||
// PicoMessage is the wire format for all Pico Protocol messages.
|
||||
|
||||
@@ -41,11 +41,7 @@ type contextTokensFile struct {
|
||||
}
|
||||
|
||||
func picoclawHomeDir() string {
|
||||
if home := os.Getenv(config.EnvHome); home != "" {
|
||||
return home
|
||||
}
|
||||
userHome, _ := os.UserHomeDir()
|
||||
return filepath.Join(userHome, ".picoclaw")
|
||||
return config.GetHome()
|
||||
}
|
||||
|
||||
func genWeixinAccountKey(cfg config.WeixinConfig) string {
|
||||
|
||||
@@ -1031,12 +1031,7 @@ func LoadConfig(path string) (*Config, error) {
|
||||
|
||||
// Ensure Workspace has a default if not set
|
||||
if cfg.Agents.Defaults.Workspace == "" {
|
||||
homePath, _ := os.UserHomeDir()
|
||||
if picoclawHome := os.Getenv(EnvHome); picoclawHome != "" {
|
||||
homePath = picoclawHome
|
||||
} else if homePath != "" {
|
||||
homePath = filepath.Join(homePath, pkg.DefaultPicoClawHome)
|
||||
}
|
||||
homePath := GetHome()
|
||||
cfg.Agents.Defaults.Workspace = filepath.Join(homePath, pkg.WorkspaceName)
|
||||
}
|
||||
|
||||
|
||||
+1
-11
@@ -6,7 +6,6 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg"
|
||||
@@ -14,16 +13,7 @@ import (
|
||||
|
||||
// DefaultConfig returns the default configuration for PicoClaw.
|
||||
func DefaultConfig() *Config {
|
||||
// Determine the base path for the workspace.
|
||||
// Priority: $PICOCLAW_HOME > ~/.picoclaw
|
||||
var homePath string
|
||||
if picoclawHome := os.Getenv(EnvHome); picoclawHome != "" {
|
||||
homePath = picoclawHome
|
||||
} else {
|
||||
userHome, _ := os.UserHomeDir()
|
||||
homePath = filepath.Join(userHome, pkg.DefaultPicoClawHome)
|
||||
}
|
||||
workspacePath := filepath.Join(homePath, pkg.WorkspaceName)
|
||||
workspacePath := filepath.Join(GetHome(), pkg.WorkspaceName)
|
||||
|
||||
return &Config{
|
||||
Version: CurrentVersion,
|
||||
|
||||
@@ -5,6 +5,13 @@
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg"
|
||||
)
|
||||
|
||||
// Runtime environment variable keys for the picoclaw process.
|
||||
// These control the location of files and binaries at runtime and are read
|
||||
// directly via os.Getenv / os.LookupEnv. All picoclaw-specific keys use the
|
||||
@@ -35,3 +42,16 @@ const (
|
||||
// Default: "127.0.0.1"
|
||||
EnvGatewayHost = "PICOCLAW_GATEWAY_HOST"
|
||||
)
|
||||
|
||||
func GetHome() string {
|
||||
homePath, _ := os.UserHomeDir()
|
||||
if picoclawHome := os.Getenv(EnvHome); picoclawHome != "" {
|
||||
homePath = picoclawHome
|
||||
} else if homePath != "" {
|
||||
homePath = filepath.Join(homePath, pkg.DefaultPicoClawHome)
|
||||
}
|
||||
if homePath == "" {
|
||||
homePath = "."
|
||||
}
|
||||
return homePath
|
||||
}
|
||||
|
||||
+37
-3
@@ -6,6 +6,7 @@ import (
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
@@ -21,7 +22,7 @@ import (
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/line"
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/maixcam"
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/onebot"
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/pico"
|
||||
"github.com/sipeed/picoclaw/pkg/channels/pico"
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/qq"
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/slack"
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/telegram"
|
||||
@@ -36,6 +37,7 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/heartbeat"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
"github.com/sipeed/picoclaw/pkg/pid"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/state"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
@@ -61,6 +63,7 @@ type services struct {
|
||||
HealthServer *health.Server
|
||||
manualReloadChan chan struct{}
|
||||
reloading atomic.Bool
|
||||
authToken string
|
||||
}
|
||||
|
||||
type startupBlockedProvider struct {
|
||||
@@ -113,6 +116,13 @@ func Run(debug bool, homePath, configPath string, allowEmptyStartup bool) error
|
||||
logger.Infof("Log level set to %q", cfg.Gateway.LogLevel)
|
||||
}
|
||||
|
||||
// Enforce singleton: write PID file with generated token.
|
||||
pidData, err := pid.WritePidFile(homePath, cfg.Gateway.Host, cfg.Gateway.Port)
|
||||
if err != nil {
|
||||
return fmt.Errorf("singleton check failed: %w", err)
|
||||
}
|
||||
defer pid.RemovePidFile(homePath)
|
||||
|
||||
provider, modelID, err := createStartupProvider(cfg, allowEmptyStartup)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating provider: %w", err)
|
||||
@@ -139,7 +149,7 @@ func Run(debug bool, homePath, configPath string, allowEmptyStartup bool) error
|
||||
"skills_available": skillsInfo["available"],
|
||||
})
|
||||
|
||||
runningServices, err := setupAndStartServices(cfg, agentLoop, msgBus)
|
||||
runningServices, err := setupAndStartServices(cfg, agentLoop, msgBus, pidData.Token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -238,6 +248,9 @@ func executeReload(
|
||||
debug bool,
|
||||
) error {
|
||||
defer runningServices.reloading.Store(false)
|
||||
|
||||
overridePicoToken(newCfg, runningServices.authToken)
|
||||
|
||||
return handleConfigReload(ctx, agentLoop, newCfg, provider, runningServices, msgBus, allowEmptyStartup, debug)
|
||||
}
|
||||
|
||||
@@ -262,6 +275,7 @@ func setupAndStartServices(
|
||||
cfg *config.Config,
|
||||
agentLoop *agent.AgentLoop,
|
||||
msgBus *bus.MessageBus,
|
||||
authToken string,
|
||||
) (*services, error) {
|
||||
runningServices := &services{}
|
||||
|
||||
@@ -304,6 +318,8 @@ func setupAndStartServices(
|
||||
fms.Start()
|
||||
}
|
||||
|
||||
overridePicoToken(cfg, authToken)
|
||||
|
||||
runningServices.ChannelManager, err = channels.NewManager(cfg, msgBus, runningServices.MediaStore)
|
||||
if err != nil {
|
||||
if fms, ok := runningServices.MediaStore.(*media.FileMediaStore); ok {
|
||||
@@ -328,7 +344,8 @@ func setupAndStartServices(
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf("%s:%d", cfg.Gateway.Host, cfg.Gateway.Port)
|
||||
runningServices.HealthServer = health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port)
|
||||
runningServices.authToken = authToken
|
||||
runningServices.HealthServer = health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port, authToken)
|
||||
runningServices.ChannelManager.SetupHTTPServer(addr, runningServices.HealthServer)
|
||||
|
||||
if err = runningServices.ChannelManager.StartAll(context.Background()); err != nil {
|
||||
@@ -547,6 +564,9 @@ func restartServices(
|
||||
logger.InfoCF("voice", "Transcription disabled", nil)
|
||||
}
|
||||
|
||||
// NOTE: PID file is written once at startup and not updated on reload.
|
||||
// Changing the gateway listen address requires a full restart.
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -665,6 +685,20 @@ func setupCronTool(
|
||||
return cronService, nil
|
||||
}
|
||||
|
||||
// overridePicoToken replaces the pico channel token with the one from the PID file.
|
||||
// The PID file is the single source of truth for the pico auth token;
|
||||
// it is generated once at gateway startup and remains unchanged across reloads.
|
||||
func overridePicoToken(cfg *config.Config, token string) {
|
||||
if !cfg.Channels.Pico.Enabled {
|
||||
return
|
||||
}
|
||||
picoToken := cfg.Channels.Pico.Token.String()
|
||||
if picoToken == "" || strings.HasPrefix(picoToken, pico.PicoTokenPrefix) {
|
||||
return
|
||||
}
|
||||
cfg.Channels.Pico.SetToken(pico.PicoTokenPrefix + token + picoToken)
|
||||
}
|
||||
|
||||
func createHeartbeatHandler(agentLoop *agent.AgentLoop) func(prompt, channel, chatID string) *tools.ToolResult {
|
||||
return func(prompt, channel, chatID string) *tools.ToolResult {
|
||||
if channel == "" || chatID == "" {
|
||||
|
||||
+32
-4
@@ -2,11 +2,11 @@ package health
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/subtle"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"maps"
|
||||
"net/http"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
@@ -18,6 +18,7 @@ type Server struct {
|
||||
checks map[string]Check
|
||||
startTime time.Time
|
||||
reloadFunc func() error
|
||||
authToken string // optional bearer token for protected endpoints
|
||||
}
|
||||
|
||||
type Check struct {
|
||||
@@ -31,15 +32,15 @@ type StatusResponse struct {
|
||||
Status string `json:"status"`
|
||||
Uptime string `json:"uptime"`
|
||||
Checks map[string]Check `json:"checks,omitempty"`
|
||||
Pid int `json:"pid"`
|
||||
}
|
||||
|
||||
func NewServer(host string, port int) *Server {
|
||||
func NewServer(host string, port int, token string) *Server {
|
||||
mux := http.NewServeMux()
|
||||
s := &Server{
|
||||
ready: false,
|
||||
checks: make(map[string]Check),
|
||||
startTime: time.Now(),
|
||||
authToken: token,
|
||||
}
|
||||
|
||||
mux.HandleFunc("/health", s.healthHandler)
|
||||
@@ -123,6 +124,21 @@ func (s *Server) reloadHandler(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Token check
|
||||
s.mu.RLock()
|
||||
requiredToken := s.authToken
|
||||
s.mu.RUnlock()
|
||||
|
||||
if requiredToken != "" {
|
||||
given := extractBearerToken(r.Header.Get("Authorization"))
|
||||
if given == "" || subtle.ConstantTimeCompare([]byte(given), []byte(requiredToken)) != 1 {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
json.NewEncoder(w).Encode(map[string]string{"error": "unauthorized"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
reloadFunc := s.reloadFunc
|
||||
s.mu.Unlock()
|
||||
@@ -154,7 +170,6 @@ func (s *Server) healthHandler(w http.ResponseWriter, r *http.Request) {
|
||||
resp := StatusResponse{
|
||||
Status: "ok",
|
||||
Uptime: uptime.String(),
|
||||
Pid: os.Getpid(),
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
@@ -220,3 +235,16 @@ func statusString(ok bool) string {
|
||||
}
|
||||
return "fail"
|
||||
}
|
||||
|
||||
// extractBearerToken returns the token from an "Authorization: Bearer <t>" header,
|
||||
// or the empty string if the header is missing or malformed.
|
||||
func extractBearerToken(header string) string {
|
||||
const prefix = "Bearer "
|
||||
if len(header) < len(prefix) {
|
||||
return ""
|
||||
}
|
||||
if header[:len(prefix)] != prefix {
|
||||
return ""
|
||||
}
|
||||
return header[len(prefix):]
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ func newTestServer() *Server {
|
||||
ready: false,
|
||||
checks: make(map[string]Check),
|
||||
startTime: time.Now(),
|
||||
authToken: "test",
|
||||
}
|
||||
return s
|
||||
}
|
||||
@@ -37,9 +38,6 @@ func TestHealthHandler_ReturnsOK(t *testing.T) {
|
||||
if resp.Status != "ok" {
|
||||
t.Errorf("status = %q, want %q", resp.Status, "ok")
|
||||
}
|
||||
if resp.Pid == 0 {
|
||||
t.Error("pid should not be 0")
|
||||
}
|
||||
if resp.Uptime == "" {
|
||||
t.Error("uptime should not be empty")
|
||||
}
|
||||
@@ -168,6 +166,7 @@ func TestReloadHandler_NoReloadFunc(t *testing.T) {
|
||||
s := newTestServer()
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/reload", nil)
|
||||
req.Header.Set("Authorization", "Bearer test")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
s.reloadHandler(w, req)
|
||||
@@ -186,6 +185,7 @@ func TestReloadHandler_Success(t *testing.T) {
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/reload", nil)
|
||||
req.Header.Set("Authorization", "Bearer test")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
s.reloadHandler(w, req)
|
||||
@@ -205,6 +205,7 @@ func TestReloadHandler_Error(t *testing.T) {
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/reload", nil)
|
||||
req.Header.Set("Authorization", "Bearer test")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
s.reloadHandler(w, req)
|
||||
@@ -292,7 +293,7 @@ func TestRegisterOnMux(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNewServer(t *testing.T) {
|
||||
s := NewServer("127.0.0.1", 0)
|
||||
s := NewServer("127.0.0.1", 0, "")
|
||||
if s == nil {
|
||||
t.Fatal("NewServer returned nil")
|
||||
}
|
||||
@@ -305,7 +306,7 @@ func TestNewServer(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStartContext_Cancellation(t *testing.T) {
|
||||
s := NewServer("127.0.0.1", 0)
|
||||
s := NewServer("127.0.0.1", 0, "")
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
@@ -14,14 +12,7 @@ func ResolveTargetHome(override string) (string, error) {
|
||||
if override != "" {
|
||||
return ExpandHome(override), nil
|
||||
}
|
||||
if envHome := os.Getenv(config.EnvHome); envHome != "" {
|
||||
return ExpandHome(envHome), nil
|
||||
}
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("resolving home directory: %w", err)
|
||||
}
|
||||
return filepath.Join(home, pkg.DefaultPicoClawHome), nil
|
||||
return config.GetHome(), nil
|
||||
}
|
||||
|
||||
func ExpandHome(path string) string {
|
||||
|
||||
@@ -0,0 +1,159 @@
|
||||
package pid
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
const pidFileName = ".picoclaw.pid"
|
||||
|
||||
// PidFileData is the JSON structure stored in the PID file.
|
||||
type PidFileData struct {
|
||||
PID int `json:"pid"`
|
||||
Token string `json:"token"`
|
||||
Version string `json:"version"`
|
||||
Port int `json:"port"`
|
||||
Host string `json:"host"`
|
||||
}
|
||||
|
||||
var pidMu sync.Mutex
|
||||
|
||||
// pidFilePath returns the absolute path for the PID file given the home directory.
|
||||
func pidFilePath(homePath string) string {
|
||||
return filepath.Join(homePath, pidFileName)
|
||||
}
|
||||
|
||||
// generateToken creates a cryptographically random 32-character hex token.
|
||||
func generateToken() string {
|
||||
b := make([]byte, 16)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
// Fallback to something pseudo-random if crypto/rand fails
|
||||
return fmt.Sprintf("%032x", time.Now().UnixNano())
|
||||
}
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
// WritePidFile creates (or overwrites) the PID file atomically.
|
||||
// It returns an error if another gateway instance appears to be running
|
||||
// (a valid PID file exists with a live process).
|
||||
func WritePidFile(homePath, host string, port int) (*PidFileData, error) {
|
||||
pidMu.Lock()
|
||||
defer pidMu.Unlock()
|
||||
|
||||
pidPath := pidFilePath(homePath)
|
||||
|
||||
// Check for existing PID file → singleton enforcement.
|
||||
if data, err := readPidFileUnlocked(pidPath); err == nil {
|
||||
if os.Getpid() != data.PID {
|
||||
logger.Infof("found pid file (PID: %d, version: %s)", data.PID, data.Version)
|
||||
if isProcessRunning(data.PID) {
|
||||
return nil, fmt.Errorf("gateway is already running (PID: %d, version: %s)", data.PID, data.Version)
|
||||
}
|
||||
logger.Warnf("not running (PID: %d) so will remove the pid file: %s", data.PID, pidPath)
|
||||
}
|
||||
// Stale PID file; process no longer exists → clean up.
|
||||
os.Remove(pidPath)
|
||||
}
|
||||
|
||||
data := &PidFileData{
|
||||
PID: os.Getpid(),
|
||||
Version: config.GetVersion(),
|
||||
Port: port,
|
||||
Host: host,
|
||||
}
|
||||
|
||||
token := generateToken()
|
||||
data.Token = token
|
||||
|
||||
raw, err := json.MarshalIndent(data, "", " ")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal pid file: %w", err)
|
||||
}
|
||||
|
||||
// Ensure parent directory exists.
|
||||
dir := filepath.Dir(pidPath)
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create pid directory: %w", err)
|
||||
}
|
||||
|
||||
// Write atomically via temp file + rename.
|
||||
tmp := pidPath + ".tmp"
|
||||
if err := os.WriteFile(tmp, raw, 0o600); err != nil {
|
||||
return nil, fmt.Errorf("failed to write pid file: %w", err)
|
||||
}
|
||||
if err := os.Rename(tmp, pidPath); err != nil {
|
||||
os.Remove(tmp)
|
||||
return nil, fmt.Errorf("failed to rename pid file: %w", err)
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// ReadPidFileWithCheck reads the PID file and additionally checks if
|
||||
// the recorded process is still alive. Returns nil if the file is
|
||||
// missing, unreadable, or the process has exited.
|
||||
func ReadPidFileWithCheck(homePath string) *PidFileData {
|
||||
pidMu.Lock()
|
||||
defer pidMu.Unlock()
|
||||
|
||||
pidPath := pidFilePath(homePath)
|
||||
data, err := readPidFileUnlocked(pidPath)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !isProcessRunning(data.PID) {
|
||||
os.Remove(pidPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
return data
|
||||
}
|
||||
|
||||
// RemovePidFile deletes the PID file (e.g. on graceful shutdown).
|
||||
func RemovePidFile(homePath string) {
|
||||
pidMu.Lock()
|
||||
defer pidMu.Unlock()
|
||||
|
||||
pidPath := pidFilePath(homePath)
|
||||
// Only remove if the PID matches our own process (avoid deleting
|
||||
// a file that belongs to a newer gateway instance).
|
||||
if data, err := readPidFileUnlocked(pidPath); err == nil {
|
||||
if data.PID != os.Getpid() {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
logger.Infof("remove pid file: %s", pidPath)
|
||||
os.Remove(pidPath)
|
||||
}
|
||||
|
||||
// readPidFileUnlocked reads the PID file without acquiring the lock.
|
||||
// Caller must hold pidMu.
|
||||
func readPidFileUnlocked(pidPath string) (*PidFileData, error) {
|
||||
raw, err := os.ReadFile(pidPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var data PidFileData
|
||||
if err := json.Unmarshal(raw, &data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Validate PID is a positive integer.
|
||||
if data.PID <= 0 {
|
||||
return nil, fmt.Errorf("invalid pid in pid file: %d", data.PID)
|
||||
}
|
||||
|
||||
return &data, nil
|
||||
}
|
||||
@@ -0,0 +1,253 @@
|
||||
package pid
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// tmpDir returns a clean temporary directory for a test.
|
||||
func tmpDir(t *testing.T) string {
|
||||
t.Helper()
|
||||
dir, err := os.MkdirTemp("", "pidtest-*")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(func() { os.RemoveAll(dir) })
|
||||
return dir
|
||||
}
|
||||
|
||||
// TestGenerateToken verifies that generateToken produces a 32-character hex string.
|
||||
func TestGenerateToken(t *testing.T) {
|
||||
token := generateToken()
|
||||
if len(token) != 32 {
|
||||
t.Errorf("expected token length 32, got %d (token: %q)", len(token), token)
|
||||
}
|
||||
// Verify all characters are valid hex.
|
||||
for _, c := range token {
|
||||
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) {
|
||||
t.Errorf("token contains non-hex character: %c", c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestGenerateTokenUniqueness checks that two consecutive tokens differ.
|
||||
func TestGenerateTokenUniqueness(t *testing.T) {
|
||||
a := generateToken()
|
||||
b := generateToken()
|
||||
if a == b {
|
||||
t.Error("two consecutive tokens should not be equal")
|
||||
}
|
||||
}
|
||||
|
||||
// TestPidFilePath returns the expected path.
|
||||
func TestPidFilePath(t *testing.T) {
|
||||
dir := tmpDir(t)
|
||||
got := pidFilePath(dir)
|
||||
want := filepath.Join(dir, pidFileName)
|
||||
if got != want {
|
||||
t.Errorf("pidFilePath(%q) = %q, want %q", dir, got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWritePidFile creates a PID file and verifies its contents.
|
||||
func TestWritePidFile(t *testing.T) {
|
||||
dir := tmpDir(t)
|
||||
data, err := WritePidFile(dir, "127.0.0.1", 18790)
|
||||
if err != nil {
|
||||
t.Fatalf("WritePidFile failed: %v", err)
|
||||
}
|
||||
|
||||
if data.PID != os.Getpid() {
|
||||
t.Errorf("PID = %d, want %d", data.PID, os.Getpid())
|
||||
}
|
||||
if data.Host != "127.0.0.1" {
|
||||
t.Errorf("Host = %q, want %q", data.Host, "127.0.0.1")
|
||||
}
|
||||
if data.Port != 18790 {
|
||||
t.Errorf("Port = %d, want %d", data.Port, 18790)
|
||||
}
|
||||
if len(data.Token) != 32 {
|
||||
t.Errorf("Token length = %d, want 32", len(data.Token))
|
||||
}
|
||||
|
||||
// Verify the file exists and can be unmarshalled.
|
||||
raw, err := os.ReadFile(filepath.Join(dir, pidFileName))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read pid file: %v", err)
|
||||
}
|
||||
|
||||
var fileData PidFileData
|
||||
if err = json.Unmarshal(raw, &fileData); err != nil {
|
||||
t.Fatalf("failed to unmarshal pid file: %v", err)
|
||||
}
|
||||
if fileData.PID != data.PID || fileData.Token != data.Token {
|
||||
t.Error("file data mismatch")
|
||||
}
|
||||
|
||||
// Verify file permissions (owner-only read/write).
|
||||
info, err := os.Stat(filepath.Join(dir, pidFileName))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to stat pid file: %v", err)
|
||||
}
|
||||
perm := info.Mode().Perm()
|
||||
if perm != 0o600 {
|
||||
t.Errorf("file permission = %o, want 0600", perm)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWritePidFileOverwrite writes twice and verifies the PID file is replaced.
|
||||
func TestWritePidFileOverwrite(t *testing.T) {
|
||||
dir := tmpDir(t)
|
||||
|
||||
data1, err := WritePidFile(dir, "0.0.0.0", 18790)
|
||||
if err != nil {
|
||||
t.Fatalf("first WritePidFile failed: %v", err)
|
||||
}
|
||||
|
||||
// Second write should succeed because the PID matches our process.
|
||||
data2, err := WritePidFile(dir, "0.0.0.0", 18800)
|
||||
if err != nil {
|
||||
t.Fatalf("second WritePidFile failed: %v", err)
|
||||
}
|
||||
|
||||
if data2.Token == data1.Token {
|
||||
t.Error("token should change on re-write")
|
||||
}
|
||||
if data2.Port != 18800 {
|
||||
t.Errorf("Port = %d, want 18800", data2.Port)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWritePidFileStalePID writes a PID file with a non-running PID, then
|
||||
// verifies WritePidFile cleans it up and writes a new one.
|
||||
func TestWritePidFileStalePID(t *testing.T) {
|
||||
dir := tmpDir(t)
|
||||
|
||||
// Write a PID file with a PID that almost certainly doesn't exist.
|
||||
stale := PidFileData{PID: 99999999, Token: "deadbeef12345678deadbeef12345678"}
|
||||
raw, _ := json.MarshalIndent(stale, "", " ")
|
||||
os.WriteFile(filepath.Join(dir, pidFileName), raw, 0o600)
|
||||
|
||||
data, err := WritePidFile(dir, "127.0.0.1", 18790)
|
||||
if err != nil {
|
||||
t.Fatalf("WritePidFile with stale PID failed: %v", err)
|
||||
}
|
||||
if data.PID != os.Getpid() {
|
||||
t.Errorf("PID = %d, want %d", data.PID, os.Getpid())
|
||||
}
|
||||
}
|
||||
|
||||
// TestReadPidFileWithCheck verifies reading a valid PID file for the current process.
|
||||
func TestReadPidFileWithCheck(t *testing.T) {
|
||||
dir := tmpDir(t)
|
||||
|
||||
// Some sandboxed environments (e.g. macOS test runner) may restrict
|
||||
// signal(0), causing isProcessRunning(getpid()) to return false.
|
||||
if !isProcessRunning(os.Getpid()) {
|
||||
t.Skip("skipping: isProcessRunning(getpid()) is false in this environment")
|
||||
}
|
||||
|
||||
written, err := WritePidFile(dir, "127.0.0.1", 18790)
|
||||
if err != nil {
|
||||
t.Fatalf("WritePidFile failed: %v", err)
|
||||
}
|
||||
|
||||
read := ReadPidFileWithCheck(dir)
|
||||
if read == nil {
|
||||
t.Fatal("ReadPidFileWithCheck returned nil for current process")
|
||||
}
|
||||
if read.PID != written.PID || read.Token != written.Token {
|
||||
t.Error("read data doesn't match written data")
|
||||
}
|
||||
}
|
||||
|
||||
// TestReadPidFileWithCheckNonexistent returns nil for missing file.
|
||||
func TestReadPidFileWithCheckNonexistent(t *testing.T) {
|
||||
dir := tmpDir(t)
|
||||
data := ReadPidFileWithCheck(dir)
|
||||
if data != nil {
|
||||
t.Error("expected nil for nonexistent PID file")
|
||||
}
|
||||
}
|
||||
|
||||
// TestReadPidFileWithCheckStalePID auto-cleans a PID file whose process is dead.
|
||||
func TestReadPidFileWithCheckStalePID(t *testing.T) {
|
||||
dir := tmpDir(t)
|
||||
|
||||
stale := PidFileData{PID: 99999999, Token: "deadbeef12345678deadbeef12345678"}
|
||||
raw, _ := json.MarshalIndent(stale, "", " ")
|
||||
os.WriteFile(filepath.Join(dir, pidFileName), raw, 0o600)
|
||||
|
||||
data := ReadPidFileWithCheck(dir)
|
||||
if data != nil {
|
||||
t.Error("expected nil for stale PID")
|
||||
}
|
||||
|
||||
// File should be cleaned up.
|
||||
if _, err := os.Stat(filepath.Join(dir, pidFileName)); !os.IsNotExist(err) {
|
||||
t.Error("stale PID file should be removed")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRemovePidFile removes the PID file for the current process.
|
||||
func TestRemovePidFile(t *testing.T) {
|
||||
dir := tmpDir(t)
|
||||
|
||||
if _, err := WritePidFile(dir, "127.0.0.1", 18790); err != nil {
|
||||
t.Fatalf("WritePidFile failed: %v", err)
|
||||
}
|
||||
|
||||
RemovePidFile(dir)
|
||||
|
||||
if _, err := os.Stat(filepath.Join(dir, pidFileName)); !os.IsNotExist(err) {
|
||||
t.Error("PID file should be removed")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRemovePidFileDifferentPID does not remove a PID file owned by another process.
|
||||
func TestRemovePidFileDifferentPID(t *testing.T) {
|
||||
dir := tmpDir(t)
|
||||
|
||||
other := PidFileData{PID: 99999999, Token: "deadbeef12345678deadbeef12345678"}
|
||||
raw, _ := json.MarshalIndent(other, "", " ")
|
||||
os.WriteFile(filepath.Join(dir, pidFileName), raw, 0o600)
|
||||
|
||||
RemovePidFile(dir)
|
||||
|
||||
if _, err := os.Stat(filepath.Join(dir, pidFileName)); os.IsNotExist(err) {
|
||||
t.Error("PID file should NOT be removed (different PID)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRemovePidFileNonexistent does not error on missing file.
|
||||
func TestRemovePidFileNonexistent(t *testing.T) {
|
||||
dir := tmpDir(t)
|
||||
// Should not panic or error.
|
||||
RemovePidFile(dir)
|
||||
}
|
||||
|
||||
// TestReadPidFileUnlockedInvalidJSON returns error for malformed content.
|
||||
func TestReadPidFileUnlockedInvalidJSON(t *testing.T) {
|
||||
dir := tmpDir(t)
|
||||
path := filepath.Join(dir, pidFileName)
|
||||
os.WriteFile(path, []byte("not json"), 0o600)
|
||||
|
||||
_, err := readPidFileUnlocked(path)
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid JSON")
|
||||
}
|
||||
}
|
||||
|
||||
// TestReadPidFileUnlockedInvalidPID returns error for non-positive PID.
|
||||
func TestReadPidFileUnlockedInvalidPID(t *testing.T) {
|
||||
dir := tmpDir(t)
|
||||
path := filepath.Join(dir, pidFileName)
|
||||
os.WriteFile(path, []byte(`{"pid": -1, "token": "a"}`), 0o600)
|
||||
|
||||
_, err := readPidFileUnlocked(path)
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid PID")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
//go:build !windows
|
||||
|
||||
package pid
|
||||
|
||||
import (
|
||||
"os"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// isProcessRunning checks whether a process with the given PID is alive
|
||||
// on Unix-like systems using signal(0).
|
||||
func isProcessRunning(pid int) bool {
|
||||
if pid <= 0 {
|
||||
return false
|
||||
}
|
||||
p, err := os.FindProcess(pid)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
// Signal(nil) does not kill the process but checks existence on Unix.
|
||||
return p.Signal(syscall.Signal(0)) == nil
|
||||
}
|
||||
@@ -0,0 +1,42 @@
|
||||
//go:build windows
|
||||
|
||||
package pid
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
var (
|
||||
kernel32 = syscall.NewLazyDLL("kernel32.dll")
|
||||
procOpenProcess = kernel32.NewProc("OpenProcess")
|
||||
procGetExitCodeProcess = kernel32.NewProc("GetExitCodeProcess")
|
||||
procCloseHandle = kernel32.NewProc("CloseHandle")
|
||||
processQueryLimitedInformation = uint32(0x1000)
|
||||
stillActive = uint32(259)
|
||||
)
|
||||
|
||||
// isProcessRunning checks whether a process with the given PID is alive
|
||||
// on Windows using OpenProcess + GetExitCodeProcess.
|
||||
func isProcessRunning(pid int) bool {
|
||||
if pid <= 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
handle, _, err := procOpenProcess.Call(
|
||||
uintptr(processQueryLimitedInformation),
|
||||
0,
|
||||
uintptr(pid),
|
||||
)
|
||||
if handle == 0 || err != nil {
|
||||
return false
|
||||
}
|
||||
defer procCloseHandle.Call(handle)
|
||||
|
||||
var exitCode uint32
|
||||
ret, _, err := procGetExitCodeProcess.Call(handle, uintptr(unsafe.Pointer(&exitCode)))
|
||||
if ret == 0 || err != nil {
|
||||
return false
|
||||
}
|
||||
return exitCode == stillActive
|
||||
}
|
||||
@@ -87,6 +87,9 @@ func (h *Handler) handleUpdateConfig(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Refresh cached pico token in case user changed it.
|
||||
refreshPicoToken(&cfg)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
|
||||
}
|
||||
@@ -182,6 +185,9 @@ func (h *Handler) handlePatchConfig(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Refresh cached pico token in case user changed it.
|
||||
refreshPicoToken(&newCfg)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
|
||||
}
|
||||
|
||||
+188
-121
@@ -17,9 +17,11 @@ import (
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/channels/pico"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/health"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
ppid "github.com/sipeed/picoclaw/pkg/pid"
|
||||
"github.com/sipeed/picoclaw/web/backend/utils"
|
||||
)
|
||||
|
||||
@@ -33,11 +35,49 @@ var gateway = struct {
|
||||
runtimeStatus string
|
||||
startupDeadline time.Time
|
||||
logs *LogBuffer
|
||||
pidData *ppid.PidFileData // pid file data read from picoclaw.pid.json
|
||||
picoToken string // cached pico token from config (for proxy auth validation)
|
||||
}{
|
||||
runtimeStatus: "stopped",
|
||||
logs: NewLogBuffer(200),
|
||||
}
|
||||
|
||||
// refreshPicoToken updates gateway.picoToken from cfg
|
||||
func refreshPicoToken(cfg *config.Config) {
|
||||
gateway.mu.Lock()
|
||||
defer gateway.mu.Unlock()
|
||||
gateway.picoToken = cfg.Channels.Pico.Token.String()
|
||||
}
|
||||
|
||||
// refreshPicoTokensLocked reads the pico token from config and caches it.
|
||||
// Caller must hold gateway.mu (or be sole writer).
|
||||
func refreshPicoTokensLocked(configPath string) {
|
||||
cfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
gateway.picoToken = cfg.Channels.Pico.Token.String()
|
||||
}
|
||||
|
||||
const (
|
||||
protocolKey = "Sec-Websocket-Protocol"
|
||||
tokenPrefix = "token."
|
||||
)
|
||||
|
||||
// picoComposedToken returns "pico-"+pidToken+picoToken for gateway auth.
|
||||
func picoComposedToken(token string) string {
|
||||
gateway.mu.Lock()
|
||||
defer gateway.mu.Unlock()
|
||||
// if not initial pico token, don't allow gateway auth
|
||||
if gateway.picoToken == "" || gateway.pidData == nil {
|
||||
return ""
|
||||
}
|
||||
if tokenPrefix+gateway.picoToken != token {
|
||||
return ""
|
||||
}
|
||||
return pico.PicoTokenPrefix + gateway.pidData.Token + gateway.picoToken
|
||||
}
|
||||
|
||||
var (
|
||||
gatewayStartupWindow = 15 * time.Second
|
||||
gatewayRestartGracePeriod = 5 * time.Second
|
||||
@@ -50,16 +90,29 @@ var gatewayHealthGet = func(url string, timeout time.Duration) (*http.Response,
|
||||
return client.Get(url)
|
||||
}
|
||||
|
||||
// getGatewayHealth checks the gateway health endpoint and returns the status response
|
||||
// getGatewayHealth checks the gateway health endpoint and returns the status response.
|
||||
// Returns (*health.StatusResponse, statusCode, error). If error is not nil, the other values are not valid.
|
||||
func (h *Handler) getGatewayHealth(cfg *config.Config, timeout time.Duration) (*health.StatusResponse, int, error) {
|
||||
port := 18790
|
||||
if cfg != nil && cfg.Gateway.Port != 0 {
|
||||
port = cfg.Gateway.Port
|
||||
// Prefer port/host from pidData when available.
|
||||
var port int
|
||||
var host string
|
||||
gateway.mu.Lock()
|
||||
if d := gateway.pidData; d != nil && d.Port > 0 {
|
||||
port = d.Port
|
||||
host = d.Host
|
||||
}
|
||||
gateway.mu.Unlock()
|
||||
if port == 0 {
|
||||
port = 18790
|
||||
if cfg != nil && cfg.Gateway.Port != 0 {
|
||||
port = cfg.Gateway.Port
|
||||
}
|
||||
}
|
||||
if host == "" {
|
||||
host = gatewayProbeHost(h.effectiveGatewayBindHost(cfg))
|
||||
}
|
||||
|
||||
probeHost := gatewayProbeHost(h.effectiveGatewayBindHost(cfg))
|
||||
url := "http://" + net.JoinHostPort(probeHost, strconv.Itoa(port)) + "/health"
|
||||
url := "http://" + net.JoinHostPort(host, strconv.Itoa(port)) + "/health"
|
||||
|
||||
return getGatewayHealthByURL(url, timeout)
|
||||
}
|
||||
@@ -92,30 +145,33 @@ func (h *Handler) registerGatewayRoutes(mux *http.ServeMux) {
|
||||
// TryAutoStartGateway checks whether gateway start preconditions are met and
|
||||
// starts it when possible. Intended to be called by the backend at startup.
|
||||
func (h *Handler) TryAutoStartGateway() {
|
||||
// Check if gateway is already running via health endpoint
|
||||
cfg, cfgErr := config.LoadConfig(h.configPath)
|
||||
if cfgErr == nil && cfg != nil {
|
||||
healthResp, statusCode, err := h.getGatewayHealth(cfg, 2*time.Second)
|
||||
if err == nil && statusCode == http.StatusOK {
|
||||
// Gateway is already running, attach to the existing process
|
||||
pid := healthResp.Pid
|
||||
gateway.mu.Lock()
|
||||
defer gateway.mu.Unlock()
|
||||
ready, reason, err := h.gatewayStartReady()
|
||||
if err != nil {
|
||||
logger.ErrorC("gateway", fmt.Sprintf("Skip auto-starting gateway: %v", err))
|
||||
return
|
||||
}
|
||||
if !ready {
|
||||
logger.InfoC("gateway", fmt.Sprintf("Skip auto-starting gateway: %s", reason))
|
||||
return
|
||||
}
|
||||
_, err = h.startGatewayLocked("starting", pid)
|
||||
if err != nil {
|
||||
logger.ErrorC("gateway", fmt.Sprintf("Failed to attach to running gateway (PID: %d): %v", pid, err))
|
||||
}
|
||||
// Check PID file first to detect an already-running gateway.
|
||||
pidData := ppid.ReadPidFileWithCheck(globalConfigDir())
|
||||
if pidData != nil {
|
||||
gateway.mu.Lock()
|
||||
ready, reason, err := h.gatewayStartReady()
|
||||
if err != nil {
|
||||
logger.ErrorC("gateway", fmt.Sprintf("Skip auto-starting gateway: %v", err))
|
||||
gateway.mu.Unlock()
|
||||
return
|
||||
}
|
||||
logger.Infof("ready: %v, reason: %s", ready, reason)
|
||||
if !ready {
|
||||
logger.InfoC("gateway", fmt.Sprintf("Skip auto-starting gateway: %s", reason))
|
||||
gateway.mu.Unlock()
|
||||
return
|
||||
}
|
||||
pid := pidData.PID
|
||||
_, err = h.startGatewayLocked("starting", pid)
|
||||
if err != nil {
|
||||
logger.ErrorC("gateway", fmt.Sprintf("Failed to attach to running gateway (PID: %d): %v", pid, err))
|
||||
} else {
|
||||
gateway.pidData = pidData
|
||||
refreshPicoTokensLocked(h.configPath)
|
||||
logger.InfoC("gateway", fmt.Sprintf("Attached to running gateway via PID file (PID: %d)", pid))
|
||||
}
|
||||
gateway.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
gateway.mu.Lock()
|
||||
@@ -400,6 +456,7 @@ func stopGatewayLocked() (int, error) {
|
||||
gateway.cmd = nil
|
||||
gateway.owned = false
|
||||
gateway.bootDefaultModel = ""
|
||||
gateway.pidData = nil
|
||||
setGatewayRuntimeStatusLocked("stopped")
|
||||
|
||||
return pid, nil
|
||||
@@ -452,6 +509,7 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
|
||||
pid = existingPid
|
||||
gateway.cmd = nil // Clear first to ensure clean state
|
||||
if err = attachToGatewayProcessLocked(pid, cfg); err != nil {
|
||||
logger.ErrorC("gateway", fmt.Sprintf("Failed to attach to existing gateway (PID %d): %v", pid, err))
|
||||
return 0, err
|
||||
}
|
||||
|
||||
@@ -461,6 +519,7 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
|
||||
// Start new process
|
||||
// Locate the picoclaw executable
|
||||
execPath := utils.FindPicoclawBinary()
|
||||
logger.InfoC("gateway", fmt.Sprintf("Starting gateway process (%s)", execPath))
|
||||
|
||||
cmd = exec.Command(execPath, "gateway", "-E")
|
||||
cmd.Env = os.Environ()
|
||||
@@ -488,10 +547,16 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
|
||||
gateway.logs.Reset()
|
||||
|
||||
// Ensure Pico Channel is configured before starting gateway
|
||||
if _, err := h.EnsurePicoChannel(""); err != nil {
|
||||
changed, err := h.EnsurePicoChannel("")
|
||||
if err != nil {
|
||||
logger.ErrorC("gateway", fmt.Sprintf("Warning: failed to ensure pico channel: %v", err))
|
||||
// Non-fatal: gateway can still start without pico channel
|
||||
}
|
||||
// Refresh cached pico token in case EnsurePicoChannel generated a new one.
|
||||
// Already holding gateway.mu from caller.
|
||||
if changed {
|
||||
refreshPicoTokensLocked(h.configPath)
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
return 0, fmt.Errorf("failed to start gateway: %w", err)
|
||||
@@ -529,7 +594,7 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
|
||||
gateway.mu.Unlock()
|
||||
}()
|
||||
|
||||
// Start a goroutine to probe health and update the runtime state once ready.
|
||||
// Start a goroutine to probe pidFile and health, update runtime state once ready.
|
||||
go func() {
|
||||
for i := 0; i < 30; i++ { // try for up to 15 seconds
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
@@ -539,13 +604,26 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
|
||||
if !stillOurs {
|
||||
return
|
||||
}
|
||||
|
||||
// Poll for pidFile first — once available we have port/host/token.
|
||||
if pd := ppid.ReadPidFileWithCheck(globalConfigDir()); pd != nil && pd.PID == pid {
|
||||
gateway.mu.Lock()
|
||||
if gateway.cmd == cmd {
|
||||
gateway.pidData = pd
|
||||
setGatewayRuntimeStatusLocked("running")
|
||||
}
|
||||
gateway.mu.Unlock()
|
||||
logger.InfoC("gateway", fmt.Sprintf("Gateway pidFile detected (PID: %d, port: %d)", pd.PID, pd.Port))
|
||||
return
|
||||
}
|
||||
|
||||
// Fallback: probe health endpoint to confirm liveness.
|
||||
cfg, err := config.LoadConfig(h.configPath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
healthResp, statusCode, err := h.getGatewayHealth(cfg, 1*time.Second)
|
||||
if err == nil && statusCode == http.StatusOK && healthResp.Pid == pid {
|
||||
// Verify the health endpoint returns the expected pid
|
||||
_, statusCode, err := h.getGatewayHealth(cfg, 1*time.Second)
|
||||
if err == nil && statusCode == http.StatusOK {
|
||||
gateway.mu.Lock()
|
||||
if gateway.cmd == cmd {
|
||||
setGatewayRuntimeStatusLocked("running")
|
||||
@@ -563,49 +641,47 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
|
||||
//
|
||||
// POST /api/gateway/start
|
||||
func (h *Handler) handleGatewayStart(w http.ResponseWriter, r *http.Request) {
|
||||
// Prevent duplicate starts by checking health endpoint
|
||||
cfg, cfgErr := config.LoadConfig(h.configPath)
|
||||
if cfgErr == nil && cfg != nil {
|
||||
healthResp, statusCode, err := h.getGatewayHealth(cfg, 2*time.Second)
|
||||
if err == nil && statusCode == http.StatusOK {
|
||||
// Gateway is already running, attach to the existing process
|
||||
pid := healthResp.Pid
|
||||
gateway.mu.Lock()
|
||||
ready, reason, err := h.gatewayStartReady()
|
||||
if err != nil {
|
||||
gateway.mu.Unlock()
|
||||
http.Error(
|
||||
w,
|
||||
fmt.Sprintf("Failed to validate gateway start conditions: %v", err),
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
return
|
||||
}
|
||||
if !ready {
|
||||
gateway.mu.Unlock()
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"status": "precondition_failed",
|
||||
"message": reason,
|
||||
})
|
||||
return
|
||||
}
|
||||
_, err = h.startGatewayLocked("starting", pid)
|
||||
// Check PID file first to detect an already-running gateway.
|
||||
pidData := ppid.ReadPidFileWithCheck(globalConfigDir())
|
||||
if pidData != nil {
|
||||
pid := pidData.PID
|
||||
gateway.mu.Lock()
|
||||
ready, reason, err := h.gatewayStartReady()
|
||||
if err != nil {
|
||||
gateway.mu.Unlock()
|
||||
http.Error(
|
||||
w,
|
||||
fmt.Sprintf("Failed to validate gateway start conditions: %v", err),
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
return
|
||||
}
|
||||
if !ready {
|
||||
gateway.mu.Unlock()
|
||||
if err != nil {
|
||||
logger.ErrorC("gateway", fmt.Sprintf("Failed to attach to running gateway (PID: %d): %v", pid, err))
|
||||
http.Error(w, fmt.Sprintf("Failed to attach to gateway: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"status": "ok",
|
||||
"pid": pid,
|
||||
"status": "precondition_failed",
|
||||
"message": reason,
|
||||
})
|
||||
return
|
||||
}
|
||||
_, err = h.startGatewayLocked("starting", pid)
|
||||
if err != nil {
|
||||
gateway.mu.Unlock()
|
||||
logger.ErrorC("gateway", fmt.Sprintf("Failed to attach to running gateway (PID: %d): %v", pid, err))
|
||||
http.Error(w, fmt.Sprintf("Failed to attach to gateway: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
gateway.pidData = pidData
|
||||
gateway.mu.Unlock()
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"status": "ok",
|
||||
"pid": pid,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
gateway.mu.Lock()
|
||||
@@ -805,65 +881,56 @@ func (h *Handler) gatewayStatusData() map[string]any {
|
||||
}
|
||||
}
|
||||
|
||||
// Probe health endpoint to get pid and status
|
||||
healthResp, statusCode, err := h.getGatewayHealth(cfg, 2*time.Second)
|
||||
if err != nil {
|
||||
// Primary detection: read PID file and check if process is alive.
|
||||
pidData := ppid.ReadPidFileWithCheck(globalConfigDir())
|
||||
if pidData != nil {
|
||||
gateway.mu.Lock()
|
||||
data["gateway_status"] = gatewayStatusWithoutHealthLocked()
|
||||
gateway.pidData = pidData
|
||||
if pidData.Version != "" {
|
||||
data["gateway_version"] = pidData.Version
|
||||
}
|
||||
setGatewayRuntimeStatusLocked("running")
|
||||
|
||||
// Attach if we don't already track this PID.
|
||||
if gateway.cmd == nil || gateway.cmd.Process == nil || gateway.cmd.Process.Pid != pidData.PID {
|
||||
_ = attachToGatewayProcessLocked(pidData.PID, cfg)
|
||||
}
|
||||
|
||||
bootDefaultModel := gateway.bootDefaultModel
|
||||
if bootDefaultModel != "" {
|
||||
data["boot_default_model"] = bootDefaultModel
|
||||
}
|
||||
data["gateway_status"] = "running"
|
||||
data["pid"] = pidData.PID
|
||||
gateway.mu.Unlock()
|
||||
logger.ErrorC("gateway", fmt.Sprintf("Gateway health check failed: %v", err))
|
||||
} else {
|
||||
if statusCode != http.StatusOK {
|
||||
logger.WarnC("gateway", fmt.Sprintf("Gateway health status: %d", statusCode))
|
||||
// Fallback: probe health endpoint to get pid and status
|
||||
_, statusCode, err := h.getGatewayHealth(cfg, 2*time.Second)
|
||||
if err != nil {
|
||||
gateway.mu.Lock()
|
||||
setGatewayRuntimeStatusLocked("error")
|
||||
data["gateway_status"] = gatewayStatusWithoutHealthLocked()
|
||||
gateway.pidData = nil
|
||||
gateway.mu.Unlock()
|
||||
data["gateway_status"] = "error"
|
||||
data["status_code"] = statusCode
|
||||
logger.ErrorC("gateway", fmt.Sprintf("Gateway health check failed: %v", err))
|
||||
} else {
|
||||
gateway.mu.Lock()
|
||||
setGatewayRuntimeStatusLocked("running")
|
||||
if gateway.cmd == nil || gateway.cmd.Process == nil || gateway.cmd.Process.Pid != healthResp.Pid {
|
||||
oldPid := "none"
|
||||
if gateway.cmd != nil && gateway.cmd.Process != nil {
|
||||
oldPid = fmt.Sprintf("%d", gateway.cmd.Process.Pid)
|
||||
}
|
||||
logger.InfoC(
|
||||
"gateway",
|
||||
fmt.Sprintf(
|
||||
"Detected new gateway PID (old: %s, new: %d), attempting to attach",
|
||||
oldPid,
|
||||
healthResp.Pid,
|
||||
),
|
||||
)
|
||||
|
||||
if err := attachToGatewayProcessLocked(healthResp.Pid, cfg); err != nil {
|
||||
// Failed to find the process, treat as error
|
||||
setGatewayRuntimeStatusLocked("error")
|
||||
data["gateway_status"] = "error"
|
||||
data["pid"] = healthResp.Pid
|
||||
logger.ErrorC(
|
||||
"gateway",
|
||||
fmt.Sprintf("Failed to attach to new gateway process (PID: %d): %v", healthResp.Pid, err),
|
||||
)
|
||||
} else {
|
||||
// Successfully attached, update response data
|
||||
bootDefaultModel := gateway.bootDefaultModel
|
||||
if bootDefaultModel != "" {
|
||||
data["boot_default_model"] = bootDefaultModel
|
||||
}
|
||||
data["gateway_status"] = "running"
|
||||
data["pid"] = healthResp.Pid
|
||||
logger.InfoC("gateway", fmt.Sprintf("Gateway health status: %d", statusCode))
|
||||
if statusCode != http.StatusOK {
|
||||
gateway.mu.Lock()
|
||||
setGatewayRuntimeStatusLocked("error")
|
||||
gateway.pidData = nil
|
||||
gateway.mu.Unlock()
|
||||
data["gateway_status"] = "error"
|
||||
data["status_code"] = statusCode
|
||||
} else {
|
||||
gateway.mu.Lock()
|
||||
setGatewayRuntimeStatusLocked("running")
|
||||
bootDefaultModel := gateway.bootDefaultModel
|
||||
if bootDefaultModel != "" {
|
||||
data["boot_default_model"] = bootDefaultModel
|
||||
}
|
||||
data["gateway_status"] = "running"
|
||||
gateway.mu.Unlock()
|
||||
}
|
||||
|
||||
bootDefaultModel := gateway.bootDefaultModel
|
||||
if bootDefaultModel != "" {
|
||||
data["boot_default_model"] = bootDefaultModel
|
||||
}
|
||||
data["gateway_status"] = "running"
|
||||
data["pid"] = healthResp.Pid
|
||||
gateway.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -470,9 +470,6 @@ func TestGatewayStatusReportsRunningFromHealthProbe(t *testing.T) {
|
||||
if got := body["gateway_status"]; got != "running" {
|
||||
t.Fatalf("gateway_status = %#v, want %q", got, "running")
|
||||
}
|
||||
if got := body["pid"]; got != float64(cmd.Process.Pid) {
|
||||
t.Fatalf("pid = %#v, want %d", got, cmd.Process.Pid)
|
||||
}
|
||||
if got := body["gateway_restart_required"]; got != false {
|
||||
t.Fatalf("gateway_restart_required = %#v, want false", got)
|
||||
}
|
||||
|
||||
+53
-8
@@ -10,6 +10,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
// registerPicoRoutes binds Pico Channel management endpoints to the ServeMux.
|
||||
@@ -26,20 +27,55 @@ func (h *Handler) registerPicoRoutes(mux *http.ServeMux) {
|
||||
|
||||
// createWsProxy creates a reverse proxy to the current gateway WebSocket endpoint.
|
||||
// The gateway bind host and port are resolved from the latest configuration.
|
||||
func (h *Handler) createWsProxy() *httputil.ReverseProxy {
|
||||
wsProxy := httputil.NewSingleHostReverseProxy(h.gatewayProxyURL())
|
||||
wsProxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
http.Error(w, "Gateway unavailable: "+err.Error(), http.StatusBadGateway)
|
||||
func (h *Handler) createWsProxy(origProtocol string, token string) *httputil.ReverseProxy {
|
||||
wsProxy := &httputil.ReverseProxy{
|
||||
Rewrite: func(r *httputil.ProxyRequest) {
|
||||
target := h.gatewayProxyURL()
|
||||
r.SetURL(target)
|
||||
r.Out.Header.Set(protocolKey, tokenPrefix+token)
|
||||
},
|
||||
ModifyResponse: func(r *http.Response) error {
|
||||
if prot := r.Header.Values(protocolKey); len(prot) > 0 {
|
||||
r.Header.Del(protocolKey)
|
||||
if origProtocol != "" {
|
||||
r.Header.Set(protocolKey, origProtocol)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
logger.Errorf("Failed to proxy WebSocket: %v", err)
|
||||
http.Error(w, "Gateway unavailable: "+err.Error(), http.StatusBadGateway)
|
||||
},
|
||||
}
|
||||
return wsProxy
|
||||
}
|
||||
|
||||
// handleWebSocketProxy wraps a reverse proxy to handle WebSocket connections.
|
||||
// The reverse proxy forwards the incoming upgrade handshake as-is.
|
||||
// It validates the client token before forwarding; rejects immediately on failure.
|
||||
func (h *Handler) handleWebSocketProxy() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
proxy := h.createWsProxy()
|
||||
proxy.ServeHTTP(w, r)
|
||||
gateway.mu.Lock()
|
||||
gatewayAvailable := gateway.pidData != nil
|
||||
gateway.mu.Unlock()
|
||||
|
||||
if !gatewayAvailable {
|
||||
logger.Warnf("Gateway not available for WebSocket proxy")
|
||||
http.Error(w, "Gateway not available", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
prot := r.Header.Values(protocolKey)
|
||||
if len(prot) > 0 {
|
||||
origProtocol := prot[0]
|
||||
newToken := picoComposedToken(prot[0])
|
||||
if newToken != "" {
|
||||
h.createWsProxy(origProtocol, newToken).ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
logger.Warnf("Invalid Pico token: %v", prot)
|
||||
http.Error(w, "Invalid Pico token", http.StatusForbidden)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -81,6 +117,11 @@ func (h *Handler) handleRegenPicoToken(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Refresh cached pico token.
|
||||
gateway.mu.Lock()
|
||||
gateway.picoToken = token
|
||||
gateway.mu.Unlock()
|
||||
|
||||
wsURL := h.buildWsURL(r)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
@@ -140,11 +181,15 @@ func (h *Handler) handlePicoSetup(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Reload config (EnsurePicoChannel may have modified it) and refresh cache.
|
||||
cfg, err := config.LoadConfig(h.configPath)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to load config: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if changed {
|
||||
refreshPicoToken(cfg)
|
||||
}
|
||||
|
||||
wsURL := h.buildWsURL(r)
|
||||
|
||||
@@ -162,7 +207,7 @@ func generateSecureToken() string {
|
||||
b := make([]byte, 16)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
// Fallback to something pseudo-random if crypto/rand fails
|
||||
return fmt.Sprintf("pico_%x", time.Now().UnixNano())
|
||||
return fmt.Sprintf("%032x", time.Now().UnixNano())
|
||||
}
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
ppid "github.com/sipeed/picoclaw/pkg/pid"
|
||||
)
|
||||
|
||||
func TestEnsurePicoChannel_FreshConfig(t *testing.T) {
|
||||
@@ -335,10 +336,22 @@ func TestHandleWebSocketProxyReloadsGatewayTargetFromConfig(t *testing.T) {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
gateway.pidData = &ppid.PidFileData{}
|
||||
gateway.picoToken = "pico"
|
||||
req1 := httptest.NewRequest(http.MethodGet, "/pico/ws", nil)
|
||||
req1.Header.Set(protocolKey, tokenPrefix+"wrong_token")
|
||||
rec1 := httptest.NewRecorder()
|
||||
handler(rec1, req1)
|
||||
|
||||
if rec1.Code != http.StatusForbidden {
|
||||
t.Fatalf("first status = %d, want %d", rec1.Code, http.StatusForbidden)
|
||||
}
|
||||
|
||||
req1 = httptest.NewRequest(http.MethodGet, "/pico/ws", nil)
|
||||
req1.Header.Set(protocolKey, tokenPrefix+"pico")
|
||||
rec1 = httptest.NewRecorder()
|
||||
handler(rec1, req1)
|
||||
|
||||
if rec1.Code != http.StatusOK {
|
||||
t.Fatalf("first status = %d, want %d", rec1.Code, http.StatusOK)
|
||||
}
|
||||
@@ -352,6 +365,7 @@ func TestHandleWebSocketProxyReloadsGatewayTargetFromConfig(t *testing.T) {
|
||||
}
|
||||
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/pico/ws", nil)
|
||||
req2.Header.Set(protocolKey, tokenPrefix+"pico")
|
||||
rec2 := httptest.NewRecorder()
|
||||
handler(rec2, req2)
|
||||
|
||||
|
||||
@@ -309,14 +309,7 @@ func loadSkillContent(path string) (string, error) {
|
||||
}
|
||||
|
||||
func globalConfigDir() string {
|
||||
if home := os.Getenv(config.EnvHome); home != "" {
|
||||
return home
|
||||
}
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return filepath.Join(home, ".picoclaw")
|
||||
return config.GetHome()
|
||||
}
|
||||
|
||||
func builtinSkillsDir() string {
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
"time"
|
||||
@@ -44,6 +46,15 @@ func (rr *responseRecorder) Unwrap() http.ResponseWriter {
|
||||
return rr.ResponseWriter
|
||||
}
|
||||
|
||||
// Hijack implements http.Hijacker so that WebSocket upgrades work through
|
||||
// the middleware layer.
|
||||
func (rr *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
if hj, ok := rr.ResponseWriter.(http.Hijacker); ok {
|
||||
return hj.Hijack()
|
||||
}
|
||||
return nil, nil, http.ErrNotSupported
|
||||
}
|
||||
|
||||
// Logger logs each HTTP request with method, path, status code, and duration.
|
||||
func Logger(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
@@ -9,16 +9,13 @@ import (
|
||||
"runtime"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
// GetPicoclawHome returns the picoclaw home directory.
|
||||
// Priority: $PICOCLAW_HOME > ~/.picoclaw
|
||||
func GetPicoclawHome() string {
|
||||
if home := os.Getenv(config.EnvHome); home != "" {
|
||||
return home
|
||||
}
|
||||
home, _ := os.UserHomeDir()
|
||||
return filepath.Join(home, ".picoclaw")
|
||||
return config.GetHome()
|
||||
}
|
||||
|
||||
// GetDefaultConfigPath returns the default path to the picoclaw config file.
|
||||
@@ -47,6 +44,7 @@ func FindPicoclawBinary() string {
|
||||
}
|
||||
|
||||
if exe, err := os.Executable(); err == nil {
|
||||
logger.Debugf("Trying to find picoclaw binary in %s", exe)
|
||||
candidate := filepath.Join(filepath.Dir(exe), binaryName)
|
||||
if info, err := os.Stat(candidate); err == nil && !info.IsDir() {
|
||||
return candidate
|
||||
|
||||
Reference in New Issue
Block a user