feat(web): add restart-required state for default model changes (#1499)

- track boot and config default models in gateway status/events
- preserve running, starting, and restarting states during health checks
- add safer gateway restart handling with stronger backend test coverage
- expose restart-required UI and refresh model state after default model update
This commit is contained in:
wenjie
2026-03-13 16:30:59 +08:00
committed by GitHub
parent 4ccea5eb93
commit 87257819f6
19 changed files with 1022 additions and 253 deletions
+5 -2
View File
@@ -7,8 +7,11 @@ import (
// GatewayEvent represents a state change event for the gateway process.
type GatewayEvent struct {
Status string `json:"gateway_status"` // "running", "starting", "stopped", "error"
PID int `json:"pid,omitempty"`
Status string `json:"gateway_status"` // "running", "starting", "restarting", "stopped", "error"
PID int `json:"pid,omitempty"`
BootDefaultModel string `json:"boot_default_model,omitempty"`
ConfigDefaultModel string `json:"config_default_model,omitempty"`
RestartRequired bool `json:"gateway_restart_required,omitempty"`
}
// EventBroadcaster manages SSE client subscriptions and broadcasts events.
+274 -68
View File
@@ -23,13 +23,29 @@ import (
// gateway holds the state for the managed gateway process.
var gateway = struct {
mu sync.Mutex
cmd *exec.Cmd
logs *LogBuffer
events *EventBroadcaster
mu sync.Mutex
cmd *exec.Cmd
bootDefaultModel string
runtimeStatus string
startupDeadline time.Time
logs *LogBuffer
events *EventBroadcaster
}{
logs: NewLogBuffer(200),
events: NewEventBroadcaster(),
runtimeStatus: "stopped",
logs: NewLogBuffer(200),
events: NewEventBroadcaster(),
}
var (
gatewayStartupWindow = 15 * time.Second
gatewayRestartGracePeriod = 5 * time.Second
gatewayRestartForceKillWindow = 3 * time.Second
gatewayRestartPollInterval = 100 * time.Millisecond
)
var gatewayHealthGet = func(url string, timeout time.Duration) (*http.Response, error) {
client := http.Client{Timeout: timeout}
return client.Get(url)
}
// registerGatewayRoutes binds gateway lifecycle endpoints to the ServeMux.
@@ -65,7 +81,7 @@ func (h *Handler) TryAutoStartGateway() {
return
}
pid, err := h.startGatewayLocked()
pid, err := h.startGatewayLocked("starting")
if err != nil {
log.Printf("Failed to auto-start gateway: %v", err)
return
@@ -131,7 +147,110 @@ func isCmdProcessAliveLocked(cmd *exec.Cmd) bool {
return cmd.Process.Signal(syscall.Signal(0)) == nil
}
func (h *Handler) startGatewayLocked() (int, error) {
func setGatewayRuntimeStatusLocked(status string) {
gateway.runtimeStatus = status
if status == "starting" || status == "restarting" {
gateway.startupDeadline = time.Now().Add(gatewayStartupWindow)
return
}
gateway.startupDeadline = time.Time{}
}
func gatewayStatusOnHealthFailureLocked() string {
if gateway.runtimeStatus == "starting" || gateway.runtimeStatus == "restarting" {
if gateway.startupDeadline.IsZero() || time.Now().Before(gateway.startupDeadline) {
return gateway.runtimeStatus
}
return "error"
}
if gateway.runtimeStatus == "running" {
return "running"
}
if gateway.runtimeStatus == "error" {
return "error"
}
return "error"
}
func currentGatewayStatusLocked(processAlive bool) string {
if !processAlive {
if gateway.runtimeStatus == "restarting" {
if gateway.startupDeadline.IsZero() || time.Now().Before(gateway.startupDeadline) {
return "restarting"
}
return "error"
}
if gateway.runtimeStatus == "error" {
return "error"
}
return "stopped"
}
return gatewayStatusOnHealthFailureLocked()
}
func waitForGatewayProcessExit(cmd *exec.Cmd, timeout time.Duration) bool {
if cmd == nil || cmd.Process == nil {
return true
}
deadline := time.Now().Add(timeout)
for {
if !isCmdProcessAliveLocked(cmd) {
return true
}
if time.Now().After(deadline) {
return false
}
time.Sleep(gatewayRestartPollInterval)
}
}
func stopGatewayProcessForRestart(cmd *exec.Cmd) error {
if cmd == nil || cmd.Process == nil || !isCmdProcessAliveLocked(cmd) {
return nil
}
var stopErr error
if runtime.GOOS == "windows" {
stopErr = cmd.Process.Kill()
} else {
stopErr = cmd.Process.Signal(syscall.SIGTERM)
}
if stopErr != nil && isCmdProcessAliveLocked(cmd) {
return fmt.Errorf("failed to stop existing gateway: %w", stopErr)
}
if waitForGatewayProcessExit(cmd, gatewayRestartGracePeriod) {
return nil
}
if runtime.GOOS != "windows" {
killErr := cmd.Process.Signal(syscall.SIGKILL)
if killErr != nil && isCmdProcessAliveLocked(cmd) {
return fmt.Errorf("failed to force-stop existing gateway: %w", killErr)
}
if waitForGatewayProcessExit(cmd, gatewayRestartForceKillWindow) {
return nil
}
}
return fmt.Errorf("existing gateway did not exit before restart")
}
func gatewayRestartRequired(status, bootDefaultModel, configDefaultModel string) bool {
return status == "running" &&
bootDefaultModel != "" &&
configDefaultModel != "" &&
bootDefaultModel != configDefaultModel
}
func (h *Handler) startGatewayLocked(initialStatus string) (int, error) {
cfg, err := config.LoadConfig(h.configPath)
if err != nil {
return 0, fmt.Errorf("failed to load config: %w", err)
}
defaultModelName := strings.TrimSpace(cfg.Agents.Defaults.GetModelName())
// Locate the picoclaw executable
execPath := utils.FindPicoclawBinary()
@@ -171,11 +290,19 @@ func (h *Handler) startGatewayLocked() (int, error) {
}
gateway.cmd = cmd
gateway.bootDefaultModel = defaultModelName
setGatewayRuntimeStatusLocked(initialStatus)
pid := cmd.Process.Pid
log.Printf("Started picoclaw gateway (PID: %d) from %s", pid, execPath)
// Broadcast starting event
gateway.events.Broadcast(GatewayEvent{Status: "starting", PID: pid})
// Broadcast the launch state immediately so clients can reflect it without polling.
gateway.events.Broadcast(GatewayEvent{
Status: initialStatus,
PID: pid,
BootDefaultModel: defaultModelName,
ConfigDefaultModel: defaultModelName,
RestartRequired: false,
})
// Capture stdout/stderr in background
go scanPipe(stdoutPipe, gateway.logs)
@@ -190,13 +317,23 @@ func (h *Handler) startGatewayLocked() (int, error) {
}
gateway.mu.Lock()
shouldBroadcastStopped := false
if gateway.cmd == cmd {
gateway.cmd = nil
gateway.bootDefaultModel = ""
if gateway.runtimeStatus != "restarting" {
setGatewayRuntimeStatusLocked("stopped")
shouldBroadcastStopped = true
}
}
gateway.mu.Unlock()
// Broadcast stopped event
gateway.events.Broadcast(GatewayEvent{Status: "stopped"})
if shouldBroadcastStopped {
gateway.events.Broadcast(GatewayEvent{
Status: "stopped",
RestartRequired: false,
})
}
}()
// Start a goroutine to probe health and broadcast "running" once ready
@@ -219,12 +356,22 @@ func (h *Handler) startGatewayLocked() (int, error) {
healthPort = 18790
}
healthURL := fmt.Sprintf("http://%s/health", net.JoinHostPort(healthHost, strconv.Itoa(healthPort)))
client := http.Client{Timeout: 1 * time.Second}
resp, err := client.Get(healthURL)
resp, err := gatewayHealthGet(healthURL, 1*time.Second)
if err == nil {
resp.Body.Close()
if resp.StatusCode == http.StatusOK {
gateway.events.Broadcast(GatewayEvent{Status: "running", PID: pid})
gateway.mu.Lock()
if gateway.cmd == cmd {
setGatewayRuntimeStatusLocked("running")
}
gateway.mu.Unlock()
gateway.events.Broadcast(GatewayEvent{
Status: "running",
PID: pid,
BootDefaultModel: defaultModelName,
ConfigDefaultModel: defaultModelName,
RestartRequired: false,
})
return
}
}
@@ -253,6 +400,7 @@ func (h *Handler) handleGatewayStart(w http.ResponseWriter, r *http.Request) {
}
if gateway.cmd != nil && gateway.cmd.Process != nil {
gateway.cmd = nil
setGatewayRuntimeStatusLocked("stopped")
}
ready, reason, err := h.gatewayStartReady()
@@ -274,7 +422,7 @@ func (h *Handler) handleGatewayStart(w http.ResponseWriter, r *http.Request) {
return
}
pid, err := h.startGatewayLocked()
pid, err := h.startGatewayLocked("starting")
if err != nil {
http.Error(w, fmt.Sprintf("Failed to start gateway: %v", err), http.StatusInternalServerError)
return
@@ -330,30 +478,72 @@ func (h *Handler) handleGatewayStop(w http.ResponseWriter, r *http.Request) {
//
// POST /api/gateway/restart
func (h *Handler) handleGatewayRestart(w http.ResponseWriter, r *http.Request) {
gateway.mu.Lock()
// Stop existing process if running
if gateway.cmd != nil && gateway.cmd.Process != nil {
if isCmdProcessAliveLocked(gateway.cmd) {
// Process is alive, send SIGTERM
if runtime.GOOS == "windows" {
gateway.cmd.Process.Kill()
} else {
gateway.cmd.Process.Signal(syscall.SIGTERM)
}
// Wait briefly for it to exit
gateway.mu.Unlock()
time.Sleep(2 * time.Second)
gateway.mu.Lock()
}
gateway.cmd = nil
ready, reason, err := h.gatewayStartReady()
if err != nil {
http.Error(
w,
fmt.Sprintf("Failed to validate gateway start conditions: %v", err),
http.StatusInternalServerError,
)
return
}
if !ready {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]any{
"status": "precondition_failed",
"message": reason,
})
return
}
gateway.mu.Lock()
previousCmd := gateway.cmd
setGatewayRuntimeStatusLocked("restarting")
gateway.events.Broadcast(GatewayEvent{
Status: "restarting",
RestartRequired: false,
})
gateway.mu.Unlock()
// Start fresh via the existing handler
h.handleGatewayStart(w, r)
if err = stopGatewayProcessForRestart(previousCmd); err != nil {
gateway.mu.Lock()
if gateway.cmd == previousCmd {
if isCmdProcessAliveLocked(previousCmd) {
setGatewayRuntimeStatusLocked("running")
} else {
gateway.cmd = nil
gateway.bootDefaultModel = ""
setGatewayRuntimeStatusLocked("error")
}
}
gateway.mu.Unlock()
http.Error(w, fmt.Sprintf("Failed to restart gateway: %v", err), http.StatusInternalServerError)
return
}
gateway.mu.Lock()
if gateway.cmd == previousCmd {
gateway.cmd = nil
gateway.bootDefaultModel = ""
}
pid, err := h.startGatewayLocked("restarting")
if err != nil {
gateway.cmd = nil
gateway.bootDefaultModel = ""
setGatewayRuntimeStatusLocked("error")
}
gateway.mu.Unlock()
if err != nil {
http.Error(w, fmt.Sprintf("Failed to restart gateway: %v", err), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]any{
"status": "ok",
"pid": pid,
})
}
// handleGatewayClearLogs clears the in-memory gateway log buffer.
@@ -374,24 +564,44 @@ func (h *Handler) handleGatewayClearLogs(w http.ResponseWriter, r *http.Request)
//
// GET /api/gateway/status
func (h *Handler) handleGatewayStatus(w http.ResponseWriter, r *http.Request) {
data := h.gatewayStatusData(r, true)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(data)
}
func (h *Handler) gatewayStatusData(r *http.Request, includeLogs bool) map[string]any {
data := map[string]any{}
cfg, cfgErr := config.LoadConfig(h.configPath)
configDefaultModel := ""
if cfgErr == nil && cfg != nil {
configDefaultModel = strings.TrimSpace(cfg.Agents.Defaults.GetModelName())
if configDefaultModel != "" {
data["config_default_model"] = configDefaultModel
}
}
// Check process state
gateway.mu.Lock()
processAlive := isGatewayProcessAliveLocked()
bootDefaultModel := ""
if processAlive {
data["pid"] = gateway.cmd.Process.Pid
if gateway.bootDefaultModel != "" {
data["boot_default_model"] = gateway.bootDefaultModel
bootDefaultModel = gateway.bootDefaultModel
}
}
gateway.mu.Unlock()
if !processAlive {
data["gateway_status"] = "stopped"
gateway.mu.Lock()
data["gateway_status"] = currentGatewayStatusLocked(false)
gateway.mu.Unlock()
} else {
// Process is alive — probe its health endpoint
cfg, err := config.LoadConfig(h.configPath)
host := "127.0.0.1"
port := 18790
if err == nil && cfg != nil {
if cfgErr == nil && cfg != nil {
host = gatewayProbeHost(h.effectiveGatewayBindHost(cfg))
if cfg.Gateway.Port != 0 {
port = cfg.Gateway.Port
@@ -399,21 +609,31 @@ func (h *Handler) handleGatewayStatus(w http.ResponseWriter, r *http.Request) {
}
url := fmt.Sprintf("http://%s/health", net.JoinHostPort(host, strconv.Itoa(port)))
client := http.Client{Timeout: 2 * time.Second}
resp, err := client.Get(url)
resp, err := gatewayHealthGet(url, 2*time.Second)
if err != nil {
data["gateway_status"] = "starting"
gateway.mu.Lock()
data["gateway_status"] = currentGatewayStatusLocked(true)
gateway.mu.Unlock()
} else {
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
gateway.mu.Lock()
setGatewayRuntimeStatusLocked("error")
gateway.mu.Unlock()
data["gateway_status"] = "error"
data["status_code"] = resp.StatusCode
} else {
var healthData map[string]any
if decErr := json.NewDecoder(resp.Body).Decode(&healthData); decErr != nil {
gateway.mu.Lock()
setGatewayRuntimeStatusLocked("error")
gateway.mu.Unlock()
data["gateway_status"] = "error"
} else {
gateway.mu.Lock()
setGatewayRuntimeStatusLocked("running")
gateway.mu.Unlock()
for k, v := range healthData {
data[k] = v
}
@@ -423,6 +643,13 @@ func (h *Handler) handleGatewayStatus(w http.ResponseWriter, r *http.Request) {
}
}
status, _ := data["gateway_status"].(string)
data["gateway_restart_required"] = gatewayRestartRequired(
status,
bootDefaultModel,
configDefaultModel,
)
ready, reason, readyErr := h.gatewayStartReady()
if readyErr != nil {
data["gateway_start_allowed"] = false
@@ -434,11 +661,11 @@ func (h *Handler) handleGatewayStatus(w http.ResponseWriter, r *http.Request) {
}
}
// Append incremental log data
appendGatewayLogs(r, data)
if includeLogs {
appendGatewayLogs(r, data)
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(data)
return data
}
// appendGatewayLogs reads log_offset and log_run_id query params from the request
@@ -524,28 +751,7 @@ func (h *Handler) handleGatewayEvents(w http.ResponseWriter, r *http.Request) {
// currentGatewayStatus returns the current gateway status as a JSON string.
func (h *Handler) currentGatewayStatus() string {
gateway.mu.Lock()
defer gateway.mu.Unlock()
data := map[string]any{
"gateway_status": "stopped",
}
if isGatewayProcessAliveLocked() {
data["gateway_status"] = "running"
data["pid"] = gateway.cmd.Process.Pid
}
ready, reason, readyErr := h.gatewayStartReady()
if readyErr != nil {
data["gateway_start_allowed"] = false
data["gateway_start_reason"] = readyErr.Error()
} else {
data["gateway_start_allowed"] = ready
if !ready {
data["gateway_start_reason"] = reason
}
}
data := h.gatewayStatusData(nil, false)
encoded, _ := json.Marshal(data)
return string(encoded)
}
+390
View File
@@ -2,19 +2,76 @@ package api
import (
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"os"
"os/exec"
"path/filepath"
"runtime"
"strconv"
"strings"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/auth"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/web/backend/utils"
)
func startLongRunningProcess(t *testing.T) *exec.Cmd {
t.Helper()
var cmd *exec.Cmd
if runtime.GOOS == "windows" {
cmd = exec.Command("powershell", "-NoProfile", "-Command", "Start-Sleep -Seconds 30")
} else {
cmd = exec.Command("sleep", "30")
}
if err := cmd.Start(); err != nil {
t.Fatalf("Start() error = %v", err)
}
return cmd
}
func startIgnoringTermProcess(t *testing.T) *exec.Cmd {
t.Helper()
if runtime.GOOS == "windows" {
t.Skip("TERM handling differs on Windows")
}
cmd := exec.Command("sh", "-c", "trap '' TERM; sleep 30")
if err := cmd.Start(); err != nil {
t.Fatalf("Start() error = %v", err)
}
return cmd
}
func resetGatewayTestState(t *testing.T) {
t.Helper()
originalHealthGet := gatewayHealthGet
originalRestartGracePeriod := gatewayRestartGracePeriod
originalRestartForceKillWindow := gatewayRestartForceKillWindow
originalRestartPollInterval := gatewayRestartPollInterval
t.Cleanup(func() {
gatewayHealthGet = originalHealthGet
gatewayRestartGracePeriod = originalRestartGracePeriod
gatewayRestartForceKillWindow = originalRestartForceKillWindow
gatewayRestartPollInterval = originalRestartPollInterval
gateway.mu.Lock()
gateway.cmd = nil
gateway.bootDefaultModel = ""
setGatewayRuntimeStatusLocked("stopped")
gateway.mu.Unlock()
})
}
func TestGatewayStartReady_NoDefaultModel(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
@@ -317,6 +374,339 @@ func TestGatewayStatusIncludesStartConditionWhenNotReady(t *testing.T) {
}
}
func TestGatewayStatusKeepsRunningWhenHealthProbeFailsAfterRunning(t *testing.T) {
resetGatewayTestState(t)
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
mux := http.NewServeMux()
h.RegisterRoutes(mux)
cmd := startLongRunningProcess(t)
t.Cleanup(func() {
if cmd.Process != nil {
_ = cmd.Process.Kill()
}
_ = cmd.Wait()
})
gateway.mu.Lock()
gateway.cmd = cmd
gateway.bootDefaultModel = "existing-model"
// Simulate a process that has already reached the running state.
setGatewayRuntimeStatusLocked("running")
gateway.mu.Unlock()
gatewayHealthGet = func(string, time.Duration) (*http.Response, error) {
return nil, errors.New("probe failed")
}
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil)
mux.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}
var body map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
t.Fatalf("unmarshal response: %v", err)
}
if got := body["gateway_status"]; got != "running" {
t.Fatalf("gateway_status = %#v, want %q", got, "running")
}
}
func TestGatewayStatusReturnsErrorAfterStartupWindowExpires(t *testing.T) {
resetGatewayTestState(t)
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
mux := http.NewServeMux()
h.RegisterRoutes(mux)
cmd := startLongRunningProcess(t)
t.Cleanup(func() {
if cmd.Process != nil {
_ = cmd.Process.Kill()
}
_ = cmd.Wait()
})
gateway.mu.Lock()
gateway.cmd = cmd
gateway.bootDefaultModel = "existing-model"
setGatewayRuntimeStatusLocked("starting")
gateway.startupDeadline = time.Now().Add(-time.Second)
gateway.mu.Unlock()
gatewayHealthGet = func(string, time.Duration) (*http.Response, error) {
return nil, errors.New("probe failed")
}
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil)
mux.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}
var body map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
t.Fatalf("unmarshal response: %v", err)
}
if got := body["gateway_status"]; got != "error" {
t.Fatalf("gateway_status = %#v, want %q", got, "error")
}
}
func TestGatewayStatusReturnsRestartingDuringRestartGap(t *testing.T) {
resetGatewayTestState(t)
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
mux := http.NewServeMux()
h.RegisterRoutes(mux)
gateway.mu.Lock()
setGatewayRuntimeStatusLocked("restarting")
gateway.mu.Unlock()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil)
mux.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}
var body map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
t.Fatalf("unmarshal response: %v", err)
}
if got := body["gateway_status"]; got != "restarting" {
t.Fatalf("gateway_status = %#v, want %q", got, "restarting")
}
}
func TestGatewayStatusIncludesRestartRequiredWhenModelsDiffer(t *testing.T) {
resetGatewayTestState(t)
configPath := filepath.Join(t.TempDir(), "config.json")
cfg := config.DefaultConfig()
cfg.Agents.Defaults.ModelName = cfg.ModelList[0].ModelName
cfg.ModelList[0].APIKey = "test-key"
if err := config.SaveConfig(configPath, cfg); err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}
h := NewHandler(configPath)
mux := http.NewServeMux()
h.RegisterRoutes(mux)
cmd := startLongRunningProcess(t)
t.Cleanup(func() {
if cmd.Process != nil {
_ = cmd.Process.Kill()
}
_ = cmd.Wait()
})
gateway.mu.Lock()
gateway.cmd = cmd
gateway.bootDefaultModel = "previous-model"
setGatewayRuntimeStatusLocked("running")
gateway.mu.Unlock()
gatewayHealthGet = func(string, time.Duration) (*http.Response, error) {
rec := httptest.NewRecorder()
rec.WriteHeader(http.StatusOK)
_, _ = rec.WriteString(`{"ok":true}`)
return rec.Result(), nil
}
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil)
mux.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}
var body map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
t.Fatalf("unmarshal response: %v", err)
}
if got := body["gateway_restart_required"]; got != true {
t.Fatalf("gateway_restart_required = %#v, want true", got)
}
}
func TestGatewayRestartKeepsRunningProcessWhenPreconditionsFail(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
cfg := config.DefaultConfig()
cfg.Agents.Defaults.ModelName = cfg.ModelList[0].ModelName
cfg.ModelList[0].APIKey = ""
cfg.ModelList[0].AuthMethod = ""
if err := config.SaveConfig(configPath, cfg); err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}
h := NewHandler(configPath)
mux := http.NewServeMux()
h.RegisterRoutes(mux)
cmd := startLongRunningProcess(t)
t.Cleanup(func() {
gateway.mu.Lock()
if gateway.cmd == cmd {
gateway.cmd = nil
gateway.bootDefaultModel = ""
}
gateway.mu.Unlock()
if cmd.Process != nil {
_ = cmd.Process.Kill()
}
_ = cmd.Wait()
})
gateway.mu.Lock()
gateway.cmd = cmd
gateway.bootDefaultModel = "existing-model"
gateway.mu.Unlock()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/gateway/restart", nil)
mux.ServeHTTP(rec, req)
if rec.Code != http.StatusBadRequest {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest)
}
gateway.mu.Lock()
stillRunning := gateway.cmd == cmd && isCmdProcessAliveLocked(cmd)
gateway.mu.Unlock()
if !stillRunning {
t.Fatalf("gateway process was stopped when restart preconditions failed")
}
}
func TestGatewayRestartKeepsOldProcessWhenItDoesNotExitInTime(t *testing.T) {
resetGatewayTestState(t)
configPath := filepath.Join(t.TempDir(), "config.json")
cfg := config.DefaultConfig()
cfg.Agents.Defaults.ModelName = cfg.ModelList[0].ModelName
cfg.ModelList[0].APIKey = "test-key"
if err := config.SaveConfig(configPath, cfg); err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}
h := NewHandler(configPath)
mux := http.NewServeMux()
h.RegisterRoutes(mux)
cmd := startIgnoringTermProcess(t)
t.Cleanup(func() {
gateway.mu.Lock()
if gateway.cmd == cmd {
gateway.cmd = nil
gateway.bootDefaultModel = ""
}
gateway.mu.Unlock()
if cmd.Process != nil {
_ = cmd.Process.Kill()
}
_ = cmd.Wait()
})
gatewayRestartGracePeriod = 150 * time.Millisecond
gatewayRestartForceKillWindow = 150 * time.Millisecond
gatewayRestartPollInterval = 10 * time.Millisecond
gateway.mu.Lock()
gateway.cmd = cmd
gateway.bootDefaultModel = "existing-model"
setGatewayRuntimeStatusLocked("running")
gateway.mu.Unlock()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/gateway/restart", nil)
mux.ServeHTTP(rec, req)
if rec.Code != http.StatusInternalServerError {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusInternalServerError)
}
gateway.mu.Lock()
stillRunning := gateway.cmd == cmd && isCmdProcessAliveLocked(cmd)
status := gateway.runtimeStatus
gateway.mu.Unlock()
if !stillRunning {
t.Fatalf("gateway process was replaced before the old process exited")
}
if status != "running" {
t.Fatalf("runtimeStatus = %q, want %q", status, "running")
}
}
func TestGatewayRestartReturnsErrorStatusWhenReplacementFailsToStart(t *testing.T) {
resetGatewayTestState(t)
configPath := filepath.Join(t.TempDir(), "config.json")
cfg := config.DefaultConfig()
cfg.Agents.Defaults.ModelName = cfg.ModelList[0].ModelName
cfg.ModelList[0].APIKey = "test-key"
if err := config.SaveConfig(configPath, cfg); err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}
invalidBinaryPath := filepath.Join(t.TempDir(), "fake-picoclaw")
if err := os.WriteFile(invalidBinaryPath, []byte("#!/bin/sh\n"), 0o644); err != nil {
t.Fatalf("WriteFile() error = %v", err)
}
t.Setenv("PICOCLAW_BINARY", invalidBinaryPath)
h := NewHandler(configPath)
mux := http.NewServeMux()
h.RegisterRoutes(mux)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/gateway/restart", nil)
mux.ServeHTTP(rec, req)
if rec.Code != http.StatusInternalServerError {
t.Fatalf("restart status = %d, want %d", rec.Code, http.StatusInternalServerError)
}
statusRec := httptest.NewRecorder()
statusReq := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil)
mux.ServeHTTP(statusRec, statusReq)
if statusRec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", statusRec.Code, http.StatusOK)
}
var body map[string]any
if err := json.Unmarshal(statusRec.Body.Bytes(), &body); err != nil {
t.Fatalf("unmarshal response: %v", err)
}
if got := body["gateway_status"]; got != "error" {
t.Fatalf("gateway_status = %#v, want %q", got, "error")
}
}
func TestGatewayClearLogsResetsBufferedHistory(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
h := NewHandler(configPath)
+4 -1
View File
@@ -1,10 +1,13 @@
// API client for gateway process management.
interface GatewayStatusResponse {
gateway_status: "running" | "starting" | "stopped" | "error"
gateway_status: "running" | "starting" | "restarting" | "stopped" | "error"
gateway_start_allowed?: boolean
gateway_start_reason?: string
gateway_restart_required?: boolean
pid?: number
boot_default_model?: string
config_default_model?: string
logs?: string[]
log_total?: number
log_run_id?: number
+1 -1
View File
@@ -84,7 +84,7 @@ export async function setDefaultModel(
body: JSON.stringify({ model_name: modelName }),
})
void refreshGatewayState()
await refreshGatewayState()
return response
}
+77 -31
View File
@@ -6,6 +6,7 @@ import {
IconMoon,
IconPlayerPlay,
IconPower,
IconRefresh,
IconSun,
} from "@tabler/icons-react"
import { Link } from "@tanstack/react-router"
@@ -31,6 +32,11 @@ import {
} from "@/components/ui/dropdown-menu.tsx"
import { Separator } from "@/components/ui/separator.tsx"
import { SidebarTrigger } from "@/components/ui/sidebar"
import {
Tooltip,
TooltipContent,
TooltipTrigger,
} from "@/components/ui/tooltip"
import { useGateway } from "@/hooks/use-gateway.ts"
import { useTheme } from "@/hooks/use-theme.ts"
@@ -41,27 +47,35 @@ export function AppHeader() {
state: gwState,
loading: gwLoading,
canStart,
restartRequired,
start,
restart,
stop,
} = useGateway()
const isRunning = gwState === "running"
const isStarting = gwState === "starting"
const isRestarting = gwState === "restarting"
const isStopped = gwState === "stopped" || gwState === "unknown"
const showNotConnectedHint =
canStart && (gwState === "stopped" || gwState === "error")
!isRestarting && canStart && (gwState === "stopped" || gwState === "error")
const [showStopDialog, setShowStopDialog] = React.useState(false)
const handleGatewayToggle = () => {
if (gwLoading || (!isRunning && !canStart)) return
if (gwLoading || isRestarting || (!isRunning && !canStart)) return
if (isRunning) {
setShowStopDialog(true)
} else {
start()
void start()
}
}
const handleGatewayRestart = () => {
if (gwLoading || isRestarting || !restartRequired || !canStart) return
void restart()
}
const confirmStop = () => {
setShowStopDialog(false)
stop()
@@ -115,35 +129,67 @@ export function AppHeader() {
</AlertDialog>
<div className="text-muted-foreground flex items-center gap-1 text-sm font-medium md:gap-2">
{restartRequired && (
<Tooltip delayDuration={700}>
<TooltipTrigger asChild>
<Button
variant="secondary"
size="icon-sm"
className="bg-amber-500/15 text-amber-700 hover:bg-amber-500/25 hover:text-amber-800 dark:text-amber-300 dark:hover:bg-amber-500/25"
onClick={handleGatewayRestart}
disabled={gwLoading || isRestarting || !canStart}
aria-label={t("header.gateway.action.restart")}
>
<IconRefresh className="size-4" />
</Button>
</TooltipTrigger>
<TooltipContent>
{t("header.gateway.restartRequired")}
</TooltipContent>
</Tooltip>
)}
{/* Gateway Start/Stop */}
<Button
variant={isStarting ? "secondary" : "default"}
size="sm"
className={`h-8 gap-2 px-3 ${
isRunning
? "bg-destructive/10 text-destructive hover:bg-destructive/20"
: isStopped
? "bg-green-500 text-white hover:bg-green-600"
: ""
}`}
onClick={handleGatewayToggle}
disabled={gwLoading || isStarting || (!isRunning && !canStart)}
>
{gwLoading || isStarting ? (
<IconLoader2 className="h-4 w-4 animate-spin opacity-70" />
) : isRunning ? (
<IconPower className="h-4 w-4 opacity-80" />
) : (
<IconPlayerPlay className="h-4 w-4 opacity-80" />
)}
<span className="text-xs font-semibold">
{isRunning
? t("header.gateway.action.stop")
: isStarting
? t("header.gateway.status.starting")
: t("header.gateway.action.start")}
</span>
</Button>
{isRunning ? (
<Tooltip delayDuration={700}>
<TooltipTrigger asChild>
<Button
variant="destructive"
size="icon-sm"
className="size-8"
onClick={handleGatewayToggle}
disabled={gwLoading}
aria-label={t("header.gateway.action.stop")}
>
<IconPower className="h-4 w-4 opacity-80" />
</Button>
</TooltipTrigger>
<TooltipContent>{t("header.gateway.action.stop")}</TooltipContent>
</Tooltip>
) : (
<Button
variant={isStarting || isRestarting ? "secondary" : "default"}
size="sm"
className={`h-8 gap-2 px-3 ${
isStopped ? "bg-green-500 text-white hover:bg-green-600" : ""
}`}
onClick={handleGatewayToggle}
disabled={gwLoading || isStarting || isRestarting || !canStart}
>
{gwLoading || isStarting || isRestarting ? (
<IconLoader2 className="h-4 w-4 animate-spin opacity-70" />
) : (
<IconPlayerPlay className="h-4 w-4 opacity-80" />
)}
<span className="text-xs font-semibold">
{isRestarting
? t("header.gateway.status.restarting")
: isStarting
? t("header.gateway.status.starting")
: t("header.gateway.action.start")}
</span>
</Button>
)}
<Separator
className="mx-4 my-2 hidden md:block"
+17 -5
View File
@@ -20,6 +20,7 @@ export function ChatPage() {
const { t } = useTranslation()
const scrollRef = useRef<HTMLDivElement>(null)
const [isAtBottom, setIsAtBottom] = useState(true)
const [hasScrolled, setHasScrolled] = useState(false)
const [input, setInput] = useState("")
const {
@@ -56,14 +57,22 @@ export function ChatPage() {
onDeletedActiveSession: newChat,
})
const handleScroll = (e: React.UIEvent<HTMLDivElement>) => {
const { scrollTop, scrollHeight, clientHeight } = e.currentTarget
const syncScrollState = (element: HTMLDivElement) => {
const { scrollTop, scrollHeight, clientHeight } = element
setHasScrolled(scrollTop > 0)
setIsAtBottom(scrollHeight - scrollTop <= clientHeight + 10)
}
const handleScroll = (e: React.UIEvent<HTMLDivElement>) => {
syncScrollState(e.currentTarget)
}
useEffect(() => {
if (isAtBottom && scrollRef.current) {
scrollRef.current.scrollTop = scrollRef.current.scrollHeight
if (scrollRef.current) {
if (isAtBottom) {
scrollRef.current.scrollTop = scrollRef.current.scrollHeight
}
syncScrollState(scrollRef.current)
}
}, [messages, isTyping, isAtBottom])
@@ -77,6 +86,9 @@ export function ChatPage() {
<div className="bg-background/95 flex h-full flex-col">
<PageHeader
title={t("navigation.chat")}
className={`transition-shadow ${
hasScrolled ? "shadow-sm" : "shadow-none"
}`}
titleExtra={
hasConfiguredModels && (
<ModelSelector
@@ -90,7 +102,7 @@ export function ChatPage() {
}
>
<Button
variant="outline"
variant="secondary"
size="sm"
onClick={newChat}
className="h-9 gap-2"
@@ -37,7 +37,7 @@ export function ModelSelector({
>
<SelectValue placeholder={t("chat.noModel")} />
</SelectTrigger>
<SelectContent>
<SelectContent position="popper" align="start">
{apiKeyModels.length > 0 && (
<SelectGroup>
<SelectLabel>{t("chat.modelGroup.apikey")}</SelectLabel>
@@ -41,7 +41,7 @@ export function SessionHistoryMenu({
return (
<DropdownMenu onOpenChange={onOpenChange}>
<DropdownMenuTrigger asChild>
<Button variant="outline" size="sm" className="h-9 gap-2">
<Button variant="secondary" size="sm" className="h-9 gap-2">
<IconHistory className="size-4" />
<span className="hidden sm:inline">{t("chat.history")}</span>
</Button>
@@ -110,7 +110,7 @@ export function EditModelSheet({
: undefined,
thinking_level: form.thinkingLevel || undefined,
})
if (setAsDefault) {
if (setAsDefault && !model.is_default) {
await setDefaultModel(model.model_name)
}
onSaved()
@@ -79,6 +79,8 @@ export function ModelsPage() {
}, [fetchModels])
const handleSetDefault = async (model: ModelInfo) => {
if (model.is_default) return
setSettingDefaultIndex(model.index)
try {
await setDefaultModel(model.model_name)
+14 -2
View File
@@ -2,16 +2,28 @@ import { IconMenu2 } from "@tabler/icons-react"
import type { ReactNode } from "react"
import { SidebarTrigger } from "@/components/ui/sidebar"
import { cn } from "@/lib/utils"
interface PageHeaderProps {
title: string
titleExtra?: ReactNode
children?: ReactNode
className?: string
}
export function PageHeader({ title, titleExtra, children }: PageHeaderProps) {
export function PageHeader({
title,
titleExtra,
children,
className,
}: PageHeaderProps) {
return (
<div className="flex h-14 shrink-0 items-center justify-between px-6 pt-2">
<div
className={cn(
"z-40 flex h-14 shrink-0 items-center justify-between px-6 pt-2",
className,
)}
>
<div className="flex items-center gap-4">
<SidebarTrigger className="border-border/60 bg-background text-muted-foreground hover:bg-accent hover:text-foreground hidden h-9 w-9 rounded-lg border sm:flex [&>svg]:size-5">
<IconMenu2 />
+24 -12
View File
@@ -1,4 +1,4 @@
import { useCallback, useEffect, useMemo, useState } from "react"
import { useCallback, useEffect, useMemo, useRef, useState } from "react"
import { type ModelInfo, getModels, setDefaultModel } from "@/api/models"
@@ -20,6 +20,7 @@ function isLocalModel(model: ModelInfo): boolean {
export function useChatModels({ isConnected }: UseChatModelsOptions) {
const [modelList, setModelList] = useState<ModelInfo[]>([])
const [defaultModelName, setDefaultModelName] = useState("")
const setDefaultRequestIdRef = useRef(0)
const loadModels = useCallback(async () => {
try {
@@ -41,17 +42,28 @@ export function useChatModels({ isConnected }: UseChatModelsOptions) {
return () => clearTimeout(timerId)
}, [isConnected, loadModels])
const handleSetDefault = useCallback(async (modelName: string) => {
try {
await setDefaultModel(modelName)
setDefaultModelName(modelName)
setModelList((prev) =>
prev.map((m) => ({ ...m, is_default: m.model_name === modelName })),
)
} catch (err) {
console.error("Failed to set default model:", err)
}
}, [])
const handleSetDefault = useCallback(
async (modelName: string) => {
if (modelName === defaultModelName) return
const requestId = ++setDefaultRequestIdRef.current
try {
await setDefaultModel(modelName)
const data = await getModels()
if (requestId !== setDefaultRequestIdRef.current) {
return
}
setModelList(data.models)
if (data.models.some((m) => m.model_name === data.default_model)) {
setDefaultModelName(data.default_model)
}
} catch (err) {
console.error("Failed to set default model:", err)
}
},
[defaultModelName],
)
const hasConfiguredModels = useMemo(
() => modelList.some((m) => m.configured),
+1 -1
View File
@@ -37,7 +37,7 @@ export function useGatewayLogs() {
const fetchLogs = async () => {
if (
!mounted ||
(gateway.status !== "running" && gateway.status !== "starting")
!["running", "starting", "restarting"].includes(gateway.status)
) {
if (mounted) {
timeout = setTimeout(fetchLogs, 1000)
+55 -28
View File
@@ -1,31 +1,30 @@
import { useAtom } from "jotai"
import { useAtomValue } from "jotai"
import { useCallback, useEffect, useState } from "react"
import {
type GatewayStatusResponse,
getGatewayStatus,
restartGateway,
startGateway,
stopGateway,
} from "@/api/gateway"
import { gatewayAtom } from "@/store"
import {
applyGatewayStatusToStore,
gatewayAtom,
updateGatewayStore,
} from "@/store"
// Global variable to ensure we only have one SSE connection
let sseInitialized = false
export function useGateway() {
const [{ status: state, canStart }, setGateway] = useAtom(gatewayAtom)
const gateway = useAtomValue(gatewayAtom)
const { status: state, canStart, restartRequired } = gateway
const [loading, setLoading] = useState(false)
const applyGatewayStatus = useCallback(
(data: GatewayStatusResponse) => {
setGateway((prev) => ({
...prev,
status: data.gateway_status ?? "unknown",
canStart: data.gateway_start_allowed ?? true,
}))
},
[setGateway],
)
const applyGatewayStatus = useCallback((data: GatewayStatusResponse) => {
applyGatewayStatusToStore(data)
}, [])
// Initialize global SSE connection once
useEffect(() => {
@@ -35,9 +34,10 @@ export function useGateway() {
getGatewayStatus()
.then((data) => applyGatewayStatus(data))
.catch(() => {
setGateway({
updateGatewayStore({
status: "unknown",
canStart: true,
restartRequired: false,
})
})
@@ -59,14 +59,7 @@ export function useGateway() {
data.gateway_status ||
typeof data.gateway_start_allowed === "boolean"
) {
setGateway((prev) => ({
...prev,
status: data.gateway_status ?? prev.status,
canStart:
typeof data.gateway_start_allowed === "boolean"
? data.gateway_start_allowed
: prev.canStart,
}))
applyGatewayStatus(data)
}
} catch {
// ignore
@@ -75,7 +68,9 @@ export function useGateway() {
es.onerror = () => {
// EventSource will auto-reconnect
setGateway((prev) => ({ ...prev, status: "unknown" }))
updateGatewayStore((prev) =>
prev.status === "restarting" ? {} : { status: "unknown" },
)
}
return () => {
@@ -83,7 +78,7 @@ export function useGateway() {
es.close()
sseInitialized = false
}
}, [applyGatewayStatus, setGateway])
}, [applyGatewayStatus])
const start = useCallback(async () => {
if (!canStart) return
@@ -92,19 +87,19 @@ export function useGateway() {
try {
await startGateway()
// SSE will push the real state changes, but set optimistic state
setGateway((prev) => ({ ...prev, status: "starting" }))
updateGatewayStore({ status: "starting" })
} catch (err) {
console.error("Failed to start gateway:", err)
try {
const status = await getGatewayStatus()
applyGatewayStatus(status)
} catch {
setGateway((prev) => ({ ...prev, status: "unknown" }))
updateGatewayStore({ status: "unknown" })
}
} finally {
setLoading(false)
}
}, [applyGatewayStatus, canStart, setGateway])
}, [applyGatewayStatus, canStart])
const stop = useCallback(async () => {
setLoading(true)
@@ -117,5 +112,37 @@ export function useGateway() {
}
}, [])
return { state, loading, canStart, start, stop }
const restart = useCallback(async () => {
if (state !== "running") return
const previousState = state
const previousCanStart = canStart
const previousRestartRequired = restartRequired
setLoading(true)
updateGatewayStore({
status: "restarting",
restartRequired: false,
})
try {
await restartGateway()
} catch (err) {
console.error("Failed to restart gateway:", err)
try {
const status = await getGatewayStatus()
applyGatewayStatus(status)
} catch {
updateGatewayStore({
status: previousState,
canStart: previousCanStart,
restartRequired: previousRestartRequired,
})
}
} finally {
setLoading(false)
}
}, [applyGatewayStatus, canStart, restartRequired, state])
return { state, loading, canStart, restartRequired, start, stop, restart }
}
+97 -83
View File
@@ -130,8 +130,9 @@ export function usePicoChat() {
const [connectionState, setConnectionState] =
useState<ConnectionState>("disconnected")
const [isTyping, setIsTyping] = useState(false)
const [activeSessionId, setActiveSessionId] =
useState<string>(() => readStoredSessionId() || generateSessionId())
const [activeSessionId, setActiveSessionId] = useState<string>(
() => readStoredSessionId() || generateSessionId(),
)
const wsRef = useRef<WebSocket | null>(null)
const isConnectingRef = useRef(false)
@@ -144,9 +145,7 @@ export function usePicoChat() {
setMessages((prev) => {
const next =
typeof nextState === "function"
? (
nextState as (prevState: ChatMessage[]) => ChatMessage[]
)(prev)
? (nextState as (prevState: ChatMessage[]) => ChatMessage[])(prev)
: nextState
if (next !== prev) {
@@ -220,64 +219,69 @@ export function usePicoChat() {
}
}, [loadSessionMessages, setTrackedMessages])
const handlePicoMessage = useCallback((msg: PicoMessage) => {
const payload = msg.payload || {}
const handlePicoMessage = useCallback(
(msg: PicoMessage) => {
const payload = msg.payload || {}
switch (msg.type) {
case "message.create": {
const content = (payload.content as string) || ""
const messageId = (payload.message_id as string) || `pico-${Date.now()}`
// Use provided timestamp or current time
const timestampRaw =
msg.timestamp !== undefined && Number.isFinite(Number(msg.timestamp))
? normalizeUnixTimestamp(Number(msg.timestamp))
: Date.now()
switch (msg.type) {
case "message.create": {
const content = (payload.content as string) || ""
const messageId =
(payload.message_id as string) || `pico-${Date.now()}`
// Use provided timestamp or current time
const timestampRaw =
msg.timestamp !== undefined &&
Number.isFinite(Number(msg.timestamp))
? normalizeUnixTimestamp(Number(msg.timestamp))
: Date.now()
setTrackedMessages((prev) => [
...prev,
{
id: messageId,
role: "assistant",
content,
timestamp: timestampRaw,
},
])
setIsTyping(false)
break
setTrackedMessages((prev) => [
...prev,
{
id: messageId,
role: "assistant",
content,
timestamp: timestampRaw,
},
])
setIsTyping(false)
break
}
case "message.update": {
const content = (payload.content as string) || ""
const messageId = payload.message_id as string
if (!messageId) break
setTrackedMessages((prev) =>
prev.map((m) => (m.id === messageId ? { ...m, content } : m)),
)
break
}
case "typing.start":
setIsTyping(true)
break
case "typing.stop":
setIsTyping(false)
break
case "error":
console.error("Pico error:", payload)
setIsTyping(false)
break
case "pong":
// heartbeat response, ignore
break
default:
console.log("Unknown pico message type:", msg.type)
}
case "message.update": {
const content = (payload.content as string) || ""
const messageId = payload.message_id as string
if (!messageId) break
setTrackedMessages((prev) =>
prev.map((m) => (m.id === messageId ? { ...m, content } : m)),
)
break
}
case "typing.start":
setIsTyping(true)
break
case "typing.stop":
setIsTyping(false)
break
case "error":
console.error("Pico error:", payload)
setIsTyping(false)
break
case "pong":
// heartbeat response, ignore
break
default:
console.log("Unknown pico message type:", msg.type)
}
}, [setTrackedMessages])
},
[setTrackedMessages],
)
const connect = useCallback(async () => {
if (
@@ -389,32 +393,35 @@ export function usePicoChat() {
return () => disconnect()
}, [disconnect])
const sendMessage = useCallback((content: string) => {
if (!wsRef.current || wsRef.current.readyState !== WebSocket.OPEN) {
console.warn("WebSocket not connected")
return
}
const sendMessage = useCallback(
(content: string) => {
if (!wsRef.current || wsRef.current.readyState !== WebSocket.OPEN) {
console.warn("WebSocket not connected")
return
}
const id = `msg-${++msgIdCounter.current}-${Date.now()}`
const timestampRaw = Date.now()
const id = `msg-${++msgIdCounter.current}-${Date.now()}`
const timestampRaw = Date.now()
// Add user message to local state
setTrackedMessages((prev) => [
...prev,
{ id, role: "user", content, timestamp: timestampRaw },
])
// Add user message to local state
setTrackedMessages((prev) => [
...prev,
{ id, role: "user", content, timestamp: timestampRaw },
])
// Show typing indicator immediately
setIsTyping(true)
// Show typing indicator immediately
setIsTyping(true)
// Send via Pico Protocol
const picoMsg: PicoMessage = {
type: "message.send",
id,
payload: { content },
}
wsRef.current.send(JSON.stringify(picoMsg))
}, [setTrackedMessages])
// Send via Pico Protocol
const picoMsg: PicoMessage = {
type: "message.send",
id,
payload: { content },
}
wsRef.current.send(JSON.stringify(picoMsg))
},
[setTrackedMessages],
)
// Switch to a historical session
const switchSession = useCallback(
@@ -443,7 +450,14 @@ export function usePicoChat() {
}
}, 100)
},
[connect, disconnect, gatewayState, loadSessionMessages, setTrackedMessages, t],
[
connect,
disconnect,
gatewayState,
loadSessionMessages,
setTrackedMessages,
t,
],
)
// Start a new empty chat
+6 -3
View File
@@ -58,11 +58,14 @@
},
"action": {
"start": "Start Gateway",
"stop": "Stop Gateway"
"stop": "Stop Gateway",
"restart": "Restart Gateway"
},
"status": {
"starting": "Starting Gateway..."
}
"starting": "Starting Gateway...",
"restarting": "Restarting Gateway..."
},
"restartRequired": "Model changes require a gateway restart to take effect."
}
},
"common": {
+6 -3
View File
@@ -58,11 +58,14 @@
},
"action": {
"start": "启动服务",
"stop": "停止服务"
"stop": "停止服务",
"restart": "重启服务"
},
"status": {
"starting": "服务启动中..."
}
"starting": "服务启动中...",
"restarting": "服务重启中..."
},
"restartRequired": "切换默认模型后需要重启服务才能生效。"
}
},
"common": {
+46 -10
View File
@@ -5,6 +5,7 @@ import { type GatewayStatusResponse, getGatewayStatus } from "@/api/gateway"
export type GatewayState =
| "running"
| "starting"
| "restarting"
| "stopped"
| "error"
| "unknown"
@@ -12,19 +13,54 @@ export type GatewayState =
export interface GatewayStoreState {
status: GatewayState
canStart: boolean
restartRequired: boolean
}
type GatewayStorePatch = Partial<GatewayStoreState>
const DEFAULT_GATEWAY_STATE: GatewayStoreState = {
status: "unknown",
canStart: true,
restartRequired: false,
}
// Global atom for gateway state
export const gatewayAtom = atom<GatewayStoreState>({
status: "unknown",
canStart: true,
})
export const gatewayAtom = atom<GatewayStoreState>(DEFAULT_GATEWAY_STATE)
function applyGatewayStatusToStore(data: GatewayStatusResponse) {
getDefaultStore().set(gatewayAtom, (prev) => ({
...prev,
status: data.gateway_status ?? "unknown",
canStart: data.gateway_start_allowed ?? true,
function normalizeGatewayStoreState(
prev: GatewayStoreState,
patch: GatewayStorePatch,
) {
return { ...prev, ...patch }
}
export function updateGatewayStore(
patch:
| GatewayStorePatch
| ((prev: GatewayStoreState) => GatewayStorePatch | GatewayStoreState),
) {
getDefaultStore().set(gatewayAtom, (prev) => {
const nextPatch = typeof patch === "function" ? patch(prev) : patch
return normalizeGatewayStoreState(prev, nextPatch)
})
}
export function applyGatewayStatusToStore(
data: Partial<
Pick<
GatewayStatusResponse,
"gateway_status" | "gateway_start_allowed" | "gateway_restart_required"
>
>,
) {
updateGatewayStore((prev) => ({
status: data.gateway_status ?? prev.status,
canStart: data.gateway_start_allowed ?? prev.canStart,
restartRequired:
data.gateway_restart_required ??
(data.gateway_status && data.gateway_status !== "running"
? false
: prev.restartRequired),
}))
}
@@ -33,6 +69,6 @@ export async function refreshGatewayState() {
const status = await getGatewayStatus()
applyGatewayStatusToStore(status)
} catch {
// Best-effort refresh only; keep current state on error.
updateGatewayStore(DEFAULT_GATEWAY_STATE)
}
}