mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
f1b659e5ef
* membench: add LLM-as-Judge evaluation mode Add --eval-mode=llm to membench for LLM-based answer generation and semantic scoring via an OpenAI-compatible API endpoint. New files: - llm_client.go: generic OpenAI-compatible chat completion client with support for API key, configurable timeout, and optional chat_template_kwargs (for llama.cpp thinking models) - eval_llm.go: LLM answer generation + LLM-as-Judge scoring for both legacy and seahorse retrieval modes Changes to main.go: - --eval-mode flag (token|llm) to select evaluation strategy - --api-base, --api-key, --model flags with env var fallback (MEMBENCH_API_BASE, MEMBENCH_API_KEY, MEMBENCH_MODEL) - --no-thinking flag for llama.cpp + Qwen thinking models - --limit flag to cap QA questions per sample for quick testing * style: fix golangci-lint formatting (gofmt + golines) * fix: address Copilot review feedback - Validate --model is required for LLM eval mode - Use rune-based truncation to preserve valid UTF-8 - Precompute totalQA count outside inner loop - Log SearchMessages errors instead of silently skipping * fix: address Copilot review round 2 - Validate --eval-mode accepts only 'token' or 'llm' - Normalize base URL to avoid /v1/v1 duplication - Separate token/LLM results for correct PrintComparison labeling - Log ExpandMessages errors instead of silently ignoring - Short-circuit with 0 scores when no context retrieved (match token eval) - Add --timeout flag wired to LLMClientOptions.Timeout * fix: address review P1+P2 — sort alignment, failure sentinel, score parser - P1: Replace hand-rolled sortByRank with sort.Slice (ascending, best first) matching eval.go's EvalSeahorse — ensures BudgetTruncate keeps best-ranked messages when truncation occurs - P2: Use -1.0 sentinel for LLM API failures and parse errors, distinct from genuine 0.0 score; aggregateMetrics skips -1.0 entries for F1 averaging while still counting HitRate - P2: Use regexp \b([1-5])\b for judge score extraction instead of first-digit scan — avoids misparses on '5/5', 'Score: 3' etc. * fix: address Copilot review round 2 - Fix F1/HitRate weighted aggregation: track ValidF1Count separately so computeModeAgg weights F1 by valid scores only, not TotalQuestions - No-context retrieval failure uses 0.0 (genuine bad score) instead of -1.0 sentinel (reserved for API/parse failures) - Validate --timeout > 0 to prevent disabling HTTP timeouts * fix: remove hardcoded /v1 from API base URL Users now provide the full versioned path in --api-base (e.g. /v1, /v4). Code only appends /chat/completions. Default changed to http://127.0.0.1:8080/v1 for backward compatibility. * fix: address Copilot review round 3 - ValidF1Count=0 when all scores are sentinel (no forced =1) - Backward compat: old eval JSON without ValidF1Count falls back to TotalQuestions in computeModeAgg - Skip empty section in PrintComparison when tokenResults is empty - Update --api-base flag help to document /v1 default and version path - Add sentinel aggregation unit tests (partial, all, weighted) * feat: add --retries flag with exponential backoff for transient LLM errors Retry on timeout, 5xx, and 429 (rate limit) with 1s/2s/4s backoff. Default 3 retries, configurable via --retries. Context cancellation is respected between retries. * fix: address Copilot review round 4 - runReport splits results by mode suffix into token/llm for PrintComparison - backward compat fallback (ValidF1Count=0 -> TotalQuestions) only for non-LLM modes; LLM modes keep ValidF1Count=0 when all scores sentinel - MaxRetries==0 means no retry; only negative falls back to default 3 - truncateStr uses []rune to avoid cutting multi-byte UTF-8 characters - Complete() returns error on empty LLM response (vs silent empty string) * feat: --no-thinking adapts to llama.cpp, Ollama, and GLM backends Send all three disable-thinking fields simultaneously: - chat_template_kwargs.enable_thinking=false (llama.cpp, GLM) - think=false (Ollama 0.9+) - thinking.type=disabled (GLM/Zhipu) Each backend picks the field it recognizes and ignores the rest. Also bumps max_tokens from 512 to 2048 for thinking models. * feat: mixed model eval + concurrent QA workers - Add --judge-model, --judge-api-base, --judge-api-key flags for separate judge model - Add --concurrency flag (default 1) with semaphore-based goroutine pool - Add reasoning_content fallback for GLM/DeepSeek style responses - Prepend /no_think to system prompt for Ollama /v1 compatibility - Reduce default MaxTokens from 2048 to 512 (answers are 1-3 sentences) - Extract evalQAWorker and buildSeahorseContext for shared concurrent logic --------- Co-authored-by: BeaconCat <BeaconCat@users.noreply.github.com>
413 lines
11 KiB
Go
413 lines
11 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log"
|
|
"os"
|
|
"path/filepath"
|
|
"sort"
|
|
"strings"
|
|
|
|
"github.com/sipeed/picoclaw/pkg/seahorse"
|
|
)
|
|
|
|
// EvalResult holds per-sample evaluation results for one mode.
|
|
type EvalResult struct {
|
|
Mode string `json:"mode"`
|
|
SampleID string `json:"sampleId"`
|
|
QAResults []QAResult `json:"qaResults"`
|
|
Agg AggMetrics `json:"aggregated"`
|
|
}
|
|
|
|
// QAResult holds metrics for a single QA pair.
|
|
type QAResult struct {
|
|
Question string `json:"question"`
|
|
Category int `json:"category"`
|
|
GoldAnswer string `json:"goldAnswer"`
|
|
TokenF1 float64 `json:"tokenF1"`
|
|
HitRate float64 `json:"hitRate"`
|
|
}
|
|
|
|
// AggMetrics holds aggregated evaluation metrics.
|
|
type AggMetrics struct {
|
|
OverallF1 float64 `json:"overallF1"`
|
|
OverallHitRate float64 `json:"overallHitRate"`
|
|
ByCategory map[int]*CatMetrics `json:"byCategory"`
|
|
TotalQuestions int `json:"totalQuestions"`
|
|
ValidF1Count int `json:"validF1Count"`
|
|
}
|
|
|
|
// CatMetrics holds metrics for a single category.
|
|
type CatMetrics struct {
|
|
F1 float64 `json:"f1"`
|
|
HitRate float64 `json:"hitRate"`
|
|
QuestionCount int `json:"questionCount"`
|
|
ValidF1Count int `json:"validF1Count"`
|
|
}
|
|
|
|
// EvalLegacy evaluates using legacy session store (raw history + budget truncation).
|
|
func EvalLegacy(
|
|
ctx context.Context,
|
|
samples []LocomoSample,
|
|
legacy *LegacyStore,
|
|
budgetTokens int,
|
|
) []EvalResult {
|
|
results := make([]EvalResult, 0, len(samples))
|
|
for si := range samples {
|
|
sample := &samples[si]
|
|
history := legacy.GetHistory(sample.SampleID)
|
|
|
|
// Convert messages to content strings
|
|
allContent := make([]string, 0, len(history))
|
|
for _, msg := range history {
|
|
allContent = append(allContent, msg.Content)
|
|
}
|
|
|
|
qaResults := make([]QAResult, 0, len(sample.QA))
|
|
for qi := range sample.QA {
|
|
qa := &sample.QA[qi]
|
|
// Budget truncate the full history
|
|
truncated, _ := BudgetTruncate(allContent, budgetTokens)
|
|
context := StringListToContent(truncated)
|
|
|
|
f1 := TokenOverlapF1(context, qa.AnswerString())
|
|
hitRate := RecallHitRate(qa.Evidence, sample, context)
|
|
|
|
qaResults = append(qaResults, QAResult{
|
|
Question: qa.Question,
|
|
Category: qa.Category,
|
|
GoldAnswer: qa.AnswerString(),
|
|
TokenF1: f1,
|
|
HitRate: hitRate,
|
|
})
|
|
}
|
|
|
|
results = append(results, EvalResult{
|
|
Mode: "legacy",
|
|
SampleID: sample.SampleID,
|
|
QAResults: qaResults,
|
|
Agg: aggregateMetrics(qaResults),
|
|
})
|
|
}
|
|
return results
|
|
}
|
|
|
|
// EvalSeahorse evaluates using seahorse short memory (per-keyword search + expand).
|
|
func EvalSeahorse(
|
|
ctx context.Context,
|
|
samples []LocomoSample,
|
|
ir *SeahorseIngestResult,
|
|
budgetTokens int,
|
|
) []EvalResult {
|
|
store := ir.Engine.GetRetrieval().Store()
|
|
retrieval := ir.Engine.GetRetrieval()
|
|
|
|
results := make([]EvalResult, 0, len(samples))
|
|
for si := range samples {
|
|
sample := &samples[si]
|
|
convID, ok := ir.ConvMap[sample.SampleID]
|
|
if !ok {
|
|
log.Printf("WARN: no conversation ID for sample %s", sample.SampleID)
|
|
continue
|
|
}
|
|
|
|
qaResults := make([]QAResult, 0, len(sample.QA))
|
|
for qi := range sample.QA {
|
|
qa := &sample.QA[qi]
|
|
keywords := ExtractKeywords(qa.Question)
|
|
|
|
// Search each keyword individually and union results,
|
|
// tracking best BM25 rank per message for relevance sorting.
|
|
bestRank := map[int64]float64{}
|
|
for _, kw := range keywords {
|
|
searchResults, err := store.SearchMessages(ctx, seahorse.SearchInput{
|
|
Pattern: kw,
|
|
ConversationID: convID,
|
|
Limit: 20,
|
|
})
|
|
if err != nil {
|
|
log.Printf("WARN: search failed for keyword %q: %v", kw, err)
|
|
continue
|
|
}
|
|
for _, sr := range searchResults {
|
|
if sr.MessageID > 0 {
|
|
if prev, ok := bestRank[sr.MessageID]; !ok || sr.Rank < prev {
|
|
bestRank[sr.MessageID] = sr.Rank
|
|
}
|
|
}
|
|
}
|
|
}
|
|
// Sort messageIDs by rank ascending (best/most-negative first).
|
|
// BudgetTruncate walks from the front, keeping best-ranked messages.
|
|
// Note: SQLite FTS5 bm25() returns negative values where more
|
|
// negative = better match.
|
|
messageIDs := make([]int64, 0, len(bestRank))
|
|
for id := range bestRank {
|
|
messageIDs = append(messageIDs, id)
|
|
}
|
|
sort.Slice(messageIDs, func(i, j int) bool {
|
|
return bestRank[messageIDs[i]] < bestRank[messageIDs[j]]
|
|
})
|
|
|
|
// Expand messages to get full content
|
|
var contentParts []string
|
|
if len(messageIDs) > 0 {
|
|
expandResult, err := retrieval.ExpandMessages(ctx, messageIDs)
|
|
if err != nil {
|
|
log.Printf("WARN: expand failed for sample %s: %v", sample.SampleID, err)
|
|
} else {
|
|
for _, msg := range expandResult.Messages {
|
|
contentParts = append(contentParts, msg.Content)
|
|
}
|
|
}
|
|
}
|
|
|
|
if len(contentParts) == 0 {
|
|
qaResults = append(qaResults, QAResult{
|
|
Question: qa.Question,
|
|
Category: qa.Category,
|
|
GoldAnswer: qa.AnswerString(),
|
|
TokenF1: 0.0,
|
|
HitRate: 0.0,
|
|
})
|
|
continue
|
|
}
|
|
|
|
// Budget truncate (drop worst-ranked)
|
|
truncated, _ := BudgetTruncate(contentParts, budgetTokens)
|
|
context := StringListToContent(truncated)
|
|
|
|
f1 := TokenOverlapF1(context, qa.AnswerString())
|
|
hitRate := RecallHitRate(qa.Evidence, sample, context)
|
|
|
|
qaResults = append(qaResults, QAResult{
|
|
Question: qa.Question,
|
|
Category: qa.Category,
|
|
GoldAnswer: qa.AnswerString(),
|
|
TokenF1: f1,
|
|
HitRate: hitRate,
|
|
})
|
|
}
|
|
|
|
results = append(results, EvalResult{
|
|
Mode: "seahorse",
|
|
SampleID: sample.SampleID,
|
|
QAResults: qaResults,
|
|
Agg: aggregateMetrics(qaResults),
|
|
})
|
|
}
|
|
return results
|
|
}
|
|
|
|
// aggregateMetrics computes overall and per-category metrics.
|
|
func aggregateMetrics(qaResults []QAResult) AggMetrics {
|
|
type catAccum struct {
|
|
f1Sum float64
|
|
f1Count int
|
|
hitRateSum float64
|
|
hitRateCount int
|
|
}
|
|
byCatAcc := map[int]*catAccum{}
|
|
totalF1 := 0.0
|
|
totalHitRate := 0.0
|
|
validF1Count := 0
|
|
for _, qr := range qaResults {
|
|
// Skip sentinel -1.0 scores (LLM API/parse failures) from F1 averaging.
|
|
if qr.TokenF1 >= 0 {
|
|
totalF1 += qr.TokenF1
|
|
validF1Count++
|
|
}
|
|
totalHitRate += qr.HitRate
|
|
acc, ok := byCatAcc[qr.Category]
|
|
if !ok {
|
|
acc = &catAccum{}
|
|
byCatAcc[qr.Category] = acc
|
|
}
|
|
if qr.TokenF1 >= 0 {
|
|
acc.f1Sum += qr.TokenF1
|
|
acc.f1Count++
|
|
}
|
|
acc.hitRateSum += qr.HitRate
|
|
acc.hitRateCount++
|
|
}
|
|
nHit := len(qaResults)
|
|
if nHit == 0 {
|
|
nHit = 1
|
|
}
|
|
byCat := map[int]*CatMetrics{}
|
|
for cat, acc := range byCatAcc {
|
|
cm := &CatMetrics{
|
|
QuestionCount: acc.hitRateCount,
|
|
ValidF1Count: acc.f1Count,
|
|
}
|
|
if acc.f1Count > 0 {
|
|
cm.F1 = acc.f1Sum / float64(acc.f1Count)
|
|
}
|
|
if acc.hitRateCount > 0 {
|
|
cm.HitRate = acc.hitRateSum / float64(acc.hitRateCount)
|
|
}
|
|
byCat[cat] = cm
|
|
}
|
|
var overallF1 float64
|
|
if validF1Count > 0 {
|
|
overallF1 = totalF1 / float64(validF1Count)
|
|
}
|
|
return AggMetrics{
|
|
OverallF1: overallF1,
|
|
OverallHitRate: totalHitRate / float64(nHit),
|
|
ByCategory: byCat,
|
|
TotalQuestions: len(qaResults),
|
|
ValidF1Count: validF1Count,
|
|
}
|
|
}
|
|
|
|
// SaveResults writes per-sample eval results to JSON files.
|
|
func SaveResults(results []EvalResult, outDir string) error {
|
|
if err := os.MkdirAll(outDir, 0o755); err != nil {
|
|
return fmt.Errorf("create output dir: %w", err)
|
|
}
|
|
for _, r := range results {
|
|
path := filepath.Join(outDir, fmt.Sprintf("eval_%s_%s.json", r.Mode, r.SampleID))
|
|
data, err := json.MarshalIndent(r, "", " ")
|
|
if err != nil {
|
|
return fmt.Errorf("marshal result: %w", err)
|
|
}
|
|
if err := os.WriteFile(path, data, 0o644); err != nil {
|
|
return fmt.Errorf("write result: %w", err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// SaveAggregated writes a combined results.json with all modes.
|
|
func SaveAggregated(results []EvalResult, outDir string) error {
|
|
byMode := map[string][]EvalResult{}
|
|
for _, r := range results {
|
|
byMode[r.Mode] = append(byMode[r.Mode], r)
|
|
}
|
|
|
|
aggMap := map[string]AggMetrics{}
|
|
for mode, modeResults := range byMode {
|
|
aggMap[mode] = computeModeAgg(modeResults)
|
|
}
|
|
|
|
data, err := json.MarshalIndent(aggMap, "", " ")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return os.WriteFile(filepath.Join(outDir, "results.json"), data, 0o644)
|
|
}
|
|
|
|
// computeModeAgg aggregates results for a single mode using weighted averaging
|
|
// (weighted by question count per sample). All modes must have the same Mode field.
|
|
func computeModeAgg(results []EvalResult) AggMetrics {
|
|
agg := AggMetrics{ByCategory: map[int]*CatMetrics{}}
|
|
for _, r := range results {
|
|
// Backward compat: old eval JSON (token mode) without ValidF1Count → use TotalQuestions.
|
|
// LLM modes may legitimately have ValidF1Count==0 (all failures).
|
|
vf1 := r.Agg.ValidF1Count
|
|
if vf1 == 0 && r.Agg.TotalQuestions > 0 && !strings.HasSuffix(r.Mode, "-llm") {
|
|
vf1 = r.Agg.TotalQuestions
|
|
}
|
|
agg.OverallF1 += r.Agg.OverallF1 * float64(vf1)
|
|
agg.OverallHitRate += r.Agg.OverallHitRate * float64(r.Agg.TotalQuestions)
|
|
agg.TotalQuestions += r.Agg.TotalQuestions
|
|
agg.ValidF1Count += vf1
|
|
for cat, cm := range r.Agg.ByCategory {
|
|
existing, ok := agg.ByCategory[cat]
|
|
if !ok {
|
|
existing = &CatMetrics{}
|
|
agg.ByCategory[cat] = existing
|
|
}
|
|
cvf1 := cm.ValidF1Count
|
|
if cvf1 == 0 && cm.QuestionCount > 0 && !strings.HasSuffix(r.Mode, "-llm") {
|
|
cvf1 = cm.QuestionCount
|
|
}
|
|
existing.F1 += cm.F1 * float64(cvf1)
|
|
existing.HitRate += cm.HitRate * float64(cm.QuestionCount)
|
|
existing.QuestionCount += cm.QuestionCount
|
|
existing.ValidF1Count += cvf1
|
|
}
|
|
}
|
|
if agg.ValidF1Count > 0 {
|
|
agg.OverallF1 /= float64(agg.ValidF1Count)
|
|
}
|
|
if agg.TotalQuestions > 0 {
|
|
agg.OverallHitRate /= float64(agg.TotalQuestions)
|
|
}
|
|
for _, cat := range agg.ByCategory {
|
|
if cat.ValidF1Count > 0 {
|
|
cat.F1 /= float64(cat.ValidF1Count)
|
|
}
|
|
if cat.QuestionCount > 0 {
|
|
cat.HitRate /= float64(cat.QuestionCount)
|
|
}
|
|
}
|
|
return agg
|
|
}
|
|
|
|
// printSection prints a single comparison table section.
|
|
func printSection(title string, results []EvalResult) {
|
|
fmt.Printf("\n--- %s ---\n", title)
|
|
byMode := map[string][]EvalResult{}
|
|
for _, r := range results {
|
|
byMode[r.Mode] = append(byMode[r.Mode], r)
|
|
}
|
|
|
|
modes := map[string]AggMetrics{}
|
|
for mode, modeResults := range byMode {
|
|
modes[mode] = computeModeAgg(modeResults)
|
|
}
|
|
|
|
modeKeys := make([]string, 0, len(modes))
|
|
for k := range modes {
|
|
modeKeys = append(modeKeys, k)
|
|
}
|
|
sort.Strings(modeKeys)
|
|
|
|
// Collect all category keys across modes
|
|
catSet := map[int]bool{}
|
|
for _, agg := range modes {
|
|
for cat := range agg.ByCategory {
|
|
catSet[cat] = true
|
|
}
|
|
}
|
|
cats := make([]int, 0, len(catSet))
|
|
for cat := range catSet {
|
|
cats = append(cats, cat)
|
|
}
|
|
sort.Ints(cats)
|
|
|
|
fmt.Printf("%-10s %-8s %-8s", "Mode", "HitRate", "F1")
|
|
for _, cat := range cats {
|
|
fmt.Printf(" %-7s", fmt.Sprintf("C%d", cat))
|
|
}
|
|
fmt.Println()
|
|
fmt.Println(strings.Repeat("-", 10+8+8+7*len(cats)+8))
|
|
|
|
for _, mode := range modeKeys {
|
|
agg := modes[mode]
|
|
fmt.Printf("%-10s %-8.4f %-8.4f", mode, agg.OverallHitRate, agg.OverallF1)
|
|
for _, cat := range cats {
|
|
if cm, ok := agg.ByCategory[cat]; ok {
|
|
fmt.Printf(" %-7.4f", cm.HitRate)
|
|
} else {
|
|
fmt.Printf(" %-7s", "N/A")
|
|
}
|
|
}
|
|
fmt.Println()
|
|
}
|
|
}
|
|
|
|
// PrintComparison outputs a human-readable comparison table to stdout.
|
|
func PrintComparison(results []EvalResult, llmResults []EvalResult) {
|
|
if len(results) > 0 {
|
|
printSection("No LLM generation", results)
|
|
}
|
|
if len(llmResults) > 0 {
|
|
printSection("With LLM", llmResults)
|
|
}
|
|
}
|