mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
perf: precompute BM25 index for repeated searches (#2177)
This commit is contained in:
+84
-67
@@ -29,18 +29,18 @@ const (
|
||||
DefaultBM25B = 0.75
|
||||
)
|
||||
|
||||
// BM25Engine is a query-time BM25 search engine over a generic corpus.
|
||||
// BM25Engine is a BM25 search engine over a generic corpus.
|
||||
// T is the document type; the caller supplies a TextFunc that extracts the
|
||||
// searchable text from each document.
|
||||
//
|
||||
// The engine is stateless between queries: no caching, no invalidation logic.
|
||||
// All indexing work is performed inside Search() on every call, making it
|
||||
// safe to use on corpora that change frequently.
|
||||
// The engine precomputes its index once at construction time and reuses it for
|
||||
// subsequent searches. If the corpus content changes, construct a new engine.
|
||||
type BM25Engine[T any] struct {
|
||||
corpus []T
|
||||
textFunc func(T) string
|
||||
k1 float64
|
||||
b float64
|
||||
index *bm25Index
|
||||
}
|
||||
|
||||
// BM25Option is a functional option to configure a BM25Engine.
|
||||
@@ -51,6 +51,17 @@ type bm25Config struct {
|
||||
b float64
|
||||
}
|
||||
|
||||
type bm25Index struct {
|
||||
entries []bm25DocEntry
|
||||
idf map[string]float32
|
||||
docLenNorm []float32
|
||||
posting map[string][]int32
|
||||
}
|
||||
|
||||
type bm25DocEntry struct {
|
||||
tf map[string]uint32
|
||||
}
|
||||
|
||||
// WithK1 overrides the term-frequency saturation constant (default 1.2).
|
||||
func WithK1(k1 float64) BM25Option {
|
||||
return func(c *bm25Config) { c.k1 = k1 }
|
||||
@@ -74,12 +85,14 @@ func NewBM25Engine[T any](corpus []T, textFunc func(T) string, opts ...BM25Optio
|
||||
for _, o := range opts {
|
||||
o(&cfg)
|
||||
}
|
||||
return &BM25Engine[T]{
|
||||
engine := &BM25Engine[T]{
|
||||
corpus: corpus,
|
||||
textFunc: textFunc,
|
||||
k1: cfg.k1,
|
||||
b: cfg.b,
|
||||
}
|
||||
engine.index = buildBM25Index(corpus, textFunc, cfg.k1, cfg.b)
|
||||
return engine
|
||||
}
|
||||
|
||||
// BM25Result is a single ranked result from a Search call.
|
||||
@@ -91,9 +104,8 @@ type BM25Result[T any] struct {
|
||||
// Search ranks the corpus against query and returns the top-k results.
|
||||
// Returns an empty slice (not nil) when there are no matches.
|
||||
//
|
||||
// Complexity: O(N×L) for indexing + O(|Q|×avgPostingLen) for scoring,
|
||||
// where N = corpus size, L = average document length, Q = query terms.
|
||||
// Top-k extraction uses a fixed-size min-heap: O(candidates × log k).
|
||||
// Complexity: O(|Q|×avgPostingLen + candidates × log k) per search after the
|
||||
// one-time indexing work performed by NewBM25Engine.
|
||||
func (e *BM25Engine[T]) Search(query string, topK int) []BM25Result[T] {
|
||||
if topK <= 0 {
|
||||
return []BM25Result[T]{}
|
||||
@@ -104,78 +116,24 @@ func (e *BM25Engine[T]) Search(query string, topK int) []BM25Result[T] {
|
||||
return []BM25Result[T]{}
|
||||
}
|
||||
|
||||
N := len(e.corpus)
|
||||
if N == 0 {
|
||||
if len(e.corpus) == 0 || e.index == nil {
|
||||
return []BM25Result[T]{}
|
||||
}
|
||||
|
||||
// Step 1: build per-document tf + raw doc lengths
|
||||
type docEntry struct {
|
||||
tf map[string]uint32
|
||||
rawLen int
|
||||
}
|
||||
|
||||
entries := make([]docEntry, N)
|
||||
df := make(map[string]int, 64)
|
||||
totalLen := 0
|
||||
|
||||
for i, doc := range e.corpus {
|
||||
tokens := bm25Tokenize(e.textFunc(doc))
|
||||
totalLen += len(tokens)
|
||||
|
||||
tf := make(map[string]uint32, len(tokens))
|
||||
for _, t := range tokens {
|
||||
tf[t]++
|
||||
}
|
||||
// df: each term counts once per document (iterate the map, keys are unique)
|
||||
for t := range tf {
|
||||
df[t]++
|
||||
}
|
||||
|
||||
entries[i] = docEntry{tf: tf, rawLen: len(tokens)}
|
||||
}
|
||||
|
||||
avgDocLen := float64(totalLen) / float64(N)
|
||||
|
||||
// Step 2: pre-compute IDF and per-doc length normalization
|
||||
// IDF (Robertson smoothing): log( (N - df(t) + 0.5) / (df(t) + 0.5) + 1 )
|
||||
idf := make(map[string]float32, len(df))
|
||||
for term, freq := range df {
|
||||
idf[term] = float32(math.Log(
|
||||
(float64(N)-float64(freq)+0.5)/(float64(freq)+0.5) + 1,
|
||||
))
|
||||
}
|
||||
|
||||
// docLenNorm[i] = k1 * (1 - b + b * |doc_i| / avgDocLen)
|
||||
// Stored as float32 — sufficient precision for ranking.
|
||||
docLenNorm := make([]float32, N)
|
||||
for i, entry := range entries {
|
||||
docLenNorm[i] = float32(e.k1 * (1 - e.b + e.b*float64(entry.rawLen)/avgDocLen))
|
||||
}
|
||||
|
||||
// Step 3: build inverted index (posting lists)
|
||||
// Iterate the tf map directly — map keys are already unique, no seen-set needed.
|
||||
posting := make(map[string][]int32, len(df))
|
||||
for i, entry := range entries {
|
||||
for term := range entry.tf {
|
||||
posting[term] = append(posting[term], int32(i))
|
||||
}
|
||||
}
|
||||
|
||||
// Step 4: score via posting lists
|
||||
// Deduplicate query terms to avoid double-weighting the same term.
|
||||
unique := bm25Dedupe(queryTerms)
|
||||
|
||||
scores := make(map[int32]float32)
|
||||
for _, term := range unique {
|
||||
termIDF, ok := idf[term]
|
||||
termIDF, ok := e.index.idf[term]
|
||||
if !ok {
|
||||
continue // term not in vocabulary → zero contribution
|
||||
}
|
||||
for _, docID := range posting[term] {
|
||||
freq := float32(entries[docID].tf[term])
|
||||
for _, docID := range e.index.posting[term] {
|
||||
freq := float32(e.index.entries[docID].tf[term])
|
||||
// TF_norm = freq * (k1+1) / (freq + docLenNorm)
|
||||
tfNorm := freq * float32(e.k1+1) / (freq + docLenNorm[docID])
|
||||
tfNorm := freq * float32(e.k1+1) / (freq + e.index.docLenNorm[docID])
|
||||
scores[docID] += termIDF * tfNorm
|
||||
}
|
||||
}
|
||||
@@ -212,6 +170,65 @@ func (e *BM25Engine[T]) Search(query string, topK int) []BM25Result[T] {
|
||||
return out
|
||||
}
|
||||
|
||||
func buildBM25Index[T any](corpus []T, textFunc func(T) string, k1, b float64) *bm25Index {
|
||||
N := len(corpus)
|
||||
if N == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
entries := make([]bm25DocEntry, N)
|
||||
rawLens := make([]int, N)
|
||||
df := make(map[string]int, 64)
|
||||
totalLen := 0
|
||||
|
||||
for i, doc := range corpus {
|
||||
tokens := bm25Tokenize(textFunc(doc))
|
||||
totalLen += len(tokens)
|
||||
rawLens[i] = len(tokens)
|
||||
|
||||
tf := make(map[string]uint32, len(tokens))
|
||||
for _, t := range tokens {
|
||||
tf[t]++
|
||||
}
|
||||
for term := range tf {
|
||||
df[term]++
|
||||
}
|
||||
|
||||
entries[i] = bm25DocEntry{tf: tf}
|
||||
}
|
||||
|
||||
avgDocLen := float64(totalLen) / float64(N)
|
||||
if avgDocLen == 0 {
|
||||
avgDocLen = 1
|
||||
}
|
||||
|
||||
idf := make(map[string]float32, len(df))
|
||||
for term, freq := range df {
|
||||
idf[term] = float32(math.Log(
|
||||
(float64(N)-float64(freq)+0.5)/(float64(freq)+0.5) + 1,
|
||||
))
|
||||
}
|
||||
|
||||
docLenNorm := make([]float32, N)
|
||||
for i, rawLen := range rawLens {
|
||||
docLenNorm[i] = float32(k1 * (1 - b + b*float64(rawLen)/avgDocLen))
|
||||
}
|
||||
|
||||
posting := make(map[string][]int32, len(df))
|
||||
for i, entry := range entries {
|
||||
for term := range entry.tf {
|
||||
posting[term] = append(posting[term], int32(i))
|
||||
}
|
||||
}
|
||||
|
||||
return &bm25Index{
|
||||
entries: entries,
|
||||
idf: idf,
|
||||
docLenNorm: docLenNorm,
|
||||
posting: posting,
|
||||
}
|
||||
}
|
||||
|
||||
// bm25Tokenize splits s into lowercase tokens, stripping edge punctuation.
|
||||
func bm25Tokenize(s string) []string {
|
||||
raw := strings.Fields(strings.ToLower(s))
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -173,3 +175,61 @@ func TestBM25Search_SortingStability(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkBM25Search_ReusedIndex(b *testing.B) {
|
||||
corpus := benchmarkBM25Corpus(2000)
|
||||
engine := NewBM25Engine(corpus, extractText)
|
||||
query := "hardware gpio i2c sensor controller latency"
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
results := engine.Search(query, 10)
|
||||
if len(results) == 0 {
|
||||
b.Fatal("expected non-empty results")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkBM25Search_RebuildEachTime(b *testing.B) {
|
||||
corpus := benchmarkBM25Corpus(2000)
|
||||
query := "hardware gpio i2c sensor controller latency"
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
engine := NewBM25Engine(corpus, extractText)
|
||||
results := engine.Search(query, 10)
|
||||
if len(results) == 0 {
|
||||
b.Fatal("expected non-empty results")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func benchmarkBM25Corpus(size int) []testDoc {
|
||||
corpus := make([]testDoc, size)
|
||||
topics := []string{
|
||||
"hardware gpio pwm adc sensor controller latency throughput",
|
||||
"telegram markdown parser message escape formatting bot command",
|
||||
"jsonl memory session history storage append compact recovery",
|
||||
"openai provider routing agent tool search registry hidden tools",
|
||||
"i2c spi uart serial device bus address transfer clock",
|
||||
}
|
||||
|
||||
for i := range corpus {
|
||||
topic := topics[i%len(topics)]
|
||||
corpus[i] = testDoc{
|
||||
ID: i,
|
||||
Text: fmt.Sprintf(
|
||||
"doc %d %s repeated repeated %s variant-%d %s",
|
||||
i,
|
||||
topic,
|
||||
topic,
|
||||
i%17,
|
||||
strings.Repeat("token ", (i%7)+1),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
return corpus
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user