diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 78b91068a..39a2e1539 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -27,6 +27,7 @@ import ( "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/routing" + "github.com/sipeed/picoclaw/pkg/session" "github.com/sipeed/picoclaw/pkg/skills" "github.com/sipeed/picoclaw/pkg/state" "github.com/sipeed/picoclaw/pkg/tools" @@ -672,9 +673,10 @@ func (al *AgentLoop) buildContinuationTarget(msg bus.InboundMessage) (*continuat if err != nil { return nil, err } + allocation := al.allocateRouteSession(route, msg) return &continuationTarget{ - SessionKey: resolveScopeKey(route, msg.SessionKey), + SessionKey: resolveScopeKey(allocation.SessionKey, msg.SessionKey), Channel: msg.Channel, ChatID: msg.ChatID, }, nil @@ -1323,18 +1325,22 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) } } - // Resolve session key from route, while preserving explicit agent-scoped keys. - scopeKey := resolveScopeKey(route, msg.SessionKey) + allocation := al.allocateRouteSession(route, msg) + + // Resolve session key from the route allocation, while preserving explicit + // agent-scoped keys supplied by the caller. + scopeKey := resolveScopeKey(allocation.SessionKey, msg.SessionKey) sessionKey := scopeKey logger.InfoCF("agent", "Routed message", map[string]any{ - "agent_id": agent.ID, - "scope_key": scopeKey, - "session_key": sessionKey, - "matched_by": route.MatchedBy, - "route_agent": route.AgentID, - "route_channel": route.Channel, + "agent_id": agent.ID, + "scope_key": scopeKey, + "session_key": sessionKey, + "matched_by": route.MatchedBy, + "route_agent": route.AgentID, + "route_channel": route.Channel, + "route_main_session": allocation.MainSessionKey, }) opts := processOptions{ @@ -1401,11 +1407,21 @@ func normalizedInboundContext(msg bus.InboundMessage) bus.InboundContext { return bus.NormalizeInboundMessage(msg).Context } -func resolveScopeKey(route routing.ResolvedRoute, msgSessionKey string) string { +func resolveScopeKey(routeSessionKey, msgSessionKey string) string { if msgSessionKey != "" && strings.HasPrefix(msgSessionKey, sessionKeyAgentPrefix) { return msgSessionKey } - return route.SessionKey + return routeSessionKey +} + +func (al *AgentLoop) allocateRouteSession(route routing.ResolvedRoute, msg bus.InboundMessage) session.Allocation { + return session.AllocateRouteSession(session.AllocationInput{ + AgentID: route.AgentID, + Channel: route.Channel, + AccountID: route.AccountID, + Peer: extractPeer(msg), + SessionPolicy: route.SessionPolicy, + }) } func (al *AgentLoop) resolveSteeringTarget(msg bus.InboundMessage) (string, string, bool) { @@ -1417,8 +1433,9 @@ func (al *AgentLoop) resolveSteeringTarget(msg bus.InboundMessage) (string, stri if err != nil || agent == nil { return "", "", false } + allocation := al.allocateRouteSession(route, msg) - return resolveScopeKey(route, msg.SessionKey), agent.ID, true + return resolveScopeKey(allocation.SessionKey, msg.SessionKey), agent.ID, true } func (al *AgentLoop) requeueInboundMessage(msg bus.InboundMessage) error { diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go index 54235b23a..1f99a5085 100644 --- a/pkg/agent/loop_test.go +++ b/pkg/agent/loop_test.go @@ -670,7 +670,12 @@ func TestProcessMessage_MediaToolHandledSkipsFollowUpLLMAndFinalText(t *testing. if err != nil { t.Fatalf("resolveMessageRoute() error = %v", err) } - sessionKey := resolveScopeKey(route, "") + sessionKey := resolveScopeKey(al.allocateRouteSession(route, bus.InboundMessage{ + Channel: "telegram", + ChatID: "chat1", + SenderID: "user1", + Content: "take a screenshot of the screen and send it to me", + }).SessionKey, "") history := defaultAgent.Sessions.GetHistory(sessionKey) if len(history) == 0 { t.Fatal("expected session history to be saved") @@ -1492,7 +1497,7 @@ func TestProcessMessage_UsesRouteSessionKey(t *testing.T) { Channel: msg.Channel, Peer: extractPeer(msg), }) - sessionKey := route.SessionKey + sessionKey := al.allocateRouteSession(route, msg).SessionKey defaultAgent := al.registry.GetDefaultAgent() if defaultAgent == nil { @@ -2195,7 +2200,15 @@ func TestAgentLoop_ToolLimitUsesDedicatedFallback(t *testing.T) { ID: "cron", }, }) - history := defaultAgent.Sessions.GetHistory(route.SessionKey) + history := defaultAgent.Sessions.GetHistory(al.allocateRouteSession(route, bus.InboundMessage{ + Channel: "test", + SenderID: "cron", + ChatID: "chat1", + Peer: bus.Peer{ + Kind: "direct", + ID: "cron", + }, + }).SessionKey) if len(history) != 4 { t.Fatalf("history len = %d, want 4", len(history)) } diff --git a/pkg/routing/route.go b/pkg/routing/route.go index 9eb060c53..494aefabb 100644 --- a/pkg/routing/route.go +++ b/pkg/routing/route.go @@ -16,14 +16,21 @@ type RouteInput struct { TeamID string } +// SessionPolicy describes how a routed message should be mapped to a session. +// The current implementation preserves the legacy dm_scope and identity_link +// semantics while moving session-key construction out of the router. +type SessionPolicy struct { + DMScope DMScope + IdentityLinks map[string][]string +} + // ResolvedRoute is the result of agent routing. type ResolvedRoute struct { - AgentID string - Channel string - AccountID string - SessionKey string - MainSessionKey string - MatchedBy string // "binding.peer", "binding.peer.parent", "binding.guild", "binding.team", "binding.account", "binding.channel", "default" + AgentID string + Channel string + AccountID string + SessionPolicy SessionPolicy + MatchedBy string // "binding.peer", "binding.peer.parent", "binding.guild", "binding.team", "binding.account", "binding.channel", "default" } // RouteResolver determines which agent handles a message based on config bindings. @@ -36,7 +43,8 @@ func NewRouteResolver(cfg *config.Config) *RouteResolver { return &RouteResolver{cfg: cfg} } -// ResolveRoute determines which agent handles the message and constructs session keys. +// ResolveRoute determines which agent handles the message and returns the +// session policy that should be used to allocate session state. // Implements the 7-level priority cascade: // peer > parent_peer > guild > team > account > channel_wildcard > default func (r *RouteResolver) ResolveRoute(input RouteInput) ResolvedRoute { @@ -44,32 +52,18 @@ func (r *RouteResolver) ResolveRoute(input RouteInput) ResolvedRoute { accountID := NormalizeAccountID(input.AccountID) peer := input.Peer - dmScope := DMScope(r.cfg.Session.DMScope) - if dmScope == "" { - dmScope = DMScopeMain - } - identityLinks := r.cfg.Session.IdentityLinks + sessionPolicy := r.sessionPolicy() bindings := r.filterBindings(channel, accountID) choose := func(agentID string, matchedBy string) ResolvedRoute { resolvedAgentID := r.pickAgentID(agentID) - sessionKey := strings.ToLower(BuildAgentPeerSessionKey(SessionKeyParams{ + return ResolvedRoute{ AgentID: resolvedAgentID, Channel: channel, AccountID: accountID, - Peer: peer, - DMScope: dmScope, - IdentityLinks: identityLinks, - })) - mainSessionKey := strings.ToLower(BuildAgentMainSessionKey(resolvedAgentID)) - return ResolvedRoute{ - AgentID: resolvedAgentID, - Channel: channel, - AccountID: accountID, - SessionKey: sessionKey, - MainSessionKey: mainSessionKey, - MatchedBy: matchedBy, + SessionPolicy: sessionPolicy, + MatchedBy: matchedBy, } } @@ -250,3 +244,27 @@ func (r *RouteResolver) resolveDefaultAgentID() string { } return DefaultAgentID } + +func (r *RouteResolver) sessionPolicy() SessionPolicy { + dmScope := DMScope(r.cfg.Session.DMScope) + if dmScope == "" { + dmScope = DMScopeMain + } + return SessionPolicy{ + DMScope: dmScope, + IdentityLinks: cloneIdentityLinks(r.cfg.Session.IdentityLinks), + } +} + +func cloneIdentityLinks(src map[string][]string) map[string][]string { + if len(src) == 0 { + return nil + } + cloned := make(map[string][]string, len(src)) + for canonical, ids := range src { + dup := make([]string, len(ids)) + copy(dup, ids) + cloned[canonical] = dup + } + return cloned +} diff --git a/pkg/routing/route_test.go b/pkg/routing/route_test.go index fdfc899f9..ab1a7a4e2 100644 --- a/pkg/routing/route_test.go +++ b/pkg/routing/route_test.go @@ -37,6 +37,12 @@ func TestResolveRoute_DefaultAgent_NoBindings(t *testing.T) { if route.MatchedBy != "default" { t.Errorf("MatchedBy = %q, want 'default'", route.MatchedBy) } + if route.SessionPolicy.DMScope != DMScopePerPeer { + t.Errorf("SessionPolicy.DMScope = %q, want %q", route.SessionPolicy.DMScope, DMScopePerPeer) + } + if route.SessionPolicy.IdentityLinks != nil { + t.Errorf("SessionPolicy.IdentityLinks = %v, want nil", route.SessionPolicy.IdentityLinks) + } } func TestResolveRoute_PeerBinding(t *testing.T) { diff --git a/pkg/session/allocator.go b/pkg/session/allocator.go new file mode 100644 index 000000000..675e577f8 --- /dev/null +++ b/pkg/session/allocator.go @@ -0,0 +1,43 @@ +package session + +import ( + "strings" + + "github.com/sipeed/picoclaw/pkg/routing" +) + +// Allocation contains the concrete session keys selected for a routed turn. +// The current implementation intentionally preserves the legacy session-key +// layout while moving key construction out of the router. +type Allocation struct { + SessionKey string + MainSessionKey string +} + +// AllocationInput contains the routing result and peer context needed to +// derive the session keys for a turn. +type AllocationInput struct { + AgentID string + Channel string + AccountID string + Peer *routing.RoutePeer + SessionPolicy routing.SessionPolicy +} + +// AllocateRouteSession maps a route decision onto the current legacy +// agent-scoped session-key format. +func AllocateRouteSession(input AllocationInput) Allocation { + sessionKey := strings.ToLower(routing.BuildAgentPeerSessionKey(routing.SessionKeyParams{ + AgentID: input.AgentID, + Channel: input.Channel, + AccountID: input.AccountID, + Peer: input.Peer, + DMScope: input.SessionPolicy.DMScope, + IdentityLinks: input.SessionPolicy.IdentityLinks, + })) + mainSessionKey := strings.ToLower(routing.BuildAgentMainSessionKey(input.AgentID)) + return Allocation{ + SessionKey: sessionKey, + MainSessionKey: mainSessionKey, + } +} diff --git a/pkg/session/allocator_test.go b/pkg/session/allocator_test.go new file mode 100644 index 000000000..a6e84e09d --- /dev/null +++ b/pkg/session/allocator_test.go @@ -0,0 +1,51 @@ +package session + +import ( + "testing" + + "github.com/sipeed/picoclaw/pkg/routing" +) + +func TestAllocateRouteSession_PerPeerDM(t *testing.T) { + allocation := AllocateRouteSession(AllocationInput{ + AgentID: "main", + Channel: "telegram", + AccountID: "default", + Peer: &routing.RoutePeer{ + Kind: "direct", + ID: "User123", + }, + SessionPolicy: routing.SessionPolicy{ + DMScope: routing.DMScopePerPeer, + }, + }) + + if allocation.SessionKey != "agent:main:direct:user123" { + t.Fatalf("SessionKey = %q, want %q", allocation.SessionKey, "agent:main:direct:user123") + } + if allocation.MainSessionKey != "agent:main:main" { + t.Fatalf("MainSessionKey = %q, want %q", allocation.MainSessionKey, "agent:main:main") + } +} + +func TestAllocateRouteSession_GroupPeer(t *testing.T) { + allocation := AllocateRouteSession(AllocationInput{ + AgentID: "main", + Channel: "slack", + AccountID: "workspace-a", + Peer: &routing.RoutePeer{ + Kind: "channel", + ID: "C001", + }, + SessionPolicy: routing.SessionPolicy{ + DMScope: routing.DMScopePerAccountChannelPeer, + }, + }) + + if allocation.SessionKey != "agent:main:slack:channel:c001" { + t.Fatalf("SessionKey = %q, want %q", allocation.SessionKey, "agent:main:slack:channel:c001") + } + if allocation.MainSessionKey != "agent:main:main" { + t.Fatalf("MainSessionKey = %q, want %q", allocation.MainSessionKey, "agent:main:main") + } +}