package bus import ( "context" "errors" "sync" "sync/atomic" "github.com/sipeed/picoclaw/pkg/logger" ) // ErrBusClosed is returned when publishing to a closed MessageBus. var ErrBusClosed = errors.New("message bus closed") const defaultBusBufferSize = 64 type MessageBus struct { inbound chan InboundMessage outbound chan OutboundMessage outboundMedia chan OutboundMediaMessage closeOnce sync.Once done chan struct{} closed atomic.Bool wg sync.WaitGroup } func NewMessageBus() *MessageBus { return &MessageBus{ inbound: make(chan InboundMessage, defaultBusBufferSize), outbound: make(chan OutboundMessage, defaultBusBufferSize), outboundMedia: make(chan OutboundMediaMessage, defaultBusBufferSize), done: make(chan struct{}), } } func publish[T any](ctx context.Context, mb *MessageBus, ch chan T, msg T) error { // check bus closed before acquiring wg, to avoid unnecessary wg.Add and potential deadlock if mb.closed.Load() { return ErrBusClosed } // check again,before sending message, to avoid sending to closed channel select { case <-ctx.Done(): return ctx.Err() case <-mb.done: return ErrBusClosed default: } mb.wg.Add(1) defer mb.wg.Done() select { case ch <- msg: return nil case <-ctx.Done(): return ctx.Err() case <-mb.done: return ErrBusClosed } } func (mb *MessageBus) PublishInbound(ctx context.Context, msg InboundMessage) error { return publish(ctx, mb, mb.inbound, msg) } func (mb *MessageBus) InboundChan() <-chan InboundMessage { return mb.inbound } func (mb *MessageBus) PublishOutbound(ctx context.Context, msg OutboundMessage) error { return publish(ctx, mb, mb.outbound, msg) } func (mb *MessageBus) OutboundChan() <-chan OutboundMessage { return mb.outbound } func (mb *MessageBus) PublishOutboundMedia(ctx context.Context, msg OutboundMediaMessage) error { return publish(ctx, mb, mb.outboundMedia, msg) } func (mb *MessageBus) OutboundMediaChan() <-chan OutboundMediaMessage { return mb.outboundMedia } func (mb *MessageBus) Close() { mb.closeOnce.Do(func() { // notify all blocked publishers to exit close(mb.done) // because every publisher will check mb.closed before acquiring wg // so we can be sure that new publishers will not be added new messages after this point mb.closed.Store(true) // wait for all ongoing Publish calls to finish, ensuring all messages have been sent to channels or exited mb.wg.Wait() // close channels safely close(mb.inbound) close(mb.outbound) close(mb.outboundMedia) // clean up any remaining messages in channels drained := 0 for range mb.inbound { drained++ } for range mb.outbound { drained++ } for range mb.outboundMedia { drained++ } if drained > 0 { logger.DebugCF("bus", "Drained buffered messages during close", map[string]any{ "count": drained, }) } }) }