mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
f1b659e5ef
* membench: add LLM-as-Judge evaluation mode Add --eval-mode=llm to membench for LLM-based answer generation and semantic scoring via an OpenAI-compatible API endpoint. New files: - llm_client.go: generic OpenAI-compatible chat completion client with support for API key, configurable timeout, and optional chat_template_kwargs (for llama.cpp thinking models) - eval_llm.go: LLM answer generation + LLM-as-Judge scoring for both legacy and seahorse retrieval modes Changes to main.go: - --eval-mode flag (token|llm) to select evaluation strategy - --api-base, --api-key, --model flags with env var fallback (MEMBENCH_API_BASE, MEMBENCH_API_KEY, MEMBENCH_MODEL) - --no-thinking flag for llama.cpp + Qwen thinking models - --limit flag to cap QA questions per sample for quick testing * style: fix golangci-lint formatting (gofmt + golines) * fix: address Copilot review feedback - Validate --model is required for LLM eval mode - Use rune-based truncation to preserve valid UTF-8 - Precompute totalQA count outside inner loop - Log SearchMessages errors instead of silently skipping * fix: address Copilot review round 2 - Validate --eval-mode accepts only 'token' or 'llm' - Normalize base URL to avoid /v1/v1 duplication - Separate token/LLM results for correct PrintComparison labeling - Log ExpandMessages errors instead of silently ignoring - Short-circuit with 0 scores when no context retrieved (match token eval) - Add --timeout flag wired to LLMClientOptions.Timeout * fix: address review P1+P2 — sort alignment, failure sentinel, score parser - P1: Replace hand-rolled sortByRank with sort.Slice (ascending, best first) matching eval.go's EvalSeahorse — ensures BudgetTruncate keeps best-ranked messages when truncation occurs - P2: Use -1.0 sentinel for LLM API failures and parse errors, distinct from genuine 0.0 score; aggregateMetrics skips -1.0 entries for F1 averaging while still counting HitRate - P2: Use regexp \b([1-5])\b for judge score extraction instead of first-digit scan — avoids misparses on '5/5', 'Score: 3' etc. * fix: address Copilot review round 2 - Fix F1/HitRate weighted aggregation: track ValidF1Count separately so computeModeAgg weights F1 by valid scores only, not TotalQuestions - No-context retrieval failure uses 0.0 (genuine bad score) instead of -1.0 sentinel (reserved for API/parse failures) - Validate --timeout > 0 to prevent disabling HTTP timeouts * fix: remove hardcoded /v1 from API base URL Users now provide the full versioned path in --api-base (e.g. /v1, /v4). Code only appends /chat/completions. Default changed to http://127.0.0.1:8080/v1 for backward compatibility. * fix: address Copilot review round 3 - ValidF1Count=0 when all scores are sentinel (no forced =1) - Backward compat: old eval JSON without ValidF1Count falls back to TotalQuestions in computeModeAgg - Skip empty section in PrintComparison when tokenResults is empty - Update --api-base flag help to document /v1 default and version path - Add sentinel aggregation unit tests (partial, all, weighted) * feat: add --retries flag with exponential backoff for transient LLM errors Retry on timeout, 5xx, and 429 (rate limit) with 1s/2s/4s backoff. Default 3 retries, configurable via --retries. Context cancellation is respected between retries. * fix: address Copilot review round 4 - runReport splits results by mode suffix into token/llm for PrintComparison - backward compat fallback (ValidF1Count=0 -> TotalQuestions) only for non-LLM modes; LLM modes keep ValidF1Count=0 when all scores sentinel - MaxRetries==0 means no retry; only negative falls back to default 3 - truncateStr uses []rune to avoid cutting multi-byte UTF-8 characters - Complete() returns error on empty LLM response (vs silent empty string) * feat: --no-thinking adapts to llama.cpp, Ollama, and GLM backends Send all three disable-thinking fields simultaneously: - chat_template_kwargs.enable_thinking=false (llama.cpp, GLM) - think=false (Ollama 0.9+) - thinking.type=disabled (GLM/Zhipu) Each backend picks the field it recognizes and ignores the rest. Also bumps max_tokens from 512 to 2048 for thinking models. * feat: mixed model eval + concurrent QA workers - Add --judge-model, --judge-api-base, --judge-api-key flags for separate judge model - Add --concurrency flag (default 1) with semaphore-based goroutine pool - Add reasoning_content fallback for GLM/DeepSeek style responses - Prepend /no_think to system prompt for Ollama /v1 compatibility - Reduce default MaxTokens from 2048 to 512 (answers are 1-3 sentences) - Extract evalQAWorker and buildSeahorseContext for shared concurrent logic --------- Co-authored-by: BeaconCat <BeaconCat@users.noreply.github.com>
183 lines
4.7 KiB
Go
183 lines
4.7 KiB
Go
package main
|
|
|
|
import (
|
|
"math"
|
|
"testing"
|
|
)
|
|
|
|
func TestComputeModeAggAllCategories(t *testing.T) {
|
|
results := []EvalResult{
|
|
{
|
|
Mode: "test",
|
|
SampleID: "s1",
|
|
QAResults: []QAResult{
|
|
{Category: 1, TokenF1: 0.5, HitRate: 0.8},
|
|
{Category: 2, TokenF1: 0.3, HitRate: 0.6},
|
|
{Category: 3, TokenF1: 0.1, HitRate: 0.4},
|
|
{Category: 4, TokenF1: 0.7, HitRate: 0.9},
|
|
{Category: 5, TokenF1: 0.2, HitRate: 0.1},
|
|
},
|
|
},
|
|
}
|
|
for i := range results {
|
|
results[i].Agg = aggregateMetrics(results[i].QAResults)
|
|
}
|
|
|
|
got := computeModeAgg(results)
|
|
|
|
// Should have all 5 categories
|
|
for cat := 1; cat <= 5; cat++ {
|
|
cm, ok := got.ByCategory[cat]
|
|
if !ok {
|
|
t.Errorf("ByCategory missing category %d", cat)
|
|
continue
|
|
}
|
|
if cm.QuestionCount != 1 {
|
|
t.Errorf("ByCategory[%d].QuestionCount = %d, want 1", cat, cm.QuestionCount)
|
|
}
|
|
}
|
|
|
|
// Verify specific F1 values per category
|
|
wantF1 := map[int]float64{1: 0.5, 2: 0.3, 3: 0.1, 4: 0.7, 5: 0.2}
|
|
for cat, want := range wantF1 {
|
|
if cm, ok := got.ByCategory[cat]; ok {
|
|
if math.Abs(cm.F1-want) > 1e-9 {
|
|
t.Errorf("ByCategory[%d].F1 = %.4f, want %.4f", cat, cm.F1, want)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestComputeModeAgg(t *testing.T) {
|
|
// Two samples with different question counts:
|
|
// sample-a: 2 questions, F1 = [0.4, 0.6] → avg 0.5
|
|
// sample-b: 8 questions, F1 = [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1] → avg 0.1
|
|
//
|
|
// Unweighted (PrintComparison bug): (0.5 + 0.1) / 2 = 0.3
|
|
// Weighted (correct): (0.4+0.6 + 0.1*8) / 10 = 1.8 / 10 = 0.18
|
|
results := []EvalResult{
|
|
{
|
|
Mode: "test",
|
|
SampleID: "sample-a",
|
|
QAResults: []QAResult{
|
|
{TokenF1: 0.4, HitRate: 0.5},
|
|
{TokenF1: 0.6, HitRate: 0.7},
|
|
},
|
|
},
|
|
{
|
|
Mode: "test",
|
|
SampleID: "sample-b",
|
|
QAResults: []QAResult{
|
|
{TokenF1: 0.1, HitRate: 0.2},
|
|
{TokenF1: 0.1, HitRate: 0.2},
|
|
{TokenF1: 0.1, HitRate: 0.2},
|
|
{TokenF1: 0.1, HitRate: 0.2},
|
|
{TokenF1: 0.1, HitRate: 0.2},
|
|
{TokenF1: 0.1, HitRate: 0.2},
|
|
{TokenF1: 0.1, HitRate: 0.2},
|
|
{TokenF1: 0.1, HitRate: 0.2},
|
|
},
|
|
},
|
|
}
|
|
// Compute per-sample aggregates
|
|
for i := range results {
|
|
results[i].Agg = aggregateMetrics(results[i].QAResults)
|
|
}
|
|
|
|
got := computeModeAgg(results)
|
|
|
|
// Weighted: (0.4+0.6+0.1*8) / 10 = 1.8/10 = 0.18
|
|
wantF1 := 0.18
|
|
if math.Abs(got.OverallF1-wantF1) > 1e-9 {
|
|
t.Errorf("OverallF1 = %.6f, want %.6f (weighted average)", got.OverallF1, wantF1)
|
|
}
|
|
|
|
// Weighted: (0.5+0.7+0.2*8) / 10 = 2.8/10 = 0.28
|
|
wantRecall := 0.28
|
|
if math.Abs(got.OverallHitRate-wantRecall) > 1e-9 {
|
|
t.Errorf("OverallHitRate = %.6f, want %.6f (weighted average)", got.OverallHitRate, wantRecall)
|
|
}
|
|
|
|
if got.TotalQuestions != 10 {
|
|
t.Errorf("TotalQuestions = %d, want 10", got.TotalQuestions)
|
|
}
|
|
}
|
|
|
|
func TestAggregateMetricsSentinel(t *testing.T) {
|
|
qa := []QAResult{
|
|
{Category: 1, TokenF1: 0.8, HitRate: 0.5},
|
|
{Category: 1, TokenF1: -1.0, HitRate: 0.3},
|
|
{Category: 1, TokenF1: 0.4, HitRate: 0.7},
|
|
}
|
|
agg := aggregateMetrics(qa)
|
|
|
|
if agg.ValidF1Count != 2 {
|
|
t.Errorf("ValidF1Count = %d, want 2", agg.ValidF1Count)
|
|
}
|
|
if agg.TotalQuestions != 3 {
|
|
t.Errorf("TotalQuestions = %d, want 3", agg.TotalQuestions)
|
|
}
|
|
wantF1 := (0.8 + 0.4) / 2.0
|
|
if math.Abs(agg.OverallF1-wantF1) > 1e-9 {
|
|
t.Errorf("OverallF1 = %.6f, want %.6f", agg.OverallF1, wantF1)
|
|
}
|
|
wantHR := (0.5 + 0.3 + 0.7) / 3.0
|
|
if math.Abs(agg.OverallHitRate-wantHR) > 1e-9 {
|
|
t.Errorf("OverallHitRate = %.6f, want %.6f", agg.OverallHitRate, wantHR)
|
|
}
|
|
}
|
|
|
|
func TestAggregateMetricsAllSentinel(t *testing.T) {
|
|
qa := []QAResult{
|
|
{Category: 1, TokenF1: -1.0, HitRate: 0.5},
|
|
{Category: 1, TokenF1: -1.0, HitRate: 0.3},
|
|
}
|
|
agg := aggregateMetrics(qa)
|
|
|
|
if agg.ValidF1Count != 0 {
|
|
t.Errorf("ValidF1Count = %d, want 0", agg.ValidF1Count)
|
|
}
|
|
if agg.OverallF1 != 0 {
|
|
t.Errorf("OverallF1 = %.6f, want 0", agg.OverallF1)
|
|
}
|
|
}
|
|
|
|
func TestComputeModeAggSentinelWeighting(t *testing.T) {
|
|
results := []EvalResult{
|
|
{
|
|
Mode: "test",
|
|
SampleID: "s1",
|
|
QAResults: []QAResult{
|
|
{Category: 1, TokenF1: 0.8, HitRate: 0.5},
|
|
{Category: 1, TokenF1: -1.0, HitRate: 0.3},
|
|
},
|
|
},
|
|
{
|
|
Mode: "test",
|
|
SampleID: "s2",
|
|
QAResults: []QAResult{
|
|
{Category: 1, TokenF1: 0.4, HitRate: 0.6},
|
|
{Category: 1, TokenF1: 0.6, HitRate: 0.8},
|
|
},
|
|
},
|
|
}
|
|
for i := range results {
|
|
results[i].Agg = aggregateMetrics(results[i].QAResults)
|
|
}
|
|
|
|
got := computeModeAgg(results)
|
|
|
|
// s1: ValidF1Count=1, F1=0.8; s2: ValidF1Count=2, F1=0.5
|
|
// Weighted: (0.8*1 + 0.5*2) / 3 = 1.8/3 = 0.6
|
|
wantF1 := 0.6
|
|
if math.Abs(got.OverallF1-wantF1) > 1e-9 {
|
|
t.Errorf("OverallF1 = %.6f, want %.6f", got.OverallF1, wantF1)
|
|
}
|
|
if got.ValidF1Count != 3 {
|
|
t.Errorf("ValidF1Count = %d, want 3", got.ValidF1Count)
|
|
}
|
|
if got.TotalQuestions != 4 {
|
|
t.Errorf("TotalQuestions = %d, want 4", got.TotalQuestions)
|
|
}
|
|
}
|