mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge branch 'main' into version
This commit is contained in:
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"regexp"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
@@ -188,6 +189,27 @@ func validateConfig(cfg *config.Config) []string {
|
||||
errs = append(errs, "channels.discord.token is required when discord channel is enabled")
|
||||
}
|
||||
|
||||
if cfg.Tools.Exec.Enabled {
|
||||
if cfg.Tools.Exec.EnableDenyPatterns {
|
||||
errs = append(
|
||||
errs,
|
||||
validateRegexPatterns("tools.exec.custom_deny_patterns", cfg.Tools.Exec.CustomDenyPatterns)...)
|
||||
}
|
||||
errs = append(
|
||||
errs,
|
||||
validateRegexPatterns("tools.exec.custom_allow_patterns", cfg.Tools.Exec.CustomAllowPatterns)...)
|
||||
}
|
||||
|
||||
return errs
|
||||
}
|
||||
|
||||
func validateRegexPatterns(field string, patterns []string) []string {
|
||||
var errs []string
|
||||
for index, pattern := range patterns {
|
||||
if _, err := regexp.Compile(pattern); err != nil {
|
||||
errs = append(errs, fmt.Sprintf("%s[%d] is not a valid regular expression: %v", field, index, err))
|
||||
}
|
||||
}
|
||||
return errs
|
||||
}
|
||||
|
||||
|
||||
@@ -86,3 +86,82 @@ func TestHandleUpdateConfig_DoesNotInheritDefaultModelFields(t *testing.T) {
|
||||
t.Fatalf("model_list[0].api_base = %q, want empty string", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlePatchConfig_RejectsInvalidExecRegexPatterns(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPatch, "/api/config", bytes.NewBufferString(`{
|
||||
"tools": {
|
||||
"exec": {
|
||||
"custom_deny_patterns": ["("]
|
||||
}
|
||||
}
|
||||
}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
mux.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusBadRequest, rec.Body.String())
|
||||
}
|
||||
if !bytes.Contains(rec.Body.Bytes(), []byte("custom_deny_patterns")) {
|
||||
t.Fatalf("expected validation error mentioning custom_deny_patterns, body=%s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlePatchConfig_AllowsInvalidExecRegexPatternsWhenExecDisabled(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPatch, "/api/config", bytes.NewBufferString(`{
|
||||
"tools": {
|
||||
"exec": {
|
||||
"enabled": false,
|
||||
"custom_deny_patterns": ["("],
|
||||
"custom_allow_patterns": ["("]
|
||||
}
|
||||
}
|
||||
}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
mux.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlePatchConfig_AllowsInvalidDenyRegexPatternsWhenDenyPatternsDisabled(t *testing.T) {
|
||||
configPath, cleanup := setupOAuthTestEnv(t)
|
||||
defer cleanup()
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPatch, "/api/config", bytes.NewBufferString(`{
|
||||
"tools": {
|
||||
"exec": {
|
||||
"enabled": true,
|
||||
"enable_deny_patterns": false,
|
||||
"custom_deny_patterns": ["("]
|
||||
}
|
||||
}
|
||||
}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
mux.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,62 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// GatewayEvent represents a state change event for the gateway process.
|
||||
type GatewayEvent struct {
|
||||
Status string `json:"gateway_status"` // "running", "starting", "stopped", "error"
|
||||
PID int `json:"pid,omitempty"`
|
||||
}
|
||||
|
||||
// EventBroadcaster manages SSE client subscriptions and broadcasts events.
|
||||
type EventBroadcaster struct {
|
||||
mu sync.RWMutex
|
||||
clients map[chan string]struct{}
|
||||
}
|
||||
|
||||
// NewEventBroadcaster creates a new broadcaster.
|
||||
func NewEventBroadcaster() *EventBroadcaster {
|
||||
return &EventBroadcaster{
|
||||
clients: make(map[chan string]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Subscribe adds a new listener channel and returns it.
|
||||
// The caller must call Unsubscribe when done.
|
||||
func (b *EventBroadcaster) Subscribe() chan string {
|
||||
ch := make(chan string, 8)
|
||||
b.mu.Lock()
|
||||
b.clients[ch] = struct{}{}
|
||||
b.mu.Unlock()
|
||||
return ch
|
||||
}
|
||||
|
||||
// Unsubscribe removes a listener channel and closes it.
|
||||
func (b *EventBroadcaster) Unsubscribe(ch chan string) {
|
||||
b.mu.Lock()
|
||||
delete(b.clients, ch)
|
||||
b.mu.Unlock()
|
||||
close(ch)
|
||||
}
|
||||
|
||||
// Broadcast sends a GatewayEvent to all connected SSE clients.
|
||||
func (b *EventBroadcaster) Broadcast(event GatewayEvent) {
|
||||
data, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
b.mu.RLock()
|
||||
defer b.mu.RUnlock()
|
||||
|
||||
for ch := range b.clients {
|
||||
// Non-blocking send; drop event if client is slow
|
||||
select {
|
||||
case ch <- string(data):
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
+513
-203
@@ -3,9 +3,9 @@ package api
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
@@ -18,24 +18,70 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/health"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/web/backend/utils"
|
||||
)
|
||||
|
||||
// gateway holds the state for the managed gateway process.
|
||||
var gateway = struct {
|
||||
mu sync.Mutex
|
||||
cmd *exec.Cmd
|
||||
logs *LogBuffer
|
||||
events *EventBroadcaster
|
||||
mu sync.Mutex
|
||||
cmd *exec.Cmd
|
||||
owned bool // true if we started the process, false if we attached to an existing one
|
||||
bootDefaultModel string
|
||||
runtimeStatus string
|
||||
startupDeadline time.Time
|
||||
logs *LogBuffer
|
||||
}{
|
||||
logs: NewLogBuffer(200),
|
||||
events: NewEventBroadcaster(),
|
||||
runtimeStatus: "stopped",
|
||||
logs: NewLogBuffer(200),
|
||||
}
|
||||
|
||||
var (
|
||||
gatewayStartupWindow = 15 * time.Second
|
||||
gatewayRestartGracePeriod = 5 * time.Second
|
||||
gatewayRestartForceKillWindow = 3 * time.Second
|
||||
gatewayRestartPollInterval = 100 * time.Millisecond
|
||||
)
|
||||
|
||||
var gatewayHealthGet = func(url string, timeout time.Duration) (*http.Response, error) {
|
||||
client := http.Client{Timeout: timeout}
|
||||
return client.Get(url)
|
||||
}
|
||||
|
||||
// getGatewayHealth checks the gateway health endpoint and returns the status response
|
||||
// Returns (*health.StatusResponse, statusCode, error). If error is not nil, the other values are not valid.
|
||||
func (h *Handler) getGatewayHealth(cfg *config.Config, timeout time.Duration) (*health.StatusResponse, int, error) {
|
||||
port := 18790
|
||||
if cfg != nil && cfg.Gateway.Port != 0 {
|
||||
port = cfg.Gateway.Port
|
||||
}
|
||||
|
||||
probeHost := gatewayProbeHost(h.effectiveGatewayBindHost(cfg))
|
||||
url := "http://" + net.JoinHostPort(probeHost, strconv.Itoa(port)) + "/health"
|
||||
|
||||
return getGatewayHealthByURL(url, timeout)
|
||||
}
|
||||
|
||||
func getGatewayHealthByURL(url string, timeout time.Duration) (*health.StatusResponse, int, error) {
|
||||
resp, err := gatewayHealthGet(url, timeout)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var healthResponse health.StatusResponse
|
||||
if decErr := json.NewDecoder(resp.Body).Decode(&healthResponse); decErr != nil {
|
||||
return nil, resp.StatusCode, decErr
|
||||
}
|
||||
|
||||
return &healthResponse, resp.StatusCode, nil
|
||||
}
|
||||
|
||||
// registerGatewayRoutes binds gateway lifecycle endpoints to the ServeMux.
|
||||
func (h *Handler) registerGatewayRoutes(mux *http.ServeMux) {
|
||||
mux.HandleFunc("GET /api/gateway/status", h.handleGatewayStatus)
|
||||
mux.HandleFunc("GET /api/gateway/events", h.handleGatewayEvents)
|
||||
mux.HandleFunc("GET /api/gateway/logs", h.handleGatewayLogs)
|
||||
mux.HandleFunc("POST /api/gateway/logs/clear", h.handleGatewayClearLogs)
|
||||
mux.HandleFunc("POST /api/gateway/start", h.handleGatewayStart)
|
||||
mux.HandleFunc("POST /api/gateway/stop", h.handleGatewayStop)
|
||||
@@ -45,32 +91,55 @@ func (h *Handler) registerGatewayRoutes(mux *http.ServeMux) {
|
||||
// TryAutoStartGateway checks whether gateway start preconditions are met and
|
||||
// starts it when possible. Intended to be called by the backend at startup.
|
||||
func (h *Handler) TryAutoStartGateway() {
|
||||
// Check if gateway is already running via health endpoint
|
||||
cfg, cfgErr := config.LoadConfig(h.configPath)
|
||||
if cfgErr == nil && cfg != nil {
|
||||
healthResp, statusCode, err := h.getGatewayHealth(cfg, 2*time.Second)
|
||||
if err == nil && statusCode == http.StatusOK {
|
||||
// Gateway is already running, attach to the existing process
|
||||
pid := healthResp.Pid
|
||||
gateway.mu.Lock()
|
||||
defer gateway.mu.Unlock()
|
||||
ready, reason, err := h.gatewayStartReady()
|
||||
if err != nil {
|
||||
logger.ErrorC("gateway", fmt.Sprintf("Skip auto-starting gateway: %v", err))
|
||||
return
|
||||
}
|
||||
if !ready {
|
||||
logger.InfoC("gateway", fmt.Sprintf("Skip auto-starting gateway: %s", reason))
|
||||
return
|
||||
}
|
||||
_, err = h.startGatewayLocked("starting", pid)
|
||||
if err != nil {
|
||||
logger.ErrorC("gateway", fmt.Sprintf("Failed to attach to running gateway (PID: %d): %v", pid, err))
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
gateway.mu.Lock()
|
||||
defer gateway.mu.Unlock()
|
||||
|
||||
if isGatewayProcessAliveLocked() {
|
||||
return
|
||||
}
|
||||
if gateway.cmd != nil && gateway.cmd.Process != nil {
|
||||
gateway.cmd = nil
|
||||
}
|
||||
|
||||
ready, reason, err := h.gatewayStartReady()
|
||||
if err != nil {
|
||||
log.Printf("Skip auto-starting gateway: %v", err)
|
||||
logger.ErrorC("gateway", fmt.Sprintf("Skip auto-starting gateway: %v", err))
|
||||
return
|
||||
}
|
||||
if !ready {
|
||||
log.Printf("Skip auto-starting gateway: %s", reason)
|
||||
logger.InfoC("gateway", fmt.Sprintf("Skip auto-starting gateway: %s", reason))
|
||||
return
|
||||
}
|
||||
|
||||
pid, err := h.startGatewayLocked()
|
||||
pid, err := h.startGatewayLocked("starting", 0)
|
||||
if err != nil {
|
||||
log.Printf("Failed to auto-start gateway: %v", err)
|
||||
logger.ErrorC("gateway", fmt.Sprintf("Failed to auto-start gateway: %v", err))
|
||||
return
|
||||
}
|
||||
log.Printf("Gateway auto-started (PID: %d)", pid)
|
||||
logger.InfoC("gateway", fmt.Sprintf("Gateway auto-started (PID: %d)", pid))
|
||||
}
|
||||
|
||||
// gatewayStartReady validates whether current config can start the gateway.
|
||||
@@ -108,8 +177,14 @@ func lookupModelConfig(cfg *config.Config, modelName string) *config.ModelConfig
|
||||
return modelCfg
|
||||
}
|
||||
|
||||
func isGatewayProcessAliveLocked() bool {
|
||||
return isCmdProcessAliveLocked(gateway.cmd)
|
||||
func gatewayRestartRequired(configDefaultModel, bootDefaultModel, gatewayStatus string) bool {
|
||||
if gatewayStatus != "running" {
|
||||
return false
|
||||
}
|
||||
if strings.TrimSpace(configDefaultModel) == "" || strings.TrimSpace(bootDefaultModel) == "" {
|
||||
return false
|
||||
}
|
||||
return configDefaultModel != bootDefaultModel
|
||||
}
|
||||
|
||||
func isCmdProcessAliveLocked(cmd *exec.Cmd) bool {
|
||||
@@ -131,20 +206,191 @@ func isCmdProcessAliveLocked(cmd *exec.Cmd) bool {
|
||||
return cmd.Process.Signal(syscall.Signal(0)) == nil
|
||||
}
|
||||
|
||||
func (h *Handler) startGatewayLocked() (int, error) {
|
||||
func setGatewayRuntimeStatusLocked(status string) {
|
||||
gateway.runtimeStatus = status
|
||||
if status == "starting" || status == "restarting" {
|
||||
gateway.startupDeadline = time.Now().Add(gatewayStartupWindow)
|
||||
return
|
||||
}
|
||||
gateway.startupDeadline = time.Time{}
|
||||
}
|
||||
|
||||
// attachToGatewayProcess attaches to an existing gateway process by PID
|
||||
// and updates the gateway state accordingly.
|
||||
// Assumes gateway.mu is held by the caller.
|
||||
func attachToGatewayProcessLocked(pid int, cfg *config.Config) error {
|
||||
process, err := os.FindProcess(pid)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to find process for PID %d: %w", pid, err)
|
||||
}
|
||||
|
||||
gateway.cmd = &exec.Cmd{Process: process}
|
||||
gateway.owned = false // We didn't start this process
|
||||
setGatewayRuntimeStatusLocked("running")
|
||||
|
||||
// Update bootDefaultModel from config
|
||||
if cfg != nil {
|
||||
defaultModelName := strings.TrimSpace(cfg.Agents.Defaults.GetModelName())
|
||||
gateway.bootDefaultModel = defaultModelName
|
||||
}
|
||||
|
||||
logger.InfoC("gateway", fmt.Sprintf("Attached to gateway process (PID: %d)", pid))
|
||||
return nil
|
||||
}
|
||||
|
||||
func gatewayStatusWithoutHealthLocked() string {
|
||||
if gateway.runtimeStatus == "starting" || gateway.runtimeStatus == "restarting" {
|
||||
if gateway.startupDeadline.IsZero() || time.Now().Before(gateway.startupDeadline) {
|
||||
return gateway.runtimeStatus
|
||||
}
|
||||
return "error"
|
||||
}
|
||||
if gateway.runtimeStatus == "running" {
|
||||
return "running"
|
||||
}
|
||||
if gateway.runtimeStatus == "error" {
|
||||
return "error"
|
||||
}
|
||||
return "stopped"
|
||||
}
|
||||
|
||||
func waitForGatewayProcessExit(cmd *exec.Cmd, timeout time.Duration) bool {
|
||||
if cmd == nil || cmd.Process == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
deadline := time.Now().Add(timeout)
|
||||
for {
|
||||
if !isCmdProcessAliveLocked(cmd) {
|
||||
return true
|
||||
}
|
||||
if time.Now().After(deadline) {
|
||||
return false
|
||||
}
|
||||
time.Sleep(gatewayRestartPollInterval)
|
||||
}
|
||||
}
|
||||
|
||||
// StopGateway stops the gateway process if it was started by this handler.
|
||||
// This method is called during application shutdown to ensure the gateway subprocess
|
||||
// is properly terminated. It only stops processes that were started by this handler,
|
||||
// not processes that were attached to from existing instances.
|
||||
func (h *Handler) StopGateway() {
|
||||
gateway.mu.Lock()
|
||||
defer gateway.mu.Unlock()
|
||||
|
||||
// Only stop if we own the process (started it ourselves)
|
||||
if !gateway.owned || gateway.cmd == nil || gateway.cmd.Process == nil {
|
||||
return
|
||||
}
|
||||
|
||||
pid, err := stopGatewayLocked()
|
||||
if err != nil {
|
||||
logger.ErrorC("gateway", fmt.Sprintf("Failed to stop gateway (PID %d): %v", pid, err))
|
||||
return
|
||||
}
|
||||
|
||||
logger.InfoC("gateway", fmt.Sprintf("Gateway stopped (PID: %d)", pid))
|
||||
}
|
||||
|
||||
// stopGatewayLocked sends a stop signal to the gateway process.
|
||||
// Assumes gateway.mu is held by the caller.
|
||||
// Returns the PID of the stopped process and any error encountered.
|
||||
func stopGatewayLocked() (int, error) {
|
||||
if gateway.cmd == nil || gateway.cmd.Process == nil {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
pid := gateway.cmd.Process.Pid
|
||||
|
||||
// Send SIGTERM for graceful shutdown (SIGKILL on Windows)
|
||||
var sigErr error
|
||||
if runtime.GOOS == "windows" {
|
||||
sigErr = gateway.cmd.Process.Kill()
|
||||
} else {
|
||||
sigErr = gateway.cmd.Process.Signal(syscall.SIGTERM)
|
||||
}
|
||||
|
||||
if sigErr != nil {
|
||||
return pid, sigErr
|
||||
}
|
||||
|
||||
logger.InfoC("gateway", fmt.Sprintf("Sent stop signal to gateway (PID: %d)", pid))
|
||||
gateway.cmd = nil
|
||||
gateway.owned = false
|
||||
gateway.bootDefaultModel = ""
|
||||
setGatewayRuntimeStatusLocked("stopped")
|
||||
|
||||
return pid, nil
|
||||
}
|
||||
|
||||
func stopGatewayProcessForRestart(cmd *exec.Cmd) error {
|
||||
if cmd == nil || cmd.Process == nil || !isCmdProcessAliveLocked(cmd) {
|
||||
return nil
|
||||
}
|
||||
|
||||
var stopErr error
|
||||
if runtime.GOOS == "windows" {
|
||||
stopErr = cmd.Process.Kill()
|
||||
} else {
|
||||
stopErr = cmd.Process.Signal(syscall.SIGTERM)
|
||||
}
|
||||
if stopErr != nil && isCmdProcessAliveLocked(cmd) {
|
||||
return fmt.Errorf("failed to stop existing gateway: %w", stopErr)
|
||||
}
|
||||
|
||||
if waitForGatewayProcessExit(cmd, gatewayRestartGracePeriod) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if runtime.GOOS != "windows" {
|
||||
killErr := cmd.Process.Signal(syscall.SIGKILL)
|
||||
if killErr != nil && isCmdProcessAliveLocked(cmd) {
|
||||
return fmt.Errorf("failed to force-stop existing gateway: %w", killErr)
|
||||
}
|
||||
if waitForGatewayProcessExit(cmd, gatewayRestartForceKillWindow) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("existing gateway did not exit before restart")
|
||||
}
|
||||
|
||||
func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int, error) {
|
||||
cfg, err := config.LoadConfig(h.configPath)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to load config: %w", err)
|
||||
}
|
||||
defaultModelName := strings.TrimSpace(cfg.Agents.Defaults.GetModelName())
|
||||
|
||||
var cmd *exec.Cmd
|
||||
var pid int
|
||||
|
||||
if existingPid > 0 {
|
||||
// Attach to existing process
|
||||
pid = existingPid
|
||||
gateway.cmd = nil // Clear first to ensure clean state
|
||||
if err = attachToGatewayProcessLocked(pid, cfg); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return pid, nil
|
||||
}
|
||||
|
||||
// Start new process
|
||||
// Locate the picoclaw executable
|
||||
execPath := utils.FindPicoclawBinary()
|
||||
|
||||
cmd := exec.Command(execPath, "gateway")
|
||||
cmd = exec.Command(execPath, "gateway", "-E")
|
||||
cmd.Env = os.Environ()
|
||||
// Forward the launcher's config path via the environment variable that
|
||||
// GetConfigPath() already reads, so the gateway sub-process uses the same
|
||||
// config file without requiring a --config flag on the gateway subcommand.
|
||||
if h.configPath != "" {
|
||||
cmd.Env = append(cmd.Env, "PICOCLAW_CONFIG="+h.configPath)
|
||||
cmd.Env = append(cmd.Env, config.EnvConfig+"="+h.configPath)
|
||||
}
|
||||
if host := h.gatewayHostOverride(); host != "" {
|
||||
cmd.Env = append(cmd.Env, "PICOCLAW_GATEWAY_HOST="+host)
|
||||
cmd.Env = append(cmd.Env, config.EnvGatewayHost+"="+host)
|
||||
}
|
||||
|
||||
stdoutPipe, err := cmd.StdoutPipe()
|
||||
@@ -161,8 +407,8 @@ func (h *Handler) startGatewayLocked() (int, error) {
|
||||
gateway.logs.Reset()
|
||||
|
||||
// Ensure Pico Channel is configured before starting gateway
|
||||
if _, err := h.ensurePicoChannel(); err != nil {
|
||||
log.Printf("Warning: failed to ensure pico channel: %v", err)
|
||||
if _, err := h.ensurePicoChannel(""); err != nil {
|
||||
logger.ErrorC("gateway", fmt.Sprintf("Warning: failed to ensure pico channel: %v", err))
|
||||
// Non-fatal: gateway can still start without pico channel
|
||||
}
|
||||
|
||||
@@ -171,11 +417,11 @@ func (h *Handler) startGatewayLocked() (int, error) {
|
||||
}
|
||||
|
||||
gateway.cmd = cmd
|
||||
pid := cmd.Process.Pid
|
||||
log.Printf("Started picoclaw gateway (PID: %d) from %s", pid, execPath)
|
||||
|
||||
// Broadcast starting event
|
||||
gateway.events.Broadcast(GatewayEvent{Status: "starting", PID: pid})
|
||||
gateway.owned = true // We started this process
|
||||
gateway.bootDefaultModel = defaultModelName
|
||||
setGatewayRuntimeStatusLocked(initialStatus)
|
||||
pid = cmd.Process.Pid
|
||||
logger.InfoC("gateway", fmt.Sprintf("Started picoclaw gateway (PID: %d) from %s", pid, execPath))
|
||||
|
||||
// Capture stdout/stderr in background
|
||||
go scanPipe(stdoutPipe, gateway.logs)
|
||||
@@ -184,22 +430,23 @@ func (h *Handler) startGatewayLocked() (int, error) {
|
||||
// Wait for exit in background and clean up
|
||||
go func() {
|
||||
if err := cmd.Wait(); err != nil {
|
||||
log.Printf("Gateway process exited: %v", err)
|
||||
logger.ErrorC("gateway", fmt.Sprintf("Gateway process exited: %v", err))
|
||||
} else {
|
||||
log.Printf("Gateway process exited normally")
|
||||
logger.InfoC("gateway", "Gateway process exited normally")
|
||||
}
|
||||
|
||||
gateway.mu.Lock()
|
||||
if gateway.cmd == cmd {
|
||||
gateway.cmd = nil
|
||||
gateway.bootDefaultModel = ""
|
||||
if gateway.runtimeStatus != "restarting" {
|
||||
setGatewayRuntimeStatusLocked("stopped")
|
||||
}
|
||||
}
|
||||
gateway.mu.Unlock()
|
||||
|
||||
// Broadcast stopped event
|
||||
gateway.events.Broadcast(GatewayEvent{Status: "stopped"})
|
||||
}()
|
||||
|
||||
// Start a goroutine to probe health and broadcast "running" once ready
|
||||
// Start a goroutine to probe health and update the runtime state once ready.
|
||||
go func() {
|
||||
for i := 0; i < 30; i++ { // try for up to 15 seconds
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
@@ -213,20 +460,15 @@ func (h *Handler) startGatewayLocked() (int, error) {
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
healthHost := gatewayProbeHost(h.effectiveGatewayBindHost(cfg))
|
||||
healthPort := cfg.Gateway.Port
|
||||
if healthPort == 0 {
|
||||
healthPort = 18790
|
||||
}
|
||||
healthURL := fmt.Sprintf("http://%s/health", net.JoinHostPort(healthHost, strconv.Itoa(healthPort)))
|
||||
client := http.Client{Timeout: 1 * time.Second}
|
||||
resp, err := client.Get(healthURL)
|
||||
if err == nil {
|
||||
resp.Body.Close()
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
gateway.events.Broadcast(GatewayEvent{Status: "running", PID: pid})
|
||||
return
|
||||
healthResp, statusCode, err := h.getGatewayHealth(cfg, 1*time.Second)
|
||||
if err == nil && statusCode == http.StatusOK && healthResp.Pid == pid {
|
||||
// Verify the health endpoint returns the expected pid
|
||||
gateway.mu.Lock()
|
||||
if gateway.cmd == cmd {
|
||||
setGatewayRuntimeStatusLocked("running")
|
||||
}
|
||||
gateway.mu.Unlock()
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
@@ -238,21 +480,57 @@ func (h *Handler) startGatewayLocked() (int, error) {
|
||||
//
|
||||
// POST /api/gateway/start
|
||||
func (h *Handler) handleGatewayStart(w http.ResponseWriter, r *http.Request) {
|
||||
// Prevent duplicate starts by checking health endpoint
|
||||
cfg, cfgErr := config.LoadConfig(h.configPath)
|
||||
if cfgErr == nil && cfg != nil {
|
||||
healthResp, statusCode, err := h.getGatewayHealth(cfg, 2*time.Second)
|
||||
if err == nil && statusCode == http.StatusOK {
|
||||
// Gateway is already running, attach to the existing process
|
||||
pid := healthResp.Pid
|
||||
gateway.mu.Lock()
|
||||
ready, reason, err := h.gatewayStartReady()
|
||||
if err != nil {
|
||||
gateway.mu.Unlock()
|
||||
http.Error(
|
||||
w,
|
||||
fmt.Sprintf("Failed to validate gateway start conditions: %v", err),
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
return
|
||||
}
|
||||
if !ready {
|
||||
gateway.mu.Unlock()
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"status": "precondition_failed",
|
||||
"message": reason,
|
||||
})
|
||||
return
|
||||
}
|
||||
_, err = h.startGatewayLocked("starting", pid)
|
||||
gateway.mu.Unlock()
|
||||
if err != nil {
|
||||
logger.ErrorC("gateway", fmt.Sprintf("Failed to attach to running gateway (PID: %d): %v", pid, err))
|
||||
http.Error(w, fmt.Sprintf("Failed to attach to gateway: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"status": "ok",
|
||||
"pid": pid,
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
gateway.mu.Lock()
|
||||
defer gateway.mu.Unlock()
|
||||
|
||||
// Prevent duplicate starts
|
||||
if isGatewayProcessAliveLocked() {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusConflict)
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"status": "already_running",
|
||||
"pid": gateway.cmd.Process.Pid,
|
||||
})
|
||||
return
|
||||
}
|
||||
if gateway.cmd != nil && gateway.cmd.Process != nil {
|
||||
gateway.cmd = nil
|
||||
setGatewayRuntimeStatusLocked("stopped")
|
||||
}
|
||||
|
||||
ready, reason, err := h.gatewayStartReady()
|
||||
@@ -274,7 +552,7 @@ func (h *Handler) handleGatewayStart(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
pid, err := h.startGatewayLocked()
|
||||
pid, err := h.startGatewayLocked("starting", 0)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to start gateway: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
@@ -288,6 +566,8 @@ func (h *Handler) handleGatewayStart(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// handleGatewayStop stops the running gateway subprocess gracefully.
|
||||
// Note: Unlike StopGateway (which only stops self-started processes), this API endpoint
|
||||
// stops any gateway process, including attached ones. This is intentional for user control.
|
||||
//
|
||||
// POST /api/gateway/stop
|
||||
func (h *Handler) handleGatewayStop(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -302,23 +582,12 @@ func (h *Handler) handleGatewayStop(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
pid := gateway.cmd.Process.Pid
|
||||
|
||||
// Send SIGTERM for graceful shutdown (SIGKILL on Windows)
|
||||
var sigErr error
|
||||
if runtime.GOOS == "windows" {
|
||||
sigErr = gateway.cmd.Process.Kill()
|
||||
} else {
|
||||
sigErr = gateway.cmd.Process.Signal(syscall.SIGTERM)
|
||||
}
|
||||
|
||||
if sigErr != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to stop gateway (PID %d): %v", pid, sigErr), http.StatusInternalServerError)
|
||||
pid, err := stopGatewayLocked()
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to stop gateway (PID %d): %v", pid, err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("Sent stop signal to gateway (PID: %d)", pid)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"status": "ok",
|
||||
@@ -326,34 +595,97 @@ func (h *Handler) handleGatewayStop(w http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
}
|
||||
|
||||
// RestartGateway restarts the gateway process. This is a non-blocking operation
|
||||
// that stops the current gateway (if running) and starts a new one.
|
||||
// Returns the PID of the new gateway process or an error.
|
||||
func (h *Handler) RestartGateway() (int, error) {
|
||||
ready, reason, err := h.gatewayStartReady()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to validate gateway start conditions: %w", err)
|
||||
}
|
||||
if !ready {
|
||||
return 0, &preconditionFailedError{reason: reason}
|
||||
}
|
||||
|
||||
gateway.mu.Lock()
|
||||
previousCmd := gateway.cmd
|
||||
setGatewayRuntimeStatusLocked("restarting")
|
||||
gateway.mu.Unlock()
|
||||
|
||||
if err = stopGatewayProcessForRestart(previousCmd); err != nil {
|
||||
gateway.mu.Lock()
|
||||
if gateway.cmd == previousCmd {
|
||||
if isCmdProcessAliveLocked(previousCmd) {
|
||||
setGatewayRuntimeStatusLocked("running")
|
||||
} else {
|
||||
gateway.cmd = nil
|
||||
gateway.bootDefaultModel = ""
|
||||
setGatewayRuntimeStatusLocked("error")
|
||||
}
|
||||
}
|
||||
gateway.mu.Unlock()
|
||||
return 0, fmt.Errorf("failed to stop gateway: %w", err)
|
||||
}
|
||||
|
||||
gateway.mu.Lock()
|
||||
if gateway.cmd == previousCmd {
|
||||
gateway.cmd = nil
|
||||
gateway.bootDefaultModel = ""
|
||||
}
|
||||
pid, err := h.startGatewayLocked("restarting", 0)
|
||||
if err != nil {
|
||||
gateway.cmd = nil
|
||||
gateway.bootDefaultModel = ""
|
||||
setGatewayRuntimeStatusLocked("error")
|
||||
}
|
||||
gateway.mu.Unlock()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to start gateway: %w", err)
|
||||
}
|
||||
|
||||
return pid, nil
|
||||
}
|
||||
|
||||
// preconditionFailedError is returned when gateway restart preconditions are not met
|
||||
type preconditionFailedError struct {
|
||||
reason string
|
||||
}
|
||||
|
||||
func (e *preconditionFailedError) Error() string {
|
||||
return e.reason
|
||||
}
|
||||
|
||||
// IsBadRequest returns true if the error should result in a 400 Bad Request status
|
||||
func (e *preconditionFailedError) IsBadRequest() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// handleGatewayRestart stops the gateway (if running) and starts a new instance.
|
||||
//
|
||||
// POST /api/gateway/restart
|
||||
func (h *Handler) handleGatewayRestart(w http.ResponseWriter, r *http.Request) {
|
||||
gateway.mu.Lock()
|
||||
|
||||
// Stop existing process if running
|
||||
if gateway.cmd != nil && gateway.cmd.Process != nil {
|
||||
if isCmdProcessAliveLocked(gateway.cmd) {
|
||||
// Process is alive, send SIGTERM
|
||||
if runtime.GOOS == "windows" {
|
||||
gateway.cmd.Process.Kill()
|
||||
} else {
|
||||
gateway.cmd.Process.Signal(syscall.SIGTERM)
|
||||
}
|
||||
|
||||
// Wait briefly for it to exit
|
||||
gateway.mu.Unlock()
|
||||
time.Sleep(2 * time.Second)
|
||||
gateway.mu.Lock()
|
||||
pid, err := h.RestartGateway()
|
||||
if err != nil {
|
||||
// Check if it's a precondition failed error
|
||||
var precondErr *preconditionFailedError
|
||||
if errors.As(err, &precondErr) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"status": "precondition_failed",
|
||||
"message": precondErr.reason,
|
||||
})
|
||||
return
|
||||
}
|
||||
gateway.cmd = nil
|
||||
http.Error(w, fmt.Sprintf("Failed to restart gateway: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
gateway.mu.Unlock()
|
||||
|
||||
// Start fresh via the existing handler
|
||||
h.handleGatewayStart(w, r)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"status": "ok",
|
||||
"pid": pid,
|
||||
})
|
||||
}
|
||||
|
||||
// handleGatewayClearLogs clears the in-memory gateway log buffer.
|
||||
@@ -370,59 +702,96 @@ func (h *Handler) handleGatewayClearLogs(w http.ResponseWriter, r *http.Request)
|
||||
})
|
||||
}
|
||||
|
||||
// handleGatewayStatus returns the gateway run status, health info, and logs.
|
||||
// handleGatewayStatus returns the gateway run status and health info.
|
||||
//
|
||||
// GET /api/gateway/status
|
||||
func (h *Handler) handleGatewayStatus(w http.ResponseWriter, r *http.Request) {
|
||||
data := h.gatewayStatusData()
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(data)
|
||||
}
|
||||
|
||||
func (h *Handler) gatewayStatusData() map[string]any {
|
||||
data := map[string]any{}
|
||||
|
||||
// Check process state
|
||||
gateway.mu.Lock()
|
||||
processAlive := isGatewayProcessAliveLocked()
|
||||
if processAlive {
|
||||
data["pid"] = gateway.cmd.Process.Pid
|
||||
}
|
||||
gateway.mu.Unlock()
|
||||
|
||||
if !processAlive {
|
||||
data["gateway_status"] = "stopped"
|
||||
} else {
|
||||
// Process is alive — probe its health endpoint
|
||||
cfg, err := config.LoadConfig(h.configPath)
|
||||
host := "127.0.0.1"
|
||||
port := 18790
|
||||
if err == nil && cfg != nil {
|
||||
host = gatewayProbeHost(h.effectiveGatewayBindHost(cfg))
|
||||
if cfg.Gateway.Port != 0 {
|
||||
port = cfg.Gateway.Port
|
||||
}
|
||||
configDefaultModel := ""
|
||||
cfg, cfgErr := config.LoadConfig(h.configPath)
|
||||
if cfgErr == nil && cfg != nil {
|
||||
configDefaultModel = strings.TrimSpace(cfg.Agents.Defaults.GetModelName())
|
||||
if configDefaultModel != "" {
|
||||
data["config_default_model"] = configDefaultModel
|
||||
}
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("http://%s/health", net.JoinHostPort(host, strconv.Itoa(port)))
|
||||
client := http.Client{Timeout: 2 * time.Second}
|
||||
resp, err := client.Get(url)
|
||||
|
||||
if err != nil {
|
||||
data["gateway_status"] = "starting"
|
||||
// Probe health endpoint to get pid and status
|
||||
healthResp, statusCode, err := h.getGatewayHealth(cfg, 2*time.Second)
|
||||
if err != nil {
|
||||
gateway.mu.Lock()
|
||||
data["gateway_status"] = gatewayStatusWithoutHealthLocked()
|
||||
gateway.mu.Unlock()
|
||||
logger.ErrorC("gateway", fmt.Sprintf("Gateway health check failed: %v", err))
|
||||
} else {
|
||||
logger.InfoC("gateway", fmt.Sprintf("Gateway health status: %d", statusCode))
|
||||
if statusCode != http.StatusOK {
|
||||
gateway.mu.Lock()
|
||||
setGatewayRuntimeStatusLocked("error")
|
||||
gateway.mu.Unlock()
|
||||
data["gateway_status"] = "error"
|
||||
data["status_code"] = statusCode
|
||||
} else {
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
data["gateway_status"] = "error"
|
||||
data["status_code"] = resp.StatusCode
|
||||
} else {
|
||||
var healthData map[string]any
|
||||
if decErr := json.NewDecoder(resp.Body).Decode(&healthData); decErr != nil {
|
||||
gateway.mu.Lock()
|
||||
setGatewayRuntimeStatusLocked("running")
|
||||
if gateway.cmd == nil || gateway.cmd.Process == nil || gateway.cmd.Process.Pid != healthResp.Pid {
|
||||
oldPid := "none"
|
||||
if gateway.cmd != nil && gateway.cmd.Process != nil {
|
||||
oldPid = fmt.Sprintf("%d", gateway.cmd.Process.Pid)
|
||||
}
|
||||
logger.InfoC(
|
||||
"gateway",
|
||||
fmt.Sprintf(
|
||||
"Detected new gateway PID (old: %s, new: %d), attempting to attach",
|
||||
oldPid,
|
||||
healthResp.Pid,
|
||||
),
|
||||
)
|
||||
|
||||
if err := attachToGatewayProcessLocked(healthResp.Pid, cfg); err != nil {
|
||||
// Failed to find the process, treat as error
|
||||
setGatewayRuntimeStatusLocked("error")
|
||||
data["gateway_status"] = "error"
|
||||
data["pid"] = healthResp.Pid
|
||||
logger.ErrorC(
|
||||
"gateway",
|
||||
fmt.Sprintf("Failed to attach to new gateway process (PID: %d): %v", healthResp.Pid, err),
|
||||
)
|
||||
} else {
|
||||
for k, v := range healthData {
|
||||
data[k] = v
|
||||
// Successfully attached, update response data
|
||||
bootDefaultModel := gateway.bootDefaultModel
|
||||
if bootDefaultModel != "" {
|
||||
data["boot_default_model"] = bootDefaultModel
|
||||
}
|
||||
data["gateway_status"] = "running"
|
||||
data["pid"] = healthResp.Pid
|
||||
}
|
||||
}
|
||||
|
||||
bootDefaultModel := gateway.bootDefaultModel
|
||||
if bootDefaultModel != "" {
|
||||
data["boot_default_model"] = bootDefaultModel
|
||||
}
|
||||
data["gateway_status"] = "running"
|
||||
data["pid"] = healthResp.Pid
|
||||
gateway.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
bootDefaultModel, _ := data["boot_default_model"].(string)
|
||||
gatewayStatus, _ := data["gateway_status"].(string)
|
||||
data["gateway_restart_required"] = gatewayRestartRequired(
|
||||
configDefaultModel,
|
||||
bootDefaultModel,
|
||||
gatewayStatus,
|
||||
)
|
||||
|
||||
ready, reason, readyErr := h.gatewayStartReady()
|
||||
if readyErr != nil {
|
||||
data["gateway_start_allowed"] = false
|
||||
@@ -434,16 +803,22 @@ func (h *Handler) handleGatewayStatus(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
// Append incremental log data
|
||||
appendGatewayLogs(r, data)
|
||||
return data
|
||||
}
|
||||
|
||||
// handleGatewayLogs returns buffered gateway logs, optionally incrementally.
|
||||
//
|
||||
// GET /api/gateway/logs
|
||||
func (h *Handler) handleGatewayLogs(w http.ResponseWriter, r *http.Request) {
|
||||
data := gatewayLogsData(r)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(data)
|
||||
}
|
||||
|
||||
// appendGatewayLogs reads log_offset and log_run_id query params from the request
|
||||
// and populates the response data map with incremental log lines.
|
||||
func appendGatewayLogs(r *http.Request, data map[string]any) {
|
||||
// gatewayLogsData reads log_offset and log_run_id query params from the request
|
||||
// and returns incremental log lines.
|
||||
func gatewayLogsData(r *http.Request) map[string]any {
|
||||
data := map[string]any{}
|
||||
clientOffset := 0
|
||||
clientRunID := -1
|
||||
|
||||
@@ -465,7 +840,7 @@ func appendGatewayLogs(r *http.Request, data map[string]any) {
|
||||
data["logs"] = []string{}
|
||||
data["log_total"] = 0
|
||||
data["log_run_id"] = 0
|
||||
return
|
||||
return data
|
||||
}
|
||||
|
||||
// If runID changed, reset offset to get all logs from new run
|
||||
@@ -482,72 +857,7 @@ func appendGatewayLogs(r *http.Request, data map[string]any) {
|
||||
data["logs"] = lines
|
||||
data["log_total"] = total
|
||||
data["log_run_id"] = runID
|
||||
}
|
||||
|
||||
// handleGatewayEvents serves an SSE stream of gateway state change events.
|
||||
//
|
||||
// GET /api/gateway/events
|
||||
func (h *Handler) handleGatewayEvents(w http.ResponseWriter, r *http.Request) {
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
http.Error(w, "SSE not supported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
|
||||
// Subscribe to gateway events
|
||||
ch := gateway.events.Subscribe()
|
||||
defer gateway.events.Unsubscribe(ch)
|
||||
|
||||
// Send initial status so the client doesn't start blank
|
||||
initial := h.currentGatewayStatus()
|
||||
fmt.Fprintf(w, "data: %s\n\n", initial)
|
||||
flusher.Flush()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-r.Context().Done():
|
||||
return
|
||||
case data, ok := <-ch:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// currentGatewayStatus returns the current gateway status as a JSON string.
|
||||
func (h *Handler) currentGatewayStatus() string {
|
||||
gateway.mu.Lock()
|
||||
defer gateway.mu.Unlock()
|
||||
|
||||
data := map[string]any{
|
||||
"gateway_status": "stopped",
|
||||
}
|
||||
if isGatewayProcessAliveLocked() {
|
||||
data["gateway_status"] = "running"
|
||||
data["pid"] = gateway.cmd.Process.Pid
|
||||
}
|
||||
|
||||
ready, reason, readyErr := h.gatewayStartReady()
|
||||
if readyErr != nil {
|
||||
data["gateway_start_allowed"] = false
|
||||
data["gateway_start_reason"] = readyErr.Error()
|
||||
} else {
|
||||
data["gateway_start_allowed"] = ready
|
||||
if !ready {
|
||||
data["gateway_start_reason"] = reason
|
||||
}
|
||||
}
|
||||
|
||||
encoded, _ := json.Marshal(data)
|
||||
return string(encoded)
|
||||
return data
|
||||
}
|
||||
|
||||
// scanPipe reads lines from r and appends them to buf. Returns when r reaches EOF.
|
||||
|
||||
@@ -3,6 +3,7 @@ package api
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@@ -46,6 +47,23 @@ func gatewayProbeHost(bindHost string) string {
|
||||
return bindHost
|
||||
}
|
||||
|
||||
func (h *Handler) gatewayProxyURL() *url.URL {
|
||||
cfg, err := config.LoadConfig(h.configPath)
|
||||
port := 18790
|
||||
bindHost := ""
|
||||
if err == nil && cfg != nil {
|
||||
if cfg.Gateway.Port != 0 {
|
||||
port = cfg.Gateway.Port
|
||||
}
|
||||
bindHost = h.effectiveGatewayBindHost(cfg)
|
||||
}
|
||||
|
||||
return &url.URL{
|
||||
Scheme: "http",
|
||||
Host: net.JoinHostPort(gatewayProbeHost(bindHost), strconv.Itoa(port)),
|
||||
}
|
||||
}
|
||||
|
||||
func requestHostName(r *http.Request) string {
|
||||
reqHost, _, err := net.SplitHostPort(r.Host)
|
||||
if err == nil {
|
||||
@@ -57,10 +75,34 @@ func requestHostName(r *http.Request) string {
|
||||
return "127.0.0.1"
|
||||
}
|
||||
|
||||
func requestWSScheme(r *http.Request) string {
|
||||
if forwarded := strings.TrimSpace(r.Header.Get("X-Forwarded-Proto")); forwarded != "" {
|
||||
proto := strings.ToLower(strings.TrimSpace(strings.Split(forwarded, ",")[0]))
|
||||
if proto == "https" || proto == "wss" {
|
||||
return "wss"
|
||||
}
|
||||
if proto == "http" || proto == "ws" {
|
||||
return "ws"
|
||||
}
|
||||
}
|
||||
|
||||
if r.TLS != nil {
|
||||
return "wss"
|
||||
}
|
||||
|
||||
return "ws"
|
||||
}
|
||||
|
||||
func (h *Handler) buildWsURL(r *http.Request, cfg *config.Config) string {
|
||||
host := h.effectiveGatewayBindHost(cfg)
|
||||
if host == "" || host == "0.0.0.0" {
|
||||
host = requestHostName(r)
|
||||
}
|
||||
return "ws://" + net.JoinHostPort(host, strconv.Itoa(cfg.Gateway.Port)) + "/pico/ws"
|
||||
// Use web server port instead of gateway port to avoid exposing extra ports
|
||||
// The WebSocket connection will be proxied by the backend to the gateway
|
||||
wsPort := h.serverPort
|
||||
if wsPort == 0 {
|
||||
wsPort = 18800 // default web server port
|
||||
}
|
||||
return requestWSScheme(r) + "://" + net.JoinHostPort(host, strconv.Itoa(wsPort)) + "/pico/ws"
|
||||
}
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/web/backend/launcherconfig"
|
||||
@@ -47,8 +51,8 @@ func TestBuildWsURLUsesRequestHostWhenLauncherPublicSaved(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "http://launcher.local/api/pico/token", nil)
|
||||
req.Host = "192.168.1.9:18800"
|
||||
|
||||
if got := h.buildWsURL(req, cfg); got != "ws://192.168.1.9:18790/pico/ws" {
|
||||
t.Fatalf("buildWsURL() = %q, want %q", got, "ws://192.168.1.9:18790/pico/ws")
|
||||
if got := h.buildWsURL(req, cfg); got != "ws://192.168.1.9:18800/pico/ws" {
|
||||
t.Fatalf("buildWsURL() = %q, want %q", got, "ws://192.168.1.9:18800/pico/ws")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -57,3 +61,128 @@ func TestGatewayProbeHostUsesLoopbackForWildcardBind(t *testing.T) {
|
||||
t.Fatalf("gatewayProbeHost() = %q, want %q", got, "127.0.0.1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayProxyURLUsesConfiguredHost(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Gateway.Host = "192.168.1.10"
|
||||
cfg.Gateway.Port = 18791
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
if got := h.gatewayProxyURL().String(); got != "http://192.168.1.10:18791" {
|
||||
t.Fatalf("gatewayProxyURL() = %q, want %q", got, "http://192.168.1.10:18791")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetGatewayHealthUsesConfiguredHost(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Gateway.Host = "192.168.1.10"
|
||||
cfg.Gateway.Port = 18791
|
||||
|
||||
originalHealthGet := gatewayHealthGet
|
||||
t.Cleanup(func() {
|
||||
gatewayHealthGet = originalHealthGet
|
||||
})
|
||||
|
||||
var requestedURL string
|
||||
gatewayHealthGet = func(url string, timeout time.Duration) (*http.Response, error) {
|
||||
requestedURL = url
|
||||
return nil, errors.New("probe failed")
|
||||
}
|
||||
|
||||
_, statusCode, err := h.getGatewayHealth(cfg, time.Second)
|
||||
_ = statusCode
|
||||
_ = err
|
||||
|
||||
if requestedURL != "http://192.168.1.10:18791/health" {
|
||||
t.Fatalf("health url = %q, want %q", requestedURL, "http://192.168.1.10:18791/health")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetGatewayHealthUsesProbeHostForPublicLauncher(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
h.SetServerOptions(18800, true, true, nil)
|
||||
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Gateway.Host = "127.0.0.1"
|
||||
cfg.Gateway.Port = 18791
|
||||
|
||||
originalHealthGet := gatewayHealthGet
|
||||
t.Cleanup(func() {
|
||||
gatewayHealthGet = originalHealthGet
|
||||
})
|
||||
|
||||
var requestedURL string
|
||||
gatewayHealthGet = func(url string, timeout time.Duration) (*http.Response, error) {
|
||||
requestedURL = url
|
||||
return nil, errors.New("probe failed")
|
||||
}
|
||||
|
||||
_, statusCode, err := h.getGatewayHealth(cfg, time.Second)
|
||||
_ = statusCode
|
||||
_ = err
|
||||
|
||||
if requestedURL != "http://127.0.0.1:18791/health" {
|
||||
t.Fatalf("health url = %q, want %q", requestedURL, "http://127.0.0.1:18791/health")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildWsURLUsesWSSWhenForwardedProtoIsHTTPS(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Gateway.Host = "0.0.0.0"
|
||||
cfg.Gateway.Port = 18790
|
||||
|
||||
req := httptest.NewRequest("GET", "http://launcher.local/api/pico/token", nil)
|
||||
req.Host = "chat.example.com"
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
|
||||
if got := h.buildWsURL(req, cfg); got != "wss://chat.example.com:18800/pico/ws" {
|
||||
t.Fatalf("buildWsURL() = %q, want %q", got, "wss://chat.example.com:18800/pico/ws")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildWsURLUsesWSSWhenRequestIsTLS(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Gateway.Host = "0.0.0.0"
|
||||
cfg.Gateway.Port = 18790
|
||||
|
||||
req := httptest.NewRequest("GET", "https://launcher.local/api/pico/token", nil)
|
||||
req.Host = "secure.example.com"
|
||||
req.TLS = &tls.ConnectionState{}
|
||||
|
||||
if got := h.buildWsURL(req, cfg); got != "wss://secure.example.com:18800/pico/ws" {
|
||||
t.Fatalf("buildWsURL() = %q, want %q", got, "wss://secure.example.com:18800/pico/ws")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildWsURLPrefersForwardedHTTPOverTLS(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Gateway.Host = "0.0.0.0"
|
||||
cfg.Gateway.Port = 18790
|
||||
|
||||
req := httptest.NewRequest("GET", "https://launcher.local/api/pico/token", nil)
|
||||
req.Host = "chat.example.com"
|
||||
req.TLS = &tls.ConnectionState{}
|
||||
req.Header.Set("X-Forwarded-Proto", "http")
|
||||
|
||||
if got := h.buildWsURL(req, cfg); got != "ws://chat.example.com:18800/pico/ws" {
|
||||
t.Fatalf("buildWsURL() = %q, want %q", got, "ws://chat.example.com:18800/pico/ws")
|
||||
}
|
||||
}
|
||||
|
||||
+554
-13
@@ -2,19 +2,86 @@ package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/auth"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/web/backend/utils"
|
||||
)
|
||||
|
||||
func startLongRunningProcess(t *testing.T) *exec.Cmd {
|
||||
t.Helper()
|
||||
|
||||
var cmd *exec.Cmd
|
||||
if runtime.GOOS == "windows" {
|
||||
cmd = exec.Command("powershell", "-NoProfile", "-Command", "Start-Sleep -Seconds 30")
|
||||
} else {
|
||||
cmd = exec.Command("sleep", "30")
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
t.Fatalf("Start() error = %v", err)
|
||||
}
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func mockGatewayHealthResponse(statusCode, pid int) *http.Response {
|
||||
return &http.Response{
|
||||
StatusCode: statusCode,
|
||||
Body: io.NopCloser(strings.NewReader(
|
||||
`{"status":"ok","uptime":"1s","pid":` + strconv.Itoa(pid) + `}`,
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
func startIgnoringTermProcess(t *testing.T) *exec.Cmd {
|
||||
t.Helper()
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("TERM handling differs on Windows")
|
||||
}
|
||||
|
||||
cmd := exec.Command("sh", "-c", "trap '' TERM; sleep 30")
|
||||
if err := cmd.Start(); err != nil {
|
||||
t.Fatalf("Start() error = %v", err)
|
||||
}
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func resetGatewayTestState(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
originalHealthGet := gatewayHealthGet
|
||||
originalRestartGracePeriod := gatewayRestartGracePeriod
|
||||
originalRestartForceKillWindow := gatewayRestartForceKillWindow
|
||||
originalRestartPollInterval := gatewayRestartPollInterval
|
||||
t.Cleanup(func() {
|
||||
gatewayHealthGet = originalHealthGet
|
||||
gatewayRestartGracePeriod = originalRestartGracePeriod
|
||||
gatewayRestartForceKillWindow = originalRestartForceKillWindow
|
||||
gatewayRestartPollInterval = originalRestartPollInterval
|
||||
|
||||
gateway.mu.Lock()
|
||||
gateway.cmd = nil
|
||||
gateway.bootDefaultModel = ""
|
||||
setGatewayRuntimeStatusLocked("stopped")
|
||||
gateway.mu.Unlock()
|
||||
})
|
||||
}
|
||||
|
||||
func TestGatewayStartReady_NoDefaultModel(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
@@ -247,7 +314,7 @@ func TestGatewayStartReady_OAuthModelRequiresStoredCredential(t *testing.T) {
|
||||
}
|
||||
cfg.ModelList = []config.ModelConfig{{
|
||||
ModelName: "openai-oauth",
|
||||
Model: "openai/gpt-5.2",
|
||||
Model: "openai/gpt-5.4",
|
||||
AuthMethod: "oauth",
|
||||
}}
|
||||
cfg.Agents.Defaults.ModelName = "openai-oauth"
|
||||
@@ -317,6 +384,477 @@ func TestGatewayStatusIncludesStartConditionWhenNotReady(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayStatusKeepsRunningWhenHealthProbeFailsAfterRunning(t *testing.T) {
|
||||
resetGatewayTestState(t)
|
||||
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
cmd := startLongRunningProcess(t)
|
||||
t.Cleanup(func() {
|
||||
if cmd.Process != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
}
|
||||
_ = cmd.Wait()
|
||||
})
|
||||
|
||||
gateway.mu.Lock()
|
||||
gateway.cmd = cmd
|
||||
gateway.bootDefaultModel = "existing-model"
|
||||
// Simulate a process that has already reached the running state.
|
||||
setGatewayRuntimeStatusLocked("running")
|
||||
gateway.mu.Unlock()
|
||||
|
||||
gatewayHealthGet = func(string, time.Duration) (*http.Response, error) {
|
||||
return nil, errors.New("probe failed")
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var body map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if got := body["gateway_status"]; got != "running" {
|
||||
t.Fatalf("gateway_status = %#v, want %q", got, "running")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayStatusReportsRunningFromHealthProbe(t *testing.T) {
|
||||
resetGatewayTestState(t)
|
||||
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
cmd := startLongRunningProcess(t)
|
||||
t.Cleanup(func() {
|
||||
if cmd.Process != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
}
|
||||
_ = cmd.Wait()
|
||||
})
|
||||
|
||||
gateway.mu.Lock()
|
||||
setGatewayRuntimeStatusLocked("stopped")
|
||||
gateway.mu.Unlock()
|
||||
|
||||
gatewayHealthGet = func(string, time.Duration) (*http.Response, error) {
|
||||
return mockGatewayHealthResponse(http.StatusOK, cmd.Process.Pid), nil
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var body map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if got := body["gateway_status"]; got != "running" {
|
||||
t.Fatalf("gateway_status = %#v, want %q", got, "running")
|
||||
}
|
||||
if got := body["pid"]; got != float64(cmd.Process.Pid) {
|
||||
t.Fatalf("pid = %#v, want %d", got, cmd.Process.Pid)
|
||||
}
|
||||
if got := body["gateway_restart_required"]; got != false {
|
||||
t.Fatalf("gateway_restart_required = %#v, want false", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayStatusRequiresRestartAfterDefaultModelChange(t *testing.T) {
|
||||
resetGatewayTestState(t)
|
||||
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Agents.Defaults.ModelName = cfg.ModelList[0].ModelName
|
||||
cfg.ModelList[0].APIKey = "test-key"
|
||||
cfg.ModelList = append(cfg.ModelList, config.ModelConfig{
|
||||
ModelName: "second-model",
|
||||
Model: "openai/gpt-4.1",
|
||||
APIKey: "second-key",
|
||||
})
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
process, err := os.FindProcess(os.Getpid())
|
||||
if err != nil {
|
||||
t.Fatalf("FindProcess() error = %v", err)
|
||||
}
|
||||
|
||||
gateway.mu.Lock()
|
||||
gateway.cmd = &exec.Cmd{Process: process}
|
||||
gateway.bootDefaultModel = cfg.ModelList[0].ModelName
|
||||
setGatewayRuntimeStatusLocked("running")
|
||||
gateway.mu.Unlock()
|
||||
|
||||
updatedCfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
updatedCfg.Agents.Defaults.ModelName = "second-model"
|
||||
if err := config.SaveConfig(configPath, updatedCfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
gatewayHealthGet = func(string, time.Duration) (*http.Response, error) {
|
||||
return mockGatewayHealthResponse(http.StatusOK, os.Getpid()), nil
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var body map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if got := body["gateway_status"]; got != "running" {
|
||||
t.Fatalf("gateway_status = %#v, want %q", got, "running")
|
||||
}
|
||||
if got := body["boot_default_model"]; got != cfg.ModelList[0].ModelName {
|
||||
t.Fatalf("boot_default_model = %#v, want %q", got, cfg.ModelList[0].ModelName)
|
||||
}
|
||||
if got := body["config_default_model"]; got != "second-model" {
|
||||
t.Fatalf("config_default_model = %#v, want %q", got, "second-model")
|
||||
}
|
||||
if got := body["gateway_restart_required"]; got != true {
|
||||
t.Fatalf("gateway_restart_required = %#v, want true", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayStatusReturnsErrorAfterStartupWindowExpires(t *testing.T) {
|
||||
resetGatewayTestState(t)
|
||||
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
cmd := startLongRunningProcess(t)
|
||||
t.Cleanup(func() {
|
||||
if cmd.Process != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
}
|
||||
_ = cmd.Wait()
|
||||
})
|
||||
|
||||
gateway.mu.Lock()
|
||||
gateway.cmd = cmd
|
||||
gateway.bootDefaultModel = "existing-model"
|
||||
setGatewayRuntimeStatusLocked("starting")
|
||||
gateway.startupDeadline = time.Now().Add(-time.Second)
|
||||
gateway.mu.Unlock()
|
||||
|
||||
gatewayHealthGet = func(string, time.Duration) (*http.Response, error) {
|
||||
return nil, errors.New("probe failed")
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var body map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if got := body["gateway_status"]; got != "error" {
|
||||
t.Fatalf("gateway_status = %#v, want %q", got, "error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayStatusReturnsRestartingDuringRestartGap(t *testing.T) {
|
||||
resetGatewayTestState(t)
|
||||
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
gateway.mu.Lock()
|
||||
setGatewayRuntimeStatusLocked("restarting")
|
||||
gateway.mu.Unlock()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var body map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if got := body["gateway_status"]; got != "restarting" {
|
||||
t.Fatalf("gateway_status = %#v, want %q", got, "restarting")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayRestartKeepsRunningProcessWhenPreconditionsFail(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Agents.Defaults.ModelName = cfg.ModelList[0].ModelName
|
||||
cfg.ModelList[0].APIKey = ""
|
||||
cfg.ModelList[0].AuthMethod = ""
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
cmd := startLongRunningProcess(t)
|
||||
t.Cleanup(func() {
|
||||
gateway.mu.Lock()
|
||||
if gateway.cmd == cmd {
|
||||
gateway.cmd = nil
|
||||
gateway.bootDefaultModel = ""
|
||||
}
|
||||
gateway.mu.Unlock()
|
||||
|
||||
if cmd.Process != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
}
|
||||
_ = cmd.Wait()
|
||||
})
|
||||
|
||||
gateway.mu.Lock()
|
||||
gateway.cmd = cmd
|
||||
gateway.bootDefaultModel = "existing-model"
|
||||
gateway.mu.Unlock()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/gateway/restart", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest)
|
||||
}
|
||||
|
||||
gateway.mu.Lock()
|
||||
stillRunning := gateway.cmd == cmd && isCmdProcessAliveLocked(cmd)
|
||||
gateway.mu.Unlock()
|
||||
|
||||
if !stillRunning {
|
||||
t.Fatalf("gateway process was stopped when restart preconditions failed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayRestartKeepsOldProcessWhenItDoesNotExitInTime(t *testing.T) {
|
||||
resetGatewayTestState(t)
|
||||
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Agents.Defaults.ModelName = cfg.ModelList[0].ModelName
|
||||
cfg.ModelList[0].APIKey = "test-key"
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
cmd := startIgnoringTermProcess(t)
|
||||
t.Cleanup(func() {
|
||||
gateway.mu.Lock()
|
||||
if gateway.cmd == cmd {
|
||||
gateway.cmd = nil
|
||||
gateway.bootDefaultModel = ""
|
||||
}
|
||||
gateway.mu.Unlock()
|
||||
|
||||
if cmd.Process != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
}
|
||||
_ = cmd.Wait()
|
||||
})
|
||||
|
||||
gatewayRestartGracePeriod = 150 * time.Millisecond
|
||||
gatewayRestartForceKillWindow = 150 * time.Millisecond
|
||||
gatewayRestartPollInterval = 10 * time.Millisecond
|
||||
|
||||
gateway.mu.Lock()
|
||||
gateway.cmd = cmd
|
||||
gateway.bootDefaultModel = "existing-model"
|
||||
setGatewayRuntimeStatusLocked("running")
|
||||
gateway.mu.Unlock()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/gateway/restart", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusInternalServerError {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
gateway.mu.Lock()
|
||||
stillRunning := gateway.cmd == cmd && isCmdProcessAliveLocked(cmd)
|
||||
status := gateway.runtimeStatus
|
||||
gateway.mu.Unlock()
|
||||
|
||||
if !stillRunning {
|
||||
t.Fatalf("gateway process was replaced before the old process exited")
|
||||
}
|
||||
if status != "running" {
|
||||
t.Fatalf("runtimeStatus = %q, want %q", status, "running")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayRestartReturnsErrorStatusWhenReplacementFailsToStart(t *testing.T) {
|
||||
resetGatewayTestState(t)
|
||||
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Agents.Defaults.ModelName = cfg.ModelList[0].ModelName
|
||||
cfg.ModelList[0].APIKey = "test-key"
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
invalidBinaryPath := filepath.Join(t.TempDir(), "fake-picoclaw")
|
||||
if err := os.WriteFile(invalidBinaryPath, []byte("#!/bin/sh\n"), 0o644); err != nil {
|
||||
t.Fatalf("WriteFile() error = %v", err)
|
||||
}
|
||||
t.Setenv("PICOCLAW_BINARY", invalidBinaryPath)
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/gateway/restart", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusInternalServerError {
|
||||
t.Fatalf("restart status = %d, want %d", rec.Code, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
statusRec := httptest.NewRecorder()
|
||||
statusReq := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil)
|
||||
mux.ServeHTTP(statusRec, statusReq)
|
||||
|
||||
if statusRec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", statusRec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var body map[string]any
|
||||
if err := json.Unmarshal(statusRec.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if got := body["gateway_status"]; got != "error" {
|
||||
t.Fatalf("gateway_status = %#v, want %q", got, "error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayStatusExcludesLogsFields(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var body map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := body["logs"]; ok {
|
||||
t.Fatalf("logs unexpectedly present in status response: %#v", body["logs"])
|
||||
}
|
||||
if _, ok := body["log_total"]; ok {
|
||||
t.Fatalf("log_total unexpectedly present in status response: %#v", body["log_total"])
|
||||
}
|
||||
if _, ok := body["log_run_id"]; ok {
|
||||
t.Fatalf("log_run_id unexpectedly present in status response: %#v", body["log_run_id"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayLogsReturnsIncrementalHistory(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
gateway.logs.Clear()
|
||||
gateway.logs.Append("first line")
|
||||
gateway.logs.Append("second line")
|
||||
runID := gateway.logs.RunID()
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(
|
||||
http.MethodGet,
|
||||
"/api/gateway/logs?log_offset=1&log_run_id="+strconv.Itoa(runID),
|
||||
nil,
|
||||
)
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("logs status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var body map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("unmarshal logs response: %v", err)
|
||||
}
|
||||
|
||||
logs, ok := body["logs"].([]any)
|
||||
if !ok {
|
||||
t.Fatalf("logs missing or not array: %#v", body["logs"])
|
||||
}
|
||||
if len(logs) != 1 || logs[0] != "second line" {
|
||||
t.Fatalf("logs = %#v, want [\"second line\"]", logs)
|
||||
}
|
||||
if got := body["log_total"]; got != float64(2) {
|
||||
t.Fatalf("log_total = %#v, want 2", got)
|
||||
}
|
||||
if got := body["log_run_id"]; got != float64(runID) {
|
||||
t.Fatalf("log_run_id = %#v, want %d", got, runID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayClearLogsResetsBufferedHistory(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
@@ -353,33 +891,36 @@ func TestGatewayClearLogsResetsBufferedHistory(t *testing.T) {
|
||||
t.Fatalf("log_run_id = %d, want > %d", int(clearRunID), previousRunID)
|
||||
}
|
||||
|
||||
statusRec := httptest.NewRecorder()
|
||||
statusReq := httptest.NewRequest(
|
||||
logsRec := httptest.NewRecorder()
|
||||
logsReq := httptest.NewRequest(
|
||||
http.MethodGet,
|
||||
"/api/gateway/status?log_offset=0&log_run_id="+strconv.Itoa(previousRunID),
|
||||
"/api/gateway/logs?log_offset=0&log_run_id="+strconv.Itoa(previousRunID),
|
||||
nil,
|
||||
)
|
||||
mux.ServeHTTP(statusRec, statusReq)
|
||||
mux.ServeHTTP(logsRec, logsReq)
|
||||
|
||||
if statusRec.Code != http.StatusOK {
|
||||
t.Fatalf("status code = %d, want %d", statusRec.Code, http.StatusOK)
|
||||
if logsRec.Code != http.StatusOK {
|
||||
t.Fatalf("logs code = %d, want %d", logsRec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var statusBody map[string]any
|
||||
if err := json.Unmarshal(statusRec.Body.Bytes(), &statusBody); err != nil {
|
||||
t.Fatalf("unmarshal status response: %v", err)
|
||||
var logsBody map[string]any
|
||||
if err := json.Unmarshal(logsRec.Body.Bytes(), &logsBody); err != nil {
|
||||
t.Fatalf("unmarshal logs response: %v", err)
|
||||
}
|
||||
|
||||
logs, ok := statusBody["logs"].([]any)
|
||||
logs, ok := logsBody["logs"].([]any)
|
||||
if !ok {
|
||||
t.Fatalf("logs missing or not array: %#v", statusBody["logs"])
|
||||
t.Fatalf("logs missing or not array: %#v", logsBody["logs"])
|
||||
}
|
||||
if len(logs) != 0 {
|
||||
t.Fatalf("logs len = %d, want 0", len(logs))
|
||||
}
|
||||
if got := statusBody["log_total"]; got != float64(0) {
|
||||
if got := logsBody["log_total"]; got != float64(0) {
|
||||
t.Fatalf("log_total = %#v, want 0", got)
|
||||
}
|
||||
if got := logsBody["log_run_id"]; got != clearBody["log_run_id"] {
|
||||
t.Fatalf("log_run_id = %#v, want %#v", got, clearBody["log_run_id"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindPicoclawBinary_EnvOverride(t *testing.T) {
|
||||
|
||||
@@ -62,7 +62,7 @@ func TestHandleListModels_ConfiguredStatusUsesRuntimeProbesForLocalModels(t *tes
|
||||
cfg.ModelList = []config.ModelConfig{
|
||||
{
|
||||
ModelName: "openai-oauth",
|
||||
Model: "openai/gpt-5.2",
|
||||
Model: "openai/gpt-5.4",
|
||||
AuthMethod: "oauth",
|
||||
},
|
||||
{
|
||||
@@ -81,8 +81,8 @@ func TestHandleListModels_ConfiguredStatusUsesRuntimeProbesForLocalModels(t *tes
|
||||
APIKey: "remote-key",
|
||||
},
|
||||
{
|
||||
ModelName: "copilot-gpt-5.2",
|
||||
Model: "github-copilot/gpt-5.2",
|
||||
ModelName: "copilot-gpt-5.4",
|
||||
Model: "github-copilot/gpt-5.4",
|
||||
APIBase: "http://127.0.0.1:4321",
|
||||
AuthMethod: "oauth",
|
||||
},
|
||||
@@ -128,7 +128,7 @@ func TestHandleListModels_ConfiguredStatusUsesRuntimeProbesForLocalModels(t *tes
|
||||
if !got["vllm-remote"] {
|
||||
t.Fatalf("remote vllm model configured = false, want true with api_key")
|
||||
}
|
||||
if !got["copilot-gpt-5.2"] {
|
||||
if !got["copilot-gpt-5.4"] {
|
||||
t.Fatalf("copilot model configured = false, want true when local bridge probe succeeds")
|
||||
}
|
||||
if len(openAIProbes) != 1 || openAIProbes[0] != "http://127.0.0.1:8000/v1|custom-model" {
|
||||
|
||||
@@ -7,13 +7,13 @@ import (
|
||||
"fmt"
|
||||
"html"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/auth"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
@@ -714,7 +714,7 @@ func (h *Handler) persistCredentialAndConfig(provider, authMethod string, cred *
|
||||
if cp.Email == "" {
|
||||
email, err := oauthFetchGoogleUserEmailFunc(cp.AccessToken)
|
||||
if err != nil {
|
||||
log.Printf("oauth warning: could not fetch google email: %v", err)
|
||||
logger.ErrorC("oauth", fmt.Sprintf("oauth warning: could not fetch google email: %v", err))
|
||||
} else {
|
||||
cp.Email = email
|
||||
}
|
||||
@@ -722,7 +722,7 @@ func (h *Handler) persistCredentialAndConfig(provider, authMethod string, cred *
|
||||
if cp.ProjectID == "" {
|
||||
projectID, err := oauthFetchAntigravityProject(cp.AccessToken)
|
||||
if err != nil {
|
||||
log.Printf("oauth warning: could not fetch antigravity project id: %v", err)
|
||||
logger.ErrorC("oauth", fmt.Sprintf("oauth warning: could not fetch antigravity project id: %v", err))
|
||||
} else {
|
||||
cp.ProjectID = projectID
|
||||
}
|
||||
@@ -780,8 +780,8 @@ func defaultModelConfigForProvider(provider, authMethod string) config.ModelConf
|
||||
switch provider {
|
||||
case oauthProviderOpenAI:
|
||||
return config.ModelConfig{
|
||||
ModelName: "gpt-5.2",
|
||||
Model: "openai/gpt-5.2",
|
||||
ModelName: "gpt-5.4",
|
||||
Model: "openai/gpt-5.4",
|
||||
AuthMethod: authMethod,
|
||||
}
|
||||
case oauthProviderAnthropic:
|
||||
|
||||
@@ -167,8 +167,8 @@ func TestOAuthLogoutClearsCredentialAndConfig(t *testing.T) {
|
||||
t.Fatalf("LoadConfig error: %v", err)
|
||||
}
|
||||
cfg.ModelList = append(cfg.ModelList, config.ModelConfig{
|
||||
ModelName: "gpt-5.2",
|
||||
Model: "openai/gpt-5.2",
|
||||
ModelName: "gpt-5.4",
|
||||
Model: "openai/gpt-5.4",
|
||||
AuthMethod: "oauth",
|
||||
})
|
||||
if err = config.SaveConfig(configPath, cfg); err != nil {
|
||||
|
||||
+37
-12
@@ -6,6 +6,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
@@ -16,6 +17,30 @@ func (h *Handler) registerPicoRoutes(mux *http.ServeMux) {
|
||||
mux.HandleFunc("GET /api/pico/token", h.handleGetPicoToken)
|
||||
mux.HandleFunc("POST /api/pico/token", h.handleRegenPicoToken)
|
||||
mux.HandleFunc("POST /api/pico/setup", h.handlePicoSetup)
|
||||
|
||||
// WebSocket proxy: forward /pico/ws to gateway
|
||||
// This allows the frontend to connect via the same port as the web UI,
|
||||
// avoiding the need to expose extra ports for WebSocket communication.
|
||||
mux.HandleFunc("GET /pico/ws", h.handleWebSocketProxy())
|
||||
}
|
||||
|
||||
// createWsProxy creates a reverse proxy to the current gateway WebSocket endpoint.
|
||||
// The gateway bind host and port are resolved from the latest configuration.
|
||||
func (h *Handler) createWsProxy() *httputil.ReverseProxy {
|
||||
wsProxy := httputil.NewSingleHostReverseProxy(h.gatewayProxyURL())
|
||||
wsProxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
http.Error(w, "Gateway unavailable: "+err.Error(), http.StatusBadGateway)
|
||||
}
|
||||
return wsProxy
|
||||
}
|
||||
|
||||
// handleWebSocketProxy wraps a reverse proxy to handle WebSocket connections.
|
||||
// The reverse proxy forwards the incoming upgrade handshake as-is.
|
||||
func (h *Handler) handleWebSocketProxy() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
proxy := h.createWsProxy()
|
||||
proxy.ServeHTTP(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// handleGetPicoToken returns the current WS token and URL for the frontend.
|
||||
@@ -65,9 +90,14 @@ func (h *Handler) handleRegenPicoToken(w http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
}
|
||||
|
||||
// ensurePicoChannel checks if the Pico Channel is properly configured and
|
||||
// enables it with sensible defaults if not. Returns true if config was changed.
|
||||
func (h *Handler) ensurePicoChannel() (bool, error) {
|
||||
// ensurePicoChannel enables the Pico channel with sane defaults if it isn't
|
||||
// already configured. Returns true when the config was modified.
|
||||
//
|
||||
// callerOrigin is the Origin header from the setup request. If non-empty and
|
||||
// no origins are configured yet, it's written as the allowed origin so the
|
||||
// WebSocket handshake works for whatever host the caller is on (LAN, custom
|
||||
// port, etc.). Pass "" when there's no request context.
|
||||
func (h *Handler) ensurePicoChannel(callerOrigin string) (bool, error) {
|
||||
cfg, err := config.LoadConfig(h.configPath)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to load config: %w", err)
|
||||
@@ -85,14 +115,9 @@ func (h *Handler) ensurePicoChannel() (bool, error) {
|
||||
changed = true
|
||||
}
|
||||
|
||||
if !cfg.Channels.Pico.AllowTokenQuery {
|
||||
cfg.Channels.Pico.AllowTokenQuery = true
|
||||
changed = true
|
||||
}
|
||||
|
||||
// Make sure origins are allowed (frontend might be running on a different port like 5173 during dev)
|
||||
if len(cfg.Channels.Pico.AllowOrigins) == 0 {
|
||||
cfg.Channels.Pico.AllowOrigins = []string{"*"}
|
||||
// Seed origins from the request instead of hardcoding ports.
|
||||
if len(cfg.Channels.Pico.AllowOrigins) == 0 && callerOrigin != "" {
|
||||
cfg.Channels.Pico.AllowOrigins = []string{callerOrigin}
|
||||
changed = true
|
||||
}
|
||||
|
||||
@@ -109,7 +134,7 @@ func (h *Handler) ensurePicoChannel() (bool, error) {
|
||||
//
|
||||
// POST /api/pico/setup
|
||||
func (h *Handler) handlePicoSetup(w http.ResponseWriter, r *http.Request) {
|
||||
changed, err := h.ensurePicoChannel()
|
||||
changed, err := h.ensurePicoChannel(r.Header.Get("Origin"))
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
|
||||
@@ -0,0 +1,314 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
func TestEnsurePicoChannel_FreshConfig(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
changed, err := h.ensurePicoChannel("")
|
||||
if err != nil {
|
||||
t.Fatalf("ensurePicoChannel() error = %v", err)
|
||||
}
|
||||
if !changed {
|
||||
t.Fatal("ensurePicoChannel() should report changed on a fresh config")
|
||||
}
|
||||
|
||||
cfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
|
||||
if !cfg.Channels.Pico.Enabled {
|
||||
t.Error("expected Pico to be enabled after setup")
|
||||
}
|
||||
if cfg.Channels.Pico.Token == "" {
|
||||
t.Error("expected a non-empty token after setup")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsurePicoChannel_DoesNotEnableTokenQuery(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
if _, err := h.ensurePicoChannel(""); err != nil {
|
||||
t.Fatalf("ensurePicoChannel() error = %v", err)
|
||||
}
|
||||
|
||||
cfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
|
||||
if cfg.Channels.Pico.AllowTokenQuery {
|
||||
t.Error("setup must not enable allow_token_query by default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsurePicoChannel_DoesNotSetWildcardOrigins(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
if _, err := h.ensurePicoChannel("http://localhost:18800"); err != nil {
|
||||
t.Fatalf("ensurePicoChannel() error = %v", err)
|
||||
}
|
||||
|
||||
cfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
|
||||
for _, origin := range cfg.Channels.Pico.AllowOrigins {
|
||||
if origin == "*" {
|
||||
t.Error("setup must not set wildcard origin '*'")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsurePicoChannel_NoOriginWithoutCaller(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
if _, err := h.ensurePicoChannel(""); err != nil {
|
||||
t.Fatalf("ensurePicoChannel() error = %v", err)
|
||||
}
|
||||
|
||||
cfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
|
||||
// Without a caller origin, allow_origins stays empty (CheckOrigin
|
||||
// allows all when the list is empty, so the channel still works).
|
||||
if len(cfg.Channels.Pico.AllowOrigins) != 0 {
|
||||
t.Errorf("allow_origins = %v, want empty when no caller origin", cfg.Channels.Pico.AllowOrigins)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsurePicoChannel_SetsCallerOrigin(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
lanOrigin := "http://192.168.1.9:18800"
|
||||
if _, err := h.ensurePicoChannel(lanOrigin); err != nil {
|
||||
t.Fatalf("ensurePicoChannel() error = %v", err)
|
||||
}
|
||||
|
||||
cfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
|
||||
if len(cfg.Channels.Pico.AllowOrigins) != 1 || cfg.Channels.Pico.AllowOrigins[0] != lanOrigin {
|
||||
t.Errorf("allow_origins = %v, want [%s]", cfg.Channels.Pico.AllowOrigins, lanOrigin)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsurePicoChannel_PreservesUserSettings(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
|
||||
// Pre-configure with custom user settings
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Channels.Pico.Enabled = true
|
||||
cfg.Channels.Pico.Token = "user-custom-token"
|
||||
cfg.Channels.Pico.AllowTokenQuery = true
|
||||
cfg.Channels.Pico.AllowOrigins = []string{"https://myapp.example.com"}
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
|
||||
changed, err := h.ensurePicoChannel("")
|
||||
if err != nil {
|
||||
t.Fatalf("ensurePicoChannel() error = %v", err)
|
||||
}
|
||||
if changed {
|
||||
t.Error("ensurePicoChannel() should not change a fully configured config")
|
||||
}
|
||||
|
||||
cfg, err = config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
|
||||
if cfg.Channels.Pico.Token != "user-custom-token" {
|
||||
t.Errorf("token = %q, want %q", cfg.Channels.Pico.Token, "user-custom-token")
|
||||
}
|
||||
if !cfg.Channels.Pico.AllowTokenQuery {
|
||||
t.Error("user's allow_token_query=true must be preserved")
|
||||
}
|
||||
if len(cfg.Channels.Pico.AllowOrigins) != 1 || cfg.Channels.Pico.AllowOrigins[0] != "https://myapp.example.com" {
|
||||
t.Errorf("allow_origins = %v, want [https://myapp.example.com]", cfg.Channels.Pico.AllowOrigins)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsurePicoChannel_Idempotent(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
origin := "http://localhost:18800"
|
||||
|
||||
// First call sets things up
|
||||
if _, err := h.ensurePicoChannel(origin); err != nil {
|
||||
t.Fatalf("first ensurePicoChannel() error = %v", err)
|
||||
}
|
||||
|
||||
cfg1, _ := config.LoadConfig(configPath)
|
||||
token1 := cfg1.Channels.Pico.Token
|
||||
|
||||
// Second call should be a no-op
|
||||
changed, err := h.ensurePicoChannel(origin)
|
||||
if err != nil {
|
||||
t.Fatalf("second ensurePicoChannel() error = %v", err)
|
||||
}
|
||||
if changed {
|
||||
t.Error("second ensurePicoChannel() should not report changed")
|
||||
}
|
||||
|
||||
cfg2, _ := config.LoadConfig(configPath)
|
||||
if cfg2.Channels.Pico.Token != token1 {
|
||||
t.Error("token should not change on subsequent calls")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlePicoSetup_IncludesRequestOrigin(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/pico/setup", nil)
|
||||
req.Header.Set("Origin", "http://10.0.0.5:3000")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.handlePicoSetup(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
cfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
|
||||
if len(cfg.Channels.Pico.AllowOrigins) != 1 || cfg.Channels.Pico.AllowOrigins[0] != "http://10.0.0.5:3000" {
|
||||
t.Errorf("allow_origins = %v, want [http://10.0.0.5:3000]", cfg.Channels.Pico.AllowOrigins)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlePicoSetup_Response(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/pico/setup", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
h.handlePicoSetup(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var resp map[string]any
|
||||
if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
if resp["token"] == nil || resp["token"] == "" {
|
||||
t.Error("response should contain a non-empty token")
|
||||
}
|
||||
if resp["ws_url"] == nil || resp["ws_url"] == "" {
|
||||
t.Error("response should contain ws_url")
|
||||
}
|
||||
if resp["enabled"] != true {
|
||||
t.Error("response should have enabled=true")
|
||||
}
|
||||
if resp["changed"] != true {
|
||||
t.Error("response should have changed=true on first setup")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleWebSocketProxyReloadsGatewayTargetFromConfig(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
handler := h.handleWebSocketProxy()
|
||||
|
||||
server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/pico/ws" {
|
||||
t.Fatalf("server1 path = %q, want %q", r.URL.Path, "/pico/ws")
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = io.WriteString(w, "server1")
|
||||
}))
|
||||
defer server1.Close()
|
||||
|
||||
server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/pico/ws" {
|
||||
t.Fatalf("server2 path = %q, want %q", r.URL.Path, "/pico/ws")
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = io.WriteString(w, "server2")
|
||||
}))
|
||||
defer server2.Close()
|
||||
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Gateway.Host = "127.0.0.1"
|
||||
cfg.Gateway.Port = mustGatewayTestPort(t, server1.URL)
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
req1 := httptest.NewRequest(http.MethodGet, "/pico/ws", nil)
|
||||
rec1 := httptest.NewRecorder()
|
||||
handler(rec1, req1)
|
||||
|
||||
if rec1.Code != http.StatusOK {
|
||||
t.Fatalf("first status = %d, want %d", rec1.Code, http.StatusOK)
|
||||
}
|
||||
if body := rec1.Body.String(); body != "server1" {
|
||||
t.Fatalf("first body = %q, want %q", body, "server1")
|
||||
}
|
||||
|
||||
cfg.Gateway.Port = mustGatewayTestPort(t, server2.URL)
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/pico/ws", nil)
|
||||
rec2 := httptest.NewRecorder()
|
||||
handler(rec2, req2)
|
||||
|
||||
if rec2.Code != http.StatusOK {
|
||||
t.Fatalf("second status = %d, want %d", rec2.Code, http.StatusOK)
|
||||
}
|
||||
if body := rec2.Body.String(); body != "server2" {
|
||||
t.Fatalf("second body = %q, want %q", body, "server2")
|
||||
}
|
||||
}
|
||||
|
||||
func mustGatewayTestPort(t *testing.T, rawURL string) int {
|
||||
t.Helper()
|
||||
|
||||
parsed, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
t.Fatalf("url.Parse() error = %v", err)
|
||||
}
|
||||
|
||||
port, err := strconv.Atoi(parsed.Port())
|
||||
if err != nil {
|
||||
t.Fatalf("Atoi(%q) error = %v", parsed.Port(), err)
|
||||
}
|
||||
|
||||
return port
|
||||
}
|
||||
@@ -70,3 +70,8 @@ func (h *Handler) RegisterRoutes(mux *http.ServeMux) {
|
||||
// Launcher service parameters (port/public)
|
||||
h.registerLauncherConfigRoutes(mux)
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down the handler, stopping the gateway if it was started by this handler.
|
||||
func (h *Handler) Shutdown() {
|
||||
h.StopGateway()
|
||||
}
|
||||
|
||||
@@ -309,7 +309,7 @@ func loadSkillContent(path string) (string, error) {
|
||||
}
|
||||
|
||||
func globalConfigDir() string {
|
||||
if home := os.Getenv("PICOCLAW_HOME"); home != "" {
|
||||
if home := os.Getenv(config.EnvHome); home != "" {
|
||||
return home
|
||||
}
|
||||
home, err := os.UserHomeDir()
|
||||
@@ -320,7 +320,7 @@ func globalConfigDir() string {
|
||||
}
|
||||
|
||||
func builtinSkillsDir() string {
|
||||
if path := os.Getenv("PICOCLAW_BUILTIN_SKILLS"); path != "" {
|
||||
if path := os.Getenv(config.EnvBuiltinSkills); path != "" {
|
||||
return path
|
||||
}
|
||||
wd, err := os.Getwd()
|
||||
|
||||
@@ -118,6 +118,12 @@ var toolCatalog = []toolCatalogEntry{
|
||||
Category: "agents",
|
||||
ConfigKey: "spawn",
|
||||
},
|
||||
{
|
||||
Name: "spawn_status",
|
||||
Description: "Query the status of spawned subagents.",
|
||||
Category: "agents",
|
||||
ConfigKey: "spawn_status",
|
||||
},
|
||||
{
|
||||
Name: "i2c",
|
||||
Description: "Interact with I2C hardware devices exposed on the host.",
|
||||
@@ -205,7 +211,7 @@ func buildToolSupport(cfg *config.Config) []toolSupportItem {
|
||||
reasonCode = "requires_skills"
|
||||
}
|
||||
}
|
||||
case "spawn":
|
||||
case "spawn", "spawn_status":
|
||||
if cfg.Tools.IsToolEnabled(entry.ConfigKey) {
|
||||
if cfg.Tools.IsToolEnabled("subagent") {
|
||||
status = "enabled"
|
||||
@@ -300,6 +306,12 @@ func applyToolState(cfg *config.Config, toolName string, enabled bool) error {
|
||||
if enabled {
|
||||
cfg.Tools.Subagent.Enabled = true
|
||||
}
|
||||
case "spawn_status":
|
||||
cfg.Tools.SpawnStatus.Enabled = enabled
|
||||
if enabled {
|
||||
cfg.Tools.Spawn.Enabled = true
|
||||
cfg.Tools.Subagent.Enabled = true
|
||||
}
|
||||
case "i2c":
|
||||
cfg.Tools.I2C.Enabled = enabled
|
||||
case "spi":
|
||||
|
||||
@@ -0,0 +1,62 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/web/backend/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
browserDelay = 500 * time.Millisecond
|
||||
shutdownTimeout = 15 * time.Second
|
||||
)
|
||||
|
||||
// shutdownApp gracefully shuts down all server components and resources.
|
||||
// It performs the following shutdown sequence:
|
||||
// - Shuts down the API handler to close all active SSE (Server-Sent Events) connections
|
||||
// - Disables HTTP keep-alive to prevent new connections during shutdown
|
||||
// - Attempts graceful HTTP server shutdown with timeout
|
||||
// - Logs shutdown status at appropriate log levels
|
||||
//
|
||||
// The function handles timeout errors gracefully by logging them at info level
|
||||
// since context.DeadlineExceeded is expected when there are active long-running
|
||||
// connections (such as SSE streams).
|
||||
//
|
||||
// This function should be called during application termination to ensure
|
||||
// clean resource cleanup and proper connection closure.
|
||||
func shutdownApp() {
|
||||
// First, shutdown API handler to close all SSE connections
|
||||
if apiHandler != nil {
|
||||
apiHandler.Shutdown()
|
||||
}
|
||||
|
||||
if server != nil {
|
||||
// Disable keep-alive to allow graceful shutdown
|
||||
server.SetKeepAlivesEnabled(false)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout)
|
||||
defer cancel()
|
||||
if err := server.Shutdown(ctx); err != nil {
|
||||
// Context deadline exceeded is expected if there are active connections
|
||||
// This is not necessarily an error, so log it at info level
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
logger.Infof("Server shutdown timeout after %v, forcing close", shutdownTimeout)
|
||||
} else {
|
||||
logger.Errorf("Server shutdown error: %v", err)
|
||||
}
|
||||
} else {
|
||||
logger.Infof("Server shutdown completed successfully")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func openBrowser() error {
|
||||
if serverAddr == "" {
|
||||
return fmt.Errorf("server address not set")
|
||||
}
|
||||
return utils.OpenBrowser(serverAddr)
|
||||
}
|
||||
+14
-4
@@ -2,11 +2,14 @@ package main
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"log"
|
||||
"mime"
|
||||
"net/http"
|
||||
"path"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
//go:embed all:dist
|
||||
@@ -14,13 +17,20 @@ var frontendFS embed.FS
|
||||
|
||||
// registerEmbedRoutes sets up the HTTP handler to serve the embedded frontend files
|
||||
func registerEmbedRoutes(mux *http.ServeMux) {
|
||||
// Register correct MIME type for SVG files
|
||||
// Go's built-in mime.TypeByExtension returns "image/svg" which is incorrect
|
||||
// The correct MIME type per RFC 6838 is "image/svg+xml"
|
||||
if err := mime.AddExtensionType(".svg", "image/svg+xml"); err != nil {
|
||||
logger.ErrorC("web", fmt.Sprintf("Warning: failed to register SVG MIME type: %v", err))
|
||||
}
|
||||
|
||||
// Attempt to get the subdirectory 'dist' where Vite usually builds
|
||||
subFS, err := fs.Sub(frontendFS, "dist")
|
||||
if err != nil {
|
||||
// Log a warning if dist doesn't exist yet (e.g., during development before a frontend build)
|
||||
log.Printf(
|
||||
"Warning: no 'dist' folder found in embedded frontend. " +
|
||||
"Ensure you run `pnpm build:backend` in the frontend directory " +
|
||||
logger.WarnC("web",
|
||||
"Warning: no 'dist' folder found in embedded frontend. "+
|
||||
"Ensure you run `pnpm build:backend` in the frontend directory "+
|
||||
"before building the Go backend.",
|
||||
)
|
||||
return
|
||||
|
||||
@@ -0,0 +1,120 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Language represents the supported languages
|
||||
type Language string
|
||||
|
||||
const (
|
||||
LanguageEnglish Language = "en"
|
||||
LanguageChinese Language = "zh"
|
||||
)
|
||||
|
||||
// current language (default: English)
|
||||
var currentLang Language = LanguageEnglish
|
||||
|
||||
// TranslationKey represents a translation key used for i18n
|
||||
type TranslationKey string
|
||||
|
||||
const (
|
||||
AppTooltip TranslationKey = "AppTooltip"
|
||||
MenuOpen TranslationKey = "MenuOpen"
|
||||
MenuOpenTooltip TranslationKey = "MenuOpenTooltip"
|
||||
MenuAbout TranslationKey = "MenuAbout"
|
||||
MenuAboutTooltip TranslationKey = "MenuAboutTooltip"
|
||||
MenuVersion TranslationKey = "MenuVersion"
|
||||
MenuVersionTooltip TranslationKey = "MenuVersionTooltip"
|
||||
MenuGitHub TranslationKey = "MenuGitHub"
|
||||
MenuDocs TranslationKey = "MenuDocs"
|
||||
MenuRestart TranslationKey = "MenuRestart"
|
||||
MenuRestartTooltip TranslationKey = "MenuRestartTooltip"
|
||||
MenuQuit TranslationKey = "MenuQuit"
|
||||
MenuQuitTooltip TranslationKey = "MenuQuitTooltip"
|
||||
Exiting TranslationKey = "Exiting"
|
||||
DocUrl TranslationKey = "DocUrl"
|
||||
)
|
||||
|
||||
// Translation tables
|
||||
// Chinese translations intentionally contain Han script
|
||||
//
|
||||
//nolint:gosmopolitan
|
||||
var translations = map[Language]map[TranslationKey]string{
|
||||
LanguageEnglish: {
|
||||
AppTooltip: "%s - Web Console",
|
||||
MenuOpen: "Open Console",
|
||||
MenuOpenTooltip: "Open PicoClaw console in browser",
|
||||
MenuAbout: "About",
|
||||
MenuAboutTooltip: "About PicoClaw",
|
||||
MenuVersion: "Version: %s",
|
||||
MenuVersionTooltip: "Current version number",
|
||||
MenuGitHub: "GitHub",
|
||||
MenuDocs: "Documentation",
|
||||
MenuRestart: "Restart Service",
|
||||
MenuRestartTooltip: "Restart Gateway service",
|
||||
MenuQuit: "Quit",
|
||||
MenuQuitTooltip: "Exit PicoClaw",
|
||||
Exiting: "Exiting PicoClaw...",
|
||||
DocUrl: "https://docs.picoclaw.io/docs/",
|
||||
},
|
||||
LanguageChinese: {
|
||||
AppTooltip: "%s - Web Console",
|
||||
MenuOpen: "打开控制台",
|
||||
MenuOpenTooltip: "在浏览器中打开 PicoClaw 控制台",
|
||||
MenuAbout: "关于",
|
||||
MenuAboutTooltip: "关于 PicoClaw",
|
||||
MenuVersion: "版本: %s",
|
||||
MenuVersionTooltip: "当前版本号",
|
||||
MenuGitHub: "GitHub",
|
||||
MenuDocs: "文档",
|
||||
MenuRestart: "重启服务",
|
||||
MenuRestartTooltip: "重启核心服务",
|
||||
MenuQuit: "退出",
|
||||
MenuQuitTooltip: "退出 PicoClaw",
|
||||
Exiting: "正在退出 PicoClaw...",
|
||||
DocUrl: "https://docs.picoclaw.io/zh-Hans/docs/",
|
||||
},
|
||||
}
|
||||
|
||||
// SetLanguage sets the current language
|
||||
func SetLanguage(lang string) {
|
||||
lang = strings.ToLower(strings.TrimSpace(lang))
|
||||
|
||||
// Extract language code before first underscore or dot
|
||||
// e.g., "en_US.UTF-8" -> "en", "zh_CN" -> "zh"
|
||||
if idx := strings.IndexAny(lang, "_."); idx > 0 {
|
||||
lang = lang[:idx]
|
||||
}
|
||||
|
||||
if lang == "zh" || lang == "zh-cn" || lang == "chinese" {
|
||||
currentLang = LanguageChinese
|
||||
} else {
|
||||
currentLang = LanguageEnglish
|
||||
}
|
||||
}
|
||||
|
||||
// GetLanguage returns the current language
|
||||
func GetLanguage() Language {
|
||||
return currentLang
|
||||
}
|
||||
|
||||
// T translates a key to the current language
|
||||
func T(key TranslationKey, args ...any) string {
|
||||
if trans, ok := translations[currentLang][key]; ok {
|
||||
if len(args) > 0 {
|
||||
return fmt.Sprintf(trans, args...)
|
||||
}
|
||||
return trans
|
||||
}
|
||||
return string(key)
|
||||
}
|
||||
|
||||
// Initialize i18n from environment variable
|
||||
func init() {
|
||||
if lang := os.Getenv("LANG"); lang != "" {
|
||||
SetLanguage(lang)
|
||||
}
|
||||
}
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 102 KiB |
+111
-29
@@ -15,23 +15,42 @@ import (
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/web/backend/api"
|
||||
"github.com/sipeed/picoclaw/web/backend/launcherconfig"
|
||||
"github.com/sipeed/picoclaw/web/backend/middleware"
|
||||
"github.com/sipeed/picoclaw/web/backend/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
appName = "PicoClaw"
|
||||
)
|
||||
|
||||
var (
|
||||
appVersion = config.Version
|
||||
|
||||
server *http.Server
|
||||
serverAddr string
|
||||
apiHandler *api.Handler
|
||||
|
||||
noBrowser *bool
|
||||
)
|
||||
|
||||
func main() {
|
||||
port := flag.String("port", "18800", "Port to listen on")
|
||||
public := flag.Bool("public", false, "Listen on all interfaces (0.0.0.0) instead of localhost only")
|
||||
noBrowser := flag.Bool("no-browser", false, "Do not auto-open browser on startup")
|
||||
noBrowser = flag.Bool("no-browser", false, "Do not auto-open browser on startup")
|
||||
lang := flag.String("lang", "", "Language: en (English) or zh (Chinese). Default: auto-detect from system locale")
|
||||
console := flag.Bool("console", false, "Console mode, no GUI")
|
||||
|
||||
flag.Usage = func() {
|
||||
fmt.Fprintf(os.Stderr, "PicoClaw Launcher - A web-based configuration editor\n\n")
|
||||
@@ -51,6 +70,32 @@ func main() {
|
||||
}
|
||||
flag.Parse()
|
||||
|
||||
// Initialize logger
|
||||
picoHome := utils.GetPicoclawHome()
|
||||
// By default, detect terminal to decide console log behavior
|
||||
// If -console-logs flag is explicitly set, it overrides the detection
|
||||
enableConsole := *console
|
||||
if !enableConsole {
|
||||
// Disable console logging by setting level to Fatal (no output)
|
||||
logger.SetConsoleLevel(logger.FATAL)
|
||||
|
||||
logPath := filepath.Join(picoHome, "logs", "web.log")
|
||||
if err := logger.EnableFileLogging(logPath); err != nil {
|
||||
// FIXME: https://github.com/sipeed/picoclaw/issues/1734
|
||||
fmt.Fprintf(os.Stderr, "Failed to initialize logger: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer logger.DisableFileLogging()
|
||||
}
|
||||
|
||||
logger.InfoC("web", "PicoClaw Launcher starting...")
|
||||
logger.InfoC("web", fmt.Sprintf("PicoClaw Home: %s", picoHome))
|
||||
|
||||
// Set language from command line or auto-detect
|
||||
if *lang != "" {
|
||||
SetLanguage(*lang)
|
||||
}
|
||||
|
||||
// Resolve config path
|
||||
configPath := utils.GetDefaultConfigPath()
|
||||
if flag.NArg() > 0 {
|
||||
@@ -59,11 +104,11 @@ func main() {
|
||||
|
||||
absPath, err := filepath.Abs(configPath)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to resolve config path: %v", err)
|
||||
logger.Fatalf("Failed to resolve config path: %v", err)
|
||||
}
|
||||
err = utils.EnsureOnboarded(absPath)
|
||||
if err != nil {
|
||||
log.Printf("Warning: Failed to initialize PicoClaw config automatically: %v", err)
|
||||
logger.Errorf("Warning: Failed to initialize PicoClaw config automatically: %v", err)
|
||||
}
|
||||
|
||||
var explicitPort bool
|
||||
@@ -80,7 +125,7 @@ func main() {
|
||||
launcherPath := launcherconfig.PathForAppConfig(absPath)
|
||||
launcherCfg, err := launcherconfig.Load(launcherPath, launcherconfig.Default())
|
||||
if err != nil {
|
||||
log.Printf("Warning: Failed to load %s: %v", launcherPath, err)
|
||||
logger.ErrorC("web", fmt.Sprintf("Warning: Failed to load %s: %v", launcherPath, err))
|
||||
launcherCfg = launcherconfig.Default()
|
||||
}
|
||||
|
||||
@@ -98,7 +143,7 @@ func main() {
|
||||
if err == nil {
|
||||
err = errors.New("must be in range 1-65535")
|
||||
}
|
||||
log.Fatalf("Invalid port %q: %v", effectivePort, err)
|
||||
logger.Fatalf("Invalid port %q: %v", effectivePort, err)
|
||||
}
|
||||
|
||||
// Determine listen address
|
||||
@@ -113,7 +158,7 @@ func main() {
|
||||
mux := http.NewServeMux()
|
||||
|
||||
// API Routes (e.g. /api/status)
|
||||
apiHandler := api.NewHandler(absPath)
|
||||
apiHandler = api.NewHandler(absPath)
|
||||
apiHandler.SetServerOptions(portNum, effectivePublic, explicitPublic, launcherCfg.AllowedCIDRs)
|
||||
apiHandler.RegisterRoutes(mux)
|
||||
|
||||
@@ -122,7 +167,7 @@ func main() {
|
||||
|
||||
accessControlledMux, err := middleware.IPAllowlist(launcherCfg.AllowedCIDRs, mux)
|
||||
if err != nil {
|
||||
log.Fatalf("Invalid allowed CIDR configuration: %v", err)
|
||||
logger.Fatalf("Invalid allowed CIDR configuration: %v", err)
|
||||
}
|
||||
|
||||
// Apply middleware stack
|
||||
@@ -132,29 +177,33 @@ func main() {
|
||||
),
|
||||
)
|
||||
|
||||
// Print startup banner
|
||||
fmt.Print(utils.Banner)
|
||||
fmt.Println()
|
||||
fmt.Println(" Open the following URL in your browser:")
|
||||
fmt.Println()
|
||||
fmt.Printf(" >> http://localhost:%s <<\n", effectivePort)
|
||||
// Print startup banner (only in console mode)
|
||||
if enableConsole {
|
||||
fmt.Print(utils.Banner)
|
||||
fmt.Println()
|
||||
fmt.Println(" Open the following URL in your browser:")
|
||||
fmt.Println()
|
||||
fmt.Printf(" >> http://localhost:%s <<\n", effectivePort)
|
||||
if effectivePublic {
|
||||
if ip := utils.GetLocalIP(); ip != "" {
|
||||
fmt.Printf(" >> http://%s:%s <<\n", ip, effectivePort)
|
||||
}
|
||||
}
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
// Log startup info to file
|
||||
logger.InfoC("web", fmt.Sprintf("Server will listen on http://localhost:%s", effectivePort))
|
||||
if effectivePublic {
|
||||
if ip := utils.GetLocalIP(); ip != "" {
|
||||
fmt.Printf(" >> http://%s:%s <<\n", ip, effectivePort)
|
||||
logger.InfoC("web", fmt.Sprintf("Public access enabled at http://%s:%s", ip, effectivePort))
|
||||
}
|
||||
}
|
||||
fmt.Println()
|
||||
|
||||
// Auto-open browser
|
||||
if !*noBrowser {
|
||||
go func() {
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
url := "http://localhost:" + effectivePort
|
||||
if err := utils.OpenBrowser(url); err != nil {
|
||||
log.Printf("Warning: Failed to auto-open browser: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
// Share the local URL with the launcher runtime.
|
||||
serverAddr = fmt.Sprintf("http://localhost:%s", effectivePort)
|
||||
|
||||
// Auto-open browser will be handled by the launcher runtime.
|
||||
|
||||
// Auto-start gateway after backend starts listening.
|
||||
go func() {
|
||||
@@ -162,8 +211,41 @@ func main() {
|
||||
apiHandler.TryAutoStartGateway()
|
||||
}()
|
||||
|
||||
// Start the Server
|
||||
if err := http.ListenAndServe(addr, handler); err != nil {
|
||||
log.Fatalf("Server failed to start: %v", err)
|
||||
// Start the Server in a goroutine
|
||||
server = &http.Server{Addr: addr, Handler: handler}
|
||||
go func() {
|
||||
logger.InfoC("web", fmt.Sprintf("Server listening on %s", addr))
|
||||
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
logger.Fatalf("Server failed to start: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
defer shutdownApp()
|
||||
|
||||
// Start system tray or run in console mode
|
||||
if enableConsole {
|
||||
if !*noBrowser {
|
||||
// Auto-open browser after systray is ready (if not disabled)
|
||||
// Check no-browser flag via environment or pass as parameter if needed
|
||||
if err := openBrowser(); err != nil {
|
||||
logger.Errorf("Warning: Failed to auto-open browser: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
|
||||
|
||||
// Main event loop - wait for signals or config changes
|
||||
for {
|
||||
select {
|
||||
case <-sigChan:
|
||||
logger.Info("Shutting down...")
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// GUI mode: start system tray
|
||||
runTray()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,19 +1,19 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"log"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
// JSONContentType sets the Content-Type header to application/json for
|
||||
// API requests handled by the wrapped handler.
|
||||
// SSE endpoints (text/event-stream) are excluded.
|
||||
func JSONContentType(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.HasPrefix(r.URL.Path, "/api/") && !strings.HasSuffix(r.URL.Path, "/events") {
|
||||
if len(r.URL.Path) >= 5 && r.URL.Path[:5] == "/api/" {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
@@ -32,7 +32,6 @@ func (rr *responseRecorder) WriteHeader(code int) {
|
||||
}
|
||||
|
||||
// Flush delegates to the underlying ResponseWriter if it implements http.Flusher.
|
||||
// This is required for SSE (Server-Sent Events) to work through the middleware.
|
||||
func (rr *responseRecorder) Flush() {
|
||||
if f, ok := rr.ResponseWriter.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
@@ -51,7 +50,7 @@ func Logger(next http.Handler) http.Handler {
|
||||
start := time.Now()
|
||||
rec := &responseRecorder{ResponseWriter: w, statusCode: http.StatusOK}
|
||||
next.ServeHTTP(rec, r)
|
||||
log.Printf("%s %s %d %s", r.Method, r.URL.Path, rec.statusCode, time.Since(start))
|
||||
logger.DebugC("http", fmt.Sprintf("%s %s %d %s", r.Method, r.URL.Path, rec.statusCode, time.Since(start)))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -61,7 +60,7 @@ func Recoverer(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
log.Printf("panic recovered: %v\n%s", err, debug.Stack())
|
||||
logger.ErrorC("http", fmt.Sprintf("panic recovered: %v\n%s", err, debug.Stack()))
|
||||
http.Error(w, `{"error":"internal server error"}`, http.StatusInternalServerError)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -0,0 +1,100 @@
|
||||
//go:build (!darwin && !freebsd) || cgo
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"fmt"
|
||||
|
||||
"fyne.io/systray"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/web/backend/utils"
|
||||
)
|
||||
|
||||
func runTray() {
|
||||
systray.Run(onReady, onExit)
|
||||
}
|
||||
|
||||
// onReady is called when the system tray is ready
|
||||
func onReady() {
|
||||
// Set icon and tooltip
|
||||
systray.SetIcon(getIcon())
|
||||
systray.SetTooltip(fmt.Sprintf(T(AppTooltip), appName))
|
||||
|
||||
// Create menu items
|
||||
mOpen := systray.AddMenuItem(T(MenuOpen), T(MenuOpenTooltip))
|
||||
mAbout := systray.AddMenuItem(T(MenuAbout), T(MenuAboutTooltip))
|
||||
|
||||
// Add version info under About menu
|
||||
mVersion := mAbout.AddSubMenuItem(fmt.Sprintf(T(MenuVersion), appVersion), T(MenuVersionTooltip))
|
||||
mVersion.Disable()
|
||||
mRepo := mAbout.AddSubMenuItem(T(MenuGitHub), "")
|
||||
mDocs := mAbout.AddSubMenuItem(T(MenuDocs), "")
|
||||
|
||||
systray.AddSeparator()
|
||||
|
||||
// Add restart option
|
||||
mRestart := systray.AddMenuItem(T(MenuRestart), T(MenuRestartTooltip))
|
||||
|
||||
systray.AddSeparator()
|
||||
|
||||
// Quit option
|
||||
mQuit := systray.AddMenuItem(T(MenuQuit), T(MenuQuitTooltip))
|
||||
|
||||
// Handle menu clicks
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-mOpen.ClickedCh:
|
||||
if err := openBrowser(); err != nil {
|
||||
logger.Errorf("Failed to open browser: %v", err)
|
||||
}
|
||||
|
||||
case <-mVersion.ClickedCh:
|
||||
// Version info - do nothing, just shows current version
|
||||
|
||||
case <-mRepo.ClickedCh:
|
||||
if err := utils.OpenBrowser("https://github.com/sipeed/picoclaw"); err != nil {
|
||||
logger.Errorf("Failed to open GitHub: %v", err)
|
||||
}
|
||||
|
||||
case <-mDocs.ClickedCh:
|
||||
if err := utils.OpenBrowser(T(DocUrl)); err != nil {
|
||||
logger.Errorf("Failed to open docs: %v", err)
|
||||
}
|
||||
|
||||
case <-mRestart.ClickedCh:
|
||||
fmt.Println("Restart request received...")
|
||||
if apiHandler != nil {
|
||||
if pid, err := apiHandler.RestartGateway(); err != nil {
|
||||
logger.Errorf("Failed to restart gateway: %v", err)
|
||||
} else {
|
||||
logger.Infof("Gateway restarted (PID: %d)", pid)
|
||||
}
|
||||
}
|
||||
|
||||
case <-mQuit.ClickedCh:
|
||||
systray.Quit()
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
if !*noBrowser {
|
||||
// Auto-open browser after systray is ready (if not disabled)
|
||||
// Check no-browser flag via environment or pass as parameter if needed
|
||||
if err := openBrowser(); err != nil {
|
||||
logger.Errorf("Warning: Failed to auto-open browser: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// onExit is called when the system tray is exiting
|
||||
func onExit() {
|
||||
logger.Info(T(Exiting))
|
||||
}
|
||||
|
||||
// getIcon returns the system tray icon
|
||||
func getIcon() []byte {
|
||||
return iconData
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
//go:build !windows
|
||||
|
||||
package main
|
||||
|
||||
import _ "embed"
|
||||
|
||||
//go:embed icon.png
|
||||
var iconData []byte
|
||||
@@ -0,0 +1,8 @@
|
||||
//go:build windows
|
||||
|
||||
package main
|
||||
|
||||
import _ "embed"
|
||||
|
||||
//go:embed icon.ico
|
||||
var iconData []byte
|
||||
@@ -0,0 +1,33 @@
|
||||
//go:build (darwin || freebsd) && !cgo
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"os/signal"
|
||||
"runtime"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
func runTray() {
|
||||
logger.Infof("System tray is unavailable in %s builds without cgo; running without tray", runtime.GOOS)
|
||||
|
||||
if !*noBrowser {
|
||||
go func() {
|
||||
time.Sleep(browserDelay)
|
||||
if err := openBrowser(); err != nil {
|
||||
logger.Errorf("Warning: Failed to auto-open browser: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
||||
defer stop()
|
||||
|
||||
<-ctx.Done()
|
||||
shutdownApp()
|
||||
}
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
var execCommand = exec.Command
|
||||
@@ -19,7 +21,7 @@ func EnsureOnboarded(configPath string) error {
|
||||
}
|
||||
|
||||
cmd := execCommand(FindPicoclawBinary(), "onboard")
|
||||
cmd.Env = append(os.Environ(), "PICOCLAW_CONFIG="+configPath)
|
||||
cmd.Env = append(os.Environ(), config.EnvConfig+"="+configPath)
|
||||
cmd.Stdin = strings.NewReader("n\n")
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
|
||||
@@ -7,21 +7,26 @@ import (
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
// GetPicoclawHome returns the picoclaw home directory.
|
||||
// Priority: $PICOCLAW_HOME > ~/.picoclaw
|
||||
func GetPicoclawHome() string {
|
||||
if home := os.Getenv(config.EnvHome); home != "" {
|
||||
return home
|
||||
}
|
||||
home, _ := os.UserHomeDir()
|
||||
return filepath.Join(home, ".picoclaw")
|
||||
}
|
||||
|
||||
// GetDefaultConfigPath returns the default path to the picoclaw config file.
|
||||
func GetDefaultConfigPath() string {
|
||||
if configPath := os.Getenv("PICOCLAW_CONFIG"); configPath != "" {
|
||||
if configPath := os.Getenv(config.EnvConfig); configPath != "" {
|
||||
return configPath
|
||||
}
|
||||
if picoclawHome := os.Getenv("PICOCLAW_HOME"); picoclawHome != "" {
|
||||
return filepath.Join(picoclawHome, "config.json")
|
||||
}
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "config.json"
|
||||
}
|
||||
return filepath.Join(home, ".picoclaw", "config.json")
|
||||
return filepath.Join(GetPicoclawHome(), "config.json")
|
||||
}
|
||||
|
||||
// FindPicoclawBinary locates the picoclaw executable.
|
||||
@@ -35,7 +40,7 @@ func FindPicoclawBinary() string {
|
||||
binaryName = "picoclaw.exe"
|
||||
}
|
||||
|
||||
if p := os.Getenv("PICOCLAW_BINARY"); p != "" {
|
||||
if p := os.Getenv(config.EnvBinary); p != "" {
|
||||
if info, _ := os.Stat(p); info != nil && !info.IsDir() {
|
||||
return p
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user