From afc7a1988f0af7479f3bc1625e4e3e24a360ea67 Mon Sep 17 00:00:00 2001 From: Hoshina Date: Mon, 23 Feb 2026 00:44:45 +0800 Subject: [PATCH] refactor(bus): fix deadlock and concurrency issues in MessageBus MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PublishInbound/PublishOutbound held RLock during blocking channel sends, deadlocking against Close() which needs a write lock when the buffer is full. ConsumeInbound/SubscribeOutbound used bare receives instead of comma-ok, causing zero-value processing or busy loops after close. Replace sync.RWMutex+bool with atomic.Bool+done channel so Publish methods use a lock-free 3-way select (send / done / ctx.Done). Add context.Context parameter to both Publish methods so callers can cancel or timeout blocked sends. Close() now only sets the atomic flag and closes the done channel—never closes the data channels—eliminating send-on-closed-channel panics. - Remove dead code: RegisterHandler, GetHandler, handlers map, MessageHandler type (zero callers across the whole repo) - Add ErrBusClosed sentinel error - Update all 10 caller sites to pass context - Add msgBus.Close() to gateway and agent shutdown flows - Add pkg/bus/bus_test.go with 11 test cases covering basic round-trip, context cancellation, closed-bus behavior, concurrent publish+close, full-buffer timeout, and idempotent Close --- cmd/picoclaw/internal/agent/helpers.go | 1 + cmd/picoclaw/internal/gateway/helpers.go | 1 + pkg/agent/loop.go | 10 +- pkg/bus/bus.go | 81 ++++---- pkg/bus/bus_test.go | 229 +++++++++++++++++++++++ pkg/bus/types.go | 2 - pkg/channels/base.go | 2 +- pkg/devices/service.go | 2 +- pkg/heartbeat/service.go | 3 +- pkg/tools/cron.go | 4 +- pkg/tools/subagent.go | 2 +- 11 files changed, 283 insertions(+), 54 deletions(-) create mode 100644 pkg/bus/bus_test.go diff --git a/cmd/picoclaw/internal/agent/helpers.go b/cmd/picoclaw/internal/agent/helpers.go index 746e9755e..f754abc65 100644 --- a/cmd/picoclaw/internal/agent/helpers.go +++ b/cmd/picoclaw/internal/agent/helpers.go @@ -48,6 +48,7 @@ func agentCmd(message, sessionKey, model string, debug bool) error { } msgBus := bus.NewMessageBus() + defer msgBus.Close() agentLoop := agent.NewAgentLoop(cfg, msgBus, provider) // Print agent startup info (only for interactive mode) diff --git a/cmd/picoclaw/internal/gateway/helpers.go b/cmd/picoclaw/internal/gateway/helpers.go index ec5ad5485..e3a51b5e9 100644 --- a/cmd/picoclaw/internal/gateway/helpers.go +++ b/cmd/picoclaw/internal/gateway/helpers.go @@ -223,6 +223,7 @@ func gatewayCmd(debug bool) error { cp.Close() } cancel() + msgBus.Close() healthServer.Stop(context.Background()) deviceService.Stop() heartbeatService.Stop() diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 124f45675..ebbeec0c1 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -121,7 +121,7 @@ func registerSharedTools( // Message tool messageTool := tools.NewMessageTool() messageTool.SetSendCallback(func(channel, chatID, content string) error { - msgBus.PublishOutbound(bus.OutboundMessage{ + msgBus.PublishOutbound(context.TODO(), bus.OutboundMessage{ Channel: channel, ChatID: chatID, Content: content, @@ -200,7 +200,7 @@ func (al *AgentLoop) Run(ctx context.Context) error { } if !alreadySent { - al.bus.PublishOutbound(bus.OutboundMessage{ + al.bus.PublishOutbound(ctx, bus.OutboundMessage{ Channel: msg.Channel, ChatID: msg.ChatID, Content: response, @@ -469,7 +469,7 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opt // 8. Optional: send response via bus if opts.SendResponse { - al.bus.PublishOutbound(bus.OutboundMessage{ + al.bus.PublishOutbound(ctx, bus.OutboundMessage{ Channel: opts.Channel, ChatID: opts.ChatID, Content: finalContent, @@ -586,7 +586,7 @@ func (al *AgentLoop) runLLMIteration( }) if retry == 0 && !constants.IsInternalChannel(opts.Channel) { - al.bus.PublishOutbound(bus.OutboundMessage{ + al.bus.PublishOutbound(ctx, bus.OutboundMessage{ Channel: opts.Channel, ChatID: opts.ChatID, Content: "Context window exceeded. Compressing history and retrying...", @@ -716,7 +716,7 @@ func (al *AgentLoop) runLLMIteration( // Send ForUser content to user immediately if not Silent if !toolResult.Silent && toolResult.ForUser != "" && opts.SendResponse { - al.bus.PublishOutbound(bus.OutboundMessage{ + al.bus.PublishOutbound(ctx, bus.OutboundMessage{ Channel: opts.Channel, ChatID: opts.ChatID, Content: toolResult.ForUser, diff --git a/pkg/bus/bus.go b/pkg/bus/bus.go index 58c0a25d5..100ddc456 100644 --- a/pkg/bus/bus.go +++ b/pkg/bus/bus.go @@ -2,81 +2,80 @@ package bus import ( "context" - "sync" + "errors" + "sync/atomic" ) +// ErrBusClosed is returned when publishing to a closed MessageBus. +var ErrBusClosed = errors.New("message bus closed") + type MessageBus struct { inbound chan InboundMessage outbound chan OutboundMessage - handlers map[string]MessageHandler - closed bool - mu sync.RWMutex + done chan struct{} + closed atomic.Bool } func NewMessageBus() *MessageBus { return &MessageBus{ inbound: make(chan InboundMessage, 100), outbound: make(chan OutboundMessage, 100), - handlers: make(map[string]MessageHandler), + done: make(chan struct{}), } } -func (mb *MessageBus) PublishInbound(msg InboundMessage) { - mb.mu.RLock() - defer mb.mu.RUnlock() - if mb.closed { - return +func (mb *MessageBus) PublishInbound(ctx context.Context, msg InboundMessage) error { + if mb.closed.Load() { + return ErrBusClosed + } + select { + case mb.inbound <- msg: + return nil + case <-mb.done: + return ErrBusClosed + case <-ctx.Done(): + return ctx.Err() } - mb.inbound <- msg } func (mb *MessageBus) ConsumeInbound(ctx context.Context) (InboundMessage, bool) { select { - case msg := <-mb.inbound: - return msg, true + case msg, ok := <-mb.inbound: + return msg, ok + case <-mb.done: + return InboundMessage{}, false case <-ctx.Done(): return InboundMessage{}, false } } -func (mb *MessageBus) PublishOutbound(msg OutboundMessage) { - mb.mu.RLock() - defer mb.mu.RUnlock() - if mb.closed { - return +func (mb *MessageBus) PublishOutbound(ctx context.Context, msg OutboundMessage) error { + if mb.closed.Load() { + return ErrBusClosed + } + select { + case mb.outbound <- msg: + return nil + case <-mb.done: + return ErrBusClosed + case <-ctx.Done(): + return ctx.Err() } - mb.outbound <- msg } func (mb *MessageBus) SubscribeOutbound(ctx context.Context) (OutboundMessage, bool) { select { - case msg := <-mb.outbound: - return msg, true + case msg, ok := <-mb.outbound: + return msg, ok + case <-mb.done: + return OutboundMessage{}, false case <-ctx.Done(): return OutboundMessage{}, false } } -func (mb *MessageBus) RegisterHandler(channel string, handler MessageHandler) { - mb.mu.Lock() - defer mb.mu.Unlock() - mb.handlers[channel] = handler -} - -func (mb *MessageBus) GetHandler(channel string) (MessageHandler, bool) { - mb.mu.RLock() - defer mb.mu.RUnlock() - handler, ok := mb.handlers[channel] - return handler, ok -} - func (mb *MessageBus) Close() { - mb.mu.Lock() - defer mb.mu.Unlock() - if mb.closed { - return + if mb.closed.CompareAndSwap(false, true) { + close(mb.done) } - mb.closed = true - close(mb.inbound) - close(mb.outbound) } diff --git a/pkg/bus/bus_test.go b/pkg/bus/bus_test.go new file mode 100644 index 000000000..47826824e --- /dev/null +++ b/pkg/bus/bus_test.go @@ -0,0 +1,229 @@ +package bus + +import ( + "context" + "sync" + "testing" + "time" +) + +func TestPublishConsume(t *testing.T) { + mb := NewMessageBus() + defer mb.Close() + + ctx := context.Background() + + msg := InboundMessage{ + Channel: "test", + SenderID: "user1", + ChatID: "chat1", + Content: "hello", + } + + if err := mb.PublishInbound(ctx, msg); err != nil { + t.Fatalf("PublishInbound failed: %v", err) + } + + got, ok := mb.ConsumeInbound(ctx) + if !ok { + t.Fatal("ConsumeInbound returned ok=false") + } + if got.Content != "hello" { + t.Fatalf("expected content 'hello', got %q", got.Content) + } + if got.Channel != "test" { + t.Fatalf("expected channel 'test', got %q", got.Channel) + } +} + +func TestPublishOutboundSubscribe(t *testing.T) { + mb := NewMessageBus() + defer mb.Close() + + ctx := context.Background() + + msg := OutboundMessage{ + Channel: "telegram", + ChatID: "123", + Content: "world", + } + + if err := mb.PublishOutbound(ctx, msg); err != nil { + t.Fatalf("PublishOutbound failed: %v", err) + } + + got, ok := mb.SubscribeOutbound(ctx) + if !ok { + t.Fatal("SubscribeOutbound returned ok=false") + } + if got.Content != "world" { + t.Fatalf("expected content 'world', got %q", got.Content) + } +} + +func TestPublishInbound_ContextCancel(t *testing.T) { + mb := NewMessageBus() + defer mb.Close() + + // Fill the buffer + ctx := context.Background() + for i := 0; i < 100; i++ { + if err := mb.PublishInbound(ctx, InboundMessage{Content: "fill"}); err != nil { + t.Fatalf("fill failed at %d: %v", i, err) + } + } + + // Now buffer is full; publish with a cancelled context + cancelCtx, cancel := context.WithCancel(context.Background()) + cancel() + + err := mb.PublishInbound(cancelCtx, InboundMessage{Content: "overflow"}) + if err == nil { + t.Fatal("expected error from cancelled context, got nil") + } + if err != context.Canceled { + t.Fatalf("expected context.Canceled, got %v", err) + } +} + +func TestPublishInbound_BusClosed(t *testing.T) { + mb := NewMessageBus() + mb.Close() + + err := mb.PublishInbound(context.Background(), InboundMessage{Content: "test"}) + if err != ErrBusClosed { + t.Fatalf("expected ErrBusClosed, got %v", err) + } +} + +func TestPublishOutbound_BusClosed(t *testing.T) { + mb := NewMessageBus() + mb.Close() + + err := mb.PublishOutbound(context.Background(), OutboundMessage{Content: "test"}) + if err != ErrBusClosed { + t.Fatalf("expected ErrBusClosed, got %v", err) + } +} + +func TestConsumeInbound_ContextCancel(t *testing.T) { + mb := NewMessageBus() + defer mb.Close() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, ok := mb.ConsumeInbound(ctx) + if ok { + t.Fatal("expected ok=false when context is cancelled") + } +} + +func TestConsumeInbound_BusClosed(t *testing.T) { + mb := NewMessageBus() + mb.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + _, ok := mb.ConsumeInbound(ctx) + if ok { + t.Fatal("expected ok=false when bus is closed") + } +} + +func TestSubscribeOutbound_BusClosed(t *testing.T) { + mb := NewMessageBus() + mb.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + _, ok := mb.SubscribeOutbound(ctx) + if ok { + t.Fatal("expected ok=false when bus is closed") + } +} + +func TestConcurrentPublishClose(t *testing.T) { + mb := NewMessageBus() + ctx := context.Background() + + const numGoroutines = 100 + var wg sync.WaitGroup + wg.Add(numGoroutines + 1) + + // Spawn many goroutines trying to publish + for i := 0; i < numGoroutines; i++ { + go func() { + defer wg.Done() + // Use a short timeout context so we don't block forever after close + publishCtx, cancel := context.WithTimeout(ctx, 50*time.Millisecond) + defer cancel() + // Errors are expected; we just must not panic or deadlock + _ = mb.PublishInbound(publishCtx, InboundMessage{Content: "concurrent"}) + }() + } + + // Close from another goroutine + go func() { + defer wg.Done() + time.Sleep(5 * time.Millisecond) + mb.Close() + }() + + // Must complete without deadlock + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + // success + case <-time.After(5 * time.Second): + t.Fatal("test timed out - possible deadlock") + } +} + +func TestPublishInbound_FullBuffer(t *testing.T) { + mb := NewMessageBus() + defer mb.Close() + + ctx := context.Background() + + // Fill the buffer + for i := 0; i < 100; i++ { + if err := mb.PublishInbound(ctx, InboundMessage{Content: "fill"}); err != nil { + t.Fatalf("fill failed at %d: %v", i, err) + } + } + + // Buffer is full; publish with short timeout + timeoutCtx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + + err := mb.PublishInbound(timeoutCtx, InboundMessage{Content: "overflow"}) + if err == nil { + t.Fatal("expected error when buffer is full and context times out") + } + if err != context.DeadlineExceeded { + t.Fatalf("expected context.DeadlineExceeded, got %v", err) + } +} + +func TestCloseIdempotent(t *testing.T) { + mb := NewMessageBus() + + // Multiple Close calls must not panic + mb.Close() + mb.Close() + mb.Close() + + // After close, publish should return ErrBusClosed + err := mb.PublishInbound(context.Background(), InboundMessage{Content: "test"}) + if err != ErrBusClosed { + t.Fatalf("expected ErrBusClosed after multiple closes, got %v", err) + } +} diff --git a/pkg/bus/types.go b/pkg/bus/types.go index e49713eb8..358829c55 100644 --- a/pkg/bus/types.go +++ b/pkg/bus/types.go @@ -24,5 +24,3 @@ type OutboundMessage struct { ChatID string `json:"chat_id"` Content string `json:"content"` } - -type MessageHandler func(InboundMessage) error diff --git a/pkg/channels/base.go b/pkg/channels/base.go index d967d9e91..adacb8c78 100644 --- a/pkg/channels/base.go +++ b/pkg/channels/base.go @@ -143,7 +143,7 @@ func (c *BaseChannel) HandleMessage( Metadata: metadata, } - c.bus.PublishInbound(msg) + c.bus.PublishInbound(context.TODO(), msg) } func (c *BaseChannel) SetRunning(running bool) { diff --git a/pkg/devices/service.go b/pkg/devices/service.go index 1541d3c57..408e1c8aa 100644 --- a/pkg/devices/service.go +++ b/pkg/devices/service.go @@ -127,7 +127,7 @@ func (s *Service) sendNotification(ev *events.DeviceEvent) { } msg := ev.FormatMessage() - msgBus.PublishOutbound(bus.OutboundMessage{ + msgBus.PublishOutbound(context.TODO(), bus.OutboundMessage{ Channel: platform, ChatID: userID, Content: msg, diff --git a/pkg/heartbeat/service.go b/pkg/heartbeat/service.go index e05a9fdbf..3e58dbc7a 100644 --- a/pkg/heartbeat/service.go +++ b/pkg/heartbeat/service.go @@ -7,6 +7,7 @@ package heartbeat import ( + "context" "fmt" "os" "path/filepath" @@ -307,7 +308,7 @@ func (hs *HeartbeatService) sendResponse(response string) { return } - msgBus.PublishOutbound(bus.OutboundMessage{ + msgBus.PublishOutbound(context.TODO(), bus.OutboundMessage{ Channel: platform, ChatID: userID, Content: response, diff --git a/pkg/tools/cron.go b/pkg/tools/cron.go index 562fffc84..3c13f5968 100644 --- a/pkg/tools/cron.go +++ b/pkg/tools/cron.go @@ -294,7 +294,7 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string { output = fmt.Sprintf("Scheduled command '%s' executed:\n%s", job.Payload.Command, result.ForLLM) } - t.msgBus.PublishOutbound(bus.OutboundMessage{ + t.msgBus.PublishOutbound(context.TODO(), bus.OutboundMessage{ Channel: channel, ChatID: chatID, Content: output, @@ -304,7 +304,7 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string { // If deliver=true, send message directly without agent processing if job.Payload.Deliver { - t.msgBus.PublishOutbound(bus.OutboundMessage{ + t.msgBus.PublishOutbound(context.TODO(), bus.OutboundMessage{ Channel: channel, ChatID: chatID, Content: job.Payload.Message, diff --git a/pkg/tools/subagent.go b/pkg/tools/subagent.go index ad371a649..081a02872 100644 --- a/pkg/tools/subagent.go +++ b/pkg/tools/subagent.go @@ -218,7 +218,7 @@ After completing the task, provide a clear summary of what was done.` // Send announce message back to main agent if sm.bus != nil { announceContent := fmt.Sprintf("Task '%s' completed.\n\nResult:\n%s", task.Label, task.Result) - sm.bus.PublishInbound(bus.InboundMessage{ + sm.bus.PublishInbound(context.TODO(), bus.InboundMessage{ Channel: "system", SenderID: fmt.Sprintf("subagent:%s", task.ID), // Format: "original_channel:original_chat_id" for routing back