Files
picoclaw/cmd/membench/main.go
T
BeaconCat f1b659e5ef membench: add LLM-as-Judge evaluation mode (#2484)
* membench: add LLM-as-Judge evaluation mode

Add --eval-mode=llm to membench for LLM-based answer generation and
semantic scoring via an OpenAI-compatible API endpoint.

New files:
- llm_client.go: generic OpenAI-compatible chat completion client
  with support for API key, configurable timeout, and optional
  chat_template_kwargs (for llama.cpp thinking models)
- eval_llm.go: LLM answer generation + LLM-as-Judge scoring for
  both legacy and seahorse retrieval modes

Changes to main.go:
- --eval-mode flag (token|llm) to select evaluation strategy
- --api-base, --api-key, --model flags with env var fallback
  (MEMBENCH_API_BASE, MEMBENCH_API_KEY, MEMBENCH_MODEL)
- --no-thinking flag for llama.cpp + Qwen thinking models
- --limit flag to cap QA questions per sample for quick testing

* style: fix golangci-lint formatting (gofmt + golines)

* fix: address Copilot review feedback

- Validate --model is required for LLM eval mode
- Use rune-based truncation to preserve valid UTF-8
- Precompute totalQA count outside inner loop
- Log SearchMessages errors instead of silently skipping

* fix: address Copilot review round 2

- Validate --eval-mode accepts only 'token' or 'llm'
- Normalize base URL to avoid /v1/v1 duplication
- Separate token/LLM results for correct PrintComparison labeling
- Log ExpandMessages errors instead of silently ignoring
- Short-circuit with 0 scores when no context retrieved (match token eval)
- Add --timeout flag wired to LLMClientOptions.Timeout

* fix: address review P1+P2 — sort alignment, failure sentinel, score parser

- P1: Replace hand-rolled sortByRank with sort.Slice (ascending, best
  first) matching eval.go's EvalSeahorse — ensures BudgetTruncate keeps
  best-ranked messages when truncation occurs
- P2: Use -1.0 sentinel for LLM API failures and parse errors, distinct
  from genuine 0.0 score; aggregateMetrics skips -1.0 entries for F1
  averaging while still counting HitRate
- P2: Use regexp \b([1-5])\b for judge score extraction instead of
  first-digit scan — avoids misparses on '5/5', 'Score: 3' etc.

* fix: address Copilot review round 2

- Fix F1/HitRate weighted aggregation: track ValidF1Count separately so
  computeModeAgg weights F1 by valid scores only, not TotalQuestions
- No-context retrieval failure uses 0.0 (genuine bad score) instead of
  -1.0 sentinel (reserved for API/parse failures)
- Validate --timeout > 0 to prevent disabling HTTP timeouts

* fix: remove hardcoded /v1 from API base URL

Users now provide the full versioned path in --api-base (e.g. /v1, /v4).
Code only appends /chat/completions. Default changed to
http://127.0.0.1:8080/v1 for backward compatibility.

* fix: address Copilot review round 3

- ValidF1Count=0 when all scores are sentinel (no forced =1)
- Backward compat: old eval JSON without ValidF1Count falls back to
  TotalQuestions in computeModeAgg
- Skip empty section in PrintComparison when tokenResults is empty
- Update --api-base flag help to document /v1 default and version path
- Add sentinel aggregation unit tests (partial, all, weighted)

* feat: add --retries flag with exponential backoff for transient LLM errors

Retry on timeout, 5xx, and 429 (rate limit) with 1s/2s/4s backoff.
Default 3 retries, configurable via --retries. Context cancellation
is respected between retries.

* fix: address Copilot review round 4

- runReport splits results by mode suffix into token/llm for PrintComparison
- backward compat fallback (ValidF1Count=0 -> TotalQuestions) only for
  non-LLM modes; LLM modes keep ValidF1Count=0 when all scores sentinel
- MaxRetries==0 means no retry; only negative falls back to default 3
- truncateStr uses []rune to avoid cutting multi-byte UTF-8 characters
- Complete() returns error on empty LLM response (vs silent empty string)

* feat: --no-thinking adapts to llama.cpp, Ollama, and GLM backends

Send all three disable-thinking fields simultaneously:
- chat_template_kwargs.enable_thinking=false (llama.cpp, GLM)
- think=false (Ollama 0.9+)
- thinking.type=disabled (GLM/Zhipu)
Each backend picks the field it recognizes and ignores the rest.
Also bumps max_tokens from 512 to 2048 for thinking models.

* feat: mixed model eval + concurrent QA workers

- Add --judge-model, --judge-api-base, --judge-api-key flags for separate judge model
- Add --concurrency flag (default 1) with semaphore-based goroutine pool
- Add reasoning_content fallback for GLM/DeepSeek style responses
- Prepend /no_think to system prompt for Ollama /v1 compatibility
- Reduce default MaxTokens from 2048 to 512 (answers are 1-3 sentences)
- Extract evalQAWorker and buildSeahorseContext for shared concurrent logic

---------

Co-authored-by: BeaconCat <BeaconCat@users.noreply.github.com>
2026-04-15 21:15:17 +08:00

362 lines
12 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package main
import (
"context"
"encoding/json"
"fmt"
"log"
"os"
"path/filepath"
"strings"
"time"
"github.com/spf13/cobra"
"github.com/sipeed/picoclaw/pkg/logger"
)
var (
flagData string
flagOut string
flagMode string
flagBudget int
flagEvalMode string
flagAPIBase string
flagAPIKey string
flagModel string
flagNoThinking bool
flagLimit int
flagTimeout int
flagRetries int
flagJudgeModel string
flagJudgeAPIBase string
flagJudgeAPIKey string
flagConcurrency int
)
func main() {
// Suppress seahorse INFO logs during benchmark
logger.SetLevel(logger.WARN)
rootCmd := &cobra.Command{
Use: "membench",
Short: "Memory benchmark tool for picoclaw",
}
ingestCmd := &cobra.Command{
Use: "ingest",
Short: "Load LOCOMO data into storage backends",
RunE: runIngest,
}
ingestCmd.Flags().StringVar(&flagData, "data", "", "LOCOMO dataset directory (required)")
ingestCmd.Flags().StringVar(&flagOut, "out", "./bench-out", "output working directory")
ingestCmd.Flags().StringVar(&flagMode, "mode", "all", "modes to ingest: legacy, seahorse, or all")
evalCmd := &cobra.Command{
Use: "eval",
Short: "Run QA evaluation against ingested data",
RunE: runEval,
}
evalCmd.Flags().StringVar(&flagData, "data", "", "LOCOMO dataset directory (required)")
evalCmd.Flags().StringVar(&flagOut, "out", "./bench-out", "output working directory")
evalCmd.Flags().StringVar(&flagMode, "mode", "all", "modes to evaluate: legacy, seahorse, or all")
evalCmd.Flags().IntVar(&flagBudget, "budget", 4000, "token budget for retrieval")
evalCmd.Flags().
StringVar(&flagEvalMode, "eval-mode", "token", "evaluation mode: token (direct match) or llm (LLM-as-Judge)")
evalCmd.Flags().
StringVar(&flagAPIBase, "api-base", "", "API base URL with version path, e.g. http://host/v1 (default: http://127.0.0.1:8080/v1, env: MEMBENCH_API_BASE)")
evalCmd.Flags().StringVar(&flagAPIKey, "api-key", "", "API key for the LLM endpoint (env: MEMBENCH_API_KEY)")
evalCmd.Flags().StringVar(&flagModel, "model", "", "model name for LLM eval (env: MEMBENCH_MODEL)")
evalCmd.Flags().
BoolVar(&flagNoThinking, "no-thinking", false, "disable thinking mode via chat_template_kwargs (llama.cpp + Qwen)")
evalCmd.Flags().IntVar(&flagLimit, "limit", 0, "max QA questions per sample (0 = all)")
evalCmd.Flags().IntVar(&flagTimeout, "timeout", 120, "HTTP timeout in seconds for LLM requests")
evalCmd.Flags().IntVar(&flagRetries, "retries", 3, "max retry attempts for transient LLM errors (timeout/5xx/429)")
evalCmd.Flags().StringVar(&flagJudgeModel, "judge-model", "", "model for judge scoring (defaults to --model)")
evalCmd.Flags().
StringVar(&flagJudgeAPIBase, "judge-api-base", "", "API base URL for judge model (defaults to --api-base)")
evalCmd.Flags().StringVar(&flagJudgeAPIKey, "judge-api-key", "", "API key for judge model (defaults to --api-key)")
evalCmd.Flags().IntVar(&flagConcurrency, "concurrency", 1, "number of concurrent QA evaluations")
reportCmd := &cobra.Command{
Use: "report",
Short: "Output comparison results from evaluation",
RunE: runReport,
}
reportCmd.Flags().StringVar(&flagOut, "out", "./bench-out", "output working directory")
runCmd := &cobra.Command{
Use: "run",
Short: "Convenience: eval + report (ingestion is done inline)",
RunE: runAll,
}
runCmd.Flags().StringVar(&flagData, "data", "", "LOCOMO dataset directory (required)")
runCmd.Flags().StringVar(&flagOut, "out", "./bench-out", "output working directory")
runCmd.Flags().StringVar(&flagMode, "mode", "all", "modes to run: legacy, seahorse, or all")
runCmd.Flags().IntVar(&flagBudget, "budget", 4000, "token budget for retrieval")
runCmd.Flags().
StringVar(&flagEvalMode, "eval-mode", "token", "evaluation mode: token (direct match) or llm (LLM-as-Judge)")
runCmd.Flags().
StringVar(&flagAPIBase, "api-base", "", "API base URL with version path, e.g. http://host/v1 (default: http://127.0.0.1:8080/v1, env: MEMBENCH_API_BASE)")
runCmd.Flags().StringVar(&flagAPIKey, "api-key", "", "API key for the LLM endpoint (env: MEMBENCH_API_KEY)")
runCmd.Flags().StringVar(&flagModel, "model", "", "model name for LLM eval (env: MEMBENCH_MODEL)")
runCmd.Flags().
BoolVar(&flagNoThinking, "no-thinking", false, "disable thinking mode via chat_template_kwargs (llama.cpp + Qwen)")
runCmd.Flags().IntVar(&flagLimit, "limit", 0, "max QA questions per sample (0 = all)")
runCmd.Flags().IntVar(&flagTimeout, "timeout", 120, "HTTP timeout in seconds for LLM requests")
runCmd.Flags().IntVar(&flagRetries, "retries", 3, "max retry attempts for transient LLM errors (timeout/5xx/429)")
runCmd.Flags().StringVar(&flagJudgeModel, "judge-model", "", "model for judge scoring (defaults to --model)")
runCmd.Flags().
StringVar(&flagJudgeAPIBase, "judge-api-base", "", "API base URL for judge model (defaults to --api-base)")
runCmd.Flags().StringVar(&flagJudgeAPIKey, "judge-api-key", "", "API key for judge model (defaults to --api-key)")
runCmd.Flags().IntVar(&flagConcurrency, "concurrency", 1, "number of concurrent QA evaluations")
rootCmd.AddCommand(ingestCmd, evalCmd, reportCmd, runCmd)
if err := rootCmd.Execute(); err != nil {
os.Exit(1)
}
}
func modesFromFlag() []string {
switch strings.ToLower(flagMode) {
case "all":
return []string{"legacy", "seahorse"}
default:
return []string{strings.ToLower(flagMode)}
}
}
func runIngest(cmd *cobra.Command, args []string) error {
if flagData == "" {
return fmt.Errorf("--data is required")
}
modes := modesFromFlag()
if len(modes) == 0 {
return nil
}
ctx := context.Background()
samples, err := LoadDataset(flagData)
if err != nil {
return fmt.Errorf("load dataset: %w", err)
}
log.Printf("Loaded %d samples from %s", len(samples), flagData)
for _, mode := range modes {
switch mode {
case "legacy":
legacy := NewLegacyStore()
for i := range samples {
legacy.IngestSample(&samples[i])
}
log.Printf("legacy: ingested %d samples", len(samples))
case "seahorse":
dbPath := filepath.Join(flagOut, "seahorse.db")
if err := os.MkdirAll(flagOut, 0o755); err != nil {
return fmt.Errorf("create out dir: %w", err)
}
_, err := IngestSeahorse(ctx, samples, dbPath)
if err != nil {
return fmt.Errorf("ingest seahorse: %w", err)
}
}
}
return nil
}
func runEval(cmd *cobra.Command, args []string) error {
if flagData == "" {
return fmt.Errorf("--data is required")
}
modes := modesFromFlag()
if len(modes) == 0 {
return nil
}
ctx := context.Background()
samples, err := LoadDataset(flagData)
if err != nil {
return fmt.Errorf("load dataset: %w", err)
}
log.Printf("Loaded %d samples", len(samples))
if flagLimit > 0 {
for i := range samples {
if len(samples[i].QA) > flagLimit {
samples[i].QA = samples[i].QA[:flagLimit]
}
}
log.Printf("Limited to %d QA per sample", flagLimit)
}
evalMode := strings.ToLower(strings.TrimSpace(flagEvalMode))
var useLLM bool
switch evalMode {
case "token":
useLLM = false
case "llm":
useLLM = true
default:
return fmt.Errorf("invalid --eval-mode %q: must be token or llm", flagEvalMode)
}
var answerClient, judgeClient *LLMClient
if useLLM {
opts, err := buildLLMOptions()
if err != nil {
return err
}
answerClient = NewLLMClient(opts)
judgeClient = answerClient // default: same client
if flagJudgeModel != "" {
jOpts := opts // copy base settings
jOpts.Model = flagJudgeModel
if flagJudgeAPIBase != "" {
jOpts.BaseURL = flagJudgeAPIBase
}
if flagJudgeAPIKey != "" {
jOpts.APIKey = flagJudgeAPIKey
}
judgeClient = NewLLMClient(jOpts)
log.Printf("Judge model: model=%s base=%s no-thinking=%v", jOpts.Model, jOpts.BaseURL, jOpts.NoThinking)
}
log.Printf("LLM eval mode: model=%s base=%s no-thinking=%v concurrency=%d",
opts.Model, opts.BaseURL, opts.NoThinking, flagConcurrency)
}
var tokenResults, llmResults []EvalResult
for _, mode := range modes {
switch mode {
case "legacy":
legacy := NewLegacyStore()
for i := range samples {
legacy.IngestSample(&samples[i])
}
if useLLM {
results := EvalLegacyLLM(ctx, samples, legacy, flagBudget, answerClient, judgeClient, flagConcurrency)
llmResults = append(llmResults, results...)
log.Printf("legacy-llm: evaluated %d samples", len(results))
} else {
results := EvalLegacy(ctx, samples, legacy, flagBudget)
tokenResults = append(tokenResults, results...)
log.Printf("legacy: evaluated %d samples", len(results))
}
case "seahorse":
dbPath := filepath.Join(flagOut, "seahorse.db")
ir, err := IngestSeahorse(ctx, samples, dbPath)
if err != nil {
return fmt.Errorf("ingest seahorse: %w", err)
}
if useLLM {
results := EvalSeahorseLLM(ctx, samples, ir, flagBudget, answerClient, judgeClient, flagConcurrency)
llmResults = append(llmResults, results...)
log.Printf("seahorse-llm: evaluated %d samples", len(results))
} else {
results := EvalSeahorse(ctx, samples, ir, flagBudget)
tokenResults = append(tokenResults, results...)
log.Printf("seahorse: evaluated %d samples", len(results))
}
}
}
allResults := append(tokenResults, llmResults...)
if err := SaveResults(allResults, flagOut); err != nil {
return fmt.Errorf("save results: %w", err)
}
if err := SaveAggregated(allResults, flagOut); err != nil {
return fmt.Errorf("save aggregated: %w", err)
}
PrintComparison(tokenResults, llmResults)
return nil
}
func runReport(cmd *cobra.Command, args []string) error {
entries, err := os.ReadDir(flagOut)
if err != nil {
return fmt.Errorf("read out dir: %w", err)
}
var allResults []EvalResult
for _, entry := range entries {
if !entry.IsDir() && strings.HasPrefix(entry.Name(), "eval_") && strings.HasSuffix(entry.Name(), ".json") {
path := filepath.Join(flagOut, entry.Name())
var r EvalResult
data, err := os.ReadFile(path)
if err != nil {
log.Printf("WARN: read %s: %v", path, err)
continue
}
if err := json.Unmarshal(data, &r); err != nil {
log.Printf("WARN: parse %s: %v", path, err)
continue
}
allResults = append(allResults, r)
}
}
if len(allResults) == 0 {
return fmt.Errorf("no eval results found in %s", flagOut)
}
var tokenResults, llmResults []EvalResult
for _, r := range allResults {
if strings.HasSuffix(r.Mode, "-llm") {
llmResults = append(llmResults, r)
} else {
tokenResults = append(tokenResults, r)
}
}
PrintComparison(tokenResults, llmResults)
return nil
}
func runAll(cmd *cobra.Command, args []string) error {
return runEval(cmd, args)
}
// envOrFlag returns the flag value if non-empty, otherwise falls back to the
// environment variable.
func envOrFlag(flag, envKey string) string {
if flag != "" {
return flag
}
return os.Getenv(envKey)
}
// buildLLMOptions resolves LLM client configuration from flags and environment
// variables. Flag values take precedence over environment variables.
//
// Environment variables:
//
// MEMBENCH_API_BASE OpenAI-compatible base URL (default http://127.0.0.1:8080/v1)
// MEMBENCH_API_KEY Bearer token for the endpoint
// MEMBENCH_MODEL Model name to send in the request
func buildLLMOptions() (LLMClientOptions, error) {
base := envOrFlag(flagAPIBase, "MEMBENCH_API_BASE")
if base == "" {
base = "http://127.0.0.1:8080/v1"
}
model := envOrFlag(flagModel, "MEMBENCH_MODEL")
if model == "" {
return LLMClientOptions{}, fmt.Errorf(
"--model or MEMBENCH_MODEL is required for LLM eval mode",
)
}
apiKey := envOrFlag(flagAPIKey, "MEMBENCH_API_KEY")
if flagTimeout <= 0 {
return LLMClientOptions{}, fmt.Errorf("--timeout must be > 0, got %d", flagTimeout)
}
return LLMClientOptions{
BaseURL: base,
Model: model,
APIKey: apiKey,
NoThinking: flagNoThinking,
Timeout: time.Duration(flagTimeout) * time.Second,
MaxRetries: flagRetries,
}, nil
}