Merge branch 'main' into feat/markdown-output-format-web-fetch

This commit is contained in:
afjcjsbx
2026-03-17 17:21:14 +01:00
13 changed files with 566 additions and 293 deletions
+53 -55
View File
@@ -267,67 +267,65 @@ func (al *AgentLoop) Run(ctx context.Context) error {
select {
case <-ctx.Done():
return nil
default:
msg, ok := al.bus.ConsumeInbound(ctx)
case msg, ok := <-al.bus.InboundChan():
if !ok {
continue
return nil
}
// Process message
// TODO: Re-enable media cleanup after inbound media is properly consumed by the agent.
// Currently disabled because files are deleted before the LLM can access their content.
// defer func() {
// if al.mediaStore != nil && msg.MediaScope != "" {
// if releaseErr := al.mediaStore.ReleaseAll(msg.MediaScope); releaseErr != nil {
// logger.WarnCF("agent", "Failed to release media", map[string]any{
// "scope": msg.MediaScope,
// "error": releaseErr.Error(),
// })
// }
// }
// }()
response, err := al.processMessage(ctx, msg)
if err != nil {
response = fmt.Sprintf("Error processing message: %v", err)
}
// Process message
func() {
// TODO: Re-enable media cleanup after inbound media is properly consumed by the agent.
// Currently disabled because files are deleted before the LLM can access their content.
// defer func() {
// if al.mediaStore != nil && msg.MediaScope != "" {
// if releaseErr := al.mediaStore.ReleaseAll(msg.MediaScope); releaseErr != nil {
// logger.WarnCF("agent", "Failed to release media", map[string]any{
// "scope": msg.MediaScope,
// "error": releaseErr.Error(),
// })
// }
// }
// }()
response, err := al.processMessage(ctx, msg)
if err != nil {
response = fmt.Sprintf("Error processing message: %v", err)
}
if response != "" {
// Check if the message tool already sent a response during this round.
// If so, skip publishing to avoid duplicate messages to the user.
// Use default agent's tools to check (message tool is shared).
alreadySent := false
defaultAgent := al.GetRegistry().GetDefaultAgent()
if defaultAgent != nil {
if tool, ok := defaultAgent.Tools.Get("message"); ok {
if mt, ok := tool.(*tools.MessageTool); ok {
alreadySent = mt.HasSentInRound()
}
if response != "" {
// Check if the message tool already sent a response during this round.
// If so, skip publishing to avoid duplicate messages to the user.
// Use default agent's tools to check (message tool is shared).
alreadySent := false
defaultAgent := al.GetRegistry().GetDefaultAgent()
if defaultAgent != nil {
if tool, ok := defaultAgent.Tools.Get("message"); ok {
if mt, ok := tool.(*tools.MessageTool); ok {
alreadySent = mt.HasSentInRound()
}
}
if !alreadySent {
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
Channel: msg.Channel,
ChatID: msg.ChatID,
Content: response,
})
logger.InfoCF("agent", "Published outbound response",
map[string]any{
"channel": msg.Channel,
"chat_id": msg.ChatID,
"content_len": len(response),
})
} else {
logger.DebugCF(
"agent",
"Skipped outbound (message tool already sent)",
map[string]any{"channel": msg.Channel},
)
}
}
}()
if !alreadySent {
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
Channel: msg.Channel,
ChatID: msg.ChatID,
Content: response,
})
logger.InfoCF("agent", "Published outbound response",
map[string]any{
"channel": msg.Channel,
"chat_id": msg.ChatID,
"content_len": len(response),
})
} else {
logger.DebugCF(
"agent",
"Skipped outbound (message tool already sent)",
map[string]any{"channel": msg.Channel},
)
}
}
default:
time.Sleep(time.Microsecond * 200)
}
}
+72 -39
View File
@@ -997,10 +997,25 @@ func TestHandleReasoning(t *testing.T) {
al, msgBus := newLoop(t)
al.handleReasoning(context.Background(), "reasoning", "telegram", "")
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if msg, ok := msgBus.SubscribeOutbound(ctx); ok {
t.Fatalf("expected no outbound message, got %+v", msg)
for {
select {
case msg, ok := <-msgBus.OutboundChan():
if !ok {
t.Fatalf("expected no outbound message, got %+v", msg)
}
if msg.Content == "reasoning" {
t.Fatalf("expected no message for empty chatID, got %+v", msg)
}
return
case <-ctx.Done():
t.Log("expected an outbound message, got none within timeout")
return
default:
// Continue to check for message
time.Sleep(5 * time.Millisecond) // Avoid busy loop
}
}
})
@@ -1008,9 +1023,7 @@ func TestHandleReasoning(t *testing.T) {
al, msgBus := newLoop(t)
al.handleReasoning(context.Background(), "hello reasoning", "slack", "channel-1")
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
msg, ok := msgBus.SubscribeOutbound(ctx)
msg, ok := <-msgBus.OutboundChan()
if !ok {
t.Fatal("expected an outbound message")
}
@@ -1024,35 +1037,52 @@ func TestHandleReasoning(t *testing.T) {
reasoning := "hello telegram reasoning"
al.handleReasoning(context.Background(), reasoning, "telegram", "tg-chat")
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
msg, ok := msgBus.SubscribeOutbound(ctx)
if !ok {
t.Fatal("expected outbound message")
}
for {
select {
case <-ctx.Done():
t.Fatal("expected an outbound message, got none within timeout")
return
case msg, ok := <-msgBus.OutboundChan():
if !ok {
t.Fatal("expected outbound message")
}
if msg.Channel != "telegram" {
t.Fatalf("expected telegram channel message, got %+v", msg)
}
if msg.ChatID != "tg-chat" {
t.Fatalf("expected chatID tg-chat, got %+v", msg)
}
if msg.Content != reasoning {
t.Fatalf("content mismatch: got %q want %q", msg.Content, reasoning)
if msg.Channel != "telegram" {
t.Fatalf("expected telegram channel message, got %+v", msg)
}
if msg.ChatID != "tg-chat" {
t.Fatalf("expected chatID tg-chat, got %+v", msg)
}
if msg.Content != reasoning {
t.Fatalf("content mismatch: got %q want %q", msg.Content, reasoning)
}
return
}
}
})
t.Run("expired ctx", func(t *testing.T) {
al, msgBus := newLoop(t)
reasoning := "hello telegram reasoning"
ctx, cancel := context.WithCancel(context.Background())
cancel()
al.handleReasoning(ctx, reasoning, "telegram", "tg-chat")
ctx, cancel = context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
msg, ok := msgBus.SubscribeOutbound(ctx)
if ok {
t.Fatalf("expected no outbound message, got %+v", msg)
al.handleReasoning(context.Background(), reasoning, "telegram", "tg-chat")
consumeCtx, consumeCancel := context.WithTimeout(context.Background(), 2*time.Second)
defer consumeCancel()
for {
select {
case msg, ok := <-msgBus.OutboundChan():
if !ok {
t.Fatalf("expected no outbound message, but received: %+v", msg)
}
t.Logf("Received unexpected outbound message: %+v", msg)
return
case <-consumeCtx.Done():
t.Fatalf("failed: no message received within timeout")
return
}
}
})
@@ -1092,20 +1122,23 @@ func TestHandleReasoning(t *testing.T) {
// Drain the bus and verify the reasoning message was NOT published
// (it should have been dropped due to timeout).
drainCtx, drainCancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer drainCancel()
foundReasoning := false
timeer := time.After(1 * time.Second)
for {
msg, ok := msgBus.SubscribeOutbound(drainCtx)
if !ok {
break
select {
case <-timeer:
t.Logf(
"no reasoning message received after draining bus for 1s, as expected,length=%d",
len(msgBus.OutboundChan()),
)
return
case msg, ok := <-msgBus.OutboundChan():
if !ok {
break
}
if msg.Content == "should timeout" {
t.Fatal("expected reasoning message to be dropped when bus is full, but it was published")
}
}
if msg.Content == "should timeout" {
foundReasoning = true
}
}
if foundReasoning {
t.Fatal("expected reasoning message to be dropped when bus is full, but it was published")
}
})
}
+60 -93
View File
@@ -3,6 +3,7 @@ package bus
import (
"context"
"errors"
"sync"
"sync/atomic"
"github.com/sipeed/picoclaw/pkg/logger"
@@ -17,8 +18,11 @@ type MessageBus struct {
inbound chan InboundMessage
outbound chan OutboundMessage
outboundMedia chan OutboundMediaMessage
done chan struct{}
closed atomic.Bool
closeOnce sync.Once
done chan struct{}
closed atomic.Bool
wg sync.WaitGroup
}
func NewMessageBus() *MessageBus {
@@ -30,128 +34,91 @@ func NewMessageBus() *MessageBus {
}
}
func (mb *MessageBus) PublishInbound(ctx context.Context, msg InboundMessage) error {
func publish[T any](ctx context.Context, mb *MessageBus, ch chan T, msg T) error {
// check bus closed before acquiring wg, to avoid unnecessary wg.Add and potential deadlock
if mb.closed.Load() {
return ErrBusClosed
}
if err := ctx.Err(); err != nil {
return err
}
// check again,before sending message, to avoid sending to closed channel
select {
case mb.inbound <- msg:
return nil
case <-mb.done:
return ErrBusClosed
case <-ctx.Done():
return ctx.Err()
case <-mb.done:
return ErrBusClosed
default:
}
mb.wg.Add(1)
defer mb.wg.Done()
select {
case ch <- msg:
return nil
case <-ctx.Done():
return ctx.Err()
case <-mb.done:
return ErrBusClosed
}
}
func (mb *MessageBus) ConsumeInbound(ctx context.Context) (InboundMessage, bool) {
select {
case msg, ok := <-mb.inbound:
return msg, ok
case <-mb.done:
return InboundMessage{}, false
case <-ctx.Done():
return InboundMessage{}, false
}
func (mb *MessageBus) PublishInbound(ctx context.Context, msg InboundMessage) error {
return publish(ctx, mb, mb.inbound, msg)
}
func (mb *MessageBus) InboundChan() <-chan InboundMessage {
return mb.inbound
}
func (mb *MessageBus) PublishOutbound(ctx context.Context, msg OutboundMessage) error {
if mb.closed.Load() {
return ErrBusClosed
}
if err := ctx.Err(); err != nil {
return err
}
select {
case mb.outbound <- msg:
return nil
case <-mb.done:
return ErrBusClosed
case <-ctx.Done():
return ctx.Err()
}
return publish(ctx, mb, mb.outbound, msg)
}
func (mb *MessageBus) SubscribeOutbound(ctx context.Context) (OutboundMessage, bool) {
select {
case msg, ok := <-mb.outbound:
return msg, ok
case <-mb.done:
return OutboundMessage{}, false
case <-ctx.Done():
return OutboundMessage{}, false
}
func (mb *MessageBus) OutboundChan() <-chan OutboundMessage {
return mb.outbound
}
func (mb *MessageBus) PublishOutboundMedia(ctx context.Context, msg OutboundMediaMessage) error {
if mb.closed.Load() {
return ErrBusClosed
}
if err := ctx.Err(); err != nil {
return err
}
select {
case mb.outboundMedia <- msg:
return nil
case <-mb.done:
return ErrBusClosed
case <-ctx.Done():
return ctx.Err()
}
return publish(ctx, mb, mb.outboundMedia, msg)
}
func (mb *MessageBus) SubscribeOutboundMedia(ctx context.Context) (OutboundMediaMessage, bool) {
select {
case msg, ok := <-mb.outboundMedia:
return msg, ok
case <-mb.done:
return OutboundMediaMessage{}, false
case <-ctx.Done():
return OutboundMediaMessage{}, false
}
func (mb *MessageBus) OutboundMediaChan() <-chan OutboundMediaMessage {
return mb.outboundMedia
}
func (mb *MessageBus) Close() {
if mb.closed.CompareAndSwap(false, true) {
mb.closeOnce.Do(func() {
// notify all blocked publishers to exit
close(mb.done)
// Drain buffered channels so messages aren't silently lost.
// Channels are NOT closed to avoid send-on-closed panics from concurrent publishers.
// because every publisher will check mb.closed before acquiring wg
// so we can be sure that new publishers will not be added new messages after this point
mb.closed.Store(true)
// wait for all ongoing Publish calls to finish, ensuring all messages have been sent to channels or exited
mb.wg.Wait()
// close channels safely
close(mb.inbound)
close(mb.outbound)
close(mb.outboundMedia)
// clean up any remaining messages in channels
drained := 0
for {
select {
case <-mb.inbound:
drained++
default:
goto doneInbound
}
for range mb.inbound {
drained++
}
doneInbound:
for {
select {
case <-mb.outbound:
drained++
default:
goto doneOutbound
}
for range mb.outbound {
drained++
}
doneOutbound:
for {
select {
case <-mb.outboundMedia:
drained++
default:
goto doneMedia
}
for range mb.outboundMedia {
drained++
}
doneMedia:
if drained > 0 {
logger.DebugCF("bus", "Drained buffered messages during close", map[string]any{
"count": drained,
})
}
}
})
}
+35 -17
View File
@@ -24,7 +24,7 @@ func TestPublishConsume(t *testing.T) {
t.Fatalf("PublishInbound failed: %v", err)
}
got, ok := mb.ConsumeInbound(ctx)
got, ok := <-mb.InboundChan()
if !ok {
t.Fatal("ConsumeInbound returned ok=false")
}
@@ -52,7 +52,7 @@ func TestPublishOutboundSubscribe(t *testing.T) {
t.Fatalf("PublishOutbound failed: %v", err)
}
got, ok := mb.SubscribeOutbound(ctx)
got, ok := <-mb.OutboundChan()
if !ok {
t.Fatal("SubscribeOutbound returned ok=false")
}
@@ -108,27 +108,48 @@ func TestPublishOutbound_BusClosed(t *testing.T) {
func TestConsumeInbound_ContextCancel(t *testing.T) {
mb := NewMessageBus()
defer mb.Close()
ctx, cancel := context.WithCancel(context.Background())
cancel()
for i := range defaultBusBufferSize {
if err := mb.PublishInbound(context.Background(), InboundMessage{Content: "fill"}); err != nil {
t.Fatalf("fill failed at %d: %v", i, err)
}
}
_, ok := mb.ConsumeInbound(ctx)
if ok {
t.Fatal("expected ok=false when context is canceled")
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
mb.PublishInbound(ctx, InboundMessage{Content: "ContextCancel"})
select {
case <-ctx.Done():
t.Log("context canceled, as expected")
case msg, ok := <-mb.InboundChan():
if !ok {
t.Fatal("expected ok=false when context is canceled")
}
if msg.Content == "ContextCancel" {
t.Fatalf("expected content 'ContextCancel', got %q", msg.Content)
}
}
}
func TestConsumeInbound_BusClosed(t *testing.T) {
mb := NewMessageBus()
mb.Close()
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
timer := time.AfterFunc(100*time.Millisecond, func() {
mb.Close()
})
_, ok := mb.ConsumeInbound(ctx)
if ok {
t.Fatal("expected ok=false when bus is closed")
select {
case <-timer.C:
t.Log("context canceled, as expected")
case _, ok := <-mb.InboundChan():
if ok {
t.Fatal("expected ok=false when context is canceled")
}
}
}
@@ -136,10 +157,7 @@ func TestSubscribeOutbound_BusClosed(t *testing.T) {
mb := NewMessageBus()
mb.Close()
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
_, ok := mb.SubscribeOutbound(ctx)
_, ok := <-mb.OutboundChan()
if ok {
t.Fatal("expected ok=false when bus is closed")
}
+33 -27
View File
@@ -585,7 +585,7 @@ func (m *Manager) sendWithRetry(ctx context.Context, name string, w *channelWork
func dispatchLoop[M any](
ctx context.Context,
m *Manager,
subscribe func(context.Context) (M, bool),
ch <-chan M,
getChannel func(M) string,
enqueue func(context.Context, *channelWorker, M) bool,
startMsg, stopMsg, unknownMsg, noWorkerMsg string,
@@ -593,35 +593,41 @@ func dispatchLoop[M any](
logger.InfoC("channels", startMsg)
for {
msg, ok := subscribe(ctx)
if !ok {
select {
case <-ctx.Done():
logger.InfoC("channels", stopMsg)
return
}
channel := getChannel(msg)
// Silently skip internal channels
if constants.IsInternalChannel(channel) {
continue
}
m.mu.RLock()
_, exists := m.channels[channel]
w, wExists := m.workers[channel]
m.mu.RUnlock()
if !exists {
logger.WarnCF("channels", unknownMsg, map[string]any{"channel": channel})
continue
}
if wExists && w != nil {
if !enqueue(ctx, w, msg) {
case msg, ok := <-ch:
if !ok {
logger.InfoC("channels", stopMsg)
return
}
} else if exists {
logger.WarnCF("channels", noWorkerMsg, map[string]any{"channel": channel})
channel := getChannel(msg)
// Silently skip internal channels
if constants.IsInternalChannel(channel) {
continue
}
m.mu.RLock()
_, exists := m.channels[channel]
w, wExists := m.workers[channel]
m.mu.RUnlock()
if !exists {
logger.WarnCF("channels", unknownMsg, map[string]any{"channel": channel})
continue
}
if wExists && w != nil {
if !enqueue(ctx, w, msg) {
return
}
} else if exists {
logger.WarnCF("channels", noWorkerMsg, map[string]any{"channel": channel})
}
}
}
}
@@ -629,7 +635,7 @@ func dispatchLoop[M any](
func (m *Manager) dispatchOutbound(ctx context.Context) {
dispatchLoop(
ctx, m,
m.bus.SubscribeOutbound,
m.bus.OutboundChan(),
func(msg bus.OutboundMessage) string { return msg.Channel },
func(ctx context.Context, w *channelWorker, msg bus.OutboundMessage) bool {
select {
@@ -649,7 +655,7 @@ func (m *Manager) dispatchOutbound(ctx context.Context) {
func (m *Manager) dispatchOutboundMedia(ctx context.Context) {
dispatchLoop(
ctx, m,
m.bus.SubscribeOutboundMedia,
m.bus.OutboundMediaChan(),
func(msg bus.OutboundMediaMessage) string { return msg.Channel },
func(ctx context.Context, w *channelWorker, msg bus.OutboundMediaMessage) bool {
select {
+14 -6
View File
@@ -34,11 +34,19 @@ func TestHandleC2CMessage_IncludesAccountIDMetadata(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
inbound, ok := messageBus.ConsumeInbound(ctx)
if !ok {
t.Fatal("expected inbound message")
}
if inbound.Metadata["account_id"] != "7750283E123456" {
t.Fatalf("account_id metadata = %q, want %q", inbound.Metadata["account_id"], "7750283E123456")
for {
select {
case <-ctx.Done():
t.Fatal("timeout waiting for inbound message")
return
case inbound, ok := <-messageBus.InboundChan():
if !ok {
t.Fatal("expected inbound message")
}
if inbound.Metadata["account_id"] != "7750283E123456" {
t.Fatalf("account_id metadata = %q, want %q", inbound.Metadata["account_id"], "7750283E123456")
}
return
}
}
}
@@ -3,7 +3,6 @@ package telegram
import (
"context"
"testing"
"time"
"github.com/mymmrac/telego"
@@ -36,10 +35,7 @@ func TestHandleMessage_DoesNotConsumeGenericCommandsLocally(t *testing.T) {
t.Fatalf("handleMessage error: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
inbound, ok := messageBus.ConsumeInbound(ctx)
inbound, ok := <-messageBus.InboundChan()
if !ok {
t.Fatal("expected inbound message to be forwarded")
}
@@ -108,22 +108,24 @@ func TestHandleMessage_GroupMentionOnly_BotCommandEntity(t *testing.T) {
t.Fatalf("handleMessage error: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 150*time.Millisecond)
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Microsecond)
defer cancel()
inbound, ok := messageBus.ConsumeInbound(ctx)
if tc.wantForwarded {
if !ok {
t.Fatal("expected inbound message to be forwarded")
select {
case <-ctx.Done():
if tc.wantForwarded {
t.Fatal("timeout waiting for message to be forwarded")
return
}
if inbound.Content != tc.wantContent {
t.Fatalf("content=%q want=%q", inbound.Content, tc.wantContent)
case inbound, ok := <-messageBus.InboundChan():
if tc.wantForwarded {
if !ok {
t.Fatal("expected inbound message to be forwarded")
}
if inbound.Content != tc.wantContent {
t.Fatalf("content=%q want=%q", inbound.Content, tc.wantContent)
}
return
}
return
}
if ok {
t.Fatalf("expected message to be filtered, got content=%q", inbound.Content)
}
})
}
+3 -13
View File
@@ -6,7 +6,6 @@ import (
"errors"
"strings"
"testing"
"time"
"github.com/mymmrac/telego"
ta "github.com/mymmrac/telego/telegoapi"
@@ -355,10 +354,7 @@ func TestHandleMessage_ForumTopic_SetsMetadata(t *testing.T) {
err := ch.handleMessage(context.Background(), msg)
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
inbound, ok := messageBus.ConsumeInbound(ctx)
inbound, ok := <-messageBus.InboundChan()
require.True(t, ok, "expected inbound message")
// Composite chatID should include thread ID
@@ -397,10 +393,7 @@ func TestHandleMessage_NoForum_NoThreadMetadata(t *testing.T) {
err := ch.handleMessage(context.Background(), msg)
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
inbound, ok := messageBus.ConsumeInbound(ctx)
inbound, ok := <-messageBus.InboundChan()
require.True(t, ok)
// Plain chatID without thread suffix
@@ -443,10 +436,7 @@ func TestHandleMessage_ReplyThread_NonForum_NoIsolation(t *testing.T) {
err := ch.handleMessage(context.Background(), msg)
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
inbound, ok := messageBus.ConsumeInbound(ctx)
inbound, ok := <-messageBus.InboundChan()
require.True(t, ok)
// chatID should NOT include thread suffix for non-forum groups
@@ -3,7 +3,6 @@ package whatsapp
import (
"context"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
@@ -25,10 +24,7 @@ func TestHandleIncomingMessage_DoesNotConsumeGenericCommandsLocally(t *testing.T
"content": "/help",
})
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
inbound, ok := messageBus.ConsumeInbound(ctx)
inbound, ok := <-messageBus.InboundChan()
if !ok {
t.Fatal("expected inbound message to be forwarded")
}
@@ -43,14 +43,19 @@ func TestHandleIncoming_DoesNotConsumeGenericCommandsLocally(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
inbound, ok := messageBus.ConsumeInbound(ctx)
if !ok {
t.Fatal("expected inbound message to be forwarded")
}
if inbound.Channel != "whatsapp_native" {
t.Fatalf("channel=%q", inbound.Channel)
}
if inbound.Content != "/new" {
t.Fatalf("content=%q", inbound.Content)
select {
case <-ctx.Done():
t.Fatal("timeout waiting for message to be forwarded")
return
case inbound, ok := <-messageBus.InboundChan():
if !ok {
t.Fatal("expected inbound message to be forwarded")
}
if inbound.Channel != "whatsapp_native" {
t.Fatalf("channel=%q", inbound.Channel)
}
if inbound.Content != "/new" {
t.Fatalf("content=%q", inbound.Content)
}
}
}
+66 -11
View File
@@ -65,6 +65,7 @@ type CronService struct {
mu sync.RWMutex
running bool
stopChan chan struct{}
wakeChan chan struct{}
gronx *gronx.Gronx
}
@@ -73,6 +74,7 @@ func NewCronService(storePath string, onJob JobHandler) *CronService {
storePath: storePath,
onJob: onJob,
gronx: gronx.New(),
wakeChan: make(chan struct{}),
}
// Initialize and load store on creation
cs.loadStore()
@@ -97,6 +99,9 @@ func (cs *CronService) Start() error {
}
cs.stopChan = make(chan struct{})
if cs.wakeChan == nil {
cs.wakeChan = make(chan struct{})
}
cs.running = true
go cs.runLoop(cs.stopChan)
@@ -119,14 +124,47 @@ func (cs *CronService) Stop() {
}
func (cs *CronService) runLoop(stopChan chan struct{}) {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
timer := time.NewTimer(time.Hour)
if !timer.Stop() {
<-timer.C
}
defer timer.Stop()
for {
// every loop, recalculate the next wake time
cs.mu.RLock()
nextWake := cs.getNextWakeMS()
cs.mu.RUnlock()
var delay time.Duration
now := time.Now().UnixMilli()
if nextWake == nil {
// no jobs, sleep for a long time (or until a new job is added)
delay = time.Hour
} else {
diff := *nextWake - now
if diff <= 0 {
delay = 0
} else {
delay = time.Duration(diff) * time.Millisecond
}
}
timer.Reset(delay)
select {
case <-stopChan:
return
case <-ticker.C:
case <-cs.wakeChan: // wake on new job or update
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
continue
case <-timer.C:
cs.checkJobs()
}
}
@@ -264,22 +302,19 @@ func (cs *CronService) executeJobByID(jobID string) {
}
func (cs *CronService) computeNextRun(schedule *CronSchedule, nowMS int64) *int64 {
if schedule.Kind == "at" {
switch schedule.Kind {
case "at":
if schedule.AtMS != nil && *schedule.AtMS > nowMS {
return schedule.AtMS
}
return nil
}
if schedule.Kind == "every" {
case "every":
if schedule.EveryMS == nil || *schedule.EveryMS <= 0 {
return nil
}
next := nowMS + *schedule.EveryMS
return &next
}
if schedule.Kind == "cron" {
case "cron":
if schedule.Expr == "" {
return nil
}
@@ -294,9 +329,19 @@ func (cs *CronService) computeNextRun(schedule *CronSchedule, nowMS int64) *int6
nextMS := nextTime.UnixMilli()
return &nextMS
default:
log.Printf("[cron] unknown schedule kind '%s'", schedule.Kind)
return nil
}
}
return nil
// wake up the loop to re-evaluate next wake time immediately (e.g. after add/update/remove jobs)
func (cs *CronService) notify() {
select {
case cs.wakeChan <- struct{}{}:
default:
// if the channel is full, it means the loop will wake up soon anyway, so we can skip sending
}
}
func (cs *CronService) recomputeNextRuns() {
@@ -400,6 +445,8 @@ func (cs *CronService) AddJob(
return nil, err
}
cs.notify()
return &job, nil
}
@@ -411,6 +458,9 @@ func (cs *CronService) UpdateJob(job *CronJob) error {
if cs.store.Jobs[i].ID == job.ID {
cs.store.Jobs[i] = *job
cs.store.Jobs[i].UpdatedAtMS = time.Now().UnixMilli()
cs.notify()
return cs.saveStoreUnsafe()
}
}
@@ -441,6 +491,8 @@ func (cs *CronService) removeJobUnsafe(jobID string) bool {
}
}
cs.notify()
return removed
}
@@ -463,6 +515,9 @@ func (cs *CronService) EnableJob(jobID string, enabled bool) *CronJob {
if err := cs.saveStoreUnsafe(); err != nil {
log.Printf("[cron] failed to save store after enable: %v", err)
}
cs.notify()
return job
}
}
+199
View File
@@ -1,10 +1,13 @@
package cron
import (
"fmt"
"os"
"path/filepath"
"runtime"
"sync"
"testing"
"time"
)
func TestSaveStore_FilePermissions(t *testing.T) {
@@ -36,3 +39,199 @@ func TestSaveStore_FilePermissions(t *testing.T) {
func int64Ptr(v int64) *int64 {
return &v
}
func setupService(handler JobHandler) (*CronService, string) {
tmpFile := fmt.Sprintf("test_cron_%d.json", time.Now().UnixNano())
cs := NewCronService(tmpFile, handler)
return cs, tmpFile
}
func TestCronService_CRUD(t *testing.T) {
cs, path := setupService(nil)
defer os.Remove(path)
// Test AddJob
at := time.Now().Add(time.Hour).UnixMilli()
job, err := cs.AddJob("Task1", CronSchedule{Kind: "at", AtMS: &at}, "msg", true, "ch", "to")
if err != nil || job.ID == "" {
t.Fatalf("AddJob failed: %v", err)
}
// Test ListJobs
if len(cs.ListJobs(true)) != 1 {
t.Error("ListJobs should return 1 job")
}
// Test UpdateJob
job.Name = "UpdatedName"
err = cs.UpdateJob(job)
if err != nil || cs.store.Jobs[0].Name != "UpdatedName" {
t.Error("UpdateJob failed")
}
// Test EnableJob
cs.EnableJob(job.ID, false)
if cs.store.Jobs[0].Enabled != false || cs.store.Jobs[0].State.NextRunAtMS != nil {
t.Error("EnableJob(false) failed to clear state")
}
// Test RemoveJob
removed := cs.RemoveJob(job.ID)
if !removed || len(cs.store.Jobs) != 0 {
t.Error("RemoveJob failed")
}
}
// 2. Test Cron Expression Calculation Logic
func TestCronService_ComputeNextRun(t *testing.T) {
cs, path := setupService(nil)
defer os.Remove(path)
now := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC).UnixMilli()
tests := []struct {
name string
schedule CronSchedule
wantNil bool
}{
{"Valid Cron", CronSchedule{Kind: "cron", Expr: "0 * * * *"}, false},
{"Invalid Cron", CronSchedule{Kind: "cron", Expr: "invalid"}, true},
{"Every MS", CronSchedule{Kind: "every", EveryMS: int64Ptr(5000)}, false},
{"At Future", CronSchedule{Kind: "at", AtMS: int64Ptr(now + 1000)}, false},
{"At Past", CronSchedule{Kind: "at", AtMS: int64Ptr(now - 1000)}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := cs.computeNextRun(&tt.schedule, now)
if (got == nil) != tt.wantNil {
t.Errorf("%s: got %v, wantNil %v", tt.name, got, tt.wantNil)
}
})
}
}
// 3. Test Execution Flow
func TestCronService_ExecutionFlow(t *testing.T) {
var mu sync.Mutex
executedJobs := make(map[string]bool)
handler := func(job *CronJob) (string, error) {
mu.Lock()
executedJobs[job.ID] = true
mu.Unlock()
return "ok", nil
}
cs, path := setupService(handler)
defer os.Remove(path)
// Start the service
if err := cs.Start(); err != nil {
t.Fatalf("Start failed: %v", err)
}
defer cs.Stop()
// Add a job then runs 100ms from now
target := time.Now().Add(100 * time.Millisecond).UnixMilli()
job, _ := cs.AddJob("FastJob", CronSchedule{Kind: "at", AtMS: &target}, "", false, "", "")
// Check for job execution with a timeout
success := false
for range 20 {
mu.Lock()
if executedJobs[job.ID] {
success = true
mu.Unlock()
break
}
mu.Unlock()
time.Sleep(100 * time.Millisecond)
}
if !success {
t.Error("Job was not executed in time")
}
// check that the job is removed after execution (DeleteAfterRun = true)
status := cs.Status()
if status["jobs"].(int) != 0 {
t.Errorf("Job should be deleted after run, got count: %v", status["jobs"])
}
}
func TestCronService_PersistenceIntegrity(t *testing.T) {
tmpFile := "persist_test.json"
defer os.Remove(tmpFile)
// write a job and persist
cs1 := NewCronService(tmpFile, nil)
at := int64(2000000000000)
cs1.AddJob("PersistMe", CronSchedule{Kind: "at", AtMS: &at}, "payload", true, "ch1", "")
// check file exists
if _, err := os.Stat(tmpFile); os.IsNotExist(err) {
t.Fatal("Store file was not created")
}
// reload and check data integrity
cs2 := NewCronService(tmpFile, nil)
if err := cs2.Load(); err != nil {
t.Fatalf("Failed to load store: %v", err)
}
jobs := cs2.ListJobs(true)
if len(jobs) != 1 || jobs[0].Name != "PersistMe" {
t.Errorf("Data corruption after reload. Got: %+v", jobs)
}
// test loading invalid JSON
os.WriteFile(tmpFile, []byte("{invalid json}"), 0o644)
cs3 := NewCronService(tmpFile, nil)
err := cs3.loadStore()
if err == nil {
t.Error("Should return error when loading invalid JSON")
}
}
func TestCronService_ConcurrentAccess(t *testing.T) {
cs, path := setupService(nil)
defer os.Remove(path)
cs.Start()
defer cs.Stop()
var wg sync.WaitGroup
workers := 10
iterations := 50
wg.Add(workers * 2)
// add jobs concurrently
for i := range workers {
go func(id int) {
defer wg.Done()
for j := range iterations {
at := time.Now().Add(time.Hour).UnixMilli()
cs.AddJob(fmt.Sprintf("Job-%d-%d", id, j), CronSchedule{Kind: "at", AtMS: &at}, "", false, "", "")
time.Sleep(100 * time.Microsecond)
}
}(i)
}
// read and update jobs concurrently
for range workers {
go func() {
defer wg.Done()
for j := range iterations {
jobs := cs.ListJobs(true)
if len(jobs) > 0 {
cs.EnableJob(jobs[0].ID, j%2 == 0)
}
time.Sleep(100 * time.Microsecond)
}
}()
}
wg.Wait()
}