perf: precompute BM25 index for repeated searches (#2177)

This commit is contained in:
mattn
2026-03-30 17:30:25 +09:00
committed by GitHub
parent 5e1b6a3971
commit 5e7545a22a
2 changed files with 144 additions and 67 deletions
+84 -67
View File
@@ -29,18 +29,18 @@ const (
DefaultBM25B = 0.75
)
// BM25Engine is a query-time BM25 search engine over a generic corpus.
// BM25Engine is a BM25 search engine over a generic corpus.
// T is the document type; the caller supplies a TextFunc that extracts the
// searchable text from each document.
//
// The engine is stateless between queries: no caching, no invalidation logic.
// All indexing work is performed inside Search() on every call, making it
// safe to use on corpora that change frequently.
// The engine precomputes its index once at construction time and reuses it for
// subsequent searches. If the corpus content changes, construct a new engine.
type BM25Engine[T any] struct {
corpus []T
textFunc func(T) string
k1 float64
b float64
index *bm25Index
}
// BM25Option is a functional option to configure a BM25Engine.
@@ -51,6 +51,17 @@ type bm25Config struct {
b float64
}
type bm25Index struct {
entries []bm25DocEntry
idf map[string]float32
docLenNorm []float32
posting map[string][]int32
}
type bm25DocEntry struct {
tf map[string]uint32
}
// WithK1 overrides the term-frequency saturation constant (default 1.2).
func WithK1(k1 float64) BM25Option {
return func(c *bm25Config) { c.k1 = k1 }
@@ -74,12 +85,14 @@ func NewBM25Engine[T any](corpus []T, textFunc func(T) string, opts ...BM25Optio
for _, o := range opts {
o(&cfg)
}
return &BM25Engine[T]{
engine := &BM25Engine[T]{
corpus: corpus,
textFunc: textFunc,
k1: cfg.k1,
b: cfg.b,
}
engine.index = buildBM25Index(corpus, textFunc, cfg.k1, cfg.b)
return engine
}
// BM25Result is a single ranked result from a Search call.
@@ -91,9 +104,8 @@ type BM25Result[T any] struct {
// Search ranks the corpus against query and returns the top-k results.
// Returns an empty slice (not nil) when there are no matches.
//
// Complexity: O(N×L) for indexing + O(|Q|×avgPostingLen) for scoring,
// where N = corpus size, L = average document length, Q = query terms.
// Top-k extraction uses a fixed-size min-heap: O(candidates × log k).
// Complexity: O(|Q|×avgPostingLen + candidates × log k) per search after the
// one-time indexing work performed by NewBM25Engine.
func (e *BM25Engine[T]) Search(query string, topK int) []BM25Result[T] {
if topK <= 0 {
return []BM25Result[T]{}
@@ -104,78 +116,24 @@ func (e *BM25Engine[T]) Search(query string, topK int) []BM25Result[T] {
return []BM25Result[T]{}
}
N := len(e.corpus)
if N == 0 {
if len(e.corpus) == 0 || e.index == nil {
return []BM25Result[T]{}
}
// Step 1: build per-document tf + raw doc lengths
type docEntry struct {
tf map[string]uint32
rawLen int
}
entries := make([]docEntry, N)
df := make(map[string]int, 64)
totalLen := 0
for i, doc := range e.corpus {
tokens := bm25Tokenize(e.textFunc(doc))
totalLen += len(tokens)
tf := make(map[string]uint32, len(tokens))
for _, t := range tokens {
tf[t]++
}
// df: each term counts once per document (iterate the map, keys are unique)
for t := range tf {
df[t]++
}
entries[i] = docEntry{tf: tf, rawLen: len(tokens)}
}
avgDocLen := float64(totalLen) / float64(N)
// Step 2: pre-compute IDF and per-doc length normalization
// IDF (Robertson smoothing): log( (N - df(t) + 0.5) / (df(t) + 0.5) + 1 )
idf := make(map[string]float32, len(df))
for term, freq := range df {
idf[term] = float32(math.Log(
(float64(N)-float64(freq)+0.5)/(float64(freq)+0.5) + 1,
))
}
// docLenNorm[i] = k1 * (1 - b + b * |doc_i| / avgDocLen)
// Stored as float32 — sufficient precision for ranking.
docLenNorm := make([]float32, N)
for i, entry := range entries {
docLenNorm[i] = float32(e.k1 * (1 - e.b + e.b*float64(entry.rawLen)/avgDocLen))
}
// Step 3: build inverted index (posting lists)
// Iterate the tf map directly — map keys are already unique, no seen-set needed.
posting := make(map[string][]int32, len(df))
for i, entry := range entries {
for term := range entry.tf {
posting[term] = append(posting[term], int32(i))
}
}
// Step 4: score via posting lists
// Deduplicate query terms to avoid double-weighting the same term.
unique := bm25Dedupe(queryTerms)
scores := make(map[int32]float32)
for _, term := range unique {
termIDF, ok := idf[term]
termIDF, ok := e.index.idf[term]
if !ok {
continue // term not in vocabulary → zero contribution
}
for _, docID := range posting[term] {
freq := float32(entries[docID].tf[term])
for _, docID := range e.index.posting[term] {
freq := float32(e.index.entries[docID].tf[term])
// TF_norm = freq * (k1+1) / (freq + docLenNorm)
tfNorm := freq * float32(e.k1+1) / (freq + docLenNorm[docID])
tfNorm := freq * float32(e.k1+1) / (freq + e.index.docLenNorm[docID])
scores[docID] += termIDF * tfNorm
}
}
@@ -212,6 +170,65 @@ func (e *BM25Engine[T]) Search(query string, topK int) []BM25Result[T] {
return out
}
func buildBM25Index[T any](corpus []T, textFunc func(T) string, k1, b float64) *bm25Index {
N := len(corpus)
if N == 0 {
return nil
}
entries := make([]bm25DocEntry, N)
rawLens := make([]int, N)
df := make(map[string]int, 64)
totalLen := 0
for i, doc := range corpus {
tokens := bm25Tokenize(textFunc(doc))
totalLen += len(tokens)
rawLens[i] = len(tokens)
tf := make(map[string]uint32, len(tokens))
for _, t := range tokens {
tf[t]++
}
for term := range tf {
df[term]++
}
entries[i] = bm25DocEntry{tf: tf}
}
avgDocLen := float64(totalLen) / float64(N)
if avgDocLen == 0 {
avgDocLen = 1
}
idf := make(map[string]float32, len(df))
for term, freq := range df {
idf[term] = float32(math.Log(
(float64(N)-float64(freq)+0.5)/(float64(freq)+0.5) + 1,
))
}
docLenNorm := make([]float32, N)
for i, rawLen := range rawLens {
docLenNorm[i] = float32(k1 * (1 - b + b*float64(rawLen)/avgDocLen))
}
posting := make(map[string][]int32, len(df))
for i, entry := range entries {
for term := range entry.tf {
posting[term] = append(posting[term], int32(i))
}
}
return &bm25Index{
entries: entries,
idf: idf,
docLenNorm: docLenNorm,
posting: posting,
}
}
// bm25Tokenize splits s into lowercase tokens, stripping edge punctuation.
func bm25Tokenize(s string) []string {
raw := strings.Fields(strings.ToLower(s))
+60
View File
@@ -1,7 +1,9 @@
package utils
import (
"fmt"
"reflect"
"strings"
"testing"
)
@@ -173,3 +175,61 @@ func TestBM25Search_SortingStability(t *testing.T) {
}
}
}
func BenchmarkBM25Search_ReusedIndex(b *testing.B) {
corpus := benchmarkBM25Corpus(2000)
engine := NewBM25Engine(corpus, extractText)
query := "hardware gpio i2c sensor controller latency"
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
results := engine.Search(query, 10)
if len(results) == 0 {
b.Fatal("expected non-empty results")
}
}
}
func BenchmarkBM25Search_RebuildEachTime(b *testing.B) {
corpus := benchmarkBM25Corpus(2000)
query := "hardware gpio i2c sensor controller latency"
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
engine := NewBM25Engine(corpus, extractText)
results := engine.Search(query, 10)
if len(results) == 0 {
b.Fatal("expected non-empty results")
}
}
}
func benchmarkBM25Corpus(size int) []testDoc {
corpus := make([]testDoc, size)
topics := []string{
"hardware gpio pwm adc sensor controller latency throughput",
"telegram markdown parser message escape formatting bot command",
"jsonl memory session history storage append compact recovery",
"openai provider routing agent tool search registry hidden tools",
"i2c spi uart serial device bus address transfer clock",
}
for i := range corpus {
topic := topics[i%len(topics)]
corpus[i] = testDoc{
ID: i,
Text: fmt.Sprintf(
"doc %d %s repeated repeated %s variant-%d %s",
i,
topic,
topic,
i%17,
strings.Repeat("token ", (i%7)+1),
),
}
}
return corpus
}