fix(session): address review regressions

This commit is contained in:
Hoshina
2026-04-13 22:51:44 +08:00
parent 0c6ad33a9c
commit c5c5ea22d6
7 changed files with 119 additions and 64 deletions
+16 -3
View File
@@ -93,9 +93,7 @@ func normalizeProcessOptions(opts processOptions) processOptions {
MessageID: strings.TrimSpace(opts.MessageID),
ReplyToMessageID: strings.TrimSpace(opts.ReplyToMessageID),
}
if inbound.Channel != "" && inbound.ChatID != "" {
inbound.ChatType = "direct"
}
inbound.ChatType = inferChatTypeFromSessionScope(opts.Dispatch.SessionScope)
if inbound.Channel != "" || inbound.ChatID != "" || inbound.SenderID != "" ||
inbound.MessageID != "" || inbound.ReplyToMessageID != "" {
inbound = bus.NormalizeInboundMessage(bus.InboundMessage{Context: inbound}).Context
@@ -132,3 +130,18 @@ func normalizeProcessOptions(opts processOptions) processOptions {
return opts
}
func inferChatTypeFromSessionScope(scope *session.SessionScope) string {
if scope == nil || len(scope.Values) == 0 {
return ""
}
chatValue := strings.TrimSpace(scope.Values["chat"])
if chatValue == "" {
return ""
}
chatType, _, ok := strings.Cut(chatValue, ":")
if !ok {
return ""
}
return strings.ToLower(strings.TrimSpace(chatType))
}
+25
View File
@@ -108,3 +108,28 @@ func TestNormalizeProcessOptions_UsesDispatchAsSourceOfTruth(t *testing.T) {
t.Fatalf("SessionScope = %#v, want support scope", opts.SessionScope)
}
}
func TestNormalizeProcessOptions_InfersLegacyChatTypeFromSessionScope(t *testing.T) {
opts := normalizeProcessOptions(processOptions{
Channel: "telegram",
ChatID: "-100123",
SenderID: "user-1",
UserMessage: "hello",
SessionScope: &session.SessionScope{
Version: session.ScopeVersionV1,
AgentID: "main",
Channel: "telegram",
Dimensions: []string{"chat"},
Values: map[string]string{
"chat": "group:-100123",
},
},
})
if opts.Dispatch.InboundContext == nil {
t.Fatal("Dispatch.InboundContext is nil")
}
if opts.Dispatch.InboundContext.ChatType != "group" {
t.Fatalf("Dispatch.InboundContext.ChatType = %q, want group", opts.Dispatch.InboundContext.ChatType)
}
}
+10 -3
View File
@@ -292,16 +292,18 @@ func (al *AgentLoop) continueWithSteeringMessages(
ctx context.Context,
agent *AgentInstance,
sessionKey, channel, chatID string,
scope *session.SessionScope,
steeringMsgs []providers.Message,
) (string, error) {
dispatch := DispatchRequest{
SessionKey: sessionKey,
SessionKey: sessionKey,
SessionScope: session.CloneScope(scope),
}
if channel != "" || chatID != "" {
dispatch.InboundContext = &bus.InboundContext{
Channel: channel,
ChatID: chatID,
ChatType: "direct",
ChatType: inferChatTypeFromSessionScope(scope),
}
}
return al.runAgentLoop(ctx, agent, processOptions{
@@ -372,7 +374,12 @@ func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID s
}
}
return al.continueWithSteeringMessages(ctx, agent, sessionKey, channel, chatID, steeringMsgs)
var scope *session.SessionScope
if metaStore, ok := agent.Sessions.(session.MetadataAwareSessionStore); ok {
scope = metaStore.GetSessionScope(sessionKey)
}
return al.continueWithSteeringMessages(ctx, agent, sessionKey, channel, chatID, scope, steeringMsgs)
}
func (al *AgentLoop) InterruptGraceful(hint string) error {
+27
View File
@@ -278,6 +278,33 @@ func TestPublishOutbound_PreservesExplicitReplyToMessageID(t *testing.T) {
}
}
func TestPublishOutbound_PreservesExplicitReplyToMessageIDWhenContextReplyIsBlank(t *testing.T) {
mb := NewMessageBus()
defer mb.Close()
msg := OutboundMessage{
Context: InboundContext{
Channel: "telegram",
ChatID: "chat-42",
ReplyToMessageID: " ",
},
ReplyToMessageID: "msg-9",
Content: "reply",
}
if err := mb.PublishOutbound(context.Background(), msg); err != nil {
t.Fatalf("PublishOutbound failed: %v", err)
}
got := <-mb.OutboundChan()
if got.ReplyToMessageID != "msg-9" {
t.Fatalf("expected mirrored reply_to_message_id msg-9, got %q", got.ReplyToMessageID)
}
if got.Context.ReplyToMessageID != "msg-9" {
t.Fatalf("expected context reply_to_message_id msg-9, got %q", got.Context.ReplyToMessageID)
}
}
func TestPublishOutboundMedia_MirrorsContextToLegacyFields(t *testing.T) {
mb := NewMessageBus()
defer mb.Close()
+6 -1
View File
@@ -34,7 +34,12 @@ func NormalizeOutboundMessage(msg OutboundMessage) OutboundMessage {
if msg.ChatID == "" {
msg.ChatID = msg.Context.ChatID
}
msg.ReplyToMessageID = msg.Context.ReplyToMessageID
if msg.ReplyToMessageID == "" {
msg.ReplyToMessageID = msg.Context.ReplyToMessageID
}
if msg.Context.ReplyToMessageID == "" {
msg.Context.ReplyToMessageID = msg.ReplyToMessageID
}
msg.Scope = cloneOutboundScope(msg.Scope)
return msg
}
+35 -2
View File
@@ -374,6 +374,11 @@ func (s *JSONLStore) promoteAliasHistoryLocked(
return false, nil
}
previousJSONL, hadPreviousJSONL, err := s.readRawJSONL(sessionKey)
if err != nil {
return false, err
}
now := time.Now()
if canonicalMeta.CreatedAt.IsZero() {
canonicalMeta.CreatedAt = now
@@ -387,10 +392,13 @@ func (s *JSONLStore) promoteAliasHistoryLocked(
canonicalMeta.Summary = aliasSummary
}
if err := s.writeMeta(sessionKey, canonicalMeta); err != nil {
if err := s.rewriteJSONL(sessionKey, aliasHistory); err != nil {
return false, err
}
if err := s.rewriteJSONL(sessionKey, aliasHistory); err != nil {
if err := s.writeMeta(sessionKey, canonicalMeta); err != nil {
if rollbackErr := s.restoreRawJSONL(sessionKey, previousJSONL, hadPreviousJSONL); rollbackErr != nil {
return false, fmt.Errorf("memory: write promoted meta: %w (rollback jsonl: %v)", err, rollbackErr)
}
return false, err
}
return true, nil
@@ -410,6 +418,31 @@ func (s *JSONLStore) sessionHasVisibleContentLocked(sessionKey string, meta Sess
return len(history) > 0, nil
}
func (s *JSONLStore) readRawJSONL(sessionKey string) ([]byte, bool, error) {
data, err := os.ReadFile(s.jsonlPath(sessionKey))
if os.IsNotExist(err) {
return nil, false, nil
}
if err != nil {
return nil, false, fmt.Errorf("memory: read jsonl: %w", err)
}
return data, true, nil
}
func (s *JSONLStore) restoreRawJSONL(sessionKey string, data []byte, existed bool) error {
path := s.jsonlPath(sessionKey)
if !existed {
if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("memory: remove jsonl rollback: %w", err)
}
return nil
}
if err := fileutil.WriteFileAtomic(path, data, 0o644); err != nil {
return fmt.Errorf("memory: restore jsonl rollback: %w", err)
}
return nil
}
// readMessages reads valid JSON lines from a .jsonl file, skipping
// the first `skip` lines without unmarshaling them. This avoids the
// cost of json.Unmarshal on logically truncated messages.
-55
View File
@@ -92,61 +92,6 @@ func (b *JSONLBackend) EnsureSessionMetadata(sessionKey string, scope *SessionSc
if _, err := promotingStore.PromoteAliasHistory(ctx, sessionKey, rawScope, aliases); err != nil {
log.Printf("session: promote alias history: %v", err)
}
return
}
canonicalMeta, metaErr := metaStore.GetSessionMeta(ctx, sessionKey)
if metaErr != nil {
log.Printf("session: get canonical session metadata: %v", metaErr)
} else if canonicalMeta.Count > 0 || strings.TrimSpace(canonicalMeta.Summary) != "" {
return
}
canonicalHistory, historyErr := b.store.GetHistory(ctx, sessionKey)
if historyErr != nil {
log.Printf("session: get canonical history: %v", historyErr)
return
}
canonicalSummary, summaryErr := b.store.GetSummary(ctx, sessionKey)
if summaryErr != nil {
log.Printf("session: get canonical summary: %v", summaryErr)
return
}
if len(canonicalHistory) > 0 || strings.TrimSpace(canonicalSummary) != "" {
return
}
for _, alias := range aliases {
alias = strings.TrimSpace(alias)
if alias == "" || alias == sessionKey {
continue
}
aliasHistory, err := b.store.GetHistory(ctx, alias)
if err != nil {
log.Printf("session: get alias history: %v", err)
continue
}
aliasSummary, err := b.store.GetSummary(ctx, alias)
if err != nil {
log.Printf("session: get alias summary: %v", err)
continue
}
if len(aliasHistory) == 0 && strings.TrimSpace(aliasSummary) == "" {
continue
}
if err := b.store.SetHistory(ctx, sessionKey, aliasHistory); err != nil {
log.Printf("session: promote alias history: %v", err)
return
}
if strings.TrimSpace(aliasSummary) != "" {
if err := b.store.SetSummary(ctx, sessionKey, aliasSummary); err != nil {
log.Printf("session: promote alias summary: %v", err)
}
}
if err := metaStore.UpsertSessionMeta(ctx, sessionKey, rawScope, aliases); err != nil {
log.Printf("session: refresh session metadata after promotion: %v", err)
}
return
}
}