mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
fix gateway reload will cause pico stop working issue (#2082)
* fix gateway reload will cause pico stop working issue * fix for review
This commit is contained in:
@@ -2321,7 +2321,7 @@ func TestProcessMessage_PublishesReasoningContentToReasoningChannel(t *testing.T
|
||||
if outbound.Content != "thinking trace" {
|
||||
t.Fatalf("reasoning content = %q, want %q", outbound.Content, "thinking trace")
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("expected reasoning content to be published to reasoning channel")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,74 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// dynamicServeMux is an http.Handler that supports dynamic registration
|
||||
// and unregistration of handlers without recreating the server.
|
||||
type dynamicServeMux struct {
|
||||
mu sync.RWMutex
|
||||
handlers map[string]http.Handler
|
||||
}
|
||||
|
||||
func newDynamicServeMux() *dynamicServeMux {
|
||||
return &dynamicServeMux{
|
||||
handlers: make(map[string]http.Handler),
|
||||
}
|
||||
}
|
||||
|
||||
// Handle registers the handler for the given pattern.
|
||||
func (dm *dynamicServeMux) Handle(pattern string, handler http.Handler) {
|
||||
dm.mu.Lock()
|
||||
defer dm.mu.Unlock()
|
||||
dm.handlers[pattern] = handler
|
||||
}
|
||||
|
||||
// HandleFunc registers the handler function for the given pattern.
|
||||
func (dm *dynamicServeMux) HandleFunc(pattern string, handler func(http.ResponseWriter, *http.Request)) {
|
||||
dm.Handle(pattern, http.HandlerFunc(handler))
|
||||
}
|
||||
|
||||
// Unhandle removes the handler for the given pattern.
|
||||
func (dm *dynamicServeMux) Unhandle(pattern string) {
|
||||
dm.mu.Lock()
|
||||
defer dm.mu.Unlock()
|
||||
delete(dm.handlers, pattern)
|
||||
}
|
||||
|
||||
// ServeHTTP dispatches the request to the handler whose pattern best matches
|
||||
// the request URL path. It supports both exact path matches and subtree
|
||||
// (trailing-slash) prefix matches, choosing the longest prefix on collision.
|
||||
func (dm *dynamicServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
dm.mu.RLock()
|
||||
defer dm.mu.RUnlock()
|
||||
|
||||
path := r.URL.Path
|
||||
|
||||
// Exact match first.
|
||||
if h, ok := dm.handlers[path]; ok {
|
||||
h.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Longest subtree prefix match (patterns ending with "/").
|
||||
var bestLen int
|
||||
var bestHandler http.Handler
|
||||
for pattern, handler := range dm.handlers {
|
||||
if strings.HasSuffix(pattern, "/") && strings.HasPrefix(path, pattern) {
|
||||
if len(pattern) > bestLen {
|
||||
bestLen = len(pattern)
|
||||
bestHandler = handler
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if bestHandler != nil {
|
||||
bestHandler.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
@@ -0,0 +1,162 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDynamicServeMuxExactMatch(t *testing.T) {
|
||||
dm := newDynamicServeMux()
|
||||
dm.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
dm.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/health", nil))
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDynamicServeMuxSubtreePrefixMatch(t *testing.T) {
|
||||
dm := newDynamicServeMux()
|
||||
dm.HandleFunc("/api/", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
})
|
||||
|
||||
for _, path := range []string{"/api/", "/api/v1", "/api/v1/resource"} {
|
||||
rec := httptest.NewRecorder()
|
||||
dm.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, path, nil))
|
||||
if rec.Code != http.StatusCreated {
|
||||
t.Fatalf("path %q: expected 201, got %d", path, rec.Code)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDynamicServeMuxExactOverPrefix(t *testing.T) {
|
||||
dm := newDynamicServeMux()
|
||||
dm.HandleFunc("/api", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
dm.HandleFunc("/api/", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
})
|
||||
|
||||
// Exact match wins
|
||||
rec := httptest.NewRecorder()
|
||||
dm.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/api", nil))
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("exact match: expected 200, got %d", rec.Code)
|
||||
}
|
||||
|
||||
// Prefix match for sub-paths
|
||||
rec = httptest.NewRecorder()
|
||||
dm.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/api/v1", nil))
|
||||
if rec.Code != http.StatusCreated {
|
||||
t.Fatalf("prefix match: expected 201, got %d", rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDynamicServeMuxLongestPrefixWins(t *testing.T) {
|
||||
dm := newDynamicServeMux()
|
||||
dm.HandleFunc("/a/", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
dm.HandleFunc("/a/b/", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
})
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
dm.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/a/b/c", nil))
|
||||
if rec.Code != http.StatusAccepted {
|
||||
t.Fatalf("longest prefix: expected 202, got %d", rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDynamicServeMuxNotFound(t *testing.T) {
|
||||
dm := newDynamicServeMux()
|
||||
rec := httptest.NewRecorder()
|
||||
dm.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/nonexistent", nil))
|
||||
if rec.Code != http.StatusNotFound {
|
||||
t.Fatalf("expected 404, got %d", rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDynamicServeMuxUnhandle(t *testing.T) {
|
||||
dm := newDynamicServeMux()
|
||||
dm.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Verify it works before removal
|
||||
rec := httptest.NewRecorder()
|
||||
dm.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/test", nil))
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("before unhandle: expected 200, got %d", rec.Code)
|
||||
}
|
||||
|
||||
// Remove and verify 404
|
||||
dm.Unhandle("/test")
|
||||
rec = httptest.NewRecorder()
|
||||
dm.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/test", nil))
|
||||
if rec.Code != http.StatusNotFound {
|
||||
t.Fatalf("after unhandle: expected 404, got %d", rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDynamicServeMuxConcurrent(t *testing.T) {
|
||||
dm := newDynamicServeMux()
|
||||
dm.HandleFunc("/static", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
var wg sync.WaitGroup
|
||||
const goroutines = 50
|
||||
|
||||
// Concurrent Handle/Unhandle
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
pattern := "/concurrent"
|
||||
if i%2 == 0 {
|
||||
dm.HandleFunc(pattern, func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
})
|
||||
} else {
|
||||
dm.Unhandle(pattern)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Concurrent ServeHTTP
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
rec := httptest.NewRecorder()
|
||||
dm.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/static", nil))
|
||||
// Should not panic; result is either 200 or 404
|
||||
_ = rec.Code
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestDynamicServeMuxHandleUsesHandler(t *testing.T) {
|
||||
dm := newDynamicServeMux()
|
||||
|
||||
var called bool
|
||||
dm.Handle("/handler", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
called = true
|
||||
}))
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
dm.ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/handler", nil))
|
||||
if !called {
|
||||
t.Fatal("handler was not called")
|
||||
}
|
||||
}
|
||||
+80
-24
@@ -83,7 +83,7 @@ type Manager struct {
|
||||
config *config.Config
|
||||
mediaStore media.MediaStore
|
||||
dispatchTask *asyncTask
|
||||
mux *http.ServeMux
|
||||
mux *dynamicServeMux
|
||||
httpServer *http.Server
|
||||
mu sync.RWMutex
|
||||
placeholders sync.Map // "channel:chatID" → placeholderID (string)
|
||||
@@ -436,7 +436,7 @@ func (m *Manager) initChannels(channels *config.ChannelsConfig) error {
|
||||
// It registers health endpoints from the health server and discovers channels
|
||||
// that implement WebhookHandler and/or HealthChecker to register their handlers.
|
||||
func (m *Manager) SetupHTTPServer(addr string, healthServer *health.Server) {
|
||||
m.mux = http.NewServeMux()
|
||||
m.mux = newDynamicServeMux()
|
||||
|
||||
// Register health endpoints
|
||||
if healthServer != nil {
|
||||
@@ -444,22 +444,7 @@ func (m *Manager) SetupHTTPServer(addr string, healthServer *health.Server) {
|
||||
}
|
||||
|
||||
// Discover and register webhook handlers and health checkers
|
||||
for name, ch := range m.channels {
|
||||
if wh, ok := ch.(WebhookHandler); ok {
|
||||
m.mux.Handle(wh.WebhookPath(), wh)
|
||||
logger.InfoCF("channels", "Webhook handler registered", map[string]any{
|
||||
"channel": name,
|
||||
"path": wh.WebhookPath(),
|
||||
})
|
||||
}
|
||||
if hc, ok := ch.(HealthChecker); ok {
|
||||
m.mux.HandleFunc(hc.HealthPath(), hc.HealthHandler)
|
||||
logger.InfoCF("channels", "Health endpoint registered", map[string]any{
|
||||
"channel": name,
|
||||
"path": hc.HealthPath(),
|
||||
})
|
||||
}
|
||||
}
|
||||
m.registerHTTPHandlersLocked()
|
||||
|
||||
m.httpServer = &http.Server{
|
||||
Addr: addr,
|
||||
@@ -469,6 +454,53 @@ func (m *Manager) SetupHTTPServer(addr string, healthServer *health.Server) {
|
||||
}
|
||||
}
|
||||
|
||||
// registerHTTPHandlersLocked registers webhook and health-check handlers for
|
||||
// all channels currently in m.channels. Caller must hold m.mu (or ensure
|
||||
// exclusive access).
|
||||
func (m *Manager) registerHTTPHandlersLocked() {
|
||||
for name, ch := range m.channels {
|
||||
m.registerChannelHTTPHandler(name, ch)
|
||||
}
|
||||
}
|
||||
|
||||
// registerChannelHTTPHandler registers the webhook/health handlers for a
|
||||
// single channel onto m.mux.
|
||||
func (m *Manager) registerChannelHTTPHandler(name string, ch Channel) {
|
||||
if wh, ok := ch.(WebhookHandler); ok {
|
||||
m.mux.Handle(wh.WebhookPath(), wh)
|
||||
logger.InfoCF("channels", "Webhook handler registered", map[string]any{
|
||||
"channel": name,
|
||||
"path": wh.WebhookPath(),
|
||||
})
|
||||
}
|
||||
if hc, ok := ch.(HealthChecker); ok {
|
||||
m.mux.HandleFunc(hc.HealthPath(), hc.HealthHandler)
|
||||
logger.InfoCF("channels", "Health endpoint registered", map[string]any{
|
||||
"channel": name,
|
||||
"path": hc.HealthPath(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// unregisterChannelHTTPHandler removes the webhook/health handlers for a
|
||||
// single channel from m.mux.
|
||||
func (m *Manager) unregisterChannelHTTPHandler(name string, ch Channel) {
|
||||
if wh, ok := ch.(WebhookHandler); ok {
|
||||
m.mux.Unhandle(wh.WebhookPath())
|
||||
logger.InfoCF("channels", "Webhook handler unregistered", map[string]any{
|
||||
"channel": name,
|
||||
"path": wh.WebhookPath(),
|
||||
})
|
||||
}
|
||||
if hc, ok := ch.(HealthChecker); ok {
|
||||
m.mux.Unhandle(hc.HealthPath())
|
||||
logger.InfoCF("channels", "Health endpoint unregistered", map[string]any{
|
||||
"channel": name,
|
||||
"path": hc.HealthPath(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) StartAll(ctx context.Context) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
@@ -984,8 +1016,17 @@ func (m *Manager) GetEnabledChannels() []string {
|
||||
func (m *Manager) Reload(ctx context.Context, cfg *config.Config) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// Save old config so we can revert on error.
|
||||
oldConfig := m.config
|
||||
|
||||
// Update config early: initChannel uses m.config via factory(m.config, m.bus).
|
||||
m.config = cfg
|
||||
|
||||
list := toChannelHashes(cfg)
|
||||
added, removed := compareChannels(m.channelHashes, list)
|
||||
|
||||
deferFuncs := make([]func(), 0, len(removed)+len(added))
|
||||
for _, name := range removed {
|
||||
// Stop all channels
|
||||
channel := m.channels[name]
|
||||
@@ -998,20 +1039,24 @@ func (m *Manager) Reload(ctx context.Context, cfg *config.Config) error {
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
go func() {
|
||||
deferFuncs = append(deferFuncs, func() {
|
||||
m.UnregisterChannel(name)
|
||||
}()
|
||||
})
|
||||
}
|
||||
dispatchCtx, cancel := context.WithCancel(ctx)
|
||||
m.dispatchTask = &asyncTask{cancel: cancel}
|
||||
cc, err := toChannelConfig(cfg, added)
|
||||
if err != nil {
|
||||
logger.ErrorC("channels", fmt.Sprintf("toChannelConfig error: %v", err))
|
||||
m.config = oldConfig
|
||||
cancel()
|
||||
return err
|
||||
}
|
||||
err = m.initChannels(cc)
|
||||
if err != nil {
|
||||
logger.ErrorC("channels", fmt.Sprintf("initChannels error: %v", err))
|
||||
m.config = oldConfig
|
||||
cancel()
|
||||
return err
|
||||
}
|
||||
for _, name := range added {
|
||||
@@ -1031,13 +1076,18 @@ func (m *Manager) Reload(ctx context.Context, cfg *config.Config) error {
|
||||
m.workers[name] = w
|
||||
go m.runWorker(dispatchCtx, name, w)
|
||||
go m.runMediaWorker(dispatchCtx, name, w)
|
||||
go func() {
|
||||
deferFuncs = append(deferFuncs, func() {
|
||||
m.RegisterChannel(name, channel)
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
m.config = cfg
|
||||
m.channelHashes = toChannelHashes(cfg)
|
||||
// Commit hashes only on full success.
|
||||
m.channelHashes = list
|
||||
go func() {
|
||||
for _, f := range deferFuncs {
|
||||
f()
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1045,11 +1095,17 @@ func (m *Manager) RegisterChannel(name string, channel Channel) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.channels[name] = channel
|
||||
if m.mux != nil {
|
||||
m.registerChannelHTTPHandler(name, channel)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) UnregisterChannel(name string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if ch, ok := m.channels[name]; ok && m.mux != nil {
|
||||
m.unregisterChannelHTTPHandler(name, ch)
|
||||
}
|
||||
if w, ok := m.workers[name]; ok && w != nil {
|
||||
close(w.queue)
|
||||
<-w.done
|
||||
|
||||
+5
-16
@@ -490,12 +490,13 @@ func restartServices(
|
||||
}
|
||||
al.SetMediaStore(runningServices.MediaStore)
|
||||
|
||||
runningServices.ChannelManager, err = channels.NewManager(cfg, msgBus, runningServices.MediaStore)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error recreating channel manager: %w", err)
|
||||
}
|
||||
al.SetChannelManager(runningServices.ChannelManager)
|
||||
|
||||
if err = runningServices.ChannelManager.Reload(context.Background(), cfg); err != nil {
|
||||
return fmt.Errorf("error reload channels: %w", err)
|
||||
}
|
||||
fmt.Println(" ✓ Channels restarted.")
|
||||
|
||||
enabledChannels := runningServices.ChannelManager.GetEnabledChannels()
|
||||
if len(enabledChannels) > 0 {
|
||||
fmt.Printf(" ✓ Channels enabled: %s\n", enabledChannels)
|
||||
@@ -503,18 +504,6 @@ func restartServices(
|
||||
fmt.Println(" ⚠ Warning: No channels enabled")
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf("%s:%d", cfg.Gateway.Host, cfg.Gateway.Port)
|
||||
// Reuse existing HealthServer to preserve reloadFunc
|
||||
if runningServices.HealthServer == nil {
|
||||
runningServices.HealthServer = health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port)
|
||||
}
|
||||
runningServices.ChannelManager.SetupHTTPServer(addr, runningServices.HealthServer)
|
||||
|
||||
if err = runningServices.ChannelManager.Reload(context.Background(), cfg); err != nil {
|
||||
return fmt.Errorf("error reload channels: %w", err)
|
||||
}
|
||||
fmt.Println(" ✓ Channels restarted.")
|
||||
|
||||
stateManager := state.NewManager(cfg.WorkspacePath())
|
||||
runningServices.DeviceService = devices.NewService(devices.Config{
|
||||
Enabled: cfg.Devices.Enabled,
|
||||
|
||||
@@ -198,9 +198,17 @@ func (s *Server) readyHandler(w http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
}
|
||||
|
||||
// HandlerMux is the interface for registering HTTP handlers, used by
|
||||
// RegisterOnMux so that callers can pass any mux implementation
|
||||
// (e.g. *http.ServeMux or a custom dynamic mux).
|
||||
type HandlerMux interface {
|
||||
Handle(pattern string, handler http.Handler)
|
||||
HandleFunc(pattern string, handler func(http.ResponseWriter, *http.Request))
|
||||
}
|
||||
|
||||
// RegisterOnMux registers /health, /ready and /reload handlers onto the given mux.
|
||||
// This allows the health endpoints to be served by a shared HTTP server.
|
||||
func (s *Server) RegisterOnMux(mux *http.ServeMux) {
|
||||
func (s *Server) RegisterOnMux(mux HandlerMux) {
|
||||
mux.HandleFunc("/health", s.healthHandler)
|
||||
mux.HandleFunc("/ready", s.readyHandler)
|
||||
mux.HandleFunc("/reload", s.reloadHandler)
|
||||
|
||||
@@ -353,6 +353,10 @@ func WarnCF(component string, message string, fields map[string]any) {
|
||||
logMessage(WARN, component, message, fields)
|
||||
}
|
||||
|
||||
func Warnf(message string, ss ...any) {
|
||||
logMessage(WARN, "", fmt.Sprintf(message, ss...), nil)
|
||||
}
|
||||
|
||||
func Error(message string) {
|
||||
logMessage(ERROR, "", message, nil)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user