From f776611e291b71785132ffba9fb34556eaff6a96 Mon Sep 17 00:00:00 2001 From: juju <14191774+tong3jie@users.noreply.github.com> Date: Wed, 18 Mar 2026 00:02:51 +0800 Subject: [PATCH 1/2] feat(cron): refactor scheduler to event-driven model and add unit tests (#1313) * feat(cron): enhance CronService with wake channel and improve job scheduling logic * fix(cron): update file permission mode to use octal notation in test and fix some lint errors * fix(cron): improve wake channel handling and enhance concurrency in tests --- pkg/cron/service.go | 77 ++++++++++++--- pkg/cron/service_test.go | 199 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 265 insertions(+), 11 deletions(-) diff --git a/pkg/cron/service.go b/pkg/cron/service.go index 04775ac42..77a413133 100644 --- a/pkg/cron/service.go +++ b/pkg/cron/service.go @@ -65,6 +65,7 @@ type CronService struct { mu sync.RWMutex running bool stopChan chan struct{} + wakeChan chan struct{} gronx *gronx.Gronx } @@ -73,6 +74,7 @@ func NewCronService(storePath string, onJob JobHandler) *CronService { storePath: storePath, onJob: onJob, gronx: gronx.New(), + wakeChan: make(chan struct{}), } // Initialize and load store on creation cs.loadStore() @@ -97,6 +99,9 @@ func (cs *CronService) Start() error { } cs.stopChan = make(chan struct{}) + if cs.wakeChan == nil { + cs.wakeChan = make(chan struct{}) + } cs.running = true go cs.runLoop(cs.stopChan) @@ -119,14 +124,47 @@ func (cs *CronService) Stop() { } func (cs *CronService) runLoop(stopChan chan struct{}) { - ticker := time.NewTicker(1 * time.Second) - defer ticker.Stop() + timer := time.NewTimer(time.Hour) + if !timer.Stop() { + <-timer.C + } + defer timer.Stop() for { + // every loop, recalculate the next wake time + cs.mu.RLock() + nextWake := cs.getNextWakeMS() + cs.mu.RUnlock() + + var delay time.Duration + now := time.Now().UnixMilli() + + if nextWake == nil { + // no jobs, sleep for a long time (or until a new job is added) + delay = time.Hour + } else { + diff := *nextWake - now + if diff <= 0 { + delay = 0 + } else { + delay = time.Duration(diff) * time.Millisecond + } + } + + timer.Reset(delay) + select { case <-stopChan: return - case <-ticker.C: + case <-cs.wakeChan: // wake on new job or update + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + continue + case <-timer.C: cs.checkJobs() } } @@ -264,22 +302,19 @@ func (cs *CronService) executeJobByID(jobID string) { } func (cs *CronService) computeNextRun(schedule *CronSchedule, nowMS int64) *int64 { - if schedule.Kind == "at" { + switch schedule.Kind { + case "at": if schedule.AtMS != nil && *schedule.AtMS > nowMS { return schedule.AtMS } return nil - } - - if schedule.Kind == "every" { + case "every": if schedule.EveryMS == nil || *schedule.EveryMS <= 0 { return nil } next := nowMS + *schedule.EveryMS return &next - } - - if schedule.Kind == "cron" { + case "cron": if schedule.Expr == "" { return nil } @@ -294,9 +329,19 @@ func (cs *CronService) computeNextRun(schedule *CronSchedule, nowMS int64) *int6 nextMS := nextTime.UnixMilli() return &nextMS + default: + log.Printf("[cron] unknown schedule kind '%s'", schedule.Kind) + return nil } +} - return nil +// wake up the loop to re-evaluate next wake time immediately (e.g. after add/update/remove jobs) +func (cs *CronService) notify() { + select { + case cs.wakeChan <- struct{}{}: + default: + // if the channel is full, it means the loop will wake up soon anyway, so we can skip sending + } } func (cs *CronService) recomputeNextRuns() { @@ -400,6 +445,8 @@ func (cs *CronService) AddJob( return nil, err } + cs.notify() + return &job, nil } @@ -411,6 +458,9 @@ func (cs *CronService) UpdateJob(job *CronJob) error { if cs.store.Jobs[i].ID == job.ID { cs.store.Jobs[i] = *job cs.store.Jobs[i].UpdatedAtMS = time.Now().UnixMilli() + + cs.notify() + return cs.saveStoreUnsafe() } } @@ -441,6 +491,8 @@ func (cs *CronService) removeJobUnsafe(jobID string) bool { } } + cs.notify() + return removed } @@ -463,6 +515,9 @@ func (cs *CronService) EnableJob(jobID string, enabled bool) *CronJob { if err := cs.saveStoreUnsafe(); err != nil { log.Printf("[cron] failed to save store after enable: %v", err) } + + cs.notify() + return job } } diff --git a/pkg/cron/service_test.go b/pkg/cron/service_test.go index 1a0dd1829..c55e62174 100644 --- a/pkg/cron/service_test.go +++ b/pkg/cron/service_test.go @@ -1,10 +1,13 @@ package cron import ( + "fmt" "os" "path/filepath" "runtime" + "sync" "testing" + "time" ) func TestSaveStore_FilePermissions(t *testing.T) { @@ -36,3 +39,199 @@ func TestSaveStore_FilePermissions(t *testing.T) { func int64Ptr(v int64) *int64 { return &v } + +func setupService(handler JobHandler) (*CronService, string) { + tmpFile := fmt.Sprintf("test_cron_%d.json", time.Now().UnixNano()) + cs := NewCronService(tmpFile, handler) + return cs, tmpFile +} + +func TestCronService_CRUD(t *testing.T) { + cs, path := setupService(nil) + defer os.Remove(path) + + // Test AddJob + at := time.Now().Add(time.Hour).UnixMilli() + job, err := cs.AddJob("Task1", CronSchedule{Kind: "at", AtMS: &at}, "msg", true, "ch", "to") + if err != nil || job.ID == "" { + t.Fatalf("AddJob failed: %v", err) + } + + // Test ListJobs + if len(cs.ListJobs(true)) != 1 { + t.Error("ListJobs should return 1 job") + } + + // Test UpdateJob + job.Name = "UpdatedName" + err = cs.UpdateJob(job) + if err != nil || cs.store.Jobs[0].Name != "UpdatedName" { + t.Error("UpdateJob failed") + } + + // Test EnableJob + cs.EnableJob(job.ID, false) + if cs.store.Jobs[0].Enabled != false || cs.store.Jobs[0].State.NextRunAtMS != nil { + t.Error("EnableJob(false) failed to clear state") + } + + // Test RemoveJob + removed := cs.RemoveJob(job.ID) + if !removed || len(cs.store.Jobs) != 0 { + t.Error("RemoveJob failed") + } +} + +// 2. Test Cron Expression Calculation Logic +func TestCronService_ComputeNextRun(t *testing.T) { + cs, path := setupService(nil) + defer os.Remove(path) + + now := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC).UnixMilli() + + tests := []struct { + name string + schedule CronSchedule + wantNil bool + }{ + {"Valid Cron", CronSchedule{Kind: "cron", Expr: "0 * * * *"}, false}, + {"Invalid Cron", CronSchedule{Kind: "cron", Expr: "invalid"}, true}, + {"Every MS", CronSchedule{Kind: "every", EveryMS: int64Ptr(5000)}, false}, + {"At Future", CronSchedule{Kind: "at", AtMS: int64Ptr(now + 1000)}, false}, + {"At Past", CronSchedule{Kind: "at", AtMS: int64Ptr(now - 1000)}, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := cs.computeNextRun(&tt.schedule, now) + if (got == nil) != tt.wantNil { + t.Errorf("%s: got %v, wantNil %v", tt.name, got, tt.wantNil) + } + }) + } +} + +// 3. Test Execution Flow +func TestCronService_ExecutionFlow(t *testing.T) { + var mu sync.Mutex + executedJobs := make(map[string]bool) + + handler := func(job *CronJob) (string, error) { + mu.Lock() + executedJobs[job.ID] = true + mu.Unlock() + return "ok", nil + } + + cs, path := setupService(handler) + defer os.Remove(path) + + // Start the service + if err := cs.Start(); err != nil { + t.Fatalf("Start failed: %v", err) + } + defer cs.Stop() + + // Add a job then runs 100ms from now + target := time.Now().Add(100 * time.Millisecond).UnixMilli() + job, _ := cs.AddJob("FastJob", CronSchedule{Kind: "at", AtMS: &target}, "", false, "", "") + + // Check for job execution with a timeout + success := false + for range 20 { + mu.Lock() + if executedJobs[job.ID] { + success = true + mu.Unlock() + break + } + mu.Unlock() + time.Sleep(100 * time.Millisecond) + } + + if !success { + t.Error("Job was not executed in time") + } + + // check that the job is removed after execution (DeleteAfterRun = true) + status := cs.Status() + if status["jobs"].(int) != 0 { + t.Errorf("Job should be deleted after run, got count: %v", status["jobs"]) + } +} + +func TestCronService_PersistenceIntegrity(t *testing.T) { + tmpFile := "persist_test.json" + defer os.Remove(tmpFile) + + // write a job and persist + cs1 := NewCronService(tmpFile, nil) + at := int64(2000000000000) + cs1.AddJob("PersistMe", CronSchedule{Kind: "at", AtMS: &at}, "payload", true, "ch1", "") + + // check file exists + if _, err := os.Stat(tmpFile); os.IsNotExist(err) { + t.Fatal("Store file was not created") + } + + // reload and check data integrity + cs2 := NewCronService(tmpFile, nil) + if err := cs2.Load(); err != nil { + t.Fatalf("Failed to load store: %v", err) + } + + jobs := cs2.ListJobs(true) + if len(jobs) != 1 || jobs[0].Name != "PersistMe" { + t.Errorf("Data corruption after reload. Got: %+v", jobs) + } + + // test loading invalid JSON + os.WriteFile(tmpFile, []byte("{invalid json}"), 0o644) + cs3 := NewCronService(tmpFile, nil) + err := cs3.loadStore() + if err == nil { + t.Error("Should return error when loading invalid JSON") + } +} + +func TestCronService_ConcurrentAccess(t *testing.T) { + cs, path := setupService(nil) + defer os.Remove(path) + + cs.Start() + defer cs.Stop() + + var wg sync.WaitGroup + workers := 10 + iterations := 50 + + wg.Add(workers * 2) + + // add jobs concurrently + for i := range workers { + go func(id int) { + defer wg.Done() + for j := range iterations { + at := time.Now().Add(time.Hour).UnixMilli() + cs.AddJob(fmt.Sprintf("Job-%d-%d", id, j), CronSchedule{Kind: "at", AtMS: &at}, "", false, "", "") + time.Sleep(100 * time.Microsecond) + } + }(i) + } + + // read and update jobs concurrently + for range workers { + go func() { + defer wg.Done() + for j := range iterations { + jobs := cs.ListJobs(true) + if len(jobs) > 0 { + cs.EnableJob(jobs[0].ID, j%2 == 0) + } + time.Sleep(100 * time.Microsecond) + } + }() + } + + wg.Wait() +} From 9c31b0ca958e94cdf081cd30ee61be1806c51013 Mon Sep 17 00:00:00 2001 From: juju <14191774+tong3jie@users.noreply.github.com> Date: Wed, 18 Mar 2026 00:12:12 +0800 Subject: [PATCH 2/2] fix: Fixed the bug where the bus was closed and consumers had unfinished messages. (#1179) * fix: Fixed the bug where the bus was closed and consumers had unfinished messages. * fix: remove unnecessary blank line in Close method * fix: refactor message bus and channel handling for improved performance and reliability * fix: improve message handling and bus closure logic for better reliability * fix: reduce sleep duration in agent loop for improved responsiveness * fix the test case --- pkg/agent/loop.go | 108 ++++++------- pkg/agent/loop_test.go | 111 ++++++++----- pkg/bus/bus.go | 153 +++++++----------- pkg/bus/bus_test.go | 52 ++++-- pkg/channels/manager.go | 60 +++---- pkg/channels/qq/qq_test.go | 20 ++- .../telegram/telegram_dispatch_test.go | 6 +- .../telegram_group_command_filter_test.go | 28 ++-- pkg/channels/telegram/telegram_test.go | 16 +- .../whatsapp/whatsapp_command_test.go | 6 +- .../whatsapp_native/whatsapp_command_test.go | 23 +-- 11 files changed, 301 insertions(+), 282 deletions(-) diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 00c9d913a..5c6cb2fe9 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -267,67 +267,65 @@ func (al *AgentLoop) Run(ctx context.Context) error { select { case <-ctx.Done(): return nil - default: - msg, ok := al.bus.ConsumeInbound(ctx) + case msg, ok := <-al.bus.InboundChan(): if !ok { - continue + return nil + } + // Process message + // TODO: Re-enable media cleanup after inbound media is properly consumed by the agent. + // Currently disabled because files are deleted before the LLM can access their content. + // defer func() { + // if al.mediaStore != nil && msg.MediaScope != "" { + // if releaseErr := al.mediaStore.ReleaseAll(msg.MediaScope); releaseErr != nil { + // logger.WarnCF("agent", "Failed to release media", map[string]any{ + // "scope": msg.MediaScope, + // "error": releaseErr.Error(), + // }) + // } + // } + // }() + + response, err := al.processMessage(ctx, msg) + if err != nil { + response = fmt.Sprintf("Error processing message: %v", err) } - // Process message - func() { - // TODO: Re-enable media cleanup after inbound media is properly consumed by the agent. - // Currently disabled because files are deleted before the LLM can access their content. - // defer func() { - // if al.mediaStore != nil && msg.MediaScope != "" { - // if releaseErr := al.mediaStore.ReleaseAll(msg.MediaScope); releaseErr != nil { - // logger.WarnCF("agent", "Failed to release media", map[string]any{ - // "scope": msg.MediaScope, - // "error": releaseErr.Error(), - // }) - // } - // } - // }() - - response, err := al.processMessage(ctx, msg) - if err != nil { - response = fmt.Sprintf("Error processing message: %v", err) - } - - if response != "" { - // Check if the message tool already sent a response during this round. - // If so, skip publishing to avoid duplicate messages to the user. - // Use default agent's tools to check (message tool is shared). - alreadySent := false - defaultAgent := al.GetRegistry().GetDefaultAgent() - if defaultAgent != nil { - if tool, ok := defaultAgent.Tools.Get("message"); ok { - if mt, ok := tool.(*tools.MessageTool); ok { - alreadySent = mt.HasSentInRound() - } + if response != "" { + // Check if the message tool already sent a response during this round. + // If so, skip publishing to avoid duplicate messages to the user. + // Use default agent's tools to check (message tool is shared). + alreadySent := false + defaultAgent := al.GetRegistry().GetDefaultAgent() + if defaultAgent != nil { + if tool, ok := defaultAgent.Tools.Get("message"); ok { + if mt, ok := tool.(*tools.MessageTool); ok { + alreadySent = mt.HasSentInRound() } } - - if !alreadySent { - al.bus.PublishOutbound(ctx, bus.OutboundMessage{ - Channel: msg.Channel, - ChatID: msg.ChatID, - Content: response, - }) - logger.InfoCF("agent", "Published outbound response", - map[string]any{ - "channel": msg.Channel, - "chat_id": msg.ChatID, - "content_len": len(response), - }) - } else { - logger.DebugCF( - "agent", - "Skipped outbound (message tool already sent)", - map[string]any{"channel": msg.Channel}, - ) - } } - }() + + if !alreadySent { + al.bus.PublishOutbound(ctx, bus.OutboundMessage{ + Channel: msg.Channel, + ChatID: msg.ChatID, + Content: response, + }) + logger.InfoCF("agent", "Published outbound response", + map[string]any{ + "channel": msg.Channel, + "chat_id": msg.ChatID, + "content_len": len(response), + }) + } else { + logger.DebugCF( + "agent", + "Skipped outbound (message tool already sent)", + map[string]any{"channel": msg.Channel}, + ) + } + } + default: + time.Sleep(time.Microsecond * 200) } } diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go index 47c378771..25ee6ab4d 100644 --- a/pkg/agent/loop_test.go +++ b/pkg/agent/loop_test.go @@ -997,10 +997,25 @@ func TestHandleReasoning(t *testing.T) { al, msgBus := newLoop(t) al.handleReasoning(context.Background(), "reasoning", "telegram", "") - ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() - if msg, ok := msgBus.SubscribeOutbound(ctx); ok { - t.Fatalf("expected no outbound message, got %+v", msg) + for { + select { + case msg, ok := <-msgBus.OutboundChan(): + if !ok { + t.Fatalf("expected no outbound message, got %+v", msg) + } + if msg.Content == "reasoning" { + t.Fatalf("expected no message for empty chatID, got %+v", msg) + } + return + case <-ctx.Done(): + t.Log("expected an outbound message, got none within timeout") + return + default: + // Continue to check for message + time.Sleep(5 * time.Millisecond) // Avoid busy loop + } } }) @@ -1008,9 +1023,7 @@ func TestHandleReasoning(t *testing.T) { al, msgBus := newLoop(t) al.handleReasoning(context.Background(), "hello reasoning", "slack", "channel-1") - ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) - defer cancel() - msg, ok := msgBus.SubscribeOutbound(ctx) + msg, ok := <-msgBus.OutboundChan() if !ok { t.Fatal("expected an outbound message") } @@ -1024,35 +1037,52 @@ func TestHandleReasoning(t *testing.T) { reasoning := "hello telegram reasoning" al.handleReasoning(context.Background(), reasoning, "telegram", "tg-chat") - ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() - msg, ok := msgBus.SubscribeOutbound(ctx) - if !ok { - t.Fatal("expected outbound message") - } + for { + select { + case <-ctx.Done(): + t.Fatal("expected an outbound message, got none within timeout") + return + case msg, ok := <-msgBus.OutboundChan(): + if !ok { + t.Fatal("expected outbound message") + } - if msg.Channel != "telegram" { - t.Fatalf("expected telegram channel message, got %+v", msg) - } - if msg.ChatID != "tg-chat" { - t.Fatalf("expected chatID tg-chat, got %+v", msg) - } - if msg.Content != reasoning { - t.Fatalf("content mismatch: got %q want %q", msg.Content, reasoning) + if msg.Channel != "telegram" { + t.Fatalf("expected telegram channel message, got %+v", msg) + } + if msg.ChatID != "tg-chat" { + t.Fatalf("expected chatID tg-chat, got %+v", msg) + } + if msg.Content != reasoning { + t.Fatalf("content mismatch: got %q want %q", msg.Content, reasoning) + } + return + } } }) t.Run("expired ctx", func(t *testing.T) { al, msgBus := newLoop(t) reasoning := "hello telegram reasoning" - ctx, cancel := context.WithCancel(context.Background()) - cancel() - al.handleReasoning(ctx, reasoning, "telegram", "tg-chat") - ctx, cancel = context.WithTimeout(context.Background(), 200*time.Millisecond) - defer cancel() - msg, ok := msgBus.SubscribeOutbound(ctx) - if ok { - t.Fatalf("expected no outbound message, got %+v", msg) + al.handleReasoning(context.Background(), reasoning, "telegram", "tg-chat") + + consumeCtx, consumeCancel := context.WithTimeout(context.Background(), 2*time.Second) + defer consumeCancel() + + for { + select { + case msg, ok := <-msgBus.OutboundChan(): + if !ok { + t.Fatalf("expected no outbound message, but received: %+v", msg) + } + t.Logf("Received unexpected outbound message: %+v", msg) + return + case <-consumeCtx.Done(): + t.Fatalf("failed: no message received within timeout") + return + } } }) @@ -1092,20 +1122,23 @@ func TestHandleReasoning(t *testing.T) { // Drain the bus and verify the reasoning message was NOT published // (it should have been dropped due to timeout). - drainCtx, drainCancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - defer drainCancel() - foundReasoning := false + timeer := time.After(1 * time.Second) for { - msg, ok := msgBus.SubscribeOutbound(drainCtx) - if !ok { - break + select { + case <-timeer: + t.Logf( + "no reasoning message received after draining bus for 1s, as expected,length=%d", + len(msgBus.OutboundChan()), + ) + return + case msg, ok := <-msgBus.OutboundChan(): + if !ok { + break + } + if msg.Content == "should timeout" { + t.Fatal("expected reasoning message to be dropped when bus is full, but it was published") + } } - if msg.Content == "should timeout" { - foundReasoning = true - } - } - if foundReasoning { - t.Fatal("expected reasoning message to be dropped when bus is full, but it was published") } }) } diff --git a/pkg/bus/bus.go b/pkg/bus/bus.go index f5ff9587d..3d08bda4f 100644 --- a/pkg/bus/bus.go +++ b/pkg/bus/bus.go @@ -3,6 +3,7 @@ package bus import ( "context" "errors" + "sync" "sync/atomic" "github.com/sipeed/picoclaw/pkg/logger" @@ -17,8 +18,11 @@ type MessageBus struct { inbound chan InboundMessage outbound chan OutboundMessage outboundMedia chan OutboundMediaMessage - done chan struct{} - closed atomic.Bool + + closeOnce sync.Once + done chan struct{} + closed atomic.Bool + wg sync.WaitGroup } func NewMessageBus() *MessageBus { @@ -30,128 +34,91 @@ func NewMessageBus() *MessageBus { } } -func (mb *MessageBus) PublishInbound(ctx context.Context, msg InboundMessage) error { +func publish[T any](ctx context.Context, mb *MessageBus, ch chan T, msg T) error { + // check bus closed before acquiring wg, to avoid unnecessary wg.Add and potential deadlock if mb.closed.Load() { return ErrBusClosed } - if err := ctx.Err(); err != nil { - return err - } + + // check again,before sending message, to avoid sending to closed channel select { - case mb.inbound <- msg: - return nil - case <-mb.done: - return ErrBusClosed case <-ctx.Done(): return ctx.Err() + case <-mb.done: + return ErrBusClosed + default: + } + + mb.wg.Add(1) + defer mb.wg.Done() + + select { + case ch <- msg: + return nil + case <-ctx.Done(): + return ctx.Err() + case <-mb.done: + return ErrBusClosed } } -func (mb *MessageBus) ConsumeInbound(ctx context.Context) (InboundMessage, bool) { - select { - case msg, ok := <-mb.inbound: - return msg, ok - case <-mb.done: - return InboundMessage{}, false - case <-ctx.Done(): - return InboundMessage{}, false - } +func (mb *MessageBus) PublishInbound(ctx context.Context, msg InboundMessage) error { + return publish(ctx, mb, mb.inbound, msg) +} + +func (mb *MessageBus) InboundChan() <-chan InboundMessage { + return mb.inbound } func (mb *MessageBus) PublishOutbound(ctx context.Context, msg OutboundMessage) error { - if mb.closed.Load() { - return ErrBusClosed - } - if err := ctx.Err(); err != nil { - return err - } - select { - case mb.outbound <- msg: - return nil - case <-mb.done: - return ErrBusClosed - case <-ctx.Done(): - return ctx.Err() - } + return publish(ctx, mb, mb.outbound, msg) } -func (mb *MessageBus) SubscribeOutbound(ctx context.Context) (OutboundMessage, bool) { - select { - case msg, ok := <-mb.outbound: - return msg, ok - case <-mb.done: - return OutboundMessage{}, false - case <-ctx.Done(): - return OutboundMessage{}, false - } +func (mb *MessageBus) OutboundChan() <-chan OutboundMessage { + return mb.outbound } func (mb *MessageBus) PublishOutboundMedia(ctx context.Context, msg OutboundMediaMessage) error { - if mb.closed.Load() { - return ErrBusClosed - } - if err := ctx.Err(); err != nil { - return err - } - select { - case mb.outboundMedia <- msg: - return nil - case <-mb.done: - return ErrBusClosed - case <-ctx.Done(): - return ctx.Err() - } + return publish(ctx, mb, mb.outboundMedia, msg) } -func (mb *MessageBus) SubscribeOutboundMedia(ctx context.Context) (OutboundMediaMessage, bool) { - select { - case msg, ok := <-mb.outboundMedia: - return msg, ok - case <-mb.done: - return OutboundMediaMessage{}, false - case <-ctx.Done(): - return OutboundMediaMessage{}, false - } +func (mb *MessageBus) OutboundMediaChan() <-chan OutboundMediaMessage { + return mb.outboundMedia } func (mb *MessageBus) Close() { - if mb.closed.CompareAndSwap(false, true) { + mb.closeOnce.Do(func() { + // notify all blocked publishers to exit close(mb.done) - // Drain buffered channels so messages aren't silently lost. - // Channels are NOT closed to avoid send-on-closed panics from concurrent publishers. + // because every publisher will check mb.closed before acquiring wg + // so we can be sure that new publishers will not be added new messages after this point + mb.closed.Store(true) + + // wait for all ongoing Publish calls to finish, ensuring all messages have been sent to channels or exited + mb.wg.Wait() + + // close channels safely + close(mb.inbound) + close(mb.outbound) + close(mb.outboundMedia) + + // clean up any remaining messages in channels drained := 0 - for { - select { - case <-mb.inbound: - drained++ - default: - goto doneInbound - } + for range mb.inbound { + drained++ } - doneInbound: - for { - select { - case <-mb.outbound: - drained++ - default: - goto doneOutbound - } + for range mb.outbound { + drained++ } - doneOutbound: - for { - select { - case <-mb.outboundMedia: - drained++ - default: - goto doneMedia - } + for range mb.outboundMedia { + drained++ } - doneMedia: + if drained > 0 { logger.DebugCF("bus", "Drained buffered messages during close", map[string]any{ "count": drained, }) } - } + }) } diff --git a/pkg/bus/bus_test.go b/pkg/bus/bus_test.go index e07b8c7fe..9b6324ca6 100644 --- a/pkg/bus/bus_test.go +++ b/pkg/bus/bus_test.go @@ -24,7 +24,7 @@ func TestPublishConsume(t *testing.T) { t.Fatalf("PublishInbound failed: %v", err) } - got, ok := mb.ConsumeInbound(ctx) + got, ok := <-mb.InboundChan() if !ok { t.Fatal("ConsumeInbound returned ok=false") } @@ -52,7 +52,7 @@ func TestPublishOutboundSubscribe(t *testing.T) { t.Fatalf("PublishOutbound failed: %v", err) } - got, ok := mb.SubscribeOutbound(ctx) + got, ok := <-mb.OutboundChan() if !ok { t.Fatal("SubscribeOutbound returned ok=false") } @@ -108,27 +108,48 @@ func TestPublishOutbound_BusClosed(t *testing.T) { func TestConsumeInbound_ContextCancel(t *testing.T) { mb := NewMessageBus() + defer mb.Close() - ctx, cancel := context.WithCancel(context.Background()) - cancel() + for i := range defaultBusBufferSize { + if err := mb.PublishInbound(context.Background(), InboundMessage{Content: "fill"}); err != nil { + t.Fatalf("fill failed at %d: %v", i, err) + } + } - _, ok := mb.ConsumeInbound(ctx) - if ok { - t.Fatal("expected ok=false when context is canceled") + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + mb.PublishInbound(ctx, InboundMessage{Content: "ContextCancel"}) + + select { + case <-ctx.Done(): + t.Log("context canceled, as expected") + + case msg, ok := <-mb.InboundChan(): + if !ok { + t.Fatal("expected ok=false when context is canceled") + } + if msg.Content == "ContextCancel" { + t.Fatalf("expected content 'ContextCancel', got %q", msg.Content) + } } } func TestConsumeInbound_BusClosed(t *testing.T) { mb := NewMessageBus() - mb.Close() - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - defer cancel() + timer := time.AfterFunc(100*time.Millisecond, func() { + mb.Close() + }) - _, ok := mb.ConsumeInbound(ctx) - if ok { - t.Fatal("expected ok=false when bus is closed") + select { + case <-timer.C: + t.Log("context canceled, as expected") + + case _, ok := <-mb.InboundChan(): + if ok { + t.Fatal("expected ok=false when context is canceled") + } } } @@ -136,10 +157,7 @@ func TestSubscribeOutbound_BusClosed(t *testing.T) { mb := NewMessageBus() mb.Close() - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - defer cancel() - - _, ok := mb.SubscribeOutbound(ctx) + _, ok := <-mb.OutboundChan() if ok { t.Fatal("expected ok=false when bus is closed") } diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go index 7d49a0e30..aed815399 100644 --- a/pkg/channels/manager.go +++ b/pkg/channels/manager.go @@ -585,7 +585,7 @@ func (m *Manager) sendWithRetry(ctx context.Context, name string, w *channelWork func dispatchLoop[M any]( ctx context.Context, m *Manager, - subscribe func(context.Context) (M, bool), + ch <-chan M, getChannel func(M) string, enqueue func(context.Context, *channelWorker, M) bool, startMsg, stopMsg, unknownMsg, noWorkerMsg string, @@ -593,35 +593,41 @@ func dispatchLoop[M any]( logger.InfoC("channels", startMsg) for { - msg, ok := subscribe(ctx) - if !ok { + select { + case <-ctx.Done(): logger.InfoC("channels", stopMsg) return - } - channel := getChannel(msg) - - // Silently skip internal channels - if constants.IsInternalChannel(channel) { - continue - } - - m.mu.RLock() - _, exists := m.channels[channel] - w, wExists := m.workers[channel] - m.mu.RUnlock() - - if !exists { - logger.WarnCF("channels", unknownMsg, map[string]any{"channel": channel}) - continue - } - - if wExists && w != nil { - if !enqueue(ctx, w, msg) { + case msg, ok := <-ch: + if !ok { + logger.InfoC("channels", stopMsg) return } - } else if exists { - logger.WarnCF("channels", noWorkerMsg, map[string]any{"channel": channel}) + + channel := getChannel(msg) + + // Silently skip internal channels + if constants.IsInternalChannel(channel) { + continue + } + + m.mu.RLock() + _, exists := m.channels[channel] + w, wExists := m.workers[channel] + m.mu.RUnlock() + + if !exists { + logger.WarnCF("channels", unknownMsg, map[string]any{"channel": channel}) + continue + } + + if wExists && w != nil { + if !enqueue(ctx, w, msg) { + return + } + } else if exists { + logger.WarnCF("channels", noWorkerMsg, map[string]any{"channel": channel}) + } } } } @@ -629,7 +635,7 @@ func dispatchLoop[M any]( func (m *Manager) dispatchOutbound(ctx context.Context) { dispatchLoop( ctx, m, - m.bus.SubscribeOutbound, + m.bus.OutboundChan(), func(msg bus.OutboundMessage) string { return msg.Channel }, func(ctx context.Context, w *channelWorker, msg bus.OutboundMessage) bool { select { @@ -649,7 +655,7 @@ func (m *Manager) dispatchOutbound(ctx context.Context) { func (m *Manager) dispatchOutboundMedia(ctx context.Context) { dispatchLoop( ctx, m, - m.bus.SubscribeOutboundMedia, + m.bus.OutboundMediaChan(), func(msg bus.OutboundMediaMessage) string { return msg.Channel }, func(ctx context.Context, w *channelWorker, msg bus.OutboundMediaMessage) bool { select { diff --git a/pkg/channels/qq/qq_test.go b/pkg/channels/qq/qq_test.go index 3ceee0d09..b04cf5abd 100644 --- a/pkg/channels/qq/qq_test.go +++ b/pkg/channels/qq/qq_test.go @@ -34,11 +34,19 @@ func TestHandleC2CMessage_IncludesAccountIDMetadata(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - inbound, ok := messageBus.ConsumeInbound(ctx) - if !ok { - t.Fatal("expected inbound message") - } - if inbound.Metadata["account_id"] != "7750283E123456" { - t.Fatalf("account_id metadata = %q, want %q", inbound.Metadata["account_id"], "7750283E123456") + for { + select { + case <-ctx.Done(): + t.Fatal("timeout waiting for inbound message") + return + case inbound, ok := <-messageBus.InboundChan(): + if !ok { + t.Fatal("expected inbound message") + } + if inbound.Metadata["account_id"] != "7750283E123456" { + t.Fatalf("account_id metadata = %q, want %q", inbound.Metadata["account_id"], "7750283E123456") + } + return + } } } diff --git a/pkg/channels/telegram/telegram_dispatch_test.go b/pkg/channels/telegram/telegram_dispatch_test.go index 1ea4a4824..0eb1de5ea 100644 --- a/pkg/channels/telegram/telegram_dispatch_test.go +++ b/pkg/channels/telegram/telegram_dispatch_test.go @@ -3,7 +3,6 @@ package telegram import ( "context" "testing" - "time" "github.com/mymmrac/telego" @@ -36,10 +35,7 @@ func TestHandleMessage_DoesNotConsumeGenericCommandsLocally(t *testing.T) { t.Fatalf("handleMessage error: %v", err) } - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - - inbound, ok := messageBus.ConsumeInbound(ctx) + inbound, ok := <-messageBus.InboundChan() if !ok { t.Fatal("expected inbound message to be forwarded") } diff --git a/pkg/channels/telegram/telegram_group_command_filter_test.go b/pkg/channels/telegram/telegram_group_command_filter_test.go index 0d5b985fe..614b2ca7f 100644 --- a/pkg/channels/telegram/telegram_group_command_filter_test.go +++ b/pkg/channels/telegram/telegram_group_command_filter_test.go @@ -108,22 +108,24 @@ func TestHandleMessage_GroupMentionOnly_BotCommandEntity(t *testing.T) { t.Fatalf("handleMessage error: %v", err) } - ctx, cancel := context.WithTimeout(context.Background(), 150*time.Millisecond) + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Microsecond) defer cancel() - - inbound, ok := messageBus.ConsumeInbound(ctx) - if tc.wantForwarded { - if !ok { - t.Fatal("expected inbound message to be forwarded") + select { + case <-ctx.Done(): + if tc.wantForwarded { + t.Fatal("timeout waiting for message to be forwarded") + return } - if inbound.Content != tc.wantContent { - t.Fatalf("content=%q want=%q", inbound.Content, tc.wantContent) + case inbound, ok := <-messageBus.InboundChan(): + if tc.wantForwarded { + if !ok { + t.Fatal("expected inbound message to be forwarded") + } + if inbound.Content != tc.wantContent { + t.Fatalf("content=%q want=%q", inbound.Content, tc.wantContent) + } + return } - return - } - - if ok { - t.Fatalf("expected message to be filtered, got content=%q", inbound.Content) } }) } diff --git a/pkg/channels/telegram/telegram_test.go b/pkg/channels/telegram/telegram_test.go index c2186d0a3..52a2b046c 100644 --- a/pkg/channels/telegram/telegram_test.go +++ b/pkg/channels/telegram/telegram_test.go @@ -6,7 +6,6 @@ import ( "errors" "strings" "testing" - "time" "github.com/mymmrac/telego" ta "github.com/mymmrac/telego/telegoapi" @@ -355,10 +354,7 @@ func TestHandleMessage_ForumTopic_SetsMetadata(t *testing.T) { err := ch.handleMessage(context.Background(), msg) require.NoError(t, err) - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - - inbound, ok := messageBus.ConsumeInbound(ctx) + inbound, ok := <-messageBus.InboundChan() require.True(t, ok, "expected inbound message") // Composite chatID should include thread ID @@ -397,10 +393,7 @@ func TestHandleMessage_NoForum_NoThreadMetadata(t *testing.T) { err := ch.handleMessage(context.Background(), msg) require.NoError(t, err) - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - - inbound, ok := messageBus.ConsumeInbound(ctx) + inbound, ok := <-messageBus.InboundChan() require.True(t, ok) // Plain chatID without thread suffix @@ -443,10 +436,7 @@ func TestHandleMessage_ReplyThread_NonForum_NoIsolation(t *testing.T) { err := ch.handleMessage(context.Background(), msg) require.NoError(t, err) - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - - inbound, ok := messageBus.ConsumeInbound(ctx) + inbound, ok := <-messageBus.InboundChan() require.True(t, ok) // chatID should NOT include thread suffix for non-forum groups diff --git a/pkg/channels/whatsapp/whatsapp_command_test.go b/pkg/channels/whatsapp/whatsapp_command_test.go index ee8aa4a52..2d85d74f8 100644 --- a/pkg/channels/whatsapp/whatsapp_command_test.go +++ b/pkg/channels/whatsapp/whatsapp_command_test.go @@ -3,7 +3,6 @@ package whatsapp import ( "context" "testing" - "time" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" @@ -25,10 +24,7 @@ func TestHandleIncomingMessage_DoesNotConsumeGenericCommandsLocally(t *testing.T "content": "/help", }) - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - - inbound, ok := messageBus.ConsumeInbound(ctx) + inbound, ok := <-messageBus.InboundChan() if !ok { t.Fatal("expected inbound message to be forwarded") } diff --git a/pkg/channels/whatsapp_native/whatsapp_command_test.go b/pkg/channels/whatsapp_native/whatsapp_command_test.go index cc2dcb619..e51bec392 100644 --- a/pkg/channels/whatsapp_native/whatsapp_command_test.go +++ b/pkg/channels/whatsapp_native/whatsapp_command_test.go @@ -43,14 +43,19 @@ func TestHandleIncoming_DoesNotConsumeGenericCommandsLocally(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - inbound, ok := messageBus.ConsumeInbound(ctx) - if !ok { - t.Fatal("expected inbound message to be forwarded") - } - if inbound.Channel != "whatsapp_native" { - t.Fatalf("channel=%q", inbound.Channel) - } - if inbound.Content != "/new" { - t.Fatalf("content=%q", inbound.Content) + select { + case <-ctx.Done(): + t.Fatal("timeout waiting for message to be forwarded") + return + case inbound, ok := <-messageBus.InboundChan(): + if !ok { + t.Fatal("expected inbound message to be forwarded") + } + if inbound.Channel != "whatsapp_native" { + t.Fatalf("channel=%q", inbound.Channel) + } + if inbound.Content != "/new" { + t.Fatalf("content=%q", inbound.Content) + } } }