mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
refactor[gateway]: just reload the changed channels on reload occurred (#1773)
This commit is contained in:
+94
-27
@@ -86,9 +86,10 @@ type Manager struct {
|
||||
mux *http.ServeMux
|
||||
httpServer *http.Server
|
||||
mu sync.RWMutex
|
||||
placeholders sync.Map // "channel:chatID" → placeholderID (string)
|
||||
typingStops sync.Map // "channel:chatID" → func()
|
||||
reactionUndos sync.Map // "channel:chatID" → reactionEntry
|
||||
placeholders sync.Map // "channel:chatID" → placeholderID (string)
|
||||
typingStops sync.Map // "channel:chatID" → func()
|
||||
reactionUndos sync.Map // "channel:chatID" → reactionEntry
|
||||
channelHashes map[string]string // channel name → config hash
|
||||
}
|
||||
|
||||
type asyncTask struct {
|
||||
@@ -178,17 +179,21 @@ func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMess
|
||||
|
||||
func NewManager(cfg *config.Config, messageBus *bus.MessageBus, store media.MediaStore) (*Manager, error) {
|
||||
m := &Manager{
|
||||
channels: make(map[string]Channel),
|
||||
workers: make(map[string]*channelWorker),
|
||||
bus: messageBus,
|
||||
config: cfg,
|
||||
mediaStore: store,
|
||||
channels: make(map[string]Channel),
|
||||
workers: make(map[string]*channelWorker),
|
||||
bus: messageBus,
|
||||
config: cfg,
|
||||
mediaStore: store,
|
||||
channelHashes: make(map[string]string),
|
||||
}
|
||||
|
||||
if err := m.initChannels(); err != nil {
|
||||
if err := m.initChannels(&cfg.Channels); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Store initial config hashes for all channels
|
||||
m.channelHashes = toChannelHashes(cfg)
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
@@ -232,15 +237,15 @@ func (m *Manager) initChannel(name, displayName string) {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) initChannels() error {
|
||||
func (m *Manager) initChannels(channels *config.ChannelsConfig) error {
|
||||
logger.InfoC("channels", "Initializing channel manager")
|
||||
|
||||
if m.config.Channels.Telegram.Enabled && m.config.Channels.Telegram.Token != "" {
|
||||
if channels.Telegram.Enabled && channels.Telegram.Token != "" {
|
||||
m.initChannel("telegram", "Telegram")
|
||||
}
|
||||
|
||||
if m.config.Channels.WhatsApp.Enabled {
|
||||
waCfg := m.config.Channels.WhatsApp
|
||||
if channels.WhatsApp.Enabled {
|
||||
waCfg := channels.WhatsApp
|
||||
if waCfg.UseNative {
|
||||
m.initChannel("whatsapp_native", "WhatsApp Native")
|
||||
} else if waCfg.BridgeURL != "" {
|
||||
@@ -248,62 +253,62 @@ func (m *Manager) initChannels() error {
|
||||
}
|
||||
}
|
||||
|
||||
if m.config.Channels.Feishu.Enabled {
|
||||
if channels.Feishu.Enabled {
|
||||
m.initChannel("feishu", "Feishu")
|
||||
}
|
||||
|
||||
if m.config.Channels.Discord.Enabled && m.config.Channels.Discord.Token != "" {
|
||||
if channels.Discord.Enabled && channels.Discord.Token != "" {
|
||||
m.initChannel("discord", "Discord")
|
||||
}
|
||||
|
||||
if m.config.Channels.MaixCam.Enabled {
|
||||
if channels.MaixCam.Enabled {
|
||||
m.initChannel("maixcam", "MaixCam")
|
||||
}
|
||||
|
||||
if m.config.Channels.QQ.Enabled {
|
||||
if channels.QQ.Enabled {
|
||||
m.initChannel("qq", "QQ")
|
||||
}
|
||||
|
||||
if m.config.Channels.DingTalk.Enabled && m.config.Channels.DingTalk.ClientID != "" {
|
||||
if channels.DingTalk.Enabled && channels.DingTalk.ClientID != "" {
|
||||
m.initChannel("dingtalk", "DingTalk")
|
||||
}
|
||||
|
||||
if m.config.Channels.Slack.Enabled && m.config.Channels.Slack.BotToken != "" {
|
||||
if channels.Slack.Enabled && channels.Slack.BotToken != "" {
|
||||
m.initChannel("slack", "Slack")
|
||||
}
|
||||
|
||||
if m.config.Channels.Matrix.Enabled &&
|
||||
if channels.Matrix.Enabled &&
|
||||
m.config.Channels.Matrix.Homeserver != "" &&
|
||||
m.config.Channels.Matrix.UserID != "" &&
|
||||
m.config.Channels.Matrix.AccessToken != "" {
|
||||
m.initChannel("matrix", "Matrix")
|
||||
}
|
||||
|
||||
if m.config.Channels.LINE.Enabled && m.config.Channels.LINE.ChannelAccessToken != "" {
|
||||
if channels.LINE.Enabled && channels.LINE.ChannelAccessToken != "" {
|
||||
m.initChannel("line", "LINE")
|
||||
}
|
||||
|
||||
if m.config.Channels.OneBot.Enabled && m.config.Channels.OneBot.WSUrl != "" {
|
||||
if channels.OneBot.Enabled && channels.OneBot.WSUrl != "" {
|
||||
m.initChannel("onebot", "OneBot")
|
||||
}
|
||||
|
||||
if m.config.Channels.WeCom.Enabled && m.config.Channels.WeCom.Token != "" {
|
||||
if channels.WeCom.Enabled && channels.WeCom.Token != "" {
|
||||
m.initChannel("wecom", "WeCom")
|
||||
}
|
||||
|
||||
if m.config.Channels.WeComAIBot.Enabled && m.config.Channels.WeComAIBot.Token != "" {
|
||||
if channels.WeComAIBot.Enabled && channels.WeComAIBot.Token != "" {
|
||||
m.initChannel("wecom_aibot", "WeCom AI Bot")
|
||||
}
|
||||
|
||||
if m.config.Channels.WeComApp.Enabled && m.config.Channels.WeComApp.CorpID != "" {
|
||||
if channels.WeComApp.Enabled && channels.WeComApp.CorpID != "" {
|
||||
m.initChannel("wecom_app", "WeCom App")
|
||||
}
|
||||
|
||||
if m.config.Channels.Pico.Enabled && m.config.Channels.Pico.Token != "" {
|
||||
if channels.Pico.Enabled && channels.Pico.Token != "" {
|
||||
m.initChannel("pico", "Pico")
|
||||
}
|
||||
|
||||
if m.config.Channels.IRC.Enabled && m.config.Channels.IRC.Server != "" {
|
||||
if channels.IRC.Enabled && channels.IRC.Server != "" {
|
||||
m.initChannel("irc", "IRC")
|
||||
}
|
||||
|
||||
@@ -825,6 +830,68 @@ func (m *Manager) GetEnabledChannels() []string {
|
||||
return names
|
||||
}
|
||||
|
||||
// Reload updates the config reference without restarting channels.
|
||||
// This is used when channel config hasn't changed but other parts of the config have.
|
||||
func (m *Manager) Reload(ctx context.Context, cfg *config.Config) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
list := toChannelHashes(cfg)
|
||||
added, removed := compareChannels(m.channelHashes, list)
|
||||
for _, name := range removed {
|
||||
// Stop all channels
|
||||
channel := m.channels[name]
|
||||
logger.InfoCF("channels", "Stopping channel", map[string]any{
|
||||
"channel": name,
|
||||
})
|
||||
if err := channel.Stop(ctx); err != nil {
|
||||
logger.ErrorCF("channels", "Error stopping channel", map[string]any{
|
||||
"channel": name,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
go func() {
|
||||
m.UnregisterChannel(name)
|
||||
}()
|
||||
}
|
||||
dispatchCtx, cancel := context.WithCancel(ctx)
|
||||
m.dispatchTask = &asyncTask{cancel: cancel}
|
||||
cc, err := toChannelConfig(cfg, added)
|
||||
if err != nil {
|
||||
logger.ErrorC("channels", fmt.Sprintf("toChannelConfig error: %v", err))
|
||||
return err
|
||||
}
|
||||
err = m.initChannels(cc)
|
||||
if err != nil {
|
||||
logger.ErrorC("channels", fmt.Sprintf("initChannels error: %v", err))
|
||||
return err
|
||||
}
|
||||
for _, name := range added {
|
||||
channel := m.channels[name]
|
||||
logger.InfoCF("channels", "Starting channel", map[string]any{
|
||||
"channel": name,
|
||||
})
|
||||
if err := channel.Start(ctx); err != nil {
|
||||
logger.ErrorCF("channels", "Failed to start channel", map[string]any{
|
||||
"channel": name,
|
||||
"error": err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
// Lazily create worker only after channel starts successfully
|
||||
w := newChannelWorker(name, channel)
|
||||
m.workers[name] = w
|
||||
go m.runWorker(dispatchCtx, name, w)
|
||||
go m.runMediaWorker(dispatchCtx, name, w)
|
||||
go func() {
|
||||
m.RegisterChannel(name, channel)
|
||||
}()
|
||||
}
|
||||
|
||||
m.config = cfg
|
||||
m.channelHashes = toChannelHashes(cfg)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) RegisterChannel(name string, channel Channel) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
@@ -0,0 +1,86 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
func toChannelHashes(cfg *config.Config) map[string]string {
|
||||
result := make(map[string]string)
|
||||
ch := cfg.Channels
|
||||
// should not be error
|
||||
marshal, _ := json.Marshal(ch)
|
||||
var channelConfig map[string]map[string]any
|
||||
_ = json.Unmarshal(marshal, &channelConfig)
|
||||
|
||||
for key, value := range channelConfig {
|
||||
if !value["enabled"].(bool) {
|
||||
continue
|
||||
}
|
||||
valueBytes, _ := json.Marshal(value)
|
||||
hash := md5.Sum(valueBytes)
|
||||
result[key] = hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func compareChannels(old, news map[string]string) (added, removed []string) {
|
||||
for key, newHash := range news {
|
||||
if oldHash, ok := old[key]; ok {
|
||||
if newHash != oldHash {
|
||||
removed = append(removed, key)
|
||||
added = append(added, key)
|
||||
}
|
||||
} else {
|
||||
added = append(added, key)
|
||||
}
|
||||
}
|
||||
for key := range old {
|
||||
if _, ok := news[key]; !ok {
|
||||
removed = append(removed, key)
|
||||
}
|
||||
}
|
||||
return added, removed
|
||||
}
|
||||
|
||||
func toChannelConfig(cfg *config.Config, list []string) (*config.ChannelsConfig, error) {
|
||||
result := &config.ChannelsConfig{}
|
||||
ch := cfg.Channels
|
||||
// should not be error
|
||||
marshal, _ := json.Marshal(ch)
|
||||
var channelConfig map[string]map[string]any
|
||||
_ = json.Unmarshal(marshal, &channelConfig)
|
||||
temp := make(map[string]map[string]any, 0)
|
||||
|
||||
for key, value := range channelConfig {
|
||||
found := false
|
||||
for _, s := range list {
|
||||
if key == s {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found || !value["enabled"].(bool) {
|
||||
continue
|
||||
}
|
||||
temp[key] = value
|
||||
}
|
||||
|
||||
marshal, err := json.Marshal(temp)
|
||||
if err != nil {
|
||||
logger.Errorf("marshal error: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
err = json.Unmarshal(marshal, result)
|
||||
if err != nil {
|
||||
logger.Errorf("unmarshal error: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
func TestToChannelHashes(t *testing.T) {
|
||||
logger.SetLevel(logger.DEBUG)
|
||||
cfg := config.DefaultConfig()
|
||||
results := toChannelHashes(cfg)
|
||||
assert.Equal(t, 0, len(results))
|
||||
logger.Debugf("results: %v", results)
|
||||
cfg2 := config.DefaultConfig()
|
||||
cfg2.Channels.DingTalk.Enabled = true
|
||||
results2 := toChannelHashes(cfg2)
|
||||
assert.Equal(t, 1, len(results2))
|
||||
logger.Debugf("results2: %v", results2)
|
||||
added, removed := compareChannels(results, results2)
|
||||
assert.EqualValues(t, []string{"dingtalk"}, added)
|
||||
assert.EqualValues(t, []string(nil), removed)
|
||||
cfg3 := config.DefaultConfig()
|
||||
cfg3.Channels.Telegram.Enabled = true
|
||||
results3 := toChannelHashes(cfg3)
|
||||
assert.Equal(t, 1, len(results3))
|
||||
logger.Debugf("results3: %v", results3)
|
||||
added, removed = compareChannels(results2, results3)
|
||||
assert.EqualValues(t, []string{"dingtalk"}, removed)
|
||||
assert.EqualValues(t, []string{"telegram"}, added)
|
||||
cfg3.Channels.Telegram.Token = "114314"
|
||||
results4 := toChannelHashes(cfg3)
|
||||
assert.Equal(t, 1, len(results4))
|
||||
logger.Debugf("results4: %v", results4)
|
||||
added, removed = compareChannels(results3, results4)
|
||||
assert.EqualValues(t, []string{"telegram"}, removed)
|
||||
assert.EqualValues(t, []string{"telegram"}, added)
|
||||
cc, err := toChannelConfig(cfg3, added)
|
||||
assert.NoError(t, err)
|
||||
logger.Debugf("cc: %#v", cc.Telegram)
|
||||
assert.Equal(t, "114314", cc.Telegram.Token)
|
||||
assert.Equal(t, true, cc.Telegram.Enabled)
|
||||
cc, err = toChannelConfig(cfg2, added)
|
||||
assert.NoError(t, err)
|
||||
logger.Debugf("cc: %#v", cc.Telegram)
|
||||
assert.Equal(t, "", cc.Telegram.Token)
|
||||
assert.Equal(t, false, cc.Telegram.Enabled)
|
||||
}
|
||||
@@ -324,11 +324,12 @@ func setupAndStartServices(
|
||||
return runningServices, nil
|
||||
}
|
||||
|
||||
func stopAndCleanupServices(runningServices *services, shutdownTimeout time.Duration) {
|
||||
func stopAndCleanupServices(runningServices *services, shutdownTimeout time.Duration, isReload bool) {
|
||||
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), shutdownTimeout)
|
||||
defer shutdownCancel()
|
||||
|
||||
if runningServices.ChannelManager != nil {
|
||||
// reload should not stop channel manager
|
||||
if !isReload && runningServices.ChannelManager != nil {
|
||||
runningServices.ChannelManager.StopAll(shutdownCtx)
|
||||
}
|
||||
if runningServices.DeviceService != nil {
|
||||
@@ -357,7 +358,7 @@ func shutdownGateway(
|
||||
cp.Close()
|
||||
}
|
||||
|
||||
stopAndCleanupServices(runningServices, gracefulShutdownTimeout)
|
||||
stopAndCleanupServices(runningServices, gracefulShutdownTimeout, false)
|
||||
|
||||
agentLoop.Stop()
|
||||
agentLoop.Close()
|
||||
@@ -384,7 +385,7 @@ func handleConfigReload(
|
||||
logger.Infof(" New model is '%s', recreating provider...", newModel)
|
||||
|
||||
logger.Info(" Stopping all services...")
|
||||
stopAndCleanupServices(runningServices, serviceShutdownTimeout)
|
||||
stopAndCleanupServices(runningServices, serviceShutdownTimeout, true)
|
||||
|
||||
newProvider, newModelID, err := createStartupProvider(newCfg, allowEmptyStartup)
|
||||
if err != nil {
|
||||
@@ -494,8 +495,8 @@ func restartServices(
|
||||
}
|
||||
runningServices.ChannelManager.SetupHTTPServer(addr, runningServices.HealthServer)
|
||||
|
||||
if err = runningServices.ChannelManager.StartAll(context.Background()); err != nil {
|
||||
return fmt.Errorf("error restarting channels: %w", err)
|
||||
if err = runningServices.ChannelManager.Reload(context.Background(), cfg); err != nil {
|
||||
return fmt.Errorf("error reload channels: %w", err)
|
||||
}
|
||||
fmt.Println(" ✓ Channels restarted.")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user