fix gateway reload will cause pico stop working issue (#2082)

* fix gateway reload will cause pico stop working issue

* fix for review
This commit is contained in:
Cytown
2026-03-28 11:30:31 +08:00
committed by GitHub
parent 60d7ec20a5
commit f1cb7cc8f5
8 changed files with 336 additions and 43 deletions
+1 -1
View File
@@ -2321,7 +2321,7 @@ func TestProcessMessage_PublishesReasoningContentToReasoningChannel(t *testing.T
if outbound.Content != "thinking trace" {
t.Fatalf("reasoning content = %q, want %q", outbound.Content, "thinking trace")
}
case <-time.After(2 * time.Second):
case <-time.After(3 * time.Second):
t.Fatal("expected reasoning content to be published to reasoning channel")
}
}
+74
View File
@@ -0,0 +1,74 @@
package channels
import (
"net/http"
"strings"
"sync"
)
// dynamicServeMux is an http.Handler that supports dynamic registration
// and unregistration of handlers without recreating the server.
type dynamicServeMux struct {
mu sync.RWMutex
handlers map[string]http.Handler
}
func newDynamicServeMux() *dynamicServeMux {
return &dynamicServeMux{
handlers: make(map[string]http.Handler),
}
}
// Handle registers the handler for the given pattern.
func (dm *dynamicServeMux) Handle(pattern string, handler http.Handler) {
dm.mu.Lock()
defer dm.mu.Unlock()
dm.handlers[pattern] = handler
}
// HandleFunc registers the handler function for the given pattern.
func (dm *dynamicServeMux) HandleFunc(pattern string, handler func(http.ResponseWriter, *http.Request)) {
dm.Handle(pattern, http.HandlerFunc(handler))
}
// Unhandle removes the handler for the given pattern.
func (dm *dynamicServeMux) Unhandle(pattern string) {
dm.mu.Lock()
defer dm.mu.Unlock()
delete(dm.handlers, pattern)
}
// ServeHTTP dispatches the request to the handler whose pattern best matches
// the request URL path. It supports both exact path matches and subtree
// (trailing-slash) prefix matches, choosing the longest prefix on collision.
func (dm *dynamicServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
dm.mu.RLock()
defer dm.mu.RUnlock()
path := r.URL.Path
// Exact match first.
if h, ok := dm.handlers[path]; ok {
h.ServeHTTP(w, r)
return
}
// Longest subtree prefix match (patterns ending with "/").
var bestLen int
var bestHandler http.Handler
for pattern, handler := range dm.handlers {
if strings.HasSuffix(pattern, "/") && strings.HasPrefix(path, pattern) {
if len(pattern) > bestLen {
bestLen = len(pattern)
bestHandler = handler
}
}
}
if bestHandler != nil {
bestHandler.ServeHTTP(w, r)
return
}
http.NotFound(w, r)
}
+162
View File
@@ -0,0 +1,162 @@
package channels
import (
"net/http"
"net/http/httptest"
"sync"
"testing"
)
func TestDynamicServeMuxExactMatch(t *testing.T) {
dm := newDynamicServeMux()
dm.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
rec := httptest.NewRecorder()
dm.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/health", nil))
if rec.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", rec.Code)
}
}
func TestDynamicServeMuxSubtreePrefixMatch(t *testing.T) {
dm := newDynamicServeMux()
dm.HandleFunc("/api/", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusCreated)
})
for _, path := range []string{"/api/", "/api/v1", "/api/v1/resource"} {
rec := httptest.NewRecorder()
dm.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, path, nil))
if rec.Code != http.StatusCreated {
t.Fatalf("path %q: expected 201, got %d", path, rec.Code)
}
}
}
func TestDynamicServeMuxExactOverPrefix(t *testing.T) {
dm := newDynamicServeMux()
dm.HandleFunc("/api", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
dm.HandleFunc("/api/", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusCreated)
})
// Exact match wins
rec := httptest.NewRecorder()
dm.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/api", nil))
if rec.Code != http.StatusOK {
t.Fatalf("exact match: expected 200, got %d", rec.Code)
}
// Prefix match for sub-paths
rec = httptest.NewRecorder()
dm.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/api/v1", nil))
if rec.Code != http.StatusCreated {
t.Fatalf("prefix match: expected 201, got %d", rec.Code)
}
}
func TestDynamicServeMuxLongestPrefixWins(t *testing.T) {
dm := newDynamicServeMux()
dm.HandleFunc("/a/", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
dm.HandleFunc("/a/b/", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusAccepted)
})
rec := httptest.NewRecorder()
dm.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/a/b/c", nil))
if rec.Code != http.StatusAccepted {
t.Fatalf("longest prefix: expected 202, got %d", rec.Code)
}
}
func TestDynamicServeMuxNotFound(t *testing.T) {
dm := newDynamicServeMux()
rec := httptest.NewRecorder()
dm.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/nonexistent", nil))
if rec.Code != http.StatusNotFound {
t.Fatalf("expected 404, got %d", rec.Code)
}
}
func TestDynamicServeMuxUnhandle(t *testing.T) {
dm := newDynamicServeMux()
dm.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
// Verify it works before removal
rec := httptest.NewRecorder()
dm.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/test", nil))
if rec.Code != http.StatusOK {
t.Fatalf("before unhandle: expected 200, got %d", rec.Code)
}
// Remove and verify 404
dm.Unhandle("/test")
rec = httptest.NewRecorder()
dm.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/test", nil))
if rec.Code != http.StatusNotFound {
t.Fatalf("after unhandle: expected 404, got %d", rec.Code)
}
}
func TestDynamicServeMuxConcurrent(t *testing.T) {
dm := newDynamicServeMux()
dm.HandleFunc("/static", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
var wg sync.WaitGroup
const goroutines = 50
// Concurrent Handle/Unhandle
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
pattern := "/concurrent"
if i%2 == 0 {
dm.HandleFunc(pattern, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusAccepted)
})
} else {
dm.Unhandle(pattern)
}
}(i)
}
// Concurrent ServeHTTP
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
rec := httptest.NewRecorder()
dm.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/static", nil))
// Should not panic; result is either 200 or 404
_ = rec.Code
}()
}
wg.Wait()
}
func TestDynamicServeMuxHandleUsesHandler(t *testing.T) {
dm := newDynamicServeMux()
var called bool
dm.Handle("/handler", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
}))
rec := httptest.NewRecorder()
dm.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/handler", nil))
if !called {
t.Fatal("handler was not called")
}
}
+80 -24
View File
@@ -83,7 +83,7 @@ type Manager struct {
config *config.Config
mediaStore media.MediaStore
dispatchTask *asyncTask
mux *http.ServeMux
mux *dynamicServeMux
httpServer *http.Server
mu sync.RWMutex
placeholders sync.Map // "channel:chatID" → placeholderID (string)
@@ -436,7 +436,7 @@ func (m *Manager) initChannels(channels *config.ChannelsConfig) error {
// It registers health endpoints from the health server and discovers channels
// that implement WebhookHandler and/or HealthChecker to register their handlers.
func (m *Manager) SetupHTTPServer(addr string, healthServer *health.Server) {
m.mux = http.NewServeMux()
m.mux = newDynamicServeMux()
// Register health endpoints
if healthServer != nil {
@@ -444,22 +444,7 @@ func (m *Manager) SetupHTTPServer(addr string, healthServer *health.Server) {
}
// Discover and register webhook handlers and health checkers
for name, ch := range m.channels {
if wh, ok := ch.(WebhookHandler); ok {
m.mux.Handle(wh.WebhookPath(), wh)
logger.InfoCF("channels", "Webhook handler registered", map[string]any{
"channel": name,
"path": wh.WebhookPath(),
})
}
if hc, ok := ch.(HealthChecker); ok {
m.mux.HandleFunc(hc.HealthPath(), hc.HealthHandler)
logger.InfoCF("channels", "Health endpoint registered", map[string]any{
"channel": name,
"path": hc.HealthPath(),
})
}
}
m.registerHTTPHandlersLocked()
m.httpServer = &http.Server{
Addr: addr,
@@ -469,6 +454,53 @@ func (m *Manager) SetupHTTPServer(addr string, healthServer *health.Server) {
}
}
// registerHTTPHandlersLocked registers webhook and health-check handlers for
// all channels currently in m.channels. Caller must hold m.mu (or ensure
// exclusive access).
func (m *Manager) registerHTTPHandlersLocked() {
for name, ch := range m.channels {
m.registerChannelHTTPHandler(name, ch)
}
}
// registerChannelHTTPHandler registers the webhook/health handlers for a
// single channel onto m.mux.
func (m *Manager) registerChannelHTTPHandler(name string, ch Channel) {
if wh, ok := ch.(WebhookHandler); ok {
m.mux.Handle(wh.WebhookPath(), wh)
logger.InfoCF("channels", "Webhook handler registered", map[string]any{
"channel": name,
"path": wh.WebhookPath(),
})
}
if hc, ok := ch.(HealthChecker); ok {
m.mux.HandleFunc(hc.HealthPath(), hc.HealthHandler)
logger.InfoCF("channels", "Health endpoint registered", map[string]any{
"channel": name,
"path": hc.HealthPath(),
})
}
}
// unregisterChannelHTTPHandler removes the webhook/health handlers for a
// single channel from m.mux.
func (m *Manager) unregisterChannelHTTPHandler(name string, ch Channel) {
if wh, ok := ch.(WebhookHandler); ok {
m.mux.Unhandle(wh.WebhookPath())
logger.InfoCF("channels", "Webhook handler unregistered", map[string]any{
"channel": name,
"path": wh.WebhookPath(),
})
}
if hc, ok := ch.(HealthChecker); ok {
m.mux.Unhandle(hc.HealthPath())
logger.InfoCF("channels", "Health endpoint unregistered", map[string]any{
"channel": name,
"path": hc.HealthPath(),
})
}
}
func (m *Manager) StartAll(ctx context.Context) error {
m.mu.Lock()
defer m.mu.Unlock()
@@ -984,8 +1016,17 @@ func (m *Manager) GetEnabledChannels() []string {
func (m *Manager) Reload(ctx context.Context, cfg *config.Config) error {
m.mu.Lock()
defer m.mu.Unlock()
// Save old config so we can revert on error.
oldConfig := m.config
// Update config early: initChannel uses m.config via factory(m.config, m.bus).
m.config = cfg
list := toChannelHashes(cfg)
added, removed := compareChannels(m.channelHashes, list)
deferFuncs := make([]func(), 0, len(removed)+len(added))
for _, name := range removed {
// Stop all channels
channel := m.channels[name]
@@ -998,20 +1039,24 @@ func (m *Manager) Reload(ctx context.Context, cfg *config.Config) error {
"error": err.Error(),
})
}
go func() {
deferFuncs = append(deferFuncs, func() {
m.UnregisterChannel(name)
}()
})
}
dispatchCtx, cancel := context.WithCancel(ctx)
m.dispatchTask = &asyncTask{cancel: cancel}
cc, err := toChannelConfig(cfg, added)
if err != nil {
logger.ErrorC("channels", fmt.Sprintf("toChannelConfig error: %v", err))
m.config = oldConfig
cancel()
return err
}
err = m.initChannels(cc)
if err != nil {
logger.ErrorC("channels", fmt.Sprintf("initChannels error: %v", err))
m.config = oldConfig
cancel()
return err
}
for _, name := range added {
@@ -1031,13 +1076,18 @@ func (m *Manager) Reload(ctx context.Context, cfg *config.Config) error {
m.workers[name] = w
go m.runWorker(dispatchCtx, name, w)
go m.runMediaWorker(dispatchCtx, name, w)
go func() {
deferFuncs = append(deferFuncs, func() {
m.RegisterChannel(name, channel)
}()
})
}
m.config = cfg
m.channelHashes = toChannelHashes(cfg)
// Commit hashes only on full success.
m.channelHashes = list
go func() {
for _, f := range deferFuncs {
f()
}
}()
return nil
}
@@ -1045,11 +1095,17 @@ func (m *Manager) RegisterChannel(name string, channel Channel) {
m.mu.Lock()
defer m.mu.Unlock()
m.channels[name] = channel
if m.mux != nil {
m.registerChannelHTTPHandler(name, channel)
}
}
func (m *Manager) UnregisterChannel(name string) {
m.mu.Lock()
defer m.mu.Unlock()
if ch, ok := m.channels[name]; ok && m.mux != nil {
m.unregisterChannelHTTPHandler(name, ch)
}
if w, ok := m.workers[name]; ok && w != nil {
close(w.queue)
<-w.done
+5 -16
View File
@@ -490,12 +490,13 @@ func restartServices(
}
al.SetMediaStore(runningServices.MediaStore)
runningServices.ChannelManager, err = channels.NewManager(cfg, msgBus, runningServices.MediaStore)
if err != nil {
return fmt.Errorf("error recreating channel manager: %w", err)
}
al.SetChannelManager(runningServices.ChannelManager)
if err = runningServices.ChannelManager.Reload(context.Background(), cfg); err != nil {
return fmt.Errorf("error reload channels: %w", err)
}
fmt.Println(" ✓ Channels restarted.")
enabledChannels := runningServices.ChannelManager.GetEnabledChannels()
if len(enabledChannels) > 0 {
fmt.Printf(" ✓ Channels enabled: %s\n", enabledChannels)
@@ -503,18 +504,6 @@ func restartServices(
fmt.Println(" ⚠ Warning: No channels enabled")
}
addr := fmt.Sprintf("%s:%d", cfg.Gateway.Host, cfg.Gateway.Port)
// Reuse existing HealthServer to preserve reloadFunc
if runningServices.HealthServer == nil {
runningServices.HealthServer = health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port)
}
runningServices.ChannelManager.SetupHTTPServer(addr, runningServices.HealthServer)
if err = runningServices.ChannelManager.Reload(context.Background(), cfg); err != nil {
return fmt.Errorf("error reload channels: %w", err)
}
fmt.Println(" ✓ Channels restarted.")
stateManager := state.NewManager(cfg.WorkspacePath())
runningServices.DeviceService = devices.NewService(devices.Config{
Enabled: cfg.Devices.Enabled,
+9 -1
View File
@@ -198,9 +198,17 @@ func (s *Server) readyHandler(w http.ResponseWriter, r *http.Request) {
})
}
// HandlerMux is the interface for registering HTTP handlers, used by
// RegisterOnMux so that callers can pass any mux implementation
// (e.g. *http.ServeMux or a custom dynamic mux).
type HandlerMux interface {
Handle(pattern string, handler http.Handler)
HandleFunc(pattern string, handler func(http.ResponseWriter, *http.Request))
}
// RegisterOnMux registers /health, /ready and /reload handlers onto the given mux.
// This allows the health endpoints to be served by a shared HTTP server.
func (s *Server) RegisterOnMux(mux *http.ServeMux) {
func (s *Server) RegisterOnMux(mux HandlerMux) {
mux.HandleFunc("/health", s.healthHandler)
mux.HandleFunc("/ready", s.readyHandler)
mux.HandleFunc("/reload", s.reloadHandler)
+4
View File
@@ -353,6 +353,10 @@ func WarnCF(component string, message string, fields map[string]any) {
logMessage(WARN, component, message, fields)
}
func Warnf(message string, ss ...any) {
logMessage(WARN, "", fmt.Sprintf(message, ss...), nil)
}
func Error(message string) {
logMessage(ERROR, "", message, nil)
}