diff --git a/cmd/picoclaw/internal/agent/helpers.go b/cmd/picoclaw/internal/agent/helpers.go index 746e9755e..f754abc65 100644 --- a/cmd/picoclaw/internal/agent/helpers.go +++ b/cmd/picoclaw/internal/agent/helpers.go @@ -48,6 +48,7 @@ func agentCmd(message, sessionKey, model string, debug bool) error { } msgBus := bus.NewMessageBus() + defer msgBus.Close() agentLoop := agent.NewAgentLoop(cfg, msgBus, provider) // Print agent startup info (only for interactive mode) diff --git a/cmd/picoclaw/internal/gateway/helpers.go b/cmd/picoclaw/internal/gateway/helpers.go index ec5ad5485..e3a51b5e9 100644 --- a/cmd/picoclaw/internal/gateway/helpers.go +++ b/cmd/picoclaw/internal/gateway/helpers.go @@ -223,6 +223,7 @@ func gatewayCmd(debug bool) error { cp.Close() } cancel() + msgBus.Close() healthServer.Stop(context.Background()) deviceService.Stop() heartbeatService.Stop() diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 124f45675..ebbeec0c1 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -121,7 +121,7 @@ func registerSharedTools( // Message tool messageTool := tools.NewMessageTool() messageTool.SetSendCallback(func(channel, chatID, content string) error { - msgBus.PublishOutbound(bus.OutboundMessage{ + msgBus.PublishOutbound(context.TODO(), bus.OutboundMessage{ Channel: channel, ChatID: chatID, Content: content, @@ -200,7 +200,7 @@ func (al *AgentLoop) Run(ctx context.Context) error { } if !alreadySent { - al.bus.PublishOutbound(bus.OutboundMessage{ + al.bus.PublishOutbound(ctx, bus.OutboundMessage{ Channel: msg.Channel, ChatID: msg.ChatID, Content: response, @@ -469,7 +469,7 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opt // 8. Optional: send response via bus if opts.SendResponse { - al.bus.PublishOutbound(bus.OutboundMessage{ + al.bus.PublishOutbound(ctx, bus.OutboundMessage{ Channel: opts.Channel, ChatID: opts.ChatID, Content: finalContent, @@ -586,7 +586,7 @@ func (al *AgentLoop) runLLMIteration( }) if retry == 0 && !constants.IsInternalChannel(opts.Channel) { - al.bus.PublishOutbound(bus.OutboundMessage{ + al.bus.PublishOutbound(ctx, bus.OutboundMessage{ Channel: opts.Channel, ChatID: opts.ChatID, Content: "Context window exceeded. Compressing history and retrying...", @@ -716,7 +716,7 @@ func (al *AgentLoop) runLLMIteration( // Send ForUser content to user immediately if not Silent if !toolResult.Silent && toolResult.ForUser != "" && opts.SendResponse { - al.bus.PublishOutbound(bus.OutboundMessage{ + al.bus.PublishOutbound(ctx, bus.OutboundMessage{ Channel: opts.Channel, ChatID: opts.ChatID, Content: toolResult.ForUser, diff --git a/pkg/bus/bus.go b/pkg/bus/bus.go index 58c0a25d5..100ddc456 100644 --- a/pkg/bus/bus.go +++ b/pkg/bus/bus.go @@ -2,81 +2,80 @@ package bus import ( "context" - "sync" + "errors" + "sync/atomic" ) +// ErrBusClosed is returned when publishing to a closed MessageBus. +var ErrBusClosed = errors.New("message bus closed") + type MessageBus struct { inbound chan InboundMessage outbound chan OutboundMessage - handlers map[string]MessageHandler - closed bool - mu sync.RWMutex + done chan struct{} + closed atomic.Bool } func NewMessageBus() *MessageBus { return &MessageBus{ inbound: make(chan InboundMessage, 100), outbound: make(chan OutboundMessage, 100), - handlers: make(map[string]MessageHandler), + done: make(chan struct{}), } } -func (mb *MessageBus) PublishInbound(msg InboundMessage) { - mb.mu.RLock() - defer mb.mu.RUnlock() - if mb.closed { - return +func (mb *MessageBus) PublishInbound(ctx context.Context, msg InboundMessage) error { + if mb.closed.Load() { + return ErrBusClosed + } + select { + case mb.inbound <- msg: + return nil + case <-mb.done: + return ErrBusClosed + case <-ctx.Done(): + return ctx.Err() } - mb.inbound <- msg } func (mb *MessageBus) ConsumeInbound(ctx context.Context) (InboundMessage, bool) { select { - case msg := <-mb.inbound: - return msg, true + case msg, ok := <-mb.inbound: + return msg, ok + case <-mb.done: + return InboundMessage{}, false case <-ctx.Done(): return InboundMessage{}, false } } -func (mb *MessageBus) PublishOutbound(msg OutboundMessage) { - mb.mu.RLock() - defer mb.mu.RUnlock() - if mb.closed { - return +func (mb *MessageBus) PublishOutbound(ctx context.Context, msg OutboundMessage) error { + if mb.closed.Load() { + return ErrBusClosed + } + select { + case mb.outbound <- msg: + return nil + case <-mb.done: + return ErrBusClosed + case <-ctx.Done(): + return ctx.Err() } - mb.outbound <- msg } func (mb *MessageBus) SubscribeOutbound(ctx context.Context) (OutboundMessage, bool) { select { - case msg := <-mb.outbound: - return msg, true + case msg, ok := <-mb.outbound: + return msg, ok + case <-mb.done: + return OutboundMessage{}, false case <-ctx.Done(): return OutboundMessage{}, false } } -func (mb *MessageBus) RegisterHandler(channel string, handler MessageHandler) { - mb.mu.Lock() - defer mb.mu.Unlock() - mb.handlers[channel] = handler -} - -func (mb *MessageBus) GetHandler(channel string) (MessageHandler, bool) { - mb.mu.RLock() - defer mb.mu.RUnlock() - handler, ok := mb.handlers[channel] - return handler, ok -} - func (mb *MessageBus) Close() { - mb.mu.Lock() - defer mb.mu.Unlock() - if mb.closed { - return + if mb.closed.CompareAndSwap(false, true) { + close(mb.done) } - mb.closed = true - close(mb.inbound) - close(mb.outbound) } diff --git a/pkg/bus/bus_test.go b/pkg/bus/bus_test.go new file mode 100644 index 000000000..47826824e --- /dev/null +++ b/pkg/bus/bus_test.go @@ -0,0 +1,229 @@ +package bus + +import ( + "context" + "sync" + "testing" + "time" +) + +func TestPublishConsume(t *testing.T) { + mb := NewMessageBus() + defer mb.Close() + + ctx := context.Background() + + msg := InboundMessage{ + Channel: "test", + SenderID: "user1", + ChatID: "chat1", + Content: "hello", + } + + if err := mb.PublishInbound(ctx, msg); err != nil { + t.Fatalf("PublishInbound failed: %v", err) + } + + got, ok := mb.ConsumeInbound(ctx) + if !ok { + t.Fatal("ConsumeInbound returned ok=false") + } + if got.Content != "hello" { + t.Fatalf("expected content 'hello', got %q", got.Content) + } + if got.Channel != "test" { + t.Fatalf("expected channel 'test', got %q", got.Channel) + } +} + +func TestPublishOutboundSubscribe(t *testing.T) { + mb := NewMessageBus() + defer mb.Close() + + ctx := context.Background() + + msg := OutboundMessage{ + Channel: "telegram", + ChatID: "123", + Content: "world", + } + + if err := mb.PublishOutbound(ctx, msg); err != nil { + t.Fatalf("PublishOutbound failed: %v", err) + } + + got, ok := mb.SubscribeOutbound(ctx) + if !ok { + t.Fatal("SubscribeOutbound returned ok=false") + } + if got.Content != "world" { + t.Fatalf("expected content 'world', got %q", got.Content) + } +} + +func TestPublishInbound_ContextCancel(t *testing.T) { + mb := NewMessageBus() + defer mb.Close() + + // Fill the buffer + ctx := context.Background() + for i := 0; i < 100; i++ { + if err := mb.PublishInbound(ctx, InboundMessage{Content: "fill"}); err != nil { + t.Fatalf("fill failed at %d: %v", i, err) + } + } + + // Now buffer is full; publish with a cancelled context + cancelCtx, cancel := context.WithCancel(context.Background()) + cancel() + + err := mb.PublishInbound(cancelCtx, InboundMessage{Content: "overflow"}) + if err == nil { + t.Fatal("expected error from cancelled context, got nil") + } + if err != context.Canceled { + t.Fatalf("expected context.Canceled, got %v", err) + } +} + +func TestPublishInbound_BusClosed(t *testing.T) { + mb := NewMessageBus() + mb.Close() + + err := mb.PublishInbound(context.Background(), InboundMessage{Content: "test"}) + if err != ErrBusClosed { + t.Fatalf("expected ErrBusClosed, got %v", err) + } +} + +func TestPublishOutbound_BusClosed(t *testing.T) { + mb := NewMessageBus() + mb.Close() + + err := mb.PublishOutbound(context.Background(), OutboundMessage{Content: "test"}) + if err != ErrBusClosed { + t.Fatalf("expected ErrBusClosed, got %v", err) + } +} + +func TestConsumeInbound_ContextCancel(t *testing.T) { + mb := NewMessageBus() + defer mb.Close() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, ok := mb.ConsumeInbound(ctx) + if ok { + t.Fatal("expected ok=false when context is cancelled") + } +} + +func TestConsumeInbound_BusClosed(t *testing.T) { + mb := NewMessageBus() + mb.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + _, ok := mb.ConsumeInbound(ctx) + if ok { + t.Fatal("expected ok=false when bus is closed") + } +} + +func TestSubscribeOutbound_BusClosed(t *testing.T) { + mb := NewMessageBus() + mb.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + _, ok := mb.SubscribeOutbound(ctx) + if ok { + t.Fatal("expected ok=false when bus is closed") + } +} + +func TestConcurrentPublishClose(t *testing.T) { + mb := NewMessageBus() + ctx := context.Background() + + const numGoroutines = 100 + var wg sync.WaitGroup + wg.Add(numGoroutines + 1) + + // Spawn many goroutines trying to publish + for i := 0; i < numGoroutines; i++ { + go func() { + defer wg.Done() + // Use a short timeout context so we don't block forever after close + publishCtx, cancel := context.WithTimeout(ctx, 50*time.Millisecond) + defer cancel() + // Errors are expected; we just must not panic or deadlock + _ = mb.PublishInbound(publishCtx, InboundMessage{Content: "concurrent"}) + }() + } + + // Close from another goroutine + go func() { + defer wg.Done() + time.Sleep(5 * time.Millisecond) + mb.Close() + }() + + // Must complete without deadlock + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + // success + case <-time.After(5 * time.Second): + t.Fatal("test timed out - possible deadlock") + } +} + +func TestPublishInbound_FullBuffer(t *testing.T) { + mb := NewMessageBus() + defer mb.Close() + + ctx := context.Background() + + // Fill the buffer + for i := 0; i < 100; i++ { + if err := mb.PublishInbound(ctx, InboundMessage{Content: "fill"}); err != nil { + t.Fatalf("fill failed at %d: %v", i, err) + } + } + + // Buffer is full; publish with short timeout + timeoutCtx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + + err := mb.PublishInbound(timeoutCtx, InboundMessage{Content: "overflow"}) + if err == nil { + t.Fatal("expected error when buffer is full and context times out") + } + if err != context.DeadlineExceeded { + t.Fatalf("expected context.DeadlineExceeded, got %v", err) + } +} + +func TestCloseIdempotent(t *testing.T) { + mb := NewMessageBus() + + // Multiple Close calls must not panic + mb.Close() + mb.Close() + mb.Close() + + // After close, publish should return ErrBusClosed + err := mb.PublishInbound(context.Background(), InboundMessage{Content: "test"}) + if err != ErrBusClosed { + t.Fatalf("expected ErrBusClosed after multiple closes, got %v", err) + } +} diff --git a/pkg/bus/types.go b/pkg/bus/types.go index e49713eb8..358829c55 100644 --- a/pkg/bus/types.go +++ b/pkg/bus/types.go @@ -24,5 +24,3 @@ type OutboundMessage struct { ChatID string `json:"chat_id"` Content string `json:"content"` } - -type MessageHandler func(InboundMessage) error diff --git a/pkg/channels/base.go b/pkg/channels/base.go index d967d9e91..adacb8c78 100644 --- a/pkg/channels/base.go +++ b/pkg/channels/base.go @@ -143,7 +143,7 @@ func (c *BaseChannel) HandleMessage( Metadata: metadata, } - c.bus.PublishInbound(msg) + c.bus.PublishInbound(context.TODO(), msg) } func (c *BaseChannel) SetRunning(running bool) { diff --git a/pkg/devices/service.go b/pkg/devices/service.go index 1541d3c57..408e1c8aa 100644 --- a/pkg/devices/service.go +++ b/pkg/devices/service.go @@ -127,7 +127,7 @@ func (s *Service) sendNotification(ev *events.DeviceEvent) { } msg := ev.FormatMessage() - msgBus.PublishOutbound(bus.OutboundMessage{ + msgBus.PublishOutbound(context.TODO(), bus.OutboundMessage{ Channel: platform, ChatID: userID, Content: msg, diff --git a/pkg/heartbeat/service.go b/pkg/heartbeat/service.go index e05a9fdbf..3e58dbc7a 100644 --- a/pkg/heartbeat/service.go +++ b/pkg/heartbeat/service.go @@ -7,6 +7,7 @@ package heartbeat import ( + "context" "fmt" "os" "path/filepath" @@ -307,7 +308,7 @@ func (hs *HeartbeatService) sendResponse(response string) { return } - msgBus.PublishOutbound(bus.OutboundMessage{ + msgBus.PublishOutbound(context.TODO(), bus.OutboundMessage{ Channel: platform, ChatID: userID, Content: response, diff --git a/pkg/tools/cron.go b/pkg/tools/cron.go index 562fffc84..3c13f5968 100644 --- a/pkg/tools/cron.go +++ b/pkg/tools/cron.go @@ -294,7 +294,7 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string { output = fmt.Sprintf("Scheduled command '%s' executed:\n%s", job.Payload.Command, result.ForLLM) } - t.msgBus.PublishOutbound(bus.OutboundMessage{ + t.msgBus.PublishOutbound(context.TODO(), bus.OutboundMessage{ Channel: channel, ChatID: chatID, Content: output, @@ -304,7 +304,7 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string { // If deliver=true, send message directly without agent processing if job.Payload.Deliver { - t.msgBus.PublishOutbound(bus.OutboundMessage{ + t.msgBus.PublishOutbound(context.TODO(), bus.OutboundMessage{ Channel: channel, ChatID: chatID, Content: job.Payload.Message, diff --git a/pkg/tools/subagent.go b/pkg/tools/subagent.go index ad371a649..081a02872 100644 --- a/pkg/tools/subagent.go +++ b/pkg/tools/subagent.go @@ -218,7 +218,7 @@ After completing the task, provide a clear summary of what was done.` // Send announce message back to main agent if sm.bus != nil { announceContent := fmt.Sprintf("Task '%s' completed.\n\nResult:\n%s", task.Label, task.Result) - sm.bus.PublishInbound(bus.InboundMessage{ + sm.bus.PublishInbound(context.TODO(), bus.InboundMessage{ Channel: "system", SenderID: fmt.Sprintf("subagent:%s", task.ID), // Format: "original_channel:original_chat_id" for routing back