mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-05-25 16:00:35 +00:00
membench: add LLM-as-Judge evaluation mode (#2484)
* membench: add LLM-as-Judge evaluation mode Add --eval-mode=llm to membench for LLM-based answer generation and semantic scoring via an OpenAI-compatible API endpoint. New files: - llm_client.go: generic OpenAI-compatible chat completion client with support for API key, configurable timeout, and optional chat_template_kwargs (for llama.cpp thinking models) - eval_llm.go: LLM answer generation + LLM-as-Judge scoring for both legacy and seahorse retrieval modes Changes to main.go: - --eval-mode flag (token|llm) to select evaluation strategy - --api-base, --api-key, --model flags with env var fallback (MEMBENCH_API_BASE, MEMBENCH_API_KEY, MEMBENCH_MODEL) - --no-thinking flag for llama.cpp + Qwen thinking models - --limit flag to cap QA questions per sample for quick testing * style: fix golangci-lint formatting (gofmt + golines) * fix: address Copilot review feedback - Validate --model is required for LLM eval mode - Use rune-based truncation to preserve valid UTF-8 - Precompute totalQA count outside inner loop - Log SearchMessages errors instead of silently skipping * fix: address Copilot review round 2 - Validate --eval-mode accepts only 'token' or 'llm' - Normalize base URL to avoid /v1/v1 duplication - Separate token/LLM results for correct PrintComparison labeling - Log ExpandMessages errors instead of silently ignoring - Short-circuit with 0 scores when no context retrieved (match token eval) - Add --timeout flag wired to LLMClientOptions.Timeout * fix: address review P1+P2 — sort alignment, failure sentinel, score parser - P1: Replace hand-rolled sortByRank with sort.Slice (ascending, best first) matching eval.go's EvalSeahorse — ensures BudgetTruncate keeps best-ranked messages when truncation occurs - P2: Use -1.0 sentinel for LLM API failures and parse errors, distinct from genuine 0.0 score; aggregateMetrics skips -1.0 entries for F1 averaging while still counting HitRate - P2: Use regexp \b([1-5])\b for judge score extraction instead of first-digit scan — avoids misparses on '5/5', 'Score: 3' etc. * fix: address Copilot review round 2 - Fix F1/HitRate weighted aggregation: track ValidF1Count separately so computeModeAgg weights F1 by valid scores only, not TotalQuestions - No-context retrieval failure uses 0.0 (genuine bad score) instead of -1.0 sentinel (reserved for API/parse failures) - Validate --timeout > 0 to prevent disabling HTTP timeouts * fix: remove hardcoded /v1 from API base URL Users now provide the full versioned path in --api-base (e.g. /v1, /v4). Code only appends /chat/completions. Default changed to http://127.0.0.1:8080/v1 for backward compatibility. * fix: address Copilot review round 3 - ValidF1Count=0 when all scores are sentinel (no forced =1) - Backward compat: old eval JSON without ValidF1Count falls back to TotalQuestions in computeModeAgg - Skip empty section in PrintComparison when tokenResults is empty - Update --api-base flag help to document /v1 default and version path - Add sentinel aggregation unit tests (partial, all, weighted) * feat: add --retries flag with exponential backoff for transient LLM errors Retry on timeout, 5xx, and 429 (rate limit) with 1s/2s/4s backoff. Default 3 retries, configurable via --retries. Context cancellation is respected between retries. * fix: address Copilot review round 4 - runReport splits results by mode suffix into token/llm for PrintComparison - backward compat fallback (ValidF1Count=0 -> TotalQuestions) only for non-LLM modes; LLM modes keep ValidF1Count=0 when all scores sentinel - MaxRetries==0 means no retry; only negative falls back to default 3 - truncateStr uses []rune to avoid cutting multi-byte UTF-8 characters - Complete() returns error on empty LLM response (vs silent empty string) * feat: --no-thinking adapts to llama.cpp, Ollama, and GLM backends Send all three disable-thinking fields simultaneously: - chat_template_kwargs.enable_thinking=false (llama.cpp, GLM) - think=false (Ollama 0.9+) - thinking.type=disabled (GLM/Zhipu) Each backend picks the field it recognizes and ignores the rest. Also bumps max_tokens from 512 to 2048 for thinking models. * feat: mixed model eval + concurrent QA workers - Add --judge-model, --judge-api-base, --judge-api-key flags for separate judge model - Add --concurrency flag (default 1) with semaphore-based goroutine pool - Add reasoning_content fallback for GLM/DeepSeek style responses - Prepend /no_think to system prompt for Ollama /v1 compatibility - Reduce default MaxTokens from 2048 to 512 (answers are 1-3 sentences) - Extract evalQAWorker and buildSeahorseContext for shared concurrent logic --------- Co-authored-by: BeaconCat <BeaconCat@users.noreply.github.com>
This commit is contained in:
+74
-28
@@ -36,6 +36,7 @@ type AggMetrics struct {
|
|||||||
OverallHitRate float64 `json:"overallHitRate"`
|
OverallHitRate float64 `json:"overallHitRate"`
|
||||||
ByCategory map[int]*CatMetrics `json:"byCategory"`
|
ByCategory map[int]*CatMetrics `json:"byCategory"`
|
||||||
TotalQuestions int `json:"totalQuestions"`
|
TotalQuestions int `json:"totalQuestions"`
|
||||||
|
ValidF1Count int `json:"validF1Count"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// CatMetrics holds metrics for a single category.
|
// CatMetrics holds metrics for a single category.
|
||||||
@@ -43,6 +44,7 @@ type CatMetrics struct {
|
|||||||
F1 float64 `json:"f1"`
|
F1 float64 `json:"f1"`
|
||||||
HitRate float64 `json:"hitRate"`
|
HitRate float64 `json:"hitRate"`
|
||||||
QuestionCount int `json:"questionCount"`
|
QuestionCount int `json:"questionCount"`
|
||||||
|
ValidF1Count int `json:"validF1Count"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// EvalLegacy evaluates using legacy session store (raw history + budget truncation).
|
// EvalLegacy evaluates using legacy session store (raw history + budget truncation).
|
||||||
@@ -201,38 +203,64 @@ func EvalSeahorse(
|
|||||||
|
|
||||||
// aggregateMetrics computes overall and per-category metrics.
|
// aggregateMetrics computes overall and per-category metrics.
|
||||||
func aggregateMetrics(qaResults []QAResult) AggMetrics {
|
func aggregateMetrics(qaResults []QAResult) AggMetrics {
|
||||||
byCat := map[int]*CatMetrics{}
|
type catAccum struct {
|
||||||
|
f1Sum float64
|
||||||
|
f1Count int
|
||||||
|
hitRateSum float64
|
||||||
|
hitRateCount int
|
||||||
|
}
|
||||||
|
byCatAcc := map[int]*catAccum{}
|
||||||
totalF1 := 0.0
|
totalF1 := 0.0
|
||||||
totalHitRate := 0.0
|
totalHitRate := 0.0
|
||||||
|
validF1Count := 0
|
||||||
for _, qr := range qaResults {
|
for _, qr := range qaResults {
|
||||||
totalF1 += qr.TokenF1
|
// Skip sentinel -1.0 scores (LLM API/parse failures) from F1 averaging.
|
||||||
totalHitRate += qr.HitRate
|
if qr.TokenF1 >= 0 {
|
||||||
cat, ok := byCat[qr.Category]
|
totalF1 += qr.TokenF1
|
||||||
if !ok {
|
validF1Count++
|
||||||
cat = &CatMetrics{}
|
|
||||||
byCat[qr.Category] = cat
|
|
||||||
}
|
}
|
||||||
cat.F1 += qr.TokenF1
|
totalHitRate += qr.HitRate
|
||||||
cat.HitRate += qr.HitRate
|
acc, ok := byCatAcc[qr.Category]
|
||||||
cat.QuestionCount++
|
if !ok {
|
||||||
|
acc = &catAccum{}
|
||||||
|
byCatAcc[qr.Category] = acc
|
||||||
|
}
|
||||||
|
if qr.TokenF1 >= 0 {
|
||||||
|
acc.f1Sum += qr.TokenF1
|
||||||
|
acc.f1Count++
|
||||||
|
}
|
||||||
|
acc.hitRateSum += qr.HitRate
|
||||||
|
acc.hitRateCount++
|
||||||
}
|
}
|
||||||
n := len(qaResults)
|
nHit := len(qaResults)
|
||||||
if n == 0 {
|
if nHit == 0 {
|
||||||
n = 1
|
nHit = 1
|
||||||
}
|
}
|
||||||
agg := AggMetrics{
|
byCat := map[int]*CatMetrics{}
|
||||||
OverallF1: totalF1 / float64(n),
|
for cat, acc := range byCatAcc {
|
||||||
OverallHitRate: totalHitRate / float64(n),
|
cm := &CatMetrics{
|
||||||
|
QuestionCount: acc.hitRateCount,
|
||||||
|
ValidF1Count: acc.f1Count,
|
||||||
|
}
|
||||||
|
if acc.f1Count > 0 {
|
||||||
|
cm.F1 = acc.f1Sum / float64(acc.f1Count)
|
||||||
|
}
|
||||||
|
if acc.hitRateCount > 0 {
|
||||||
|
cm.HitRate = acc.hitRateSum / float64(acc.hitRateCount)
|
||||||
|
}
|
||||||
|
byCat[cat] = cm
|
||||||
|
}
|
||||||
|
var overallF1 float64
|
||||||
|
if validF1Count > 0 {
|
||||||
|
overallF1 = totalF1 / float64(validF1Count)
|
||||||
|
}
|
||||||
|
return AggMetrics{
|
||||||
|
OverallF1: overallF1,
|
||||||
|
OverallHitRate: totalHitRate / float64(nHit),
|
||||||
ByCategory: byCat,
|
ByCategory: byCat,
|
||||||
TotalQuestions: len(qaResults),
|
TotalQuestions: len(qaResults),
|
||||||
|
ValidF1Count: validF1Count,
|
||||||
}
|
}
|
||||||
for _, cat := range agg.ByCategory {
|
|
||||||
if cat.QuestionCount > 0 {
|
|
||||||
cat.F1 /= float64(cat.QuestionCount)
|
|
||||||
cat.HitRate /= float64(cat.QuestionCount)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return agg
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveResults writes per-sample eval results to JSON files.
|
// SaveResults writes per-sample eval results to JSON files.
|
||||||
@@ -277,27 +305,43 @@ func SaveAggregated(results []EvalResult, outDir string) error {
|
|||||||
func computeModeAgg(results []EvalResult) AggMetrics {
|
func computeModeAgg(results []EvalResult) AggMetrics {
|
||||||
agg := AggMetrics{ByCategory: map[int]*CatMetrics{}}
|
agg := AggMetrics{ByCategory: map[int]*CatMetrics{}}
|
||||||
for _, r := range results {
|
for _, r := range results {
|
||||||
agg.OverallF1 += r.Agg.OverallF1 * float64(r.Agg.TotalQuestions)
|
// Backward compat: old eval JSON (token mode) without ValidF1Count → use TotalQuestions.
|
||||||
|
// LLM modes may legitimately have ValidF1Count==0 (all failures).
|
||||||
|
vf1 := r.Agg.ValidF1Count
|
||||||
|
if vf1 == 0 && r.Agg.TotalQuestions > 0 && !strings.HasSuffix(r.Mode, "-llm") {
|
||||||
|
vf1 = r.Agg.TotalQuestions
|
||||||
|
}
|
||||||
|
agg.OverallF1 += r.Agg.OverallF1 * float64(vf1)
|
||||||
agg.OverallHitRate += r.Agg.OverallHitRate * float64(r.Agg.TotalQuestions)
|
agg.OverallHitRate += r.Agg.OverallHitRate * float64(r.Agg.TotalQuestions)
|
||||||
agg.TotalQuestions += r.Agg.TotalQuestions
|
agg.TotalQuestions += r.Agg.TotalQuestions
|
||||||
|
agg.ValidF1Count += vf1
|
||||||
for cat, cm := range r.Agg.ByCategory {
|
for cat, cm := range r.Agg.ByCategory {
|
||||||
existing, ok := agg.ByCategory[cat]
|
existing, ok := agg.ByCategory[cat]
|
||||||
if !ok {
|
if !ok {
|
||||||
existing = &CatMetrics{}
|
existing = &CatMetrics{}
|
||||||
agg.ByCategory[cat] = existing
|
agg.ByCategory[cat] = existing
|
||||||
}
|
}
|
||||||
existing.F1 += cm.F1 * float64(cm.QuestionCount)
|
cvf1 := cm.ValidF1Count
|
||||||
|
if cvf1 == 0 && cm.QuestionCount > 0 && !strings.HasSuffix(r.Mode, "-llm") {
|
||||||
|
cvf1 = cm.QuestionCount
|
||||||
|
}
|
||||||
|
existing.F1 += cm.F1 * float64(cvf1)
|
||||||
existing.HitRate += cm.HitRate * float64(cm.QuestionCount)
|
existing.HitRate += cm.HitRate * float64(cm.QuestionCount)
|
||||||
existing.QuestionCount += cm.QuestionCount
|
existing.QuestionCount += cm.QuestionCount
|
||||||
|
existing.ValidF1Count += cvf1
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if agg.ValidF1Count > 0 {
|
||||||
|
agg.OverallF1 /= float64(agg.ValidF1Count)
|
||||||
|
}
|
||||||
if agg.TotalQuestions > 0 {
|
if agg.TotalQuestions > 0 {
|
||||||
agg.OverallF1 /= float64(agg.TotalQuestions)
|
|
||||||
agg.OverallHitRate /= float64(agg.TotalQuestions)
|
agg.OverallHitRate /= float64(agg.TotalQuestions)
|
||||||
}
|
}
|
||||||
for _, cat := range agg.ByCategory {
|
for _, cat := range agg.ByCategory {
|
||||||
|
if cat.ValidF1Count > 0 {
|
||||||
|
cat.F1 /= float64(cat.ValidF1Count)
|
||||||
|
}
|
||||||
if cat.QuestionCount > 0 {
|
if cat.QuestionCount > 0 {
|
||||||
cat.F1 /= float64(cat.QuestionCount)
|
|
||||||
cat.HitRate /= float64(cat.QuestionCount)
|
cat.HitRate /= float64(cat.QuestionCount)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -359,7 +403,9 @@ func printSection(title string, results []EvalResult) {
|
|||||||
|
|
||||||
// PrintComparison outputs a human-readable comparison table to stdout.
|
// PrintComparison outputs a human-readable comparison table to stdout.
|
||||||
func PrintComparison(results []EvalResult, llmResults []EvalResult) {
|
func PrintComparison(results []EvalResult, llmResults []EvalResult) {
|
||||||
printSection("No LLM generation", results)
|
if len(results) > 0 {
|
||||||
|
printSection("No LLM generation", results)
|
||||||
|
}
|
||||||
if len(llmResults) > 0 {
|
if len(llmResults) > 0 {
|
||||||
printSection("With LLM", llmResults)
|
printSection("With LLM", llmResults)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,346 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"regexp"
|
||||||
|
"sort"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/sipeed/picoclaw/pkg/seahorse"
|
||||||
|
)
|
||||||
|
|
||||||
|
const answerSystemPrompt = `You are a helpful assistant. Given conversation context, answer the question concisely and accurately. If the answer is not in the context, say "I don't know". Answer in 1-3 sentences maximum.`
|
||||||
|
|
||||||
|
const judgeSystemPrompt = `You are an impartial judge evaluating answer quality.
|
||||||
|
Compare the candidate answer against the reference answer.
|
||||||
|
Consider semantic equivalence — different wording expressing the same meaning should score high.
|
||||||
|
|
||||||
|
Output ONLY a single integer score from 1 to 5:
|
||||||
|
1 = completely wrong or irrelevant
|
||||||
|
2 = partially related but mostly incorrect
|
||||||
|
3 = partially correct, missing key details
|
||||||
|
4 = mostly correct with minor omissions
|
||||||
|
5 = fully correct, semantically equivalent
|
||||||
|
|
||||||
|
Output ONLY the number, nothing else.`
|
||||||
|
|
||||||
|
// generateAnswer asks the LLM to answer a question given retrieved context.
|
||||||
|
func generateAnswer(ctx context.Context, client *LLMClient, contextText, question string) (string, error) {
|
||||||
|
// Truncate context to avoid exceeding model limits while preserving valid UTF-8.
|
||||||
|
contextRunes := []rune(contextText)
|
||||||
|
if len(contextRunes) > 6000 {
|
||||||
|
contextText = string(contextRunes[:6000]) + "\n... [truncated]"
|
||||||
|
}
|
||||||
|
|
||||||
|
userPrompt := fmt.Sprintf("## Conversation Context\n\n%s\n\n## Question\n\n%s", contextText, question)
|
||||||
|
return client.Complete(ctx, answerSystemPrompt, userPrompt)
|
||||||
|
}
|
||||||
|
|
||||||
|
// scoreRe matches the first standalone integer 1-5 in the judge response.
|
||||||
|
var scoreRe = regexp.MustCompile(`\b([1-5])\b`)
|
||||||
|
|
||||||
|
// judgeAnswer asks the LLM to score the candidate answer vs the gold answer.
|
||||||
|
// Returns a score from 0.0 to 1.0, or -1.0 on parse failure.
|
||||||
|
func judgeAnswer(
|
||||||
|
ctx context.Context,
|
||||||
|
judgeClient *LLMClient,
|
||||||
|
question, goldAnswer, candidateAnswer string,
|
||||||
|
) (float64, error) {
|
||||||
|
userPrompt := fmt.Sprintf(
|
||||||
|
"Question: %s\n\nReference Answer: %s\n\nCandidate Answer: %s\n\nScore:",
|
||||||
|
question, goldAnswer, candidateAnswer,
|
||||||
|
)
|
||||||
|
|
||||||
|
response, err := judgeClient.Complete(ctx, judgeSystemPrompt, userPrompt)
|
||||||
|
if err != nil {
|
||||||
|
return -1.0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
response = strings.TrimSpace(response)
|
||||||
|
if m := scoreRe.FindStringSubmatch(response); len(m) == 2 {
|
||||||
|
score, _ := strconv.Atoi(m[1])
|
||||||
|
return float64(score-1) / 4.0, nil // Normalize 1-5 to 0.0-1.0
|
||||||
|
}
|
||||||
|
log.Printf("WARNING: could not parse judge score from: %q, returning -1", response)
|
||||||
|
return -1.0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// qaWork describes one QA evaluation unit.
|
||||||
|
type qaWork struct {
|
||||||
|
sampleID string
|
||||||
|
qaIndex int
|
||||||
|
globalIndex int
|
||||||
|
totalQA int
|
||||||
|
qa *LocomoQA
|
||||||
|
contextText string
|
||||||
|
sample *LocomoSample
|
||||||
|
}
|
||||||
|
|
||||||
|
// qaResult collects one QA evaluation output.
|
||||||
|
type qaResultOut struct {
|
||||||
|
index int // position in the flat QA list for ordering
|
||||||
|
result QAResult
|
||||||
|
answer string
|
||||||
|
score float64
|
||||||
|
}
|
||||||
|
|
||||||
|
// evalQAWorker processes a single QA item: generate answer + judge score.
|
||||||
|
func evalQAWorker(
|
||||||
|
ctx context.Context,
|
||||||
|
w qaWork,
|
||||||
|
answerClient, judgeClient *LLMClient,
|
||||||
|
logPrefix string,
|
||||||
|
) qaResultOut {
|
||||||
|
llmAnswer, err := generateAnswer(ctx, answerClient, w.contextText, w.qa.Question)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("WARN: LLM generation failed for sample %s Q%d: %v", w.sampleID, w.qaIndex, err)
|
||||||
|
llmAnswer = ""
|
||||||
|
}
|
||||||
|
|
||||||
|
score := -1.0
|
||||||
|
if llmAnswer != "" {
|
||||||
|
score, err = judgeAnswer(ctx, judgeClient, w.qa.Question, w.qa.AnswerString(), llmAnswer)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("WARN: LLM judge failed for sample %s Q%d: %v", w.sampleID, w.qaIndex, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
hitRate := RecallHitRate(w.qa.Evidence, w.sample, w.contextText)
|
||||||
|
|
||||||
|
log.Printf("[%s] sample=%s q=%d/%d score=%.2f answer=%q",
|
||||||
|
logPrefix, w.sampleID, w.globalIndex, w.totalQA, score, truncateStr(llmAnswer, 80))
|
||||||
|
|
||||||
|
return qaResultOut{
|
||||||
|
index: w.globalIndex,
|
||||||
|
result: QAResult{
|
||||||
|
Question: w.qa.Question,
|
||||||
|
Category: w.qa.Category,
|
||||||
|
GoldAnswer: w.qa.AnswerString(),
|
||||||
|
TokenF1: score,
|
||||||
|
HitRate: hitRate,
|
||||||
|
},
|
||||||
|
answer: llmAnswer,
|
||||||
|
score: score,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// EvalLegacyLLM evaluates legacy store using LLM generation + LLM-as-Judge.
|
||||||
|
func EvalLegacyLLM(
|
||||||
|
ctx context.Context,
|
||||||
|
samples []LocomoSample,
|
||||||
|
legacy *LegacyStore,
|
||||||
|
budgetTokens int,
|
||||||
|
answerClient, judgeClient *LLMClient,
|
||||||
|
concurrency int,
|
||||||
|
) []EvalResult {
|
||||||
|
if concurrency < 1 {
|
||||||
|
concurrency = 1
|
||||||
|
}
|
||||||
|
totalQA := countTotalQA(samples)
|
||||||
|
results := make([]EvalResult, 0, len(samples))
|
||||||
|
|
||||||
|
for si := range samples {
|
||||||
|
sample := &samples[si]
|
||||||
|
history := legacy.GetHistory(sample.SampleID)
|
||||||
|
|
||||||
|
allContent := make([]string, 0, len(history))
|
||||||
|
for _, msg := range history {
|
||||||
|
allContent = append(allContent, msg.Content)
|
||||||
|
}
|
||||||
|
|
||||||
|
truncated, _ := BudgetTruncate(allContent, budgetTokens)
|
||||||
|
contextText := StringListToContent(truncated)
|
||||||
|
|
||||||
|
qaResults := make([]QAResult, len(sample.QA))
|
||||||
|
|
||||||
|
if concurrency <= 1 {
|
||||||
|
for qi := range sample.QA {
|
||||||
|
out := evalQAWorker(ctx, qaWork{
|
||||||
|
sampleID: sample.SampleID, qaIndex: qi,
|
||||||
|
globalIndex: si*len(sample.QA) + qi + 1, totalQA: totalQA,
|
||||||
|
qa: &sample.QA[qi], contextText: contextText, sample: sample,
|
||||||
|
}, answerClient, judgeClient, "legacy-llm")
|
||||||
|
qaResults[qi] = out.result
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
sem := make(chan struct{}, concurrency)
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for qi := range sample.QA {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
sem <- struct{}{}
|
||||||
|
defer func() { <-sem }()
|
||||||
|
out := evalQAWorker(ctx, qaWork{
|
||||||
|
sampleID: sample.SampleID, qaIndex: qi,
|
||||||
|
globalIndex: si*len(sample.QA) + qi + 1, totalQA: totalQA,
|
||||||
|
qa: &sample.QA[qi], contextText: contextText, sample: sample,
|
||||||
|
}, answerClient, judgeClient, "legacy-llm")
|
||||||
|
qaResults[qi] = out.result // safe: each goroutine writes distinct index
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
results = append(results, EvalResult{
|
||||||
|
Mode: "legacy-llm",
|
||||||
|
SampleID: sample.SampleID,
|
||||||
|
QAResults: qaResults,
|
||||||
|
Agg: aggregateMetrics(qaResults),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return results
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildSeahorseContext retrieves context for a seahorse QA item.
|
||||||
|
func buildSeahorseContext(
|
||||||
|
ctx context.Context,
|
||||||
|
ir *SeahorseIngestResult,
|
||||||
|
sample *LocomoSample,
|
||||||
|
qa *LocomoQA,
|
||||||
|
budgetTokens int,
|
||||||
|
) string {
|
||||||
|
store := ir.Engine.GetRetrieval().Store()
|
||||||
|
retrieval := ir.Engine.GetRetrieval()
|
||||||
|
convID := ir.ConvMap[sample.SampleID]
|
||||||
|
|
||||||
|
keywords := ExtractKeywords(qa.Question)
|
||||||
|
bestRank := map[int64]float64{}
|
||||||
|
for _, kw := range keywords {
|
||||||
|
searchResults, err := store.SearchMessages(ctx, seahorse.SearchInput{
|
||||||
|
Pattern: kw,
|
||||||
|
ConversationID: convID,
|
||||||
|
Limit: 20,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, sr := range searchResults {
|
||||||
|
if sr.MessageID > 0 {
|
||||||
|
if prev, ok := bestRank[sr.MessageID]; !ok || sr.Rank < prev {
|
||||||
|
bestRank[sr.MessageID] = sr.Rank
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
messageIDs := make([]int64, 0, len(bestRank))
|
||||||
|
for id := range bestRank {
|
||||||
|
messageIDs = append(messageIDs, id)
|
||||||
|
}
|
||||||
|
sort.Slice(messageIDs, func(i, j int) bool {
|
||||||
|
return bestRank[messageIDs[i]] < bestRank[messageIDs[j]]
|
||||||
|
})
|
||||||
|
|
||||||
|
var contentParts []string
|
||||||
|
if len(messageIDs) > 0 {
|
||||||
|
expandResult, err := retrieval.ExpandMessages(ctx, messageIDs)
|
||||||
|
if err == nil {
|
||||||
|
for _, msg := range expandResult.Messages {
|
||||||
|
contentParts = append(contentParts, msg.Content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(contentParts) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
truncated, _ := BudgetTruncate(contentParts, budgetTokens)
|
||||||
|
return StringListToContent(truncated)
|
||||||
|
}
|
||||||
|
|
||||||
|
// EvalSeahorseLLM evaluates seahorse retrieval using LLM generation + LLM-as-Judge.
|
||||||
|
func EvalSeahorseLLM(
|
||||||
|
ctx context.Context,
|
||||||
|
samples []LocomoSample,
|
||||||
|
ir *SeahorseIngestResult,
|
||||||
|
budgetTokens int,
|
||||||
|
answerClient, judgeClient *LLMClient,
|
||||||
|
concurrency int,
|
||||||
|
) []EvalResult {
|
||||||
|
if concurrency < 1 {
|
||||||
|
concurrency = 1
|
||||||
|
}
|
||||||
|
totalQA := countTotalQA(samples)
|
||||||
|
results := make([]EvalResult, 0, len(samples))
|
||||||
|
|
||||||
|
for si := range samples {
|
||||||
|
sample := &samples[si]
|
||||||
|
if _, ok := ir.ConvMap[sample.SampleID]; !ok {
|
||||||
|
log.Printf("WARN: no conversation ID for sample %s", sample.SampleID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
qaResults := make([]QAResult, len(sample.QA))
|
||||||
|
|
||||||
|
evalOne := func(qi int) {
|
||||||
|
qa := &sample.QA[qi]
|
||||||
|
contextText := buildSeahorseContext(ctx, ir, sample, qa, budgetTokens)
|
||||||
|
if contextText == "" {
|
||||||
|
qaResults[qi] = QAResult{
|
||||||
|
Question: qa.Question,
|
||||||
|
Category: qa.Category,
|
||||||
|
GoldAnswer: qa.AnswerString(),
|
||||||
|
TokenF1: 0.0,
|
||||||
|
HitRate: 0.0,
|
||||||
|
}
|
||||||
|
log.Printf("[seahorse-llm] sample=%s q=%d/%d score=0.00 answer=(no context)",
|
||||||
|
sample.SampleID, si*len(sample.QA)+qi+1, totalQA)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
out := evalQAWorker(ctx, qaWork{
|
||||||
|
sampleID: sample.SampleID, qaIndex: qi,
|
||||||
|
globalIndex: si*len(sample.QA) + qi + 1, totalQA: totalQA,
|
||||||
|
qa: qa, contextText: contextText, sample: sample,
|
||||||
|
}, answerClient, judgeClient, "seahorse-llm")
|
||||||
|
qaResults[qi] = out.result
|
||||||
|
}
|
||||||
|
|
||||||
|
if concurrency <= 1 {
|
||||||
|
for qi := range sample.QA {
|
||||||
|
evalOne(qi)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
sem := make(chan struct{}, concurrency)
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for qi := range sample.QA {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
sem <- struct{}{}
|
||||||
|
defer func() { <-sem }()
|
||||||
|
evalOne(qi)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
results = append(results, EvalResult{
|
||||||
|
Mode: "seahorse-llm",
|
||||||
|
SampleID: sample.SampleID,
|
||||||
|
QAResults: qaResults,
|
||||||
|
Agg: aggregateMetrics(qaResults),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return results
|
||||||
|
}
|
||||||
|
|
||||||
|
func countTotalQA(samples []LocomoSample) int {
|
||||||
|
n := 0
|
||||||
|
for i := range samples {
|
||||||
|
n += len(samples[i].QA)
|
||||||
|
}
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
func truncateStr(s string, maxLen int) string {
|
||||||
|
s = strings.ReplaceAll(s, "\n", " ")
|
||||||
|
runes := []rune(s)
|
||||||
|
if len(runes) > maxLen {
|
||||||
|
return string(runes[:maxLen]) + "..."
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
@@ -102,3 +102,81 @@ func TestComputeModeAgg(t *testing.T) {
|
|||||||
t.Errorf("TotalQuestions = %d, want 10", got.TotalQuestions)
|
t.Errorf("TotalQuestions = %d, want 10", got.TotalQuestions)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAggregateMetricsSentinel(t *testing.T) {
|
||||||
|
qa := []QAResult{
|
||||||
|
{Category: 1, TokenF1: 0.8, HitRate: 0.5},
|
||||||
|
{Category: 1, TokenF1: -1.0, HitRate: 0.3},
|
||||||
|
{Category: 1, TokenF1: 0.4, HitRate: 0.7},
|
||||||
|
}
|
||||||
|
agg := aggregateMetrics(qa)
|
||||||
|
|
||||||
|
if agg.ValidF1Count != 2 {
|
||||||
|
t.Errorf("ValidF1Count = %d, want 2", agg.ValidF1Count)
|
||||||
|
}
|
||||||
|
if agg.TotalQuestions != 3 {
|
||||||
|
t.Errorf("TotalQuestions = %d, want 3", agg.TotalQuestions)
|
||||||
|
}
|
||||||
|
wantF1 := (0.8 + 0.4) / 2.0
|
||||||
|
if math.Abs(agg.OverallF1-wantF1) > 1e-9 {
|
||||||
|
t.Errorf("OverallF1 = %.6f, want %.6f", agg.OverallF1, wantF1)
|
||||||
|
}
|
||||||
|
wantHR := (0.5 + 0.3 + 0.7) / 3.0
|
||||||
|
if math.Abs(agg.OverallHitRate-wantHR) > 1e-9 {
|
||||||
|
t.Errorf("OverallHitRate = %.6f, want %.6f", agg.OverallHitRate, wantHR)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAggregateMetricsAllSentinel(t *testing.T) {
|
||||||
|
qa := []QAResult{
|
||||||
|
{Category: 1, TokenF1: -1.0, HitRate: 0.5},
|
||||||
|
{Category: 1, TokenF1: -1.0, HitRate: 0.3},
|
||||||
|
}
|
||||||
|
agg := aggregateMetrics(qa)
|
||||||
|
|
||||||
|
if agg.ValidF1Count != 0 {
|
||||||
|
t.Errorf("ValidF1Count = %d, want 0", agg.ValidF1Count)
|
||||||
|
}
|
||||||
|
if agg.OverallF1 != 0 {
|
||||||
|
t.Errorf("OverallF1 = %.6f, want 0", agg.OverallF1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComputeModeAggSentinelWeighting(t *testing.T) {
|
||||||
|
results := []EvalResult{
|
||||||
|
{
|
||||||
|
Mode: "test",
|
||||||
|
SampleID: "s1",
|
||||||
|
QAResults: []QAResult{
|
||||||
|
{Category: 1, TokenF1: 0.8, HitRate: 0.5},
|
||||||
|
{Category: 1, TokenF1: -1.0, HitRate: 0.3},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Mode: "test",
|
||||||
|
SampleID: "s2",
|
||||||
|
QAResults: []QAResult{
|
||||||
|
{Category: 1, TokenF1: 0.4, HitRate: 0.6},
|
||||||
|
{Category: 1, TokenF1: 0.6, HitRate: 0.8},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for i := range results {
|
||||||
|
results[i].Agg = aggregateMetrics(results[i].QAResults)
|
||||||
|
}
|
||||||
|
|
||||||
|
got := computeModeAgg(results)
|
||||||
|
|
||||||
|
// s1: ValidF1Count=1, F1=0.8; s2: ValidF1Count=2, F1=0.5
|
||||||
|
// Weighted: (0.8*1 + 0.5*2) / 3 = 1.8/3 = 0.6
|
||||||
|
wantF1 := 0.6
|
||||||
|
if math.Abs(got.OverallF1-wantF1) > 1e-9 {
|
||||||
|
t.Errorf("OverallF1 = %.6f, want %.6f", got.OverallF1, wantF1)
|
||||||
|
}
|
||||||
|
if got.ValidF1Count != 3 {
|
||||||
|
t.Errorf("ValidF1Count = %d, want 3", got.ValidF1Count)
|
||||||
|
}
|
||||||
|
if got.TotalQuestions != 4 {
|
||||||
|
t.Errorf("TotalQuestions = %d, want 4", got.TotalQuestions)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,198 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// LLMClient wraps an OpenAI-compatible chat completion endpoint.
|
||||||
|
type LLMClient struct {
|
||||||
|
BaseURL string
|
||||||
|
Model string
|
||||||
|
APIKey string
|
||||||
|
NoThinking bool // send chat_template_kwargs to disable thinking (llama.cpp specific)
|
||||||
|
MaxRetries int // max retry attempts for transient errors (0 = no retry)
|
||||||
|
Client *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// LLMClientOptions configures the LLM client.
|
||||||
|
type LLMClientOptions struct {
|
||||||
|
BaseURL string
|
||||||
|
Model string
|
||||||
|
APIKey string
|
||||||
|
Timeout time.Duration
|
||||||
|
NoThinking bool
|
||||||
|
MaxRetries int // max retry attempts (default 3)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewLLMClient creates a client for an OpenAI-compatible chat completion API.
|
||||||
|
func NewLLMClient(opts LLMClientOptions) *LLMClient {
|
||||||
|
if opts.Timeout == 0 {
|
||||||
|
opts.Timeout = 120 * time.Second
|
||||||
|
}
|
||||||
|
maxRetries := opts.MaxRetries
|
||||||
|
if maxRetries < 0 {
|
||||||
|
maxRetries = 3
|
||||||
|
}
|
||||||
|
return &LLMClient{
|
||||||
|
BaseURL: strings.TrimRight(opts.BaseURL, "/"),
|
||||||
|
Model: opts.Model,
|
||||||
|
APIKey: opts.APIKey,
|
||||||
|
NoThinking: opts.NoThinking,
|
||||||
|
MaxRetries: maxRetries,
|
||||||
|
Client: &http.Client{
|
||||||
|
Timeout: opts.Timeout,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type chatRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Messages []chatMessage `json:"messages"`
|
||||||
|
Temperature float64 `json:"temperature"`
|
||||||
|
MaxTokens int `json:"max_tokens"`
|
||||||
|
ChatTemplateKwargs map[string]any `json:"chat_template_kwargs,omitempty"` // llama.cpp
|
||||||
|
Think *bool `json:"think,omitempty"` // Ollama
|
||||||
|
Thinking map[string]any `json:"thinking,omitempty"` // GLM (智谱)
|
||||||
|
}
|
||||||
|
|
||||||
|
type chatMessage struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type chatResponse struct {
|
||||||
|
Choices []struct {
|
||||||
|
Message struct {
|
||||||
|
Content string `json:"content"`
|
||||||
|
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||||
|
} `json:"message"`
|
||||||
|
} `json:"choices"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Complete sends a chat completion request and returns the assistant's reply.
|
||||||
|
func (c *LLMClient) Complete(ctx context.Context, systemPrompt, userPrompt string) (string, error) {
|
||||||
|
sysContent := systemPrompt
|
||||||
|
if c.NoThinking && sysContent != "" {
|
||||||
|
// Prepend /no_think tag — works with Ollama /v1 endpoint and
|
||||||
|
// Qwen chat templates where the JSON think field is ignored.
|
||||||
|
sysContent = "/no_think\n" + sysContent
|
||||||
|
}
|
||||||
|
messages := []chatMessage{}
|
||||||
|
if sysContent != "" {
|
||||||
|
messages = append(messages, chatMessage{Role: "system", Content: sysContent})
|
||||||
|
}
|
||||||
|
messages = append(messages, chatMessage{Role: "user", Content: userPrompt})
|
||||||
|
|
||||||
|
body := chatRequest{
|
||||||
|
Model: c.Model,
|
||||||
|
Messages: messages,
|
||||||
|
Temperature: 0.1,
|
||||||
|
MaxTokens: 512,
|
||||||
|
}
|
||||||
|
if c.NoThinking {
|
||||||
|
// llama.cpp: chat_template_kwargs
|
||||||
|
body.ChatTemplateKwargs = map[string]any{
|
||||||
|
"enable_thinking": false,
|
||||||
|
}
|
||||||
|
// Ollama (0.9+): think field
|
||||||
|
thinkFalse := false
|
||||||
|
body.Think = &thinkFalse
|
||||||
|
// GLM (智谱): thinking field
|
||||||
|
body.Thinking = map[string]any{
|
||||||
|
"type": "disabled",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonBody, err := json.Marshal(body)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("marshal request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
endpoint := strings.TrimRight(c.BaseURL, "/") + "/chat/completions"
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "POST", endpoint, bytes.NewReader(jsonBody))
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("create request: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
if c.APIKey != "" {
|
||||||
|
req.Header.Set("Authorization", "Bearer "+c.APIKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
var respBody []byte
|
||||||
|
var lastErr error
|
||||||
|
for attempt := 0; attempt <= c.MaxRetries; attempt++ {
|
||||||
|
if attempt > 0 {
|
||||||
|
backoff := time.Duration(1<<(attempt-1)) * time.Second // 1s, 2s, 4s, ...
|
||||||
|
log.Printf("LLM retry %d/%d after %v: %v", attempt, c.MaxRetries, backoff, lastErr)
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return "", ctx.Err()
|
||||||
|
case <-time.After(backoff):
|
||||||
|
}
|
||||||
|
// Rebuild request (body reader is consumed)
|
||||||
|
req, err = http.NewRequestWithContext(ctx, "POST", endpoint, bytes.NewReader(jsonBody))
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("create request: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
if c.APIKey != "" {
|
||||||
|
req.Header.Set("Authorization", "Bearer "+c.APIKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp *http.Response
|
||||||
|
resp, lastErr = c.Client.Do(req)
|
||||||
|
if lastErr != nil {
|
||||||
|
continue // network/timeout error → retry
|
||||||
|
}
|
||||||
|
|
||||||
|
respBody, lastErr = io.ReadAll(resp.Body)
|
||||||
|
resp.Body.Close()
|
||||||
|
if lastErr != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode == 429 || resp.StatusCode >= 500 {
|
||||||
|
lastErr = fmt.Errorf("API error %d: %s", resp.StatusCode, string(respBody))
|
||||||
|
continue // rate limit or server error → retry
|
||||||
|
}
|
||||||
|
if resp.StatusCode != 200 {
|
||||||
|
return "", fmt.Errorf("API error %d: %s", resp.StatusCode, string(respBody))
|
||||||
|
}
|
||||||
|
|
||||||
|
lastErr = nil
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if lastErr != nil {
|
||||||
|
return "", fmt.Errorf("after %d retries: %w", c.MaxRetries, lastErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
var chatResp chatResponse
|
||||||
|
if err := json.Unmarshal(respBody, &chatResp); err != nil {
|
||||||
|
return "", fmt.Errorf("parse response: %w", err)
|
||||||
|
}
|
||||||
|
if len(chatResp.Choices) == 0 {
|
||||||
|
return "", fmt.Errorf("no choices in response")
|
||||||
|
}
|
||||||
|
content := strings.TrimSpace(chatResp.Choices[0].Message.Content)
|
||||||
|
// Strip any residual <think>...</think> blocks
|
||||||
|
if idx := strings.Index(content, "</think>"); idx >= 0 {
|
||||||
|
content = strings.TrimSpace(content[idx+len("</think>"):])
|
||||||
|
}
|
||||||
|
// Fallback: GLM/DeepSeek put thinking output in reasoning_content when thinking is enabled
|
||||||
|
if content == "" && chatResp.Choices[0].Message.ReasoningContent != "" {
|
||||||
|
content = strings.TrimSpace(chatResp.Choices[0].Message.ReasoningContent)
|
||||||
|
}
|
||||||
|
if content == "" {
|
||||||
|
return "", fmt.Errorf("empty LLM response")
|
||||||
|
}
|
||||||
|
return content, nil
|
||||||
|
}
|
||||||
+166
-13
@@ -8,6 +8,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
@@ -15,10 +16,22 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
flagData string
|
flagData string
|
||||||
flagOut string
|
flagOut string
|
||||||
flagMode string
|
flagMode string
|
||||||
flagBudget int
|
flagBudget int
|
||||||
|
flagEvalMode string
|
||||||
|
flagAPIBase string
|
||||||
|
flagAPIKey string
|
||||||
|
flagModel string
|
||||||
|
flagNoThinking bool
|
||||||
|
flagLimit int
|
||||||
|
flagTimeout int
|
||||||
|
flagRetries int
|
||||||
|
flagJudgeModel string
|
||||||
|
flagJudgeAPIBase string
|
||||||
|
flagJudgeAPIKey string
|
||||||
|
flagConcurrency int
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
@@ -48,6 +61,22 @@ func main() {
|
|||||||
evalCmd.Flags().StringVar(&flagOut, "out", "./bench-out", "output working directory")
|
evalCmd.Flags().StringVar(&flagOut, "out", "./bench-out", "output working directory")
|
||||||
evalCmd.Flags().StringVar(&flagMode, "mode", "all", "modes to evaluate: legacy, seahorse, or all")
|
evalCmd.Flags().StringVar(&flagMode, "mode", "all", "modes to evaluate: legacy, seahorse, or all")
|
||||||
evalCmd.Flags().IntVar(&flagBudget, "budget", 4000, "token budget for retrieval")
|
evalCmd.Flags().IntVar(&flagBudget, "budget", 4000, "token budget for retrieval")
|
||||||
|
evalCmd.Flags().
|
||||||
|
StringVar(&flagEvalMode, "eval-mode", "token", "evaluation mode: token (direct match) or llm (LLM-as-Judge)")
|
||||||
|
evalCmd.Flags().
|
||||||
|
StringVar(&flagAPIBase, "api-base", "", "API base URL with version path, e.g. http://host/v1 (default: http://127.0.0.1:8080/v1, env: MEMBENCH_API_BASE)")
|
||||||
|
evalCmd.Flags().StringVar(&flagAPIKey, "api-key", "", "API key for the LLM endpoint (env: MEMBENCH_API_KEY)")
|
||||||
|
evalCmd.Flags().StringVar(&flagModel, "model", "", "model name for LLM eval (env: MEMBENCH_MODEL)")
|
||||||
|
evalCmd.Flags().
|
||||||
|
BoolVar(&flagNoThinking, "no-thinking", false, "disable thinking mode via chat_template_kwargs (llama.cpp + Qwen)")
|
||||||
|
evalCmd.Flags().IntVar(&flagLimit, "limit", 0, "max QA questions per sample (0 = all)")
|
||||||
|
evalCmd.Flags().IntVar(&flagTimeout, "timeout", 120, "HTTP timeout in seconds for LLM requests")
|
||||||
|
evalCmd.Flags().IntVar(&flagRetries, "retries", 3, "max retry attempts for transient LLM errors (timeout/5xx/429)")
|
||||||
|
evalCmd.Flags().StringVar(&flagJudgeModel, "judge-model", "", "model for judge scoring (defaults to --model)")
|
||||||
|
evalCmd.Flags().
|
||||||
|
StringVar(&flagJudgeAPIBase, "judge-api-base", "", "API base URL for judge model (defaults to --api-base)")
|
||||||
|
evalCmd.Flags().StringVar(&flagJudgeAPIKey, "judge-api-key", "", "API key for judge model (defaults to --api-key)")
|
||||||
|
evalCmd.Flags().IntVar(&flagConcurrency, "concurrency", 1, "number of concurrent QA evaluations")
|
||||||
|
|
||||||
reportCmd := &cobra.Command{
|
reportCmd := &cobra.Command{
|
||||||
Use: "report",
|
Use: "report",
|
||||||
@@ -65,6 +94,22 @@ func main() {
|
|||||||
runCmd.Flags().StringVar(&flagOut, "out", "./bench-out", "output working directory")
|
runCmd.Flags().StringVar(&flagOut, "out", "./bench-out", "output working directory")
|
||||||
runCmd.Flags().StringVar(&flagMode, "mode", "all", "modes to run: legacy, seahorse, or all")
|
runCmd.Flags().StringVar(&flagMode, "mode", "all", "modes to run: legacy, seahorse, or all")
|
||||||
runCmd.Flags().IntVar(&flagBudget, "budget", 4000, "token budget for retrieval")
|
runCmd.Flags().IntVar(&flagBudget, "budget", 4000, "token budget for retrieval")
|
||||||
|
runCmd.Flags().
|
||||||
|
StringVar(&flagEvalMode, "eval-mode", "token", "evaluation mode: token (direct match) or llm (LLM-as-Judge)")
|
||||||
|
runCmd.Flags().
|
||||||
|
StringVar(&flagAPIBase, "api-base", "", "API base URL with version path, e.g. http://host/v1 (default: http://127.0.0.1:8080/v1, env: MEMBENCH_API_BASE)")
|
||||||
|
runCmd.Flags().StringVar(&flagAPIKey, "api-key", "", "API key for the LLM endpoint (env: MEMBENCH_API_KEY)")
|
||||||
|
runCmd.Flags().StringVar(&flagModel, "model", "", "model name for LLM eval (env: MEMBENCH_MODEL)")
|
||||||
|
runCmd.Flags().
|
||||||
|
BoolVar(&flagNoThinking, "no-thinking", false, "disable thinking mode via chat_template_kwargs (llama.cpp + Qwen)")
|
||||||
|
runCmd.Flags().IntVar(&flagLimit, "limit", 0, "max QA questions per sample (0 = all)")
|
||||||
|
runCmd.Flags().IntVar(&flagTimeout, "timeout", 120, "HTTP timeout in seconds for LLM requests")
|
||||||
|
runCmd.Flags().IntVar(&flagRetries, "retries", 3, "max retry attempts for transient LLM errors (timeout/5xx/429)")
|
||||||
|
runCmd.Flags().StringVar(&flagJudgeModel, "judge-model", "", "model for judge scoring (defaults to --model)")
|
||||||
|
runCmd.Flags().
|
||||||
|
StringVar(&flagJudgeAPIBase, "judge-api-base", "", "API base URL for judge model (defaults to --api-base)")
|
||||||
|
runCmd.Flags().StringVar(&flagJudgeAPIKey, "judge-api-key", "", "API key for judge model (defaults to --api-key)")
|
||||||
|
runCmd.Flags().IntVar(&flagConcurrency, "concurrency", 1, "number of concurrent QA evaluations")
|
||||||
|
|
||||||
rootCmd.AddCommand(ingestCmd, evalCmd, reportCmd, runCmd)
|
rootCmd.AddCommand(ingestCmd, evalCmd, reportCmd, runCmd)
|
||||||
|
|
||||||
@@ -136,7 +181,50 @@ func runEval(cmd *cobra.Command, args []string) error {
|
|||||||
}
|
}
|
||||||
log.Printf("Loaded %d samples", len(samples))
|
log.Printf("Loaded %d samples", len(samples))
|
||||||
|
|
||||||
var allResults []EvalResult
|
if flagLimit > 0 {
|
||||||
|
for i := range samples {
|
||||||
|
if len(samples[i].QA) > flagLimit {
|
||||||
|
samples[i].QA = samples[i].QA[:flagLimit]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log.Printf("Limited to %d QA per sample", flagLimit)
|
||||||
|
}
|
||||||
|
|
||||||
|
evalMode := strings.ToLower(strings.TrimSpace(flagEvalMode))
|
||||||
|
var useLLM bool
|
||||||
|
switch evalMode {
|
||||||
|
case "token":
|
||||||
|
useLLM = false
|
||||||
|
case "llm":
|
||||||
|
useLLM = true
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("invalid --eval-mode %q: must be token or llm", flagEvalMode)
|
||||||
|
}
|
||||||
|
var answerClient, judgeClient *LLMClient
|
||||||
|
if useLLM {
|
||||||
|
opts, err := buildLLMOptions()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
answerClient = NewLLMClient(opts)
|
||||||
|
judgeClient = answerClient // default: same client
|
||||||
|
if flagJudgeModel != "" {
|
||||||
|
jOpts := opts // copy base settings
|
||||||
|
jOpts.Model = flagJudgeModel
|
||||||
|
if flagJudgeAPIBase != "" {
|
||||||
|
jOpts.BaseURL = flagJudgeAPIBase
|
||||||
|
}
|
||||||
|
if flagJudgeAPIKey != "" {
|
||||||
|
jOpts.APIKey = flagJudgeAPIKey
|
||||||
|
}
|
||||||
|
judgeClient = NewLLMClient(jOpts)
|
||||||
|
log.Printf("Judge model: model=%s base=%s no-thinking=%v", jOpts.Model, jOpts.BaseURL, jOpts.NoThinking)
|
||||||
|
}
|
||||||
|
log.Printf("LLM eval mode: model=%s base=%s no-thinking=%v concurrency=%d",
|
||||||
|
opts.Model, opts.BaseURL, opts.NoThinking, flagConcurrency)
|
||||||
|
}
|
||||||
|
|
||||||
|
var tokenResults, llmResults []EvalResult
|
||||||
|
|
||||||
for _, mode := range modes {
|
for _, mode := range modes {
|
||||||
switch mode {
|
switch mode {
|
||||||
@@ -145,21 +233,34 @@ func runEval(cmd *cobra.Command, args []string) error {
|
|||||||
for i := range samples {
|
for i := range samples {
|
||||||
legacy.IngestSample(&samples[i])
|
legacy.IngestSample(&samples[i])
|
||||||
}
|
}
|
||||||
results := EvalLegacy(ctx, samples, legacy, flagBudget)
|
if useLLM {
|
||||||
allResults = append(allResults, results...)
|
results := EvalLegacyLLM(ctx, samples, legacy, flagBudget, answerClient, judgeClient, flagConcurrency)
|
||||||
log.Printf("legacy: evaluated %d samples", len(results))
|
llmResults = append(llmResults, results...)
|
||||||
|
log.Printf("legacy-llm: evaluated %d samples", len(results))
|
||||||
|
} else {
|
||||||
|
results := EvalLegacy(ctx, samples, legacy, flagBudget)
|
||||||
|
tokenResults = append(tokenResults, results...)
|
||||||
|
log.Printf("legacy: evaluated %d samples", len(results))
|
||||||
|
}
|
||||||
case "seahorse":
|
case "seahorse":
|
||||||
dbPath := filepath.Join(flagOut, "seahorse.db")
|
dbPath := filepath.Join(flagOut, "seahorse.db")
|
||||||
ir, err := IngestSeahorse(ctx, samples, dbPath)
|
ir, err := IngestSeahorse(ctx, samples, dbPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("ingest seahorse: %w", err)
|
return fmt.Errorf("ingest seahorse: %w", err)
|
||||||
}
|
}
|
||||||
results := EvalSeahorse(ctx, samples, ir, flagBudget)
|
if useLLM {
|
||||||
allResults = append(allResults, results...)
|
results := EvalSeahorseLLM(ctx, samples, ir, flagBudget, answerClient, judgeClient, flagConcurrency)
|
||||||
log.Printf("seahorse: evaluated %d samples", len(results))
|
llmResults = append(llmResults, results...)
|
||||||
|
log.Printf("seahorse-llm: evaluated %d samples", len(results))
|
||||||
|
} else {
|
||||||
|
results := EvalSeahorse(ctx, samples, ir, flagBudget)
|
||||||
|
tokenResults = append(tokenResults, results...)
|
||||||
|
log.Printf("seahorse: evaluated %d samples", len(results))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
allResults := append(tokenResults, llmResults...)
|
||||||
if err := SaveResults(allResults, flagOut); err != nil {
|
if err := SaveResults(allResults, flagOut); err != nil {
|
||||||
return fmt.Errorf("save results: %w", err)
|
return fmt.Errorf("save results: %w", err)
|
||||||
}
|
}
|
||||||
@@ -167,7 +268,7 @@ func runEval(cmd *cobra.Command, args []string) error {
|
|||||||
return fmt.Errorf("save aggregated: %w", err)
|
return fmt.Errorf("save aggregated: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
PrintComparison(allResults, nil)
|
PrintComparison(tokenResults, llmResults)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -199,10 +300,62 @@ func runReport(cmd *cobra.Command, args []string) error {
|
|||||||
return fmt.Errorf("no eval results found in %s", flagOut)
|
return fmt.Errorf("no eval results found in %s", flagOut)
|
||||||
}
|
}
|
||||||
|
|
||||||
PrintComparison(allResults, nil)
|
var tokenResults, llmResults []EvalResult
|
||||||
|
for _, r := range allResults {
|
||||||
|
if strings.HasSuffix(r.Mode, "-llm") {
|
||||||
|
llmResults = append(llmResults, r)
|
||||||
|
} else {
|
||||||
|
tokenResults = append(tokenResults, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
PrintComparison(tokenResults, llmResults)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func runAll(cmd *cobra.Command, args []string) error {
|
func runAll(cmd *cobra.Command, args []string) error {
|
||||||
return runEval(cmd, args)
|
return runEval(cmd, args)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// envOrFlag returns the flag value if non-empty, otherwise falls back to the
|
||||||
|
// environment variable.
|
||||||
|
func envOrFlag(flag, envKey string) string {
|
||||||
|
if flag != "" {
|
||||||
|
return flag
|
||||||
|
}
|
||||||
|
return os.Getenv(envKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildLLMOptions resolves LLM client configuration from flags and environment
|
||||||
|
// variables. Flag values take precedence over environment variables.
|
||||||
|
//
|
||||||
|
// Environment variables:
|
||||||
|
//
|
||||||
|
// MEMBENCH_API_BASE – OpenAI-compatible base URL (default http://127.0.0.1:8080/v1)
|
||||||
|
// MEMBENCH_API_KEY – Bearer token for the endpoint
|
||||||
|
// MEMBENCH_MODEL – Model name to send in the request
|
||||||
|
func buildLLMOptions() (LLMClientOptions, error) {
|
||||||
|
base := envOrFlag(flagAPIBase, "MEMBENCH_API_BASE")
|
||||||
|
if base == "" {
|
||||||
|
base = "http://127.0.0.1:8080/v1"
|
||||||
|
}
|
||||||
|
model := envOrFlag(flagModel, "MEMBENCH_MODEL")
|
||||||
|
if model == "" {
|
||||||
|
return LLMClientOptions{}, fmt.Errorf(
|
||||||
|
"--model or MEMBENCH_MODEL is required for LLM eval mode",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
apiKey := envOrFlag(flagAPIKey, "MEMBENCH_API_KEY")
|
||||||
|
|
||||||
|
if flagTimeout <= 0 {
|
||||||
|
return LLMClientOptions{}, fmt.Errorf("--timeout must be > 0, got %d", flagTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
return LLMClientOptions{
|
||||||
|
BaseURL: base,
|
||||||
|
Model: model,
|
||||||
|
APIKey: apiKey,
|
||||||
|
NoThinking: flagNoThinking,
|
||||||
|
Timeout: time.Duration(flagTimeout) * time.Second,
|
||||||
|
MaxRetries: flagRetries,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user