Files
picoclaw/pkg/channels/registry.go
T
Cytown 667fc85d54 refactor(config): make config.Channel to multiple instance support
add new field type to Channel struct
config.channels refactor to channel_list
update config version to 3
update the docs
2026-04-13 22:21:21 +08:00

79 lines
2.5 KiB
Go

package channels
import (
"fmt"
"sync"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
)
// ChannelFactory is a constructor function that creates a Channel from config and message bus.
// Each channel subpackage registers one or more factories via init().
// channelName is the config map key for this channel instance (may differ from the channel type).
// channelType is the channel type string used to look up the Channel config.
type ChannelFactory func(channelName, channelType string, cfg *config.Config, bus *bus.MessageBus) (Channel, error)
var (
factoriesMu sync.RWMutex
factories = map[string]ChannelFactory{}
)
// RegisterFactory registers a named channel factory. Called from subpackage init() functions.
func RegisterFactory(name string, f ChannelFactory) {
factoriesMu.Lock()
defer factoriesMu.Unlock()
factories[name] = f
}
// RegisterSafeFactory is a convenience wrapper that handles GetDecoded() error checking
// and type assertion, reducing boilerplate in channel init() functions.
//
// Usage:
//
// func init() {
// channels.RegisterSafeFactory(config.ChannelTelegram,
// func(bc *config.Channel, c *config.TelegramSettings, b *bus.MessageBus) (channels.Channel, error) {
// return NewTelegramChannel(bc, c, b)
// })
// }
func RegisterSafeFactory[S any](
channelType string,
ctor func(bc *config.Channel, settings *S, bus *bus.MessageBus) (Channel, error),
) {
RegisterFactory(channelType, func(channelName, _ string, cfg *config.Config, b *bus.MessageBus) (Channel, error) {
bc := cfg.Channels[channelName]
if bc == nil {
return nil, fmt.Errorf("channel %q: config not found", channelName)
}
decoded, err := bc.GetDecoded()
if err != nil {
return nil, fmt.Errorf("channel %q: failed to decode settings: %w", channelName, err)
}
settings, ok := decoded.(*S)
if !ok {
return nil, fmt.Errorf("channel %q: expected %T settings, got %T", channelName, (*S)(nil), decoded)
}
return ctor(bc, settings, b)
})
}
// getFactory looks up a channel factory by name.
func getFactory(name string) (ChannelFactory, bool) {
factoriesMu.RLock()
defer factoriesMu.RUnlock()
f, ok := factories[name]
return f, ok
}
// GetRegisteredFactoryNames returns a slice of all registered channel factory names.
func GetRegisteredFactoryNames() []string {
factoriesMu.RLock()
defer factoriesMu.RUnlock()
names := make([]string, 0, len(factories))
for name := range factories {
names = append(names, name)
}
return names
}