diff --git a/web/backend/api/events.go b/web/backend/api/events.go index 0a8d4a9bb..af44d1824 100644 --- a/web/backend/api/events.go +++ b/web/backend/api/events.go @@ -7,8 +7,11 @@ import ( // GatewayEvent represents a state change event for the gateway process. type GatewayEvent struct { - Status string `json:"gateway_status"` // "running", "starting", "stopped", "error" - PID int `json:"pid,omitempty"` + Status string `json:"gateway_status"` // "running", "starting", "restarting", "stopped", "error" + PID int `json:"pid,omitempty"` + BootDefaultModel string `json:"boot_default_model,omitempty"` + ConfigDefaultModel string `json:"config_default_model,omitempty"` + RestartRequired bool `json:"gateway_restart_required,omitempty"` } // EventBroadcaster manages SSE client subscriptions and broadcasts events. diff --git a/web/backend/api/gateway.go b/web/backend/api/gateway.go index 41f702e32..95b482ce0 100644 --- a/web/backend/api/gateway.go +++ b/web/backend/api/gateway.go @@ -23,13 +23,29 @@ import ( // gateway holds the state for the managed gateway process. var gateway = struct { - mu sync.Mutex - cmd *exec.Cmd - logs *LogBuffer - events *EventBroadcaster + mu sync.Mutex + cmd *exec.Cmd + bootDefaultModel string + runtimeStatus string + startupDeadline time.Time + logs *LogBuffer + events *EventBroadcaster }{ - logs: NewLogBuffer(200), - events: NewEventBroadcaster(), + runtimeStatus: "stopped", + logs: NewLogBuffer(200), + events: NewEventBroadcaster(), +} + +var ( + gatewayStartupWindow = 15 * time.Second + gatewayRestartGracePeriod = 5 * time.Second + gatewayRestartForceKillWindow = 3 * time.Second + gatewayRestartPollInterval = 100 * time.Millisecond +) + +var gatewayHealthGet = func(url string, timeout time.Duration) (*http.Response, error) { + client := http.Client{Timeout: timeout} + return client.Get(url) } // registerGatewayRoutes binds gateway lifecycle endpoints to the ServeMux. @@ -65,7 +81,7 @@ func (h *Handler) TryAutoStartGateway() { return } - pid, err := h.startGatewayLocked() + pid, err := h.startGatewayLocked("starting") if err != nil { log.Printf("Failed to auto-start gateway: %v", err) return @@ -131,7 +147,110 @@ func isCmdProcessAliveLocked(cmd *exec.Cmd) bool { return cmd.Process.Signal(syscall.Signal(0)) == nil } -func (h *Handler) startGatewayLocked() (int, error) { +func setGatewayRuntimeStatusLocked(status string) { + gateway.runtimeStatus = status + if status == "starting" || status == "restarting" { + gateway.startupDeadline = time.Now().Add(gatewayStartupWindow) + return + } + gateway.startupDeadline = time.Time{} +} + +func gatewayStatusOnHealthFailureLocked() string { + if gateway.runtimeStatus == "starting" || gateway.runtimeStatus == "restarting" { + if gateway.startupDeadline.IsZero() || time.Now().Before(gateway.startupDeadline) { + return gateway.runtimeStatus + } + return "error" + } + if gateway.runtimeStatus == "running" { + return "running" + } + if gateway.runtimeStatus == "error" { + return "error" + } + return "error" +} + +func currentGatewayStatusLocked(processAlive bool) string { + if !processAlive { + if gateway.runtimeStatus == "restarting" { + if gateway.startupDeadline.IsZero() || time.Now().Before(gateway.startupDeadline) { + return "restarting" + } + return "error" + } + if gateway.runtimeStatus == "error" { + return "error" + } + return "stopped" + } + return gatewayStatusOnHealthFailureLocked() +} + +func waitForGatewayProcessExit(cmd *exec.Cmd, timeout time.Duration) bool { + if cmd == nil || cmd.Process == nil { + return true + } + + deadline := time.Now().Add(timeout) + for { + if !isCmdProcessAliveLocked(cmd) { + return true + } + if time.Now().After(deadline) { + return false + } + time.Sleep(gatewayRestartPollInterval) + } +} + +func stopGatewayProcessForRestart(cmd *exec.Cmd) error { + if cmd == nil || cmd.Process == nil || !isCmdProcessAliveLocked(cmd) { + return nil + } + + var stopErr error + if runtime.GOOS == "windows" { + stopErr = cmd.Process.Kill() + } else { + stopErr = cmd.Process.Signal(syscall.SIGTERM) + } + if stopErr != nil && isCmdProcessAliveLocked(cmd) { + return fmt.Errorf("failed to stop existing gateway: %w", stopErr) + } + + if waitForGatewayProcessExit(cmd, gatewayRestartGracePeriod) { + return nil + } + + if runtime.GOOS != "windows" { + killErr := cmd.Process.Signal(syscall.SIGKILL) + if killErr != nil && isCmdProcessAliveLocked(cmd) { + return fmt.Errorf("failed to force-stop existing gateway: %w", killErr) + } + if waitForGatewayProcessExit(cmd, gatewayRestartForceKillWindow) { + return nil + } + } + + return fmt.Errorf("existing gateway did not exit before restart") +} + +func gatewayRestartRequired(status, bootDefaultModel, configDefaultModel string) bool { + return status == "running" && + bootDefaultModel != "" && + configDefaultModel != "" && + bootDefaultModel != configDefaultModel +} + +func (h *Handler) startGatewayLocked(initialStatus string) (int, error) { + cfg, err := config.LoadConfig(h.configPath) + if err != nil { + return 0, fmt.Errorf("failed to load config: %w", err) + } + defaultModelName := strings.TrimSpace(cfg.Agents.Defaults.GetModelName()) + // Locate the picoclaw executable execPath := utils.FindPicoclawBinary() @@ -171,11 +290,19 @@ func (h *Handler) startGatewayLocked() (int, error) { } gateway.cmd = cmd + gateway.bootDefaultModel = defaultModelName + setGatewayRuntimeStatusLocked(initialStatus) pid := cmd.Process.Pid log.Printf("Started picoclaw gateway (PID: %d) from %s", pid, execPath) - // Broadcast starting event - gateway.events.Broadcast(GatewayEvent{Status: "starting", PID: pid}) + // Broadcast the launch state immediately so clients can reflect it without polling. + gateway.events.Broadcast(GatewayEvent{ + Status: initialStatus, + PID: pid, + BootDefaultModel: defaultModelName, + ConfigDefaultModel: defaultModelName, + RestartRequired: false, + }) // Capture stdout/stderr in background go scanPipe(stdoutPipe, gateway.logs) @@ -190,13 +317,23 @@ func (h *Handler) startGatewayLocked() (int, error) { } gateway.mu.Lock() + shouldBroadcastStopped := false if gateway.cmd == cmd { gateway.cmd = nil + gateway.bootDefaultModel = "" + if gateway.runtimeStatus != "restarting" { + setGatewayRuntimeStatusLocked("stopped") + shouldBroadcastStopped = true + } } gateway.mu.Unlock() - // Broadcast stopped event - gateway.events.Broadcast(GatewayEvent{Status: "stopped"}) + if shouldBroadcastStopped { + gateway.events.Broadcast(GatewayEvent{ + Status: "stopped", + RestartRequired: false, + }) + } }() // Start a goroutine to probe health and broadcast "running" once ready @@ -219,12 +356,22 @@ func (h *Handler) startGatewayLocked() (int, error) { healthPort = 18790 } healthURL := fmt.Sprintf("http://%s/health", net.JoinHostPort(healthHost, strconv.Itoa(healthPort))) - client := http.Client{Timeout: 1 * time.Second} - resp, err := client.Get(healthURL) + resp, err := gatewayHealthGet(healthURL, 1*time.Second) if err == nil { resp.Body.Close() if resp.StatusCode == http.StatusOK { - gateway.events.Broadcast(GatewayEvent{Status: "running", PID: pid}) + gateway.mu.Lock() + if gateway.cmd == cmd { + setGatewayRuntimeStatusLocked("running") + } + gateway.mu.Unlock() + gateway.events.Broadcast(GatewayEvent{ + Status: "running", + PID: pid, + BootDefaultModel: defaultModelName, + ConfigDefaultModel: defaultModelName, + RestartRequired: false, + }) return } } @@ -253,6 +400,7 @@ func (h *Handler) handleGatewayStart(w http.ResponseWriter, r *http.Request) { } if gateway.cmd != nil && gateway.cmd.Process != nil { gateway.cmd = nil + setGatewayRuntimeStatusLocked("stopped") } ready, reason, err := h.gatewayStartReady() @@ -274,7 +422,7 @@ func (h *Handler) handleGatewayStart(w http.ResponseWriter, r *http.Request) { return } - pid, err := h.startGatewayLocked() + pid, err := h.startGatewayLocked("starting") if err != nil { http.Error(w, fmt.Sprintf("Failed to start gateway: %v", err), http.StatusInternalServerError) return @@ -330,30 +478,72 @@ func (h *Handler) handleGatewayStop(w http.ResponseWriter, r *http.Request) { // // POST /api/gateway/restart func (h *Handler) handleGatewayRestart(w http.ResponseWriter, r *http.Request) { - gateway.mu.Lock() - - // Stop existing process if running - if gateway.cmd != nil && gateway.cmd.Process != nil { - if isCmdProcessAliveLocked(gateway.cmd) { - // Process is alive, send SIGTERM - if runtime.GOOS == "windows" { - gateway.cmd.Process.Kill() - } else { - gateway.cmd.Process.Signal(syscall.SIGTERM) - } - - // Wait briefly for it to exit - gateway.mu.Unlock() - time.Sleep(2 * time.Second) - gateway.mu.Lock() - } - gateway.cmd = nil + ready, reason, err := h.gatewayStartReady() + if err != nil { + http.Error( + w, + fmt.Sprintf("Failed to validate gateway start conditions: %v", err), + http.StatusInternalServerError, + ) + return + } + if !ready { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(map[string]any{ + "status": "precondition_failed", + "message": reason, + }) + return } + gateway.mu.Lock() + previousCmd := gateway.cmd + setGatewayRuntimeStatusLocked("restarting") + gateway.events.Broadcast(GatewayEvent{ + Status: "restarting", + RestartRequired: false, + }) gateway.mu.Unlock() - // Start fresh via the existing handler - h.handleGatewayStart(w, r) + if err = stopGatewayProcessForRestart(previousCmd); err != nil { + gateway.mu.Lock() + if gateway.cmd == previousCmd { + if isCmdProcessAliveLocked(previousCmd) { + setGatewayRuntimeStatusLocked("running") + } else { + gateway.cmd = nil + gateway.bootDefaultModel = "" + setGatewayRuntimeStatusLocked("error") + } + } + gateway.mu.Unlock() + http.Error(w, fmt.Sprintf("Failed to restart gateway: %v", err), http.StatusInternalServerError) + return + } + + gateway.mu.Lock() + if gateway.cmd == previousCmd { + gateway.cmd = nil + gateway.bootDefaultModel = "" + } + pid, err := h.startGatewayLocked("restarting") + if err != nil { + gateway.cmd = nil + gateway.bootDefaultModel = "" + setGatewayRuntimeStatusLocked("error") + } + gateway.mu.Unlock() + if err != nil { + http.Error(w, fmt.Sprintf("Failed to restart gateway: %v", err), http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]any{ + "status": "ok", + "pid": pid, + }) } // handleGatewayClearLogs clears the in-memory gateway log buffer. @@ -374,24 +564,44 @@ func (h *Handler) handleGatewayClearLogs(w http.ResponseWriter, r *http.Request) // // GET /api/gateway/status func (h *Handler) handleGatewayStatus(w http.ResponseWriter, r *http.Request) { + data := h.gatewayStatusData(r, true) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(data) +} + +func (h *Handler) gatewayStatusData(r *http.Request, includeLogs bool) map[string]any { data := map[string]any{} + cfg, cfgErr := config.LoadConfig(h.configPath) + configDefaultModel := "" + if cfgErr == nil && cfg != nil { + configDefaultModel = strings.TrimSpace(cfg.Agents.Defaults.GetModelName()) + if configDefaultModel != "" { + data["config_default_model"] = configDefaultModel + } + } // Check process state gateway.mu.Lock() processAlive := isGatewayProcessAliveLocked() + bootDefaultModel := "" if processAlive { data["pid"] = gateway.cmd.Process.Pid + if gateway.bootDefaultModel != "" { + data["boot_default_model"] = gateway.bootDefaultModel + bootDefaultModel = gateway.bootDefaultModel + } } gateway.mu.Unlock() if !processAlive { - data["gateway_status"] = "stopped" + gateway.mu.Lock() + data["gateway_status"] = currentGatewayStatusLocked(false) + gateway.mu.Unlock() } else { // Process is alive — probe its health endpoint - cfg, err := config.LoadConfig(h.configPath) host := "127.0.0.1" port := 18790 - if err == nil && cfg != nil { + if cfgErr == nil && cfg != nil { host = gatewayProbeHost(h.effectiveGatewayBindHost(cfg)) if cfg.Gateway.Port != 0 { port = cfg.Gateway.Port @@ -399,21 +609,31 @@ func (h *Handler) handleGatewayStatus(w http.ResponseWriter, r *http.Request) { } url := fmt.Sprintf("http://%s/health", net.JoinHostPort(host, strconv.Itoa(port))) - client := http.Client{Timeout: 2 * time.Second} - resp, err := client.Get(url) + resp, err := gatewayHealthGet(url, 2*time.Second) if err != nil { - data["gateway_status"] = "starting" + gateway.mu.Lock() + data["gateway_status"] = currentGatewayStatusLocked(true) + gateway.mu.Unlock() } else { defer resp.Body.Close() if resp.StatusCode != http.StatusOK { + gateway.mu.Lock() + setGatewayRuntimeStatusLocked("error") + gateway.mu.Unlock() data["gateway_status"] = "error" data["status_code"] = resp.StatusCode } else { var healthData map[string]any if decErr := json.NewDecoder(resp.Body).Decode(&healthData); decErr != nil { + gateway.mu.Lock() + setGatewayRuntimeStatusLocked("error") + gateway.mu.Unlock() data["gateway_status"] = "error" } else { + gateway.mu.Lock() + setGatewayRuntimeStatusLocked("running") + gateway.mu.Unlock() for k, v := range healthData { data[k] = v } @@ -423,6 +643,13 @@ func (h *Handler) handleGatewayStatus(w http.ResponseWriter, r *http.Request) { } } + status, _ := data["gateway_status"].(string) + data["gateway_restart_required"] = gatewayRestartRequired( + status, + bootDefaultModel, + configDefaultModel, + ) + ready, reason, readyErr := h.gatewayStartReady() if readyErr != nil { data["gateway_start_allowed"] = false @@ -434,11 +661,11 @@ func (h *Handler) handleGatewayStatus(w http.ResponseWriter, r *http.Request) { } } - // Append incremental log data - appendGatewayLogs(r, data) + if includeLogs { + appendGatewayLogs(r, data) + } - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(data) + return data } // appendGatewayLogs reads log_offset and log_run_id query params from the request @@ -524,28 +751,7 @@ func (h *Handler) handleGatewayEvents(w http.ResponseWriter, r *http.Request) { // currentGatewayStatus returns the current gateway status as a JSON string. func (h *Handler) currentGatewayStatus() string { - gateway.mu.Lock() - defer gateway.mu.Unlock() - - data := map[string]any{ - "gateway_status": "stopped", - } - if isGatewayProcessAliveLocked() { - data["gateway_status"] = "running" - data["pid"] = gateway.cmd.Process.Pid - } - - ready, reason, readyErr := h.gatewayStartReady() - if readyErr != nil { - data["gateway_start_allowed"] = false - data["gateway_start_reason"] = readyErr.Error() - } else { - data["gateway_start_allowed"] = ready - if !ready { - data["gateway_start_reason"] = reason - } - } - + data := h.gatewayStatusData(nil, false) encoded, _ := json.Marshal(data) return string(encoded) } diff --git a/web/backend/api/gateway_test.go b/web/backend/api/gateway_test.go index d4265776a..fe3fccdee 100644 --- a/web/backend/api/gateway_test.go +++ b/web/backend/api/gateway_test.go @@ -2,19 +2,76 @@ package api import ( "encoding/json" + "errors" "net/http" "net/http/httptest" "os" + "os/exec" "path/filepath" + "runtime" "strconv" "strings" "testing" + "time" "github.com/sipeed/picoclaw/pkg/auth" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/web/backend/utils" ) +func startLongRunningProcess(t *testing.T) *exec.Cmd { + t.Helper() + + var cmd *exec.Cmd + if runtime.GOOS == "windows" { + cmd = exec.Command("powershell", "-NoProfile", "-Command", "Start-Sleep -Seconds 30") + } else { + cmd = exec.Command("sleep", "30") + } + + if err := cmd.Start(); err != nil { + t.Fatalf("Start() error = %v", err) + } + + return cmd +} + +func startIgnoringTermProcess(t *testing.T) *exec.Cmd { + t.Helper() + + if runtime.GOOS == "windows" { + t.Skip("TERM handling differs on Windows") + } + + cmd := exec.Command("sh", "-c", "trap '' TERM; sleep 30") + if err := cmd.Start(); err != nil { + t.Fatalf("Start() error = %v", err) + } + + return cmd +} + +func resetGatewayTestState(t *testing.T) { + t.Helper() + + originalHealthGet := gatewayHealthGet + originalRestartGracePeriod := gatewayRestartGracePeriod + originalRestartForceKillWindow := gatewayRestartForceKillWindow + originalRestartPollInterval := gatewayRestartPollInterval + t.Cleanup(func() { + gatewayHealthGet = originalHealthGet + gatewayRestartGracePeriod = originalRestartGracePeriod + gatewayRestartForceKillWindow = originalRestartForceKillWindow + gatewayRestartPollInterval = originalRestartPollInterval + + gateway.mu.Lock() + gateway.cmd = nil + gateway.bootDefaultModel = "" + setGatewayRuntimeStatusLocked("stopped") + gateway.mu.Unlock() + }) +} + func TestGatewayStartReady_NoDefaultModel(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") h := NewHandler(configPath) @@ -317,6 +374,339 @@ func TestGatewayStatusIncludesStartConditionWhenNotReady(t *testing.T) { } } +func TestGatewayStatusKeepsRunningWhenHealthProbeFailsAfterRunning(t *testing.T) { + resetGatewayTestState(t) + + configPath := filepath.Join(t.TempDir(), "config.json") + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + cmd := startLongRunningProcess(t) + t.Cleanup(func() { + if cmd.Process != nil { + _ = cmd.Process.Kill() + } + _ = cmd.Wait() + }) + + gateway.mu.Lock() + gateway.cmd = cmd + gateway.bootDefaultModel = "existing-model" + // Simulate a process that has already reached the running state. + setGatewayRuntimeStatusLocked("running") + gateway.mu.Unlock() + + gatewayHealthGet = func(string, time.Duration) (*http.Response, error) { + return nil, errors.New("probe failed") + } + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil) + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } + + var body map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil { + t.Fatalf("unmarshal response: %v", err) + } + + if got := body["gateway_status"]; got != "running" { + t.Fatalf("gateway_status = %#v, want %q", got, "running") + } +} + +func TestGatewayStatusReturnsErrorAfterStartupWindowExpires(t *testing.T) { + resetGatewayTestState(t) + + configPath := filepath.Join(t.TempDir(), "config.json") + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + cmd := startLongRunningProcess(t) + t.Cleanup(func() { + if cmd.Process != nil { + _ = cmd.Process.Kill() + } + _ = cmd.Wait() + }) + + gateway.mu.Lock() + gateway.cmd = cmd + gateway.bootDefaultModel = "existing-model" + setGatewayRuntimeStatusLocked("starting") + gateway.startupDeadline = time.Now().Add(-time.Second) + gateway.mu.Unlock() + + gatewayHealthGet = func(string, time.Duration) (*http.Response, error) { + return nil, errors.New("probe failed") + } + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil) + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } + + var body map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil { + t.Fatalf("unmarshal response: %v", err) + } + + if got := body["gateway_status"]; got != "error" { + t.Fatalf("gateway_status = %#v, want %q", got, "error") + } +} + +func TestGatewayStatusReturnsRestartingDuringRestartGap(t *testing.T) { + resetGatewayTestState(t) + + configPath := filepath.Join(t.TempDir(), "config.json") + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + gateway.mu.Lock() + setGatewayRuntimeStatusLocked("restarting") + gateway.mu.Unlock() + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil) + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } + + var body map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil { + t.Fatalf("unmarshal response: %v", err) + } + + if got := body["gateway_status"]; got != "restarting" { + t.Fatalf("gateway_status = %#v, want %q", got, "restarting") + } +} + +func TestGatewayStatusIncludesRestartRequiredWhenModelsDiffer(t *testing.T) { + resetGatewayTestState(t) + + configPath := filepath.Join(t.TempDir(), "config.json") + cfg := config.DefaultConfig() + cfg.Agents.Defaults.ModelName = cfg.ModelList[0].ModelName + cfg.ModelList[0].APIKey = "test-key" + if err := config.SaveConfig(configPath, cfg); err != nil { + t.Fatalf("SaveConfig() error = %v", err) + } + + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + cmd := startLongRunningProcess(t) + t.Cleanup(func() { + if cmd.Process != nil { + _ = cmd.Process.Kill() + } + _ = cmd.Wait() + }) + + gateway.mu.Lock() + gateway.cmd = cmd + gateway.bootDefaultModel = "previous-model" + setGatewayRuntimeStatusLocked("running") + gateway.mu.Unlock() + + gatewayHealthGet = func(string, time.Duration) (*http.Response, error) { + rec := httptest.NewRecorder() + rec.WriteHeader(http.StatusOK) + _, _ = rec.WriteString(`{"ok":true}`) + return rec.Result(), nil + } + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil) + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } + + var body map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil { + t.Fatalf("unmarshal response: %v", err) + } + + if got := body["gateway_restart_required"]; got != true { + t.Fatalf("gateway_restart_required = %#v, want true", got) + } +} + +func TestGatewayRestartKeepsRunningProcessWhenPreconditionsFail(t *testing.T) { + configPath := filepath.Join(t.TempDir(), "config.json") + cfg := config.DefaultConfig() + cfg.Agents.Defaults.ModelName = cfg.ModelList[0].ModelName + cfg.ModelList[0].APIKey = "" + cfg.ModelList[0].AuthMethod = "" + if err := config.SaveConfig(configPath, cfg); err != nil { + t.Fatalf("SaveConfig() error = %v", err) + } + + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + cmd := startLongRunningProcess(t) + t.Cleanup(func() { + gateway.mu.Lock() + if gateway.cmd == cmd { + gateway.cmd = nil + gateway.bootDefaultModel = "" + } + gateway.mu.Unlock() + + if cmd.Process != nil { + _ = cmd.Process.Kill() + } + _ = cmd.Wait() + }) + + gateway.mu.Lock() + gateway.cmd = cmd + gateway.bootDefaultModel = "existing-model" + gateway.mu.Unlock() + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/gateway/restart", nil) + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest) + } + + gateway.mu.Lock() + stillRunning := gateway.cmd == cmd && isCmdProcessAliveLocked(cmd) + gateway.mu.Unlock() + + if !stillRunning { + t.Fatalf("gateway process was stopped when restart preconditions failed") + } +} + +func TestGatewayRestartKeepsOldProcessWhenItDoesNotExitInTime(t *testing.T) { + resetGatewayTestState(t) + + configPath := filepath.Join(t.TempDir(), "config.json") + cfg := config.DefaultConfig() + cfg.Agents.Defaults.ModelName = cfg.ModelList[0].ModelName + cfg.ModelList[0].APIKey = "test-key" + if err := config.SaveConfig(configPath, cfg); err != nil { + t.Fatalf("SaveConfig() error = %v", err) + } + + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + cmd := startIgnoringTermProcess(t) + t.Cleanup(func() { + gateway.mu.Lock() + if gateway.cmd == cmd { + gateway.cmd = nil + gateway.bootDefaultModel = "" + } + gateway.mu.Unlock() + + if cmd.Process != nil { + _ = cmd.Process.Kill() + } + _ = cmd.Wait() + }) + + gatewayRestartGracePeriod = 150 * time.Millisecond + gatewayRestartForceKillWindow = 150 * time.Millisecond + gatewayRestartPollInterval = 10 * time.Millisecond + + gateway.mu.Lock() + gateway.cmd = cmd + gateway.bootDefaultModel = "existing-model" + setGatewayRuntimeStatusLocked("running") + gateway.mu.Unlock() + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/gateway/restart", nil) + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusInternalServerError { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusInternalServerError) + } + + gateway.mu.Lock() + stillRunning := gateway.cmd == cmd && isCmdProcessAliveLocked(cmd) + status := gateway.runtimeStatus + gateway.mu.Unlock() + + if !stillRunning { + t.Fatalf("gateway process was replaced before the old process exited") + } + if status != "running" { + t.Fatalf("runtimeStatus = %q, want %q", status, "running") + } +} + +func TestGatewayRestartReturnsErrorStatusWhenReplacementFailsToStart(t *testing.T) { + resetGatewayTestState(t) + + configPath := filepath.Join(t.TempDir(), "config.json") + cfg := config.DefaultConfig() + cfg.Agents.Defaults.ModelName = cfg.ModelList[0].ModelName + cfg.ModelList[0].APIKey = "test-key" + if err := config.SaveConfig(configPath, cfg); err != nil { + t.Fatalf("SaveConfig() error = %v", err) + } + + invalidBinaryPath := filepath.Join(t.TempDir(), "fake-picoclaw") + if err := os.WriteFile(invalidBinaryPath, []byte("#!/bin/sh\n"), 0o644); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + t.Setenv("PICOCLAW_BINARY", invalidBinaryPath) + + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/gateway/restart", nil) + mux.ServeHTTP(rec, req) + + if rec.Code != http.StatusInternalServerError { + t.Fatalf("restart status = %d, want %d", rec.Code, http.StatusInternalServerError) + } + + statusRec := httptest.NewRecorder() + statusReq := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil) + mux.ServeHTTP(statusRec, statusReq) + + if statusRec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", statusRec.Code, http.StatusOK) + } + + var body map[string]any + if err := json.Unmarshal(statusRec.Body.Bytes(), &body); err != nil { + t.Fatalf("unmarshal response: %v", err) + } + + if got := body["gateway_status"]; got != "error" { + t.Fatalf("gateway_status = %#v, want %q", got, "error") + } +} + func TestGatewayClearLogsResetsBufferedHistory(t *testing.T) { configPath := filepath.Join(t.TempDir(), "config.json") h := NewHandler(configPath) diff --git a/web/frontend/src/api/gateway.ts b/web/frontend/src/api/gateway.ts index 020e92e3a..1688a5278 100644 --- a/web/frontend/src/api/gateway.ts +++ b/web/frontend/src/api/gateway.ts @@ -1,10 +1,13 @@ // API client for gateway process management. interface GatewayStatusResponse { - gateway_status: "running" | "starting" | "stopped" | "error" + gateway_status: "running" | "starting" | "restarting" | "stopped" | "error" gateway_start_allowed?: boolean gateway_start_reason?: string + gateway_restart_required?: boolean pid?: number + boot_default_model?: string + config_default_model?: string logs?: string[] log_total?: number log_run_id?: number diff --git a/web/frontend/src/api/models.ts b/web/frontend/src/api/models.ts index 6a4544c65..8e49b48b4 100644 --- a/web/frontend/src/api/models.ts +++ b/web/frontend/src/api/models.ts @@ -84,7 +84,7 @@ export async function setDefaultModel( body: JSON.stringify({ model_name: modelName }), }) - void refreshGatewayState() + await refreshGatewayState() return response } diff --git a/web/frontend/src/components/app-header.tsx b/web/frontend/src/components/app-header.tsx index 7a50fe0fb..fe0c84e69 100644 --- a/web/frontend/src/components/app-header.tsx +++ b/web/frontend/src/components/app-header.tsx @@ -6,6 +6,7 @@ import { IconMoon, IconPlayerPlay, IconPower, + IconRefresh, IconSun, } from "@tabler/icons-react" import { Link } from "@tanstack/react-router" @@ -31,6 +32,11 @@ import { } from "@/components/ui/dropdown-menu.tsx" import { Separator } from "@/components/ui/separator.tsx" import { SidebarTrigger } from "@/components/ui/sidebar" +import { + Tooltip, + TooltipContent, + TooltipTrigger, +} from "@/components/ui/tooltip" import { useGateway } from "@/hooks/use-gateway.ts" import { useTheme } from "@/hooks/use-theme.ts" @@ -41,27 +47,35 @@ export function AppHeader() { state: gwState, loading: gwLoading, canStart, + restartRequired, start, + restart, stop, } = useGateway() const isRunning = gwState === "running" const isStarting = gwState === "starting" + const isRestarting = gwState === "restarting" const isStopped = gwState === "stopped" || gwState === "unknown" const showNotConnectedHint = - canStart && (gwState === "stopped" || gwState === "error") + !isRestarting && canStart && (gwState === "stopped" || gwState === "error") const [showStopDialog, setShowStopDialog] = React.useState(false) const handleGatewayToggle = () => { - if (gwLoading || (!isRunning && !canStart)) return + if (gwLoading || isRestarting || (!isRunning && !canStart)) return if (isRunning) { setShowStopDialog(true) } else { - start() + void start() } } + const handleGatewayRestart = () => { + if (gwLoading || isRestarting || !restartRequired || !canStart) return + void restart() + } + const confirmStop = () => { setShowStopDialog(false) stop() @@ -115,35 +129,67 @@ export function AppHeader() {