From 296077eabf7ad4ce3a65aa3aba34ce9b0f6c25d9 Mon Sep 17 00:00:00 2001 From: Hoshina Date: Wed, 8 Apr 2026 00:32:53 +0800 Subject: [PATCH] fix(session): restore thread and legacy compatibility --- pkg/agent/loop.go | 4 ++ pkg/agent/steering.go | 6 +- pkg/bus/bus.go | 6 +- pkg/bus/inbound_context.go | 7 +-- pkg/bus/outbound_context.go | 39 ++++++++++--- pkg/channels/manager.go | 4 +- pkg/channels/slack/slack.go | 45 ++++++++++++--- pkg/channels/slack/slack_test.go | 18 ++++++ pkg/channels/telegram/telegram.go | 26 ++++++++- pkg/channels/telegram/telegram_test.go | 32 +++++++++++ pkg/config/config_test.go | 46 +++++++++++++++ pkg/config/legacy_bindings.go | 68 ++++++++++++++++++++-- pkg/session/allocator.go | 66 +++++++++++++++++---- pkg/session/allocator_test.go | 59 +++++++++++++++++++ pkg/session/jsonl_backend.go | 7 +++ pkg/session/jsonl_backend_test.go | 43 ++++++++++++++ web/backend/api/session.go | 59 ++++++++++++++++++- web/backend/api/session_test.go | 79 ++++++++++++++++++++++++++ 18 files changed, 568 insertions(+), 46 deletions(-) diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 26b35c2f1..1512ff824 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -805,6 +805,8 @@ func outboundTurnMetadata( func outboundMessageForTurn(ts *turnState, content string) bus.OutboundMessage { agentID, sessionKey, scope := outboundTurnMetadata(ts.agent.ID, ts.sessionKey, ts.opts.Dispatch.SessionScope) return bus.OutboundMessage{ + Channel: ts.channel, + ChatID: ts.chatID, Context: outboundContextFromInbound( ts.opts.Dispatch.InboundContext, ts.channel, @@ -2827,6 +2829,8 @@ turnLoop: parts = append(parts, part) } outboundMedia := bus.OutboundMediaMessage{ + Channel: ts.channel, + ChatID: ts.chatID, Context: outboundContextFromInbound( ts.opts.Dispatch.InboundContext, ts.channel, diff --git a/pkg/agent/steering.go b/pkg/agent/steering.go index a7051890d..d70c92731 100644 --- a/pkg/agent/steering.go +++ b/pkg/agent/steering.go @@ -3,6 +3,7 @@ package agent import ( "context" "fmt" + "sort" "strings" "sync" @@ -319,7 +320,9 @@ func (al *AgentLoop) agentForSession(sessionKey string) *AgentInstance { return nil } - for _, agentID := range registry.ListAgentIDs() { + agentIDs := registry.ListAgentIDs() + sort.Strings(agentIDs) + for _, agentID := range agentIDs { agent, ok := registry.GetAgent(agentID) if !ok || agent == nil { continue @@ -331,7 +334,6 @@ func (al *AgentLoop) agentForSession(sessionKey string) *AgentInstance { if scopedAgent, ok := registry.GetAgent(resolvedAgentID); ok { return scopedAgent } - return agent } return registry.GetDefaultAgent() diff --git a/pkg/bus/bus.go b/pkg/bus/bus.go index 03ef3123f..9a05d4f95 100644 --- a/pkg/bus/bus.go +++ b/pkg/bus/bus.go @@ -90,10 +90,10 @@ func publish[T any](ctx context.Context, mb *MessageBus, ch chan T, msg T) error } func (mb *MessageBus) PublishInbound(ctx context.Context, msg InboundMessage) error { + msg = NormalizeInboundMessage(msg) if msg.Context.isZero() { return ErrMissingInboundContext } - msg = NormalizeInboundMessage(msg) return publish(ctx, mb, mb.inbound, msg) } @@ -102,10 +102,10 @@ func (mb *MessageBus) InboundChan() <-chan InboundMessage { } func (mb *MessageBus) PublishOutbound(ctx context.Context, msg OutboundMessage) error { + msg = NormalizeOutboundMessage(msg) if msg.Context.isZero() { return ErrMissingOutboundContext } - msg = NormalizeOutboundMessage(msg) return publish(ctx, mb, mb.outbound, msg) } @@ -114,10 +114,10 @@ func (mb *MessageBus) OutboundChan() <-chan OutboundMessage { } func (mb *MessageBus) PublishOutboundMedia(ctx context.Context, msg OutboundMediaMessage) error { + msg = NormalizeOutboundMediaMessage(msg) if msg.Context.isZero() { return ErrMissingOutboundMediaContext } - msg = NormalizeOutboundMediaMessage(msg) return publish(ctx, mb, mb.outboundMedia, msg) } diff --git a/pkg/bus/inbound_context.go b/pkg/bus/inbound_context.go index 3a19ac957..320424178 100644 --- a/pkg/bus/inbound_context.go +++ b/pkg/bus/inbound_context.go @@ -65,10 +65,5 @@ func cloneStringMap(src map[string]string) map[string]string { } func normalizeKind(kind string) string { - switch strings.ToLower(strings.TrimSpace(kind)) { - case "direct", "group", "channel", "guild", "team", "workspace", "tenant", "topic": - return strings.ToLower(strings.TrimSpace(kind)) - default: - return strings.ToLower(strings.TrimSpace(kind)) - } + return strings.ToLower(strings.TrimSpace(kind)) } diff --git a/pkg/bus/outbound_context.go b/pkg/bus/outbound_context.go index 416a26861..4861483a1 100644 --- a/pkg/bus/outbound_context.go +++ b/pkg/bus/outbound_context.go @@ -15,23 +15,48 @@ func NewOutboundContext(channel, chatID, replyToMessageID string) InboundContext // NormalizeOutboundMessage ensures Context is normalized and keeps convenience // mirrors in sync for runtime consumers. func NormalizeOutboundMessage(msg OutboundMessage) OutboundMessage { - msg.Context = normalizeInboundContext(msg.Context) - msg.Channel = msg.Context.Channel - msg.ChatID = msg.Context.ChatID - msg.Scope = cloneOutboundScope(msg.Scope) + msg.Channel = strings.TrimSpace(msg.Channel) + msg.ChatID = strings.TrimSpace(msg.ChatID) + msg.ReplyToMessageID = strings.TrimSpace(msg.ReplyToMessageID) + if msg.Context.Channel == "" { + msg.Context.Channel = msg.Channel + } + if msg.Context.ChatID == "" { + msg.Context.ChatID = msg.ChatID + } if msg.Context.ReplyToMessageID == "" { - msg.Context.ReplyToMessageID = strings.TrimSpace(msg.ReplyToMessageID) + msg.Context.ReplyToMessageID = msg.ReplyToMessageID + } + msg.Context = normalizeInboundContext(msg.Context) + if msg.Channel == "" { + msg.Channel = msg.Context.Channel + } + if msg.ChatID == "" { + msg.ChatID = msg.Context.ChatID } msg.ReplyToMessageID = msg.Context.ReplyToMessageID + msg.Scope = cloneOutboundScope(msg.Scope) return msg } // NormalizeOutboundMediaMessage ensures media outbound messages also carry a // normalized context while keeping convenience mirrors in sync. func NormalizeOutboundMediaMessage(msg OutboundMediaMessage) OutboundMediaMessage { + msg.Channel = strings.TrimSpace(msg.Channel) + msg.ChatID = strings.TrimSpace(msg.ChatID) + if msg.Context.Channel == "" { + msg.Context.Channel = msg.Channel + } + if msg.Context.ChatID == "" { + msg.Context.ChatID = msg.ChatID + } msg.Context = normalizeInboundContext(msg.Context) - msg.Channel = msg.Context.Channel - msg.ChatID = msg.Context.ChatID + if msg.Channel == "" { + msg.Channel = msg.Context.Channel + } + if msg.ChatID == "" { + msg.ChatID = msg.Context.ChatID + } msg.Scope = cloneOutboundScope(msg.Scope) return msg } diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go index 7c4013676..f62438eca 100644 --- a/pkg/channels/manager.go +++ b/pkg/channels/manager.go @@ -103,7 +103,7 @@ func outboundMessageChannel(msg bus.OutboundMessage) string { } func outboundMessageChatID(msg bus.OutboundMessage) string { - return msg.Context.ChatID + return msg.ChatID } func outboundMediaChannel(msg bus.OutboundMediaMessage) string { @@ -111,7 +111,7 @@ func outboundMediaChannel(msg bus.OutboundMediaMessage) string { } func outboundMediaChatID(msg bus.OutboundMediaMessage) string { - return msg.Context.ChatID + return msg.ChatID } // RecordPlaceholder registers a placeholder message for later editing. diff --git a/pkg/channels/slack/slack.go b/pkg/channels/slack/slack.go index 543f6f338..53d112e6c 100644 --- a/pkg/channels/slack/slack.go +++ b/pkg/channels/slack/slack.go @@ -113,7 +113,7 @@ func (c *SlackChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]str return nil, channels.ErrNotRunning } - channelID, threadTS := parseSlackChatID(msg.ChatID) + deliveryChatID, channelID, threadTS := resolveSlackOutboundTarget(msg.ChatID, &msg.Context) if channelID == "" { return nil, fmt.Errorf("invalid slack chat ID: %s", msg.ChatID) } @@ -135,7 +135,7 @@ func (c *SlackChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]str return nil, fmt.Errorf("slack send: %w", channels.ErrTemporary) } - if ref, ok := c.pendingAcks.LoadAndDelete(msg.ChatID); ok { + if ref, ok := c.pendingAcks.LoadAndDelete(deliveryChatID); ok { msgRef := ref.(slackMessageRef) c.api.AddReaction("white_check_mark", slack.ItemRef{ Channel: msgRef.ChannelID, @@ -157,7 +157,7 @@ func (c *SlackChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessa return nil, channels.ErrNotRunning } - channelID, _ := parseSlackChatID(msg.ChatID) + _, channelID, threadTS := resolveSlackMediaOutboundTarget(msg.ChatID, &msg.Context) if channelID == "" { return nil, fmt.Errorf("invalid slack chat ID: %s", msg.ChatID) } @@ -188,10 +188,11 @@ func (c *SlackChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessa } _, err = c.api.UploadFileV2Context(ctx, slack.UploadFileV2Parameters{ - Channel: channelID, - File: localPath, - Filename: filename, - Title: title, + Channel: channelID, + ThreadTimestamp: threadTS, + File: localPath, + Filename: filename, + Title: title, }) if err != nil { logger.ErrorCF("slack", "Failed to upload media", map[string]any{ @@ -561,3 +562,33 @@ func parseSlackChatID(chatID string) (channelID, threadTS string) { } return channelID, threadTS } + +func resolveSlackOutboundTarget(chatID string, outboundCtx *bus.InboundContext) (string, string, string) { + deliveryChatID := strings.TrimSpace(chatID) + if deliveryChatID == "" && outboundCtx != nil { + deliveryChatID = strings.TrimSpace(outboundCtx.ChatID) + } + channelID, threadTS := parseSlackChatID(deliveryChatID) + if threadTS == "" && outboundCtx != nil { + threadTS = strings.TrimSpace(outboundCtx.TopicID) + if threadTS != "" && channelID != "" { + deliveryChatID = channelID + "/" + threadTS + } + } + return deliveryChatID, channelID, threadTS +} + +func resolveSlackMediaOutboundTarget(chatID string, outboundCtx *bus.InboundContext) (string, string, string) { + deliveryChatID := strings.TrimSpace(chatID) + if deliveryChatID == "" && outboundCtx != nil { + deliveryChatID = strings.TrimSpace(outboundCtx.ChatID) + } + channelID, threadTS := parseSlackChatID(deliveryChatID) + if threadTS == "" && outboundCtx != nil { + threadTS = strings.TrimSpace(outboundCtx.TopicID) + if threadTS != "" && channelID != "" { + deliveryChatID = channelID + "/" + threadTS + } + } + return deliveryChatID, channelID, threadTS +} diff --git a/pkg/channels/slack/slack_test.go b/pkg/channels/slack/slack_test.go index d1980a7c9..a81c2193c 100644 --- a/pkg/channels/slack/slack_test.go +++ b/pkg/channels/slack/slack_test.go @@ -53,6 +53,24 @@ func TestParseSlackChatID(t *testing.T) { } } +func TestResolveSlackOutboundTarget_PrefersContextTopicID(t *testing.T) { + deliveryChatID, channelID, threadTS := resolveSlackOutboundTarget("C123456", &bus.InboundContext{ + Channel: "slack", + ChatID: "C123456", + TopicID: "1234567890.123456", + }) + + if deliveryChatID != "C123456/1234567890.123456" { + t.Fatalf("deliveryChatID = %q, want %q", deliveryChatID, "C123456/1234567890.123456") + } + if channelID != "C123456" { + t.Fatalf("channelID = %q, want %q", channelID, "C123456") + } + if threadTS != "1234567890.123456" { + t.Fatalf("threadTS = %q, want %q", threadTS, "1234567890.123456") + } +} + func TestStripBotMention(t *testing.T) { ch := &SlackChannel{botUserID: "U12345BOT"} diff --git a/pkg/channels/telegram/telegram.go b/pkg/channels/telegram/telegram.go index 20a659266..270d44131 100644 --- a/pkg/channels/telegram/telegram.go +++ b/pkg/channels/telegram/telegram.go @@ -176,7 +176,7 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([] useMarkdownV2 := c.config.Channels.Telegram.UseMarkdownV2 - chatID, threadID, err := parseTelegramChatID(msg.ChatID) + chatID, threadID, err := resolveTelegramOutboundTarget(msg.ChatID, &msg.Context) if err != nil { return nil, fmt.Errorf("invalid chat ID %s: %w", msg.ChatID, channels.ErrSendFailed) } @@ -463,7 +463,7 @@ func (c *TelegramChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMe return nil, channels.ErrNotRunning } - chatID, threadID, err := parseTelegramChatID(msg.ChatID) + chatID, threadID, err := resolveTelegramOutboundTarget(msg.ChatID, &msg.Context) if err != nil { return nil, fmt.Errorf("invalid chat ID %s: %w", msg.ChatID, channels.ErrSendFailed) } @@ -960,6 +960,28 @@ func parseTelegramChatID(chatID string) (int64, int, error) { return cid, tid, nil } +func resolveTelegramOutboundTarget(chatID string, outboundCtx *bus.InboundContext) (int64, int, error) { + targetChatID := strings.TrimSpace(chatID) + if targetChatID == "" && outboundCtx != nil { + targetChatID = strings.TrimSpace(outboundCtx.ChatID) + } + resolvedChatID, resolvedThreadID, err := parseTelegramChatID(targetChatID) + if err != nil { + return 0, 0, err + } + if resolvedThreadID != 0 || outboundCtx == nil { + return resolvedChatID, resolvedThreadID, nil + } + topicID := strings.TrimSpace(outboundCtx.TopicID) + if topicID == "" { + return resolvedChatID, resolvedThreadID, nil + } + if threadID, convErr := strconv.Atoi(topicID); convErr == nil { + return resolvedChatID, threadID, nil + } + return resolvedChatID, resolvedThreadID, nil +} + func logParseFailed(err error, useMarkdownV2 bool) { parsingName := "HTML" if useMarkdownV2 { diff --git a/pkg/channels/telegram/telegram_test.go b/pkg/channels/telegram/telegram_test.go index 0b5d21e2b..8e8fc7053 100644 --- a/pkg/channels/telegram/telegram_test.go +++ b/pkg/channels/telegram/telegram_test.go @@ -527,6 +527,38 @@ func TestSend_WithForumThreadID(t *testing.T) { assert.Len(t, caller.calls, 1) } +func TestSend_UsesContextTopicIDWhenChatIDDoesNotIncludeThread(t *testing.T) { + caller := &stubCaller{ + callFn: func(ctx context.Context, url string, data *ta.RequestData) (*ta.Response, error) { + return successResponse(t), nil + }, + } + ch := newTestChannel(t, caller) + + _, err := ch.Send(context.Background(), bus.OutboundMessage{ + ChatID: "-1001234567890", + Content: "Hello from topic context", + Context: bus.InboundContext{ + Channel: "telegram", + ChatID: "-1001234567890", + TopicID: "42", + }, + }) + + require.NoError(t, err) + require.Len(t, caller.calls, 1) + + var params struct { + ChatID int64 `json:"chat_id"` + MessageThreadID int `json:"message_thread_id"` + Text string `json:"text"` + } + require.NoError(t, json.Unmarshal(caller.calls[0].Data.BodyRaw, ¶ms)) + assert.Equal(t, int64(-1001234567890), params.ChatID) + assert.Equal(t, 42, params.MessageThreadID) + assert.Equal(t, "Hello from topic context", params.Text) +} + func TestHandleMessage_ForumTopic_SetsMetadata(t *testing.T) { messageBus := bus.NewMessageBus() ch := &TelegramChannel{ diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index 74e5cc9fe..9aa91e4d9 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -425,6 +425,52 @@ func TestLoadConfig_PrefersDispatchRulesOverLegacyBindings(t *testing.T) { } } +func TestLoadConfig_MigratesLegacyDirectBindingsWithIdentityLinks(t *testing.T) { + dir := t.TempDir() + configPath := filepath.Join(dir, "config.json") + raw := `{ + "version": 2, + "agents": { + "defaults": { + "workspace": "~/.picoclaw/workspace", + "model": "glm-4.7" + }, + "list": [ + { "id": "main", "default": true }, + { "id": "support" } + ] + }, + "session": { + "identity_links": { + "john": ["telegram:123", "123"] + } + }, + "bindings": [ + { + "agent_id": "support", + "match": { + "channel": "telegram", + "peer": { "kind": "direct", "id": "123" } + } + } + ] + }` + if err := os.WriteFile(configPath, []byte(raw), 0o644); err != nil { + t.Fatalf("WriteFile(configPath): %v", err) + } + + cfg, err := LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error: %v", err) + } + if cfg.Agents.Dispatch == nil || len(cfg.Agents.Dispatch.Rules) != 1 { + t.Fatalf("Dispatch.Rules = %+v, want 1 migrated rule", cfg.Agents.Dispatch) + } + if got := cfg.Agents.Dispatch.Rules[0].When.Sender; got != "john" { + t.Fatalf("migrated sender selector = %q, want %q", got, "john") + } +} + // TestDefaultConfig_HeartbeatEnabled verifies heartbeat is enabled by default func TestDefaultConfig_HeartbeatEnabled(t *testing.T) { cfg := DefaultConfig() diff --git a/pkg/config/legacy_bindings.go b/pkg/config/legacy_bindings.go index 83fa08669..751a35de7 100644 --- a/pkg/config/legacy_bindings.go +++ b/pkg/config/legacy_bindings.go @@ -57,7 +57,7 @@ func applyLegacyBindingsMigration(data []byte, cfg *Config) { return } - rules, dropped := migrateLegacyBindings(bindings) + rules, dropped := migrateLegacyBindings(bindings, cfg.Session.IdentityLinks) if len(rules) == 0 { logger.WarnF( "legacy bindings config is deprecated and could not be migrated", @@ -97,7 +97,7 @@ func decodeLegacyBindings(data []byte) ([]legacyAgentBinding, bool, error) { return bindings, true, nil } -func migrateLegacyBindings(bindings []legacyAgentBinding) ([]DispatchRule, int) { +func migrateLegacyBindings(bindings []legacyAgentBinding, identityLinks map[string][]string) ([]DispatchRule, int) { if len(bindings) == 0 { return nil, 0 } @@ -111,7 +111,7 @@ func migrateLegacyBindings(bindings []legacyAgentBinding) ([]DispatchRule, int) prioritized := make([]prioritizedRule, 0, len(bindings)) dropped := 0 for i, binding := range bindings { - rule, kind, ok := migrateLegacyBinding(binding, i) + rule, kind, ok := migrateLegacyBinding(binding, i, identityLinks) if !ok { dropped++ continue @@ -133,7 +133,11 @@ func migrateLegacyBindings(bindings []legacyAgentBinding) ([]DispatchRule, int) return rules, dropped } -func migrateLegacyBinding(binding legacyAgentBinding, index int) (DispatchRule, int, bool) { +func migrateLegacyBinding( + binding legacyAgentBinding, + index int, + identityLinks map[string][]string, +) (DispatchRule, int, bool) { channel := strings.ToLower(strings.TrimSpace(binding.Match.Channel)) agentID := strings.TrimSpace(binding.AgentID) if channel == "" || agentID == "" { @@ -163,7 +167,7 @@ func migrateLegacyBinding(binding legacyAgentBinding, index int) (DispatchRule, } switch peerKind { case "direct": - rule.When.Sender = peerID + rule.When.Sender = canonicalLegacyBindingSenderID(channel, peerID, identityLinks) return rule, 0, true case "group", "channel": rule.When.Chat = peerKind + ":" + peerID @@ -207,3 +211,57 @@ func normalizeLegacyAccountSelector(accountID string) string { return strings.ToLower(accountID) } } + +func canonicalLegacyBindingSenderID(channel, peerID string, identityLinks map[string][]string) string { + peerID = strings.TrimSpace(peerID) + if peerID == "" { + return "" + } + + if linked := resolveLegacyBindingLinkedID(identityLinks, channel, peerID); linked != "" { + return strings.ToLower(linked) + } + + return strings.ToLower(peerID) +} + +func resolveLegacyBindingLinkedID(identityLinks map[string][]string, channel, peerID string) string { + if len(identityLinks) == 0 { + return "" + } + peerID = strings.TrimSpace(peerID) + if peerID == "" { + return "" + } + + candidates := make(map[string]struct{}) + rawCandidate := strings.ToLower(peerID) + if rawCandidate != "" { + candidates[rawCandidate] = struct{}{} + } + channel = strings.ToLower(strings.TrimSpace(channel)) + if channel != "" { + candidates[channel+":"+rawCandidate] = struct{}{} + } + if idx := strings.Index(rawCandidate, ":"); idx > 0 && idx < len(rawCandidate)-1 { + candidates[rawCandidate[idx+1:]] = struct{}{} + } + + for canonical, ids := range identityLinks { + canonical = strings.TrimSpace(canonical) + if canonical == "" { + continue + } + for _, id := range ids { + normalized := strings.ToLower(strings.TrimSpace(id)) + if normalized == "" { + continue + } + if _, ok := candidates[normalized]; ok { + return canonical + } + } + } + + return "" +} diff --git a/pkg/session/allocator.go b/pkg/session/allocator.go index 7045b93d6..509550cb2 100644 --- a/pkg/session/allocator.go +++ b/pkg/session/allocator.go @@ -44,6 +44,7 @@ func AllocateRouteSession(input AllocationInput) Allocation { func buildSessionScope(input AllocationInput) SessionScope { inbound := input.Context + includeTopicInChatDimension := shouldPreserveTelegramForumIsolation(input) scope := SessionScope{ Version: ScopeVersionV1, AgentID: routing.NormalizeAgentID(input.AgentID), @@ -73,6 +74,11 @@ func buildSessionScope(input AllocationInput) SessionScope { if chatID == "" { continue } + if includeTopicInChatDimension { + if topicID := strings.TrimSpace(inbound.TopicID); topicID != "" { + chatID = chatID + "/" + topicID + } + } chatType := strings.ToLower(strings.TrimSpace(inbound.ChatType)) if chatType == "" { chatType = "direct" @@ -111,18 +117,16 @@ func buildLegacySessionAliases(input AllocationInput) []string { inbound := input.Context if strings.EqualFold(strings.TrimSpace(inbound.ChatType), "direct") { - senderID := CanonicalSessionIdentityID( - inbound.Channel, - inbound.SenderID, - input.SessionPolicy.IdentityLinks, - ) - if senderID == "" { + peerIDs := buildLegacyDirectPeerIDs(input) + if len(peerIDs) == 0 { return uniqueAliases(aliases) } - aliases = append( - aliases, - BuildLegacyDirectAliases(input.AgentID, inbound.Channel, inbound.Account, senderID)..., - ) + for _, peerID := range peerIDs { + aliases = append( + aliases, + BuildLegacyDirectAliases(input.AgentID, inbound.Channel, inbound.Account, peerID)..., + ) + } return uniqueAliases(aliases) } @@ -143,6 +147,48 @@ func buildLegacySessionAliases(input AllocationInput) []string { return uniqueAliases(aliases) } +func shouldPreserveTelegramForumIsolation(input AllocationInput) bool { + inbound := input.Context + if !strings.EqualFold(strings.TrimSpace(inbound.Channel), "telegram") { + return false + } + if strings.TrimSpace(inbound.TopicID) == "" { + return false + } + for _, dimension := range input.SessionPolicy.Dimensions { + if strings.EqualFold(strings.TrimSpace(dimension), "topic") { + return false + } + } + return true +} + +func buildLegacyDirectPeerIDs(input AllocationInput) []string { + inbound := input.Context + peerIDs := make([]string, 0, 3) + + rawSenderID := strings.TrimSpace(inbound.SenderID) + if rawSenderID != "" { + peerIDs = append(peerIDs, strings.ToLower(rawSenderID)) + } + + canonicalSenderID := CanonicalSessionIdentityID( + inbound.Channel, + inbound.SenderID, + input.SessionPolicy.IdentityLinks, + ) + if canonicalSenderID != "" { + peerIDs = append(peerIDs, canonicalSenderID) + } + + chatID := strings.TrimSpace(inbound.ChatID) + if chatID != "" { + peerIDs = append(peerIDs, strings.ToLower(chatID)) + } + + return uniqueAliases(peerIDs) +} + func uniqueAliases(aliases []string) []string { if len(aliases) == 0 { return nil diff --git a/pkg/session/allocator_test.go b/pkg/session/allocator_test.go index c688fe0bf..9750ffc39 100644 --- a/pkg/session/allocator_test.go +++ b/pkg/session/allocator_test.go @@ -80,6 +80,65 @@ func TestAllocateRouteSession_GroupPeer(t *testing.T) { } } +func TestAllocateRouteSession_TelegramForumTopicsRemainIsolatedByDefault(t *testing.T) { + first := AllocateRouteSession(AllocationInput{ + AgentID: "main", + Context: bus.InboundContext{ + Channel: "telegram", + ChatID: "-1001234567890", + ChatType: "group", + TopicID: "42", + SenderID: "7", + }, + SessionPolicy: routing.SessionPolicy{ + Dimensions: []string{"chat"}, + }, + }) + second := AllocateRouteSession(AllocationInput{ + AgentID: "main", + Context: bus.InboundContext{ + Channel: "telegram", + ChatID: "-1001234567890", + ChatType: "group", + TopicID: "99", + SenderID: "7", + }, + SessionPolicy: routing.SessionPolicy{ + Dimensions: []string{"chat"}, + }, + }) + + if first.SessionKey == second.SessionKey { + t.Fatalf("forum topics should not share default session key: %q", first.SessionKey) + } + if got := first.Scope.Values["chat"]; got != "group:-1001234567890/42" { + t.Fatalf("first.Scope.Values[chat] = %q, want %q", got, "group:-1001234567890/42") + } + if got := second.Scope.Values["chat"]; got != "group:-1001234567890/99" { + t.Fatalf("second.Scope.Values[chat] = %q, want %q", got, "group:-1001234567890/99") + } +} + +func TestAllocateRouteSession_PicoDirectAliasesIncludeLegacyChatKey(t *testing.T) { + allocation := AllocateRouteSession(AllocationInput{ + AgentID: "main", + Context: bus.InboundContext{ + Channel: "pico", + Account: "default", + ChatID: "pico:session-123", + ChatType: "direct", + SenderID: "pico-user", + }, + SessionPolicy: routing.SessionPolicy{ + Dimensions: []string{"sender"}, + }, + }) + + if !containsAlias(allocation.SessionAliases, "agent:main:pico:direct:pico:session-123") { + t.Fatalf("SessionAliases = %v, want pico legacy alias", allocation.SessionAliases) + } +} + func TestBuildOpaqueSessionKey_IsStable(t *testing.T) { first := BuildOpaqueSessionKey("agent:main:direct:user123") second := BuildOpaqueSessionKey("agent:main:direct:user123") diff --git a/pkg/session/jsonl_backend.go b/pkg/session/jsonl_backend.go index 06044b618..4e4f96029 100644 --- a/pkg/session/jsonl_backend.go +++ b/pkg/session/jsonl_backend.go @@ -84,6 +84,13 @@ func (b *JSONLBackend) EnsureSessionMetadata(sessionKey string, scope *SessionSc return } + canonicalMeta, metaErr := metaStore.GetSessionMeta(ctx, sessionKey) + if metaErr != nil { + log.Printf("session: get canonical session metadata: %v", metaErr) + } else if canonicalMeta.Count > 0 || strings.TrimSpace(canonicalMeta.Summary) != "" { + return + } + canonicalHistory, historyErr := b.store.GetHistory(ctx, sessionKey) if historyErr != nil { log.Printf("session: get canonical history: %v", historyErr) diff --git a/pkg/session/jsonl_backend_test.go b/pkg/session/jsonl_backend_test.go index 411e3e8c5..362619125 100644 --- a/pkg/session/jsonl_backend_test.go +++ b/pkg/session/jsonl_backend_test.go @@ -4,8 +4,10 @@ import ( "fmt" "testing" + "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/memory" "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/routing" "github.com/sipeed/picoclaw/pkg/session" ) @@ -239,3 +241,44 @@ func TestJSONLBackend_EnsureSessionMetadata_PromotesLegacyAliasHistory(t *testin t.Fatalf("promoted summary = %q, want %q", summary, "legacy summary") } } + +func TestJSONLBackend_EnsureSessionMetadata_PromotesLegacyPicoDirectAliasHistory(t *testing.T) { + b := newBackend(t) + + legacyKey := "agent:main:pico:direct:pico:session-123" + b.AddMessage(legacyKey, "user", "legacy pico history") + + scope := &session.SessionScope{ + Version: session.ScopeVersionV1, + AgentID: "main", + Channel: "pico", + Account: "default", + Dimensions: []string{"sender"}, + Values: map[string]string{ + "sender": "pico-user", + }, + } + allocation := session.AllocateRouteSession(session.AllocationInput{ + AgentID: "main", + Context: bus.InboundContext{ + Channel: "pico", + Account: "default", + ChatID: "pico:session-123", + ChatType: "direct", + SenderID: "pico-user", + }, + SessionPolicy: routing.SessionPolicy{ + Dimensions: []string{"sender"}, + }, + }) + + b.EnsureSessionMetadata(allocation.SessionKey, scope, allocation.SessionAliases) + + if got := b.ResolveSessionKey(legacyKey); got != allocation.SessionKey { + t.Fatalf("ResolveSessionKey() = %q, want %q", got, allocation.SessionKey) + } + history := b.GetHistory(allocation.SessionKey) + if len(history) != 1 || history[0].Content != "legacy pico history" { + t.Fatalf("promoted history = %+v", history) + } +} diff --git a/web/backend/api/session.go b/web/backend/api/session.go index 914e075f9..f3dd03dc0 100644 --- a/web/backend/api/session.go +++ b/web/backend/api/session.go @@ -256,11 +256,13 @@ func (h *Handler) findPicoJSONLSessions(dir string) ([]picoJSONLSessionRef, erro refs := make([]picoJSONLSessionRef, 0) seen := make(map[string]struct{}) + metaBackedBases := make(map[string]struct{}) for _, entry := range entries { if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".meta.json") { continue } - metaPath := filepath.Join(dir, entry.Name()) + name := entry.Name() + metaPath := filepath.Join(dir, name) meta, err := h.readSessionMeta(metaPath, "") if err != nil { continue @@ -269,6 +271,27 @@ func (h *Handler) findPicoJSONLSessions(dir string) ([]picoJSONLSessionRef, erro if !ok || ref.Key == "" || ref.ID == "" { continue } + metaBackedBases[strings.TrimSuffix(name, ".meta.json")] = struct{}{} + if _, exists := seen[ref.ID]; exists { + continue + } + seen[ref.ID] = struct{}{} + refs = append(refs, ref) + } + + for _, entry := range entries { + if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".jsonl") { + continue + } + name := entry.Name() + base := strings.TrimSuffix(name, ".jsonl") + if _, ok := metaBackedBases[base]; ok { + continue + } + ref, ok := jsonlSessionRefFromFilename(name) + if !ok || ref.Key == "" || ref.ID == "" { + continue + } if _, exists := seen[ref.ID]; exists { continue } @@ -300,7 +323,8 @@ func (h *Handler) findLegacyPicoSessions(dir string) ([]picoLegacySessionRef, er refs := make([]picoLegacySessionRef, 0) seen := make(map[string]struct{}) for _, entry := range entries { - if entry.IsDir() || filepath.Ext(entry.Name()) != ".json" { + name := entry.Name() + if entry.IsDir() || filepath.Ext(name) != ".json" || strings.HasSuffix(name, ".meta.json") { continue } @@ -323,6 +347,37 @@ func (h *Handler) findLegacyPicoSessions(dir string) ([]picoLegacySessionRef, er return refs, nil } +func jsonlSessionRefFromFilename(name string) (picoJSONLSessionRef, bool) { + if !strings.HasSuffix(name, ".jsonl") { + return picoJSONLSessionRef{}, false + } + base := strings.TrimSuffix(name, ".jsonl") + if base == "" { + return picoJSONLSessionRef{}, false + } + + legacyPrefix := sanitizeSessionKey(legacyPicoSessionPrefix) + if strings.HasPrefix(base, legacyPrefix) { + sessionID := strings.TrimPrefix(base, legacyPrefix) + if sessionID == "" { + return picoJSONLSessionRef{}, false + } + return picoJSONLSessionRef{ + ID: sessionID, + Key: legacyPicoSessionPrefix + sessionID, + }, true + } + + if session.IsOpaqueSessionKey(base) { + return picoJSONLSessionRef{ + ID: base, + Key: base, + }, true + } + + return picoJSONLSessionRef{}, false +} + func (h *Handler) findLegacyPicoSession(dir, sessionID string) (picoLegacySessionRef, error) { refs, err := h.findLegacyPicoSessions(dir) if err != nil { diff --git a/web/backend/api/session_test.go b/web/backend/api/session_test.go index 4c871ee30..6b7205057 100644 --- a/web/backend/api/session_test.go +++ b/web/backend/api/session_test.go @@ -750,3 +750,82 @@ func TestHandleSessions_FiltersEmptyJSONLFiles(t *testing.T) { t.Fatalf("detail status = %d, want %d, body=%s", detailRec.Code, http.StatusNotFound, detailRec.Body.String()) } } + +func TestHandleSessions_ListsLegacyJSONLWithoutMeta(t *testing.T) { + configPath, cleanup := setupOAuthTestEnv(t) + defer cleanup() + + dir := sessionsTestDir(t, configPath) + sessionKey := legacyPicoSessionPrefix + "missing-meta" + base := filepath.Join(dir, sanitizeSessionKey(sessionKey)) + line, err := json.Marshal(providers.Message{Role: "user", Content: "recover me"}) + if err != nil { + t.Fatalf("Marshal(message) error = %v", err) + } + if err := os.WriteFile(base+".jsonl", append(line, '\n'), 0o644); err != nil { + t.Fatalf("WriteFile(jsonl) error = %v", err) + } + + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + listRec := httptest.NewRecorder() + listReq := httptest.NewRequest(http.MethodGet, "/api/sessions", nil) + mux.ServeHTTP(listRec, listReq) + + if listRec.Code != http.StatusOK { + t.Fatalf("list status = %d, want %d, body=%s", listRec.Code, http.StatusOK, listRec.Body.String()) + } + + var items []sessionListItem + if err := json.Unmarshal(listRec.Body.Bytes(), &items); err != nil { + t.Fatalf("Unmarshal(list) error = %v", err) + } + if len(items) != 1 { + t.Fatalf("len(items) = %d, want 1", len(items)) + } + if items[0].ID != "missing-meta" { + t.Fatalf("items[0].ID = %q, want %q", items[0].ID, "missing-meta") + } + + detailRec := httptest.NewRecorder() + detailReq := httptest.NewRequest(http.MethodGet, "/api/sessions/missing-meta", nil) + mux.ServeHTTP(detailRec, detailReq) + + if detailRec.Code != http.StatusOK { + t.Fatalf("detail status = %d, want %d, body=%s", detailRec.Code, http.StatusOK, detailRec.Body.String()) + } +} + +func TestHandleSessions_IgnoresMetaJSONInLegacyFallback(t *testing.T) { + configPath, cleanup := setupOAuthTestEnv(t) + defer cleanup() + + dir := sessionsTestDir(t, configPath) + metaOnly := filepath.Join(dir, "agent_main_pico_direct_pico_meta-only.meta.json") + metaOnlyContent := []byte(`{"key":"agent:main:pico:direct:pico:meta-only","summary":"meta only"}`) + if err := os.WriteFile(metaOnly, metaOnlyContent, 0o644); err != nil { + t.Fatalf("WriteFile(meta) error = %v", err) + } + + h := NewHandler(configPath) + mux := http.NewServeMux() + h.RegisterRoutes(mux) + + listRec := httptest.NewRecorder() + listReq := httptest.NewRequest(http.MethodGet, "/api/sessions", nil) + mux.ServeHTTP(listRec, listReq) + + if listRec.Code != http.StatusOK { + t.Fatalf("list status = %d, want %d, body=%s", listRec.Code, http.StatusOK, listRec.Body.String()) + } + + var items []sessionListItem + if err := json.Unmarshal(listRec.Body.Bytes(), &items); err != nil { + t.Fatalf("Unmarshal(list) error = %v", err) + } + if len(items) != 0 { + t.Fatalf("len(items) = %d, want 0", len(items)) + } +}