fix: Fixed the bug where the bus was closed and consumers had unfinished messages. (#1179)

* fix: Fixed the bug where the bus was closed and consumers had unfinished messages.

* fix: remove unnecessary blank line in Close method

* fix: refactor message bus and channel handling for improved performance and reliability

* fix: improve message handling and bus closure logic for better reliability

* fix: reduce sleep duration in agent loop for improved responsiveness

* fix the test case
This commit is contained in:
juju
2026-03-18 00:12:12 +08:00
committed by GitHub
parent f776611e29
commit 9c31b0ca95
11 changed files with 301 additions and 282 deletions
+53 -55
View File
@@ -267,67 +267,65 @@ func (al *AgentLoop) Run(ctx context.Context) error {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return nil return nil
default: case msg, ok := <-al.bus.InboundChan():
msg, ok := al.bus.ConsumeInbound(ctx)
if !ok { if !ok {
continue return nil
}
// Process message
// TODO: Re-enable media cleanup after inbound media is properly consumed by the agent.
// Currently disabled because files are deleted before the LLM can access their content.
// defer func() {
// if al.mediaStore != nil && msg.MediaScope != "" {
// if releaseErr := al.mediaStore.ReleaseAll(msg.MediaScope); releaseErr != nil {
// logger.WarnCF("agent", "Failed to release media", map[string]any{
// "scope": msg.MediaScope,
// "error": releaseErr.Error(),
// })
// }
// }
// }()
response, err := al.processMessage(ctx, msg)
if err != nil {
response = fmt.Sprintf("Error processing message: %v", err)
} }
// Process message if response != "" {
func() { // Check if the message tool already sent a response during this round.
// TODO: Re-enable media cleanup after inbound media is properly consumed by the agent. // If so, skip publishing to avoid duplicate messages to the user.
// Currently disabled because files are deleted before the LLM can access their content. // Use default agent's tools to check (message tool is shared).
// defer func() { alreadySent := false
// if al.mediaStore != nil && msg.MediaScope != "" { defaultAgent := al.GetRegistry().GetDefaultAgent()
// if releaseErr := al.mediaStore.ReleaseAll(msg.MediaScope); releaseErr != nil { if defaultAgent != nil {
// logger.WarnCF("agent", "Failed to release media", map[string]any{ if tool, ok := defaultAgent.Tools.Get("message"); ok {
// "scope": msg.MediaScope, if mt, ok := tool.(*tools.MessageTool); ok {
// "error": releaseErr.Error(), alreadySent = mt.HasSentInRound()
// })
// }
// }
// }()
response, err := al.processMessage(ctx, msg)
if err != nil {
response = fmt.Sprintf("Error processing message: %v", err)
}
if response != "" {
// Check if the message tool already sent a response during this round.
// If so, skip publishing to avoid duplicate messages to the user.
// Use default agent's tools to check (message tool is shared).
alreadySent := false
defaultAgent := al.GetRegistry().GetDefaultAgent()
if defaultAgent != nil {
if tool, ok := defaultAgent.Tools.Get("message"); ok {
if mt, ok := tool.(*tools.MessageTool); ok {
alreadySent = mt.HasSentInRound()
}
} }
} }
if !alreadySent {
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
Channel: msg.Channel,
ChatID: msg.ChatID,
Content: response,
})
logger.InfoCF("agent", "Published outbound response",
map[string]any{
"channel": msg.Channel,
"chat_id": msg.ChatID,
"content_len": len(response),
})
} else {
logger.DebugCF(
"agent",
"Skipped outbound (message tool already sent)",
map[string]any{"channel": msg.Channel},
)
}
} }
}()
if !alreadySent {
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
Channel: msg.Channel,
ChatID: msg.ChatID,
Content: response,
})
logger.InfoCF("agent", "Published outbound response",
map[string]any{
"channel": msg.Channel,
"chat_id": msg.ChatID,
"content_len": len(response),
})
} else {
logger.DebugCF(
"agent",
"Skipped outbound (message tool already sent)",
map[string]any{"channel": msg.Channel},
)
}
}
default:
time.Sleep(time.Microsecond * 200)
} }
} }
+72 -39
View File
@@ -997,10 +997,25 @@ func TestHandleReasoning(t *testing.T) {
al, msgBus := newLoop(t) al, msgBus := newLoop(t)
al.handleReasoning(context.Background(), "reasoning", "telegram", "") al.handleReasoning(context.Background(), "reasoning", "telegram", "")
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel() defer cancel()
if msg, ok := msgBus.SubscribeOutbound(ctx); ok { for {
t.Fatalf("expected no outbound message, got %+v", msg) select {
case msg, ok := <-msgBus.OutboundChan():
if !ok {
t.Fatalf("expected no outbound message, got %+v", msg)
}
if msg.Content == "reasoning" {
t.Fatalf("expected no message for empty chatID, got %+v", msg)
}
return
case <-ctx.Done():
t.Log("expected an outbound message, got none within timeout")
return
default:
// Continue to check for message
time.Sleep(5 * time.Millisecond) // Avoid busy loop
}
} }
}) })
@@ -1008,9 +1023,7 @@ func TestHandleReasoning(t *testing.T) {
al, msgBus := newLoop(t) al, msgBus := newLoop(t)
al.handleReasoning(context.Background(), "hello reasoning", "slack", "channel-1") al.handleReasoning(context.Background(), "hello reasoning", "slack", "channel-1")
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) msg, ok := <-msgBus.OutboundChan()
defer cancel()
msg, ok := msgBus.SubscribeOutbound(ctx)
if !ok { if !ok {
t.Fatal("expected an outbound message") t.Fatal("expected an outbound message")
} }
@@ -1024,35 +1037,52 @@ func TestHandleReasoning(t *testing.T) {
reasoning := "hello telegram reasoning" reasoning := "hello telegram reasoning"
al.handleReasoning(context.Background(), reasoning, "telegram", "tg-chat") al.handleReasoning(context.Background(), reasoning, "telegram", "tg-chat")
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel() defer cancel()
msg, ok := msgBus.SubscribeOutbound(ctx) for {
if !ok { select {
t.Fatal("expected outbound message") case <-ctx.Done():
} t.Fatal("expected an outbound message, got none within timeout")
return
case msg, ok := <-msgBus.OutboundChan():
if !ok {
t.Fatal("expected outbound message")
}
if msg.Channel != "telegram" { if msg.Channel != "telegram" {
t.Fatalf("expected telegram channel message, got %+v", msg) t.Fatalf("expected telegram channel message, got %+v", msg)
} }
if msg.ChatID != "tg-chat" { if msg.ChatID != "tg-chat" {
t.Fatalf("expected chatID tg-chat, got %+v", msg) t.Fatalf("expected chatID tg-chat, got %+v", msg)
} }
if msg.Content != reasoning { if msg.Content != reasoning {
t.Fatalf("content mismatch: got %q want %q", msg.Content, reasoning) t.Fatalf("content mismatch: got %q want %q", msg.Content, reasoning)
}
return
}
} }
}) })
t.Run("expired ctx", func(t *testing.T) { t.Run("expired ctx", func(t *testing.T) {
al, msgBus := newLoop(t) al, msgBus := newLoop(t)
reasoning := "hello telegram reasoning" reasoning := "hello telegram reasoning"
ctx, cancel := context.WithCancel(context.Background())
cancel()
al.handleReasoning(ctx, reasoning, "telegram", "tg-chat")
ctx, cancel = context.WithTimeout(context.Background(), 200*time.Millisecond) al.handleReasoning(context.Background(), reasoning, "telegram", "tg-chat")
defer cancel()
msg, ok := msgBus.SubscribeOutbound(ctx) consumeCtx, consumeCancel := context.WithTimeout(context.Background(), 2*time.Second)
if ok { defer consumeCancel()
t.Fatalf("expected no outbound message, got %+v", msg)
for {
select {
case msg, ok := <-msgBus.OutboundChan():
if !ok {
t.Fatalf("expected no outbound message, but received: %+v", msg)
}
t.Logf("Received unexpected outbound message: %+v", msg)
return
case <-consumeCtx.Done():
t.Fatalf("failed: no message received within timeout")
return
}
} }
}) })
@@ -1092,20 +1122,23 @@ func TestHandleReasoning(t *testing.T) {
// Drain the bus and verify the reasoning message was NOT published // Drain the bus and verify the reasoning message was NOT published
// (it should have been dropped due to timeout). // (it should have been dropped due to timeout).
drainCtx, drainCancel := context.WithTimeout(context.Background(), 100*time.Millisecond) timeer := time.After(1 * time.Second)
defer drainCancel()
foundReasoning := false
for { for {
msg, ok := msgBus.SubscribeOutbound(drainCtx) select {
if !ok { case <-timeer:
break t.Logf(
"no reasoning message received after draining bus for 1s, as expected,length=%d",
len(msgBus.OutboundChan()),
)
return
case msg, ok := <-msgBus.OutboundChan():
if !ok {
break
}
if msg.Content == "should timeout" {
t.Fatal("expected reasoning message to be dropped when bus is full, but it was published")
}
} }
if msg.Content == "should timeout" {
foundReasoning = true
}
}
if foundReasoning {
t.Fatal("expected reasoning message to be dropped when bus is full, but it was published")
} }
}) })
} }
+60 -93
View File
@@ -3,6 +3,7 @@ package bus
import ( import (
"context" "context"
"errors" "errors"
"sync"
"sync/atomic" "sync/atomic"
"github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/logger"
@@ -17,8 +18,11 @@ type MessageBus struct {
inbound chan InboundMessage inbound chan InboundMessage
outbound chan OutboundMessage outbound chan OutboundMessage
outboundMedia chan OutboundMediaMessage outboundMedia chan OutboundMediaMessage
done chan struct{}
closed atomic.Bool closeOnce sync.Once
done chan struct{}
closed atomic.Bool
wg sync.WaitGroup
} }
func NewMessageBus() *MessageBus { func NewMessageBus() *MessageBus {
@@ -30,128 +34,91 @@ func NewMessageBus() *MessageBus {
} }
} }
func (mb *MessageBus) PublishInbound(ctx context.Context, msg InboundMessage) error { func publish[T any](ctx context.Context, mb *MessageBus, ch chan T, msg T) error {
// check bus closed before acquiring wg, to avoid unnecessary wg.Add and potential deadlock
if mb.closed.Load() { if mb.closed.Load() {
return ErrBusClosed return ErrBusClosed
} }
if err := ctx.Err(); err != nil {
return err // check again,before sending message, to avoid sending to closed channel
}
select { select {
case mb.inbound <- msg:
return nil
case <-mb.done:
return ErrBusClosed
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return ctx.Err()
case <-mb.done:
return ErrBusClosed
default:
}
mb.wg.Add(1)
defer mb.wg.Done()
select {
case ch <- msg:
return nil
case <-ctx.Done():
return ctx.Err()
case <-mb.done:
return ErrBusClosed
} }
} }
func (mb *MessageBus) ConsumeInbound(ctx context.Context) (InboundMessage, bool) { func (mb *MessageBus) PublishInbound(ctx context.Context, msg InboundMessage) error {
select { return publish(ctx, mb, mb.inbound, msg)
case msg, ok := <-mb.inbound: }
return msg, ok
case <-mb.done: func (mb *MessageBus) InboundChan() <-chan InboundMessage {
return InboundMessage{}, false return mb.inbound
case <-ctx.Done():
return InboundMessage{}, false
}
} }
func (mb *MessageBus) PublishOutbound(ctx context.Context, msg OutboundMessage) error { func (mb *MessageBus) PublishOutbound(ctx context.Context, msg OutboundMessage) error {
if mb.closed.Load() { return publish(ctx, mb, mb.outbound, msg)
return ErrBusClosed
}
if err := ctx.Err(); err != nil {
return err
}
select {
case mb.outbound <- msg:
return nil
case <-mb.done:
return ErrBusClosed
case <-ctx.Done():
return ctx.Err()
}
} }
func (mb *MessageBus) SubscribeOutbound(ctx context.Context) (OutboundMessage, bool) { func (mb *MessageBus) OutboundChan() <-chan OutboundMessage {
select { return mb.outbound
case msg, ok := <-mb.outbound:
return msg, ok
case <-mb.done:
return OutboundMessage{}, false
case <-ctx.Done():
return OutboundMessage{}, false
}
} }
func (mb *MessageBus) PublishOutboundMedia(ctx context.Context, msg OutboundMediaMessage) error { func (mb *MessageBus) PublishOutboundMedia(ctx context.Context, msg OutboundMediaMessage) error {
if mb.closed.Load() { return publish(ctx, mb, mb.outboundMedia, msg)
return ErrBusClosed
}
if err := ctx.Err(); err != nil {
return err
}
select {
case mb.outboundMedia <- msg:
return nil
case <-mb.done:
return ErrBusClosed
case <-ctx.Done():
return ctx.Err()
}
} }
func (mb *MessageBus) SubscribeOutboundMedia(ctx context.Context) (OutboundMediaMessage, bool) { func (mb *MessageBus) OutboundMediaChan() <-chan OutboundMediaMessage {
select { return mb.outboundMedia
case msg, ok := <-mb.outboundMedia:
return msg, ok
case <-mb.done:
return OutboundMediaMessage{}, false
case <-ctx.Done():
return OutboundMediaMessage{}, false
}
} }
func (mb *MessageBus) Close() { func (mb *MessageBus) Close() {
if mb.closed.CompareAndSwap(false, true) { mb.closeOnce.Do(func() {
// notify all blocked publishers to exit
close(mb.done) close(mb.done)
// Drain buffered channels so messages aren't silently lost. // because every publisher will check mb.closed before acquiring wg
// Channels are NOT closed to avoid send-on-closed panics from concurrent publishers. // so we can be sure that new publishers will not be added new messages after this point
mb.closed.Store(true)
// wait for all ongoing Publish calls to finish, ensuring all messages have been sent to channels or exited
mb.wg.Wait()
// close channels safely
close(mb.inbound)
close(mb.outbound)
close(mb.outboundMedia)
// clean up any remaining messages in channels
drained := 0 drained := 0
for { for range mb.inbound {
select { drained++
case <-mb.inbound:
drained++
default:
goto doneInbound
}
} }
doneInbound: for range mb.outbound {
for { drained++
select {
case <-mb.outbound:
drained++
default:
goto doneOutbound
}
} }
doneOutbound: for range mb.outboundMedia {
for { drained++
select {
case <-mb.outboundMedia:
drained++
default:
goto doneMedia
}
} }
doneMedia:
if drained > 0 { if drained > 0 {
logger.DebugCF("bus", "Drained buffered messages during close", map[string]any{ logger.DebugCF("bus", "Drained buffered messages during close", map[string]any{
"count": drained, "count": drained,
}) })
} }
} })
} }
+35 -17
View File
@@ -24,7 +24,7 @@ func TestPublishConsume(t *testing.T) {
t.Fatalf("PublishInbound failed: %v", err) t.Fatalf("PublishInbound failed: %v", err)
} }
got, ok := mb.ConsumeInbound(ctx) got, ok := <-mb.InboundChan()
if !ok { if !ok {
t.Fatal("ConsumeInbound returned ok=false") t.Fatal("ConsumeInbound returned ok=false")
} }
@@ -52,7 +52,7 @@ func TestPublishOutboundSubscribe(t *testing.T) {
t.Fatalf("PublishOutbound failed: %v", err) t.Fatalf("PublishOutbound failed: %v", err)
} }
got, ok := mb.SubscribeOutbound(ctx) got, ok := <-mb.OutboundChan()
if !ok { if !ok {
t.Fatal("SubscribeOutbound returned ok=false") t.Fatal("SubscribeOutbound returned ok=false")
} }
@@ -108,27 +108,48 @@ func TestPublishOutbound_BusClosed(t *testing.T) {
func TestConsumeInbound_ContextCancel(t *testing.T) { func TestConsumeInbound_ContextCancel(t *testing.T) {
mb := NewMessageBus() mb := NewMessageBus()
defer mb.Close() defer mb.Close()
ctx, cancel := context.WithCancel(context.Background()) for i := range defaultBusBufferSize {
cancel() if err := mb.PublishInbound(context.Background(), InboundMessage{Content: "fill"}); err != nil {
t.Fatalf("fill failed at %d: %v", i, err)
}
}
_, ok := mb.ConsumeInbound(ctx) ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
if ok { defer cancel()
t.Fatal("expected ok=false when context is canceled") mb.PublishInbound(ctx, InboundMessage{Content: "ContextCancel"})
select {
case <-ctx.Done():
t.Log("context canceled, as expected")
case msg, ok := <-mb.InboundChan():
if !ok {
t.Fatal("expected ok=false when context is canceled")
}
if msg.Content == "ContextCancel" {
t.Fatalf("expected content 'ContextCancel', got %q", msg.Content)
}
} }
} }
func TestConsumeInbound_BusClosed(t *testing.T) { func TestConsumeInbound_BusClosed(t *testing.T) {
mb := NewMessageBus() mb := NewMessageBus()
mb.Close()
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) timer := time.AfterFunc(100*time.Millisecond, func() {
defer cancel() mb.Close()
})
_, ok := mb.ConsumeInbound(ctx) select {
if ok { case <-timer.C:
t.Fatal("expected ok=false when bus is closed") t.Log("context canceled, as expected")
case _, ok := <-mb.InboundChan():
if ok {
t.Fatal("expected ok=false when context is canceled")
}
} }
} }
@@ -136,10 +157,7 @@ func TestSubscribeOutbound_BusClosed(t *testing.T) {
mb := NewMessageBus() mb := NewMessageBus()
mb.Close() mb.Close()
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) _, ok := <-mb.OutboundChan()
defer cancel()
_, ok := mb.SubscribeOutbound(ctx)
if ok { if ok {
t.Fatal("expected ok=false when bus is closed") t.Fatal("expected ok=false when bus is closed")
} }
+33 -27
View File
@@ -585,7 +585,7 @@ func (m *Manager) sendWithRetry(ctx context.Context, name string, w *channelWork
func dispatchLoop[M any]( func dispatchLoop[M any](
ctx context.Context, ctx context.Context,
m *Manager, m *Manager,
subscribe func(context.Context) (M, bool), ch <-chan M,
getChannel func(M) string, getChannel func(M) string,
enqueue func(context.Context, *channelWorker, M) bool, enqueue func(context.Context, *channelWorker, M) bool,
startMsg, stopMsg, unknownMsg, noWorkerMsg string, startMsg, stopMsg, unknownMsg, noWorkerMsg string,
@@ -593,35 +593,41 @@ func dispatchLoop[M any](
logger.InfoC("channels", startMsg) logger.InfoC("channels", startMsg)
for { for {
msg, ok := subscribe(ctx) select {
if !ok { case <-ctx.Done():
logger.InfoC("channels", stopMsg) logger.InfoC("channels", stopMsg)
return return
}
channel := getChannel(msg) case msg, ok := <-ch:
if !ok {
// Silently skip internal channels logger.InfoC("channels", stopMsg)
if constants.IsInternalChannel(channel) {
continue
}
m.mu.RLock()
_, exists := m.channels[channel]
w, wExists := m.workers[channel]
m.mu.RUnlock()
if !exists {
logger.WarnCF("channels", unknownMsg, map[string]any{"channel": channel})
continue
}
if wExists && w != nil {
if !enqueue(ctx, w, msg) {
return return
} }
} else if exists {
logger.WarnCF("channels", noWorkerMsg, map[string]any{"channel": channel}) channel := getChannel(msg)
// Silently skip internal channels
if constants.IsInternalChannel(channel) {
continue
}
m.mu.RLock()
_, exists := m.channels[channel]
w, wExists := m.workers[channel]
m.mu.RUnlock()
if !exists {
logger.WarnCF("channels", unknownMsg, map[string]any{"channel": channel})
continue
}
if wExists && w != nil {
if !enqueue(ctx, w, msg) {
return
}
} else if exists {
logger.WarnCF("channels", noWorkerMsg, map[string]any{"channel": channel})
}
} }
} }
} }
@@ -629,7 +635,7 @@ func dispatchLoop[M any](
func (m *Manager) dispatchOutbound(ctx context.Context) { func (m *Manager) dispatchOutbound(ctx context.Context) {
dispatchLoop( dispatchLoop(
ctx, m, ctx, m,
m.bus.SubscribeOutbound, m.bus.OutboundChan(),
func(msg bus.OutboundMessage) string { return msg.Channel }, func(msg bus.OutboundMessage) string { return msg.Channel },
func(ctx context.Context, w *channelWorker, msg bus.OutboundMessage) bool { func(ctx context.Context, w *channelWorker, msg bus.OutboundMessage) bool {
select { select {
@@ -649,7 +655,7 @@ func (m *Manager) dispatchOutbound(ctx context.Context) {
func (m *Manager) dispatchOutboundMedia(ctx context.Context) { func (m *Manager) dispatchOutboundMedia(ctx context.Context) {
dispatchLoop( dispatchLoop(
ctx, m, ctx, m,
m.bus.SubscribeOutboundMedia, m.bus.OutboundMediaChan(),
func(msg bus.OutboundMediaMessage) string { return msg.Channel }, func(msg bus.OutboundMediaMessage) string { return msg.Channel },
func(ctx context.Context, w *channelWorker, msg bus.OutboundMediaMessage) bool { func(ctx context.Context, w *channelWorker, msg bus.OutboundMediaMessage) bool {
select { select {
+14 -6
View File
@@ -34,11 +34,19 @@ func TestHandleC2CMessage_IncludesAccountIDMetadata(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second) ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel() defer cancel()
inbound, ok := messageBus.ConsumeInbound(ctx) for {
if !ok { select {
t.Fatal("expected inbound message") case <-ctx.Done():
} t.Fatal("timeout waiting for inbound message")
if inbound.Metadata["account_id"] != "7750283E123456" { return
t.Fatalf("account_id metadata = %q, want %q", inbound.Metadata["account_id"], "7750283E123456") case inbound, ok := <-messageBus.InboundChan():
if !ok {
t.Fatal("expected inbound message")
}
if inbound.Metadata["account_id"] != "7750283E123456" {
t.Fatalf("account_id metadata = %q, want %q", inbound.Metadata["account_id"], "7750283E123456")
}
return
}
} }
} }
@@ -3,7 +3,6 @@ package telegram
import ( import (
"context" "context"
"testing" "testing"
"time"
"github.com/mymmrac/telego" "github.com/mymmrac/telego"
@@ -36,10 +35,7 @@ func TestHandleMessage_DoesNotConsumeGenericCommandsLocally(t *testing.T) {
t.Fatalf("handleMessage error: %v", err) t.Fatalf("handleMessage error: %v", err)
} }
ctx, cancel := context.WithTimeout(context.Background(), time.Second) inbound, ok := <-messageBus.InboundChan()
defer cancel()
inbound, ok := messageBus.ConsumeInbound(ctx)
if !ok { if !ok {
t.Fatal("expected inbound message to be forwarded") t.Fatal("expected inbound message to be forwarded")
} }
@@ -108,22 +108,24 @@ func TestHandleMessage_GroupMentionOnly_BotCommandEntity(t *testing.T) {
t.Fatalf("handleMessage error: %v", err) t.Fatalf("handleMessage error: %v", err)
} }
ctx, cancel := context.WithTimeout(context.Background(), 150*time.Millisecond) ctx, cancel := context.WithTimeout(context.Background(), 200*time.Microsecond)
defer cancel() defer cancel()
select {
inbound, ok := messageBus.ConsumeInbound(ctx) case <-ctx.Done():
if tc.wantForwarded { if tc.wantForwarded {
if !ok { t.Fatal("timeout waiting for message to be forwarded")
t.Fatal("expected inbound message to be forwarded") return
} }
if inbound.Content != tc.wantContent { case inbound, ok := <-messageBus.InboundChan():
t.Fatalf("content=%q want=%q", inbound.Content, tc.wantContent) if tc.wantForwarded {
if !ok {
t.Fatal("expected inbound message to be forwarded")
}
if inbound.Content != tc.wantContent {
t.Fatalf("content=%q want=%q", inbound.Content, tc.wantContent)
}
return
} }
return
}
if ok {
t.Fatalf("expected message to be filtered, got content=%q", inbound.Content)
} }
}) })
} }
+3 -13
View File
@@ -6,7 +6,6 @@ import (
"errors" "errors"
"strings" "strings"
"testing" "testing"
"time"
"github.com/mymmrac/telego" "github.com/mymmrac/telego"
ta "github.com/mymmrac/telego/telegoapi" ta "github.com/mymmrac/telego/telegoapi"
@@ -355,10 +354,7 @@ func TestHandleMessage_ForumTopic_SetsMetadata(t *testing.T) {
err := ch.handleMessage(context.Background(), msg) err := ch.handleMessage(context.Background(), msg)
require.NoError(t, err) require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), time.Second) inbound, ok := <-messageBus.InboundChan()
defer cancel()
inbound, ok := messageBus.ConsumeInbound(ctx)
require.True(t, ok, "expected inbound message") require.True(t, ok, "expected inbound message")
// Composite chatID should include thread ID // Composite chatID should include thread ID
@@ -397,10 +393,7 @@ func TestHandleMessage_NoForum_NoThreadMetadata(t *testing.T) {
err := ch.handleMessage(context.Background(), msg) err := ch.handleMessage(context.Background(), msg)
require.NoError(t, err) require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), time.Second) inbound, ok := <-messageBus.InboundChan()
defer cancel()
inbound, ok := messageBus.ConsumeInbound(ctx)
require.True(t, ok) require.True(t, ok)
// Plain chatID without thread suffix // Plain chatID without thread suffix
@@ -443,10 +436,7 @@ func TestHandleMessage_ReplyThread_NonForum_NoIsolation(t *testing.T) {
err := ch.handleMessage(context.Background(), msg) err := ch.handleMessage(context.Background(), msg)
require.NoError(t, err) require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), time.Second) inbound, ok := <-messageBus.InboundChan()
defer cancel()
inbound, ok := messageBus.ConsumeInbound(ctx)
require.True(t, ok) require.True(t, ok)
// chatID should NOT include thread suffix for non-forum groups // chatID should NOT include thread suffix for non-forum groups
@@ -3,7 +3,6 @@ package whatsapp
import ( import (
"context" "context"
"testing" "testing"
"time"
"github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/channels"
@@ -25,10 +24,7 @@ func TestHandleIncomingMessage_DoesNotConsumeGenericCommandsLocally(t *testing.T
"content": "/help", "content": "/help",
}) })
ctx, cancel := context.WithTimeout(context.Background(), time.Second) inbound, ok := <-messageBus.InboundChan()
defer cancel()
inbound, ok := messageBus.ConsumeInbound(ctx)
if !ok { if !ok {
t.Fatal("expected inbound message to be forwarded") t.Fatal("expected inbound message to be forwarded")
} }
@@ -43,14 +43,19 @@ func TestHandleIncoming_DoesNotConsumeGenericCommandsLocally(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second) ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel() defer cancel()
inbound, ok := messageBus.ConsumeInbound(ctx) select {
if !ok { case <-ctx.Done():
t.Fatal("expected inbound message to be forwarded") t.Fatal("timeout waiting for message to be forwarded")
} return
if inbound.Channel != "whatsapp_native" { case inbound, ok := <-messageBus.InboundChan():
t.Fatalf("channel=%q", inbound.Channel) if !ok {
} t.Fatal("expected inbound message to be forwarded")
if inbound.Content != "/new" { }
t.Fatalf("content=%q", inbound.Content) if inbound.Channel != "whatsapp_native" {
t.Fatalf("channel=%q", inbound.Channel)
}
if inbound.Content != "/new" {
t.Fatalf("content=%q", inbound.Content)
}
} }
} }