mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
refactor(bus): fix deadlock and concurrency issues in MessageBus
PublishInbound/PublishOutbound held RLock during blocking channel sends, deadlocking against Close() which needs a write lock when the buffer is full. ConsumeInbound/SubscribeOutbound used bare receives instead of comma-ok, causing zero-value processing or busy loops after close. Replace sync.RWMutex+bool with atomic.Bool+done channel so Publish methods use a lock-free 3-way select (send / done / ctx.Done). Add context.Context parameter to both Publish methods so callers can cancel or timeout blocked sends. Close() now only sets the atomic flag and closes the done channel—never closes the data channels—eliminating send-on-closed-channel panics. - Remove dead code: RegisterHandler, GetHandler, handlers map, MessageHandler type (zero callers across the whole repo) - Add ErrBusClosed sentinel error - Update all 10 caller sites to pass context - Add msgBus.Close() to gateway and agent shutdown flows - Add pkg/bus/bus_test.go with 11 test cases covering basic round-trip, context cancellation, closed-bus behavior, concurrent publish+close, full-buffer timeout, and idempotent Close
This commit is contained in:
@@ -48,6 +48,7 @@ func agentCmd(message, sessionKey, model string, debug bool) error {
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
defer msgBus.Close()
|
||||
agentLoop := agent.NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
// Print agent startup info (only for interactive mode)
|
||||
|
||||
@@ -223,6 +223,7 @@ func gatewayCmd(debug bool) error {
|
||||
cp.Close()
|
||||
}
|
||||
cancel()
|
||||
msgBus.Close()
|
||||
healthServer.Stop(context.Background())
|
||||
deviceService.Stop()
|
||||
heartbeatService.Stop()
|
||||
|
||||
+5
-5
@@ -121,7 +121,7 @@ func registerSharedTools(
|
||||
// Message tool
|
||||
messageTool := tools.NewMessageTool()
|
||||
messageTool.SetSendCallback(func(channel, chatID, content string) error {
|
||||
msgBus.PublishOutbound(bus.OutboundMessage{
|
||||
msgBus.PublishOutbound(context.TODO(), bus.OutboundMessage{
|
||||
Channel: channel,
|
||||
ChatID: chatID,
|
||||
Content: content,
|
||||
@@ -200,7 +200,7 @@ func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
}
|
||||
|
||||
if !alreadySent {
|
||||
al.bus.PublishOutbound(bus.OutboundMessage{
|
||||
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
|
||||
Channel: msg.Channel,
|
||||
ChatID: msg.ChatID,
|
||||
Content: response,
|
||||
@@ -469,7 +469,7 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opt
|
||||
|
||||
// 8. Optional: send response via bus
|
||||
if opts.SendResponse {
|
||||
al.bus.PublishOutbound(bus.OutboundMessage{
|
||||
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
|
||||
Channel: opts.Channel,
|
||||
ChatID: opts.ChatID,
|
||||
Content: finalContent,
|
||||
@@ -586,7 +586,7 @@ func (al *AgentLoop) runLLMIteration(
|
||||
})
|
||||
|
||||
if retry == 0 && !constants.IsInternalChannel(opts.Channel) {
|
||||
al.bus.PublishOutbound(bus.OutboundMessage{
|
||||
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
|
||||
Channel: opts.Channel,
|
||||
ChatID: opts.ChatID,
|
||||
Content: "Context window exceeded. Compressing history and retrying...",
|
||||
@@ -716,7 +716,7 @@ func (al *AgentLoop) runLLMIteration(
|
||||
|
||||
// Send ForUser content to user immediately if not Silent
|
||||
if !toolResult.Silent && toolResult.ForUser != "" && opts.SendResponse {
|
||||
al.bus.PublishOutbound(bus.OutboundMessage{
|
||||
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
|
||||
Channel: opts.Channel,
|
||||
ChatID: opts.ChatID,
|
||||
Content: toolResult.ForUser,
|
||||
|
||||
+40
-41
@@ -2,81 +2,80 @@ package bus
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// ErrBusClosed is returned when publishing to a closed MessageBus.
|
||||
var ErrBusClosed = errors.New("message bus closed")
|
||||
|
||||
type MessageBus struct {
|
||||
inbound chan InboundMessage
|
||||
outbound chan OutboundMessage
|
||||
handlers map[string]MessageHandler
|
||||
closed bool
|
||||
mu sync.RWMutex
|
||||
done chan struct{}
|
||||
closed atomic.Bool
|
||||
}
|
||||
|
||||
func NewMessageBus() *MessageBus {
|
||||
return &MessageBus{
|
||||
inbound: make(chan InboundMessage, 100),
|
||||
outbound: make(chan OutboundMessage, 100),
|
||||
handlers: make(map[string]MessageHandler),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (mb *MessageBus) PublishInbound(msg InboundMessage) {
|
||||
mb.mu.RLock()
|
||||
defer mb.mu.RUnlock()
|
||||
if mb.closed {
|
||||
return
|
||||
func (mb *MessageBus) PublishInbound(ctx context.Context, msg InboundMessage) error {
|
||||
if mb.closed.Load() {
|
||||
return ErrBusClosed
|
||||
}
|
||||
select {
|
||||
case mb.inbound <- msg:
|
||||
return nil
|
||||
case <-mb.done:
|
||||
return ErrBusClosed
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
mb.inbound <- msg
|
||||
}
|
||||
|
||||
func (mb *MessageBus) ConsumeInbound(ctx context.Context) (InboundMessage, bool) {
|
||||
select {
|
||||
case msg := <-mb.inbound:
|
||||
return msg, true
|
||||
case msg, ok := <-mb.inbound:
|
||||
return msg, ok
|
||||
case <-mb.done:
|
||||
return InboundMessage{}, false
|
||||
case <-ctx.Done():
|
||||
return InboundMessage{}, false
|
||||
}
|
||||
}
|
||||
|
||||
func (mb *MessageBus) PublishOutbound(msg OutboundMessage) {
|
||||
mb.mu.RLock()
|
||||
defer mb.mu.RUnlock()
|
||||
if mb.closed {
|
||||
return
|
||||
func (mb *MessageBus) PublishOutbound(ctx context.Context, msg OutboundMessage) error {
|
||||
if mb.closed.Load() {
|
||||
return ErrBusClosed
|
||||
}
|
||||
select {
|
||||
case mb.outbound <- msg:
|
||||
return nil
|
||||
case <-mb.done:
|
||||
return ErrBusClosed
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
mb.outbound <- msg
|
||||
}
|
||||
|
||||
func (mb *MessageBus) SubscribeOutbound(ctx context.Context) (OutboundMessage, bool) {
|
||||
select {
|
||||
case msg := <-mb.outbound:
|
||||
return msg, true
|
||||
case msg, ok := <-mb.outbound:
|
||||
return msg, ok
|
||||
case <-mb.done:
|
||||
return OutboundMessage{}, false
|
||||
case <-ctx.Done():
|
||||
return OutboundMessage{}, false
|
||||
}
|
||||
}
|
||||
|
||||
func (mb *MessageBus) RegisterHandler(channel string, handler MessageHandler) {
|
||||
mb.mu.Lock()
|
||||
defer mb.mu.Unlock()
|
||||
mb.handlers[channel] = handler
|
||||
}
|
||||
|
||||
func (mb *MessageBus) GetHandler(channel string) (MessageHandler, bool) {
|
||||
mb.mu.RLock()
|
||||
defer mb.mu.RUnlock()
|
||||
handler, ok := mb.handlers[channel]
|
||||
return handler, ok
|
||||
}
|
||||
|
||||
func (mb *MessageBus) Close() {
|
||||
mb.mu.Lock()
|
||||
defer mb.mu.Unlock()
|
||||
if mb.closed {
|
||||
return
|
||||
if mb.closed.CompareAndSwap(false, true) {
|
||||
close(mb.done)
|
||||
}
|
||||
mb.closed = true
|
||||
close(mb.inbound)
|
||||
close(mb.outbound)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,229 @@
|
||||
package bus
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestPublishConsume(t *testing.T) {
|
||||
mb := NewMessageBus()
|
||||
defer mb.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
msg := InboundMessage{
|
||||
Channel: "test",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "hello",
|
||||
}
|
||||
|
||||
if err := mb.PublishInbound(ctx, msg); err != nil {
|
||||
t.Fatalf("PublishInbound failed: %v", err)
|
||||
}
|
||||
|
||||
got, ok := mb.ConsumeInbound(ctx)
|
||||
if !ok {
|
||||
t.Fatal("ConsumeInbound returned ok=false")
|
||||
}
|
||||
if got.Content != "hello" {
|
||||
t.Fatalf("expected content 'hello', got %q", got.Content)
|
||||
}
|
||||
if got.Channel != "test" {
|
||||
t.Fatalf("expected channel 'test', got %q", got.Channel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishOutboundSubscribe(t *testing.T) {
|
||||
mb := NewMessageBus()
|
||||
defer mb.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
msg := OutboundMessage{
|
||||
Channel: "telegram",
|
||||
ChatID: "123",
|
||||
Content: "world",
|
||||
}
|
||||
|
||||
if err := mb.PublishOutbound(ctx, msg); err != nil {
|
||||
t.Fatalf("PublishOutbound failed: %v", err)
|
||||
}
|
||||
|
||||
got, ok := mb.SubscribeOutbound(ctx)
|
||||
if !ok {
|
||||
t.Fatal("SubscribeOutbound returned ok=false")
|
||||
}
|
||||
if got.Content != "world" {
|
||||
t.Fatalf("expected content 'world', got %q", got.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishInbound_ContextCancel(t *testing.T) {
|
||||
mb := NewMessageBus()
|
||||
defer mb.Close()
|
||||
|
||||
// Fill the buffer
|
||||
ctx := context.Background()
|
||||
for i := 0; i < 100; i++ {
|
||||
if err := mb.PublishInbound(ctx, InboundMessage{Content: "fill"}); err != nil {
|
||||
t.Fatalf("fill failed at %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Now buffer is full; publish with a cancelled context
|
||||
cancelCtx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
err := mb.PublishInbound(cancelCtx, InboundMessage{Content: "overflow"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error from cancelled context, got nil")
|
||||
}
|
||||
if err != context.Canceled {
|
||||
t.Fatalf("expected context.Canceled, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishInbound_BusClosed(t *testing.T) {
|
||||
mb := NewMessageBus()
|
||||
mb.Close()
|
||||
|
||||
err := mb.PublishInbound(context.Background(), InboundMessage{Content: "test"})
|
||||
if err != ErrBusClosed {
|
||||
t.Fatalf("expected ErrBusClosed, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishOutbound_BusClosed(t *testing.T) {
|
||||
mb := NewMessageBus()
|
||||
mb.Close()
|
||||
|
||||
err := mb.PublishOutbound(context.Background(), OutboundMessage{Content: "test"})
|
||||
if err != ErrBusClosed {
|
||||
t.Fatalf("expected ErrBusClosed, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConsumeInbound_ContextCancel(t *testing.T) {
|
||||
mb := NewMessageBus()
|
||||
defer mb.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
_, ok := mb.ConsumeInbound(ctx)
|
||||
if ok {
|
||||
t.Fatal("expected ok=false when context is cancelled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConsumeInbound_BusClosed(t *testing.T) {
|
||||
mb := NewMessageBus()
|
||||
mb.Close()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
_, ok := mb.ConsumeInbound(ctx)
|
||||
if ok {
|
||||
t.Fatal("expected ok=false when bus is closed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubscribeOutbound_BusClosed(t *testing.T) {
|
||||
mb := NewMessageBus()
|
||||
mb.Close()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
_, ok := mb.SubscribeOutbound(ctx)
|
||||
if ok {
|
||||
t.Fatal("expected ok=false when bus is closed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrentPublishClose(t *testing.T) {
|
||||
mb := NewMessageBus()
|
||||
ctx := context.Background()
|
||||
|
||||
const numGoroutines = 100
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numGoroutines + 1)
|
||||
|
||||
// Spawn many goroutines trying to publish
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
// Use a short timeout context so we don't block forever after close
|
||||
publishCtx, cancel := context.WithTimeout(ctx, 50*time.Millisecond)
|
||||
defer cancel()
|
||||
// Errors are expected; we just must not panic or deadlock
|
||||
_ = mb.PublishInbound(publishCtx, InboundMessage{Content: "concurrent"})
|
||||
}()
|
||||
}
|
||||
|
||||
// Close from another goroutine
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
mb.Close()
|
||||
}()
|
||||
|
||||
// Must complete without deadlock
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// success
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("test timed out - possible deadlock")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishInbound_FullBuffer(t *testing.T) {
|
||||
mb := NewMessageBus()
|
||||
defer mb.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Fill the buffer
|
||||
for i := 0; i < 100; i++ {
|
||||
if err := mb.PublishInbound(ctx, InboundMessage{Content: "fill"}); err != nil {
|
||||
t.Fatalf("fill failed at %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Buffer is full; publish with short timeout
|
||||
timeoutCtx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
err := mb.PublishInbound(timeoutCtx, InboundMessage{Content: "overflow"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error when buffer is full and context times out")
|
||||
}
|
||||
if err != context.DeadlineExceeded {
|
||||
t.Fatalf("expected context.DeadlineExceeded, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloseIdempotent(t *testing.T) {
|
||||
mb := NewMessageBus()
|
||||
|
||||
// Multiple Close calls must not panic
|
||||
mb.Close()
|
||||
mb.Close()
|
||||
mb.Close()
|
||||
|
||||
// After close, publish should return ErrBusClosed
|
||||
err := mb.PublishInbound(context.Background(), InboundMessage{Content: "test"})
|
||||
if err != ErrBusClosed {
|
||||
t.Fatalf("expected ErrBusClosed after multiple closes, got %v", err)
|
||||
}
|
||||
}
|
||||
@@ -24,5 +24,3 @@ type OutboundMessage struct {
|
||||
ChatID string `json:"chat_id"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type MessageHandler func(InboundMessage) error
|
||||
|
||||
@@ -143,7 +143,7 @@ func (c *BaseChannel) HandleMessage(
|
||||
Metadata: metadata,
|
||||
}
|
||||
|
||||
c.bus.PublishInbound(msg)
|
||||
c.bus.PublishInbound(context.TODO(), msg)
|
||||
}
|
||||
|
||||
func (c *BaseChannel) SetRunning(running bool) {
|
||||
|
||||
@@ -127,7 +127,7 @@ func (s *Service) sendNotification(ev *events.DeviceEvent) {
|
||||
}
|
||||
|
||||
msg := ev.FormatMessage()
|
||||
msgBus.PublishOutbound(bus.OutboundMessage{
|
||||
msgBus.PublishOutbound(context.TODO(), bus.OutboundMessage{
|
||||
Channel: platform,
|
||||
ChatID: userID,
|
||||
Content: msg,
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
package heartbeat
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -307,7 +308,7 @@ func (hs *HeartbeatService) sendResponse(response string) {
|
||||
return
|
||||
}
|
||||
|
||||
msgBus.PublishOutbound(bus.OutboundMessage{
|
||||
msgBus.PublishOutbound(context.TODO(), bus.OutboundMessage{
|
||||
Channel: platform,
|
||||
ChatID: userID,
|
||||
Content: response,
|
||||
|
||||
+2
-2
@@ -294,7 +294,7 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string {
|
||||
output = fmt.Sprintf("Scheduled command '%s' executed:\n%s", job.Payload.Command, result.ForLLM)
|
||||
}
|
||||
|
||||
t.msgBus.PublishOutbound(bus.OutboundMessage{
|
||||
t.msgBus.PublishOutbound(context.TODO(), bus.OutboundMessage{
|
||||
Channel: channel,
|
||||
ChatID: chatID,
|
||||
Content: output,
|
||||
@@ -304,7 +304,7 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string {
|
||||
|
||||
// If deliver=true, send message directly without agent processing
|
||||
if job.Payload.Deliver {
|
||||
t.msgBus.PublishOutbound(bus.OutboundMessage{
|
||||
t.msgBus.PublishOutbound(context.TODO(), bus.OutboundMessage{
|
||||
Channel: channel,
|
||||
ChatID: chatID,
|
||||
Content: job.Payload.Message,
|
||||
|
||||
@@ -218,7 +218,7 @@ After completing the task, provide a clear summary of what was done.`
|
||||
// Send announce message back to main agent
|
||||
if sm.bus != nil {
|
||||
announceContent := fmt.Sprintf("Task '%s' completed.\n\nResult:\n%s", task.Label, task.Result)
|
||||
sm.bus.PublishInbound(bus.InboundMessage{
|
||||
sm.bus.PublishInbound(context.TODO(), bus.InboundMessage{
|
||||
Channel: "system",
|
||||
SenderID: fmt.Sprintf("subagent:%s", task.ID),
|
||||
// Format: "original_channel:original_chat_id" for routing back
|
||||
|
||||
Reference in New Issue
Block a user