diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go index a3fae5744..77c2e3c10 100644 --- a/pkg/agent/loop_test.go +++ b/pkg/agent/loop_test.go @@ -2321,7 +2321,7 @@ func TestProcessMessage_PublishesReasoningContentToReasoningChannel(t *testing.T if outbound.Content != "thinking trace" { t.Fatalf("reasoning content = %q, want %q", outbound.Content, "thinking trace") } - case <-time.After(2 * time.Second): + case <-time.After(3 * time.Second): t.Fatal("expected reasoning content to be published to reasoning channel") } } diff --git a/pkg/channels/dynamic_mux.go b/pkg/channels/dynamic_mux.go new file mode 100644 index 000000000..399f18b7a --- /dev/null +++ b/pkg/channels/dynamic_mux.go @@ -0,0 +1,74 @@ +package channels + +import ( + "net/http" + "strings" + "sync" +) + +// dynamicServeMux is an http.Handler that supports dynamic registration +// and unregistration of handlers without recreating the server. +type dynamicServeMux struct { + mu sync.RWMutex + handlers map[string]http.Handler +} + +func newDynamicServeMux() *dynamicServeMux { + return &dynamicServeMux{ + handlers: make(map[string]http.Handler), + } +} + +// Handle registers the handler for the given pattern. +func (dm *dynamicServeMux) Handle(pattern string, handler http.Handler) { + dm.mu.Lock() + defer dm.mu.Unlock() + dm.handlers[pattern] = handler +} + +// HandleFunc registers the handler function for the given pattern. +func (dm *dynamicServeMux) HandleFunc(pattern string, handler func(http.ResponseWriter, *http.Request)) { + dm.Handle(pattern, http.HandlerFunc(handler)) +} + +// Unhandle removes the handler for the given pattern. +func (dm *dynamicServeMux) Unhandle(pattern string) { + dm.mu.Lock() + defer dm.mu.Unlock() + delete(dm.handlers, pattern) +} + +// ServeHTTP dispatches the request to the handler whose pattern best matches +// the request URL path. It supports both exact path matches and subtree +// (trailing-slash) prefix matches, choosing the longest prefix on collision. +func (dm *dynamicServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { + dm.mu.RLock() + defer dm.mu.RUnlock() + + path := r.URL.Path + + // Exact match first. + if h, ok := dm.handlers[path]; ok { + h.ServeHTTP(w, r) + return + } + + // Longest subtree prefix match (patterns ending with "/"). + var bestLen int + var bestHandler http.Handler + for pattern, handler := range dm.handlers { + if strings.HasSuffix(pattern, "/") && strings.HasPrefix(path, pattern) { + if len(pattern) > bestLen { + bestLen = len(pattern) + bestHandler = handler + } + } + } + + if bestHandler != nil { + bestHandler.ServeHTTP(w, r) + return + } + + http.NotFound(w, r) +} diff --git a/pkg/channels/dynamic_mux_test.go b/pkg/channels/dynamic_mux_test.go new file mode 100644 index 000000000..d895c69c9 --- /dev/null +++ b/pkg/channels/dynamic_mux_test.go @@ -0,0 +1,162 @@ +package channels + +import ( + "net/http" + "net/http/httptest" + "sync" + "testing" +) + +func TestDynamicServeMuxExactMatch(t *testing.T) { + dm := newDynamicServeMux() + dm.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + rec := httptest.NewRecorder() + dm.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/health", nil)) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", rec.Code) + } +} + +func TestDynamicServeMuxSubtreePrefixMatch(t *testing.T) { + dm := newDynamicServeMux() + dm.HandleFunc("/api/", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusCreated) + }) + + for _, path := range []string{"/api/", "/api/v1", "/api/v1/resource"} { + rec := httptest.NewRecorder() + dm.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, path, nil)) + if rec.Code != http.StatusCreated { + t.Fatalf("path %q: expected 201, got %d", path, rec.Code) + } + } +} + +func TestDynamicServeMuxExactOverPrefix(t *testing.T) { + dm := newDynamicServeMux() + dm.HandleFunc("/api", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + dm.HandleFunc("/api/", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusCreated) + }) + + // Exact match wins + rec := httptest.NewRecorder() + dm.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/api", nil)) + if rec.Code != http.StatusOK { + t.Fatalf("exact match: expected 200, got %d", rec.Code) + } + + // Prefix match for sub-paths + rec = httptest.NewRecorder() + dm.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/api/v1", nil)) + if rec.Code != http.StatusCreated { + t.Fatalf("prefix match: expected 201, got %d", rec.Code) + } +} + +func TestDynamicServeMuxLongestPrefixWins(t *testing.T) { + dm := newDynamicServeMux() + dm.HandleFunc("/a/", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + dm.HandleFunc("/a/b/", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusAccepted) + }) + + rec := httptest.NewRecorder() + dm.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/a/b/c", nil)) + if rec.Code != http.StatusAccepted { + t.Fatalf("longest prefix: expected 202, got %d", rec.Code) + } +} + +func TestDynamicServeMuxNotFound(t *testing.T) { + dm := newDynamicServeMux() + rec := httptest.NewRecorder() + dm.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/nonexistent", nil)) + if rec.Code != http.StatusNotFound { + t.Fatalf("expected 404, got %d", rec.Code) + } +} + +func TestDynamicServeMuxUnhandle(t *testing.T) { + dm := newDynamicServeMux() + dm.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + // Verify it works before removal + rec := httptest.NewRecorder() + dm.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/test", nil)) + if rec.Code != http.StatusOK { + t.Fatalf("before unhandle: expected 200, got %d", rec.Code) + } + + // Remove and verify 404 + dm.Unhandle("/test") + rec = httptest.NewRecorder() + dm.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/test", nil)) + if rec.Code != http.StatusNotFound { + t.Fatalf("after unhandle: expected 404, got %d", rec.Code) + } +} + +func TestDynamicServeMuxConcurrent(t *testing.T) { + dm := newDynamicServeMux() + dm.HandleFunc("/static", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + var wg sync.WaitGroup + const goroutines = 50 + + // Concurrent Handle/Unhandle + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + pattern := "/concurrent" + if i%2 == 0 { + dm.HandleFunc(pattern, func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusAccepted) + }) + } else { + dm.Unhandle(pattern) + } + }(i) + } + + // Concurrent ServeHTTP + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + rec := httptest.NewRecorder() + dm.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/static", nil)) + // Should not panic; result is either 200 or 404 + _ = rec.Code + }() + } + + wg.Wait() +} + +func TestDynamicServeMuxHandleUsesHandler(t *testing.T) { + dm := newDynamicServeMux() + + var called bool + dm.Handle("/handler", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + })) + + rec := httptest.NewRecorder() + dm.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/handler", nil)) + if !called { + t.Fatal("handler was not called") + } +} diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go index 0486d2c5f..4e8074189 100644 --- a/pkg/channels/manager.go +++ b/pkg/channels/manager.go @@ -83,7 +83,7 @@ type Manager struct { config *config.Config mediaStore media.MediaStore dispatchTask *asyncTask - mux *http.ServeMux + mux *dynamicServeMux httpServer *http.Server mu sync.RWMutex placeholders sync.Map // "channel:chatID" → placeholderID (string) @@ -436,7 +436,7 @@ func (m *Manager) initChannels(channels *config.ChannelsConfig) error { // It registers health endpoints from the health server and discovers channels // that implement WebhookHandler and/or HealthChecker to register their handlers. func (m *Manager) SetupHTTPServer(addr string, healthServer *health.Server) { - m.mux = http.NewServeMux() + m.mux = newDynamicServeMux() // Register health endpoints if healthServer != nil { @@ -444,22 +444,7 @@ func (m *Manager) SetupHTTPServer(addr string, healthServer *health.Server) { } // Discover and register webhook handlers and health checkers - for name, ch := range m.channels { - if wh, ok := ch.(WebhookHandler); ok { - m.mux.Handle(wh.WebhookPath(), wh) - logger.InfoCF("channels", "Webhook handler registered", map[string]any{ - "channel": name, - "path": wh.WebhookPath(), - }) - } - if hc, ok := ch.(HealthChecker); ok { - m.mux.HandleFunc(hc.HealthPath(), hc.HealthHandler) - logger.InfoCF("channels", "Health endpoint registered", map[string]any{ - "channel": name, - "path": hc.HealthPath(), - }) - } - } + m.registerHTTPHandlersLocked() m.httpServer = &http.Server{ Addr: addr, @@ -469,6 +454,53 @@ func (m *Manager) SetupHTTPServer(addr string, healthServer *health.Server) { } } +// registerHTTPHandlersLocked registers webhook and health-check handlers for +// all channels currently in m.channels. Caller must hold m.mu (or ensure +// exclusive access). +func (m *Manager) registerHTTPHandlersLocked() { + for name, ch := range m.channels { + m.registerChannelHTTPHandler(name, ch) + } +} + +// registerChannelHTTPHandler registers the webhook/health handlers for a +// single channel onto m.mux. +func (m *Manager) registerChannelHTTPHandler(name string, ch Channel) { + if wh, ok := ch.(WebhookHandler); ok { + m.mux.Handle(wh.WebhookPath(), wh) + logger.InfoCF("channels", "Webhook handler registered", map[string]any{ + "channel": name, + "path": wh.WebhookPath(), + }) + } + if hc, ok := ch.(HealthChecker); ok { + m.mux.HandleFunc(hc.HealthPath(), hc.HealthHandler) + logger.InfoCF("channels", "Health endpoint registered", map[string]any{ + "channel": name, + "path": hc.HealthPath(), + }) + } +} + +// unregisterChannelHTTPHandler removes the webhook/health handlers for a +// single channel from m.mux. +func (m *Manager) unregisterChannelHTTPHandler(name string, ch Channel) { + if wh, ok := ch.(WebhookHandler); ok { + m.mux.Unhandle(wh.WebhookPath()) + logger.InfoCF("channels", "Webhook handler unregistered", map[string]any{ + "channel": name, + "path": wh.WebhookPath(), + }) + } + if hc, ok := ch.(HealthChecker); ok { + m.mux.Unhandle(hc.HealthPath()) + logger.InfoCF("channels", "Health endpoint unregistered", map[string]any{ + "channel": name, + "path": hc.HealthPath(), + }) + } +} + func (m *Manager) StartAll(ctx context.Context) error { m.mu.Lock() defer m.mu.Unlock() @@ -984,8 +1016,17 @@ func (m *Manager) GetEnabledChannels() []string { func (m *Manager) Reload(ctx context.Context, cfg *config.Config) error { m.mu.Lock() defer m.mu.Unlock() + + // Save old config so we can revert on error. + oldConfig := m.config + + // Update config early: initChannel uses m.config via factory(m.config, m.bus). + m.config = cfg + list := toChannelHashes(cfg) added, removed := compareChannels(m.channelHashes, list) + + deferFuncs := make([]func(), 0, len(removed)+len(added)) for _, name := range removed { // Stop all channels channel := m.channels[name] @@ -998,20 +1039,24 @@ func (m *Manager) Reload(ctx context.Context, cfg *config.Config) error { "error": err.Error(), }) } - go func() { + deferFuncs = append(deferFuncs, func() { m.UnregisterChannel(name) - }() + }) } dispatchCtx, cancel := context.WithCancel(ctx) m.dispatchTask = &asyncTask{cancel: cancel} cc, err := toChannelConfig(cfg, added) if err != nil { logger.ErrorC("channels", fmt.Sprintf("toChannelConfig error: %v", err)) + m.config = oldConfig + cancel() return err } err = m.initChannels(cc) if err != nil { logger.ErrorC("channels", fmt.Sprintf("initChannels error: %v", err)) + m.config = oldConfig + cancel() return err } for _, name := range added { @@ -1031,13 +1076,18 @@ func (m *Manager) Reload(ctx context.Context, cfg *config.Config) error { m.workers[name] = w go m.runWorker(dispatchCtx, name, w) go m.runMediaWorker(dispatchCtx, name, w) - go func() { + deferFuncs = append(deferFuncs, func() { m.RegisterChannel(name, channel) - }() + }) } - m.config = cfg - m.channelHashes = toChannelHashes(cfg) + // Commit hashes only on full success. + m.channelHashes = list + go func() { + for _, f := range deferFuncs { + f() + } + }() return nil } @@ -1045,11 +1095,17 @@ func (m *Manager) RegisterChannel(name string, channel Channel) { m.mu.Lock() defer m.mu.Unlock() m.channels[name] = channel + if m.mux != nil { + m.registerChannelHTTPHandler(name, channel) + } } func (m *Manager) UnregisterChannel(name string) { m.mu.Lock() defer m.mu.Unlock() + if ch, ok := m.channels[name]; ok && m.mux != nil { + m.unregisterChannelHTTPHandler(name, ch) + } if w, ok := m.workers[name]; ok && w != nil { close(w.queue) <-w.done diff --git a/pkg/gateway/gateway.go b/pkg/gateway/gateway.go index 03d7dfe0c..c35b3e744 100644 --- a/pkg/gateway/gateway.go +++ b/pkg/gateway/gateway.go @@ -490,12 +490,13 @@ func restartServices( } al.SetMediaStore(runningServices.MediaStore) - runningServices.ChannelManager, err = channels.NewManager(cfg, msgBus, runningServices.MediaStore) - if err != nil { - return fmt.Errorf("error recreating channel manager: %w", err) - } al.SetChannelManager(runningServices.ChannelManager) + if err = runningServices.ChannelManager.Reload(context.Background(), cfg); err != nil { + return fmt.Errorf("error reload channels: %w", err) + } + fmt.Println(" ✓ Channels restarted.") + enabledChannels := runningServices.ChannelManager.GetEnabledChannels() if len(enabledChannels) > 0 { fmt.Printf(" ✓ Channels enabled: %s\n", enabledChannels) @@ -503,18 +504,6 @@ func restartServices( fmt.Println(" ⚠ Warning: No channels enabled") } - addr := fmt.Sprintf("%s:%d", cfg.Gateway.Host, cfg.Gateway.Port) - // Reuse existing HealthServer to preserve reloadFunc - if runningServices.HealthServer == nil { - runningServices.HealthServer = health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port) - } - runningServices.ChannelManager.SetupHTTPServer(addr, runningServices.HealthServer) - - if err = runningServices.ChannelManager.Reload(context.Background(), cfg); err != nil { - return fmt.Errorf("error reload channels: %w", err) - } - fmt.Println(" ✓ Channels restarted.") - stateManager := state.NewManager(cfg.WorkspacePath()) runningServices.DeviceService = devices.NewService(devices.Config{ Enabled: cfg.Devices.Enabled, diff --git a/pkg/health/server.go b/pkg/health/server.go index fe20e4b94..387cb0756 100644 --- a/pkg/health/server.go +++ b/pkg/health/server.go @@ -198,9 +198,17 @@ func (s *Server) readyHandler(w http.ResponseWriter, r *http.Request) { }) } +// HandlerMux is the interface for registering HTTP handlers, used by +// RegisterOnMux so that callers can pass any mux implementation +// (e.g. *http.ServeMux or a custom dynamic mux). +type HandlerMux interface { + Handle(pattern string, handler http.Handler) + HandleFunc(pattern string, handler func(http.ResponseWriter, *http.Request)) +} + // RegisterOnMux registers /health, /ready and /reload handlers onto the given mux. // This allows the health endpoints to be served by a shared HTTP server. -func (s *Server) RegisterOnMux(mux *http.ServeMux) { +func (s *Server) RegisterOnMux(mux HandlerMux) { mux.HandleFunc("/health", s.healthHandler) mux.HandleFunc("/ready", s.readyHandler) mux.HandleFunc("/reload", s.reloadHandler) diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go index 1bcc1cec9..1244addb2 100644 --- a/pkg/logger/logger.go +++ b/pkg/logger/logger.go @@ -353,6 +353,10 @@ func WarnCF(component string, message string, fields map[string]any) { logMessage(WARN, component, message, fields) } +func Warnf(message string, ss ...any) { + logMessage(WARN, "", fmt.Sprintf(message, ss...), nil) +} + func Error(message string) { logMessage(ERROR, "", message, nil) } diff --git a/web/backend/api/gateway.go b/web/backend/api/gateway.go index 4bde5ce82..808475cb6 100644 --- a/web/backend/api/gateway.go +++ b/web/backend/api/gateway.go @@ -730,8 +730,8 @@ func (h *Handler) gatewayStatusData() map[string]any { gateway.mu.Unlock() logger.ErrorC("gateway", fmt.Sprintf("Gateway health check failed: %v", err)) } else { - logger.InfoC("gateway", fmt.Sprintf("Gateway health status: %d", statusCode)) if statusCode != http.StatusOK { + logger.WarnC("gateway", fmt.Sprintf("Gateway health status: %d", statusCode)) gateway.mu.Lock() setGatewayRuntimeStatusLocked("error") gateway.mu.Unlock()