diff --git a/pkg/utils/bm25.go b/pkg/utils/bm25.go index 95c63f0e3..f8b9f6882 100644 --- a/pkg/utils/bm25.go +++ b/pkg/utils/bm25.go @@ -29,18 +29,18 @@ const ( DefaultBM25B = 0.75 ) -// BM25Engine is a query-time BM25 search engine over a generic corpus. +// BM25Engine is a BM25 search engine over a generic corpus. // T is the document type; the caller supplies a TextFunc that extracts the // searchable text from each document. // -// The engine is stateless between queries: no caching, no invalidation logic. -// All indexing work is performed inside Search() on every call, making it -// safe to use on corpora that change frequently. +// The engine precomputes its index once at construction time and reuses it for +// subsequent searches. If the corpus content changes, construct a new engine. type BM25Engine[T any] struct { corpus []T textFunc func(T) string k1 float64 b float64 + index *bm25Index } // BM25Option is a functional option to configure a BM25Engine. @@ -51,6 +51,17 @@ type bm25Config struct { b float64 } +type bm25Index struct { + entries []bm25DocEntry + idf map[string]float32 + docLenNorm []float32 + posting map[string][]int32 +} + +type bm25DocEntry struct { + tf map[string]uint32 +} + // WithK1 overrides the term-frequency saturation constant (default 1.2). func WithK1(k1 float64) BM25Option { return func(c *bm25Config) { c.k1 = k1 } @@ -74,12 +85,14 @@ func NewBM25Engine[T any](corpus []T, textFunc func(T) string, opts ...BM25Optio for _, o := range opts { o(&cfg) } - return &BM25Engine[T]{ + engine := &BM25Engine[T]{ corpus: corpus, textFunc: textFunc, k1: cfg.k1, b: cfg.b, } + engine.index = buildBM25Index(corpus, textFunc, cfg.k1, cfg.b) + return engine } // BM25Result is a single ranked result from a Search call. @@ -91,9 +104,8 @@ type BM25Result[T any] struct { // Search ranks the corpus against query and returns the top-k results. // Returns an empty slice (not nil) when there are no matches. // -// Complexity: O(N×L) for indexing + O(|Q|×avgPostingLen) for scoring, -// where N = corpus size, L = average document length, Q = query terms. -// Top-k extraction uses a fixed-size min-heap: O(candidates × log k). +// Complexity: O(|Q|×avgPostingLen + candidates × log k) per search after the +// one-time indexing work performed by NewBM25Engine. func (e *BM25Engine[T]) Search(query string, topK int) []BM25Result[T] { if topK <= 0 { return []BM25Result[T]{} @@ -104,78 +116,24 @@ func (e *BM25Engine[T]) Search(query string, topK int) []BM25Result[T] { return []BM25Result[T]{} } - N := len(e.corpus) - if N == 0 { + if len(e.corpus) == 0 || e.index == nil { return []BM25Result[T]{} } - // Step 1: build per-document tf + raw doc lengths - type docEntry struct { - tf map[string]uint32 - rawLen int - } - - entries := make([]docEntry, N) - df := make(map[string]int, 64) - totalLen := 0 - - for i, doc := range e.corpus { - tokens := bm25Tokenize(e.textFunc(doc)) - totalLen += len(tokens) - - tf := make(map[string]uint32, len(tokens)) - for _, t := range tokens { - tf[t]++ - } - // df: each term counts once per document (iterate the map, keys are unique) - for t := range tf { - df[t]++ - } - - entries[i] = docEntry{tf: tf, rawLen: len(tokens)} - } - - avgDocLen := float64(totalLen) / float64(N) - - // Step 2: pre-compute IDF and per-doc length normalization - // IDF (Robertson smoothing): log( (N - df(t) + 0.5) / (df(t) + 0.5) + 1 ) - idf := make(map[string]float32, len(df)) - for term, freq := range df { - idf[term] = float32(math.Log( - (float64(N)-float64(freq)+0.5)/(float64(freq)+0.5) + 1, - )) - } - - // docLenNorm[i] = k1 * (1 - b + b * |doc_i| / avgDocLen) - // Stored as float32 — sufficient precision for ranking. - docLenNorm := make([]float32, N) - for i, entry := range entries { - docLenNorm[i] = float32(e.k1 * (1 - e.b + e.b*float64(entry.rawLen)/avgDocLen)) - } - - // Step 3: build inverted index (posting lists) - // Iterate the tf map directly — map keys are already unique, no seen-set needed. - posting := make(map[string][]int32, len(df)) - for i, entry := range entries { - for term := range entry.tf { - posting[term] = append(posting[term], int32(i)) - } - } - // Step 4: score via posting lists // Deduplicate query terms to avoid double-weighting the same term. unique := bm25Dedupe(queryTerms) scores := make(map[int32]float32) for _, term := range unique { - termIDF, ok := idf[term] + termIDF, ok := e.index.idf[term] if !ok { continue // term not in vocabulary → zero contribution } - for _, docID := range posting[term] { - freq := float32(entries[docID].tf[term]) + for _, docID := range e.index.posting[term] { + freq := float32(e.index.entries[docID].tf[term]) // TF_norm = freq * (k1+1) / (freq + docLenNorm) - tfNorm := freq * float32(e.k1+1) / (freq + docLenNorm[docID]) + tfNorm := freq * float32(e.k1+1) / (freq + e.index.docLenNorm[docID]) scores[docID] += termIDF * tfNorm } } @@ -212,6 +170,65 @@ func (e *BM25Engine[T]) Search(query string, topK int) []BM25Result[T] { return out } +func buildBM25Index[T any](corpus []T, textFunc func(T) string, k1, b float64) *bm25Index { + N := len(corpus) + if N == 0 { + return nil + } + + entries := make([]bm25DocEntry, N) + rawLens := make([]int, N) + df := make(map[string]int, 64) + totalLen := 0 + + for i, doc := range corpus { + tokens := bm25Tokenize(textFunc(doc)) + totalLen += len(tokens) + rawLens[i] = len(tokens) + + tf := make(map[string]uint32, len(tokens)) + for _, t := range tokens { + tf[t]++ + } + for term := range tf { + df[term]++ + } + + entries[i] = bm25DocEntry{tf: tf} + } + + avgDocLen := float64(totalLen) / float64(N) + if avgDocLen == 0 { + avgDocLen = 1 + } + + idf := make(map[string]float32, len(df)) + for term, freq := range df { + idf[term] = float32(math.Log( + (float64(N)-float64(freq)+0.5)/(float64(freq)+0.5) + 1, + )) + } + + docLenNorm := make([]float32, N) + for i, rawLen := range rawLens { + docLenNorm[i] = float32(k1 * (1 - b + b*float64(rawLen)/avgDocLen)) + } + + posting := make(map[string][]int32, len(df)) + for i, entry := range entries { + for term := range entry.tf { + posting[term] = append(posting[term], int32(i)) + } + } + + return &bm25Index{ + entries: entries, + idf: idf, + docLenNorm: docLenNorm, + posting: posting, + } +} + // bm25Tokenize splits s into lowercase tokens, stripping edge punctuation. func bm25Tokenize(s string) []string { raw := strings.Fields(strings.ToLower(s)) diff --git a/pkg/utils/bm25_test.go b/pkg/utils/bm25_test.go index 4bc85b246..216fe733d 100644 --- a/pkg/utils/bm25_test.go +++ b/pkg/utils/bm25_test.go @@ -1,7 +1,9 @@ package utils import ( + "fmt" "reflect" + "strings" "testing" ) @@ -173,3 +175,61 @@ func TestBM25Search_SortingStability(t *testing.T) { } } } + +func BenchmarkBM25Search_ReusedIndex(b *testing.B) { + corpus := benchmarkBM25Corpus(2000) + engine := NewBM25Engine(corpus, extractText) + query := "hardware gpio i2c sensor controller latency" + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + results := engine.Search(query, 10) + if len(results) == 0 { + b.Fatal("expected non-empty results") + } + } +} + +func BenchmarkBM25Search_RebuildEachTime(b *testing.B) { + corpus := benchmarkBM25Corpus(2000) + query := "hardware gpio i2c sensor controller latency" + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + engine := NewBM25Engine(corpus, extractText) + results := engine.Search(query, 10) + if len(results) == 0 { + b.Fatal("expected non-empty results") + } + } +} + +func benchmarkBM25Corpus(size int) []testDoc { + corpus := make([]testDoc, size) + topics := []string{ + "hardware gpio pwm adc sensor controller latency throughput", + "telegram markdown parser message escape formatting bot command", + "jsonl memory session history storage append compact recovery", + "openai provider routing agent tool search registry hidden tools", + "i2c spi uart serial device bus address transfer clock", + } + + for i := range corpus { + topic := topics[i%len(topics)] + corpus[i] = testDoc{ + ID: i, + Text: fmt.Sprintf( + "doc %d %s repeated repeated %s variant-%d %s", + i, + topic, + topic, + i%17, + strings.Repeat("token ", (i%7)+1), + ), + } + } + + return corpus +}