feat: add web gateway hot reload and polling state sync (#1684)

* feat(gateway): support hot reload and empty startup

- extract gateway runtime into pkg/gateway
- add gateway.hot_reload config with default and example values
- allow starting the gateway without a default model via --allow-empty
- stop treating missing enabled channels as a startup error
- update related tests

* feat: replace gateway SSE updates with polling-based state sync

- remove gateway SSE broadcasting and event endpoint
- add polling-based gateway status refresh with stopping state handling
- detect when gateway restart is required after default model changes
- resolve gateway health and websocket proxy targets from configured host
- update gateway UI labels and add backend/frontend test coverage
This commit is contained in:
wenjie
2026-03-17 18:46:00 +08:00
committed by GitHub
parent 11207186c8
commit 8a44410e37
24 changed files with 700 additions and 543 deletions
-80
View File
@@ -1,80 +0,0 @@
package api
import (
"encoding/json"
"sync"
)
// GatewayEvent represents a state change event for the gateway process.
type GatewayEvent struct {
Status string `json:"gateway_status"` // "running", "starting", "restarting", "stopped", "error"
PID int `json:"pid,omitempty"`
BootDefaultModel string `json:"boot_default_model,omitempty"`
ConfigDefaultModel string `json:"config_default_model,omitempty"`
RestartRequired bool `json:"gateway_restart_required,omitempty"`
}
// EventBroadcaster manages SSE client subscriptions and broadcasts events.
type EventBroadcaster struct {
mu sync.RWMutex
clients map[chan string]struct{}
}
// NewEventBroadcaster creates a new broadcaster.
func NewEventBroadcaster() *EventBroadcaster {
return &EventBroadcaster{
clients: make(map[chan string]struct{}),
}
}
// Subscribe adds a new listener channel and returns it.
// The caller must call Unsubscribe when done.
func (b *EventBroadcaster) Subscribe() chan string {
ch := make(chan string, 8)
b.mu.Lock()
b.clients[ch] = struct{}{}
b.mu.Unlock()
return ch
}
// Unsubscribe removes a listener channel and closes it.
func (b *EventBroadcaster) Unsubscribe(ch chan string) {
b.mu.Lock()
defer b.mu.Unlock()
// Check if the channel is still registered before closing
if _, exists := b.clients[ch]; exists {
delete(b.clients, ch)
close(ch)
}
}
// Broadcast sends a GatewayEvent to all connected SSE clients.
func (b *EventBroadcaster) Broadcast(event GatewayEvent) {
data, err := json.Marshal(event)
if err != nil {
return
}
b.mu.RLock()
defer b.mu.RUnlock()
for ch := range b.clients {
// Non-blocking send; drop event if client is slow
select {
case ch <- string(data):
default:
}
}
}
// Shutdown closes all subscriber channels, notifying all SSE clients to disconnect.
// This should be called when the server is shutting down.
func (b *EventBroadcaster) Shutdown() {
// Close all channels to notify listeners
for ch := range b.clients {
b.Unsubscribe(ch)
}
// Clear the map
b.clients = make(map[chan string]struct{})
}
+60 -149
View File
@@ -7,6 +7,7 @@ import (
"fmt"
"io"
"log"
"net"
"net/http"
"os"
"os/exec"
@@ -30,11 +31,9 @@ var gateway = struct {
runtimeStatus string
startupDeadline time.Time
logs *LogBuffer
events *EventBroadcaster
}{
runtimeStatus: "stopped",
logs: NewLogBuffer(200),
events: NewEventBroadcaster(),
}
var (
@@ -51,11 +50,19 @@ var gatewayHealthGet = func(url string, timeout time.Duration) (*http.Response,
// getGatewayHealth checks the gateway health endpoint and returns the status response
// Returns (*health.StatusResponse, statusCode, error). If error is not nil, the other values are not valid.
func getGatewayHealth(port int, timeout time.Duration) (*health.StatusResponse, int, error) {
if port == 0 {
port = 18790
func (h *Handler) getGatewayHealth(cfg *config.Config, timeout time.Duration) (*health.StatusResponse, int, error) {
port := 18790
if cfg != nil && cfg.Gateway.Port != 0 {
port = cfg.Gateway.Port
}
url := fmt.Sprintf("http://127.0.0.1:%d/health", port)
probeHost := gatewayProbeHost(h.effectiveGatewayBindHost(cfg))
url := "http://" + net.JoinHostPort(probeHost, strconv.Itoa(port)) + "/health"
return getGatewayHealthByURL(url, timeout)
}
func getGatewayHealthByURL(url string, timeout time.Duration) (*health.StatusResponse, int, error) {
resp, err := gatewayHealthGet(url, timeout)
if err != nil {
return nil, 0, err
@@ -73,7 +80,6 @@ func getGatewayHealth(port int, timeout time.Duration) (*health.StatusResponse,
// registerGatewayRoutes binds gateway lifecycle endpoints to the ServeMux.
func (h *Handler) registerGatewayRoutes(mux *http.ServeMux) {
mux.HandleFunc("GET /api/gateway/status", h.handleGatewayStatus)
mux.HandleFunc("GET /api/gateway/events", h.handleGatewayEvents)
mux.HandleFunc("GET /api/gateway/logs", h.handleGatewayLogs)
mux.HandleFunc("POST /api/gateway/logs/clear", h.handleGatewayClearLogs)
mux.HandleFunc("POST /api/gateway/start", h.handleGatewayStart)
@@ -87,7 +93,7 @@ func (h *Handler) TryAutoStartGateway() {
// Check if gateway is already running via health endpoint
cfg, cfgErr := config.LoadConfig(h.configPath)
if cfgErr == nil && cfg != nil {
healthResp, statusCode, err := getGatewayHealth(cfg.Gateway.Port, 2*time.Second)
healthResp, statusCode, err := h.getGatewayHealth(cfg, 2*time.Second)
if err == nil && statusCode == http.StatusOK {
// Gateway is already running, attach to the existing process
pid := healthResp.Pid
@@ -170,6 +176,16 @@ func lookupModelConfig(cfg *config.Config, modelName string) *config.ModelConfig
return modelCfg
}
func gatewayRestartRequired(configDefaultModel, bootDefaultModel, gatewayStatus string) bool {
if gatewayStatus != "running" {
return false
}
if strings.TrimSpace(configDefaultModel) == "" || strings.TrimSpace(bootDefaultModel) == "" {
return false
}
return configDefaultModel != bootDefaultModel
}
func isCmdProcessAliveLocked(cmd *exec.Cmd) bool {
if cmd == nil || cmd.Process == nil {
return false
@@ -220,7 +236,7 @@ func attachToGatewayProcessLocked(pid int, cfg *config.Config) error {
return nil
}
func gatewayStatusOnHealthFailureLocked() string {
func gatewayStatusWithoutHealthLocked() string {
if gateway.runtimeStatus == "starting" || gateway.runtimeStatus == "restarting" {
if gateway.startupDeadline.IsZero() || time.Now().Before(gateway.startupDeadline) {
return gateway.runtimeStatus
@@ -233,23 +249,7 @@ func gatewayStatusOnHealthFailureLocked() string {
if gateway.runtimeStatus == "error" {
return "error"
}
return "error"
}
func currentGatewayStatusLocked(processAlive bool) string {
if !processAlive {
if gateway.runtimeStatus == "restarting" {
if gateway.startupDeadline.IsZero() || time.Now().Before(gateway.startupDeadline) {
return "restarting"
}
return "error"
}
if gateway.runtimeStatus == "error" {
return "error"
}
return "stopped"
}
return gatewayStatusOnHealthFailureLocked()
return "stopped"
}
func waitForGatewayProcessExit(cmd *exec.Cmd, timeout time.Duration) bool {
@@ -319,15 +319,6 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
return 0, err
}
// Broadcast the attached state
gateway.events.Broadcast(GatewayEvent{
Status: initialStatus,
PID: pid,
BootDefaultModel: defaultModelName,
ConfigDefaultModel: defaultModelName,
RestartRequired: false,
})
return pid, nil
}
@@ -335,7 +326,7 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
// Locate the picoclaw executable
execPath := utils.FindPicoclawBinary()
cmd = exec.Command(execPath, "gateway")
cmd = exec.Command(execPath, "gateway", "-E")
cmd.Env = os.Environ()
// Forward the launcher's config path via the environment variable that
// GetConfigPath() already reads, so the gateway sub-process uses the same
@@ -376,15 +367,6 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
pid = cmd.Process.Pid
log.Printf("Started picoclaw gateway (PID: %d) from %s", pid, execPath)
// Broadcast the launch state immediately so clients can reflect it without polling.
gateway.events.Broadcast(GatewayEvent{
Status: initialStatus,
PID: pid,
BootDefaultModel: defaultModelName,
ConfigDefaultModel: defaultModelName,
RestartRequired: false,
})
// Capture stdout/stderr in background
go scanPipe(stdoutPipe, gateway.logs)
go scanPipe(stderrPipe, gateway.logs)
@@ -398,26 +380,17 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
}
gateway.mu.Lock()
shouldBroadcastStopped := false
if gateway.cmd == cmd {
gateway.cmd = nil
gateway.bootDefaultModel = ""
if gateway.runtimeStatus != "restarting" {
setGatewayRuntimeStatusLocked("stopped")
shouldBroadcastStopped = true
}
}
gateway.mu.Unlock()
if shouldBroadcastStopped {
gateway.events.Broadcast(GatewayEvent{
Status: "stopped",
RestartRequired: false,
})
}
}()
// Start a goroutine to probe health and broadcast "running" once ready
// Start a goroutine to probe health and update the runtime state once ready.
go func() {
for i := 0; i < 30; i++ { // try for up to 15 seconds
time.Sleep(500 * time.Millisecond)
@@ -431,7 +404,7 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
if err != nil {
continue
}
healthResp, statusCode, err := getGatewayHealth(cfg.Gateway.Port, 1*time.Second)
healthResp, statusCode, err := h.getGatewayHealth(cfg, 1*time.Second)
if err == nil && statusCode == http.StatusOK && healthResp.Pid == pid {
// Verify the health endpoint returns the expected pid
gateway.mu.Lock()
@@ -439,13 +412,6 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
setGatewayRuntimeStatusLocked("running")
}
gateway.mu.Unlock()
gateway.events.Broadcast(GatewayEvent{
Status: "running",
PID: pid,
BootDefaultModel: defaultModelName,
ConfigDefaultModel: defaultModelName,
RestartRequired: false,
})
return
}
}
@@ -461,7 +427,7 @@ func (h *Handler) handleGatewayStart(w http.ResponseWriter, r *http.Request) {
// Prevent duplicate starts by checking health endpoint
cfg, cfgErr := config.LoadConfig(h.configPath)
if cfgErr == nil && cfg != nil {
healthResp, statusCode, err := getGatewayHealth(cfg.Gateway.Port, 2*time.Second)
healthResp, statusCode, err := h.getGatewayHealth(cfg, 2*time.Second)
if err == nil && statusCode == http.StatusOK {
// Gateway is already running, attach to the existing process
pid := healthResp.Pid
@@ -597,10 +563,6 @@ func (h *Handler) RestartGateway() (int, error) {
gateway.mu.Lock()
previousCmd := gateway.cmd
setGatewayRuntimeStatusLocked("restarting")
gateway.events.Broadcast(GatewayEvent{
Status: "restarting",
RestartRequired: false,
})
gateway.mu.Unlock()
if err = stopGatewayProcessForRestart(previousCmd); err != nil {
@@ -704,24 +666,20 @@ func (h *Handler) handleGatewayStatus(w http.ResponseWriter, r *http.Request) {
func (h *Handler) gatewayStatusData() map[string]any {
data := map[string]any{}
configDefaultModel := ""
cfg, cfgErr := config.LoadConfig(h.configPath)
if cfgErr == nil && cfg != nil {
configDefaultModel := strings.TrimSpace(cfg.Agents.Defaults.GetModelName())
configDefaultModel = strings.TrimSpace(cfg.Agents.Defaults.GetModelName())
if configDefaultModel != "" {
data["config_default_model"] = configDefaultModel
}
}
// Probe health endpoint to get pid and status
port := 0
if cfgErr == nil && cfg != nil {
port = cfg.Gateway.Port
}
healthResp, statusCode, err := getGatewayHealth(port, 2*time.Second)
healthResp, statusCode, err := h.getGatewayHealth(cfg, 2*time.Second)
if err != nil {
gateway.mu.Lock()
data["gateway_status"] = currentGatewayStatusLocked(true)
data["gateway_status"] = gatewayStatusWithoutHealthLocked()
gateway.mu.Unlock()
log.Printf("Gateway health check failed: %v", err)
} else {
@@ -734,45 +692,43 @@ func (h *Handler) gatewayStatusData() map[string]any {
data["status_code"] = statusCode
} else {
gateway.mu.Lock()
// Check if this pid matches our tracked process
if gateway.cmd != nil && gateway.cmd.Process != nil && gateway.cmd.Process.Pid == healthResp.Pid {
setGatewayRuntimeStatusLocked("running")
bootDefaultModel := gateway.bootDefaultModel
if bootDefaultModel != "" {
data["boot_default_model"] = bootDefaultModel
}
data["gateway_status"] = "running"
data["pid"] = healthResp.Pid
} else {
// Health endpoint responded with a different pid
// This could be a manual restart, try to attach to the new process
setGatewayRuntimeStatusLocked("running")
if gateway.cmd == nil || gateway.cmd.Process == nil || gateway.cmd.Process.Pid != healthResp.Pid {
oldPid := "none"
if gateway.cmd != nil && gateway.cmd.Process != nil {
oldPid = fmt.Sprintf("%d", gateway.cmd.Process.Pid)
}
log.Printf("Detected new gateway PID (old: %s, new: %d), attempting to attach", oldPid, healthResp.Pid)
log.Printf(
"Detected gateway PID from health (old: %s, new: %d), attempting to attach",
oldPid,
healthResp.Pid,
)
if err := attachToGatewayProcessLocked(healthResp.Pid, cfg); err != nil {
// Failed to find the process, treat as error
setGatewayRuntimeStatusLocked("error")
data["gateway_status"] = "error"
data["pid"] = healthResp.Pid
log.Printf("Failed to attach to new gateway process (PID: %d): %v", healthResp.Pid, err)
} else {
// Successfully attached, update response data
bootDefaultModel := gateway.bootDefaultModel
if bootDefaultModel != "" {
data["boot_default_model"] = bootDefaultModel
}
data["gateway_status"] = "running"
data["pid"] = healthResp.Pid
log.Printf(
"Failed to attach to gateway process reported by health (PID: %d): %v",
healthResp.Pid,
err,
)
}
}
bootDefaultModel := gateway.bootDefaultModel
if bootDefaultModel != "" {
data["boot_default_model"] = bootDefaultModel
}
data["gateway_status"] = "running"
data["pid"] = healthResp.Pid
gateway.mu.Unlock()
}
}
data["gateway_restart_required"] = false
bootDefaultModel, _ := data["boot_default_model"].(string)
gatewayStatus, _ := data["gateway_status"].(string)
data["gateway_restart_required"] = gatewayRestartRequired(
configDefaultModel,
bootDefaultModel,
gatewayStatus,
)
ready, reason, readyErr := h.gatewayStartReady()
if readyErr != nil {
@@ -842,51 +798,6 @@ func gatewayLogsData(r *http.Request) map[string]any {
return data
}
// handleGatewayEvents serves an SSE stream of gateway state change events.
//
// GET /api/gateway/events
func (h *Handler) handleGatewayEvents(w http.ResponseWriter, r *http.Request) {
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "SSE not supported", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("Access-Control-Allow-Origin", "*")
// Subscribe to gateway events
ch := gateway.events.Subscribe()
defer gateway.events.Unsubscribe(ch)
// Send initial status so the client doesn't start blank
initial := h.currentGatewayStatus()
fmt.Fprintf(w, "data: %s\n\n", initial)
flusher.Flush()
for {
select {
case <-r.Context().Done():
return
case data, ok := <-ch:
if !ok {
return
}
fmt.Fprintf(w, "data: %s\n\n", data)
flusher.Flush()
}
}
}
// currentGatewayStatus returns the current gateway status as a JSON string.
func (h *Handler) currentGatewayStatus() string {
data := h.gatewayStatusData()
encoded, _ := json.Marshal(data)
return string(encoded)
}
// scanPipe reads lines from r and appends them to buf. Returns when r reaches EOF.
func scanPipe(r io.Reader, buf *LogBuffer) {
scanner := bufio.NewScanner(r)
+18
View File
@@ -3,6 +3,7 @@ package api
import (
"net"
"net/http"
"net/url"
"strconv"
"strings"
@@ -46,6 +47,23 @@ func gatewayProbeHost(bindHost string) string {
return bindHost
}
func (h *Handler) gatewayProxyURL() *url.URL {
cfg, err := config.LoadConfig(h.configPath)
port := 18790
bindHost := ""
if err == nil && cfg != nil {
if cfg.Gateway.Port != 0 {
port = cfg.Gateway.Port
}
bindHost = h.effectiveGatewayBindHost(cfg)
}
return &url.URL{
Scheme: "http",
Host: net.JoinHostPort(gatewayProbeHost(bindHost), strconv.Itoa(port)),
}
}
func requestHostName(r *http.Request) string {
reqHost, _, err := net.SplitHostPort(r.Host)
if err == nil {
+76
View File
@@ -2,9 +2,12 @@ package api
import (
"crypto/tls"
"errors"
"net/http"
"net/http/httptest"
"path/filepath"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/web/backend/launcherconfig"
@@ -59,6 +62,79 @@ func TestGatewayProbeHostUsesLoopbackForWildcardBind(t *testing.T) {
}
}
func TestGatewayProxyURLUsesConfiguredHost(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
cfg := config.DefaultConfig()
cfg.Gateway.Host = "192.168.1.10"
cfg.Gateway.Port = 18791
if err := config.SaveConfig(configPath, cfg); err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}
if got := h.gatewayProxyURL().String(); got != "http://192.168.1.10:18791" {
t.Fatalf("gatewayProxyURL() = %q, want %q", got, "http://192.168.1.10:18791")
}
}
func TestGetGatewayHealthUsesConfiguredHost(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
cfg := config.DefaultConfig()
cfg.Gateway.Host = "192.168.1.10"
cfg.Gateway.Port = 18791
originalHealthGet := gatewayHealthGet
t.Cleanup(func() {
gatewayHealthGet = originalHealthGet
})
var requestedURL string
gatewayHealthGet = func(url string, timeout time.Duration) (*http.Response, error) {
requestedURL = url
return nil, errors.New("probe failed")
}
_, statusCode, err := h.getGatewayHealth(cfg, time.Second)
_ = statusCode
_ = err
if requestedURL != "http://192.168.1.10:18791/health" {
t.Fatalf("health url = %q, want %q", requestedURL, "http://192.168.1.10:18791/health")
}
}
func TestGetGatewayHealthUsesProbeHostForPublicLauncher(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
h.SetServerOptions(18800, true, true, nil)
cfg := config.DefaultConfig()
cfg.Gateway.Host = "127.0.0.1"
cfg.Gateway.Port = 18791
originalHealthGet := gatewayHealthGet
t.Cleanup(func() {
gatewayHealthGet = originalHealthGet
})
var requestedURL string
gatewayHealthGet = func(url string, timeout time.Duration) (*http.Response, error) {
requestedURL = url
return nil, errors.New("probe failed")
}
_, statusCode, err := h.getGatewayHealth(cfg, time.Second)
_ = statusCode
_ = err
if requestedURL != "http://127.0.0.1:18791/health" {
t.Fatalf("health url = %q, want %q", requestedURL, "http://127.0.0.1:18791/health")
}
}
func TestBuildWsURLUsesWSSWhenForwardedProtoIsHTTPS(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
+129
View File
@@ -3,6 +3,7 @@ package api
import (
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
"os"
@@ -36,6 +37,15 @@ func startLongRunningProcess(t *testing.T) *exec.Cmd {
return cmd
}
func mockGatewayHealthResponse(statusCode, pid int) *http.Response {
return &http.Response{
StatusCode: statusCode,
Body: io.NopCloser(strings.NewReader(
`{"status":"ok","uptime":"1s","pid":` + strconv.Itoa(pid) + `}`,
)),
}
}
func startIgnoringTermProcess(t *testing.T) *exec.Cmd {
t.Helper()
@@ -419,6 +429,125 @@ func TestGatewayStatusKeepsRunningWhenHealthProbeFailsAfterRunning(t *testing.T)
}
}
func TestGatewayStatusReportsRunningFromHealthProbe(t *testing.T) {
resetGatewayTestState(t)
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
mux := http.NewServeMux()
h.RegisterRoutes(mux)
cmd := startLongRunningProcess(t)
t.Cleanup(func() {
if cmd.Process != nil {
_ = cmd.Process.Kill()
}
_ = cmd.Wait()
})
gateway.mu.Lock()
setGatewayRuntimeStatusLocked("stopped")
gateway.mu.Unlock()
gatewayHealthGet = func(string, time.Duration) (*http.Response, error) {
return mockGatewayHealthResponse(http.StatusOK, cmd.Process.Pid), nil
}
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil)
mux.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}
var body map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
t.Fatalf("unmarshal response: %v", err)
}
if got := body["gateway_status"]; got != "running" {
t.Fatalf("gateway_status = %#v, want %q", got, "running")
}
if got := body["pid"]; got != float64(cmd.Process.Pid) {
t.Fatalf("pid = %#v, want %d", got, cmd.Process.Pid)
}
if got := body["gateway_restart_required"]; got != false {
t.Fatalf("gateway_restart_required = %#v, want false", got)
}
}
func TestGatewayStatusRequiresRestartAfterDefaultModelChange(t *testing.T) {
resetGatewayTestState(t)
configPath := filepath.Join(t.TempDir(), "config.json")
cfg := config.DefaultConfig()
cfg.Agents.Defaults.ModelName = cfg.ModelList[0].ModelName
cfg.ModelList[0].APIKey = "test-key"
cfg.ModelList = append(cfg.ModelList, config.ModelConfig{
ModelName: "second-model",
Model: "openai/gpt-4.1",
APIKey: "second-key",
})
if err := config.SaveConfig(configPath, cfg); err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}
h := NewHandler(configPath)
mux := http.NewServeMux()
h.RegisterRoutes(mux)
process, err := os.FindProcess(os.Getpid())
if err != nil {
t.Fatalf("FindProcess() error = %v", err)
}
gateway.mu.Lock()
gateway.cmd = &exec.Cmd{Process: process}
gateway.bootDefaultModel = cfg.ModelList[0].ModelName
setGatewayRuntimeStatusLocked("running")
gateway.mu.Unlock()
updatedCfg, err := config.LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error = %v", err)
}
updatedCfg.Agents.Defaults.ModelName = "second-model"
if err := config.SaveConfig(configPath, updatedCfg); err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}
gatewayHealthGet = func(string, time.Duration) (*http.Response, error) {
return mockGatewayHealthResponse(http.StatusOK, os.Getpid()), nil
}
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil)
mux.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}
var body map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
t.Fatalf("unmarshal response: %v", err)
}
if got := body["gateway_status"]; got != "running" {
t.Fatalf("gateway_status = %#v, want %q", got, "running")
}
if got := body["boot_default_model"]; got != cfg.ModelList[0].ModelName {
t.Fatalf("boot_default_model = %#v, want %q", got, cfg.ModelList[0].ModelName)
}
if got := body["config_default_model"]; got != "second-model" {
t.Fatalf("config_default_model = %#v, want %q", got, "second-model")
}
if got := body["gateway_restart_required"]; got != true {
t.Fatalf("gateway_restart_required = %#v, want true", got)
}
}
func TestGatewayStatusReturnsErrorAfterStartupWindowExpires(t *testing.T) {
resetGatewayTestState(t)
+7 -17
View File
@@ -7,7 +7,6 @@ import (
"fmt"
"net/http"
"net/http/httputil"
"net/url"
"time"
"github.com/sipeed/picoclaw/pkg/config"
@@ -22,20 +21,13 @@ func (h *Handler) registerPicoRoutes(mux *http.ServeMux) {
// WebSocket proxy: forward /pico/ws to gateway
// This allows the frontend to connect via the same port as the web UI,
// avoiding the need to expose extra ports for WebSocket communication.
wsProxy := h.createWsProxy()
mux.HandleFunc("GET /pico/ws", h.handleWebSocketProxy(wsProxy))
mux.HandleFunc("GET /pico/ws", h.handleWebSocketProxy())
}
// createWsProxy creates a reverse proxy to the gateway WebSocket endpoint.
// The gateway port is read from the configuration.
// createWsProxy creates a reverse proxy to the current gateway WebSocket endpoint.
// The gateway bind host and port are resolved from the latest configuration.
func (h *Handler) createWsProxy() *httputil.ReverseProxy {
cfg, err := config.LoadConfig(h.configPath)
gatewayPort := 18790 // default
if err == nil && cfg.Gateway.Port != 0 {
gatewayPort = cfg.Gateway.Port
}
gatewayURL, _ := url.Parse(fmt.Sprintf("http://127.0.0.1:%d", gatewayPort))
wsProxy := httputil.NewSingleHostReverseProxy(gatewayURL)
wsProxy := httputil.NewSingleHostReverseProxy(h.gatewayProxyURL())
wsProxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
http.Error(w, "Gateway unavailable: "+err.Error(), http.StatusBadGateway)
}
@@ -43,12 +35,10 @@ func (h *Handler) createWsProxy() *httputil.ReverseProxy {
}
// handleWebSocketProxy wraps a reverse proxy to handle WebSocket connections.
// It ensures the Connection and Upgrade headers are properly forwarded.
func (h *Handler) handleWebSocketProxy(proxy *httputil.ReverseProxy) http.HandlerFunc {
// The reverse proxy forwards the incoming upgrade handshake as-is.
func (h *Handler) handleWebSocketProxy() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Set headers for WebSocket upgrade
r.Header.Set("Connection", "upgrade")
r.Header.Set("Upgrade", "websocket")
proxy := h.createWsProxy()
proxy.ServeHTTP(w, r)
}
}
+77
View File
@@ -2,9 +2,12 @@ package api
import (
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"net/url"
"path/filepath"
"strconv"
"testing"
"github.com/sipeed/picoclaw/pkg/config"
@@ -235,3 +238,77 @@ func TestHandlePicoSetup_Response(t *testing.T) {
t.Error("response should have changed=true on first setup")
}
}
func TestHandleWebSocketProxyReloadsGatewayTargetFromConfig(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
handler := h.handleWebSocketProxy()
server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/pico/ws" {
t.Fatalf("server1 path = %q, want %q", r.URL.Path, "/pico/ws")
}
w.WriteHeader(http.StatusOK)
_, _ = io.WriteString(w, "server1")
}))
defer server1.Close()
server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/pico/ws" {
t.Fatalf("server2 path = %q, want %q", r.URL.Path, "/pico/ws")
}
w.WriteHeader(http.StatusOK)
_, _ = io.WriteString(w, "server2")
}))
defer server2.Close()
cfg := config.DefaultConfig()
cfg.Gateway.Host = "127.0.0.1"
cfg.Gateway.Port = mustGatewayTestPort(t, server1.URL)
if err := config.SaveConfig(configPath, cfg); err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}
req1 := httptest.NewRequest(http.MethodGet, "/pico/ws", nil)
rec1 := httptest.NewRecorder()
handler(rec1, req1)
if rec1.Code != http.StatusOK {
t.Fatalf("first status = %d, want %d", rec1.Code, http.StatusOK)
}
if body := rec1.Body.String(); body != "server1" {
t.Fatalf("first body = %q, want %q", body, "server1")
}
cfg.Gateway.Port = mustGatewayTestPort(t, server2.URL)
if err := config.SaveConfig(configPath, cfg); err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}
req2 := httptest.NewRequest(http.MethodGet, "/pico/ws", nil)
rec2 := httptest.NewRecorder()
handler(rec2, req2)
if rec2.Code != http.StatusOK {
t.Fatalf("second status = %d, want %d", rec2.Code, http.StatusOK)
}
if body := rec2.Body.String(); body != "server2" {
t.Fatalf("second body = %q, want %q", body, "server2")
}
}
func mustGatewayTestPort(t *testing.T, rawURL string) int {
t.Helper()
parsed, err := url.Parse(rawURL)
if err != nil {
t.Fatalf("url.Parse() error = %v", err)
}
port, err := strconv.Atoi(parsed.Port())
if err != nil {
t.Fatalf("Atoi(%q) error = %v", parsed.Port(), err)
}
return port
}
+1 -4
View File
@@ -71,7 +71,4 @@ func (h *Handler) RegisterRoutes(mux *http.ServeMux) {
h.registerLauncherConfigRoutes(mux)
}
// Shutdown gracefully shuts down the handler, closing all SSE connections.
func (h *Handler) Shutdown() {
gateway.events.Shutdown()
}
func (h *Handler) Shutdown() {}