From 2a6ade0fe46d3f2617d1e28742080cfd23361c67 Mon Sep 17 00:00:00 2001 From: Cytown Date: Thu, 19 Mar 2026 13:42:36 +0800 Subject: [PATCH] feat: add /reload to gateway api and command (#1725) * feat: add /reload to gateway api and command * prevent duplicate reload request in same time --- pkg/agent/loop.go | 12 ++++++ pkg/commands/builtin.go | 1 + pkg/commands/cmd_reload.go | 20 ++++++++++ pkg/commands/runtime.go | 1 + pkg/gateway/gateway.go | 77 ++++++++++++++++++++++++++++++++++---- pkg/health/server.go | 53 +++++++++++++++++++++++--- 6 files changed, 150 insertions(+), 14 deletions(-) create mode 100644 pkg/commands/cmd_reload.go diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 33da33e92..a6eccc3fe 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -49,6 +49,7 @@ type AgentLoop struct { cmdRegistry *commands.Registry mcp mcpRuntime mu sync.RWMutex + reloadFunc func() error // Track active requests for safe provider cleanup activeRequests sync.WaitGroup } @@ -498,6 +499,11 @@ func (al *AgentLoop) SetTranscriber(t voice.Transcriber) { al.transcriber = t } +// SetReloadFunc sets the callback function for triggering config reload. +func (al *AgentLoop) SetReloadFunc(fn func() error) { + al.reloadFunc = fn +} + var audioAnnotationRe = regexp.MustCompile(`\[(voice|audio)(?::[^\]]*)?\]`) // transcribeAudioInMessage resolves audio media refs, transcribes them, and @@ -1931,6 +1937,12 @@ func (al *AgentLoop) buildCommandsRuntime(agent *AgentInstance, opts *processOpt return nil }, } + rt.ReloadConfig = func() error { + if al.reloadFunc == nil { + return fmt.Errorf("reload not configured") + } + return al.reloadFunc() + } if agent != nil { rt.GetModelInfo = func() (string, string) { return agent.Model, cfg.Agents.Defaults.Provider diff --git a/pkg/commands/builtin.go b/pkg/commands/builtin.go index aed6a1874..6d9ece82f 100644 --- a/pkg/commands/builtin.go +++ b/pkg/commands/builtin.go @@ -13,5 +13,6 @@ func BuiltinDefinitions() []Definition { switchCommand(), checkCommand(), clearCommand(), + reloadCommand(), } } diff --git a/pkg/commands/cmd_reload.go b/pkg/commands/cmd_reload.go new file mode 100644 index 000000000..07ab44016 --- /dev/null +++ b/pkg/commands/cmd_reload.go @@ -0,0 +1,20 @@ +package commands + +import "context" + +func reloadCommand() Definition { + return Definition{ + Name: "reload", + Description: "Reload the configuration file", + Usage: "/reload", + Handler: func(_ context.Context, req Request, rt *Runtime) error { + if rt == nil || rt.ReloadConfig == nil { + return req.Reply(unavailableMsg) + } + if err := rt.ReloadConfig(); err != nil { + return req.Reply("Failed to reload configuration: " + err.Error()) + } + return req.Reply("Config reload triggered!") + }, + } +} diff --git a/pkg/commands/runtime.go b/pkg/commands/runtime.go index 037184686..84f775808 100644 --- a/pkg/commands/runtime.go +++ b/pkg/commands/runtime.go @@ -14,4 +14,5 @@ type Runtime struct { SwitchModel func(value string) (oldModel string, err error) SwitchChannel func(value string) error ClearHistory func() error + ReloadConfig func() error } diff --git a/pkg/gateway/gateway.go b/pkg/gateway/gateway.go index 6745d1748..ee7815fe2 100644 --- a/pkg/gateway/gateway.go +++ b/pkg/gateway/gateway.go @@ -7,6 +7,7 @@ import ( "os/signal" "path/filepath" "sync" + "sync/atomic" "syscall" "time" @@ -54,6 +55,8 @@ type services struct { ChannelManager *channels.Manager DeviceService *devices.Service HealthServer *health.Server + manualReloadChan chan struct{} + reloading atomic.Bool } type startupBlockedProvider struct { @@ -117,6 +120,25 @@ func Run(debug bool, configPath string, allowEmptyStartup bool) error { return err } + // Setup manual reload channel for /reload endpoint + manualReloadChan := make(chan struct{}, 1) + runningServices.manualReloadChan = manualReloadChan + reloadTrigger := func() error { + if !runningServices.reloading.CompareAndSwap(false, true) { + return fmt.Errorf("reload already in progress") + } + select { + case manualReloadChan <- struct{}{}: + return nil + default: + // Should not happen, but reset flag if channel is full + runningServices.reloading.Store(false) + return fmt.Errorf("reload already queued") + } + } + runningServices.HealthServer.SetReloadFunc(reloadTrigger) + agentLoop.SetReloadFunc(reloadTrigger) + fmt.Printf("✓ Gateway started on %s:%d\n", cfg.Gateway.Host, cfg.Gateway.Port) fmt.Println("Press Ctrl+C to stop") @@ -143,14 +165,50 @@ func Run(debug bool, configPath string, allowEmptyStartup bool) error { shutdownGateway(runningServices, agentLoop, provider, true) return nil case newCfg := <-configReloadChan: - err := handleConfigReload(ctx, agentLoop, newCfg, &provider, runningServices, msgBus, allowEmptyStartup) + if !runningServices.reloading.CompareAndSwap(false, true) { + logger.Warn("Config reload skipped: another reload is in progress") + continue + } + err := executeReload(ctx, agentLoop, newCfg, &provider, runningServices, msgBus, allowEmptyStartup) if err != nil { logger.Errorf("Config reload failed: %v", err) } + case <-manualReloadChan: + logger.Info("Manual reload triggered via /reload endpoint") + newCfg, err := config.LoadConfig(configPath) + if err != nil { + logger.Errorf("Error loading config for manual reload: %v", err) + runningServices.reloading.Store(false) + continue + } + if err = newCfg.ValidateModelList(); err != nil { + logger.Errorf("Config validation failed: %v", err) + runningServices.reloading.Store(false) + continue + } + err = executeReload(ctx, agentLoop, newCfg, &provider, runningServices, msgBus, allowEmptyStartup) + if err != nil { + logger.Errorf("Manual reload failed: %v", err) + } else { + logger.Info("Manual reload completed successfully") + } } } } +func executeReload( + ctx context.Context, + agentLoop *agent.AgentLoop, + newCfg *config.Config, + provider *providers.LLMProvider, + runningServices *services, + msgBus *bus.MessageBus, + allowEmptyStartup bool, +) error { + defer runningServices.reloading.Store(false) + return handleConfigReload(ctx, agentLoop, newCfg, provider, runningServices, msgBus, allowEmptyStartup) +} + func createStartupProvider( cfg *config.Config, allowEmptyStartup bool, @@ -245,7 +303,11 @@ func setupAndStartServices( return nil, fmt.Errorf("error starting channels: %w", err) } - fmt.Printf("✓ Health endpoints available at http://%s:%d/health and /ready\n", cfg.Gateway.Host, cfg.Gateway.Port) + fmt.Printf( + "✓ Health endpoints available at http://%s:%d/health, /ready and /reload (POST)\n", + cfg.Gateway.Host, + cfg.Gateway.Port, + ) stateManager := state.NewManager(cfg.WorkspacePath()) runningServices.DeviceService = devices.NewService(devices.Config{ @@ -426,17 +488,16 @@ func restartServices( } addr := fmt.Sprintf("%s:%d", cfg.Gateway.Host, cfg.Gateway.Port) - runningServices.HealthServer = health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port) + // Reuse existing HealthServer to preserve reloadFunc + if runningServices.HealthServer == nil { + runningServices.HealthServer = health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port) + } runningServices.ChannelManager.SetupHTTPServer(addr, runningServices.HealthServer) if err = runningServices.ChannelManager.StartAll(context.Background()); err != nil { return fmt.Errorf("error restarting channels: %w", err) } - fmt.Printf( - " ✓ Channels restarted, health endpoints at http://%s:%d/health and ready\n", - cfg.Gateway.Host, - cfg.Gateway.Port, - ) + fmt.Println(" ✓ Channels restarted.") stateManager := state.NewManager(cfg.WorkspacePath()) runningServices.DeviceService = devices.NewService(devices.Config{ diff --git a/pkg/health/server.go b/pkg/health/server.go index b9ee9f496..fe20e4b94 100644 --- a/pkg/health/server.go +++ b/pkg/health/server.go @@ -12,11 +12,12 @@ import ( ) type Server struct { - server *http.Server - mu sync.RWMutex - ready bool - checks map[string]Check - startTime time.Time + server *http.Server + mu sync.RWMutex + ready bool + checks map[string]Check + startTime time.Time + reloadFunc func() error } type Check struct { @@ -43,6 +44,7 @@ func NewServer(host string, port int) *Server { mux.HandleFunc("/health", s.healthHandler) mux.HandleFunc("/ready", s.readyHandler) + mux.HandleFunc("/reload", s.reloadHandler) addr := fmt.Sprintf("%s:%d", host, port) s.server = &http.Server{ @@ -106,6 +108,44 @@ func (s *Server) RegisterCheck(name string, checkFn func() (bool, string)) { } } +// SetReloadFunc sets the callback function for config reload. +func (s *Server) SetReloadFunc(fn func() error) { + s.mu.Lock() + defer s.mu.Unlock() + s.reloadFunc = fn +} + +func (s *Server) reloadHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusMethodNotAllowed) + json.NewEncoder(w).Encode(map[string]string{"error": "method not allowed, use POST"}) + return + } + + s.mu.Lock() + reloadFunc := s.reloadFunc + s.mu.Unlock() + + if reloadFunc == nil { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusServiceUnavailable) + json.NewEncoder(w).Encode(map[string]string{"error": "reload not configured"}) + return + } + + if err := reloadFunc(); err != nil { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusInternalServerError) + json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]string{"status": "reload triggered"}) +} + func (s *Server) healthHandler(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) @@ -158,11 +198,12 @@ func (s *Server) readyHandler(w http.ResponseWriter, r *http.Request) { }) } -// RegisterOnMux registers /health and /ready handlers onto the given mux. +// RegisterOnMux registers /health, /ready and /reload handlers onto the given mux. // This allows the health endpoints to be served by a shared HTTP server. func (s *Server) RegisterOnMux(mux *http.ServeMux) { mux.HandleFunc("/health", s.healthHandler) mux.HandleFunc("/ready", s.readyHandler) + mux.HandleFunc("/reload", s.reloadHandler) } func statusString(ok bool) string {