Files
picoclaw/pkg/utils/bm25_test.go
T
Mauro b89f6445d1 feat(mcp): tool search tools (#1243)
* feat(mcp): tool search tools

* removed unused call_discovered_tool

* improvements and optimizations

* fix gate mcp enabled

* fix TOCTOU race BM25 cache version check

* fix encapsulation bypass on registry internals

* safety comment on TickTTL

* added more unit tests

* enhanced logs
2026-03-09 18:21:49 +01:00

176 lines
4.6 KiB
Go

package utils
import (
"reflect"
"testing"
)
// testDoc is a generic structure for use in tests.
type testDoc struct {
ID int
Text string
}
func extractText(d testDoc) string {
return d.Text
}
func TestBM25Search_EdgeCases(t *testing.T) {
corpus := []testDoc{
{1, "hello world"},
{2, "foo bar"},
}
engine := NewBM25Engine(corpus, extractText)
tests := []struct {
name string
query string
topK int
}{
{"Zero topK", "hello", 0},
{"Negative topK", "hello", -1},
{"Empty query", "", 5},
{"Query with only punctuation", "...,,,!!!", 5},
{"No matches found", "golang", 5},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
results := engine.Search(tt.query, tt.topK)
if len(results) != 0 {
t.Errorf("expected 0 results, got %d", len(results))
}
// Check that it never returns nil, but an empty slice
if results == nil {
t.Errorf("expected empty slice, got nil")
}
})
}
}
func TestBM25Search_EmptyCorpus(t *testing.T) {
engine := NewBM25Engine([]testDoc{}, extractText)
results := engine.Search("hello", 5)
if len(results) != 0 || results == nil {
t.Errorf("expected empty slice from empty corpus, got %v", results)
}
}
func TestBM25Search_RankingLogic(t *testing.T) {
corpus := []testDoc{
{1, "the quick brown fox jumps over the lazy dog"},
{2, "quick fox"},
{3, "quick quick quick fox"}, // High Term Frequency (TF)
{4, "completely irrelevant document here"},
}
engine := NewBM25Engine(corpus, extractText)
t.Run("Term Frequency (TF) boosts score", func(t *testing.T) {
results := engine.Search("quick", 5)
if len(results) < 3 {
t.Fatalf("expected at least 3 results, got %d", len(results))
}
// Doc 3 has the word "quick" repeated 3 times, it should beat Doc 2
if results[0].Document.ID != 3 {
t.Errorf("expected doc 3 to rank first due to high TF, got doc %d", results[0].Document.ID)
}
})
t.Run("Document Length penalty", func(t *testing.T) {
results := engine.Search("fox", 5)
if len(results) < 3 {
t.Fatalf("expected at least 3 results, got %d", len(results))
}
// Doc 2 ("quick fox") is much shorter than Doc 1 ("the quick brown fox..."),
// so, with equal Term Frequency for the word "fox" (1 time), Doc 2 wins.
if results[0].Document.ID != 2 {
t.Errorf("expected doc 2 to rank first due to shorter length, got doc %d", results[0].Document.ID)
}
})
t.Run("TopK limits results", func(t *testing.T) {
results := engine.Search("quick", 2)
if len(results) != 2 {
t.Errorf("expected exactly 2 results, got %d", len(results))
}
})
}
func TestBM25Tokenize(t *testing.T) {
tests := []struct {
input string
expected []string
}{
{"Hello World", []string{"hello", "world"}},
{" spaces everywhere ", []string{"spaces", "everywhere"}},
{"punctuation... test!!!", []string{"punctuation", "test"}},
{"(parentheses) and-hyphens", []string{"parentheses", "and-hyphens"}}, // hyphens trimmed from edges
{"internal-hyphen is kept", []string{"internal-hyphen", "is", "kept"}},
{".,;?!", []string{}}, // Becomes empty after trim
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
got := bm25Tokenize(tt.input)
if len(got) == 0 && len(tt.expected) == 0 {
return // Both empty
}
if !reflect.DeepEqual(got, tt.expected) {
t.Errorf("bm25Tokenize(%q) = %v, want %v", tt.input, got, tt.expected)
}
})
}
}
func TestBM25Dedupe(t *testing.T) {
input := []string{"apple", "banana", "apple", "orange", "banana"}
expected := []string{"apple", "banana", "orange"}
got := bm25Dedupe(input)
if !reflect.DeepEqual(got, expected) {
t.Errorf("bm25Dedupe() = %v, want %v", got, expected)
}
}
func TestBM25Options(t *testing.T) {
corpus := []testDoc{{1, "test"}}
engine := NewBM25Engine(
corpus,
extractText,
WithK1(2.5),
WithB(0.9),
)
if engine.k1 != 2.5 {
t.Errorf("expected k1 to be 2.5, got %v", engine.k1)
}
if engine.b != 0.9 {
t.Errorf("expected b to be 0.9, got %v", engine.b)
}
}
func TestBM25Search_SortingStability(t *testing.T) {
// Ensure that sorting by heap returns in correct descending order
corpus := []testDoc{
{1, "golang is good"},
{2, "golang golang"},
{3, "golang golang golang"},
{4, "golang golang golang golang"},
}
engine := NewBM25Engine(corpus, extractText)
results := engine.Search("golang", 10)
if len(results) != 4 {
t.Fatalf("expected 4 results, got %d", len(results))
}
// Score should be strictly decreasing
for i := 1; i < len(results); i++ {
if results[i].Score > results[i-1].Score {
t.Errorf("results not sorted correctly: result %d score (%v) > result %d score (%v)",
i, results[i].Score, i-1, results[i-1].Score)
}
}
}