make gateway aware of config.json change (#1187)

* make gateway aware of config.json change

* fix according to code review

* fix lint

* fix review comment

* fix for review

* refactor to fix review

* fix for review

* fix for review
This commit is contained in:
Cytown
2026-03-13 14:27:46 +08:00
committed by GitHub
parent dfa36f39cb
commit 9676e51e89
4 changed files with 640 additions and 95 deletions
+453 -67
View File
@@ -3,10 +3,10 @@ package gateway
import (
"context"
"fmt"
"log"
"os"
"os/signal"
"path/filepath"
"sync"
"time"
"github.com/sipeed/picoclaw/cmd/picoclaw/internal"
@@ -41,12 +41,31 @@ import (
"github.com/sipeed/picoclaw/pkg/voice"
)
// Timeout constants for service operations
const (
serviceRestartTimeout = 30 * time.Second
serviceShutdownTimeout = 30 * time.Second
providerReloadTimeout = 30 * time.Second
gracefulShutdownTimeout = 15 * time.Second
)
// gatewayServices holds references to all running services
type gatewayServices struct {
CronService *cron.CronService
HeartbeatService *heartbeat.HeartbeatService
MediaStore media.MediaStore
ChannelManager *channels.Manager
DeviceService *devices.Service
HealthServer *health.Server
}
func gatewayCmd(debug bool) error {
if debug {
logger.SetLevel(logger.DEBUG)
fmt.Println("🔍 Debug mode enabled")
}
configPath := internal.GetConfigPath()
cfg, err := internal.LoadConfig()
if err != nil {
return fmt.Errorf("error loading config: %w", err)
@@ -83,9 +102,55 @@ func gatewayCmd(debug bool) error {
"skills_available": skillsInfo["available"],
})
// Setup and start all services
services, err := setupAndStartServices(cfg, agentLoop, msgBus)
if err != nil {
return err
}
fmt.Printf("✓ Gateway started on %s:%d\n", cfg.Gateway.Host, cfg.Gateway.Port)
fmt.Println("Press Ctrl+C to stop")
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go agentLoop.Run(ctx)
// Setup config file watcher for hot reload
configReloadChan, stopWatch := setupConfigWatcherPolling(configPath, debug)
defer stopWatch()
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt)
// Main event loop - wait for signals or config changes
for {
select {
case <-sigChan:
logger.Info("Shutting down...")
shutdownGateway(services, agentLoop, provider, true)
return nil
case newCfg := <-configReloadChan:
err := handleConfigReload(ctx, agentLoop, newCfg, &provider, services, msgBus)
if err != nil {
logger.Errorf("Config reload failed: %v", err)
}
}
}
}
// setupAndStartServices initializes and starts all services
func setupAndStartServices(
cfg *config.Config,
agentLoop *agent.AgentLoop,
msgBus *bus.MessageBus,
) (*gatewayServices, error) {
services := &gatewayServices{}
// Setup cron tool and service
execTimeout := time.Duration(cfg.Tools.Cron.ExecTimeoutMinutes) * time.Minute
cronService := setupCronTool(
services.CronService = setupCronTool(
agentLoop,
msgBus,
cfg.WorkspacePath(),
@@ -93,20 +158,26 @@ func gatewayCmd(debug bool) error {
execTimeout,
cfg,
)
if err := services.CronService.Start(); err != nil {
return nil, fmt.Errorf("error starting cron service: %w", err)
}
fmt.Println("✓ Cron service started")
heartbeatService := heartbeat.NewHeartbeatService(
// Setup heartbeat service
services.HeartbeatService = heartbeat.NewHeartbeatService(
cfg.WorkspacePath(),
cfg.Heartbeat.Interval,
cfg.Heartbeat.Enabled,
)
heartbeatService.SetBus(msgBus)
heartbeatService.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
services.HeartbeatService.SetBus(msgBus)
services.HeartbeatService.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
// Use cli:direct as fallback if no valid channel
if channel == "" || chatID == "" {
channel, chatID = "cli", "direct"
}
// Use ProcessHeartbeat - no session history, each heartbeat is independent
var response string
var err error
response, err = agentLoop.ProcessHeartbeat(context.Background(), prompt, channel, chatID)
if err != nil {
return tools.ErrorResult(fmt.Sprintf("Heartbeat error: %v", err))
@@ -118,24 +189,36 @@ func gatewayCmd(debug bool) error {
// sent to user via processSystemMessage when the async task completes
return tools.SilentResult(response)
})
if err := services.HeartbeatService.Start(); err != nil {
return nil, fmt.Errorf("error starting heartbeat service: %w", err)
}
fmt.Println("✓ Heartbeat service started")
// Create media store for file lifecycle management with TTL cleanup
mediaStore := media.NewFileMediaStoreWithCleanup(media.MediaCleanerConfig{
services.MediaStore = media.NewFileMediaStoreWithCleanup(media.MediaCleanerConfig{
Enabled: cfg.Tools.MediaCleanup.Enabled,
MaxAge: time.Duration(cfg.Tools.MediaCleanup.MaxAge) * time.Minute,
Interval: time.Duration(cfg.Tools.MediaCleanup.Interval) * time.Minute,
})
mediaStore.Start()
// Start the media store if it's a FileMediaStore with cleanup
if fms, ok := services.MediaStore.(*media.FileMediaStore); ok {
fms.Start()
}
channelManager, err := channels.NewManager(cfg, msgBus, mediaStore)
// Create channel manager
var err error
services.ChannelManager, err = channels.NewManager(cfg, msgBus, services.MediaStore)
if err != nil {
mediaStore.Stop()
return fmt.Errorf("error creating channel manager: %w", err)
// Stop the media store if it's a FileMediaStore with cleanup
if fms, ok := services.MediaStore.(*media.FileMediaStore); ok {
fms.Stop()
}
return nil, fmt.Errorf("error creating channel manager: %w", err)
}
// Inject channel manager and media store into agent loop
agentLoop.SetChannelManager(channelManager)
agentLoop.SetMediaStore(mediaStore)
agentLoop.SetChannelManager(services.ChannelManager)
agentLoop.SetMediaStore(services.MediaStore)
// Wire up voice transcription if a supported provider is configured.
if transcriber := voice.DetectTranscriber(cfg); transcriber != nil {
@@ -143,83 +226,386 @@ func gatewayCmd(debug bool) error {
logger.InfoCF("voice", "Transcription enabled (agent-level)", map[string]any{"provider": transcriber.Name()})
}
enabledChannels := channelManager.GetEnabledChannels()
enabledChannels := services.ChannelManager.GetEnabledChannels()
if len(enabledChannels) > 0 {
fmt.Printf("✓ Channels enabled: %s\n", enabledChannels)
} else {
fmt.Println("⚠ Warning: No channels enabled")
}
fmt.Printf("✓ Gateway started on %s:%d\n", cfg.Gateway.Host, cfg.Gateway.Port)
fmt.Println("Press Ctrl+C to stop")
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
if err := cronService.Start(); err != nil {
fmt.Printf("Error starting cron service: %v\n", err)
}
fmt.Println("✓ Cron service started")
if err := heartbeatService.Start(); err != nil {
fmt.Printf("Error starting heartbeat service: %v\n", err)
}
fmt.Println("✓ Heartbeat service started")
stateManager := state.NewManager(cfg.WorkspacePath())
deviceService := devices.NewService(devices.Config{
Enabled: cfg.Devices.Enabled,
MonitorUSB: cfg.Devices.MonitorUSB,
}, stateManager)
deviceService.SetBus(msgBus)
if err := deviceService.Start(ctx); err != nil {
fmt.Printf("Error starting device service: %v\n", err)
} else if cfg.Devices.Enabled {
fmt.Println("✓ Device event service started")
}
// Setup shared HTTP server with health endpoints and webhook handlers
healthServer := health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port)
addr := fmt.Sprintf("%s:%d", cfg.Gateway.Host, cfg.Gateway.Port)
channelManager.SetupHTTPServer(addr, healthServer)
services.HealthServer = health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port)
services.ChannelManager.SetupHTTPServer(addr, services.HealthServer)
if err := channelManager.StartAll(ctx); err != nil {
fmt.Printf("Error starting channels: %v\n", err)
return err
if err := services.ChannelManager.StartAll(context.Background()); err != nil {
return nil, fmt.Errorf("error starting channels: %w", err)
}
fmt.Printf("✓ Health endpoints available at http://%s:%d/health and /ready\n", cfg.Gateway.Host, cfg.Gateway.Port)
go agentLoop.Run(ctx)
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt)
<-sigChan
fmt.Println("\nShutting down...")
if cp, ok := provider.(providers.StatefulProvider); ok {
cp.Close()
// Setup state manager and device service
stateManager := state.NewManager(cfg.WorkspacePath())
services.DeviceService = devices.NewService(devices.Config{
Enabled: cfg.Devices.Enabled,
MonitorUSB: cfg.Devices.MonitorUSB,
}, stateManager)
services.DeviceService.SetBus(msgBus)
if err := services.DeviceService.Start(context.Background()); err != nil {
logger.ErrorCF("device", "Error starting device service", map[string]any{"error": err.Error()})
} else if cfg.Devices.Enabled {
fmt.Println("✓ Device event service started")
}
cancel()
msgBus.Close()
// Use a fresh context with timeout for graceful shutdown,
// since the original ctx is already canceled.
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 15*time.Second)
return services, nil
}
// stopAndCleanupServices stops all services and cleans up resources
func stopAndCleanupServices(
services *gatewayServices,
shutdownTimeout time.Duration,
) {
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), shutdownTimeout)
defer shutdownCancel()
channelManager.StopAll(shutdownCtx)
deviceService.Stop()
heartbeatService.Stop()
cronService.Stop()
mediaStore.Stop()
if services.ChannelManager != nil {
services.ChannelManager.StopAll(shutdownCtx)
}
if services.DeviceService != nil {
services.DeviceService.Stop()
}
if services.HeartbeatService != nil {
services.HeartbeatService.Stop()
}
if services.CronService != nil {
services.CronService.Stop()
}
if services.MediaStore != nil {
// Stop the media store if it's a FileMediaStore with cleanup
if fms, ok := services.MediaStore.(*media.FileMediaStore); ok {
fms.Stop()
}
}
}
// shutdownGateway performs a complete gateway shutdown
func shutdownGateway(
services *gatewayServices,
agentLoop *agent.AgentLoop,
provider providers.LLMProvider,
fullShutdown bool,
) {
if cp, ok := provider.(providers.StatefulProvider); ok && fullShutdown {
cp.Close()
}
stopAndCleanupServices(services, gracefulShutdownTimeout)
agentLoop.Stop()
agentLoop.Close()
fmt.Println("✓ Gateway stopped")
logger.Info("✓ Gateway stopped")
}
// handleConfigReload handles config file reload by stopping all services,
// reloading the provider and config, and restarting services with the new config.
func handleConfigReload(
ctx context.Context,
al *agent.AgentLoop,
newCfg *config.Config,
providerRef *providers.LLMProvider,
services *gatewayServices,
msgBus *bus.MessageBus,
) error {
logger.Info("🔄 Config file changed, reloading...")
newModel := newCfg.Agents.Defaults.ModelName
if newModel == "" {
newModel = newCfg.Agents.Defaults.Model
}
logger.Infof(" New model is '%s', recreating provider...", newModel)
// Stop all services before reloading
logger.Info(" Stopping all services...")
stopAndCleanupServices(services, serviceShutdownTimeout)
// Create new provider from updated config first to ensure validity
// This will use the correct API key and settings from newCfg.ModelList
newProvider, newModelID, err := providers.CreateProvider(newCfg)
if err != nil {
logger.Errorf(" ⚠ Error creating new provider: %v", err)
logger.Warn(" Attempting to restart services with old provider and config...")
// Try to restart services with old configuration
if restartErr := restartServices(al, services, msgBus); restartErr != nil {
logger.Errorf(" ⚠ Failed to restart services: %v", restartErr)
}
return fmt.Errorf("error creating new provider: %w", err)
}
if newModelID != "" {
newCfg.Agents.Defaults.ModelName = newModelID
}
// Use the atomic reload method on AgentLoop to safely swap provider and config.
// This handles locking internally to prevent races with in-flight LLM calls
// and concurrent reads of registry/config while the swap occurs.
reloadCtx, reloadCancel := context.WithTimeout(context.Background(), providerReloadTimeout)
defer reloadCancel()
if err := al.ReloadProviderAndConfig(reloadCtx, newProvider, newCfg); err != nil {
logger.Errorf(" ⚠ Error reloading agent loop: %v", err)
// Close the newly created provider since it wasn't adopted
if cp, ok := newProvider.(providers.StatefulProvider); ok {
cp.Close()
}
logger.Warn(" Attempting to restart services with old provider and config...")
if restartErr := restartServices(al, services, msgBus); restartErr != nil {
logger.Errorf(" ⚠ Failed to restart services: %v", restartErr)
}
return fmt.Errorf("error reloading agent loop: %w", err)
}
// Update local provider reference only after successful atomic reload
*providerRef = newProvider
// Restart all services with new config
logger.Info(" Restarting all services with new configuration...")
if err := restartServices(al, services, msgBus); err != nil {
logger.Errorf(" ⚠ Error restarting services: %v", err)
return fmt.Errorf("error restarting services: %w", err)
}
logger.Info(" ✓ Provider, configuration, and services reloaded successfully (thread-safe)")
return nil
}
// restartServices restarts all services after a config reload
func restartServices(
al *agent.AgentLoop,
services *gatewayServices,
msgBus *bus.MessageBus,
) error {
// Create an independent context with timeout for service restart
// This prevents cancellation from the main loop context during reload
ctx, cancel := context.WithTimeout(context.Background(), serviceRestartTimeout)
defer cancel()
// Get current config from agent loop (which has been updated if this is a reload)
cfg := al.GetConfig()
// Re-create and start cron service with new config
execTimeout := time.Duration(cfg.Tools.Cron.ExecTimeoutMinutes) * time.Minute
services.CronService = setupCronTool(
al,
msgBus,
cfg.WorkspacePath(),
cfg.Agents.Defaults.RestrictToWorkspace,
execTimeout,
cfg,
)
if err := services.CronService.Start(); err != nil {
return fmt.Errorf("error restarting cron service: %w", err)
}
fmt.Println(" ✓ Cron service restarted")
// Re-create and start heartbeat service with new config
services.HeartbeatService = heartbeat.NewHeartbeatService(
cfg.WorkspacePath(),
cfg.Heartbeat.Interval,
cfg.Heartbeat.Enabled,
)
services.HeartbeatService.SetBus(msgBus)
services.HeartbeatService.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
if channel == "" || chatID == "" {
channel, chatID = "cli", "direct"
}
var response string
var err error
response, err = al.ProcessHeartbeat(context.Background(), prompt, channel, chatID)
if err != nil {
return tools.ErrorResult(fmt.Sprintf("Heartbeat error: %v", err))
}
if response == "HEARTBEAT_OK" {
return tools.SilentResult("Heartbeat OK")
}
return tools.SilentResult(response)
})
if err := services.HeartbeatService.Start(); err != nil {
return fmt.Errorf("error restarting heartbeat service: %w", err)
}
fmt.Println(" ✓ Heartbeat service restarted")
// Stop the old media store before creating a new one
if fms, ok := services.MediaStore.(*media.FileMediaStore); ok {
fms.Stop()
}
// Re-create media store with new config
services.MediaStore = media.NewFileMediaStoreWithCleanup(media.MediaCleanerConfig{
Enabled: cfg.Tools.MediaCleanup.Enabled,
MaxAge: time.Duration(cfg.Tools.MediaCleanup.MaxAge) * time.Minute,
Interval: time.Duration(cfg.Tools.MediaCleanup.Interval) * time.Minute,
})
// Start the media store if it's a FileMediaStore with cleanup
if fms, ok := services.MediaStore.(*media.FileMediaStore); ok {
fms.Start()
}
al.SetMediaStore(services.MediaStore)
// Re-create channel manager with new config
var err error
services.ChannelManager, err = channels.NewManager(cfg, msgBus, services.MediaStore)
if err != nil {
// Stop the media store if it's a FileMediaStore with cleanup
if fms, ok := services.MediaStore.(*media.FileMediaStore); ok {
fms.Stop()
}
return fmt.Errorf("error recreating channel manager: %w", err)
}
al.SetChannelManager(services.ChannelManager)
enabledChannels := services.ChannelManager.GetEnabledChannels()
if len(enabledChannels) > 0 {
fmt.Printf(" ✓ Channels enabled: %s\n", enabledChannels)
} else {
fmt.Println(" ⚠ Warning: No channels enabled")
}
// Setup HTTP server with new config
addr := fmt.Sprintf("%s:%d", cfg.Gateway.Host, cfg.Gateway.Port)
services.HealthServer = health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port)
services.ChannelManager.SetupHTTPServer(addr, services.HealthServer)
if err := services.ChannelManager.StartAll(ctx); err != nil {
return fmt.Errorf("error restarting channels: %w", err)
}
fmt.Printf(
" ✓ Channels restarted, health endpoints at http://%s:%d/health and ready\n",
cfg.Gateway.Host,
cfg.Gateway.Port,
)
// Re-create device service with new config
stateManager := state.NewManager(cfg.WorkspacePath())
services.DeviceService = devices.NewService(devices.Config{
Enabled: cfg.Devices.Enabled,
MonitorUSB: cfg.Devices.MonitorUSB,
}, stateManager)
services.DeviceService.SetBus(msgBus)
if err := services.DeviceService.Start(ctx); err != nil {
logger.WarnCF("device", "Failed to restart device service", map[string]any{"error": err.Error()})
} else if cfg.Devices.Enabled {
fmt.Println(" ✓ Device event service restarted")
}
// Wire up voice transcription with new config
transcriber := voice.DetectTranscriber(cfg)
al.SetTranscriber(transcriber) // This will set it to nil if disabled
if transcriber != nil {
logger.InfoCF("voice", "Transcription re-enabled (agent-level)", map[string]any{"provider": transcriber.Name()})
} else {
logger.InfoCF("voice", "Transcription disabled", nil)
}
return nil
}
// setupConfigWatcherPolling sets up a simple polling-based config file watcher
// Returns a channel for config updates and a stop function
func setupConfigWatcherPolling(configPath string, debug bool) (chan *config.Config, func()) {
configChan := make(chan *config.Config, 1)
stop := make(chan struct{})
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
// Get initial file info
lastModTime := getFileModTime(configPath)
lastSize := getFileSize(configPath)
ticker := time.NewTicker(2 * time.Second) // Check every 2 seconds
defer ticker.Stop()
for {
select {
case <-ticker.C:
currentModTime := getFileModTime(configPath)
currentSize := getFileSize(configPath)
// Check if file changed (modification time or size changed)
if currentModTime.After(lastModTime) || currentSize != lastSize {
if debug {
logger.Debugf("🔍 Config file change detected")
}
// Debounce - wait a bit to ensure file write is complete
time.Sleep(500 * time.Millisecond)
// Validate and load new config
newCfg, err := config.LoadConfig(configPath)
if err != nil {
logger.Errorf("⚠ Error loading new config: %v", err)
logger.Warn(" Using previous valid config")
continue
}
// Validate the new config
if err := newCfg.ValidateModelList(); err != nil {
logger.Errorf(" ⚠ New config validation failed: %v", err)
logger.Warn(" Using previous valid config")
continue
}
logger.Info("✓ Config file validated and loaded")
// Update last known state
lastModTime = currentModTime
lastSize = currentSize
// Send new config to main loop (non-blocking)
select {
case configChan <- newCfg:
default:
// Channel full, skip this update
logger.Warn("⚠ Previous config reload still in progress, skipping")
}
}
case <-stop:
return
}
}
}()
stopFunc := func() {
close(stop)
wg.Wait()
}
return configChan, stopFunc
}
// getFileModTime returns the modification time of a file, or zero time if file doesn't exist
func getFileModTime(path string) time.Time {
info, err := os.Stat(path)
if err != nil {
return time.Time{}
}
return info.ModTime()
}
// getFileSize returns the size of a file, or 0 if file doesn't exist
func getFileSize(path string) int64 {
info, err := os.Stat(path)
if err != nil {
return 0
}
return info.Size()
}
func setupCronTool(
agentLoop *agent.AgentLoop,
msgBus *bus.MessageBus,
@@ -239,7 +625,7 @@ func setupCronTool(
var err error
cronTool, err = tools.NewCronTool(cronService, agentLoop, msgBus, workspace, restrict, execTimeout, cfg)
if err != nil {
log.Fatalf("Critical error during CronTool initialization: %v", err)
logger.Fatalf("Critical error during CronTool initialization: %v", err)
}
agentLoop.RegisterTool(cronTool)
+171 -28
View File
@@ -48,6 +48,9 @@ type AgentLoop struct {
transcriber voice.Transcriber
cmdRegistry *commands.Registry
mcp mcpRuntime
mu sync.RWMutex
// Track active requests for safe provider cleanup
activeRequests sync.WaitGroup
}
// processOptions configures how a message is processed
@@ -239,6 +242,7 @@ func registerSharedTools(
func (al *AgentLoop) Run(ctx context.Context) error {
al.running.Store(true)
if err := al.ensureMCPInitialized(ctx); err != nil {
return err
}
@@ -278,7 +282,7 @@ func (al *AgentLoop) Run(ctx context.Context) error {
// If so, skip publishing to avoid duplicate messages to the user.
// Use default agent's tools to check (message tool is shared).
alreadySent := false
defaultAgent := al.registry.GetDefaultAgent()
defaultAgent := al.GetRegistry().GetDefaultAgent()
if defaultAgent != nil {
if tool, ok := defaultAgent.Tools.Get("message"); ok {
if mt, ok := tool.(*tools.MessageTool); ok {
@@ -331,12 +335,13 @@ func (al *AgentLoop) Close() {
}
}
al.registry.Close()
al.GetRegistry().Close()
}
func (al *AgentLoop) RegisterTool(tool tools.Tool) {
for _, agentID := range al.registry.ListAgentIDs() {
if agent, ok := al.registry.GetAgent(agentID); ok {
registry := al.GetRegistry()
for _, agentID := range registry.ListAgentIDs() {
if agent, ok := registry.GetAgent(agentID); ok {
agent.Tools.Register(tool)
}
}
@@ -346,12 +351,123 @@ func (al *AgentLoop) SetChannelManager(cm *channels.Manager) {
al.channelManager = cm
}
// ReloadProviderAndConfig atomically swaps the provider and config with proper synchronization.
// It uses a context to allow timeout control from the caller.
// Returns an error if the reload fails or context is canceled.
func (al *AgentLoop) ReloadProviderAndConfig(
ctx context.Context,
provider providers.LLMProvider,
cfg *config.Config,
) error {
// Validate inputs
if provider == nil {
return fmt.Errorf("provider cannot be nil")
}
if cfg == nil {
return fmt.Errorf("config cannot be nil")
}
// Create new registry with updated config and provider
// Wrap in defer/recover to handle any panics gracefully
var registry *AgentRegistry
var panicErr error
done := make(chan struct{}, 1)
go func() {
defer func() {
if r := recover(); r != nil {
panicErr = fmt.Errorf("panic during registry creation: %v", r)
logger.ErrorCF("agent", "Panic during registry creation",
map[string]any{"panic": r})
}
close(done)
}()
registry = NewAgentRegistry(cfg, provider)
}()
// Wait for completion or context cancellation
select {
case <-done:
if registry == nil {
if panicErr != nil {
return fmt.Errorf("registry creation failed: %w", panicErr)
}
return fmt.Errorf("registry creation failed (nil result)")
}
case <-ctx.Done():
return fmt.Errorf("context canceled during registry creation: %w", ctx.Err())
}
// Check context again before proceeding
if err := ctx.Err(); err != nil {
return fmt.Errorf("context canceled after registry creation: %w", err)
}
// Ensure shared tools are re-registered on the new registry
registerSharedTools(cfg, al.bus, registry, provider)
// Atomically swap the config and registry under write lock
// This ensures readers see a consistent pair
al.mu.Lock()
oldRegistry := al.registry
// Store new values
al.cfg = cfg
al.registry = registry
// Also update fallback chain with new config
al.fallback = providers.NewFallbackChain(providers.NewCooldownTracker())
al.mu.Unlock()
// Close old provider after releasing the lock
// This prevents blocking readers while closing
if oldProvider, ok := extractProvider(oldRegistry); ok {
if stateful, ok := oldProvider.(providers.StatefulProvider); ok {
// Give in-flight requests a moment to complete
// Use a reasonable timeout that balances cleanup vs resource usage
select {
case <-time.After(100 * time.Millisecond):
stateful.Close()
case <-ctx.Done():
// Context canceled, close immediately but log warning
logger.WarnCF("agent", "Context canceled during provider cleanup, forcing close",
map[string]any{"error": ctx.Err()})
stateful.Close()
}
}
}
logger.InfoCF("agent", "Provider and config reloaded successfully",
map[string]any{
"model": cfg.Agents.Defaults.GetModelName(),
})
return nil
}
// GetRegistry returns the current registry (thread-safe)
func (al *AgentLoop) GetRegistry() *AgentRegistry {
al.mu.RLock()
defer al.mu.RUnlock()
return al.registry
}
// GetConfig returns the current config (thread-safe)
func (al *AgentLoop) GetConfig() *config.Config {
al.mu.RLock()
defer al.mu.RUnlock()
return al.cfg
}
// SetMediaStore injects a MediaStore for media lifecycle management.
func (al *AgentLoop) SetMediaStore(s media.MediaStore) {
al.mediaStore = s
// Propagate store to send_file tools in all agents.
al.registry.ForEachTool("send_file", func(t tools.Tool) {
registry := al.GetRegistry()
registry.ForEachTool("send_file", func(t tools.Tool) {
if sf, ok := t.(*tools.SendFileTool); ok {
sf.SetMediaStore(s)
}
@@ -540,7 +656,7 @@ func (al *AgentLoop) ProcessHeartbeat(
ctx context.Context,
content, channel, chatID string,
) (string, error) {
agent := al.registry.GetDefaultAgent()
agent := al.GetRegistry().GetDefaultAgent()
if agent == nil {
return "", fmt.Errorf("no default agent for heartbeat")
}
@@ -636,7 +752,8 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
}
func (al *AgentLoop) resolveMessageRoute(msg bus.InboundMessage) (routing.ResolvedRoute, *AgentInstance, error) {
route := al.registry.ResolveRoute(routing.RouteInput{
registry := al.GetRegistry()
route := registry.ResolveRoute(routing.RouteInput{
Channel: msg.Channel,
AccountID: inboundMetadata(msg, metadataKeyAccountID),
Peer: extractPeer(msg),
@@ -645,9 +762,9 @@ func (al *AgentLoop) resolveMessageRoute(msg bus.InboundMessage) (routing.Resolv
TeamID: inboundMetadata(msg, metadataKeyTeamID),
})
agent, ok := al.registry.GetAgent(route.AgentID)
agent, ok := registry.GetAgent(route.AgentID)
if !ok {
agent = al.registry.GetDefaultAgent()
agent = registry.GetDefaultAgent()
}
if agent == nil {
return routing.ResolvedRoute{}, nil, fmt.Errorf("no agent available for route (agent_id=%s)", route.AgentID)
@@ -709,7 +826,7 @@ func (al *AgentLoop) processSystemMessage(
}
// Use default agent for system messages
agent := al.registry.GetDefaultAgent()
agent := al.GetRegistry().GetDefaultAgent()
if agent == nil {
return "", fmt.Errorf("no default agent for system message")
}
@@ -765,7 +882,8 @@ func (al *AgentLoop) runAgentLoop(
)
// Resolve media:// refs to base64 data URLs (streaming)
maxMediaSize := al.cfg.Agents.Defaults.GetMaxMediaSize()
cfg := al.GetConfig()
maxMediaSize := cfg.Agents.Defaults.GetMaxMediaSize()
messages = resolveMediaRefs(messages, al.mediaStore, maxMediaSize)
// 2. Save user message to session
@@ -943,6 +1061,9 @@ func (al *AgentLoop) runLLMIteration(
}
callLLM := func() (*providers.LLMResponse, error) {
al.activeRequests.Add(1)
defer al.activeRequests.Done()
if len(activeCandidates) > 1 && al.fallback != nil {
fbResult, fbErr := al.fallback.Execute(
ctx,
@@ -1041,6 +1162,7 @@ func (al *AgentLoop) runLLMIteration(
map[string]any{
"agent_id": agent.ID,
"iteration": iteration,
"model": activeModel,
"error": err.Error(),
})
return "", iteration, fmt.Errorf("LLM call failed after retries: %w", err)
@@ -1392,7 +1514,8 @@ func (al *AgentLoop) forceCompression(agent *AgentInstance, sessionKey string) {
func (al *AgentLoop) GetStartupInfo() map[string]any {
info := make(map[string]any)
agent := al.registry.GetDefaultAgent()
registry := al.GetRegistry()
agent := registry.GetDefaultAgent()
if agent == nil {
return info
}
@@ -1409,8 +1532,8 @@ func (al *AgentLoop) GetStartupInfo() map[string]any {
// Agents info
info["agents"] = map[string]any{
"count": len(al.registry.ListAgentIDs()),
"ids": al.registry.ListAgentIDs(),
"count": len(registry.ListAgentIDs()),
"ids": registry.ListAgentIDs(),
}
return info
@@ -1598,17 +1721,22 @@ func (al *AgentLoop) retryLLMCall(
var err error
for attempt := 0; attempt < maxRetries; attempt++ {
resp, err = agent.Provider.Chat(
ctx,
[]providers.Message{{Role: "user", Content: prompt}},
nil,
agent.Model,
map[string]any{
"max_tokens": agent.MaxTokens,
"temperature": llmTemperature,
"prompt_cache_key": agent.ID,
},
)
al.activeRequests.Add(1)
resp, err = func() (*providers.LLMResponse, error) {
defer al.activeRequests.Done()
return agent.Provider.Chat(
ctx,
[]providers.Message{{Role: "user", Content: prompt}},
nil,
agent.Model,
map[string]any{
"max_tokens": agent.MaxTokens,
"temperature": llmTemperature,
"prompt_cache_key": agent.ID,
},
)
}()
if err == nil && resp != nil && resp.Content != "" {
return resp, nil
}
@@ -1741,9 +1869,11 @@ func (al *AgentLoop) handleCommand(
}
func (al *AgentLoop) buildCommandsRuntime(agent *AgentInstance, opts *processOptions) *commands.Runtime {
registry := al.GetRegistry()
cfg := al.GetConfig()
rt := &commands.Runtime{
Config: al.cfg,
ListAgentIDs: al.registry.ListAgentIDs,
Config: cfg,
ListAgentIDs: registry.ListAgentIDs,
ListDefinitions: al.cmdRegistry.Definitions,
GetEnabledChannels: func() []string {
if al.channelManager == nil {
@@ -1763,7 +1893,7 @@ func (al *AgentLoop) buildCommandsRuntime(agent *AgentInstance, opts *processOpt
}
if agent != nil {
rt.GetModelInfo = func() (string, string) {
return agent.Model, al.cfg.Agents.Defaults.Provider
return agent.Model, cfg.Agents.Defaults.Provider
}
rt.SwitchModel = func(value string) (string, error) {
oldModel := agent.Model
@@ -1827,3 +1957,16 @@ func extractParentPeer(msg bus.InboundMessage) *routing.RoutePeer {
}
return &routing.RoutePeer{Kind: parentKind, ID: parentID}
}
// Helper to extract provider from registry for cleanup
func extractProvider(registry *AgentRegistry) (providers.LLMProvider, bool) {
if registry == nil {
return nil, false
}
// Get any agent to access the provider
defaultAgent := registry.GetDefaultAgent()
if defaultAgent == nil {
return nil, false
}
return defaultAgent.Provider, true
}
+12
View File
@@ -195,6 +195,10 @@ func DebugC(component string, message string) {
logMessage(DEBUG, component, message, nil)
}
func Debugf(message string, ss ...any) {
logMessage(DEBUG, "", fmt.Sprintf(message, ss...), nil)
}
func DebugF(message string, fields map[string]any) {
logMessage(DEBUG, "", message, fields)
}
@@ -215,6 +219,10 @@ func InfoF(message string, fields map[string]any) {
logMessage(INFO, "", message, fields)
}
func Infof(message string, ss ...any) {
logMessage(INFO, "", fmt.Sprintf(message, ss...), nil)
}
func InfoCF(component string, message string, fields map[string]any) {
logMessage(INFO, component, message, fields)
}
@@ -243,6 +251,10 @@ func ErrorC(component string, message string) {
logMessage(ERROR, component, message, nil)
}
func Errorf(message string, ss ...any) {
logMessage(ERROR, "", fmt.Sprintf(message, ss...), nil)
}
func ErrorF(message string, fields map[string]any) {
logMessage(ERROR, "", message, fields)
}
+4
View File
@@ -123,17 +123,21 @@ func TestLoggerHelperFunctions(t *testing.T) {
SetLevel(INFO)
Debug("This should not log")
Debugf("this should not log")
Info("This should log")
Warn("This should log")
Error("This should log")
InfoC("test", "Component message")
InfoF("Fields message", map[string]any{"key": "value"})
Infof("test from %v", "Infof")
WarnC("test", "Warning with component")
ErrorF("Error with fields", map[string]any{"error": "test"})
Errorf("test from %v", "Errorf")
SetLevel(DEBUG)
DebugC("test", "Debug with component")
Debugf("test from %v", "Debugf")
WarnF("Warning with fields", map[string]any{"key": "value"})
}