From 59dee895fc906827df14f05fb36303a31686d080 Mon Sep 17 00:00:00 2001 From: Hoshina Date: Wed, 1 Apr 2026 20:56:48 +0800 Subject: [PATCH] refactor(runtime): drop non-session legacy context compatibility --- pkg/agent/eventbus_test.go | 6 - pkg/agent/events.go | 4 - pkg/agent/hooks.go | 10 - pkg/agent/loop.go | 183 +++---------- pkg/agent/loop_test.go | 243 ++++++++---------- pkg/agent/registry.go | 7 +- pkg/agent/steering.go | 3 +- pkg/agent/steering_test.go | 62 +++-- pkg/bus/bus.go | 15 ++ pkg/bus/bus_test.go | 174 +++++++++---- pkg/bus/inbound_context.go | 216 +--------------- pkg/bus/outbound_context.go | 64 ++--- pkg/bus/types.go | 35 +-- pkg/channels/base.go | 46 +--- pkg/channels/base_test.go | 56 ++++ pkg/channels/dingtalk/dingtalk.go | 32 ++- pkg/channels/discord/discord.go | 6 +- pkg/channels/feishu/feishu_64.go | 35 ++- pkg/channels/irc/handler.go | 23 +- pkg/channels/line/line.go | 11 +- pkg/channels/maixcam/maixcam.go | 20 +- pkg/channels/manager.go | 66 +++-- pkg/channels/manager_test.go | 181 +++++++++---- pkg/channels/matrix/matrix.go | 26 +- pkg/channels/onebot/onebot.go | 6 +- pkg/channels/pico/client.go | 19 +- pkg/channels/pico/pico.go | 13 +- pkg/channels/qq/qq.go | 20 +- pkg/channels/slack/slack.go | 24 +- pkg/channels/telegram/telegram.go | 5 - pkg/channels/wecom/wecom.go | 3 +- pkg/channels/weixin/weixin.go | 18 +- pkg/channels/whatsapp/whatsapp.go | 22 +- .../whatsapp_native/whatsapp_native.go | 13 +- pkg/config/config.go | 6 +- pkg/devices/service.go | 3 +- pkg/heartbeat/service.go | 3 +- pkg/routing/route.go | 79 ++++-- pkg/routing/route_test.go | 73 +++--- pkg/routing/session_key.go | 218 ---------------- pkg/routing/session_key_test.go | 207 --------------- pkg/session/allocator.go | 41 +-- pkg/session/key.go | 135 +++++++++- pkg/session/key_test.go | 72 ++++++ pkg/tools/cron.go | 6 +- 45 files changed, 1083 insertions(+), 1427 deletions(-) delete mode 100644 pkg/routing/session_key.go delete mode 100644 pkg/routing/session_key_test.go create mode 100644 pkg/session/key_test.go diff --git a/pkg/agent/eventbus_test.go b/pkg/agent/eventbus_test.go index 574d7bbcc..66046f87b 100644 --- a/pkg/agent/eventbus_test.go +++ b/pkg/agent/eventbus_test.go @@ -610,12 +610,6 @@ func TestAgentLoop_EmitsFollowUpQueuedEvent(t *testing.T) { if payload.SourceTool != "async_followup" { t.Fatalf("expected source tool async_followup, got %q", payload.SourceTool) } - if payload.Channel != "cli" { - t.Fatalf("expected channel cli, got %q", payload.Channel) - } - if payload.ChatID != "direct" { - t.Fatalf("expected chat id direct, got %q", payload.ChatID) - } if payload.ContentLen != len("background result") { t.Fatalf("expected content len %d, got %d", len("background result"), payload.ContentLen) } diff --git a/pkg/agent/events.go b/pkg/agent/events.go index d17f5a90b..6741d0053 100644 --- a/pkg/agent/events.go +++ b/pkg/agent/events.go @@ -116,8 +116,6 @@ const ( // TurnStartPayload describes the start of a turn. type TurnStartPayload struct { - Channel string - ChatID string UserMessage string MediaCount int } @@ -217,8 +215,6 @@ type SteeringInjectedPayload struct { // FollowUpQueuedPayload describes an async follow-up queued back into the inbound bus. type FollowUpQueuedPayload struct { SourceTool string - Channel string - ChatID string ContentLen int } diff --git a/pkg/agent/hooks.go b/pkg/agent/hooks.go index c3c4b21ce..0e0c139ae 100644 --- a/pkg/agent/hooks.go +++ b/pkg/agent/hooks.go @@ -94,8 +94,6 @@ type LLMHookRequest struct { Messages []providers.Message `json:"messages,omitempty"` Tools []providers.ToolDefinition `json:"tools,omitempty"` Options map[string]any `json:"options,omitempty"` - Channel string `json:"channel,omitempty"` - ChatID string `json:"chat_id,omitempty"` GracefulTerminal bool `json:"graceful_terminal,omitempty"` } @@ -117,8 +115,6 @@ type LLMHookResponse struct { Context *TurnContext `json:"context,omitempty"` Model string `json:"model"` Response *providers.LLMResponse `json:"response,omitempty"` - Channel string `json:"channel,omitempty"` - ChatID string `json:"chat_id,omitempty"` } func (r *LLMHookResponse) Clone() *LLMHookResponse { @@ -137,8 +133,6 @@ type ToolCallHookRequest struct { Context *TurnContext `json:"context,omitempty"` Tool string `json:"tool"` Arguments map[string]any `json:"arguments,omitempty"` - Channel string `json:"channel,omitempty"` - ChatID string `json:"chat_id,omitempty"` } func (r *ToolCallHookRequest) Clone() *ToolCallHookRequest { @@ -157,8 +151,6 @@ type ToolApprovalRequest struct { Context *TurnContext `json:"context,omitempty"` Tool string `json:"tool"` Arguments map[string]any `json:"arguments,omitempty"` - Channel string `json:"channel,omitempty"` - ChatID string `json:"chat_id,omitempty"` } func (r *ToolApprovalRequest) Clone() *ToolApprovalRequest { @@ -179,8 +171,6 @@ type ToolResultHookResponse struct { Arguments map[string]any `json:"arguments,omitempty"` Result *tools.ToolResult `json:"result,omitempty"` Duration time.Duration `json:"duration"` - Channel string `json:"channel,omitempty"` - ChatID string `json:"chat_id,omitempty"` } func (r *ToolResultHookResponse) Clone() *ToolResultHookResponse { diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 70827598a..b12ad5b1d 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -107,14 +107,6 @@ const ( defaultResponse = "The model returned an empty response. This may indicate a provider error or token limit." toolLimitResponse = "I've reached `max_tool_iterations` without a final response. Increase `max_tool_iterations` in config.json if this task needs more tool steps." handledToolResponseSummary = "Requested output delivered via tool attachment." - sessionKeyAgentPrefix = "agent:" - sessionKeyOpaquePrefix = "sk_" - metadataKeyAccountID = "account_id" - metadataKeyGuildID = "guild_id" - metadataKeyTeamID = "team_id" - metadataKeyReplyToMessage = "reply_to_message_id" - metadataKeyParentPeerKind = "parent_peer_kind" - metadataKeyParentPeerID = "parent_peer_id" ) func NewAgentLoop( @@ -234,9 +226,9 @@ func registerSharedTools( messageTool.SetSendCallback(func(channel, chatID, content, replyToMessageID string) error { pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second) defer pubCancel() + outboundCtx := bus.NewOutboundContext(channel, chatID, replyToMessageID) return msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{ - Channel: channel, - ChatID: chatID, + Context: outboundCtx, Content: content, ReplyToMessageID: replyToMessageID, }) @@ -657,8 +649,7 @@ func (al *AgentLoop) PublishResponseIfNeeded(ctx context.Context, channel, chatI } al.bus.PublishOutbound(ctx, bus.OutboundMessage{ - Channel: channel, - ChatID: chatID, + Context: bus.NewOutboundContext(channel, chatID, ""), Content: response, }) logger.InfoCF("agent", "Published outbound response", @@ -714,11 +705,7 @@ func outboundContextFromInbound( channel, chatID, replyToMessageID string, ) bus.InboundContext { if inbound == nil { - return bus.ContextFromLegacyOutbound(bus.OutboundMessage{ - Channel: channel, - ChatID: chatID, - ReplyToMessageID: replyToMessageID, - }) + return bus.NewOutboundContext(channel, chatID, replyToMessageID) } outboundCtx := *cloneInboundContext(inbound) @@ -736,8 +723,6 @@ func outboundContextFromInbound( func outboundMessageForTurn(ts *turnState, content string) bus.OutboundMessage { return bus.OutboundMessage{ - Channel: ts.channel, - ChatID: ts.chatID, Context: outboundContextFromInbound( ts.opts.InboundContext, ts.channel, @@ -894,8 +879,6 @@ func (al *AgentLoop) logEvent(evt Event) { switch payload := evt.Payload.(type) { case TurnStartPayload: - fields["channel"] = payload.Channel - fields["chat_id"] = payload.ChatID fields["user_len"] = len(payload.UserMessage) fields["media_count"] = payload.MediaCount case TurnEndPayload: @@ -948,8 +931,6 @@ func (al *AgentLoop) logEvent(evt Event) { fields["total_content_len"] = payload.TotalContentLen case FollowUpQueuedPayload: fields["source_tool"] = payload.SourceTool - fields["channel"] = payload.Channel - fields["chat_id"] = payload.ChatID fields["content_len"] = payload.ContentLen case InterruptReceivedPayload: fields["interrupt_kind"] = payload.Kind @@ -1292,8 +1273,7 @@ func (al *AgentLoop) sendTranscriptionFeedback( } err := al.channelManager.SendMessage(ctx, bus.OutboundMessage{ - Channel: channel, - ChatID: chatID, + Context: bus.NewOutboundContext(channel, chatID, messageID), Content: feedbackMsg, ReplyToMessageID: messageID, }) @@ -1369,13 +1349,15 @@ func (al *AgentLoop) ProcessDirectWithChannel( } msg := bus.InboundMessage{ - Channel: channel, - SenderID: "cron", - ChatID: chatID, + Context: bus.InboundContext{ + Channel: channel, + ChatID: chatID, + ChatType: "direct", + SenderID: "cron", + }, Content: content, SessionKey: sessionKey, } - msg.Context = bus.ContextFromLegacyInbound(msg) return al.processMessage(ctx, msg) } @@ -1481,7 +1463,7 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) Channel: msg.Channel, ChatID: msg.ChatID, MessageID: msg.MessageID, - ReplyToMessageID: inboundMetadata(msg, metadataKeyReplyToMessage), + ReplyToMessageID: msg.Context.ReplyToMessageID, SenderID: msg.SenderID, SenderDisplayName: msg.Sender.DisplayName, UserMessage: msg.Content, @@ -1515,18 +1497,7 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) func (al *AgentLoop) resolveMessageRoute(msg bus.InboundMessage) (routing.ResolvedRoute, *AgentInstance, error) { registry := al.GetRegistry() inboundCtx := normalizedInboundContext(msg) - channel := strings.TrimSpace(inboundCtx.Channel) - if channel == "" { - channel = msg.Channel - } - route := registry.ResolveRoute(routing.RouteInput{ - Channel: channel, - AccountID: routeAccountID(msg), - Peer: extractPeer(msg), - ParentPeer: extractParentPeer(msg), - GuildID: routeGuildID(msg), - TeamID: routeTeamID(msg), - }) + route := registry.ResolveRoute(inboundCtx) agent, ok := registry.GetAgent(route.AgentID) if !ok { @@ -1551,8 +1522,7 @@ func resolveScopeKey(routeSessionKey, msgSessionKey string) string { } func isExplicitSessionKey(sessionKey string) bool { - sessionKey = strings.TrimSpace(strings.ToLower(sessionKey)) - return strings.HasPrefix(sessionKey, sessionKeyAgentPrefix) || strings.HasPrefix(sessionKey, sessionKeyOpaquePrefix) + return session.IsExplicitSessionKey(sessionKey) } func buildSessionAliases(canonicalKey string, keys ...string) []string { @@ -1621,8 +1591,7 @@ func (al *AgentLoop) requeueInboundMessage(msg bus.InboundMessage) error { pubCtx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() return al.bus.PublishOutbound(pubCtx, bus.OutboundMessage{ - Channel: msg.Channel, - ChatID: msg.ChatID, + Context: msg.Context, Content: msg.Content, }) } @@ -1679,7 +1648,7 @@ func (al *AgentLoop) processSystemMessage( } // Use the origin session for context - sessionKey := routing.BuildAgentMainSessionKey(agent.ID) + sessionKey := session.BuildMainSessionKey(agent.ID) return al.runAgentLoop(ctx, agent, processOptions{ SessionKey: sessionKey, @@ -1739,8 +1708,6 @@ func (al *AgentLoop) runAgentLoop( if opts.SendResponse && result.finalContent != "" { al.bus.PublishOutbound(ctx, bus.OutboundMessage{ - Channel: opts.Channel, - ChatID: opts.ChatID, Context: outboundContextFromInbound( opts.InboundContext, opts.Channel, @@ -1796,8 +1763,7 @@ func (al *AgentLoop) handleReasoning( defer pubCancel() if err := al.bus.PublishOutbound(pubCtx, bus.OutboundMessage{ - Channel: channelName, - ChatID: channelID, + Context: bus.NewOutboundContext(channelName, channelID, ""), Content: reasoningContent, }); err != nil { // Treat context.DeadlineExceeded / context.Canceled as expected @@ -1851,8 +1817,6 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er EventKindTurnStart, ts.eventMeta("runTurn", "turn.start"), TurnStartPayload{ - Channel: ts.channel, - ChatID: ts.chatID, UserMessage: ts.userMessage, MediaCount: len(ts.media), }, @@ -2085,8 +2049,6 @@ turnLoop: Messages: callMessages, Tools: providerToolDefs, Options: llmOpts, - Channel: ts.channel, - ChatID: ts.chatID, GracefulTerminal: gracefulTerminal, }) switch decision.normalizedAction() { @@ -2314,8 +2276,6 @@ turnLoop: Context: cloneTurnContext(ts.turnCtx), Model: llmModel, Response: response, - Channel: ts.channel, - ChatID: ts.chatID, }) switch decision.normalizedAction() { case HookActionContinue, HookActionModify: @@ -2346,7 +2306,7 @@ turnLoop: reasoningContent = response.ReasoningContent } go al.handleReasoning( - turnCtx, + ctx, reasoningContent, ts.channel, al.targetReasoningChannelID(ts.channel), @@ -2467,8 +2427,6 @@ turnLoop: Context: cloneTurnContext(ts.turnCtx), Tool: toolName, Arguments: toolArgs, - Channel: ts.channel, - ChatID: ts.chatID, }) switch decision.normalizedAction() { case HookActionContinue, HookActionModify: @@ -2514,8 +2472,6 @@ turnLoop: Context: cloneTurnContext(ts.turnCtx), Tool: toolName, Arguments: toolArgs, - Channel: ts.channel, - ChatID: ts.chatID, }) if !approval.Approved { allResponsesHandled = false @@ -2605,8 +2561,6 @@ turnLoop: ts.scope.meta(toolIteration, "runTurn", "turn.follow_up.queued"), FollowUpQueuedPayload{ SourceTool: asyncToolName, - Channel: ts.channel, - ChatID: ts.chatID, ContentLen: len(content), }, ) @@ -2614,10 +2568,13 @@ turnLoop: pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second) defer pubCancel() _ = al.bus.PublishInbound(pubCtx, bus.InboundMessage{ - Channel: "system", - SenderID: fmt.Sprintf("async:%s", asyncToolName), - ChatID: fmt.Sprintf("%s:%s", ts.channel, ts.chatID), - Content: content, + Context: bus.InboundContext{ + Channel: "system", + ChatID: fmt.Sprintf("%s:%s", ts.channel, ts.chatID), + ChatType: "direct", + SenderID: fmt.Sprintf("async:%s", asyncToolName), + }, + Content: content, }) } @@ -2652,8 +2609,6 @@ turnLoop: Arguments: toolArgs, Result: toolResult, Duration: toolDuration, - Channel: ts.channel, - ChatID: ts.chatID, }) switch decision.normalizedAction() { case HookActionContinue, HookActionModify: @@ -2692,9 +2647,13 @@ turnLoop: parts = append(parts, part) } outboundMedia := bus.OutboundMediaMessage{ - Channel: ts.channel, - ChatID: ts.chatID, - Parts: parts, + Context: outboundContextFromInbound( + ts.opts.InboundContext, + ts.channel, + ts.chatID, + ts.opts.ReplyToMessageID, + ), + Parts: parts, } if al.channelManager != nil && ts.channel != "" && !constants.IsInternalChannel(ts.channel) { if err := al.channelManager.SendMedia(ctx, outboundMedia); err != nil { @@ -3758,84 +3717,6 @@ func mapCommandError(result commands.ExecuteResult) string { return fmt.Sprintf("Failed to execute /%s: %v", result.Command, result.Err) } -// extractPeer extracts the routing peer from the inbound message's structured Peer field. -func extractPeer(msg bus.InboundMessage) *routing.RoutePeer { - if msg.Peer.Kind != "" { - peerID := msg.Peer.ID - if peerID == "" { - if msg.Peer.Kind == "direct" { - peerID = msg.SenderID - } else { - peerID = msg.ChatID - } - } - return &routing.RoutePeer{Kind: msg.Peer.Kind, ID: peerID} - } - - inboundCtx := normalizedInboundContext(msg) - peerKind := strings.TrimSpace(inboundCtx.ChatType) - if peerKind == "" { - return nil - } - - peerID := strings.TrimSpace(inboundCtx.ChatID) - if peerKind == "direct" && peerID == "" { - peerID = strings.TrimSpace(inboundCtx.SenderID) - } - if peerID == "" { - return nil - } - return &routing.RoutePeer{Kind: peerKind, ID: peerID} -} - -func inboundMetadata(msg bus.InboundMessage, key string) string { - if msg.Metadata == nil { - return "" - } - return msg.Metadata[key] -} - -// extractParentPeer extracts the parent peer (reply-to) from inbound message metadata. -func extractParentPeer(msg bus.InboundMessage) *routing.RoutePeer { - inboundCtx := normalizedInboundContext(msg) - if topicID := strings.TrimSpace(inboundCtx.TopicID); topicID != "" { - return &routing.RoutePeer{Kind: "topic", ID: topicID} - } - - parentKind := inboundMetadata(msg, metadataKeyParentPeerKind) - parentID := inboundMetadata(msg, metadataKeyParentPeerID) - if parentKind == "" || parentID == "" { - return nil - } - return &routing.RoutePeer{Kind: parentKind, ID: parentID} -} - -func routeAccountID(msg bus.InboundMessage) string { - if accountID := strings.TrimSpace(normalizedInboundContext(msg).Account); accountID != "" { - return accountID - } - return inboundMetadata(msg, metadataKeyAccountID) -} - -func routeGuildID(msg bus.InboundMessage) string { - inboundCtx := normalizedInboundContext(msg) - if strings.EqualFold(strings.TrimSpace(inboundCtx.SpaceType), "guild") { - return strings.TrimSpace(inboundCtx.SpaceID) - } - return inboundMetadata(msg, metadataKeyGuildID) -} - -func routeTeamID(msg bus.InboundMessage) string { - inboundCtx := normalizedInboundContext(msg) - switch strings.ToLower(strings.TrimSpace(inboundCtx.SpaceType)) { - case "team", "workspace": - if spaceID := strings.TrimSpace(inboundCtx.SpaceID); spaceID != "" { - return spaceID - } - } - return inboundMetadata(msg, metadataKeyTeamID) -} - // isNativeSearchProvider reports whether the given LLM provider implements // NativeSearchCapable and returns true for SupportsNativeSearch. func isNativeSearchProvider(p providers.LLMProvider) bool { diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go index 3efb7ddfd..4aa356f88 100644 --- a/pkg/agent/loop_test.go +++ b/pkg/agent/loop_test.go @@ -140,7 +140,7 @@ func TestProcessMessage_IncludesCurrentSenderInDynamicContext(t *testing.T) { provider := &recordingProvider{} al := NewAgentLoop(cfg, msgBus, provider) - response, err := al.processMessage(context.Background(), bus.InboundMessage{ + response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{ Channel: "discord", SenderID: "discord:123", Sender: bus.SenderInfo{ @@ -148,7 +148,7 @@ func TestProcessMessage_IncludesCurrentSenderInDynamicContext(t *testing.T) { }, ChatID: "group-1", Content: "hello", - }) + })) if err != nil { t.Fatalf("processMessage() error = %v", err) } @@ -199,12 +199,12 @@ func TestProcessMessage_UseCommandLoadsRequestedSkill(t *testing.T) { provider := &recordingProvider{} al := NewAgentLoop(cfg, msgBus, provider) - response, err := al.processMessage(context.Background(), bus.InboundMessage{ + response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{ Channel: "telegram", SenderID: "telegram:123", ChatID: "chat-1", Content: "/use shell explain how to list files", - }) + })) if err != nil { t.Fatalf("processMessage() error = %v", err) } @@ -289,12 +289,12 @@ func TestProcessMessage_UseCommandArmsSkillForNextMessage(t *testing.T) { provider := &recordingProvider{} al := NewAgentLoop(cfg, msgBus, provider) - response, err := al.processMessage(context.Background(), bus.InboundMessage{ + response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{ Channel: "telegram", SenderID: "telegram:123", ChatID: "chat-1", Content: "/use shell", - }) + })) if err != nil { t.Fatalf("processMessage() arm error = %v", err) } @@ -302,12 +302,12 @@ func TestProcessMessage_UseCommandArmsSkillForNextMessage(t *testing.T) { t.Fatalf("arm response = %q, want armed confirmation", response) } - response, err = al.processMessage(context.Background(), bus.InboundMessage{ + response, err = al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{ Channel: "telegram", SenderID: "telegram:123", ChatID: "chat-1", Content: "explain how to list files", - }) + })) if err != nil { t.Fatalf("processMessage() follow-up error = %v", err) } @@ -620,12 +620,12 @@ func TestProcessMessage_MediaToolHandledSkipsFollowUpLLMAndFinalText(t *testing. path: imagePath, }) - response, err := al.processMessage(context.Background(), bus.InboundMessage{ + response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{ Channel: "telegram", ChatID: "chat1", SenderID: "user1", Content: "take a screenshot of the screen and send it to me", - }) + })) if err != nil { t.Fatalf("processMessage() error = %v", err) } @@ -662,21 +662,21 @@ func TestProcessMessage_MediaToolHandledSkipsFollowUpLLMAndFinalText(t *testing. if defaultAgent == nil { t.Fatal("expected default agent") } - route, _, err := al.resolveMessageRoute(bus.InboundMessage{ + route, _, err := al.resolveMessageRoute(testInboundMessage(bus.InboundMessage{ Channel: "telegram", ChatID: "chat1", SenderID: "user1", Content: "take a screenshot of the screen and send it to me", - }) + })) if err != nil { t.Fatalf("resolveMessageRoute() error = %v", err) } - sessionKey := resolveScopeKey(al.allocateRouteSession(route, bus.InboundMessage{ + sessionKey := resolveScopeKey(al.allocateRouteSession(route, testInboundMessage(bus.InboundMessage{ Channel: "telegram", ChatID: "chat1", SenderID: "user1", Content: "take a screenshot of the screen and send it to me", - }).SessionKey, "") + })).SessionKey, "") history := defaultAgent.Sessions.GetHistory(sessionKey) if len(history) == 0 { t.Fatal("expected session history to be saved") @@ -720,12 +720,12 @@ func TestProcessMessage_HandledToolProcessesQueuedSteeringBeforeReturning(t *tes loop: al, }) - response, err := al.processMessage(context.Background(), bus.InboundMessage{ + response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{ Channel: "telegram", ChatID: "chat1", SenderID: "user1", Content: "take a screenshot of the screen and send it to me", - }) + })) if err != nil { t.Fatalf("processMessage() error = %v", err) } @@ -740,41 +740,6 @@ func TestProcessMessage_HandledToolProcessesQueuedSteeringBeforeReturning(t *tes } } -func TestExtractPeer_UsesInboundContextWhenLegacyPeerMissing(t *testing.T) { - msg := bus.InboundMessage{ - Context: bus.InboundContext{ - Channel: "slack", - ChatID: "C001", - ChatType: "channel", - SenderID: "U001", - }, - } - - peer := extractPeer(msg) - if peer == nil { - t.Fatal("expected peer from inbound context") - } - if peer.Kind != "channel" || peer.ID != "C001" { - t.Fatalf("peer = %+v, want channel/C001", peer) - } -} - -func TestExtractParentPeer_UsesInboundContextTopicID(t *testing.T) { - msg := bus.InboundMessage{ - Context: bus.InboundContext{ - TopicID: "thread-42", - }, - } - - parentPeer := extractParentPeer(msg) - if parentPeer == nil { - t.Fatal("expected parent peer from topic context") - } - if parentPeer.Kind != "topic" || parentPeer.ID != "thread-42" { - t.Fatalf("parent peer = %+v, want topic/thread-42", parentPeer) - } -} - func TestAppendEventContextFields_IncludesInboundRouteAndScope(t *testing.T) { fields := map[string]any{} @@ -872,7 +837,7 @@ func TestResolveMessageRoute_UsesInboundContextAccountAndSpace(t *testing.T) { msgBus := bus.NewMessageBus() al := NewAgentLoop(cfg, msgBus, &simpleMockProvider{response: "ok"}) - route, _, err := al.resolveMessageRoute(bus.InboundMessage{ + route, _, err := al.resolveMessageRoute(testInboundMessage(bus.InboundMessage{ Context: bus.InboundContext{ Channel: "slack", Account: "workspace-a", @@ -883,7 +848,7 @@ func TestResolveMessageRoute_UsesInboundContextAccountAndSpace(t *testing.T) { SpaceType: "workspace", }, Content: "hello", - }) + })) if err != nil { t.Fatalf("resolveMessageRoute() error = %v", err) } @@ -926,12 +891,12 @@ func TestProcessMessage_MediaArtifactCanBeForwardedBySendFile(t *testing.T) { path: imagePath, }) - response, err := al.processMessage(context.Background(), bus.InboundMessage{ + response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{ Channel: "telegram", ChatID: "chat1", SenderID: "user1", Content: "take a screenshot of the screen and send it to me", - }) + })) if err != nil { t.Fatalf("processMessage() error = %v", err) } @@ -1518,13 +1483,39 @@ func (h testHelper) executeAndGetResponse(tb testing.TB, ctx context.Context, ms timeoutCtx, cancel := context.WithTimeout(ctx, responseTimeout) defer cancel() - response, err := h.al.processMessage(timeoutCtx, msg) + response, err := h.al.processMessage(timeoutCtx, testInboundMessage(msg)) if err != nil { tb.Fatalf("processMessage failed: %v", err) } return response } +func testInboundMessage(msg bus.InboundMessage) bus.InboundMessage { + if msg.Context.Channel == "" && + msg.Context.Account == "" && + msg.Context.ChatID == "" && + msg.Context.ChatType == "" && + msg.Context.TopicID == "" && + msg.Context.SpaceID == "" && + msg.Context.SpaceType == "" && + msg.Context.SenderID == "" && + msg.Context.MessageID == "" && + !msg.Context.Mentioned && + msg.Context.ReplyToMessageID == "" && + msg.Context.ReplyToSenderID == "" && + len(msg.Context.ReplyHandles) == 0 && + len(msg.Context.Raw) == 0 { + msg.Context = bus.InboundContext{ + Channel: msg.Channel, + ChatID: msg.ChatID, + ChatType: "direct", + SenderID: msg.SenderID, + MessageID: msg.MessageID, + } + } + return bus.NormalizeInboundMessage(msg) +} + const responseTimeout = 3 * time.Second func TestProcessMessage_UsesRouteSessionKey(t *testing.T) { @@ -1550,20 +1541,16 @@ func TestProcessMessage_UsesRouteSessionKey(t *testing.T) { al := NewAgentLoop(cfg, msgBus, provider) msg := bus.InboundMessage{ - Channel: "telegram", - SenderID: "user1", - ChatID: "chat1", - Content: "hello", - Peer: bus.Peer{ - Kind: "direct", - ID: "user1", + Context: bus.InboundContext{ + Channel: "telegram", + ChatID: "chat1", + ChatType: "direct", + SenderID: "user1", }, + Content: "hello", } - route := al.registry.ResolveRoute(routing.RouteInput{ - Channel: msg.Channel, - Peer: extractPeer(msg), - }) + route := al.registry.ResolveRoute(bus.NormalizeInboundMessage(msg).Context) sessionKey := al.allocateRouteSession(route, msg).SessionKey defaultAgent := al.registry.GetDefaultAgent() @@ -1610,21 +1597,22 @@ func TestProcessMessage_CommandOutcomes(t *testing.T) { helper := testHelper{al: al} baseMsg := bus.InboundMessage{ - Channel: "whatsapp", - SenderID: "user1", - ChatID: "chat1", - Peer: bus.Peer{ - Kind: "direct", - ID: "user1", + Context: bus.InboundContext{ + Channel: "whatsapp", + ChatID: "chat1", + ChatType: "direct", + SenderID: "user1", }, } showResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{ - Channel: baseMsg.Channel, - SenderID: baseMsg.SenderID, - ChatID: baseMsg.ChatID, - Content: "/show channel", - Peer: baseMsg.Peer, + Context: bus.InboundContext{ + Channel: baseMsg.Context.Channel, + ChatID: baseMsg.Context.ChatID, + ChatType: baseMsg.Context.ChatType, + SenderID: baseMsg.Context.SenderID, + }, + Content: "/show channel", }) if showResp != "Current Channel: whatsapp" { t.Fatalf("unexpected /show reply: %q", showResp) @@ -1634,11 +1622,13 @@ func TestProcessMessage_CommandOutcomes(t *testing.T) { } fooResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{ - Channel: baseMsg.Channel, - SenderID: baseMsg.SenderID, - ChatID: baseMsg.ChatID, - Content: "/foo", - Peer: baseMsg.Peer, + Context: bus.InboundContext{ + Channel: baseMsg.Context.Channel, + ChatID: baseMsg.Context.ChatID, + ChatType: baseMsg.Context.ChatType, + SenderID: baseMsg.Context.SenderID, + }, + Content: "/foo", }) if fooResp != "LLM reply" { t.Fatalf("unexpected /foo reply: %q", fooResp) @@ -1648,11 +1638,13 @@ func TestProcessMessage_CommandOutcomes(t *testing.T) { } newResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{ - Channel: baseMsg.Channel, - SenderID: baseMsg.SenderID, - ChatID: baseMsg.ChatID, - Content: "/new", - Peer: baseMsg.Peer, + Context: bus.InboundContext{ + Channel: baseMsg.Context.Channel, + ChatID: baseMsg.Context.ChatID, + ChatType: baseMsg.Context.ChatType, + SenderID: baseMsg.Context.SenderID, + }, + Content: "/new", }) if newResp != "LLM reply" { t.Fatalf("unexpected /new reply: %q", newResp) @@ -1705,10 +1697,6 @@ func TestProcessMessage_SwitchModelShowModelConsistency(t *testing.T) { SenderID: "user1", ChatID: "chat1", Content: "/switch model to deepseek", - Peer: bus.Peer{ - Kind: "direct", - ID: "user1", - }, }) if !strings.Contains(switchResp, "Switched model from local to deepseek") { t.Fatalf("unexpected /switch reply: %q", switchResp) @@ -1719,10 +1707,6 @@ func TestProcessMessage_SwitchModelShowModelConsistency(t *testing.T) { SenderID: "user1", ChatID: "chat1", Content: "/show model", - Peer: bus.Peer{ - Kind: "direct", - ID: "user1", - }, }) if !strings.Contains(showResp, "Current Model: deepseek (Provider: openrouter)") { t.Fatalf("unexpected /show model reply after switch: %q", showResp) @@ -1770,10 +1754,6 @@ func TestProcessMessage_SwitchModelRejectsUnknownAlias(t *testing.T) { SenderID: "user1", ChatID: "chat1", Content: "/switch model to missing", - Peer: bus.Peer{ - Kind: "direct", - ID: "user1", - }, }) if switchResp != `model "missing" not found in model_list or providers` { t.Fatalf("unexpected /switch error reply: %q", switchResp) @@ -1784,10 +1764,6 @@ func TestProcessMessage_SwitchModelRejectsUnknownAlias(t *testing.T) { SenderID: "user1", ChatID: "chat1", Content: "/show model", - Peer: bus.Peer{ - Kind: "direct", - ID: "user1", - }, }) if !strings.Contains(showResp, "Current Model: local (Provider: openai)") { t.Fatalf("unexpected /show model reply after rejected switch: %q", showResp) @@ -1854,10 +1830,6 @@ func TestProcessMessage_SwitchModelRoutesSubsequentRequestsToSelectedProvider(t SenderID: "user1", ChatID: "chat1", Content: "hello before switch", - Peer: bus.Peer{ - Kind: "direct", - ID: "user1", - }, }) if firstResp != "local reply" { t.Fatalf("unexpected response before switch: %q", firstResp) @@ -1877,10 +1849,6 @@ func TestProcessMessage_SwitchModelRoutesSubsequentRequestsToSelectedProvider(t SenderID: "user1", ChatID: "chat1", Content: "/switch model to deepseek", - Peer: bus.Peer{ - Kind: "direct", - ID: "user1", - }, }) if !strings.Contains(switchResp, "Switched model from local to deepseek") { t.Fatalf("unexpected /switch reply: %q", switchResp) @@ -1891,10 +1859,6 @@ func TestProcessMessage_SwitchModelRoutesSubsequentRequestsToSelectedProvider(t SenderID: "user1", ChatID: "chat1", Content: "hello after switch", - Peer: bus.Peer{ - Kind: "direct", - ID: "user1", - }, }) if secondResp != "remote reply" { t.Fatalf("unexpected response after switch: %q", secondResp) @@ -1984,10 +1948,6 @@ func TestProcessMessage_ModelRoutingUsesLightProvider(t *testing.T) { SenderID: "user1", ChatID: "chat1", Content: "hi", - Peer: bus.Peer{ - Kind: "direct", - ID: "user1", - }, }) if resp != "light reply" { t.Fatalf("response = %q, want %q", resp, "light reply") @@ -2260,22 +2220,16 @@ func TestAgentLoop_ToolLimitUsesDedicatedFallback(t *testing.T) { if defaultAgent == nil { t.Fatal("No default agent found") } - route := al.registry.ResolveRoute(routing.RouteInput{ - Channel: "test", - Peer: &routing.RoutePeer{ - Kind: "direct", - ID: "cron", - }, + route := al.registry.ResolveRoute(bus.InboundContext{ + Channel: "test", + ChatType: "direct", + SenderID: "cron", }) - history := defaultAgent.Sessions.GetHistory(al.allocateRouteSession(route, bus.InboundMessage{ + history := defaultAgent.Sessions.GetHistory(al.allocateRouteSession(route, testInboundMessage(bus.InboundMessage{ Channel: "test", SenderID: "cron", ChatID: "chat1", - Peer: bus.Peer{ - Kind: "direct", - ID: "cron", - }, - }).SessionKey) + })).SessionKey) if len(history) != 4 { t.Fatalf("history len = %d, want 4", len(history)) } @@ -2533,8 +2487,7 @@ func TestHandleReasoning(t *testing.T) { for i := 0; ; i++ { fillCtx, fillCancel := context.WithTimeout(context.Background(), 50*time.Millisecond) err := msgBus.PublishOutbound(fillCtx, bus.OutboundMessage{ - Channel: "filler", - ChatID: "filler", + Context: bus.NewOutboundContext("filler", "filler", ""), Content: fmt.Sprintf("filler-%d", i), }) fillCancel() @@ -2608,12 +2561,12 @@ func TestProcessMessage_PublishesReasoningContentToReasoningChannel(t *testing.T chManager.RegisterChannel("telegram", &fakeChannel{id: "reason-chat"}) al.SetChannelManager(chManager) - response, err := al.processMessage(context.Background(), bus.InboundMessage{ + response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{ Channel: "telegram", SenderID: "user1", ChatID: "chat1", Content: "hello", - }) + })) if err != nil { t.Fatalf("processMessage() error = %v", err) } @@ -2629,6 +2582,9 @@ func TestProcessMessage_PublishesReasoningContentToReasoningChannel(t *testing.T if outbound.ChatID != "reason-chat" { t.Fatalf("reasoning chatID = %q, want %q", outbound.ChatID, "reason-chat") } + if outbound.Context.Channel != "telegram" || outbound.Context.ChatID != "reason-chat" { + t.Fatalf("unexpected reasoning context: %+v", outbound.Context) + } if outbound.Content != "thinking trace" { t.Fatalf("reasoning content = %q, want %q", outbound.Content, "thinking trace") } @@ -2714,12 +2670,12 @@ func TestProcessMessage_PublishesToolFeedbackWhenEnabled(t *testing.T) { provider := &toolFeedbackProvider{filePath: heartbeatFile} al := NewAgentLoop(cfg, msgBus, provider) - response, err := al.processMessage(context.Background(), bus.InboundMessage{ + response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{ Channel: "telegram", SenderID: "user-1", ChatID: "chat-1", Content: "check tool feedback", - }) + })) if err != nil { t.Fatalf("processMessage() error = %v", err) } @@ -2735,6 +2691,9 @@ func TestProcessMessage_PublishesToolFeedbackWhenEnabled(t *testing.T) { if outbound.ChatID != "chat-1" { t.Fatalf("tool feedback chatID = %q, want %q", outbound.ChatID, "chat-1") } + if outbound.Context.Channel != "telegram" || outbound.Context.ChatID != "chat-1" { + t.Fatalf("unexpected tool feedback context: %+v", outbound.Context) + } if !strings.Contains(outbound.Content, "`read_file`") { t.Fatalf("tool feedback content = %q, want read_file preview", outbound.Content) } @@ -3157,13 +3116,13 @@ func TestProcessMessage_ContextOverflowRecovery(t *testing.T) { agent.Sessions.AddFullMessage(sessionKey, providers.Message{Role: "assistant", Content: "response"}) } - response, err := al.processMessage(context.Background(), bus.InboundMessage{ + response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{ Channel: "test", ChatID: "chat1", SenderID: "user1", SessionKey: "test-session", Content: "trigger recovery", - }) + })) if err != nil { t.Fatalf("processMessage() error = %v", err) } @@ -3199,12 +3158,12 @@ func TestProcessMessage_ContextOverflow_AnthropicStyle(t *testing.T) { return &providers.LLMResponse{Content: "Anthropic recovery success"}, nil } - response, err := al.processMessage(context.Background(), bus.InboundMessage{ + response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{ Channel: "test", ChatID: "chat1", SenderID: "user1", Content: "hello", - }) + })) if err != nil { t.Fatalf("processMessage() error = %v", err) } diff --git a/pkg/agent/registry.go b/pkg/agent/registry.go index 58b7ce440..8aa11e37b 100644 --- a/pkg/agent/registry.go +++ b/pkg/agent/registry.go @@ -3,6 +3,7 @@ package agent import ( "sync" + "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/providers" @@ -64,9 +65,9 @@ func (r *AgentRegistry) GetAgent(agentID string) (*AgentInstance, bool) { return agent, ok } -// ResolveRoute determines which agent handles the message. -func (r *AgentRegistry) ResolveRoute(input routing.RouteInput) routing.ResolvedRoute { - return r.resolver.ResolveRoute(input) +// ResolveRoute determines which agent handles the normalized inbound context. +func (r *AgentRegistry) ResolveRoute(inbound bus.InboundContext) routing.ResolvedRoute { + return r.resolver.ResolveRoute(inbound) } // ListAgentIDs returns all registered agent IDs. diff --git a/pkg/agent/steering.go b/pkg/agent/steering.go index b5cf049b3..f72e761f4 100644 --- a/pkg/agent/steering.go +++ b/pkg/agent/steering.go @@ -8,7 +8,6 @@ import ( "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/providers" - "github.com/sipeed/picoclaw/pkg/routing" "github.com/sipeed/picoclaw/pkg/session" "github.com/sipeed/picoclaw/pkg/tools" ) @@ -332,7 +331,7 @@ func (al *AgentLoop) agentForSession(sessionKey string) *AgentInstance { return agent } - if parsed := routing.ParseAgentSessionKey(sessionKey); parsed != nil { + if parsed := session.ParseLegacyAgentSessionKey(sessionKey); parsed != nil { if agent, ok := registry.GetAgent(parsed.AgentID); ok { return agent } diff --git a/pkg/agent/steering_test.go b/pkg/agent/steering_test.go index b67ec006c..9ecd8472a 100644 --- a/pkg/agent/steering_test.go +++ b/pkg/agent/steering_test.go @@ -366,14 +366,13 @@ func TestDrainBusToSteering_RequeuesDifferentScopeMessage(t *testing.T) { al := NewAgentLoop(cfg, msgBus, &mockProvider{}) activeMsg := bus.InboundMessage{ - Channel: "telegram", - SenderID: "user1", - ChatID: "chat1", - Content: "active turn", - Peer: bus.Peer{ - Kind: "direct", - ID: "user1", + Context: bus.InboundContext{ + Channel: "telegram", + ChatID: "chat1", + ChatType: "direct", + SenderID: "user1", }, + Content: "active turn", } activeScope, activeAgentID, ok := al.resolveSteeringTarget(activeMsg) if !ok { @@ -381,14 +380,13 @@ func TestDrainBusToSteering_RequeuesDifferentScopeMessage(t *testing.T) { } otherMsg := bus.InboundMessage{ - Channel: "telegram", - SenderID: "user2", - ChatID: "chat2", - Content: "other session", - Peer: bus.Peer{ - Kind: "direct", - ID: "user2", + Context: bus.InboundContext{ + Channel: "telegram", + ChatID: "chat2", + ChatType: "direct", + SenderID: "user2", }, + Content: "other session", } otherScope, _, ok := al.resolveSteeringTarget(otherMsg) if !ok { @@ -425,7 +423,7 @@ func TestDrainBusToSteering_RequeuesDifferentScopeMessage(t *testing.T) { case <-ctx.Done(): t.Fatalf("timeout waiting for requeued message on outbound bus") case requeued := <-msgBus.OutboundChan(): - if requeued.Channel != otherMsg.Channel || requeued.ChatID != otherMsg.ChatID || + if requeued.Context.Channel != otherMsg.Context.Channel || requeued.Context.ChatID != otherMsg.Context.ChatID || requeued.Content != otherMsg.Content { t.Fatalf("requeued message mismatch: got %+v want %+v", requeued, otherMsg) } @@ -842,24 +840,22 @@ func TestAgentLoop_Run_AutoContinuesLateSteeringMessage(t *testing.T) { }() first := bus.InboundMessage{ - Channel: "test", - SenderID: "user1", - ChatID: "chat1", - Content: "first message", - Peer: bus.Peer{ - Kind: "direct", - ID: "user1", + Context: bus.InboundContext{ + Channel: "test", + ChatID: "chat1", + ChatType: "direct", + SenderID: "user1", }, + Content: "first message", } late := bus.InboundMessage{ - Channel: "test", - SenderID: "user1", - ChatID: "chat1", - Content: "late append", - Peer: bus.Peer{ - Kind: "direct", - ID: "user1", + Context: bus.InboundContext{ + Channel: "test", + ChatID: "chat1", + ChatType: "direct", + SenderID: "user1", }, + Content: "late append", } pubCtx, pubCancel := context.WithTimeout(context.Background(), 2*time.Second) @@ -950,7 +946,7 @@ func TestAgentLoop_Steering_DirectResponseContinuesWithQueuedMessage(t *testing. }, } - sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID) + sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID) provider := &blockingDirectProvider{ firstStarted: make(chan struct{}), releaseFirst: make(chan struct{}), @@ -1117,7 +1113,7 @@ func TestAgentLoop_Continue_PreservesSteeringMedia(t *testing.T) { }, } - sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID) + sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID) msgBus := bus.NewMessageBus() al := NewAgentLoop(cfg, msgBus, provider) al.SetMediaStore(store) @@ -1225,7 +1221,7 @@ func TestAgentLoop_InterruptGraceful_UsesTerminalNoToolCall(t *testing.T) { al := NewAgentLoop(cfg, msgBus, provider) al.RegisterTool(tool1) al.RegisterTool(tool2) - sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID) + sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID) sub := al.SubscribeEvents(32) defer al.UnsubscribeEvents(sub.ID) @@ -1379,7 +1375,7 @@ func TestAgentLoop_InterruptHard_RestoresSession(t *testing.T) { al := NewAgentLoop(cfg, msgBus, provider) started := make(chan struct{}) al.RegisterTool(&interruptibleTool{name: "cancel_tool", started: started}) - sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID) + sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID) defaultAgent := al.registry.GetDefaultAgent() if defaultAgent == nil { diff --git a/pkg/bus/bus.go b/pkg/bus/bus.go index 3e7ec9cdc..45e755673 100644 --- a/pkg/bus/bus.go +++ b/pkg/bus/bus.go @@ -12,6 +12,12 @@ import ( // ErrBusClosed is returned when publishing to a closed MessageBus. var ErrBusClosed = errors.New("message bus closed") +var ( + ErrMissingInboundContext = errors.New("inbound message context is required") + ErrMissingOutboundContext = errors.New("outbound message context is required") + ErrMissingOutboundMediaContext = errors.New("outbound media context is required") +) + const defaultBusBufferSize = 64 // StreamDelegate is implemented by the channel Manager to provide streaming @@ -80,6 +86,9 @@ func publish[T any](ctx context.Context, mb *MessageBus, ch chan T, msg T) error } func (mb *MessageBus) PublishInbound(ctx context.Context, msg InboundMessage) error { + if msg.Context.isZero() { + return ErrMissingInboundContext + } msg = NormalizeInboundMessage(msg) return publish(ctx, mb, mb.inbound, msg) } @@ -89,6 +98,9 @@ func (mb *MessageBus) InboundChan() <-chan InboundMessage { } func (mb *MessageBus) PublishOutbound(ctx context.Context, msg OutboundMessage) error { + if msg.Context.isZero() { + return ErrMissingOutboundContext + } msg = NormalizeOutboundMessage(msg) return publish(ctx, mb, mb.outbound, msg) } @@ -98,6 +110,9 @@ func (mb *MessageBus) OutboundChan() <-chan OutboundMessage { } func (mb *MessageBus) PublishOutboundMedia(ctx context.Context, msg OutboundMediaMessage) error { + if msg.Context.isZero() { + return ErrMissingOutboundMediaContext + } msg = NormalizeOutboundMediaMessage(msg) return publish(ctx, mb, mb.outboundMedia, msg) } diff --git a/pkg/bus/bus_test.go b/pkg/bus/bus_test.go index 087c0a65e..18d1d1df8 100644 --- a/pkg/bus/bus_test.go +++ b/pkg/bus/bus_test.go @@ -14,10 +14,13 @@ func TestPublishConsume(t *testing.T) { ctx := context.Background() msg := InboundMessage{ - Channel: "test", - SenderID: "user1", - ChatID: "chat1", - Content: "hello", + Context: InboundContext{ + Channel: "test", + ChatID: "chat1", + ChatType: "direct", + SenderID: "user1", + }, + Content: "hello", } if err := mb.PublishInbound(ctx, msg); err != nil { @@ -45,25 +48,25 @@ func TestPublishConsume(t *testing.T) { } } -func TestPublishInbound_NormalizesLegacyFieldsIntoContext(t *testing.T) { +func TestPublishInbound_NormalizesContext(t *testing.T) { mb := NewMessageBus() defer mb.Close() msg := InboundMessage{ - Channel: "slack", - SenderID: "U123", - ChatID: "C456/1712", - Content: "hello", - MessageID: "1712.01", - Peer: Peer{Kind: "group", ID: "C456"}, - Metadata: map[string]string{ - "account_id": "workspace-a", - "team_id": "T001", - "reply_to_message_id": "1700.01", - "is_mentioned": "true", - "parent_peer_kind": "topic", - "parent_peer_id": "1712", + Context: InboundContext{ + Channel: "slack", + Account: "workspace-a", + ChatID: "C456/1712", + ChatType: "group", + TopicID: "1712", + SpaceID: "T001", + SpaceType: "team", + SenderID: "U123", + MessageID: "1712.01", + ReplyToMessageID: "1700.01", + Mentioned: true, }, + Content: "hello", } if err := mb.PublishInbound(context.Background(), msg); err != nil { @@ -94,7 +97,7 @@ func TestPublishInbound_NormalizesLegacyFieldsIntoContext(t *testing.T) { } } -func TestPublishInbound_MirrorsContextIntoLegacyFields(t *testing.T) { +func TestPublishInbound_MirrorsContextIntoConvenienceFields(t *testing.T) { mb := NewMessageBus() defer mb.Close() @@ -132,27 +135,8 @@ func TestPublishInbound_MirrorsContextIntoLegacyFields(t *testing.T) { if got.MessageID != "777" { t.Fatalf("expected legacy message ID 777, got %q", got.MessageID) } - if got.Peer.Kind != "group" || got.Peer.ID != "-1001" { - t.Fatalf("expected legacy peer group/-1001, got %q/%q", got.Peer.Kind, got.Peer.ID) - } - if got.Metadata["account_id"] != "bot-a" { - t.Fatalf("expected mirrored account_id bot-a, got %q", got.Metadata["account_id"]) - } - if got.Metadata["guild_id"] != "guild-9" { - t.Fatalf("expected mirrored guild_id guild-9, got %q", got.Metadata["guild_id"]) - } - if got.Metadata["parent_peer_kind"] != "topic" || got.Metadata["parent_peer_id"] != "42" { - t.Fatalf( - "expected mirrored topic parent peer, got %q/%q", - got.Metadata["parent_peer_kind"], - got.Metadata["parent_peer_id"], - ) - } - if got.Metadata["reply_to_message_id"] != "666" { - t.Fatalf("expected mirrored reply_to_message_id 666, got %q", got.Metadata["reply_to_message_id"]) - } - if got.Metadata["is_mentioned"] != "true" { - t.Fatalf("expected mirrored is_mentioned true, got %q", got.Metadata["is_mentioned"]) + if got.Context.Account != "bot-a" || got.Context.SpaceID != "guild-9" || got.Context.TopicID != "42" { + t.Fatalf("unexpected normalized context: %+v", got.Context) } } @@ -163,8 +147,10 @@ func TestPublishOutboundSubscribe(t *testing.T) { ctx := context.Background() msg := OutboundMessage{ - Channel: "telegram", - ChatID: "123", + Context: InboundContext{ + Channel: "telegram", + ChatID: "123", + }, Content: "world", } @@ -179,6 +165,9 @@ func TestPublishOutboundSubscribe(t *testing.T) { if got.Content != "world" { t.Fatalf("expected content 'world', got %q", got.Content) } + if got.Context.Channel != "telegram" || got.Context.ChatID != "123" { + t.Fatalf("expected normalized outbound context, got %+v", got.Context) + } } func TestPublishOutbound_MirrorsContextToLegacyFields(t *testing.T) { @@ -241,6 +230,19 @@ func TestPublishOutboundMedia_MirrorsContextToLegacyFields(t *testing.T) { } } +func TestNewOutboundContext_NormalizesReplyAddress(t *testing.T) { + ctx := NewOutboundContext(" telegram ", " chat-42 ", " msg-9 ") + if ctx.Channel != "telegram" { + t.Fatalf("expected channel telegram, got %q", ctx.Channel) + } + if ctx.ChatID != "chat-42" { + t.Fatalf("expected chat_id chat-42, got %q", ctx.ChatID) + } + if ctx.ReplyToMessageID != "msg-9" { + t.Fatalf("expected reply_to_message_id msg-9, got %q", ctx.ReplyToMessageID) + } +} + func TestPublishInbound_ContextCancel(t *testing.T) { mb := NewMessageBus() defer mb.Close() @@ -248,7 +250,15 @@ func TestPublishInbound_ContextCancel(t *testing.T) { // Fill the buffer ctx := context.Background() for i := range defaultBusBufferSize { - if err := mb.PublishInbound(ctx, InboundMessage{Content: "fill"}); err != nil { + if err := mb.PublishInbound(ctx, InboundMessage{ + Context: InboundContext{ + Channel: "test", + ChatID: "chat-fill", + ChatType: "direct", + SenderID: "user-fill", + }, + Content: "fill", + }); err != nil { t.Fatalf("fill failed at %d: %v", i, err) } } @@ -257,7 +267,15 @@ func TestPublishInbound_ContextCancel(t *testing.T) { cancelCtx, cancel := context.WithCancel(context.Background()) cancel() - err := mb.PublishInbound(cancelCtx, InboundMessage{Content: "overflow"}) + err := mb.PublishInbound(cancelCtx, InboundMessage{ + Context: InboundContext{ + Channel: "test", + ChatID: "chat-overflow", + ChatType: "direct", + SenderID: "user-overflow", + }, + Content: "overflow", + }) if err == nil { t.Fatal("expected error from canceled context, got nil") } @@ -270,7 +288,15 @@ func TestPublishInbound_BusClosed(t *testing.T) { mb := NewMessageBus() mb.Close() - err := mb.PublishInbound(context.Background(), InboundMessage{Content: "test"}) + err := mb.PublishInbound(context.Background(), InboundMessage{ + Context: InboundContext{ + Channel: "test", + ChatID: "chat1", + ChatType: "direct", + SenderID: "user1", + }, + Content: "test", + }) if err != ErrBusClosed { t.Fatalf("expected ErrBusClosed, got %v", err) } @@ -280,7 +306,13 @@ func TestPublishOutbound_BusClosed(t *testing.T) { mb := NewMessageBus() mb.Close() - err := mb.PublishOutbound(context.Background(), OutboundMessage{Content: "test"}) + err := mb.PublishOutbound(context.Background(), OutboundMessage{ + Context: InboundContext{ + Channel: "test", + ChatID: "chat1", + }, + Content: "test", + }) if err != ErrBusClosed { t.Fatalf("expected ErrBusClosed, got %v", err) } @@ -292,14 +324,30 @@ func TestConsumeInbound_ContextCancel(t *testing.T) { defer mb.Close() for i := range defaultBusBufferSize { - if err := mb.PublishInbound(context.Background(), InboundMessage{Content: "fill"}); err != nil { + if err := mb.PublishInbound(context.Background(), InboundMessage{ + Context: InboundContext{ + Channel: "test", + ChatID: "chat-fill", + ChatType: "direct", + SenderID: "user-fill", + }, + Content: "fill", + }); err != nil { t.Fatalf("fill failed at %d: %v", i, err) } } ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() - mb.PublishInbound(ctx, InboundMessage{Content: "ContextCancel"}) + mb.PublishInbound(ctx, InboundMessage{ + Context: InboundContext{ + Channel: "test", + ChatID: "chat-cancel", + ChatType: "direct", + SenderID: "user-cancel", + }, + Content: "ContextCancel", + }) select { case <-ctx.Done(): @@ -393,7 +441,15 @@ func TestPublishInbound_FullBuffer(t *testing.T) { // Fill the buffer for i := range defaultBusBufferSize { - if err := mb.PublishInbound(ctx, InboundMessage{Content: "fill"}); err != nil { + if err := mb.PublishInbound(ctx, InboundMessage{ + Context: InboundContext{ + Channel: "test", + ChatID: "chat-fill", + ChatType: "direct", + SenderID: "user-fill", + }, + Content: "fill", + }); err != nil { t.Fatalf("fill failed at %d: %v", i, err) } } @@ -402,7 +458,15 @@ func TestPublishInbound_FullBuffer(t *testing.T) { timeoutCtx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) defer cancel() - err := mb.PublishInbound(timeoutCtx, InboundMessage{Content: "overflow"}) + err := mb.PublishInbound(timeoutCtx, InboundMessage{ + Context: InboundContext{ + Channel: "test", + ChatID: "chat-overflow", + ChatType: "direct", + SenderID: "user-overflow", + }, + Content: "overflow", + }) if err == nil { t.Fatal("expected error when buffer is full and context times out") } @@ -420,7 +484,15 @@ func TestCloseIdempotent(t *testing.T) { mb.Close() // After close, publish should return ErrBusClosed - err := mb.PublishInbound(context.Background(), InboundMessage{Content: "test"}) + err := mb.PublishInbound(context.Background(), InboundMessage{ + Context: InboundContext{ + Channel: "test", + ChatID: "chat1", + ChatType: "direct", + SenderID: "user1", + }, + Content: "test", + }) if err != ErrBusClosed { t.Fatalf("expected ErrBusClosed after multiple closes, got %v", err) } diff --git a/pkg/bus/inbound_context.go b/pkg/bus/inbound_context.go index 501f27be4..3a19ac957 100644 --- a/pkg/bus/inbound_context.go +++ b/pkg/bus/inbound_context.go @@ -2,92 +2,19 @@ package bus import "strings" -const ( - metadataKeyAccountID = "account_id" - metadataKeyGuildID = "guild_id" - metadataKeyTeamID = "team_id" - metadataKeyReplyToMessage = "reply_to_message_id" - metadataKeyReplyToSender = "reply_to_sender_id" - metadataKeyParentPeerKind = "parent_peer_kind" - metadataKeyParentPeerID = "parent_peer_id" - metadataKeyIsMentioned = "is_mentioned" -) - -// ContextFromLegacyInbound builds a normalized inbound context from the legacy -// top-level fields on InboundMessage. This keeps older producers working while -// new producers migrate to writing Context directly. -func ContextFromLegacyInbound(msg InboundMessage) InboundContext { - ctx := InboundContext{ - Channel: strings.TrimSpace(msg.Channel), - ChatID: strings.TrimSpace(msg.ChatID), - ChatType: normalizeKind(msg.Peer.Kind), - SenderID: firstNonEmpty( - strings.TrimSpace(msg.SenderID), - strings.TrimSpace(msg.Sender.CanonicalID), - strings.TrimSpace(msg.Sender.PlatformID), - ), - MessageID: strings.TrimSpace(msg.MessageID), - Raw: cloneStringMap(msg.Metadata), - } - - if account := metadataValue(msg.Metadata, metadataKeyAccountID); account != "" { - ctx.Account = account - } - if replyToMsgID := metadataValue(msg.Metadata, metadataKeyReplyToMessage); replyToMsgID != "" { - ctx.ReplyToMessageID = replyToMsgID - } - if replyToSenderID := metadataValue(msg.Metadata, metadataKeyReplyToSender); replyToSenderID != "" { - ctx.ReplyToSenderID = replyToSenderID - } - if isTruthy(metadataValue(msg.Metadata, metadataKeyIsMentioned)) { - ctx.Mentioned = true - } - - parentKind := normalizeKind(metadataValue(msg.Metadata, metadataKeyParentPeerKind)) - parentID := metadataValue(msg.Metadata, metadataKeyParentPeerID) - if parentKind == "topic" && parentID != "" { - ctx.TopicID = parentID - } - - switch { - case metadataValue(msg.Metadata, metadataKeyGuildID) != "": - ctx.SpaceType = "guild" - ctx.SpaceID = metadataValue(msg.Metadata, metadataKeyGuildID) - case metadataValue(msg.Metadata, metadataKeyTeamID) != "": - ctx.SpaceType = "team" - ctx.SpaceID = metadataValue(msg.Metadata, metadataKeyTeamID) - } - - return normalizeInboundContext(ctx) -} - -// NormalizeInboundMessage ensures the normalized Context is present and mirrors -// missing legacy fields from it so older consumers continue to work during the -// migration period. +// NormalizeInboundMessage ensures the inbound context is normalized and keeps +// convenience mirrors in sync for runtime consumers. func NormalizeInboundMessage(msg InboundMessage) InboundMessage { - if msg.Context.isZero() { - msg.Context = ContextFromLegacyInbound(msg) - } else { - msg.Context = normalizeInboundContext(msg.Context) - } - - if msg.Channel == "" { - msg.Channel = msg.Context.Channel - } - if msg.SenderID == "" { - msg.SenderID = msg.Context.SenderID - } - if msg.ChatID == "" { - msg.ChatID = msg.Context.ChatID - } + msg.Context = normalizeInboundContext(msg.Context) + msg.Channel = msg.Context.Channel + msg.SenderID = msg.Context.SenderID + msg.ChatID = msg.Context.ChatID if msg.MessageID == "" { msg.MessageID = msg.Context.MessageID } - if msg.Peer.Kind == "" { - msg.Peer = peerFromContext(msg.Context) + if msg.Context.MessageID == "" { + msg.Context.MessageID = msg.MessageID } - - msg.Metadata = mergeLegacyMetadata(msg.Metadata, msg.Context) return msg } @@ -125,110 +52,6 @@ func normalizeInboundContext(ctx InboundContext) InboundContext { return ctx } -func peerFromContext(ctx InboundContext) Peer { - kind := normalizeKind(ctx.ChatType) - if kind == "" { - return Peer{} - } - - switch kind { - case "direct": - return Peer{ - Kind: "direct", - ID: firstNonEmpty(strings.TrimSpace(ctx.SenderID), strings.TrimSpace(ctx.ChatID)), - } - case "group", "channel": - return Peer{ - Kind: kind, - ID: strings.TrimSpace(ctx.ChatID), - } - default: - return Peer{ - Kind: kind, - ID: strings.TrimSpace(ctx.ChatID), - } - } -} - -func mergeLegacyMetadata(existing map[string]string, ctx InboundContext) map[string]string { - merged := cloneStringMap(existing) - if len(merged) == 0 { - merged = cloneStringMap(ctx.Raw) - } else { - for k, v := range ctx.Raw { - if _, ok := merged[k]; !ok { - merged[k] = v - } - } - } - - if ctx.Account != "" { - if merged == nil { - merged = make(map[string]string) - } - setMissing(merged, metadataKeyAccountID, ctx.Account) - } - if ctx.ReplyToMessageID != "" { - if merged == nil { - merged = make(map[string]string) - } - setMissing(merged, metadataKeyReplyToMessage, ctx.ReplyToMessageID) - } - if ctx.ReplyToSenderID != "" { - if merged == nil { - merged = make(map[string]string) - } - setMissing(merged, metadataKeyReplyToSender, ctx.ReplyToSenderID) - } - if ctx.Mentioned { - if merged == nil { - merged = make(map[string]string) - } - setMissing(merged, metadataKeyIsMentioned, "true") - } - if ctx.TopicID != "" { - if merged == nil { - merged = make(map[string]string) - } - setMissing(merged, metadataKeyParentPeerKind, "topic") - setMissing(merged, metadataKeyParentPeerID, ctx.TopicID) - } - - switch normalizeKind(ctx.SpaceType) { - case "guild": - if merged == nil { - merged = make(map[string]string) - } - setMissing(merged, metadataKeyGuildID, ctx.SpaceID) - case "team", "workspace": - if merged == nil { - merged = make(map[string]string) - } - setMissing(merged, metadataKeyTeamID, ctx.SpaceID) - } - - if len(merged) == 0 { - return nil - } - return merged -} - -func setMissing(dst map[string]string, key, value string) { - if value == "" { - return - } - if _, ok := dst[key]; !ok { - dst[key] = value - } -} - -func metadataValue(metadata map[string]string, key string) string { - if metadata == nil { - return "" - } - return strings.TrimSpace(metadata[key]) -} - func cloneStringMap(src map[string]string) map[string]string { if len(src) == 0 { return nil @@ -241,24 +64,11 @@ func cloneStringMap(src map[string]string) map[string]string { return dst } -func firstNonEmpty(values ...string) string { - for _, value := range values { - if value != "" { - return value - } - } - return "" -} - -func normalizeKind(value string) string { - return strings.ToLower(strings.TrimSpace(value)) -} - -func isTruthy(value string) bool { - switch strings.ToLower(strings.TrimSpace(value)) { - case "1", "t", "true", "y", "yes", "on": - return true +func normalizeKind(kind string) string { + switch strings.ToLower(strings.TrimSpace(kind)) { + case "direct", "group", "channel", "guild", "team", "workspace", "tenant", "topic": + return strings.ToLower(strings.TrimSpace(kind)) default: - return false + return strings.ToLower(strings.TrimSpace(kind)) } } diff --git a/pkg/bus/outbound_context.go b/pkg/bus/outbound_context.go index e02353ea9..b3f58f736 100644 --- a/pkg/bus/outbound_context.go +++ b/pkg/bus/outbound_context.go @@ -2,62 +2,34 @@ package bus import "strings" -// ContextFromLegacyOutbound builds a minimal outbound context from the legacy -// top-level outbound fields. This keeps older outbound publishers working -// while new publishers gradually start carrying the original InboundContext. -func ContextFromLegacyOutbound(msg OutboundMessage) InboundContext { +// NewOutboundContext builds the minimal normalized addressing context required +// to deliver an outbound text message or reply. +func NewOutboundContext(channel, chatID, replyToMessageID string) InboundContext { return normalizeInboundContext(InboundContext{ - Channel: strings.TrimSpace(msg.Channel), - ChatID: strings.TrimSpace(msg.ChatID), - ReplyToMessageID: strings.TrimSpace(msg.ReplyToMessageID), + Channel: strings.TrimSpace(channel), + ChatID: strings.TrimSpace(chatID), + ReplyToMessageID: strings.TrimSpace(replyToMessageID), }) } -// ContextFromLegacyOutboundMedia builds a minimal outbound context for media. -func ContextFromLegacyOutboundMedia(msg OutboundMediaMessage) InboundContext { - return normalizeInboundContext(InboundContext{ - Channel: strings.TrimSpace(msg.Channel), - ChatID: strings.TrimSpace(msg.ChatID), - }) -} - -// NormalizeOutboundMessage ensures Context is present and mirrors legacy -// top-level addressing fields from it so older senders keep working. +// NormalizeOutboundMessage ensures Context is normalized and keeps convenience +// mirrors in sync for runtime consumers. func NormalizeOutboundMessage(msg OutboundMessage) OutboundMessage { - if msg.Context.isZero() { - msg.Context = ContextFromLegacyOutbound(msg) - } else { - msg.Context = normalizeInboundContext(msg.Context) + msg.Context = normalizeInboundContext(msg.Context) + msg.Channel = msg.Context.Channel + msg.ChatID = msg.Context.ChatID + if msg.Context.ReplyToMessageID == "" { + msg.Context.ReplyToMessageID = strings.TrimSpace(msg.ReplyToMessageID) } - - if msg.Channel == "" { - msg.Channel = msg.Context.Channel - } - if msg.ChatID == "" { - msg.ChatID = msg.Context.ChatID - } - if msg.ReplyToMessageID == "" { - msg.ReplyToMessageID = msg.Context.ReplyToMessageID - } - + msg.ReplyToMessageID = msg.Context.ReplyToMessageID return msg } // NormalizeOutboundMediaMessage ensures media outbound messages also carry a -// normalized context while preserving the legacy top-level routing fields. +// normalized context while keeping convenience mirrors in sync. func NormalizeOutboundMediaMessage(msg OutboundMediaMessage) OutboundMediaMessage { - if msg.Context.isZero() { - msg.Context = ContextFromLegacyOutboundMedia(msg) - } else { - msg.Context = normalizeInboundContext(msg.Context) - } - - if msg.Channel == "" { - msg.Channel = msg.Context.Channel - } - if msg.ChatID == "" { - msg.ChatID = msg.Context.ChatID - } - + msg.Context = normalizeInboundContext(msg.Context) + msg.Channel = msg.Context.Channel + msg.ChatID = msg.Context.ChatID return msg } diff --git a/pkg/bus/types.go b/pkg/bus/types.go index f844ab1e0..cccfc8baf 100644 --- a/pkg/bus/types.go +++ b/pkg/bus/types.go @@ -1,11 +1,5 @@ package bus -// Peer identifies the routing peer for a message (direct, group, channel, etc.) -type Peer struct { - Kind string `json:"kind"` // "direct" | "group" | "channel" | "" - ID string `json:"id"` -} - // SenderInfo provides structured sender identity information. type SenderInfo struct { Platform string `json:"platform,omitempty"` // "telegram", "discord", "slack", ... @@ -16,9 +10,8 @@ type SenderInfo struct { } // InboundContext captures the normalized, platform-agnostic facts about an -// inbound message. This is the long-term source of truth for routing and -// session allocation. Legacy top-level fields on InboundMessage remain during -// the transition and are derived from this context when missing. +// inbound message. This is the source of truth for routing and session +// allocation. type InboundContext struct { Channel string `json:"channel"` Account string `json:"account,omitempty"` @@ -43,18 +36,18 @@ type InboundContext struct { } type InboundMessage struct { - Channel string `json:"channel"` - SenderID string `json:"sender_id"` - Sender SenderInfo `json:"sender"` - ChatID string `json:"chat_id"` - Context InboundContext `json:"context"` - Content string `json:"content"` - Media []string `json:"media,omitempty"` - Peer Peer `json:"peer"` // routing peer - MessageID string `json:"message_id,omitempty"` // platform message ID - MediaScope string `json:"media_scope,omitempty"` // media lifecycle scope - SessionKey string `json:"session_key"` - Metadata map[string]string `json:"metadata,omitempty"` + Context InboundContext `json:"context"` + Sender SenderInfo `json:"sender"` + Content string `json:"content"` + Media []string `json:"media,omitempty"` + MediaScope string `json:"media_scope,omitempty"` // media lifecycle scope + SessionKey string `json:"session_key"` + + // Convenience mirrors derived from Context for runtime consumers. + Channel string `json:"channel"` + SenderID string `json:"sender_id"` + ChatID string `json:"chat_id"` + MessageID string `json:"message_id,omitempty"` // platform message ID } type OutboundMessage struct { diff --git a/pkg/channels/base.go b/pkg/channels/base.go index 8161fa12e..37fce7cb6 100644 --- a/pkg/channels/base.go +++ b/pkg/channels/base.go @@ -244,35 +244,8 @@ func (c *BaseChannel) IsAllowedSender(sender bus.SenderInfo) bool { return false } -func (c *BaseChannel) HandleMessage( - ctx context.Context, - peer bus.Peer, - messageID, senderID, chatID, content string, - media []string, - metadata map[string]string, - senderOpts ...bus.SenderInfo, -) { - var sender bus.SenderInfo - if len(senderOpts) > 0 { - sender = senderOpts[0] - } - - inboundCtx := bus.ContextFromLegacyInbound(bus.InboundMessage{ - Channel: c.name, - SenderID: senderID, - Sender: sender, - ChatID: chatID, - Peer: peer, - MessageID: messageID, - Metadata: metadata, - }) - - c.HandleMessageWithContext(ctx, peer, chatID, content, media, inboundCtx, senderOpts...) -} - func (c *BaseChannel) HandleMessageWithContext( ctx context.Context, - peer bus.Peer, deliveryChatID, content string, media []string, inboundCtx bus.InboundContext, @@ -315,15 +288,10 @@ func (c *BaseChannel) HandleMessageWithContext( scope := BuildMediaScope(c.name, deliveryChatID, inboundCtx.MessageID) msg := bus.InboundMessage{ - Channel: c.name, - SenderID: resolvedSenderID, - Sender: sender, - ChatID: deliveryChatID, Context: inboundCtx, + Sender: sender, Content: content, Media: media, - Peer: peer, - MessageID: inboundCtx.MessageID, MediaScope: scope, } msg = bus.NormalizeInboundMessage(msg) @@ -369,6 +337,18 @@ func (c *BaseChannel) HandleMessageWithContext( } } +// HandleInboundContext publishes a normalized inbound message using only the +// structured context. +func (c *BaseChannel) HandleInboundContext( + ctx context.Context, + deliveryChatID, content string, + media []string, + inboundCtx bus.InboundContext, + senderOpts ...bus.SenderInfo, +) { + c.HandleMessageWithContext(ctx, deliveryChatID, content, media, inboundCtx, senderOpts...) +} + func (c *BaseChannel) SetRunning(running bool) { c.running.Store(running) } diff --git a/pkg/channels/base_test.go b/pkg/channels/base_test.go index 6132b8bf9..04500f775 100644 --- a/pkg/channels/base_test.go +++ b/pkg/channels/base_test.go @@ -1,6 +1,7 @@ package channels import ( + "context" "testing" "github.com/sipeed/picoclaw/pkg/bus" @@ -263,3 +264,58 @@ func TestIsAllowedSender(t *testing.T) { }) } } + +func TestHandleInboundContext_PublishesNormalizedContext(t *testing.T) { + tests := []struct { + name string + inbound bus.InboundContext + wantChat string + wantSender string + }{ + { + name: "direct uses sender as peer", + inbound: bus.InboundContext{ + Channel: "test", + ChatID: "chat-1", + ChatType: "direct", + SenderID: "user-1", + MessageID: "msg-1", + }, + wantChat: "chat-1", + wantSender: "user-1", + }, + { + name: "group uses chat as peer", + inbound: bus.InboundContext{ + Channel: "test", + ChatID: "group-1", + ChatType: "group", + SenderID: "user-2", + MessageID: "msg-2", + }, + wantChat: "group-1", + wantSender: "user-2", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + msgBus := bus.NewMessageBus() + defer msgBus.Close() + + ch := NewBaseChannel("test", nil, msgBus, nil) + ch.HandleInboundContext(context.Background(), tt.inbound.ChatID, "hello", nil, tt.inbound) + + msg := <-msgBus.InboundChan() + if msg.ChatID != tt.wantChat { + t.Fatalf("ChatID = %q, want %q", msg.ChatID, tt.wantChat) + } + if msg.SenderID != tt.wantSender { + t.Fatalf("SenderID = %q, want %q", msg.SenderID, tt.wantSender) + } + if msg.Context.ChatType != tt.inbound.ChatType { + t.Fatalf("ChatType = %q, want %q", msg.Context.ChatType, tt.inbound.ChatType) + } + }) + } +} diff --git a/pkg/channels/dingtalk/dingtalk.go b/pkg/channels/dingtalk/dingtalk.go index 04ccec8a2..30dfffad9 100644 --- a/pkg/channels/dingtalk/dingtalk.go +++ b/pkg/channels/dingtalk/dingtalk.go @@ -181,16 +181,15 @@ func (c *DingTalkChannel) onChatBotMessageReceived( "session_webhook": data.SessionWebhook, } - var peer bus.Peer + var ( + chatType string + isMentioned bool + ) if data.ConversationType == "1" { - peerID := senderID - if peerID == "" { - peerID = chatID - } - peer = bus.Peer{Kind: "direct", ID: peerID} + chatType = "direct" } else { - peer = bus.Peer{Kind: "group", ID: data.ConversationId} - isMentioned := data.IsInAtList + chatType = "group" + isMentioned = data.IsInAtList if isMentioned { content = stripLeadingAtMentions(content) } @@ -228,8 +227,21 @@ func (c *DingTalkChannel) onChatBotMessageReceived( return nil, nil } - // Handle the message through the base channel - c.HandleMessage(ctx, peer, "", resolvedSenderID, chatID, content, nil, metadata, sender) + inboundCtx := bus.InboundContext{ + Channel: "dingtalk", + ChatID: chatID, + ChatType: chatType, + SenderID: resolvedSenderID, + Mentioned: isMentioned, + Raw: metadata, + } + if data.SessionWebhook != "" { + inboundCtx.ReplyHandles = map[string]string{ + "session_webhook": data.SessionWebhook, + } + } + + c.HandleInboundContext(ctx, chatID, content, nil, inboundCtx, sender) // Return nil to indicate we've handled the message asynchronously // The response will be sent through the message bus diff --git a/pkg/channels/discord/discord.go b/pkg/channels/discord/discord.go index 0376dcdae..427d20779 100644 --- a/pkg/channels/discord/discord.go +++ b/pkg/channels/discord/discord.go @@ -461,14 +461,10 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag }) peerKind := "channel" - peerID := m.ChannelID if m.GuildID == "" { peerKind = "direct" - peerID = senderID } - peer := bus.Peer{Kind: peerKind, ID: peerID} - metadata := map[string]string{ "user_id": senderID, "username": m.Author.Username, @@ -494,7 +490,7 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag inboundCtx.ReplyToMessageID = m.MessageReference.MessageID } - c.HandleMessageWithContext(c.ctx, peer, m.ChannelID, content, mediaPaths, inboundCtx, sender) + c.HandleInboundContext(c.ctx, m.ChannelID, content, mediaPaths, inboundCtx, sender) } // startTyping starts a continuous typing indicator loop for the given chatID. diff --git a/pkg/channels/feishu/feishu_64.go b/pkg/channels/feishu/feishu_64.go index b0b231d09..f74fab19b 100644 --- a/pkg/channels/feishu/feishu_64.go +++ b/pkg/channels/feishu/feishu_64.go @@ -447,22 +447,25 @@ func (c *FeishuChannel) handleMessageReceive(ctx context.Context, event *larkim. if messageType != "" { metadata["message_type"] = messageType } - chatType := stringValue(message.ChatType) - if chatType != "" { - metadata["chat_type"] = chatType + rawChatType := stringValue(message.ChatType) + if rawChatType != "" { + metadata["chat_type"] = rawChatType } if sender != nil && sender.TenantKey != nil { metadata["tenant_key"] = *sender.TenantKey } - var peer bus.Peer - if chatType == "p2p" { - peer = bus.Peer{Kind: "direct", ID: senderID} + var ( + inboundChatType string + isMentioned bool + ) + if rawChatType == "p2p" { + inboundChatType = "direct" } else { - peer = bus.Peer{Kind: "group", ID: chatID} + inboundChatType = "group" // Check if bot was mentioned - isMentioned := c.isBotMentioned(message) + isMentioned = c.isBotMentioned(message) // Strip mention placeholders from content before group trigger check if len(message.Mentions) > 0 { @@ -484,7 +487,21 @@ func (c *FeishuChannel) handleMessageReceive(ctx context.Context, event *larkim. "preview": utils.Truncate(content, 80), }) - c.HandleMessage(ctx, peer, messageID, senderID, chatID, content, mediaRefs, metadata, senderInfo) + inboundCtx := bus.InboundContext{ + Channel: "feishu", + ChatID: chatID, + ChatType: inboundChatType, + SenderID: senderID, + MessageID: messageID, + Mentioned: isMentioned, + Raw: metadata, + } + if sender != nil && sender.TenantKey != nil && *sender.TenantKey != "" { + inboundCtx.SpaceType = "tenant" + inboundCtx.SpaceID = *sender.TenantKey + } + + c.HandleInboundContext(ctx, chatID, content, mediaRefs, inboundCtx, senderInfo) return nil } diff --git a/pkg/channels/irc/handler.go b/pkg/channels/irc/handler.go index b92359da4..73df9c43c 100644 --- a/pkg/channels/irc/handler.go +++ b/pkg/channels/irc/handler.go @@ -51,14 +51,11 @@ func (c *IRCChannel) onPrivmsg(conn *ircevent.Connection, e ircmsg.Message) { isDM := !strings.HasPrefix(target, "#") && !strings.HasPrefix(target, "&") var chatID string - var peer bus.Peer if isDM { chatID = nick - peer = bus.Peer{Kind: "direct", ID: nick} } else { chatID = target - peer = bus.Peer{Kind: "group", ID: target} } sender := bus.SenderInfo{ @@ -73,9 +70,11 @@ func (c *IRCChannel) onPrivmsg(conn *ircevent.Connection, e ircmsg.Message) { return } + isMentioned := false + // For channel messages, check group trigger (mention detection) if !isDM { - isMentioned := isBotMentioned(content, currentNick) + isMentioned = isBotMentioned(content, currentNick) if isMentioned { content = stripBotMention(content, currentNick) } @@ -100,7 +99,21 @@ func (c *IRCChannel) onPrivmsg(conn *ircevent.Connection, e ircmsg.Message) { metadata["channel"] = target } - c.HandleMessage(c.ctx, peer, messageID, nick, chatID, content, nil, metadata, sender) + inboundCtx := bus.InboundContext{ + Channel: "irc", + ChatID: chatID, + SenderID: nick, + MessageID: messageID, + Mentioned: isMentioned, + Raw: metadata, + } + if isDM { + inboundCtx.ChatType = "direct" + } else { + inboundCtx.ChatType = "group" + } + + c.HandleInboundContext(c.ctx, chatID, content, nil, inboundCtx, sender) } // nickMentionedAt returns the byte index where botNick is mentioned in content diff --git a/pkg/channels/line/line.go b/pkg/channels/line/line.go index 269f14997..b0853fb8b 100644 --- a/pkg/channels/line/line.go +++ b/pkg/channels/line/line.go @@ -368,13 +368,6 @@ func (c *LINEChannel) processEvent(event lineEvent) { "source_type": event.Source.Type, } - var peer bus.Peer - if isGroup { - peer = bus.Peer{Kind: "group", ID: chatID} - } else { - peer = bus.Peer{Kind: "direct", ID: senderID} - } - logger.DebugCF("line", "Received message", map[string]any{ "sender_id": senderID, "chat_id": chatID, @@ -396,7 +389,7 @@ func (c *LINEChannel) processEvent(event lineEvent) { inboundCtx := bus.InboundContext{ Channel: c.Name(), ChatID: chatID, - ChatType: peer.Kind, + ChatType: map[bool]string{true: "group", false: "direct"}[isGroup], SenderID: senderID, MessageID: msg.ID, Mentioned: isMentioned, @@ -411,7 +404,7 @@ func (c *LINEChannel) processEvent(event lineEvent) { } } - c.HandleMessageWithContext(c.ctx, peer, chatID, content, mediaPaths, inboundCtx, sender) + c.HandleInboundContext(c.ctx, chatID, content, mediaPaths, inboundCtx, sender) } // isBotMentioned checks if the bot is mentioned in the message. diff --git a/pkg/channels/maixcam/maixcam.go b/pkg/channels/maixcam/maixcam.go index bbbf2da56..0c77d1392 100644 --- a/pkg/channels/maixcam/maixcam.go +++ b/pkg/channels/maixcam/maixcam.go @@ -196,17 +196,15 @@ func (c *MaixCamChannel) handlePersonDetection(msg MaixCamMessage) { return } - c.HandleMessage( - c.ctx, - bus.Peer{Kind: "channel", ID: "default"}, - "", - senderID, - chatID, - content, - []string{}, - metadata, - sender, - ) + inboundCtx := bus.InboundContext{ + Channel: "maixcam", + ChatID: chatID, + ChatType: "channel", + SenderID: senderID, + Raw: metadata, + } + + c.HandleInboundContext(c.ctx, chatID, content, nil, inboundCtx, sender) } func (c *MaixCamChannel) handleStatusUpdate(msg MaixCamMessage) { diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go index 76d1e67c5..60cea9e78 100644 --- a/pkg/channels/manager.go +++ b/pkg/channels/manager.go @@ -97,6 +97,22 @@ type asyncTask struct { cancel context.CancelFunc } +func outboundMessageChannel(msg bus.OutboundMessage) string { + return msg.Context.Channel +} + +func outboundMessageChatID(msg bus.OutboundMessage) string { + return msg.Context.ChatID +} + +func outboundMediaChannel(msg bus.OutboundMediaMessage) string { + return msg.Context.Channel +} + +func outboundMediaChatID(msg bus.OutboundMediaMessage) string { + return msg.Context.ChatID +} + // RecordPlaceholder registers a placeholder message for later editing. // Implements PlaceholderRecorder. func (m *Manager) RecordPlaceholder(channel, chatID, placeholderID string) { @@ -160,7 +176,8 @@ func (m *Manager) RecordReactionUndo(channel, chatID string, undo func()) { // preSend handles typing stop, reaction undo, and placeholder editing before sending a message. // Returns the delivered message IDs and true when delivery completed before a normal Send. func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMessage, ch Channel) ([]string, bool) { - key := name + ":" + msg.ChatID + chatID := outboundMessageChatID(msg) + key := name + ":" + chatID // 1. Stop typing if v, loaded := m.typingStops.LoadAndDelete(key); loaded { @@ -182,9 +199,9 @@ func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMess if entry, ok := v.(placeholderEntry); ok && entry.id != "" { // Prefer deleting the placeholder (cleaner UX than editing to same content) if deleter, ok := ch.(MessageDeleter); ok { - deleter.DeleteMessage(ctx, msg.ChatID, entry.id) // best effort + deleter.DeleteMessage(ctx, chatID, entry.id) // best effort } else if editor, ok := ch.(MessageEditor); ok { - editor.EditMessage(ctx, msg.ChatID, entry.id, msg.Content) // fallback + editor.EditMessage(ctx, chatID, entry.id, msg.Content) // fallback } } } @@ -195,7 +212,7 @@ func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMess if v, loaded := m.placeholders.LoadAndDelete(key); loaded { if entry, ok := v.(placeholderEntry); ok && entry.id != "" { if editor, ok := ch.(MessageEditor); ok { - if err := editor.EditMessage(ctx, msg.ChatID, entry.id, msg.Content); err == nil { + if err := editor.EditMessage(ctx, chatID, entry.id, msg.Content); err == nil { return []string{entry.id}, true } // edit failed → fall through to normal Send @@ -211,7 +228,8 @@ func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMess // delivery never edits the placeholder because there is no text payload to // replace it with; it only attempts to delete the placeholder when possible. func (m *Manager) preSendMedia(ctx context.Context, name string, msg bus.OutboundMediaMessage, ch Channel) { - key := name + ":" + msg.ChatID + chatID := outboundMediaChatID(msg) + key := name + ":" + chatID // 1. Stop typing if v, loaded := m.typingStops.LoadAndDelete(key); loaded { @@ -234,7 +252,7 @@ func (m *Manager) preSendMedia(ctx context.Context, name string, msg bus.Outboun if v, loaded := m.placeholders.LoadAndDelete(key); loaded { if entry, ok := v.(placeholderEntry); ok && entry.id != "" { if deleter, ok := ch.(MessageDeleter); ok { - deleter.DeleteMessage(ctx, msg.ChatID, entry.id) // best effort + deleter.DeleteMessage(ctx, chatID, entry.id) // best effort } } } @@ -756,7 +774,7 @@ func (m *Manager) sendWithRetry( // All retries exhausted or permanent failure logger.ErrorCF("channels", "Send failed", map[string]any{ "channel": name, - "chat_id": msg.ChatID, + "chat_id": outboundMessageChatID(msg), "error": lastErr.Error(), "retries": maxRetries, }) @@ -818,7 +836,7 @@ func (m *Manager) dispatchOutbound(ctx context.Context) { dispatchLoop( ctx, m, m.bus.OutboundChan(), - func(msg bus.OutboundMessage) string { return msg.Channel }, + func(msg bus.OutboundMessage) string { return outboundMessageChannel(msg) }, func(ctx context.Context, w *channelWorker, msg bus.OutboundMessage) bool { select { case w.queue <- msg: @@ -838,7 +856,7 @@ func (m *Manager) dispatchOutboundMedia(ctx context.Context) { dispatchLoop( ctx, m, m.bus.OutboundMediaChan(), - func(msg bus.OutboundMediaMessage) string { return msg.Channel }, + func(msg bus.OutboundMediaMessage) string { return outboundMediaChannel(msg) }, func(ctx context.Context, w *channelWorker, msg bus.OutboundMediaMessage) bool { select { case w.mediaQueue <- msg: @@ -937,7 +955,7 @@ func (m *Manager) sendMediaWithRetry( // All retries exhausted or permanent failure logger.ErrorCF("channels", "SendMedia failed", map[string]any{ "channel": name, - "chat_id": msg.ChatID, + "chat_id": outboundMediaChatID(msg), "error": lastErr.Error(), "retries": maxRetries, }) @@ -1131,17 +1149,18 @@ func (m *Manager) UnregisterChannel(name string) { // a subsequent operation depends on the message having been sent. func (m *Manager) SendMessage(ctx context.Context, msg bus.OutboundMessage) error { msg = bus.NormalizeOutboundMessage(msg) + channelName := outboundMessageChannel(msg) m.mu.RLock() - _, exists := m.channels[msg.Channel] - w, wExists := m.workers[msg.Channel] + _, exists := m.channels[channelName] + w, wExists := m.workers[channelName] m.mu.RUnlock() if !exists { - return fmt.Errorf("channel %s not found", msg.Channel) + return fmt.Errorf("channel %s not found", channelName) } if !wExists || w == nil { - return fmt.Errorf("channel %s has no active worker", msg.Channel) + return fmt.Errorf("channel %s has no active worker", channelName) } maxLen := 0 @@ -1152,10 +1171,10 @@ func (m *Manager) SendMessage(ctx context.Context, msg bus.OutboundMessage) erro for _, chunk := range SplitMessage(msg.Content, maxLen) { chunkMsg := msg chunkMsg.Content = chunk - m.sendWithRetry(ctx, msg.Channel, w, chunkMsg) + m.sendWithRetry(ctx, channelName, w, chunkMsg) } } else { - m.sendWithRetry(ctx, msg.Channel, w, msg) + m.sendWithRetry(ctx, channelName, w, msg) } return nil } @@ -1166,20 +1185,21 @@ func (m *Manager) SendMessage(ctx context.Context, msg bus.OutboundMessage) erro // depends on actual media delivery. func (m *Manager) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error { msg = bus.NormalizeOutboundMediaMessage(msg) + channelName := outboundMediaChannel(msg) m.mu.RLock() - _, exists := m.channels[msg.Channel] - w, wExists := m.workers[msg.Channel] + _, exists := m.channels[channelName] + w, wExists := m.workers[channelName] m.mu.RUnlock() if !exists { - return fmt.Errorf("channel %s not found", msg.Channel) + return fmt.Errorf("channel %s not found", channelName) } if !wExists || w == nil { - return fmt.Errorf("channel %s has no active worker", msg.Channel) + return fmt.Errorf("channel %s has no active worker", channelName) } - _, err := m.sendMediaWithRetry(ctx, msg.Channel, w, msg) + _, err := m.sendMediaWithRetry(ctx, channelName, w, msg) return err } @@ -1194,10 +1214,10 @@ func (m *Manager) SendToChannel(ctx context.Context, channelName, chatID, conten } msg := bus.OutboundMessage{ - Channel: channelName, - ChatID: chatID, + Context: bus.NewOutboundContext(channelName, chatID, ""), Content: content, } + msg = bus.NormalizeOutboundMessage(msg) if wExists && w != nil { select { diff --git a/pkg/channels/manager_test.go b/pkg/channels/manager_test.go index e76212905..29219679d 100644 --- a/pkg/channels/manager_test.go +++ b/pkg/channels/manager_test.go @@ -89,6 +89,20 @@ func newTestManager() *Manager { } } +func testOutboundMessage(msg bus.OutboundMessage) bus.OutboundMessage { + if msg.Context.Channel == "" && msg.Context.ChatID == "" { + msg.Context = bus.NewOutboundContext(msg.Channel, msg.ChatID, msg.ReplyToMessageID) + } + return bus.NormalizeOutboundMessage(msg) +} + +func testOutboundMediaMessage(msg bus.OutboundMediaMessage) bus.OutboundMediaMessage { + if msg.Context.Channel == "" && msg.Context.ChatID == "" { + msg.Context = bus.NewOutboundContext(msg.Channel, msg.ChatID, "") + } + return bus.NormalizeOutboundMediaMessage(msg) +} + func TestSendWithRetry_Success(t *testing.T) { m := newTestManager() var callCount int @@ -104,7 +118,7 @@ func TestSendWithRetry_Success(t *testing.T) { } ctx := context.Background() - msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"} + msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}) m.sendWithRetry(ctx, "test", w, msg) @@ -131,7 +145,7 @@ func TestSendWithRetry_TemporaryThenSuccess(t *testing.T) { } ctx := context.Background() - msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"} + msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}) m.sendWithRetry(ctx, "test", w, msg) @@ -155,7 +169,7 @@ func TestSendWithRetry_PermanentFailure(t *testing.T) { } ctx := context.Background() - msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"} + msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}) m.sendWithRetry(ctx, "test", w, msg) @@ -179,7 +193,7 @@ func TestSendWithRetry_NotRunning(t *testing.T) { } ctx := context.Background() - msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"} + msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}) m.sendWithRetry(ctx, "test", w, msg) @@ -206,7 +220,7 @@ func TestSendWithRetry_RateLimitRetry(t *testing.T) { } ctx := context.Background() - msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"} + msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}) start := time.Now() m.sendWithRetry(ctx, "test", w, msg) @@ -236,7 +250,7 @@ func TestSendWithRetry_MaxRetriesExhausted(t *testing.T) { } ctx := context.Background() - msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"} + msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}) m.sendWithRetry(ctx, "test", w, msg) @@ -262,11 +276,11 @@ func TestSendMedia_Success(t *testing.T) { m.channels["test"] = ch m.workers["test"] = w - err := m.SendMedia(context.Background(), bus.OutboundMediaMessage{ + err := m.SendMedia(context.Background(), testOutboundMediaMessage(bus.OutboundMediaMessage{ Channel: "test", ChatID: "chat1", Parts: []bus.MediaPart{{Ref: "media://abc"}}, - }) + })) if err != nil { t.Fatalf("SendMedia() error = %v", err) } @@ -289,11 +303,11 @@ func TestSendMedia_PropagatesFailure(t *testing.T) { m.channels["test"] = ch m.workers["test"] = w - err := m.SendMedia(context.Background(), bus.OutboundMediaMessage{ + err := m.SendMedia(context.Background(), testOutboundMediaMessage(bus.OutboundMediaMessage{ Channel: "test", ChatID: "chat1", Parts: []bus.MediaPart{{Ref: "media://abc"}}, - }) + })) if err == nil { t.Fatal("expected SendMedia to return error") } @@ -316,11 +330,11 @@ func TestSendMedia_UnsupportedChannelReturnsError(t *testing.T) { m.channels["test"] = ch m.workers["test"] = w - err := m.SendMedia(context.Background(), bus.OutboundMediaMessage{ + err := m.SendMedia(context.Background(), testOutboundMediaMessage(bus.OutboundMediaMessage{ Channel: "test", ChatID: "chat1", Parts: []bus.MediaPart{{Ref: "media://abc"}}, - }) + })) if err == nil { t.Fatal("expected SendMedia to return error for unsupported channel") } @@ -346,11 +360,11 @@ func TestSendMedia_DeletesPlaceholderBeforeSending(t *testing.T) { m.workers["test"] = w m.RecordPlaceholder("test", "chat1", "placeholder-1") - err := m.SendMedia(context.Background(), bus.OutboundMediaMessage{ + err := m.SendMedia(context.Background(), testOutboundMediaMessage(bus.OutboundMediaMessage{ Channel: "test", ChatID: "chat1", Parts: []bus.MediaPart{{Ref: "media://abc"}}, - }) + })) if err != nil { t.Fatalf("SendMedia() error = %v", err) } @@ -383,7 +397,7 @@ func TestSendWithRetry_UnknownError(t *testing.T) { } ctx := context.Background() - msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"} + msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}) m.sendWithRetry(ctx, "test", w, msg) @@ -407,7 +421,7 @@ func TestSendWithRetry_ContextCancelled(t *testing.T) { } ctx, cancel := context.WithCancel(context.Background()) - msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"} + msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}) // Cancel context after first Send attempt returns ch.sendFn = func(_ context.Context, _ bus.OutboundMessage) error { @@ -453,7 +467,7 @@ func TestWorkerRateLimiter(t *testing.T) { // Enqueue 4 messages for i := range 4 { - w.queue <- bus.OutboundMessage{Channel: "test", ChatID: "1", Content: fmt.Sprintf("msg%d", i)} + w.queue <- testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: fmt.Sprintf("msg%d", i)}) } // Wait enough time for all messages to be sent (4 msgs at 2/s = ~2s, give extra margin) @@ -529,7 +543,7 @@ func TestRunWorker_MessageSplitting(t *testing.T) { go m.runWorker(ctx, "test", w) // Send a message that should be split - w.queue <- bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello world"} + w.queue <- testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello world"}) time.Sleep(100 * time.Millisecond) @@ -570,7 +584,7 @@ func TestSendWithRetry_ExponentialBackoff(t *testing.T) { } ctx := context.Background() - msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"} + msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}) start := time.Now() m.sendWithRetry(ctx, "test", w, msg) @@ -630,7 +644,7 @@ func TestPreSend_PlaceholderEditSuccess(t *testing.T) { // Register placeholder m.RecordPlaceholder("test", "123", "456") - msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"} + msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"}) _, edited := m.preSend(context.Background(), "test", msg, ch) if !edited { @@ -660,7 +674,7 @@ func TestPreSend_PlaceholderEditFails_FallsThrough(t *testing.T) { m.RecordPlaceholder("test", "123", "456") - msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"} + msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"}) _, edited := m.preSend(context.Background(), "test", msg, ch) if edited { @@ -719,7 +733,7 @@ func TestPreSend_TypingStopCalled(t *testing.T) { stopCalled = true }) - msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"} + msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"}) m.preSend(context.Background(), "test", msg, ch) if !stopCalled { @@ -736,7 +750,7 @@ func TestPreSend_NoRegisteredState(t *testing.T) { }, } - msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"} + msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"}) _, edited := m.preSend(context.Background(), "test", msg, ch) if edited { @@ -766,7 +780,7 @@ func TestPreSend_TypingAndPlaceholder(t *testing.T) { }) m.RecordPlaceholder("test", "123", "456") - msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"} + msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"}) _, edited := m.preSend(context.Background(), "test", msg, ch) if !stopCalled { @@ -830,7 +844,7 @@ func TestRecordTypingStop_ReplacesExistingStop(t *testing.T) { t.Fatalf("expected replacement typing stop to stay active until preSend, got %d calls", newStopCalls) } - msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"} + msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"}) m.preSend(context.Background(), "test", msg, &mockChannel{}) if newStopCalls != 1 { @@ -864,7 +878,7 @@ func TestSendWithRetry_PreSendEditsPlaceholder(t *testing.T) { limiter: rate.NewLimiter(rate.Inf, 1), } - msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"} + msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"}) m.sendWithRetry(context.Background(), "test", w, msg) if sendCalled { @@ -1027,7 +1041,7 @@ func TestPreSendStillWorksWithWrappedTypes(t *testing.T) { }) m.RecordPlaceholder("test", "chat1", "ph_id") - msg := bus.OutboundMessage{Channel: "test", ChatID: "chat1", Content: "response"} + msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "chat1", Content: "response"}) _, edited := m.preSend(context.Background(), "test", msg, ch) if !stopCalled { @@ -1130,11 +1144,11 @@ func TestManager_PlaceholderConsumedByResponse(t *testing.T) { // Transcription feedback arrives first — it should consume the placeholder // and be delivered via EditMessage, not Send. - msgTranscript := bus.OutboundMessage{ + msgTranscript := testOutboundMessage(bus.OutboundMessage{ Channel: "mock", ChatID: "chat-1", Content: "Transcript: hello", - } + }) mgr.sendWithRetry(ctx, "mock", worker, msgTranscript) if mockCh.editedMessages != 1 { @@ -1150,11 +1164,11 @@ func TestManager_PlaceholderConsumedByResponse(t *testing.T) { } // Final LLM response arrives — no placeholder left, so it goes through Send - msgFinal := bus.OutboundMessage{ + msgFinal := testOutboundMessage(bus.OutboundMessage{ Channel: "mock", ChatID: "chat-1", Content: "Final Answer", - } + }) mgr.sendWithRetry(ctx, "mock", worker, msgFinal) if len(mockCh.sentMessages) != 1 { @@ -1180,12 +1194,12 @@ func TestSendMessage_Synchronous(t *testing.T) { m.channels["test"] = ch m.workers["test"] = w - msg := bus.OutboundMessage{ + msg := testOutboundMessage(bus.OutboundMessage{ Channel: "test", ChatID: "123", Content: "hello world", ReplyToMessageID: "msg-456", - } + }) err := m.SendMessage(context.Background(), msg) if err != nil { @@ -1207,11 +1221,11 @@ func TestSendMessage_Synchronous(t *testing.T) { func TestSendMessage_UnknownChannel(t *testing.T) { m := newTestManager() - msg := bus.OutboundMessage{ + msg := testOutboundMessage(bus.OutboundMessage{ Channel: "nonexistent", ChatID: "123", Content: "hello", - } + }) err := m.SendMessage(context.Background(), msg) if err == nil { @@ -1228,11 +1242,11 @@ func TestSendMessage_NoWorker(t *testing.T) { m.channels["test"] = ch // No worker registered - msg := bus.OutboundMessage{ + msg := testOutboundMessage(bus.OutboundMessage{ Channel: "test", ChatID: "123", Content: "hello", - } + }) err := m.SendMessage(context.Background(), msg) if err == nil { @@ -1261,11 +1275,11 @@ func TestSendMessage_WithRetry(t *testing.T) { m.channels["test"] = ch m.workers["test"] = w - msg := bus.OutboundMessage{ + msg := testOutboundMessage(bus.OutboundMessage{ Channel: "test", ChatID: "123", Content: "retry me", - } + }) err := m.SendMessage(context.Background(), msg) if err != nil { @@ -1277,6 +1291,46 @@ func TestSendMessage_WithRetry(t *testing.T) { } } +func TestSendMessage_ContextOnlyUsesContextAddressing(t *testing.T) { + m := newTestManager() + + var received []bus.OutboundMessage + ch := &mockChannel{ + sendFn: func(_ context.Context, msg bus.OutboundMessage) error { + received = append(received, msg) + return nil + }, + } + + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + m.channels["test"] = ch + m.workers["test"] = w + + msg := testOutboundMessage(bus.OutboundMessage{ + Context: bus.NewOutboundContext("test", "123", "msg-9"), + Content: "hello", + }) + + if err := m.SendMessage(context.Background(), msg); err != nil { + t.Fatalf("expected no error, got %v", err) + } + if len(received) != 1 { + t.Fatalf("expected 1 message sent, got %d", len(received)) + } + if received[0].Channel != "test" || received[0].ChatID != "123" { + t.Fatalf("expected mirrored legacy address, got %+v", received[0]) + } + if received[0].Context.Channel != "test" || received[0].Context.ChatID != "123" { + t.Fatalf("expected context address to be preserved, got %+v", received[0].Context) + } + if received[0].ReplyToMessageID != "msg-9" { + t.Fatalf("expected reply_to_message_id msg-9, got %q", received[0].ReplyToMessageID) + } +} + func TestSendMessage_WithSplitting(t *testing.T) { m := newTestManager() @@ -1298,11 +1352,11 @@ func TestSendMessage_WithSplitting(t *testing.T) { m.channels["test"] = ch m.workers["test"] = w - msg := bus.OutboundMessage{ + msg := testOutboundMessage(bus.OutboundMessage{ Channel: "test", ChatID: "123", Content: "hello world", - } + }) err := m.SendMessage(context.Background(), msg) if err != nil { @@ -1314,6 +1368,43 @@ func TestSendMessage_WithSplitting(t *testing.T) { } } +func TestSendMedia_ContextOnlyUsesContextAddressing(t *testing.T) { + m := newTestManager() + + var received []bus.OutboundMediaMessage + ch := &mockMediaChannel{ + sendMediaFn: func(_ context.Context, msg bus.OutboundMediaMessage) ([]string, error) { + received = append(received, msg) + return nil, nil + }, + } + + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + m.channels["test"] = ch + m.workers["test"] = w + + msg := testOutboundMediaMessage(bus.OutboundMediaMessage{ + Context: bus.NewOutboundContext("test", "media-chat", ""), + Parts: []bus.MediaPart{{Type: "image", Ref: "media://1"}}, + }) + + if err := m.SendMedia(context.Background(), msg); err != nil { + t.Fatalf("expected no error, got %v", err) + } + if len(received) != 1 { + t.Fatalf("expected 1 media message sent, got %d", len(received)) + } + if received[0].Channel != "test" || received[0].ChatID != "media-chat" { + t.Fatalf("expected mirrored legacy media address, got %+v", received[0]) + } + if received[0].Context.Channel != "test" || received[0].Context.ChatID != "media-chat" { + t.Fatalf("expected media context address to be preserved, got %+v", received[0].Context) + } +} + func TestSendMessage_PreservesOrdering(t *testing.T) { m := newTestManager() @@ -1333,12 +1424,12 @@ func TestSendMessage_PreservesOrdering(t *testing.T) { m.workers["test"] = w // Send two messages sequentially — they must arrive in order - _ = m.SendMessage(context.Background(), bus.OutboundMessage{ + _ = m.SendMessage(context.Background(), testOutboundMessage(bus.OutboundMessage{ Channel: "test", ChatID: "1", Content: "first", - }) - _ = m.SendMessage(context.Background(), bus.OutboundMessage{ + })) + _ = m.SendMessage(context.Background(), testOutboundMessage(bus.OutboundMessage{ Channel: "test", ChatID: "1", Content: "second", - }) + })) if len(order) != 2 { t.Fatalf("expected 2 messages, got %d", len(order)) diff --git a/pkg/channels/matrix/matrix.go b/pkg/channels/matrix/matrix.go index 96db964cf..431fc5dc8 100644 --- a/pkg/channels/matrix/matrix.go +++ b/pkg/channels/matrix/matrix.go @@ -736,10 +736,8 @@ func (c *MatrixChannel) handleMessageEvent(ctx context.Context, evt *event.Event } peerKind := "direct" - peerID := senderID if isGroup { peerKind = "group" - peerID = roomID } metadata := map[string]string{ @@ -752,17 +750,19 @@ func (c *MatrixChannel) handleMessageEvent(ctx context.Context, evt *event.Event metadata["reply_to_msg_id"] = replyTo.String() } - c.HandleMessage( - c.baseContext(), - bus.Peer{Kind: peerKind, ID: peerID}, - evt.ID.String(), - senderID, - roomID, - content, - mediaPaths, - metadata, - sender, - ) + inboundCtx := bus.InboundContext{ + Channel: "matrix", + ChatID: roomID, + ChatType: peerKind, + SenderID: senderID, + MessageID: evt.ID.String(), + Raw: metadata, + } + if replyTo := msgEvt.GetRelatesTo().GetReplyTo(); replyTo != "" { + inboundCtx.ReplyToMessageID = replyTo.String() + } + + c.HandleInboundContext(c.baseContext(), roomID, content, mediaPaths, inboundCtx, sender) } // decryptEvent decrypts an encrypted event and returns the decrypted message event content. diff --git a/pkg/channels/onebot/onebot.go b/pkg/channels/onebot/onebot.go index e5651b046..4f8dff234 100644 --- a/pkg/channels/onebot/onebot.go +++ b/pkg/channels/onebot/onebot.go @@ -994,8 +994,6 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) { var contextChatID string var contextChatType string - var peer bus.Peer - metadata := map[string]string{} if parsed.ReplyTo != "" { @@ -1007,14 +1005,12 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) { chatID = "private:" + senderID contextChatID = senderID contextChatType = "direct" - peer = bus.Peer{Kind: "direct", ID: senderID} case "group": groupIDStr := strconv.FormatInt(groupID, 10) chatID = "group:" + groupIDStr contextChatID = groupIDStr contextChatType = "group" - peer = bus.Peer{Kind: "group", ID: groupIDStr} metadata["group_id"] = groupIDStr senderUserID, _ := parseJSONInt64(sender.UserID) @@ -1089,7 +1085,7 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) { Raw: metadata, } - c.HandleMessageWithContext(c.ctx, peer, chatID, content, parsed.Media, inboundCtx, senderInfo) + c.HandleInboundContext(c.ctx, chatID, content, parsed.Media, inboundCtx, senderInfo) } func (c *OneBotChannel) isDuplicate(messageID string) bool { diff --git a/pkg/channels/pico/client.go b/pkg/channels/pico/client.go index b4bfd09e5..91af34e4c 100644 --- a/pkg/channels/pico/client.go +++ b/pkg/channels/pico/client.go @@ -254,8 +254,6 @@ func (c *PicoClientChannel) handleServerMessage(pc *picoConn, msg PicoMessage) { chatID := "pico_client:" + sessionID senderID := "pico-remote" - peer := bus.Peer{Kind: "direct", ID: chatID} - sender := bus.SenderInfo{ Platform: "pico_client", PlatformID: senderID, @@ -266,10 +264,19 @@ func (c *PicoClientChannel) handleServerMessage(pc *picoConn, msg PicoMessage) { return } - c.HandleMessage(c.ctx, peer, msg.ID, senderID, chatID, content, nil, map[string]string{ - "platform": "pico_client", - "session_id": sessionID, - }, sender) + inboundCtx := bus.InboundContext{ + Channel: "pico_client", + ChatID: chatID, + ChatType: "direct", + SenderID: senderID, + MessageID: msg.ID, + Raw: map[string]string{ + "platform": "pico_client", + "session_id": sessionID, + }, + } + + c.HandleInboundContext(c.ctx, chatID, content, nil, inboundCtx, sender) } // Send sends a message to the remote server. diff --git a/pkg/channels/pico/pico.go b/pkg/channels/pico/pico.go index 0a7bf15a4..4f3f4aba3 100644 --- a/pkg/channels/pico/pico.go +++ b/pkg/channels/pico/pico.go @@ -539,8 +539,6 @@ func (c *PicoChannel) handleMessageSend(pc *picoConn, msg PicoMessage) { chatID := "pico:" + sessionID senderID := "pico-user" - peer := bus.Peer{Kind: "direct", ID: "pico:" + sessionID} - metadata := map[string]string{ "platform": "pico", "session_id": sessionID, @@ -562,7 +560,16 @@ func (c *PicoChannel) handleMessageSend(pc *picoConn, msg PicoMessage) { return } - c.HandleMessage(c.ctx, peer, msg.ID, senderID, chatID, content, nil, metadata, sender) + inboundCtx := bus.InboundContext{ + Channel: "pico", + ChatID: chatID, + ChatType: "direct", + SenderID: senderID, + MessageID: msg.ID, + Raw: metadata, + } + + c.HandleInboundContext(c.ctx, chatID, content, nil, inboundCtx, sender) } // truncate truncates a string to maxLen runes. diff --git a/pkg/channels/qq/qq.go b/pkg/channels/qq/qq.go index ba0045da6..aa78d8e85 100644 --- a/pkg/channels/qq/qq.go +++ b/pkg/channels/qq/qq.go @@ -657,15 +657,7 @@ func (c *QQChannel) handleC2CMessage() event.C2CMessageEventHandler { Raw: metadata, } - c.HandleMessageWithContext( - c.ctx, - bus.Peer{Kind: "direct", ID: senderID}, - senderID, - content, - mediaPaths, - inboundCtx, - sender, - ) + c.HandleInboundContext(c.ctx, senderID, content, mediaPaths, inboundCtx, sender) return nil } @@ -744,15 +736,7 @@ func (c *QQChannel) handleGroupATMessage() event.GroupATMessageEventHandler { Raw: metadata, } - c.HandleMessageWithContext( - c.ctx, - bus.Peer{Kind: "group", ID: data.GroupID}, - data.GroupID, - content, - mediaPaths, - inboundCtx, - sender, - ) + c.HandleInboundContext(c.ctx, data.GroupID, content, mediaPaths, inboundCtx, sender) return nil } diff --git a/pkg/channels/slack/slack.go b/pkg/channels/slack/slack.go index 882cc5cb5..543f6f338 100644 --- a/pkg/channels/slack/slack.go +++ b/pkg/channels/slack/slack.go @@ -356,14 +356,10 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { } peerKind := "channel" - peerID := channelID if strings.HasPrefix(channelID, "D") { peerKind = "direct" - peerID = senderID } - peer := bus.Peer{Kind: peerKind, ID: peerID} - metadata := map[string]string{ "message_ts": messageTS, "channel_id": channelID, @@ -394,7 +390,7 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { inboundCtx.TopicID = threadTS } - c.HandleMessageWithContext(c.ctx, peer, chatID, content, mediaPaths, inboundCtx, sender) + c.HandleInboundContext(c.ctx, chatID, content, mediaPaths, inboundCtx, sender) } func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) { @@ -442,14 +438,10 @@ func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) { } mentionPeerKind := "channel" - mentionPeerID := channelID if strings.HasPrefix(channelID, "D") { mentionPeerKind = "direct" - mentionPeerID = senderID } - mentionPeer := bus.Peer{Kind: mentionPeerKind, ID: mentionPeerID} - metadata := map[string]string{ "message_ts": messageTS, "channel_id": channelID, @@ -472,7 +464,7 @@ func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) { Raw: metadata, } - c.HandleMessageWithContext(c.ctx, mentionPeer, chatID, content, nil, inboundCtx, mentionSender) + c.HandleInboundContext(c.ctx, chatID, content, nil, inboundCtx, mentionSender) } func (c *SlackChannel) handleSlashCommand(event socketmode.Event) { @@ -520,10 +512,8 @@ func (c *SlackChannel) handleSlashCommand(event socketmode.Event) { "text": utils.Truncate(content, 50), }) peerKind := "channel" - peerID := channelID if strings.HasPrefix(channelID, "D") { peerKind = "direct" - peerID = senderID } inboundCtx := bus.InboundContext{ Channel: c.Name(), @@ -536,15 +526,7 @@ func (c *SlackChannel) handleSlashCommand(event socketmode.Event) { Raw: metadata, } - c.HandleMessageWithContext( - c.ctx, - bus.Peer{Kind: peerKind, ID: peerID}, - chatID, - content, - nil, - inboundCtx, - cmdSender, - ) + c.HandleInboundContext(c.ctx, chatID, content, nil, inboundCtx, cmdSender) } func (c *SlackChannel) downloadSlackFile(file slack.File) string { diff --git a/pkg/channels/telegram/telegram.go b/pkg/channels/telegram/telegram.go index e1532bcf9..31a5afb30 100644 --- a/pkg/channels/telegram/telegram.go +++ b/pkg/channels/telegram/telegram.go @@ -708,13 +708,9 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes }) peerKind := "direct" - peerID := fmt.Sprintf("%d", user.ID) if message.Chat.Type != "private" { peerKind = "group" - peerID = compositeChatID } - - peer := bus.Peer{Kind: peerKind, ID: peerID} messageID := fmt.Sprintf("%d", message.MessageID) metadata := map[string]string{ @@ -742,7 +738,6 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes c.HandleMessageWithContext( c.ctx, - peer, compositeChatID, content, mediaPaths, diff --git a/pkg/channels/wecom/wecom.go b/pkg/channels/wecom/wecom.go index 65b9b4ca4..10b95a20f 100644 --- a/pkg/channels/wecom/wecom.go +++ b/pkg/channels/wecom/wecom.go @@ -570,7 +570,6 @@ func (c *WeComChannel) dispatchIncoming(reqID string, msg wecomIncomingMessage) return err } - peer := bus.Peer{Kind: peerKind, ID: actualChatID} metadata := map[string]string{ "channel": "wecom", "req_id": reqID, @@ -596,7 +595,7 @@ func (c *WeComChannel) dispatchIncoming(reqID string, msg wecomIncomingMessage) Raw: metadata, } - c.HandleMessageWithContext(c.ctx, peer, actualChatID, content, mediaRefs, inboundCtx, sender) + c.HandleInboundContext(c.ctx, actualChatID, content, mediaRefs, inboundCtx, sender) return nil } diff --git a/pkg/channels/weixin/weixin.go b/pkg/channels/weixin/weixin.go index 0e9010131..5e62a8a3b 100644 --- a/pkg/channels/weixin/weixin.go +++ b/pkg/channels/weixin/weixin.go @@ -334,8 +334,6 @@ func (c *WeixinChannel) handleInboundMessage(ctx context.Context, msg WeixinMess return } - peer := bus.Peer{Kind: "direct", ID: fromUserID} - metadata := map[string]string{ "from_user_id": fromUserID, "context_token": msg.ContextToken, @@ -354,7 +352,21 @@ func (c *WeixinChannel) handleInboundMessage(ctx context.Context, msg WeixinMess c.persistContextTokens() } - c.HandleMessage(ctx, peer, messageID, fromUserID, fromUserID, content, mediaRefs, metadata, sender) + inboundCtx := bus.InboundContext{ + Channel: "weixin", + ChatID: fromUserID, + ChatType: "direct", + SenderID: fromUserID, + MessageID: messageID, + Raw: metadata, + } + if msg.ContextToken != "" { + inboundCtx.ReplyHandles = map[string]string{ + "context_token": msg.ContextToken, + } + } + + c.HandleInboundContext(ctx, fromUserID, content, mediaRefs, inboundCtx, sender) } // Send implements channels.Channel by sending a text message to the WeChat user. diff --git a/pkg/channels/whatsapp/whatsapp.go b/pkg/channels/whatsapp/whatsapp.go index 98622fe37..7064da219 100644 --- a/pkg/channels/whatsapp/whatsapp.go +++ b/pkg/channels/whatsapp/whatsapp.go @@ -223,13 +223,6 @@ func (c *WhatsAppChannel) handleIncomingMessage(msg map[string]any) { metadata["user_name"] = userName } - var peer bus.Peer - if chatID == senderID { - peer = bus.Peer{Kind: "direct", ID: senderID} - } else { - peer = bus.Peer{Kind: "group", ID: chatID} - } - logger.InfoCF("whatsapp", "WhatsApp message received", map[string]any{ "sender": senderID, "preview": utils.Truncate(content, 50), @@ -248,5 +241,18 @@ func (c *WhatsAppChannel) handleIncomingMessage(msg map[string]any) { return } - c.HandleMessage(c.ctx, peer, messageID, senderID, chatID, content, mediaPaths, metadata, sender) + inboundCtx := bus.InboundContext{ + Channel: "whatsapp", + ChatID: chatID, + SenderID: senderID, + MessageID: messageID, + Raw: metadata, + } + if chatID == senderID { + inboundCtx.ChatType = "direct" + } else { + inboundCtx.ChatType = "group" + } + + c.HandleInboundContext(c.ctx, chatID, content, mediaPaths, inboundCtx, sender) } diff --git a/pkg/channels/whatsapp_native/whatsapp_native.go b/pkg/channels/whatsapp_native/whatsapp_native.go index d0a74a405..a1e6e50cd 100644 --- a/pkg/channels/whatsapp_native/whatsapp_native.go +++ b/pkg/channels/whatsapp_native/whatsapp_native.go @@ -375,7 +375,6 @@ func (c *WhatsAppNativeChannel) handleIncoming(evt *events.Message) { if evt.Info.Chat.Server == types.GroupServer { peerKind = "group" } - peer := bus.Peer{Kind: peerKind, ID: chatID} messageID := evt.Info.ID sender := bus.SenderInfo{ Platform: "whatsapp", @@ -393,7 +392,17 @@ func (c *WhatsAppNativeChannel) handleIncoming(evt *events.Message) { "WhatsApp message received", map[string]any{"sender_id": senderID, "content_preview": utils.Truncate(content, 50)}, ) - c.HandleMessage(c.runCtx, peer, messageID, senderID, chatID, content, mediaPaths, metadata, sender) + + inboundCtx := bus.InboundContext{ + Channel: "whatsapp", + ChatID: chatID, + SenderID: senderID, + MessageID: messageID, + ChatType: peerKind, + Raw: metadata, + } + + c.HandleInboundContext(c.runCtx, chatID, content, mediaPaths, inboundCtx, sender) } func (c *WhatsAppNativeChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]string, error) { diff --git a/pkg/config/config.go b/pkg/config/config.go index 10eb07339..014c90045 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -99,7 +99,7 @@ type BuildInfo struct { } // MarshalJSON implements custom JSON marshaling for Config -// to omit providers section when empty and session when empty +// to omit providers section when empty and session when empty. func (c *Config) MarshalJSON() ([]byte, error) { type Alias Config aux := &struct { @@ -109,11 +109,8 @@ func (c *Config) MarshalJSON() ([]byte, error) { Alias: (*Alias)(c), } - // Only include session if not empty. Deprecated dm_scope is intentionally - // omitted so persisted configs converge on dimensions-based session policy. if len(c.Session.Dimensions) > 0 || len(c.Session.IdentityLinks) > 0 { sessionCfg := c.Session - sessionCfg.DMScope = "" aux.Session = &sessionCfg } @@ -199,7 +196,6 @@ type AgentBinding struct { type SessionConfig struct { Dimensions []string `json:"dimensions,omitempty"` - DMScope string `json:"dm_scope,omitempty"` // Deprecated: ignored by the new session policy path. IdentityLinks map[string][]string `json:"identity_links,omitempty"` } diff --git a/pkg/devices/service.go b/pkg/devices/service.go index 1bafe6085..1cf2a686e 100644 --- a/pkg/devices/service.go +++ b/pkg/devices/service.go @@ -131,8 +131,7 @@ func (s *Service) sendNotification(ev *events.DeviceEvent) { pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second) defer pubCancel() msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{ - Channel: platform, - ChatID: userID, + Context: bus.NewOutboundContext(platform, userID, ""), Content: msg, }) diff --git a/pkg/heartbeat/service.go b/pkg/heartbeat/service.go index 5dda78ea9..e5b28ec11 100644 --- a/pkg/heartbeat/service.go +++ b/pkg/heartbeat/service.go @@ -339,8 +339,7 @@ func (hs *HeartbeatService) sendResponse(response string) { pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second) defer pubCancel() msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{ - Channel: platform, - ChatID: userID, + Context: bus.NewOutboundContext(platform, userID, ""), Content: response, }) diff --git a/pkg/routing/route.go b/pkg/routing/route.go index e5a000067..88a0006da 100644 --- a/pkg/routing/route.go +++ b/pkg/routing/route.go @@ -3,25 +3,21 @@ package routing import ( "strings" + "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" ) -// RouteInput contains the routing context from an inbound message. -type RouteInput struct { - Channel string - AccountID string - Peer *RoutePeer - ParentPeer *RoutePeer - GuildID string - TeamID string -} - // SessionPolicy describes how a routed message should be mapped to a session. type SessionPolicy struct { Dimensions []string IdentityLinks map[string][]string } +type RoutePeer struct { + Kind string + ID string +} + // ResolvedRoute is the result of agent routing. type ResolvedRoute struct { AgentID string @@ -41,14 +37,15 @@ func NewRouteResolver(cfg *config.Config) *RouteResolver { return &RouteResolver{cfg: cfg} } -// ResolveRoute determines which agent handles the message and returns the -// session policy that should be used to allocate session state. +// ResolveRoute determines which agent handles the message from a normalized +// inbound context and returns the session policy that should be used to +// allocate session state. // Implements the 7-level priority cascade: // peer > parent_peer > guild > team > account > channel_wildcard > default -func (r *RouteResolver) ResolveRoute(input RouteInput) ResolvedRoute { - channel := strings.ToLower(strings.TrimSpace(input.Channel)) - accountID := NormalizeAccountID(input.AccountID) - peer := input.Peer +func (r *RouteResolver) ResolveRoute(inbound bus.InboundContext) ResolvedRoute { + channel := strings.ToLower(strings.TrimSpace(inbound.Channel)) + accountID := NormalizeAccountID(inbound.Account) + peer := routePeerFromContext(inbound) sessionPolicy := r.sessionPolicy() @@ -73,7 +70,7 @@ func (r *RouteResolver) ResolveRoute(input RouteInput) ResolvedRoute { } // Priority 2: Parent peer binding - parentPeer := input.ParentPeer + parentPeer := parentPeerFromContext(inbound) if parentPeer != nil && strings.TrimSpace(parentPeer.ID) != "" { if match := r.findPeerMatch(bindings, parentPeer); match != nil { return choose(match.AgentID, "binding.peer.parent") @@ -81,7 +78,7 @@ func (r *RouteResolver) ResolveRoute(input RouteInput) ResolvedRoute { } // Priority 3: Guild binding - guildID := strings.TrimSpace(input.GuildID) + guildID := routeGuildIDFromContext(inbound) if guildID != "" { if match := r.findGuildMatch(bindings, guildID); match != nil { return choose(match.AgentID, "binding.guild") @@ -89,7 +86,7 @@ func (r *RouteResolver) ResolveRoute(input RouteInput) ResolvedRoute { } // Priority 4: Team binding - teamID := strings.TrimSpace(input.TeamID) + teamID := routeTeamIDFromContext(inbound) if teamID != "" { if match := r.findTeamMatch(bindings, teamID); match != nil { return choose(match.AgentID, "binding.team") @@ -276,6 +273,46 @@ func normalizeSessionDimensions(dimensions []string) []string { return normalized } +func routePeerFromContext(ctx bus.InboundContext) *RoutePeer { + peerKind := normalizeChannel(strings.TrimSpace(ctx.ChatType)) + if peerKind == "" || peerKind == "unknown" { + return nil + } + + peerID := strings.TrimSpace(ctx.ChatID) + if peerKind == "direct" && peerID == "" { + peerID = strings.TrimSpace(ctx.SenderID) + } + if peerID == "" { + return nil + } + + return &RoutePeer{Kind: peerKind, ID: peerID} +} + +func parentPeerFromContext(ctx bus.InboundContext) *RoutePeer { + if topicID := strings.TrimSpace(ctx.TopicID); topicID != "" { + return &RoutePeer{Kind: "topic", ID: topicID} + } + return nil +} + +func routeGuildIDFromContext(ctx bus.InboundContext) string { + if strings.EqualFold(strings.TrimSpace(ctx.SpaceType), "guild") { + return strings.TrimSpace(ctx.SpaceID) + } + return "" +} + +func routeTeamIDFromContext(ctx bus.InboundContext) string { + switch strings.ToLower(strings.TrimSpace(ctx.SpaceType)) { + case "team", "workspace": + return strings.TrimSpace(ctx.SpaceID) + default: + return "" + } +} + func cloneIdentityLinks(src map[string][]string) map[string][]string { if len(src) == 0 { return nil @@ -288,3 +325,7 @@ func cloneIdentityLinks(src map[string][]string) map[string][]string { } return cloned } + +func normalizeChannel(value string) string { + return strings.ToLower(strings.TrimSpace(value)) +} diff --git a/pkg/routing/route_test.go b/pkg/routing/route_test.go index 3397bd8e8..46a0f9f13 100644 --- a/pkg/routing/route_test.go +++ b/pkg/routing/route_test.go @@ -3,6 +3,7 @@ package routing import ( "testing" + "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" ) @@ -26,9 +27,10 @@ func TestResolveRoute_DefaultAgent_NoBindings(t *testing.T) { cfg := testConfig(nil, nil) r := NewRouteResolver(cfg) - route := r.ResolveRoute(RouteInput{ - Channel: "telegram", - Peer: &RoutePeer{Kind: "direct", ID: "user1"}, + route := r.ResolveRoute(bus.InboundContext{ + Channel: "telegram", + ChatType: "direct", + SenderID: "user1", }) if route.AgentID != DefaultAgentID { @@ -63,9 +65,10 @@ func TestResolveRoute_PeerBinding(t *testing.T) { cfg := testConfig(agents, bindings) r := NewRouteResolver(cfg) - route := r.ResolveRoute(RouteInput{ - Channel: "telegram", - Peer: &RoutePeer{Kind: "direct", ID: "user123"}, + route := r.ResolveRoute(bus.InboundContext{ + Channel: "telegram", + ChatType: "direct", + SenderID: "user123", }) if route.AgentID != "support" { @@ -94,10 +97,12 @@ func TestResolveRoute_GuildBinding(t *testing.T) { cfg := testConfig(agents, bindings) r := NewRouteResolver(cfg) - route := r.ResolveRoute(RouteInput{ - Channel: "discord", - GuildID: "guild-abc", - Peer: &RoutePeer{Kind: "channel", ID: "ch1"}, + route := r.ResolveRoute(bus.InboundContext{ + Channel: "discord", + ChatID: "ch1", + ChatType: "channel", + SpaceID: "guild-abc", + SpaceType: "guild", }) if route.AgentID != "gaming" { @@ -126,10 +131,12 @@ func TestResolveRoute_TeamBinding(t *testing.T) { cfg := testConfig(agents, bindings) r := NewRouteResolver(cfg) - route := r.ResolveRoute(RouteInput{ - Channel: "slack", - TeamID: "T12345", - Peer: &RoutePeer{Kind: "channel", ID: "C001"}, + route := r.ResolveRoute(bus.InboundContext{ + Channel: "slack", + ChatID: "C001", + ChatType: "channel", + SpaceID: "T12345", + SpaceType: "team", }) if route.AgentID != "work" { @@ -157,10 +164,11 @@ func TestResolveRoute_AccountBinding(t *testing.T) { cfg := testConfig(agents, bindings) r := NewRouteResolver(cfg) - route := r.ResolveRoute(RouteInput{ - Channel: "telegram", - AccountID: "bot2", - Peer: &RoutePeer{Kind: "direct", ID: "user1"}, + route := r.ResolveRoute(bus.InboundContext{ + Channel: "telegram", + Account: "bot2", + ChatType: "direct", + SenderID: "user1", }) if route.AgentID != "premium" { @@ -188,9 +196,10 @@ func TestResolveRoute_ChannelWildcard(t *testing.T) { cfg := testConfig(agents, bindings) r := NewRouteResolver(cfg) - route := r.ResolveRoute(RouteInput{ - Channel: "telegram", - Peer: &RoutePeer{Kind: "direct", ID: "user1"}, + route := r.ResolveRoute(bus.InboundContext{ + Channel: "telegram", + ChatType: "direct", + SenderID: "user1", }) if route.AgentID != "telegram-bot" { @@ -228,10 +237,12 @@ func TestResolveRoute_PriorityOrder_PeerBeatsGuild(t *testing.T) { cfg := testConfig(agents, bindings) r := NewRouteResolver(cfg) - route := r.ResolveRoute(RouteInput{ - Channel: "discord", - GuildID: "guild-1", - Peer: &RoutePeer{Kind: "direct", ID: "user-vip"}, + route := r.ResolveRoute(bus.InboundContext{ + Channel: "discord", + ChatType: "direct", + SenderID: "user-vip", + SpaceID: "guild-1", + SpaceType: "guild", }) if route.AgentID != "vip" { @@ -258,9 +269,7 @@ func TestResolveRoute_InvalidAgentFallsToDefault(t *testing.T) { cfg := testConfig(agents, bindings) r := NewRouteResolver(cfg) - route := r.ResolveRoute(RouteInput{ - Channel: "telegram", - }) + route := r.ResolveRoute(bus.InboundContext{Channel: "telegram"}) if route.AgentID != "main" { t.Errorf("AgentID = %q, want 'main' (invalid agent should fall to default)", route.AgentID) @@ -276,9 +285,7 @@ func TestResolveRoute_DefaultAgentSelection(t *testing.T) { cfg := testConfig(agents, nil) r := NewRouteResolver(cfg) - route := r.ResolveRoute(RouteInput{ - Channel: "cli", - }) + route := r.ResolveRoute(bus.InboundContext{Channel: "cli"}) if route.AgentID != "beta" { t.Errorf("AgentID = %q, want 'beta' (marked as default)", route.AgentID) @@ -293,9 +300,7 @@ func TestResolveRoute_NoDefaultUsesFirst(t *testing.T) { cfg := testConfig(agents, nil) r := NewRouteResolver(cfg) - route := r.ResolveRoute(RouteInput{ - Channel: "cli", - }) + route := r.ResolveRoute(bus.InboundContext{Channel: "cli"}) if route.AgentID != "alpha" { t.Errorf("AgentID = %q, want 'alpha' (first in list)", route.AgentID) diff --git a/pkg/routing/session_key.go b/pkg/routing/session_key.go deleted file mode 100644 index cc3ce43f3..000000000 --- a/pkg/routing/session_key.go +++ /dev/null @@ -1,218 +0,0 @@ -package routing - -import ( - "fmt" - "strings" -) - -// DMScope controls DM session isolation granularity. -type DMScope string - -const ( - DMScopeMain DMScope = "main" - DMScopePerPeer DMScope = "per-peer" - DMScopePerChannelPeer DMScope = "per-channel-peer" - DMScopePerAccountChannelPeer DMScope = "per-account-channel-peer" -) - -// RoutePeer represents a chat peer with kind and ID. -type RoutePeer struct { - Kind string // "direct", "group", "channel" - ID string -} - -// SessionKeyParams holds all inputs for session key construction. -type SessionKeyParams struct { - AgentID string - Channel string - AccountID string - Peer *RoutePeer - DMScope DMScope - IdentityLinks map[string][]string -} - -// ParsedSessionKey is the result of parsing an agent-scoped session key. -type ParsedSessionKey struct { - AgentID string - Rest string -} - -// BuildAgentMainSessionKey returns "agent::main". -func BuildAgentMainSessionKey(agentID string) string { - return fmt.Sprintf("agent:%s:%s", NormalizeAgentID(agentID), DefaultMainKey) -} - -// BuildAgentPeerSessionKey constructs a session key based on agent, channel, peer, and DM scope. -func BuildAgentPeerSessionKey(params SessionKeyParams) string { - agentID := NormalizeAgentID(params.AgentID) - - peer := params.Peer - if peer == nil { - peer = &RoutePeer{Kind: "direct"} - } - peerKind := strings.TrimSpace(peer.Kind) - if peerKind == "" { - peerKind = "direct" - } - - if peerKind == "direct" { - dmScope := params.DMScope - if dmScope == "" { - dmScope = DMScopeMain - } - peerID := CanonicalSessionPeerID(params.Channel, peer.ID, dmScope, params.IdentityLinks) - - switch dmScope { - case DMScopePerAccountChannelPeer: - if peerID != "" { - channel := normalizeChannel(params.Channel) - accountID := NormalizeAccountID(params.AccountID) - return fmt.Sprintf("agent:%s:%s:%s:direct:%s", agentID, channel, accountID, peerID) - } - case DMScopePerChannelPeer: - if peerID != "" { - channel := normalizeChannel(params.Channel) - return fmt.Sprintf("agent:%s:%s:direct:%s", agentID, channel, peerID) - } - case DMScopePerPeer: - if peerID != "" { - return fmt.Sprintf("agent:%s:direct:%s", agentID, peerID) - } - } - return BuildAgentMainSessionKey(agentID) - } - - // Group/channel peers always get per-peer sessions - channel := normalizeChannel(params.Channel) - peerID := strings.ToLower(strings.TrimSpace(peer.ID)) - if peerID == "" { - peerID = "unknown" - } - return fmt.Sprintf("agent:%s:%s:%s:%s", agentID, channel, peerKind, peerID) -} - -// CanonicalSessionPeerID applies the current DM session canonicalization rules, -// including identity-link collapse when enabled. -func CanonicalSessionPeerID( - channel, peerID string, - dmScope DMScope, - identityLinks map[string][]string, -) string { - normalizedPeerID := strings.TrimSpace(peerID) - if normalizedPeerID == "" { - return "" - } - - if dmScope != DMScopeMain { - if linked := resolveLinkedPeerID(identityLinks, channel, normalizedPeerID); linked != "" { - normalizedPeerID = linked - } - } - - return strings.ToLower(normalizedPeerID) -} - -// CanonicalSessionIdentityID collapses an identity using identity_links when -// possible, then returns a normalized lowercase identifier. -func CanonicalSessionIdentityID(channel, rawID string, identityLinks map[string][]string) string { - normalizedID := strings.TrimSpace(rawID) - if normalizedID == "" { - return "" - } - if linked := resolveLinkedPeerID(identityLinks, channel, normalizedID); linked != "" { - normalizedID = linked - } - return strings.ToLower(normalizedID) -} - -// ParseAgentSessionKey extracts agentId and rest from "agent::". -func ParseAgentSessionKey(sessionKey string) *ParsedSessionKey { - raw := strings.TrimSpace(sessionKey) - if raw == "" { - return nil - } - parts := strings.SplitN(raw, ":", 3) - if len(parts) < 3 { - return nil - } - if parts[0] != "agent" { - return nil - } - agentID := strings.TrimSpace(parts[1]) - rest := parts[2] - if agentID == "" || rest == "" { - return nil - } - return &ParsedSessionKey{AgentID: agentID, Rest: rest} -} - -// IsSubagentSessionKey returns true if the session key represents a subagent. -func IsSubagentSessionKey(sessionKey string) bool { - raw := strings.TrimSpace(sessionKey) - if raw == "" { - return false - } - if strings.HasPrefix(strings.ToLower(raw), "subagent:") { - return true - } - parsed := ParseAgentSessionKey(raw) - if parsed == nil { - return false - } - return strings.HasPrefix(strings.ToLower(parsed.Rest), "subagent:") -} - -func normalizeChannel(channel string) string { - c := strings.TrimSpace(strings.ToLower(channel)) - if c == "" { - return "unknown" - } - return c -} - -func resolveLinkedPeerID(identityLinks map[string][]string, channel, peerID string) string { - if len(identityLinks) == 0 { - return "" - } - peerID = strings.TrimSpace(peerID) - if peerID == "" { - return "" - } - - candidates := make(map[string]bool) - rawCandidate := strings.ToLower(peerID) - if rawCandidate != "" { - candidates[rawCandidate] = true - } - channel = strings.ToLower(strings.TrimSpace(channel)) - if channel != "" { - scopedCandidate := fmt.Sprintf("%s:%s", channel, strings.ToLower(peerID)) - candidates[scopedCandidate] = true - } - - // If peerID is already in canonical "platform:id" format, also add the - // bare ID part as a candidate for backward compatibility with identity_links - // that use raw IDs (e.g. "123" instead of "telegram:123"). - if idx := strings.Index(rawCandidate, ":"); idx > 0 && idx < len(rawCandidate)-1 { - bareID := rawCandidate[idx+1:] - candidates[bareID] = true - } - - if len(candidates) == 0 { - return "" - } - - for canonical, ids := range identityLinks { - canonicalName := strings.TrimSpace(canonical) - if canonicalName == "" { - continue - } - for _, id := range ids { - normalized := strings.ToLower(strings.TrimSpace(id)) - if normalized != "" && candidates[normalized] { - return canonicalName - } - } - } - return "" -} diff --git a/pkg/routing/session_key_test.go b/pkg/routing/session_key_test.go deleted file mode 100644 index ad7a1ca02..000000000 --- a/pkg/routing/session_key_test.go +++ /dev/null @@ -1,207 +0,0 @@ -package routing - -import "testing" - -func TestBuildAgentMainSessionKey(t *testing.T) { - got := BuildAgentMainSessionKey("sales") - want := "agent:sales:main" - if got != want { - t.Errorf("BuildAgentMainSessionKey('sales') = %q, want %q", got, want) - } -} - -func TestBuildAgentMainSessionKey_Normalizes(t *testing.T) { - got := BuildAgentMainSessionKey("Sales Bot") - want := "agent:sales-bot:main" - if got != want { - t.Errorf("BuildAgentMainSessionKey('Sales Bot') = %q, want %q", got, want) - } -} - -func TestBuildAgentPeerSessionKey_DMScopeMain(t *testing.T) { - got := BuildAgentPeerSessionKey(SessionKeyParams{ - AgentID: "main", - Channel: "telegram", - Peer: &RoutePeer{Kind: "direct", ID: "user123"}, - DMScope: DMScopeMain, - }) - want := "agent:main:main" - if got != want { - t.Errorf("DMScopeMain = %q, want %q", got, want) - } -} - -func TestBuildAgentPeerSessionKey_DMScopePerPeer(t *testing.T) { - got := BuildAgentPeerSessionKey(SessionKeyParams{ - AgentID: "main", - Channel: "telegram", - Peer: &RoutePeer{Kind: "direct", ID: "user123"}, - DMScope: DMScopePerPeer, - }) - want := "agent:main:direct:user123" - if got != want { - t.Errorf("DMScopePerPeer = %q, want %q", got, want) - } -} - -func TestBuildAgentPeerSessionKey_DMScopePerChannelPeer(t *testing.T) { - got := BuildAgentPeerSessionKey(SessionKeyParams{ - AgentID: "main", - Channel: "telegram", - Peer: &RoutePeer{Kind: "direct", ID: "user123"}, - DMScope: DMScopePerChannelPeer, - }) - want := "agent:main:telegram:direct:user123" - if got != want { - t.Errorf("DMScopePerChannelPeer = %q, want %q", got, want) - } -} - -func TestBuildAgentPeerSessionKey_DMScopePerAccountChannelPeer(t *testing.T) { - got := BuildAgentPeerSessionKey(SessionKeyParams{ - AgentID: "main", - Channel: "telegram", - AccountID: "bot1", - Peer: &RoutePeer{Kind: "direct", ID: "User123"}, - DMScope: DMScopePerAccountChannelPeer, - }) - want := "agent:main:telegram:bot1:direct:user123" - if got != want { - t.Errorf("DMScopePerAccountChannelPeer = %q, want %q", got, want) - } -} - -func TestBuildAgentPeerSessionKey_GroupPeer(t *testing.T) { - got := BuildAgentPeerSessionKey(SessionKeyParams{ - AgentID: "main", - Channel: "telegram", - Peer: &RoutePeer{Kind: "group", ID: "chat456"}, - DMScope: DMScopePerPeer, - }) - want := "agent:main:telegram:group:chat456" - if got != want { - t.Errorf("GroupPeer = %q, want %q", got, want) - } -} - -func TestBuildAgentPeerSessionKey_NilPeer(t *testing.T) { - got := BuildAgentPeerSessionKey(SessionKeyParams{ - AgentID: "main", - Channel: "telegram", - Peer: nil, - DMScope: DMScopePerPeer, - }) - // nil peer defaults to direct with empty ID, falls to main - want := "agent:main:main" - if got != want { - t.Errorf("NilPeer = %q, want %q", got, want) - } -} - -func TestBuildAgentPeerSessionKey_IdentityLink(t *testing.T) { - links := map[string][]string{ - "john": {"telegram:user123", "discord:john#1234"}, - } - got := BuildAgentPeerSessionKey(SessionKeyParams{ - AgentID: "main", - Channel: "telegram", - Peer: &RoutePeer{Kind: "direct", ID: "user123"}, - DMScope: DMScopePerPeer, - IdentityLinks: links, - }) - want := "agent:main:direct:john" - if got != want { - t.Errorf("IdentityLink = %q, want %q", got, want) - } -} - -func TestResolveLinkedPeerID_CanonicalPeerID(t *testing.T) { - // When peerID is already in canonical "platform:id" format, - // it should match identity_links that use the bare ID. - links := map[string][]string{ - "john": {"123"}, - } - got := resolveLinkedPeerID(links, "telegram", "telegram:123") - if got != "john" { - t.Errorf("resolveLinkedPeerID with canonical peerID = %q, want %q", got, "john") - } -} - -func TestResolveLinkedPeerID_CanonicalInLinks(t *testing.T) { - // When identity_links contain canonical IDs and peerID is canonical too - links := map[string][]string{ - "john": {"telegram:123", "discord:456"}, - } - got := resolveLinkedPeerID(links, "telegram", "telegram:123") - if got != "john" { - t.Errorf("resolveLinkedPeerID canonical in links = %q, want %q", got, "john") - } -} - -func TestResolveLinkedPeerID_BarePeerIDMatchesCanonicalLink(t *testing.T) { - // When peerID is bare "123" and links have "telegram:123", - // the scoped candidate "telegram:123" should match. - links := map[string][]string{ - "john": {"telegram:123"}, - } - got := resolveLinkedPeerID(links, "telegram", "123") - if got != "john" { - t.Errorf("resolveLinkedPeerID bare peer matches canonical link = %q, want %q", got, "john") - } -} - -func TestResolveLinkedPeerID_NoMatch(t *testing.T) { - links := map[string][]string{ - "john": {"telegram:123"}, - } - got := resolveLinkedPeerID(links, "discord", "999") - if got != "" { - t.Errorf("resolveLinkedPeerID no match = %q, want empty", got) - } -} - -func TestParseAgentSessionKey_Valid(t *testing.T) { - parsed := ParseAgentSessionKey("agent:sales:telegram:direct:user123") - if parsed == nil { - t.Fatal("expected non-nil result") - } - if parsed.AgentID != "sales" { - t.Errorf("AgentID = %q, want 'sales'", parsed.AgentID) - } - if parsed.Rest != "telegram:direct:user123" { - t.Errorf("Rest = %q, want 'telegram:direct:user123'", parsed.Rest) - } -} - -func TestParseAgentSessionKey_Invalid(t *testing.T) { - tests := []string{ - "", - "foo:bar", - "notprefix:sales:main", - "agent::main", - "agent:sales:", - } - for _, input := range tests { - if got := ParseAgentSessionKey(input); got != nil { - t.Errorf("ParseAgentSessionKey(%q) = %+v, want nil", input, got) - } - } -} - -func TestIsSubagentSessionKey(t *testing.T) { - tests := []struct { - input string - want bool - }{ - {"subagent:task-1", true}, - {"agent:main:subagent:task-1", true}, - {"agent:main:main", false}, - {"agent:main:telegram:direct:user123", false}, - {"", false}, - } - for _, tt := range tests { - if got := IsSubagentSessionKey(tt.input); got != tt.want { - t.Errorf("IsSubagentSessionKey(%q) = %v, want %v", tt.input, got, tt.want) - } - } -} diff --git a/pkg/session/allocator.go b/pkg/session/allocator.go index 6bf678deb..7045b93d6 100644 --- a/pkg/session/allocator.go +++ b/pkg/session/allocator.go @@ -32,7 +32,7 @@ type AllocationInput struct { func AllocateRouteSession(input AllocationInput) Allocation { scope := buildSessionScope(input) legacySessionAliases := buildLegacySessionAliases(input) - legacyMainSessionKey := strings.ToLower(routing.BuildAgentMainSessionKey(input.AgentID)) + legacyMainSessionKey := strings.ToLower(BuildLegacyMainAlias(input.AgentID)) return Allocation{ Scope: scope, SessionKey: BuildSessionKey(scope), @@ -85,7 +85,7 @@ func buildSessionScope(input AllocationInput) SessionScope { values["topic"] = "topic:" + strings.ToLower(topicID) } case "sender": - senderID := routing.CanonicalSessionIdentityID( + senderID := CanonicalSessionIdentityID( inbound.Channel, inbound.SenderID, input.SessionPolicy.IdentityLinks, @@ -107,11 +107,11 @@ func buildSessionScope(input AllocationInput) SessionScope { } func buildLegacySessionAliases(input AllocationInput) []string { - aliases := []string{strings.ToLower(routing.BuildAgentMainSessionKey(input.AgentID))} + aliases := []string{strings.ToLower(BuildLegacyMainAlias(input.AgentID))} inbound := input.Context if strings.EqualFold(strings.TrimSpace(inbound.ChatType), "direct") { - senderID := routing.CanonicalSessionIdentityID( + senderID := CanonicalSessionIdentityID( inbound.Channel, inbound.SenderID, input.SessionPolicy.IdentityLinks, @@ -119,20 +119,10 @@ func buildLegacySessionAliases(input AllocationInput) []string { if senderID == "" { return uniqueAliases(aliases) } - for _, dmScope := range []routing.DMScope{ - routing.DMScopePerPeer, - routing.DMScopePerChannelPeer, - routing.DMScopePerAccountChannelPeer, - } { - aliases = append(aliases, strings.ToLower(routing.BuildAgentPeerSessionKey(routing.SessionKeyParams{ - AgentID: input.AgentID, - Channel: inbound.Channel, - AccountID: inbound.Account, - Peer: &routing.RoutePeer{Kind: "direct", ID: senderID}, - DMScope: dmScope, - IdentityLinks: input.SessionPolicy.IdentityLinks, - }))) - } + aliases = append( + aliases, + BuildLegacyDirectAliases(input.AgentID, inbound.Channel, inbound.Account, senderID)..., + ) return uniqueAliases(aliases) } @@ -143,15 +133,12 @@ func buildLegacySessionAliases(input AllocationInput) []string { if topicID := strings.TrimSpace(inbound.TopicID); topicID != "" { peerID = peerID + "/" + topicID } - aliases = append(aliases, strings.ToLower(routing.BuildAgentPeerSessionKey(routing.SessionKeyParams{ - AgentID: input.AgentID, - Channel: inbound.Channel, - AccountID: inbound.Account, - Peer: &routing.RoutePeer{ - Kind: strings.ToLower(strings.TrimSpace(inbound.ChatType)), - ID: peerID, - }, - }))) + aliases = append(aliases, BuildLegacyPeerAlias( + input.AgentID, + inbound.Channel, + strings.ToLower(strings.TrimSpace(inbound.ChatType)), + peerID, + )) return uniqueAliases(aliases) } diff --git a/pkg/session/key.go b/pkg/session/key.go index 77dd115f5..6f1ee438f 100644 --- a/pkg/session/key.go +++ b/pkg/session/key.go @@ -5,9 +5,19 @@ import ( "encoding/hex" "fmt" "strings" + + "github.com/sipeed/picoclaw/pkg/routing" ) -const sessionKeyV1Prefix = "sk_v1_" +const ( + sessionKeyV1Prefix = "sk_v1_" + legacyAgentSessionKeyPrefix = "agent:" +) + +type ParsedLegacySessionKey struct { + AgentID string + Rest string +} // BuildOpaqueSessionKey returns a stable opaque session key derived from a // canonical alias string. The alias remains available through metadata for @@ -27,6 +37,129 @@ func IsOpaqueSessionKey(key string) bool { return strings.HasPrefix(strings.ToLower(strings.TrimSpace(key)), sessionKeyV1Prefix) } +func IsLegacyAgentSessionKey(key string) bool { + return strings.HasPrefix(strings.ToLower(strings.TrimSpace(key)), legacyAgentSessionKeyPrefix) +} + +func IsExplicitSessionKey(key string) bool { + return IsOpaqueSessionKey(key) || IsLegacyAgentSessionKey(key) +} + +func ParseLegacyAgentSessionKey(sessionKey string) *ParsedLegacySessionKey { + raw := strings.TrimSpace(sessionKey) + if raw == "" { + return nil + } + parts := strings.SplitN(raw, ":", 3) + if len(parts) < 3 || parts[0] != "agent" { + return nil + } + agentID := strings.TrimSpace(parts[1]) + rest := parts[2] + if agentID == "" || rest == "" { + return nil + } + return &ParsedLegacySessionKey{AgentID: agentID, Rest: rest} +} + +func BuildLegacyMainAlias(agentID string) string { + return fmt.Sprintf("agent:%s:main", routing.NormalizeAgentID(agentID)) +} + +// BuildMainSessionKey returns the canonical opaque main-session key for an +// agent. The corresponding legacy alias remains available via +// BuildLegacyMainAlias for compatibility and migration logic. +func BuildMainSessionKey(agentID string) string { + return BuildOpaqueSessionKey(BuildLegacyMainAlias(agentID)) +} + +func BuildLegacyDirectAliases(agentID, channel, account, peerID string) []string { + agentID = routing.NormalizeAgentID(agentID) + channel = normalizeLegacyChannel(channel) + account = routing.NormalizeAccountID(account) + peerID = strings.ToLower(strings.TrimSpace(peerID)) + if peerID == "" { + return nil + } + return []string{ + fmt.Sprintf("agent:%s:direct:%s", agentID, peerID), + fmt.Sprintf("agent:%s:%s:direct:%s", agentID, channel, peerID), + fmt.Sprintf("agent:%s:%s:%s:direct:%s", agentID, channel, account, peerID), + } +} + +func BuildLegacyPeerAlias(agentID, channel, peerKind, peerID string) string { + agentID = routing.NormalizeAgentID(agentID) + channel = normalizeLegacyChannel(channel) + peerKind = strings.ToLower(strings.TrimSpace(peerKind)) + if peerKind == "" { + peerKind = "unknown" + } + peerID = strings.ToLower(strings.TrimSpace(peerID)) + if peerID == "" { + peerID = "unknown" + } + return fmt.Sprintf("agent:%s:%s:%s:%s", agentID, channel, peerKind, peerID) +} + +// CanonicalSessionIdentityID collapses an identity using identity_links when +// possible, then returns a normalized lowercase identifier. +func CanonicalSessionIdentityID(channel, rawID string, identityLinks map[string][]string) string { + normalizedID := strings.TrimSpace(rawID) + if normalizedID == "" { + return "" + } + if linked := resolveLinkedPeerID(identityLinks, channel, normalizedID); linked != "" { + normalizedID = linked + } + return strings.ToLower(normalizedID) +} + +func normalizeLegacyChannel(channel string) string { + channel = strings.ToLower(strings.TrimSpace(channel)) + if channel == "" { + return "unknown" + } + return channel +} + +func resolveLinkedPeerID(identityLinks map[string][]string, channel, peerID string) string { + if len(identityLinks) == 0 { + return "" + } + peerID = strings.TrimSpace(peerID) + if peerID == "" { + return "" + } + + candidates := make(map[string]bool) + rawCandidate := strings.ToLower(peerID) + if rawCandidate != "" { + candidates[rawCandidate] = true + } + channel = strings.ToLower(strings.TrimSpace(channel)) + if channel != "" { + candidates[fmt.Sprintf("%s:%s", channel, rawCandidate)] = true + } + if idx := strings.Index(rawCandidate, ":"); idx > 0 && idx < len(rawCandidate)-1 { + candidates[rawCandidate[idx+1:]] = true + } + + for canonical, ids := range identityLinks { + canonicalName := strings.TrimSpace(canonical) + if canonicalName == "" { + continue + } + for _, id := range ids { + normalized := strings.ToLower(strings.TrimSpace(id)) + if normalized != "" && candidates[normalized] { + return canonicalName + } + } + } + return "" +} + // CanonicalScopeSignature returns a stable serialized representation of scope. func CanonicalScopeSignature(scope SessionScope) string { parts := []string{ diff --git a/pkg/session/key_test.go b/pkg/session/key_test.go new file mode 100644 index 000000000..ede38d468 --- /dev/null +++ b/pkg/session/key_test.go @@ -0,0 +1,72 @@ +package session + +import "testing" + +func TestIsExplicitSessionKey(t *testing.T) { + tests := []struct { + key string + want bool + }{ + {"sk_v1_abc", true}, + {"agent:main:direct:user123", true}, + {"custom-key", false}, + {"", false}, + } + + for _, tt := range tests { + if got := IsExplicitSessionKey(tt.key); got != tt.want { + t.Fatalf("IsExplicitSessionKey(%q) = %v, want %v", tt.key, got, tt.want) + } + } +} + +func TestParseLegacyAgentSessionKey(t *testing.T) { + parsed := ParseLegacyAgentSessionKey("agent:sales:telegram:direct:user123") + if parsed == nil { + t.Fatal("expected parsed legacy key, got nil") + } + if parsed.AgentID != "sales" { + t.Fatalf("AgentID = %q, want sales", parsed.AgentID) + } + if parsed.Rest != "telegram:direct:user123" { + t.Fatalf("Rest = %q, want telegram:direct:user123", parsed.Rest) + } + + if got := ParseLegacyAgentSessionKey("sk_v1_abc"); got != nil { + t.Fatalf("expected nil for opaque key, got %+v", got) + } +} + +func TestBuildLegacyDirectAliases(t *testing.T) { + aliases := BuildLegacyDirectAliases("Main", "Telegram", "BotA", "User123") + want := []string{ + "agent:main:direct:user123", + "agent:main:telegram:direct:user123", + "agent:main:telegram:bota:direct:user123", + } + if len(aliases) != len(want) { + t.Fatalf("len(aliases) = %d, want %d", len(aliases), len(want)) + } + for i := range want { + if aliases[i] != want[i] { + t.Fatalf("aliases[%d] = %q, want %q", i, aliases[i], want[i]) + } + } +} + +func TestBuildLegacyPeerAlias(t *testing.T) { + got := BuildLegacyPeerAlias("Main", "Slack", "channel", "C001") + if got != "agent:main:slack:channel:c001" { + t.Fatalf("BuildLegacyPeerAlias() = %q", got) + } +} + +func TestBuildMainSessionKey(t *testing.T) { + got := BuildMainSessionKey("Main") + if !IsOpaqueSessionKey(got) { + t.Fatalf("BuildMainSessionKey() = %q, want opaque key", got) + } + if got != BuildOpaqueSessionKey("agent:main:main") { + t.Fatalf("BuildMainSessionKey() = %q, want stable main-key hash", got) + } +} diff --git a/pkg/tools/cron.go b/pkg/tools/cron.go index c6ac3a129..30a8e92cd 100644 --- a/pkg/tools/cron.go +++ b/pkg/tools/cron.go @@ -311,8 +311,7 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string { pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second) defer pubCancel() t.msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{ - Channel: channel, - ChatID: chatID, + Context: bus.NewOutboundContext(channel, chatID, ""), Content: output, }) return "ok" @@ -335,8 +334,7 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string { pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second) defer pubCancel() t.msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{ - Channel: channel, - ChatID: chatID, + Context: bus.NewOutboundContext(channel, chatID, ""), Content: output, }) return "ok"