mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
make gateway aware of config.json change (#1187)
* make gateway aware of config.json change * fix according to code review * fix lint * fix review comment * fix for review * refactor to fix review * fix for review * fix for review
This commit is contained in:
@@ -3,10 +3,10 @@ package gateway
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/cmd/picoclaw/internal"
|
||||
@@ -41,12 +41,31 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/voice"
|
||||
)
|
||||
|
||||
// Timeout constants for service operations
|
||||
const (
|
||||
serviceRestartTimeout = 30 * time.Second
|
||||
serviceShutdownTimeout = 30 * time.Second
|
||||
providerReloadTimeout = 30 * time.Second
|
||||
gracefulShutdownTimeout = 15 * time.Second
|
||||
)
|
||||
|
||||
// gatewayServices holds references to all running services
|
||||
type gatewayServices struct {
|
||||
CronService *cron.CronService
|
||||
HeartbeatService *heartbeat.HeartbeatService
|
||||
MediaStore media.MediaStore
|
||||
ChannelManager *channels.Manager
|
||||
DeviceService *devices.Service
|
||||
HealthServer *health.Server
|
||||
}
|
||||
|
||||
func gatewayCmd(debug bool) error {
|
||||
if debug {
|
||||
logger.SetLevel(logger.DEBUG)
|
||||
fmt.Println("🔍 Debug mode enabled")
|
||||
}
|
||||
|
||||
configPath := internal.GetConfigPath()
|
||||
cfg, err := internal.LoadConfig()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error loading config: %w", err)
|
||||
@@ -83,9 +102,55 @@ func gatewayCmd(debug bool) error {
|
||||
"skills_available": skillsInfo["available"],
|
||||
})
|
||||
|
||||
// Setup and start all services
|
||||
services, err := setupAndStartServices(cfg, agentLoop, msgBus)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Printf("✓ Gateway started on %s:%d\n", cfg.Gateway.Host, cfg.Gateway.Port)
|
||||
fmt.Println("Press Ctrl+C to stop")
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go agentLoop.Run(ctx)
|
||||
|
||||
// Setup config file watcher for hot reload
|
||||
configReloadChan, stopWatch := setupConfigWatcherPolling(configPath, debug)
|
||||
defer stopWatch()
|
||||
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, os.Interrupt)
|
||||
|
||||
// Main event loop - wait for signals or config changes
|
||||
for {
|
||||
select {
|
||||
case <-sigChan:
|
||||
logger.Info("Shutting down...")
|
||||
shutdownGateway(services, agentLoop, provider, true)
|
||||
return nil
|
||||
|
||||
case newCfg := <-configReloadChan:
|
||||
err := handleConfigReload(ctx, agentLoop, newCfg, &provider, services, msgBus)
|
||||
if err != nil {
|
||||
logger.Errorf("Config reload failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// setupAndStartServices initializes and starts all services
|
||||
func setupAndStartServices(
|
||||
cfg *config.Config,
|
||||
agentLoop *agent.AgentLoop,
|
||||
msgBus *bus.MessageBus,
|
||||
) (*gatewayServices, error) {
|
||||
services := &gatewayServices{}
|
||||
|
||||
// Setup cron tool and service
|
||||
execTimeout := time.Duration(cfg.Tools.Cron.ExecTimeoutMinutes) * time.Minute
|
||||
cronService := setupCronTool(
|
||||
services.CronService = setupCronTool(
|
||||
agentLoop,
|
||||
msgBus,
|
||||
cfg.WorkspacePath(),
|
||||
@@ -93,20 +158,26 @@ func gatewayCmd(debug bool) error {
|
||||
execTimeout,
|
||||
cfg,
|
||||
)
|
||||
if err := services.CronService.Start(); err != nil {
|
||||
return nil, fmt.Errorf("error starting cron service: %w", err)
|
||||
}
|
||||
fmt.Println("✓ Cron service started")
|
||||
|
||||
heartbeatService := heartbeat.NewHeartbeatService(
|
||||
// Setup heartbeat service
|
||||
services.HeartbeatService = heartbeat.NewHeartbeatService(
|
||||
cfg.WorkspacePath(),
|
||||
cfg.Heartbeat.Interval,
|
||||
cfg.Heartbeat.Enabled,
|
||||
)
|
||||
heartbeatService.SetBus(msgBus)
|
||||
heartbeatService.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
|
||||
services.HeartbeatService.SetBus(msgBus)
|
||||
services.HeartbeatService.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
|
||||
// Use cli:direct as fallback if no valid channel
|
||||
if channel == "" || chatID == "" {
|
||||
channel, chatID = "cli", "direct"
|
||||
}
|
||||
// Use ProcessHeartbeat - no session history, each heartbeat is independent
|
||||
var response string
|
||||
var err error
|
||||
response, err = agentLoop.ProcessHeartbeat(context.Background(), prompt, channel, chatID)
|
||||
if err != nil {
|
||||
return tools.ErrorResult(fmt.Sprintf("Heartbeat error: %v", err))
|
||||
@@ -118,24 +189,36 @@ func gatewayCmd(debug bool) error {
|
||||
// sent to user via processSystemMessage when the async task completes
|
||||
return tools.SilentResult(response)
|
||||
})
|
||||
if err := services.HeartbeatService.Start(); err != nil {
|
||||
return nil, fmt.Errorf("error starting heartbeat service: %w", err)
|
||||
}
|
||||
fmt.Println("✓ Heartbeat service started")
|
||||
|
||||
// Create media store for file lifecycle management with TTL cleanup
|
||||
mediaStore := media.NewFileMediaStoreWithCleanup(media.MediaCleanerConfig{
|
||||
services.MediaStore = media.NewFileMediaStoreWithCleanup(media.MediaCleanerConfig{
|
||||
Enabled: cfg.Tools.MediaCleanup.Enabled,
|
||||
MaxAge: time.Duration(cfg.Tools.MediaCleanup.MaxAge) * time.Minute,
|
||||
Interval: time.Duration(cfg.Tools.MediaCleanup.Interval) * time.Minute,
|
||||
})
|
||||
mediaStore.Start()
|
||||
// Start the media store if it's a FileMediaStore with cleanup
|
||||
if fms, ok := services.MediaStore.(*media.FileMediaStore); ok {
|
||||
fms.Start()
|
||||
}
|
||||
|
||||
channelManager, err := channels.NewManager(cfg, msgBus, mediaStore)
|
||||
// Create channel manager
|
||||
var err error
|
||||
services.ChannelManager, err = channels.NewManager(cfg, msgBus, services.MediaStore)
|
||||
if err != nil {
|
||||
mediaStore.Stop()
|
||||
return fmt.Errorf("error creating channel manager: %w", err)
|
||||
// Stop the media store if it's a FileMediaStore with cleanup
|
||||
if fms, ok := services.MediaStore.(*media.FileMediaStore); ok {
|
||||
fms.Stop()
|
||||
}
|
||||
return nil, fmt.Errorf("error creating channel manager: %w", err)
|
||||
}
|
||||
|
||||
// Inject channel manager and media store into agent loop
|
||||
agentLoop.SetChannelManager(channelManager)
|
||||
agentLoop.SetMediaStore(mediaStore)
|
||||
agentLoop.SetChannelManager(services.ChannelManager)
|
||||
agentLoop.SetMediaStore(services.MediaStore)
|
||||
|
||||
// Wire up voice transcription if a supported provider is configured.
|
||||
if transcriber := voice.DetectTranscriber(cfg); transcriber != nil {
|
||||
@@ -143,83 +226,386 @@ func gatewayCmd(debug bool) error {
|
||||
logger.InfoCF("voice", "Transcription enabled (agent-level)", map[string]any{"provider": transcriber.Name()})
|
||||
}
|
||||
|
||||
enabledChannels := channelManager.GetEnabledChannels()
|
||||
enabledChannels := services.ChannelManager.GetEnabledChannels()
|
||||
if len(enabledChannels) > 0 {
|
||||
fmt.Printf("✓ Channels enabled: %s\n", enabledChannels)
|
||||
} else {
|
||||
fmt.Println("⚠ Warning: No channels enabled")
|
||||
}
|
||||
|
||||
fmt.Printf("✓ Gateway started on %s:%d\n", cfg.Gateway.Host, cfg.Gateway.Port)
|
||||
fmt.Println("Press Ctrl+C to stop")
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
if err := cronService.Start(); err != nil {
|
||||
fmt.Printf("Error starting cron service: %v\n", err)
|
||||
}
|
||||
fmt.Println("✓ Cron service started")
|
||||
|
||||
if err := heartbeatService.Start(); err != nil {
|
||||
fmt.Printf("Error starting heartbeat service: %v\n", err)
|
||||
}
|
||||
fmt.Println("✓ Heartbeat service started")
|
||||
|
||||
stateManager := state.NewManager(cfg.WorkspacePath())
|
||||
deviceService := devices.NewService(devices.Config{
|
||||
Enabled: cfg.Devices.Enabled,
|
||||
MonitorUSB: cfg.Devices.MonitorUSB,
|
||||
}, stateManager)
|
||||
deviceService.SetBus(msgBus)
|
||||
if err := deviceService.Start(ctx); err != nil {
|
||||
fmt.Printf("Error starting device service: %v\n", err)
|
||||
} else if cfg.Devices.Enabled {
|
||||
fmt.Println("✓ Device event service started")
|
||||
}
|
||||
|
||||
// Setup shared HTTP server with health endpoints and webhook handlers
|
||||
healthServer := health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port)
|
||||
addr := fmt.Sprintf("%s:%d", cfg.Gateway.Host, cfg.Gateway.Port)
|
||||
channelManager.SetupHTTPServer(addr, healthServer)
|
||||
services.HealthServer = health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port)
|
||||
services.ChannelManager.SetupHTTPServer(addr, services.HealthServer)
|
||||
|
||||
if err := channelManager.StartAll(ctx); err != nil {
|
||||
fmt.Printf("Error starting channels: %v\n", err)
|
||||
return err
|
||||
if err := services.ChannelManager.StartAll(context.Background()); err != nil {
|
||||
return nil, fmt.Errorf("error starting channels: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("✓ Health endpoints available at http://%s:%d/health and /ready\n", cfg.Gateway.Host, cfg.Gateway.Port)
|
||||
|
||||
go agentLoop.Run(ctx)
|
||||
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, os.Interrupt)
|
||||
<-sigChan
|
||||
|
||||
fmt.Println("\nShutting down...")
|
||||
if cp, ok := provider.(providers.StatefulProvider); ok {
|
||||
cp.Close()
|
||||
// Setup state manager and device service
|
||||
stateManager := state.NewManager(cfg.WorkspacePath())
|
||||
services.DeviceService = devices.NewService(devices.Config{
|
||||
Enabled: cfg.Devices.Enabled,
|
||||
MonitorUSB: cfg.Devices.MonitorUSB,
|
||||
}, stateManager)
|
||||
services.DeviceService.SetBus(msgBus)
|
||||
if err := services.DeviceService.Start(context.Background()); err != nil {
|
||||
logger.ErrorCF("device", "Error starting device service", map[string]any{"error": err.Error()})
|
||||
} else if cfg.Devices.Enabled {
|
||||
fmt.Println("✓ Device event service started")
|
||||
}
|
||||
cancel()
|
||||
msgBus.Close()
|
||||
|
||||
// Use a fresh context with timeout for graceful shutdown,
|
||||
// since the original ctx is already canceled.
|
||||
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
return services, nil
|
||||
}
|
||||
|
||||
// stopAndCleanupServices stops all services and cleans up resources
|
||||
func stopAndCleanupServices(
|
||||
services *gatewayServices,
|
||||
shutdownTimeout time.Duration,
|
||||
) {
|
||||
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), shutdownTimeout)
|
||||
defer shutdownCancel()
|
||||
|
||||
channelManager.StopAll(shutdownCtx)
|
||||
deviceService.Stop()
|
||||
heartbeatService.Stop()
|
||||
cronService.Stop()
|
||||
mediaStore.Stop()
|
||||
if services.ChannelManager != nil {
|
||||
services.ChannelManager.StopAll(shutdownCtx)
|
||||
}
|
||||
if services.DeviceService != nil {
|
||||
services.DeviceService.Stop()
|
||||
}
|
||||
if services.HeartbeatService != nil {
|
||||
services.HeartbeatService.Stop()
|
||||
}
|
||||
if services.CronService != nil {
|
||||
services.CronService.Stop()
|
||||
}
|
||||
if services.MediaStore != nil {
|
||||
// Stop the media store if it's a FileMediaStore with cleanup
|
||||
if fms, ok := services.MediaStore.(*media.FileMediaStore); ok {
|
||||
fms.Stop()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// shutdownGateway performs a complete gateway shutdown
|
||||
func shutdownGateway(
|
||||
services *gatewayServices,
|
||||
agentLoop *agent.AgentLoop,
|
||||
provider providers.LLMProvider,
|
||||
fullShutdown bool,
|
||||
) {
|
||||
if cp, ok := provider.(providers.StatefulProvider); ok && fullShutdown {
|
||||
cp.Close()
|
||||
}
|
||||
|
||||
stopAndCleanupServices(services, gracefulShutdownTimeout)
|
||||
|
||||
agentLoop.Stop()
|
||||
agentLoop.Close()
|
||||
fmt.Println("✓ Gateway stopped")
|
||||
|
||||
logger.Info("✓ Gateway stopped")
|
||||
}
|
||||
|
||||
// handleConfigReload handles config file reload by stopping all services,
|
||||
// reloading the provider and config, and restarting services with the new config.
|
||||
func handleConfigReload(
|
||||
ctx context.Context,
|
||||
al *agent.AgentLoop,
|
||||
newCfg *config.Config,
|
||||
providerRef *providers.LLMProvider,
|
||||
services *gatewayServices,
|
||||
msgBus *bus.MessageBus,
|
||||
) error {
|
||||
logger.Info("🔄 Config file changed, reloading...")
|
||||
|
||||
newModel := newCfg.Agents.Defaults.ModelName
|
||||
if newModel == "" {
|
||||
newModel = newCfg.Agents.Defaults.Model
|
||||
}
|
||||
|
||||
logger.Infof(" New model is '%s', recreating provider...", newModel)
|
||||
|
||||
// Stop all services before reloading
|
||||
logger.Info(" Stopping all services...")
|
||||
stopAndCleanupServices(services, serviceShutdownTimeout)
|
||||
|
||||
// Create new provider from updated config first to ensure validity
|
||||
// This will use the correct API key and settings from newCfg.ModelList
|
||||
newProvider, newModelID, err := providers.CreateProvider(newCfg)
|
||||
if err != nil {
|
||||
logger.Errorf(" ⚠ Error creating new provider: %v", err)
|
||||
logger.Warn(" Attempting to restart services with old provider and config...")
|
||||
// Try to restart services with old configuration
|
||||
if restartErr := restartServices(al, services, msgBus); restartErr != nil {
|
||||
logger.Errorf(" ⚠ Failed to restart services: %v", restartErr)
|
||||
}
|
||||
return fmt.Errorf("error creating new provider: %w", err)
|
||||
}
|
||||
|
||||
if newModelID != "" {
|
||||
newCfg.Agents.Defaults.ModelName = newModelID
|
||||
}
|
||||
|
||||
// Use the atomic reload method on AgentLoop to safely swap provider and config.
|
||||
// This handles locking internally to prevent races with in-flight LLM calls
|
||||
// and concurrent reads of registry/config while the swap occurs.
|
||||
reloadCtx, reloadCancel := context.WithTimeout(context.Background(), providerReloadTimeout)
|
||||
defer reloadCancel()
|
||||
|
||||
if err := al.ReloadProviderAndConfig(reloadCtx, newProvider, newCfg); err != nil {
|
||||
logger.Errorf(" ⚠ Error reloading agent loop: %v", err)
|
||||
// Close the newly created provider since it wasn't adopted
|
||||
if cp, ok := newProvider.(providers.StatefulProvider); ok {
|
||||
cp.Close()
|
||||
}
|
||||
logger.Warn(" Attempting to restart services with old provider and config...")
|
||||
if restartErr := restartServices(al, services, msgBus); restartErr != nil {
|
||||
logger.Errorf(" ⚠ Failed to restart services: %v", restartErr)
|
||||
}
|
||||
return fmt.Errorf("error reloading agent loop: %w", err)
|
||||
}
|
||||
|
||||
// Update local provider reference only after successful atomic reload
|
||||
*providerRef = newProvider
|
||||
|
||||
// Restart all services with new config
|
||||
logger.Info(" Restarting all services with new configuration...")
|
||||
if err := restartServices(al, services, msgBus); err != nil {
|
||||
logger.Errorf(" ⚠ Error restarting services: %v", err)
|
||||
return fmt.Errorf("error restarting services: %w", err)
|
||||
}
|
||||
|
||||
logger.Info(" ✓ Provider, configuration, and services reloaded successfully (thread-safe)")
|
||||
return nil
|
||||
}
|
||||
|
||||
// restartServices restarts all services after a config reload
|
||||
func restartServices(
|
||||
al *agent.AgentLoop,
|
||||
services *gatewayServices,
|
||||
msgBus *bus.MessageBus,
|
||||
) error {
|
||||
// Create an independent context with timeout for service restart
|
||||
// This prevents cancellation from the main loop context during reload
|
||||
ctx, cancel := context.WithTimeout(context.Background(), serviceRestartTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Get current config from agent loop (which has been updated if this is a reload)
|
||||
cfg := al.GetConfig()
|
||||
|
||||
// Re-create and start cron service with new config
|
||||
execTimeout := time.Duration(cfg.Tools.Cron.ExecTimeoutMinutes) * time.Minute
|
||||
services.CronService = setupCronTool(
|
||||
al,
|
||||
msgBus,
|
||||
cfg.WorkspacePath(),
|
||||
cfg.Agents.Defaults.RestrictToWorkspace,
|
||||
execTimeout,
|
||||
cfg,
|
||||
)
|
||||
if err := services.CronService.Start(); err != nil {
|
||||
return fmt.Errorf("error restarting cron service: %w", err)
|
||||
}
|
||||
fmt.Println(" ✓ Cron service restarted")
|
||||
|
||||
// Re-create and start heartbeat service with new config
|
||||
services.HeartbeatService = heartbeat.NewHeartbeatService(
|
||||
cfg.WorkspacePath(),
|
||||
cfg.Heartbeat.Interval,
|
||||
cfg.Heartbeat.Enabled,
|
||||
)
|
||||
services.HeartbeatService.SetBus(msgBus)
|
||||
services.HeartbeatService.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
|
||||
if channel == "" || chatID == "" {
|
||||
channel, chatID = "cli", "direct"
|
||||
}
|
||||
var response string
|
||||
var err error
|
||||
response, err = al.ProcessHeartbeat(context.Background(), prompt, channel, chatID)
|
||||
if err != nil {
|
||||
return tools.ErrorResult(fmt.Sprintf("Heartbeat error: %v", err))
|
||||
}
|
||||
if response == "HEARTBEAT_OK" {
|
||||
return tools.SilentResult("Heartbeat OK")
|
||||
}
|
||||
return tools.SilentResult(response)
|
||||
})
|
||||
if err := services.HeartbeatService.Start(); err != nil {
|
||||
return fmt.Errorf("error restarting heartbeat service: %w", err)
|
||||
}
|
||||
fmt.Println(" ✓ Heartbeat service restarted")
|
||||
|
||||
// Stop the old media store before creating a new one
|
||||
if fms, ok := services.MediaStore.(*media.FileMediaStore); ok {
|
||||
fms.Stop()
|
||||
}
|
||||
|
||||
// Re-create media store with new config
|
||||
services.MediaStore = media.NewFileMediaStoreWithCleanup(media.MediaCleanerConfig{
|
||||
Enabled: cfg.Tools.MediaCleanup.Enabled,
|
||||
MaxAge: time.Duration(cfg.Tools.MediaCleanup.MaxAge) * time.Minute,
|
||||
Interval: time.Duration(cfg.Tools.MediaCleanup.Interval) * time.Minute,
|
||||
})
|
||||
// Start the media store if it's a FileMediaStore with cleanup
|
||||
if fms, ok := services.MediaStore.(*media.FileMediaStore); ok {
|
||||
fms.Start()
|
||||
}
|
||||
al.SetMediaStore(services.MediaStore)
|
||||
|
||||
// Re-create channel manager with new config
|
||||
var err error
|
||||
services.ChannelManager, err = channels.NewManager(cfg, msgBus, services.MediaStore)
|
||||
if err != nil {
|
||||
// Stop the media store if it's a FileMediaStore with cleanup
|
||||
if fms, ok := services.MediaStore.(*media.FileMediaStore); ok {
|
||||
fms.Stop()
|
||||
}
|
||||
return fmt.Errorf("error recreating channel manager: %w", err)
|
||||
}
|
||||
al.SetChannelManager(services.ChannelManager)
|
||||
|
||||
enabledChannels := services.ChannelManager.GetEnabledChannels()
|
||||
if len(enabledChannels) > 0 {
|
||||
fmt.Printf(" ✓ Channels enabled: %s\n", enabledChannels)
|
||||
} else {
|
||||
fmt.Println(" ⚠ Warning: No channels enabled")
|
||||
}
|
||||
|
||||
// Setup HTTP server with new config
|
||||
addr := fmt.Sprintf("%s:%d", cfg.Gateway.Host, cfg.Gateway.Port)
|
||||
services.HealthServer = health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port)
|
||||
services.ChannelManager.SetupHTTPServer(addr, services.HealthServer)
|
||||
|
||||
if err := services.ChannelManager.StartAll(ctx); err != nil {
|
||||
return fmt.Errorf("error restarting channels: %w", err)
|
||||
}
|
||||
fmt.Printf(
|
||||
" ✓ Channels restarted, health endpoints at http://%s:%d/health and ready\n",
|
||||
cfg.Gateway.Host,
|
||||
cfg.Gateway.Port,
|
||||
)
|
||||
|
||||
// Re-create device service with new config
|
||||
stateManager := state.NewManager(cfg.WorkspacePath())
|
||||
services.DeviceService = devices.NewService(devices.Config{
|
||||
Enabled: cfg.Devices.Enabled,
|
||||
MonitorUSB: cfg.Devices.MonitorUSB,
|
||||
}, stateManager)
|
||||
services.DeviceService.SetBus(msgBus)
|
||||
if err := services.DeviceService.Start(ctx); err != nil {
|
||||
logger.WarnCF("device", "Failed to restart device service", map[string]any{"error": err.Error()})
|
||||
} else if cfg.Devices.Enabled {
|
||||
fmt.Println(" ✓ Device event service restarted")
|
||||
}
|
||||
|
||||
// Wire up voice transcription with new config
|
||||
transcriber := voice.DetectTranscriber(cfg)
|
||||
al.SetTranscriber(transcriber) // This will set it to nil if disabled
|
||||
if transcriber != nil {
|
||||
logger.InfoCF("voice", "Transcription re-enabled (agent-level)", map[string]any{"provider": transcriber.Name()})
|
||||
} else {
|
||||
logger.InfoCF("voice", "Transcription disabled", nil)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupConfigWatcherPolling sets up a simple polling-based config file watcher
|
||||
// Returns a channel for config updates and a stop function
|
||||
func setupConfigWatcherPolling(configPath string, debug bool) (chan *config.Config, func()) {
|
||||
configChan := make(chan *config.Config, 1)
|
||||
stop := make(chan struct{})
|
||||
var wg sync.WaitGroup
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
// Get initial file info
|
||||
lastModTime := getFileModTime(configPath)
|
||||
lastSize := getFileSize(configPath)
|
||||
|
||||
ticker := time.NewTicker(2 * time.Second) // Check every 2 seconds
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
currentModTime := getFileModTime(configPath)
|
||||
currentSize := getFileSize(configPath)
|
||||
|
||||
// Check if file changed (modification time or size changed)
|
||||
if currentModTime.After(lastModTime) || currentSize != lastSize {
|
||||
if debug {
|
||||
logger.Debugf("🔍 Config file change detected")
|
||||
}
|
||||
|
||||
// Debounce - wait a bit to ensure file write is complete
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// Validate and load new config
|
||||
newCfg, err := config.LoadConfig(configPath)
|
||||
if err != nil {
|
||||
logger.Errorf("⚠ Error loading new config: %v", err)
|
||||
logger.Warn(" Using previous valid config")
|
||||
continue
|
||||
}
|
||||
|
||||
// Validate the new config
|
||||
if err := newCfg.ValidateModelList(); err != nil {
|
||||
logger.Errorf(" ⚠ New config validation failed: %v", err)
|
||||
logger.Warn(" Using previous valid config")
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Info("✓ Config file validated and loaded")
|
||||
|
||||
// Update last known state
|
||||
lastModTime = currentModTime
|
||||
lastSize = currentSize
|
||||
|
||||
// Send new config to main loop (non-blocking)
|
||||
select {
|
||||
case configChan <- newCfg:
|
||||
default:
|
||||
// Channel full, skip this update
|
||||
logger.Warn("⚠ Previous config reload still in progress, skipping")
|
||||
}
|
||||
}
|
||||
|
||||
case <-stop:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
stopFunc := func() {
|
||||
close(stop)
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
return configChan, stopFunc
|
||||
}
|
||||
|
||||
// getFileModTime returns the modification time of a file, or zero time if file doesn't exist
|
||||
func getFileModTime(path string) time.Time {
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return time.Time{}
|
||||
}
|
||||
return info.ModTime()
|
||||
}
|
||||
|
||||
// getFileSize returns the size of a file, or 0 if file doesn't exist
|
||||
func getFileSize(path string) int64 {
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return info.Size()
|
||||
}
|
||||
|
||||
func setupCronTool(
|
||||
agentLoop *agent.AgentLoop,
|
||||
msgBus *bus.MessageBus,
|
||||
@@ -239,7 +625,7 @@ func setupCronTool(
|
||||
var err error
|
||||
cronTool, err = tools.NewCronTool(cronService, agentLoop, msgBus, workspace, restrict, execTimeout, cfg)
|
||||
if err != nil {
|
||||
log.Fatalf("Critical error during CronTool initialization: %v", err)
|
||||
logger.Fatalf("Critical error during CronTool initialization: %v", err)
|
||||
}
|
||||
|
||||
agentLoop.RegisterTool(cronTool)
|
||||
|
||||
+171
-28
@@ -48,6 +48,9 @@ type AgentLoop struct {
|
||||
transcriber voice.Transcriber
|
||||
cmdRegistry *commands.Registry
|
||||
mcp mcpRuntime
|
||||
mu sync.RWMutex
|
||||
// Track active requests for safe provider cleanup
|
||||
activeRequests sync.WaitGroup
|
||||
}
|
||||
|
||||
// processOptions configures how a message is processed
|
||||
@@ -239,6 +242,7 @@ func registerSharedTools(
|
||||
|
||||
func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
al.running.Store(true)
|
||||
|
||||
if err := al.ensureMCPInitialized(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -278,7 +282,7 @@ func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
// If so, skip publishing to avoid duplicate messages to the user.
|
||||
// Use default agent's tools to check (message tool is shared).
|
||||
alreadySent := false
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
defaultAgent := al.GetRegistry().GetDefaultAgent()
|
||||
if defaultAgent != nil {
|
||||
if tool, ok := defaultAgent.Tools.Get("message"); ok {
|
||||
if mt, ok := tool.(*tools.MessageTool); ok {
|
||||
@@ -331,12 +335,13 @@ func (al *AgentLoop) Close() {
|
||||
}
|
||||
}
|
||||
|
||||
al.registry.Close()
|
||||
al.GetRegistry().Close()
|
||||
}
|
||||
|
||||
func (al *AgentLoop) RegisterTool(tool tools.Tool) {
|
||||
for _, agentID := range al.registry.ListAgentIDs() {
|
||||
if agent, ok := al.registry.GetAgent(agentID); ok {
|
||||
registry := al.GetRegistry()
|
||||
for _, agentID := range registry.ListAgentIDs() {
|
||||
if agent, ok := registry.GetAgent(agentID); ok {
|
||||
agent.Tools.Register(tool)
|
||||
}
|
||||
}
|
||||
@@ -346,12 +351,123 @@ func (al *AgentLoop) SetChannelManager(cm *channels.Manager) {
|
||||
al.channelManager = cm
|
||||
}
|
||||
|
||||
// ReloadProviderAndConfig atomically swaps the provider and config with proper synchronization.
|
||||
// It uses a context to allow timeout control from the caller.
|
||||
// Returns an error if the reload fails or context is canceled.
|
||||
func (al *AgentLoop) ReloadProviderAndConfig(
|
||||
ctx context.Context,
|
||||
provider providers.LLMProvider,
|
||||
cfg *config.Config,
|
||||
) error {
|
||||
// Validate inputs
|
||||
if provider == nil {
|
||||
return fmt.Errorf("provider cannot be nil")
|
||||
}
|
||||
if cfg == nil {
|
||||
return fmt.Errorf("config cannot be nil")
|
||||
}
|
||||
|
||||
// Create new registry with updated config and provider
|
||||
// Wrap in defer/recover to handle any panics gracefully
|
||||
var registry *AgentRegistry
|
||||
var panicErr error
|
||||
done := make(chan struct{}, 1)
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
panicErr = fmt.Errorf("panic during registry creation: %v", r)
|
||||
logger.ErrorCF("agent", "Panic during registry creation",
|
||||
map[string]any{"panic": r})
|
||||
}
|
||||
close(done)
|
||||
}()
|
||||
|
||||
registry = NewAgentRegistry(cfg, provider)
|
||||
}()
|
||||
|
||||
// Wait for completion or context cancellation
|
||||
select {
|
||||
case <-done:
|
||||
if registry == nil {
|
||||
if panicErr != nil {
|
||||
return fmt.Errorf("registry creation failed: %w", panicErr)
|
||||
}
|
||||
return fmt.Errorf("registry creation failed (nil result)")
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("context canceled during registry creation: %w", ctx.Err())
|
||||
}
|
||||
|
||||
// Check context again before proceeding
|
||||
if err := ctx.Err(); err != nil {
|
||||
return fmt.Errorf("context canceled after registry creation: %w", err)
|
||||
}
|
||||
|
||||
// Ensure shared tools are re-registered on the new registry
|
||||
registerSharedTools(cfg, al.bus, registry, provider)
|
||||
|
||||
// Atomically swap the config and registry under write lock
|
||||
// This ensures readers see a consistent pair
|
||||
al.mu.Lock()
|
||||
oldRegistry := al.registry
|
||||
|
||||
// Store new values
|
||||
al.cfg = cfg
|
||||
al.registry = registry
|
||||
|
||||
// Also update fallback chain with new config
|
||||
al.fallback = providers.NewFallbackChain(providers.NewCooldownTracker())
|
||||
|
||||
al.mu.Unlock()
|
||||
|
||||
// Close old provider after releasing the lock
|
||||
// This prevents blocking readers while closing
|
||||
if oldProvider, ok := extractProvider(oldRegistry); ok {
|
||||
if stateful, ok := oldProvider.(providers.StatefulProvider); ok {
|
||||
// Give in-flight requests a moment to complete
|
||||
// Use a reasonable timeout that balances cleanup vs resource usage
|
||||
select {
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
stateful.Close()
|
||||
case <-ctx.Done():
|
||||
// Context canceled, close immediately but log warning
|
||||
logger.WarnCF("agent", "Context canceled during provider cleanup, forcing close",
|
||||
map[string]any{"error": ctx.Err()})
|
||||
stateful.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logger.InfoCF("agent", "Provider and config reloaded successfully",
|
||||
map[string]any{
|
||||
"model": cfg.Agents.Defaults.GetModelName(),
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetRegistry returns the current registry (thread-safe)
|
||||
func (al *AgentLoop) GetRegistry() *AgentRegistry {
|
||||
al.mu.RLock()
|
||||
defer al.mu.RUnlock()
|
||||
return al.registry
|
||||
}
|
||||
|
||||
// GetConfig returns the current config (thread-safe)
|
||||
func (al *AgentLoop) GetConfig() *config.Config {
|
||||
al.mu.RLock()
|
||||
defer al.mu.RUnlock()
|
||||
return al.cfg
|
||||
}
|
||||
|
||||
// SetMediaStore injects a MediaStore for media lifecycle management.
|
||||
func (al *AgentLoop) SetMediaStore(s media.MediaStore) {
|
||||
al.mediaStore = s
|
||||
|
||||
// Propagate store to send_file tools in all agents.
|
||||
al.registry.ForEachTool("send_file", func(t tools.Tool) {
|
||||
registry := al.GetRegistry()
|
||||
registry.ForEachTool("send_file", func(t tools.Tool) {
|
||||
if sf, ok := t.(*tools.SendFileTool); ok {
|
||||
sf.SetMediaStore(s)
|
||||
}
|
||||
@@ -540,7 +656,7 @@ func (al *AgentLoop) ProcessHeartbeat(
|
||||
ctx context.Context,
|
||||
content, channel, chatID string,
|
||||
) (string, error) {
|
||||
agent := al.registry.GetDefaultAgent()
|
||||
agent := al.GetRegistry().GetDefaultAgent()
|
||||
if agent == nil {
|
||||
return "", fmt.Errorf("no default agent for heartbeat")
|
||||
}
|
||||
@@ -636,7 +752,8 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
|
||||
}
|
||||
|
||||
func (al *AgentLoop) resolveMessageRoute(msg bus.InboundMessage) (routing.ResolvedRoute, *AgentInstance, error) {
|
||||
route := al.registry.ResolveRoute(routing.RouteInput{
|
||||
registry := al.GetRegistry()
|
||||
route := registry.ResolveRoute(routing.RouteInput{
|
||||
Channel: msg.Channel,
|
||||
AccountID: inboundMetadata(msg, metadataKeyAccountID),
|
||||
Peer: extractPeer(msg),
|
||||
@@ -645,9 +762,9 @@ func (al *AgentLoop) resolveMessageRoute(msg bus.InboundMessage) (routing.Resolv
|
||||
TeamID: inboundMetadata(msg, metadataKeyTeamID),
|
||||
})
|
||||
|
||||
agent, ok := al.registry.GetAgent(route.AgentID)
|
||||
agent, ok := registry.GetAgent(route.AgentID)
|
||||
if !ok {
|
||||
agent = al.registry.GetDefaultAgent()
|
||||
agent = registry.GetDefaultAgent()
|
||||
}
|
||||
if agent == nil {
|
||||
return routing.ResolvedRoute{}, nil, fmt.Errorf("no agent available for route (agent_id=%s)", route.AgentID)
|
||||
@@ -709,7 +826,7 @@ func (al *AgentLoop) processSystemMessage(
|
||||
}
|
||||
|
||||
// Use default agent for system messages
|
||||
agent := al.registry.GetDefaultAgent()
|
||||
agent := al.GetRegistry().GetDefaultAgent()
|
||||
if agent == nil {
|
||||
return "", fmt.Errorf("no default agent for system message")
|
||||
}
|
||||
@@ -765,7 +882,8 @@ func (al *AgentLoop) runAgentLoop(
|
||||
)
|
||||
|
||||
// Resolve media:// refs to base64 data URLs (streaming)
|
||||
maxMediaSize := al.cfg.Agents.Defaults.GetMaxMediaSize()
|
||||
cfg := al.GetConfig()
|
||||
maxMediaSize := cfg.Agents.Defaults.GetMaxMediaSize()
|
||||
messages = resolveMediaRefs(messages, al.mediaStore, maxMediaSize)
|
||||
|
||||
// 2. Save user message to session
|
||||
@@ -943,6 +1061,9 @@ func (al *AgentLoop) runLLMIteration(
|
||||
}
|
||||
|
||||
callLLM := func() (*providers.LLMResponse, error) {
|
||||
al.activeRequests.Add(1)
|
||||
defer al.activeRequests.Done()
|
||||
|
||||
if len(activeCandidates) > 1 && al.fallback != nil {
|
||||
fbResult, fbErr := al.fallback.Execute(
|
||||
ctx,
|
||||
@@ -1041,6 +1162,7 @@ func (al *AgentLoop) runLLMIteration(
|
||||
map[string]any{
|
||||
"agent_id": agent.ID,
|
||||
"iteration": iteration,
|
||||
"model": activeModel,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return "", iteration, fmt.Errorf("LLM call failed after retries: %w", err)
|
||||
@@ -1392,7 +1514,8 @@ func (al *AgentLoop) forceCompression(agent *AgentInstance, sessionKey string) {
|
||||
func (al *AgentLoop) GetStartupInfo() map[string]any {
|
||||
info := make(map[string]any)
|
||||
|
||||
agent := al.registry.GetDefaultAgent()
|
||||
registry := al.GetRegistry()
|
||||
agent := registry.GetDefaultAgent()
|
||||
if agent == nil {
|
||||
return info
|
||||
}
|
||||
@@ -1409,8 +1532,8 @@ func (al *AgentLoop) GetStartupInfo() map[string]any {
|
||||
|
||||
// Agents info
|
||||
info["agents"] = map[string]any{
|
||||
"count": len(al.registry.ListAgentIDs()),
|
||||
"ids": al.registry.ListAgentIDs(),
|
||||
"count": len(registry.ListAgentIDs()),
|
||||
"ids": registry.ListAgentIDs(),
|
||||
}
|
||||
|
||||
return info
|
||||
@@ -1598,17 +1721,22 @@ func (al *AgentLoop) retryLLMCall(
|
||||
var err error
|
||||
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
resp, err = agent.Provider.Chat(
|
||||
ctx,
|
||||
[]providers.Message{{Role: "user", Content: prompt}},
|
||||
nil,
|
||||
agent.Model,
|
||||
map[string]any{
|
||||
"max_tokens": agent.MaxTokens,
|
||||
"temperature": llmTemperature,
|
||||
"prompt_cache_key": agent.ID,
|
||||
},
|
||||
)
|
||||
al.activeRequests.Add(1)
|
||||
resp, err = func() (*providers.LLMResponse, error) {
|
||||
defer al.activeRequests.Done()
|
||||
return agent.Provider.Chat(
|
||||
ctx,
|
||||
[]providers.Message{{Role: "user", Content: prompt}},
|
||||
nil,
|
||||
agent.Model,
|
||||
map[string]any{
|
||||
"max_tokens": agent.MaxTokens,
|
||||
"temperature": llmTemperature,
|
||||
"prompt_cache_key": agent.ID,
|
||||
},
|
||||
)
|
||||
}()
|
||||
|
||||
if err == nil && resp != nil && resp.Content != "" {
|
||||
return resp, nil
|
||||
}
|
||||
@@ -1741,9 +1869,11 @@ func (al *AgentLoop) handleCommand(
|
||||
}
|
||||
|
||||
func (al *AgentLoop) buildCommandsRuntime(agent *AgentInstance, opts *processOptions) *commands.Runtime {
|
||||
registry := al.GetRegistry()
|
||||
cfg := al.GetConfig()
|
||||
rt := &commands.Runtime{
|
||||
Config: al.cfg,
|
||||
ListAgentIDs: al.registry.ListAgentIDs,
|
||||
Config: cfg,
|
||||
ListAgentIDs: registry.ListAgentIDs,
|
||||
ListDefinitions: al.cmdRegistry.Definitions,
|
||||
GetEnabledChannels: func() []string {
|
||||
if al.channelManager == nil {
|
||||
@@ -1763,7 +1893,7 @@ func (al *AgentLoop) buildCommandsRuntime(agent *AgentInstance, opts *processOpt
|
||||
}
|
||||
if agent != nil {
|
||||
rt.GetModelInfo = func() (string, string) {
|
||||
return agent.Model, al.cfg.Agents.Defaults.Provider
|
||||
return agent.Model, cfg.Agents.Defaults.Provider
|
||||
}
|
||||
rt.SwitchModel = func(value string) (string, error) {
|
||||
oldModel := agent.Model
|
||||
@@ -1827,3 +1957,16 @@ func extractParentPeer(msg bus.InboundMessage) *routing.RoutePeer {
|
||||
}
|
||||
return &routing.RoutePeer{Kind: parentKind, ID: parentID}
|
||||
}
|
||||
|
||||
// Helper to extract provider from registry for cleanup
|
||||
func extractProvider(registry *AgentRegistry) (providers.LLMProvider, bool) {
|
||||
if registry == nil {
|
||||
return nil, false
|
||||
}
|
||||
// Get any agent to access the provider
|
||||
defaultAgent := registry.GetDefaultAgent()
|
||||
if defaultAgent == nil {
|
||||
return nil, false
|
||||
}
|
||||
return defaultAgent.Provider, true
|
||||
}
|
||||
|
||||
@@ -195,6 +195,10 @@ func DebugC(component string, message string) {
|
||||
logMessage(DEBUG, component, message, nil)
|
||||
}
|
||||
|
||||
func Debugf(message string, ss ...any) {
|
||||
logMessage(DEBUG, "", fmt.Sprintf(message, ss...), nil)
|
||||
}
|
||||
|
||||
func DebugF(message string, fields map[string]any) {
|
||||
logMessage(DEBUG, "", message, fields)
|
||||
}
|
||||
@@ -215,6 +219,10 @@ func InfoF(message string, fields map[string]any) {
|
||||
logMessage(INFO, "", message, fields)
|
||||
}
|
||||
|
||||
func Infof(message string, ss ...any) {
|
||||
logMessage(INFO, "", fmt.Sprintf(message, ss...), nil)
|
||||
}
|
||||
|
||||
func InfoCF(component string, message string, fields map[string]any) {
|
||||
logMessage(INFO, component, message, fields)
|
||||
}
|
||||
@@ -243,6 +251,10 @@ func ErrorC(component string, message string) {
|
||||
logMessage(ERROR, component, message, nil)
|
||||
}
|
||||
|
||||
func Errorf(message string, ss ...any) {
|
||||
logMessage(ERROR, "", fmt.Sprintf(message, ss...), nil)
|
||||
}
|
||||
|
||||
func ErrorF(message string, fields map[string]any) {
|
||||
logMessage(ERROR, "", message, fields)
|
||||
}
|
||||
|
||||
@@ -123,17 +123,21 @@ func TestLoggerHelperFunctions(t *testing.T) {
|
||||
SetLevel(INFO)
|
||||
|
||||
Debug("This should not log")
|
||||
Debugf("this should not log")
|
||||
Info("This should log")
|
||||
Warn("This should log")
|
||||
Error("This should log")
|
||||
|
||||
InfoC("test", "Component message")
|
||||
InfoF("Fields message", map[string]any{"key": "value"})
|
||||
Infof("test from %v", "Infof")
|
||||
|
||||
WarnC("test", "Warning with component")
|
||||
ErrorF("Error with fields", map[string]any{"error": "test"})
|
||||
Errorf("test from %v", "Errorf")
|
||||
|
||||
SetLevel(DEBUG)
|
||||
DebugC("test", "Debug with component")
|
||||
Debugf("test from %v", "Debugf")
|
||||
WarnF("Warning with fields", map[string]any{"key": "value"})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user