refactor(memory): use sync.Map for session locks and skip-scan in readMessages

Address review feedback from @Zhaoyikaiii:

- Replace map[string]*sync.Mutex + separate mu with sync.Map.LoadOrStore
  for simpler, lock-free session lock management.

- Add skip parameter to readMessages so callers (GetHistory, Compact)
  can skip truncated lines without paying the json.Unmarshal cost.

- Add countLines helper for TruncateHistory's count reconciliation,
  avoiding full deserialization when only the line count is needed.
This commit is contained in:
xiaoen
2026-02-26 14:31:02 +08:00
parent b464687e2f
commit 5d73ee2d9a
2 changed files with 50 additions and 40 deletions
+48 -38
View File
@@ -36,10 +36,8 @@ type sessionMeta struct {
// GetHistory ignores lines before that offset. This keeps all writes
// append-only, which is both fast and crash-safe.
type JSONLStore struct {
dir string
mu sync.Mutex
locks map[string]*sync.Mutex
dir string
locks sync.Map // map[string]*sync.Mutex, one per session
}
// NewJSONLStore creates a new JSONL-backed store rooted at dir.
@@ -48,23 +46,13 @@ func NewJSONLStore(dir string) (*JSONLStore, error) {
if err != nil {
return nil, fmt.Errorf("memory: create directory: %w", err)
}
return &JSONLStore{
dir: dir,
locks: make(map[string]*sync.Mutex),
}, nil
return &JSONLStore{dir: dir}, nil
}
// sessionLock returns (or creates) a per-session mutex.
func (s *JSONLStore) sessionLock(key string) *sync.Mutex {
s.mu.Lock()
defer s.mu.Unlock()
l, ok := s.locks[key]
if !ok {
l = &sync.Mutex{}
s.locks[key] = l
}
return l
v, _ := s.locks.LoadOrStore(key, &sync.Mutex{})
return v.(*sync.Mutex)
}
func (s *JSONLStore) jsonlPath(key string) string {
@@ -122,9 +110,11 @@ func (s *JSONLStore) writeMeta(key string, meta sessionMeta) error {
return nil
}
// readMessages reads all valid JSON lines from a .jsonl file.
// readMessages reads valid JSON lines from a .jsonl file, skipping
// the first `skip` lines without unmarshaling them. This avoids the
// cost of json.Unmarshal on logically truncated messages.
// Malformed trailing lines (e.g. from a crash) are silently skipped.
func readMessages(path string) ([]providers.Message, error) {
func readMessages(path string, skip int) ([]providers.Message, error) {
f, err := os.Open(path)
if os.IsNotExist(err) {
return []providers.Message{}, nil
@@ -139,11 +129,16 @@ func readMessages(path string) ([]providers.Message, error) {
// Allow up to 1 MB per line for messages with large content.
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
lineNum := 0
for scanner.Scan() {
line := scanner.Bytes()
if len(line) == 0 {
continue
}
lineNum++
if lineNum <= skip {
continue
}
var msg providers.Message
if json.Unmarshal(line, &msg) != nil {
// Corrupt line — likely a partial write from a crash.
@@ -162,6 +157,30 @@ func readMessages(path string) ([]providers.Message, error) {
return msgs, nil
}
// countLines counts the total number of non-empty lines in a .jsonl file.
// Used by TruncateHistory to reconcile a stale meta.Count without
// the overhead of unmarshaling every message.
func countLines(path string) (int, error) {
f, err := os.Open(path)
if os.IsNotExist(err) {
return 0, nil
}
if err != nil {
return 0, fmt.Errorf("memory: open jsonl: %w", err)
}
defer f.Close()
n := 0
scanner := bufio.NewScanner(f)
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
for scanner.Scan() {
if len(scanner.Bytes()) > 0 {
n++
}
}
return n, scanner.Err()
}
func (s *JSONLStore) AddMessage(
_ context.Context, sessionKey, role, content string,
) error {
@@ -234,18 +253,13 @@ func (s *JSONLStore) GetHistory(
return nil, err
}
msgs, err := readMessages(s.jsonlPath(sessionKey))
// Pass meta.Skip so readMessages skips those lines without
// unmarshaling them — avoids wasted CPU on truncated messages.
msgs, err := readMessages(s.jsonlPath(sessionKey), meta.Skip)
if err != nil {
return nil, err
}
// Apply logical truncation: skip the first meta.Skip messages.
if meta.Skip > 0 && meta.Skip < len(msgs) {
msgs = msgs[meta.Skip:]
} else if meta.Skip >= len(msgs) {
msgs = []providers.Message{}
}
return msgs, nil
}
@@ -299,11 +313,11 @@ func (s *JSONLStore) TruncateHistory(
// If the meta count might be stale (e.g. after a crash during
// addMsg), reconcile with the actual line count on disk.
if meta.Count == 0 {
msgs, readErr := readMessages(s.jsonlPath(sessionKey))
if readErr != nil {
return readErr
n, countErr := countLines(s.jsonlPath(sessionKey))
if countErr != nil {
return countErr
}
meta.Count = len(msgs)
meta.Count = n
}
if keepLast <= 0 {
@@ -369,17 +383,13 @@ func (s *JSONLStore) Compact(
return nil
}
all, err := readMessages(s.jsonlPath(sessionKey))
// Read only the active messages, skipping truncated lines
// without unmarshaling them.
active, err := readMessages(s.jsonlPath(sessionKey), meta.Skip)
if err != nil {
return err
}
// Keep only the active (non-skipped) messages.
var active []providers.Message
if meta.Skip < len(all) {
active = all[meta.Skip:]
}
err = s.rewriteJSONL(sessionKey, active)
if err != nil {
return err
+2 -2
View File
@@ -440,7 +440,7 @@ func TestCompact_RemovesSkippedMessages(t *testing.T) {
}
// Before compact: file still has 10 lines.
allOnDisk, err := readMessages(store.jsonlPath("compact"))
allOnDisk, err := readMessages(store.jsonlPath("compact"), 0)
if err != nil {
t.Fatalf("readMessages: %v", err)
}
@@ -455,7 +455,7 @@ func TestCompact_RemovesSkippedMessages(t *testing.T) {
}
// After compact: file should have only 3 lines.
allOnDisk, err = readMessages(store.jsonlPath("compact"))
allOnDisk, err = readMessages(store.jsonlPath("compact"), 0)
if err != nil {
t.Fatalf("readMessages: %v", err)
}