From 78fd080189cdaab14861a41273fc5e0cc6d1cf43 Mon Sep 17 00:00:00 2001 From: Hoshina Date: Mon, 27 Apr 2026 13:09:03 +0800 Subject: [PATCH] fix(events): keep runtime observers non-blocking Add a non-blocking runtime publish path and switch hot-path publishers to it. Enforce subscription timeout boundaries, keep ordered subscriber snapshots up to date on subscribe changes, expose all runtime kinds to process hooks, add safe log attrs for non-agent events, and close the gateway message bus on full shutdown. --- pkg/agent/events_runtime.go | 14 +--- pkg/agent/hook_mount.go | 22 +----- pkg/agent/hook_mount_test.go | 4 +- pkg/agent/runtime_event_logger_test.go | 26 +++++++ pkg/bus/bus.go | 1 + pkg/bus/bus_test.go | 10 +++ pkg/bus/events.go | 22 +++--- pkg/channels/events.go | 45 +++++++++--- pkg/channels/manager_test.go | 6 ++ pkg/events/bus.go | 42 ++++++----- pkg/events/events_test.go | 99 ++++++++++++++++++++++++++ pkg/events/kind.go | 53 ++++++++++++++ pkg/events/subscription.go | 73 ++++++++++++++++--- pkg/events/subscription_test.go | 39 ++++++++++ pkg/gateway/events.go | 19 +++-- pkg/gateway/gateway.go | 7 +- pkg/gateway/gateway_test.go | 37 ++++++++++ pkg/mcp/events.go | 33 ++++++--- pkg/mcp/manager_test.go | 8 +++ pkg/tools/integration/mcp_tool.go | 20 +++++- pkg/tools/integration/mcp_tool_test.go | 5 ++ 21 files changed, 486 insertions(+), 99 deletions(-) diff --git a/pkg/agent/events_runtime.go b/pkg/agent/events_runtime.go index b530f6161..2284665e6 100644 --- a/pkg/agent/events_runtime.go +++ b/pkg/agent/events_runtime.go @@ -1,23 +1,13 @@ package agent -import ( - "context" - "time" - - runtimeevents "github.com/sipeed/picoclaw/pkg/events" -) - -const runtimeEventPublishTimeout = 100 * time.Millisecond +import runtimeevents "github.com/sipeed/picoclaw/pkg/events" func (al *AgentLoop) publishRuntimeEvent(evt runtimeevents.Event) { if al == nil || al.runtimeEvents == nil { return } - ctx, cancel := context.WithTimeout(context.Background(), runtimeEventPublishTimeout) - defer cancel() - - al.runtimeEvents.Publish(ctx, evt) + al.runtimeEvents.PublishNonBlocking(evt) } func runtimeScopeFromHookMeta(meta HookMeta, eventCtx *TurnContext) runtimeevents.Scope { diff --git a/pkg/agent/hook_mount.go b/pkg/agent/hook_mount.go index 409d56b32..c518feee8 100644 --- a/pkg/agent/hook_mount.go +++ b/pkg/agent/hook_mount.go @@ -311,27 +311,7 @@ func processHookObserveKindsFromConfig(observe []string) ([]string, bool, error) } func validHookEventKinds() map[string]string { - runtimeKinds := []runtimeevents.Kind{ - runtimeevents.KindAgentTurnStart, - runtimeevents.KindAgentTurnEnd, - runtimeevents.KindAgentLLMRequest, - runtimeevents.KindAgentLLMDelta, - runtimeevents.KindAgentLLMResponse, - runtimeevents.KindAgentLLMRetry, - runtimeevents.KindAgentContextCompress, - runtimeevents.KindAgentSessionSummarize, - runtimeevents.KindAgentToolExecStart, - runtimeevents.KindAgentToolExecEnd, - runtimeevents.KindAgentToolExecSkipped, - runtimeevents.KindAgentSteeringInjected, - runtimeevents.KindAgentFollowUpQueued, - runtimeevents.KindAgentInterruptReceived, - runtimeevents.KindAgentSubTurnSpawn, - runtimeevents.KindAgentSubTurnEnd, - runtimeevents.KindAgentSubTurnResultDelivered, - runtimeevents.KindAgentSubTurnOrphan, - runtimeevents.KindAgentError, - } + runtimeKinds := runtimeevents.KnownKinds() kinds := make(map[string]string, len(runtimeKinds)*2) for _, kind := range runtimeKinds { kinds[kind.String()] = kind.String() diff --git a/pkg/agent/hook_mount_test.go b/pkg/agent/hook_mount_test.go index bf98e6e03..5cd64af7b 100644 --- a/pkg/agent/hook_mount_test.go +++ b/pkg/agent/hook_mount_test.go @@ -163,6 +163,8 @@ func TestProcessHookObserveKindsFromConfigAcceptsRuntimeNames(t *testing.T) { kinds, enabled, err := processHookObserveKindsFromConfig([]string{ "tool_exec_start", "agent.tool.exec_end", + "gateway.ready", + "mcp.server.failed", }) if err != nil { t.Fatalf("processHookObserveKindsFromConfig failed: %v", err) @@ -171,7 +173,7 @@ func TestProcessHookObserveKindsFromConfigAcceptsRuntimeNames(t *testing.T) { t.Fatal("expected observe to be enabled") } - want := []string{"agent.tool.exec_start", "agent.tool.exec_end"} + want := []string{"agent.tool.exec_start", "agent.tool.exec_end", "gateway.ready", "mcp.server.failed"} if !slices.Equal(kinds, want) { t.Fatalf("observe kinds = %v, want %v", kinds, want) } diff --git a/pkg/agent/runtime_event_logger_test.go b/pkg/agent/runtime_event_logger_test.go index a64529314..1c95b365c 100644 --- a/pkg/agent/runtime_event_logger_test.go +++ b/pkg/agent/runtime_event_logger_test.go @@ -102,6 +102,32 @@ func TestRuntimeEventLogFieldsSummarizeAgentPayload(t *testing.T) { } } +func TestRuntimeEventLogFieldsIncludeSafeAttrs(t *testing.T) { + fields := runtimeEventLogFields(runtimeevents.Event{ + ID: "evt-gateway", + Kind: runtimeevents.KindGatewayReady, + Severity: runtimeevents.SeverityInfo, + Attrs: map[string]any{ + "duration_ms": 42, + "error": "startup failed", + "event_kind": "conflict", + }, + }) + + if fields["duration_ms"] != 42 || fields["error"] != "startup failed" { + t.Fatalf("missing safe attrs: %#v", fields) + } + if fields["event_kind"] != runtimeevents.KindGatewayReady.String() { + t.Fatalf("event_kind overwritten by attrs: %#v", fields) + } + if fields["attr_event_kind"] != "conflict" { + t.Fatalf("conflicting attr not preserved with prefix: %#v", fields) + } + if _, ok := fields["payload"]; ok { + t.Fatalf("raw payload should not be included by runtimeEventLogFields: %#v", fields) + } +} + func runtimeEventLoggerStateForTest( al *AgentLoop, ) (*runtimeEventLogger, runtimeevents.Subscription) { diff --git a/pkg/bus/bus.go b/pkg/bus/bus.go index fe01f31d5..dee67d87c 100644 --- a/pkg/bus/bus.go +++ b/pkg/bus/bus.go @@ -55,6 +55,7 @@ type MessageBus struct { // EventPublisher is the minimal runtime event publisher used by MessageBus. type EventPublisher interface { Publish(ctx context.Context, evt runtimeevents.Event) runtimeevents.PublishResult + PublishNonBlocking(evt runtimeevents.Event) runtimeevents.PublishResult } func NewMessageBus() *MessageBus { diff --git a/pkg/bus/bus_test.go b/pkg/bus/bus_test.go index e3abdc0e6..a0a9e1e14 100644 --- a/pkg/bus/bus_test.go +++ b/pkg/bus/bus_test.go @@ -203,6 +203,9 @@ func TestMessageBusPublishesRuntimeFailureAndCloseEvents(t *testing.T) { failed.Severity != runtimeevents.SeverityError { t.Fatalf("publish failed event = %+v", failed) } + if failed.Attrs["stream"] != "inbound" || failed.Attrs["error"] == "" { + t.Fatalf("publish failed attrs = %#v, want stream and error", failed.Attrs) + } if err := mb.PublishOutbound(context.Background(), OutboundMessage{ Context: NewOutboundContext("telegram", "chat-1", ""), @@ -213,9 +216,13 @@ func TestMessageBusPublishesRuntimeFailureAndCloseEvents(t *testing.T) { mb.Close() seen := map[runtimeevents.Kind]bool{} + var drainedAttrs map[string]any for range 3 { evt := receiveBusRuntimeEvent(t, eventsCh) seen[evt.Kind] = true + if evt.Kind == runtimeevents.KindBusCloseDrained { + drainedAttrs = evt.Attrs + } } for _, kind := range []runtimeevents.Kind{ runtimeevents.KindBusCloseStarted, @@ -226,6 +233,9 @@ func TestMessageBusPublishesRuntimeFailureAndCloseEvents(t *testing.T) { t.Fatalf("missing %s event, seen=%v", kind, seen) } } + if drainedAttrs["drained"] != 1 { + t.Fatalf("bus close drained attrs = %#v, want drained count", drainedAttrs) + } } func receiveBusRuntimeEvent(t *testing.T, ch <-chan runtimeevents.Event) runtimeevents.Event { diff --git a/pkg/bus/events.go b/pkg/bus/events.go index c542892a1..4640ed1fc 100644 --- a/pkg/bus/events.go +++ b/pkg/bus/events.go @@ -1,14 +1,9 @@ package bus import ( - "context" - "time" - runtimeevents "github.com/sipeed/picoclaw/pkg/events" ) -const busEventPublishTimeout = 100 * time.Millisecond - type busPublishFailedPayload struct { Stream string `json:"stream"` Error string `json:"error"` @@ -27,9 +22,7 @@ func (mb *MessageBus) publishFailure(stream string, scope runtimeevents.Scope, e return } - ctx, cancel := context.WithTimeout(context.Background(), busEventPublishTimeout) - defer cancel() - publisher.Publish(ctx, runtimeevents.Event{ + publisher.PublishNonBlocking(runtimeevents.Event{ Kind: runtimeevents.KindBusPublishFailed, Source: runtimeevents.Source{Component: "bus", Name: stream}, Scope: scope, @@ -38,6 +31,10 @@ func (mb *MessageBus) publishFailure(stream string, scope runtimeevents.Scope, e Stream: stream, Error: err.Error(), }, + Attrs: map[string]any{ + "stream": stream, + "error": err.Error(), + }, }) } @@ -50,13 +47,16 @@ func (mb *MessageBus) publishCloseEvent(kind runtimeevents.Kind, drained int) { return } - ctx, cancel := context.WithTimeout(context.Background(), busEventPublishTimeout) - defer cancel() - publisher.Publish(ctx, runtimeevents.Event{ + attrs := map[string]any{} + if drained > 0 { + attrs["drained"] = drained + } + publisher.PublishNonBlocking(runtimeevents.Event{ Kind: kind, Source: runtimeevents.Source{Component: "bus"}, Severity: runtimeevents.SeverityInfo, Payload: busClosePayload{Drained: drained}, + Attrs: attrs, }) } diff --git a/pkg/channels/events.go b/pkg/channels/events.go index f731b84c5..60e5640f0 100644 --- a/pkg/channels/events.go +++ b/pkg/channels/events.go @@ -1,15 +1,10 @@ package channels import ( - "context" - "time" - "github.com/sipeed/picoclaw/pkg/bus" runtimeevents "github.com/sipeed/picoclaw/pkg/events" ) -const channelEventPublishTimeout = 100 * time.Millisecond - func channelTypeForEvent(m *Manager, channelName string) string { if m == nil || m.config == nil { return channelName @@ -33,17 +28,51 @@ func (m *Manager) publishChannelEvent( if scope.Channel == "" { scope.Channel = channelName } - ctx, cancel := context.WithTimeout(context.Background(), channelEventPublishTimeout) - defer cancel() - m.runtimeEvents.Publish(ctx, runtimeevents.Event{ + m.runtimeEvents.PublishNonBlocking(runtimeevents.Event{ Kind: kind, Source: runtimeevents.Source{Component: "channel", Name: channelName}, Scope: scope, Severity: severity, Payload: payload, + Attrs: channelEventAttrs(payload), }) } +func channelEventAttrs(payload any) map[string]any { + switch payload := payload.(type) { + case ChannelLifecyclePayload: + attrs := map[string]any{} + setAttrString(attrs, "type", payload.Type) + setAttrString(attrs, "error", payload.Error) + return attrs + case ChannelOutboundPayload: + attrs := map[string]any{} + if payload.Media { + attrs["media"] = payload.Media + } + if payload.ContentLen > 0 { + attrs["content_len"] = payload.ContentLen + } + if len(payload.MessageIDs) > 0 { + attrs["message_ids_count"] = len(payload.MessageIDs) + } + setAttrString(attrs, "reply_to_message_id", payload.ReplyToMessageID) + setAttrString(attrs, "error", payload.Error) + if payload.Retries > 0 { + attrs["retries"] = payload.Retries + } + return attrs + default: + return nil + } +} + +func setAttrString(attrs map[string]any, key, value string) { + if value != "" { + attrs[key] = value + } +} + func (m *Manager) publishOutboundSent( channelName string, msg bus.OutboundMessage, diff --git a/pkg/channels/manager_test.go b/pkg/channels/manager_test.go index 8680a08a8..5369f7b71 100644 --- a/pkg/channels/manager_test.go +++ b/pkg/channels/manager_test.go @@ -380,6 +380,9 @@ func TestSendWithRetryPublishesOutboundRuntimeEvents(t *testing.T) { if sent.Kind != runtimeevents.KindChannelMessageOutboundSent || sent.Scope.ChatID != "chat-1" { t.Fatalf("sent event = %+v", sent) } + if sent.Attrs["content_len"] != 5 { + t.Fatalf("sent attrs = %#v, want content_len", sent.Attrs) + } failWorker := &channelWorker{ ch: &mockChannel{ @@ -402,6 +405,9 @@ func TestSendWithRetryPublishesOutboundRuntimeEvents(t *testing.T) { if failed.Severity != runtimeevents.SeverityError { t.Fatalf("failed severity = %q", failed.Severity) } + if failed.Attrs["error"] == "" || failed.Attrs["retries"] != maxRetries { + t.Fatalf("failed attrs = %#v, want error and retries", failed.Attrs) + } } func TestSendWithRetry_TemporaryThenSuccess(t *testing.T) { diff --git a/pkg/events/bus.go b/pkg/events/bus.go index 80ed9eb80..f193ccb74 100644 --- a/pkg/events/bus.go +++ b/pkg/events/bus.go @@ -14,6 +14,7 @@ var globalEventSeq atomic.Uint64 // Bus publishes runtime events and creates filtered channels. type Bus interface { Publish(ctx context.Context, evt Event) PublishResult + PublishNonBlocking(evt Event) PublishResult Channel() EventChannel Close() error Stats() Stats @@ -30,9 +31,10 @@ type PublishResult struct { // EventBus is an in-process runtime event broadcaster. type EventBus struct { - mu sync.RWMutex - subs map[uint64]*eventSubscription - closed bool + mu sync.RWMutex + subs map[uint64]*eventSubscription + orderedSubs []*eventSubscription + closed bool nextSubID atomic.Uint64 published atomic.Uint64 @@ -53,6 +55,15 @@ func NewBus() *EventBus { // Publish broadcasts evt to subscriptions whose filters match it. func (b *EventBus) Publish(ctx context.Context, evt Event) PublishResult { + return b.publish(ctx, evt, false) +} + +// PublishNonBlocking broadcasts evt without waiting for subscriber queue capacity. +func (b *EventBus) PublishNonBlocking(evt Event) PublishResult { + return b.publish(context.Background(), evt, true) +} + +func (b *EventBus) publish(ctx context.Context, evt Event, nonBlocking bool) PublishResult { if b == nil { return PublishResult{Closed: true} } @@ -82,7 +93,7 @@ func (b *EventBus) Publish(ctx context.Context, evt Event) PublishResult { result.Matched++ b.matched.Add(1) - delivery := sub.enqueue(ctx, evt) + delivery := sub.enqueue(ctx, evt, nonBlocking) if delivery.closed { continue } @@ -114,11 +125,9 @@ func (b *EventBus) Close() error { return nil } b.closed = true - subs := make([]*eventSubscription, 0, len(b.subs)) - for id, sub := range b.subs { - subs = append(subs, sub) - delete(b.subs, id) - } + subs := b.orderedSubs + b.subs = nil + b.orderedSubs = nil b.mu.Unlock() for _, sub := range subs { @@ -135,14 +144,9 @@ func (b *EventBus) Stats() Stats { b.mu.RLock() closed := b.closed - subs := make([]*eventSubscription, 0, len(b.subs)) - for _, sub := range b.subs { - subs = append(subs, sub) - } + subs := b.orderedSubs b.mu.RUnlock() - sortSubscriptions(subs) - stats := Stats{ Published: b.published.Load(), Matched: b.matched.Load(), @@ -180,6 +184,7 @@ func (b *EventBus) subscribe( return nil, ErrBusClosed } b.subs[id] = sub + b.rebuildOrderedSubscribersLocked() b.mu.Unlock() if handler != nil { @@ -194,6 +199,7 @@ func (b *EventBus) unsubscribe(id uint64) { sub, ok := b.subs[id] if ok { delete(b.subs, id) + b.rebuildOrderedSubscribersLocked() } b.mu.Unlock() @@ -210,12 +216,16 @@ func (b *EventBus) snapshotSubscribers() ([]*eventSubscription, bool) { return nil, true } + return b.orderedSubs, false +} + +func (b *EventBus) rebuildOrderedSubscribersLocked() { subs := make([]*eventSubscription, 0, len(b.subs)) for _, sub := range b.subs { subs = append(subs, sub) } sortSubscriptions(subs) - return subs, false + b.orderedSubs = subs } func sortSubscriptions(subs []*eventSubscription) { diff --git a/pkg/events/events_test.go b/pkg/events/events_test.go index 19b9df96d..6991e8291 100644 --- a/pkg/events/events_test.go +++ b/pkg/events/events_test.go @@ -131,6 +131,105 @@ func TestBlockRespectsContext(t *testing.T) { } } +func TestPublishNonBlockingDropsForFullBlockSubscriber(t *testing.T) { + t.Parallel() + + bus := NewBus() + defer closeBus(t, bus) + + sub, _, err := bus.Channel().SubscribeChan( + context.Background(), + SubscribeOptions{Name: "block", Buffer: 1, Backpressure: Block}, + ) + if err != nil { + t.Fatalf("SubscribeChan failed: %v", err) + } + + first := bus.PublishNonBlocking(Event{Kind: Kind("test.first")}) + if first.Delivered != 1 { + t.Fatalf("first PublishNonBlocking = %+v, want one delivered event", first) + } + + resultCh := make(chan PublishResult, 1) + go func() { + resultCh <- bus.PublishNonBlocking(Event{Kind: Kind("test.second")}) + }() + + select { + case second := <-resultCh: + if second.Matched != 1 || second.Delivered != 0 || second.Dropped != 1 || second.Blocked != 0 { + t.Fatalf("second PublishNonBlocking = %+v, want non-blocking drop", second) + } + case <-time.After(100 * time.Millisecond): + t.Fatal("PublishNonBlocking blocked on full Block subscriber") + } + + if got := sub.Stats().Dropped; got != 1 { + t.Fatalf("subscription dropped = %d, want 1", got) + } +} + +func TestStatsSubscribersKeepPriorityOrder(t *testing.T) { + t.Parallel() + + bus := NewBus() + defer closeBus(t, bus) + + low, _, err := bus.Channel().SubscribeChan( + context.Background(), + SubscribeOptions{Name: "low", Priority: -1}, + ) + if err != nil { + t.Fatalf("SubscribeChan low failed: %v", err) + } + high, _, err := bus.Channel().SubscribeChan( + context.Background(), + SubscribeOptions{Name: "high", Priority: 10}, + ) + if err != nil { + t.Fatalf("SubscribeChan high failed: %v", err) + } + peer, _, err := bus.Channel().SubscribeChan( + context.Background(), + SubscribeOptions{Name: "peer", Priority: 10}, + ) + if err != nil { + t.Fatalf("SubscribeChan peer failed: %v", err) + } + + stats := bus.Stats() + got := []string{ + stats.SubscriberStats[0].Name, + stats.SubscriberStats[1].Name, + stats.SubscriberStats[2].Name, + } + want := []string{"high", "peer", "low"} + if got[0] != want[0] || got[1] != want[1] || got[2] != want[2] { + t.Fatalf("subscriber order = %v, want %v", got, want) + } + + if err := high.Close(); err != nil { + t.Fatalf("Close high failed: %v", err) + } + + stats = bus.Stats() + got = []string{ + stats.SubscriberStats[0].Name, + stats.SubscriberStats[1].Name, + } + want = []string{"peer", "low"} + if got[0] != want[0] || got[1] != want[1] { + t.Fatalf("subscriber order after unsubscribe = %v, want %v", got, want) + } + + if err := peer.Close(); err != nil { + t.Fatalf("Close peer failed: %v", err) + } + if err := low.Close(); err != nil { + t.Fatalf("Close low failed: %v", err) + } +} + func receiveEvent(t *testing.T, ch <-chan Event) Event { t.Helper() diff --git a/pkg/events/kind.go b/pkg/events/kind.go index 85e61f741..b9327e155 100644 --- a/pkg/events/kind.go +++ b/pkg/events/kind.go @@ -101,3 +101,56 @@ const ( // KindMCPToolCallEnd is emitted when an MCP tool call ends. KindMCPToolCallEnd Kind = "mcp.tool.call.end" ) + +var knownKinds = []Kind{ + KindAgentTurnStart, + KindAgentTurnEnd, + KindAgentLLMRequest, + KindAgentLLMDelta, + KindAgentLLMResponse, + KindAgentLLMRetry, + KindAgentContextCompress, + KindAgentSessionSummarize, + KindAgentToolExecStart, + KindAgentToolExecEnd, + KindAgentToolExecSkipped, + KindAgentSteeringInjected, + KindAgentFollowUpQueued, + KindAgentInterruptReceived, + KindAgentSubTurnSpawn, + KindAgentSubTurnEnd, + KindAgentSubTurnResultDelivered, + KindAgentSubTurnOrphan, + KindAgentError, + KindChannelLifecycleStarted, + KindChannelLifecycleInitialized, + KindChannelLifecycleStartFailed, + KindChannelLifecycleStopped, + KindChannelWebhookRegistered, + KindChannelWebhookUnregistered, + KindChannelMessageOutboundQueued, + KindChannelMessageOutboundSent, + KindChannelMessageOutboundFailed, + KindChannelRateLimited, + KindBusPublishFailed, + KindBusCloseStarted, + KindBusCloseCompleted, + KindBusCloseDrained, + KindGatewayStart, + KindGatewayReady, + KindGatewayShutdown, + KindGatewayReloadStarted, + KindGatewayReloadCompleted, + KindGatewayReloadFailed, + KindMCPServerConnected, + KindMCPServerConnecting, + KindMCPServerFailed, + KindMCPToolDiscovered, + KindMCPToolCallStart, + KindMCPToolCallEnd, +} + +// KnownKinds returns the runtime event kinds declared by this package. +func KnownKinds() []Kind { + return append([]Kind(nil), knownKinds...) +} diff --git a/pkg/events/subscription.go b/pkg/events/subscription.go index 1b3977300..6619707a7 100644 --- a/pkg/events/subscription.go +++ b/pkg/events/subscription.go @@ -28,8 +28,11 @@ type SubscribeOptions struct { Priority int Concurrency ConcurrencyKind Backpressure BackpressurePolicy - Timeout time.Duration - PanicPolicy PanicPolicy + // Timeout bounds how long the subscription worker waits for one handler call. + // Handlers should still honor ctx cancellation; timed-out calls keep running + // until their handler returns. + Timeout time.Duration + PanicPolicy PanicPolicy } // ConcurrencyKind controls how handler subscriptions process queued events. @@ -107,6 +110,11 @@ type eventSubscription struct { counters subscriberCounters } +type handlerResult struct { + err error + panicked bool +} + func normalizeSubscribeOptions(opts SubscribeOptions) SubscribeOptions { if opts.Buffer <= 0 { opts.Buffer = defaultSubscriberBuffer @@ -234,26 +242,54 @@ func (s *eventSubscription) handle(ctx context.Context, evt Event) { if ctx == nil { ctx = context.Background() } - if s.opts.Timeout > 0 { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, s.opts.Timeout) - defer cancel() + + if s.opts.Timeout <= 0 { + s.recordHandlerResult(ctx, s.invokeHandler(ctx, evt)) + return } + ctx, cancel := context.WithTimeout(ctx, s.opts.Timeout) + defer cancel() + + done := make(chan handlerResult, 1) + go func() { + done <- s.invokeHandler(ctx, evt) + }() + + select { + case result := <-done: + s.recordHandlerResult(ctx, result) + case <-ctx.Done(): + if errors.Is(ctx.Err(), context.DeadlineExceeded) { + s.counters.timedOut.Add(1) + } + s.counters.failed.Add(1) + } +} + +func (s *eventSubscription) invokeHandler(ctx context.Context, evt Event) (result handlerResult) { if s.opts.PanicPolicy != Crash { defer func() { if recovered := recover(); recovered != nil { s.counters.panicked.Add(1) + result.panicked = true log.Printf("events: subscriber %q recovered panic: %v", s.name, recovered) } }() } - err := s.handler(ctx, evt) + result.err = s.handler(ctx, evt) + return result +} + +func (s *eventSubscription) recordHandlerResult(ctx context.Context, result handlerResult) { + if result.panicked { + return + } if errors.Is(ctx.Err(), context.DeadlineExceeded) { s.counters.timedOut.Add(1) } - if err != nil { + if result.err != nil { s.counters.failed.Add(1) return } @@ -303,11 +339,15 @@ type deliveryResult struct { closed bool } -func (s *eventSubscription) enqueue(ctx context.Context, evt Event) deliveryResult { +func (s *eventSubscription) enqueue(ctx context.Context, evt Event, nonBlocking bool) deliveryResult { if ctx == nil { ctx = context.Background() } + if nonBlocking { + return s.enqueueNonBlocking(evt) + } + if s.opts.Backpressure == Block { return s.enqueueBlocking(ctx, evt) } @@ -343,6 +383,21 @@ func (s *eventSubscription) enqueueBlocking(ctx context.Context, evt Event) deli return s.enqueueBlock(ctx, evt) } +func (s *eventSubscription) enqueueNonBlocking(evt Event) deliveryResult { + s.mu.RLock() + defer s.mu.RUnlock() + + if s.closed { + return deliveryResult{closed: true} + } + + s.counters.received.Add(1) + if s.opts.Backpressure == DropOldest { + return s.enqueueDropOldest(evt) + } + return s.enqueueDropNewest(evt) +} + func (s *eventSubscription) enqueueDropNewest(evt Event) deliveryResult { select { case <-s.closing: diff --git a/pkg/events/subscription_test.go b/pkg/events/subscription_test.go index 44d4be64b..8fde731cc 100644 --- a/pkg/events/subscription_test.go +++ b/pkg/events/subscription_test.go @@ -185,6 +185,45 @@ func TestLockedHandlerProcessesSequentially(t *testing.T) { } } +func TestHandlerTimeoutDoesNotWedgeLockedSubscription(t *testing.T) { + t.Parallel() + + bus := NewBus() + defer closeBus(t, bus) + + releaseFirst := make(chan struct{}) + defer close(releaseFirst) + + var calls atomic.Uint64 + sub, err := bus.Channel().Subscribe( + context.Background(), + SubscribeOptions{Name: "timeout", Buffer: 2, Concurrency: Locked, Timeout: 20 * time.Millisecond}, + func(context.Context, Event) error { + if calls.Add(1) == 1 { + <-releaseFirst + } + return nil + }, + ) + if err != nil { + t.Fatalf("Subscribe failed: %v", err) + } + + bus.Publish(context.Background(), Event{Kind: Kind("test.first")}) + waitForStat(t, func() uint64 { + return sub.Stats().TimedOut + }, 1) + + bus.Publish(context.Background(), Event{Kind: Kind("test.second")}) + waitForStat(t, func() uint64 { + return sub.Stats().Handled + }, 1) + + if got := sub.Stats().Failed; got != 1 { + t.Fatalf("subscription failed = %d, want timeout failure", got) + } +} + func waitForSubscriptionDone(t *testing.T, sub Subscription) { t.Helper() diff --git a/pkg/gateway/events.go b/pkg/gateway/events.go index fd263fba6..0f454ed7d 100644 --- a/pkg/gateway/events.go +++ b/pkg/gateway/events.go @@ -1,15 +1,12 @@ package gateway import ( - "context" "time" "github.com/sipeed/picoclaw/pkg/agent" runtimeevents "github.com/sipeed/picoclaw/pkg/events" ) -const gatewayEventPublishTimeout = 100 * time.Millisecond - type gatewayEventPayload struct { DurationMS int64 `json:"duration_ms,omitempty"` Error string `json:"error,omitempty"` @@ -35,12 +32,22 @@ func publishGatewayEvent( payload.Error = err.Error() } - ctx, cancel := context.WithTimeout(context.Background(), gatewayEventPublishTimeout) - defer cancel() - al.RuntimeEventBus().Publish(ctx, runtimeevents.Event{ + al.RuntimeEventBus().PublishNonBlocking(runtimeevents.Event{ Kind: kind, Source: runtimeevents.Source{Component: "gateway"}, Severity: severity, Payload: payload, + Attrs: gatewayEventAttrs(payload), }) } + +func gatewayEventAttrs(payload gatewayEventPayload) map[string]any { + attrs := map[string]any{} + if payload.DurationMS > 0 { + attrs["duration_ms"] = payload.DurationMS + } + if payload.Error != "" { + attrs["error"] = payload.Error + } + return attrs +} diff --git a/pkg/gateway/gateway.go b/pkg/gateway/gateway.go index dd579ed7e..4fd06d836 100644 --- a/pkg/gateway/gateway.go +++ b/pkg/gateway/gateway.go @@ -267,7 +267,7 @@ func Run(debug bool, homePath, configPath string, allowEmptyStartup bool) (runEr select { case <-sigChan: logger.Info("Shutting down...") - shutdownGateway(runningServices, agentLoop, provider, true) + shutdownGateway(runningServices, agentLoop, provider, msgBus, true) return nil case newCfg := <-configReloadChan: if !runningServices.reloading.CompareAndSwap(false, true) { @@ -510,6 +510,7 @@ func shutdownGateway( runningServices *services, agentLoop *agent.AgentLoop, provider providers.LLMProvider, + msgBus *bus.MessageBus, fullShutdown bool, ) { publishGatewayEvent(agentLoop, runtimeevents.KindGatewayShutdown, time.Time{}, nil) @@ -520,6 +521,10 @@ func shutdownGateway( stopAndCleanupServices(runningServices, gracefulShutdownTimeout, false) + if fullShutdown && msgBus != nil { + msgBus.Close() + } + agentLoop.Stop() agentLoop.Close() diff --git a/pkg/gateway/gateway_test.go b/pkg/gateway/gateway_test.go index 9af833b6b..ab3833ba6 100644 --- a/pkg/gateway/gateway_test.go +++ b/pkg/gateway/gateway_test.go @@ -2,6 +2,7 @@ package gateway import ( "context" + "errors" "fmt" "os" "os/exec" @@ -159,6 +160,42 @@ func TestPublishGatewayEvent(t *testing.T) { if payload.DurationMS <= 0 { t.Fatalf("DurationMS = %d, want positive", payload.DurationMS) } + if evt.Attrs["duration_ms"] == nil { + t.Fatalf("gateway event attrs missing duration_ms: %#v", evt.Attrs) + } +} + +func TestShutdownGatewayClosesMessageBus(t *testing.T) { + msgBus := bus.NewMessageBus() + al := agent.NewAgentLoop( + config.DefaultConfig(), + msgBus, + &startupBlockedProvider{reason: "not used"}, + ) + msgBus.SetEventPublisher(al.RuntimeEventBus()) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + sub, eventsCh, err := al.RuntimeEventBus().Channel().OfKind(runtimeevents.KindBusCloseCompleted).SubscribeChan( + ctx, + runtimeevents.SubscribeOptions{Name: "bus-close-test", Buffer: 4}, + ) + if err != nil { + t.Fatalf("SubscribeChan() error = %v", err) + } + defer func() { + _ = sub.Close() + }() + + shutdownGateway(&services{}, al, &startupBlockedProvider{reason: "not used"}, msgBus, true) + + evt := receiveGatewayRuntimeEvent(t, eventsCh) + if evt.Kind != runtimeevents.KindBusCloseCompleted { + t.Fatalf("shutdown event kind = %q, want %q", evt.Kind, runtimeevents.KindBusCloseCompleted) + } + if err := msgBus.PublishVoiceControl(context.Background(), bus.VoiceControl{}); !errors.Is(err, bus.ErrBusClosed) { + t.Fatalf("PublishVoiceControl after shutdown error = %v, want %v", err, bus.ErrBusClosed) + } } func receiveGatewayRuntimeEvent(t *testing.T, ch <-chan runtimeevents.Event) runtimeevents.Event { diff --git a/pkg/mcp/events.go b/pkg/mcp/events.go index 843fa9088..3b7f53f96 100644 --- a/pkg/mcp/events.go +++ b/pkg/mcp/events.go @@ -1,15 +1,10 @@ package mcp import ( - "context" - "time" - "github.com/sipeed/picoclaw/pkg/config" runtimeevents "github.com/sipeed/picoclaw/pkg/events" ) -const mcpEventPublishTimeout = 100 * time.Millisecond - func (m *Manager) publishServerEvent( kind runtimeevents.Kind, serverName string, @@ -36,13 +31,12 @@ func (m *Manager) publishServerEvent( payload.Error = err.Error() } - ctx, cancel := context.WithTimeout(context.Background(), mcpEventPublishTimeout) - defer cancel() - m.runtimeEvents.Publish(ctx, runtimeevents.Event{ + m.runtimeEvents.PublishNonBlocking(runtimeevents.Event{ Kind: kind, Source: runtimeevents.Source{Component: "mcp", Name: serverName}, Severity: severity, Payload: payload, + Attrs: mcpServerEventAttrs(payload), }) } @@ -57,16 +51,33 @@ func (m *Manager) publishToolDiscovered(serverName string, cfg config.MCPServerC Command: cfg.Command, Tool: toolName, } - ctx, cancel := context.WithTimeout(context.Background(), mcpEventPublishTimeout) - defer cancel() - m.runtimeEvents.Publish(ctx, runtimeevents.Event{ + m.runtimeEvents.PublishNonBlocking(runtimeevents.Event{ Kind: runtimeevents.KindMCPToolDiscovered, Source: runtimeevents.Source{Component: "mcp", Name: serverName}, Severity: runtimeevents.SeverityInfo, Payload: payload, + Attrs: mcpServerEventAttrs(payload), }) } +func mcpServerEventAttrs(payload ServerEventPayload) map[string]any { + attrs := map[string]any{} + setMCPAttrString(attrs, "server", payload.Server) + setMCPAttrString(attrs, "type", payload.Type) + setMCPAttrString(attrs, "tool", payload.Tool) + if payload.ToolCount > 0 { + attrs["tool_count"] = payload.ToolCount + } + setMCPAttrString(attrs, "error", payload.Error) + return attrs +} + +func setMCPAttrString(attrs map[string]any, key, value string) { + if value != "" { + attrs[key] = value + } +} + func mcpTransportType(cfg config.MCPServerConfig) string { if cfg.Type != "" { return cfg.Type diff --git a/pkg/mcp/manager_test.go b/pkg/mcp/manager_test.go index 17b272820..5789a37a9 100644 --- a/pkg/mcp/manager_test.go +++ b/pkg/mcp/manager_test.go @@ -300,6 +300,11 @@ func TestConnectServerPublishesRuntimeEvents(t *testing.T) { connected.Severity != runtimeevents.SeverityInfo { t.Fatalf("connected event = %+v", connected) } + if connected.Attrs["server"] != "good" || + connected.Attrs["type"] != "stdio" || + connected.Attrs["tool_count"] != 1 { + t.Fatalf("connected attrs = %#v", connected.Attrs) + } err = mgr.ConnectServer(context.Background(), "bad", config.MCPServerConfig{ Type: "stdio", @@ -314,6 +319,9 @@ func TestConnectServerPublishesRuntimeEvents(t *testing.T) { failed.Severity != runtimeevents.SeverityError { t.Fatalf("failed event = %+v", failed) } + if failed.Attrs["server"] != "bad" || failed.Attrs["error"] != "connect failed" { + t.Fatalf("failed attrs = %#v", failed.Attrs) + } } func receiveMCPRuntimeEvent(t *testing.T, ch <-chan runtimeevents.Event) runtimeevents.Event { diff --git a/pkg/tools/integration/mcp_tool.go b/pkg/tools/integration/mcp_tool.go index 1ee86a4a9..8cfc1de5e 100644 --- a/pkg/tools/integration/mcp_tool.go +++ b/pkg/tools/integration/mcp_tool.go @@ -310,17 +310,31 @@ func (t *MCPTool) publishRuntimeEvent( severity = runtimeevents.SeverityError } - publishCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - defer cancel() - t.runtimeEvents.Publish(publishCtx, runtimeevents.Event{ + t.runtimeEvents.PublishNonBlocking(runtimeevents.Event{ Kind: kind, Source: runtimeevents.Source{Component: "mcp", Name: t.serverName}, Scope: scope, Severity: severity, Payload: payload, + Attrs: mcpToolCallEventAttrs(payload), }) } +func mcpToolCallEventAttrs(payload MCPToolCallPayload) map[string]any { + attrs := map[string]any{ + "server": payload.Server, + "tool": payload.Tool, + "duration_ms": payload.DurationMS, + } + if payload.IsError { + attrs["is_error"] = payload.IsError + } + if payload.Error != "" { + attrs["error"] = payload.Error + } + return attrs +} + // extractContentText extracts text from MCP content array func extractContentText(content []mcp.Content) string { var parts []string diff --git a/pkg/tools/integration/mcp_tool_test.go b/pkg/tools/integration/mcp_tool_test.go index f346e5329..7c961e1e1 100644 --- a/pkg/tools/integration/mcp_tool_test.go +++ b/pkg/tools/integration/mcp_tool_test.go @@ -350,6 +350,11 @@ func TestMCPTool_Execute_PublishesRuntimeEvents(t *testing.T) { if payload.Server != "github" || payload.Tool != "search_repos" || payload.IsError { t.Fatalf("ended payload = %+v", payload) } + if ended.Attrs["server"] != "github" || + ended.Attrs["tool"] != "search_repos" || + ended.Attrs["duration_ms"] == nil { + t.Fatalf("ended attrs = %#v", ended.Attrs) + } } func receiveMCPToolRuntimeEvent(t *testing.T, ch <-chan runtimeevents.Event) runtimeevents.Event {