mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
feat: add web gateway hot reload and polling state sync (#1684)
* feat(gateway): support hot reload and empty startup - extract gateway runtime into pkg/gateway - add gateway.hot_reload config with default and example values - allow starting the gateway without a default model via --allow-empty - stop treating missing enabled channels as a startup error - update related tests * feat: replace gateway SSE updates with polling-based state sync - remove gateway SSE broadcasting and event endpoint - add polling-based gateway status refresh with stopping state handling - detect when gateway restart is required after default model changes - resolve gateway health and websocket proxy targets from configured host - update gateway UI labels and add backend/frontend test coverage
This commit is contained in:
@@ -5,6 +5,8 @@ import (
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/sipeed/picoclaw/cmd/picoclaw/internal"
|
||||
"github.com/sipeed/picoclaw/pkg/gateway"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
@@ -12,6 +14,7 @@ import (
|
||||
func NewGatewayCommand() *cobra.Command {
|
||||
var debug bool
|
||||
var noTruncate bool
|
||||
var allowEmpty bool
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "gateway",
|
||||
@@ -31,12 +34,19 @@ func NewGatewayCommand() *cobra.Command {
|
||||
return nil
|
||||
},
|
||||
RunE: func(_ *cobra.Command, _ []string) error {
|
||||
return gatewayCmd(debug)
|
||||
return gateway.Run(debug, internal.GetConfigPath(), allowEmpty)
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().BoolVarP(&debug, "debug", "d", false, "Enable debug logging")
|
||||
cmd.Flags().BoolVarP(&noTruncate, "no-truncate", "T", false, "Disable string truncation in debug logs")
|
||||
cmd.Flags().BoolVarP(
|
||||
&allowEmpty,
|
||||
"allow-empty",
|
||||
"E",
|
||||
false,
|
||||
"Continue starting even when no default model is configured",
|
||||
)
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
@@ -28,4 +28,5 @@ func TestNewGatewayCommand(t *testing.T) {
|
||||
|
||||
assert.True(t, cmd.HasFlags())
|
||||
assert.NotNil(t, cmd.Flags().Lookup("debug"))
|
||||
assert.NotNil(t, cmd.Flags().Lookup("allow-empty"))
|
||||
}
|
||||
|
||||
@@ -518,6 +518,7 @@
|
||||
},
|
||||
"gateway": {
|
||||
"host": "127.0.0.1",
|
||||
"port": 18790
|
||||
"port": 18790,
|
||||
"hot_reload": false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -357,7 +357,6 @@ func (m *Manager) StartAll(ctx context.Context) error {
|
||||
|
||||
if len(m.channels) == 0 {
|
||||
logger.WarnC("channels", "No channels enabled")
|
||||
return errors.New("no channels enabled")
|
||||
}
|
||||
|
||||
logger.InfoC("channels", "Starting all channels")
|
||||
|
||||
@@ -625,8 +625,9 @@ func (c *ModelConfig) Validate() error {
|
||||
}
|
||||
|
||||
type GatewayConfig struct {
|
||||
Host string `json:"host" env:"PICOCLAW_GATEWAY_HOST"`
|
||||
Port int `json:"port" env:"PICOCLAW_GATEWAY_PORT"`
|
||||
Host string `json:"host" env:"PICOCLAW_GATEWAY_HOST"`
|
||||
Port int `json:"port" env:"PICOCLAW_GATEWAY_PORT"`
|
||||
HotReload bool `json:"hot_reload" env:"PICOCLAW_GATEWAY_HOT_RELOAD"`
|
||||
}
|
||||
|
||||
type ToolDiscoveryConfig struct {
|
||||
|
||||
@@ -267,6 +267,9 @@ func TestDefaultConfig_Gateway(t *testing.T) {
|
||||
if cfg.Gateway.Port == 0 {
|
||||
t.Error("Gateway port should have default value")
|
||||
}
|
||||
if cfg.Gateway.HotReload {
|
||||
t.Error("Gateway hot reload should be disabled by default")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDefaultConfig_Providers verifies provider structure
|
||||
|
||||
@@ -395,8 +395,9 @@ func DefaultConfig() *Config {
|
||||
},
|
||||
},
|
||||
Gateway: GatewayConfig{
|
||||
Host: "127.0.0.1",
|
||||
Port: 18790,
|
||||
Host: "127.0.0.1",
|
||||
Port: 18790,
|
||||
HotReload: false,
|
||||
},
|
||||
Tools: ToolsConfig{
|
||||
MediaCleanup: MediaCleanupConfig{
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/cmd/picoclaw/internal"
|
||||
"github.com/sipeed/picoclaw/pkg/agent"
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
@@ -42,15 +41,13 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/voice"
|
||||
)
|
||||
|
||||
// Timeout constants for service operations
|
||||
const (
|
||||
serviceShutdownTimeout = 30 * time.Second
|
||||
providerReloadTimeout = 30 * time.Second
|
||||
gracefulShutdownTimeout = 15 * time.Second
|
||||
)
|
||||
|
||||
// gatewayServices holds references to all running services
|
||||
type gatewayServices struct {
|
||||
type services struct {
|
||||
CronService *cron.CronService
|
||||
HeartbeatService *heartbeat.HeartbeatService
|
||||
MediaStore media.MediaStore
|
||||
@@ -59,24 +56,41 @@ type gatewayServices struct {
|
||||
HealthServer *health.Server
|
||||
}
|
||||
|
||||
func gatewayCmd(debug bool) error {
|
||||
type startupBlockedProvider struct {
|
||||
reason string
|
||||
}
|
||||
|
||||
func (p *startupBlockedProvider) Chat(
|
||||
_ context.Context,
|
||||
_ []providers.Message,
|
||||
_ []providers.ToolDefinition,
|
||||
_ string,
|
||||
_ map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
return nil, fmt.Errorf("%s", p.reason)
|
||||
}
|
||||
|
||||
func (p *startupBlockedProvider) GetDefaultModel() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Run starts the gateway runtime using the configuration loaded from configPath.
|
||||
func Run(debug bool, configPath string, allowEmptyStartup bool) error {
|
||||
if debug {
|
||||
logger.SetLevel(logger.DEBUG)
|
||||
fmt.Println("🔍 Debug mode enabled")
|
||||
}
|
||||
|
||||
configPath := internal.GetConfigPath()
|
||||
cfg, err := internal.LoadConfig()
|
||||
cfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error loading config: %w", err)
|
||||
}
|
||||
|
||||
provider, modelID, err := providers.CreateProvider(cfg)
|
||||
provider, modelID, err := createStartupProvider(cfg, allowEmptyStartup)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating provider: %w", err)
|
||||
}
|
||||
|
||||
// Use the resolved model ID from provider creation
|
||||
if modelID != "" {
|
||||
cfg.Agents.Defaults.ModelName = modelID
|
||||
}
|
||||
@@ -84,17 +98,13 @@ func gatewayCmd(debug bool) error {
|
||||
msgBus := bus.NewMessageBus()
|
||||
agentLoop := agent.NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
// Print agent startup info
|
||||
fmt.Println("\n📦 Agent Status:")
|
||||
startupInfo := agentLoop.GetStartupInfo()
|
||||
toolsInfo := startupInfo["tools"].(map[string]any)
|
||||
skillsInfo := startupInfo["skills"].(map[string]any)
|
||||
fmt.Printf(" • Tools: %d loaded\n", toolsInfo["count"])
|
||||
fmt.Printf(" • Skills: %d/%d available\n",
|
||||
skillsInfo["available"],
|
||||
skillsInfo["total"])
|
||||
fmt.Printf(" • Skills: %d/%d available\n", skillsInfo["available"], skillsInfo["total"])
|
||||
|
||||
// Log to file as well
|
||||
logger.InfoCF("agent", "Agent initialized",
|
||||
map[string]any{
|
||||
"tools_count": toolsInfo["count"],
|
||||
@@ -102,8 +112,7 @@ func gatewayCmd(debug bool) error {
|
||||
"skills_available": skillsInfo["available"],
|
||||
})
|
||||
|
||||
// Setup and start all services
|
||||
services, err := setupAndStartServices(cfg, agentLoop, msgBus)
|
||||
runningServices, err := setupAndStartServices(cfg, agentLoop, msgBus)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -116,23 +125,25 @@ func gatewayCmd(debug bool) error {
|
||||
|
||||
go agentLoop.Run(ctx)
|
||||
|
||||
// Setup config file watcher for hot reload
|
||||
configReloadChan, stopWatch := setupConfigWatcherPolling(configPath, debug)
|
||||
var configReloadChan <-chan *config.Config
|
||||
stopWatch := func() {}
|
||||
if cfg.Gateway.HotReload {
|
||||
configReloadChan, stopWatch = setupConfigWatcherPolling(configPath, debug)
|
||||
logger.Info("Config hot reload enabled")
|
||||
}
|
||||
defer stopWatch()
|
||||
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
|
||||
|
||||
// Main event loop - wait for signals or config changes
|
||||
for {
|
||||
select {
|
||||
case <-sigChan:
|
||||
logger.Info("Shutting down...")
|
||||
shutdownGateway(services, agentLoop, provider, true)
|
||||
shutdownGateway(runningServices, agentLoop, provider, true)
|
||||
return nil
|
||||
|
||||
case newCfg := <-configReloadChan:
|
||||
err := handleConfigReload(ctx, agentLoop, newCfg, &provider, services, msgBus)
|
||||
err := handleConfigReload(ctx, agentLoop, newCfg, &provider, runningServices, msgBus, allowEmptyStartup)
|
||||
if err != nil {
|
||||
logger.Errorf("Config reload failed: %v", err)
|
||||
}
|
||||
@@ -140,18 +151,33 @@ func gatewayCmd(debug bool) error {
|
||||
}
|
||||
}
|
||||
|
||||
// setupAndStartServices initializes and starts all services
|
||||
func createStartupProvider(
|
||||
cfg *config.Config,
|
||||
allowEmptyStartup bool,
|
||||
) (providers.LLMProvider, string, error) {
|
||||
modelName := cfg.Agents.Defaults.GetModelName()
|
||||
if modelName == "" && allowEmptyStartup {
|
||||
reason := "no default model configured; gateway started in limited mode"
|
||||
fmt.Printf("⚠ Warning: %s\n", reason)
|
||||
logger.WarnCF("gateway", "Gateway started without default model", map[string]any{
|
||||
"limited_mode": true,
|
||||
})
|
||||
return &startupBlockedProvider{reason: reason}, "", nil
|
||||
}
|
||||
|
||||
return providers.CreateProvider(cfg)
|
||||
}
|
||||
|
||||
func setupAndStartServices(
|
||||
cfg *config.Config,
|
||||
agentLoop *agent.AgentLoop,
|
||||
msgBus *bus.MessageBus,
|
||||
) (*gatewayServices, error) {
|
||||
services := &gatewayServices{}
|
||||
) (*services, error) {
|
||||
runningServices := &services{}
|
||||
|
||||
// Setup cron tool and service
|
||||
execTimeout := time.Duration(cfg.Tools.Cron.ExecTimeoutMinutes) * time.Minute
|
||||
var err error
|
||||
services.CronService, err = setupCronTool(
|
||||
runningServices.CronService, err = setupCronTool(
|
||||
agentLoop,
|
||||
msgBus,
|
||||
cfg.WorkspacePath(),
|
||||
@@ -162,120 +188,105 @@ func setupAndStartServices(
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error setting up cron service: %w", err)
|
||||
}
|
||||
if err = services.CronService.Start(); err != nil {
|
||||
if err = runningServices.CronService.Start(); err != nil {
|
||||
return nil, fmt.Errorf("error starting cron service: %w", err)
|
||||
}
|
||||
fmt.Println("✓ Cron service started")
|
||||
|
||||
// Setup heartbeat service
|
||||
services.HeartbeatService = heartbeat.NewHeartbeatService(
|
||||
runningServices.HeartbeatService = heartbeat.NewHeartbeatService(
|
||||
cfg.WorkspacePath(),
|
||||
cfg.Heartbeat.Interval,
|
||||
cfg.Heartbeat.Enabled,
|
||||
)
|
||||
services.HeartbeatService.SetBus(msgBus)
|
||||
services.HeartbeatService.SetHandler(createHeartbeatHandler(agentLoop))
|
||||
if err = services.HeartbeatService.Start(); err != nil {
|
||||
runningServices.HeartbeatService.SetBus(msgBus)
|
||||
runningServices.HeartbeatService.SetHandler(createHeartbeatHandler(agentLoop))
|
||||
if err = runningServices.HeartbeatService.Start(); err != nil {
|
||||
return nil, fmt.Errorf("error starting heartbeat service: %w", err)
|
||||
}
|
||||
fmt.Println("✓ Heartbeat service started")
|
||||
|
||||
// Create media store for file lifecycle management with TTL cleanup
|
||||
services.MediaStore = media.NewFileMediaStoreWithCleanup(media.MediaCleanerConfig{
|
||||
runningServices.MediaStore = media.NewFileMediaStoreWithCleanup(media.MediaCleanerConfig{
|
||||
Enabled: cfg.Tools.MediaCleanup.Enabled,
|
||||
MaxAge: time.Duration(cfg.Tools.MediaCleanup.MaxAge) * time.Minute,
|
||||
Interval: time.Duration(cfg.Tools.MediaCleanup.Interval) * time.Minute,
|
||||
})
|
||||
// Start the media store if it's a FileMediaStore with cleanup
|
||||
if fms, ok := services.MediaStore.(*media.FileMediaStore); ok {
|
||||
if fms, ok := runningServices.MediaStore.(*media.FileMediaStore); ok {
|
||||
fms.Start()
|
||||
}
|
||||
|
||||
// Create channel manager
|
||||
services.ChannelManager, err = channels.NewManager(cfg, msgBus, services.MediaStore)
|
||||
runningServices.ChannelManager, err = channels.NewManager(cfg, msgBus, runningServices.MediaStore)
|
||||
if err != nil {
|
||||
// Stop the media store if it's a FileMediaStore with cleanup
|
||||
if fms, ok := services.MediaStore.(*media.FileMediaStore); ok {
|
||||
if fms, ok := runningServices.MediaStore.(*media.FileMediaStore); ok {
|
||||
fms.Stop()
|
||||
}
|
||||
return nil, fmt.Errorf("error creating channel manager: %w", err)
|
||||
}
|
||||
|
||||
// Inject channel manager and media store into agent loop
|
||||
agentLoop.SetChannelManager(services.ChannelManager)
|
||||
agentLoop.SetMediaStore(services.MediaStore)
|
||||
agentLoop.SetChannelManager(runningServices.ChannelManager)
|
||||
agentLoop.SetMediaStore(runningServices.MediaStore)
|
||||
|
||||
// Wire up voice transcription if a supported provider is configured.
|
||||
if transcriber := voice.DetectTranscriber(cfg); transcriber != nil {
|
||||
agentLoop.SetTranscriber(transcriber)
|
||||
logger.InfoCF("voice", "Transcription enabled (agent-level)", map[string]any{"provider": transcriber.Name()})
|
||||
}
|
||||
|
||||
enabledChannels := services.ChannelManager.GetEnabledChannels()
|
||||
enabledChannels := runningServices.ChannelManager.GetEnabledChannels()
|
||||
if len(enabledChannels) > 0 {
|
||||
fmt.Printf("✓ Channels enabled: %s\n", enabledChannels)
|
||||
} else {
|
||||
fmt.Println("⚠ Warning: No channels enabled")
|
||||
}
|
||||
|
||||
// Setup shared HTTP server with health endpoints and webhook handlers
|
||||
addr := fmt.Sprintf("%s:%d", cfg.Gateway.Host, cfg.Gateway.Port)
|
||||
services.HealthServer = health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port)
|
||||
services.ChannelManager.SetupHTTPServer(addr, services.HealthServer)
|
||||
runningServices.HealthServer = health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port)
|
||||
runningServices.ChannelManager.SetupHTTPServer(addr, runningServices.HealthServer)
|
||||
|
||||
if err = services.ChannelManager.StartAll(context.Background()); err != nil {
|
||||
if err = runningServices.ChannelManager.StartAll(context.Background()); err != nil {
|
||||
return nil, fmt.Errorf("error starting channels: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("✓ Health endpoints available at http://%s:%d/health and /ready\n", cfg.Gateway.Host, cfg.Gateway.Port)
|
||||
|
||||
// Setup state manager and device service
|
||||
stateManager := state.NewManager(cfg.WorkspacePath())
|
||||
services.DeviceService = devices.NewService(devices.Config{
|
||||
runningServices.DeviceService = devices.NewService(devices.Config{
|
||||
Enabled: cfg.Devices.Enabled,
|
||||
MonitorUSB: cfg.Devices.MonitorUSB,
|
||||
}, stateManager)
|
||||
services.DeviceService.SetBus(msgBus)
|
||||
if err = services.DeviceService.Start(context.Background()); err != nil {
|
||||
runningServices.DeviceService.SetBus(msgBus)
|
||||
if err = runningServices.DeviceService.Start(context.Background()); err != nil {
|
||||
logger.ErrorCF("device", "Error starting device service", map[string]any{"error": err.Error()})
|
||||
} else if cfg.Devices.Enabled {
|
||||
fmt.Println("✓ Device event service started")
|
||||
}
|
||||
|
||||
return services, nil
|
||||
return runningServices, nil
|
||||
}
|
||||
|
||||
// stopAndCleanupServices stops all services and cleans up resources
|
||||
func stopAndCleanupServices(
|
||||
services *gatewayServices,
|
||||
shutdownTimeout time.Duration,
|
||||
) {
|
||||
func stopAndCleanupServices(runningServices *services, shutdownTimeout time.Duration) {
|
||||
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), shutdownTimeout)
|
||||
defer shutdownCancel()
|
||||
|
||||
if services.ChannelManager != nil {
|
||||
services.ChannelManager.StopAll(shutdownCtx)
|
||||
if runningServices.ChannelManager != nil {
|
||||
runningServices.ChannelManager.StopAll(shutdownCtx)
|
||||
}
|
||||
if services.DeviceService != nil {
|
||||
services.DeviceService.Stop()
|
||||
if runningServices.DeviceService != nil {
|
||||
runningServices.DeviceService.Stop()
|
||||
}
|
||||
if services.HeartbeatService != nil {
|
||||
services.HeartbeatService.Stop()
|
||||
if runningServices.HeartbeatService != nil {
|
||||
runningServices.HeartbeatService.Stop()
|
||||
}
|
||||
if services.CronService != nil {
|
||||
services.CronService.Stop()
|
||||
if runningServices.CronService != nil {
|
||||
runningServices.CronService.Stop()
|
||||
}
|
||||
if services.MediaStore != nil {
|
||||
// Stop the media store if it's a FileMediaStore with cleanup
|
||||
if fms, ok := services.MediaStore.(*media.FileMediaStore); ok {
|
||||
if runningServices.MediaStore != nil {
|
||||
if fms, ok := runningServices.MediaStore.(*media.FileMediaStore); ok {
|
||||
fms.Stop()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// shutdownGateway performs a complete gateway shutdown
|
||||
func shutdownGateway(
|
||||
services *gatewayServices,
|
||||
runningServices *services,
|
||||
agentLoop *agent.AgentLoop,
|
||||
provider providers.LLMProvider,
|
||||
fullShutdown bool,
|
||||
@@ -284,7 +295,7 @@ func shutdownGateway(
|
||||
cp.Close()
|
||||
}
|
||||
|
||||
stopAndCleanupServices(services, gracefulShutdownTimeout)
|
||||
stopAndCleanupServices(runningServices, gracefulShutdownTimeout)
|
||||
|
||||
agentLoop.Stop()
|
||||
agentLoop.Close()
|
||||
@@ -292,15 +303,14 @@ func shutdownGateway(
|
||||
logger.Info("✓ Gateway stopped")
|
||||
}
|
||||
|
||||
// handleConfigReload handles config file reload by stopping all services,
|
||||
// reloading the provider and config, and restarting services with the new config.
|
||||
func handleConfigReload(
|
||||
ctx context.Context,
|
||||
al *agent.AgentLoop,
|
||||
newCfg *config.Config,
|
||||
providerRef *providers.LLMProvider,
|
||||
services *gatewayServices,
|
||||
runningServices *services,
|
||||
msgBus *bus.MessageBus,
|
||||
allowEmptyStartup bool,
|
||||
) error {
|
||||
logger.Info("🔄 Config file changed, reloading...")
|
||||
|
||||
@@ -311,18 +321,14 @@ func handleConfigReload(
|
||||
|
||||
logger.Infof(" New model is '%s', recreating provider...", newModel)
|
||||
|
||||
// Stop all services before reloading
|
||||
logger.Info(" Stopping all services...")
|
||||
stopAndCleanupServices(services, serviceShutdownTimeout)
|
||||
stopAndCleanupServices(runningServices, serviceShutdownTimeout)
|
||||
|
||||
// Create new provider from updated config first to ensure validity
|
||||
// This will use the correct API key and settings from newCfg.ModelList
|
||||
newProvider, newModelID, err := providers.CreateProvider(newCfg)
|
||||
newProvider, newModelID, err := createStartupProvider(newCfg, allowEmptyStartup)
|
||||
if err != nil {
|
||||
logger.Errorf(" ⚠ Error creating new provider: %v", err)
|
||||
logger.Warn(" Attempting to restart services with old provider and config...")
|
||||
// Try to restart services with old configuration
|
||||
if restartErr := restartServices(al, services, msgBus); restartErr != nil {
|
||||
if restartErr := restartServices(al, runningServices, msgBus); restartErr != nil {
|
||||
logger.Errorf(" ⚠ Failed to restart services: %v", restartErr)
|
||||
}
|
||||
return fmt.Errorf("error creating new provider: %w", err)
|
||||
@@ -332,31 +338,25 @@ func handleConfigReload(
|
||||
newCfg.Agents.Defaults.ModelName = newModelID
|
||||
}
|
||||
|
||||
// Use the atomic reload method on AgentLoop to safely swap provider and config.
|
||||
// This handles locking internally to prevent races with in-flight LLM calls
|
||||
// and concurrent reads of registry/config while the swap occurs.
|
||||
reloadCtx, reloadCancel := context.WithTimeout(context.Background(), providerReloadTimeout)
|
||||
defer reloadCancel()
|
||||
|
||||
if err := al.ReloadProviderAndConfig(reloadCtx, newProvider, newCfg); err != nil {
|
||||
logger.Errorf(" ⚠ Error reloading agent loop: %v", err)
|
||||
// Close the newly created provider since it wasn't adopted
|
||||
if cp, ok := newProvider.(providers.StatefulProvider); ok {
|
||||
cp.Close()
|
||||
}
|
||||
logger.Warn(" Attempting to restart services with old provider and config...")
|
||||
if restartErr := restartServices(al, services, msgBus); restartErr != nil {
|
||||
if restartErr := restartServices(al, runningServices, msgBus); restartErr != nil {
|
||||
logger.Errorf(" ⚠ Failed to restart services: %v", restartErr)
|
||||
}
|
||||
return fmt.Errorf("error reloading agent loop: %w", err)
|
||||
}
|
||||
|
||||
// Update local provider reference only after successful atomic reload
|
||||
*providerRef = newProvider
|
||||
|
||||
// Restart all services with new config
|
||||
logger.Info(" Restarting all services with new configuration...")
|
||||
if err := restartServices(al, services, msgBus); err != nil {
|
||||
if err := restartServices(al, runningServices, msgBus); err != nil {
|
||||
logger.Errorf(" ⚠ Error restarting services: %v", err)
|
||||
return fmt.Errorf("error restarting services: %w", err)
|
||||
}
|
||||
@@ -365,19 +365,16 @@ func handleConfigReload(
|
||||
return nil
|
||||
}
|
||||
|
||||
// restartServices restarts all services after a config reload
|
||||
func restartServices(
|
||||
al *agent.AgentLoop,
|
||||
services *gatewayServices,
|
||||
runningServices *services,
|
||||
msgBus *bus.MessageBus,
|
||||
) error {
|
||||
// Get current config from agent loop (which has been updated if this is a reload)
|
||||
cfg := al.GetConfig()
|
||||
|
||||
// Re-create and start cron service with new config
|
||||
execTimeout := time.Duration(cfg.Tools.Cron.ExecTimeoutMinutes) * time.Minute
|
||||
var err error
|
||||
services.CronService, err = setupCronTool(
|
||||
runningServices.CronService, err = setupCronTool(
|
||||
al,
|
||||
msgBus,
|
||||
cfg.WorkspacePath(),
|
||||
@@ -388,57 +385,51 @@ func restartServices(
|
||||
if err != nil {
|
||||
return fmt.Errorf("error restarting cron service: %w", err)
|
||||
}
|
||||
if err = services.CronService.Start(); err != nil {
|
||||
if err = runningServices.CronService.Start(); err != nil {
|
||||
return fmt.Errorf("error restarting cron service: %w", err)
|
||||
}
|
||||
fmt.Println(" ✓ Cron service restarted")
|
||||
|
||||
// Re-create and start heartbeat service with new config
|
||||
services.HeartbeatService = heartbeat.NewHeartbeatService(
|
||||
runningServices.HeartbeatService = heartbeat.NewHeartbeatService(
|
||||
cfg.WorkspacePath(),
|
||||
cfg.Heartbeat.Interval,
|
||||
cfg.Heartbeat.Enabled,
|
||||
)
|
||||
services.HeartbeatService.SetBus(msgBus)
|
||||
services.HeartbeatService.SetHandler(createHeartbeatHandler(al))
|
||||
if err = services.HeartbeatService.Start(); err != nil {
|
||||
runningServices.HeartbeatService.SetBus(msgBus)
|
||||
runningServices.HeartbeatService.SetHandler(createHeartbeatHandler(al))
|
||||
if err = runningServices.HeartbeatService.Start(); err != nil {
|
||||
return fmt.Errorf("error restarting heartbeat service: %w", err)
|
||||
}
|
||||
fmt.Println(" ✓ Heartbeat service restarted")
|
||||
|
||||
// Re-create media store with new config
|
||||
services.MediaStore = media.NewFileMediaStoreWithCleanup(media.MediaCleanerConfig{
|
||||
runningServices.MediaStore = media.NewFileMediaStoreWithCleanup(media.MediaCleanerConfig{
|
||||
Enabled: cfg.Tools.MediaCleanup.Enabled,
|
||||
MaxAge: time.Duration(cfg.Tools.MediaCleanup.MaxAge) * time.Minute,
|
||||
Interval: time.Duration(cfg.Tools.MediaCleanup.Interval) * time.Minute,
|
||||
})
|
||||
// Start the media store if it's a FileMediaStore with cleanup
|
||||
if fms, ok := services.MediaStore.(*media.FileMediaStore); ok {
|
||||
if fms, ok := runningServices.MediaStore.(*media.FileMediaStore); ok {
|
||||
fms.Start()
|
||||
}
|
||||
al.SetMediaStore(services.MediaStore)
|
||||
al.SetMediaStore(runningServices.MediaStore)
|
||||
|
||||
// Re-create channel manager with new config
|
||||
services.ChannelManager, err = channels.NewManager(cfg, msgBus, services.MediaStore)
|
||||
runningServices.ChannelManager, err = channels.NewManager(cfg, msgBus, runningServices.MediaStore)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error recreating channel manager: %w", err)
|
||||
}
|
||||
al.SetChannelManager(services.ChannelManager)
|
||||
al.SetChannelManager(runningServices.ChannelManager)
|
||||
|
||||
enabledChannels := services.ChannelManager.GetEnabledChannels()
|
||||
enabledChannels := runningServices.ChannelManager.GetEnabledChannels()
|
||||
if len(enabledChannels) > 0 {
|
||||
fmt.Printf(" ✓ Channels enabled: %s\n", enabledChannels)
|
||||
} else {
|
||||
fmt.Println(" ⚠ Warning: No channels enabled")
|
||||
}
|
||||
|
||||
// Setup HTTP server with new config
|
||||
addr := fmt.Sprintf("%s:%d", cfg.Gateway.Host, cfg.Gateway.Port)
|
||||
services.HealthServer = health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port)
|
||||
services.ChannelManager.SetupHTTPServer(addr, services.HealthServer)
|
||||
runningServices.HealthServer = health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port)
|
||||
runningServices.ChannelManager.SetupHTTPServer(addr, runningServices.HealthServer)
|
||||
|
||||
// Use background context for lifecycle to ensure services persist after restartServices returns
|
||||
if err = services.ChannelManager.StartAll(context.Background()); err != nil {
|
||||
if err = runningServices.ChannelManager.StartAll(context.Background()); err != nil {
|
||||
return fmt.Errorf("error restarting channels: %w", err)
|
||||
}
|
||||
fmt.Printf(
|
||||
@@ -447,22 +438,20 @@ func restartServices(
|
||||
cfg.Gateway.Port,
|
||||
)
|
||||
|
||||
// Re-create device service with new config
|
||||
stateManager := state.NewManager(cfg.WorkspacePath())
|
||||
services.DeviceService = devices.NewService(devices.Config{
|
||||
runningServices.DeviceService = devices.NewService(devices.Config{
|
||||
Enabled: cfg.Devices.Enabled,
|
||||
MonitorUSB: cfg.Devices.MonitorUSB,
|
||||
}, stateManager)
|
||||
services.DeviceService.SetBus(msgBus)
|
||||
if err := services.DeviceService.Start(context.Background()); err != nil {
|
||||
runningServices.DeviceService.SetBus(msgBus)
|
||||
if err := runningServices.DeviceService.Start(context.Background()); err != nil {
|
||||
logger.WarnCF("device", "Failed to restart device service", map[string]any{"error": err.Error()})
|
||||
} else if cfg.Devices.Enabled {
|
||||
fmt.Println(" ✓ Device event service restarted")
|
||||
}
|
||||
|
||||
// Wire up voice transcription with new config
|
||||
transcriber := voice.DetectTranscriber(cfg)
|
||||
al.SetTranscriber(transcriber) // This will set it to nil if disabled
|
||||
al.SetTranscriber(transcriber)
|
||||
if transcriber != nil {
|
||||
logger.InfoCF("voice", "Transcription re-enabled (agent-level)", map[string]any{"provider": transcriber.Name()})
|
||||
} else {
|
||||
@@ -472,8 +461,6 @@ func restartServices(
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupConfigWatcherPolling sets up a simple polling-based config file watcher
|
||||
// Returns a channel for config updates and a stop function
|
||||
func setupConfigWatcherPolling(configPath string, debug bool) (chan *config.Config, func()) {
|
||||
configChan := make(chan *config.Config, 1)
|
||||
stop := make(chan struct{})
|
||||
@@ -483,11 +470,10 @@ func setupConfigWatcherPolling(configPath string, debug bool) (chan *config.Conf
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
// Get initial file info
|
||||
lastModTime := getFileModTime(configPath)
|
||||
lastSize := getFileSize(configPath)
|
||||
|
||||
ticker := time.NewTicker(2 * time.Second) // Check every 2 seconds
|
||||
ticker := time.NewTicker(2 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
@@ -496,20 +482,16 @@ func setupConfigWatcherPolling(configPath string, debug bool) (chan *config.Conf
|
||||
currentModTime := getFileModTime(configPath)
|
||||
currentSize := getFileSize(configPath)
|
||||
|
||||
// Check if file changed (modification time or size changed)
|
||||
if currentModTime.After(lastModTime) || currentSize != lastSize {
|
||||
if debug {
|
||||
logger.Debugf("🔍 Config file change detected")
|
||||
}
|
||||
|
||||
// Debounce - wait a bit to ensure file write is complete
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// Update last known state to prevent repeated reload attempts on failure
|
||||
lastModTime = currentModTime
|
||||
lastSize = currentSize
|
||||
|
||||
// Validate and load new config
|
||||
newCfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
logger.Errorf("⚠ Error loading new config: %v", err)
|
||||
@@ -517,7 +499,6 @@ func setupConfigWatcherPolling(configPath string, debug bool) (chan *config.Conf
|
||||
continue
|
||||
}
|
||||
|
||||
// Validate the new config
|
||||
if err := newCfg.ValidateModelList(); err != nil {
|
||||
logger.Errorf(" ⚠ New config validation failed: %v", err)
|
||||
logger.Warn(" Using previous valid config")
|
||||
@@ -526,15 +507,12 @@ func setupConfigWatcherPolling(configPath string, debug bool) (chan *config.Conf
|
||||
|
||||
logger.Info("✓ Config file validated and loaded")
|
||||
|
||||
// Send new config to main loop (non-blocking)
|
||||
select {
|
||||
case configChan <- newCfg:
|
||||
default:
|
||||
// Channel full, skip this update
|
||||
logger.Warn("⚠ Previous config reload still in progress, skipping")
|
||||
}
|
||||
}
|
||||
|
||||
case <-stop:
|
||||
return
|
||||
}
|
||||
@@ -549,7 +527,6 @@ func setupConfigWatcherPolling(configPath string, debug bool) (chan *config.Conf
|
||||
return configChan, stopFunc
|
||||
}
|
||||
|
||||
// getFileModTime returns the modification time of a file, or zero time if file doesn't exist
|
||||
func getFileModTime(path string) time.Time {
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
@@ -558,7 +535,6 @@ func getFileModTime(path string) time.Time {
|
||||
return info.ModTime()
|
||||
}
|
||||
|
||||
// getFileSize returns the size of a file, or 0 if file doesn't exist
|
||||
func getFileSize(path string) int64 {
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
@@ -577,10 +553,8 @@ func setupCronTool(
|
||||
) (*cron.CronService, error) {
|
||||
cronStorePath := filepath.Join(workspace, "cron", "jobs.json")
|
||||
|
||||
// Create cron service
|
||||
cronService := cron.NewCronService(cronStorePath, nil)
|
||||
|
||||
// Create and register CronTool if enabled
|
||||
var cronTool *tools.CronTool
|
||||
if cfg.Tools.IsToolEnabled("cron") {
|
||||
var err error
|
||||
@@ -592,7 +566,6 @@ func setupCronTool(
|
||||
agentLoop.RegisterTool(cronTool)
|
||||
}
|
||||
|
||||
// Set onJob handler
|
||||
if cronTool != nil {
|
||||
cronService.SetOnJob(func(job *cron.CronJob) (string, error) {
|
||||
result := cronTool.ExecuteJob(context.Background(), job)
|
||||
@@ -605,22 +578,17 @@ func setupCronTool(
|
||||
|
||||
func createHeartbeatHandler(agentLoop *agent.AgentLoop) func(prompt, channel, chatID string) *tools.ToolResult {
|
||||
return func(prompt, channel, chatID string) *tools.ToolResult {
|
||||
// Use cli:direct as fallback if no valid channel
|
||||
if channel == "" || chatID == "" {
|
||||
channel, chatID = "cli", "direct"
|
||||
}
|
||||
// Use ProcessHeartbeat - no session history, each heartbeat is independent
|
||||
var response string
|
||||
var err error
|
||||
response, err = agentLoop.ProcessHeartbeat(context.Background(), prompt, channel, chatID)
|
||||
|
||||
response, err := agentLoop.ProcessHeartbeat(context.Background(), prompt, channel, chatID)
|
||||
if err != nil {
|
||||
return tools.ErrorResult(fmt.Sprintf("Heartbeat error: %v", err))
|
||||
}
|
||||
if response == "HEARTBEAT_OK" {
|
||||
return tools.SilentResult("Heartbeat OK")
|
||||
}
|
||||
// For heartbeat, always return silent - the subagent result will be
|
||||
// sent to user via processSystemMessage when the async task completes
|
||||
return tools.SilentResult(response)
|
||||
}
|
||||
}
|
||||
@@ -1,80 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// GatewayEvent represents a state change event for the gateway process.
|
||||
type GatewayEvent struct {
|
||||
Status string `json:"gateway_status"` // "running", "starting", "restarting", "stopped", "error"
|
||||
PID int `json:"pid,omitempty"`
|
||||
BootDefaultModel string `json:"boot_default_model,omitempty"`
|
||||
ConfigDefaultModel string `json:"config_default_model,omitempty"`
|
||||
RestartRequired bool `json:"gateway_restart_required,omitempty"`
|
||||
}
|
||||
|
||||
// EventBroadcaster manages SSE client subscriptions and broadcasts events.
|
||||
type EventBroadcaster struct {
|
||||
mu sync.RWMutex
|
||||
clients map[chan string]struct{}
|
||||
}
|
||||
|
||||
// NewEventBroadcaster creates a new broadcaster.
|
||||
func NewEventBroadcaster() *EventBroadcaster {
|
||||
return &EventBroadcaster{
|
||||
clients: make(map[chan string]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Subscribe adds a new listener channel and returns it.
|
||||
// The caller must call Unsubscribe when done.
|
||||
func (b *EventBroadcaster) Subscribe() chan string {
|
||||
ch := make(chan string, 8)
|
||||
b.mu.Lock()
|
||||
b.clients[ch] = struct{}{}
|
||||
b.mu.Unlock()
|
||||
return ch
|
||||
}
|
||||
|
||||
// Unsubscribe removes a listener channel and closes it.
|
||||
func (b *EventBroadcaster) Unsubscribe(ch chan string) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
// Check if the channel is still registered before closing
|
||||
if _, exists := b.clients[ch]; exists {
|
||||
delete(b.clients, ch)
|
||||
close(ch)
|
||||
}
|
||||
}
|
||||
|
||||
// Broadcast sends a GatewayEvent to all connected SSE clients.
|
||||
func (b *EventBroadcaster) Broadcast(event GatewayEvent) {
|
||||
data, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
b.mu.RLock()
|
||||
defer b.mu.RUnlock()
|
||||
|
||||
for ch := range b.clients {
|
||||
// Non-blocking send; drop event if client is slow
|
||||
select {
|
||||
case ch <- string(data):
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown closes all subscriber channels, notifying all SSE clients to disconnect.
|
||||
// This should be called when the server is shutting down.
|
||||
func (b *EventBroadcaster) Shutdown() {
|
||||
// Close all channels to notify listeners
|
||||
for ch := range b.clients {
|
||||
b.Unsubscribe(ch)
|
||||
}
|
||||
// Clear the map
|
||||
b.clients = make(map[chan string]struct{})
|
||||
}
|
||||
+60
-149
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
@@ -30,11 +31,9 @@ var gateway = struct {
|
||||
runtimeStatus string
|
||||
startupDeadline time.Time
|
||||
logs *LogBuffer
|
||||
events *EventBroadcaster
|
||||
}{
|
||||
runtimeStatus: "stopped",
|
||||
logs: NewLogBuffer(200),
|
||||
events: NewEventBroadcaster(),
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -51,11 +50,19 @@ var gatewayHealthGet = func(url string, timeout time.Duration) (*http.Response,
|
||||
|
||||
// getGatewayHealth checks the gateway health endpoint and returns the status response
|
||||
// Returns (*health.StatusResponse, statusCode, error). If error is not nil, the other values are not valid.
|
||||
func getGatewayHealth(port int, timeout time.Duration) (*health.StatusResponse, int, error) {
|
||||
if port == 0 {
|
||||
port = 18790
|
||||
func (h *Handler) getGatewayHealth(cfg *config.Config, timeout time.Duration) (*health.StatusResponse, int, error) {
|
||||
port := 18790
|
||||
if cfg != nil && cfg.Gateway.Port != 0 {
|
||||
port = cfg.Gateway.Port
|
||||
}
|
||||
url := fmt.Sprintf("http://127.0.0.1:%d/health", port)
|
||||
|
||||
probeHost := gatewayProbeHost(h.effectiveGatewayBindHost(cfg))
|
||||
url := "http://" + net.JoinHostPort(probeHost, strconv.Itoa(port)) + "/health"
|
||||
|
||||
return getGatewayHealthByURL(url, timeout)
|
||||
}
|
||||
|
||||
func getGatewayHealthByURL(url string, timeout time.Duration) (*health.StatusResponse, int, error) {
|
||||
resp, err := gatewayHealthGet(url, timeout)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
@@ -73,7 +80,6 @@ func getGatewayHealth(port int, timeout time.Duration) (*health.StatusResponse,
|
||||
// registerGatewayRoutes binds gateway lifecycle endpoints to the ServeMux.
|
||||
func (h *Handler) registerGatewayRoutes(mux *http.ServeMux) {
|
||||
mux.HandleFunc("GET /api/gateway/status", h.handleGatewayStatus)
|
||||
mux.HandleFunc("GET /api/gateway/events", h.handleGatewayEvents)
|
||||
mux.HandleFunc("GET /api/gateway/logs", h.handleGatewayLogs)
|
||||
mux.HandleFunc("POST /api/gateway/logs/clear", h.handleGatewayClearLogs)
|
||||
mux.HandleFunc("POST /api/gateway/start", h.handleGatewayStart)
|
||||
@@ -87,7 +93,7 @@ func (h *Handler) TryAutoStartGateway() {
|
||||
// Check if gateway is already running via health endpoint
|
||||
cfg, cfgErr := config.LoadConfig(h.configPath)
|
||||
if cfgErr == nil && cfg != nil {
|
||||
healthResp, statusCode, err := getGatewayHealth(cfg.Gateway.Port, 2*time.Second)
|
||||
healthResp, statusCode, err := h.getGatewayHealth(cfg, 2*time.Second)
|
||||
if err == nil && statusCode == http.StatusOK {
|
||||
// Gateway is already running, attach to the existing process
|
||||
pid := healthResp.Pid
|
||||
@@ -170,6 +176,16 @@ func lookupModelConfig(cfg *config.Config, modelName string) *config.ModelConfig
|
||||
return modelCfg
|
||||
}
|
||||
|
||||
func gatewayRestartRequired(configDefaultModel, bootDefaultModel, gatewayStatus string) bool {
|
||||
if gatewayStatus != "running" {
|
||||
return false
|
||||
}
|
||||
if strings.TrimSpace(configDefaultModel) == "" || strings.TrimSpace(bootDefaultModel) == "" {
|
||||
return false
|
||||
}
|
||||
return configDefaultModel != bootDefaultModel
|
||||
}
|
||||
|
||||
func isCmdProcessAliveLocked(cmd *exec.Cmd) bool {
|
||||
if cmd == nil || cmd.Process == nil {
|
||||
return false
|
||||
@@ -220,7 +236,7 @@ func attachToGatewayProcessLocked(pid int, cfg *config.Config) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func gatewayStatusOnHealthFailureLocked() string {
|
||||
func gatewayStatusWithoutHealthLocked() string {
|
||||
if gateway.runtimeStatus == "starting" || gateway.runtimeStatus == "restarting" {
|
||||
if gateway.startupDeadline.IsZero() || time.Now().Before(gateway.startupDeadline) {
|
||||
return gateway.runtimeStatus
|
||||
@@ -233,23 +249,7 @@ func gatewayStatusOnHealthFailureLocked() string {
|
||||
if gateway.runtimeStatus == "error" {
|
||||
return "error"
|
||||
}
|
||||
return "error"
|
||||
}
|
||||
|
||||
func currentGatewayStatusLocked(processAlive bool) string {
|
||||
if !processAlive {
|
||||
if gateway.runtimeStatus == "restarting" {
|
||||
if gateway.startupDeadline.IsZero() || time.Now().Before(gateway.startupDeadline) {
|
||||
return "restarting"
|
||||
}
|
||||
return "error"
|
||||
}
|
||||
if gateway.runtimeStatus == "error" {
|
||||
return "error"
|
||||
}
|
||||
return "stopped"
|
||||
}
|
||||
return gatewayStatusOnHealthFailureLocked()
|
||||
return "stopped"
|
||||
}
|
||||
|
||||
func waitForGatewayProcessExit(cmd *exec.Cmd, timeout time.Duration) bool {
|
||||
@@ -319,15 +319,6 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Broadcast the attached state
|
||||
gateway.events.Broadcast(GatewayEvent{
|
||||
Status: initialStatus,
|
||||
PID: pid,
|
||||
BootDefaultModel: defaultModelName,
|
||||
ConfigDefaultModel: defaultModelName,
|
||||
RestartRequired: false,
|
||||
})
|
||||
|
||||
return pid, nil
|
||||
}
|
||||
|
||||
@@ -335,7 +326,7 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
|
||||
// Locate the picoclaw executable
|
||||
execPath := utils.FindPicoclawBinary()
|
||||
|
||||
cmd = exec.Command(execPath, "gateway")
|
||||
cmd = exec.Command(execPath, "gateway", "-E")
|
||||
cmd.Env = os.Environ()
|
||||
// Forward the launcher's config path via the environment variable that
|
||||
// GetConfigPath() already reads, so the gateway sub-process uses the same
|
||||
@@ -376,15 +367,6 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
|
||||
pid = cmd.Process.Pid
|
||||
log.Printf("Started picoclaw gateway (PID: %d) from %s", pid, execPath)
|
||||
|
||||
// Broadcast the launch state immediately so clients can reflect it without polling.
|
||||
gateway.events.Broadcast(GatewayEvent{
|
||||
Status: initialStatus,
|
||||
PID: pid,
|
||||
BootDefaultModel: defaultModelName,
|
||||
ConfigDefaultModel: defaultModelName,
|
||||
RestartRequired: false,
|
||||
})
|
||||
|
||||
// Capture stdout/stderr in background
|
||||
go scanPipe(stdoutPipe, gateway.logs)
|
||||
go scanPipe(stderrPipe, gateway.logs)
|
||||
@@ -398,26 +380,17 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
|
||||
}
|
||||
|
||||
gateway.mu.Lock()
|
||||
shouldBroadcastStopped := false
|
||||
if gateway.cmd == cmd {
|
||||
gateway.cmd = nil
|
||||
gateway.bootDefaultModel = ""
|
||||
if gateway.runtimeStatus != "restarting" {
|
||||
setGatewayRuntimeStatusLocked("stopped")
|
||||
shouldBroadcastStopped = true
|
||||
}
|
||||
}
|
||||
gateway.mu.Unlock()
|
||||
|
||||
if shouldBroadcastStopped {
|
||||
gateway.events.Broadcast(GatewayEvent{
|
||||
Status: "stopped",
|
||||
RestartRequired: false,
|
||||
})
|
||||
}
|
||||
}()
|
||||
|
||||
// Start a goroutine to probe health and broadcast "running" once ready
|
||||
// Start a goroutine to probe health and update the runtime state once ready.
|
||||
go func() {
|
||||
for i := 0; i < 30; i++ { // try for up to 15 seconds
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
@@ -431,7 +404,7 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
healthResp, statusCode, err := getGatewayHealth(cfg.Gateway.Port, 1*time.Second)
|
||||
healthResp, statusCode, err := h.getGatewayHealth(cfg, 1*time.Second)
|
||||
if err == nil && statusCode == http.StatusOK && healthResp.Pid == pid {
|
||||
// Verify the health endpoint returns the expected pid
|
||||
gateway.mu.Lock()
|
||||
@@ -439,13 +412,6 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
|
||||
setGatewayRuntimeStatusLocked("running")
|
||||
}
|
||||
gateway.mu.Unlock()
|
||||
gateway.events.Broadcast(GatewayEvent{
|
||||
Status: "running",
|
||||
PID: pid,
|
||||
BootDefaultModel: defaultModelName,
|
||||
ConfigDefaultModel: defaultModelName,
|
||||
RestartRequired: false,
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -461,7 +427,7 @@ func (h *Handler) handleGatewayStart(w http.ResponseWriter, r *http.Request) {
|
||||
// Prevent duplicate starts by checking health endpoint
|
||||
cfg, cfgErr := config.LoadConfig(h.configPath)
|
||||
if cfgErr == nil && cfg != nil {
|
||||
healthResp, statusCode, err := getGatewayHealth(cfg.Gateway.Port, 2*time.Second)
|
||||
healthResp, statusCode, err := h.getGatewayHealth(cfg, 2*time.Second)
|
||||
if err == nil && statusCode == http.StatusOK {
|
||||
// Gateway is already running, attach to the existing process
|
||||
pid := healthResp.Pid
|
||||
@@ -597,10 +563,6 @@ func (h *Handler) RestartGateway() (int, error) {
|
||||
gateway.mu.Lock()
|
||||
previousCmd := gateway.cmd
|
||||
setGatewayRuntimeStatusLocked("restarting")
|
||||
gateway.events.Broadcast(GatewayEvent{
|
||||
Status: "restarting",
|
||||
RestartRequired: false,
|
||||
})
|
||||
gateway.mu.Unlock()
|
||||
|
||||
if err = stopGatewayProcessForRestart(previousCmd); err != nil {
|
||||
@@ -704,24 +666,20 @@ func (h *Handler) handleGatewayStatus(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
func (h *Handler) gatewayStatusData() map[string]any {
|
||||
data := map[string]any{}
|
||||
configDefaultModel := ""
|
||||
cfg, cfgErr := config.LoadConfig(h.configPath)
|
||||
if cfgErr == nil && cfg != nil {
|
||||
configDefaultModel := strings.TrimSpace(cfg.Agents.Defaults.GetModelName())
|
||||
configDefaultModel = strings.TrimSpace(cfg.Agents.Defaults.GetModelName())
|
||||
if configDefaultModel != "" {
|
||||
data["config_default_model"] = configDefaultModel
|
||||
}
|
||||
}
|
||||
|
||||
// Probe health endpoint to get pid and status
|
||||
port := 0
|
||||
if cfgErr == nil && cfg != nil {
|
||||
port = cfg.Gateway.Port
|
||||
}
|
||||
|
||||
healthResp, statusCode, err := getGatewayHealth(port, 2*time.Second)
|
||||
healthResp, statusCode, err := h.getGatewayHealth(cfg, 2*time.Second)
|
||||
if err != nil {
|
||||
gateway.mu.Lock()
|
||||
data["gateway_status"] = currentGatewayStatusLocked(true)
|
||||
data["gateway_status"] = gatewayStatusWithoutHealthLocked()
|
||||
gateway.mu.Unlock()
|
||||
log.Printf("Gateway health check failed: %v", err)
|
||||
} else {
|
||||
@@ -734,45 +692,43 @@ func (h *Handler) gatewayStatusData() map[string]any {
|
||||
data["status_code"] = statusCode
|
||||
} else {
|
||||
gateway.mu.Lock()
|
||||
// Check if this pid matches our tracked process
|
||||
if gateway.cmd != nil && gateway.cmd.Process != nil && gateway.cmd.Process.Pid == healthResp.Pid {
|
||||
setGatewayRuntimeStatusLocked("running")
|
||||
bootDefaultModel := gateway.bootDefaultModel
|
||||
if bootDefaultModel != "" {
|
||||
data["boot_default_model"] = bootDefaultModel
|
||||
}
|
||||
data["gateway_status"] = "running"
|
||||
data["pid"] = healthResp.Pid
|
||||
} else {
|
||||
// Health endpoint responded with a different pid
|
||||
// This could be a manual restart, try to attach to the new process
|
||||
setGatewayRuntimeStatusLocked("running")
|
||||
if gateway.cmd == nil || gateway.cmd.Process == nil || gateway.cmd.Process.Pid != healthResp.Pid {
|
||||
oldPid := "none"
|
||||
if gateway.cmd != nil && gateway.cmd.Process != nil {
|
||||
oldPid = fmt.Sprintf("%d", gateway.cmd.Process.Pid)
|
||||
}
|
||||
log.Printf("Detected new gateway PID (old: %s, new: %d), attempting to attach", oldPid, healthResp.Pid)
|
||||
|
||||
log.Printf(
|
||||
"Detected gateway PID from health (old: %s, new: %d), attempting to attach",
|
||||
oldPid,
|
||||
healthResp.Pid,
|
||||
)
|
||||
if err := attachToGatewayProcessLocked(healthResp.Pid, cfg); err != nil {
|
||||
// Failed to find the process, treat as error
|
||||
setGatewayRuntimeStatusLocked("error")
|
||||
data["gateway_status"] = "error"
|
||||
data["pid"] = healthResp.Pid
|
||||
log.Printf("Failed to attach to new gateway process (PID: %d): %v", healthResp.Pid, err)
|
||||
} else {
|
||||
// Successfully attached, update response data
|
||||
bootDefaultModel := gateway.bootDefaultModel
|
||||
if bootDefaultModel != "" {
|
||||
data["boot_default_model"] = bootDefaultModel
|
||||
}
|
||||
data["gateway_status"] = "running"
|
||||
data["pid"] = healthResp.Pid
|
||||
log.Printf(
|
||||
"Failed to attach to gateway process reported by health (PID: %d): %v",
|
||||
healthResp.Pid,
|
||||
err,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
bootDefaultModel := gateway.bootDefaultModel
|
||||
if bootDefaultModel != "" {
|
||||
data["boot_default_model"] = bootDefaultModel
|
||||
}
|
||||
data["gateway_status"] = "running"
|
||||
data["pid"] = healthResp.Pid
|
||||
gateway.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
data["gateway_restart_required"] = false
|
||||
bootDefaultModel, _ := data["boot_default_model"].(string)
|
||||
gatewayStatus, _ := data["gateway_status"].(string)
|
||||
data["gateway_restart_required"] = gatewayRestartRequired(
|
||||
configDefaultModel,
|
||||
bootDefaultModel,
|
||||
gatewayStatus,
|
||||
)
|
||||
|
||||
ready, reason, readyErr := h.gatewayStartReady()
|
||||
if readyErr != nil {
|
||||
@@ -842,51 +798,6 @@ func gatewayLogsData(r *http.Request) map[string]any {
|
||||
return data
|
||||
}
|
||||
|
||||
// handleGatewayEvents serves an SSE stream of gateway state change events.
|
||||
//
|
||||
// GET /api/gateway/events
|
||||
func (h *Handler) handleGatewayEvents(w http.ResponseWriter, r *http.Request) {
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
http.Error(w, "SSE not supported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
|
||||
// Subscribe to gateway events
|
||||
ch := gateway.events.Subscribe()
|
||||
defer gateway.events.Unsubscribe(ch)
|
||||
|
||||
// Send initial status so the client doesn't start blank
|
||||
initial := h.currentGatewayStatus()
|
||||
fmt.Fprintf(w, "data: %s\n\n", initial)
|
||||
flusher.Flush()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-r.Context().Done():
|
||||
return
|
||||
case data, ok := <-ch:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// currentGatewayStatus returns the current gateway status as a JSON string.
|
||||
func (h *Handler) currentGatewayStatus() string {
|
||||
data := h.gatewayStatusData()
|
||||
encoded, _ := json.Marshal(data)
|
||||
return string(encoded)
|
||||
}
|
||||
|
||||
// scanPipe reads lines from r and appends them to buf. Returns when r reaches EOF.
|
||||
func scanPipe(r io.Reader, buf *LogBuffer) {
|
||||
scanner := bufio.NewScanner(r)
|
||||
|
||||
@@ -3,6 +3,7 @@ package api
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@@ -46,6 +47,23 @@ func gatewayProbeHost(bindHost string) string {
|
||||
return bindHost
|
||||
}
|
||||
|
||||
func (h *Handler) gatewayProxyURL() *url.URL {
|
||||
cfg, err := config.LoadConfig(h.configPath)
|
||||
port := 18790
|
||||
bindHost := ""
|
||||
if err == nil && cfg != nil {
|
||||
if cfg.Gateway.Port != 0 {
|
||||
port = cfg.Gateway.Port
|
||||
}
|
||||
bindHost = h.effectiveGatewayBindHost(cfg)
|
||||
}
|
||||
|
||||
return &url.URL{
|
||||
Scheme: "http",
|
||||
Host: net.JoinHostPort(gatewayProbeHost(bindHost), strconv.Itoa(port)),
|
||||
}
|
||||
}
|
||||
|
||||
func requestHostName(r *http.Request) string {
|
||||
reqHost, _, err := net.SplitHostPort(r.Host)
|
||||
if err == nil {
|
||||
|
||||
@@ -2,9 +2,12 @@ package api
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/web/backend/launcherconfig"
|
||||
@@ -59,6 +62,79 @@ func TestGatewayProbeHostUsesLoopbackForWildcardBind(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayProxyURLUsesConfiguredHost(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Gateway.Host = "192.168.1.10"
|
||||
cfg.Gateway.Port = 18791
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
if got := h.gatewayProxyURL().String(); got != "http://192.168.1.10:18791" {
|
||||
t.Fatalf("gatewayProxyURL() = %q, want %q", got, "http://192.168.1.10:18791")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetGatewayHealthUsesConfiguredHost(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Gateway.Host = "192.168.1.10"
|
||||
cfg.Gateway.Port = 18791
|
||||
|
||||
originalHealthGet := gatewayHealthGet
|
||||
t.Cleanup(func() {
|
||||
gatewayHealthGet = originalHealthGet
|
||||
})
|
||||
|
||||
var requestedURL string
|
||||
gatewayHealthGet = func(url string, timeout time.Duration) (*http.Response, error) {
|
||||
requestedURL = url
|
||||
return nil, errors.New("probe failed")
|
||||
}
|
||||
|
||||
_, statusCode, err := h.getGatewayHealth(cfg, time.Second)
|
||||
_ = statusCode
|
||||
_ = err
|
||||
|
||||
if requestedURL != "http://192.168.1.10:18791/health" {
|
||||
t.Fatalf("health url = %q, want %q", requestedURL, "http://192.168.1.10:18791/health")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetGatewayHealthUsesProbeHostForPublicLauncher(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
h.SetServerOptions(18800, true, true, nil)
|
||||
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Gateway.Host = "127.0.0.1"
|
||||
cfg.Gateway.Port = 18791
|
||||
|
||||
originalHealthGet := gatewayHealthGet
|
||||
t.Cleanup(func() {
|
||||
gatewayHealthGet = originalHealthGet
|
||||
})
|
||||
|
||||
var requestedURL string
|
||||
gatewayHealthGet = func(url string, timeout time.Duration) (*http.Response, error) {
|
||||
requestedURL = url
|
||||
return nil, errors.New("probe failed")
|
||||
}
|
||||
|
||||
_, statusCode, err := h.getGatewayHealth(cfg, time.Second)
|
||||
_ = statusCode
|
||||
_ = err
|
||||
|
||||
if requestedURL != "http://127.0.0.1:18791/health" {
|
||||
t.Fatalf("health url = %q, want %q", requestedURL, "http://127.0.0.1:18791/health")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildWsURLUsesWSSWhenForwardedProtoIsHTTPS(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
|
||||
@@ -3,6 +3,7 @@ package api
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
@@ -36,6 +37,15 @@ func startLongRunningProcess(t *testing.T) *exec.Cmd {
|
||||
return cmd
|
||||
}
|
||||
|
||||
func mockGatewayHealthResponse(statusCode, pid int) *http.Response {
|
||||
return &http.Response{
|
||||
StatusCode: statusCode,
|
||||
Body: io.NopCloser(strings.NewReader(
|
||||
`{"status":"ok","uptime":"1s","pid":` + strconv.Itoa(pid) + `}`,
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
func startIgnoringTermProcess(t *testing.T) *exec.Cmd {
|
||||
t.Helper()
|
||||
|
||||
@@ -419,6 +429,125 @@ func TestGatewayStatusKeepsRunningWhenHealthProbeFailsAfterRunning(t *testing.T)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayStatusReportsRunningFromHealthProbe(t *testing.T) {
|
||||
resetGatewayTestState(t)
|
||||
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
cmd := startLongRunningProcess(t)
|
||||
t.Cleanup(func() {
|
||||
if cmd.Process != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
}
|
||||
_ = cmd.Wait()
|
||||
})
|
||||
|
||||
gateway.mu.Lock()
|
||||
setGatewayRuntimeStatusLocked("stopped")
|
||||
gateway.mu.Unlock()
|
||||
|
||||
gatewayHealthGet = func(string, time.Duration) (*http.Response, error) {
|
||||
return mockGatewayHealthResponse(http.StatusOK, cmd.Process.Pid), nil
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var body map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if got := body["gateway_status"]; got != "running" {
|
||||
t.Fatalf("gateway_status = %#v, want %q", got, "running")
|
||||
}
|
||||
if got := body["pid"]; got != float64(cmd.Process.Pid) {
|
||||
t.Fatalf("pid = %#v, want %d", got, cmd.Process.Pid)
|
||||
}
|
||||
if got := body["gateway_restart_required"]; got != false {
|
||||
t.Fatalf("gateway_restart_required = %#v, want false", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayStatusRequiresRestartAfterDefaultModelChange(t *testing.T) {
|
||||
resetGatewayTestState(t)
|
||||
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Agents.Defaults.ModelName = cfg.ModelList[0].ModelName
|
||||
cfg.ModelList[0].APIKey = "test-key"
|
||||
cfg.ModelList = append(cfg.ModelList, config.ModelConfig{
|
||||
ModelName: "second-model",
|
||||
Model: "openai/gpt-4.1",
|
||||
APIKey: "second-key",
|
||||
})
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
h := NewHandler(configPath)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
process, err := os.FindProcess(os.Getpid())
|
||||
if err != nil {
|
||||
t.Fatalf("FindProcess() error = %v", err)
|
||||
}
|
||||
|
||||
gateway.mu.Lock()
|
||||
gateway.cmd = &exec.Cmd{Process: process}
|
||||
gateway.bootDefaultModel = cfg.ModelList[0].ModelName
|
||||
setGatewayRuntimeStatusLocked("running")
|
||||
gateway.mu.Unlock()
|
||||
|
||||
updatedCfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error = %v", err)
|
||||
}
|
||||
updatedCfg.Agents.Defaults.ModelName = "second-model"
|
||||
if err := config.SaveConfig(configPath, updatedCfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
gatewayHealthGet = func(string, time.Duration) (*http.Response, error) {
|
||||
return mockGatewayHealthResponse(http.StatusOK, os.Getpid()), nil
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil)
|
||||
mux.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var body map[string]any
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
|
||||
t.Fatalf("unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if got := body["gateway_status"]; got != "running" {
|
||||
t.Fatalf("gateway_status = %#v, want %q", got, "running")
|
||||
}
|
||||
if got := body["boot_default_model"]; got != cfg.ModelList[0].ModelName {
|
||||
t.Fatalf("boot_default_model = %#v, want %q", got, cfg.ModelList[0].ModelName)
|
||||
}
|
||||
if got := body["config_default_model"]; got != "second-model" {
|
||||
t.Fatalf("config_default_model = %#v, want %q", got, "second-model")
|
||||
}
|
||||
if got := body["gateway_restart_required"]; got != true {
|
||||
t.Fatalf("gateway_restart_required = %#v, want true", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGatewayStatusReturnsErrorAfterStartupWindowExpires(t *testing.T) {
|
||||
resetGatewayTestState(t)
|
||||
|
||||
|
||||
+7
-17
@@ -7,7 +7,6 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
@@ -22,20 +21,13 @@ func (h *Handler) registerPicoRoutes(mux *http.ServeMux) {
|
||||
// WebSocket proxy: forward /pico/ws to gateway
|
||||
// This allows the frontend to connect via the same port as the web UI,
|
||||
// avoiding the need to expose extra ports for WebSocket communication.
|
||||
wsProxy := h.createWsProxy()
|
||||
mux.HandleFunc("GET /pico/ws", h.handleWebSocketProxy(wsProxy))
|
||||
mux.HandleFunc("GET /pico/ws", h.handleWebSocketProxy())
|
||||
}
|
||||
|
||||
// createWsProxy creates a reverse proxy to the gateway WebSocket endpoint.
|
||||
// The gateway port is read from the configuration.
|
||||
// createWsProxy creates a reverse proxy to the current gateway WebSocket endpoint.
|
||||
// The gateway bind host and port are resolved from the latest configuration.
|
||||
func (h *Handler) createWsProxy() *httputil.ReverseProxy {
|
||||
cfg, err := config.LoadConfig(h.configPath)
|
||||
gatewayPort := 18790 // default
|
||||
if err == nil && cfg.Gateway.Port != 0 {
|
||||
gatewayPort = cfg.Gateway.Port
|
||||
}
|
||||
gatewayURL, _ := url.Parse(fmt.Sprintf("http://127.0.0.1:%d", gatewayPort))
|
||||
wsProxy := httputil.NewSingleHostReverseProxy(gatewayURL)
|
||||
wsProxy := httputil.NewSingleHostReverseProxy(h.gatewayProxyURL())
|
||||
wsProxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
http.Error(w, "Gateway unavailable: "+err.Error(), http.StatusBadGateway)
|
||||
}
|
||||
@@ -43,12 +35,10 @@ func (h *Handler) createWsProxy() *httputil.ReverseProxy {
|
||||
}
|
||||
|
||||
// handleWebSocketProxy wraps a reverse proxy to handle WebSocket connections.
|
||||
// It ensures the Connection and Upgrade headers are properly forwarded.
|
||||
func (h *Handler) handleWebSocketProxy(proxy *httputil.ReverseProxy) http.HandlerFunc {
|
||||
// The reverse proxy forwards the incoming upgrade handshake as-is.
|
||||
func (h *Handler) handleWebSocketProxy() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// Set headers for WebSocket upgrade
|
||||
r.Header.Set("Connection", "upgrade")
|
||||
r.Header.Set("Upgrade", "websocket")
|
||||
proxy := h.createWsProxy()
|
||||
proxy.ServeHTTP(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,9 +2,12 @@ package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
@@ -235,3 +238,77 @@ func TestHandlePicoSetup_Response(t *testing.T) {
|
||||
t.Error("response should have changed=true on first setup")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleWebSocketProxyReloadsGatewayTargetFromConfig(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "config.json")
|
||||
h := NewHandler(configPath)
|
||||
handler := h.handleWebSocketProxy()
|
||||
|
||||
server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/pico/ws" {
|
||||
t.Fatalf("server1 path = %q, want %q", r.URL.Path, "/pico/ws")
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = io.WriteString(w, "server1")
|
||||
}))
|
||||
defer server1.Close()
|
||||
|
||||
server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/pico/ws" {
|
||||
t.Fatalf("server2 path = %q, want %q", r.URL.Path, "/pico/ws")
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = io.WriteString(w, "server2")
|
||||
}))
|
||||
defer server2.Close()
|
||||
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Gateway.Host = "127.0.0.1"
|
||||
cfg.Gateway.Port = mustGatewayTestPort(t, server1.URL)
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
req1 := httptest.NewRequest(http.MethodGet, "/pico/ws", nil)
|
||||
rec1 := httptest.NewRecorder()
|
||||
handler(rec1, req1)
|
||||
|
||||
if rec1.Code != http.StatusOK {
|
||||
t.Fatalf("first status = %d, want %d", rec1.Code, http.StatusOK)
|
||||
}
|
||||
if body := rec1.Body.String(); body != "server1" {
|
||||
t.Fatalf("first body = %q, want %q", body, "server1")
|
||||
}
|
||||
|
||||
cfg.Gateway.Port = mustGatewayTestPort(t, server2.URL)
|
||||
if err := config.SaveConfig(configPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig() error = %v", err)
|
||||
}
|
||||
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/pico/ws", nil)
|
||||
rec2 := httptest.NewRecorder()
|
||||
handler(rec2, req2)
|
||||
|
||||
if rec2.Code != http.StatusOK {
|
||||
t.Fatalf("second status = %d, want %d", rec2.Code, http.StatusOK)
|
||||
}
|
||||
if body := rec2.Body.String(); body != "server2" {
|
||||
t.Fatalf("second body = %q, want %q", body, "server2")
|
||||
}
|
||||
}
|
||||
|
||||
func mustGatewayTestPort(t *testing.T, rawURL string) int {
|
||||
t.Helper()
|
||||
|
||||
parsed, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
t.Fatalf("url.Parse() error = %v", err)
|
||||
}
|
||||
|
||||
port, err := strconv.Atoi(parsed.Port())
|
||||
if err != nil {
|
||||
t.Fatalf("Atoi(%q) error = %v", parsed.Port(), err)
|
||||
}
|
||||
|
||||
return port
|
||||
}
|
||||
|
||||
@@ -71,7 +71,4 @@ func (h *Handler) RegisterRoutes(mux *http.ServeMux) {
|
||||
h.registerLauncherConfigRoutes(mux)
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down the handler, closing all SSE connections.
|
||||
func (h *Handler) Shutdown() {
|
||||
gateway.events.Shutdown()
|
||||
}
|
||||
func (h *Handler) Shutdown() {}
|
||||
|
||||
@@ -4,16 +4,14 @@ import (
|
||||
"log"
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// JSONContentType sets the Content-Type header to application/json for
|
||||
// API requests handled by the wrapped handler.
|
||||
// SSE endpoints (text/event-stream) are excluded.
|
||||
func JSONContentType(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.HasPrefix(r.URL.Path, "/api/") && !strings.HasSuffix(r.URL.Path, "/events") {
|
||||
if len(r.URL.Path) >= 5 && r.URL.Path[:5] == "/api/" {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
@@ -32,7 +30,6 @@ func (rr *responseRecorder) WriteHeader(code int) {
|
||||
}
|
||||
|
||||
// Flush delegates to the underlying ResponseWriter if it implements http.Flusher.
|
||||
// This is required for SSE (Server-Sent Events) to work through the middleware.
|
||||
func (rr *responseRecorder) Flush() {
|
||||
if f, ok := rr.ResponseWriter.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
|
||||
@@ -94,7 +94,7 @@ func onReady() {
|
||||
func onExit() {
|
||||
fmt.Println(T(Exiting))
|
||||
|
||||
// First, shutdown API handler to close all SSE connections
|
||||
// First, shutdown API handler
|
||||
if apiHandler != nil {
|
||||
apiHandler.Shutdown()
|
||||
}
|
||||
|
||||
@@ -56,14 +56,20 @@ export function AppHeader() {
|
||||
const isRunning = gwState === "running"
|
||||
const isStarting = gwState === "starting"
|
||||
const isRestarting = gwState === "restarting"
|
||||
const isStopping = gwState === "stopping"
|
||||
const isStopped = gwState === "stopped" || gwState === "unknown"
|
||||
const showNotConnectedHint =
|
||||
!isRestarting && canStart && (gwState === "stopped" || gwState === "error")
|
||||
!isRestarting &&
|
||||
!isStopping &&
|
||||
canStart &&
|
||||
(gwState === "stopped" || gwState === "error")
|
||||
|
||||
const [showStopDialog, setShowStopDialog] = React.useState(false)
|
||||
|
||||
const handleGatewayToggle = () => {
|
||||
if (gwLoading || isRestarting || (!isRunning && !canStart)) return
|
||||
if (gwLoading || isRestarting || isStopping || (!isRunning && !canStart)) {
|
||||
return
|
||||
}
|
||||
if (isRunning) {
|
||||
setShowStopDialog(true)
|
||||
} else {
|
||||
@@ -137,7 +143,7 @@ export function AppHeader() {
|
||||
size="icon-sm"
|
||||
className="bg-amber-500/15 text-amber-700 hover:bg-amber-500/25 hover:text-amber-800 dark:text-amber-300 dark:hover:bg-amber-500/25"
|
||||
onClick={handleGatewayRestart}
|
||||
disabled={gwLoading || isRestarting || !canStart}
|
||||
disabled={gwLoading || isRestarting || isStopping || !canStart}
|
||||
aria-label={t("header.gateway.action.restart")}
|
||||
>
|
||||
<IconRefresh className="size-4" />
|
||||
@@ -168,25 +174,31 @@ export function AppHeader() {
|
||||
</Tooltip>
|
||||
) : (
|
||||
<Button
|
||||
variant={isStarting || isRestarting ? "secondary" : "default"}
|
||||
variant={
|
||||
isStarting || isRestarting || isStopping ? "secondary" : "default"
|
||||
}
|
||||
size="sm"
|
||||
className={`h-8 gap-2 px-3 ${
|
||||
isStopped ? "bg-green-500 text-white hover:bg-green-600" : ""
|
||||
}`}
|
||||
onClick={handleGatewayToggle}
|
||||
disabled={gwLoading || isStarting || isRestarting || !canStart}
|
||||
disabled={
|
||||
gwLoading || isStarting || isRestarting || isStopping || !canStart
|
||||
}
|
||||
>
|
||||
{gwLoading || isStarting || isRestarting ? (
|
||||
{gwLoading || isStarting || isRestarting || isStopping ? (
|
||||
<IconLoader2 className="h-4 w-4 animate-spin opacity-70" />
|
||||
) : (
|
||||
<IconPlayerPlay className="h-4 w-4 opacity-80" />
|
||||
)}
|
||||
<span className="text-xs font-semibold">
|
||||
{isRestarting
|
||||
? t("header.gateway.status.restarting")
|
||||
: isStarting
|
||||
? t("header.gateway.status.starting")
|
||||
: t("header.gateway.action.start")}
|
||||
{isStopping
|
||||
? t("header.gateway.status.stopping")
|
||||
: isRestarting
|
||||
? t("header.gateway.status.restarting")
|
||||
: isStarting
|
||||
? t("header.gateway.status.starting")
|
||||
: t("header.gateway.action.start")}
|
||||
</span>
|
||||
</Button>
|
||||
)}
|
||||
|
||||
@@ -37,7 +37,9 @@ export function useGatewayLogs() {
|
||||
const fetchLogs = async () => {
|
||||
if (
|
||||
!mounted ||
|
||||
!["running", "starting", "restarting"].includes(gateway.status)
|
||||
!["running", "starting", "restarting", "stopping"].includes(
|
||||
gateway.status,
|
||||
)
|
||||
) {
|
||||
if (mounted) {
|
||||
timeout = setTimeout(fetchLogs, 1000)
|
||||
|
||||
@@ -1,83 +1,24 @@
|
||||
import { useAtomValue } from "jotai"
|
||||
import { useCallback, useEffect, useState } from "react"
|
||||
|
||||
import { restartGateway, startGateway, stopGateway } from "@/api/gateway"
|
||||
import {
|
||||
type GatewayStatusResponse,
|
||||
getGatewayStatus,
|
||||
restartGateway,
|
||||
startGateway,
|
||||
stopGateway,
|
||||
} from "@/api/gateway"
|
||||
import {
|
||||
applyGatewayStatusToStore,
|
||||
beginGatewayStoppingTransition,
|
||||
cancelGatewayStoppingTransition,
|
||||
gatewayAtom,
|
||||
refreshGatewayState,
|
||||
subscribeGatewayPolling,
|
||||
updateGatewayStore,
|
||||
} from "@/store"
|
||||
|
||||
// Global variable to ensure we only have one SSE connection
|
||||
let sseInitialized = false
|
||||
|
||||
export function useGateway() {
|
||||
const gateway = useAtomValue(gatewayAtom)
|
||||
const { status: state, canStart, restartRequired } = gateway
|
||||
const [loading, setLoading] = useState(false)
|
||||
|
||||
const applyGatewayStatus = useCallback((data: GatewayStatusResponse) => {
|
||||
applyGatewayStatusToStore(data)
|
||||
}, [])
|
||||
|
||||
// Initialize global SSE connection once
|
||||
useEffect(() => {
|
||||
if (sseInitialized) return
|
||||
sseInitialized = true
|
||||
|
||||
getGatewayStatus()
|
||||
.then((data) => applyGatewayStatus(data))
|
||||
.catch(() => {
|
||||
updateGatewayStore({
|
||||
status: "unknown",
|
||||
canStart: true,
|
||||
restartRequired: false,
|
||||
})
|
||||
})
|
||||
|
||||
const statusPoll = window.setInterval(() => {
|
||||
getGatewayStatus()
|
||||
.then((data) => applyGatewayStatus(data))
|
||||
.catch(() => {
|
||||
// ignore polling errors
|
||||
})
|
||||
}, 5000)
|
||||
|
||||
// Subscribe to SSE for real-time updates globally
|
||||
const es = new EventSource("/api/gateway/events")
|
||||
|
||||
es.onmessage = (event) => {
|
||||
try {
|
||||
const data = JSON.parse(event.data)
|
||||
if (
|
||||
data.gateway_status ||
|
||||
typeof data.gateway_start_allowed === "boolean"
|
||||
) {
|
||||
applyGatewayStatus(data)
|
||||
}
|
||||
} catch {
|
||||
// ignore
|
||||
}
|
||||
}
|
||||
|
||||
es.onerror = () => {
|
||||
// EventSource will auto-reconnect. Preserve the last known gateway
|
||||
// status so transient SSE disconnects do not suppress chat websocket
|
||||
// reconnects while polling catches up.
|
||||
}
|
||||
|
||||
return () => {
|
||||
window.clearInterval(statusPoll)
|
||||
es.close()
|
||||
sseInitialized = false
|
||||
}
|
||||
}, [applyGatewayStatus])
|
||||
return subscribeGatewayPolling()
|
||||
}, [])
|
||||
|
||||
const start = useCallback(async () => {
|
||||
if (!canStart) return
|
||||
@@ -85,33 +26,28 @@ export function useGateway() {
|
||||
setLoading(true)
|
||||
try {
|
||||
await startGateway()
|
||||
// SSE will push the real state changes, but set optimistic state
|
||||
updateGatewayStore({ status: "starting" })
|
||||
} catch (err) {
|
||||
console.error("Failed to start gateway:", err)
|
||||
try {
|
||||
const status = await getGatewayStatus()
|
||||
applyGatewayStatus(status)
|
||||
} catch {
|
||||
updateGatewayStore({ status: "unknown" })
|
||||
}
|
||||
} finally {
|
||||
setLoading(false)
|
||||
}
|
||||
}, [applyGatewayStatus, canStart])
|
||||
|
||||
const stop = useCallback(async () => {
|
||||
setLoading(true)
|
||||
try {
|
||||
await stopGateway()
|
||||
updateGatewayStore({
|
||||
status: "stopped",
|
||||
canStart: true,
|
||||
status: "starting",
|
||||
restartRequired: false,
|
||||
})
|
||||
} catch (err) {
|
||||
console.error("Failed to stop gateway:", err)
|
||||
console.error("Failed to start gateway:", err)
|
||||
} finally {
|
||||
await refreshGatewayState({ force: true })
|
||||
setLoading(false)
|
||||
}
|
||||
}, [canStart])
|
||||
|
||||
const stop = useCallback(async () => {
|
||||
setLoading(true)
|
||||
beginGatewayStoppingTransition()
|
||||
try {
|
||||
await stopGateway()
|
||||
} catch (err) {
|
||||
console.error("Failed to stop gateway:", err)
|
||||
cancelGatewayStoppingTransition()
|
||||
} finally {
|
||||
await refreshGatewayState({ force: true })
|
||||
setLoading(false)
|
||||
}
|
||||
}, [])
|
||||
@@ -119,34 +55,20 @@ export function useGateway() {
|
||||
const restart = useCallback(async () => {
|
||||
if (state !== "running") return
|
||||
|
||||
const previousState = state
|
||||
const previousCanStart = canStart
|
||||
const previousRestartRequired = restartRequired
|
||||
|
||||
setLoading(true)
|
||||
updateGatewayStore({
|
||||
status: "restarting",
|
||||
restartRequired: false,
|
||||
})
|
||||
|
||||
try {
|
||||
await restartGateway()
|
||||
updateGatewayStore({
|
||||
status: "restarting",
|
||||
restartRequired: false,
|
||||
})
|
||||
} catch (err) {
|
||||
console.error("Failed to restart gateway:", err)
|
||||
try {
|
||||
const status = await getGatewayStatus()
|
||||
applyGatewayStatus(status)
|
||||
} catch {
|
||||
updateGatewayStore({
|
||||
status: previousState,
|
||||
canStart: previousCanStart,
|
||||
restartRequired: previousRestartRequired,
|
||||
})
|
||||
}
|
||||
} finally {
|
||||
await refreshGatewayState({ force: true })
|
||||
setLoading(false)
|
||||
}
|
||||
}, [applyGatewayStatus, canStart, restartRequired, state])
|
||||
}, [state])
|
||||
|
||||
return { state, loading, canStart, restartRequired, start, stop, restart }
|
||||
}
|
||||
|
||||
@@ -63,7 +63,8 @@
|
||||
},
|
||||
"status": {
|
||||
"starting": "Starting Gateway...",
|
||||
"restarting": "Restarting Gateway..."
|
||||
"restarting": "Restarting Gateway...",
|
||||
"stopping": "Stopping Gateway..."
|
||||
},
|
||||
"restartRequired": "Model changes require a gateway restart to take effect."
|
||||
}
|
||||
|
||||
@@ -63,7 +63,8 @@
|
||||
},
|
||||
"status": {
|
||||
"starting": "服务启动中...",
|
||||
"restarting": "服务重启中..."
|
||||
"restarting": "服务重启中...",
|
||||
"stopping": "服务停止中..."
|
||||
},
|
||||
"restartRequired": "切换默认模型后需要重启服务才能生效。"
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ export type GatewayState =
|
||||
| "running"
|
||||
| "starting"
|
||||
| "restarting"
|
||||
| "stopping"
|
||||
| "stopped"
|
||||
| "error"
|
||||
| "unknown"
|
||||
@@ -24,9 +25,29 @@ const DEFAULT_GATEWAY_STATE: GatewayStoreState = {
|
||||
restartRequired: false,
|
||||
}
|
||||
|
||||
const GATEWAY_POLL_INTERVAL_MS = 2000
|
||||
const GATEWAY_TRANSIENT_POLL_INTERVAL_MS = 1000
|
||||
const GATEWAY_STOPPING_TIMEOUT_MS = 5000
|
||||
|
||||
interface RefreshGatewayStateOptions {
|
||||
force?: boolean
|
||||
}
|
||||
|
||||
// Global atom for gateway state
|
||||
export const gatewayAtom = atom<GatewayStoreState>(DEFAULT_GATEWAY_STATE)
|
||||
|
||||
let gatewayPollingSubscribers = 0
|
||||
let gatewayPollingTimer: ReturnType<typeof setTimeout> | null = null
|
||||
let gatewayPollingRequest: Promise<void> | null = null
|
||||
let gatewayStoppingTimer: ReturnType<typeof setTimeout> | null = null
|
||||
|
||||
function clearGatewayStoppingTimeout() {
|
||||
if (gatewayStoppingTimer !== null) {
|
||||
clearTimeout(gatewayStoppingTimer)
|
||||
gatewayStoppingTimer = null
|
||||
}
|
||||
}
|
||||
|
||||
function normalizeGatewayStoreState(
|
||||
prev: GatewayStoreState,
|
||||
patch: GatewayStorePatch,
|
||||
@@ -49,10 +70,38 @@ export function updateGatewayStore(
|
||||
| GatewayStorePatch
|
||||
| ((prev: GatewayStoreState) => GatewayStorePatch | GatewayStoreState),
|
||||
) {
|
||||
getDefaultStore().set(gatewayAtom, (prev) => {
|
||||
const store = getDefaultStore()
|
||||
store.set(gatewayAtom, (prev) => {
|
||||
const nextPatch = typeof patch === "function" ? patch(prev) : patch
|
||||
return normalizeGatewayStoreState(prev, nextPatch)
|
||||
})
|
||||
const nextState = store.get(gatewayAtom)
|
||||
if (nextState?.status !== "stopping") {
|
||||
clearGatewayStoppingTimeout()
|
||||
}
|
||||
}
|
||||
|
||||
export function beginGatewayStoppingTransition() {
|
||||
clearGatewayStoppingTimeout()
|
||||
updateGatewayStore({
|
||||
status: "stopping",
|
||||
canStart: false,
|
||||
restartRequired: false,
|
||||
})
|
||||
gatewayStoppingTimer = setTimeout(() => {
|
||||
gatewayStoppingTimer = null
|
||||
updateGatewayStore((prev) =>
|
||||
prev.status === "stopping" ? { status: "running" } : prev,
|
||||
)
|
||||
void refreshGatewayState({ force: true })
|
||||
}, GATEWAY_STOPPING_TIMEOUT_MS)
|
||||
}
|
||||
|
||||
export function cancelGatewayStoppingTransition() {
|
||||
clearGatewayStoppingTimeout()
|
||||
updateGatewayStore((prev) =>
|
||||
prev.status === "stopping" ? { status: "running" } : prev,
|
||||
)
|
||||
}
|
||||
|
||||
export function applyGatewayStatusToStore(
|
||||
@@ -64,21 +113,92 @@ export function applyGatewayStatusToStore(
|
||||
>,
|
||||
) {
|
||||
updateGatewayStore((prev) => ({
|
||||
status: data.gateway_status ?? prev.status,
|
||||
canStart: data.gateway_start_allowed ?? prev.canStart,
|
||||
restartRequired:
|
||||
data.gateway_restart_required ??
|
||||
(data.gateway_status && data.gateway_status !== "running"
|
||||
status:
|
||||
prev.status === "stopping" && data.gateway_status === "running"
|
||||
? "stopping"
|
||||
: (data.gateway_status ?? prev.status),
|
||||
canStart:
|
||||
prev.status === "stopping" && data.gateway_status === "running"
|
||||
? false
|
||||
: prev.restartRequired),
|
||||
: (data.gateway_start_allowed ?? prev.canStart),
|
||||
restartRequired:
|
||||
prev.status === "stopping" && data.gateway_status === "running"
|
||||
? false
|
||||
: (data.gateway_restart_required ?? prev.restartRequired),
|
||||
}))
|
||||
}
|
||||
|
||||
export async function refreshGatewayState() {
|
||||
function nextGatewayPollInterval() {
|
||||
const status = getDefaultStore().get(gatewayAtom).status
|
||||
if (
|
||||
status === "starting" ||
|
||||
status === "restarting" ||
|
||||
status === "stopping"
|
||||
) {
|
||||
return GATEWAY_TRANSIENT_POLL_INTERVAL_MS
|
||||
}
|
||||
return GATEWAY_POLL_INTERVAL_MS
|
||||
}
|
||||
|
||||
function scheduleGatewayPoll(delay = nextGatewayPollInterval()) {
|
||||
if (gatewayPollingSubscribers === 0) {
|
||||
return
|
||||
}
|
||||
|
||||
if (gatewayPollingTimer !== null) {
|
||||
clearTimeout(gatewayPollingTimer)
|
||||
}
|
||||
|
||||
gatewayPollingTimer = setTimeout(() => {
|
||||
gatewayPollingTimer = null
|
||||
void refreshGatewayState()
|
||||
}, delay)
|
||||
}
|
||||
|
||||
export async function refreshGatewayState(
|
||||
options: RefreshGatewayStateOptions = {},
|
||||
) {
|
||||
if (gatewayPollingRequest) {
|
||||
await gatewayPollingRequest
|
||||
if (options.force) {
|
||||
return refreshGatewayState()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
gatewayPollingRequest = (async () => {
|
||||
try {
|
||||
const status = await getGatewayStatus()
|
||||
applyGatewayStatusToStore(status)
|
||||
} catch {
|
||||
// Preserve the last known state when a poll fails.
|
||||
} finally {
|
||||
gatewayPollingRequest = null
|
||||
scheduleGatewayPoll()
|
||||
}
|
||||
})()
|
||||
|
||||
try {
|
||||
const status = await getGatewayStatus()
|
||||
applyGatewayStatusToStore(status)
|
||||
} catch {
|
||||
updateGatewayStore(DEFAULT_GATEWAY_STATE)
|
||||
await gatewayPollingRequest
|
||||
} finally {
|
||||
if (gatewayPollingSubscribers === 0 && gatewayPollingTimer !== null) {
|
||||
clearTimeout(gatewayPollingTimer)
|
||||
gatewayPollingTimer = null
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export function subscribeGatewayPolling() {
|
||||
gatewayPollingSubscribers += 1
|
||||
if (gatewayPollingSubscribers === 1) {
|
||||
void refreshGatewayState()
|
||||
}
|
||||
|
||||
return () => {
|
||||
gatewayPollingSubscribers = Math.max(0, gatewayPollingSubscribers - 1)
|
||||
if (gatewayPollingSubscribers === 0 && gatewayPollingTimer !== null) {
|
||||
clearTimeout(gatewayPollingTimer)
|
||||
gatewayPollingTimer = null
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user