mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Compare commits
24 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 51eecde01e | |||
| 8b3e502690 | |||
| 7d16764674 | |||
| ee29aaa871 | |||
| 330de0c382 | |||
| 6ce0306c66 | |||
| 1fc2710999 | |||
| 6a8552a664 | |||
| 7bf6cbe1fa | |||
| 38a498e202 | |||
| 778f939302 | |||
| 84edc462d6 | |||
| f0e6b7aa37 | |||
| 661ce5e311 | |||
| c3e7396a3d | |||
| 29277d4b3b | |||
| 9ec27835cf | |||
| 1175f4a62b | |||
| 15a70ac45c | |||
| 71337b6f52 | |||
| 84e42d6904 | |||
| e8d92e4a36 | |||
| cbd0798a56 | |||
| d8c5183d9a |
@@ -41,10 +41,11 @@ jobs:
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
|
||||
- name: Install govulncheck
|
||||
run: go install golang.org/x/vuln/cmd/govulncheck@v1.1.4
|
||||
|
||||
- name: Run Govulncheck
|
||||
uses: golang/govulncheck-action@v1
|
||||
with:
|
||||
go-package: ./...
|
||||
run: govulncheck -C . -format text ./...
|
||||
|
||||
test:
|
||||
name: Tests
|
||||
|
||||
@@ -67,3 +67,5 @@ web/backend/dist/*
|
||||
.claude/
|
||||
|
||||
docker/data
|
||||
|
||||
.omc/
|
||||
|
||||
@@ -12,6 +12,7 @@ linters:
|
||||
- exhaustruct
|
||||
- funcorder
|
||||
- gochecknoglobals
|
||||
- gosmopolitan # Project legitimately uses CJK text in tests (FTS5, token counting)
|
||||
- godot
|
||||
- intrange
|
||||
- ireturn
|
||||
|
||||
@@ -349,6 +349,25 @@ build-macos-app:build-launcher
|
||||
@./scripts/build-macos-app.sh $(PLATFORM)-$(ARCH)
|
||||
@echo "macOS .app bundle created: $(BUILD_DIR)/PicoClaw.app"
|
||||
|
||||
## mem: Build membench, download LOCOMO data (if needed), run benchmark, and show results
|
||||
mem:
|
||||
@echo "Building membench..."
|
||||
@mkdir -p $(BUILD_DIR)
|
||||
@$(GO) build -o $(BUILD_DIR)/membench ./cmd/membench
|
||||
@echo "Build complete: $(BUILD_DIR)/membench"
|
||||
@if [ ! -f $(BUILD_DIR)/memdata/locomo10.json ]; then \
|
||||
echo "Downloading LOCOMO dataset..."; \
|
||||
mkdir -p $(BUILD_DIR)/memdata; \
|
||||
curl -sfL "https://raw.githubusercontent.com/snap-research/locomo/main/data/locomo10.json" \
|
||||
-o $(BUILD_DIR)/memdata/locomo10.json && [ -s $(BUILD_DIR)/memdata/locomo10.json ] || { echo "Error: LOCOMO download failed"; exit 1; }; \
|
||||
echo "Download complete"; \
|
||||
else \
|
||||
echo "LOCOMO dataset already exists, skipping download"; \
|
||||
fi
|
||||
@echo "Running benchmark..."
|
||||
@rm -rf $(BUILD_DIR)/memout
|
||||
@$(BUILD_DIR)/membench run --data $(BUILD_DIR)/memdata --out $(BUILD_DIR)/memout --budget 4000
|
||||
|
||||
## help: Show this help message
|
||||
help:
|
||||
@echo "picoclaw Makefile"
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 365 KiB After Width: | Height: | Size: 362 KiB |
@@ -0,0 +1,366 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/seahorse"
|
||||
)
|
||||
|
||||
// EvalResult holds per-sample evaluation results for one mode.
|
||||
type EvalResult struct {
|
||||
Mode string `json:"mode"`
|
||||
SampleID string `json:"sampleId"`
|
||||
QAResults []QAResult `json:"qaResults"`
|
||||
Agg AggMetrics `json:"aggregated"`
|
||||
}
|
||||
|
||||
// QAResult holds metrics for a single QA pair.
|
||||
type QAResult struct {
|
||||
Question string `json:"question"`
|
||||
Category int `json:"category"`
|
||||
GoldAnswer string `json:"goldAnswer"`
|
||||
TokenF1 float64 `json:"tokenF1"`
|
||||
HitRate float64 `json:"hitRate"`
|
||||
}
|
||||
|
||||
// AggMetrics holds aggregated evaluation metrics.
|
||||
type AggMetrics struct {
|
||||
OverallF1 float64 `json:"overallF1"`
|
||||
OverallHitRate float64 `json:"overallHitRate"`
|
||||
ByCategory map[int]*CatMetrics `json:"byCategory"`
|
||||
TotalQuestions int `json:"totalQuestions"`
|
||||
}
|
||||
|
||||
// CatMetrics holds metrics for a single category.
|
||||
type CatMetrics struct {
|
||||
F1 float64 `json:"f1"`
|
||||
HitRate float64 `json:"hitRate"`
|
||||
QuestionCount int `json:"questionCount"`
|
||||
}
|
||||
|
||||
// EvalLegacy evaluates using legacy session store (raw history + budget truncation).
|
||||
func EvalLegacy(
|
||||
ctx context.Context,
|
||||
samples []LocomoSample,
|
||||
legacy *LegacyStore,
|
||||
budgetTokens int,
|
||||
) []EvalResult {
|
||||
results := make([]EvalResult, 0, len(samples))
|
||||
for si := range samples {
|
||||
sample := &samples[si]
|
||||
history := legacy.GetHistory(sample.SampleID)
|
||||
|
||||
// Convert messages to content strings
|
||||
allContent := make([]string, 0, len(history))
|
||||
for _, msg := range history {
|
||||
allContent = append(allContent, msg.Content)
|
||||
}
|
||||
|
||||
qaResults := make([]QAResult, 0, len(sample.QA))
|
||||
for qi := range sample.QA {
|
||||
qa := &sample.QA[qi]
|
||||
// Budget truncate the full history
|
||||
truncated, _ := BudgetTruncate(allContent, budgetTokens)
|
||||
context := StringListToContent(truncated)
|
||||
|
||||
f1 := TokenOverlapF1(context, qa.AnswerString())
|
||||
hitRate := RecallHitRate(qa.Evidence, sample, context)
|
||||
|
||||
qaResults = append(qaResults, QAResult{
|
||||
Question: qa.Question,
|
||||
Category: qa.Category,
|
||||
GoldAnswer: qa.AnswerString(),
|
||||
TokenF1: f1,
|
||||
HitRate: hitRate,
|
||||
})
|
||||
}
|
||||
|
||||
results = append(results, EvalResult{
|
||||
Mode: "legacy",
|
||||
SampleID: sample.SampleID,
|
||||
QAResults: qaResults,
|
||||
Agg: aggregateMetrics(qaResults),
|
||||
})
|
||||
}
|
||||
return results
|
||||
}
|
||||
|
||||
// EvalSeahorse evaluates using seahorse short memory (per-keyword search + expand).
|
||||
func EvalSeahorse(
|
||||
ctx context.Context,
|
||||
samples []LocomoSample,
|
||||
ir *SeahorseIngestResult,
|
||||
budgetTokens int,
|
||||
) []EvalResult {
|
||||
store := ir.Engine.GetRetrieval().Store()
|
||||
retrieval := ir.Engine.GetRetrieval()
|
||||
|
||||
results := make([]EvalResult, 0, len(samples))
|
||||
for si := range samples {
|
||||
sample := &samples[si]
|
||||
convID, ok := ir.ConvMap[sample.SampleID]
|
||||
if !ok {
|
||||
log.Printf("WARN: no conversation ID for sample %s", sample.SampleID)
|
||||
continue
|
||||
}
|
||||
|
||||
qaResults := make([]QAResult, 0, len(sample.QA))
|
||||
for qi := range sample.QA {
|
||||
qa := &sample.QA[qi]
|
||||
keywords := ExtractKeywords(qa.Question)
|
||||
|
||||
// Search each keyword individually and union results,
|
||||
// tracking best BM25 rank per message for relevance sorting.
|
||||
bestRank := map[int64]float64{}
|
||||
for _, kw := range keywords {
|
||||
searchResults, err := store.SearchMessages(ctx, seahorse.SearchInput{
|
||||
Pattern: kw,
|
||||
ConversationID: convID,
|
||||
Limit: 20,
|
||||
})
|
||||
if err != nil {
|
||||
log.Printf("WARN: search failed for keyword %q: %v", kw, err)
|
||||
continue
|
||||
}
|
||||
for _, sr := range searchResults {
|
||||
if sr.MessageID > 0 {
|
||||
if prev, ok := bestRank[sr.MessageID]; !ok || sr.Rank < prev {
|
||||
bestRank[sr.MessageID] = sr.Rank
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Sort messageIDs by rank ascending (best/most-negative first).
|
||||
// BudgetTruncate walks from the front, keeping best-ranked messages.
|
||||
// Note: SQLite FTS5 bm25() returns negative values where more
|
||||
// negative = better match.
|
||||
messageIDs := make([]int64, 0, len(bestRank))
|
||||
for id := range bestRank {
|
||||
messageIDs = append(messageIDs, id)
|
||||
}
|
||||
sort.Slice(messageIDs, func(i, j int) bool {
|
||||
return bestRank[messageIDs[i]] < bestRank[messageIDs[j]]
|
||||
})
|
||||
|
||||
// Expand messages to get full content
|
||||
var contentParts []string
|
||||
if len(messageIDs) > 0 {
|
||||
expandResult, err := retrieval.ExpandMessages(ctx, messageIDs)
|
||||
if err != nil {
|
||||
log.Printf("WARN: expand failed for sample %s: %v", sample.SampleID, err)
|
||||
} else {
|
||||
for _, msg := range expandResult.Messages {
|
||||
contentParts = append(contentParts, msg.Content)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(contentParts) == 0 {
|
||||
qaResults = append(qaResults, QAResult{
|
||||
Question: qa.Question,
|
||||
Category: qa.Category,
|
||||
GoldAnswer: qa.AnswerString(),
|
||||
TokenF1: 0.0,
|
||||
HitRate: 0.0,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// Budget truncate (drop worst-ranked)
|
||||
truncated, _ := BudgetTruncate(contentParts, budgetTokens)
|
||||
context := StringListToContent(truncated)
|
||||
|
||||
f1 := TokenOverlapF1(context, qa.AnswerString())
|
||||
hitRate := RecallHitRate(qa.Evidence, sample, context)
|
||||
|
||||
qaResults = append(qaResults, QAResult{
|
||||
Question: qa.Question,
|
||||
Category: qa.Category,
|
||||
GoldAnswer: qa.AnswerString(),
|
||||
TokenF1: f1,
|
||||
HitRate: hitRate,
|
||||
})
|
||||
}
|
||||
|
||||
results = append(results, EvalResult{
|
||||
Mode: "seahorse",
|
||||
SampleID: sample.SampleID,
|
||||
QAResults: qaResults,
|
||||
Agg: aggregateMetrics(qaResults),
|
||||
})
|
||||
}
|
||||
return results
|
||||
}
|
||||
|
||||
// aggregateMetrics computes overall and per-category metrics.
|
||||
func aggregateMetrics(qaResults []QAResult) AggMetrics {
|
||||
byCat := map[int]*CatMetrics{}
|
||||
totalF1 := 0.0
|
||||
totalHitRate := 0.0
|
||||
for _, qr := range qaResults {
|
||||
totalF1 += qr.TokenF1
|
||||
totalHitRate += qr.HitRate
|
||||
cat, ok := byCat[qr.Category]
|
||||
if !ok {
|
||||
cat = &CatMetrics{}
|
||||
byCat[qr.Category] = cat
|
||||
}
|
||||
cat.F1 += qr.TokenF1
|
||||
cat.HitRate += qr.HitRate
|
||||
cat.QuestionCount++
|
||||
}
|
||||
n := len(qaResults)
|
||||
if n == 0 {
|
||||
n = 1
|
||||
}
|
||||
agg := AggMetrics{
|
||||
OverallF1: totalF1 / float64(n),
|
||||
OverallHitRate: totalHitRate / float64(n),
|
||||
ByCategory: byCat,
|
||||
TotalQuestions: len(qaResults),
|
||||
}
|
||||
for _, cat := range agg.ByCategory {
|
||||
if cat.QuestionCount > 0 {
|
||||
cat.F1 /= float64(cat.QuestionCount)
|
||||
cat.HitRate /= float64(cat.QuestionCount)
|
||||
}
|
||||
}
|
||||
return agg
|
||||
}
|
||||
|
||||
// SaveResults writes per-sample eval results to JSON files.
|
||||
func SaveResults(results []EvalResult, outDir string) error {
|
||||
if err := os.MkdirAll(outDir, 0o755); err != nil {
|
||||
return fmt.Errorf("create output dir: %w", err)
|
||||
}
|
||||
for _, r := range results {
|
||||
path := filepath.Join(outDir, fmt.Sprintf("eval_%s_%s.json", r.Mode, r.SampleID))
|
||||
data, err := json.MarshalIndent(r, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal result: %w", err)
|
||||
}
|
||||
if err := os.WriteFile(path, data, 0o644); err != nil {
|
||||
return fmt.Errorf("write result: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SaveAggregated writes a combined results.json with all modes.
|
||||
func SaveAggregated(results []EvalResult, outDir string) error {
|
||||
byMode := map[string][]EvalResult{}
|
||||
for _, r := range results {
|
||||
byMode[r.Mode] = append(byMode[r.Mode], r)
|
||||
}
|
||||
|
||||
aggMap := map[string]AggMetrics{}
|
||||
for mode, modeResults := range byMode {
|
||||
aggMap[mode] = computeModeAgg(modeResults)
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(aggMap, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(filepath.Join(outDir, "results.json"), data, 0o644)
|
||||
}
|
||||
|
||||
// computeModeAgg aggregates results for a single mode using weighted averaging
|
||||
// (weighted by question count per sample). All modes must have the same Mode field.
|
||||
func computeModeAgg(results []EvalResult) AggMetrics {
|
||||
agg := AggMetrics{ByCategory: map[int]*CatMetrics{}}
|
||||
for _, r := range results {
|
||||
agg.OverallF1 += r.Agg.OverallF1 * float64(r.Agg.TotalQuestions)
|
||||
agg.OverallHitRate += r.Agg.OverallHitRate * float64(r.Agg.TotalQuestions)
|
||||
agg.TotalQuestions += r.Agg.TotalQuestions
|
||||
for cat, cm := range r.Agg.ByCategory {
|
||||
existing, ok := agg.ByCategory[cat]
|
||||
if !ok {
|
||||
existing = &CatMetrics{}
|
||||
agg.ByCategory[cat] = existing
|
||||
}
|
||||
existing.F1 += cm.F1 * float64(cm.QuestionCount)
|
||||
existing.HitRate += cm.HitRate * float64(cm.QuestionCount)
|
||||
existing.QuestionCount += cm.QuestionCount
|
||||
}
|
||||
}
|
||||
if agg.TotalQuestions > 0 {
|
||||
agg.OverallF1 /= float64(agg.TotalQuestions)
|
||||
agg.OverallHitRate /= float64(agg.TotalQuestions)
|
||||
}
|
||||
for _, cat := range agg.ByCategory {
|
||||
if cat.QuestionCount > 0 {
|
||||
cat.F1 /= float64(cat.QuestionCount)
|
||||
cat.HitRate /= float64(cat.QuestionCount)
|
||||
}
|
||||
}
|
||||
return agg
|
||||
}
|
||||
|
||||
// printSection prints a single comparison table section.
|
||||
func printSection(title string, results []EvalResult) {
|
||||
fmt.Printf("\n--- %s ---\n", title)
|
||||
byMode := map[string][]EvalResult{}
|
||||
for _, r := range results {
|
||||
byMode[r.Mode] = append(byMode[r.Mode], r)
|
||||
}
|
||||
|
||||
modes := map[string]AggMetrics{}
|
||||
for mode, modeResults := range byMode {
|
||||
modes[mode] = computeModeAgg(modeResults)
|
||||
}
|
||||
|
||||
modeKeys := make([]string, 0, len(modes))
|
||||
for k := range modes {
|
||||
modeKeys = append(modeKeys, k)
|
||||
}
|
||||
sort.Strings(modeKeys)
|
||||
|
||||
// Collect all category keys across modes
|
||||
catSet := map[int]bool{}
|
||||
for _, agg := range modes {
|
||||
for cat := range agg.ByCategory {
|
||||
catSet[cat] = true
|
||||
}
|
||||
}
|
||||
cats := make([]int, 0, len(catSet))
|
||||
for cat := range catSet {
|
||||
cats = append(cats, cat)
|
||||
}
|
||||
sort.Ints(cats)
|
||||
|
||||
fmt.Printf("%-10s %-8s %-8s", "Mode", "HitRate", "F1")
|
||||
for _, cat := range cats {
|
||||
fmt.Printf(" %-7s", fmt.Sprintf("C%d", cat))
|
||||
}
|
||||
fmt.Println()
|
||||
fmt.Println(strings.Repeat("-", 10+8+8+7*len(cats)+8))
|
||||
|
||||
for _, mode := range modeKeys {
|
||||
agg := modes[mode]
|
||||
fmt.Printf("%-10s %-8.4f %-8.4f", mode, agg.OverallHitRate, agg.OverallF1)
|
||||
for _, cat := range cats {
|
||||
if cm, ok := agg.ByCategory[cat]; ok {
|
||||
fmt.Printf(" %-7.4f", cm.HitRate)
|
||||
} else {
|
||||
fmt.Printf(" %-7s", "N/A")
|
||||
}
|
||||
}
|
||||
fmt.Println()
|
||||
}
|
||||
}
|
||||
|
||||
// PrintComparison outputs a human-readable comparison table to stdout.
|
||||
func PrintComparison(results []EvalResult, llmResults []EvalResult) {
|
||||
printSection("No LLM generation", results)
|
||||
if len(llmResults) > 0 {
|
||||
printSection("With LLM", llmResults)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,104 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestComputeModeAggAllCategories(t *testing.T) {
|
||||
results := []EvalResult{
|
||||
{
|
||||
Mode: "test",
|
||||
SampleID: "s1",
|
||||
QAResults: []QAResult{
|
||||
{Category: 1, TokenF1: 0.5, HitRate: 0.8},
|
||||
{Category: 2, TokenF1: 0.3, HitRate: 0.6},
|
||||
{Category: 3, TokenF1: 0.1, HitRate: 0.4},
|
||||
{Category: 4, TokenF1: 0.7, HitRate: 0.9},
|
||||
{Category: 5, TokenF1: 0.2, HitRate: 0.1},
|
||||
},
|
||||
},
|
||||
}
|
||||
for i := range results {
|
||||
results[i].Agg = aggregateMetrics(results[i].QAResults)
|
||||
}
|
||||
|
||||
got := computeModeAgg(results)
|
||||
|
||||
// Should have all 5 categories
|
||||
for cat := 1; cat <= 5; cat++ {
|
||||
cm, ok := got.ByCategory[cat]
|
||||
if !ok {
|
||||
t.Errorf("ByCategory missing category %d", cat)
|
||||
continue
|
||||
}
|
||||
if cm.QuestionCount != 1 {
|
||||
t.Errorf("ByCategory[%d].QuestionCount = %d, want 1", cat, cm.QuestionCount)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify specific F1 values per category
|
||||
wantF1 := map[int]float64{1: 0.5, 2: 0.3, 3: 0.1, 4: 0.7, 5: 0.2}
|
||||
for cat, want := range wantF1 {
|
||||
if cm, ok := got.ByCategory[cat]; ok {
|
||||
if math.Abs(cm.F1-want) > 1e-9 {
|
||||
t.Errorf("ByCategory[%d].F1 = %.4f, want %.4f", cat, cm.F1, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeModeAgg(t *testing.T) {
|
||||
// Two samples with different question counts:
|
||||
// sample-a: 2 questions, F1 = [0.4, 0.6] → avg 0.5
|
||||
// sample-b: 8 questions, F1 = [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1] → avg 0.1
|
||||
//
|
||||
// Unweighted (PrintComparison bug): (0.5 + 0.1) / 2 = 0.3
|
||||
// Weighted (correct): (0.4+0.6 + 0.1*8) / 10 = 1.8 / 10 = 0.18
|
||||
results := []EvalResult{
|
||||
{
|
||||
Mode: "test",
|
||||
SampleID: "sample-a",
|
||||
QAResults: []QAResult{
|
||||
{TokenF1: 0.4, HitRate: 0.5},
|
||||
{TokenF1: 0.6, HitRate: 0.7},
|
||||
},
|
||||
},
|
||||
{
|
||||
Mode: "test",
|
||||
SampleID: "sample-b",
|
||||
QAResults: []QAResult{
|
||||
{TokenF1: 0.1, HitRate: 0.2},
|
||||
{TokenF1: 0.1, HitRate: 0.2},
|
||||
{TokenF1: 0.1, HitRate: 0.2},
|
||||
{TokenF1: 0.1, HitRate: 0.2},
|
||||
{TokenF1: 0.1, HitRate: 0.2},
|
||||
{TokenF1: 0.1, HitRate: 0.2},
|
||||
{TokenF1: 0.1, HitRate: 0.2},
|
||||
{TokenF1: 0.1, HitRate: 0.2},
|
||||
},
|
||||
},
|
||||
}
|
||||
// Compute per-sample aggregates
|
||||
for i := range results {
|
||||
results[i].Agg = aggregateMetrics(results[i].QAResults)
|
||||
}
|
||||
|
||||
got := computeModeAgg(results)
|
||||
|
||||
// Weighted: (0.4+0.6+0.1*8) / 10 = 1.8/10 = 0.18
|
||||
wantF1 := 0.18
|
||||
if math.Abs(got.OverallF1-wantF1) > 1e-9 {
|
||||
t.Errorf("OverallF1 = %.6f, want %.6f (weighted average)", got.OverallF1, wantF1)
|
||||
}
|
||||
|
||||
// Weighted: (0.5+0.7+0.2*8) / 10 = 2.8/10 = 0.28
|
||||
wantRecall := 0.28
|
||||
if math.Abs(got.OverallHitRate-wantRecall) > 1e-9 {
|
||||
t.Errorf("OverallHitRate = %.6f, want %.6f (weighted average)", got.OverallHitRate, wantRecall)
|
||||
}
|
||||
|
||||
if got.TotalQuestions != 10 {
|
||||
t.Errorf("TotalQuestions = %d, want 10", got.TotalQuestions)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,85 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/seahorse"
|
||||
)
|
||||
|
||||
// ConvMap stores the mapping from sampleID to seahorse ConversationID.
|
||||
type ConvMap map[string]int64
|
||||
|
||||
// SeahorseIngestResult holds the results of ingesting into seahorse.
|
||||
type SeahorseIngestResult struct {
|
||||
Engine *seahorse.Engine
|
||||
ConvMap ConvMap // sampleID → conversationID
|
||||
}
|
||||
|
||||
// IngestSeahorse loads all LOCOMO samples into a seahorse Engine.
|
||||
// Returns the engine and a mapping from sampleID to conversationID for scoped retrieval.
|
||||
func IngestSeahorse(ctx context.Context, samples []LocomoSample, dbPath string) (*SeahorseIngestResult, error) {
|
||||
noopFn := func(ctx context.Context, prompt string, opts seahorse.CompleteOptions) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
engine, err := seahorse.NewEngine(seahorse.Config{
|
||||
DBPath: dbPath,
|
||||
}, noopFn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create seahorse engine: %w", err)
|
||||
}
|
||||
|
||||
store := engine.GetRetrieval().Store()
|
||||
convMap := make(ConvMap)
|
||||
|
||||
for si := range samples {
|
||||
sample := &samples[si]
|
||||
sessionKey := "locomo-" + sample.SampleID
|
||||
|
||||
// Check if conversation already exists (idempotent)
|
||||
existing, _ := store.GetConversationBySessionKey(ctx, sessionKey)
|
||||
if existing != nil {
|
||||
convMap[sample.SampleID] = existing.ConversationID
|
||||
log.Printf("Skipping existing sample %s: convID=%d", sample.SampleID, existing.ConversationID)
|
||||
continue
|
||||
}
|
||||
|
||||
turns := GetTurns(sample)
|
||||
|
||||
// Convert turns to seahorse messages
|
||||
msgs := make([]seahorse.Message, 0, len(turns))
|
||||
for _, turn := range turns {
|
||||
content := turn.Speaker + ": " + turn.Text
|
||||
msgs = append(msgs, seahorse.Message{
|
||||
Role: "user",
|
||||
Content: content,
|
||||
TokenCount: len(turn.Text) / 4,
|
||||
})
|
||||
}
|
||||
|
||||
// Ingest all turns for this sample
|
||||
_, err := engine.Ingest(ctx, sessionKey, msgs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ingest sample %s: %w", sample.SampleID, err)
|
||||
}
|
||||
|
||||
// Get the conversation ID for scoped retrieval
|
||||
conv, err := store.GetConversationBySessionKey(ctx, sessionKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get conversation for %s: %w", sample.SampleID, err)
|
||||
}
|
||||
if conv == nil {
|
||||
return nil, fmt.Errorf("conversation not found for %s after ingest", sample.SampleID)
|
||||
}
|
||||
convMap[sample.SampleID] = conv.ConversationID
|
||||
log.Printf("Ingested sample %s: %d turns, convID=%d", sample.SampleID, len(turns), conv.ConversationID)
|
||||
}
|
||||
|
||||
log.Printf("Seahorse ingestion complete: %d samples, %d conversations", len(samples), len(convMap))
|
||||
return &SeahorseIngestResult{
|
||||
Engine: engine,
|
||||
ConvMap: convMap,
|
||||
}, nil
|
||||
}
|
||||
@@ -0,0 +1,79 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/seahorse"
|
||||
)
|
||||
|
||||
func TestIngestSeahorseIdempotent(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
tmpDir := t.TempDir()
|
||||
dbPath := filepath.Join(tmpDir, "test.db")
|
||||
|
||||
// Minimal test data
|
||||
samples := []LocomoSample{
|
||||
{
|
||||
SampleID: "test-1",
|
||||
Conversation: map[string]json.RawMessage{
|
||||
"session_1": json.RawMessage(`[
|
||||
{"speaker":"A","dia_id":"D1:1","text":"hello world this is a test message"},
|
||||
{"speaker":"B","dia_id":"D1:2","text":"another message for testing purposes"}
|
||||
]`),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// First ingestion
|
||||
result1, err := IngestSeahorse(ctx, samples, dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("first ingest failed: %v", err)
|
||||
}
|
||||
convCount1 := len(result1.ConvMap)
|
||||
result1.Engine.Close()
|
||||
|
||||
// Second ingestion on same DB — should reuse existing data
|
||||
result2, err := IngestSeahorse(ctx, samples, dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("second ingest failed: %v", err)
|
||||
}
|
||||
defer result2.Engine.Close()
|
||||
|
||||
// ConvMap should have same number of entries (no duplicates)
|
||||
if len(result2.ConvMap) != convCount1 {
|
||||
t.Errorf("second ingest convMap has %d entries, want %d (same as first)",
|
||||
len(result2.ConvMap), convCount1)
|
||||
}
|
||||
|
||||
// Verify conversation IDs are the same (reused, not new ones)
|
||||
for id, cid1 := range result1.ConvMap {
|
||||
cid2, ok := result2.ConvMap[id]
|
||||
if !ok {
|
||||
t.Errorf("sample %s missing from second ConvMap", id)
|
||||
continue
|
||||
}
|
||||
if cid2 != cid1 {
|
||||
t.Errorf("sample %s: second ingest got convID %d, want %d (reused)", id, cid2, cid1)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify no duplicate messages by counting
|
||||
store := result2.Engine.GetRetrieval().Store()
|
||||
for _, convID := range result2.ConvMap {
|
||||
msgs, err := store.SearchMessages(ctx, seahorse.SearchInput{
|
||||
Pattern: "test",
|
||||
ConversationID: convID,
|
||||
Limit: 100,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("search failed: %v", err)
|
||||
}
|
||||
// Should find exactly 1 message containing "test" (the first turn)
|
||||
if len(msgs) > 2 {
|
||||
t.Errorf("found %d messages for 'test' in conv %d, expected ≤2 (no duplicates)", len(msgs), convID)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
)
|
||||
|
||||
// LegacyStore wraps session.SessionManager for legacy baseline.
|
||||
type LegacyStore struct {
|
||||
sm *session.SessionManager
|
||||
}
|
||||
|
||||
// NewLegacyStore creates a new in-memory session manager.
|
||||
func NewLegacyStore() *LegacyStore {
|
||||
return &LegacyStore{
|
||||
sm: session.NewSessionManager(""),
|
||||
}
|
||||
}
|
||||
|
||||
// IngestSample loads all turns from a LOCOMO sample into the legacy session store.
|
||||
func (ls *LegacyStore) IngestSample(sample *LocomoSample) {
|
||||
sessionKey := "locomo-" + sample.SampleID
|
||||
turns := GetTurns(sample)
|
||||
for _, turn := range turns {
|
||||
content := turn.Speaker + ": " + turn.Text
|
||||
ls.sm.AddMessage(sessionKey, "user", content)
|
||||
}
|
||||
}
|
||||
|
||||
// GetHistory returns all messages for a sample's session.
|
||||
func (ls *LegacyStore) GetHistory(sampleID string) []providers.Message {
|
||||
sessionKey := "locomo-" + sampleID
|
||||
return ls.sm.GetHistory(sessionKey)
|
||||
}
|
||||
@@ -0,0 +1,142 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// LocomoSample represents one conversation sample from the LOCOMO dataset.
|
||||
type LocomoSample struct {
|
||||
SampleID string `json:"sample_id"`
|
||||
Conversation map[string]json.RawMessage `json:"conversation"`
|
||||
QA []LocomoQA `json:"qa"`
|
||||
}
|
||||
|
||||
// LocomoTurn represents a single turn in a conversation.
|
||||
type LocomoTurn struct {
|
||||
Speaker string `json:"speaker"`
|
||||
DiaID string `json:"dia_id"`
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
// LocomoQA represents a question-answer pair with evidence.
|
||||
type LocomoQA struct {
|
||||
Question string `json:"question"`
|
||||
Answer json.RawMessage `json:"answer"` // can be string or int (category 1-4)
|
||||
AdversarialAnswer string `json:"adversarial_answer"` // category 5 only
|
||||
Evidence []string `json:"evidence"`
|
||||
Category int `json:"category"` // 1=single-hop, 2=multi-hop, 3=open-ended, 5=adversarial
|
||||
}
|
||||
|
||||
// AnswerString returns the answer as a string, handling both string and int types.
|
||||
func (qa *LocomoQA) AnswerString() string {
|
||||
// Prefer answer field (category 1-4)
|
||||
if len(qa.Answer) > 0 {
|
||||
var s string
|
||||
if err := json.Unmarshal(qa.Answer, &s); err == nil {
|
||||
return s
|
||||
}
|
||||
var n json.Number
|
||||
if err := json.Unmarshal(qa.Answer, &n); err == nil {
|
||||
return n.String()
|
||||
}
|
||||
return strings.Trim(string(qa.Answer), `"`)
|
||||
}
|
||||
// Fallback to adversarial_answer (category 5)
|
||||
return qa.AdversarialAnswer
|
||||
}
|
||||
|
||||
// LoadDataset reads all JSON files from dataDir and returns parsed samples.
|
||||
func LoadDataset(dataDir string) ([]LocomoSample, error) {
|
||||
entries, err := os.ReadDir(dataDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read data dir %s: %w", dataDir, err)
|
||||
}
|
||||
|
||||
var samples []LocomoSample
|
||||
for _, entry := range entries {
|
||||
if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".json") {
|
||||
path := filepath.Join(dataDir, entry.Name())
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read file %s: %w", path, err)
|
||||
}
|
||||
var batch []LocomoSample
|
||||
if err := json.Unmarshal(data, &batch); err != nil {
|
||||
return nil, fmt.Errorf("parse file %s: %w", path, err)
|
||||
}
|
||||
samples = append(samples, batch...)
|
||||
}
|
||||
}
|
||||
return samples, nil
|
||||
}
|
||||
|
||||
// GetSessionNames returns sorted session keys (session_1, session_2, ...) from conversation.
|
||||
func GetSessionNames(conv map[string]json.RawMessage) []string {
|
||||
var names []string
|
||||
for k := range conv {
|
||||
if strings.HasPrefix(k, "session_") && !strings.Contains(k, "_date_time") {
|
||||
names = append(names, k)
|
||||
}
|
||||
}
|
||||
sort.Slice(names, func(i, j int) bool {
|
||||
ni := sessionNum(names[i])
|
||||
nj := sessionNum(names[j])
|
||||
return ni < nj
|
||||
})
|
||||
return names
|
||||
}
|
||||
|
||||
func sessionNum(key string) int {
|
||||
// "session_1" → 1, "session_10" → 10
|
||||
parts := strings.SplitN(key, "_", 2)
|
||||
if len(parts) < 2 {
|
||||
return 0
|
||||
}
|
||||
n, _ := strconv.Atoi(parts[1])
|
||||
return n
|
||||
}
|
||||
|
||||
// GetTurns flattens all sessions' turns in chronological order.
|
||||
func GetTurns(sample *LocomoSample) []LocomoTurn {
|
||||
names := GetSessionNames(sample.Conversation)
|
||||
var all []LocomoTurn
|
||||
for _, name := range names {
|
||||
raw, ok := sample.Conversation[name]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
var turns []LocomoTurn
|
||||
if err := json.Unmarshal(raw, &turns); err != nil {
|
||||
log.Printf("WARNING: unmarshal failed for session %q in sample %s: %v", name, sample.SampleID, err)
|
||||
continue
|
||||
}
|
||||
all = append(all, turns...)
|
||||
}
|
||||
return all
|
||||
}
|
||||
|
||||
// GetTurnByDiaID finds a specific turn by dia_id (e.g. "D1:3").
|
||||
func GetTurnByDiaID(sample *LocomoSample, diaID string) *LocomoTurn {
|
||||
turns := GetTurns(sample)
|
||||
for i := range turns {
|
||||
if turns[i].DiaID == diaID {
|
||||
return &turns[i]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetSpeakers returns the two speaker names from conversation metadata.
|
||||
func GetSpeakers(conv map[string]json.RawMessage) (string, string) {
|
||||
var a, b string
|
||||
json.Unmarshal(conv["speaker_a"], &a)
|
||||
json.Unmarshal(conv["speaker_b"], &b)
|
||||
return a, b
|
||||
}
|
||||
@@ -0,0 +1,67 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAnswerString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
json string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
"string answer",
|
||||
`{"question":"Q","answer":"Paris","evidence":[],"category":1}`,
|
||||
"Paris",
|
||||
},
|
||||
{
|
||||
"int answer",
|
||||
`{"question":"Q","answer":42,"evidence":[],"category":1}`,
|
||||
"42",
|
||||
},
|
||||
{
|
||||
"adversarial answer (category 5)",
|
||||
`{"question":"Q","evidence":[],"category":5,"adversarial_answer":"self-care is important"}`,
|
||||
"self-care is important",
|
||||
},
|
||||
{
|
||||
"both answer and adversarial_answer present",
|
||||
`{"question":"Q","answer":"normal","evidence":[],"category":5,"adversarial_answer":"adversarial"}`,
|
||||
"normal",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var qa LocomoQA
|
||||
if err := json.Unmarshal([]byte(tt.json), &qa); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
got := qa.AnswerString()
|
||||
if got != tt.want {
|
||||
t.Errorf("AnswerString() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetSessionNames(t *testing.T) {
|
||||
conv := map[string]json.RawMessage{
|
||||
"session_2": {},
|
||||
"session_1": {},
|
||||
"session_10": {},
|
||||
"session_1_date_time": {},
|
||||
"speaker_a": {},
|
||||
}
|
||||
names := GetSessionNames(conv)
|
||||
want := []string{"session_1", "session_2", "session_10"}
|
||||
if len(names) != len(want) {
|
||||
t.Fatalf("got %v, want %v", names, want)
|
||||
}
|
||||
for i, n := range names {
|
||||
if n != want[i] {
|
||||
t.Errorf("names[%d] = %q, want %q", i, n, want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,208 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
var (
|
||||
flagData string
|
||||
flagOut string
|
||||
flagMode string
|
||||
flagBudget int
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Suppress seahorse INFO logs during benchmark
|
||||
logger.SetLevel(logger.WARN)
|
||||
|
||||
rootCmd := &cobra.Command{
|
||||
Use: "membench",
|
||||
Short: "Memory benchmark tool for picoclaw",
|
||||
}
|
||||
|
||||
ingestCmd := &cobra.Command{
|
||||
Use: "ingest",
|
||||
Short: "Load LOCOMO data into storage backends",
|
||||
RunE: runIngest,
|
||||
}
|
||||
ingestCmd.Flags().StringVar(&flagData, "data", "", "LOCOMO dataset directory (required)")
|
||||
ingestCmd.Flags().StringVar(&flagOut, "out", "./bench-out", "output working directory")
|
||||
ingestCmd.Flags().StringVar(&flagMode, "mode", "all", "modes to ingest: legacy, seahorse, or all")
|
||||
|
||||
evalCmd := &cobra.Command{
|
||||
Use: "eval",
|
||||
Short: "Run QA evaluation against ingested data",
|
||||
RunE: runEval,
|
||||
}
|
||||
evalCmd.Flags().StringVar(&flagData, "data", "", "LOCOMO dataset directory (required)")
|
||||
evalCmd.Flags().StringVar(&flagOut, "out", "./bench-out", "output working directory")
|
||||
evalCmd.Flags().StringVar(&flagMode, "mode", "all", "modes to evaluate: legacy, seahorse, or all")
|
||||
evalCmd.Flags().IntVar(&flagBudget, "budget", 4000, "token budget for retrieval")
|
||||
|
||||
reportCmd := &cobra.Command{
|
||||
Use: "report",
|
||||
Short: "Output comparison results from evaluation",
|
||||
RunE: runReport,
|
||||
}
|
||||
reportCmd.Flags().StringVar(&flagOut, "out", "./bench-out", "output working directory")
|
||||
|
||||
runCmd := &cobra.Command{
|
||||
Use: "run",
|
||||
Short: "Convenience: eval + report (ingestion is done inline)",
|
||||
RunE: runAll,
|
||||
}
|
||||
runCmd.Flags().StringVar(&flagData, "data", "", "LOCOMO dataset directory (required)")
|
||||
runCmd.Flags().StringVar(&flagOut, "out", "./bench-out", "output working directory")
|
||||
runCmd.Flags().StringVar(&flagMode, "mode", "all", "modes to run: legacy, seahorse, or all")
|
||||
runCmd.Flags().IntVar(&flagBudget, "budget", 4000, "token budget for retrieval")
|
||||
|
||||
rootCmd.AddCommand(ingestCmd, evalCmd, reportCmd, runCmd)
|
||||
|
||||
if err := rootCmd.Execute(); err != nil {
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func modesFromFlag() []string {
|
||||
switch strings.ToLower(flagMode) {
|
||||
case "all":
|
||||
return []string{"legacy", "seahorse"}
|
||||
default:
|
||||
return []string{strings.ToLower(flagMode)}
|
||||
}
|
||||
}
|
||||
|
||||
func runIngest(cmd *cobra.Command, args []string) error {
|
||||
if flagData == "" {
|
||||
return fmt.Errorf("--data is required")
|
||||
}
|
||||
modes := modesFromFlag()
|
||||
if len(modes) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
samples, err := LoadDataset(flagData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load dataset: %w", err)
|
||||
}
|
||||
log.Printf("Loaded %d samples from %s", len(samples), flagData)
|
||||
|
||||
for _, mode := range modes {
|
||||
switch mode {
|
||||
case "legacy":
|
||||
legacy := NewLegacyStore()
|
||||
for i := range samples {
|
||||
legacy.IngestSample(&samples[i])
|
||||
}
|
||||
log.Printf("legacy: ingested %d samples", len(samples))
|
||||
case "seahorse":
|
||||
dbPath := filepath.Join(flagOut, "seahorse.db")
|
||||
if err := os.MkdirAll(flagOut, 0o755); err != nil {
|
||||
return fmt.Errorf("create out dir: %w", err)
|
||||
}
|
||||
_, err := IngestSeahorse(ctx, samples, dbPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ingest seahorse: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func runEval(cmd *cobra.Command, args []string) error {
|
||||
if flagData == "" {
|
||||
return fmt.Errorf("--data is required")
|
||||
}
|
||||
modes := modesFromFlag()
|
||||
if len(modes) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
samples, err := LoadDataset(flagData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load dataset: %w", err)
|
||||
}
|
||||
log.Printf("Loaded %d samples", len(samples))
|
||||
|
||||
var allResults []EvalResult
|
||||
|
||||
for _, mode := range modes {
|
||||
switch mode {
|
||||
case "legacy":
|
||||
legacy := NewLegacyStore()
|
||||
for i := range samples {
|
||||
legacy.IngestSample(&samples[i])
|
||||
}
|
||||
results := EvalLegacy(ctx, samples, legacy, flagBudget)
|
||||
allResults = append(allResults, results...)
|
||||
log.Printf("legacy: evaluated %d samples", len(results))
|
||||
case "seahorse":
|
||||
dbPath := filepath.Join(flagOut, "seahorse.db")
|
||||
ir, err := IngestSeahorse(ctx, samples, dbPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ingest seahorse: %w", err)
|
||||
}
|
||||
results := EvalSeahorse(ctx, samples, ir, flagBudget)
|
||||
allResults = append(allResults, results...)
|
||||
log.Printf("seahorse: evaluated %d samples", len(results))
|
||||
}
|
||||
}
|
||||
|
||||
if err := SaveResults(allResults, flagOut); err != nil {
|
||||
return fmt.Errorf("save results: %w", err)
|
||||
}
|
||||
if err := SaveAggregated(allResults, flagOut); err != nil {
|
||||
return fmt.Errorf("save aggregated: %w", err)
|
||||
}
|
||||
|
||||
PrintComparison(allResults, nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
func runReport(cmd *cobra.Command, args []string) error {
|
||||
entries, err := os.ReadDir(flagOut)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read out dir: %w", err)
|
||||
}
|
||||
|
||||
var allResults []EvalResult
|
||||
for _, entry := range entries {
|
||||
if !entry.IsDir() && strings.HasPrefix(entry.Name(), "eval_") && strings.HasSuffix(entry.Name(), ".json") {
|
||||
path := filepath.Join(flagOut, entry.Name())
|
||||
var r EvalResult
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
log.Printf("WARN: read %s: %v", path, err)
|
||||
continue
|
||||
}
|
||||
if err := json.Unmarshal(data, &r); err != nil {
|
||||
log.Printf("WARN: parse %s: %v", path, err)
|
||||
continue
|
||||
}
|
||||
allResults = append(allResults, r)
|
||||
}
|
||||
}
|
||||
|
||||
if len(allResults) == 0 {
|
||||
return fmt.Errorf("no eval results found in %s", flagOut)
|
||||
}
|
||||
|
||||
PrintComparison(allResults, nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
func runAll(cmd *cobra.Command, args []string) error {
|
||||
return runEval(cmd, args)
|
||||
}
|
||||
@@ -0,0 +1,227 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
// diaIDRe matches valid dia_id patterns like "D1:3", "D30:5".
|
||||
var diaIDRe = regexp.MustCompile(`^D(\d+):(\d+)$`)
|
||||
|
||||
// SplitEvidenceIDs splits an evidence string that may contain multiple
|
||||
// semicolon-separated or space-separated dia_ids. Only returns valid IDs.
|
||||
// Example: "D8:6; D9:17" → ["D8:6", "D9:17"]
|
||||
// Example: "D9:1 D4:4 D4:6" → ["D9:1", "D4:4", "D4:6"]
|
||||
func SplitEvidenceIDs(evidence string) []string {
|
||||
if evidence == "" {
|
||||
return nil
|
||||
}
|
||||
// Split on semicolons first, then spaces
|
||||
parts := strings.Split(evidence, ";")
|
||||
var ids []string
|
||||
for _, part := range parts {
|
||||
for _, token := range strings.Fields(strings.TrimSpace(part)) {
|
||||
token = strings.TrimSpace(token)
|
||||
if diaIDRe.MatchString(token) {
|
||||
ids = append(ids, NormalizeDiaID(token))
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// NormalizeDiaID strips leading zeros from the number parts of a dia_id.
|
||||
// "D30:05" → "D30:5", "D10:003" → "D10:3"
|
||||
func NormalizeDiaID(id string) string {
|
||||
m := diaIDRe.FindStringSubmatch(id)
|
||||
if m == nil {
|
||||
return id
|
||||
}
|
||||
session, _ := strconv.Atoi(m[1])
|
||||
turn, _ := strconv.Atoi(m[2])
|
||||
return fmt.Sprintf("D%d:%d", session, turn)
|
||||
}
|
||||
|
||||
// stopwords is a fixed English stopword list for deterministic keyword extraction.
|
||||
var stopwords = map[string]struct{}{
|
||||
"a": {}, "an": {}, "the": {},
|
||||
"is": {}, "are": {}, "was": {}, "were": {},
|
||||
"did": {}, "does": {}, "do": {},
|
||||
"when": {}, "where": {}, "what": {}, "who": {},
|
||||
"how": {}, "why": {},
|
||||
"to": {}, "of": {}, "in": {}, "on": {}, "at": {},
|
||||
"for": {}, "and": {}, "or": {}, "but": {}, "not": {},
|
||||
"it": {}, "this": {}, "that": {}, "with": {},
|
||||
"from": {}, "by": {}, "as": {},
|
||||
"if": {}, "then": {}, "than": {}, "so": {},
|
||||
"no": {}, "yes": {},
|
||||
"all": {}, "any": {}, "each": {}, "every": {},
|
||||
"some": {}, "such": {},
|
||||
"about": {}, "into": {}, "over": {},
|
||||
"after": {}, "before": {}, "between": {},
|
||||
"through": {}, "during": {}, "until": {},
|
||||
"would": {}, "could": {}, "should": {},
|
||||
"may": {}, "might": {}, "can": {},
|
||||
"will": {}, "shall": {}, "must": {},
|
||||
"have": {}, "has": {}, "had": {},
|
||||
"been": {}, "being": {}, "be": {},
|
||||
"go": {}, "went": {}, "gone": {},
|
||||
"i": {}, "you": {}, "me": {}, "my": {}, "your": {},
|
||||
"we": {}, "they": {}, "them": {}, "our": {},
|
||||
"its": {}, "their": {}, "he": {}, "she": {},
|
||||
"his": {}, "her": {},
|
||||
}
|
||||
|
||||
// ExtractKeywords removes stopwords and punctuation, returns individual keywords.
|
||||
// Deterministic: uses fixed stopword list, no LLM.
|
||||
func ExtractKeywords(question string) []string {
|
||||
// Lowercase and split on whitespace/punctuation
|
||||
lower := strings.ToLower(question)
|
||||
words := strings.FieldsFunc(lower, func(r rune) bool {
|
||||
return !unicode.IsLetter(r) && !unicode.IsDigit(r)
|
||||
})
|
||||
|
||||
var keywords []string
|
||||
for _, w := range words {
|
||||
if w == "" || len(w) < 2 {
|
||||
continue
|
||||
}
|
||||
if _, ok := stopwords[w]; ok {
|
||||
continue
|
||||
}
|
||||
keywords = append(keywords, w)
|
||||
if len(keywords) >= 6 {
|
||||
break
|
||||
}
|
||||
}
|
||||
return keywords
|
||||
}
|
||||
|
||||
// TokenOverlapF1 computes token-level F1 between prediction and reference.
|
||||
// Both strings are lowercased and split on whitespace.
|
||||
// NOTE: This metric underestimates quality for multi-hop (cat 2) and
|
||||
// open-ended (cat 3) questions where the gold answer uses different phrasing
|
||||
// than the source text. LLM-Judge scoring is a v2 follow-up.
|
||||
func TokenOverlapF1(prediction, reference string) float64 {
|
||||
predTokens := tokenize(prediction)
|
||||
refTokens := tokenize(reference)
|
||||
|
||||
if len(predTokens) == 0 && len(refTokens) == 0 {
|
||||
return 1.0
|
||||
}
|
||||
if len(predTokens) == 0 || len(refTokens) == 0 {
|
||||
return 0.0
|
||||
}
|
||||
|
||||
// Count matches
|
||||
refCount := map[string]int{}
|
||||
for _, t := range refTokens {
|
||||
refCount[t]++
|
||||
}
|
||||
|
||||
predCount := map[string]int{}
|
||||
for _, t := range predTokens {
|
||||
predCount[t]++
|
||||
}
|
||||
|
||||
var matches float64
|
||||
for token, pc := range predCount {
|
||||
if rc, ok := refCount[token]; ok {
|
||||
matches += float64(min(pc, rc))
|
||||
}
|
||||
}
|
||||
|
||||
precision := matches / float64(len(predTokens))
|
||||
recall := matches / float64(len(refTokens))
|
||||
|
||||
if precision+recall == 0 {
|
||||
return 0.0
|
||||
}
|
||||
return 2 * precision * recall / (precision + recall)
|
||||
}
|
||||
|
||||
func tokenize(s string) []string {
|
||||
lower := strings.ToLower(s)
|
||||
return strings.Fields(lower)
|
||||
}
|
||||
|
||||
// RecallHitRate computes fraction of evidence IDs found in retrieved content.
|
||||
// For each evidence dia_id, looks up the turn text and checks substring match.
|
||||
// Logs a warning for turns with text < 20 chars (higher false-positive risk).
|
||||
func RecallHitRate(evidenceIDs []string, sample *LocomoSample, retrievedContent string) float64 {
|
||||
if len(evidenceIDs) == 0 {
|
||||
return 1.0 // no evidence required = perfect
|
||||
}
|
||||
|
||||
// Expand any multi-ID evidence entries (e.g. "D8:6; D9:17" or "D9:1 D4:4")
|
||||
var expanded []string
|
||||
for _, id := range evidenceIDs {
|
||||
split := SplitEvidenceIDs(id)
|
||||
if split != nil {
|
||||
expanded = append(expanded, split...)
|
||||
}
|
||||
}
|
||||
if len(expanded) == 0 {
|
||||
log.Printf("WARNING: no valid dia_ids after expanding evidence %v", evidenceIDs)
|
||||
return float64(0) / float64(len(evidenceIDs))
|
||||
}
|
||||
|
||||
// Build turn index once (avoids re-parsing JSON per ID)
|
||||
turns := GetTurns(sample)
|
||||
turnMap := make(map[string]*LocomoTurn, len(turns))
|
||||
for i := range turns {
|
||||
turnMap[turns[i].DiaID] = &turns[i]
|
||||
}
|
||||
|
||||
lowerRetrieved := strings.ToLower(retrievedContent)
|
||||
found := 0
|
||||
resolvable := 0
|
||||
for _, diaID := range expanded {
|
||||
turn, ok := turnMap[diaID]
|
||||
if !ok {
|
||||
log.Printf("WARNING: dia_id %q not found in sample %s", diaID, sample.SampleID)
|
||||
continue
|
||||
}
|
||||
resolvable++
|
||||
if len(turn.Text) < 20 {
|
||||
log.Printf("WARNING: short turn text (%d chars) for dia_id %s: %q",
|
||||
len(turn.Text), diaID, turn.Text)
|
||||
}
|
||||
if strings.Contains(lowerRetrieved, strings.ToLower(turn.Text)) {
|
||||
found++
|
||||
}
|
||||
}
|
||||
if resolvable == 0 {
|
||||
return 0.0 // no resolvable evidence = can't evaluate
|
||||
}
|
||||
return float64(found) / float64(resolvable)
|
||||
}
|
||||
|
||||
// BudgetTruncate truncates messages to fit within a token budget.
|
||||
// Returns the truncated messages and total token count.
|
||||
func BudgetTruncate(messages []string, budgetTokens int) ([]string, int) {
|
||||
var result []string
|
||||
total := 0
|
||||
// Walk from the front (best first) and keep until budget exhausted.
|
||||
for i := 0; i < len(messages); i++ {
|
||||
tokens := len(messages[i]) / 4
|
||||
if total+tokens > budgetTokens && len(result) > 0 {
|
||||
break
|
||||
}
|
||||
result = append(result, messages[i])
|
||||
total += tokens
|
||||
}
|
||||
return result, total
|
||||
}
|
||||
|
||||
// StringListToContent joins a list of strings into a single content string.
|
||||
func StringListToContent(parts []string) string {
|
||||
return strings.Join(parts, "\n")
|
||||
}
|
||||
@@ -0,0 +1,239 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"math"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSplitEvidenceIDs(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want []string
|
||||
}{
|
||||
{"D1:3", []string{"D1:3"}},
|
||||
{"D8:6; D9:17", []string{"D8:6", "D9:17"}},
|
||||
{"D9:1 D4:4 D4:6", []string{"D9:1", "D4:4", "D4:6"}},
|
||||
{"D22:1 D22:2 D9:10 D9:11", []string{"D22:1", "D22:2", "D9:10", "D9:11"}},
|
||||
{"D21:18 D21:22 D11:15 D11:19", []string{"D21:18", "D21:22", "D11:15", "D11:19"}},
|
||||
{"D30:05", []string{"D30:5"}},
|
||||
{"D", nil},
|
||||
{"D:", nil},
|
||||
{"", nil},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
got := SplitEvidenceIDs(tt.input)
|
||||
if len(got) != len(tt.want) {
|
||||
t.Fatalf("SplitEvidenceIDs(%q) = %v, want %v", tt.input, got, tt.want)
|
||||
}
|
||||
for i := range got {
|
||||
if got[i] != tt.want[i] {
|
||||
t.Errorf("[%d] = %q, want %q", i, got[i], tt.want[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeDiaID(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{"D1:3", "D1:3"},
|
||||
{"D30:05", "D30:5"},
|
||||
{"D10:003", "D10:3"},
|
||||
{"D1:0", "D1:0"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
got := NormalizeDiaID(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("NormalizeDiaID(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenOverlapF1(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
prediction string
|
||||
reference string
|
||||
want float64
|
||||
}{
|
||||
{"exact match", "hello world", "hello world", 1.0},
|
||||
{"no overlap", "foo bar", "baz qux", 0.0},
|
||||
{"empty both", "", "", 1.0},
|
||||
{"empty prediction", "", "hello", 0.0},
|
||||
{"empty reference", "hello", "", 0.0},
|
||||
{"partial overlap", "the cat sat on the mat", "the cat on the floor", 8.0 / 11.0},
|
||||
{"case insensitive", "Hello World", "hello world", 1.0},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := TokenOverlapF1(tt.prediction, tt.reference)
|
||||
if math.Abs(got-tt.want) > 1e-9 {
|
||||
t.Errorf("TokenOverlapF1(%q, %q) = %.4f, want %.4f",
|
||||
tt.prediction, tt.reference, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBudgetTruncate(t *testing.T) {
|
||||
t.Run("within budget returns all", func(t *testing.T) {
|
||||
msgs := []string{"short", "message", "here"}
|
||||
result, total := BudgetTruncate(msgs, 1000)
|
||||
if len(result) != 3 {
|
||||
t.Errorf("expected 3 messages, got %d", len(result))
|
||||
}
|
||||
if total == 0 {
|
||||
t.Error("expected non-zero token count")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("over budget keeps best first", func(t *testing.T) {
|
||||
msgs := []string{
|
||||
"best message that is quite long and takes up tokens",
|
||||
"good message also fairly long content",
|
||||
"worst short",
|
||||
}
|
||||
result, _ := BudgetTruncate(msgs, 5) // very small budget
|
||||
if len(result) == 0 {
|
||||
t.Fatal("expected at least one message")
|
||||
}
|
||||
// Best-ranked (first) should be kept
|
||||
if result[0] != "best message that is quite long and takes up tokens" {
|
||||
t.Errorf("expected best message kept first, got %q", result[0])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("over budget keeps best ranked first", func(t *testing.T) {
|
||||
// Messages are sorted by bm25 rank ascending (best/most-negative first).
|
||||
// When budget is insufficient, BudgetTruncate must keep the front
|
||||
// (best-ranked) messages, not the tail (worst-ranked).
|
||||
msgs := []string{
|
||||
"best ranked message with some content here",
|
||||
"second best message also has content",
|
||||
"third message here too",
|
||||
"worst ranked short",
|
||||
}
|
||||
// Budget only fits ~1 message (~10 tokens per message, budget=12)
|
||||
result, _ := BudgetTruncate(msgs, 12)
|
||||
if len(result) == 0 {
|
||||
t.Fatal("expected at least one message")
|
||||
}
|
||||
if result[0] != "best ranked message with some content here" {
|
||||
t.Errorf("expected best-ranked (first) message kept, got %q", result[0])
|
||||
}
|
||||
// Worst-ranked (last) must NOT appear
|
||||
for _, m := range result {
|
||||
if m == "worst ranked short" {
|
||||
t.Error("worst-ranked message should have been truncated")
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("preserves original order", func(t *testing.T) {
|
||||
msgs := []string{"alpha", "beta", "gamma"}
|
||||
result, _ := BudgetTruncate(msgs, 100)
|
||||
for i, got := range result {
|
||||
if got != msgs[i] {
|
||||
t.Errorf("result[%d] = %q, want %q", i, got, msgs[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty input", func(t *testing.T) {
|
||||
result, total := BudgetTruncate(nil, 100)
|
||||
if len(result) != 0 {
|
||||
t.Errorf("expected 0 messages, got %d", len(result))
|
||||
}
|
||||
if total != 0 {
|
||||
t.Errorf("expected 0 tokens, got %d", total)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRecallHitRate(t *testing.T) {
|
||||
// Build a sample with known turns
|
||||
sample := &LocomoSample{
|
||||
SampleID: "test-sample",
|
||||
Conversation: map[string]json.RawMessage{
|
||||
"session_1": json.RawMessage(`[
|
||||
{"speaker":"A","dia_id":"D1:1","text":"hello world this is a test message with enough length"},
|
||||
{"speaker":"B","dia_id":"D1:2","text":"another message for testing recall computation purposes here"},
|
||||
{"speaker":"A","dia_id":"D1:3","text":"third turn with some more content to test"}
|
||||
]`),
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("all evidence found", func(t *testing.T) {
|
||||
retrieved := "hello world this is a test message with enough length another message for testing recall computation purposes here"
|
||||
got := RecallHitRate([]string{"D1:1", "D1:2"}, sample, retrieved)
|
||||
if math.Abs(got-1.0) > 1e-9 {
|
||||
t.Errorf("RecallHitRate all found = %.4f, want 1.0", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("partial evidence found", func(t *testing.T) {
|
||||
retrieved := "hello world this is a test message with enough length"
|
||||
got := RecallHitRate([]string{"D1:1", "D1:2"}, sample, retrieved)
|
||||
if math.Abs(got-0.5) > 1e-9 {
|
||||
t.Errorf("RecallHitRate partial = %.4f, want 0.5", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no evidence required", func(t *testing.T) {
|
||||
got := RecallHitRate(nil, sample, "anything")
|
||||
if got != 1.0 {
|
||||
t.Errorf("RecallHitRate no evidence = %.4f, want 1.0", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing turn excluded from denominator", func(t *testing.T) {
|
||||
// D1:1 is found, D99:1 does not exist in sample
|
||||
// Should only count resolvable turns in denominator
|
||||
retrieved := "hello world this is a test message with enough length"
|
||||
got := RecallHitRate([]string{"D1:1", "D99:1"}, sample, retrieved)
|
||||
if math.Abs(got-1.0) > 1e-9 {
|
||||
t.Errorf("RecallHitRate missing turn = %.4f, want 1.0 (unresolvable excluded)", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtractKeywords(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want []string
|
||||
}{
|
||||
{"simple", "What is the capital of France", []string{"capital", "france"}},
|
||||
{
|
||||
"stops removed",
|
||||
"Who is the president of the United States",
|
||||
[]string{"president", "united", "states"},
|
||||
},
|
||||
{
|
||||
"max 6 keywords",
|
||||
"one two three four five six seven eight nine ten",
|
||||
[]string{"one", "two", "three", "four", "five", "six"},
|
||||
},
|
||||
{"short words filtered", "I am a go to the store", []string{"am", "store"}},
|
||||
{"empty", "", nil},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := ExtractKeywords(tt.input)
|
||||
if len(got) != len(tt.want) {
|
||||
t.Fatalf("ExtractKeywords(%q) = %v (len %d), want %v (len %d)",
|
||||
tt.input, got, len(got), tt.want, len(tt.want))
|
||||
}
|
||||
for i := range got {
|
||||
if got[i] != tt.want[i] {
|
||||
t.Errorf("[%d] = %q, want %q", i, got[i], tt.want[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -29,7 +29,7 @@ import (
|
||||
)
|
||||
|
||||
func NewPicoclawCommand() *cobra.Command {
|
||||
short := fmt.Sprintf("%s picoclaw - Personal AI Assistant v%s\n\n", internal.Logo, config.GetVersion())
|
||||
short := fmt.Sprintf("%s picoclaw - Personal AI Assistant %s\n\n", internal.Logo, config.GetVersion())
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "picoclaw",
|
||||
|
||||
@@ -17,7 +17,7 @@ func TestNewPicoclawCommand(t *testing.T) {
|
||||
|
||||
require.NotNil(t, cmd)
|
||||
|
||||
short := fmt.Sprintf("%s picoclaw - Personal AI Assistant v%s\n\n", internal.Logo, config.GetVersion())
|
||||
short := fmt.Sprintf("%s picoclaw - Personal AI Assistant %s\n\n", internal.Logo, config.GetVersion())
|
||||
|
||||
assert.Equal(t, "picoclaw", cmd.Use)
|
||||
assert.Equal(t, short, cmd.Short)
|
||||
|
||||
@@ -9,4 +9,4 @@ COPY $TARGETPLATFORM/picoclaw-launcher /usr/local/bin/picoclaw-launcher
|
||||
COPY $TARGETPLATFORM/picoclaw-launcher-tui /usr/local/bin/picoclaw-launcher-tui
|
||||
|
||||
ENTRYPOINT ["picoclaw-launcher"]
|
||||
CMD ["-public", "-no-browser"]
|
||||
CMD ["-console", "-public", "-no-browser"]
|
||||
|
||||
@@ -45,8 +45,11 @@ services:
|
||||
- launcher
|
||||
environment:
|
||||
- PICOCLAW_GATEWAY_HOST=0.0.0.0
|
||||
# Set a fixed dashboard token instead of a random one each restart.
|
||||
# If not set, a random token is generated and printed to the console on startup.
|
||||
#- PICOCLAW_LAUNCHER_TOKEN=your-secret-token-here
|
||||
ports:
|
||||
- "127.0.0.1:18800:18800"
|
||||
- "127.0.0.1:18790:18790"
|
||||
- "18800:18800"
|
||||
- "18790:18790"
|
||||
volumes:
|
||||
- ./data:/root/.picoclaw
|
||||
|
||||
@@ -28,6 +28,69 @@ The currently exposed synchronous hook points are:
|
||||
|
||||
Everything else is exposed as read-only events.
|
||||
|
||||
## Hook Actions
|
||||
|
||||
Hooks can return different actions to control the flow:
|
||||
|
||||
| Action | Applicable Stages | Effect |
|
||||
| --- | --- | --- |
|
||||
| `continue` | All interceptors | Pass through without modification |
|
||||
| `modify` | `before_llm`, `after_llm`, `before_tool`, `after_tool` | Modify request/response and continue |
|
||||
| `respond` | `before_tool` | Return a tool result directly, skip actual tool execution |
|
||||
| `deny_tool` | `before_tool` | Deny tool execution, return error message |
|
||||
| `abort_turn` | All interceptors | Abort the current turn |
|
||||
| `hard_abort` | All interceptors | Force stop the entire agent loop |
|
||||
|
||||
### The `respond` Action
|
||||
|
||||
The `respond` action is special: it allows a `before_tool` hook to provide the tool result directly, skipping the actual tool execution. This is useful for:
|
||||
|
||||
1. **Plugin tool injection**: External hooks can implement tools without registering them in the tool registry
|
||||
2. **Tool result caching**: Return cached results for repeated tool calls
|
||||
3. **Tool mocking**: Return mock results for testing purposes
|
||||
|
||||
When a hook returns `respond` with a `HookResult`, the agent loop:
|
||||
1. Skips the actual tool execution
|
||||
2. Uses the provided result as if the tool had executed
|
||||
3. Continues the turn normally with the result
|
||||
|
||||
Example (Go in-process hook):
|
||||
|
||||
```go
|
||||
func (h *MyHook) BeforeTool(
|
||||
ctx context.Context,
|
||||
call *agent.ToolCallHookRequest,
|
||||
) (*agent.ToolCallHookRequest, agent.HookDecision, error) {
|
||||
if call.Tool == "my_plugin_tool" {
|
||||
next := call.Clone()
|
||||
next.HookResult = &tools.ToolResult{
|
||||
ForLLM: "Plugin tool executed successfully",
|
||||
Silent: false,
|
||||
IsError: false,
|
||||
}
|
||||
return next, agent.HookDecision{Action: agent.HookActionRespond}, nil
|
||||
}
|
||||
return call, agent.HookDecision{Action: agent.HookActionContinue}, nil
|
||||
}
|
||||
```
|
||||
|
||||
Example (Python process hook):
|
||||
|
||||
```python
|
||||
def handle_before_tool(params: dict) -> dict:
|
||||
tool = params.get("tool", "")
|
||||
if tool == "my_plugin_tool":
|
||||
return {
|
||||
"action": "respond",
|
||||
"result": {
|
||||
"for_llm": "Plugin tool executed successfully",
|
||||
"silent": False,
|
||||
"is_error": False
|
||||
}
|
||||
}
|
||||
return {"action": "continue"}
|
||||
```
|
||||
|
||||
## Execution Order
|
||||
|
||||
`HookManager` sorts hooks like this:
|
||||
|
||||
@@ -28,6 +28,69 @@
|
||||
|
||||
其余 lifecycle 通过事件形式只读暴露。
|
||||
|
||||
## Hook Actions
|
||||
|
||||
Hook 可以返回不同的 action 来控制流程:
|
||||
|
||||
| Action | 适用阶段 | 效果 |
|
||||
| --- | --- | --- |
|
||||
| `continue` | 所有拦截型 | 放行,不做修改 |
|
||||
| `modify` | `before_llm`, `after_llm`, `before_tool`, `after_tool` | 改写请求/响应后放行 |
|
||||
| `respond` | `before_tool` | 直接返回工具结果,跳过实际工具执行 |
|
||||
| `deny_tool` | `before_tool` | 拒绝工具执行,返回错误信息 |
|
||||
| `abort_turn` | 所有拦截型 | 中止当前 turn |
|
||||
| `hard_abort` | 所有拦截型 | 强制终止整个 agent loop |
|
||||
|
||||
### `respond` Action
|
||||
|
||||
`respond` action 是特殊的:它允许 `before_tool` hook 直接提供工具结果,跳过实际工具执行。适用于:
|
||||
|
||||
1. **插件工具注入**:外部 hook 可以实现工具,无需在 ToolRegistry 注册
|
||||
2. **工具结果缓存**:对重复调用返回缓存结果
|
||||
3. **工具模拟**:测试时返回模拟结果
|
||||
|
||||
当 hook 返回 `respond` 并携带 `HookResult` 时,agent loop 会:
|
||||
1. 跳过实际工具执行
|
||||
2. 使用提供的结果作为工具执行结果
|
||||
3. 正常继续 turn 流程
|
||||
|
||||
示例(Go 进程内 hook):
|
||||
|
||||
```go
|
||||
func (h *MyHook) BeforeTool(
|
||||
ctx context.Context,
|
||||
call *agent.ToolCallHookRequest,
|
||||
) (*agent.ToolCallHookRequest, agent.HookDecision, error) {
|
||||
if call.Tool == "my_plugin_tool" {
|
||||
next := call.Clone()
|
||||
next.HookResult = &tools.ToolResult{
|
||||
ForLLM: "Plugin tool executed successfully",
|
||||
Silent: false,
|
||||
IsError: false,
|
||||
}
|
||||
return next, agent.HookDecision{Action: agent.HookActionRespond}, nil
|
||||
}
|
||||
return call, agent.HookDecision{Action: agent.HookActionContinue}, nil
|
||||
}
|
||||
```
|
||||
|
||||
示例(Python process hook):
|
||||
|
||||
```python
|
||||
def handle_before_tool(params: dict) -> dict:
|
||||
tool = params.get("tool", "")
|
||||
if tool == "my_plugin_tool":
|
||||
return {
|
||||
"action": "respond",
|
||||
"result": {
|
||||
"for_llm": "Plugin tool executed successfully",
|
||||
"silent": False,
|
||||
"is_error": False
|
||||
}
|
||||
}
|
||||
return {"action": "continue"}
|
||||
```
|
||||
|
||||
## 执行顺序
|
||||
|
||||
HookManager 的排序规则是:
|
||||
|
||||
@@ -0,0 +1,568 @@
|
||||
# Hook JSON-RPC Protocol Details
|
||||
|
||||
All hooks use `JSON-RPC 2.0` format, with one JSON message per line, transmitted via stdio.
|
||||
|
||||
---
|
||||
|
||||
## Basic Protocol Structure
|
||||
|
||||
### Request (PicoClaw → Hook)
|
||||
|
||||
```json
|
||||
{"jsonrpc":"2.0","id":1,"method":"hook.xxx","params":{...}}
|
||||
```
|
||||
|
||||
### Response (Hook → PicoClaw)
|
||||
|
||||
Success:
|
||||
```json
|
||||
{"jsonrpc":"2.0","id":1,"result":{...}}
|
||||
```
|
||||
|
||||
Error:
|
||||
```json
|
||||
{"jsonrpc":"2.0","id":1,"error":{"code":-32000,"message":"error message"}}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 1. `hook.hello` (Handshake)
|
||||
|
||||
Handshake must be completed at startup, otherwise the hook process will be terminated.
|
||||
|
||||
### Request
|
||||
|
||||
```json
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "hook.hello",
|
||||
"params": {
|
||||
"name": "py_review_gate",
|
||||
"version": 1,
|
||||
"modes": ["observe", "tool", "approve"]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
| Field | Description |
|
||||
|-------|-------------|
|
||||
| `name` | hook name (from configuration) |
|
||||
| `version` | protocol version, currently `1` |
|
||||
| `modes` | capability modes supported by the hook |
|
||||
|
||||
### Response
|
||||
|
||||
```json
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"result": {
|
||||
"ok": true,
|
||||
"name": "python-review-gate"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 2. `hook.before_llm`
|
||||
|
||||
Triggered before sending request to LLM. Can be used to inject tools.
|
||||
|
||||
### Request
|
||||
|
||||
```json
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 2,
|
||||
"method": "hook.before_llm",
|
||||
"params": {
|
||||
"meta": {
|
||||
"AgentID": "agent-1",
|
||||
"TurnID": "turn-1",
|
||||
"ParentTurnID": "",
|
||||
"SessionKey": "session-1",
|
||||
"Iteration": 0,
|
||||
"TracePath": "runTurn",
|
||||
"Source": "turn.llm.request"
|
||||
},
|
||||
"model": "claude-sonnet",
|
||||
"messages": [
|
||||
{"role": "user", "content": "hello"}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "echo",
|
||||
"description": "echo text",
|
||||
"parameters": {"type": "object"}
|
||||
}
|
||||
}
|
||||
],
|
||||
"options": {
|
||||
"temperature": 0.7
|
||||
},
|
||||
"channel": "cli",
|
||||
"chat_id": "chat-1",
|
||||
"graceful_terminal": false
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
| Field | Description |
|
||||
|-------|-------------|
|
||||
| `meta` | event metadata for tracing |
|
||||
| `model` | requested model name |
|
||||
| `messages` | conversation history |
|
||||
| `tools` | list of available tool definitions |
|
||||
| `options` | LLM parameters (temperature, max_tokens, etc.) |
|
||||
| `channel` | request source channel |
|
||||
| `chat_id` | session ID |
|
||||
|
||||
### Response (Tool Injection Example)
|
||||
|
||||
```json
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 2,
|
||||
"result": {
|
||||
"action": "modify",
|
||||
"request": {
|
||||
"model": "claude-sonnet",
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "echo",
|
||||
"description": "echo",
|
||||
"parameters": {}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "my_plugin_tool",
|
||||
"description": "Plugin injected tool",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
| Field | Description |
|
||||
|-------|-------------|
|
||||
| `action` | decision action (see table below) |
|
||||
| `request` | modified request object |
|
||||
|
||||
---
|
||||
|
||||
## 3. `hook.after_llm`
|
||||
|
||||
Triggered after receiving LLM response. Can modify response content.
|
||||
|
||||
### Request
|
||||
|
||||
```json
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 3,
|
||||
"method": "hook.after_llm",
|
||||
"params": {
|
||||
"meta": {
|
||||
"AgentID": "agent-1",
|
||||
"TurnID": "turn-1",
|
||||
"SessionKey": "session-1"
|
||||
},
|
||||
"model": "claude-sonnet",
|
||||
"response": {
|
||||
"role": "assistant",
|
||||
"content": "Hi!",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "tc-1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "echo",
|
||||
"arguments": "{\"text\":\"hi\"}"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"channel": "cli",
|
||||
"chat_id": "chat-1"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Response
|
||||
|
||||
```json
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 3,
|
||||
"result": {
|
||||
"action": "continue"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 4. `hook.before_tool`
|
||||
|
||||
Triggered before tool execution. Can modify tool name and arguments, deny execution, or return result directly.
|
||||
|
||||
### Request
|
||||
|
||||
```json
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 4,
|
||||
"method": "hook.before_tool",
|
||||
"params": {
|
||||
"meta": {
|
||||
"AgentID": "agent-1",
|
||||
"TurnID": "turn-1",
|
||||
"SessionKey": "session-1"
|
||||
},
|
||||
"tool": "echo_text",
|
||||
"arguments": {
|
||||
"text": "hello"
|
||||
},
|
||||
"channel": "cli",
|
||||
"chat_id": "chat-1"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
| Field | Description |
|
||||
|-------|-------------|
|
||||
| `tool` | tool name |
|
||||
| `arguments` | tool arguments |
|
||||
|
||||
### Response (Modify Arguments)
|
||||
|
||||
```json
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 4,
|
||||
"result": {
|
||||
"action": "modify",
|
||||
"call": {
|
||||
"tool": "echo_text",
|
||||
"arguments": {
|
||||
"text": "modified hello"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Response (Deny Execution)
|
||||
|
||||
```json
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 4,
|
||||
"result": {
|
||||
"action": "deny_tool",
|
||||
"reason": "Invalid arguments"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Response (Return Result Directly - respond)
|
||||
|
||||
```json
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 4,
|
||||
"result": {
|
||||
"action": "respond",
|
||||
"call": {
|
||||
"tool": "my_plugin_tool",
|
||||
"arguments": {
|
||||
"query": "hello"
|
||||
}
|
||||
},
|
||||
"result": {
|
||||
"for_llm": "Plugin tool executed successfully",
|
||||
"for_user": "",
|
||||
"silent": false,
|
||||
"is_error": false
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
The `respond` action allows hooks to return tool results directly, skipping actual tool execution. Use cases:
|
||||
1. **Plugin tool injection**: External hooks can implement tools without registering in ToolRegistry
|
||||
2. **Tool result caching**: Return cached results for repeated calls
|
||||
3. **Tool mocking**: Return mock results during testing
|
||||
|
||||
| Field | Description |
|
||||
|-------|-------------|
|
||||
| `action` | must be `respond` |
|
||||
| `call` | modified call information (optional) |
|
||||
| `result` | tool result to return directly |
|
||||
|
||||
---
|
||||
|
||||
## 5. `hook.after_tool`
|
||||
|
||||
Triggered after tool execution completes. Can modify the result returned to LLM.
|
||||
|
||||
### Request
|
||||
|
||||
```json
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 5,
|
||||
"method": "hook.after_tool",
|
||||
"params": {
|
||||
"meta": {
|
||||
"AgentID": "agent-1",
|
||||
"TurnID": "turn-1",
|
||||
"SessionKey": "session-1"
|
||||
},
|
||||
"tool": "echo_text",
|
||||
"arguments": {
|
||||
"text": "hello"
|
||||
},
|
||||
"result": {
|
||||
"for_llm": "echoed: hello",
|
||||
"for_user": "",
|
||||
"silent": false,
|
||||
"is_error": false,
|
||||
"async": false,
|
||||
"media": [],
|
||||
"artifact_tags": [],
|
||||
"response_handled": false
|
||||
},
|
||||
"duration": 15000000,
|
||||
"channel": "cli",
|
||||
"chat_id": "chat-1"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
| Field | Description |
|
||||
|-------|-------------|
|
||||
| `result.for_llm` | content returned to LLM |
|
||||
| `result.for_user` | content sent to user |
|
||||
| `result.silent` | whether silent (not sent to user) |
|
||||
| `result.is_error` | whether it's an error |
|
||||
| `result.async` | whether executed asynchronously |
|
||||
| `result.media` | list of media references |
|
||||
| `result.artifact_tags` | local artifact path tags |
|
||||
| `result.response_handled` | whether response has been handled |
|
||||
| `duration` | execution time (nanoseconds) |
|
||||
|
||||
### Response
|
||||
|
||||
```json
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 5,
|
||||
"result": {
|
||||
"action": "continue"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 6. `hook.approve_tool`
|
||||
|
||||
Approval hook for deciding whether to allow execution of sensitive tools.
|
||||
|
||||
### Request
|
||||
|
||||
```json
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 6,
|
||||
"method": "hook.approve_tool",
|
||||
"params": {
|
||||
"meta": {
|
||||
"AgentID": "agent-1",
|
||||
"TurnID": "turn-1",
|
||||
"SessionKey": "session-1"
|
||||
},
|
||||
"tool": "bash",
|
||||
"arguments": {
|
||||
"command": "rm -rf /"
|
||||
},
|
||||
"channel": "cli",
|
||||
"chat_id": "chat-1"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Response (Approved)
|
||||
|
||||
```json
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 6,
|
||||
"result": {
|
||||
"approved": true
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Response (Denied)
|
||||
|
||||
```json
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 6,
|
||||
"result": {
|
||||
"approved": false,
|
||||
"reason": "Dangerous command, execution denied"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 7. `hook.event` (notification)
|
||||
|
||||
Observer event, broadcast only, no response required. `id` is `0` or absent.
|
||||
|
||||
```json
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"method": "hook.event",
|
||||
"params": {
|
||||
"Kind": "tool_exec_start",
|
||||
"Meta": {
|
||||
"AgentID": "agent-1",
|
||||
"TurnID": "turn-1"
|
||||
},
|
||||
"Payload": {
|
||||
"Tool": "echo_text",
|
||||
"Arguments": {"text": "hello"}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Common `Kind` values:
|
||||
- `turn_start` / `turn_end`
|
||||
- `llm_request` / `llm_response`
|
||||
- `tool_exec_start` / `tool_exec_end` / `tool_exec_skipped`
|
||||
- `steering_injected`
|
||||
- `interrupt_received`
|
||||
- `error`
|
||||
|
||||
---
|
||||
|
||||
## Action Options
|
||||
|
||||
| action | Applicable hooks | Effect |
|
||||
|--------|-----------------|--------|
|
||||
| `continue` | All interceptor types | Pass through without modification |
|
||||
| `modify` | `before_llm`, `before_tool`, `after_llm`, `after_tool` | Modify request/response and pass through |
|
||||
| `respond` | `before_tool` | Return tool result directly, skip actual execution. **Note: AfterTool is NOT called (design decision - respond provides final answer).** |
|
||||
| `deny_tool` | `before_tool` | Deny tool execution |
|
||||
| `abort_turn` | All interceptor types | Abort current turn, return error |
|
||||
| `hard_abort` | All interceptor types | Force stop entire agent loop |
|
||||
|
||||
---
|
||||
|
||||
## Complete Flow Example
|
||||
|
||||
```json
|
||||
{"jsonrpc":"2.0","id":1,"method":"hook.hello","params":{"name":"my_hook","version":1,"modes":["tool","approve"]}}
|
||||
{"jsonrpc":"2.0","id":1,"result":{"ok":true,"name":"my_hook"}}
|
||||
{"jsonrpc":"2.0","id":2,"method":"hook.before_llm","params":{"model":"claude-sonnet","messages":[{"role":"user","content":"hello"}],"tools":[]}}
|
||||
{"jsonrpc":"2.0","id":2,"result":{"action":"continue"}}
|
||||
{"jsonrpc":"2.0","id":3,"method":"hook.before_tool","params":{"tool":"bash","arguments":{"command":"ls"}}}
|
||||
{"jsonrpc":"2.0","id":3,"result":{"action":"continue"}}
|
||||
{"jsonrpc":"2.0","id":4,"method":"hook.approve_tool","params":{"tool":"bash","arguments":{"command":"ls"}}}
|
||||
{"jsonrpc":"2.0","id":4,"result":{"approved":true}}
|
||||
{"jsonrpc":"2.0","id":5,"method":"hook.after_tool","params":{"tool":"bash","arguments":{"command":"ls"},"result":{"for_llm":"file1.txt\nfile2.txt"},"duration":5000000}}
|
||||
{"jsonrpc":"2.0","id":5,"result":{"action":"continue"}}
|
||||
{"jsonrpc":"2.0","id":6,"method":"hook.after_llm","params":{"model":"claude-sonnet","response":{"role":"assistant","content":"Files listed"}}}
|
||||
{"jsonrpc":"2.0","id":6,"result":{"action":"continue"}}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Plugin Tool Injection via `before_llm` and `before_tool`
|
||||
|
||||
Standard flow for plugin tool injection:
|
||||
|
||||
1. In `before_llm`, inject tool definition to let LLM know the tool is available
|
||||
2. In `before_tool`, use `respond` action to return tool execution result directly
|
||||
|
||||
### `before_llm` Inject Tool Definition
|
||||
|
||||
```python
|
||||
def handle_before_llm(params: dict) -> dict:
|
||||
tools = params.get("tools", [])
|
||||
|
||||
# Add plugin tool definition
|
||||
tools.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "my_plugin_tool",
|
||||
"description": "Plugin provided tool",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"input": {"type": "string", "description": "Input content"}
|
||||
},
|
||||
"required": ["input"]
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return {
|
||||
"action": "modify",
|
||||
"request": {
|
||||
"model": params["model"],
|
||||
"messages": params["messages"],
|
||||
"tools": tools,
|
||||
"options": params.get("options", {})
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### `before_tool` Return Execution Result
|
||||
|
||||
```python
|
||||
def handle_before_tool(params: dict) -> dict:
|
||||
tool = params.get("tool", "")
|
||||
|
||||
if tool == "my_plugin_tool":
|
||||
# Implement tool logic here
|
||||
args = params.get("arguments", {})
|
||||
input_text = args.get("input", "")
|
||||
|
||||
# Return result directly, no need to register in ToolRegistry
|
||||
return {
|
||||
"action": "respond",
|
||||
"result": {
|
||||
"for_llm": f"Plugin tool executed successfully, input: {input_text}",
|
||||
"silent": False,
|
||||
"is_error": False
|
||||
}
|
||||
}
|
||||
|
||||
return {"action": "continue"}
|
||||
```
|
||||
|
||||
This way, external hooks can fully implement plugin tools without registering any tool implementation inside PicoClaw.
|
||||
@@ -0,0 +1,568 @@
|
||||
# Hook JSON-RPC 协议详解
|
||||
|
||||
所有 hook 使用 `JSON-RPC 2.0` 格式,每行一个 JSON 消息,通过 stdio 传输。
|
||||
|
||||
---
|
||||
|
||||
## 基础协议结构
|
||||
|
||||
### 请求(PicoClaw → Hook)
|
||||
|
||||
```json
|
||||
{"jsonrpc":"2.0","id":1,"method":"hook.xxx","params":{...}}
|
||||
```
|
||||
|
||||
### 响应(Hook → PicoClaw)
|
||||
|
||||
成功:
|
||||
```json
|
||||
{"jsonrpc":"2.0","id":1,"result":{...}}
|
||||
```
|
||||
|
||||
错误:
|
||||
```json
|
||||
{"jsonrpc":"2.0","id":1,"error":{"code":-32000,"message":"错误信息"}}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 1. `hook.hello`(握手)
|
||||
|
||||
启动时必须完成握手,否则 hook 进程会被终止。
|
||||
|
||||
### 请求
|
||||
|
||||
```json
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "hook.hello",
|
||||
"params": {
|
||||
"name": "py_review_gate",
|
||||
"version": 1,
|
||||
"modes": ["observe", "tool", "approve"]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
| 字段 | 说明 |
|
||||
|------|------|
|
||||
| `name` | hook 名称(来自配置) |
|
||||
| `version` | 协议版本,当前为 `1` |
|
||||
| `modes` | hook 支持的能力模式 |
|
||||
|
||||
### 响应
|
||||
|
||||
```json
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"result": {
|
||||
"ok": true,
|
||||
"name": "python-review-gate"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 2. `hook.before_llm`
|
||||
|
||||
在发送请求给 LLM 之前触发。可用于注入工具。
|
||||
|
||||
### 请求
|
||||
|
||||
```json
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 2,
|
||||
"method": "hook.before_llm",
|
||||
"params": {
|
||||
"meta": {
|
||||
"AgentID": "agent-1",
|
||||
"TurnID": "turn-1",
|
||||
"ParentTurnID": "",
|
||||
"SessionKey": "session-1",
|
||||
"Iteration": 0,
|
||||
"TracePath": "runTurn",
|
||||
"Source": "turn.llm.request"
|
||||
},
|
||||
"model": "claude-sonnet",
|
||||
"messages": [
|
||||
{"role": "user", "content": "hello"}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "echo",
|
||||
"description": "echo text",
|
||||
"parameters": {"type": "object"}
|
||||
}
|
||||
}
|
||||
],
|
||||
"options": {
|
||||
"temperature": 0.7
|
||||
},
|
||||
"channel": "cli",
|
||||
"chat_id": "chat-1",
|
||||
"graceful_terminal": false
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
| 字段 | 说明 |
|
||||
|------|------|
|
||||
| `meta` | 事件元数据,用于追踪 |
|
||||
| `model` | 请求的模型名称 |
|
||||
| `messages` | 对话历史 |
|
||||
| `tools` | 可用工具定义列表 |
|
||||
| `options` | LLM 参数(temperature、max_tokens 等) |
|
||||
| `channel` | 请求来源通道 |
|
||||
| `chat_id` | 会话 ID |
|
||||
|
||||
### 响应(注入工具示例)
|
||||
|
||||
```json
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 2,
|
||||
"result": {
|
||||
"action": "modify",
|
||||
"request": {
|
||||
"model": "claude-sonnet",
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "echo",
|
||||
"description": "echo",
|
||||
"parameters": {}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "my_plugin_tool",
|
||||
"description": "插件注入的工具",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
| 字段 | 说明 |
|
||||
|------|------|
|
||||
| `action` | 决策动作(见下表) |
|
||||
| `request` | 修改后的请求对象 |
|
||||
|
||||
---
|
||||
|
||||
## 3. `hook.after_llm`
|
||||
|
||||
在收到 LLM 响应后触发。可修改响应内容。
|
||||
|
||||
### 请求
|
||||
|
||||
```json
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 3,
|
||||
"method": "hook.after_llm",
|
||||
"params": {
|
||||
"meta": {
|
||||
"AgentID": "agent-1",
|
||||
"TurnID": "turn-1",
|
||||
"SessionKey": "session-1"
|
||||
},
|
||||
"model": "claude-sonnet",
|
||||
"response": {
|
||||
"role": "assistant",
|
||||
"content": "Hi!",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "tc-1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "echo",
|
||||
"arguments": "{\"text\":\"hi\"}"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"channel": "cli",
|
||||
"chat_id": "chat-1"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 响应
|
||||
|
||||
```json
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 3,
|
||||
"result": {
|
||||
"action": "continue"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 4. `hook.before_tool`
|
||||
|
||||
在执行工具前触发。可修改工具名称和参数,或拒绝执行,或直接返回结果。
|
||||
|
||||
### 请求
|
||||
|
||||
```json
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 4,
|
||||
"method": "hook.before_tool",
|
||||
"params": {
|
||||
"meta": {
|
||||
"AgentID": "agent-1",
|
||||
"TurnID": "turn-1",
|
||||
"SessionKey": "session-1"
|
||||
},
|
||||
"tool": "echo_text",
|
||||
"arguments": {
|
||||
"text": "hello"
|
||||
},
|
||||
"channel": "cli",
|
||||
"chat_id": "chat-1"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
| 字段 | 说明 |
|
||||
|------|------|
|
||||
| `tool` | 工具名称 |
|
||||
| `arguments` | 工具参数 |
|
||||
|
||||
### 响应(改写参数)
|
||||
|
||||
```json
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 4,
|
||||
"result": {
|
||||
"action": "modify",
|
||||
"call": {
|
||||
"tool": "echo_text",
|
||||
"arguments": {
|
||||
"text": "modified hello"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 响应(拒绝执行)
|
||||
|
||||
```json
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 4,
|
||||
"result": {
|
||||
"action": "deny_tool",
|
||||
"reason": "参数不合法"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 响应(直接返回结果 - respond)
|
||||
|
||||
```json
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 4,
|
||||
"result": {
|
||||
"action": "respond",
|
||||
"call": {
|
||||
"tool": "my_plugin_tool",
|
||||
"arguments": {
|
||||
"query": "hello"
|
||||
}
|
||||
},
|
||||
"result": {
|
||||
"for_llm": "Plugin tool executed successfully",
|
||||
"for_user": "",
|
||||
"silent": false,
|
||||
"is_error": false
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
`respond` action 允许 hook 直接返回工具结果,跳过实际工具执行。适用于:
|
||||
1. **插件工具注入**:外部 hook 可实现工具,无需在 ToolRegistry 注册
|
||||
2. **工具结果缓存**:对重复调用返回缓存结果
|
||||
3. **工具模拟**:测试时返回模拟结果
|
||||
|
||||
| 字段 | 说明 |
|
||||
|------|------|
|
||||
| `action` | 必须为 `respond` |
|
||||
| `call` | 修改后的调用信息(可选) |
|
||||
| `result` | 直接返回的工具结果 |
|
||||
|
||||
---
|
||||
|
||||
## 5. `hook.after_tool`
|
||||
|
||||
在工具执行完成后触发。可修改返回给 LLM 的结果。
|
||||
|
||||
### 请求
|
||||
|
||||
```json
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 5,
|
||||
"method": "hook.after_tool",
|
||||
"params": {
|
||||
"meta": {
|
||||
"AgentID": "agent-1",
|
||||
"TurnID": "turn-1",
|
||||
"SessionKey": "session-1"
|
||||
},
|
||||
"tool": "echo_text",
|
||||
"arguments": {
|
||||
"text": "hello"
|
||||
},
|
||||
"result": {
|
||||
"for_llm": "echoed: hello",
|
||||
"for_user": "",
|
||||
"silent": false,
|
||||
"is_error": false,
|
||||
"async": false,
|
||||
"media": [],
|
||||
"artifact_tags": [],
|
||||
"response_handled": false
|
||||
},
|
||||
"duration": 15000000,
|
||||
"channel": "cli",
|
||||
"chat_id": "chat-1"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
| 字段 | 说明 |
|
||||
|------|------|
|
||||
| `result.for_llm` | 返回给 LLM 的内容 |
|
||||
| `result.for_user` | 发送给用户的内容 |
|
||||
| `result.silent` | 是否静默(不发送给用户) |
|
||||
| `result.is_error` | 是否为错误 |
|
||||
| `result.async` | 是否异步执行 |
|
||||
| `result.media` | 媒体引用列表 |
|
||||
| `result.artifact_tags` | 本地产物路径标签 |
|
||||
| `result.response_handled` | 是否已处理响应 |
|
||||
| `duration` | 执行耗时(纳秒) |
|
||||
|
||||
### 响应
|
||||
|
||||
```json
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 5,
|
||||
"result": {
|
||||
"action": "continue"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 6. `hook.approve_tool`
|
||||
|
||||
审批型 hook,用于决定是否允许执行敏感工具。
|
||||
|
||||
### 请求
|
||||
|
||||
```json
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 6,
|
||||
"method": "hook.approve_tool",
|
||||
"params": {
|
||||
"meta": {
|
||||
"AgentID": "agent-1",
|
||||
"TurnID": "turn-1",
|
||||
"SessionKey": "session-1"
|
||||
},
|
||||
"tool": "bash",
|
||||
"arguments": {
|
||||
"command": "rm -rf /"
|
||||
},
|
||||
"channel": "cli",
|
||||
"chat_id": "chat-1"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 响应(批准)
|
||||
|
||||
```json
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 6,
|
||||
"result": {
|
||||
"approved": true
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 响应(拒绝)
|
||||
|
||||
```json
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 6,
|
||||
"result": {
|
||||
"approved": false,
|
||||
"reason": "危险命令,禁止执行"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 7. `hook.event`(notification)
|
||||
|
||||
观察型事件,仅广播,无需响应。`id` 为 `0` 或不存在。
|
||||
|
||||
```json
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"method": "hook.event",
|
||||
"params": {
|
||||
"Kind": "tool_exec_start",
|
||||
"Meta": {
|
||||
"AgentID": "agent-1",
|
||||
"TurnID": "turn-1"
|
||||
},
|
||||
"Payload": {
|
||||
"Tool": "echo_text",
|
||||
"Arguments": {"text": "hello"}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
常见 `Kind` 值:
|
||||
- `turn_start` / `turn_end`
|
||||
- `llm_request` / `llm_response`
|
||||
- `tool_exec_start` / `tool_exec_end` / `tool_exec_skipped`
|
||||
- `steering_injected`
|
||||
- `interrupt_received`
|
||||
- `error`
|
||||
|
||||
---
|
||||
|
||||
## action 可选值
|
||||
|
||||
| action | 适用 hook | 效果 |
|
||||
|--------|----------|------|
|
||||
| `continue` | 所有拦截型 | 放行,不做修改 |
|
||||
| `modify` | `before_llm`, `before_tool`, `after_llm`, `after_tool` | 改写请求/响应后放行 |
|
||||
| `respond` | `before_tool` | 直接返回工具结果,跳过实际执行 |
|
||||
| `deny_tool` | `before_tool` | 拒绝执行该工具 |
|
||||
| `abort_turn` | 所有拦截型 | 中止当前 turn,返回错误 |
|
||||
| `hard_abort` | 所有拦截型 | 强制终止整个 agent loop |
|
||||
|
||||
---
|
||||
|
||||
## 完整流程示例
|
||||
|
||||
```json
|
||||
{"jsonrpc":"2.0","id":1,"method":"hook.hello","params":{"name":"my_hook","version":1,"modes":["tool","approve"]}}
|
||||
{"jsonrpc":"2.0","id":1,"result":{"ok":true,"name":"my_hook"}}
|
||||
{"jsonrpc":"2.0","id":2,"method":"hook.before_llm","params":{"model":"claude-sonnet","messages":[{"role":"user","content":"hello"}],"tools":[]}}
|
||||
{"jsonrpc":"2.0","id":2,"result":{"action":"continue"}}
|
||||
{"jsonrpc":"2.0","id":3,"method":"hook.before_tool","params":{"tool":"bash","arguments":{"command":"ls"}}}
|
||||
{"jsonrpc":"2.0","id":3,"result":{"action":"continue"}}
|
||||
{"jsonrpc":"2.0","id":4,"method":"hook.approve_tool","params":{"tool":"bash","arguments":{"command":"ls"}}}
|
||||
{"jsonrpc":"2.0","id":4,"result":{"approved":true}}
|
||||
{"jsonrpc":"2.0","id":5,"method":"hook.after_tool","params":{"tool":"bash","arguments":{"command":"ls"},"result":{"for_llm":"file1.txt\nfile2.txt"},"duration":5000000}}
|
||||
{"jsonrpc":"2.0","id":5,"result":{"action":"continue"}}
|
||||
{"jsonrpc":"2.0","id":6,"method":"hook.after_llm","params":{"model":"claude-sonnet","response":{"role":"assistant","content":"已列出文件"}}}
|
||||
{"jsonrpc":"2.0","id":6,"result":{"action":"continue"}}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 通过 `before_llm` 和 `before_tool` 实现插件工具注入
|
||||
|
||||
插件工具注入的标准流程:
|
||||
|
||||
1. 在 `before_llm` 中注入工具定义,让 LLM 知道有这个工具可用
|
||||
2. 在 `before_tool` 中使用 `respond` action 直接返回工具执行结果
|
||||
|
||||
### `before_llm` 注入工具定义
|
||||
|
||||
```python
|
||||
def handle_before_llm(params: dict) -> dict:
|
||||
tools = params.get("tools", [])
|
||||
|
||||
# 添加插件工具定义
|
||||
tools.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "my_plugin_tool",
|
||||
"description": "插件提供的工具",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"input": {"type": "string", "description": "输入内容"}
|
||||
},
|
||||
"required": ["input"]
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return {
|
||||
"action": "modify",
|
||||
"request": {
|
||||
"model": params["model"],
|
||||
"messages": params["messages"],
|
||||
"tools": tools,
|
||||
"options": params.get("options", {})
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### `before_tool` 返回执行结果
|
||||
|
||||
```python
|
||||
def handle_before_tool(params: dict) -> dict:
|
||||
tool = params.get("tool", "")
|
||||
|
||||
if tool == "my_plugin_tool":
|
||||
# 在这里实现工具逻辑
|
||||
args = params.get("arguments", {})
|
||||
input_text = args.get("input", "")
|
||||
|
||||
# 直接返回结果,无需在 ToolRegistry 注册
|
||||
return {
|
||||
"action": "respond",
|
||||
"result": {
|
||||
"for_llm": f"插件工具执行成功,输入: {input_text}",
|
||||
"silent": False,
|
||||
"is_error": False
|
||||
}
|
||||
}
|
||||
|
||||
return {"action": "continue"}
|
||||
```
|
||||
|
||||
通过这种方式,外部 hook 可以完全实现插件工具,无需在 PicoClaw 内部注册任何工具实现。
|
||||
@@ -0,0 +1,587 @@
|
||||
# Plugin Tool Injection Example
|
||||
|
||||
This document demonstrates how to use PicoClaw's hook system to implement external plugin tool injection, allowing LLM to call tools implemented by external hook processes.
|
||||
|
||||
---
|
||||
|
||||
## Core Principle
|
||||
|
||||
Through the hook system's `respond` action, external hooks can:
|
||||
|
||||
1. Inject tool **definitions** in `before_llm`, letting LLM know the tool is available
|
||||
2. Return tool **execution results** directly in `before_tool` using `respond` action, skipping ToolRegistry
|
||||
|
||||
This way, external hooks can fully implement plugin tools without registering any tools inside PicoClaw.
|
||||
|
||||
---
|
||||
|
||||
## Complete Example: Weather Query Plugin
|
||||
|
||||
Below is a complete Python hook example implementing a weather query plugin tool.
|
||||
|
||||
### 1. Hook Script Implementation
|
||||
|
||||
Save as `/tmp/weather_plugin.py`:
|
||||
|
||||
```python
|
||||
#!/usr/bin/env python3
|
||||
"""Weather query plugin hook example"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
import signal
|
||||
from typing import Any
|
||||
|
||||
# Simulated weather data
|
||||
WEATHER_DATA = {
|
||||
"Beijing": {"temp": 15, "weather": "Sunny", "humidity": 45},
|
||||
"Shanghai": {"temp": 18, "weather": "Cloudy", "humidity": 60},
|
||||
"Guangzhou": {"temp": 25, "weather": "Sunny", "humidity": 70},
|
||||
"Shenzhen": {"temp": 26, "weather": "Cloudy", "humidity": 75},
|
||||
}
|
||||
|
||||
|
||||
def get_weather(city: str) -> dict:
|
||||
"""Get weather data (simulated)"""
|
||||
data = WEATHER_DATA.get(city)
|
||||
if data:
|
||||
return {
|
||||
"for_llm": f"{city} weather: {data['weather']}, temperature {data['temp']}°C, humidity {data['humidity']}%",
|
||||
"for_user": "",
|
||||
"silent": False,
|
||||
"is_error": False,
|
||||
}
|
||||
return {
|
||||
"for_llm": f"Weather data not found for city {city}",
|
||||
"for_user": "",
|
||||
"silent": False,
|
||||
"is_error": True,
|
||||
}
|
||||
|
||||
|
||||
def handle_hello(params: dict) -> dict:
|
||||
return {"ok": True, "name": "weather-plugin"}
|
||||
|
||||
|
||||
def handle_before_llm(params: dict) -> dict:
|
||||
"""Inject weather query tool definition"""
|
||||
tools = params.get("tools", [])
|
||||
|
||||
# Add weather query tool
|
||||
tools.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Query weather information for a specified city",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "City name, e.g.: Beijing, Shanghai, Guangzhou"
|
||||
}
|
||||
},
|
||||
"required": ["city"]
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return {
|
||||
"action": "modify",
|
||||
"request": {
|
||||
"model": params.get("model"),
|
||||
"messages": params.get("messages", []),
|
||||
"tools": tools,
|
||||
"options": params.get("options", {}),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def handle_before_tool(params: dict) -> dict:
|
||||
"""Handle tool call, return result directly"""
|
||||
tool = params.get("tool", "")
|
||||
args = params.get("arguments", {})
|
||||
|
||||
if tool == "get_weather":
|
||||
city = args.get("city", "")
|
||||
result = get_weather(city)
|
||||
|
||||
# Use respond action to return result directly, skip ToolRegistry
|
||||
return {
|
||||
"action": "respond",
|
||||
"result": result,
|
||||
}
|
||||
|
||||
# Other tools continue normal flow
|
||||
return {"action": "continue"}
|
||||
|
||||
|
||||
def handle_request(method: str, params: dict) -> dict:
|
||||
if method == "hook.hello":
|
||||
return handle_hello(params)
|
||||
if method == "hook.before_llm":
|
||||
return handle_before_llm(params)
|
||||
if method == "hook.before_tool":
|
||||
return handle_before_tool(params)
|
||||
if method == "hook.after_llm":
|
||||
return {"action": "continue"}
|
||||
if method == "hook.after_tool":
|
||||
return {"action": "continue"}
|
||||
if method == "hook.approve_tool":
|
||||
return {"approved": True}
|
||||
raise KeyError(f"method not found: {method}")
|
||||
|
||||
|
||||
def send_response(message_id: int, result: Any | None = None, error: str | None = None) -> None:
|
||||
payload: dict[str, Any] = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": message_id,
|
||||
}
|
||||
if error is not None:
|
||||
payload["error"] = {"code": -32000, "message": error}
|
||||
else:
|
||||
payload["result"] = result if result is not None else {}
|
||||
|
||||
sys.stdout.write(json.dumps(payload, ensure_ascii=True) + "\n")
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
def main() -> int:
|
||||
for raw_line in sys.stdin:
|
||||
line = raw_line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
try:
|
||||
message = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
method = message.get("method")
|
||||
message_id = message.get("id", 0)
|
||||
params = message.get("params") or {}
|
||||
|
||||
if not message_id:
|
||||
continue
|
||||
|
||||
try:
|
||||
result = handle_request(str(method or ""), params)
|
||||
send_response(int(message_id), result=result)
|
||||
except KeyError as exc:
|
||||
send_response(int(message_id), error=str(exc))
|
||||
except Exception as exc:
|
||||
send_response(int(message_id), error=f"unexpected error: {exc}")
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
signal.signal(signal.SIGINT, lambda *_: raise SystemExit(0))
|
||||
signal.signal(signal.SIGTERM, lambda *_: raise SystemExit(0))
|
||||
raise SystemExit(main())
|
||||
```
|
||||
|
||||
### 2. Configure PicoClaw
|
||||
|
||||
Add hook configuration in the config file:
|
||||
|
||||
```json
|
||||
{
|
||||
"hooks": {
|
||||
"enabled": true,
|
||||
"processes": {
|
||||
"weather_plugin": {
|
||||
"enabled": true,
|
||||
"priority": 100,
|
||||
"transport": "stdio",
|
||||
"command": ["python3", "/tmp/weather_plugin.py"],
|
||||
"intercept": ["before_llm", "before_tool"]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 3. Test Results
|
||||
|
||||
When user asks "What's the weather in Beijing today?":
|
||||
|
||||
1. PicoClaw sends `hook.before_llm`, hook injects `get_weather` tool definition
|
||||
2. LLM sees tool definition, decides to call `get_weather(city="Beijing")`
|
||||
3. PicoClaw sends `hook.before_tool`, hook uses `respond` action to return weather data
|
||||
4. LLM receives result, replies to user "Beijing is sunny today, temperature 15°C"
|
||||
|
||||
---
|
||||
|
||||
## Flow Diagram
|
||||
|
||||
```
|
||||
User: "What's the weather in Beijing today?"
|
||||
↓
|
||||
PicoClaw
|
||||
↓
|
||||
hook.before_llm
|
||||
↓ (inject get_weather tool definition)
|
||||
LLM request
|
||||
↓
|
||||
LLM decides to call get_weather(city="Beijing")
|
||||
↓
|
||||
hook.before_tool
|
||||
↓ (respond action returns weather data)
|
||||
Return result directly to LLM
|
||||
↓ (skip ToolRegistry)
|
||||
LLM replies: "Beijing is sunny today, temperature 15°C"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Key Points
|
||||
|
||||
### `before_llm` Inject Tool Definition
|
||||
|
||||
Tool definition follows OpenAI function calling format:
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "tool_name",
|
||||
"description": "tool description",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"param_name": {
|
||||
"type": "string",
|
||||
"description": "parameter description"
|
||||
}
|
||||
},
|
||||
"required": ["list of required parameters"]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### `before_tool` Use respond Action
|
||||
|
||||
`respond` action response format:
|
||||
|
||||
```json
|
||||
{
|
||||
"action": "respond",
|
||||
"result": {
|
||||
"for_llm": "Content returned to LLM",
|
||||
"for_user": "Optional, content sent to user",
|
||||
"silent": false,
|
||||
"is_error": false,
|
||||
"media": ["Optional, media reference list"],
|
||||
"response_handled": false
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
| Field | Description |
|
||||
|-------|-------------|
|
||||
| `for_llm` | Required, LLM will see this content |
|
||||
| `for_user` | Optional, sent directly to user |
|
||||
| `silent` | When true, not sent to user |
|
||||
| `is_error` | When true, indicates execution failure |
|
||||
| `media` | Optional, media file references (images, files, etc.) |
|
||||
| `response_handled` | When true, indicates user request is handled, turn will end |
|
||||
|
||||
---
|
||||
|
||||
## Media File Handling
|
||||
|
||||
The `respond` action supports returning media files (images, files, etc.). There are two processing modes:
|
||||
|
||||
### 1. Automatic Delivery (`response_handled=true`)
|
||||
|
||||
When `response_handled=true`, media files are automatically sent to the user and the turn ends:
|
||||
|
||||
```json
|
||||
{
|
||||
"action": "respond",
|
||||
"result": {
|
||||
"for_llm": "Image sent to user",
|
||||
"for_user": "",
|
||||
"media": ["media://abc123"],
|
||||
"response_handled": true
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Use cases:
|
||||
- Image generation plugin directly returning results
|
||||
- File download plugin sending files to user
|
||||
|
||||
### 2. LLM Visible (`response_handled=false`)
|
||||
|
||||
When `response_handled=false`, media references are passed to the LLM, which can see the content in the next request:
|
||||
|
||||
```json
|
||||
{
|
||||
"action": "respond",
|
||||
"result": {
|
||||
"for_llm": "Image loaded, path: /tmp/image.png [file:/tmp/image.png]",
|
||||
"media": ["media://abc123"]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
After seeing the content, the LLM can decide:
|
||||
- Use `send_file` tool to send to user
|
||||
- Analyze image content and reply to user
|
||||
- Other processing approaches
|
||||
|
||||
### Media Reference Format
|
||||
|
||||
Media references use the `media://` protocol:
|
||||
|
||||
```
|
||||
media://<store-id>
|
||||
```
|
||||
|
||||
These references are managed by PicoClaw's MediaStore and can be:
|
||||
- Sent to user via channel
|
||||
- Converted to base64 in LLM vision requests
|
||||
|
||||
### Alternative: Use Existing Tools
|
||||
|
||||
If the plugin generates files, you can return the file path and let the LLM call `send_file` or similar tools:
|
||||
|
||||
```json
|
||||
{
|
||||
"action": "respond",
|
||||
"result": {
|
||||
"for_llm": "Image generated, saved at /tmp/generated_image.png. Use send_file tool to send to user.",
|
||||
"for_user": "",
|
||||
"silent": false
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
This approach:
|
||||
- More decoupled, LLM decides when to send
|
||||
- Leverages existing tool mechanisms
|
||||
- Supports batch sending, delayed sending, etc.
|
||||
|
||||
---
|
||||
|
||||
## Multi-Tool Injection Example
|
||||
|
||||
Multiple tools can be injected simultaneously:
|
||||
|
||||
```python
|
||||
def handle_before_llm(params: dict) -> dict:
|
||||
tools = params.get("tools", [])
|
||||
|
||||
# Tool 1: Weather query
|
||||
tools.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Query city weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string", "description": "City name"}
|
||||
},
|
||||
"required": ["city"]
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
# Tool 2: Calculator
|
||||
tools.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "calculate",
|
||||
"description": "Perform mathematical calculations",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"expression": {"type": "string", "description": "Mathematical expression"}
|
||||
},
|
||||
"required": ["expression"]
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return {
|
||||
"action": "modify",
|
||||
"request": {
|
||||
"model": params.get("model"),
|
||||
"messages": params.get("messages", []),
|
||||
"tools": tools,
|
||||
"options": params.get("options", {}),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def handle_before_tool(params: dict) -> dict:
|
||||
tool = params.get("tool", "")
|
||||
args = params.get("arguments", {})
|
||||
|
||||
if tool == "get_weather":
|
||||
return {
|
||||
"action": "respond",
|
||||
"result": get_weather(args.get("city", "")),
|
||||
}
|
||||
|
||||
if tool == "calculate":
|
||||
# Simple calculation example
|
||||
try:
|
||||
expr = args.get("expression", "")
|
||||
result = eval(expr) # Note: needs security handling in actual use
|
||||
return {
|
||||
"action": "respond",
|
||||
"result": {
|
||||
"for_llm": f"Calculation result: {result}",
|
||||
"silent": False,
|
||||
"is_error": False,
|
||||
},
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"action": "respond",
|
||||
"result": {
|
||||
"for_llm": f"Calculation error: {e}",
|
||||
"silent": False,
|
||||
"is_error": True,
|
||||
},
|
||||
}
|
||||
|
||||
return {"action": "continue"}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Coexistence with Built-in Tools
|
||||
|
||||
Injected plugin tools coexist with PicoClaw built-in tools:
|
||||
|
||||
- Built-in tools (like `bash`, `read_file`) execute normally through ToolRegistry
|
||||
- Plugin tools return results through hook's `respond` action
|
||||
- `handle_before_tool` only handles plugin tools, other tools return `continue`
|
||||
|
||||
---
|
||||
|
||||
## Go In-Process Hook Example
|
||||
|
||||
If you need to implement plugin tool injection in Go code:
|
||||
|
||||
```go
|
||||
package myhooks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/sipeed/picoclaw/pkg/agent"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
|
||||
type WeatherPluginHook struct{}
|
||||
|
||||
func (h *WeatherPluginHook) BeforeLLM(
|
||||
ctx context.Context,
|
||||
req *agent.LLMHookRequest,
|
||||
) (*agent.LLMHookRequest, agent.HookDecision, error) {
|
||||
// Inject tool definition
|
||||
req.Tools = append(req.Tools, agent.ToolDefinition{
|
||||
Type: "function",
|
||||
Function: agent.FunctionDefinition{
|
||||
Name: "get_weather",
|
||||
Description: "Query city weather",
|
||||
Parameters: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"city": map[string]any{
|
||||
"type": "string",
|
||||
"description": "City name",
|
||||
},
|
||||
},
|
||||
"required": []string{"city"},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
return req, agent.HookDecision{Action: agent.HookActionContinue}, nil
|
||||
}
|
||||
|
||||
func (h *WeatherPluginHook) BeforeTool(
|
||||
ctx context.Context,
|
||||
call *agent.ToolCallHookRequest,
|
||||
) (*agent.ToolCallHookRequest, agent.HookDecision, error) {
|
||||
if call.Tool == "get_weather" {
|
||||
city := call.Arguments["city"].(string)
|
||||
|
||||
// Set HookResult, use respond action
|
||||
next := call.Clone()
|
||||
next.HookResult = &tools.ToolResult{
|
||||
ForLLM: getWeatherData(city),
|
||||
Silent: false,
|
||||
IsError: false,
|
||||
}
|
||||
|
||||
return next, agent.HookDecision{Action: agent.HookActionRespond}, nil
|
||||
}
|
||||
|
||||
return call, agent.HookDecision{Action: agent.HookActionContinue}, nil
|
||||
}
|
||||
|
||||
func getWeatherData(city string) string {
|
||||
// Implement weather query logic
|
||||
return fmt.Sprintf("%s weather: Sunny, temperature 20°C", city)
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Summary
|
||||
|
||||
Through the hook system's `respond` action, external processes can:
|
||||
|
||||
1. **Inject tool definitions**: Let LLM know new tools are available
|
||||
2. **Provide tool implementation**: Return execution results directly, no need to register in ToolRegistry
|
||||
3. **Coexist with built-in tools**: Does not affect normal operation of PicoClaw's original tools
|
||||
|
||||
This provides a flexible and elegant solution for plugin development.
|
||||
|
||||
---
|
||||
|
||||
## Security Boundaries
|
||||
|
||||
### Bypassing Approval Checks
|
||||
|
||||
**Important**: The `respond` action bypasses `ApproveTool` approval checks.
|
||||
|
||||
This means:
|
||||
- A `before_tool` hook can return `respond` for **any tool name**, including sensitive tools (like `bash`)
|
||||
- The tool won't go through the approval process, directly returning the hook-provided result
|
||||
- This is designed for plugin tools but introduces security risks
|
||||
|
||||
### Security Recommendations
|
||||
|
||||
1. **Review hook configuration**: Ensure only trusted hook processes are enabled
|
||||
2. **Limit hook scope**: Add your own security checks in hook implementation
|
||||
3. **Use `deny_tool` for rejection**: Use `deny_tool` action instead of `respond` with error for denying execution
|
||||
|
||||
### Example: Hook-Internal Security Check
|
||||
|
||||
```python
|
||||
def handle_before_tool(params: dict) -> dict:
|
||||
tool = params.get("tool", "")
|
||||
args = params.get("arguments", {})
|
||||
|
||||
# Security check: only handle plugin tools
|
||||
if tool in ["get_weather", "calculate"]:
|
||||
return {
|
||||
"action": "respond",
|
||||
"result": execute_plugin_tool(tool, args),
|
||||
}
|
||||
|
||||
# Other tools continue normal flow (will go through approval)
|
||||
return {"action": "continue"}
|
||||
```
|
||||
|
||||
This ensures the hook only affects plugin tools, not system tool approval flow.
|
||||
@@ -0,0 +1,587 @@
|
||||
# 插件工具注入示例
|
||||
|
||||
本文档展示如何利用 PicoClaw 的 hook 系统实现外部插件工具注入,让 LLM 能调用由外部 hook 进程实现的工具。
|
||||
|
||||
---
|
||||
|
||||
## 核心原理
|
||||
|
||||
通过 hook 系统的 `respond` action,外部 hook 可以:
|
||||
|
||||
1. 在 `before_llm` 中注入工具**定义**,让 LLM 知道有这个工具可用
|
||||
2. 在 `before_tool` 中使用 `respond` action 直接返回工具**执行结果**,跳过 ToolRegistry
|
||||
|
||||
这样,外部 hook 可以完全实现插件工具,无需在 PicoClaw 内部注册任何工具。
|
||||
|
||||
---
|
||||
|
||||
## 完整示例:天气查询插件
|
||||
|
||||
下面是一个完整的 Python hook 示例,实现一个天气查询插件工具。
|
||||
|
||||
### 1. Hook 脚本实现
|
||||
|
||||
保存为 `/tmp/weather_plugin.py`:
|
||||
|
||||
```python
|
||||
#!/usr/bin/env python3
|
||||
"""天气查询插件 hook 示例"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sys
|
||||
import signal
|
||||
from typing import Any
|
||||
|
||||
# 模拟天气数据
|
||||
WEATHER_DATA = {
|
||||
"北京": {"temp": 15, "weather": "晴", "humidity": 45},
|
||||
"上海": {"temp": 18, "weather": "多云", "humidity": 60},
|
||||
"广州": {"temp": 25, "weather": "晴", "humidity": 70},
|
||||
"深圳": {"temp": 26, "weather": "多云", "humidity": 75},
|
||||
}
|
||||
|
||||
|
||||
def get_weather(city: str) -> dict:
|
||||
"""获取天气数据(模拟)"""
|
||||
data = WEATHER_DATA.get(city)
|
||||
if data:
|
||||
return {
|
||||
"for_llm": f"{city}天气:{data['weather']},温度{data['temp']}°C,湿度{data['humidity']}%",
|
||||
"for_user": "",
|
||||
"silent": False,
|
||||
"is_error": False,
|
||||
}
|
||||
return {
|
||||
"for_llm": f"未找到城市 {city} 的天气数据",
|
||||
"for_user": "",
|
||||
"silent": False,
|
||||
"is_error": True,
|
||||
}
|
||||
|
||||
|
||||
def handle_hello(params: dict) -> dict:
|
||||
return {"ok": True, "name": "weather-plugin"}
|
||||
|
||||
|
||||
def handle_before_llm(params: dict) -> dict:
|
||||
"""注入天气查询工具定义"""
|
||||
tools = params.get("tools", [])
|
||||
|
||||
# 添加天气查询工具
|
||||
tools.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "查询指定城市的天气信息",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "城市名称,如:北京、上海、广州"
|
||||
}
|
||||
},
|
||||
"required": ["city"]
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return {
|
||||
"action": "modify",
|
||||
"request": {
|
||||
"model": params.get("model"),
|
||||
"messages": params.get("messages", []),
|
||||
"tools": tools,
|
||||
"options": params.get("options", {}),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def handle_before_tool(params: dict) -> dict:
|
||||
"""处理工具调用,直接返回结果"""
|
||||
tool = params.get("tool", "")
|
||||
args = params.get("arguments", {})
|
||||
|
||||
if tool == "get_weather":
|
||||
city = args.get("city", "")
|
||||
result = get_weather(city)
|
||||
|
||||
# 使用 respond action 直接返回结果,跳过 ToolRegistry
|
||||
return {
|
||||
"action": "respond",
|
||||
"result": result,
|
||||
}
|
||||
|
||||
# 其他工具继续正常流程
|
||||
return {"action": "continue"}
|
||||
|
||||
|
||||
def handle_request(method: str, params: dict) -> dict:
|
||||
if method == "hook.hello":
|
||||
return handle_hello(params)
|
||||
if method == "hook.before_llm":
|
||||
return handle_before_llm(params)
|
||||
if method == "hook.before_tool":
|
||||
return handle_before_tool(params)
|
||||
if method == "hook.after_llm":
|
||||
return {"action": "continue"}
|
||||
if method == "hook.after_tool":
|
||||
return {"action": "continue"}
|
||||
if method == "hook.approve_tool":
|
||||
return {"approved": True}
|
||||
raise KeyError(f"method not found: {method}")
|
||||
|
||||
|
||||
def send_response(message_id: int, result: Any | None = None, error: str | None = None) -> None:
|
||||
payload: dict[str, Any] = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": message_id,
|
||||
}
|
||||
if error is not None:
|
||||
payload["error"] = {"code": -32000, "message": error}
|
||||
else:
|
||||
payload["result"] = result if result is not None else {}
|
||||
|
||||
sys.stdout.write(json.dumps(payload, ensure_ascii=True) + "\n")
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
def main() -> int:
|
||||
for raw_line in sys.stdin:
|
||||
line = raw_line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
try:
|
||||
message = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
method = message.get("method")
|
||||
message_id = message.get("id", 0)
|
||||
params = message.get("params") or {}
|
||||
|
||||
if not message_id:
|
||||
continue
|
||||
|
||||
try:
|
||||
result = handle_request(str(method or ""), params)
|
||||
send_response(int(message_id), result=result)
|
||||
except KeyError as exc:
|
||||
send_response(int(message_id), error=str(exc))
|
||||
except Exception as exc:
|
||||
send_response(int(message_id), error=f"unexpected error: {exc}")
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
signal.signal(signal.SIGINT, lambda *_: raise SystemExit(0))
|
||||
signal.signal(signal.SIGTERM, lambda *_: raise SystemExit(0))
|
||||
raise SystemExit(main())
|
||||
```
|
||||
|
||||
### 2. 配置 PicoClaw
|
||||
|
||||
在配置文件中添加 hook 配置:
|
||||
|
||||
```json
|
||||
{
|
||||
"hooks": {
|
||||
"enabled": true,
|
||||
"processes": {
|
||||
"weather_plugin": {
|
||||
"enabled": true,
|
||||
"priority": 100,
|
||||
"transport": "stdio",
|
||||
"command": ["python3", "/tmp/weather_plugin.py"],
|
||||
"intercept": ["before_llm", "before_tool"]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 3. 测试效果
|
||||
|
||||
当用户问"北京今天天气怎么样?"时:
|
||||
|
||||
1. PicoClaw 发送 `hook.before_llm`,hook 注入 `get_weather` 工具定义
|
||||
2. LLM 看到工具定义,决定调用 `get_weather(city="北京")`
|
||||
3. PicoClaw 发送 `hook.before_tool`,hook 使用 `respond` action 返回天气数据
|
||||
4. LLM 收到结果,回复用户"北京今天晴天,温度15°C"
|
||||
|
||||
---
|
||||
|
||||
## 流程图解
|
||||
|
||||
```
|
||||
用户: "北京今天天气怎么样?"
|
||||
↓
|
||||
PicoClaw
|
||||
↓
|
||||
hook.before_llm
|
||||
↓ (注入 get_weather 工具定义)
|
||||
LLM 请求
|
||||
↓
|
||||
LLM 决定调用 get_weather(city="北京")
|
||||
↓
|
||||
hook.before_tool
|
||||
↓ (respond action 返回天气数据)
|
||||
直接返回结果给 LLM
|
||||
↓ (跳过 ToolRegistry)
|
||||
LLM 回复: "北京今天晴天,温度15°C"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 关键点说明
|
||||
|
||||
### `before_llm` 注入工具定义
|
||||
|
||||
工具定义遵循 OpenAI function calling 格式:
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "工具名称",
|
||||
"description": "工具描述",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"参数名": {
|
||||
"type": "string",
|
||||
"description": "参数描述"
|
||||
}
|
||||
},
|
||||
"required": ["必需参数列表"]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### `before_tool` 使用 respond action
|
||||
|
||||
`respond` action 的响应格式:
|
||||
|
||||
```json
|
||||
{
|
||||
"action": "respond",
|
||||
"result": {
|
||||
"for_llm": "返回给 LLM 的内容",
|
||||
"for_user": "可选,发送给用户的内容",
|
||||
"silent": false,
|
||||
"is_error": false,
|
||||
"media": ["可选,媒体引用列表"],
|
||||
"response_handled": false
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
| 字段 | 说明 |
|
||||
|------|------|
|
||||
| `for_llm` | 必须,LLM 会看到这个内容 |
|
||||
| `for_user` | 可选,直接发送给用户 |
|
||||
| `silent` | 为 true 时不发送给用户 |
|
||||
| `is_error` | 为 true 时表示执行失败 |
|
||||
| `media` | 可选,媒体文件引用列表(如图片、文件) |
|
||||
| `response_handled` | 为 true 时表示已处理用户请求,轮次将结束 |
|
||||
|
||||
---
|
||||
|
||||
## 媒体文件处理
|
||||
|
||||
`respond` action 支持返回媒体文件(图片、文件等)。有两种处理方式:
|
||||
|
||||
### 1. 自动发送(`response_handled=true`)
|
||||
|
||||
当 `response_handled=true` 时,媒体文件会自动发送给用户,轮次结束:
|
||||
|
||||
```json
|
||||
{
|
||||
"action": "respond",
|
||||
"result": {
|
||||
"for_llm": "图片已发送给用户",
|
||||
"for_user": "",
|
||||
"media": ["media://abc123"],
|
||||
"response_handled": true
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
适用场景:
|
||||
- 图像生成插件直接返回结果
|
||||
- 文件下载插件发送文件给用户
|
||||
|
||||
### 2. LLM 可见(`response_handled=false`)
|
||||
|
||||
当 `response_handled=false` 时,媒体引用会传递给 LLM,LLM 可以在下一轮请求中看到内容:
|
||||
|
||||
```json
|
||||
{
|
||||
"action": "respond",
|
||||
"result": {
|
||||
"for_llm": "图片已加载,路径:/tmp/image.png [file:/tmp/image.png]",
|
||||
"media": ["media://abc123"]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
LLM 看到内容后,可以自主决定:
|
||||
- 使用 `send_file` 工具发送给用户
|
||||
- 分析图片内容并回复用户
|
||||
- 其他处理方式
|
||||
|
||||
### 媒体引用格式
|
||||
|
||||
媒体引用使用 `media://` 协议:
|
||||
|
||||
```
|
||||
media://<store-id>
|
||||
```
|
||||
|
||||
这些引用由 PicoClaw 的 MediaStore 管理,可以:
|
||||
- 通过 channel 发送给用户
|
||||
- 在 LLM vision 请求中转换为 base64
|
||||
|
||||
### 替代方案:使用现有工具
|
||||
|
||||
如果插件生成文件,可以返回文件路径让 LLM 调用 `send_file` 等工具:
|
||||
|
||||
```json
|
||||
{
|
||||
"action": "respond",
|
||||
"result": {
|
||||
"for_llm": "图片已生成,保存在 /tmp/generated_image.png。使用 send_file 工具发送给用户。",
|
||||
"for_user": "",
|
||||
"silent": false
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
这种方式:
|
||||
- 更解耦,LLM 自主决策发送时机
|
||||
- 利用现有工具机制
|
||||
- 支持批量发送、延迟发送等场景
|
||||
|
||||
---
|
||||
|
||||
## 多工具注入示例
|
||||
|
||||
可以同时注入多个工具:
|
||||
|
||||
```python
|
||||
def handle_before_llm(params: dict) -> dict:
|
||||
tools = params.get("tools", [])
|
||||
|
||||
# 工具1:天气查询
|
||||
tools.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "查询城市天气",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string", "description": "城市名称"}
|
||||
},
|
||||
"required": ["city"]
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
# 工具2:计算器
|
||||
tools.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "calculate",
|
||||
"description": "执行数学计算",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"expression": {"type": "string", "description": "数学表达式"}
|
||||
},
|
||||
"required": ["expression"]
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return {
|
||||
"action": "modify",
|
||||
"request": {
|
||||
"model": params.get("model"),
|
||||
"messages": params.get("messages", []),
|
||||
"tools": tools,
|
||||
"options": params.get("options", {}),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def handle_before_tool(params: dict) -> dict:
|
||||
tool = params.get("tool", "")
|
||||
args = params.get("arguments", {})
|
||||
|
||||
if tool == "get_weather":
|
||||
return {
|
||||
"action": "respond",
|
||||
"result": get_weather(args.get("city", "")),
|
||||
}
|
||||
|
||||
if tool == "calculate":
|
||||
# 简单计算示例
|
||||
try:
|
||||
expr = args.get("expression", "")
|
||||
result = eval(expr) # 注意:实际使用时需要安全处理
|
||||
return {
|
||||
"action": "respond",
|
||||
"result": {
|
||||
"for_llm": f"计算结果: {result}",
|
||||
"silent": False,
|
||||
"is_error": False,
|
||||
},
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"action": "respond",
|
||||
"result": {
|
||||
"for_llm": f"计算错误: {e}",
|
||||
"silent": False,
|
||||
"is_error": True,
|
||||
},
|
||||
}
|
||||
|
||||
return {"action": "continue"}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 与内置工具共存
|
||||
|
||||
注入的插件工具与 PicoClaw 内置工具共存:
|
||||
|
||||
- 内置工具(如 `bash`、`read_file`)正常通过 ToolRegistry 执行
|
||||
- 插件工具通过 hook 的 `respond` action 返回结果
|
||||
- `handle_before_tool` 中只处理插件工具,其他工具返回 `continue`
|
||||
|
||||
---
|
||||
|
||||
## Go 进程内 Hook 示例
|
||||
|
||||
如果需要在 Go 代码中实现插件工具注入:
|
||||
|
||||
```go
|
||||
package myhooks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/sipeed/picoclaw/pkg/agent"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
|
||||
type WeatherPluginHook struct{}
|
||||
|
||||
func (h *WeatherPluginHook) BeforeLLM(
|
||||
ctx context.Context,
|
||||
req *agent.LLMHookRequest,
|
||||
) (*agent.LLMHookRequest, agent.HookDecision, error) {
|
||||
// 注入工具定义
|
||||
req.Tools = append(req.Tools, agent.ToolDefinition{
|
||||
Type: "function",
|
||||
Function: agent.FunctionDefinition{
|
||||
Name: "get_weather",
|
||||
Description: "查询城市天气",
|
||||
Parameters: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"city": map[string]any{
|
||||
"type": "string",
|
||||
"description": "城市名称",
|
||||
},
|
||||
},
|
||||
"required": []string{"city"},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
return req, agent.HookDecision{Action: agent.HookActionContinue}, nil
|
||||
}
|
||||
|
||||
func (h *WeatherPluginHook) BeforeTool(
|
||||
ctx context.Context,
|
||||
call *agent.ToolCallHookRequest,
|
||||
) (*agent.ToolCallHookRequest, agent.HookDecision, error) {
|
||||
if call.Tool == "get_weather" {
|
||||
city := call.Arguments["city"].(string)
|
||||
|
||||
// 设置 HookResult,使用 respond action
|
||||
next := call.Clone()
|
||||
next.HookResult = &tools.ToolResult{
|
||||
ForLLM: getWeatherData(city),
|
||||
Silent: false,
|
||||
IsError: false,
|
||||
}
|
||||
|
||||
return next, agent.HookDecision{Action: agent.HookActionRespond}, nil
|
||||
}
|
||||
|
||||
return call, agent.HookDecision{Action: agent.HookActionContinue}, nil
|
||||
}
|
||||
|
||||
func getWeatherData(city string) string {
|
||||
// 实现天气查询逻辑
|
||||
return fmt.Sprintf("%s天气:晴,温度20°C", city)
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 总结
|
||||
|
||||
通过 hook 系统的 `respond` action,外部进程可以:
|
||||
|
||||
1. **注入工具定义**:让 LLM 知道有新工具可用
|
||||
2. **提供工具实现**:直接返回执行结果,无需注册到 ToolRegistry
|
||||
3. **与内置工具共存**:不影响 PicoClaw 原有工具的正常运行
|
||||
|
||||
这为插件开发提供了灵活、优雅的解决方案。
|
||||
|
||||
---
|
||||
|
||||
## 安全边界说明
|
||||
|
||||
### 绕过审批检查
|
||||
|
||||
**重要**:`respond` action 会绕过 `ApproveTool` 审批检查。
|
||||
|
||||
这意味着:
|
||||
- `before_tool` hook 可以为**任何工具名称**返回 `respond`,包括敏感工具(如 `bash`)
|
||||
- 工具不会经过审批流程,直接返回 hook 提供的结果
|
||||
- 这是为了支持插件工具而设计,但也带来了安全风险
|
||||
|
||||
### 安全建议
|
||||
|
||||
1. **审查 hook 配置**:确保只有可信的 hook 进程被启用
|
||||
2. **限制 hook 权限**:在 hook 实现中添加自己的安全检查
|
||||
3. **优先使用 `deny_tool`**:对于拒绝执行,使用 `deny_tool` action 而非 `respond` 返回错误
|
||||
|
||||
### 示例:hook 内置安全检查
|
||||
|
||||
```python
|
||||
def handle_before_tool(params: dict) -> dict:
|
||||
tool = params.get("tool", "")
|
||||
args = params.get("arguments", {})
|
||||
|
||||
# 安全检查:只处理插件工具
|
||||
if tool in ["get_weather", "calculate"]:
|
||||
return {
|
||||
"action": "respond",
|
||||
"result": execute_plugin_tool(tool, args),
|
||||
}
|
||||
|
||||
# 其他工具继续正常流程(会经过审批)
|
||||
return {"action": "continue"}
|
||||
```
|
||||
|
||||
这样可以确保 hook 只影响插件工具,不影响系统工具的审批流程。
|
||||
@@ -122,6 +122,7 @@ This design also enables **multi-agent support** with flexible provider selectio
|
||||
| `max_tokens_field` | string | No | Override the max tokens field name in request body (e.g., `max_completion_tokens` for o1 models) |
|
||||
| `thinking_level` | string | No | Extended thinking level: `off`, `low`, `medium`, `high`, `xhigh`, or `adaptive` |
|
||||
| `extra_body` | object | No | Additional fields to inject into every request body |
|
||||
| `custom_headers` | object | No | Additional HTTP headers to inject into every request (e.g., `{"X-Source":"coding-plan"}`). If a key matches a built-in header, the custom value overrides the built-in one (e.g., `Authorization`, `User-Agent`, `Content-Type`, `Accept`). |
|
||||
| `rpm` | int | No | Per-minute request rate limit |
|
||||
| `fallbacks` | string[] | No | Fallback model names for automatic failover |
|
||||
| `enabled` | bool | No | Whether this model entry is active (default: `true`) |
|
||||
|
||||
@@ -528,6 +528,9 @@ For example:
|
||||
- `PICOCLAW_TOOLS_EXEC_ENABLE_DENY_PATTERNS=false`
|
||||
- `PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES=10`
|
||||
- `PICOCLAW_TOOLS_MCP_ENABLED=true`
|
||||
- `PICOCLAW_TOOLS_MCP_MAX_INLINE_TEXT_CHARS=16384`
|
||||
|
||||
Note: Nested map-style config (for example `tools.mcp.servers.<name>.*`) is configured in `config.json` rather than
|
||||
environment variables.
|
||||
|
||||
For MCP tools, `tools.mcp.max_inline_text_chars` controls how much text result is kept inline in model context. The threshold is counted in Unicode characters (Go runes), not bytes. For example, `16384` means up to 16,384 characters inline, which may occupy more than 16 KB for multibyte text such as CJK. Above this threshold, PicoClaw saves the MCP text result as a local artifact in the agent workspace and gives the model a short note plus a structured `[file:...]` artifact path instead of injecting the full payload into context.
|
||||
|
||||
@@ -118,6 +118,7 @@
|
||||
| `max_tokens_field` | string | 否 | 覆盖请求体中 max tokens 的字段名(如 o1 模型使用 `max_completion_tokens`) |
|
||||
| `thinking_level` | string | 否 | 扩展思考级别:`off`、`low`、`medium`、`high`、`xhigh` 或 `adaptive` |
|
||||
| `extra_body` | object | 否 | 注入到每个请求体中的额外字段 |
|
||||
| `custom_headers` | object | 否 | 注入到每个请求中的额外 HTTP 请求头(例如 `{"X-Source":"coding-plan"}`)。若键名与内置请求头同名,会覆盖内置值(如 `Authorization`、`User-Agent`、`Content-Type`、`Accept`)。 |
|
||||
| `rpm` | int | 否 | 每分钟请求速率限制 |
|
||||
| `fallbacks` | string[] | 否 | 自动故障转移的备用模型名称 |
|
||||
| `enabled` | bool | 否 | 是否启用此模型条目(默认:`true`) |
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
module github.com/sipeed/picoclaw
|
||||
|
||||
go 1.25.8
|
||||
go 1.25.9
|
||||
|
||||
require (
|
||||
fyne.io/systray v1.12.0
|
||||
@@ -8,6 +8,7 @@ require (
|
||||
github.com/SevereCloud/vksdk/v3 v3.3.1
|
||||
github.com/adhocore/gronx v1.19.6
|
||||
github.com/anthropics/anthropic-sdk-go v1.26.0
|
||||
github.com/atc0005/go-teams-notify/v2 v2.14.0
|
||||
github.com/atotto/clipboard v0.1.4
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.5
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.12
|
||||
@@ -29,7 +30,7 @@ require (
|
||||
github.com/mymmrac/telego v1.7.0
|
||||
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1
|
||||
github.com/openai/openai-go/v3 v3.22.0
|
||||
github.com/pion/rtp v1.8.7
|
||||
github.com/pion/rtp v1.10.1
|
||||
github.com/pion/webrtc/v3 v3.3.6
|
||||
github.com/rivo/tview v0.42.0
|
||||
github.com/rs/zerolog v1.35.0
|
||||
@@ -45,7 +46,7 @@ require (
|
||||
google.golang.org/protobuf v1.36.11
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
maunium.net/go/mautrix v0.26.4
|
||||
modernc.org/sqlite v1.47.0
|
||||
modernc.org/sqlite v1.48.0
|
||||
rsc.io/qr v0.2.0
|
||||
)
|
||||
|
||||
|
||||
@@ -21,6 +21,8 @@ github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwTo
|
||||
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
|
||||
github.com/anthropics/anthropic-sdk-go v1.26.0 h1:oUTzFaUpAevfuELAP1sjL6CQJ9HHAfT7CoSYSac11PY=
|
||||
github.com/anthropics/anthropic-sdk-go v1.26.0/go.mod h1:qUKmaW+uuPB64iy1l+4kOSvaLqPXnHTTBKH6RVZ7q5Q=
|
||||
github.com/atc0005/go-teams-notify/v2 v2.14.0 h1:7N+xw+COnYANLREaAveQ65rsNQ12nIZJED9nMLyscCo=
|
||||
github.com/atc0005/go-teams-notify/v2 v2.14.0/go.mod h1:EECsWM2b0Hvoz7O+QdlsvyN2KCUOFQCGj8bUBXv3A3Q=
|
||||
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
|
||||
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.5 h1:dj5kopbwUsVUVFgO4Fi5BIT3t4WyqIDjGKCangnV/yY=
|
||||
@@ -207,8 +209,8 @@ github.com/petermattis/goid v0.0.0-20260226131333-17d1149c6ac6 h1:rh2lKw/P/EqHa7
|
||||
github.com/petermattis/goid v0.0.0-20260226131333-17d1149c6ac6/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4=
|
||||
github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA=
|
||||
github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8=
|
||||
github.com/pion/rtp v1.8.7 h1:qslKkG8qxvQ7hqaxkmL7Pl0XcUm+/Er7nMnu6Vq+ZxM=
|
||||
github.com/pion/rtp v1.8.7/go.mod h1:pBGHaFt/yW7bf1jjWAoUjpSNoDnw98KTMg+jWWvziqU=
|
||||
github.com/pion/rtp v1.10.1 h1:xP1prZcCTUuhO2c83XtxyOHJteISg6o8iPsE2acaMtA=
|
||||
github.com/pion/rtp v1.10.1/go.mod h1:rF5nS1GqbR7H/TCpKwylzeq6yDM+MM6k+On5EgeThEM=
|
||||
github.com/pion/webrtc/v3 v3.3.6 h1:7XAh4RPtlY1Vul6/GmZrv7z+NnxKA6If0KStXBI2ZLE=
|
||||
github.com/pion/webrtc/v3 v3.3.6/go.mod h1:zyN7th4mZpV27eXybfR/cnUf3J2DRy8zw/mdjD9JTNM=
|
||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
||||
@@ -456,8 +458,8 @@ modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
|
||||
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
|
||||
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
|
||||
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
|
||||
modernc.org/sqlite v1.47.0 h1:R1XyaNpoW4Et9yly+I2EeX7pBza/w+pmYee/0HJDyKk=
|
||||
modernc.org/sqlite v1.47.0/go.mod h1:hWjRO6Tj/5Ik8ieqxQybiEOUXy0NJFNp2tpvVpKlvig=
|
||||
modernc.org/sqlite v1.48.0 h1:ElZyLop3Q2mHYk5IFPPXADejZrlHu7APbpB0sF78bq4=
|
||||
modernc.org/sqlite v1.48.0/go.mod h1:hWjRO6Tj/5Ik8ieqxQybiEOUXy0NJFNp2tpvVpKlvig=
|
||||
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
|
||||
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
|
||||
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
||||
|
||||
+11
-85
@@ -6,10 +6,8 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/tokenizer"
|
||||
)
|
||||
|
||||
// parseTurnBoundaries returns the starting index of each Turn in the history.
|
||||
@@ -86,88 +84,16 @@ func findSafeBoundary(history []providers.Message, targetIndex int) int {
|
||||
return 0
|
||||
}
|
||||
|
||||
// estimateMessageTokens estimates the token count for a single message,
|
||||
// including Content, ReasoningContent, ToolCalls arguments, ToolCallID
|
||||
// metadata, and Media items. Uses a heuristic of 2.5 characters per token.
|
||||
func estimateMessageTokens(msg providers.Message) int {
|
||||
contentChars := utf8.RuneCountInString(msg.Content)
|
||||
|
||||
// SystemParts are structured system blocks used for cache-aware adapters.
|
||||
// They carry the same content as Content, but in multiple blocks.
|
||||
// We estimate them as an alternative representation, not additive.
|
||||
systemPartsChars := 0
|
||||
if len(msg.SystemParts) > 0 {
|
||||
for _, part := range msg.SystemParts {
|
||||
systemPartsChars += utf8.RuneCountInString(part.Text)
|
||||
}
|
||||
// Per-part overhead for JSON structure (type, text, cache_control).
|
||||
const perPartOverhead = 20
|
||||
systemPartsChars += len(msg.SystemParts) * perPartOverhead
|
||||
}
|
||||
|
||||
// Use the larger of the two representations to stay conservative.
|
||||
chars := contentChars
|
||||
if systemPartsChars > chars {
|
||||
chars = systemPartsChars
|
||||
}
|
||||
|
||||
chars += utf8.RuneCountInString(msg.ReasoningContent)
|
||||
|
||||
for _, tc := range msg.ToolCalls {
|
||||
chars += len(tc.ID) + len(tc.Type)
|
||||
if tc.Function != nil {
|
||||
// Count function name + arguments (the wire format for most providers).
|
||||
// tc.Name mirrors tc.Function.Name — count only once to avoid double-counting.
|
||||
chars += len(tc.Function.Name) + len(tc.Function.Arguments)
|
||||
} else {
|
||||
// Fallback: some provider formats use top-level Name without Function.
|
||||
chars += len(tc.Name)
|
||||
}
|
||||
}
|
||||
|
||||
if msg.ToolCallID != "" {
|
||||
chars += len(msg.ToolCallID)
|
||||
}
|
||||
|
||||
// Per-message overhead for role label, JSON structure, separators.
|
||||
const messageOverhead = 12
|
||||
chars += messageOverhead
|
||||
|
||||
tokens := chars * 2 / 5
|
||||
|
||||
// Media items (images, files) are serialized by provider adapters into
|
||||
// multipart or image_url payloads. Add a fixed per-item token estimate
|
||||
// directly (not through the chars heuristic) since actual cost depends
|
||||
// on resolution and provider-specific image tokenization.
|
||||
const mediaTokensPerItem = 256
|
||||
tokens += len(msg.Media) * mediaTokensPerItem
|
||||
|
||||
return tokens
|
||||
// EstimateMessageTokens estimates the token count for a single message.
|
||||
// Delegates to the shared tokenizer package for consistency across agent and seahorse.
|
||||
func EstimateMessageTokens(msg providers.Message) int {
|
||||
return tokenizer.EstimateMessageTokens(msg)
|
||||
}
|
||||
|
||||
// estimateToolDefsTokens estimates the total token cost of tool definitions
|
||||
// as they appear in the LLM request. Each tool's name, description, and
|
||||
// JSON schema parameters contribute to the context window budget.
|
||||
func estimateToolDefsTokens(defs []providers.ToolDefinition) int {
|
||||
if len(defs) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
totalChars := 0
|
||||
for _, d := range defs {
|
||||
totalChars += len(d.Function.Name) + len(d.Function.Description)
|
||||
|
||||
if d.Function.Parameters != nil {
|
||||
if paramJSON, err := json.Marshal(d.Function.Parameters); err == nil {
|
||||
totalChars += len(paramJSON)
|
||||
}
|
||||
}
|
||||
|
||||
// Per-tool overhead: type field, JSON structure, separators.
|
||||
totalChars += 20
|
||||
}
|
||||
|
||||
return totalChars * 2 / 5
|
||||
// EstimateToolDefsTokens estimates the total token cost of tool definitions
|
||||
// as they appear in the LLM request. Delegates to the shared tokenizer package.
|
||||
func EstimateToolDefsTokens(defs []providers.ToolDefinition) int {
|
||||
return tokenizer.EstimateToolDefsTokens(defs)
|
||||
}
|
||||
|
||||
// isOverContextBudget checks whether the assembled messages plus tool definitions
|
||||
@@ -181,10 +107,10 @@ func isOverContextBudget(
|
||||
) bool {
|
||||
msgTokens := 0
|
||||
for _, m := range messages {
|
||||
msgTokens += estimateMessageTokens(m)
|
||||
msgTokens += EstimateMessageTokens(m)
|
||||
}
|
||||
|
||||
toolTokens := estimateToolDefsTokens(toolDefs)
|
||||
toolTokens := EstimateToolDefsTokens(toolDefs)
|
||||
total := msgTokens + toolTokens + maxTokens
|
||||
|
||||
return total > contextWindow
|
||||
|
||||
@@ -417,9 +417,9 @@ func TestEstimateMessageTokens(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := estimateMessageTokens(tt.msg)
|
||||
got := EstimateMessageTokens(tt.msg)
|
||||
if got < tt.want {
|
||||
t.Errorf("estimateMessageTokens() = %d, want >= %d", got, tt.want)
|
||||
t.Errorf("EstimateMessageTokens() = %d, want >= %d", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -443,8 +443,8 @@ func TestEstimateMessageTokens_ToolCallsContribute(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
plainTokens := estimateMessageTokens(plain)
|
||||
withTCTokens := estimateMessageTokens(withTC)
|
||||
plainTokens := EstimateMessageTokens(plain)
|
||||
withTCTokens := EstimateMessageTokens(withTC)
|
||||
|
||||
if withTCTokens <= plainTokens {
|
||||
t.Errorf("message with ToolCalls (%d tokens) should exceed plain message (%d tokens)",
|
||||
@@ -457,7 +457,7 @@ func TestEstimateMessageTokens_MultibyteContent(t *testing.T) {
|
||||
// but may map to different token counts. The heuristic should still produce
|
||||
// reasonable estimates via RuneCountInString.
|
||||
msg := msgUser("caf\u00e9 na\u00efve r\u00e9sum\u00e9 \u00fcber stra\u00dfe")
|
||||
tokens := estimateMessageTokens(msg)
|
||||
tokens := EstimateMessageTokens(msg)
|
||||
if tokens <= 0 {
|
||||
t.Errorf("multibyte message should produce positive token count, got %d", tokens)
|
||||
}
|
||||
@@ -481,7 +481,7 @@ func TestEstimateMessageTokens_LargeArguments(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
tokens := estimateMessageTokens(msg)
|
||||
tokens := EstimateMessageTokens(msg)
|
||||
// 5000+ chars → at least 2000 tokens with the 2.5 char/token heuristic
|
||||
if tokens < 2000 {
|
||||
t.Errorf("large tool call arguments should produce significant token count, got %d", tokens)
|
||||
@@ -496,8 +496,8 @@ func TestEstimateMessageTokens_ReasoningContent(t *testing.T) {
|
||||
ReasoningContent: strings.Repeat("thinking step ", 200),
|
||||
}
|
||||
|
||||
plainTokens := estimateMessageTokens(plain)
|
||||
reasoningTokens := estimateMessageTokens(withReasoning)
|
||||
plainTokens := EstimateMessageTokens(plain)
|
||||
reasoningTokens := EstimateMessageTokens(withReasoning)
|
||||
|
||||
if reasoningTokens <= plainTokens {
|
||||
t.Errorf("message with ReasoningContent (%d tokens) should exceed plain message (%d tokens)",
|
||||
@@ -513,8 +513,8 @@ func TestEstimateMessageTokens_MediaItems(t *testing.T) {
|
||||
Media: []string{"media://img1.png", "media://img2.png"},
|
||||
}
|
||||
|
||||
plainTokens := estimateMessageTokens(plain)
|
||||
mediaTokens := estimateMessageTokens(withMedia)
|
||||
plainTokens := EstimateMessageTokens(plain)
|
||||
mediaTokens := EstimateMessageTokens(withMedia)
|
||||
|
||||
if mediaTokens <= plainTokens {
|
||||
t.Errorf("message with Media (%d tokens) should exceed plain message (%d tokens)",
|
||||
@@ -540,8 +540,8 @@ func TestEstimateMessageTokens_SystemParts(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
plainTokens := estimateMessageTokens(plain)
|
||||
partsTokens := estimateMessageTokens(withParts)
|
||||
plainTokens := EstimateMessageTokens(plain)
|
||||
partsTokens := EstimateMessageTokens(withParts)
|
||||
|
||||
if partsTokens <= plainTokens {
|
||||
t.Errorf("system message with SystemParts (%d) should exceed plain message (%d)",
|
||||
@@ -549,7 +549,7 @@ func TestEstimateMessageTokens_SystemParts(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// --- estimateToolDefsTokens tests ---
|
||||
// --- EstimateToolDefsTokens tests ---
|
||||
|
||||
func TestEstimateToolDefsTokens(t *testing.T) {
|
||||
tests := []struct {
|
||||
@@ -599,9 +599,9 @@ func TestEstimateToolDefsTokens(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := estimateToolDefsTokens(tt.defs)
|
||||
got := EstimateToolDefsTokens(tt.defs)
|
||||
if got < tt.want {
|
||||
t.Errorf("estimateToolDefsTokens() = %d, want >= %d", got, tt.want)
|
||||
t.Errorf("EstimateToolDefsTokens() = %d, want >= %d", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -624,8 +624,8 @@ func TestEstimateToolDefsTokens_ScalesWithCount(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
one := estimateToolDefsTokens([]providers.ToolDefinition{makeTool("tool_a")})
|
||||
three := estimateToolDefsTokens([]providers.ToolDefinition{
|
||||
one := EstimateToolDefsTokens([]providers.ToolDefinition{makeTool("tool_a")})
|
||||
three := EstimateToolDefsTokens([]providers.ToolDefinition{
|
||||
makeTool("tool_a"), makeTool("tool_b"), makeTool("tool_c"),
|
||||
})
|
||||
|
||||
@@ -770,7 +770,7 @@ func TestEstimateMessageTokens_WithReasoningAndMedia(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
tokens := estimateMessageTokens(msg)
|
||||
tokens := EstimateMessageTokens(msg)
|
||||
|
||||
// ReasoningContent alone is ~1700 chars → ~680 tokens.
|
||||
// Content + TC + overhead adds more. Should be well above 500.
|
||||
@@ -781,7 +781,7 @@ func TestEstimateMessageTokens_WithReasoningAndMedia(t *testing.T) {
|
||||
// Compare without reasoning to ensure it's counted.
|
||||
msgNoReasoning := msg
|
||||
msgNoReasoning.ReasoningContent = ""
|
||||
tokensNoReasoning := estimateMessageTokens(msgNoReasoning)
|
||||
tokensNoReasoning := EstimateMessageTokens(msgNoReasoning)
|
||||
|
||||
if tokens <= tokensNoReasoning {
|
||||
t.Errorf("reasoning content should add tokens: with=%d, without=%d", tokens, tokensNoReasoning)
|
||||
|
||||
@@ -373,7 +373,7 @@ func (m *legacyContextManager) summarizeBatch(
|
||||
func (m *legacyContextManager) estimateTokens(messages []providers.Message) int {
|
||||
total := 0
|
||||
for _, msg := range messages {
|
||||
total += estimateMessageTokens(msg)
|
||||
total += EstimateMessageTokens(msg)
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
@@ -43,6 +43,7 @@ type AssembleResponse struct {
|
||||
type CompactRequest struct {
|
||||
SessionKey string // session identifier
|
||||
Reason ContextCompressReason // proactive_budget | llm_retry | summarize
|
||||
Budget int // context window budget (used for retry aggressive compaction)
|
||||
}
|
||||
|
||||
// IngestRequest is the input to Ingest.
|
||||
|
||||
@@ -0,0 +1,269 @@
|
||||
//go:build !mipsle && !netbsd && !(freebsd && arm)
|
||||
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
|
||||
"github.com/sipeed/picoclaw/pkg/seahorse"
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
"github.com/sipeed/picoclaw/pkg/tokenizer"
|
||||
)
|
||||
|
||||
// seahorseContextManager adapts seahorse.Engine to agent.ContextManager.
|
||||
type seahorseContextManager struct {
|
||||
engine *seahorse.Engine
|
||||
sessions session.SessionStore // for startup bootstrap
|
||||
}
|
||||
|
||||
// newSeahorseContextManager creates a seahorse-backed ContextManager.
|
||||
func newSeahorseContextManager(_ json.RawMessage, al *AgentLoop) (ContextManager, error) {
|
||||
if al == nil {
|
||||
return nil, fmt.Errorf("seahorse: AgentLoop is required")
|
||||
}
|
||||
|
||||
// Resolve workspace for DB path
|
||||
// DB stores session data, so it goes in sessions/ directory
|
||||
agent := al.registry.GetDefaultAgent()
|
||||
dbPath := agent.Workspace + "/sessions/seahorse.db"
|
||||
|
||||
// Create CompleteFn from provider
|
||||
completeFn := providerToCompleteFn(agent.Provider, agent.Model)
|
||||
|
||||
// Create engine
|
||||
engine, err := seahorse.NewEngine(seahorse.Config{
|
||||
DBPath: dbPath,
|
||||
}, completeFn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("seahorse: create engine: %w", err)
|
||||
}
|
||||
|
||||
mgr := &seahorseContextManager{
|
||||
engine: engine,
|
||||
sessions: agent.Sessions,
|
||||
}
|
||||
|
||||
// Register seahorse tools with the agent's tool registry
|
||||
retrieval := mgr.engine.GetRetrieval()
|
||||
al.RegisterTool(seahorse.NewGrepTool(retrieval))
|
||||
al.RegisterTool(seahorse.NewExpandTool(retrieval))
|
||||
|
||||
// Bootstrap all existing sessions at startup
|
||||
if agent.Sessions != nil {
|
||||
ctx := context.Background()
|
||||
for _, sessionKey := range agent.Sessions.ListSessions() {
|
||||
mgr.bootstrapSession(ctx, sessionKey)
|
||||
}
|
||||
}
|
||||
|
||||
return mgr, nil
|
||||
}
|
||||
|
||||
// providerToCompleteFn wraps providers.LLMProvider as a seahorse.CompleteFn.
|
||||
func providerToCompleteFn(provider providers.LLMProvider, model string) seahorse.CompleteFn {
|
||||
return func(ctx context.Context, prompt string, opts seahorse.CompleteOptions) (string, error) {
|
||||
resp, err := provider.Chat(
|
||||
ctx,
|
||||
[]providers.Message{{Role: "user", Content: prompt}},
|
||||
nil, // no tools for summarization
|
||||
model,
|
||||
map[string]any{
|
||||
"max_tokens": opts.MaxTokens,
|
||||
"temperature": opts.Temperature,
|
||||
"prompt_cache_key": "seahorse",
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return resp.Content, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Assemble builds budget-aware context from seahorse SQLite.
|
||||
func (m *seahorseContextManager) Assemble(ctx context.Context, req *AssembleRequest) (*AssembleResponse, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("seahorse assemble: nil request")
|
||||
}
|
||||
|
||||
budget := req.Budget
|
||||
if budget <= 0 {
|
||||
budget = 100000
|
||||
}
|
||||
|
||||
// Reserve space for model response (spec lines 1400-1410)
|
||||
effectiveBudget := budget - req.MaxTokens
|
||||
if effectiveBudget <= 0 {
|
||||
// MaxTokens >= budget is a configuration problem
|
||||
// Use 50% as minimum to avoid guaranteed overflow
|
||||
logger.WarnCF("agent", "MaxTokens >= budget, using 50% fallback",
|
||||
map[string]any{"budget": budget, "max_tokens": req.MaxTokens})
|
||||
effectiveBudget = budget / 2
|
||||
}
|
||||
|
||||
result, err := m.engine.Assemble(ctx, req.SessionKey, seahorse.AssembleInput{
|
||||
Budget: effectiveBudget,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("seahorse assemble: %w", err)
|
||||
}
|
||||
|
||||
history := seahorseToProviderMessages(result)
|
||||
|
||||
// Summary is already formatted as XML with system prompt addition by assembler
|
||||
return &AssembleResponse{
|
||||
History: history,
|
||||
Summary: result.Summary,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Compact compresses conversation history via seahorse summarization.
|
||||
func (m *seahorseContextManager) Compact(ctx context.Context, req *CompactRequest) error {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// For retry (LLM overflow), use aggressive CompactUntilUnder to guarantee
|
||||
// context shrinks below budget (spec lines ~1410).
|
||||
if req.Reason == ContextCompressReasonRetry && req.Budget > 0 {
|
||||
_, err := m.engine.CompactUntilUnder(ctx, req.SessionKey, req.Budget)
|
||||
return err
|
||||
}
|
||||
|
||||
_, err := m.engine.Compact(ctx, req.SessionKey, seahorse.CompactInput{
|
||||
Force: req.Reason == ContextCompressReasonRetry,
|
||||
Budget: &req.Budget,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// Ingest records a message into seahorse SQLite.
|
||||
// All existing sessions are bootstrapped at startup, so this only ingests new messages.
|
||||
func (m *seahorseContextManager) Ingest(ctx context.Context, req *IngestRequest) error {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
msg := providerToSeahorseMessage(req.Message)
|
||||
_, err := m.engine.Ingest(ctx, req.SessionKey, []seahorse.Message{msg})
|
||||
return err
|
||||
}
|
||||
|
||||
// bootstrapSession reconciles JSONL session history into seahorse SQLite.
|
||||
func (m *seahorseContextManager) bootstrapSession(ctx context.Context, sessionKey string) {
|
||||
if m.sessions == nil {
|
||||
return
|
||||
}
|
||||
|
||||
history := m.sessions.GetHistory(sessionKey)
|
||||
if len(history) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Convert provider messages to seahorse messages
|
||||
msgs := make([]seahorse.Message, len(history))
|
||||
for i, h := range history {
|
||||
msgs[i] = providerToSeahorseMessage(h)
|
||||
}
|
||||
|
||||
if err := m.engine.Bootstrap(ctx, sessionKey, msgs); err != nil {
|
||||
logger.WarnCF("seahorse", "bootstrap", map[string]any{
|
||||
"session": sessionKey,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// providerToSeahorseMessage converts a providers.Message to a seahorse.Message.
|
||||
func providerToSeahorseMessage(msg protocoltypes.Message) seahorse.Message {
|
||||
result := seahorse.Message{
|
||||
Role: msg.Role,
|
||||
Content: msg.Content,
|
||||
ReasoningContent: msg.ReasoningContent,
|
||||
TokenCount: tokenizer.EstimateMessageTokens(msg),
|
||||
}
|
||||
|
||||
// Convert ToolCalls → MessageParts
|
||||
for _, tc := range msg.ToolCalls {
|
||||
part := seahorse.MessagePart{
|
||||
Type: "tool_use",
|
||||
Name: tc.Function.Name,
|
||||
Arguments: tc.Function.Arguments,
|
||||
ToolCallID: tc.ID,
|
||||
}
|
||||
result.Parts = append(result.Parts, part)
|
||||
}
|
||||
|
||||
// Convert tool result
|
||||
if msg.ToolCallID != "" {
|
||||
part := seahorse.MessagePart{
|
||||
Type: "tool_result",
|
||||
ToolCallID: msg.ToolCallID,
|
||||
Text: msg.Content,
|
||||
}
|
||||
result.Parts = append(result.Parts, part)
|
||||
}
|
||||
|
||||
// Convert media attachments
|
||||
for _, mediaURI := range msg.Media {
|
||||
part := seahorse.MessagePart{
|
||||
Type: "media",
|
||||
MediaURI: mediaURI,
|
||||
}
|
||||
result.Parts = append(result.Parts, part)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// seahorseToProviderMessages converts a seahorse.AssembleResult to []providers.Message.
|
||||
func seahorseToProviderMessages(result *seahorse.AssembleResult) []protocoltypes.Message {
|
||||
messages := make([]protocoltypes.Message, 0, len(result.Messages))
|
||||
|
||||
// Convert assembled messages (which already include summary XML messages)
|
||||
for _, msg := range result.Messages {
|
||||
pm := protocoltypes.Message{
|
||||
Role: msg.Role,
|
||||
Content: msg.Content,
|
||||
ReasoningContent: msg.ReasoningContent,
|
||||
}
|
||||
|
||||
// Reconstruct ToolCalls from parts
|
||||
for _, part := range msg.Parts {
|
||||
if part.Type == "tool_use" {
|
||||
pm.ToolCalls = append(pm.ToolCalls, protocoltypes.ToolCall{
|
||||
ID: part.ToolCallID,
|
||||
Type: "function", // Required by OpenAI-compatible APIs (GLM, etc.)
|
||||
Function: &protocoltypes.FunctionCall{
|
||||
Name: part.Name,
|
||||
Arguments: part.Arguments,
|
||||
},
|
||||
})
|
||||
}
|
||||
if part.Type == "tool_result" {
|
||||
pm.ToolCallID = part.ToolCallID
|
||||
if pm.Content == "" && part.Text != "" {
|
||||
pm.Content = part.Text
|
||||
}
|
||||
}
|
||||
if part.Type == "media" && part.MediaURI != "" {
|
||||
pm.Media = append(pm.Media, part.MediaURI)
|
||||
}
|
||||
}
|
||||
|
||||
messages = append(messages, pm)
|
||||
}
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
func init() {
|
||||
if err := RegisterContextManager("seahorse", newSeahorseContextManager); err != nil {
|
||||
panic(fmt.Sprintf("register seahorse context manager: %v", err))
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,20 @@
|
||||
//go:build mipsle || netbsd || (freebsd && arm)
|
||||
|
||||
package agent
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// newSeahorseContextManager is unavailable on platforms where modernc sqlite/libc
|
||||
// currently has no stable build path for this project.
|
||||
func newSeahorseContextManager(_ json.RawMessage, _ *AgentLoop) (ContextManager, error) {
|
||||
return nil, fmt.Errorf("seahorse context manager is unavailable on this platform")
|
||||
}
|
||||
|
||||
func init() {
|
||||
if err := RegisterContextManager("seahorse", newSeahorseContextManager); err != nil {
|
||||
panic(fmt.Sprintf("register seahorse context manager: %v", err))
|
||||
}
|
||||
}
|
||||
@@ -12,7 +12,9 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/isolation"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -90,7 +92,8 @@ type processHookAfterLLMResponse struct {
|
||||
|
||||
type processHookBeforeToolResponse struct {
|
||||
processHookDecisionResponse
|
||||
Call *ToolCallHookRequest `json:"call,omitempty"`
|
||||
Call *ToolCallHookRequest `json:"call,omitempty"`
|
||||
Result *tools.ToolResult `json:"result,omitempty"` // Result returned directly by hook (for respond action)
|
||||
}
|
||||
|
||||
type processHookAfterToolResponse struct {
|
||||
@@ -120,7 +123,9 @@ func NewProcessHook(ctx context.Context, name string, opts ProcessHookOptions) (
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create process hook stderr: %w", err)
|
||||
}
|
||||
if err := cmd.Start(); err != nil {
|
||||
// Route hook subprocess startup through the shared isolation entry point so
|
||||
// process hooks inherit the same isolation behavior as other child processes.
|
||||
if err := isolation.Start(cmd); err != nil {
|
||||
return nil, fmt.Errorf("start process hook: %w", err)
|
||||
}
|
||||
|
||||
@@ -241,6 +246,10 @@ func (ph *ProcessHook) BeforeTool(
|
||||
if resp.Call == nil {
|
||||
resp.Call = call
|
||||
}
|
||||
// If hook returned a Result, carry it in ToolCallHookRequest
|
||||
if resp.Result != nil {
|
||||
resp.Call.HookResult = resp.Result
|
||||
}
|
||||
return resp.Call, HookDecision{Action: resp.Action, Reason: resp.Reason}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -7,10 +7,13 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/isolation"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
@@ -178,6 +181,76 @@ func TestAgentLoop_MountProcessHook_ApprovalDeny(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_MountProcessHook_IsolationSupportsRelativeDirAndCommand(t *testing.T) {
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skip("linux-only isolation path handling")
|
||||
}
|
||||
|
||||
provider := &llmHookTestProvider{}
|
||||
al, agent, cleanup := newHookTestLoop(t, provider)
|
||||
defer cleanup()
|
||||
|
||||
root := t.TempDir()
|
||||
t.Setenv(config.EnvHome, filepath.Join(root, "picoclaw-home"))
|
||||
binDir := filepath.Join(root, "bin")
|
||||
hookDir := filepath.Join(root, "hooks")
|
||||
if err := os.MkdirAll(binDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.MkdirAll(hookDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
writeFakeBwrap(t, filepath.Join(binDir, "bwrap"))
|
||||
t.Setenv("PATH", binDir+string(os.PathListSeparator)+os.Getenv("PATH"))
|
||||
linkTestBinary(t, os.Args[0], filepath.Join(hookDir, "hook-helper"))
|
||||
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Isolation.Enabled = true
|
||||
isolation.Configure(cfg)
|
||||
t.Cleanup(func() { isolation.Configure(config.DefaultConfig()) })
|
||||
|
||||
cwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
relHookDir, err := filepath.Rel(cwd, hookDir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
mountErr := al.MountProcessHook(context.Background(), "ipc-relative", ProcessHookOptions{
|
||||
Command: []string{"./hook-helper", "-test.run=TestProcessHook_HelperProcess", "--"},
|
||||
Dir: relHookDir,
|
||||
Env: processHookHelperEnv("rewrite", ""),
|
||||
InterceptLLM: true,
|
||||
})
|
||||
if mountErr != nil {
|
||||
t.Fatalf("MountProcessHook failed with relative dir/command under isolation: %v", mountErr)
|
||||
}
|
||||
|
||||
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
||||
SessionKey: "session-relative",
|
||||
Channel: "cli",
|
||||
ChatID: "direct",
|
||||
UserMessage: "hello",
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("runAgentLoop failed: %v", err)
|
||||
}
|
||||
if resp != "provider content|ipc" {
|
||||
t.Fatalf("expected process-hooked llm content, got %q", resp)
|
||||
}
|
||||
provider.mu.Lock()
|
||||
lastModel := provider.lastModel
|
||||
provider.mu.Unlock()
|
||||
if lastModel != "process-model" {
|
||||
t.Fatalf("expected process model, got %q", lastModel)
|
||||
}
|
||||
}
|
||||
|
||||
func processHookHelperCommand() []string {
|
||||
return []string{os.Args[0], "-test.run=TestProcessHook_HelperProcess", "--"}
|
||||
}
|
||||
@@ -193,6 +266,59 @@ func processHookHelperEnv(mode, eventLog string) []string {
|
||||
return env
|
||||
}
|
||||
|
||||
func writeFakeBwrap(t *testing.T, path string) {
|
||||
t.Helper()
|
||||
script := `#!/bin/sh
|
||||
set -eu
|
||||
workdir=
|
||||
while [ "$#" -gt 0 ]; do
|
||||
case "$1" in
|
||||
--)
|
||||
shift
|
||||
break
|
||||
;;
|
||||
--chdir)
|
||||
workdir="$2"
|
||||
shift 2
|
||||
;;
|
||||
--bind|--ro-bind)
|
||||
shift 3
|
||||
;;
|
||||
--proc|--dev)
|
||||
shift 2
|
||||
;;
|
||||
--die-with-parent|--unshare-ipc)
|
||||
shift
|
||||
;;
|
||||
*)
|
||||
shift
|
||||
;;
|
||||
esac
|
||||
done
|
||||
if [ -n "$workdir" ]; then
|
||||
cd "$workdir"
|
||||
fi
|
||||
exec "$@"
|
||||
`
|
||||
if err := os.WriteFile(path, []byte(script), 0o755); err != nil {
|
||||
t.Fatalf("write fake bwrap: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func linkTestBinary(t *testing.T, source, target string) {
|
||||
t.Helper()
|
||||
if err := os.Symlink(source, target); err == nil {
|
||||
return
|
||||
}
|
||||
data, err := os.ReadFile(source)
|
||||
if err != nil {
|
||||
t.Fatalf("read test binary: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(target, data, 0o755); err != nil {
|
||||
t.Fatalf("create hook helper binary: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func waitForFileContains(t *testing.T, path, substring string) {
|
||||
t.Helper()
|
||||
|
||||
|
||||
+19
-5
@@ -25,6 +25,7 @@ type HookAction string
|
||||
const (
|
||||
HookActionContinue HookAction = "continue"
|
||||
HookActionModify HookAction = "modify"
|
||||
HookActionRespond HookAction = "respond" // Return result directly, skip tool execution. SECURITY: This bypasses ApproveTool checks, allowing hooks to return results for any tool (including sensitive ones like bash) without approval. Use with caution.
|
||||
HookActionDenyTool HookAction = "deny_tool"
|
||||
HookActionAbortTurn HookAction = "abort_turn"
|
||||
HookActionHardAbort HookAction = "hard_abort"
|
||||
@@ -127,11 +128,12 @@ func (r *LLMHookResponse) Clone() *LLMHookResponse {
|
||||
}
|
||||
|
||||
type ToolCallHookRequest struct {
|
||||
Meta EventMeta `json:"meta"`
|
||||
Tool string `json:"tool"`
|
||||
Arguments map[string]any `json:"arguments,omitempty"`
|
||||
Channel string `json:"channel,omitempty"`
|
||||
ChatID string `json:"chat_id,omitempty"`
|
||||
Meta EventMeta `json:"meta"`
|
||||
Tool string `json:"tool"`
|
||||
Arguments map[string]any `json:"arguments,omitempty"`
|
||||
Channel string `json:"channel,omitempty"`
|
||||
ChatID string `json:"chat_id,omitempty"`
|
||||
HookResult *tools.ToolResult `json:"hook_result,omitempty"` // Result returned directly by hook (for respond action). Media is supported - see Media handling section in docs.
|
||||
}
|
||||
|
||||
func (r *ToolCallHookRequest) Clone() *ToolCallHookRequest {
|
||||
@@ -140,6 +142,7 @@ func (r *ToolCallHookRequest) Clone() *ToolCallHookRequest {
|
||||
}
|
||||
cloned := *r
|
||||
cloned.Arguments = cloneStringAnyMap(r.Arguments)
|
||||
cloned.HookResult = cloneToolResult(r.HookResult)
|
||||
return &cloned
|
||||
}
|
||||
|
||||
@@ -382,6 +385,10 @@ func (hm *HookManager) BeforeTool(
|
||||
if next != nil {
|
||||
current = next
|
||||
}
|
||||
case HookActionRespond:
|
||||
// Hook returns result directly, skip tool execution
|
||||
// Carry HookResult in ToolCallHookRequest and return
|
||||
return next, decision
|
||||
case HookActionDenyTool, HookActionAbortTurn, HookActionHardAbort:
|
||||
return current, decision
|
||||
default:
|
||||
@@ -793,6 +800,13 @@ func cloneToolResult(result *tools.ToolResult) *tools.ToolResult {
|
||||
if len(result.Media) > 0 {
|
||||
cloned.Media = append([]string(nil), result.Media...)
|
||||
}
|
||||
if len(result.ArtifactTags) > 0 {
|
||||
cloned.ArtifactTags = append([]string(nil), result.ArtifactTags...)
|
||||
}
|
||||
if len(result.Messages) > 0 {
|
||||
cloned.Messages = make([]providers.Message, len(result.Messages))
|
||||
copy(cloned.Messages, result.Messages)
|
||||
}
|
||||
return &cloned
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
@@ -10,6 +11,7 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
|
||||
@@ -343,3 +345,517 @@ func TestAgentLoop_Hooks_ToolApproverCanDeny(t *testing.T) {
|
||||
t.Fatalf("expected skipped reason %q, got %q", expected, payload.Reason)
|
||||
}
|
||||
}
|
||||
|
||||
// respondHook is a test hook for testing HookActionRespond functionality
|
||||
type respondHook struct {
|
||||
respondTools map[string]bool // tool names to respond to
|
||||
}
|
||||
|
||||
func (h *respondHook) BeforeTool(
|
||||
ctx context.Context,
|
||||
call *ToolCallHookRequest,
|
||||
) (*ToolCallHookRequest, HookDecision, error) {
|
||||
if h.respondTools[call.Tool] {
|
||||
next := call.Clone()
|
||||
next.HookResult = &tools.ToolResult{
|
||||
ForLLM: "hook-responded: " + call.Tool,
|
||||
ForUser: "",
|
||||
Silent: false,
|
||||
IsError: false,
|
||||
}
|
||||
return next, HookDecision{Action: HookActionRespond}, nil
|
||||
}
|
||||
return call, HookDecision{Action: HookActionContinue}, nil
|
||||
}
|
||||
|
||||
func (h *respondHook) AfterTool(
|
||||
ctx context.Context,
|
||||
result *ToolResultHookResponse,
|
||||
) (*ToolResultHookResponse, HookDecision, error) {
|
||||
// Should not be called since respond skips tool execution
|
||||
return result, HookDecision{Action: HookActionContinue}, nil
|
||||
}
|
||||
|
||||
func TestAgentLoop_Hooks_ToolRespondAction(t *testing.T) {
|
||||
provider := &toolHookProvider{}
|
||||
al, agent, cleanup := newHookTestLoop(t, provider)
|
||||
defer cleanup()
|
||||
|
||||
al.RegisterTool(&echoTextTool{})
|
||||
if err := al.MountHook(NamedHook("respond-hook", &respondHook{
|
||||
respondTools: map[string]bool{"echo_text": true},
|
||||
})); err != nil {
|
||||
t.Fatalf("MountHook failed: %v", err)
|
||||
}
|
||||
|
||||
sub := al.SubscribeEvents(16)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
|
||||
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
||||
SessionKey: "session-1",
|
||||
Channel: "cli",
|
||||
ChatID: "direct",
|
||||
UserMessage: "run tool",
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("runAgentLoop failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify response comes from hook, not tool
|
||||
expected := "hook-responded: echo_text"
|
||||
if resp != expected {
|
||||
t.Fatalf("expected %q, got %q", expected, resp)
|
||||
}
|
||||
|
||||
// Verify event stream has ToolExecEnd, not actual tool execution
|
||||
events := collectEventStream(sub.C)
|
||||
endEvt, ok := findEvent(events, EventKindToolExecEnd)
|
||||
if !ok {
|
||||
t.Fatal("expected tool exec end event")
|
||||
}
|
||||
payload, ok := endEvt.Payload.(ToolExecEndPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected ToolExecEndPayload, got %T", endEvt.Payload)
|
||||
}
|
||||
if payload.Tool != "echo_text" {
|
||||
t.Fatalf("expected tool echo_text, got %q", payload.Tool)
|
||||
}
|
||||
if payload.ForLLMLen != len(expected) {
|
||||
t.Fatalf("expected ForLLMLen %d, got %d", len(expected), payload.ForLLMLen)
|
||||
}
|
||||
}
|
||||
|
||||
// denyToolHook tests HookActionDenyTool functionality
|
||||
type denyToolHook struct {
|
||||
denyTools map[string]bool
|
||||
}
|
||||
|
||||
func (h *denyToolHook) BeforeTool(
|
||||
ctx context.Context,
|
||||
call *ToolCallHookRequest,
|
||||
) (*ToolCallHookRequest, HookDecision, error) {
|
||||
if h.denyTools[call.Tool] {
|
||||
return call, HookDecision{Action: HookActionDenyTool, Reason: "tool denied by hook"}, nil
|
||||
}
|
||||
return call, HookDecision{Action: HookActionContinue}, nil
|
||||
}
|
||||
|
||||
func (h *denyToolHook) AfterTool(
|
||||
ctx context.Context,
|
||||
result *ToolResultHookResponse,
|
||||
) (*ToolResultHookResponse, HookDecision, error) {
|
||||
return result, HookDecision{Action: HookActionContinue}, nil
|
||||
}
|
||||
|
||||
func TestAgentLoop_Hooks_ToolDenyAction(t *testing.T) {
|
||||
provider := &toolHookProvider{}
|
||||
al, agent, cleanup := newHookTestLoop(t, provider)
|
||||
defer cleanup()
|
||||
|
||||
al.RegisterTool(&echoTextTool{})
|
||||
if err := al.MountHook(NamedHook("deny-hook", &denyToolHook{
|
||||
denyTools: map[string]bool{"echo_text": true},
|
||||
})); err != nil {
|
||||
t.Fatalf("MountHook failed: %v", err)
|
||||
}
|
||||
|
||||
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
||||
SessionKey: "session-1",
|
||||
Channel: "cli",
|
||||
ChatID: "direct",
|
||||
UserMessage: "run tool",
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("runAgentLoop failed: %v", err)
|
||||
}
|
||||
|
||||
expected := "Tool execution denied by hook: tool denied by hook"
|
||||
if resp != expected {
|
||||
t.Fatalf("expected %q, got %q", expected, resp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHookManager_BeforeTool_RespondAction(t *testing.T) {
|
||||
hm := NewHookManager(nil)
|
||||
defer hm.Close()
|
||||
|
||||
hook := &respondHook{
|
||||
respondTools: map[string]bool{"test_tool": true},
|
||||
}
|
||||
if err := hm.Mount(NamedHook("respond-test", hook)); err != nil {
|
||||
t.Fatalf("mount hook: %v", err)
|
||||
}
|
||||
|
||||
req := &ToolCallHookRequest{
|
||||
Tool: "test_tool",
|
||||
Arguments: map[string]any{"arg": "value"},
|
||||
}
|
||||
result, decision := hm.BeforeTool(context.Background(), req)
|
||||
|
||||
if decision.Action != HookActionRespond {
|
||||
t.Fatalf("expected action %q, got %q", HookActionRespond, decision.Action)
|
||||
}
|
||||
|
||||
if result.HookResult == nil {
|
||||
t.Fatal("expected HookResult to be set")
|
||||
}
|
||||
if result.HookResult.ForLLM != "hook-responded: test_tool" {
|
||||
t.Fatalf("unexpected HookResult.ForLLM: %q", result.HookResult.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
type respondWithMediaHook struct {
|
||||
respondTools map[string]bool
|
||||
media []string
|
||||
responseHandled bool
|
||||
forLLM string
|
||||
}
|
||||
|
||||
func (h *respondWithMediaHook) BeforeTool(
|
||||
ctx context.Context,
|
||||
call *ToolCallHookRequest,
|
||||
) (*ToolCallHookRequest, HookDecision, error) {
|
||||
if h.respondTools[call.Tool] {
|
||||
next := call.Clone()
|
||||
next.HookResult = &tools.ToolResult{
|
||||
ForLLM: h.forLLM,
|
||||
ForUser: "media result",
|
||||
Media: h.media,
|
||||
ResponseHandled: h.responseHandled,
|
||||
Silent: false,
|
||||
IsError: false,
|
||||
}
|
||||
return next, HookDecision{Action: HookActionRespond}, nil
|
||||
}
|
||||
return call, HookDecision{Action: HookActionContinue}, nil
|
||||
}
|
||||
|
||||
func (h *respondWithMediaHook) AfterTool(
|
||||
ctx context.Context,
|
||||
result *ToolResultHookResponse,
|
||||
) (*ToolResultHookResponse, HookDecision, error) {
|
||||
return result, HookDecision{Action: HookActionContinue}, nil
|
||||
}
|
||||
|
||||
type errorMediaChannel struct {
|
||||
fakeChannel
|
||||
sendErr error
|
||||
}
|
||||
|
||||
func (f *errorMediaChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) ([]string, error) {
|
||||
return nil, f.sendErr
|
||||
}
|
||||
|
||||
func TestAgentLoop_HookRespond_MediaError(t *testing.T) {
|
||||
provider := &multiToolProvider{
|
||||
toolCalls: []providers.ToolCall{
|
||||
{ID: "call-1", Name: "media_tool", Arguments: map[string]any{}},
|
||||
},
|
||||
finalContent: "done",
|
||||
}
|
||||
al, agent, cleanup := newHookTestLoop(t, provider)
|
||||
defer cleanup()
|
||||
|
||||
hook := &respondWithMediaHook{
|
||||
respondTools: map[string]bool{"media_tool": true},
|
||||
media: []string{"media://test/image.png"},
|
||||
responseHandled: true,
|
||||
forLLM: "media sent successfully",
|
||||
}
|
||||
if err := al.MountHook(NamedHook("media-hook", hook)); err != nil {
|
||||
t.Fatalf("MountHook failed: %v", err)
|
||||
}
|
||||
|
||||
al.channelManager = newStartedTestChannelManager(t, al.bus, al.mediaStore, "discord", &errorMediaChannel{
|
||||
sendErr: errors.New("channel unavailable"),
|
||||
})
|
||||
|
||||
sub := al.SubscribeEvents(16)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
|
||||
_, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
||||
SessionKey: "session-media-err",
|
||||
Channel: "discord",
|
||||
ChatID: "chat1",
|
||||
UserMessage: "send media",
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("runAgentLoop failed: %v", err)
|
||||
}
|
||||
|
||||
events := collectEventStream(sub.C)
|
||||
endEvt, ok := findEvent(events, EventKindToolExecEnd)
|
||||
if !ok {
|
||||
t.Fatal("expected ToolExecEnd event")
|
||||
}
|
||||
payload, ok := endEvt.Payload.(ToolExecEndPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected ToolExecEndPayload, got %T", endEvt.Payload)
|
||||
}
|
||||
|
||||
if !payload.IsError {
|
||||
t.Fatal("expected IsError=true when SendMedia fails")
|
||||
}
|
||||
|
||||
if payload.ForLLMLen < 30 {
|
||||
t.Fatalf("expected ForLLM to contain error message, got ForLLMLen=%d", payload.ForLLMLen)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_HookRespond_BusFallback(t *testing.T) {
|
||||
provider := &multiToolProvider{
|
||||
toolCalls: []providers.ToolCall{
|
||||
{ID: "call-1", Name: "media_tool", Arguments: map[string]any{}},
|
||||
},
|
||||
finalContent: "done",
|
||||
}
|
||||
al, agent, cleanup := newHookTestLoop(t, provider)
|
||||
defer cleanup()
|
||||
|
||||
hook := &respondWithMediaHook{
|
||||
respondTools: map[string]bool{"media_tool": true},
|
||||
media: []string{"media://test/image.png"},
|
||||
responseHandled: true,
|
||||
forLLM: "media queued",
|
||||
}
|
||||
if err := al.MountHook(NamedHook("media-hook", hook)); err != nil {
|
||||
t.Fatalf("MountHook failed: %v", err)
|
||||
}
|
||||
|
||||
sub := al.SubscribeEvents(16)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
|
||||
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
||||
SessionKey: "session-bus-fallback",
|
||||
Channel: "cli",
|
||||
ChatID: "chat1",
|
||||
UserMessage: "send media",
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("runAgentLoop failed: %v", err)
|
||||
}
|
||||
|
||||
events := collectEventStream(sub.C)
|
||||
endEvt, ok := findEvent(events, EventKindToolExecEnd)
|
||||
if !ok {
|
||||
t.Fatal("expected ToolExecEnd event")
|
||||
}
|
||||
payload, ok := endEvt.Payload.(ToolExecEndPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected ToolExecEndPayload, got %T", endEvt.Payload)
|
||||
}
|
||||
|
||||
if payload.IsError {
|
||||
t.Fatal("expected IsError=false for bus fallback (media queued, not delivered)")
|
||||
}
|
||||
|
||||
if resp != "done" {
|
||||
t.Fatalf("expected response 'done', got %q", resp)
|
||||
}
|
||||
}
|
||||
|
||||
type multiToolProvider struct {
|
||||
mu sync.Mutex
|
||||
callCount int
|
||||
toolCalls []providers.ToolCall
|
||||
finalContent string
|
||||
}
|
||||
|
||||
func (p *multiToolProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
p.callCount++
|
||||
if p.callCount == 1 && len(p.toolCalls) > 0 {
|
||||
return &providers.LLMResponse{
|
||||
ToolCalls: p.toolCalls,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &providers.LLMResponse{
|
||||
Content: p.finalContent,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *multiToolProvider) GetDefaultModel() string {
|
||||
return "multi-tool-provider"
|
||||
}
|
||||
|
||||
func TestAgentLoop_HookRespond_InterruptSkipsRemaining(t *testing.T) {
|
||||
provider := &multiToolProvider{
|
||||
toolCalls: []providers.ToolCall{
|
||||
{ID: "call-1", Name: "tool_one", Arguments: map[string]any{}},
|
||||
{ID: "call-2", Name: "tool_two", Arguments: map[string]any{}},
|
||||
{ID: "call-3", Name: "tool_three", Arguments: map[string]any{}},
|
||||
},
|
||||
finalContent: "done",
|
||||
}
|
||||
al, _, cleanup := newHookTestLoop(t, provider)
|
||||
defer cleanup()
|
||||
|
||||
tool1ExecCh := make(chan struct{}, 1)
|
||||
al.RegisterTool(&slowTool{name: "tool_two", duration: 100 * time.Millisecond, execCh: tool1ExecCh})
|
||||
al.RegisterTool(&slowTool{name: "tool_three", duration: 100 * time.Millisecond})
|
||||
|
||||
hook := &respondHook{
|
||||
respondTools: map[string]bool{"tool_one": true},
|
||||
}
|
||||
if err := al.MountHook(NamedHook("respond-hook", hook)); err != nil {
|
||||
t.Fatalf("MountHook failed: %v", err)
|
||||
}
|
||||
|
||||
sub := al.SubscribeEvents(32)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
|
||||
sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID)
|
||||
|
||||
type result struct {
|
||||
resp string
|
||||
err error
|
||||
}
|
||||
resultCh := make(chan result, 1)
|
||||
go func() {
|
||||
resp, err := al.ProcessDirectWithChannel(
|
||||
context.Background(),
|
||||
"run tools",
|
||||
sessionKey,
|
||||
"cli",
|
||||
"chat1",
|
||||
)
|
||||
resultCh <- result{resp: resp, err: err}
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if err := al.InterruptGraceful("stop now"); err != nil {
|
||||
t.Fatalf("InterruptGraceful failed: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case r := <-resultCh:
|
||||
if r.err != nil {
|
||||
t.Fatalf("unexpected error: %v", r.err)
|
||||
}
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("timeout waiting for result")
|
||||
}
|
||||
|
||||
events := collectEventStream(sub.C)
|
||||
|
||||
skippedEvts := filterEvents(events, EventKindToolExecSkipped)
|
||||
if len(skippedEvts) < 1 {
|
||||
t.Fatal("expected at least one ToolExecSkipped event after interrupt")
|
||||
}
|
||||
|
||||
for _, evt := range skippedEvts {
|
||||
payload, ok := evt.Payload.(ToolExecSkippedPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected ToolExecSkippedPayload, got %T", evt.Payload)
|
||||
}
|
||||
if payload.Reason != "graceful interrupt requested" {
|
||||
t.Fatalf("expected skip reason 'graceful interrupt requested', got %q", payload.Reason)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_HookRespond_SteeringSkipsRemaining(t *testing.T) {
|
||||
provider := &multiToolProvider{
|
||||
toolCalls: []providers.ToolCall{
|
||||
{ID: "call-1", Name: "tool_one", Arguments: map[string]any{}},
|
||||
{ID: "call-2", Name: "tool_two", Arguments: map[string]any{}},
|
||||
{ID: "call-3", Name: "tool_three", Arguments: map[string]any{}},
|
||||
},
|
||||
finalContent: "done",
|
||||
}
|
||||
al, _, cleanup := newHookTestLoop(t, provider)
|
||||
defer cleanup()
|
||||
|
||||
al.RegisterTool(&slowTool{name: "tool_two", duration: 100 * time.Millisecond})
|
||||
al.RegisterTool(&slowTool{name: "tool_three", duration: 100 * time.Millisecond})
|
||||
|
||||
hook := &respondHook{
|
||||
respondTools: map[string]bool{"tool_one": true},
|
||||
}
|
||||
if err := al.MountHook(NamedHook("respond-hook", hook)); err != nil {
|
||||
t.Fatalf("MountHook failed: %v", err)
|
||||
}
|
||||
|
||||
sub := al.SubscribeEvents(32)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
|
||||
sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID)
|
||||
|
||||
type result struct {
|
||||
resp string
|
||||
err error
|
||||
}
|
||||
resultCh := make(chan result, 1)
|
||||
go func() {
|
||||
resp, err := al.ProcessDirectWithChannel(
|
||||
context.Background(),
|
||||
"run tools",
|
||||
sessionKey,
|
||||
"cli",
|
||||
"chat1",
|
||||
)
|
||||
resultCh <- result{resp: resp, err: err}
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
al.Steer(providers.Message{Role: "user", Content: "change direction"})
|
||||
|
||||
select {
|
||||
case r := <-resultCh:
|
||||
if r.err != nil {
|
||||
t.Fatalf("unexpected error: %v", r.err)
|
||||
}
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("timeout waiting for result")
|
||||
}
|
||||
|
||||
events := collectEventStream(sub.C)
|
||||
|
||||
skippedEvts := filterEvents(events, EventKindToolExecSkipped)
|
||||
if len(skippedEvts) < 1 {
|
||||
t.Fatal("expected at least one ToolExecSkipped event after steering")
|
||||
}
|
||||
|
||||
for _, evt := range skippedEvts {
|
||||
payload, ok := evt.Payload.(ToolExecSkippedPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected ToolExecSkippedPayload, got %T", evt.Payload)
|
||||
}
|
||||
if payload.Reason != "queued user steering message" {
|
||||
t.Fatalf("expected skip reason 'queued user steering message', got %q", payload.Reason)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func filterEvents(events []Event, kind EventKind) []Event {
|
||||
var result []Event
|
||||
for _, evt := range events {
|
||||
if evt.Kind == kind {
|
||||
result = append(result, evt)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/isolation"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
"github.com/sipeed/picoclaw/pkg/memory"
|
||||
@@ -51,6 +52,10 @@ type AgentInstance struct {
|
||||
// LightProvider is the concrete provider instance for the configured light model.
|
||||
// It is only used when routing selects the light tier for a turn.
|
||||
LightProvider providers.LLMProvider
|
||||
// CandidateProviders maps "provider/model" keys to per-candidate LLMProvider
|
||||
// instances. This allows each fallback model to use its own api_base and api_key
|
||||
// from model_list, instead of inheriting the primary model's provider config.
|
||||
CandidateProviders map[string]providers.LLMProvider
|
||||
}
|
||||
|
||||
// NewAgentInstance creates an agent instance from config.
|
||||
@@ -60,6 +65,12 @@ func NewAgentInstance(
|
||||
cfg *config.Config,
|
||||
provider providers.LLMProvider,
|
||||
) *AgentInstance {
|
||||
if cfg != nil {
|
||||
// Keep the subprocess isolation runtime aligned with the latest loaded config
|
||||
// before any tools or providers start spawning child processes.
|
||||
isolation.Configure(cfg)
|
||||
}
|
||||
|
||||
workspace := resolveAgentWorkspace(agentCfg, defaults)
|
||||
os.MkdirAll(workspace, 0o755)
|
||||
|
||||
@@ -175,6 +186,9 @@ func NewAgentInstance(
|
||||
// Resolve fallback candidates
|
||||
candidates := resolveModelCandidates(cfg, defaults.Provider, model, fallbacks)
|
||||
|
||||
candidateProviders := make(map[string]providers.LLMProvider)
|
||||
populateCandidateProvidersFromNames(cfg, workspace, fallbacks, candidateProviders)
|
||||
|
||||
// Model routing setup: pre-resolve light model candidates at creation time
|
||||
// to avoid repeated model_list lookups on every incoming message.
|
||||
var router *routing.Router
|
||||
@@ -199,6 +213,7 @@ func NewAgentInstance(
|
||||
})
|
||||
lightCandidates = resolved
|
||||
lightProvider = lp
|
||||
populateCandidateProvidersFromNames(cfg, workspace, []string{rc.LightModel}, candidateProviders)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -230,6 +245,43 @@ func NewAgentInstance(
|
||||
Router: router,
|
||||
LightCandidates: lightCandidates,
|
||||
LightProvider: lightProvider,
|
||||
CandidateProviders: candidateProviders,
|
||||
}
|
||||
}
|
||||
|
||||
// populateCandidateProvidersFromNames resolves each model name (alias or
|
||||
// "provider/model") via resolvedModelConfig and creates a dedicated LLMProvider
|
||||
// for it. This reuses the canonical config resolution path (GetModelConfig) so
|
||||
// alias handling and load-balancing stay consistent with the rest of the codebase.
|
||||
func populateCandidateProvidersFromNames(
|
||||
cfg *config.Config,
|
||||
workspace string,
|
||||
names []string,
|
||||
out map[string]providers.LLMProvider,
|
||||
) {
|
||||
if cfg == nil || len(names) == 0 {
|
||||
return
|
||||
}
|
||||
for _, name := range names {
|
||||
mc, err := resolvedModelConfig(cfg, strings.TrimSpace(name), workspace)
|
||||
if err != nil {
|
||||
logger.WarnCF("agent",
|
||||
"fallback provider: no model_list entry found; will inherit primary provider credentials",
|
||||
map[string]any{"name": name, "error": err.Error()})
|
||||
continue
|
||||
}
|
||||
protocol, modelID := providers.ExtractProtocol(strings.TrimSpace(mc.Model))
|
||||
key := providers.ModelKey(providers.NormalizeProvider(protocol), modelID)
|
||||
if _, exists := out[key]; exists {
|
||||
continue
|
||||
}
|
||||
p, _, err := providers.CreateProviderFromConfig(mc)
|
||||
if err != nil {
|
||||
logger.WarnCF("agent", "fallback provider: failed to create provider",
|
||||
map[string]any{"model": mc.Model, "error": err.Error()})
|
||||
continue
|
||||
}
|
||||
out[key] = p
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/media"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
func TestNewAgentInstance_UsesDefaultsTemperatureAndMaxTokens(t *testing.T) {
|
||||
@@ -300,6 +301,199 @@ func TestNewAgentInstance_AllowsMediaTempDirForReadListAndExec(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestPopulateCandidateProviders_NilCfgIsNoop verifies that passing a nil
|
||||
// config does not panic and leaves the output map empty.
|
||||
func TestPopulateCandidateProviders_NilCfgIsNoop(t *testing.T) {
|
||||
out := map[string]providers.LLMProvider{}
|
||||
populateCandidateProvidersFromNames(nil, t.TempDir(), []string{"gpt-4o"}, out)
|
||||
if len(out) != 0 {
|
||||
t.Fatalf("expected empty map, got %d entries", len(out))
|
||||
}
|
||||
}
|
||||
|
||||
// TestPopulateCandidateProviders_SkipsExistingKeys verifies that a key already
|
||||
// present in the output map is not overwritten.
|
||||
func TestPopulateCandidateProviders_SkipsExistingKeys(t *testing.T) {
|
||||
existing := &mockProvider{}
|
||||
key := providers.ModelKey("openai", "gpt-4o")
|
||||
out := map[string]providers.LLMProvider{key: existing}
|
||||
|
||||
cfg := &config.Config{
|
||||
ModelList: []*config.ModelConfig{
|
||||
{ModelName: "my-gpt", Model: "openai/gpt-4o", APIKeys: config.SimpleSecureStrings("test-key")},
|
||||
},
|
||||
}
|
||||
populateCandidateProvidersFromNames(cfg, t.TempDir(), []string{"my-gpt"}, out)
|
||||
|
||||
if out[key] != existing {
|
||||
t.Fatal("existing provider entry was overwritten; expected it to be preserved")
|
||||
}
|
||||
}
|
||||
|
||||
// TestPopulateCandidateProviders_ResolvesAlias verifies that a model_name
|
||||
// alias (e.g. "my-gpt") is resolved via GetModelConfig and the provider
|
||||
// is created using the underlying model's config.
|
||||
func TestPopulateCandidateProviders_ResolvesAlias(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
out := map[string]providers.LLMProvider{}
|
||||
|
||||
cfg := &config.Config{
|
||||
ModelList: []*config.ModelConfig{
|
||||
{ModelName: "my-gpt", Model: "openai/gpt-4o", APIBase: "https://api.openai.com/v1", Workspace: workspace},
|
||||
},
|
||||
}
|
||||
populateCandidateProvidersFromNames(cfg, workspace, []string{"my-gpt"}, out)
|
||||
|
||||
key := providers.ModelKey("openai", "gpt-4o")
|
||||
if out[key] == nil {
|
||||
t.Fatalf("expected CandidateProviders[%q] to be populated for alias", key)
|
||||
}
|
||||
}
|
||||
|
||||
// TestPopulateCandidateProviders_ResolvesProtocolPrefix verifies that a
|
||||
// model_list entry using full "provider/model" notation (e.g.
|
||||
// "gemini/gemma-3-27b-it") is matched correctly when referenced by model_name.
|
||||
func TestPopulateCandidateProviders_ResolvesProtocolPrefix(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
out := map[string]providers.LLMProvider{}
|
||||
|
||||
cfg := &config.Config{
|
||||
ModelList: []*config.ModelConfig{
|
||||
{
|
||||
ModelName: "gemma",
|
||||
Model: "gemini/gemma-3-27b-it",
|
||||
APIKeys: config.SimpleSecureStrings("gemini-test-key"),
|
||||
Workspace: workspace,
|
||||
},
|
||||
},
|
||||
}
|
||||
populateCandidateProvidersFromNames(cfg, workspace, []string{"gemma"}, out)
|
||||
|
||||
key := providers.ModelKey("gemini", "gemma-3-27b-it")
|
||||
if out[key] == nil {
|
||||
t.Fatalf("expected CandidateProviders[%q] to be populated for protocol-prefixed model", key)
|
||||
}
|
||||
}
|
||||
|
||||
// TestPopulateCandidateProviders_EmptyNamesIsNoop verifies the early-exit
|
||||
// path when the names slice is empty.
|
||||
func TestPopulateCandidateProviders_EmptyNamesIsNoop(t *testing.T) {
|
||||
out := map[string]providers.LLMProvider{}
|
||||
cfg := &config.Config{
|
||||
ModelList: []*config.ModelConfig{
|
||||
{ModelName: "my-gpt", Model: "openai/gpt-4o", APIKeys: config.SimpleSecureStrings("key")},
|
||||
},
|
||||
}
|
||||
populateCandidateProvidersFromNames(cfg, t.TempDir(), nil, out)
|
||||
if len(out) != 0 {
|
||||
t.Fatalf("expected empty map, got %d entries", len(out))
|
||||
}
|
||||
}
|
||||
|
||||
// TestPopulateCandidateProviders_EmptyModelListIsNoop verifies the early-exit
|
||||
// path when model_list is empty — no provider can be created.
|
||||
func TestPopulateCandidateProviders_EmptyModelListIsNoop(t *testing.T) {
|
||||
out := map[string]providers.LLMProvider{}
|
||||
cfg := &config.Config{}
|
||||
populateCandidateProvidersFromNames(cfg, t.TempDir(), []string{"gpt-4o"}, out)
|
||||
if len(out) != 0 {
|
||||
t.Fatalf("expected empty map, got %d entries", len(out))
|
||||
}
|
||||
}
|
||||
|
||||
// TestPopulateCandidateProviders_UnmatchedNameIsSkipped verifies that a
|
||||
// name with no matching model_list entry is skipped and does not
|
||||
// cause a panic or leave a nil entry in the map.
|
||||
func TestPopulateCandidateProviders_UnmatchedNameIsSkipped(t *testing.T) {
|
||||
out := map[string]providers.LLMProvider{}
|
||||
cfg := &config.Config{
|
||||
ModelList: []*config.ModelConfig{
|
||||
{ModelName: "my-gpt", Model: "openai/gpt-4o", APIKeys: config.SimpleSecureStrings("key")},
|
||||
},
|
||||
}
|
||||
populateCandidateProvidersFromNames(cfg, t.TempDir(), []string{"nonexistent-model"}, out)
|
||||
|
||||
if len(out) != 0 {
|
||||
t.Fatalf("expected empty map for unmatched name, got %d entries", len(out))
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewAgentInstance_CandidateProvidersPopulatedForCrossProviderFallbacks
|
||||
// mirrors the exact scenario from bug #2140: primary model on OpenRouter with
|
||||
// Gemini fallbacks. Each entry must get its own provider instance so that
|
||||
// fallback requests go to the correct API endpoint, not the primary's.
|
||||
func TestNewAgentInstance_CandidateProvidersPopulatedForCrossProviderFallbacks(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: workspace,
|
||||
ModelName: "mistral-small-3.1",
|
||||
ModelFallbacks: []string{"gemma-3-27b", "gemini-images"},
|
||||
},
|
||||
},
|
||||
ModelList: []*config.ModelConfig{
|
||||
{
|
||||
ModelName: "mistral-small-3.1",
|
||||
Model: "openrouter/mistralai/mistral-small-3.1-24b-instruct:free",
|
||||
APIBase: "https://openrouter.ai/api/v1",
|
||||
APIKeys: config.SimpleSecureStrings("sk-or-test"),
|
||||
Workspace: workspace,
|
||||
},
|
||||
{
|
||||
ModelName: "gemma-3-27b",
|
||||
Model: "gemini/gemma-3-27b-it",
|
||||
APIKeys: config.SimpleSecureStrings("AIzaSy-test"),
|
||||
Workspace: workspace,
|
||||
},
|
||||
{
|
||||
ModelName: "gemini-images",
|
||||
Model: "gemini/gemini-2.5-flash-lite",
|
||||
APIKeys: config.SimpleSecureStrings("AIzaSy-test"),
|
||||
Workspace: workspace,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
primaryProvider := &mockProvider{}
|
||||
agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, primaryProvider)
|
||||
|
||||
// Only fallback models need entries — the primary uses the injected provider directly.
|
||||
wantKeys := []string{
|
||||
providers.ModelKey("gemini", "gemma-3-27b-it"),
|
||||
providers.ModelKey("gemini", "gemini-2.5-flash-lite"),
|
||||
}
|
||||
|
||||
for _, key := range wantKeys {
|
||||
p, ok := agent.CandidateProviders[key]
|
||||
if !ok {
|
||||
t.Errorf("CandidateProviders missing key %q", key)
|
||||
continue
|
||||
}
|
||||
if p == nil {
|
||||
t.Errorf("CandidateProviders[%q] is nil", key)
|
||||
}
|
||||
// Each fallback must use its own provider, not the injected primary.
|
||||
if p == primaryProvider {
|
||||
t.Errorf(
|
||||
"CandidateProviders[%q] is the same instance as the primary provider; fallback would inherit primary credentials",
|
||||
key,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if t.Failed() {
|
||||
t.Logf("CandidateProviders keys present: %v", func() []string {
|
||||
keys := make([]string, 0, len(agent.CandidateProviders))
|
||||
for k := range agent.CandidateProviders {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
return keys
|
||||
}())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAgentInstance_ReadFileModeSelectsSchema(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
|
||||
|
||||
+240
-2
@@ -1742,6 +1742,7 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
|
||||
if err := al.contextManager.Compact(turnCtx, &CompactRequest{
|
||||
SessionKey: ts.sessionKey,
|
||||
Reason: ContextCompressReasonProactive,
|
||||
Budget: ts.agent.ContextWindow,
|
||||
}); err != nil {
|
||||
logger.WarnCF("agent", "Proactive compact failed", map[string]any{
|
||||
"session_key": ts.sessionKey,
|
||||
@@ -1857,6 +1858,7 @@ turnLoop:
|
||||
if !ts.opts.NoHistory {
|
||||
ts.agent.Sessions.AddFullMessage(ts.sessionKey, pm)
|
||||
ts.recordPersistedMessage(pm)
|
||||
ts.ingestMessage(turnCtx, al, pm)
|
||||
}
|
||||
logger.InfoCF("agent", "Injected steering message into context",
|
||||
map[string]any{
|
||||
@@ -2018,7 +2020,11 @@ turnLoop:
|
||||
providerCtx,
|
||||
activeCandidates,
|
||||
func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) {
|
||||
return activeProvider.Chat(ctx, messagesForCall, toolDefsForCall, model, llmOpts)
|
||||
candidateProvider := activeProvider
|
||||
if cp, ok := ts.agent.CandidateProviders[providers.ModelKey(provider, model)]; ok {
|
||||
candidateProvider = cp
|
||||
}
|
||||
return candidateProvider.Chat(ctx, messagesForCall, toolDefsForCall, model, llmOpts)
|
||||
},
|
||||
)
|
||||
if fbErr != nil {
|
||||
@@ -2128,6 +2134,7 @@ turnLoop:
|
||||
if compactErr := al.contextManager.Compact(turnCtx, &CompactRequest{
|
||||
SessionKey: ts.sessionKey,
|
||||
Reason: ContextCompressReasonRetry,
|
||||
Budget: ts.agent.ContextWindow,
|
||||
}); compactErr != nil {
|
||||
logger.WarnCF("agent", "Context overflow compact failed", map[string]any{
|
||||
"session_key": ts.sessionKey,
|
||||
@@ -2345,6 +2352,236 @@ turnLoop:
|
||||
toolName = toolReq.Tool
|
||||
toolArgs = toolReq.Arguments
|
||||
}
|
||||
case HookActionRespond:
|
||||
// Hook returns result directly, skip tool execution.
|
||||
// SECURITY: This bypasses ApproveTool, allowing hooks to respond
|
||||
// for any tool name without approval. This is intentional for
|
||||
// plugin tools but means a before_tool hook can override even
|
||||
// sensitive tools like bash. Hook configuration should be
|
||||
// carefully reviewed to prevent unauthorized tool execution.
|
||||
if toolReq != nil && toolReq.HookResult != nil {
|
||||
hookResult := toolReq.HookResult
|
||||
|
||||
argsJSON, _ := json.Marshal(toolArgs)
|
||||
argsPreview := utils.Truncate(string(argsJSON), 200)
|
||||
logger.InfoCF("agent", fmt.Sprintf("Tool call (hook respond): %s(%s)", toolName, argsPreview),
|
||||
map[string]any{
|
||||
"agent_id": ts.agent.ID,
|
||||
"tool": toolName,
|
||||
"iteration": iteration,
|
||||
})
|
||||
|
||||
// Emit ToolExecStart event (same as normal tool execution)
|
||||
al.emitEvent(
|
||||
EventKindToolExecStart,
|
||||
ts.eventMeta("runTurn", "turn.tool.start"),
|
||||
ToolExecStartPayload{
|
||||
Tool: toolName,
|
||||
Arguments: cloneEventArguments(toolArgs),
|
||||
},
|
||||
)
|
||||
|
||||
// Send tool feedback to chat channel if enabled (same as normal tool execution)
|
||||
if al.cfg.Agents.Defaults.IsToolFeedbackEnabled() &&
|
||||
ts.channel != "" &&
|
||||
!ts.opts.SuppressToolFeedback {
|
||||
argsJSON, _ := json.Marshal(toolArgs)
|
||||
feedbackPreview := utils.Truncate(
|
||||
string(argsJSON),
|
||||
al.cfg.Agents.Defaults.GetToolFeedbackMaxArgsLength(),
|
||||
)
|
||||
feedbackMsg := fmt.Sprintf("\U0001f527 `%s`\n```\n%s\n```", toolName, feedbackPreview)
|
||||
fbCtx, fbCancel := context.WithTimeout(turnCtx, 3*time.Second)
|
||||
_ = al.bus.PublishOutbound(fbCtx, bus.OutboundMessage{
|
||||
Channel: ts.channel,
|
||||
ChatID: ts.chatID,
|
||||
Content: feedbackMsg,
|
||||
})
|
||||
fbCancel()
|
||||
}
|
||||
|
||||
toolDuration := time.Duration(0) // Hook execution time unknown
|
||||
|
||||
// Send ForUser content to user
|
||||
// For ResponseHandled results, send regardless of SendResponse setting,
|
||||
// same as normal tool execution path.
|
||||
shouldSendForUser := !hookResult.Silent && hookResult.ForUser != "" &&
|
||||
(ts.opts.SendResponse || hookResult.ResponseHandled)
|
||||
if shouldSendForUser {
|
||||
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
|
||||
Channel: ts.channel,
|
||||
ChatID: ts.chatID,
|
||||
Content: hookResult.ForUser,
|
||||
Metadata: map[string]string{
|
||||
"is_tool_call": "true",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Handle media from hook result (same as normal tool execution)
|
||||
if len(hookResult.Media) > 0 && hookResult.ResponseHandled {
|
||||
parts := make([]bus.MediaPart, 0, len(hookResult.Media))
|
||||
for _, ref := range hookResult.Media {
|
||||
part := bus.MediaPart{Ref: ref}
|
||||
if al.mediaStore != nil {
|
||||
if _, meta, err := al.mediaStore.ResolveWithMeta(ref); err == nil {
|
||||
part.Filename = meta.Filename
|
||||
part.ContentType = meta.ContentType
|
||||
part.Type = inferMediaType(meta.Filename, meta.ContentType)
|
||||
}
|
||||
}
|
||||
parts = append(parts, part)
|
||||
}
|
||||
outboundMedia := bus.OutboundMediaMessage{
|
||||
Channel: ts.channel,
|
||||
ChatID: ts.chatID,
|
||||
Parts: parts,
|
||||
}
|
||||
if al.channelManager != nil && ts.channel != "" && !constants.IsInternalChannel(ts.channel) {
|
||||
if err := al.channelManager.SendMedia(ctx, outboundMedia); err != nil {
|
||||
logger.WarnCF("agent", "Failed to deliver hook media",
|
||||
map[string]any{
|
||||
"agent_id": ts.agent.ID,
|
||||
"tool": toolName,
|
||||
"channel": ts.channel,
|
||||
"chat_id": ts.chatID,
|
||||
"error": err.Error(),
|
||||
})
|
||||
// Same as normal tool execution: notify LLM about delivery failure
|
||||
hookResult.IsError = true
|
||||
hookResult.ForLLM = fmt.Sprintf("failed to deliver attachment: %v", err)
|
||||
}
|
||||
} else if al.bus != nil {
|
||||
al.bus.PublishOutboundMedia(ctx, outboundMedia)
|
||||
// Same as normal tool execution: bus only queues, media not yet delivered
|
||||
hookResult.ResponseHandled = false
|
||||
}
|
||||
}
|
||||
|
||||
// Track response handling status (same as normal tool execution)
|
||||
if !hookResult.ResponseHandled {
|
||||
allResponsesHandled = false
|
||||
}
|
||||
|
||||
// Build tool message
|
||||
contentForLLM := hookResult.ContentForLLM()
|
||||
if al.cfg.Tools.IsFilterSensitiveDataEnabled() {
|
||||
contentForLLM = al.cfg.FilterSensitiveData(contentForLLM)
|
||||
}
|
||||
|
||||
toolResultMsg := providers.Message{
|
||||
Role: "tool",
|
||||
Content: contentForLLM,
|
||||
ToolCallID: tc.ID,
|
||||
}
|
||||
|
||||
// Handle media for LLM vision (same as normal tool execution)
|
||||
if len(hookResult.Media) > 0 && !hookResult.ResponseHandled {
|
||||
hookResult.ArtifactTags = buildArtifactTags(al.mediaStore, hookResult.Media)
|
||||
// Recalculate contentForLLM after adding ArtifactTags
|
||||
contentForLLM = hookResult.ContentForLLM()
|
||||
if al.cfg.Tools.IsFilterSensitiveDataEnabled() {
|
||||
contentForLLM = al.cfg.FilterSensitiveData(contentForLLM)
|
||||
}
|
||||
toolResultMsg.Content = contentForLLM
|
||||
toolResultMsg.Media = append(toolResultMsg.Media, hookResult.Media...)
|
||||
}
|
||||
|
||||
// Emit ToolExecEnd event (after filtering, same as normal tool execution)
|
||||
al.emitEvent(
|
||||
EventKindToolExecEnd,
|
||||
ts.eventMeta("runTurn", "turn.tool.end"),
|
||||
ToolExecEndPayload{
|
||||
Tool: toolName,
|
||||
Duration: toolDuration,
|
||||
ForLLMLen: len(contentForLLM),
|
||||
ForUserLen: len(hookResult.ForUser),
|
||||
IsError: hookResult.IsError,
|
||||
Async: hookResult.Async,
|
||||
},
|
||||
)
|
||||
|
||||
messages = append(messages, toolResultMsg)
|
||||
if !ts.opts.NoHistory {
|
||||
ts.agent.Sessions.AddFullMessage(ts.sessionKey, toolResultMsg)
|
||||
ts.recordPersistedMessage(toolResultMsg)
|
||||
ts.ingestMessage(turnCtx, al, toolResultMsg)
|
||||
}
|
||||
|
||||
// Same as normal tool execution: check for steering/interrupt/SubTurn after each tool
|
||||
if steerMsgs := al.dequeueSteeringMessagesForScope(ts.sessionKey); len(steerMsgs) > 0 {
|
||||
pendingMessages = append(pendingMessages, steerMsgs...)
|
||||
}
|
||||
|
||||
skipReason := ""
|
||||
skipMessage := ""
|
||||
if len(pendingMessages) > 0 {
|
||||
skipReason = "queued user steering message"
|
||||
skipMessage = "Skipped due to queued user message."
|
||||
} else if gracefulPending, _ := ts.gracefulInterruptRequested(); gracefulPending {
|
||||
skipReason = "graceful interrupt requested"
|
||||
skipMessage = "Skipped due to graceful interrupt."
|
||||
}
|
||||
|
||||
if skipReason != "" {
|
||||
remaining := len(normalizedToolCalls) - i - 1
|
||||
if remaining > 0 {
|
||||
logger.InfoCF("agent", "Turn checkpoint: skipping remaining tools after hook respond",
|
||||
map[string]any{
|
||||
"agent_id": ts.agent.ID,
|
||||
"completed": i + 1,
|
||||
"skipped": remaining,
|
||||
"reason": skipReason,
|
||||
})
|
||||
for j := i + 1; j < len(normalizedToolCalls); j++ {
|
||||
skippedTC := normalizedToolCalls[j]
|
||||
al.emitEvent(
|
||||
EventKindToolExecSkipped,
|
||||
ts.eventMeta("runTurn", "turn.tool.skipped"),
|
||||
ToolExecSkippedPayload{
|
||||
Tool: skippedTC.Name,
|
||||
Reason: skipReason,
|
||||
},
|
||||
)
|
||||
skippedMsg := providers.Message{
|
||||
Role: "tool",
|
||||
Content: skipMessage,
|
||||
ToolCallID: skippedTC.ID,
|
||||
}
|
||||
messages = append(messages, skippedMsg)
|
||||
if !ts.opts.NoHistory {
|
||||
ts.agent.Sessions.AddFullMessage(ts.sessionKey, skippedMsg)
|
||||
ts.recordPersistedMessage(skippedMsg)
|
||||
}
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
// Also poll for any SubTurn results that arrived during tool execution.
|
||||
if ts.pendingResults != nil {
|
||||
select {
|
||||
case result, ok := <-ts.pendingResults:
|
||||
if ok && result != nil && result.ForLLM != "" {
|
||||
content := al.cfg.FilterSensitiveData(result.ForLLM)
|
||||
msg := providers.Message{Role: "user", Content: fmt.Sprintf("[SubTurn Result] %s", content)}
|
||||
messages = append(messages, msg)
|
||||
ts.agent.Sessions.AddFullMessage(ts.sessionKey, msg)
|
||||
}
|
||||
default:
|
||||
// No results available
|
||||
}
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
// If no HookResult, fall back to continue with warning
|
||||
logger.WarnCF("agent", "Hook returned respond action but no HookResult provided",
|
||||
map[string]any{
|
||||
"agent_id": ts.agent.ID,
|
||||
"tool": toolName,
|
||||
"action": "respond",
|
||||
})
|
||||
case HookActionDenyTool:
|
||||
allResponsesHandled = false
|
||||
denyContent := hookDeniedToolContent("Tool execution denied by hook", decision.Reason)
|
||||
@@ -2773,7 +3010,7 @@ turnLoop:
|
||||
}
|
||||
}
|
||||
if ts.opts.EnableSummary {
|
||||
al.contextManager.Compact(turnCtx, &CompactRequest{SessionKey: ts.sessionKey, Reason: ContextCompressReasonSummarize})
|
||||
al.contextManager.Compact(turnCtx, &CompactRequest{SessionKey: ts.sessionKey, Reason: ContextCompressReasonSummarize, Budget: ts.agent.ContextWindow})
|
||||
}
|
||||
|
||||
ts.setPhase(TurnPhaseCompleted)
|
||||
@@ -2849,6 +3086,7 @@ turnLoop:
|
||||
&CompactRequest{
|
||||
SessionKey: ts.sessionKey,
|
||||
Reason: ContextCompressReasonSummarize,
|
||||
Budget: ts.agent.ContextWindow,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@@ -126,6 +126,8 @@ func (al *AgentLoop) ensureMCPInitialized(ctx context.Context) error {
|
||||
}
|
||||
|
||||
mcpTool := tools.NewMCPTool(mcpManager, serverName, tool)
|
||||
mcpTool.SetWorkspace(agent.Workspace)
|
||||
mcpTool.SetMaxInlineTextRunes(al.cfg.Tools.MCP.GetMaxInlineTextChars())
|
||||
|
||||
if registerAsHidden {
|
||||
agent.Tools.RegisterHidden(mcpTool)
|
||||
|
||||
@@ -1839,6 +1839,164 @@ func TestProcessMessage_ModelRoutingUsesLightProvider(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessMessage_FallbackUsesPerCandidateProvider is the loop-level test for
|
||||
// bug #2140. It verifies that when the primary model returns a rate-limit error
|
||||
// the fallback closure routes the retry to the fallback model's own provider
|
||||
// (its own api_base), not back to the primary provider's endpoint.
|
||||
func TestProcessMessage_FallbackUsesPerCandidateProvider(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
|
||||
primaryCalls := 0
|
||||
primaryServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
primaryCalls++
|
||||
// Return 429 so FallbackChain classifies this as retriable and moves on.
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"error": map[string]any{
|
||||
"message": "rate limit exceeded",
|
||||
"type": "rate_limit_error",
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer primaryServer.Close()
|
||||
|
||||
fallbackCalls := 0
|
||||
fallbackServer := newStrictChatCompletionTestServer(
|
||||
t, "fallback", "gemma-3-27b-it", "fallback reply", &fallbackCalls,
|
||||
)
|
||||
defer fallbackServer.Close()
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: workspace,
|
||||
ModelName: "mistral-primary",
|
||||
ModelFallbacks: []string{"gemma-fallback"},
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 3,
|
||||
},
|
||||
},
|
||||
ModelList: []*config.ModelConfig{
|
||||
{
|
||||
ModelName: "mistral-primary",
|
||||
Model: "openrouter/mistralai/mistral-small-3.1",
|
||||
APIBase: primaryServer.URL,
|
||||
APIKeys: config.SimpleSecureStrings("primary-key"),
|
||||
Workspace: workspace,
|
||||
},
|
||||
{
|
||||
ModelName: "gemma-fallback",
|
||||
Model: "gemini/gemma-3-27b-it",
|
||||
APIBase: fallbackServer.URL,
|
||||
APIKeys: config.SimpleSecureStrings("fallback-key"),
|
||||
Workspace: workspace,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
provider, _, err := providers.CreateProvider(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProvider() error = %v", err)
|
||||
}
|
||||
msgBus := bus.NewMessageBus()
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
helper := testHelper{al: al}
|
||||
|
||||
resp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "hi",
|
||||
Peer: bus.Peer{Kind: "direct", ID: "user1"},
|
||||
})
|
||||
|
||||
if resp != "fallback reply" {
|
||||
t.Fatalf("response = %q, want %q (fallback provider)", resp, "fallback reply")
|
||||
}
|
||||
if primaryCalls == 0 {
|
||||
t.Fatal("primary server was never called; expected at least one attempt")
|
||||
}
|
||||
if fallbackCalls != 1 {
|
||||
t.Fatalf("fallback server calls = %d, want 1", fallbackCalls)
|
||||
}
|
||||
}
|
||||
|
||||
// TestProcessMessage_FallbackUsesActiveProviderWhenCandidateNotRegistered verifies
|
||||
// that when a candidate has no model_list entry it is absent from CandidateProviders
|
||||
// and the fallback closure falls back to activeProvider instead of panicking.
|
||||
func TestProcessMessage_FallbackUsesActiveProviderWhenCandidateNotRegistered(t *testing.T) {
|
||||
workspace := t.TempDir()
|
||||
|
||||
// Primary server: returns 429 on first call, succeeds on second.
|
||||
// Both the primary and the unregistered fallback share this server
|
||||
// (same api_base) so activeProvider routes both calls here.
|
||||
callCount := 0
|
||||
primaryServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
callCount++
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if callCount == 1 {
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"error": map[string]any{"message": "rate limit", "type": "rate_limit_error"},
|
||||
})
|
||||
return
|
||||
}
|
||||
// Second call (fallback via activeProvider) succeeds.
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"choices": []map[string]any{
|
||||
{"message": map[string]any{"content": "active provider reply"}, "finish_reason": "stop"},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer primaryServer.Close()
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: workspace,
|
||||
ModelName: "primary-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 3,
|
||||
// No model_list entry for this alias — absent from CandidateProviders.
|
||||
ModelFallbacks: []string{"openrouter/fallback-model"},
|
||||
},
|
||||
},
|
||||
ModelList: []*config.ModelConfig{
|
||||
{
|
||||
ModelName: "primary-model",
|
||||
Model: "openrouter/primary-model",
|
||||
APIBase: primaryServer.URL,
|
||||
APIKeys: config.SimpleSecureStrings("primary-key"),
|
||||
Workspace: workspace,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
provider, _, err := providers.CreateProvider(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProvider() error = %v", err)
|
||||
}
|
||||
msgBus := bus.NewMessageBus()
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
helper := testHelper{al: al}
|
||||
resp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
|
||||
Channel: "telegram",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "hi",
|
||||
Peer: bus.Peer{Kind: "direct", ID: "user1"},
|
||||
})
|
||||
|
||||
if resp != "active provider reply" {
|
||||
t.Fatalf("response = %q, want %q", resp, "active provider reply")
|
||||
}
|
||||
if callCount < 2 {
|
||||
t.Fatalf("primary server calls = %d, want >= 2 (one 429 + one success via activeProvider)", callCount)
|
||||
}
|
||||
}
|
||||
|
||||
// TestToolResult_SilentToolDoesNotSendUserMessage verifies silent tools don't trigger outbound
|
||||
func TestToolResult_SilentToolDoesNotSendUserMessage(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
|
||||
@@ -604,6 +604,7 @@ type ephemeralSessionStoreIface interface {
|
||||
SetHistory(key string, history []providers.Message)
|
||||
TruncateHistory(key string, keepLast int)
|
||||
Save(key string) error
|
||||
ListSessions() []string
|
||||
Close() error
|
||||
}
|
||||
|
||||
@@ -663,8 +664,9 @@ func (e *ephemeralSessionStore) TruncateHistory(_ string, keepLast int) {
|
||||
e.history = e.history[len(e.history)-keepLast:]
|
||||
}
|
||||
|
||||
func (e *ephemeralSessionStore) Save(_ string) error { return nil }
|
||||
func (e *ephemeralSessionStore) Close() error { return nil }
|
||||
func (e *ephemeralSessionStore) Save(_ string) error { return nil }
|
||||
func (e *ephemeralSessionStore) Close() error { return nil }
|
||||
func (e *ephemeralSessionStore) ListSessions() []string { return nil }
|
||||
|
||||
func (e *ephemeralSessionStore) truncateLocked() {
|
||||
if len(e.history) > maxEphemeralHistorySize {
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
lark "github.com/larksuite/oapi-sdk-go/v3"
|
||||
larkcore "github.com/larksuite/oapi-sdk-go/v3/core"
|
||||
@@ -42,12 +43,18 @@ type FeishuChannel struct {
|
||||
wsClient *larkws.Client
|
||||
tokenCache *tokenCache // custom cache that supports invalidation
|
||||
|
||||
botOpenID atomic.Value // stores string; populated lazily for @mention detection
|
||||
botOpenID atomic.Value // stores string; populated lazily for @mention detection
|
||||
messageCache sync.Map // caches fetched messages (messageID -> *larkim.Message)
|
||||
|
||||
mu sync.Mutex
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
type cachedMessage struct {
|
||||
msg *larkim.Message
|
||||
expiry time.Time
|
||||
}
|
||||
|
||||
func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChannel, error) {
|
||||
base := channels.NewBaseChannel("feishu", cfg, bus, cfg.AllowFrom,
|
||||
channels.WithGroupTrigger(cfg.GroupTrigger),
|
||||
@@ -436,24 +443,8 @@ func (c *FeishuChannel) handleMessageReceive(ctx context.Context, event *larkim.
|
||||
// Append media tags to content (like Telegram does)
|
||||
content = appendMediaTags(content, messageType, mediaRefs)
|
||||
|
||||
if content == "" {
|
||||
content = "[empty message]"
|
||||
}
|
||||
|
||||
metadata := map[string]string{}
|
||||
if messageID != "" {
|
||||
metadata["message_id"] = messageID
|
||||
}
|
||||
if messageType != "" {
|
||||
metadata["message_type"] = messageType
|
||||
}
|
||||
chatType := stringValue(message.ChatType)
|
||||
if chatType != "" {
|
||||
metadata["chat_type"] = chatType
|
||||
}
|
||||
if sender != nil && sender.TenantKey != nil {
|
||||
metadata["tenant_key"] = *sender.TenantKey
|
||||
}
|
||||
metadata := buildInboundMetadata(message, sender)
|
||||
|
||||
var peer bus.Peer
|
||||
if chatType == "p2p" {
|
||||
@@ -477,12 +468,25 @@ func (c *FeishuChannel) handleMessageReceive(ctx context.Context, event *larkim.
|
||||
content = cleaned
|
||||
}
|
||||
|
||||
if replyTargetID(message) != "" || stringValue(message.ThreadId) != "" {
|
||||
content, mediaRefs = c.prependReplyContext(ctx, message, chatID, content, mediaRefs)
|
||||
}
|
||||
if content == "" {
|
||||
content = "[empty message]"
|
||||
}
|
||||
|
||||
logger.InfoCF("feishu", "Feishu message received", map[string]any{
|
||||
"sender_id": senderID,
|
||||
"chat_id": chatID,
|
||||
"message_id": messageID,
|
||||
"preview": utils.Truncate(content, 80),
|
||||
})
|
||||
logger.InfoCF("feishu", "Feishu reply linkage", map[string]any{
|
||||
"message_id": messageID,
|
||||
"parent_id": stringValue(message.ParentId),
|
||||
"root_id": stringValue(message.RootId),
|
||||
"thread_id": stringValue(message.ThreadId),
|
||||
})
|
||||
|
||||
c.HandleMessage(ctx, peer, messageID, senderID, chatID, content, mediaRefs, metadata, senderInfo)
|
||||
return nil
|
||||
|
||||
@@ -0,0 +1,298 @@
|
||||
//go:build amd64 || arm64 || riscv64 || mips64 || ppc64
|
||||
|
||||
package feishu
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
const messageCacheTTL = 30 * time.Second
|
||||
|
||||
const (
|
||||
maxReplyContextLen = 600
|
||||
)
|
||||
|
||||
func (c *FeishuChannel) prependReplyContext(
|
||||
ctx context.Context,
|
||||
message *larkim.EventMessage,
|
||||
chatID string,
|
||||
content string,
|
||||
mediaRefs []string,
|
||||
) (string, []string) {
|
||||
if message == nil {
|
||||
return content, mediaRefs
|
||||
}
|
||||
|
||||
lookupCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
targetMessageID := c.resolveReplyTargetMessageID(lookupCtx, message)
|
||||
if targetMessageID == "" {
|
||||
logger.DebugCF("feishu", "No reply target resolved; skip reply context", map[string]any{
|
||||
"message_id": stringValue(message.MessageId),
|
||||
"parent_id": stringValue(message.ParentId),
|
||||
"root_id": stringValue(message.RootId),
|
||||
"thread_id": stringValue(message.ThreadId),
|
||||
})
|
||||
return content, mediaRefs
|
||||
}
|
||||
|
||||
repliedMessage, err := c.fetchMessageByID(lookupCtx, targetMessageID)
|
||||
if err != nil {
|
||||
logger.DebugCF("feishu", "Failed to fetch replied message context", map[string]any{
|
||||
"target_message_id": targetMessageID,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return content, mediaRefs
|
||||
}
|
||||
|
||||
messageType := stringValue(repliedMessage.MsgType)
|
||||
rawContent := ""
|
||||
if repliedMessage.Body != nil {
|
||||
rawContent = stringValue(repliedMessage.Body.Content)
|
||||
}
|
||||
|
||||
var repliedMediaRefs []string
|
||||
if store := c.GetMediaStore(); store != nil {
|
||||
repliedMediaRefs = c.downloadInboundMedia(lookupCtx, chatID, targetMessageID, messageType, rawContent, store)
|
||||
if messageType == larkim.MsgTypeInteractive {
|
||||
_, externalURLs := extractCardImageKeys(rawContent)
|
||||
if len(externalURLs) > 0 {
|
||||
repliedMediaRefs = append(repliedMediaRefs, externalURLs...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
repliedContent := normalizeRepliedContent(messageType, rawContent, repliedMediaRefs)
|
||||
if len(repliedMediaRefs) > 0 {
|
||||
mediaRefs = append(repliedMediaRefs, mediaRefs...)
|
||||
}
|
||||
|
||||
return formatReplyContext(targetMessageID, repliedContent, content), mediaRefs
|
||||
}
|
||||
|
||||
func (c *FeishuChannel) resolveReplyTargetMessageID(ctx context.Context, message *larkim.EventMessage) string {
|
||||
if targetID := replyTargetID(message); targetID != "" {
|
||||
logger.DebugCF("feishu", "Resolved reply target from event payload", map[string]any{
|
||||
"message_id": stringValue(message.MessageId),
|
||||
"parent_id": stringValue(message.ParentId),
|
||||
"root_id": stringValue(message.RootId),
|
||||
"target_id": targetID,
|
||||
})
|
||||
return targetID
|
||||
}
|
||||
|
||||
currentMessageID := stringValue(message.MessageId)
|
||||
if currentMessageID == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
if stringValue(message.ThreadId) == "" {
|
||||
logger.DebugCF("feishu", "No reply target found; message is not in a thread", map[string]any{
|
||||
"message_id": stringValue(message.MessageId),
|
||||
})
|
||||
return ""
|
||||
}
|
||||
|
||||
msg, err := c.fetchMessageByID(ctx, currentMessageID)
|
||||
if err != nil {
|
||||
logger.DebugCF("feishu", "Failed to query current message detail for reply info", map[string]any{
|
||||
"message_id": currentMessageID,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return ""
|
||||
}
|
||||
|
||||
targetID := replyTargetIDFromMessage(msg)
|
||||
if targetID != "" {
|
||||
logger.DebugCF("feishu", "Resolved reply target from message detail", map[string]any{
|
||||
"message_id": currentMessageID,
|
||||
"parent_id": stringValue(msg.ParentId),
|
||||
"root_id": stringValue(msg.RootId),
|
||||
"target_id": targetID,
|
||||
})
|
||||
}
|
||||
return targetID
|
||||
}
|
||||
|
||||
func (c *FeishuChannel) fetchMessageByID(ctx context.Context, messageID string) (*larkim.Message, error) {
|
||||
if cached, ok := c.messageCache.Load(messageID); ok {
|
||||
cm := cached.(*cachedMessage)
|
||||
if time.Now().Before(cm.expiry) {
|
||||
return cm.msg, nil
|
||||
}
|
||||
c.messageCache.Delete(messageID)
|
||||
}
|
||||
|
||||
req := larkim.NewGetMessageReqBuilder().
|
||||
MessageId(messageID).
|
||||
Build()
|
||||
|
||||
resp, err := c.client.Im.V1.Message.Get(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("feishu get message: %w", err)
|
||||
}
|
||||
if !resp.Success() {
|
||||
c.invalidateTokenOnAuthError(resp.Code)
|
||||
return nil, fmt.Errorf("feishu get message api error (code=%d msg=%s)", resp.Code, resp.Msg)
|
||||
}
|
||||
if resp.Data == nil || len(resp.Data.Items) == 0 || resp.Data.Items[0] == nil {
|
||||
return nil, fmt.Errorf("feishu get message: empty response")
|
||||
}
|
||||
// Items[0] contains the target message - the Feishu API returns a list
|
||||
// but we request a single message by ID, so the list always has at most one item.
|
||||
msg := resp.Data.Items[0]
|
||||
c.messageCache.Store(messageID, &cachedMessage{msg: msg, expiry: time.Now().Add(messageCacheTTL)})
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func replyTargetID(message *larkim.EventMessage) string {
|
||||
if message == nil {
|
||||
return ""
|
||||
}
|
||||
if parentID := stringValue(message.ParentId); parentID != "" {
|
||||
return parentID
|
||||
}
|
||||
return stringValue(message.RootId)
|
||||
}
|
||||
|
||||
func replyTargetIDFromMessage(message *larkim.Message) string {
|
||||
if message == nil {
|
||||
return ""
|
||||
}
|
||||
if parentID := stringValue(message.ParentId); parentID != "" {
|
||||
return parentID
|
||||
}
|
||||
return stringValue(message.RootId)
|
||||
}
|
||||
|
||||
func buildInboundMetadata(message *larkim.EventMessage, sender *larkim.EventSender) map[string]string {
|
||||
metadata := map[string]string{}
|
||||
if message == nil {
|
||||
return metadata
|
||||
}
|
||||
|
||||
messageID := stringValue(message.MessageId)
|
||||
if messageID != "" {
|
||||
metadata["message_id"] = messageID
|
||||
}
|
||||
|
||||
messageType := stringValue(message.MessageType)
|
||||
if messageType != "" {
|
||||
metadata["message_type"] = messageType
|
||||
}
|
||||
|
||||
chatType := stringValue(message.ChatType)
|
||||
if chatType != "" {
|
||||
metadata["chat_type"] = chatType
|
||||
}
|
||||
|
||||
parentID := stringValue(message.ParentId)
|
||||
if parentID != "" {
|
||||
metadata["parent_id"] = parentID
|
||||
}
|
||||
|
||||
rootID := stringValue(message.RootId)
|
||||
if rootID != "" {
|
||||
metadata["root_id"] = rootID
|
||||
}
|
||||
|
||||
if replyTo := replyTargetID(message); replyTo != "" {
|
||||
metadata["reply_to_message_id"] = replyTo
|
||||
}
|
||||
|
||||
threadID := stringValue(message.ThreadId)
|
||||
if threadID != "" {
|
||||
metadata["thread_id"] = threadID
|
||||
}
|
||||
|
||||
if sender != nil && sender.TenantKey != nil && *sender.TenantKey != "" {
|
||||
metadata["tenant_key"] = *sender.TenantKey
|
||||
}
|
||||
|
||||
return metadata
|
||||
}
|
||||
|
||||
func normalizeRepliedContent(messageType, rawContent string, mediaRefs []string) string {
|
||||
content := extractContent(messageType, rawContent)
|
||||
|
||||
if containsFeishuUpgradePlaceholder(rawContent) || containsFeishuUpgradePlaceholder(content) {
|
||||
content = ""
|
||||
}
|
||||
|
||||
content = appendMediaTags(content, messageType, mediaRefs)
|
||||
if strings.TrimSpace(content) != "" {
|
||||
return content
|
||||
}
|
||||
|
||||
switch messageType {
|
||||
case larkim.MsgTypeImage:
|
||||
return "[replied image]"
|
||||
case larkim.MsgTypeFile:
|
||||
return "[replied file]"
|
||||
case larkim.MsgTypeAudio:
|
||||
return "[replied audio]"
|
||||
case larkim.MsgTypeMedia:
|
||||
return "[replied video]"
|
||||
case larkim.MsgTypeInteractive:
|
||||
return "[replied interactive card]"
|
||||
default:
|
||||
return "[replied message content unavailable]"
|
||||
}
|
||||
}
|
||||
|
||||
func containsFeishuUpgradePlaceholder(s string) bool {
|
||||
upgradePrompt := "\u8bf7\u5347\u7ea7\u81f3\u6700\u65b0\u7248\u672c\u5ba2\u6237\u7aef"
|
||||
upgradePromptEscaped := "\\u8bf7\\u5347\\u7ea7\\u81f3\\u6700\\u65b0\\u7248\\u672c\\u5ba2\\u6237\\u7aef"
|
||||
return strings.Contains(s, upgradePrompt) || strings.Contains(s, upgradePromptEscaped)
|
||||
}
|
||||
|
||||
func formatReplyContext(parentID, repliedContent, content string) string {
|
||||
parentID = strings.TrimSpace(parentID)
|
||||
repliedContent = strings.TrimSpace(repliedContent)
|
||||
content = strings.TrimSpace(content)
|
||||
|
||||
if parentID == "" || repliedContent == "" {
|
||||
return content
|
||||
}
|
||||
|
||||
repliedContent = utils.Truncate(repliedContent, maxReplyContextLen)
|
||||
repliedContent = sanitizeReplyContextContent(repliedContent)
|
||||
content = sanitizeReplyContextContent(content)
|
||||
header := fmt.Sprintf("[replied_message id=%q]", parentID)
|
||||
footer := "[/replied_message]"
|
||||
if content == "" {
|
||||
return header + "\n" + repliedContent + "\n" + footer
|
||||
}
|
||||
if hasLeadingCommandPrefix(content) {
|
||||
return content + "\n\n" + header + "\n" + repliedContent + "\n" + footer
|
||||
}
|
||||
return header + "\n" + repliedContent + "\n" + footer + "\n\n[current_message]\n" + content + "\n[/current_message]"
|
||||
}
|
||||
|
||||
func hasLeadingCommandPrefix(s string) bool {
|
||||
tokens := strings.Fields(strings.TrimSpace(s))
|
||||
if len(tokens) == 0 {
|
||||
return false
|
||||
}
|
||||
first := tokens[0]
|
||||
return strings.HasPrefix(first, "/") || strings.HasPrefix(first, "!")
|
||||
}
|
||||
|
||||
func sanitizeReplyContextContent(s string) string {
|
||||
tagEscaper := strings.NewReplacer(
|
||||
"[replied_message", `\[replied_message`,
|
||||
"[/replied_message]", `\[/replied_message]`,
|
||||
"[current_message]", `\[current_message]`,
|
||||
"[/current_message]", `\[/current_message]`,
|
||||
)
|
||||
return tagEscaper.Replace(s)
|
||||
}
|
||||
@@ -0,0 +1,229 @@
|
||||
//go:build amd64 || arm64 || riscv64 || mips64 || ppc64
|
||||
|
||||
package feishu
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1"
|
||||
)
|
||||
|
||||
func TestBuildInboundMetadata(t *testing.T) {
|
||||
strPtr := func(s string) *string { return &s }
|
||||
|
||||
t.Run("includes basic and reply fields", func(t *testing.T) {
|
||||
message := &larkim.EventMessage{
|
||||
MessageId: strPtr("om_msg_1"),
|
||||
MessageType: strPtr("text"),
|
||||
ChatType: strPtr("group"),
|
||||
ParentId: strPtr("om_parent_1"),
|
||||
RootId: strPtr("om_root_1"),
|
||||
ThreadId: strPtr("omt_thread_1"),
|
||||
}
|
||||
sender := &larkim.EventSender{TenantKey: strPtr("tenant_x")}
|
||||
|
||||
got := buildInboundMetadata(message, sender)
|
||||
|
||||
if got["message_id"] != "om_msg_1" {
|
||||
t.Fatalf("message_id = %q, want %q", got["message_id"], "om_msg_1")
|
||||
}
|
||||
if got["message_type"] != "text" {
|
||||
t.Fatalf("message_type = %q, want %q", got["message_type"], "text")
|
||||
}
|
||||
if got["chat_type"] != "group" {
|
||||
t.Fatalf("chat_type = %q, want %q", got["chat_type"], "group")
|
||||
}
|
||||
if got["parent_id"] != "om_parent_1" {
|
||||
t.Fatalf("parent_id = %q, want %q", got["parent_id"], "om_parent_1")
|
||||
}
|
||||
if got["reply_to_message_id"] != "om_parent_1" {
|
||||
t.Fatalf("reply_to_message_id = %q, want %q", got["reply_to_message_id"], "om_parent_1")
|
||||
}
|
||||
if got["root_id"] != "om_root_1" {
|
||||
t.Fatalf("root_id = %q, want %q", got["root_id"], "om_root_1")
|
||||
}
|
||||
if got["thread_id"] != "omt_thread_1" {
|
||||
t.Fatalf("thread_id = %q, want %q", got["thread_id"], "omt_thread_1")
|
||||
}
|
||||
if got["tenant_key"] != "tenant_x" {
|
||||
t.Fatalf("tenant_key = %q, want %q", got["tenant_key"], "tenant_x")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("falls back reply_to_message_id to root_id", func(t *testing.T) {
|
||||
message := &larkim.EventMessage{
|
||||
MessageId: strPtr("om_msg_3"),
|
||||
RootId: strPtr("om_root_3"),
|
||||
}
|
||||
|
||||
got := buildInboundMetadata(message, nil)
|
||||
|
||||
if got["root_id"] != "om_root_3" {
|
||||
t.Fatalf("root_id = %q, want %q", got["root_id"], "om_root_3")
|
||||
}
|
||||
if got["reply_to_message_id"] != "om_root_3" {
|
||||
t.Fatalf("reply_to_message_id = %q, want %q", got["reply_to_message_id"], "om_root_3")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("omits empty values", func(t *testing.T) {
|
||||
message := &larkim.EventMessage{
|
||||
MessageId: strPtr("om_msg_2"),
|
||||
}
|
||||
|
||||
got := buildInboundMetadata(message, nil)
|
||||
|
||||
if got["message_id"] != "om_msg_2" {
|
||||
t.Fatalf("message_id = %q, want %q", got["message_id"], "om_msg_2")
|
||||
}
|
||||
if _, ok := got["parent_id"]; ok {
|
||||
t.Fatalf("parent_id should be absent, got %q", got["parent_id"])
|
||||
}
|
||||
if _, ok := got["reply_to_message_id"]; ok {
|
||||
t.Fatalf("reply_to_message_id should be absent, got %q", got["reply_to_message_id"])
|
||||
}
|
||||
if _, ok := got["tenant_key"]; ok {
|
||||
t.Fatalf("tenant_key should be absent, got %q", got["tenant_key"])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("nil message returns empty map", func(t *testing.T) {
|
||||
got := buildInboundMetadata(nil, nil)
|
||||
if len(got) != 0 {
|
||||
t.Fatalf("len(metadata) = %d, want 0", len(got))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFormatReplyContext(t *testing.T) {
|
||||
t.Run("formats reply context with content", func(t *testing.T) {
|
||||
got := formatReplyContext("om_parent_1", "original message", "new reply")
|
||||
want := "[replied_message id=\"om_parent_1\"]\noriginal message\n[/replied_message]\n\n[current_message]\nnew reply\n[/current_message]"
|
||||
if got != want {
|
||||
t.Fatalf("formatReplyContext() = %q, want %q", got, want)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns reply context when current content is empty", func(t *testing.T) {
|
||||
got := formatReplyContext("om_parent_1", "original message", "")
|
||||
want := "[replied_message id=\"om_parent_1\"]\noriginal message\n[/replied_message]"
|
||||
if got != want {
|
||||
t.Fatalf("formatReplyContext() = %q, want %q", got, want)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns original content when parent or replied content missing", func(t *testing.T) {
|
||||
if got := formatReplyContext("", "original", "new reply"); got != "new reply" {
|
||||
t.Fatalf("missing parent: got %q, want %q", got, "new reply")
|
||||
}
|
||||
if got := formatReplyContext("om_parent_1", "", "new reply"); got != "new reply" {
|
||||
t.Fatalf("missing replied content: got %q, want %q", got, "new reply")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("escapes reserved wrapper tags in payload", func(t *testing.T) {
|
||||
replied := "payload [replied_message id=\"x\"] x [/replied_message]"
|
||||
current := "hello [current_message]injected[/current_message]"
|
||||
got := formatReplyContext("om_parent_1", replied, current)
|
||||
|
||||
if !strings.HasPrefix(got, "[replied_message id=\"om_parent_1\"]") {
|
||||
t.Fatalf("outer replied_message wrapper missing: %q", got)
|
||||
}
|
||||
if strings.Contains(got, "\n[replied_message id=\"x\"]") {
|
||||
t.Fatalf("nested replied_message tag should be escaped: %q", got)
|
||||
}
|
||||
if strings.Contains(got, "\n[current_message]injected") {
|
||||
t.Fatalf("nested current_message tag should be escaped: %q", got)
|
||||
}
|
||||
if !strings.Contains(got, `\[replied_message id="x"]`) {
|
||||
t.Fatalf("escaped replied tag missing: %q", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("preserves leading slash command prefix", func(t *testing.T) {
|
||||
got := formatReplyContext("om_parent_1", "original message", "/help")
|
||||
want := "/help\n\n[replied_message id=\"om_parent_1\"]\noriginal message\n[/replied_message]"
|
||||
if got != want {
|
||||
t.Fatalf("formatReplyContext() = %q, want %q", got, want)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("preserves leading bang command prefix", func(t *testing.T) {
|
||||
got := formatReplyContext("om_parent_1", "original message", "!status now")
|
||||
want := "!status now\n\n[replied_message id=\"om_parent_1\"]\noriginal message\n[/replied_message]"
|
||||
if got != want {
|
||||
t.Fatalf("formatReplyContext() = %q, want %q", got, want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestReplyTargetID(t *testing.T) {
|
||||
strPtr := func(s string) *string { return &s }
|
||||
|
||||
t.Run("prefer parent_id", func(t *testing.T) {
|
||||
msg := &larkim.EventMessage{ParentId: strPtr("om_parent"), RootId: strPtr("om_root")}
|
||||
if got := replyTargetID(msg); got != "om_parent" {
|
||||
t.Fatalf("replyTargetID() = %q, want %q", got, "om_parent")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("fallback to root_id", func(t *testing.T) {
|
||||
msg := &larkim.EventMessage{RootId: strPtr("om_root")}
|
||||
if got := replyTargetID(msg); got != "om_root" {
|
||||
t.Fatalf("replyTargetID() = %q, want %q", got, "om_root")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty when no fields", func(t *testing.T) {
|
||||
if got := replyTargetID(&larkim.EventMessage{}); got != "" {
|
||||
t.Fatalf("replyTargetID() = %q, want empty", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestNormalizeRepliedContent(t *testing.T) {
|
||||
t.Run("filters feishu upgrade placeholder for interactive", func(t *testing.T) {
|
||||
raw := `{"text":"\u8bf7\u5347\u7ea7\u81f3\u6700\u65b0\u7248\u672c\u5ba2\u6237\u7aef\uff0c\u4ee5\u67e5\u770b\u5185\u5bb9"}`
|
||||
got := normalizeRepliedContent("interactive", raw, nil)
|
||||
if got != "[replied interactive card]" {
|
||||
t.Fatalf("normalizeRepliedContent() = %q, want %q", got, "[replied interactive card]")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("keeps filename and file tag for replied file", func(t *testing.T) {
|
||||
got := normalizeRepliedContent("file", `{"file_key":"file_xxx","file_name":"doc.pdf"}`, []string{"media://r1"})
|
||||
if got != "doc.pdf [file]" {
|
||||
t.Fatalf("normalizeRepliedContent() = %q, want %q", got, "doc.pdf [file]")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("falls back when file content missing", func(t *testing.T) {
|
||||
got := normalizeRepliedContent("file", `{"file_key":"file_xxx"}`, nil)
|
||||
if got != "[replied file]" {
|
||||
t.Fatalf("normalizeRepliedContent() = %q, want %q", got, "[replied file]")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestHasLeadingCommandPrefix(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want bool
|
||||
}{
|
||||
{name: "slash command", input: "/help", want: true},
|
||||
{name: "bang command", input: "!status", want: true},
|
||||
{name: "leading spaces slash", input: " /ping arg", want: true},
|
||||
{name: "normal text", input: "hello /help", want: false},
|
||||
{name: "empty", input: "", want: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := hasLeadingCommandPrefix(tt.input); got != tt.want {
|
||||
t.Fatalf("hasLeadingCommandPrefix(%q) = %v, want %v", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -430,6 +430,19 @@ func (m *Manager) initChannels(channels *config.ChannelsConfig) error {
|
||||
m.initChannel("vk", "VK")
|
||||
}
|
||||
|
||||
if channels.TeamsWebhook.Enabled && len(channels.TeamsWebhook.Webhooks) > 0 {
|
||||
hasValidTarget := false
|
||||
for _, target := range channels.TeamsWebhook.Webhooks {
|
||||
if target.WebhookURL.String() != "" {
|
||||
hasValidTarget = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if hasValidTarget {
|
||||
m.initChannel("teams_webhook", "Teams Webhook")
|
||||
}
|
||||
}
|
||||
|
||||
logger.InfoCF("channels", "Channel initialization completed", map[string]any{
|
||||
"enabled_channels": len(m.channels),
|
||||
})
|
||||
|
||||
@@ -62,6 +62,13 @@ func hiddenValues(key string, value map[string]any, ch config.ChannelsConfig) {
|
||||
value["app_secret"] = ch.Feishu.AppSecret.String()
|
||||
value["encrypt_key"] = ch.Feishu.EncryptKey.String()
|
||||
value["verification_token"] = ch.Feishu.VerificationToken.String()
|
||||
case "teams_webhook":
|
||||
// Expose webhook URLs for hash computation (they contain secrets)
|
||||
webhooks := make(map[string]string)
|
||||
for name, target := range ch.TeamsWebhook.Webhooks {
|
||||
webhooks[name] = target.WebhookURL.String()
|
||||
}
|
||||
value["webhooks"] = webhooks
|
||||
}
|
||||
}
|
||||
|
||||
@@ -166,4 +173,13 @@ func updateKeys(newcfg, old *config.ChannelsConfig) {
|
||||
newcfg.Feishu.EncryptKey = old.Feishu.EncryptKey
|
||||
newcfg.Feishu.VerificationToken = old.Feishu.VerificationToken
|
||||
}
|
||||
if newcfg.TeamsWebhook.Enabled {
|
||||
// Copy SecureString webhook URLs from old config
|
||||
for name, oldTarget := range old.TeamsWebhook.Webhooks {
|
||||
if newTarget, ok := newcfg.TeamsWebhook.Webhooks[name]; ok {
|
||||
newTarget.WebhookURL = oldTarget.WebhookURL
|
||||
newcfg.TeamsWebhook.Webhooks[name] = newTarget
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
package teamswebhook
|
||||
|
||||
import (
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
func init() {
|
||||
channels.RegisterFactory("teams_webhook", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
|
||||
return NewTeamsWebhookChannel(cfg.Channels.TeamsWebhook, b)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,422 @@
|
||||
package teamswebhook
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
goteamsnotify "github.com/atc0005/go-teams-notify/v2"
|
||||
"github.com/atc0005/go-teams-notify/v2/adaptivecard"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
// statusCodeRe extracts HTTP status codes from error messages like "401 Unauthorized".
|
||||
var statusCodeRe = regexp.MustCompile(`\b([45]\d{2})\b`)
|
||||
|
||||
// markdownTableRe matches a markdown table block (header + separator + rows).
|
||||
// It captures the entire table including all rows.
|
||||
var markdownTableRe = regexp.MustCompile(`(?m)^(\|[^\n]+\|)\n(\|[-:\|\s]+\|)\n((?:\|[^\n]+\|\n?)+)`)
|
||||
|
||||
// teamsMessageSender abstracts the Teams client for testability.
|
||||
type teamsMessageSender interface {
|
||||
SendWithContext(ctx context.Context, webhookURL string, message goteamsnotify.TeamsMessage) error
|
||||
}
|
||||
|
||||
// classifyTeamsError extracts HTTP status code from error message and classifies it.
|
||||
// The go-teams-notify library returns errors like "error on notification: 401 Unauthorized, ...".
|
||||
// This allows proper retry behavior: 4xx errors are permanent, 5xx are temporary.
|
||||
func classifyTeamsError(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
errMsg := err.Error()
|
||||
if matches := statusCodeRe.FindStringSubmatch(errMsg); len(matches) > 1 {
|
||||
if statusCode, parseErr := strconv.Atoi(matches[1]); parseErr == nil {
|
||||
return channels.ClassifySendError(statusCode, err)
|
||||
}
|
||||
}
|
||||
// Fallback: treat as temporary network error (retryable)
|
||||
return channels.ClassifyNetError(err)
|
||||
}
|
||||
|
||||
// TeamsWebhookChannel is an output-only channel that sends messages
|
||||
// to Microsoft Teams via Power Automate workflow webhooks.
|
||||
// Multiple webhook targets can be configured and selected via ChatID.
|
||||
type TeamsWebhookChannel struct {
|
||||
*channels.BaseChannel
|
||||
config config.TeamsWebhookConfig
|
||||
client teamsMessageSender
|
||||
}
|
||||
|
||||
// NewTeamsWebhookChannel creates a new Teams webhook channel.
|
||||
func NewTeamsWebhookChannel(
|
||||
cfg config.TeamsWebhookConfig,
|
||||
bus *bus.MessageBus,
|
||||
) (*TeamsWebhookChannel, error) {
|
||||
if len(cfg.Webhooks) == 0 {
|
||||
return nil, fmt.Errorf("teams_webhook: at least one webhook target is required")
|
||||
}
|
||||
|
||||
// Require "default" webhook target
|
||||
if _, hasDefault := cfg.Webhooks["default"]; !hasDefault {
|
||||
return nil, fmt.Errorf("teams_webhook: a 'default' webhook target is required")
|
||||
}
|
||||
|
||||
// Validate all webhook targets have valid HTTPS URLs
|
||||
for name, target := range cfg.Webhooks {
|
||||
webhookURL := target.WebhookURL.String()
|
||||
if webhookURL == "" {
|
||||
return nil, fmt.Errorf("teams_webhook: webhook %q has empty webhook_url", name)
|
||||
}
|
||||
parsed, err := url.Parse(webhookURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("teams_webhook: webhook %q has invalid URL: %w", name, err)
|
||||
}
|
||||
if !strings.EqualFold(parsed.Scheme, "https") {
|
||||
return nil, fmt.Errorf("teams_webhook: webhook %q must use HTTPS (got %q)", name, parsed.Scheme)
|
||||
}
|
||||
}
|
||||
|
||||
base := channels.NewBaseChannel(
|
||||
"teams_webhook",
|
||||
cfg,
|
||||
bus,
|
||||
[]string{
|
||||
"*",
|
||||
}, // Output-only channel; "*" suppresses misleading "allows EVERYONE" audit warning
|
||||
channels.WithMaxMessageLength(24000), // Power Automate webhook payload limit is 28KB
|
||||
)
|
||||
|
||||
client := goteamsnotify.NewTeamsClient()
|
||||
|
||||
return &TeamsWebhookChannel{
|
||||
BaseChannel: base,
|
||||
config: cfg,
|
||||
client: client,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Start initializes the channel. For output-only channels, this is a no-op.
|
||||
func (c *TeamsWebhookChannel) Start(ctx context.Context) error {
|
||||
targets := make([]string, 0, len(c.config.Webhooks))
|
||||
for name := range c.config.Webhooks {
|
||||
targets = append(targets, name)
|
||||
}
|
||||
sort.Strings(targets)
|
||||
logger.InfoCF("teams_webhook", "Starting Teams webhook channel (output-only)", map[string]any{
|
||||
"targets": targets,
|
||||
})
|
||||
c.SetRunning(true)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop shuts down the channel.
|
||||
func (c *TeamsWebhookChannel) Stop(ctx context.Context) error {
|
||||
logger.InfoC("teams_webhook", "Stopping Teams webhook channel")
|
||||
c.SetRunning(false)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send delivers a message to the specified Teams webhook target.
|
||||
// The target is selected by msg.ChatID which must match a key in the webhooks map.
|
||||
func (c *TeamsWebhookChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]string, error) {
|
||||
if !c.IsRunning() {
|
||||
return nil, channels.ErrNotRunning
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
// Look up webhook target by ChatID, fall back to "default" if empty or unknown
|
||||
targetName := msg.ChatID
|
||||
if targetName == "" {
|
||||
targetName = "default"
|
||||
}
|
||||
|
||||
target, ok := c.config.Webhooks[targetName]
|
||||
if !ok {
|
||||
// Log warning and fall back to default target
|
||||
logger.WarnCF("teams_webhook", "Unknown target, falling back to default", map[string]any{
|
||||
"requested": msg.ChatID,
|
||||
"using": "default",
|
||||
})
|
||||
target = c.config.Webhooks["default"]
|
||||
}
|
||||
|
||||
// Build an Adaptive Card for rich formatting
|
||||
card, err := c.buildAdaptiveCard(msg, target)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("teams_webhook: failed to build card: %w", err)
|
||||
}
|
||||
|
||||
// Create the message with the card
|
||||
teamsMsg, err := adaptivecard.NewMessageFromCard(card)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("teams_webhook: failed to create message: %w", err)
|
||||
}
|
||||
|
||||
// Send to Teams
|
||||
if err := c.client.SendWithContext(ctx, target.WebhookURL.String(), teamsMsg); err != nil {
|
||||
// Log without raw error to avoid leaking webhook URL (embedded in net/http errors)
|
||||
logger.ErrorCF("teams_webhook", "Failed to send message to Teams webhook", map[string]any{
|
||||
"target": msg.ChatID,
|
||||
})
|
||||
// Classify error based on status code extracted from error message.
|
||||
// The go-teams-notify library includes status in errors like "401 Unauthorized".
|
||||
// Use ClassifySendError for proper retry behavior (4xx = permanent, 5xx = temporary).
|
||||
classifiedErr := classifyTeamsError(err)
|
||||
return nil, fmt.Errorf("teams_webhook: send failed: %w", classifiedErr)
|
||||
}
|
||||
|
||||
logger.DebugCF("teams_webhook", "Message sent successfully", map[string]any{
|
||||
"target": msg.ChatID,
|
||||
})
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// buildAdaptiveCard creates a formatted Adaptive Card from the outbound message.
|
||||
// It detects markdown tables and converts them to native Adaptive Card Table elements,
|
||||
// since TextBlocks only support a limited markdown subset (no tables).
|
||||
func (c *TeamsWebhookChannel) buildAdaptiveCard(
|
||||
msg bus.OutboundMessage,
|
||||
target config.TeamsWebhookTarget,
|
||||
) (adaptivecard.Card, error) {
|
||||
card := adaptivecard.NewCard()
|
||||
card.Type = adaptivecard.TypeAdaptiveCard
|
||||
|
||||
// Set full width for Teams rendering
|
||||
card.MSTeams.Width = "Full"
|
||||
|
||||
// Add title if configured on the target
|
||||
title := target.Title
|
||||
if title == "" {
|
||||
title = "PicoClaw Notification"
|
||||
}
|
||||
|
||||
titleBlock := adaptivecard.NewTextBlock(title, true)
|
||||
titleBlock.Size = adaptivecard.SizeLarge
|
||||
titleBlock.Weight = adaptivecard.WeightBolder
|
||||
titleBlock.Style = adaptivecard.TextBlockStyleHeading
|
||||
|
||||
if err := card.AddElement(false, titleBlock); err != nil {
|
||||
return card, err
|
||||
}
|
||||
|
||||
content := msg.Content
|
||||
if content == "" {
|
||||
content = "(empty message)"
|
||||
}
|
||||
|
||||
// Split content into text segments and tables
|
||||
// TextBlocks support: bold, italic, bullet/numbered lists, links
|
||||
// TextBlocks do NOT support: headers, tables, images
|
||||
segments := splitContentWithTables(content)
|
||||
|
||||
for _, seg := range segments {
|
||||
if seg.isTable {
|
||||
// Convert markdown table to Adaptive Card Table element
|
||||
tableElement, err := parseMarkdownTable(seg.content)
|
||||
if err != nil {
|
||||
// Fallback: render as preformatted text if parsing fails
|
||||
logger.WarnCF("teams_webhook", "Failed to parse markdown table, using fallback", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
block := adaptivecard.NewTextBlock("```\n"+seg.content+"\n```", true)
|
||||
block.Wrap = true
|
||||
if err := card.AddElement(false, block); err != nil {
|
||||
return card, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
if err := card.AddElement(false, tableElement); err != nil {
|
||||
return card, err
|
||||
}
|
||||
} else {
|
||||
// Regular text content
|
||||
text := strings.TrimSpace(seg.content)
|
||||
if text == "" {
|
||||
continue
|
||||
}
|
||||
block := adaptivecard.NewTextBlock(text, true)
|
||||
block.Wrap = true
|
||||
if err := card.AddElement(false, block); err != nil {
|
||||
return card, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return card, nil
|
||||
}
|
||||
|
||||
// contentSegment represents either a text block or a table in the message content.
|
||||
type contentSegment struct {
|
||||
content string
|
||||
isTable bool
|
||||
}
|
||||
|
||||
// splitContentWithTables splits content into alternating text and table segments.
|
||||
func splitContentWithTables(content string) []contentSegment {
|
||||
var segments []contentSegment
|
||||
|
||||
matches := markdownTableRe.FindAllStringSubmatchIndex(content, -1)
|
||||
if len(matches) == 0 {
|
||||
// No tables found, return entire content as text
|
||||
return []contentSegment{{content: content, isTable: false}}
|
||||
}
|
||||
|
||||
lastEnd := 0
|
||||
for _, match := range matches {
|
||||
// Text before this table
|
||||
if match[0] > lastEnd {
|
||||
segments = append(segments, contentSegment{
|
||||
content: content[lastEnd:match[0]],
|
||||
isTable: false,
|
||||
})
|
||||
}
|
||||
// The table itself
|
||||
segments = append(segments, contentSegment{
|
||||
content: content[match[0]:match[1]],
|
||||
isTable: true,
|
||||
})
|
||||
lastEnd = match[1]
|
||||
}
|
||||
|
||||
// Text after the last table
|
||||
if lastEnd < len(content) {
|
||||
segments = append(segments, contentSegment{
|
||||
content: content[lastEnd:],
|
||||
isTable: false,
|
||||
})
|
||||
}
|
||||
|
||||
return segments
|
||||
}
|
||||
|
||||
// parseMarkdownTable converts a markdown table string to an Adaptive Card Table element.
|
||||
func parseMarkdownTable(tableStr string) (adaptivecard.Element, error) {
|
||||
lines := strings.Split(strings.TrimSpace(tableStr), "\n")
|
||||
if len(lines) < 2 {
|
||||
return adaptivecard.Element{}, fmt.Errorf("table must have at least header and separator rows")
|
||||
}
|
||||
|
||||
// Track header content length per column for width calculation
|
||||
var headerLengths []int
|
||||
|
||||
// Parse all rows (header + data rows, skip separator)
|
||||
var allRows [][]adaptivecard.TableCell
|
||||
for i, line := range lines {
|
||||
// Skip separator row (contains only |, -, :, and spaces)
|
||||
if i == 1 && isSeparatorRow(line) {
|
||||
continue
|
||||
}
|
||||
|
||||
cells := parseTableRow(line)
|
||||
if len(cells) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var tableCells []adaptivecard.TableCell
|
||||
for _, cellText := range cells {
|
||||
trimmedText := strings.TrimSpace(cellText)
|
||||
|
||||
// Use header row (first row) to determine column widths
|
||||
if i == 0 {
|
||||
headerLengths = append(headerLengths, len(trimmedText))
|
||||
}
|
||||
|
||||
textBlock := adaptivecard.Element{
|
||||
Type: adaptivecard.TypeElementTextBlock,
|
||||
Text: trimmedText,
|
||||
Wrap: true,
|
||||
}
|
||||
cell := adaptivecard.TableCell{
|
||||
Type: adaptivecard.TypeTableCell,
|
||||
Items: []*adaptivecard.Element{&textBlock},
|
||||
}
|
||||
tableCells = append(tableCells, cell)
|
||||
}
|
||||
allRows = append(allRows, tableCells)
|
||||
}
|
||||
|
||||
if len(allRows) == 0 {
|
||||
return adaptivecard.Element{}, fmt.Errorf("no valid rows found in table")
|
||||
}
|
||||
|
||||
// Create table with first row as headers
|
||||
firstRowAsHeaders := true
|
||||
showGridLines := true
|
||||
|
||||
table, err := adaptivecard.NewTableFromTableCells(allRows, 0, firstRowAsHeaders, showGridLines)
|
||||
if err != nil {
|
||||
return adaptivecard.Element{}, fmt.Errorf("failed to create table: %w", err)
|
||||
}
|
||||
|
||||
// Set column widths based on header content length
|
||||
table.Columns = calculateColumnWidths(headerLengths)
|
||||
|
||||
return table, nil
|
||||
}
|
||||
|
||||
// calculateColumnWidths creates TableColumnDefinition entries with widths
|
||||
// proportional to the max content length of each column.
|
||||
func calculateColumnWidths(maxLengths []int) []adaptivecard.Column {
|
||||
if len(maxLengths) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Use content length as relative weight, with a minimum of 1
|
||||
columns := make([]adaptivecard.Column, len(maxLengths))
|
||||
for i, length := range maxLengths {
|
||||
weight := length
|
||||
if weight < 1 {
|
||||
weight = 1
|
||||
}
|
||||
columns[i] = adaptivecard.Column{
|
||||
Type: "TableColumnDefinition",
|
||||
Width: weight,
|
||||
}
|
||||
}
|
||||
|
||||
return columns
|
||||
}
|
||||
|
||||
// isSeparatorRow checks if a line is a markdown table separator (e.g., |---|---|).
|
||||
func isSeparatorRow(line string) bool {
|
||||
// Remove pipes and spaces, check if only dashes and colons remain
|
||||
cleaned := strings.ReplaceAll(line, "|", "")
|
||||
cleaned = strings.ReplaceAll(cleaned, " ", "")
|
||||
cleaned = strings.ReplaceAll(cleaned, "-", "")
|
||||
cleaned = strings.ReplaceAll(cleaned, ":", "")
|
||||
return cleaned == ""
|
||||
}
|
||||
|
||||
// parseTableRow extracts cell values from a markdown table row.
|
||||
func parseTableRow(line string) []string {
|
||||
// Trim leading/trailing pipes and split by |
|
||||
line = strings.TrimSpace(line)
|
||||
line = strings.TrimPrefix(line, "|")
|
||||
line = strings.TrimSuffix(line, "|")
|
||||
|
||||
if line == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
parts := strings.Split(line, "|")
|
||||
var cells []string
|
||||
for _, p := range parts {
|
||||
cells = append(cells, strings.TrimSpace(p))
|
||||
}
|
||||
return cells
|
||||
}
|
||||
@@ -0,0 +1,583 @@
|
||||
package teamswebhook
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
goteamsnotify "github.com/atc0005/go-teams-notify/v2"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
// mockTeamsClient implements teamsMessageSender for testing.
|
||||
type mockTeamsClient struct {
|
||||
sendFunc func(ctx context.Context, webhookURL string, message goteamsnotify.TeamsMessage) error
|
||||
}
|
||||
|
||||
func (m *mockTeamsClient) SendWithContext(
|
||||
ctx context.Context,
|
||||
webhookURL string,
|
||||
message goteamsnotify.TeamsMessage,
|
||||
) error {
|
||||
if m.sendFunc != nil {
|
||||
return m.sendFunc(ctx, webhookURL, message)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestNewTeamsWebhookChannel(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
|
||||
// Test missing webhooks
|
||||
_, err := NewTeamsWebhookChannel(config.TeamsWebhookConfig{
|
||||
Enabled: true,
|
||||
Webhooks: nil,
|
||||
}, msgBus)
|
||||
if err == nil {
|
||||
t.Error("expected error for missing webhooks")
|
||||
}
|
||||
|
||||
// Test missing "default" webhook
|
||||
_, err = NewTeamsWebhookChannel(config.TeamsWebhookConfig{
|
||||
Enabled: true,
|
||||
Webhooks: map[string]config.TeamsWebhookTarget{
|
||||
"alerts": {
|
||||
WebhookURL: *config.NewSecureString("https://example.com/webhook"),
|
||||
Title: "Alerts",
|
||||
},
|
||||
},
|
||||
}, msgBus)
|
||||
if err == nil {
|
||||
t.Error("expected error for missing 'default' webhook")
|
||||
}
|
||||
|
||||
// Test empty webhook URL
|
||||
_, err = NewTeamsWebhookChannel(config.TeamsWebhookConfig{
|
||||
Enabled: true,
|
||||
Webhooks: map[string]config.TeamsWebhookTarget{
|
||||
"default": {Title: "Default"},
|
||||
},
|
||||
}, msgBus)
|
||||
if err == nil {
|
||||
t.Error("expected error for empty webhook_url")
|
||||
}
|
||||
|
||||
// Test HTTP URL (should fail, must be HTTPS)
|
||||
_, err = NewTeamsWebhookChannel(config.TeamsWebhookConfig{
|
||||
Enabled: true,
|
||||
Webhooks: map[string]config.TeamsWebhookTarget{
|
||||
"default": {
|
||||
WebhookURL: *config.NewSecureString("http://example.com/webhook"),
|
||||
Title: "Default",
|
||||
},
|
||||
},
|
||||
}, msgBus)
|
||||
if err == nil {
|
||||
t.Error("expected error for HTTP webhook URL (must be HTTPS)")
|
||||
}
|
||||
|
||||
// Test valid config with HTTPS (must include "default")
|
||||
ch, err := NewTeamsWebhookChannel(config.TeamsWebhookConfig{
|
||||
Enabled: true,
|
||||
Webhooks: map[string]config.TeamsWebhookTarget{
|
||||
"default": {
|
||||
WebhookURL: *config.NewSecureString("https://example.com/webhook-default"),
|
||||
Title: "Default",
|
||||
},
|
||||
"alerts": {
|
||||
WebhookURL: *config.NewSecureString("https://example.com/webhook1"),
|
||||
Title: "Alerts",
|
||||
},
|
||||
},
|
||||
}, msgBus)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if ch.Name() != "teams_webhook" {
|
||||
t.Errorf("expected name 'teams_webhook', got %q", ch.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestTeamsWebhookChannel_StartStop(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
ch, err := NewTeamsWebhookChannel(config.TeamsWebhookConfig{
|
||||
Enabled: true,
|
||||
Webhooks: map[string]config.TeamsWebhookTarget{
|
||||
"default": {
|
||||
WebhookURL: *config.NewSecureString("https://example.com/webhook"),
|
||||
},
|
||||
},
|
||||
}, msgBus)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
if ch.IsRunning() {
|
||||
t.Error("channel should not be running before Start")
|
||||
}
|
||||
|
||||
if err := ch.Start(ctx); err != nil {
|
||||
t.Fatalf("Start failed: %v", err)
|
||||
}
|
||||
|
||||
if !ch.IsRunning() {
|
||||
t.Error("channel should be running after Start")
|
||||
}
|
||||
|
||||
if err := ch.Stop(ctx); err != nil {
|
||||
t.Fatalf("Stop failed: %v", err)
|
||||
}
|
||||
|
||||
if ch.IsRunning() {
|
||||
t.Error("channel should not be running after Stop")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTeamsWebhookChannel_BuildAdaptiveCard(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
ch, err := NewTeamsWebhookChannel(config.TeamsWebhookConfig{
|
||||
Enabled: true,
|
||||
Webhooks: map[string]config.TeamsWebhookTarget{
|
||||
"default": {
|
||||
WebhookURL: *config.NewSecureString("https://example.com/webhook-default"),
|
||||
Title: "Default",
|
||||
},
|
||||
"alerts": {
|
||||
WebhookURL: *config.NewSecureString("https://example.com/webhook"),
|
||||
Title: "Custom Title",
|
||||
},
|
||||
},
|
||||
}, msgBus)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
target := ch.config.Webhooks["alerts"]
|
||||
msg := bus.OutboundMessage{
|
||||
Content: "Test message content",
|
||||
ChatID: "alerts",
|
||||
}
|
||||
|
||||
card, err := ch.buildAdaptiveCard(msg, target)
|
||||
if err != nil {
|
||||
t.Fatalf("buildAdaptiveCard failed: %v", err)
|
||||
}
|
||||
|
||||
if card.Type != "AdaptiveCard" {
|
||||
t.Errorf("expected card type 'AdaptiveCard', got %q", card.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTeamsWebhookChannel_SendNotRunning(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
ch, err := NewTeamsWebhookChannel(config.TeamsWebhookConfig{
|
||||
Enabled: true,
|
||||
Webhooks: map[string]config.TeamsWebhookTarget{
|
||||
"default": {
|
||||
WebhookURL: *config.NewSecureString("https://example.com/webhook"),
|
||||
},
|
||||
},
|
||||
}, msgBus)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
msg := bus.OutboundMessage{Content: "test", ChatID: "default"}
|
||||
|
||||
_, err = ch.Send(ctx, msg)
|
||||
if err == nil {
|
||||
t.Error("expected error when sending while not running")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTeamsWebhookChannel_SendDefaultTargetFallback(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
chatID string
|
||||
}{
|
||||
{"unknown target falls back to default", "unknown"},
|
||||
{"empty ChatID uses default", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
ch, err := NewTeamsWebhookChannel(config.TeamsWebhookConfig{
|
||||
Enabled: true,
|
||||
Webhooks: map[string]config.TeamsWebhookTarget{
|
||||
"default": {
|
||||
WebhookURL: *config.NewSecureString("https://example.com/webhook-default"),
|
||||
},
|
||||
"alerts": {
|
||||
WebhookURL: *config.NewSecureString("https://example.com/webhook-alerts"),
|
||||
},
|
||||
},
|
||||
}, msgBus)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
var sentURL string
|
||||
ch.client = &mockTeamsClient{
|
||||
sendFunc: func(ctx context.Context, webhookURL string, message goteamsnotify.TeamsMessage) error {
|
||||
sentURL = webhookURL
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
_ = ch.Start(ctx)
|
||||
defer ch.Stop(ctx)
|
||||
|
||||
msg := bus.OutboundMessage{Content: "test", ChatID: tt.chatID}
|
||||
_, err = ch.Send(ctx, msg)
|
||||
if err != nil {
|
||||
t.Fatalf("expected success, got error: %v", err)
|
||||
}
|
||||
|
||||
if sentURL != "https://example.com/webhook-default" {
|
||||
t.Errorf("expected default webhook URL, got %q", sentURL)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTeamsWebhookChannel_SendSuccess(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
ch, err := NewTeamsWebhookChannel(config.TeamsWebhookConfig{
|
||||
Enabled: true,
|
||||
Webhooks: map[string]config.TeamsWebhookTarget{
|
||||
"default": {
|
||||
WebhookURL: *config.NewSecureString("https://example.com/webhook-default"),
|
||||
Title: "Default",
|
||||
},
|
||||
"alerts": {
|
||||
WebhookURL: *config.NewSecureString("https://example.com/webhook-alerts"),
|
||||
Title: "Test Alerts",
|
||||
},
|
||||
},
|
||||
}, msgBus)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Inject mock client
|
||||
var sentURL string
|
||||
ch.client = &mockTeamsClient{
|
||||
sendFunc: func(ctx context.Context, webhookURL string, message goteamsnotify.TeamsMessage) error {
|
||||
sentURL = webhookURL
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
_ = ch.Start(ctx)
|
||||
defer ch.Stop(ctx)
|
||||
|
||||
msg := bus.OutboundMessage{Content: "Hello Teams!", ChatID: "alerts"}
|
||||
|
||||
_, err = ch.Send(ctx, msg)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if sentURL != "https://example.com/webhook-alerts" {
|
||||
t.Errorf("expected webhook URL 'https://example.com/webhook-alerts', got %q", sentURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTeamsWebhookChannel_SendError(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
ch, err := NewTeamsWebhookChannel(config.TeamsWebhookConfig{
|
||||
Enabled: true,
|
||||
Webhooks: map[string]config.TeamsWebhookTarget{
|
||||
"default": {
|
||||
WebhookURL: *config.NewSecureString("https://example.com/webhook-default"),
|
||||
},
|
||||
"alerts": {
|
||||
WebhookURL: *config.NewSecureString("https://example.com/webhook-alerts"),
|
||||
},
|
||||
},
|
||||
}, msgBus)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Inject mock client that returns an error
|
||||
ch.client = &mockTeamsClient{
|
||||
sendFunc: func(ctx context.Context, webhookURL string, message goteamsnotify.TeamsMessage) error {
|
||||
return errors.New("error on notification: 401 Unauthorized, forbidden")
|
||||
},
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
_ = ch.Start(ctx)
|
||||
defer ch.Stop(ctx)
|
||||
|
||||
msg := bus.OutboundMessage{Content: "test", ChatID: "alerts"}
|
||||
|
||||
_, err = ch.Send(ctx, msg)
|
||||
if err == nil {
|
||||
t.Error("expected error from failed send")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitContentWithTables(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
wantSegs int
|
||||
wantTbl int // number of table segments
|
||||
}{
|
||||
{
|
||||
name: "no tables",
|
||||
content: "Just some text\nwith multiple lines",
|
||||
wantSegs: 1,
|
||||
wantTbl: 0,
|
||||
},
|
||||
{
|
||||
name: "single table",
|
||||
content: `| Col1 | Col2 |
|
||||
|------|------|
|
||||
| A | B |
|
||||
| C | D |`,
|
||||
wantSegs: 1,
|
||||
wantTbl: 1,
|
||||
},
|
||||
{
|
||||
name: "text before table",
|
||||
content: `Here is some text.
|
||||
|
||||
| Col1 | Col2 |
|
||||
|------|------|
|
||||
| A | B |`,
|
||||
wantSegs: 2,
|
||||
wantTbl: 1,
|
||||
},
|
||||
{
|
||||
name: "text before and after table",
|
||||
content: `Before table.
|
||||
|
||||
| Col1 | Col2 |
|
||||
|------|------|
|
||||
| A | B |
|
||||
|
||||
After table.`,
|
||||
wantSegs: 3,
|
||||
wantTbl: 1,
|
||||
},
|
||||
{
|
||||
name: "multiple tables",
|
||||
content: `First table:
|
||||
|
||||
| A | B |
|
||||
|---|---|
|
||||
| 1 | 2 |
|
||||
|
||||
Second table:
|
||||
|
||||
| X | Y |
|
||||
|---|---|
|
||||
| 3 | 4 |`,
|
||||
wantSegs: 4,
|
||||
wantTbl: 2,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
segs := splitContentWithTables(tt.content)
|
||||
if len(segs) != tt.wantSegs {
|
||||
t.Errorf("got %d segments, want %d", len(segs), tt.wantSegs)
|
||||
}
|
||||
tableCount := 0
|
||||
for _, s := range segs {
|
||||
if s.isTable {
|
||||
tableCount++
|
||||
}
|
||||
}
|
||||
if tableCount != tt.wantTbl {
|
||||
t.Errorf("got %d tables, want %d", tableCount, tt.wantTbl)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseMarkdownTable(t *testing.T) {
|
||||
tableStr := `| Name | Value |
|
||||
|------|-------|
|
||||
| foo | 123 |
|
||||
| bar | 456 |`
|
||||
|
||||
elem, err := parseMarkdownTable(tableStr)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if elem.Type != "Table" {
|
||||
t.Errorf("expected type 'Table', got %q", elem.Type)
|
||||
}
|
||||
|
||||
// Should have 3 rows (header + 2 data rows)
|
||||
if len(elem.Rows) != 3 {
|
||||
t.Errorf("expected 3 rows, got %d", len(elem.Rows))
|
||||
}
|
||||
|
||||
// Should have 2 columns with widths based on content length
|
||||
if len(elem.Columns) != 2 {
|
||||
t.Errorf("expected 2 columns, got %d", len(elem.Columns))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseMarkdownTableColumnWidths(t *testing.T) {
|
||||
// Column widths are based on HEADER row only:
|
||||
// Col1: "Description" (11 chars)
|
||||
// Col2: "X" (1 char)
|
||||
// Col3: "Amount" (6 chars)
|
||||
tableStr := `| Description | X | Amount |
|
||||
|-------------|---|--------|
|
||||
| Short | Y | 100 |
|
||||
| Longer text | Z | 50 |`
|
||||
|
||||
elem, err := parseMarkdownTable(tableStr)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(elem.Columns) != 3 {
|
||||
t.Fatalf("expected 3 columns, got %d", len(elem.Columns))
|
||||
}
|
||||
|
||||
// Verify column widths are based on header content length
|
||||
w1, ok1 := elem.Columns[0].Width.(int)
|
||||
w2, ok2 := elem.Columns[1].Width.(int)
|
||||
w3, ok3 := elem.Columns[2].Width.(int)
|
||||
|
||||
if !ok1 || !ok2 || !ok3 {
|
||||
t.Fatalf("expected int widths, got types: %T, %T, %T",
|
||||
elem.Columns[0].Width, elem.Columns[1].Width, elem.Columns[2].Width)
|
||||
}
|
||||
|
||||
// Header lengths: "Description" = 11, "X" = 1, "Amount" = 6
|
||||
if w1 != 11 {
|
||||
t.Errorf("expected col1 width 11 (from 'Description'), got %d", w1)
|
||||
}
|
||||
if w2 != 1 {
|
||||
t.Errorf("expected col2 width 1 (from 'X'), got %d", w2)
|
||||
}
|
||||
if w3 != 6 {
|
||||
t.Errorf("expected col3 width 6 (from 'Amount'), got %d", w3)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateColumnWidths(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
maxLengths []int
|
||||
wantWidths []int
|
||||
}{
|
||||
{
|
||||
name: "equal lengths",
|
||||
maxLengths: []int{10, 10, 10},
|
||||
wantWidths: []int{10, 10, 10},
|
||||
},
|
||||
{
|
||||
name: "varying lengths",
|
||||
maxLengths: []int{5, 20, 10},
|
||||
wantWidths: []int{5, 20, 10},
|
||||
},
|
||||
{
|
||||
name: "zero length gets minimum of 1",
|
||||
maxLengths: []int{0, 5, 0},
|
||||
wantWidths: []int{1, 5, 1},
|
||||
},
|
||||
{
|
||||
name: "empty input",
|
||||
maxLengths: []int{},
|
||||
wantWidths: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cols := calculateColumnWidths(tt.maxLengths)
|
||||
|
||||
if tt.wantWidths == nil {
|
||||
if cols != nil {
|
||||
t.Errorf("expected nil, got %v", cols)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if len(cols) != len(tt.wantWidths) {
|
||||
t.Fatalf("expected %d columns, got %d", len(tt.wantWidths), len(cols))
|
||||
}
|
||||
|
||||
for i, col := range cols {
|
||||
width, ok := col.Width.(int)
|
||||
if !ok {
|
||||
t.Errorf("column %d: expected int width, got %T", i, col.Width)
|
||||
continue
|
||||
}
|
||||
if width != tt.wantWidths[i] {
|
||||
t.Errorf("column %d: expected width %d, got %d", i, tt.wantWidths[i], width)
|
||||
}
|
||||
if col.Type != "TableColumnDefinition" {
|
||||
t.Errorf("column %d: expected type 'TableColumnDefinition', got %q", i, col.Type)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTableRow(t *testing.T) {
|
||||
tests := []struct {
|
||||
line string
|
||||
want []string
|
||||
}{
|
||||
{"| A | B | C |", []string{"A", "B", "C"}},
|
||||
{"|A|B|C|", []string{"A", "B", "C"}},
|
||||
{"| foo | bar |", []string{"foo", "bar"}},
|
||||
{"", nil},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := parseTableRow(tt.line)
|
||||
if len(got) != len(tt.want) {
|
||||
t.Errorf("parseTableRow(%q): got %v, want %v", tt.line, got, tt.want)
|
||||
continue
|
||||
}
|
||||
for i := range got {
|
||||
if got[i] != tt.want[i] {
|
||||
t.Errorf("parseTableRow(%q)[%d]: got %q, want %q", tt.line, i, got[i], tt.want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSeparatorRow(t *testing.T) {
|
||||
tests := []struct {
|
||||
line string
|
||||
want bool
|
||||
}{
|
||||
{"|---|---|", true},
|
||||
{"| --- | --- |", true},
|
||||
{"|:---|---:|", true},
|
||||
{"| :---: | :---: |", true},
|
||||
{"| A | B |", false},
|
||||
{"| foo | bar |", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := isSeparatorRow(tt.line)
|
||||
if got != tt.want {
|
||||
t.Errorf("isSeparatorRow(%q): got %v, want %v", tt.line, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
+79
-35
@@ -24,20 +24,21 @@ var rrCounter atomic.Uint64
|
||||
// CurrentVersion is the latest config schema version
|
||||
const CurrentVersion = 2
|
||||
|
||||
// Config is the current config structure with version support
|
||||
// Config is the current config structure with version support.
|
||||
type Config struct {
|
||||
Version int `json:"version" yaml:"-"` // Config schema version for migration
|
||||
Agents AgentsConfig `json:"agents" yaml:"-"`
|
||||
Bindings []AgentBinding `json:"bindings,omitempty" yaml:"-"`
|
||||
Session SessionConfig `json:"session,omitempty" yaml:"-"`
|
||||
Channels ChannelsConfig `json:"channels" yaml:"channels"`
|
||||
ModelList SecureModelList `json:"model_list" yaml:"model_list"` // New model-centric provider configuration
|
||||
Gateway GatewayConfig `json:"gateway" yaml:"-"`
|
||||
Hooks HooksConfig `json:"hooks,omitempty" yaml:"-"`
|
||||
Tools ToolsConfig `json:"tools" yaml:",inline"`
|
||||
Heartbeat HeartbeatConfig `json:"heartbeat" yaml:"-"`
|
||||
Devices DevicesConfig `json:"devices" yaml:"-"`
|
||||
Voice VoiceConfig `json:"voice" yaml:"-"`
|
||||
Version int `json:"version" yaml:"-"` // Config schema version for migration
|
||||
Isolation IsolationConfig `json:"isolation,omitempty" yaml:"-"`
|
||||
Agents AgentsConfig `json:"agents" yaml:"-"`
|
||||
Bindings []AgentBinding `json:"bindings,omitempty" yaml:"-"`
|
||||
Session SessionConfig `json:"session,omitempty" yaml:"-"`
|
||||
Channels ChannelsConfig `json:"channels" yaml:"channels"`
|
||||
ModelList SecureModelList `json:"model_list" yaml:"model_list"` // New model-centric provider configuration
|
||||
Gateway GatewayConfig `json:"gateway" yaml:"-"`
|
||||
Hooks HooksConfig `json:"hooks,omitempty" yaml:"-"`
|
||||
Tools ToolsConfig `json:"tools" yaml:",inline"`
|
||||
Heartbeat HeartbeatConfig `json:"heartbeat" yaml:"-"`
|
||||
Devices DevicesConfig `json:"devices" yaml:"-"`
|
||||
Voice VoiceConfig `json:"voice" yaml:"-"`
|
||||
// BuildInfo contains build-time version information
|
||||
BuildInfo BuildInfo `json:"build_info,omitempty" yaml:"-"`
|
||||
|
||||
@@ -45,6 +46,21 @@ type Config struct {
|
||||
sensitiveCache *SensitiveDataCache
|
||||
}
|
||||
|
||||
// IsolationConfig controls subprocess isolation for commands started by PicoClaw.
|
||||
// It is applied by the isolation package rather than by sandboxing the main process.
|
||||
type IsolationConfig struct {
|
||||
Enabled bool `json:"enabled,omitempty"`
|
||||
ExposePaths []ExposePath `json:"expose_paths,omitempty"`
|
||||
}
|
||||
|
||||
// ExposePath describes a host path that should remain visible inside the isolated
|
||||
// child-process environment. This is currently implemented on Linux only.
|
||||
type ExposePath struct {
|
||||
Source string `json:"source"`
|
||||
Target string `json:"target,omitempty"`
|
||||
Mode string `json:"mode"`
|
||||
}
|
||||
|
||||
// FilterSensitiveData filters sensitive values from content before sending to LLM.
|
||||
// This prevents the LLM from seeing its own credentials.
|
||||
// Uses strings.Replacer for O(n+m) performance (computed once per SecurityConfig).
|
||||
@@ -280,23 +296,24 @@ func (d *AgentDefaults) GetModelName() string {
|
||||
}
|
||||
|
||||
type ChannelsConfig struct {
|
||||
WhatsApp WhatsAppConfig `json:"whatsapp" yaml:"-"`
|
||||
Telegram TelegramConfig `json:"telegram" yaml:"telegram,omitempty"`
|
||||
Feishu FeishuConfig `json:"feishu" yaml:"feishu,omitempty"`
|
||||
Discord DiscordConfig `json:"discord" yaml:"discord,omitempty"`
|
||||
MaixCam MaixCamConfig `json:"maixcam" yaml:"-"`
|
||||
QQ QQConfig `json:"qq" yaml:"qq,omitempty"`
|
||||
DingTalk DingTalkConfig `json:"dingtalk" yaml:"dingtalk,omitempty"`
|
||||
Slack SlackConfig `json:"slack" yaml:"slack,omitempty"`
|
||||
Matrix MatrixConfig `json:"matrix" yaml:"matrix,omitempty"`
|
||||
LINE LINEConfig `json:"line" yaml:"line,omitempty"`
|
||||
OneBot OneBotConfig `json:"onebot" yaml:"onebot,omitempty"`
|
||||
WeCom WeComConfig `json:"wecom" yaml:"wecom,omitempty" envPrefix:"PICOCLAW_CHANNELS_WECOM_"`
|
||||
Weixin WeixinConfig `json:"weixin" yaml:"weixin,omitempty"`
|
||||
Pico PicoConfig `json:"pico" yaml:"pico,omitempty"`
|
||||
PicoClient PicoClientConfig `json:"pico_client" yaml:"pico_client,omitempty"`
|
||||
IRC IRCConfig `json:"irc" yaml:"irc,omitempty"`
|
||||
VK VKConfig `json:"vk" yaml:"vk,omitempty"`
|
||||
WhatsApp WhatsAppConfig `json:"whatsapp" yaml:"-"`
|
||||
Telegram TelegramConfig `json:"telegram" yaml:"telegram,omitempty"`
|
||||
Feishu FeishuConfig `json:"feishu" yaml:"feishu,omitempty"`
|
||||
Discord DiscordConfig `json:"discord" yaml:"discord,omitempty"`
|
||||
MaixCam MaixCamConfig `json:"maixcam" yaml:"-"`
|
||||
QQ QQConfig `json:"qq" yaml:"qq,omitempty"`
|
||||
DingTalk DingTalkConfig `json:"dingtalk" yaml:"dingtalk,omitempty"`
|
||||
Slack SlackConfig `json:"slack" yaml:"slack,omitempty"`
|
||||
Matrix MatrixConfig `json:"matrix" yaml:"matrix,omitempty"`
|
||||
LINE LINEConfig `json:"line" yaml:"line,omitempty"`
|
||||
OneBot OneBotConfig `json:"onebot" yaml:"onebot,omitempty"`
|
||||
WeCom WeComConfig `json:"wecom" yaml:"wecom,omitempty" envPrefix:"PICOCLAW_CHANNELS_WECOM_"`
|
||||
Weixin WeixinConfig `json:"weixin" yaml:"weixin,omitempty"`
|
||||
Pico PicoConfig `json:"pico" yaml:"pico,omitempty"`
|
||||
PicoClient PicoClientConfig `json:"pico_client" yaml:"pico_client,omitempty"`
|
||||
IRC IRCConfig `json:"irc" yaml:"irc,omitempty"`
|
||||
VK VKConfig `json:"vk" yaml:"vk,omitempty"`
|
||||
TeamsWebhook TeamsWebhookConfig `json:"teams_webhook" yaml:"teams_webhook,omitempty"`
|
||||
}
|
||||
|
||||
// GroupTriggerConfig controls when the bot responds in group chats.
|
||||
@@ -566,6 +583,19 @@ func (c *VKConfig) SetToken(token string) {
|
||||
c.Token = *NewSecureString(token)
|
||||
}
|
||||
|
||||
// TeamsWebhookConfig configures the output-only Microsoft Teams webhook channel.
|
||||
// Multiple webhook targets can be configured and selected via ChatID at send time.
|
||||
type TeamsWebhookConfig struct {
|
||||
Enabled bool `json:"enabled" yaml:"-" env:"PICOCLAW_CHANNELS_TEAMS_WEBHOOK_ENABLED"`
|
||||
Webhooks map[string]TeamsWebhookTarget `json:"webhooks" yaml:"webhooks,omitempty"`
|
||||
}
|
||||
|
||||
// TeamsWebhookTarget represents a single Teams webhook destination.
|
||||
type TeamsWebhookTarget struct {
|
||||
WebhookURL SecureString `json:"webhook_url,omitzero" yaml:"webhook_url,omitempty"`
|
||||
Title string `json:"title,omitempty" yaml:"-"`
|
||||
}
|
||||
|
||||
type HeartbeatConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_HEARTBEAT_ENABLED"`
|
||||
Interval int `json:"interval" env:"PICOCLAW_HEARTBEAT_INTERVAL"` // minutes, min 5
|
||||
@@ -605,11 +635,12 @@ type ModelConfig struct {
|
||||
Workspace string `json:"workspace,omitempty"` // Workspace path for CLI-based providers
|
||||
|
||||
// Optional optimizations
|
||||
RPM int `json:"rpm,omitempty"` // Requests per minute limit
|
||||
MaxTokensField string `json:"max_tokens_field,omitempty"` // Field name for max tokens (e.g., "max_completion_tokens")
|
||||
RequestTimeout int `json:"request_timeout,omitempty"`
|
||||
ThinkingLevel string `json:"thinking_level,omitempty"` // Extended thinking: off|low|medium|high|xhigh|adaptive
|
||||
ExtraBody map[string]any `json:"extra_body,omitempty"` // Additional fields to inject into request body
|
||||
RPM int `json:"rpm,omitempty"` // Requests per minute limit
|
||||
MaxTokensField string `json:"max_tokens_field,omitempty"` // Field name for max tokens (e.g., "max_completion_tokens")
|
||||
RequestTimeout int `json:"request_timeout,omitempty"`
|
||||
ThinkingLevel string `json:"thinking_level,omitempty"` // Extended thinking: off|low|medium|high|xhigh|adaptive
|
||||
ExtraBody map[string]any `json:"extra_body,omitempty"` // Additional fields to inject into request body
|
||||
CustomHeaders map[string]string `json:"custom_headers,omitempty"` // Additional headers to inject into every HTTP request
|
||||
|
||||
APIKeys SecureStrings `json:"api_keys,omitzero" yaml:"api_keys,omitempty"` // API authentication keys (multiple keys for failover)
|
||||
|
||||
@@ -943,10 +974,21 @@ type MCPServerConfig struct {
|
||||
type MCPConfig struct {
|
||||
ToolConfig ` envPrefix:"PICOCLAW_TOOLS_MCP_"`
|
||||
Discovery ToolDiscoveryConfig ` json:"discovery"`
|
||||
// MaxInlineTextChars controls how much MCP text stays inline before it is saved as an artifact.
|
||||
MaxInlineTextChars int `json:"max_inline_text_chars,omitempty" env:"PICOCLAW_TOOLS_MCP_MAX_INLINE_TEXT_CHARS"`
|
||||
// Servers is a map of server name to server configuration
|
||||
Servers map[string]MCPServerConfig `json:"servers,omitempty"`
|
||||
}
|
||||
|
||||
const DefaultMCPMaxInlineTextChars = 16 * 1024
|
||||
|
||||
func (c *MCPConfig) GetMaxInlineTextChars() int {
|
||||
if c.MaxInlineTextChars > 0 {
|
||||
return c.MaxInlineTextChars
|
||||
}
|
||||
return DefaultMCPMaxInlineTextChars
|
||||
}
|
||||
|
||||
func LoadConfig(path string) (*Config, error) {
|
||||
logger.Debugf("loading config from %s", path)
|
||||
|
||||
@@ -1268,6 +1310,7 @@ func expandMultiKeyModels(models []*ModelConfig) []*ModelConfig {
|
||||
RequestTimeout: m.RequestTimeout,
|
||||
ThinkingLevel: m.ThinkingLevel,
|
||||
ExtraBody: m.ExtraBody,
|
||||
CustomHeaders: m.CustomHeaders,
|
||||
isVirtual: true,
|
||||
}
|
||||
expanded = append(expanded, additionalEntry)
|
||||
@@ -1288,6 +1331,7 @@ func expandMultiKeyModels(models []*ModelConfig) []*ModelConfig {
|
||||
RequestTimeout: m.RequestTimeout,
|
||||
ThinkingLevel: m.ThinkingLevel,
|
||||
ExtraBody: m.ExtraBody,
|
||||
CustomHeaders: m.CustomHeaders,
|
||||
APIKeys: SimpleSecureStrings(keys[0]),
|
||||
}
|
||||
|
||||
|
||||
@@ -198,6 +198,41 @@ func TestAgentConfig_FullParse(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultConfig_MCPMaxInlineTextChars(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
if cfg.Tools.MCP.GetMaxInlineTextChars() != DefaultMCPMaxInlineTextChars {
|
||||
t.Fatalf(
|
||||
"DefaultConfig().Tools.MCP.GetMaxInlineTextChars() = %d, want %d",
|
||||
cfg.Tools.MCP.GetMaxInlineTextChars(),
|
||||
DefaultMCPMaxInlineTextChars,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfig_MCPMaxInlineTextChars(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
configPath := filepath.Join(dir, "config.json")
|
||||
raw := `{
|
||||
"tools": {
|
||||
"mcp": {
|
||||
"enabled": true,
|
||||
"max_inline_text_chars": 2048
|
||||
}
|
||||
}
|
||||
}`
|
||||
if err := os.WriteFile(configPath, []byte(raw), 0o644); err != nil {
|
||||
t.Fatalf("WriteFile(configPath): %v", err)
|
||||
}
|
||||
|
||||
cfg, err := LoadConfig(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig() error: %v", err)
|
||||
}
|
||||
if got := cfg.Tools.MCP.GetMaxInlineTextChars(); got != 2048 {
|
||||
t.Fatalf("cfg.Tools.MCP.GetMaxInlineTextChars() = %d, want 2048", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_BackwardCompat_NoAgentsList(t *testing.T) {
|
||||
jsonData := `{
|
||||
"agents": {
|
||||
@@ -817,6 +852,37 @@ func TestDefaultConfig_WorkspacePath_WithPicoclawHome(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultConfig_IsolationEnabled(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
if cfg.Isolation.Enabled {
|
||||
t.Fatal("DefaultConfig().Isolation.Enabled should be false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_UnmarshalIsolation(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
raw := []byte(`{
|
||||
"isolation": {
|
||||
"enabled": false,
|
||||
"expose_paths": [
|
||||
{"source":"/src","target":"/dst","mode":"ro"}
|
||||
]
|
||||
}
|
||||
}`)
|
||||
if err := json.Unmarshal(raw, cfg); err != nil {
|
||||
t.Fatalf("json.Unmarshal isolation config: %v", err)
|
||||
}
|
||||
if cfg.Isolation.Enabled {
|
||||
t.Fatal("Isolation.Enabled should be false after unmarshal")
|
||||
}
|
||||
if len(cfg.Isolation.ExposePaths) != 1 {
|
||||
t.Fatalf("ExposePaths len = %d, want 1", len(cfg.Isolation.ExposePaths))
|
||||
}
|
||||
if got := cfg.Isolation.ExposePaths[0]; got.Source != "/src" || got.Target != "/dst" || got.Mode != "ro" {
|
||||
t.Fatalf("ExposePaths[0] = %+v, want source=/src target=/dst mode=ro", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFlexibleStringSlice_UnmarshalText tests UnmarshalText with various comma separators
|
||||
func TestFlexibleStringSlice_UnmarshalText(t *testing.T) {
|
||||
tests := []struct {
|
||||
@@ -1493,6 +1559,42 @@ func TestModelConfig_ExtraBodyRoundTrip(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelConfig_CustomHeadersRoundTrip(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgPath := filepath.Join(dir, "config.json")
|
||||
|
||||
cfg := &Config{
|
||||
Version: CurrentVersion,
|
||||
ModelList: []*ModelConfig{
|
||||
{
|
||||
ModelName: "test-model",
|
||||
Model: "openai/test",
|
||||
APIKeys: SimpleSecureStrings("sk-test"),
|
||||
CustomHeaders: map[string]string{"X-Source": "coding-plan", "X-Agent": "openclaw"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if err := SaveConfig(cfgPath, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig error: %v", err)
|
||||
}
|
||||
|
||||
loaded, err := LoadConfig(cfgPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadConfig error: %v", err)
|
||||
}
|
||||
|
||||
if loaded.ModelList[0].CustomHeaders == nil {
|
||||
t.Fatal("CustomHeaders should not be nil after round-trip")
|
||||
}
|
||||
if got := loaded.ModelList[0].CustomHeaders["X-Source"]; got != "coding-plan" {
|
||||
t.Errorf("CustomHeaders[X-Source] = %q, want coding-plan", got)
|
||||
}
|
||||
if got := loaded.ModelList[0].CustomHeaders["X-Agent"]; got != "openclaw" {
|
||||
t.Errorf("CustomHeaders[X-Agent] = %q, want openclaw", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultConfig_MinimaxExtraBody(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
|
||||
@@ -17,6 +17,11 @@ func DefaultConfig() *Config {
|
||||
|
||||
return &Config{
|
||||
Version: CurrentVersion,
|
||||
// Isolation is opt-in so existing installations keep their current behavior
|
||||
// until the user explicitly enables subprocess sandboxing.
|
||||
Isolation: IsolationConfig{
|
||||
Enabled: false,
|
||||
},
|
||||
Agents: AgentsConfig{
|
||||
Defaults: AgentDefaults{
|
||||
Workspace: workspacePath,
|
||||
@@ -462,7 +467,8 @@ func DefaultConfig() *Config {
|
||||
UseBM25: true,
|
||||
UseRegex: false,
|
||||
},
|
||||
Servers: map[string]MCPServerConfig{},
|
||||
MaxInlineTextChars: DefaultMCPMaxInlineTextChars,
|
||||
Servers: map[string]MCPServerConfig{},
|
||||
},
|
||||
AppendFile: ToolConfig{
|
||||
Enabled: true,
|
||||
|
||||
@@ -28,6 +28,7 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/channels/pico"
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/qq"
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/slack"
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/teams_webhook"
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/telegram"
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/vk"
|
||||
_ "github.com/sipeed/picoclaw/pkg/channels/wecom"
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"maps"
|
||||
"net/http"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
@@ -31,6 +32,7 @@ type Check struct {
|
||||
type StatusResponse struct {
|
||||
Status string `json:"status"`
|
||||
Uptime string `json:"uptime"`
|
||||
PID int `json:"pid,omitempty"`
|
||||
Checks map[string]Check `json:"checks,omitempty"`
|
||||
}
|
||||
|
||||
@@ -170,6 +172,7 @@ func (s *Server) healthHandler(w http.ResponseWriter, r *http.Request) {
|
||||
resp := StatusResponse{
|
||||
Status: "ok",
|
||||
Uptime: uptime.String(),
|
||||
PID: os.Getpid(),
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
|
||||
@@ -0,0 +1,238 @@
|
||||
# `pkg/isolation`
|
||||
|
||||
`pkg/isolation` provides process-level isolation for child processes started by `picoclaw`.
|
||||
|
||||
It does not sandbox the main `picoclaw` process itself.
|
||||
|
||||
## Scope
|
||||
|
||||
The current scope is the child-process startup path:
|
||||
|
||||
- `exec` tool
|
||||
- CLI providers such as `claude-cli` and `codex-cli`
|
||||
- process hooks
|
||||
- MCP `stdio` servers
|
||||
|
||||
## One-Sentence Model
|
||||
|
||||
- The `picoclaw` main process still runs in the host environment.
|
||||
- Every child process should enter the shared `pkg/isolation` startup path first.
|
||||
- The startup path applies platform-specific isolation according to config.
|
||||
|
||||
## Architecture
|
||||
|
||||
The implementation has four layers:
|
||||
|
||||
1. Configuration layer: reads `config.Config.Isolation` and injects it through `isolation.Configure(cfg)`.
|
||||
2. Instance layout layer: resolves `config.GetHome()`, prepares instance directories, and builds the runtime user environment.
|
||||
3. Platform backend layer: Linux uses `bwrap`; Windows uses a restricted token, low integrity, and a `Job Object`; other platforms are not implemented.
|
||||
4. Unified startup layer: `PrepareCommand(cmd)`, `Start(cmd)`, and `Run(cmd)`.
|
||||
|
||||
All integrations that spawn subprocesses should reuse these helpers instead of calling `cmd.Start` or `cmd.Run` directly.
|
||||
|
||||
## Configuration
|
||||
|
||||
Isolation lives under:
|
||||
|
||||
```json
|
||||
{
|
||||
"isolation": {
|
||||
"enabled": false,
|
||||
"expose_paths": []
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Field meanings:
|
||||
|
||||
- `enabled`: enables or disables subprocess isolation. Default: `false`.
|
||||
- `expose_paths`: explicitly exposes host paths inside the isolated environment. It only matters when `enabled=true`. This is currently supported on Linux only.
|
||||
|
||||
Example:
|
||||
|
||||
```json
|
||||
{
|
||||
"isolation": {
|
||||
"enabled": true,
|
||||
"expose_paths": [
|
||||
{
|
||||
"source": "/opt/toolchains/go",
|
||||
"target": "/opt/toolchains/go",
|
||||
"mode": "ro"
|
||||
},
|
||||
{
|
||||
"source": "/data/shared-assets",
|
||||
"target": "/opt/picoclaw-instance-a/workspace/assets",
|
||||
"mode": "rw"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Rules for `expose_paths`:
|
||||
|
||||
- `source` is a host path.
|
||||
- `target` is the path inside the isolated environment.
|
||||
- `mode` must be `ro` or `rw`.
|
||||
- When `target` is empty, it defaults to `source`.
|
||||
- Only one final rule may exist for the same `target`.
|
||||
- Later-loaded config overrides earlier rules for the same `target`.
|
||||
|
||||
Platform note:
|
||||
|
||||
- Linux uses a real `source -> target` mount view.
|
||||
- Windows does not currently support `expose_paths`.
|
||||
|
||||
## Instance Root And Directories
|
||||
|
||||
The instance root follows `config.GetHome()`:
|
||||
|
||||
- If `PICOCLAW_HOME` is set, use it.
|
||||
- Otherwise use the default `.picoclaw` directory under the user home.
|
||||
|
||||
If `config.GetHome()` falls back to `.` while isolation is enabled, startup should fail.
|
||||
|
||||
Default instance directories include:
|
||||
|
||||
- instance root
|
||||
- `skills`
|
||||
- `logs`
|
||||
- `cache`
|
||||
- `state`
|
||||
- `runtime-user-env`
|
||||
|
||||
`workspace` is derived from `cfg.WorkspacePath()` when configured, otherwise from the default workspace rule.
|
||||
|
||||
Windows also prepares:
|
||||
|
||||
- `runtime-user-env/AppData/Roaming`
|
||||
- `runtime-user-env/AppData/Local`
|
||||
|
||||
## User Environment Redirect
|
||||
|
||||
When isolation is enabled, child processes receive a redirected per-instance user environment.
|
||||
|
||||
Linux variables:
|
||||
|
||||
- `HOME`
|
||||
- `TMPDIR`
|
||||
- `XDG_CONFIG_HOME`
|
||||
- `XDG_CACHE_HOME`
|
||||
- `XDG_STATE_HOME`
|
||||
|
||||
Windows variables:
|
||||
|
||||
- `USERPROFILE`
|
||||
- `HOME`
|
||||
- `TEMP`
|
||||
- `TMP`
|
||||
- `APPDATA`
|
||||
- `LOCALAPPDATA`
|
||||
|
||||
These paths point into `runtime-user-env` under the instance root.
|
||||
|
||||
## Platform Behavior
|
||||
|
||||
### Linux
|
||||
|
||||
The Linux backend currently depends on `bwrap` (`bubblewrap`).
|
||||
|
||||
Capabilities:
|
||||
|
||||
- minimal filesystem view
|
||||
- `ipc` namespace isolation
|
||||
- redirected child-process user environment
|
||||
- `source -> target` read-only or read-write mounts
|
||||
|
||||
Default mounts include the instance root plus the minimum runtime system paths such as `/usr`, `/bin`, `/lib`, `/lib64`, and `/etc/resolv.conf`.
|
||||
|
||||
At runtime, PicoClaw also adds the executable path, its directory, the effective working directory, and absolute path arguments when needed.
|
||||
|
||||
There is no automatic fallback when `bwrap` is missing.
|
||||
|
||||
Install examples:
|
||||
|
||||
- `apt install bubblewrap`
|
||||
- `dnf install bubblewrap`
|
||||
- `yum install bubblewrap`
|
||||
- `pacman -S bubblewrap`
|
||||
- `apk add bubblewrap`
|
||||
|
||||
If isolation must be disabled temporarily:
|
||||
|
||||
```json
|
||||
{
|
||||
"isolation": {
|
||||
"enabled": false
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Disabling isolation increases the risk that child processes can access or modify more host files.
|
||||
|
||||
### Windows
|
||||
|
||||
Windows isolation currently supports process-level restrictions such as restricted tokens, low integrity, job objects, and redirected user-environment directories.
|
||||
|
||||
`expose_paths` is not currently supported on Windows. If it is configured, startup should fail instead of pretending the paths were exposed.
|
||||
|
||||
The Windows backend currently uses:
|
||||
|
||||
- a restricted primary token
|
||||
- low integrity level
|
||||
- a `Job Object`
|
||||
- redirected child-process user environment
|
||||
|
||||
It does not currently implement true `source -> target` filesystem remapping.
|
||||
|
||||
### macOS And Other Platforms
|
||||
|
||||
They are not implemented yet.
|
||||
|
||||
When isolation is explicitly enabled on an unsupported platform, the higher-level runtime should surface that as an unsupported configuration instead of pretending isolation succeeded.
|
||||
|
||||
## Logging And Debugging
|
||||
|
||||
When isolation is enabled, PicoClaw logs the generated isolation plan.
|
||||
|
||||
Linux log name:
|
||||
|
||||
- `linux isolation mount plan`
|
||||
|
||||
Windows log name:
|
||||
|
||||
- `windows isolation access rules`
|
||||
|
||||
If you suspect isolation is ineffective, check whether unexpected host paths appear in those logs.
|
||||
|
||||
## Relationship To `restrict_to_workspace`
|
||||
|
||||
- `restrict_to_workspace` limits the paths an agent is normally allowed to access.
|
||||
- `pkg/isolation` limits what a child process can see and where its user environment points.
|
||||
|
||||
They complement each other and do not replace each other.
|
||||
|
||||
## Current Limits
|
||||
|
||||
- Linux isolation is implemented with `bwrap`, not a custom in-process isolation runtime.
|
||||
- Linux does not currently enable a dedicated `pid` namespace by default.
|
||||
- Windows does not yet implement full host ACL enforcement for every allowed or denied path.
|
||||
- macOS is not implemented.
|
||||
- The current design isolates child processes, not the main `picoclaw` process.
|
||||
|
||||
## Suggested Reading Order
|
||||
|
||||
If you are new to this code, read it in this order:
|
||||
|
||||
1. `pkg/config/config.go`
|
||||
2. `pkg/isolation/runtime.go`
|
||||
3. `pkg/isolation/platform_linux.go`
|
||||
4. `pkg/isolation/platform_windows.go`
|
||||
5. Call sites:
|
||||
6. `pkg/tools/shell.go`
|
||||
7. `pkg/providers/*.go`
|
||||
8. `pkg/agent/hook_process.go`
|
||||
9. `pkg/mcp/manager.go`
|
||||
|
||||
That path gives the fastest overview of the configuration model, runtime flow, and platform-specific limits.
|
||||
@@ -0,0 +1,238 @@
|
||||
# `pkg/isolation`
|
||||
|
||||
`pkg/isolation` 为 `picoclaw` 启动的子进程提供进程级隔离能力。
|
||||
|
||||
它当前不会把 `picoclaw` 主进程自身放进沙箱中运行。
|
||||
|
||||
## 生效范围
|
||||
|
||||
当前生效范围是子进程启动链路:
|
||||
|
||||
- `exec` 工具
|
||||
- `claude-cli`、`codex-cli` 等 CLI provider
|
||||
- 进程型 hooks
|
||||
- MCP `stdio` server
|
||||
|
||||
## 一句话理解
|
||||
|
||||
- `picoclaw` 主进程仍运行在宿主环境中。
|
||||
- 所有子进程都应先经过 `pkg/isolation` 的统一启动入口。
|
||||
- 入口会根据配置和平台,为子进程施加对应隔离。
|
||||
|
||||
## 架构
|
||||
|
||||
当前实现可以分为四层:
|
||||
|
||||
1. 配置层:读取 `config.Config.Isolation`,并通过 `isolation.Configure(cfg)` 注入运行时。
|
||||
2. 实例目录层:解析 `config.GetHome()`,准备实例目录,并构建运行时用户环境目录。
|
||||
3. 平台后端层:Linux 使用 `bwrap`;Windows 使用受限 token、低完整性级别和 `Job Object`;其他平台未实现。
|
||||
4. 统一启动层:`PrepareCommand(cmd)`、`Start(cmd)`、`Run(cmd)`。
|
||||
|
||||
所有启动子进程的接入点都应复用这组入口,而不是各自直接调用 `cmd.Start` 或 `cmd.Run`。
|
||||
|
||||
## 配置
|
||||
|
||||
隔离配置位于:
|
||||
|
||||
```json
|
||||
{
|
||||
"isolation": {
|
||||
"enabled": false,
|
||||
"expose_paths": []
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
字段说明:
|
||||
|
||||
- `enabled`:是否启用子进程隔离。默认值:`false`。
|
||||
- `expose_paths`:显式把宿主路径带入隔离环境。仅在 `enabled=true` 时生效。目前只在 Linux 上支持。
|
||||
|
||||
示例:
|
||||
|
||||
```json
|
||||
{
|
||||
"isolation": {
|
||||
"enabled": true,
|
||||
"expose_paths": [
|
||||
{
|
||||
"source": "/opt/toolchains/go",
|
||||
"target": "/opt/toolchains/go",
|
||||
"mode": "ro"
|
||||
},
|
||||
{
|
||||
"source": "/data/shared-assets",
|
||||
"target": "/opt/picoclaw-instance-a/workspace/assets",
|
||||
"mode": "rw"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
`expose_paths` 规则:
|
||||
|
||||
- `source`:宿主机路径。
|
||||
- `target`:隔离环境内的目标路径。
|
||||
- `mode`:只能是 `ro` 或 `rw`。
|
||||
- `target` 为空时,默认等于 `source`。
|
||||
- 同一个 `target` 最终只能保留一条规则。
|
||||
- 后加载的配置会覆盖先加载的同目标规则。
|
||||
|
||||
平台说明:
|
||||
|
||||
- Linux 会真实使用 `source -> target` 挂载视图。
|
||||
- Windows 当前不支持 `expose_paths`。
|
||||
|
||||
## 实例根与目录
|
||||
|
||||
实例根遵循 `config.GetHome()`:
|
||||
|
||||
- 如果设置了 `PICOCLAW_HOME`,使用该值。
|
||||
- 否则默认使用用户目录下的 `.picoclaw`。
|
||||
|
||||
如果 `config.GetHome()` 在隔离开启时最终回退到当前目录 `.`,启动应直接失败。
|
||||
|
||||
默认实例目录包括:
|
||||
|
||||
- 实例根本身
|
||||
- `skills`
|
||||
- `logs`
|
||||
- `cache`
|
||||
- `state`
|
||||
- `runtime-user-env`
|
||||
|
||||
`workspace` 优先使用 `cfg.WorkspacePath()` 的结果;未显式配置时才按默认规则派生。
|
||||
|
||||
Windows 还会额外准备:
|
||||
|
||||
- `runtime-user-env/AppData/Roaming`
|
||||
- `runtime-user-env/AppData/Local`
|
||||
|
||||
## 用户环境重定向
|
||||
|
||||
隔离开启后,子进程会收到重定向到实例目录下的独立用户环境。
|
||||
|
||||
Linux 注入变量:
|
||||
|
||||
- `HOME`
|
||||
- `TMPDIR`
|
||||
- `XDG_CONFIG_HOME`
|
||||
- `XDG_CACHE_HOME`
|
||||
- `XDG_STATE_HOME`
|
||||
|
||||
Windows 注入变量:
|
||||
|
||||
- `USERPROFILE`
|
||||
- `HOME`
|
||||
- `TEMP`
|
||||
- `TMP`
|
||||
- `APPDATA`
|
||||
- `LOCALAPPDATA`
|
||||
|
||||
这些路径都会指向实例根下的 `runtime-user-env`。
|
||||
|
||||
## 平台行为
|
||||
|
||||
### Linux
|
||||
|
||||
Linux 后端当前依赖 `bwrap`(`bubblewrap`)。
|
||||
|
||||
能力:
|
||||
|
||||
- 最小文件系统视图
|
||||
- `ipc namespace`
|
||||
- 子进程用户环境重定向
|
||||
- `source -> target` 只读或读写挂载
|
||||
|
||||
默认映射包括实例根,以及 `/usr`、`/bin`、`/lib`、`/lib64`、`/etc/resolv.conf` 等最小运行时系统路径。
|
||||
|
||||
运行时还会按需补充可执行文件本身、其所在目录、生效后的工作目录,以及命令行中的绝对路径参数。
|
||||
|
||||
缺少 `bwrap` 时不会自动回退。
|
||||
|
||||
安装示例:
|
||||
|
||||
- `apt install bubblewrap`
|
||||
- `dnf install bubblewrap`
|
||||
- `yum install bubblewrap`
|
||||
- `pacman -S bubblewrap`
|
||||
- `apk add bubblewrap`
|
||||
|
||||
如果需要临时关闭隔离:
|
||||
|
||||
```json
|
||||
{
|
||||
"isolation": {
|
||||
"enabled": false
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
关闭隔离后,子进程访问或修改更多宿主文件的风险会明显上升。
|
||||
|
||||
### Windows
|
||||
|
||||
Windows 隔离当前提供的是进程级限制,例如 restricted token、low integrity、job object,以及用户环境目录重定向。
|
||||
|
||||
`expose_paths` 目前不支持 Windows。如果配置了该字段,启动应直接失败,而不是假装这些路径已经被暴露进隔离环境。
|
||||
|
||||
Windows 后端当前使用:
|
||||
|
||||
- 受限 primary token
|
||||
- 低完整性级别
|
||||
- `Job Object`
|
||||
- 子进程用户环境重定向
|
||||
|
||||
它当前不会实现真正的 `source -> target` 文件系统重映射。
|
||||
|
||||
### macOS 与其他平台
|
||||
|
||||
当前尚未实现。
|
||||
|
||||
当在未支持的平台上显式开启隔离时,上层运行时应将其视为不支持的配置,而不是假装隔离成功。
|
||||
|
||||
## 日志与排障
|
||||
|
||||
隔离开启后,PicoClaw 会打印生成后的隔离计划,便于排障。
|
||||
|
||||
Linux 日志名:
|
||||
|
||||
- `linux isolation mount plan`
|
||||
|
||||
Windows 日志名:
|
||||
|
||||
- `windows isolation access rules`
|
||||
|
||||
如果你怀疑隔离未生效,先检查这些日志里是否出现了不应暴露的宿主路径。
|
||||
|
||||
## 与 `restrict_to_workspace` 的关系
|
||||
|
||||
- `restrict_to_workspace` 限制的是 agent 默认可访问的路径。
|
||||
- `pkg/isolation` 限制的是子进程运行时能看到什么文件系统,以及它的用户环境指向哪里。
|
||||
|
||||
两者互补,不互相替代。
|
||||
|
||||
## 当前限制
|
||||
|
||||
- Linux 基于 `bwrap` 实现,而不是纯内建 isolation runtime。
|
||||
- Linux 当前没有默认启用独立的 `pid namespace`。
|
||||
- Windows 还没有对所有允许/拒绝路径做完整 ACL 落地。
|
||||
- macOS 尚未实现。
|
||||
- 当前隔离的是子进程,不是 `picoclaw` 主进程自身。
|
||||
|
||||
## 建议阅读顺序
|
||||
|
||||
如果你是第一次看这部分代码,建议按这个顺序阅读:
|
||||
|
||||
1. `pkg/config/config.go`
|
||||
2. `pkg/isolation/runtime.go`
|
||||
3. `pkg/isolation/platform_linux.go`
|
||||
4. `pkg/isolation/platform_windows.go`
|
||||
5. 调用点:
|
||||
6. `pkg/tools/shell.go`
|
||||
7. `pkg/providers/*.go`
|
||||
8. `pkg/agent/hook_process.go`
|
||||
9. `pkg/mcp/manager.go`
|
||||
|
||||
这样能最快建立对配置模型、运行流程和平台边界的整体理解。
|
||||
@@ -0,0 +1,264 @@
|
||||
//go:build linux
|
||||
|
||||
package isolation
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
func applyPlatformIsolation(cmd *exec.Cmd, isolation config.IsolationConfig, root string) error {
|
||||
if !isolation.Enabled {
|
||||
return nil
|
||||
}
|
||||
// Bubblewrap is the only supported Linux backend right now. Fail closed when
|
||||
// it is unavailable instead of silently running the child process unisolated.
|
||||
bwrapPath, err := exec.LookPath("bwrap")
|
||||
if err != nil {
|
||||
hint := bwrapInstallHint()
|
||||
disableHint := `set "isolation.enabled": false in config.json`
|
||||
logger.WarnCF("isolation", "bubblewrap is required for Linux isolation",
|
||||
map[string]any{
|
||||
"binary": "bwrap",
|
||||
"install": hint,
|
||||
"disable_isolation": disableHint,
|
||||
"risk": "disabling isolation lets child processes run without Linux filesystem isolation",
|
||||
})
|
||||
return fmt.Errorf(
|
||||
"linux isolation requires bwrap and does not fall back automatically: %w; install bubblewrap with one of: %s; or disable isolation by setting %s; disabling isolation means child processes can run without Linux filesystem isolation and may access or modify more host files",
|
||||
err,
|
||||
hint,
|
||||
disableHint,
|
||||
)
|
||||
}
|
||||
if cmd == nil || cmd.Path == "" || len(cmd.Args) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
originalPath := cmd.Path
|
||||
originalArgs := append([]string{}, cmd.Args...)
|
||||
_, execDir, err := resolveLinuxWorkingDir(cmd.Dir, originalPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resolvedPath, err := resolveLinuxCommandPath(originalPath, execDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Start from the configured mount plan, then add only the executable, its
|
||||
// resolved path, the effective working directory, and any absolute path
|
||||
// arguments needed to preserve the original command semantics.
|
||||
plan := BuildLinuxMountPlan(root, isolation.ExposePaths)
|
||||
plan = ensureLinuxMountRule(plan, resolvedPath, resolvedPath, "ro")
|
||||
plan = ensureLinuxMountRule(plan, filepath.Dir(resolvedPath), filepath.Dir(resolvedPath), "ro")
|
||||
if resolved, resolveErr := filepath.EvalSymlinks(resolvedPath); resolveErr == nil && resolved != resolvedPath {
|
||||
plan = ensureLinuxMountRule(plan, resolved, resolved, "ro")
|
||||
plan = ensureLinuxMountRule(plan, filepath.Dir(resolved), filepath.Dir(resolved), "ro")
|
||||
}
|
||||
if execDir != "" {
|
||||
plan = ensureLinuxMountRule(plan, execDir, execDir, "rw")
|
||||
if resolved, resolveErr := filepath.EvalSymlinks(execDir); resolveErr == nil && resolved != execDir {
|
||||
plan = ensureLinuxMountRule(plan, resolved, resolved, "rw")
|
||||
}
|
||||
}
|
||||
plan = appendLinuxArgumentMounts(plan, originalArgs[1:])
|
||||
logger.DebugCF("isolation", "linux isolation mount plan",
|
||||
map[string]any{
|
||||
"root": root,
|
||||
"command": resolvedPath,
|
||||
"working_dir": execDir,
|
||||
"mounts": formatLinuxMountPlan(plan),
|
||||
})
|
||||
bwrapArgs, err := buildLinuxBwrapArgs(originalPath, resolvedPath, originalArgs, execDir, plan)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Path = bwrapPath
|
||||
cmd.Args = bwrapArgs
|
||||
cmd.Dir = ""
|
||||
return nil
|
||||
}
|
||||
|
||||
func bwrapInstallHint() string {
|
||||
return "apt install bubblewrap; dnf install bubblewrap; yum install bubblewrap; pacman -S bubblewrap; apk add bubblewrap"
|
||||
}
|
||||
|
||||
// formatLinuxMountPlan reshapes the internal plan for structured logging.
|
||||
func formatLinuxMountPlan(plan []MountRule) []map[string]string {
|
||||
formatted := make([]map[string]string, 0, len(plan))
|
||||
for _, rule := range plan {
|
||||
formatted = append(formatted, map[string]string{
|
||||
"source": rule.Source,
|
||||
"target": rule.Target,
|
||||
"mode": rule.Mode,
|
||||
})
|
||||
}
|
||||
return formatted
|
||||
}
|
||||
|
||||
func postStartPlatformIsolation(cmd *exec.Cmd, isolation config.IsolationConfig, root string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func cleanupPendingPlatformResources(cmd *exec.Cmd) {
|
||||
}
|
||||
|
||||
// buildLinuxBwrapArgs translates the mount plan into the bubblewrap command
|
||||
// line that re-executes the original process inside the isolated mount view.
|
||||
func buildLinuxBwrapArgs(
|
||||
originalPath string,
|
||||
resolvedPath string,
|
||||
originalArgs []string,
|
||||
execDir string,
|
||||
plan []MountRule,
|
||||
) ([]string, error) {
|
||||
bwrapArgs := []string{
|
||||
"bwrap",
|
||||
"--die-with-parent",
|
||||
"--unshare-ipc",
|
||||
"--proc", "/proc",
|
||||
"--dev", "/dev",
|
||||
}
|
||||
for _, rule := range plan {
|
||||
flag, err := linuxBindFlag(rule)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
bwrapArgs = append(bwrapArgs, flag, rule.Source, rule.Target)
|
||||
}
|
||||
if execDir != "" {
|
||||
bwrapArgs = append(bwrapArgs, "--chdir", execDir)
|
||||
}
|
||||
execPath := originalPath
|
||||
if isRelativeCommandPath(originalPath) {
|
||||
execPath = resolvedPath
|
||||
}
|
||||
bwrapArgs = append(bwrapArgs, "--", execPath)
|
||||
if len(originalArgs) > 1 {
|
||||
bwrapArgs = append(bwrapArgs, originalArgs[1:]...)
|
||||
}
|
||||
return bwrapArgs, nil
|
||||
}
|
||||
|
||||
func resolveLinuxWorkingDir(originalDir, originalPath string) (string, string, error) {
|
||||
if originalDir != "" {
|
||||
resolved, err := filepath.Abs(originalDir)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("resolve command dir %s: %w", originalDir, err)
|
||||
}
|
||||
return resolved, resolved, nil
|
||||
}
|
||||
if !isRelativeCommandPath(originalPath) {
|
||||
return "", "", nil
|
||||
}
|
||||
wd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("resolve current working dir: %w", err)
|
||||
}
|
||||
return "", wd, nil
|
||||
}
|
||||
|
||||
func resolveLinuxCommandPath(originalPath, execDir string) (string, error) {
|
||||
if filepath.IsAbs(originalPath) || !isRelativeCommandPath(originalPath) {
|
||||
return filepath.Clean(originalPath), nil
|
||||
}
|
||||
base := execDir
|
||||
if base == "" {
|
||||
var err error
|
||||
base, err = os.Getwd()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("resolve current working dir: %w", err)
|
||||
}
|
||||
}
|
||||
return filepath.Clean(filepath.Join(base, originalPath)), nil
|
||||
}
|
||||
|
||||
func appendLinuxArgumentMounts(plan []MountRule, args []string) []MountRule {
|
||||
for _, arg := range args {
|
||||
path, ok := linuxArgumentPath(arg)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
clean := filepath.Clean(path)
|
||||
if info, err := os.Stat(clean); err == nil {
|
||||
mode := "ro"
|
||||
if info.IsDir() {
|
||||
mode = "rw"
|
||||
}
|
||||
plan = ensureLinuxMountRule(plan, clean, clean, mode)
|
||||
if resolved, resolveErr := filepath.EvalSymlinks(clean); resolveErr == nil && resolved != clean {
|
||||
plan = ensureLinuxMountRule(plan, resolved, resolved, mode)
|
||||
}
|
||||
continue
|
||||
} else if !errors.Is(err, os.ErrNotExist) {
|
||||
continue
|
||||
}
|
||||
parent := filepath.Dir(clean)
|
||||
if parent == clean {
|
||||
continue
|
||||
}
|
||||
if _, err := os.Stat(parent); err == nil {
|
||||
plan = ensureLinuxMountRule(plan, parent, parent, "rw")
|
||||
}
|
||||
}
|
||||
return plan
|
||||
}
|
||||
|
||||
func linuxArgumentPath(arg string) (string, bool) {
|
||||
if filepath.IsAbs(arg) {
|
||||
return arg, true
|
||||
}
|
||||
idx := strings.IndexRune(arg, '=')
|
||||
if idx <= 0 || idx == len(arg)-1 {
|
||||
return "", false
|
||||
}
|
||||
value := arg[idx+1:]
|
||||
if !filepath.IsAbs(value) {
|
||||
return "", false
|
||||
}
|
||||
return value, true
|
||||
}
|
||||
|
||||
func isRelativeCommandPath(path string) bool {
|
||||
return !filepath.IsAbs(path) && strings.ContainsRune(path, filepath.Separator)
|
||||
}
|
||||
|
||||
// ensureLinuxMountRule appends a mount rule unless another rule already owns
|
||||
// the same target path.
|
||||
func ensureLinuxMountRule(plan []MountRule, source, target, mode string) []MountRule {
|
||||
cleanSource := filepath.Clean(source)
|
||||
cleanTarget := filepath.Clean(target)
|
||||
for _, rule := range plan {
|
||||
if filepath.Clean(rule.Target) == cleanTarget {
|
||||
return plan
|
||||
}
|
||||
}
|
||||
return append(plan, MountRule{Source: cleanSource, Target: cleanTarget, Mode: mode})
|
||||
}
|
||||
|
||||
// linuxBindFlag selects the correct bubblewrap bind flag based on mount mode.
|
||||
func linuxBindFlag(rule MountRule) (string, error) {
|
||||
info, err := os.Stat(rule.Source)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("stat linux mount source %s: %w", rule.Source, err)
|
||||
}
|
||||
if !info.IsDir() {
|
||||
if rule.Mode == "rw" {
|
||||
return "--bind", nil
|
||||
}
|
||||
return "--ro-bind", nil
|
||||
}
|
||||
if rule.Mode == "rw" {
|
||||
return "--bind", nil
|
||||
}
|
||||
return "--ro-bind", nil
|
||||
}
|
||||
@@ -0,0 +1,148 @@
|
||||
//go:build linux
|
||||
|
||||
package isolation
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
func TestBuildLinuxBwrapArgs_IncludesNamespaceFlagsAndExec(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
binaryDir := filepath.Join(root, "bin")
|
||||
if err := os.MkdirAll(binaryDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
binaryPath := filepath.Join(binaryDir, "tool")
|
||||
if err := os.WriteFile(binaryPath, []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
plan := BuildLinuxMountPlan(root, []config.ExposePath{{Source: binaryDir, Target: binaryDir, Mode: "ro"}})
|
||||
args, err := buildLinuxBwrapArgs(binaryPath, binaryPath, []string{binaryPath, "--flag"}, root, plan)
|
||||
if err != nil {
|
||||
t.Fatalf("buildLinuxBwrapArgs() error = %v", err)
|
||||
}
|
||||
hasNet := false
|
||||
hasIPC := false
|
||||
hasExec := false
|
||||
for i := range args {
|
||||
switch args[i] {
|
||||
case "--unshare-net":
|
||||
hasNet = true
|
||||
case "--unshare-ipc":
|
||||
hasIPC = true
|
||||
case "--":
|
||||
if i+1 < len(args) && args[i+1] == binaryPath {
|
||||
hasExec = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if hasNet {
|
||||
t.Fatalf("bwrap args should not unshare net by default: %v", args)
|
||||
}
|
||||
if !hasIPC || !hasExec {
|
||||
t.Fatalf("bwrap args missing required items: %v", args)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveLinuxWorkingDir_ResolvesRelativeDir(t *testing.T) {
|
||||
cwd := t.TempDir()
|
||||
previous, err := os.Getwd()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if chdirErr := os.Chdir(previous); chdirErr != nil {
|
||||
t.Fatalf("restore cwd: %v", chdirErr)
|
||||
}
|
||||
}()
|
||||
if chdirErr := os.Chdir(cwd); chdirErr != nil {
|
||||
t.Fatal(chdirErr)
|
||||
}
|
||||
|
||||
resolvedDir, execDir, err := resolveLinuxWorkingDir("./hooks", "./hook.sh")
|
||||
if err != nil {
|
||||
t.Fatalf("resolveLinuxWorkingDir() error = %v", err)
|
||||
}
|
||||
want := filepath.Join(cwd, "hooks")
|
||||
if resolvedDir != want || execDir != want {
|
||||
t.Fatalf("resolveLinuxWorkingDir() = (%q, %q), want (%q, %q)", resolvedDir, execDir, want, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveLinuxCommandPath_UsesExecDirForRelativeCommand(t *testing.T) {
|
||||
execDir := filepath.Join(t.TempDir(), "hooks")
|
||||
got, err := resolveLinuxCommandPath("./hook.sh", execDir)
|
||||
if err != nil {
|
||||
t.Fatalf("resolveLinuxCommandPath() error = %v", err)
|
||||
}
|
||||
want := filepath.Join(execDir, "hook.sh")
|
||||
if got != want {
|
||||
t.Fatalf("resolveLinuxCommandPath() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildLinuxBwrapArgs_UsesResolvedPathForRelativeCommand(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
execDir := filepath.Join(root, "hooks")
|
||||
if err := os.MkdirAll(execDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
resolvedPath := filepath.Join(execDir, "hook.sh")
|
||||
if err := os.WriteFile(resolvedPath, []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
plan := []MountRule{
|
||||
{Source: execDir, Target: execDir, Mode: "rw"},
|
||||
{Source: resolvedPath, Target: resolvedPath, Mode: "ro"},
|
||||
}
|
||||
args, err := buildLinuxBwrapArgs("./hook.sh", resolvedPath, []string{"./hook.sh"}, execDir, plan)
|
||||
if err != nil {
|
||||
t.Fatalf("buildLinuxBwrapArgs() error = %v", err)
|
||||
}
|
||||
hasExecDir := false
|
||||
for _, arg := range args {
|
||||
if arg == execDir {
|
||||
hasExecDir = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasExecDir {
|
||||
t.Fatalf("buildLinuxBwrapArgs() missing resolved chdir: %v", args)
|
||||
}
|
||||
for i := range args {
|
||||
if args[i] == "--" {
|
||||
if i+1 >= len(args) || args[i+1] != resolvedPath {
|
||||
t.Fatalf("buildLinuxBwrapArgs() exec path = %v, want %q after --", args, resolvedPath)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
t.Fatalf("buildLinuxBwrapArgs() missing exec delimiter: %v", args)
|
||||
}
|
||||
|
||||
func TestAppendLinuxArgumentMounts_AddsAbsoluteArgumentPaths(t *testing.T) {
|
||||
root := t.TempDir()
|
||||
input := filepath.Join(root, "input.txt")
|
||||
if err := os.WriteFile(input, []byte("data"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
output := filepath.Join(root, "out", "result.txt")
|
||||
if err := os.MkdirAll(filepath.Dir(output), 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
plan := appendLinuxArgumentMounts(nil, []string{input, "--output=" + output})
|
||||
if len(plan) != 2 {
|
||||
t.Fatalf("appendLinuxArgumentMounts() len = %d, want 2", len(plan))
|
||||
}
|
||||
if plan[0].Source != input || plan[0].Mode != "ro" {
|
||||
t.Fatalf("appendLinuxArgumentMounts()[0] = %+v, want source=%q mode=ro", plan[0], input)
|
||||
}
|
||||
if plan[1].Source != filepath.Dir(output) || plan[1].Mode != "rw" {
|
||||
t.Fatalf("appendLinuxArgumentMounts()[1] = %+v, want source=%q mode=rw", plan[1], filepath.Dir(output))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
//go:build !linux && !windows
|
||||
|
||||
package isolation
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
func applyPlatformIsolation(cmd *exec.Cmd, isolation config.IsolationConfig, root string) error {
|
||||
// Unsupported platforms currently keep the command unchanged. Callers rely on
|
||||
// Preflight and higher-level checks to surface unsupported isolation modes.
|
||||
return nil
|
||||
}
|
||||
|
||||
func postStartPlatformIsolation(cmd *exec.Cmd, isolation config.IsolationConfig, root string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func cleanupPendingPlatformResources(cmd *exec.Cmd) {
|
||||
}
|
||||
@@ -0,0 +1,217 @@
|
||||
//go:build windows
|
||||
|
||||
package isolation
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"sync"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
const disableMaxPrivilege = 0x1
|
||||
|
||||
// windowsProcessResources holds native handles that must live for the lifetime
|
||||
// of an isolated child process.
|
||||
type windowsProcessResources struct {
|
||||
job windows.Handle
|
||||
token windows.Token
|
||||
}
|
||||
|
||||
var (
|
||||
windowsProcessResourcesByPID sync.Map
|
||||
windowsPendingResources sync.Map
|
||||
advapi32 = windows.NewLazySystemDLL("advapi32.dll")
|
||||
procCreateRestrictedToken = advapi32.NewProc("CreateRestrictedToken")
|
||||
)
|
||||
|
||||
func applyPlatformIsolation(cmd *exec.Cmd, isolation config.IsolationConfig, root string) error {
|
||||
if !isolation.Enabled || cmd == nil {
|
||||
return nil
|
||||
}
|
||||
if cmd.SysProcAttr == nil {
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{}
|
||||
}
|
||||
rules := BuildWindowsAccessRules(root, isolation.ExposePaths)
|
||||
logger.InfoCF("isolation", "windows isolation process constraints",
|
||||
map[string]any{
|
||||
"root": root,
|
||||
"command": cmd.Path,
|
||||
"rules": formatWindowsAccessRules(rules),
|
||||
"note": "Windows currently enforces restricted token, low integrity, and job object limits; expose_paths filesystem remapping is rejected during preflight",
|
||||
})
|
||||
// Create the restricted token before the process starts so CreateProcess uses
|
||||
// the reduced privilege set from the first instruction.
|
||||
restrictedToken, err := createRestrictedPrimaryToken()
|
||||
if err != nil {
|
||||
return fmt.Errorf("create restricted primary token: %w", err)
|
||||
}
|
||||
cmd.SysProcAttr.CreationFlags |= windows.CREATE_NEW_PROCESS_GROUP | windows.CREATE_BREAKAWAY_FROM_JOB
|
||||
cmd.SysProcAttr.Token = syscall.Token(restrictedToken)
|
||||
windowsPendingResources.Store(cmd, windowsProcessResources{token: restrictedToken})
|
||||
return nil
|
||||
}
|
||||
|
||||
func postStartPlatformIsolation(cmd *exec.Cmd, isolation config.IsolationConfig, root string) error {
|
||||
if !isolation.Enabled || cmd == nil || cmd.Process == nil {
|
||||
return nil
|
||||
}
|
||||
resourcesAny, _ := windowsPendingResources.LoadAndDelete(cmd)
|
||||
resources, _ := resourcesAny.(windowsProcessResources)
|
||||
// Job objects can only be attached after the process exists, so the Windows
|
||||
// backend finishes isolation in this post-start hook.
|
||||
job, err := windows.CreateJobObject(nil, nil)
|
||||
if err != nil {
|
||||
if resources.token != 0 {
|
||||
_ = resources.token.Close()
|
||||
}
|
||||
return fmt.Errorf("create windows job object: %w", err)
|
||||
}
|
||||
|
||||
info := windows.JOBOBJECT_EXTENDED_LIMIT_INFORMATION{}
|
||||
info.BasicLimitInformation.LimitFlags = windows.JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE
|
||||
if _, err := windows.SetInformationJobObject(
|
||||
job,
|
||||
windows.JobObjectExtendedLimitInformation,
|
||||
uintptr(unsafe.Pointer(&info)),
|
||||
uint32(unsafe.Sizeof(info)),
|
||||
); err != nil {
|
||||
_ = windows.CloseHandle(job)
|
||||
if resources.token != 0 {
|
||||
_ = resources.token.Close()
|
||||
}
|
||||
return fmt.Errorf("set windows job object info: %w", err)
|
||||
}
|
||||
|
||||
proc, err := windows.OpenProcess(
|
||||
windows.PROCESS_SET_QUOTA|windows.PROCESS_TERMINATE|windows.PROCESS_QUERY_LIMITED_INFORMATION|windows.SYNCHRONIZE,
|
||||
false,
|
||||
uint32(cmd.Process.Pid),
|
||||
)
|
||||
if err != nil {
|
||||
_ = windows.CloseHandle(job)
|
||||
if resources.token != 0 {
|
||||
_ = resources.token.Close()
|
||||
}
|
||||
return fmt.Errorf("open process for job assignment: %w", err)
|
||||
}
|
||||
|
||||
if err := windows.AssignProcessToJobObject(job, proc); err != nil {
|
||||
_ = windows.CloseHandle(proc)
|
||||
_ = windows.CloseHandle(job)
|
||||
if resources.token != 0 {
|
||||
_ = resources.token.Close()
|
||||
}
|
||||
return fmt.Errorf("assign process to job object: %w", err)
|
||||
}
|
||||
|
||||
if resources.token != 0 {
|
||||
_ = resources.token.Close()
|
||||
}
|
||||
resources.job = job
|
||||
windowsProcessResourcesByPID.Store(cmd.Process.Pid, resources)
|
||||
go reapWindowsProcessResources(cmd.Process.Pid, proc, job)
|
||||
return nil
|
||||
}
|
||||
|
||||
func cleanupPendingPlatformResources(cmd *exec.Cmd) {
|
||||
if cmd == nil {
|
||||
return
|
||||
}
|
||||
resourcesAny, ok := windowsPendingResources.LoadAndDelete(cmd)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
resources, _ := resourcesAny.(windowsProcessResources)
|
||||
if resources.token != 0 {
|
||||
_ = resources.token.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func reapWindowsProcessResources(pid int, proc windows.Handle, job windows.Handle) {
|
||||
_, _ = windows.WaitForSingleObject(proc, windows.INFINITE)
|
||||
_ = windows.CloseHandle(proc)
|
||||
_ = windows.CloseHandle(job)
|
||||
windowsProcessResourcesByPID.Delete(pid)
|
||||
}
|
||||
|
||||
// createRestrictedPrimaryToken duplicates the current process token, removes
|
||||
// maximum privileges, and lowers integrity before it is assigned to a child.
|
||||
func createRestrictedPrimaryToken() (windows.Token, error) {
|
||||
var current windows.Token
|
||||
if err := windows.OpenProcessToken(
|
||||
windows.CurrentProcess(),
|
||||
windows.TOKEN_DUPLICATE|windows.TOKEN_ASSIGN_PRIMARY|windows.TOKEN_QUERY|windows.TOKEN_ADJUST_DEFAULT,
|
||||
¤t,
|
||||
); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer current.Close()
|
||||
|
||||
var restricted windows.Token
|
||||
r1, _, e1 := procCreateRestrictedToken.Call(
|
||||
uintptr(current),
|
||||
uintptr(disableMaxPrivilege),
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
uintptr(unsafe.Pointer(&restricted)),
|
||||
)
|
||||
if r1 == 0 {
|
||||
if e1 != nil && e1 != syscall.Errno(0) {
|
||||
return 0, e1
|
||||
}
|
||||
return 0, syscall.EINVAL
|
||||
}
|
||||
if err := setTokenLowIntegrity(restricted); err != nil {
|
||||
_ = restricted.Close()
|
||||
return 0, err
|
||||
}
|
||||
return restricted, nil
|
||||
}
|
||||
|
||||
// setTokenLowIntegrity lowers the token integrity level so writes to higher
|
||||
// integrity locations are blocked by the OS.
|
||||
func setTokenLowIntegrity(token windows.Token) error {
|
||||
lowSID, err := windows.CreateWellKnownSid(windows.WinLowLabelSid)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create low integrity sid: %w", err)
|
||||
}
|
||||
tml := windows.Tokenmandatorylabel{
|
||||
Label: windows.SIDAndAttributes{
|
||||
Sid: lowSID,
|
||||
Attributes: windows.SE_GROUP_INTEGRITY,
|
||||
},
|
||||
}
|
||||
if err := windows.SetTokenInformation(
|
||||
token,
|
||||
windows.TokenIntegrityLevel,
|
||||
(*byte)(unsafe.Pointer(&tml)),
|
||||
tml.Size(),
|
||||
); err != nil {
|
||||
return fmt.Errorf("set token low integrity: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// formatWindowsAccessRules reshapes the internal rules for structured logging.
|
||||
func formatWindowsAccessRules(rules []AccessRule) []map[string]string {
|
||||
formatted := make([]map[string]string, 0, len(rules))
|
||||
for _, rule := range rules {
|
||||
formatted = append(formatted, map[string]string{
|
||||
"path": rule.Path,
|
||||
"mode": rule.Mode,
|
||||
})
|
||||
}
|
||||
return formatted
|
||||
}
|
||||
@@ -0,0 +1,443 @@
|
||||
package isolation
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
// MountRule describes a source-to-target mount exposed inside the Linux
|
||||
// isolation view.
|
||||
type MountRule struct {
|
||||
Source string
|
||||
Target string
|
||||
Mode string
|
||||
}
|
||||
|
||||
// AccessRule describes the effective Windows-side access rule for a host path.
|
||||
type AccessRule struct {
|
||||
Path string
|
||||
Mode string
|
||||
}
|
||||
|
||||
// UserEnv contains the redirected per-instance user directories injected into
|
||||
// isolated child processes.
|
||||
type UserEnv struct {
|
||||
Home string
|
||||
Tmp string
|
||||
Config string
|
||||
Cache string
|
||||
State string
|
||||
AppData string
|
||||
LocalAppData string
|
||||
}
|
||||
|
||||
var (
|
||||
isolationMu sync.RWMutex
|
||||
currentIsolation = config.DefaultConfig().Isolation
|
||||
)
|
||||
|
||||
// Configure updates the process-wide isolation state used by subsequent child
|
||||
// process launches.
|
||||
func Configure(cfg *config.Config) {
|
||||
isolationMu.Lock()
|
||||
defer isolationMu.Unlock()
|
||||
if cfg == nil {
|
||||
defaults := config.DefaultConfig()
|
||||
currentIsolation = defaults.Isolation
|
||||
return
|
||||
}
|
||||
currentIsolation = cfg.Isolation
|
||||
}
|
||||
|
||||
// CurrentConfig returns the currently active isolation settings.
|
||||
func CurrentConfig() config.IsolationConfig {
|
||||
isolationMu.RLock()
|
||||
defer isolationMu.RUnlock()
|
||||
return currentIsolation
|
||||
}
|
||||
|
||||
// ResolveInstanceRoot resolves the instance root used to build the isolated
|
||||
// filesystem and redirected user environment.
|
||||
func ResolveInstanceRoot() (string, error) {
|
||||
root := filepath.Clean(config.GetHome())
|
||||
if root == "." {
|
||||
return "", fmt.Errorf("instance root resolved to current directory")
|
||||
}
|
||||
return root, nil
|
||||
}
|
||||
|
||||
// PrepareInstanceRoot creates the directories required by the isolation runtime.
|
||||
func PrepareInstanceRoot(root string) error {
|
||||
for _, dir := range InstanceDirs(root) {
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return fmt.Errorf("prepare instance dir %s: %w", dir, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// InstanceDirs returns the directories that must exist under the instance root
|
||||
// for isolation-aware child processes.
|
||||
func InstanceDirs(root string) []string {
|
||||
dirs := []string{
|
||||
root,
|
||||
filepath.Join(root, "skills"),
|
||||
filepath.Join(root, "logs"),
|
||||
filepath.Join(root, "cache"),
|
||||
filepath.Join(root, "state"),
|
||||
filepath.Join(root, "runtime-user-env"),
|
||||
filepath.Join(root, "runtime-user-env", "home"),
|
||||
filepath.Join(root, "runtime-user-env", "tmp"),
|
||||
filepath.Join(root, "runtime-user-env", "config"),
|
||||
filepath.Join(root, "runtime-user-env", "cache"),
|
||||
filepath.Join(root, "runtime-user-env", "state"),
|
||||
}
|
||||
dirs = append(dirs, filepath.Join(root, pkg.WorkspaceName))
|
||||
if runtime.GOOS == "windows" {
|
||||
dirs = append(dirs,
|
||||
filepath.Join(root, "runtime-user-env", "AppData", "Roaming"),
|
||||
filepath.Join(root, "runtime-user-env", "AppData", "Local"),
|
||||
)
|
||||
}
|
||||
return dirs
|
||||
}
|
||||
|
||||
// ResolveUserEnv derives the redirected user directories rooted under the
|
||||
// instance runtime area.
|
||||
func ResolveUserEnv(root string) UserEnv {
|
||||
base := filepath.Join(root, "runtime-user-env")
|
||||
return UserEnv{
|
||||
Home: filepath.Join(base, "home"),
|
||||
Tmp: filepath.Join(base, "tmp"),
|
||||
Config: filepath.Join(base, "config"),
|
||||
Cache: filepath.Join(base, "cache"),
|
||||
State: filepath.Join(base, "state"),
|
||||
AppData: filepath.Join(base, "AppData", "Roaming"),
|
||||
LocalAppData: filepath.Join(base, "AppData", "Local"),
|
||||
}
|
||||
}
|
||||
|
||||
// ApplyUserEnv rewrites the child process environment so home, temp, and
|
||||
// platform-specific user-data directories point into the instance root.
|
||||
func ApplyUserEnv(cmd *exec.Cmd, root string) {
|
||||
userEnv := ResolveUserEnv(root)
|
||||
envMap := make(map[string]string)
|
||||
for _, item := range cmd.Environ() {
|
||||
if idx := strings.IndexRune(item, '='); idx > 0 {
|
||||
envMap[item[:idx]] = item[idx+1:]
|
||||
}
|
||||
}
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
envMap["USERPROFILE"] = userEnv.Home
|
||||
envMap["HOME"] = userEnv.Home
|
||||
envMap["TEMP"] = userEnv.Tmp
|
||||
envMap["TMP"] = userEnv.Tmp
|
||||
envMap["APPDATA"] = userEnv.AppData
|
||||
envMap["LOCALAPPDATA"] = userEnv.LocalAppData
|
||||
} else {
|
||||
envMap["HOME"] = userEnv.Home
|
||||
envMap["TMPDIR"] = userEnv.Tmp
|
||||
envMap["XDG_CONFIG_HOME"] = userEnv.Config
|
||||
envMap["XDG_CACHE_HOME"] = userEnv.Cache
|
||||
envMap["XDG_STATE_HOME"] = userEnv.State
|
||||
}
|
||||
|
||||
env := make([]string, 0, len(envMap))
|
||||
for k, v := range envMap {
|
||||
env = append(env, fmt.Sprintf("%s=%s", k, v))
|
||||
}
|
||||
cmd.Env = env
|
||||
}
|
||||
|
||||
// ValidateExposePaths verifies the user-supplied path exposure rules before a
|
||||
// child process is started.
|
||||
func ValidateExposePaths(items []config.ExposePath) error {
|
||||
seen := map[string]struct{}{}
|
||||
for _, item := range items {
|
||||
if item.Source == "" {
|
||||
return fmt.Errorf("source is required")
|
||||
}
|
||||
if item.Mode != "ro" && item.Mode != "rw" {
|
||||
return fmt.Errorf("invalid expose_paths mode: %s", item.Mode)
|
||||
}
|
||||
|
||||
source := filepath.Clean(item.Source)
|
||||
target := item.Target
|
||||
if target == "" {
|
||||
target = source
|
||||
}
|
||||
target = filepath.Clean(target)
|
||||
|
||||
if !filepath.IsAbs(source) || !filepath.IsAbs(target) {
|
||||
return fmt.Errorf("source and target must be absolute paths")
|
||||
}
|
||||
if _, ok := seen[target]; ok {
|
||||
return fmt.Errorf("duplicate expose_path target: %s", target)
|
||||
}
|
||||
seen[target] = struct{}{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NormalizeExposePath fills implicit defaults and cleans path values so merge
|
||||
// and validation logic can work with canonical paths.
|
||||
func NormalizeExposePath(item config.ExposePath) config.ExposePath {
|
||||
source := filepath.Clean(item.Source)
|
||||
target := item.Target
|
||||
if target == "" {
|
||||
target = source
|
||||
}
|
||||
return config.ExposePath{
|
||||
Source: source,
|
||||
Target: filepath.Clean(target),
|
||||
Mode: item.Mode,
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultExposePaths returns the minimum built-in host paths required for the
|
||||
// current platform to run isolated child processes.
|
||||
func DefaultExposePaths(root string) []config.ExposePath {
|
||||
items := []config.ExposePath{{
|
||||
Source: root,
|
||||
Target: root,
|
||||
Mode: "rw",
|
||||
}}
|
||||
if runtime.GOOS == "linux" {
|
||||
items = append(items, defaultLinuxSystemExposePaths()...)
|
||||
}
|
||||
return items
|
||||
}
|
||||
|
||||
func defaultLinuxSystemExposePaths() []config.ExposePath {
|
||||
return existingExposePaths([]config.ExposePath{
|
||||
{Source: "/usr", Target: "/usr", Mode: "ro"},
|
||||
{Source: "/bin", Target: "/bin", Mode: "ro"},
|
||||
{Source: "/lib", Target: "/lib", Mode: "ro"},
|
||||
{Source: "/lib64", Target: "/lib64", Mode: "ro"},
|
||||
{Source: "/etc/resolv.conf", Target: "/etc/resolv.conf", Mode: "ro"},
|
||||
{Source: "/etc/hosts", Target: "/etc/hosts", Mode: "ro"},
|
||||
{Source: "/etc/nsswitch.conf", Target: "/etc/nsswitch.conf", Mode: "ro"},
|
||||
{Source: "/etc/passwd", Target: "/etc/passwd", Mode: "ro"},
|
||||
{Source: "/etc/group", Target: "/etc/group", Mode: "ro"},
|
||||
{Source: "/etc/ssl", Target: "/etc/ssl", Mode: "ro"},
|
||||
{Source: "/etc/pki", Target: "/etc/pki", Mode: "ro"},
|
||||
{Source: "/etc/ca-certificates", Target: "/etc/ca-certificates", Mode: "ro"},
|
||||
{Source: "/usr/share/ca-certificates", Target: "/usr/share/ca-certificates", Mode: "ro"},
|
||||
{Source: "/usr/local/share/ca-certificates", Target: "/usr/local/share/ca-certificates", Mode: "ro"},
|
||||
{Source: "/etc/alternatives", Target: "/etc/alternatives", Mode: "ro"},
|
||||
{Source: "/usr/share/zoneinfo", Target: "/usr/share/zoneinfo", Mode: "ro"},
|
||||
{Source: "/etc/localtime", Target: "/etc/localtime", Mode: "ro"},
|
||||
})
|
||||
}
|
||||
|
||||
// existingExposePaths keeps only the builtin host paths that exist on the
|
||||
// current machine so Linux isolation does not fail on distro-specific paths.
|
||||
func existingExposePaths(items []config.ExposePath) []config.ExposePath {
|
||||
filtered := make([]config.ExposePath, 0, len(items))
|
||||
for _, item := range items {
|
||||
if _, err := os.Stat(item.Source); err == nil {
|
||||
filtered = append(filtered, item)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
// MergeExposePaths merges built-in rules with user overrides. Rules are keyed
|
||||
// by target path so later entries replace earlier ones for the same target.
|
||||
func MergeExposePaths(defaults []config.ExposePath, overrides []config.ExposePath) []config.ExposePath {
|
||||
merged := make([]config.ExposePath, 0, len(defaults)+len(overrides))
|
||||
indexByTarget := make(map[string]int, len(defaults)+len(overrides))
|
||||
appendOrReplace := func(item config.ExposePath) {
|
||||
normalized := NormalizeExposePath(item)
|
||||
if idx, ok := indexByTarget[normalized.Target]; ok {
|
||||
merged[idx] = normalized
|
||||
return
|
||||
}
|
||||
indexByTarget[normalized.Target] = len(merged)
|
||||
merged = append(merged, normalized)
|
||||
}
|
||||
for _, item := range defaults {
|
||||
appendOrReplace(item)
|
||||
}
|
||||
for _, item := range overrides {
|
||||
appendOrReplace(item)
|
||||
}
|
||||
return merged
|
||||
}
|
||||
|
||||
// BuildLinuxMountPlan converts the merged expose-path configuration into the
|
||||
// mount rules consumed by the Linux bubblewrap backend.
|
||||
func BuildLinuxMountPlan(root string, overrides []config.ExposePath) []MountRule {
|
||||
merged := MergeExposePaths(DefaultExposePaths(root), overrides)
|
||||
plan := make([]MountRule, 0, len(merged))
|
||||
for _, item := range merged {
|
||||
plan = append(plan, MountRule{Source: item.Source, Target: item.Target, Mode: item.Mode})
|
||||
}
|
||||
return plan
|
||||
}
|
||||
|
||||
// BuildWindowsAccessRules derives the host-path access policy used by the
|
||||
// Windows restricted-token backend.
|
||||
func BuildWindowsAccessRules(root string, overrides []config.ExposePath) []AccessRule {
|
||||
merged := MergeExposePaths(nil, overrides)
|
||||
rules := make([]AccessRule, 0, len(merged)+1)
|
||||
rules = append(rules, AccessRule{Path: root, Mode: "rw"})
|
||||
for _, item := range merged {
|
||||
rules = append(rules, AccessRule{Path: item.Source, Mode: item.Mode})
|
||||
}
|
||||
return rules
|
||||
}
|
||||
|
||||
func validateWindowsExposePaths(items []config.ExposePath) error {
|
||||
if len(items) == 0 {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("windows isolation does not yet support expose_paths filesystem rules")
|
||||
}
|
||||
|
||||
// IsSupported reports whether the current platform has an implemented isolation
|
||||
// backend.
|
||||
func IsSupported() bool {
|
||||
return isSupportedOn(runtime.GOOS)
|
||||
}
|
||||
|
||||
func isSupportedOn(goos string) bool {
|
||||
switch goos {
|
||||
case "linux", "windows":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Preflight validates the configured isolation state and prepares the instance
|
||||
// runtime directories before any child process is launched.
|
||||
func Preflight() error {
|
||||
isolation := CurrentConfig()
|
||||
if !isolation.Enabled {
|
||||
return nil
|
||||
}
|
||||
if !IsSupported() {
|
||||
return fmt.Errorf("subprocess isolation is not supported on %s", runtime.GOOS)
|
||||
}
|
||||
root, err := ResolveInstanceRoot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := PrepareInstanceRoot(root); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := ValidateExposePaths(isolation.ExposePaths); err != nil {
|
||||
return err
|
||||
}
|
||||
if runtime.GOOS == "linux" {
|
||||
for _, rule := range BuildLinuxMountPlan(root, isolation.ExposePaths) {
|
||||
if rule.Source == "" || rule.Target == "" {
|
||||
return fmt.Errorf("invalid linux mount rule")
|
||||
}
|
||||
}
|
||||
}
|
||||
if runtime.GOOS == "windows" {
|
||||
if err := validateWindowsExposePaths(isolation.ExposePaths); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, rule := range BuildWindowsAccessRules(root, isolation.ExposePaths) {
|
||||
if rule.Path == "" {
|
||||
return fmt.Errorf("invalid windows access rule")
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start prepares isolation for the command, starts it, and applies any
|
||||
// post-start platform hooks required by the active backend.
|
||||
func Start(cmd *exec.Cmd) error {
|
||||
if err := PrepareCommand(cmd); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := cmd.Start(); err != nil {
|
||||
cleanupPendingPlatformResources(cmd)
|
||||
return err
|
||||
}
|
||||
isolation := CurrentConfig()
|
||||
root := ""
|
||||
if isolation.Enabled {
|
||||
var err error
|
||||
root, err = ResolveInstanceRoot()
|
||||
if err != nil {
|
||||
terminateStartedCommand(cmd)
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := postStartPlatformIsolation(cmd, isolation, root); err != nil {
|
||||
terminateStartedCommand(cmd)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Run is the Start-and-Wait helper that keeps the same isolation behavior as
|
||||
// Start while returning the command's final exit status.
|
||||
func Run(cmd *exec.Cmd) error {
|
||||
if err := PrepareCommand(cmd); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := cmd.Start(); err != nil {
|
||||
cleanupPendingPlatformResources(cmd)
|
||||
return err
|
||||
}
|
||||
isolation := CurrentConfig()
|
||||
root := ""
|
||||
if isolation.Enabled {
|
||||
var err error
|
||||
root, err = ResolveInstanceRoot()
|
||||
if err != nil {
|
||||
terminateStartedCommand(cmd)
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := postStartPlatformIsolation(cmd, isolation, root); err != nil {
|
||||
terminateStartedCommand(cmd)
|
||||
return err
|
||||
}
|
||||
return cmd.Wait()
|
||||
}
|
||||
|
||||
func terminateStartedCommand(cmd *exec.Cmd) {
|
||||
cleanupPendingPlatformResources(cmd)
|
||||
if cmd == nil || cmd.Process == nil {
|
||||
return
|
||||
}
|
||||
_ = cmd.Process.Kill()
|
||||
_ = cmd.Wait()
|
||||
}
|
||||
|
||||
// PrepareCommand mutates the command in-place so it inherits the configured
|
||||
// isolated environment before being started by the caller.
|
||||
func PrepareCommand(cmd *exec.Cmd) error {
|
||||
isolation := CurrentConfig()
|
||||
if err := Preflight(); err != nil {
|
||||
return err
|
||||
}
|
||||
if isolation.Enabled {
|
||||
root, err := ResolveInstanceRoot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ApplyUserEnv(cmd, root)
|
||||
if err := applyPlatformIsolation(cmd, isolation, root); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,245 @@
|
||||
package isolation
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
func TestResolveInstanceRoot_UsesPicoclawHome(t *testing.T) {
|
||||
t.Setenv(config.EnvHome, "/custom/picoclaw/home")
|
||||
root, err := ResolveInstanceRoot()
|
||||
if err != nil {
|
||||
t.Fatalf("ResolveInstanceRoot() error = %v", err)
|
||||
}
|
||||
if root != "/custom/picoclaw/home" {
|
||||
t.Fatalf("ResolveInstanceRoot() = %q, want %q", root, "/custom/picoclaw/home")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrepareInstanceRoot_CreatesDirectories(t *testing.T) {
|
||||
root := filepath.Join(t.TempDir(), "instance")
|
||||
if err := PrepareInstanceRoot(root); err != nil {
|
||||
t.Fatalf("PrepareInstanceRoot() error = %v", err)
|
||||
}
|
||||
for _, dir := range InstanceDirs(root) {
|
||||
if info, err := os.Stat(dir); err != nil {
|
||||
t.Fatalf("os.Stat(%q): %v", dir, err)
|
||||
} else if !info.IsDir() {
|
||||
t.Fatalf("%q is not a directory", dir)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstanceDirs_UsesInstanceWorkspaceNotGlobalState(t *testing.T) {
|
||||
root := filepath.Join(t.TempDir(), "instance")
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Isolation.Enabled = true
|
||||
cfg.Agents.Defaults.Workspace = filepath.Join(t.TempDir(), "external-workspace")
|
||||
Configure(cfg)
|
||||
t.Cleanup(func() { Configure(config.DefaultConfig()) })
|
||||
|
||||
dirs := InstanceDirs(root)
|
||||
wantWorkspace := filepath.Join(root, pkg.WorkspaceName)
|
||||
found := false
|
||||
for _, dir := range dirs {
|
||||
if dir == wantWorkspace {
|
||||
found = true
|
||||
}
|
||||
if dir == cfg.WorkspacePath() {
|
||||
t.Fatalf("InstanceDirs() should not depend on process-wide workspace state: %q", dir)
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Fatalf("InstanceDirs() missing instance workspace dir %q", wantWorkspace)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSupportedOn(t *testing.T) {
|
||||
tests := []struct {
|
||||
goos string
|
||||
want bool
|
||||
}{
|
||||
{goos: "linux", want: true},
|
||||
{goos: "windows", want: true},
|
||||
{goos: "darwin", want: false},
|
||||
{goos: "freebsd", want: false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if got := isSupportedOn(tt.goos); got != tt.want {
|
||||
t.Fatalf("isSupportedOn(%q) = %v, want %v", tt.goos, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateExposePaths(t *testing.T) {
|
||||
err := ValidateExposePaths([]config.ExposePath{{Source: "/src", Target: "/dst", Mode: "ro"}})
|
||||
if err != nil {
|
||||
t.Fatalf("ValidateExposePaths() error = %v", err)
|
||||
}
|
||||
|
||||
err = ValidateExposePaths([]config.ExposePath{{Source: "/src", Target: "/dst", Mode: "bad"}})
|
||||
if err == nil {
|
||||
t.Fatal("ValidateExposePaths() expected invalid mode error")
|
||||
}
|
||||
|
||||
err = ValidateExposePaths(
|
||||
[]config.ExposePath{
|
||||
{Source: "/src", Target: "/dst", Mode: "ro"},
|
||||
{Source: "/other", Target: "/dst", Mode: "rw"},
|
||||
},
|
||||
)
|
||||
if err == nil {
|
||||
t.Fatal("ValidateExposePaths() expected duplicate target error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeExposePaths_OverrideByTarget(t *testing.T) {
|
||||
merged := MergeExposePaths(
|
||||
[]config.ExposePath{{Source: "/src-a", Target: "/dst", Mode: "ro"}},
|
||||
[]config.ExposePath{{Source: "/src-b", Target: "/dst", Mode: "rw"}},
|
||||
)
|
||||
if len(merged) != 1 {
|
||||
t.Fatalf("MergeExposePaths len = %d, want 1", len(merged))
|
||||
}
|
||||
if got := merged[0]; got.Source != "/src-b" || got.Target != "/dst" || got.Mode != "rw" {
|
||||
t.Fatalf("merged[0] = %+v, want source=/src-b target=/dst mode=rw", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildLinuxMountPlan(t *testing.T) {
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skip("linux-only default mount set")
|
||||
}
|
||||
plan := BuildLinuxMountPlan("/rootdir", []config.ExposePath{{Source: "/src", Target: "/dst", Mode: "ro"}})
|
||||
if len(plan) == 0 {
|
||||
t.Fatal("BuildLinuxMountPlan returned empty plan")
|
||||
}
|
||||
foundRoot := false
|
||||
foundOverride := false
|
||||
for _, rule := range plan {
|
||||
if rule.Source == "/rootdir" && rule.Target == "/rootdir" && rule.Mode == "rw" {
|
||||
foundRoot = true
|
||||
}
|
||||
if rule.Source == "/src" && rule.Target == "/dst" && rule.Mode == "ro" {
|
||||
foundOverride = true
|
||||
}
|
||||
}
|
||||
if !foundRoot {
|
||||
t.Fatal("BuildLinuxMountPlan missing root mapping")
|
||||
}
|
||||
if !foundOverride {
|
||||
t.Fatal("BuildLinuxMountPlan missing override mapping")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildWindowsAccessRules(t *testing.T) {
|
||||
rules := BuildWindowsAccessRules(
|
||||
`C:\picoclaw`,
|
||||
[]config.ExposePath{{Source: `D:\data`, Target: `C:\mapped`, Mode: "ro"}},
|
||||
)
|
||||
if len(rules) == 0 {
|
||||
t.Fatal("BuildWindowsAccessRules returned empty rules")
|
||||
}
|
||||
foundRoot := false
|
||||
foundOverride := false
|
||||
for _, rule := range rules {
|
||||
if rule.Path == `C:\picoclaw` && rule.Mode == "rw" {
|
||||
foundRoot = true
|
||||
}
|
||||
if rule.Path == `D:\data` && rule.Mode == "ro" {
|
||||
foundOverride = true
|
||||
}
|
||||
}
|
||||
if !foundRoot {
|
||||
t.Fatal("BuildWindowsAccessRules missing root rule")
|
||||
}
|
||||
if !foundOverride {
|
||||
t.Fatal("BuildWindowsAccessRules missing override rule")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateWindowsExposePaths(t *testing.T) {
|
||||
if err := validateWindowsExposePaths(nil); err != nil {
|
||||
t.Fatalf("validateWindowsExposePaths(nil) error = %v", err)
|
||||
}
|
||||
err := validateWindowsExposePaths([]config.ExposePath{{Source: `D:\data`, Target: `D:\data`, Mode: "ro"}})
|
||||
if err == nil {
|
||||
t.Fatal("validateWindowsExposePaths() expected error for expose_paths")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultLinuxSystemExposePaths(t *testing.T) {
|
||||
paths := defaultLinuxSystemExposePaths()
|
||||
needed := map[string]bool{}
|
||||
for _, path := range []string{"/etc/hosts", "/etc/nsswitch.conf", "/etc/ssl", "/usr/share/zoneinfo", "/etc/localtime"} {
|
||||
if _, err := os.Stat(path); err == nil {
|
||||
needed[path] = false
|
||||
}
|
||||
}
|
||||
for _, item := range paths {
|
||||
if _, ok := needed[item.Source]; ok {
|
||||
needed[item.Source] = true
|
||||
}
|
||||
}
|
||||
for path, found := range needed {
|
||||
if !found {
|
||||
t.Fatalf("defaultLinuxSystemExposePaths missing %s", path)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExistingExposePaths_SkipsMissingPaths(t *testing.T) {
|
||||
existing := filepath.Join(t.TempDir(), "existing")
|
||||
if err := os.MkdirAll(existing, 0o755); err != nil {
|
||||
t.Fatalf("os.MkdirAll() error = %v", err)
|
||||
}
|
||||
filtered := existingExposePaths([]config.ExposePath{
|
||||
{Source: existing, Target: existing, Mode: "ro"},
|
||||
{Source: filepath.Join(t.TempDir(), "missing"), Target: "/missing", Mode: "ro"},
|
||||
})
|
||||
if len(filtered) != 1 {
|
||||
t.Fatalf("existingExposePaths() len = %d, want 1", len(filtered))
|
||||
}
|
||||
if got := filtered[0]; got.Source != existing {
|
||||
t.Fatalf("existingExposePaths()[0] = %+v, want source=%q", got, existing)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrepareCommand_AppliesUserEnv(t *testing.T) {
|
||||
t.Setenv(config.EnvHome, filepath.Join(t.TempDir(), "home"))
|
||||
if runtime.GOOS == "linux" {
|
||||
binDir := filepath.Join(t.TempDir(), "bin")
|
||||
if err := os.MkdirAll(binDir, 0o755); err != nil {
|
||||
t.Fatalf("os.MkdirAll() error = %v", err)
|
||||
}
|
||||
fakeBwrap := filepath.Join(binDir, "bwrap")
|
||||
if err := os.WriteFile(fakeBwrap, []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil {
|
||||
t.Fatalf("os.WriteFile() error = %v", err)
|
||||
}
|
||||
t.Setenv("PATH", binDir+string(os.PathListSeparator)+os.Getenv("PATH"))
|
||||
}
|
||||
cfg := config.DefaultConfig()
|
||||
cfg.Isolation.Enabled = true
|
||||
Configure(cfg)
|
||||
t.Cleanup(func() { Configure(config.DefaultConfig()) })
|
||||
cmd := exec.Command("sh", "-c", "true")
|
||||
if err := PrepareCommand(cmd); err != nil {
|
||||
t.Fatalf("PrepareCommand() error = %v", err)
|
||||
}
|
||||
hasHome := false
|
||||
for _, env := range cmd.Env {
|
||||
if len(env) > 5 && env[:5] == "HOME=" {
|
||||
hasHome = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if runtime.GOOS != "windows" && !hasHome {
|
||||
t.Fatal("PrepareCommand() did not inject HOME")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,226 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os/exec"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
|
||||
sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/isolation"
|
||||
)
|
||||
|
||||
var isolatedCommandTerminateDuration = 5 * time.Second
|
||||
|
||||
// isolatedCommandTransport mirrors the SDK command transport but routes
|
||||
// process startup through pkg/isolation so Windows post-start hooks run too.
|
||||
type isolatedCommandTransport struct {
|
||||
Command *exec.Cmd
|
||||
TerminateDuration time.Duration
|
||||
}
|
||||
|
||||
func (t *isolatedCommandTransport) Connect(ctx context.Context) (sdkmcp.Connection, error) {
|
||||
stdout, err := t.Command.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stdout = io.NopCloser(stdout)
|
||||
stdin, err := t.Command.StdinPipe()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := isolation.Start(t.Command); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
td := t.TerminateDuration
|
||||
if td <= 0 {
|
||||
td = isolatedCommandTerminateDuration
|
||||
}
|
||||
return newIsolatedIOConn(&isolatedPipeRWC{cmd: t.Command, stdout: stdout, stdin: stdin, terminateDuration: td}), nil
|
||||
}
|
||||
|
||||
type isolatedPipeRWC struct {
|
||||
cmd *exec.Cmd
|
||||
stdout io.ReadCloser
|
||||
stdin io.WriteCloser
|
||||
terminateDuration time.Duration
|
||||
}
|
||||
|
||||
func (s *isolatedPipeRWC) Read(p []byte) (n int, err error) {
|
||||
return s.stdout.Read(p)
|
||||
}
|
||||
|
||||
func (s *isolatedPipeRWC) Write(p []byte) (n int, err error) {
|
||||
return s.stdin.Write(p)
|
||||
}
|
||||
|
||||
func (s *isolatedPipeRWC) Close() error {
|
||||
if err := s.stdin.Close(); err != nil {
|
||||
return fmt.Errorf("closing stdin: %v", err)
|
||||
}
|
||||
resChan := make(chan error, 1)
|
||||
go func() {
|
||||
resChan <- s.cmd.Wait()
|
||||
}()
|
||||
wait := func() (error, bool) {
|
||||
select {
|
||||
case err := <-resChan:
|
||||
return err, true
|
||||
case <-time.After(s.terminateDuration):
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
if err, ok := wait(); ok {
|
||||
return err
|
||||
}
|
||||
if err := s.cmd.Process.Signal(syscall.SIGTERM); err == nil {
|
||||
if err, ok := wait(); ok {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := s.cmd.Process.Kill(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err, ok := wait(); ok {
|
||||
return err
|
||||
}
|
||||
return fmt.Errorf("unresponsive subprocess")
|
||||
}
|
||||
|
||||
type isolatedIOConn struct {
|
||||
writeMu sync.Mutex
|
||||
rwc io.ReadWriteCloser
|
||||
incoming <-chan isolatedMsgOrErr
|
||||
queue []jsonrpc.Message
|
||||
closeOnce sync.Once
|
||||
closed chan struct{}
|
||||
closeErr error
|
||||
}
|
||||
|
||||
type isolatedMsgOrErr struct {
|
||||
msg json.RawMessage
|
||||
err error
|
||||
}
|
||||
|
||||
func newIsolatedIOConn(rwc io.ReadWriteCloser) *isolatedIOConn {
|
||||
incoming := make(chan isolatedMsgOrErr)
|
||||
closed := make(chan struct{})
|
||||
go func() {
|
||||
dec := json.NewDecoder(rwc)
|
||||
for {
|
||||
var raw json.RawMessage
|
||||
err := dec.Decode(&raw)
|
||||
if err == nil {
|
||||
var tr [1]byte
|
||||
if n, readErr := dec.Buffered().Read(tr[:]); n > 0 {
|
||||
if tr[0] != '\n' && tr[0] != '\r' {
|
||||
err = fmt.Errorf("invalid trailing data at the end of stream")
|
||||
}
|
||||
} else if readErr != nil && readErr != io.EOF {
|
||||
err = readErr
|
||||
}
|
||||
}
|
||||
select {
|
||||
case incoming <- isolatedMsgOrErr{msg: raw, err: err}:
|
||||
case <-closed:
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
return &isolatedIOConn{rwc: rwc, incoming: incoming, closed: closed}
|
||||
}
|
||||
|
||||
func (c *isolatedIOConn) SessionID() string { return "" }
|
||||
|
||||
func (c *isolatedIOConn) Read(ctx context.Context) (jsonrpc.Message, error) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
if len(c.queue) > 0 {
|
||||
next := c.queue[0]
|
||||
c.queue = c.queue[1:]
|
||||
return next, nil
|
||||
}
|
||||
var raw json.RawMessage
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case v := <-c.incoming:
|
||||
if v.err != nil {
|
||||
return nil, v.err
|
||||
}
|
||||
raw = v.msg
|
||||
case <-c.closed:
|
||||
return nil, io.EOF
|
||||
}
|
||||
msgs, err := readIsolatedBatch(raw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.queue = msgs[1:]
|
||||
return msgs[0], nil
|
||||
}
|
||||
|
||||
func readIsolatedBatch(data []byte) ([]jsonrpc.Message, error) {
|
||||
var rawBatch []json.RawMessage
|
||||
if err := json.Unmarshal(data, &rawBatch); err == nil {
|
||||
if len(rawBatch) == 0 {
|
||||
return nil, fmt.Errorf("empty batch")
|
||||
}
|
||||
msgs := make([]jsonrpc.Message, 0, len(rawBatch))
|
||||
for _, raw := range rawBatch {
|
||||
msg, err := jsonrpc.DecodeMessage(raw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
msgs = append(msgs, msg)
|
||||
}
|
||||
return msgs, nil
|
||||
}
|
||||
msg, err := jsonrpc.DecodeMessage(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []jsonrpc.Message{msg}, nil
|
||||
}
|
||||
|
||||
func (c *isolatedIOConn) Write(ctx context.Context, msg jsonrpc.Message) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
c.writeMu.Lock()
|
||||
defer c.writeMu.Unlock()
|
||||
data, err := jsonrpc.EncodeMessage(msg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshaling message: %v", err)
|
||||
}
|
||||
data = append(data, '\n')
|
||||
_, err = c.rwc.Write(data)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *isolatedIOConn) Close() error {
|
||||
c.closeOnce.Do(func() {
|
||||
c.closeErr = c.rwc.Close()
|
||||
close(c.closed)
|
||||
})
|
||||
return c.closeErr
|
||||
}
|
||||
|
||||
var (
|
||||
_ sdkmcp.Transport = (*isolatedCommandTransport)(nil)
|
||||
_ sdkmcp.Connection = (*isolatedIOConn)(nil)
|
||||
)
|
||||
+1
-2
@@ -365,8 +365,7 @@ func (m *Manager) ConnectServer(
|
||||
env = append(env, fmt.Sprintf("%s=%s", k, v))
|
||||
}
|
||||
cmd.Env = env
|
||||
|
||||
transport = &mcp.CommandTransport{Command: cmd}
|
||||
transport = &isolatedCommandTransport{Command: cmd}
|
||||
default:
|
||||
return fmt.Errorf(
|
||||
"unsupported transport type: %s (supported: stdio, sse, http)",
|
||||
|
||||
@@ -455,6 +455,33 @@ func (s *JSONLStore) rewriteJSONL(
|
||||
return fileutil.WriteFileAtomic(s.jsonlPath(sessionKey), buf.Bytes(), 0o644)
|
||||
}
|
||||
|
||||
// ListSessions returns all known session keys by reading .meta.json files.
|
||||
func (s *JSONLStore) ListSessions() []string {
|
||||
entries, err := os.ReadDir(s.dir)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
var keys []string
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".meta.json") {
|
||||
continue
|
||||
}
|
||||
// Read the meta file to get the original key
|
||||
data, err := os.ReadFile(filepath.Join(s.dir, entry.Name()))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
var meta sessionMeta
|
||||
if err := json.Unmarshal(data, &meta); err != nil {
|
||||
continue
|
||||
}
|
||||
if meta.Key != "" {
|
||||
keys = append(keys, meta.Key)
|
||||
}
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
func (s *JSONLStore) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -37,6 +37,9 @@ type Store interface {
|
||||
// data. Backends that do not accumulate dead data may return nil.
|
||||
Compact(ctx context.Context, sessionKey string) error
|
||||
|
||||
// ListSessions returns all known session keys.
|
||||
ListSessions() []string
|
||||
|
||||
// Close releases any resources held by the store.
|
||||
Close() error
|
||||
}
|
||||
|
||||
+37
-2
@@ -4,6 +4,7 @@ import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -16,6 +17,8 @@ import (
|
||||
|
||||
const pidFileName = ".picoclaw.pid"
|
||||
|
||||
var errInvalidPidFile = errors.New("invalid pid file")
|
||||
|
||||
// PidFileData is the JSON structure stored in the PID file.
|
||||
type PidFileData struct {
|
||||
PID int `json:"pid"`
|
||||
@@ -109,6 +112,14 @@ func ReadPidFileWithCheck(homePath string) *PidFileData {
|
||||
pidPath := pidFilePath(homePath)
|
||||
data, err := readPidFileUnlocked(pidPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
if errors.Is(err, errInvalidPidFile) {
|
||||
logger.Warnf("invalid pid file, remove it: %s (%v)", pidPath, err)
|
||||
_ = os.Remove(pidPath)
|
||||
return nil
|
||||
}
|
||||
logger.Debugf("failed to read pid file: %s", err)
|
||||
return nil
|
||||
}
|
||||
@@ -140,6 +151,30 @@ func RemovePidFile(homePath string) {
|
||||
os.Remove(pidPath)
|
||||
}
|
||||
|
||||
// RemovePidFileIfPID deletes the PID file only when the recorded PID matches
|
||||
// expectedPID. It returns true when the file is removed successfully.
|
||||
func RemovePidFileIfPID(homePath string, expectedPID int) bool {
|
||||
if expectedPID <= 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
pidMu.Lock()
|
||||
defer pidMu.Unlock()
|
||||
|
||||
pidPath := pidFilePath(homePath)
|
||||
data, err := readPidFileUnlocked(pidPath)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if data.PID != expectedPID {
|
||||
return false
|
||||
}
|
||||
if err := os.Remove(pidPath); err != nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// readPidFileUnlocked reads the PID file without acquiring the lock.
|
||||
// Caller must hold pidMu.
|
||||
func readPidFileUnlocked(pidPath string) (*PidFileData, error) {
|
||||
@@ -150,12 +185,12 @@ func readPidFileUnlocked(pidPath string) (*PidFileData, error) {
|
||||
|
||||
var data PidFileData
|
||||
if err := json.Unmarshal(raw, &data); err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("%w: %v", errInvalidPidFile, err)
|
||||
}
|
||||
|
||||
// Validate PID is a positive integer.
|
||||
if data.PID <= 0 {
|
||||
return nil, fmt.Errorf("invalid pid in pid file: %d", data.PID)
|
||||
return nil, fmt.Errorf("%w: pid=%d", errInvalidPidFile, data.PID)
|
||||
}
|
||||
|
||||
return &data, nil
|
||||
|
||||
@@ -191,6 +191,22 @@ func TestReadPidFileWithCheckStalePID(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestReadPidFileWithCheckInvalidFile auto-cleans malformed PID file.
|
||||
func TestReadPidFileWithCheckInvalidFile(t *testing.T) {
|
||||
dir := tmpDir(t)
|
||||
path := filepath.Join(dir, pidFileName)
|
||||
os.WriteFile(path, []byte("not json"), 0o600)
|
||||
|
||||
data := ReadPidFileWithCheck(dir)
|
||||
if data != nil {
|
||||
t.Error("expected nil for malformed pid file")
|
||||
}
|
||||
|
||||
if _, err := os.Stat(path); !os.IsNotExist(err) {
|
||||
t.Error("malformed PID file should be removed")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRemovePidFile removes the PID file for the current process.
|
||||
func TestRemovePidFile(t *testing.T) {
|
||||
dir := tmpDir(t)
|
||||
@@ -228,6 +244,40 @@ func TestRemovePidFileNonexistent(t *testing.T) {
|
||||
RemovePidFile(dir)
|
||||
}
|
||||
|
||||
func TestRemovePidFileIfPID(t *testing.T) {
|
||||
dir := tmpDir(t)
|
||||
|
||||
other := PidFileData{PID: 99999999, Token: "deadbeef12345678deadbeef12345678"}
|
||||
raw, _ := json.MarshalIndent(other, "", " ")
|
||||
path := filepath.Join(dir, pidFileName)
|
||||
os.WriteFile(path, raw, 0o600)
|
||||
|
||||
removed := RemovePidFileIfPID(dir, 99999999)
|
||||
if !removed {
|
||||
t.Fatal("expected RemovePidFileIfPID to remove matching pid file")
|
||||
}
|
||||
if _, err := os.Stat(path); !os.IsNotExist(err) {
|
||||
t.Error("PID file should be removed for matching expected PID")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemovePidFileIfPIDMismatch(t *testing.T) {
|
||||
dir := tmpDir(t)
|
||||
|
||||
other := PidFileData{PID: 99999999, Token: "deadbeef12345678deadbeef12345678"}
|
||||
raw, _ := json.MarshalIndent(other, "", " ")
|
||||
path := filepath.Join(dir, pidFileName)
|
||||
os.WriteFile(path, raw, 0o600)
|
||||
|
||||
removed := RemovePidFileIfPID(dir, 88888888)
|
||||
if removed {
|
||||
t.Fatal("expected RemovePidFileIfPID to keep non-matching pid file")
|
||||
}
|
||||
if _, err := os.Stat(path); os.IsNotExist(err) {
|
||||
t.Error("PID file should NOT be removed for mismatching expected PID")
|
||||
}
|
||||
}
|
||||
|
||||
// TestReadPidFileUnlockedInvalidJSON returns error for malformed content.
|
||||
func TestReadPidFileUnlockedInvalidJSON(t *testing.T) {
|
||||
dir := tmpDir(t)
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
package pid
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"syscall"
|
||||
)
|
||||
@@ -18,5 +19,11 @@ func isProcessRunning(pid int) bool {
|
||||
return false
|
||||
}
|
||||
// Signal(nil) does not kill the process but checks existence on Unix.
|
||||
return p.Signal(syscall.Signal(0)) == nil
|
||||
err = p.Signal(syscall.Signal(0))
|
||||
if err == nil {
|
||||
return true
|
||||
}
|
||||
var errno syscall.Errno
|
||||
// EPERM means the process exists but we are not allowed to signal it.
|
||||
return errors.As(err, &errno) && errno == syscall.EPERM
|
||||
}
|
||||
|
||||
@@ -23,19 +23,19 @@ func isProcessRunning(pid int) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
handle, _, err := procOpenProcess.Call(
|
||||
handle, _, _ := procOpenProcess.Call(
|
||||
uintptr(processQueryLimitedInformation),
|
||||
0,
|
||||
uintptr(pid),
|
||||
)
|
||||
if handle == 0 || err != nil {
|
||||
if handle == 0 {
|
||||
return false
|
||||
}
|
||||
defer procCloseHandle.Call(handle)
|
||||
|
||||
var exitCode uint32
|
||||
ret, _, err := procGetExitCodeProcess.Call(handle, uintptr(unsafe.Pointer(&exitCode)))
|
||||
if ret == 0 || err != nil {
|
||||
ret, _, _ := procGetExitCodeProcess.Call(handle, uintptr(unsafe.Pointer(&exitCode)))
|
||||
if ret == 0 {
|
||||
return false
|
||||
}
|
||||
return exitCode == stillActive
|
||||
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/isolation"
|
||||
)
|
||||
|
||||
// ClaudeCliProvider implements LLMProvider using the claude CLI as a subprocess.
|
||||
@@ -49,7 +51,9 @@ func (p *ClaudeCliProvider) Chat(
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
// Execute the CLI through the shared isolation wrapper so external provider
|
||||
// processes honor the configured isolation policy.
|
||||
if err := isolation.Run(cmd); err != nil {
|
||||
stderrStr := strings.TrimSpace(stderr.String())
|
||||
stdoutStr := strings.TrimSpace(stdout.String())
|
||||
switch {
|
||||
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/isolation"
|
||||
)
|
||||
|
||||
// CodexCliProvider implements LLMProvider by wrapping the codex CLI as a subprocess.
|
||||
@@ -56,7 +58,9 @@ func (p *CodexCliProvider) Chat(
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
err := cmd.Run()
|
||||
// Execute the CLI through the shared isolation wrapper so external provider
|
||||
// processes honor the configured isolation policy.
|
||||
err := isolation.Run(cmd)
|
||||
|
||||
// Parse JSONL from stdout even if exit code is non-zero,
|
||||
// because codex writes diagnostic noise to stderr (e.g. rollout errors)
|
||||
|
||||
@@ -262,6 +262,22 @@ func TestDecodeToolCallArguments_StringJSON(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeToolCallArguments_StringJSON_NewlineEscape(t *testing.T) {
|
||||
raw := json.RawMessage(`"{\"content\":\"line1\\nline2\"}"`)
|
||||
args := DecodeToolCallArguments(raw, "write_file")
|
||||
if args["content"] != "line1\nline2" {
|
||||
t.Errorf("content = %q, want newline-expanded string", args["content"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeToolCallArguments_StringJSON_LiteralBackslashN(t *testing.T) {
|
||||
raw := json.RawMessage(`"{\"content\":\"line1\\\\nline2\"}"`)
|
||||
args := DecodeToolCallArguments(raw, "write_file")
|
||||
if args["content"] != `line1\nline2` {
|
||||
t.Errorf("content = %q, want literal backslash-n", args["content"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeToolCallArguments_EmptyInput(t *testing.T) {
|
||||
args := DecodeToolCallArguments(nil, "test")
|
||||
if len(args) != 0 {
|
||||
|
||||
@@ -160,6 +160,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
userAgent,
|
||||
cfg.RequestTimeout,
|
||||
cfg.ExtraBody,
|
||||
cfg.CustomHeaders,
|
||||
), modelID, nil
|
||||
|
||||
case "azure", "azure-openai":
|
||||
@@ -238,6 +239,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
userAgent,
|
||||
cfg.RequestTimeout,
|
||||
cfg.ExtraBody,
|
||||
cfg.CustomHeaders,
|
||||
), modelID, nil
|
||||
|
||||
case "minimax":
|
||||
@@ -264,6 +266,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
userAgent,
|
||||
cfg.RequestTimeout,
|
||||
extraBody,
|
||||
cfg.CustomHeaders,
|
||||
), modelID, nil
|
||||
|
||||
case "anthropic":
|
||||
@@ -291,6 +294,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
|
||||
userAgent,
|
||||
cfg.RequestTimeout,
|
||||
cfg.ExtraBody,
|
||||
cfg.CustomHeaders,
|
||||
), modelID, nil
|
||||
|
||||
case "anthropic-messages":
|
||||
|
||||
@@ -846,6 +846,49 @@ func TestCreateProviderFromConfig_MinimaxPreservesUserExtraBody(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateProviderFromConfig_CustomHeaders(t *testing.T) {
|
||||
var gotSource, gotAuth string
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotSource = r.Header.Get("X-Source")
|
||||
gotAuth = r.Header.Get("Authorization")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"choices":[{"message":{"content":"ok"},"finish_reason":"stop"}]}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cfg := &config.ModelConfig{
|
||||
ModelName: "test-headers",
|
||||
Model: "openai/gpt-4o",
|
||||
APIBase: server.URL,
|
||||
CustomHeaders: map[string]string{"X-Source": "coding-plan", "Authorization": "Token config-auth"},
|
||||
}
|
||||
cfg.SetAPIKey("test-key")
|
||||
|
||||
provider, modelID, err := CreateProviderFromConfig(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProviderFromConfig() error = %v", err)
|
||||
}
|
||||
|
||||
_, err = provider.Chat(
|
||||
t.Context(),
|
||||
[]Message{{Role: "user", Content: "hi"}},
|
||||
nil,
|
||||
modelID,
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
if gotSource != "coding-plan" {
|
||||
t.Fatalf("X-Source = %q, want %q", gotSource, "coding-plan")
|
||||
}
|
||||
if gotAuth != "Token config-auth" {
|
||||
t.Fatalf("Authorization = %q, want %q", gotAuth, "Token config-auth")
|
||||
}
|
||||
}
|
||||
|
||||
// openaiCompatResponse is the JSON response used by OpenAI-compatible providers.
|
||||
const openaiCompatResponse = `{"choices":[{"message":{"content":"ok"},"finish_reason":"stop"}]}`
|
||||
|
||||
|
||||
@@ -24,13 +24,14 @@ func NewHTTPProvider(apiKey, apiBase, proxy string) *HTTPProvider {
|
||||
}
|
||||
|
||||
func NewHTTPProviderWithMaxTokensField(apiKey, apiBase, proxy, maxTokensField string) *HTTPProvider {
|
||||
return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(apiKey, apiBase, proxy, maxTokensField, "", 0, nil)
|
||||
return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(apiKey, apiBase, proxy, maxTokensField, "", 0, nil, nil)
|
||||
}
|
||||
|
||||
func NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(
|
||||
apiKey, apiBase, proxy, maxTokensField, userAgent string,
|
||||
requestTimeoutSeconds int,
|
||||
extraBody map[string]any,
|
||||
customHeaders map[string]string,
|
||||
) *HTTPProvider {
|
||||
return &HTTPProvider{
|
||||
delegate: openai_compat.NewProvider(
|
||||
@@ -40,6 +41,7 @@ func NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(
|
||||
openai_compat.WithMaxTokensField(maxTokensField),
|
||||
openai_compat.WithRequestTimeout(time.Duration(requestTimeoutSeconds)*time.Second),
|
||||
openai_compat.WithExtraBody(extraBody),
|
||||
openai_compat.WithCustomHeaders(customHeaders),
|
||||
openai_compat.WithUserAgent(userAgent),
|
||||
),
|
||||
}
|
||||
|
||||
@@ -36,6 +36,7 @@ type Provider struct {
|
||||
maxTokensField string // Field name for max tokens (e.g., "max_completion_tokens" for o1/glm models)
|
||||
httpClient *http.Client
|
||||
extraBody map[string]any // Additional fields to inject into request body
|
||||
customHeaders map[string]string
|
||||
userAgent string
|
||||
}
|
||||
|
||||
@@ -87,6 +88,12 @@ func WithExtraBody(extraBody map[string]any) Option {
|
||||
}
|
||||
}
|
||||
|
||||
func WithCustomHeaders(customHeaders map[string]string) Option {
|
||||
return func(p *Provider) {
|
||||
p.customHeaders = customHeaders
|
||||
}
|
||||
}
|
||||
|
||||
func NewProvider(apiKey, apiBase, proxy string, opts ...Option) *Provider {
|
||||
p := &Provider{
|
||||
apiKey: apiKey,
|
||||
@@ -181,6 +188,15 @@ func (p *Provider) buildRequestBody(
|
||||
return requestBody
|
||||
}
|
||||
|
||||
func (p *Provider) applyCustomHeaders(req *http.Request) {
|
||||
for k, v := range p.customHeaders {
|
||||
if strings.TrimSpace(k) == "" {
|
||||
continue
|
||||
}
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Provider) Chat(
|
||||
ctx context.Context,
|
||||
messages []Message,
|
||||
@@ -211,6 +227,7 @@ func (p *Provider) Chat(
|
||||
if p.apiKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+p.apiKey)
|
||||
}
|
||||
p.applyCustomHeaders(req)
|
||||
|
||||
resp, err := p.httpClient.Do(req)
|
||||
if err != nil {
|
||||
@@ -254,9 +271,13 @@ func (p *Provider) ChatStream(
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
if p.userAgent != "" {
|
||||
req.Header.Set("User-Agent", p.userAgent)
|
||||
}
|
||||
if p.apiKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+p.apiKey)
|
||||
}
|
||||
p.applyCustomHeaders(req)
|
||||
|
||||
// Use a client without Timeout for streaming — the http.Client.Timeout covers
|
||||
// the entire request lifecycle including body reads, which would kill long streams.
|
||||
|
||||
@@ -710,6 +710,111 @@ func TestProviderChat_ExtraBodyOverridesOptions(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChat_CustomHeadersInjected(t *testing.T) {
|
||||
var gotSource, gotAuth, gotUserAgent string
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotSource = r.Header.Get("X-Source")
|
||||
gotAuth = r.Header.Get("Authorization")
|
||||
gotUserAgent = r.Header.Get("User-Agent")
|
||||
resp := map[string]any{
|
||||
"choices": []map[string]any{
|
||||
{
|
||||
"message": map[string]any{"content": "ok"},
|
||||
"finish_reason": "stop",
|
||||
},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider(
|
||||
"key",
|
||||
server.URL,
|
||||
"",
|
||||
WithUserAgent("PicoClaw/Test"),
|
||||
WithCustomHeaders(map[string]string{
|
||||
"X-Source": "coding-plan",
|
||||
"Authorization": "Token custom-auth",
|
||||
"User-Agent": "Custom-UA/1.0",
|
||||
}),
|
||||
)
|
||||
|
||||
_, err := p.Chat(
|
||||
t.Context(),
|
||||
[]Message{{Role: "user", Content: "hi"}},
|
||||
nil,
|
||||
"gpt-4o",
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
if gotSource != "coding-plan" {
|
||||
t.Fatalf("X-Source = %q, want %q", gotSource, "coding-plan")
|
||||
}
|
||||
if gotAuth != "Token custom-auth" {
|
||||
t.Fatalf("Authorization = %q, want %q", gotAuth, "Token custom-auth")
|
||||
}
|
||||
if gotUserAgent != "Custom-UA/1.0" {
|
||||
t.Fatalf("User-Agent = %q, want %q", gotUserAgent, "Custom-UA/1.0")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderChatStream_CustomHeadersInjected(t *testing.T) {
|
||||
var gotSource, gotAuth, gotUserAgent string
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotSource = r.Header.Get("X-Source")
|
||||
gotAuth = r.Header.Get("Authorization")
|
||||
gotUserAgent = r.Header.Get("User-Agent")
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = w.Write([]byte("data: {\"choices\":[{\"delta\":{\"content\":\"ok\"},\"finish_reason\":\"stop\"}]}\n\n"))
|
||||
_, _ = w.Write([]byte("data: [DONE]\n\n"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
p := NewProvider(
|
||||
"key",
|
||||
server.URL,
|
||||
"",
|
||||
WithUserAgent("PicoClaw/Test"),
|
||||
WithCustomHeaders(map[string]string{
|
||||
"X-Source": "coding-plan",
|
||||
"Authorization": "Token stream-auth",
|
||||
"User-Agent": "Custom-UA/Stream",
|
||||
}),
|
||||
)
|
||||
|
||||
out, err := p.ChatStream(
|
||||
t.Context(),
|
||||
[]Message{{Role: "user", Content: "hi"}},
|
||||
nil,
|
||||
"gpt-4o",
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("ChatStream() error = %v", err)
|
||||
}
|
||||
if out.Content != "ok" {
|
||||
t.Fatalf("Content = %q, want %q", out.Content, "ok")
|
||||
}
|
||||
if gotSource != "coding-plan" {
|
||||
t.Fatalf("X-Source = %q, want %q", gotSource, "coding-plan")
|
||||
}
|
||||
if gotAuth != "Token stream-auth" {
|
||||
t.Fatalf("Authorization = %q, want %q", gotAuth, "Token stream-auth")
|
||||
}
|
||||
if gotUserAgent != "Custom-UA/Stream" {
|
||||
t.Fatalf("User-Agent = %q, want %q", gotUserAgent, "Custom-UA/Stream")
|
||||
}
|
||||
}
|
||||
|
||||
type roundTripperFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||
|
||||
@@ -23,6 +23,12 @@ func buildCLIToolsPrompt(tools []ToolDefinition) string {
|
||||
)
|
||||
sb.WriteString("\n```\n\n")
|
||||
sb.WriteString("CRITICAL: The 'arguments' field MUST be a JSON-encoded STRING.\n\n")
|
||||
sb.WriteString("Escaping rules (what to type in `function.arguments`):\n")
|
||||
sb.WriteString("- Use `\\n` to represent a real newline character.\n")
|
||||
sb.WriteString("- Use `\\\\n` to represent a literal backslash+n sequence (`\\n`).\n")
|
||||
sb.WriteString(
|
||||
"- `function.arguments` is a JSON-encoded string, so quotes/backslashes must be escaped in the outer payload.\n\n",
|
||||
)
|
||||
sb.WriteString("### Tool Definitions:\n\n")
|
||||
|
||||
for _, tool := range tools {
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"tool_name": "Bash",
|
||||
"tool_input_preview": "{\"command\":\"cd /home/yliu/repos/picoclaw && make lint 2>&1\",\"timeout\":120000}",
|
||||
"error": "Exit code 2\npkg/agent/context_seahorse_test.go:1027:1: File is not properly formatted (gci)\n\t\t\tEarliestAt: &now,\n^\n1 issues:\n* gci: 1\nmake: *** [Makefile:264: lint] Error 1",
|
||||
"timestamp": "2026-04-04T02:38:32.067Z",
|
||||
"retry_count": 6
|
||||
}
|
||||
@@ -0,0 +1,58 @@
|
||||
package seahorse
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// =============================================================================
|
||||
// CompactUntilUnder iteration cap
|
||||
// =============================================================================
|
||||
|
||||
func TestCompactUntilUnderIterationCap(t *testing.T) {
|
||||
// Setup: create a conversation with so many tokens that compaction
|
||||
// will never reach the budget. The iteration cap prevents infinite loops.
|
||||
//
|
||||
// We use a mock CompleteFn that always returns the same content,
|
||||
// and a budget of 0 which tokens can never reach.
|
||||
// Without the cap, this would loop forever.
|
||||
|
||||
db := openTestDB(t)
|
||||
if err := runSchema(db); err != nil {
|
||||
t.Fatalf("migration: %v", err)
|
||||
}
|
||||
s := &Store{db: db}
|
||||
|
||||
conv, _ := s.GetOrCreateConversation(context.Background(), "agent:iter-cap")
|
||||
convID := conv.ConversationID
|
||||
|
||||
// Add many messages to ensure there's plenty to compact
|
||||
for i := 0; i < 40; i++ {
|
||||
m, _ := s.AddMessage(context.Background(), convID, "user",
|
||||
"this is a long message with lots of tokens to push context over budget", 100)
|
||||
s.AppendContextMessage(context.Background(), convID, m.ID)
|
||||
}
|
||||
|
||||
// A completeFn that always succeeds but returns non-reducing content
|
||||
mockComplete := func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) {
|
||||
return "Summary that doesn't reduce tokens much.", nil
|
||||
}
|
||||
|
||||
ce, cancel := newTestCompactionEngineWithStore(s, mockComplete)
|
||||
defer cancel()
|
||||
|
||||
// Use budget=1 so tokens can never reach budget
|
||||
// (each message is 100 tokens, so 40 messages = 4000 tokens, budget 1 is unreachable)
|
||||
// The function should stop after maxCompactIterations, not loop forever
|
||||
ce.config = Config{} // ensure defaults
|
||||
|
||||
result, err := ce.CompactUntilUnder(context.Background(), convID, 1)
|
||||
if err != nil {
|
||||
// Should not error — should stop gracefully
|
||||
t.Fatalf("CompactUntilUnder with budget=0: %v", err)
|
||||
}
|
||||
|
||||
// The function should have completed within reasonable time
|
||||
// If it exceeded the cap, it would still return (not hang)
|
||||
_ = result
|
||||
}
|
||||
@@ -0,0 +1,144 @@
|
||||
package seahorse
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// =============================================================================
|
||||
// Bug 1: formatMessagesForSummary ignores Parts
|
||||
// - formatMessagesForSummary only reads m.Content, empty for Part-based messages
|
||||
// - truncateSummary has same issue
|
||||
// =============================================================================
|
||||
|
||||
func TestFormatMessagesForSummaryIncludesParts(t *testing.T) {
|
||||
ts := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
messages := []Message{
|
||||
{ID: 1, Role: "user", Content: "hello world", CreatedAt: ts},
|
||||
{
|
||||
ID: 2,
|
||||
Role: "assistant",
|
||||
Content: "", // empty — real content is in Parts
|
||||
Parts: []MessagePart{
|
||||
{Type: "text", Text: "I will run a command"},
|
||||
{Type: "tool_use", Name: "bash", Arguments: `{"command":"ls -la"}`, ToolCallID: "call_1"},
|
||||
},
|
||||
CreatedAt: ts.Add(time.Minute),
|
||||
},
|
||||
{
|
||||
ID: 3,
|
||||
Role: "tool",
|
||||
Content: "", // empty — real content is in Parts
|
||||
Parts: []MessagePart{
|
||||
{Type: "tool_result", Text: "file1.txt\nfile2.txt", ToolCallID: "call_1"},
|
||||
},
|
||||
CreatedAt: ts.Add(2 * time.Minute),
|
||||
},
|
||||
}
|
||||
|
||||
result := formatMessagesForSummary(messages)
|
||||
|
||||
// Must contain the plain text message
|
||||
if !contains(result, "hello world") {
|
||||
t.Error("formatMessagesForSummary: missing plain text content")
|
||||
}
|
||||
|
||||
// Must contain tool_use info (not blank)
|
||||
if !contains(result, "bash") || !contains(result, "ls -la") {
|
||||
t.Errorf("formatMessagesForSummary: tool_use info missing from Parts.\nGot:\n%s", result)
|
||||
}
|
||||
|
||||
// Must contain tool_result info (not blank)
|
||||
if !contains(result, "file1.txt") {
|
||||
t.Errorf("formatMessagesForSummary: tool_result text missing from Parts.\nGot:\n%s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncateSummaryIncludesParts(t *testing.T) {
|
||||
messages := []Message{
|
||||
{ID: 1, Role: "user", Content: "run the tests", CreatedAt: time.Now()},
|
||||
{
|
||||
ID: 2,
|
||||
Role: "assistant",
|
||||
Content: "", // empty
|
||||
Parts: []MessagePart{
|
||||
{Type: "tool_use", Name: "bash", Arguments: `{"command":"go test ./..."}`, ToolCallID: "call_1"},
|
||||
},
|
||||
CreatedAt: time.Now(),
|
||||
},
|
||||
{
|
||||
ID: 3,
|
||||
Role: "tool",
|
||||
Content: "", // empty
|
||||
Parts: []MessagePart{
|
||||
{Type: "tool_result", Text: "PASS\nok 3.2s", ToolCallID: "call_1"},
|
||||
},
|
||||
CreatedAt: time.Now(),
|
||||
},
|
||||
}
|
||||
|
||||
result := truncateSummary(messages)
|
||||
|
||||
// Must contain plain text
|
||||
if !contains(result, "run the tests") {
|
||||
t.Error("truncateSummary: missing plain text content")
|
||||
}
|
||||
|
||||
// Must contain tool info from Parts (not blank)
|
||||
if !contains(result, "bash") || !contains(result, "go test") {
|
||||
t.Errorf("truncateSummary: tool_use info missing from Parts.\nGot:\n%s", result)
|
||||
}
|
||||
|
||||
// Must contain tool_result from Parts
|
||||
if !contains(result, "PASS") {
|
||||
t.Errorf("truncateSummary: tool_result text missing from Parts.\nGot:\n%s", result)
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Bug 2: SearchMessages cannot find Part-based messages
|
||||
// - FTS5 indexes empty content, LIKE queries empty content
|
||||
// =============================================================================
|
||||
|
||||
func TestSearchMessagesFindsPartBasedMessages(t *testing.T) {
|
||||
s := openTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
conv, _ := s.GetOrCreateConversation(ctx, "agent:search-parts")
|
||||
convID := conv.ConversationID
|
||||
|
||||
// Add a plain message (searchable)
|
||||
s.AddMessage(ctx, convID, "user", "list the files please", 5)
|
||||
|
||||
// Add a Part-based message (tool_use) — currently NOT searchable
|
||||
parts := []MessagePart{
|
||||
{Type: "tool_use", Name: "bash", Arguments: `{"command":"grep -r TODO ."}`, ToolCallID: "call_1"},
|
||||
}
|
||||
s.AddMessageWithParts(ctx, convID, "assistant", parts, 10)
|
||||
|
||||
// Add a Part-based message (tool_result) — currently NOT searchable
|
||||
resultParts := []MessagePart{
|
||||
{Type: "tool_result", Text: "main.go:42: TODO fix this bug", ToolCallID: "call_1"},
|
||||
}
|
||||
s.AddMessageWithParts(ctx, convID, "tool", resultParts, 10)
|
||||
|
||||
// Search for "grep" — should find the tool_use message
|
||||
results, err := s.SearchMessages(ctx, SearchInput{Pattern: "grep"})
|
||||
if err != nil {
|
||||
t.Fatalf("SearchMessages: %v", err)
|
||||
}
|
||||
if len(results) == 0 {
|
||||
t.Error("SearchMessages: 'grep' not found — Part-based messages are invisible to search")
|
||||
}
|
||||
|
||||
// Search for "TODO fix" — should find the tool_result message
|
||||
results2, err := s.SearchMessages(ctx, SearchInput{Pattern: "TODO fix"})
|
||||
if err != nil {
|
||||
t.Fatalf("SearchMessages: %v", err)
|
||||
}
|
||||
if len(results2) == 0 {
|
||||
t.Error("SearchMessages: 'TODO fix' not found — tool_result messages are invisible to search")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,185 @@
|
||||
package seahorse
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
// SQL statements for FTS5 tables with trigram tokenizer.
|
||||
const (
|
||||
sqlCreateSummariesFTS = `CREATE VIRTUAL TABLE IF NOT EXISTS summaries_fts USING fts5(
|
||||
summary_id,
|
||||
content,
|
||||
tokenize="trigram"
|
||||
)`
|
||||
sqlCreateMessagesFTS = `CREATE VIRTUAL TABLE IF NOT EXISTS messages_fts USING fts5(
|
||||
message_id,
|
||||
content,
|
||||
tokenize="trigram"
|
||||
)`
|
||||
sqlCheckFTS5Available = `CREATE VIRTUAL TABLE IF NOT EXISTS _fts5_check USING fts5(content)`
|
||||
sqlCheckTrigramAvailable = `CREATE VIRTUAL TABLE IF NOT EXISTS _trigram_check USING fts5(content, tokenize="trigram")`
|
||||
sqlDropFTS5Check = `DROP TABLE IF EXISTS _fts5_check`
|
||||
sqlDropTrigramCheck = `DROP TABLE IF EXISTS _trigram_check`
|
||||
)
|
||||
|
||||
// runSchema creates or upgrades the database schema.
|
||||
// All schemas are idempotent (safe to run multiple times).
|
||||
func runSchema(db *sql.DB) error {
|
||||
// Check FTS5 support before creating tables
|
||||
if err := checkFTS5Support(db); err != nil {
|
||||
return fmt.Errorf("FTS5 check: %w", err)
|
||||
}
|
||||
|
||||
stmts := []string{
|
||||
`CREATE TABLE IF NOT EXISTS conversations (
|
||||
conversation_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
session_key TEXT NOT NULL UNIQUE,
|
||||
created_at TEXT NOT NULL DEFAULT (datetime('now')),
|
||||
updated_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||||
)`,
|
||||
|
||||
`CREATE TABLE IF NOT EXISTS messages (
|
||||
message_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
conversation_id INTEGER NOT NULL REFERENCES conversations(conversation_id),
|
||||
role TEXT NOT NULL,
|
||||
content TEXT NOT NULL DEFAULT '',
|
||||
token_count INTEGER NOT NULL DEFAULT 0,
|
||||
created_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||||
)`,
|
||||
|
||||
`CREATE TABLE IF NOT EXISTS message_parts (
|
||||
part_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
message_id INTEGER NOT NULL REFERENCES messages(message_id),
|
||||
type TEXT NOT NULL,
|
||||
text TEXT,
|
||||
name TEXT,
|
||||
arguments TEXT,
|
||||
tool_call_id TEXT,
|
||||
media_uri TEXT,
|
||||
mime_type TEXT,
|
||||
ordinal INTEGER NOT NULL DEFAULT 0
|
||||
)`,
|
||||
|
||||
`CREATE TABLE IF NOT EXISTS summaries (
|
||||
summary_id TEXT PRIMARY KEY,
|
||||
conversation_id INTEGER NOT NULL REFERENCES conversations(conversation_id),
|
||||
kind TEXT NOT NULL,
|
||||
depth INTEGER NOT NULL DEFAULT 0,
|
||||
content TEXT NOT NULL,
|
||||
token_count INTEGER NOT NULL DEFAULT 0,
|
||||
earliest_at TEXT,
|
||||
latest_at TEXT,
|
||||
descendant_count INTEGER NOT NULL DEFAULT 0,
|
||||
descendant_token_count INTEGER NOT NULL DEFAULT 0,
|
||||
source_message_token_count INTEGER NOT NULL DEFAULT 0,
|
||||
model TEXT,
|
||||
created_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||||
)`,
|
||||
|
||||
`CREATE TABLE IF NOT EXISTS summary_parents (
|
||||
summary_id TEXT NOT NULL,
|
||||
parent_summary_id TEXT NOT NULL,
|
||||
PRIMARY KEY (summary_id, parent_summary_id)
|
||||
)`,
|
||||
|
||||
`CREATE TABLE IF NOT EXISTS summary_messages (
|
||||
summary_id TEXT NOT NULL,
|
||||
message_id INTEGER NOT NULL,
|
||||
ordinal INTEGER NOT NULL DEFAULT 0,
|
||||
PRIMARY KEY (summary_id, message_id)
|
||||
)`,
|
||||
|
||||
`CREATE TABLE IF NOT EXISTS context_items (
|
||||
conversation_id INTEGER NOT NULL,
|
||||
ordinal INTEGER NOT NULL,
|
||||
item_type TEXT NOT NULL,
|
||||
summary_id TEXT,
|
||||
message_id INTEGER,
|
||||
token_count INTEGER NOT NULL DEFAULT 0,
|
||||
created_at TEXT NOT NULL DEFAULT (datetime('now')),
|
||||
PRIMARY KEY (conversation_id, ordinal)
|
||||
)`,
|
||||
|
||||
// FTS5 virtual table with trigram tokenizer for CJK support
|
||||
sqlCreateSummariesFTS,
|
||||
|
||||
// FTS5 virtual table for message search with trigram tokenizer
|
||||
sqlCreateMessagesFTS,
|
||||
|
||||
// Indexes for common query patterns
|
||||
`CREATE INDEX IF NOT EXISTS idx_messages_conversation ON messages(conversation_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_messages_created ON messages(conversation_id, created_at)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_summaries_conversation ON summaries(conversation_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_summaries_kind_depth ON summaries(conversation_id, kind, depth)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_summary_parents_parent ON summary_parents(parent_summary_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_summary_messages_message ON summary_messages(message_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_context_items_conv ON context_items(conversation_id, ordinal)`,
|
||||
|
||||
// FTS5 triggers to keep summaries_fts in sync with summaries table
|
||||
`CREATE TRIGGER IF NOT EXISTS summaries_ai AFTER INSERT ON summaries BEGIN
|
||||
INSERT INTO summaries_fts (summary_id, content) VALUES (new.summary_id, new.content);
|
||||
END`,
|
||||
`CREATE TRIGGER IF NOT EXISTS summaries_ad AFTER DELETE ON summaries BEGIN
|
||||
INSERT INTO summaries_fts (summaries_fts, summary_id, content) VALUES ('delete', old.summary_id, old.content);
|
||||
END`,
|
||||
`CREATE TRIGGER IF NOT EXISTS summaries_au AFTER UPDATE ON summaries BEGIN
|
||||
INSERT INTO summaries_fts (summaries_fts, summary_id, content) VALUES ('delete', old.summary_id, old.content);
|
||||
INSERT INTO summaries_fts (summary_id, content) VALUES (new.summary_id, new.content);
|
||||
END`,
|
||||
|
||||
// FTS5 triggers to keep messages_fts in sync with messages table
|
||||
`CREATE TRIGGER IF NOT EXISTS messages_ai AFTER INSERT ON messages BEGIN
|
||||
INSERT INTO messages_fts (message_id, content) VALUES (new.message_id, new.content);
|
||||
END`,
|
||||
`CREATE TRIGGER IF NOT EXISTS messages_ad AFTER DELETE ON messages BEGIN
|
||||
DELETE FROM messages_fts WHERE message_id = old.message_id;
|
||||
END`,
|
||||
`CREATE TRIGGER IF NOT EXISTS messages_au AFTER UPDATE ON messages BEGIN
|
||||
DELETE FROM messages_fts WHERE message_id = old.message_id;
|
||||
INSERT INTO messages_fts (message_id, content) VALUES (new.message_id, new.content);
|
||||
END`,
|
||||
}
|
||||
|
||||
for _, s := range stmts {
|
||||
if _, err := db.Exec(s); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkFTS5Support verifies that SQLite has FTS5 with trigram tokenizer enabled.
|
||||
// This is required for full-text search with CJK (Chinese, Japanese, Korean) support.
|
||||
func checkFTS5Support(db *sql.DB) error {
|
||||
// Check if FTS5 is compiled in
|
||||
var fts5Enabled int
|
||||
err := db.QueryRow(`SELECT sqlite_compileoption_used('ENABLE_FTS5')`).Scan(&fts5Enabled)
|
||||
if err != nil {
|
||||
// sqlite_compileoption_used might not exist in older SQLite
|
||||
// Try a different approach: create a test FTS5 table
|
||||
_, testErr := db.Exec(sqlCheckFTS5Available)
|
||||
if testErr != nil {
|
||||
return fmt.Errorf("SQLite FTS5 not available: %w (required for full-text search)", testErr)
|
||||
}
|
||||
db.Exec(sqlDropFTS5Check)
|
||||
} else if fts5Enabled == 0 {
|
||||
return fmt.Errorf("SQLite was compiled without FTS5 support (required for full-text search)")
|
||||
}
|
||||
|
||||
// Check if trigram tokenizer is available by trying to create a test table
|
||||
// Not all SQLite builds include the trigram tokenizer
|
||||
_, err = db.Exec(sqlCheckTrigramAvailable)
|
||||
if err != nil {
|
||||
logger.WarnCF("seahorse", "SQLite trigram tokenizer not available, CJK search may be limited",
|
||||
map[string]any{"error": err.Error()})
|
||||
// Trigram is not strictly required, just better for CJK
|
||||
// Don't return error, just log warning
|
||||
} else {
|
||||
db.Exec(sqlDropTrigramCheck)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,223 @@
|
||||
package seahorse
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
var testDBCounter uint64
|
||||
|
||||
func openTestDB(t *testing.T) *sql.DB {
|
||||
t.Helper()
|
||||
|
||||
n := atomic.AddUint64(&testDBCounter, 1)
|
||||
testName := strings.NewReplacer("/", "_", " ", "_").Replace(t.Name())
|
||||
// Use a shared in-memory database so concurrent goroutines/connections in tests
|
||||
// observe the same schema/data.
|
||||
dsn := fmt.Sprintf("file:seahorse_test_%s_%d?mode=memory&cache=shared", testName, n)
|
||||
|
||||
db, err := sql.Open("sqlite", dsn)
|
||||
if err != nil {
|
||||
t.Fatalf("open test db: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { db.Close() })
|
||||
return db
|
||||
}
|
||||
|
||||
func TestRunMigrations(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
|
||||
if err := runSchema(db); err != nil {
|
||||
t.Fatalf("runSchema: %v", err)
|
||||
}
|
||||
|
||||
// Verify all tables exist
|
||||
tables := []string{
|
||||
"conversations",
|
||||
"messages",
|
||||
"message_parts",
|
||||
"summaries",
|
||||
"summary_parents",
|
||||
"summary_messages",
|
||||
"context_items",
|
||||
}
|
||||
for _, tbl := range tables {
|
||||
var name string
|
||||
err := db.QueryRow(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name=?", tbl,
|
||||
).Scan(&name)
|
||||
if err != nil {
|
||||
t.Errorf("table %q not found: %v", tbl, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify FTS5 virtual table exists
|
||||
var ftsName string
|
||||
err := db.QueryRow(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='summaries_fts'",
|
||||
).Scan(&ftsName)
|
||||
if err != nil {
|
||||
t.Errorf("FTS5 table summaries_fts not found: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunMigrationsIdempotent(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
|
||||
// Run migrations twice — should succeed both times
|
||||
if err := runSchema(db); err != nil {
|
||||
t.Fatalf("first migration: %v", err)
|
||||
}
|
||||
if err := runSchema(db); err != nil {
|
||||
t.Fatalf("second migration (idempotent): %v", err)
|
||||
}
|
||||
|
||||
// Verify we can still insert data after double migration
|
||||
res, err := db.Exec(
|
||||
"INSERT INTO conversations (session_key, created_at, updated_at) VALUES (?, datetime('now'), datetime('now'))",
|
||||
"test-session",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("insert after double migration: %v", err)
|
||||
}
|
||||
id, _ := res.LastInsertId()
|
||||
if id == 0 {
|
||||
t.Error("expected non-zero conversation id")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrationConversationUnique(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
if err := runSchema(db); err != nil {
|
||||
t.Fatalf("migration: %v", err)
|
||||
}
|
||||
|
||||
// Insert first
|
||||
_, err := db.Exec(
|
||||
"INSERT INTO conversations (session_key, created_at, updated_at) VALUES (?, datetime('now'), datetime('now'))",
|
||||
"unique-key",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("first insert: %v", err)
|
||||
}
|
||||
|
||||
// Duplicate should fail
|
||||
_, err = db.Exec(
|
||||
"INSERT INTO conversations (session_key, created_at, updated_at) VALUES (?, datetime('now'), datetime('now'))",
|
||||
"unique-key",
|
||||
)
|
||||
if err == nil {
|
||||
t.Error("expected unique constraint violation for duplicate session_key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrationSummaryFTSInsert(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
if err := runSchema(db); err != nil {
|
||||
t.Fatalf("migration: %v", err)
|
||||
}
|
||||
|
||||
// Insert a conversation first
|
||||
_, err := db.Exec(
|
||||
"INSERT INTO conversations (session_key, created_at, updated_at) VALUES (?, datetime('now'), datetime('now'))",
|
||||
"fts-test",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("insert conversation: %v", err)
|
||||
}
|
||||
|
||||
// Insert a summary
|
||||
_, err = db.Exec(
|
||||
`INSERT INTO summaries (summary_id, conversation_id, kind, depth, content, token_count, created_at)
|
||||
VALUES ('sum_test1', 1, 'leaf', 0, '你好世界 hello world', 10, datetime('now'))`)
|
||||
if err != nil {
|
||||
t.Fatalf("insert summary: %v", err)
|
||||
}
|
||||
|
||||
// FTS should find it — trigram tokenizer requires >= 3 chars
|
||||
rows, err := db.Query(
|
||||
"SELECT summary_id FROM summaries_fts WHERE summaries_fts MATCH ?",
|
||||
"你好世",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("FTS query: %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var found string
|
||||
if rows.Next() {
|
||||
if err := rows.Scan(&found); err != nil {
|
||||
t.Fatalf("scan: %v", err)
|
||||
}
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
t.Fatalf("rows.Err: %v", err)
|
||||
}
|
||||
if found != "sum_test1" {
|
||||
t.Errorf("FTS: expected 'sum_test1', got %q", found)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrationSummaryParentsPK(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
if err := runSchema(db); err != nil {
|
||||
t.Fatalf("migration: %v", err)
|
||||
}
|
||||
|
||||
// Insert two summaries
|
||||
for _, id := range []string{"sum_a", "sum_b"} {
|
||||
_, err := db.Exec(
|
||||
`INSERT INTO summaries (summary_id, conversation_id, kind, depth, content, token_count, created_at)
|
||||
VALUES (?, 1, 'leaf', 0, 'content', 5, datetime('now'))`, id)
|
||||
if err != nil {
|
||||
t.Fatalf("insert summary %s: %v", id, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Link child to parent
|
||||
_, err := db.Exec(
|
||||
"INSERT INTO summary_parents (summary_id, parent_summary_id) VALUES ('sum_a', 'sum_b')")
|
||||
if err != nil {
|
||||
t.Fatalf("link: %v", err)
|
||||
}
|
||||
|
||||
// Duplicate link should fail (composite PK)
|
||||
_, err = db.Exec(
|
||||
"INSERT INTO summary_parents (summary_id, parent_summary_id) VALUES ('sum_a', 'sum_b')")
|
||||
if err == nil {
|
||||
t.Error("expected unique constraint violation for duplicate summary_parents link")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFTS5SQLConstants(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
|
||||
// Verify FTS5 check SQL executes without error
|
||||
_, err := db.Exec(sqlCheckFTS5Available)
|
||||
if err != nil {
|
||||
t.Errorf("sqlCheckFTS5Available failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify trigram check SQL executes without error
|
||||
_, err = db.Exec(sqlCheckTrigramAvailable)
|
||||
if err != nil {
|
||||
t.Errorf("sqlCheckTrigramAvailable failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify summaries_fts SQL executes without error
|
||||
_, err = db.Exec(sqlCreateSummariesFTS)
|
||||
if err != nil {
|
||||
t.Errorf("sqlCreateSummariesFTS failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify messages_fts SQL executes without error
|
||||
_, err = db.Exec(sqlCreateMessagesFTS)
|
||||
if err != nil {
|
||||
t.Errorf("sqlCreateMessagesFTS failed: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,261 @@
|
||||
package seahorse
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
// escapeXML escapes special characters for safe inclusion in XML content.
|
||||
func escapeXML(s string) string {
|
||||
s = strings.ReplaceAll(s, "&", "&")
|
||||
s = strings.ReplaceAll(s, "<", "<")
|
||||
s = strings.ReplaceAll(s, ">", ">")
|
||||
s = strings.ReplaceAll(s, "\"", """)
|
||||
s = strings.ReplaceAll(s, "'", "'")
|
||||
return s
|
||||
}
|
||||
|
||||
// resolvedItem is a context item resolved to its full content with token count.
|
||||
type resolvedItem struct {
|
||||
ordinal int
|
||||
itemType string // "message" or "summary"
|
||||
message *Message
|
||||
summary *Summary
|
||||
tokenCount int
|
||||
}
|
||||
|
||||
// Assemble builds budget-constrained context from summaries + messages.
|
||||
//
|
||||
// Algorithm:
|
||||
// 1. Fetch context_items, resolve to full content
|
||||
// 2. Split into evictable prefix + protected fresh tail
|
||||
// 3. If evictable fits in remaining budget → include all
|
||||
// 4. Else walk evictable from newest to oldest, keep while fits
|
||||
func (a *Assembler) Assemble(ctx context.Context, convID int64, input AssembleInput) (*AssembleResult, error) {
|
||||
items, err := a.store.GetContextItems(ctx, convID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get context items: %w", err)
|
||||
}
|
||||
if len(items) == 0 {
|
||||
return &AssembleResult{}, nil
|
||||
}
|
||||
|
||||
// Resolve all items
|
||||
resolved := make([]resolvedItem, len(items))
|
||||
for i, item := range items {
|
||||
r, err := a.resolveItem(ctx, item)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resolved[i] = r
|
||||
}
|
||||
|
||||
// Split into evictable prefix and protected fresh tail
|
||||
tailStart := len(resolved) - FreshTailCount
|
||||
if tailStart < 0 {
|
||||
tailStart = 0
|
||||
}
|
||||
evictable := resolved[:tailStart]
|
||||
freshTail := resolved[tailStart:]
|
||||
|
||||
// Calculate fresh tail tokens
|
||||
freshTailTokens := 0
|
||||
for _, r := range freshTail {
|
||||
freshTailTokens += r.tokenCount
|
||||
}
|
||||
|
||||
// Budget-aware selection of evictable items
|
||||
remainingBudget := input.Budget - freshTailTokens
|
||||
if remainingBudget < 0 {
|
||||
// Fresh tail alone exceeds budget - we keep it anyway (design decision)
|
||||
// Log for debugging retry/overflow issues
|
||||
logger.InfoCF("seahorse", "assemble: fresh tail exceeds budget", map[string]any{
|
||||
"budget": input.Budget,
|
||||
"fresh_tail_tokens": freshTailTokens,
|
||||
"fresh_tail_count": len(freshTail),
|
||||
"over_budget_by": freshTailTokens - input.Budget,
|
||||
})
|
||||
remainingBudget = 0
|
||||
}
|
||||
|
||||
var selected []resolvedItem
|
||||
evictableTokens := 0
|
||||
for _, r := range evictable {
|
||||
evictableTokens += r.tokenCount
|
||||
}
|
||||
|
||||
if evictableTokens <= remainingBudget {
|
||||
// All evictable fit
|
||||
selected = append(selected, evictable...)
|
||||
} else {
|
||||
// Walk from newest to oldest, keep while fits
|
||||
var kept []resolvedItem
|
||||
accum := 0
|
||||
for i := len(evictable) - 1; i >= 0; i-- {
|
||||
if accum+evictable[i].tokenCount <= remainingBudget {
|
||||
kept = append(kept, evictable[i])
|
||||
accum += evictable[i].tokenCount
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
// Reverse to restore chronological order
|
||||
for i, j := 0, len(kept)-1; i < j; i, j = i+1, j-1 {
|
||||
kept[i], kept[j] = kept[j], kept[i]
|
||||
}
|
||||
selected = append(selected, kept...)
|
||||
}
|
||||
|
||||
// Combine: selected evictable + fresh tail
|
||||
final := append(selected, freshTail...)
|
||||
|
||||
// Build result
|
||||
var messages []Message
|
||||
var summaries []Summary
|
||||
var sourceIDs []string
|
||||
totalTokens := 0
|
||||
maxDepth := 0
|
||||
condensedCount := 0
|
||||
|
||||
for _, r := range final {
|
||||
totalTokens += r.tokenCount
|
||||
if r.itemType == "message" && r.message != nil {
|
||||
messages = append(messages, *r.message)
|
||||
sourceIDs = append(sourceIDs, fmt.Sprintf("msg:%d", r.message.ID))
|
||||
} else if r.itemType == "summary" && r.summary != nil {
|
||||
summaries = append(summaries, *r.summary)
|
||||
if r.summary.Depth > maxDepth {
|
||||
maxDepth = r.summary.Depth
|
||||
}
|
||||
if r.summary.Kind == SummaryKindCondensed {
|
||||
condensedCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build depth-aware system prompt addition
|
||||
systemPromptAddition := ""
|
||||
if len(summaries) > 0 {
|
||||
if maxDepth >= 2 || condensedCount >= 2 {
|
||||
systemPromptAddition = "Your context has been heavily compressed through multi-level summarization.\n" +
|
||||
"- Do NOT assert specific facts (commands, SHAs, paths, timestamps) from summaries without expanding.\n" +
|
||||
"- When uncertain, use expand to recover original detail before making claims.\n" +
|
||||
"- Tool escalation: grep \xe2\x86\x92 describe \xe2\x86\x92 expand"
|
||||
} else {
|
||||
systemPromptAddition = "Some earlier messages have been summarized. Use expand tools to recover details if needed."
|
||||
}
|
||||
}
|
||||
|
||||
// Build Summary field: all XML summaries + system prompt addition
|
||||
var summaryParts []string
|
||||
for _, sum := range summaries {
|
||||
if sum.Content == "" {
|
||||
continue
|
||||
}
|
||||
// Load parent IDs for XML formatting
|
||||
parentSummaries, err := a.store.GetSummaryParents(ctx, sum.SummaryID)
|
||||
if err != nil {
|
||||
logger.WarnCF("seahorse", "assemble: get summary parents", map[string]any{
|
||||
"summary_id": sum.SummaryID,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
var parentIDs []string
|
||||
for _, ps := range parentSummaries {
|
||||
parentIDs = append(parentIDs, ps.SummaryID)
|
||||
}
|
||||
summaryParts = append(summaryParts, FormatSummaryXML(&sum, parentIDs))
|
||||
}
|
||||
summary := strings.Join(summaryParts, "\n\n")
|
||||
if systemPromptAddition != "" {
|
||||
if summary != "" {
|
||||
summary += "\n\n"
|
||||
}
|
||||
summary += systemPromptAddition
|
||||
}
|
||||
|
||||
return &AssembleResult{
|
||||
Messages: messages,
|
||||
Summary: summary,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// resolveItem loads the full message or summary for a context item.
|
||||
func (a *Assembler) resolveItem(ctx context.Context, item ContextItem) (resolvedItem, error) {
|
||||
if item.ItemType == "message" {
|
||||
msg, err := a.store.GetMessageByID(ctx, item.MessageID)
|
||||
if err != nil {
|
||||
return resolvedItem{}, err
|
||||
}
|
||||
tokens := item.TokenCount
|
||||
if tokens == 0 {
|
||||
tokens = msg.TokenCount
|
||||
}
|
||||
return resolvedItem{
|
||||
ordinal: item.Ordinal,
|
||||
itemType: "message",
|
||||
message: msg,
|
||||
tokenCount: tokens,
|
||||
}, nil
|
||||
}
|
||||
|
||||
if item.ItemType == "summary" {
|
||||
sum, err := a.store.GetSummary(ctx, item.SummaryID)
|
||||
if err != nil {
|
||||
return resolvedItem{}, err
|
||||
}
|
||||
tokens := item.TokenCount
|
||||
if tokens == 0 {
|
||||
tokens = sum.TokenCount
|
||||
}
|
||||
return resolvedItem{
|
||||
ordinal: item.Ordinal,
|
||||
itemType: "summary",
|
||||
summary: sum,
|
||||
tokenCount: tokens,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return resolvedItem{
|
||||
ordinal: item.Ordinal,
|
||||
itemType: item.ItemType,
|
||||
tokenCount: item.TokenCount,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// FormatSummaryXML formats a summary as XML for LLM context.
|
||||
// This is exported so context managers can format summaries consistently.
|
||||
func FormatSummaryXML(s *Summary, parentIDs []string) string {
|
||||
// Build time attributes if available
|
||||
var attrs string
|
||||
if s.EarliestAt != nil {
|
||||
attrs += fmt.Sprintf(` earliest_at="%s"`, s.EarliestAt.Format(time.RFC3339))
|
||||
}
|
||||
if s.LatestAt != nil {
|
||||
attrs += fmt.Sprintf(` latest_at="%s"`, s.LatestAt.Format(time.RFC3339))
|
||||
}
|
||||
|
||||
var parentsSection string
|
||||
if s.Kind == SummaryKindCondensed && len(parentIDs) > 0 {
|
||||
parents := "<parents>\n"
|
||||
for _, pid := range parentIDs {
|
||||
parents += fmt.Sprintf(" <summary_ref id=\"%s\" />\n", pid)
|
||||
}
|
||||
parents += " </parents>\n"
|
||||
parentsSection = parents
|
||||
}
|
||||
return fmt.Sprintf(
|
||||
"<summary id=\"%s\" kind=\"%s\" depth=\"%d\" descendant_count=\"%d\"%s>\n <content>\n %s\n </content>\n%s</summary>",
|
||||
s.SummaryID,
|
||||
string(s.Kind),
|
||||
s.Depth,
|
||||
s.DescendantCount,
|
||||
attrs,
|
||||
escapeXML(s.Content),
|
||||
parentsSection,
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,536 @@
|
||||
package seahorse
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// --- Assembler Tests ---
|
||||
|
||||
// helper: create a store with messages and summaries for assembly tests
|
||||
func setupAssemblerStore(t *testing.T) (*Store, int64) {
|
||||
t.Helper()
|
||||
s := openTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
conv, err := s.GetOrCreateConversation(ctx, "test:assemble")
|
||||
if err != nil {
|
||||
t.Fatalf("create conversation: %v", err)
|
||||
}
|
||||
|
||||
return s, conv.ConversationID
|
||||
}
|
||||
|
||||
func TestAssemblerAssembleEmpty(t *testing.T) {
|
||||
s, convID := setupAssemblerStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
a := &Assembler{store: s, config: Config{}}
|
||||
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 1000})
|
||||
if err != nil {
|
||||
t.Fatalf("Assemble: %v", err)
|
||||
}
|
||||
if len(result.Messages) != 0 {
|
||||
t.Errorf("Messages = %d, want 0", len(result.Messages))
|
||||
}
|
||||
if result.Summary != "" {
|
||||
t.Errorf("Summary = %q, want empty", result.Summary)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAssemblerAssembleMessagesOnly(t *testing.T) {
|
||||
s, convID := setupAssemblerStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create messages
|
||||
msg1, _ := s.AddMessage(ctx, convID, "user", "hello", 5)
|
||||
msg2, _ := s.AddMessage(ctx, convID, "assistant", "world", 5)
|
||||
|
||||
// Create context items
|
||||
s.UpsertContextItems(ctx, convID, []ContextItem{
|
||||
{Ordinal: 100, ItemType: "message", MessageID: msg1.ID, TokenCount: 5},
|
||||
{Ordinal: 200, ItemType: "message", MessageID: msg2.ID, TokenCount: 5},
|
||||
})
|
||||
|
||||
a := &Assembler{store: s, config: Config{}}
|
||||
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 100})
|
||||
if err != nil {
|
||||
t.Fatalf("Assemble: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 2 {
|
||||
t.Fatalf("Messages = %d, want 2", len(result.Messages))
|
||||
}
|
||||
if result.Messages[0].Content != "hello" {
|
||||
t.Errorf("Messages[0].Content = %q, want 'hello'", result.Messages[0].Content)
|
||||
}
|
||||
if result.Messages[1].Content != "world" {
|
||||
t.Errorf("Messages[1].Content = %q, want 'world'", result.Messages[1].Content)
|
||||
}
|
||||
// No summaries, so Summary should be empty
|
||||
if result.Summary != "" {
|
||||
t.Errorf("Summary = %q, want empty", result.Summary)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAssemblerAssembleWithSummary(t *testing.T) {
|
||||
s, convID := setupAssemblerStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a summary
|
||||
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: "summary of early messages",
|
||||
TokenCount: 50,
|
||||
})
|
||||
|
||||
// Create recent messages
|
||||
msg1, _ := s.AddMessage(ctx, convID, "user", "recent", 5)
|
||||
msg2, _ := s.AddMessage(ctx, convID, "assistant", "reply", 5)
|
||||
|
||||
// Context: summary + recent messages
|
||||
s.UpsertContextItems(ctx, convID, []ContextItem{
|
||||
{Ordinal: 100, ItemType: "summary", SummaryID: summary.SummaryID, TokenCount: 50},
|
||||
{Ordinal: 200, ItemType: "message", MessageID: msg1.ID, TokenCount: 5},
|
||||
{Ordinal: 300, ItemType: "message", MessageID: msg2.ID, TokenCount: 5},
|
||||
})
|
||||
|
||||
a := &Assembler{store: s, config: Config{}}
|
||||
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 1000})
|
||||
if err != nil {
|
||||
t.Fatalf("Assemble: %v", err)
|
||||
}
|
||||
|
||||
// Messages = 2 raw messages (summaries are in Summary field, not Messages)
|
||||
if len(result.Messages) != 2 {
|
||||
t.Errorf("Messages = %d, want 2 (raw messages only)", len(result.Messages))
|
||||
}
|
||||
// Summary should contain XML with summary content
|
||||
if result.Summary == "" {
|
||||
t.Error("Summary should not be empty when summary exists")
|
||||
}
|
||||
if !strings.Contains(result.Summary, summary.Content) {
|
||||
t.Errorf("Summary should contain summary content %q", summary.Content)
|
||||
}
|
||||
if !strings.Contains(result.Summary, "<summary") {
|
||||
t.Error("Summary should contain <summary XML tag")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAssemblerBudgetEvictsOldest(t *testing.T) {
|
||||
s, convID := setupAssemblerStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create 40 messages, each with 10 tokens = 400 total
|
||||
msgs := make([]*Message, 40)
|
||||
for i := 0; i < 40; i++ {
|
||||
m, _ := s.AddMessage(ctx, convID, "user", "msg", 10)
|
||||
msgs[i] = m
|
||||
}
|
||||
|
||||
// Context items for all messages
|
||||
items := make([]ContextItem, 40)
|
||||
for i := 0; i < 40; i++ {
|
||||
items[i] = ContextItem{
|
||||
Ordinal: (i + 1) * 100,
|
||||
ItemType: "message",
|
||||
MessageID: msgs[i].ID,
|
||||
TokenCount: 10,
|
||||
}
|
||||
}
|
||||
s.UpsertContextItems(ctx, convID, items)
|
||||
|
||||
// Budget of 200 tokens with FreshTailCount=32
|
||||
// Fresh tail = last 32 messages (320 tokens, over budget, but always included)
|
||||
// Evictable = first 8 messages (80 tokens)
|
||||
// Budget after tail: max(0, 200-320) = 0 → no evictable items included
|
||||
a := &Assembler{store: s, config: Config{}}
|
||||
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 200})
|
||||
if err != nil {
|
||||
t.Fatalf("Assemble: %v", err)
|
||||
}
|
||||
|
||||
// Should only include the 32-item fresh tail
|
||||
if len(result.Messages) != 32 {
|
||||
t.Errorf("Messages = %d, want 32 (fresh tail)", len(result.Messages))
|
||||
}
|
||||
// Should be the LAST 32 messages
|
||||
if result.Messages[0].ID != msgs[8].ID {
|
||||
t.Errorf("first message ID = %d, want %d (msgs[8])", result.Messages[0].ID, msgs[8].ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAssemblerBudgetFitsAll(t *testing.T) {
|
||||
s, convID := setupAssemblerStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
msgs := make([]*Message, 5)
|
||||
for i := 0; i < 5; i++ {
|
||||
m, _ := s.AddMessage(ctx, convID, "user", "msg", 10)
|
||||
msgs[i] = m
|
||||
}
|
||||
|
||||
items := make([]ContextItem, 5)
|
||||
for i := 0; i < 5; i++ {
|
||||
items[i] = ContextItem{
|
||||
Ordinal: (i + 1) * 100,
|
||||
ItemType: "message",
|
||||
MessageID: msgs[i].ID,
|
||||
TokenCount: 10,
|
||||
}
|
||||
}
|
||||
s.UpsertContextItems(ctx, convID, items)
|
||||
|
||||
// Budget = 100, total = 50, FreshTailCount=32 → all items in tail
|
||||
a := &Assembler{store: s, config: Config{}}
|
||||
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 100})
|
||||
if err != nil {
|
||||
t.Fatalf("Assemble: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 5 {
|
||||
t.Errorf("Messages = %d, want 5", len(result.Messages))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAssemblerSummaryXMLFormat(t *testing.T) {
|
||||
s, convID := setupAssemblerStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: "test summary content",
|
||||
TokenCount: 20,
|
||||
})
|
||||
|
||||
msg, _ := s.AddMessage(ctx, convID, "user", "hello", 5)
|
||||
|
||||
s.UpsertContextItems(ctx, convID, []ContextItem{
|
||||
{Ordinal: 100, ItemType: "summary", SummaryID: summary.SummaryID, TokenCount: 20},
|
||||
{Ordinal: 200, ItemType: "message", MessageID: msg.ID, TokenCount: 5},
|
||||
})
|
||||
|
||||
a := &Assembler{store: s, config: Config{}}
|
||||
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 1000})
|
||||
if err != nil {
|
||||
t.Fatalf("Assemble: %v", err)
|
||||
}
|
||||
|
||||
// Messages should only contain raw messages (no XML summary in Messages)
|
||||
if len(result.Messages) != 1 {
|
||||
t.Errorf("Messages = %d, want 1 (raw message only)", len(result.Messages))
|
||||
}
|
||||
// Summary should contain XML with summary content
|
||||
if result.Summary == "" {
|
||||
t.Fatal("Summary should not be empty")
|
||||
}
|
||||
if !contains(result.Summary, "<summary") {
|
||||
t.Errorf("Summary missing <summary tag: %q", result.Summary)
|
||||
}
|
||||
if !contains(result.Summary, summary.SummaryID) {
|
||||
t.Errorf("Summary missing summary ID: %q", result.Summary)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAssemblerSummaryXMLEscaping(t *testing.T) {
|
||||
// Summary content with special XML characters should be properly escaped
|
||||
s, convID := setupAssemblerStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create summary with content containing XML special characters
|
||||
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: `User said: "hello" & asked about <tags>`,
|
||||
TokenCount: 20,
|
||||
})
|
||||
|
||||
s.UpsertContextItems(ctx, convID, []ContextItem{
|
||||
{Ordinal: 100, ItemType: "summary", SummaryID: summary.SummaryID, TokenCount: 20},
|
||||
})
|
||||
|
||||
a := &Assembler{store: s, config: Config{}}
|
||||
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 1000})
|
||||
if err != nil {
|
||||
t.Fatalf("Assemble: %v", err)
|
||||
}
|
||||
|
||||
// Summary field should contain XML with escaped special characters
|
||||
if result.Summary == "" {
|
||||
t.Fatal("Summary should not be empty")
|
||||
}
|
||||
|
||||
// Check that special characters are escaped
|
||||
if strings.Contains(result.Summary, "<tags>") {
|
||||
t.Errorf("BUG: unescaped < in summary content: %q", result.Summary)
|
||||
}
|
||||
if strings.Contains(result.Summary, `"hello"`) {
|
||||
t.Errorf("BUG: unescaped \" in summary content: %q", result.Summary)
|
||||
}
|
||||
// & should be escaped as &
|
||||
if strings.Contains(result.Summary, " & ") {
|
||||
t.Errorf("BUG: unescaped & in summary content: %q", result.Summary)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAssemblerSummaryXMLWithParents(t *testing.T) {
|
||||
s, convID := setupAssemblerStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a leaf and a condensed summary (condensed has parent)
|
||||
leaf, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: "leaf content",
|
||||
TokenCount: 20,
|
||||
})
|
||||
condensed, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindCondensed,
|
||||
Depth: 1,
|
||||
Content: "condensed content",
|
||||
TokenCount: 15,
|
||||
ParentIDs: []string{leaf.SummaryID},
|
||||
})
|
||||
|
||||
msg, _ := s.AddMessage(ctx, convID, "user", "fresh", 5)
|
||||
|
||||
s.UpsertContextItems(ctx, convID, []ContextItem{
|
||||
{Ordinal: 100, ItemType: "summary", SummaryID: condensed.SummaryID, TokenCount: 15},
|
||||
{Ordinal: 200, ItemType: "message", MessageID: msg.ID, TokenCount: 5},
|
||||
})
|
||||
|
||||
a := &Assembler{store: s, config: Config{}}
|
||||
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 1000})
|
||||
if err != nil {
|
||||
t.Fatalf("Assemble: %v", err)
|
||||
}
|
||||
|
||||
// Summary field should contain XML with parent information
|
||||
if result.Summary == "" {
|
||||
t.Fatal("Summary should not be empty")
|
||||
}
|
||||
xmlContent := result.Summary
|
||||
|
||||
// Should contain <parents> section with parent ID
|
||||
if !contains(xmlContent, "<parents>") {
|
||||
t.Errorf("condensed summary XML missing <parents> section: %q", xmlContent)
|
||||
}
|
||||
if !contains(xmlContent, leaf.SummaryID) {
|
||||
t.Errorf("condensed summary XML missing parent ID %q: %q", leaf.SummaryID, xmlContent)
|
||||
}
|
||||
|
||||
// Should contain kind="condensed"
|
||||
if !contains(xmlContent, `kind="condensed"`) {
|
||||
t.Errorf("condensed summary XML missing kind attribute: %q", xmlContent)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAssemblerSummaryXMLIncludesDescendantCount(t *testing.T) {
|
||||
s, convID := setupAssemblerStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a leaf summary with specific descendant count
|
||||
leaf, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: "leaf content",
|
||||
TokenCount: 20,
|
||||
DescendantCount: 8,
|
||||
DescendantTokenCount: 1200,
|
||||
})
|
||||
|
||||
msg, _ := s.AddMessage(ctx, convID, "user", "fresh", 5)
|
||||
|
||||
s.UpsertContextItems(ctx, convID, []ContextItem{
|
||||
{Ordinal: 100, ItemType: "summary", SummaryID: leaf.SummaryID, TokenCount: 20},
|
||||
{Ordinal: 200, ItemType: "message", MessageID: msg.ID, TokenCount: 5},
|
||||
})
|
||||
|
||||
a := &Assembler{store: s, config: Config{}}
|
||||
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 1000})
|
||||
if err != nil {
|
||||
t.Fatalf("Assemble: %v", err)
|
||||
}
|
||||
|
||||
if result.Summary == "" {
|
||||
t.Fatal("Summary should not be empty")
|
||||
}
|
||||
xmlContent := result.Summary
|
||||
|
||||
// Should contain descendant_count="8"
|
||||
if !contains(xmlContent, `descendant_count="8"`) {
|
||||
t.Errorf("summary XML missing descendant_count attribute: %q", xmlContent)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAssemblerLeafSummaryNoParents(t *testing.T) {
|
||||
s, convID := setupAssemblerStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Leaf summary has no parents
|
||||
leaf, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: "leaf content",
|
||||
TokenCount: 20,
|
||||
})
|
||||
|
||||
msg, _ := s.AddMessage(ctx, convID, "user", "fresh", 5)
|
||||
|
||||
s.UpsertContextItems(ctx, convID, []ContextItem{
|
||||
{Ordinal: 100, ItemType: "summary", SummaryID: leaf.SummaryID, TokenCount: 20},
|
||||
{Ordinal: 200, ItemType: "message", MessageID: msg.ID, TokenCount: 5},
|
||||
})
|
||||
|
||||
a := &Assembler{store: s, config: Config{}}
|
||||
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 1000})
|
||||
if err != nil {
|
||||
t.Fatalf("Assemble: %v", err)
|
||||
}
|
||||
|
||||
if result.Summary == "" {
|
||||
t.Fatal("Summary should not be empty")
|
||||
}
|
||||
xmlContent := result.Summary
|
||||
|
||||
// Leaf summary should NOT have <parents> section
|
||||
if contains(xmlContent, "<parents>") {
|
||||
t.Errorf("leaf summary XML should not have <parents> section: %q", xmlContent)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAssemblerDepthAwarePrompt(t *testing.T) {
|
||||
s, convID := setupAssemblerStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a condensed summary (depth >= 2) to trigger full guidance
|
||||
now := time.Now().UTC()
|
||||
leaf, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: "leaf summary",
|
||||
TokenCount: 20,
|
||||
EarliestAt: &now,
|
||||
LatestAt: &now,
|
||||
})
|
||||
condensed, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindCondensed,
|
||||
Depth: 2,
|
||||
Content: "condensed summary",
|
||||
TokenCount: 15,
|
||||
ParentIDs: []string{leaf.SummaryID},
|
||||
DescendantCount: 1,
|
||||
DescendantTokenCount: 20,
|
||||
})
|
||||
|
||||
msg, _ := s.AddMessage(ctx, convID, "user", "fresh", 5)
|
||||
|
||||
s.UpsertContextItems(ctx, convID, []ContextItem{
|
||||
{Ordinal: 100, ItemType: "summary", SummaryID: condensed.SummaryID, TokenCount: 15},
|
||||
{Ordinal: 200, ItemType: "message", MessageID: msg.ID, TokenCount: 5},
|
||||
})
|
||||
|
||||
a := &Assembler{store: s, config: Config{}}
|
||||
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 1000})
|
||||
if err != nil {
|
||||
t.Fatalf("Assemble: %v", err)
|
||||
}
|
||||
|
||||
// Should have a depth-aware prompt in Summary field
|
||||
if result.Summary == "" {
|
||||
t.Error("expected non-empty Summary when depth >= 2")
|
||||
}
|
||||
// SystemPromptAddition is embedded in Summary field
|
||||
if !strings.Contains(result.Summary, "multi-level summarization") {
|
||||
t.Error("Summary should contain system prompt addition about multi-level summarization")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatSummaryXMLUsesSummaryRef(t *testing.T) {
|
||||
// Spec: condensed summaries use <summary_ref id="parentId" /> not <parent>parentId</parent>
|
||||
now := time.Now().UTC()
|
||||
s := Summary{
|
||||
SummaryID: "sum_condensed1",
|
||||
Kind: SummaryKindCondensed,
|
||||
Depth: 1,
|
||||
Content: "condensed content",
|
||||
TokenCount: 50,
|
||||
DescendantCount: 2,
|
||||
EarliestAt: &now,
|
||||
LatestAt: &now,
|
||||
}
|
||||
parentIDs := []string{"sum_leaf1", "sum_leaf2"}
|
||||
|
||||
xml := FormatSummaryXML(&s, parentIDs)
|
||||
|
||||
// Must use <summary_ref id="..." /> per spec
|
||||
if !contains(xml, `<summary_ref id="sum_leaf1" />`) {
|
||||
t.Errorf("expected <summary_ref id=\"sum_leaf1\" />, got: %s", xml)
|
||||
}
|
||||
if !contains(xml, `<summary_ref id="sum_leaf2" />`) {
|
||||
t.Errorf("expected <summary_ref id=\"sum_leaf2\" />, got: %s", xml)
|
||||
}
|
||||
// Must NOT use old <parent> tag
|
||||
if contains(xml, "<parent>") {
|
||||
t.Errorf("should not use <parent> tag, got: %s", xml)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatSummaryXMLIncludesTimestamps(t *testing.T) {
|
||||
// Spec: summary XML includes earliest_at and latest_at attributes
|
||||
earliest := time.Date(2026, 3, 15, 10, 0, 0, 0, time.UTC)
|
||||
latest := time.Date(2026, 3, 15, 14, 30, 0, 0, time.UTC)
|
||||
s := Summary{
|
||||
SummaryID: "sum_leaf1",
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: "leaf content",
|
||||
TokenCount: 30,
|
||||
DescendantCount: 0,
|
||||
EarliestAt: &earliest,
|
||||
LatestAt: &latest,
|
||||
}
|
||||
|
||||
xml := FormatSummaryXML(&s, nil)
|
||||
|
||||
if !contains(xml, `earliest_at="2026-03-15T10:00:00Z"`) {
|
||||
t.Errorf("missing earliest_at attribute, got: %s", xml)
|
||||
}
|
||||
if !contains(xml, `latest_at="2026-03-15T14:30:00Z"`) {
|
||||
t.Errorf("missing latest_at attribute, got: %s", xml)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatSummaryXMLNoTimestampsWhenNil(t *testing.T) {
|
||||
// When EarliestAt/LatestAt are nil, attributes should be omitted
|
||||
s := Summary{
|
||||
SummaryID: "sum_leaf1",
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: "leaf content",
|
||||
TokenCount: 30,
|
||||
DescendantCount: 0,
|
||||
}
|
||||
|
||||
xml := FormatSummaryXML(&s, nil)
|
||||
|
||||
if contains(xml, "earliest_at=") {
|
||||
t.Errorf("should not have earliest_at when nil, got: %s", xml)
|
||||
}
|
||||
if contains(xml, "latest_at=") {
|
||||
t.Errorf("should not have latest_at when nil, got: %s", xml)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,336 @@
|
||||
package seahorse
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
// newBenchStore creates a test store for benchmarks.
|
||||
func newBenchStore(b *testing.B) (*Store, func()) {
|
||||
b.Helper()
|
||||
db, err := sql.Open("sqlite", ":memory:")
|
||||
if err != nil {
|
||||
b.Fatalf("open test db: %v", err)
|
||||
}
|
||||
if err := runSchema(db); err != nil {
|
||||
db.Close()
|
||||
b.Fatalf("migration: %v", err)
|
||||
}
|
||||
return &Store{db: db}, func() { db.Close() }
|
||||
}
|
||||
|
||||
// --- Ingest benchmarks ---
|
||||
|
||||
func BenchmarkIngest_SingleMessage(b *testing.B) {
|
||||
s, cleanup := newBenchStore(b)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
conv, _ := s.GetOrCreateConversation(ctx, "bench:ingest")
|
||||
convID := conv.ConversationID
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := s.AddMessage(ctx, convID, "user", "Test message content", 15)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkIngest_BatchMessages(b *testing.B) {
|
||||
s, cleanup := newBenchStore(b)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
conv, _ := s.GetOrCreateConversation(ctx, fmt.Sprintf("bench:ingest-batch:%d", i))
|
||||
convID := conv.ConversationID
|
||||
|
||||
for j := 0; j < 10; j++ {
|
||||
added, err := s.AddMessage(ctx, convID, "user",
|
||||
fmt.Sprintf("Message %d in batch", j), 10)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
s.AppendContextMessage(ctx, convID, added.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- Assemble benchmarks ---
|
||||
|
||||
func BenchmarkAssemble_MessagesOnly(b *testing.B) {
|
||||
s, cleanup := newBenchStore(b)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
conv, _ := s.GetOrCreateConversation(ctx, "bench:assemble-msgs")
|
||||
convID := conv.ConversationID
|
||||
|
||||
// Add 100 messages
|
||||
for i := 0; i < 100; i++ {
|
||||
m, _ := s.AddMessage(ctx, convID, "user",
|
||||
fmt.Sprintf("Message content %d with some text", i), 10)
|
||||
s.AppendContextMessage(ctx, convID, m.ID)
|
||||
}
|
||||
|
||||
a := &Assembler{store: s}
|
||||
input := AssembleInput{Budget: 50000}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := a.Assemble(ctx, convID, input)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkAssemble_WithSummaries(b *testing.B) {
|
||||
s, cleanup := newBenchStore(b)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
conv, _ := s.GetOrCreateConversation(ctx, "bench:assemble-sums")
|
||||
convID := conv.ConversationID
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
// Add 10 leaf summaries
|
||||
for i := 0; i < 10; i++ {
|
||||
sum, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: fmt.Sprintf("Leaf summary %d", i),
|
||||
TokenCount: 500,
|
||||
EarliestAt: &now,
|
||||
LatestAt: &now,
|
||||
})
|
||||
s.AppendContextSummary(ctx, convID, sum.SummaryID)
|
||||
}
|
||||
|
||||
// Add 20 fresh messages
|
||||
for i := 0; i < 20; i++ {
|
||||
m, _ := s.AddMessage(ctx, convID, "user", fmt.Sprintf("Fresh message %d", i), 10)
|
||||
s.AppendContextMessage(ctx, convID, m.ID)
|
||||
}
|
||||
|
||||
a := &Assembler{store: s}
|
||||
input := AssembleInput{Budget: 10000}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := a.Assemble(ctx, convID, input)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkAssemble_BudgetEviction(b *testing.B) {
|
||||
s, cleanup := newBenchStore(b)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
conv, _ := s.GetOrCreateConversation(ctx, "bench:assemble-evict")
|
||||
convID := conv.ConversationID
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
// Add 50 leaf summaries (more than budget can hold)
|
||||
for i := 0; i < 50; i++ {
|
||||
sum, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: fmt.Sprintf("Summary %d", i),
|
||||
TokenCount: 300,
|
||||
EarliestAt: &now,
|
||||
LatestAt: &now,
|
||||
})
|
||||
s.AppendContextSummary(ctx, convID, sum.SummaryID)
|
||||
}
|
||||
|
||||
// Add fresh tail
|
||||
for i := 0; i < FreshTailCount; i++ {
|
||||
m, _ := s.AddMessage(ctx, convID, "user", "fresh", 10)
|
||||
s.AppendContextMessage(ctx, convID, m.ID)
|
||||
}
|
||||
|
||||
a := &Assembler{store: s}
|
||||
input := AssembleInput{Budget: 5000} // Force eviction
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := a.Assemble(ctx, convID, input)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- Search (FTS5) benchmarks ---
|
||||
|
||||
// benchSeedSummaries adds n summaries to a conversation for search benchmarks.
|
||||
func benchSeedSummaries(b *testing.B, s *Store, convID int64, n int, contentTpl string) {
|
||||
b.Helper()
|
||||
now := time.Now().UTC()
|
||||
for i := 0; i < n; i++ {
|
||||
sum, err := s.CreateSummary(context.Background(), CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: fmt.Sprintf(contentTpl, i),
|
||||
TokenCount: 200,
|
||||
EarliestAt: &now,
|
||||
LatestAt: &now,
|
||||
})
|
||||
if err != nil {
|
||||
b.Fatalf("create summary: %v", err)
|
||||
}
|
||||
s.AppendContextSummary(context.Background(), convID, sum.SummaryID)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSearchSummaries_FTS5(b *testing.B) {
|
||||
s, cleanup := newBenchStore(b)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
conv, _ := s.GetOrCreateConversation(ctx, "bench:search-fts")
|
||||
convID := conv.ConversationID
|
||||
|
||||
benchSeedSummaries(b, s, convID, 100, "Summary about database configuration and API endpoints %d")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := s.SearchSummaries(ctx, SearchInput{
|
||||
Pattern: "database",
|
||||
Mode: "full_text",
|
||||
ConversationID: convID,
|
||||
})
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSearchSummaries_Like(b *testing.B) {
|
||||
s, cleanup := newBenchStore(b)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
conv, _ := s.GetOrCreateConversation(ctx, "bench:search-like")
|
||||
convID := conv.ConversationID
|
||||
|
||||
benchSeedSummaries(b, s, convID, 100, "Summary about configuration %d")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := s.SearchSummaries(ctx, SearchInput{
|
||||
Pattern: "config",
|
||||
Mode: "like",
|
||||
ConversationID: convID,
|
||||
})
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSearchMessages_FTS5(b *testing.B) {
|
||||
s, cleanup := newBenchStore(b)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
conv, _ := s.GetOrCreateConversation(ctx, "bench:search-msg-fts")
|
||||
convID := conv.ConversationID
|
||||
|
||||
// Add 500 messages
|
||||
for i := 0; i < 500; i++ {
|
||||
m, _ := s.AddMessage(ctx, convID, "user",
|
||||
fmt.Sprintf("User message about API and database integration %d", i), 20)
|
||||
s.AppendContextMessage(ctx, convID, m.ID)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := s.SearchMessages(ctx, SearchInput{
|
||||
Pattern: "API database",
|
||||
Mode: "full_text",
|
||||
ConversationID: convID,
|
||||
})
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- Bootstrap benchmarks ---
|
||||
|
||||
func BenchmarkBootstrap_Empty(b *testing.B) {
|
||||
s, cleanup := newBenchStore(b)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
conv, _ := s.GetOrCreateConversation(ctx, fmt.Sprintf("bench:bootstrap-empty:%d", i))
|
||||
convID := conv.ConversationID
|
||||
_ = convID // Bootstrap with empty history
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkBootstrap_100Messages(b *testing.B) {
|
||||
s, cleanup := newBenchStore(b)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
|
||||
// Prepare 100 messages
|
||||
msgs := make([]Message, 100)
|
||||
for i := 0; i < 100; i++ {
|
||||
msgs[i] = Message{
|
||||
Role: "user",
|
||||
Content: fmt.Sprintf("Bootstrap message %d", i),
|
||||
TokenCount: 15,
|
||||
}
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
conv, _ := s.GetOrCreateConversation(ctx, fmt.Sprintf("bench:bootstrap-100:%d", i))
|
||||
convID := conv.ConversationID
|
||||
|
||||
for _, m := range msgs {
|
||||
added, _ := s.AddMessage(ctx, convID, m.Role, m.Content, m.TokenCount)
|
||||
s.AppendContextMessage(ctx, convID, added.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkBootstrap_500Messages(b *testing.B) {
|
||||
s, cleanup := newBenchStore(b)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
|
||||
msgs := make([]Message, 500)
|
||||
for i := 0; i < 500; i++ {
|
||||
msgs[i] = Message{
|
||||
Role: "user",
|
||||
Content: fmt.Sprintf("Bootstrap message %d", i),
|
||||
TokenCount: 15,
|
||||
}
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
conv, _ := s.GetOrCreateConversation(ctx, fmt.Sprintf("bench:bootstrap-500:%d", i))
|
||||
convID := conv.ConversationID
|
||||
|
||||
for _, m := range msgs {
|
||||
added, _ := s.AddMessage(ctx, convID, m.Role, m.Content, m.TokenCount)
|
||||
s.AppendContextMessage(ctx, convID, added.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,898 @@
|
||||
package seahorse
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/tokenizer"
|
||||
)
|
||||
|
||||
// CompactInput controls compaction behavior.
|
||||
type CompactInput struct {
|
||||
Budget *int // Token budget override
|
||||
Force bool // Force compaction even if below threshold
|
||||
}
|
||||
|
||||
// CompactResult describes what was compacted.
|
||||
type CompactResult struct {
|
||||
SummariesCreated []string `json:"summariesCreated"`
|
||||
TokensSaved int `json:"tokensSaved"`
|
||||
LeafSummaries int `json:"leafSummaries"`
|
||||
CondensedSummaries int `json:"condensedSummaries"`
|
||||
}
|
||||
|
||||
// NeedsCompaction returns true if context tokens >= ContextThreshold × contextWindow.
|
||||
func (e *CompactionEngine) NeedsCompaction(ctx context.Context, convID int64, contextWindow int) (bool, error) {
|
||||
tokens, err := e.store.GetContextTokenCount(ctx, convID)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("get token count: %w", err)
|
||||
}
|
||||
threshold := int(float64(contextWindow) * ContextThreshold)
|
||||
return tokens >= threshold, nil
|
||||
}
|
||||
|
||||
// Close cancels the shutdown context, stopping async goroutines.
|
||||
func (e *CompactionEngine) Close() {
|
||||
if e.shutdownCancel != nil {
|
||||
e.shutdownCancel()
|
||||
}
|
||||
}
|
||||
|
||||
// Compact runs leaf compaction (sync) and optionally condensed compaction.
|
||||
func (e *CompactionEngine) Compact(ctx context.Context, convID int64, input CompactInput) (*CompactResult, error) {
|
||||
result := &CompactResult{}
|
||||
|
||||
// Phase 1: leaf compaction (synchronous, every turn)
|
||||
summaryID, err := e.compactLeaf(ctx, convID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("compact leaf: %w", err)
|
||||
}
|
||||
if summaryID != nil {
|
||||
result.SummariesCreated = append(result.SummariesCreated, *summaryID)
|
||||
result.LeafSummaries++
|
||||
logger.InfoCF("seahorse", "compact: leaf", map[string]any{
|
||||
"conv_id": convID,
|
||||
"summary_id": *summaryID,
|
||||
})
|
||||
}
|
||||
|
||||
// Phase 2: condensed compaction if over threshold
|
||||
tokensBefore, _ := e.store.GetContextTokenCount(ctx, convID)
|
||||
var budget int
|
||||
if input.Budget != nil {
|
||||
budget = *input.Budget
|
||||
if budget == 0 {
|
||||
logger.ErrorCF("seahorse", "Compact: budget is 0, this should not happen", map[string]any{
|
||||
"conv_id": convID,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
budget = int(float64(tokensBefore) * ContextThreshold)
|
||||
}
|
||||
|
||||
if input.Force || (tokensBefore > budget && budget > 0) {
|
||||
// Launch async condensed compaction with dedup
|
||||
if _, loaded := e.condensing.LoadOrStore(convID, struct{}{}); !loaded {
|
||||
go func() {
|
||||
defer e.condensing.Delete(convID)
|
||||
e.runCondensedLoop(e.shutdownCtx, convID)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
tokensAfter, _ := e.store.GetContextTokenCount(ctx, convID)
|
||||
if tokensAfter < tokensBefore {
|
||||
result.TokensSaved = tokensBefore - tokensAfter
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// CompactUntilUnder aggressively compacts until context is under budget.
|
||||
func (e *CompactionEngine) CompactUntilUnder(ctx context.Context, convID int64, budget int) (*CompactResult, error) {
|
||||
result := &CompactResult{}
|
||||
prevTokens := 0
|
||||
logger.InfoCF("seahorse", "compact_until_under: start", map[string]any{"conv_id": convID, "budget": budget})
|
||||
|
||||
for iter := 0; iter < MaxCompactIterations; iter++ {
|
||||
tokens, err := e.store.GetContextTokenCount(ctx, convID)
|
||||
if err != nil {
|
||||
return result, fmt.Errorf("get tokens: %w", err)
|
||||
}
|
||||
if tokens <= budget {
|
||||
logger.InfoCF("seahorse", "compact_until_under: done", map[string]any{
|
||||
"conv_id": convID,
|
||||
"budget": budget,
|
||||
"tokens": tokens,
|
||||
"leaf": result.LeafSummaries,
|
||||
"condensed": result.CondensedSummaries,
|
||||
})
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Try leaf first
|
||||
summaryID, err := e.compactLeaf(ctx, convID, true)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
if summaryID != nil {
|
||||
result.SummariesCreated = append(result.SummariesCreated, *summaryID)
|
||||
result.LeafSummaries++
|
||||
logger.InfoCF("seahorse", "compact_until_under: leaf", map[string]any{
|
||||
"conv_id": convID,
|
||||
"summary_id": *summaryID,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// Try condensed with forced fanout
|
||||
condensedID, err := e.compactCondensed(ctx, convID)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
if condensedID != nil {
|
||||
result.SummariesCreated = append(result.SummariesCreated, *condensedID)
|
||||
result.CondensedSummaries++
|
||||
logger.InfoCF("seahorse", "compact_until_under: condensed", map[string]any{
|
||||
"conv_id": convID,
|
||||
"summary_id": *condensedID,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// No progress
|
||||
newTokens, _ := e.store.GetContextTokenCount(ctx, convID)
|
||||
if newTokens >= prevTokens {
|
||||
logger.WarnCF("seahorse", "compact_until_under: no progress", map[string]any{
|
||||
"conv_id": convID,
|
||||
"tokens": newTokens,
|
||||
})
|
||||
return result, nil
|
||||
}
|
||||
prevTokens = newTokens
|
||||
}
|
||||
|
||||
// Safety cap exceeded — see MaxCompactIterations doc for rationale.
|
||||
logger.WarnCF("seahorse", "compact_until_under: exceeded max iterations", map[string]any{
|
||||
"conv_id": convID,
|
||||
"budget": budget,
|
||||
"iterations": MaxCompactIterations,
|
||||
"tokens": prevTokens,
|
||||
})
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// compactLeaf compresses the oldest contiguous message chunk into a leaf summary.
|
||||
// When force is true, FreshTailCount protection is bypassed (used by CompactUntilUnder).
|
||||
func (e *CompactionEngine) compactLeaf(ctx context.Context, convID int64, force ...bool) (*string, error) {
|
||||
items, err := e.store.GetContextItems(ctx, convID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Find oldest contiguous message chunk outside fresh tail
|
||||
msgCount := 0
|
||||
msgTokens := 0
|
||||
for _, item := range items {
|
||||
if item.ItemType == "message" {
|
||||
msgCount++
|
||||
msgTokens += item.TokenCount
|
||||
}
|
||||
}
|
||||
|
||||
// Trigger if either message count or token threshold is met
|
||||
if msgCount < LeafMinFanout && msgTokens < LeafChunkTokens {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Calculate fresh tail boundary (bypass when forced)
|
||||
useForce := len(force) > 0 && force[0]
|
||||
tailStartIdx := len(items) - FreshTailCount
|
||||
if useForce {
|
||||
tailStartIdx = len(items) // allow compacting everything
|
||||
}
|
||||
if tailStartIdx < 0 {
|
||||
tailStartIdx = 0
|
||||
}
|
||||
|
||||
// Find oldest contiguous message chunk, accumulating up to LeafChunkTokens
|
||||
var chunk []ContextItem
|
||||
chunkStart := -1
|
||||
chunkEnd := -1
|
||||
accumTokens := 0
|
||||
for i := 0; i < tailStartIdx; i++ {
|
||||
if items[i].ItemType == "message" {
|
||||
if chunkStart == -1 {
|
||||
chunkStart = i
|
||||
}
|
||||
chunkEnd = i
|
||||
accumTokens += items[i].TokenCount
|
||||
// Stop accumulating once we reach the token budget
|
||||
if accumTokens >= LeafChunkTokens {
|
||||
break
|
||||
}
|
||||
} else {
|
||||
// Non-message breaks the chunk
|
||||
if chunkStart != -1 && (chunkEnd-chunkStart+1) >= LeafMinFanout {
|
||||
break
|
||||
}
|
||||
chunkStart = -1
|
||||
chunkEnd = -1
|
||||
accumTokens = 0
|
||||
}
|
||||
}
|
||||
|
||||
if chunkStart == -1 || (chunkEnd-chunkStart+1) < LeafMinFanout {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
chunk = items[chunkStart : chunkEnd+1]
|
||||
|
||||
// Collect messages for the chunk
|
||||
var messages []Message
|
||||
for _, item := range chunk {
|
||||
msg, innerErr := e.store.GetMessageByID(ctx, item.MessageID)
|
||||
if innerErr != nil {
|
||||
return nil, innerErr
|
||||
}
|
||||
messages = append(messages, *msg)
|
||||
}
|
||||
|
||||
// Get prior summaries for context
|
||||
priorSummary := ""
|
||||
priorCount := 0
|
||||
for i := chunkStart - 1; i >= 0 && priorCount < 2; i-- {
|
||||
if items[i].ItemType == "summary" {
|
||||
sum, innerErr2 := e.store.GetSummary(ctx, items[i].SummaryID)
|
||||
if innerErr2 == nil {
|
||||
priorSummary = sum.Content + "\n" + priorSummary
|
||||
priorCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Generate summary
|
||||
content, err := e.generateLeafSummary(ctx, messages, priorSummary)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create summary in store
|
||||
tokenCount := tokenizer.EstimateMessageTokens(providers.Message{Content: content})
|
||||
|
||||
var earliestAt, latestAt *time.Time
|
||||
if len(messages) > 0 {
|
||||
earliestAt = &messages[0].CreatedAt
|
||||
latestAt = &messages[len(messages)-1].CreatedAt
|
||||
}
|
||||
|
||||
summary, err := e.store.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: content,
|
||||
TokenCount: tokenCount,
|
||||
EarliestAt: earliestAt,
|
||||
LatestAt: latestAt,
|
||||
SourceMessageTokens: sumMessageTokens(messages),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Link to source messages
|
||||
msgIDs := make([]int64, len(messages))
|
||||
for i, m := range messages {
|
||||
msgIDs[i] = m.ID
|
||||
}
|
||||
if err := e.store.LinkSummaryToMessages(ctx, summary.SummaryID, msgIDs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Replace context range with summary
|
||||
if err := e.store.ReplaceContextRangeWithSummary(
|
||||
ctx, convID, chunk[0].Ordinal, chunk[len(chunk)-1].Ordinal, summary.SummaryID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &summary.SummaryID, nil
|
||||
}
|
||||
|
||||
// compactCondensed compresses multiple summaries into one higher-level summary.
|
||||
func (e *CompactionEngine) compactCondensed(ctx context.Context, convID int64) (*string, error) {
|
||||
// Try ordinal-aware selection first (respects consecutive ordering)
|
||||
var candidates []Summary
|
||||
|
||||
depths, err := e.store.GetDistinctDepthsInContext(ctx, convID, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, depth := range depths {
|
||||
var chunkAtDepth []Summary
|
||||
var err2 error
|
||||
chunkAtDepth, err2 = e.selectOldestChunkAtDepth(ctx, convID, depth)
|
||||
if err2 != nil {
|
||||
continue
|
||||
}
|
||||
if len(chunkAtDepth) > 0 {
|
||||
candidates = chunkAtDepth
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to depth-grouping selection
|
||||
if len(candidates) == 0 {
|
||||
candidates, err = e.selectShallowestCondensationCandidate(ctx, convID, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if len(candidates) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Generate condensed summary
|
||||
content, err := e.generateCondensedSummary(ctx, candidates)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Merge metadata
|
||||
maxDepth := 0
|
||||
descendantCount := 0
|
||||
descendantTokenCount := 0
|
||||
sourceMessageTokens := 0
|
||||
var earliestAt, latestAt *time.Time
|
||||
|
||||
parentIDs := make([]string, len(candidates))
|
||||
for i, c := range candidates {
|
||||
parentIDs[i] = c.SummaryID
|
||||
if c.Depth > maxDepth {
|
||||
maxDepth = c.Depth
|
||||
}
|
||||
descendantCount += c.DescendantCount + 1
|
||||
descendantTokenCount += c.TokenCount + c.DescendantTokenCount
|
||||
sourceMessageTokens += c.SourceMessageTokenCount
|
||||
if c.EarliestAt != nil {
|
||||
if earliestAt == nil || c.EarliestAt.Before(*earliestAt) {
|
||||
earliestAt = c.EarliestAt
|
||||
}
|
||||
}
|
||||
if c.LatestAt != nil {
|
||||
if latestAt == nil || c.LatestAt.After(*latestAt) {
|
||||
latestAt = c.LatestAt
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tokenCount := tokenizer.EstimateMessageTokens(providers.Message{Content: content})
|
||||
|
||||
summary, err := e.store.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindCondensed,
|
||||
Depth: maxDepth + 1,
|
||||
Content: content,
|
||||
TokenCount: tokenCount,
|
||||
EarliestAt: earliestAt,
|
||||
LatestAt: latestAt,
|
||||
DescendantCount: descendantCount,
|
||||
DescendantTokenCount: descendantTokenCount,
|
||||
SourceMessageTokens: sourceMessageTokens,
|
||||
ParentIDs: parentIDs,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Find the ordinal range for the candidate summaries in context
|
||||
items, err := e.store.GetContextItems(ctx, convID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
candidateSet := make(map[string]bool)
|
||||
for _, c := range candidates {
|
||||
candidateSet[c.SummaryID] = true
|
||||
}
|
||||
|
||||
startOrd := -1
|
||||
endOrd := -1
|
||||
hasNonCandidate := false
|
||||
for _, item := range items {
|
||||
if item.ItemType == "summary" && candidateSet[item.SummaryID] {
|
||||
if startOrd == -1 {
|
||||
startOrd, endOrd = item.Ordinal, item.Ordinal
|
||||
} else {
|
||||
// Check for non-candidate items between endOrd and current ordinal
|
||||
for _, it := range items {
|
||||
if it.Ordinal > endOrd && it.Ordinal <= item.Ordinal {
|
||||
if it.ItemType != "summary" || !candidateSet[it.SummaryID] {
|
||||
hasNonCandidate = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if hasNonCandidate {
|
||||
break
|
||||
}
|
||||
if item.Ordinal < startOrd {
|
||||
startOrd = item.Ordinal
|
||||
}
|
||||
if item.Ordinal > endOrd {
|
||||
endOrd = item.Ordinal
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if startOrd == -1 || endOrd == -1 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Collect candidate summary IDs
|
||||
candidateIDs := make([]string, 0, len(candidates))
|
||||
for _, c := range candidates {
|
||||
candidateIDs = append(candidateIDs, c.SummaryID)
|
||||
}
|
||||
|
||||
if hasNonCandidate {
|
||||
// Use safe per-item deletion to avoid deleting non-candidate items
|
||||
if err := e.store.ReplaceContextItemsWithSummary(ctx, convID, candidateIDs, summary.SummaryID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
// Candidates are consecutive, use efficient range deletion
|
||||
if err := e.store.ReplaceContextRangeWithSummary(ctx, convID, startOrd, endOrd, summary.SummaryID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &summary.SummaryID, nil
|
||||
}
|
||||
|
||||
// selectShallowestCondensationCandidate finds the shallowest consecutive summary group.
|
||||
func (e *CompactionEngine) selectShallowestCondensationCandidate(
|
||||
ctx context.Context, convID int64, forced bool,
|
||||
) ([]Summary, error) {
|
||||
items, err := e.store.GetContextItems(ctx, convID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Group by depth, find consecutive runs
|
||||
tailStartIdx := len(items) - FreshTailCount
|
||||
if tailStartIdx < 0 {
|
||||
tailStartIdx = 0
|
||||
}
|
||||
|
||||
minFanout := CondensedMinFanout
|
||||
if forced {
|
||||
minFanout = CondensedMinFanoutHard
|
||||
}
|
||||
|
||||
// Track depth groups
|
||||
depthGroups := make(map[int][]ContextItem)
|
||||
for i := 0; i < tailStartIdx; i++ {
|
||||
item := items[i]
|
||||
if item.ItemType != "summary" {
|
||||
continue
|
||||
}
|
||||
sum, err := e.store.GetSummary(ctx, item.SummaryID)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
depthGroups[sum.Depth] = append(depthGroups[sum.Depth], item)
|
||||
}
|
||||
|
||||
// Find shallowest depth with enough candidates
|
||||
// Collect all depths and sort to handle non-consecutive depths
|
||||
var depths []int
|
||||
for depth := range depthGroups {
|
||||
depths = append(depths, depth)
|
||||
}
|
||||
sort.Ints(depths)
|
||||
|
||||
for _, depth := range depths {
|
||||
group := depthGroups[depth]
|
||||
if len(group) >= minFanout {
|
||||
// Load summaries
|
||||
var result []Summary
|
||||
for _, item := range group[:minFanout] {
|
||||
sum, err := e.store.GetSummary(ctx, item.SummaryID)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
result = append(result, *sum)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// selectOldestChunkAtDepth scans context_items from oldest ordinal, collecting consecutive
|
||||
// summaries at the given depth. Stops at non-summary items, different depth, fresh tail, or
|
||||
// token overflow. Returns contiguous chunk of summaries.
|
||||
func (e *CompactionEngine) selectOldestChunkAtDepth(
|
||||
ctx context.Context, convID int64, targetDepth int,
|
||||
) ([]Summary, error) {
|
||||
items, err := e.store.GetContextItems(ctx, convID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tailStartIdx := len(items) - FreshTailCount
|
||||
if tailStartIdx < 0 {
|
||||
tailStartIdx = 0
|
||||
}
|
||||
|
||||
var chunk []Summary
|
||||
accumTokens := 0
|
||||
|
||||
for i := 0; i < tailStartIdx; i++ {
|
||||
item := items[i]
|
||||
if item.ItemType != "summary" {
|
||||
// Non-summary breaks the chunk
|
||||
break
|
||||
}
|
||||
sum, err := e.store.GetSummary(ctx, item.SummaryID)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
if sum.Depth != targetDepth {
|
||||
// Different depth breaks the chunk
|
||||
break
|
||||
}
|
||||
if accumTokens+sum.TokenCount > LeafChunkTokens {
|
||||
// Token overflow stops collection
|
||||
break
|
||||
}
|
||||
chunk = append(chunk, *sum)
|
||||
accumTokens += sum.TokenCount
|
||||
}
|
||||
|
||||
// Min tokens check: spec line 808
|
||||
// chunk tokens must be >= max(CondensedTargetTokens, LeafChunkTokens × 0.1) = 2000
|
||||
minTokens := CondensedTargetTokens // 2000
|
||||
if accumTokens < minTokens {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return chunk, nil
|
||||
}
|
||||
|
||||
// generateLeafSummary calls the LLM to generate a leaf summary with 3-level escalation.
|
||||
// Level 1: normal LLM prompt. Level 2: aggressive prompt. Level 3: deterministic truncation.
|
||||
func (e *CompactionEngine) generateLeafSummary(
|
||||
ctx context.Context,
|
||||
messages []Message,
|
||||
previousSummary string,
|
||||
) (string, error) {
|
||||
if e.complete == nil {
|
||||
return truncateSummary(messages), nil
|
||||
}
|
||||
|
||||
sourceText := formatMessagesForSummary(messages)
|
||||
inputTokens := sumMessageTokens(messages)
|
||||
targetTokens := minInt(LeafTargetTokens, int(float64(inputTokens)*0.35))
|
||||
|
||||
// Level 1: normal prompt
|
||||
prompt := buildLeafSummaryPrompt(sourceText, previousSummary, targetTokens)
|
||||
content, err := e.complete(ctx, prompt, CompleteOptions{
|
||||
MaxTokens: LeafTargetTokens * 2,
|
||||
Temperature: 0.3,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if content == "" {
|
||||
// Retry with temperature=0
|
||||
content, err = e.complete(ctx, prompt, CompleteOptions{
|
||||
MaxTokens: LeafTargetTokens * 2,
|
||||
Temperature: 0,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
// Check if level 1 succeeded
|
||||
if content != "" && tokenizer.EstimateMessageTokens(providers.Message{Content: content}) < inputTokens {
|
||||
return content, nil
|
||||
}
|
||||
|
||||
// Level 2: aggressive prompt
|
||||
aggressiveTarget := minInt(640, int(float64(inputTokens)*0.20))
|
||||
aggressivePrompt := buildAggressiveLeafSummaryPrompt(sourceText, previousSummary, aggressiveTarget)
|
||||
content, err = e.complete(ctx, aggressivePrompt, CompleteOptions{
|
||||
MaxTokens: aggressiveTarget * 2,
|
||||
Temperature: 0.3,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if content == "" {
|
||||
// Retry with temperature=0
|
||||
content, err = e.complete(ctx, aggressivePrompt, CompleteOptions{
|
||||
MaxTokens: aggressiveTarget * 2,
|
||||
Temperature: 0,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
if content != "" && tokenizer.EstimateMessageTokens(providers.Message{Content: content}) < inputTokens {
|
||||
return content, nil
|
||||
}
|
||||
|
||||
// Level 3: deterministic truncation
|
||||
return truncateSummary(messages), nil
|
||||
}
|
||||
|
||||
// generateCondensedSummary calls the LLM to generate a condensed summary with 3-level escalation.
|
||||
func (e *CompactionEngine) generateCondensedSummary(ctx context.Context, summaries []Summary) (string, error) {
|
||||
if e.complete == nil {
|
||||
return truncateCondensedSummaries(summaries), nil
|
||||
}
|
||||
|
||||
sourceText := formatSummariesForCondensation(summaries)
|
||||
inputTokens := sumSummaryTokens(summaries)
|
||||
targetTokens := minInt(CondensedTargetTokens, int(float64(inputTokens)*0.35))
|
||||
|
||||
// Level 1: normal prompt
|
||||
prompt := buildCondensedSummaryPrompt(sourceText, targetTokens)
|
||||
content, err := e.complete(ctx, prompt, CompleteOptions{
|
||||
MaxTokens: CondensedTargetTokens * 2,
|
||||
Temperature: 0.3,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if content == "" {
|
||||
content, err = e.complete(ctx, prompt, CompleteOptions{
|
||||
MaxTokens: CondensedTargetTokens * 2,
|
||||
Temperature: 0,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
if content != "" {
|
||||
return content, nil
|
||||
}
|
||||
|
||||
// Level 2: aggressive prompt
|
||||
aggressiveTarget := minInt(640, int(float64(inputTokens)*0.20))
|
||||
aggressivePrompt := buildCondensedSummaryPrompt(sourceText, aggressiveTarget)
|
||||
content, err = e.complete(ctx, aggressivePrompt, CompleteOptions{
|
||||
MaxTokens: aggressiveTarget * 2,
|
||||
Temperature: 0.3,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if content != "" {
|
||||
return content, nil
|
||||
}
|
||||
|
||||
// Level 3: deterministic fallback
|
||||
return truncateCondensedSummaries(summaries), nil
|
||||
}
|
||||
|
||||
// runCondensedLoop runs condensed compaction in a loop until:
|
||||
// a) context tokens <= threshold (success), OR
|
||||
// b) No candidate found (nothing to condense), OR
|
||||
// c) tokensAfter >= tokensBefore (no progress this iteration), OR
|
||||
// d) tokensAfter >= previousTokens (no improvement over last iteration)
|
||||
func (e *CompactionEngine) runCondensedLoop(ctx context.Context, convID int64) {
|
||||
var prevTokens int
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
tokensBefore, err := e.store.GetContextTokenCount(ctx, convID)
|
||||
if err != nil {
|
||||
logger.ErrorCF("seahorse", "condensed: get tokens", map[string]any{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
condensedID, err := e.compactCondensed(ctx, convID)
|
||||
if err != nil {
|
||||
logger.ErrorCF("seahorse", "condensed: compact", map[string]any{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if condensedID == nil {
|
||||
// No candidate found
|
||||
logger.DebugCF("seahorse", "condensed: no candidate", map[string]any{"conv_id": convID})
|
||||
return
|
||||
}
|
||||
|
||||
tokensAfter, _ := e.store.GetContextTokenCount(ctx, convID)
|
||||
|
||||
if tokensAfter >= tokensBefore {
|
||||
// No progress this iteration
|
||||
logger.DebugCF(
|
||||
"seahorse",
|
||||
"condensed: no progress",
|
||||
map[string]any{"conv_id": convID, "tokens_before": tokensBefore, "tokens_after": tokensAfter},
|
||||
)
|
||||
return
|
||||
}
|
||||
if tokensAfter >= prevTokens && prevTokens > 0 {
|
||||
// No improvement over last iteration
|
||||
logger.DebugCF(
|
||||
"seahorse",
|
||||
"condensed: no improvement",
|
||||
map[string]any{"conv_id": convID, "tokens": tokensAfter},
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
prevTokens = tokensAfter
|
||||
}
|
||||
}
|
||||
|
||||
// --- Helper functions ---
|
||||
|
||||
func formatMessagesForSummary(messages []Message) string {
|
||||
var result string
|
||||
for _, m := range messages {
|
||||
ts := m.CreatedAt.Format("2006-01-02 15:04 MST")
|
||||
content := m.Content
|
||||
if content == "" && len(m.Parts) > 0 {
|
||||
content = partsToReadableContent(m.Parts)
|
||||
}
|
||||
result += fmt.Sprintf("[%s]\n%s\n\n", ts, content)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func formatSummariesForCondensation(summaries []Summary) string {
|
||||
var result string
|
||||
for _, s := range summaries {
|
||||
earliest := ""
|
||||
if s.EarliestAt != nil {
|
||||
earliest = s.EarliestAt.Format("2006-01-02")
|
||||
}
|
||||
latest := ""
|
||||
if s.LatestAt != nil {
|
||||
latest = s.LatestAt.Format("2006-01-02")
|
||||
}
|
||||
result += fmt.Sprintf("[%s - %s]\n%s\n\n", earliest, latest, s.Content)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func buildLeafSummaryPrompt(sourceText, previousSummary string, targetTokens int) string {
|
||||
prev := "(none)"
|
||||
if previousSummary != "" {
|
||||
prev = previousSummary
|
||||
}
|
||||
return fmt.Sprintf(`You summarize a SEGMENT of a conversation for future model turns.
|
||||
Treat this as incremental memory compaction input, not a full-conversation summary.
|
||||
|
||||
Normal summary policy:
|
||||
- Preserve key decisions, rationale, constraints, and active tasks.
|
||||
- Keep essential technical details needed to continue work safely.
|
||||
- Remove obvious repetition and conversational filler.
|
||||
|
||||
Output requirements:
|
||||
- Plain text only.
|
||||
- No preamble, headings, or markdown formatting.
|
||||
- Track file operations (created, modified, deleted, renamed) with file paths and current status.
|
||||
- If no file operations appear, include exactly: "Files: none".
|
||||
- End with exactly: "Expand for details about: <comma-separated list of what was dropped or compressed>".
|
||||
- Target length: about %d tokens or less.
|
||||
|
||||
<previous_context>
|
||||
%s
|
||||
</previous_context>
|
||||
|
||||
<conversation_segment>
|
||||
%s
|
||||
</conversation_segment>`, targetTokens, prev, sourceText)
|
||||
}
|
||||
|
||||
func buildCondensedSummaryPrompt(sourceText string, targetTokens int) string {
|
||||
return fmt.Sprintf(`You condense multiple summaries into a single higher-level summary.
|
||||
Preserve all important decisions, constraints, and outcomes.
|
||||
Merge overlapping topics. Keep technical details intact.
|
||||
|
||||
Output requirements:
|
||||
- Plain text only.
|
||||
- No preamble, headings, or markdown formatting.
|
||||
- End with exactly: "Expand for details about: <comma-separated list>".
|
||||
- Target length: about %d tokens or less.
|
||||
|
||||
<summaries>
|
||||
%s
|
||||
</summaries>`, targetTokens, sourceText)
|
||||
}
|
||||
|
||||
func buildAggressiveLeafSummaryPrompt(sourceText, previousSummary string, targetTokens int) string {
|
||||
prev := "(none)"
|
||||
if previousSummary != "" {
|
||||
prev = previousSummary
|
||||
}
|
||||
return fmt.Sprintf(`You summarize a SEGMENT of a conversation for future model turns.
|
||||
Aggressive summary policy:
|
||||
- Keep only durable facts and current task state.
|
||||
- Remove examples, repetition, and low-value narrative details.
|
||||
- Preserve explicit TODOs, blockers, decisions, and constraints.
|
||||
|
||||
Output requirements:
|
||||
- Plain text only.
|
||||
- No preamble, headings, or markdown formatting.
|
||||
- Track file operations (created, modified, deleted, renamed) with file paths and current status.
|
||||
- If no file operations appear, include exactly: "Files: none".
|
||||
- End with exactly: "Expand for details about: <comma-separated list of what was dropped or compressed>".
|
||||
- Target length: about %d tokens or less.
|
||||
|
||||
<previous_context>
|
||||
%s
|
||||
</previous_context>
|
||||
|
||||
<conversation_segment>
|
||||
%s
|
||||
</conversation_segment>`, targetTokens, prev, sourceText)
|
||||
}
|
||||
|
||||
func truncateSummary(messages []Message) string {
|
||||
content := ""
|
||||
for _, m := range messages {
|
||||
c := m.Content
|
||||
if c == "" && len(m.Parts) > 0 {
|
||||
c = partsToReadableContent(m.Parts)
|
||||
}
|
||||
content += c + "\n"
|
||||
}
|
||||
if len(content) > 2048 {
|
||||
content = content[:2048]
|
||||
}
|
||||
content += fmt.Sprintf("\n[Truncated from %d messages]", len(messages))
|
||||
return content
|
||||
}
|
||||
|
||||
func truncateCondensedSummaries(summaries []Summary) string {
|
||||
content := ""
|
||||
for _, s := range summaries {
|
||||
content += s.Content + "\n"
|
||||
}
|
||||
if len(content) > 2048 {
|
||||
content = content[:2048]
|
||||
}
|
||||
content += fmt.Sprintf("\n[Condensed from %d summaries]", len(summaries))
|
||||
return content
|
||||
}
|
||||
|
||||
func sumMessageTokens(messages []Message) int {
|
||||
total := 0
|
||||
for _, m := range messages {
|
||||
total += m.TokenCount
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
func sumSummaryTokens(summaries []Summary) int {
|
||||
total := 0
|
||||
for _, s := range summaries {
|
||||
total += s.TokenCount
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
func minInt(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
@@ -0,0 +1,974 @@
|
||||
package seahorse
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// --- Test Helpers ---
|
||||
|
||||
// waitForCondensed blocks until the async condensed goroutine for convID finishes.
|
||||
// Returns false if timeout is reached.
|
||||
func waitForCondensed(ce *CompactionEngine, convID int64, timeout time.Duration) bool {
|
||||
deadline := time.Now().Add(timeout)
|
||||
for time.Now().Before(deadline) {
|
||||
if _, exists := ce.condensing.Load(convID); !exists {
|
||||
return true
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// --- Compaction Tests ---
|
||||
|
||||
func newTestCompactionEngine(t *testing.T) (*CompactionEngine, *Store, int64) {
|
||||
t.Helper()
|
||||
db := openTestDB(t)
|
||||
if err := runSchema(db); err != nil {
|
||||
t.Fatalf("migration: %v", err)
|
||||
}
|
||||
s := &Store{db: db}
|
||||
ctx := context.Background()
|
||||
conv, _ := s.GetOrCreateConversation(ctx, "test:compact")
|
||||
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
|
||||
ce := &CompactionEngine{
|
||||
store: s,
|
||||
config: Config{},
|
||||
complete: mockCompleteFn,
|
||||
shutdownCtx: shutdownCtx,
|
||||
shutdownCancel: shutdownCancel,
|
||||
}
|
||||
convID := conv.ConversationID
|
||||
// Ensure async goroutines are stopped before database is closed.
|
||||
// Register cleanup here (after openTestDB) so it runs BEFORE openTestDB's db.Close().
|
||||
t.Cleanup(func() {
|
||||
shutdownCancel()
|
||||
// Wait for async condensed goroutine to finish (poll condensing map)
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
if _, exists := ce.condensing.Load(convID); !exists {
|
||||
break
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
})
|
||||
return ce, s, conv.ConversationID
|
||||
}
|
||||
|
||||
// newTestCompactionEngineWithStore creates a CompactionEngine with existing store.
|
||||
// Note: Caller is responsible for calling shutdownCancel when test ends.
|
||||
func newTestCompactionEngineWithStore(
|
||||
s *Store, complete CompleteFn,
|
||||
) (ce *CompactionEngine, shutdownCancel context.CancelFunc) {
|
||||
shutdownCtx, cancel := context.WithCancel(context.Background())
|
||||
return &CompactionEngine{
|
||||
store: s,
|
||||
config: Config{},
|
||||
complete: complete,
|
||||
shutdownCtx: shutdownCtx,
|
||||
shutdownCancel: cancel,
|
||||
}, cancel
|
||||
}
|
||||
|
||||
// mockCompleteFn returns a simple summary for testing
|
||||
var mockCompleteFn CompleteFn = func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) {
|
||||
return "Mock summary of the conversation segment.", nil
|
||||
}
|
||||
|
||||
func TestNeedsCompaction(t *testing.T) {
|
||||
ce, s, convID := newTestCompactionEngine(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Empty context — no compaction needed
|
||||
needed, err := ce.NeedsCompaction(ctx, convID, 10000)
|
||||
if err != nil {
|
||||
t.Fatalf("NeedsCompaction: %v", err)
|
||||
}
|
||||
if needed {
|
||||
t.Error("expected no compaction for empty context")
|
||||
}
|
||||
|
||||
// Add messages to context, total tokens = 8000
|
||||
for i := 0; i < 8; i++ {
|
||||
m, _ := s.AddMessage(ctx, convID, "user", "test message content", 1000)
|
||||
s.AppendContextMessage(ctx, convID, m.ID)
|
||||
}
|
||||
|
||||
// Threshold = 0.75 × 10000 = 7500. We have 8000 tokens → needs compaction
|
||||
needed, err = ce.NeedsCompaction(ctx, convID, 10000)
|
||||
if err != nil {
|
||||
t.Fatalf("NeedsCompaction: %v", err)
|
||||
}
|
||||
if !needed {
|
||||
t.Error("expected compaction needed at 8000/10000 tokens (threshold 75%)")
|
||||
}
|
||||
|
||||
// Below threshold: 5000 / 10000 → no compaction
|
||||
s.UpsertContextItems(ctx, convID, nil) // clear
|
||||
for i := 0; i < 5; i++ {
|
||||
m, _ := s.AddMessage(ctx, convID, "user", "test", 1000)
|
||||
s.AppendContextMessage(ctx, convID, m.ID)
|
||||
}
|
||||
needed, _ = ce.NeedsCompaction(ctx, convID, 10000)
|
||||
if needed {
|
||||
t.Error("expected no compaction at 5000/10000 tokens")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompactLeaf(t *testing.T) {
|
||||
ce, s, convID := newTestCompactionEngine(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create enough messages to trigger leaf compaction:
|
||||
// Need > FreshTailCount(32) evictable messages with >= LeafMinFanout(8) contiguous
|
||||
for i := 0; i < 40; i++ {
|
||||
m, _ := s.AddMessage(ctx, convID, "user", "message content for compaction test", 100)
|
||||
s.AppendContextMessage(ctx, convID, m.ID)
|
||||
}
|
||||
|
||||
// Compact
|
||||
result, err := ce.Compact(ctx, convID, CompactInput{})
|
||||
if err != nil {
|
||||
t.Fatalf("Compact: %v", err)
|
||||
}
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
}
|
||||
|
||||
// Should have created at least one leaf summary
|
||||
if result.LeafSummaries == 0 {
|
||||
t.Error("expected at least 1 leaf summary")
|
||||
}
|
||||
|
||||
// Context should now contain a summary item
|
||||
items, _ := s.GetContextItems(ctx, convID)
|
||||
foundSummary := false
|
||||
for _, item := range items {
|
||||
if item.ItemType == "summary" {
|
||||
foundSummary = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundSummary {
|
||||
t.Error("expected a summary in context_items after leaf compaction")
|
||||
}
|
||||
|
||||
// Some messages should have been replaced
|
||||
if len(result.SummariesCreated) == 0 {
|
||||
t.Error("expected at least 1 summary created")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompactLeafNoCandidate(t *testing.T) {
|
||||
ce, _, convID := newTestCompactionEngine(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Too few messages to trigger leaf compaction
|
||||
m, _ := ce.store.AddMessage(ctx, convID, "user", "short", 10)
|
||||
ce.store.AppendContextMessage(ctx, convID, m.ID)
|
||||
|
||||
result, err := ce.Compact(ctx, convID, CompactInput{})
|
||||
if err != nil {
|
||||
t.Fatalf("Compact: %v", err)
|
||||
}
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result even with no candidate")
|
||||
}
|
||||
if result.LeafSummaries != 0 {
|
||||
t.Errorf("LeafSummaries = %d, want 0 (too few messages)", result.LeafSummaries)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompactCondensed(t *testing.T) {
|
||||
ce, s, convID := newTestCompactionEngine(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create enough leaf summaries and fresh messages to enable condensation
|
||||
leafIDs := make([]string, CondensedMinFanout)
|
||||
for i := 0; i < CondensedMinFanout; i++ {
|
||||
now := time.Now().UTC()
|
||||
summary, err := s.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: "leaf summary content " + time.Now().String(),
|
||||
TokenCount: 500,
|
||||
EarliestAt: &now,
|
||||
LatestAt: &now,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSummary %d: %v", i, err)
|
||||
}
|
||||
leafIDs[i] = summary.SummaryID
|
||||
s.AppendContextSummary(ctx, convID, summary.SummaryID)
|
||||
}
|
||||
|
||||
// Add enough fresh messages to have a fresh tail (>= FreshTailCount)
|
||||
for i := 0; i < FreshTailCount; i++ {
|
||||
m, _ := s.AddMessage(ctx, convID, "user", "fresh message", 10)
|
||||
s.AppendContextMessage(ctx, convID, m.ID)
|
||||
}
|
||||
|
||||
// Compact with force to trigger condensation
|
||||
_, err := ce.Compact(ctx, convID, CompactInput{Force: true})
|
||||
if err != nil {
|
||||
t.Fatalf("Compact: %v", err)
|
||||
}
|
||||
|
||||
// Wait for async condensed goroutine to complete
|
||||
if !waitForCondensed(ce, convID, 2*time.Second) {
|
||||
t.Fatal("timeout waiting for condensed compaction")
|
||||
}
|
||||
|
||||
// Should have created a condensed summary in the DB
|
||||
summaries, _ := s.GetSummariesByConversation(ctx, convID)
|
||||
foundCondensed := false
|
||||
for _, sum := range summaries {
|
||||
if sum.Kind == SummaryKindCondensed {
|
||||
foundCondensed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundCondensed {
|
||||
t.Error("expected at least 1 condensed summary")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompactCondensedDoesNotOrphanSummaryWhenCandidatesRemovedConcurrently(t *testing.T) {
|
||||
// Reproduce orphan bug: candidates found by selectOldestChunkAtDepth are removed
|
||||
// from context_items between candidate selection and ordinal range scan.
|
||||
// Use a slow CompleteFn with barrier sync to control timing.
|
||||
s := openTestStore(t)
|
||||
ctx := context.Background()
|
||||
conv, _ := s.GetOrCreateConversation(ctx, "test:orphan-race")
|
||||
convID := conv.ConversationID
|
||||
|
||||
// Create leaf summaries with enough tokens for condensation
|
||||
var leafIDs []string
|
||||
for i := 0; i < CondensedMinFanout; i++ {
|
||||
now := time.Now().UTC()
|
||||
sum, err := s.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: fmt.Sprintf("leaf summary %d", i),
|
||||
TokenCount: 500,
|
||||
EarliestAt: &now,
|
||||
LatestAt: &now,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSummary: %v", err)
|
||||
}
|
||||
leafIDs = append(leafIDs, sum.SummaryID)
|
||||
s.AppendContextSummary(ctx, convID, sum.SummaryID)
|
||||
}
|
||||
|
||||
// Add fresh tail so leaf summaries are in evictable range
|
||||
for i := 0; i < FreshTailCount+1; i++ {
|
||||
m, _ := s.AddMessage(ctx, convID, "user", "fresh", 10)
|
||||
s.AppendContextMessage(ctx, convID, m.ID)
|
||||
}
|
||||
|
||||
// Barrier: CompleteFn waits until test removes context_items, then returns
|
||||
var barrier1, barrier2 sync.WaitGroup
|
||||
barrier1.Add(1) // CompleteFn signals when called
|
||||
barrier2.Add(1) // test signals when context_items removed
|
||||
|
||||
slowComplete := func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) {
|
||||
barrier1.Done() // signal: LLM called, candidates selected
|
||||
barrier2.Wait() // wait: test removes context_items
|
||||
return "Condensed summary.", nil
|
||||
}
|
||||
|
||||
ce, cancel := newTestCompactionEngineWithStore(s, slowComplete)
|
||||
t.Cleanup(func() {
|
||||
cancel()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
})
|
||||
|
||||
// Run compactCondensed in background
|
||||
type compactResult struct {
|
||||
summaryID *string
|
||||
err error
|
||||
}
|
||||
resultCh := make(chan compactResult, 1)
|
||||
go func() {
|
||||
sid, err := ce.compactCondensed(context.Background(), convID)
|
||||
resultCh <- compactResult{summaryID: sid, err: err}
|
||||
}()
|
||||
|
||||
// Wait for CompleteFn to be called (candidates selected)
|
||||
barrier1.Wait()
|
||||
|
||||
// Remove leaf summaries from context_items (simulating concurrent replacement)
|
||||
items, _ := s.GetContextItems(ctx, convID)
|
||||
var preserved []ContextItem
|
||||
for _, item := range items {
|
||||
isLeaf := false
|
||||
for _, lid := range leafIDs {
|
||||
if item.SummaryID == lid {
|
||||
isLeaf = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !isLeaf {
|
||||
preserved = append(preserved, item)
|
||||
}
|
||||
}
|
||||
s.UpsertContextItems(ctx, convID, preserved)
|
||||
|
||||
// Let CompleteFn return
|
||||
barrier2.Done()
|
||||
|
||||
// Get result
|
||||
res := <-resultCh
|
||||
if res.err != nil {
|
||||
t.Fatalf("compactCondensed: %v", res.err)
|
||||
}
|
||||
|
||||
// With the bug: returns non-nil summaryID even though context_items has no matching ordinals
|
||||
// The fix: should return nil when startOrd == -1
|
||||
if res.summaryID != nil {
|
||||
t.Errorf("compactCondensed returned summaryID=%s, want nil (orphan created)", *res.summaryID)
|
||||
|
||||
// Verify the orphan exists in DB
|
||||
summary, _ := s.GetSummary(context.Background(), *res.summaryID)
|
||||
if summary != nil && summary.Kind == SummaryKindCondensed {
|
||||
// Check it's NOT in context_items (orphan)
|
||||
items2, _ := s.GetContextItems(context.Background(), convID)
|
||||
found := false
|
||||
for _, item := range items2 {
|
||||
if item.SummaryID == *res.summaryID {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("condensed summary exists in DB but not in context_items — orphan confirmed")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompactUntilUnder(t *testing.T) {
|
||||
ce, s, convID := newTestCompactionEngine(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create many leaf summaries to ensure we can condense
|
||||
for i := 0; i < 8; i++ {
|
||||
now := time.Now().UTC()
|
||||
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: "leaf summary for condensation test",
|
||||
TokenCount: 500,
|
||||
EarliestAt: &now,
|
||||
LatestAt: &now,
|
||||
})
|
||||
s.AppendContextSummary(ctx, convID, summary.SummaryID)
|
||||
}
|
||||
|
||||
// Force compact until under budget
|
||||
result, err := ce.CompactUntilUnder(ctx, convID, 2000)
|
||||
if err != nil {
|
||||
t.Fatalf("CompactUntilUnder: %v", err)
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectShallowestCondensationCandidate(t *testing.T) {
|
||||
ce, s, convID := newTestCompactionEngine(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create enough leaf summaries + fresh messages for candidates
|
||||
for i := 0; i < LeafMinFanout; i++ {
|
||||
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: "leaf",
|
||||
TokenCount: 100,
|
||||
})
|
||||
s.AppendContextSummary(ctx, convID, summary.SummaryID)
|
||||
}
|
||||
|
||||
// Add fresh tail messages so summaries are in evictable range
|
||||
for i := 0; i < FreshTailCount+1; i++ {
|
||||
m, _ := s.AddMessage(ctx, convID, "user", "fresh", 5)
|
||||
s.AppendContextMessage(ctx, convID, m.ID)
|
||||
}
|
||||
|
||||
candidates, err := ce.selectShallowestCondensationCandidate(ctx, convID, false)
|
||||
if err != nil {
|
||||
t.Fatalf("selectShallowestCondensationCandidate: %v", err)
|
||||
}
|
||||
|
||||
// Should find leaf summaries at depth 0
|
||||
if len(candidates) < CondensedMinFanout {
|
||||
t.Errorf("candidates = %d, want >= %d", len(candidates), CondensedMinFanout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectShallowestCondensationCandidateEmpty(t *testing.T) {
|
||||
ce, _, convID := newTestCompactionEngine(t)
|
||||
ctx := context.Background()
|
||||
|
||||
candidates, err := ce.selectShallowestCondensationCandidate(ctx, convID, false)
|
||||
if err != nil {
|
||||
t.Fatalf("selectShallowestCondensationCandidate: %v", err)
|
||||
}
|
||||
if len(candidates) != 0 {
|
||||
t.Errorf("candidates = %d, want 0 for empty context", len(candidates))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompactCondensedUsesSelectOldestChunk(t *testing.T) {
|
||||
// Verify that compactCondensed prefers ordinal-ordered chunks via selectOldestChunkAtDepth
|
||||
// rather than just grouping by depth without regard to order
|
||||
ce, s, convID := newTestCompactionEngine(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create interleaved summaries at depth 0 with a message in between:
|
||||
// sum1 (ordinal 100), msg (ordinal 200), sum2 (ordinal 300)
|
||||
|
||||
for i := 0; i < LeafMinFanout+2; i++ {
|
||||
now := time.Now().UTC()
|
||||
|
||||
s.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: fmt.Sprintf("leaf summary %d", i),
|
||||
TokenCount: 100,
|
||||
EarliestAt: &now,
|
||||
LatestAt: &now,
|
||||
})
|
||||
}
|
||||
|
||||
// Insert a message between first two summaries to break contiguity
|
||||
// for selectShallowestCondensationCandidate but would still find all 3
|
||||
// but selectOldestChunkAtDepth should only find sum1 + sum2 (not sum3)
|
||||
|
||||
msg, _ := s.AddMessage(ctx, convID, "user", "interrupting message", 5)
|
||||
s.AppendContextMessage(ctx, convID, msg.ID)
|
||||
|
||||
// Run compactCondensed
|
||||
result, err := ce.compactCondensed(ctx, convID)
|
||||
if err != nil {
|
||||
t.Fatalf("compactCondensed: %v", err)
|
||||
}
|
||||
|
||||
// The result should have merged the two summaries at the start
|
||||
// (skipping the message in between), This proves ordinal-aware selection works.
|
||||
|
||||
_ = result // verify summary was created
|
||||
|
||||
if result != nil {
|
||||
summaries, _ := s.GetSummariesByConversation(ctx, convID)
|
||||
found := false
|
||||
for _, sum := range summaries {
|
||||
if sum.Kind == SummaryKindCondensed {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("expected condensed summary to be created via ordinal-aware selection")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompactCondensedUsesOrdinalAwareSelection(t *testing.T) {
|
||||
ce, s, convID := newTestCompactionEngine(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create leaf summaries at depth 0 (total tokens >= CondensedTargetTokens)
|
||||
for i := 0; i < 5; i++ {
|
||||
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: fmt.Sprintf("leaf summary %d", i),
|
||||
TokenCount: 500, // 5 × 500 = 2500 >= CondensedTargetTokens (2000)
|
||||
})
|
||||
s.AppendContextSummary(ctx, convID, summary.SummaryID)
|
||||
}
|
||||
|
||||
// Add fresh tail
|
||||
for i := 0; i < FreshTailCount+1; i++ {
|
||||
m, _ := s.AddMessage(ctx, convID, "user", "fresh", 5)
|
||||
s.AppendContextMessage(ctx, convID, m.ID)
|
||||
}
|
||||
|
||||
chunk, err := ce.selectOldestChunkAtDepth(ctx, convID, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("selectOldestChunkAtDepth: %v", err)
|
||||
}
|
||||
if len(chunk) < 2 {
|
||||
t.Errorf("chunk length = %d, want >= 2 contiguous summaries", len(chunk))
|
||||
}
|
||||
for _, s := range chunk {
|
||||
if s.Depth != 0 {
|
||||
t.Errorf("got depth %d, want 0", s.Depth)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectOldestChunkAtDepthBreaksOnMessage(t *testing.T) {
|
||||
ce, s, convID := newTestCompactionEngine(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create 3 summaries, then a message, then 3 more summaries
|
||||
for i := 0; i < 3; i++ {
|
||||
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: fmt.Sprintf("leaf %d", i),
|
||||
TokenCount: 100,
|
||||
})
|
||||
s.AppendContextSummary(ctx, convID, summary.SummaryID)
|
||||
}
|
||||
msg, _ := s.AddMessage(ctx, convID, "user", "break", 10)
|
||||
s.AppendContextMessage(ctx, convID, msg.ID)
|
||||
for i := 0; i < 3; i++ {
|
||||
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: fmt.Sprintf("leaf-after %d", i),
|
||||
TokenCount: 100,
|
||||
})
|
||||
s.AppendContextSummary(ctx, convID, summary.SummaryID)
|
||||
}
|
||||
for i := 0; i < FreshTailCount+1; i++ {
|
||||
m, _ := s.AddMessage(ctx, convID, "user", "fresh", 5)
|
||||
s.AppendContextMessage(ctx, convID, m.ID)
|
||||
}
|
||||
|
||||
chunk, _ := ce.selectOldestChunkAtDepth(ctx, convID, 0)
|
||||
if len(chunk) > 3 {
|
||||
t.Errorf("chunk length = %d, want <= 3 (message breaks chain)", len(chunk))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectOldestChunkAtDepthMinTokens(t *testing.T) {
|
||||
ce, s, convID := newTestCompactionEngine(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create summaries with very low token counts (total < 2000)
|
||||
for i := 0; i < 5; i++ {
|
||||
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: fmt.Sprintf("tiny summary %d", i),
|
||||
TokenCount: 50, // very small
|
||||
})
|
||||
s.AppendContextSummary(ctx, convID, summary.SummaryID)
|
||||
}
|
||||
|
||||
// Add fresh tail to protect from compaction
|
||||
for i := 0; i < FreshTailCount+1; i++ {
|
||||
m, _ := s.AddMessage(ctx, convID, "user", fmt.Sprintf("tail %d", i), 10)
|
||||
s.AppendContextMessage(ctx, convID, m.ID)
|
||||
}
|
||||
|
||||
// Should return nil because total tokens (250) < 2000 minimum
|
||||
chunk, err := ce.selectOldestChunkAtDepth(ctx, convID, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("selectOldestChunkAtDepth: %v", err)
|
||||
}
|
||||
if len(chunk) > 0 {
|
||||
t.Errorf("expected empty chunk when tokens < 2000, got %d summaries", len(chunk))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectOldestChunkAtDepthPassesMinTokens(t *testing.T) {
|
||||
ce, s, convID := newTestCompactionEngine(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create summaries with enough tokens (total >= 2000)
|
||||
for i := 0; i < 5; i++ {
|
||||
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: fmt.Sprintf(
|
||||
"substantial summary with enough content to meet minimum token threshold for condensation candidate %d",
|
||||
i,
|
||||
),
|
||||
TokenCount: 500, // 5 × 500 = 2500 >= 2000
|
||||
})
|
||||
s.AppendContextSummary(ctx, convID, summary.SummaryID)
|
||||
}
|
||||
|
||||
// Add fresh tail
|
||||
for i := 0; i < FreshTailCount+1; i++ {
|
||||
m, _ := s.AddMessage(ctx, convID, "user", fmt.Sprintf("tail %d", i), 10)
|
||||
s.AppendContextMessage(ctx, convID, m.ID)
|
||||
}
|
||||
|
||||
// Should return chunk because total tokens (2500) >= 2000
|
||||
chunk, err := ce.selectOldestChunkAtDepth(ctx, convID, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("selectOldestChunkAtDepth: %v", err)
|
||||
}
|
||||
if len(chunk) == 0 {
|
||||
t.Error("expected non-empty chunk when tokens >= 2000")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateLeafSummary(t *testing.T) {
|
||||
ce, _, _ := newTestCompactionEngine(t)
|
||||
ctx := context.Background()
|
||||
|
||||
msgs := []Message{
|
||||
{Role: "user", Content: "hello world", TokenCount: 5},
|
||||
{Role: "assistant", Content: "hi there", TokenCount: 5},
|
||||
}
|
||||
|
||||
content, err := ce.generateLeafSummary(ctx, msgs, "")
|
||||
if err != nil {
|
||||
t.Fatalf("generateLeafSummary: %v", err)
|
||||
}
|
||||
if content == "" {
|
||||
t.Error("expected non-empty summary content")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateLeafSummaryEscalationToAggressive(t *testing.T) {
|
||||
// Level 1 returns summary that's too large (tokens >= input), should escalate to level 2
|
||||
var calls []string
|
||||
escalateComplete := func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) {
|
||||
if contains(prompt, "Aggressive summary policy") {
|
||||
calls = append(calls, "aggressive")
|
||||
return "Short aggressive summary.", nil
|
||||
}
|
||||
calls = append(calls, "normal")
|
||||
// Return a very long summary to trigger escalation
|
||||
longContent := make([]byte, 5000)
|
||||
for i := range longContent {
|
||||
longContent[i] = 'x'
|
||||
}
|
||||
return string(longContent), nil
|
||||
}
|
||||
|
||||
s := openTestStore(t)
|
||||
ce, _ := newTestCompactionEngineWithStore(s, escalateComplete)
|
||||
|
||||
msgs := []Message{
|
||||
{Role: "user", Content: "hello world", TokenCount: 10},
|
||||
{Role: "assistant", Content: "response", TokenCount: 10},
|
||||
}
|
||||
|
||||
content, err := ce.generateLeafSummary(context.Background(), msgs, "")
|
||||
if err != nil {
|
||||
t.Fatalf("generateLeafSummary: %v", err)
|
||||
}
|
||||
if content == "" {
|
||||
t.Error("expected non-empty summary content")
|
||||
}
|
||||
// Should have called both normal and aggressive
|
||||
foundNormal := false
|
||||
foundAggressive := false
|
||||
for _, c := range calls {
|
||||
if c == "normal" {
|
||||
foundNormal = true
|
||||
}
|
||||
if c == "aggressive" {
|
||||
foundAggressive = true
|
||||
}
|
||||
}
|
||||
if !foundNormal {
|
||||
t.Error("expected normal LLM call")
|
||||
}
|
||||
if !foundAggressive {
|
||||
t.Error("expected aggressive LLM call (level 2 escalation)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateLeafSummaryEscalationToTruncation(t *testing.T) {
|
||||
// Both normal and aggressive return empty, should escalate to level 3 truncation
|
||||
emptyComplete := func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
s := openTestStore(t)
|
||||
ce, _ := newTestCompactionEngineWithStore(s, emptyComplete)
|
||||
|
||||
msgs := []Message{
|
||||
{Role: "user", Content: "hello world from test", TokenCount: 10},
|
||||
{Role: "assistant", Content: "response text here", TokenCount: 10},
|
||||
}
|
||||
|
||||
content, err := ce.generateLeafSummary(context.Background(), msgs, "")
|
||||
if err != nil {
|
||||
t.Fatalf("generateLeafSummary: %v", err)
|
||||
}
|
||||
// Level 3 truncation should have produced something
|
||||
if content == "" {
|
||||
t.Error("expected non-empty content from level 3 truncation fallback")
|
||||
}
|
||||
if !contains(content, "Truncated from") {
|
||||
t.Errorf("expected truncation marker in content: %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateCondensedSummary(t *testing.T) {
|
||||
ce, _, _ := newTestCompactionEngine(t)
|
||||
ctx := context.Background()
|
||||
|
||||
summaries := []Summary{
|
||||
{SummaryID: "sum_a", Content: "first summary", TokenCount: 100},
|
||||
{SummaryID: "sum_b", Content: "second summary", TokenCount: 100},
|
||||
}
|
||||
|
||||
content, err := ce.generateCondensedSummary(ctx, summaries)
|
||||
if err != nil {
|
||||
t.Fatalf("generateCondensedSummary: %v", err)
|
||||
}
|
||||
if content == "" {
|
||||
t.Error("expected non-empty condensed summary content")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateCondensedSummaryEscalation(t *testing.T) {
|
||||
// When LLM returns empty, should fall back to deterministic concatenation
|
||||
emptyComplete := func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
s := openTestStore(t)
|
||||
ce, _ := newTestCompactionEngineWithStore(s, emptyComplete)
|
||||
|
||||
summaries := []Summary{
|
||||
{SummaryID: "sum_a", Content: "first summary text", TokenCount: 50},
|
||||
{SummaryID: "sum_b", Content: "second summary text", TokenCount: 50},
|
||||
}
|
||||
|
||||
content, err := ce.generateCondensedSummary(context.Background(), summaries)
|
||||
if err != nil {
|
||||
t.Fatalf("generateCondensedSummary: %v", err)
|
||||
}
|
||||
// Should fall back to concatenation
|
||||
if content == "" {
|
||||
t.Error("expected non-empty content from fallback")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Async Condensed Compaction (Phase 2) ---
|
||||
|
||||
func TestCompactAsyncReturnsBeforeCondensed(t *testing.T) {
|
||||
// Use a slow CompleteFn to verify Compact returns before condensed finishes
|
||||
var callCount int32
|
||||
slowComplete := func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) {
|
||||
atomic.AddInt32(&callCount, 1)
|
||||
time.Sleep(500 * time.Millisecond) // simulate slow LLM
|
||||
return "Slow condensed summary.", nil
|
||||
}
|
||||
|
||||
s := openTestStore(t)
|
||||
ctx := context.Background()
|
||||
conv, _ := s.GetOrCreateConversation(ctx, "test:async")
|
||||
convID := conv.ConversationID
|
||||
|
||||
ce, cancel := newTestCompactionEngineWithStore(s, slowComplete)
|
||||
t.Cleanup(func() {
|
||||
cancel()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
})
|
||||
|
||||
// Create enough leaf summaries for condensation + fresh tail
|
||||
for i := 0; i < CondensedMinFanout; i++ {
|
||||
now := time.Now().UTC()
|
||||
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: "leaf for async test",
|
||||
TokenCount: 500,
|
||||
EarliestAt: &now,
|
||||
LatestAt: &now,
|
||||
})
|
||||
s.AppendContextSummary(ctx, convID, summary.SummaryID)
|
||||
}
|
||||
for i := 0; i < FreshTailCount; i++ {
|
||||
m, _ := s.AddMessage(ctx, convID, "user", "fresh", 10)
|
||||
s.AppendContextMessage(ctx, convID, m.ID)
|
||||
}
|
||||
|
||||
// Compact with force — should return quickly, condensed runs async
|
||||
start := time.Now()
|
||||
result, err := ce.Compact(ctx, convID, CompactInput{Force: true})
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Compact: %v", err)
|
||||
}
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
}
|
||||
|
||||
// Should return well before the 500ms LLM call
|
||||
if elapsed > 200*time.Millisecond {
|
||||
t.Errorf("Compact took %v, should return before async condensed finishes", elapsed)
|
||||
}
|
||||
|
||||
// Wait for async to complete
|
||||
time.Sleep(800 * time.Millisecond)
|
||||
|
||||
// Verify condensed summary was created by background goroutine
|
||||
summaries, _ := s.GetSummariesByConversation(ctx, convID)
|
||||
foundCondensed := false
|
||||
for _, sum := range summaries {
|
||||
if sum.Kind == SummaryKindCondensed {
|
||||
foundCondensed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundCondensed {
|
||||
t.Error("expected at least one condensed summary from async Phase 2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompactAsyncDedup(t *testing.T) {
|
||||
var callCount int32
|
||||
slowComplete := func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) {
|
||||
atomic.AddInt32(&callCount, 1)
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
return "Slow condensed summary.", nil
|
||||
}
|
||||
|
||||
s := openTestStore(t)
|
||||
ctx := context.Background()
|
||||
conv, _ := s.GetOrCreateConversation(ctx, "test:dedup")
|
||||
convID := conv.ConversationID
|
||||
|
||||
ce, cancel := newTestCompactionEngineWithStore(s, slowComplete)
|
||||
t.Cleanup(func() {
|
||||
cancel()
|
||||
waitForCondensed(ce, convID, 2*time.Second)
|
||||
})
|
||||
|
||||
// Create conditions for condensed compaction
|
||||
for i := 0; i < CondensedMinFanout; i++ {
|
||||
now := time.Now().UTC()
|
||||
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: "leaf for dedup",
|
||||
TokenCount: 500,
|
||||
EarliestAt: &now,
|
||||
LatestAt: &now,
|
||||
})
|
||||
s.AppendContextSummary(ctx, convID, summary.SummaryID)
|
||||
}
|
||||
for i := 0; i < FreshTailCount; i++ {
|
||||
m, _ := s.AddMessage(ctx, convID, "user", "fresh", 10)
|
||||
s.AppendContextMessage(ctx, convID, m.ID)
|
||||
}
|
||||
|
||||
// Call Compact twice rapidly
|
||||
ce.Compact(ctx, convID, CompactInput{Force: true})
|
||||
ce.Compact(ctx, convID, CompactInput{Force: true})
|
||||
|
||||
// Wait for async to finish
|
||||
time.Sleep(600 * time.Millisecond)
|
||||
|
||||
// LLM should only be called once for condensed (dedup)
|
||||
// callCount may be 0 if no leaf was created (only condensed in goroutine)
|
||||
// The key is that we don't get 2+ condensed calls
|
||||
if atomic.LoadInt32(&callCount) > 1 {
|
||||
t.Errorf("LLM called %d times, expected at most 1 (dedup)", callCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompactLeafForceBypassesFreshTail(t *testing.T) {
|
||||
// Spec: compactLeaf with force=true should bypass FreshTailCount protection
|
||||
// so CompactUntilUnder can compress messages inside the fresh tail
|
||||
ce, s, convID := newTestCompactionEngine(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create exactly FreshTailCount+4 messages (36 total)
|
||||
// Without force: all messages are in fresh tail → no candidate
|
||||
// With force: should compact the oldest messages
|
||||
total := FreshTailCount + 4
|
||||
for i := 0; i < total; i++ {
|
||||
m, _ := s.AddMessage(ctx, convID, "user", fmt.Sprintf("message %d for force test", i), 100)
|
||||
s.AppendContextMessage(ctx, convID, m.ID)
|
||||
}
|
||||
|
||||
// Without force: should return nil (all in fresh tail)
|
||||
summaryID, err := ce.compactLeaf(ctx, convID)
|
||||
if err != nil {
|
||||
t.Fatalf("compactLeaf no-force: %v", err)
|
||||
}
|
||||
if summaryID != nil {
|
||||
t.Error("expected nil without force (all messages in fresh tail)")
|
||||
}
|
||||
|
||||
// With force: should compact despite fresh tail protection
|
||||
summaryID, err = ce.compactLeaf(ctx, convID, true)
|
||||
if err != nil {
|
||||
t.Fatalf("compactLeaf force: %v", err)
|
||||
}
|
||||
if summaryID == nil {
|
||||
t.Error("expected summary with force=true (bypasses fresh tail)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompactLeafAccumulatesUpToLeafChunkTokens(t *testing.T) {
|
||||
// Spec: compactLeaf should accumulate messages up to LeafChunkTokens before stopping
|
||||
// It should NOT take the entire contiguous chunk regardless of token count
|
||||
ce, s, convID := newTestCompactionEngine(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create messages totaling far more than LeafChunkTokens (20000)
|
||||
// Each message is ~500 tokens, create 80 messages = 40000 tokens
|
||||
for i := 0; i < 80; i++ {
|
||||
m, _ := s.AddMessage(
|
||||
ctx,
|
||||
convID,
|
||||
"user",
|
||||
fmt.Sprintf(
|
||||
"message %d with lots of content to make it big enough for token counting purposes and this should be a substantial message body that represents a meaningful conversation turn",
|
||||
i,
|
||||
),
|
||||
500,
|
||||
)
|
||||
s.AppendContextMessage(ctx, convID, m.ID)
|
||||
}
|
||||
|
||||
summaryID, err := ce.compactLeaf(ctx, convID)
|
||||
if err != nil {
|
||||
t.Fatalf("compactLeaf: %v", err)
|
||||
}
|
||||
if summaryID == nil {
|
||||
t.Fatal("expected a summary to be created")
|
||||
}
|
||||
|
||||
// The source messages that were compacted should total roughly LeafChunkTokens (20000),
|
||||
// not the entire 40000 tokens worth of messages
|
||||
summary, _ := s.GetSummary(ctx, *summaryID)
|
||||
if summary == nil {
|
||||
t.Fatal("summary not found")
|
||||
}
|
||||
|
||||
// Source message tokens should be roughly <= LeafChunkTokens (20000)
|
||||
// Spec says: "Stop when accumulated tokens >= LeafChunkTokens"
|
||||
if summary.SourceMessageTokenCount > LeafChunkTokens {
|
||||
t.Errorf("source tokens = %d, should be <= LeafChunkTokens (%d)",
|
||||
summary.SourceMessageTokenCount, LeafChunkTokens)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
package seahorse
|
||||
|
||||
// Short-term memory configuration constants — all are experience-based defaults.
|
||||
|
||||
const (
|
||||
// OrdinalStep is the gap between ordinals in context_items.
|
||||
// Insert at midpoint; resequence only when precision exhausted.
|
||||
OrdinalStep = 100
|
||||
|
||||
// ContextThreshold is the compaction trigger for the context window.
|
||||
ContextThreshold float64 = 0.75 // Compact at 75% of context window
|
||||
FreshTailCount int = 32 // Recent messages protected from compaction
|
||||
|
||||
// LeafMinFanout is the fanout parameter.
|
||||
LeafMinFanout int = 8 // Min messages per leaf summary
|
||||
CondensedMinFanout int = 4 // Min summaries per condensed
|
||||
CondensedMinFanoutHard int = 2 // Min for forced compaction
|
||||
|
||||
// LeafChunkTokens is the token target.
|
||||
LeafChunkTokens int = 20000 // Max tokens per leaf chunk
|
||||
LeafTargetTokens int = 1200 // Target tokens for leaf summaries
|
||||
CondensedTargetTokens int = 2000 // Target tokens for condensed summaries
|
||||
MaxExpandTokens int = 4000 // Token cap for expansion queries
|
||||
|
||||
// MaxCompactIterations caps CompactUntilUnder to prevent infinite loops.
|
||||
// Each iteration reduces ~4x tokens via leaf (8:1) or condensed (4:1) compaction.
|
||||
// With a 200k token context window and 75% threshold, ~20 iterations is enough
|
||||
// for any realistic scenario. If exceeded, the issue is logged as a warning.
|
||||
MaxCompactIterations int = 20
|
||||
)
|
||||
@@ -0,0 +1,568 @@
|
||||
package seahorse
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
// Config holds engine configuration.
|
||||
type Config struct {
|
||||
DBPath string `json:"dbPath"`
|
||||
IgnoreSessionPatterns []string `json:"ignoreSessionPatterns,omitempty"`
|
||||
StatelessSessionPatterns []string `json:"statelessSessionPatterns,omitempty"`
|
||||
}
|
||||
|
||||
// CompleteFn is the LLM completion function type.
|
||||
type CompleteFn func(ctx context.Context, prompt string, opts CompleteOptions) (string, error)
|
||||
|
||||
// CompleteOptions holds LLM completion parameters.
|
||||
type CompleteOptions struct {
|
||||
Model string
|
||||
MaxTokens int
|
||||
Temperature float64
|
||||
}
|
||||
|
||||
// IngestResult is the result of message ingestion.
|
||||
type IngestResult struct {
|
||||
MessageCount int `json:"messageCount"`
|
||||
TokenCount int `json:"tokenCount"`
|
||||
}
|
||||
|
||||
// AssembleInput controls context assembly.
|
||||
type AssembleInput struct {
|
||||
Budget int `json:"budget"`
|
||||
Query string `json:"query,omitempty"`
|
||||
}
|
||||
|
||||
// AssembleResult contains assembled context.
|
||||
type AssembleResult struct {
|
||||
Messages []Message `json:"messages"`
|
||||
Summary string `json:"summary"` // formatted XML summaries + system prompt addition
|
||||
}
|
||||
|
||||
const numSessionShards = 256
|
||||
|
||||
// Engine is the main short-term memory engine.
|
||||
type Engine struct {
|
||||
store *Store
|
||||
compaction *CompactionEngine
|
||||
compactionMu sync.Mutex
|
||||
assembler *Assembler
|
||||
assemblerMu sync.Mutex
|
||||
retrieval *RetrievalEngine
|
||||
config Config
|
||||
complete CompleteFn
|
||||
ignorePatterns []*regexp.Regexp
|
||||
statelessPatterns []*regexp.Regexp
|
||||
sessionShards [numSessionShards]struct {
|
||||
mu sync.Mutex
|
||||
}
|
||||
}
|
||||
|
||||
// CompactionEngine handles LLM-based summarization (defined in short_compaction.go).
|
||||
type CompactionEngine struct {
|
||||
store *Store
|
||||
config Config
|
||||
complete CompleteFn
|
||||
condensing sync.Map // map[int64]struct{} — dedup for async condensed goroutines
|
||||
shutdownCtx context.Context
|
||||
shutdownCancel context.CancelFunc
|
||||
}
|
||||
|
||||
// Assembler handles budget-aware context assembly (defined in short_assembler.go).
|
||||
type Assembler struct {
|
||||
store *Store
|
||||
config Config
|
||||
}
|
||||
|
||||
// RetrievalEngine handles search and expansion (defined in short_retrieval.go).
|
||||
type RetrievalEngine struct {
|
||||
store *Store
|
||||
config Config
|
||||
}
|
||||
|
||||
// Store returns the underlying store for direct access.
|
||||
func (r *RetrievalEngine) Store() *Store {
|
||||
return r.store
|
||||
}
|
||||
|
||||
// NewEngine creates a new short-term memory engine.
|
||||
func NewEngine(config Config, completeFn CompleteFn) (*Engine, error) {
|
||||
dir := filepath.Dir(config.DBPath)
|
||||
if dir != "" && dir != "." {
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return nil, fmt.Errorf("create db directory: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
db, err := sql.Open("sqlite", config.DBPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open db: %w", err)
|
||||
}
|
||||
|
||||
// Configure SQLite for concurrent access
|
||||
if _, err := db.Exec("PRAGMA journal_mode = WAL;"); err != nil {
|
||||
db.Close()
|
||||
return nil, fmt.Errorf("enable WAL: %w", err)
|
||||
}
|
||||
if _, err := db.Exec("PRAGMA busy_timeout = 5000;"); err != nil {
|
||||
db.Close()
|
||||
return nil, fmt.Errorf("set busy_timeout: %w", err)
|
||||
}
|
||||
if _, err := db.Exec("PRAGMA synchronous = NORMAL;"); err != nil {
|
||||
db.Close()
|
||||
return nil, fmt.Errorf("set synchronous: %w", err)
|
||||
}
|
||||
|
||||
if err := runSchema(db); err != nil {
|
||||
db.Close()
|
||||
return nil, fmt.Errorf("migrations: %w", err)
|
||||
}
|
||||
|
||||
store := &Store{db: db}
|
||||
|
||||
// Prepend hardcoded ignore patterns (spec lines 1326-1328)
|
||||
ignorePatterns := make([]string, 0, 1+len(config.IgnoreSessionPatterns))
|
||||
ignorePatterns = append(ignorePatterns, "heartbeat")
|
||||
ignorePatterns = append(ignorePatterns, config.IgnoreSessionPatterns...)
|
||||
|
||||
retrieval := &RetrievalEngine{store: store, config: config}
|
||||
|
||||
return &Engine{
|
||||
store: store,
|
||||
compaction: nil,
|
||||
assembler: nil,
|
||||
retrieval: retrieval,
|
||||
config: config,
|
||||
complete: completeFn,
|
||||
ignorePatterns: compileSessionPatterns(ignorePatterns),
|
||||
statelessPatterns: compileSessionPatterns(config.StatelessSessionPatterns),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// compileSessionPattern converts a glob pattern to a compiled regex.
|
||||
// Pattern rules:
|
||||
// - * matches any sequence of non-colon characters ([^:]*)
|
||||
// - ** matches any sequence of characters including colons (.*)
|
||||
// - All other characters are treated literally
|
||||
// - Pattern is anchored (^...$)
|
||||
func compileSessionPattern(pattern string) *regexp.Regexp {
|
||||
var b strings.Builder
|
||||
b.WriteByte('^')
|
||||
|
||||
i := 0
|
||||
for i < len(pattern) {
|
||||
if i+1 < len(pattern) && pattern[i] == '*' && pattern[i+1] == '*' {
|
||||
b.WriteString(".*")
|
||||
i += 2
|
||||
continue
|
||||
}
|
||||
if pattern[i] == '*' {
|
||||
b.WriteString("[^:]*")
|
||||
i++
|
||||
continue
|
||||
}
|
||||
b.WriteString(regexp.QuoteMeta(string(pattern[i])))
|
||||
i++
|
||||
}
|
||||
|
||||
b.WriteByte('$')
|
||||
return regexp.MustCompile(b.String())
|
||||
}
|
||||
|
||||
// compileSessionPatterns compiles multiple glob patterns into regex patterns.
|
||||
func compileSessionPatterns(patterns []string) []*regexp.Regexp {
|
||||
result := make([]*regexp.Regexp, 0, len(patterns))
|
||||
for _, p := range patterns {
|
||||
if p == "" {
|
||||
continue
|
||||
}
|
||||
result = append(result, compileSessionPattern(p))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// shouldIgnoreSession returns true if the session key matches any ignore pattern.
|
||||
func (e *Engine) shouldIgnoreSession(sessionKey string) bool {
|
||||
for _, p := range e.ignorePatterns {
|
||||
if p.MatchString(sessionKey) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isStatelessSession returns true if the session key matches any stateless pattern.
|
||||
func (e *Engine) isStatelessSession(sessionKey string) bool {
|
||||
for _, p := range e.statelessPatterns {
|
||||
if p.MatchString(sessionKey) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// fnv32 computes FNV-1a 32-bit hash for session key sharding.
|
||||
func fnv32(key string) uint32 {
|
||||
h := uint32(2166136261)
|
||||
for _, c := range key {
|
||||
h ^= uint32(c)
|
||||
h *= 16777619
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
// getSessionMutex returns the sharded mutex for a session key.
|
||||
func (e *Engine) getSessionMutex(sessionKey string) *sync.Mutex {
|
||||
h := fnv32(sessionKey)
|
||||
shard := h % numSessionShards
|
||||
return &e.sessionShards[shard].mu
|
||||
}
|
||||
|
||||
// Ingest adds messages to a conversation identified by sessionKey.
|
||||
func (e *Engine) Ingest(ctx context.Context, sessionKey string, messages []Message) (*IngestResult, error) {
|
||||
if e.shouldIgnoreSession(sessionKey) {
|
||||
return nil, nil
|
||||
}
|
||||
if e.isStatelessSession(sessionKey) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
mu := e.getSessionMutex(sessionKey)
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
conv, err := e.store.GetOrCreateConversation(ctx, sessionKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get conversation: %w", err)
|
||||
}
|
||||
|
||||
var totalTokens int
|
||||
var msgIDs []int64
|
||||
for _, msg := range messages {
|
||||
var added *Message
|
||||
var err error
|
||||
if len(msg.Parts) > 0 {
|
||||
added, err = e.store.AddMessageWithParts(ctx, conv.ConversationID, msg.Role, msg.Parts, msg.TokenCount)
|
||||
} else {
|
||||
added, err = e.store.AddMessage(ctx, conv.ConversationID, msg.Role, msg.Content, msg.TokenCount)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("add message: %w", err)
|
||||
}
|
||||
totalTokens += msg.TokenCount
|
||||
msgIDs = append(msgIDs, added.ID)
|
||||
}
|
||||
|
||||
// Append to context_items using actual inserted IDs
|
||||
if err := e.store.AppendContextMessages(ctx, conv.ConversationID, msgIDs); err != nil {
|
||||
return nil, fmt.Errorf("append context: %w", err)
|
||||
}
|
||||
|
||||
logger.InfoCF("seahorse", "ingest", map[string]any{
|
||||
"conv_id": conv.ConversationID,
|
||||
"messages": len(messages),
|
||||
"tokens": totalTokens,
|
||||
})
|
||||
return &IngestResult{
|
||||
MessageCount: len(messages),
|
||||
TokenCount: totalTokens,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close releases resources.
|
||||
func (e *Engine) Close() error {
|
||||
// Signal compaction goroutines to stop
|
||||
if e.compaction != nil {
|
||||
e.compaction.Close()
|
||||
}
|
||||
if e.store != nil && e.store.db != nil {
|
||||
return e.store.db.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetRetrieval returns the retrieval engine for tool implementations.
|
||||
func (e *Engine) GetRetrieval() *RetrievalEngine {
|
||||
return e.retrieval
|
||||
}
|
||||
|
||||
// Assemble builds budget-constrained context for a session.
|
||||
func (e *Engine) Assemble(ctx context.Context, sessionKey string, input AssembleInput) (*AssembleResult, error) {
|
||||
if e.shouldIgnoreSession(sessionKey) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
conv, err := e.store.GetOrCreateConversation(ctx, sessionKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get conversation: %w", err)
|
||||
}
|
||||
|
||||
e.initAssemblerOnce()
|
||||
return e.assembler.Assemble(ctx, conv.ConversationID, input)
|
||||
}
|
||||
|
||||
// Compact compresses conversation history for a session.
|
||||
func (e *Engine) Compact(ctx context.Context, sessionKey string, input CompactInput) (*CompactResult, error) {
|
||||
if e.shouldIgnoreSession(sessionKey) || e.isStatelessSession(sessionKey) {
|
||||
return &CompactResult{}, nil
|
||||
}
|
||||
|
||||
conv, err := e.store.GetOrCreateConversation(ctx, sessionKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get conversation: %w", err)
|
||||
}
|
||||
|
||||
e.initCompactionOnce()
|
||||
return e.compaction.Compact(ctx, conv.ConversationID, input)
|
||||
}
|
||||
|
||||
// CompactUntilUnder aggressively compacts until context is under budget.
|
||||
// Used for emergency compaction after LLM overflow (retry reason).
|
||||
func (e *Engine) CompactUntilUnder(ctx context.Context, sessionKey string, budget int) (*CompactResult, error) {
|
||||
if e.shouldIgnoreSession(sessionKey) || e.isStatelessSession(sessionKey) {
|
||||
return &CompactResult{}, nil
|
||||
}
|
||||
|
||||
conv, err := e.store.GetOrCreateConversation(ctx, sessionKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get conversation: %w", err)
|
||||
}
|
||||
|
||||
e.initCompactionOnce()
|
||||
return e.compaction.CompactUntilUnder(ctx, conv.ConversationID, budget)
|
||||
}
|
||||
|
||||
// initCompactionOnce lazily initializes the compaction engine.
|
||||
func (e *Engine) initCompactionOnce() {
|
||||
if e.compaction == nil {
|
||||
e.compactionMu.Lock()
|
||||
defer e.compactionMu.Unlock()
|
||||
if e.compaction == nil {
|
||||
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
|
||||
e.compaction = &CompactionEngine{
|
||||
store: e.store,
|
||||
config: e.config,
|
||||
complete: e.complete,
|
||||
shutdownCtx: shutdownCtx,
|
||||
shutdownCancel: shutdownCancel,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// initAssemblerOnce lazily initializes the assembler.
|
||||
func (e *Engine) initAssemblerOnce() {
|
||||
if e.assembler == nil {
|
||||
e.assemblerMu.Lock()
|
||||
defer e.assemblerMu.Unlock()
|
||||
if e.assembler == nil {
|
||||
e.assembler = &Assembler{store: e.store, config: e.config}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// IngestMessages is an alias for Ingest.
|
||||
func (e *Engine) IngestMessages(ctx context.Context, sessionKey string, messages []Message) (*IngestResult, error) {
|
||||
return e.Ingest(ctx, sessionKey, messages)
|
||||
}
|
||||
|
||||
// Bootstrap reconciles a session's messages with the database.
|
||||
// Called once at startup for each known session.
|
||||
// Bootstrap reconciles JSONL history with SQLite by ingesting only the delta.
|
||||
// Simple approach: find longest matching prefix and append delta.
|
||||
// If any mismatch is detected, clear and rebuild.
|
||||
func (e *Engine) Bootstrap(ctx context.Context, sessionKey string, messages []Message) error {
|
||||
if e.shouldIgnoreSession(sessionKey) {
|
||||
return nil
|
||||
}
|
||||
if e.isStatelessSession(sessionKey) {
|
||||
return nil
|
||||
}
|
||||
if len(messages) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
conv, err := e.store.GetOrCreateConversation(ctx, sessionKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("bootstrap: get conversation: %w", err)
|
||||
}
|
||||
|
||||
// Get messages already in DB
|
||||
dbMsgs, err := e.store.GetMessages(ctx, conv.ConversationID, len(messages), 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("bootstrap: get messages: %w", err)
|
||||
}
|
||||
|
||||
// Fast path: DB has same count and exact match → no-op
|
||||
if len(dbMsgs) == len(messages) {
|
||||
matched := true
|
||||
for i := 0; i < len(messages); i++ {
|
||||
if !messageMatches(dbMsgs[i], messages[i]) {
|
||||
matched = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if matched {
|
||||
return nil // DB is up to date
|
||||
}
|
||||
}
|
||||
|
||||
// Find longest matching prefix from the start
|
||||
anchor := -1
|
||||
compareLen := len(dbMsgs)
|
||||
if compareLen > len(messages) {
|
||||
compareLen = len(messages)
|
||||
}
|
||||
|
||||
for i := 0; i < compareLen; i++ {
|
||||
if messageMatches(dbMsgs[i], messages[i]) {
|
||||
anchor = i
|
||||
} else {
|
||||
// Mismatch detected - log details and rebuild
|
||||
logger.InfoCF("seahorse", "bootstrap: mismatch detected", map[string]any{
|
||||
"conv_id": conv.ConversationID,
|
||||
"index": i,
|
||||
"db_role": dbMsgs[i].Role,
|
||||
"db_content": truncate(dbMsgs[i].Content, 50),
|
||||
"db_parts": len(dbMsgs[i].Parts),
|
||||
"msg_role": messages[i].Role,
|
||||
"msg_content": truncate(messages[i].Content, 50),
|
||||
"msg_parts": len(messages[i].Parts),
|
||||
})
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// If we hit a mismatch before reaching the end of DB messages, delete delta and re-ingest
|
||||
// Note: anchor can be -1 if first message didn't match (history completely changed)
|
||||
if anchor >= 0 && anchor < len(dbMsgs)-1 && len(dbMsgs) > 0 {
|
||||
anchorID := dbMsgs[anchor].ID
|
||||
logger.InfoCF("seahorse", "bootstrap: history edit detected", map[string]any{
|
||||
"conv_id": conv.ConversationID,
|
||||
"db_count": len(dbMsgs),
|
||||
"anchor": anchor,
|
||||
"anchor_id": anchorID,
|
||||
"msg_count": len(messages),
|
||||
"delta_start": anchor + 1,
|
||||
})
|
||||
|
||||
// Delete messages after anchor (also clears context_items)
|
||||
if err := e.store.DeleteMessagesAfterID(ctx, conv.ConversationID, anchorID); err != nil {
|
||||
return fmt.Errorf("bootstrap: delete messages: %w", err)
|
||||
}
|
||||
|
||||
// Re-ingest from anchor+1 to end
|
||||
delta := messages[anchor+1:]
|
||||
if len(delta) > 0 {
|
||||
_, err := e.Ingest(ctx, sessionKey, delta)
|
||||
if err != nil {
|
||||
return fmt.Errorf("bootstrap: re-ingest: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Normal case: append delta after anchor
|
||||
if anchor >= 0 && anchor < len(messages)-1 {
|
||||
delta := messages[anchor+1:]
|
||||
if len(delta) > 0 {
|
||||
_, err := e.Ingest(ctx, sessionKey, delta)
|
||||
if err != nil {
|
||||
return fmt.Errorf("bootstrap: ingest delta: %w", err)
|
||||
}
|
||||
}
|
||||
} else if anchor == -1 && len(dbMsgs) > 0 {
|
||||
// First message changed (history completely different) - rebuild from scratch
|
||||
logger.InfoCF("seahorse", "bootstrap: history replaced, rebuilding", map[string]any{
|
||||
"conv_id": conv.ConversationID,
|
||||
"db_count": len(dbMsgs),
|
||||
"msg_count": len(messages),
|
||||
})
|
||||
// Delete all existing messages
|
||||
if err := e.store.DeleteMessagesAfterID(ctx, conv.ConversationID, 0); err != nil {
|
||||
return fmt.Errorf("bootstrap: delete all messages: %w", err)
|
||||
}
|
||||
// Re-ingest everything
|
||||
if len(messages) > 0 {
|
||||
_, err := e.Ingest(ctx, sessionKey, messages)
|
||||
if err != nil {
|
||||
return fmt.Errorf("bootstrap: re-ingest all: %w", err)
|
||||
}
|
||||
}
|
||||
} else if anchor == -1 && len(dbMsgs) == 0 {
|
||||
// DB is empty, ingest everything
|
||||
_, err := e.Ingest(ctx, sessionKey, messages)
|
||||
if err != nil {
|
||||
return fmt.Errorf("bootstrap: ingest all: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// truncate shortens a string for logging.
|
||||
func truncate(s string, maxLen int) string {
|
||||
if len(s) <= maxLen {
|
||||
return s
|
||||
}
|
||||
return s[:maxLen] + "..."
|
||||
}
|
||||
|
||||
// messageMatches compares two messages using (role, content) or (role, parts).
|
||||
// TokenCount is NOT compared because it may be re-estimated differently
|
||||
// during bootstrap (e.g., via tokenizer.EstimateMessageTokens).
|
||||
// For messages with Parts (tool_use, tool_result), compare Parts instead of Content
|
||||
// since AddMessageWithParts stores empty Content in DB.
|
||||
func messageMatches(a, b Message) bool {
|
||||
if a.Role != b.Role {
|
||||
return false
|
||||
}
|
||||
// If either message has Parts, compare Parts
|
||||
if len(a.Parts) > 0 || len(b.Parts) > 0 {
|
||||
return partsMatch(a.Parts, b.Parts)
|
||||
}
|
||||
// Simple text messages: compare Content
|
||||
return a.Content == b.Content
|
||||
}
|
||||
|
||||
// partsMatch compares two slices of MessagePart for equality.
|
||||
func partsMatch(a, b []MessagePart) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i].Type != b[i].Type {
|
||||
return false
|
||||
}
|
||||
switch a[i].Type {
|
||||
case "text":
|
||||
if a[i].Text != b[i].Text {
|
||||
return false
|
||||
}
|
||||
case "tool_use":
|
||||
if a[i].Name != b[i].Name || a[i].Arguments != b[i].Arguments || a[i].ToolCallID != b[i].ToolCallID {
|
||||
return false
|
||||
}
|
||||
case "tool_result":
|
||||
if a[i].ToolCallID != b[i].ToolCallID || a[i].Text != b[i].Text {
|
||||
return false
|
||||
}
|
||||
case "media":
|
||||
if a[i].MediaURI != b[i].MediaURI || a[i].MimeType != b[i].MimeType {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,212 @@
|
||||
package seahorse
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ParseLastDuration parses a "last" duration string like "6h", "7d", "2w", "1m".
|
||||
// Returns the duration and nil error, or zero and error if invalid.
|
||||
func ParseLastDuration(s string) (time.Duration, error) {
|
||||
if s == "" {
|
||||
return 0, fmt.Errorf("empty duration")
|
||||
}
|
||||
|
||||
re := regexp.MustCompile(`^(\d+)([hdwm])$`)
|
||||
matches := re.FindStringSubmatch(s)
|
||||
if matches == nil {
|
||||
return 0, fmt.Errorf("invalid duration format: %q (use format like 6h, 7d, 2w, 1m)", s)
|
||||
}
|
||||
|
||||
value, _ := strconv.Atoi(matches[1])
|
||||
unit := matches[2]
|
||||
|
||||
switch unit {
|
||||
case "h":
|
||||
return time.Duration(value) * time.Hour, nil
|
||||
case "d":
|
||||
return time.Duration(value) * 24 * time.Hour, nil
|
||||
case "w":
|
||||
return time.Duration(value) * 7 * 24 * time.Hour, nil
|
||||
case "m":
|
||||
return time.Duration(value) * 30 * 24 * time.Hour, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unknown unit: %q", unit)
|
||||
}
|
||||
}
|
||||
|
||||
// GrepInput controls search across summaries and messages.
|
||||
type GrepInput struct {
|
||||
Pattern string `json:"pattern"`
|
||||
Scope string `json:"scope,omitempty"` // "both" (default), "summary", or "message"
|
||||
Role string `json:"role,omitempty"` // "user", "assistant", or "" (all)
|
||||
AllConversations bool `json:"allConversations,omitempty"`
|
||||
Since *time.Time `json:"since,omitempty"`
|
||||
Before *time.Time `json:"before,omitempty"`
|
||||
Last string `json:"last,omitempty"` // shortcut: "6h", "7d", "2w", "1m"
|
||||
Limit int `json:"limit,omitempty"`
|
||||
}
|
||||
|
||||
// GrepResult contains search results.
|
||||
type GrepResult struct {
|
||||
Success bool `json:"success"`
|
||||
Summaries []GrepSummaryResult `json:"summaries"`
|
||||
Messages []GrepMessageResult `json:"messages"`
|
||||
TotalSummaries int `json:"totalSummaries"`
|
||||
TotalMessages int `json:"totalMessages"`
|
||||
Hint string `json:"hint,omitempty"`
|
||||
}
|
||||
|
||||
// GrepSummaryResult is a summary match from grep.
|
||||
type GrepSummaryResult struct {
|
||||
ID string `json:"id"`
|
||||
Content string `json:"content"`
|
||||
Depth int `json:"depth"`
|
||||
Kind SummaryKind `json:"kind"`
|
||||
ConversationID int64 `json:"conversationId"`
|
||||
// Rank is the bm25 relevance score (negative value, lower = better match).
|
||||
// Examples: -5.0 = excellent match, -2.0 = good match, -0.5 = partial match.
|
||||
Rank float64 `json:"rank,omitempty"`
|
||||
}
|
||||
|
||||
// GrepMessageResult is a message match from grep.
|
||||
type GrepMessageResult struct {
|
||||
ID int64 `json:"id,string"`
|
||||
Snippet string `json:"snippet"`
|
||||
Role string `json:"role"`
|
||||
ConversationID int64 `json:"conversationId"`
|
||||
Rank float64 `json:"rank,omitempty"` // Relevance score (more negative = better match)
|
||||
}
|
||||
|
||||
// ExpandMessagesResult contains expanded messages.
|
||||
type ExpandMessagesResult struct {
|
||||
Messages []Message `json:"messages"`
|
||||
TokenCount int `json:"tokenCount"`
|
||||
}
|
||||
|
||||
// Grep searches summaries and messages for matching content.
|
||||
func (r *RetrievalEngine) Grep(ctx context.Context, input GrepInput) (*GrepResult, error) {
|
||||
if input.Pattern == "" {
|
||||
return nil, fmt.Errorf("grep: pattern is required")
|
||||
}
|
||||
|
||||
limit := input.Limit
|
||||
if limit == 0 {
|
||||
limit = 20
|
||||
}
|
||||
|
||||
// Handle Last parameter: convert to Since
|
||||
since := input.Since
|
||||
if input.Last != "" {
|
||||
dur, err := ParseLastDuration(input.Last)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("grep: invalid last: %w", err)
|
||||
}
|
||||
t := time.Now().UTC().Add(-dur)
|
||||
since = &t
|
||||
}
|
||||
|
||||
// Auto-detect mode: use LIKE if pattern contains %, otherwise full-text
|
||||
mode := ""
|
||||
if strings.Contains(input.Pattern, "%") {
|
||||
mode = "like"
|
||||
}
|
||||
|
||||
searchInput := SearchInput{
|
||||
Pattern: input.Pattern,
|
||||
Mode: mode,
|
||||
Role: input.Role,
|
||||
AllConversations: input.AllConversations,
|
||||
Since: since,
|
||||
Before: input.Before,
|
||||
Limit: limit,
|
||||
}
|
||||
|
||||
result := &GrepResult{
|
||||
Success: true,
|
||||
Summaries: make([]GrepSummaryResult, 0),
|
||||
Messages: make([]GrepMessageResult, 0),
|
||||
TotalSummaries: 0,
|
||||
TotalMessages: 0,
|
||||
}
|
||||
|
||||
// Determine scope
|
||||
scope := input.Scope
|
||||
if scope == "" {
|
||||
scope = "both"
|
||||
}
|
||||
|
||||
// Search summaries if requested
|
||||
if scope == "both" || scope == "summary" {
|
||||
sumResults, err := r.store.SearchSummaries(ctx, searchInput)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("search summaries: %w", err)
|
||||
}
|
||||
for _, sr := range sumResults {
|
||||
if sr.SummaryID != "" {
|
||||
result.Summaries = append(result.Summaries, GrepSummaryResult{
|
||||
ID: sr.SummaryID,
|
||||
Content: sr.Content,
|
||||
Depth: sr.Depth,
|
||||
Kind: sr.Kind,
|
||||
ConversationID: sr.ConversationID,
|
||||
Rank: sr.Rank,
|
||||
})
|
||||
}
|
||||
}
|
||||
if len(sumResults) > 0 {
|
||||
result.TotalSummaries = sumResults[0].TotalCount
|
||||
}
|
||||
}
|
||||
|
||||
// Search messages if requested
|
||||
if scope == "both" || scope == "message" {
|
||||
msgResults, err := r.store.SearchMessages(ctx, searchInput)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("search messages: %w", err)
|
||||
}
|
||||
for _, sr := range msgResults {
|
||||
if sr.MessageID > 0 {
|
||||
result.Messages = append(result.Messages, GrepMessageResult{
|
||||
ID: sr.MessageID,
|
||||
Snippet: sr.Snippet,
|
||||
Role: sr.Role,
|
||||
ConversationID: sr.ConversationID,
|
||||
Rank: sr.Rank,
|
||||
})
|
||||
}
|
||||
}
|
||||
if len(msgResults) > 0 {
|
||||
result.TotalMessages = msgResults[0].TotalCount
|
||||
}
|
||||
}
|
||||
|
||||
// Add hint if no results
|
||||
if len(result.Summaries) == 0 && len(result.Messages) == 0 {
|
||||
result.Hint = "No matches. Try: %keyword% for fuzzy search, or all_conversations: true"
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ExpandMessages retrieves full message content by IDs.
|
||||
func (r *RetrievalEngine) ExpandMessages(ctx context.Context, messageIDs []int64) (*ExpandMessagesResult, error) {
|
||||
result := &ExpandMessagesResult{
|
||||
Messages: make([]Message, 0, len(messageIDs)),
|
||||
}
|
||||
|
||||
for _, msgID := range messageIDs {
|
||||
msg, err := r.store.GetMessageByID(ctx, msgID)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
result.Messages = append(result.Messages, *msg)
|
||||
result.TokenCount += msg.TokenCount
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
@@ -0,0 +1,362 @@
|
||||
package seahorse
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// --- Retrieval Tests ---
|
||||
|
||||
func newTestRetrieval(t *testing.T) (*RetrievalEngine, *Store, int64) {
|
||||
t.Helper()
|
||||
s := openTestStore(t)
|
||||
ctx := context.Background()
|
||||
conv, _ := s.GetOrCreateConversation(ctx, "test:retrieval")
|
||||
return &RetrievalEngine{store: s}, s, conv.ConversationID
|
||||
}
|
||||
|
||||
func TestRetrievalGrepSummaries(t *testing.T) {
|
||||
r, s, convID := newTestRetrieval(t)
|
||||
ctx := context.Background()
|
||||
|
||||
s.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: "数据库连接配置说明",
|
||||
TokenCount: 50,
|
||||
})
|
||||
s.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: "API endpoint documentation",
|
||||
TokenCount: 50,
|
||||
})
|
||||
|
||||
// FTS5 search (trigram, needs >= 3 chars)
|
||||
results, err := r.Grep(ctx, GrepInput{
|
||||
Pattern: "数据库连",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Grep: %v", err)
|
||||
}
|
||||
if len(results.Summaries) == 0 {
|
||||
t.Error("expected at least 1 FTS result")
|
||||
}
|
||||
|
||||
// LIKE search with wildcard
|
||||
results, err = r.Grep(ctx, GrepInput{
|
||||
Pattern: "%endpoint%",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Grep LIKE: %v", err)
|
||||
}
|
||||
if len(results.Summaries) == 0 {
|
||||
t.Error("expected at least 1 LIKE result")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetrievalGrepMessages(t *testing.T) {
|
||||
r, s, convID := newTestRetrieval(t)
|
||||
ctx := context.Background()
|
||||
|
||||
s.AddMessage(ctx, convID, "user", "find this message about testing", 5)
|
||||
s.AddMessage(ctx, convID, "user", "unrelated content here", 5)
|
||||
|
||||
results, err := r.Grep(ctx, GrepInput{
|
||||
Pattern: "testing",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Grep: %v", err)
|
||||
}
|
||||
if len(results.Messages) == 0 {
|
||||
t.Error("expected at least 1 result for 'testing'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetrievalExpandMessages(t *testing.T) {
|
||||
r, s, convID := newTestRetrieval(t)
|
||||
ctx := context.Background()
|
||||
|
||||
msg, _ := s.AddMessage(ctx, convID, "user", "expand this message", 10)
|
||||
|
||||
result, err := r.ExpandMessages(ctx, []int64{msg.ID})
|
||||
if err != nil {
|
||||
t.Fatalf("ExpandMessages: %v", err)
|
||||
}
|
||||
if len(result.Messages) != 1 {
|
||||
t.Errorf("Messages = %d, want 1", len(result.Messages))
|
||||
}
|
||||
if result.Messages[0].Content != "expand this message" {
|
||||
t.Errorf("Content = %q, want 'expand this message'", result.Messages[0].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetrievalExpandMultipleMessages(t *testing.T) {
|
||||
r, s, convID := newTestRetrieval(t)
|
||||
ctx := context.Background()
|
||||
|
||||
msg1, _ := s.AddMessage(ctx, convID, "user", "first message", 10)
|
||||
msg2, _ := s.AddMessage(ctx, convID, "assistant", "second message", 10)
|
||||
msg3, _ := s.AddMessage(ctx, convID, "user", "third message", 10)
|
||||
|
||||
result, err := r.ExpandMessages(ctx, []int64{msg1.ID, msg2.ID, msg3.ID})
|
||||
if err != nil {
|
||||
t.Fatalf("ExpandMessages: %v", err)
|
||||
}
|
||||
if len(result.Messages) != 3 {
|
||||
t.Errorf("Messages = %d, want 3", len(result.Messages))
|
||||
}
|
||||
if result.TokenCount != 30 {
|
||||
t.Errorf("TokenCount = %d, want 30", result.TokenCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetrievalGrepWithTimeFilter(t *testing.T) {
|
||||
r, s, convID := newTestRetrieval(t)
|
||||
ctx := context.Background()
|
||||
|
||||
now := time.Now().UTC()
|
||||
before := now.Add(-2 * time.Hour)
|
||||
|
||||
// Create messages at different times
|
||||
s.AddMessage(ctx, convID, "user", "old message about auth", 5)
|
||||
s.AddMessage(ctx, convID, "user", "recent message about auth", 5)
|
||||
|
||||
// Search with time filter
|
||||
results, err := r.Grep(ctx, GrepInput{
|
||||
Pattern: "auth",
|
||||
Since: &before,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Grep: %v", err)
|
||||
}
|
||||
_ = results // Just verify no error
|
||||
}
|
||||
|
||||
func TestRetrievalGrepAllConversations(t *testing.T) {
|
||||
r, s, _ := newTestRetrieval(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create another conversation
|
||||
conv2, _ := s.GetOrCreateConversation(ctx, "test:retrieval2")
|
||||
|
||||
// Add messages to both
|
||||
s.AddMessage(ctx, conv2.ConversationID, "user", "unique keyword xyz", 5)
|
||||
|
||||
// Search all conversations
|
||||
results, err := r.Grep(ctx, GrepInput{
|
||||
Pattern: "xyz",
|
||||
AllConversations: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Grep: %v", err)
|
||||
}
|
||||
if len(results.Messages) == 0 {
|
||||
t.Error("expected to find message in other conversation")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Last Duration Parsing Tests ---
|
||||
|
||||
func TestParseLastDuration(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
wantDur time.Duration
|
||||
wantErr bool
|
||||
}{
|
||||
{"6h", 6 * time.Hour, false},
|
||||
{"1d", 24 * time.Hour, false},
|
||||
{"7d", 7 * 24 * time.Hour, false},
|
||||
{"2w", 14 * 24 * time.Hour, false},
|
||||
{"1m", 30 * 24 * time.Hour, false}, // month = 30 days
|
||||
{"3m", 90 * 24 * time.Hour, false},
|
||||
{"", 0, true},
|
||||
{"invalid", 0, true},
|
||||
{"5x", 0, true}, // unknown unit
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
got, err := ParseLastDuration(tt.input)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != tt.wantDur {
|
||||
t.Errorf("ParseLastDuration(%q) = %v, want %v", tt.input, got, tt.wantDur)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- Role Filter Tests ---
|
||||
|
||||
func TestRetrievalGrepRoleFilter(t *testing.T) {
|
||||
r, s, convID := newTestRetrieval(t)
|
||||
ctx := context.Background()
|
||||
|
||||
s.AddMessage(ctx, convID, "user", "user message about alpha", 5)
|
||||
s.AddMessage(ctx, convID, "assistant", "assistant reply about alpha", 5)
|
||||
s.AddMessage(ctx, convID, "user", "another user message", 5)
|
||||
|
||||
// Search all roles
|
||||
allResults, err := r.Grep(ctx, GrepInput{
|
||||
Pattern: "alpha",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Grep: %v", err)
|
||||
}
|
||||
if len(allResults.Messages) != 2 {
|
||||
t.Errorf("expected 2 messages, got %d", len(allResults.Messages))
|
||||
}
|
||||
|
||||
// Search user only
|
||||
userResults, err := r.Grep(ctx, GrepInput{
|
||||
Pattern: "alpha",
|
||||
Role: "user",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Grep: %v", err)
|
||||
}
|
||||
if len(userResults.Messages) != 1 {
|
||||
t.Errorf("expected 1 user message, got %d", len(userResults.Messages))
|
||||
}
|
||||
if userResults.Messages[0].Role != "user" {
|
||||
t.Errorf("expected role=user, got %s", userResults.Messages[0].Role)
|
||||
}
|
||||
|
||||
// Search assistant only
|
||||
assistantResults, err := r.Grep(ctx, GrepInput{
|
||||
Pattern: "alpha",
|
||||
Role: "assistant",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Grep: %v", err)
|
||||
}
|
||||
if len(assistantResults.Messages) != 1 {
|
||||
t.Errorf("expected 1 assistant message, got %d", len(assistantResults.Messages))
|
||||
}
|
||||
}
|
||||
|
||||
// --- Last Parameter Tests ---
|
||||
|
||||
func TestRetrievalGrepWithLast(t *testing.T) {
|
||||
r, s, convID := newTestRetrieval(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Add messages (we can't control timestamps in SQLite easily,
|
||||
// but we can verify the parameter is parsed correctly)
|
||||
s.AddMessage(ctx, convID, "user", "recent message about testing", 5)
|
||||
|
||||
// Test that Last parameter is converted to Since
|
||||
results, err := r.Grep(ctx, GrepInput{
|
||||
Pattern: "testing",
|
||||
Last: "1d", // last 1 day
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Grep: %v", err)
|
||||
}
|
||||
// Should still find the message since it's recent
|
||||
if len(results.Messages) == 0 {
|
||||
t.Error("expected to find recent message")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRetrievalGrepRoleFilterWithSummaries tests that role filter works when
|
||||
// searching both summaries and messages (summaries don't have role column).
|
||||
func TestRetrievalGrepRoleFilterWithSummaries(t *testing.T) {
|
||||
r, s, convID := newTestRetrieval(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a summary (no role column)
|
||||
s.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: "summary about testing",
|
||||
TokenCount: 50,
|
||||
})
|
||||
|
||||
// Add messages with different roles
|
||||
s.AddMessage(ctx, convID, "user", "user message about testing", 5)
|
||||
s.AddMessage(ctx, convID, "assistant", "assistant reply about testing", 5)
|
||||
|
||||
// Search with role filter and scope=both (default), using LIKE mode (%)
|
||||
// This should NOT error even though summaries don't have role column
|
||||
bothResults, err := r.Grep(ctx, GrepInput{
|
||||
Pattern: "%testing%", // LIKE mode to trigger the bug
|
||||
Role: "user",
|
||||
Scope: "both",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Grep with role and scope=both: %v", err)
|
||||
}
|
||||
|
||||
// Should only return user messages, not summaries or assistant messages
|
||||
if len(bothResults.Messages) != 1 {
|
||||
t.Errorf("expected 1 user message, got %d", len(bothResults.Messages))
|
||||
}
|
||||
if len(bothResults.Messages) > 0 && bothResults.Messages[0].Role != "user" {
|
||||
t.Errorf("expected role=user, got %s", bothResults.Messages[0].Role)
|
||||
}
|
||||
|
||||
// Summaries should be empty since they don't have roles to filter
|
||||
// (or we could return all summaries - either is acceptable)
|
||||
}
|
||||
|
||||
// TestRetrievalGrepTotalCounts tests that grep returns total counts.
|
||||
func TestRetrievalGrepTotalCounts(t *testing.T) {
|
||||
r, s, convID := newTestRetrieval(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create 3 summaries
|
||||
for i := 0; i < 3; i++ {
|
||||
s.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: fmt.Sprintf("summary about testing %d", i),
|
||||
TokenCount: 50,
|
||||
})
|
||||
}
|
||||
|
||||
// Add 5 messages
|
||||
for i := 0; i < 5; i++ {
|
||||
s.AddMessage(ctx, convID, "user", fmt.Sprintf("message about testing %d", i), 5)
|
||||
}
|
||||
|
||||
// Search with limit smaller than total
|
||||
results, err := r.Grep(ctx, GrepInput{
|
||||
Pattern: "%testing%", // LIKE mode
|
||||
Scope: "both",
|
||||
Limit: 2,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Grep: %v", err)
|
||||
}
|
||||
|
||||
// Should return limited results
|
||||
if len(results.Summaries) > 2 {
|
||||
t.Errorf("expected at most 2 summaries, got %d", len(results.Summaries))
|
||||
}
|
||||
if len(results.Messages) > 2 {
|
||||
t.Errorf("expected at most 2 messages, got %d", len(results.Messages))
|
||||
}
|
||||
|
||||
// But total counts should reflect all matches
|
||||
if results.TotalSummaries != 3 {
|
||||
t.Errorf("expected TotalSummaries=3, got %d", results.TotalSummaries)
|
||||
}
|
||||
if results.TotalMessages != 5 {
|
||||
t.Errorf("expected TotalMessages=5, got %d", results.TotalMessages)
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user