refactor[gateway]: just reload the changed channels on reload occurred (#1773)

This commit is contained in:
Cytown
2026-03-19 15:28:52 +08:00
committed by GitHub
parent 2a6ade0fe4
commit a8ce992429
4 changed files with 238 additions and 33 deletions
+94 -27
View File
@@ -86,9 +86,10 @@ type Manager struct {
mux *http.ServeMux
httpServer *http.Server
mu sync.RWMutex
placeholders sync.Map // "channel:chatID" → placeholderID (string)
typingStops sync.Map // "channel:chatID" → func()
reactionUndos sync.Map // "channel:chatID" → reactionEntry
placeholders sync.Map // "channel:chatID" → placeholderID (string)
typingStops sync.Map // "channel:chatID" → func()
reactionUndos sync.Map // "channel:chatID" → reactionEntry
channelHashes map[string]string // channel name → config hash
}
type asyncTask struct {
@@ -178,17 +179,21 @@ func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMess
func NewManager(cfg *config.Config, messageBus *bus.MessageBus, store media.MediaStore) (*Manager, error) {
m := &Manager{
channels: make(map[string]Channel),
workers: make(map[string]*channelWorker),
bus: messageBus,
config: cfg,
mediaStore: store,
channels: make(map[string]Channel),
workers: make(map[string]*channelWorker),
bus: messageBus,
config: cfg,
mediaStore: store,
channelHashes: make(map[string]string),
}
if err := m.initChannels(); err != nil {
if err := m.initChannels(&cfg.Channels); err != nil {
return nil, err
}
// Store initial config hashes for all channels
m.channelHashes = toChannelHashes(cfg)
return m, nil
}
@@ -232,15 +237,15 @@ func (m *Manager) initChannel(name, displayName string) {
}
}
func (m *Manager) initChannels() error {
func (m *Manager) initChannels(channels *config.ChannelsConfig) error {
logger.InfoC("channels", "Initializing channel manager")
if m.config.Channels.Telegram.Enabled && m.config.Channels.Telegram.Token != "" {
if channels.Telegram.Enabled && channels.Telegram.Token != "" {
m.initChannel("telegram", "Telegram")
}
if m.config.Channels.WhatsApp.Enabled {
waCfg := m.config.Channels.WhatsApp
if channels.WhatsApp.Enabled {
waCfg := channels.WhatsApp
if waCfg.UseNative {
m.initChannel("whatsapp_native", "WhatsApp Native")
} else if waCfg.BridgeURL != "" {
@@ -248,62 +253,62 @@ func (m *Manager) initChannels() error {
}
}
if m.config.Channels.Feishu.Enabled {
if channels.Feishu.Enabled {
m.initChannel("feishu", "Feishu")
}
if m.config.Channels.Discord.Enabled && m.config.Channels.Discord.Token != "" {
if channels.Discord.Enabled && channels.Discord.Token != "" {
m.initChannel("discord", "Discord")
}
if m.config.Channels.MaixCam.Enabled {
if channels.MaixCam.Enabled {
m.initChannel("maixcam", "MaixCam")
}
if m.config.Channels.QQ.Enabled {
if channels.QQ.Enabled {
m.initChannel("qq", "QQ")
}
if m.config.Channels.DingTalk.Enabled && m.config.Channels.DingTalk.ClientID != "" {
if channels.DingTalk.Enabled && channels.DingTalk.ClientID != "" {
m.initChannel("dingtalk", "DingTalk")
}
if m.config.Channels.Slack.Enabled && m.config.Channels.Slack.BotToken != "" {
if channels.Slack.Enabled && channels.Slack.BotToken != "" {
m.initChannel("slack", "Slack")
}
if m.config.Channels.Matrix.Enabled &&
if channels.Matrix.Enabled &&
m.config.Channels.Matrix.Homeserver != "" &&
m.config.Channels.Matrix.UserID != "" &&
m.config.Channels.Matrix.AccessToken != "" {
m.initChannel("matrix", "Matrix")
}
if m.config.Channels.LINE.Enabled && m.config.Channels.LINE.ChannelAccessToken != "" {
if channels.LINE.Enabled && channels.LINE.ChannelAccessToken != "" {
m.initChannel("line", "LINE")
}
if m.config.Channels.OneBot.Enabled && m.config.Channels.OneBot.WSUrl != "" {
if channels.OneBot.Enabled && channels.OneBot.WSUrl != "" {
m.initChannel("onebot", "OneBot")
}
if m.config.Channels.WeCom.Enabled && m.config.Channels.WeCom.Token != "" {
if channels.WeCom.Enabled && channels.WeCom.Token != "" {
m.initChannel("wecom", "WeCom")
}
if m.config.Channels.WeComAIBot.Enabled && m.config.Channels.WeComAIBot.Token != "" {
if channels.WeComAIBot.Enabled && channels.WeComAIBot.Token != "" {
m.initChannel("wecom_aibot", "WeCom AI Bot")
}
if m.config.Channels.WeComApp.Enabled && m.config.Channels.WeComApp.CorpID != "" {
if channels.WeComApp.Enabled && channels.WeComApp.CorpID != "" {
m.initChannel("wecom_app", "WeCom App")
}
if m.config.Channels.Pico.Enabled && m.config.Channels.Pico.Token != "" {
if channels.Pico.Enabled && channels.Pico.Token != "" {
m.initChannel("pico", "Pico")
}
if m.config.Channels.IRC.Enabled && m.config.Channels.IRC.Server != "" {
if channels.IRC.Enabled && channels.IRC.Server != "" {
m.initChannel("irc", "IRC")
}
@@ -825,6 +830,68 @@ func (m *Manager) GetEnabledChannels() []string {
return names
}
// Reload updates the config reference without restarting channels.
// This is used when channel config hasn't changed but other parts of the config have.
func (m *Manager) Reload(ctx context.Context, cfg *config.Config) error {
m.mu.Lock()
defer m.mu.Unlock()
list := toChannelHashes(cfg)
added, removed := compareChannels(m.channelHashes, list)
for _, name := range removed {
// Stop all channels
channel := m.channels[name]
logger.InfoCF("channels", "Stopping channel", map[string]any{
"channel": name,
})
if err := channel.Stop(ctx); err != nil {
logger.ErrorCF("channels", "Error stopping channel", map[string]any{
"channel": name,
"error": err.Error(),
})
}
go func() {
m.UnregisterChannel(name)
}()
}
dispatchCtx, cancel := context.WithCancel(ctx)
m.dispatchTask = &asyncTask{cancel: cancel}
cc, err := toChannelConfig(cfg, added)
if err != nil {
logger.ErrorC("channels", fmt.Sprintf("toChannelConfig error: %v", err))
return err
}
err = m.initChannels(cc)
if err != nil {
logger.ErrorC("channels", fmt.Sprintf("initChannels error: %v", err))
return err
}
for _, name := range added {
channel := m.channels[name]
logger.InfoCF("channels", "Starting channel", map[string]any{
"channel": name,
})
if err := channel.Start(ctx); err != nil {
logger.ErrorCF("channels", "Failed to start channel", map[string]any{
"channel": name,
"error": err.Error(),
})
continue
}
// Lazily create worker only after channel starts successfully
w := newChannelWorker(name, channel)
m.workers[name] = w
go m.runWorker(dispatchCtx, name, w)
go m.runMediaWorker(dispatchCtx, name, w)
go func() {
m.RegisterChannel(name, channel)
}()
}
m.config = cfg
m.channelHashes = toChannelHashes(cfg)
return nil
}
func (m *Manager) RegisterChannel(name string, channel Channel) {
m.mu.Lock()
defer m.mu.Unlock()
+86
View File
@@ -0,0 +1,86 @@
package channels
import (
"crypto/md5"
"encoding/hex"
"encoding/json"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
)
func toChannelHashes(cfg *config.Config) map[string]string {
result := make(map[string]string)
ch := cfg.Channels
// should not be error
marshal, _ := json.Marshal(ch)
var channelConfig map[string]map[string]any
_ = json.Unmarshal(marshal, &channelConfig)
for key, value := range channelConfig {
if !value["enabled"].(bool) {
continue
}
valueBytes, _ := json.Marshal(value)
hash := md5.Sum(valueBytes)
result[key] = hex.EncodeToString(hash[:])
}
return result
}
func compareChannels(old, news map[string]string) (added, removed []string) {
for key, newHash := range news {
if oldHash, ok := old[key]; ok {
if newHash != oldHash {
removed = append(removed, key)
added = append(added, key)
}
} else {
added = append(added, key)
}
}
for key := range old {
if _, ok := news[key]; !ok {
removed = append(removed, key)
}
}
return added, removed
}
func toChannelConfig(cfg *config.Config, list []string) (*config.ChannelsConfig, error) {
result := &config.ChannelsConfig{}
ch := cfg.Channels
// should not be error
marshal, _ := json.Marshal(ch)
var channelConfig map[string]map[string]any
_ = json.Unmarshal(marshal, &channelConfig)
temp := make(map[string]map[string]any, 0)
for key, value := range channelConfig {
found := false
for _, s := range list {
if key == s {
found = true
break
}
}
if !found || !value["enabled"].(bool) {
continue
}
temp[key] = value
}
marshal, err := json.Marshal(temp)
if err != nil {
logger.Errorf("marshal error: %v", err)
return nil, err
}
err = json.Unmarshal(marshal, result)
if err != nil {
logger.Errorf("unmarshal error: %v", err)
return nil, err
}
return result, nil
}
+51
View File
@@ -0,0 +1,51 @@
package channels
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
)
func TestToChannelHashes(t *testing.T) {
logger.SetLevel(logger.DEBUG)
cfg := config.DefaultConfig()
results := toChannelHashes(cfg)
assert.Equal(t, 0, len(results))
logger.Debugf("results: %v", results)
cfg2 := config.DefaultConfig()
cfg2.Channels.DingTalk.Enabled = true
results2 := toChannelHashes(cfg2)
assert.Equal(t, 1, len(results2))
logger.Debugf("results2: %v", results2)
added, removed := compareChannels(results, results2)
assert.EqualValues(t, []string{"dingtalk"}, added)
assert.EqualValues(t, []string(nil), removed)
cfg3 := config.DefaultConfig()
cfg3.Channels.Telegram.Enabled = true
results3 := toChannelHashes(cfg3)
assert.Equal(t, 1, len(results3))
logger.Debugf("results3: %v", results3)
added, removed = compareChannels(results2, results3)
assert.EqualValues(t, []string{"dingtalk"}, removed)
assert.EqualValues(t, []string{"telegram"}, added)
cfg3.Channels.Telegram.Token = "114314"
results4 := toChannelHashes(cfg3)
assert.Equal(t, 1, len(results4))
logger.Debugf("results4: %v", results4)
added, removed = compareChannels(results3, results4)
assert.EqualValues(t, []string{"telegram"}, removed)
assert.EqualValues(t, []string{"telegram"}, added)
cc, err := toChannelConfig(cfg3, added)
assert.NoError(t, err)
logger.Debugf("cc: %#v", cc.Telegram)
assert.Equal(t, "114314", cc.Telegram.Token)
assert.Equal(t, true, cc.Telegram.Enabled)
cc, err = toChannelConfig(cfg2, added)
assert.NoError(t, err)
logger.Debugf("cc: %#v", cc.Telegram)
assert.Equal(t, "", cc.Telegram.Token)
assert.Equal(t, false, cc.Telegram.Enabled)
}
+7 -6
View File
@@ -324,11 +324,12 @@ func setupAndStartServices(
return runningServices, nil
}
func stopAndCleanupServices(runningServices *services, shutdownTimeout time.Duration) {
func stopAndCleanupServices(runningServices *services, shutdownTimeout time.Duration, isReload bool) {
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), shutdownTimeout)
defer shutdownCancel()
if runningServices.ChannelManager != nil {
// reload should not stop channel manager
if !isReload && runningServices.ChannelManager != nil {
runningServices.ChannelManager.StopAll(shutdownCtx)
}
if runningServices.DeviceService != nil {
@@ -357,7 +358,7 @@ func shutdownGateway(
cp.Close()
}
stopAndCleanupServices(runningServices, gracefulShutdownTimeout)
stopAndCleanupServices(runningServices, gracefulShutdownTimeout, false)
agentLoop.Stop()
agentLoop.Close()
@@ -384,7 +385,7 @@ func handleConfigReload(
logger.Infof(" New model is '%s', recreating provider...", newModel)
logger.Info(" Stopping all services...")
stopAndCleanupServices(runningServices, serviceShutdownTimeout)
stopAndCleanupServices(runningServices, serviceShutdownTimeout, true)
newProvider, newModelID, err := createStartupProvider(newCfg, allowEmptyStartup)
if err != nil {
@@ -494,8 +495,8 @@ func restartServices(
}
runningServices.ChannelManager.SetupHTTPServer(addr, runningServices.HealthServer)
if err = runningServices.ChannelManager.StartAll(context.Background()); err != nil {
return fmt.Errorf("error restarting channels: %w", err)
if err = runningServices.ChannelManager.Reload(context.Background(), cfg); err != nil {
return fmt.Errorf("error reload channels: %w", err)
}
fmt.Println(" ✓ Channels restarted.")