refactor(bus): fix deadlock and concurrency issues in MessageBus

PublishInbound/PublishOutbound held RLock during blocking channel sends,
deadlocking against Close() which needs a write lock when the buffer is
full. ConsumeInbound/SubscribeOutbound used bare receives instead of
comma-ok, causing zero-value processing or busy loops after close.

Replace sync.RWMutex+bool with atomic.Bool+done channel so Publish
methods use a lock-free 3-way select (send / done / ctx.Done). Add
context.Context parameter to both Publish methods so callers can cancel
or timeout blocked sends. Close() now only sets the atomic flag and
closes the done channel—never closes the data channels—eliminating
send-on-closed-channel panics.

- Remove dead code: RegisterHandler, GetHandler, handlers map,
  MessageHandler type (zero callers across the whole repo)
- Add ErrBusClosed sentinel error
- Update all 10 caller sites to pass context
- Add msgBus.Close() to gateway and agent shutdown flows
- Add pkg/bus/bus_test.go with 11 test cases covering basic round-trip,
  context cancellation, closed-bus behavior, concurrent publish+close,
  full-buffer timeout, and idempotent Close
This commit is contained in:
Hoshina
2026-02-23 00:44:45 +08:00
parent 38a26d702c
commit afc7a1988f
11 changed files with 283 additions and 54 deletions
+1
View File
@@ -48,6 +48,7 @@ func agentCmd(message, sessionKey, model string, debug bool) error {
}
msgBus := bus.NewMessageBus()
defer msgBus.Close()
agentLoop := agent.NewAgentLoop(cfg, msgBus, provider)
// Print agent startup info (only for interactive mode)
+1
View File
@@ -223,6 +223,7 @@ func gatewayCmd(debug bool) error {
cp.Close()
}
cancel()
msgBus.Close()
healthServer.Stop(context.Background())
deviceService.Stop()
heartbeatService.Stop()
+5 -5
View File
@@ -121,7 +121,7 @@ func registerSharedTools(
// Message tool
messageTool := tools.NewMessageTool()
messageTool.SetSendCallback(func(channel, chatID, content string) error {
msgBus.PublishOutbound(bus.OutboundMessage{
msgBus.PublishOutbound(context.TODO(), bus.OutboundMessage{
Channel: channel,
ChatID: chatID,
Content: content,
@@ -200,7 +200,7 @@ func (al *AgentLoop) Run(ctx context.Context) error {
}
if !alreadySent {
al.bus.PublishOutbound(bus.OutboundMessage{
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
Channel: msg.Channel,
ChatID: msg.ChatID,
Content: response,
@@ -469,7 +469,7 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opt
// 8. Optional: send response via bus
if opts.SendResponse {
al.bus.PublishOutbound(bus.OutboundMessage{
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
Channel: opts.Channel,
ChatID: opts.ChatID,
Content: finalContent,
@@ -586,7 +586,7 @@ func (al *AgentLoop) runLLMIteration(
})
if retry == 0 && !constants.IsInternalChannel(opts.Channel) {
al.bus.PublishOutbound(bus.OutboundMessage{
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
Channel: opts.Channel,
ChatID: opts.ChatID,
Content: "Context window exceeded. Compressing history and retrying...",
@@ -716,7 +716,7 @@ func (al *AgentLoop) runLLMIteration(
// Send ForUser content to user immediately if not Silent
if !toolResult.Silent && toolResult.ForUser != "" && opts.SendResponse {
al.bus.PublishOutbound(bus.OutboundMessage{
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
Channel: opts.Channel,
ChatID: opts.ChatID,
Content: toolResult.ForUser,
+40 -41
View File
@@ -2,81 +2,80 @@ package bus
import (
"context"
"sync"
"errors"
"sync/atomic"
)
// ErrBusClosed is returned when publishing to a closed MessageBus.
var ErrBusClosed = errors.New("message bus closed")
type MessageBus struct {
inbound chan InboundMessage
outbound chan OutboundMessage
handlers map[string]MessageHandler
closed bool
mu sync.RWMutex
done chan struct{}
closed atomic.Bool
}
func NewMessageBus() *MessageBus {
return &MessageBus{
inbound: make(chan InboundMessage, 100),
outbound: make(chan OutboundMessage, 100),
handlers: make(map[string]MessageHandler),
done: make(chan struct{}),
}
}
func (mb *MessageBus) PublishInbound(msg InboundMessage) {
mb.mu.RLock()
defer mb.mu.RUnlock()
if mb.closed {
return
func (mb *MessageBus) PublishInbound(ctx context.Context, msg InboundMessage) error {
if mb.closed.Load() {
return ErrBusClosed
}
select {
case mb.inbound <- msg:
return nil
case <-mb.done:
return ErrBusClosed
case <-ctx.Done():
return ctx.Err()
}
mb.inbound <- msg
}
func (mb *MessageBus) ConsumeInbound(ctx context.Context) (InboundMessage, bool) {
select {
case msg := <-mb.inbound:
return msg, true
case msg, ok := <-mb.inbound:
return msg, ok
case <-mb.done:
return InboundMessage{}, false
case <-ctx.Done():
return InboundMessage{}, false
}
}
func (mb *MessageBus) PublishOutbound(msg OutboundMessage) {
mb.mu.RLock()
defer mb.mu.RUnlock()
if mb.closed {
return
func (mb *MessageBus) PublishOutbound(ctx context.Context, msg OutboundMessage) error {
if mb.closed.Load() {
return ErrBusClosed
}
select {
case mb.outbound <- msg:
return nil
case <-mb.done:
return ErrBusClosed
case <-ctx.Done():
return ctx.Err()
}
mb.outbound <- msg
}
func (mb *MessageBus) SubscribeOutbound(ctx context.Context) (OutboundMessage, bool) {
select {
case msg := <-mb.outbound:
return msg, true
case msg, ok := <-mb.outbound:
return msg, ok
case <-mb.done:
return OutboundMessage{}, false
case <-ctx.Done():
return OutboundMessage{}, false
}
}
func (mb *MessageBus) RegisterHandler(channel string, handler MessageHandler) {
mb.mu.Lock()
defer mb.mu.Unlock()
mb.handlers[channel] = handler
}
func (mb *MessageBus) GetHandler(channel string) (MessageHandler, bool) {
mb.mu.RLock()
defer mb.mu.RUnlock()
handler, ok := mb.handlers[channel]
return handler, ok
}
func (mb *MessageBus) Close() {
mb.mu.Lock()
defer mb.mu.Unlock()
if mb.closed {
return
if mb.closed.CompareAndSwap(false, true) {
close(mb.done)
}
mb.closed = true
close(mb.inbound)
close(mb.outbound)
}
+229
View File
@@ -0,0 +1,229 @@
package bus
import (
"context"
"sync"
"testing"
"time"
)
func TestPublishConsume(t *testing.T) {
mb := NewMessageBus()
defer mb.Close()
ctx := context.Background()
msg := InboundMessage{
Channel: "test",
SenderID: "user1",
ChatID: "chat1",
Content: "hello",
}
if err := mb.PublishInbound(ctx, msg); err != nil {
t.Fatalf("PublishInbound failed: %v", err)
}
got, ok := mb.ConsumeInbound(ctx)
if !ok {
t.Fatal("ConsumeInbound returned ok=false")
}
if got.Content != "hello" {
t.Fatalf("expected content 'hello', got %q", got.Content)
}
if got.Channel != "test" {
t.Fatalf("expected channel 'test', got %q", got.Channel)
}
}
func TestPublishOutboundSubscribe(t *testing.T) {
mb := NewMessageBus()
defer mb.Close()
ctx := context.Background()
msg := OutboundMessage{
Channel: "telegram",
ChatID: "123",
Content: "world",
}
if err := mb.PublishOutbound(ctx, msg); err != nil {
t.Fatalf("PublishOutbound failed: %v", err)
}
got, ok := mb.SubscribeOutbound(ctx)
if !ok {
t.Fatal("SubscribeOutbound returned ok=false")
}
if got.Content != "world" {
t.Fatalf("expected content 'world', got %q", got.Content)
}
}
func TestPublishInbound_ContextCancel(t *testing.T) {
mb := NewMessageBus()
defer mb.Close()
// Fill the buffer
ctx := context.Background()
for i := 0; i < 100; i++ {
if err := mb.PublishInbound(ctx, InboundMessage{Content: "fill"}); err != nil {
t.Fatalf("fill failed at %d: %v", i, err)
}
}
// Now buffer is full; publish with a cancelled context
cancelCtx, cancel := context.WithCancel(context.Background())
cancel()
err := mb.PublishInbound(cancelCtx, InboundMessage{Content: "overflow"})
if err == nil {
t.Fatal("expected error from cancelled context, got nil")
}
if err != context.Canceled {
t.Fatalf("expected context.Canceled, got %v", err)
}
}
func TestPublishInbound_BusClosed(t *testing.T) {
mb := NewMessageBus()
mb.Close()
err := mb.PublishInbound(context.Background(), InboundMessage{Content: "test"})
if err != ErrBusClosed {
t.Fatalf("expected ErrBusClosed, got %v", err)
}
}
func TestPublishOutbound_BusClosed(t *testing.T) {
mb := NewMessageBus()
mb.Close()
err := mb.PublishOutbound(context.Background(), OutboundMessage{Content: "test"})
if err != ErrBusClosed {
t.Fatalf("expected ErrBusClosed, got %v", err)
}
}
func TestConsumeInbound_ContextCancel(t *testing.T) {
mb := NewMessageBus()
defer mb.Close()
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, ok := mb.ConsumeInbound(ctx)
if ok {
t.Fatal("expected ok=false when context is cancelled")
}
}
func TestConsumeInbound_BusClosed(t *testing.T) {
mb := NewMessageBus()
mb.Close()
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
_, ok := mb.ConsumeInbound(ctx)
if ok {
t.Fatal("expected ok=false when bus is closed")
}
}
func TestSubscribeOutbound_BusClosed(t *testing.T) {
mb := NewMessageBus()
mb.Close()
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
_, ok := mb.SubscribeOutbound(ctx)
if ok {
t.Fatal("expected ok=false when bus is closed")
}
}
func TestConcurrentPublishClose(t *testing.T) {
mb := NewMessageBus()
ctx := context.Background()
const numGoroutines = 100
var wg sync.WaitGroup
wg.Add(numGoroutines + 1)
// Spawn many goroutines trying to publish
for i := 0; i < numGoroutines; i++ {
go func() {
defer wg.Done()
// Use a short timeout context so we don't block forever after close
publishCtx, cancel := context.WithTimeout(ctx, 50*time.Millisecond)
defer cancel()
// Errors are expected; we just must not panic or deadlock
_ = mb.PublishInbound(publishCtx, InboundMessage{Content: "concurrent"})
}()
}
// Close from another goroutine
go func() {
defer wg.Done()
time.Sleep(5 * time.Millisecond)
mb.Close()
}()
// Must complete without deadlock
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
select {
case <-done:
// success
case <-time.After(5 * time.Second):
t.Fatal("test timed out - possible deadlock")
}
}
func TestPublishInbound_FullBuffer(t *testing.T) {
mb := NewMessageBus()
defer mb.Close()
ctx := context.Background()
// Fill the buffer
for i := 0; i < 100; i++ {
if err := mb.PublishInbound(ctx, InboundMessage{Content: "fill"}); err != nil {
t.Fatalf("fill failed at %d: %v", i, err)
}
}
// Buffer is full; publish with short timeout
timeoutCtx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancel()
err := mb.PublishInbound(timeoutCtx, InboundMessage{Content: "overflow"})
if err == nil {
t.Fatal("expected error when buffer is full and context times out")
}
if err != context.DeadlineExceeded {
t.Fatalf("expected context.DeadlineExceeded, got %v", err)
}
}
func TestCloseIdempotent(t *testing.T) {
mb := NewMessageBus()
// Multiple Close calls must not panic
mb.Close()
mb.Close()
mb.Close()
// After close, publish should return ErrBusClosed
err := mb.PublishInbound(context.Background(), InboundMessage{Content: "test"})
if err != ErrBusClosed {
t.Fatalf("expected ErrBusClosed after multiple closes, got %v", err)
}
}
-2
View File
@@ -24,5 +24,3 @@ type OutboundMessage struct {
ChatID string `json:"chat_id"`
Content string `json:"content"`
}
type MessageHandler func(InboundMessage) error
+1 -1
View File
@@ -143,7 +143,7 @@ func (c *BaseChannel) HandleMessage(
Metadata: metadata,
}
c.bus.PublishInbound(msg)
c.bus.PublishInbound(context.TODO(), msg)
}
func (c *BaseChannel) SetRunning(running bool) {
+1 -1
View File
@@ -127,7 +127,7 @@ func (s *Service) sendNotification(ev *events.DeviceEvent) {
}
msg := ev.FormatMessage()
msgBus.PublishOutbound(bus.OutboundMessage{
msgBus.PublishOutbound(context.TODO(), bus.OutboundMessage{
Channel: platform,
ChatID: userID,
Content: msg,
+2 -1
View File
@@ -7,6 +7,7 @@
package heartbeat
import (
"context"
"fmt"
"os"
"path/filepath"
@@ -307,7 +308,7 @@ func (hs *HeartbeatService) sendResponse(response string) {
return
}
msgBus.PublishOutbound(bus.OutboundMessage{
msgBus.PublishOutbound(context.TODO(), bus.OutboundMessage{
Channel: platform,
ChatID: userID,
Content: response,
+2 -2
View File
@@ -294,7 +294,7 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string {
output = fmt.Sprintf("Scheduled command '%s' executed:\n%s", job.Payload.Command, result.ForLLM)
}
t.msgBus.PublishOutbound(bus.OutboundMessage{
t.msgBus.PublishOutbound(context.TODO(), bus.OutboundMessage{
Channel: channel,
ChatID: chatID,
Content: output,
@@ -304,7 +304,7 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string {
// If deliver=true, send message directly without agent processing
if job.Payload.Deliver {
t.msgBus.PublishOutbound(bus.OutboundMessage{
t.msgBus.PublishOutbound(context.TODO(), bus.OutboundMessage{
Channel: channel,
ChatID: chatID,
Content: job.Payload.Message,
+1 -1
View File
@@ -218,7 +218,7 @@ After completing the task, provide a clear summary of what was done.`
// Send announce message back to main agent
if sm.bus != nil {
announceContent := fmt.Sprintf("Task '%s' completed.\n\nResult:\n%s", task.Label, task.Result)
sm.bus.PublishInbound(bus.InboundMessage{
sm.bus.PublishInbound(context.TODO(), bus.InboundMessage{
Channel: "system",
SenderID: fmt.Sprintf("subagent:%s", task.ID),
// Format: "original_channel:original_chat_id" for routing back