diff --git a/pkg/agent/dispatch_request.go b/pkg/agent/dispatch_request.go index 40548c41a..cb54264d6 100644 --- a/pkg/agent/dispatch_request.go +++ b/pkg/agent/dispatch_request.go @@ -93,9 +93,7 @@ func normalizeProcessOptions(opts processOptions) processOptions { MessageID: strings.TrimSpace(opts.MessageID), ReplyToMessageID: strings.TrimSpace(opts.ReplyToMessageID), } - if inbound.Channel != "" && inbound.ChatID != "" { - inbound.ChatType = "direct" - } + inbound.ChatType = inferChatTypeFromSessionScope(opts.Dispatch.SessionScope) if inbound.Channel != "" || inbound.ChatID != "" || inbound.SenderID != "" || inbound.MessageID != "" || inbound.ReplyToMessageID != "" { inbound = bus.NormalizeInboundMessage(bus.InboundMessage{Context: inbound}).Context @@ -132,3 +130,18 @@ func normalizeProcessOptions(opts processOptions) processOptions { return opts } + +func inferChatTypeFromSessionScope(scope *session.SessionScope) string { + if scope == nil || len(scope.Values) == 0 { + return "" + } + chatValue := strings.TrimSpace(scope.Values["chat"]) + if chatValue == "" { + return "" + } + chatType, _, ok := strings.Cut(chatValue, ":") + if !ok { + return "" + } + return strings.ToLower(strings.TrimSpace(chatType)) +} diff --git a/pkg/agent/dispatch_request_test.go b/pkg/agent/dispatch_request_test.go index 89fc01a3b..ec5f70339 100644 --- a/pkg/agent/dispatch_request_test.go +++ b/pkg/agent/dispatch_request_test.go @@ -108,3 +108,28 @@ func TestNormalizeProcessOptions_UsesDispatchAsSourceOfTruth(t *testing.T) { t.Fatalf("SessionScope = %#v, want support scope", opts.SessionScope) } } + +func TestNormalizeProcessOptions_InfersLegacyChatTypeFromSessionScope(t *testing.T) { + opts := normalizeProcessOptions(processOptions{ + Channel: "telegram", + ChatID: "-100123", + SenderID: "user-1", + UserMessage: "hello", + SessionScope: &session.SessionScope{ + Version: session.ScopeVersionV1, + AgentID: "main", + Channel: "telegram", + Dimensions: []string{"chat"}, + Values: map[string]string{ + "chat": "group:-100123", + }, + }, + }) + + if opts.Dispatch.InboundContext == nil { + t.Fatal("Dispatch.InboundContext is nil") + } + if opts.Dispatch.InboundContext.ChatType != "group" { + t.Fatalf("Dispatch.InboundContext.ChatType = %q, want group", opts.Dispatch.InboundContext.ChatType) + } +} diff --git a/pkg/agent/steering.go b/pkg/agent/steering.go index d70c92731..a2e5fec21 100644 --- a/pkg/agent/steering.go +++ b/pkg/agent/steering.go @@ -292,16 +292,18 @@ func (al *AgentLoop) continueWithSteeringMessages( ctx context.Context, agent *AgentInstance, sessionKey, channel, chatID string, + scope *session.SessionScope, steeringMsgs []providers.Message, ) (string, error) { dispatch := DispatchRequest{ - SessionKey: sessionKey, + SessionKey: sessionKey, + SessionScope: session.CloneScope(scope), } if channel != "" || chatID != "" { dispatch.InboundContext = &bus.InboundContext{ Channel: channel, ChatID: chatID, - ChatType: "direct", + ChatType: inferChatTypeFromSessionScope(scope), } } return al.runAgentLoop(ctx, agent, processOptions{ @@ -372,7 +374,12 @@ func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID s } } - return al.continueWithSteeringMessages(ctx, agent, sessionKey, channel, chatID, steeringMsgs) + var scope *session.SessionScope + if metaStore, ok := agent.Sessions.(session.MetadataAwareSessionStore); ok { + scope = metaStore.GetSessionScope(sessionKey) + } + + return al.continueWithSteeringMessages(ctx, agent, sessionKey, channel, chatID, scope, steeringMsgs) } func (al *AgentLoop) InterruptGraceful(hint string) error { diff --git a/pkg/bus/bus_test.go b/pkg/bus/bus_test.go index fc1f8b611..5145d4759 100644 --- a/pkg/bus/bus_test.go +++ b/pkg/bus/bus_test.go @@ -278,6 +278,33 @@ func TestPublishOutbound_PreservesExplicitReplyToMessageID(t *testing.T) { } } +func TestPublishOutbound_PreservesExplicitReplyToMessageIDWhenContextReplyIsBlank(t *testing.T) { + mb := NewMessageBus() + defer mb.Close() + + msg := OutboundMessage{ + Context: InboundContext{ + Channel: "telegram", + ChatID: "chat-42", + ReplyToMessageID: " ", + }, + ReplyToMessageID: "msg-9", + Content: "reply", + } + + if err := mb.PublishOutbound(context.Background(), msg); err != nil { + t.Fatalf("PublishOutbound failed: %v", err) + } + + got := <-mb.OutboundChan() + if got.ReplyToMessageID != "msg-9" { + t.Fatalf("expected mirrored reply_to_message_id msg-9, got %q", got.ReplyToMessageID) + } + if got.Context.ReplyToMessageID != "msg-9" { + t.Fatalf("expected context reply_to_message_id msg-9, got %q", got.Context.ReplyToMessageID) + } +} + func TestPublishOutboundMedia_MirrorsContextToLegacyFields(t *testing.T) { mb := NewMessageBus() defer mb.Close() diff --git a/pkg/bus/outbound_context.go b/pkg/bus/outbound_context.go index 4861483a1..cbbbc99c7 100644 --- a/pkg/bus/outbound_context.go +++ b/pkg/bus/outbound_context.go @@ -34,7 +34,12 @@ func NormalizeOutboundMessage(msg OutboundMessage) OutboundMessage { if msg.ChatID == "" { msg.ChatID = msg.Context.ChatID } - msg.ReplyToMessageID = msg.Context.ReplyToMessageID + if msg.ReplyToMessageID == "" { + msg.ReplyToMessageID = msg.Context.ReplyToMessageID + } + if msg.Context.ReplyToMessageID == "" { + msg.Context.ReplyToMessageID = msg.ReplyToMessageID + } msg.Scope = cloneOutboundScope(msg.Scope) return msg } diff --git a/pkg/memory/jsonl.go b/pkg/memory/jsonl.go index a1b794b97..8d3320f3f 100644 --- a/pkg/memory/jsonl.go +++ b/pkg/memory/jsonl.go @@ -374,6 +374,11 @@ func (s *JSONLStore) promoteAliasHistoryLocked( return false, nil } + previousJSONL, hadPreviousJSONL, err := s.readRawJSONL(sessionKey) + if err != nil { + return false, err + } + now := time.Now() if canonicalMeta.CreatedAt.IsZero() { canonicalMeta.CreatedAt = now @@ -387,10 +392,13 @@ func (s *JSONLStore) promoteAliasHistoryLocked( canonicalMeta.Summary = aliasSummary } - if err := s.writeMeta(sessionKey, canonicalMeta); err != nil { + if err := s.rewriteJSONL(sessionKey, aliasHistory); err != nil { return false, err } - if err := s.rewriteJSONL(sessionKey, aliasHistory); err != nil { + if err := s.writeMeta(sessionKey, canonicalMeta); err != nil { + if rollbackErr := s.restoreRawJSONL(sessionKey, previousJSONL, hadPreviousJSONL); rollbackErr != nil { + return false, fmt.Errorf("memory: write promoted meta: %w (rollback jsonl: %v)", err, rollbackErr) + } return false, err } return true, nil @@ -410,6 +418,31 @@ func (s *JSONLStore) sessionHasVisibleContentLocked(sessionKey string, meta Sess return len(history) > 0, nil } +func (s *JSONLStore) readRawJSONL(sessionKey string) ([]byte, bool, error) { + data, err := os.ReadFile(s.jsonlPath(sessionKey)) + if os.IsNotExist(err) { + return nil, false, nil + } + if err != nil { + return nil, false, fmt.Errorf("memory: read jsonl: %w", err) + } + return data, true, nil +} + +func (s *JSONLStore) restoreRawJSONL(sessionKey string, data []byte, existed bool) error { + path := s.jsonlPath(sessionKey) + if !existed { + if err := os.Remove(path); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("memory: remove jsonl rollback: %w", err) + } + return nil + } + if err := fileutil.WriteFileAtomic(path, data, 0o644); err != nil { + return fmt.Errorf("memory: restore jsonl rollback: %w", err) + } + return nil +} + // readMessages reads valid JSON lines from a .jsonl file, skipping // the first `skip` lines without unmarshaling them. This avoids the // cost of json.Unmarshal on logically truncated messages. diff --git a/pkg/session/jsonl_backend.go b/pkg/session/jsonl_backend.go index 2c4eb4e5a..68ef2d753 100644 --- a/pkg/session/jsonl_backend.go +++ b/pkg/session/jsonl_backend.go @@ -92,61 +92,6 @@ func (b *JSONLBackend) EnsureSessionMetadata(sessionKey string, scope *SessionSc if _, err := promotingStore.PromoteAliasHistory(ctx, sessionKey, rawScope, aliases); err != nil { log.Printf("session: promote alias history: %v", err) } - return - } - - canonicalMeta, metaErr := metaStore.GetSessionMeta(ctx, sessionKey) - if metaErr != nil { - log.Printf("session: get canonical session metadata: %v", metaErr) - } else if canonicalMeta.Count > 0 || strings.TrimSpace(canonicalMeta.Summary) != "" { - return - } - - canonicalHistory, historyErr := b.store.GetHistory(ctx, sessionKey) - if historyErr != nil { - log.Printf("session: get canonical history: %v", historyErr) - return - } - canonicalSummary, summaryErr := b.store.GetSummary(ctx, sessionKey) - if summaryErr != nil { - log.Printf("session: get canonical summary: %v", summaryErr) - return - } - if len(canonicalHistory) > 0 || strings.TrimSpace(canonicalSummary) != "" { - return - } - - for _, alias := range aliases { - alias = strings.TrimSpace(alias) - if alias == "" || alias == sessionKey { - continue - } - aliasHistory, err := b.store.GetHistory(ctx, alias) - if err != nil { - log.Printf("session: get alias history: %v", err) - continue - } - aliasSummary, err := b.store.GetSummary(ctx, alias) - if err != nil { - log.Printf("session: get alias summary: %v", err) - continue - } - if len(aliasHistory) == 0 && strings.TrimSpace(aliasSummary) == "" { - continue - } - if err := b.store.SetHistory(ctx, sessionKey, aliasHistory); err != nil { - log.Printf("session: promote alias history: %v", err) - return - } - if strings.TrimSpace(aliasSummary) != "" { - if err := b.store.SetSummary(ctx, sessionKey, aliasSummary); err != nil { - log.Printf("session: promote alias summary: %v", err) - } - } - if err := metaStore.UpsertSessionMeta(ctx, sessionKey, rawScope, aliases); err != nil { - log.Printf("session: refresh session metadata after promotion: %v", err) - } - return } }