mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
1175f4a62b
Benchmark tool comparing legacy session manager vs seahorse short memory retrieval on the LOCOMO long-term conversational memory dataset. - cmd/membench/: CLI with ingest/eval/report/run subcommands - Mode A (legacy): recency-biased budget truncation baseline - Mode B (seahorse): per-keyword trigram FTS5 search + expand - Metrics: Token-Overlap F1 and Recall Hit Rate - `make mem` builds, downloads data, runs benchmark end-to-end
240 lines
6.9 KiB
Go
240 lines
6.9 KiB
Go
package main
|
|
|
|
import (
|
|
"encoding/json"
|
|
"math"
|
|
"testing"
|
|
)
|
|
|
|
func TestSplitEvidenceIDs(t *testing.T) {
|
|
tests := []struct {
|
|
input string
|
|
want []string
|
|
}{
|
|
{"D1:3", []string{"D1:3"}},
|
|
{"D8:6; D9:17", []string{"D8:6", "D9:17"}},
|
|
{"D9:1 D4:4 D4:6", []string{"D9:1", "D4:4", "D4:6"}},
|
|
{"D22:1 D22:2 D9:10 D9:11", []string{"D22:1", "D22:2", "D9:10", "D9:11"}},
|
|
{"D21:18 D21:22 D11:15 D11:19", []string{"D21:18", "D21:22", "D11:15", "D11:19"}},
|
|
{"D30:05", []string{"D30:5"}},
|
|
{"D", nil},
|
|
{"D:", nil},
|
|
{"", nil},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.input, func(t *testing.T) {
|
|
got := SplitEvidenceIDs(tt.input)
|
|
if len(got) != len(tt.want) {
|
|
t.Fatalf("SplitEvidenceIDs(%q) = %v, want %v", tt.input, got, tt.want)
|
|
}
|
|
for i := range got {
|
|
if got[i] != tt.want[i] {
|
|
t.Errorf("[%d] = %q, want %q", i, got[i], tt.want[i])
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestNormalizeDiaID(t *testing.T) {
|
|
tests := []struct {
|
|
input string
|
|
want string
|
|
}{
|
|
{"D1:3", "D1:3"},
|
|
{"D30:05", "D30:5"},
|
|
{"D10:003", "D10:3"},
|
|
{"D1:0", "D1:0"},
|
|
}
|
|
for _, tt := range tests {
|
|
got := NormalizeDiaID(tt.input)
|
|
if got != tt.want {
|
|
t.Errorf("NormalizeDiaID(%q) = %q, want %q", tt.input, got, tt.want)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestTokenOverlapF1(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
prediction string
|
|
reference string
|
|
want float64
|
|
}{
|
|
{"exact match", "hello world", "hello world", 1.0},
|
|
{"no overlap", "foo bar", "baz qux", 0.0},
|
|
{"empty both", "", "", 1.0},
|
|
{"empty prediction", "", "hello", 0.0},
|
|
{"empty reference", "hello", "", 0.0},
|
|
{"partial overlap", "the cat sat on the mat", "the cat on the floor", 8.0 / 11.0},
|
|
{"case insensitive", "Hello World", "hello world", 1.0},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got := TokenOverlapF1(tt.prediction, tt.reference)
|
|
if math.Abs(got-tt.want) > 1e-9 {
|
|
t.Errorf("TokenOverlapF1(%q, %q) = %.4f, want %.4f",
|
|
tt.prediction, tt.reference, got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestBudgetTruncate(t *testing.T) {
|
|
t.Run("within budget returns all", func(t *testing.T) {
|
|
msgs := []string{"short", "message", "here"}
|
|
result, total := BudgetTruncate(msgs, 1000)
|
|
if len(result) != 3 {
|
|
t.Errorf("expected 3 messages, got %d", len(result))
|
|
}
|
|
if total == 0 {
|
|
t.Error("expected non-zero token count")
|
|
}
|
|
})
|
|
|
|
t.Run("over budget keeps best first", func(t *testing.T) {
|
|
msgs := []string{
|
|
"best message that is quite long and takes up tokens",
|
|
"good message also fairly long content",
|
|
"worst short",
|
|
}
|
|
result, _ := BudgetTruncate(msgs, 5) // very small budget
|
|
if len(result) == 0 {
|
|
t.Fatal("expected at least one message")
|
|
}
|
|
// Best-ranked (first) should be kept
|
|
if result[0] != "best message that is quite long and takes up tokens" {
|
|
t.Errorf("expected best message kept first, got %q", result[0])
|
|
}
|
|
})
|
|
|
|
t.Run("over budget keeps best ranked first", func(t *testing.T) {
|
|
// Messages are sorted by bm25 rank ascending (best/most-negative first).
|
|
// When budget is insufficient, BudgetTruncate must keep the front
|
|
// (best-ranked) messages, not the tail (worst-ranked).
|
|
msgs := []string{
|
|
"best ranked message with some content here",
|
|
"second best message also has content",
|
|
"third message here too",
|
|
"worst ranked short",
|
|
}
|
|
// Budget only fits ~1 message (~10 tokens per message, budget=12)
|
|
result, _ := BudgetTruncate(msgs, 12)
|
|
if len(result) == 0 {
|
|
t.Fatal("expected at least one message")
|
|
}
|
|
if result[0] != "best ranked message with some content here" {
|
|
t.Errorf("expected best-ranked (first) message kept, got %q", result[0])
|
|
}
|
|
// Worst-ranked (last) must NOT appear
|
|
for _, m := range result {
|
|
if m == "worst ranked short" {
|
|
t.Error("worst-ranked message should have been truncated")
|
|
}
|
|
}
|
|
})
|
|
|
|
t.Run("preserves original order", func(t *testing.T) {
|
|
msgs := []string{"alpha", "beta", "gamma"}
|
|
result, _ := BudgetTruncate(msgs, 100)
|
|
for i, got := range result {
|
|
if got != msgs[i] {
|
|
t.Errorf("result[%d] = %q, want %q", i, got, msgs[i])
|
|
}
|
|
}
|
|
})
|
|
|
|
t.Run("empty input", func(t *testing.T) {
|
|
result, total := BudgetTruncate(nil, 100)
|
|
if len(result) != 0 {
|
|
t.Errorf("expected 0 messages, got %d", len(result))
|
|
}
|
|
if total != 0 {
|
|
t.Errorf("expected 0 tokens, got %d", total)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestRecallHitRate(t *testing.T) {
|
|
// Build a sample with known turns
|
|
sample := &LocomoSample{
|
|
SampleID: "test-sample",
|
|
Conversation: map[string]json.RawMessage{
|
|
"session_1": json.RawMessage(`[
|
|
{"speaker":"A","dia_id":"D1:1","text":"hello world this is a test message with enough length"},
|
|
{"speaker":"B","dia_id":"D1:2","text":"another message for testing recall computation purposes here"},
|
|
{"speaker":"A","dia_id":"D1:3","text":"third turn with some more content to test"}
|
|
]`),
|
|
},
|
|
}
|
|
|
|
t.Run("all evidence found", func(t *testing.T) {
|
|
retrieved := "hello world this is a test message with enough length another message for testing recall computation purposes here"
|
|
got := RecallHitRate([]string{"D1:1", "D1:2"}, sample, retrieved)
|
|
if math.Abs(got-1.0) > 1e-9 {
|
|
t.Errorf("RecallHitRate all found = %.4f, want 1.0", got)
|
|
}
|
|
})
|
|
|
|
t.Run("partial evidence found", func(t *testing.T) {
|
|
retrieved := "hello world this is a test message with enough length"
|
|
got := RecallHitRate([]string{"D1:1", "D1:2"}, sample, retrieved)
|
|
if math.Abs(got-0.5) > 1e-9 {
|
|
t.Errorf("RecallHitRate partial = %.4f, want 0.5", got)
|
|
}
|
|
})
|
|
|
|
t.Run("no evidence required", func(t *testing.T) {
|
|
got := RecallHitRate(nil, sample, "anything")
|
|
if got != 1.0 {
|
|
t.Errorf("RecallHitRate no evidence = %.4f, want 1.0", got)
|
|
}
|
|
})
|
|
|
|
t.Run("missing turn excluded from denominator", func(t *testing.T) {
|
|
// D1:1 is found, D99:1 does not exist in sample
|
|
// Should only count resolvable turns in denominator
|
|
retrieved := "hello world this is a test message with enough length"
|
|
got := RecallHitRate([]string{"D1:1", "D99:1"}, sample, retrieved)
|
|
if math.Abs(got-1.0) > 1e-9 {
|
|
t.Errorf("RecallHitRate missing turn = %.4f, want 1.0 (unresolvable excluded)", got)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestExtractKeywords(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
input string
|
|
want []string
|
|
}{
|
|
{"simple", "What is the capital of France", []string{"capital", "france"}},
|
|
{
|
|
"stops removed",
|
|
"Who is the president of the United States",
|
|
[]string{"president", "united", "states"},
|
|
},
|
|
{
|
|
"max 6 keywords",
|
|
"one two three four five six seven eight nine ten",
|
|
[]string{"one", "two", "three", "four", "five", "six"},
|
|
},
|
|
{"short words filtered", "I am a go to the store", []string{"am", "store"}},
|
|
{"empty", "", nil},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got := ExtractKeywords(tt.input)
|
|
if len(got) != len(tt.want) {
|
|
t.Fatalf("ExtractKeywords(%q) = %v (len %d), want %v (len %d)",
|
|
tt.input, got, len(got), tt.want, len(tt.want))
|
|
}
|
|
for i := range got {
|
|
if got[i] != tt.want[i] {
|
|
t.Errorf("[%d] = %q, want %q", i, got[i], tt.want[i])
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|