diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go index aed815399..9e5fea1b6 100644 --- a/pkg/channels/manager.go +++ b/pkg/channels/manager.go @@ -86,9 +86,10 @@ type Manager struct { mux *http.ServeMux httpServer *http.Server mu sync.RWMutex - placeholders sync.Map // "channel:chatID" → placeholderID (string) - typingStops sync.Map // "channel:chatID" → func() - reactionUndos sync.Map // "channel:chatID" → reactionEntry + placeholders sync.Map // "channel:chatID" → placeholderID (string) + typingStops sync.Map // "channel:chatID" → func() + reactionUndos sync.Map // "channel:chatID" → reactionEntry + channelHashes map[string]string // channel name → config hash } type asyncTask struct { @@ -178,17 +179,21 @@ func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMess func NewManager(cfg *config.Config, messageBus *bus.MessageBus, store media.MediaStore) (*Manager, error) { m := &Manager{ - channels: make(map[string]Channel), - workers: make(map[string]*channelWorker), - bus: messageBus, - config: cfg, - mediaStore: store, + channels: make(map[string]Channel), + workers: make(map[string]*channelWorker), + bus: messageBus, + config: cfg, + mediaStore: store, + channelHashes: make(map[string]string), } - if err := m.initChannels(); err != nil { + if err := m.initChannels(&cfg.Channels); err != nil { return nil, err } + // Store initial config hashes for all channels + m.channelHashes = toChannelHashes(cfg) + return m, nil } @@ -232,15 +237,15 @@ func (m *Manager) initChannel(name, displayName string) { } } -func (m *Manager) initChannels() error { +func (m *Manager) initChannels(channels *config.ChannelsConfig) error { logger.InfoC("channels", "Initializing channel manager") - if m.config.Channels.Telegram.Enabled && m.config.Channels.Telegram.Token != "" { + if channels.Telegram.Enabled && channels.Telegram.Token != "" { m.initChannel("telegram", "Telegram") } - if m.config.Channels.WhatsApp.Enabled { - waCfg := m.config.Channels.WhatsApp + if channels.WhatsApp.Enabled { + waCfg := channels.WhatsApp if waCfg.UseNative { m.initChannel("whatsapp_native", "WhatsApp Native") } else if waCfg.BridgeURL != "" { @@ -248,62 +253,62 @@ func (m *Manager) initChannels() error { } } - if m.config.Channels.Feishu.Enabled { + if channels.Feishu.Enabled { m.initChannel("feishu", "Feishu") } - if m.config.Channels.Discord.Enabled && m.config.Channels.Discord.Token != "" { + if channels.Discord.Enabled && channels.Discord.Token != "" { m.initChannel("discord", "Discord") } - if m.config.Channels.MaixCam.Enabled { + if channels.MaixCam.Enabled { m.initChannel("maixcam", "MaixCam") } - if m.config.Channels.QQ.Enabled { + if channels.QQ.Enabled { m.initChannel("qq", "QQ") } - if m.config.Channels.DingTalk.Enabled && m.config.Channels.DingTalk.ClientID != "" { + if channels.DingTalk.Enabled && channels.DingTalk.ClientID != "" { m.initChannel("dingtalk", "DingTalk") } - if m.config.Channels.Slack.Enabled && m.config.Channels.Slack.BotToken != "" { + if channels.Slack.Enabled && channels.Slack.BotToken != "" { m.initChannel("slack", "Slack") } - if m.config.Channels.Matrix.Enabled && + if channels.Matrix.Enabled && m.config.Channels.Matrix.Homeserver != "" && m.config.Channels.Matrix.UserID != "" && m.config.Channels.Matrix.AccessToken != "" { m.initChannel("matrix", "Matrix") } - if m.config.Channels.LINE.Enabled && m.config.Channels.LINE.ChannelAccessToken != "" { + if channels.LINE.Enabled && channels.LINE.ChannelAccessToken != "" { m.initChannel("line", "LINE") } - if m.config.Channels.OneBot.Enabled && m.config.Channels.OneBot.WSUrl != "" { + if channels.OneBot.Enabled && channels.OneBot.WSUrl != "" { m.initChannel("onebot", "OneBot") } - if m.config.Channels.WeCom.Enabled && m.config.Channels.WeCom.Token != "" { + if channels.WeCom.Enabled && channels.WeCom.Token != "" { m.initChannel("wecom", "WeCom") } - if m.config.Channels.WeComAIBot.Enabled && m.config.Channels.WeComAIBot.Token != "" { + if channels.WeComAIBot.Enabled && channels.WeComAIBot.Token != "" { m.initChannel("wecom_aibot", "WeCom AI Bot") } - if m.config.Channels.WeComApp.Enabled && m.config.Channels.WeComApp.CorpID != "" { + if channels.WeComApp.Enabled && channels.WeComApp.CorpID != "" { m.initChannel("wecom_app", "WeCom App") } - if m.config.Channels.Pico.Enabled && m.config.Channels.Pico.Token != "" { + if channels.Pico.Enabled && channels.Pico.Token != "" { m.initChannel("pico", "Pico") } - if m.config.Channels.IRC.Enabled && m.config.Channels.IRC.Server != "" { + if channels.IRC.Enabled && channels.IRC.Server != "" { m.initChannel("irc", "IRC") } @@ -825,6 +830,68 @@ func (m *Manager) GetEnabledChannels() []string { return names } +// Reload updates the config reference without restarting channels. +// This is used when channel config hasn't changed but other parts of the config have. +func (m *Manager) Reload(ctx context.Context, cfg *config.Config) error { + m.mu.Lock() + defer m.mu.Unlock() + list := toChannelHashes(cfg) + added, removed := compareChannels(m.channelHashes, list) + for _, name := range removed { + // Stop all channels + channel := m.channels[name] + logger.InfoCF("channels", "Stopping channel", map[string]any{ + "channel": name, + }) + if err := channel.Stop(ctx); err != nil { + logger.ErrorCF("channels", "Error stopping channel", map[string]any{ + "channel": name, + "error": err.Error(), + }) + } + go func() { + m.UnregisterChannel(name) + }() + } + dispatchCtx, cancel := context.WithCancel(ctx) + m.dispatchTask = &asyncTask{cancel: cancel} + cc, err := toChannelConfig(cfg, added) + if err != nil { + logger.ErrorC("channels", fmt.Sprintf("toChannelConfig error: %v", err)) + return err + } + err = m.initChannels(cc) + if err != nil { + logger.ErrorC("channels", fmt.Sprintf("initChannels error: %v", err)) + return err + } + for _, name := range added { + channel := m.channels[name] + logger.InfoCF("channels", "Starting channel", map[string]any{ + "channel": name, + }) + if err := channel.Start(ctx); err != nil { + logger.ErrorCF("channels", "Failed to start channel", map[string]any{ + "channel": name, + "error": err.Error(), + }) + continue + } + // Lazily create worker only after channel starts successfully + w := newChannelWorker(name, channel) + m.workers[name] = w + go m.runWorker(dispatchCtx, name, w) + go m.runMediaWorker(dispatchCtx, name, w) + go func() { + m.RegisterChannel(name, channel) + }() + } + + m.config = cfg + m.channelHashes = toChannelHashes(cfg) + return nil +} + func (m *Manager) RegisterChannel(name string, channel Channel) { m.mu.Lock() defer m.mu.Unlock() diff --git a/pkg/channels/manager_channel.go b/pkg/channels/manager_channel.go new file mode 100644 index 000000000..57cb05412 --- /dev/null +++ b/pkg/channels/manager_channel.go @@ -0,0 +1,86 @@ +package channels + +import ( + "crypto/md5" + "encoding/hex" + "encoding/json" + + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" +) + +func toChannelHashes(cfg *config.Config) map[string]string { + result := make(map[string]string) + ch := cfg.Channels + // should not be error + marshal, _ := json.Marshal(ch) + var channelConfig map[string]map[string]any + _ = json.Unmarshal(marshal, &channelConfig) + + for key, value := range channelConfig { + if !value["enabled"].(bool) { + continue + } + valueBytes, _ := json.Marshal(value) + hash := md5.Sum(valueBytes) + result[key] = hex.EncodeToString(hash[:]) + } + + return result +} + +func compareChannels(old, news map[string]string) (added, removed []string) { + for key, newHash := range news { + if oldHash, ok := old[key]; ok { + if newHash != oldHash { + removed = append(removed, key) + added = append(added, key) + } + } else { + added = append(added, key) + } + } + for key := range old { + if _, ok := news[key]; !ok { + removed = append(removed, key) + } + } + return added, removed +} + +func toChannelConfig(cfg *config.Config, list []string) (*config.ChannelsConfig, error) { + result := &config.ChannelsConfig{} + ch := cfg.Channels + // should not be error + marshal, _ := json.Marshal(ch) + var channelConfig map[string]map[string]any + _ = json.Unmarshal(marshal, &channelConfig) + temp := make(map[string]map[string]any, 0) + + for key, value := range channelConfig { + found := false + for _, s := range list { + if key == s { + found = true + break + } + } + if !found || !value["enabled"].(bool) { + continue + } + temp[key] = value + } + + marshal, err := json.Marshal(temp) + if err != nil { + logger.Errorf("marshal error: %v", err) + return nil, err + } + err = json.Unmarshal(marshal, result) + if err != nil { + logger.Errorf("unmarshal error: %v", err) + return nil, err + } + + return result, nil +} diff --git a/pkg/channels/manager_channel_test.go b/pkg/channels/manager_channel_test.go new file mode 100644 index 000000000..651764c4f --- /dev/null +++ b/pkg/channels/manager_channel_test.go @@ -0,0 +1,51 @@ +package channels + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" +) + +func TestToChannelHashes(t *testing.T) { + logger.SetLevel(logger.DEBUG) + cfg := config.DefaultConfig() + results := toChannelHashes(cfg) + assert.Equal(t, 0, len(results)) + logger.Debugf("results: %v", results) + cfg2 := config.DefaultConfig() + cfg2.Channels.DingTalk.Enabled = true + results2 := toChannelHashes(cfg2) + assert.Equal(t, 1, len(results2)) + logger.Debugf("results2: %v", results2) + added, removed := compareChannels(results, results2) + assert.EqualValues(t, []string{"dingtalk"}, added) + assert.EqualValues(t, []string(nil), removed) + cfg3 := config.DefaultConfig() + cfg3.Channels.Telegram.Enabled = true + results3 := toChannelHashes(cfg3) + assert.Equal(t, 1, len(results3)) + logger.Debugf("results3: %v", results3) + added, removed = compareChannels(results2, results3) + assert.EqualValues(t, []string{"dingtalk"}, removed) + assert.EqualValues(t, []string{"telegram"}, added) + cfg3.Channels.Telegram.Token = "114314" + results4 := toChannelHashes(cfg3) + assert.Equal(t, 1, len(results4)) + logger.Debugf("results4: %v", results4) + added, removed = compareChannels(results3, results4) + assert.EqualValues(t, []string{"telegram"}, removed) + assert.EqualValues(t, []string{"telegram"}, added) + cc, err := toChannelConfig(cfg3, added) + assert.NoError(t, err) + logger.Debugf("cc: %#v", cc.Telegram) + assert.Equal(t, "114314", cc.Telegram.Token) + assert.Equal(t, true, cc.Telegram.Enabled) + cc, err = toChannelConfig(cfg2, added) + assert.NoError(t, err) + logger.Debugf("cc: %#v", cc.Telegram) + assert.Equal(t, "", cc.Telegram.Token) + assert.Equal(t, false, cc.Telegram.Enabled) +} diff --git a/pkg/gateway/gateway.go b/pkg/gateway/gateway.go index ee7815fe2..9a2706b3b 100644 --- a/pkg/gateway/gateway.go +++ b/pkg/gateway/gateway.go @@ -324,11 +324,12 @@ func setupAndStartServices( return runningServices, nil } -func stopAndCleanupServices(runningServices *services, shutdownTimeout time.Duration) { +func stopAndCleanupServices(runningServices *services, shutdownTimeout time.Duration, isReload bool) { shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), shutdownTimeout) defer shutdownCancel() - if runningServices.ChannelManager != nil { + // reload should not stop channel manager + if !isReload && runningServices.ChannelManager != nil { runningServices.ChannelManager.StopAll(shutdownCtx) } if runningServices.DeviceService != nil { @@ -357,7 +358,7 @@ func shutdownGateway( cp.Close() } - stopAndCleanupServices(runningServices, gracefulShutdownTimeout) + stopAndCleanupServices(runningServices, gracefulShutdownTimeout, false) agentLoop.Stop() agentLoop.Close() @@ -384,7 +385,7 @@ func handleConfigReload( logger.Infof(" New model is '%s', recreating provider...", newModel) logger.Info(" Stopping all services...") - stopAndCleanupServices(runningServices, serviceShutdownTimeout) + stopAndCleanupServices(runningServices, serviceShutdownTimeout, true) newProvider, newModelID, err := createStartupProvider(newCfg, allowEmptyStartup) if err != nil { @@ -494,8 +495,8 @@ func restartServices( } runningServices.ChannelManager.SetupHTTPServer(addr, runningServices.HealthServer) - if err = runningServices.ChannelManager.StartAll(context.Background()); err != nil { - return fmt.Errorf("error restarting channels: %w", err) + if err = runningServices.ChannelManager.Reload(context.Background(), cfg); err != nil { + return fmt.Errorf("error reload channels: %w", err) } fmt.Println(" ✓ Channels restarted.")