diff --git a/web/backend/api/events.go b/web/backend/api/events.go index 0a8d4a9bb..af44d1824 100644 --- a/web/backend/api/events.go +++ b/web/backend/api/events.go @@ -7,8 +7,11 @@ import ( // GatewayEvent represents a state change event for the gateway process. type GatewayEvent struct { - Status string `json:"gateway_status"` // "running", "starting", "stopped", "error" - PID int `json:"pid,omitempty"` + Status string `json:"gateway_status"` // "running", "starting", "restarting", "stopped", "error" + PID int `json:"pid,omitempty"` + BootDefaultModel string `json:"boot_default_model,omitempty"` + ConfigDefaultModel string `json:"config_default_model,omitempty"` + RestartRequired bool `json:"gateway_restart_required,omitempty"` } // EventBroadcaster manages SSE client subscriptions and broadcasts events. diff --git a/web/backend/api/gateway.go b/web/backend/api/gateway.go index 41f702e32..95b482ce0 100644 --- a/web/backend/api/gateway.go +++ b/web/backend/api/gateway.go @@ -23,13 +23,29 @@ import ( // gateway holds the state for the managed gateway process. var gateway = struct { - mu sync.Mutex - cmd *exec.Cmd - logs *LogBuffer - events *EventBroadcaster + mu sync.Mutex + cmd *exec.Cmd + bootDefaultModel string + runtimeStatus string + startupDeadline time.Time + logs *LogBuffer + events *EventBroadcaster }{ - logs: NewLogBuffer(200), - events: NewEventBroadcaster(), + runtimeStatus: "stopped", + logs: NewLogBuffer(200), + events: NewEventBroadcaster(), +} + +var ( + gatewayStartupWindow = 15 * time.Second + gatewayRestartGracePeriod = 5 * time.Second + gatewayRestartForceKillWindow = 3 * time.Second + gatewayRestartPollInterval = 100 * time.Millisecond +) + +var gatewayHealthGet = func(url string, timeout time.Duration) (*http.Response, error) { + client := http.Client{Timeout: timeout} + return client.Get(url) } // registerGatewayRoutes binds gateway lifecycle endpoints to the ServeMux. @@ -65,7 +81,7 @@ func (h *Handler) TryAutoStartGateway() { return } - pid, err := h.startGatewayLocked() + pid, err := h.startGatewayLocked("starting") if err != nil { log.Printf("Failed to auto-start gateway: %v", err) return @@ -131,7 +147,110 @@ func isCmdProcessAliveLocked(cmd *exec.Cmd) bool { return cmd.Process.Signal(syscall.Signal(0)) == nil } -func (h *Handler) startGatewayLocked() (int, error) { +func setGatewayRuntimeStatusLocked(status string) { + gateway.runtimeStatus = status + if status == "starting" || status == "restarting" { + gateway.startupDeadline = time.Now().Add(gatewayStartupWindow) + return + } + gateway.startupDeadline = time.Time{} +} + +func gatewayStatusOnHealthFailureLocked() string { + if gateway.runtimeStatus == "starting" || gateway.runtimeStatus == "restarting" { + if gateway.startupDeadline.IsZero() || time.Now().Before(gateway.startupDeadline) { + return gateway.runtimeStatus + } + return "error" + } + if gateway.runtimeStatus == "running" { + return "running" + } + if gateway.runtimeStatus == "error" { + return "error" + } + return "error" +} + +func currentGatewayStatusLocked(processAlive bool) string { + if !processAlive { + if gateway.runtimeStatus == "restarting" { + if gateway.startupDeadline.IsZero() || time.Now().Before(gateway.startupDeadline) { + return "restarting" + } + return "error" + } + if gateway.runtimeStatus == "error" { + return "error" + } + return "stopped" + } + return gatewayStatusOnHealthFailureLocked() +} + +func waitForGatewayProcessExit(cmd *exec.Cmd, timeout time.Duration) bool { + if cmd == nil || cmd.Process == nil { + return true + } + + deadline := time.Now().Add(timeout) + for { + if !isCmdProcessAliveLocked(cmd) { + return true + } + if time.Now().After(deadline) { + return false + } + time.Sleep(gatewayRestartPollInterval) + } +} + +func stopGatewayProcessForRestart(cmd *exec.Cmd) error { + if cmd == nil || cmd.Process == nil || !isCmdProcessAliveLocked(cmd) { + return nil + } + + var stopErr error + if runtime.GOOS == "windows" { + stopErr = cmd.Process.Kill() + } else { + stopErr = cmd.Process.Signal(syscall.SIGTERM) + } + if stopErr != nil && isCmdProcessAliveLocked(cmd) { + return fmt.Errorf("failed to stop existing gateway: %w", stopErr) + } + + if waitForGatewayProcessExit(cmd, gatewayRestartGracePeriod) { + return nil + } + + if runtime.GOOS != "windows" { + killErr := cmd.Process.Signal(syscall.SIGKILL) + if killErr != nil && isCmdProcessAliveLocked(cmd) { + return fmt.Errorf("failed to force-stop existing gateway: %w", killErr) + } + if waitForGatewayProcessExit(cmd, gatewayRestartForceKillWindow) { + return nil + } + } + + return fmt.Errorf("existing gateway did not exit before restart") +} + +func gatewayRestartRequired(status, bootDefaultModel, configDefaultModel string) bool { + return status == "running" && + bootDefaultModel != "" && + configDefaultModel != "" && + bootDefaultModel != configDefaultModel +} + +func (h *Handler) startGatewayLocked(initialStatus string) (int, error) { + cfg, err := config.LoadConfig(h.configPath) + if err != nil { + return 0, fmt.Errorf("failed to load config: %w", err) + } + defaultModelName := strings.TrimSpace(cfg.Agents.Defaults.GetModelName()) + // Locate the picoclaw executable execPath := utils.FindPicoclawBinary() @@ -171,11 +290,19 @@ func (h *Handler) startGatewayLocked() (int, error) { } gateway.cmd = cmd + gateway.bootDefaultModel = defaultModelName + setGatewayRuntimeStatusLocked(initialStatus) pid := cmd.Process.Pid log.Printf("Started picoclaw gateway (PID: %d) from %s", pid, execPath) - // Broadcast starting event - gateway.events.Broadcast(GatewayEvent{Status: "starting", PID: pid}) + // Broadcast the launch state immediately so clients can reflect it without polling. + gateway.events.Broadcast(GatewayEvent{ + Status: initialStatus, + PID: pid, + BootDefaultModel: defaultModelName, + ConfigDefaultModel: defaultModelName, + RestartRequired: false, + }) // Capture stdout/stderr in background go scanPipe(stdoutPipe, gateway.logs) @@ -190,13 +317,23 @@ func (h *Handler) startGatewayLocked() (int, error) { } gateway.mu.Lock() + shouldBroadcastStopped := false if gateway.cmd == cmd { gateway.cmd = nil + gateway.bootDefaultModel = "" + if gateway.runtimeStatus != "restarting" { + setGatewayRuntimeStatusLocked("stopped") + shouldBroadcastStopped = true + } } gateway.mu.Unlock() - // Broadcast stopped event - gateway.events.Broadcast(GatewayEvent{Status: "stopped"}) + if shouldBroadcastStopped { + gateway.events.Broadcast(GatewayEvent{ + Status: "stopped", + RestartRequired: false, + }) + } }() // Start a goroutine to probe health and broadcast "running" once ready @@ -219,12 +356,22 @@ func (h *Handler) startGatewayLocked() (int, error) { healthPort = 18790 } healthURL := fmt.Sprintf("http://%s/health", net.JoinHostPort(healthHost, strconv.Itoa(healthPort))) - client := http.Client{Timeout: 1 * time.Second} - resp, err := client.Get(healthURL) + resp, err := gatewayHealthGet(healthURL, 1*time.Second) if err == nil { resp.Body.Close() if resp.StatusCode == http.StatusOK { - gateway.events.Broadcast(GatewayEvent{Status: "running", PID: pid}) + gateway.mu.Lock() + if gateway.cmd == cmd { + setGatewayRuntimeStatusLocked("running") + } + gateway.mu.Unlock() + gateway.events.Broadcast(GatewayEvent{ + Status: "running", + PID: pid, + BootDefaultModel: defaultModelName, + ConfigDefaultModel: defaultModelName, + RestartRequired: false, + }) return } } @@ -253,6 +400,7 @@ func (h *Handler) handleGatewayStart(w http.ResponseWriter, r *http.Request) { } if gateway.cmd != nil && gateway.cmd.Process != nil { gateway.cmd = nil + setGatewayRuntimeStatusLocked("stopped") } ready, reason, err := h.gatewayStartReady() @@ -274,7 +422,7 @@ func (h *Handler) handleGatewayStart(w http.ResponseWriter, r *http.Request) { return } - pid, err := h.startGatewayLocked() + pid, err := h.startGatewayLocked("starting") if err != nil { http.Error(w, fmt.Sprintf("Failed to start gateway: %v", err), http.StatusInternalServerError) return @@ -330,30 +478,72 @@ func (h *Handler) handleGatewayStop(w http.ResponseWriter, r *http.Request) { // // POST /api/gateway/restart func (h *Handler) handleGatewayRestart(w http.ResponseWriter, r *http.Request) { - gateway.mu.Lock() - - // Stop existing process if running - if gateway.cmd != nil && gateway.cmd.Process != nil { - if isCmdProcessAliveLocked(gateway.cmd) { - // Process is alive, send SIGTERM - if runtime.GOOS == "windows" { - gateway.cmd.Process.Kill() - } else { - gateway.cmd.Process.Signal(syscall.SIGTERM) - } - - // Wait briefly for it to exit - gateway.mu.Unlock() - time.Sleep(2 * time.Second) - gateway.mu.Lock() - } - gateway.cmd = nil + ready, reason, err := h.gatewayStartReady() + if err != nil { + http.Error( + w, + fmt.Sprintf("Failed to validate gateway start conditions: %v", err), + http.StatusInternalServerError, + ) + return + } + if !ready { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(map[string]any{ + "status": "precondition_failed", + "message": reason, + }) + return } + gateway.mu.Lock() + previousCmd := gateway.cmd + setGatewayRuntimeStatusLocked("restarting") + gateway.events.Broadcast(GatewayEvent{ + Status: "restarting", + RestartRequired: false, + }) gateway.mu.Unlock() - // Start fresh via the existing handler - h.handleGatewayStart(w, r) + if err = stopGatewayProcessForRestart(previousCmd); err != nil { + gateway.mu.Lock() + if gateway.cmd == previousCmd { + if isCmdProcessAliveLocked(previousCmd) { + setGatewayRuntimeStatusLocked("running") + } else { + gateway.cmd = nil + gateway.bootDefaultModel = "" + setGatewayRuntimeStatusLocked("error") + } + } + gateway.mu.Unlock() + http.Error(w, fmt.Sprintf("Failed to restart gateway: %v", err), http.StatusInternalServerError) + return + } + + gateway.mu.Lock() + if gateway.cmd == previousCmd { + gateway.cmd = nil + gateway.bootDefaultModel = "" + } + pid, err := h.startGatewayLocked("restarting") + if err != nil { + gateway.cmd = nil + gateway.bootDefaultModel = "" + setGatewayRuntimeStatusLocked("error") + } + gateway.mu.Unlock() + if err != nil { + http.Error(w, fmt.Sprintf("Failed to restart gateway: %v", err), http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]any{ + "status": "ok", + "pid": pid, + }) } // handleGatewayClearLogs clears the in-memory gateway log buffer. @@ -374,24 +564,44 @@ func (h *Handler) handleGatewayClearLogs(w http.ResponseWriter, r *http.Request) // // GET /api/gateway/status func (h *Handler) handleGatewayStatus(w http.ResponseWriter, r *http.Request) { + data := h.gatewayStatusData(r, true) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(data) +} + +func (h *Handler) gatewayStatusData(r *http.Request, includeLogs bool) map[string]any { data := map[string]any{} + cfg, cfgErr := config.LoadConfig(h.configPath) + configDefaultModel := "" + if cfgErr == nil && cfg != nil { + configDefaultModel = strings.TrimSpace(cfg.Agents.Defaults.GetModelName()) + if configDefaultModel != "" { + data["config_default_model"] = configDefaultModel + } + } // Check process state gateway.mu.Lock() processAlive := isGatewayProcessAliveLocked() + bootDefaultModel := "" if processAlive { data["pid"] = gateway.cmd.Process.Pid + if gateway.bootDefaultModel != "" { + data["boot_default_model"] = gateway.bootDefaultModel + bootDefaultModel = gateway.bootDefaultModel + } } gateway.mu.Unlock() if !processAlive { - data["gateway_status"] = "stopped" + gateway.mu.Lock() + data["gateway_status"] = currentGatewayStatusLocked(false) + gateway.mu.Unlock() } else { // Process is alive — probe its health endpoint - cfg, err := config.LoadConfig(h.configPath) host := "127.0.0.1" port := 18790 - if err == nil && cfg != nil { + if cfgErr == nil && cfg != nil { host = gatewayProbeHost(h.effectiveGatewayBindHost(cfg)) if cfg.Gateway.Port != 0 { port = cfg.Gateway.Port @@ -399,21 +609,31 @@ func (h *Handler) handleGatewayStatus(w http.ResponseWriter, r *http.Request) { } url := fmt.Sprintf("http://%s/health", net.JoinHostPort(host, strconv.Itoa(port))) - client := http.Client{Timeout: 2 * time.Second} - resp, err := client.Get(url) + resp, err := gatewayHealthGet(url, 2*time.Second) if err != nil { - data["gateway_status"] = "starting" + gateway.mu.Lock() + data["gateway_status"] = currentGatewayStatusLocked(true) + gateway.mu.Unlock() } else { defer resp.Body.Close() if resp.StatusCode != http.StatusOK { + gateway.mu.Lock() + setGatewayRuntimeStatusLocked("error") + gateway.mu.Unlock() data["gateway_status"] = "error" data["status_code"] = resp.StatusCode } else { var healthData map[string]any if decErr := json.NewDecoder(resp.Body).Decode(&healthData); decErr != nil { + gateway.mu.Lock() + setGatewayRuntimeStatusLocked("error") + gateway.mu.Unlock() data["gateway_status"] = "error" } else { + gateway.mu.Lock() + setGatewayRuntimeStatusLocked("running") + gateway.mu.Unlock() for k, v := range healthData { data[k] = v } @@ -423,6 +643,13 @@ func (h *Handler) handleGatewayStatus(w http.ResponseWriter, r *http.Request) { } } + status, _ := data["gateway_status"].(string) + data["gateway_restart_required"] = gatewayRestartRequired( + status, + bootDefaultModel, + configDefaultModel, + ) + ready, reason, readyErr := h.gatewayStartReady() if readyErr != nil { data["gateway_start_allowed"] = false @@ -434,11 +661,11 @@ func (h *Handler) handleGatewayStatus(w http.ResponseWriter, r *http.Request) { } } - // Append incremental log data - appendGatewayLogs(r, data) + if includeLogs { + appendGatewayLogs(r, data) + } - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(data) + return data } // appendGatewayLogs reads log_offset and log_run_id query params from the request @@ -524,28 +751,7 @@ func (h *Handler) handleGatewayEvents(w http.ResponseWriter, r *http.Request) { // currentGatewayStatus returns the current gateway status as a JSON string. func (h *Handler) currentGatewayStatus() string { - gateway.mu.Lock() - defer gateway.mu.Unlock() - - data := map[string]any{ - "gateway_status": "stopped", - } - if isGatewayProcessAliveLocked() { - data["gateway_status"] = "running" - data["pid"] = gateway.cmd.Process.Pid - } - - ready, reason, readyErr := h.gatewayStartReady() - if readyErr != nil { - data["gateway_start_allowed"] = false - data["gateway_start_reason"] = readyErr.Error() - } else { - data["gateway_start_allowed"] = ready - if !ready { - data["gateway_start_reason"] = reason - } - } - + data := h.gatewayStatusData(nil, false) encoded, _ := json.Marshal(data) return string(encoded) } diff --git a/web/backend/api/gateway_test.go b/web/backend/api/gateway_test.go index d4265776a..fe3fccdee 100644 --- a/web/backend/api/gateway_test.go +++ b/web/backend/api/gateway_test.go @@ -2,19 +2,76 @@ package api import ( "encoding/json" + "errors" "net/http" "net/http/httptest" "os" + "os/exec" "path/filepath" + "runtime" "strconv" "strings" "testing" + "time" "github.com/sipeed/picoclaw/pkg/auth" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/web/backend/utils" ) +func startLongRunningProcess(t *testing.T) *exec.Cmd { + t.Helper() + + var cmd *exec.Cmd + if runtime.GOOS == "windows" { + cmd = exec.Command("powershell", "-NoProfile", "-Command", "Start-Sleep -Seconds 30") + } else { + cmd = exec.Command("sleep", "30") + } + + if err := cmd.Start(); err != nil { + t.Fatalf("Start() error = %v", err) + } + + return cmd +} + +func startIgnoringTermProcess(t *testing.T) *exec.Cmd { + t.Helper() + + if runtime.GOOS == "windows" { + t.Skip("TERM handling differs on Windows") + } + + cmd := exec.Command("sh", "-c", "trap '' TERM; sleep 30") + if err := cmd.Start(); err != nil { + t.Fatalf("Start() error = %v", err) + } + + return cmd +} + +func resetGatewayTestState(t *testing.T) { + t.Helper() + + originalHealthGet := gatewayHealthGet + originalRestartGracePeriod := gatewayRestartGracePeriod + originalRestartForceKillWindow := gatewayRestartForceKillWindow + originalRestartPollInterval := gatewayRestartPollInterval + t.Cleanup(func() { + gatewayHealthGet = originalHealthGet + gatewayRestartGracePeriod = originalRestartGracePeriod + gatewayRestartForceKillWindow = originalRestartForceKillWindow + gatewayRestartPollInterval = originalRestartPollInterval + + gateway.mu.Lock() + gateway.cmd = nil + gateway.bootDefaultModel = "" + setGatewayRuntimeStatusLocked("stopped") + gateway.mu.Unlock() + }) +} + func TestGatewayStartReady_NoDefaultModel(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") h := NewHandler(configPath) @@ -317,6 +374,339 @@ func TestGatewayStatusIncludesStartConditionWhenNotReady(t *testing.T) { } } +func TestGatewayStatusKeepsRunningWhenHealthProbeFailsAfterRunning(t *testing.T) { + resetGatewayTestState(t) + + configPath := filepath.Join(t.TempDir(), "config.json") + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + cmd := startLongRunningProcess(t) + t.Cleanup(func() { + if cmd.Process != nil { + _ = cmd.Process.Kill() + } + _ = cmd.Wait() + }) + + gateway.mu.Lock() + gateway.cmd = cmd + gateway.bootDefaultModel = "existing-model" + // Simulate a process that has already reached the running state. + setGatewayRuntimeStatusLocked("running") + gateway.mu.Unlock() + + gatewayHealthGet = func(string, time.Duration) (*http.Response, error) { + return nil, errors.New("probe failed") + } + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil) + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } + + var body map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil { + t.Fatalf("unmarshal response: %v", err) + } + + if got := body["gateway_status"]; got != "running" { + t.Fatalf("gateway_status = %#v, want %q", got, "running") + } +} + +func TestGatewayStatusReturnsErrorAfterStartupWindowExpires(t *testing.T) { + resetGatewayTestState(t) + + configPath := filepath.Join(t.TempDir(), "config.json") + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + cmd := startLongRunningProcess(t) + t.Cleanup(func() { + if cmd.Process != nil { + _ = cmd.Process.Kill() + } + _ = cmd.Wait() + }) + + gateway.mu.Lock() + gateway.cmd = cmd + gateway.bootDefaultModel = "existing-model" + setGatewayRuntimeStatusLocked("starting") + gateway.startupDeadline = time.Now().Add(-time.Second) + gateway.mu.Unlock() + + gatewayHealthGet = func(string, time.Duration) (*http.Response, error) { + return nil, errors.New("probe failed") + } + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil) + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } + + var body map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil { + t.Fatalf("unmarshal response: %v", err) + } + + if got := body["gateway_status"]; got != "error" { + t.Fatalf("gateway_status = %#v, want %q", got, "error") + } +} + +func TestGatewayStatusReturnsRestartingDuringRestartGap(t *testing.T) { + resetGatewayTestState(t) + + configPath := filepath.Join(t.TempDir(), "config.json") + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + gateway.mu.Lock() + setGatewayRuntimeStatusLocked("restarting") + gateway.mu.Unlock() + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil) + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } + + var body map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil { + t.Fatalf("unmarshal response: %v", err) + } + + if got := body["gateway_status"]; got != "restarting" { + t.Fatalf("gateway_status = %#v, want %q", got, "restarting") + } +} + +func TestGatewayStatusIncludesRestartRequiredWhenModelsDiffer(t *testing.T) { + resetGatewayTestState(t) + + configPath := filepath.Join(t.TempDir(), "config.json") + cfg := config.DefaultConfig() + cfg.Agents.Defaults.ModelName = cfg.ModelList[0].ModelName + cfg.ModelList[0].APIKey = "test-key" + if err := config.SaveConfig(configPath, cfg); err != nil { + t.Fatalf("SaveConfig() error = %v", err) + } + + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + cmd := startLongRunningProcess(t) + t.Cleanup(func() { + if cmd.Process != nil { + _ = cmd.Process.Kill() + } + _ = cmd.Wait() + }) + + gateway.mu.Lock() + gateway.cmd = cmd + gateway.bootDefaultModel = "previous-model" + setGatewayRuntimeStatusLocked("running") + gateway.mu.Unlock() + + gatewayHealthGet = func(string, time.Duration) (*http.Response, error) { + rec := httptest.NewRecorder() + rec.WriteHeader(http.StatusOK) + _, _ = rec.WriteString(`{"ok":true}`) + return rec.Result(), nil + } + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil) + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } + + var body map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil { + t.Fatalf("unmarshal response: %v", err) + } + + if got := body["gateway_restart_required"]; got != true { + t.Fatalf("gateway_restart_required = %#v, want true", got) + } +} + +func TestGatewayRestartKeepsRunningProcessWhenPreconditionsFail(t *testing.T) { + configPath := filepath.Join(t.TempDir(), "config.json") + cfg := config.DefaultConfig() + cfg.Agents.Defaults.ModelName = cfg.ModelList[0].ModelName + cfg.ModelList[0].APIKey = "" + cfg.ModelList[0].AuthMethod = "" + if err := config.SaveConfig(configPath, cfg); err != nil { + t.Fatalf("SaveConfig() error = %v", err) + } + + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + cmd := startLongRunningProcess(t) + t.Cleanup(func() { + gateway.mu.Lock() + if gateway.cmd == cmd { + gateway.cmd = nil + gateway.bootDefaultModel = "" + } + gateway.mu.Unlock() + + if cmd.Process != nil { + _ = cmd.Process.Kill() + } + _ = cmd.Wait() + }) + + gateway.mu.Lock() + gateway.cmd = cmd + gateway.bootDefaultModel = "existing-model" + gateway.mu.Unlock() + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/gateway/restart", nil) + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest) + } + + gateway.mu.Lock() + stillRunning := gateway.cmd == cmd && isCmdProcessAliveLocked(cmd) + gateway.mu.Unlock() + + if !stillRunning { + t.Fatalf("gateway process was stopped when restart preconditions failed") + } +} + +func TestGatewayRestartKeepsOldProcessWhenItDoesNotExitInTime(t *testing.T) { + resetGatewayTestState(t) + + configPath := filepath.Join(t.TempDir(), "config.json") + cfg := config.DefaultConfig() + cfg.Agents.Defaults.ModelName = cfg.ModelList[0].ModelName + cfg.ModelList[0].APIKey = "test-key" + if err := config.SaveConfig(configPath, cfg); err != nil { + t.Fatalf("SaveConfig() error = %v", err) + } + + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + cmd := startIgnoringTermProcess(t) + t.Cleanup(func() { + gateway.mu.Lock() + if gateway.cmd == cmd { + gateway.cmd = nil + gateway.bootDefaultModel = "" + } + gateway.mu.Unlock() + + if cmd.Process != nil { + _ = cmd.Process.Kill() + } + _ = cmd.Wait() + }) + + gatewayRestartGracePeriod = 150 * time.Millisecond + gatewayRestartForceKillWindow = 150 * time.Millisecond + gatewayRestartPollInterval = 10 * time.Millisecond + + gateway.mu.Lock() + gateway.cmd = cmd + gateway.bootDefaultModel = "existing-model" + setGatewayRuntimeStatusLocked("running") + gateway.mu.Unlock() + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/gateway/restart", nil) + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusInternalServerError { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusInternalServerError) + } + + gateway.mu.Lock() + stillRunning := gateway.cmd == cmd && isCmdProcessAliveLocked(cmd) + status := gateway.runtimeStatus + gateway.mu.Unlock() + + if !stillRunning { + t.Fatalf("gateway process was replaced before the old process exited") + } + if status != "running" { + t.Fatalf("runtimeStatus = %q, want %q", status, "running") + } +} + +func TestGatewayRestartReturnsErrorStatusWhenReplacementFailsToStart(t *testing.T) { + resetGatewayTestState(t) + + configPath := filepath.Join(t.TempDir(), "config.json") + cfg := config.DefaultConfig() + cfg.Agents.Defaults.ModelName = cfg.ModelList[0].ModelName + cfg.ModelList[0].APIKey = "test-key" + if err := config.SaveConfig(configPath, cfg); err != nil { + t.Fatalf("SaveConfig() error = %v", err) + } + + invalidBinaryPath := filepath.Join(t.TempDir(), "fake-picoclaw") + if err := os.WriteFile(invalidBinaryPath, []byte("#!/bin/sh\n"), 0o644); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + t.Setenv("PICOCLAW_BINARY", invalidBinaryPath) + + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/gateway/restart", nil) + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusInternalServerError { + t.Fatalf("restart status = %d, want %d", rec.Code, http.StatusInternalServerError) + } + + statusRec := httptest.NewRecorder() + statusReq := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil) + mux.ServeHTTP(statusRec, statusReq) + + if statusRec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", statusRec.Code, http.StatusOK) + } + + var body map[string]any + if err := json.Unmarshal(statusRec.Body.Bytes(), &body); err != nil { + t.Fatalf("unmarshal response: %v", err) + } + + if got := body["gateway_status"]; got != "error" { + t.Fatalf("gateway_status = %#v, want %q", got, "error") + } +} + func TestGatewayClearLogsResetsBufferedHistory(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") h := NewHandler(configPath) diff --git a/web/frontend/src/api/gateway.ts b/web/frontend/src/api/gateway.ts index 020e92e3a..1688a5278 100644 --- a/web/frontend/src/api/gateway.ts +++ b/web/frontend/src/api/gateway.ts @@ -1,10 +1,13 @@ // API client for gateway process management. interface GatewayStatusResponse { - gateway_status: "running" | "starting" | "stopped" | "error" + gateway_status: "running" | "starting" | "restarting" | "stopped" | "error" gateway_start_allowed?: boolean gateway_start_reason?: string + gateway_restart_required?: boolean pid?: number + boot_default_model?: string + config_default_model?: string logs?: string[] log_total?: number log_run_id?: number diff --git a/web/frontend/src/api/models.ts b/web/frontend/src/api/models.ts index 6a4544c65..8e49b48b4 100644 --- a/web/frontend/src/api/models.ts +++ b/web/frontend/src/api/models.ts @@ -84,7 +84,7 @@ export async function setDefaultModel( body: JSON.stringify({ model_name: modelName }), }) - void refreshGatewayState() + await refreshGatewayState() return response } diff --git a/web/frontend/src/components/app-header.tsx b/web/frontend/src/components/app-header.tsx index 7a50fe0fb..fe0c84e69 100644 --- a/web/frontend/src/components/app-header.tsx +++ b/web/frontend/src/components/app-header.tsx @@ -6,6 +6,7 @@ import { IconMoon, IconPlayerPlay, IconPower, + IconRefresh, IconSun, } from "@tabler/icons-react" import { Link } from "@tanstack/react-router" @@ -31,6 +32,11 @@ import { } from "@/components/ui/dropdown-menu.tsx" import { Separator } from "@/components/ui/separator.tsx" import { SidebarTrigger } from "@/components/ui/sidebar" +import { + Tooltip, + TooltipContent, + TooltipTrigger, +} from "@/components/ui/tooltip" import { useGateway } from "@/hooks/use-gateway.ts" import { useTheme } from "@/hooks/use-theme.ts" @@ -41,27 +47,35 @@ export function AppHeader() { state: gwState, loading: gwLoading, canStart, + restartRequired, start, + restart, stop, } = useGateway() const isRunning = gwState === "running" const isStarting = gwState === "starting" + const isRestarting = gwState === "restarting" const isStopped = gwState === "stopped" || gwState === "unknown" const showNotConnectedHint = - canStart && (gwState === "stopped" || gwState === "error") + !isRestarting && canStart && (gwState === "stopped" || gwState === "error") const [showStopDialog, setShowStopDialog] = React.useState(false) const handleGatewayToggle = () => { - if (gwLoading || (!isRunning && !canStart)) return + if (gwLoading || isRestarting || (!isRunning && !canStart)) return if (isRunning) { setShowStopDialog(true) } else { - start() + void start() } } + const handleGatewayRestart = () => { + if (gwLoading || isRestarting || !restartRequired || !canStart) return + void restart() + } + const confirmStop = () => { setShowStopDialog(false) stop() @@ -115,35 +129,67 @@ export function AppHeader() {
+ {restartRequired && ( + + + + + + {t("header.gateway.restartRequired")} + + + )} + {/* Gateway Start/Stop */} - + {isRunning ? ( + + + + + {t("header.gateway.action.stop")} + + ) : ( + + )} (null) const [isAtBottom, setIsAtBottom] = useState(true) + const [hasScrolled, setHasScrolled] = useState(false) const [input, setInput] = useState("") const { @@ -56,14 +57,22 @@ export function ChatPage() { onDeletedActiveSession: newChat, }) - const handleScroll = (e: React.UIEvent) => { - const { scrollTop, scrollHeight, clientHeight } = e.currentTarget + const syncScrollState = (element: HTMLDivElement) => { + const { scrollTop, scrollHeight, clientHeight } = element + setHasScrolled(scrollTop > 0) setIsAtBottom(scrollHeight - scrollTop <= clientHeight + 10) } + const handleScroll = (e: React.UIEvent) => { + syncScrollState(e.currentTarget) + } + useEffect(() => { - if (isAtBottom && scrollRef.current) { - scrollRef.current.scrollTop = scrollRef.current.scrollHeight + if (scrollRef.current) { + if (isAtBottom) { + scrollRef.current.scrollTop = scrollRef.current.scrollHeight + } + syncScrollState(scrollRef.current) } }, [messages, isTyping, isAtBottom]) @@ -77,6 +86,9 @@ export function ChatPage() {
diff --git a/web/frontend/src/components/models/edit-model-sheet.tsx b/web/frontend/src/components/models/edit-model-sheet.tsx index 4c77944a9..237991a9f 100644 --- a/web/frontend/src/components/models/edit-model-sheet.tsx +++ b/web/frontend/src/components/models/edit-model-sheet.tsx @@ -110,7 +110,7 @@ export function EditModelSheet({ : undefined, thinking_level: form.thinkingLevel || undefined, }) - if (setAsDefault) { + if (setAsDefault && !model.is_default) { await setDefaultModel(model.model_name) } onSaved() diff --git a/web/frontend/src/components/models/models-page.tsx b/web/frontend/src/components/models/models-page.tsx index b8e80e709..6776e5ca8 100644 --- a/web/frontend/src/components/models/models-page.tsx +++ b/web/frontend/src/components/models/models-page.tsx @@ -79,6 +79,8 @@ export function ModelsPage() { }, [fetchModels]) const handleSetDefault = async (model: ModelInfo) => { + if (model.is_default) return + setSettingDefaultIndex(model.index) try { await setDefaultModel(model.model_name) diff --git a/web/frontend/src/components/page-header.tsx b/web/frontend/src/components/page-header.tsx index 9d4aa6975..656551f39 100644 --- a/web/frontend/src/components/page-header.tsx +++ b/web/frontend/src/components/page-header.tsx @@ -2,16 +2,28 @@ import { IconMenu2 } from "@tabler/icons-react" import type { ReactNode } from "react" import { SidebarTrigger } from "@/components/ui/sidebar" +import { cn } from "@/lib/utils" interface PageHeaderProps { title: string titleExtra?: ReactNode children?: ReactNode + className?: string } -export function PageHeader({ title, titleExtra, children }: PageHeaderProps) { +export function PageHeader({ + title, + titleExtra, + children, + className, +}: PageHeaderProps) { return ( -
+
diff --git a/web/frontend/src/hooks/use-chat-models.ts b/web/frontend/src/hooks/use-chat-models.ts index 8a82ceaf3..9afa882db 100644 --- a/web/frontend/src/hooks/use-chat-models.ts +++ b/web/frontend/src/hooks/use-chat-models.ts @@ -1,4 +1,4 @@ -import { useCallback, useEffect, useMemo, useState } from "react" +import { useCallback, useEffect, useMemo, useRef, useState } from "react" import { type ModelInfo, getModels, setDefaultModel } from "@/api/models" @@ -20,6 +20,7 @@ function isLocalModel(model: ModelInfo): boolean { export function useChatModels({ isConnected }: UseChatModelsOptions) { const [modelList, setModelList] = useState([]) const [defaultModelName, setDefaultModelName] = useState("") + const setDefaultRequestIdRef = useRef(0) const loadModels = useCallback(async () => { try { @@ -41,17 +42,28 @@ export function useChatModels({ isConnected }: UseChatModelsOptions) { return () => clearTimeout(timerId) }, [isConnected, loadModels]) - const handleSetDefault = useCallback(async (modelName: string) => { - try { - await setDefaultModel(modelName) - setDefaultModelName(modelName) - setModelList((prev) => - prev.map((m) => ({ ...m, is_default: m.model_name === modelName })), - ) - } catch (err) { - console.error("Failed to set default model:", err) - } - }, []) + const handleSetDefault = useCallback( + async (modelName: string) => { + if (modelName === defaultModelName) return + const requestId = ++setDefaultRequestIdRef.current + + try { + await setDefaultModel(modelName) + const data = await getModels() + if (requestId !== setDefaultRequestIdRef.current) { + return + } + + setModelList(data.models) + if (data.models.some((m) => m.model_name === data.default_model)) { + setDefaultModelName(data.default_model) + } + } catch (err) { + console.error("Failed to set default model:", err) + } + }, + [defaultModelName], + ) const hasConfiguredModels = useMemo( () => modelList.some((m) => m.configured), diff --git a/web/frontend/src/hooks/use-gateway-logs.ts b/web/frontend/src/hooks/use-gateway-logs.ts index a39e6e930..593e90b26 100644 --- a/web/frontend/src/hooks/use-gateway-logs.ts +++ b/web/frontend/src/hooks/use-gateway-logs.ts @@ -37,7 +37,7 @@ export function useGatewayLogs() { const fetchLogs = async () => { if ( !mounted || - (gateway.status !== "running" && gateway.status !== "starting") + !["running", "starting", "restarting"].includes(gateway.status) ) { if (mounted) { timeout = setTimeout(fetchLogs, 1000) diff --git a/web/frontend/src/hooks/use-gateway.ts b/web/frontend/src/hooks/use-gateway.ts index 097dc3598..848f4d59c 100644 --- a/web/frontend/src/hooks/use-gateway.ts +++ b/web/frontend/src/hooks/use-gateway.ts @@ -1,31 +1,30 @@ -import { useAtom } from "jotai" +import { useAtomValue } from "jotai" import { useCallback, useEffect, useState } from "react" import { type GatewayStatusResponse, getGatewayStatus, + restartGateway, startGateway, stopGateway, } from "@/api/gateway" -import { gatewayAtom } from "@/store" +import { + applyGatewayStatusToStore, + gatewayAtom, + updateGatewayStore, +} from "@/store" // Global variable to ensure we only have one SSE connection let sseInitialized = false export function useGateway() { - const [{ status: state, canStart }, setGateway] = useAtom(gatewayAtom) + const gateway = useAtomValue(gatewayAtom) + const { status: state, canStart, restartRequired } = gateway const [loading, setLoading] = useState(false) - const applyGatewayStatus = useCallback( - (data: GatewayStatusResponse) => { - setGateway((prev) => ({ - ...prev, - status: data.gateway_status ?? "unknown", - canStart: data.gateway_start_allowed ?? true, - })) - }, - [setGateway], - ) + const applyGatewayStatus = useCallback((data: GatewayStatusResponse) => { + applyGatewayStatusToStore(data) + }, []) // Initialize global SSE connection once useEffect(() => { @@ -35,9 +34,10 @@ export function useGateway() { getGatewayStatus() .then((data) => applyGatewayStatus(data)) .catch(() => { - setGateway({ + updateGatewayStore({ status: "unknown", canStart: true, + restartRequired: false, }) }) @@ -59,14 +59,7 @@ export function useGateway() { data.gateway_status || typeof data.gateway_start_allowed === "boolean" ) { - setGateway((prev) => ({ - ...prev, - status: data.gateway_status ?? prev.status, - canStart: - typeof data.gateway_start_allowed === "boolean" - ? data.gateway_start_allowed - : prev.canStart, - })) + applyGatewayStatus(data) } } catch { // ignore @@ -75,7 +68,9 @@ export function useGateway() { es.onerror = () => { // EventSource will auto-reconnect - setGateway((prev) => ({ ...prev, status: "unknown" })) + updateGatewayStore((prev) => + prev.status === "restarting" ? {} : { status: "unknown" }, + ) } return () => { @@ -83,7 +78,7 @@ export function useGateway() { es.close() sseInitialized = false } - }, [applyGatewayStatus, setGateway]) + }, [applyGatewayStatus]) const start = useCallback(async () => { if (!canStart) return @@ -92,19 +87,19 @@ export function useGateway() { try { await startGateway() // SSE will push the real state changes, but set optimistic state - setGateway((prev) => ({ ...prev, status: "starting" })) + updateGatewayStore({ status: "starting" }) } catch (err) { console.error("Failed to start gateway:", err) try { const status = await getGatewayStatus() applyGatewayStatus(status) } catch { - setGateway((prev) => ({ ...prev, status: "unknown" })) + updateGatewayStore({ status: "unknown" }) } } finally { setLoading(false) } - }, [applyGatewayStatus, canStart, setGateway]) + }, [applyGatewayStatus, canStart]) const stop = useCallback(async () => { setLoading(true) @@ -117,5 +112,37 @@ export function useGateway() { } }, []) - return { state, loading, canStart, start, stop } + const restart = useCallback(async () => { + if (state !== "running") return + + const previousState = state + const previousCanStart = canStart + const previousRestartRequired = restartRequired + + setLoading(true) + updateGatewayStore({ + status: "restarting", + restartRequired: false, + }) + + try { + await restartGateway() + } catch (err) { + console.error("Failed to restart gateway:", err) + try { + const status = await getGatewayStatus() + applyGatewayStatus(status) + } catch { + updateGatewayStore({ + status: previousState, + canStart: previousCanStart, + restartRequired: previousRestartRequired, + }) + } + } finally { + setLoading(false) + } + }, [applyGatewayStatus, canStart, restartRequired, state]) + + return { state, loading, canStart, restartRequired, start, stop, restart } } diff --git a/web/frontend/src/hooks/use-pico-chat.ts b/web/frontend/src/hooks/use-pico-chat.ts index 7e3066177..2b7a510af 100644 --- a/web/frontend/src/hooks/use-pico-chat.ts +++ b/web/frontend/src/hooks/use-pico-chat.ts @@ -130,8 +130,9 @@ export function usePicoChat() { const [connectionState, setConnectionState] = useState("disconnected") const [isTyping, setIsTyping] = useState(false) - const [activeSessionId, setActiveSessionId] = - useState(() => readStoredSessionId() || generateSessionId()) + const [activeSessionId, setActiveSessionId] = useState( + () => readStoredSessionId() || generateSessionId(), + ) const wsRef = useRef(null) const isConnectingRef = useRef(false) @@ -144,9 +145,7 @@ export function usePicoChat() { setMessages((prev) => { const next = typeof nextState === "function" - ? ( - nextState as (prevState: ChatMessage[]) => ChatMessage[] - )(prev) + ? (nextState as (prevState: ChatMessage[]) => ChatMessage[])(prev) : nextState if (next !== prev) { @@ -220,64 +219,69 @@ export function usePicoChat() { } }, [loadSessionMessages, setTrackedMessages]) - const handlePicoMessage = useCallback((msg: PicoMessage) => { - const payload = msg.payload || {} + const handlePicoMessage = useCallback( + (msg: PicoMessage) => { + const payload = msg.payload || {} - switch (msg.type) { - case "message.create": { - const content = (payload.content as string) || "" - const messageId = (payload.message_id as string) || `pico-${Date.now()}` - // Use provided timestamp or current time - const timestampRaw = - msg.timestamp !== undefined && Number.isFinite(Number(msg.timestamp)) - ? normalizeUnixTimestamp(Number(msg.timestamp)) - : Date.now() + switch (msg.type) { + case "message.create": { + const content = (payload.content as string) || "" + const messageId = + (payload.message_id as string) || `pico-${Date.now()}` + // Use provided timestamp or current time + const timestampRaw = + msg.timestamp !== undefined && + Number.isFinite(Number(msg.timestamp)) + ? normalizeUnixTimestamp(Number(msg.timestamp)) + : Date.now() - setTrackedMessages((prev) => [ - ...prev, - { - id: messageId, - role: "assistant", - content, - timestamp: timestampRaw, - }, - ]) - setIsTyping(false) - break + setTrackedMessages((prev) => [ + ...prev, + { + id: messageId, + role: "assistant", + content, + timestamp: timestampRaw, + }, + ]) + setIsTyping(false) + break + } + + case "message.update": { + const content = (payload.content as string) || "" + const messageId = payload.message_id as string + if (!messageId) break + + setTrackedMessages((prev) => + prev.map((m) => (m.id === messageId ? { ...m, content } : m)), + ) + break + } + + case "typing.start": + setIsTyping(true) + break + + case "typing.stop": + setIsTyping(false) + break + + case "error": + console.error("Pico error:", payload) + setIsTyping(false) + break + + case "pong": + // heartbeat response, ignore + break + + default: + console.log("Unknown pico message type:", msg.type) } - - case "message.update": { - const content = (payload.content as string) || "" - const messageId = payload.message_id as string - if (!messageId) break - - setTrackedMessages((prev) => - prev.map((m) => (m.id === messageId ? { ...m, content } : m)), - ) - break - } - - case "typing.start": - setIsTyping(true) - break - - case "typing.stop": - setIsTyping(false) - break - - case "error": - console.error("Pico error:", payload) - setIsTyping(false) - break - - case "pong": - // heartbeat response, ignore - break - - default: - console.log("Unknown pico message type:", msg.type) - } - }, [setTrackedMessages]) + }, + [setTrackedMessages], + ) const connect = useCallback(async () => { if ( @@ -389,32 +393,35 @@ export function usePicoChat() { return () => disconnect() }, [disconnect]) - const sendMessage = useCallback((content: string) => { - if (!wsRef.current || wsRef.current.readyState !== WebSocket.OPEN) { - console.warn("WebSocket not connected") - return - } + const sendMessage = useCallback( + (content: string) => { + if (!wsRef.current || wsRef.current.readyState !== WebSocket.OPEN) { + console.warn("WebSocket not connected") + return + } - const id = `msg-${++msgIdCounter.current}-${Date.now()}` - const timestampRaw = Date.now() + const id = `msg-${++msgIdCounter.current}-${Date.now()}` + const timestampRaw = Date.now() - // Add user message to local state - setTrackedMessages((prev) => [ - ...prev, - { id, role: "user", content, timestamp: timestampRaw }, - ]) + // Add user message to local state + setTrackedMessages((prev) => [ + ...prev, + { id, role: "user", content, timestamp: timestampRaw }, + ]) - // Show typing indicator immediately - setIsTyping(true) + // Show typing indicator immediately + setIsTyping(true) - // Send via Pico Protocol - const picoMsg: PicoMessage = { - type: "message.send", - id, - payload: { content }, - } - wsRef.current.send(JSON.stringify(picoMsg)) - }, [setTrackedMessages]) + // Send via Pico Protocol + const picoMsg: PicoMessage = { + type: "message.send", + id, + payload: { content }, + } + wsRef.current.send(JSON.stringify(picoMsg)) + }, + [setTrackedMessages], + ) // Switch to a historical session const switchSession = useCallback( @@ -443,7 +450,14 @@ export function usePicoChat() { } }, 100) }, - [connect, disconnect, gatewayState, loadSessionMessages, setTrackedMessages, t], + [ + connect, + disconnect, + gatewayState, + loadSessionMessages, + setTrackedMessages, + t, + ], ) // Start a new empty chat diff --git a/web/frontend/src/i18n/locales/en.json b/web/frontend/src/i18n/locales/en.json index 453c5905f..b099dec13 100644 --- a/web/frontend/src/i18n/locales/en.json +++ b/web/frontend/src/i18n/locales/en.json @@ -58,11 +58,14 @@ }, "action": { "start": "Start Gateway", - "stop": "Stop Gateway" + "stop": "Stop Gateway", + "restart": "Restart Gateway" }, "status": { - "starting": "Starting Gateway..." - } + "starting": "Starting Gateway...", + "restarting": "Restarting Gateway..." + }, + "restartRequired": "Model changes require a gateway restart to take effect." } }, "common": { diff --git a/web/frontend/src/i18n/locales/zh.json b/web/frontend/src/i18n/locales/zh.json index b6bdedbfa..78093e5c7 100644 --- a/web/frontend/src/i18n/locales/zh.json +++ b/web/frontend/src/i18n/locales/zh.json @@ -58,11 +58,14 @@ }, "action": { "start": "启动服务", - "stop": "停止服务" + "stop": "停止服务", + "restart": "重启服务" }, "status": { - "starting": "服务启动中..." - } + "starting": "服务启动中...", + "restarting": "服务重启中..." + }, + "restartRequired": "切换默认模型后需要重启服务才能生效。" } }, "common": { diff --git a/web/frontend/src/store/gateway.ts b/web/frontend/src/store/gateway.ts index 89da9d7fd..b7655839c 100644 --- a/web/frontend/src/store/gateway.ts +++ b/web/frontend/src/store/gateway.ts @@ -5,6 +5,7 @@ import { type GatewayStatusResponse, getGatewayStatus } from "@/api/gateway" export type GatewayState = | "running" | "starting" + | "restarting" | "stopped" | "error" | "unknown" @@ -12,19 +13,54 @@ export type GatewayState = export interface GatewayStoreState { status: GatewayState canStart: boolean + restartRequired: boolean +} + +type GatewayStorePatch = Partial + +const DEFAULT_GATEWAY_STATE: GatewayStoreState = { + status: "unknown", + canStart: true, + restartRequired: false, } // Global atom for gateway state -export const gatewayAtom = atom({ - status: "unknown", - canStart: true, -}) +export const gatewayAtom = atom(DEFAULT_GATEWAY_STATE) -function applyGatewayStatusToStore(data: GatewayStatusResponse) { - getDefaultStore().set(gatewayAtom, (prev) => ({ - ...prev, - status: data.gateway_status ?? "unknown", - canStart: data.gateway_start_allowed ?? true, +function normalizeGatewayStoreState( + prev: GatewayStoreState, + patch: GatewayStorePatch, +) { + return { ...prev, ...patch } +} + +export function updateGatewayStore( + patch: + | GatewayStorePatch + | ((prev: GatewayStoreState) => GatewayStorePatch | GatewayStoreState), +) { + getDefaultStore().set(gatewayAtom, (prev) => { + const nextPatch = typeof patch === "function" ? patch(prev) : patch + return normalizeGatewayStoreState(prev, nextPatch) + }) +} + +export function applyGatewayStatusToStore( + data: Partial< + Pick< + GatewayStatusResponse, + "gateway_status" | "gateway_start_allowed" | "gateway_restart_required" + > + >, +) { + updateGatewayStore((prev) => ({ + status: data.gateway_status ?? prev.status, + canStart: data.gateway_start_allowed ?? prev.canStart, + restartRequired: + data.gateway_restart_required ?? + (data.gateway_status && data.gateway_status !== "running" + ? false + : prev.restartRequired), })) } @@ -33,6 +69,6 @@ export async function refreshGatewayState() { const status = await getGatewayStatus() applyGatewayStatusToStore(status) } catch { - // Best-effort refresh only; keep current state on error. + updateGatewayStore(DEFAULT_GATEWAY_STATE) } }