diff --git a/docs/configuration.md b/docs/configuration.md index c1c1cc498..e59d6a022 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -120,133 +120,78 @@ dammi le ultime news - Unknown slash command (for example `/foo`) passes through to normal LLM processing. - Registered but unsupported command on the current channel (for example `/show` on WhatsApp) returns an explicit user-facing error and stops further processing. -### Agent Bindings (Route messages to specific agents) +### Routing -Use `bindings` in `config.json` to route incoming messages to different agents by channel/account/context. +Routing is configured through `agents.dispatch.rules`. + +Each rule matches against the normalized inbound context produced by channels. +Rules are evaluated from top to bottom. The first matching rule wins. If no +rule matches, PicoClaw falls back to the configured default agent. + +Supported match fields: + +* `channel` +* `account` +* `space` +* `chat` +* `topic` +* `sender` +* `mentioned` + +Match values use the same scope vocabulary as the session system: + +* `space`: `workspace:t001`, `guild:123456` +* `chat`: `direct:user123`, `group:-100123`, `channel:c123` +* `topic`: `topic:42` +* `sender`: a normalized sender identifier for the platform + +Rules may optionally override the global `session.dimensions` value through +`session_dimensions`. This allows routing and session allocation to stay aligned +without reintroducing the old `bindings` or `dm_scope` formats. + +Example: ```json { "agents": { - "defaults": { - "workspace": "~/.picoclaw/workspace", - "model_name": "gpt-4o-mini" - }, "list": [ - { "id": "main", "default": true, "name": "Main Assistant" }, - { "id": "support", "name": "Support Assistant" }, - { "id": "sales", "name": "Sales Assistant" } - ] - }, - "bindings": [ - { - "agent_id": "support", - "match": { - "channel": "telegram", - "account_id": "*", - "peer": { "kind": "direct", "id": "user123" } - } - }, - { - "agent_id": "sales", - "match": { - "channel": "discord", - "account_id": "my-discord-bot", - "guild_id": "987654321" - } + { "id": "main", "default": true }, + { "id": "support" }, + { "id": "sales" } + ], + "dispatch": { + "rules": [ + { + "name": "vip in support group", + "agent": "sales", + "when": { + "channel": "telegram", + "chat": "group:-1001234567890", + "sender": "12345" + }, + "session_dimensions": ["chat", "sender"] + }, + { + "name": "telegram support group", + "agent": "support", + "when": { + "channel": "telegram", + "chat": "group:-1001234567890" + }, + "session_dimensions": ["chat"] + } + ] } - ] -} -``` - -#### `bindings` fields - -| Field | Required | Description | -|-------|----------|-------------| -| `agent_id` | Yes | Target agent id in `agents.list` | -| `match.channel` | Yes | Channel name (e.g. `telegram`, `discord`) | -| `match.account_id` | No | Channel account filter. Use `"*"` for all accounts of that channel. If omitted, only default account is matched | -| `match.peer.kind` + `match.peer.id` | No | Exact peer match (e.g. direct chat / topic / group id) | -| `match.guild_id` | No | Guild/server-level match | -| `match.team_id` | No | Team/workspace-level match | - -#### Matching priority - -When multiple bindings exist, PicoClaw resolves in this order: - -1. `peer` -2. `parent_peer` (for thread/topic parent contexts) -3. `guild_id` -4. `team_id` -5. `account_id` (non-wildcard) -6. channel wildcard (`account_id: "*"`) -7. default agent - -If a binding points to a missing `agent_id`, PicoClaw falls back to the default agent. - -#### How matching works (step-by-step) - -1. PicoClaw first filters bindings by `match.channel` (must equal current channel). -2. It then filters by `match.account_id`: - - omitted: match only the channel's default account - - `"*"`: match all accounts on this channel - - explicit value: exact account id match (case-insensitive) -3. From the remaining candidates, it applies the priority chain above and stops at the first hit. - -In other words: **channel + account form the candidate set; peer/guild/team then decide final winner**. - -#### Common recipes - -**1) Route one specific DM user to a specialist agent** - -```json -{ - "agent_id": "support", - "match": { - "channel": "telegram", - "account_id": "*", - "peer": { "kind": "direct", "id": "user123" } + }, + "session": { + "dimensions": ["chat"] } } ``` -**2) Route one Discord server (guild) to a dedicated agent** - -```json -{ - "agent_id": "sales", - "match": { - "channel": "discord", - "account_id": "my-discord-bot", - "guild_id": "987654321" - } -} -``` - -**3) Route all remaining traffic of a channel to a fallback agent** - -```json -{ - "agent_id": "main", - "match": { - "channel": "discord", - "account_id": "*" - } -} -``` - -#### Authoring guidelines (important) - -- Keep exactly one clear default agent in `agents.list` (`"default": true`). -- Put specific rules (`peer`, `guild_id`, `team_id`) and broad rules (`account_id: "*"` only) together safely; priority already guarantees specific rules win. -- Avoid duplicate rules with the same specificity and match values. If duplicates exist, the first matching entry in the config array wins. -- Ensure every `agent_id` exists in `agents.list`; unknown IDs silently fall back to default. - -#### Troubleshooting checklist - -- **Rule not taking effect?** Check `match.channel` spelling first (must be exact). -- **Expected account-specific routing but still using default?** Verify `match.account_id` equals actual runtime account id. -- **Wildcard catches too much traffic?** Add more specific `peer/guild/team` rules for critical paths. -- **Unexpected default fallback?** Confirm `agent_id` exists and is not misspelled. +In the example above, the VIP rule must appear before the broader group rule. +Because routing is strictly ordered, more specific rules should be placed +earlier and broader fallback rules later. ### 🔒 Security Sandbox diff --git a/pkg/agent/context_legacy.go b/pkg/agent/context_legacy.go index 85e331ae9..5644571fb 100644 --- a/pkg/agent/context_legacy.go +++ b/pkg/agent/context_legacy.go @@ -42,7 +42,7 @@ func (m *legacyContextManager) Compact(_ context.Context, req *CompactRequest) e if result, ok := m.forceCompression(req.SessionKey); ok { m.al.emitEvent( EventKindContextCompress, - m.al.newTurnEventScope("", req.SessionKey).meta(0, "forceCompression", "turn.context.compress"), + m.al.newTurnEventScope("", req.SessionKey, nil).meta(0, "forceCompression", "turn.context.compress"), ContextCompressPayload{ Reason: req.Reason, DroppedMessages: result.DroppedMessages, @@ -247,7 +247,7 @@ func (m *legacyContextManager) summarizeSession(agent *AgentInstance, sessionKey agent.Sessions.Save(sessionKey) m.al.emitEvent( EventKindSessionSummarize, - m.al.newTurnEventScope(agent.ID, sessionKey).meta(0, "summarizeSession", "turn.session.summarize"), + m.al.newTurnEventScope(agent.ID, sessionKey, nil).meta(0, "summarizeSession", "turn.session.summarize"), SessionSummarizePayload{ SummarizedMessages: len(validMessages), KeptMessages: keepCount, diff --git a/pkg/agent/dispatch_request.go b/pkg/agent/dispatch_request.go new file mode 100644 index 000000000..cb54264d6 --- /dev/null +++ b/pkg/agent/dispatch_request.go @@ -0,0 +1,147 @@ +package agent + +import ( + "strings" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/routing" + "github.com/sipeed/picoclaw/pkg/session" +) + +// DispatchRequest is the normalized runtime input passed into the agent loop +// after routing and session allocation have completed. +type DispatchRequest struct { + SessionKey string + SessionAliases []string + InboundContext *bus.InboundContext + RouteResult *routing.ResolvedRoute + SessionScope *session.SessionScope + UserMessage string + Media []string +} + +func (r DispatchRequest) Channel() string { + if r.InboundContext == nil { + return "" + } + return r.InboundContext.Channel +} + +func (r DispatchRequest) ChatID() string { + if r.InboundContext == nil { + return "" + } + return r.InboundContext.ChatID +} + +func (r DispatchRequest) MessageID() string { + if r.InboundContext == nil { + return "" + } + return r.InboundContext.MessageID +} + +func (r DispatchRequest) ReplyToMessageID() string { + if r.InboundContext == nil { + return "" + } + return r.InboundContext.ReplyToMessageID +} + +func (r DispatchRequest) SenderID() string { + if r.InboundContext == nil { + return "" + } + return r.InboundContext.SenderID +} + +func normalizeProcessOptionsInPlace(opts *processOptions) { + if opts == nil { + return + } + *opts = normalizeProcessOptions(*opts) +} + +func normalizeProcessOptions(opts processOptions) processOptions { + if opts.Dispatch.SessionKey == "" { + opts.Dispatch.SessionKey = strings.TrimSpace(opts.SessionKey) + } + if len(opts.Dispatch.SessionAliases) == 0 && len(opts.SessionAliases) > 0 { + opts.Dispatch.SessionAliases = append([]string(nil), opts.SessionAliases...) + } + if opts.Dispatch.UserMessage == "" { + opts.Dispatch.UserMessage = opts.UserMessage + } + if len(opts.Dispatch.Media) == 0 && len(opts.Media) > 0 { + opts.Dispatch.Media = append([]string(nil), opts.Media...) + } + if opts.Dispatch.RouteResult == nil { + opts.Dispatch.RouteResult = cloneResolvedRoute(opts.RouteResult) + } + if opts.Dispatch.SessionScope == nil { + opts.Dispatch.SessionScope = session.CloneScope(opts.SessionScope) + } + if opts.Dispatch.InboundContext == nil { + if opts.InboundContext != nil { + opts.Dispatch.InboundContext = cloneInboundContext(opts.InboundContext) + } else if opts.Channel != "" || opts.ChatID != "" || opts.SenderID != "" || + opts.MessageID != "" || opts.ReplyToMessageID != "" { + inbound := bus.InboundContext{ + Channel: strings.TrimSpace(opts.Channel), + ChatID: strings.TrimSpace(opts.ChatID), + SenderID: strings.TrimSpace(opts.SenderID), + MessageID: strings.TrimSpace(opts.MessageID), + ReplyToMessageID: strings.TrimSpace(opts.ReplyToMessageID), + } + inbound.ChatType = inferChatTypeFromSessionScope(opts.Dispatch.SessionScope) + if inbound.Channel != "" || inbound.ChatID != "" || inbound.SenderID != "" || + inbound.MessageID != "" || inbound.ReplyToMessageID != "" { + inbound = bus.NormalizeInboundMessage(bus.InboundMessage{Context: inbound}).Context + opts.Dispatch.InboundContext = &inbound + } + } + } + + // Keep legacy mirrors populated while the rest of the runtime migrates. + opts.SessionKey = opts.Dispatch.SessionKey + opts.SessionAliases = append([]string(nil), opts.Dispatch.SessionAliases...) + opts.UserMessage = opts.Dispatch.UserMessage + opts.Media = append([]string(nil), opts.Dispatch.Media...) + opts.InboundContext = cloneInboundContext(opts.Dispatch.InboundContext) + opts.RouteResult = cloneResolvedRoute(opts.Dispatch.RouteResult) + opts.SessionScope = session.CloneScope(opts.Dispatch.SessionScope) + if opts.InboundContext != nil { + if opts.Channel == "" { + opts.Channel = opts.InboundContext.Channel + } + if opts.ChatID == "" { + opts.ChatID = opts.InboundContext.ChatID + } + if opts.MessageID == "" { + opts.MessageID = opts.InboundContext.MessageID + } + if opts.ReplyToMessageID == "" { + opts.ReplyToMessageID = opts.InboundContext.ReplyToMessageID + } + if opts.SenderID == "" { + opts.SenderID = opts.InboundContext.SenderID + } + } + + return opts +} + +func inferChatTypeFromSessionScope(scope *session.SessionScope) string { + if scope == nil || len(scope.Values) == 0 { + return "" + } + chatValue := strings.TrimSpace(scope.Values["chat"]) + if chatValue == "" { + return "" + } + chatType, _, ok := strings.Cut(chatValue, ":") + if !ok { + return "" + } + return strings.ToLower(strings.TrimSpace(chatType)) +} diff --git a/pkg/agent/dispatch_request_test.go b/pkg/agent/dispatch_request_test.go new file mode 100644 index 000000000..ec5f70339 --- /dev/null +++ b/pkg/agent/dispatch_request_test.go @@ -0,0 +1,135 @@ +package agent + +import ( + "testing" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/routing" + "github.com/sipeed/picoclaw/pkg/session" +) + +func TestNormalizeProcessOptions_PopulatesDispatchFromLegacyFields(t *testing.T) { + opts := normalizeProcessOptions(processOptions{ + SessionKey: "session-1", + SessionAliases: []string{"legacy:one"}, + Channel: "telegram", + ChatID: "chat-1", + MessageID: "msg-1", + ReplyToMessageID: "reply-1", + SenderID: "user-1", + UserMessage: "hello", + Media: []string{"media://one"}, + }) + + if opts.Dispatch.SessionKey != "session-1" { + t.Fatalf("Dispatch.SessionKey = %q, want session-1", opts.Dispatch.SessionKey) + } + if len(opts.Dispatch.SessionAliases) != 1 || opts.Dispatch.SessionAliases[0] != "legacy:one" { + t.Fatalf("Dispatch.SessionAliases = %v, want [legacy:one]", opts.Dispatch.SessionAliases) + } + if opts.Dispatch.Channel() != "telegram" || opts.Dispatch.ChatID() != "chat-1" { + t.Fatalf( + "dispatch addressing = (%q,%q), want (telegram,chat-1)", + opts.Dispatch.Channel(), + opts.Dispatch.ChatID(), + ) + } + if opts.Dispatch.SenderID() != "user-1" || opts.Dispatch.MessageID() != "msg-1" { + t.Fatalf("dispatch sender/message = (%q,%q)", opts.Dispatch.SenderID(), opts.Dispatch.MessageID()) + } + if opts.Dispatch.ReplyToMessageID() != "reply-1" { + t.Fatalf("Dispatch.ReplyToMessageID() = %q, want reply-1", opts.Dispatch.ReplyToMessageID()) + } + if opts.Dispatch.UserMessage != "hello" { + t.Fatalf("Dispatch.UserMessage = %q, want hello", opts.Dispatch.UserMessage) + } + if len(opts.Dispatch.Media) != 1 || opts.Dispatch.Media[0] != "media://one" { + t.Fatalf("Dispatch.Media = %v, want [media://one]", opts.Dispatch.Media) + } +} + +func TestNormalizeProcessOptions_UsesDispatchAsSourceOfTruth(t *testing.T) { + inbound := &bus.InboundContext{ + Channel: "slack", + ChatID: "C123", + ChatType: "channel", + SenderID: "U123", + MessageID: "m-1", + ReplyToMessageID: "parent-1", + } + route := &routing.ResolvedRoute{ + AgentID: "support", + Channel: "slack", + AccountID: "workspace-a", + MatchedBy: "dispatch.rule:test", + SessionPolicy: routing.SessionPolicy{ + Dimensions: []string{"chat", "sender"}, + }, + } + scope := &session.SessionScope{ + Version: session.ScopeVersionV1, + AgentID: "support", + Channel: "slack", + Account: "workspace-a", + Dimensions: []string{"chat"}, + Values: map[string]string{ + "chat": "channel:c123", + }, + } + + opts := normalizeProcessOptions(processOptions{ + Dispatch: DispatchRequest{ + SessionKey: "sk_v1_example", + SessionAliases: []string{"agent:support:slack:channel:c123"}, + InboundContext: inbound, + RouteResult: route, + SessionScope: scope, + UserMessage: "hello", + Media: []string{"media://one"}, + }, + }) + + if opts.SessionKey != "sk_v1_example" { + t.Fatalf("SessionKey = %q, want sk_v1_example", opts.SessionKey) + } + if opts.Channel != "slack" || opts.ChatID != "C123" { + t.Fatalf("legacy mirrors = (%q,%q), want (slack,C123)", opts.Channel, opts.ChatID) + } + if opts.SenderID != "U123" || opts.MessageID != "m-1" { + t.Fatalf("legacy sender/message = (%q,%q)", opts.SenderID, opts.MessageID) + } + if opts.ReplyToMessageID != "parent-1" { + t.Fatalf("ReplyToMessageID = %q, want parent-1", opts.ReplyToMessageID) + } + if opts.RouteResult == nil || opts.RouteResult.AgentID != "support" { + t.Fatalf("RouteResult = %#v, want support route", opts.RouteResult) + } + if opts.SessionScope == nil || opts.SessionScope.AgentID != "support" { + t.Fatalf("SessionScope = %#v, want support scope", opts.SessionScope) + } +} + +func TestNormalizeProcessOptions_InfersLegacyChatTypeFromSessionScope(t *testing.T) { + opts := normalizeProcessOptions(processOptions{ + Channel: "telegram", + ChatID: "-100123", + SenderID: "user-1", + UserMessage: "hello", + SessionScope: &session.SessionScope{ + Version: session.ScopeVersionV1, + AgentID: "main", + Channel: "telegram", + Dimensions: []string{"chat"}, + Values: map[string]string{ + "chat": "group:-100123", + }, + }, + }) + + if opts.Dispatch.InboundContext == nil { + t.Fatal("Dispatch.InboundContext is nil") + } + if opts.Dispatch.InboundContext.ChatType != "group" { + t.Fatalf("Dispatch.InboundContext.ChatType = %q, want group", opts.Dispatch.InboundContext.ChatType) + } +} diff --git a/pkg/agent/eventbus_test.go b/pkg/agent/eventbus_test.go index 2785d70a5..31b996260 100644 --- a/pkg/agent/eventbus_test.go +++ b/pkg/agent/eventbus_test.go @@ -10,6 +10,8 @@ import ( "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/routing" + "github.com/sipeed/picoclaw/pkg/session" "github.com/sipeed/picoclaw/pkg/tools" ) @@ -136,6 +138,31 @@ func TestAgentLoop_EmitsMinimalTurnEvents(t *testing.T) { DefaultResponse: defaultResponse, EnableSummary: false, SendResponse: false, + InboundContext: &bus.InboundContext{ + Channel: "cli", + ChatID: "direct", + ChatType: "direct", + SenderID: "tester", + }, + RouteResult: &routing.ResolvedRoute{ + AgentID: "main", + Channel: "cli", + AccountID: routing.DefaultAccountID, + SessionPolicy: routing.SessionPolicy{ + Dimensions: []string{"sender"}, + }, + MatchedBy: "default", + }, + SessionScope: &session.SessionScope{ + Version: session.ScopeVersionV1, + AgentID: "main", + Channel: "cli", + Account: routing.DefaultAccountID, + Dimensions: []string{"sender"}, + Values: map[string]string{ + "sender": "tester", + }, + }, }) if err != nil { t.Fatalf("runAgentLoop failed: %v", err) @@ -176,6 +203,18 @@ func TestAgentLoop_EmitsMinimalTurnEvents(t *testing.T) { if evt.Meta.SessionKey != "session-1" { t.Fatalf("event %d has session key %q, want session-1", i, evt.Meta.SessionKey) } + if evt.Context == nil || evt.Context.Inbound == nil { + t.Fatalf("event %d missing inbound turn context", i) + } + if evt.Context.Inbound.Channel != "cli" || evt.Context.Inbound.SenderID != "tester" { + t.Fatalf("event %d inbound context = %+v", i, evt.Context.Inbound) + } + if evt.Context.Route == nil || evt.Context.Route.AgentID != "main" { + t.Fatalf("event %d missing route context: %+v", i, evt.Context.Route) + } + if evt.Context.Scope == nil || evt.Context.Scope.Values["sender"] != "tester" { + t.Fatalf("event %d missing session scope: %+v", i, evt.Context.Scope) + } } startPayload, ok := events[0].Payload.(TurnStartPayload) @@ -472,7 +511,6 @@ func TestAgentLoop_EmitsSessionSummarizeEvent(t *testing.T) { sub := al.SubscribeEvents(16) defer al.UnsubscribeEvents(sub.ID) - // Use legacyContextManager's summarizeSession via contextManager interface lcm := &legacyContextManager{al: al} lcm.summarizeSession(defaultAgent, "session-1") @@ -572,12 +610,6 @@ func TestAgentLoop_EmitsFollowUpQueuedEvent(t *testing.T) { if payload.SourceTool != "async_followup" { t.Fatalf("expected source tool async_followup, got %q", payload.SourceTool) } - if payload.Channel != "cli" { - t.Fatalf("expected channel cli, got %q", payload.Channel) - } - if payload.ChatID != "direct" { - t.Fatalf("expected chat id direct, got %q", payload.ChatID) - } if payload.ContentLen != len("background result") { t.Fatalf("expected content len %d, got %d", len("background result"), payload.ContentLen) } diff --git a/pkg/agent/events.go b/pkg/agent/events.go index 615eacf9f..f68d3eab5 100644 --- a/pkg/agent/events.go +++ b/pkg/agent/events.go @@ -86,6 +86,7 @@ type Event struct { Kind EventKind Time time.Time Meta EventMeta + Context *TurnContext Payload any } @@ -98,6 +99,7 @@ type EventMeta struct { Iteration int TracePath string Source string + turnContext *TurnContext } // TurnEndStatus describes the terminal state of a turn. @@ -114,8 +116,6 @@ const ( // TurnStartPayload describes the start of a turn. type TurnStartPayload struct { - Channel string - ChatID string UserMessage string MediaCount int } @@ -217,8 +217,6 @@ type SteeringInjectedPayload struct { // FollowUpQueuedPayload describes an async follow-up queued back into the inbound bus. type FollowUpQueuedPayload struct { SourceTool string - Channel string - ChatID string ContentLen int } diff --git a/pkg/agent/hooks.go b/pkg/agent/hooks.go index c23961dc6..687e54532 100644 --- a/pkg/agent/hooks.go +++ b/pkg/agent/hooks.go @@ -90,12 +90,11 @@ type ToolApprover interface { type LLMHookRequest struct { Meta EventMeta `json:"meta"` + Context *TurnContext `json:"context,omitempty"` Model string `json:"model"` Messages []providers.Message `json:"messages,omitempty"` Tools []providers.ToolDefinition `json:"tools,omitempty"` Options map[string]any `json:"options,omitempty"` - Channel string `json:"channel,omitempty"` - ChatID string `json:"chat_id,omitempty"` GracefulTerminal bool `json:"graceful_terminal,omitempty"` } @@ -104,6 +103,8 @@ func (r *LLMHookRequest) Clone() *LLMHookRequest { return nil } cloned := *r + cloned.Meta = cloneEventMeta(r.Meta) + cloned.Context = cloneTurnContext(r.Context) cloned.Messages = cloneProviderMessages(r.Messages) cloned.Tools = cloneToolDefinitions(r.Tools) cloned.Options = cloneStringAnyMap(r.Options) @@ -112,10 +113,9 @@ func (r *LLMHookRequest) Clone() *LLMHookRequest { type LLMHookResponse struct { Meta EventMeta `json:"meta"` + Context *TurnContext `json:"context,omitempty"` Model string `json:"model"` Response *providers.LLMResponse `json:"response,omitempty"` - Channel string `json:"channel,omitempty"` - ChatID string `json:"chat_id,omitempty"` } func (r *LLMHookResponse) Clone() *LLMHookResponse { @@ -123,12 +123,15 @@ func (r *LLMHookResponse) Clone() *LLMHookResponse { return nil } cloned := *r + cloned.Meta = cloneEventMeta(r.Meta) + cloned.Context = cloneTurnContext(r.Context) cloned.Response = cloneLLMResponse(r.Response) return &cloned } type ToolCallHookRequest struct { Meta EventMeta `json:"meta"` + Context *TurnContext `json:"context,omitempty"` Tool string `json:"tool"` Arguments map[string]any `json:"arguments,omitempty"` Channel string `json:"channel,omitempty"` @@ -141,6 +144,8 @@ func (r *ToolCallHookRequest) Clone() *ToolCallHookRequest { return nil } cloned := *r + cloned.Meta = cloneEventMeta(r.Meta) + cloned.Context = cloneTurnContext(r.Context) cloned.Arguments = cloneStringAnyMap(r.Arguments) cloned.HookResult = cloneToolResult(r.HookResult) return &cloned @@ -148,10 +153,9 @@ func (r *ToolCallHookRequest) Clone() *ToolCallHookRequest { type ToolApprovalRequest struct { Meta EventMeta `json:"meta"` + Context *TurnContext `json:"context,omitempty"` Tool string `json:"tool"` Arguments map[string]any `json:"arguments,omitempty"` - Channel string `json:"channel,omitempty"` - ChatID string `json:"chat_id,omitempty"` } func (r *ToolApprovalRequest) Clone() *ToolApprovalRequest { @@ -159,18 +163,19 @@ func (r *ToolApprovalRequest) Clone() *ToolApprovalRequest { return nil } cloned := *r + cloned.Meta = cloneEventMeta(r.Meta) + cloned.Context = cloneTurnContext(r.Context) cloned.Arguments = cloneStringAnyMap(r.Arguments) return &cloned } type ToolResultHookResponse struct { Meta EventMeta `json:"meta"` + Context *TurnContext `json:"context,omitempty"` Tool string `json:"tool"` Arguments map[string]any `json:"arguments,omitempty"` Result *tools.ToolResult `json:"result,omitempty"` Duration time.Duration `json:"duration"` - Channel string `json:"channel,omitempty"` - ChatID string `json:"chat_id,omitempty"` } func (r *ToolResultHookResponse) Clone() *ToolResultHookResponse { @@ -178,6 +183,8 @@ func (r *ToolResultHookResponse) Clone() *ToolResultHookResponse { return nil } cloned := *r + cloned.Meta = cloneEventMeta(r.Meta) + cloned.Context = cloneTurnContext(r.Context) cloned.Arguments = cloneStringAnyMap(r.Arguments) cloned.Result = cloneToolResult(r.Result) return &cloned diff --git a/pkg/agent/hooks_test.go b/pkg/agent/hooks_test.go index 9049a5c72..6979fbf1e 100644 --- a/pkg/agent/hooks_test.go +++ b/pkg/agent/hooks_test.go @@ -12,6 +12,7 @@ import ( "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/routing" + "github.com/sipeed/picoclaw/pkg/session" "github.com/sipeed/picoclaw/pkg/tools" ) @@ -108,7 +109,8 @@ func (p *llmHookTestProvider) GetDefaultModel() string { } type llmObserverHook struct { - eventCh chan Event + eventCh chan Event + lastInbound *bus.InboundContext } func (h *llmObserverHook) OnEvent(ctx context.Context, evt Event) error { @@ -125,6 +127,9 @@ func (h *llmObserverHook) BeforeLLM( ctx context.Context, req *LLMHookRequest, ) (*LLMHookRequest, HookDecision, error) { + if req.Context != nil { + h.lastInbound = cloneInboundContext(req.Context.Inbound) + } next := req.Clone() next.Model = "hook-model" return next, HookDecision{Action: HookActionModify}, nil @@ -157,6 +162,31 @@ func TestAgentLoop_Hooks_ObserverAndLLMInterceptor(t *testing.T) { DefaultResponse: defaultResponse, EnableSummary: false, SendResponse: false, + InboundContext: &bus.InboundContext{ + Channel: "cli", + ChatID: "direct", + ChatType: "direct", + SenderID: "hook-user", + }, + RouteResult: &routing.ResolvedRoute{ + AgentID: "main", + Channel: "cli", + AccountID: routing.DefaultAccountID, + SessionPolicy: routing.SessionPolicy{ + Dimensions: []string{"sender"}, + }, + MatchedBy: "default", + }, + SessionScope: &session.SessionScope{ + Version: session.ScopeVersionV1, + AgentID: "main", + Channel: "cli", + Account: routing.DefaultAccountID, + Dimensions: []string{"sender"}, + Values: map[string]string{ + "sender": "hook-user", + }, + }, }) if err != nil { t.Fatalf("runAgentLoop failed: %v", err) @@ -171,12 +201,30 @@ func TestAgentLoop_Hooks_ObserverAndLLMInterceptor(t *testing.T) { if lastModel != "hook-model" { t.Fatalf("expected model hook-model, got %q", lastModel) } + if hook.lastInbound == nil { + t.Fatal("expected hook to receive inbound context") + } + if hook.lastInbound.Channel != "cli" || hook.lastInbound.SenderID != "hook-user" { + t.Fatalf("hook inbound context = %+v", hook.lastInbound) + } + if hook.lastInbound != nil && hook.lastInbound.ChatID != "direct" { + t.Fatalf("hook inbound chat ID = %q, want direct", hook.lastInbound.ChatID) + } select { case evt := <-hook.eventCh: if evt.Kind != EventKindTurnEnd { t.Fatalf("expected turn end event, got %v", evt.Kind) } + if evt.Context == nil || evt.Context.Inbound == nil { + t.Fatal("expected observer event to carry inbound context") + } + if evt.Context.Route == nil || evt.Context.Route.AgentID != "main" { + t.Fatalf("expected observer event to carry route context, got %+v", evt.Context.Route) + } + if evt.Context.Scope == nil || evt.Context.Scope.Values["sender"] != "hook-user" { + t.Fatalf("expected observer event to carry session scope, got %+v", evt.Context.Scope) + } case <-time.After(2 * time.Second): t.Fatal("timed out waiting for hook observer event") } @@ -725,7 +773,7 @@ func TestAgentLoop_HookRespond_InterruptSkipsRemaining(t *testing.T) { sub := al.SubscribeEvents(32) defer al.UnsubscribeEvents(sub.ID) - sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID) + sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID) type result struct { resp string @@ -801,7 +849,7 @@ func TestAgentLoop_HookRespond_SteeringSkipsRemaining(t *testing.T) { sub := al.SubscribeEvents(32) defer al.UnsubscribeEvents(sub.ID) - sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID) + sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID) type result struct { resp string diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index a1e88f5be..01e457b5a 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -29,6 +29,7 @@ import ( "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/routing" + "github.com/sipeed/picoclaw/pkg/session" "github.com/sipeed/picoclaw/pkg/skills" "github.com/sipeed/picoclaw/pkg/state" "github.com/sipeed/picoclaw/pkg/tools" @@ -73,25 +74,30 @@ type AgentLoop struct { // processOptions configures how a message is processed type processOptions struct { - SessionKey string // Session identifier for history/context - Channel string // Target channel for tool execution - ChatID string // Target chat ID for tool execution - MessageID string // Current inbound platform message ID - ReplyToMessageID string // Current inbound reply target message ID - SenderID string // Current sender ID for dynamic context - SenderDisplayName string // Current sender display name for dynamic context - UserMessage string // User message content (may include prefix) - ForcedSkills []string // Skills explicitly requested for this message - SystemPromptOverride string // Override the default system prompt (Used by SubTurns) - Media []string // media:// refs from inbound message - InitialSteeringMessages []providers.Message // Steering messages from refactor/agent - DefaultResponse string // Response when LLM returns empty - EnableSummary bool // Whether to trigger summarization - SendResponse bool // Whether to send response via bus - AllowInterimPicoPublish bool // Whether pico tool-call interim text can be published when SendResponse is false - SuppressToolFeedback bool // Whether to suppress inline tool feedback messages - NoHistory bool // If true, don't load session history (for heartbeat) - SkipInitialSteeringPoll bool // If true, skip the steering poll at loop start (used by Continue) + Dispatch DispatchRequest // Normalized routed request boundary for this turn + SessionKey string // Session identifier for history/context + SessionAliases []string // Compatibility aliases for the session key + Channel string // Target channel for tool execution + ChatID string // Target chat ID for tool execution + MessageID string // Current inbound platform message ID + ReplyToMessageID string // Current inbound reply target message ID + SenderID string // Current sender ID for dynamic context + SenderDisplayName string // Current sender display name for dynamic context + UserMessage string // User message content (may include prefix) + ForcedSkills []string // Skills explicitly requested for this message + SystemPromptOverride string // Override the default system prompt (Used by SubTurns) + Media []string // media:// refs from inbound message + InitialSteeringMessages []providers.Message // Steering messages from refactor/agent + DefaultResponse string // Response when LLM returns empty + EnableSummary bool // Whether to trigger summarization + SendResponse bool // Whether to send response via bus + AllowInterimPicoPublish bool // Whether pico tool-call interim text can be published when SendResponse is false + SuppressToolFeedback bool // Whether to suppress inline tool feedback messages + NoHistory bool // If true, don't load session history (for heartbeat) + SkipInitialSteeringPoll bool // If true, skip the steering poll at loop start (used by Continue) + InboundContext *bus.InboundContext // Normalized inbound facts for events/hooks + RouteResult *routing.ResolvedRoute // Route decision snapshot for events/hooks + SessionScope *session.SessionScope // Session scope snapshot for events/hooks } type continuationTarget struct { @@ -245,12 +251,23 @@ func registerSharedTools( // Message tool if cfg.Tools.IsToolEnabled("message") { messageTool := tools.NewMessageTool() - messageTool.SetSendCallback(func(channel, chatID, content, replyToMessageID string) error { + messageTool.SetSendCallback(func( + ctx context.Context, + channel, chatID, content, replyToMessageID string, + ) error { pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second) defer pubCancel() + outboundCtx := bus.NewOutboundContext(channel, chatID, replyToMessageID) + outboundAgentID, outboundSessionKey, outboundScope := outboundTurnMetadata( + tools.ToolAgentID(ctx), + tools.ToolSessionKey(ctx), + tools.ToolSessionScope(ctx), + ) return msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{ - Channel: channel, - ChatID: chatID, + Context: outboundCtx, + AgentID: outboundAgentID, + SessionKey: outboundSessionKey, + Scope: outboundScope, Content: content, ReplyToMessageID: replyToMessageID, }) @@ -600,6 +617,19 @@ func (al *AgentLoop) Run(ctx context.Context) error { // immediately available messages, blocking for the first one until ctx is done. func (al *AgentLoop) drainBusToSteering(ctx context.Context, activeScope, activeAgentID string) { blocking := true + var requeue []bus.InboundMessage + defer func() { + for _, msg := range requeue { + if err := al.requeueInboundMessage(msg); err != nil { + logger.WarnCF("agent", "Failed to flush requeued inbound message", map[string]any{ + "error": err.Error(), + "channel": msg.Channel, + "sender_id": msg.SenderID, + }) + } + } + }() + for { var msg bus.InboundMessage @@ -630,13 +660,7 @@ func (al *AgentLoop) drainBusToSteering(ctx context.Context, activeScope, active msgScope, _, scopeOK := al.resolveSteeringTarget(msg) if !scopeOK || msgScope != activeScope { - if err := al.requeueInboundMessage(msg); err != nil { - logger.WarnCF("agent", "Failed to requeue non-steering inbound message", map[string]any{ - "error": err.Error(), - "channel": msg.Channel, - "sender_id": msg.SenderID, - }) - } + requeue = append(requeue, msg) continue } @@ -694,8 +718,7 @@ func (al *AgentLoop) PublishResponseIfNeeded(ctx context.Context, channel, chatI } al.bus.PublishOutbound(ctx, bus.OutboundMessage{ - Channel: channel, - ChatID: chatID, + Context: bus.NewOutboundContext(channel, chatID, ""), Content: response, }) logger.InfoCF("agent", "Published outbound response", @@ -715,9 +738,10 @@ func (al *AgentLoop) buildContinuationTarget(msg bus.InboundMessage) (*continuat if err != nil { return nil, err } + allocation := al.allocateRouteSession(route, msg) return &continuationTarget{ - SessionKey: resolveScopeKey(route, msg.SessionKey), + SessionKey: resolveScopeKey(allocation.SessionKey, msg.SessionKey), Channel: msg.Channel, ChatID: msg.ChatID, }, nil @@ -745,6 +769,74 @@ func (al *AgentLoop) Close() { } } +func outboundContextFromInbound( + inbound *bus.InboundContext, + channel, chatID, replyToMessageID string, +) bus.InboundContext { + if inbound == nil { + return bus.NewOutboundContext(channel, chatID, replyToMessageID) + } + + outboundCtx := *cloneInboundContext(inbound) + if outboundCtx.Channel == "" { + outboundCtx.Channel = channel + } + if outboundCtx.ChatID == "" { + outboundCtx.ChatID = chatID + } + if outboundCtx.ReplyToMessageID == "" { + outboundCtx.ReplyToMessageID = replyToMessageID + } + return outboundCtx +} + +func outboundScopeFromSessionScope(scope *session.SessionScope) *bus.OutboundScope { + if scope == nil { + return nil + } + outboundScope := &bus.OutboundScope{ + Version: scope.Version, + AgentID: scope.AgentID, + Channel: scope.Channel, + Account: scope.Account, + } + if len(scope.Dimensions) > 0 { + outboundScope.Dimensions = append([]string(nil), scope.Dimensions...) + } + if len(scope.Values) > 0 { + outboundScope.Values = make(map[string]string, len(scope.Values)) + for key, value := range scope.Values { + outboundScope.Values[key] = value + } + } + return outboundScope +} + +func outboundTurnMetadata( + agentID, sessionKey string, + scope *session.SessionScope, +) (string, string, *bus.OutboundScope) { + return agentID, sessionKey, outboundScopeFromSessionScope(scope) +} + +func outboundMessageForTurn(ts *turnState, content string) bus.OutboundMessage { + agentID, sessionKey, scope := outboundTurnMetadata(ts.agent.ID, ts.sessionKey, ts.opts.Dispatch.SessionScope) + return bus.OutboundMessage{ + Channel: ts.channel, + ChatID: ts.chatID, + Context: outboundContextFromInbound( + ts.opts.Dispatch.InboundContext, + ts.channel, + ts.chatID, + ts.opts.Dispatch.ReplyToMessageID(), + ), + AgentID: agentID, + SessionKey: sessionKey, + Scope: scope, + Content: content, + } +} + // MountHook registers an in-process hook on the agent loop. func (al *AgentLoop) MountHook(reg HookRegistration) error { if al == nil || al.hooks == nil { @@ -791,32 +883,37 @@ type turnEventScope struct { agentID string sessionKey string turnID string + context *TurnContext } -func (al *AgentLoop) newTurnEventScope(agentID, sessionKey string) turnEventScope { +func (al *AgentLoop) newTurnEventScope(agentID, sessionKey string, turnCtx *TurnContext) turnEventScope { seq := al.turnSeq.Add(1) return turnEventScope{ agentID: agentID, sessionKey: sessionKey, turnID: fmt.Sprintf("%s-turn-%d", agentID, seq), + context: cloneTurnContext(turnCtx), } } func (ts turnEventScope) meta(iteration int, source, tracePath string) EventMeta { return EventMeta{ - AgentID: ts.agentID, - TurnID: ts.turnID, - SessionKey: ts.sessionKey, - Iteration: iteration, - Source: source, - TracePath: tracePath, + AgentID: ts.agentID, + TurnID: ts.turnID, + SessionKey: ts.sessionKey, + Iteration: iteration, + Source: source, + TracePath: tracePath, + turnContext: cloneTurnContext(ts.context), } } func (al *AgentLoop) emitEvent(kind EventKind, meta EventMeta, payload any) { + clonedMeta := cloneEventMeta(meta) evt := Event{ Kind: kind, - Meta: meta, + Meta: clonedMeta, + Context: cloneTurnContext(clonedMeta.turnContext), Payload: payload, } @@ -882,10 +979,10 @@ func (al *AgentLoop) logEvent(evt Event) { fields["source"] = evt.Meta.Source } + appendEventContextFields(fields, evt.Context) + switch payload := evt.Payload.(type) { case TurnStartPayload: - fields["channel"] = payload.Channel - fields["chat_id"] = payload.ChatID fields["user_len"] = len(payload.UserMessage) fields["media_count"] = payload.MediaCount case TurnEndPayload: @@ -938,8 +1035,6 @@ func (al *AgentLoop) logEvent(evt Event) { fields["total_content_len"] = payload.TotalContentLen case FollowUpQueuedPayload: fields["source_tool"] = payload.SourceTool - fields["channel"] = payload.Channel - fields["chat_id"] = payload.ChatID fields["content_len"] = payload.ContentLen case InterruptReceivedPayload: fields["interrupt_kind"] = payload.Kind @@ -965,6 +1060,87 @@ func (al *AgentLoop) logEvent(evt Event) { logger.InfoCF("eventbus", fmt.Sprintf("Agent event: %s", evt.Kind.String()), fields) } +func appendEventContextFields(fields map[string]any, turnCtx *TurnContext) { + if turnCtx == nil { + return + } + + if inbound := turnCtx.Inbound; inbound != nil { + if inbound.Channel != "" { + fields["inbound_channel"] = inbound.Channel + } + if inbound.Account != "" { + fields["inbound_account"] = inbound.Account + } + if inbound.ChatID != "" { + fields["inbound_chat_id"] = inbound.ChatID + } + if inbound.ChatType != "" { + fields["inbound_chat_type"] = inbound.ChatType + } + if inbound.TopicID != "" { + fields["inbound_topic_id"] = inbound.TopicID + } + if inbound.SpaceType != "" { + fields["inbound_space_type"] = inbound.SpaceType + } + if inbound.SpaceID != "" { + fields["inbound_space_id"] = inbound.SpaceID + } + if inbound.SenderID != "" { + fields["inbound_sender_id"] = inbound.SenderID + } + if inbound.Mentioned { + fields["inbound_mentioned"] = true + } + } + + if route := turnCtx.Route; route != nil { + if route.AgentID != "" { + fields["route_agent_id"] = route.AgentID + } + if route.Channel != "" { + fields["route_channel"] = route.Channel + } + if route.AccountID != "" { + fields["route_account_id"] = route.AccountID + } + if route.MatchedBy != "" { + fields["route_matched_by"] = route.MatchedBy + } + if len(route.SessionPolicy.Dimensions) > 0 { + fields["route_dimensions"] = strings.Join(route.SessionPolicy.Dimensions, ",") + } + if count := len(route.SessionPolicy.IdentityLinks); count > 0 { + fields["route_identity_link_count"] = count + } + } + + if scope := turnCtx.Scope; scope != nil { + if scope.Version > 0 { + fields["scope_version"] = scope.Version + } + if scope.AgentID != "" { + fields["scope_agent_id"] = scope.AgentID + } + if scope.Channel != "" { + fields["scope_channel"] = scope.Channel + } + if scope.Account != "" { + fields["scope_account"] = scope.Account + } + if len(scope.Dimensions) > 0 { + fields["scope_dimensions"] = strings.Join(scope.Dimensions, ",") + } + for dim, value := range scope.Values { + if dim == "" || value == "" { + continue + } + fields["scope_"+dim] = value + } + } +} + func (al *AgentLoop) RegisterTool(tool tools.Tool) { registry := al.GetRegistry() for _, agentID := range registry.ListAgentIDs() { @@ -1239,8 +1415,7 @@ func (al *AgentLoop) sendTranscriptionFeedback( } err := al.channelManager.SendMessage(ctx, bus.OutboundMessage{ - Channel: channel, - ChatID: chatID, + Context: bus.NewOutboundContext(channel, chatID, messageID), Content: feedbackMsg, ReplyToMessageID: messageID, }) @@ -1316,9 +1491,12 @@ func (al *AgentLoop) ProcessDirectWithChannel( } msg := bus.InboundMessage{ - Channel: channel, - SenderID: "cron", - ChatID: chatID, + Context: bus.InboundContext{ + Channel: channel, + ChatID: chatID, + ChatType: "direct", + SenderID: "cron", + }, Content: content, SessionKey: sessionKey, } @@ -1343,11 +1521,20 @@ func (al *AgentLoop) ProcessHeartbeat( if agent == nil { return "", fmt.Errorf("no default agent for heartbeat") } + dispatch := DispatchRequest{ + SessionKey: "heartbeat", + UserMessage: content, + } + if channel != "" || chatID != "" { + dispatch.InboundContext = &bus.InboundContext{ + Channel: channel, + ChatID: chatID, + ChatType: "direct", + SenderID: "heartbeat", + } + } return al.runAgentLoop(ctx, agent, processOptions{ - SessionKey: "heartbeat", - Channel: channel, - ChatID: chatID, - UserMessage: content, + Dispatch: dispatch, DefaultResponse: defaultResponse, EnableSummary: false, SendResponse: false, @@ -1357,6 +1544,8 @@ func (al *AgentLoop) ProcessHeartbeat( } func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) (string, error) { + msg = bus.NormalizeInboundMessage(msg) + // Add message preview to log (show full content for error messages) var logContent string if strings.Contains(msg.Content, "Error:") || strings.Contains(msg.Content, "error") { @@ -1401,30 +1590,36 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) } } - // Resolve session key from route, while preserving explicit agent-scoped keys. - scopeKey := resolveScopeKey(route, msg.SessionKey) + allocation := al.allocateRouteSession(route, msg) + + // Resolve session key from the route allocation, while preserving explicit + // agent-scoped keys supplied by the caller. + scopeKey := resolveScopeKey(allocation.SessionKey, msg.SessionKey) sessionKey := scopeKey logger.InfoCF("agent", "Routed message", map[string]any{ - "agent_id": agent.ID, - "scope_key": scopeKey, - "session_key": sessionKey, - "matched_by": route.MatchedBy, - "route_agent": route.AgentID, - "route_channel": route.Channel, + "agent_id": agent.ID, + "scope_key": scopeKey, + "session_key": sessionKey, + "matched_by": route.MatchedBy, + "route_agent": route.AgentID, + "route_channel": route.Channel, + "route_main_session": allocation.MainSessionKey, }) opts := processOptions{ - SessionKey: sessionKey, - Channel: msg.Channel, - ChatID: msg.ChatID, - MessageID: msg.MessageID, - ReplyToMessageID: inboundMetadata(msg, metadataKeyReplyToMessage), + Dispatch: DispatchRequest{ + SessionKey: sessionKey, + SessionAliases: buildSessionAliases(sessionKey, append(allocation.SessionAliases, msg.SessionKey)...), + InboundContext: cloneInboundContext(&msg.Context), + RouteResult: cloneResolvedRoute(&route), + SessionScope: session.CloneScope(&allocation.Scope), + UserMessage: msg.Content, + Media: append([]string(nil), msg.Media...), + }, SenderID: msg.SenderID, SenderDisplayName: msg.Sender.DisplayName, - UserMessage: msg.Content, - Media: msg.Media, DefaultResponse: defaultResponse, EnableSummary: true, SendResponse: false, @@ -1437,11 +1632,11 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) return response, nil } - if pending := al.takePendingSkills(opts.SessionKey); len(pending) > 0 { + if pending := al.takePendingSkills(opts.Dispatch.SessionKey); len(pending) > 0 { opts.ForcedSkills = append(opts.ForcedSkills, pending...) logger.InfoCF("agent", "Applying pending skill override", map[string]any{ - "session_key": opts.SessionKey, + "session_key": opts.Dispatch.SessionKey, "skills": strings.Join(pending, ","), }) } @@ -1451,14 +1646,8 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) func (al *AgentLoop) resolveMessageRoute(msg bus.InboundMessage) (routing.ResolvedRoute, *AgentInstance, error) { registry := al.GetRegistry() - route := registry.ResolveRoute(routing.RouteInput{ - Channel: msg.Channel, - AccountID: inboundMetadata(msg, metadataKeyAccountID), - Peer: extractPeer(msg), - ParentPeer: extractParentPeer(msg), - GuildID: inboundMetadata(msg, metadataKeyGuildID), - TeamID: inboundMetadata(msg, metadataKeyTeamID), - }) + inboundCtx := normalizedInboundContext(msg) + route := registry.ResolveRoute(inboundCtx) agent, ok := registry.GetAgent(route.AgentID) if !ok { @@ -1471,11 +1660,64 @@ func (al *AgentLoop) resolveMessageRoute(msg bus.InboundMessage) (routing.Resolv return route, agent, nil } -func resolveScopeKey(route routing.ResolvedRoute, msgSessionKey string) string { - if msgSessionKey != "" && strings.HasPrefix(msgSessionKey, sessionKeyAgentPrefix) { +func normalizedInboundContext(msg bus.InboundMessage) bus.InboundContext { + return bus.NormalizeInboundMessage(msg).Context +} + +func resolveScopeKey(routeSessionKey, msgSessionKey string) string { + if isExplicitSessionKey(msgSessionKey) { return msgSessionKey } - return route.SessionKey + return routeSessionKey +} + +func isExplicitSessionKey(sessionKey string) bool { + return session.IsExplicitSessionKey(sessionKey) +} + +func buildSessionAliases(canonicalKey string, keys ...string) []string { + if len(keys) == 0 { + return nil + } + aliases := make([]string, 0, len(keys)) + seen := make(map[string]struct{}, len(keys)) + canonicalKey = strings.TrimSpace(canonicalKey) + for _, key := range keys { + key = strings.TrimSpace(key) + if key == "" || key == canonicalKey { + continue + } + if _, ok := seen[key]; ok { + continue + } + seen[key] = struct{}{} + aliases = append(aliases, key) + } + if len(aliases) == 0 { + return nil + } + return aliases +} + +func ensureSessionMetadata(store session.SessionStore, key string, scope *session.SessionScope, aliases []string) { + if key == "" || scope == nil { + return + } + metaStore, ok := store.(interface { + EnsureSessionMetadata(sessionKey string, scope *session.SessionScope, aliases []string) + }) + if !ok { + return + } + metaStore.EnsureSessionMetadata(key, scope, aliases) +} + +func (al *AgentLoop) allocateRouteSession(route routing.ResolvedRoute, msg bus.InboundMessage) session.Allocation { + return session.AllocateRouteSession(session.AllocationInput{ + AgentID: route.AgentID, + Context: normalizedInboundContext(msg), + SessionPolicy: route.SessionPolicy, + }) } func (al *AgentLoop) resolveSteeringTarget(msg bus.InboundMessage) (string, string, bool) { @@ -1487,8 +1729,9 @@ func (al *AgentLoop) resolveSteeringTarget(msg bus.InboundMessage) (string, stri if err != nil || agent == nil { return "", "", false } + allocation := al.allocateRouteSession(route, msg) - return resolveScopeKey(route, msg.SessionKey), agent.ID, true + return resolveScopeKey(allocation.SessionKey, msg.SessionKey), agent.ID, true } func (al *AgentLoop) requeueInboundMessage(msg bus.InboundMessage) error { @@ -1497,11 +1740,7 @@ func (al *AgentLoop) requeueInboundMessage(msg bus.InboundMessage) error { } pubCtx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - return al.bus.PublishOutbound(pubCtx, bus.OutboundMessage{ - Channel: msg.Channel, - ChatID: msg.ChatID, - Content: msg.Content, - }) + return al.bus.PublishInbound(pubCtx, msg) } func (al *AgentLoop) processSystemMessage( @@ -1556,13 +1795,22 @@ func (al *AgentLoop) processSystemMessage( } // Use the origin session for context - sessionKey := routing.BuildAgentMainSessionKey(agent.ID) + sessionKey := session.BuildMainSessionKey(agent.ID) + dispatch := DispatchRequest{ + SessionKey: sessionKey, + UserMessage: fmt.Sprintf("[System: %s] %s", msg.SenderID, msg.Content), + } + if originChannel != "" || originChatID != "" { + dispatch.InboundContext = &bus.InboundContext{ + Channel: originChannel, + ChatID: originChatID, + ChatType: "direct", + SenderID: msg.SenderID, + } + } return al.runAgentLoop(ctx, agent, processOptions{ - SessionKey: sessionKey, - Channel: originChannel, - ChatID: originChatID, - UserMessage: fmt.Sprintf("[System: %s] %s", msg.SenderID, msg.Content), + Dispatch: dispatch, DefaultResponse: "Background task completed.", EnableSummary: false, SendResponse: true, @@ -1576,9 +1824,13 @@ func (al *AgentLoop) runAgentLoop( agent *AgentInstance, opts processOptions, ) (string, error) { + opts = normalizeProcessOptions(opts) + // Record last channel for heartbeat notifications (skip internal channels and cli) - if opts.Channel != "" && opts.ChatID != "" && !constants.IsInternalChannel(opts.Channel) { - channelKey := fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID) + if opts.Dispatch.Channel() != "" && + opts.Dispatch.ChatID() != "" && + !constants.IsInternalChannel(opts.Dispatch.Channel()) { + channelKey := fmt.Sprintf("%s:%s", opts.Dispatch.Channel(), opts.Dispatch.ChatID()) if err := al.RecordLastChannel(channelKey); err != nil { logger.WarnCF( "agent", @@ -1588,7 +1840,19 @@ func (al *AgentLoop) runAgentLoop( } } - ts := newTurnState(agent, opts, al.newTurnEventScope(agent.ID, opts.SessionKey)) + ensureSessionMetadata( + agent.Sessions, + opts.Dispatch.SessionKey, + opts.Dispatch.SessionScope, + opts.Dispatch.SessionAliases, + ) + + turnScope := al.newTurnEventScope( + agent.ID, + opts.Dispatch.SessionKey, + newTurnContext(opts.Dispatch.InboundContext, opts.Dispatch.RouteResult, opts.Dispatch.SessionScope), + ) + ts := newTurnState(agent, opts, turnScope) result, err := al.runTurn(ctx, ts) if err != nil { return "", err @@ -1608,10 +1872,22 @@ func (al *AgentLoop) runAgentLoop( } if opts.SendResponse && result.finalContent != "" { + agentID, sessionKey, scope := outboundTurnMetadata( + agent.ID, + opts.Dispatch.SessionKey, + opts.Dispatch.SessionScope, + ) al.bus.PublishOutbound(ctx, bus.OutboundMessage{ - Channel: opts.Channel, - ChatID: opts.ChatID, - Content: result.finalContent, + Context: outboundContextFromInbound( + opts.Dispatch.InboundContext, + opts.Dispatch.Channel(), + opts.Dispatch.ChatID(), + opts.Dispatch.ReplyToMessageID(), + ), + AgentID: agentID, + SessionKey: sessionKey, + Scope: scope, + Content: result.finalContent, }) } @@ -1620,7 +1896,7 @@ func (al *AgentLoop) runAgentLoop( logger.InfoCF("agent", fmt.Sprintf("Response: %s", responsePreview), map[string]any{ "agent_id": agent.ID, - "session_key": opts.SessionKey, + "session_key": opts.Dispatch.SessionKey, "iterations": ts.currentIteration(), "final_length": len(result.finalContent), }) @@ -1652,12 +1928,14 @@ func (al *AgentLoop) publishPicoReasoning(ctx context.Context, reasoningContent, defer pubCancel() if err := al.bus.PublishOutbound(pubCtx, bus.OutboundMessage{ - Channel: "pico", - ChatID: chatID, - Content: reasoningContent, - Metadata: map[string]string{ - metadataKeyMessageKind: messageKindThought, + Context: bus.InboundContext{ + Channel: "pico", + ChatID: chatID, + Raw: map[string]string{ + metadataKeyMessageKind: messageKindThought, + }, }, + Content: reasoningContent, }); err != nil { if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) || errors.Is(err, bus.ErrBusClosed) { @@ -1695,8 +1973,7 @@ func (al *AgentLoop) handleReasoning( defer pubCancel() if err := al.bus.PublishOutbound(pubCtx, bus.OutboundMessage{ - Channel: channelName, - ChatID: channelID, + Context: bus.NewOutboundContext(channelName, channelID, ""), Content: reasoningContent, }); err != nil { // Treat context.DeadlineExceeded / context.Canceled as expected @@ -1750,8 +2027,6 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er EventKindTurnStart, ts.eventMeta("runTurn", "turn.start"), TurnStartPayload{ - Channel: ts.channel, - ChatID: ts.chatID, UserMessage: ts.userMessage, MediaCount: len(ts.media), }, @@ -1779,7 +2054,7 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er ts.media, ts.channel, ts.chatID, - ts.opts.SenderID, + ts.opts.Dispatch.SenderID(), ts.opts.SenderDisplayName, activeSkillNames(ts.agent, ts.opts)..., ) @@ -1816,7 +2091,7 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er messages = ts.agent.ContextBuilder.BuildMessages( history, summary, ts.userMessage, ts.media, ts.channel, ts.chatID, - ts.opts.SenderID, ts.opts.SenderDisplayName, + ts.opts.Dispatch.SenderID(), ts.opts.SenderDisplayName, activeSkillNames(ts.agent, ts.opts)..., ) messages = resolveMediaRefs(messages, al.mediaStore, maxMediaSize) @@ -2002,12 +2277,11 @@ turnLoop: if al.hooks != nil { llmReq, decision := al.hooks.BeforeLLM(turnCtx, &LLMHookRequest{ Meta: ts.eventMeta("runTurn", "turn.llm.request"), + Context: cloneTurnContext(ts.turnCtx), Model: llmModel, Messages: callMessages, Tools: providerToolDefs, Options: llmOpts, - Channel: ts.channel, - ChatID: ts.chatID, GracefulTerminal: gracefulTerminal, }) switch decision.normalizedAction() { @@ -2178,11 +2452,10 @@ turnLoop: ) if retry == 0 && !constants.IsInternalChannel(ts.channel) { - al.bus.PublishOutbound(ctx, bus.OutboundMessage{ - Channel: ts.channel, - ChatID: ts.chatID, - Content: "Context window exceeded. Compressing history and retrying...", - }) + al.bus.PublishOutbound(ctx, outboundMessageForTurn( + ts, + "Context window exceeded. Compressing history and retrying...", + )) } if compactErr := al.contextManager.Compact(turnCtx, &CompactRequest{ @@ -2207,7 +2480,7 @@ turnLoop: } messages = ts.agent.ContextBuilder.BuildMessages( history, summary, "", - nil, ts.channel, ts.chatID, ts.opts.SenderID, ts.opts.SenderDisplayName, + nil, ts.channel, ts.chatID, ts.opts.Dispatch.SenderID(), ts.opts.SenderDisplayName, activeSkillNames(ts.agent, ts.opts)..., ) callMessages = messages @@ -2242,10 +2515,9 @@ turnLoop: if al.hooks != nil { llmResp, decision := al.hooks.AfterLLM(turnCtx, &LLMHookResponse{ Meta: ts.eventMeta("runTurn", "turn.llm.response"), + Context: cloneTurnContext(ts.turnCtx), Model: llmModel, Response: response, - Channel: ts.channel, - ChatID: ts.chatID, }) switch decision.normalizedAction() { case HookActionContinue, HookActionModify: @@ -2419,10 +2691,9 @@ turnLoop: if al.hooks != nil { toolReq, decision := al.hooks.BeforeTool(turnCtx, &ToolCallHookRequest{ Meta: ts.eventMeta("runTurn", "turn.tool.before"), + Context: cloneTurnContext(ts.turnCtx), Tool: toolName, Arguments: toolArgs, - Channel: ts.channel, - ChatID: ts.chatID, }) switch decision.normalizedAction() { case HookActionContinue, HookActionModify: @@ -2487,12 +2758,14 @@ turnLoop: (ts.opts.SendResponse || hookResult.ResponseHandled) if shouldSendForUser { al.bus.PublishOutbound(ctx, bus.OutboundMessage{ - Channel: ts.channel, - ChatID: ts.chatID, - Content: hookResult.ForUser, - Metadata: map[string]string{ - "is_tool_call": "true", + Context: bus.InboundContext{ + Channel: ts.channel, + ChatID: ts.chatID, + Raw: map[string]string{ + "is_tool_call": "true", + }, }, + Content: hookResult.ForUser, }) } @@ -2695,10 +2968,9 @@ turnLoop: if al.hooks != nil { approval := al.hooks.ApproveTool(turnCtx, &ToolApprovalRequest{ Meta: ts.eventMeta("runTurn", "turn.tool.approve"), + Context: cloneTurnContext(ts.turnCtx), Tool: toolName, Arguments: toolArgs, - Channel: ts.channel, - ChatID: ts.chatID, }) if !approval.Approved { allResponsesHandled = false @@ -2752,11 +3024,7 @@ turnLoop: ) feedbackMsg := utils.FormatToolFeedbackMessage(tc.Name, feedbackPreview) fbCtx, fbCancel := context.WithTimeout(turnCtx, 3*time.Second) - _ = al.bus.PublishOutbound(fbCtx, bus.OutboundMessage{ - Channel: ts.channel, - ChatID: ts.chatID, - Content: feedbackMsg, - }) + _ = al.bus.PublishOutbound(fbCtx, outboundMessageForTurn(ts, feedbackMsg)) fbCancel() } @@ -2769,11 +3037,7 @@ turnLoop: if !result.Silent && result.ForUser != "" { outCtx, outCancel := context.WithTimeout(context.Background(), 5*time.Second) defer outCancel() - _ = al.bus.PublishOutbound(outCtx, bus.OutboundMessage{ - Channel: ts.channel, - ChatID: ts.chatID, - Content: result.ForUser, - }) + _ = al.bus.PublishOutbound(outCtx, outboundMessageForTurn(ts, result.ForUser)) } // Determine content for the agent loop (ForLLM or error). @@ -2796,8 +3060,6 @@ turnLoop: ts.scope.meta(toolIteration, "runTurn", "turn.follow_up.queued"), FollowUpQueuedPayload{ SourceTool: asyncToolName, - Channel: ts.channel, - ChatID: ts.chatID, ContentLen: len(content), }, ) @@ -2805,10 +3067,13 @@ turnLoop: pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second) defer pubCancel() _ = al.bus.PublishInbound(pubCtx, bus.InboundMessage{ - Channel: "system", - SenderID: fmt.Sprintf("async:%s", asyncToolName), - ChatID: fmt.Sprintf("%s:%s", ts.channel, ts.chatID), - Content: content, + Context: bus.InboundContext{ + Channel: "system", + ChatID: fmt.Sprintf("%s:%s", ts.channel, ts.chatID), + ChatType: "direct", + SenderID: fmt.Sprintf("async:%s", asyncToolName), + }, + Content: content, }) } @@ -2817,8 +3082,14 @@ turnLoop: turnCtx, ts.channel, ts.chatID, - ts.opts.MessageID, - ts.opts.ReplyToMessageID, + ts.opts.Dispatch.MessageID(), + ts.opts.Dispatch.ReplyToMessageID(), + ) + execCtx = tools.WithToolSessionContext( + execCtx, + ts.agent.ID, + ts.sessionKey, + ts.opts.Dispatch.SessionScope, ) toolResult := ts.agent.Tools.ExecuteWithContext( execCtx, @@ -2838,12 +3109,11 @@ turnLoop: if al.hooks != nil { toolResp, decision := al.hooks.AfterTool(turnCtx, &ToolResultHookResponse{ Meta: ts.eventMeta("runTurn", "turn.tool.after"), + Context: cloneTurnContext(ts.turnCtx), Tool: toolName, Arguments: toolArgs, Result: toolResult, Duration: toolDuration, - Channel: ts.channel, - ChatID: ts.chatID, }) switch decision.normalizedAction() { case HookActionContinue, HookActionModify: @@ -2869,27 +3139,6 @@ turnLoop: toolResult = tools.ErrorResult("hook returned nil tool result") } - // Send ForUser if not silent and has content. - // For ResponseHandled tools, send regardless of SendResponse setting, - // since they've already handled the response (e.g., send_tts, send_file). - shouldSendForUser := !toolResult.Silent && toolResult.ForUser != "" && - (ts.opts.SendResponse || toolResult.ResponseHandled) - if shouldSendForUser { - al.bus.PublishOutbound(ctx, bus.OutboundMessage{ - Channel: ts.channel, - ChatID: ts.chatID, - Content: toolResult.ForUser, - Metadata: map[string]string{ - "is_tool_call": "true", - }, - }) - logger.DebugCF("agent", "Sent tool result to user", - map[string]any{ - "tool": toolName, - "content_len": len(toolResult.ForUser), - }) - } - if len(toolResult.Media) > 0 && toolResult.ResponseHandled { parts := make([]bus.MediaPart, 0, len(toolResult.Media)) for _, ref := range toolResult.Media { @@ -2906,7 +3155,16 @@ turnLoop: outboundMedia := bus.OutboundMediaMessage{ Channel: ts.channel, ChatID: ts.chatID, - Parts: parts, + Context: outboundContextFromInbound( + ts.opts.Dispatch.InboundContext, + ts.channel, + ts.chatID, + ts.opts.Dispatch.ReplyToMessageID(), + ), + AgentID: ts.agent.ID, + SessionKey: ts.sessionKey, + Scope: outboundScopeFromSessionScope(ts.opts.Dispatch.SessionScope), + Parts: parts, } if al.channelManager != nil && ts.channel != "" && !constants.IsInternalChannel(ts.channel) { if err := al.channelManager.SendMedia(ctx, outboundMedia); err != nil { @@ -2942,6 +3200,17 @@ turnLoop: allResponsesHandled = false } + shouldSendForUser := !toolResult.Silent && + toolResult.ForUser != "" && + (ts.opts.SendResponse || toolResult.ResponseHandled) + if shouldSendForUser { + al.bus.PublishOutbound(ctx, outboundMessageForTurn(ts, toolResult.ForUser)) + logger.DebugCF("agent", "Sent tool result to user", + map[string]any{ + "tool": toolName, + "content_len": len(toolResult.ForUser), + }) + } contentForLLM := toolResult.ContentForLLM() // Filter sensitive data (API keys, tokens, secrets) before sending to LLM @@ -3371,6 +3640,8 @@ func (al *AgentLoop) handleCommand( agent *AgentInstance, opts *processOptions, ) (string, bool) { + normalizeProcessOptionsInPlace(opts) + if !commands.HasCommandPrefix(msg.Content) { return "", false } @@ -3452,6 +3723,8 @@ func (al *AgentLoop) applyExplicitSkillCommand( agent *AgentInstance, opts *processOptions, ) (matched bool, handled bool, reply string) { + normalizeProcessOptionsInPlace(opts) + cmdName, ok := commands.CommandName(raw) if !ok || cmdName != "use" { return false, false, "" @@ -3469,7 +3742,7 @@ func (al *AgentLoop) applyExplicitSkillCommand( arg := strings.TrimSpace(parts[1]) if strings.EqualFold(arg, "clear") || strings.EqualFold(arg, "off") { if opts != nil { - al.clearPendingSkills(opts.SessionKey) + al.clearPendingSkills(opts.Dispatch.SessionKey) } return true, true, "Cleared pending skill override." } @@ -3480,10 +3753,10 @@ func (al *AgentLoop) applyExplicitSkillCommand( } if len(parts) < 3 { - if opts == nil || strings.TrimSpace(opts.SessionKey) == "" { + if opts == nil || strings.TrimSpace(opts.Dispatch.SessionKey) == "" { return true, true, commandsUnavailableSkillMessage() } - al.setPendingSkills(opts.SessionKey, []string{skillName}) + al.setPendingSkills(opts.Dispatch.SessionKey, []string{skillName}) return true, true, fmt.Sprintf( "Skill %q is armed for your next message. Send your next prompt normally, or use /use clear to cancel.", skillName, @@ -3497,6 +3770,7 @@ func (al *AgentLoop) applyExplicitSkillCommand( if opts != nil { opts.ForcedSkills = append(opts.ForcedSkills, skillName) + opts.Dispatch.UserMessage = message opts.UserMessage = message } @@ -3508,6 +3782,8 @@ func (al *AgentLoop) buildCommandsRuntime( agent *AgentInstance, opts *processOptions, ) *commands.Runtime { + normalizeProcessOptionsInPlace(opts) + registry := al.GetRegistry() cfg := al.GetConfig() rt := &commands.Runtime{ @@ -3669,39 +3945,6 @@ func mapCommandError(result commands.ExecuteResult) string { return fmt.Sprintf("Failed to execute /%s: %v", result.Command, result.Err) } -// extractPeer extracts the routing peer from the inbound message's structured Peer field. -func extractPeer(msg bus.InboundMessage) *routing.RoutePeer { - if msg.Peer.Kind == "" { - return nil - } - peerID := msg.Peer.ID - if peerID == "" { - if msg.Peer.Kind == "direct" { - peerID = msg.SenderID - } else { - peerID = msg.ChatID - } - } - return &routing.RoutePeer{Kind: msg.Peer.Kind, ID: peerID} -} - -func inboundMetadata(msg bus.InboundMessage, key string) string { - if msg.Metadata == nil { - return "" - } - return msg.Metadata[key] -} - -// extractParentPeer extracts the parent peer (reply-to) from inbound message metadata. -func extractParentPeer(msg bus.InboundMessage) *routing.RoutePeer { - parentKind := inboundMetadata(msg, metadataKeyParentPeerKind) - parentID := inboundMetadata(msg, metadataKeyParentPeerID) - if parentKind == "" || parentID == "" { - return nil - } - return &routing.RoutePeer{Kind: parentKind, ID: parentID} -} - // isNativeSearchProvider reports whether the given LLM provider implements // NativeSearchCapable and returns true for SupportsNativeSearch. func isNativeSearchProvider(p providers.LLMProvider) bool { diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go index 7fe5836b3..9cca84b6b 100644 --- a/pkg/agent/loop_test.go +++ b/pkg/agent/loop_test.go @@ -20,6 +20,7 @@ import ( "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/routing" + "github.com/sipeed/picoclaw/pkg/session" "github.com/sipeed/picoclaw/pkg/tools" ) @@ -38,7 +39,13 @@ func (f *fakeChannel) ReasoningChannelID() string { return f.id type fakeMediaChannel struct { fakeChannel - sentMedia []bus.OutboundMediaMessage + sentMessages []bus.OutboundMessage + sentMedia []bus.OutboundMediaMessage +} + +func (f *fakeMediaChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]string, error) { + f.sentMessages = append(f.sentMessages, msg) + return nil, nil } func (f *fakeMediaChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) ([]string, error) { @@ -139,7 +146,7 @@ func TestProcessMessage_IncludesCurrentSenderInDynamicContext(t *testing.T) { provider := &recordingProvider{} al := NewAgentLoop(cfg, msgBus, provider) - response, err := al.processMessage(context.Background(), bus.InboundMessage{ + response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{ Channel: "discord", SenderID: "discord:123", Sender: bus.SenderInfo{ @@ -147,7 +154,7 @@ func TestProcessMessage_IncludesCurrentSenderInDynamicContext(t *testing.T) { }, ChatID: "group-1", Content: "hello", - }) + })) if err != nil { t.Fatalf("processMessage() error = %v", err) } @@ -198,12 +205,12 @@ func TestProcessMessage_UseCommandLoadsRequestedSkill(t *testing.T) { provider := &recordingProvider{} al := NewAgentLoop(cfg, msgBus, provider) - response, err := al.processMessage(context.Background(), bus.InboundMessage{ + response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{ Channel: "telegram", SenderID: "telegram:123", ChatID: "chat-1", Content: "/use shell explain how to list files", - }) + })) if err != nil { t.Fatalf("processMessage() error = %v", err) } @@ -288,12 +295,12 @@ func TestProcessMessage_UseCommandArmsSkillForNextMessage(t *testing.T) { provider := &recordingProvider{} al := NewAgentLoop(cfg, msgBus, provider) - response, err := al.processMessage(context.Background(), bus.InboundMessage{ + response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{ Channel: "telegram", SenderID: "telegram:123", ChatID: "chat-1", Content: "/use shell", - }) + })) if err != nil { t.Fatalf("processMessage() arm error = %v", err) } @@ -301,12 +308,12 @@ func TestProcessMessage_UseCommandArmsSkillForNextMessage(t *testing.T) { t.Fatalf("arm response = %q, want armed confirmation", response) } - response, err = al.processMessage(context.Background(), bus.InboundMessage{ + response, err = al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{ Channel: "telegram", SenderID: "telegram:123", ChatID: "chat-1", Content: "explain how to list files", - }) + })) if err != nil { t.Fatalf("processMessage() follow-up error = %v", err) } @@ -619,12 +626,12 @@ func TestProcessMessage_MediaToolHandledSkipsFollowUpLLMAndFinalText(t *testing. path: imagePath, }) - response, err := al.processMessage(context.Background(), bus.InboundMessage{ + response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{ Channel: "telegram", ChatID: "chat1", SenderID: "user1", Content: "take a screenshot of the screen and send it to me", - }) + })) if err != nil { t.Fatalf("processMessage() error = %v", err) } @@ -661,16 +668,21 @@ func TestProcessMessage_MediaToolHandledSkipsFollowUpLLMAndFinalText(t *testing. if defaultAgent == nil { t.Fatal("expected default agent") } - route, _, err := al.resolveMessageRoute(bus.InboundMessage{ + route, _, err := al.resolveMessageRoute(testInboundMessage(bus.InboundMessage{ Channel: "telegram", ChatID: "chat1", SenderID: "user1", Content: "take a screenshot of the screen and send it to me", - }) + })) if err != nil { t.Fatalf("resolveMessageRoute() error = %v", err) } - sessionKey := resolveScopeKey(route, "") + sessionKey := resolveScopeKey(al.allocateRouteSession(route, testInboundMessage(bus.InboundMessage{ + Channel: "telegram", + ChatID: "chat1", + SenderID: "user1", + Content: "take a screenshot of the screen and send it to me", + })).SessionKey, "") history := defaultAgent.Sessions.GetHistory(sessionKey) if len(history) == 0 { t.Fatal("expected session history to be saved") @@ -714,12 +726,12 @@ func TestProcessMessage_HandledToolProcessesQueuedSteeringBeforeReturning(t *tes loop: al, }) - response, err := al.processMessage(context.Background(), bus.InboundMessage{ + response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{ Channel: "telegram", ChatID: "chat1", SenderID: "user1", Content: "take a screenshot of the screen and send it to me", - }) + })) if err != nil { t.Fatalf("processMessage() error = %v", err) } @@ -734,6 +746,263 @@ func TestProcessMessage_HandledToolProcessesQueuedSteeringBeforeReturning(t *tes } } +func TestRunAgentLoop_ResponseHandledToolPublishesForUserWhenSendResponseDisabled(t *testing.T) { + tmpDir := t.TempDir() + cfg := config.DefaultConfig() + cfg.Agents.Defaults.Workspace = tmpDir + cfg.Agents.Defaults.ModelName = "test-model" + cfg.Agents.Defaults.MaxTokens = 4096 + cfg.Agents.Defaults.MaxToolIterations = 10 + + msgBus := bus.NewMessageBus() + provider := &handledUserProvider{} + al := NewAgentLoop(cfg, msgBus, provider) + + store := media.NewFileMediaStore() + al.SetMediaStore(store) + telegramChannel := &fakeMediaChannel{fakeChannel: fakeChannel{id: "rid-telegram"}} + al.SetChannelManager(newStartedTestChannelManager(t, msgBus, store, "telegram", telegramChannel)) + al.RegisterTool(&handledUserTool{}) + + defaultAgent := al.registry.GetDefaultAgent() + if defaultAgent == nil { + t.Fatal("expected default agent") + } + + response, err := al.runAgentLoop(context.Background(), defaultAgent, processOptions{ + Dispatch: DispatchRequest{ + SessionKey: "session-1", + UserMessage: "take a screenshot of the screen and send it to me", + SessionScope: &session.SessionScope{ + Version: session.ScopeVersionV1, + AgentID: defaultAgent.ID, + Channel: "telegram", + Dimensions: []string{"chat"}, + Values: map[string]string{ + "chat": "direct:chat1", + }, + }, + InboundContext: &bus.InboundContext{ + Channel: "telegram", + ChatID: "chat1", + ChatType: "direct", + SenderID: "user1", + }, + }, + DefaultResponse: defaultResponse, + EnableSummary: false, + SendResponse: false, + }) + if err != nil { + t.Fatalf("runAgentLoop() error = %v", err) + } + if response != "" { + t.Fatalf("expected no final response when tool already handled delivery, got %q", response) + } + + deadline := time.Now().Add(2 * time.Second) + for len(telegramChannel.sentMessages) == 0 && time.Now().Before(deadline) { + time.Sleep(10 * time.Millisecond) + } + if len(telegramChannel.sentMessages) != 1 { + t.Fatalf("expected exactly 1 sent text message, got %d", len(telegramChannel.sentMessages)) + } + if telegramChannel.sentMessages[0].Content != "Handled user output from tool." { + t.Fatalf("unexpected sent text message: %+v", telegramChannel.sentMessages[0]) + } + if telegramChannel.sentMessages[0].AgentID != defaultAgent.ID { + t.Fatalf("sent text agent_id = %q, want %q", telegramChannel.sentMessages[0].AgentID, defaultAgent.ID) + } + if telegramChannel.sentMessages[0].SessionKey != "session-1" { + t.Fatalf("sent text session_key = %q, want session-1", telegramChannel.sentMessages[0].SessionKey) + } + if telegramChannel.sentMessages[0].Scope == nil || + telegramChannel.sentMessages[0].Scope.Values["chat"] != "direct:chat1" { + t.Fatalf("unexpected sent text scope: %+v", telegramChannel.sentMessages[0].Scope) + } +} + +func TestAppendEventContextFields_IncludesInboundRouteAndScope(t *testing.T) { + fields := map[string]any{} + + appendEventContextFields(fields, &TurnContext{ + Inbound: &bus.InboundContext{ + Channel: "slack", + Account: "workspace-a", + ChatID: "C123", + ChatType: "channel", + TopicID: "thread-42", + SpaceType: "workspace", + SpaceID: "T001", + SenderID: "U123", + Mentioned: true, + }, + Route: &routing.ResolvedRoute{ + AgentID: "support", + Channel: "slack", + AccountID: "workspace-a", + MatchedBy: "default", + SessionPolicy: routing.SessionPolicy{ + Dimensions: []string{"chat", "sender"}, + IdentityLinks: map[string][]string{ + "canonical-user": {"slack:U123"}, + }, + }, + }, + Scope: &session.SessionScope{ + Version: session.ScopeVersionV1, + AgentID: "support", + Channel: "slack", + Account: "workspace-a", + Dimensions: []string{"chat", "sender"}, + Values: map[string]string{ + "chat": "channel:c123", + "sender": "u123", + }, + }, + }) + + if fields["inbound_channel"] != "slack" { + t.Fatalf("inbound_channel = %v, want slack", fields["inbound_channel"]) + } + if fields["inbound_topic_id"] != "thread-42" { + t.Fatalf("inbound_topic_id = %v, want thread-42", fields["inbound_topic_id"]) + } + if fields["route_matched_by"] != "default" { + t.Fatalf("route_matched_by = %v, want default", fields["route_matched_by"]) + } + if fields["route_dimensions"] != "chat,sender" { + t.Fatalf("route_dimensions = %v, want chat,sender", fields["route_dimensions"]) + } + if fields["route_identity_link_count"] != 1 { + t.Fatalf("route_identity_link_count = %v, want 1", fields["route_identity_link_count"]) + } + if fields["scope_dimensions"] != "chat,sender" { + t.Fatalf("scope_dimensions = %v, want chat,sender", fields["scope_dimensions"]) + } + if fields["scope_chat"] != "channel:c123" { + t.Fatalf("scope_chat = %v, want channel:c123", fields["scope_chat"]) + } + if fields["scope_sender"] != "u123" { + t.Fatalf("scope_sender = %v, want u123", fields["scope_sender"]) + } +} + +func TestResolveMessageRoute_UsesInboundContextAccount(t *testing.T) { + tmpDir := t.TempDir() + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + ModelName: "test-model", + }, + List: []config.AgentConfig{ + {ID: "main", Default: true}, + {ID: "work"}, + }, + }, + Session: config.SessionConfig{ + Dimensions: []string{"sender"}, + }, + } + + msgBus := bus.NewMessageBus() + al := NewAgentLoop(cfg, msgBus, &simpleMockProvider{response: "ok"}) + + route, _, err := al.resolveMessageRoute(testInboundMessage(bus.InboundMessage{ + Context: bus.InboundContext{ + Channel: "slack", + Account: "workspace-a", + ChatID: "C123", + ChatType: "channel", + SenderID: "U123", + SpaceID: "T001", + SpaceType: "workspace", + }, + Content: "hello", + })) + if err != nil { + t.Fatalf("resolveMessageRoute() error = %v", err) + } + if route.AgentID != "main" { + t.Fatalf("AgentID = %q, want main", route.AgentID) + } + if route.MatchedBy != "default" { + t.Fatalf("MatchedBy = %q, want default", route.MatchedBy) + } + if route.AccountID != "workspace-a" { + t.Fatalf("AccountID = %q, want workspace-a", route.AccountID) + } +} + +func TestResolveMessageRoute_UsesDispatchRulesInOrder(t *testing.T) { + tmpDir := t.TempDir() + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + ModelName: "test-model", + }, + List: []config.AgentConfig{ + {ID: "main", Default: true}, + {ID: "support"}, + {ID: "sales"}, + }, + Dispatch: &config.DispatchConfig{ + Rules: []config.DispatchRule{ + { + Name: "support-group", + Agent: "support", + When: config.DispatchSelector{ + Channel: "telegram", + Chat: "group:-100123", + }, + SessionDimensions: []string{"chat"}, + }, + { + Name: "vip-in-group", + Agent: "sales", + When: config.DispatchSelector{ + Channel: "telegram", + Chat: "group:-100123", + Sender: "12345", + }, + SessionDimensions: []string{"chat", "sender"}, + }, + }, + }, + }, + Session: config.SessionConfig{ + Dimensions: []string{"sender"}, + }, + } + + msgBus := bus.NewMessageBus() + al := NewAgentLoop(cfg, msgBus, &simpleMockProvider{response: "ok"}) + + route, _, err := al.resolveMessageRoute(testInboundMessage(bus.InboundMessage{ + Context: bus.InboundContext{ + Channel: "telegram", + ChatID: "-100123", + ChatType: "group", + SenderID: "12345", + }, + Content: "hello", + })) + if err != nil { + t.Fatalf("resolveMessageRoute() error = %v", err) + } + if route.AgentID != "support" { + t.Fatalf("AgentID = %q, want support", route.AgentID) + } + if route.MatchedBy != "dispatch.rule:support-group" { + t.Fatalf("MatchedBy = %q, want dispatch.rule:support-group", route.MatchedBy) + } + if got := route.SessionPolicy.Dimensions; len(got) != 1 || got[0] != "chat" { + t.Fatalf("SessionPolicy.Dimensions = %v, want [chat]", got) + } +} + func TestProcessMessage_MediaArtifactCanBeForwardedBySendFile(t *testing.T) { tmpDir := t.TempDir() cfg := config.DefaultConfig() @@ -765,12 +1034,12 @@ func TestProcessMessage_MediaArtifactCanBeForwardedBySendFile(t *testing.T) { path: imagePath, }) - response, err := al.processMessage(context.Background(), bus.InboundMessage{ + response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{ Channel: "telegram", ChatID: "chat1", SenderID: "user1", Content: "take a screenshot of the screen and send it to me", - }) + })) if err != nil { t.Fatalf("processMessage() error = %v", err) } @@ -975,6 +1244,66 @@ func (m *handledMediaProvider) GetDefaultModel() string { return "handled-media-model" } +type handledUserProvider struct { + calls int +} + +func (m *handledUserProvider) Chat( + ctx context.Context, + messages []providers.Message, + tools []providers.ToolDefinition, + model string, + opts map[string]any, +) (*providers.LLMResponse, error) { + m.calls++ + if m.calls == 1 { + return &providers.LLMResponse{ + Content: "Delivering the result now.", + ToolCalls: []providers.ToolCall{{ + ID: "call_handled_user", + Type: "function", + Name: "handled_user_tool", + Arguments: map[string]any{}, + }}, + }, nil + } + return &providers.LLMResponse{}, nil +} + +func (m *handledUserProvider) GetDefaultModel() string { + return "handled-user-model" +} + +type messageToolProvider struct { + calls int +} + +func (m *messageToolProvider) Chat( + ctx context.Context, + messages []providers.Message, + tools []providers.ToolDefinition, + model string, + opts map[string]any, +) (*providers.LLMResponse, error) { + m.calls++ + if m.calls == 1 { + return &providers.LLMResponse{ + Content: "", + ToolCalls: []providers.ToolCall{{ + ID: "call_message", + Type: "function", + Name: "message", + Arguments: map[string]any{"content": "direct tool message"}, + }}, + }, nil + } + return &providers.LLMResponse{}, nil +} + +func (m *messageToolProvider) GetDefaultModel() string { + return "message-tool-model" +} + type artifactThenSendProvider struct { calls int } @@ -1178,6 +1507,24 @@ func (m *handledMediaTool) Execute(ctx context.Context, args map[string]any) *to return tools.MediaResult("Attachment delivered by tool.", []string{ref}).WithResponseHandled() } +type handledUserTool struct{} + +func (m *handledUserTool) Name() string { return "handled_user_tool" } +func (m *handledUserTool) Description() string { + return "Returns a user-visible result and marks delivery as handled" +} + +func (m *handledUserTool) Parameters() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{}, + } +} + +func (m *handledUserTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult { + return tools.UserResult("Handled user output from tool.").WithResponseHandled() +} + type handledMediaWithSteeringProvider struct { calls int } @@ -1391,13 +1738,39 @@ func (h testHelper) executeAndGetResponse(tb testing.TB, ctx context.Context, ms timeoutCtx, cancel := context.WithTimeout(ctx, responseTimeout) defer cancel() - response, err := h.al.processMessage(timeoutCtx, msg) + response, err := h.al.processMessage(timeoutCtx, testInboundMessage(msg)) if err != nil { tb.Fatalf("processMessage failed: %v", err) } return response } +func testInboundMessage(msg bus.InboundMessage) bus.InboundMessage { + if msg.Context.Channel == "" && + msg.Context.Account == "" && + msg.Context.ChatID == "" && + msg.Context.ChatType == "" && + msg.Context.TopicID == "" && + msg.Context.SpaceID == "" && + msg.Context.SpaceType == "" && + msg.Context.SenderID == "" && + msg.Context.MessageID == "" && + !msg.Context.Mentioned && + msg.Context.ReplyToMessageID == "" && + msg.Context.ReplyToSenderID == "" && + len(msg.Context.ReplyHandles) == 0 && + len(msg.Context.Raw) == 0 { + msg.Context = bus.InboundContext{ + Channel: msg.Channel, + ChatID: msg.ChatID, + ChatType: "direct", + SenderID: msg.SenderID, + MessageID: msg.MessageID, + } + } + return bus.NormalizeInboundMessage(msg) +} + const responseTimeout = 3 * time.Second func TestProcessMessage_UsesRouteSessionKey(t *testing.T) { @@ -1423,21 +1796,17 @@ func TestProcessMessage_UsesRouteSessionKey(t *testing.T) { al := NewAgentLoop(cfg, msgBus, provider) msg := bus.InboundMessage{ - Channel: "telegram", - SenderID: "user1", - ChatID: "chat1", - Content: "hello", - Peer: bus.Peer{ - Kind: "direct", - ID: "user1", + Context: bus.InboundContext{ + Channel: "telegram", + ChatID: "chat1", + ChatType: "direct", + SenderID: "user1", }, + Content: "hello", } - route := al.registry.ResolveRoute(routing.RouteInput{ - Channel: msg.Channel, - Peer: extractPeer(msg), - }) - sessionKey := route.SessionKey + route := al.registry.ResolveRoute(bus.NormalizeInboundMessage(msg).Context) + sessionKey := al.allocateRouteSession(route, msg).SessionKey defaultAgent := al.registry.GetDefaultAgent() if defaultAgent == nil { @@ -1473,7 +1842,7 @@ func TestProcessMessage_CommandOutcomes(t *testing.T) { }, }, Session: config.SessionConfig{ - DMScope: "per-channel-peer", + Dimensions: []string{"chat"}, }, } @@ -1483,21 +1852,22 @@ func TestProcessMessage_CommandOutcomes(t *testing.T) { helper := testHelper{al: al} baseMsg := bus.InboundMessage{ - Channel: "whatsapp", - SenderID: "user1", - ChatID: "chat1", - Peer: bus.Peer{ - Kind: "direct", - ID: "user1", + Context: bus.InboundContext{ + Channel: "whatsapp", + ChatID: "chat1", + ChatType: "direct", + SenderID: "user1", }, } showResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{ - Channel: baseMsg.Channel, - SenderID: baseMsg.SenderID, - ChatID: baseMsg.ChatID, - Content: "/show channel", - Peer: baseMsg.Peer, + Context: bus.InboundContext{ + Channel: baseMsg.Context.Channel, + ChatID: baseMsg.Context.ChatID, + ChatType: baseMsg.Context.ChatType, + SenderID: baseMsg.Context.SenderID, + }, + Content: "/show channel", }) if showResp != "Current Channel: whatsapp" { t.Fatalf("unexpected /show reply: %q", showResp) @@ -1507,11 +1877,13 @@ func TestProcessMessage_CommandOutcomes(t *testing.T) { } fooResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{ - Channel: baseMsg.Channel, - SenderID: baseMsg.SenderID, - ChatID: baseMsg.ChatID, - Content: "/foo", - Peer: baseMsg.Peer, + Context: bus.InboundContext{ + Channel: baseMsg.Context.Channel, + ChatID: baseMsg.Context.ChatID, + ChatType: baseMsg.Context.ChatType, + SenderID: baseMsg.Context.SenderID, + }, + Content: "/foo", }) if fooResp != "LLM reply" { t.Fatalf("unexpected /foo reply: %q", fooResp) @@ -1521,11 +1893,13 @@ func TestProcessMessage_CommandOutcomes(t *testing.T) { } newResp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{ - Channel: baseMsg.Channel, - SenderID: baseMsg.SenderID, - ChatID: baseMsg.ChatID, - Content: "/new", - Peer: baseMsg.Peer, + Context: bus.InboundContext{ + Channel: baseMsg.Context.Channel, + ChatID: baseMsg.Context.ChatID, + ChatType: baseMsg.Context.ChatType, + SenderID: baseMsg.Context.SenderID, + }, + Content: "/new", }) if newResp != "LLM reply" { t.Fatalf("unexpected /new reply: %q", newResp) @@ -1578,10 +1952,6 @@ func TestProcessMessage_SwitchModelShowModelConsistency(t *testing.T) { SenderID: "user1", ChatID: "chat1", Content: "/switch model to deepseek", - Peer: bus.Peer{ - Kind: "direct", - ID: "user1", - }, }) if !strings.Contains(switchResp, "Switched model from local to deepseek") { t.Fatalf("unexpected /switch reply: %q", switchResp) @@ -1592,10 +1962,6 @@ func TestProcessMessage_SwitchModelShowModelConsistency(t *testing.T) { SenderID: "user1", ChatID: "chat1", Content: "/show model", - Peer: bus.Peer{ - Kind: "direct", - ID: "user1", - }, }) if !strings.Contains(showResp, "Current Model: deepseek (Provider: openrouter)") { t.Fatalf("unexpected /show model reply after switch: %q", showResp) @@ -1643,10 +2009,6 @@ func TestProcessMessage_SwitchModelRejectsUnknownAlias(t *testing.T) { SenderID: "user1", ChatID: "chat1", Content: "/switch model to missing", - Peer: bus.Peer{ - Kind: "direct", - ID: "user1", - }, }) if switchResp != `model "missing" not found in model_list or providers` { t.Fatalf("unexpected /switch error reply: %q", switchResp) @@ -1657,10 +2019,6 @@ func TestProcessMessage_SwitchModelRejectsUnknownAlias(t *testing.T) { SenderID: "user1", ChatID: "chat1", Content: "/show model", - Peer: bus.Peer{ - Kind: "direct", - ID: "user1", - }, }) if !strings.Contains(showResp, "Current Model: local (Provider: openai)") { t.Fatalf("unexpected /show model reply after rejected switch: %q", showResp) @@ -1727,10 +2085,6 @@ func TestProcessMessage_SwitchModelRoutesSubsequentRequestsToSelectedProvider(t SenderID: "user1", ChatID: "chat1", Content: "hello before switch", - Peer: bus.Peer{ - Kind: "direct", - ID: "user1", - }, }) if firstResp != "local reply" { t.Fatalf("unexpected response before switch: %q", firstResp) @@ -1750,10 +2104,6 @@ func TestProcessMessage_SwitchModelRoutesSubsequentRequestsToSelectedProvider(t SenderID: "user1", ChatID: "chat1", Content: "/switch model to deepseek", - Peer: bus.Peer{ - Kind: "direct", - ID: "user1", - }, }) if !strings.Contains(switchResp, "Switched model from local to deepseek") { t.Fatalf("unexpected /switch reply: %q", switchResp) @@ -1764,10 +2114,6 @@ func TestProcessMessage_SwitchModelRoutesSubsequentRequestsToSelectedProvider(t SenderID: "user1", ChatID: "chat1", Content: "hello after switch", - Peer: bus.Peer{ - Kind: "direct", - ID: "user1", - }, }) if secondResp != "remote reply" { t.Fatalf("unexpected response after switch: %q", secondResp) @@ -1857,10 +2203,6 @@ func TestProcessMessage_ModelRoutingUsesLightProvider(t *testing.T) { SenderID: "user1", ChatID: "chat1", Content: "hi", - Peer: bus.Peer{ - Kind: "direct", - ID: "user1", - }, }) if resp != "light reply" { t.Fatalf("response = %q, want %q", resp, "light reply") @@ -1942,7 +2284,6 @@ func TestProcessMessage_FallbackUsesPerCandidateProvider(t *testing.T) { SenderID: "user1", ChatID: "chat1", Content: "hi", - Peer: bus.Peer{Kind: "direct", ID: "user1"}, }) if resp != "fallback reply" { @@ -2020,7 +2361,6 @@ func TestProcessMessage_FallbackUsesActiveProviderWhenCandidateNotRegistered(t * SenderID: "user1", ChatID: "chat1", Content: "hi", - Peer: bus.Peer{Kind: "direct", ID: "user1"}, }) if resp != "active provider reply" { @@ -2291,14 +2631,16 @@ func TestAgentLoop_ToolLimitUsesDedicatedFallback(t *testing.T) { if defaultAgent == nil { t.Fatal("No default agent found") } - route := al.registry.ResolveRoute(routing.RouteInput{ - Channel: "test", - Peer: &routing.RoutePeer{ - Kind: "direct", - ID: "cron", - }, + route := al.registry.ResolveRoute(bus.InboundContext{ + Channel: "test", + ChatType: "direct", + SenderID: "cron", }) - history := defaultAgent.Sessions.GetHistory(route.SessionKey) + history := defaultAgent.Sessions.GetHistory(al.allocateRouteSession(route, testInboundMessage(bus.InboundMessage{ + Channel: "test", + SenderID: "cron", + ChatID: "chat1", + })).SessionKey) if len(history) != 4 { t.Fatalf("history len = %d, want 4", len(history)) } @@ -2556,8 +2898,7 @@ func TestHandleReasoning(t *testing.T) { for i := 0; ; i++ { fillCtx, fillCancel := context.WithTimeout(context.Background(), 50*time.Millisecond) err := msgBus.PublishOutbound(fillCtx, bus.OutboundMessage{ - Channel: "filler", - ChatID: "filler", + Context: bus.NewOutboundContext("filler", "filler", ""), Content: fmt.Sprintf("filler-%d", i), }) fillCancel() @@ -2631,12 +2972,12 @@ func TestProcessMessage_PublishesReasoningContentToReasoningChannel(t *testing.T chManager.RegisterChannel("telegram", &fakeChannel{id: "reason-chat"}) al.SetChannelManager(chManager) - response, err := al.processMessage(context.Background(), bus.InboundMessage{ + response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{ Channel: "telegram", SenderID: "user1", ChatID: "chat1", Content: "hello", - }) + })) if err != nil { t.Fatalf("processMessage() error = %v", err) } @@ -2652,6 +2993,9 @@ func TestProcessMessage_PublishesReasoningContentToReasoningChannel(t *testing.T if outbound.ChatID != "reason-chat" { t.Fatalf("reasoning chatID = %q, want %q", outbound.ChatID, "reason-chat") } + if outbound.Context.Channel != "telegram" || outbound.Context.ChatID != "reason-chat" { + t.Fatalf("unexpected reasoning context: %+v", outbound.Context) + } if outbound.Content != "thinking trace" { t.Fatalf("reasoning content = %q, want %q", outbound.Content, "thinking trace") } @@ -2711,8 +3055,12 @@ func TestProcessMessage_PicoPublishesReasoningAsThoughtMessage(t *testing.T) { if thoughtMsg.Channel != "pico" || thoughtMsg.ChatID != "pico:test-session" { t.Fatalf("thought message route = %s/%s, want pico/pico:test-session", thoughtMsg.Channel, thoughtMsg.ChatID) } - if thoughtMsg.Metadata[metadataKeyMessageKind] != messageKindThought { - t.Fatalf("thought metadata kind = %q, want %q", thoughtMsg.Metadata[metadataKeyMessageKind], messageKindThought) + if thoughtMsg.Context.Raw[metadataKeyMessageKind] != messageKindThought { + t.Fatalf( + "thought metadata kind = %q, want %q", + thoughtMsg.Context.Raw[metadataKeyMessageKind], + messageKindThought, + ) } } @@ -2793,12 +3141,12 @@ func TestProcessMessage_PublishesToolFeedbackWhenEnabled(t *testing.T) { provider := &toolFeedbackProvider{filePath: heartbeatFile} al := NewAgentLoop(cfg, msgBus, provider) - response, err := al.processMessage(context.Background(), bus.InboundMessage{ + response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{ Channel: "telegram", SenderID: "user-1", ChatID: "chat-1", Content: "check tool feedback", - }) + })) if err != nil { t.Fatalf("processMessage() error = %v", err) } @@ -2814,14 +3162,73 @@ func TestProcessMessage_PublishesToolFeedbackWhenEnabled(t *testing.T) { if outbound.ChatID != "chat-1" { t.Fatalf("tool feedback chatID = %q, want %q", outbound.ChatID, "chat-1") } + if outbound.Context.Channel != "telegram" || outbound.Context.ChatID != "chat-1" { + t.Fatalf("unexpected tool feedback context: %+v", outbound.Context) + } if !strings.Contains(outbound.Content, "`read_file`") { t.Fatalf("tool feedback content = %q, want read_file preview", outbound.Content) } + if outbound.AgentID != "main" { + t.Fatalf("tool feedback agent_id = %q, want main", outbound.AgentID) + } + if outbound.SessionKey == "" { + t.Fatal("expected tool feedback to carry session_key") + } + if outbound.Scope == nil || outbound.Scope.AgentID != "main" || outbound.Scope.Channel != "telegram" { + t.Fatalf("expected tool feedback scope, got %+v", outbound.Scope) + } case <-time.After(2 * time.Second): t.Fatal("expected outbound tool feedback for regular messages") } } +func TestProcessMessage_MessageToolPublishesOutboundWithTurnMetadata(t *testing.T) { + cfg := config.DefaultConfig() + cfg.Agents.Defaults.Workspace = t.TempDir() + cfg.Agents.Defaults.ModelName = "test-model" + cfg.Agents.Defaults.MaxTokens = 4096 + cfg.Agents.Defaults.MaxToolIterations = 10 + cfg.Session.Dimensions = []string{"chat"} + + msgBus := bus.NewMessageBus() + provider := &messageToolProvider{} + al := NewAgentLoop(cfg, msgBus, provider) + + response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{ + Channel: "telegram", + SenderID: "user-1", + ChatID: "chat-1", + Content: "send a direct message", + })) + if err != nil { + t.Fatalf("processMessage() error = %v", err) + } + if response == "" { + t.Fatal("expected processMessage() to return a final loop response") + } + + select { + case outbound := <-msgBus.OutboundChan(): + if outbound.Content != "direct tool message" { + t.Fatalf("outbound content = %q, want direct tool message", outbound.Content) + } + if outbound.AgentID != "main" { + t.Fatalf("outbound agent_id = %q, want main", outbound.AgentID) + } + if outbound.SessionKey == "" { + t.Fatal("expected message tool outbound to carry session_key") + } + if outbound.Scope == nil || outbound.Scope.Values["chat"] != "direct:chat-1" { + t.Fatalf("unexpected message tool outbound scope: %+v", outbound.Scope) + } + if outbound.Context.Channel != "telegram" || outbound.Context.ChatID != "chat-1" { + t.Fatalf("unexpected message tool outbound context: %+v", outbound.Context) + } + case <-time.After(2 * time.Second): + t.Fatal("expected message tool outbound") + } +} + func TestRun_PicoPublishesAssistantContentDuringToolCallsWithoutFinalDuplicate(t *testing.T) { tmpDir := t.TempDir() @@ -3363,13 +3770,13 @@ func TestProcessMessage_ContextOverflowRecovery(t *testing.T) { agent.Sessions.AddFullMessage(sessionKey, providers.Message{Role: "assistant", Content: "response"}) } - response, err := al.processMessage(context.Background(), bus.InboundMessage{ + response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{ Channel: "test", ChatID: "chat1", SenderID: "user1", SessionKey: "test-session", Content: "trigger recovery", - }) + })) if err != nil { t.Fatalf("processMessage() error = %v", err) } @@ -3405,12 +3812,12 @@ func TestProcessMessage_ContextOverflow_AnthropicStyle(t *testing.T) { return &providers.LLMResponse{Content: "Anthropic recovery success"}, nil } - response, err := al.processMessage(context.Background(), bus.InboundMessage{ + response, err := al.processMessage(context.Background(), testInboundMessage(bus.InboundMessage{ Channel: "test", ChatID: "chat1", SenderID: "user1", Content: "hello", - }) + })) if err != nil { t.Fatalf("processMessage() error = %v", err) } diff --git a/pkg/agent/registry.go b/pkg/agent/registry.go index 58b7ce440..8aa11e37b 100644 --- a/pkg/agent/registry.go +++ b/pkg/agent/registry.go @@ -3,6 +3,7 @@ package agent import ( "sync" + "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/providers" @@ -64,9 +65,9 @@ func (r *AgentRegistry) GetAgent(agentID string) (*AgentInstance, bool) { return agent, ok } -// ResolveRoute determines which agent handles the message. -func (r *AgentRegistry) ResolveRoute(input routing.RouteInput) routing.ResolvedRoute { - return r.resolver.ResolveRoute(input) +// ResolveRoute determines which agent handles the normalized inbound context. +func (r *AgentRegistry) ResolveRoute(inbound bus.InboundContext) routing.ResolvedRoute { + return r.resolver.ResolveRoute(inbound) } // ListAgentIDs returns all registered agent IDs. diff --git a/pkg/agent/steering.go b/pkg/agent/steering.go index ad6613e8c..a2e5fec21 100644 --- a/pkg/agent/steering.go +++ b/pkg/agent/steering.go @@ -3,12 +3,14 @@ package agent import ( "context" "fmt" + "sort" "strings" "sync" + "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/providers" - "github.com/sipeed/picoclaw/pkg/routing" + "github.com/sipeed/picoclaw/pkg/session" "github.com/sipeed/picoclaw/pkg/tools" ) @@ -290,12 +292,22 @@ func (al *AgentLoop) continueWithSteeringMessages( ctx context.Context, agent *AgentInstance, sessionKey, channel, chatID string, + scope *session.SessionScope, steeringMsgs []providers.Message, ) (string, error) { + dispatch := DispatchRequest{ + SessionKey: sessionKey, + SessionScope: session.CloneScope(scope), + } + if channel != "" || chatID != "" { + dispatch.InboundContext = &bus.InboundContext{ + Channel: channel, + ChatID: chatID, + ChatType: inferChatTypeFromSessionScope(scope), + } + } return al.runAgentLoop(ctx, agent, processOptions{ - SessionKey: sessionKey, - Channel: channel, - ChatID: chatID, + Dispatch: dispatch, DefaultResponse: defaultResponse, EnableSummary: true, SendResponse: false, @@ -310,9 +322,19 @@ func (al *AgentLoop) agentForSession(sessionKey string) *AgentInstance { return nil } - if parsed := routing.ParseAgentSessionKey(sessionKey); parsed != nil { - if agent, ok := registry.GetAgent(parsed.AgentID); ok { - return agent + agentIDs := registry.ListAgentIDs() + sort.Strings(agentIDs) + for _, agentID := range agentIDs { + agent, ok := registry.GetAgent(agentID) + if !ok || agent == nil { + continue + } + resolvedAgentID := session.ResolveAgentID(agent.Sessions, sessionKey) + if resolvedAgentID == "" { + continue + } + if scopedAgent, ok := registry.GetAgent(resolvedAgentID); ok { + return scopedAgent } } @@ -352,7 +374,12 @@ func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID s } } - return al.continueWithSteeringMessages(ctx, agent, sessionKey, channel, chatID, steeringMsgs) + var scope *session.SessionScope + if metaStore, ok := agent.Sessions.(session.MetadataAwareSessionStore); ok { + scope = metaStore.GetSessionScope(sessionKey) + } + + return al.continueWithSteeringMessages(ctx, agent, sessionKey, channel, chatID, scope, steeringMsgs) } func (al *AgentLoop) InterruptGraceful(hint string) error { diff --git a/pkg/agent/steering_test.go b/pkg/agent/steering_test.go index 75ba9861d..8e6063f08 100644 --- a/pkg/agent/steering_test.go +++ b/pkg/agent/steering_test.go @@ -17,6 +17,7 @@ import ( "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/routing" + "github.com/sipeed/picoclaw/pkg/session" "github.com/sipeed/picoclaw/pkg/tools" ) @@ -357,7 +358,7 @@ func TestDrainBusToSteering_RequeuesDifferentScopeMessage(t *testing.T) { }, }, Session: config.SessionConfig{ - DMScope: "per-peer", + Dimensions: []string{"sender"}, }, } @@ -365,14 +366,13 @@ func TestDrainBusToSteering_RequeuesDifferentScopeMessage(t *testing.T) { al := NewAgentLoop(cfg, msgBus, &mockProvider{}) activeMsg := bus.InboundMessage{ - Channel: "telegram", - SenderID: "user1", - ChatID: "chat1", - Content: "active turn", - Peer: bus.Peer{ - Kind: "direct", - ID: "user1", + Context: bus.InboundContext{ + Channel: "telegram", + ChatID: "chat1", + ChatType: "direct", + SenderID: "user1", }, + Content: "active turn", } activeScope, activeAgentID, ok := al.resolveSteeringTarget(activeMsg) if !ok { @@ -380,14 +380,13 @@ func TestDrainBusToSteering_RequeuesDifferentScopeMessage(t *testing.T) { } otherMsg := bus.InboundMessage{ - Channel: "telegram", - SenderID: "user2", - ChatID: "chat2", - Content: "other session", - Peer: bus.Peer{ - Kind: "direct", - ID: "user2", + Context: bus.InboundContext{ + Channel: "telegram", + ChatID: "chat2", + ChatType: "direct", + SenderID: "user2", }, + Content: "other session", } otherScope, _, ok := al.resolveSteeringTarget(otherMsg) if !ok { @@ -422,9 +421,9 @@ func TestDrainBusToSteering_RequeuesDifferentScopeMessage(t *testing.T) { select { case <-ctx.Done(): - t.Fatalf("timeout waiting for requeued message on outbound bus") - case requeued := <-msgBus.OutboundChan(): - if requeued.Channel != otherMsg.Channel || requeued.ChatID != otherMsg.ChatID || + t.Fatalf("timeout waiting for requeued message on inbound bus") + case requeued := <-msgBus.InboundChan(): + if requeued.Context.Channel != otherMsg.Context.Channel || requeued.Context.ChatID != otherMsg.Context.ChatID || requeued.Content != otherMsg.Content { t.Fatalf("requeued message mismatch: got %+v want %+v", requeued, otherMsg) } @@ -841,24 +840,22 @@ func TestAgentLoop_Run_AutoContinuesLateSteeringMessage(t *testing.T) { }() first := bus.InboundMessage{ - Channel: "test", - SenderID: "user1", - ChatID: "chat1", - Content: "first message", - Peer: bus.Peer{ - Kind: "direct", - ID: "user1", + Context: bus.InboundContext{ + Channel: "test", + ChatID: "chat1", + ChatType: "direct", + SenderID: "user1", }, + Content: "first message", } late := bus.InboundMessage{ - Channel: "test", - SenderID: "user1", - ChatID: "chat1", - Content: "late append", - Peer: bus.Peer{ - Kind: "direct", - ID: "user1", + Context: bus.InboundContext{ + Channel: "test", + ChatID: "chat1", + ChatType: "direct", + SenderID: "user1", }, + Content: "late append", } pubCtx, pubCancel := context.WithTimeout(context.Background(), 2*time.Second) @@ -949,7 +946,7 @@ func TestAgentLoop_Steering_DirectResponseContinuesWithQueuedMessage(t *testing. }, } - sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID) + sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID) provider := &blockingDirectProvider{ firstStarted: make(chan struct{}), releaseFirst: make(chan struct{}), @@ -1013,6 +1010,62 @@ func TestAgentLoop_Steering_DirectResponseContinuesWithQueuedMessage(t *testing. } } +func TestAgentLoop_AgentForSession_UsesStoredScopeMetadata(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + ModelName: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + List: []config.AgentConfig{ + {ID: "sales", Default: true}, + {ID: "support"}, + }, + }, + } + + al := NewAgentLoop(cfg, bus.NewMessageBus(), &mockProvider{}) + support, ok := al.registry.GetAgent("support") + if !ok || support == nil { + t.Fatal("expected support agent") + } + + metaStore, ok := support.Sessions.(session.MetadataAwareSessionStore) + if !ok { + t.Fatal("support session store does not support metadata") + } + + alias := "agent:support:slack:channel:c001" + key := session.BuildOpaqueSessionKey(alias) + scope := &session.SessionScope{ + Version: session.ScopeVersionV1, + AgentID: "support", + Channel: "slack", + Account: "default", + Dimensions: []string{"chat"}, + Values: map[string]string{ + "chat": "channel:c001", + }, + } + metaStore.EnsureSessionMetadata(key, scope, []string{alias}) + + got := al.agentForSession(key) + if got == nil { + t.Fatal("agentForSession() returned nil") + } + if got.ID != "support" { + t.Fatalf("agentForSession() = %q, want %q", got.ID, "support") + } +} + func TestAgentLoop_Continue_PreservesSteeringMedia(t *testing.T) { tmpDir, err := os.MkdirTemp("", "agent-test-*") if err != nil { @@ -1060,7 +1113,7 @@ func TestAgentLoop_Continue_PreservesSteeringMedia(t *testing.T) { }, } - sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID) + sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID) msgBus := bus.NewMessageBus() al := NewAgentLoop(cfg, msgBus, provider) al.SetMediaStore(store) @@ -1168,7 +1221,7 @@ func TestAgentLoop_InterruptGraceful_UsesTerminalNoToolCall(t *testing.T) { al := NewAgentLoop(cfg, msgBus, provider) al.RegisterTool(tool1) al.RegisterTool(tool2) - sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID) + sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID) sub := al.SubscribeEvents(32) defer al.UnsubscribeEvents(sub.ID) @@ -1322,7 +1375,7 @@ func TestAgentLoop_InterruptHard_RestoresSession(t *testing.T) { al := NewAgentLoop(cfg, msgBus, provider) started := make(chan struct{}) al.RegisterTool(&interruptibleTool{name: "cancel_tool", started: started}) - sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID) + sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID) defaultAgent := al.registry.GetDefaultAgent() if defaultAgent == nil { diff --git a/pkg/agent/subturn.go b/pkg/agent/subturn.go index 9ee7b15c9..cd193017b 100644 --- a/pkg/agent/subturn.go +++ b/pkg/agent/subturn.go @@ -351,15 +351,17 @@ func spawnSubTurn( } // Create processOptions for the child turn + dispatch := DispatchRequest{ + SessionKey: childID, + UserMessage: cfg.SystemPrompt, + Media: nil, + InboundContext: cloneInboundContext(parentTS.opts.Dispatch.InboundContext), + } opts := processOptions{ - SessionKey: childID, - Channel: parentTS.channel, - ChatID: parentTS.chatID, - SenderID: parentTS.opts.SenderID, + Dispatch: dispatch, + SenderID: parentTS.opts.Dispatch.SenderID(), SenderDisplayName: parentTS.opts.SenderDisplayName, - UserMessage: cfg.SystemPrompt, // Task description becomes the first user message SystemPromptOverride: cfg.ActualSystemPrompt, - Media: nil, InitialSteeringMessages: cfg.InitialMessages, DefaultResponse: "", EnableSummary: false, @@ -369,7 +371,11 @@ func spawnSubTurn( } // Create event scope for the child turn - scope := al.newTurnEventScope(agent.ID, childID) + scope := al.newTurnEventScope( + agent.ID, + childID, + newTurnContext(opts.Dispatch.InboundContext, opts.Dispatch.RouteResult, opts.Dispatch.SessionScope), + ) // Create child turnState using the new API childTS := newTurnState(&agent, opts, scope) diff --git a/pkg/agent/turn.go b/pkg/agent/turn.go index 8f099ed1d..a061742e3 100644 --- a/pkg/agent/turn.go +++ b/pkg/agent/turn.go @@ -56,6 +56,7 @@ type turnState struct { turnID string agentID string sessionKey string + turnCtx *TurnContext channel string chatID string @@ -115,11 +116,12 @@ func newTurnState(agent *AgentInstance, opts processOptions, scope turnEventScop scope: scope, turnID: scope.turnID, agentID: agent.ID, - sessionKey: opts.SessionKey, - channel: opts.Channel, - chatID: opts.ChatID, - userMessage: opts.UserMessage, - media: append([]string(nil), opts.Media...), + sessionKey: opts.Dispatch.SessionKey, + turnCtx: cloneTurnContext(scope.context), + channel: opts.Dispatch.Channel(), + chatID: opts.Dispatch.ChatID(), + userMessage: opts.Dispatch.UserMessage, + media: append([]string(nil), opts.Dispatch.Media...), phase: TurnPhaseSetup, startedAt: time.Now(), } @@ -127,7 +129,7 @@ func newTurnState(agent *AgentInstance, opts processOptions, scope turnEventScop // Bind session store and capture initial history length for rollback logic if agent != nil && agent.Sessions != nil { ts.session = agent.Sessions - ts.initialHistoryLength = len(agent.Sessions.GetHistory(opts.SessionKey)) + ts.initialHistoryLength = len(agent.Sessions.GetHistory(opts.Dispatch.SessionKey)) } return ts @@ -302,12 +304,13 @@ func (ts *turnState) hardAbortRequested() bool { func (ts *turnState) eventMeta(source, tracePath string) EventMeta { snap := ts.snapshot() return EventMeta{ - AgentID: snap.AgentID, - TurnID: snap.TurnID, - SessionKey: snap.SessionKey, - Iteration: snap.Iteration, - Source: source, - TracePath: tracePath, + AgentID: snap.AgentID, + TurnID: snap.TurnID, + SessionKey: snap.SessionKey, + Iteration: snap.Iteration, + Source: source, + TracePath: tracePath, + turnContext: cloneTurnContext(ts.turnCtx), } } diff --git a/pkg/agent/turn_context.go b/pkg/agent/turn_context.go new file mode 100644 index 000000000..8913993aa --- /dev/null +++ b/pkg/agent/turn_context.go @@ -0,0 +1,92 @@ +package agent + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/routing" + "github.com/sipeed/picoclaw/pkg/session" +) + +// TurnContext carries normalized turn-scoped facts that can be shared across +// events, hooks, and other runtime observers without re-parsing legacy fields. +type TurnContext struct { + Inbound *bus.InboundContext `json:"inbound,omitempty"` + Route *routing.ResolvedRoute `json:"route,omitempty"` + Scope *session.SessionScope `json:"scope,omitempty"` +} + +func newTurnContext( + inbound *bus.InboundContext, + route *routing.ResolvedRoute, + scope *session.SessionScope, +) *TurnContext { + if inbound == nil && route == nil && scope == nil { + return nil + } + return &TurnContext{ + Inbound: cloneInboundContext(inbound), + Route: cloneResolvedRoute(route), + Scope: session.CloneScope(scope), + } +} + +func cloneTurnContext(ctx *TurnContext) *TurnContext { + if ctx == nil { + return nil + } + cloned := *ctx + cloned.Inbound = cloneInboundContext(ctx.Inbound) + cloned.Route = cloneResolvedRoute(ctx.Route) + cloned.Scope = session.CloneScope(ctx.Scope) + return &cloned +} + +func cloneInboundContext(ctx *bus.InboundContext) *bus.InboundContext { + if ctx == nil { + return nil + } + cloned := *ctx + cloned.ReplyHandles = cloneStringMap(ctx.ReplyHandles) + cloned.Raw = cloneStringMap(ctx.Raw) + return &cloned +} + +func cloneStringMap(src map[string]string) map[string]string { + if len(src) == 0 { + return nil + } + cloned := make(map[string]string, len(src)) + for k, v := range src { + cloned[k] = v + } + return cloned +} + +func cloneEventMeta(meta EventMeta) EventMeta { + meta.turnContext = cloneTurnContext(meta.turnContext) + return meta +} + +func cloneResolvedRoute(route *routing.ResolvedRoute) *routing.ResolvedRoute { + if route == nil { + return nil + } + cloned := *route + cloned.SessionPolicy = routing.SessionPolicy{ + Dimensions: append([]string(nil), route.SessionPolicy.Dimensions...), + IdentityLinks: cloneIdentityLinks(route.SessionPolicy.IdentityLinks), + } + return &cloned +} + +func cloneIdentityLinks(src map[string][]string) map[string][]string { + if len(src) == 0 { + return nil + } + cloned := make(map[string][]string, len(src)) + for canonical, ids := range src { + dup := make([]string, len(ids)) + copy(dup, ids) + cloned[canonical] = dup + } + return cloned +} diff --git a/pkg/audio/asr/agent.go b/pkg/audio/asr/agent.go index 32ce0c92a..c483a0778 100644 --- a/pkg/audio/asr/agent.go +++ b/pkg/audio/asr/agent.go @@ -226,8 +226,7 @@ func (a *Agent) processUtterance(ctx context.Context, acc *speechAccumulator) { logger.ErrorCF("voice-agent", "Failed to publish leave control", map[string]any{"error": err}) } if err := a.bus.PublishOutbound(ctx, bus.OutboundMessage{ - Channel: channelType, - ChatID: acc.chatID, + Context: bus.NewOutboundContext(channelType, acc.chatID, ""), Content: "Goodbye! Leaving the voice channel.", }); err != nil { logger.ErrorCF("voice-agent", "Failed to publish goodbye message", map[string]any{"error": err}) @@ -238,14 +237,16 @@ func (a *Agent) processUtterance(ctx context.Context, acc *speechAccumulator) { oralPrompt := "\n\n[SYSTEM]: The user just spoke this to you over voice chat. Please reply in a highly concise, conversational, oral style suitable for text-to-speech. Do not use markdown, emojis, asterisks, or code blocks. Speak naturally." if err := a.bus.PublishInbound(ctx, bus.InboundMessage{ - Channel: channelType, - SenderID: acc.speakerID, - ChatID: acc.chatID, - Content: res.Text + oralPrompt, - Peer: bus.Peer{Kind: "channel", ID: acc.chatID}, - Metadata: map[string]string{ - "is_voice": "true", + Context: bus.InboundContext{ + Channel: channelType, + ChatID: acc.chatID, + ChatType: "channel", + SenderID: acc.speakerID, + Raw: map[string]string{ + "is_voice": "true", + }, }, + Content: res.Text + oralPrompt, }); err != nil { logger.ErrorCF("voice-agent", "Failed to publish inbound message", map[string]any{"error": err}) } diff --git a/pkg/audio/asr/agent_test.go b/pkg/audio/asr/agent_test.go index cc1b008a4..0f9bcb3b2 100644 --- a/pkg/audio/asr/agent_test.go +++ b/pkg/audio/asr/agent_test.go @@ -185,8 +185,8 @@ func TestAgentCheckSilencePublishesInboundAndCleansUp(t *testing.T) { if !strings.Contains(msg.Content, "hello there") { t.Fatalf("unexpected inbound content: %q", msg.Content) } - if msg.Metadata["is_voice"] != "true" { - t.Fatalf("expected is_voice metadata, got %#v", msg.Metadata) + if msg.Context.Raw["is_voice"] != "true" { + t.Fatalf("expected is_voice metadata, got %#v", msg.Context.Raw) } case <-time.After(500 * time.Millisecond): t.Fatal("expected inbound publish") diff --git a/pkg/bus/bus.go b/pkg/bus/bus.go index a9c74ef90..9a05d4f95 100644 --- a/pkg/bus/bus.go +++ b/pkg/bus/bus.go @@ -12,6 +12,12 @@ import ( // ErrBusClosed is returned when publishing to a closed MessageBus. var ErrBusClosed = errors.New("message bus closed") +var ( + ErrMissingInboundContext = errors.New("inbound message context is required") + ErrMissingOutboundContext = errors.New("outbound message context is required") + ErrMissingOutboundMediaContext = errors.New("outbound media context is required") +) + const defaultBusBufferSize = 64 // StreamDelegate is implemented by the channel Manager to provide streaming @@ -49,7 +55,7 @@ func NewMessageBus() *MessageBus { inbound: make(chan InboundMessage, defaultBusBufferSize), outbound: make(chan OutboundMessage, defaultBusBufferSize), outboundMedia: make(chan OutboundMediaMessage, defaultBusBufferSize), - audioChunks: make(chan AudioChunk, defaultBusBufferSize*4), // Audio chunks need more buffer + audioChunks: make(chan AudioChunk, defaultBusBufferSize*4), // Audio chunks need more buffer. voiceControls: make(chan VoiceControl, defaultBusBufferSize), done: make(chan struct{}), } @@ -84,6 +90,10 @@ func publish[T any](ctx context.Context, mb *MessageBus, ch chan T, msg T) error } func (mb *MessageBus) PublishInbound(ctx context.Context, msg InboundMessage) error { + msg = NormalizeInboundMessage(msg) + if msg.Context.isZero() { + return ErrMissingInboundContext + } return publish(ctx, mb, mb.inbound, msg) } @@ -92,6 +102,10 @@ func (mb *MessageBus) InboundChan() <-chan InboundMessage { } func (mb *MessageBus) PublishOutbound(ctx context.Context, msg OutboundMessage) error { + msg = NormalizeOutboundMessage(msg) + if msg.Context.isZero() { + return ErrMissingOutboundContext + } return publish(ctx, mb, mb.outbound, msg) } @@ -100,6 +114,10 @@ func (mb *MessageBus) OutboundChan() <-chan OutboundMessage { } func (mb *MessageBus) PublishOutboundMedia(ctx context.Context, msg OutboundMediaMessage) error { + msg = NormalizeOutboundMediaMessage(msg) + if msg.Context.isZero() { + return ErrMissingOutboundMediaContext + } return publish(ctx, mb, mb.outboundMedia, msg) } diff --git a/pkg/bus/bus_test.go b/pkg/bus/bus_test.go index 9b6324ca6..5145d4759 100644 --- a/pkg/bus/bus_test.go +++ b/pkg/bus/bus_test.go @@ -14,10 +14,13 @@ func TestPublishConsume(t *testing.T) { ctx := context.Background() msg := InboundMessage{ - Channel: "test", - SenderID: "user1", - ChatID: "chat1", - Content: "hello", + Context: InboundContext{ + Channel: "test", + ChatID: "chat1", + ChatType: "direct", + SenderID: "user1", + }, + Content: "hello", } if err := mb.PublishInbound(ctx, msg); err != nil { @@ -34,6 +37,138 @@ func TestPublishConsume(t *testing.T) { if got.Channel != "test" { t.Fatalf("expected channel 'test', got %q", got.Channel) } + if got.Context.Channel != "test" { + t.Fatalf("expected context channel 'test', got %q", got.Context.Channel) + } + if got.Context.ChatID != "chat1" { + t.Fatalf("expected context chat ID 'chat1', got %q", got.Context.ChatID) + } + if got.Context.SenderID != "user1" { + t.Fatalf("expected context sender ID 'user1', got %q", got.Context.SenderID) + } +} + +func TestPublishInbound_NormalizesContext(t *testing.T) { + mb := NewMessageBus() + defer mb.Close() + + msg := InboundMessage{ + Context: InboundContext{ + Channel: "slack", + Account: "workspace-a", + ChatID: "C456/1712", + ChatType: "group", + TopicID: "1712", + SpaceID: "T001", + SpaceType: "team", + SenderID: "U123", + MessageID: "1712.01", + ReplyToMessageID: "1700.01", + Mentioned: true, + }, + Content: "hello", + } + + if err := mb.PublishInbound(context.Background(), msg); err != nil { + t.Fatalf("PublishInbound failed: %v", err) + } + + got := <-mb.InboundChan() + if got.Context.Channel != "slack" { + t.Fatalf("expected context channel slack, got %q", got.Context.Channel) + } + if got.Context.Account != "workspace-a" { + t.Fatalf("expected context account workspace-a, got %q", got.Context.Account) + } + if got.Context.ChatType != "group" { + t.Fatalf("expected context chat type group, got %q", got.Context.ChatType) + } + if got.Context.TopicID != "1712" { + t.Fatalf("expected topic 1712, got %q", got.Context.TopicID) + } + if got.Context.SpaceType != "team" || got.Context.SpaceID != "T001" { + t.Fatalf("expected team space T001, got %q/%q", got.Context.SpaceType, got.Context.SpaceID) + } + if !got.Context.Mentioned { + t.Fatal("expected mentioned=true in context") + } + if got.Context.ReplyToMessageID != "1700.01" { + t.Fatalf("expected reply_to_message_id 1700.01, got %q", got.Context.ReplyToMessageID) + } +} + +func TestPublishInbound_MirrorsContextIntoConvenienceFields(t *testing.T) { + mb := NewMessageBus() + defer mb.Close() + + msg := InboundMessage{ + Context: InboundContext{ + Channel: "telegram", + Account: "bot-a", + ChatID: "-1001", + ChatType: "group", + TopicID: "42", + SpaceID: "guild-9", + SpaceType: "guild", + SenderID: "user-1", + MessageID: "777", + Mentioned: true, + ReplyToMessageID: "666", + }, + Content: "hi", + } + + if err := mb.PublishInbound(context.Background(), msg); err != nil { + t.Fatalf("PublishInbound failed: %v", err) + } + + got := <-mb.InboundChan() + if got.Channel != "telegram" { + t.Fatalf("expected legacy channel telegram, got %q", got.Channel) + } + if got.ChatID != "-1001" { + t.Fatalf("expected legacy chat ID -1001, got %q", got.ChatID) + } + if got.SenderID != "user-1" { + t.Fatalf("expected legacy sender ID user-1, got %q", got.SenderID) + } + if got.MessageID != "777" { + t.Fatalf("expected legacy message ID 777, got %q", got.MessageID) + } + if got.Context.Account != "bot-a" || got.Context.SpaceID != "guild-9" || got.Context.TopicID != "42" { + t.Fatalf("unexpected normalized context: %+v", got.Context) + } +} + +func TestPublishInbound_BackfillsContextFromLegacyFields(t *testing.T) { + mb := NewMessageBus() + defer mb.Close() + + msg := InboundMessage{ + Channel: "pico", + ChatID: "session-1", + SenderID: "user-1", + MessageID: "msg-1", + Content: "hello", + } + + if err := mb.PublishInbound(context.Background(), msg); err != nil { + t.Fatalf("PublishInbound failed: %v", err) + } + + got := <-mb.InboundChan() + if got.Context.Channel != "pico" { + t.Fatalf("expected context channel pico, got %q", got.Context.Channel) + } + if got.Context.ChatID != "session-1" { + t.Fatalf("expected context chat ID session-1, got %q", got.Context.ChatID) + } + if got.Context.SenderID != "user-1" { + t.Fatalf("expected context sender ID user-1, got %q", got.Context.SenderID) + } + if got.Context.MessageID != "msg-1" { + t.Fatalf("expected context message ID msg-1, got %q", got.Context.MessageID) + } } func TestPublishOutboundSubscribe(t *testing.T) { @@ -43,8 +178,10 @@ func TestPublishOutboundSubscribe(t *testing.T) { ctx := context.Background() msg := OutboundMessage{ - Channel: "telegram", - ChatID: "123", + Context: InboundContext{ + Channel: "telegram", + ChatID: "123", + }, Content: "world", } @@ -59,6 +196,222 @@ func TestPublishOutboundSubscribe(t *testing.T) { if got.Content != "world" { t.Fatalf("expected content 'world', got %q", got.Content) } + if got.Context.Channel != "telegram" || got.Context.ChatID != "123" { + t.Fatalf("expected normalized outbound context, got %+v", got.Context) + } +} + +func TestPublishOutbound_MirrorsContextToLegacyFields(t *testing.T) { + mb := NewMessageBus() + defer mb.Close() + + msg := OutboundMessage{ + Context: InboundContext{ + Channel: "telegram", + ChatID: "chat-42", + ReplyToMessageID: "msg-9", + }, + AgentID: "main", + SessionKey: "sk_v1_123", + Scope: &OutboundScope{ + Version: 1, + AgentID: "main", + Channel: "telegram", + Account: "bot-a", + Dimensions: []string{"chat", "sender"}, + Values: map[string]string{ + "chat": "direct:chat-42", + "sender": "user-1", + }, + }, + Content: "reply", + } + + if err := mb.PublishOutbound(context.Background(), msg); err != nil { + t.Fatalf("PublishOutbound failed: %v", err) + } + + got := <-mb.OutboundChan() + if got.Channel != "telegram" { + t.Fatalf("expected legacy channel telegram, got %q", got.Channel) + } + if got.ChatID != "chat-42" { + t.Fatalf("expected legacy chat ID chat-42, got %q", got.ChatID) + } + if got.ReplyToMessageID != "msg-9" { + t.Fatalf("expected mirrored reply_to_message_id msg-9, got %q", got.ReplyToMessageID) + } + if got.AgentID != "main" || got.SessionKey != "sk_v1_123" { + t.Fatalf("unexpected outbound turn metadata: agent=%q session=%q", got.AgentID, got.SessionKey) + } + if got.Scope == nil || got.Scope.AgentID != "main" || got.Scope.Values["chat"] != "direct:chat-42" { + t.Fatalf("unexpected outbound scope: %+v", got.Scope) + } + if got.Context.Channel != "telegram" || got.Context.ChatID != "chat-42" { + t.Fatalf("unexpected outbound context: %+v", got.Context) + } +} + +func TestPublishOutbound_PreservesExplicitReplyToMessageID(t *testing.T) { + mb := NewMessageBus() + defer mb.Close() + + msg := OutboundMessage{ + Context: InboundContext{ + Channel: "telegram", + ChatID: "chat-42", + }, + ReplyToMessageID: "msg-9", + Content: "reply", + } + + if err := mb.PublishOutbound(context.Background(), msg); err != nil { + t.Fatalf("PublishOutbound failed: %v", err) + } + + got := <-mb.OutboundChan() + if got.ReplyToMessageID != "msg-9" { + t.Fatalf("expected mirrored reply_to_message_id msg-9, got %q", got.ReplyToMessageID) + } + if got.Context.ReplyToMessageID != "msg-9" { + t.Fatalf("expected context reply_to_message_id msg-9, got %q", got.Context.ReplyToMessageID) + } +} + +func TestPublishOutbound_PreservesExplicitReplyToMessageIDWhenContextReplyIsBlank(t *testing.T) { + mb := NewMessageBus() + defer mb.Close() + + msg := OutboundMessage{ + Context: InboundContext{ + Channel: "telegram", + ChatID: "chat-42", + ReplyToMessageID: " ", + }, + ReplyToMessageID: "msg-9", + Content: "reply", + } + + if err := mb.PublishOutbound(context.Background(), msg); err != nil { + t.Fatalf("PublishOutbound failed: %v", err) + } + + got := <-mb.OutboundChan() + if got.ReplyToMessageID != "msg-9" { + t.Fatalf("expected mirrored reply_to_message_id msg-9, got %q", got.ReplyToMessageID) + } + if got.Context.ReplyToMessageID != "msg-9" { + t.Fatalf("expected context reply_to_message_id msg-9, got %q", got.Context.ReplyToMessageID) + } +} + +func TestPublishOutboundMedia_MirrorsContextToLegacyFields(t *testing.T) { + mb := NewMessageBus() + defer mb.Close() + + msg := OutboundMediaMessage{ + Context: InboundContext{ + Channel: "slack", + ChatID: "C001", + }, + AgentID: "support", + SessionKey: "sk_v1_media", + Scope: &OutboundScope{ + Version: 1, + AgentID: "support", + Channel: "slack", + Dimensions: []string{"chat"}, + Values: map[string]string{ + "chat": "channel:c001", + }, + }, + Parts: []MediaPart{{Type: "image", Ref: "media://1"}}, + } + + if err := mb.PublishOutboundMedia(context.Background(), msg); err != nil { + t.Fatalf("PublishOutboundMedia failed: %v", err) + } + + got := <-mb.OutboundMediaChan() + if got.Channel != "slack" { + t.Fatalf("expected legacy channel slack, got %q", got.Channel) + } + if got.ChatID != "C001" { + t.Fatalf("expected legacy chat ID C001, got %q", got.ChatID) + } + if got.AgentID != "support" || got.SessionKey != "sk_v1_media" { + t.Fatalf("unexpected outbound media turn metadata: agent=%q session=%q", got.AgentID, got.SessionKey) + } + if got.Scope == nil || got.Scope.Values["chat"] != "channel:c001" { + t.Fatalf("unexpected outbound media scope: %+v", got.Scope) + } + if got.Context.Channel != "slack" || got.Context.ChatID != "C001" { + t.Fatalf("unexpected outbound media context: %+v", got.Context) + } +} + +func TestPublishAudioChunkSubscribe(t *testing.T) { + mb := NewMessageBus() + defer mb.Close() + + chunk := AudioChunk{ + SessionID: "voice-1", + SpeakerID: "speaker-1", + ChatID: "chat-1", + Channel: "discord", + Sequence: 7, + Format: "opus", + Data: []byte{0x01, 0x02}, + } + + if err := mb.PublishAudioChunk(context.Background(), chunk); err != nil { + t.Fatalf("PublishAudioChunk failed: %v", err) + } + + got, ok := <-mb.AudioChunksChan() + if !ok { + t.Fatal("AudioChunksChan returned ok=false") + } + if got.SessionID != "voice-1" || got.Sequence != 7 { + t.Fatalf("unexpected audio chunk: %+v", got) + } +} + +func TestPublishVoiceControlSubscribe(t *testing.T) { + mb := NewMessageBus() + defer mb.Close() + + ctrl := VoiceControl{ + SessionID: "voice-1", + ChatID: "chat-1", + Type: "command", + Action: "start", + } + + if err := mb.PublishVoiceControl(context.Background(), ctrl); err != nil { + t.Fatalf("PublishVoiceControl failed: %v", err) + } + + got, ok := <-mb.VoiceControlsChan() + if !ok { + t.Fatal("VoiceControlsChan returned ok=false") + } + if got.Type != "command" || got.Action != "start" { + t.Fatalf("unexpected voice control: %+v", got) + } +} + +func TestNewOutboundContext_NormalizesReplyAddress(t *testing.T) { + ctx := NewOutboundContext(" telegram ", " chat-42 ", " msg-9 ") + if ctx.Channel != "telegram" { + t.Fatalf("expected channel telegram, got %q", ctx.Channel) + } + if ctx.ChatID != "chat-42" { + t.Fatalf("expected chat_id chat-42, got %q", ctx.ChatID) + } + if ctx.ReplyToMessageID != "msg-9" { + t.Fatalf("expected reply_to_message_id msg-9, got %q", ctx.ReplyToMessageID) + } } func TestPublishInbound_ContextCancel(t *testing.T) { @@ -68,7 +421,15 @@ func TestPublishInbound_ContextCancel(t *testing.T) { // Fill the buffer ctx := context.Background() for i := range defaultBusBufferSize { - if err := mb.PublishInbound(ctx, InboundMessage{Content: "fill"}); err != nil { + if err := mb.PublishInbound(ctx, InboundMessage{ + Context: InboundContext{ + Channel: "test", + ChatID: "chat-fill", + ChatType: "direct", + SenderID: "user-fill", + }, + Content: "fill", + }); err != nil { t.Fatalf("fill failed at %d: %v", i, err) } } @@ -77,7 +438,15 @@ func TestPublishInbound_ContextCancel(t *testing.T) { cancelCtx, cancel := context.WithCancel(context.Background()) cancel() - err := mb.PublishInbound(cancelCtx, InboundMessage{Content: "overflow"}) + err := mb.PublishInbound(cancelCtx, InboundMessage{ + Context: InboundContext{ + Channel: "test", + ChatID: "chat-overflow", + ChatType: "direct", + SenderID: "user-overflow", + }, + Content: "overflow", + }) if err == nil { t.Fatal("expected error from canceled context, got nil") } @@ -90,7 +459,15 @@ func TestPublishInbound_BusClosed(t *testing.T) { mb := NewMessageBus() mb.Close() - err := mb.PublishInbound(context.Background(), InboundMessage{Content: "test"}) + err := mb.PublishInbound(context.Background(), InboundMessage{ + Context: InboundContext{ + Channel: "test", + ChatID: "chat1", + ChatType: "direct", + SenderID: "user1", + }, + Content: "test", + }) if err != ErrBusClosed { t.Fatalf("expected ErrBusClosed, got %v", err) } @@ -100,7 +477,13 @@ func TestPublishOutbound_BusClosed(t *testing.T) { mb := NewMessageBus() mb.Close() - err := mb.PublishOutbound(context.Background(), OutboundMessage{Content: "test"}) + err := mb.PublishOutbound(context.Background(), OutboundMessage{ + Context: InboundContext{ + Channel: "test", + ChatID: "chat1", + }, + Content: "test", + }) if err != ErrBusClosed { t.Fatalf("expected ErrBusClosed, got %v", err) } @@ -112,14 +495,30 @@ func TestConsumeInbound_ContextCancel(t *testing.T) { defer mb.Close() for i := range defaultBusBufferSize { - if err := mb.PublishInbound(context.Background(), InboundMessage{Content: "fill"}); err != nil { + if err := mb.PublishInbound(context.Background(), InboundMessage{ + Context: InboundContext{ + Channel: "test", + ChatID: "chat-fill", + ChatType: "direct", + SenderID: "user-fill", + }, + Content: "fill", + }); err != nil { t.Fatalf("fill failed at %d: %v", i, err) } } ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() - mb.PublishInbound(ctx, InboundMessage{Content: "ContextCancel"}) + mb.PublishInbound(ctx, InboundMessage{ + Context: InboundContext{ + Channel: "test", + ChatID: "chat-cancel", + ChatType: "direct", + SenderID: "user-cancel", + }, + Content: "ContextCancel", + }) select { case <-ctx.Done(): @@ -213,7 +612,15 @@ func TestPublishInbound_FullBuffer(t *testing.T) { // Fill the buffer for i := range defaultBusBufferSize { - if err := mb.PublishInbound(ctx, InboundMessage{Content: "fill"}); err != nil { + if err := mb.PublishInbound(ctx, InboundMessage{ + Context: InboundContext{ + Channel: "test", + ChatID: "chat-fill", + ChatType: "direct", + SenderID: "user-fill", + }, + Content: "fill", + }); err != nil { t.Fatalf("fill failed at %d: %v", i, err) } } @@ -222,7 +629,15 @@ func TestPublishInbound_FullBuffer(t *testing.T) { timeoutCtx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) defer cancel() - err := mb.PublishInbound(timeoutCtx, InboundMessage{Content: "overflow"}) + err := mb.PublishInbound(timeoutCtx, InboundMessage{ + Context: InboundContext{ + Channel: "test", + ChatID: "chat-overflow", + ChatType: "direct", + SenderID: "user-overflow", + }, + Content: "overflow", + }) if err == nil { t.Fatal("expected error when buffer is full and context times out") } @@ -240,7 +655,15 @@ func TestCloseIdempotent(t *testing.T) { mb.Close() // After close, publish should return ErrBusClosed - err := mb.PublishInbound(context.Background(), InboundMessage{Content: "test"}) + err := mb.PublishInbound(context.Background(), InboundMessage{ + Context: InboundContext{ + Channel: "test", + ChatID: "chat1", + ChatType: "direct", + SenderID: "user1", + }, + Content: "test", + }) if err != ErrBusClosed { t.Fatalf("expected ErrBusClosed after multiple closes, got %v", err) } diff --git a/pkg/bus/inbound_context.go b/pkg/bus/inbound_context.go new file mode 100644 index 000000000..d6be80565 --- /dev/null +++ b/pkg/bus/inbound_context.go @@ -0,0 +1,81 @@ +package bus + +import "strings" + +// NormalizeInboundMessage ensures the inbound context is normalized and keeps +// convenience mirrors in sync for runtime consumers. +func NormalizeInboundMessage(msg InboundMessage) InboundMessage { + if msg.Context.Channel == "" { + msg.Context.Channel = msg.Channel + } + if msg.Context.ChatID == "" { + msg.Context.ChatID = msg.ChatID + } + if msg.Context.SenderID == "" { + msg.Context.SenderID = msg.SenderID + } + if msg.Context.MessageID == "" { + msg.Context.MessageID = msg.MessageID + } + msg.Context = normalizeInboundContext(msg.Context) + msg.Channel = msg.Context.Channel + msg.SenderID = msg.Context.SenderID + msg.ChatID = msg.Context.ChatID + if msg.MessageID == "" { + msg.MessageID = msg.Context.MessageID + } + if msg.Context.MessageID == "" { + msg.Context.MessageID = msg.MessageID + } + return msg +} + +func (ctx InboundContext) isZero() bool { + return ctx.Channel == "" && + ctx.Account == "" && + ctx.ChatID == "" && + ctx.ChatType == "" && + ctx.TopicID == "" && + ctx.SpaceID == "" && + ctx.SpaceType == "" && + ctx.SenderID == "" && + ctx.MessageID == "" && + !ctx.Mentioned && + ctx.ReplyToMessageID == "" && + ctx.ReplyToSenderID == "" && + len(ctx.ReplyHandles) == 0 && + len(ctx.Raw) == 0 +} + +func normalizeInboundContext(ctx InboundContext) InboundContext { + ctx.Channel = strings.TrimSpace(ctx.Channel) + ctx.Account = strings.TrimSpace(ctx.Account) + ctx.ChatID = strings.TrimSpace(ctx.ChatID) + ctx.ChatType = normalizeKind(ctx.ChatType) + ctx.TopicID = strings.TrimSpace(ctx.TopicID) + ctx.SpaceID = strings.TrimSpace(ctx.SpaceID) + ctx.SpaceType = normalizeKind(ctx.SpaceType) + ctx.SenderID = strings.TrimSpace(ctx.SenderID) + ctx.MessageID = strings.TrimSpace(ctx.MessageID) + ctx.ReplyToMessageID = strings.TrimSpace(ctx.ReplyToMessageID) + ctx.ReplyToSenderID = strings.TrimSpace(ctx.ReplyToSenderID) + ctx.ReplyHandles = cloneStringMap(ctx.ReplyHandles) + ctx.Raw = cloneStringMap(ctx.Raw) + return ctx +} + +func cloneStringMap(src map[string]string) map[string]string { + if len(src) == 0 { + return nil + } + + dst := make(map[string]string, len(src)) + for k, v := range src { + dst[k] = v + } + return dst +} + +func normalizeKind(kind string) string { + return strings.ToLower(strings.TrimSpace(kind)) +} diff --git a/pkg/bus/outbound_context.go b/pkg/bus/outbound_context.go new file mode 100644 index 000000000..cbbbc99c7 --- /dev/null +++ b/pkg/bus/outbound_context.go @@ -0,0 +1,84 @@ +package bus + +import "strings" + +// NewOutboundContext builds the minimal normalized addressing context required +// to deliver an outbound text message or reply. +func NewOutboundContext(channel, chatID, replyToMessageID string) InboundContext { + return normalizeInboundContext(InboundContext{ + Channel: strings.TrimSpace(channel), + ChatID: strings.TrimSpace(chatID), + ReplyToMessageID: strings.TrimSpace(replyToMessageID), + }) +} + +// NormalizeOutboundMessage ensures Context is normalized and keeps convenience +// mirrors in sync for runtime consumers. +func NormalizeOutboundMessage(msg OutboundMessage) OutboundMessage { + msg.Channel = strings.TrimSpace(msg.Channel) + msg.ChatID = strings.TrimSpace(msg.ChatID) + msg.ReplyToMessageID = strings.TrimSpace(msg.ReplyToMessageID) + if msg.Context.Channel == "" { + msg.Context.Channel = msg.Channel + } + if msg.Context.ChatID == "" { + msg.Context.ChatID = msg.ChatID + } + if msg.Context.ReplyToMessageID == "" { + msg.Context.ReplyToMessageID = msg.ReplyToMessageID + } + msg.Context = normalizeInboundContext(msg.Context) + if msg.Channel == "" { + msg.Channel = msg.Context.Channel + } + if msg.ChatID == "" { + msg.ChatID = msg.Context.ChatID + } + if msg.ReplyToMessageID == "" { + msg.ReplyToMessageID = msg.Context.ReplyToMessageID + } + if msg.Context.ReplyToMessageID == "" { + msg.Context.ReplyToMessageID = msg.ReplyToMessageID + } + msg.Scope = cloneOutboundScope(msg.Scope) + return msg +} + +// NormalizeOutboundMediaMessage ensures media outbound messages also carry a +// normalized context while keeping convenience mirrors in sync. +func NormalizeOutboundMediaMessage(msg OutboundMediaMessage) OutboundMediaMessage { + msg.Channel = strings.TrimSpace(msg.Channel) + msg.ChatID = strings.TrimSpace(msg.ChatID) + if msg.Context.Channel == "" { + msg.Context.Channel = msg.Channel + } + if msg.Context.ChatID == "" { + msg.Context.ChatID = msg.ChatID + } + msg.Context = normalizeInboundContext(msg.Context) + if msg.Channel == "" { + msg.Channel = msg.Context.Channel + } + if msg.ChatID == "" { + msg.ChatID = msg.Context.ChatID + } + msg.Scope = cloneOutboundScope(msg.Scope) + return msg +} + +func cloneOutboundScope(scope *OutboundScope) *OutboundScope { + if scope == nil { + return nil + } + cloned := *scope + if len(scope.Dimensions) > 0 { + cloned.Dimensions = append([]string(nil), scope.Dimensions...) + } + if len(scope.Values) > 0 { + cloned.Values = make(map[string]string, len(scope.Values)) + for key, value := range scope.Values { + cloned.Values[key] = value + } + } + return &cloned +} diff --git a/pkg/bus/types.go b/pkg/bus/types.go index 27cf61b5f..aa06ca173 100644 --- a/pkg/bus/types.go +++ b/pkg/bus/types.go @@ -1,11 +1,5 @@ package bus -// Peer identifies the routing peer for a message (direct, group, channel, etc.) -type Peer struct { - Kind string `json:"kind"` // "direct" | "group" | "channel" | "" - ID string `json:"id"` -} - // SenderInfo provides structured sender identity information. type SenderInfo struct { Platform string `json:"platform,omitempty"` // "telegram", "discord", "slack", ... @@ -15,26 +9,67 @@ type SenderInfo struct { DisplayName string `json:"display_name,omitempty"` // display name } +// InboundContext captures the normalized, platform-agnostic facts about an +// inbound message. This is the source of truth for routing and session +// allocation. +type InboundContext struct { + Channel string `json:"channel"` + Account string `json:"account,omitempty"` + + ChatID string `json:"chat_id"` + ChatType string `json:"chat_type,omitempty"` // direct / group / channel + TopicID string `json:"topic_id,omitempty"` + + SpaceID string `json:"space_id,omitempty"` + SpaceType string `json:"space_type,omitempty"` // guild / team / workspace / tenant + + SenderID string `json:"sender_id"` + MessageID string `json:"message_id,omitempty"` + + Mentioned bool `json:"mentioned,omitempty"` + + ReplyToMessageID string `json:"reply_to_message_id,omitempty"` + ReplyToSenderID string `json:"reply_to_sender_id,omitempty"` + + ReplyHandles map[string]string `json:"reply_handles,omitempty"` + Raw map[string]string `json:"raw,omitempty"` +} + type InboundMessage struct { - Channel string `json:"channel"` - SenderID string `json:"sender_id"` - Sender SenderInfo `json:"sender"` - ChatID string `json:"chat_id"` - Content string `json:"content"` - Media []string `json:"media,omitempty"` - Peer Peer `json:"peer"` // routing peer - MessageID string `json:"message_id,omitempty"` // platform message ID - MediaScope string `json:"media_scope,omitempty"` // media lifecycle scope - SessionKey string `json:"session_key"` - Metadata map[string]string `json:"metadata,omitempty"` + Context InboundContext `json:"context"` + Sender SenderInfo `json:"sender"` + Content string `json:"content"` + Media []string `json:"media,omitempty"` + MediaScope string `json:"media_scope,omitempty"` // media lifecycle scope + SessionKey string `json:"session_key"` + + // Convenience mirrors derived from Context for runtime consumers. + Channel string `json:"channel"` + SenderID string `json:"sender_id"` + ChatID string `json:"chat_id"` + MessageID string `json:"message_id,omitempty"` // platform message ID +} + +// OutboundScope captures the structured session scope associated with an +// outbound turn result without depending on the session package. +type OutboundScope struct { + Version int `json:"version,omitempty"` + AgentID string `json:"agent_id,omitempty"` + Channel string `json:"channel,omitempty"` + Account string `json:"account,omitempty"` + Dimensions []string `json:"dimensions,omitempty"` + Values map[string]string `json:"values,omitempty"` } type OutboundMessage struct { - Channel string `json:"channel"` - ChatID string `json:"chat_id"` - Content string `json:"content"` - ReplyToMessageID string `json:"reply_to_message_id,omitempty"` - Metadata map[string]string `json:"metadata,omitempty"` + Channel string `json:"channel"` + ChatID string `json:"chat_id"` + Context InboundContext `json:"context"` + AgentID string `json:"agent_id,omitempty"` + SessionKey string `json:"session_key,omitempty"` + Scope *OutboundScope `json:"scope,omitempty"` + Content string `json:"content"` + ReplyToMessageID string `json:"reply_to_message_id,omitempty"` } // MediaPart describes a single media attachment to send. @@ -48,9 +83,13 @@ type MediaPart struct { // OutboundMediaMessage carries media attachments from Agent to channels via the bus. type OutboundMediaMessage struct { - Channel string `json:"channel"` - ChatID string `json:"chat_id"` - Parts []MediaPart `json:"parts"` + Channel string `json:"channel"` + ChatID string `json:"chat_id"` + Context InboundContext `json:"context"` + AgentID string `json:"agent_id,omitempty"` + SessionKey string `json:"session_key,omitempty"` + Scope *OutboundScope `json:"scope,omitempty"` + Parts []MediaPart `json:"parts"` } // AudioChunk represents a chunk of streaming voice data. diff --git a/pkg/channels/base.go b/pkg/channels/base.go index 876291186..3585fb075 100644 --- a/pkg/channels/base.go +++ b/pkg/channels/base.go @@ -260,12 +260,11 @@ func (c *BaseChannel) IsAllowedSender(sender bus.SenderInfo) bool { return false } -func (c *BaseChannel) HandleMessage( +func (c *BaseChannel) HandleMessageWithContext( ctx context.Context, - peer bus.Peer, - messageID, senderID, chatID, content string, + deliveryChatID, content string, media []string, - metadata map[string]string, + inboundCtx bus.InboundContext, senderOpts ...bus.SenderInfo, ) { // Use SenderInfo-based allow check when available, else fall back to string @@ -273,6 +272,7 @@ func (c *BaseChannel) HandleMessage( if len(senderOpts) > 0 { sender = senderOpts[0] } + senderID := strings.TrimSpace(inboundCtx.SenderID) if sender.CanonicalID != "" || sender.PlatformID != "" { if !c.IsAllowedSender(sender) { return @@ -289,20 +289,28 @@ func (c *BaseChannel) HandleMessage( resolvedSenderID = sender.CanonicalID } - scope := BuildMediaScope(c.name, chatID, messageID) + if resolvedSenderID == "" { + resolvedSenderID = senderID + } + + inboundCtx.Channel = c.name + if inboundCtx.ChatID == "" { + inboundCtx.ChatID = deliveryChatID + } + if inboundCtx.SenderID == "" { + inboundCtx.SenderID = resolvedSenderID + } + + scope := BuildMediaScope(c.name, deliveryChatID, inboundCtx.MessageID) msg := bus.InboundMessage{ - Channel: c.name, - SenderID: resolvedSenderID, + Context: inboundCtx, Sender: sender, - ChatID: chatID, Content: content, Media: media, - Peer: peer, - MessageID: messageID, MediaScope: scope, - Metadata: metadata, } + msg = bus.NormalizeInboundMessage(msg) // Auto-trigger typing indicator, message reaction, and placeholder before publishing. // Each capability is independent — all three may fire for the same message. @@ -313,14 +321,14 @@ func (c *BaseChannel) HandleMessage( if c.owner != nil && c.placeholderRecorder != nil { // Typing if tc, ok := c.owner.(TypingCapable); ok { - if stop, err := tc.StartTyping(ctx, chatID); err == nil { - c.placeholderRecorder.RecordTypingStop(c.name, chatID, stop) + if stop, err := tc.StartTyping(ctx, deliveryChatID); err == nil { + c.placeholderRecorder.RecordTypingStop(c.name, deliveryChatID, stop) } } // Reaction - if rc, ok := c.owner.(ReactionCapable); ok && messageID != "" { - if undo, err := rc.ReactToMessage(ctx, chatID, messageID); err == nil { - c.placeholderRecorder.RecordReactionUndo(c.name, chatID, undo) + if rc, ok := c.owner.(ReactionCapable); ok && msg.MessageID != "" { + if undo, err := rc.ReactToMessage(ctx, deliveryChatID, msg.MessageID); err == nil { + c.placeholderRecorder.RecordReactionUndo(c.name, deliveryChatID, undo) } } // Placeholder — independent pipeline. @@ -329,8 +337,8 @@ func (c *BaseChannel) HandleMessage( // "Thinking…" only once the voice has been processed. if !audioAnnotationRe.MatchString(content) { if pc, ok := c.owner.(PlaceholderCapable); ok { - if phID, err := pc.SendPlaceholder(ctx, chatID); err == nil && phID != "" { - c.placeholderRecorder.RecordPlaceholder(c.name, chatID, phID) + if phID, err := pc.SendPlaceholder(ctx, deliveryChatID); err == nil && phID != "" { + c.placeholderRecorder.RecordPlaceholder(c.name, deliveryChatID, phID) } } } @@ -339,12 +347,24 @@ func (c *BaseChannel) HandleMessage( if err := c.bus.PublishInbound(ctx, msg); err != nil { logger.ErrorCF("channels", "Failed to publish inbound message", map[string]any{ "channel": c.name, - "chat_id": chatID, + "chat_id": deliveryChatID, "error": err.Error(), }) } } +// HandleInboundContext publishes a normalized inbound message using only the +// structured context. +func (c *BaseChannel) HandleInboundContext( + ctx context.Context, + deliveryChatID, content string, + media []string, + inboundCtx bus.InboundContext, + senderOpts ...bus.SenderInfo, +) { + c.HandleMessageWithContext(ctx, deliveryChatID, content, media, inboundCtx, senderOpts...) +} + func (c *BaseChannel) SetRunning(running bool) { c.running.Store(running) } diff --git a/pkg/channels/base_test.go b/pkg/channels/base_test.go index 6132b8bf9..04500f775 100644 --- a/pkg/channels/base_test.go +++ b/pkg/channels/base_test.go @@ -1,6 +1,7 @@ package channels import ( + "context" "testing" "github.com/sipeed/picoclaw/pkg/bus" @@ -263,3 +264,58 @@ func TestIsAllowedSender(t *testing.T) { }) } } + +func TestHandleInboundContext_PublishesNormalizedContext(t *testing.T) { + tests := []struct { + name string + inbound bus.InboundContext + wantChat string + wantSender string + }{ + { + name: "direct uses sender as peer", + inbound: bus.InboundContext{ + Channel: "test", + ChatID: "chat-1", + ChatType: "direct", + SenderID: "user-1", + MessageID: "msg-1", + }, + wantChat: "chat-1", + wantSender: "user-1", + }, + { + name: "group uses chat as peer", + inbound: bus.InboundContext{ + Channel: "test", + ChatID: "group-1", + ChatType: "group", + SenderID: "user-2", + MessageID: "msg-2", + }, + wantChat: "group-1", + wantSender: "user-2", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + msgBus := bus.NewMessageBus() + defer msgBus.Close() + + ch := NewBaseChannel("test", nil, msgBus, nil) + ch.HandleInboundContext(context.Background(), tt.inbound.ChatID, "hello", nil, tt.inbound) + + msg := <-msgBus.InboundChan() + if msg.ChatID != tt.wantChat { + t.Fatalf("ChatID = %q, want %q", msg.ChatID, tt.wantChat) + } + if msg.SenderID != tt.wantSender { + t.Fatalf("SenderID = %q, want %q", msg.SenderID, tt.wantSender) + } + if msg.Context.ChatType != tt.inbound.ChatType { + t.Fatalf("ChatType = %q, want %q", msg.Context.ChatType, tt.inbound.ChatType) + } + }) + } +} diff --git a/pkg/channels/dingtalk/dingtalk.go b/pkg/channels/dingtalk/dingtalk.go index e7c3685f3..9cd461bc8 100644 --- a/pkg/channels/dingtalk/dingtalk.go +++ b/pkg/channels/dingtalk/dingtalk.go @@ -185,16 +185,15 @@ func (c *DingTalkChannel) onChatBotMessageReceived( "session_webhook": data.SessionWebhook, } - var peer bus.Peer + var ( + chatType string + isMentioned bool + ) if data.ConversationType == "1" { - peerID := senderID - if peerID == "" { - peerID = chatID - } - peer = bus.Peer{Kind: "direct", ID: peerID} + chatType = "direct" } else { - peer = bus.Peer{Kind: "group", ID: data.ConversationId} - isMentioned := data.IsInAtList + chatType = "group" + isMentioned = data.IsInAtList if isMentioned { content = stripLeadingAtMentions(content) } @@ -232,8 +231,21 @@ func (c *DingTalkChannel) onChatBotMessageReceived( return nil, nil } - // Handle the message through the base channel - c.HandleMessage(ctx, peer, "", resolvedSenderID, chatID, content, nil, metadata, sender) + inboundCtx := bus.InboundContext{ + Channel: "dingtalk", + ChatID: chatID, + ChatType: chatType, + SenderID: resolvedSenderID, + Mentioned: isMentioned, + Raw: metadata, + } + if data.SessionWebhook != "" { + inboundCtx.ReplyHandles = map[string]string{ + "session_webhook": data.SessionWebhook, + } + } + + c.HandleInboundContext(ctx, chatID, content, nil, inboundCtx, sender) // Return nil to indicate we've handled the message asynchronously // The response will be sent through the message bus diff --git a/pkg/channels/dingtalk/dingtalk_test.go b/pkg/channels/dingtalk/dingtalk_test.go index 50c99046f..6dfc44730 100644 --- a/pkg/channels/dingtalk/dingtalk_test.go +++ b/pkg/channels/dingtalk/dingtalk_test.go @@ -75,8 +75,8 @@ func TestOnChatBotMessageReceived_GroupMentionOnlyUsesIsInAtListAndStripsMention if inbound.ChatID != "group-abc" { t.Fatalf("chat_id=%q", inbound.ChatID) } - if inbound.Peer.Kind != "group" || inbound.Peer.ID != "group-abc" { - t.Fatalf("peer=%+v", inbound.Peer) + if inbound.Context.ChatType != "group" { + t.Fatalf("chat_type=%q", inbound.Context.ChatType) } if inbound.Content != "/help" { t.Fatalf("content=%q", inbound.Content) @@ -103,12 +103,15 @@ func TestOnChatBotMessageReceived_DirectFallbackSenderIDUsesConversationID(t *te if inbound.ChatID != "conv-direct-42" { t.Fatalf("chat_id=%q", inbound.ChatID) } - if inbound.Peer.Kind != "direct" || inbound.Peer.ID != "openid-user-42" { - t.Fatalf("peer=%+v", inbound.Peer) + if inbound.Context.ChatType != "direct" { + t.Fatalf("chat_type=%q", inbound.Context.ChatType) } - if inbound.SenderID != "dingtalk:openid-user-42" { + if inbound.SenderID != "openid-user-42" { t.Fatalf("sender_id=%q", inbound.SenderID) } + if inbound.Sender.CanonicalID != "dingtalk:openid-user-42" { + t.Fatalf("sender canonical_id=%q", inbound.Sender.CanonicalID) + } if _, ok := ch.sessionWebhooks.Load("conv-direct-42"); !ok { t.Fatal("expected session webhook keyed by conversation_id") diff --git a/pkg/channels/discord/discord.go b/pkg/channels/discord/discord.go index 50d060fd8..28f7277d3 100644 --- a/pkg/channels/discord/discord.go +++ b/pkg/channels/discord/discord.go @@ -408,8 +408,8 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag // In guild (group) channels, apply unified group trigger filtering // DMs (GuildID is empty) always get a response + isMentioned := false if m.GuildID != "" { - isMentioned := false for _, mention := range m.Mentions { if mention.ID == c.botUserID { isMentioned = true @@ -506,14 +506,10 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag }) peerKind := "channel" - peerID := m.ChannelID if m.GuildID == "" { peerKind = "direct" - peerID = senderID } - peer := bus.Peer{Kind: peerKind, ID: peerID} - metadata := map[string]string{ "user_id": senderID, "username": m.Author.Username, @@ -522,8 +518,24 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag "channel_id": m.ChannelID, "is_dm": fmt.Sprintf("%t", m.GuildID == ""), } + inboundCtx := bus.InboundContext{ + Channel: c.Name(), + ChatID: m.ChannelID, + ChatType: peerKind, + SenderID: senderID, + MessageID: m.ID, + Mentioned: isMentioned, + Raw: metadata, + } + if m.GuildID != "" { + inboundCtx.SpaceID = m.GuildID + inboundCtx.SpaceType = "guild" + } + if m.MessageReference != nil { + inboundCtx.ReplyToMessageID = m.MessageReference.MessageID + } - c.HandleMessage(c.ctx, peer, m.ID, senderID, m.ChannelID, content, mediaPaths, metadata, sender) + c.HandleInboundContext(c.ctx, m.ChannelID, content, mediaPaths, inboundCtx, sender) } // startTyping starts a continuous typing indicator loop for the given chatID. diff --git a/pkg/channels/feishu/feishu_64.go b/pkg/channels/feishu/feishu_64.go index ecb3da894..02ee47d69 100644 --- a/pkg/channels/feishu/feishu_64.go +++ b/pkg/channels/feishu/feishu_64.go @@ -445,17 +445,23 @@ func (c *FeishuChannel) handleMessageReceive(ctx context.Context, event *larkim. // Append media tags to content (like Telegram does) content = appendMediaTags(content, messageType, mediaRefs) + if content == "" { + content = "[empty message]" + } chatType := stringValue(message.ChatType) metadata := buildInboundMetadata(message, sender) - var peer bus.Peer + var ( + inboundChatType string + isMentioned bool + ) if chatType == "p2p" { - peer = bus.Peer{Kind: "direct", ID: senderID} + inboundChatType = "direct" } else { - peer = bus.Peer{Kind: "group", ID: chatID} + inboundChatType = "group" // Check if bot was mentioned - isMentioned := c.isBotMentioned(message) + isMentioned = c.isBotMentioned(message) // Strip mention placeholders from content before group trigger check if len(message.Mentions) > 0 { @@ -490,7 +496,21 @@ func (c *FeishuChannel) handleMessageReceive(ctx context.Context, event *larkim. "thread_id": stringValue(message.ThreadId), }) - c.HandleMessage(ctx, peer, messageID, senderID, chatID, content, mediaRefs, metadata, senderInfo) + inboundCtx := bus.InboundContext{ + Channel: "feishu", + ChatID: chatID, + ChatType: inboundChatType, + SenderID: senderID, + MessageID: messageID, + Mentioned: isMentioned, + Raw: metadata, + } + if sender != nil && sender.TenantKey != nil && *sender.TenantKey != "" { + inboundCtx.SpaceType = "tenant" + inboundCtx.SpaceID = *sender.TenantKey + } + + c.HandleInboundContext(ctx, chatID, content, mediaRefs, inboundCtx, senderInfo) return nil } diff --git a/pkg/channels/irc/handler.go b/pkg/channels/irc/handler.go index b92359da4..73df9c43c 100644 --- a/pkg/channels/irc/handler.go +++ b/pkg/channels/irc/handler.go @@ -51,14 +51,11 @@ func (c *IRCChannel) onPrivmsg(conn *ircevent.Connection, e ircmsg.Message) { isDM := !strings.HasPrefix(target, "#") && !strings.HasPrefix(target, "&") var chatID string - var peer bus.Peer if isDM { chatID = nick - peer = bus.Peer{Kind: "direct", ID: nick} } else { chatID = target - peer = bus.Peer{Kind: "group", ID: target} } sender := bus.SenderInfo{ @@ -73,9 +70,11 @@ func (c *IRCChannel) onPrivmsg(conn *ircevent.Connection, e ircmsg.Message) { return } + isMentioned := false + // For channel messages, check group trigger (mention detection) if !isDM { - isMentioned := isBotMentioned(content, currentNick) + isMentioned = isBotMentioned(content, currentNick) if isMentioned { content = stripBotMention(content, currentNick) } @@ -100,7 +99,21 @@ func (c *IRCChannel) onPrivmsg(conn *ircevent.Connection, e ircmsg.Message) { metadata["channel"] = target } - c.HandleMessage(c.ctx, peer, messageID, nick, chatID, content, nil, metadata, sender) + inboundCtx := bus.InboundContext{ + Channel: "irc", + ChatID: chatID, + SenderID: nick, + MessageID: messageID, + Mentioned: isMentioned, + Raw: metadata, + } + if isDM { + inboundCtx.ChatType = "direct" + } else { + inboundCtx.ChatType = "group" + } + + c.HandleInboundContext(c.ctx, chatID, content, nil, inboundCtx, sender) } // nickMentionedAt returns the byte index where botNick is mentioned in content diff --git a/pkg/channels/line/line.go b/pkg/channels/line/line.go index c2515a5ac..760506a31 100644 --- a/pkg/channels/line/line.go +++ b/pkg/channels/line/line.go @@ -354,8 +354,9 @@ func (c *LINEChannel) processEvent(event lineEvent) { } // In group chats, apply unified group trigger filtering + isMentioned := false if isGroup { - isMentioned := c.isBotMentioned(msg) + isMentioned = c.isBotMentioned(msg) respond, cleaned := c.ShouldRespondInGroup(isMentioned, content) if !respond { logger.DebugCF("line", "Ignoring group message by group trigger", map[string]any{ @@ -371,13 +372,6 @@ func (c *LINEChannel) processEvent(event lineEvent) { "source_type": event.Source.Type, } - var peer bus.Peer - if isGroup { - peer = bus.Peer{Kind: "group", ID: chatID} - } else { - peer = bus.Peer{Kind: "direct", ID: senderID} - } - logger.DebugCF("line", "Received message", map[string]any{ "sender_id": senderID, "chat_id": chatID, @@ -396,7 +390,25 @@ func (c *LINEChannel) processEvent(event lineEvent) { return } - c.HandleMessage(c.ctx, peer, msg.ID, senderID, chatID, content, mediaPaths, metadata, sender) + inboundCtx := bus.InboundContext{ + Channel: c.Name(), + ChatID: chatID, + ChatType: map[bool]string{true: "group", false: "direct"}[isGroup], + SenderID: senderID, + MessageID: msg.ID, + Mentioned: isMentioned, + Raw: metadata, + } + if event.ReplyToken != "" { + inboundCtx.ReplyHandles = map[string]string{ + "reply_token": event.ReplyToken, + } + if msg.QuoteToken != "" { + inboundCtx.ReplyHandles["quote_token"] = msg.QuoteToken + } + } + + c.HandleInboundContext(c.ctx, chatID, content, mediaPaths, inboundCtx, sender) } // isBotMentioned checks if the bot is mentioned in the message. diff --git a/pkg/channels/maixcam/maixcam.go b/pkg/channels/maixcam/maixcam.go index c9bf4d25e..b81206c59 100644 --- a/pkg/channels/maixcam/maixcam.go +++ b/pkg/channels/maixcam/maixcam.go @@ -200,17 +200,15 @@ func (c *MaixCamChannel) handlePersonDetection(msg MaixCamMessage) { return } - c.HandleMessage( - c.ctx, - bus.Peer{Kind: "channel", ID: "default"}, - "", - senderID, - chatID, - content, - []string{}, - metadata, - sender, - ) + inboundCtx := bus.InboundContext{ + Channel: "maixcam", + ChatID: chatID, + ChatType: "channel", + SenderID: senderID, + Raw: metadata, + } + + c.HandleInboundContext(c.ctx, chatID, content, nil, inboundCtx, sender) } func (c *MaixCamChannel) handleStatusUpdate(msg MaixCamMessage) { diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go index 5d5e6f9f0..4d8e47c0f 100644 --- a/pkg/channels/manager.go +++ b/pkg/channels/manager.go @@ -98,6 +98,22 @@ type asyncTask struct { cancel context.CancelFunc } +func outboundMessageChannel(msg bus.OutboundMessage) string { + return msg.Context.Channel +} + +func outboundMessageChatID(msg bus.OutboundMessage) string { + return msg.ChatID +} + +func outboundMediaChannel(msg bus.OutboundMediaMessage) string { + return msg.Context.Channel +} + +func outboundMediaChatID(msg bus.OutboundMediaMessage) string { + return msg.ChatID +} + // RecordPlaceholder registers a placeholder message for later editing. // Implements PlaceholderRecorder. func (m *Manager) RecordPlaceholder(channel, chatID, placeholderID string) { @@ -161,7 +177,8 @@ func (m *Manager) RecordReactionUndo(channel, chatID string, undo func()) { // preSend handles typing stop, reaction undo, and placeholder editing before sending a message. // Returns the delivered message IDs and true when delivery completed before a normal Send. func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMessage, ch Channel) ([]string, bool) { - key := name + ":" + msg.ChatID + chatID := outboundMessageChatID(msg) + key := name + ":" + chatID // 1. Stop typing if v, loaded := m.typingStops.LoadAndDelete(key); loaded { @@ -183,9 +200,9 @@ func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMess if entry, ok := v.(placeholderEntry); ok && entry.id != "" { // Prefer deleting the placeholder (cleaner UX than editing to same content) if deleter, ok := ch.(MessageDeleter); ok { - deleter.DeleteMessage(ctx, msg.ChatID, entry.id) // best effort + deleter.DeleteMessage(ctx, chatID, entry.id) // best effort } else if editor, ok := ch.(MessageEditor); ok { - editor.EditMessage(ctx, msg.ChatID, entry.id, msg.Content) // fallback + editor.EditMessage(ctx, chatID, entry.id, msg.Content) // fallback } } } @@ -196,7 +213,7 @@ func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMess if v, loaded := m.placeholders.LoadAndDelete(key); loaded { if entry, ok := v.(placeholderEntry); ok && entry.id != "" { if editor, ok := ch.(MessageEditor); ok { - if err := editor.EditMessage(ctx, msg.ChatID, entry.id, msg.Content); err == nil { + if err := editor.EditMessage(ctx, chatID, entry.id, msg.Content); err == nil { return []string{entry.id}, true } // edit failed → fall through to normal Send @@ -212,7 +229,8 @@ func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMess // delivery never edits the placeholder because there is no text payload to // replace it with; it only attempts to delete the placeholder when possible. func (m *Manager) preSendMedia(ctx context.Context, name string, msg bus.OutboundMediaMessage, ch Channel) { - key := name + ":" + msg.ChatID + chatID := outboundMediaChatID(msg) + key := name + ":" + chatID // 1. Stop typing if v, loaded := m.typingStops.LoadAndDelete(key); loaded { @@ -235,7 +253,7 @@ func (m *Manager) preSendMedia(ctx context.Context, name string, msg bus.Outboun if v, loaded := m.placeholders.LoadAndDelete(key); loaded { if entry, ok := v.(placeholderEntry); ok && entry.id != "" { if deleter, ok := ch.(MessageDeleter); ok { - deleter.DeleteMessage(ctx, msg.ChatID, entry.id) // best effort + deleter.DeleteMessage(ctx, chatID, entry.id) // best effort } } } @@ -820,7 +838,7 @@ func (m *Manager) sendWithRetry( // All retries exhausted or permanent failure logger.ErrorCF("channels", "Send failed", map[string]any{ "channel": name, - "chat_id": msg.ChatID, + "chat_id": outboundMessageChatID(msg), "error": lastErr.Error(), "retries": maxRetries, }) @@ -882,7 +900,7 @@ func (m *Manager) dispatchOutbound(ctx context.Context) { dispatchLoop( ctx, m, m.bus.OutboundChan(), - func(msg bus.OutboundMessage) string { return msg.Channel }, + func(msg bus.OutboundMessage) string { return outboundMessageChannel(msg) }, func(ctx context.Context, w *channelWorker, msg bus.OutboundMessage) bool { select { case w.queue <- msg: @@ -902,7 +920,7 @@ func (m *Manager) dispatchOutboundMedia(ctx context.Context) { dispatchLoop( ctx, m, m.bus.OutboundMediaChan(), - func(msg bus.OutboundMediaMessage) string { return msg.Channel }, + func(msg bus.OutboundMediaMessage) string { return outboundMediaChannel(msg) }, func(ctx context.Context, w *channelWorker, msg bus.OutboundMediaMessage) bool { select { case w.mediaQueue <- msg: @@ -1001,7 +1019,7 @@ func (m *Manager) sendMediaWithRetry( // All retries exhausted or permanent failure logger.ErrorCF("channels", "SendMedia failed", map[string]any{ "channel": name, - "chat_id": msg.ChatID, + "chat_id": outboundMediaChatID(msg), "error": lastErr.Error(), "retries": maxRetries, }) @@ -1200,16 +1218,19 @@ func (m *Manager) UnregisterChannel(name string) { // delivered (or all retries are exhausted), which preserves ordering when // a subsequent operation depends on the message having been sent. func (m *Manager) SendMessage(ctx context.Context, msg bus.OutboundMessage) error { + msg = bus.NormalizeOutboundMessage(msg) + channelName := outboundMessageChannel(msg) + m.mu.RLock() - _, exists := m.channels[msg.Channel] - w, wExists := m.workers[msg.Channel] + _, exists := m.channels[channelName] + w, wExists := m.workers[channelName] m.mu.RUnlock() if !exists { - return fmt.Errorf("channel %s not found", msg.Channel) + return fmt.Errorf("channel %s not found", channelName) } if !wExists || w == nil { - return fmt.Errorf("channel %s has no active worker", msg.Channel) + return fmt.Errorf("channel %s has no active worker", channelName) } maxLen := 0 @@ -1220,10 +1241,10 @@ func (m *Manager) SendMessage(ctx context.Context, msg bus.OutboundMessage) erro for _, chunk := range SplitMessage(msg.Content, maxLen) { chunkMsg := msg chunkMsg.Content = chunk - m.sendWithRetry(ctx, msg.Channel, w, chunkMsg) + m.sendWithRetry(ctx, channelName, w, chunkMsg) } } else { - m.sendWithRetry(ctx, msg.Channel, w, msg) + m.sendWithRetry(ctx, channelName, w, msg) } return nil } @@ -1233,19 +1254,22 @@ func (m *Manager) SendMessage(ctx context.Context, msg bus.OutboundMessage) erro // retries are exhausted), which preserves ordering when later agent behavior // depends on actual media delivery. func (m *Manager) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error { + msg = bus.NormalizeOutboundMediaMessage(msg) + channelName := outboundMediaChannel(msg) + m.mu.RLock() - _, exists := m.channels[msg.Channel] - w, wExists := m.workers[msg.Channel] + _, exists := m.channels[channelName] + w, wExists := m.workers[channelName] m.mu.RUnlock() if !exists { - return fmt.Errorf("channel %s not found", msg.Channel) + return fmt.Errorf("channel %s not found", channelName) } if !wExists || w == nil { - return fmt.Errorf("channel %s has no active worker", msg.Channel) + return fmt.Errorf("channel %s has no active worker", channelName) } - _, err := m.sendMediaWithRetry(ctx, msg.Channel, w, msg) + _, err := m.sendMediaWithRetry(ctx, channelName, w, msg) return err } @@ -1260,10 +1284,10 @@ func (m *Manager) SendToChannel(ctx context.Context, channelName, chatID, conten } msg := bus.OutboundMessage{ - Channel: channelName, - ChatID: chatID, + Context: bus.NewOutboundContext(channelName, chatID, ""), Content: content, } + msg = bus.NormalizeOutboundMessage(msg) if wExists && w != nil { select { diff --git a/pkg/channels/manager_test.go b/pkg/channels/manager_test.go index 6b261b2dd..881993d9c 100644 --- a/pkg/channels/manager_test.go +++ b/pkg/channels/manager_test.go @@ -175,11 +175,11 @@ func TestStartAll_PartialFailure_StartsSuccessfulWorkers(t *testing.T) { pubCtx, pubCancel := context.WithTimeout(context.Background(), 2*time.Second) defer pubCancel() - if err := m.bus.PublishOutbound(pubCtx, bus.OutboundMessage{ + if err := m.bus.PublishOutbound(pubCtx, testOutboundMessage(bus.OutboundMessage{ Channel: "good", ChatID: "chat-1", Content: "hello", - }); err != nil { + })); err != nil { t.Fatalf("PublishOutbound() error = %v", err) } @@ -197,6 +197,20 @@ func TestStartAll_PartialFailure_StartsSuccessfulWorkers(t *testing.T) { } } +func testOutboundMessage(msg bus.OutboundMessage) bus.OutboundMessage { + if msg.Context.Channel == "" && msg.Context.ChatID == "" { + msg.Context = bus.NewOutboundContext(msg.Channel, msg.ChatID, msg.ReplyToMessageID) + } + return bus.NormalizeOutboundMessage(msg) +} + +func testOutboundMediaMessage(msg bus.OutboundMediaMessage) bus.OutboundMediaMessage { + if msg.Context.Channel == "" && msg.Context.ChatID == "" { + msg.Context = bus.NewOutboundContext(msg.Channel, msg.ChatID, "") + } + return bus.NormalizeOutboundMediaMessage(msg) +} + func TestSendWithRetry_Success(t *testing.T) { m := newTestManager() var callCount int @@ -212,7 +226,7 @@ func TestSendWithRetry_Success(t *testing.T) { } ctx := context.Background() - msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"} + msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}) m.sendWithRetry(ctx, "test", w, msg) @@ -239,7 +253,7 @@ func TestSendWithRetry_TemporaryThenSuccess(t *testing.T) { } ctx := context.Background() - msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"} + msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}) m.sendWithRetry(ctx, "test", w, msg) @@ -263,7 +277,7 @@ func TestSendWithRetry_PermanentFailure(t *testing.T) { } ctx := context.Background() - msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"} + msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}) m.sendWithRetry(ctx, "test", w, msg) @@ -287,7 +301,7 @@ func TestSendWithRetry_NotRunning(t *testing.T) { } ctx := context.Background() - msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"} + msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}) m.sendWithRetry(ctx, "test", w, msg) @@ -314,7 +328,7 @@ func TestSendWithRetry_RateLimitRetry(t *testing.T) { } ctx := context.Background() - msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"} + msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}) start := time.Now() m.sendWithRetry(ctx, "test", w, msg) @@ -344,7 +358,7 @@ func TestSendWithRetry_MaxRetriesExhausted(t *testing.T) { } ctx := context.Background() - msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"} + msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}) m.sendWithRetry(ctx, "test", w, msg) @@ -370,11 +384,11 @@ func TestSendMedia_Success(t *testing.T) { m.channels["test"] = ch m.workers["test"] = w - err := m.SendMedia(context.Background(), bus.OutboundMediaMessage{ + err := m.SendMedia(context.Background(), testOutboundMediaMessage(bus.OutboundMediaMessage{ Channel: "test", ChatID: "chat1", Parts: []bus.MediaPart{{Ref: "media://abc"}}, - }) + })) if err != nil { t.Fatalf("SendMedia() error = %v", err) } @@ -397,11 +411,11 @@ func TestSendMedia_PropagatesFailure(t *testing.T) { m.channels["test"] = ch m.workers["test"] = w - err := m.SendMedia(context.Background(), bus.OutboundMediaMessage{ + err := m.SendMedia(context.Background(), testOutboundMediaMessage(bus.OutboundMediaMessage{ Channel: "test", ChatID: "chat1", Parts: []bus.MediaPart{{Ref: "media://abc"}}, - }) + })) if err == nil { t.Fatal("expected SendMedia to return error") } @@ -424,11 +438,11 @@ func TestSendMedia_UnsupportedChannelReturnsError(t *testing.T) { m.channels["test"] = ch m.workers["test"] = w - err := m.SendMedia(context.Background(), bus.OutboundMediaMessage{ + err := m.SendMedia(context.Background(), testOutboundMediaMessage(bus.OutboundMediaMessage{ Channel: "test", ChatID: "chat1", Parts: []bus.MediaPart{{Ref: "media://abc"}}, - }) + })) if err == nil { t.Fatal("expected SendMedia to return error for unsupported channel") } @@ -454,11 +468,11 @@ func TestSendMedia_DeletesPlaceholderBeforeSending(t *testing.T) { m.workers["test"] = w m.RecordPlaceholder("test", "chat1", "placeholder-1") - err := m.SendMedia(context.Background(), bus.OutboundMediaMessage{ + err := m.SendMedia(context.Background(), testOutboundMediaMessage(bus.OutboundMediaMessage{ Channel: "test", ChatID: "chat1", Parts: []bus.MediaPart{{Ref: "media://abc"}}, - }) + })) if err != nil { t.Fatalf("SendMedia() error = %v", err) } @@ -491,7 +505,7 @@ func TestSendWithRetry_UnknownError(t *testing.T) { } ctx := context.Background() - msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"} + msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}) m.sendWithRetry(ctx, "test", w, msg) @@ -515,7 +529,7 @@ func TestSendWithRetry_ContextCancelled(t *testing.T) { } ctx, cancel := context.WithCancel(context.Background()) - msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"} + msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}) // Cancel context after first Send attempt returns ch.sendFn = func(_ context.Context, _ bus.OutboundMessage) error { @@ -561,7 +575,7 @@ func TestWorkerRateLimiter(t *testing.T) { // Enqueue 4 messages for i := range 4 { - w.queue <- bus.OutboundMessage{Channel: "test", ChatID: "1", Content: fmt.Sprintf("msg%d", i)} + w.queue <- testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: fmt.Sprintf("msg%d", i)}) } // Wait enough time for all messages to be sent (4 msgs at 2/s = ~2s, give extra margin) @@ -637,7 +651,7 @@ func TestRunWorker_MessageSplitting(t *testing.T) { go m.runWorker(ctx, "test", w) // Send a message that should be split - w.queue <- bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello world"} + w.queue <- testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello world"}) time.Sleep(100 * time.Millisecond) @@ -678,7 +692,7 @@ func TestSendWithRetry_ExponentialBackoff(t *testing.T) { } ctx := context.Background() - msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"} + msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"}) start := time.Now() m.sendWithRetry(ctx, "test", w, msg) @@ -738,7 +752,7 @@ func TestPreSend_PlaceholderEditSuccess(t *testing.T) { // Register placeholder m.RecordPlaceholder("test", "123", "456") - msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"} + msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"}) _, edited := m.preSend(context.Background(), "test", msg, ch) if !edited { @@ -768,7 +782,7 @@ func TestPreSend_PlaceholderEditFails_FallsThrough(t *testing.T) { m.RecordPlaceholder("test", "123", "456") - msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"} + msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"}) _, edited := m.preSend(context.Background(), "test", msg, ch) if edited { @@ -827,7 +841,7 @@ func TestPreSend_TypingStopCalled(t *testing.T) { stopCalled = true }) - msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"} + msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"}) m.preSend(context.Background(), "test", msg, ch) if !stopCalled { @@ -844,7 +858,7 @@ func TestPreSend_NoRegisteredState(t *testing.T) { }, } - msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"} + msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"}) _, edited := m.preSend(context.Background(), "test", msg, ch) if edited { @@ -874,7 +888,7 @@ func TestPreSend_TypingAndPlaceholder(t *testing.T) { }) m.RecordPlaceholder("test", "123", "456") - msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"} + msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"}) _, edited := m.preSend(context.Background(), "test", msg, ch) if !stopCalled { @@ -938,7 +952,7 @@ func TestRecordTypingStop_ReplacesExistingStop(t *testing.T) { t.Fatalf("expected replacement typing stop to stay active until preSend, got %d calls", newStopCalls) } - msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"} + msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"}) m.preSend(context.Background(), "test", msg, &mockChannel{}) if newStopCalls != 1 { @@ -972,7 +986,7 @@ func TestSendWithRetry_PreSendEditsPlaceholder(t *testing.T) { limiter: rate.NewLimiter(rate.Inf, 1), } - msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"} + msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"}) m.sendWithRetry(context.Background(), "test", w, msg) if sendCalled { @@ -1135,7 +1149,7 @@ func TestPreSendStillWorksWithWrappedTypes(t *testing.T) { }) m.RecordPlaceholder("test", "chat1", "ph_id") - msg := bus.OutboundMessage{Channel: "test", ChatID: "chat1", Content: "response"} + msg := testOutboundMessage(bus.OutboundMessage{Channel: "test", ChatID: "chat1", Content: "response"}) _, edited := m.preSend(context.Background(), "test", msg, ch) if !stopCalled { @@ -1238,11 +1252,11 @@ func TestManager_PlaceholderConsumedByResponse(t *testing.T) { // Transcription feedback arrives first — it should consume the placeholder // and be delivered via EditMessage, not Send. - msgTranscript := bus.OutboundMessage{ + msgTranscript := testOutboundMessage(bus.OutboundMessage{ Channel: "mock", ChatID: "chat-1", Content: "Transcript: hello", - } + }) mgr.sendWithRetry(ctx, "mock", worker, msgTranscript) if mockCh.editedMessages != 1 { @@ -1258,11 +1272,11 @@ func TestManager_PlaceholderConsumedByResponse(t *testing.T) { } // Final LLM response arrives — no placeholder left, so it goes through Send - msgFinal := bus.OutboundMessage{ + msgFinal := testOutboundMessage(bus.OutboundMessage{ Channel: "mock", ChatID: "chat-1", Content: "Final Answer", - } + }) mgr.sendWithRetry(ctx, "mock", worker, msgFinal) if len(mockCh.sentMessages) != 1 { @@ -1288,12 +1302,12 @@ func TestSendMessage_Synchronous(t *testing.T) { m.channels["test"] = ch m.workers["test"] = w - msg := bus.OutboundMessage{ + msg := testOutboundMessage(bus.OutboundMessage{ Channel: "test", ChatID: "123", Content: "hello world", ReplyToMessageID: "msg-456", - } + }) err := m.SendMessage(context.Background(), msg) if err != nil { @@ -1315,11 +1329,11 @@ func TestSendMessage_Synchronous(t *testing.T) { func TestSendMessage_UnknownChannel(t *testing.T) { m := newTestManager() - msg := bus.OutboundMessage{ + msg := testOutboundMessage(bus.OutboundMessage{ Channel: "nonexistent", ChatID: "123", Content: "hello", - } + }) err := m.SendMessage(context.Background(), msg) if err == nil { @@ -1336,11 +1350,11 @@ func TestSendMessage_NoWorker(t *testing.T) { m.channels["test"] = ch // No worker registered - msg := bus.OutboundMessage{ + msg := testOutboundMessage(bus.OutboundMessage{ Channel: "test", ChatID: "123", Content: "hello", - } + }) err := m.SendMessage(context.Background(), msg) if err == nil { @@ -1369,11 +1383,11 @@ func TestSendMessage_WithRetry(t *testing.T) { m.channels["test"] = ch m.workers["test"] = w - msg := bus.OutboundMessage{ + msg := testOutboundMessage(bus.OutboundMessage{ Channel: "test", ChatID: "123", Content: "retry me", - } + }) err := m.SendMessage(context.Background(), msg) if err != nil { @@ -1385,6 +1399,46 @@ func TestSendMessage_WithRetry(t *testing.T) { } } +func TestSendMessage_ContextOnlyUsesContextAddressing(t *testing.T) { + m := newTestManager() + + var received []bus.OutboundMessage + ch := &mockChannel{ + sendFn: func(_ context.Context, msg bus.OutboundMessage) error { + received = append(received, msg) + return nil + }, + } + + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + m.channels["test"] = ch + m.workers["test"] = w + + msg := testOutboundMessage(bus.OutboundMessage{ + Context: bus.NewOutboundContext("test", "123", "msg-9"), + Content: "hello", + }) + + if err := m.SendMessage(context.Background(), msg); err != nil { + t.Fatalf("expected no error, got %v", err) + } + if len(received) != 1 { + t.Fatalf("expected 1 message sent, got %d", len(received)) + } + if received[0].Channel != "test" || received[0].ChatID != "123" { + t.Fatalf("expected mirrored legacy address, got %+v", received[0]) + } + if received[0].Context.Channel != "test" || received[0].Context.ChatID != "123" { + t.Fatalf("expected context address to be preserved, got %+v", received[0].Context) + } + if received[0].ReplyToMessageID != "msg-9" { + t.Fatalf("expected reply_to_message_id msg-9, got %q", received[0].ReplyToMessageID) + } +} + func TestSendMessage_WithSplitting(t *testing.T) { m := newTestManager() @@ -1406,11 +1460,11 @@ func TestSendMessage_WithSplitting(t *testing.T) { m.channels["test"] = ch m.workers["test"] = w - msg := bus.OutboundMessage{ + msg := testOutboundMessage(bus.OutboundMessage{ Channel: "test", ChatID: "123", Content: "hello world", - } + }) err := m.SendMessage(context.Background(), msg) if err != nil { @@ -1422,6 +1476,43 @@ func TestSendMessage_WithSplitting(t *testing.T) { } } +func TestSendMedia_ContextOnlyUsesContextAddressing(t *testing.T) { + m := newTestManager() + + var received []bus.OutboundMediaMessage + ch := &mockMediaChannel{ + sendMediaFn: func(_ context.Context, msg bus.OutboundMediaMessage) ([]string, error) { + received = append(received, msg) + return nil, nil + }, + } + + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + m.channels["test"] = ch + m.workers["test"] = w + + msg := testOutboundMediaMessage(bus.OutboundMediaMessage{ + Context: bus.NewOutboundContext("test", "media-chat", ""), + Parts: []bus.MediaPart{{Type: "image", Ref: "media://1"}}, + }) + + if err := m.SendMedia(context.Background(), msg); err != nil { + t.Fatalf("expected no error, got %v", err) + } + if len(received) != 1 { + t.Fatalf("expected 1 media message sent, got %d", len(received)) + } + if received[0].Channel != "test" || received[0].ChatID != "media-chat" { + t.Fatalf("expected mirrored legacy media address, got %+v", received[0]) + } + if received[0].Context.Channel != "test" || received[0].Context.ChatID != "media-chat" { + t.Fatalf("expected media context address to be preserved, got %+v", received[0].Context) + } +} + func TestSendMessage_PreservesOrdering(t *testing.T) { m := newTestManager() @@ -1441,12 +1532,12 @@ func TestSendMessage_PreservesOrdering(t *testing.T) { m.workers["test"] = w // Send two messages sequentially — they must arrive in order - _ = m.SendMessage(context.Background(), bus.OutboundMessage{ + _ = m.SendMessage(context.Background(), testOutboundMessage(bus.OutboundMessage{ Channel: "test", ChatID: "1", Content: "first", - }) - _ = m.SendMessage(context.Background(), bus.OutboundMessage{ + })) + _ = m.SendMessage(context.Background(), testOutboundMessage(bus.OutboundMessage{ Channel: "test", ChatID: "1", Content: "second", - }) + })) if len(order) != 2 { t.Fatalf("expected 2 messages, got %d", len(order)) diff --git a/pkg/channels/matrix/matrix.go b/pkg/channels/matrix/matrix.go index a4061c409..40e1b0a36 100644 --- a/pkg/channels/matrix/matrix.go +++ b/pkg/channels/matrix/matrix.go @@ -739,10 +739,8 @@ func (c *MatrixChannel) handleMessageEvent(ctx context.Context, evt *event.Event } peerKind := "direct" - peerID := senderID if isGroup { peerKind = "group" - peerID = roomID } metadata := map[string]string{ @@ -755,17 +753,19 @@ func (c *MatrixChannel) handleMessageEvent(ctx context.Context, evt *event.Event metadata["reply_to_msg_id"] = replyTo.String() } - c.HandleMessage( - c.baseContext(), - bus.Peer{Kind: peerKind, ID: peerID}, - evt.ID.String(), - senderID, - roomID, - content, - mediaPaths, - metadata, - sender, - ) + inboundCtx := bus.InboundContext{ + Channel: "matrix", + ChatID: roomID, + ChatType: peerKind, + SenderID: senderID, + MessageID: evt.ID.String(), + Raw: metadata, + } + if replyTo := msgEvt.GetRelatesTo().GetReplyTo(); replyTo != "" { + inboundCtx.ReplyToMessageID = replyTo.String() + } + + c.HandleInboundContext(c.baseContext(), roomID, content, mediaPaths, inboundCtx, sender) } // decryptEvent decrypts an encrypted event and returns the decrypted message event content. diff --git a/pkg/channels/onebot/onebot.go b/pkg/channels/onebot/onebot.go index f576bf1d0..f0d0a890f 100644 --- a/pkg/channels/onebot/onebot.go +++ b/pkg/channels/onebot/onebot.go @@ -995,8 +995,8 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) { senderID := strconv.FormatInt(userID, 10) var chatID string - - var peer bus.Peer + var contextChatID string + var contextChatType string metadata := map[string]string{} @@ -1007,12 +1007,14 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) { switch raw.MessageType { case "private": chatID = "private:" + senderID - peer = bus.Peer{Kind: "direct", ID: senderID} + contextChatID = senderID + contextChatType = "direct" case "group": groupIDStr := strconv.FormatInt(groupID, 10) chatID = "group:" + groupIDStr - peer = bus.Peer{Kind: "group", ID: groupIDStr} + contextChatID = groupIDStr + contextChatType = "group" metadata["group_id"] = groupIDStr senderUserID, _ := parseJSONInt64(sender.UserID) @@ -1076,7 +1078,18 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) { return } - c.HandleMessage(c.ctx, peer, messageID, senderID, chatID, content, parsed.Media, metadata, senderInfo) + inboundCtx := bus.InboundContext{ + Channel: c.Name(), + ChatID: contextChatID, + ChatType: contextChatType, + SenderID: senderID, + MessageID: messageID, + Mentioned: isBotMentioned, + ReplyToMessageID: parsed.ReplyTo, + Raw: metadata, + } + + c.HandleInboundContext(c.ctx, chatID, content, parsed.Media, inboundCtx, senderInfo) } func (c *OneBotChannel) isDuplicate(messageID string) bool { diff --git a/pkg/channels/pico/client.go b/pkg/channels/pico/client.go index cdfaa9e44..009900e01 100644 --- a/pkg/channels/pico/client.go +++ b/pkg/channels/pico/client.go @@ -259,8 +259,6 @@ func (c *PicoClientChannel) handleServerMessage(pc *picoConn, msg PicoMessage) { chatID := "pico_client:" + sessionID senderID := "pico-remote" - peer := bus.Peer{Kind: "direct", ID: chatID} - sender := bus.SenderInfo{ Platform: "pico_client", PlatformID: senderID, @@ -271,10 +269,19 @@ func (c *PicoClientChannel) handleServerMessage(pc *picoConn, msg PicoMessage) { return } - c.HandleMessage(c.ctx, peer, msg.ID, senderID, chatID, content, nil, map[string]string{ - "platform": "pico_client", - "session_id": sessionID, - }, sender) + inboundCtx := bus.InboundContext{ + Channel: "pico_client", + ChatID: chatID, + ChatType: "direct", + SenderID: senderID, + MessageID: msg.ID, + Raw: map[string]string{ + "platform": "pico_client", + "session_id": sessionID, + }, + } + + c.HandleInboundContext(c.ctx, chatID, content, nil, inboundCtx, sender) } // Send sends a message to the remote server. diff --git a/pkg/channels/pico/pico.go b/pkg/channels/pico/pico.go index c22cd34d3..f998712c8 100644 --- a/pkg/channels/pico/pico.go +++ b/pkg/channels/pico/pico.go @@ -39,11 +39,11 @@ var allowedInlineImageMIMETypes = map[string]struct{}{ "image/bmp": {}, } -func outboundMessageIsThought(metadata map[string]string) bool { - if len(metadata) == 0 { +func outboundMessageIsThought(msg bus.OutboundMessage) bool { + if len(msg.Context.Raw) == 0 { return false } - return strings.EqualFold(strings.TrimSpace(metadata["message_kind"]), MessageKindThought) + return strings.EqualFold(strings.TrimSpace(msg.Context.Raw["message_kind"]), MessageKindThought) } // writeJSON sends a JSON message to the connection with write locking. @@ -260,7 +260,7 @@ func (c *PicoChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]stri if !c.IsRunning() { return nil, channels.ErrNotRunning } - isThought := outboundMessageIsThought(msg.Metadata) + isThought := outboundMessageIsThought(msg) outMsg := newMessage(TypeMessageCreate, map[string]any{ PayloadKeyContent: msg.Content, @@ -578,8 +578,6 @@ func (c *PicoChannel) handleMessageSend(pc *picoConn, msg PicoMessage) { chatID := "pico:" + sessionID senderID := "pico-user" - peer := bus.Peer{Kind: "direct", ID: "pico:" + sessionID} - metadata := map[string]string{ "platform": "pico", "session_id": sessionID, @@ -602,7 +600,16 @@ func (c *PicoChannel) handleMessageSend(pc *picoConn, msg PicoMessage) { return } - c.HandleMessage(c.ctx, peer, msg.ID, senderID, chatID, content, media, metadata, sender) + inboundCtx := bus.InboundContext{ + Channel: "pico", + ChatID: chatID, + ChatType: "direct", + SenderID: senderID, + MessageID: msg.ID, + Raw: metadata, + } + + c.HandleInboundContext(c.ctx, chatID, content, media, inboundCtx, sender) } // truncate truncates a string to maxLen runes. diff --git a/pkg/channels/qq/qq.go b/pkg/channels/qq/qq.go index e21ff2951..71cba5548 100644 --- a/pkg/channels/qq/qq.go +++ b/pkg/channels/qq/qq.go @@ -590,12 +590,22 @@ func qqFileType(partType string) uint64 { } func (c *QQChannel) maxBase64FileSizeBytes() int64 { + if c.config == nil { + return 0 + } if c.config.MaxBase64FileSizeMiB <= 0 { return 0 } return c.config.MaxBase64FileSizeMiB * bytesPerMiB } +func (c *QQChannel) accountID() string { + if c.config == nil { + return "" + } + return c.config.AppID +} + // handleC2CMessage handles QQ private messages. func (c *QQChannel) handleC2CMessage() event.C2CMessageEventHandler { return func(event *dto.WSPayload, data *dto.WSC2CMessageData) error { @@ -649,17 +659,17 @@ func (c *QQChannel) handleC2CMessage() event.C2CMessageEventHandler { metadata := map[string]string{ "account_id": senderID, } + inboundCtx := bus.InboundContext{ + Channel: c.Name(), + Account: c.accountID(), + ChatID: senderID, + ChatType: "direct", + SenderID: senderID, + MessageID: data.ID, + Raw: metadata, + } - c.HandleMessage(c.ctx, - bus.Peer{Kind: "direct", ID: senderID}, - data.ID, - senderID, - senderID, - content, - mediaPaths, - metadata, - sender, - ) + c.HandleInboundContext(c.ctx, senderID, content, mediaPaths, inboundCtx, sender) return nil } @@ -727,17 +737,18 @@ func (c *QQChannel) handleGroupATMessage() event.GroupATMessageEventHandler { "account_id": senderID, "group_id": data.GroupID, } + inboundCtx := bus.InboundContext{ + Channel: c.Name(), + Account: c.accountID(), + ChatID: data.GroupID, + ChatType: "group", + SenderID: senderID, + MessageID: data.ID, + Mentioned: true, + Raw: metadata, + } - c.HandleMessage(c.ctx, - bus.Peer{Kind: "group", ID: data.GroupID}, - data.ID, - senderID, - data.GroupID, - content, - mediaPaths, - metadata, - sender, - ) + c.HandleInboundContext(c.ctx, data.GroupID, content, mediaPaths, inboundCtx, sender) return nil } diff --git a/pkg/channels/qq/qq_test.go b/pkg/channels/qq/qq_test.go index c3cac1eba..2ab03ab54 100644 --- a/pkg/channels/qq/qq_test.go +++ b/pkg/channels/qq/qq_test.go @@ -54,8 +54,8 @@ func TestHandleC2CMessage_IncludesAccountIDMetadata(t *testing.T) { if !ok { t.Fatal("expected inbound message") } - if inbound.Metadata["account_id"] != "7750283E123456" { - t.Fatalf("account_id metadata = %q, want %q", inbound.Metadata["account_id"], "7750283E123456") + if inbound.Context.Raw["account_id"] != "7750283E123456" { + t.Fatalf("account_id raw = %q, want %q", inbound.Context.Raw["account_id"], "7750283E123456") } return } @@ -165,8 +165,8 @@ func TestHandleGroupATMessage_AttachmentOnlyPublishesMedia(t *testing.T) { if !strings.HasPrefix(inbound.Media[0], "media://") { t.Fatalf("inbound.Media[0] = %q, want media:// ref", inbound.Media[0]) } - if inbound.Peer.Kind != "group" || inbound.Peer.ID != "group-1" { - t.Fatalf("inbound.Peer = %+v, want group/group-1", inbound.Peer) + if inbound.Context.ChatType != "group" { + t.Fatalf("inbound.Context.ChatType = %q, want group", inbound.Context.ChatType) } } diff --git a/pkg/channels/slack/slack.go b/pkg/channels/slack/slack.go index 579c97556..19e7b737c 100644 --- a/pkg/channels/slack/slack.go +++ b/pkg/channels/slack/slack.go @@ -117,7 +117,7 @@ func (c *SlackChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]str return nil, channels.ErrNotRunning } - channelID, threadTS := parseSlackChatID(msg.ChatID) + deliveryChatID, channelID, threadTS := resolveSlackOutboundTarget(msg.ChatID, &msg.Context) if channelID == "" { return nil, fmt.Errorf("invalid slack chat ID: %s", msg.ChatID) } @@ -139,7 +139,7 @@ func (c *SlackChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]str return nil, fmt.Errorf("slack send: %w", channels.ErrTemporary) } - if ref, ok := c.pendingAcks.LoadAndDelete(msg.ChatID); ok { + if ref, ok := c.pendingAcks.LoadAndDelete(deliveryChatID); ok { msgRef := ref.(slackMessageRef) c.api.AddReaction("white_check_mark", slack.ItemRef{ Channel: msgRef.ChannelID, @@ -161,7 +161,7 @@ func (c *SlackChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessa return nil, channels.ErrNotRunning } - channelID, _ := parseSlackChatID(msg.ChatID) + _, channelID, threadTS := resolveSlackMediaOutboundTarget(msg.ChatID, &msg.Context) if channelID == "" { return nil, fmt.Errorf("invalid slack chat ID: %s", msg.ChatID) } @@ -192,10 +192,11 @@ func (c *SlackChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessa } _, err = c.api.UploadFileV2Context(ctx, slack.UploadFileV2Parameters{ - Channel: channelID, - File: localPath, - Filename: filename, - Title: title, + Channel: channelID, + ThreadTimestamp: threadTS, + File: localPath, + Filename: filename, + Title: title, }) if err != nil { logger.ErrorCF("slack", "Failed to upload media", map[string]any{ @@ -360,14 +361,10 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { } peerKind := "channel" - peerID := channelID if strings.HasPrefix(channelID, "D") { peerKind = "direct" - peerID = senderID } - peer := bus.Peer{Kind: peerKind, ID: peerID} - metadata := map[string]string{ "message_ts": messageTS, "channel_id": channelID, @@ -383,7 +380,22 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { "has_thread": threadTS != "", }) - c.HandleMessage(c.ctx, peer, messageTS, senderID, chatID, content, mediaPaths, metadata, sender) + inboundCtx := bus.InboundContext{ + Channel: c.Name(), + Account: c.teamID, + ChatID: channelID, + ChatType: peerKind, + SenderID: senderID, + MessageID: messageTS, + SpaceID: c.teamID, + SpaceType: "workspace", + Raw: metadata, + } + if threadTS != "" { + inboundCtx.TopicID = threadTS + } + + c.HandleInboundContext(c.ctx, chatID, content, mediaPaths, inboundCtx, sender) } func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) { @@ -431,14 +443,10 @@ func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) { } mentionPeerKind := "channel" - mentionPeerID := channelID if strings.HasPrefix(channelID, "D") { mentionPeerKind = "direct" - mentionPeerID = senderID } - mentionPeer := bus.Peer{Kind: mentionPeerKind, ID: mentionPeerID} - metadata := map[string]string{ "message_ts": messageTS, "channel_id": channelID, @@ -447,8 +455,21 @@ func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) { "is_mention": "true", "team_id": c.teamID, } + inboundCtx := bus.InboundContext{ + Channel: c.Name(), + Account: c.teamID, + ChatID: channelID, + ChatType: mentionPeerKind, + TopicID: threadTS, + SenderID: senderID, + MessageID: messageTS, + SpaceID: c.teamID, + SpaceType: "workspace", + Mentioned: true, + Raw: metadata, + } - c.HandleMessage(c.ctx, mentionPeer, messageTS, senderID, chatID, content, nil, metadata, mentionSender) + c.HandleInboundContext(c.ctx, chatID, content, nil, inboundCtx, mentionSender) } func (c *SlackChannel) handleSlashCommand(event socketmode.Event) { @@ -495,18 +516,22 @@ func (c *SlackChannel) handleSlashCommand(event socketmode.Event) { "command": cmd.Command, "text": utils.Truncate(content, 50), }) + peerKind := "channel" + if strings.HasPrefix(channelID, "D") { + peerKind = "direct" + } + inboundCtx := bus.InboundContext{ + Channel: c.Name(), + Account: c.teamID, + ChatID: channelID, + ChatType: peerKind, + SenderID: senderID, + SpaceID: c.teamID, + SpaceType: "workspace", + Raw: metadata, + } - c.HandleMessage( - c.ctx, - bus.Peer{Kind: "channel", ID: channelID}, - "", - senderID, - chatID, - content, - nil, - metadata, - cmdSender, - ) + c.HandleInboundContext(c.ctx, chatID, content, nil, inboundCtx, cmdSender) } func (c *SlackChannel) downloadSlackFile(file slack.File) string { @@ -541,3 +566,33 @@ func parseSlackChatID(chatID string) (channelID, threadTS string) { } return channelID, threadTS } + +func resolveSlackOutboundTarget(chatID string, outboundCtx *bus.InboundContext) (string, string, string) { + deliveryChatID := strings.TrimSpace(chatID) + if deliveryChatID == "" && outboundCtx != nil { + deliveryChatID = strings.TrimSpace(outboundCtx.ChatID) + } + channelID, threadTS := parseSlackChatID(deliveryChatID) + if threadTS == "" && outboundCtx != nil { + threadTS = strings.TrimSpace(outboundCtx.TopicID) + if threadTS != "" && channelID != "" { + deliveryChatID = channelID + "/" + threadTS + } + } + return deliveryChatID, channelID, threadTS +} + +func resolveSlackMediaOutboundTarget(chatID string, outboundCtx *bus.InboundContext) (string, string, string) { + deliveryChatID := strings.TrimSpace(chatID) + if deliveryChatID == "" && outboundCtx != nil { + deliveryChatID = strings.TrimSpace(outboundCtx.ChatID) + } + channelID, threadTS := parseSlackChatID(deliveryChatID) + if threadTS == "" && outboundCtx != nil { + threadTS = strings.TrimSpace(outboundCtx.TopicID) + if threadTS != "" && channelID != "" { + deliveryChatID = channelID + "/" + threadTS + } + } + return deliveryChatID, channelID, threadTS +} diff --git a/pkg/channels/slack/slack_test.go b/pkg/channels/slack/slack_test.go index e4629efb3..a72521d67 100644 --- a/pkg/channels/slack/slack_test.go +++ b/pkg/channels/slack/slack_test.go @@ -53,6 +53,24 @@ func TestParseSlackChatID(t *testing.T) { } } +func TestResolveSlackOutboundTarget_PrefersContextTopicID(t *testing.T) { + deliveryChatID, channelID, threadTS := resolveSlackOutboundTarget("C123456", &bus.InboundContext{ + Channel: "slack", + ChatID: "C123456", + TopicID: "1234567890.123456", + }) + + if deliveryChatID != "C123456/1234567890.123456" { + t.Fatalf("deliveryChatID = %q, want %q", deliveryChatID, "C123456/1234567890.123456") + } + if channelID != "C123456" { + t.Fatalf("channelID = %q, want %q", channelID, "C123456") + } + if threadTS != "1234567890.123456" { + t.Fatalf("threadTS = %q, want %q", threadTS, "1234567890.123456") + } +} + func TestStripBotMention(t *testing.T) { ch := &SlackChannel{botUserID: "U12345BOT"} diff --git a/pkg/channels/telegram/telegram.go b/pkg/channels/telegram/telegram.go index ae0291f09..2a9cfe4ae 100644 --- a/pkg/channels/telegram/telegram.go +++ b/pkg/channels/telegram/telegram.go @@ -182,7 +182,7 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([] useMarkdownV2 := c.tgCfg.UseMarkdownV2 - chatID, threadID, err := parseTelegramChatID(msg.ChatID) + chatID, threadID, err := resolveTelegramOutboundTarget(msg.ChatID, &msg.Context) if err != nil { return nil, fmt.Errorf("invalid chat ID %s: %w", msg.ChatID, channels.ErrSendFailed) } @@ -469,7 +469,7 @@ func (c *TelegramChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMe return nil, channels.ErrNotRunning } - chatID, threadID, err := parseTelegramChatID(msg.ChatID) + chatID, threadID, err := resolveTelegramOutboundTarget(msg.ChatID, &msg.Context) if err != nil { return nil, fmt.Errorf("invalid chat ID %s: %w", msg.ChatID, channels.ErrSendFailed) } @@ -697,8 +697,9 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes } // In group chats, apply unified group trigger filtering + isMentioned := false if message.Chat.Type != "private" { - isMentioned := c.isBotMentioned(message) + isMentioned = c.isBotMentioned(message) if isMentioned { content = c.stripBotMention(content) } @@ -744,13 +745,9 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes }) peerKind := "direct" - peerID := fmt.Sprintf("%d", user.ID) if message.Chat.Type != "private" { peerKind = "group" - peerID = compositeChatID } - - peer := bus.Peer{Kind: peerKind, ID: peerID} messageID := fmt.Sprintf("%d", message.MessageID) metadata := map[string]string{ @@ -759,24 +756,29 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes "first_name": user.FirstName, "is_group": fmt.Sprintf("%t", message.Chat.Type != "private"), } - if message.ReplyToMessage != nil { - metadata["reply_to_message_id"] = fmt.Sprintf("%d", message.ReplyToMessage.MessageID) - } - // Set parent_peer metadata for per-topic agent binding. + inboundCtx := bus.InboundContext{ + Channel: c.Name(), + ChatID: fmt.Sprintf("%d", chatID), + ChatType: peerKind, + SenderID: platformID, + MessageID: messageID, + Mentioned: isMentioned, + Raw: metadata, + } if message.Chat.IsForum && threadID != 0 { - metadata["parent_peer_kind"] = "topic" - metadata["parent_peer_id"] = fmt.Sprintf("%d", threadID) + inboundCtx.TopicID = fmt.Sprintf("%d", threadID) + } + if message.ReplyToMessage != nil { + inboundCtx.ReplyToMessageID = fmt.Sprintf("%d", message.ReplyToMessage.MessageID) } - c.HandleMessage(c.ctx, - peer, - messageID, - platformID, + c.HandleMessageWithContext( + c.ctx, compositeChatID, content, mediaPaths, - metadata, + inboundCtx, sender, ) return nil @@ -964,6 +966,28 @@ func parseTelegramChatID(chatID string) (int64, int, error) { return cid, tid, nil } +func resolveTelegramOutboundTarget(chatID string, outboundCtx *bus.InboundContext) (int64, int, error) { + targetChatID := strings.TrimSpace(chatID) + if targetChatID == "" && outboundCtx != nil { + targetChatID = strings.TrimSpace(outboundCtx.ChatID) + } + resolvedChatID, resolvedThreadID, err := parseTelegramChatID(targetChatID) + if err != nil { + return 0, 0, err + } + if resolvedThreadID != 0 || outboundCtx == nil { + return resolvedChatID, resolvedThreadID, nil + } + topicID := strings.TrimSpace(outboundCtx.TopicID) + if topicID == "" { + return resolvedChatID, resolvedThreadID, nil + } + if threadID, convErr := strconv.Atoi(topicID); convErr == nil { + return resolvedChatID, threadID, nil + } + return resolvedChatID, resolvedThreadID, nil +} + func logParseFailed(err error, useMarkdownV2 bool) { parsingName := "HTML" if useMarkdownV2 { diff --git a/pkg/channels/telegram/telegram_test.go b/pkg/channels/telegram/telegram_test.go index ddf890e71..3d147b337 100644 --- a/pkg/channels/telegram/telegram_test.go +++ b/pkg/channels/telegram/telegram_test.go @@ -528,6 +528,38 @@ func TestSend_WithForumThreadID(t *testing.T) { assert.Len(t, caller.calls, 1) } +func TestSend_UsesContextTopicIDWhenChatIDDoesNotIncludeThread(t *testing.T) { + caller := &stubCaller{ + callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) { + return successResponse(t), nil + }, + } + ch := newTestChannel(t, caller) + + _, err := ch.Send(context.Background(), bus.OutboundMessage{ + ChatID: "-1001234567890", + Content: "Hello from topic context", + Context: bus.InboundContext{ + Channel: "telegram", + ChatID: "-1001234567890", + TopicID: "42", + }, + }) + + require.NoError(t, err) + require.Len(t, caller.calls, 1) + + var params struct { + ChatID int64 `json:"chat_id"` + MessageThreadID int `json:"message_thread_id"` + Text string `json:"text"` + } + require.NoError(t, json.Unmarshal(caller.calls[0].Data.BodyRaw, ¶ms)) + assert.Equal(t, int64(-1001234567890), params.ChatID) + assert.Equal(t, 42, params.MessageThreadID) + assert.Equal(t, "Hello from topic context", params.Text) +} + func TestHandleMessage_ForumTopic_SetsMetadata(t *testing.T) { messageBus := bus.NewMessageBus() ch := &TelegramChannel{ @@ -557,16 +589,10 @@ func TestHandleMessage_ForumTopic_SetsMetadata(t *testing.T) { inbound, ok := <-messageBus.InboundChan() require.True(t, ok, "expected inbound message") - // Composite chatID should include thread ID - assert.Equal(t, "-1001234567890/42", inbound.ChatID) - - // Peer ID should include thread ID for session key isolation - assert.Equal(t, "group", inbound.Peer.Kind) - assert.Equal(t, "-1001234567890/42", inbound.Peer.ID) - - // Parent peer metadata should be set for agent binding - assert.Equal(t, "topic", inbound.Metadata["parent_peer_kind"]) - assert.Equal(t, "42", inbound.Metadata["parent_peer_id"]) + // ChatID remains the parent chat; TopicID isolates the sub-conversation. + assert.Equal(t, "-1001234567890", inbound.ChatID) + assert.Equal(t, "group", inbound.Context.ChatType) + assert.Equal(t, "42", inbound.Context.TopicID) } func TestHandleMessage_NoForum_NoThreadMetadata(t *testing.T) { @@ -599,13 +625,8 @@ func TestHandleMessage_NoForum_NoThreadMetadata(t *testing.T) { // Plain chatID without thread suffix assert.Equal(t, "-100999", inbound.ChatID) - // Peer ID should be raw chat ID (no thread suffix) - assert.Equal(t, "group", inbound.Peer.Kind) - assert.Equal(t, "-100999", inbound.Peer.ID) - - // No parent peer metadata - assert.Empty(t, inbound.Metadata["parent_peer_kind"]) - assert.Empty(t, inbound.Metadata["parent_peer_id"]) + assert.Equal(t, "group", inbound.Context.ChatType) + assert.Empty(t, inbound.Context.TopicID) } func TestHandleMessage_ReplyThread_NonForum_NoIsolation(t *testing.T) { @@ -642,13 +663,8 @@ func TestHandleMessage_ReplyThread_NonForum_NoIsolation(t *testing.T) { // chatID should NOT include thread suffix for non-forum groups assert.Equal(t, "-100999", inbound.ChatID) - // Peer ID should be raw chat ID (shared session for whole group) - assert.Equal(t, "group", inbound.Peer.Kind) - assert.Equal(t, "-100999", inbound.Peer.ID) - - // No parent peer metadata - assert.Empty(t, inbound.Metadata["parent_peer_kind"]) - assert.Empty(t, inbound.Metadata["parent_peer_id"]) + assert.Equal(t, "group", inbound.Context.ChatType) + assert.Empty(t, inbound.Context.TopicID) } func assertHandleMessageQuotedUserReply( @@ -701,7 +717,7 @@ func assertHandleMessageQuotedUserReply( inbound, ok := <-messageBus.InboundChan() require.True(t, ok) - assert.Equal(t, strconv.Itoa(replyMessageID), inbound.Metadata["reply_to_message_id"]) + assert.Equal(t, strconv.Itoa(replyMessageID), inbound.Context.ReplyToMessageID) assert.Equal(t, expectedContent, inbound.Content) } @@ -787,7 +803,7 @@ func TestHandleMessage_ReplyToOwnBotMessage_UsesAssistantRole(t *testing.T) { inbound, ok := <-messageBus.InboundChan() require.True(t, ok) - assert.Equal(t, "101", inbound.Metadata["reply_to_message_id"]) + assert.Equal(t, "101", inbound.Context.ReplyToMessageID) assert.Equal( t, "[quoted assistant message from afjcjsbx_picoclaw_bot]: Fatto! Ho creato il file notizie_2026_03_28.md\n\nti ricordi questo file?", diff --git a/pkg/channels/vk/vk.go b/pkg/channels/vk/vk.go index 47c1091b8..b27431ba0 100644 --- a/pkg/channels/vk/vk.go +++ b/pkg/channels/vk/vk.go @@ -172,14 +172,11 @@ func (c *VKChannel) handleMessage(msg object.MessagesMessage) { _ = groupTrigger } - peerKind := "direct" - peerIDStr := userID + chatType := "direct" if isGroupChat { - peerKind = "group" - peerIDStr = chatID + chatType = "group" } - peer := bus.Peer{Kind: peerKind, ID: peerIDStr} messageID := strconv.Itoa(msg.ConversationMessageID) metadata := map[string]string{ @@ -187,16 +184,15 @@ func (c *VKChannel) handleMessage(msg object.MessagesMessage) { "is_group": fmt.Sprintf("%t", isGroupChat), } - c.HandleMessage(c.ctx, - peer, - messageID, - userID, - chatID, - text, - nil, - metadata, - sender, - ) + c.HandleInboundContext(c.ctx, chatID, text, nil, bus.InboundContext{ + Channel: "vk", + ChatID: chatID, + ChatType: chatType, + SenderID: userID, + MessageID: messageID, + Mentioned: isGroupChat && c.isMentioned(msg), + Raw: metadata, + }, sender) } func (c *VKChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]string, error) { diff --git a/pkg/channels/wecom/wecom.go b/pkg/channels/wecom/wecom.go index dc40f0c69..a0a23feda 100644 --- a/pkg/channels/wecom/wecom.go +++ b/pkg/channels/wecom/wecom.go @@ -570,7 +570,6 @@ func (c *WeComChannel) dispatchIncoming(reqID string, msg wecomIncomingMessage) return err } - peer := bus.Peer{Kind: peerKind, ID: actualChatID} metadata := map[string]string{ "channel": "wecom", "req_id": reqID, @@ -583,7 +582,20 @@ func (c *WeComChannel) dispatchIncoming(reqID string, msg wecomIncomingMessage) metadata["quote_text"] = quoteText } - c.HandleMessage(c.ctx, peer, msg.MsgID, senderID, actualChatID, content, mediaRefs, metadata, sender) + inboundCtx := bus.InboundContext{ + Channel: c.Name(), + Account: strings.TrimSpace(msg.AIBotID), + ChatID: actualChatID, + ChatType: peerKind, + SenderID: senderID, + MessageID: msg.MsgID, + ReplyHandles: map[string]string{ + "req_id": reqID, + }, + Raw: metadata, + } + + c.HandleInboundContext(c.ctx, actualChatID, content, mediaRefs, inboundCtx, sender) return nil } diff --git a/pkg/channels/wecom/wecom_test.go b/pkg/channels/wecom/wecom_test.go index 1e79afae9..85a2f6ef7 100644 --- a/pkg/channels/wecom/wecom_test.go +++ b/pkg/channels/wecom/wecom_test.go @@ -50,11 +50,11 @@ func TestDispatchIncoming_UsesActualChatIDAndStoresReqIDRoute(t *testing.T) { if inbound.MessageID != "msg-1" { t.Fatalf("inbound MessageID = %q, want msg-1", inbound.MessageID) } - if inbound.Peer.ID != "chat-1" { - t.Fatalf("inbound Peer.ID = %q, want chat-1", inbound.Peer.ID) + if inbound.Context.ChatType != "direct" { + t.Fatalf("inbound Context.ChatType = %q, want direct", inbound.Context.ChatType) } - if inbound.Metadata["req_id"] != "req-1" { - t.Fatalf("inbound req_id = %q, want req-1", inbound.Metadata["req_id"]) + if inbound.Context.ReplyHandles["req_id"] != "req-1" { + t.Fatalf("inbound req_id = %q, want req-1", inbound.Context.ReplyHandles["req_id"]) } default: t.Fatal("expected inbound message to be published") diff --git a/pkg/channels/weixin/weixin.go b/pkg/channels/weixin/weixin.go index 589cf164e..2897d2422 100644 --- a/pkg/channels/weixin/weixin.go +++ b/pkg/channels/weixin/weixin.go @@ -357,8 +357,6 @@ func (c *WeixinChannel) handleInboundMessage(ctx context.Context, msg WeixinMess return } - peer := bus.Peer{Kind: "direct", ID: fromUserID} - metadata := map[string]string{ "from_user_id": fromUserID, "context_token": msg.ContextToken, @@ -377,7 +375,21 @@ func (c *WeixinChannel) handleInboundMessage(ctx context.Context, msg WeixinMess c.persistContextTokens() } - c.HandleMessage(ctx, peer, messageID, fromUserID, fromUserID, content, mediaRefs, metadata, sender) + inboundCtx := bus.InboundContext{ + Channel: "weixin", + ChatID: fromUserID, + ChatType: "direct", + SenderID: fromUserID, + MessageID: messageID, + Raw: metadata, + } + if msg.ContextToken != "" { + inboundCtx.ReplyHandles = map[string]string{ + "context_token": msg.ContextToken, + } + } + + c.HandleInboundContext(ctx, fromUserID, content, mediaRefs, inboundCtx, sender) } // Send implements channels.Channel by sending a text message to the WeChat user. diff --git a/pkg/channels/whatsapp/whatsapp.go b/pkg/channels/whatsapp/whatsapp.go index 5c2962a94..4c338b5f4 100644 --- a/pkg/channels/whatsapp/whatsapp.go +++ b/pkg/channels/whatsapp/whatsapp.go @@ -227,13 +227,6 @@ func (c *WhatsAppChannel) handleIncomingMessage(msg map[string]any) { metadata["user_name"] = userName } - var peer bus.Peer - if chatID == senderID { - peer = bus.Peer{Kind: "direct", ID: senderID} - } else { - peer = bus.Peer{Kind: "group", ID: chatID} - } - logger.InfoCF("whatsapp", "WhatsApp message received", map[string]any{ "sender": senderID, "preview": utils.Truncate(content, 50), @@ -252,5 +245,18 @@ func (c *WhatsAppChannel) handleIncomingMessage(msg map[string]any) { return } - c.HandleMessage(c.ctx, peer, messageID, senderID, chatID, content, mediaPaths, metadata, sender) + inboundCtx := bus.InboundContext{ + Channel: "whatsapp", + ChatID: chatID, + SenderID: senderID, + MessageID: messageID, + Raw: metadata, + } + if chatID == senderID { + inboundCtx.ChatType = "direct" + } else { + inboundCtx.ChatType = "group" + } + + c.HandleInboundContext(c.ctx, chatID, content, mediaPaths, inboundCtx, sender) } diff --git a/pkg/channels/whatsapp_native/whatsapp_native.go b/pkg/channels/whatsapp_native/whatsapp_native.go index 32ae085ac..de4ecfd44 100644 --- a/pkg/channels/whatsapp_native/whatsapp_native.go +++ b/pkg/channels/whatsapp_native/whatsapp_native.go @@ -377,7 +377,6 @@ func (c *WhatsAppNativeChannel) handleIncoming(evt *events.Message) { if evt.Info.Chat.Server == types.GroupServer { peerKind = "group" } - peer := bus.Peer{Kind: peerKind, ID: chatID} messageID := evt.Info.ID sender := bus.SenderInfo{ Platform: "whatsapp", @@ -395,7 +394,17 @@ func (c *WhatsAppNativeChannel) handleIncoming(evt *events.Message) { "WhatsApp message received", map[string]any{"sender_id": senderID, "content_preview": utils.Truncate(content, 50)}, ) - c.HandleMessage(c.runCtx, peer, messageID, senderID, chatID, content, mediaPaths, metadata, sender) + + inboundCtx := bus.InboundContext{ + Channel: "whatsapp", + ChatID: chatID, + SenderID: senderID, + MessageID: messageID, + ChatType: peerKind, + Raw: metadata, + } + + c.HandleInboundContext(c.runCtx, chatID, content, mediaPaths, inboundCtx, sender) } func (c *WhatsAppNativeChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]string, error) { diff --git a/pkg/config/config.go b/pkg/config/config.go index fe259fd23..9488fd96c 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -30,10 +30,10 @@ func init() { // Config is the current config structure with version support. type Config struct { - Version int `json:"version" yaml:"-"` // Config schema version for migration + // Config schema version for migration. + Version int `json:"version" yaml:"-"` Isolation IsolationConfig `json:"isolation,omitempty" yaml:"-"` Agents AgentsConfig `json:"agents" yaml:"-"` - Bindings []AgentBinding `json:"bindings,omitempty" yaml:"-"` Session SessionConfig `json:"session,omitempty" yaml:"-"` Channels ChannelsConfig `json:"channel_list" yaml:"channel_list"` ModelList SecureModelList `json:"model_list" yaml:"model_list"` // New model-centric provider configuration @@ -120,7 +120,7 @@ type BuildInfo struct { } // MarshalJSON implements custom JSON marshaling for Config -// to omit providers section when empty and session when empty +// to omit providers section when empty and session when empty. func (c *Config) MarshalJSON() ([]byte, error) { type Alias Config aux := &struct { @@ -130,17 +130,18 @@ func (c *Config) MarshalJSON() ([]byte, error) { Alias: (*Alias)(c), } - // Only include session if not empty - if c.Session.DMScope != "" || len(c.Session.IdentityLinks) > 0 { - aux.Session = &c.Session + if len(c.Session.Dimensions) > 0 || len(c.Session.IdentityLinks) > 0 { + sessionCfg := c.Session + aux.Session = &sessionCfg } return json.Marshal(aux) } type AgentsConfig struct { - Defaults AgentDefaults `json:"defaults"` - List []AgentConfig `json:"list,omitempty"` + Defaults AgentDefaults `json:"defaults"` + List []AgentConfig `json:"list,omitempty"` + Dispatch *DispatchConfig `json:"dispatch,omitempty"` } // AgentModelConfig supports both string and structured model config. @@ -197,26 +198,29 @@ type SubagentsConfig struct { Model *AgentModelConfig `json:"model,omitempty"` } -type PeerMatch struct { - Kind string `json:"kind"` - ID string `json:"id"` +type DispatchConfig struct { + Rules []DispatchRule `json:"rules,omitempty"` } -type BindingMatch struct { - Channel string `json:"channel"` - AccountID string `json:"account_id,omitempty"` - Peer *PeerMatch `json:"peer,omitempty"` - GuildID string `json:"guild_id,omitempty"` - TeamID string `json:"team_id,omitempty"` +type DispatchRule struct { + Name string `json:"name,omitempty"` + Agent string `json:"agent"` + When DispatchSelector `json:"when"` + SessionDimensions []string `json:"session_dimensions,omitempty"` } -type AgentBinding struct { - AgentID string `json:"agent_id"` - Match BindingMatch `json:"match"` +type DispatchSelector struct { + Channel string `json:"channel,omitempty"` + Account string `json:"account,omitempty"` + Space string `json:"space,omitempty"` + Chat string `json:"chat,omitempty"` + Topic string `json:"topic,omitempty"` + Sender string `json:"sender,omitempty"` + Mentioned *bool `json:"mentioned,omitempty"` } type SessionConfig struct { - DMScope string `json:"dm_scope,omitempty"` + Dimensions []string `json:"dimensions,omitempty"` IdentityLinks map[string][]string `json:"identity_links,omitempty"` } @@ -509,9 +513,10 @@ type DevicesConfig struct { } type VoiceConfig struct { - ModelName string `json:"model_name,omitempty" env:"PICOCLAW_VOICE_MODEL_NAME"` - TTSModelName string `json:"tts_model_name,omitempty" env:"PICOCLAW_VOICE_TTS_MODEL_NAME"` - EchoTranscription bool `json:"echo_transcription" env:"PICOCLAW_VOICE_ECHO_TRANSCRIPTION"` + ModelName string `json:"model_name,omitempty" env:"PICOCLAW_VOICE_MODEL_NAME"` + TTSModelName string `json:"tts_model_name,omitempty" env:"PICOCLAW_VOICE_TTS_MODEL_NAME"` + EchoTranscription bool `json:"echo_transcription" env:"PICOCLAW_VOICE_ECHO_TRANSCRIPTION"` + ElevenLabsAPIKey string `json:"elevenlabs_api_key,omitempty" env:"PICOCLAW_VOICE_ELEVENLABS_API_KEY"` } // ModelConfig represents a model-centric provider configuration. @@ -1066,6 +1071,8 @@ func LoadConfig(path string) (*Config, error) { return nil, fmt.Errorf("unsupported config version: %d", versionInfo.Version) } + applyLegacyBindingsMigration(data, cfg) + if err = env.Parse(cfg); err != nil { return nil, err } @@ -1256,6 +1263,7 @@ func expandMultiKeyModels(models []*ModelConfig) []*ModelConfig { ThinkingLevel: m.ThinkingLevel, ExtraBody: m.ExtraBody, CustomHeaders: m.CustomHeaders, + UserAgent: m.UserAgent, isVirtual: true, } expanded = append(expanded, additionalEntry) @@ -1277,6 +1285,7 @@ func expandMultiKeyModels(models []*ModelConfig) []*ModelConfig { ThinkingLevel: m.ThinkingLevel, ExtraBody: m.ExtraBody, CustomHeaders: m.CustomHeaders, + UserAgent: m.UserAgent, APIKeys: SimpleSecureStrings(keys[0]), } diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index 501bdb5c8..42e2d266c 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -109,18 +109,8 @@ func TestAgentConfig_FullParse(t *testing.T) { } ] }, - "bindings": [ - { - "agent_id": "support", - "match": { - "channel": "telegram", - "account_id": "*", - "peer": {"kind": "direct", "id": "user123"} - } - } - ], "session": { - "dm_scope": "per-peer", + "dimensions": ["sender"], "identity_links": { "john": ["telegram:123", "discord:john#1234"] } @@ -158,19 +148,8 @@ func TestAgentConfig_FullParse(t *testing.T) { t.Errorf("support.Subagents = %+v", support.Subagents) } - if len(cfg.Bindings) != 1 { - t.Fatalf("bindings len = %d, want 1", len(cfg.Bindings)) - } - binding := cfg.Bindings[0] - if binding.AgentID != "support" || binding.Match.Channel != "telegram" { - t.Errorf("binding = %+v", binding) - } - if binding.Match.Peer == nil || binding.Match.Peer.Kind != "direct" || binding.Match.Peer.ID != "user123" { - t.Errorf("binding.Match.Peer = %+v", binding.Match.Peer) - } - - if cfg.Session.DMScope != "per-peer" { - t.Errorf("Session.DMScope = %q", cfg.Session.DMScope) + if len(cfg.Session.Dimensions) != 1 || cfg.Session.Dimensions[0] != "sender" { + t.Errorf("Session.Dimensions = %v", cfg.Session.Dimensions) } if len(cfg.Session.IdentityLinks) != 1 { t.Errorf("Session.IdentityLinks = %v", cfg.Session.IdentityLinks) @@ -236,8 +215,242 @@ func TestConfig_BackwardCompat_NoAgentsList(t *testing.T) { if len(cfg.Agents.List) != 0 { t.Errorf("agents.list should be empty for backward compat, got %d", len(cfg.Agents.List)) } - if len(cfg.Bindings) != 0 { - t.Errorf("bindings should be empty, got %d", len(cfg.Bindings)) +} + +func TestAgentConfig_ParsesDispatchRules(t *testing.T) { + jsonData := `{ + "agents": { + "defaults": { + "workspace": "~/.picoclaw/workspace", + "model": "glm-4.7" + }, + "list": [ + { "id": "main", "default": true }, + { "id": "support" } + ], + "dispatch": { + "rules": [ + { + "name": "support-vip", + "agent": "support", + "when": { + "channel": "telegram", + "chat": "group:-100123", + "sender": "12345", + "mentioned": true + }, + "session_dimensions": ["chat", "sender"] + } + ] + } + } + }` + + cfg := DefaultConfig() + if err := json.Unmarshal([]byte(jsonData), cfg); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if cfg.Agents.Dispatch == nil { + t.Fatal("Agents.Dispatch should not be nil") + } + if len(cfg.Agents.Dispatch.Rules) != 1 { + t.Fatalf("Dispatch.Rules len = %d, want 1", len(cfg.Agents.Dispatch.Rules)) + } + rule := cfg.Agents.Dispatch.Rules[0] + if rule.Name != "support-vip" || rule.Agent != "support" { + t.Fatalf("rule = %+v", rule) + } + if rule.When.Channel != "telegram" || rule.When.Chat != "group:-100123" || rule.When.Sender != "12345" { + t.Fatalf("rule.When = %+v", rule.When) + } + if rule.When.Mentioned == nil || !*rule.When.Mentioned { + t.Fatalf("rule.When.Mentioned = %+v, want true", rule.When.Mentioned) + } + if got := rule.SessionDimensions; len(got) != 2 || got[0] != "chat" || got[1] != "sender" { + t.Fatalf("rule.SessionDimensions = %v, want [chat sender]", got) + } +} + +func TestLoadConfig_MigratesLegacyBindingsToDispatchRules(t *testing.T) { + dir := t.TempDir() + configPath := filepath.Join(dir, "config.json") + raw := `{ + "version": 2, + "agents": { + "defaults": { + "workspace": "~/.picoclaw/workspace", + "model": "glm-4.7" + }, + "list": [ + { "id": "main", "default": true }, + { "id": "support" }, + { "id": "ops" }, + { "id": "slack" } + ] + }, + "bindings": [ + { + "agent_id": "support", + "match": { + "channel": "telegram", + "peer": { "kind": "group", "id": "-100123" } + } + }, + { + "agent_id": "ops", + "match": { + "channel": "discord", + "guild_id": "guild-1" + } + }, + { + "agent_id": "slack", + "match": { + "channel": "slack", + "account_id": "*" + } + } + ] + }` + if err := os.WriteFile(configPath, []byte(raw), 0o644); err != nil { + t.Fatalf("WriteFile(configPath): %v", err) + } + + cfg, err := LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error: %v", err) + } + if cfg.Agents.Dispatch == nil { + t.Fatal("Agents.Dispatch should not be nil") + } + if len(cfg.Agents.Dispatch.Rules) != 3 { + t.Fatalf("Dispatch.Rules len = %d, want 3", len(cfg.Agents.Dispatch.Rules)) + } + + first := cfg.Agents.Dispatch.Rules[0] + if first.Agent != "support" { + t.Fatalf("first.Agent = %q, want %q", first.Agent, "support") + } + if first.When.Channel != "telegram" || first.When.Chat != "group:-100123" { + t.Fatalf("first.When = %+v", first.When) + } + if first.When.Account != legacyDefaultAccountID { + t.Fatalf("first.When.Account = %q, want %q", first.When.Account, legacyDefaultAccountID) + } + + second := cfg.Agents.Dispatch.Rules[1] + if second.Agent != "ops" || second.When.Space != "guild:guild-1" { + t.Fatalf("second = %+v", second) + } + + third := cfg.Agents.Dispatch.Rules[2] + if third.Agent != "slack" { + t.Fatalf("third.Agent = %q, want %q", third.Agent, "slack") + } + if third.When.Channel != "slack" || third.When.Account != "" { + t.Fatalf("third.When = %+v", third.When) + } +} + +func TestLoadConfig_PrefersDispatchRulesOverLegacyBindings(t *testing.T) { + dir := t.TempDir() + configPath := filepath.Join(dir, "config.json") + raw := `{ + "version": 2, + "agents": { + "defaults": { + "workspace": "~/.picoclaw/workspace", + "model": "glm-4.7" + }, + "list": [ + { "id": "main", "default": true }, + { "id": "support" } + ], + "dispatch": { + "rules": [ + { + "name": "explicit", + "agent": "support", + "when": { + "channel": "telegram", + "chat": "group:-100123" + } + } + ] + } + }, + "bindings": [ + { + "agent_id": "main", + "match": { + "channel": "telegram", + "account_id": "*" + } + } + ] + }` + if err := os.WriteFile(configPath, []byte(raw), 0o644); err != nil { + t.Fatalf("WriteFile(configPath): %v", err) + } + + cfg, err := LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error: %v", err) + } + if cfg.Agents.Dispatch == nil { + t.Fatal("Agents.Dispatch should not be nil") + } + if len(cfg.Agents.Dispatch.Rules) != 1 { + t.Fatalf("Dispatch.Rules len = %d, want 1", len(cfg.Agents.Dispatch.Rules)) + } + if cfg.Agents.Dispatch.Rules[0].Name != "explicit" { + t.Fatalf("Dispatch.Rules[0].Name = %q, want %q", cfg.Agents.Dispatch.Rules[0].Name, "explicit") + } +} + +func TestLoadConfig_MigratesLegacyDirectBindingsWithIdentityLinks(t *testing.T) { + dir := t.TempDir() + configPath := filepath.Join(dir, "config.json") + raw := `{ + "version": 2, + "agents": { + "defaults": { + "workspace": "~/.picoclaw/workspace", + "model": "glm-4.7" + }, + "list": [ + { "id": "main", "default": true }, + { "id": "support" } + ] + }, + "session": { + "identity_links": { + "john": ["telegram:123", "123"] + } + }, + "bindings": [ + { + "agent_id": "support", + "match": { + "channel": "telegram", + "peer": { "kind": "direct", "id": "123" } + } + } + ] + }` + if err := os.WriteFile(configPath, []byte(raw), 0o644); err != nil { + t.Fatalf("WriteFile(configPath): %v", err) + } + + cfg, err := LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error: %v", err) + } + if cfg.Agents.Dispatch == nil || len(cfg.Agents.Dispatch.Rules) != 1 { + t.Fatalf("Dispatch.Rules = %+v, want 1 migrated rule", cfg.Agents.Dispatch) + } + if got := cfg.Agents.Dispatch.Rules[0].When.Sender; got != "john" { + t.Fatalf("migrated sender selector = %q, want %q", got, "john") } } @@ -374,13 +587,6 @@ func TestDefaultConfig_WebTools(t *testing.T) { } } -func TestDefaultConfig_ReadFileMode(t *testing.T) { - cfg := DefaultConfig() - if cfg.Tools.ReadFile.EffectiveMode() != ReadFileModeBytes { - t.Fatalf("expected default read_file mode %q, got %q", ReadFileModeBytes, cfg.Tools.ReadFile.EffectiveMode()) - } -} - func TestSaveConfig_FilePermissions(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("file permission bits are not enforced on Windows") @@ -825,7 +1031,7 @@ func TestLoadConfig_HooksProcessConfig(t *testing.T) { } } -// TestDefaultConfig_DMScope verifies the default dm_scope value +// TestDefaultConfig_SessionDimensions verifies the default session dimensions // TestDefaultConfig_SummarizationThresholds verifies summarization defaults func TestDefaultConfig_SummarizationThresholds(t *testing.T) { cfg := DefaultConfig() @@ -838,11 +1044,11 @@ func TestDefaultConfig_SummarizationThresholds(t *testing.T) { } } -func TestDefaultConfig_DMScope(t *testing.T) { +func TestDefaultConfig_SessionDimensions(t *testing.T) { cfg := DefaultConfig() - if cfg.Session.DMScope != "per-channel-peer" { - t.Errorf("Session.DMScope = %q, want 'per-channel-peer'", cfg.Session.DMScope) + if len(cfg.Session.Dimensions) != 1 || cfg.Session.Dimensions[0] != "chat" { + t.Errorf("Session.Dimensions = %v, want [chat]", cfg.Session.Dimensions) } } @@ -1076,7 +1282,6 @@ func TestLoadConfig_TelegramPlaceholderTextAcceptsSingleString(t *testing.T) { data := `{ "version": 1, "agents": { "defaults": { "workspace": "", "model": "", "max_tokens": 0, "max_tool_iterations": 0 } }, - "bindings": [], "session": {}, "channels": { "telegram": { diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go index 40f7d5d52..b2054b90c 100644 --- a/pkg/config/defaults.go +++ b/pkg/config/defaults.go @@ -41,9 +41,8 @@ func DefaultConfig() *Config { SplitOnMarker: false, }, }, - Bindings: []AgentBinding{}, Session: SessionConfig{ - DMScope: "per-channel-peer", + Dimensions: []string{"chat"}, }, Channels: defaultChannels(), Hooks: HooksConfig{ @@ -422,7 +421,9 @@ func DefaultConfig() *Config { }, Voice: VoiceConfig{ ModelName: "", + TTSModelName: "", EchoTranscription: false, + ElevenLabsAPIKey: "", }, BuildInfo: BuildInfo{ Version: Version, diff --git a/pkg/config/legacy_bindings.go b/pkg/config/legacy_bindings.go new file mode 100644 index 000000000..751a35de7 --- /dev/null +++ b/pkg/config/legacy_bindings.go @@ -0,0 +1,267 @@ +package config + +import ( + "encoding/json" + "fmt" + "strings" + + "github.com/sipeed/picoclaw/pkg/logger" +) + +const legacyDefaultAccountID = "default" + +type legacyBindingsEnvelope struct { + Bindings json.RawMessage `json:"bindings"` +} + +type legacyAgentBinding struct { + AgentID string `json:"agent_id"` + Match legacyBindingMatch `json:"match"` +} + +type legacyBindingMatch struct { + Channel string `json:"channel"` + AccountID string `json:"account_id,omitempty"` + Peer *legacyPeerMatch `json:"peer,omitempty"` + GuildID string `json:"guild_id,omitempty"` + TeamID string `json:"team_id,omitempty"` +} + +type legacyPeerMatch struct { + Kind string `json:"kind"` + ID string `json:"id"` +} + +func applyLegacyBindingsMigration(data []byte, cfg *Config) { + if cfg == nil { + return + } + + bindings, found, err := decodeLegacyBindings(data) + if err != nil { + logger.WarnF( + "legacy bindings config detected but could not be decoded", + map[string]any{"error": err}, + ) + return + } + if !found { + return + } + + if cfg.Agents.Dispatch != nil && len(cfg.Agents.Dispatch.Rules) > 0 { + logger.WarnF( + "legacy bindings config is deprecated and ignored because agents.dispatch.rules is configured", + map[string]any{"bindings": len(bindings), "dispatch_rules": len(cfg.Agents.Dispatch.Rules)}, + ) + return + } + + rules, dropped := migrateLegacyBindings(bindings, cfg.Session.IdentityLinks) + if len(rules) == 0 { + logger.WarnF( + "legacy bindings config is deprecated and could not be migrated", + map[string]any{"bindings": len(bindings), "dropped_bindings": dropped}, + ) + return + } + + if cfg.Agents.Dispatch == nil { + cfg.Agents.Dispatch = &DispatchConfig{} + } + cfg.Agents.Dispatch.Rules = rules + + fields := map[string]any{ + "bindings": len(bindings), + "dispatch_rules": len(rules), + } + if dropped > 0 { + fields["dropped_bindings"] = dropped + } + logger.WarnF("legacy bindings config is deprecated; migrated to agents.dispatch.rules in memory", fields) +} + +func decodeLegacyBindings(data []byte) ([]legacyAgentBinding, bool, error) { + var envelope legacyBindingsEnvelope + if err := json.Unmarshal(data, &envelope); err != nil { + return nil, false, err + } + if len(envelope.Bindings) == 0 { + return nil, false, nil + } + + var bindings []legacyAgentBinding + if err := json.Unmarshal(envelope.Bindings, &bindings); err != nil { + return nil, true, err + } + return bindings, true, nil +} + +func migrateLegacyBindings(bindings []legacyAgentBinding, identityLinks map[string][]string) ([]DispatchRule, int) { + if len(bindings) == 0 { + return nil, 0 + } + + type prioritizedRule struct { + rule DispatchRule + index int + kind int + } + + prioritized := make([]prioritizedRule, 0, len(bindings)) + dropped := 0 + for i, binding := range bindings { + rule, kind, ok := migrateLegacyBinding(binding, i, identityLinks) + if !ok { + dropped++ + continue + } + prioritized = append(prioritized, prioritizedRule{rule: rule, index: i, kind: kind}) + } + if len(prioritized) == 0 { + return nil, dropped + } + + rules := make([]DispatchRule, 0, len(prioritized)) + for kind := 0; kind <= 4; kind++ { + for _, item := range prioritized { + if item.kind == kind { + rules = append(rules, item.rule) + } + } + } + return rules, dropped +} + +func migrateLegacyBinding( + binding legacyAgentBinding, + index int, + identityLinks map[string][]string, +) (DispatchRule, int, bool) { + channel := strings.ToLower(strings.TrimSpace(binding.Match.Channel)) + agentID := strings.TrimSpace(binding.AgentID) + if channel == "" || agentID == "" { + return DispatchRule{}, 0, false + } + + rule := DispatchRule{ + Name: fmt.Sprintf("legacy-binding-%d", index+1), + Agent: agentID, + When: DispatchSelector{ + Channel: channel, + }, + } + + switch normalizeLegacyAccountSelector(binding.Match.AccountID) { + case "": + case "*": + default: + rule.When.Account = normalizeLegacyAccountSelector(binding.Match.AccountID) + } + + if peer := binding.Match.Peer; peer != nil { + peerKind := strings.ToLower(strings.TrimSpace(peer.Kind)) + peerID := strings.TrimSpace(peer.ID) + if peerID == "" { + return DispatchRule{}, 0, false + } + switch peerKind { + case "direct": + rule.When.Sender = canonicalLegacyBindingSenderID(channel, peerID, identityLinks) + return rule, 0, true + case "group", "channel": + rule.When.Chat = peerKind + ":" + peerID + return rule, 0, true + case "topic": + rule.When.Topic = "topic:" + peerID + return rule, 0, true + default: + return DispatchRule{}, 0, false + } + } + + if guildID := strings.TrimSpace(binding.Match.GuildID); guildID != "" { + rule.When.Space = "guild:" + guildID + return rule, 1, true + } + + if teamID := strings.TrimSpace(binding.Match.TeamID); teamID != "" { + rule.When.Space = "team:" + teamID + return rule, 2, true + } + + accountSelector := normalizeLegacyAccountSelector(binding.Match.AccountID) + if accountSelector == "*" { + rule.When.Account = "" + return rule, 4, true + } + + rule.When.Account = accountSelector + return rule, 3, true +} + +func normalizeLegacyAccountSelector(accountID string) string { + accountID = strings.TrimSpace(accountID) + switch accountID { + case "": + return legacyDefaultAccountID + case "*": + return "*" + default: + return strings.ToLower(accountID) + } +} + +func canonicalLegacyBindingSenderID(channel, peerID string, identityLinks map[string][]string) string { + peerID = strings.TrimSpace(peerID) + if peerID == "" { + return "" + } + + if linked := resolveLegacyBindingLinkedID(identityLinks, channel, peerID); linked != "" { + return strings.ToLower(linked) + } + + return strings.ToLower(peerID) +} + +func resolveLegacyBindingLinkedID(identityLinks map[string][]string, channel, peerID string) string { + if len(identityLinks) == 0 { + return "" + } + peerID = strings.TrimSpace(peerID) + if peerID == "" { + return "" + } + + candidates := make(map[string]struct{}) + rawCandidate := strings.ToLower(peerID) + if rawCandidate != "" { + candidates[rawCandidate] = struct{}{} + } + channel = strings.ToLower(strings.TrimSpace(channel)) + if channel != "" { + candidates[channel+":"+rawCandidate] = struct{}{} + } + if idx := strings.Index(rawCandidate, ":"); idx > 0 && idx < len(rawCandidate)-1 { + candidates[rawCandidate[idx+1:]] = struct{}{} + } + + for canonical, ids := range identityLinks { + canonical = strings.TrimSpace(canonical) + if canonical == "" { + continue + } + for _, id := range ids { + normalized := strings.ToLower(strings.TrimSpace(id)) + if normalized == "" { + continue + } + if _, ok := candidates[normalized]; ok { + return canonical + } + } + } + + return "" +} diff --git a/pkg/devices/service.go b/pkg/devices/service.go index 1bafe6085..1cf2a686e 100644 --- a/pkg/devices/service.go +++ b/pkg/devices/service.go @@ -131,8 +131,7 @@ func (s *Service) sendNotification(ev *events.DeviceEvent) { pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second) defer pubCancel() msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{ - Channel: platform, - ChatID: userID, + Context: bus.NewOutboundContext(platform, userID, ""), Content: msg, }) diff --git a/pkg/heartbeat/service.go b/pkg/heartbeat/service.go index 5dda78ea9..e5b28ec11 100644 --- a/pkg/heartbeat/service.go +++ b/pkg/heartbeat/service.go @@ -339,8 +339,7 @@ func (hs *HeartbeatService) sendResponse(response string) { pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second) defer pubCancel() msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{ - Channel: platform, - ChatID: userID, + Context: bus.NewOutboundContext(platform, userID, ""), Content: response, }) diff --git a/pkg/memory/jsonl.go b/pkg/memory/jsonl.go index fc1ec8eb1..8d3320f3f 100644 --- a/pkg/memory/jsonl.go +++ b/pkg/memory/jsonl.go @@ -32,14 +32,19 @@ const ( maxLineSize = 10 * 1024 * 1024 // 10 MB ) -// sessionMeta holds per-session metadata stored in a .meta.json file. -type sessionMeta struct { - Key string `json:"key"` - Summary string `json:"summary"` - Skip int `json:"skip"` - Count int `json:"count"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` +// SessionMeta holds per-session metadata stored in a .meta.json file. +// +// Scope is stored as raw JSON so pkg/memory can stay decoupled from the +// higher-level session package while still preserving structured scope data. +type SessionMeta struct { + Key string `json:"key"` + Summary string `json:"summary"` + Skip int `json:"skip"` + Count int `json:"count"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + Scope json.RawMessage `json:"scope,omitempty"` + Aliases []string `json:"aliases,omitempty"` } // JSONLStore implements Store using append-only JSONL files. @@ -98,25 +103,31 @@ func sanitizeKey(key string) string { // readMeta loads the metadata file for a session. // Returns a zero-value sessionMeta if the file does not exist. -func (s *JSONLStore) readMeta(key string) (sessionMeta, error) { +func (s *JSONLStore) readMeta(key string) (SessionMeta, error) { data, err := os.ReadFile(s.metaPath(key)) if os.IsNotExist(err) { - return sessionMeta{Key: key}, nil + return SessionMeta{Key: key}, nil } if err != nil { - return sessionMeta{}, fmt.Errorf("memory: read meta: %w", err) + return SessionMeta{}, fmt.Errorf("memory: read meta: %w", err) } - var meta sessionMeta + var meta SessionMeta err = json.Unmarshal(data, &meta) if err != nil { - return sessionMeta{}, fmt.Errorf("memory: decode meta: %w", err) + return SessionMeta{}, fmt.Errorf("memory: decode meta: %w", err) + } + if meta.Key == "" { + meta.Key = key } return meta, nil } // writeMeta atomically writes the metadata file using the project's // standard WriteFileAtomic (temp + fsync + rename). -func (s *JSONLStore) writeMeta(key string, meta sessionMeta) error { +func (s *JSONLStore) writeMeta(key string, meta SessionMeta) error { + if strings.TrimSpace(meta.Key) == "" { + meta.Key = key + } data, err := json.MarshalIndent(meta, "", " ") if err != nil { return fmt.Errorf("memory: encode meta: %w", err) @@ -124,6 +135,314 @@ func (s *JSONLStore) writeMeta(key string, meta sessionMeta) error { return fileutil.WriteFileAtomic(s.metaPath(key), data, 0o644) } +func cloneRawJSON(data json.RawMessage) json.RawMessage { + if len(data) == 0 { + return nil + } + return append(json.RawMessage(nil), data...) +} + +func normalizeAliases(canonicalKey string, aliases []string) []string { + if len(aliases) == 0 { + return nil + } + normalized := make([]string, 0, len(aliases)) + seen := make(map[string]struct{}, len(aliases)) + canonicalKey = strings.TrimSpace(canonicalKey) + for _, alias := range aliases { + alias = strings.TrimSpace(alias) + if alias == "" || alias == canonicalKey { + continue + } + if _, ok := seen[alias]; ok { + continue + } + seen[alias] = struct{}{} + normalized = append(normalized, alias) + } + if len(normalized) == 0 { + return nil + } + return normalized +} + +func (s *JSONLStore) sessionExists(key string) bool { + if key == "" { + return false + } + if _, err := os.Stat(s.jsonlPath(key)); err == nil { + return true + } + if _, err := os.Stat(s.metaPath(key)); err == nil { + return true + } + return false +} + +// GetSessionMeta returns the current metadata snapshot for sessionKey. +func (s *JSONLStore) GetSessionMeta(_ context.Context, sessionKey string) (SessionMeta, error) { + l := s.sessionLock(sessionKey) + l.Lock() + defer l.Unlock() + + meta, err := s.readMeta(sessionKey) + if err != nil { + return SessionMeta{}, err + } + meta.Scope = cloneRawJSON(meta.Scope) + if len(meta.Aliases) > 0 { + meta.Aliases = append([]string(nil), meta.Aliases...) + } + return meta, nil +} + +// UpsertSessionMeta stores structured session metadata while preserving +// summary/count/skip timestamps maintained by the core JSONL store. +func (s *JSONLStore) UpsertSessionMeta( + _ context.Context, + sessionKey string, + scope json.RawMessage, + aliases []string, +) error { + l := s.sessionLock(sessionKey) + l.Lock() + defer l.Unlock() + + meta, err := s.readMeta(sessionKey) + if err != nil { + return err + } + meta.Scope = cloneRawJSON(scope) + meta.Aliases = normalizeAliases(sessionKey, aliases) + now := time.Now() + if meta.CreatedAt.IsZero() { + meta.CreatedAt = now + } + meta.UpdatedAt = now + + return s.writeMeta(sessionKey, meta) +} + +// PromoteAliasHistory atomically promotes the first non-empty alias session +// into the canonical session when the canonical session is still empty. +func (s *JSONLStore) PromoteAliasHistory( + _ context.Context, + sessionKey string, + scope json.RawMessage, + aliases []string, +) (bool, error) { + sessionKey = strings.TrimSpace(sessionKey) + if sessionKey == "" { + return false, nil + } + + aliases = normalizeAliases(sessionKey, aliases) + for _, alias := range aliases { + unlock := s.lockSessionPair(sessionKey, alias) + promoted, err := s.promoteAliasHistoryLocked(sessionKey, alias, scope, aliases) + unlock() + if err != nil || promoted { + return promoted, err + } + } + + return false, nil +} + +// ResolveSessionKey returns the canonical session key for a candidate key. +// It short-circuits direct canonical keys when possible, then scans metadata +// once to resolve aliases or canonical metadata keys. +func (s *JSONLStore) ResolveSessionKey(_ context.Context, sessionKey string) (string, bool, error) { + sessionKey = strings.TrimSpace(sessionKey) + if sessionKey == "" { + return "", false, nil + } + + hasDirectSession := s.sessionExists(sessionKey) + if hasDirectSession && shouldShortCircuitSessionResolve(sessionKey) { + return sessionKey, true, nil + } + + entries, err := os.ReadDir(s.dir) + if err != nil { + return "", false, fmt.Errorf("memory: read sessions dir: %w", err) + } + + var directMetaMatch string + for _, entry := range entries { + if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".meta.json") { + continue + } + + data, readErr := os.ReadFile(filepath.Join(s.dir, entry.Name())) + if readErr != nil { + log.Printf("memory: skipping unreadable meta %s: %v", entry.Name(), readErr) + continue + } + + var meta SessionMeta + if err := json.Unmarshal(data, &meta); err != nil { + log.Printf("memory: skipping corrupt meta %s: %v", entry.Name(), err) + continue + } + + if meta.Key == "" { + continue + } + + if meta.Key == sessionKey { + directMetaMatch = meta.Key + } + + for _, alias := range meta.Aliases { + if alias == sessionKey && meta.Key != sessionKey { + return meta.Key, true, nil + } + } + } + + if directMetaMatch != "" { + return directMetaMatch, true, nil + } + + if hasDirectSession { + return sessionKey, true, nil + } + + return "", false, nil +} + +func shouldShortCircuitSessionResolve(sessionKey string) bool { + sessionKey = strings.TrimSpace(strings.ToLower(sessionKey)) + if sessionKey == "" { + return false + } + return !strings.ContainsAny(sessionKey, ":/\\") +} + +func (s *JSONLStore) lockSessionPair(keyA, keyB string) func() { + lockA := s.sessionLock(keyA) + lockB := s.sessionLock(keyB) + if lockA == lockB { + lockA.Lock() + return func() { lockA.Unlock() } + } + if keyA <= keyB { + lockA.Lock() + lockB.Lock() + return func() { + lockB.Unlock() + lockA.Unlock() + } + } + lockB.Lock() + lockA.Lock() + return func() { + lockA.Unlock() + lockB.Unlock() + } +} + +func (s *JSONLStore) promoteAliasHistoryLocked( + sessionKey string, + alias string, + scope json.RawMessage, + aliases []string, +) (bool, error) { + canonicalMeta, err := s.readMeta(sessionKey) + if err != nil { + return false, err + } + canonicalHasContent, err := s.sessionHasVisibleContentLocked(sessionKey, canonicalMeta) + if err != nil { + return false, err + } + if canonicalHasContent { + return false, nil + } + + aliasMeta, err := s.readMeta(alias) + if err != nil { + return false, err + } + aliasHistory, err := readMessages(s.jsonlPath(alias), aliasMeta.Skip) + if err != nil { + return false, err + } + aliasSummary := strings.TrimSpace(aliasMeta.Summary) + if len(aliasHistory) == 0 && aliasSummary == "" { + return false, nil + } + + previousJSONL, hadPreviousJSONL, err := s.readRawJSONL(sessionKey) + if err != nil { + return false, err + } + + now := time.Now() + if canonicalMeta.CreatedAt.IsZero() { + canonicalMeta.CreatedAt = now + } + canonicalMeta.Scope = cloneRawJSON(scope) + canonicalMeta.Aliases = normalizeAliases(sessionKey, aliases) + canonicalMeta.Skip = 0 + canonicalMeta.Count = len(aliasHistory) + canonicalMeta.UpdatedAt = now + if aliasSummary != "" { + canonicalMeta.Summary = aliasSummary + } + + if err := s.rewriteJSONL(sessionKey, aliasHistory); err != nil { + return false, err + } + if err := s.writeMeta(sessionKey, canonicalMeta); err != nil { + if rollbackErr := s.restoreRawJSONL(sessionKey, previousJSONL, hadPreviousJSONL); rollbackErr != nil { + return false, fmt.Errorf("memory: write promoted meta: %w (rollback jsonl: %v)", err, rollbackErr) + } + return false, err + } + return true, nil +} + +func (s *JSONLStore) sessionHasVisibleContentLocked(sessionKey string, meta SessionMeta) (bool, error) { + if meta.Count-meta.Skip > 0 || strings.TrimSpace(meta.Summary) != "" { + return true, nil + } + if meta.Count != 0 || meta.Skip != 0 { + return false, nil + } + history, err := readMessages(s.jsonlPath(sessionKey), meta.Skip) + if err != nil { + return false, err + } + return len(history) > 0, nil +} + +func (s *JSONLStore) readRawJSONL(sessionKey string) ([]byte, bool, error) { + data, err := os.ReadFile(s.jsonlPath(sessionKey)) + if os.IsNotExist(err) { + return nil, false, nil + } + if err != nil { + return nil, false, fmt.Errorf("memory: read jsonl: %w", err) + } + return data, true, nil +} + +func (s *JSONLStore) restoreRawJSONL(sessionKey string, data []byte, existed bool) error { + path := s.jsonlPath(sessionKey) + if !existed { + if err := os.Remove(path); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("memory: remove jsonl rollback: %w", err) + } + return nil + } + if err := fileutil.WriteFileAtomic(path, data, 0o644); err != nil { + return fmt.Errorf("memory: restore jsonl rollback: %w", err) + } + return nil +} + // readMessages reads valid JSON lines from a .jsonl file, skipping // the first `skip` lines without unmarshaling them. This avoids the // cost of json.Unmarshal on logically truncated messages. @@ -471,7 +790,7 @@ func (s *JSONLStore) ListSessions() []string { if err != nil { continue } - var meta sessionMeta + var meta SessionMeta if err := json.Unmarshal(data, &meta); err != nil { continue } diff --git a/pkg/memory/jsonl_test.go b/pkg/memory/jsonl_test.go index 356ff14ff..b64c1b25f 100644 --- a/pkg/memory/jsonl_test.go +++ b/pkg/memory/jsonl_test.go @@ -2,8 +2,10 @@ package memory import ( "context" + "encoding/json" "os" "path/filepath" + "reflect" "sync" "testing" @@ -241,6 +243,142 @@ func TestSetSummary_GetSummary(t *testing.T) { } } +func TestSessionMetaScopeAndAliasesPersist(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + scope := json.RawMessage(`{"version":1,"channel":"telegram","values":{"chat":"group:c1"}}`) + aliases := []string{"legacy:one", "legacy:one", "canonical"} + if err := store.UpsertSessionMeta(ctx, "canonical", scope, aliases); err != nil { + t.Fatalf("UpsertSessionMeta() error = %v", err) + } + + meta, err := store.GetSessionMeta(ctx, "canonical") + if err != nil { + t.Fatalf("GetSessionMeta() error = %v", err) + } + var gotScope map[string]any + if err := json.Unmarshal(meta.Scope, &gotScope); err != nil { + t.Fatalf("Unmarshal(meta.Scope) error = %v", err) + } + var wantScope map[string]any + if err := json.Unmarshal(scope, &wantScope); err != nil { + t.Fatalf("Unmarshal(scope) error = %v", err) + } + if !reflect.DeepEqual(gotScope, wantScope) { + t.Fatalf("meta.Scope = %#v, want %#v", gotScope, wantScope) + } + if len(meta.Aliases) != 1 || meta.Aliases[0] != "legacy:one" { + t.Fatalf("meta.Aliases = %#v, want [legacy:one]", meta.Aliases) + } +} + +func TestResolveSessionKeyByAlias(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + if err := store.AddMessage(ctx, "canonical", "user", "hello"); err != nil { + t.Fatalf("AddMessage() error = %v", err) + } + if err := store.UpsertSessionMeta(ctx, "canonical", nil, []string{"legacy:key"}); err != nil { + t.Fatalf("UpsertSessionMeta() error = %v", err) + } + + resolved, found, err := store.ResolveSessionKey(ctx, "legacy:key") + if err != nil { + t.Fatalf("ResolveSessionKey() error = %v", err) + } + if !found { + t.Fatal("ResolveSessionKey() did not find alias") + } + if resolved != "canonical" { + t.Fatalf("resolved = %q, want %q", resolved, "canonical") + } +} + +func TestResolveSessionKeyByAlias_PrefersMetadataOverLegacyFile(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + if err := store.AddMessage(ctx, "legacy:key", "user", "legacy"); err != nil { + t.Fatalf("AddMessage(legacy) error = %v", err) + } + if err := store.AddMessage(ctx, "canonical", "user", "canonical"); err != nil { + t.Fatalf("AddMessage(canonical) error = %v", err) + } + if err := store.UpsertSessionMeta(ctx, "canonical", nil, []string{"legacy:key"}); err != nil { + t.Fatalf("UpsertSessionMeta() error = %v", err) + } + + resolved, found, err := store.ResolveSessionKey(ctx, "legacy:key") + if err != nil { + t.Fatalf("ResolveSessionKey() error = %v", err) + } + if !found { + t.Fatal("ResolveSessionKey() did not find alias") + } + if resolved != "canonical" { + t.Fatalf("resolved = %q, want %q", resolved, "canonical") + } +} + +func TestResolveSessionKey_DirectHitSkipsCorruptMetadata(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + if err := store.AddMessage(ctx, "canonical", "user", "hello"); err != nil { + t.Fatalf("AddMessage() error = %v", err) + } + if err := os.WriteFile( + filepath.Join(store.dir, "broken.meta.json"), + []byte("{not-json"), + 0o644, + ); err != nil { + t.Fatalf("WriteFile(broken.meta.json) error = %v", err) + } + + resolved, found, err := store.ResolveSessionKey(ctx, "canonical") + if err != nil { + t.Fatalf("ResolveSessionKey() error = %v", err) + } + if !found { + t.Fatal("ResolveSessionKey() did not find direct session") + } + if resolved != "canonical" { + t.Fatalf("resolved = %q, want %q", resolved, "canonical") + } +} + +func TestResolveSessionKey_SkipsCorruptMetadataDuringAliasScan(t *testing.T) { + store := newTestStore(t) + ctx := context.Background() + + if err := store.AddMessage(ctx, "canonical", "user", "hello"); err != nil { + t.Fatalf("AddMessage() error = %v", err) + } + if err := store.UpsertSessionMeta(ctx, "canonical", nil, []string{"legacy:key"}); err != nil { + t.Fatalf("UpsertSessionMeta() error = %v", err) + } + if err := os.WriteFile( + filepath.Join(store.dir, "broken.meta.json"), + []byte("{not-json"), + 0o644, + ); err != nil { + t.Fatalf("WriteFile(broken.meta.json) error = %v", err) + } + + resolved, found, err := store.ResolveSessionKey(ctx, "legacy:key") + if err != nil { + t.Fatalf("ResolveSessionKey() error = %v", err) + } + if !found { + t.Fatal("ResolveSessionKey() did not find alias") + } + if resolved != "canonical" { + t.Fatalf("resolved = %q, want %q", resolved, "canonical") + } +} + func TestTruncateHistory_KeepLast(t *testing.T) { store := newTestStore(t) ctx := context.Background() diff --git a/pkg/routing/route.go b/pkg/routing/route.go index 9eb060c53..023f35a25 100644 --- a/pkg/routing/route.go +++ b/pkg/routing/route.go @@ -1,32 +1,29 @@ package routing import ( + "fmt" "strings" + "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" ) -// RouteInput contains the routing context from an inbound message. -type RouteInput struct { - Channel string - AccountID string - Peer *RoutePeer - ParentPeer *RoutePeer - GuildID string - TeamID string +// SessionPolicy describes how a routed message should be mapped to a session. +type SessionPolicy struct { + Dimensions []string + IdentityLinks map[string][]string } // ResolvedRoute is the result of agent routing. type ResolvedRoute struct { - AgentID string - Channel string - AccountID string - SessionKey string - MainSessionKey string - MatchedBy string // "binding.peer", "binding.peer.parent", "binding.guild", "binding.team", "binding.account", "binding.channel", "default" + AgentID string + Channel string + AccountID string + SessionPolicy SessionPolicy + MatchedBy string } -// RouteResolver determines which agent handles a message based on config bindings. +// RouteResolver determines which agent handles a message. type RouteResolver struct { cfg *config.Config } @@ -36,182 +33,32 @@ func NewRouteResolver(cfg *config.Config) *RouteResolver { return &RouteResolver{cfg: cfg} } -// ResolveRoute determines which agent handles the message and constructs session keys. -// Implements the 7-level priority cascade: -// peer > parent_peer > guild > team > account > channel_wildcard > default -func (r *RouteResolver) ResolveRoute(input RouteInput) ResolvedRoute { - channel := strings.ToLower(strings.TrimSpace(input.Channel)) - accountID := NormalizeAccountID(input.AccountID) - peer := input.Peer +// ResolveRoute determines which agent handles the message from a normalized +// inbound context and returns the session policy that should be used to +// allocate session state. +func (r *RouteResolver) ResolveRoute(inbound bus.InboundContext) ResolvedRoute { + channel := strings.ToLower(strings.TrimSpace(inbound.Channel)) + accountID := NormalizeAccountID(inbound.Account) + identityLinks := cloneIdentityLinks(r.cfg.Session.IdentityLinks) + view := buildDispatchView(inbound, identityLinks) - dmScope := DMScope(r.cfg.Session.DMScope) - if dmScope == "" { - dmScope = DMScopeMain - } - identityLinks := r.cfg.Session.IdentityLinks - - bindings := r.filterBindings(channel, accountID) - - choose := func(agentID string, matchedBy string) ResolvedRoute { - resolvedAgentID := r.pickAgentID(agentID) - sessionKey := strings.ToLower(BuildAgentPeerSessionKey(SessionKeyParams{ - AgentID: resolvedAgentID, + if rule := r.matchDispatchRule(view); rule != nil { + return ResolvedRoute{ + AgentID: r.pickAgentID(rule.Agent), Channel: channel, AccountID: accountID, - Peer: peer, - DMScope: dmScope, - IdentityLinks: identityLinks, - })) - mainSessionKey := strings.ToLower(BuildAgentMainSessionKey(resolvedAgentID)) - return ResolvedRoute{ - AgentID: resolvedAgentID, - Channel: channel, - AccountID: accountID, - SessionKey: sessionKey, - MainSessionKey: mainSessionKey, - MatchedBy: matchedBy, + SessionPolicy: r.sessionPolicy(rule), + MatchedBy: matchedByForRule(rule), } } - // Priority 1: Peer binding - if peer != nil && strings.TrimSpace(peer.ID) != "" { - if match := r.findPeerMatch(bindings, peer); match != nil { - return choose(match.AgentID, "binding.peer") - } + return ResolvedRoute{ + AgentID: r.pickAgentID(r.resolveDefaultAgentID()), + Channel: channel, + AccountID: accountID, + SessionPolicy: r.sessionPolicy(nil), + MatchedBy: "default", } - - // Priority 2: Parent peer binding - parentPeer := input.ParentPeer - if parentPeer != nil && strings.TrimSpace(parentPeer.ID) != "" { - if match := r.findPeerMatch(bindings, parentPeer); match != nil { - return choose(match.AgentID, "binding.peer.parent") - } - } - - // Priority 3: Guild binding - guildID := strings.TrimSpace(input.GuildID) - if guildID != "" { - if match := r.findGuildMatch(bindings, guildID); match != nil { - return choose(match.AgentID, "binding.guild") - } - } - - // Priority 4: Team binding - teamID := strings.TrimSpace(input.TeamID) - if teamID != "" { - if match := r.findTeamMatch(bindings, teamID); match != nil { - return choose(match.AgentID, "binding.team") - } - } - - // Priority 5: Account binding - if match := r.findAccountMatch(bindings); match != nil { - return choose(match.AgentID, "binding.account") - } - - // Priority 6: Channel wildcard binding - if match := r.findChannelWildcardMatch(bindings); match != nil { - return choose(match.AgentID, "binding.channel") - } - - // Priority 7: Default agent - return choose(r.resolveDefaultAgentID(), "default") -} - -func (r *RouteResolver) filterBindings(channel, accountID string) []config.AgentBinding { - var filtered []config.AgentBinding - for _, b := range r.cfg.Bindings { - matchChannel := strings.ToLower(strings.TrimSpace(b.Match.Channel)) - if matchChannel == "" || matchChannel != channel { - continue - } - if !matchesAccountID(b.Match.AccountID, accountID) { - continue - } - filtered = append(filtered, b) - } - return filtered -} - -func matchesAccountID(matchAccountID, actual string) bool { - trimmed := strings.TrimSpace(matchAccountID) - if trimmed == "" { - return actual == DefaultAccountID - } - if trimmed == "*" { - return true - } - return strings.ToLower(trimmed) == strings.ToLower(actual) -} - -func (r *RouteResolver) findPeerMatch(bindings []config.AgentBinding, peer *RoutePeer) *config.AgentBinding { - for i := range bindings { - b := &bindings[i] - if b.Match.Peer == nil { - continue - } - peerKind := strings.ToLower(strings.TrimSpace(b.Match.Peer.Kind)) - peerID := strings.TrimSpace(b.Match.Peer.ID) - if peerKind == "" || peerID == "" { - continue - } - if peerKind == strings.ToLower(peer.Kind) && peerID == peer.ID { - return b - } - } - return nil -} - -func (r *RouteResolver) findGuildMatch(bindings []config.AgentBinding, guildID string) *config.AgentBinding { - for i := range bindings { - b := &bindings[i] - matchGuild := strings.TrimSpace(b.Match.GuildID) - if matchGuild != "" && matchGuild == guildID { - return &bindings[i] - } - } - return nil -} - -func (r *RouteResolver) findTeamMatch(bindings []config.AgentBinding, teamID string) *config.AgentBinding { - for i := range bindings { - b := &bindings[i] - matchTeam := strings.TrimSpace(b.Match.TeamID) - if matchTeam != "" && matchTeam == teamID { - return &bindings[i] - } - } - return nil -} - -func (r *RouteResolver) findAccountMatch(bindings []config.AgentBinding) *config.AgentBinding { - for i := range bindings { - b := &bindings[i] - accountID := strings.TrimSpace(b.Match.AccountID) - if accountID == "*" { - continue - } - if b.Match.Peer != nil || b.Match.GuildID != "" || b.Match.TeamID != "" { - continue - } - return &bindings[i] - } - return nil -} - -func (r *RouteResolver) findChannelWildcardMatch(bindings []config.AgentBinding) *config.AgentBinding { - for i := range bindings { - b := &bindings[i] - accountID := strings.TrimSpace(b.Match.AccountID) - if accountID != "*" { - continue - } - if b.Match.Peer != nil || b.Match.GuildID != "" || b.Match.TeamID != "" { - continue - } - return &bindings[i] - } - return nil } func (r *RouteResolver) pickAgentID(agentID string) string { @@ -250,3 +97,217 @@ func (r *RouteResolver) resolveDefaultAgentID() string { } return DefaultAgentID } + +func (r *RouteResolver) sessionPolicy(rule *config.DispatchRule) SessionPolicy { + dimensions := r.cfg.Session.Dimensions + if rule != nil && len(rule.SessionDimensions) > 0 { + dimensions = rule.SessionDimensions + } + return SessionPolicy{ + Dimensions: normalizeSessionDimensions(dimensions), + IdentityLinks: cloneIdentityLinks(r.cfg.Session.IdentityLinks), + } +} + +func normalizeSessionDimensions(dimensions []string) []string { + if len(dimensions) == 0 { + return nil + } + + normalized := make([]string, 0, len(dimensions)) + seen := make(map[string]struct{}, len(dimensions)) + for _, dimension := range dimensions { + dimension = strings.ToLower(strings.TrimSpace(dimension)) + switch dimension { + case "space", "chat", "topic", "sender": + default: + continue + } + if _, ok := seen[dimension]; ok { + continue + } + seen[dimension] = struct{}{} + normalized = append(normalized, dimension) + } + if len(normalized) == 0 { + return nil + } + return normalized +} + +func cloneIdentityLinks(src map[string][]string) map[string][]string { + if len(src) == 0 { + return nil + } + cloned := make(map[string][]string, len(src)) + for canonical, ids := range src { + dup := make([]string, len(ids)) + copy(dup, ids) + cloned[canonical] = dup + } + return cloned +} + +type dispatchView struct { + Channel string + Account string + Space string + Chat string + Topic string + Sender string + Mentioned bool +} + +func (r *RouteResolver) matchDispatchRule(view dispatchView) *config.DispatchRule { + if r.cfg == nil || r.cfg.Agents.Dispatch == nil || len(r.cfg.Agents.Dispatch.Rules) == 0 { + return nil + } + + for i := range r.cfg.Agents.Dispatch.Rules { + rule := &r.cfg.Agents.Dispatch.Rules[i] + if !selectorHasAnyConstraint(rule.When) { + continue + } + if ruleMatchesView(*rule, view) { + return rule + } + } + return nil +} + +func ruleMatchesView(rule config.DispatchRule, view dispatchView) bool { + when := normalizeDispatchSelector(rule.When) + if when.Channel != "" && when.Channel != view.Channel { + return false + } + if when.Account != "" && when.Account != view.Account { + return false + } + if when.Space != "" && when.Space != view.Space { + return false + } + if when.Chat != "" && when.Chat != view.Chat { + return false + } + if when.Topic != "" && when.Topic != view.Topic { + return false + } + if when.Sender != "" && when.Sender != view.Sender { + return false + } + if when.Mentioned != nil && *when.Mentioned != view.Mentioned { + return false + } + return true +} + +func matchedByForRule(rule *config.DispatchRule) string { + if rule == nil { + return "default" + } + name := strings.TrimSpace(rule.Name) + if name == "" { + return "dispatch.rule" + } + return "dispatch.rule:" + strings.ToLower(name) +} + +func buildDispatchView(inbound bus.InboundContext, identityLinks map[string][]string) dispatchView { + view := dispatchView{ + Channel: strings.ToLower(strings.TrimSpace(inbound.Channel)), + Account: NormalizeAccountID(inbound.Account), + Mentioned: inbound.Mentioned, + } + + if spaceID := strings.TrimSpace(inbound.SpaceID); spaceID != "" { + spaceType := strings.ToLower(strings.TrimSpace(inbound.SpaceType)) + if spaceType == "" { + spaceType = "space" + } + view.Space = fmt.Sprintf("%s:%s", spaceType, strings.ToLower(spaceID)) + } + + if chatID := strings.TrimSpace(inbound.ChatID); chatID != "" { + chatType := strings.ToLower(strings.TrimSpace(inbound.ChatType)) + if chatType == "" { + chatType = "direct" + } + view.Chat = fmt.Sprintf("%s:%s", chatType, strings.ToLower(chatID)) + } + + if topicID := strings.TrimSpace(inbound.TopicID); topicID != "" { + view.Topic = "topic:" + strings.ToLower(topicID) + } + + view.Sender = canonicalDispatchSenderID(inbound.Channel, inbound.SenderID, identityLinks) + + return view +} + +func normalizeDispatchSelector(selector config.DispatchSelector) config.DispatchSelector { + selector.Channel = strings.ToLower(strings.TrimSpace(selector.Channel)) + selector.Account = NormalizeAccountID(selector.Account) + selector.Space = strings.ToLower(strings.TrimSpace(selector.Space)) + selector.Chat = strings.ToLower(strings.TrimSpace(selector.Chat)) + selector.Topic = strings.ToLower(strings.TrimSpace(selector.Topic)) + selector.Sender = strings.ToLower(strings.TrimSpace(selector.Sender)) + return selector +} + +func selectorHasAnyConstraint(selector config.DispatchSelector) bool { + return strings.TrimSpace(selector.Channel) != "" || + strings.TrimSpace(selector.Account) != "" || + strings.TrimSpace(selector.Space) != "" || + strings.TrimSpace(selector.Chat) != "" || + strings.TrimSpace(selector.Topic) != "" || + strings.TrimSpace(selector.Sender) != "" || + selector.Mentioned != nil +} + +func canonicalDispatchSenderID(channel, rawID string, identityLinks map[string][]string) string { + normalizedID := strings.TrimSpace(rawID) + if normalizedID == "" { + return "" + } + if linked := resolveLinkedDispatchID(identityLinks, channel, normalizedID); linked != "" { + normalizedID = linked + } + return strings.ToLower(normalizedID) +} + +func resolveLinkedDispatchID(identityLinks map[string][]string, channel, peerID string) string { + if len(identityLinks) == 0 { + return "" + } + peerID = strings.TrimSpace(peerID) + if peerID == "" { + return "" + } + + candidates := make(map[string]bool) + rawCandidate := strings.ToLower(peerID) + if rawCandidate != "" { + candidates[rawCandidate] = true + } + channel = strings.ToLower(strings.TrimSpace(channel)) + if channel != "" { + candidates[fmt.Sprintf("%s:%s", channel, rawCandidate)] = true + } + if idx := strings.Index(rawCandidate, ":"); idx > 0 && idx < len(rawCandidate)-1 { + candidates[rawCandidate[idx+1:]] = true + } + + for canonical, ids := range identityLinks { + canonicalName := strings.TrimSpace(canonical) + if canonicalName == "" { + continue + } + for _, id := range ids { + normalized := strings.ToLower(strings.TrimSpace(id)) + if normalized != "" && candidates[normalized] { + return canonicalName + } + } + } + return "" +} diff --git a/pkg/routing/route_test.go b/pkg/routing/route_test.go index fdfc899f9..729e880fe 100644 --- a/pkg/routing/route_test.go +++ b/pkg/routing/route_test.go @@ -3,10 +3,11 @@ package routing import ( "testing" + "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" ) -func testConfig(agents []config.AgentConfig, bindings []config.AgentBinding) *config.Config { +func testConfig(agents []config.AgentConfig) *config.Config { return &config.Config{ Agents: config.AgentsConfig{ Defaults: config.AgentDefaults{ @@ -15,20 +16,20 @@ func testConfig(agents []config.AgentConfig, bindings []config.AgentBinding) *co }, List: agents, }, - Bindings: bindings, Session: config.SessionConfig{ - DMScope: "per-peer", + Dimensions: []string{"sender"}, }, } } func TestResolveRoute_DefaultAgent_NoBindings(t *testing.T) { - cfg := testConfig(nil, nil) + cfg := testConfig(nil) r := NewRouteResolver(cfg) - route := r.ResolveRoute(RouteInput{ - Channel: "telegram", - Peer: &RoutePeer{Kind: "direct", ID: "user1"}, + route := r.ResolveRoute(bus.InboundContext{ + Channel: "telegram", + ChatType: "direct", + SenderID: "user1", }) if route.AgentID != DefaultAgentID { @@ -37,202 +38,152 @@ func TestResolveRoute_DefaultAgent_NoBindings(t *testing.T) { if route.MatchedBy != "default" { t.Errorf("MatchedBy = %q, want 'default'", route.MatchedBy) } + if len(route.SessionPolicy.Dimensions) != 1 || route.SessionPolicy.Dimensions[0] != "sender" { + t.Errorf("SessionPolicy.Dimensions = %v, want [sender]", route.SessionPolicy.Dimensions) + } + if route.SessionPolicy.IdentityLinks != nil { + t.Errorf("SessionPolicy.IdentityLinks = %v, want nil", route.SessionPolicy.IdentityLinks) + } } -func TestResolveRoute_PeerBinding(t *testing.T) { - agents := []config.AgentConfig{ - {ID: "sales", Default: true}, - {ID: "support"}, +func TestResolveRoute_UsesNormalizedInboundContextFields(t *testing.T) { + cfg := testConfig([]config.AgentConfig{{ID: "sales", Default: true}}) + r := NewRouteResolver(cfg) + + route := r.ResolveRoute(bus.InboundContext{ + Channel: "Telegram", + Account: "Bot2", + ChatType: "direct", + SenderID: "user123", + }) + + if route.AgentID != "sales" { + t.Errorf("AgentID = %q, want 'sales'", route.AgentID) } - bindings := []config.AgentBinding{ - { - AgentID: "support", - Match: config.BindingMatch{ - Channel: "telegram", - AccountID: "*", - Peer: &config.PeerMatch{Kind: "direct", ID: "user123"}, + if route.Channel != "telegram" { + t.Errorf("Channel = %q, want 'telegram'", route.Channel) + } + if route.AccountID != "bot2" { + t.Errorf("AccountID = %q, want 'bot2'", route.AccountID) + } + if route.MatchedBy != "default" { + t.Errorf("MatchedBy = %q, want 'default'", route.MatchedBy) + } +} + +func TestResolveRoute_DispatchFirstMatchWins(t *testing.T) { + cfg := testConfig([]config.AgentConfig{ + {ID: "main", Default: true}, + {ID: "support"}, + {ID: "sales"}, + }) + cfg.Agents.Dispatch = &config.DispatchConfig{ + Rules: []config.DispatchRule{ + { + Name: "support-group", + Agent: "support", + When: config.DispatchSelector{ + Channel: "telegram", + Chat: "group:-100123", + }, + }, + { + Name: "vip-in-group", + Agent: "sales", + When: config.DispatchSelector{ + Channel: "telegram", + Chat: "group:-100123", + Sender: "12345", + }, }, }, } - cfg := testConfig(agents, bindings) r := NewRouteResolver(cfg) - route := r.ResolveRoute(RouteInput{ - Channel: "telegram", - Peer: &RoutePeer{Kind: "direct", ID: "user123"}, + route := r.ResolveRoute(bus.InboundContext{ + Channel: "telegram", + ChatID: "-100123", + ChatType: "group", + SenderID: "12345", }) if route.AgentID != "support" { - t.Errorf("AgentID = %q, want 'support'", route.AgentID) + t.Fatalf("AgentID = %q, want support", route.AgentID) } - if route.MatchedBy != "binding.peer" { - t.Errorf("MatchedBy = %q, want 'binding.peer'", route.MatchedBy) + if route.MatchedBy != "dispatch.rule:support-group" { + t.Fatalf("MatchedBy = %q, want dispatch.rule:support-group", route.MatchedBy) } } -func TestResolveRoute_GuildBinding(t *testing.T) { - agents := []config.AgentConfig{ - {ID: "general", Default: true}, - {ID: "gaming"}, - } - bindings := []config.AgentBinding{ - { - AgentID: "gaming", - Match: config.BindingMatch{ - Channel: "discord", - AccountID: "*", - GuildID: "guild-abc", - }, - }, - } - cfg := testConfig(agents, bindings) - r := NewRouteResolver(cfg) - - route := r.ResolveRoute(RouteInput{ - Channel: "discord", - GuildID: "guild-abc", - Peer: &RoutePeer{Kind: "channel", ID: "ch1"}, - }) - - if route.AgentID != "gaming" { - t.Errorf("AgentID = %q, want 'gaming'", route.AgentID) - } - if route.MatchedBy != "binding.guild" { - t.Errorf("MatchedBy = %q, want 'binding.guild'", route.MatchedBy) - } -} - -func TestResolveRoute_TeamBinding(t *testing.T) { - agents := []config.AgentConfig{ - {ID: "general", Default: true}, - {ID: "work"}, - } - bindings := []config.AgentBinding{ - { - AgentID: "work", - Match: config.BindingMatch{ - Channel: "slack", - AccountID: "*", - TeamID: "T12345", - }, - }, - } - cfg := testConfig(agents, bindings) - r := NewRouteResolver(cfg) - - route := r.ResolveRoute(RouteInput{ - Channel: "slack", - TeamID: "T12345", - Peer: &RoutePeer{Kind: "channel", ID: "C001"}, - }) - - if route.AgentID != "work" { - t.Errorf("AgentID = %q, want 'work'", route.AgentID) - } - if route.MatchedBy != "binding.team" { - t.Errorf("MatchedBy = %q, want 'binding.team'", route.MatchedBy) - } -} - -func TestResolveRoute_AccountBinding(t *testing.T) { - agents := []config.AgentConfig{ - {ID: "default-agent", Default: true}, - {ID: "premium"}, - } - bindings := []config.AgentBinding{ - { - AgentID: "premium", - Match: config.BindingMatch{ - Channel: "telegram", - AccountID: "bot2", - }, - }, - } - cfg := testConfig(agents, bindings) - r := NewRouteResolver(cfg) - - route := r.ResolveRoute(RouteInput{ - Channel: "telegram", - AccountID: "bot2", - Peer: &RoutePeer{Kind: "direct", ID: "user1"}, - }) - - if route.AgentID != "premium" { - t.Errorf("AgentID = %q, want 'premium'", route.AgentID) - } - if route.MatchedBy != "binding.account" { - t.Errorf("MatchedBy = %q, want 'binding.account'", route.MatchedBy) - } -} - -func TestResolveRoute_ChannelWildcard(t *testing.T) { - agents := []config.AgentConfig{ +func TestResolveRoute_DispatchOverridesSessionDimensions(t *testing.T) { + cfg := testConfig([]config.AgentConfig{ {ID: "main", Default: true}, - {ID: "telegram-bot"}, - } - bindings := []config.AgentBinding{ - { - AgentID: "telegram-bot", - Match: config.BindingMatch{ - Channel: "telegram", - AccountID: "*", + {ID: "support"}, + }) + cfg.Session.Dimensions = []string{"chat"} + cfg.Agents.Dispatch = &config.DispatchConfig{ + Rules: []config.DispatchRule{ + { + Name: "support-dm", + Agent: "support", + When: config.DispatchSelector{ + Channel: "telegram", + Chat: "direct:user-1", + }, + SessionDimensions: []string{"chat", "sender"}, }, }, } - cfg := testConfig(agents, bindings) r := NewRouteResolver(cfg) - route := r.ResolveRoute(RouteInput{ - Channel: "telegram", - Peer: &RoutePeer{Kind: "direct", ID: "user1"}, + route := r.ResolveRoute(bus.InboundContext{ + Channel: "telegram", + ChatID: "user-1", + ChatType: "direct", + SenderID: "user-1", }) - if route.AgentID != "telegram-bot" { - t.Errorf("AgentID = %q, want 'telegram-bot'", route.AgentID) + if route.AgentID != "support" { + t.Fatalf("AgentID = %q, want support", route.AgentID) } - if route.MatchedBy != "binding.channel" { - t.Errorf("MatchedBy = %q, want 'binding.channel'", route.MatchedBy) + if got := route.SessionPolicy.Dimensions; len(got) != 2 || got[0] != "chat" || got[1] != "sender" { + t.Fatalf("SessionPolicy.Dimensions = %v, want [chat sender]", got) } } -func TestResolveRoute_PriorityOrder_PeerBeatsGuild(t *testing.T) { - agents := []config.AgentConfig{ - {ID: "general", Default: true}, - {ID: "vip"}, - {ID: "gaming"}, - } - bindings := []config.AgentBinding{ - { - AgentID: "vip", - Match: config.BindingMatch{ - Channel: "discord", - AccountID: "*", - Peer: &config.PeerMatch{Kind: "direct", ID: "user-vip"}, - }, - }, - { - AgentID: "gaming", - Match: config.BindingMatch{ - Channel: "discord", - AccountID: "*", - GuildID: "guild-1", +func TestResolveRoute_DispatchMentionedRule(t *testing.T) { + cfg := testConfig([]config.AgentConfig{ + {ID: "main", Default: true}, + {ID: "support"}, + }) + mentioned := true + cfg.Agents.Dispatch = &config.DispatchConfig{ + Rules: []config.DispatchRule{ + { + Name: "slack-mentions", + Agent: "support", + When: config.DispatchSelector{ + Channel: "slack", + Space: "workspace:t001", + Mentioned: &mentioned, + }, }, }, } - cfg := testConfig(agents, bindings) r := NewRouteResolver(cfg) - route := r.ResolveRoute(RouteInput{ - Channel: "discord", - GuildID: "guild-1", - Peer: &RoutePeer{Kind: "direct", ID: "user-vip"}, + route := r.ResolveRoute(bus.InboundContext{ + Channel: "slack", + ChatID: "C123", + ChatType: "channel", + SpaceID: "T001", + SpaceType: "workspace", + SenderID: "U123", + Mentioned: true, }) - if route.AgentID != "vip" { - t.Errorf("AgentID = %q, want 'vip' (peer should beat guild)", route.AgentID) - } - if route.MatchedBy != "binding.peer" { - t.Errorf("MatchedBy = %q, want 'binding.peer'", route.MatchedBy) + if route.AgentID != "support" { + t.Fatalf("AgentID = %q, want support", route.AgentID) } } @@ -240,21 +191,10 @@ func TestResolveRoute_InvalidAgentFallsToDefault(t *testing.T) { agents := []config.AgentConfig{ {ID: "main", Default: true}, } - bindings := []config.AgentBinding{ - { - AgentID: "nonexistent", - Match: config.BindingMatch{ - Channel: "telegram", - AccountID: "*", - }, - }, - } - cfg := testConfig(agents, bindings) + cfg := testConfig(agents) r := NewRouteResolver(cfg) - route := r.ResolveRoute(RouteInput{ - Channel: "telegram", - }) + route := r.ResolveRoute(bus.InboundContext{Channel: "telegram"}) if route.AgentID != "main" { t.Errorf("AgentID = %q, want 'main' (invalid agent should fall to default)", route.AgentID) @@ -267,12 +207,10 @@ func TestResolveRoute_DefaultAgentSelection(t *testing.T) { {ID: "beta", Default: true}, {ID: "gamma"}, } - cfg := testConfig(agents, nil) + cfg := testConfig(agents) r := NewRouteResolver(cfg) - route := r.ResolveRoute(RouteInput{ - Channel: "cli", - }) + route := r.ResolveRoute(bus.InboundContext{Channel: "cli"}) if route.AgentID != "beta" { t.Errorf("AgentID = %q, want 'beta' (marked as default)", route.AgentID) @@ -284,12 +222,10 @@ func TestResolveRoute_NoDefaultUsesFirst(t *testing.T) { {ID: "alpha"}, {ID: "beta"}, } - cfg := testConfig(agents, nil) + cfg := testConfig(agents) r := NewRouteResolver(cfg) - route := r.ResolveRoute(RouteInput{ - Channel: "cli", - }) + route := r.ResolveRoute(bus.InboundContext{Channel: "cli"}) if route.AgentID != "alpha" { t.Errorf("AgentID = %q, want 'alpha' (first in list)", route.AgentID) diff --git a/pkg/routing/session_key.go b/pkg/routing/session_key.go deleted file mode 100644 index eab592bec..000000000 --- a/pkg/routing/session_key.go +++ /dev/null @@ -1,192 +0,0 @@ -package routing - -import ( - "fmt" - "strings" -) - -// DMScope controls DM session isolation granularity. -type DMScope string - -const ( - DMScopeMain DMScope = "main" - DMScopePerPeer DMScope = "per-peer" - DMScopePerChannelPeer DMScope = "per-channel-peer" - DMScopePerAccountChannelPeer DMScope = "per-account-channel-peer" -) - -// RoutePeer represents a chat peer with kind and ID. -type RoutePeer struct { - Kind string // "direct", "group", "channel" - ID string -} - -// SessionKeyParams holds all inputs for session key construction. -type SessionKeyParams struct { - AgentID string - Channel string - AccountID string - Peer *RoutePeer - DMScope DMScope - IdentityLinks map[string][]string -} - -// ParsedSessionKey is the result of parsing an agent-scoped session key. -type ParsedSessionKey struct { - AgentID string - Rest string -} - -// BuildAgentMainSessionKey returns "agent::main". -func BuildAgentMainSessionKey(agentID string) string { - return fmt.Sprintf("agent:%s:%s", NormalizeAgentID(agentID), DefaultMainKey) -} - -// BuildAgentPeerSessionKey constructs a session key based on agent, channel, peer, and DM scope. -func BuildAgentPeerSessionKey(params SessionKeyParams) string { - agentID := NormalizeAgentID(params.AgentID) - - peer := params.Peer - if peer == nil { - peer = &RoutePeer{Kind: "direct"} - } - peerKind := strings.TrimSpace(peer.Kind) - if peerKind == "" { - peerKind = "direct" - } - - if peerKind == "direct" { - dmScope := params.DMScope - if dmScope == "" { - dmScope = DMScopeMain - } - peerID := strings.TrimSpace(peer.ID) - - // Resolve identity links (cross-platform collapse) - if dmScope != DMScopeMain && peerID != "" { - if linked := resolveLinkedPeerID(params.IdentityLinks, params.Channel, peerID); linked != "" { - peerID = linked - } - } - peerID = strings.ToLower(peerID) - - switch dmScope { - case DMScopePerAccountChannelPeer: - if peerID != "" { - channel := normalizeChannel(params.Channel) - accountID := NormalizeAccountID(params.AccountID) - return fmt.Sprintf("agent:%s:%s:%s:direct:%s", agentID, channel, accountID, peerID) - } - case DMScopePerChannelPeer: - if peerID != "" { - channel := normalizeChannel(params.Channel) - return fmt.Sprintf("agent:%s:%s:direct:%s", agentID, channel, peerID) - } - case DMScopePerPeer: - if peerID != "" { - return fmt.Sprintf("agent:%s:direct:%s", agentID, peerID) - } - } - return BuildAgentMainSessionKey(agentID) - } - - // Group/channel peers always get per-peer sessions - channel := normalizeChannel(params.Channel) - peerID := strings.ToLower(strings.TrimSpace(peer.ID)) - if peerID == "" { - peerID = "unknown" - } - return fmt.Sprintf("agent:%s:%s:%s:%s", agentID, channel, peerKind, peerID) -} - -// ParseAgentSessionKey extracts agentId and rest from "agent::". -func ParseAgentSessionKey(sessionKey string) *ParsedSessionKey { - raw := strings.TrimSpace(sessionKey) - if raw == "" { - return nil - } - parts := strings.SplitN(raw, ":", 3) - if len(parts) < 3 { - return nil - } - if parts[0] != "agent" { - return nil - } - agentID := strings.TrimSpace(parts[1]) - rest := parts[2] - if agentID == "" || rest == "" { - return nil - } - return &ParsedSessionKey{AgentID: agentID, Rest: rest} -} - -// IsSubagentSessionKey returns true if the session key represents a subagent. -func IsSubagentSessionKey(sessionKey string) bool { - raw := strings.TrimSpace(sessionKey) - if raw == "" { - return false - } - if strings.HasPrefix(strings.ToLower(raw), "subagent:") { - return true - } - parsed := ParseAgentSessionKey(raw) - if parsed == nil { - return false - } - return strings.HasPrefix(strings.ToLower(parsed.Rest), "subagent:") -} - -func normalizeChannel(channel string) string { - c := strings.TrimSpace(strings.ToLower(channel)) - if c == "" { - return "unknown" - } - return c -} - -func resolveLinkedPeerID(identityLinks map[string][]string, channel, peerID string) string { - if len(identityLinks) == 0 { - return "" - } - peerID = strings.TrimSpace(peerID) - if peerID == "" { - return "" - } - - candidates := make(map[string]bool) - rawCandidate := strings.ToLower(peerID) - if rawCandidate != "" { - candidates[rawCandidate] = true - } - channel = strings.ToLower(strings.TrimSpace(channel)) - if channel != "" { - scopedCandidate := fmt.Sprintf("%s:%s", channel, strings.ToLower(peerID)) - candidates[scopedCandidate] = true - } - - // If peerID is already in canonical "platform:id" format, also add the - // bare ID part as a candidate for backward compatibility with identity_links - // that use raw IDs (e.g. "123" instead of "telegram:123"). - if idx := strings.Index(rawCandidate, ":"); idx > 0 && idx < len(rawCandidate)-1 { - bareID := rawCandidate[idx+1:] - candidates[bareID] = true - } - - if len(candidates) == 0 { - return "" - } - - for canonical, ids := range identityLinks { - canonicalName := strings.TrimSpace(canonical) - if canonicalName == "" { - continue - } - for _, id := range ids { - normalized := strings.ToLower(strings.TrimSpace(id)) - if normalized != "" && candidates[normalized] { - return canonicalName - } - } - } - return "" -} diff --git a/pkg/routing/session_key_test.go b/pkg/routing/session_key_test.go deleted file mode 100644 index ad7a1ca02..000000000 --- a/pkg/routing/session_key_test.go +++ /dev/null @@ -1,207 +0,0 @@ -package routing - -import "testing" - -func TestBuildAgentMainSessionKey(t *testing.T) { - got := BuildAgentMainSessionKey("sales") - want := "agent:sales:main" - if got != want { - t.Errorf("BuildAgentMainSessionKey('sales') = %q, want %q", got, want) - } -} - -func TestBuildAgentMainSessionKey_Normalizes(t *testing.T) { - got := BuildAgentMainSessionKey("Sales Bot") - want := "agent:sales-bot:main" - if got != want { - t.Errorf("BuildAgentMainSessionKey('Sales Bot') = %q, want %q", got, want) - } -} - -func TestBuildAgentPeerSessionKey_DMScopeMain(t *testing.T) { - got := BuildAgentPeerSessionKey(SessionKeyParams{ - AgentID: "main", - Channel: "telegram", - Peer: &RoutePeer{Kind: "direct", ID: "user123"}, - DMScope: DMScopeMain, - }) - want := "agent:main:main" - if got != want { - t.Errorf("DMScopeMain = %q, want %q", got, want) - } -} - -func TestBuildAgentPeerSessionKey_DMScopePerPeer(t *testing.T) { - got := BuildAgentPeerSessionKey(SessionKeyParams{ - AgentID: "main", - Channel: "telegram", - Peer: &RoutePeer{Kind: "direct", ID: "user123"}, - DMScope: DMScopePerPeer, - }) - want := "agent:main:direct:user123" - if got != want { - t.Errorf("DMScopePerPeer = %q, want %q", got, want) - } -} - -func TestBuildAgentPeerSessionKey_DMScopePerChannelPeer(t *testing.T) { - got := BuildAgentPeerSessionKey(SessionKeyParams{ - AgentID: "main", - Channel: "telegram", - Peer: &RoutePeer{Kind: "direct", ID: "user123"}, - DMScope: DMScopePerChannelPeer, - }) - want := "agent:main:telegram:direct:user123" - if got != want { - t.Errorf("DMScopePerChannelPeer = %q, want %q", got, want) - } -} - -func TestBuildAgentPeerSessionKey_DMScopePerAccountChannelPeer(t *testing.T) { - got := BuildAgentPeerSessionKey(SessionKeyParams{ - AgentID: "main", - Channel: "telegram", - AccountID: "bot1", - Peer: &RoutePeer{Kind: "direct", ID: "User123"}, - DMScope: DMScopePerAccountChannelPeer, - }) - want := "agent:main:telegram:bot1:direct:user123" - if got != want { - t.Errorf("DMScopePerAccountChannelPeer = %q, want %q", got, want) - } -} - -func TestBuildAgentPeerSessionKey_GroupPeer(t *testing.T) { - got := BuildAgentPeerSessionKey(SessionKeyParams{ - AgentID: "main", - Channel: "telegram", - Peer: &RoutePeer{Kind: "group", ID: "chat456"}, - DMScope: DMScopePerPeer, - }) - want := "agent:main:telegram:group:chat456" - if got != want { - t.Errorf("GroupPeer = %q, want %q", got, want) - } -} - -func TestBuildAgentPeerSessionKey_NilPeer(t *testing.T) { - got := BuildAgentPeerSessionKey(SessionKeyParams{ - AgentID: "main", - Channel: "telegram", - Peer: nil, - DMScope: DMScopePerPeer, - }) - // nil peer defaults to direct with empty ID, falls to main - want := "agent:main:main" - if got != want { - t.Errorf("NilPeer = %q, want %q", got, want) - } -} - -func TestBuildAgentPeerSessionKey_IdentityLink(t *testing.T) { - links := map[string][]string{ - "john": {"telegram:user123", "discord:john#1234"}, - } - got := BuildAgentPeerSessionKey(SessionKeyParams{ - AgentID: "main", - Channel: "telegram", - Peer: &RoutePeer{Kind: "direct", ID: "user123"}, - DMScope: DMScopePerPeer, - IdentityLinks: links, - }) - want := "agent:main:direct:john" - if got != want { - t.Errorf("IdentityLink = %q, want %q", got, want) - } -} - -func TestResolveLinkedPeerID_CanonicalPeerID(t *testing.T) { - // When peerID is already in canonical "platform:id" format, - // it should match identity_links that use the bare ID. - links := map[string][]string{ - "john": {"123"}, - } - got := resolveLinkedPeerID(links, "telegram", "telegram:123") - if got != "john" { - t.Errorf("resolveLinkedPeerID with canonical peerID = %q, want %q", got, "john") - } -} - -func TestResolveLinkedPeerID_CanonicalInLinks(t *testing.T) { - // When identity_links contain canonical IDs and peerID is canonical too - links := map[string][]string{ - "john": {"telegram:123", "discord:456"}, - } - got := resolveLinkedPeerID(links, "telegram", "telegram:123") - if got != "john" { - t.Errorf("resolveLinkedPeerID canonical in links = %q, want %q", got, "john") - } -} - -func TestResolveLinkedPeerID_BarePeerIDMatchesCanonicalLink(t *testing.T) { - // When peerID is bare "123" and links have "telegram:123", - // the scoped candidate "telegram:123" should match. - links := map[string][]string{ - "john": {"telegram:123"}, - } - got := resolveLinkedPeerID(links, "telegram", "123") - if got != "john" { - t.Errorf("resolveLinkedPeerID bare peer matches canonical link = %q, want %q", got, "john") - } -} - -func TestResolveLinkedPeerID_NoMatch(t *testing.T) { - links := map[string][]string{ - "john": {"telegram:123"}, - } - got := resolveLinkedPeerID(links, "discord", "999") - if got != "" { - t.Errorf("resolveLinkedPeerID no match = %q, want empty", got) - } -} - -func TestParseAgentSessionKey_Valid(t *testing.T) { - parsed := ParseAgentSessionKey("agent:sales:telegram:direct:user123") - if parsed == nil { - t.Fatal("expected non-nil result") - } - if parsed.AgentID != "sales" { - t.Errorf("AgentID = %q, want 'sales'", parsed.AgentID) - } - if parsed.Rest != "telegram:direct:user123" { - t.Errorf("Rest = %q, want 'telegram:direct:user123'", parsed.Rest) - } -} - -func TestParseAgentSessionKey_Invalid(t *testing.T) { - tests := []string{ - "", - "foo:bar", - "notprefix:sales:main", - "agent::main", - "agent:sales:", - } - for _, input := range tests { - if got := ParseAgentSessionKey(input); got != nil { - t.Errorf("ParseAgentSessionKey(%q) = %+v, want nil", input, got) - } - } -} - -func TestIsSubagentSessionKey(t *testing.T) { - tests := []struct { - input string - want bool - }{ - {"subagent:task-1", true}, - {"agent:main:subagent:task-1", true}, - {"agent:main:main", false}, - {"agent:main:telegram:direct:user123", false}, - {"", false}, - } - for _, tt := range tests { - if got := IsSubagentSessionKey(tt.input); got != tt.want { - t.Errorf("IsSubagentSessionKey(%q) = %v, want %v", tt.input, got, tt.want) - } - } -} diff --git a/pkg/session/allocator.go b/pkg/session/allocator.go new file mode 100644 index 000000000..509550cb2 --- /dev/null +++ b/pkg/session/allocator.go @@ -0,0 +1,213 @@ +package session + +import ( + "fmt" + "strings" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/routing" +) + +// Allocation contains the concrete session keys selected for a routed turn. +// The current implementation intentionally preserves the legacy session-key +// layout while moving key construction out of the router. +type Allocation struct { + Scope SessionScope + SessionKey string + SessionAliases []string + MainSessionKey string + MainAliases []string +} + +// AllocationInput contains the routing result and peer context needed to +// derive the session keys for a turn. +type AllocationInput struct { + AgentID string + Context bus.InboundContext + SessionPolicy routing.SessionPolicy +} + +// AllocateRouteSession maps a route decision onto a structured scope and the +// current opaque session-key format. +func AllocateRouteSession(input AllocationInput) Allocation { + scope := buildSessionScope(input) + legacySessionAliases := buildLegacySessionAliases(input) + legacyMainSessionKey := strings.ToLower(BuildLegacyMainAlias(input.AgentID)) + return Allocation{ + Scope: scope, + SessionKey: BuildSessionKey(scope), + SessionAliases: legacySessionAliases, + MainSessionKey: BuildOpaqueSessionKey(legacyMainSessionKey), + MainAliases: []string{legacyMainSessionKey}, + } +} + +func buildSessionScope(input AllocationInput) SessionScope { + inbound := input.Context + includeTopicInChatDimension := shouldPreserveTelegramForumIsolation(input) + scope := SessionScope{ + Version: ScopeVersionV1, + AgentID: routing.NormalizeAgentID(input.AgentID), + Channel: strings.ToLower(strings.TrimSpace(inbound.Channel)), + Account: routing.NormalizeAccountID(inbound.Account), + } + if scope.Channel == "" { + scope.Channel = "unknown" + } + + dimensions := make([]string, 0, len(input.SessionPolicy.Dimensions)) + values := make(map[string]string, len(input.SessionPolicy.Dimensions)) + + for _, dimension := range input.SessionPolicy.Dimensions { + switch dimension { + case "space": + if spaceID := strings.TrimSpace(inbound.SpaceID); spaceID != "" { + spaceType := strings.ToLower(strings.TrimSpace(inbound.SpaceType)) + if spaceType == "" { + spaceType = "space" + } + dimensions = append(dimensions, "space") + values["space"] = fmt.Sprintf("%s:%s", spaceType, strings.ToLower(spaceID)) + } + case "chat": + chatID := strings.TrimSpace(inbound.ChatID) + if chatID == "" { + continue + } + if includeTopicInChatDimension { + if topicID := strings.TrimSpace(inbound.TopicID); topicID != "" { + chatID = chatID + "/" + topicID + } + } + chatType := strings.ToLower(strings.TrimSpace(inbound.ChatType)) + if chatType == "" { + chatType = "direct" + } + dimensions = append(dimensions, "chat") + values["chat"] = fmt.Sprintf("%s:%s", chatType, strings.ToLower(chatID)) + case "topic": + if topicID := strings.TrimSpace(inbound.TopicID); topicID != "" { + dimensions = append(dimensions, "topic") + values["topic"] = "topic:" + strings.ToLower(topicID) + } + case "sender": + senderID := CanonicalSessionIdentityID( + inbound.Channel, + inbound.SenderID, + input.SessionPolicy.IdentityLinks, + ) + if senderID == "" { + continue + } + dimensions = append(dimensions, "sender") + values["sender"] = senderID + } + } + + if len(dimensions) > 0 { + scope.Dimensions = dimensions + scope.Values = values + } + + return scope +} + +func buildLegacySessionAliases(input AllocationInput) []string { + aliases := []string{strings.ToLower(BuildLegacyMainAlias(input.AgentID))} + inbound := input.Context + + if strings.EqualFold(strings.TrimSpace(inbound.ChatType), "direct") { + peerIDs := buildLegacyDirectPeerIDs(input) + if len(peerIDs) == 0 { + return uniqueAliases(aliases) + } + for _, peerID := range peerIDs { + aliases = append( + aliases, + BuildLegacyDirectAliases(input.AgentID, inbound.Channel, inbound.Account, peerID)..., + ) + } + return uniqueAliases(aliases) + } + + peerID := strings.TrimSpace(inbound.ChatID) + if peerID == "" { + return uniqueAliases(aliases) + } + if topicID := strings.TrimSpace(inbound.TopicID); topicID != "" { + peerID = peerID + "/" + topicID + } + aliases = append(aliases, BuildLegacyPeerAlias( + input.AgentID, + inbound.Channel, + strings.ToLower(strings.TrimSpace(inbound.ChatType)), + peerID, + )) + + return uniqueAliases(aliases) +} + +func shouldPreserveTelegramForumIsolation(input AllocationInput) bool { + inbound := input.Context + if !strings.EqualFold(strings.TrimSpace(inbound.Channel), "telegram") { + return false + } + if strings.TrimSpace(inbound.TopicID) == "" { + return false + } + for _, dimension := range input.SessionPolicy.Dimensions { + if strings.EqualFold(strings.TrimSpace(dimension), "topic") { + return false + } + } + return true +} + +func buildLegacyDirectPeerIDs(input AllocationInput) []string { + inbound := input.Context + peerIDs := make([]string, 0, 3) + + rawSenderID := strings.TrimSpace(inbound.SenderID) + if rawSenderID != "" { + peerIDs = append(peerIDs, strings.ToLower(rawSenderID)) + } + + canonicalSenderID := CanonicalSessionIdentityID( + inbound.Channel, + inbound.SenderID, + input.SessionPolicy.IdentityLinks, + ) + if canonicalSenderID != "" { + peerIDs = append(peerIDs, canonicalSenderID) + } + + chatID := strings.TrimSpace(inbound.ChatID) + if chatID != "" { + peerIDs = append(peerIDs, strings.ToLower(chatID)) + } + + return uniqueAliases(peerIDs) +} + +func uniqueAliases(aliases []string) []string { + if len(aliases) == 0 { + return nil + } + normalized := make([]string, 0, len(aliases)) + seen := make(map[string]struct{}, len(aliases)) + for _, alias := range aliases { + alias = strings.TrimSpace(strings.ToLower(alias)) + if alias == "" { + continue + } + if _, ok := seen[alias]; ok { + continue + } + seen[alias] = struct{}{} + normalized = append(normalized, alias) + } + if len(normalized) == 0 { + return nil + } + return normalized +} diff --git a/pkg/session/allocator_test.go b/pkg/session/allocator_test.go new file mode 100644 index 000000000..9750ffc39 --- /dev/null +++ b/pkg/session/allocator_test.go @@ -0,0 +1,160 @@ +package session + +import ( + "testing" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/routing" +) + +func TestAllocateRouteSession_PerPeerDM(t *testing.T) { + allocation := AllocateRouteSession(AllocationInput{ + AgentID: "main", + Context: bus.InboundContext{ + Channel: "telegram", + Account: "default", + ChatID: "dm-123", + ChatType: "direct", + SenderID: "User123", + }, + SessionPolicy: routing.SessionPolicy{ + Dimensions: []string{"sender"}, + }, + }) + + if allocation.SessionKey == "" || !IsOpaqueSessionKey(allocation.SessionKey) { + t.Fatalf("SessionKey = %q, want opaque session key", allocation.SessionKey) + } + if !containsAlias(allocation.SessionAliases, "agent:main:direct:user123") { + t.Fatalf("SessionAliases = %v, want to contain agent:main:direct:user123", allocation.SessionAliases) + } + if allocation.MainSessionKey == "" || !IsOpaqueSessionKey(allocation.MainSessionKey) { + t.Fatalf("MainSessionKey = %q, want opaque session key", allocation.MainSessionKey) + } + if len(allocation.MainAliases) != 1 || allocation.MainAliases[0] != "agent:main:main" { + t.Fatalf("MainAliases = %v, want [agent:main:main]", allocation.MainAliases) + } + if allocation.Scope.Version != ScopeVersionV1 { + t.Fatalf("Scope.Version = %d, want %d", allocation.Scope.Version, ScopeVersionV1) + } + if len(allocation.Scope.Dimensions) != 1 || allocation.Scope.Dimensions[0] != "sender" { + t.Fatalf("Scope.Dimensions = %v, want [sender]", allocation.Scope.Dimensions) + } + if allocation.Scope.Values["sender"] != "user123" { + t.Fatalf("Scope.Values[sender] = %q, want user123", allocation.Scope.Values["sender"]) + } +} + +func TestAllocateRouteSession_GroupPeer(t *testing.T) { + allocation := AllocateRouteSession(AllocationInput{ + AgentID: "main", + Context: bus.InboundContext{ + Channel: "slack", + Account: "workspace-a", + ChatID: "C001", + ChatType: "channel", + SenderID: "U001", + }, + SessionPolicy: routing.SessionPolicy{ + Dimensions: []string{"chat"}, + }, + }) + + if allocation.SessionKey == "" || !IsOpaqueSessionKey(allocation.SessionKey) { + t.Fatalf("SessionKey = %q, want opaque session key", allocation.SessionKey) + } + if !containsAlias(allocation.SessionAliases, "agent:main:slack:channel:c001") { + t.Fatalf("SessionAliases = %v, want to contain agent:main:slack:channel:c001", allocation.SessionAliases) + } + if allocation.MainSessionKey == "" || !IsOpaqueSessionKey(allocation.MainSessionKey) { + t.Fatalf("MainSessionKey = %q, want opaque session key", allocation.MainSessionKey) + } + if len(allocation.MainAliases) != 1 || allocation.MainAliases[0] != "agent:main:main" { + t.Fatalf("MainAliases = %v, want [agent:main:main]", allocation.MainAliases) + } + if len(allocation.Scope.Dimensions) != 1 || allocation.Scope.Dimensions[0] != "chat" { + t.Fatalf("Scope.Dimensions = %v, want [chat]", allocation.Scope.Dimensions) + } + if allocation.Scope.Values["chat"] != "channel:c001" { + t.Fatalf("Scope.Values[chat] = %q, want channel:c001", allocation.Scope.Values["chat"]) + } +} + +func TestAllocateRouteSession_TelegramForumTopicsRemainIsolatedByDefault(t *testing.T) { + first := AllocateRouteSession(AllocationInput{ + AgentID: "main", + Context: bus.InboundContext{ + Channel: "telegram", + ChatID: "-1001234567890", + ChatType: "group", + TopicID: "42", + SenderID: "7", + }, + SessionPolicy: routing.SessionPolicy{ + Dimensions: []string{"chat"}, + }, + }) + second := AllocateRouteSession(AllocationInput{ + AgentID: "main", + Context: bus.InboundContext{ + Channel: "telegram", + ChatID: "-1001234567890", + ChatType: "group", + TopicID: "99", + SenderID: "7", + }, + SessionPolicy: routing.SessionPolicy{ + Dimensions: []string{"chat"}, + }, + }) + + if first.SessionKey == second.SessionKey { + t.Fatalf("forum topics should not share default session key: %q", first.SessionKey) + } + if got := first.Scope.Values["chat"]; got != "group:-1001234567890/42" { + t.Fatalf("first.Scope.Values[chat] = %q, want %q", got, "group:-1001234567890/42") + } + if got := second.Scope.Values["chat"]; got != "group:-1001234567890/99" { + t.Fatalf("second.Scope.Values[chat] = %q, want %q", got, "group:-1001234567890/99") + } +} + +func TestAllocateRouteSession_PicoDirectAliasesIncludeLegacyChatKey(t *testing.T) { + allocation := AllocateRouteSession(AllocationInput{ + AgentID: "main", + Context: bus.InboundContext{ + Channel: "pico", + Account: "default", + ChatID: "pico:session-123", + ChatType: "direct", + SenderID: "pico-user", + }, + SessionPolicy: routing.SessionPolicy{ + Dimensions: []string{"sender"}, + }, + }) + + if !containsAlias(allocation.SessionAliases, "agent:main:pico:direct:pico:session-123") { + t.Fatalf("SessionAliases = %v, want pico legacy alias", allocation.SessionAliases) + } +} + +func TestBuildOpaqueSessionKey_IsStable(t *testing.T) { + first := BuildOpaqueSessionKey("agent:main:direct:user123") + second := BuildOpaqueSessionKey("agent:main:direct:user123") + if first != second { + t.Fatalf("BuildOpaqueSessionKey() mismatch: %q != %q", first, second) + } + if !IsOpaqueSessionKey(first) { + t.Fatalf("expected opaque session key, got %q", first) + } +} + +func containsAlias(aliases []string, want string) bool { + for _, alias := range aliases { + if alias == want { + return true + } + } + return false +} diff --git a/pkg/session/jsonl_backend.go b/pkg/session/jsonl_backend.go index 5a2297e30..68ef2d753 100644 --- a/pkg/session/jsonl_backend.go +++ b/pkg/session/jsonl_backend.go @@ -2,7 +2,9 @@ package session import ( "context" + "encoding/json" "log" + "strings" "github.com/sipeed/picoclaw/pkg/memory" "github.com/sipeed/picoclaw/pkg/providers" @@ -15,24 +17,123 @@ type JSONLBackend struct { store memory.Store } +type metaAwareStore interface { + GetSessionMeta(ctx context.Context, sessionKey string) (memory.SessionMeta, error) + UpsertSessionMeta(ctx context.Context, sessionKey string, scope json.RawMessage, aliases []string) error + ResolveSessionKey(ctx context.Context, sessionKey string) (string, bool, error) +} + +type aliasPromotingStore interface { + PromoteAliasHistory(ctx context.Context, sessionKey string, scope json.RawMessage, aliases []string) (bool, error) +} + +// MetadataAwareSessionStore exposes structured session metadata operations. +type MetadataAwareSessionStore interface { + EnsureSessionMetadata(sessionKey string, scope *SessionScope, aliases []string) + ResolveSessionKey(sessionKey string) string + GetSessionScope(sessionKey string) *SessionScope +} + // NewJSONLBackend wraps a memory.Store for use as a SessionStore. func NewJSONLBackend(store memory.Store) *JSONLBackend { return &JSONLBackend{store: store} } +func (b *JSONLBackend) resolveSessionKey(sessionKey string) string { + metaStore, ok := b.store.(metaAwareStore) + if !ok { + return sessionKey + } + resolved, found, err := metaStore.ResolveSessionKey(context.Background(), sessionKey) + if err != nil { + log.Printf("session: resolve session key: %v", err) + return sessionKey + } + if found && resolved != "" { + return resolved + } + return sessionKey +} + +// ResolveSessionKey maps aliases onto their canonical session key when the +// underlying store supports structured metadata. Unknown aliases fall back to +// the original input so existing callers remain compatible. +func (b *JSONLBackend) ResolveSessionKey(sessionKey string) string { + return b.resolveSessionKey(sessionKey) +} + +// EnsureSessionMetadata persists scope and alias metadata for a session. +func (b *JSONLBackend) EnsureSessionMetadata(sessionKey string, scope *SessionScope, aliases []string) { + metaStore, ok := b.store.(metaAwareStore) + if !ok { + return + } + sessionKey = strings.TrimSpace(sessionKey) + if sessionKey == "" { + return + } + + var rawScope json.RawMessage + if scope != nil { + data, err := json.Marshal(scope) + if err != nil { + log.Printf("session: encode session scope: %v", err) + return + } + rawScope = data + } + ctx := context.Background() + if err := metaStore.UpsertSessionMeta(ctx, sessionKey, rawScope, aliases); err != nil { + log.Printf("session: upsert session metadata: %v", err) + return + } + + if promotingStore, ok := b.store.(aliasPromotingStore); ok { + if _, err := promotingStore.PromoteAliasHistory(ctx, sessionKey, rawScope, aliases); err != nil { + log.Printf("session: promote alias history: %v", err) + } + } +} + +// GetSessionScope reads structured scope metadata for a session key or alias. +func (b *JSONLBackend) GetSessionScope(sessionKey string) *SessionScope { + metaStore, ok := b.store.(metaAwareStore) + if !ok { + return nil + } + sessionKey = b.resolveSessionKey(sessionKey) + meta, err := metaStore.GetSessionMeta(context.Background(), sessionKey) + if err != nil { + log.Printf("session: get session metadata: %v", err) + return nil + } + if len(meta.Scope) == 0 { + return nil + } + var scope SessionScope + if err := json.Unmarshal(meta.Scope, &scope); err != nil { + log.Printf("session: decode session scope: %v", err) + return nil + } + return CloneScope(&scope) +} + func (b *JSONLBackend) AddMessage(sessionKey, role, content string) { + sessionKey = b.resolveSessionKey(sessionKey) if err := b.store.AddMessage(context.Background(), sessionKey, role, content); err != nil { log.Printf("session: add message: %v", err) } } func (b *JSONLBackend) AddFullMessage(sessionKey string, msg providers.Message) { + sessionKey = b.resolveSessionKey(sessionKey) if err := b.store.AddFullMessage(context.Background(), sessionKey, msg); err != nil { log.Printf("session: add full message: %v", err) } } func (b *JSONLBackend) GetHistory(key string) []providers.Message { + key = b.resolveSessionKey(key) msgs, err := b.store.GetHistory(context.Background(), key) if err != nil { log.Printf("session: get history: %v", err) @@ -42,6 +143,7 @@ func (b *JSONLBackend) GetHistory(key string) []providers.Message { } func (b *JSONLBackend) GetSummary(key string) string { + key = b.resolveSessionKey(key) summary, err := b.store.GetSummary(context.Background(), key) if err != nil { log.Printf("session: get summary: %v", err) @@ -51,18 +153,21 @@ func (b *JSONLBackend) GetSummary(key string) string { } func (b *JSONLBackend) SetSummary(key, summary string) { + key = b.resolveSessionKey(key) if err := b.store.SetSummary(context.Background(), key, summary); err != nil { log.Printf("session: set summary: %v", err) } } func (b *JSONLBackend) SetHistory(key string, history []providers.Message) { + key = b.resolveSessionKey(key) if err := b.store.SetHistory(context.Background(), key, history); err != nil { log.Printf("session: set history: %v", err) } } func (b *JSONLBackend) TruncateHistory(key string, keepLast int) { + key = b.resolveSessionKey(key) if err := b.store.TruncateHistory(context.Background(), key, keepLast); err != nil { log.Printf("session: truncate history: %v", err) } @@ -72,6 +177,7 @@ func (b *JSONLBackend) TruncateHistory(key string, keepLast int) { // immediately, the data is already durable. Save runs compaction to reclaim // space from logically truncated messages (no-op when there are none). func (b *JSONLBackend) Save(key string) error { + key = b.resolveSessionKey(key) return b.store.Compact(context.Background(), key) } diff --git a/pkg/session/jsonl_backend_test.go b/pkg/session/jsonl_backend_test.go index 40fa019cb..0b79ad84d 100644 --- a/pkg/session/jsonl_backend_test.go +++ b/pkg/session/jsonl_backend_test.go @@ -4,8 +4,10 @@ import ( "fmt" "testing" + "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/memory" "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/routing" "github.com/sipeed/picoclaw/pkg/session" ) @@ -177,3 +179,126 @@ func TestJSONLBackend_SummarizeFlow(t *testing.T) { t.Errorf("first message = %q, want %q", history[0].Content, "msg 16") } } + +func TestJSONLBackend_ResolveAliasAndPersistMetadata(t *testing.T) { + b := newBackend(t) + + scope := &session.SessionScope{ + Version: session.ScopeVersionV1, + AgentID: "main", + Channel: "telegram", + Account: "default", + Dimensions: []string{"chat"}, + Values: map[string]string{ + "chat": "group:c1", + }, + } + b.EnsureSessionMetadata("canonical", scope, []string{"legacy"}) + + if got := b.ResolveSessionKey("legacy"); got != "canonical" { + t.Fatalf("ResolveSessionKey() = %q, want %q", got, "canonical") + } + + b.AddMessage("legacy", "user", "hello through alias") + history := b.GetHistory("canonical") + if len(history) != 1 { + t.Fatalf("len(history) = %d, want 1", len(history)) + } + if history[0].Content != "hello through alias" { + t.Fatalf("history[0].Content = %q, want %q", history[0].Content, "hello through alias") + } + + resolvedScope := b.GetSessionScope("legacy") + if resolvedScope == nil { + t.Fatal("GetSessionScope() returned nil") + } + if resolvedScope.AgentID != scope.AgentID || resolvedScope.Values["chat"] != scope.Values["chat"] { + t.Fatalf("GetSessionScope() = %+v, want %+v", resolvedScope, scope) + } +} + +func TestJSONLBackend_EnsureSessionMetadata_PromotesLegacyAliasHistory(t *testing.T) { + b := newBackend(t) + + legacyKey := "agent:main:direct:legacy-user" + b.AddMessage(legacyKey, "user", "legacy history") + b.SetSummary(legacyKey, "legacy summary") + + canonicalKey := session.BuildOpaqueSessionKey(legacyKey) + b.EnsureSessionMetadata(canonicalKey, &session.SessionScope{ + Version: session.ScopeVersionV1, + AgentID: "main", + }, []string{legacyKey}) + + if got := b.ResolveSessionKey(legacyKey); got != canonicalKey { + t.Fatalf("ResolveSessionKey() = %q, want %q", got, canonicalKey) + } + history := b.GetHistory(canonicalKey) + if len(history) != 1 || history[0].Content != "legacy history" { + t.Fatalf("promoted history = %+v", history) + } + if summary := b.GetSummary(canonicalKey); summary != "legacy summary" { + t.Fatalf("promoted summary = %q, want %q", summary, "legacy summary") + } +} + +func TestJSONLBackend_EnsureSessionMetadata_PromotesLegacyPicoDirectAliasHistory(t *testing.T) { + b := newBackend(t) + + legacyKey := "agent:main:pico:direct:pico:session-123" + b.AddMessage(legacyKey, "user", "legacy pico history") + + scope := &session.SessionScope{ + Version: session.ScopeVersionV1, + AgentID: "main", + Channel: "pico", + Account: "default", + Dimensions: []string{"sender"}, + Values: map[string]string{ + "sender": "pico-user", + }, + } + allocation := session.AllocateRouteSession(session.AllocationInput{ + AgentID: "main", + Context: bus.InboundContext{ + Channel: "pico", + Account: "default", + ChatID: "pico:session-123", + ChatType: "direct", + SenderID: "pico-user", + }, + SessionPolicy: routing.SessionPolicy{ + Dimensions: []string{"sender"}, + }, + }) + + b.EnsureSessionMetadata(allocation.SessionKey, scope, allocation.SessionAliases) + + if got := b.ResolveSessionKey(legacyKey); got != allocation.SessionKey { + t.Fatalf("ResolveSessionKey() = %q, want %q", got, allocation.SessionKey) + } + history := b.GetHistory(allocation.SessionKey) + if len(history) != 1 || history[0].Content != "legacy pico history" { + t.Fatalf("promoted history = %+v", history) + } +} + +func TestJSONLBackend_EnsureSessionMetadata_DoesNotOverwriteNonEmptyCanonicalHistory(t *testing.T) { + b := newBackend(t) + + canonicalKey := session.BuildOpaqueSessionKey("agent:main:direct:current-user") + legacyKey := "agent:main:direct:legacy-user" + + b.AddMessage(canonicalKey, "user", "current canonical history") + b.AddMessage(legacyKey, "user", "legacy history") + + b.EnsureSessionMetadata(canonicalKey, &session.SessionScope{ + Version: session.ScopeVersionV1, + AgentID: "main", + }, []string{legacyKey}) + + history := b.GetHistory(canonicalKey) + if len(history) != 1 || history[0].Content != "current canonical history" { + t.Fatalf("canonical history overwritten: %+v", history) + } +} diff --git a/pkg/session/key.go b/pkg/session/key.go new file mode 100644 index 000000000..fb0836bc1 --- /dev/null +++ b/pkg/session/key.go @@ -0,0 +1,205 @@ +package session + +import ( + "crypto/sha256" + "encoding/hex" + "fmt" + "strings" + + "github.com/sipeed/picoclaw/pkg/routing" +) + +const ( + sessionKeyV1Prefix = "sk_v1_" + legacyAgentSessionKeyPrefix = "agent:" +) + +type ParsedLegacySessionKey struct { + AgentID string + Rest string +} + +// BuildOpaqueSessionKey returns a stable opaque session key derived from a +// canonical alias string. The alias remains available through metadata for +// compatibility and migration purposes. +func BuildOpaqueSessionKey(alias string) string { + normalized := strings.TrimSpace(strings.ToLower(alias)) + if normalized == "" { + return "" + } + sum := sha256.Sum256([]byte(normalized)) + return sessionKeyV1Prefix + hex.EncodeToString(sum[:]) +} + +// IsOpaqueSessionKey returns true when the key matches the current opaque +// session-key format. +func IsOpaqueSessionKey(key string) bool { + return strings.HasPrefix(strings.ToLower(strings.TrimSpace(key)), sessionKeyV1Prefix) +} + +func IsLegacyAgentSessionKey(key string) bool { + return strings.HasPrefix(strings.ToLower(strings.TrimSpace(key)), legacyAgentSessionKeyPrefix) +} + +func IsExplicitSessionKey(key string) bool { + return IsOpaqueSessionKey(key) || IsLegacyAgentSessionKey(key) +} + +func ParseLegacyAgentSessionKey(sessionKey string) *ParsedLegacySessionKey { + raw := strings.TrimSpace(sessionKey) + if raw == "" { + return nil + } + parts := strings.SplitN(raw, ":", 3) + if len(parts) < 3 || parts[0] != "agent" { + return nil + } + agentID := strings.TrimSpace(parts[1]) + rest := parts[2] + if agentID == "" || rest == "" { + return nil + } + return &ParsedLegacySessionKey{AgentID: agentID, Rest: rest} +} + +// ResolveAgentID returns the routed agent ID associated with a session. It +// prefers structured session scope metadata when available and falls back to +// legacy agent-scoped session keys for compatibility. +func ResolveAgentID(store any, sessionKey string) string { + if scopeReader, ok := store.(interface { + GetSessionScope(sessionKey string) *SessionScope + }); ok { + scope := scopeReader.GetSessionScope(sessionKey) + if scope != nil && strings.TrimSpace(scope.AgentID) != "" { + return routing.NormalizeAgentID(scope.AgentID) + } + } + + if parsed := ParseLegacyAgentSessionKey(sessionKey); parsed != nil { + return routing.NormalizeAgentID(parsed.AgentID) + } + + return "" +} + +func BuildLegacyMainAlias(agentID string) string { + return fmt.Sprintf("agent:%s:main", routing.NormalizeAgentID(agentID)) +} + +// BuildMainSessionKey returns the canonical opaque main-session key for an +// agent. The corresponding legacy alias remains available via +// BuildLegacyMainAlias for compatibility and migration logic. +func BuildMainSessionKey(agentID string) string { + return BuildOpaqueSessionKey(BuildLegacyMainAlias(agentID)) +} + +func BuildLegacyDirectAliases(agentID, channel, account, peerID string) []string { + agentID = routing.NormalizeAgentID(agentID) + channel = normalizeLegacyChannel(channel) + account = routing.NormalizeAccountID(account) + peerID = strings.ToLower(strings.TrimSpace(peerID)) + if peerID == "" { + return nil + } + return []string{ + fmt.Sprintf("agent:%s:direct:%s", agentID, peerID), + fmt.Sprintf("agent:%s:%s:direct:%s", agentID, channel, peerID), + fmt.Sprintf("agent:%s:%s:%s:direct:%s", agentID, channel, account, peerID), + } +} + +func BuildLegacyPeerAlias(agentID, channel, peerKind, peerID string) string { + agentID = routing.NormalizeAgentID(agentID) + channel = normalizeLegacyChannel(channel) + peerKind = strings.ToLower(strings.TrimSpace(peerKind)) + if peerKind == "" { + peerKind = "unknown" + } + peerID = strings.ToLower(strings.TrimSpace(peerID)) + if peerID == "" { + peerID = "unknown" + } + return fmt.Sprintf("agent:%s:%s:%s:%s", agentID, channel, peerKind, peerID) +} + +// CanonicalSessionIdentityID collapses an identity using identity_links when +// possible, then returns a normalized lowercase identifier. +func CanonicalSessionIdentityID(channel, rawID string, identityLinks map[string][]string) string { + normalizedID := strings.TrimSpace(rawID) + if normalizedID == "" { + return "" + } + if linked := resolveLinkedPeerID(identityLinks, channel, normalizedID); linked != "" { + normalizedID = linked + } + return strings.ToLower(normalizedID) +} + +func normalizeLegacyChannel(channel string) string { + channel = strings.ToLower(strings.TrimSpace(channel)) + if channel == "" { + return "unknown" + } + return channel +} + +func resolveLinkedPeerID(identityLinks map[string][]string, channel, peerID string) string { + if len(identityLinks) == 0 { + return "" + } + peerID = strings.TrimSpace(peerID) + if peerID == "" { + return "" + } + + candidates := make(map[string]bool) + rawCandidate := strings.ToLower(peerID) + if rawCandidate != "" { + candidates[rawCandidate] = true + } + channel = strings.ToLower(strings.TrimSpace(channel)) + if channel != "" { + candidates[fmt.Sprintf("%s:%s", channel, rawCandidate)] = true + } + if idx := strings.Index(rawCandidate, ":"); idx > 0 && idx < len(rawCandidate)-1 { + candidates[rawCandidate[idx+1:]] = true + } + + for canonical, ids := range identityLinks { + canonicalName := strings.TrimSpace(canonical) + if canonicalName == "" { + continue + } + for _, id := range ids { + normalized := strings.ToLower(strings.TrimSpace(id)) + if normalized != "" && candidates[normalized] { + return canonicalName + } + } + } + return "" +} + +// CanonicalScopeSignature returns a stable serialized representation of scope. +func CanonicalScopeSignature(scope SessionScope) string { + parts := []string{ + fmt.Sprintf("v=%d", scope.Version), + fmt.Sprintf("agent=%s", strings.TrimSpace(strings.ToLower(scope.AgentID))), + fmt.Sprintf("channel=%s", strings.TrimSpace(strings.ToLower(scope.Channel))), + fmt.Sprintf("account=%s", strings.TrimSpace(strings.ToLower(scope.Account))), + } + for _, dimension := range scope.Dimensions { + dimension = strings.TrimSpace(strings.ToLower(dimension)) + if dimension == "" { + continue + } + value := strings.TrimSpace(strings.ToLower(scope.Values[dimension])) + parts = append(parts, fmt.Sprintf("%s=%s", dimension, value)) + } + return strings.Join(parts, "|") +} + +// BuildSessionKey returns the current opaque key for a structured session scope. +func BuildSessionKey(scope SessionScope) string { + return BuildOpaqueSessionKey(CanonicalScopeSignature(scope)) +} diff --git a/pkg/session/key_test.go b/pkg/session/key_test.go new file mode 100644 index 000000000..6cdf397e1 --- /dev/null +++ b/pkg/session/key_test.go @@ -0,0 +1,100 @@ +package session + +import "testing" + +type testScopeReader struct { + scope *SessionScope +} + +func (r testScopeReader) GetSessionScope(sessionKey string) *SessionScope { + return CloneScope(r.scope) +} + +func TestIsExplicitSessionKey(t *testing.T) { + tests := []struct { + key string + want bool + }{ + {"sk_v1_abc", true}, + {"agent:main:direct:user123", true}, + {"custom-key", false}, + {"", false}, + } + + for _, tt := range tests { + if got := IsExplicitSessionKey(tt.key); got != tt.want { + t.Fatalf("IsExplicitSessionKey(%q) = %v, want %v", tt.key, got, tt.want) + } + } +} + +func TestParseLegacyAgentSessionKey(t *testing.T) { + parsed := ParseLegacyAgentSessionKey("agent:sales:telegram:direct:user123") + if parsed == nil { + t.Fatal("expected parsed legacy key, got nil") + } + if parsed.AgentID != "sales" { + t.Fatalf("AgentID = %q, want sales", parsed.AgentID) + } + if parsed.Rest != "telegram:direct:user123" { + t.Fatalf("Rest = %q, want telegram:direct:user123", parsed.Rest) + } + + if got := ParseLegacyAgentSessionKey("sk_v1_abc"); got != nil { + t.Fatalf("expected nil for opaque key, got %+v", got) + } +} + +func TestBuildLegacyDirectAliases(t *testing.T) { + aliases := BuildLegacyDirectAliases("Main", "Telegram", "BotA", "User123") + want := []string{ + "agent:main:direct:user123", + "agent:main:telegram:direct:user123", + "agent:main:telegram:bota:direct:user123", + } + if len(aliases) != len(want) { + t.Fatalf("len(aliases) = %d, want %d", len(aliases), len(want)) + } + for i := range want { + if aliases[i] != want[i] { + t.Fatalf("aliases[%d] = %q, want %q", i, aliases[i], want[i]) + } + } +} + +func TestBuildLegacyPeerAlias(t *testing.T) { + got := BuildLegacyPeerAlias("Main", "Slack", "channel", "C001") + if got != "agent:main:slack:channel:c001" { + t.Fatalf("BuildLegacyPeerAlias() = %q", got) + } +} + +func TestBuildMainSessionKey(t *testing.T) { + got := BuildMainSessionKey("Main") + if !IsOpaqueSessionKey(got) { + t.Fatalf("BuildMainSessionKey() = %q, want opaque key", got) + } + if got != BuildOpaqueSessionKey("agent:main:main") { + t.Fatalf("BuildMainSessionKey() = %q, want stable main-key hash", got) + } +} + +func TestResolveAgentID_PrefersSessionScope(t *testing.T) { + store := testScopeReader{ + scope: &SessionScope{ + Version: ScopeVersionV1, + AgentID: "Support", + Channel: "slack", + }, + } + + if got := ResolveAgentID(store, "sk_v1_anything"); got != "support" { + t.Fatalf("ResolveAgentID() = %q, want support", got) + } +} + +func TestResolveAgentID_FallsBackToLegacyKey(t *testing.T) { + if got := ResolveAgentID(nil, "agent:Sales:telegram:direct:user123"); got != "sales" { + t.Fatalf("ResolveAgentID() = %q, want sales", got) + } +} diff --git a/pkg/session/scope.go b/pkg/session/scope.go new file mode 100644 index 000000000..efb026ea3 --- /dev/null +++ b/pkg/session/scope.go @@ -0,0 +1,32 @@ +package session + +// ScopeVersionV1 is the first structured session-scope schema version. +const ScopeVersionV1 = 1 + +// SessionScope describes the semantic session partition selected for a turn. +type SessionScope struct { + Version int `json:"version"` + AgentID string `json:"agent_id"` + Channel string `json:"channel"` + Account string `json:"account"` + Dimensions []string `json:"dimensions"` + Values map[string]string `json:"values"` +} + +// CloneScope returns a deep copy of scope. +func CloneScope(scope *SessionScope) *SessionScope { + if scope == nil { + return nil + } + cloned := *scope + if len(scope.Dimensions) > 0 { + cloned.Dimensions = append([]string(nil), scope.Dimensions...) + } + if len(scope.Values) > 0 { + cloned.Values = make(map[string]string, len(scope.Values)) + for key, value := range scope.Values { + cloned.Values[key] = value + } + } + return &cloned +} diff --git a/pkg/tools/base.go b/pkg/tools/base.go index afee95692..e1f9aacc0 100644 --- a/pkg/tools/base.go +++ b/pkg/tools/base.go @@ -1,6 +1,10 @@ package tools -import "context" +import ( + "context" + + "github.com/sipeed/picoclaw/pkg/session" +) // Tool is the interface that all tools must implement. type Tool interface { @@ -25,6 +29,9 @@ var ( ctxKeyChatID = &toolCtxKey{"chatID"} ctxKeyMessageID = &toolCtxKey{"messageID"} ctxKeyReplyToMessageID = &toolCtxKey{"replyToMessageID"} + ctxKeyAgentID = &toolCtxKey{"agentID"} + ctxKeySessionKey = &toolCtxKey{"sessionKey"} + ctxKeySessionScope = &toolCtxKey{"sessionScope"} ) // WithToolContext returns a child context carrying channel and chatID. @@ -51,6 +58,18 @@ func WithToolInboundContext( return ctx } +// WithToolSessionContext returns a child context carrying turn-scoped session metadata. +func WithToolSessionContext( + ctx context.Context, + agentID, sessionKey string, + scope *session.SessionScope, +) context.Context { + ctx = context.WithValue(ctx, ctxKeyAgentID, agentID) + ctx = context.WithValue(ctx, ctxKeySessionKey, sessionKey) + ctx = context.WithValue(ctx, ctxKeySessionScope, session.CloneScope(scope)) + return ctx +} + // ToolChannel extracts the channel from ctx, or "" if unset. func ToolChannel(ctx context.Context) string { v, _ := ctx.Value(ctxKeyChannel).(string) @@ -75,6 +94,24 @@ func ToolReplyToMessageID(ctx context.Context) string { return v } +// ToolAgentID extracts the active turn's agent ID from ctx, or "" if unset. +func ToolAgentID(ctx context.Context) string { + v, _ := ctx.Value(ctxKeyAgentID).(string) + return v +} + +// ToolSessionKey extracts the active turn's session key from ctx, or "" if unset. +func ToolSessionKey(ctx context.Context) string { + v, _ := ctx.Value(ctxKeySessionKey).(string) + return v +} + +// ToolSessionScope extracts the active turn's structured session scope from ctx. +func ToolSessionScope(ctx context.Context) *session.SessionScope { + scope, _ := ctx.Value(ctxKeySessionScope).(*session.SessionScope) + return session.CloneScope(scope) +} + // AsyncCallback is a function type that async tools use to notify completion. // When an async tool finishes its work, it calls this callback with the result. // diff --git a/pkg/tools/cron.go b/pkg/tools/cron.go index c6ac3a129..30a8e92cd 100644 --- a/pkg/tools/cron.go +++ b/pkg/tools/cron.go @@ -311,8 +311,7 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string { pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second) defer pubCancel() t.msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{ - Channel: channel, - ChatID: chatID, + Context: bus.NewOutboundContext(channel, chatID, ""), Content: output, }) return "ok" @@ -335,8 +334,7 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string { pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second) defer pubCancel() t.msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{ - Channel: channel, - ChatID: chatID, + Context: bus.NewOutboundContext(channel, chatID, ""), Content: output, }) return "ok" diff --git a/pkg/tools/message.go b/pkg/tools/message.go index 5a384b37e..39440e5a3 100644 --- a/pkg/tools/message.go +++ b/pkg/tools/message.go @@ -6,7 +6,7 @@ import ( "sync" ) -type SendCallback func(channel, chatID, content, replyToMessageID string) error +type SendCallbackWithContext func(ctx context.Context, channel, chatID, content, replyToMessageID string) error // sentTarget records the channel+chatID that the message tool sent to. type sentTarget struct { @@ -15,7 +15,7 @@ type sentTarget struct { } type MessageTool struct { - sendCallback SendCallback + sendCallback SendCallbackWithContext mu sync.Mutex sentTargets []sentTarget // Tracks all targets sent to in the current round } @@ -86,7 +86,7 @@ func (t *MessageTool) HasSentTo(channel, chatID string) bool { return false } -func (t *MessageTool) SetSendCallback(callback SendCallback) { +func (t *MessageTool) SetSendCallback(callback SendCallbackWithContext) { t.sendCallback = callback } @@ -115,7 +115,7 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]any) *ToolRes return &ToolResult{ForLLM: "Message sending not configured", IsError: true} } - if err := t.sendCallback(channel, chatID, content, replyToMessageID); err != nil { + if err := t.sendCallback(ctx, channel, chatID, content, replyToMessageID); err != nil { return &ToolResult{ ForLLM: fmt.Sprintf("sending message: %v", err), IsError: true, diff --git a/pkg/tools/message_test.go b/pkg/tools/message_test.go index 93a611ee0..649593252 100644 --- a/pkg/tools/message_test.go +++ b/pkg/tools/message_test.go @@ -4,16 +4,22 @@ import ( "context" "errors" "testing" + + "github.com/sipeed/picoclaw/pkg/session" ) func TestMessageTool_Execute_Success(t *testing.T) { tool := NewMessageTool() var sentChannel, sentChatID, sentContent string - tool.SetSendCallback(func(channel, chatID, content, replyToMessageID string) error { + tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error { sentChannel = channel sentChatID = chatID sentContent = content + if ToolAgentID(ctx) != "" || ToolSessionKey(ctx) != "" || ToolSessionScope(ctx) != nil { + t.Fatalf("expected empty turn metadata in basic context, got agent=%q session=%q scope=%+v", + ToolAgentID(ctx), ToolSessionKey(ctx), ToolSessionScope(ctx)) + } return nil }) @@ -61,7 +67,7 @@ func TestMessageTool_Execute_WithCustomChannel(t *testing.T) { tool := NewMessageTool() var sentChannel, sentChatID string - tool.SetSendCallback(func(channel, chatID, content, replyToMessageID string) error { + tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error { sentChannel = channel sentChatID = chatID return nil @@ -96,7 +102,7 @@ func TestMessageTool_Execute_SendFailure(t *testing.T) { tool := NewMessageTool() sendErr := errors.New("network error") - tool.SetSendCallback(func(channel, chatID, content, replyToMessageID string) error { + tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error { return sendErr }) @@ -149,7 +155,7 @@ func TestMessageTool_Execute_NoTargetChannel(t *testing.T) { tool := NewMessageTool() // No WithToolContext — channel/chatID are empty - tool.SetSendCallback(func(channel, chatID, content, replyToMessageID string) error { + tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error { return nil }) @@ -266,7 +272,7 @@ func TestMessageTool_Execute_WithReplyToMessageID(t *testing.T) { tool := NewMessageTool() var sentReplyTo string - tool.SetSendCallback(func(channel, chatID, content, replyToMessageID string) error { + tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error { sentReplyTo = replyToMessageID return nil }) @@ -285,3 +291,41 @@ func TestMessageTool_Execute_WithReplyToMessageID(t *testing.T) { t.Fatalf("expected reply_to_message_id msg-123, got %q", sentReplyTo) } } + +func TestMessageTool_Execute_PropagatesTurnSessionMetadata(t *testing.T) { + tool := NewMessageTool() + + var gotAgentID, gotSessionKey string + var gotScope *session.SessionScope + tool.SetSendCallback(func(ctx context.Context, channel, chatID, content, replyToMessageID string) error { + gotAgentID = ToolAgentID(ctx) + gotSessionKey = ToolSessionKey(ctx) + gotScope = ToolSessionScope(ctx) + return nil + }) + + ctx := WithToolContext(context.Background(), "test-channel", "test-chat-id") + ctx = WithToolSessionContext(ctx, "main", "sk_v1_tool", &session.SessionScope{ + Version: session.ScopeVersionV1, + AgentID: "main", + Channel: "telegram", + Dimensions: []string{"chat"}, + Values: map[string]string{ + "chat": "direct:test-chat-id", + }, + }) + + result := tool.Execute(ctx, map[string]any{"content": "Hello, world!"}) + if result.IsError { + t.Fatalf("expected success, got error: %s", result.ForLLM) + } + if gotAgentID != "main" { + t.Fatalf("ToolAgentID() = %q, want main", gotAgentID) + } + if gotSessionKey != "sk_v1_tool" { + t.Fatalf("ToolSessionKey() = %q, want sk_v1_tool", gotSessionKey) + } + if gotScope == nil || gotScope.Values["chat"] != "direct:test-chat-id" { + t.Fatalf("ToolSessionScope() = %+v, want chat scope", gotScope) + } +} diff --git a/web/backend/api/session.go b/web/backend/api/session.go index 9bb6055e2..054b78b73 100644 --- a/web/backend/api/session.go +++ b/web/backend/api/session.go @@ -13,7 +13,9 @@ import ( "time" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/memory" "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/session" "github.com/sipeed/picoclaw/pkg/utils" ) @@ -49,26 +51,12 @@ type sessionChatMessage struct { Media []string `json:"media,omitempty"` } -type sessionMetaFile struct { - Key string `json:"key"` - Summary string `json:"summary"` - Skip int `json:"skip"` - Count int `json:"count"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` -} - -// picoSessionPrefix is the key prefix used by the gateway's routing for Pico -// channel sessions. The full key format is: -// -// agent:main:pico:direct:pico: -// -// The sanitized filename replaces ':' with '_', so on disk it becomes: -// -// agent_main_pico_direct_pico_.json +// legacyPicoSessionPrefix is the legacy key prefix used by older Pico JSON/JSONL +// sessions before structured scope metadata existed. const ( - picoSessionPrefix = "agent:main:pico:direct:pico:" - sanitizedPicoSessionPrefix = "agent_main_pico_direct_pico_" + legacyPicoSessionPrefix = "agent:main:pico:direct:pico:" + picoSessionPrefix = legacyPicoSessionPrefix + // Keep the session API aligned with the shared JSONL store reader limit in // pkg/memory/jsonl.go so oversized lines fail consistently everywhere. maxSessionJSONLLineSize = 10 * 1024 * 1024 @@ -82,28 +70,23 @@ func defaultToolFeedbackMaxArgsLength() int { return defaults.GetToolFeedbackMaxArgsLength() } -// extractPicoSessionID extracts the session UUID from a full session key. +// extractLegacyPicoSessionID extracts the session UUID from an old Pico key. // Returns the UUID and true if the key matches the Pico session pattern. -func extractPicoSessionID(key string) (string, bool) { - if strings.HasPrefix(key, picoSessionPrefix) { - return strings.TrimPrefix(key, picoSessionPrefix), true - } - return "", false -} - -func extractPicoSessionIDFromSanitizedKey(key string) (string, bool) { - if strings.HasPrefix(key, sanitizedPicoSessionPrefix) { - return strings.TrimPrefix(key, sanitizedPicoSessionPrefix), true +func extractLegacyPicoSessionID(key string) (string, bool) { + if strings.HasPrefix(key, legacyPicoSessionPrefix) { + return strings.TrimPrefix(key, legacyPicoSessionPrefix), true } return "", false } func sanitizeSessionKey(key string) string { - return strings.ReplaceAll(key, ":", "_") + key = strings.ReplaceAll(key, ":", "_") + key = strings.ReplaceAll(key, "/", "_") + key = strings.ReplaceAll(key, "\\", "_") + return key } -func (h *Handler) readLegacySession(dir, sessionID string) (sessionFile, error) { - path := filepath.Join(dir, sanitizeSessionKey(picoSessionPrefix+sessionID)+".json") +func (h *Handler) readLegacySession(path string) (sessionFile, error) { data, err := os.ReadFile(path) if err != nil { return sessionFile{}, err @@ -116,18 +99,18 @@ func (h *Handler) readLegacySession(dir, sessionID string) (sessionFile, error) return sess, nil } -func (h *Handler) readSessionMeta(path, sessionKey string) (sessionMetaFile, error) { +func (h *Handler) readSessionMeta(path, sessionKey string) (memory.SessionMeta, error) { data, err := os.ReadFile(path) if os.IsNotExist(err) { - return sessionMetaFile{Key: sessionKey}, nil + return memory.SessionMeta{Key: sessionKey}, nil } if err != nil { - return sessionMetaFile{}, err + return memory.SessionMeta{}, err } - var meta sessionMetaFile + var meta memory.SessionMeta if err := json.Unmarshal(data, &meta); err != nil { - return sessionMetaFile{}, err + return memory.SessionMeta{}, err } if meta.Key == "" { meta.Key = sessionKey @@ -170,8 +153,7 @@ func (h *Handler) readSessionMessages(path string, skip int) ([]providers.Messag return msgs, nil } -func (h *Handler) readJSONLSession(dir, sessionID string) (sessionFile, error) { - sessionKey := picoSessionPrefix + sessionID +func (h *Handler) readJSONLSession(dir, sessionKey string) (sessionFile, error) { base := filepath.Join(dir, sanitizeSessionKey(sessionKey)) jsonlPath := base + ".jsonl" metaPath := base + ".meta.json" @@ -208,6 +190,213 @@ func (h *Handler) readJSONLSession(dir, sessionID string) (sessionFile, error) { }, nil } +type picoJSONLSessionRef struct { + ID string + Key string +} + +type picoLegacySessionRef struct { + ID string + Path string +} + +func extractPicoSessionIDFromScope(scope session.SessionScope) (string, bool) { + if !strings.EqualFold(strings.TrimSpace(scope.Channel), "pico") { + return "", false + } + + candidates := []string{ + strings.TrimSpace(scope.Values["sender"]), + strings.TrimSpace(scope.Values["chat"]), + } + for _, candidate := range candidates { + if candidate == "" { + continue + } + if idx := strings.Index(candidate, "pico:"); idx >= 0 { + sessionID := strings.TrimSpace(candidate[idx+len("pico:"):]) + if sessionID != "" { + return sessionID, true + } + } + } + return "", false +} + +func sessionRefFromMeta(meta memory.SessionMeta) (picoJSONLSessionRef, bool) { + if len(meta.Scope) == 0 { + if sessionID, ok := extractLegacyPicoSessionID(meta.Key); ok { + return picoJSONLSessionRef{ID: sessionID, Key: meta.Key}, true + } + for _, alias := range meta.Aliases { + if sessionID, ok := extractLegacyPicoSessionID(alias); ok { + return picoJSONLSessionRef{ID: sessionID, Key: meta.Key}, true + } + } + return picoJSONLSessionRef{}, false + } + var scope session.SessionScope + if err := json.Unmarshal(meta.Scope, &scope); err != nil { + return picoJSONLSessionRef{}, false + } + sessionID, ok := extractPicoSessionIDFromScope(scope) + if !ok { + if legacySessionID, ok := extractLegacyPicoSessionID(meta.Key); ok { + return picoJSONLSessionRef{ID: legacySessionID, Key: meta.Key}, true + } + for _, alias := range meta.Aliases { + if legacySessionID, ok := extractLegacyPicoSessionID(alias); ok { + return picoJSONLSessionRef{ID: legacySessionID, Key: meta.Key}, true + } + } + return picoJSONLSessionRef{}, false + } + return picoJSONLSessionRef{ID: sessionID, Key: meta.Key}, true +} + +func (h *Handler) findPicoJSONLSessions(dir string) ([]picoJSONLSessionRef, error) { + entries, err := os.ReadDir(dir) + if err != nil { + return nil, err + } + + refs := make([]picoJSONLSessionRef, 0) + seen := make(map[string]struct{}) + metaBackedBases := make(map[string]struct{}) + for _, entry := range entries { + if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".meta.json") { + continue + } + name := entry.Name() + metaPath := filepath.Join(dir, name) + meta, err := h.readSessionMeta(metaPath, "") + if err != nil { + continue + } + ref, ok := sessionRefFromMeta(meta) + if !ok || ref.Key == "" || ref.ID == "" { + continue + } + metaBackedBases[strings.TrimSuffix(name, ".meta.json")] = struct{}{} + if _, exists := seen[ref.ID]; exists { + continue + } + seen[ref.ID] = struct{}{} + refs = append(refs, ref) + } + + for _, entry := range entries { + if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".jsonl") { + continue + } + name := entry.Name() + base := strings.TrimSuffix(name, ".jsonl") + if _, ok := metaBackedBases[base]; ok { + continue + } + ref, ok := jsonlSessionRefFromFilename(name) + if !ok || ref.Key == "" || ref.ID == "" { + continue + } + if _, exists := seen[ref.ID]; exists { + continue + } + seen[ref.ID] = struct{}{} + refs = append(refs, ref) + } + return refs, nil +} + +func (h *Handler) findPicoJSONLSession(dir, sessionID string) (picoJSONLSessionRef, error) { + refs, err := h.findPicoJSONLSessions(dir) + if err != nil { + return picoJSONLSessionRef{}, err + } + for _, ref := range refs { + if ref.ID == sessionID { + return ref, nil + } + } + return picoJSONLSessionRef{}, os.ErrNotExist +} + +func (h *Handler) findLegacyPicoSessions(dir string) ([]picoLegacySessionRef, error) { + entries, err := os.ReadDir(dir) + if err != nil { + return nil, err + } + + refs := make([]picoLegacySessionRef, 0) + seen := make(map[string]struct{}) + for _, entry := range entries { + name := entry.Name() + if entry.IsDir() || filepath.Ext(name) != ".json" || strings.HasSuffix(name, ".meta.json") { + continue + } + + path := filepath.Join(dir, entry.Name()) + sess, err := h.readLegacySession(path) + if err != nil || isEmptySession(sess) { + continue + } + + sessionID, ok := extractLegacyPicoSessionID(sess.Key) + if !ok || sessionID == "" { + continue + } + if _, exists := seen[sessionID]; exists { + continue + } + seen[sessionID] = struct{}{} + refs = append(refs, picoLegacySessionRef{ID: sessionID, Path: path}) + } + return refs, nil +} + +func jsonlSessionRefFromFilename(name string) (picoJSONLSessionRef, bool) { + if !strings.HasSuffix(name, ".jsonl") { + return picoJSONLSessionRef{}, false + } + base := strings.TrimSuffix(name, ".jsonl") + if base == "" { + return picoJSONLSessionRef{}, false + } + + legacyPrefix := sanitizeSessionKey(legacyPicoSessionPrefix) + if strings.HasPrefix(base, legacyPrefix) { + sessionID := strings.TrimPrefix(base, legacyPrefix) + if sessionID == "" { + return picoJSONLSessionRef{}, false + } + return picoJSONLSessionRef{ + ID: sessionID, + Key: legacyPicoSessionPrefix + sessionID, + }, true + } + + if session.IsOpaqueSessionKey(base) { + return picoJSONLSessionRef{ + ID: base, + Key: base, + }, true + } + + return picoJSONLSessionRef{}, false +} + +func (h *Handler) findLegacyPicoSession(dir, sessionID string) (picoLegacySessionRef, error) { + refs, err := h.findLegacyPicoSessions(dir) + if err != nil { + return picoLegacySessionRef{}, err + } + for _, ref := range refs { + if ref.ID == sessionID { + return ref, nil + } + } + return picoLegacySessionRef{}, os.ErrNotExist +} + func buildSessionListItem(sessionID string, sess sessionFile, toolFeedbackMaxArgsLength int) sessionListItem { preview := "" for _, msg := range sess.Messages { @@ -458,8 +647,7 @@ func (h *Handler) handleListSessions(w http.ResponseWriter, r *http.Request) { return } - entries, err := os.ReadDir(dir) - if err != nil { + if _, err := os.ReadDir(dir); err != nil { // Directory doesn't exist yet = no sessions w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode([]sessionListItem{}) @@ -469,74 +657,29 @@ func (h *Handler) handleListSessions(w http.ResponseWriter, r *http.Request) { items := []sessionListItem{} seen := make(map[string]struct{}) - for _, entry := range entries { - if entry.IsDir() { - continue + if refs, findErr := h.findPicoJSONLSessions(dir); findErr == nil { + for _, ref := range refs { + sess, loadErr := h.readJSONLSession(dir, ref.Key) + if loadErr != nil || isEmptySession(sess) { + continue + } + seen[ref.ID] = struct{}{} + items = append(items, buildSessionListItem(ref.ID, sess, toolFeedbackMaxArgsLength)) } + } - name := entry.Name() - var ( - sessionID string - sess sessionFile - loadErr error - ok bool - ) - - switch { - case strings.HasSuffix(name, ".jsonl"): - sessionID, ok = extractPicoSessionIDFromSanitizedKey(strings.TrimSuffix(name, ".jsonl")) - if !ok { + if legacyRefs, findErr := h.findLegacyPicoSessions(dir); findErr == nil { + for _, ref := range legacyRefs { + if _, exists := seen[ref.ID]; exists { continue } - sess, loadErr = h.readJSONLSession(dir, sessionID) - if loadErr == nil && isEmptySession(sess) { + sess, loadErr := h.readLegacySession(ref.Path) + if loadErr != nil || isEmptySession(sess) { continue } - case strings.HasSuffix(name, ".meta.json"): - continue - case filepath.Ext(name) == ".json": - base := strings.TrimSuffix(name, ".json") - if _, statErr := os.Stat(filepath.Join(dir, base+".jsonl")); statErr == nil { - if jsonlSessionID, found := extractPicoSessionIDFromSanitizedKey(base); found { - if jsonlSess, jsonlErr := h.readJSONLSession( - dir, - jsonlSessionID, - ); jsonlErr == nil && - !isEmptySession(jsonlSess) { - continue - } - } - } - data, err := os.ReadFile(filepath.Join(dir, name)) - if err != nil { - continue - } - if err := json.Unmarshal(data, &sess); err != nil { - continue - } - if isEmptySession(sess) { - continue - } - sessionID, ok = extractPicoSessionID(sess.Key) - if !ok { - continue - } - if _, exists := seen[sessionID]; exists { - continue - } - default: - continue + seen[ref.ID] = struct{}{} + items = append(items, buildSessionListItem(ref.ID, sess, toolFeedbackMaxArgsLength)) } - - if loadErr != nil { - continue - } - if _, exists := seen[sessionID]; exists { - continue - } - - seen[sessionID] = struct{}{} - items = append(items, buildSessionListItem(sessionID, sess, toolFeedbackMaxArgsLength)) } // Sort by updated descending (most recent first) @@ -590,13 +733,20 @@ func (h *Handler) handleGetSession(w http.ResponseWriter, r *http.Request) { return } - sess, err := h.readJSONLSession(dir, sessionID) + ref, refErr := h.findPicoJSONLSession(dir, sessionID) + var sess sessionFile + err = refErr + if refErr == nil { + sess, err = h.readJSONLSession(dir, ref.Key) + } if err == nil && isEmptySession(sess) { err = os.ErrNotExist } if err != nil { if errors.Is(err, os.ErrNotExist) { - sess, err = h.readLegacySession(dir, sessionID) + if legacyRef, legacyErr := h.findLegacyPicoSession(dir, sessionID); legacyErr == nil { + sess, err = h.readLegacySession(legacyRef.Path) + } if err == nil && isEmptySession(sess) { err = os.ErrNotExist } @@ -639,21 +789,30 @@ func (h *Handler) handleDeleteSession(w http.ResponseWriter, r *http.Request) { return } - base := filepath.Join(dir, sanitizeSessionKey(picoSessionPrefix+sessionID)) - jsonlPath := base + ".jsonl" - metaPath := base + ".meta.json" - legacyPath := base + ".json" - removed := false - for _, path := range []string{jsonlPath, metaPath, legacyPath} { - if err := os.Remove(path); err != nil { - if os.IsNotExist(err) { - continue + if ref, err := h.findPicoJSONLSession(dir, sessionID); err == nil { + base := filepath.Join(dir, sanitizeSessionKey(ref.Key)) + for _, path := range []string{base + ".jsonl", base + ".meta.json"} { + if err := os.Remove(path); err != nil { + if os.IsNotExist(err) { + continue + } + http.Error(w, "failed to delete session", http.StatusInternalServerError) + return } - http.Error(w, "failed to delete session", http.StatusInternalServerError) - return + removed = true + } + } + + if legacyRef, err := h.findLegacyPicoSession(dir, sessionID); err == nil { + if err := os.Remove(legacyRef.Path); err != nil { + if !os.IsNotExist(err) { + http.Error(w, "failed to delete session", http.StatusInternalServerError) + return + } + } else { + removed = true } - removed = true } if !removed { diff --git a/web/backend/api/session_test.go b/web/backend/api/session_test.go index 599921bfe..e40a8c77c 100644 --- a/web/backend/api/session_test.go +++ b/web/backend/api/session_test.go @@ -36,12 +36,12 @@ func TestHandleListSessions_JSONLStorage(t *testing.T) { defer cleanup() dir := sessionsTestDir(t, configPath) - store, err := memory.NewJSONLStore(dir) - if err != nil { - t.Fatalf("NewJSONLStore() error = %v", err) + store, storeErr := memory.NewJSONLStore(dir) + if storeErr != nil { + t.Fatalf("NewJSONLStore() error = %v", storeErr) } - sessionKey := picoSessionPrefix + "history-jsonl" + sessionKey := legacyPicoSessionPrefix + "history-jsonl" if err := store.AddFullMessage(nil, sessionKey, providers.Message{ Role: "user", Content: "Explain why the history API is empty after migration.", @@ -106,12 +106,12 @@ func TestHandleListSessions_TitleUsesFirstUserMessage(t *testing.T) { defer cleanup() dir := sessionsTestDir(t, configPath) - store, err := memory.NewJSONLStore(dir) - if err != nil { - t.Fatalf("NewJSONLStore() error = %v", err) + store, storeErr := memory.NewJSONLStore(dir) + if storeErr != nil { + t.Fatalf("NewJSONLStore() error = %v", storeErr) } - sessionKey := picoSessionPrefix + "summary-title" + sessionKey := legacyPicoSessionPrefix + "summary-title" if err := store.AddFullMessage(nil, sessionKey, providers.Message{ Role: "user", Content: "fallback preview", @@ -164,7 +164,7 @@ func TestHandleGetSession_JSONLStorage(t *testing.T) { t.Fatalf("NewJSONLStore() error = %v", err) } - sessionKey := picoSessionPrefix + "detail-jsonl" + sessionKey := legacyPicoSessionPrefix + "detail-jsonl" for _, msg := range []providers.Message{ {Role: "user", Content: "first"}, {Role: "assistant", Content: "second"}, @@ -218,6 +218,81 @@ func TestHandleGetSession_JSONLStorage(t *testing.T) { } } +func TestHandleSessions_JSONLScopeDiscovery(t *testing.T) { + configPath, cleanup := setupOAuthTestEnv(t) + defer cleanup() + + dir := sessionsTestDir(t, configPath) + store, storeErr := memory.NewJSONLStore(dir) + if storeErr != nil { + t.Fatalf("NewJSONLStore() error = %v", storeErr) + } + + sessionKey := "sk_v1_scope_discovery" + if err := store.AddFullMessage(nil, sessionKey, providers.Message{ + Role: "user", + Content: "scope discovered session", + }); err != nil { + t.Fatalf("AddFullMessage() error = %v", err) + } + if err := store.SetSummary(nil, sessionKey, "scope summary"); err != nil { + t.Fatalf("SetSummary() error = %v", err) + } + + scopeData, err := json.Marshal(session.SessionScope{ + Version: session.ScopeVersionV1, + AgentID: "main", + Channel: "pico", + Account: "default", + Dimensions: []string{"sender"}, + Values: map[string]string{ + "sender": "pico:scope-jsonl", + }, + }) + if err != nil { + t.Fatalf("Marshal(scope) error = %v", err) + } + if err := store.UpsertSessionMeta(nil, sessionKey, scopeData, nil); err != nil { + t.Fatalf("UpsertSessionMeta() error = %v", err) + } + + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + listRec := httptest.NewRecorder() + listReq := httptest.NewRequest(http.MethodGet, "/api/sessions", nil) + mux.ServeHTTP(listRec, listReq) + if listRec.Code != http.StatusOK { + t.Fatalf("list status = %d, want %d, body=%s", listRec.Code, http.StatusOK, listRec.Body.String()) + } + + var items []sessionListItem + if err := json.Unmarshal(listRec.Body.Bytes(), &items); err != nil { + t.Fatalf("Unmarshal(list) error = %v", err) + } + if len(items) != 1 { + t.Fatalf("len(items) = %d, want 1", len(items)) + } + if items[0].ID != "scope-jsonl" { + t.Fatalf("items[0].ID = %q, want %q", items[0].ID, "scope-jsonl") + } + + detailRec := httptest.NewRecorder() + detailReq := httptest.NewRequest(http.MethodGet, "/api/sessions/scope-jsonl", nil) + mux.ServeHTTP(detailRec, detailReq) + if detailRec.Code != http.StatusOK { + t.Fatalf("detail status = %d, want %d, body=%s", detailRec.Code, http.StatusOK, detailRec.Body.String()) + } + + deleteRec := httptest.NewRecorder() + deleteReq := httptest.NewRequest(http.MethodDelete, "/api/sessions/scope-jsonl", nil) + mux.ServeHTTP(deleteRec, deleteReq) + if deleteRec.Code != http.StatusNoContent { + t.Fatalf("delete status = %d, want %d, body=%s", deleteRec.Code, http.StatusNoContent, deleteRec.Body.String()) + } +} + func TestHandleGetSession_OmitsTransientThoughtMessages(t *testing.T) { configPath, cleanup := setupOAuthTestEnv(t) defer cleanup() @@ -784,7 +859,7 @@ func TestHandleDeleteSession_JSONLStorage(t *testing.T) { t.Fatalf("NewJSONLStore() error = %v", err) } - sessionKey := picoSessionPrefix + "delete-jsonl" + sessionKey := legacyPicoSessionPrefix + "delete-jsonl" if err := store.AddFullMessage(nil, sessionKey, providers.Message{ Role: "user", Content: "delete me", @@ -821,7 +896,7 @@ func TestHandleGetSession_LegacyJSONFallback(t *testing.T) { dir := sessionsTestDir(t, configPath) manager := session.NewSessionManager(dir) - sessionKey := picoSessionPrefix + "legacy-json" + sessionKey := legacyPicoSessionPrefix + "legacy-json" manager.AddMessage(sessionKey, "user", "legacy user") manager.AddMessage(sessionKey, "assistant", "legacy assistant") if err := manager.Save(sessionKey); err != nil { @@ -846,7 +921,7 @@ func TestHandleSessions_FiltersEmptyJSONLFiles(t *testing.T) { defer cleanup() dir := sessionsTestDir(t, configPath) - base := filepath.Join(dir, sanitizeSessionKey(picoSessionPrefix+"empty-jsonl")) + base := filepath.Join(dir, sanitizeSessionKey(legacyPicoSessionPrefix+"empty-jsonl")) if err := os.WriteFile(base+".jsonl", []byte{}, 0o644); err != nil { t.Fatalf("WriteFile(jsonl) error = %v", err) } @@ -879,3 +954,82 @@ func TestHandleSessions_FiltersEmptyJSONLFiles(t *testing.T) { t.Fatalf("detail status = %d, want %d, body=%s", detailRec.Code, http.StatusNotFound, detailRec.Body.String()) } } + +func TestHandleSessions_ListsLegacyJSONLWithoutMeta(t *testing.T) { + configPath, cleanup := setupOAuthTestEnv(t) + defer cleanup() + + dir := sessionsTestDir(t, configPath) + sessionKey := legacyPicoSessionPrefix + "missing-meta" + base := filepath.Join(dir, sanitizeSessionKey(sessionKey)) + line, err := json.Marshal(providers.Message{Role: "user", Content: "recover me"}) + if err != nil { + t.Fatalf("Marshal(message) error = %v", err) + } + if err := os.WriteFile(base+".jsonl", append(line, '\n'), 0o644); err != nil { + t.Fatalf("WriteFile(jsonl) error = %v", err) + } + + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + listRec := httptest.NewRecorder() + listReq := httptest.NewRequest(http.MethodGet, "/api/sessions", nil) + mux.ServeHTTP(listRec, listReq) + + if listRec.Code != http.StatusOK { + t.Fatalf("list status = %d, want %d, body=%s", listRec.Code, http.StatusOK, listRec.Body.String()) + } + + var items []sessionListItem + if err := json.Unmarshal(listRec.Body.Bytes(), &items); err != nil { + t.Fatalf("Unmarshal(list) error = %v", err) + } + if len(items) != 1 { + t.Fatalf("len(items) = %d, want 1", len(items)) + } + if items[0].ID != "missing-meta" { + t.Fatalf("items[0].ID = %q, want %q", items[0].ID, "missing-meta") + } + + detailRec := httptest.NewRecorder() + detailReq := httptest.NewRequest(http.MethodGet, "/api/sessions/missing-meta", nil) + mux.ServeHTTP(detailRec, detailReq) + + if detailRec.Code != http.StatusOK { + t.Fatalf("detail status = %d, want %d, body=%s", detailRec.Code, http.StatusOK, detailRec.Body.String()) + } +} + +func TestHandleSessions_IgnoresMetaJSONInLegacyFallback(t *testing.T) { + configPath, cleanup := setupOAuthTestEnv(t) + defer cleanup() + + dir := sessionsTestDir(t, configPath) + metaOnly := filepath.Join(dir, "agent_main_pico_direct_pico_meta-only.meta.json") + metaOnlyContent := []byte(`{"key":"agent:main:pico:direct:pico:meta-only","summary":"meta only"}`) + if err := os.WriteFile(metaOnly, metaOnlyContent, 0o644); err != nil { + t.Fatalf("WriteFile(meta) error = %v", err) + } + + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + listRec := httptest.NewRecorder() + listReq := httptest.NewRequest(http.MethodGet, "/api/sessions", nil) + mux.ServeHTTP(listRec, listReq) + + if listRec.Code != http.StatusOK { + t.Fatalf("list status = %d, want %d, body=%s", listRec.Code, http.StatusOK, listRec.Body.String()) + } + + var items []sessionListItem + if err := json.Unmarshal(listRec.Body.Bytes(), &items); err != nil { + t.Fatalf("Unmarshal(list) error = %v", err) + } + if len(items) != 0 { + t.Fatalf("len(items) = %d, want 0", len(items)) + } +}