fix: Fixed the bug where the bus was closed and consumers had unfinished messages. (#1179)

* fix: Fixed the bug where the bus was closed and consumers had unfinished messages.

* fix: remove unnecessary blank line in Close method

* fix: refactor message bus and channel handling for improved performance and reliability

* fix: improve message handling and bus closure logic for better reliability

* fix: reduce sleep duration in agent loop for improved responsiveness

* fix the test case
This commit is contained in:
juju
2026-03-18 00:12:12 +08:00
committed by GitHub
parent f776611e29
commit 9c31b0ca95
11 changed files with 301 additions and 282 deletions
+4 -6
View File
@@ -267,14 +267,11 @@ func (al *AgentLoop) Run(ctx context.Context) error {
select {
case <-ctx.Done():
return nil
default:
msg, ok := al.bus.ConsumeInbound(ctx)
case msg, ok := <-al.bus.InboundChan():
if !ok {
continue
return nil
}
// Process message
func() {
// TODO: Re-enable media cleanup after inbound media is properly consumed by the agent.
// Currently disabled because files are deleted before the LLM can access their content.
// defer func() {
@@ -327,7 +324,8 @@ func (al *AgentLoop) Run(ctx context.Context) error {
)
}
}
}()
default:
time.Sleep(time.Microsecond * 200)
}
}
+56 -23
View File
@@ -997,20 +997,33 @@ func TestHandleReasoning(t *testing.T) {
al, msgBus := newLoop(t)
al.handleReasoning(context.Background(), "reasoning", "telegram", "")
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if msg, ok := msgBus.SubscribeOutbound(ctx); ok {
for {
select {
case msg, ok := <-msgBus.OutboundChan():
if !ok {
t.Fatalf("expected no outbound message, got %+v", msg)
}
if msg.Content == "reasoning" {
t.Fatalf("expected no message for empty chatID, got %+v", msg)
}
return
case <-ctx.Done():
t.Log("expected an outbound message, got none within timeout")
return
default:
// Continue to check for message
time.Sleep(5 * time.Millisecond) // Avoid busy loop
}
}
})
t.Run("publishes one message for non telegram", func(t *testing.T) {
al, msgBus := newLoop(t)
al.handleReasoning(context.Background(), "hello reasoning", "slack", "channel-1")
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
msg, ok := msgBus.SubscribeOutbound(ctx)
msg, ok := <-msgBus.OutboundChan()
if !ok {
t.Fatal("expected an outbound message")
}
@@ -1024,9 +1037,14 @@ func TestHandleReasoning(t *testing.T) {
reasoning := "hello telegram reasoning"
al.handleReasoning(context.Background(), reasoning, "telegram", "tg-chat")
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
msg, ok := msgBus.SubscribeOutbound(ctx)
for {
select {
case <-ctx.Done():
t.Fatal("expected an outbound message, got none within timeout")
return
case msg, ok := <-msgBus.OutboundChan():
if !ok {
t.Fatal("expected outbound message")
}
@@ -1040,19 +1058,31 @@ func TestHandleReasoning(t *testing.T) {
if msg.Content != reasoning {
t.Fatalf("content mismatch: got %q want %q", msg.Content, reasoning)
}
return
}
}
})
t.Run("expired ctx", func(t *testing.T) {
al, msgBus := newLoop(t)
reasoning := "hello telegram reasoning"
ctx, cancel := context.WithCancel(context.Background())
cancel()
al.handleReasoning(ctx, reasoning, "telegram", "tg-chat")
ctx, cancel = context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
msg, ok := msgBus.SubscribeOutbound(ctx)
if ok {
t.Fatalf("expected no outbound message, got %+v", msg)
al.handleReasoning(context.Background(), reasoning, "telegram", "tg-chat")
consumeCtx, consumeCancel := context.WithTimeout(context.Background(), 2*time.Second)
defer consumeCancel()
for {
select {
case msg, ok := <-msgBus.OutboundChan():
if !ok {
t.Fatalf("expected no outbound message, but received: %+v", msg)
}
t.Logf("Received unexpected outbound message: %+v", msg)
return
case <-consumeCtx.Done():
t.Fatalf("failed: no message received within timeout")
return
}
}
})
@@ -1092,21 +1122,24 @@ func TestHandleReasoning(t *testing.T) {
// Drain the bus and verify the reasoning message was NOT published
// (it should have been dropped due to timeout).
drainCtx, drainCancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer drainCancel()
foundReasoning := false
timeer := time.After(1 * time.Second)
for {
msg, ok := msgBus.SubscribeOutbound(drainCtx)
select {
case <-timeer:
t.Logf(
"no reasoning message received after draining bus for 1s, as expected,length=%d",
len(msgBus.OutboundChan()),
)
return
case msg, ok := <-msgBus.OutboundChan():
if !ok {
break
}
if msg.Content == "should timeout" {
foundReasoning = true
}
}
if foundReasoning {
t.Fatal("expected reasoning message to be dropped when bus is full, but it was published")
}
}
}
})
}
+55 -88
View File
@@ -3,6 +3,7 @@ package bus
import (
"context"
"errors"
"sync"
"sync/atomic"
"github.com/sipeed/picoclaw/pkg/logger"
@@ -17,8 +18,11 @@ type MessageBus struct {
inbound chan InboundMessage
outbound chan OutboundMessage
outboundMedia chan OutboundMediaMessage
closeOnce sync.Once
done chan struct{}
closed atomic.Bool
wg sync.WaitGroup
}
func NewMessageBus() *MessageBus {
@@ -30,128 +34,91 @@ func NewMessageBus() *MessageBus {
}
}
func (mb *MessageBus) PublishInbound(ctx context.Context, msg InboundMessage) error {
func publish[T any](ctx context.Context, mb *MessageBus, ch chan T, msg T) error {
// check bus closed before acquiring wg, to avoid unnecessary wg.Add and potential deadlock
if mb.closed.Load() {
return ErrBusClosed
}
if err := ctx.Err(); err != nil {
return err
}
// check again,before sending message, to avoid sending to closed channel
select {
case mb.inbound <- msg:
return nil
case <-mb.done:
return ErrBusClosed
case <-ctx.Done():
return ctx.Err()
case <-mb.done:
return ErrBusClosed
default:
}
mb.wg.Add(1)
defer mb.wg.Done()
select {
case ch <- msg:
return nil
case <-ctx.Done():
return ctx.Err()
case <-mb.done:
return ErrBusClosed
}
}
func (mb *MessageBus) ConsumeInbound(ctx context.Context) (InboundMessage, bool) {
select {
case msg, ok := <-mb.inbound:
return msg, ok
case <-mb.done:
return InboundMessage{}, false
case <-ctx.Done():
return InboundMessage{}, false
}
func (mb *MessageBus) PublishInbound(ctx context.Context, msg InboundMessage) error {
return publish(ctx, mb, mb.inbound, msg)
}
func (mb *MessageBus) InboundChan() <-chan InboundMessage {
return mb.inbound
}
func (mb *MessageBus) PublishOutbound(ctx context.Context, msg OutboundMessage) error {
if mb.closed.Load() {
return ErrBusClosed
}
if err := ctx.Err(); err != nil {
return err
}
select {
case mb.outbound <- msg:
return nil
case <-mb.done:
return ErrBusClosed
case <-ctx.Done():
return ctx.Err()
}
return publish(ctx, mb, mb.outbound, msg)
}
func (mb *MessageBus) SubscribeOutbound(ctx context.Context) (OutboundMessage, bool) {
select {
case msg, ok := <-mb.outbound:
return msg, ok
case <-mb.done:
return OutboundMessage{}, false
case <-ctx.Done():
return OutboundMessage{}, false
}
func (mb *MessageBus) OutboundChan() <-chan OutboundMessage {
return mb.outbound
}
func (mb *MessageBus) PublishOutboundMedia(ctx context.Context, msg OutboundMediaMessage) error {
if mb.closed.Load() {
return ErrBusClosed
}
if err := ctx.Err(); err != nil {
return err
}
select {
case mb.outboundMedia <- msg:
return nil
case <-mb.done:
return ErrBusClosed
case <-ctx.Done():
return ctx.Err()
}
return publish(ctx, mb, mb.outboundMedia, msg)
}
func (mb *MessageBus) SubscribeOutboundMedia(ctx context.Context) (OutboundMediaMessage, bool) {
select {
case msg, ok := <-mb.outboundMedia:
return msg, ok
case <-mb.done:
return OutboundMediaMessage{}, false
case <-ctx.Done():
return OutboundMediaMessage{}, false
}
func (mb *MessageBus) OutboundMediaChan() <-chan OutboundMediaMessage {
return mb.outboundMedia
}
func (mb *MessageBus) Close() {
if mb.closed.CompareAndSwap(false, true) {
mb.closeOnce.Do(func() {
// notify all blocked publishers to exit
close(mb.done)
// Drain buffered channels so messages aren't silently lost.
// Channels are NOT closed to avoid send-on-closed panics from concurrent publishers.
// because every publisher will check mb.closed before acquiring wg
// so we can be sure that new publishers will not be added new messages after this point
mb.closed.Store(true)
// wait for all ongoing Publish calls to finish, ensuring all messages have been sent to channels or exited
mb.wg.Wait()
// close channels safely
close(mb.inbound)
close(mb.outbound)
close(mb.outboundMedia)
// clean up any remaining messages in channels
drained := 0
for {
select {
case <-mb.inbound:
for range mb.inbound {
drained++
default:
goto doneInbound
}
}
doneInbound:
for {
select {
case <-mb.outbound:
for range mb.outbound {
drained++
default:
goto doneOutbound
}
}
doneOutbound:
for {
select {
case <-mb.outboundMedia:
for range mb.outboundMedia {
drained++
default:
goto doneMedia
}
}
doneMedia:
if drained > 0 {
logger.DebugCF("bus", "Drained buffered messages during close", map[string]any{
"count": drained,
})
}
}
})
}
+32 -14
View File
@@ -24,7 +24,7 @@ func TestPublishConsume(t *testing.T) {
t.Fatalf("PublishInbound failed: %v", err)
}
got, ok := mb.ConsumeInbound(ctx)
got, ok := <-mb.InboundChan()
if !ok {
t.Fatal("ConsumeInbound returned ok=false")
}
@@ -52,7 +52,7 @@ func TestPublishOutboundSubscribe(t *testing.T) {
t.Fatalf("PublishOutbound failed: %v", err)
}
got, ok := mb.SubscribeOutbound(ctx)
got, ok := <-mb.OutboundChan()
if !ok {
t.Fatal("SubscribeOutbound returned ok=false")
}
@@ -108,27 +108,48 @@ func TestPublishOutbound_BusClosed(t *testing.T) {
func TestConsumeInbound_ContextCancel(t *testing.T) {
mb := NewMessageBus()
defer mb.Close()
ctx, cancel := context.WithCancel(context.Background())
cancel()
for i := range defaultBusBufferSize {
if err := mb.PublishInbound(context.Background(), InboundMessage{Content: "fill"}); err != nil {
t.Fatalf("fill failed at %d: %v", i, err)
}
}
_, ok := mb.ConsumeInbound(ctx)
if ok {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
mb.PublishInbound(ctx, InboundMessage{Content: "ContextCancel"})
select {
case <-ctx.Done():
t.Log("context canceled, as expected")
case msg, ok := <-mb.InboundChan():
if !ok {
t.Fatal("expected ok=false when context is canceled")
}
if msg.Content == "ContextCancel" {
t.Fatalf("expected content 'ContextCancel', got %q", msg.Content)
}
}
}
func TestConsumeInbound_BusClosed(t *testing.T) {
mb := NewMessageBus()
timer := time.AfterFunc(100*time.Millisecond, func() {
mb.Close()
})
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
select {
case <-timer.C:
t.Log("context canceled, as expected")
_, ok := mb.ConsumeInbound(ctx)
case _, ok := <-mb.InboundChan():
if ok {
t.Fatal("expected ok=false when bus is closed")
t.Fatal("expected ok=false when context is canceled")
}
}
}
@@ -136,10 +157,7 @@ func TestSubscribeOutbound_BusClosed(t *testing.T) {
mb := NewMessageBus()
mb.Close()
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
_, ok := mb.SubscribeOutbound(ctx)
_, ok := <-mb.OutboundChan()
if ok {
t.Fatal("expected ok=false when bus is closed")
}
+10 -4
View File
@@ -585,7 +585,7 @@ func (m *Manager) sendWithRetry(ctx context.Context, name string, w *channelWork
func dispatchLoop[M any](
ctx context.Context,
m *Manager,
subscribe func(context.Context) (M, bool),
ch <-chan M,
getChannel func(M) string,
enqueue func(context.Context, *channelWorker, M) bool,
startMsg, stopMsg, unknownMsg, noWorkerMsg string,
@@ -593,7 +593,12 @@ func dispatchLoop[M any](
logger.InfoC("channels", startMsg)
for {
msg, ok := subscribe(ctx)
select {
case <-ctx.Done():
logger.InfoC("channels", stopMsg)
return
case msg, ok := <-ch:
if !ok {
logger.InfoC("channels", stopMsg)
return
@@ -624,12 +629,13 @@ func dispatchLoop[M any](
logger.WarnCF("channels", noWorkerMsg, map[string]any{"channel": channel})
}
}
}
}
func (m *Manager) dispatchOutbound(ctx context.Context) {
dispatchLoop(
ctx, m,
m.bus.SubscribeOutbound,
m.bus.OutboundChan(),
func(msg bus.OutboundMessage) string { return msg.Channel },
func(ctx context.Context, w *channelWorker, msg bus.OutboundMessage) bool {
select {
@@ -649,7 +655,7 @@ func (m *Manager) dispatchOutbound(ctx context.Context) {
func (m *Manager) dispatchOutboundMedia(ctx context.Context) {
dispatchLoop(
ctx, m,
m.bus.SubscribeOutboundMedia,
m.bus.OutboundMediaChan(),
func(msg bus.OutboundMediaMessage) string { return msg.Channel },
func(ctx context.Context, w *channelWorker, msg bus.OutboundMediaMessage) bool {
select {
+9 -1
View File
@@ -34,11 +34,19 @@ func TestHandleC2CMessage_IncludesAccountIDMetadata(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
inbound, ok := messageBus.ConsumeInbound(ctx)
for {
select {
case <-ctx.Done():
t.Fatal("timeout waiting for inbound message")
return
case inbound, ok := <-messageBus.InboundChan():
if !ok {
t.Fatal("expected inbound message")
}
if inbound.Metadata["account_id"] != "7750283E123456" {
t.Fatalf("account_id metadata = %q, want %q", inbound.Metadata["account_id"], "7750283E123456")
}
return
}
}
}
@@ -3,7 +3,6 @@ package telegram
import (
"context"
"testing"
"time"
"github.com/mymmrac/telego"
@@ -36,10 +35,7 @@ func TestHandleMessage_DoesNotConsumeGenericCommandsLocally(t *testing.T) {
t.Fatalf("handleMessage error: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
inbound, ok := messageBus.ConsumeInbound(ctx)
inbound, ok := <-messageBus.InboundChan()
if !ok {
t.Fatal("expected inbound message to be forwarded")
}
@@ -108,10 +108,15 @@ func TestHandleMessage_GroupMentionOnly_BotCommandEntity(t *testing.T) {
t.Fatalf("handleMessage error: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 150*time.Millisecond)
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Microsecond)
defer cancel()
inbound, ok := messageBus.ConsumeInbound(ctx)
select {
case <-ctx.Done():
if tc.wantForwarded {
t.Fatal("timeout waiting for message to be forwarded")
return
}
case inbound, ok := <-messageBus.InboundChan():
if tc.wantForwarded {
if !ok {
t.Fatal("expected inbound message to be forwarded")
@@ -121,9 +126,6 @@ func TestHandleMessage_GroupMentionOnly_BotCommandEntity(t *testing.T) {
}
return
}
if ok {
t.Fatalf("expected message to be filtered, got content=%q", inbound.Content)
}
})
}
+3 -13
View File
@@ -6,7 +6,6 @@ import (
"errors"
"strings"
"testing"
"time"
"github.com/mymmrac/telego"
ta "github.com/mymmrac/telego/telegoapi"
@@ -355,10 +354,7 @@ func TestHandleMessage_ForumTopic_SetsMetadata(t *testing.T) {
err := ch.handleMessage(context.Background(), msg)
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
inbound, ok := messageBus.ConsumeInbound(ctx)
inbound, ok := <-messageBus.InboundChan()
require.True(t, ok, "expected inbound message")
// Composite chatID should include thread ID
@@ -397,10 +393,7 @@ func TestHandleMessage_NoForum_NoThreadMetadata(t *testing.T) {
err := ch.handleMessage(context.Background(), msg)
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
inbound, ok := messageBus.ConsumeInbound(ctx)
inbound, ok := <-messageBus.InboundChan()
require.True(t, ok)
// Plain chatID without thread suffix
@@ -443,10 +436,7 @@ func TestHandleMessage_ReplyThread_NonForum_NoIsolation(t *testing.T) {
err := ch.handleMessage(context.Background(), msg)
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
inbound, ok := messageBus.ConsumeInbound(ctx)
inbound, ok := <-messageBus.InboundChan()
require.True(t, ok)
// chatID should NOT include thread suffix for non-forum groups
@@ -3,7 +3,6 @@ package whatsapp
import (
"context"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
@@ -25,10 +24,7 @@ func TestHandleIncomingMessage_DoesNotConsumeGenericCommandsLocally(t *testing.T
"content": "/help",
})
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
inbound, ok := messageBus.ConsumeInbound(ctx)
inbound, ok := <-messageBus.InboundChan()
if !ok {
t.Fatal("expected inbound message to be forwarded")
}
@@ -43,7 +43,11 @@ func TestHandleIncoming_DoesNotConsumeGenericCommandsLocally(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
inbound, ok := messageBus.ConsumeInbound(ctx)
select {
case <-ctx.Done():
t.Fatal("timeout waiting for message to be forwarded")
return
case inbound, ok := <-messageBus.InboundChan():
if !ok {
t.Fatal("expected inbound message to be forwarded")
}
@@ -53,4 +57,5 @@ func TestHandleIncoming_DoesNotConsumeGenericCommandsLocally(t *testing.T) {
if inbound.Content != "/new" {
t.Fatalf("content=%q", inbound.Content)
}
}
}