mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
refactor(runtime): drop non-session legacy context compatibility
This commit is contained in:
@@ -610,12 +610,6 @@ func TestAgentLoop_EmitsFollowUpQueuedEvent(t *testing.T) {
|
||||
if payload.SourceTool != "async_followup" {
|
||||
t.Fatalf("expected source tool async_followup, got %q", payload.SourceTool)
|
||||
}
|
||||
if payload.Channel != "cli" {
|
||||
t.Fatalf("expected channel cli, got %q", payload.Channel)
|
||||
}
|
||||
if payload.ChatID != "direct" {
|
||||
t.Fatalf("expected chat id direct, got %q", payload.ChatID)
|
||||
}
|
||||
if payload.ContentLen != len("background result") {
|
||||
t.Fatalf("expected content len %d, got %d", len("background result"), payload.ContentLen)
|
||||
}
|
||||
|
||||
@@ -116,8 +116,6 @@ const (
|
||||
|
||||
// TurnStartPayload describes the start of a turn.
|
||||
type TurnStartPayload struct {
|
||||
Channel string
|
||||
ChatID string
|
||||
UserMessage string
|
||||
MediaCount int
|
||||
}
|
||||
@@ -217,8 +215,6 @@ type SteeringInjectedPayload struct {
|
||||
// FollowUpQueuedPayload describes an async follow-up queued back into the inbound bus.
|
||||
type FollowUpQueuedPayload struct {
|
||||
SourceTool string
|
||||
Channel string
|
||||
ChatID string
|
||||
ContentLen int
|
||||
}
|
||||
|
||||
|
||||
@@ -94,8 +94,6 @@ type LLMHookRequest struct {
|
||||
Messages []providers.Message `json:"messages,omitempty"`
|
||||
Tools []providers.ToolDefinition `json:"tools,omitempty"`
|
||||
Options map[string]any `json:"options,omitempty"`
|
||||
Channel string `json:"channel,omitempty"`
|
||||
ChatID string `json:"chat_id,omitempty"`
|
||||
GracefulTerminal bool `json:"graceful_terminal,omitempty"`
|
||||
}
|
||||
|
||||
@@ -117,8 +115,6 @@ type LLMHookResponse struct {
|
||||
Context *TurnContext `json:"context,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Response *providers.LLMResponse `json:"response,omitempty"`
|
||||
Channel string `json:"channel,omitempty"`
|
||||
ChatID string `json:"chat_id,omitempty"`
|
||||
}
|
||||
|
||||
func (r *LLMHookResponse) Clone() *LLMHookResponse {
|
||||
@@ -137,8 +133,6 @@ type ToolCallHookRequest struct {
|
||||
Context *TurnContext `json:"context,omitempty"`
|
||||
Tool string `json:"tool"`
|
||||
Arguments map[string]any `json:"arguments,omitempty"`
|
||||
Channel string `json:"channel,omitempty"`
|
||||
ChatID string `json:"chat_id,omitempty"`
|
||||
}
|
||||
|
||||
func (r *ToolCallHookRequest) Clone() *ToolCallHookRequest {
|
||||
@@ -157,8 +151,6 @@ type ToolApprovalRequest struct {
|
||||
Context *TurnContext `json:"context,omitempty"`
|
||||
Tool string `json:"tool"`
|
||||
Arguments map[string]any `json:"arguments,omitempty"`
|
||||
Channel string `json:"channel,omitempty"`
|
||||
ChatID string `json:"chat_id,omitempty"`
|
||||
}
|
||||
|
||||
func (r *ToolApprovalRequest) Clone() *ToolApprovalRequest {
|
||||
@@ -179,8 +171,6 @@ type ToolResultHookResponse struct {
|
||||
Arguments map[string]any `json:"arguments,omitempty"`
|
||||
Result *tools.ToolResult `json:"result,omitempty"`
|
||||
Duration time.Duration `json:"duration"`
|
||||
Channel string `json:"channel,omitempty"`
|
||||
ChatID string `json:"chat_id,omitempty"`
|
||||
}
|
||||
|
||||
func (r *ToolResultHookResponse) Clone() *ToolResultHookResponse {
|
||||
|
||||
+32
-151
@@ -107,14 +107,6 @@ const (
|
||||
defaultResponse = "The model returned an empty response. This may indicate a provider error or token limit."
|
||||
toolLimitResponse = "I've reached `max_tool_iterations` without a final response. Increase `max_tool_iterations` in config.json if this task needs more tool steps."
|
||||
handledToolResponseSummary = "Requested output delivered via tool attachment."
|
||||
sessionKeyAgentPrefix = "agent:"
|
||||
sessionKeyOpaquePrefix = "sk_"
|
||||
metadataKeyAccountID = "account_id"
|
||||
metadataKeyGuildID = "guild_id"
|
||||
metadataKeyTeamID = "team_id"
|
||||
metadataKeyReplyToMessage = "reply_to_message_id"
|
||||
metadataKeyParentPeerKind = "parent_peer_kind"
|
||||
metadataKeyParentPeerID = "parent_peer_id"
|
||||
)
|
||||
|
||||
func NewAgentLoop(
|
||||
@@ -234,9 +226,9 @@ func registerSharedTools(
|
||||
messageTool.SetSendCallback(func(channel, chatID, content, replyToMessageID string) error {
|
||||
pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer pubCancel()
|
||||
outboundCtx := bus.NewOutboundContext(channel, chatID, replyToMessageID)
|
||||
return msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{
|
||||
Channel: channel,
|
||||
ChatID: chatID,
|
||||
Context: outboundCtx,
|
||||
Content: content,
|
||||
ReplyToMessageID: replyToMessageID,
|
||||
})
|
||||
@@ -657,8 +649,7 @@ func (al *AgentLoop) PublishResponseIfNeeded(ctx context.Context, channel, chatI
|
||||
}
|
||||
|
||||
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
|
||||
Channel: channel,
|
||||
ChatID: chatID,
|
||||
Context: bus.NewOutboundContext(channel, chatID, ""),
|
||||
Content: response,
|
||||
})
|
||||
logger.InfoCF("agent", "Published outbound response",
|
||||
@@ -714,11 +705,7 @@ func outboundContextFromInbound(
|
||||
channel, chatID, replyToMessageID string,
|
||||
) bus.InboundContext {
|
||||
if inbound == nil {
|
||||
return bus.ContextFromLegacyOutbound(bus.OutboundMessage{
|
||||
Channel: channel,
|
||||
ChatID: chatID,
|
||||
ReplyToMessageID: replyToMessageID,
|
||||
})
|
||||
return bus.NewOutboundContext(channel, chatID, replyToMessageID)
|
||||
}
|
||||
|
||||
outboundCtx := *cloneInboundContext(inbound)
|
||||
@@ -736,8 +723,6 @@ func outboundContextFromInbound(
|
||||
|
||||
func outboundMessageForTurn(ts *turnState, content string) bus.OutboundMessage {
|
||||
return bus.OutboundMessage{
|
||||
Channel: ts.channel,
|
||||
ChatID: ts.chatID,
|
||||
Context: outboundContextFromInbound(
|
||||
ts.opts.InboundContext,
|
||||
ts.channel,
|
||||
@@ -894,8 +879,6 @@ func (al *AgentLoop) logEvent(evt Event) {
|
||||
|
||||
switch payload := evt.Payload.(type) {
|
||||
case TurnStartPayload:
|
||||
fields["channel"] = payload.Channel
|
||||
fields["chat_id"] = payload.ChatID
|
||||
fields["user_len"] = len(payload.UserMessage)
|
||||
fields["media_count"] = payload.MediaCount
|
||||
case TurnEndPayload:
|
||||
@@ -948,8 +931,6 @@ func (al *AgentLoop) logEvent(evt Event) {
|
||||
fields["total_content_len"] = payload.TotalContentLen
|
||||
case FollowUpQueuedPayload:
|
||||
fields["source_tool"] = payload.SourceTool
|
||||
fields["channel"] = payload.Channel
|
||||
fields["chat_id"] = payload.ChatID
|
||||
fields["content_len"] = payload.ContentLen
|
||||
case InterruptReceivedPayload:
|
||||
fields["interrupt_kind"] = payload.Kind
|
||||
@@ -1292,8 +1273,7 @@ func (al *AgentLoop) sendTranscriptionFeedback(
|
||||
}
|
||||
|
||||
err := al.channelManager.SendMessage(ctx, bus.OutboundMessage{
|
||||
Channel: channel,
|
||||
ChatID: chatID,
|
||||
Context: bus.NewOutboundContext(channel, chatID, messageID),
|
||||
Content: feedbackMsg,
|
||||
ReplyToMessageID: messageID,
|
||||
})
|
||||
@@ -1369,13 +1349,15 @@ func (al *AgentLoop) ProcessDirectWithChannel(
|
||||
}
|
||||
|
||||
msg := bus.InboundMessage{
|
||||
Channel: channel,
|
||||
SenderID: "cron",
|
||||
ChatID: chatID,
|
||||
Context: bus.InboundContext{
|
||||
Channel: channel,
|
||||
ChatID: chatID,
|
||||
ChatType: "direct",
|
||||
SenderID: "cron",
|
||||
},
|
||||
Content: content,
|
||||
SessionKey: sessionKey,
|
||||
}
|
||||
msg.Context = bus.ContextFromLegacyInbound(msg)
|
||||
|
||||
return al.processMessage(ctx, msg)
|
||||
}
|
||||
@@ -1481,7 +1463,7 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
|
||||
Channel: msg.Channel,
|
||||
ChatID: msg.ChatID,
|
||||
MessageID: msg.MessageID,
|
||||
ReplyToMessageID: inboundMetadata(msg, metadataKeyReplyToMessage),
|
||||
ReplyToMessageID: msg.Context.ReplyToMessageID,
|
||||
SenderID: msg.SenderID,
|
||||
SenderDisplayName: msg.Sender.DisplayName,
|
||||
UserMessage: msg.Content,
|
||||
@@ -1515,18 +1497,7 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
|
||||
func (al *AgentLoop) resolveMessageRoute(msg bus.InboundMessage) (routing.ResolvedRoute, *AgentInstance, error) {
|
||||
registry := al.GetRegistry()
|
||||
inboundCtx := normalizedInboundContext(msg)
|
||||
channel := strings.TrimSpace(inboundCtx.Channel)
|
||||
if channel == "" {
|
||||
channel = msg.Channel
|
||||
}
|
||||
route := registry.ResolveRoute(routing.RouteInput{
|
||||
Channel: channel,
|
||||
AccountID: routeAccountID(msg),
|
||||
Peer: extractPeer(msg),
|
||||
ParentPeer: extractParentPeer(msg),
|
||||
GuildID: routeGuildID(msg),
|
||||
TeamID: routeTeamID(msg),
|
||||
})
|
||||
route := registry.ResolveRoute(inboundCtx)
|
||||
|
||||
agent, ok := registry.GetAgent(route.AgentID)
|
||||
if !ok {
|
||||
@@ -1551,8 +1522,7 @@ func resolveScopeKey(routeSessionKey, msgSessionKey string) string {
|
||||
}
|
||||
|
||||
func isExplicitSessionKey(sessionKey string) bool {
|
||||
sessionKey = strings.TrimSpace(strings.ToLower(sessionKey))
|
||||
return strings.HasPrefix(sessionKey, sessionKeyAgentPrefix) || strings.HasPrefix(sessionKey, sessionKeyOpaquePrefix)
|
||||
return session.IsExplicitSessionKey(sessionKey)
|
||||
}
|
||||
|
||||
func buildSessionAliases(canonicalKey string, keys ...string) []string {
|
||||
@@ -1621,8 +1591,7 @@ func (al *AgentLoop) requeueInboundMessage(msg bus.InboundMessage) error {
|
||||
pubCtx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
return al.bus.PublishOutbound(pubCtx, bus.OutboundMessage{
|
||||
Channel: msg.Channel,
|
||||
ChatID: msg.ChatID,
|
||||
Context: msg.Context,
|
||||
Content: msg.Content,
|
||||
})
|
||||
}
|
||||
@@ -1679,7 +1648,7 @@ func (al *AgentLoop) processSystemMessage(
|
||||
}
|
||||
|
||||
// Use the origin session for context
|
||||
sessionKey := routing.BuildAgentMainSessionKey(agent.ID)
|
||||
sessionKey := session.BuildMainSessionKey(agent.ID)
|
||||
|
||||
return al.runAgentLoop(ctx, agent, processOptions{
|
||||
SessionKey: sessionKey,
|
||||
@@ -1739,8 +1708,6 @@ func (al *AgentLoop) runAgentLoop(
|
||||
|
||||
if opts.SendResponse && result.finalContent != "" {
|
||||
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
|
||||
Channel: opts.Channel,
|
||||
ChatID: opts.ChatID,
|
||||
Context: outboundContextFromInbound(
|
||||
opts.InboundContext,
|
||||
opts.Channel,
|
||||
@@ -1796,8 +1763,7 @@ func (al *AgentLoop) handleReasoning(
|
||||
defer pubCancel()
|
||||
|
||||
if err := al.bus.PublishOutbound(pubCtx, bus.OutboundMessage{
|
||||
Channel: channelName,
|
||||
ChatID: channelID,
|
||||
Context: bus.NewOutboundContext(channelName, channelID, ""),
|
||||
Content: reasoningContent,
|
||||
}); err != nil {
|
||||
// Treat context.DeadlineExceeded / context.Canceled as expected
|
||||
@@ -1851,8 +1817,6 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
|
||||
EventKindTurnStart,
|
||||
ts.eventMeta("runTurn", "turn.start"),
|
||||
TurnStartPayload{
|
||||
Channel: ts.channel,
|
||||
ChatID: ts.chatID,
|
||||
UserMessage: ts.userMessage,
|
||||
MediaCount: len(ts.media),
|
||||
},
|
||||
@@ -2085,8 +2049,6 @@ turnLoop:
|
||||
Messages: callMessages,
|
||||
Tools: providerToolDefs,
|
||||
Options: llmOpts,
|
||||
Channel: ts.channel,
|
||||
ChatID: ts.chatID,
|
||||
GracefulTerminal: gracefulTerminal,
|
||||
})
|
||||
switch decision.normalizedAction() {
|
||||
@@ -2314,8 +2276,6 @@ turnLoop:
|
||||
Context: cloneTurnContext(ts.turnCtx),
|
||||
Model: llmModel,
|
||||
Response: response,
|
||||
Channel: ts.channel,
|
||||
ChatID: ts.chatID,
|
||||
})
|
||||
switch decision.normalizedAction() {
|
||||
case HookActionContinue, HookActionModify:
|
||||
@@ -2346,7 +2306,7 @@ turnLoop:
|
||||
reasoningContent = response.ReasoningContent
|
||||
}
|
||||
go al.handleReasoning(
|
||||
turnCtx,
|
||||
ctx,
|
||||
reasoningContent,
|
||||
ts.channel,
|
||||
al.targetReasoningChannelID(ts.channel),
|
||||
@@ -2467,8 +2427,6 @@ turnLoop:
|
||||
Context: cloneTurnContext(ts.turnCtx),
|
||||
Tool: toolName,
|
||||
Arguments: toolArgs,
|
||||
Channel: ts.channel,
|
||||
ChatID: ts.chatID,
|
||||
})
|
||||
switch decision.normalizedAction() {
|
||||
case HookActionContinue, HookActionModify:
|
||||
@@ -2514,8 +2472,6 @@ turnLoop:
|
||||
Context: cloneTurnContext(ts.turnCtx),
|
||||
Tool: toolName,
|
||||
Arguments: toolArgs,
|
||||
Channel: ts.channel,
|
||||
ChatID: ts.chatID,
|
||||
})
|
||||
if !approval.Approved {
|
||||
allResponsesHandled = false
|
||||
@@ -2605,8 +2561,6 @@ turnLoop:
|
||||
ts.scope.meta(toolIteration, "runTurn", "turn.follow_up.queued"),
|
||||
FollowUpQueuedPayload{
|
||||
SourceTool: asyncToolName,
|
||||
Channel: ts.channel,
|
||||
ChatID: ts.chatID,
|
||||
ContentLen: len(content),
|
||||
},
|
||||
)
|
||||
@@ -2614,10 +2568,13 @@ turnLoop:
|
||||
pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer pubCancel()
|
||||
_ = al.bus.PublishInbound(pubCtx, bus.InboundMessage{
|
||||
Channel: "system",
|
||||
SenderID: fmt.Sprintf("async:%s", asyncToolName),
|
||||
ChatID: fmt.Sprintf("%s:%s", ts.channel, ts.chatID),
|
||||
Content: content,
|
||||
Context: bus.InboundContext{
|
||||
Channel: "system",
|
||||
ChatID: fmt.Sprintf("%s:%s", ts.channel, ts.chatID),
|
||||
ChatType: "direct",
|
||||
SenderID: fmt.Sprintf("async:%s", asyncToolName),
|
||||
},
|
||||
Content: content,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -2652,8 +2609,6 @@ turnLoop:
|
||||
Arguments: toolArgs,
|
||||
Result: toolResult,
|
||||
Duration: toolDuration,
|
||||
Channel: ts.channel,
|
||||
ChatID: ts.chatID,
|
||||
})
|
||||
switch decision.normalizedAction() {
|
||||
case HookActionContinue, HookActionModify:
|
||||
@@ -2692,9 +2647,13 @@ turnLoop:
|
||||
parts = append(parts, part)
|
||||
}
|
||||
outboundMedia := bus.OutboundMediaMessage{
|
||||
Channel: ts.channel,
|
||||
ChatID: ts.chatID,
|
||||
Parts: parts,
|
||||
Context: outboundContextFromInbound(
|
||||
ts.opts.InboundContext,
|
||||
ts.channel,
|
||||
ts.chatID,
|
||||
ts.opts.ReplyToMessageID,
|
||||
),
|
||||
Parts: parts,
|
||||
}
|
||||
if al.channelManager != nil && ts.channel != "" && !constants.IsInternalChannel(ts.channel) {
|
||||
if err := al.channelManager.SendMedia(ctx, outboundMedia); err != nil {
|
||||
@@ -3758,84 +3717,6 @@ func mapCommandError(result commands.ExecuteResult) string {
|
||||
return fmt.Sprintf("Failed to execute /%s: %v", result.Command, result.Err)
|
||||
}
|
||||
|
||||
// extractPeer extracts the routing peer from the inbound message's structured Peer field.
|
||||
func extractPeer(msg bus.InboundMessage) *routing.RoutePeer {
|
||||
if msg.Peer.Kind != "" {
|
||||
peerID := msg.Peer.ID
|
||||
if peerID == "" {
|
||||
if msg.Peer.Kind == "direct" {
|
||||
peerID = msg.SenderID
|
||||
} else {
|
||||
peerID = msg.ChatID
|
||||
}
|
||||
}
|
||||
return &routing.RoutePeer{Kind: msg.Peer.Kind, ID: peerID}
|
||||
}
|
||||
|
||||
inboundCtx := normalizedInboundContext(msg)
|
||||
peerKind := strings.TrimSpace(inboundCtx.ChatType)
|
||||
if peerKind == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
peerID := strings.TrimSpace(inboundCtx.ChatID)
|
||||
if peerKind == "direct" && peerID == "" {
|
||||
peerID = strings.TrimSpace(inboundCtx.SenderID)
|
||||
}
|
||||
if peerID == "" {
|
||||
return nil
|
||||
}
|
||||
return &routing.RoutePeer{Kind: peerKind, ID: peerID}
|
||||
}
|
||||
|
||||
func inboundMetadata(msg bus.InboundMessage, key string) string {
|
||||
if msg.Metadata == nil {
|
||||
return ""
|
||||
}
|
||||
return msg.Metadata[key]
|
||||
}
|
||||
|
||||
// extractParentPeer extracts the parent peer (reply-to) from inbound message metadata.
|
||||
func extractParentPeer(msg bus.InboundMessage) *routing.RoutePeer {
|
||||
inboundCtx := normalizedInboundContext(msg)
|
||||
if topicID := strings.TrimSpace(inboundCtx.TopicID); topicID != "" {
|
||||
return &routing.RoutePeer{Kind: "topic", ID: topicID}
|
||||
}
|
||||
|
||||
parentKind := inboundMetadata(msg, metadataKeyParentPeerKind)
|
||||
parentID := inboundMetadata(msg, metadataKeyParentPeerID)
|
||||
if parentKind == "" || parentID == "" {
|
||||
return nil
|
||||
}
|
||||
return &routing.RoutePeer{Kind: parentKind, ID: parentID}
|
||||
}
|
||||
|
||||
func routeAccountID(msg bus.InboundMessage) string {
|
||||
if accountID := strings.TrimSpace(normalizedInboundContext(msg).Account); accountID != "" {
|
||||
return accountID
|
||||
}
|
||||
return inboundMetadata(msg, metadataKeyAccountID)
|
||||
}
|
||||
|
||||
func routeGuildID(msg bus.InboundMessage) string {
|
||||
inboundCtx := normalizedInboundContext(msg)
|
||||
if strings.EqualFold(strings.TrimSpace(inboundCtx.SpaceType), "guild") {
|
||||
return strings.TrimSpace(inboundCtx.SpaceID)
|
||||
}
|
||||
return inboundMetadata(msg, metadataKeyGuildID)
|
||||
}
|
||||
|
||||
func routeTeamID(msg bus.InboundMessage) string {
|
||||
inboundCtx := normalizedInboundContext(msg)
|
||||
switch strings.ToLower(strings.TrimSpace(inboundCtx.SpaceType)) {
|
||||
case "team", "workspace":
|
||||
if spaceID := strings.TrimSpace(inboundCtx.SpaceID); spaceID != "" {
|
||||
return spaceID
|
||||
}
|
||||
}
|
||||
return inboundMetadata(msg, metadataKeyTeamID)
|
||||
}
|
||||
|
||||
// isNativeSearchProvider reports whether the given LLM provider implements
|
||||
// NativeSearchCapable and returns true for SupportsNativeSearch.
|
||||
func isNativeSearchProvider(p providers.LLMProvider) bool {
|
||||
|
||||
+101
-142
@@ -140,7 +140,7 @@ func TestProcessMessage_IncludesCurrentSenderInDynamicContext(t *testing.T) {
|
||||
provider := &recordingProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
response, err := al.processMessage(context.Background(), bus.InboundMessage{
|
||||
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
|
||||
Channel: "discord",
|
||||
SenderID: "discord:123",
|
||||
Sender: bus.SenderInfo{
|
||||
@@ -148,7 +148,7 @@ func TestProcessMessage_IncludesCurrentSenderInDynamicContext(t *testing.T) {
|
||||
},
|
||||
ChatID: "group-1",
|
||||
Content: "hello",
|
||||
})
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
@@ -199,12 +199,12 @@ func TestProcessMessage_UseCommandLoadsRequestedSkill(t *testing.T) {
|
||||
provider := &recordingProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
response, err := al.processMessage(context.Background(), bus.InboundMessage{
|
||||
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "telegram:123",
|
||||
ChatID: "chat-1",
|
||||
Content: "/use shell explain how to list files",
|
||||
})
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
@@ -289,12 +289,12 @@ func TestProcessMessage_UseCommandArmsSkillForNextMessage(t *testing.T) {
|
||||
provider := &recordingProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
response, err := al.processMessage(context.Background(), bus.InboundMessage{
|
||||
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "telegram:123",
|
||||
ChatID: "chat-1",
|
||||
Content: "/use shell",
|
||||
})
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() arm error = %v", err)
|
||||
}
|
||||
@@ -302,12 +302,12 @@ func TestProcessMessage_UseCommandArmsSkillForNextMessage(t *testing.T) {
|
||||
t.Fatalf("arm response = %q, want armed confirmation", response)
|
||||
}
|
||||
|
||||
response, err = al.processMessage(context.Background(), bus.InboundMessage{
|
||||
response, err = al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "telegram:123",
|
||||
ChatID: "chat-1",
|
||||
Content: "explain how to list files",
|
||||
})
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() follow-up error = %v", err)
|
||||
}
|
||||
@@ -620,12 +620,12 @@ func TestProcessMessage_MediaToolHandledSkipsFollowUpLLMAndFinalText(t *testing.
|
||||
path: imagePath,
|
||||
})
|
||||
|
||||
response, err := al.processMessage(context.Background(), bus.InboundMessage{
|
||||
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
ChatID: "chat1",
|
||||
SenderID: "user1",
|
||||
Content: "take a screenshot of the screen and send it to me",
|
||||
})
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
@@ -662,21 +662,21 @@ func TestProcessMessage_MediaToolHandledSkipsFollowUpLLMAndFinalText(t *testing.
|
||||
if defaultAgent == nil {
|
||||
t.Fatal("expected default agent")
|
||||
}
|
||||
route, _, err := al.resolveMessageRoute(bus.InboundMessage{
|
||||
route, _, err := al.resolveMessageRoute(testInboundMessage(bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
ChatID: "chat1",
|
||||
SenderID: "user1",
|
||||
Content: "take a screenshot of the screen and send it to me",
|
||||
})
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("resolveMessageRoute() error = %v", err)
|
||||
}
|
||||
sessionKey := resolveScopeKey(al.allocateRouteSession(route, bus.InboundMessage{
|
||||
sessionKey := resolveScopeKey(al.allocateRouteSession(route, testInboundMessage(bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
ChatID: "chat1",
|
||||
SenderID: "user1",
|
||||
Content: "take a screenshot of the screen and send it to me",
|
||||
}).SessionKey, "")
|
||||
})).SessionKey, "")
|
||||
history := defaultAgent.Sessions.GetHistory(sessionKey)
|
||||
if len(history) == 0 {
|
||||
t.Fatal("expected session history to be saved")
|
||||
@@ -720,12 +720,12 @@ func TestProcessMessage_HandledToolProcessesQueuedSteeringBeforeReturning(t *tes
|
||||
loop: al,
|
||||
})
|
||||
|
||||
response, err := al.processMessage(context.Background(), bus.InboundMessage{
|
||||
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
ChatID: "chat1",
|
||||
SenderID: "user1",
|
||||
Content: "take a screenshot of the screen and send it to me",
|
||||
})
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
@@ -740,41 +740,6 @@ func TestProcessMessage_HandledToolProcessesQueuedSteeringBeforeReturning(t *tes
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractPeer_UsesInboundContextWhenLegacyPeerMissing(t *testing.T) {
|
||||
msg := bus.InboundMessage{
|
||||
Context: bus.InboundContext{
|
||||
Channel: "slack",
|
||||
ChatID: "C001",
|
||||
ChatType: "channel",
|
||||
SenderID: "U001",
|
||||
},
|
||||
}
|
||||
|
||||
peer := extractPeer(msg)
|
||||
if peer == nil {
|
||||
t.Fatal("expected peer from inbound context")
|
||||
}
|
||||
if peer.Kind != "channel" || peer.ID != "C001" {
|
||||
t.Fatalf("peer = %+v, want channel/C001", peer)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractParentPeer_UsesInboundContextTopicID(t *testing.T) {
|
||||
msg := bus.InboundMessage{
|
||||
Context: bus.InboundContext{
|
||||
TopicID: "thread-42",
|
||||
},
|
||||
}
|
||||
|
||||
parentPeer := extractParentPeer(msg)
|
||||
if parentPeer == nil {
|
||||
t.Fatal("expected parent peer from topic context")
|
||||
}
|
||||
if parentPeer.Kind != "topic" || parentPeer.ID != "thread-42" {
|
||||
t.Fatalf("parent peer = %+v, want topic/thread-42", parentPeer)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendEventContextFields_IncludesInboundRouteAndScope(t *testing.T) {
|
||||
fields := map[string]any{}
|
||||
|
||||
@@ -872,7 +837,7 @@ func TestResolveMessageRoute_UsesInboundContextAccountAndSpace(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
al := NewAgentLoop(cfg, msgBus, &simpleMockProvider{response: "ok"})
|
||||
|
||||
route, _, err := al.resolveMessageRoute(bus.InboundMessage{
|
||||
route, _, err := al.resolveMessageRoute(testInboundMessage(bus.InboundMessage{
|
||||
Context: bus.InboundContext{
|
||||
Channel: "slack",
|
||||
Account: "workspace-a",
|
||||
@@ -883,7 +848,7 @@ func TestResolveMessageRoute_UsesInboundContextAccountAndSpace(t *testing.T) {
|
||||
SpaceType: "workspace",
|
||||
},
|
||||
Content: "hello",
|
||||
})
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("resolveMessageRoute() error = %v", err)
|
||||
}
|
||||
@@ -926,12 +891,12 @@ func TestProcessMessage_MediaArtifactCanBeForwardedBySendFile(t *testing.T) {
|
||||
path: imagePath,
|
||||
})
|
||||
|
||||
response, err := al.processMessage(context.Background(), bus.InboundMessage{
|
||||
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
ChatID: "chat1",
|
||||
SenderID: "user1",
|
||||
Content: "take a screenshot of the screen and send it to me",
|
||||
})
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
@@ -1518,13 +1483,39 @@ func (h testHelper) executeAndGetResponse(tb testing.TB, ctx context.Context, ms
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, responseTimeout)
|
||||
defer cancel()
|
||||
|
||||
response, err := h.al.processMessage(timeoutCtx, msg)
|
||||
response, err := h.al.processMessage(timeoutCtx, testInboundMessage(msg))
|
||||
if err != nil {
|
||||
tb.Fatalf("processMessage failed: %v", err)
|
||||
}
|
||||
return response
|
||||
}
|
||||
|
||||
func testInboundMessage(msg bus.InboundMessage) bus.InboundMessage {
|
||||
if msg.Context.Channel == "" &&
|
||||
msg.Context.Account == "" &&
|
||||
msg.Context.ChatID == "" &&
|
||||
msg.Context.ChatType == "" &&
|
||||
msg.Context.TopicID == "" &&
|
||||
msg.Context.SpaceID == "" &&
|
||||
msg.Context.SpaceType == "" &&
|
||||
msg.Context.SenderID == "" &&
|
||||
msg.Context.MessageID == "" &&
|
||||
!msg.Context.Mentioned &&
|
||||
msg.Context.ReplyToMessageID == "" &&
|
||||
msg.Context.ReplyToSenderID == "" &&
|
||||
len(msg.Context.ReplyHandles) == 0 &&
|
||||
len(msg.Context.Raw) == 0 {
|
||||
msg.Context = bus.InboundContext{
|
||||
Channel: msg.Channel,
|
||||
ChatID: msg.ChatID,
|
||||
ChatType: "direct",
|
||||
SenderID: msg.SenderID,
|
||||
MessageID: msg.MessageID,
|
||||
}
|
||||
}
|
||||
return bus.NormalizeInboundMessage(msg)
|
||||
}
|
||||
|
||||
const responseTimeout = 3 * time.Second
|
||||
|
||||
func TestProcessMessage_UsesRouteSessionKey(t *testing.T) {
|
||||
@@ -1550,20 +1541,16 @@ func TestProcessMessage_UsesRouteSessionKey(t *testing.T) {
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
msg := bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "hello",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
Context: bus.InboundContext{
|
||||
Channel: "telegram",
|
||||
ChatID: "chat1",
|
||||
ChatType: "direct",
|
||||
SenderID: "user1",
|
||||
},
|
||||
Content: "hello",
|
||||
}
|
||||
|
||||
route := al.registry.ResolveRoute(routing.RouteInput{
|
||||
Channel: msg.Channel,
|
||||
Peer: extractPeer(msg),
|
||||
})
|
||||
route := al.registry.ResolveRoute(bus.NormalizeInboundMessage(msg).Context)
|
||||
sessionKey := al.allocateRouteSession(route, msg).SessionKey
|
||||
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
@@ -1610,21 +1597,22 @@ func TestProcessMessage_CommandOutcomes(t *testing.T) {
|
||||
helper := testHelper{al: al}
|
||||
|
||||
baseMsg := bus.InboundMessage{
|
||||
Channel: "whatsapp",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
Context: bus.InboundContext{
|
||||
Channel: "whatsapp",
|
||||
ChatID: "chat1",
|
||||
ChatType: "direct",
|
||||
SenderID: "user1",
|
||||
},
|
||||
}
|
||||
|
||||
showResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
||||
Channel: baseMsg.Channel,
|
||||
SenderID: baseMsg.SenderID,
|
||||
ChatID: baseMsg.ChatID,
|
||||
Content: "/show channel",
|
||||
Peer: baseMsg.Peer,
|
||||
Context: bus.InboundContext{
|
||||
Channel: baseMsg.Context.Channel,
|
||||
ChatID: baseMsg.Context.ChatID,
|
||||
ChatType: baseMsg.Context.ChatType,
|
||||
SenderID: baseMsg.Context.SenderID,
|
||||
},
|
||||
Content: "/show channel",
|
||||
})
|
||||
if showResp != "Current Channel: whatsapp" {
|
||||
t.Fatalf("unexpected /show reply: %q", showResp)
|
||||
@@ -1634,11 +1622,13 @@ func TestProcessMessage_CommandOutcomes(t *testing.T) {
|
||||
}
|
||||
|
||||
fooResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
||||
Channel: baseMsg.Channel,
|
||||
SenderID: baseMsg.SenderID,
|
||||
ChatID: baseMsg.ChatID,
|
||||
Content: "/foo",
|
||||
Peer: baseMsg.Peer,
|
||||
Context: bus.InboundContext{
|
||||
Channel: baseMsg.Context.Channel,
|
||||
ChatID: baseMsg.Context.ChatID,
|
||||
ChatType: baseMsg.Context.ChatType,
|
||||
SenderID: baseMsg.Context.SenderID,
|
||||
},
|
||||
Content: "/foo",
|
||||
})
|
||||
if fooResp != "LLM reply" {
|
||||
t.Fatalf("unexpected /foo reply: %q", fooResp)
|
||||
@@ -1648,11 +1638,13 @@ func TestProcessMessage_CommandOutcomes(t *testing.T) {
|
||||
}
|
||||
|
||||
newResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
||||
Channel: baseMsg.Channel,
|
||||
SenderID: baseMsg.SenderID,
|
||||
ChatID: baseMsg.ChatID,
|
||||
Content: "/new",
|
||||
Peer: baseMsg.Peer,
|
||||
Context: bus.InboundContext{
|
||||
Channel: baseMsg.Context.Channel,
|
||||
ChatID: baseMsg.Context.ChatID,
|
||||
ChatType: baseMsg.Context.ChatType,
|
||||
SenderID: baseMsg.Context.SenderID,
|
||||
},
|
||||
Content: "/new",
|
||||
})
|
||||
if newResp != "LLM reply" {
|
||||
t.Fatalf("unexpected /new reply: %q", newResp)
|
||||
@@ -1705,10 +1697,6 @@ func TestProcessMessage_SwitchModelShowModelConsistency(t *testing.T) {
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "/switch model to deepseek",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
})
|
||||
if !strings.Contains(switchResp, "Switched model from local to deepseek") {
|
||||
t.Fatalf("unexpected /switch reply: %q", switchResp)
|
||||
@@ -1719,10 +1707,6 @@ func TestProcessMessage_SwitchModelShowModelConsistency(t *testing.T) {
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "/show model",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
})
|
||||
if !strings.Contains(showResp, "Current Model: deepseek (Provider: openrouter)") {
|
||||
t.Fatalf("unexpected /show model reply after switch: %q", showResp)
|
||||
@@ -1770,10 +1754,6 @@ func TestProcessMessage_SwitchModelRejectsUnknownAlias(t *testing.T) {
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "/switch model to missing",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
})
|
||||
if switchResp != `model "missing" not found in model_list or providers` {
|
||||
t.Fatalf("unexpected /switch error reply: %q", switchResp)
|
||||
@@ -1784,10 +1764,6 @@ func TestProcessMessage_SwitchModelRejectsUnknownAlias(t *testing.T) {
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "/show model",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
})
|
||||
if !strings.Contains(showResp, "Current Model: local (Provider: openai)") {
|
||||
t.Fatalf("unexpected /show model reply after rejected switch: %q", showResp)
|
||||
@@ -1854,10 +1830,6 @@ func TestProcessMessage_SwitchModelRoutesSubsequentRequestsToSelectedProvider(t
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "hello before switch",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
})
|
||||
if firstResp != "local reply" {
|
||||
t.Fatalf("unexpected response before switch: %q", firstResp)
|
||||
@@ -1877,10 +1849,6 @@ func TestProcessMessage_SwitchModelRoutesSubsequentRequestsToSelectedProvider(t
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "/switch model to deepseek",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
})
|
||||
if !strings.Contains(switchResp, "Switched model from local to deepseek") {
|
||||
t.Fatalf("unexpected /switch reply: %q", switchResp)
|
||||
@@ -1891,10 +1859,6 @@ func TestProcessMessage_SwitchModelRoutesSubsequentRequestsToSelectedProvider(t
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "hello after switch",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
})
|
||||
if secondResp != "remote reply" {
|
||||
t.Fatalf("unexpected response after switch: %q", secondResp)
|
||||
@@ -1984,10 +1948,6 @@ func TestProcessMessage_ModelRoutingUsesLightProvider(t *testing.T) {
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "hi",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
})
|
||||
if resp != "light reply" {
|
||||
t.Fatalf("response = %q, want %q", resp, "light reply")
|
||||
@@ -2260,22 +2220,16 @@ func TestAgentLoop_ToolLimitUsesDedicatedFallback(t *testing.T) {
|
||||
if defaultAgent == nil {
|
||||
t.Fatal("No default agent found")
|
||||
}
|
||||
route := al.registry.ResolveRoute(routing.RouteInput{
|
||||
Channel: "test",
|
||||
Peer: &routing.RoutePeer{
|
||||
Kind: "direct",
|
||||
ID: "cron",
|
||||
},
|
||||
route := al.registry.ResolveRoute(bus.InboundContext{
|
||||
Channel: "test",
|
||||
ChatType: "direct",
|
||||
SenderID: "cron",
|
||||
})
|
||||
history := defaultAgent.Sessions.GetHistory(al.allocateRouteSession(route, bus.InboundMessage{
|
||||
history := defaultAgent.Sessions.GetHistory(al.allocateRouteSession(route, testInboundMessage(bus.InboundMessage{
|
||||
Channel: "test",
|
||||
SenderID: "cron",
|
||||
ChatID: "chat1",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "cron",
|
||||
},
|
||||
}).SessionKey)
|
||||
})).SessionKey)
|
||||
if len(history) != 4 {
|
||||
t.Fatalf("history len = %d, want 4", len(history))
|
||||
}
|
||||
@@ -2533,8 +2487,7 @@ func TestHandleReasoning(t *testing.T) {
|
||||
for i := 0; ; i++ {
|
||||
fillCtx, fillCancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
err := msgBus.PublishOutbound(fillCtx, bus.OutboundMessage{
|
||||
Channel: "filler",
|
||||
ChatID: "filler",
|
||||
Context: bus.NewOutboundContext("filler", "filler", ""),
|
||||
Content: fmt.Sprintf("filler-%d", i),
|
||||
})
|
||||
fillCancel()
|
||||
@@ -2608,12 +2561,12 @@ func TestProcessMessage_PublishesReasoningContentToReasoningChannel(t *testing.T
|
||||
chManager.RegisterChannel("telegram", &fakeChannel{id: "reason-chat"})
|
||||
al.SetChannelManager(chManager)
|
||||
|
||||
response, err := al.processMessage(context.Background(), bus.InboundMessage{
|
||||
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "hello",
|
||||
})
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
@@ -2629,6 +2582,9 @@ func TestProcessMessage_PublishesReasoningContentToReasoningChannel(t *testing.T
|
||||
if outbound.ChatID != "reason-chat" {
|
||||
t.Fatalf("reasoning chatID = %q, want %q", outbound.ChatID, "reason-chat")
|
||||
}
|
||||
if outbound.Context.Channel != "telegram" || outbound.Context.ChatID != "reason-chat" {
|
||||
t.Fatalf("unexpected reasoning context: %+v", outbound.Context)
|
||||
}
|
||||
if outbound.Content != "thinking trace" {
|
||||
t.Fatalf("reasoning content = %q, want %q", outbound.Content, "thinking trace")
|
||||
}
|
||||
@@ -2714,12 +2670,12 @@ func TestProcessMessage_PublishesToolFeedbackWhenEnabled(t *testing.T) {
|
||||
provider := &toolFeedbackProvider{filePath: heartbeatFile}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
response, err := al.processMessage(context.Background(), bus.InboundMessage{
|
||||
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user-1",
|
||||
ChatID: "chat-1",
|
||||
Content: "check tool feedback",
|
||||
})
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
@@ -2735,6 +2691,9 @@ func TestProcessMessage_PublishesToolFeedbackWhenEnabled(t *testing.T) {
|
||||
if outbound.ChatID != "chat-1" {
|
||||
t.Fatalf("tool feedback chatID = %q, want %q", outbound.ChatID, "chat-1")
|
||||
}
|
||||
if outbound.Context.Channel != "telegram" || outbound.Context.ChatID != "chat-1" {
|
||||
t.Fatalf("unexpected tool feedback context: %+v", outbound.Context)
|
||||
}
|
||||
if !strings.Contains(outbound.Content, "`read_file`") {
|
||||
t.Fatalf("tool feedback content = %q, want read_file preview", outbound.Content)
|
||||
}
|
||||
@@ -3157,13 +3116,13 @@ func TestProcessMessage_ContextOverflowRecovery(t *testing.T) {
|
||||
agent.Sessions.AddFullMessage(sessionKey, providers.Message{Role: "assistant", Content: "response"})
|
||||
}
|
||||
|
||||
response, err := al.processMessage(context.Background(), bus.InboundMessage{
|
||||
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
SenderID: "user1",
|
||||
SessionKey: "test-session",
|
||||
Content: "trigger recovery",
|
||||
})
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
@@ -3199,12 +3158,12 @@ func TestProcessMessage_ContextOverflow_AnthropicStyle(t *testing.T) {
|
||||
return &providers.LLMResponse{Content: "Anthropic recovery success"}, nil
|
||||
}
|
||||
|
||||
response, err := al.processMessage(context.Background(), bus.InboundMessage{
|
||||
response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
SenderID: "user1",
|
||||
Content: "hello",
|
||||
})
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("processMessage() error = %v", err)
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package agent
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
@@ -64,9 +65,9 @@ func (r *AgentRegistry) GetAgent(agentID string) (*AgentInstance, bool) {
|
||||
return agent, ok
|
||||
}
|
||||
|
||||
// ResolveRoute determines which agent handles the message.
|
||||
func (r *AgentRegistry) ResolveRoute(input routing.RouteInput) routing.ResolvedRoute {
|
||||
return r.resolver.ResolveRoute(input)
|
||||
// ResolveRoute determines which agent handles the normalized inbound context.
|
||||
func (r *AgentRegistry) ResolveRoute(inbound bus.InboundContext) routing.ResolvedRoute {
|
||||
return r.resolver.ResolveRoute(inbound)
|
||||
}
|
||||
|
||||
// ListAgentIDs returns all registered agent IDs.
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
@@ -332,7 +331,7 @@ func (al *AgentLoop) agentForSession(sessionKey string) *AgentInstance {
|
||||
return agent
|
||||
}
|
||||
|
||||
if parsed := routing.ParseAgentSessionKey(sessionKey); parsed != nil {
|
||||
if parsed := session.ParseLegacyAgentSessionKey(sessionKey); parsed != nil {
|
||||
if agent, ok := registry.GetAgent(parsed.AgentID); ok {
|
||||
return agent
|
||||
}
|
||||
|
||||
+29
-33
@@ -366,14 +366,13 @@ func TestDrainBusToSteering_RequeuesDifferentScopeMessage(t *testing.T) {
|
||||
al := NewAgentLoop(cfg, msgBus, &mockProvider{})
|
||||
|
||||
activeMsg := bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "active turn",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
Context: bus.InboundContext{
|
||||
Channel: "telegram",
|
||||
ChatID: "chat1",
|
||||
ChatType: "direct",
|
||||
SenderID: "user1",
|
||||
},
|
||||
Content: "active turn",
|
||||
}
|
||||
activeScope, activeAgentID, ok := al.resolveSteeringTarget(activeMsg)
|
||||
if !ok {
|
||||
@@ -381,14 +380,13 @@ func TestDrainBusToSteering_RequeuesDifferentScopeMessage(t *testing.T) {
|
||||
}
|
||||
|
||||
otherMsg := bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user2",
|
||||
ChatID: "chat2",
|
||||
Content: "other session",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user2",
|
||||
Context: bus.InboundContext{
|
||||
Channel: "telegram",
|
||||
ChatID: "chat2",
|
||||
ChatType: "direct",
|
||||
SenderID: "user2",
|
||||
},
|
||||
Content: "other session",
|
||||
}
|
||||
otherScope, _, ok := al.resolveSteeringTarget(otherMsg)
|
||||
if !ok {
|
||||
@@ -425,7 +423,7 @@ func TestDrainBusToSteering_RequeuesDifferentScopeMessage(t *testing.T) {
|
||||
case <-ctx.Done():
|
||||
t.Fatalf("timeout waiting for requeued message on outbound bus")
|
||||
case requeued := <-msgBus.OutboundChan():
|
||||
if requeued.Channel != otherMsg.Channel || requeued.ChatID != otherMsg.ChatID ||
|
||||
if requeued.Context.Channel != otherMsg.Context.Channel || requeued.Context.ChatID != otherMsg.Context.ChatID ||
|
||||
requeued.Content != otherMsg.Content {
|
||||
t.Fatalf("requeued message mismatch: got %+v want %+v", requeued, otherMsg)
|
||||
}
|
||||
@@ -842,24 +840,22 @@ func TestAgentLoop_Run_AutoContinuesLateSteeringMessage(t *testing.T) {
|
||||
}()
|
||||
|
||||
first := bus.InboundMessage{
|
||||
Channel: "test",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "first message",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
Context: bus.InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
ChatType: "direct",
|
||||
SenderID: "user1",
|
||||
},
|
||||
Content: "first message",
|
||||
}
|
||||
late := bus.InboundMessage{
|
||||
Channel: "test",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "late append",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
Context: bus.InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
ChatType: "direct",
|
||||
SenderID: "user1",
|
||||
},
|
||||
Content: "late append",
|
||||
}
|
||||
|
||||
pubCtx, pubCancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
@@ -950,7 +946,7 @@ func TestAgentLoop_Steering_DirectResponseContinuesWithQueuedMessage(t *testing.
|
||||
},
|
||||
}
|
||||
|
||||
sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID)
|
||||
sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID)
|
||||
provider := &blockingDirectProvider{
|
||||
firstStarted: make(chan struct{}),
|
||||
releaseFirst: make(chan struct{}),
|
||||
@@ -1117,7 +1113,7 @@ func TestAgentLoop_Continue_PreservesSteeringMedia(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID)
|
||||
sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID)
|
||||
msgBus := bus.NewMessageBus()
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
al.SetMediaStore(store)
|
||||
@@ -1225,7 +1221,7 @@ func TestAgentLoop_InterruptGraceful_UsesTerminalNoToolCall(t *testing.T) {
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
al.RegisterTool(tool1)
|
||||
al.RegisterTool(tool2)
|
||||
sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID)
|
||||
sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID)
|
||||
|
||||
sub := al.SubscribeEvents(32)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
@@ -1379,7 +1375,7 @@ func TestAgentLoop_InterruptHard_RestoresSession(t *testing.T) {
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
started := make(chan struct{})
|
||||
al.RegisterTool(&interruptibleTool{name: "cancel_tool", started: started})
|
||||
sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID)
|
||||
sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID)
|
||||
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
if defaultAgent == nil {
|
||||
|
||||
@@ -12,6 +12,12 @@ import (
|
||||
// ErrBusClosed is returned when publishing to a closed MessageBus.
|
||||
var ErrBusClosed = errors.New("message bus closed")
|
||||
|
||||
var (
|
||||
ErrMissingInboundContext = errors.New("inbound message context is required")
|
||||
ErrMissingOutboundContext = errors.New("outbound message context is required")
|
||||
ErrMissingOutboundMediaContext = errors.New("outbound media context is required")
|
||||
)
|
||||
|
||||
const defaultBusBufferSize = 64
|
||||
|
||||
// StreamDelegate is implemented by the channel Manager to provide streaming
|
||||
@@ -80,6 +86,9 @@ func publish[T any](ctx context.Context, mb *MessageBus, ch chan T, msg T) error
|
||||
}
|
||||
|
||||
func (mb *MessageBus) PublishInbound(ctx context.Context, msg InboundMessage) error {
|
||||
if msg.Context.isZero() {
|
||||
return ErrMissingInboundContext
|
||||
}
|
||||
msg = NormalizeInboundMessage(msg)
|
||||
return publish(ctx, mb, mb.inbound, msg)
|
||||
}
|
||||
@@ -89,6 +98,9 @@ func (mb *MessageBus) InboundChan() <-chan InboundMessage {
|
||||
}
|
||||
|
||||
func (mb *MessageBus) PublishOutbound(ctx context.Context, msg OutboundMessage) error {
|
||||
if msg.Context.isZero() {
|
||||
return ErrMissingOutboundContext
|
||||
}
|
||||
msg = NormalizeOutboundMessage(msg)
|
||||
return publish(ctx, mb, mb.outbound, msg)
|
||||
}
|
||||
@@ -98,6 +110,9 @@ func (mb *MessageBus) OutboundChan() <-chan OutboundMessage {
|
||||
}
|
||||
|
||||
func (mb *MessageBus) PublishOutboundMedia(ctx context.Context, msg OutboundMediaMessage) error {
|
||||
if msg.Context.isZero() {
|
||||
return ErrMissingOutboundMediaContext
|
||||
}
|
||||
msg = NormalizeOutboundMediaMessage(msg)
|
||||
return publish(ctx, mb, mb.outboundMedia, msg)
|
||||
}
|
||||
|
||||
+123
-51
@@ -14,10 +14,13 @@ func TestPublishConsume(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
msg := InboundMessage{
|
||||
Channel: "test",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "hello",
|
||||
Context: InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
ChatType: "direct",
|
||||
SenderID: "user1",
|
||||
},
|
||||
Content: "hello",
|
||||
}
|
||||
|
||||
if err := mb.PublishInbound(ctx, msg); err != nil {
|
||||
@@ -45,25 +48,25 @@ func TestPublishConsume(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishInbound_NormalizesLegacyFieldsIntoContext(t *testing.T) {
|
||||
func TestPublishInbound_NormalizesContext(t *testing.T) {
|
||||
mb := NewMessageBus()
|
||||
defer mb.Close()
|
||||
|
||||
msg := InboundMessage{
|
||||
Channel: "slack",
|
||||
SenderID: "U123",
|
||||
ChatID: "C456/1712",
|
||||
Content: "hello",
|
||||
MessageID: "1712.01",
|
||||
Peer: Peer{Kind: "group", ID: "C456"},
|
||||
Metadata: map[string]string{
|
||||
"account_id": "workspace-a",
|
||||
"team_id": "T001",
|
||||
"reply_to_message_id": "1700.01",
|
||||
"is_mentioned": "true",
|
||||
"parent_peer_kind": "topic",
|
||||
"parent_peer_id": "1712",
|
||||
Context: InboundContext{
|
||||
Channel: "slack",
|
||||
Account: "workspace-a",
|
||||
ChatID: "C456/1712",
|
||||
ChatType: "group",
|
||||
TopicID: "1712",
|
||||
SpaceID: "T001",
|
||||
SpaceType: "team",
|
||||
SenderID: "U123",
|
||||
MessageID: "1712.01",
|
||||
ReplyToMessageID: "1700.01",
|
||||
Mentioned: true,
|
||||
},
|
||||
Content: "hello",
|
||||
}
|
||||
|
||||
if err := mb.PublishInbound(context.Background(), msg); err != nil {
|
||||
@@ -94,7 +97,7 @@ func TestPublishInbound_NormalizesLegacyFieldsIntoContext(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishInbound_MirrorsContextIntoLegacyFields(t *testing.T) {
|
||||
func TestPublishInbound_MirrorsContextIntoConvenienceFields(t *testing.T) {
|
||||
mb := NewMessageBus()
|
||||
defer mb.Close()
|
||||
|
||||
@@ -132,27 +135,8 @@ func TestPublishInbound_MirrorsContextIntoLegacyFields(t *testing.T) {
|
||||
if got.MessageID != "777" {
|
||||
t.Fatalf("expected legacy message ID 777, got %q", got.MessageID)
|
||||
}
|
||||
if got.Peer.Kind != "group" || got.Peer.ID != "-1001" {
|
||||
t.Fatalf("expected legacy peer group/-1001, got %q/%q", got.Peer.Kind, got.Peer.ID)
|
||||
}
|
||||
if got.Metadata["account_id"] != "bot-a" {
|
||||
t.Fatalf("expected mirrored account_id bot-a, got %q", got.Metadata["account_id"])
|
||||
}
|
||||
if got.Metadata["guild_id"] != "guild-9" {
|
||||
t.Fatalf("expected mirrored guild_id guild-9, got %q", got.Metadata["guild_id"])
|
||||
}
|
||||
if got.Metadata["parent_peer_kind"] != "topic" || got.Metadata["parent_peer_id"] != "42" {
|
||||
t.Fatalf(
|
||||
"expected mirrored topic parent peer, got %q/%q",
|
||||
got.Metadata["parent_peer_kind"],
|
||||
got.Metadata["parent_peer_id"],
|
||||
)
|
||||
}
|
||||
if got.Metadata["reply_to_message_id"] != "666" {
|
||||
t.Fatalf("expected mirrored reply_to_message_id 666, got %q", got.Metadata["reply_to_message_id"])
|
||||
}
|
||||
if got.Metadata["is_mentioned"] != "true" {
|
||||
t.Fatalf("expected mirrored is_mentioned true, got %q", got.Metadata["is_mentioned"])
|
||||
if got.Context.Account != "bot-a" || got.Context.SpaceID != "guild-9" || got.Context.TopicID != "42" {
|
||||
t.Fatalf("unexpected normalized context: %+v", got.Context)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -163,8 +147,10 @@ func TestPublishOutboundSubscribe(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
msg := OutboundMessage{
|
||||
Channel: "telegram",
|
||||
ChatID: "123",
|
||||
Context: InboundContext{
|
||||
Channel: "telegram",
|
||||
ChatID: "123",
|
||||
},
|
||||
Content: "world",
|
||||
}
|
||||
|
||||
@@ -179,6 +165,9 @@ func TestPublishOutboundSubscribe(t *testing.T) {
|
||||
if got.Content != "world" {
|
||||
t.Fatalf("expected content 'world', got %q", got.Content)
|
||||
}
|
||||
if got.Context.Channel != "telegram" || got.Context.ChatID != "123" {
|
||||
t.Fatalf("expected normalized outbound context, got %+v", got.Context)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishOutbound_MirrorsContextToLegacyFields(t *testing.T) {
|
||||
@@ -241,6 +230,19 @@ func TestPublishOutboundMedia_MirrorsContextToLegacyFields(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewOutboundContext_NormalizesReplyAddress(t *testing.T) {
|
||||
ctx := NewOutboundContext(" telegram ", " chat-42 ", " msg-9 ")
|
||||
if ctx.Channel != "telegram" {
|
||||
t.Fatalf("expected channel telegram, got %q", ctx.Channel)
|
||||
}
|
||||
if ctx.ChatID != "chat-42" {
|
||||
t.Fatalf("expected chat_id chat-42, got %q", ctx.ChatID)
|
||||
}
|
||||
if ctx.ReplyToMessageID != "msg-9" {
|
||||
t.Fatalf("expected reply_to_message_id msg-9, got %q", ctx.ReplyToMessageID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublishInbound_ContextCancel(t *testing.T) {
|
||||
mb := NewMessageBus()
|
||||
defer mb.Close()
|
||||
@@ -248,7 +250,15 @@ func TestPublishInbound_ContextCancel(t *testing.T) {
|
||||
// Fill the buffer
|
||||
ctx := context.Background()
|
||||
for i := range defaultBusBufferSize {
|
||||
if err := mb.PublishInbound(ctx, InboundMessage{Content: "fill"}); err != nil {
|
||||
if err := mb.PublishInbound(ctx, InboundMessage{
|
||||
Context: InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "chat-fill",
|
||||
ChatType: "direct",
|
||||
SenderID: "user-fill",
|
||||
},
|
||||
Content: "fill",
|
||||
}); err != nil {
|
||||
t.Fatalf("fill failed at %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
@@ -257,7 +267,15 @@ func TestPublishInbound_ContextCancel(t *testing.T) {
|
||||
cancelCtx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
err := mb.PublishInbound(cancelCtx, InboundMessage{Content: "overflow"})
|
||||
err := mb.PublishInbound(cancelCtx, InboundMessage{
|
||||
Context: InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "chat-overflow",
|
||||
ChatType: "direct",
|
||||
SenderID: "user-overflow",
|
||||
},
|
||||
Content: "overflow",
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error from canceled context, got nil")
|
||||
}
|
||||
@@ -270,7 +288,15 @@ func TestPublishInbound_BusClosed(t *testing.T) {
|
||||
mb := NewMessageBus()
|
||||
mb.Close()
|
||||
|
||||
err := mb.PublishInbound(context.Background(), InboundMessage{Content: "test"})
|
||||
err := mb.PublishInbound(context.Background(), InboundMessage{
|
||||
Context: InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
ChatType: "direct",
|
||||
SenderID: "user1",
|
||||
},
|
||||
Content: "test",
|
||||
})
|
||||
if err != ErrBusClosed {
|
||||
t.Fatalf("expected ErrBusClosed, got %v", err)
|
||||
}
|
||||
@@ -280,7 +306,13 @@ func TestPublishOutbound_BusClosed(t *testing.T) {
|
||||
mb := NewMessageBus()
|
||||
mb.Close()
|
||||
|
||||
err := mb.PublishOutbound(context.Background(), OutboundMessage{Content: "test"})
|
||||
err := mb.PublishOutbound(context.Background(), OutboundMessage{
|
||||
Context: InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
},
|
||||
Content: "test",
|
||||
})
|
||||
if err != ErrBusClosed {
|
||||
t.Fatalf("expected ErrBusClosed, got %v", err)
|
||||
}
|
||||
@@ -292,14 +324,30 @@ func TestConsumeInbound_ContextCancel(t *testing.T) {
|
||||
defer mb.Close()
|
||||
|
||||
for i := range defaultBusBufferSize {
|
||||
if err := mb.PublishInbound(context.Background(), InboundMessage{Content: "fill"}); err != nil {
|
||||
if err := mb.PublishInbound(context.Background(), InboundMessage{
|
||||
Context: InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "chat-fill",
|
||||
ChatType: "direct",
|
||||
SenderID: "user-fill",
|
||||
},
|
||||
Content: "fill",
|
||||
}); err != nil {
|
||||
t.Fatalf("fill failed at %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
mb.PublishInbound(ctx, InboundMessage{Content: "ContextCancel"})
|
||||
mb.PublishInbound(ctx, InboundMessage{
|
||||
Context: InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "chat-cancel",
|
||||
ChatType: "direct",
|
||||
SenderID: "user-cancel",
|
||||
},
|
||||
Content: "ContextCancel",
|
||||
})
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
@@ -393,7 +441,15 @@ func TestPublishInbound_FullBuffer(t *testing.T) {
|
||||
|
||||
// Fill the buffer
|
||||
for i := range defaultBusBufferSize {
|
||||
if err := mb.PublishInbound(ctx, InboundMessage{Content: "fill"}); err != nil {
|
||||
if err := mb.PublishInbound(ctx, InboundMessage{
|
||||
Context: InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "chat-fill",
|
||||
ChatType: "direct",
|
||||
SenderID: "user-fill",
|
||||
},
|
||||
Content: "fill",
|
||||
}); err != nil {
|
||||
t.Fatalf("fill failed at %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
@@ -402,7 +458,15 @@ func TestPublishInbound_FullBuffer(t *testing.T) {
|
||||
timeoutCtx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
err := mb.PublishInbound(timeoutCtx, InboundMessage{Content: "overflow"})
|
||||
err := mb.PublishInbound(timeoutCtx, InboundMessage{
|
||||
Context: InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "chat-overflow",
|
||||
ChatType: "direct",
|
||||
SenderID: "user-overflow",
|
||||
},
|
||||
Content: "overflow",
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error when buffer is full and context times out")
|
||||
}
|
||||
@@ -420,7 +484,15 @@ func TestCloseIdempotent(t *testing.T) {
|
||||
mb.Close()
|
||||
|
||||
// After close, publish should return ErrBusClosed
|
||||
err := mb.PublishInbound(context.Background(), InboundMessage{Content: "test"})
|
||||
err := mb.PublishInbound(context.Background(), InboundMessage{
|
||||
Context: InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
ChatType: "direct",
|
||||
SenderID: "user1",
|
||||
},
|
||||
Content: "test",
|
||||
})
|
||||
if err != ErrBusClosed {
|
||||
t.Fatalf("expected ErrBusClosed after multiple closes, got %v", err)
|
||||
}
|
||||
|
||||
+13
-203
@@ -2,92 +2,19 @@ package bus
|
||||
|
||||
import "strings"
|
||||
|
||||
const (
|
||||
metadataKeyAccountID = "account_id"
|
||||
metadataKeyGuildID = "guild_id"
|
||||
metadataKeyTeamID = "team_id"
|
||||
metadataKeyReplyToMessage = "reply_to_message_id"
|
||||
metadataKeyReplyToSender = "reply_to_sender_id"
|
||||
metadataKeyParentPeerKind = "parent_peer_kind"
|
||||
metadataKeyParentPeerID = "parent_peer_id"
|
||||
metadataKeyIsMentioned = "is_mentioned"
|
||||
)
|
||||
|
||||
// ContextFromLegacyInbound builds a normalized inbound context from the legacy
|
||||
// top-level fields on InboundMessage. This keeps older producers working while
|
||||
// new producers migrate to writing Context directly.
|
||||
func ContextFromLegacyInbound(msg InboundMessage) InboundContext {
|
||||
ctx := InboundContext{
|
||||
Channel: strings.TrimSpace(msg.Channel),
|
||||
ChatID: strings.TrimSpace(msg.ChatID),
|
||||
ChatType: normalizeKind(msg.Peer.Kind),
|
||||
SenderID: firstNonEmpty(
|
||||
strings.TrimSpace(msg.SenderID),
|
||||
strings.TrimSpace(msg.Sender.CanonicalID),
|
||||
strings.TrimSpace(msg.Sender.PlatformID),
|
||||
),
|
||||
MessageID: strings.TrimSpace(msg.MessageID),
|
||||
Raw: cloneStringMap(msg.Metadata),
|
||||
}
|
||||
|
||||
if account := metadataValue(msg.Metadata, metadataKeyAccountID); account != "" {
|
||||
ctx.Account = account
|
||||
}
|
||||
if replyToMsgID := metadataValue(msg.Metadata, metadataKeyReplyToMessage); replyToMsgID != "" {
|
||||
ctx.ReplyToMessageID = replyToMsgID
|
||||
}
|
||||
if replyToSenderID := metadataValue(msg.Metadata, metadataKeyReplyToSender); replyToSenderID != "" {
|
||||
ctx.ReplyToSenderID = replyToSenderID
|
||||
}
|
||||
if isTruthy(metadataValue(msg.Metadata, metadataKeyIsMentioned)) {
|
||||
ctx.Mentioned = true
|
||||
}
|
||||
|
||||
parentKind := normalizeKind(metadataValue(msg.Metadata, metadataKeyParentPeerKind))
|
||||
parentID := metadataValue(msg.Metadata, metadataKeyParentPeerID)
|
||||
if parentKind == "topic" && parentID != "" {
|
||||
ctx.TopicID = parentID
|
||||
}
|
||||
|
||||
switch {
|
||||
case metadataValue(msg.Metadata, metadataKeyGuildID) != "":
|
||||
ctx.SpaceType = "guild"
|
||||
ctx.SpaceID = metadataValue(msg.Metadata, metadataKeyGuildID)
|
||||
case metadataValue(msg.Metadata, metadataKeyTeamID) != "":
|
||||
ctx.SpaceType = "team"
|
||||
ctx.SpaceID = metadataValue(msg.Metadata, metadataKeyTeamID)
|
||||
}
|
||||
|
||||
return normalizeInboundContext(ctx)
|
||||
}
|
||||
|
||||
// NormalizeInboundMessage ensures the normalized Context is present and mirrors
|
||||
// missing legacy fields from it so older consumers continue to work during the
|
||||
// migration period.
|
||||
// NormalizeInboundMessage ensures the inbound context is normalized and keeps
|
||||
// convenience mirrors in sync for runtime consumers.
|
||||
func NormalizeInboundMessage(msg InboundMessage) InboundMessage {
|
||||
if msg.Context.isZero() {
|
||||
msg.Context = ContextFromLegacyInbound(msg)
|
||||
} else {
|
||||
msg.Context = normalizeInboundContext(msg.Context)
|
||||
}
|
||||
|
||||
if msg.Channel == "" {
|
||||
msg.Channel = msg.Context.Channel
|
||||
}
|
||||
if msg.SenderID == "" {
|
||||
msg.SenderID = msg.Context.SenderID
|
||||
}
|
||||
if msg.ChatID == "" {
|
||||
msg.ChatID = msg.Context.ChatID
|
||||
}
|
||||
msg.Context = normalizeInboundContext(msg.Context)
|
||||
msg.Channel = msg.Context.Channel
|
||||
msg.SenderID = msg.Context.SenderID
|
||||
msg.ChatID = msg.Context.ChatID
|
||||
if msg.MessageID == "" {
|
||||
msg.MessageID = msg.Context.MessageID
|
||||
}
|
||||
if msg.Peer.Kind == "" {
|
||||
msg.Peer = peerFromContext(msg.Context)
|
||||
if msg.Context.MessageID == "" {
|
||||
msg.Context.MessageID = msg.MessageID
|
||||
}
|
||||
|
||||
msg.Metadata = mergeLegacyMetadata(msg.Metadata, msg.Context)
|
||||
return msg
|
||||
}
|
||||
|
||||
@@ -125,110 +52,6 @@ func normalizeInboundContext(ctx InboundContext) InboundContext {
|
||||
return ctx
|
||||
}
|
||||
|
||||
func peerFromContext(ctx InboundContext) Peer {
|
||||
kind := normalizeKind(ctx.ChatType)
|
||||
if kind == "" {
|
||||
return Peer{}
|
||||
}
|
||||
|
||||
switch kind {
|
||||
case "direct":
|
||||
return Peer{
|
||||
Kind: "direct",
|
||||
ID: firstNonEmpty(strings.TrimSpace(ctx.SenderID), strings.TrimSpace(ctx.ChatID)),
|
||||
}
|
||||
case "group", "channel":
|
||||
return Peer{
|
||||
Kind: kind,
|
||||
ID: strings.TrimSpace(ctx.ChatID),
|
||||
}
|
||||
default:
|
||||
return Peer{
|
||||
Kind: kind,
|
||||
ID: strings.TrimSpace(ctx.ChatID),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func mergeLegacyMetadata(existing map[string]string, ctx InboundContext) map[string]string {
|
||||
merged := cloneStringMap(existing)
|
||||
if len(merged) == 0 {
|
||||
merged = cloneStringMap(ctx.Raw)
|
||||
} else {
|
||||
for k, v := range ctx.Raw {
|
||||
if _, ok := merged[k]; !ok {
|
||||
merged[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if ctx.Account != "" {
|
||||
if merged == nil {
|
||||
merged = make(map[string]string)
|
||||
}
|
||||
setMissing(merged, metadataKeyAccountID, ctx.Account)
|
||||
}
|
||||
if ctx.ReplyToMessageID != "" {
|
||||
if merged == nil {
|
||||
merged = make(map[string]string)
|
||||
}
|
||||
setMissing(merged, metadataKeyReplyToMessage, ctx.ReplyToMessageID)
|
||||
}
|
||||
if ctx.ReplyToSenderID != "" {
|
||||
if merged == nil {
|
||||
merged = make(map[string]string)
|
||||
}
|
||||
setMissing(merged, metadataKeyReplyToSender, ctx.ReplyToSenderID)
|
||||
}
|
||||
if ctx.Mentioned {
|
||||
if merged == nil {
|
||||
merged = make(map[string]string)
|
||||
}
|
||||
setMissing(merged, metadataKeyIsMentioned, "true")
|
||||
}
|
||||
if ctx.TopicID != "" {
|
||||
if merged == nil {
|
||||
merged = make(map[string]string)
|
||||
}
|
||||
setMissing(merged, metadataKeyParentPeerKind, "topic")
|
||||
setMissing(merged, metadataKeyParentPeerID, ctx.TopicID)
|
||||
}
|
||||
|
||||
switch normalizeKind(ctx.SpaceType) {
|
||||
case "guild":
|
||||
if merged == nil {
|
||||
merged = make(map[string]string)
|
||||
}
|
||||
setMissing(merged, metadataKeyGuildID, ctx.SpaceID)
|
||||
case "team", "workspace":
|
||||
if merged == nil {
|
||||
merged = make(map[string]string)
|
||||
}
|
||||
setMissing(merged, metadataKeyTeamID, ctx.SpaceID)
|
||||
}
|
||||
|
||||
if len(merged) == 0 {
|
||||
return nil
|
||||
}
|
||||
return merged
|
||||
}
|
||||
|
||||
func setMissing(dst map[string]string, key, value string) {
|
||||
if value == "" {
|
||||
return
|
||||
}
|
||||
if _, ok := dst[key]; !ok {
|
||||
dst[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
func metadataValue(metadata map[string]string, key string) string {
|
||||
if metadata == nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(metadata[key])
|
||||
}
|
||||
|
||||
func cloneStringMap(src map[string]string) map[string]string {
|
||||
if len(src) == 0 {
|
||||
return nil
|
||||
@@ -241,24 +64,11 @@ func cloneStringMap(src map[string]string) map[string]string {
|
||||
return dst
|
||||
}
|
||||
|
||||
func firstNonEmpty(values ...string) string {
|
||||
for _, value := range values {
|
||||
if value != "" {
|
||||
return value
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func normalizeKind(value string) string {
|
||||
return strings.ToLower(strings.TrimSpace(value))
|
||||
}
|
||||
|
||||
func isTruthy(value string) bool {
|
||||
switch strings.ToLower(strings.TrimSpace(value)) {
|
||||
case "1", "t", "true", "y", "yes", "on":
|
||||
return true
|
||||
func normalizeKind(kind string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(kind)) {
|
||||
case "direct", "group", "channel", "guild", "team", "workspace", "tenant", "topic":
|
||||
return strings.ToLower(strings.TrimSpace(kind))
|
||||
default:
|
||||
return false
|
||||
return strings.ToLower(strings.TrimSpace(kind))
|
||||
}
|
||||
}
|
||||
|
||||
+18
-46
@@ -2,62 +2,34 @@ package bus
|
||||
|
||||
import "strings"
|
||||
|
||||
// ContextFromLegacyOutbound builds a minimal outbound context from the legacy
|
||||
// top-level outbound fields. This keeps older outbound publishers working
|
||||
// while new publishers gradually start carrying the original InboundContext.
|
||||
func ContextFromLegacyOutbound(msg OutboundMessage) InboundContext {
|
||||
// NewOutboundContext builds the minimal normalized addressing context required
|
||||
// to deliver an outbound text message or reply.
|
||||
func NewOutboundContext(channel, chatID, replyToMessageID string) InboundContext {
|
||||
return normalizeInboundContext(InboundContext{
|
||||
Channel: strings.TrimSpace(msg.Channel),
|
||||
ChatID: strings.TrimSpace(msg.ChatID),
|
||||
ReplyToMessageID: strings.TrimSpace(msg.ReplyToMessageID),
|
||||
Channel: strings.TrimSpace(channel),
|
||||
ChatID: strings.TrimSpace(chatID),
|
||||
ReplyToMessageID: strings.TrimSpace(replyToMessageID),
|
||||
})
|
||||
}
|
||||
|
||||
// ContextFromLegacyOutboundMedia builds a minimal outbound context for media.
|
||||
func ContextFromLegacyOutboundMedia(msg OutboundMediaMessage) InboundContext {
|
||||
return normalizeInboundContext(InboundContext{
|
||||
Channel: strings.TrimSpace(msg.Channel),
|
||||
ChatID: strings.TrimSpace(msg.ChatID),
|
||||
})
|
||||
}
|
||||
|
||||
// NormalizeOutboundMessage ensures Context is present and mirrors legacy
|
||||
// top-level addressing fields from it so older senders keep working.
|
||||
// NormalizeOutboundMessage ensures Context is normalized and keeps convenience
|
||||
// mirrors in sync for runtime consumers.
|
||||
func NormalizeOutboundMessage(msg OutboundMessage) OutboundMessage {
|
||||
if msg.Context.isZero() {
|
||||
msg.Context = ContextFromLegacyOutbound(msg)
|
||||
} else {
|
||||
msg.Context = normalizeInboundContext(msg.Context)
|
||||
msg.Context = normalizeInboundContext(msg.Context)
|
||||
msg.Channel = msg.Context.Channel
|
||||
msg.ChatID = msg.Context.ChatID
|
||||
if msg.Context.ReplyToMessageID == "" {
|
||||
msg.Context.ReplyToMessageID = strings.TrimSpace(msg.ReplyToMessageID)
|
||||
}
|
||||
|
||||
if msg.Channel == "" {
|
||||
msg.Channel = msg.Context.Channel
|
||||
}
|
||||
if msg.ChatID == "" {
|
||||
msg.ChatID = msg.Context.ChatID
|
||||
}
|
||||
if msg.ReplyToMessageID == "" {
|
||||
msg.ReplyToMessageID = msg.Context.ReplyToMessageID
|
||||
}
|
||||
|
||||
msg.ReplyToMessageID = msg.Context.ReplyToMessageID
|
||||
return msg
|
||||
}
|
||||
|
||||
// NormalizeOutboundMediaMessage ensures media outbound messages also carry a
|
||||
// normalized context while preserving the legacy top-level routing fields.
|
||||
// normalized context while keeping convenience mirrors in sync.
|
||||
func NormalizeOutboundMediaMessage(msg OutboundMediaMessage) OutboundMediaMessage {
|
||||
if msg.Context.isZero() {
|
||||
msg.Context = ContextFromLegacyOutboundMedia(msg)
|
||||
} else {
|
||||
msg.Context = normalizeInboundContext(msg.Context)
|
||||
}
|
||||
|
||||
if msg.Channel == "" {
|
||||
msg.Channel = msg.Context.Channel
|
||||
}
|
||||
if msg.ChatID == "" {
|
||||
msg.ChatID = msg.Context.ChatID
|
||||
}
|
||||
|
||||
msg.Context = normalizeInboundContext(msg.Context)
|
||||
msg.Channel = msg.Context.Channel
|
||||
msg.ChatID = msg.Context.ChatID
|
||||
return msg
|
||||
}
|
||||
|
||||
+14
-21
@@ -1,11 +1,5 @@
|
||||
package bus
|
||||
|
||||
// Peer identifies the routing peer for a message (direct, group, channel, etc.)
|
||||
type Peer struct {
|
||||
Kind string `json:"kind"` // "direct" | "group" | "channel" | ""
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
// SenderInfo provides structured sender identity information.
|
||||
type SenderInfo struct {
|
||||
Platform string `json:"platform,omitempty"` // "telegram", "discord", "slack", ...
|
||||
@@ -16,9 +10,8 @@ type SenderInfo struct {
|
||||
}
|
||||
|
||||
// InboundContext captures the normalized, platform-agnostic facts about an
|
||||
// inbound message. This is the long-term source of truth for routing and
|
||||
// session allocation. Legacy top-level fields on InboundMessage remain during
|
||||
// the transition and are derived from this context when missing.
|
||||
// inbound message. This is the source of truth for routing and session
|
||||
// allocation.
|
||||
type InboundContext struct {
|
||||
Channel string `json:"channel"`
|
||||
Account string `json:"account,omitempty"`
|
||||
@@ -43,18 +36,18 @@ type InboundContext struct {
|
||||
}
|
||||
|
||||
type InboundMessage struct {
|
||||
Channel string `json:"channel"`
|
||||
SenderID string `json:"sender_id"`
|
||||
Sender SenderInfo `json:"sender"`
|
||||
ChatID string `json:"chat_id"`
|
||||
Context InboundContext `json:"context"`
|
||||
Content string `json:"content"`
|
||||
Media []string `json:"media,omitempty"`
|
||||
Peer Peer `json:"peer"` // routing peer
|
||||
MessageID string `json:"message_id,omitempty"` // platform message ID
|
||||
MediaScope string `json:"media_scope,omitempty"` // media lifecycle scope
|
||||
SessionKey string `json:"session_key"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
Context InboundContext `json:"context"`
|
||||
Sender SenderInfo `json:"sender"`
|
||||
Content string `json:"content"`
|
||||
Media []string `json:"media,omitempty"`
|
||||
MediaScope string `json:"media_scope,omitempty"` // media lifecycle scope
|
||||
SessionKey string `json:"session_key"`
|
||||
|
||||
// Convenience mirrors derived from Context for runtime consumers.
|
||||
Channel string `json:"channel"`
|
||||
SenderID string `json:"sender_id"`
|
||||
ChatID string `json:"chat_id"`
|
||||
MessageID string `json:"message_id,omitempty"` // platform message ID
|
||||
}
|
||||
|
||||
type OutboundMessage struct {
|
||||
|
||||
+13
-33
@@ -244,35 +244,8 @@ func (c *BaseChannel) IsAllowedSender(sender bus.SenderInfo) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *BaseChannel) HandleMessage(
|
||||
ctx context.Context,
|
||||
peer bus.Peer,
|
||||
messageID, senderID, chatID, content string,
|
||||
media []string,
|
||||
metadata map[string]string,
|
||||
senderOpts ...bus.SenderInfo,
|
||||
) {
|
||||
var sender bus.SenderInfo
|
||||
if len(senderOpts) > 0 {
|
||||
sender = senderOpts[0]
|
||||
}
|
||||
|
||||
inboundCtx := bus.ContextFromLegacyInbound(bus.InboundMessage{
|
||||
Channel: c.name,
|
||||
SenderID: senderID,
|
||||
Sender: sender,
|
||||
ChatID: chatID,
|
||||
Peer: peer,
|
||||
MessageID: messageID,
|
||||
Metadata: metadata,
|
||||
})
|
||||
|
||||
c.HandleMessageWithContext(ctx, peer, chatID, content, media, inboundCtx, senderOpts...)
|
||||
}
|
||||
|
||||
func (c *BaseChannel) HandleMessageWithContext(
|
||||
ctx context.Context,
|
||||
peer bus.Peer,
|
||||
deliveryChatID, content string,
|
||||
media []string,
|
||||
inboundCtx bus.InboundContext,
|
||||
@@ -315,15 +288,10 @@ func (c *BaseChannel) HandleMessageWithContext(
|
||||
scope := BuildMediaScope(c.name, deliveryChatID, inboundCtx.MessageID)
|
||||
|
||||
msg := bus.InboundMessage{
|
||||
Channel: c.name,
|
||||
SenderID: resolvedSenderID,
|
||||
Sender: sender,
|
||||
ChatID: deliveryChatID,
|
||||
Context: inboundCtx,
|
||||
Sender: sender,
|
||||
Content: content,
|
||||
Media: media,
|
||||
Peer: peer,
|
||||
MessageID: inboundCtx.MessageID,
|
||||
MediaScope: scope,
|
||||
}
|
||||
msg = bus.NormalizeInboundMessage(msg)
|
||||
@@ -369,6 +337,18 @@ func (c *BaseChannel) HandleMessageWithContext(
|
||||
}
|
||||
}
|
||||
|
||||
// HandleInboundContext publishes a normalized inbound message using only the
|
||||
// structured context.
|
||||
func (c *BaseChannel) HandleInboundContext(
|
||||
ctx context.Context,
|
||||
deliveryChatID, content string,
|
||||
media []string,
|
||||
inboundCtx bus.InboundContext,
|
||||
senderOpts ...bus.SenderInfo,
|
||||
) {
|
||||
c.HandleMessageWithContext(ctx, deliveryChatID, content, media, inboundCtx, senderOpts...)
|
||||
}
|
||||
|
||||
func (c *BaseChannel) SetRunning(running bool) {
|
||||
c.running.Store(running)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
@@ -263,3 +264,58 @@ func TestIsAllowedSender(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleInboundContext_PublishesNormalizedContext(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
inbound bus.InboundContext
|
||||
wantChat string
|
||||
wantSender string
|
||||
}{
|
||||
{
|
||||
name: "direct uses sender as peer",
|
||||
inbound: bus.InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "chat-1",
|
||||
ChatType: "direct",
|
||||
SenderID: "user-1",
|
||||
MessageID: "msg-1",
|
||||
},
|
||||
wantChat: "chat-1",
|
||||
wantSender: "user-1",
|
||||
},
|
||||
{
|
||||
name: "group uses chat as peer",
|
||||
inbound: bus.InboundContext{
|
||||
Channel: "test",
|
||||
ChatID: "group-1",
|
||||
ChatType: "group",
|
||||
SenderID: "user-2",
|
||||
MessageID: "msg-2",
|
||||
},
|
||||
wantChat: "group-1",
|
||||
wantSender: "user-2",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
defer msgBus.Close()
|
||||
|
||||
ch := NewBaseChannel("test", nil, msgBus, nil)
|
||||
ch.HandleInboundContext(context.Background(), tt.inbound.ChatID, "hello", nil, tt.inbound)
|
||||
|
||||
msg := <-msgBus.InboundChan()
|
||||
if msg.ChatID != tt.wantChat {
|
||||
t.Fatalf("ChatID = %q, want %q", msg.ChatID, tt.wantChat)
|
||||
}
|
||||
if msg.SenderID != tt.wantSender {
|
||||
t.Fatalf("SenderID = %q, want %q", msg.SenderID, tt.wantSender)
|
||||
}
|
||||
if msg.Context.ChatType != tt.inbound.ChatType {
|
||||
t.Fatalf("ChatType = %q, want %q", msg.Context.ChatType, tt.inbound.ChatType)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -181,16 +181,15 @@ func (c *DingTalkChannel) onChatBotMessageReceived(
|
||||
"session_webhook": data.SessionWebhook,
|
||||
}
|
||||
|
||||
var peer bus.Peer
|
||||
var (
|
||||
chatType string
|
||||
isMentioned bool
|
||||
)
|
||||
if data.ConversationType == "1" {
|
||||
peerID := senderID
|
||||
if peerID == "" {
|
||||
peerID = chatID
|
||||
}
|
||||
peer = bus.Peer{Kind: "direct", ID: peerID}
|
||||
chatType = "direct"
|
||||
} else {
|
||||
peer = bus.Peer{Kind: "group", ID: data.ConversationId}
|
||||
isMentioned := data.IsInAtList
|
||||
chatType = "group"
|
||||
isMentioned = data.IsInAtList
|
||||
if isMentioned {
|
||||
content = stripLeadingAtMentions(content)
|
||||
}
|
||||
@@ -228,8 +227,21 @@ func (c *DingTalkChannel) onChatBotMessageReceived(
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Handle the message through the base channel
|
||||
c.HandleMessage(ctx, peer, "", resolvedSenderID, chatID, content, nil, metadata, sender)
|
||||
inboundCtx := bus.InboundContext{
|
||||
Channel: "dingtalk",
|
||||
ChatID: chatID,
|
||||
ChatType: chatType,
|
||||
SenderID: resolvedSenderID,
|
||||
Mentioned: isMentioned,
|
||||
Raw: metadata,
|
||||
}
|
||||
if data.SessionWebhook != "" {
|
||||
inboundCtx.ReplyHandles = map[string]string{
|
||||
"session_webhook": data.SessionWebhook,
|
||||
}
|
||||
}
|
||||
|
||||
c.HandleInboundContext(ctx, chatID, content, nil, inboundCtx, sender)
|
||||
|
||||
// Return nil to indicate we've handled the message asynchronously
|
||||
// The response will be sent through the message bus
|
||||
|
||||
@@ -461,14 +461,10 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
|
||||
})
|
||||
|
||||
peerKind := "channel"
|
||||
peerID := m.ChannelID
|
||||
if m.GuildID == "" {
|
||||
peerKind = "direct"
|
||||
peerID = senderID
|
||||
}
|
||||
|
||||
peer := bus.Peer{Kind: peerKind, ID: peerID}
|
||||
|
||||
metadata := map[string]string{
|
||||
"user_id": senderID,
|
||||
"username": m.Author.Username,
|
||||
@@ -494,7 +490,7 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
|
||||
inboundCtx.ReplyToMessageID = m.MessageReference.MessageID
|
||||
}
|
||||
|
||||
c.HandleMessageWithContext(c.ctx, peer, m.ChannelID, content, mediaPaths, inboundCtx, sender)
|
||||
c.HandleInboundContext(c.ctx, m.ChannelID, content, mediaPaths, inboundCtx, sender)
|
||||
}
|
||||
|
||||
// startTyping starts a continuous typing indicator loop for the given chatID.
|
||||
|
||||
@@ -447,22 +447,25 @@ func (c *FeishuChannel) handleMessageReceive(ctx context.Context, event *larkim.
|
||||
if messageType != "" {
|
||||
metadata["message_type"] = messageType
|
||||
}
|
||||
chatType := stringValue(message.ChatType)
|
||||
if chatType != "" {
|
||||
metadata["chat_type"] = chatType
|
||||
rawChatType := stringValue(message.ChatType)
|
||||
if rawChatType != "" {
|
||||
metadata["chat_type"] = rawChatType
|
||||
}
|
||||
if sender != nil && sender.TenantKey != nil {
|
||||
metadata["tenant_key"] = *sender.TenantKey
|
||||
}
|
||||
|
||||
var peer bus.Peer
|
||||
if chatType == "p2p" {
|
||||
peer = bus.Peer{Kind: "direct", ID: senderID}
|
||||
var (
|
||||
inboundChatType string
|
||||
isMentioned bool
|
||||
)
|
||||
if rawChatType == "p2p" {
|
||||
inboundChatType = "direct"
|
||||
} else {
|
||||
peer = bus.Peer{Kind: "group", ID: chatID}
|
||||
inboundChatType = "group"
|
||||
|
||||
// Check if bot was mentioned
|
||||
isMentioned := c.isBotMentioned(message)
|
||||
isMentioned = c.isBotMentioned(message)
|
||||
|
||||
// Strip mention placeholders from content before group trigger check
|
||||
if len(message.Mentions) > 0 {
|
||||
@@ -484,7 +487,21 @@ func (c *FeishuChannel) handleMessageReceive(ctx context.Context, event *larkim.
|
||||
"preview": utils.Truncate(content, 80),
|
||||
})
|
||||
|
||||
c.HandleMessage(ctx, peer, messageID, senderID, chatID, content, mediaRefs, metadata, senderInfo)
|
||||
inboundCtx := bus.InboundContext{
|
||||
Channel: "feishu",
|
||||
ChatID: chatID,
|
||||
ChatType: inboundChatType,
|
||||
SenderID: senderID,
|
||||
MessageID: messageID,
|
||||
Mentioned: isMentioned,
|
||||
Raw: metadata,
|
||||
}
|
||||
if sender != nil && sender.TenantKey != nil && *sender.TenantKey != "" {
|
||||
inboundCtx.SpaceType = "tenant"
|
||||
inboundCtx.SpaceID = *sender.TenantKey
|
||||
}
|
||||
|
||||
c.HandleInboundContext(ctx, chatID, content, mediaRefs, inboundCtx, senderInfo)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -51,14 +51,11 @@ func (c *IRCChannel) onPrivmsg(conn *ircevent.Connection, e ircmsg.Message) {
|
||||
isDM := !strings.HasPrefix(target, "#") && !strings.HasPrefix(target, "&")
|
||||
|
||||
var chatID string
|
||||
var peer bus.Peer
|
||||
|
||||
if isDM {
|
||||
chatID = nick
|
||||
peer = bus.Peer{Kind: "direct", ID: nick}
|
||||
} else {
|
||||
chatID = target
|
||||
peer = bus.Peer{Kind: "group", ID: target}
|
||||
}
|
||||
|
||||
sender := bus.SenderInfo{
|
||||
@@ -73,9 +70,11 @@ func (c *IRCChannel) onPrivmsg(conn *ircevent.Connection, e ircmsg.Message) {
|
||||
return
|
||||
}
|
||||
|
||||
isMentioned := false
|
||||
|
||||
// For channel messages, check group trigger (mention detection)
|
||||
if !isDM {
|
||||
isMentioned := isBotMentioned(content, currentNick)
|
||||
isMentioned = isBotMentioned(content, currentNick)
|
||||
if isMentioned {
|
||||
content = stripBotMention(content, currentNick)
|
||||
}
|
||||
@@ -100,7 +99,21 @@ func (c *IRCChannel) onPrivmsg(conn *ircevent.Connection, e ircmsg.Message) {
|
||||
metadata["channel"] = target
|
||||
}
|
||||
|
||||
c.HandleMessage(c.ctx, peer, messageID, nick, chatID, content, nil, metadata, sender)
|
||||
inboundCtx := bus.InboundContext{
|
||||
Channel: "irc",
|
||||
ChatID: chatID,
|
||||
SenderID: nick,
|
||||
MessageID: messageID,
|
||||
Mentioned: isMentioned,
|
||||
Raw: metadata,
|
||||
}
|
||||
if isDM {
|
||||
inboundCtx.ChatType = "direct"
|
||||
} else {
|
||||
inboundCtx.ChatType = "group"
|
||||
}
|
||||
|
||||
c.HandleInboundContext(c.ctx, chatID, content, nil, inboundCtx, sender)
|
||||
}
|
||||
|
||||
// nickMentionedAt returns the byte index where botNick is mentioned in content
|
||||
|
||||
@@ -368,13 +368,6 @@ func (c *LINEChannel) processEvent(event lineEvent) {
|
||||
"source_type": event.Source.Type,
|
||||
}
|
||||
|
||||
var peer bus.Peer
|
||||
if isGroup {
|
||||
peer = bus.Peer{Kind: "group", ID: chatID}
|
||||
} else {
|
||||
peer = bus.Peer{Kind: "direct", ID: senderID}
|
||||
}
|
||||
|
||||
logger.DebugCF("line", "Received message", map[string]any{
|
||||
"sender_id": senderID,
|
||||
"chat_id": chatID,
|
||||
@@ -396,7 +389,7 @@ func (c *LINEChannel) processEvent(event lineEvent) {
|
||||
inboundCtx := bus.InboundContext{
|
||||
Channel: c.Name(),
|
||||
ChatID: chatID,
|
||||
ChatType: peer.Kind,
|
||||
ChatType: map[bool]string{true: "group", false: "direct"}[isGroup],
|
||||
SenderID: senderID,
|
||||
MessageID: msg.ID,
|
||||
Mentioned: isMentioned,
|
||||
@@ -411,7 +404,7 @@ func (c *LINEChannel) processEvent(event lineEvent) {
|
||||
}
|
||||
}
|
||||
|
||||
c.HandleMessageWithContext(c.ctx, peer, chatID, content, mediaPaths, inboundCtx, sender)
|
||||
c.HandleInboundContext(c.ctx, chatID, content, mediaPaths, inboundCtx, sender)
|
||||
}
|
||||
|
||||
// isBotMentioned checks if the bot is mentioned in the message.
|
||||
|
||||
@@ -196,17 +196,15 @@ func (c *MaixCamChannel) handlePersonDetection(msg MaixCamMessage) {
|
||||
return
|
||||
}
|
||||
|
||||
c.HandleMessage(
|
||||
c.ctx,
|
||||
bus.Peer{Kind: "channel", ID: "default"},
|
||||
"",
|
||||
senderID,
|
||||
chatID,
|
||||
content,
|
||||
[]string{},
|
||||
metadata,
|
||||
sender,
|
||||
)
|
||||
inboundCtx := bus.InboundContext{
|
||||
Channel: "maixcam",
|
||||
ChatID: chatID,
|
||||
ChatType: "channel",
|
||||
SenderID: senderID,
|
||||
Raw: metadata,
|
||||
}
|
||||
|
||||
c.HandleInboundContext(c.ctx, chatID, content, nil, inboundCtx, sender)
|
||||
}
|
||||
|
||||
func (c *MaixCamChannel) handleStatusUpdate(msg MaixCamMessage) {
|
||||
|
||||
+43
-23
@@ -97,6 +97,22 @@ type asyncTask struct {
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func outboundMessageChannel(msg bus.OutboundMessage) string {
|
||||
return msg.Context.Channel
|
||||
}
|
||||
|
||||
func outboundMessageChatID(msg bus.OutboundMessage) string {
|
||||
return msg.Context.ChatID
|
||||
}
|
||||
|
||||
func outboundMediaChannel(msg bus.OutboundMediaMessage) string {
|
||||
return msg.Context.Channel
|
||||
}
|
||||
|
||||
func outboundMediaChatID(msg bus.OutboundMediaMessage) string {
|
||||
return msg.Context.ChatID
|
||||
}
|
||||
|
||||
// RecordPlaceholder registers a placeholder message for later editing.
|
||||
// Implements PlaceholderRecorder.
|
||||
func (m *Manager) RecordPlaceholder(channel, chatID, placeholderID string) {
|
||||
@@ -160,7 +176,8 @@ func (m *Manager) RecordReactionUndo(channel, chatID string, undo func()) {
|
||||
// preSend handles typing stop, reaction undo, and placeholder editing before sending a message.
|
||||
// Returns the delivered message IDs and true when delivery completed before a normal Send.
|
||||
func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMessage, ch Channel) ([]string, bool) {
|
||||
key := name + ":" + msg.ChatID
|
||||
chatID := outboundMessageChatID(msg)
|
||||
key := name + ":" + chatID
|
||||
|
||||
// 1. Stop typing
|
||||
if v, loaded := m.typingStops.LoadAndDelete(key); loaded {
|
||||
@@ -182,9 +199,9 @@ func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMess
|
||||
if entry, ok := v.(placeholderEntry); ok && entry.id != "" {
|
||||
// Prefer deleting the placeholder (cleaner UX than editing to same content)
|
||||
if deleter, ok := ch.(MessageDeleter); ok {
|
||||
deleter.DeleteMessage(ctx, msg.ChatID, entry.id) // best effort
|
||||
deleter.DeleteMessage(ctx, chatID, entry.id) // best effort
|
||||
} else if editor, ok := ch.(MessageEditor); ok {
|
||||
editor.EditMessage(ctx, msg.ChatID, entry.id, msg.Content) // fallback
|
||||
editor.EditMessage(ctx, chatID, entry.id, msg.Content) // fallback
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -195,7 +212,7 @@ func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMess
|
||||
if v, loaded := m.placeholders.LoadAndDelete(key); loaded {
|
||||
if entry, ok := v.(placeholderEntry); ok && entry.id != "" {
|
||||
if editor, ok := ch.(MessageEditor); ok {
|
||||
if err := editor.EditMessage(ctx, msg.ChatID, entry.id, msg.Content); err == nil {
|
||||
if err := editor.EditMessage(ctx, chatID, entry.id, msg.Content); err == nil {
|
||||
return []string{entry.id}, true
|
||||
}
|
||||
// edit failed → fall through to normal Send
|
||||
@@ -211,7 +228,8 @@ func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMess
|
||||
// delivery never edits the placeholder because there is no text payload to
|
||||
// replace it with; it only attempts to delete the placeholder when possible.
|
||||
func (m *Manager) preSendMedia(ctx context.Context, name string, msg bus.OutboundMediaMessage, ch Channel) {
|
||||
key := name + ":" + msg.ChatID
|
||||
chatID := outboundMediaChatID(msg)
|
||||
key := name + ":" + chatID
|
||||
|
||||
// 1. Stop typing
|
||||
if v, loaded := m.typingStops.LoadAndDelete(key); loaded {
|
||||
@@ -234,7 +252,7 @@ func (m *Manager) preSendMedia(ctx context.Context, name string, msg bus.Outboun
|
||||
if v, loaded := m.placeholders.LoadAndDelete(key); loaded {
|
||||
if entry, ok := v.(placeholderEntry); ok && entry.id != "" {
|
||||
if deleter, ok := ch.(MessageDeleter); ok {
|
||||
deleter.DeleteMessage(ctx, msg.ChatID, entry.id) // best effort
|
||||
deleter.DeleteMessage(ctx, chatID, entry.id) // best effort
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -756,7 +774,7 @@ func (m *Manager) sendWithRetry(
|
||||
// All retries exhausted or permanent failure
|
||||
logger.ErrorCF("channels", "Send failed", map[string]any{
|
||||
"channel": name,
|
||||
"chat_id": msg.ChatID,
|
||||
"chat_id": outboundMessageChatID(msg),
|
||||
"error": lastErr.Error(),
|
||||
"retries": maxRetries,
|
||||
})
|
||||
@@ -818,7 +836,7 @@ func (m *Manager) dispatchOutbound(ctx context.Context) {
|
||||
dispatchLoop(
|
||||
ctx, m,
|
||||
m.bus.OutboundChan(),
|
||||
func(msg bus.OutboundMessage) string { return msg.Channel },
|
||||
func(msg bus.OutboundMessage) string { return outboundMessageChannel(msg) },
|
||||
func(ctx context.Context, w *channelWorker, msg bus.OutboundMessage) bool {
|
||||
select {
|
||||
case w.queue <- msg:
|
||||
@@ -838,7 +856,7 @@ func (m *Manager) dispatchOutboundMedia(ctx context.Context) {
|
||||
dispatchLoop(
|
||||
ctx, m,
|
||||
m.bus.OutboundMediaChan(),
|
||||
func(msg bus.OutboundMediaMessage) string { return msg.Channel },
|
||||
func(msg bus.OutboundMediaMessage) string { return outboundMediaChannel(msg) },
|
||||
func(ctx context.Context, w *channelWorker, msg bus.OutboundMediaMessage) bool {
|
||||
select {
|
||||
case w.mediaQueue <- msg:
|
||||
@@ -937,7 +955,7 @@ func (m *Manager) sendMediaWithRetry(
|
||||
// All retries exhausted or permanent failure
|
||||
logger.ErrorCF("channels", "SendMedia failed", map[string]any{
|
||||
"channel": name,
|
||||
"chat_id": msg.ChatID,
|
||||
"chat_id": outboundMediaChatID(msg),
|
||||
"error": lastErr.Error(),
|
||||
"retries": maxRetries,
|
||||
})
|
||||
@@ -1131,17 +1149,18 @@ func (m *Manager) UnregisterChannel(name string) {
|
||||
// a subsequent operation depends on the message having been sent.
|
||||
func (m *Manager) SendMessage(ctx context.Context, msg bus.OutboundMessage) error {
|
||||
msg = bus.NormalizeOutboundMessage(msg)
|
||||
channelName := outboundMessageChannel(msg)
|
||||
|
||||
m.mu.RLock()
|
||||
_, exists := m.channels[msg.Channel]
|
||||
w, wExists := m.workers[msg.Channel]
|
||||
_, exists := m.channels[channelName]
|
||||
w, wExists := m.workers[channelName]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return fmt.Errorf("channel %s not found", msg.Channel)
|
||||
return fmt.Errorf("channel %s not found", channelName)
|
||||
}
|
||||
if !wExists || w == nil {
|
||||
return fmt.Errorf("channel %s has no active worker", msg.Channel)
|
||||
return fmt.Errorf("channel %s has no active worker", channelName)
|
||||
}
|
||||
|
||||
maxLen := 0
|
||||
@@ -1152,10 +1171,10 @@ func (m *Manager) SendMessage(ctx context.Context, msg bus.OutboundMessage) erro
|
||||
for _, chunk := range SplitMessage(msg.Content, maxLen) {
|
||||
chunkMsg := msg
|
||||
chunkMsg.Content = chunk
|
||||
m.sendWithRetry(ctx, msg.Channel, w, chunkMsg)
|
||||
m.sendWithRetry(ctx, channelName, w, chunkMsg)
|
||||
}
|
||||
} else {
|
||||
m.sendWithRetry(ctx, msg.Channel, w, msg)
|
||||
m.sendWithRetry(ctx, channelName, w, msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1166,20 +1185,21 @@ func (m *Manager) SendMessage(ctx context.Context, msg bus.OutboundMessage) erro
|
||||
// depends on actual media delivery.
|
||||
func (m *Manager) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error {
|
||||
msg = bus.NormalizeOutboundMediaMessage(msg)
|
||||
channelName := outboundMediaChannel(msg)
|
||||
|
||||
m.mu.RLock()
|
||||
_, exists := m.channels[msg.Channel]
|
||||
w, wExists := m.workers[msg.Channel]
|
||||
_, exists := m.channels[channelName]
|
||||
w, wExists := m.workers[channelName]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return fmt.Errorf("channel %s not found", msg.Channel)
|
||||
return fmt.Errorf("channel %s not found", channelName)
|
||||
}
|
||||
if !wExists || w == nil {
|
||||
return fmt.Errorf("channel %s has no active worker", msg.Channel)
|
||||
return fmt.Errorf("channel %s has no active worker", channelName)
|
||||
}
|
||||
|
||||
_, err := m.sendMediaWithRetry(ctx, msg.Channel, w, msg)
|
||||
_, err := m.sendMediaWithRetry(ctx, channelName, w, msg)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1194,10 +1214,10 @@ func (m *Manager) SendToChannel(ctx context.Context, channelName, chatID, conten
|
||||
}
|
||||
|
||||
msg := bus.OutboundMessage{
|
||||
Channel: channelName,
|
||||
ChatID: chatID,
|
||||
Context: bus.NewOutboundContext(channelName, chatID, ""),
|
||||
Content: content,
|
||||
}
|
||||
msg = bus.NormalizeOutboundMessage(msg)
|
||||
|
||||
if wExists && w != nil {
|
||||
select {
|
||||
|
||||
+136
-45
@@ -89,6 +89,20 @@ func newTestManager() *Manager {
|
||||
}
|
||||
}
|
||||
|
||||
func testOutboundMessage(msg bus.OutboundMessage) bus.OutboundMessage {
|
||||
if msg.Context.Channel == "" && msg.Context.ChatID == "" {
|
||||
msg.Context = bus.NewOutboundContext(msg.Channel, msg.ChatID, msg.ReplyToMessageID)
|
||||
}
|
||||
return bus.NormalizeOutboundMessage(msg)
|
||||
}
|
||||
|
||||
func testOutboundMediaMessage(msg bus.OutboundMediaMessage) bus.OutboundMediaMessage {
|
||||
if msg.Context.Channel == "" && msg.Context.ChatID == "" {
|
||||
msg.Context = bus.NewOutboundContext(msg.Channel, msg.ChatID, "")
|
||||
}
|
||||
return bus.NormalizeOutboundMediaMessage(msg)
|
||||
}
|
||||
|
||||
func TestSendWithRetry_Success(t *testing.T) {
|
||||
m := newTestManager()
|
||||
var callCount int
|
||||
@@ -104,7 +118,7 @@ func TestSendWithRetry_Success(t *testing.T) {
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
|
||||
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"})
|
||||
|
||||
m.sendWithRetry(ctx, "test", w, msg)
|
||||
|
||||
@@ -131,7 +145,7 @@ func TestSendWithRetry_TemporaryThenSuccess(t *testing.T) {
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
|
||||
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"})
|
||||
|
||||
m.sendWithRetry(ctx, "test", w, msg)
|
||||
|
||||
@@ -155,7 +169,7 @@ func TestSendWithRetry_PermanentFailure(t *testing.T) {
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
|
||||
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"})
|
||||
|
||||
m.sendWithRetry(ctx, "test", w, msg)
|
||||
|
||||
@@ -179,7 +193,7 @@ func TestSendWithRetry_NotRunning(t *testing.T) {
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
|
||||
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"})
|
||||
|
||||
m.sendWithRetry(ctx, "test", w, msg)
|
||||
|
||||
@@ -206,7 +220,7 @@ func TestSendWithRetry_RateLimitRetry(t *testing.T) {
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
|
||||
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"})
|
||||
|
||||
start := time.Now()
|
||||
m.sendWithRetry(ctx, "test", w, msg)
|
||||
@@ -236,7 +250,7 @@ func TestSendWithRetry_MaxRetriesExhausted(t *testing.T) {
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
|
||||
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"})
|
||||
|
||||
m.sendWithRetry(ctx, "test", w, msg)
|
||||
|
||||
@@ -262,11 +276,11 @@ func TestSendMedia_Success(t *testing.T) {
|
||||
m.channels["test"] = ch
|
||||
m.workers["test"] = w
|
||||
|
||||
err := m.SendMedia(context.Background(), bus.OutboundMediaMessage{
|
||||
err := m.SendMedia(context.Background(), testOutboundMediaMessage(bus.OutboundMediaMessage{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
Parts: []bus.MediaPart{{Ref: "media://abc"}},
|
||||
})
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("SendMedia() error = %v", err)
|
||||
}
|
||||
@@ -289,11 +303,11 @@ func TestSendMedia_PropagatesFailure(t *testing.T) {
|
||||
m.channels["test"] = ch
|
||||
m.workers["test"] = w
|
||||
|
||||
err := m.SendMedia(context.Background(), bus.OutboundMediaMessage{
|
||||
err := m.SendMedia(context.Background(), testOutboundMediaMessage(bus.OutboundMediaMessage{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
Parts: []bus.MediaPart{{Ref: "media://abc"}},
|
||||
})
|
||||
}))
|
||||
if err == nil {
|
||||
t.Fatal("expected SendMedia to return error")
|
||||
}
|
||||
@@ -316,11 +330,11 @@ func TestSendMedia_UnsupportedChannelReturnsError(t *testing.T) {
|
||||
m.channels["test"] = ch
|
||||
m.workers["test"] = w
|
||||
|
||||
err := m.SendMedia(context.Background(), bus.OutboundMediaMessage{
|
||||
err := m.SendMedia(context.Background(), testOutboundMediaMessage(bus.OutboundMediaMessage{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
Parts: []bus.MediaPart{{Ref: "media://abc"}},
|
||||
})
|
||||
}))
|
||||
if err == nil {
|
||||
t.Fatal("expected SendMedia to return error for unsupported channel")
|
||||
}
|
||||
@@ -346,11 +360,11 @@ func TestSendMedia_DeletesPlaceholderBeforeSending(t *testing.T) {
|
||||
m.workers["test"] = w
|
||||
m.RecordPlaceholder("test", "chat1", "placeholder-1")
|
||||
|
||||
err := m.SendMedia(context.Background(), bus.OutboundMediaMessage{
|
||||
err := m.SendMedia(context.Background(), testOutboundMediaMessage(bus.OutboundMediaMessage{
|
||||
Channel: "test",
|
||||
ChatID: "chat1",
|
||||
Parts: []bus.MediaPart{{Ref: "media://abc"}},
|
||||
})
|
||||
}))
|
||||
if err != nil {
|
||||
t.Fatalf("SendMedia() error = %v", err)
|
||||
}
|
||||
@@ -383,7 +397,7 @@ func TestSendWithRetry_UnknownError(t *testing.T) {
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
|
||||
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"})
|
||||
|
||||
m.sendWithRetry(ctx, "test", w, msg)
|
||||
|
||||
@@ -407,7 +421,7 @@ func TestSendWithRetry_ContextCancelled(t *testing.T) {
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
|
||||
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"})
|
||||
|
||||
// Cancel context after first Send attempt returns
|
||||
ch.sendFn = func(_ context.Context, _ bus.OutboundMessage) error {
|
||||
@@ -453,7 +467,7 @@ func TestWorkerRateLimiter(t *testing.T) {
|
||||
|
||||
// Enqueue 4 messages
|
||||
for i := range 4 {
|
||||
w.queue <- bus.OutboundMessage{Channel: "test", ChatID: "1", Content: fmt.Sprintf("msg%d", i)}
|
||||
w.queue <- testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: fmt.Sprintf("msg%d", i)})
|
||||
}
|
||||
|
||||
// Wait enough time for all messages to be sent (4 msgs at 2/s = ~2s, give extra margin)
|
||||
@@ -529,7 +543,7 @@ func TestRunWorker_MessageSplitting(t *testing.T) {
|
||||
go m.runWorker(ctx, "test", w)
|
||||
|
||||
// Send a message that should be split
|
||||
w.queue <- bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello world"}
|
||||
w.queue <- testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello world"})
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
@@ -570,7 +584,7 @@ func TestSendWithRetry_ExponentialBackoff(t *testing.T) {
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}
|
||||
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"})
|
||||
|
||||
start := time.Now()
|
||||
m.sendWithRetry(ctx, "test", w, msg)
|
||||
@@ -630,7 +644,7 @@ func TestPreSend_PlaceholderEditSuccess(t *testing.T) {
|
||||
// Register placeholder
|
||||
m.RecordPlaceholder("test", "123", "456")
|
||||
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"}
|
||||
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"})
|
||||
_, edited := m.preSend(context.Background(), "test", msg, ch)
|
||||
|
||||
if !edited {
|
||||
@@ -660,7 +674,7 @@ func TestPreSend_PlaceholderEditFails_FallsThrough(t *testing.T) {
|
||||
|
||||
m.RecordPlaceholder("test", "123", "456")
|
||||
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"}
|
||||
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"})
|
||||
_, edited := m.preSend(context.Background(), "test", msg, ch)
|
||||
|
||||
if edited {
|
||||
@@ -719,7 +733,7 @@ func TestPreSend_TypingStopCalled(t *testing.T) {
|
||||
stopCalled = true
|
||||
})
|
||||
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"}
|
||||
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"})
|
||||
m.preSend(context.Background(), "test", msg, ch)
|
||||
|
||||
if !stopCalled {
|
||||
@@ -736,7 +750,7 @@ func TestPreSend_NoRegisteredState(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"}
|
||||
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"})
|
||||
_, edited := m.preSend(context.Background(), "test", msg, ch)
|
||||
|
||||
if edited {
|
||||
@@ -766,7 +780,7 @@ func TestPreSend_TypingAndPlaceholder(t *testing.T) {
|
||||
})
|
||||
m.RecordPlaceholder("test", "123", "456")
|
||||
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"}
|
||||
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"})
|
||||
_, edited := m.preSend(context.Background(), "test", msg, ch)
|
||||
|
||||
if !stopCalled {
|
||||
@@ -830,7 +844,7 @@ func TestRecordTypingStop_ReplacesExistingStop(t *testing.T) {
|
||||
t.Fatalf("expected replacement typing stop to stay active until preSend, got %d calls", newStopCalls)
|
||||
}
|
||||
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"}
|
||||
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"})
|
||||
m.preSend(context.Background(), "test", msg, &mockChannel{})
|
||||
|
||||
if newStopCalls != 1 {
|
||||
@@ -864,7 +878,7 @@ func TestSendWithRetry_PreSendEditsPlaceholder(t *testing.T) {
|
||||
limiter: rate.NewLimiter(rate.Inf, 1),
|
||||
}
|
||||
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"}
|
||||
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"})
|
||||
m.sendWithRetry(context.Background(), "test", w, msg)
|
||||
|
||||
if sendCalled {
|
||||
@@ -1027,7 +1041,7 @@ func TestPreSendStillWorksWithWrappedTypes(t *testing.T) {
|
||||
})
|
||||
m.RecordPlaceholder("test", "chat1", "ph_id")
|
||||
|
||||
msg := bus.OutboundMessage{Channel: "test", ChatID: "chat1", Content: "response"}
|
||||
msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "chat1", Content: "response"})
|
||||
_, edited := m.preSend(context.Background(), "test", msg, ch)
|
||||
|
||||
if !stopCalled {
|
||||
@@ -1130,11 +1144,11 @@ func TestManager_PlaceholderConsumedByResponse(t *testing.T) {
|
||||
|
||||
// Transcription feedback arrives first — it should consume the placeholder
|
||||
// and be delivered via EditMessage, not Send.
|
||||
msgTranscript := bus.OutboundMessage{
|
||||
msgTranscript := testOutboundMessage(bus.OutboundMessage{
|
||||
Channel: "mock",
|
||||
ChatID: "chat-1",
|
||||
Content: "Transcript: hello",
|
||||
}
|
||||
})
|
||||
mgr.sendWithRetry(ctx, "mock", worker, msgTranscript)
|
||||
|
||||
if mockCh.editedMessages != 1 {
|
||||
@@ -1150,11 +1164,11 @@ func TestManager_PlaceholderConsumedByResponse(t *testing.T) {
|
||||
}
|
||||
|
||||
// Final LLM response arrives — no placeholder left, so it goes through Send
|
||||
msgFinal := bus.OutboundMessage{
|
||||
msgFinal := testOutboundMessage(bus.OutboundMessage{
|
||||
Channel: "mock",
|
||||
ChatID: "chat-1",
|
||||
Content: "Final Answer",
|
||||
}
|
||||
})
|
||||
mgr.sendWithRetry(ctx, "mock", worker, msgFinal)
|
||||
|
||||
if len(mockCh.sentMessages) != 1 {
|
||||
@@ -1180,12 +1194,12 @@ func TestSendMessage_Synchronous(t *testing.T) {
|
||||
m.channels["test"] = ch
|
||||
m.workers["test"] = w
|
||||
|
||||
msg := bus.OutboundMessage{
|
||||
msg := testOutboundMessage(bus.OutboundMessage{
|
||||
Channel: "test",
|
||||
ChatID: "123",
|
||||
Content: "hello world",
|
||||
ReplyToMessageID: "msg-456",
|
||||
}
|
||||
})
|
||||
|
||||
err := m.SendMessage(context.Background(), msg)
|
||||
if err != nil {
|
||||
@@ -1207,11 +1221,11 @@ func TestSendMessage_Synchronous(t *testing.T) {
|
||||
func TestSendMessage_UnknownChannel(t *testing.T) {
|
||||
m := newTestManager()
|
||||
|
||||
msg := bus.OutboundMessage{
|
||||
msg := testOutboundMessage(bus.OutboundMessage{
|
||||
Channel: "nonexistent",
|
||||
ChatID: "123",
|
||||
Content: "hello",
|
||||
}
|
||||
})
|
||||
|
||||
err := m.SendMessage(context.Background(), msg)
|
||||
if err == nil {
|
||||
@@ -1228,11 +1242,11 @@ func TestSendMessage_NoWorker(t *testing.T) {
|
||||
m.channels["test"] = ch
|
||||
// No worker registered
|
||||
|
||||
msg := bus.OutboundMessage{
|
||||
msg := testOutboundMessage(bus.OutboundMessage{
|
||||
Channel: "test",
|
||||
ChatID: "123",
|
||||
Content: "hello",
|
||||
}
|
||||
})
|
||||
|
||||
err := m.SendMessage(context.Background(), msg)
|
||||
if err == nil {
|
||||
@@ -1261,11 +1275,11 @@ func TestSendMessage_WithRetry(t *testing.T) {
|
||||
m.channels["test"] = ch
|
||||
m.workers["test"] = w
|
||||
|
||||
msg := bus.OutboundMessage{
|
||||
msg := testOutboundMessage(bus.OutboundMessage{
|
||||
Channel: "test",
|
||||
ChatID: "123",
|
||||
Content: "retry me",
|
||||
}
|
||||
})
|
||||
|
||||
err := m.SendMessage(context.Background(), msg)
|
||||
if err != nil {
|
||||
@@ -1277,6 +1291,46 @@ func TestSendMessage_WithRetry(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendMessage_ContextOnlyUsesContextAddressing(t *testing.T) {
|
||||
m := newTestManager()
|
||||
|
||||
var received []bus.OutboundMessage
|
||||
ch := &mockChannel{
|
||||
sendFn: func(_ context.Context, msg bus.OutboundMessage) error {
|
||||
received = append(received, msg)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
w := &channelWorker{
|
||||
ch: ch,
|
||||
limiter: rate.NewLimiter(rate.Inf, 1),
|
||||
}
|
||||
m.channels["test"] = ch
|
||||
m.workers["test"] = w
|
||||
|
||||
msg := testOutboundMessage(bus.OutboundMessage{
|
||||
Context: bus.NewOutboundContext("test", "123", "msg-9"),
|
||||
Content: "hello",
|
||||
})
|
||||
|
||||
if err := m.SendMessage(context.Background(), msg); err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
if len(received) != 1 {
|
||||
t.Fatalf("expected 1 message sent, got %d", len(received))
|
||||
}
|
||||
if received[0].Channel != "test" || received[0].ChatID != "123" {
|
||||
t.Fatalf("expected mirrored legacy address, got %+v", received[0])
|
||||
}
|
||||
if received[0].Context.Channel != "test" || received[0].Context.ChatID != "123" {
|
||||
t.Fatalf("expected context address to be preserved, got %+v", received[0].Context)
|
||||
}
|
||||
if received[0].ReplyToMessageID != "msg-9" {
|
||||
t.Fatalf("expected reply_to_message_id msg-9, got %q", received[0].ReplyToMessageID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendMessage_WithSplitting(t *testing.T) {
|
||||
m := newTestManager()
|
||||
|
||||
@@ -1298,11 +1352,11 @@ func TestSendMessage_WithSplitting(t *testing.T) {
|
||||
m.channels["test"] = ch
|
||||
m.workers["test"] = w
|
||||
|
||||
msg := bus.OutboundMessage{
|
||||
msg := testOutboundMessage(bus.OutboundMessage{
|
||||
Channel: "test",
|
||||
ChatID: "123",
|
||||
Content: "hello world",
|
||||
}
|
||||
})
|
||||
|
||||
err := m.SendMessage(context.Background(), msg)
|
||||
if err != nil {
|
||||
@@ -1314,6 +1368,43 @@ func TestSendMessage_WithSplitting(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendMedia_ContextOnlyUsesContextAddressing(t *testing.T) {
|
||||
m := newTestManager()
|
||||
|
||||
var received []bus.OutboundMediaMessage
|
||||
ch := &mockMediaChannel{
|
||||
sendMediaFn: func(_ context.Context, msg bus.OutboundMediaMessage) ([]string, error) {
|
||||
received = append(received, msg)
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
w := &channelWorker{
|
||||
ch: ch,
|
||||
limiter: rate.NewLimiter(rate.Inf, 1),
|
||||
}
|
||||
m.channels["test"] = ch
|
||||
m.workers["test"] = w
|
||||
|
||||
msg := testOutboundMediaMessage(bus.OutboundMediaMessage{
|
||||
Context: bus.NewOutboundContext("test", "media-chat", ""),
|
||||
Parts: []bus.MediaPart{{Type: "image", Ref: "media://1"}},
|
||||
})
|
||||
|
||||
if err := m.SendMedia(context.Background(), msg); err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
if len(received) != 1 {
|
||||
t.Fatalf("expected 1 media message sent, got %d", len(received))
|
||||
}
|
||||
if received[0].Channel != "test" || received[0].ChatID != "media-chat" {
|
||||
t.Fatalf("expected mirrored legacy media address, got %+v", received[0])
|
||||
}
|
||||
if received[0].Context.Channel != "test" || received[0].Context.ChatID != "media-chat" {
|
||||
t.Fatalf("expected media context address to be preserved, got %+v", received[0].Context)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendMessage_PreservesOrdering(t *testing.T) {
|
||||
m := newTestManager()
|
||||
|
||||
@@ -1333,12 +1424,12 @@ func TestSendMessage_PreservesOrdering(t *testing.T) {
|
||||
m.workers["test"] = w
|
||||
|
||||
// Send two messages sequentially — they must arrive in order
|
||||
_ = m.SendMessage(context.Background(), bus.OutboundMessage{
|
||||
_ = m.SendMessage(context.Background(), testOutboundMessage(bus.OutboundMessage{
|
||||
Channel: "test", ChatID: "1", Content: "first",
|
||||
})
|
||||
_ = m.SendMessage(context.Background(), bus.OutboundMessage{
|
||||
}))
|
||||
_ = m.SendMessage(context.Background(), testOutboundMessage(bus.OutboundMessage{
|
||||
Channel: "test", ChatID: "1", Content: "second",
|
||||
})
|
||||
}))
|
||||
|
||||
if len(order) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %d", len(order))
|
||||
|
||||
@@ -736,10 +736,8 @@ func (c *MatrixChannel) handleMessageEvent(ctx context.Context, evt *event.Event
|
||||
}
|
||||
|
||||
peerKind := "direct"
|
||||
peerID := senderID
|
||||
if isGroup {
|
||||
peerKind = "group"
|
||||
peerID = roomID
|
||||
}
|
||||
|
||||
metadata := map[string]string{
|
||||
@@ -752,17 +750,19 @@ func (c *MatrixChannel) handleMessageEvent(ctx context.Context, evt *event.Event
|
||||
metadata["reply_to_msg_id"] = replyTo.String()
|
||||
}
|
||||
|
||||
c.HandleMessage(
|
||||
c.baseContext(),
|
||||
bus.Peer{Kind: peerKind, ID: peerID},
|
||||
evt.ID.String(),
|
||||
senderID,
|
||||
roomID,
|
||||
content,
|
||||
mediaPaths,
|
||||
metadata,
|
||||
sender,
|
||||
)
|
||||
inboundCtx := bus.InboundContext{
|
||||
Channel: "matrix",
|
||||
ChatID: roomID,
|
||||
ChatType: peerKind,
|
||||
SenderID: senderID,
|
||||
MessageID: evt.ID.String(),
|
||||
Raw: metadata,
|
||||
}
|
||||
if replyTo := msgEvt.GetRelatesTo().GetReplyTo(); replyTo != "" {
|
||||
inboundCtx.ReplyToMessageID = replyTo.String()
|
||||
}
|
||||
|
||||
c.HandleInboundContext(c.baseContext(), roomID, content, mediaPaths, inboundCtx, sender)
|
||||
}
|
||||
|
||||
// decryptEvent decrypts an encrypted event and returns the decrypted message event content.
|
||||
|
||||
@@ -994,8 +994,6 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) {
|
||||
var contextChatID string
|
||||
var contextChatType string
|
||||
|
||||
var peer bus.Peer
|
||||
|
||||
metadata := map[string]string{}
|
||||
|
||||
if parsed.ReplyTo != "" {
|
||||
@@ -1007,14 +1005,12 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) {
|
||||
chatID = "private:" + senderID
|
||||
contextChatID = senderID
|
||||
contextChatType = "direct"
|
||||
peer = bus.Peer{Kind: "direct", ID: senderID}
|
||||
|
||||
case "group":
|
||||
groupIDStr := strconv.FormatInt(groupID, 10)
|
||||
chatID = "group:" + groupIDStr
|
||||
contextChatID = groupIDStr
|
||||
contextChatType = "group"
|
||||
peer = bus.Peer{Kind: "group", ID: groupIDStr}
|
||||
metadata["group_id"] = groupIDStr
|
||||
|
||||
senderUserID, _ := parseJSONInt64(sender.UserID)
|
||||
@@ -1089,7 +1085,7 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) {
|
||||
Raw: metadata,
|
||||
}
|
||||
|
||||
c.HandleMessageWithContext(c.ctx, peer, chatID, content, parsed.Media, inboundCtx, senderInfo)
|
||||
c.HandleInboundContext(c.ctx, chatID, content, parsed.Media, inboundCtx, senderInfo)
|
||||
}
|
||||
|
||||
func (c *OneBotChannel) isDuplicate(messageID string) bool {
|
||||
|
||||
@@ -254,8 +254,6 @@ func (c *PicoClientChannel) handleServerMessage(pc *picoConn, msg PicoMessage) {
|
||||
|
||||
chatID := "pico_client:" + sessionID
|
||||
senderID := "pico-remote"
|
||||
peer := bus.Peer{Kind: "direct", ID: chatID}
|
||||
|
||||
sender := bus.SenderInfo{
|
||||
Platform: "pico_client",
|
||||
PlatformID: senderID,
|
||||
@@ -266,10 +264,19 @@ func (c *PicoClientChannel) handleServerMessage(pc *picoConn, msg PicoMessage) {
|
||||
return
|
||||
}
|
||||
|
||||
c.HandleMessage(c.ctx, peer, msg.ID, senderID, chatID, content, nil, map[string]string{
|
||||
"platform": "pico_client",
|
||||
"session_id": sessionID,
|
||||
}, sender)
|
||||
inboundCtx := bus.InboundContext{
|
||||
Channel: "pico_client",
|
||||
ChatID: chatID,
|
||||
ChatType: "direct",
|
||||
SenderID: senderID,
|
||||
MessageID: msg.ID,
|
||||
Raw: map[string]string{
|
||||
"platform": "pico_client",
|
||||
"session_id": sessionID,
|
||||
},
|
||||
}
|
||||
|
||||
c.HandleInboundContext(c.ctx, chatID, content, nil, inboundCtx, sender)
|
||||
}
|
||||
|
||||
// Send sends a message to the remote server.
|
||||
|
||||
@@ -539,8 +539,6 @@ func (c *PicoChannel) handleMessageSend(pc *picoConn, msg PicoMessage) {
|
||||
chatID := "pico:" + sessionID
|
||||
senderID := "pico-user"
|
||||
|
||||
peer := bus.Peer{Kind: "direct", ID: "pico:" + sessionID}
|
||||
|
||||
metadata := map[string]string{
|
||||
"platform": "pico",
|
||||
"session_id": sessionID,
|
||||
@@ -562,7 +560,16 @@ func (c *PicoChannel) handleMessageSend(pc *picoConn, msg PicoMessage) {
|
||||
return
|
||||
}
|
||||
|
||||
c.HandleMessage(c.ctx, peer, msg.ID, senderID, chatID, content, nil, metadata, sender)
|
||||
inboundCtx := bus.InboundContext{
|
||||
Channel: "pico",
|
||||
ChatID: chatID,
|
||||
ChatType: "direct",
|
||||
SenderID: senderID,
|
||||
MessageID: msg.ID,
|
||||
Raw: metadata,
|
||||
}
|
||||
|
||||
c.HandleInboundContext(c.ctx, chatID, content, nil, inboundCtx, sender)
|
||||
}
|
||||
|
||||
// truncate truncates a string to maxLen runes.
|
||||
|
||||
+2
-18
@@ -657,15 +657,7 @@ func (c *QQChannel) handleC2CMessage() event.C2CMessageEventHandler {
|
||||
Raw: metadata,
|
||||
}
|
||||
|
||||
c.HandleMessageWithContext(
|
||||
c.ctx,
|
||||
bus.Peer{Kind: "direct", ID: senderID},
|
||||
senderID,
|
||||
content,
|
||||
mediaPaths,
|
||||
inboundCtx,
|
||||
sender,
|
||||
)
|
||||
c.HandleInboundContext(c.ctx, senderID, content, mediaPaths, inboundCtx, sender)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -744,15 +736,7 @@ func (c *QQChannel) handleGroupATMessage() event.GroupATMessageEventHandler {
|
||||
Raw: metadata,
|
||||
}
|
||||
|
||||
c.HandleMessageWithContext(
|
||||
c.ctx,
|
||||
bus.Peer{Kind: "group", ID: data.GroupID},
|
||||
data.GroupID,
|
||||
content,
|
||||
mediaPaths,
|
||||
inboundCtx,
|
||||
sender,
|
||||
)
|
||||
c.HandleInboundContext(c.ctx, data.GroupID, content, mediaPaths, inboundCtx, sender)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -356,14 +356,10 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) {
|
||||
}
|
||||
|
||||
peerKind := "channel"
|
||||
peerID := channelID
|
||||
if strings.HasPrefix(channelID, "D") {
|
||||
peerKind = "direct"
|
||||
peerID = senderID
|
||||
}
|
||||
|
||||
peer := bus.Peer{Kind: peerKind, ID: peerID}
|
||||
|
||||
metadata := map[string]string{
|
||||
"message_ts": messageTS,
|
||||
"channel_id": channelID,
|
||||
@@ -394,7 +390,7 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) {
|
||||
inboundCtx.TopicID = threadTS
|
||||
}
|
||||
|
||||
c.HandleMessageWithContext(c.ctx, peer, chatID, content, mediaPaths, inboundCtx, sender)
|
||||
c.HandleInboundContext(c.ctx, chatID, content, mediaPaths, inboundCtx, sender)
|
||||
}
|
||||
|
||||
func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) {
|
||||
@@ -442,14 +438,10 @@ func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) {
|
||||
}
|
||||
|
||||
mentionPeerKind := "channel"
|
||||
mentionPeerID := channelID
|
||||
if strings.HasPrefix(channelID, "D") {
|
||||
mentionPeerKind = "direct"
|
||||
mentionPeerID = senderID
|
||||
}
|
||||
|
||||
mentionPeer := bus.Peer{Kind: mentionPeerKind, ID: mentionPeerID}
|
||||
|
||||
metadata := map[string]string{
|
||||
"message_ts": messageTS,
|
||||
"channel_id": channelID,
|
||||
@@ -472,7 +464,7 @@ func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) {
|
||||
Raw: metadata,
|
||||
}
|
||||
|
||||
c.HandleMessageWithContext(c.ctx, mentionPeer, chatID, content, nil, inboundCtx, mentionSender)
|
||||
c.HandleInboundContext(c.ctx, chatID, content, nil, inboundCtx, mentionSender)
|
||||
}
|
||||
|
||||
func (c *SlackChannel) handleSlashCommand(event socketmode.Event) {
|
||||
@@ -520,10 +512,8 @@ func (c *SlackChannel) handleSlashCommand(event socketmode.Event) {
|
||||
"text": utils.Truncate(content, 50),
|
||||
})
|
||||
peerKind := "channel"
|
||||
peerID := channelID
|
||||
if strings.HasPrefix(channelID, "D") {
|
||||
peerKind = "direct"
|
||||
peerID = senderID
|
||||
}
|
||||
inboundCtx := bus.InboundContext{
|
||||
Channel: c.Name(),
|
||||
@@ -536,15 +526,7 @@ func (c *SlackChannel) handleSlashCommand(event socketmode.Event) {
|
||||
Raw: metadata,
|
||||
}
|
||||
|
||||
c.HandleMessageWithContext(
|
||||
c.ctx,
|
||||
bus.Peer{Kind: peerKind, ID: peerID},
|
||||
chatID,
|
||||
content,
|
||||
nil,
|
||||
inboundCtx,
|
||||
cmdSender,
|
||||
)
|
||||
c.HandleInboundContext(c.ctx, chatID, content, nil, inboundCtx, cmdSender)
|
||||
}
|
||||
|
||||
func (c *SlackChannel) downloadSlackFile(file slack.File) string {
|
||||
|
||||
@@ -708,13 +708,9 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes
|
||||
})
|
||||
|
||||
peerKind := "direct"
|
||||
peerID := fmt.Sprintf("%d", user.ID)
|
||||
if message.Chat.Type != "private" {
|
||||
peerKind = "group"
|
||||
peerID = compositeChatID
|
||||
}
|
||||
|
||||
peer := bus.Peer{Kind: peerKind, ID: peerID}
|
||||
messageID := fmt.Sprintf("%d", message.MessageID)
|
||||
|
||||
metadata := map[string]string{
|
||||
@@ -742,7 +738,6 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes
|
||||
|
||||
c.HandleMessageWithContext(
|
||||
c.ctx,
|
||||
peer,
|
||||
compositeChatID,
|
||||
content,
|
||||
mediaPaths,
|
||||
|
||||
@@ -570,7 +570,6 @@ func (c *WeComChannel) dispatchIncoming(reqID string, msg wecomIncomingMessage)
|
||||
return err
|
||||
}
|
||||
|
||||
peer := bus.Peer{Kind: peerKind, ID: actualChatID}
|
||||
metadata := map[string]string{
|
||||
"channel": "wecom",
|
||||
"req_id": reqID,
|
||||
@@ -596,7 +595,7 @@ func (c *WeComChannel) dispatchIncoming(reqID string, msg wecomIncomingMessage)
|
||||
Raw: metadata,
|
||||
}
|
||||
|
||||
c.HandleMessageWithContext(c.ctx, peer, actualChatID, content, mediaRefs, inboundCtx, sender)
|
||||
c.HandleInboundContext(c.ctx, actualChatID, content, mediaRefs, inboundCtx, sender)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -334,8 +334,6 @@ func (c *WeixinChannel) handleInboundMessage(ctx context.Context, msg WeixinMess
|
||||
return
|
||||
}
|
||||
|
||||
peer := bus.Peer{Kind: "direct", ID: fromUserID}
|
||||
|
||||
metadata := map[string]string{
|
||||
"from_user_id": fromUserID,
|
||||
"context_token": msg.ContextToken,
|
||||
@@ -354,7 +352,21 @@ func (c *WeixinChannel) handleInboundMessage(ctx context.Context, msg WeixinMess
|
||||
c.persistContextTokens()
|
||||
}
|
||||
|
||||
c.HandleMessage(ctx, peer, messageID, fromUserID, fromUserID, content, mediaRefs, metadata, sender)
|
||||
inboundCtx := bus.InboundContext{
|
||||
Channel: "weixin",
|
||||
ChatID: fromUserID,
|
||||
ChatType: "direct",
|
||||
SenderID: fromUserID,
|
||||
MessageID: messageID,
|
||||
Raw: metadata,
|
||||
}
|
||||
if msg.ContextToken != "" {
|
||||
inboundCtx.ReplyHandles = map[string]string{
|
||||
"context_token": msg.ContextToken,
|
||||
}
|
||||
}
|
||||
|
||||
c.HandleInboundContext(ctx, fromUserID, content, mediaRefs, inboundCtx, sender)
|
||||
}
|
||||
|
||||
// Send implements channels.Channel by sending a text message to the WeChat user.
|
||||
|
||||
@@ -223,13 +223,6 @@ func (c *WhatsAppChannel) handleIncomingMessage(msg map[string]any) {
|
||||
metadata["user_name"] = userName
|
||||
}
|
||||
|
||||
var peer bus.Peer
|
||||
if chatID == senderID {
|
||||
peer = bus.Peer{Kind: "direct", ID: senderID}
|
||||
} else {
|
||||
peer = bus.Peer{Kind: "group", ID: chatID}
|
||||
}
|
||||
|
||||
logger.InfoCF("whatsapp", "WhatsApp message received", map[string]any{
|
||||
"sender": senderID,
|
||||
"preview": utils.Truncate(content, 50),
|
||||
@@ -248,5 +241,18 @@ func (c *WhatsAppChannel) handleIncomingMessage(msg map[string]any) {
|
||||
return
|
||||
}
|
||||
|
||||
c.HandleMessage(c.ctx, peer, messageID, senderID, chatID, content, mediaPaths, metadata, sender)
|
||||
inboundCtx := bus.InboundContext{
|
||||
Channel: "whatsapp",
|
||||
ChatID: chatID,
|
||||
SenderID: senderID,
|
||||
MessageID: messageID,
|
||||
Raw: metadata,
|
||||
}
|
||||
if chatID == senderID {
|
||||
inboundCtx.ChatType = "direct"
|
||||
} else {
|
||||
inboundCtx.ChatType = "group"
|
||||
}
|
||||
|
||||
c.HandleInboundContext(c.ctx, chatID, content, mediaPaths, inboundCtx, sender)
|
||||
}
|
||||
|
||||
@@ -375,7 +375,6 @@ func (c *WhatsAppNativeChannel) handleIncoming(evt *events.Message) {
|
||||
if evt.Info.Chat.Server == types.GroupServer {
|
||||
peerKind = "group"
|
||||
}
|
||||
peer := bus.Peer{Kind: peerKind, ID: chatID}
|
||||
messageID := evt.Info.ID
|
||||
sender := bus.SenderInfo{
|
||||
Platform: "whatsapp",
|
||||
@@ -393,7 +392,17 @@ func (c *WhatsAppNativeChannel) handleIncoming(evt *events.Message) {
|
||||
"WhatsApp message received",
|
||||
map[string]any{"sender_id": senderID, "content_preview": utils.Truncate(content, 50)},
|
||||
)
|
||||
c.HandleMessage(c.runCtx, peer, messageID, senderID, chatID, content, mediaPaths, metadata, sender)
|
||||
|
||||
inboundCtx := bus.InboundContext{
|
||||
Channel: "whatsapp",
|
||||
ChatID: chatID,
|
||||
SenderID: senderID,
|
||||
MessageID: messageID,
|
||||
ChatType: peerKind,
|
||||
Raw: metadata,
|
||||
}
|
||||
|
||||
c.HandleInboundContext(c.runCtx, chatID, content, mediaPaths, inboundCtx, sender)
|
||||
}
|
||||
|
||||
func (c *WhatsAppNativeChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]string, error) {
|
||||
|
||||
@@ -99,7 +99,7 @@ type BuildInfo struct {
|
||||
}
|
||||
|
||||
// MarshalJSON implements custom JSON marshaling for Config
|
||||
// to omit providers section when empty and session when empty
|
||||
// to omit providers section when empty and session when empty.
|
||||
func (c *Config) MarshalJSON() ([]byte, error) {
|
||||
type Alias Config
|
||||
aux := &struct {
|
||||
@@ -109,11 +109,8 @@ func (c *Config) MarshalJSON() ([]byte, error) {
|
||||
Alias: (*Alias)(c),
|
||||
}
|
||||
|
||||
// Only include session if not empty. Deprecated dm_scope is intentionally
|
||||
// omitted so persisted configs converge on dimensions-based session policy.
|
||||
if len(c.Session.Dimensions) > 0 || len(c.Session.IdentityLinks) > 0 {
|
||||
sessionCfg := c.Session
|
||||
sessionCfg.DMScope = ""
|
||||
aux.Session = &sessionCfg
|
||||
}
|
||||
|
||||
@@ -199,7 +196,6 @@ type AgentBinding struct {
|
||||
|
||||
type SessionConfig struct {
|
||||
Dimensions []string `json:"dimensions,omitempty"`
|
||||
DMScope string `json:"dm_scope,omitempty"` // Deprecated: ignored by the new session policy path.
|
||||
IdentityLinks map[string][]string `json:"identity_links,omitempty"`
|
||||
}
|
||||
|
||||
|
||||
@@ -131,8 +131,7 @@ func (s *Service) sendNotification(ev *events.DeviceEvent) {
|
||||
pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer pubCancel()
|
||||
msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{
|
||||
Channel: platform,
|
||||
ChatID: userID,
|
||||
Context: bus.NewOutboundContext(platform, userID, ""),
|
||||
Content: msg,
|
||||
})
|
||||
|
||||
|
||||
@@ -339,8 +339,7 @@ func (hs *HeartbeatService) sendResponse(response string) {
|
||||
pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer pubCancel()
|
||||
msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{
|
||||
Channel: platform,
|
||||
ChatID: userID,
|
||||
Context: bus.NewOutboundContext(platform, userID, ""),
|
||||
Content: response,
|
||||
})
|
||||
|
||||
|
||||
+60
-19
@@ -3,25 +3,21 @@ package routing
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
// RouteInput contains the routing context from an inbound message.
|
||||
type RouteInput struct {
|
||||
Channel string
|
||||
AccountID string
|
||||
Peer *RoutePeer
|
||||
ParentPeer *RoutePeer
|
||||
GuildID string
|
||||
TeamID string
|
||||
}
|
||||
|
||||
// SessionPolicy describes how a routed message should be mapped to a session.
|
||||
type SessionPolicy struct {
|
||||
Dimensions []string
|
||||
IdentityLinks map[string][]string
|
||||
}
|
||||
|
||||
type RoutePeer struct {
|
||||
Kind string
|
||||
ID string
|
||||
}
|
||||
|
||||
// ResolvedRoute is the result of agent routing.
|
||||
type ResolvedRoute struct {
|
||||
AgentID string
|
||||
@@ -41,14 +37,15 @@ func NewRouteResolver(cfg *config.Config) *RouteResolver {
|
||||
return &RouteResolver{cfg: cfg}
|
||||
}
|
||||
|
||||
// ResolveRoute determines which agent handles the message and returns the
|
||||
// session policy that should be used to allocate session state.
|
||||
// ResolveRoute determines which agent handles the message from a normalized
|
||||
// inbound context and returns the session policy that should be used to
|
||||
// allocate session state.
|
||||
// Implements the 7-level priority cascade:
|
||||
// peer > parent_peer > guild > team > account > channel_wildcard > default
|
||||
func (r *RouteResolver) ResolveRoute(input RouteInput) ResolvedRoute {
|
||||
channel := strings.ToLower(strings.TrimSpace(input.Channel))
|
||||
accountID := NormalizeAccountID(input.AccountID)
|
||||
peer := input.Peer
|
||||
func (r *RouteResolver) ResolveRoute(inbound bus.InboundContext) ResolvedRoute {
|
||||
channel := strings.ToLower(strings.TrimSpace(inbound.Channel))
|
||||
accountID := NormalizeAccountID(inbound.Account)
|
||||
peer := routePeerFromContext(inbound)
|
||||
|
||||
sessionPolicy := r.sessionPolicy()
|
||||
|
||||
@@ -73,7 +70,7 @@ func (r *RouteResolver) ResolveRoute(input RouteInput) ResolvedRoute {
|
||||
}
|
||||
|
||||
// Priority 2: Parent peer binding
|
||||
parentPeer := input.ParentPeer
|
||||
parentPeer := parentPeerFromContext(inbound)
|
||||
if parentPeer != nil && strings.TrimSpace(parentPeer.ID) != "" {
|
||||
if match := r.findPeerMatch(bindings, parentPeer); match != nil {
|
||||
return choose(match.AgentID, "binding.peer.parent")
|
||||
@@ -81,7 +78,7 @@ func (r *RouteResolver) ResolveRoute(input RouteInput) ResolvedRoute {
|
||||
}
|
||||
|
||||
// Priority 3: Guild binding
|
||||
guildID := strings.TrimSpace(input.GuildID)
|
||||
guildID := routeGuildIDFromContext(inbound)
|
||||
if guildID != "" {
|
||||
if match := r.findGuildMatch(bindings, guildID); match != nil {
|
||||
return choose(match.AgentID, "binding.guild")
|
||||
@@ -89,7 +86,7 @@ func (r *RouteResolver) ResolveRoute(input RouteInput) ResolvedRoute {
|
||||
}
|
||||
|
||||
// Priority 4: Team binding
|
||||
teamID := strings.TrimSpace(input.TeamID)
|
||||
teamID := routeTeamIDFromContext(inbound)
|
||||
if teamID != "" {
|
||||
if match := r.findTeamMatch(bindings, teamID); match != nil {
|
||||
return choose(match.AgentID, "binding.team")
|
||||
@@ -276,6 +273,46 @@ func normalizeSessionDimensions(dimensions []string) []string {
|
||||
return normalized
|
||||
}
|
||||
|
||||
func routePeerFromContext(ctx bus.InboundContext) *RoutePeer {
|
||||
peerKind := normalizeChannel(strings.TrimSpace(ctx.ChatType))
|
||||
if peerKind == "" || peerKind == "unknown" {
|
||||
return nil
|
||||
}
|
||||
|
||||
peerID := strings.TrimSpace(ctx.ChatID)
|
||||
if peerKind == "direct" && peerID == "" {
|
||||
peerID = strings.TrimSpace(ctx.SenderID)
|
||||
}
|
||||
if peerID == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &RoutePeer{Kind: peerKind, ID: peerID}
|
||||
}
|
||||
|
||||
func parentPeerFromContext(ctx bus.InboundContext) *RoutePeer {
|
||||
if topicID := strings.TrimSpace(ctx.TopicID); topicID != "" {
|
||||
return &RoutePeer{Kind: "topic", ID: topicID}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func routeGuildIDFromContext(ctx bus.InboundContext) string {
|
||||
if strings.EqualFold(strings.TrimSpace(ctx.SpaceType), "guild") {
|
||||
return strings.TrimSpace(ctx.SpaceID)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func routeTeamIDFromContext(ctx bus.InboundContext) string {
|
||||
switch strings.ToLower(strings.TrimSpace(ctx.SpaceType)) {
|
||||
case "team", "workspace":
|
||||
return strings.TrimSpace(ctx.SpaceID)
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func cloneIdentityLinks(src map[string][]string) map[string][]string {
|
||||
if len(src) == 0 {
|
||||
return nil
|
||||
@@ -288,3 +325,7 @@ func cloneIdentityLinks(src map[string][]string) map[string][]string {
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func normalizeChannel(value string) string {
|
||||
return strings.ToLower(strings.TrimSpace(value))
|
||||
}
|
||||
|
||||
+39
-34
@@ -3,6 +3,7 @@ package routing
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
@@ -26,9 +27,10 @@ func TestResolveRoute_DefaultAgent_NoBindings(t *testing.T) {
|
||||
cfg := testConfig(nil, nil)
|
||||
r := NewRouteResolver(cfg)
|
||||
|
||||
route := r.ResolveRoute(RouteInput{
|
||||
Channel: "telegram",
|
||||
Peer: &RoutePeer{Kind: "direct", ID: "user1"},
|
||||
route := r.ResolveRoute(bus.InboundContext{
|
||||
Channel: "telegram",
|
||||
ChatType: "direct",
|
||||
SenderID: "user1",
|
||||
})
|
||||
|
||||
if route.AgentID != DefaultAgentID {
|
||||
@@ -63,9 +65,10 @@ func TestResolveRoute_PeerBinding(t *testing.T) {
|
||||
cfg := testConfig(agents, bindings)
|
||||
r := NewRouteResolver(cfg)
|
||||
|
||||
route := r.ResolveRoute(RouteInput{
|
||||
Channel: "telegram",
|
||||
Peer: &RoutePeer{Kind: "direct", ID: "user123"},
|
||||
route := r.ResolveRoute(bus.InboundContext{
|
||||
Channel: "telegram",
|
||||
ChatType: "direct",
|
||||
SenderID: "user123",
|
||||
})
|
||||
|
||||
if route.AgentID != "support" {
|
||||
@@ -94,10 +97,12 @@ func TestResolveRoute_GuildBinding(t *testing.T) {
|
||||
cfg := testConfig(agents, bindings)
|
||||
r := NewRouteResolver(cfg)
|
||||
|
||||
route := r.ResolveRoute(RouteInput{
|
||||
Channel: "discord",
|
||||
GuildID: "guild-abc",
|
||||
Peer: &RoutePeer{Kind: "channel", ID: "ch1"},
|
||||
route := r.ResolveRoute(bus.InboundContext{
|
||||
Channel: "discord",
|
||||
ChatID: "ch1",
|
||||
ChatType: "channel",
|
||||
SpaceID: "guild-abc",
|
||||
SpaceType: "guild",
|
||||
})
|
||||
|
||||
if route.AgentID != "gaming" {
|
||||
@@ -126,10 +131,12 @@ func TestResolveRoute_TeamBinding(t *testing.T) {
|
||||
cfg := testConfig(agents, bindings)
|
||||
r := NewRouteResolver(cfg)
|
||||
|
||||
route := r.ResolveRoute(RouteInput{
|
||||
Channel: "slack",
|
||||
TeamID: "T12345",
|
||||
Peer: &RoutePeer{Kind: "channel", ID: "C001"},
|
||||
route := r.ResolveRoute(bus.InboundContext{
|
||||
Channel: "slack",
|
||||
ChatID: "C001",
|
||||
ChatType: "channel",
|
||||
SpaceID: "T12345",
|
||||
SpaceType: "team",
|
||||
})
|
||||
|
||||
if route.AgentID != "work" {
|
||||
@@ -157,10 +164,11 @@ func TestResolveRoute_AccountBinding(t *testing.T) {
|
||||
cfg := testConfig(agents, bindings)
|
||||
r := NewRouteResolver(cfg)
|
||||
|
||||
route := r.ResolveRoute(RouteInput{
|
||||
Channel: "telegram",
|
||||
AccountID: "bot2",
|
||||
Peer: &RoutePeer{Kind: "direct", ID: "user1"},
|
||||
route := r.ResolveRoute(bus.InboundContext{
|
||||
Channel: "telegram",
|
||||
Account: "bot2",
|
||||
ChatType: "direct",
|
||||
SenderID: "user1",
|
||||
})
|
||||
|
||||
if route.AgentID != "premium" {
|
||||
@@ -188,9 +196,10 @@ func TestResolveRoute_ChannelWildcard(t *testing.T) {
|
||||
cfg := testConfig(agents, bindings)
|
||||
r := NewRouteResolver(cfg)
|
||||
|
||||
route := r.ResolveRoute(RouteInput{
|
||||
Channel: "telegram",
|
||||
Peer: &RoutePeer{Kind: "direct", ID: "user1"},
|
||||
route := r.ResolveRoute(bus.InboundContext{
|
||||
Channel: "telegram",
|
||||
ChatType: "direct",
|
||||
SenderID: "user1",
|
||||
})
|
||||
|
||||
if route.AgentID != "telegram-bot" {
|
||||
@@ -228,10 +237,12 @@ func TestResolveRoute_PriorityOrder_PeerBeatsGuild(t *testing.T) {
|
||||
cfg := testConfig(agents, bindings)
|
||||
r := NewRouteResolver(cfg)
|
||||
|
||||
route := r.ResolveRoute(RouteInput{
|
||||
Channel: "discord",
|
||||
GuildID: "guild-1",
|
||||
Peer: &RoutePeer{Kind: "direct", ID: "user-vip"},
|
||||
route := r.ResolveRoute(bus.InboundContext{
|
||||
Channel: "discord",
|
||||
ChatType: "direct",
|
||||
SenderID: "user-vip",
|
||||
SpaceID: "guild-1",
|
||||
SpaceType: "guild",
|
||||
})
|
||||
|
||||
if route.AgentID != "vip" {
|
||||
@@ -258,9 +269,7 @@ func TestResolveRoute_InvalidAgentFallsToDefault(t *testing.T) {
|
||||
cfg := testConfig(agents, bindings)
|
||||
r := NewRouteResolver(cfg)
|
||||
|
||||
route := r.ResolveRoute(RouteInput{
|
||||
Channel: "telegram",
|
||||
})
|
||||
route := r.ResolveRoute(bus.InboundContext{Channel: "telegram"})
|
||||
|
||||
if route.AgentID != "main" {
|
||||
t.Errorf("AgentID = %q, want 'main' (invalid agent should fall to default)", route.AgentID)
|
||||
@@ -276,9 +285,7 @@ func TestResolveRoute_DefaultAgentSelection(t *testing.T) {
|
||||
cfg := testConfig(agents, nil)
|
||||
r := NewRouteResolver(cfg)
|
||||
|
||||
route := r.ResolveRoute(RouteInput{
|
||||
Channel: "cli",
|
||||
})
|
||||
route := r.ResolveRoute(bus.InboundContext{Channel: "cli"})
|
||||
|
||||
if route.AgentID != "beta" {
|
||||
t.Errorf("AgentID = %q, want 'beta' (marked as default)", route.AgentID)
|
||||
@@ -293,9 +300,7 @@ func TestResolveRoute_NoDefaultUsesFirst(t *testing.T) {
|
||||
cfg := testConfig(agents, nil)
|
||||
r := NewRouteResolver(cfg)
|
||||
|
||||
route := r.ResolveRoute(RouteInput{
|
||||
Channel: "cli",
|
||||
})
|
||||
route := r.ResolveRoute(bus.InboundContext{Channel: "cli"})
|
||||
|
||||
if route.AgentID != "alpha" {
|
||||
t.Errorf("AgentID = %q, want 'alpha' (first in list)", route.AgentID)
|
||||
|
||||
@@ -1,218 +0,0 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// DMScope controls DM session isolation granularity.
|
||||
type DMScope string
|
||||
|
||||
const (
|
||||
DMScopeMain DMScope = "main"
|
||||
DMScopePerPeer DMScope = "per-peer"
|
||||
DMScopePerChannelPeer DMScope = "per-channel-peer"
|
||||
DMScopePerAccountChannelPeer DMScope = "per-account-channel-peer"
|
||||
)
|
||||
|
||||
// RoutePeer represents a chat peer with kind and ID.
|
||||
type RoutePeer struct {
|
||||
Kind string // "direct", "group", "channel"
|
||||
ID string
|
||||
}
|
||||
|
||||
// SessionKeyParams holds all inputs for session key construction.
|
||||
type SessionKeyParams struct {
|
||||
AgentID string
|
||||
Channel string
|
||||
AccountID string
|
||||
Peer *RoutePeer
|
||||
DMScope DMScope
|
||||
IdentityLinks map[string][]string
|
||||
}
|
||||
|
||||
// ParsedSessionKey is the result of parsing an agent-scoped session key.
|
||||
type ParsedSessionKey struct {
|
||||
AgentID string
|
||||
Rest string
|
||||
}
|
||||
|
||||
// BuildAgentMainSessionKey returns "agent:<agentId>:main".
|
||||
func BuildAgentMainSessionKey(agentID string) string {
|
||||
return fmt.Sprintf("agent:%s:%s", NormalizeAgentID(agentID), DefaultMainKey)
|
||||
}
|
||||
|
||||
// BuildAgentPeerSessionKey constructs a session key based on agent, channel, peer, and DM scope.
|
||||
func BuildAgentPeerSessionKey(params SessionKeyParams) string {
|
||||
agentID := NormalizeAgentID(params.AgentID)
|
||||
|
||||
peer := params.Peer
|
||||
if peer == nil {
|
||||
peer = &RoutePeer{Kind: "direct"}
|
||||
}
|
||||
peerKind := strings.TrimSpace(peer.Kind)
|
||||
if peerKind == "" {
|
||||
peerKind = "direct"
|
||||
}
|
||||
|
||||
if peerKind == "direct" {
|
||||
dmScope := params.DMScope
|
||||
if dmScope == "" {
|
||||
dmScope = DMScopeMain
|
||||
}
|
||||
peerID := CanonicalSessionPeerID(params.Channel, peer.ID, dmScope, params.IdentityLinks)
|
||||
|
||||
switch dmScope {
|
||||
case DMScopePerAccountChannelPeer:
|
||||
if peerID != "" {
|
||||
channel := normalizeChannel(params.Channel)
|
||||
accountID := NormalizeAccountID(params.AccountID)
|
||||
return fmt.Sprintf("agent:%s:%s:%s:direct:%s", agentID, channel, accountID, peerID)
|
||||
}
|
||||
case DMScopePerChannelPeer:
|
||||
if peerID != "" {
|
||||
channel := normalizeChannel(params.Channel)
|
||||
return fmt.Sprintf("agent:%s:%s:direct:%s", agentID, channel, peerID)
|
||||
}
|
||||
case DMScopePerPeer:
|
||||
if peerID != "" {
|
||||
return fmt.Sprintf("agent:%s:direct:%s", agentID, peerID)
|
||||
}
|
||||
}
|
||||
return BuildAgentMainSessionKey(agentID)
|
||||
}
|
||||
|
||||
// Group/channel peers always get per-peer sessions
|
||||
channel := normalizeChannel(params.Channel)
|
||||
peerID := strings.ToLower(strings.TrimSpace(peer.ID))
|
||||
if peerID == "" {
|
||||
peerID = "unknown"
|
||||
}
|
||||
return fmt.Sprintf("agent:%s:%s:%s:%s", agentID, channel, peerKind, peerID)
|
||||
}
|
||||
|
||||
// CanonicalSessionPeerID applies the current DM session canonicalization rules,
|
||||
// including identity-link collapse when enabled.
|
||||
func CanonicalSessionPeerID(
|
||||
channel, peerID string,
|
||||
dmScope DMScope,
|
||||
identityLinks map[string][]string,
|
||||
) string {
|
||||
normalizedPeerID := strings.TrimSpace(peerID)
|
||||
if normalizedPeerID == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
if dmScope != DMScopeMain {
|
||||
if linked := resolveLinkedPeerID(identityLinks, channel, normalizedPeerID); linked != "" {
|
||||
normalizedPeerID = linked
|
||||
}
|
||||
}
|
||||
|
||||
return strings.ToLower(normalizedPeerID)
|
||||
}
|
||||
|
||||
// CanonicalSessionIdentityID collapses an identity using identity_links when
|
||||
// possible, then returns a normalized lowercase identifier.
|
||||
func CanonicalSessionIdentityID(channel, rawID string, identityLinks map[string][]string) string {
|
||||
normalizedID := strings.TrimSpace(rawID)
|
||||
if normalizedID == "" {
|
||||
return ""
|
||||
}
|
||||
if linked := resolveLinkedPeerID(identityLinks, channel, normalizedID); linked != "" {
|
||||
normalizedID = linked
|
||||
}
|
||||
return strings.ToLower(normalizedID)
|
||||
}
|
||||
|
||||
// ParseAgentSessionKey extracts agentId and rest from "agent:<agentId>:<rest>".
|
||||
func ParseAgentSessionKey(sessionKey string) *ParsedSessionKey {
|
||||
raw := strings.TrimSpace(sessionKey)
|
||||
if raw == "" {
|
||||
return nil
|
||||
}
|
||||
parts := strings.SplitN(raw, ":", 3)
|
||||
if len(parts) < 3 {
|
||||
return nil
|
||||
}
|
||||
if parts[0] != "agent" {
|
||||
return nil
|
||||
}
|
||||
agentID := strings.TrimSpace(parts[1])
|
||||
rest := parts[2]
|
||||
if agentID == "" || rest == "" {
|
||||
return nil
|
||||
}
|
||||
return &ParsedSessionKey{AgentID: agentID, Rest: rest}
|
||||
}
|
||||
|
||||
// IsSubagentSessionKey returns true if the session key represents a subagent.
|
||||
func IsSubagentSessionKey(sessionKey string) bool {
|
||||
raw := strings.TrimSpace(sessionKey)
|
||||
if raw == "" {
|
||||
return false
|
||||
}
|
||||
if strings.HasPrefix(strings.ToLower(raw), "subagent:") {
|
||||
return true
|
||||
}
|
||||
parsed := ParseAgentSessionKey(raw)
|
||||
if parsed == nil {
|
||||
return false
|
||||
}
|
||||
return strings.HasPrefix(strings.ToLower(parsed.Rest), "subagent:")
|
||||
}
|
||||
|
||||
func normalizeChannel(channel string) string {
|
||||
c := strings.TrimSpace(strings.ToLower(channel))
|
||||
if c == "" {
|
||||
return "unknown"
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
func resolveLinkedPeerID(identityLinks map[string][]string, channel, peerID string) string {
|
||||
if len(identityLinks) == 0 {
|
||||
return ""
|
||||
}
|
||||
peerID = strings.TrimSpace(peerID)
|
||||
if peerID == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
candidates := make(map[string]bool)
|
||||
rawCandidate := strings.ToLower(peerID)
|
||||
if rawCandidate != "" {
|
||||
candidates[rawCandidate] = true
|
||||
}
|
||||
channel = strings.ToLower(strings.TrimSpace(channel))
|
||||
if channel != "" {
|
||||
scopedCandidate := fmt.Sprintf("%s:%s", channel, strings.ToLower(peerID))
|
||||
candidates[scopedCandidate] = true
|
||||
}
|
||||
|
||||
// If peerID is already in canonical "platform:id" format, also add the
|
||||
// bare ID part as a candidate for backward compatibility with identity_links
|
||||
// that use raw IDs (e.g. "123" instead of "telegram:123").
|
||||
if idx := strings.Index(rawCandidate, ":"); idx > 0 && idx < len(rawCandidate)-1 {
|
||||
bareID := rawCandidate[idx+1:]
|
||||
candidates[bareID] = true
|
||||
}
|
||||
|
||||
if len(candidates) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
for canonical, ids := range identityLinks {
|
||||
canonicalName := strings.TrimSpace(canonical)
|
||||
if canonicalName == "" {
|
||||
continue
|
||||
}
|
||||
for _, id := range ids {
|
||||
normalized := strings.ToLower(strings.TrimSpace(id))
|
||||
if normalized != "" && candidates[normalized] {
|
||||
return canonicalName
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -1,207 +0,0 @@
|
||||
package routing
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestBuildAgentMainSessionKey(t *testing.T) {
|
||||
got := BuildAgentMainSessionKey("sales")
|
||||
want := "agent:sales:main"
|
||||
if got != want {
|
||||
t.Errorf("BuildAgentMainSessionKey('sales') = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAgentMainSessionKey_Normalizes(t *testing.T) {
|
||||
got := BuildAgentMainSessionKey("Sales Bot")
|
||||
want := "agent:sales-bot:main"
|
||||
if got != want {
|
||||
t.Errorf("BuildAgentMainSessionKey('Sales Bot') = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAgentPeerSessionKey_DMScopeMain(t *testing.T) {
|
||||
got := BuildAgentPeerSessionKey(SessionKeyParams{
|
||||
AgentID: "main",
|
||||
Channel: "telegram",
|
||||
Peer: &RoutePeer{Kind: "direct", ID: "user123"},
|
||||
DMScope: DMScopeMain,
|
||||
})
|
||||
want := "agent:main:main"
|
||||
if got != want {
|
||||
t.Errorf("DMScopeMain = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAgentPeerSessionKey_DMScopePerPeer(t *testing.T) {
|
||||
got := BuildAgentPeerSessionKey(SessionKeyParams{
|
||||
AgentID: "main",
|
||||
Channel: "telegram",
|
||||
Peer: &RoutePeer{Kind: "direct", ID: "user123"},
|
||||
DMScope: DMScopePerPeer,
|
||||
})
|
||||
want := "agent:main:direct:user123"
|
||||
if got != want {
|
||||
t.Errorf("DMScopePerPeer = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAgentPeerSessionKey_DMScopePerChannelPeer(t *testing.T) {
|
||||
got := BuildAgentPeerSessionKey(SessionKeyParams{
|
||||
AgentID: "main",
|
||||
Channel: "telegram",
|
||||
Peer: &RoutePeer{Kind: "direct", ID: "user123"},
|
||||
DMScope: DMScopePerChannelPeer,
|
||||
})
|
||||
want := "agent:main:telegram:direct:user123"
|
||||
if got != want {
|
||||
t.Errorf("DMScopePerChannelPeer = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAgentPeerSessionKey_DMScopePerAccountChannelPeer(t *testing.T) {
|
||||
got := BuildAgentPeerSessionKey(SessionKeyParams{
|
||||
AgentID: "main",
|
||||
Channel: "telegram",
|
||||
AccountID: "bot1",
|
||||
Peer: &RoutePeer{Kind: "direct", ID: "User123"},
|
||||
DMScope: DMScopePerAccountChannelPeer,
|
||||
})
|
||||
want := "agent:main:telegram:bot1:direct:user123"
|
||||
if got != want {
|
||||
t.Errorf("DMScopePerAccountChannelPeer = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAgentPeerSessionKey_GroupPeer(t *testing.T) {
|
||||
got := BuildAgentPeerSessionKey(SessionKeyParams{
|
||||
AgentID: "main",
|
||||
Channel: "telegram",
|
||||
Peer: &RoutePeer{Kind: "group", ID: "chat456"},
|
||||
DMScope: DMScopePerPeer,
|
||||
})
|
||||
want := "agent:main:telegram:group:chat456"
|
||||
if got != want {
|
||||
t.Errorf("GroupPeer = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAgentPeerSessionKey_NilPeer(t *testing.T) {
|
||||
got := BuildAgentPeerSessionKey(SessionKeyParams{
|
||||
AgentID: "main",
|
||||
Channel: "telegram",
|
||||
Peer: nil,
|
||||
DMScope: DMScopePerPeer,
|
||||
})
|
||||
// nil peer defaults to direct with empty ID, falls to main
|
||||
want := "agent:main:main"
|
||||
if got != want {
|
||||
t.Errorf("NilPeer = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAgentPeerSessionKey_IdentityLink(t *testing.T) {
|
||||
links := map[string][]string{
|
||||
"john": {"telegram:user123", "discord:john#1234"},
|
||||
}
|
||||
got := BuildAgentPeerSessionKey(SessionKeyParams{
|
||||
AgentID: "main",
|
||||
Channel: "telegram",
|
||||
Peer: &RoutePeer{Kind: "direct", ID: "user123"},
|
||||
DMScope: DMScopePerPeer,
|
||||
IdentityLinks: links,
|
||||
})
|
||||
want := "agent:main:direct:john"
|
||||
if got != want {
|
||||
t.Errorf("IdentityLink = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveLinkedPeerID_CanonicalPeerID(t *testing.T) {
|
||||
// When peerID is already in canonical "platform:id" format,
|
||||
// it should match identity_links that use the bare ID.
|
||||
links := map[string][]string{
|
||||
"john": {"123"},
|
||||
}
|
||||
got := resolveLinkedPeerID(links, "telegram", "telegram:123")
|
||||
if got != "john" {
|
||||
t.Errorf("resolveLinkedPeerID with canonical peerID = %q, want %q", got, "john")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveLinkedPeerID_CanonicalInLinks(t *testing.T) {
|
||||
// When identity_links contain canonical IDs and peerID is canonical too
|
||||
links := map[string][]string{
|
||||
"john": {"telegram:123", "discord:456"},
|
||||
}
|
||||
got := resolveLinkedPeerID(links, "telegram", "telegram:123")
|
||||
if got != "john" {
|
||||
t.Errorf("resolveLinkedPeerID canonical in links = %q, want %q", got, "john")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveLinkedPeerID_BarePeerIDMatchesCanonicalLink(t *testing.T) {
|
||||
// When peerID is bare "123" and links have "telegram:123",
|
||||
// the scoped candidate "telegram:123" should match.
|
||||
links := map[string][]string{
|
||||
"john": {"telegram:123"},
|
||||
}
|
||||
got := resolveLinkedPeerID(links, "telegram", "123")
|
||||
if got != "john" {
|
||||
t.Errorf("resolveLinkedPeerID bare peer matches canonical link = %q, want %q", got, "john")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveLinkedPeerID_NoMatch(t *testing.T) {
|
||||
links := map[string][]string{
|
||||
"john": {"telegram:123"},
|
||||
}
|
||||
got := resolveLinkedPeerID(links, "discord", "999")
|
||||
if got != "" {
|
||||
t.Errorf("resolveLinkedPeerID no match = %q, want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseAgentSessionKey_Valid(t *testing.T) {
|
||||
parsed := ParseAgentSessionKey("agent:sales:telegram:direct:user123")
|
||||
if parsed == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
}
|
||||
if parsed.AgentID != "sales" {
|
||||
t.Errorf("AgentID = %q, want 'sales'", parsed.AgentID)
|
||||
}
|
||||
if parsed.Rest != "telegram:direct:user123" {
|
||||
t.Errorf("Rest = %q, want 'telegram:direct:user123'", parsed.Rest)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseAgentSessionKey_Invalid(t *testing.T) {
|
||||
tests := []string{
|
||||
"",
|
||||
"foo:bar",
|
||||
"notprefix:sales:main",
|
||||
"agent::main",
|
||||
"agent:sales:",
|
||||
}
|
||||
for _, input := range tests {
|
||||
if got := ParseAgentSessionKey(input); got != nil {
|
||||
t.Errorf("ParseAgentSessionKey(%q) = %+v, want nil", input, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSubagentSessionKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want bool
|
||||
}{
|
||||
{"subagent:task-1", true},
|
||||
{"agent:main:subagent:task-1", true},
|
||||
{"agent:main:main", false},
|
||||
{"agent:main:telegram:direct:user123", false},
|
||||
{"", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if got := IsSubagentSessionKey(tt.input); got != tt.want {
|
||||
t.Errorf("IsSubagentSessionKey(%q) = %v, want %v", tt.input, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
+14
-27
@@ -32,7 +32,7 @@ type AllocationInput struct {
|
||||
func AllocateRouteSession(input AllocationInput) Allocation {
|
||||
scope := buildSessionScope(input)
|
||||
legacySessionAliases := buildLegacySessionAliases(input)
|
||||
legacyMainSessionKey := strings.ToLower(routing.BuildAgentMainSessionKey(input.AgentID))
|
||||
legacyMainSessionKey := strings.ToLower(BuildLegacyMainAlias(input.AgentID))
|
||||
return Allocation{
|
||||
Scope: scope,
|
||||
SessionKey: BuildSessionKey(scope),
|
||||
@@ -85,7 +85,7 @@ func buildSessionScope(input AllocationInput) SessionScope {
|
||||
values["topic"] = "topic:" + strings.ToLower(topicID)
|
||||
}
|
||||
case "sender":
|
||||
senderID := routing.CanonicalSessionIdentityID(
|
||||
senderID := CanonicalSessionIdentityID(
|
||||
inbound.Channel,
|
||||
inbound.SenderID,
|
||||
input.SessionPolicy.IdentityLinks,
|
||||
@@ -107,11 +107,11 @@ func buildSessionScope(input AllocationInput) SessionScope {
|
||||
}
|
||||
|
||||
func buildLegacySessionAliases(input AllocationInput) []string {
|
||||
aliases := []string{strings.ToLower(routing.BuildAgentMainSessionKey(input.AgentID))}
|
||||
aliases := []string{strings.ToLower(BuildLegacyMainAlias(input.AgentID))}
|
||||
inbound := input.Context
|
||||
|
||||
if strings.EqualFold(strings.TrimSpace(inbound.ChatType), "direct") {
|
||||
senderID := routing.CanonicalSessionIdentityID(
|
||||
senderID := CanonicalSessionIdentityID(
|
||||
inbound.Channel,
|
||||
inbound.SenderID,
|
||||
input.SessionPolicy.IdentityLinks,
|
||||
@@ -119,20 +119,10 @@ func buildLegacySessionAliases(input AllocationInput) []string {
|
||||
if senderID == "" {
|
||||
return uniqueAliases(aliases)
|
||||
}
|
||||
for _, dmScope := range []routing.DMScope{
|
||||
routing.DMScopePerPeer,
|
||||
routing.DMScopePerChannelPeer,
|
||||
routing.DMScopePerAccountChannelPeer,
|
||||
} {
|
||||
aliases = append(aliases, strings.ToLower(routing.BuildAgentPeerSessionKey(routing.SessionKeyParams{
|
||||
AgentID: input.AgentID,
|
||||
Channel: inbound.Channel,
|
||||
AccountID: inbound.Account,
|
||||
Peer: &routing.RoutePeer{Kind: "direct", ID: senderID},
|
||||
DMScope: dmScope,
|
||||
IdentityLinks: input.SessionPolicy.IdentityLinks,
|
||||
})))
|
||||
}
|
||||
aliases = append(
|
||||
aliases,
|
||||
BuildLegacyDirectAliases(input.AgentID, inbound.Channel, inbound.Account, senderID)...,
|
||||
)
|
||||
return uniqueAliases(aliases)
|
||||
}
|
||||
|
||||
@@ -143,15 +133,12 @@ func buildLegacySessionAliases(input AllocationInput) []string {
|
||||
if topicID := strings.TrimSpace(inbound.TopicID); topicID != "" {
|
||||
peerID = peerID + "/" + topicID
|
||||
}
|
||||
aliases = append(aliases, strings.ToLower(routing.BuildAgentPeerSessionKey(routing.SessionKeyParams{
|
||||
AgentID: input.AgentID,
|
||||
Channel: inbound.Channel,
|
||||
AccountID: inbound.Account,
|
||||
Peer: &routing.RoutePeer{
|
||||
Kind: strings.ToLower(strings.TrimSpace(inbound.ChatType)),
|
||||
ID: peerID,
|
||||
},
|
||||
})))
|
||||
aliases = append(aliases, BuildLegacyPeerAlias(
|
||||
input.AgentID,
|
||||
inbound.Channel,
|
||||
strings.ToLower(strings.TrimSpace(inbound.ChatType)),
|
||||
peerID,
|
||||
))
|
||||
|
||||
return uniqueAliases(aliases)
|
||||
}
|
||||
|
||||
+134
-1
@@ -5,9 +5,19 @@ import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
)
|
||||
|
||||
const sessionKeyV1Prefix = "sk_v1_"
|
||||
const (
|
||||
sessionKeyV1Prefix = "sk_v1_"
|
||||
legacyAgentSessionKeyPrefix = "agent:"
|
||||
)
|
||||
|
||||
type ParsedLegacySessionKey struct {
|
||||
AgentID string
|
||||
Rest string
|
||||
}
|
||||
|
||||
// BuildOpaqueSessionKey returns a stable opaque session key derived from a
|
||||
// canonical alias string. The alias remains available through metadata for
|
||||
@@ -27,6 +37,129 @@ func IsOpaqueSessionKey(key string) bool {
|
||||
return strings.HasPrefix(strings.ToLower(strings.TrimSpace(key)), sessionKeyV1Prefix)
|
||||
}
|
||||
|
||||
func IsLegacyAgentSessionKey(key string) bool {
|
||||
return strings.HasPrefix(strings.ToLower(strings.TrimSpace(key)), legacyAgentSessionKeyPrefix)
|
||||
}
|
||||
|
||||
func IsExplicitSessionKey(key string) bool {
|
||||
return IsOpaqueSessionKey(key) || IsLegacyAgentSessionKey(key)
|
||||
}
|
||||
|
||||
func ParseLegacyAgentSessionKey(sessionKey string) *ParsedLegacySessionKey {
|
||||
raw := strings.TrimSpace(sessionKey)
|
||||
if raw == "" {
|
||||
return nil
|
||||
}
|
||||
parts := strings.SplitN(raw, ":", 3)
|
||||
if len(parts) < 3 || parts[0] != "agent" {
|
||||
return nil
|
||||
}
|
||||
agentID := strings.TrimSpace(parts[1])
|
||||
rest := parts[2]
|
||||
if agentID == "" || rest == "" {
|
||||
return nil
|
||||
}
|
||||
return &ParsedLegacySessionKey{AgentID: agentID, Rest: rest}
|
||||
}
|
||||
|
||||
func BuildLegacyMainAlias(agentID string) string {
|
||||
return fmt.Sprintf("agent:%s:main", routing.NormalizeAgentID(agentID))
|
||||
}
|
||||
|
||||
// BuildMainSessionKey returns the canonical opaque main-session key for an
|
||||
// agent. The corresponding legacy alias remains available via
|
||||
// BuildLegacyMainAlias for compatibility and migration logic.
|
||||
func BuildMainSessionKey(agentID string) string {
|
||||
return BuildOpaqueSessionKey(BuildLegacyMainAlias(agentID))
|
||||
}
|
||||
|
||||
func BuildLegacyDirectAliases(agentID, channel, account, peerID string) []string {
|
||||
agentID = routing.NormalizeAgentID(agentID)
|
||||
channel = normalizeLegacyChannel(channel)
|
||||
account = routing.NormalizeAccountID(account)
|
||||
peerID = strings.ToLower(strings.TrimSpace(peerID))
|
||||
if peerID == "" {
|
||||
return nil
|
||||
}
|
||||
return []string{
|
||||
fmt.Sprintf("agent:%s:direct:%s", agentID, peerID),
|
||||
fmt.Sprintf("agent:%s:%s:direct:%s", agentID, channel, peerID),
|
||||
fmt.Sprintf("agent:%s:%s:%s:direct:%s", agentID, channel, account, peerID),
|
||||
}
|
||||
}
|
||||
|
||||
func BuildLegacyPeerAlias(agentID, channel, peerKind, peerID string) string {
|
||||
agentID = routing.NormalizeAgentID(agentID)
|
||||
channel = normalizeLegacyChannel(channel)
|
||||
peerKind = strings.ToLower(strings.TrimSpace(peerKind))
|
||||
if peerKind == "" {
|
||||
peerKind = "unknown"
|
||||
}
|
||||
peerID = strings.ToLower(strings.TrimSpace(peerID))
|
||||
if peerID == "" {
|
||||
peerID = "unknown"
|
||||
}
|
||||
return fmt.Sprintf("agent:%s:%s:%s:%s", agentID, channel, peerKind, peerID)
|
||||
}
|
||||
|
||||
// CanonicalSessionIdentityID collapses an identity using identity_links when
|
||||
// possible, then returns a normalized lowercase identifier.
|
||||
func CanonicalSessionIdentityID(channel, rawID string, identityLinks map[string][]string) string {
|
||||
normalizedID := strings.TrimSpace(rawID)
|
||||
if normalizedID == "" {
|
||||
return ""
|
||||
}
|
||||
if linked := resolveLinkedPeerID(identityLinks, channel, normalizedID); linked != "" {
|
||||
normalizedID = linked
|
||||
}
|
||||
return strings.ToLower(normalizedID)
|
||||
}
|
||||
|
||||
func normalizeLegacyChannel(channel string) string {
|
||||
channel = strings.ToLower(strings.TrimSpace(channel))
|
||||
if channel == "" {
|
||||
return "unknown"
|
||||
}
|
||||
return channel
|
||||
}
|
||||
|
||||
func resolveLinkedPeerID(identityLinks map[string][]string, channel, peerID string) string {
|
||||
if len(identityLinks) == 0 {
|
||||
return ""
|
||||
}
|
||||
peerID = strings.TrimSpace(peerID)
|
||||
if peerID == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
candidates := make(map[string]bool)
|
||||
rawCandidate := strings.ToLower(peerID)
|
||||
if rawCandidate != "" {
|
||||
candidates[rawCandidate] = true
|
||||
}
|
||||
channel = strings.ToLower(strings.TrimSpace(channel))
|
||||
if channel != "" {
|
||||
candidates[fmt.Sprintf("%s:%s", channel, rawCandidate)] = true
|
||||
}
|
||||
if idx := strings.Index(rawCandidate, ":"); idx > 0 && idx < len(rawCandidate)-1 {
|
||||
candidates[rawCandidate[idx+1:]] = true
|
||||
}
|
||||
|
||||
for canonical, ids := range identityLinks {
|
||||
canonicalName := strings.TrimSpace(canonical)
|
||||
if canonicalName == "" {
|
||||
continue
|
||||
}
|
||||
for _, id := range ids {
|
||||
normalized := strings.ToLower(strings.TrimSpace(id))
|
||||
if normalized != "" && candidates[normalized] {
|
||||
return canonicalName
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// CanonicalScopeSignature returns a stable serialized representation of scope.
|
||||
func CanonicalScopeSignature(scope SessionScope) string {
|
||||
parts := []string{
|
||||
|
||||
@@ -0,0 +1,72 @@
|
||||
package session
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestIsExplicitSessionKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
key string
|
||||
want bool
|
||||
}{
|
||||
{"sk_v1_abc", true},
|
||||
{"agent:main:direct:user123", true},
|
||||
{"custom-key", false},
|
||||
{"", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
if got := IsExplicitSessionKey(tt.key); got != tt.want {
|
||||
t.Fatalf("IsExplicitSessionKey(%q) = %v, want %v", tt.key, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseLegacyAgentSessionKey(t *testing.T) {
|
||||
parsed := ParseLegacyAgentSessionKey("agent:sales:telegram:direct:user123")
|
||||
if parsed == nil {
|
||||
t.Fatal("expected parsed legacy key, got nil")
|
||||
}
|
||||
if parsed.AgentID != "sales" {
|
||||
t.Fatalf("AgentID = %q, want sales", parsed.AgentID)
|
||||
}
|
||||
if parsed.Rest != "telegram:direct:user123" {
|
||||
t.Fatalf("Rest = %q, want telegram:direct:user123", parsed.Rest)
|
||||
}
|
||||
|
||||
if got := ParseLegacyAgentSessionKey("sk_v1_abc"); got != nil {
|
||||
t.Fatalf("expected nil for opaque key, got %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildLegacyDirectAliases(t *testing.T) {
|
||||
aliases := BuildLegacyDirectAliases("Main", "Telegram", "BotA", "User123")
|
||||
want := []string{
|
||||
"agent:main:direct:user123",
|
||||
"agent:main:telegram:direct:user123",
|
||||
"agent:main:telegram:bota:direct:user123",
|
||||
}
|
||||
if len(aliases) != len(want) {
|
||||
t.Fatalf("len(aliases) = %d, want %d", len(aliases), len(want))
|
||||
}
|
||||
for i := range want {
|
||||
if aliases[i] != want[i] {
|
||||
t.Fatalf("aliases[%d] = %q, want %q", i, aliases[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildLegacyPeerAlias(t *testing.T) {
|
||||
got := BuildLegacyPeerAlias("Main", "Slack", "channel", "C001")
|
||||
if got != "agent:main:slack:channel:c001" {
|
||||
t.Fatalf("BuildLegacyPeerAlias() = %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildMainSessionKey(t *testing.T) {
|
||||
got := BuildMainSessionKey("Main")
|
||||
if !IsOpaqueSessionKey(got) {
|
||||
t.Fatalf("BuildMainSessionKey() = %q, want opaque key", got)
|
||||
}
|
||||
if got != BuildOpaqueSessionKey("agent:main:main") {
|
||||
t.Fatalf("BuildMainSessionKey() = %q, want stable main-key hash", got)
|
||||
}
|
||||
}
|
||||
+2
-4
@@ -311,8 +311,7 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string {
|
||||
pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer pubCancel()
|
||||
t.msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{
|
||||
Channel: channel,
|
||||
ChatID: chatID,
|
||||
Context: bus.NewOutboundContext(channel, chatID, ""),
|
||||
Content: output,
|
||||
})
|
||||
return "ok"
|
||||
@@ -335,8 +334,7 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string {
|
||||
pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer pubCancel()
|
||||
t.msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{
|
||||
Channel: channel,
|
||||
ChatID: chatID,
|
||||
Context: bus.NewOutboundContext(channel, chatID, ""),
|
||||
Content: output,
|
||||
})
|
||||
return "ok"
|
||||
|
||||
Reference in New Issue
Block a user