mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
feat: add web gateway hot reload and polling state sync (#1684)
* feat(gateway): support hot reload and empty startup - extract gateway runtime into pkg/gateway - add gateway.hot_reload config with default and example values - allow starting the gateway without a default model via --allow-empty - stop treating missing enabled channels as a startup error - update related tests * feat: replace gateway SSE updates with polling-based state sync - remove gateway SSE broadcasting and event endpoint - add polling-based gateway status refresh with stopping state handling - detect when gateway restart is required after default model changes - resolve gateway health and websocket proxy targets from configured host - update gateway UI labels and add backend/frontend test coverage
This commit is contained in:
@@ -1,80 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// GatewayEvent represents a state change event for the gateway process.
|
||||
type GatewayEvent struct {
|
||||
Status string `json:"gateway_status"` // "running", "starting", "restarting", "stopped", "error"
|
||||
PID int `json:"pid,omitempty"`
|
||||
BootDefaultModel string `json:"boot_default_model,omitempty"`
|
||||
ConfigDefaultModel string `json:"config_default_model,omitempty"`
|
||||
RestartRequired bool `json:"gateway_restart_required,omitempty"`
|
||||
}
|
||||
|
||||
// EventBroadcaster manages SSE client subscriptions and broadcasts events.
|
||||
type EventBroadcaster struct {
|
||||
mu sync.RWMutex
|
||||
clients map[chan string]struct{}
|
||||
}
|
||||
|
||||
// NewEventBroadcaster creates a new broadcaster.
|
||||
func NewEventBroadcaster() *EventBroadcaster {
|
||||
return &EventBroadcaster{
|
||||
clients: make(map[chan string]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Subscribe adds a new listener channel and returns it.
|
||||
// The caller must call Unsubscribe when done.
|
||||
func (b *EventBroadcaster) Subscribe() chan string {
|
||||
ch := make(chan string, 8)
|
||||
b.mu.Lock()
|
||||
b.clients[ch] = struct{}{}
|
||||
b.mu.Unlock()
|
||||
return ch
|
||||
}
|
||||
|
||||
// Unsubscribe removes a listener channel and closes it.
|
||||
func (b *EventBroadcaster) Unsubscribe(ch chan string) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
// Check if the channel is still registered before closing
|
||||
if _, exists := b.clients[ch]; exists {
|
||||
delete(b.clients, ch)
|
||||
close(ch)
|
||||
}
|
||||
}
|
||||
|
||||
// Broadcast sends a GatewayEvent to all connected SSE clients.
|
||||
func (b *EventBroadcaster) Broadcast(event GatewayEvent) {
|
||||
data, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
b.mu.RLock()
|
||||
defer b.mu.RUnlock()
|
||||
|
||||
for ch := range b.clients {
|
||||
// Non-blocking send; drop event if client is slow
|
||||
select {
|
||||
case ch <- string(data):
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown closes all subscriber channels, notifying all SSE clients to disconnect.
|
||||
// This should be called when the server is shutting down.
|
||||
func (b *EventBroadcaster) Shutdown() {
|
||||
// Close all channels to notify listeners
|
||||
for ch := range b.clients {
|
||||
b.Unsubscribe(ch)
|
||||
}
|
||||
// Clear the map
|
||||
b.clients = make(map[chan string]struct{})
|
||||
}
|
||||
+60
-149
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
@@ -30,11 +31,9 @@ var gateway = struct {
|
||||
runtimeStatus string
|
||||
startupDeadline time.Time
|
||||
logs *LogBuffer
|
||||
events *EventBroadcaster
|
||||
}{
|
||||
runtimeStatus: "stopped",
|
||||
logs: NewLogBuffer(200),
|
||||
events: NewEventBroadcaster(),
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -51,11 +50,19 @@ var gatewayHealthGet = func(url string, timeout time.Duration) (*http.Response,
|
||||
|
||||
// getGatewayHealth checks the gateway health endpoint and returns the status response
|
||||
// Returns (*health.StatusResponse, statusCode, error). If error is not nil, the other values are not valid.
|
||||
func getGatewayHealth(port int, timeout time.Duration) (*health.StatusResponse, int, error) {
|
||||
if port == 0 {
|
||||
port = 18790
|
||||
func (h *Handler) getGatewayHealth(cfg *config.Config, timeout time.Duration) (*health.StatusResponse, int, error) {
|
||||
port := 18790
|
||||
if cfg != nil && cfg.Gateway.Port != 0 {
|
||||
port = cfg.Gateway.Port
|
||||
}
|
||||
url := fmt.Sprintf("http://127.0.0.1:%d/health", port)
|
||||
|
||||
probeHost := gatewayProbeHost(h.effectiveGatewayBindHost(cfg))
|
||||
url := "http://" + net.JoinHostPort(probeHost, strconv.Itoa(port)) + "/health"
|
||||
|
||||
return getGatewayHealthByURL(url, timeout)
|
||||
}
|
||||
|
||||
func getGatewayHealthByURL(url string, timeout time.Duration) (*health.StatusResponse, int, error) {
|
||||
resp, err := gatewayHealthGet(url, timeout)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
@@ -73,7 +80,6 @@ func getGatewayHealth(port int, timeout time.Duration) (*health.StatusResponse,
|
||||
// registerGatewayRoutes binds gateway lifecycle endpoints to the ServeMux.
|
||||
func (h *Handler) registerGatewayRoutes(mux *http.ServeMux) {
|
||||
mux.HandleFunc("GET /api/gateway/status", h.handleGatewayStatus)
|
||||
mux.HandleFunc("GET /api/gateway/events", h.handleGatewayEvents)
|
||||
mux.HandleFunc("GET /api/gateway/logs", h.handleGatewayLogs)
|
||||
mux.HandleFunc("POST /api/gateway/logs/clear", h.handleGatewayClearLogs)
|
||||
mux.HandleFunc("POST /api/gateway/start", h.handleGatewayStart)
|
||||
@@ -87,7 +93,7 @@ func (h *Handler) TryAutoStartGateway() {
|
||||
// Check if gateway is already running via health endpoint
|
||||
cfg, cfgErr := config.LoadConfig(h.configPath)
|
||||
if cfgErr == nil && cfg != nil {
|
||||
healthResp, statusCode, err := getGatewayHealth(cfg.Gateway.Port, 2*time.Second)
|
||||
healthResp, statusCode, err := h.getGatewayHealth(cfg, 2*time.Second)
|
||||
if err == nil && statusCode == http.StatusOK {
|
||||
// Gateway is already running, attach to the existing process
|
||||
pid := healthResp.Pid
|
||||
@@ -170,6 +176,16 @@ func lookupModelConfig(cfg *config.Config, modelName string) *config.ModelConfig
|
||||
return modelCfg
|
||||
}
|
||||
|
||||
func gatewayRestartRequired(configDefaultModel, bootDefaultModel, gatewayStatus string) bool {
|
||||
if gatewayStatus != "running" {
|
||||
return false
|
||||
}
|
||||
if strings.TrimSpace(configDefaultModel) == "" || strings.TrimSpace(bootDefaultModel) == "" {
|
||||
return false
|
||||
}
|
||||
return configDefaultModel != bootDefaultModel
|
||||
}
|
||||
|
||||
func isCmdProcessAliveLocked(cmd *exec.Cmd) bool {
|
||||
if cmd == nil || cmd.Process == nil {
|
||||
return false
|
||||
@@ -220,7 +236,7 @@ func attachToGatewayProcessLocked(pid int, cfg *config.Config) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func gatewayStatusOnHealthFailureLocked() string {
|
||||
func gatewayStatusWithoutHealthLocked() string {
|
||||
if gateway.runtimeStatus == "starting" || gateway.runtimeStatus == "restarting" {
|
||||
if gateway.startupDeadline.IsZero() || time.Now().Before(gateway.startupDeadline) {
|
||||
return gateway.runtimeStatus
|
||||
@@ -233,23 +249,7 @@ func gatewayStatusOnHealthFailureLocked() string {
|
||||
if gateway.runtimeStatus == "error" {
|
||||
return "error"
|
||||
}
|
||||
return "error"
|
||||
}
|
||||
|
||||
func currentGatewayStatusLocked(processAlive bool) string {
|
||||
if !processAlive {
|
||||
if gateway.runtimeStatus == "restarting" {
|
||||
if gateway.startupDeadline.IsZero() || time.Now().Before(gateway.startupDeadline) {
|
||||
return "restarting"
|
||||
}
|
||||
return "error"
|
||||
}
|
||||
if gateway.runtimeStatus == "error" {
|
||||
return "error"
|
||||
}
|
||||
return "stopped"
|
||||
}
|
||||
return gatewayStatusOnHealthFailureLocked()
|
||||
return "stopped"
|
||||
}
|
||||
|
||||
func waitForGatewayProcessExit(cmd *exec.Cmd, timeout time.Duration) bool {
|
||||
@@ -319,15 +319,6 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Broadcast the attached state
|
||||
gateway.events.Broadcast(GatewayEvent{
|
||||
Status: initialStatus,
|
||||
PID: pid,
|
||||
BootDefaultModel: defaultModelName,
|
||||
ConfigDefaultModel: defaultModelName,
|
||||
RestartRequired: false,
|
||||
})
|
||||
|
||||
return pid, nil
|
||||
}
|
||||
|
||||
@@ -335,7 +326,7 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
|
||||
// Locate the picoclaw executable
|
||||
execPath := utils.FindPicoclawBinary()
|
||||
|
||||
cmd = exec.Command(execPath, "gateway")
|
||||
cmd = exec.Command(execPath, "gateway", "-E")
|
||||
cmd.Env = os.Environ()
|
||||
// Forward the launcher's config path via the environment variable that
|
||||
// GetConfigPath() already reads, so the gateway sub-process uses the same
|
||||
@@ -376,15 +367,6 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
|
||||
pid = cmd.Process.Pid
|
||||
log.Printf("Started picoclaw gateway (PID: %d) from %s", pid, execPath)
|
||||
|
||||
// Broadcast the launch state immediately so clients can reflect it without polling.
|
||||
gateway.events.Broadcast(GatewayEvent{
|
||||
Status: initialStatus,
|
||||
PID: pid,
|
||||
BootDefaultModel: defaultModelName,
|
||||
ConfigDefaultModel: defaultModelName,
|
||||
RestartRequired: false,
|
||||
})
|
||||
|
||||
// Capture stdout/stderr in background
|
||||
go scanPipe(stdoutPipe, gateway.logs)
|
||||
go scanPipe(stderrPipe, gateway.logs)
|
||||
@@ -398,26 +380,17 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
|
||||
}
|
||||
|
||||
gateway.mu.Lock()
|
||||
shouldBroadcastStopped := false
|
||||
if gateway.cmd == cmd {
|
||||
gateway.cmd = nil
|
||||
gateway.bootDefaultModel = ""
|
||||
if gateway.runtimeStatus != "restarting" {
|
||||
setGatewayRuntimeStatusLocked("stopped")
|
||||
shouldBroadcastStopped = true
|
||||
}
|
||||
}
|
||||
gateway.mu.Unlock()
|
||||
|
||||
if shouldBroadcastStopped {
|
||||
gateway.events.Broadcast(GatewayEvent{
|
||||
Status: "stopped",
|
||||
RestartRequired: false,
|
||||
})
|
||||
}
|
||||
}()
|
||||
|
||||
// Start a goroutine to probe health and broadcast "running" once ready
|
||||
// Start a goroutine to probe health and update the runtime state once ready.
|
||||
go func() {
|
||||
for i := 0; i < 30; i++ { // try for up to 15 seconds
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
@@ -431,7 +404,7 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
healthResp, statusCode, err := getGatewayHealth(cfg.Gateway.Port, 1*time.Second)
|
||||
healthResp, statusCode, err := h.getGatewayHealth(cfg, 1*time.Second)
|
||||
if err == nil && statusCode == http.StatusOK && healthResp.Pid == pid {
|
||||
// Verify the health endpoint returns the expected pid
|
||||
gateway.mu.Lock()
|
||||
@@ -439,13 +412,6 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
|
||||
setGatewayRuntimeStatusLocked("running")
|
||||
}
|
||||
gateway.mu.Unlock()
|
||||
gateway.events.Broadcast(GatewayEvent{
|
||||
Status: "running",
|
||||
PID: pid,
|
||||
BootDefaultModel: defaultModelName,
|
||||
ConfigDefaultModel: defaultModelName,
|
||||
RestartRequired: false,
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -461,7 +427,7 @@ func (h *Handler) handleGatewayStart(w http.ResponseWriter, r *http.Request) {
|
||||
// Prevent duplicate starts by checking health endpoint
|
||||
cfg, cfgErr := config.LoadConfig(h.configPath)
|
||||
if cfgErr == nil && cfg != nil {
|
||||
healthResp, statusCode, err := getGatewayHealth(cfg.Gateway.Port, 2*time.Second)
|
||||
healthResp, statusCode, err := h.getGatewayHealth(cfg, 2*time.Second)
|
||||
if err == nil && statusCode == http.StatusOK {
|
||||
// Gateway is already running, attach to the existing process
|
||||
pid := healthResp.Pid
|
||||
@@ -597,10 +563,6 @@ func (h *Handler) RestartGateway() (int, error) {
|
||||
gateway.mu.Lock()
|
||||
previousCmd := gateway.cmd
|
||||
setGatewayRuntimeStatusLocked("restarting")
|
||||
gateway.events.Broadcast(GatewayEvent{
|
||||
Status: "restarting",
|
||||
RestartRequired: false,
|
||||
})
|
||||
gateway.mu.Unlock()
|
||||
|
||||
if err = stopGatewayProcessForRestart(previousCmd); err != nil {
|
||||
@@ -704,24 +666,20 @@ func (h *Handler) handleGatewayStatus(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
func (h *Handler) gatewayStatusData() map[string]any {
|
||||
data := map[string]any{}
|
||||
configDefaultModel := ""
|
||||
cfg, cfgErr := config.LoadConfig(h.configPath)
|
||||
if cfgErr == nil && cfg != nil {
|
||||
configDefaultModel := strings.TrimSpace(cfg.Agents.Defaults.GetModelName())
|
||||
configDefaultModel = strings.TrimSpace(cfg.Agents.Defaults.GetModelName())
|
||||
if configDefaultModel != "" {
|
||||
data["config_default_model"] = configDefaultModel
|
||||
}
|
||||
}
|
||||
|
||||
// Probe health endpoint to get pid and status
|
||||
port := 0
|
||||
if cfgErr == nil && cfg != nil {
|
||||
port = cfg.Gateway.Port
|
||||
}
|
||||
|
||||
healthResp, statusCode, err := getGatewayHealth(port, 2*time.Second)
|
||||
healthResp, statusCode, err := h.getGatewayHealth(cfg, 2*time.Second)
|
||||
if err != nil {
|
||||
gateway.mu.Lock()
|
||||
data["gateway_status"] = currentGatewayStatusLocked(true)
|
||||
data["gateway_status"] = gatewayStatusWithoutHealthLocked()
|
||||
gateway.mu.Unlock()
|
||||
log.Printf("Gateway health check failed: %v", err)
|
||||
} else {
|
||||
@@ -734,45 +692,43 @@ func (h *Handler) gatewayStatusData() map[string]any {
|
||||
data["status_code"] = statusCode
|
||||
} else {
|
||||
gateway.mu.Lock()
|
||||
// Check if this pid matches our tracked process
|
||||
if gateway.cmd != nil && gateway.cmd.Process != nil && gateway.cmd.Process.Pid == healthResp.Pid {
|
||||
setGatewayRuntimeStatusLocked("running")
|
||||
bootDefaultModel := gateway.bootDefaultModel
|
||||
if bootDefaultModel != "" {
|
||||
data["boot_default_model"] = bootDefaultModel
|
||||
}
|
||||
data["gateway_status"] = "running"
|
||||
data["pid"] = healthResp.Pid
|
||||
} else {
|
||||
// Health endpoint responded with a different pid
|
||||
// This could be a manual restart, try to attach to the new process
|
||||
setGatewayRuntimeStatusLocked("running")
|
||||
if gateway.cmd == nil || gateway.cmd.Process == nil || gateway.cmd.Process.Pid != healthResp.Pid {
|
||||
oldPid := "none"
|
||||
if gateway.cmd != nil && gateway.cmd.Process != nil {
|
||||
oldPid = fmt.Sprintf("%d", gateway.cmd.Process.Pid)
|
||||
}
|
||||
log.Printf("Detected new gateway PID (old: %s, new: %d), attempting to attach", oldPid, healthResp.Pid)
|
||||
|
||||
log.Printf(
|
||||
"Detected gateway PID from health (old: %s, new: %d), attempting to attach",
|
||||
oldPid,
|
||||
healthResp.Pid,
|
||||
)
|
||||
if err := attachToGatewayProcessLocked(healthResp.Pid, cfg); err != nil {
|
||||
// Failed to find the process, treat as error
|
||||
setGatewayRuntimeStatusLocked("error")
|
||||
data["gateway_status"] = "error"
|
||||
data["pid"] = healthResp.Pid
|
||||
log.Printf("Failed to attach to new gateway process (PID: %d): %v", healthResp.Pid, err)
|
||||
} else {
|
||||
// Successfully attached, update response data
|
||||
bootDefaultModel := gateway.bootDefaultModel
|
||||
if bootDefaultModel != "" {
|
||||
data["boot_default_model"] = bootDefaultModel
|
||||
}
|
||||
data["gateway_status"] = "running"
|
||||
data["pid"] = healthResp.Pid
|
||||
log.Printf(
|
||||
"Failed to attach to gateway process reported by health (PID: %d): %v",
|
||||
healthResp.Pid,
|
||||
err,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
bootDefaultModel := gateway.bootDefaultModel
|
||||
if bootDefaultModel != "" {
|
||||
data["boot_default_model"] = bootDefaultModel
|
||||
}
|
||||
data["gateway_status"] = "running"
|
||||
data["pid"] = healthResp.Pid
|
||||
gateway.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
data["gateway_restart_required"] = false
|
||||
bootDefaultModel, _ := data["boot_default_model"].(string)
|
||||
gatewayStatus, _ := data["gateway_status"].(string)
|
||||
data["gateway_restart_required"] = gatewayRestartRequired(
|
||||
configDefaultModel,
|
||||
bootDefaultModel,
|
||||
gatewayStatus,
|
||||
)
|
||||
|
||||
ready, reason, readyErr := h.gatewayStartReady()
|
||||
if readyErr != nil {
|
||||
@@ -842,51 +798,6 @@ func gatewayLogsData(r *http.Request) map[string]any {
|
||||
return data
|
||||
}
|
||||
|
||||
// handleGatewayEvents serves an SSE stream of gateway state change events.
|
||||
//
|
||||
// GET /api/gateway/events
|
||||
func (h *Handler) handleGatewayEvents(w http.ResponseWriter, r *http.Request) {
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
http.Error(w, "SSE not supported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
|
||||
// Subscribe to gateway events
|
||||
ch := gateway.events.Subscribe()
|
||||
defer gateway.events.Unsubscribe(ch)
|
||||
|
||||
// Send initial status so the client doesn't start blank
|
||||
initial := h.currentGatewayStatus()
|
||||
fmt.Fprintf(w, "data: %s\n\n", initial)
|
||||
flusher.Flush()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-r.Context().Done():
|
||||
return
|
||||
case data, ok := <-ch:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// currentGatewayStatus returns the current gateway status as a JSON string.
|
||||
func (h *Handler) currentGatewayStatus() string {
|
||||
data := h.gatewayStatusData()
|
||||
encoded, _ := json.Marshal(data)
|
||||
return string(encoded)
|
||||
}
|
||||
|
||||
// scanPipe reads lines from r and appends them to buf. Returns when r reaches EOF.
|
||||
func scanPipe(r io.Reader, buf *LogBuffer) {
|
||||
scanner := bufio.NewScanner(r)
|
||||
|
||||
@@ -3,6 +3,7 @@ package api
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@@ -46,6 +47,23 @@ func gatewayProbeHost(bindHost string) string {
|
||||
return bindHost
|
||||
}
|
||||
|
||||
func (h *Handler) gatewayProxyURL() *url.URL {
|
||||
cfg, err := config.LoadConfig(h.configPath)
|
||||
port := 18790
|
||||
bindHost := ""
|
||||
if err == nil && cfg != nil {
|
||||
if cfg.Gateway.Port != 0 {
|
||||
port = cfg.Gateway.Port
|
||||
}
|
||||
bindHost = h.effectiveGatewayBindHost(cfg)
|
||||
}
|
||||
|
||||
return &url.URL{
|
||||
Scheme: "http",
|
||||
Host: net.JoinHostPort(gatewayProbeHost(bindHost), strconv.Itoa(port)),
|
||||
}
|
||||
}
|
||||
|
||||
func requestHostName(r *http.Request) string {
|
||||
reqHost, _, err := net.SplitHostPort(r.Host)
|
||||
if err == nil {
|
||||
|
||||
@@ -2,9 +2,12 @@ package api
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/web/backend/launcherconfig"
|
||||
@@ -59,6 +62,79 @@ func TestGatewayProbeHostUsesLoopbackForWildcardBind(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayProxyURLUsesConfiguredHost(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Gateway.Host = "192.168.1.10"
|
||||
cfg.Gateway.Port = 18791
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
if got := h.gatewayProxyURL().String(); got != "http://192.168.1.10:18791" {
|
||||
t.Fatalf("gatewayProxyURL() = %q, want %q", got, "http://192.168.1.10:18791")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetGatewayHealthUsesConfiguredHost(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Gateway.Host = "192.168.1.10"
|
||||
cfg.Gateway.Port = 18791
|
||||
|
||||
originalHealthGet := gatewayHealthGet
|
||||
t.Cleanup(func() {
|
||||
gatewayHealthGet = originalHealthGet
|
||||
})
|
||||
|
||||
var requestedURL string
|
||||
gatewayHealthGet = func(url string, timeout time.Duration) (*http.Response, error) {
|
||||
requestedURL = url
|
||||
return nil, errors.New("probe failed")
|
||||
}
|
||||
|
||||
_, statusCode, err := h.getGatewayHealth(cfg, time.Second)
|
||||
_ = statusCode
|
||||
_ = err
|
||||
|
||||
if requestedURL != "http://192.168.1.10:18791/health" {
|
||||
t.Fatalf("health url = %q, want %q", requestedURL, "http://192.168.1.10:18791/health")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetGatewayHealthUsesProbeHostForPublicLauncher(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
h.SetServerOptions(18800, true, true, nil)
|
||||
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Gateway.Host = "127.0.0.1"
|
||||
cfg.Gateway.Port = 18791
|
||||
|
||||
originalHealthGet := gatewayHealthGet
|
||||
t.Cleanup(func() {
|
||||
gatewayHealthGet = originalHealthGet
|
||||
})
|
||||
|
||||
var requestedURL string
|
||||
gatewayHealthGet = func(url string, timeout time.Duration) (*http.Response, error) {
|
||||
requestedURL = url
|
||||
return nil, errors.New("probe failed")
|
||||
}
|
||||
|
||||
_, statusCode, err := h.getGatewayHealth(cfg, time.Second)
|
||||
_ = statusCode
|
||||
_ = err
|
||||
|
||||
if requestedURL != "http://127.0.0.1:18791/health" {
|
||||
t.Fatalf("health url = %q, want %q", requestedURL, "http://127.0.0.1:18791/health")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildWsURLUsesWSSWhenForwardedProtoIsHTTPS(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
@@ -3,6 +3,7 @@ package api
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
@@ -36,6 +37,15 @@ func startLongRunningProcess(t *testing.T) *exec.Cmd {
|
||||
return cmd
|
||||
}
|
||||
|
||||
func mockGatewayHealthResponse(statusCode, pid int) *http.Response {
|
||||
return &http.Response{
|
||||
StatusCode: statusCode,
|
||||
Body: io.NopCloser(strings.NewReader(
|
||||
`{"status":"ok","uptime":"1s","pid":` + strconv.Itoa(pid) + `}`,
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
func startIgnoringTermProcess(t *testing.T) *exec.Cmd {
|
||||
t.Helper()
|
||||
|
||||
@@ -419,6 +429,125 @@ func TestGatewayStatusKeepsRunningWhenHealthProbeFailsAfterRunning(t *testing.T)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayStatusReportsRunningFromHealthProbe(t *testing.T) {
|
||||
resetGatewayTestState(t)
|
||||
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
cmd := startLongRunningProcess(t)
|
||||
t.Cleanup(func() {
|
||||
if cmd.Process != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
}
|
||||
_ = cmd.Wait()
|
||||
})
|
||||
|
||||
gateway.mu.Lock()
|
||||
setGatewayRuntimeStatusLocked("stopped")
|
||||
gateway.mu.Unlock()
|
||||
|
||||
gatewayHealthGet = func(string, time.Duration) (*http.Response, error) {
|
||||
return mockGatewayHealthResponse(http.StatusOK, cmd.Process.Pid), nil
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var body map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if got := body["gateway_status"]; got != "running" {
|
||||
t.Fatalf("gateway_status = %#v, want %q", got, "running")
|
||||
}
|
||||
if got := body["pid"]; got != float64(cmd.Process.Pid) {
|
||||
t.Fatalf("pid = %#v, want %d", got, cmd.Process.Pid)
|
||||
}
|
||||
if got := body["gateway_restart_required"]; got != false {
|
||||
t.Fatalf("gateway_restart_required = %#v, want false", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayStatusRequiresRestartAfterDefaultModelChange(t *testing.T) {
|
||||
resetGatewayTestState(t)
|
||||
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Agents.Defaults.ModelName = cfg.ModelList[0].ModelName
|
||||
cfg.ModelList[0].APIKey = "test-key"
|
||||
cfg.ModelList = append(cfg.ModelList, config.ModelConfig{
|
||||
ModelName: "second-model",
|
||||
Model: "openai/gpt-4.1",
|
||||
APIKey: "second-key",
|
||||
})
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
process, err := os.FindProcess(os.Getpid())
|
||||
if err != nil {
|
||||
t.Fatalf("FindProcess() error = %v", err)
|
||||
}
|
||||
|
||||
gateway.mu.Lock()
|
||||
gateway.cmd = &exec.Cmd{Process: process}
|
||||
gateway.bootDefaultModel = cfg.ModelList[0].ModelName
|
||||
setGatewayRuntimeStatusLocked("running")
|
||||
gateway.mu.Unlock()
|
||||
|
||||
updatedCfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
updatedCfg.Agents.Defaults.ModelName = "second-model"
|
||||
if err := config.SaveConfig(configPath, updatedCfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
gatewayHealthGet = func(string, time.Duration) (*http.Response, error) {
|
||||
return mockGatewayHealthResponse(http.StatusOK, os.Getpid()), nil
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var body map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if got := body["gateway_status"]; got != "running" {
|
||||
t.Fatalf("gateway_status = %#v, want %q", got, "running")
|
||||
}
|
||||
if got := body["boot_default_model"]; got != cfg.ModelList[0].ModelName {
|
||||
t.Fatalf("boot_default_model = %#v, want %q", got, cfg.ModelList[0].ModelName)
|
||||
}
|
||||
if got := body["config_default_model"]; got != "second-model" {
|
||||
t.Fatalf("config_default_model = %#v, want %q", got, "second-model")
|
||||
}
|
||||
if got := body["gateway_restart_required"]; got != true {
|
||||
t.Fatalf("gateway_restart_required = %#v, want true", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayStatusReturnsErrorAfterStartupWindowExpires(t *testing.T) {
|
||||
resetGatewayTestState(t)
|
||||
|
||||
|
||||
+7
-17
@@ -7,7 +7,6 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
@@ -22,20 +21,13 @@ func (h *Handler) registerPicoRoutes(mux *http.ServeMux) {
|
||||
// WebSocket proxy: forward /pico/ws to gateway
|
||||
// This allows the frontend to connect via the same port as the web UI,
|
||||
// avoiding the need to expose extra ports for WebSocket communication.
|
||||
wsProxy := h.createWsProxy()
|
||||
mux.HandleFunc("GET /pico/ws", h.handleWebSocketProxy(wsProxy))
|
||||
mux.HandleFunc("GET /pico/ws", h.handleWebSocketProxy())
|
||||
}
|
||||
|
||||
// createWsProxy creates a reverse proxy to the gateway WebSocket endpoint.
|
||||
// The gateway port is read from the configuration.
|
||||
// createWsProxy creates a reverse proxy to the current gateway WebSocket endpoint.
|
||||
// The gateway bind host and port are resolved from the latest configuration.
|
||||
func (h *Handler) createWsProxy() *httputil.ReverseProxy {
|
||||
cfg, err := config.LoadConfig(h.configPath)
|
||||
gatewayPort := 18790 // default
|
||||
if err == nil && cfg.Gateway.Port != 0 {
|
||||
gatewayPort = cfg.Gateway.Port
|
||||
}
|
||||
gatewayURL, _ := url.Parse(fmt.Sprintf("http://127.0.0.1:%d", gatewayPort))
|
||||
wsProxy := httputil.NewSingleHostReverseProxy(gatewayURL)
|
||||
wsProxy := httputil.NewSingleHostReverseProxy(h.gatewayProxyURL())
|
||||
wsProxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
http.Error(w, "Gateway unavailable: "+err.Error(), http.StatusBadGateway)
|
||||
}
|
||||
@@ -43,12 +35,10 @@ func (h *Handler) createWsProxy() *httputil.ReverseProxy {
|
||||
}
|
||||
|
||||
// handleWebSocketProxy wraps a reverse proxy to handle WebSocket connections.
|
||||
// It ensures the Connection and Upgrade headers are properly forwarded.
|
||||
func (h *Handler) handleWebSocketProxy(proxy *httputil.ReverseProxy) http.HandlerFunc {
|
||||
// The reverse proxy forwards the incoming upgrade handshake as-is.
|
||||
func (h *Handler) handleWebSocketProxy() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// Set headers for WebSocket upgrade
|
||||
r.Header.Set("Connection", "upgrade")
|
||||
r.Header.Set("Upgrade", "websocket")
|
||||
proxy := h.createWsProxy()
|
||||
proxy.ServeHTTP(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,9 +2,12 @@ package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
@@ -235,3 +238,77 @@ func TestHandlePicoSetup_Response(t *testing.T) {
|
||||
t.Error("response should have changed=true on first setup")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleWebSocketProxyReloadsGatewayTargetFromConfig(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
handler := h.handleWebSocketProxy()
|
||||
|
||||
server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/pico/ws" {
|
||||
t.Fatalf("server1 path = %q, want %q", r.URL.Path, "/pico/ws")
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = io.WriteString(w, "server1")
|
||||
}))
|
||||
defer server1.Close()
|
||||
|
||||
server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/pico/ws" {
|
||||
t.Fatalf("server2 path = %q, want %q", r.URL.Path, "/pico/ws")
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = io.WriteString(w, "server2")
|
||||
}))
|
||||
defer server2.Close()
|
||||
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Gateway.Host = "127.0.0.1"
|
||||
cfg.Gateway.Port = mustGatewayTestPort(t, server1.URL)
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
req1 := httptest.NewRequest(http.MethodGet, "/pico/ws", nil)
|
||||
rec1 := httptest.NewRecorder()
|
||||
handler(rec1, req1)
|
||||
|
||||
if rec1.Code != http.StatusOK {
|
||||
t.Fatalf("first status = %d, want %d", rec1.Code, http.StatusOK)
|
||||
}
|
||||
if body := rec1.Body.String(); body != "server1" {
|
||||
t.Fatalf("first body = %q, want %q", body, "server1")
|
||||
}
|
||||
|
||||
cfg.Gateway.Port = mustGatewayTestPort(t, server2.URL)
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/pico/ws", nil)
|
||||
rec2 := httptest.NewRecorder()
|
||||
handler(rec2, req2)
|
||||
|
||||
if rec2.Code != http.StatusOK {
|
||||
t.Fatalf("second status = %d, want %d", rec2.Code, http.StatusOK)
|
||||
}
|
||||
if body := rec2.Body.String(); body != "server2" {
|
||||
t.Fatalf("second body = %q, want %q", body, "server2")
|
||||
}
|
||||
}
|
||||
|
||||
func mustGatewayTestPort(t *testing.T, rawURL string) int {
|
||||
t.Helper()
|
||||
|
||||
parsed, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
t.Fatalf("url.Parse() error = %v", err)
|
||||
}
|
||||
|
||||
port, err := strconv.Atoi(parsed.Port())
|
||||
if err != nil {
|
||||
t.Fatalf("Atoi(%q) error = %v", parsed.Port(), err)
|
||||
}
|
||||
|
||||
return port
|
||||
}
|
||||
|
||||
@@ -71,7 +71,4 @@ func (h *Handler) RegisterRoutes(mux *http.ServeMux) {
|
||||
h.registerLauncherConfigRoutes(mux)
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down the handler, closing all SSE connections.
|
||||
func (h *Handler) Shutdown() {
|
||||
gateway.events.Shutdown()
|
||||
}
|
||||
func (h *Handler) Shutdown() {}
|
||||
|
||||
@@ -4,16 +4,14 @@ import (
|
||||
"log"
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// JSONContentType sets the Content-Type header to application/json for
|
||||
// API requests handled by the wrapped handler.
|
||||
// SSE endpoints (text/event-stream) are excluded.
|
||||
func JSONContentType(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.HasPrefix(r.URL.Path, "/api/") && !strings.HasSuffix(r.URL.Path, "/events") {
|
||||
if len(r.URL.Path) >= 5 && r.URL.Path[:5] == "/api/" {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
@@ -32,7 +30,6 @@ func (rr *responseRecorder) WriteHeader(code int) {
|
||||
}
|
||||
|
||||
// Flush delegates to the underlying ResponseWriter if it implements http.Flusher.
|
||||
// This is required for SSE (Server-Sent Events) to work through the middleware.
|
||||
func (rr *responseRecorder) Flush() {
|
||||
if f, ok := rr.ResponseWriter.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
|
||||
@@ -94,7 +94,7 @@ func onReady() {
|
||||
func onExit() {
|
||||
fmt.Println(T(Exiting))
|
||||
|
||||
// First, shutdown API handler to close all SSE connections
|
||||
// First, shutdown API handler
|
||||
if apiHandler != nil {
|
||||
apiHandler.Shutdown()
|
||||
}
|
||||
|
||||
@@ -56,14 +56,20 @@ export function AppHeader() {
|
||||
const isRunning = gwState === "running"
|
||||
const isStarting = gwState === "starting"
|
||||
const isRestarting = gwState === "restarting"
|
||||
const isStopping = gwState === "stopping"
|
||||
const isStopped = gwState === "stopped" || gwState === "unknown"
|
||||
const showNotConnectedHint =
|
||||
!isRestarting && canStart && (gwState === "stopped" || gwState === "error")
|
||||
!isRestarting &&
|
||||
!isStopping &&
|
||||
canStart &&
|
||||
(gwState === "stopped" || gwState === "error")
|
||||
|
||||
const [showStopDialog, setShowStopDialog] = React.useState(false)
|
||||
|
||||
const handleGatewayToggle = () => {
|
||||
if (gwLoading || isRestarting || (!isRunning && !canStart)) return
|
||||
if (gwLoading || isRestarting || isStopping || (!isRunning && !canStart)) {
|
||||
return
|
||||
}
|
||||
if (isRunning) {
|
||||
setShowStopDialog(true)
|
||||
} else {
|
||||
@@ -137,7 +143,7 @@ export function AppHeader() {
|
||||
size="icon-sm"
|
||||
className="bg-amber-500/15 text-amber-700 hover:bg-amber-500/25 hover:text-amber-800 dark:text-amber-300 dark:hover:bg-amber-500/25"
|
||||
onClick={handleGatewayRestart}
|
||||
disabled={gwLoading || isRestarting || !canStart}
|
||||
disabled={gwLoading || isRestarting || isStopping || !canStart}
|
||||
aria-label={t("header.gateway.action.restart")}
|
||||
>
|
||||
<IconRefresh className="size-4" />
|
||||
@@ -168,25 +174,31 @@ export function AppHeader() {
|
||||
</Tooltip>
|
||||
) : (
|
||||
<Button
|
||||
variant={isStarting || isRestarting ? "secondary" : "default"}
|
||||
variant={
|
||||
isStarting || isRestarting || isStopping ? "secondary" : "default"
|
||||
}
|
||||
size="sm"
|
||||
className={`h-8 gap-2 px-3 ${
|
||||
isStopped ? "bg-green-500 text-white hover:bg-green-600" : ""
|
||||
}`}
|
||||
onClick={handleGatewayToggle}
|
||||
disabled={gwLoading || isStarting || isRestarting || !canStart}
|
||||
disabled={
|
||||
gwLoading || isStarting || isRestarting || isStopping || !canStart
|
||||
}
|
||||
>
|
||||
{gwLoading || isStarting || isRestarting ? (
|
||||
{gwLoading || isStarting || isRestarting || isStopping ? (
|
||||
<IconLoader2 className="h-4 w-4 animate-spin opacity-70" />
|
||||
) : (
|
||||
<IconPlayerPlay className="h-4 w-4 opacity-80" />
|
||||
)}
|
||||
<span className="text-xs font-semibold">
|
||||
{isRestarting
|
||||
? t("header.gateway.status.restarting")
|
||||
: isStarting
|
||||
? t("header.gateway.status.starting")
|
||||
: t("header.gateway.action.start")}
|
||||
{isStopping
|
||||
? t("header.gateway.status.stopping")
|
||||
: isRestarting
|
||||
? t("header.gateway.status.restarting")
|
||||
: isStarting
|
||||
? t("header.gateway.status.starting")
|
||||
: t("header.gateway.action.start")}
|
||||
</span>
|
||||
</Button>
|
||||
)}
|
||||
|
||||
@@ -37,7 +37,9 @@ export function useGatewayLogs() {
|
||||
const fetchLogs = async () => {
|
||||
if (
|
||||
!mounted ||
|
||||
!["running", "starting", "restarting"].includes(gateway.status)
|
||||
!["running", "starting", "restarting", "stopping"].includes(
|
||||
gateway.status,
|
||||
)
|
||||
) {
|
||||
if (mounted) {
|
||||
timeout = setTimeout(fetchLogs, 1000)
|
||||
|
||||
@@ -1,83 +1,24 @@
|
||||
import { useAtomValue } from "jotai"
|
||||
import { useCallback, useEffect, useState } from "react"
|
||||
|
||||
import { restartGateway, startGateway, stopGateway } from "@/api/gateway"
|
||||
import {
|
||||
type GatewayStatusResponse,
|
||||
getGatewayStatus,
|
||||
restartGateway,
|
||||
startGateway,
|
||||
stopGateway,
|
||||
} from "@/api/gateway"
|
||||
import {
|
||||
applyGatewayStatusToStore,
|
||||
beginGatewayStoppingTransition,
|
||||
cancelGatewayStoppingTransition,
|
||||
gatewayAtom,
|
||||
refreshGatewayState,
|
||||
subscribeGatewayPolling,
|
||||
updateGatewayStore,
|
||||
} from "@/store"
|
||||
|
||||
// Global variable to ensure we only have one SSE connection
|
||||
let sseInitialized = false
|
||||
|
||||
export function useGateway() {
|
||||
const gateway = useAtomValue(gatewayAtom)
|
||||
const { status: state, canStart, restartRequired } = gateway
|
||||
const [loading, setLoading] = useState(false)
|
||||
|
||||
const applyGatewayStatus = useCallback((data: GatewayStatusResponse) => {
|
||||
applyGatewayStatusToStore(data)
|
||||
}, [])
|
||||
|
||||
// Initialize global SSE connection once
|
||||
useEffect(() => {
|
||||
if (sseInitialized) return
|
||||
sseInitialized = true
|
||||
|
||||
getGatewayStatus()
|
||||
.then((data) => applyGatewayStatus(data))
|
||||
.catch(() => {
|
||||
updateGatewayStore({
|
||||
status: "unknown",
|
||||
canStart: true,
|
||||
restartRequired: false,
|
||||
})
|
||||
})
|
||||
|
||||
const statusPoll = window.setInterval(() => {
|
||||
getGatewayStatus()
|
||||
.then((data) => applyGatewayStatus(data))
|
||||
.catch(() => {
|
||||
// ignore polling errors
|
||||
})
|
||||
}, 5000)
|
||||
|
||||
// Subscribe to SSE for real-time updates globally
|
||||
const es = new EventSource("/api/gateway/events")
|
||||
|
||||
es.onmessage = (event) => {
|
||||
try {
|
||||
const data = JSON.parse(event.data)
|
||||
if (
|
||||
data.gateway_status ||
|
||||
typeof data.gateway_start_allowed === "boolean"
|
||||
) {
|
||||
applyGatewayStatus(data)
|
||||
}
|
||||
} catch {
|
||||
// ignore
|
||||
}
|
||||
}
|
||||
|
||||
es.onerror = () => {
|
||||
// EventSource will auto-reconnect. Preserve the last known gateway
|
||||
// status so transient SSE disconnects do not suppress chat websocket
|
||||
// reconnects while polling catches up.
|
||||
}
|
||||
|
||||
return () => {
|
||||
window.clearInterval(statusPoll)
|
||||
es.close()
|
||||
sseInitialized = false
|
||||
}
|
||||
}, [applyGatewayStatus])
|
||||
return subscribeGatewayPolling()
|
||||
}, [])
|
||||
|
||||
const start = useCallback(async () => {
|
||||
if (!canStart) return
|
||||
@@ -85,33 +26,28 @@ export function useGateway() {
|
||||
setLoading(true)
|
||||
try {
|
||||
await startGateway()
|
||||
// SSE will push the real state changes, but set optimistic state
|
||||
updateGatewayStore({ status: "starting" })
|
||||
} catch (err) {
|
||||
console.error("Failed to start gateway:", err)
|
||||
try {
|
||||
const status = await getGatewayStatus()
|
||||
applyGatewayStatus(status)
|
||||
} catch {
|
||||
updateGatewayStore({ status: "unknown" })
|
||||
}
|
||||
} finally {
|
||||
setLoading(false)
|
||||
}
|
||||
}, [applyGatewayStatus, canStart])
|
||||
|
||||
const stop = useCallback(async () => {
|
||||
setLoading(true)
|
||||
try {
|
||||
await stopGateway()
|
||||
updateGatewayStore({
|
||||
status: "stopped",
|
||||
canStart: true,
|
||||
status: "starting",
|
||||
restartRequired: false,
|
||||
})
|
||||
} catch (err) {
|
||||
console.error("Failed to stop gateway:", err)
|
||||
console.error("Failed to start gateway:", err)
|
||||
} finally {
|
||||
await refreshGatewayState({ force: true })
|
||||
setLoading(false)
|
||||
}
|
||||
}, [canStart])
|
||||
|
||||
const stop = useCallback(async () => {
|
||||
setLoading(true)
|
||||
beginGatewayStoppingTransition()
|
||||
try {
|
||||
await stopGateway()
|
||||
} catch (err) {
|
||||
console.error("Failed to stop gateway:", err)
|
||||
cancelGatewayStoppingTransition()
|
||||
} finally {
|
||||
await refreshGatewayState({ force: true })
|
||||
setLoading(false)
|
||||
}
|
||||
}, [])
|
||||
@@ -119,34 +55,20 @@ export function useGateway() {
|
||||
const restart = useCallback(async () => {
|
||||
if (state !== "running") return
|
||||
|
||||
const previousState = state
|
||||
const previousCanStart = canStart
|
||||
const previousRestartRequired = restartRequired
|
||||
|
||||
setLoading(true)
|
||||
updateGatewayStore({
|
||||
status: "restarting",
|
||||
restartRequired: false,
|
||||
})
|
||||
|
||||
try {
|
||||
await restartGateway()
|
||||
updateGatewayStore({
|
||||
status: "restarting",
|
||||
restartRequired: false,
|
||||
})
|
||||
} catch (err) {
|
||||
console.error("Failed to restart gateway:", err)
|
||||
try {
|
||||
const status = await getGatewayStatus()
|
||||
applyGatewayStatus(status)
|
||||
} catch {
|
||||
updateGatewayStore({
|
||||
status: previousState,
|
||||
canStart: previousCanStart,
|
||||
restartRequired: previousRestartRequired,
|
||||
})
|
||||
}
|
||||
} finally {
|
||||
await refreshGatewayState({ force: true })
|
||||
setLoading(false)
|
||||
}
|
||||
}, [applyGatewayStatus, canStart, restartRequired, state])
|
||||
}, [state])
|
||||
|
||||
return { state, loading, canStart, restartRequired, start, stop, restart }
|
||||
}
|
||||
|
||||
@@ -63,7 +63,8 @@
|
||||
},
|
||||
"status": {
|
||||
"starting": "Starting Gateway...",
|
||||
"restarting": "Restarting Gateway..."
|
||||
"restarting": "Restarting Gateway...",
|
||||
"stopping": "Stopping Gateway..."
|
||||
},
|
||||
"restartRequired": "Model changes require a gateway restart to take effect."
|
||||
}
|
||||
|
||||
@@ -63,7 +63,8 @@
|
||||
},
|
||||
"status": {
|
||||
"starting": "服务启动中...",
|
||||
"restarting": "服务重启中..."
|
||||
"restarting": "服务重启中...",
|
||||
"stopping": "服务停止中..."
|
||||
},
|
||||
"restartRequired": "切换默认模型后需要重启服务才能生效。"
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ export type GatewayState =
|
||||
| "running"
|
||||
| "starting"
|
||||
| "restarting"
|
||||
| "stopping"
|
||||
| "stopped"
|
||||
| "error"
|
||||
| "unknown"
|
||||
@@ -24,9 +25,29 @@ const DEFAULT_GATEWAY_STATE: GatewayStoreState = {
|
||||
restartRequired: false,
|
||||
}
|
||||
|
||||
const GATEWAY_POLL_INTERVAL_MS = 2000
|
||||
const GATEWAY_TRANSIENT_POLL_INTERVAL_MS = 1000
|
||||
const GATEWAY_STOPPING_TIMEOUT_MS = 5000
|
||||
|
||||
interface RefreshGatewayStateOptions {
|
||||
force?: boolean
|
||||
}
|
||||
|
||||
// Global atom for gateway state
|
||||
export const gatewayAtom = atom<GatewayStoreState>(DEFAULT_GATEWAY_STATE)
|
||||
|
||||
let gatewayPollingSubscribers = 0
|
||||
let gatewayPollingTimer: ReturnType<typeof setTimeout> | null = null
|
||||
let gatewayPollingRequest: Promise<void> | null = null
|
||||
let gatewayStoppingTimer: ReturnType<typeof setTimeout> | null = null
|
||||
|
||||
function clearGatewayStoppingTimeout() {
|
||||
if (gatewayStoppingTimer !== null) {
|
||||
clearTimeout(gatewayStoppingTimer)
|
||||
gatewayStoppingTimer = null
|
||||
}
|
||||
}
|
||||
|
||||
function normalizeGatewayStoreState(
|
||||
prev: GatewayStoreState,
|
||||
patch: GatewayStorePatch,
|
||||
@@ -49,10 +70,38 @@ export function updateGatewayStore(
|
||||
| GatewayStorePatch
|
||||
| ((prev: GatewayStoreState) => GatewayStorePatch | GatewayStoreState),
|
||||
) {
|
||||
getDefaultStore().set(gatewayAtom, (prev) => {
|
||||
const store = getDefaultStore()
|
||||
store.set(gatewayAtom, (prev) => {
|
||||
const nextPatch = typeof patch === "function" ? patch(prev) : patch
|
||||
return normalizeGatewayStoreState(prev, nextPatch)
|
||||
})
|
||||
const nextState = store.get(gatewayAtom)
|
||||
if (nextState?.status !== "stopping") {
|
||||
clearGatewayStoppingTimeout()
|
||||
}
|
||||
}
|
||||
|
||||
export function beginGatewayStoppingTransition() {
|
||||
clearGatewayStoppingTimeout()
|
||||
updateGatewayStore({
|
||||
status: "stopping",
|
||||
canStart: false,
|
||||
restartRequired: false,
|
||||
})
|
||||
gatewayStoppingTimer = setTimeout(() => {
|
||||
gatewayStoppingTimer = null
|
||||
updateGatewayStore((prev) =>
|
||||
prev.status === "stopping" ? { status: "running" } : prev,
|
||||
)
|
||||
void refreshGatewayState({ force: true })
|
||||
}, GATEWAY_STOPPING_TIMEOUT_MS)
|
||||
}
|
||||
|
||||
export function cancelGatewayStoppingTransition() {
|
||||
clearGatewayStoppingTimeout()
|
||||
updateGatewayStore((prev) =>
|
||||
prev.status === "stopping" ? { status: "running" } : prev,
|
||||
)
|
||||
}
|
||||
|
||||
export function applyGatewayStatusToStore(
|
||||
@@ -64,21 +113,92 @@ export function applyGatewayStatusToStore(
|
||||
>,
|
||||
) {
|
||||
updateGatewayStore((prev) => ({
|
||||
status: data.gateway_status ?? prev.status,
|
||||
canStart: data.gateway_start_allowed ?? prev.canStart,
|
||||
restartRequired:
|
||||
data.gateway_restart_required ??
|
||||
(data.gateway_status && data.gateway_status !== "running"
|
||||
status:
|
||||
prev.status === "stopping" && data.gateway_status === "running"
|
||||
? "stopping"
|
||||
: (data.gateway_status ?? prev.status),
|
||||
canStart:
|
||||
prev.status === "stopping" && data.gateway_status === "running"
|
||||
? false
|
||||
: prev.restartRequired),
|
||||
: (data.gateway_start_allowed ?? prev.canStart),
|
||||
restartRequired:
|
||||
prev.status === "stopping" && data.gateway_status === "running"
|
||||
? false
|
||||
: (data.gateway_restart_required ?? prev.restartRequired),
|
||||
}))
|
||||
}
|
||||
|
||||
export async function refreshGatewayState() {
|
||||
function nextGatewayPollInterval() {
|
||||
const status = getDefaultStore().get(gatewayAtom).status
|
||||
if (
|
||||
status === "starting" ||
|
||||
status === "restarting" ||
|
||||
status === "stopping"
|
||||
) {
|
||||
return GATEWAY_TRANSIENT_POLL_INTERVAL_MS
|
||||
}
|
||||
return GATEWAY_POLL_INTERVAL_MS
|
||||
}
|
||||
|
||||
function scheduleGatewayPoll(delay = nextGatewayPollInterval()) {
|
||||
if (gatewayPollingSubscribers === 0) {
|
||||
return
|
||||
}
|
||||
|
||||
if (gatewayPollingTimer !== null) {
|
||||
clearTimeout(gatewayPollingTimer)
|
||||
}
|
||||
|
||||
gatewayPollingTimer = setTimeout(() => {
|
||||
gatewayPollingTimer = null
|
||||
void refreshGatewayState()
|
||||
}, delay)
|
||||
}
|
||||
|
||||
export async function refreshGatewayState(
|
||||
options: RefreshGatewayStateOptions = {},
|
||||
) {
|
||||
if (gatewayPollingRequest) {
|
||||
await gatewayPollingRequest
|
||||
if (options.force) {
|
||||
return refreshGatewayState()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
gatewayPollingRequest = (async () => {
|
||||
try {
|
||||
const status = await getGatewayStatus()
|
||||
applyGatewayStatusToStore(status)
|
||||
} catch {
|
||||
// Preserve the last known state when a poll fails.
|
||||
} finally {
|
||||
gatewayPollingRequest = null
|
||||
scheduleGatewayPoll()
|
||||
}
|
||||
})()
|
||||
|
||||
try {
|
||||
const status = await getGatewayStatus()
|
||||
applyGatewayStatusToStore(status)
|
||||
} catch {
|
||||
updateGatewayStore(DEFAULT_GATEWAY_STATE)
|
||||
await gatewayPollingRequest
|
||||
} finally {
|
||||
if (gatewayPollingSubscribers === 0 && gatewayPollingTimer !== null) {
|
||||
clearTimeout(gatewayPollingTimer)
|
||||
gatewayPollingTimer = null
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export function subscribeGatewayPolling() {
|
||||
gatewayPollingSubscribers += 1
|
||||
if (gatewayPollingSubscribers === 1) {
|
||||
void refreshGatewayState()
|
||||
}
|
||||
|
||||
return () => {
|
||||
gatewayPollingSubscribers = Math.max(0, gatewayPollingSubscribers - 1)
|
||||
if (gatewayPollingSubscribers === 0 && gatewayPollingTimer !== null) {
|
||||
clearTimeout(gatewayPollingTimer)
|
||||
gatewayPollingTimer = null
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user