mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
feat(events): publish runtime service events
Migrate hook observation to runtime events and update the process hook notification protocol. Add runtime event publication for message bus failures, channel lifecycle/outbound flow, gateway reloads, MCP server state, and MCP tool calls. Validation: go test ./pkg/events/... ./pkg/bus ./pkg/agent ./pkg/channels ./pkg/mcp ./pkg/tools/integration ./pkg/gateway; make lint
This commit is contained in:
@@ -206,3 +206,11 @@ func (al *AgentLoop) RuntimeEventStats() runtimeevents.Stats {
|
||||
}
|
||||
return al.runtimeEvents.Stats()
|
||||
}
|
||||
|
||||
// RuntimeEventBus returns the runtime event bus used by the agent loop.
|
||||
func (al *AgentLoop) RuntimeEventBus() runtimeevents.Bus {
|
||||
if al == nil {
|
||||
return nil
|
||||
}
|
||||
return al.runtimeEvents
|
||||
}
|
||||
|
||||
@@ -79,7 +79,7 @@ func NewAgentLoop(
|
||||
al.ownsRuntimeEvents = true
|
||||
}
|
||||
al.providerFactory = providers.CreateProviderFromConfig
|
||||
al.hooks = NewHookManager(eventBus)
|
||||
al.hooks = NewHookManagerWithRuntimeEvents(eventBus, al.runtimeEvents.Channel())
|
||||
configureHookManagerFromConfig(al.hooks, cfg)
|
||||
al.contextManager = al.resolveContextManager()
|
||||
|
||||
|
||||
@@ -97,7 +97,7 @@ func (al *AgentLoop) ensureMCPInitialized(ctx context.Context) error {
|
||||
}
|
||||
|
||||
al.mcp.initOnce.Do(func() {
|
||||
mcpManager := mcp.NewManager()
|
||||
mcpManager := mcp.NewManager(mcp.WithRuntimeEvents(al.runtimeEvents))
|
||||
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
workspacePath := al.cfg.WorkspacePath()
|
||||
@@ -164,6 +164,7 @@ func (al *AgentLoop) ensureMCPInitialized(ctx context.Context) error {
|
||||
mcpTool := tools.NewMCPTool(mcpManager, serverName, tool)
|
||||
mcpTool.SetWorkspace(agent.Workspace)
|
||||
mcpTool.SetMaxInlineTextRunes(al.cfg.Tools.MCP.GetMaxInlineTextChars())
|
||||
mcpTool.SetEventPublisher(al.runtimeEvents)
|
||||
|
||||
if registerAsHidden {
|
||||
agent.Tools.RegisterHidden(mcpTool)
|
||||
|
||||
+12
-5
@@ -8,6 +8,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
||||
)
|
||||
|
||||
type hookRuntime struct {
|
||||
@@ -295,10 +296,11 @@ func processHookObserveKindsFromConfig(observe []string) ([]string, bool, error)
|
||||
case "", "*", "all":
|
||||
return nil, true, nil
|
||||
default:
|
||||
if _, ok := validKinds[kind]; !ok {
|
||||
normalizedKind, ok := validKinds[kind]
|
||||
if !ok {
|
||||
return nil, false, fmt.Errorf("unsupported observe event %q", kind)
|
||||
}
|
||||
normalized = append(normalized, kind)
|
||||
normalized = append(normalized, normalizedKind)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -308,10 +310,15 @@ func processHookObserveKindsFromConfig(observe []string) ([]string, bool, error)
|
||||
return normalized, true, nil
|
||||
}
|
||||
|
||||
func validHookEventKinds() map[string]struct{} {
|
||||
kinds := make(map[string]struct{}, int(eventKindCount))
|
||||
func validHookEventKinds() map[string]string {
|
||||
kinds := make(map[string]string, int(eventKindCount)*2)
|
||||
for kind := EventKind(0); kind < eventKindCount; kind++ {
|
||||
kinds[kind.String()] = struct{}{}
|
||||
runtimeKind := runtimeKindForAgentEvent(kind).String()
|
||||
kinds[kind.String()] = runtimeKind
|
||||
kinds[runtimeKind] = runtimeKind
|
||||
}
|
||||
kinds[runtimeevents.KindAgentToolExecStart.String()] = runtimeevents.KindAgentToolExecStart.String()
|
||||
kinds[runtimeevents.KindAgentToolExecEnd.String()] = runtimeevents.KindAgentToolExecEnd.String()
|
||||
kinds[runtimeevents.KindAgentToolExecSkipped.String()] = runtimeevents.KindAgentToolExecSkipped.String()
|
||||
return kinds
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
@@ -155,7 +156,25 @@ func TestAgentLoop_ProcessDirectWithChannel_AutoMountsProcessHook(t *testing.T)
|
||||
t.Fatalf("expected process model, got %q", lastModel)
|
||||
}
|
||||
|
||||
waitForFileContains(t, eventLog, "turn_end")
|
||||
waitForFileContains(t, eventLog, "agent.turn.end")
|
||||
}
|
||||
|
||||
func TestProcessHookObserveKindsFromConfigAcceptsRuntimeNames(t *testing.T) {
|
||||
kinds, enabled, err := processHookObserveKindsFromConfig([]string{
|
||||
"tool_exec_start",
|
||||
"agent.tool.exec_end",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("processHookObserveKindsFromConfig failed: %v", err)
|
||||
}
|
||||
if !enabled {
|
||||
t.Fatal("expected observe to be enabled")
|
||||
}
|
||||
|
||||
want := []string{"agent.tool.exec_start", "agent.tool.exec_end"}
|
||||
if !slices.Equal(kinds, want) {
|
||||
t.Fatalf("observe kinds = %v, want %v", kinds, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_ProcessDirectWithChannel_InvalidConfiguredHookFails(t *testing.T) {
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
||||
"github.com/sipeed/picoclaw/pkg/isolation"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
@@ -188,13 +189,26 @@ func (ph *ProcessHook) OnEvent(ctx context.Context, evt Event) error {
|
||||
return nil
|
||||
}
|
||||
if len(ph.observeKinds) > 0 {
|
||||
if _, ok := ph.observeKinds[evt.Kind.String()]; !ok {
|
||||
kind := runtimeKindForAgentEvent(evt.Kind).String()
|
||||
if _, ok := ph.observeKinds[kind]; !ok {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return ph.notify(ctx, "hook.event", evt)
|
||||
}
|
||||
|
||||
func (ph *ProcessHook) OnRuntimeEvent(ctx context.Context, evt runtimeevents.Event) error {
|
||||
if ph == nil || !ph.opts.Observe {
|
||||
return nil
|
||||
}
|
||||
if len(ph.observeKinds) > 0 {
|
||||
if _, ok := ph.observeKinds[evt.Kind.String()]; !ok {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return ph.notify(ctx, "hook.runtime_event", evt)
|
||||
}
|
||||
|
||||
func (ph *ProcessHook) BeforeLLM(
|
||||
ctx context.Context,
|
||||
req *LLMHookRequest,
|
||||
|
||||
@@ -66,7 +66,7 @@ func TestAgentLoop_MountProcessHook_LLMAndObserver(t *testing.T) {
|
||||
t.Fatalf("expected process model, got %q", lastModel)
|
||||
}
|
||||
|
||||
waitForFileContains(t, eventLog, "turn_end")
|
||||
waitForFileContains(t, eventLog, "agent.turn.end")
|
||||
}
|
||||
|
||||
func TestAgentLoop_MountProcessHook_ToolRewrite(t *testing.T) {
|
||||
@@ -350,10 +350,12 @@ func runProcessHookHelper() error {
|
||||
}
|
||||
|
||||
if msg.ID == 0 {
|
||||
if msg.Method == "hook.event" && eventLog != "" {
|
||||
if (msg.Method == "hook.event" || msg.Method == "hook.runtime_event") && eventLog != "" {
|
||||
var evt map[string]any
|
||||
if err := json.Unmarshal(msg.Params, &evt); err == nil {
|
||||
if rawKind, ok := evt["Kind"].(float64); ok {
|
||||
if kind, ok := evt["kind"].(string); ok {
|
||||
_ = os.WriteFile(eventLog, []byte(kind+"\n"), 0o644)
|
||||
} else if rawKind, ok := evt["Kind"].(float64); ok {
|
||||
kind := EventKind(rawKind)
|
||||
_ = os.WriteFile(eventLog, []byte(kind.String()+"\n"), 0o644)
|
||||
}
|
||||
|
||||
+99
-7
@@ -9,6 +9,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
@@ -75,6 +76,10 @@ type EventObserver interface {
|
||||
OnEvent(ctx context.Context, evt Event) error
|
||||
}
|
||||
|
||||
type RuntimeEventObserver interface {
|
||||
OnRuntimeEvent(ctx context.Context, evt runtimeevents.Event) error
|
||||
}
|
||||
|
||||
type LLMInterceptor interface {
|
||||
BeforeLLM(ctx context.Context, req *LLMHookRequest) (*LLMHookRequest, HookDecision, error)
|
||||
AfterLLM(ctx context.Context, resp *LLMHookResponse) (*LLMHookResponse, HookDecision, error)
|
||||
@@ -193,6 +198,7 @@ func (r *ToolResultHookResponse) Clone() *ToolResultHookResponse {
|
||||
|
||||
type HookManager struct {
|
||||
eventBus *EventBus
|
||||
runtimeEvents runtimeevents.EventChannel
|
||||
observerTimeout time.Duration
|
||||
interceptorTimeout time.Duration
|
||||
approvalTimeout time.Duration
|
||||
@@ -201,28 +207,56 @@ type HookManager struct {
|
||||
hooks map[string]HookRegistration
|
||||
ordered []HookRegistration
|
||||
|
||||
sub EventSubscription
|
||||
done chan struct{}
|
||||
closeOnce sync.Once
|
||||
sub EventSubscription
|
||||
runtimeSub runtimeevents.Subscription
|
||||
done chan struct{}
|
||||
runtimeDone chan struct{}
|
||||
runtimeObserveEnabled bool
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
func NewHookManager(eventBus *EventBus) *HookManager {
|
||||
return NewHookManagerWithRuntimeEvents(eventBus, nil)
|
||||
}
|
||||
|
||||
func NewHookManagerWithRuntimeEvents(eventBus *EventBus, runtimeEvents runtimeevents.EventChannel) *HookManager {
|
||||
hm := &HookManager{
|
||||
eventBus: eventBus,
|
||||
runtimeEvents: runtimeEvents,
|
||||
observerTimeout: defaultHookObserverTimeout,
|
||||
interceptorTimeout: defaultHookInterceptorTimeout,
|
||||
approvalTimeout: defaultHookApprovalTimeout,
|
||||
hooks: make(map[string]HookRegistration),
|
||||
done: make(chan struct{}),
|
||||
runtimeDone: make(chan struct{}),
|
||||
}
|
||||
|
||||
if eventBus == nil {
|
||||
if eventBus != nil {
|
||||
hm.sub = eventBus.Subscribe(hookObserverBufferSize)
|
||||
go hm.dispatchEvents()
|
||||
} else {
|
||||
close(hm.done)
|
||||
return hm
|
||||
}
|
||||
|
||||
hm.sub = eventBus.Subscribe(hookObserverBufferSize)
|
||||
go hm.dispatchEvents()
|
||||
if runtimeEvents != nil {
|
||||
sub, ch, err := runtimeEvents.SubscribeChan(context.Background(), runtimeevents.SubscribeOptions{
|
||||
Name: "hook-manager-observer",
|
||||
Buffer: hookObserverBufferSize,
|
||||
})
|
||||
if err != nil {
|
||||
logger.WarnCF("hooks", "Failed to subscribe runtime events for hooks", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
close(hm.runtimeDone)
|
||||
} else {
|
||||
hm.runtimeSub = sub
|
||||
hm.runtimeObserveEnabled = true
|
||||
go hm.dispatchRuntimeEvents(ch)
|
||||
}
|
||||
} else {
|
||||
close(hm.runtimeDone)
|
||||
}
|
||||
|
||||
return hm
|
||||
}
|
||||
|
||||
@@ -235,7 +269,15 @@ func (hm *HookManager) Close() {
|
||||
if hm.eventBus != nil {
|
||||
hm.eventBus.Unsubscribe(hm.sub.ID)
|
||||
}
|
||||
if hm.runtimeSub != nil {
|
||||
if err := hm.runtimeSub.Close(); err != nil {
|
||||
logger.WarnCF("hooks", "Failed to close runtime event hook subscription", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
<-hm.done
|
||||
<-hm.runtimeDone
|
||||
hm.closeAllHooks()
|
||||
})
|
||||
}
|
||||
@@ -297,6 +339,11 @@ func (hm *HookManager) dispatchEvents() {
|
||||
|
||||
for evt := range hm.sub.C {
|
||||
for _, reg := range hm.snapshotHooks() {
|
||||
if hm.runtimeObserveEnabled {
|
||||
if _, ok := reg.Hook.(RuntimeEventObserver); ok {
|
||||
continue
|
||||
}
|
||||
}
|
||||
observer, ok := reg.Hook.(EventObserver)
|
||||
if !ok {
|
||||
continue
|
||||
@@ -306,6 +353,20 @@ func (hm *HookManager) dispatchEvents() {
|
||||
}
|
||||
}
|
||||
|
||||
func (hm *HookManager) dispatchRuntimeEvents(ch <-chan runtimeevents.Event) {
|
||||
defer close(hm.runtimeDone)
|
||||
|
||||
for evt := range ch {
|
||||
for _, reg := range hm.snapshotHooks() {
|
||||
observer, ok := reg.Hook.(RuntimeEventObserver)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
hm.runRuntimeObserver(reg.Name, observer, evt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (hm *HookManager) BeforeLLM(ctx context.Context, req *LLMHookRequest) (*LLMHookRequest, HookDecision) {
|
||||
if hm == nil || req == nil {
|
||||
return req, HookDecision{Action: HookActionContinue}
|
||||
@@ -608,6 +669,37 @@ func (hm *HookManager) runObserver(name string, observer EventObserver, evt Even
|
||||
}
|
||||
}
|
||||
|
||||
func (hm *HookManager) runRuntimeObserver(
|
||||
name string,
|
||||
observer RuntimeEventObserver,
|
||||
evt runtimeevents.Event,
|
||||
) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hm.observerTimeout)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- observer.OnRuntimeEvent(ctx, evt)
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
logger.WarnCF("hooks", "Runtime event observer failed", map[string]any{
|
||||
"hook": name,
|
||||
"event": evt.Kind.String(),
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
case <-ctx.Done():
|
||||
logger.WarnCF("hooks", "Runtime event observer timed out", map[string]any{
|
||||
"hook": name,
|
||||
"event": evt.Kind.String(),
|
||||
"timeout_ms": hm.observerTimeout.Milliseconds(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (hm *HookManager) callBeforeLLM(
|
||||
parent context.Context,
|
||||
name string,
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
@@ -150,6 +151,31 @@ func (h *llmObserverHook) AfterLLM(
|
||||
return next, HookDecision{Action: HookActionModify}, nil
|
||||
}
|
||||
|
||||
type dualRuntimeObserverHook struct {
|
||||
legacyCh chan Event
|
||||
runtimeCh chan runtimeevents.Event
|
||||
}
|
||||
|
||||
func (h *dualRuntimeObserverHook) OnEvent(ctx context.Context, evt Event) error {
|
||||
if evt.Kind == EventKindTurnEnd {
|
||||
select {
|
||||
case h.legacyCh <- evt:
|
||||
default:
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *dualRuntimeObserverHook) OnRuntimeEvent(ctx context.Context, evt runtimeevents.Event) error {
|
||||
if evt.Kind == runtimeevents.KindAgentTurnEnd {
|
||||
select {
|
||||
case h.runtimeCh <- evt:
|
||||
default:
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type llmSystemRewriteHook struct{}
|
||||
|
||||
func (h *llmSystemRewriteHook) BeforeLLM(
|
||||
@@ -498,6 +524,65 @@ func TestAgentLoop_Hooks_ObserverAndLLMInterceptor(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_Hooks_RuntimeObserverPreferredOverLegacyObserver(t *testing.T) {
|
||||
provider := &llmHookTestProvider{}
|
||||
al, agent, cleanup := newHookTestLoop(t, provider)
|
||||
defer cleanup()
|
||||
|
||||
hook := &dualRuntimeObserverHook{
|
||||
legacyCh: make(chan Event, 1),
|
||||
runtimeCh: make(chan runtimeevents.Event, 1),
|
||||
}
|
||||
if err := al.MountHook(NamedHook("runtime-observer", hook)); err != nil {
|
||||
t.Fatalf("MountHook failed: %v", err)
|
||||
}
|
||||
|
||||
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
||||
SessionKey: "session-1",
|
||||
Channel: "cli",
|
||||
ChatID: "direct",
|
||||
UserMessage: "hello",
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
InboundContext: &bus.InboundContext{
|
||||
Channel: "cli",
|
||||
Account: "default",
|
||||
ChatID: "direct",
|
||||
ChatType: "direct",
|
||||
SenderID: "hook-user",
|
||||
MessageID: "msg-1",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("runAgentLoop failed: %v", err)
|
||||
}
|
||||
if resp != "provider content" {
|
||||
t.Fatalf("expected provider content, got %q", resp)
|
||||
}
|
||||
|
||||
select {
|
||||
case evt := <-hook.runtimeCh:
|
||||
if evt.Kind != runtimeevents.KindAgentTurnEnd {
|
||||
t.Fatalf("runtime observer kind = %q", evt.Kind)
|
||||
}
|
||||
if evt.Scope.SessionKey != "session-1" ||
|
||||
evt.Scope.Channel != "cli" ||
|
||||
evt.Scope.ChatID != "direct" ||
|
||||
evt.Scope.MessageID != "msg-1" {
|
||||
t.Fatalf("runtime observer scope = %+v", evt.Scope)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for runtime observer event")
|
||||
}
|
||||
|
||||
select {
|
||||
case evt := <-hook.legacyCh:
|
||||
t.Fatalf("legacy observer unexpectedly received %v", evt.Kind)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_BtwCommand_UsesLLMHooks(t *testing.T) {
|
||||
provider := &llmHookTestProvider{}
|
||||
al, agent, cleanup := newHookTestLoop(t, provider)
|
||||
|
||||
+43
-5
@@ -6,6 +6,7 @@ import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
@@ -48,6 +49,12 @@ type MessageBus struct {
|
||||
closed atomic.Bool
|
||||
wg sync.WaitGroup
|
||||
streamDelegate atomic.Value // stores StreamDelegate
|
||||
eventPublisher atomic.Value // stores EventPublisher
|
||||
}
|
||||
|
||||
// EventPublisher is the minimal runtime event publisher used by MessageBus.
|
||||
type EventPublisher interface {
|
||||
Publish(ctx context.Context, evt runtimeevents.Event) runtimeevents.PublishResult
|
||||
}
|
||||
|
||||
func NewMessageBus() *MessageBus {
|
||||
@@ -92,9 +99,14 @@ func publish[T any](ctx context.Context, mb *MessageBus, ch chan T, msg T) error
|
||||
func (mb *MessageBus) PublishInbound(ctx context.Context, msg InboundMessage) error {
|
||||
msg = NormalizeInboundMessage(msg)
|
||||
if msg.Context.isZero() {
|
||||
mb.publishFailure("inbound", runtimeScopeFromInboundContext(msg.Context), ErrMissingInboundContext)
|
||||
return ErrMissingInboundContext
|
||||
}
|
||||
return publish(ctx, mb, mb.inbound, msg)
|
||||
if err := publish(ctx, mb, mb.inbound, msg); err != nil {
|
||||
mb.publishFailure("inbound", runtimeScopeFromInboundContext(msg.Context), err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mb *MessageBus) InboundChan() <-chan InboundMessage {
|
||||
@@ -104,9 +116,14 @@ func (mb *MessageBus) InboundChan() <-chan InboundMessage {
|
||||
func (mb *MessageBus) PublishOutbound(ctx context.Context, msg OutboundMessage) error {
|
||||
msg = NormalizeOutboundMessage(msg)
|
||||
if msg.Context.isZero() {
|
||||
mb.publishFailure("outbound", runtimeScopeFromInboundContext(msg.Context), ErrMissingOutboundContext)
|
||||
return ErrMissingOutboundContext
|
||||
}
|
||||
return publish(ctx, mb, mb.outbound, msg)
|
||||
if err := publish(ctx, mb, mb.outbound, msg); err != nil {
|
||||
mb.publishFailure("outbound", runtimeScopeFromInboundContext(msg.Context), err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mb *MessageBus) OutboundChan() <-chan OutboundMessage {
|
||||
@@ -116,9 +133,14 @@ func (mb *MessageBus) OutboundChan() <-chan OutboundMessage {
|
||||
func (mb *MessageBus) PublishOutboundMedia(ctx context.Context, msg OutboundMediaMessage) error {
|
||||
msg = NormalizeOutboundMediaMessage(msg)
|
||||
if msg.Context.isZero() {
|
||||
mb.publishFailure("outbound_media", runtimeScopeFromInboundContext(msg.Context), ErrMissingOutboundMediaContext)
|
||||
return ErrMissingOutboundMediaContext
|
||||
}
|
||||
return publish(ctx, mb, mb.outboundMedia, msg)
|
||||
if err := publish(ctx, mb, mb.outboundMedia, msg); err != nil {
|
||||
mb.publishFailure("outbound_media", runtimeScopeFromInboundContext(msg.Context), err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mb *MessageBus) OutboundMediaChan() <-chan OutboundMediaMessage {
|
||||
@@ -126,7 +148,11 @@ func (mb *MessageBus) OutboundMediaChan() <-chan OutboundMediaMessage {
|
||||
}
|
||||
|
||||
func (mb *MessageBus) PublishAudioChunk(ctx context.Context, chunk AudioChunk) error {
|
||||
return publish(ctx, mb, mb.audioChunks, chunk)
|
||||
if err := publish(ctx, mb, mb.audioChunks, chunk); err != nil {
|
||||
mb.publishFailure("audio_chunk", runtimeScopeFromAudioChunk(chunk), err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mb *MessageBus) AudioChunksChan() <-chan AudioChunk {
|
||||
@@ -134,7 +160,11 @@ func (mb *MessageBus) AudioChunksChan() <-chan AudioChunk {
|
||||
}
|
||||
|
||||
func (mb *MessageBus) PublishVoiceControl(ctx context.Context, ctrl VoiceControl) error {
|
||||
return publish(ctx, mb, mb.voiceControls, ctrl)
|
||||
if err := publish(ctx, mb, mb.voiceControls, ctrl); err != nil {
|
||||
mb.publishFailure("voice_control", runtimeScopeFromVoiceControl(ctrl), err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mb *MessageBus) VoiceControlsChan() <-chan VoiceControl {
|
||||
@@ -146,6 +176,11 @@ func (mb *MessageBus) SetStreamDelegate(d StreamDelegate) {
|
||||
mb.streamDelegate.Store(d)
|
||||
}
|
||||
|
||||
// SetEventPublisher registers a runtime event publisher for bus errors and lifecycle events.
|
||||
func (mb *MessageBus) SetEventPublisher(p EventPublisher) {
|
||||
mb.eventPublisher.Store(p)
|
||||
}
|
||||
|
||||
// GetStreamer returns a Streamer for the given channel+chatID via the delegate.
|
||||
func (mb *MessageBus) GetStreamer(ctx context.Context, channel, chatID string) (Streamer, bool) {
|
||||
if d, ok := mb.streamDelegate.Load().(StreamDelegate); ok && d != nil {
|
||||
@@ -156,6 +191,7 @@ func (mb *MessageBus) GetStreamer(ctx context.Context, channel, chatID string) (
|
||||
|
||||
func (mb *MessageBus) Close() {
|
||||
mb.closeOnce.Do(func() {
|
||||
mb.publishCloseEvent(runtimeevents.KindBusCloseStarted, 0)
|
||||
// notify all blocked publishers to exit
|
||||
close(mb.done)
|
||||
|
||||
@@ -195,6 +231,8 @@ func (mb *MessageBus) Close() {
|
||||
logger.DebugCF("bus", "Drained buffered messages during close", map[string]any{
|
||||
"count": drained,
|
||||
})
|
||||
mb.publishCloseEvent(runtimeevents.KindBusCloseDrained, drained)
|
||||
}
|
||||
mb.publishCloseEvent(runtimeevents.KindBusCloseCompleted, drained)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
||||
)
|
||||
|
||||
func TestPublishConsume(t *testing.T) {
|
||||
@@ -171,6 +173,76 @@ func TestPublishInbound_BackfillsContextFromLegacyFields(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageBusPublishesRuntimeFailureAndCloseEvents(t *testing.T) {
|
||||
eventBus := runtimeevents.NewBus()
|
||||
defer func() {
|
||||
if err := eventBus.Close(); err != nil {
|
||||
t.Errorf("event bus close failed: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
_, eventsCh, err := eventBus.Channel().OfKind(
|
||||
runtimeevents.KindBusPublishFailed,
|
||||
runtimeevents.KindBusCloseStarted,
|
||||
runtimeevents.KindBusCloseDrained,
|
||||
runtimeevents.KindBusCloseCompleted,
|
||||
).SubscribeChan(t.Context(), runtimeevents.SubscribeOptions{Name: "bus-events", Buffer: 4})
|
||||
if err != nil {
|
||||
t.Fatalf("SubscribeChan failed: %v", err)
|
||||
}
|
||||
|
||||
mb := NewMessageBus()
|
||||
mb.SetEventPublisher(eventBus)
|
||||
|
||||
if err := mb.PublishInbound(context.Background(), InboundMessage{}); err == nil {
|
||||
t.Fatal("expected PublishInbound to fail")
|
||||
}
|
||||
failed := receiveBusRuntimeEvent(t, eventsCh)
|
||||
if failed.Kind != runtimeevents.KindBusPublishFailed ||
|
||||
failed.Source.Name != "inbound" ||
|
||||
failed.Severity != runtimeevents.SeverityError {
|
||||
t.Fatalf("publish failed event = %+v", failed)
|
||||
}
|
||||
|
||||
if err := mb.PublishOutbound(context.Background(), OutboundMessage{
|
||||
Context: NewOutboundContext("telegram", "chat-1", ""),
|
||||
Content: "queued",
|
||||
}); err != nil {
|
||||
t.Fatalf("PublishOutbound failed: %v", err)
|
||||
}
|
||||
mb.Close()
|
||||
|
||||
seen := map[runtimeevents.Kind]bool{}
|
||||
for range 3 {
|
||||
evt := receiveBusRuntimeEvent(t, eventsCh)
|
||||
seen[evt.Kind] = true
|
||||
}
|
||||
for _, kind := range []runtimeevents.Kind{
|
||||
runtimeevents.KindBusCloseStarted,
|
||||
runtimeevents.KindBusCloseDrained,
|
||||
runtimeevents.KindBusCloseCompleted,
|
||||
} {
|
||||
if !seen[kind] {
|
||||
t.Fatalf("missing %s event, seen=%v", kind, seen)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func receiveBusRuntimeEvent(t *testing.T, ch <-chan runtimeevents.Event) runtimeevents.Event {
|
||||
t.Helper()
|
||||
|
||||
select {
|
||||
case evt, ok := <-ch:
|
||||
if !ok {
|
||||
t.Fatal("runtime event channel closed before expected event")
|
||||
}
|
||||
return evt
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for runtime event")
|
||||
return runtimeevents.Event{}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishOutboundSubscribe(t *testing.T) {
|
||||
mb := NewMessageBus()
|
||||
defer mb.Close()
|
||||
|
||||
@@ -0,0 +1,88 @@
|
||||
package bus
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
||||
)
|
||||
|
||||
const busEventPublishTimeout = 100 * time.Millisecond
|
||||
|
||||
type busPublishFailedPayload struct {
|
||||
Stream string `json:"stream"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
type busClosePayload struct {
|
||||
Drained int `json:"drained,omitempty"`
|
||||
}
|
||||
|
||||
func (mb *MessageBus) publishFailure(stream string, scope runtimeevents.Scope, err error) {
|
||||
if mb == nil || err == nil {
|
||||
return
|
||||
}
|
||||
publisher, ok := mb.eventPublisher.Load().(EventPublisher)
|
||||
if !ok || publisher == nil {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), busEventPublishTimeout)
|
||||
defer cancel()
|
||||
publisher.Publish(ctx, runtimeevents.Event{
|
||||
Kind: runtimeevents.KindBusPublishFailed,
|
||||
Source: runtimeevents.Source{Component: "bus", Name: stream},
|
||||
Scope: scope,
|
||||
Severity: runtimeevents.SeverityError,
|
||||
Payload: busPublishFailedPayload{
|
||||
Stream: stream,
|
||||
Error: err.Error(),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (mb *MessageBus) publishCloseEvent(kind runtimeevents.Kind, drained int) {
|
||||
if mb == nil {
|
||||
return
|
||||
}
|
||||
publisher, ok := mb.eventPublisher.Load().(EventPublisher)
|
||||
if !ok || publisher == nil {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), busEventPublishTimeout)
|
||||
defer cancel()
|
||||
publisher.Publish(ctx, runtimeevents.Event{
|
||||
Kind: kind,
|
||||
Source: runtimeevents.Source{Component: "bus"},
|
||||
Severity: runtimeevents.SeverityInfo,
|
||||
Payload: busClosePayload{Drained: drained},
|
||||
})
|
||||
}
|
||||
|
||||
func runtimeScopeFromInboundContext(ctx InboundContext) runtimeevents.Scope {
|
||||
return runtimeevents.Scope{
|
||||
Channel: ctx.Channel,
|
||||
Account: ctx.Account,
|
||||
ChatID: ctx.ChatID,
|
||||
TopicID: ctx.TopicID,
|
||||
SpaceID: ctx.SpaceID,
|
||||
SpaceType: ctx.SpaceType,
|
||||
ChatType: ctx.ChatType,
|
||||
SenderID: ctx.SenderID,
|
||||
MessageID: ctx.MessageID,
|
||||
}
|
||||
}
|
||||
|
||||
func runtimeScopeFromAudioChunk(chunk AudioChunk) runtimeevents.Scope {
|
||||
return runtimeevents.Scope{
|
||||
Channel: chunk.Channel,
|
||||
ChatID: chunk.ChatID,
|
||||
}
|
||||
}
|
||||
|
||||
func runtimeScopeFromVoiceControl(ctrl VoiceControl) runtimeevents.Scope {
|
||||
return runtimeevents.Scope{
|
||||
ChatID: ctrl.ChatID,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,168 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
||||
)
|
||||
|
||||
const channelEventPublishTimeout = 100 * time.Millisecond
|
||||
|
||||
func channelTypeForEvent(m *Manager, channelName string) string {
|
||||
if m == nil || m.config == nil {
|
||||
return channelName
|
||||
}
|
||||
if bc := m.config.Channels.Get(channelName); bc != nil && bc.Type != "" {
|
||||
return bc.Type
|
||||
}
|
||||
return channelName
|
||||
}
|
||||
|
||||
func (m *Manager) publishChannelEvent(
|
||||
kind runtimeevents.Kind,
|
||||
channelName string,
|
||||
scope runtimeevents.Scope,
|
||||
severity runtimeevents.Severity,
|
||||
payload any,
|
||||
) {
|
||||
if m == nil || m.runtimeEvents == nil {
|
||||
return
|
||||
}
|
||||
if scope.Channel == "" {
|
||||
scope.Channel = channelName
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), channelEventPublishTimeout)
|
||||
defer cancel()
|
||||
m.runtimeEvents.Publish(ctx, runtimeevents.Event{
|
||||
Kind: kind,
|
||||
Source: runtimeevents.Source{Component: "channel", Name: channelName},
|
||||
Scope: scope,
|
||||
Severity: severity,
|
||||
Payload: payload,
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Manager) publishOutboundSent(
|
||||
channelName string,
|
||||
msg bus.OutboundMessage,
|
||||
messageIDs []string,
|
||||
) {
|
||||
m.publishChannelEvent(
|
||||
runtimeevents.KindChannelMessageOutboundSent,
|
||||
channelName,
|
||||
scopeFromOutboundContext(msg.Context),
|
||||
runtimeevents.SeverityInfo,
|
||||
ChannelOutboundPayload{
|
||||
ContentLen: len([]rune(msg.Content)),
|
||||
MessageIDs: append([]string(nil), messageIDs...),
|
||||
ReplyToMessageID: msg.ReplyToMessageID,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func (m *Manager) publishOutboundQueued(
|
||||
channelName string,
|
||||
msg bus.OutboundMessage,
|
||||
) {
|
||||
m.publishChannelEvent(
|
||||
runtimeevents.KindChannelMessageOutboundQueued,
|
||||
channelName,
|
||||
scopeFromOutboundContext(msg.Context),
|
||||
runtimeevents.SeverityInfo,
|
||||
ChannelOutboundPayload{
|
||||
ContentLen: len([]rune(msg.Content)),
|
||||
ReplyToMessageID: msg.ReplyToMessageID,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func (m *Manager) publishOutboundFailed(
|
||||
channelName string,
|
||||
msg bus.OutboundMessage,
|
||||
err error,
|
||||
media bool,
|
||||
) {
|
||||
payload := ChannelOutboundPayload{
|
||||
Media: media,
|
||||
ContentLen: len([]rune(msg.Content)),
|
||||
ReplyToMessageID: msg.ReplyToMessageID,
|
||||
Retries: maxRetries,
|
||||
}
|
||||
if err != nil {
|
||||
payload.Error = err.Error()
|
||||
}
|
||||
m.publishChannelEvent(
|
||||
runtimeevents.KindChannelMessageOutboundFailed,
|
||||
channelName,
|
||||
scopeFromOutboundContext(msg.Context),
|
||||
runtimeevents.SeverityError,
|
||||
payload,
|
||||
)
|
||||
}
|
||||
|
||||
func (m *Manager) publishOutboundMediaSent(
|
||||
channelName string,
|
||||
msg bus.OutboundMediaMessage,
|
||||
messageIDs []string,
|
||||
) {
|
||||
m.publishChannelEvent(
|
||||
runtimeevents.KindChannelMessageOutboundSent,
|
||||
channelName,
|
||||
scopeFromOutboundContext(msg.Context),
|
||||
runtimeevents.SeverityInfo,
|
||||
ChannelOutboundPayload{
|
||||
Media: true,
|
||||
MessageIDs: append([]string(nil), messageIDs...),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func (m *Manager) publishOutboundMediaQueued(
|
||||
channelName string,
|
||||
msg bus.OutboundMediaMessage,
|
||||
) {
|
||||
m.publishChannelEvent(
|
||||
runtimeevents.KindChannelMessageOutboundQueued,
|
||||
channelName,
|
||||
scopeFromOutboundContext(msg.Context),
|
||||
runtimeevents.SeverityInfo,
|
||||
ChannelOutboundPayload{Media: true},
|
||||
)
|
||||
}
|
||||
|
||||
func (m *Manager) publishOutboundMediaFailed(
|
||||
channelName string,
|
||||
msg bus.OutboundMediaMessage,
|
||||
err error,
|
||||
) {
|
||||
payload := ChannelOutboundPayload{
|
||||
Media: true,
|
||||
Retries: maxRetries,
|
||||
}
|
||||
if err != nil {
|
||||
payload.Error = err.Error()
|
||||
}
|
||||
m.publishChannelEvent(
|
||||
runtimeevents.KindChannelMessageOutboundFailed,
|
||||
channelName,
|
||||
scopeFromOutboundContext(msg.Context),
|
||||
runtimeevents.SeverityError,
|
||||
payload,
|
||||
)
|
||||
}
|
||||
|
||||
func scopeFromOutboundContext(ctx bus.InboundContext) runtimeevents.Scope {
|
||||
return runtimeevents.Scope{
|
||||
Channel: ctx.Channel,
|
||||
Account: ctx.Account,
|
||||
ChatID: ctx.ChatID,
|
||||
TopicID: ctx.TopicID,
|
||||
SpaceID: ctx.SpaceID,
|
||||
SpaceType: ctx.SpaceType,
|
||||
ChatType: ctx.ChatType,
|
||||
SenderID: ctx.SenderID,
|
||||
MessageID: ctx.MessageID,
|
||||
}
|
||||
}
|
||||
+125
-1
@@ -23,6 +23,7 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/constants"
|
||||
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
||||
"github.com/sipeed/picoclaw/pkg/health"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
@@ -84,6 +85,7 @@ type Manager struct {
|
||||
channels map[string]Channel
|
||||
workers map[string]*channelWorker
|
||||
bus *bus.MessageBus
|
||||
runtimeEvents runtimeevents.Bus
|
||||
config *config.Config
|
||||
mediaStore media.MediaStore
|
||||
dispatchTask *asyncTask
|
||||
@@ -98,6 +100,32 @@ type Manager struct {
|
||||
channelHashes map[string]string // channel name → config hash
|
||||
}
|
||||
|
||||
// ManagerOption configures a channel Manager.
|
||||
type ManagerOption func(*Manager)
|
||||
|
||||
// WithRuntimeEvents injects the runtime event bus used for channel observations.
|
||||
func WithRuntimeEvents(eventBus runtimeevents.Bus) ManagerOption {
|
||||
return func(m *Manager) {
|
||||
m.runtimeEvents = eventBus
|
||||
}
|
||||
}
|
||||
|
||||
// ChannelLifecyclePayload describes channel lifecycle runtime events.
|
||||
type ChannelLifecyclePayload struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// ChannelOutboundPayload describes channel outbound message runtime events.
|
||||
type ChannelOutboundPayload struct {
|
||||
Media bool `json:"media,omitempty"`
|
||||
ContentLen int `json:"content_len,omitempty"`
|
||||
MessageIDs []string `json:"message_ids,omitempty"`
|
||||
ReplyToMessageID string `json:"reply_to_message_id,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Retries int `json:"retries,omitempty"`
|
||||
}
|
||||
|
||||
type toolFeedbackMessageTracker interface {
|
||||
RecordToolFeedbackMessage(chatID, messageID, content string)
|
||||
ClearToolFeedbackMessage(chatID string)
|
||||
@@ -395,7 +423,12 @@ func (m *Manager) preSendMedia(ctx context.Context, name string, msg bus.Outboun
|
||||
}
|
||||
}
|
||||
|
||||
func NewManager(cfg *config.Config, messageBus *bus.MessageBus, store media.MediaStore) (*Manager, error) {
|
||||
func NewManager(
|
||||
cfg *config.Config,
|
||||
messageBus *bus.MessageBus,
|
||||
store media.MediaStore,
|
||||
opts ...ManagerOption,
|
||||
) (*Manager, error) {
|
||||
m := &Manager{
|
||||
channels: make(map[string]Channel),
|
||||
workers: make(map[string]*channelWorker),
|
||||
@@ -404,6 +437,11 @@ func NewManager(cfg *config.Config, messageBus *bus.MessageBus, store media.Medi
|
||||
mediaStore: store,
|
||||
channelHashes: make(map[string]string),
|
||||
}
|
||||
for _, opt := range opts {
|
||||
if opt != nil {
|
||||
opt(m)
|
||||
}
|
||||
}
|
||||
|
||||
// Register as streaming delegate so the agent loop can obtain streamers
|
||||
messageBus.SetStreamDelegate(m)
|
||||
@@ -528,6 +566,13 @@ func (m *Manager) initChannel(typeName, channelName string) {
|
||||
setter.SetOwner(ch)
|
||||
}
|
||||
m.channels[channelName] = ch
|
||||
m.publishChannelEvent(
|
||||
runtimeevents.KindChannelLifecycleInitialized,
|
||||
channelName,
|
||||
runtimeevents.Scope{Channel: channelName},
|
||||
runtimeevents.SeverityInfo,
|
||||
ChannelLifecyclePayload{Type: typeName},
|
||||
)
|
||||
logger.InfoCF("channels", "Channel enabled successfully", map[string]any{
|
||||
"channel": channelName,
|
||||
"type": typeName,
|
||||
@@ -673,6 +718,13 @@ func (m *Manager) registerHTTPHandlersLocked() {
|
||||
func (m *Manager) registerChannelHTTPHandler(name string, ch Channel) {
|
||||
if wh, ok := ch.(WebhookHandler); ok {
|
||||
m.mux.Handle(wh.WebhookPath(), wh)
|
||||
m.publishChannelEvent(
|
||||
runtimeevents.KindChannelWebhookRegistered,
|
||||
name,
|
||||
runtimeevents.Scope{Channel: name},
|
||||
runtimeevents.SeverityInfo,
|
||||
ChannelLifecyclePayload{Type: channelTypeForEvent(m, name)},
|
||||
)
|
||||
logger.InfoCF("channels", "Webhook handler registered", map[string]any{
|
||||
"channel": name,
|
||||
"path": wh.WebhookPath(),
|
||||
@@ -692,6 +744,13 @@ func (m *Manager) registerChannelHTTPHandler(name string, ch Channel) {
|
||||
func (m *Manager) unregisterChannelHTTPHandler(name string, ch Channel) {
|
||||
if wh, ok := ch.(WebhookHandler); ok {
|
||||
m.mux.Unhandle(wh.WebhookPath())
|
||||
m.publishChannelEvent(
|
||||
runtimeevents.KindChannelWebhookUnregistered,
|
||||
name,
|
||||
runtimeevents.Scope{Channel: name},
|
||||
runtimeevents.SeverityInfo,
|
||||
ChannelLifecyclePayload{Type: channelTypeForEvent(m, name)},
|
||||
)
|
||||
logger.InfoCF("channels", "Webhook handler unregistered", map[string]any{
|
||||
"channel": name,
|
||||
"path": wh.WebhookPath(),
|
||||
@@ -730,6 +789,13 @@ func (m *Manager) StartAll(ctx context.Context) error {
|
||||
"channel": name,
|
||||
"error": err.Error(),
|
||||
})
|
||||
m.publishChannelEvent(
|
||||
runtimeevents.KindChannelLifecycleStartFailed,
|
||||
name,
|
||||
runtimeevents.Scope{Channel: name},
|
||||
runtimeevents.SeverityError,
|
||||
ChannelLifecyclePayload{Type: channelTypeForEvent(m, name), Error: err.Error()},
|
||||
)
|
||||
failedStarts = append(failedStarts, fmt.Errorf("channel %s: %w", name, err))
|
||||
failedNames = append(failedNames, name)
|
||||
continue
|
||||
@@ -745,6 +811,13 @@ func (m *Manager) StartAll(ctx context.Context) error {
|
||||
m.workers[name] = w
|
||||
go m.runWorker(dispatchCtx, name, w)
|
||||
go m.runMediaWorker(dispatchCtx, name, w)
|
||||
m.publishChannelEvent(
|
||||
runtimeevents.KindChannelLifecycleStarted,
|
||||
name,
|
||||
runtimeevents.Scope{Channel: name},
|
||||
runtimeevents.SeverityInfo,
|
||||
ChannelLifecyclePayload{Type: channelType},
|
||||
)
|
||||
}
|
||||
|
||||
if len(m.channels) > 0 && len(m.workers) == 0 {
|
||||
@@ -881,7 +954,15 @@ func (m *Manager) StopAll(ctx context.Context) error {
|
||||
"channel": name,
|
||||
"error": err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
m.publishChannelEvent(
|
||||
runtimeevents.KindChannelLifecycleStopped,
|
||||
name,
|
||||
runtimeevents.Scope{Channel: name},
|
||||
runtimeevents.SeverityInfo,
|
||||
ChannelLifecyclePayload{Type: channelTypeForEvent(m, name)},
|
||||
)
|
||||
}
|
||||
|
||||
logger.InfoC("channels", "All channels stopped")
|
||||
@@ -991,11 +1072,23 @@ func (m *Manager) sendWithRetry(
|
||||
// Rate limit: wait for token
|
||||
if err := w.limiter.Wait(ctx); err != nil {
|
||||
// ctx canceled, shutting down
|
||||
m.publishChannelEvent(
|
||||
runtimeevents.KindChannelRateLimited,
|
||||
name,
|
||||
scopeFromOutboundContext(msg.Context),
|
||||
runtimeevents.SeverityWarn,
|
||||
ChannelOutboundPayload{
|
||||
ContentLen: len([]rune(msg.Content)),
|
||||
ReplyToMessageID: msg.ReplyToMessageID,
|
||||
Error: err.Error(),
|
||||
},
|
||||
)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Pre-send: stop typing and try to edit placeholder
|
||||
if msgIDs, handled := m.preSend(ctx, name, msg, w.ch); handled {
|
||||
m.publishOutboundSent(name, msg, msgIDs)
|
||||
return msgIDs, true
|
||||
}
|
||||
|
||||
@@ -1004,6 +1097,7 @@ func (m *Manager) sendWithRetry(
|
||||
for attempt := 0; attempt <= maxRetries; attempt++ {
|
||||
msgIDs, lastErr = w.ch.Send(ctx, msg)
|
||||
if lastErr == nil {
|
||||
m.publishOutboundSent(name, msg, msgIDs)
|
||||
return msgIDs, true
|
||||
}
|
||||
|
||||
@@ -1043,6 +1137,7 @@ func (m *Manager) sendWithRetry(
|
||||
"error": lastErr.Error(),
|
||||
"retries": maxRetries,
|
||||
})
|
||||
m.publishOutboundFailed(name, msg, lastErr, false)
|
||||
|
||||
return nil, false
|
||||
}
|
||||
@@ -1105,6 +1200,7 @@ func (m *Manager) dispatchOutbound(ctx context.Context) {
|
||||
func(ctx context.Context, w *channelWorker, msg bus.OutboundMessage) bool {
|
||||
select {
|
||||
case w.queue <- msg:
|
||||
m.publishOutboundQueued(outboundMessageChannel(msg), msg)
|
||||
return true
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
@@ -1125,6 +1221,7 @@ func (m *Manager) dispatchOutboundMedia(ctx context.Context) {
|
||||
func(ctx context.Context, w *channelWorker, msg bus.OutboundMediaMessage) bool {
|
||||
select {
|
||||
case w.mediaQueue <- msg:
|
||||
m.publishOutboundMediaQueued(outboundMediaChannel(msg), msg)
|
||||
return true
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
@@ -1174,6 +1271,16 @@ func (m *Manager) sendMediaWithRetry(
|
||||
|
||||
// Rate limit: wait for token
|
||||
if err := w.limiter.Wait(ctx); err != nil {
|
||||
m.publishChannelEvent(
|
||||
runtimeevents.KindChannelRateLimited,
|
||||
name,
|
||||
scopeFromOutboundContext(msg.Context),
|
||||
runtimeevents.SeverityWarn,
|
||||
ChannelOutboundPayload{
|
||||
Media: true,
|
||||
Error: err.Error(),
|
||||
},
|
||||
)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1185,6 +1292,7 @@ func (m *Manager) sendMediaWithRetry(
|
||||
for attempt := 0; attempt <= maxRetries; attempt++ {
|
||||
msgIDs, lastErr = ms.SendMedia(ctx, msg)
|
||||
if lastErr == nil {
|
||||
m.publishOutboundMediaSent(name, msg, msgIDs)
|
||||
return msgIDs, nil
|
||||
}
|
||||
|
||||
@@ -1224,6 +1332,7 @@ func (m *Manager) sendMediaWithRetry(
|
||||
"error": lastErr.Error(),
|
||||
"retries": maxRetries,
|
||||
})
|
||||
m.publishOutboundMediaFailed(name, msg, lastErr)
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
@@ -1361,6 +1470,13 @@ func (m *Manager) Reload(ctx context.Context, cfg *config.Config) error {
|
||||
"channel": name,
|
||||
"error": err.Error(),
|
||||
})
|
||||
m.publishChannelEvent(
|
||||
runtimeevents.KindChannelLifecycleStartFailed,
|
||||
name,
|
||||
runtimeevents.Scope{Channel: name},
|
||||
runtimeevents.SeverityError,
|
||||
ChannelLifecyclePayload{Type: channelTypeForEvent(m, name), Error: err.Error()},
|
||||
)
|
||||
continue
|
||||
}
|
||||
// Lazily create worker only after channel starts successfully
|
||||
@@ -1374,6 +1490,13 @@ func (m *Manager) Reload(ctx context.Context, cfg *config.Config) error {
|
||||
m.workers[name] = w
|
||||
go m.runWorker(dispatchCtx, name, w)
|
||||
go m.runMediaWorker(dispatchCtx, name, w)
|
||||
m.publishChannelEvent(
|
||||
runtimeevents.KindChannelLifecycleStarted,
|
||||
name,
|
||||
runtimeevents.Scope{Channel: name},
|
||||
runtimeevents.SeverityInfo,
|
||||
ChannelLifecyclePayload{Type: channelType},
|
||||
)
|
||||
deferFuncs = append(deferFuncs, func() {
|
||||
m.RegisterChannel(name, channel)
|
||||
})
|
||||
@@ -1496,6 +1619,7 @@ func (m *Manager) SendToChannel(ctx context.Context, channelName, chatID, conten
|
||||
if wExists && w != nil {
|
||||
select {
|
||||
case w.queue <- msg:
|
||||
m.publishOutboundQueued(channelName, msg)
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
@@ -242,6 +243,57 @@ func TestStartAll_PartialFailure_StartsSuccessfulWorkers(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartAllPublishesLifecycleRuntimeEvents(t *testing.T) {
|
||||
eventBus := runtimeevents.NewBus()
|
||||
defer func() {
|
||||
if err := eventBus.Close(); err != nil {
|
||||
t.Errorf("event bus close failed: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
_, eventsCh, err := eventBus.Channel().SubscribeChan(
|
||||
t.Context(),
|
||||
runtimeevents.SubscribeOptions{Name: "channel-lifecycle", Buffer: 4},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("SubscribeChan failed: %v", err)
|
||||
}
|
||||
|
||||
m := newTestManager()
|
||||
m.runtimeEvents = eventBus
|
||||
m.config = &config.Config{Channels: config.ChannelsConfig{}}
|
||||
m.channels["good"] = &mockChannel{}
|
||||
m.channels["bad"] = &mockChannel{
|
||||
startFn: func(_ context.Context) error { return errors.New("bad start") },
|
||||
}
|
||||
|
||||
if err := m.StartAll(t.Context()); err != nil {
|
||||
t.Fatalf("StartAll() error = %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
stopCtx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
if err := m.StopAll(stopCtx); err != nil {
|
||||
t.Errorf("StopAll() error = %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
events := []runtimeevents.Event{
|
||||
receiveChannelRuntimeEvent(t, eventsCh),
|
||||
receiveChannelRuntimeEvent(t, eventsCh),
|
||||
}
|
||||
seen := map[runtimeevents.Kind]runtimeevents.Event{}
|
||||
for _, evt := range events {
|
||||
seen[evt.Kind] = evt
|
||||
}
|
||||
if evt, ok := seen[runtimeevents.KindChannelLifecycleStarted]; !ok || evt.Scope.Channel != "good" {
|
||||
t.Fatalf("missing started event for good channel: %+v", events)
|
||||
}
|
||||
if evt, ok := seen[runtimeevents.KindChannelLifecycleStartFailed]; !ok || evt.Scope.Channel != "bad" {
|
||||
t.Fatalf("missing failed event for bad channel: %+v", events)
|
||||
}
|
||||
}
|
||||
|
||||
func testOutboundMessage(msg bus.OutboundMessage) bus.OutboundMessage {
|
||||
if msg.Context.Channel == "" && msg.Context.ChatID == "" {
|
||||
msg.Context = bus.NewOutboundContext(msg.Channel, msg.ChatID, msg.ReplyToMessageID)
|
||||
@@ -256,6 +308,21 @@ func testOutboundMediaMessage(msg bus.OutboundMediaMessage) bus.OutboundMediaMes
|
||||
return bus.NormalizeOutboundMediaMessage(msg)
|
||||
}
|
||||
|
||||
func receiveChannelRuntimeEvent(t *testing.T, ch <-chan runtimeevents.Event) runtimeevents.Event {
|
||||
t.Helper()
|
||||
|
||||
select {
|
||||
case evt, ok := <-ch:
|
||||
if !ok {
|
||||
t.Fatal("runtime event channel closed before expected event")
|
||||
}
|
||||
return evt
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for runtime event")
|
||||
return runtimeevents.Event{}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendWithRetry_Success(t *testing.T) {
|
||||
m := newTestManager()
|
||||
var callCount int
|
||||
@@ -280,6 +347,63 @@ func TestSendWithRetry_Success(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendWithRetryPublishesOutboundRuntimeEvents(t *testing.T) {
|
||||
eventBus := runtimeevents.NewBus()
|
||||
defer func() {
|
||||
if err := eventBus.Close(); err != nil {
|
||||
t.Errorf("event bus close failed: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
_, eventsCh, err := eventBus.Channel().OfKind(
|
||||
runtimeevents.KindChannelMessageOutboundSent,
|
||||
runtimeevents.KindChannelMessageOutboundFailed,
|
||||
).SubscribeChan(t.Context(), runtimeevents.SubscribeOptions{Name: "channel-outbound", Buffer: 2})
|
||||
if err != nil {
|
||||
t.Fatalf("SubscribeChan failed: %v", err)
|
||||
}
|
||||
|
||||
m := newTestManager()
|
||||
m.runtimeEvents = eventBus
|
||||
|
||||
successWorker := &channelWorker{
|
||||
ch: &mockChannel{},
|
||||
limiter: rate.NewLimiter(rate.Inf, 1),
|
||||
}
|
||||
m.sendWithRetry(
|
||||
context.Background(),
|
||||
"test",
|
||||
successWorker,
|
||||
testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "chat-1", Content: "hello"}),
|
||||
)
|
||||
sent := receiveChannelRuntimeEvent(t, eventsCh)
|
||||
if sent.Kind != runtimeevents.KindChannelMessageOutboundSent || sent.Scope.ChatID != "chat-1" {
|
||||
t.Fatalf("sent event = %+v", sent)
|
||||
}
|
||||
|
||||
failWorker := &channelWorker{
|
||||
ch: &mockChannel{
|
||||
sendFn: func(context.Context, bus.OutboundMessage) error {
|
||||
return fmt.Errorf("send failed: %w", ErrSendFailed)
|
||||
},
|
||||
},
|
||||
limiter: rate.NewLimiter(rate.Inf, 1),
|
||||
}
|
||||
m.sendWithRetry(
|
||||
context.Background(),
|
||||
"test",
|
||||
failWorker,
|
||||
testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "chat-2", Content: "hello"}),
|
||||
)
|
||||
failed := receiveChannelRuntimeEvent(t, eventsCh)
|
||||
if failed.Kind != runtimeevents.KindChannelMessageOutboundFailed || failed.Scope.ChatID != "chat-2" {
|
||||
t.Fatalf("failed event = %+v", failed)
|
||||
}
|
||||
if failed.Severity != runtimeevents.SeverityError {
|
||||
t.Fatalf("failed severity = %q", failed.Severity)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendWithRetry_TemporaryThenSuccess(t *testing.T) {
|
||||
m := newTestManager()
|
||||
var callCount int
|
||||
|
||||
@@ -47,12 +47,33 @@ const (
|
||||
|
||||
// KindChannelLifecycleStarted is emitted when a channel starts.
|
||||
KindChannelLifecycleStarted Kind = "channel.lifecycle.started"
|
||||
// KindChannelLifecycleInitialized is emitted when a channel is initialized.
|
||||
KindChannelLifecycleInitialized Kind = "channel.lifecycle.initialized"
|
||||
// KindChannelLifecycleStartFailed is emitted when a channel fails to start.
|
||||
KindChannelLifecycleStartFailed Kind = "channel.lifecycle.start_failed"
|
||||
// KindChannelLifecycleStopped is emitted when a channel stops.
|
||||
KindChannelLifecycleStopped Kind = "channel.lifecycle.stopped"
|
||||
// KindChannelWebhookRegistered is emitted when a channel webhook is registered.
|
||||
KindChannelWebhookRegistered Kind = "channel.webhook.registered"
|
||||
// KindChannelWebhookUnregistered is emitted when a channel webhook is unregistered.
|
||||
KindChannelWebhookUnregistered Kind = "channel.webhook.unregistered"
|
||||
// KindChannelMessageOutboundQueued is emitted when an outbound message is queued.
|
||||
KindChannelMessageOutboundQueued Kind = "channel.message.outbound_queued"
|
||||
// KindChannelMessageOutboundSent is emitted when an outbound channel message is sent.
|
||||
KindChannelMessageOutboundSent Kind = "channel.message.outbound_sent"
|
||||
// KindChannelMessageOutboundFailed is emitted when an outbound channel message fails.
|
||||
KindChannelMessageOutboundFailed Kind = "channel.message.outbound_failed"
|
||||
// KindChannelRateLimited is emitted when channel rate limiting blocks delivery.
|
||||
KindChannelRateLimited Kind = "channel.rate_limited"
|
||||
|
||||
// KindBusPublishFailed is emitted when message bus publish fails.
|
||||
KindBusPublishFailed Kind = "bus.publish.failed"
|
||||
// KindBusCloseStarted is emitted when message bus close starts.
|
||||
KindBusCloseStarted Kind = "bus.close.started"
|
||||
// KindBusCloseCompleted is emitted when message bus close completes.
|
||||
KindBusCloseCompleted Kind = "bus.close.completed"
|
||||
// KindBusCloseDrained is emitted when message bus close drains buffered messages.
|
||||
KindBusCloseDrained Kind = "bus.close.drained"
|
||||
|
||||
// KindGatewayReloadStarted is emitted when gateway reload starts.
|
||||
KindGatewayReloadStarted Kind = "gateway.reload.started"
|
||||
@@ -63,6 +84,14 @@ const (
|
||||
|
||||
// KindMCPServerConnected is emitted when an MCP server connects.
|
||||
KindMCPServerConnected Kind = "mcp.server.connected"
|
||||
// KindMCPServerConnecting is emitted before connecting to an MCP server.
|
||||
KindMCPServerConnecting Kind = "mcp.server.connecting"
|
||||
// KindMCPServerFailed is emitted when an MCP server fails.
|
||||
KindMCPServerFailed Kind = "mcp.server.failed"
|
||||
// KindMCPToolDiscovered is emitted when an MCP tool is discovered.
|
||||
KindMCPToolDiscovered Kind = "mcp.tool.discovered"
|
||||
// KindMCPToolCallStart is emitted when an MCP tool call starts.
|
||||
KindMCPToolCallStart Kind = "mcp.tool.call.start"
|
||||
// KindMCPToolCallEnd is emitted when an MCP tool call ends.
|
||||
KindMCPToolCallEnd Kind = "mcp.tool.call.end"
|
||||
)
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
package gateway
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/agent"
|
||||
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
||||
)
|
||||
|
||||
const gatewayEventPublishTimeout = 100 * time.Millisecond
|
||||
|
||||
type gatewayReloadPayload struct {
|
||||
DurationMS int64 `json:"duration_ms,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func publishGatewayReloadEvent(
|
||||
al *agent.AgentLoop,
|
||||
kind runtimeevents.Kind,
|
||||
startedAt time.Time,
|
||||
err error,
|
||||
) {
|
||||
if al == nil || al.RuntimeEventBus() == nil {
|
||||
return
|
||||
}
|
||||
|
||||
severity := runtimeevents.SeverityInfo
|
||||
payload := gatewayReloadPayload{}
|
||||
if !startedAt.IsZero() {
|
||||
payload.DurationMS = time.Since(startedAt).Milliseconds()
|
||||
}
|
||||
if err != nil {
|
||||
severity = runtimeevents.SeverityError
|
||||
payload.Error = err.Error()
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), gatewayEventPublishTimeout)
|
||||
defer cancel()
|
||||
al.RuntimeEventBus().Publish(ctx, runtimeevents.Event{
|
||||
Kind: kind,
|
||||
Source: runtimeevents.Source{Component: "gateway"},
|
||||
Severity: severity,
|
||||
Payload: payload,
|
||||
})
|
||||
}
|
||||
+20
-3
@@ -39,6 +39,7 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/cron"
|
||||
"github.com/sipeed/picoclaw/pkg/devices"
|
||||
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
||||
"github.com/sipeed/picoclaw/pkg/health"
|
||||
"github.com/sipeed/picoclaw/pkg/heartbeat"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
@@ -197,6 +198,7 @@ func Run(debug bool, homePath, configPath string, allowEmptyStartup bool) (runEr
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
agentLoop := agent.NewAgentLoop(cfg, msgBus, provider)
|
||||
msgBus.SetEventPublisher(agentLoop.RuntimeEventBus())
|
||||
|
||||
fmt.Println("\n📦 Agent Status:")
|
||||
startupInfo := agentLoop.GetStartupInfo()
|
||||
@@ -312,10 +314,20 @@ func executeReload(
|
||||
msgBus *bus.MessageBus,
|
||||
allowEmptyStartup bool,
|
||||
debug bool,
|
||||
) error {
|
||||
) (err error) {
|
||||
startedAt := time.Now()
|
||||
publishGatewayReloadEvent(agentLoop, runtimeevents.KindGatewayReloadStarted, startedAt, nil)
|
||||
defer runningServices.reloading.Store(false)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
publishGatewayReloadEvent(agentLoop, runtimeevents.KindGatewayReloadFailed, startedAt, err)
|
||||
return
|
||||
}
|
||||
publishGatewayReloadEvent(agentLoop, runtimeevents.KindGatewayReloadCompleted, startedAt, nil)
|
||||
}()
|
||||
|
||||
return handleConfigReload(ctx, agentLoop, newCfg, provider, runningServices, msgBus, allowEmptyStartup, debug)
|
||||
err = handleConfigReload(ctx, agentLoop, newCfg, provider, runningServices, msgBus, allowEmptyStartup, debug)
|
||||
return err
|
||||
}
|
||||
|
||||
func createStartupProvider(
|
||||
@@ -383,7 +395,12 @@ func setupAndStartServices(
|
||||
fms.Start()
|
||||
}
|
||||
|
||||
runningServices.ChannelManager, err = channels.NewManager(cfg, msgBus, runningServices.MediaStore)
|
||||
runningServices.ChannelManager, err = channels.NewManager(
|
||||
cfg,
|
||||
msgBus,
|
||||
runningServices.MediaStore,
|
||||
channels.WithRuntimeEvents(agentLoop.RuntimeEventBus()),
|
||||
)
|
||||
if err != nil {
|
||||
if fms, ok := runningServices.MediaStore.(*media.FileMediaStore); ok {
|
||||
fms.Stop()
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
||||
)
|
||||
|
||||
const mcpEventPublishTimeout = 100 * time.Millisecond
|
||||
|
||||
func (m *Manager) publishServerEvent(
|
||||
kind runtimeevents.Kind,
|
||||
serverName string,
|
||||
cfg config.MCPServerConfig,
|
||||
toolCount int,
|
||||
err error,
|
||||
) {
|
||||
if m == nil || m.runtimeEvents == nil {
|
||||
return
|
||||
}
|
||||
|
||||
severity := runtimeevents.SeverityInfo
|
||||
if err != nil {
|
||||
severity = runtimeevents.SeverityError
|
||||
}
|
||||
payload := ServerEventPayload{
|
||||
Server: serverName,
|
||||
Type: mcpTransportType(cfg),
|
||||
URL: cfg.URL,
|
||||
Command: cfg.Command,
|
||||
ToolCount: toolCount,
|
||||
}
|
||||
if err != nil {
|
||||
payload.Error = err.Error()
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), mcpEventPublishTimeout)
|
||||
defer cancel()
|
||||
m.runtimeEvents.Publish(ctx, runtimeevents.Event{
|
||||
Kind: kind,
|
||||
Source: runtimeevents.Source{Component: "mcp", Name: serverName},
|
||||
Severity: severity,
|
||||
Payload: payload,
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Manager) publishToolDiscovered(serverName string, cfg config.MCPServerConfig, toolName string) {
|
||||
if m == nil || m.runtimeEvents == nil {
|
||||
return
|
||||
}
|
||||
payload := ServerEventPayload{
|
||||
Server: serverName,
|
||||
Type: mcpTransportType(cfg),
|
||||
URL: cfg.URL,
|
||||
Command: cfg.Command,
|
||||
Tool: toolName,
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), mcpEventPublishTimeout)
|
||||
defer cancel()
|
||||
m.runtimeEvents.Publish(ctx, runtimeevents.Event{
|
||||
Kind: runtimeevents.KindMCPToolDiscovered,
|
||||
Source: runtimeevents.Source{Component: "mcp", Name: serverName},
|
||||
Severity: runtimeevents.SeverityInfo,
|
||||
Payload: payload,
|
||||
})
|
||||
}
|
||||
|
||||
func mcpTransportType(cfg config.MCPServerConfig) string {
|
||||
if cfg.Type != "" {
|
||||
return cfg.Type
|
||||
}
|
||||
if cfg.URL != "" {
|
||||
return "sse"
|
||||
}
|
||||
if cfg.Command != "" {
|
||||
return "stdio"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
+46
-6
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
@@ -127,19 +128,47 @@ type ServerConnection struct {
|
||||
|
||||
// Manager manages multiple MCP server connections
|
||||
type Manager struct {
|
||||
servers map[string]*ServerConnection
|
||||
mu sync.RWMutex
|
||||
closed atomic.Bool // changed from bool to atomic.Bool to avoid TOCTOU race
|
||||
wg sync.WaitGroup // tracks in-flight CallTool calls
|
||||
servers map[string]*ServerConnection
|
||||
runtimeEvents runtimeevents.Bus
|
||||
mu sync.RWMutex
|
||||
closed atomic.Bool // changed from bool to atomic.Bool to avoid TOCTOU race
|
||||
wg sync.WaitGroup // tracks in-flight CallTool calls
|
||||
}
|
||||
|
||||
var connectServerFunc = connectServer
|
||||
|
||||
// ManagerOption configures an MCP manager.
|
||||
type ManagerOption func(*Manager)
|
||||
|
||||
// WithRuntimeEvents injects the runtime event bus used for MCP observations.
|
||||
func WithRuntimeEvents(eventBus runtimeevents.Bus) ManagerOption {
|
||||
return func(m *Manager) {
|
||||
m.runtimeEvents = eventBus
|
||||
}
|
||||
}
|
||||
|
||||
// ServerEventPayload describes MCP server connection events.
|
||||
type ServerEventPayload struct {
|
||||
Server string `json:"server"`
|
||||
Type string `json:"type,omitempty"`
|
||||
URL string `json:"url,omitempty"`
|
||||
Command string `json:"command,omitempty"`
|
||||
Tool string `json:"tool,omitempty"`
|
||||
ToolCount int `json:"tool_count,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// NewManager creates a new MCP manager
|
||||
func NewManager() *Manager {
|
||||
return &Manager{
|
||||
func NewManager(opts ...ManagerOption) *Manager {
|
||||
m := &Manager{
|
||||
servers: make(map[string]*ServerConnection),
|
||||
}
|
||||
for _, opt := range opts {
|
||||
if opt != nil {
|
||||
opt(m)
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// LoadFromConfig loads MCP servers from configuration
|
||||
@@ -264,8 +293,10 @@ func (m *Manager) ConnectServer(
|
||||
name string,
|
||||
cfg config.MCPServerConfig,
|
||||
) error {
|
||||
m.publishServerEvent(runtimeevents.KindMCPServerConnecting, name, cfg, 0, nil)
|
||||
conn, err := connectServerFunc(ctx, name, cfg)
|
||||
if err != nil {
|
||||
m.publishServerEvent(runtimeevents.KindMCPServerFailed, name, cfg, 0, err)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -274,10 +305,19 @@ func (m *Manager) ConnectServer(
|
||||
|
||||
if m.closed.Load() {
|
||||
_ = conn.Session.Close()
|
||||
m.publishServerEvent(runtimeevents.KindMCPServerFailed, name, cfg, 0, fmt.Errorf("manager is closed"))
|
||||
return fmt.Errorf("manager is closed")
|
||||
}
|
||||
|
||||
m.servers[name] = conn
|
||||
for _, tool := range conn.Tools {
|
||||
toolName := ""
|
||||
if tool != nil {
|
||||
toolName = tool.Name
|
||||
}
|
||||
m.publishToolDiscovered(name, cfg, toolName)
|
||||
}
|
||||
m.publishServerEvent(runtimeevents.KindMCPServerConnected, name, cfg, len(conn.Tools), nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -10,11 +10,13 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
|
||||
sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
||||
)
|
||||
|
||||
func TestLoadEnvFile(t *testing.T) {
|
||||
@@ -248,6 +250,87 @@ func TestNewManager_InitialState(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectServerPublishesRuntimeEvents(t *testing.T) {
|
||||
originalConnectServerFunc := connectServerFunc
|
||||
t.Cleanup(func() {
|
||||
connectServerFunc = originalConnectServerFunc
|
||||
})
|
||||
|
||||
eventBus := runtimeevents.NewBus()
|
||||
defer func() {
|
||||
if err := eventBus.Close(); err != nil {
|
||||
t.Errorf("event bus close failed: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
_, eventsCh, err := eventBus.Channel().OfKind(
|
||||
runtimeevents.KindMCPServerConnected,
|
||||
runtimeevents.KindMCPServerFailed,
|
||||
).SubscribeChan(t.Context(), runtimeevents.SubscribeOptions{Name: "mcp-events", Buffer: 2})
|
||||
if err != nil {
|
||||
t.Fatalf("SubscribeChan failed: %v", err)
|
||||
}
|
||||
|
||||
connectServerFunc = func(
|
||||
_ context.Context,
|
||||
name string,
|
||||
cfg config.MCPServerConfig,
|
||||
) (*ServerConnection, error) {
|
||||
if name == "bad" {
|
||||
return nil, fmt.Errorf("connect failed")
|
||||
}
|
||||
return &ServerConnection{
|
||||
Name: name,
|
||||
Config: cfg,
|
||||
Tools: []*sdkmcp.Tool{{Name: "echo"}},
|
||||
}, nil
|
||||
}
|
||||
|
||||
mgr := NewManager(WithRuntimeEvents(eventBus))
|
||||
err = mgr.ConnectServer(context.Background(), "good", config.MCPServerConfig{
|
||||
Type: "stdio",
|
||||
Command: "echo",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ConnectServer(good) error = %v", err)
|
||||
}
|
||||
connected := receiveMCPRuntimeEvent(t, eventsCh)
|
||||
if connected.Kind != runtimeevents.KindMCPServerConnected ||
|
||||
connected.Source.Name != "good" ||
|
||||
connected.Severity != runtimeevents.SeverityInfo {
|
||||
t.Fatalf("connected event = %+v", connected)
|
||||
}
|
||||
|
||||
err = mgr.ConnectServer(context.Background(), "bad", config.MCPServerConfig{
|
||||
Type: "stdio",
|
||||
Command: "echo",
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected ConnectServer(bad) to fail")
|
||||
}
|
||||
failed := receiveMCPRuntimeEvent(t, eventsCh)
|
||||
if failed.Kind != runtimeevents.KindMCPServerFailed ||
|
||||
failed.Source.Name != "bad" ||
|
||||
failed.Severity != runtimeevents.SeverityError {
|
||||
t.Fatalf("failed event = %+v", failed)
|
||||
}
|
||||
}
|
||||
|
||||
func receiveMCPRuntimeEvent(t *testing.T, ch <-chan runtimeevents.Event) runtimeevents.Event {
|
||||
t.Helper()
|
||||
|
||||
select {
|
||||
case evt, ok := <-ch:
|
||||
if !ok {
|
||||
t.Fatal("runtime event channel closed before expected event")
|
||||
}
|
||||
return evt
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for runtime event")
|
||||
return runtimeevents.Event{}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadFromMCPConfig_DisabledOrEmptyServers(t *testing.T) {
|
||||
mgr := NewManager()
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
|
||||
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
toolshared "github.com/sipeed/picoclaw/pkg/tools/shared"
|
||||
@@ -36,6 +37,16 @@ type MCPTool struct {
|
||||
mediaStore media.MediaStore
|
||||
workspace string
|
||||
maxInlineTextRunes int
|
||||
runtimeEvents runtimeevents.Bus
|
||||
}
|
||||
|
||||
// MCPToolCallPayload describes MCP tool execution runtime events.
|
||||
type MCPToolCallPayload struct {
|
||||
Server string `json:"server"`
|
||||
Tool string `json:"tool"`
|
||||
DurationMS int64 `json:"duration_ms,omitempty"`
|
||||
IsError bool `json:"is_error,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// NewMCPTool creates a new MCP tool wrapper
|
||||
@@ -62,6 +73,11 @@ func (t *MCPTool) SetMaxInlineTextRunes(limit int) {
|
||||
}
|
||||
}
|
||||
|
||||
// SetEventPublisher injects the runtime event bus used for MCP tool observations.
|
||||
func (t *MCPTool) SetEventPublisher(eventBus runtimeevents.Bus) {
|
||||
t.runtimeEvents = eventBus
|
||||
}
|
||||
|
||||
const maxMCPInlineTextRunes = 16 * 1024
|
||||
|
||||
// sanitizeIdentifierComponent normalizes a string so it can be safely used
|
||||
@@ -237,26 +253,74 @@ func (t *MCPTool) Parameters() map[string]any {
|
||||
|
||||
// Execute executes the MCP tool
|
||||
func (t *MCPTool) Execute(ctx context.Context, args map[string]any) *ToolResult {
|
||||
startedAt := time.Now()
|
||||
t.publishRuntimeEvent(ctx, runtimeevents.KindMCPToolCallStart, startedAt, false, "")
|
||||
|
||||
result, err := t.manager.CallTool(ctx, t.serverName, t.tool.Name, args)
|
||||
if err != nil {
|
||||
t.publishRuntimeEvent(ctx, runtimeevents.KindMCPToolCallEnd, startedAt, true, err.Error())
|
||||
return ErrorResult(fmt.Sprintf("MCP tool execution failed: %v", err)).WithError(err)
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
nilErr := fmt.Errorf("MCP tool returned nil result without error")
|
||||
t.publishRuntimeEvent(ctx, runtimeevents.KindMCPToolCallEnd, startedAt, true, nilErr.Error())
|
||||
return ErrorResult("MCP tool execution failed: nil result").WithError(nilErr)
|
||||
}
|
||||
|
||||
// Handle error result from server
|
||||
if result.IsError {
|
||||
errMsg := extractContentText(result.Content)
|
||||
t.publishRuntimeEvent(ctx, runtimeevents.KindMCPToolCallEnd, startedAt, true, errMsg)
|
||||
return ErrorResult(fmt.Sprintf("MCP tool returned error: %s", errMsg)).
|
||||
WithError(fmt.Errorf("MCP tool error: %s", errMsg))
|
||||
}
|
||||
|
||||
t.publishRuntimeEvent(ctx, runtimeevents.KindMCPToolCallEnd, startedAt, false, "")
|
||||
return t.normalizeResultContent(ctx, result.Content)
|
||||
}
|
||||
|
||||
func (t *MCPTool) publishRuntimeEvent(
|
||||
ctx context.Context,
|
||||
kind runtimeevents.Kind,
|
||||
startedAt time.Time,
|
||||
isError bool,
|
||||
errMsg string,
|
||||
) {
|
||||
if t == nil || t.runtimeEvents == nil {
|
||||
return
|
||||
}
|
||||
|
||||
scope := runtimeevents.Scope{
|
||||
AgentID: toolshared.ToolAgentID(ctx),
|
||||
SessionKey: toolshared.ToolSessionKey(ctx),
|
||||
Channel: toolshared.ToolChannel(ctx),
|
||||
ChatID: toolshared.ToolChatID(ctx),
|
||||
MessageID: toolshared.ToolMessageID(ctx),
|
||||
}
|
||||
payload := MCPToolCallPayload{
|
||||
Server: t.serverName,
|
||||
Tool: t.tool.Name,
|
||||
DurationMS: time.Since(startedAt).Milliseconds(),
|
||||
IsError: isError,
|
||||
Error: errMsg,
|
||||
}
|
||||
severity := runtimeevents.SeverityInfo
|
||||
if isError {
|
||||
severity = runtimeevents.SeverityError
|
||||
}
|
||||
|
||||
publishCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
t.runtimeEvents.Publish(publishCtx, runtimeevents.Event{
|
||||
Kind: kind,
|
||||
Source: runtimeevents.Source{Component: "mcp", Name: t.serverName},
|
||||
Scope: scope,
|
||||
Severity: severity,
|
||||
Payload: payload,
|
||||
})
|
||||
}
|
||||
|
||||
// extractContentText extracts text from MCP content array
|
||||
func extractContentText(content []mcp.Content) string {
|
||||
var parts []string
|
||||
|
||||
@@ -7,9 +7,11 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
|
||||
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
toolshared "github.com/sipeed/picoclaw/pkg/tools/shared"
|
||||
)
|
||||
@@ -299,6 +301,72 @@ func TestMCPTool_Execute_Success(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPTool_Execute_PublishesRuntimeEvents(t *testing.T) {
|
||||
eventBus := runtimeevents.NewBus()
|
||||
defer func() {
|
||||
if err := eventBus.Close(); err != nil {
|
||||
t.Errorf("event bus close failed: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
_, eventsCh, err := eventBus.Channel().OfKind(
|
||||
runtimeevents.KindMCPToolCallStart,
|
||||
runtimeevents.KindMCPToolCallEnd,
|
||||
).SubscribeChan(t.Context(), runtimeevents.SubscribeOptions{Name: "mcp-tool-events", Buffer: 2})
|
||||
if err != nil {
|
||||
t.Fatalf("SubscribeChan failed: %v", err)
|
||||
}
|
||||
|
||||
manager := &MockMCPManager{}
|
||||
mcpTool := NewMCPTool(manager, "github", &mcp.Tool{Name: "search_repos"})
|
||||
mcpTool.SetEventPublisher(eventBus)
|
||||
|
||||
ctx := toolshared.WithToolContext(context.Background(), "telegram", "chat-1")
|
||||
ctx = toolshared.WithToolMessageContext(ctx, "msg-1", "")
|
||||
ctx = toolshared.WithToolSessionContext(ctx, "main", "session-1", nil)
|
||||
result := mcpTool.Execute(ctx, map[string]any{"query": "picoclaw"})
|
||||
if result == nil || result.IsError {
|
||||
t.Fatalf("Execute result = %+v", result)
|
||||
}
|
||||
|
||||
started := receiveMCPToolRuntimeEvent(t, eventsCh)
|
||||
if started.Kind != runtimeevents.KindMCPToolCallStart ||
|
||||
started.Scope.AgentID != "main" ||
|
||||
started.Scope.SessionKey != "session-1" ||
|
||||
started.Scope.Channel != "telegram" ||
|
||||
started.Scope.ChatID != "chat-1" ||
|
||||
started.Scope.MessageID != "msg-1" {
|
||||
t.Fatalf("started event = %+v", started)
|
||||
}
|
||||
|
||||
ended := receiveMCPToolRuntimeEvent(t, eventsCh)
|
||||
if ended.Kind != runtimeevents.KindMCPToolCallEnd || ended.Severity != runtimeevents.SeverityInfo {
|
||||
t.Fatalf("ended event = %+v", ended)
|
||||
}
|
||||
payload, ok := ended.Payload.(MCPToolCallPayload)
|
||||
if !ok {
|
||||
t.Fatalf("ended payload = %T, want MCPToolCallPayload", ended.Payload)
|
||||
}
|
||||
if payload.Server != "github" || payload.Tool != "search_repos" || payload.IsError {
|
||||
t.Fatalf("ended payload = %+v", payload)
|
||||
}
|
||||
}
|
||||
|
||||
func receiveMCPToolRuntimeEvent(t *testing.T, ch <-chan runtimeevents.Event) runtimeevents.Event {
|
||||
t.Helper()
|
||||
|
||||
select {
|
||||
case evt, ok := <-ch:
|
||||
if !ok {
|
||||
t.Fatal("runtime event channel closed before expected event")
|
||||
}
|
||||
return evt
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for runtime event")
|
||||
return runtimeevents.Event{}
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPTool_Execute_ManagerError tests execution when manager returns error
|
||||
func TestMCPTool_Execute_ManagerError(t *testing.T) {
|
||||
manager := &MockMCPManager{
|
||||
|
||||
Reference in New Issue
Block a user