Files
picoclaw/cmd/membench/metrics_test.go
T
Liu Yuan 1175f4a62b feat(membench): add LOCOMO memory benchmark tool (#2353)
Benchmark tool comparing legacy session manager vs seahorse short memory
retrieval on the LOCOMO long-term conversational memory dataset.

- cmd/membench/: CLI with ingest/eval/report/run subcommands
- Mode A (legacy): recency-biased budget truncation baseline
- Mode B (seahorse): per-keyword trigram FTS5 search + expand
- Metrics: Token-Overlap F1 and Recall Hit Rate
- `make mem` builds, downloads data, runs benchmark end-to-end
2026-04-06 17:26:43 +08:00

240 lines
6.9 KiB
Go

package main
import (
"encoding/json"
"math"
"testing"
)
func TestSplitEvidenceIDs(t *testing.T) {
tests := []struct {
input string
want []string
}{
{"D1:3", []string{"D1:3"}},
{"D8:6; D9:17", []string{"D8:6", "D9:17"}},
{"D9:1 D4:4 D4:6", []string{"D9:1", "D4:4", "D4:6"}},
{"D22:1 D22:2 D9:10 D9:11", []string{"D22:1", "D22:2", "D9:10", "D9:11"}},
{"D21:18 D21:22 D11:15 D11:19", []string{"D21:18", "D21:22", "D11:15", "D11:19"}},
{"D30:05", []string{"D30:5"}},
{"D", nil},
{"D:", nil},
{"", nil},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
got := SplitEvidenceIDs(tt.input)
if len(got) != len(tt.want) {
t.Fatalf("SplitEvidenceIDs(%q) = %v, want %v", tt.input, got, tt.want)
}
for i := range got {
if got[i] != tt.want[i] {
t.Errorf("[%d] = %q, want %q", i, got[i], tt.want[i])
}
}
})
}
}
func TestNormalizeDiaID(t *testing.T) {
tests := []struct {
input string
want string
}{
{"D1:3", "D1:3"},
{"D30:05", "D30:5"},
{"D10:003", "D10:3"},
{"D1:0", "D1:0"},
}
for _, tt := range tests {
got := NormalizeDiaID(tt.input)
if got != tt.want {
t.Errorf("NormalizeDiaID(%q) = %q, want %q", tt.input, got, tt.want)
}
}
}
func TestTokenOverlapF1(t *testing.T) {
tests := []struct {
name string
prediction string
reference string
want float64
}{
{"exact match", "hello world", "hello world", 1.0},
{"no overlap", "foo bar", "baz qux", 0.0},
{"empty both", "", "", 1.0},
{"empty prediction", "", "hello", 0.0},
{"empty reference", "hello", "", 0.0},
{"partial overlap", "the cat sat on the mat", "the cat on the floor", 8.0 / 11.0},
{"case insensitive", "Hello World", "hello world", 1.0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := TokenOverlapF1(tt.prediction, tt.reference)
if math.Abs(got-tt.want) > 1e-9 {
t.Errorf("TokenOverlapF1(%q, %q) = %.4f, want %.4f",
tt.prediction, tt.reference, got, tt.want)
}
})
}
}
func TestBudgetTruncate(t *testing.T) {
t.Run("within budget returns all", func(t *testing.T) {
msgs := []string{"short", "message", "here"}
result, total := BudgetTruncate(msgs, 1000)
if len(result) != 3 {
t.Errorf("expected 3 messages, got %d", len(result))
}
if total == 0 {
t.Error("expected non-zero token count")
}
})
t.Run("over budget keeps best first", func(t *testing.T) {
msgs := []string{
"best message that is quite long and takes up tokens",
"good message also fairly long content",
"worst short",
}
result, _ := BudgetTruncate(msgs, 5) // very small budget
if len(result) == 0 {
t.Fatal("expected at least one message")
}
// Best-ranked (first) should be kept
if result[0] != "best message that is quite long and takes up tokens" {
t.Errorf("expected best message kept first, got %q", result[0])
}
})
t.Run("over budget keeps best ranked first", func(t *testing.T) {
// Messages are sorted by bm25 rank ascending (best/most-negative first).
// When budget is insufficient, BudgetTruncate must keep the front
// (best-ranked) messages, not the tail (worst-ranked).
msgs := []string{
"best ranked message with some content here",
"second best message also has content",
"third message here too",
"worst ranked short",
}
// Budget only fits ~1 message (~10 tokens per message, budget=12)
result, _ := BudgetTruncate(msgs, 12)
if len(result) == 0 {
t.Fatal("expected at least one message")
}
if result[0] != "best ranked message with some content here" {
t.Errorf("expected best-ranked (first) message kept, got %q", result[0])
}
// Worst-ranked (last) must NOT appear
for _, m := range result {
if m == "worst ranked short" {
t.Error("worst-ranked message should have been truncated")
}
}
})
t.Run("preserves original order", func(t *testing.T) {
msgs := []string{"alpha", "beta", "gamma"}
result, _ := BudgetTruncate(msgs, 100)
for i, got := range result {
if got != msgs[i] {
t.Errorf("result[%d] = %q, want %q", i, got, msgs[i])
}
}
})
t.Run("empty input", func(t *testing.T) {
result, total := BudgetTruncate(nil, 100)
if len(result) != 0 {
t.Errorf("expected 0 messages, got %d", len(result))
}
if total != 0 {
t.Errorf("expected 0 tokens, got %d", total)
}
})
}
func TestRecallHitRate(t *testing.T) {
// Build a sample with known turns
sample := &LocomoSample{
SampleID: "test-sample",
Conversation: map[string]json.RawMessage{
"session_1": json.RawMessage(`[
{"speaker":"A","dia_id":"D1:1","text":"hello world this is a test message with enough length"},
{"speaker":"B","dia_id":"D1:2","text":"another message for testing recall computation purposes here"},
{"speaker":"A","dia_id":"D1:3","text":"third turn with some more content to test"}
]`),
},
}
t.Run("all evidence found", func(t *testing.T) {
retrieved := "hello world this is a test message with enough length another message for testing recall computation purposes here"
got := RecallHitRate([]string{"D1:1", "D1:2"}, sample, retrieved)
if math.Abs(got-1.0) > 1e-9 {
t.Errorf("RecallHitRate all found = %.4f, want 1.0", got)
}
})
t.Run("partial evidence found", func(t *testing.T) {
retrieved := "hello world this is a test message with enough length"
got := RecallHitRate([]string{"D1:1", "D1:2"}, sample, retrieved)
if math.Abs(got-0.5) > 1e-9 {
t.Errorf("RecallHitRate partial = %.4f, want 0.5", got)
}
})
t.Run("no evidence required", func(t *testing.T) {
got := RecallHitRate(nil, sample, "anything")
if got != 1.0 {
t.Errorf("RecallHitRate no evidence = %.4f, want 1.0", got)
}
})
t.Run("missing turn excluded from denominator", func(t *testing.T) {
// D1:1 is found, D99:1 does not exist in sample
// Should only count resolvable turns in denominator
retrieved := "hello world this is a test message with enough length"
got := RecallHitRate([]string{"D1:1", "D99:1"}, sample, retrieved)
if math.Abs(got-1.0) > 1e-9 {
t.Errorf("RecallHitRate missing turn = %.4f, want 1.0 (unresolvable excluded)", got)
}
})
}
func TestExtractKeywords(t *testing.T) {
tests := []struct {
name string
input string
want []string
}{
{"simple", "What is the capital of France", []string{"capital", "france"}},
{
"stops removed",
"Who is the president of the United States",
[]string{"president", "united", "states"},
},
{
"max 6 keywords",
"one two three four five six seven eight nine ten",
[]string{"one", "two", "three", "four", "five", "six"},
},
{"short words filtered", "I am a go to the store", []string{"am", "store"}},
{"empty", "", nil},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ExtractKeywords(tt.input)
if len(got) != len(tt.want) {
t.Fatalf("ExtractKeywords(%q) = %v (len %d), want %v (len %d)",
tt.input, got, len(got), tt.want, len(tt.want))
}
for i := range got {
if got[i] != tt.want[i] {
t.Errorf("[%d] = %q, want %q", i, got[i], tt.want[i])
}
}
})
}
}