From 2095ec8700343935b2a296102d4f77fad38eb07a Mon Sep 17 00:00:00 2001 From: Hoshina Date: Wed, 1 Apr 2026 14:08:44 +0800 Subject: [PATCH] refactor(agent): route using inbound context --- pkg/agent/loop.go | 80 +++++++++++++++++++++++++++++++------ pkg/agent/loop_test.go | 89 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 156 insertions(+), 13 deletions(-) diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 84b783985..78b91068a 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -1372,13 +1372,18 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) func (al *AgentLoop) resolveMessageRoute(msg bus.InboundMessage) (routing.ResolvedRoute, *AgentInstance, error) { registry := al.GetRegistry() + inboundCtx := normalizedInboundContext(msg) + channel := strings.TrimSpace(inboundCtx.Channel) + if channel == "" { + channel = msg.Channel + } route := registry.ResolveRoute(routing.RouteInput{ - Channel: msg.Channel, - AccountID: inboundMetadata(msg, metadataKeyAccountID), + Channel: channel, + AccountID: routeAccountID(msg), Peer: extractPeer(msg), ParentPeer: extractParentPeer(msg), - GuildID: inboundMetadata(msg, metadataKeyGuildID), - TeamID: inboundMetadata(msg, metadataKeyTeamID), + GuildID: routeGuildID(msg), + TeamID: routeTeamID(msg), }) agent, ok := registry.GetAgent(route.AgentID) @@ -1392,6 +1397,10 @@ func (al *AgentLoop) resolveMessageRoute(msg bus.InboundMessage) (routing.Resolv return route, agent, nil } +func normalizedInboundContext(msg bus.InboundMessage) bus.InboundContext { + return bus.NormalizeInboundMessage(msg).Context +} + func resolveScopeKey(route routing.ResolvedRoute, msgSessionKey string) string { if msgSessionKey != "" && strings.HasPrefix(msgSessionKey, sessionKeyAgentPrefix) { return msgSessionKey @@ -3553,18 +3562,32 @@ func mapCommandError(result commands.ExecuteResult) string { // extractPeer extracts the routing peer from the inbound message's structured Peer field. func extractPeer(msg bus.InboundMessage) *routing.RoutePeer { - if msg.Peer.Kind == "" { + if msg.Peer.Kind != "" { + peerID := msg.Peer.ID + if peerID == "" { + if msg.Peer.Kind == "direct" { + peerID = msg.SenderID + } else { + peerID = msg.ChatID + } + } + return &routing.RoutePeer{Kind: msg.Peer.Kind, ID: peerID} + } + + inboundCtx := normalizedInboundContext(msg) + peerKind := strings.TrimSpace(inboundCtx.ChatType) + if peerKind == "" { return nil } - peerID := msg.Peer.ID - if peerID == "" { - if msg.Peer.Kind == "direct" { - peerID = msg.SenderID - } else { - peerID = msg.ChatID - } + + peerID := strings.TrimSpace(inboundCtx.ChatID) + if peerKind == "direct" && peerID == "" { + peerID = strings.TrimSpace(inboundCtx.SenderID) } - return &routing.RoutePeer{Kind: msg.Peer.Kind, ID: peerID} + if peerID == "" { + return nil + } + return &routing.RoutePeer{Kind: peerKind, ID: peerID} } func inboundMetadata(msg bus.InboundMessage, key string) string { @@ -3576,6 +3599,11 @@ func inboundMetadata(msg bus.InboundMessage, key string) string { // extractParentPeer extracts the parent peer (reply-to) from inbound message metadata. func extractParentPeer(msg bus.InboundMessage) *routing.RoutePeer { + inboundCtx := normalizedInboundContext(msg) + if topicID := strings.TrimSpace(inboundCtx.TopicID); topicID != "" { + return &routing.RoutePeer{Kind: "topic", ID: topicID} + } + parentKind := inboundMetadata(msg, metadataKeyParentPeerKind) parentID := inboundMetadata(msg, metadataKeyParentPeerID) if parentKind == "" || parentID == "" { @@ -3584,6 +3612,32 @@ func extractParentPeer(msg bus.InboundMessage) *routing.RoutePeer { return &routing.RoutePeer{Kind: parentKind, ID: parentID} } +func routeAccountID(msg bus.InboundMessage) string { + if accountID := strings.TrimSpace(normalizedInboundContext(msg).Account); accountID != "" { + return accountID + } + return inboundMetadata(msg, metadataKeyAccountID) +} + +func routeGuildID(msg bus.InboundMessage) string { + inboundCtx := normalizedInboundContext(msg) + if strings.EqualFold(strings.TrimSpace(inboundCtx.SpaceType), "guild") { + return strings.TrimSpace(inboundCtx.SpaceID) + } + return inboundMetadata(msg, metadataKeyGuildID) +} + +func routeTeamID(msg bus.InboundMessage) string { + inboundCtx := normalizedInboundContext(msg) + switch strings.ToLower(strings.TrimSpace(inboundCtx.SpaceType)) { + case "team", "workspace": + if spaceID := strings.TrimSpace(inboundCtx.SpaceID); spaceID != "" { + return spaceID + } + } + return inboundMetadata(msg, metadataKeyTeamID) +} + // isNativeSearchProvider reports whether the given LLM provider implements // NativeSearchCapable and returns true for SupportsNativeSearch. func isNativeSearchProvider(p providers.LLMProvider) bool { diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go index 9513d8aca..54235b23a 100644 --- a/pkg/agent/loop_test.go +++ b/pkg/agent/loop_test.go @@ -734,6 +734,95 @@ func TestProcessMessage_HandledToolProcessesQueuedSteeringBeforeReturning(t *tes } } +func TestExtractPeer_UsesInboundContextWhenLegacyPeerMissing(t *testing.T) { + msg := bus.InboundMessage{ + Context: bus.InboundContext{ + Channel: "slack", + ChatID: "C001", + ChatType: "channel", + SenderID: "U001", + }, + } + + peer := extractPeer(msg) + if peer == nil { + t.Fatal("expected peer from inbound context") + } + if peer.Kind != "channel" || peer.ID != "C001" { + t.Fatalf("peer = %+v, want channel/C001", peer) + } +} + +func TestExtractParentPeer_UsesInboundContextTopicID(t *testing.T) { + msg := bus.InboundMessage{ + Context: bus.InboundContext{ + TopicID: "thread-42", + }, + } + + parentPeer := extractParentPeer(msg) + if parentPeer == nil { + t.Fatal("expected parent peer from topic context") + } + if parentPeer.Kind != "topic" || parentPeer.ID != "thread-42" { + t.Fatalf("parent peer = %+v, want topic/thread-42", parentPeer) + } +} + +func TestResolveMessageRoute_UsesInboundContextAccountAndSpace(t *testing.T) { + tmpDir := t.TempDir() + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + ModelName: "test-model", + }, + List: []config.AgentConfig{ + {ID: "main", Default: true}, + {ID: "work"}, + }, + }, + Bindings: []config.AgentBinding{ + { + AgentID: "work", + Match: config.BindingMatch{ + Channel: "slack", + AccountID: "*", + TeamID: "T001", + }, + }, + }, + Session: config.SessionConfig{ + DMScope: "per-peer", + }, + } + + msgBus := bus.NewMessageBus() + al := NewAgentLoop(cfg, msgBus, &simpleMockProvider{response: "ok"}) + + route, _, err := al.resolveMessageRoute(bus.InboundMessage{ + Context: bus.InboundContext{ + Channel: "slack", + Account: "workspace-a", + ChatID: "C123", + ChatType: "channel", + SenderID: "U123", + SpaceID: "T001", + SpaceType: "workspace", + }, + Content: "hello", + }) + if err != nil { + t.Fatalf("resolveMessageRoute() error = %v", err) + } + if route.AgentID != "work" { + t.Fatalf("AgentID = %q, want work", route.AgentID) + } + if route.MatchedBy != "binding.team" { + t.Fatalf("MatchedBy = %q, want binding.team", route.MatchedBy) + } +} + func TestProcessMessage_MediaArtifactCanBeForwardedBySendFile(t *testing.T) { tmpDir := t.TempDir() cfg := config.DefaultConfig()