Compare commits

..

24 Commits

Author SHA1 Message Date
lxowalle 51eecde01e Feat/support isolation (#2423)
* * completed

* * optimzie

* * fix format

* * fix pr check

* try to fix ci

* * Indicates that Windows does not support expos_paths, adding more mount paths for the Linux platform.

* fix isolation startup lifecycle and MCP transport wrapping

* fix isolation startup cleanup and optional Linux mounts

* fix isolation path handling for relative hooks

Preserve relative command and working-directory semantics when Linux isolation wraps subprocesses, and restore absolute argv path exposure to avoid startup regressions. Add hook coverage and docs updates so isolation-enabled process hooks keep working as configured.

* * fix ci
2026-04-08 18:15:42 +08:00
ywj 8b3e502690 fix(feishu): enrich reply context for card and file replies (#2144)
* fix(feishu): enrich reply context for card and file replies

* refactor(feishu): extract reply functions to feishu_reply.go

- Move reply-related functions to new feishu_reply.go
- Move corresponding tests to feishu_reply_test.go
- Extract magic number 600 to maxReplyContextLen constant
- Unify replyTargetID/replyTargetFromMessage (prefer parent_id, fallback root_id)
- Add source comment for containsFeishuUpgradePlaceholder

* fix(feishu): skip API fallback for non-thread messages, prepend replied media refs

- resolveReplyTargetMessageID: only call fetchMessageByID fallback when
  ThreadId is set, avoiding unnecessary API calls for non-reply messages
- prependReplyContext: prepend replied media refs before current media refs
  to maintain correct ordering

* fix(feishu): add message cache for fetchMessageByID to avoid repeated downloads

- Add messageCache (sync.Map) to FeishuChannel struct
- Cache fetched messages with 30s TTL to avoid re-downloading attachments
  when multiple users reply to the same parent message in a thread
- Cleanup expired entries on read access (no background goroutine needed)

* fix(feishu): early-return for non-reply messages, add cache and fetchMessageByID comment

* fix: remove duplicate test and fix gci import order

* fix(feishu): remove duplicate prependReplyContext call
2026-04-08 14:26:17 +08:00
wenjie 7d16764674 fix(gateway): validate PID ownership and clean stale pid files (#2422)
* fix(gateway): validate PID ownership and clean stale pid files

- include `pid` in health responses for runtime PID verification
- add `RemovePidFileIfPID` to safely delete PID files only on PID match
- sanitize gateway PID data via process-command checks with health fallback
- ignore and remove stale/non-gateway PID files before gateway operations
- refuse stop/restart actions when the attached process is not a gateway
- update gateway and websocket tests to cover PID validation and safety paths

* test(seahorse): use shared in-memory SQLite DB in tests to fix async compaction failures

* test: remove unused sendMediaErr field from hook test mock
2026-04-08 14:23:21 +08:00
Harmoon ee29aaa871 Enhance hooks with respond action and comprehensive documentation (#2215)
* feat(hooks): add respond action for tool execution bypass

Add a new HookActionRespond that allows hooks to return tool results directly, skipping actual tool execution. This enables plugin tool injection, caching, and mocking capabilities.

- Add HookActionRespond constant and support in HookManager
- Extend ToolCallHookRequest with HookResult field
- Implement respond action handling in process hooks and agent loop
- Add comprehensive tests for respond and deny_tool actions
- Update documentation with hook actions table and examples

* docs(hooks): add JSON-RPC protocol and plugin tool injection documentation

Add comprehensive documentation for hook JSON-RPC protocol and plugin tool injection capabilities:

- Add "Hook Actions" section to README.zh.md explaining respond action for tool execution bypass
- Create hook-json-protocol.md/.zh.md detailing JSON-RPC 2.0 protocol for all hook methods
- Create plugin-tool-injection.md/.zh.md with complete examples for external tool implementation
- Document how hooks can inject tool definitions and return results via respond action
- Include Python and Go examples for weather query plugin implementation

* feat(agent): emit tool events and feedback for hook results

Add ToolExecStart event emission and tool feedback for hook results to ensure consistent behavior between normal tool execution and hook bypass scenarios. This maintains parity in event tracking and user feedback when tools are executed via hooks.

* style(agent): format whitespace in hook structs and constants

Remove trailing whitespace and standardize spacing in JSON struct tags, constants, and test data for improved code consistency.

* feat(hooks): add media support for plugin tool injection

Extend the hook respond action to support media file handling:
- Add `media` field for returning images and files from hooks
- Add `response_handled` field to control turn completion behavior
- When response_handled=true, media is automatically delivered to user
- When response_handled=false, media is passed to LLM for vision requests

This enables plugins to directly return generated images, downloaded
files, and other media content either to users or for LLM analysis.

* docs(hooks): document security implications of respond action

Add security boundary documentation explaining that the respond action
bypasses ApproveTool checks, allowing hooks to return results for any
tool without approval. Include recommendations for secure hook
implementation and code comments marking the security considerations.

Changes:
- Add "Security Boundaries" section to plugin-tool-injection docs
- Document bypass of approval checks and associated risks
- Provide security recommendations and example code
- Add inline security comments in hooks.go and loop.go

* refactor(agent): improve completeness of tool result cloning and hook processing

Extend cloneToolResult to properly copy ArtifactTags and Messages fields,
ensuring deep copies of all ToolResult data. Consolidate event emission
and user message handling to match the normal tool execution flow.

* fix(agent): align hook respond path with normal tool execution flow

The hook respond code path was missing several critical behaviors that
existed in normal tool execution:

- Add logging for tool calls with arguments preview
- Add is_tool_call metadata to user-facing messages
- Handle attachment delivery failures by setting error state and
  notifying LLM
- Set ResponseHandled=false when using bus for media delivery
- Check for steering messages and graceful interrupts after tool
  execution, skipping remaining tools when appropriate
- Poll for SubTurn results that arrived during tool execution

This ensures consistent behavior between hook-responded tool calls and
normally executed tool calls.

* test(agent): add tests for hook respond media error handling

Add comprehensive tests for the hook respond code path when media
delivery fails. Tests cover error media channel scenarios and verify
proper error state handling.

Also document that AfterTool is not called when using respond action,
as it provides the final answer directly (design decision).
2026-04-08 11:47:02 +08:00
wenjie 330de0c382 fix(agent): disable seahorse context manager on freebsd/arm (#2417)
* fix(agent): disable seahorse context manager on freebsd/arm

Exclude freebsd/arm from the seahorse-enabled build and route it to the
unsupported stub implementation.

This avoids freebsd/arm build failures caused by modernc sqlite/libc while
keeping picoclaw buildable on that target.

* build: bump Go version from 1.25.8 to 1.25.9

* ci: install and run govulncheck directly in PR workflow
2026-04-08 10:57:22 +08:00
corevibe555 6ce0306c66 fix: use per candidate provider for model_fallbacks (#2143)
* fix: use per-candidate provider for model_fallbacks

Each fallback model now uses its own api_base and api_key from
model_list instead of inheriting the primary model's provider config.

Previously, a single LLMProvider was created from the primary model's
ModelConfig and reused for all fallback candidates — only the model ID
string was swapped. This caused all fallback requests to be routed to
the primary provider's endpoint, making cross-provider fallback chains
non-functional (e.g., OpenRouter primary with Gemini fallback would
send the Gemini request to OpenRouter's API).

Fix: pre-create a per-candidate LLMProvider at agent initialization
time by looking up each candidate's ModelConfig from model_list. The
fallback run closure now selects the correct provider per candidate
via CandidateProviders map, falling back to agent.Provider when no
override is found.

Fixes #2140

Made-with: Cursor

test: add test for instance.go

fix: fix test

refactor: optimize

fix: fix Golang lint issues

chore: comment cleanup

* refactor: use resolvedModelConfig() instead of buildModelIndex()

* fix
2026-04-07 20:07:56 +08:00
Andy Lo-A-Foe 1fc2710999 feat(channels): add teams_webhook output-only channel (#2244)
Add Microsoft Teams webhook integration via Power Automate workflows.

Features:
- Output-only channel for sending notifications to Teams
- Multiple webhook targets with named configuration
- Required "default" target with automatic fallback
- Rich Adaptive Card formatting with full-width rendering
- Markdown table conversion to native Adaptive Card Tables
- Column widths based on header content length
- HTTPS-only webhook URL validation
- Proper error classification for retry behavior

Configuration:
- channels.teams_webhook.enabled: bool
- channels.teams_webhook.webhooks: map of named targets
  - Each target has webhook_url (SecureString) and optional title

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-04-07 19:24:27 +08:00
Guoguo 6a8552a664 fix(web): derive WebSocket URL from browser location instead of backend (#2405)
The frontend previously used ws_url returned by /api/pico/token, which
is built from the launcher's own port. Behind a reverse proxy this can
produce incorrect URLs (e.g. ws://localhost:18800 instead of the
proxy's public address).

Since the launcher already proxies /pico/ws on the same port, the
frontend can simply use window.location.host to construct the
WebSocket URL, which is always correct regardless of proxy layers.

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-07 17:37:01 +08:00
wenjie 7bf6cbe1fa fix(gateway): harden PID liveness handling and websocket proxy state (#2403)
- treat `EPERM` from `signal(0)` as “process exists” on Unix
- classify malformed PID files as invalid and auto-remove them during read
- keep cached `pidData` only for transient races and downgrade `running` to `stopped` when the tracked process is gone
- refresh PID data on WebSocket proxy requests and reject stale cached gateway state
- add regression tests for invalid PID files, status downgrade, on-demand PID loading, and stale proxy rejection
2026-04-07 16:34:42 +08:00
LC 38a498e202 feat(provider): support custom headers injection for HTTP providers (#2402)
* feat(provider): support custom headers injection for HTTP providers

* fix(provider): resolve lint problem

* fix(provider): align stream user-agent and header precedence docs
2026-04-07 16:05:21 +08:00
eturn 778f939302 fix [BUG] WebUI cannot connect to the gateway started by WebUI (#2267)
#2213
2026-04-07 15:46:45 +08:00
BeaconCat 84edc462d6 assets: update WeChat QR code image (#2385)
Co-authored-by: BeaconCat <BeaconCat@users.noreply.github.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-07 14:09:11 +08:00
Liu Yuan f0e6b7aa37 fix(seahorse): correct bm25 rank semantics in comments (#2360)
SQLite FTS5 bm25() returns negative values where numerically smaller
(more negative) indicates a better match. The official docs state:

  "The better the match, the numerically smaller the value returned."

Two comments incorrectly stated "closer to 0 = better match" and
"lower = better match". Updated all rank descriptions to use the
unambiguous "more negative = higher relevance" phrasing.

This matters because these comments are used as tool prompt hints
for LLM agents, and incorrect semantics could lead to wrong ranking
decisions.
2026-04-07 12:32:28 +08:00
wenjie 661ce5e311 fix(build): gate seahorse context manager on unsupported platforms (#2384)
- add build tags to exclude context_seahorse.go on mipsle and netbsd
- add context_seahorse_unsupported.go to keep registration and return a clear runtime error
- remove unused indirect dependency github.com/reiver/go-porterstemmer from go.mod and go.sum
2026-04-07 11:49:35 +08:00
dependabot[bot] c3e7396a3d build(deps): bump github.com/pion/rtp from 1.8.7 to 1.10.1 (#2290)
Bumps [github.com/pion/rtp](https://github.com/pion/rtp) from 1.8.7 to 1.10.1.
- [Release notes](https://github.com/pion/rtp/releases)
- [Commits](https://github.com/pion/rtp/compare/v1.8.7...v1.10.1)

---
updated-dependencies:
- dependency-name: github.com/pion/rtp
  dependency-version: 1.10.1
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-04-07 09:58:20 +08:00
dependabot[bot] 29277d4b3b build(deps): bump modernc.org/sqlite from 1.47.0 to 1.48.0 (#2289)
Bumps [modernc.org/sqlite](https://gitlab.com/cznic/sqlite) from 1.47.0 to 1.48.0.
- [Changelog](https://gitlab.com/cznic/sqlite/blob/master/CHANGELOG.md)
- [Commits](https://gitlab.com/cznic/sqlite/compare/v1.47.0...v1.48.0)

---
updated-dependencies:
- dependency-name: modernc.org/sqlite
  dependency-version: 1.48.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-04-07 09:56:46 +08:00
Guoguo 9ec27835cf fix(docker): add -console flag and open network for launcher (#2314)
- Add -console to Dockerfile CMD so launcher outputs logs to stdout,
  making docker logs work as expected
- Remove 127.0.0.1 bind from ports to allow public network access
- Add commented PICOCLAW_LAUNCHER_TOKEN env var example

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-07 09:34:54 +08:00
Liu Yuan 1175f4a62b feat(membench): add LOCOMO memory benchmark tool (#2353)
Benchmark tool comparing legacy session manager vs seahorse short memory
retrieval on the LOCOMO long-term conversational memory dataset.

- cmd/membench/: CLI with ingest/eval/report/run subcommands
- Mode A (legacy): recency-biased budget truncation baseline
- Mode B (seahorse): per-keyword trigram FTS5 search + expand
- Metrics: Token-Overlap F1 and Recall Hit Rate
- `make mem` builds, downloads data, runs benchmark end-to-end
2026-04-06 17:26:43 +08:00
Liu Yuan 15a70ac45c feat(seahorse): implement short-term memory engine (LCM) (#2285)
* feat(seahorse): implement short-term memory engine of seahorse

Add pkg/seahorse/ module implementing a SQLite-backed DAG-based summary
hierarchy for context management, ported from lossless-claw's LCM design:

- types.go + short_constants.go: core types (Message, Summary, Conversation,
  ContextItem) and configuration constants (fanout, token targets, thresholds)
- migration.go: idempotent DB schema with FTS5 trigram tokenizer for CJK
- store.go: full SQLite CRUD (conversations, messages, summaries DAG,
  context_items with ordinal gap numbering, FTS5 search)
- short_engine.go: Engine lifecycle (NewEngine, Ingest, Assemble, Compact),
  session pattern filtering (ignore/stateless glob→regex compilation),
  per-session mutex via sync.Map
- short_assembler.go: budget-aware context assembly with fresh tail protection
  (32 messages), oldest-first eviction, summary XML formatting, RebuildContextItems
- short_compaction.go: leaf compaction (messages→summary) and condensed
  compaction (summaries→higher-level summary), 3-level LLM escalation,
  CompactUntilUnder for emergency overflow
- short_retrieval.go: lookupByID, FTS5/LIKE search, recursive expand with
  token cap
- context_seahorse.go: agent.ContextManager adapter, registered as "seahorse",
  provider↔seahorse message type conversion (ToolCalls, tool_result)

* fix(seahorse): correct 3 adapter bugs in context management

- TokenCount: use full message (Content+ToolCalls+Media) instead of Content-only
- Empty Content: rebuild Content from tool_result Parts when stored empty
- Duplicate summaries: summaries only in Summary field, not in History messages
- Grep: fix SearchResult.Snippet→Content for summaries
- Schema: fix FTS5 SQL uses VIRTUAL TABLE not TEMP TABLE
- TestFTS5SQLConstants: verify FTS5 SQL syntax correctness
- Test: fix flaky TestCompactLeaf

* fix(agent): ingest steering messages into seahorse SQLite

Steering messages were only persisted to session JSONL but not ingested
into seahorse SQLite, causing them to be missing from context assembly.

Added `ts.ingestMessage(turnCtx, al, pm)` call in the steering message
injection block alongside the existing JSONL persistence.

Test: TestSeahorseSteeringMessageIngested verifies steering messages
appear in seahorse SQLite DB after being processed.

* fix(seahorse): address 3 blocking bugs from code review

- Fix resequenceContextItemsTx scan error handling (store.go:850)
  Changed `return err` to `return scanErr` to properly propagate scan errors
  instead of returning nil (which silently corrupts data)

- Fix sql.NullString for INTEGER column (store.go:847)
  Changed `mid` from sql.NullString to sql.NullInt64 since message_id
  is INTEGER in schema. Removed unnecessary strconv.ParseInt call.

- Fix compactCondensed fallback deleting non-candidate items
  Added ReplaceContextItemsWithSummary method for per-item deletion
  when candidates are not contiguous in ordinal space.
  Optimized to use range deletion when candidates are consecutive.

* fix(seahorse): pass Budget to Compact for correct condensed threshold

Issue #4 from PR review: When Budget was not passed to seahorse.Compact,
it defaulted to `tokensBefore * 0.75`, making `tokensBefore > budget`
always true and causing condensed compaction to trigger unnecessarily.

Changes:
- context_seahorse.go: Forward Budget from CompactRequest to CompactInput
- loop.go: Pass Budget (ContextWindow) in all 3 Compact calls
- Add test verifying condensed is skipped when tokens < threshold
- Fix lint issues in store.go and store_test.go

* fix(seahorse): add mutex for assembler lazy initialization

Issue #5 from PR review: The check-then-create pattern for e.assembler
was a data race when multiple goroutines called Assemble() concurrently:
    if e.assembler == nil {
        e.assembler = &Assembler{...}
    }

Changes:
- Add assemblerMu sync.Mutex to Engine struct
- Add initAssemblerOnce() using double-checked locking (same pattern as initCompactionOnce)
- Add TestAssemblerLazyInitRace to verify thread-safety

* fix(seahorse): handle non-consecutive depths in selectShallowestCondensationCandidate

Issue #8 from PR review: the loop iterated depth 0, 1, 2... assuming
consecutive keys, but break when key was missing caused deeper depths
to never be checked.

Fix: collect all existing depth keys, sort, then iterate in order.

* fix(seahorse): wrap DeleteMessagesAfterID and appendContextItems in transactions

- DeleteMessagesAfterID: wrap all DELETE operations in a transaction for
  atomicity, remove redundant manual FTS delete (handled by trigger)
- appendContextItems: use transaction to fix read-then-write race condition
- Add GetMaxOrdinalTx and resolveItemTokenCountTx for transaction-scoped queries
- Remove unused resolveItemTokenCount function

Fixes PR review issues 6 and 7.

* fix(seahorse): derive readable content from Parts and cap CompactUntilUnder iterations

- Derive readable content from MessageParts in AddMessageWithParts so
  FTS5 indexing and summary formatting can access tool call information
- formatMessagesForSummary and truncateSummary now fall back to Parts
  when Content is empty, fixing blank summaries for Part-based messages
- Add MaxCompactIterations (20) to prevent CompactUntilUnder infinite
  loops; exceeded iterations are logged as warnings
2026-04-05 09:05:16 +08:00
LC 71337b6f52 fix(tool): clarify write_file nested-JSON escape semantics and add tests (#2320)
* fix(tool): clarify write_file nested-JSON escape semantics and add tests

* fix(tool): improve formatting of escaping rules in CLI tool prompt

* fix(tool): align escape notation with function.arguments layer
2026-04-04 17:56:49 +08:00
Mauro 84e42d6904 Merge pull request #2316 from zeroznet/fix/help-banner-double-v
fix: avoid duplicate v in CLI help banner
2026-04-03 23:14:07 +02:00
Robert Bopko e8d92e4a36 test: update root help banner expectation 2026-04-03 21:59:57 +02:00
Robert Bopko cbd0798a56 fix: avoid duplicate v in CLI help banner 2026-04-03 19:58:52 +02:00
Mauro d8c5183d9a feat(mcp): store oversized text results as artifacts (#2308)
* feat(mcp): store oversized text results as artifacts

* feat(mcp): fix doc

* fix(mcp): preserve raw MCP payload in text artifacts

* fix(mcp): avoid leaking large text when artifact persistence fails

* chore(mcp): clarify inline text limit and cover artifact edge cases
2026-04-04 01:30:36 +08:00
130 changed files with 22044 additions and 263 deletions
+4 -3
View File
@@ -41,10 +41,11 @@ jobs:
with:
go-version-file: go.mod
- name: Install govulncheck
run: go install golang.org/x/vuln/cmd/govulncheck@v1.1.4
- name: Run Govulncheck
uses: golang/govulncheck-action@v1
with:
go-package: ./...
run: govulncheck -C . -format text ./...
test:
name: Tests
+2
View File
@@ -67,3 +67,5 @@ web/backend/dist/*
.claude/
docker/data
.omc/
+1
View File
@@ -12,6 +12,7 @@ linters:
- exhaustruct
- funcorder
- gochecknoglobals
- gosmopolitan # Project legitimately uses CJK text in tests (FTS5, token counting)
- godot
- intrange
- ireturn
+19
View File
@@ -349,6 +349,25 @@ build-macos-app:build-launcher
@./scripts/build-macos-app.sh $(PLATFORM)-$(ARCH)
@echo "macOS .app bundle created: $(BUILD_DIR)/PicoClaw.app"
## mem: Build membench, download LOCOMO data (if needed), run benchmark, and show results
mem:
@echo "Building membench..."
@mkdir -p $(BUILD_DIR)
@$(GO) build -o $(BUILD_DIR)/membench ./cmd/membench
@echo "Build complete: $(BUILD_DIR)/membench"
@if [ ! -f $(BUILD_DIR)/memdata/locomo10.json ]; then \
echo "Downloading LOCOMO dataset..."; \
mkdir -p $(BUILD_DIR)/memdata; \
curl -sfL "https://raw.githubusercontent.com/snap-research/locomo/main/data/locomo10.json" \
-o $(BUILD_DIR)/memdata/locomo10.json && [ -s $(BUILD_DIR)/memdata/locomo10.json ] || { echo "Error: LOCOMO download failed"; exit 1; }; \
echo "Download complete"; \
else \
echo "LOCOMO dataset already exists, skipping download"; \
fi
@echo "Running benchmark..."
@rm -rf $(BUILD_DIR)/memout
@$(BUILD_DIR)/membench run --data $(BUILD_DIR)/memdata --out $(BUILD_DIR)/memout --budget 4000
## help: Show this help message
help:
@echo "picoclaw Makefile"
BIN
View File
Binary file not shown.

Before

Width:  |  Height:  |  Size: 365 KiB

After

Width:  |  Height:  |  Size: 362 KiB

+366
View File
@@ -0,0 +1,366 @@
package main
import (
"context"
"encoding/json"
"fmt"
"log"
"os"
"path/filepath"
"sort"
"strings"
"github.com/sipeed/picoclaw/pkg/seahorse"
)
// EvalResult holds per-sample evaluation results for one mode.
type EvalResult struct {
Mode string `json:"mode"`
SampleID string `json:"sampleId"`
QAResults []QAResult `json:"qaResults"`
Agg AggMetrics `json:"aggregated"`
}
// QAResult holds metrics for a single QA pair.
type QAResult struct {
Question string `json:"question"`
Category int `json:"category"`
GoldAnswer string `json:"goldAnswer"`
TokenF1 float64 `json:"tokenF1"`
HitRate float64 `json:"hitRate"`
}
// AggMetrics holds aggregated evaluation metrics.
type AggMetrics struct {
OverallF1 float64 `json:"overallF1"`
OverallHitRate float64 `json:"overallHitRate"`
ByCategory map[int]*CatMetrics `json:"byCategory"`
TotalQuestions int `json:"totalQuestions"`
}
// CatMetrics holds metrics for a single category.
type CatMetrics struct {
F1 float64 `json:"f1"`
HitRate float64 `json:"hitRate"`
QuestionCount int `json:"questionCount"`
}
// EvalLegacy evaluates using legacy session store (raw history + budget truncation).
func EvalLegacy(
ctx context.Context,
samples []LocomoSample,
legacy *LegacyStore,
budgetTokens int,
) []EvalResult {
results := make([]EvalResult, 0, len(samples))
for si := range samples {
sample := &samples[si]
history := legacy.GetHistory(sample.SampleID)
// Convert messages to content strings
allContent := make([]string, 0, len(history))
for _, msg := range history {
allContent = append(allContent, msg.Content)
}
qaResults := make([]QAResult, 0, len(sample.QA))
for qi := range sample.QA {
qa := &sample.QA[qi]
// Budget truncate the full history
truncated, _ := BudgetTruncate(allContent, budgetTokens)
context := StringListToContent(truncated)
f1 := TokenOverlapF1(context, qa.AnswerString())
hitRate := RecallHitRate(qa.Evidence, sample, context)
qaResults = append(qaResults, QAResult{
Question: qa.Question,
Category: qa.Category,
GoldAnswer: qa.AnswerString(),
TokenF1: f1,
HitRate: hitRate,
})
}
results = append(results, EvalResult{
Mode: "legacy",
SampleID: sample.SampleID,
QAResults: qaResults,
Agg: aggregateMetrics(qaResults),
})
}
return results
}
// EvalSeahorse evaluates using seahorse short memory (per-keyword search + expand).
func EvalSeahorse(
ctx context.Context,
samples []LocomoSample,
ir *SeahorseIngestResult,
budgetTokens int,
) []EvalResult {
store := ir.Engine.GetRetrieval().Store()
retrieval := ir.Engine.GetRetrieval()
results := make([]EvalResult, 0, len(samples))
for si := range samples {
sample := &samples[si]
convID, ok := ir.ConvMap[sample.SampleID]
if !ok {
log.Printf("WARN: no conversation ID for sample %s", sample.SampleID)
continue
}
qaResults := make([]QAResult, 0, len(sample.QA))
for qi := range sample.QA {
qa := &sample.QA[qi]
keywords := ExtractKeywords(qa.Question)
// Search each keyword individually and union results,
// tracking best BM25 rank per message for relevance sorting.
bestRank := map[int64]float64{}
for _, kw := range keywords {
searchResults, err := store.SearchMessages(ctx, seahorse.SearchInput{
Pattern: kw,
ConversationID: convID,
Limit: 20,
})
if err != nil {
log.Printf("WARN: search failed for keyword %q: %v", kw, err)
continue
}
for _, sr := range searchResults {
if sr.MessageID > 0 {
if prev, ok := bestRank[sr.MessageID]; !ok || sr.Rank < prev {
bestRank[sr.MessageID] = sr.Rank
}
}
}
}
// Sort messageIDs by rank ascending (best/most-negative first).
// BudgetTruncate walks from the front, keeping best-ranked messages.
// Note: SQLite FTS5 bm25() returns negative values where more
// negative = better match.
messageIDs := make([]int64, 0, len(bestRank))
for id := range bestRank {
messageIDs = append(messageIDs, id)
}
sort.Slice(messageIDs, func(i, j int) bool {
return bestRank[messageIDs[i]] < bestRank[messageIDs[j]]
})
// Expand messages to get full content
var contentParts []string
if len(messageIDs) > 0 {
expandResult, err := retrieval.ExpandMessages(ctx, messageIDs)
if err != nil {
log.Printf("WARN: expand failed for sample %s: %v", sample.SampleID, err)
} else {
for _, msg := range expandResult.Messages {
contentParts = append(contentParts, msg.Content)
}
}
}
if len(contentParts) == 0 {
qaResults = append(qaResults, QAResult{
Question: qa.Question,
Category: qa.Category,
GoldAnswer: qa.AnswerString(),
TokenF1: 0.0,
HitRate: 0.0,
})
continue
}
// Budget truncate (drop worst-ranked)
truncated, _ := BudgetTruncate(contentParts, budgetTokens)
context := StringListToContent(truncated)
f1 := TokenOverlapF1(context, qa.AnswerString())
hitRate := RecallHitRate(qa.Evidence, sample, context)
qaResults = append(qaResults, QAResult{
Question: qa.Question,
Category: qa.Category,
GoldAnswer: qa.AnswerString(),
TokenF1: f1,
HitRate: hitRate,
})
}
results = append(results, EvalResult{
Mode: "seahorse",
SampleID: sample.SampleID,
QAResults: qaResults,
Agg: aggregateMetrics(qaResults),
})
}
return results
}
// aggregateMetrics computes overall and per-category metrics.
func aggregateMetrics(qaResults []QAResult) AggMetrics {
byCat := map[int]*CatMetrics{}
totalF1 := 0.0
totalHitRate := 0.0
for _, qr := range qaResults {
totalF1 += qr.TokenF1
totalHitRate += qr.HitRate
cat, ok := byCat[qr.Category]
if !ok {
cat = &CatMetrics{}
byCat[qr.Category] = cat
}
cat.F1 += qr.TokenF1
cat.HitRate += qr.HitRate
cat.QuestionCount++
}
n := len(qaResults)
if n == 0 {
n = 1
}
agg := AggMetrics{
OverallF1: totalF1 / float64(n),
OverallHitRate: totalHitRate / float64(n),
ByCategory: byCat,
TotalQuestions: len(qaResults),
}
for _, cat := range agg.ByCategory {
if cat.QuestionCount > 0 {
cat.F1 /= float64(cat.QuestionCount)
cat.HitRate /= float64(cat.QuestionCount)
}
}
return agg
}
// SaveResults writes per-sample eval results to JSON files.
func SaveResults(results []EvalResult, outDir string) error {
if err := os.MkdirAll(outDir, 0o755); err != nil {
return fmt.Errorf("create output dir: %w", err)
}
for _, r := range results {
path := filepath.Join(outDir, fmt.Sprintf("eval_%s_%s.json", r.Mode, r.SampleID))
data, err := json.MarshalIndent(r, "", " ")
if err != nil {
return fmt.Errorf("marshal result: %w", err)
}
if err := os.WriteFile(path, data, 0o644); err != nil {
return fmt.Errorf("write result: %w", err)
}
}
return nil
}
// SaveAggregated writes a combined results.json with all modes.
func SaveAggregated(results []EvalResult, outDir string) error {
byMode := map[string][]EvalResult{}
for _, r := range results {
byMode[r.Mode] = append(byMode[r.Mode], r)
}
aggMap := map[string]AggMetrics{}
for mode, modeResults := range byMode {
aggMap[mode] = computeModeAgg(modeResults)
}
data, err := json.MarshalIndent(aggMap, "", " ")
if err != nil {
return err
}
return os.WriteFile(filepath.Join(outDir, "results.json"), data, 0o644)
}
// computeModeAgg aggregates results for a single mode using weighted averaging
// (weighted by question count per sample). All modes must have the same Mode field.
func computeModeAgg(results []EvalResult) AggMetrics {
agg := AggMetrics{ByCategory: map[int]*CatMetrics{}}
for _, r := range results {
agg.OverallF1 += r.Agg.OverallF1 * float64(r.Agg.TotalQuestions)
agg.OverallHitRate += r.Agg.OverallHitRate * float64(r.Agg.TotalQuestions)
agg.TotalQuestions += r.Agg.TotalQuestions
for cat, cm := range r.Agg.ByCategory {
existing, ok := agg.ByCategory[cat]
if !ok {
existing = &CatMetrics{}
agg.ByCategory[cat] = existing
}
existing.F1 += cm.F1 * float64(cm.QuestionCount)
existing.HitRate += cm.HitRate * float64(cm.QuestionCount)
existing.QuestionCount += cm.QuestionCount
}
}
if agg.TotalQuestions > 0 {
agg.OverallF1 /= float64(agg.TotalQuestions)
agg.OverallHitRate /= float64(agg.TotalQuestions)
}
for _, cat := range agg.ByCategory {
if cat.QuestionCount > 0 {
cat.F1 /= float64(cat.QuestionCount)
cat.HitRate /= float64(cat.QuestionCount)
}
}
return agg
}
// printSection prints a single comparison table section.
func printSection(title string, results []EvalResult) {
fmt.Printf("\n--- %s ---\n", title)
byMode := map[string][]EvalResult{}
for _, r := range results {
byMode[r.Mode] = append(byMode[r.Mode], r)
}
modes := map[string]AggMetrics{}
for mode, modeResults := range byMode {
modes[mode] = computeModeAgg(modeResults)
}
modeKeys := make([]string, 0, len(modes))
for k := range modes {
modeKeys = append(modeKeys, k)
}
sort.Strings(modeKeys)
// Collect all category keys across modes
catSet := map[int]bool{}
for _, agg := range modes {
for cat := range agg.ByCategory {
catSet[cat] = true
}
}
cats := make([]int, 0, len(catSet))
for cat := range catSet {
cats = append(cats, cat)
}
sort.Ints(cats)
fmt.Printf("%-10s %-8s %-8s", "Mode", "HitRate", "F1")
for _, cat := range cats {
fmt.Printf(" %-7s", fmt.Sprintf("C%d", cat))
}
fmt.Println()
fmt.Println(strings.Repeat("-", 10+8+8+7*len(cats)+8))
for _, mode := range modeKeys {
agg := modes[mode]
fmt.Printf("%-10s %-8.4f %-8.4f", mode, agg.OverallHitRate, agg.OverallF1)
for _, cat := range cats {
if cm, ok := agg.ByCategory[cat]; ok {
fmt.Printf(" %-7.4f", cm.HitRate)
} else {
fmt.Printf(" %-7s", "N/A")
}
}
fmt.Println()
}
}
// PrintComparison outputs a human-readable comparison table to stdout.
func PrintComparison(results []EvalResult, llmResults []EvalResult) {
printSection("No LLM generation", results)
if len(llmResults) > 0 {
printSection("With LLM", llmResults)
}
}
+104
View File
@@ -0,0 +1,104 @@
package main
import (
"math"
"testing"
)
func TestComputeModeAggAllCategories(t *testing.T) {
results := []EvalResult{
{
Mode: "test",
SampleID: "s1",
QAResults: []QAResult{
{Category: 1, TokenF1: 0.5, HitRate: 0.8},
{Category: 2, TokenF1: 0.3, HitRate: 0.6},
{Category: 3, TokenF1: 0.1, HitRate: 0.4},
{Category: 4, TokenF1: 0.7, HitRate: 0.9},
{Category: 5, TokenF1: 0.2, HitRate: 0.1},
},
},
}
for i := range results {
results[i].Agg = aggregateMetrics(results[i].QAResults)
}
got := computeModeAgg(results)
// Should have all 5 categories
for cat := 1; cat <= 5; cat++ {
cm, ok := got.ByCategory[cat]
if !ok {
t.Errorf("ByCategory missing category %d", cat)
continue
}
if cm.QuestionCount != 1 {
t.Errorf("ByCategory[%d].QuestionCount = %d, want 1", cat, cm.QuestionCount)
}
}
// Verify specific F1 values per category
wantF1 := map[int]float64{1: 0.5, 2: 0.3, 3: 0.1, 4: 0.7, 5: 0.2}
for cat, want := range wantF1 {
if cm, ok := got.ByCategory[cat]; ok {
if math.Abs(cm.F1-want) > 1e-9 {
t.Errorf("ByCategory[%d].F1 = %.4f, want %.4f", cat, cm.F1, want)
}
}
}
}
func TestComputeModeAgg(t *testing.T) {
// Two samples with different question counts:
// sample-a: 2 questions, F1 = [0.4, 0.6] → avg 0.5
// sample-b: 8 questions, F1 = [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1] → avg 0.1
//
// Unweighted (PrintComparison bug): (0.5 + 0.1) / 2 = 0.3
// Weighted (correct): (0.4+0.6 + 0.1*8) / 10 = 1.8 / 10 = 0.18
results := []EvalResult{
{
Mode: "test",
SampleID: "sample-a",
QAResults: []QAResult{
{TokenF1: 0.4, HitRate: 0.5},
{TokenF1: 0.6, HitRate: 0.7},
},
},
{
Mode: "test",
SampleID: "sample-b",
QAResults: []QAResult{
{TokenF1: 0.1, HitRate: 0.2},
{TokenF1: 0.1, HitRate: 0.2},
{TokenF1: 0.1, HitRate: 0.2},
{TokenF1: 0.1, HitRate: 0.2},
{TokenF1: 0.1, HitRate: 0.2},
{TokenF1: 0.1, HitRate: 0.2},
{TokenF1: 0.1, HitRate: 0.2},
{TokenF1: 0.1, HitRate: 0.2},
},
},
}
// Compute per-sample aggregates
for i := range results {
results[i].Agg = aggregateMetrics(results[i].QAResults)
}
got := computeModeAgg(results)
// Weighted: (0.4+0.6+0.1*8) / 10 = 1.8/10 = 0.18
wantF1 := 0.18
if math.Abs(got.OverallF1-wantF1) > 1e-9 {
t.Errorf("OverallF1 = %.6f, want %.6f (weighted average)", got.OverallF1, wantF1)
}
// Weighted: (0.5+0.7+0.2*8) / 10 = 2.8/10 = 0.28
wantRecall := 0.28
if math.Abs(got.OverallHitRate-wantRecall) > 1e-9 {
t.Errorf("OverallHitRate = %.6f, want %.6f (weighted average)", got.OverallHitRate, wantRecall)
}
if got.TotalQuestions != 10 {
t.Errorf("TotalQuestions = %d, want 10", got.TotalQuestions)
}
}
+85
View File
@@ -0,0 +1,85 @@
package main
import (
"context"
"fmt"
"log"
"github.com/sipeed/picoclaw/pkg/seahorse"
)
// ConvMap stores the mapping from sampleID to seahorse ConversationID.
type ConvMap map[string]int64
// SeahorseIngestResult holds the results of ingesting into seahorse.
type SeahorseIngestResult struct {
Engine *seahorse.Engine
ConvMap ConvMap // sampleID → conversationID
}
// IngestSeahorse loads all LOCOMO samples into a seahorse Engine.
// Returns the engine and a mapping from sampleID to conversationID for scoped retrieval.
func IngestSeahorse(ctx context.Context, samples []LocomoSample, dbPath string) (*SeahorseIngestResult, error) {
noopFn := func(ctx context.Context, prompt string, opts seahorse.CompleteOptions) (string, error) {
return "", nil
}
engine, err := seahorse.NewEngine(seahorse.Config{
DBPath: dbPath,
}, noopFn)
if err != nil {
return nil, fmt.Errorf("create seahorse engine: %w", err)
}
store := engine.GetRetrieval().Store()
convMap := make(ConvMap)
for si := range samples {
sample := &samples[si]
sessionKey := "locomo-" + sample.SampleID
// Check if conversation already exists (idempotent)
existing, _ := store.GetConversationBySessionKey(ctx, sessionKey)
if existing != nil {
convMap[sample.SampleID] = existing.ConversationID
log.Printf("Skipping existing sample %s: convID=%d", sample.SampleID, existing.ConversationID)
continue
}
turns := GetTurns(sample)
// Convert turns to seahorse messages
msgs := make([]seahorse.Message, 0, len(turns))
for _, turn := range turns {
content := turn.Speaker + ": " + turn.Text
msgs = append(msgs, seahorse.Message{
Role: "user",
Content: content,
TokenCount: len(turn.Text) / 4,
})
}
// Ingest all turns for this sample
_, err := engine.Ingest(ctx, sessionKey, msgs)
if err != nil {
return nil, fmt.Errorf("ingest sample %s: %w", sample.SampleID, err)
}
// Get the conversation ID for scoped retrieval
conv, err := store.GetConversationBySessionKey(ctx, sessionKey)
if err != nil {
return nil, fmt.Errorf("get conversation for %s: %w", sample.SampleID, err)
}
if conv == nil {
return nil, fmt.Errorf("conversation not found for %s after ingest", sample.SampleID)
}
convMap[sample.SampleID] = conv.ConversationID
log.Printf("Ingested sample %s: %d turns, convID=%d", sample.SampleID, len(turns), conv.ConversationID)
}
log.Printf("Seahorse ingestion complete: %d samples, %d conversations", len(samples), len(convMap))
return &SeahorseIngestResult{
Engine: engine,
ConvMap: convMap,
}, nil
}
+79
View File
@@ -0,0 +1,79 @@
package main
import (
"context"
"encoding/json"
"path/filepath"
"testing"
"github.com/sipeed/picoclaw/pkg/seahorse"
)
func TestIngestSeahorseIdempotent(t *testing.T) {
ctx := context.Background()
tmpDir := t.TempDir()
dbPath := filepath.Join(tmpDir, "test.db")
// Minimal test data
samples := []LocomoSample{
{
SampleID: "test-1",
Conversation: map[string]json.RawMessage{
"session_1": json.RawMessage(`[
{"speaker":"A","dia_id":"D1:1","text":"hello world this is a test message"},
{"speaker":"B","dia_id":"D1:2","text":"another message for testing purposes"}
]`),
},
},
}
// First ingestion
result1, err := IngestSeahorse(ctx, samples, dbPath)
if err != nil {
t.Fatalf("first ingest failed: %v", err)
}
convCount1 := len(result1.ConvMap)
result1.Engine.Close()
// Second ingestion on same DB — should reuse existing data
result2, err := IngestSeahorse(ctx, samples, dbPath)
if err != nil {
t.Fatalf("second ingest failed: %v", err)
}
defer result2.Engine.Close()
// ConvMap should have same number of entries (no duplicates)
if len(result2.ConvMap) != convCount1 {
t.Errorf("second ingest convMap has %d entries, want %d (same as first)",
len(result2.ConvMap), convCount1)
}
// Verify conversation IDs are the same (reused, not new ones)
for id, cid1 := range result1.ConvMap {
cid2, ok := result2.ConvMap[id]
if !ok {
t.Errorf("sample %s missing from second ConvMap", id)
continue
}
if cid2 != cid1 {
t.Errorf("sample %s: second ingest got convID %d, want %d (reused)", id, cid2, cid1)
}
}
// Verify no duplicate messages by counting
store := result2.Engine.GetRetrieval().Store()
for _, convID := range result2.ConvMap {
msgs, err := store.SearchMessages(ctx, seahorse.SearchInput{
Pattern: "test",
ConversationID: convID,
Limit: 100,
})
if err != nil {
t.Fatalf("search failed: %v", err)
}
// Should find exactly 1 message containing "test" (the first turn)
if len(msgs) > 2 {
t.Errorf("found %d messages for 'test' in conv %d, expected ≤2 (no duplicates)", len(msgs), convID)
}
}
}
+34
View File
@@ -0,0 +1,34 @@
package main
import (
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/session"
)
// LegacyStore wraps session.SessionManager for legacy baseline.
type LegacyStore struct {
sm *session.SessionManager
}
// NewLegacyStore creates a new in-memory session manager.
func NewLegacyStore() *LegacyStore {
return &LegacyStore{
sm: session.NewSessionManager(""),
}
}
// IngestSample loads all turns from a LOCOMO sample into the legacy session store.
func (ls *LegacyStore) IngestSample(sample *LocomoSample) {
sessionKey := "locomo-" + sample.SampleID
turns := GetTurns(sample)
for _, turn := range turns {
content := turn.Speaker + ": " + turn.Text
ls.sm.AddMessage(sessionKey, "user", content)
}
}
// GetHistory returns all messages for a sample's session.
func (ls *LegacyStore) GetHistory(sampleID string) []providers.Message {
sessionKey := "locomo-" + sampleID
return ls.sm.GetHistory(sessionKey)
}
+142
View File
@@ -0,0 +1,142 @@
package main
import (
"encoding/json"
"fmt"
"log"
"os"
"path/filepath"
"sort"
"strconv"
"strings"
)
// LocomoSample represents one conversation sample from the LOCOMO dataset.
type LocomoSample struct {
SampleID string `json:"sample_id"`
Conversation map[string]json.RawMessage `json:"conversation"`
QA []LocomoQA `json:"qa"`
}
// LocomoTurn represents a single turn in a conversation.
type LocomoTurn struct {
Speaker string `json:"speaker"`
DiaID string `json:"dia_id"`
Text string `json:"text"`
}
// LocomoQA represents a question-answer pair with evidence.
type LocomoQA struct {
Question string `json:"question"`
Answer json.RawMessage `json:"answer"` // can be string or int (category 1-4)
AdversarialAnswer string `json:"adversarial_answer"` // category 5 only
Evidence []string `json:"evidence"`
Category int `json:"category"` // 1=single-hop, 2=multi-hop, 3=open-ended, 5=adversarial
}
// AnswerString returns the answer as a string, handling both string and int types.
func (qa *LocomoQA) AnswerString() string {
// Prefer answer field (category 1-4)
if len(qa.Answer) > 0 {
var s string
if err := json.Unmarshal(qa.Answer, &s); err == nil {
return s
}
var n json.Number
if err := json.Unmarshal(qa.Answer, &n); err == nil {
return n.String()
}
return strings.Trim(string(qa.Answer), `"`)
}
// Fallback to adversarial_answer (category 5)
return qa.AdversarialAnswer
}
// LoadDataset reads all JSON files from dataDir and returns parsed samples.
func LoadDataset(dataDir string) ([]LocomoSample, error) {
entries, err := os.ReadDir(dataDir)
if err != nil {
return nil, fmt.Errorf("read data dir %s: %w", dataDir, err)
}
var samples []LocomoSample
for _, entry := range entries {
if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".json") {
path := filepath.Join(dataDir, entry.Name())
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("read file %s: %w", path, err)
}
var batch []LocomoSample
if err := json.Unmarshal(data, &batch); err != nil {
return nil, fmt.Errorf("parse file %s: %w", path, err)
}
samples = append(samples, batch...)
}
}
return samples, nil
}
// GetSessionNames returns sorted session keys (session_1, session_2, ...) from conversation.
func GetSessionNames(conv map[string]json.RawMessage) []string {
var names []string
for k := range conv {
if strings.HasPrefix(k, "session_") && !strings.Contains(k, "_date_time") {
names = append(names, k)
}
}
sort.Slice(names, func(i, j int) bool {
ni := sessionNum(names[i])
nj := sessionNum(names[j])
return ni < nj
})
return names
}
func sessionNum(key string) int {
// "session_1" → 1, "session_10" → 10
parts := strings.SplitN(key, "_", 2)
if len(parts) < 2 {
return 0
}
n, _ := strconv.Atoi(parts[1])
return n
}
// GetTurns flattens all sessions' turns in chronological order.
func GetTurns(sample *LocomoSample) []LocomoTurn {
names := GetSessionNames(sample.Conversation)
var all []LocomoTurn
for _, name := range names {
raw, ok := sample.Conversation[name]
if !ok {
continue
}
var turns []LocomoTurn
if err := json.Unmarshal(raw, &turns); err != nil {
log.Printf("WARNING: unmarshal failed for session %q in sample %s: %v", name, sample.SampleID, err)
continue
}
all = append(all, turns...)
}
return all
}
// GetTurnByDiaID finds a specific turn by dia_id (e.g. "D1:3").
func GetTurnByDiaID(sample *LocomoSample, diaID string) *LocomoTurn {
turns := GetTurns(sample)
for i := range turns {
if turns[i].DiaID == diaID {
return &turns[i]
}
}
return nil
}
// GetSpeakers returns the two speaker names from conversation metadata.
func GetSpeakers(conv map[string]json.RawMessage) (string, string) {
var a, b string
json.Unmarshal(conv["speaker_a"], &a)
json.Unmarshal(conv["speaker_b"], &b)
return a, b
}
+67
View File
@@ -0,0 +1,67 @@
package main
import (
"encoding/json"
"testing"
)
func TestAnswerString(t *testing.T) {
tests := []struct {
name string
json string
want string
}{
{
"string answer",
`{"question":"Q","answer":"Paris","evidence":[],"category":1}`,
"Paris",
},
{
"int answer",
`{"question":"Q","answer":42,"evidence":[],"category":1}`,
"42",
},
{
"adversarial answer (category 5)",
`{"question":"Q","evidence":[],"category":5,"adversarial_answer":"self-care is important"}`,
"self-care is important",
},
{
"both answer and adversarial_answer present",
`{"question":"Q","answer":"normal","evidence":[],"category":5,"adversarial_answer":"adversarial"}`,
"normal",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var qa LocomoQA
if err := json.Unmarshal([]byte(tt.json), &qa); err != nil {
t.Fatalf("unmarshal: %v", err)
}
got := qa.AnswerString()
if got != tt.want {
t.Errorf("AnswerString() = %q, want %q", got, tt.want)
}
})
}
}
func TestGetSessionNames(t *testing.T) {
conv := map[string]json.RawMessage{
"session_2": {},
"session_1": {},
"session_10": {},
"session_1_date_time": {},
"speaker_a": {},
}
names := GetSessionNames(conv)
want := []string{"session_1", "session_2", "session_10"}
if len(names) != len(want) {
t.Fatalf("got %v, want %v", names, want)
}
for i, n := range names {
if n != want[i] {
t.Errorf("names[%d] = %q, want %q", i, n, want[i])
}
}
}
+208
View File
@@ -0,0 +1,208 @@
package main
import (
"context"
"encoding/json"
"fmt"
"log"
"os"
"path/filepath"
"strings"
"github.com/spf13/cobra"
"github.com/sipeed/picoclaw/pkg/logger"
)
var (
flagData string
flagOut string
flagMode string
flagBudget int
)
func main() {
// Suppress seahorse INFO logs during benchmark
logger.SetLevel(logger.WARN)
rootCmd := &cobra.Command{
Use: "membench",
Short: "Memory benchmark tool for picoclaw",
}
ingestCmd := &cobra.Command{
Use: "ingest",
Short: "Load LOCOMO data into storage backends",
RunE: runIngest,
}
ingestCmd.Flags().StringVar(&flagData, "data", "", "LOCOMO dataset directory (required)")
ingestCmd.Flags().StringVar(&flagOut, "out", "./bench-out", "output working directory")
ingestCmd.Flags().StringVar(&flagMode, "mode", "all", "modes to ingest: legacy, seahorse, or all")
evalCmd := &cobra.Command{
Use: "eval",
Short: "Run QA evaluation against ingested data",
RunE: runEval,
}
evalCmd.Flags().StringVar(&flagData, "data", "", "LOCOMO dataset directory (required)")
evalCmd.Flags().StringVar(&flagOut, "out", "./bench-out", "output working directory")
evalCmd.Flags().StringVar(&flagMode, "mode", "all", "modes to evaluate: legacy, seahorse, or all")
evalCmd.Flags().IntVar(&flagBudget, "budget", 4000, "token budget for retrieval")
reportCmd := &cobra.Command{
Use: "report",
Short: "Output comparison results from evaluation",
RunE: runReport,
}
reportCmd.Flags().StringVar(&flagOut, "out", "./bench-out", "output working directory")
runCmd := &cobra.Command{
Use: "run",
Short: "Convenience: eval + report (ingestion is done inline)",
RunE: runAll,
}
runCmd.Flags().StringVar(&flagData, "data", "", "LOCOMO dataset directory (required)")
runCmd.Flags().StringVar(&flagOut, "out", "./bench-out", "output working directory")
runCmd.Flags().StringVar(&flagMode, "mode", "all", "modes to run: legacy, seahorse, or all")
runCmd.Flags().IntVar(&flagBudget, "budget", 4000, "token budget for retrieval")
rootCmd.AddCommand(ingestCmd, evalCmd, reportCmd, runCmd)
if err := rootCmd.Execute(); err != nil {
os.Exit(1)
}
}
func modesFromFlag() []string {
switch strings.ToLower(flagMode) {
case "all":
return []string{"legacy", "seahorse"}
default:
return []string{strings.ToLower(flagMode)}
}
}
func runIngest(cmd *cobra.Command, args []string) error {
if flagData == "" {
return fmt.Errorf("--data is required")
}
modes := modesFromFlag()
if len(modes) == 0 {
return nil
}
ctx := context.Background()
samples, err := LoadDataset(flagData)
if err != nil {
return fmt.Errorf("load dataset: %w", err)
}
log.Printf("Loaded %d samples from %s", len(samples), flagData)
for _, mode := range modes {
switch mode {
case "legacy":
legacy := NewLegacyStore()
for i := range samples {
legacy.IngestSample(&samples[i])
}
log.Printf("legacy: ingested %d samples", len(samples))
case "seahorse":
dbPath := filepath.Join(flagOut, "seahorse.db")
if err := os.MkdirAll(flagOut, 0o755); err != nil {
return fmt.Errorf("create out dir: %w", err)
}
_, err := IngestSeahorse(ctx, samples, dbPath)
if err != nil {
return fmt.Errorf("ingest seahorse: %w", err)
}
}
}
return nil
}
func runEval(cmd *cobra.Command, args []string) error {
if flagData == "" {
return fmt.Errorf("--data is required")
}
modes := modesFromFlag()
if len(modes) == 0 {
return nil
}
ctx := context.Background()
samples, err := LoadDataset(flagData)
if err != nil {
return fmt.Errorf("load dataset: %w", err)
}
log.Printf("Loaded %d samples", len(samples))
var allResults []EvalResult
for _, mode := range modes {
switch mode {
case "legacy":
legacy := NewLegacyStore()
for i := range samples {
legacy.IngestSample(&samples[i])
}
results := EvalLegacy(ctx, samples, legacy, flagBudget)
allResults = append(allResults, results...)
log.Printf("legacy: evaluated %d samples", len(results))
case "seahorse":
dbPath := filepath.Join(flagOut, "seahorse.db")
ir, err := IngestSeahorse(ctx, samples, dbPath)
if err != nil {
return fmt.Errorf("ingest seahorse: %w", err)
}
results := EvalSeahorse(ctx, samples, ir, flagBudget)
allResults = append(allResults, results...)
log.Printf("seahorse: evaluated %d samples", len(results))
}
}
if err := SaveResults(allResults, flagOut); err != nil {
return fmt.Errorf("save results: %w", err)
}
if err := SaveAggregated(allResults, flagOut); err != nil {
return fmt.Errorf("save aggregated: %w", err)
}
PrintComparison(allResults, nil)
return nil
}
func runReport(cmd *cobra.Command, args []string) error {
entries, err := os.ReadDir(flagOut)
if err != nil {
return fmt.Errorf("read out dir: %w", err)
}
var allResults []EvalResult
for _, entry := range entries {
if !entry.IsDir() && strings.HasPrefix(entry.Name(), "eval_") && strings.HasSuffix(entry.Name(), ".json") {
path := filepath.Join(flagOut, entry.Name())
var r EvalResult
data, err := os.ReadFile(path)
if err != nil {
log.Printf("WARN: read %s: %v", path, err)
continue
}
if err := json.Unmarshal(data, &r); err != nil {
log.Printf("WARN: parse %s: %v", path, err)
continue
}
allResults = append(allResults, r)
}
}
if len(allResults) == 0 {
return fmt.Errorf("no eval results found in %s", flagOut)
}
PrintComparison(allResults, nil)
return nil
}
func runAll(cmd *cobra.Command, args []string) error {
return runEval(cmd, args)
}
+227
View File
@@ -0,0 +1,227 @@
package main
import (
"fmt"
"log"
"regexp"
"strconv"
"strings"
"unicode"
)
// diaIDRe matches valid dia_id patterns like "D1:3", "D30:5".
var diaIDRe = regexp.MustCompile(`^D(\d+):(\d+)$`)
// SplitEvidenceIDs splits an evidence string that may contain multiple
// semicolon-separated or space-separated dia_ids. Only returns valid IDs.
// Example: "D8:6; D9:17" → ["D8:6", "D9:17"]
// Example: "D9:1 D4:4 D4:6" → ["D9:1", "D4:4", "D4:6"]
func SplitEvidenceIDs(evidence string) []string {
if evidence == "" {
return nil
}
// Split on semicolons first, then spaces
parts := strings.Split(evidence, ";")
var ids []string
for _, part := range parts {
for _, token := range strings.Fields(strings.TrimSpace(part)) {
token = strings.TrimSpace(token)
if diaIDRe.MatchString(token) {
ids = append(ids, NormalizeDiaID(token))
}
}
}
if len(ids) == 0 {
return nil
}
return ids
}
// NormalizeDiaID strips leading zeros from the number parts of a dia_id.
// "D30:05" → "D30:5", "D10:003" → "D10:3"
func NormalizeDiaID(id string) string {
m := diaIDRe.FindStringSubmatch(id)
if m == nil {
return id
}
session, _ := strconv.Atoi(m[1])
turn, _ := strconv.Atoi(m[2])
return fmt.Sprintf("D%d:%d", session, turn)
}
// stopwords is a fixed English stopword list for deterministic keyword extraction.
var stopwords = map[string]struct{}{
"a": {}, "an": {}, "the": {},
"is": {}, "are": {}, "was": {}, "were": {},
"did": {}, "does": {}, "do": {},
"when": {}, "where": {}, "what": {}, "who": {},
"how": {}, "why": {},
"to": {}, "of": {}, "in": {}, "on": {}, "at": {},
"for": {}, "and": {}, "or": {}, "but": {}, "not": {},
"it": {}, "this": {}, "that": {}, "with": {},
"from": {}, "by": {}, "as": {},
"if": {}, "then": {}, "than": {}, "so": {},
"no": {}, "yes": {},
"all": {}, "any": {}, "each": {}, "every": {},
"some": {}, "such": {},
"about": {}, "into": {}, "over": {},
"after": {}, "before": {}, "between": {},
"through": {}, "during": {}, "until": {},
"would": {}, "could": {}, "should": {},
"may": {}, "might": {}, "can": {},
"will": {}, "shall": {}, "must": {},
"have": {}, "has": {}, "had": {},
"been": {}, "being": {}, "be": {},
"go": {}, "went": {}, "gone": {},
"i": {}, "you": {}, "me": {}, "my": {}, "your": {},
"we": {}, "they": {}, "them": {}, "our": {},
"its": {}, "their": {}, "he": {}, "she": {},
"his": {}, "her": {},
}
// ExtractKeywords removes stopwords and punctuation, returns individual keywords.
// Deterministic: uses fixed stopword list, no LLM.
func ExtractKeywords(question string) []string {
// Lowercase and split on whitespace/punctuation
lower := strings.ToLower(question)
words := strings.FieldsFunc(lower, func(r rune) bool {
return !unicode.IsLetter(r) && !unicode.IsDigit(r)
})
var keywords []string
for _, w := range words {
if w == "" || len(w) < 2 {
continue
}
if _, ok := stopwords[w]; ok {
continue
}
keywords = append(keywords, w)
if len(keywords) >= 6 {
break
}
}
return keywords
}
// TokenOverlapF1 computes token-level F1 between prediction and reference.
// Both strings are lowercased and split on whitespace.
// NOTE: This metric underestimates quality for multi-hop (cat 2) and
// open-ended (cat 3) questions where the gold answer uses different phrasing
// than the source text. LLM-Judge scoring is a v2 follow-up.
func TokenOverlapF1(prediction, reference string) float64 {
predTokens := tokenize(prediction)
refTokens := tokenize(reference)
if len(predTokens) == 0 && len(refTokens) == 0 {
return 1.0
}
if len(predTokens) == 0 || len(refTokens) == 0 {
return 0.0
}
// Count matches
refCount := map[string]int{}
for _, t := range refTokens {
refCount[t]++
}
predCount := map[string]int{}
for _, t := range predTokens {
predCount[t]++
}
var matches float64
for token, pc := range predCount {
if rc, ok := refCount[token]; ok {
matches += float64(min(pc, rc))
}
}
precision := matches / float64(len(predTokens))
recall := matches / float64(len(refTokens))
if precision+recall == 0 {
return 0.0
}
return 2 * precision * recall / (precision + recall)
}
func tokenize(s string) []string {
lower := strings.ToLower(s)
return strings.Fields(lower)
}
// RecallHitRate computes fraction of evidence IDs found in retrieved content.
// For each evidence dia_id, looks up the turn text and checks substring match.
// Logs a warning for turns with text < 20 chars (higher false-positive risk).
func RecallHitRate(evidenceIDs []string, sample *LocomoSample, retrievedContent string) float64 {
if len(evidenceIDs) == 0 {
return 1.0 // no evidence required = perfect
}
// Expand any multi-ID evidence entries (e.g. "D8:6; D9:17" or "D9:1 D4:4")
var expanded []string
for _, id := range evidenceIDs {
split := SplitEvidenceIDs(id)
if split != nil {
expanded = append(expanded, split...)
}
}
if len(expanded) == 0 {
log.Printf("WARNING: no valid dia_ids after expanding evidence %v", evidenceIDs)
return float64(0) / float64(len(evidenceIDs))
}
// Build turn index once (avoids re-parsing JSON per ID)
turns := GetTurns(sample)
turnMap := make(map[string]*LocomoTurn, len(turns))
for i := range turns {
turnMap[turns[i].DiaID] = &turns[i]
}
lowerRetrieved := strings.ToLower(retrievedContent)
found := 0
resolvable := 0
for _, diaID := range expanded {
turn, ok := turnMap[diaID]
if !ok {
log.Printf("WARNING: dia_id %q not found in sample %s", diaID, sample.SampleID)
continue
}
resolvable++
if len(turn.Text) < 20 {
log.Printf("WARNING: short turn text (%d chars) for dia_id %s: %q",
len(turn.Text), diaID, turn.Text)
}
if strings.Contains(lowerRetrieved, strings.ToLower(turn.Text)) {
found++
}
}
if resolvable == 0 {
return 0.0 // no resolvable evidence = can't evaluate
}
return float64(found) / float64(resolvable)
}
// BudgetTruncate truncates messages to fit within a token budget.
// Returns the truncated messages and total token count.
func BudgetTruncate(messages []string, budgetTokens int) ([]string, int) {
var result []string
total := 0
// Walk from the front (best first) and keep until budget exhausted.
for i := 0; i < len(messages); i++ {
tokens := len(messages[i]) / 4
if total+tokens > budgetTokens && len(result) > 0 {
break
}
result = append(result, messages[i])
total += tokens
}
return result, total
}
// StringListToContent joins a list of strings into a single content string.
func StringListToContent(parts []string) string {
return strings.Join(parts, "\n")
}
+239
View File
@@ -0,0 +1,239 @@
package main
import (
"encoding/json"
"math"
"testing"
)
func TestSplitEvidenceIDs(t *testing.T) {
tests := []struct {
input string
want []string
}{
{"D1:3", []string{"D1:3"}},
{"D8:6; D9:17", []string{"D8:6", "D9:17"}},
{"D9:1 D4:4 D4:6", []string{"D9:1", "D4:4", "D4:6"}},
{"D22:1 D22:2 D9:10 D9:11", []string{"D22:1", "D22:2", "D9:10", "D9:11"}},
{"D21:18 D21:22 D11:15 D11:19", []string{"D21:18", "D21:22", "D11:15", "D11:19"}},
{"D30:05", []string{"D30:5"}},
{"D", nil},
{"D:", nil},
{"", nil},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
got := SplitEvidenceIDs(tt.input)
if len(got) != len(tt.want) {
t.Fatalf("SplitEvidenceIDs(%q) = %v, want %v", tt.input, got, tt.want)
}
for i := range got {
if got[i] != tt.want[i] {
t.Errorf("[%d] = %q, want %q", i, got[i], tt.want[i])
}
}
})
}
}
func TestNormalizeDiaID(t *testing.T) {
tests := []struct {
input string
want string
}{
{"D1:3", "D1:3"},
{"D30:05", "D30:5"},
{"D10:003", "D10:3"},
{"D1:0", "D1:0"},
}
for _, tt := range tests {
got := NormalizeDiaID(tt.input)
if got != tt.want {
t.Errorf("NormalizeDiaID(%q) = %q, want %q", tt.input, got, tt.want)
}
}
}
func TestTokenOverlapF1(t *testing.T) {
tests := []struct {
name string
prediction string
reference string
want float64
}{
{"exact match", "hello world", "hello world", 1.0},
{"no overlap", "foo bar", "baz qux", 0.0},
{"empty both", "", "", 1.0},
{"empty prediction", "", "hello", 0.0},
{"empty reference", "hello", "", 0.0},
{"partial overlap", "the cat sat on the mat", "the cat on the floor", 8.0 / 11.0},
{"case insensitive", "Hello World", "hello world", 1.0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := TokenOverlapF1(tt.prediction, tt.reference)
if math.Abs(got-tt.want) > 1e-9 {
t.Errorf("TokenOverlapF1(%q, %q) = %.4f, want %.4f",
tt.prediction, tt.reference, got, tt.want)
}
})
}
}
func TestBudgetTruncate(t *testing.T) {
t.Run("within budget returns all", func(t *testing.T) {
msgs := []string{"short", "message", "here"}
result, total := BudgetTruncate(msgs, 1000)
if len(result) != 3 {
t.Errorf("expected 3 messages, got %d", len(result))
}
if total == 0 {
t.Error("expected non-zero token count")
}
})
t.Run("over budget keeps best first", func(t *testing.T) {
msgs := []string{
"best message that is quite long and takes up tokens",
"good message also fairly long content",
"worst short",
}
result, _ := BudgetTruncate(msgs, 5) // very small budget
if len(result) == 0 {
t.Fatal("expected at least one message")
}
// Best-ranked (first) should be kept
if result[0] != "best message that is quite long and takes up tokens" {
t.Errorf("expected best message kept first, got %q", result[0])
}
})
t.Run("over budget keeps best ranked first", func(t *testing.T) {
// Messages are sorted by bm25 rank ascending (best/most-negative first).
// When budget is insufficient, BudgetTruncate must keep the front
// (best-ranked) messages, not the tail (worst-ranked).
msgs := []string{
"best ranked message with some content here",
"second best message also has content",
"third message here too",
"worst ranked short",
}
// Budget only fits ~1 message (~10 tokens per message, budget=12)
result, _ := BudgetTruncate(msgs, 12)
if len(result) == 0 {
t.Fatal("expected at least one message")
}
if result[0] != "best ranked message with some content here" {
t.Errorf("expected best-ranked (first) message kept, got %q", result[0])
}
// Worst-ranked (last) must NOT appear
for _, m := range result {
if m == "worst ranked short" {
t.Error("worst-ranked message should have been truncated")
}
}
})
t.Run("preserves original order", func(t *testing.T) {
msgs := []string{"alpha", "beta", "gamma"}
result, _ := BudgetTruncate(msgs, 100)
for i, got := range result {
if got != msgs[i] {
t.Errorf("result[%d] = %q, want %q", i, got, msgs[i])
}
}
})
t.Run("empty input", func(t *testing.T) {
result, total := BudgetTruncate(nil, 100)
if len(result) != 0 {
t.Errorf("expected 0 messages, got %d", len(result))
}
if total != 0 {
t.Errorf("expected 0 tokens, got %d", total)
}
})
}
func TestRecallHitRate(t *testing.T) {
// Build a sample with known turns
sample := &LocomoSample{
SampleID: "test-sample",
Conversation: map[string]json.RawMessage{
"session_1": json.RawMessage(`[
{"speaker":"A","dia_id":"D1:1","text":"hello world this is a test message with enough length"},
{"speaker":"B","dia_id":"D1:2","text":"another message for testing recall computation purposes here"},
{"speaker":"A","dia_id":"D1:3","text":"third turn with some more content to test"}
]`),
},
}
t.Run("all evidence found", func(t *testing.T) {
retrieved := "hello world this is a test message with enough length another message for testing recall computation purposes here"
got := RecallHitRate([]string{"D1:1", "D1:2"}, sample, retrieved)
if math.Abs(got-1.0) > 1e-9 {
t.Errorf("RecallHitRate all found = %.4f, want 1.0", got)
}
})
t.Run("partial evidence found", func(t *testing.T) {
retrieved := "hello world this is a test message with enough length"
got := RecallHitRate([]string{"D1:1", "D1:2"}, sample, retrieved)
if math.Abs(got-0.5) > 1e-9 {
t.Errorf("RecallHitRate partial = %.4f, want 0.5", got)
}
})
t.Run("no evidence required", func(t *testing.T) {
got := RecallHitRate(nil, sample, "anything")
if got != 1.0 {
t.Errorf("RecallHitRate no evidence = %.4f, want 1.0", got)
}
})
t.Run("missing turn excluded from denominator", func(t *testing.T) {
// D1:1 is found, D99:1 does not exist in sample
// Should only count resolvable turns in denominator
retrieved := "hello world this is a test message with enough length"
got := RecallHitRate([]string{"D1:1", "D99:1"}, sample, retrieved)
if math.Abs(got-1.0) > 1e-9 {
t.Errorf("RecallHitRate missing turn = %.4f, want 1.0 (unresolvable excluded)", got)
}
})
}
func TestExtractKeywords(t *testing.T) {
tests := []struct {
name string
input string
want []string
}{
{"simple", "What is the capital of France", []string{"capital", "france"}},
{
"stops removed",
"Who is the president of the United States",
[]string{"president", "united", "states"},
},
{
"max 6 keywords",
"one two three four five six seven eight nine ten",
[]string{"one", "two", "three", "four", "five", "six"},
},
{"short words filtered", "I am a go to the store", []string{"am", "store"}},
{"empty", "", nil},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ExtractKeywords(tt.input)
if len(got) != len(tt.want) {
t.Fatalf("ExtractKeywords(%q) = %v (len %d), want %v (len %d)",
tt.input, got, len(got), tt.want, len(tt.want))
}
for i := range got {
if got[i] != tt.want[i] {
t.Errorf("[%d] = %q, want %q", i, got[i], tt.want[i])
}
}
})
}
}
+1 -1
View File
@@ -29,7 +29,7 @@ import (
)
func NewPicoclawCommand() *cobra.Command {
short := fmt.Sprintf("%s picoclaw - Personal AI Assistant v%s\n\n", internal.Logo, config.GetVersion())
short := fmt.Sprintf("%s picoclaw - Personal AI Assistant %s\n\n", internal.Logo, config.GetVersion())
cmd := &cobra.Command{
Use: "picoclaw",
+1 -1
View File
@@ -17,7 +17,7 @@ func TestNewPicoclawCommand(t *testing.T) {
require.NotNil(t, cmd)
short := fmt.Sprintf("%s picoclaw - Personal AI Assistant v%s\n\n", internal.Logo, config.GetVersion())
short := fmt.Sprintf("%s picoclaw - Personal AI Assistant %s\n\n", internal.Logo, config.GetVersion())
assert.Equal(t, "picoclaw", cmd.Use)
assert.Equal(t, short, cmd.Short)
+1 -1
View File
@@ -9,4 +9,4 @@ COPY $TARGETPLATFORM/picoclaw-launcher /usr/local/bin/picoclaw-launcher
COPY $TARGETPLATFORM/picoclaw-launcher-tui /usr/local/bin/picoclaw-launcher-tui
ENTRYPOINT ["picoclaw-launcher"]
CMD ["-public", "-no-browser"]
CMD ["-console", "-public", "-no-browser"]
+5 -2
View File
@@ -45,8 +45,11 @@ services:
- launcher
environment:
- PICOCLAW_GATEWAY_HOST=0.0.0.0
# Set a fixed dashboard token instead of a random one each restart.
# If not set, a random token is generated and printed to the console on startup.
#- PICOCLAW_LAUNCHER_TOKEN=your-secret-token-here
ports:
- "127.0.0.1:18800:18800"
- "127.0.0.1:18790:18790"
- "18800:18800"
- "18790:18790"
volumes:
- ./data:/root/.picoclaw
+63
View File
@@ -28,6 +28,69 @@ The currently exposed synchronous hook points are:
Everything else is exposed as read-only events.
## Hook Actions
Hooks can return different actions to control the flow:
| Action | Applicable Stages | Effect |
| --- | --- | --- |
| `continue` | All interceptors | Pass through without modification |
| `modify` | `before_llm`, `after_llm`, `before_tool`, `after_tool` | Modify request/response and continue |
| `respond` | `before_tool` | Return a tool result directly, skip actual tool execution |
| `deny_tool` | `before_tool` | Deny tool execution, return error message |
| `abort_turn` | All interceptors | Abort the current turn |
| `hard_abort` | All interceptors | Force stop the entire agent loop |
### The `respond` Action
The `respond` action is special: it allows a `before_tool` hook to provide the tool result directly, skipping the actual tool execution. This is useful for:
1. **Plugin tool injection**: External hooks can implement tools without registering them in the tool registry
2. **Tool result caching**: Return cached results for repeated tool calls
3. **Tool mocking**: Return mock results for testing purposes
When a hook returns `respond` with a `HookResult`, the agent loop:
1. Skips the actual tool execution
2. Uses the provided result as if the tool had executed
3. Continues the turn normally with the result
Example (Go in-process hook):
```go
func (h *MyHook) BeforeTool(
ctx context.Context,
call *agent.ToolCallHookRequest,
) (*agent.ToolCallHookRequest, agent.HookDecision, error) {
if call.Tool == "my_plugin_tool" {
next := call.Clone()
next.HookResult = &tools.ToolResult{
ForLLM: "Plugin tool executed successfully",
Silent: false,
IsError: false,
}
return next, agent.HookDecision{Action: agent.HookActionRespond}, nil
}
return call, agent.HookDecision{Action: agent.HookActionContinue}, nil
}
```
Example (Python process hook):
```python
def handle_before_tool(params: dict) -> dict:
tool = params.get("tool", "")
if tool == "my_plugin_tool":
return {
"action": "respond",
"result": {
"for_llm": "Plugin tool executed successfully",
"silent": False,
"is_error": False
}
}
return {"action": "continue"}
```
## Execution Order
`HookManager` sorts hooks like this:
+63
View File
@@ -28,6 +28,69 @@
其余 lifecycle 通过事件形式只读暴露。
## Hook Actions
Hook 可以返回不同的 action 来控制流程:
| Action | 适用阶段 | 效果 |
| --- | --- | --- |
| `continue` | 所有拦截型 | 放行,不做修改 |
| `modify` | `before_llm`, `after_llm`, `before_tool`, `after_tool` | 改写请求/响应后放行 |
| `respond` | `before_tool` | 直接返回工具结果,跳过实际工具执行 |
| `deny_tool` | `before_tool` | 拒绝工具执行,返回错误信息 |
| `abort_turn` | 所有拦截型 | 中止当前 turn |
| `hard_abort` | 所有拦截型 | 强制终止整个 agent loop |
### `respond` Action
`respond` action 是特殊的:它允许 `before_tool` hook 直接提供工具结果,跳过实际工具执行。适用于:
1. **插件工具注入**:外部 hook 可以实现工具,无需在 ToolRegistry 注册
2. **工具结果缓存**:对重复调用返回缓存结果
3. **工具模拟**:测试时返回模拟结果
当 hook 返回 `respond` 并携带 `HookResult` 时,agent loop 会:
1. 跳过实际工具执行
2. 使用提供的结果作为工具执行结果
3. 正常继续 turn 流程
示例(Go 进程内 hook):
```go
func (h *MyHook) BeforeTool(
ctx context.Context,
call *agent.ToolCallHookRequest,
) (*agent.ToolCallHookRequest, agent.HookDecision, error) {
if call.Tool == "my_plugin_tool" {
next := call.Clone()
next.HookResult = &tools.ToolResult{
ForLLM: "Plugin tool executed successfully",
Silent: false,
IsError: false,
}
return next, agent.HookDecision{Action: agent.HookActionRespond}, nil
}
return call, agent.HookDecision{Action: agent.HookActionContinue}, nil
}
```
示例(Python process hook):
```python
def handle_before_tool(params: dict) -> dict:
tool = params.get("tool", "")
if tool == "my_plugin_tool":
return {
"action": "respond",
"result": {
"for_llm": "Plugin tool executed successfully",
"silent": False,
"is_error": False
}
}
return {"action": "continue"}
```
## 执行顺序
HookManager 的排序规则是:
+568
View File
@@ -0,0 +1,568 @@
# Hook JSON-RPC Protocol Details
All hooks use `JSON-RPC 2.0` format, with one JSON message per line, transmitted via stdio.
---
## Basic Protocol Structure
### Request (PicoClaw → Hook)
```json
{"jsonrpc":"2.0","id":1,"method":"hook.xxx","params":{...}}
```
### Response (Hook → PicoClaw)
Success:
```json
{"jsonrpc":"2.0","id":1,"result":{...}}
```
Error:
```json
{"jsonrpc":"2.0","id":1,"error":{"code":-32000,"message":"error message"}}
```
---
## 1. `hook.hello` (Handshake)
Handshake must be completed at startup, otherwise the hook process will be terminated.
### Request
```json
{
"jsonrpc": "2.0",
"id": 1,
"method": "hook.hello",
"params": {
"name": "py_review_gate",
"version": 1,
"modes": ["observe", "tool", "approve"]
}
}
```
| Field | Description |
|-------|-------------|
| `name` | hook name (from configuration) |
| `version` | protocol version, currently `1` |
| `modes` | capability modes supported by the hook |
### Response
```json
{
"jsonrpc": "2.0",
"id": 1,
"result": {
"ok": true,
"name": "python-review-gate"
}
}
```
---
## 2. `hook.before_llm`
Triggered before sending request to LLM. Can be used to inject tools.
### Request
```json
{
"jsonrpc": "2.0",
"id": 2,
"method": "hook.before_llm",
"params": {
"meta": {
"AgentID": "agent-1",
"TurnID": "turn-1",
"ParentTurnID": "",
"SessionKey": "session-1",
"Iteration": 0,
"TracePath": "runTurn",
"Source": "turn.llm.request"
},
"model": "claude-sonnet",
"messages": [
{"role": "user", "content": "hello"}
],
"tools": [
{
"type": "function",
"function": {
"name": "echo",
"description": "echo text",
"parameters": {"type": "object"}
}
}
],
"options": {
"temperature": 0.7
},
"channel": "cli",
"chat_id": "chat-1",
"graceful_terminal": false
}
}
```
| Field | Description |
|-------|-------------|
| `meta` | event metadata for tracing |
| `model` | requested model name |
| `messages` | conversation history |
| `tools` | list of available tool definitions |
| `options` | LLM parameters (temperature, max_tokens, etc.) |
| `channel` | request source channel |
| `chat_id` | session ID |
### Response (Tool Injection Example)
```json
{
"jsonrpc": "2.0",
"id": 2,
"result": {
"action": "modify",
"request": {
"model": "claude-sonnet",
"messages": [{"role": "user", "content": "hello"}],
"tools": [
{
"type": "function",
"function": {
"name": "echo",
"description": "echo",
"parameters": {}
}
},
{
"type": "function",
"function": {
"name": "my_plugin_tool",
"description": "Plugin injected tool",
"parameters": {
"type": "object",
"properties": {
"query": {"type": "string"}
}
}
}
}
]
}
}
}
```
| Field | Description |
|-------|-------------|
| `action` | decision action (see table below) |
| `request` | modified request object |
---
## 3. `hook.after_llm`
Triggered after receiving LLM response. Can modify response content.
### Request
```json
{
"jsonrpc": "2.0",
"id": 3,
"method": "hook.after_llm",
"params": {
"meta": {
"AgentID": "agent-1",
"TurnID": "turn-1",
"SessionKey": "session-1"
},
"model": "claude-sonnet",
"response": {
"role": "assistant",
"content": "Hi!",
"tool_calls": [
{
"id": "tc-1",
"type": "function",
"function": {
"name": "echo",
"arguments": "{\"text\":\"hi\"}"
}
}
]
},
"channel": "cli",
"chat_id": "chat-1"
}
}
```
### Response
```json
{
"jsonrpc": "2.0",
"id": 3,
"result": {
"action": "continue"
}
}
```
---
## 4. `hook.before_tool`
Triggered before tool execution. Can modify tool name and arguments, deny execution, or return result directly.
### Request
```json
{
"jsonrpc": "2.0",
"id": 4,
"method": "hook.before_tool",
"params": {
"meta": {
"AgentID": "agent-1",
"TurnID": "turn-1",
"SessionKey": "session-1"
},
"tool": "echo_text",
"arguments": {
"text": "hello"
},
"channel": "cli",
"chat_id": "chat-1"
}
}
```
| Field | Description |
|-------|-------------|
| `tool` | tool name |
| `arguments` | tool arguments |
### Response (Modify Arguments)
```json
{
"jsonrpc": "2.0",
"id": 4,
"result": {
"action": "modify",
"call": {
"tool": "echo_text",
"arguments": {
"text": "modified hello"
}
}
}
}
```
### Response (Deny Execution)
```json
{
"jsonrpc": "2.0",
"id": 4,
"result": {
"action": "deny_tool",
"reason": "Invalid arguments"
}
}
```
### Response (Return Result Directly - respond)
```json
{
"jsonrpc": "2.0",
"id": 4,
"result": {
"action": "respond",
"call": {
"tool": "my_plugin_tool",
"arguments": {
"query": "hello"
}
},
"result": {
"for_llm": "Plugin tool executed successfully",
"for_user": "",
"silent": false,
"is_error": false
}
}
}
```
The `respond` action allows hooks to return tool results directly, skipping actual tool execution. Use cases:
1. **Plugin tool injection**: External hooks can implement tools without registering in ToolRegistry
2. **Tool result caching**: Return cached results for repeated calls
3. **Tool mocking**: Return mock results during testing
| Field | Description |
|-------|-------------|
| `action` | must be `respond` |
| `call` | modified call information (optional) |
| `result` | tool result to return directly |
---
## 5. `hook.after_tool`
Triggered after tool execution completes. Can modify the result returned to LLM.
### Request
```json
{
"jsonrpc": "2.0",
"id": 5,
"method": "hook.after_tool",
"params": {
"meta": {
"AgentID": "agent-1",
"TurnID": "turn-1",
"SessionKey": "session-1"
},
"tool": "echo_text",
"arguments": {
"text": "hello"
},
"result": {
"for_llm": "echoed: hello",
"for_user": "",
"silent": false,
"is_error": false,
"async": false,
"media": [],
"artifact_tags": [],
"response_handled": false
},
"duration": 15000000,
"channel": "cli",
"chat_id": "chat-1"
}
}
```
| Field | Description |
|-------|-------------|
| `result.for_llm` | content returned to LLM |
| `result.for_user` | content sent to user |
| `result.silent` | whether silent (not sent to user) |
| `result.is_error` | whether it's an error |
| `result.async` | whether executed asynchronously |
| `result.media` | list of media references |
| `result.artifact_tags` | local artifact path tags |
| `result.response_handled` | whether response has been handled |
| `duration` | execution time (nanoseconds) |
### Response
```json
{
"jsonrpc": "2.0",
"id": 5,
"result": {
"action": "continue"
}
}
```
---
## 6. `hook.approve_tool`
Approval hook for deciding whether to allow execution of sensitive tools.
### Request
```json
{
"jsonrpc": "2.0",
"id": 6,
"method": "hook.approve_tool",
"params": {
"meta": {
"AgentID": "agent-1",
"TurnID": "turn-1",
"SessionKey": "session-1"
},
"tool": "bash",
"arguments": {
"command": "rm -rf /"
},
"channel": "cli",
"chat_id": "chat-1"
}
}
```
### Response (Approved)
```json
{
"jsonrpc": "2.0",
"id": 6,
"result": {
"approved": true
}
}
```
### Response (Denied)
```json
{
"jsonrpc": "2.0",
"id": 6,
"result": {
"approved": false,
"reason": "Dangerous command, execution denied"
}
}
```
---
## 7. `hook.event` (notification)
Observer event, broadcast only, no response required. `id` is `0` or absent.
```json
{
"jsonrpc": "2.0",
"method": "hook.event",
"params": {
"Kind": "tool_exec_start",
"Meta": {
"AgentID": "agent-1",
"TurnID": "turn-1"
},
"Payload": {
"Tool": "echo_text",
"Arguments": {"text": "hello"}
}
}
}
```
Common `Kind` values:
- `turn_start` / `turn_end`
- `llm_request` / `llm_response`
- `tool_exec_start` / `tool_exec_end` / `tool_exec_skipped`
- `steering_injected`
- `interrupt_received`
- `error`
---
## Action Options
| action | Applicable hooks | Effect |
|--------|-----------------|--------|
| `continue` | All interceptor types | Pass through without modification |
| `modify` | `before_llm`, `before_tool`, `after_llm`, `after_tool` | Modify request/response and pass through |
| `respond` | `before_tool` | Return tool result directly, skip actual execution. **Note: AfterTool is NOT called (design decision - respond provides final answer).** |
| `deny_tool` | `before_tool` | Deny tool execution |
| `abort_turn` | All interceptor types | Abort current turn, return error |
| `hard_abort` | All interceptor types | Force stop entire agent loop |
---
## Complete Flow Example
```json
{"jsonrpc":"2.0","id":1,"method":"hook.hello","params":{"name":"my_hook","version":1,"modes":["tool","approve"]}}
{"jsonrpc":"2.0","id":1,"result":{"ok":true,"name":"my_hook"}}
{"jsonrpc":"2.0","id":2,"method":"hook.before_llm","params":{"model":"claude-sonnet","messages":[{"role":"user","content":"hello"}],"tools":[]}}
{"jsonrpc":"2.0","id":2,"result":{"action":"continue"}}
{"jsonrpc":"2.0","id":3,"method":"hook.before_tool","params":{"tool":"bash","arguments":{"command":"ls"}}}
{"jsonrpc":"2.0","id":3,"result":{"action":"continue"}}
{"jsonrpc":"2.0","id":4,"method":"hook.approve_tool","params":{"tool":"bash","arguments":{"command":"ls"}}}
{"jsonrpc":"2.0","id":4,"result":{"approved":true}}
{"jsonrpc":"2.0","id":5,"method":"hook.after_tool","params":{"tool":"bash","arguments":{"command":"ls"},"result":{"for_llm":"file1.txt\nfile2.txt"},"duration":5000000}}
{"jsonrpc":"2.0","id":5,"result":{"action":"continue"}}
{"jsonrpc":"2.0","id":6,"method":"hook.after_llm","params":{"model":"claude-sonnet","response":{"role":"assistant","content":"Files listed"}}}
{"jsonrpc":"2.0","id":6,"result":{"action":"continue"}}
```
---
## Plugin Tool Injection via `before_llm` and `before_tool`
Standard flow for plugin tool injection:
1. In `before_llm`, inject tool definition to let LLM know the tool is available
2. In `before_tool`, use `respond` action to return tool execution result directly
### `before_llm` Inject Tool Definition
```python
def handle_before_llm(params: dict) -> dict:
tools = params.get("tools", [])
# Add plugin tool definition
tools.append({
"type": "function",
"function": {
"name": "my_plugin_tool",
"description": "Plugin provided tool",
"parameters": {
"type": "object",
"properties": {
"input": {"type": "string", "description": "Input content"}
},
"required": ["input"]
}
}
})
return {
"action": "modify",
"request": {
"model": params["model"],
"messages": params["messages"],
"tools": tools,
"options": params.get("options", {})
}
}
```
### `before_tool` Return Execution Result
```python
def handle_before_tool(params: dict) -> dict:
tool = params.get("tool", "")
if tool == "my_plugin_tool":
# Implement tool logic here
args = params.get("arguments", {})
input_text = args.get("input", "")
# Return result directly, no need to register in ToolRegistry
return {
"action": "respond",
"result": {
"for_llm": f"Plugin tool executed successfully, input: {input_text}",
"silent": False,
"is_error": False
}
}
return {"action": "continue"}
```
This way, external hooks can fully implement plugin tools without registering any tool implementation inside PicoClaw.
+568
View File
@@ -0,0 +1,568 @@
# Hook JSON-RPC 协议详解
所有 hook 使用 `JSON-RPC 2.0` 格式,每行一个 JSON 消息,通过 stdio 传输。
---
## 基础协议结构
### 请求(PicoClaw → Hook
```json
{"jsonrpc":"2.0","id":1,"method":"hook.xxx","params":{...}}
```
### 响应(Hook → PicoClaw
成功:
```json
{"jsonrpc":"2.0","id":1,"result":{...}}
```
错误:
```json
{"jsonrpc":"2.0","id":1,"error":{"code":-32000,"message":"错误信息"}}
```
---
## 1. `hook.hello`(握手)
启动时必须完成握手,否则 hook 进程会被终止。
### 请求
```json
{
"jsonrpc": "2.0",
"id": 1,
"method": "hook.hello",
"params": {
"name": "py_review_gate",
"version": 1,
"modes": ["observe", "tool", "approve"]
}
}
```
| 字段 | 说明 |
|------|------|
| `name` | hook 名称(来自配置) |
| `version` | 协议版本,当前为 `1` |
| `modes` | hook 支持的能力模式 |
### 响应
```json
{
"jsonrpc": "2.0",
"id": 1,
"result": {
"ok": true,
"name": "python-review-gate"
}
}
```
---
## 2. `hook.before_llm`
在发送请求给 LLM 之前触发。可用于注入工具。
### 请求
```json
{
"jsonrpc": "2.0",
"id": 2,
"method": "hook.before_llm",
"params": {
"meta": {
"AgentID": "agent-1",
"TurnID": "turn-1",
"ParentTurnID": "",
"SessionKey": "session-1",
"Iteration": 0,
"TracePath": "runTurn",
"Source": "turn.llm.request"
},
"model": "claude-sonnet",
"messages": [
{"role": "user", "content": "hello"}
],
"tools": [
{
"type": "function",
"function": {
"name": "echo",
"description": "echo text",
"parameters": {"type": "object"}
}
}
],
"options": {
"temperature": 0.7
},
"channel": "cli",
"chat_id": "chat-1",
"graceful_terminal": false
}
}
```
| 字段 | 说明 |
|------|------|
| `meta` | 事件元数据,用于追踪 |
| `model` | 请求的模型名称 |
| `messages` | 对话历史 |
| `tools` | 可用工具定义列表 |
| `options` | LLM 参数(temperature、max_tokens 等) |
| `channel` | 请求来源通道 |
| `chat_id` | 会话 ID |
### 响应(注入工具示例)
```json
{
"jsonrpc": "2.0",
"id": 2,
"result": {
"action": "modify",
"request": {
"model": "claude-sonnet",
"messages": [{"role": "user", "content": "hello"}],
"tools": [
{
"type": "function",
"function": {
"name": "echo",
"description": "echo",
"parameters": {}
}
},
{
"type": "function",
"function": {
"name": "my_plugin_tool",
"description": "插件注入的工具",
"parameters": {
"type": "object",
"properties": {
"query": {"type": "string"}
}
}
}
}
]
}
}
}
```
| 字段 | 说明 |
|------|------|
| `action` | 决策动作(见下表) |
| `request` | 修改后的请求对象 |
---
## 3. `hook.after_llm`
在收到 LLM 响应后触发。可修改响应内容。
### 请求
```json
{
"jsonrpc": "2.0",
"id": 3,
"method": "hook.after_llm",
"params": {
"meta": {
"AgentID": "agent-1",
"TurnID": "turn-1",
"SessionKey": "session-1"
},
"model": "claude-sonnet",
"response": {
"role": "assistant",
"content": "Hi!",
"tool_calls": [
{
"id": "tc-1",
"type": "function",
"function": {
"name": "echo",
"arguments": "{\"text\":\"hi\"}"
}
}
]
},
"channel": "cli",
"chat_id": "chat-1"
}
}
```
### 响应
```json
{
"jsonrpc": "2.0",
"id": 3,
"result": {
"action": "continue"
}
}
```
---
## 4. `hook.before_tool`
在执行工具前触发。可修改工具名称和参数,或拒绝执行,或直接返回结果。
### 请求
```json
{
"jsonrpc": "2.0",
"id": 4,
"method": "hook.before_tool",
"params": {
"meta": {
"AgentID": "agent-1",
"TurnID": "turn-1",
"SessionKey": "session-1"
},
"tool": "echo_text",
"arguments": {
"text": "hello"
},
"channel": "cli",
"chat_id": "chat-1"
}
}
```
| 字段 | 说明 |
|------|------|
| `tool` | 工具名称 |
| `arguments` | 工具参数 |
### 响应(改写参数)
```json
{
"jsonrpc": "2.0",
"id": 4,
"result": {
"action": "modify",
"call": {
"tool": "echo_text",
"arguments": {
"text": "modified hello"
}
}
}
}
```
### 响应(拒绝执行)
```json
{
"jsonrpc": "2.0",
"id": 4,
"result": {
"action": "deny_tool",
"reason": "参数不合法"
}
}
```
### 响应(直接返回结果 - respond)
```json
{
"jsonrpc": "2.0",
"id": 4,
"result": {
"action": "respond",
"call": {
"tool": "my_plugin_tool",
"arguments": {
"query": "hello"
}
},
"result": {
"for_llm": "Plugin tool executed successfully",
"for_user": "",
"silent": false,
"is_error": false
}
}
}
```
`respond` action 允许 hook 直接返回工具结果,跳过实际工具执行。适用于:
1. **插件工具注入**:外部 hook 可实现工具,无需在 ToolRegistry 注册
2. **工具结果缓存**:对重复调用返回缓存结果
3. **工具模拟**:测试时返回模拟结果
| 字段 | 说明 |
|------|------|
| `action` | 必须为 `respond` |
| `call` | 修改后的调用信息(可选) |
| `result` | 直接返回的工具结果 |
---
## 5. `hook.after_tool`
在工具执行完成后触发。可修改返回给 LLM 的结果。
### 请求
```json
{
"jsonrpc": "2.0",
"id": 5,
"method": "hook.after_tool",
"params": {
"meta": {
"AgentID": "agent-1",
"TurnID": "turn-1",
"SessionKey": "session-1"
},
"tool": "echo_text",
"arguments": {
"text": "hello"
},
"result": {
"for_llm": "echoed: hello",
"for_user": "",
"silent": false,
"is_error": false,
"async": false,
"media": [],
"artifact_tags": [],
"response_handled": false
},
"duration": 15000000,
"channel": "cli",
"chat_id": "chat-1"
}
}
```
| 字段 | 说明 |
|------|------|
| `result.for_llm` | 返回给 LLM 的内容 |
| `result.for_user` | 发送给用户的内容 |
| `result.silent` | 是否静默(不发送给用户) |
| `result.is_error` | 是否为错误 |
| `result.async` | 是否异步执行 |
| `result.media` | 媒体引用列表 |
| `result.artifact_tags` | 本地产物路径标签 |
| `result.response_handled` | 是否已处理响应 |
| `duration` | 执行耗时(纳秒) |
### 响应
```json
{
"jsonrpc": "2.0",
"id": 5,
"result": {
"action": "continue"
}
}
```
---
## 6. `hook.approve_tool`
审批型 hook,用于决定是否允许执行敏感工具。
### 请求
```json
{
"jsonrpc": "2.0",
"id": 6,
"method": "hook.approve_tool",
"params": {
"meta": {
"AgentID": "agent-1",
"TurnID": "turn-1",
"SessionKey": "session-1"
},
"tool": "bash",
"arguments": {
"command": "rm -rf /"
},
"channel": "cli",
"chat_id": "chat-1"
}
}
```
### 响应(批准)
```json
{
"jsonrpc": "2.0",
"id": 6,
"result": {
"approved": true
}
}
```
### 响应(拒绝)
```json
{
"jsonrpc": "2.0",
"id": 6,
"result": {
"approved": false,
"reason": "危险命令,禁止执行"
}
}
```
---
## 7. `hook.event`notification
观察型事件,仅广播,无需响应。`id``0` 或不存在。
```json
{
"jsonrpc": "2.0",
"method": "hook.event",
"params": {
"Kind": "tool_exec_start",
"Meta": {
"AgentID": "agent-1",
"TurnID": "turn-1"
},
"Payload": {
"Tool": "echo_text",
"Arguments": {"text": "hello"}
}
}
}
```
常见 `Kind` 值:
- `turn_start` / `turn_end`
- `llm_request` / `llm_response`
- `tool_exec_start` / `tool_exec_end` / `tool_exec_skipped`
- `steering_injected`
- `interrupt_received`
- `error`
---
## action 可选值
| action | 适用 hook | 效果 |
|--------|----------|------|
| `continue` | 所有拦截型 | 放行,不做修改 |
| `modify` | `before_llm`, `before_tool`, `after_llm`, `after_tool` | 改写请求/响应后放行 |
| `respond` | `before_tool` | 直接返回工具结果,跳过实际执行 |
| `deny_tool` | `before_tool` | 拒绝执行该工具 |
| `abort_turn` | 所有拦截型 | 中止当前 turn,返回错误 |
| `hard_abort` | 所有拦截型 | 强制终止整个 agent loop |
---
## 完整流程示例
```json
{"jsonrpc":"2.0","id":1,"method":"hook.hello","params":{"name":"my_hook","version":1,"modes":["tool","approve"]}}
{"jsonrpc":"2.0","id":1,"result":{"ok":true,"name":"my_hook"}}
{"jsonrpc":"2.0","id":2,"method":"hook.before_llm","params":{"model":"claude-sonnet","messages":[{"role":"user","content":"hello"}],"tools":[]}}
{"jsonrpc":"2.0","id":2,"result":{"action":"continue"}}
{"jsonrpc":"2.0","id":3,"method":"hook.before_tool","params":{"tool":"bash","arguments":{"command":"ls"}}}
{"jsonrpc":"2.0","id":3,"result":{"action":"continue"}}
{"jsonrpc":"2.0","id":4,"method":"hook.approve_tool","params":{"tool":"bash","arguments":{"command":"ls"}}}
{"jsonrpc":"2.0","id":4,"result":{"approved":true}}
{"jsonrpc":"2.0","id":5,"method":"hook.after_tool","params":{"tool":"bash","arguments":{"command":"ls"},"result":{"for_llm":"file1.txt\nfile2.txt"},"duration":5000000}}
{"jsonrpc":"2.0","id":5,"result":{"action":"continue"}}
{"jsonrpc":"2.0","id":6,"method":"hook.after_llm","params":{"model":"claude-sonnet","response":{"role":"assistant","content":"已列出文件"}}}
{"jsonrpc":"2.0","id":6,"result":{"action":"continue"}}
```
---
## 通过 `before_llm` 和 `before_tool` 实现插件工具注入
插件工具注入的标准流程:
1.`before_llm` 中注入工具定义,让 LLM 知道有这个工具可用
2.`before_tool` 中使用 `respond` action 直接返回工具执行结果
### `before_llm` 注入工具定义
```python
def handle_before_llm(params: dict) -> dict:
tools = params.get("tools", [])
# 添加插件工具定义
tools.append({
"type": "function",
"function": {
"name": "my_plugin_tool",
"description": "插件提供的工具",
"parameters": {
"type": "object",
"properties": {
"input": {"type": "string", "description": "输入内容"}
},
"required": ["input"]
}
}
})
return {
"action": "modify",
"request": {
"model": params["model"],
"messages": params["messages"],
"tools": tools,
"options": params.get("options", {})
}
}
```
### `before_tool` 返回执行结果
```python
def handle_before_tool(params: dict) -> dict:
tool = params.get("tool", "")
if tool == "my_plugin_tool":
# 在这里实现工具逻辑
args = params.get("arguments", {})
input_text = args.get("input", "")
# 直接返回结果,无需在 ToolRegistry 注册
return {
"action": "respond",
"result": {
"for_llm": f"插件工具执行成功,输入: {input_text}",
"silent": False,
"is_error": False
}
}
return {"action": "continue"}
```
通过这种方式,外部 hook 可以完全实现插件工具,无需在 PicoClaw 内部注册任何工具实现。
+587
View File
@@ -0,0 +1,587 @@
# Plugin Tool Injection Example
This document demonstrates how to use PicoClaw's hook system to implement external plugin tool injection, allowing LLM to call tools implemented by external hook processes.
---
## Core Principle
Through the hook system's `respond` action, external hooks can:
1. Inject tool **definitions** in `before_llm`, letting LLM know the tool is available
2. Return tool **execution results** directly in `before_tool` using `respond` action, skipping ToolRegistry
This way, external hooks can fully implement plugin tools without registering any tools inside PicoClaw.
---
## Complete Example: Weather Query Plugin
Below is a complete Python hook example implementing a weather query plugin tool.
### 1. Hook Script Implementation
Save as `/tmp/weather_plugin.py`:
```python
#!/usr/bin/env python3
"""Weather query plugin hook example"""
from __future__ import annotations
import json
import sys
import signal
from typing import Any
# Simulated weather data
WEATHER_DATA = {
"Beijing": {"temp": 15, "weather": "Sunny", "humidity": 45},
"Shanghai": {"temp": 18, "weather": "Cloudy", "humidity": 60},
"Guangzhou": {"temp": 25, "weather": "Sunny", "humidity": 70},
"Shenzhen": {"temp": 26, "weather": "Cloudy", "humidity": 75},
}
def get_weather(city: str) -> dict:
"""Get weather data (simulated)"""
data = WEATHER_DATA.get(city)
if data:
return {
"for_llm": f"{city} weather: {data['weather']}, temperature {data['temp']}°C, humidity {data['humidity']}%",
"for_user": "",
"silent": False,
"is_error": False,
}
return {
"for_llm": f"Weather data not found for city {city}",
"for_user": "",
"silent": False,
"is_error": True,
}
def handle_hello(params: dict) -> dict:
return {"ok": True, "name": "weather-plugin"}
def handle_before_llm(params: dict) -> dict:
"""Inject weather query tool definition"""
tools = params.get("tools", [])
# Add weather query tool
tools.append({
"type": "function",
"function": {
"name": "get_weather",
"description": "Query weather information for a specified city",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "City name, e.g.: Beijing, Shanghai, Guangzhou"
}
},
"required": ["city"]
}
}
})
return {
"action": "modify",
"request": {
"model": params.get("model"),
"messages": params.get("messages", []),
"tools": tools,
"options": params.get("options", {}),
}
}
def handle_before_tool(params: dict) -> dict:
"""Handle tool call, return result directly"""
tool = params.get("tool", "")
args = params.get("arguments", {})
if tool == "get_weather":
city = args.get("city", "")
result = get_weather(city)
# Use respond action to return result directly, skip ToolRegistry
return {
"action": "respond",
"result": result,
}
# Other tools continue normal flow
return {"action": "continue"}
def handle_request(method: str, params: dict) -> dict:
if method == "hook.hello":
return handle_hello(params)
if method == "hook.before_llm":
return handle_before_llm(params)
if method == "hook.before_tool":
return handle_before_tool(params)
if method == "hook.after_llm":
return {"action": "continue"}
if method == "hook.after_tool":
return {"action": "continue"}
if method == "hook.approve_tool":
return {"approved": True}
raise KeyError(f"method not found: {method}")
def send_response(message_id: int, result: Any | None = None, error: str | None = None) -> None:
payload: dict[str, Any] = {
"jsonrpc": "2.0",
"id": message_id,
}
if error is not None:
payload["error"] = {"code": -32000, "message": error}
else:
payload["result"] = result if result is not None else {}
sys.stdout.write(json.dumps(payload, ensure_ascii=True) + "\n")
sys.stdout.flush()
def main() -> int:
for raw_line in sys.stdin:
line = raw_line.strip()
if not line:
continue
try:
message = json.loads(line)
except json.JSONDecodeError:
continue
method = message.get("method")
message_id = message.get("id", 0)
params = message.get("params") or {}
if not message_id:
continue
try:
result = handle_request(str(method or ""), params)
send_response(int(message_id), result=result)
except KeyError as exc:
send_response(int(message_id), error=str(exc))
except Exception as exc:
send_response(int(message_id), error=f"unexpected error: {exc}")
return 0
if __name__ == "__main__":
signal.signal(signal.SIGINT, lambda *_: raise SystemExit(0))
signal.signal(signal.SIGTERM, lambda *_: raise SystemExit(0))
raise SystemExit(main())
```
### 2. Configure PicoClaw
Add hook configuration in the config file:
```json
{
"hooks": {
"enabled": true,
"processes": {
"weather_plugin": {
"enabled": true,
"priority": 100,
"transport": "stdio",
"command": ["python3", "/tmp/weather_plugin.py"],
"intercept": ["before_llm", "before_tool"]
}
}
}
}
```
### 3. Test Results
When user asks "What's the weather in Beijing today?":
1. PicoClaw sends `hook.before_llm`, hook injects `get_weather` tool definition
2. LLM sees tool definition, decides to call `get_weather(city="Beijing")`
3. PicoClaw sends `hook.before_tool`, hook uses `respond` action to return weather data
4. LLM receives result, replies to user "Beijing is sunny today, temperature 15°C"
---
## Flow Diagram
```
User: "What's the weather in Beijing today?"
PicoClaw
hook.before_llm
↓ (inject get_weather tool definition)
LLM request
LLM decides to call get_weather(city="Beijing")
hook.before_tool
↓ (respond action returns weather data)
Return result directly to LLM
↓ (skip ToolRegistry)
LLM replies: "Beijing is sunny today, temperature 15°C"
```
---
## Key Points
### `before_llm` Inject Tool Definition
Tool definition follows OpenAI function calling format:
```json
{
"type": "function",
"function": {
"name": "tool_name",
"description": "tool description",
"parameters": {
"type": "object",
"properties": {
"param_name": {
"type": "string",
"description": "parameter description"
}
},
"required": ["list of required parameters"]
}
}
}
```
### `before_tool` Use respond Action
`respond` action response format:
```json
{
"action": "respond",
"result": {
"for_llm": "Content returned to LLM",
"for_user": "Optional, content sent to user",
"silent": false,
"is_error": false,
"media": ["Optional, media reference list"],
"response_handled": false
}
}
```
| Field | Description |
|-------|-------------|
| `for_llm` | Required, LLM will see this content |
| `for_user` | Optional, sent directly to user |
| `silent` | When true, not sent to user |
| `is_error` | When true, indicates execution failure |
| `media` | Optional, media file references (images, files, etc.) |
| `response_handled` | When true, indicates user request is handled, turn will end |
---
## Media File Handling
The `respond` action supports returning media files (images, files, etc.). There are two processing modes:
### 1. Automatic Delivery (`response_handled=true`)
When `response_handled=true`, media files are automatically sent to the user and the turn ends:
```json
{
"action": "respond",
"result": {
"for_llm": "Image sent to user",
"for_user": "",
"media": ["media://abc123"],
"response_handled": true
}
}
```
Use cases:
- Image generation plugin directly returning results
- File download plugin sending files to user
### 2. LLM Visible (`response_handled=false`)
When `response_handled=false`, media references are passed to the LLM, which can see the content in the next request:
```json
{
"action": "respond",
"result": {
"for_llm": "Image loaded, path: /tmp/image.png [file:/tmp/image.png]",
"media": ["media://abc123"]
}
}
```
After seeing the content, the LLM can decide:
- Use `send_file` tool to send to user
- Analyze image content and reply to user
- Other processing approaches
### Media Reference Format
Media references use the `media://` protocol:
```
media://<store-id>
```
These references are managed by PicoClaw's MediaStore and can be:
- Sent to user via channel
- Converted to base64 in LLM vision requests
### Alternative: Use Existing Tools
If the plugin generates files, you can return the file path and let the LLM call `send_file` or similar tools:
```json
{
"action": "respond",
"result": {
"for_llm": "Image generated, saved at /tmp/generated_image.png. Use send_file tool to send to user.",
"for_user": "",
"silent": false
}
}
```
This approach:
- More decoupled, LLM decides when to send
- Leverages existing tool mechanisms
- Supports batch sending, delayed sending, etc.
---
## Multi-Tool Injection Example
Multiple tools can be injected simultaneously:
```python
def handle_before_llm(params: dict) -> dict:
tools = params.get("tools", [])
# Tool 1: Weather query
tools.append({
"type": "function",
"function": {
"name": "get_weather",
"description": "Query city weather",
"parameters": {
"type": "object",
"properties": {
"city": {"type": "string", "description": "City name"}
},
"required": ["city"]
}
}
})
# Tool 2: Calculator
tools.append({
"type": "function",
"function": {
"name": "calculate",
"description": "Perform mathematical calculations",
"parameters": {
"type": "object",
"properties": {
"expression": {"type": "string", "description": "Mathematical expression"}
},
"required": ["expression"]
}
}
})
return {
"action": "modify",
"request": {
"model": params.get("model"),
"messages": params.get("messages", []),
"tools": tools,
"options": params.get("options", {}),
}
}
def handle_before_tool(params: dict) -> dict:
tool = params.get("tool", "")
args = params.get("arguments", {})
if tool == "get_weather":
return {
"action": "respond",
"result": get_weather(args.get("city", "")),
}
if tool == "calculate":
# Simple calculation example
try:
expr = args.get("expression", "")
result = eval(expr) # Note: needs security handling in actual use
return {
"action": "respond",
"result": {
"for_llm": f"Calculation result: {result}",
"silent": False,
"is_error": False,
},
}
except Exception as e:
return {
"action": "respond",
"result": {
"for_llm": f"Calculation error: {e}",
"silent": False,
"is_error": True,
},
}
return {"action": "continue"}
```
---
## Coexistence with Built-in Tools
Injected plugin tools coexist with PicoClaw built-in tools:
- Built-in tools (like `bash`, `read_file`) execute normally through ToolRegistry
- Plugin tools return results through hook's `respond` action
- `handle_before_tool` only handles plugin tools, other tools return `continue`
---
## Go In-Process Hook Example
If you need to implement plugin tool injection in Go code:
```go
package myhooks
import (
"context"
"github.com/sipeed/picoclaw/pkg/agent"
"github.com/sipeed/picoclaw/pkg/tools"
)
type WeatherPluginHook struct{}
func (h *WeatherPluginHook) BeforeLLM(
ctx context.Context,
req *agent.LLMHookRequest,
) (*agent.LLMHookRequest, agent.HookDecision, error) {
// Inject tool definition
req.Tools = append(req.Tools, agent.ToolDefinition{
Type: "function",
Function: agent.FunctionDefinition{
Name: "get_weather",
Description: "Query city weather",
Parameters: map[string]any{
"type": "object",
"properties": map[string]any{
"city": map[string]any{
"type": "string",
"description": "City name",
},
},
"required": []string{"city"},
},
},
})
return req, agent.HookDecision{Action: agent.HookActionContinue}, nil
}
func (h *WeatherPluginHook) BeforeTool(
ctx context.Context,
call *agent.ToolCallHookRequest,
) (*agent.ToolCallHookRequest, agent.HookDecision, error) {
if call.Tool == "get_weather" {
city := call.Arguments["city"].(string)
// Set HookResult, use respond action
next := call.Clone()
next.HookResult = &tools.ToolResult{
ForLLM: getWeatherData(city),
Silent: false,
IsError: false,
}
return next, agent.HookDecision{Action: agent.HookActionRespond}, nil
}
return call, agent.HookDecision{Action: agent.HookActionContinue}, nil
}
func getWeatherData(city string) string {
// Implement weather query logic
return fmt.Sprintf("%s weather: Sunny, temperature 20°C", city)
}
```
---
## Summary
Through the hook system's `respond` action, external processes can:
1. **Inject tool definitions**: Let LLM know new tools are available
2. **Provide tool implementation**: Return execution results directly, no need to register in ToolRegistry
3. **Coexist with built-in tools**: Does not affect normal operation of PicoClaw's original tools
This provides a flexible and elegant solution for plugin development.
---
## Security Boundaries
### Bypassing Approval Checks
**Important**: The `respond` action bypasses `ApproveTool` approval checks.
This means:
- A `before_tool` hook can return `respond` for **any tool name**, including sensitive tools (like `bash`)
- The tool won't go through the approval process, directly returning the hook-provided result
- This is designed for plugin tools but introduces security risks
### Security Recommendations
1. **Review hook configuration**: Ensure only trusted hook processes are enabled
2. **Limit hook scope**: Add your own security checks in hook implementation
3. **Use `deny_tool` for rejection**: Use `deny_tool` action instead of `respond` with error for denying execution
### Example: Hook-Internal Security Check
```python
def handle_before_tool(params: dict) -> dict:
tool = params.get("tool", "")
args = params.get("arguments", {})
# Security check: only handle plugin tools
if tool in ["get_weather", "calculate"]:
return {
"action": "respond",
"result": execute_plugin_tool(tool, args),
}
# Other tools continue normal flow (will go through approval)
return {"action": "continue"}
```
This ensures the hook only affects plugin tools, not system tool approval flow.
+587
View File
@@ -0,0 +1,587 @@
# 插件工具注入示例
本文档展示如何利用 PicoClaw 的 hook 系统实现外部插件工具注入,让 LLM 能调用由外部 hook 进程实现的工具。
---
## 核心原理
通过 hook 系统的 `respond` action,外部 hook 可以:
1.`before_llm` 中注入工具**定义**,让 LLM 知道有这个工具可用
2.`before_tool` 中使用 `respond` action 直接返回工具**执行结果**,跳过 ToolRegistry
这样,外部 hook 可以完全实现插件工具,无需在 PicoClaw 内部注册任何工具。
---
## 完整示例:天气查询插件
下面是一个完整的 Python hook 示例,实现一个天气查询插件工具。
### 1. Hook 脚本实现
保存为 `/tmp/weather_plugin.py`
```python
#!/usr/bin/env python3
"""天气查询插件 hook 示例"""
from __future__ import annotations
import json
import sys
import signal
from typing import Any
# 模拟天气数据
WEATHER_DATA = {
"北京": {"temp": 15, "weather": "", "humidity": 45},
"上海": {"temp": 18, "weather": "多云", "humidity": 60},
"广州": {"temp": 25, "weather": "", "humidity": 70},
"深圳": {"temp": 26, "weather": "多云", "humidity": 75},
}
def get_weather(city: str) -> dict:
"""获取天气数据(模拟)"""
data = WEATHER_DATA.get(city)
if data:
return {
"for_llm": f"{city}天气:{data['weather']},温度{data['temp']}°C,湿度{data['humidity']}%",
"for_user": "",
"silent": False,
"is_error": False,
}
return {
"for_llm": f"未找到城市 {city} 的天气数据",
"for_user": "",
"silent": False,
"is_error": True,
}
def handle_hello(params: dict) -> dict:
return {"ok": True, "name": "weather-plugin"}
def handle_before_llm(params: dict) -> dict:
"""注入天气查询工具定义"""
tools = params.get("tools", [])
# 添加天气查询工具
tools.append({
"type": "function",
"function": {
"name": "get_weather",
"description": "查询指定城市的天气信息",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "城市名称,如:北京、上海、广州"
}
},
"required": ["city"]
}
}
})
return {
"action": "modify",
"request": {
"model": params.get("model"),
"messages": params.get("messages", []),
"tools": tools,
"options": params.get("options", {}),
}
}
def handle_before_tool(params: dict) -> dict:
"""处理工具调用,直接返回结果"""
tool = params.get("tool", "")
args = params.get("arguments", {})
if tool == "get_weather":
city = args.get("city", "")
result = get_weather(city)
# 使用 respond action 直接返回结果,跳过 ToolRegistry
return {
"action": "respond",
"result": result,
}
# 其他工具继续正常流程
return {"action": "continue"}
def handle_request(method: str, params: dict) -> dict:
if method == "hook.hello":
return handle_hello(params)
if method == "hook.before_llm":
return handle_before_llm(params)
if method == "hook.before_tool":
return handle_before_tool(params)
if method == "hook.after_llm":
return {"action": "continue"}
if method == "hook.after_tool":
return {"action": "continue"}
if method == "hook.approve_tool":
return {"approved": True}
raise KeyError(f"method not found: {method}")
def send_response(message_id: int, result: Any | None = None, error: str | None = None) -> None:
payload: dict[str, Any] = {
"jsonrpc": "2.0",
"id": message_id,
}
if error is not None:
payload["error"] = {"code": -32000, "message": error}
else:
payload["result"] = result if result is not None else {}
sys.stdout.write(json.dumps(payload, ensure_ascii=True) + "\n")
sys.stdout.flush()
def main() -> int:
for raw_line in sys.stdin:
line = raw_line.strip()
if not line:
continue
try:
message = json.loads(line)
except json.JSONDecodeError:
continue
method = message.get("method")
message_id = message.get("id", 0)
params = message.get("params") or {}
if not message_id:
continue
try:
result = handle_request(str(method or ""), params)
send_response(int(message_id), result=result)
except KeyError as exc:
send_response(int(message_id), error=str(exc))
except Exception as exc:
send_response(int(message_id), error=f"unexpected error: {exc}")
return 0
if __name__ == "__main__":
signal.signal(signal.SIGINT, lambda *_: raise SystemExit(0))
signal.signal(signal.SIGTERM, lambda *_: raise SystemExit(0))
raise SystemExit(main())
```
### 2. 配置 PicoClaw
在配置文件中添加 hook 配置:
```json
{
"hooks": {
"enabled": true,
"processes": {
"weather_plugin": {
"enabled": true,
"priority": 100,
"transport": "stdio",
"command": ["python3", "/tmp/weather_plugin.py"],
"intercept": ["before_llm", "before_tool"]
}
}
}
}
```
### 3. 测试效果
当用户问"北京今天天气怎么样?"时:
1. PicoClaw 发送 `hook.before_llm`hook 注入 `get_weather` 工具定义
2. LLM 看到工具定义,决定调用 `get_weather(city="北京")`
3. PicoClaw 发送 `hook.before_tool`hook 使用 `respond` action 返回天气数据
4. LLM 收到结果,回复用户"北京今天晴天,温度15°C"
---
## 流程图解
```
用户: "北京今天天气怎么样?"
PicoClaw
hook.before_llm
↓ (注入 get_weather 工具定义)
LLM 请求
LLM 决定调用 get_weather(city="北京")
hook.before_tool
↓ (respond action 返回天气数据)
直接返回结果给 LLM
↓ (跳过 ToolRegistry)
LLM 回复: "北京今天晴天,温度15°C"
```
---
## 关键点说明
### `before_llm` 注入工具定义
工具定义遵循 OpenAI function calling 格式:
```json
{
"type": "function",
"function": {
"name": "工具名称",
"description": "工具描述",
"parameters": {
"type": "object",
"properties": {
"参数名": {
"type": "string",
"description": "参数描述"
}
},
"required": ["必需参数列表"]
}
}
}
```
### `before_tool` 使用 respond action
`respond` action 的响应格式:
```json
{
"action": "respond",
"result": {
"for_llm": "返回给 LLM 的内容",
"for_user": "可选,发送给用户的内容",
"silent": false,
"is_error": false,
"media": ["可选,媒体引用列表"],
"response_handled": false
}
}
```
| 字段 | 说明 |
|------|------|
| `for_llm` | 必须,LLM 会看到这个内容 |
| `for_user` | 可选,直接发送给用户 |
| `silent` | 为 true 时不发送给用户 |
| `is_error` | 为 true 时表示执行失败 |
| `media` | 可选,媒体文件引用列表(如图片、文件) |
| `response_handled` | 为 true 时表示已处理用户请求,轮次将结束 |
---
## 媒体文件处理
`respond` action 支持返回媒体文件(图片、文件等)。有两种处理方式:
### 1. 自动发送(`response_handled=true`
`response_handled=true` 时,媒体文件会自动发送给用户,轮次结束:
```json
{
"action": "respond",
"result": {
"for_llm": "图片已发送给用户",
"for_user": "",
"media": ["media://abc123"],
"response_handled": true
}
}
```
适用场景:
- 图像生成插件直接返回结果
- 文件下载插件发送文件给用户
### 2. LLM 可见(`response_handled=false`
`response_handled=false` 时,媒体引用会传递给 LLM,LLM 可以在下一轮请求中看到内容:
```json
{
"action": "respond",
"result": {
"for_llm": "图片已加载,路径:/tmp/image.png [file:/tmp/image.png]",
"media": ["media://abc123"]
}
}
```
LLM 看到内容后,可以自主决定:
- 使用 `send_file` 工具发送给用户
- 分析图片内容并回复用户
- 其他处理方式
### 媒体引用格式
媒体引用使用 `media://` 协议:
```
media://<store-id>
```
这些引用由 PicoClaw 的 MediaStore 管理,可以:
- 通过 channel 发送给用户
- 在 LLM vision 请求中转换为 base64
### 替代方案:使用现有工具
如果插件生成文件,可以返回文件路径让 LLM 调用 `send_file` 等工具:
```json
{
"action": "respond",
"result": {
"for_llm": "图片已生成,保存在 /tmp/generated_image.png。使用 send_file 工具发送给用户。",
"for_user": "",
"silent": false
}
}
```
这种方式:
- 更解耦,LLM 自主决策发送时机
- 利用现有工具机制
- 支持批量发送、延迟发送等场景
---
## 多工具注入示例
可以同时注入多个工具:
```python
def handle_before_llm(params: dict) -> dict:
tools = params.get("tools", [])
# 工具1:天气查询
tools.append({
"type": "function",
"function": {
"name": "get_weather",
"description": "查询城市天气",
"parameters": {
"type": "object",
"properties": {
"city": {"type": "string", "description": "城市名称"}
},
"required": ["city"]
}
}
})
# 工具2:计算器
tools.append({
"type": "function",
"function": {
"name": "calculate",
"description": "执行数学计算",
"parameters": {
"type": "object",
"properties": {
"expression": {"type": "string", "description": "数学表达式"}
},
"required": ["expression"]
}
}
})
return {
"action": "modify",
"request": {
"model": params.get("model"),
"messages": params.get("messages", []),
"tools": tools,
"options": params.get("options", {}),
}
}
def handle_before_tool(params: dict) -> dict:
tool = params.get("tool", "")
args = params.get("arguments", {})
if tool == "get_weather":
return {
"action": "respond",
"result": get_weather(args.get("city", "")),
}
if tool == "calculate":
# 简单计算示例
try:
expr = args.get("expression", "")
result = eval(expr) # 注意:实际使用时需要安全处理
return {
"action": "respond",
"result": {
"for_llm": f"计算结果: {result}",
"silent": False,
"is_error": False,
},
}
except Exception as e:
return {
"action": "respond",
"result": {
"for_llm": f"计算错误: {e}",
"silent": False,
"is_error": True,
},
}
return {"action": "continue"}
```
---
## 与内置工具共存
注入的插件工具与 PicoClaw 内置工具共存:
- 内置工具(如 `bash``read_file`)正常通过 ToolRegistry 执行
- 插件工具通过 hook 的 `respond` action 返回结果
- `handle_before_tool` 中只处理插件工具,其他工具返回 `continue`
---
## Go 进程内 Hook 示例
如果需要在 Go 代码中实现插件工具注入:
```go
package myhooks
import (
"context"
"github.com/sipeed/picoclaw/pkg/agent"
"github.com/sipeed/picoclaw/pkg/tools"
)
type WeatherPluginHook struct{}
func (h *WeatherPluginHook) BeforeLLM(
ctx context.Context,
req *agent.LLMHookRequest,
) (*agent.LLMHookRequest, agent.HookDecision, error) {
// 注入工具定义
req.Tools = append(req.Tools, agent.ToolDefinition{
Type: "function",
Function: agent.FunctionDefinition{
Name: "get_weather",
Description: "查询城市天气",
Parameters: map[string]any{
"type": "object",
"properties": map[string]any{
"city": map[string]any{
"type": "string",
"description": "城市名称",
},
},
"required": []string{"city"},
},
},
})
return req, agent.HookDecision{Action: agent.HookActionContinue}, nil
}
func (h *WeatherPluginHook) BeforeTool(
ctx context.Context,
call *agent.ToolCallHookRequest,
) (*agent.ToolCallHookRequest, agent.HookDecision, error) {
if call.Tool == "get_weather" {
city := call.Arguments["city"].(string)
// 设置 HookResult,使用 respond action
next := call.Clone()
next.HookResult = &tools.ToolResult{
ForLLM: getWeatherData(city),
Silent: false,
IsError: false,
}
return next, agent.HookDecision{Action: agent.HookActionRespond}, nil
}
return call, agent.HookDecision{Action: agent.HookActionContinue}, nil
}
func getWeatherData(city string) string {
// 实现天气查询逻辑
return fmt.Sprintf("%s天气:晴,温度20°C", city)
}
```
---
## 总结
通过 hook 系统的 `respond` action,外部进程可以:
1. **注入工具定义**:让 LLM 知道有新工具可用
2. **提供工具实现**:直接返回执行结果,无需注册到 ToolRegistry
3. **与内置工具共存**:不影响 PicoClaw 原有工具的正常运行
这为插件开发提供了灵活、优雅的解决方案。
---
## 安全边界说明
### 绕过审批检查
**重要**`respond` action 会绕过 `ApproveTool` 审批检查。
这意味着:
- `before_tool` hook 可以为**任何工具名称**返回 `respond`,包括敏感工具(如 `bash`
- 工具不会经过审批流程,直接返回 hook 提供的结果
- 这是为了支持插件工具而设计,但也带来了安全风险
### 安全建议
1. **审查 hook 配置**:确保只有可信的 hook 进程被启用
2. **限制 hook 权限**:在 hook 实现中添加自己的安全检查
3. **优先使用 `deny_tool`**:对于拒绝执行,使用 `deny_tool` action 而非 `respond` 返回错误
### 示例:hook 内置安全检查
```python
def handle_before_tool(params: dict) -> dict:
tool = params.get("tool", "")
args = params.get("arguments", {})
# 安全检查:只处理插件工具
if tool in ["get_weather", "calculate"]:
return {
"action": "respond",
"result": execute_plugin_tool(tool, args),
}
# 其他工具继续正常流程(会经过审批)
return {"action": "continue"}
```
这样可以确保 hook 只影响插件工具,不影响系统工具的审批流程。
+1
View File
@@ -122,6 +122,7 @@ This design also enables **multi-agent support** with flexible provider selectio
| `max_tokens_field` | string | No | Override the max tokens field name in request body (e.g., `max_completion_tokens` for o1 models) |
| `thinking_level` | string | No | Extended thinking level: `off`, `low`, `medium`, `high`, `xhigh`, or `adaptive` |
| `extra_body` | object | No | Additional fields to inject into every request body |
| `custom_headers` | object | No | Additional HTTP headers to inject into every request (e.g., `{"X-Source":"coding-plan"}`). If a key matches a built-in header, the custom value overrides the built-in one (e.g., `Authorization`, `User-Agent`, `Content-Type`, `Accept`). |
| `rpm` | int | No | Per-minute request rate limit |
| `fallbacks` | string[] | No | Fallback model names for automatic failover |
| `enabled` | bool | No | Whether this model entry is active (default: `true`) |
+3
View File
@@ -528,6 +528,9 @@ For example:
- `PICOCLAW_TOOLS_EXEC_ENABLE_DENY_PATTERNS=false`
- `PICOCLAW_TOOLS_CRON_EXEC_TIMEOUT_MINUTES=10`
- `PICOCLAW_TOOLS_MCP_ENABLED=true`
- `PICOCLAW_TOOLS_MCP_MAX_INLINE_TEXT_CHARS=16384`
Note: Nested map-style config (for example `tools.mcp.servers.<name>.*`) is configured in `config.json` rather than
environment variables.
For MCP tools, `tools.mcp.max_inline_text_chars` controls how much text result is kept inline in model context. The threshold is counted in Unicode characters (Go runes), not bytes. For example, `16384` means up to 16,384 characters inline, which may occupy more than 16 KB for multibyte text such as CJK. Above this threshold, PicoClaw saves the MCP text result as a local artifact in the agent workspace and gives the model a short note plus a structured `[file:...]` artifact path instead of injecting the full payload into context.
+1
View File
@@ -118,6 +118,7 @@
| `max_tokens_field` | string | 否 | 覆盖请求体中 max tokens 的字段名(如 o1 模型使用 `max_completion_tokens` |
| `thinking_level` | string | 否 | 扩展思考级别:`off``low``medium``high``xhigh``adaptive` |
| `extra_body` | object | 否 | 注入到每个请求体中的额外字段 |
| `custom_headers` | object | 否 | 注入到每个请求中的额外 HTTP 请求头(例如 `{"X-Source":"coding-plan"}`)。若键名与内置请求头同名,会覆盖内置值(如 `Authorization``User-Agent``Content-Type``Accept`)。 |
| `rpm` | int | 否 | 每分钟请求速率限制 |
| `fallbacks` | string[] | 否 | 自动故障转移的备用模型名称 |
| `enabled` | bool | 否 | 是否启用此模型条目(默认:`true` |
+4 -3
View File
@@ -1,6 +1,6 @@
module github.com/sipeed/picoclaw
go 1.25.8
go 1.25.9
require (
fyne.io/systray v1.12.0
@@ -8,6 +8,7 @@ require (
github.com/SevereCloud/vksdk/v3 v3.3.1
github.com/adhocore/gronx v1.19.6
github.com/anthropics/anthropic-sdk-go v1.26.0
github.com/atc0005/go-teams-notify/v2 v2.14.0
github.com/atotto/clipboard v0.1.4
github.com/aws/aws-sdk-go-v2 v1.41.5
github.com/aws/aws-sdk-go-v2/config v1.32.12
@@ -29,7 +30,7 @@ require (
github.com/mymmrac/telego v1.7.0
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1
github.com/openai/openai-go/v3 v3.22.0
github.com/pion/rtp v1.8.7
github.com/pion/rtp v1.10.1
github.com/pion/webrtc/v3 v3.3.6
github.com/rivo/tview v0.42.0
github.com/rs/zerolog v1.35.0
@@ -45,7 +46,7 @@ require (
google.golang.org/protobuf v1.36.11
gopkg.in/yaml.v3 v3.0.1
maunium.net/go/mautrix v0.26.4
modernc.org/sqlite v1.47.0
modernc.org/sqlite v1.48.0
rsc.io/qr v0.2.0
)
+6 -4
View File
@@ -21,6 +21,8 @@ github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwTo
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
github.com/anthropics/anthropic-sdk-go v1.26.0 h1:oUTzFaUpAevfuELAP1sjL6CQJ9HHAfT7CoSYSac11PY=
github.com/anthropics/anthropic-sdk-go v1.26.0/go.mod h1:qUKmaW+uuPB64iy1l+4kOSvaLqPXnHTTBKH6RVZ7q5Q=
github.com/atc0005/go-teams-notify/v2 v2.14.0 h1:7N+xw+COnYANLREaAveQ65rsNQ12nIZJED9nMLyscCo=
github.com/atc0005/go-teams-notify/v2 v2.14.0/go.mod h1:EECsWM2b0Hvoz7O+QdlsvyN2KCUOFQCGj8bUBXv3A3Q=
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
github.com/aws/aws-sdk-go-v2 v1.41.5 h1:dj5kopbwUsVUVFgO4Fi5BIT3t4WyqIDjGKCangnV/yY=
@@ -207,8 +209,8 @@ github.com/petermattis/goid v0.0.0-20260226131333-17d1149c6ac6 h1:rh2lKw/P/EqHa7
github.com/petermattis/goid v0.0.0-20260226131333-17d1149c6ac6/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4=
github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA=
github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8=
github.com/pion/rtp v1.8.7 h1:qslKkG8qxvQ7hqaxkmL7Pl0XcUm+/Er7nMnu6Vq+ZxM=
github.com/pion/rtp v1.8.7/go.mod h1:pBGHaFt/yW7bf1jjWAoUjpSNoDnw98KTMg+jWWvziqU=
github.com/pion/rtp v1.10.1 h1:xP1prZcCTUuhO2c83XtxyOHJteISg6o8iPsE2acaMtA=
github.com/pion/rtp v1.10.1/go.mod h1:rF5nS1GqbR7H/TCpKwylzeq6yDM+MM6k+On5EgeThEM=
github.com/pion/webrtc/v3 v3.3.6 h1:7XAh4RPtlY1Vul6/GmZrv7z+NnxKA6If0KStXBI2ZLE=
github.com/pion/webrtc/v3 v3.3.6/go.mod h1:zyN7th4mZpV27eXybfR/cnUf3J2DRy8zw/mdjD9JTNM=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
@@ -456,8 +458,8 @@ modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
modernc.org/sqlite v1.47.0 h1:R1XyaNpoW4Et9yly+I2EeX7pBza/w+pmYee/0HJDyKk=
modernc.org/sqlite v1.47.0/go.mod h1:hWjRO6Tj/5Ik8ieqxQybiEOUXy0NJFNp2tpvVpKlvig=
modernc.org/sqlite v1.48.0 h1:ElZyLop3Q2mHYk5IFPPXADejZrlHu7APbpB0sF78bq4=
modernc.org/sqlite v1.48.0/go.mod h1:hWjRO6Tj/5Ik8ieqxQybiEOUXy0NJFNp2tpvVpKlvig=
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
+11 -85
View File
@@ -6,10 +6,8 @@
package agent
import (
"encoding/json"
"unicode/utf8"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/tokenizer"
)
// parseTurnBoundaries returns the starting index of each Turn in the history.
@@ -86,88 +84,16 @@ func findSafeBoundary(history []providers.Message, targetIndex int) int {
return 0
}
// estimateMessageTokens estimates the token count for a single message,
// including Content, ReasoningContent, ToolCalls arguments, ToolCallID
// metadata, and Media items. Uses a heuristic of 2.5 characters per token.
func estimateMessageTokens(msg providers.Message) int {
contentChars := utf8.RuneCountInString(msg.Content)
// SystemParts are structured system blocks used for cache-aware adapters.
// They carry the same content as Content, but in multiple blocks.
// We estimate them as an alternative representation, not additive.
systemPartsChars := 0
if len(msg.SystemParts) > 0 {
for _, part := range msg.SystemParts {
systemPartsChars += utf8.RuneCountInString(part.Text)
}
// Per-part overhead for JSON structure (type, text, cache_control).
const perPartOverhead = 20
systemPartsChars += len(msg.SystemParts) * perPartOverhead
}
// Use the larger of the two representations to stay conservative.
chars := contentChars
if systemPartsChars > chars {
chars = systemPartsChars
}
chars += utf8.RuneCountInString(msg.ReasoningContent)
for _, tc := range msg.ToolCalls {
chars += len(tc.ID) + len(tc.Type)
if tc.Function != nil {
// Count function name + arguments (the wire format for most providers).
// tc.Name mirrors tc.Function.Name — count only once to avoid double-counting.
chars += len(tc.Function.Name) + len(tc.Function.Arguments)
} else {
// Fallback: some provider formats use top-level Name without Function.
chars += len(tc.Name)
}
}
if msg.ToolCallID != "" {
chars += len(msg.ToolCallID)
}
// Per-message overhead for role label, JSON structure, separators.
const messageOverhead = 12
chars += messageOverhead
tokens := chars * 2 / 5
// Media items (images, files) are serialized by provider adapters into
// multipart or image_url payloads. Add a fixed per-item token estimate
// directly (not through the chars heuristic) since actual cost depends
// on resolution and provider-specific image tokenization.
const mediaTokensPerItem = 256
tokens += len(msg.Media) * mediaTokensPerItem
return tokens
// EstimateMessageTokens estimates the token count for a single message.
// Delegates to the shared tokenizer package for consistency across agent and seahorse.
func EstimateMessageTokens(msg providers.Message) int {
return tokenizer.EstimateMessageTokens(msg)
}
// estimateToolDefsTokens estimates the total token cost of tool definitions
// as they appear in the LLM request. Each tool's name, description, and
// JSON schema parameters contribute to the context window budget.
func estimateToolDefsTokens(defs []providers.ToolDefinition) int {
if len(defs) == 0 {
return 0
}
totalChars := 0
for _, d := range defs {
totalChars += len(d.Function.Name) + len(d.Function.Description)
if d.Function.Parameters != nil {
if paramJSON, err := json.Marshal(d.Function.Parameters); err == nil {
totalChars += len(paramJSON)
}
}
// Per-tool overhead: type field, JSON structure, separators.
totalChars += 20
}
return totalChars * 2 / 5
// EstimateToolDefsTokens estimates the total token cost of tool definitions
// as they appear in the LLM request. Delegates to the shared tokenizer package.
func EstimateToolDefsTokens(defs []providers.ToolDefinition) int {
return tokenizer.EstimateToolDefsTokens(defs)
}
// isOverContextBudget checks whether the assembled messages plus tool definitions
@@ -181,10 +107,10 @@ func isOverContextBudget(
) bool {
msgTokens := 0
for _, m := range messages {
msgTokens += estimateMessageTokens(m)
msgTokens += EstimateMessageTokens(m)
}
toolTokens := estimateToolDefsTokens(toolDefs)
toolTokens := EstimateToolDefsTokens(toolDefs)
total := msgTokens + toolTokens + maxTokens
return total > contextWindow
+19 -19
View File
@@ -417,9 +417,9 @@ func TestEstimateMessageTokens(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := estimateMessageTokens(tt.msg)
got := EstimateMessageTokens(tt.msg)
if got < tt.want {
t.Errorf("estimateMessageTokens() = %d, want >= %d", got, tt.want)
t.Errorf("EstimateMessageTokens() = %d, want >= %d", got, tt.want)
}
})
}
@@ -443,8 +443,8 @@ func TestEstimateMessageTokens_ToolCallsContribute(t *testing.T) {
},
}
plainTokens := estimateMessageTokens(plain)
withTCTokens := estimateMessageTokens(withTC)
plainTokens := EstimateMessageTokens(plain)
withTCTokens := EstimateMessageTokens(withTC)
if withTCTokens <= plainTokens {
t.Errorf("message with ToolCalls (%d tokens) should exceed plain message (%d tokens)",
@@ -457,7 +457,7 @@ func TestEstimateMessageTokens_MultibyteContent(t *testing.T) {
// but may map to different token counts. The heuristic should still produce
// reasonable estimates via RuneCountInString.
msg := msgUser("caf\u00e9 na\u00efve r\u00e9sum\u00e9 \u00fcber stra\u00dfe")
tokens := estimateMessageTokens(msg)
tokens := EstimateMessageTokens(msg)
if tokens <= 0 {
t.Errorf("multibyte message should produce positive token count, got %d", tokens)
}
@@ -481,7 +481,7 @@ func TestEstimateMessageTokens_LargeArguments(t *testing.T) {
},
}
tokens := estimateMessageTokens(msg)
tokens := EstimateMessageTokens(msg)
// 5000+ chars → at least 2000 tokens with the 2.5 char/token heuristic
if tokens < 2000 {
t.Errorf("large tool call arguments should produce significant token count, got %d", tokens)
@@ -496,8 +496,8 @@ func TestEstimateMessageTokens_ReasoningContent(t *testing.T) {
ReasoningContent: strings.Repeat("thinking step ", 200),
}
plainTokens := estimateMessageTokens(plain)
reasoningTokens := estimateMessageTokens(withReasoning)
plainTokens := EstimateMessageTokens(plain)
reasoningTokens := EstimateMessageTokens(withReasoning)
if reasoningTokens <= plainTokens {
t.Errorf("message with ReasoningContent (%d tokens) should exceed plain message (%d tokens)",
@@ -513,8 +513,8 @@ func TestEstimateMessageTokens_MediaItems(t *testing.T) {
Media: []string{"media://img1.png", "media://img2.png"},
}
plainTokens := estimateMessageTokens(plain)
mediaTokens := estimateMessageTokens(withMedia)
plainTokens := EstimateMessageTokens(plain)
mediaTokens := EstimateMessageTokens(withMedia)
if mediaTokens <= plainTokens {
t.Errorf("message with Media (%d tokens) should exceed plain message (%d tokens)",
@@ -540,8 +540,8 @@ func TestEstimateMessageTokens_SystemParts(t *testing.T) {
},
}
plainTokens := estimateMessageTokens(plain)
partsTokens := estimateMessageTokens(withParts)
plainTokens := EstimateMessageTokens(plain)
partsTokens := EstimateMessageTokens(withParts)
if partsTokens <= plainTokens {
t.Errorf("system message with SystemParts (%d) should exceed plain message (%d)",
@@ -549,7 +549,7 @@ func TestEstimateMessageTokens_SystemParts(t *testing.T) {
}
}
// --- estimateToolDefsTokens tests ---
// --- EstimateToolDefsTokens tests ---
func TestEstimateToolDefsTokens(t *testing.T) {
tests := []struct {
@@ -599,9 +599,9 @@ func TestEstimateToolDefsTokens(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := estimateToolDefsTokens(tt.defs)
got := EstimateToolDefsTokens(tt.defs)
if got < tt.want {
t.Errorf("estimateToolDefsTokens() = %d, want >= %d", got, tt.want)
t.Errorf("EstimateToolDefsTokens() = %d, want >= %d", got, tt.want)
}
})
}
@@ -624,8 +624,8 @@ func TestEstimateToolDefsTokens_ScalesWithCount(t *testing.T) {
}
}
one := estimateToolDefsTokens([]providers.ToolDefinition{makeTool("tool_a")})
three := estimateToolDefsTokens([]providers.ToolDefinition{
one := EstimateToolDefsTokens([]providers.ToolDefinition{makeTool("tool_a")})
three := EstimateToolDefsTokens([]providers.ToolDefinition{
makeTool("tool_a"), makeTool("tool_b"), makeTool("tool_c"),
})
@@ -770,7 +770,7 @@ func TestEstimateMessageTokens_WithReasoningAndMedia(t *testing.T) {
},
}
tokens := estimateMessageTokens(msg)
tokens := EstimateMessageTokens(msg)
// ReasoningContent alone is ~1700 chars → ~680 tokens.
// Content + TC + overhead adds more. Should be well above 500.
@@ -781,7 +781,7 @@ func TestEstimateMessageTokens_WithReasoningAndMedia(t *testing.T) {
// Compare without reasoning to ensure it's counted.
msgNoReasoning := msg
msgNoReasoning.ReasoningContent = ""
tokensNoReasoning := estimateMessageTokens(msgNoReasoning)
tokensNoReasoning := EstimateMessageTokens(msgNoReasoning)
if tokens <= tokensNoReasoning {
t.Errorf("reasoning content should add tokens: with=%d, without=%d", tokens, tokensNoReasoning)
+1 -1
View File
@@ -373,7 +373,7 @@ func (m *legacyContextManager) summarizeBatch(
func (m *legacyContextManager) estimateTokens(messages []providers.Message) int {
total := 0
for _, msg := range messages {
total += estimateMessageTokens(msg)
total += EstimateMessageTokens(msg)
}
return total
}
+1
View File
@@ -43,6 +43,7 @@ type AssembleResponse struct {
type CompactRequest struct {
SessionKey string // session identifier
Reason ContextCompressReason // proactive_budget | llm_retry | summarize
Budget int // context window budget (used for retry aggressive compaction)
}
// IngestRequest is the input to Ingest.
+269
View File
@@ -0,0 +1,269 @@
//go:build !mipsle && !netbsd && !(freebsd && arm)
package agent
import (
"context"
"encoding/json"
"fmt"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
"github.com/sipeed/picoclaw/pkg/seahorse"
"github.com/sipeed/picoclaw/pkg/session"
"github.com/sipeed/picoclaw/pkg/tokenizer"
)
// seahorseContextManager adapts seahorse.Engine to agent.ContextManager.
type seahorseContextManager struct {
engine *seahorse.Engine
sessions session.SessionStore // for startup bootstrap
}
// newSeahorseContextManager creates a seahorse-backed ContextManager.
func newSeahorseContextManager(_ json.RawMessage, al *AgentLoop) (ContextManager, error) {
if al == nil {
return nil, fmt.Errorf("seahorse: AgentLoop is required")
}
// Resolve workspace for DB path
// DB stores session data, so it goes in sessions/ directory
agent := al.registry.GetDefaultAgent()
dbPath := agent.Workspace + "/sessions/seahorse.db"
// Create CompleteFn from provider
completeFn := providerToCompleteFn(agent.Provider, agent.Model)
// Create engine
engine, err := seahorse.NewEngine(seahorse.Config{
DBPath: dbPath,
}, completeFn)
if err != nil {
return nil, fmt.Errorf("seahorse: create engine: %w", err)
}
mgr := &seahorseContextManager{
engine: engine,
sessions: agent.Sessions,
}
// Register seahorse tools with the agent's tool registry
retrieval := mgr.engine.GetRetrieval()
al.RegisterTool(seahorse.NewGrepTool(retrieval))
al.RegisterTool(seahorse.NewExpandTool(retrieval))
// Bootstrap all existing sessions at startup
if agent.Sessions != nil {
ctx := context.Background()
for _, sessionKey := range agent.Sessions.ListSessions() {
mgr.bootstrapSession(ctx, sessionKey)
}
}
return mgr, nil
}
// providerToCompleteFn wraps providers.LLMProvider as a seahorse.CompleteFn.
func providerToCompleteFn(provider providers.LLMProvider, model string) seahorse.CompleteFn {
return func(ctx context.Context, prompt string, opts seahorse.CompleteOptions) (string, error) {
resp, err := provider.Chat(
ctx,
[]providers.Message{{Role: "user", Content: prompt}},
nil, // no tools for summarization
model,
map[string]any{
"max_tokens": opts.MaxTokens,
"temperature": opts.Temperature,
"prompt_cache_key": "seahorse",
},
)
if err != nil {
return "", err
}
return resp.Content, nil
}
}
// Assemble builds budget-aware context from seahorse SQLite.
func (m *seahorseContextManager) Assemble(ctx context.Context, req *AssembleRequest) (*AssembleResponse, error) {
if req == nil {
return nil, fmt.Errorf("seahorse assemble: nil request")
}
budget := req.Budget
if budget <= 0 {
budget = 100000
}
// Reserve space for model response (spec lines 1400-1410)
effectiveBudget := budget - req.MaxTokens
if effectiveBudget <= 0 {
// MaxTokens >= budget is a configuration problem
// Use 50% as minimum to avoid guaranteed overflow
logger.WarnCF("agent", "MaxTokens >= budget, using 50% fallback",
map[string]any{"budget": budget, "max_tokens": req.MaxTokens})
effectiveBudget = budget / 2
}
result, err := m.engine.Assemble(ctx, req.SessionKey, seahorse.AssembleInput{
Budget: effectiveBudget,
})
if err != nil {
return nil, fmt.Errorf("seahorse assemble: %w", err)
}
history := seahorseToProviderMessages(result)
// Summary is already formatted as XML with system prompt addition by assembler
return &AssembleResponse{
History: history,
Summary: result.Summary,
}, nil
}
// Compact compresses conversation history via seahorse summarization.
func (m *seahorseContextManager) Compact(ctx context.Context, req *CompactRequest) error {
if req == nil {
return nil
}
// For retry (LLM overflow), use aggressive CompactUntilUnder to guarantee
// context shrinks below budget (spec lines ~1410).
if req.Reason == ContextCompressReasonRetry && req.Budget > 0 {
_, err := m.engine.CompactUntilUnder(ctx, req.SessionKey, req.Budget)
return err
}
_, err := m.engine.Compact(ctx, req.SessionKey, seahorse.CompactInput{
Force: req.Reason == ContextCompressReasonRetry,
Budget: &req.Budget,
})
return err
}
// Ingest records a message into seahorse SQLite.
// All existing sessions are bootstrapped at startup, so this only ingests new messages.
func (m *seahorseContextManager) Ingest(ctx context.Context, req *IngestRequest) error {
if req == nil {
return nil
}
msg := providerToSeahorseMessage(req.Message)
_, err := m.engine.Ingest(ctx, req.SessionKey, []seahorse.Message{msg})
return err
}
// bootstrapSession reconciles JSONL session history into seahorse SQLite.
func (m *seahorseContextManager) bootstrapSession(ctx context.Context, sessionKey string) {
if m.sessions == nil {
return
}
history := m.sessions.GetHistory(sessionKey)
if len(history) == 0 {
return
}
// Convert provider messages to seahorse messages
msgs := make([]seahorse.Message, len(history))
for i, h := range history {
msgs[i] = providerToSeahorseMessage(h)
}
if err := m.engine.Bootstrap(ctx, sessionKey, msgs); err != nil {
logger.WarnCF("seahorse", "bootstrap", map[string]any{
"session": sessionKey,
"error": err.Error(),
})
}
}
// providerToSeahorseMessage converts a providers.Message to a seahorse.Message.
func providerToSeahorseMessage(msg protocoltypes.Message) seahorse.Message {
result := seahorse.Message{
Role: msg.Role,
Content: msg.Content,
ReasoningContent: msg.ReasoningContent,
TokenCount: tokenizer.EstimateMessageTokens(msg),
}
// Convert ToolCalls → MessageParts
for _, tc := range msg.ToolCalls {
part := seahorse.MessagePart{
Type: "tool_use",
Name: tc.Function.Name,
Arguments: tc.Function.Arguments,
ToolCallID: tc.ID,
}
result.Parts = append(result.Parts, part)
}
// Convert tool result
if msg.ToolCallID != "" {
part := seahorse.MessagePart{
Type: "tool_result",
ToolCallID: msg.ToolCallID,
Text: msg.Content,
}
result.Parts = append(result.Parts, part)
}
// Convert media attachments
for _, mediaURI := range msg.Media {
part := seahorse.MessagePart{
Type: "media",
MediaURI: mediaURI,
}
result.Parts = append(result.Parts, part)
}
return result
}
// seahorseToProviderMessages converts a seahorse.AssembleResult to []providers.Message.
func seahorseToProviderMessages(result *seahorse.AssembleResult) []protocoltypes.Message {
messages := make([]protocoltypes.Message, 0, len(result.Messages))
// Convert assembled messages (which already include summary XML messages)
for _, msg := range result.Messages {
pm := protocoltypes.Message{
Role: msg.Role,
Content: msg.Content,
ReasoningContent: msg.ReasoningContent,
}
// Reconstruct ToolCalls from parts
for _, part := range msg.Parts {
if part.Type == "tool_use" {
pm.ToolCalls = append(pm.ToolCalls, protocoltypes.ToolCall{
ID: part.ToolCallID,
Type: "function", // Required by OpenAI-compatible APIs (GLM, etc.)
Function: &protocoltypes.FunctionCall{
Name: part.Name,
Arguments: part.Arguments,
},
})
}
if part.Type == "tool_result" {
pm.ToolCallID = part.ToolCallID
if pm.Content == "" && part.Text != "" {
pm.Content = part.Text
}
}
if part.Type == "media" && part.MediaURI != "" {
pm.Media = append(pm.Media, part.MediaURI)
}
}
messages = append(messages, pm)
}
return messages
}
func init() {
if err := RegisterContextManager("seahorse", newSeahorseContextManager); err != nil {
panic(fmt.Sprintf("register seahorse context manager: %v", err))
}
}
File diff suppressed because it is too large Load Diff
+20
View File
@@ -0,0 +1,20 @@
//go:build mipsle || netbsd || (freebsd && arm)
package agent
import (
"encoding/json"
"fmt"
)
// newSeahorseContextManager is unavailable on platforms where modernc sqlite/libc
// currently has no stable build path for this project.
func newSeahorseContextManager(_ json.RawMessage, _ *AgentLoop) (ContextManager, error) {
return nil, fmt.Errorf("seahorse context manager is unavailable on this platform")
}
func init() {
if err := RegisterContextManager("seahorse", newSeahorseContextManager); err != nil {
panic(fmt.Sprintf("register seahorse context manager: %v", err))
}
}
+11 -2
View File
@@ -12,7 +12,9 @@ import (
"sync/atomic"
"time"
"github.com/sipeed/picoclaw/pkg/isolation"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/tools"
)
const (
@@ -90,7 +92,8 @@ type processHookAfterLLMResponse struct {
type processHookBeforeToolResponse struct {
processHookDecisionResponse
Call *ToolCallHookRequest `json:"call,omitempty"`
Call *ToolCallHookRequest `json:"call,omitempty"`
Result *tools.ToolResult `json:"result,omitempty"` // Result returned directly by hook (for respond action)
}
type processHookAfterToolResponse struct {
@@ -120,7 +123,9 @@ func NewProcessHook(ctx context.Context, name string, opts ProcessHookOptions) (
if err != nil {
return nil, fmt.Errorf("create process hook stderr: %w", err)
}
if err := cmd.Start(); err != nil {
// Route hook subprocess startup through the shared isolation entry point so
// process hooks inherit the same isolation behavior as other child processes.
if err := isolation.Start(cmd); err != nil {
return nil, fmt.Errorf("start process hook: %w", err)
}
@@ -241,6 +246,10 @@ func (ph *ProcessHook) BeforeTool(
if resp.Call == nil {
resp.Call = call
}
// If hook returned a Result, carry it in ToolCallHookRequest
if resp.Result != nil {
resp.Call.HookResult = resp.Result
}
return resp.Call, HookDecision{Action: resp.Action, Reason: resp.Reason}, nil
}
+126
View File
@@ -7,10 +7,13 @@ import (
"fmt"
"os"
"path/filepath"
"runtime"
"strings"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/isolation"
"github.com/sipeed/picoclaw/pkg/providers"
)
@@ -178,6 +181,76 @@ func TestAgentLoop_MountProcessHook_ApprovalDeny(t *testing.T) {
}
}
func TestAgentLoop_MountProcessHook_IsolationSupportsRelativeDirAndCommand(t *testing.T) {
if runtime.GOOS != "linux" {
t.Skip("linux-only isolation path handling")
}
provider := &llmHookTestProvider{}
al, agent, cleanup := newHookTestLoop(t, provider)
defer cleanup()
root := t.TempDir()
t.Setenv(config.EnvHome, filepath.Join(root, "picoclaw-home"))
binDir := filepath.Join(root, "bin")
hookDir := filepath.Join(root, "hooks")
if err := os.MkdirAll(binDir, 0o755); err != nil {
t.Fatal(err)
}
if err := os.MkdirAll(hookDir, 0o755); err != nil {
t.Fatal(err)
}
writeFakeBwrap(t, filepath.Join(binDir, "bwrap"))
t.Setenv("PATH", binDir+string(os.PathListSeparator)+os.Getenv("PATH"))
linkTestBinary(t, os.Args[0], filepath.Join(hookDir, "hook-helper"))
cfg := config.DefaultConfig()
cfg.Isolation.Enabled = true
isolation.Configure(cfg)
t.Cleanup(func() { isolation.Configure(config.DefaultConfig()) })
cwd, err := os.Getwd()
if err != nil {
t.Fatal(err)
}
relHookDir, err := filepath.Rel(cwd, hookDir)
if err != nil {
t.Fatal(err)
}
mountErr := al.MountProcessHook(context.Background(), "ipc-relative", ProcessHookOptions{
Command: []string{"./hook-helper", "-test.run=TestProcessHook_HelperProcess", "--"},
Dir: relHookDir,
Env: processHookHelperEnv("rewrite", ""),
InterceptLLM: true,
})
if mountErr != nil {
t.Fatalf("MountProcessHook failed with relative dir/command under isolation: %v", mountErr)
}
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
SessionKey: "session-relative",
Channel: "cli",
ChatID: "direct",
UserMessage: "hello",
DefaultResponse: defaultResponse,
EnableSummary: false,
SendResponse: false,
})
if err != nil {
t.Fatalf("runAgentLoop failed: %v", err)
}
if resp != "provider content|ipc" {
t.Fatalf("expected process-hooked llm content, got %q", resp)
}
provider.mu.Lock()
lastModel := provider.lastModel
provider.mu.Unlock()
if lastModel != "process-model" {
t.Fatalf("expected process model, got %q", lastModel)
}
}
func processHookHelperCommand() []string {
return []string{os.Args[0], "-test.run=TestProcessHook_HelperProcess", "--"}
}
@@ -193,6 +266,59 @@ func processHookHelperEnv(mode, eventLog string) []string {
return env
}
func writeFakeBwrap(t *testing.T, path string) {
t.Helper()
script := `#!/bin/sh
set -eu
workdir=
while [ "$#" -gt 0 ]; do
case "$1" in
--)
shift
break
;;
--chdir)
workdir="$2"
shift 2
;;
--bind|--ro-bind)
shift 3
;;
--proc|--dev)
shift 2
;;
--die-with-parent|--unshare-ipc)
shift
;;
*)
shift
;;
esac
done
if [ -n "$workdir" ]; then
cd "$workdir"
fi
exec "$@"
`
if err := os.WriteFile(path, []byte(script), 0o755); err != nil {
t.Fatalf("write fake bwrap: %v", err)
}
}
func linkTestBinary(t *testing.T, source, target string) {
t.Helper()
if err := os.Symlink(source, target); err == nil {
return
}
data, err := os.ReadFile(source)
if err != nil {
t.Fatalf("read test binary: %v", err)
}
if err := os.WriteFile(target, data, 0o755); err != nil {
t.Fatalf("create hook helper binary: %v", err)
}
}
func waitForFileContains(t *testing.T, path, substring string) {
t.Helper()
+19 -5
View File
@@ -25,6 +25,7 @@ type HookAction string
const (
HookActionContinue HookAction = "continue"
HookActionModify HookAction = "modify"
HookActionRespond HookAction = "respond" // Return result directly, skip tool execution. SECURITY: This bypasses ApproveTool checks, allowing hooks to return results for any tool (including sensitive ones like bash) without approval. Use with caution.
HookActionDenyTool HookAction = "deny_tool"
HookActionAbortTurn HookAction = "abort_turn"
HookActionHardAbort HookAction = "hard_abort"
@@ -127,11 +128,12 @@ func (r *LLMHookResponse) Clone() *LLMHookResponse {
}
type ToolCallHookRequest struct {
Meta EventMeta `json:"meta"`
Tool string `json:"tool"`
Arguments map[string]any `json:"arguments,omitempty"`
Channel string `json:"channel,omitempty"`
ChatID string `json:"chat_id,omitempty"`
Meta EventMeta `json:"meta"`
Tool string `json:"tool"`
Arguments map[string]any `json:"arguments,omitempty"`
Channel string `json:"channel,omitempty"`
ChatID string `json:"chat_id,omitempty"`
HookResult *tools.ToolResult `json:"hook_result,omitempty"` // Result returned directly by hook (for respond action). Media is supported - see Media handling section in docs.
}
func (r *ToolCallHookRequest) Clone() *ToolCallHookRequest {
@@ -140,6 +142,7 @@ func (r *ToolCallHookRequest) Clone() *ToolCallHookRequest {
}
cloned := *r
cloned.Arguments = cloneStringAnyMap(r.Arguments)
cloned.HookResult = cloneToolResult(r.HookResult)
return &cloned
}
@@ -382,6 +385,10 @@ func (hm *HookManager) BeforeTool(
if next != nil {
current = next
}
case HookActionRespond:
// Hook returns result directly, skip tool execution
// Carry HookResult in ToolCallHookRequest and return
return next, decision
case HookActionDenyTool, HookActionAbortTurn, HookActionHardAbort:
return current, decision
default:
@@ -793,6 +800,13 @@ func cloneToolResult(result *tools.ToolResult) *tools.ToolResult {
if len(result.Media) > 0 {
cloned.Media = append([]string(nil), result.Media...)
}
if len(result.ArtifactTags) > 0 {
cloned.ArtifactTags = append([]string(nil), result.ArtifactTags...)
}
if len(result.Messages) > 0 {
cloned.Messages = make([]providers.Message, len(result.Messages))
copy(cloned.Messages, result.Messages)
}
return &cloned
}
+516
View File
@@ -2,6 +2,7 @@ package agent
import (
"context"
"errors"
"os"
"sync"
"testing"
@@ -10,6 +11,7 @@ import (
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/routing"
"github.com/sipeed/picoclaw/pkg/tools"
)
@@ -343,3 +345,517 @@ func TestAgentLoop_Hooks_ToolApproverCanDeny(t *testing.T) {
t.Fatalf("expected skipped reason %q, got %q", expected, payload.Reason)
}
}
// respondHook is a test hook for testing HookActionRespond functionality
type respondHook struct {
respondTools map[string]bool // tool names to respond to
}
func (h *respondHook) BeforeTool(
ctx context.Context,
call *ToolCallHookRequest,
) (*ToolCallHookRequest, HookDecision, error) {
if h.respondTools[call.Tool] {
next := call.Clone()
next.HookResult = &tools.ToolResult{
ForLLM: "hook-responded: " + call.Tool,
ForUser: "",
Silent: false,
IsError: false,
}
return next, HookDecision{Action: HookActionRespond}, nil
}
return call, HookDecision{Action: HookActionContinue}, nil
}
func (h *respondHook) AfterTool(
ctx context.Context,
result *ToolResultHookResponse,
) (*ToolResultHookResponse, HookDecision, error) {
// Should not be called since respond skips tool execution
return result, HookDecision{Action: HookActionContinue}, nil
}
func TestAgentLoop_Hooks_ToolRespondAction(t *testing.T) {
provider := &toolHookProvider{}
al, agent, cleanup := newHookTestLoop(t, provider)
defer cleanup()
al.RegisterTool(&echoTextTool{})
if err := al.MountHook(NamedHook("respond-hook", &respondHook{
respondTools: map[string]bool{"echo_text": true},
})); err != nil {
t.Fatalf("MountHook failed: %v", err)
}
sub := al.SubscribeEvents(16)
defer al.UnsubscribeEvents(sub.ID)
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
SessionKey: "session-1",
Channel: "cli",
ChatID: "direct",
UserMessage: "run tool",
DefaultResponse: defaultResponse,
EnableSummary: false,
SendResponse: false,
})
if err != nil {
t.Fatalf("runAgentLoop failed: %v", err)
}
// Verify response comes from hook, not tool
expected := "hook-responded: echo_text"
if resp != expected {
t.Fatalf("expected %q, got %q", expected, resp)
}
// Verify event stream has ToolExecEnd, not actual tool execution
events := collectEventStream(sub.C)
endEvt, ok := findEvent(events, EventKindToolExecEnd)
if !ok {
t.Fatal("expected tool exec end event")
}
payload, ok := endEvt.Payload.(ToolExecEndPayload)
if !ok {
t.Fatalf("expected ToolExecEndPayload, got %T", endEvt.Payload)
}
if payload.Tool != "echo_text" {
t.Fatalf("expected tool echo_text, got %q", payload.Tool)
}
if payload.ForLLMLen != len(expected) {
t.Fatalf("expected ForLLMLen %d, got %d", len(expected), payload.ForLLMLen)
}
}
// denyToolHook tests HookActionDenyTool functionality
type denyToolHook struct {
denyTools map[string]bool
}
func (h *denyToolHook) BeforeTool(
ctx context.Context,
call *ToolCallHookRequest,
) (*ToolCallHookRequest, HookDecision, error) {
if h.denyTools[call.Tool] {
return call, HookDecision{Action: HookActionDenyTool, Reason: "tool denied by hook"}, nil
}
return call, HookDecision{Action: HookActionContinue}, nil
}
func (h *denyToolHook) AfterTool(
ctx context.Context,
result *ToolResultHookResponse,
) (*ToolResultHookResponse, HookDecision, error) {
return result, HookDecision{Action: HookActionContinue}, nil
}
func TestAgentLoop_Hooks_ToolDenyAction(t *testing.T) {
provider := &toolHookProvider{}
al, agent, cleanup := newHookTestLoop(t, provider)
defer cleanup()
al.RegisterTool(&echoTextTool{})
if err := al.MountHook(NamedHook("deny-hook", &denyToolHook{
denyTools: map[string]bool{"echo_text": true},
})); err != nil {
t.Fatalf("MountHook failed: %v", err)
}
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
SessionKey: "session-1",
Channel: "cli",
ChatID: "direct",
UserMessage: "run tool",
DefaultResponse: defaultResponse,
EnableSummary: false,
SendResponse: false,
})
if err != nil {
t.Fatalf("runAgentLoop failed: %v", err)
}
expected := "Tool execution denied by hook: tool denied by hook"
if resp != expected {
t.Fatalf("expected %q, got %q", expected, resp)
}
}
func TestHookManager_BeforeTool_RespondAction(t *testing.T) {
hm := NewHookManager(nil)
defer hm.Close()
hook := &respondHook{
respondTools: map[string]bool{"test_tool": true},
}
if err := hm.Mount(NamedHook("respond-test", hook)); err != nil {
t.Fatalf("mount hook: %v", err)
}
req := &ToolCallHookRequest{
Tool: "test_tool",
Arguments: map[string]any{"arg": "value"},
}
result, decision := hm.BeforeTool(context.Background(), req)
if decision.Action != HookActionRespond {
t.Fatalf("expected action %q, got %q", HookActionRespond, decision.Action)
}
if result.HookResult == nil {
t.Fatal("expected HookResult to be set")
}
if result.HookResult.ForLLM != "hook-responded: test_tool" {
t.Fatalf("unexpected HookResult.ForLLM: %q", result.HookResult.ForLLM)
}
}
type respondWithMediaHook struct {
respondTools map[string]bool
media []string
responseHandled bool
forLLM string
}
func (h *respondWithMediaHook) BeforeTool(
ctx context.Context,
call *ToolCallHookRequest,
) (*ToolCallHookRequest, HookDecision, error) {
if h.respondTools[call.Tool] {
next := call.Clone()
next.HookResult = &tools.ToolResult{
ForLLM: h.forLLM,
ForUser: "media result",
Media: h.media,
ResponseHandled: h.responseHandled,
Silent: false,
IsError: false,
}
return next, HookDecision{Action: HookActionRespond}, nil
}
return call, HookDecision{Action: HookActionContinue}, nil
}
func (h *respondWithMediaHook) AfterTool(
ctx context.Context,
result *ToolResultHookResponse,
) (*ToolResultHookResponse, HookDecision, error) {
return result, HookDecision{Action: HookActionContinue}, nil
}
type errorMediaChannel struct {
fakeChannel
sendErr error
}
func (f *errorMediaChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) ([]string, error) {
return nil, f.sendErr
}
func TestAgentLoop_HookRespond_MediaError(t *testing.T) {
provider := &multiToolProvider{
toolCalls: []providers.ToolCall{
{ID: "call-1", Name: "media_tool", Arguments: map[string]any{}},
},
finalContent: "done",
}
al, agent, cleanup := newHookTestLoop(t, provider)
defer cleanup()
hook := &respondWithMediaHook{
respondTools: map[string]bool{"media_tool": true},
media: []string{"media://test/image.png"},
responseHandled: true,
forLLM: "media sent successfully",
}
if err := al.MountHook(NamedHook("media-hook", hook)); err != nil {
t.Fatalf("MountHook failed: %v", err)
}
al.channelManager = newStartedTestChannelManager(t, al.bus, al.mediaStore, "discord", &errorMediaChannel{
sendErr: errors.New("channel unavailable"),
})
sub := al.SubscribeEvents(16)
defer al.UnsubscribeEvents(sub.ID)
_, err := al.runAgentLoop(context.Background(), agent, processOptions{
SessionKey: "session-media-err",
Channel: "discord",
ChatID: "chat1",
UserMessage: "send media",
DefaultResponse: defaultResponse,
EnableSummary: false,
SendResponse: false,
})
if err != nil {
t.Fatalf("runAgentLoop failed: %v", err)
}
events := collectEventStream(sub.C)
endEvt, ok := findEvent(events, EventKindToolExecEnd)
if !ok {
t.Fatal("expected ToolExecEnd event")
}
payload, ok := endEvt.Payload.(ToolExecEndPayload)
if !ok {
t.Fatalf("expected ToolExecEndPayload, got %T", endEvt.Payload)
}
if !payload.IsError {
t.Fatal("expected IsError=true when SendMedia fails")
}
if payload.ForLLMLen < 30 {
t.Fatalf("expected ForLLM to contain error message, got ForLLMLen=%d", payload.ForLLMLen)
}
}
func TestAgentLoop_HookRespond_BusFallback(t *testing.T) {
provider := &multiToolProvider{
toolCalls: []providers.ToolCall{
{ID: "call-1", Name: "media_tool", Arguments: map[string]any{}},
},
finalContent: "done",
}
al, agent, cleanup := newHookTestLoop(t, provider)
defer cleanup()
hook := &respondWithMediaHook{
respondTools: map[string]bool{"media_tool": true},
media: []string{"media://test/image.png"},
responseHandled: true,
forLLM: "media queued",
}
if err := al.MountHook(NamedHook("media-hook", hook)); err != nil {
t.Fatalf("MountHook failed: %v", err)
}
sub := al.SubscribeEvents(16)
defer al.UnsubscribeEvents(sub.ID)
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
SessionKey: "session-bus-fallback",
Channel: "cli",
ChatID: "chat1",
UserMessage: "send media",
DefaultResponse: defaultResponse,
EnableSummary: false,
SendResponse: false,
})
if err != nil {
t.Fatalf("runAgentLoop failed: %v", err)
}
events := collectEventStream(sub.C)
endEvt, ok := findEvent(events, EventKindToolExecEnd)
if !ok {
t.Fatal("expected ToolExecEnd event")
}
payload, ok := endEvt.Payload.(ToolExecEndPayload)
if !ok {
t.Fatalf("expected ToolExecEndPayload, got %T", endEvt.Payload)
}
if payload.IsError {
t.Fatal("expected IsError=false for bus fallback (media queued, not delivered)")
}
if resp != "done" {
t.Fatalf("expected response 'done', got %q", resp)
}
}
type multiToolProvider struct {
mu sync.Mutex
callCount int
toolCalls []providers.ToolCall
finalContent string
}
func (p *multiToolProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
p.mu.Lock()
defer p.mu.Unlock()
p.callCount++
if p.callCount == 1 && len(p.toolCalls) > 0 {
return &providers.LLMResponse{
ToolCalls: p.toolCalls,
}, nil
}
return &providers.LLMResponse{
Content: p.finalContent,
}, nil
}
func (p *multiToolProvider) GetDefaultModel() string {
return "multi-tool-provider"
}
func TestAgentLoop_HookRespond_InterruptSkipsRemaining(t *testing.T) {
provider := &multiToolProvider{
toolCalls: []providers.ToolCall{
{ID: "call-1", Name: "tool_one", Arguments: map[string]any{}},
{ID: "call-2", Name: "tool_two", Arguments: map[string]any{}},
{ID: "call-3", Name: "tool_three", Arguments: map[string]any{}},
},
finalContent: "done",
}
al, _, cleanup := newHookTestLoop(t, provider)
defer cleanup()
tool1ExecCh := make(chan struct{}, 1)
al.RegisterTool(&slowTool{name: "tool_two", duration: 100 * time.Millisecond, execCh: tool1ExecCh})
al.RegisterTool(&slowTool{name: "tool_three", duration: 100 * time.Millisecond})
hook := &respondHook{
respondTools: map[string]bool{"tool_one": true},
}
if err := al.MountHook(NamedHook("respond-hook", hook)); err != nil {
t.Fatalf("MountHook failed: %v", err)
}
sub := al.SubscribeEvents(32)
defer al.UnsubscribeEvents(sub.ID)
sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID)
type result struct {
resp string
err error
}
resultCh := make(chan result, 1)
go func() {
resp, err := al.ProcessDirectWithChannel(
context.Background(),
"run tools",
sessionKey,
"cli",
"chat1",
)
resultCh <- result{resp: resp, err: err}
}()
time.Sleep(50 * time.Millisecond)
if err := al.InterruptGraceful("stop now"); err != nil {
t.Fatalf("InterruptGraceful failed: %v", err)
}
select {
case r := <-resultCh:
if r.err != nil {
t.Fatalf("unexpected error: %v", r.err)
}
case <-time.After(3 * time.Second):
t.Fatal("timeout waiting for result")
}
events := collectEventStream(sub.C)
skippedEvts := filterEvents(events, EventKindToolExecSkipped)
if len(skippedEvts) < 1 {
t.Fatal("expected at least one ToolExecSkipped event after interrupt")
}
for _, evt := range skippedEvts {
payload, ok := evt.Payload.(ToolExecSkippedPayload)
if !ok {
t.Fatalf("expected ToolExecSkippedPayload, got %T", evt.Payload)
}
if payload.Reason != "graceful interrupt requested" {
t.Fatalf("expected skip reason 'graceful interrupt requested', got %q", payload.Reason)
}
}
}
func TestAgentLoop_HookRespond_SteeringSkipsRemaining(t *testing.T) {
provider := &multiToolProvider{
toolCalls: []providers.ToolCall{
{ID: "call-1", Name: "tool_one", Arguments: map[string]any{}},
{ID: "call-2", Name: "tool_two", Arguments: map[string]any{}},
{ID: "call-3", Name: "tool_three", Arguments: map[string]any{}},
},
finalContent: "done",
}
al, _, cleanup := newHookTestLoop(t, provider)
defer cleanup()
al.RegisterTool(&slowTool{name: "tool_two", duration: 100 * time.Millisecond})
al.RegisterTool(&slowTool{name: "tool_three", duration: 100 * time.Millisecond})
hook := &respondHook{
respondTools: map[string]bool{"tool_one": true},
}
if err := al.MountHook(NamedHook("respond-hook", hook)); err != nil {
t.Fatalf("MountHook failed: %v", err)
}
sub := al.SubscribeEvents(32)
defer al.UnsubscribeEvents(sub.ID)
sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID)
type result struct {
resp string
err error
}
resultCh := make(chan result, 1)
go func() {
resp, err := al.ProcessDirectWithChannel(
context.Background(),
"run tools",
sessionKey,
"cli",
"chat1",
)
resultCh <- result{resp: resp, err: err}
}()
time.Sleep(50 * time.Millisecond)
al.Steer(providers.Message{Role: "user", Content: "change direction"})
select {
case r := <-resultCh:
if r.err != nil {
t.Fatalf("unexpected error: %v", r.err)
}
case <-time.After(3 * time.Second):
t.Fatal("timeout waiting for result")
}
events := collectEventStream(sub.C)
skippedEvts := filterEvents(events, EventKindToolExecSkipped)
if len(skippedEvts) < 1 {
t.Fatal("expected at least one ToolExecSkipped event after steering")
}
for _, evt := range skippedEvts {
payload, ok := evt.Payload.(ToolExecSkippedPayload)
if !ok {
t.Fatalf("expected ToolExecSkippedPayload, got %T", evt.Payload)
}
if payload.Reason != "queued user steering message" {
t.Fatalf("expected skip reason 'queued user steering message', got %q", payload.Reason)
}
}
}
func filterEvents(events []Event, kind EventKind) []Event {
var result []Event
for _, evt := range events {
if evt.Kind == kind {
result = append(result, evt)
}
}
return result
}
+52
View File
@@ -9,6 +9,7 @@ import (
"strings"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/isolation"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/media"
"github.com/sipeed/picoclaw/pkg/memory"
@@ -51,6 +52,10 @@ type AgentInstance struct {
// LightProvider is the concrete provider instance for the configured light model.
// It is only used when routing selects the light tier for a turn.
LightProvider providers.LLMProvider
// CandidateProviders maps "provider/model" keys to per-candidate LLMProvider
// instances. This allows each fallback model to use its own api_base and api_key
// from model_list, instead of inheriting the primary model's provider config.
CandidateProviders map[string]providers.LLMProvider
}
// NewAgentInstance creates an agent instance from config.
@@ -60,6 +65,12 @@ func NewAgentInstance(
cfg *config.Config,
provider providers.LLMProvider,
) *AgentInstance {
if cfg != nil {
// Keep the subprocess isolation runtime aligned with the latest loaded config
// before any tools or providers start spawning child processes.
isolation.Configure(cfg)
}
workspace := resolveAgentWorkspace(agentCfg, defaults)
os.MkdirAll(workspace, 0o755)
@@ -175,6 +186,9 @@ func NewAgentInstance(
// Resolve fallback candidates
candidates := resolveModelCandidates(cfg, defaults.Provider, model, fallbacks)
candidateProviders := make(map[string]providers.LLMProvider)
populateCandidateProvidersFromNames(cfg, workspace, fallbacks, candidateProviders)
// Model routing setup: pre-resolve light model candidates at creation time
// to avoid repeated model_list lookups on every incoming message.
var router *routing.Router
@@ -199,6 +213,7 @@ func NewAgentInstance(
})
lightCandidates = resolved
lightProvider = lp
populateCandidateProvidersFromNames(cfg, workspace, []string{rc.LightModel}, candidateProviders)
}
}
} else {
@@ -230,6 +245,43 @@ func NewAgentInstance(
Router: router,
LightCandidates: lightCandidates,
LightProvider: lightProvider,
CandidateProviders: candidateProviders,
}
}
// populateCandidateProvidersFromNames resolves each model name (alias or
// "provider/model") via resolvedModelConfig and creates a dedicated LLMProvider
// for it. This reuses the canonical config resolution path (GetModelConfig) so
// alias handling and load-balancing stay consistent with the rest of the codebase.
func populateCandidateProvidersFromNames(
cfg *config.Config,
workspace string,
names []string,
out map[string]providers.LLMProvider,
) {
if cfg == nil || len(names) == 0 {
return
}
for _, name := range names {
mc, err := resolvedModelConfig(cfg, strings.TrimSpace(name), workspace)
if err != nil {
logger.WarnCF("agent",
"fallback provider: no model_list entry found; will inherit primary provider credentials",
map[string]any{"name": name, "error": err.Error()})
continue
}
protocol, modelID := providers.ExtractProtocol(strings.TrimSpace(mc.Model))
key := providers.ModelKey(providers.NormalizeProvider(protocol), modelID)
if _, exists := out[key]; exists {
continue
}
p, _, err := providers.CreateProviderFromConfig(mc)
if err != nil {
logger.WarnCF("agent", "fallback provider: failed to create provider",
map[string]any{"model": mc.Model, "error": err.Error()})
continue
}
out[key] = p
}
}
+194
View File
@@ -9,6 +9,7 @@ import (
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/media"
"github.com/sipeed/picoclaw/pkg/providers"
)
func TestNewAgentInstance_UsesDefaultsTemperatureAndMaxTokens(t *testing.T) {
@@ -300,6 +301,199 @@ func TestNewAgentInstance_AllowsMediaTempDirForReadListAndExec(t *testing.T) {
}
}
// TestPopulateCandidateProviders_NilCfgIsNoop verifies that passing a nil
// config does not panic and leaves the output map empty.
func TestPopulateCandidateProviders_NilCfgIsNoop(t *testing.T) {
out := map[string]providers.LLMProvider{}
populateCandidateProvidersFromNames(nil, t.TempDir(), []string{"gpt-4o"}, out)
if len(out) != 0 {
t.Fatalf("expected empty map, got %d entries", len(out))
}
}
// TestPopulateCandidateProviders_SkipsExistingKeys verifies that a key already
// present in the output map is not overwritten.
func TestPopulateCandidateProviders_SkipsExistingKeys(t *testing.T) {
existing := &mockProvider{}
key := providers.ModelKey("openai", "gpt-4o")
out := map[string]providers.LLMProvider{key: existing}
cfg := &config.Config{
ModelList: []*config.ModelConfig{
{ModelName: "my-gpt", Model: "openai/gpt-4o", APIKeys: config.SimpleSecureStrings("test-key")},
},
}
populateCandidateProvidersFromNames(cfg, t.TempDir(), []string{"my-gpt"}, out)
if out[key] != existing {
t.Fatal("existing provider entry was overwritten; expected it to be preserved")
}
}
// TestPopulateCandidateProviders_ResolvesAlias verifies that a model_name
// alias (e.g. "my-gpt") is resolved via GetModelConfig and the provider
// is created using the underlying model's config.
func TestPopulateCandidateProviders_ResolvesAlias(t *testing.T) {
workspace := t.TempDir()
out := map[string]providers.LLMProvider{}
cfg := &config.Config{
ModelList: []*config.ModelConfig{
{ModelName: "my-gpt", Model: "openai/gpt-4o", APIBase: "https://api.openai.com/v1", Workspace: workspace},
},
}
populateCandidateProvidersFromNames(cfg, workspace, []string{"my-gpt"}, out)
key := providers.ModelKey("openai", "gpt-4o")
if out[key] == nil {
t.Fatalf("expected CandidateProviders[%q] to be populated for alias", key)
}
}
// TestPopulateCandidateProviders_ResolvesProtocolPrefix verifies that a
// model_list entry using full "provider/model" notation (e.g.
// "gemini/gemma-3-27b-it") is matched correctly when referenced by model_name.
func TestPopulateCandidateProviders_ResolvesProtocolPrefix(t *testing.T) {
workspace := t.TempDir()
out := map[string]providers.LLMProvider{}
cfg := &config.Config{
ModelList: []*config.ModelConfig{
{
ModelName: "gemma",
Model: "gemini/gemma-3-27b-it",
APIKeys: config.SimpleSecureStrings("gemini-test-key"),
Workspace: workspace,
},
},
}
populateCandidateProvidersFromNames(cfg, workspace, []string{"gemma"}, out)
key := providers.ModelKey("gemini", "gemma-3-27b-it")
if out[key] == nil {
t.Fatalf("expected CandidateProviders[%q] to be populated for protocol-prefixed model", key)
}
}
// TestPopulateCandidateProviders_EmptyNamesIsNoop verifies the early-exit
// path when the names slice is empty.
func TestPopulateCandidateProviders_EmptyNamesIsNoop(t *testing.T) {
out := map[string]providers.LLMProvider{}
cfg := &config.Config{
ModelList: []*config.ModelConfig{
{ModelName: "my-gpt", Model: "openai/gpt-4o", APIKeys: config.SimpleSecureStrings("key")},
},
}
populateCandidateProvidersFromNames(cfg, t.TempDir(), nil, out)
if len(out) != 0 {
t.Fatalf("expected empty map, got %d entries", len(out))
}
}
// TestPopulateCandidateProviders_EmptyModelListIsNoop verifies the early-exit
// path when model_list is empty — no provider can be created.
func TestPopulateCandidateProviders_EmptyModelListIsNoop(t *testing.T) {
out := map[string]providers.LLMProvider{}
cfg := &config.Config{}
populateCandidateProvidersFromNames(cfg, t.TempDir(), []string{"gpt-4o"}, out)
if len(out) != 0 {
t.Fatalf("expected empty map, got %d entries", len(out))
}
}
// TestPopulateCandidateProviders_UnmatchedNameIsSkipped verifies that a
// name with no matching model_list entry is skipped and does not
// cause a panic or leave a nil entry in the map.
func TestPopulateCandidateProviders_UnmatchedNameIsSkipped(t *testing.T) {
out := map[string]providers.LLMProvider{}
cfg := &config.Config{
ModelList: []*config.ModelConfig{
{ModelName: "my-gpt", Model: "openai/gpt-4o", APIKeys: config.SimpleSecureStrings("key")},
},
}
populateCandidateProvidersFromNames(cfg, t.TempDir(), []string{"nonexistent-model"}, out)
if len(out) != 0 {
t.Fatalf("expected empty map for unmatched name, got %d entries", len(out))
}
}
// TestNewAgentInstance_CandidateProvidersPopulatedForCrossProviderFallbacks
// mirrors the exact scenario from bug #2140: primary model on OpenRouter with
// Gemini fallbacks. Each entry must get its own provider instance so that
// fallback requests go to the correct API endpoint, not the primary's.
func TestNewAgentInstance_CandidateProvidersPopulatedForCrossProviderFallbacks(t *testing.T) {
workspace := t.TempDir()
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: workspace,
ModelName: "mistral-small-3.1",
ModelFallbacks: []string{"gemma-3-27b", "gemini-images"},
},
},
ModelList: []*config.ModelConfig{
{
ModelName: "mistral-small-3.1",
Model: "openrouter/mistralai/mistral-small-3.1-24b-instruct:free",
APIBase: "https://openrouter.ai/api/v1",
APIKeys: config.SimpleSecureStrings("sk-or-test"),
Workspace: workspace,
},
{
ModelName: "gemma-3-27b",
Model: "gemini/gemma-3-27b-it",
APIKeys: config.SimpleSecureStrings("AIzaSy-test"),
Workspace: workspace,
},
{
ModelName: "gemini-images",
Model: "gemini/gemini-2.5-flash-lite",
APIKeys: config.SimpleSecureStrings("AIzaSy-test"),
Workspace: workspace,
},
},
}
primaryProvider := &mockProvider{}
agent := NewAgentInstance(nil, &cfg.Agents.Defaults, cfg, primaryProvider)
// Only fallback models need entries — the primary uses the injected provider directly.
wantKeys := []string{
providers.ModelKey("gemini", "gemma-3-27b-it"),
providers.ModelKey("gemini", "gemini-2.5-flash-lite"),
}
for _, key := range wantKeys {
p, ok := agent.CandidateProviders[key]
if !ok {
t.Errorf("CandidateProviders missing key %q", key)
continue
}
if p == nil {
t.Errorf("CandidateProviders[%q] is nil", key)
}
// Each fallback must use its own provider, not the injected primary.
if p == primaryProvider {
t.Errorf(
"CandidateProviders[%q] is the same instance as the primary provider; fallback would inherit primary credentials",
key,
)
}
}
if t.Failed() {
t.Logf("CandidateProviders keys present: %v", func() []string {
keys := make([]string, 0, len(agent.CandidateProviders))
for k := range agent.CandidateProviders {
keys = append(keys, k)
}
return keys
}())
}
}
func TestNewAgentInstance_ReadFileModeSelectsSchema(t *testing.T) {
workspace := t.TempDir()
+240 -2
View File
@@ -1742,6 +1742,7 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
if err := al.contextManager.Compact(turnCtx, &CompactRequest{
SessionKey: ts.sessionKey,
Reason: ContextCompressReasonProactive,
Budget: ts.agent.ContextWindow,
}); err != nil {
logger.WarnCF("agent", "Proactive compact failed", map[string]any{
"session_key": ts.sessionKey,
@@ -1857,6 +1858,7 @@ turnLoop:
if !ts.opts.NoHistory {
ts.agent.Sessions.AddFullMessage(ts.sessionKey, pm)
ts.recordPersistedMessage(pm)
ts.ingestMessage(turnCtx, al, pm)
}
logger.InfoCF("agent", "Injected steering message into context",
map[string]any{
@@ -2018,7 +2020,11 @@ turnLoop:
providerCtx,
activeCandidates,
func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) {
return activeProvider.Chat(ctx, messagesForCall, toolDefsForCall, model, llmOpts)
candidateProvider := activeProvider
if cp, ok := ts.agent.CandidateProviders[providers.ModelKey(provider, model)]; ok {
candidateProvider = cp
}
return candidateProvider.Chat(ctx, messagesForCall, toolDefsForCall, model, llmOpts)
},
)
if fbErr != nil {
@@ -2128,6 +2134,7 @@ turnLoop:
if compactErr := al.contextManager.Compact(turnCtx, &CompactRequest{
SessionKey: ts.sessionKey,
Reason: ContextCompressReasonRetry,
Budget: ts.agent.ContextWindow,
}); compactErr != nil {
logger.WarnCF("agent", "Context overflow compact failed", map[string]any{
"session_key": ts.sessionKey,
@@ -2345,6 +2352,236 @@ turnLoop:
toolName = toolReq.Tool
toolArgs = toolReq.Arguments
}
case HookActionRespond:
// Hook returns result directly, skip tool execution.
// SECURITY: This bypasses ApproveTool, allowing hooks to respond
// for any tool name without approval. This is intentional for
// plugin tools but means a before_tool hook can override even
// sensitive tools like bash. Hook configuration should be
// carefully reviewed to prevent unauthorized tool execution.
if toolReq != nil && toolReq.HookResult != nil {
hookResult := toolReq.HookResult
argsJSON, _ := json.Marshal(toolArgs)
argsPreview := utils.Truncate(string(argsJSON), 200)
logger.InfoCF("agent", fmt.Sprintf("Tool call (hook respond): %s(%s)", toolName, argsPreview),
map[string]any{
"agent_id": ts.agent.ID,
"tool": toolName,
"iteration": iteration,
})
// Emit ToolExecStart event (same as normal tool execution)
al.emitEvent(
EventKindToolExecStart,
ts.eventMeta("runTurn", "turn.tool.start"),
ToolExecStartPayload{
Tool: toolName,
Arguments: cloneEventArguments(toolArgs),
},
)
// Send tool feedback to chat channel if enabled (same as normal tool execution)
if al.cfg.Agents.Defaults.IsToolFeedbackEnabled() &&
ts.channel != "" &&
!ts.opts.SuppressToolFeedback {
argsJSON, _ := json.Marshal(toolArgs)
feedbackPreview := utils.Truncate(
string(argsJSON),
al.cfg.Agents.Defaults.GetToolFeedbackMaxArgsLength(),
)
feedbackMsg := fmt.Sprintf("\U0001f527 `%s`\n```\n%s\n```", toolName, feedbackPreview)
fbCtx, fbCancel := context.WithTimeout(turnCtx, 3*time.Second)
_ = al.bus.PublishOutbound(fbCtx, bus.OutboundMessage{
Channel: ts.channel,
ChatID: ts.chatID,
Content: feedbackMsg,
})
fbCancel()
}
toolDuration := time.Duration(0) // Hook execution time unknown
// Send ForUser content to user
// For ResponseHandled results, send regardless of SendResponse setting,
// same as normal tool execution path.
shouldSendForUser := !hookResult.Silent && hookResult.ForUser != "" &&
(ts.opts.SendResponse || hookResult.ResponseHandled)
if shouldSendForUser {
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
Channel: ts.channel,
ChatID: ts.chatID,
Content: hookResult.ForUser,
Metadata: map[string]string{
"is_tool_call": "true",
},
})
}
// Handle media from hook result (same as normal tool execution)
if len(hookResult.Media) > 0 && hookResult.ResponseHandled {
parts := make([]bus.MediaPart, 0, len(hookResult.Media))
for _, ref := range hookResult.Media {
part := bus.MediaPart{Ref: ref}
if al.mediaStore != nil {
if _, meta, err := al.mediaStore.ResolveWithMeta(ref); err == nil {
part.Filename = meta.Filename
part.ContentType = meta.ContentType
part.Type = inferMediaType(meta.Filename, meta.ContentType)
}
}
parts = append(parts, part)
}
outboundMedia := bus.OutboundMediaMessage{
Channel: ts.channel,
ChatID: ts.chatID,
Parts: parts,
}
if al.channelManager != nil && ts.channel != "" && !constants.IsInternalChannel(ts.channel) {
if err := al.channelManager.SendMedia(ctx, outboundMedia); err != nil {
logger.WarnCF("agent", "Failed to deliver hook media",
map[string]any{
"agent_id": ts.agent.ID,
"tool": toolName,
"channel": ts.channel,
"chat_id": ts.chatID,
"error": err.Error(),
})
// Same as normal tool execution: notify LLM about delivery failure
hookResult.IsError = true
hookResult.ForLLM = fmt.Sprintf("failed to deliver attachment: %v", err)
}
} else if al.bus != nil {
al.bus.PublishOutboundMedia(ctx, outboundMedia)
// Same as normal tool execution: bus only queues, media not yet delivered
hookResult.ResponseHandled = false
}
}
// Track response handling status (same as normal tool execution)
if !hookResult.ResponseHandled {
allResponsesHandled = false
}
// Build tool message
contentForLLM := hookResult.ContentForLLM()
if al.cfg.Tools.IsFilterSensitiveDataEnabled() {
contentForLLM = al.cfg.FilterSensitiveData(contentForLLM)
}
toolResultMsg := providers.Message{
Role: "tool",
Content: contentForLLM,
ToolCallID: tc.ID,
}
// Handle media for LLM vision (same as normal tool execution)
if len(hookResult.Media) > 0 && !hookResult.ResponseHandled {
hookResult.ArtifactTags = buildArtifactTags(al.mediaStore, hookResult.Media)
// Recalculate contentForLLM after adding ArtifactTags
contentForLLM = hookResult.ContentForLLM()
if al.cfg.Tools.IsFilterSensitiveDataEnabled() {
contentForLLM = al.cfg.FilterSensitiveData(contentForLLM)
}
toolResultMsg.Content = contentForLLM
toolResultMsg.Media = append(toolResultMsg.Media, hookResult.Media...)
}
// Emit ToolExecEnd event (after filtering, same as normal tool execution)
al.emitEvent(
EventKindToolExecEnd,
ts.eventMeta("runTurn", "turn.tool.end"),
ToolExecEndPayload{
Tool: toolName,
Duration: toolDuration,
ForLLMLen: len(contentForLLM),
ForUserLen: len(hookResult.ForUser),
IsError: hookResult.IsError,
Async: hookResult.Async,
},
)
messages = append(messages, toolResultMsg)
if !ts.opts.NoHistory {
ts.agent.Sessions.AddFullMessage(ts.sessionKey, toolResultMsg)
ts.recordPersistedMessage(toolResultMsg)
ts.ingestMessage(turnCtx, al, toolResultMsg)
}
// Same as normal tool execution: check for steering/interrupt/SubTurn after each tool
if steerMsgs := al.dequeueSteeringMessagesForScope(ts.sessionKey); len(steerMsgs) > 0 {
pendingMessages = append(pendingMessages, steerMsgs...)
}
skipReason := ""
skipMessage := ""
if len(pendingMessages) > 0 {
skipReason = "queued user steering message"
skipMessage = "Skipped due to queued user message."
} else if gracefulPending, _ := ts.gracefulInterruptRequested(); gracefulPending {
skipReason = "graceful interrupt requested"
skipMessage = "Skipped due to graceful interrupt."
}
if skipReason != "" {
remaining := len(normalizedToolCalls) - i - 1
if remaining > 0 {
logger.InfoCF("agent", "Turn checkpoint: skipping remaining tools after hook respond",
map[string]any{
"agent_id": ts.agent.ID,
"completed": i + 1,
"skipped": remaining,
"reason": skipReason,
})
for j := i + 1; j < len(normalizedToolCalls); j++ {
skippedTC := normalizedToolCalls[j]
al.emitEvent(
EventKindToolExecSkipped,
ts.eventMeta("runTurn", "turn.tool.skipped"),
ToolExecSkippedPayload{
Tool: skippedTC.Name,
Reason: skipReason,
},
)
skippedMsg := providers.Message{
Role: "tool",
Content: skipMessage,
ToolCallID: skippedTC.ID,
}
messages = append(messages, skippedMsg)
if !ts.opts.NoHistory {
ts.agent.Sessions.AddFullMessage(ts.sessionKey, skippedMsg)
ts.recordPersistedMessage(skippedMsg)
}
}
}
break
}
// Also poll for any SubTurn results that arrived during tool execution.
if ts.pendingResults != nil {
select {
case result, ok := <-ts.pendingResults:
if ok && result != nil && result.ForLLM != "" {
content := al.cfg.FilterSensitiveData(result.ForLLM)
msg := providers.Message{Role: "user", Content: fmt.Sprintf("[SubTurn Result] %s", content)}
messages = append(messages, msg)
ts.agent.Sessions.AddFullMessage(ts.sessionKey, msg)
}
default:
// No results available
}
}
continue
}
// If no HookResult, fall back to continue with warning
logger.WarnCF("agent", "Hook returned respond action but no HookResult provided",
map[string]any{
"agent_id": ts.agent.ID,
"tool": toolName,
"action": "respond",
})
case HookActionDenyTool:
allResponsesHandled = false
denyContent := hookDeniedToolContent("Tool execution denied by hook", decision.Reason)
@@ -2773,7 +3010,7 @@ turnLoop:
}
}
if ts.opts.EnableSummary {
al.contextManager.Compact(turnCtx, &CompactRequest{SessionKey: ts.sessionKey, Reason: ContextCompressReasonSummarize})
al.contextManager.Compact(turnCtx, &CompactRequest{SessionKey: ts.sessionKey, Reason: ContextCompressReasonSummarize, Budget: ts.agent.ContextWindow})
}
ts.setPhase(TurnPhaseCompleted)
@@ -2849,6 +3086,7 @@ turnLoop:
&CompactRequest{
SessionKey: ts.sessionKey,
Reason: ContextCompressReasonSummarize,
Budget: ts.agent.ContextWindow,
},
)
}
+2
View File
@@ -126,6 +126,8 @@ func (al *AgentLoop) ensureMCPInitialized(ctx context.Context) error {
}
mcpTool := tools.NewMCPTool(mcpManager, serverName, tool)
mcpTool.SetWorkspace(agent.Workspace)
mcpTool.SetMaxInlineTextRunes(al.cfg.Tools.MCP.GetMaxInlineTextChars())
if registerAsHidden {
agent.Tools.RegisterHidden(mcpTool)
+158
View File
@@ -1839,6 +1839,164 @@ func TestProcessMessage_ModelRoutingUsesLightProvider(t *testing.T) {
}
}
// TestProcessMessage_FallbackUsesPerCandidateProvider is the loop-level test for
// bug #2140. It verifies that when the primary model returns a rate-limit error
// the fallback closure routes the retry to the fallback model's own provider
// (its own api_base), not back to the primary provider's endpoint.
func TestProcessMessage_FallbackUsesPerCandidateProvider(t *testing.T) {
workspace := t.TempDir()
primaryCalls := 0
primaryServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
primaryCalls++
// Return 429 so FallbackChain classifies this as retriable and moves on.
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusTooManyRequests)
_ = json.NewEncoder(w).Encode(map[string]any{
"error": map[string]any{
"message": "rate limit exceeded",
"type": "rate_limit_error",
},
})
}))
defer primaryServer.Close()
fallbackCalls := 0
fallbackServer := newStrictChatCompletionTestServer(
t, "fallback", "gemma-3-27b-it", "fallback reply", &fallbackCalls,
)
defer fallbackServer.Close()
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: workspace,
ModelName: "mistral-primary",
ModelFallbacks: []string{"gemma-fallback"},
MaxTokens: 4096,
MaxToolIterations: 3,
},
},
ModelList: []*config.ModelConfig{
{
ModelName: "mistral-primary",
Model: "openrouter/mistralai/mistral-small-3.1",
APIBase: primaryServer.URL,
APIKeys: config.SimpleSecureStrings("primary-key"),
Workspace: workspace,
},
{
ModelName: "gemma-fallback",
Model: "gemini/gemma-3-27b-it",
APIBase: fallbackServer.URL,
APIKeys: config.SimpleSecureStrings("fallback-key"),
Workspace: workspace,
},
},
}
provider, _, err := providers.CreateProvider(cfg)
if err != nil {
t.Fatalf("CreateProvider() error = %v", err)
}
msgBus := bus.NewMessageBus()
al := NewAgentLoop(cfg, msgBus, provider)
helper := testHelper{al: al}
resp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
Channel: "telegram",
SenderID: "user1",
ChatID: "chat1",
Content: "hi",
Peer: bus.Peer{Kind: "direct", ID: "user1"},
})
if resp != "fallback reply" {
t.Fatalf("response = %q, want %q (fallback provider)", resp, "fallback reply")
}
if primaryCalls == 0 {
t.Fatal("primary server was never called; expected at least one attempt")
}
if fallbackCalls != 1 {
t.Fatalf("fallback server calls = %d, want 1", fallbackCalls)
}
}
// TestProcessMessage_FallbackUsesActiveProviderWhenCandidateNotRegistered verifies
// that when a candidate has no model_list entry it is absent from CandidateProviders
// and the fallback closure falls back to activeProvider instead of panicking.
func TestProcessMessage_FallbackUsesActiveProviderWhenCandidateNotRegistered(t *testing.T) {
workspace := t.TempDir()
// Primary server: returns 429 on first call, succeeds on second.
// Both the primary and the unregistered fallback share this server
// (same api_base) so activeProvider routes both calls here.
callCount := 0
primaryServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
callCount++
w.Header().Set("Content-Type", "application/json")
if callCount == 1 {
w.WriteHeader(http.StatusTooManyRequests)
_ = json.NewEncoder(w).Encode(map[string]any{
"error": map[string]any{"message": "rate limit", "type": "rate_limit_error"},
})
return
}
// Second call (fallback via activeProvider) succeeds.
_ = json.NewEncoder(w).Encode(map[string]any{
"choices": []map[string]any{
{"message": map[string]any{"content": "active provider reply"}, "finish_reason": "stop"},
},
})
}))
defer primaryServer.Close()
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: workspace,
ModelName: "primary-model",
MaxTokens: 4096,
MaxToolIterations: 3,
// No model_list entry for this alias — absent from CandidateProviders.
ModelFallbacks: []string{"openrouter/fallback-model"},
},
},
ModelList: []*config.ModelConfig{
{
ModelName: "primary-model",
Model: "openrouter/primary-model",
APIBase: primaryServer.URL,
APIKeys: config.SimpleSecureStrings("primary-key"),
Workspace: workspace,
},
},
}
provider, _, err := providers.CreateProvider(cfg)
if err != nil {
t.Fatalf("CreateProvider() error = %v", err)
}
msgBus := bus.NewMessageBus()
al := NewAgentLoop(cfg, msgBus, provider)
helper := testHelper{al: al}
resp := helper.executeAndGetResponse(t, context.Background(), bus.InboundMessage{
Channel: "telegram",
SenderID: "user1",
ChatID: "chat1",
Content: "hi",
Peer: bus.Peer{Kind: "direct", ID: "user1"},
})
if resp != "active provider reply" {
t.Fatalf("response = %q, want %q", resp, "active provider reply")
}
if callCount < 2 {
t.Fatalf("primary server calls = %d, want >= 2 (one 429 + one success via activeProvider)", callCount)
}
}
// TestToolResult_SilentToolDoesNotSendUserMessage verifies silent tools don't trigger outbound
func TestToolResult_SilentToolDoesNotSendUserMessage(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
+4 -2
View File
@@ -604,6 +604,7 @@ type ephemeralSessionStoreIface interface {
SetHistory(key string, history []providers.Message)
TruncateHistory(key string, keepLast int)
Save(key string) error
ListSessions() []string
Close() error
}
@@ -663,8 +664,9 @@ func (e *ephemeralSessionStore) TruncateHistory(_ string, keepLast int) {
e.history = e.history[len(e.history)-keepLast:]
}
func (e *ephemeralSessionStore) Save(_ string) error { return nil }
func (e *ephemeralSessionStore) Close() error { return nil }
func (e *ephemeralSessionStore) Save(_ string) error { return nil }
func (e *ephemeralSessionStore) Close() error { return nil }
func (e *ephemeralSessionStore) ListSessions() []string { return nil }
func (e *ephemeralSessionStore) truncateLocked() {
if len(e.history) > maxEphemeralHistorySize {
+22 -18
View File
@@ -14,6 +14,7 @@ import (
"strings"
"sync"
"sync/atomic"
"time"
lark "github.com/larksuite/oapi-sdk-go/v3"
larkcore "github.com/larksuite/oapi-sdk-go/v3/core"
@@ -42,12 +43,18 @@ type FeishuChannel struct {
wsClient *larkws.Client
tokenCache *tokenCache // custom cache that supports invalidation
botOpenID atomic.Value // stores string; populated lazily for @mention detection
botOpenID atomic.Value // stores string; populated lazily for @mention detection
messageCache sync.Map // caches fetched messages (messageID -> *larkim.Message)
mu sync.Mutex
cancel context.CancelFunc
}
type cachedMessage struct {
msg *larkim.Message
expiry time.Time
}
func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChannel, error) {
base := channels.NewBaseChannel("feishu", cfg, bus, cfg.AllowFrom,
channels.WithGroupTrigger(cfg.GroupTrigger),
@@ -436,24 +443,8 @@ func (c *FeishuChannel) handleMessageReceive(ctx context.Context, event *larkim.
// Append media tags to content (like Telegram does)
content = appendMediaTags(content, messageType, mediaRefs)
if content == "" {
content = "[empty message]"
}
metadata := map[string]string{}
if messageID != "" {
metadata["message_id"] = messageID
}
if messageType != "" {
metadata["message_type"] = messageType
}
chatType := stringValue(message.ChatType)
if chatType != "" {
metadata["chat_type"] = chatType
}
if sender != nil && sender.TenantKey != nil {
metadata["tenant_key"] = *sender.TenantKey
}
metadata := buildInboundMetadata(message, sender)
var peer bus.Peer
if chatType == "p2p" {
@@ -477,12 +468,25 @@ func (c *FeishuChannel) handleMessageReceive(ctx context.Context, event *larkim.
content = cleaned
}
if replyTargetID(message) != "" || stringValue(message.ThreadId) != "" {
content, mediaRefs = c.prependReplyContext(ctx, message, chatID, content, mediaRefs)
}
if content == "" {
content = "[empty message]"
}
logger.InfoCF("feishu", "Feishu message received", map[string]any{
"sender_id": senderID,
"chat_id": chatID,
"message_id": messageID,
"preview": utils.Truncate(content, 80),
})
logger.InfoCF("feishu", "Feishu reply linkage", map[string]any{
"message_id": messageID,
"parent_id": stringValue(message.ParentId),
"root_id": stringValue(message.RootId),
"thread_id": stringValue(message.ThreadId),
})
c.HandleMessage(ctx, peer, messageID, senderID, chatID, content, mediaRefs, metadata, senderInfo)
return nil
+298
View File
@@ -0,0 +1,298 @@
//go:build amd64 || arm64 || riscv64 || mips64 || ppc64
package feishu
import (
"context"
"fmt"
"strings"
"time"
larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/utils"
)
const messageCacheTTL = 30 * time.Second
const (
maxReplyContextLen = 600
)
func (c *FeishuChannel) prependReplyContext(
ctx context.Context,
message *larkim.EventMessage,
chatID string,
content string,
mediaRefs []string,
) (string, []string) {
if message == nil {
return content, mediaRefs
}
lookupCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
targetMessageID := c.resolveReplyTargetMessageID(lookupCtx, message)
if targetMessageID == "" {
logger.DebugCF("feishu", "No reply target resolved; skip reply context", map[string]any{
"message_id": stringValue(message.MessageId),
"parent_id": stringValue(message.ParentId),
"root_id": stringValue(message.RootId),
"thread_id": stringValue(message.ThreadId),
})
return content, mediaRefs
}
repliedMessage, err := c.fetchMessageByID(lookupCtx, targetMessageID)
if err != nil {
logger.DebugCF("feishu", "Failed to fetch replied message context", map[string]any{
"target_message_id": targetMessageID,
"error": err.Error(),
})
return content, mediaRefs
}
messageType := stringValue(repliedMessage.MsgType)
rawContent := ""
if repliedMessage.Body != nil {
rawContent = stringValue(repliedMessage.Body.Content)
}
var repliedMediaRefs []string
if store := c.GetMediaStore(); store != nil {
repliedMediaRefs = c.downloadInboundMedia(lookupCtx, chatID, targetMessageID, messageType, rawContent, store)
if messageType == larkim.MsgTypeInteractive {
_, externalURLs := extractCardImageKeys(rawContent)
if len(externalURLs) > 0 {
repliedMediaRefs = append(repliedMediaRefs, externalURLs...)
}
}
}
repliedContent := normalizeRepliedContent(messageType, rawContent, repliedMediaRefs)
if len(repliedMediaRefs) > 0 {
mediaRefs = append(repliedMediaRefs, mediaRefs...)
}
return formatReplyContext(targetMessageID, repliedContent, content), mediaRefs
}
func (c *FeishuChannel) resolveReplyTargetMessageID(ctx context.Context, message *larkim.EventMessage) string {
if targetID := replyTargetID(message); targetID != "" {
logger.DebugCF("feishu", "Resolved reply target from event payload", map[string]any{
"message_id": stringValue(message.MessageId),
"parent_id": stringValue(message.ParentId),
"root_id": stringValue(message.RootId),
"target_id": targetID,
})
return targetID
}
currentMessageID := stringValue(message.MessageId)
if currentMessageID == "" {
return ""
}
if stringValue(message.ThreadId) == "" {
logger.DebugCF("feishu", "No reply target found; message is not in a thread", map[string]any{
"message_id": stringValue(message.MessageId),
})
return ""
}
msg, err := c.fetchMessageByID(ctx, currentMessageID)
if err != nil {
logger.DebugCF("feishu", "Failed to query current message detail for reply info", map[string]any{
"message_id": currentMessageID,
"error": err.Error(),
})
return ""
}
targetID := replyTargetIDFromMessage(msg)
if targetID != "" {
logger.DebugCF("feishu", "Resolved reply target from message detail", map[string]any{
"message_id": currentMessageID,
"parent_id": stringValue(msg.ParentId),
"root_id": stringValue(msg.RootId),
"target_id": targetID,
})
}
return targetID
}
func (c *FeishuChannel) fetchMessageByID(ctx context.Context, messageID string) (*larkim.Message, error) {
if cached, ok := c.messageCache.Load(messageID); ok {
cm := cached.(*cachedMessage)
if time.Now().Before(cm.expiry) {
return cm.msg, nil
}
c.messageCache.Delete(messageID)
}
req := larkim.NewGetMessageReqBuilder().
MessageId(messageID).
Build()
resp, err := c.client.Im.V1.Message.Get(ctx, req)
if err != nil {
return nil, fmt.Errorf("feishu get message: %w", err)
}
if !resp.Success() {
c.invalidateTokenOnAuthError(resp.Code)
return nil, fmt.Errorf("feishu get message api error (code=%d msg=%s)", resp.Code, resp.Msg)
}
if resp.Data == nil || len(resp.Data.Items) == 0 || resp.Data.Items[0] == nil {
return nil, fmt.Errorf("feishu get message: empty response")
}
// Items[0] contains the target message - the Feishu API returns a list
// but we request a single message by ID, so the list always has at most one item.
msg := resp.Data.Items[0]
c.messageCache.Store(messageID, &cachedMessage{msg: msg, expiry: time.Now().Add(messageCacheTTL)})
return msg, nil
}
func replyTargetID(message *larkim.EventMessage) string {
if message == nil {
return ""
}
if parentID := stringValue(message.ParentId); parentID != "" {
return parentID
}
return stringValue(message.RootId)
}
func replyTargetIDFromMessage(message *larkim.Message) string {
if message == nil {
return ""
}
if parentID := stringValue(message.ParentId); parentID != "" {
return parentID
}
return stringValue(message.RootId)
}
func buildInboundMetadata(message *larkim.EventMessage, sender *larkim.EventSender) map[string]string {
metadata := map[string]string{}
if message == nil {
return metadata
}
messageID := stringValue(message.MessageId)
if messageID != "" {
metadata["message_id"] = messageID
}
messageType := stringValue(message.MessageType)
if messageType != "" {
metadata["message_type"] = messageType
}
chatType := stringValue(message.ChatType)
if chatType != "" {
metadata["chat_type"] = chatType
}
parentID := stringValue(message.ParentId)
if parentID != "" {
metadata["parent_id"] = parentID
}
rootID := stringValue(message.RootId)
if rootID != "" {
metadata["root_id"] = rootID
}
if replyTo := replyTargetID(message); replyTo != "" {
metadata["reply_to_message_id"] = replyTo
}
threadID := stringValue(message.ThreadId)
if threadID != "" {
metadata["thread_id"] = threadID
}
if sender != nil && sender.TenantKey != nil && *sender.TenantKey != "" {
metadata["tenant_key"] = *sender.TenantKey
}
return metadata
}
func normalizeRepliedContent(messageType, rawContent string, mediaRefs []string) string {
content := extractContent(messageType, rawContent)
if containsFeishuUpgradePlaceholder(rawContent) || containsFeishuUpgradePlaceholder(content) {
content = ""
}
content = appendMediaTags(content, messageType, mediaRefs)
if strings.TrimSpace(content) != "" {
return content
}
switch messageType {
case larkim.MsgTypeImage:
return "[replied image]"
case larkim.MsgTypeFile:
return "[replied file]"
case larkim.MsgTypeAudio:
return "[replied audio]"
case larkim.MsgTypeMedia:
return "[replied video]"
case larkim.MsgTypeInteractive:
return "[replied interactive card]"
default:
return "[replied message content unavailable]"
}
}
func containsFeishuUpgradePlaceholder(s string) bool {
upgradePrompt := "\u8bf7\u5347\u7ea7\u81f3\u6700\u65b0\u7248\u672c\u5ba2\u6237\u7aef"
upgradePromptEscaped := "\\u8bf7\\u5347\\u7ea7\\u81f3\\u6700\\u65b0\\u7248\\u672c\\u5ba2\\u6237\\u7aef"
return strings.Contains(s, upgradePrompt) || strings.Contains(s, upgradePromptEscaped)
}
func formatReplyContext(parentID, repliedContent, content string) string {
parentID = strings.TrimSpace(parentID)
repliedContent = strings.TrimSpace(repliedContent)
content = strings.TrimSpace(content)
if parentID == "" || repliedContent == "" {
return content
}
repliedContent = utils.Truncate(repliedContent, maxReplyContextLen)
repliedContent = sanitizeReplyContextContent(repliedContent)
content = sanitizeReplyContextContent(content)
header := fmt.Sprintf("[replied_message id=%q]", parentID)
footer := "[/replied_message]"
if content == "" {
return header + "\n" + repliedContent + "\n" + footer
}
if hasLeadingCommandPrefix(content) {
return content + "\n\n" + header + "\n" + repliedContent + "\n" + footer
}
return header + "\n" + repliedContent + "\n" + footer + "\n\n[current_message]\n" + content + "\n[/current_message]"
}
func hasLeadingCommandPrefix(s string) bool {
tokens := strings.Fields(strings.TrimSpace(s))
if len(tokens) == 0 {
return false
}
first := tokens[0]
return strings.HasPrefix(first, "/") || strings.HasPrefix(first, "!")
}
func sanitizeReplyContextContent(s string) string {
tagEscaper := strings.NewReplacer(
"[replied_message", `\[replied_message`,
"[/replied_message]", `\[/replied_message]`,
"[current_message]", `\[current_message]`,
"[/current_message]", `\[/current_message]`,
)
return tagEscaper.Replace(s)
}
+229
View File
@@ -0,0 +1,229 @@
//go:build amd64 || arm64 || riscv64 || mips64 || ppc64
package feishu
import (
"strings"
"testing"
larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1"
)
func TestBuildInboundMetadata(t *testing.T) {
strPtr := func(s string) *string { return &s }
t.Run("includes basic and reply fields", func(t *testing.T) {
message := &larkim.EventMessage{
MessageId: strPtr("om_msg_1"),
MessageType: strPtr("text"),
ChatType: strPtr("group"),
ParentId: strPtr("om_parent_1"),
RootId: strPtr("om_root_1"),
ThreadId: strPtr("omt_thread_1"),
}
sender := &larkim.EventSender{TenantKey: strPtr("tenant_x")}
got := buildInboundMetadata(message, sender)
if got["message_id"] != "om_msg_1" {
t.Fatalf("message_id = %q, want %q", got["message_id"], "om_msg_1")
}
if got["message_type"] != "text" {
t.Fatalf("message_type = %q, want %q", got["message_type"], "text")
}
if got["chat_type"] != "group" {
t.Fatalf("chat_type = %q, want %q", got["chat_type"], "group")
}
if got["parent_id"] != "om_parent_1" {
t.Fatalf("parent_id = %q, want %q", got["parent_id"], "om_parent_1")
}
if got["reply_to_message_id"] != "om_parent_1" {
t.Fatalf("reply_to_message_id = %q, want %q", got["reply_to_message_id"], "om_parent_1")
}
if got["root_id"] != "om_root_1" {
t.Fatalf("root_id = %q, want %q", got["root_id"], "om_root_1")
}
if got["thread_id"] != "omt_thread_1" {
t.Fatalf("thread_id = %q, want %q", got["thread_id"], "omt_thread_1")
}
if got["tenant_key"] != "tenant_x" {
t.Fatalf("tenant_key = %q, want %q", got["tenant_key"], "tenant_x")
}
})
t.Run("falls back reply_to_message_id to root_id", func(t *testing.T) {
message := &larkim.EventMessage{
MessageId: strPtr("om_msg_3"),
RootId: strPtr("om_root_3"),
}
got := buildInboundMetadata(message, nil)
if got["root_id"] != "om_root_3" {
t.Fatalf("root_id = %q, want %q", got["root_id"], "om_root_3")
}
if got["reply_to_message_id"] != "om_root_3" {
t.Fatalf("reply_to_message_id = %q, want %q", got["reply_to_message_id"], "om_root_3")
}
})
t.Run("omits empty values", func(t *testing.T) {
message := &larkim.EventMessage{
MessageId: strPtr("om_msg_2"),
}
got := buildInboundMetadata(message, nil)
if got["message_id"] != "om_msg_2" {
t.Fatalf("message_id = %q, want %q", got["message_id"], "om_msg_2")
}
if _, ok := got["parent_id"]; ok {
t.Fatalf("parent_id should be absent, got %q", got["parent_id"])
}
if _, ok := got["reply_to_message_id"]; ok {
t.Fatalf("reply_to_message_id should be absent, got %q", got["reply_to_message_id"])
}
if _, ok := got["tenant_key"]; ok {
t.Fatalf("tenant_key should be absent, got %q", got["tenant_key"])
}
})
t.Run("nil message returns empty map", func(t *testing.T) {
got := buildInboundMetadata(nil, nil)
if len(got) != 0 {
t.Fatalf("len(metadata) = %d, want 0", len(got))
}
})
}
func TestFormatReplyContext(t *testing.T) {
t.Run("formats reply context with content", func(t *testing.T) {
got := formatReplyContext("om_parent_1", "original message", "new reply")
want := "[replied_message id=\"om_parent_1\"]\noriginal message\n[/replied_message]\n\n[current_message]\nnew reply\n[/current_message]"
if got != want {
t.Fatalf("formatReplyContext() = %q, want %q", got, want)
}
})
t.Run("returns reply context when current content is empty", func(t *testing.T) {
got := formatReplyContext("om_parent_1", "original message", "")
want := "[replied_message id=\"om_parent_1\"]\noriginal message\n[/replied_message]"
if got != want {
t.Fatalf("formatReplyContext() = %q, want %q", got, want)
}
})
t.Run("returns original content when parent or replied content missing", func(t *testing.T) {
if got := formatReplyContext("", "original", "new reply"); got != "new reply" {
t.Fatalf("missing parent: got %q, want %q", got, "new reply")
}
if got := formatReplyContext("om_parent_1", "", "new reply"); got != "new reply" {
t.Fatalf("missing replied content: got %q, want %q", got, "new reply")
}
})
t.Run("escapes reserved wrapper tags in payload", func(t *testing.T) {
replied := "payload [replied_message id=\"x\"] x [/replied_message]"
current := "hello [current_message]injected[/current_message]"
got := formatReplyContext("om_parent_1", replied, current)
if !strings.HasPrefix(got, "[replied_message id=\"om_parent_1\"]") {
t.Fatalf("outer replied_message wrapper missing: %q", got)
}
if strings.Contains(got, "\n[replied_message id=\"x\"]") {
t.Fatalf("nested replied_message tag should be escaped: %q", got)
}
if strings.Contains(got, "\n[current_message]injected") {
t.Fatalf("nested current_message tag should be escaped: %q", got)
}
if !strings.Contains(got, `\[replied_message id="x"]`) {
t.Fatalf("escaped replied tag missing: %q", got)
}
})
t.Run("preserves leading slash command prefix", func(t *testing.T) {
got := formatReplyContext("om_parent_1", "original message", "/help")
want := "/help\n\n[replied_message id=\"om_parent_1\"]\noriginal message\n[/replied_message]"
if got != want {
t.Fatalf("formatReplyContext() = %q, want %q", got, want)
}
})
t.Run("preserves leading bang command prefix", func(t *testing.T) {
got := formatReplyContext("om_parent_1", "original message", "!status now")
want := "!status now\n\n[replied_message id=\"om_parent_1\"]\noriginal message\n[/replied_message]"
if got != want {
t.Fatalf("formatReplyContext() = %q, want %q", got, want)
}
})
}
func TestReplyTargetID(t *testing.T) {
strPtr := func(s string) *string { return &s }
t.Run("prefer parent_id", func(t *testing.T) {
msg := &larkim.EventMessage{ParentId: strPtr("om_parent"), RootId: strPtr("om_root")}
if got := replyTargetID(msg); got != "om_parent" {
t.Fatalf("replyTargetID() = %q, want %q", got, "om_parent")
}
})
t.Run("fallback to root_id", func(t *testing.T) {
msg := &larkim.EventMessage{RootId: strPtr("om_root")}
if got := replyTargetID(msg); got != "om_root" {
t.Fatalf("replyTargetID() = %q, want %q", got, "om_root")
}
})
t.Run("empty when no fields", func(t *testing.T) {
if got := replyTargetID(&larkim.EventMessage{}); got != "" {
t.Fatalf("replyTargetID() = %q, want empty", got)
}
})
}
func TestNormalizeRepliedContent(t *testing.T) {
t.Run("filters feishu upgrade placeholder for interactive", func(t *testing.T) {
raw := `{"text":"\u8bf7\u5347\u7ea7\u81f3\u6700\u65b0\u7248\u672c\u5ba2\u6237\u7aef\uff0c\u4ee5\u67e5\u770b\u5185\u5bb9"}`
got := normalizeRepliedContent("interactive", raw, nil)
if got != "[replied interactive card]" {
t.Fatalf("normalizeRepliedContent() = %q, want %q", got, "[replied interactive card]")
}
})
t.Run("keeps filename and file tag for replied file", func(t *testing.T) {
got := normalizeRepliedContent("file", `{"file_key":"file_xxx","file_name":"doc.pdf"}`, []string{"media://r1"})
if got != "doc.pdf [file]" {
t.Fatalf("normalizeRepliedContent() = %q, want %q", got, "doc.pdf [file]")
}
})
t.Run("falls back when file content missing", func(t *testing.T) {
got := normalizeRepliedContent("file", `{"file_key":"file_xxx"}`, nil)
if got != "[replied file]" {
t.Fatalf("normalizeRepliedContent() = %q, want %q", got, "[replied file]")
}
})
}
func TestHasLeadingCommandPrefix(t *testing.T) {
tests := []struct {
name string
input string
want bool
}{
{name: "slash command", input: "/help", want: true},
{name: "bang command", input: "!status", want: true},
{name: "leading spaces slash", input: " /ping arg", want: true},
{name: "normal text", input: "hello /help", want: false},
{name: "empty", input: "", want: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := hasLeadingCommandPrefix(tt.input); got != tt.want {
t.Fatalf("hasLeadingCommandPrefix(%q) = %v, want %v", tt.input, got, tt.want)
}
})
}
}
+13
View File
@@ -430,6 +430,19 @@ func (m *Manager) initChannels(channels *config.ChannelsConfig) error {
m.initChannel("vk", "VK")
}
if channels.TeamsWebhook.Enabled && len(channels.TeamsWebhook.Webhooks) > 0 {
hasValidTarget := false
for _, target := range channels.TeamsWebhook.Webhooks {
if target.WebhookURL.String() != "" {
hasValidTarget = true
break
}
}
if hasValidTarget {
m.initChannel("teams_webhook", "Teams Webhook")
}
}
logger.InfoCF("channels", "Channel initialization completed", map[string]any{
"enabled_channels": len(m.channels),
})
+16
View File
@@ -62,6 +62,13 @@ func hiddenValues(key string, value map[string]any, ch config.ChannelsConfig) {
value["app_secret"] = ch.Feishu.AppSecret.String()
value["encrypt_key"] = ch.Feishu.EncryptKey.String()
value["verification_token"] = ch.Feishu.VerificationToken.String()
case "teams_webhook":
// Expose webhook URLs for hash computation (they contain secrets)
webhooks := make(map[string]string)
for name, target := range ch.TeamsWebhook.Webhooks {
webhooks[name] = target.WebhookURL.String()
}
value["webhooks"] = webhooks
}
}
@@ -166,4 +173,13 @@ func updateKeys(newcfg, old *config.ChannelsConfig) {
newcfg.Feishu.EncryptKey = old.Feishu.EncryptKey
newcfg.Feishu.VerificationToken = old.Feishu.VerificationToken
}
if newcfg.TeamsWebhook.Enabled {
// Copy SecureString webhook URLs from old config
for name, oldTarget := range old.TeamsWebhook.Webhooks {
if newTarget, ok := newcfg.TeamsWebhook.Webhooks[name]; ok {
newTarget.WebhookURL = oldTarget.WebhookURL
newcfg.TeamsWebhook.Webhooks[name] = newTarget
}
}
}
}
+13
View File
@@ -0,0 +1,13 @@
package teamswebhook
import (
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
)
func init() {
channels.RegisterFactory("teams_webhook", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) {
return NewTeamsWebhookChannel(cfg.Channels.TeamsWebhook, b)
})
}
+422
View File
@@ -0,0 +1,422 @@
package teamswebhook
import (
"context"
"fmt"
"net/url"
"regexp"
"sort"
"strconv"
"strings"
goteamsnotify "github.com/atc0005/go-teams-notify/v2"
"github.com/atc0005/go-teams-notify/v2/adaptivecard"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
)
// statusCodeRe extracts HTTP status codes from error messages like "401 Unauthorized".
var statusCodeRe = regexp.MustCompile(`\b([45]\d{2})\b`)
// markdownTableRe matches a markdown table block (header + separator + rows).
// It captures the entire table including all rows.
var markdownTableRe = regexp.MustCompile(`(?m)^(\|[^\n]+\|)\n(\|[-:\|\s]+\|)\n((?:\|[^\n]+\|\n?)+)`)
// teamsMessageSender abstracts the Teams client for testability.
type teamsMessageSender interface {
SendWithContext(ctx context.Context, webhookURL string, message goteamsnotify.TeamsMessage) error
}
// classifyTeamsError extracts HTTP status code from error message and classifies it.
// The go-teams-notify library returns errors like "error on notification: 401 Unauthorized, ...".
// This allows proper retry behavior: 4xx errors are permanent, 5xx are temporary.
func classifyTeamsError(err error) error {
if err == nil {
return nil
}
errMsg := err.Error()
if matches := statusCodeRe.FindStringSubmatch(errMsg); len(matches) > 1 {
if statusCode, parseErr := strconv.Atoi(matches[1]); parseErr == nil {
return channels.ClassifySendError(statusCode, err)
}
}
// Fallback: treat as temporary network error (retryable)
return channels.ClassifyNetError(err)
}
// TeamsWebhookChannel is an output-only channel that sends messages
// to Microsoft Teams via Power Automate workflow webhooks.
// Multiple webhook targets can be configured and selected via ChatID.
type TeamsWebhookChannel struct {
*channels.BaseChannel
config config.TeamsWebhookConfig
client teamsMessageSender
}
// NewTeamsWebhookChannel creates a new Teams webhook channel.
func NewTeamsWebhookChannel(
cfg config.TeamsWebhookConfig,
bus *bus.MessageBus,
) (*TeamsWebhookChannel, error) {
if len(cfg.Webhooks) == 0 {
return nil, fmt.Errorf("teams_webhook: at least one webhook target is required")
}
// Require "default" webhook target
if _, hasDefault := cfg.Webhooks["default"]; !hasDefault {
return nil, fmt.Errorf("teams_webhook: a 'default' webhook target is required")
}
// Validate all webhook targets have valid HTTPS URLs
for name, target := range cfg.Webhooks {
webhookURL := target.WebhookURL.String()
if webhookURL == "" {
return nil, fmt.Errorf("teams_webhook: webhook %q has empty webhook_url", name)
}
parsed, err := url.Parse(webhookURL)
if err != nil {
return nil, fmt.Errorf("teams_webhook: webhook %q has invalid URL: %w", name, err)
}
if !strings.EqualFold(parsed.Scheme, "https") {
return nil, fmt.Errorf("teams_webhook: webhook %q must use HTTPS (got %q)", name, parsed.Scheme)
}
}
base := channels.NewBaseChannel(
"teams_webhook",
cfg,
bus,
[]string{
"*",
}, // Output-only channel; "*" suppresses misleading "allows EVERYONE" audit warning
channels.WithMaxMessageLength(24000), // Power Automate webhook payload limit is 28KB
)
client := goteamsnotify.NewTeamsClient()
return &TeamsWebhookChannel{
BaseChannel: base,
config: cfg,
client: client,
}, nil
}
// Start initializes the channel. For output-only channels, this is a no-op.
func (c *TeamsWebhookChannel) Start(ctx context.Context) error {
targets := make([]string, 0, len(c.config.Webhooks))
for name := range c.config.Webhooks {
targets = append(targets, name)
}
sort.Strings(targets)
logger.InfoCF("teams_webhook", "Starting Teams webhook channel (output-only)", map[string]any{
"targets": targets,
})
c.SetRunning(true)
return nil
}
// Stop shuts down the channel.
func (c *TeamsWebhookChannel) Stop(ctx context.Context) error {
logger.InfoC("teams_webhook", "Stopping Teams webhook channel")
c.SetRunning(false)
return nil
}
// Send delivers a message to the specified Teams webhook target.
// The target is selected by msg.ChatID which must match a key in the webhooks map.
func (c *TeamsWebhookChannel) Send(ctx context.Context, msg bus.OutboundMessage) ([]string, error) {
if !c.IsRunning() {
return nil, channels.ErrNotRunning
}
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
// Look up webhook target by ChatID, fall back to "default" if empty or unknown
targetName := msg.ChatID
if targetName == "" {
targetName = "default"
}
target, ok := c.config.Webhooks[targetName]
if !ok {
// Log warning and fall back to default target
logger.WarnCF("teams_webhook", "Unknown target, falling back to default", map[string]any{
"requested": msg.ChatID,
"using": "default",
})
target = c.config.Webhooks["default"]
}
// Build an Adaptive Card for rich formatting
card, err := c.buildAdaptiveCard(msg, target)
if err != nil {
return nil, fmt.Errorf("teams_webhook: failed to build card: %w", err)
}
// Create the message with the card
teamsMsg, err := adaptivecard.NewMessageFromCard(card)
if err != nil {
return nil, fmt.Errorf("teams_webhook: failed to create message: %w", err)
}
// Send to Teams
if err := c.client.SendWithContext(ctx, target.WebhookURL.String(), teamsMsg); err != nil {
// Log without raw error to avoid leaking webhook URL (embedded in net/http errors)
logger.ErrorCF("teams_webhook", "Failed to send message to Teams webhook", map[string]any{
"target": msg.ChatID,
})
// Classify error based on status code extracted from error message.
// The go-teams-notify library includes status in errors like "401 Unauthorized".
// Use ClassifySendError for proper retry behavior (4xx = permanent, 5xx = temporary).
classifiedErr := classifyTeamsError(err)
return nil, fmt.Errorf("teams_webhook: send failed: %w", classifiedErr)
}
logger.DebugCF("teams_webhook", "Message sent successfully", map[string]any{
"target": msg.ChatID,
})
return nil, nil
}
// buildAdaptiveCard creates a formatted Adaptive Card from the outbound message.
// It detects markdown tables and converts them to native Adaptive Card Table elements,
// since TextBlocks only support a limited markdown subset (no tables).
func (c *TeamsWebhookChannel) buildAdaptiveCard(
msg bus.OutboundMessage,
target config.TeamsWebhookTarget,
) (adaptivecard.Card, error) {
card := adaptivecard.NewCard()
card.Type = adaptivecard.TypeAdaptiveCard
// Set full width for Teams rendering
card.MSTeams.Width = "Full"
// Add title if configured on the target
title := target.Title
if title == "" {
title = "PicoClaw Notification"
}
titleBlock := adaptivecard.NewTextBlock(title, true)
titleBlock.Size = adaptivecard.SizeLarge
titleBlock.Weight = adaptivecard.WeightBolder
titleBlock.Style = adaptivecard.TextBlockStyleHeading
if err := card.AddElement(false, titleBlock); err != nil {
return card, err
}
content := msg.Content
if content == "" {
content = "(empty message)"
}
// Split content into text segments and tables
// TextBlocks support: bold, italic, bullet/numbered lists, links
// TextBlocks do NOT support: headers, tables, images
segments := splitContentWithTables(content)
for _, seg := range segments {
if seg.isTable {
// Convert markdown table to Adaptive Card Table element
tableElement, err := parseMarkdownTable(seg.content)
if err != nil {
// Fallback: render as preformatted text if parsing fails
logger.WarnCF("teams_webhook", "Failed to parse markdown table, using fallback", map[string]any{
"error": err.Error(),
})
block := adaptivecard.NewTextBlock("```\n"+seg.content+"\n```", true)
block.Wrap = true
if err := card.AddElement(false, block); err != nil {
return card, err
}
continue
}
if err := card.AddElement(false, tableElement); err != nil {
return card, err
}
} else {
// Regular text content
text := strings.TrimSpace(seg.content)
if text == "" {
continue
}
block := adaptivecard.NewTextBlock(text, true)
block.Wrap = true
if err := card.AddElement(false, block); err != nil {
return card, err
}
}
}
return card, nil
}
// contentSegment represents either a text block or a table in the message content.
type contentSegment struct {
content string
isTable bool
}
// splitContentWithTables splits content into alternating text and table segments.
func splitContentWithTables(content string) []contentSegment {
var segments []contentSegment
matches := markdownTableRe.FindAllStringSubmatchIndex(content, -1)
if len(matches) == 0 {
// No tables found, return entire content as text
return []contentSegment{{content: content, isTable: false}}
}
lastEnd := 0
for _, match := range matches {
// Text before this table
if match[0] > lastEnd {
segments = append(segments, contentSegment{
content: content[lastEnd:match[0]],
isTable: false,
})
}
// The table itself
segments = append(segments, contentSegment{
content: content[match[0]:match[1]],
isTable: true,
})
lastEnd = match[1]
}
// Text after the last table
if lastEnd < len(content) {
segments = append(segments, contentSegment{
content: content[lastEnd:],
isTable: false,
})
}
return segments
}
// parseMarkdownTable converts a markdown table string to an Adaptive Card Table element.
func parseMarkdownTable(tableStr string) (adaptivecard.Element, error) {
lines := strings.Split(strings.TrimSpace(tableStr), "\n")
if len(lines) < 2 {
return adaptivecard.Element{}, fmt.Errorf("table must have at least header and separator rows")
}
// Track header content length per column for width calculation
var headerLengths []int
// Parse all rows (header + data rows, skip separator)
var allRows [][]adaptivecard.TableCell
for i, line := range lines {
// Skip separator row (contains only |, -, :, and spaces)
if i == 1 && isSeparatorRow(line) {
continue
}
cells := parseTableRow(line)
if len(cells) == 0 {
continue
}
var tableCells []adaptivecard.TableCell
for _, cellText := range cells {
trimmedText := strings.TrimSpace(cellText)
// Use header row (first row) to determine column widths
if i == 0 {
headerLengths = append(headerLengths, len(trimmedText))
}
textBlock := adaptivecard.Element{
Type: adaptivecard.TypeElementTextBlock,
Text: trimmedText,
Wrap: true,
}
cell := adaptivecard.TableCell{
Type: adaptivecard.TypeTableCell,
Items: []*adaptivecard.Element{&textBlock},
}
tableCells = append(tableCells, cell)
}
allRows = append(allRows, tableCells)
}
if len(allRows) == 0 {
return adaptivecard.Element{}, fmt.Errorf("no valid rows found in table")
}
// Create table with first row as headers
firstRowAsHeaders := true
showGridLines := true
table, err := adaptivecard.NewTableFromTableCells(allRows, 0, firstRowAsHeaders, showGridLines)
if err != nil {
return adaptivecard.Element{}, fmt.Errorf("failed to create table: %w", err)
}
// Set column widths based on header content length
table.Columns = calculateColumnWidths(headerLengths)
return table, nil
}
// calculateColumnWidths creates TableColumnDefinition entries with widths
// proportional to the max content length of each column.
func calculateColumnWidths(maxLengths []int) []adaptivecard.Column {
if len(maxLengths) == 0 {
return nil
}
// Use content length as relative weight, with a minimum of 1
columns := make([]adaptivecard.Column, len(maxLengths))
for i, length := range maxLengths {
weight := length
if weight < 1 {
weight = 1
}
columns[i] = adaptivecard.Column{
Type: "TableColumnDefinition",
Width: weight,
}
}
return columns
}
// isSeparatorRow checks if a line is a markdown table separator (e.g., |---|---|).
func isSeparatorRow(line string) bool {
// Remove pipes and spaces, check if only dashes and colons remain
cleaned := strings.ReplaceAll(line, "|", "")
cleaned = strings.ReplaceAll(cleaned, " ", "")
cleaned = strings.ReplaceAll(cleaned, "-", "")
cleaned = strings.ReplaceAll(cleaned, ":", "")
return cleaned == ""
}
// parseTableRow extracts cell values from a markdown table row.
func parseTableRow(line string) []string {
// Trim leading/trailing pipes and split by |
line = strings.TrimSpace(line)
line = strings.TrimPrefix(line, "|")
line = strings.TrimSuffix(line, "|")
if line == "" {
return nil
}
parts := strings.Split(line, "|")
var cells []string
for _, p := range parts {
cells = append(cells, strings.TrimSpace(p))
}
return cells
}
@@ -0,0 +1,583 @@
package teamswebhook
import (
"context"
"errors"
"testing"
goteamsnotify "github.com/atc0005/go-teams-notify/v2"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
)
// mockTeamsClient implements teamsMessageSender for testing.
type mockTeamsClient struct {
sendFunc func(ctx context.Context, webhookURL string, message goteamsnotify.TeamsMessage) error
}
func (m *mockTeamsClient) SendWithContext(
ctx context.Context,
webhookURL string,
message goteamsnotify.TeamsMessage,
) error {
if m.sendFunc != nil {
return m.sendFunc(ctx, webhookURL, message)
}
return nil
}
func TestNewTeamsWebhookChannel(t *testing.T) {
msgBus := bus.NewMessageBus()
// Test missing webhooks
_, err := NewTeamsWebhookChannel(config.TeamsWebhookConfig{
Enabled: true,
Webhooks: nil,
}, msgBus)
if err == nil {
t.Error("expected error for missing webhooks")
}
// Test missing "default" webhook
_, err = NewTeamsWebhookChannel(config.TeamsWebhookConfig{
Enabled: true,
Webhooks: map[string]config.TeamsWebhookTarget{
"alerts": {
WebhookURL: *config.NewSecureString("https://example.com/webhook"),
Title: "Alerts",
},
},
}, msgBus)
if err == nil {
t.Error("expected error for missing 'default' webhook")
}
// Test empty webhook URL
_, err = NewTeamsWebhookChannel(config.TeamsWebhookConfig{
Enabled: true,
Webhooks: map[string]config.TeamsWebhookTarget{
"default": {Title: "Default"},
},
}, msgBus)
if err == nil {
t.Error("expected error for empty webhook_url")
}
// Test HTTP URL (should fail, must be HTTPS)
_, err = NewTeamsWebhookChannel(config.TeamsWebhookConfig{
Enabled: true,
Webhooks: map[string]config.TeamsWebhookTarget{
"default": {
WebhookURL: *config.NewSecureString("http://example.com/webhook"),
Title: "Default",
},
},
}, msgBus)
if err == nil {
t.Error("expected error for HTTP webhook URL (must be HTTPS)")
}
// Test valid config with HTTPS (must include "default")
ch, err := NewTeamsWebhookChannel(config.TeamsWebhookConfig{
Enabled: true,
Webhooks: map[string]config.TeamsWebhookTarget{
"default": {
WebhookURL: *config.NewSecureString("https://example.com/webhook-default"),
Title: "Default",
},
"alerts": {
WebhookURL: *config.NewSecureString("https://example.com/webhook1"),
Title: "Alerts",
},
},
}, msgBus)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if ch.Name() != "teams_webhook" {
t.Errorf("expected name 'teams_webhook', got %q", ch.Name())
}
}
func TestTeamsWebhookChannel_StartStop(t *testing.T) {
msgBus := bus.NewMessageBus()
ch, err := NewTeamsWebhookChannel(config.TeamsWebhookConfig{
Enabled: true,
Webhooks: map[string]config.TeamsWebhookTarget{
"default": {
WebhookURL: *config.NewSecureString("https://example.com/webhook"),
},
},
}, msgBus)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
ctx := context.Background()
if ch.IsRunning() {
t.Error("channel should not be running before Start")
}
if err := ch.Start(ctx); err != nil {
t.Fatalf("Start failed: %v", err)
}
if !ch.IsRunning() {
t.Error("channel should be running after Start")
}
if err := ch.Stop(ctx); err != nil {
t.Fatalf("Stop failed: %v", err)
}
if ch.IsRunning() {
t.Error("channel should not be running after Stop")
}
}
func TestTeamsWebhookChannel_BuildAdaptiveCard(t *testing.T) {
msgBus := bus.NewMessageBus()
ch, err := NewTeamsWebhookChannel(config.TeamsWebhookConfig{
Enabled: true,
Webhooks: map[string]config.TeamsWebhookTarget{
"default": {
WebhookURL: *config.NewSecureString("https://example.com/webhook-default"),
Title: "Default",
},
"alerts": {
WebhookURL: *config.NewSecureString("https://example.com/webhook"),
Title: "Custom Title",
},
},
}, msgBus)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
target := ch.config.Webhooks["alerts"]
msg := bus.OutboundMessage{
Content: "Test message content",
ChatID: "alerts",
}
card, err := ch.buildAdaptiveCard(msg, target)
if err != nil {
t.Fatalf("buildAdaptiveCard failed: %v", err)
}
if card.Type != "AdaptiveCard" {
t.Errorf("expected card type 'AdaptiveCard', got %q", card.Type)
}
}
func TestTeamsWebhookChannel_SendNotRunning(t *testing.T) {
msgBus := bus.NewMessageBus()
ch, err := NewTeamsWebhookChannel(config.TeamsWebhookConfig{
Enabled: true,
Webhooks: map[string]config.TeamsWebhookTarget{
"default": {
WebhookURL: *config.NewSecureString("https://example.com/webhook"),
},
},
}, msgBus)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
ctx := context.Background()
msg := bus.OutboundMessage{Content: "test", ChatID: "default"}
_, err = ch.Send(ctx, msg)
if err == nil {
t.Error("expected error when sending while not running")
}
}
func TestTeamsWebhookChannel_SendDefaultTargetFallback(t *testing.T) {
tests := []struct {
name string
chatID string
}{
{"unknown target falls back to default", "unknown"},
{"empty ChatID uses default", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
msgBus := bus.NewMessageBus()
ch, err := NewTeamsWebhookChannel(config.TeamsWebhookConfig{
Enabled: true,
Webhooks: map[string]config.TeamsWebhookTarget{
"default": {
WebhookURL: *config.NewSecureString("https://example.com/webhook-default"),
},
"alerts": {
WebhookURL: *config.NewSecureString("https://example.com/webhook-alerts"),
},
},
}, msgBus)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
var sentURL string
ch.client = &mockTeamsClient{
sendFunc: func(ctx context.Context, webhookURL string, message goteamsnotify.TeamsMessage) error {
sentURL = webhookURL
return nil
},
}
ctx := context.Background()
_ = ch.Start(ctx)
defer ch.Stop(ctx)
msg := bus.OutboundMessage{Content: "test", ChatID: tt.chatID}
_, err = ch.Send(ctx, msg)
if err != nil {
t.Fatalf("expected success, got error: %v", err)
}
if sentURL != "https://example.com/webhook-default" {
t.Errorf("expected default webhook URL, got %q", sentURL)
}
})
}
}
func TestTeamsWebhookChannel_SendSuccess(t *testing.T) {
msgBus := bus.NewMessageBus()
ch, err := NewTeamsWebhookChannel(config.TeamsWebhookConfig{
Enabled: true,
Webhooks: map[string]config.TeamsWebhookTarget{
"default": {
WebhookURL: *config.NewSecureString("https://example.com/webhook-default"),
Title: "Default",
},
"alerts": {
WebhookURL: *config.NewSecureString("https://example.com/webhook-alerts"),
Title: "Test Alerts",
},
},
}, msgBus)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Inject mock client
var sentURL string
ch.client = &mockTeamsClient{
sendFunc: func(ctx context.Context, webhookURL string, message goteamsnotify.TeamsMessage) error {
sentURL = webhookURL
return nil
},
}
ctx := context.Background()
_ = ch.Start(ctx)
defer ch.Stop(ctx)
msg := bus.OutboundMessage{Content: "Hello Teams!", ChatID: "alerts"}
_, err = ch.Send(ctx, msg)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if sentURL != "https://example.com/webhook-alerts" {
t.Errorf("expected webhook URL 'https://example.com/webhook-alerts', got %q", sentURL)
}
}
func TestTeamsWebhookChannel_SendError(t *testing.T) {
msgBus := bus.NewMessageBus()
ch, err := NewTeamsWebhookChannel(config.TeamsWebhookConfig{
Enabled: true,
Webhooks: map[string]config.TeamsWebhookTarget{
"default": {
WebhookURL: *config.NewSecureString("https://example.com/webhook-default"),
},
"alerts": {
WebhookURL: *config.NewSecureString("https://example.com/webhook-alerts"),
},
},
}, msgBus)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Inject mock client that returns an error
ch.client = &mockTeamsClient{
sendFunc: func(ctx context.Context, webhookURL string, message goteamsnotify.TeamsMessage) error {
return errors.New("error on notification: 401 Unauthorized, forbidden")
},
}
ctx := context.Background()
_ = ch.Start(ctx)
defer ch.Stop(ctx)
msg := bus.OutboundMessage{Content: "test", ChatID: "alerts"}
_, err = ch.Send(ctx, msg)
if err == nil {
t.Error("expected error from failed send")
}
}
func TestSplitContentWithTables(t *testing.T) {
tests := []struct {
name string
content string
wantSegs int
wantTbl int // number of table segments
}{
{
name: "no tables",
content: "Just some text\nwith multiple lines",
wantSegs: 1,
wantTbl: 0,
},
{
name: "single table",
content: `| Col1 | Col2 |
|------|------|
| A | B |
| C | D |`,
wantSegs: 1,
wantTbl: 1,
},
{
name: "text before table",
content: `Here is some text.
| Col1 | Col2 |
|------|------|
| A | B |`,
wantSegs: 2,
wantTbl: 1,
},
{
name: "text before and after table",
content: `Before table.
| Col1 | Col2 |
|------|------|
| A | B |
After table.`,
wantSegs: 3,
wantTbl: 1,
},
{
name: "multiple tables",
content: `First table:
| A | B |
|---|---|
| 1 | 2 |
Second table:
| X | Y |
|---|---|
| 3 | 4 |`,
wantSegs: 4,
wantTbl: 2,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
segs := splitContentWithTables(tt.content)
if len(segs) != tt.wantSegs {
t.Errorf("got %d segments, want %d", len(segs), tt.wantSegs)
}
tableCount := 0
for _, s := range segs {
if s.isTable {
tableCount++
}
}
if tableCount != tt.wantTbl {
t.Errorf("got %d tables, want %d", tableCount, tt.wantTbl)
}
})
}
}
func TestParseMarkdownTable(t *testing.T) {
tableStr := `| Name | Value |
|------|-------|
| foo | 123 |
| bar | 456 |`
elem, err := parseMarkdownTable(tableStr)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if elem.Type != "Table" {
t.Errorf("expected type 'Table', got %q", elem.Type)
}
// Should have 3 rows (header + 2 data rows)
if len(elem.Rows) != 3 {
t.Errorf("expected 3 rows, got %d", len(elem.Rows))
}
// Should have 2 columns with widths based on content length
if len(elem.Columns) != 2 {
t.Errorf("expected 2 columns, got %d", len(elem.Columns))
}
}
func TestParseMarkdownTableColumnWidths(t *testing.T) {
// Column widths are based on HEADER row only:
// Col1: "Description" (11 chars)
// Col2: "X" (1 char)
// Col3: "Amount" (6 chars)
tableStr := `| Description | X | Amount |
|-------------|---|--------|
| Short | Y | 100 |
| Longer text | Z | 50 |`
elem, err := parseMarkdownTable(tableStr)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(elem.Columns) != 3 {
t.Fatalf("expected 3 columns, got %d", len(elem.Columns))
}
// Verify column widths are based on header content length
w1, ok1 := elem.Columns[0].Width.(int)
w2, ok2 := elem.Columns[1].Width.(int)
w3, ok3 := elem.Columns[2].Width.(int)
if !ok1 || !ok2 || !ok3 {
t.Fatalf("expected int widths, got types: %T, %T, %T",
elem.Columns[0].Width, elem.Columns[1].Width, elem.Columns[2].Width)
}
// Header lengths: "Description" = 11, "X" = 1, "Amount" = 6
if w1 != 11 {
t.Errorf("expected col1 width 11 (from 'Description'), got %d", w1)
}
if w2 != 1 {
t.Errorf("expected col2 width 1 (from 'X'), got %d", w2)
}
if w3 != 6 {
t.Errorf("expected col3 width 6 (from 'Amount'), got %d", w3)
}
}
func TestCalculateColumnWidths(t *testing.T) {
tests := []struct {
name string
maxLengths []int
wantWidths []int
}{
{
name: "equal lengths",
maxLengths: []int{10, 10, 10},
wantWidths: []int{10, 10, 10},
},
{
name: "varying lengths",
maxLengths: []int{5, 20, 10},
wantWidths: []int{5, 20, 10},
},
{
name: "zero length gets minimum of 1",
maxLengths: []int{0, 5, 0},
wantWidths: []int{1, 5, 1},
},
{
name: "empty input",
maxLengths: []int{},
wantWidths: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cols := calculateColumnWidths(tt.maxLengths)
if tt.wantWidths == nil {
if cols != nil {
t.Errorf("expected nil, got %v", cols)
}
return
}
if len(cols) != len(tt.wantWidths) {
t.Fatalf("expected %d columns, got %d", len(tt.wantWidths), len(cols))
}
for i, col := range cols {
width, ok := col.Width.(int)
if !ok {
t.Errorf("column %d: expected int width, got %T", i, col.Width)
continue
}
if width != tt.wantWidths[i] {
t.Errorf("column %d: expected width %d, got %d", i, tt.wantWidths[i], width)
}
if col.Type != "TableColumnDefinition" {
t.Errorf("column %d: expected type 'TableColumnDefinition', got %q", i, col.Type)
}
}
})
}
}
func TestParseTableRow(t *testing.T) {
tests := []struct {
line string
want []string
}{
{"| A | B | C |", []string{"A", "B", "C"}},
{"|A|B|C|", []string{"A", "B", "C"}},
{"| foo | bar |", []string{"foo", "bar"}},
{"", nil},
}
for _, tt := range tests {
got := parseTableRow(tt.line)
if len(got) != len(tt.want) {
t.Errorf("parseTableRow(%q): got %v, want %v", tt.line, got, tt.want)
continue
}
for i := range got {
if got[i] != tt.want[i] {
t.Errorf("parseTableRow(%q)[%d]: got %q, want %q", tt.line, i, got[i], tt.want[i])
}
}
}
}
func TestIsSeparatorRow(t *testing.T) {
tests := []struct {
line string
want bool
}{
{"|---|---|", true},
{"| --- | --- |", true},
{"|:---|---:|", true},
{"| :---: | :---: |", true},
{"| A | B |", false},
{"| foo | bar |", false},
}
for _, tt := range tests {
got := isSeparatorRow(tt.line)
if got != tt.want {
t.Errorf("isSeparatorRow(%q): got %v, want %v", tt.line, got, tt.want)
}
}
}
+79 -35
View File
@@ -24,20 +24,21 @@ var rrCounter atomic.Uint64
// CurrentVersion is the latest config schema version
const CurrentVersion = 2
// Config is the current config structure with version support
// Config is the current config structure with version support.
type Config struct {
Version int `json:"version" yaml:"-"` // Config schema version for migration
Agents AgentsConfig `json:"agents" yaml:"-"`
Bindings []AgentBinding `json:"bindings,omitempty" yaml:"-"`
Session SessionConfig `json:"session,omitempty" yaml:"-"`
Channels ChannelsConfig `json:"channels" yaml:"channels"`
ModelList SecureModelList `json:"model_list" yaml:"model_list"` // New model-centric provider configuration
Gateway GatewayConfig `json:"gateway" yaml:"-"`
Hooks HooksConfig `json:"hooks,omitempty" yaml:"-"`
Tools ToolsConfig `json:"tools" yaml:",inline"`
Heartbeat HeartbeatConfig `json:"heartbeat" yaml:"-"`
Devices DevicesConfig `json:"devices" yaml:"-"`
Voice VoiceConfig `json:"voice" yaml:"-"`
Version int `json:"version" yaml:"-"` // Config schema version for migration
Isolation IsolationConfig `json:"isolation,omitempty" yaml:"-"`
Agents AgentsConfig `json:"agents" yaml:"-"`
Bindings []AgentBinding `json:"bindings,omitempty" yaml:"-"`
Session SessionConfig `json:"session,omitempty" yaml:"-"`
Channels ChannelsConfig `json:"channels" yaml:"channels"`
ModelList SecureModelList `json:"model_list" yaml:"model_list"` // New model-centric provider configuration
Gateway GatewayConfig `json:"gateway" yaml:"-"`
Hooks HooksConfig `json:"hooks,omitempty" yaml:"-"`
Tools ToolsConfig `json:"tools" yaml:",inline"`
Heartbeat HeartbeatConfig `json:"heartbeat" yaml:"-"`
Devices DevicesConfig `json:"devices" yaml:"-"`
Voice VoiceConfig `json:"voice" yaml:"-"`
// BuildInfo contains build-time version information
BuildInfo BuildInfo `json:"build_info,omitempty" yaml:"-"`
@@ -45,6 +46,21 @@ type Config struct {
sensitiveCache *SensitiveDataCache
}
// IsolationConfig controls subprocess isolation for commands started by PicoClaw.
// It is applied by the isolation package rather than by sandboxing the main process.
type IsolationConfig struct {
Enabled bool `json:"enabled,omitempty"`
ExposePaths []ExposePath `json:"expose_paths,omitempty"`
}
// ExposePath describes a host path that should remain visible inside the isolated
// child-process environment. This is currently implemented on Linux only.
type ExposePath struct {
Source string `json:"source"`
Target string `json:"target,omitempty"`
Mode string `json:"mode"`
}
// FilterSensitiveData filters sensitive values from content before sending to LLM.
// This prevents the LLM from seeing its own credentials.
// Uses strings.Replacer for O(n+m) performance (computed once per SecurityConfig).
@@ -280,23 +296,24 @@ func (d *AgentDefaults) GetModelName() string {
}
type ChannelsConfig struct {
WhatsApp WhatsAppConfig `json:"whatsapp" yaml:"-"`
Telegram TelegramConfig `json:"telegram" yaml:"telegram,omitempty"`
Feishu FeishuConfig `json:"feishu" yaml:"feishu,omitempty"`
Discord DiscordConfig `json:"discord" yaml:"discord,omitempty"`
MaixCam MaixCamConfig `json:"maixcam" yaml:"-"`
QQ QQConfig `json:"qq" yaml:"qq,omitempty"`
DingTalk DingTalkConfig `json:"dingtalk" yaml:"dingtalk,omitempty"`
Slack SlackConfig `json:"slack" yaml:"slack,omitempty"`
Matrix MatrixConfig `json:"matrix" yaml:"matrix,omitempty"`
LINE LINEConfig `json:"line" yaml:"line,omitempty"`
OneBot OneBotConfig `json:"onebot" yaml:"onebot,omitempty"`
WeCom WeComConfig `json:"wecom" yaml:"wecom,omitempty" envPrefix:"PICOCLAW_CHANNELS_WECOM_"`
Weixin WeixinConfig `json:"weixin" yaml:"weixin,omitempty"`
Pico PicoConfig `json:"pico" yaml:"pico,omitempty"`
PicoClient PicoClientConfig `json:"pico_client" yaml:"pico_client,omitempty"`
IRC IRCConfig `json:"irc" yaml:"irc,omitempty"`
VK VKConfig `json:"vk" yaml:"vk,omitempty"`
WhatsApp WhatsAppConfig `json:"whatsapp" yaml:"-"`
Telegram TelegramConfig `json:"telegram" yaml:"telegram,omitempty"`
Feishu FeishuConfig `json:"feishu" yaml:"feishu,omitempty"`
Discord DiscordConfig `json:"discord" yaml:"discord,omitempty"`
MaixCam MaixCamConfig `json:"maixcam" yaml:"-"`
QQ QQConfig `json:"qq" yaml:"qq,omitempty"`
DingTalk DingTalkConfig `json:"dingtalk" yaml:"dingtalk,omitempty"`
Slack SlackConfig `json:"slack" yaml:"slack,omitempty"`
Matrix MatrixConfig `json:"matrix" yaml:"matrix,omitempty"`
LINE LINEConfig `json:"line" yaml:"line,omitempty"`
OneBot OneBotConfig `json:"onebot" yaml:"onebot,omitempty"`
WeCom WeComConfig `json:"wecom" yaml:"wecom,omitempty" envPrefix:"PICOCLAW_CHANNELS_WECOM_"`
Weixin WeixinConfig `json:"weixin" yaml:"weixin,omitempty"`
Pico PicoConfig `json:"pico" yaml:"pico,omitempty"`
PicoClient PicoClientConfig `json:"pico_client" yaml:"pico_client,omitempty"`
IRC IRCConfig `json:"irc" yaml:"irc,omitempty"`
VK VKConfig `json:"vk" yaml:"vk,omitempty"`
TeamsWebhook TeamsWebhookConfig `json:"teams_webhook" yaml:"teams_webhook,omitempty"`
}
// GroupTriggerConfig controls when the bot responds in group chats.
@@ -566,6 +583,19 @@ func (c *VKConfig) SetToken(token string) {
c.Token = *NewSecureString(token)
}
// TeamsWebhookConfig configures the output-only Microsoft Teams webhook channel.
// Multiple webhook targets can be configured and selected via ChatID at send time.
type TeamsWebhookConfig struct {
Enabled bool `json:"enabled" yaml:"-" env:"PICOCLAW_CHANNELS_TEAMS_WEBHOOK_ENABLED"`
Webhooks map[string]TeamsWebhookTarget `json:"webhooks" yaml:"webhooks,omitempty"`
}
// TeamsWebhookTarget represents a single Teams webhook destination.
type TeamsWebhookTarget struct {
WebhookURL SecureString `json:"webhook_url,omitzero" yaml:"webhook_url,omitempty"`
Title string `json:"title,omitempty" yaml:"-"`
}
type HeartbeatConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_HEARTBEAT_ENABLED"`
Interval int `json:"interval" env:"PICOCLAW_HEARTBEAT_INTERVAL"` // minutes, min 5
@@ -605,11 +635,12 @@ type ModelConfig struct {
Workspace string `json:"workspace,omitempty"` // Workspace path for CLI-based providers
// Optional optimizations
RPM int `json:"rpm,omitempty"` // Requests per minute limit
MaxTokensField string `json:"max_tokens_field,omitempty"` // Field name for max tokens (e.g., "max_completion_tokens")
RequestTimeout int `json:"request_timeout,omitempty"`
ThinkingLevel string `json:"thinking_level,omitempty"` // Extended thinking: off|low|medium|high|xhigh|adaptive
ExtraBody map[string]any `json:"extra_body,omitempty"` // Additional fields to inject into request body
RPM int `json:"rpm,omitempty"` // Requests per minute limit
MaxTokensField string `json:"max_tokens_field,omitempty"` // Field name for max tokens (e.g., "max_completion_tokens")
RequestTimeout int `json:"request_timeout,omitempty"`
ThinkingLevel string `json:"thinking_level,omitempty"` // Extended thinking: off|low|medium|high|xhigh|adaptive
ExtraBody map[string]any `json:"extra_body,omitempty"` // Additional fields to inject into request body
CustomHeaders map[string]string `json:"custom_headers,omitempty"` // Additional headers to inject into every HTTP request
APIKeys SecureStrings `json:"api_keys,omitzero" yaml:"api_keys,omitempty"` // API authentication keys (multiple keys for failover)
@@ -943,10 +974,21 @@ type MCPServerConfig struct {
type MCPConfig struct {
ToolConfig ` envPrefix:"PICOCLAW_TOOLS_MCP_"`
Discovery ToolDiscoveryConfig ` json:"discovery"`
// MaxInlineTextChars controls how much MCP text stays inline before it is saved as an artifact.
MaxInlineTextChars int `json:"max_inline_text_chars,omitempty" env:"PICOCLAW_TOOLS_MCP_MAX_INLINE_TEXT_CHARS"`
// Servers is a map of server name to server configuration
Servers map[string]MCPServerConfig `json:"servers,omitempty"`
}
const DefaultMCPMaxInlineTextChars = 16 * 1024
func (c *MCPConfig) GetMaxInlineTextChars() int {
if c.MaxInlineTextChars > 0 {
return c.MaxInlineTextChars
}
return DefaultMCPMaxInlineTextChars
}
func LoadConfig(path string) (*Config, error) {
logger.Debugf("loading config from %s", path)
@@ -1268,6 +1310,7 @@ func expandMultiKeyModels(models []*ModelConfig) []*ModelConfig {
RequestTimeout: m.RequestTimeout,
ThinkingLevel: m.ThinkingLevel,
ExtraBody: m.ExtraBody,
CustomHeaders: m.CustomHeaders,
isVirtual: true,
}
expanded = append(expanded, additionalEntry)
@@ -1288,6 +1331,7 @@ func expandMultiKeyModels(models []*ModelConfig) []*ModelConfig {
RequestTimeout: m.RequestTimeout,
ThinkingLevel: m.ThinkingLevel,
ExtraBody: m.ExtraBody,
CustomHeaders: m.CustomHeaders,
APIKeys: SimpleSecureStrings(keys[0]),
}
+102
View File
@@ -198,6 +198,41 @@ func TestAgentConfig_FullParse(t *testing.T) {
}
}
func TestDefaultConfig_MCPMaxInlineTextChars(t *testing.T) {
cfg := DefaultConfig()
if cfg.Tools.MCP.GetMaxInlineTextChars() != DefaultMCPMaxInlineTextChars {
t.Fatalf(
"DefaultConfig().Tools.MCP.GetMaxInlineTextChars() = %d, want %d",
cfg.Tools.MCP.GetMaxInlineTextChars(),
DefaultMCPMaxInlineTextChars,
)
}
}
func TestLoadConfig_MCPMaxInlineTextChars(t *testing.T) {
dir := t.TempDir()
configPath := filepath.Join(dir, "config.json")
raw := `{
"tools": {
"mcp": {
"enabled": true,
"max_inline_text_chars": 2048
}
}
}`
if err := os.WriteFile(configPath, []byte(raw), 0o644); err != nil {
t.Fatalf("WriteFile(configPath): %v", err)
}
cfg, err := LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error: %v", err)
}
if got := cfg.Tools.MCP.GetMaxInlineTextChars(); got != 2048 {
t.Fatalf("cfg.Tools.MCP.GetMaxInlineTextChars() = %d, want 2048", got)
}
}
func TestConfig_BackwardCompat_NoAgentsList(t *testing.T) {
jsonData := `{
"agents": {
@@ -817,6 +852,37 @@ func TestDefaultConfig_WorkspacePath_WithPicoclawHome(t *testing.T) {
}
}
func TestDefaultConfig_IsolationEnabled(t *testing.T) {
cfg := DefaultConfig()
if cfg.Isolation.Enabled {
t.Fatal("DefaultConfig().Isolation.Enabled should be false")
}
}
func TestConfig_UnmarshalIsolation(t *testing.T) {
cfg := DefaultConfig()
raw := []byte(`{
"isolation": {
"enabled": false,
"expose_paths": [
{"source":"/src","target":"/dst","mode":"ro"}
]
}
}`)
if err := json.Unmarshal(raw, cfg); err != nil {
t.Fatalf("json.Unmarshal isolation config: %v", err)
}
if cfg.Isolation.Enabled {
t.Fatal("Isolation.Enabled should be false after unmarshal")
}
if len(cfg.Isolation.ExposePaths) != 1 {
t.Fatalf("ExposePaths len = %d, want 1", len(cfg.Isolation.ExposePaths))
}
if got := cfg.Isolation.ExposePaths[0]; got.Source != "/src" || got.Target != "/dst" || got.Mode != "ro" {
t.Fatalf("ExposePaths[0] = %+v, want source=/src target=/dst mode=ro", got)
}
}
// TestFlexibleStringSlice_UnmarshalText tests UnmarshalText with various comma separators
func TestFlexibleStringSlice_UnmarshalText(t *testing.T) {
tests := []struct {
@@ -1493,6 +1559,42 @@ func TestModelConfig_ExtraBodyRoundTrip(t *testing.T) {
}
}
func TestModelConfig_CustomHeadersRoundTrip(t *testing.T) {
dir := t.TempDir()
cfgPath := filepath.Join(dir, "config.json")
cfg := &Config{
Version: CurrentVersion,
ModelList: []*ModelConfig{
{
ModelName: "test-model",
Model: "openai/test",
APIKeys: SimpleSecureStrings("sk-test"),
CustomHeaders: map[string]string{"X-Source": "coding-plan", "X-Agent": "openclaw"},
},
},
}
if err := SaveConfig(cfgPath, cfg); err != nil {
t.Fatalf("SaveConfig error: %v", err)
}
loaded, err := LoadConfig(cfgPath)
if err != nil {
t.Fatalf("LoadConfig error: %v", err)
}
if loaded.ModelList[0].CustomHeaders == nil {
t.Fatal("CustomHeaders should not be nil after round-trip")
}
if got := loaded.ModelList[0].CustomHeaders["X-Source"]; got != "coding-plan" {
t.Errorf("CustomHeaders[X-Source] = %q, want coding-plan", got)
}
if got := loaded.ModelList[0].CustomHeaders["X-Agent"]; got != "openclaw" {
t.Errorf("CustomHeaders[X-Agent] = %q, want openclaw", got)
}
}
func TestDefaultConfig_MinimaxExtraBody(t *testing.T) {
cfg := DefaultConfig()
+7 -1
View File
@@ -17,6 +17,11 @@ func DefaultConfig() *Config {
return &Config{
Version: CurrentVersion,
// Isolation is opt-in so existing installations keep their current behavior
// until the user explicitly enables subprocess sandboxing.
Isolation: IsolationConfig{
Enabled: false,
},
Agents: AgentsConfig{
Defaults: AgentDefaults{
Workspace: workspacePath,
@@ -462,7 +467,8 @@ func DefaultConfig() *Config {
UseBM25: true,
UseRegex: false,
},
Servers: map[string]MCPServerConfig{},
MaxInlineTextChars: DefaultMCPMaxInlineTextChars,
Servers: map[string]MCPServerConfig{},
},
AppendFile: ToolConfig{
Enabled: true,
+1
View File
@@ -28,6 +28,7 @@ import (
"github.com/sipeed/picoclaw/pkg/channels/pico"
_ "github.com/sipeed/picoclaw/pkg/channels/qq"
_ "github.com/sipeed/picoclaw/pkg/channels/slack"
_ "github.com/sipeed/picoclaw/pkg/channels/teams_webhook"
_ "github.com/sipeed/picoclaw/pkg/channels/telegram"
_ "github.com/sipeed/picoclaw/pkg/channels/vk"
_ "github.com/sipeed/picoclaw/pkg/channels/wecom"
+3
View File
@@ -7,6 +7,7 @@ import (
"fmt"
"maps"
"net/http"
"os"
"sync"
"time"
)
@@ -31,6 +32,7 @@ type Check struct {
type StatusResponse struct {
Status string `json:"status"`
Uptime string `json:"uptime"`
PID int `json:"pid,omitempty"`
Checks map[string]Check `json:"checks,omitempty"`
}
@@ -170,6 +172,7 @@ func (s *Server) healthHandler(w http.ResponseWriter, r *http.Request) {
resp := StatusResponse{
Status: "ok",
Uptime: uptime.String(),
PID: os.Getpid(),
}
json.NewEncoder(w).Encode(resp)
+238
View File
@@ -0,0 +1,238 @@
# `pkg/isolation`
`pkg/isolation` provides process-level isolation for child processes started by `picoclaw`.
It does not sandbox the main `picoclaw` process itself.
## Scope
The current scope is the child-process startup path:
- `exec` tool
- CLI providers such as `claude-cli` and `codex-cli`
- process hooks
- MCP `stdio` servers
## One-Sentence Model
- The `picoclaw` main process still runs in the host environment.
- Every child process should enter the shared `pkg/isolation` startup path first.
- The startup path applies platform-specific isolation according to config.
## Architecture
The implementation has four layers:
1. Configuration layer: reads `config.Config.Isolation` and injects it through `isolation.Configure(cfg)`.
2. Instance layout layer: resolves `config.GetHome()`, prepares instance directories, and builds the runtime user environment.
3. Platform backend layer: Linux uses `bwrap`; Windows uses a restricted token, low integrity, and a `Job Object`; other platforms are not implemented.
4. Unified startup layer: `PrepareCommand(cmd)`, `Start(cmd)`, and `Run(cmd)`.
All integrations that spawn subprocesses should reuse these helpers instead of calling `cmd.Start` or `cmd.Run` directly.
## Configuration
Isolation lives under:
```json
{
"isolation": {
"enabled": false,
"expose_paths": []
}
}
```
Field meanings:
- `enabled`: enables or disables subprocess isolation. Default: `false`.
- `expose_paths`: explicitly exposes host paths inside the isolated environment. It only matters when `enabled=true`. This is currently supported on Linux only.
Example:
```json
{
"isolation": {
"enabled": true,
"expose_paths": [
{
"source": "/opt/toolchains/go",
"target": "/opt/toolchains/go",
"mode": "ro"
},
{
"source": "/data/shared-assets",
"target": "/opt/picoclaw-instance-a/workspace/assets",
"mode": "rw"
}
]
}
}
```
Rules for `expose_paths`:
- `source` is a host path.
- `target` is the path inside the isolated environment.
- `mode` must be `ro` or `rw`.
- When `target` is empty, it defaults to `source`.
- Only one final rule may exist for the same `target`.
- Later-loaded config overrides earlier rules for the same `target`.
Platform note:
- Linux uses a real `source -> target` mount view.
- Windows does not currently support `expose_paths`.
## Instance Root And Directories
The instance root follows `config.GetHome()`:
- If `PICOCLAW_HOME` is set, use it.
- Otherwise use the default `.picoclaw` directory under the user home.
If `config.GetHome()` falls back to `.` while isolation is enabled, startup should fail.
Default instance directories include:
- instance root
- `skills`
- `logs`
- `cache`
- `state`
- `runtime-user-env`
`workspace` is derived from `cfg.WorkspacePath()` when configured, otherwise from the default workspace rule.
Windows also prepares:
- `runtime-user-env/AppData/Roaming`
- `runtime-user-env/AppData/Local`
## User Environment Redirect
When isolation is enabled, child processes receive a redirected per-instance user environment.
Linux variables:
- `HOME`
- `TMPDIR`
- `XDG_CONFIG_HOME`
- `XDG_CACHE_HOME`
- `XDG_STATE_HOME`
Windows variables:
- `USERPROFILE`
- `HOME`
- `TEMP`
- `TMP`
- `APPDATA`
- `LOCALAPPDATA`
These paths point into `runtime-user-env` under the instance root.
## Platform Behavior
### Linux
The Linux backend currently depends on `bwrap` (`bubblewrap`).
Capabilities:
- minimal filesystem view
- `ipc` namespace isolation
- redirected child-process user environment
- `source -> target` read-only or read-write mounts
Default mounts include the instance root plus the minimum runtime system paths such as `/usr`, `/bin`, `/lib`, `/lib64`, and `/etc/resolv.conf`.
At runtime, PicoClaw also adds the executable path, its directory, the effective working directory, and absolute path arguments when needed.
There is no automatic fallback when `bwrap` is missing.
Install examples:
- `apt install bubblewrap`
- `dnf install bubblewrap`
- `yum install bubblewrap`
- `pacman -S bubblewrap`
- `apk add bubblewrap`
If isolation must be disabled temporarily:
```json
{
"isolation": {
"enabled": false
}
}
```
Disabling isolation increases the risk that child processes can access or modify more host files.
### Windows
Windows isolation currently supports process-level restrictions such as restricted tokens, low integrity, job objects, and redirected user-environment directories.
`expose_paths` is not currently supported on Windows. If it is configured, startup should fail instead of pretending the paths were exposed.
The Windows backend currently uses:
- a restricted primary token
- low integrity level
- a `Job Object`
- redirected child-process user environment
It does not currently implement true `source -> target` filesystem remapping.
### macOS And Other Platforms
They are not implemented yet.
When isolation is explicitly enabled on an unsupported platform, the higher-level runtime should surface that as an unsupported configuration instead of pretending isolation succeeded.
## Logging And Debugging
When isolation is enabled, PicoClaw logs the generated isolation plan.
Linux log name:
- `linux isolation mount plan`
Windows log name:
- `windows isolation access rules`
If you suspect isolation is ineffective, check whether unexpected host paths appear in those logs.
## Relationship To `restrict_to_workspace`
- `restrict_to_workspace` limits the paths an agent is normally allowed to access.
- `pkg/isolation` limits what a child process can see and where its user environment points.
They complement each other and do not replace each other.
## Current Limits
- Linux isolation is implemented with `bwrap`, not a custom in-process isolation runtime.
- Linux does not currently enable a dedicated `pid` namespace by default.
- Windows does not yet implement full host ACL enforcement for every allowed or denied path.
- macOS is not implemented.
- The current design isolates child processes, not the main `picoclaw` process.
## Suggested Reading Order
If you are new to this code, read it in this order:
1. `pkg/config/config.go`
2. `pkg/isolation/runtime.go`
3. `pkg/isolation/platform_linux.go`
4. `pkg/isolation/platform_windows.go`
5. Call sites:
6. `pkg/tools/shell.go`
7. `pkg/providers/*.go`
8. `pkg/agent/hook_process.go`
9. `pkg/mcp/manager.go`
That path gives the fastest overview of the configuration model, runtime flow, and platform-specific limits.
+238
View File
@@ -0,0 +1,238 @@
# `pkg/isolation`
`pkg/isolation``picoclaw` 启动的子进程提供进程级隔离能力。
它当前不会把 `picoclaw` 主进程自身放进沙箱中运行。
## 生效范围
当前生效范围是子进程启动链路:
- `exec` 工具
- `claude-cli``codex-cli` 等 CLI provider
- 进程型 hooks
- MCP `stdio` server
## 一句话理解
- `picoclaw` 主进程仍运行在宿主环境中。
- 所有子进程都应先经过 `pkg/isolation` 的统一启动入口。
- 入口会根据配置和平台,为子进程施加对应隔离。
## 架构
当前实现可以分为四层:
1. 配置层:读取 `config.Config.Isolation`,并通过 `isolation.Configure(cfg)` 注入运行时。
2. 实例目录层:解析 `config.GetHome()`,准备实例目录,并构建运行时用户环境目录。
3. 平台后端层:Linux 使用 `bwrap`Windows 使用受限 token、低完整性级别和 `Job Object`;其他平台未实现。
4. 统一启动层:`PrepareCommand(cmd)``Start(cmd)``Run(cmd)`
所有启动子进程的接入点都应复用这组入口,而不是各自直接调用 `cmd.Start``cmd.Run`
## 配置
隔离配置位于:
```json
{
"isolation": {
"enabled": false,
"expose_paths": []
}
}
```
字段说明:
- `enabled`:是否启用子进程隔离。默认值:`false`
- `expose_paths`:显式把宿主路径带入隔离环境。仅在 `enabled=true` 时生效。目前只在 Linux 上支持。
示例:
```json
{
"isolation": {
"enabled": true,
"expose_paths": [
{
"source": "/opt/toolchains/go",
"target": "/opt/toolchains/go",
"mode": "ro"
},
{
"source": "/data/shared-assets",
"target": "/opt/picoclaw-instance-a/workspace/assets",
"mode": "rw"
}
]
}
}
```
`expose_paths` 规则:
- `source`:宿主机路径。
- `target`:隔离环境内的目标路径。
- `mode`:只能是 `ro``rw`
- `target` 为空时,默认等于 `source`
- 同一个 `target` 最终只能保留一条规则。
- 后加载的配置会覆盖先加载的同目标规则。
平台说明:
- Linux 会真实使用 `source -> target` 挂载视图。
- Windows 当前不支持 `expose_paths`
## 实例根与目录
实例根遵循 `config.GetHome()`
- 如果设置了 `PICOCLAW_HOME`,使用该值。
- 否则默认使用用户目录下的 `.picoclaw`
如果 `config.GetHome()` 在隔离开启时最终回退到当前目录 `.`,启动应直接失败。
默认实例目录包括:
- 实例根本身
- `skills`
- `logs`
- `cache`
- `state`
- `runtime-user-env`
`workspace` 优先使用 `cfg.WorkspacePath()` 的结果;未显式配置时才按默认规则派生。
Windows 还会额外准备:
- `runtime-user-env/AppData/Roaming`
- `runtime-user-env/AppData/Local`
## 用户环境重定向
隔离开启后,子进程会收到重定向到实例目录下的独立用户环境。
Linux 注入变量:
- `HOME`
- `TMPDIR`
- `XDG_CONFIG_HOME`
- `XDG_CACHE_HOME`
- `XDG_STATE_HOME`
Windows 注入变量:
- `USERPROFILE`
- `HOME`
- `TEMP`
- `TMP`
- `APPDATA`
- `LOCALAPPDATA`
这些路径都会指向实例根下的 `runtime-user-env`
## 平台行为
### Linux
Linux 后端当前依赖 `bwrap``bubblewrap`)。
能力:
- 最小文件系统视图
- `ipc namespace`
- 子进程用户环境重定向
- `source -> target` 只读或读写挂载
默认映射包括实例根,以及 `/usr``/bin``/lib``/lib64``/etc/resolv.conf` 等最小运行时系统路径。
运行时还会按需补充可执行文件本身、其所在目录、生效后的工作目录,以及命令行中的绝对路径参数。
缺少 `bwrap` 时不会自动回退。
安装示例:
- `apt install bubblewrap`
- `dnf install bubblewrap`
- `yum install bubblewrap`
- `pacman -S bubblewrap`
- `apk add bubblewrap`
如果需要临时关闭隔离:
```json
{
"isolation": {
"enabled": false
}
}
```
关闭隔离后,子进程访问或修改更多宿主文件的风险会明显上升。
### Windows
Windows 隔离当前提供的是进程级限制,例如 restricted token、low integrity、job object,以及用户环境目录重定向。
`expose_paths` 目前不支持 Windows。如果配置了该字段,启动应直接失败,而不是假装这些路径已经被暴露进隔离环境。
Windows 后端当前使用:
- 受限 primary token
- 低完整性级别
- `Job Object`
- 子进程用户环境重定向
它当前不会实现真正的 `source -> target` 文件系统重映射。
### macOS 与其他平台
当前尚未实现。
当在未支持的平台上显式开启隔离时,上层运行时应将其视为不支持的配置,而不是假装隔离成功。
## 日志与排障
隔离开启后,PicoClaw 会打印生成后的隔离计划,便于排障。
Linux 日志名:
- `linux isolation mount plan`
Windows 日志名:
- `windows isolation access rules`
如果你怀疑隔离未生效,先检查这些日志里是否出现了不应暴露的宿主路径。
## 与 `restrict_to_workspace` 的关系
- `restrict_to_workspace` 限制的是 agent 默认可访问的路径。
- `pkg/isolation` 限制的是子进程运行时能看到什么文件系统,以及它的用户环境指向哪里。
两者互补,不互相替代。
## 当前限制
- Linux 基于 `bwrap` 实现,而不是纯内建 isolation runtime。
- Linux 当前没有默认启用独立的 `pid namespace`
- Windows 还没有对所有允许/拒绝路径做完整 ACL 落地。
- macOS 尚未实现。
- 当前隔离的是子进程,不是 `picoclaw` 主进程自身。
## 建议阅读顺序
如果你是第一次看这部分代码,建议按这个顺序阅读:
1. `pkg/config/config.go`
2. `pkg/isolation/runtime.go`
3. `pkg/isolation/platform_linux.go`
4. `pkg/isolation/platform_windows.go`
5. 调用点:
6. `pkg/tools/shell.go`
7. `pkg/providers/*.go`
8. `pkg/agent/hook_process.go`
9. `pkg/mcp/manager.go`
这样能最快建立对配置模型、运行流程和平台边界的整体理解。
+264
View File
@@ -0,0 +1,264 @@
//go:build linux
package isolation
import (
"errors"
"fmt"
"os"
"os/exec"
"path/filepath"
"strings"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
)
func applyPlatformIsolation(cmd *exec.Cmd, isolation config.IsolationConfig, root string) error {
if !isolation.Enabled {
return nil
}
// Bubblewrap is the only supported Linux backend right now. Fail closed when
// it is unavailable instead of silently running the child process unisolated.
bwrapPath, err := exec.LookPath("bwrap")
if err != nil {
hint := bwrapInstallHint()
disableHint := `set "isolation.enabled": false in config.json`
logger.WarnCF("isolation", "bubblewrap is required for Linux isolation",
map[string]any{
"binary": "bwrap",
"install": hint,
"disable_isolation": disableHint,
"risk": "disabling isolation lets child processes run without Linux filesystem isolation",
})
return fmt.Errorf(
"linux isolation requires bwrap and does not fall back automatically: %w; install bubblewrap with one of: %s; or disable isolation by setting %s; disabling isolation means child processes can run without Linux filesystem isolation and may access or modify more host files",
err,
hint,
disableHint,
)
}
if cmd == nil || cmd.Path == "" || len(cmd.Args) == 0 {
return nil
}
originalPath := cmd.Path
originalArgs := append([]string{}, cmd.Args...)
_, execDir, err := resolveLinuxWorkingDir(cmd.Dir, originalPath)
if err != nil {
return err
}
resolvedPath, err := resolveLinuxCommandPath(originalPath, execDir)
if err != nil {
return err
}
// Start from the configured mount plan, then add only the executable, its
// resolved path, the effective working directory, and any absolute path
// arguments needed to preserve the original command semantics.
plan := BuildLinuxMountPlan(root, isolation.ExposePaths)
plan = ensureLinuxMountRule(plan, resolvedPath, resolvedPath, "ro")
plan = ensureLinuxMountRule(plan, filepath.Dir(resolvedPath), filepath.Dir(resolvedPath), "ro")
if resolved, resolveErr := filepath.EvalSymlinks(resolvedPath); resolveErr == nil && resolved != resolvedPath {
plan = ensureLinuxMountRule(plan, resolved, resolved, "ro")
plan = ensureLinuxMountRule(plan, filepath.Dir(resolved), filepath.Dir(resolved), "ro")
}
if execDir != "" {
plan = ensureLinuxMountRule(plan, execDir, execDir, "rw")
if resolved, resolveErr := filepath.EvalSymlinks(execDir); resolveErr == nil && resolved != execDir {
plan = ensureLinuxMountRule(plan, resolved, resolved, "rw")
}
}
plan = appendLinuxArgumentMounts(plan, originalArgs[1:])
logger.DebugCF("isolation", "linux isolation mount plan",
map[string]any{
"root": root,
"command": resolvedPath,
"working_dir": execDir,
"mounts": formatLinuxMountPlan(plan),
})
bwrapArgs, err := buildLinuxBwrapArgs(originalPath, resolvedPath, originalArgs, execDir, plan)
if err != nil {
return err
}
cmd.Path = bwrapPath
cmd.Args = bwrapArgs
cmd.Dir = ""
return nil
}
func bwrapInstallHint() string {
return "apt install bubblewrap; dnf install bubblewrap; yum install bubblewrap; pacman -S bubblewrap; apk add bubblewrap"
}
// formatLinuxMountPlan reshapes the internal plan for structured logging.
func formatLinuxMountPlan(plan []MountRule) []map[string]string {
formatted := make([]map[string]string, 0, len(plan))
for _, rule := range plan {
formatted = append(formatted, map[string]string{
"source": rule.Source,
"target": rule.Target,
"mode": rule.Mode,
})
}
return formatted
}
func postStartPlatformIsolation(cmd *exec.Cmd, isolation config.IsolationConfig, root string) error {
return nil
}
func cleanupPendingPlatformResources(cmd *exec.Cmd) {
}
// buildLinuxBwrapArgs translates the mount plan into the bubblewrap command
// line that re-executes the original process inside the isolated mount view.
func buildLinuxBwrapArgs(
originalPath string,
resolvedPath string,
originalArgs []string,
execDir string,
plan []MountRule,
) ([]string, error) {
bwrapArgs := []string{
"bwrap",
"--die-with-parent",
"--unshare-ipc",
"--proc", "/proc",
"--dev", "/dev",
}
for _, rule := range plan {
flag, err := linuxBindFlag(rule)
if err != nil {
return nil, err
}
bwrapArgs = append(bwrapArgs, flag, rule.Source, rule.Target)
}
if execDir != "" {
bwrapArgs = append(bwrapArgs, "--chdir", execDir)
}
execPath := originalPath
if isRelativeCommandPath(originalPath) {
execPath = resolvedPath
}
bwrapArgs = append(bwrapArgs, "--", execPath)
if len(originalArgs) > 1 {
bwrapArgs = append(bwrapArgs, originalArgs[1:]...)
}
return bwrapArgs, nil
}
func resolveLinuxWorkingDir(originalDir, originalPath string) (string, string, error) {
if originalDir != "" {
resolved, err := filepath.Abs(originalDir)
if err != nil {
return "", "", fmt.Errorf("resolve command dir %s: %w", originalDir, err)
}
return resolved, resolved, nil
}
if !isRelativeCommandPath(originalPath) {
return "", "", nil
}
wd, err := os.Getwd()
if err != nil {
return "", "", fmt.Errorf("resolve current working dir: %w", err)
}
return "", wd, nil
}
func resolveLinuxCommandPath(originalPath, execDir string) (string, error) {
if filepath.IsAbs(originalPath) || !isRelativeCommandPath(originalPath) {
return filepath.Clean(originalPath), nil
}
base := execDir
if base == "" {
var err error
base, err = os.Getwd()
if err != nil {
return "", fmt.Errorf("resolve current working dir: %w", err)
}
}
return filepath.Clean(filepath.Join(base, originalPath)), nil
}
func appendLinuxArgumentMounts(plan []MountRule, args []string) []MountRule {
for _, arg := range args {
path, ok := linuxArgumentPath(arg)
if !ok {
continue
}
clean := filepath.Clean(path)
if info, err := os.Stat(clean); err == nil {
mode := "ro"
if info.IsDir() {
mode = "rw"
}
plan = ensureLinuxMountRule(plan, clean, clean, mode)
if resolved, resolveErr := filepath.EvalSymlinks(clean); resolveErr == nil && resolved != clean {
plan = ensureLinuxMountRule(plan, resolved, resolved, mode)
}
continue
} else if !errors.Is(err, os.ErrNotExist) {
continue
}
parent := filepath.Dir(clean)
if parent == clean {
continue
}
if _, err := os.Stat(parent); err == nil {
plan = ensureLinuxMountRule(plan, parent, parent, "rw")
}
}
return plan
}
func linuxArgumentPath(arg string) (string, bool) {
if filepath.IsAbs(arg) {
return arg, true
}
idx := strings.IndexRune(arg, '=')
if idx <= 0 || idx == len(arg)-1 {
return "", false
}
value := arg[idx+1:]
if !filepath.IsAbs(value) {
return "", false
}
return value, true
}
func isRelativeCommandPath(path string) bool {
return !filepath.IsAbs(path) && strings.ContainsRune(path, filepath.Separator)
}
// ensureLinuxMountRule appends a mount rule unless another rule already owns
// the same target path.
func ensureLinuxMountRule(plan []MountRule, source, target, mode string) []MountRule {
cleanSource := filepath.Clean(source)
cleanTarget := filepath.Clean(target)
for _, rule := range plan {
if filepath.Clean(rule.Target) == cleanTarget {
return plan
}
}
return append(plan, MountRule{Source: cleanSource, Target: cleanTarget, Mode: mode})
}
// linuxBindFlag selects the correct bubblewrap bind flag based on mount mode.
func linuxBindFlag(rule MountRule) (string, error) {
info, err := os.Stat(rule.Source)
if err != nil {
return "", fmt.Errorf("stat linux mount source %s: %w", rule.Source, err)
}
if !info.IsDir() {
if rule.Mode == "rw" {
return "--bind", nil
}
return "--ro-bind", nil
}
if rule.Mode == "rw" {
return "--bind", nil
}
return "--ro-bind", nil
}
+148
View File
@@ -0,0 +1,148 @@
//go:build linux
package isolation
import (
"os"
"path/filepath"
"testing"
"github.com/sipeed/picoclaw/pkg/config"
)
func TestBuildLinuxBwrapArgs_IncludesNamespaceFlagsAndExec(t *testing.T) {
root := t.TempDir()
binaryDir := filepath.Join(root, "bin")
if err := os.MkdirAll(binaryDir, 0o755); err != nil {
t.Fatal(err)
}
binaryPath := filepath.Join(binaryDir, "tool")
if err := os.WriteFile(binaryPath, []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil {
t.Fatal(err)
}
plan := BuildLinuxMountPlan(root, []config.ExposePath{{Source: binaryDir, Target: binaryDir, Mode: "ro"}})
args, err := buildLinuxBwrapArgs(binaryPath, binaryPath, []string{binaryPath, "--flag"}, root, plan)
if err != nil {
t.Fatalf("buildLinuxBwrapArgs() error = %v", err)
}
hasNet := false
hasIPC := false
hasExec := false
for i := range args {
switch args[i] {
case "--unshare-net":
hasNet = true
case "--unshare-ipc":
hasIPC = true
case "--":
if i+1 < len(args) && args[i+1] == binaryPath {
hasExec = true
}
}
}
if hasNet {
t.Fatalf("bwrap args should not unshare net by default: %v", args)
}
if !hasIPC || !hasExec {
t.Fatalf("bwrap args missing required items: %v", args)
}
}
func TestResolveLinuxWorkingDir_ResolvesRelativeDir(t *testing.T) {
cwd := t.TempDir()
previous, err := os.Getwd()
if err != nil {
t.Fatal(err)
}
defer func() {
if chdirErr := os.Chdir(previous); chdirErr != nil {
t.Fatalf("restore cwd: %v", chdirErr)
}
}()
if chdirErr := os.Chdir(cwd); chdirErr != nil {
t.Fatal(chdirErr)
}
resolvedDir, execDir, err := resolveLinuxWorkingDir("./hooks", "./hook.sh")
if err != nil {
t.Fatalf("resolveLinuxWorkingDir() error = %v", err)
}
want := filepath.Join(cwd, "hooks")
if resolvedDir != want || execDir != want {
t.Fatalf("resolveLinuxWorkingDir() = (%q, %q), want (%q, %q)", resolvedDir, execDir, want, want)
}
}
func TestResolveLinuxCommandPath_UsesExecDirForRelativeCommand(t *testing.T) {
execDir := filepath.Join(t.TempDir(), "hooks")
got, err := resolveLinuxCommandPath("./hook.sh", execDir)
if err != nil {
t.Fatalf("resolveLinuxCommandPath() error = %v", err)
}
want := filepath.Join(execDir, "hook.sh")
if got != want {
t.Fatalf("resolveLinuxCommandPath() = %q, want %q", got, want)
}
}
func TestBuildLinuxBwrapArgs_UsesResolvedPathForRelativeCommand(t *testing.T) {
root := t.TempDir()
execDir := filepath.Join(root, "hooks")
if err := os.MkdirAll(execDir, 0o755); err != nil {
t.Fatal(err)
}
resolvedPath := filepath.Join(execDir, "hook.sh")
if err := os.WriteFile(resolvedPath, []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil {
t.Fatal(err)
}
plan := []MountRule{
{Source: execDir, Target: execDir, Mode: "rw"},
{Source: resolvedPath, Target: resolvedPath, Mode: "ro"},
}
args, err := buildLinuxBwrapArgs("./hook.sh", resolvedPath, []string{"./hook.sh"}, execDir, plan)
if err != nil {
t.Fatalf("buildLinuxBwrapArgs() error = %v", err)
}
hasExecDir := false
for _, arg := range args {
if arg == execDir {
hasExecDir = true
break
}
}
if !hasExecDir {
t.Fatalf("buildLinuxBwrapArgs() missing resolved chdir: %v", args)
}
for i := range args {
if args[i] == "--" {
if i+1 >= len(args) || args[i+1] != resolvedPath {
t.Fatalf("buildLinuxBwrapArgs() exec path = %v, want %q after --", args, resolvedPath)
}
return
}
}
t.Fatalf("buildLinuxBwrapArgs() missing exec delimiter: %v", args)
}
func TestAppendLinuxArgumentMounts_AddsAbsoluteArgumentPaths(t *testing.T) {
root := t.TempDir()
input := filepath.Join(root, "input.txt")
if err := os.WriteFile(input, []byte("data"), 0o644); err != nil {
t.Fatal(err)
}
output := filepath.Join(root, "out", "result.txt")
if err := os.MkdirAll(filepath.Dir(output), 0o755); err != nil {
t.Fatal(err)
}
plan := appendLinuxArgumentMounts(nil, []string{input, "--output=" + output})
if len(plan) != 2 {
t.Fatalf("appendLinuxArgumentMounts() len = %d, want 2", len(plan))
}
if plan[0].Source != input || plan[0].Mode != "ro" {
t.Fatalf("appendLinuxArgumentMounts()[0] = %+v, want source=%q mode=ro", plan[0], input)
}
if plan[1].Source != filepath.Dir(output) || plan[1].Mode != "rw" {
t.Fatalf("appendLinuxArgumentMounts()[1] = %+v, want source=%q mode=rw", plan[1], filepath.Dir(output))
}
}
+22
View File
@@ -0,0 +1,22 @@
//go:build !linux && !windows
package isolation
import (
"os/exec"
"github.com/sipeed/picoclaw/pkg/config"
)
func applyPlatformIsolation(cmd *exec.Cmd, isolation config.IsolationConfig, root string) error {
// Unsupported platforms currently keep the command unchanged. Callers rely on
// Preflight and higher-level checks to surface unsupported isolation modes.
return nil
}
func postStartPlatformIsolation(cmd *exec.Cmd, isolation config.IsolationConfig, root string) error {
return nil
}
func cleanupPendingPlatformResources(cmd *exec.Cmd) {
}
+217
View File
@@ -0,0 +1,217 @@
//go:build windows
package isolation
import (
"fmt"
"os/exec"
"sync"
"syscall"
"unsafe"
"golang.org/x/sys/windows"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
)
const disableMaxPrivilege = 0x1
// windowsProcessResources holds native handles that must live for the lifetime
// of an isolated child process.
type windowsProcessResources struct {
job windows.Handle
token windows.Token
}
var (
windowsProcessResourcesByPID sync.Map
windowsPendingResources sync.Map
advapi32 = windows.NewLazySystemDLL("advapi32.dll")
procCreateRestrictedToken = advapi32.NewProc("CreateRestrictedToken")
)
func applyPlatformIsolation(cmd *exec.Cmd, isolation config.IsolationConfig, root string) error {
if !isolation.Enabled || cmd == nil {
return nil
}
if cmd.SysProcAttr == nil {
cmd.SysProcAttr = &syscall.SysProcAttr{}
}
rules := BuildWindowsAccessRules(root, isolation.ExposePaths)
logger.InfoCF("isolation", "windows isolation process constraints",
map[string]any{
"root": root,
"command": cmd.Path,
"rules": formatWindowsAccessRules(rules),
"note": "Windows currently enforces restricted token, low integrity, and job object limits; expose_paths filesystem remapping is rejected during preflight",
})
// Create the restricted token before the process starts so CreateProcess uses
// the reduced privilege set from the first instruction.
restrictedToken, err := createRestrictedPrimaryToken()
if err != nil {
return fmt.Errorf("create restricted primary token: %w", err)
}
cmd.SysProcAttr.CreationFlags |= windows.CREATE_NEW_PROCESS_GROUP | windows.CREATE_BREAKAWAY_FROM_JOB
cmd.SysProcAttr.Token = syscall.Token(restrictedToken)
windowsPendingResources.Store(cmd, windowsProcessResources{token: restrictedToken})
return nil
}
func postStartPlatformIsolation(cmd *exec.Cmd, isolation config.IsolationConfig, root string) error {
if !isolation.Enabled || cmd == nil || cmd.Process == nil {
return nil
}
resourcesAny, _ := windowsPendingResources.LoadAndDelete(cmd)
resources, _ := resourcesAny.(windowsProcessResources)
// Job objects can only be attached after the process exists, so the Windows
// backend finishes isolation in this post-start hook.
job, err := windows.CreateJobObject(nil, nil)
if err != nil {
if resources.token != 0 {
_ = resources.token.Close()
}
return fmt.Errorf("create windows job object: %w", err)
}
info := windows.JOBOBJECT_EXTENDED_LIMIT_INFORMATION{}
info.BasicLimitInformation.LimitFlags = windows.JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE
if _, err := windows.SetInformationJobObject(
job,
windows.JobObjectExtendedLimitInformation,
uintptr(unsafe.Pointer(&info)),
uint32(unsafe.Sizeof(info)),
); err != nil {
_ = windows.CloseHandle(job)
if resources.token != 0 {
_ = resources.token.Close()
}
return fmt.Errorf("set windows job object info: %w", err)
}
proc, err := windows.OpenProcess(
windows.PROCESS_SET_QUOTA|windows.PROCESS_TERMINATE|windows.PROCESS_QUERY_LIMITED_INFORMATION|windows.SYNCHRONIZE,
false,
uint32(cmd.Process.Pid),
)
if err != nil {
_ = windows.CloseHandle(job)
if resources.token != 0 {
_ = resources.token.Close()
}
return fmt.Errorf("open process for job assignment: %w", err)
}
if err := windows.AssignProcessToJobObject(job, proc); err != nil {
_ = windows.CloseHandle(proc)
_ = windows.CloseHandle(job)
if resources.token != 0 {
_ = resources.token.Close()
}
return fmt.Errorf("assign process to job object: %w", err)
}
if resources.token != 0 {
_ = resources.token.Close()
}
resources.job = job
windowsProcessResourcesByPID.Store(cmd.Process.Pid, resources)
go reapWindowsProcessResources(cmd.Process.Pid, proc, job)
return nil
}
func cleanupPendingPlatformResources(cmd *exec.Cmd) {
if cmd == nil {
return
}
resourcesAny, ok := windowsPendingResources.LoadAndDelete(cmd)
if !ok {
return
}
resources, _ := resourcesAny.(windowsProcessResources)
if resources.token != 0 {
_ = resources.token.Close()
}
}
func reapWindowsProcessResources(pid int, proc windows.Handle, job windows.Handle) {
_, _ = windows.WaitForSingleObject(proc, windows.INFINITE)
_ = windows.CloseHandle(proc)
_ = windows.CloseHandle(job)
windowsProcessResourcesByPID.Delete(pid)
}
// createRestrictedPrimaryToken duplicates the current process token, removes
// maximum privileges, and lowers integrity before it is assigned to a child.
func createRestrictedPrimaryToken() (windows.Token, error) {
var current windows.Token
if err := windows.OpenProcessToken(
windows.CurrentProcess(),
windows.TOKEN_DUPLICATE|windows.TOKEN_ASSIGN_PRIMARY|windows.TOKEN_QUERY|windows.TOKEN_ADJUST_DEFAULT,
&current,
); err != nil {
return 0, err
}
defer current.Close()
var restricted windows.Token
r1, _, e1 := procCreateRestrictedToken.Call(
uintptr(current),
uintptr(disableMaxPrivilege),
0,
0,
0,
0,
0,
0,
0,
uintptr(unsafe.Pointer(&restricted)),
)
if r1 == 0 {
if e1 != nil && e1 != syscall.Errno(0) {
return 0, e1
}
return 0, syscall.EINVAL
}
if err := setTokenLowIntegrity(restricted); err != nil {
_ = restricted.Close()
return 0, err
}
return restricted, nil
}
// setTokenLowIntegrity lowers the token integrity level so writes to higher
// integrity locations are blocked by the OS.
func setTokenLowIntegrity(token windows.Token) error {
lowSID, err := windows.CreateWellKnownSid(windows.WinLowLabelSid)
if err != nil {
return fmt.Errorf("create low integrity sid: %w", err)
}
tml := windows.Tokenmandatorylabel{
Label: windows.SIDAndAttributes{
Sid: lowSID,
Attributes: windows.SE_GROUP_INTEGRITY,
},
}
if err := windows.SetTokenInformation(
token,
windows.TokenIntegrityLevel,
(*byte)(unsafe.Pointer(&tml)),
tml.Size(),
); err != nil {
return fmt.Errorf("set token low integrity: %w", err)
}
return nil
}
// formatWindowsAccessRules reshapes the internal rules for structured logging.
func formatWindowsAccessRules(rules []AccessRule) []map[string]string {
formatted := make([]map[string]string, 0, len(rules))
for _, rule := range rules {
formatted = append(formatted, map[string]string{
"path": rule.Path,
"mode": rule.Mode,
})
}
return formatted
}
+443
View File
@@ -0,0 +1,443 @@
package isolation
import (
"fmt"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"sync"
"github.com/sipeed/picoclaw/pkg"
"github.com/sipeed/picoclaw/pkg/config"
)
// MountRule describes a source-to-target mount exposed inside the Linux
// isolation view.
type MountRule struct {
Source string
Target string
Mode string
}
// AccessRule describes the effective Windows-side access rule for a host path.
type AccessRule struct {
Path string
Mode string
}
// UserEnv contains the redirected per-instance user directories injected into
// isolated child processes.
type UserEnv struct {
Home string
Tmp string
Config string
Cache string
State string
AppData string
LocalAppData string
}
var (
isolationMu sync.RWMutex
currentIsolation = config.DefaultConfig().Isolation
)
// Configure updates the process-wide isolation state used by subsequent child
// process launches.
func Configure(cfg *config.Config) {
isolationMu.Lock()
defer isolationMu.Unlock()
if cfg == nil {
defaults := config.DefaultConfig()
currentIsolation = defaults.Isolation
return
}
currentIsolation = cfg.Isolation
}
// CurrentConfig returns the currently active isolation settings.
func CurrentConfig() config.IsolationConfig {
isolationMu.RLock()
defer isolationMu.RUnlock()
return currentIsolation
}
// ResolveInstanceRoot resolves the instance root used to build the isolated
// filesystem and redirected user environment.
func ResolveInstanceRoot() (string, error) {
root := filepath.Clean(config.GetHome())
if root == "." {
return "", fmt.Errorf("instance root resolved to current directory")
}
return root, nil
}
// PrepareInstanceRoot creates the directories required by the isolation runtime.
func PrepareInstanceRoot(root string) error {
for _, dir := range InstanceDirs(root) {
if err := os.MkdirAll(dir, 0o755); err != nil {
return fmt.Errorf("prepare instance dir %s: %w", dir, err)
}
}
return nil
}
// InstanceDirs returns the directories that must exist under the instance root
// for isolation-aware child processes.
func InstanceDirs(root string) []string {
dirs := []string{
root,
filepath.Join(root, "skills"),
filepath.Join(root, "logs"),
filepath.Join(root, "cache"),
filepath.Join(root, "state"),
filepath.Join(root, "runtime-user-env"),
filepath.Join(root, "runtime-user-env", "home"),
filepath.Join(root, "runtime-user-env", "tmp"),
filepath.Join(root, "runtime-user-env", "config"),
filepath.Join(root, "runtime-user-env", "cache"),
filepath.Join(root, "runtime-user-env", "state"),
}
dirs = append(dirs, filepath.Join(root, pkg.WorkspaceName))
if runtime.GOOS == "windows" {
dirs = append(dirs,
filepath.Join(root, "runtime-user-env", "AppData", "Roaming"),
filepath.Join(root, "runtime-user-env", "AppData", "Local"),
)
}
return dirs
}
// ResolveUserEnv derives the redirected user directories rooted under the
// instance runtime area.
func ResolveUserEnv(root string) UserEnv {
base := filepath.Join(root, "runtime-user-env")
return UserEnv{
Home: filepath.Join(base, "home"),
Tmp: filepath.Join(base, "tmp"),
Config: filepath.Join(base, "config"),
Cache: filepath.Join(base, "cache"),
State: filepath.Join(base, "state"),
AppData: filepath.Join(base, "AppData", "Roaming"),
LocalAppData: filepath.Join(base, "AppData", "Local"),
}
}
// ApplyUserEnv rewrites the child process environment so home, temp, and
// platform-specific user-data directories point into the instance root.
func ApplyUserEnv(cmd *exec.Cmd, root string) {
userEnv := ResolveUserEnv(root)
envMap := make(map[string]string)
for _, item := range cmd.Environ() {
if idx := strings.IndexRune(item, '='); idx > 0 {
envMap[item[:idx]] = item[idx+1:]
}
}
if runtime.GOOS == "windows" {
envMap["USERPROFILE"] = userEnv.Home
envMap["HOME"] = userEnv.Home
envMap["TEMP"] = userEnv.Tmp
envMap["TMP"] = userEnv.Tmp
envMap["APPDATA"] = userEnv.AppData
envMap["LOCALAPPDATA"] = userEnv.LocalAppData
} else {
envMap["HOME"] = userEnv.Home
envMap["TMPDIR"] = userEnv.Tmp
envMap["XDG_CONFIG_HOME"] = userEnv.Config
envMap["XDG_CACHE_HOME"] = userEnv.Cache
envMap["XDG_STATE_HOME"] = userEnv.State
}
env := make([]string, 0, len(envMap))
for k, v := range envMap {
env = append(env, fmt.Sprintf("%s=%s", k, v))
}
cmd.Env = env
}
// ValidateExposePaths verifies the user-supplied path exposure rules before a
// child process is started.
func ValidateExposePaths(items []config.ExposePath) error {
seen := map[string]struct{}{}
for _, item := range items {
if item.Source == "" {
return fmt.Errorf("source is required")
}
if item.Mode != "ro" && item.Mode != "rw" {
return fmt.Errorf("invalid expose_paths mode: %s", item.Mode)
}
source := filepath.Clean(item.Source)
target := item.Target
if target == "" {
target = source
}
target = filepath.Clean(target)
if !filepath.IsAbs(source) || !filepath.IsAbs(target) {
return fmt.Errorf("source and target must be absolute paths")
}
if _, ok := seen[target]; ok {
return fmt.Errorf("duplicate expose_path target: %s", target)
}
seen[target] = struct{}{}
}
return nil
}
// NormalizeExposePath fills implicit defaults and cleans path values so merge
// and validation logic can work with canonical paths.
func NormalizeExposePath(item config.ExposePath) config.ExposePath {
source := filepath.Clean(item.Source)
target := item.Target
if target == "" {
target = source
}
return config.ExposePath{
Source: source,
Target: filepath.Clean(target),
Mode: item.Mode,
}
}
// DefaultExposePaths returns the minimum built-in host paths required for the
// current platform to run isolated child processes.
func DefaultExposePaths(root string) []config.ExposePath {
items := []config.ExposePath{{
Source: root,
Target: root,
Mode: "rw",
}}
if runtime.GOOS == "linux" {
items = append(items, defaultLinuxSystemExposePaths()...)
}
return items
}
func defaultLinuxSystemExposePaths() []config.ExposePath {
return existingExposePaths([]config.ExposePath{
{Source: "/usr", Target: "/usr", Mode: "ro"},
{Source: "/bin", Target: "/bin", Mode: "ro"},
{Source: "/lib", Target: "/lib", Mode: "ro"},
{Source: "/lib64", Target: "/lib64", Mode: "ro"},
{Source: "/etc/resolv.conf", Target: "/etc/resolv.conf", Mode: "ro"},
{Source: "/etc/hosts", Target: "/etc/hosts", Mode: "ro"},
{Source: "/etc/nsswitch.conf", Target: "/etc/nsswitch.conf", Mode: "ro"},
{Source: "/etc/passwd", Target: "/etc/passwd", Mode: "ro"},
{Source: "/etc/group", Target: "/etc/group", Mode: "ro"},
{Source: "/etc/ssl", Target: "/etc/ssl", Mode: "ro"},
{Source: "/etc/pki", Target: "/etc/pki", Mode: "ro"},
{Source: "/etc/ca-certificates", Target: "/etc/ca-certificates", Mode: "ro"},
{Source: "/usr/share/ca-certificates", Target: "/usr/share/ca-certificates", Mode: "ro"},
{Source: "/usr/local/share/ca-certificates", Target: "/usr/local/share/ca-certificates", Mode: "ro"},
{Source: "/etc/alternatives", Target: "/etc/alternatives", Mode: "ro"},
{Source: "/usr/share/zoneinfo", Target: "/usr/share/zoneinfo", Mode: "ro"},
{Source: "/etc/localtime", Target: "/etc/localtime", Mode: "ro"},
})
}
// existingExposePaths keeps only the builtin host paths that exist on the
// current machine so Linux isolation does not fail on distro-specific paths.
func existingExposePaths(items []config.ExposePath) []config.ExposePath {
filtered := make([]config.ExposePath, 0, len(items))
for _, item := range items {
if _, err := os.Stat(item.Source); err == nil {
filtered = append(filtered, item)
}
}
return filtered
}
// MergeExposePaths merges built-in rules with user overrides. Rules are keyed
// by target path so later entries replace earlier ones for the same target.
func MergeExposePaths(defaults []config.ExposePath, overrides []config.ExposePath) []config.ExposePath {
merged := make([]config.ExposePath, 0, len(defaults)+len(overrides))
indexByTarget := make(map[string]int, len(defaults)+len(overrides))
appendOrReplace := func(item config.ExposePath) {
normalized := NormalizeExposePath(item)
if idx, ok := indexByTarget[normalized.Target]; ok {
merged[idx] = normalized
return
}
indexByTarget[normalized.Target] = len(merged)
merged = append(merged, normalized)
}
for _, item := range defaults {
appendOrReplace(item)
}
for _, item := range overrides {
appendOrReplace(item)
}
return merged
}
// BuildLinuxMountPlan converts the merged expose-path configuration into the
// mount rules consumed by the Linux bubblewrap backend.
func BuildLinuxMountPlan(root string, overrides []config.ExposePath) []MountRule {
merged := MergeExposePaths(DefaultExposePaths(root), overrides)
plan := make([]MountRule, 0, len(merged))
for _, item := range merged {
plan = append(plan, MountRule{Source: item.Source, Target: item.Target, Mode: item.Mode})
}
return plan
}
// BuildWindowsAccessRules derives the host-path access policy used by the
// Windows restricted-token backend.
func BuildWindowsAccessRules(root string, overrides []config.ExposePath) []AccessRule {
merged := MergeExposePaths(nil, overrides)
rules := make([]AccessRule, 0, len(merged)+1)
rules = append(rules, AccessRule{Path: root, Mode: "rw"})
for _, item := range merged {
rules = append(rules, AccessRule{Path: item.Source, Mode: item.Mode})
}
return rules
}
func validateWindowsExposePaths(items []config.ExposePath) error {
if len(items) == 0 {
return nil
}
return fmt.Errorf("windows isolation does not yet support expose_paths filesystem rules")
}
// IsSupported reports whether the current platform has an implemented isolation
// backend.
func IsSupported() bool {
return isSupportedOn(runtime.GOOS)
}
func isSupportedOn(goos string) bool {
switch goos {
case "linux", "windows":
return true
default:
return false
}
}
// Preflight validates the configured isolation state and prepares the instance
// runtime directories before any child process is launched.
func Preflight() error {
isolation := CurrentConfig()
if !isolation.Enabled {
return nil
}
if !IsSupported() {
return fmt.Errorf("subprocess isolation is not supported on %s", runtime.GOOS)
}
root, err := ResolveInstanceRoot()
if err != nil {
return err
}
if err := PrepareInstanceRoot(root); err != nil {
return err
}
if err := ValidateExposePaths(isolation.ExposePaths); err != nil {
return err
}
if runtime.GOOS == "linux" {
for _, rule := range BuildLinuxMountPlan(root, isolation.ExposePaths) {
if rule.Source == "" || rule.Target == "" {
return fmt.Errorf("invalid linux mount rule")
}
}
}
if runtime.GOOS == "windows" {
if err := validateWindowsExposePaths(isolation.ExposePaths); err != nil {
return err
}
for _, rule := range BuildWindowsAccessRules(root, isolation.ExposePaths) {
if rule.Path == "" {
return fmt.Errorf("invalid windows access rule")
}
}
}
return nil
}
// Start prepares isolation for the command, starts it, and applies any
// post-start platform hooks required by the active backend.
func Start(cmd *exec.Cmd) error {
if err := PrepareCommand(cmd); err != nil {
return err
}
if err := cmd.Start(); err != nil {
cleanupPendingPlatformResources(cmd)
return err
}
isolation := CurrentConfig()
root := ""
if isolation.Enabled {
var err error
root, err = ResolveInstanceRoot()
if err != nil {
terminateStartedCommand(cmd)
return err
}
}
if err := postStartPlatformIsolation(cmd, isolation, root); err != nil {
terminateStartedCommand(cmd)
return err
}
return nil
}
// Run is the Start-and-Wait helper that keeps the same isolation behavior as
// Start while returning the command's final exit status.
func Run(cmd *exec.Cmd) error {
if err := PrepareCommand(cmd); err != nil {
return err
}
if err := cmd.Start(); err != nil {
cleanupPendingPlatformResources(cmd)
return err
}
isolation := CurrentConfig()
root := ""
if isolation.Enabled {
var err error
root, err = ResolveInstanceRoot()
if err != nil {
terminateStartedCommand(cmd)
return err
}
}
if err := postStartPlatformIsolation(cmd, isolation, root); err != nil {
terminateStartedCommand(cmd)
return err
}
return cmd.Wait()
}
func terminateStartedCommand(cmd *exec.Cmd) {
cleanupPendingPlatformResources(cmd)
if cmd == nil || cmd.Process == nil {
return
}
_ = cmd.Process.Kill()
_ = cmd.Wait()
}
// PrepareCommand mutates the command in-place so it inherits the configured
// isolated environment before being started by the caller.
func PrepareCommand(cmd *exec.Cmd) error {
isolation := CurrentConfig()
if err := Preflight(); err != nil {
return err
}
if isolation.Enabled {
root, err := ResolveInstanceRoot()
if err != nil {
return err
}
ApplyUserEnv(cmd, root)
if err := applyPlatformIsolation(cmd, isolation, root); err != nil {
return err
}
}
return nil
}
+245
View File
@@ -0,0 +1,245 @@
package isolation
import (
"os"
"os/exec"
"path/filepath"
"runtime"
"testing"
"github.com/sipeed/picoclaw/pkg"
"github.com/sipeed/picoclaw/pkg/config"
)
func TestResolveInstanceRoot_UsesPicoclawHome(t *testing.T) {
t.Setenv(config.EnvHome, "/custom/picoclaw/home")
root, err := ResolveInstanceRoot()
if err != nil {
t.Fatalf("ResolveInstanceRoot() error = %v", err)
}
if root != "/custom/picoclaw/home" {
t.Fatalf("ResolveInstanceRoot() = %q, want %q", root, "/custom/picoclaw/home")
}
}
func TestPrepareInstanceRoot_CreatesDirectories(t *testing.T) {
root := filepath.Join(t.TempDir(), "instance")
if err := PrepareInstanceRoot(root); err != nil {
t.Fatalf("PrepareInstanceRoot() error = %v", err)
}
for _, dir := range InstanceDirs(root) {
if info, err := os.Stat(dir); err != nil {
t.Fatalf("os.Stat(%q): %v", dir, err)
} else if !info.IsDir() {
t.Fatalf("%q is not a directory", dir)
}
}
}
func TestInstanceDirs_UsesInstanceWorkspaceNotGlobalState(t *testing.T) {
root := filepath.Join(t.TempDir(), "instance")
cfg := config.DefaultConfig()
cfg.Isolation.Enabled = true
cfg.Agents.Defaults.Workspace = filepath.Join(t.TempDir(), "external-workspace")
Configure(cfg)
t.Cleanup(func() { Configure(config.DefaultConfig()) })
dirs := InstanceDirs(root)
wantWorkspace := filepath.Join(root, pkg.WorkspaceName)
found := false
for _, dir := range dirs {
if dir == wantWorkspace {
found = true
}
if dir == cfg.WorkspacePath() {
t.Fatalf("InstanceDirs() should not depend on process-wide workspace state: %q", dir)
}
}
if !found {
t.Fatalf("InstanceDirs() missing instance workspace dir %q", wantWorkspace)
}
}
func TestIsSupportedOn(t *testing.T) {
tests := []struct {
goos string
want bool
}{
{goos: "linux", want: true},
{goos: "windows", want: true},
{goos: "darwin", want: false},
{goos: "freebsd", want: false},
}
for _, tt := range tests {
if got := isSupportedOn(tt.goos); got != tt.want {
t.Fatalf("isSupportedOn(%q) = %v, want %v", tt.goos, got, tt.want)
}
}
}
func TestValidateExposePaths(t *testing.T) {
err := ValidateExposePaths([]config.ExposePath{{Source: "/src", Target: "/dst", Mode: "ro"}})
if err != nil {
t.Fatalf("ValidateExposePaths() error = %v", err)
}
err = ValidateExposePaths([]config.ExposePath{{Source: "/src", Target: "/dst", Mode: "bad"}})
if err == nil {
t.Fatal("ValidateExposePaths() expected invalid mode error")
}
err = ValidateExposePaths(
[]config.ExposePath{
{Source: "/src", Target: "/dst", Mode: "ro"},
{Source: "/other", Target: "/dst", Mode: "rw"},
},
)
if err == nil {
t.Fatal("ValidateExposePaths() expected duplicate target error")
}
}
func TestMergeExposePaths_OverrideByTarget(t *testing.T) {
merged := MergeExposePaths(
[]config.ExposePath{{Source: "/src-a", Target: "/dst", Mode: "ro"}},
[]config.ExposePath{{Source: "/src-b", Target: "/dst", Mode: "rw"}},
)
if len(merged) != 1 {
t.Fatalf("MergeExposePaths len = %d, want 1", len(merged))
}
if got := merged[0]; got.Source != "/src-b" || got.Target != "/dst" || got.Mode != "rw" {
t.Fatalf("merged[0] = %+v, want source=/src-b target=/dst mode=rw", got)
}
}
func TestBuildLinuxMountPlan(t *testing.T) {
if runtime.GOOS != "linux" {
t.Skip("linux-only default mount set")
}
plan := BuildLinuxMountPlan("/rootdir", []config.ExposePath{{Source: "/src", Target: "/dst", Mode: "ro"}})
if len(plan) == 0 {
t.Fatal("BuildLinuxMountPlan returned empty plan")
}
foundRoot := false
foundOverride := false
for _, rule := range plan {
if rule.Source == "/rootdir" && rule.Target == "/rootdir" && rule.Mode == "rw" {
foundRoot = true
}
if rule.Source == "/src" && rule.Target == "/dst" && rule.Mode == "ro" {
foundOverride = true
}
}
if !foundRoot {
t.Fatal("BuildLinuxMountPlan missing root mapping")
}
if !foundOverride {
t.Fatal("BuildLinuxMountPlan missing override mapping")
}
}
func TestBuildWindowsAccessRules(t *testing.T) {
rules := BuildWindowsAccessRules(
`C:\picoclaw`,
[]config.ExposePath{{Source: `D:\data`, Target: `C:\mapped`, Mode: "ro"}},
)
if len(rules) == 0 {
t.Fatal("BuildWindowsAccessRules returned empty rules")
}
foundRoot := false
foundOverride := false
for _, rule := range rules {
if rule.Path == `C:\picoclaw` && rule.Mode == "rw" {
foundRoot = true
}
if rule.Path == `D:\data` && rule.Mode == "ro" {
foundOverride = true
}
}
if !foundRoot {
t.Fatal("BuildWindowsAccessRules missing root rule")
}
if !foundOverride {
t.Fatal("BuildWindowsAccessRules missing override rule")
}
}
func TestValidateWindowsExposePaths(t *testing.T) {
if err := validateWindowsExposePaths(nil); err != nil {
t.Fatalf("validateWindowsExposePaths(nil) error = %v", err)
}
err := validateWindowsExposePaths([]config.ExposePath{{Source: `D:\data`, Target: `D:\data`, Mode: "ro"}})
if err == nil {
t.Fatal("validateWindowsExposePaths() expected error for expose_paths")
}
}
func TestDefaultLinuxSystemExposePaths(t *testing.T) {
paths := defaultLinuxSystemExposePaths()
needed := map[string]bool{}
for _, path := range []string{"/etc/hosts", "/etc/nsswitch.conf", "/etc/ssl", "/usr/share/zoneinfo", "/etc/localtime"} {
if _, err := os.Stat(path); err == nil {
needed[path] = false
}
}
for _, item := range paths {
if _, ok := needed[item.Source]; ok {
needed[item.Source] = true
}
}
for path, found := range needed {
if !found {
t.Fatalf("defaultLinuxSystemExposePaths missing %s", path)
}
}
}
func TestExistingExposePaths_SkipsMissingPaths(t *testing.T) {
existing := filepath.Join(t.TempDir(), "existing")
if err := os.MkdirAll(existing, 0o755); err != nil {
t.Fatalf("os.MkdirAll() error = %v", err)
}
filtered := existingExposePaths([]config.ExposePath{
{Source: existing, Target: existing, Mode: "ro"},
{Source: filepath.Join(t.TempDir(), "missing"), Target: "/missing", Mode: "ro"},
})
if len(filtered) != 1 {
t.Fatalf("existingExposePaths() len = %d, want 1", len(filtered))
}
if got := filtered[0]; got.Source != existing {
t.Fatalf("existingExposePaths()[0] = %+v, want source=%q", got, existing)
}
}
func TestPrepareCommand_AppliesUserEnv(t *testing.T) {
t.Setenv(config.EnvHome, filepath.Join(t.TempDir(), "home"))
if runtime.GOOS == "linux" {
binDir := filepath.Join(t.TempDir(), "bin")
if err := os.MkdirAll(binDir, 0o755); err != nil {
t.Fatalf("os.MkdirAll() error = %v", err)
}
fakeBwrap := filepath.Join(binDir, "bwrap")
if err := os.WriteFile(fakeBwrap, []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil {
t.Fatalf("os.WriteFile() error = %v", err)
}
t.Setenv("PATH", binDir+string(os.PathListSeparator)+os.Getenv("PATH"))
}
cfg := config.DefaultConfig()
cfg.Isolation.Enabled = true
Configure(cfg)
t.Cleanup(func() { Configure(config.DefaultConfig()) })
cmd := exec.Command("sh", "-c", "true")
if err := PrepareCommand(cmd); err != nil {
t.Fatalf("PrepareCommand() error = %v", err)
}
hasHome := false
for _, env := range cmd.Env {
if len(env) > 5 && env[:5] == "HOME=" {
hasHome = true
break
}
}
if runtime.GOOS != "windows" && !hasHome {
t.Fatal("PrepareCommand() did not inject HOME")
}
}
+226
View File
@@ -0,0 +1,226 @@
package mcp
import (
"context"
"encoding/json"
"fmt"
"io"
"os/exec"
"sync"
"syscall"
"time"
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
"github.com/sipeed/picoclaw/pkg/isolation"
)
var isolatedCommandTerminateDuration = 5 * time.Second
// isolatedCommandTransport mirrors the SDK command transport but routes
// process startup through pkg/isolation so Windows post-start hooks run too.
type isolatedCommandTransport struct {
Command *exec.Cmd
TerminateDuration time.Duration
}
func (t *isolatedCommandTransport) Connect(ctx context.Context) (sdkmcp.Connection, error) {
stdout, err := t.Command.StdoutPipe()
if err != nil {
return nil, err
}
stdout = io.NopCloser(stdout)
stdin, err := t.Command.StdinPipe()
if err != nil {
return nil, err
}
if err := isolation.Start(t.Command); err != nil {
return nil, err
}
td := t.TerminateDuration
if td <= 0 {
td = isolatedCommandTerminateDuration
}
return newIsolatedIOConn(&isolatedPipeRWC{cmd: t.Command, stdout: stdout, stdin: stdin, terminateDuration: td}), nil
}
type isolatedPipeRWC struct {
cmd *exec.Cmd
stdout io.ReadCloser
stdin io.WriteCloser
terminateDuration time.Duration
}
func (s *isolatedPipeRWC) Read(p []byte) (n int, err error) {
return s.stdout.Read(p)
}
func (s *isolatedPipeRWC) Write(p []byte) (n int, err error) {
return s.stdin.Write(p)
}
func (s *isolatedPipeRWC) Close() error {
if err := s.stdin.Close(); err != nil {
return fmt.Errorf("closing stdin: %v", err)
}
resChan := make(chan error, 1)
go func() {
resChan <- s.cmd.Wait()
}()
wait := func() (error, bool) {
select {
case err := <-resChan:
return err, true
case <-time.After(s.terminateDuration):
}
return nil, false
}
if err, ok := wait(); ok {
return err
}
if err := s.cmd.Process.Signal(syscall.SIGTERM); err == nil {
if err, ok := wait(); ok {
return err
}
}
if err := s.cmd.Process.Kill(); err != nil {
return err
}
if err, ok := wait(); ok {
return err
}
return fmt.Errorf("unresponsive subprocess")
}
type isolatedIOConn struct {
writeMu sync.Mutex
rwc io.ReadWriteCloser
incoming <-chan isolatedMsgOrErr
queue []jsonrpc.Message
closeOnce sync.Once
closed chan struct{}
closeErr error
}
type isolatedMsgOrErr struct {
msg json.RawMessage
err error
}
func newIsolatedIOConn(rwc io.ReadWriteCloser) *isolatedIOConn {
incoming := make(chan isolatedMsgOrErr)
closed := make(chan struct{})
go func() {
dec := json.NewDecoder(rwc)
for {
var raw json.RawMessage
err := dec.Decode(&raw)
if err == nil {
var tr [1]byte
if n, readErr := dec.Buffered().Read(tr[:]); n > 0 {
if tr[0] != '\n' && tr[0] != '\r' {
err = fmt.Errorf("invalid trailing data at the end of stream")
}
} else if readErr != nil && readErr != io.EOF {
err = readErr
}
}
select {
case incoming <- isolatedMsgOrErr{msg: raw, err: err}:
case <-closed:
return
}
if err != nil {
return
}
}
}()
return &isolatedIOConn{rwc: rwc, incoming: incoming, closed: closed}
}
func (c *isolatedIOConn) SessionID() string { return "" }
func (c *isolatedIOConn) Read(ctx context.Context) (jsonrpc.Message, error) {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
if len(c.queue) > 0 {
next := c.queue[0]
c.queue = c.queue[1:]
return next, nil
}
var raw json.RawMessage
select {
case <-ctx.Done():
return nil, ctx.Err()
case v := <-c.incoming:
if v.err != nil {
return nil, v.err
}
raw = v.msg
case <-c.closed:
return nil, io.EOF
}
msgs, err := readIsolatedBatch(raw)
if err != nil {
return nil, err
}
c.queue = msgs[1:]
return msgs[0], nil
}
func readIsolatedBatch(data []byte) ([]jsonrpc.Message, error) {
var rawBatch []json.RawMessage
if err := json.Unmarshal(data, &rawBatch); err == nil {
if len(rawBatch) == 0 {
return nil, fmt.Errorf("empty batch")
}
msgs := make([]jsonrpc.Message, 0, len(rawBatch))
for _, raw := range rawBatch {
msg, err := jsonrpc.DecodeMessage(raw)
if err != nil {
return nil, err
}
msgs = append(msgs, msg)
}
return msgs, nil
}
msg, err := jsonrpc.DecodeMessage(data)
if err != nil {
return nil, err
}
return []jsonrpc.Message{msg}, nil
}
func (c *isolatedIOConn) Write(ctx context.Context, msg jsonrpc.Message) error {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
c.writeMu.Lock()
defer c.writeMu.Unlock()
data, err := jsonrpc.EncodeMessage(msg)
if err != nil {
return fmt.Errorf("marshaling message: %v", err)
}
data = append(data, '\n')
_, err = c.rwc.Write(data)
return err
}
func (c *isolatedIOConn) Close() error {
c.closeOnce.Do(func() {
c.closeErr = c.rwc.Close()
close(c.closed)
})
return c.closeErr
}
var (
_ sdkmcp.Transport = (*isolatedCommandTransport)(nil)
_ sdkmcp.Connection = (*isolatedIOConn)(nil)
)
+1 -2
View File
@@ -365,8 +365,7 @@ func (m *Manager) ConnectServer(
env = append(env, fmt.Sprintf("%s=%s", k, v))
}
cmd.Env = env
transport = &mcp.CommandTransport{Command: cmd}
transport = &isolatedCommandTransport{Command: cmd}
default:
return fmt.Errorf(
"unsupported transport type: %s (supported: stdio, sse, http)",
+27
View File
@@ -455,6 +455,33 @@ func (s *JSONLStore) rewriteJSONL(
return fileutil.WriteFileAtomic(s.jsonlPath(sessionKey), buf.Bytes(), 0o644)
}
// ListSessions returns all known session keys by reading .meta.json files.
func (s *JSONLStore) ListSessions() []string {
entries, err := os.ReadDir(s.dir)
if err != nil {
return nil
}
var keys []string
for _, entry := range entries {
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".meta.json") {
continue
}
// Read the meta file to get the original key
data, err := os.ReadFile(filepath.Join(s.dir, entry.Name()))
if err != nil {
continue
}
var meta sessionMeta
if err := json.Unmarshal(data, &meta); err != nil {
continue
}
if meta.Key != "" {
keys = append(keys, meta.Key)
}
}
return keys
}
func (s *JSONLStore) Close() error {
return nil
}
+3
View File
@@ -37,6 +37,9 @@ type Store interface {
// data. Backends that do not accumulate dead data may return nil.
Compact(ctx context.Context, sessionKey string) error
// ListSessions returns all known session keys.
ListSessions() []string
// Close releases any resources held by the store.
Close() error
}
+37 -2
View File
@@ -4,6 +4,7 @@ import (
"crypto/rand"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
@@ -16,6 +17,8 @@ import (
const pidFileName = ".picoclaw.pid"
var errInvalidPidFile = errors.New("invalid pid file")
// PidFileData is the JSON structure stored in the PID file.
type PidFileData struct {
PID int `json:"pid"`
@@ -109,6 +112,14 @@ func ReadPidFileWithCheck(homePath string) *PidFileData {
pidPath := pidFilePath(homePath)
data, err := readPidFileUnlocked(pidPath)
if err != nil {
if os.IsNotExist(err) {
return nil
}
if errors.Is(err, errInvalidPidFile) {
logger.Warnf("invalid pid file, remove it: %s (%v)", pidPath, err)
_ = os.Remove(pidPath)
return nil
}
logger.Debugf("failed to read pid file: %s", err)
return nil
}
@@ -140,6 +151,30 @@ func RemovePidFile(homePath string) {
os.Remove(pidPath)
}
// RemovePidFileIfPID deletes the PID file only when the recorded PID matches
// expectedPID. It returns true when the file is removed successfully.
func RemovePidFileIfPID(homePath string, expectedPID int) bool {
if expectedPID <= 0 {
return false
}
pidMu.Lock()
defer pidMu.Unlock()
pidPath := pidFilePath(homePath)
data, err := readPidFileUnlocked(pidPath)
if err != nil {
return false
}
if data.PID != expectedPID {
return false
}
if err := os.Remove(pidPath); err != nil {
return false
}
return true
}
// readPidFileUnlocked reads the PID file without acquiring the lock.
// Caller must hold pidMu.
func readPidFileUnlocked(pidPath string) (*PidFileData, error) {
@@ -150,12 +185,12 @@ func readPidFileUnlocked(pidPath string) (*PidFileData, error) {
var data PidFileData
if err := json.Unmarshal(raw, &data); err != nil {
return nil, err
return nil, fmt.Errorf("%w: %v", errInvalidPidFile, err)
}
// Validate PID is a positive integer.
if data.PID <= 0 {
return nil, fmt.Errorf("invalid pid in pid file: %d", data.PID)
return nil, fmt.Errorf("%w: pid=%d", errInvalidPidFile, data.PID)
}
return &data, nil
+50
View File
@@ -191,6 +191,22 @@ func TestReadPidFileWithCheckStalePID(t *testing.T) {
}
}
// TestReadPidFileWithCheckInvalidFile auto-cleans malformed PID file.
func TestReadPidFileWithCheckInvalidFile(t *testing.T) {
dir := tmpDir(t)
path := filepath.Join(dir, pidFileName)
os.WriteFile(path, []byte("not json"), 0o600)
data := ReadPidFileWithCheck(dir)
if data != nil {
t.Error("expected nil for malformed pid file")
}
if _, err := os.Stat(path); !os.IsNotExist(err) {
t.Error("malformed PID file should be removed")
}
}
// TestRemovePidFile removes the PID file for the current process.
func TestRemovePidFile(t *testing.T) {
dir := tmpDir(t)
@@ -228,6 +244,40 @@ func TestRemovePidFileNonexistent(t *testing.T) {
RemovePidFile(dir)
}
func TestRemovePidFileIfPID(t *testing.T) {
dir := tmpDir(t)
other := PidFileData{PID: 99999999, Token: "deadbeef12345678deadbeef12345678"}
raw, _ := json.MarshalIndent(other, "", " ")
path := filepath.Join(dir, pidFileName)
os.WriteFile(path, raw, 0o600)
removed := RemovePidFileIfPID(dir, 99999999)
if !removed {
t.Fatal("expected RemovePidFileIfPID to remove matching pid file")
}
if _, err := os.Stat(path); !os.IsNotExist(err) {
t.Error("PID file should be removed for matching expected PID")
}
}
func TestRemovePidFileIfPIDMismatch(t *testing.T) {
dir := tmpDir(t)
other := PidFileData{PID: 99999999, Token: "deadbeef12345678deadbeef12345678"}
raw, _ := json.MarshalIndent(other, "", " ")
path := filepath.Join(dir, pidFileName)
os.WriteFile(path, raw, 0o600)
removed := RemovePidFileIfPID(dir, 88888888)
if removed {
t.Fatal("expected RemovePidFileIfPID to keep non-matching pid file")
}
if _, err := os.Stat(path); os.IsNotExist(err) {
t.Error("PID file should NOT be removed for mismatching expected PID")
}
}
// TestReadPidFileUnlockedInvalidJSON returns error for malformed content.
func TestReadPidFileUnlockedInvalidJSON(t *testing.T) {
dir := tmpDir(t)
+8 -1
View File
@@ -3,6 +3,7 @@
package pid
import (
"errors"
"os"
"syscall"
)
@@ -18,5 +19,11 @@ func isProcessRunning(pid int) bool {
return false
}
// Signal(nil) does not kill the process but checks existence on Unix.
return p.Signal(syscall.Signal(0)) == nil
err = p.Signal(syscall.Signal(0))
if err == nil {
return true
}
var errno syscall.Errno
// EPERM means the process exists but we are not allowed to signal it.
return errors.As(err, &errno) && errno == syscall.EPERM
}
+4 -4
View File
@@ -23,19 +23,19 @@ func isProcessRunning(pid int) bool {
return false
}
handle, _, err := procOpenProcess.Call(
handle, _, _ := procOpenProcess.Call(
uintptr(processQueryLimitedInformation),
0,
uintptr(pid),
)
if handle == 0 || err != nil {
if handle == 0 {
return false
}
defer procCloseHandle.Call(handle)
var exitCode uint32
ret, _, err := procGetExitCodeProcess.Call(handle, uintptr(unsafe.Pointer(&exitCode)))
if ret == 0 || err != nil {
ret, _, _ := procGetExitCodeProcess.Call(handle, uintptr(unsafe.Pointer(&exitCode)))
if ret == 0 {
return false
}
return exitCode == stillActive
+5 -1
View File
@@ -7,6 +7,8 @@ import (
"fmt"
"os/exec"
"strings"
"github.com/sipeed/picoclaw/pkg/isolation"
)
// ClaudeCliProvider implements LLMProvider using the claude CLI as a subprocess.
@@ -49,7 +51,9 @@ func (p *ClaudeCliProvider) Chat(
cmd.Stdout = &stdout
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
// Execute the CLI through the shared isolation wrapper so external provider
// processes honor the configured isolation policy.
if err := isolation.Run(cmd); err != nil {
stderrStr := strings.TrimSpace(stderr.String())
stdoutStr := strings.TrimSpace(stdout.String())
switch {
+5 -1
View File
@@ -8,6 +8,8 @@ import (
"fmt"
"os/exec"
"strings"
"github.com/sipeed/picoclaw/pkg/isolation"
)
// CodexCliProvider implements LLMProvider by wrapping the codex CLI as a subprocess.
@@ -56,7 +58,9 @@ func (p *CodexCliProvider) Chat(
cmd.Stdout = &stdout
cmd.Stderr = &stderr
err := cmd.Run()
// Execute the CLI through the shared isolation wrapper so external provider
// processes honor the configured isolation policy.
err := isolation.Run(cmd)
// Parse JSONL from stdout even if exit code is non-zero,
// because codex writes diagnostic noise to stderr (e.g. rollout errors)
+16
View File
@@ -262,6 +262,22 @@ func TestDecodeToolCallArguments_StringJSON(t *testing.T) {
}
}
func TestDecodeToolCallArguments_StringJSON_NewlineEscape(t *testing.T) {
raw := json.RawMessage(`"{\"content\":\"line1\\nline2\"}"`)
args := DecodeToolCallArguments(raw, "write_file")
if args["content"] != "line1\nline2" {
t.Errorf("content = %q, want newline-expanded string", args["content"])
}
}
func TestDecodeToolCallArguments_StringJSON_LiteralBackslashN(t *testing.T) {
raw := json.RawMessage(`"{\"content\":\"line1\\\\nline2\"}"`)
args := DecodeToolCallArguments(raw, "write_file")
if args["content"] != `line1\nline2` {
t.Errorf("content = %q, want literal backslash-n", args["content"])
}
}
func TestDecodeToolCallArguments_EmptyInput(t *testing.T) {
args := DecodeToolCallArguments(nil, "test")
if len(args) != 0 {
+4
View File
@@ -160,6 +160,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
userAgent,
cfg.RequestTimeout,
cfg.ExtraBody,
cfg.CustomHeaders,
), modelID, nil
case "azure", "azure-openai":
@@ -238,6 +239,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
userAgent,
cfg.RequestTimeout,
cfg.ExtraBody,
cfg.CustomHeaders,
), modelID, nil
case "minimax":
@@ -264,6 +266,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
userAgent,
cfg.RequestTimeout,
extraBody,
cfg.CustomHeaders,
), modelID, nil
case "anthropic":
@@ -291,6 +294,7 @@ func CreateProviderFromConfig(cfg *config.ModelConfig) (LLMProvider, string, err
userAgent,
cfg.RequestTimeout,
cfg.ExtraBody,
cfg.CustomHeaders,
), modelID, nil
case "anthropic-messages":
+43
View File
@@ -846,6 +846,49 @@ func TestCreateProviderFromConfig_MinimaxPreservesUserExtraBody(t *testing.T) {
}
}
func TestCreateProviderFromConfig_CustomHeaders(t *testing.T) {
var gotSource, gotAuth string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotSource = r.Header.Get("X-Source")
gotAuth = r.Header.Get("Authorization")
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"choices":[{"message":{"content":"ok"},"finish_reason":"stop"}]}`))
}))
defer server.Close()
cfg := &config.ModelConfig{
ModelName: "test-headers",
Model: "openai/gpt-4o",
APIBase: server.URL,
CustomHeaders: map[string]string{"X-Source": "coding-plan", "Authorization": "Token config-auth"},
}
cfg.SetAPIKey("test-key")
provider, modelID, err := CreateProviderFromConfig(cfg)
if err != nil {
t.Fatalf("CreateProviderFromConfig() error = %v", err)
}
_, err = provider.Chat(
t.Context(),
[]Message{{Role: "user", Content: "hi"}},
nil,
modelID,
nil,
)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if gotSource != "coding-plan" {
t.Fatalf("X-Source = %q, want %q", gotSource, "coding-plan")
}
if gotAuth != "Token config-auth" {
t.Fatalf("Authorization = %q, want %q", gotAuth, "Token config-auth")
}
}
// openaiCompatResponse is the JSON response used by OpenAI-compatible providers.
const openaiCompatResponse = `{"choices":[{"message":{"content":"ok"},"finish_reason":"stop"}]}`
+3 -1
View File
@@ -24,13 +24,14 @@ func NewHTTPProvider(apiKey, apiBase, proxy string) *HTTPProvider {
}
func NewHTTPProviderWithMaxTokensField(apiKey, apiBase, proxy, maxTokensField string) *HTTPProvider {
return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(apiKey, apiBase, proxy, maxTokensField, "", 0, nil)
return NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(apiKey, apiBase, proxy, maxTokensField, "", 0, nil, nil)
}
func NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(
apiKey, apiBase, proxy, maxTokensField, userAgent string,
requestTimeoutSeconds int,
extraBody map[string]any,
customHeaders map[string]string,
) *HTTPProvider {
return &HTTPProvider{
delegate: openai_compat.NewProvider(
@@ -40,6 +41,7 @@ func NewHTTPProviderWithMaxTokensFieldAndRequestTimeout(
openai_compat.WithMaxTokensField(maxTokensField),
openai_compat.WithRequestTimeout(time.Duration(requestTimeoutSeconds)*time.Second),
openai_compat.WithExtraBody(extraBody),
openai_compat.WithCustomHeaders(customHeaders),
openai_compat.WithUserAgent(userAgent),
),
}
+21
View File
@@ -36,6 +36,7 @@ type Provider struct {
maxTokensField string // Field name for max tokens (e.g., "max_completion_tokens" for o1/glm models)
httpClient *http.Client
extraBody map[string]any // Additional fields to inject into request body
customHeaders map[string]string
userAgent string
}
@@ -87,6 +88,12 @@ func WithExtraBody(extraBody map[string]any) Option {
}
}
func WithCustomHeaders(customHeaders map[string]string) Option {
return func(p *Provider) {
p.customHeaders = customHeaders
}
}
func NewProvider(apiKey, apiBase, proxy string, opts ...Option) *Provider {
p := &Provider{
apiKey: apiKey,
@@ -181,6 +188,15 @@ func (p *Provider) buildRequestBody(
return requestBody
}
func (p *Provider) applyCustomHeaders(req *http.Request) {
for k, v := range p.customHeaders {
if strings.TrimSpace(k) == "" {
continue
}
req.Header.Set(k, v)
}
}
func (p *Provider) Chat(
ctx context.Context,
messages []Message,
@@ -211,6 +227,7 @@ func (p *Provider) Chat(
if p.apiKey != "" {
req.Header.Set("Authorization", "Bearer "+p.apiKey)
}
p.applyCustomHeaders(req)
resp, err := p.httpClient.Do(req)
if err != nil {
@@ -254,9 +271,13 @@ func (p *Provider) ChatStream(
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "text/event-stream")
if p.userAgent != "" {
req.Header.Set("User-Agent", p.userAgent)
}
if p.apiKey != "" {
req.Header.Set("Authorization", "Bearer "+p.apiKey)
}
p.applyCustomHeaders(req)
// Use a client without Timeout for streaming — the http.Client.Timeout covers
// the entire request lifecycle including body reads, which would kill long streams.
@@ -710,6 +710,111 @@ func TestProviderChat_ExtraBodyOverridesOptions(t *testing.T) {
}
}
func TestProviderChat_CustomHeadersInjected(t *testing.T) {
var gotSource, gotAuth, gotUserAgent string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotSource = r.Header.Get("X-Source")
gotAuth = r.Header.Get("Authorization")
gotUserAgent = r.Header.Get("User-Agent")
resp := map[string]any{
"choices": []map[string]any{
{
"message": map[string]any{"content": "ok"},
"finish_reason": "stop",
},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
p := NewProvider(
"key",
server.URL,
"",
WithUserAgent("PicoClaw/Test"),
WithCustomHeaders(map[string]string{
"X-Source": "coding-plan",
"Authorization": "Token custom-auth",
"User-Agent": "Custom-UA/1.0",
}),
)
_, err := p.Chat(
t.Context(),
[]Message{{Role: "user", Content: "hi"}},
nil,
"gpt-4o",
nil,
)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if gotSource != "coding-plan" {
t.Fatalf("X-Source = %q, want %q", gotSource, "coding-plan")
}
if gotAuth != "Token custom-auth" {
t.Fatalf("Authorization = %q, want %q", gotAuth, "Token custom-auth")
}
if gotUserAgent != "Custom-UA/1.0" {
t.Fatalf("User-Agent = %q, want %q", gotUserAgent, "Custom-UA/1.0")
}
}
func TestProviderChatStream_CustomHeadersInjected(t *testing.T) {
var gotSource, gotAuth, gotUserAgent string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotSource = r.Header.Get("X-Source")
gotAuth = r.Header.Get("Authorization")
gotUserAgent = r.Header.Get("User-Agent")
w.Header().Set("Content-Type", "text/event-stream")
_, _ = w.Write([]byte("data: {\"choices\":[{\"delta\":{\"content\":\"ok\"},\"finish_reason\":\"stop\"}]}\n\n"))
_, _ = w.Write([]byte("data: [DONE]\n\n"))
}))
defer server.Close()
p := NewProvider(
"key",
server.URL,
"",
WithUserAgent("PicoClaw/Test"),
WithCustomHeaders(map[string]string{
"X-Source": "coding-plan",
"Authorization": "Token stream-auth",
"User-Agent": "Custom-UA/Stream",
}),
)
out, err := p.ChatStream(
t.Context(),
[]Message{{Role: "user", Content: "hi"}},
nil,
"gpt-4o",
nil,
nil,
)
if err != nil {
t.Fatalf("ChatStream() error = %v", err)
}
if out.Content != "ok" {
t.Fatalf("Content = %q, want %q", out.Content, "ok")
}
if gotSource != "coding-plan" {
t.Fatalf("X-Source = %q, want %q", gotSource, "coding-plan")
}
if gotAuth != "Token stream-auth" {
t.Fatalf("Authorization = %q, want %q", gotAuth, "Token stream-auth")
}
if gotUserAgent != "Custom-UA/Stream" {
t.Fatalf("User-Agent = %q, want %q", gotUserAgent, "Custom-UA/Stream")
}
}
type roundTripperFunc func(*http.Request) (*http.Response, error)
func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
+6
View File
@@ -23,6 +23,12 @@ func buildCLIToolsPrompt(tools []ToolDefinition) string {
)
sb.WriteString("\n```\n\n")
sb.WriteString("CRITICAL: The 'arguments' field MUST be a JSON-encoded STRING.\n\n")
sb.WriteString("Escaping rules (what to type in `function.arguments`):\n")
sb.WriteString("- Use `\\n` to represent a real newline character.\n")
sb.WriteString("- Use `\\\\n` to represent a literal backslash+n sequence (`\\n`).\n")
sb.WriteString(
"- `function.arguments` is a JSON-encoded string, so quotes/backslashes must be escaped in the outer payload.\n\n",
)
sb.WriteString("### Tool Definitions:\n\n")
for _, tool := range tools {
@@ -0,0 +1,7 @@
{
"tool_name": "Bash",
"tool_input_preview": "{\"command\":\"cd /home/yliu/repos/picoclaw && make lint 2>&1\",\"timeout\":120000}",
"error": "Exit code 2\npkg/agent/context_seahorse_test.go:1027:1: File is not properly formatted (gci)\n\t\t\tEarliestAt: &now,\n^\n1 issues:\n* gci: 1\nmake: *** [Makefile:264: lint] Error 1",
"timestamp": "2026-04-04T02:38:32.067Z",
"retry_count": 6
}
+58
View File
@@ -0,0 +1,58 @@
package seahorse
import (
"context"
"testing"
)
// =============================================================================
// CompactUntilUnder iteration cap
// =============================================================================
func TestCompactUntilUnderIterationCap(t *testing.T) {
// Setup: create a conversation with so many tokens that compaction
// will never reach the budget. The iteration cap prevents infinite loops.
//
// We use a mock CompleteFn that always returns the same content,
// and a budget of 0 which tokens can never reach.
// Without the cap, this would loop forever.
db := openTestDB(t)
if err := runSchema(db); err != nil {
t.Fatalf("migration: %v", err)
}
s := &Store{db: db}
conv, _ := s.GetOrCreateConversation(context.Background(), "agent:iter-cap")
convID := conv.ConversationID
// Add many messages to ensure there's plenty to compact
for i := 0; i < 40; i++ {
m, _ := s.AddMessage(context.Background(), convID, "user",
"this is a long message with lots of tokens to push context over budget", 100)
s.AppendContextMessage(context.Background(), convID, m.ID)
}
// A completeFn that always succeeds but returns non-reducing content
mockComplete := func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) {
return "Summary that doesn't reduce tokens much.", nil
}
ce, cancel := newTestCompactionEngineWithStore(s, mockComplete)
defer cancel()
// Use budget=1 so tokens can never reach budget
// (each message is 100 tokens, so 40 messages = 4000 tokens, budget 1 is unreachable)
// The function should stop after maxCompactIterations, not loop forever
ce.config = Config{} // ensure defaults
result, err := ce.CompactUntilUnder(context.Background(), convID, 1)
if err != nil {
// Should not error — should stop gracefully
t.Fatalf("CompactUntilUnder with budget=0: %v", err)
}
// The function should have completed within reasonable time
// If it exceeded the cap, it would still return (not hang)
_ = result
}
+144
View File
@@ -0,0 +1,144 @@
package seahorse
import (
"context"
"testing"
"time"
)
// =============================================================================
// Bug 1: formatMessagesForSummary ignores Parts
// - formatMessagesForSummary only reads m.Content, empty for Part-based messages
// - truncateSummary has same issue
// =============================================================================
func TestFormatMessagesForSummaryIncludesParts(t *testing.T) {
ts := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
messages := []Message{
{ID: 1, Role: "user", Content: "hello world", CreatedAt: ts},
{
ID: 2,
Role: "assistant",
Content: "", // empty — real content is in Parts
Parts: []MessagePart{
{Type: "text", Text: "I will run a command"},
{Type: "tool_use", Name: "bash", Arguments: `{"command":"ls -la"}`, ToolCallID: "call_1"},
},
CreatedAt: ts.Add(time.Minute),
},
{
ID: 3,
Role: "tool",
Content: "", // empty — real content is in Parts
Parts: []MessagePart{
{Type: "tool_result", Text: "file1.txt\nfile2.txt", ToolCallID: "call_1"},
},
CreatedAt: ts.Add(2 * time.Minute),
},
}
result := formatMessagesForSummary(messages)
// Must contain the plain text message
if !contains(result, "hello world") {
t.Error("formatMessagesForSummary: missing plain text content")
}
// Must contain tool_use info (not blank)
if !contains(result, "bash") || !contains(result, "ls -la") {
t.Errorf("formatMessagesForSummary: tool_use info missing from Parts.\nGot:\n%s", result)
}
// Must contain tool_result info (not blank)
if !contains(result, "file1.txt") {
t.Errorf("formatMessagesForSummary: tool_result text missing from Parts.\nGot:\n%s", result)
}
}
func TestTruncateSummaryIncludesParts(t *testing.T) {
messages := []Message{
{ID: 1, Role: "user", Content: "run the tests", CreatedAt: time.Now()},
{
ID: 2,
Role: "assistant",
Content: "", // empty
Parts: []MessagePart{
{Type: "tool_use", Name: "bash", Arguments: `{"command":"go test ./..."}`, ToolCallID: "call_1"},
},
CreatedAt: time.Now(),
},
{
ID: 3,
Role: "tool",
Content: "", // empty
Parts: []MessagePart{
{Type: "tool_result", Text: "PASS\nok 3.2s", ToolCallID: "call_1"},
},
CreatedAt: time.Now(),
},
}
result := truncateSummary(messages)
// Must contain plain text
if !contains(result, "run the tests") {
t.Error("truncateSummary: missing plain text content")
}
// Must contain tool info from Parts (not blank)
if !contains(result, "bash") || !contains(result, "go test") {
t.Errorf("truncateSummary: tool_use info missing from Parts.\nGot:\n%s", result)
}
// Must contain tool_result from Parts
if !contains(result, "PASS") {
t.Errorf("truncateSummary: tool_result text missing from Parts.\nGot:\n%s", result)
}
}
// =============================================================================
// Bug 2: SearchMessages cannot find Part-based messages
// - FTS5 indexes empty content, LIKE queries empty content
// =============================================================================
func TestSearchMessagesFindsPartBasedMessages(t *testing.T) {
s := openTestStore(t)
ctx := context.Background()
conv, _ := s.GetOrCreateConversation(ctx, "agent:search-parts")
convID := conv.ConversationID
// Add a plain message (searchable)
s.AddMessage(ctx, convID, "user", "list the files please", 5)
// Add a Part-based message (tool_use) — currently NOT searchable
parts := []MessagePart{
{Type: "tool_use", Name: "bash", Arguments: `{"command":"grep -r TODO ."}`, ToolCallID: "call_1"},
}
s.AddMessageWithParts(ctx, convID, "assistant", parts, 10)
// Add a Part-based message (tool_result) — currently NOT searchable
resultParts := []MessagePart{
{Type: "tool_result", Text: "main.go:42: TODO fix this bug", ToolCallID: "call_1"},
}
s.AddMessageWithParts(ctx, convID, "tool", resultParts, 10)
// Search for "grep" — should find the tool_use message
results, err := s.SearchMessages(ctx, SearchInput{Pattern: "grep"})
if err != nil {
t.Fatalf("SearchMessages: %v", err)
}
if len(results) == 0 {
t.Error("SearchMessages: 'grep' not found — Part-based messages are invisible to search")
}
// Search for "TODO fix" — should find the tool_result message
results2, err := s.SearchMessages(ctx, SearchInput{Pattern: "TODO fix"})
if err != nil {
t.Fatalf("SearchMessages: %v", err)
}
if len(results2) == 0 {
t.Error("SearchMessages: 'TODO fix' not found — tool_result messages are invisible to search")
}
}
+185
View File
@@ -0,0 +1,185 @@
package seahorse
import (
"database/sql"
"fmt"
"github.com/sipeed/picoclaw/pkg/logger"
)
// SQL statements for FTS5 tables with trigram tokenizer.
const (
sqlCreateSummariesFTS = `CREATE VIRTUAL TABLE IF NOT EXISTS summaries_fts USING fts5(
summary_id,
content,
tokenize="trigram"
)`
sqlCreateMessagesFTS = `CREATE VIRTUAL TABLE IF NOT EXISTS messages_fts USING fts5(
message_id,
content,
tokenize="trigram"
)`
sqlCheckFTS5Available = `CREATE VIRTUAL TABLE IF NOT EXISTS _fts5_check USING fts5(content)`
sqlCheckTrigramAvailable = `CREATE VIRTUAL TABLE IF NOT EXISTS _trigram_check USING fts5(content, tokenize="trigram")`
sqlDropFTS5Check = `DROP TABLE IF EXISTS _fts5_check`
sqlDropTrigramCheck = `DROP TABLE IF EXISTS _trigram_check`
)
// runSchema creates or upgrades the database schema.
// All schemas are idempotent (safe to run multiple times).
func runSchema(db *sql.DB) error {
// Check FTS5 support before creating tables
if err := checkFTS5Support(db); err != nil {
return fmt.Errorf("FTS5 check: %w", err)
}
stmts := []string{
`CREATE TABLE IF NOT EXISTS conversations (
conversation_id INTEGER PRIMARY KEY AUTOINCREMENT,
session_key TEXT NOT NULL UNIQUE,
created_at TEXT NOT NULL DEFAULT (datetime('now')),
updated_at TEXT NOT NULL DEFAULT (datetime('now'))
)`,
`CREATE TABLE IF NOT EXISTS messages (
message_id INTEGER PRIMARY KEY AUTOINCREMENT,
conversation_id INTEGER NOT NULL REFERENCES conversations(conversation_id),
role TEXT NOT NULL,
content TEXT NOT NULL DEFAULT '',
token_count INTEGER NOT NULL DEFAULT 0,
created_at TEXT NOT NULL DEFAULT (datetime('now'))
)`,
`CREATE TABLE IF NOT EXISTS message_parts (
part_id INTEGER PRIMARY KEY AUTOINCREMENT,
message_id INTEGER NOT NULL REFERENCES messages(message_id),
type TEXT NOT NULL,
text TEXT,
name TEXT,
arguments TEXT,
tool_call_id TEXT,
media_uri TEXT,
mime_type TEXT,
ordinal INTEGER NOT NULL DEFAULT 0
)`,
`CREATE TABLE IF NOT EXISTS summaries (
summary_id TEXT PRIMARY KEY,
conversation_id INTEGER NOT NULL REFERENCES conversations(conversation_id),
kind TEXT NOT NULL,
depth INTEGER NOT NULL DEFAULT 0,
content TEXT NOT NULL,
token_count INTEGER NOT NULL DEFAULT 0,
earliest_at TEXT,
latest_at TEXT,
descendant_count INTEGER NOT NULL DEFAULT 0,
descendant_token_count INTEGER NOT NULL DEFAULT 0,
source_message_token_count INTEGER NOT NULL DEFAULT 0,
model TEXT,
created_at TEXT NOT NULL DEFAULT (datetime('now'))
)`,
`CREATE TABLE IF NOT EXISTS summary_parents (
summary_id TEXT NOT NULL,
parent_summary_id TEXT NOT NULL,
PRIMARY KEY (summary_id, parent_summary_id)
)`,
`CREATE TABLE IF NOT EXISTS summary_messages (
summary_id TEXT NOT NULL,
message_id INTEGER NOT NULL,
ordinal INTEGER NOT NULL DEFAULT 0,
PRIMARY KEY (summary_id, message_id)
)`,
`CREATE TABLE IF NOT EXISTS context_items (
conversation_id INTEGER NOT NULL,
ordinal INTEGER NOT NULL,
item_type TEXT NOT NULL,
summary_id TEXT,
message_id INTEGER,
token_count INTEGER NOT NULL DEFAULT 0,
created_at TEXT NOT NULL DEFAULT (datetime('now')),
PRIMARY KEY (conversation_id, ordinal)
)`,
// FTS5 virtual table with trigram tokenizer for CJK support
sqlCreateSummariesFTS,
// FTS5 virtual table for message search with trigram tokenizer
sqlCreateMessagesFTS,
// Indexes for common query patterns
`CREATE INDEX IF NOT EXISTS idx_messages_conversation ON messages(conversation_id)`,
`CREATE INDEX IF NOT EXISTS idx_messages_created ON messages(conversation_id, created_at)`,
`CREATE INDEX IF NOT EXISTS idx_summaries_conversation ON summaries(conversation_id)`,
`CREATE INDEX IF NOT EXISTS idx_summaries_kind_depth ON summaries(conversation_id, kind, depth)`,
`CREATE INDEX IF NOT EXISTS idx_summary_parents_parent ON summary_parents(parent_summary_id)`,
`CREATE INDEX IF NOT EXISTS idx_summary_messages_message ON summary_messages(message_id)`,
`CREATE INDEX IF NOT EXISTS idx_context_items_conv ON context_items(conversation_id, ordinal)`,
// FTS5 triggers to keep summaries_fts in sync with summaries table
`CREATE TRIGGER IF NOT EXISTS summaries_ai AFTER INSERT ON summaries BEGIN
INSERT INTO summaries_fts (summary_id, content) VALUES (new.summary_id, new.content);
END`,
`CREATE TRIGGER IF NOT EXISTS summaries_ad AFTER DELETE ON summaries BEGIN
INSERT INTO summaries_fts (summaries_fts, summary_id, content) VALUES ('delete', old.summary_id, old.content);
END`,
`CREATE TRIGGER IF NOT EXISTS summaries_au AFTER UPDATE ON summaries BEGIN
INSERT INTO summaries_fts (summaries_fts, summary_id, content) VALUES ('delete', old.summary_id, old.content);
INSERT INTO summaries_fts (summary_id, content) VALUES (new.summary_id, new.content);
END`,
// FTS5 triggers to keep messages_fts in sync with messages table
`CREATE TRIGGER IF NOT EXISTS messages_ai AFTER INSERT ON messages BEGIN
INSERT INTO messages_fts (message_id, content) VALUES (new.message_id, new.content);
END`,
`CREATE TRIGGER IF NOT EXISTS messages_ad AFTER DELETE ON messages BEGIN
DELETE FROM messages_fts WHERE message_id = old.message_id;
END`,
`CREATE TRIGGER IF NOT EXISTS messages_au AFTER UPDATE ON messages BEGIN
DELETE FROM messages_fts WHERE message_id = old.message_id;
INSERT INTO messages_fts (message_id, content) VALUES (new.message_id, new.content);
END`,
}
for _, s := range stmts {
if _, err := db.Exec(s); err != nil {
return err
}
}
return nil
}
// checkFTS5Support verifies that SQLite has FTS5 with trigram tokenizer enabled.
// This is required for full-text search with CJK (Chinese, Japanese, Korean) support.
func checkFTS5Support(db *sql.DB) error {
// Check if FTS5 is compiled in
var fts5Enabled int
err := db.QueryRow(`SELECT sqlite_compileoption_used('ENABLE_FTS5')`).Scan(&fts5Enabled)
if err != nil {
// sqlite_compileoption_used might not exist in older SQLite
// Try a different approach: create a test FTS5 table
_, testErr := db.Exec(sqlCheckFTS5Available)
if testErr != nil {
return fmt.Errorf("SQLite FTS5 not available: %w (required for full-text search)", testErr)
}
db.Exec(sqlDropFTS5Check)
} else if fts5Enabled == 0 {
return fmt.Errorf("SQLite was compiled without FTS5 support (required for full-text search)")
}
// Check if trigram tokenizer is available by trying to create a test table
// Not all SQLite builds include the trigram tokenizer
_, err = db.Exec(sqlCheckTrigramAvailable)
if err != nil {
logger.WarnCF("seahorse", "SQLite trigram tokenizer not available, CJK search may be limited",
map[string]any{"error": err.Error()})
// Trigram is not strictly required, just better for CJK
// Don't return error, just log warning
} else {
db.Exec(sqlDropTrigramCheck)
}
return nil
}
+223
View File
@@ -0,0 +1,223 @@
package seahorse
import (
"database/sql"
"fmt"
"strings"
"sync/atomic"
"testing"
_ "modernc.org/sqlite"
)
var testDBCounter uint64
func openTestDB(t *testing.T) *sql.DB {
t.Helper()
n := atomic.AddUint64(&testDBCounter, 1)
testName := strings.NewReplacer("/", "_", " ", "_").Replace(t.Name())
// Use a shared in-memory database so concurrent goroutines/connections in tests
// observe the same schema/data.
dsn := fmt.Sprintf("file:seahorse_test_%s_%d?mode=memory&cache=shared", testName, n)
db, err := sql.Open("sqlite", dsn)
if err != nil {
t.Fatalf("open test db: %v", err)
}
t.Cleanup(func() { db.Close() })
return db
}
func TestRunMigrations(t *testing.T) {
db := openTestDB(t)
if err := runSchema(db); err != nil {
t.Fatalf("runSchema: %v", err)
}
// Verify all tables exist
tables := []string{
"conversations",
"messages",
"message_parts",
"summaries",
"summary_parents",
"summary_messages",
"context_items",
}
for _, tbl := range tables {
var name string
err := db.QueryRow(
"SELECT name FROM sqlite_master WHERE type='table' AND name=?", tbl,
).Scan(&name)
if err != nil {
t.Errorf("table %q not found: %v", tbl, err)
}
}
// Verify FTS5 virtual table exists
var ftsName string
err := db.QueryRow(
"SELECT name FROM sqlite_master WHERE type='table' AND name='summaries_fts'",
).Scan(&ftsName)
if err != nil {
t.Errorf("FTS5 table summaries_fts not found: %v", err)
}
}
func TestRunMigrationsIdempotent(t *testing.T) {
db := openTestDB(t)
// Run migrations twice — should succeed both times
if err := runSchema(db); err != nil {
t.Fatalf("first migration: %v", err)
}
if err := runSchema(db); err != nil {
t.Fatalf("second migration (idempotent): %v", err)
}
// Verify we can still insert data after double migration
res, err := db.Exec(
"INSERT INTO conversations (session_key, created_at, updated_at) VALUES (?, datetime('now'), datetime('now'))",
"test-session",
)
if err != nil {
t.Fatalf("insert after double migration: %v", err)
}
id, _ := res.LastInsertId()
if id == 0 {
t.Error("expected non-zero conversation id")
}
}
func TestMigrationConversationUnique(t *testing.T) {
db := openTestDB(t)
if err := runSchema(db); err != nil {
t.Fatalf("migration: %v", err)
}
// Insert first
_, err := db.Exec(
"INSERT INTO conversations (session_key, created_at, updated_at) VALUES (?, datetime('now'), datetime('now'))",
"unique-key",
)
if err != nil {
t.Fatalf("first insert: %v", err)
}
// Duplicate should fail
_, err = db.Exec(
"INSERT INTO conversations (session_key, created_at, updated_at) VALUES (?, datetime('now'), datetime('now'))",
"unique-key",
)
if err == nil {
t.Error("expected unique constraint violation for duplicate session_key")
}
}
func TestMigrationSummaryFTSInsert(t *testing.T) {
db := openTestDB(t)
if err := runSchema(db); err != nil {
t.Fatalf("migration: %v", err)
}
// Insert a conversation first
_, err := db.Exec(
"INSERT INTO conversations (session_key, created_at, updated_at) VALUES (?, datetime('now'), datetime('now'))",
"fts-test",
)
if err != nil {
t.Fatalf("insert conversation: %v", err)
}
// Insert a summary
_, err = db.Exec(
`INSERT INTO summaries (summary_id, conversation_id, kind, depth, content, token_count, created_at)
VALUES ('sum_test1', 1, 'leaf', 0, '你好世界 hello world', 10, datetime('now'))`)
if err != nil {
t.Fatalf("insert summary: %v", err)
}
// FTS should find it — trigram tokenizer requires >= 3 chars
rows, err := db.Query(
"SELECT summary_id FROM summaries_fts WHERE summaries_fts MATCH ?",
"你好世",
)
if err != nil {
t.Fatalf("FTS query: %v", err)
}
defer rows.Close()
var found string
if rows.Next() {
if err := rows.Scan(&found); err != nil {
t.Fatalf("scan: %v", err)
}
}
if err := rows.Err(); err != nil {
t.Fatalf("rows.Err: %v", err)
}
if found != "sum_test1" {
t.Errorf("FTS: expected 'sum_test1', got %q", found)
}
}
func TestMigrationSummaryParentsPK(t *testing.T) {
db := openTestDB(t)
if err := runSchema(db); err != nil {
t.Fatalf("migration: %v", err)
}
// Insert two summaries
for _, id := range []string{"sum_a", "sum_b"} {
_, err := db.Exec(
`INSERT INTO summaries (summary_id, conversation_id, kind, depth, content, token_count, created_at)
VALUES (?, 1, 'leaf', 0, 'content', 5, datetime('now'))`, id)
if err != nil {
t.Fatalf("insert summary %s: %v", id, err)
}
}
// Link child to parent
_, err := db.Exec(
"INSERT INTO summary_parents (summary_id, parent_summary_id) VALUES ('sum_a', 'sum_b')")
if err != nil {
t.Fatalf("link: %v", err)
}
// Duplicate link should fail (composite PK)
_, err = db.Exec(
"INSERT INTO summary_parents (summary_id, parent_summary_id) VALUES ('sum_a', 'sum_b')")
if err == nil {
t.Error("expected unique constraint violation for duplicate summary_parents link")
}
}
func TestFTS5SQLConstants(t *testing.T) {
db := openTestDB(t)
// Verify FTS5 check SQL executes without error
_, err := db.Exec(sqlCheckFTS5Available)
if err != nil {
t.Errorf("sqlCheckFTS5Available failed: %v", err)
}
// Verify trigram check SQL executes without error
_, err = db.Exec(sqlCheckTrigramAvailable)
if err != nil {
t.Errorf("sqlCheckTrigramAvailable failed: %v", err)
}
// Verify summaries_fts SQL executes without error
_, err = db.Exec(sqlCreateSummariesFTS)
if err != nil {
t.Errorf("sqlCreateSummariesFTS failed: %v", err)
}
// Verify messages_fts SQL executes without error
_, err = db.Exec(sqlCreateMessagesFTS)
if err != nil {
t.Errorf("sqlCreateMessagesFTS failed: %v", err)
}
}
+261
View File
@@ -0,0 +1,261 @@
package seahorse
import (
"context"
"fmt"
"strings"
"time"
"github.com/sipeed/picoclaw/pkg/logger"
)
// escapeXML escapes special characters for safe inclusion in XML content.
func escapeXML(s string) string {
s = strings.ReplaceAll(s, "&", "&amp;")
s = strings.ReplaceAll(s, "<", "&lt;")
s = strings.ReplaceAll(s, ">", "&gt;")
s = strings.ReplaceAll(s, "\"", "&quot;")
s = strings.ReplaceAll(s, "'", "&apos;")
return s
}
// resolvedItem is a context item resolved to its full content with token count.
type resolvedItem struct {
ordinal int
itemType string // "message" or "summary"
message *Message
summary *Summary
tokenCount int
}
// Assemble builds budget-constrained context from summaries + messages.
//
// Algorithm:
// 1. Fetch context_items, resolve to full content
// 2. Split into evictable prefix + protected fresh tail
// 3. If evictable fits in remaining budget → include all
// 4. Else walk evictable from newest to oldest, keep while fits
func (a *Assembler) Assemble(ctx context.Context, convID int64, input AssembleInput) (*AssembleResult, error) {
items, err := a.store.GetContextItems(ctx, convID)
if err != nil {
return nil, fmt.Errorf("get context items: %w", err)
}
if len(items) == 0 {
return &AssembleResult{}, nil
}
// Resolve all items
resolved := make([]resolvedItem, len(items))
for i, item := range items {
r, err := a.resolveItem(ctx, item)
if err != nil {
return nil, err
}
resolved[i] = r
}
// Split into evictable prefix and protected fresh tail
tailStart := len(resolved) - FreshTailCount
if tailStart < 0 {
tailStart = 0
}
evictable := resolved[:tailStart]
freshTail := resolved[tailStart:]
// Calculate fresh tail tokens
freshTailTokens := 0
for _, r := range freshTail {
freshTailTokens += r.tokenCount
}
// Budget-aware selection of evictable items
remainingBudget := input.Budget - freshTailTokens
if remainingBudget < 0 {
// Fresh tail alone exceeds budget - we keep it anyway (design decision)
// Log for debugging retry/overflow issues
logger.InfoCF("seahorse", "assemble: fresh tail exceeds budget", map[string]any{
"budget": input.Budget,
"fresh_tail_tokens": freshTailTokens,
"fresh_tail_count": len(freshTail),
"over_budget_by": freshTailTokens - input.Budget,
})
remainingBudget = 0
}
var selected []resolvedItem
evictableTokens := 0
for _, r := range evictable {
evictableTokens += r.tokenCount
}
if evictableTokens <= remainingBudget {
// All evictable fit
selected = append(selected, evictable...)
} else {
// Walk from newest to oldest, keep while fits
var kept []resolvedItem
accum := 0
for i := len(evictable) - 1; i >= 0; i-- {
if accum+evictable[i].tokenCount <= remainingBudget {
kept = append(kept, evictable[i])
accum += evictable[i].tokenCount
} else {
break
}
}
// Reverse to restore chronological order
for i, j := 0, len(kept)-1; i < j; i, j = i+1, j-1 {
kept[i], kept[j] = kept[j], kept[i]
}
selected = append(selected, kept...)
}
// Combine: selected evictable + fresh tail
final := append(selected, freshTail...)
// Build result
var messages []Message
var summaries []Summary
var sourceIDs []string
totalTokens := 0
maxDepth := 0
condensedCount := 0
for _, r := range final {
totalTokens += r.tokenCount
if r.itemType == "message" && r.message != nil {
messages = append(messages, *r.message)
sourceIDs = append(sourceIDs, fmt.Sprintf("msg:%d", r.message.ID))
} else if r.itemType == "summary" && r.summary != nil {
summaries = append(summaries, *r.summary)
if r.summary.Depth > maxDepth {
maxDepth = r.summary.Depth
}
if r.summary.Kind == SummaryKindCondensed {
condensedCount++
}
}
}
// Build depth-aware system prompt addition
systemPromptAddition := ""
if len(summaries) > 0 {
if maxDepth >= 2 || condensedCount >= 2 {
systemPromptAddition = "Your context has been heavily compressed through multi-level summarization.\n" +
"- Do NOT assert specific facts (commands, SHAs, paths, timestamps) from summaries without expanding.\n" +
"- When uncertain, use expand to recover original detail before making claims.\n" +
"- Tool escalation: grep \xe2\x86\x92 describe \xe2\x86\x92 expand"
} else {
systemPromptAddition = "Some earlier messages have been summarized. Use expand tools to recover details if needed."
}
}
// Build Summary field: all XML summaries + system prompt addition
var summaryParts []string
for _, sum := range summaries {
if sum.Content == "" {
continue
}
// Load parent IDs for XML formatting
parentSummaries, err := a.store.GetSummaryParents(ctx, sum.SummaryID)
if err != nil {
logger.WarnCF("seahorse", "assemble: get summary parents", map[string]any{
"summary_id": sum.SummaryID,
"error": err.Error(),
})
}
var parentIDs []string
for _, ps := range parentSummaries {
parentIDs = append(parentIDs, ps.SummaryID)
}
summaryParts = append(summaryParts, FormatSummaryXML(&sum, parentIDs))
}
summary := strings.Join(summaryParts, "\n\n")
if systemPromptAddition != "" {
if summary != "" {
summary += "\n\n"
}
summary += systemPromptAddition
}
return &AssembleResult{
Messages: messages,
Summary: summary,
}, nil
}
// resolveItem loads the full message or summary for a context item.
func (a *Assembler) resolveItem(ctx context.Context, item ContextItem) (resolvedItem, error) {
if item.ItemType == "message" {
msg, err := a.store.GetMessageByID(ctx, item.MessageID)
if err != nil {
return resolvedItem{}, err
}
tokens := item.TokenCount
if tokens == 0 {
tokens = msg.TokenCount
}
return resolvedItem{
ordinal: item.Ordinal,
itemType: "message",
message: msg,
tokenCount: tokens,
}, nil
}
if item.ItemType == "summary" {
sum, err := a.store.GetSummary(ctx, item.SummaryID)
if err != nil {
return resolvedItem{}, err
}
tokens := item.TokenCount
if tokens == 0 {
tokens = sum.TokenCount
}
return resolvedItem{
ordinal: item.Ordinal,
itemType: "summary",
summary: sum,
tokenCount: tokens,
}, nil
}
return resolvedItem{
ordinal: item.Ordinal,
itemType: item.ItemType,
tokenCount: item.TokenCount,
}, nil
}
// FormatSummaryXML formats a summary as XML for LLM context.
// This is exported so context managers can format summaries consistently.
func FormatSummaryXML(s *Summary, parentIDs []string) string {
// Build time attributes if available
var attrs string
if s.EarliestAt != nil {
attrs += fmt.Sprintf(` earliest_at="%s"`, s.EarliestAt.Format(time.RFC3339))
}
if s.LatestAt != nil {
attrs += fmt.Sprintf(` latest_at="%s"`, s.LatestAt.Format(time.RFC3339))
}
var parentsSection string
if s.Kind == SummaryKindCondensed && len(parentIDs) > 0 {
parents := "<parents>\n"
for _, pid := range parentIDs {
parents += fmt.Sprintf(" <summary_ref id=\"%s\" />\n", pid)
}
parents += " </parents>\n"
parentsSection = parents
}
return fmt.Sprintf(
"<summary id=\"%s\" kind=\"%s\" depth=\"%d\" descendant_count=\"%d\"%s>\n <content>\n %s\n </content>\n%s</summary>",
s.SummaryID,
string(s.Kind),
s.Depth,
s.DescendantCount,
attrs,
escapeXML(s.Content),
parentsSection,
)
}
+536
View File
@@ -0,0 +1,536 @@
package seahorse
import (
"context"
"strings"
"testing"
"time"
)
// --- Assembler Tests ---
// helper: create a store with messages and summaries for assembly tests
func setupAssemblerStore(t *testing.T) (*Store, int64) {
t.Helper()
s := openTestStore(t)
ctx := context.Background()
conv, err := s.GetOrCreateConversation(ctx, "test:assemble")
if err != nil {
t.Fatalf("create conversation: %v", err)
}
return s, conv.ConversationID
}
func TestAssemblerAssembleEmpty(t *testing.T) {
s, convID := setupAssemblerStore(t)
ctx := context.Background()
a := &Assembler{store: s, config: Config{}}
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 1000})
if err != nil {
t.Fatalf("Assemble: %v", err)
}
if len(result.Messages) != 0 {
t.Errorf("Messages = %d, want 0", len(result.Messages))
}
if result.Summary != "" {
t.Errorf("Summary = %q, want empty", result.Summary)
}
}
func TestAssemblerAssembleMessagesOnly(t *testing.T) {
s, convID := setupAssemblerStore(t)
ctx := context.Background()
// Create messages
msg1, _ := s.AddMessage(ctx, convID, "user", "hello", 5)
msg2, _ := s.AddMessage(ctx, convID, "assistant", "world", 5)
// Create context items
s.UpsertContextItems(ctx, convID, []ContextItem{
{Ordinal: 100, ItemType: "message", MessageID: msg1.ID, TokenCount: 5},
{Ordinal: 200, ItemType: "message", MessageID: msg2.ID, TokenCount: 5},
})
a := &Assembler{store: s, config: Config{}}
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 100})
if err != nil {
t.Fatalf("Assemble: %v", err)
}
if len(result.Messages) != 2 {
t.Fatalf("Messages = %d, want 2", len(result.Messages))
}
if result.Messages[0].Content != "hello" {
t.Errorf("Messages[0].Content = %q, want 'hello'", result.Messages[0].Content)
}
if result.Messages[1].Content != "world" {
t.Errorf("Messages[1].Content = %q, want 'world'", result.Messages[1].Content)
}
// No summaries, so Summary should be empty
if result.Summary != "" {
t.Errorf("Summary = %q, want empty", result.Summary)
}
}
func TestAssemblerAssembleWithSummary(t *testing.T) {
s, convID := setupAssemblerStore(t)
ctx := context.Background()
// Create a summary
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindLeaf,
Depth: 0,
Content: "summary of early messages",
TokenCount: 50,
})
// Create recent messages
msg1, _ := s.AddMessage(ctx, convID, "user", "recent", 5)
msg2, _ := s.AddMessage(ctx, convID, "assistant", "reply", 5)
// Context: summary + recent messages
s.UpsertContextItems(ctx, convID, []ContextItem{
{Ordinal: 100, ItemType: "summary", SummaryID: summary.SummaryID, TokenCount: 50},
{Ordinal: 200, ItemType: "message", MessageID: msg1.ID, TokenCount: 5},
{Ordinal: 300, ItemType: "message", MessageID: msg2.ID, TokenCount: 5},
})
a := &Assembler{store: s, config: Config{}}
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 1000})
if err != nil {
t.Fatalf("Assemble: %v", err)
}
// Messages = 2 raw messages (summaries are in Summary field, not Messages)
if len(result.Messages) != 2 {
t.Errorf("Messages = %d, want 2 (raw messages only)", len(result.Messages))
}
// Summary should contain XML with summary content
if result.Summary == "" {
t.Error("Summary should not be empty when summary exists")
}
if !strings.Contains(result.Summary, summary.Content) {
t.Errorf("Summary should contain summary content %q", summary.Content)
}
if !strings.Contains(result.Summary, "<summary") {
t.Error("Summary should contain <summary XML tag")
}
}
func TestAssemblerBudgetEvictsOldest(t *testing.T) {
s, convID := setupAssemblerStore(t)
ctx := context.Background()
// Create 40 messages, each with 10 tokens = 400 total
msgs := make([]*Message, 40)
for i := 0; i < 40; i++ {
m, _ := s.AddMessage(ctx, convID, "user", "msg", 10)
msgs[i] = m
}
// Context items for all messages
items := make([]ContextItem, 40)
for i := 0; i < 40; i++ {
items[i] = ContextItem{
Ordinal: (i + 1) * 100,
ItemType: "message",
MessageID: msgs[i].ID,
TokenCount: 10,
}
}
s.UpsertContextItems(ctx, convID, items)
// Budget of 200 tokens with FreshTailCount=32
// Fresh tail = last 32 messages (320 tokens, over budget, but always included)
// Evictable = first 8 messages (80 tokens)
// Budget after tail: max(0, 200-320) = 0 → no evictable items included
a := &Assembler{store: s, config: Config{}}
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 200})
if err != nil {
t.Fatalf("Assemble: %v", err)
}
// Should only include the 32-item fresh tail
if len(result.Messages) != 32 {
t.Errorf("Messages = %d, want 32 (fresh tail)", len(result.Messages))
}
// Should be the LAST 32 messages
if result.Messages[0].ID != msgs[8].ID {
t.Errorf("first message ID = %d, want %d (msgs[8])", result.Messages[0].ID, msgs[8].ID)
}
}
func TestAssemblerBudgetFitsAll(t *testing.T) {
s, convID := setupAssemblerStore(t)
ctx := context.Background()
msgs := make([]*Message, 5)
for i := 0; i < 5; i++ {
m, _ := s.AddMessage(ctx, convID, "user", "msg", 10)
msgs[i] = m
}
items := make([]ContextItem, 5)
for i := 0; i < 5; i++ {
items[i] = ContextItem{
Ordinal: (i + 1) * 100,
ItemType: "message",
MessageID: msgs[i].ID,
TokenCount: 10,
}
}
s.UpsertContextItems(ctx, convID, items)
// Budget = 100, total = 50, FreshTailCount=32 → all items in tail
a := &Assembler{store: s, config: Config{}}
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 100})
if err != nil {
t.Fatalf("Assemble: %v", err)
}
if len(result.Messages) != 5 {
t.Errorf("Messages = %d, want 5", len(result.Messages))
}
}
func TestAssemblerSummaryXMLFormat(t *testing.T) {
s, convID := setupAssemblerStore(t)
ctx := context.Background()
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindLeaf,
Depth: 0,
Content: "test summary content",
TokenCount: 20,
})
msg, _ := s.AddMessage(ctx, convID, "user", "hello", 5)
s.UpsertContextItems(ctx, convID, []ContextItem{
{Ordinal: 100, ItemType: "summary", SummaryID: summary.SummaryID, TokenCount: 20},
{Ordinal: 200, ItemType: "message", MessageID: msg.ID, TokenCount: 5},
})
a := &Assembler{store: s, config: Config{}}
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 1000})
if err != nil {
t.Fatalf("Assemble: %v", err)
}
// Messages should only contain raw messages (no XML summary in Messages)
if len(result.Messages) != 1 {
t.Errorf("Messages = %d, want 1 (raw message only)", len(result.Messages))
}
// Summary should contain XML with summary content
if result.Summary == "" {
t.Fatal("Summary should not be empty")
}
if !contains(result.Summary, "<summary") {
t.Errorf("Summary missing <summary tag: %q", result.Summary)
}
if !contains(result.Summary, summary.SummaryID) {
t.Errorf("Summary missing summary ID: %q", result.Summary)
}
}
func TestAssemblerSummaryXMLEscaping(t *testing.T) {
// Summary content with special XML characters should be properly escaped
s, convID := setupAssemblerStore(t)
ctx := context.Background()
// Create summary with content containing XML special characters
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindLeaf,
Depth: 0,
Content: `User said: "hello" & asked about <tags>`,
TokenCount: 20,
})
s.UpsertContextItems(ctx, convID, []ContextItem{
{Ordinal: 100, ItemType: "summary", SummaryID: summary.SummaryID, TokenCount: 20},
})
a := &Assembler{store: s, config: Config{}}
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 1000})
if err != nil {
t.Fatalf("Assemble: %v", err)
}
// Summary field should contain XML with escaped special characters
if result.Summary == "" {
t.Fatal("Summary should not be empty")
}
// Check that special characters are escaped
if strings.Contains(result.Summary, "<tags>") {
t.Errorf("BUG: unescaped < in summary content: %q", result.Summary)
}
if strings.Contains(result.Summary, `"hello"`) {
t.Errorf("BUG: unescaped \" in summary content: %q", result.Summary)
}
// & should be escaped as &amp;
if strings.Contains(result.Summary, " & ") {
t.Errorf("BUG: unescaped & in summary content: %q", result.Summary)
}
}
func TestAssemblerSummaryXMLWithParents(t *testing.T) {
s, convID := setupAssemblerStore(t)
ctx := context.Background()
// Create a leaf and a condensed summary (condensed has parent)
leaf, _ := s.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindLeaf,
Depth: 0,
Content: "leaf content",
TokenCount: 20,
})
condensed, _ := s.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindCondensed,
Depth: 1,
Content: "condensed content",
TokenCount: 15,
ParentIDs: []string{leaf.SummaryID},
})
msg, _ := s.AddMessage(ctx, convID, "user", "fresh", 5)
s.UpsertContextItems(ctx, convID, []ContextItem{
{Ordinal: 100, ItemType: "summary", SummaryID: condensed.SummaryID, TokenCount: 15},
{Ordinal: 200, ItemType: "message", MessageID: msg.ID, TokenCount: 5},
})
a := &Assembler{store: s, config: Config{}}
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 1000})
if err != nil {
t.Fatalf("Assemble: %v", err)
}
// Summary field should contain XML with parent information
if result.Summary == "" {
t.Fatal("Summary should not be empty")
}
xmlContent := result.Summary
// Should contain <parents> section with parent ID
if !contains(xmlContent, "<parents>") {
t.Errorf("condensed summary XML missing <parents> section: %q", xmlContent)
}
if !contains(xmlContent, leaf.SummaryID) {
t.Errorf("condensed summary XML missing parent ID %q: %q", leaf.SummaryID, xmlContent)
}
// Should contain kind="condensed"
if !contains(xmlContent, `kind="condensed"`) {
t.Errorf("condensed summary XML missing kind attribute: %q", xmlContent)
}
}
func TestAssemblerSummaryXMLIncludesDescendantCount(t *testing.T) {
s, convID := setupAssemblerStore(t)
ctx := context.Background()
// Create a leaf summary with specific descendant count
leaf, _ := s.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindLeaf,
Depth: 0,
Content: "leaf content",
TokenCount: 20,
DescendantCount: 8,
DescendantTokenCount: 1200,
})
msg, _ := s.AddMessage(ctx, convID, "user", "fresh", 5)
s.UpsertContextItems(ctx, convID, []ContextItem{
{Ordinal: 100, ItemType: "summary", SummaryID: leaf.SummaryID, TokenCount: 20},
{Ordinal: 200, ItemType: "message", MessageID: msg.ID, TokenCount: 5},
})
a := &Assembler{store: s, config: Config{}}
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 1000})
if err != nil {
t.Fatalf("Assemble: %v", err)
}
if result.Summary == "" {
t.Fatal("Summary should not be empty")
}
xmlContent := result.Summary
// Should contain descendant_count="8"
if !contains(xmlContent, `descendant_count="8"`) {
t.Errorf("summary XML missing descendant_count attribute: %q", xmlContent)
}
}
func TestAssemblerLeafSummaryNoParents(t *testing.T) {
s, convID := setupAssemblerStore(t)
ctx := context.Background()
// Leaf summary has no parents
leaf, _ := s.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindLeaf,
Depth: 0,
Content: "leaf content",
TokenCount: 20,
})
msg, _ := s.AddMessage(ctx, convID, "user", "fresh", 5)
s.UpsertContextItems(ctx, convID, []ContextItem{
{Ordinal: 100, ItemType: "summary", SummaryID: leaf.SummaryID, TokenCount: 20},
{Ordinal: 200, ItemType: "message", MessageID: msg.ID, TokenCount: 5},
})
a := &Assembler{store: s, config: Config{}}
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 1000})
if err != nil {
t.Fatalf("Assemble: %v", err)
}
if result.Summary == "" {
t.Fatal("Summary should not be empty")
}
xmlContent := result.Summary
// Leaf summary should NOT have <parents> section
if contains(xmlContent, "<parents>") {
t.Errorf("leaf summary XML should not have <parents> section: %q", xmlContent)
}
}
func TestAssemblerDepthAwarePrompt(t *testing.T) {
s, convID := setupAssemblerStore(t)
ctx := context.Background()
// Create a condensed summary (depth >= 2) to trigger full guidance
now := time.Now().UTC()
leaf, _ := s.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindLeaf,
Depth: 0,
Content: "leaf summary",
TokenCount: 20,
EarliestAt: &now,
LatestAt: &now,
})
condensed, _ := s.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindCondensed,
Depth: 2,
Content: "condensed summary",
TokenCount: 15,
ParentIDs: []string{leaf.SummaryID},
DescendantCount: 1,
DescendantTokenCount: 20,
})
msg, _ := s.AddMessage(ctx, convID, "user", "fresh", 5)
s.UpsertContextItems(ctx, convID, []ContextItem{
{Ordinal: 100, ItemType: "summary", SummaryID: condensed.SummaryID, TokenCount: 15},
{Ordinal: 200, ItemType: "message", MessageID: msg.ID, TokenCount: 5},
})
a := &Assembler{store: s, config: Config{}}
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 1000})
if err != nil {
t.Fatalf("Assemble: %v", err)
}
// Should have a depth-aware prompt in Summary field
if result.Summary == "" {
t.Error("expected non-empty Summary when depth >= 2")
}
// SystemPromptAddition is embedded in Summary field
if !strings.Contains(result.Summary, "multi-level summarization") {
t.Error("Summary should contain system prompt addition about multi-level summarization")
}
}
func TestFormatSummaryXMLUsesSummaryRef(t *testing.T) {
// Spec: condensed summaries use <summary_ref id="parentId" /> not <parent>parentId</parent>
now := time.Now().UTC()
s := Summary{
SummaryID: "sum_condensed1",
Kind: SummaryKindCondensed,
Depth: 1,
Content: "condensed content",
TokenCount: 50,
DescendantCount: 2,
EarliestAt: &now,
LatestAt: &now,
}
parentIDs := []string{"sum_leaf1", "sum_leaf2"}
xml := FormatSummaryXML(&s, parentIDs)
// Must use <summary_ref id="..." /> per spec
if !contains(xml, `<summary_ref id="sum_leaf1" />`) {
t.Errorf("expected <summary_ref id=\"sum_leaf1\" />, got: %s", xml)
}
if !contains(xml, `<summary_ref id="sum_leaf2" />`) {
t.Errorf("expected <summary_ref id=\"sum_leaf2\" />, got: %s", xml)
}
// Must NOT use old <parent> tag
if contains(xml, "<parent>") {
t.Errorf("should not use <parent> tag, got: %s", xml)
}
}
func TestFormatSummaryXMLIncludesTimestamps(t *testing.T) {
// Spec: summary XML includes earliest_at and latest_at attributes
earliest := time.Date(2026, 3, 15, 10, 0, 0, 0, time.UTC)
latest := time.Date(2026, 3, 15, 14, 30, 0, 0, time.UTC)
s := Summary{
SummaryID: "sum_leaf1",
Kind: SummaryKindLeaf,
Depth: 0,
Content: "leaf content",
TokenCount: 30,
DescendantCount: 0,
EarliestAt: &earliest,
LatestAt: &latest,
}
xml := FormatSummaryXML(&s, nil)
if !contains(xml, `earliest_at="2026-03-15T10:00:00Z"`) {
t.Errorf("missing earliest_at attribute, got: %s", xml)
}
if !contains(xml, `latest_at="2026-03-15T14:30:00Z"`) {
t.Errorf("missing latest_at attribute, got: %s", xml)
}
}
func TestFormatSummaryXMLNoTimestampsWhenNil(t *testing.T) {
// When EarliestAt/LatestAt are nil, attributes should be omitted
s := Summary{
SummaryID: "sum_leaf1",
Kind: SummaryKindLeaf,
Depth: 0,
Content: "leaf content",
TokenCount: 30,
DescendantCount: 0,
}
xml := FormatSummaryXML(&s, nil)
if contains(xml, "earliest_at=") {
t.Errorf("should not have earliest_at when nil, got: %s", xml)
}
if contains(xml, "latest_at=") {
t.Errorf("should not have latest_at when nil, got: %s", xml)
}
}
+336
View File
@@ -0,0 +1,336 @@
package seahorse
import (
"context"
"database/sql"
"fmt"
"testing"
"time"
_ "modernc.org/sqlite"
)
// newBenchStore creates a test store for benchmarks.
func newBenchStore(b *testing.B) (*Store, func()) {
b.Helper()
db, err := sql.Open("sqlite", ":memory:")
if err != nil {
b.Fatalf("open test db: %v", err)
}
if err := runSchema(db); err != nil {
db.Close()
b.Fatalf("migration: %v", err)
}
return &Store{db: db}, func() { db.Close() }
}
// --- Ingest benchmarks ---
func BenchmarkIngest_SingleMessage(b *testing.B) {
s, cleanup := newBenchStore(b)
defer cleanup()
ctx := context.Background()
conv, _ := s.GetOrCreateConversation(ctx, "bench:ingest")
convID := conv.ConversationID
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := s.AddMessage(ctx, convID, "user", "Test message content", 15)
if err != nil {
b.Fatal(err)
}
}
}
func BenchmarkIngest_BatchMessages(b *testing.B) {
s, cleanup := newBenchStore(b)
defer cleanup()
ctx := context.Background()
b.ResetTimer()
for i := 0; i < b.N; i++ {
conv, _ := s.GetOrCreateConversation(ctx, fmt.Sprintf("bench:ingest-batch:%d", i))
convID := conv.ConversationID
for j := 0; j < 10; j++ {
added, err := s.AddMessage(ctx, convID, "user",
fmt.Sprintf("Message %d in batch", j), 10)
if err != nil {
b.Fatal(err)
}
s.AppendContextMessage(ctx, convID, added.ID)
}
}
}
// --- Assemble benchmarks ---
func BenchmarkAssemble_MessagesOnly(b *testing.B) {
s, cleanup := newBenchStore(b)
defer cleanup()
ctx := context.Background()
conv, _ := s.GetOrCreateConversation(ctx, "bench:assemble-msgs")
convID := conv.ConversationID
// Add 100 messages
for i := 0; i < 100; i++ {
m, _ := s.AddMessage(ctx, convID, "user",
fmt.Sprintf("Message content %d with some text", i), 10)
s.AppendContextMessage(ctx, convID, m.ID)
}
a := &Assembler{store: s}
input := AssembleInput{Budget: 50000}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := a.Assemble(ctx, convID, input)
if err != nil {
b.Fatal(err)
}
}
}
func BenchmarkAssemble_WithSummaries(b *testing.B) {
s, cleanup := newBenchStore(b)
defer cleanup()
ctx := context.Background()
conv, _ := s.GetOrCreateConversation(ctx, "bench:assemble-sums")
convID := conv.ConversationID
now := time.Now().UTC()
// Add 10 leaf summaries
for i := 0; i < 10; i++ {
sum, _ := s.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindLeaf,
Depth: 0,
Content: fmt.Sprintf("Leaf summary %d", i),
TokenCount: 500,
EarliestAt: &now,
LatestAt: &now,
})
s.AppendContextSummary(ctx, convID, sum.SummaryID)
}
// Add 20 fresh messages
for i := 0; i < 20; i++ {
m, _ := s.AddMessage(ctx, convID, "user", fmt.Sprintf("Fresh message %d", i), 10)
s.AppendContextMessage(ctx, convID, m.ID)
}
a := &Assembler{store: s}
input := AssembleInput{Budget: 10000}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := a.Assemble(ctx, convID, input)
if err != nil {
b.Fatal(err)
}
}
}
func BenchmarkAssemble_BudgetEviction(b *testing.B) {
s, cleanup := newBenchStore(b)
defer cleanup()
ctx := context.Background()
conv, _ := s.GetOrCreateConversation(ctx, "bench:assemble-evict")
convID := conv.ConversationID
now := time.Now().UTC()
// Add 50 leaf summaries (more than budget can hold)
for i := 0; i < 50; i++ {
sum, _ := s.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindLeaf,
Depth: 0,
Content: fmt.Sprintf("Summary %d", i),
TokenCount: 300,
EarliestAt: &now,
LatestAt: &now,
})
s.AppendContextSummary(ctx, convID, sum.SummaryID)
}
// Add fresh tail
for i := 0; i < FreshTailCount; i++ {
m, _ := s.AddMessage(ctx, convID, "user", "fresh", 10)
s.AppendContextMessage(ctx, convID, m.ID)
}
a := &Assembler{store: s}
input := AssembleInput{Budget: 5000} // Force eviction
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := a.Assemble(ctx, convID, input)
if err != nil {
b.Fatal(err)
}
}
}
// --- Search (FTS5) benchmarks ---
// benchSeedSummaries adds n summaries to a conversation for search benchmarks.
func benchSeedSummaries(b *testing.B, s *Store, convID int64, n int, contentTpl string) {
b.Helper()
now := time.Now().UTC()
for i := 0; i < n; i++ {
sum, err := s.CreateSummary(context.Background(), CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindLeaf,
Depth: 0,
Content: fmt.Sprintf(contentTpl, i),
TokenCount: 200,
EarliestAt: &now,
LatestAt: &now,
})
if err != nil {
b.Fatalf("create summary: %v", err)
}
s.AppendContextSummary(context.Background(), convID, sum.SummaryID)
}
}
func BenchmarkSearchSummaries_FTS5(b *testing.B) {
s, cleanup := newBenchStore(b)
defer cleanup()
ctx := context.Background()
conv, _ := s.GetOrCreateConversation(ctx, "bench:search-fts")
convID := conv.ConversationID
benchSeedSummaries(b, s, convID, 100, "Summary about database configuration and API endpoints %d")
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := s.SearchSummaries(ctx, SearchInput{
Pattern: "database",
Mode: "full_text",
ConversationID: convID,
})
if err != nil {
b.Fatal(err)
}
}
}
func BenchmarkSearchSummaries_Like(b *testing.B) {
s, cleanup := newBenchStore(b)
defer cleanup()
ctx := context.Background()
conv, _ := s.GetOrCreateConversation(ctx, "bench:search-like")
convID := conv.ConversationID
benchSeedSummaries(b, s, convID, 100, "Summary about configuration %d")
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := s.SearchSummaries(ctx, SearchInput{
Pattern: "config",
Mode: "like",
ConversationID: convID,
})
if err != nil {
b.Fatal(err)
}
}
}
func BenchmarkSearchMessages_FTS5(b *testing.B) {
s, cleanup := newBenchStore(b)
defer cleanup()
ctx := context.Background()
conv, _ := s.GetOrCreateConversation(ctx, "bench:search-msg-fts")
convID := conv.ConversationID
// Add 500 messages
for i := 0; i < 500; i++ {
m, _ := s.AddMessage(ctx, convID, "user",
fmt.Sprintf("User message about API and database integration %d", i), 20)
s.AppendContextMessage(ctx, convID, m.ID)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := s.SearchMessages(ctx, SearchInput{
Pattern: "API database",
Mode: "full_text",
ConversationID: convID,
})
if err != nil {
b.Fatal(err)
}
}
}
// --- Bootstrap benchmarks ---
func BenchmarkBootstrap_Empty(b *testing.B) {
s, cleanup := newBenchStore(b)
defer cleanup()
ctx := context.Background()
b.ResetTimer()
for i := 0; i < b.N; i++ {
conv, _ := s.GetOrCreateConversation(ctx, fmt.Sprintf("bench:bootstrap-empty:%d", i))
convID := conv.ConversationID
_ = convID // Bootstrap with empty history
}
}
func BenchmarkBootstrap_100Messages(b *testing.B) {
s, cleanup := newBenchStore(b)
defer cleanup()
ctx := context.Background()
// Prepare 100 messages
msgs := make([]Message, 100)
for i := 0; i < 100; i++ {
msgs[i] = Message{
Role: "user",
Content: fmt.Sprintf("Bootstrap message %d", i),
TokenCount: 15,
}
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
conv, _ := s.GetOrCreateConversation(ctx, fmt.Sprintf("bench:bootstrap-100:%d", i))
convID := conv.ConversationID
for _, m := range msgs {
added, _ := s.AddMessage(ctx, convID, m.Role, m.Content, m.TokenCount)
s.AppendContextMessage(ctx, convID, added.ID)
}
}
}
func BenchmarkBootstrap_500Messages(b *testing.B) {
s, cleanup := newBenchStore(b)
defer cleanup()
ctx := context.Background()
msgs := make([]Message, 500)
for i := 0; i < 500; i++ {
msgs[i] = Message{
Role: "user",
Content: fmt.Sprintf("Bootstrap message %d", i),
TokenCount: 15,
}
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
conv, _ := s.GetOrCreateConversation(ctx, fmt.Sprintf("bench:bootstrap-500:%d", i))
convID := conv.ConversationID
for _, m := range msgs {
added, _ := s.AddMessage(ctx, convID, m.Role, m.Content, m.TokenCount)
s.AppendContextMessage(ctx, convID, added.ID)
}
}
}
+898
View File
@@ -0,0 +1,898 @@
package seahorse
import (
"context"
"fmt"
"sort"
"time"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/tokenizer"
)
// CompactInput controls compaction behavior.
type CompactInput struct {
Budget *int // Token budget override
Force bool // Force compaction even if below threshold
}
// CompactResult describes what was compacted.
type CompactResult struct {
SummariesCreated []string `json:"summariesCreated"`
TokensSaved int `json:"tokensSaved"`
LeafSummaries int `json:"leafSummaries"`
CondensedSummaries int `json:"condensedSummaries"`
}
// NeedsCompaction returns true if context tokens >= ContextThreshold × contextWindow.
func (e *CompactionEngine) NeedsCompaction(ctx context.Context, convID int64, contextWindow int) (bool, error) {
tokens, err := e.store.GetContextTokenCount(ctx, convID)
if err != nil {
return false, fmt.Errorf("get token count: %w", err)
}
threshold := int(float64(contextWindow) * ContextThreshold)
return tokens >= threshold, nil
}
// Close cancels the shutdown context, stopping async goroutines.
func (e *CompactionEngine) Close() {
if e.shutdownCancel != nil {
e.shutdownCancel()
}
}
// Compact runs leaf compaction (sync) and optionally condensed compaction.
func (e *CompactionEngine) Compact(ctx context.Context, convID int64, input CompactInput) (*CompactResult, error) {
result := &CompactResult{}
// Phase 1: leaf compaction (synchronous, every turn)
summaryID, err := e.compactLeaf(ctx, convID)
if err != nil {
return nil, fmt.Errorf("compact leaf: %w", err)
}
if summaryID != nil {
result.SummariesCreated = append(result.SummariesCreated, *summaryID)
result.LeafSummaries++
logger.InfoCF("seahorse", "compact: leaf", map[string]any{
"conv_id": convID,
"summary_id": *summaryID,
})
}
// Phase 2: condensed compaction if over threshold
tokensBefore, _ := e.store.GetContextTokenCount(ctx, convID)
var budget int
if input.Budget != nil {
budget = *input.Budget
if budget == 0 {
logger.ErrorCF("seahorse", "Compact: budget is 0, this should not happen", map[string]any{
"conv_id": convID,
})
}
} else {
budget = int(float64(tokensBefore) * ContextThreshold)
}
if input.Force || (tokensBefore > budget && budget > 0) {
// Launch async condensed compaction with dedup
if _, loaded := e.condensing.LoadOrStore(convID, struct{}{}); !loaded {
go func() {
defer e.condensing.Delete(convID)
e.runCondensedLoop(e.shutdownCtx, convID)
}()
}
}
tokensAfter, _ := e.store.GetContextTokenCount(ctx, convID)
if tokensAfter < tokensBefore {
result.TokensSaved = tokensBefore - tokensAfter
}
return result, nil
}
// CompactUntilUnder aggressively compacts until context is under budget.
func (e *CompactionEngine) CompactUntilUnder(ctx context.Context, convID int64, budget int) (*CompactResult, error) {
result := &CompactResult{}
prevTokens := 0
logger.InfoCF("seahorse", "compact_until_under: start", map[string]any{"conv_id": convID, "budget": budget})
for iter := 0; iter < MaxCompactIterations; iter++ {
tokens, err := e.store.GetContextTokenCount(ctx, convID)
if err != nil {
return result, fmt.Errorf("get tokens: %w", err)
}
if tokens <= budget {
logger.InfoCF("seahorse", "compact_until_under: done", map[string]any{
"conv_id": convID,
"budget": budget,
"tokens": tokens,
"leaf": result.LeafSummaries,
"condensed": result.CondensedSummaries,
})
return result, nil
}
// Try leaf first
summaryID, err := e.compactLeaf(ctx, convID, true)
if err != nil {
return result, err
}
if summaryID != nil {
result.SummariesCreated = append(result.SummariesCreated, *summaryID)
result.LeafSummaries++
logger.InfoCF("seahorse", "compact_until_under: leaf", map[string]any{
"conv_id": convID,
"summary_id": *summaryID,
})
continue
}
// Try condensed with forced fanout
condensedID, err := e.compactCondensed(ctx, convID)
if err != nil {
return result, err
}
if condensedID != nil {
result.SummariesCreated = append(result.SummariesCreated, *condensedID)
result.CondensedSummaries++
logger.InfoCF("seahorse", "compact_until_under: condensed", map[string]any{
"conv_id": convID,
"summary_id": *condensedID,
})
continue
}
// No progress
newTokens, _ := e.store.GetContextTokenCount(ctx, convID)
if newTokens >= prevTokens {
logger.WarnCF("seahorse", "compact_until_under: no progress", map[string]any{
"conv_id": convID,
"tokens": newTokens,
})
return result, nil
}
prevTokens = newTokens
}
// Safety cap exceeded — see MaxCompactIterations doc for rationale.
logger.WarnCF("seahorse", "compact_until_under: exceeded max iterations", map[string]any{
"conv_id": convID,
"budget": budget,
"iterations": MaxCompactIterations,
"tokens": prevTokens,
})
return result, nil
}
// compactLeaf compresses the oldest contiguous message chunk into a leaf summary.
// When force is true, FreshTailCount protection is bypassed (used by CompactUntilUnder).
func (e *CompactionEngine) compactLeaf(ctx context.Context, convID int64, force ...bool) (*string, error) {
items, err := e.store.GetContextItems(ctx, convID)
if err != nil {
return nil, err
}
// Find oldest contiguous message chunk outside fresh tail
msgCount := 0
msgTokens := 0
for _, item := range items {
if item.ItemType == "message" {
msgCount++
msgTokens += item.TokenCount
}
}
// Trigger if either message count or token threshold is met
if msgCount < LeafMinFanout && msgTokens < LeafChunkTokens {
return nil, nil
}
// Calculate fresh tail boundary (bypass when forced)
useForce := len(force) > 0 && force[0]
tailStartIdx := len(items) - FreshTailCount
if useForce {
tailStartIdx = len(items) // allow compacting everything
}
if tailStartIdx < 0 {
tailStartIdx = 0
}
// Find oldest contiguous message chunk, accumulating up to LeafChunkTokens
var chunk []ContextItem
chunkStart := -1
chunkEnd := -1
accumTokens := 0
for i := 0; i < tailStartIdx; i++ {
if items[i].ItemType == "message" {
if chunkStart == -1 {
chunkStart = i
}
chunkEnd = i
accumTokens += items[i].TokenCount
// Stop accumulating once we reach the token budget
if accumTokens >= LeafChunkTokens {
break
}
} else {
// Non-message breaks the chunk
if chunkStart != -1 && (chunkEnd-chunkStart+1) >= LeafMinFanout {
break
}
chunkStart = -1
chunkEnd = -1
accumTokens = 0
}
}
if chunkStart == -1 || (chunkEnd-chunkStart+1) < LeafMinFanout {
return nil, nil
}
chunk = items[chunkStart : chunkEnd+1]
// Collect messages for the chunk
var messages []Message
for _, item := range chunk {
msg, innerErr := e.store.GetMessageByID(ctx, item.MessageID)
if innerErr != nil {
return nil, innerErr
}
messages = append(messages, *msg)
}
// Get prior summaries for context
priorSummary := ""
priorCount := 0
for i := chunkStart - 1; i >= 0 && priorCount < 2; i-- {
if items[i].ItemType == "summary" {
sum, innerErr2 := e.store.GetSummary(ctx, items[i].SummaryID)
if innerErr2 == nil {
priorSummary = sum.Content + "\n" + priorSummary
priorCount++
}
}
}
// Generate summary
content, err := e.generateLeafSummary(ctx, messages, priorSummary)
if err != nil {
return nil, err
}
// Create summary in store
tokenCount := tokenizer.EstimateMessageTokens(providers.Message{Content: content})
var earliestAt, latestAt *time.Time
if len(messages) > 0 {
earliestAt = &messages[0].CreatedAt
latestAt = &messages[len(messages)-1].CreatedAt
}
summary, err := e.store.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindLeaf,
Depth: 0,
Content: content,
TokenCount: tokenCount,
EarliestAt: earliestAt,
LatestAt: latestAt,
SourceMessageTokens: sumMessageTokens(messages),
})
if err != nil {
return nil, err
}
// Link to source messages
msgIDs := make([]int64, len(messages))
for i, m := range messages {
msgIDs[i] = m.ID
}
if err := e.store.LinkSummaryToMessages(ctx, summary.SummaryID, msgIDs); err != nil {
return nil, err
}
// Replace context range with summary
if err := e.store.ReplaceContextRangeWithSummary(
ctx, convID, chunk[0].Ordinal, chunk[len(chunk)-1].Ordinal, summary.SummaryID,
); err != nil {
return nil, err
}
return &summary.SummaryID, nil
}
// compactCondensed compresses multiple summaries into one higher-level summary.
func (e *CompactionEngine) compactCondensed(ctx context.Context, convID int64) (*string, error) {
// Try ordinal-aware selection first (respects consecutive ordering)
var candidates []Summary
depths, err := e.store.GetDistinctDepthsInContext(ctx, convID, 0)
if err != nil {
return nil, err
}
for _, depth := range depths {
var chunkAtDepth []Summary
var err2 error
chunkAtDepth, err2 = e.selectOldestChunkAtDepth(ctx, convID, depth)
if err2 != nil {
continue
}
if len(chunkAtDepth) > 0 {
candidates = chunkAtDepth
break
}
}
// Fallback to depth-grouping selection
if len(candidates) == 0 {
candidates, err = e.selectShallowestCondensationCandidate(ctx, convID, false)
if err != nil {
return nil, err
}
}
if len(candidates) == 0 {
return nil, nil
}
// Generate condensed summary
content, err := e.generateCondensedSummary(ctx, candidates)
if err != nil {
return nil, err
}
// Merge metadata
maxDepth := 0
descendantCount := 0
descendantTokenCount := 0
sourceMessageTokens := 0
var earliestAt, latestAt *time.Time
parentIDs := make([]string, len(candidates))
for i, c := range candidates {
parentIDs[i] = c.SummaryID
if c.Depth > maxDepth {
maxDepth = c.Depth
}
descendantCount += c.DescendantCount + 1
descendantTokenCount += c.TokenCount + c.DescendantTokenCount
sourceMessageTokens += c.SourceMessageTokenCount
if c.EarliestAt != nil {
if earliestAt == nil || c.EarliestAt.Before(*earliestAt) {
earliestAt = c.EarliestAt
}
}
if c.LatestAt != nil {
if latestAt == nil || c.LatestAt.After(*latestAt) {
latestAt = c.LatestAt
}
}
}
tokenCount := tokenizer.EstimateMessageTokens(providers.Message{Content: content})
summary, err := e.store.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindCondensed,
Depth: maxDepth + 1,
Content: content,
TokenCount: tokenCount,
EarliestAt: earliestAt,
LatestAt: latestAt,
DescendantCount: descendantCount,
DescendantTokenCount: descendantTokenCount,
SourceMessageTokens: sourceMessageTokens,
ParentIDs: parentIDs,
})
if err != nil {
return nil, err
}
// Find the ordinal range for the candidate summaries in context
items, err := e.store.GetContextItems(ctx, convID)
if err != nil {
return nil, err
}
candidateSet := make(map[string]bool)
for _, c := range candidates {
candidateSet[c.SummaryID] = true
}
startOrd := -1
endOrd := -1
hasNonCandidate := false
for _, item := range items {
if item.ItemType == "summary" && candidateSet[item.SummaryID] {
if startOrd == -1 {
startOrd, endOrd = item.Ordinal, item.Ordinal
} else {
// Check for non-candidate items between endOrd and current ordinal
for _, it := range items {
if it.Ordinal > endOrd && it.Ordinal <= item.Ordinal {
if it.ItemType != "summary" || !candidateSet[it.SummaryID] {
hasNonCandidate = true
break
}
}
}
if hasNonCandidate {
break
}
if item.Ordinal < startOrd {
startOrd = item.Ordinal
}
if item.Ordinal > endOrd {
endOrd = item.Ordinal
}
}
}
}
if startOrd == -1 || endOrd == -1 {
return nil, nil
}
// Collect candidate summary IDs
candidateIDs := make([]string, 0, len(candidates))
for _, c := range candidates {
candidateIDs = append(candidateIDs, c.SummaryID)
}
if hasNonCandidate {
// Use safe per-item deletion to avoid deleting non-candidate items
if err := e.store.ReplaceContextItemsWithSummary(ctx, convID, candidateIDs, summary.SummaryID); err != nil {
return nil, err
}
} else {
// Candidates are consecutive, use efficient range deletion
if err := e.store.ReplaceContextRangeWithSummary(ctx, convID, startOrd, endOrd, summary.SummaryID); err != nil {
return nil, err
}
}
return &summary.SummaryID, nil
}
// selectShallowestCondensationCandidate finds the shallowest consecutive summary group.
func (e *CompactionEngine) selectShallowestCondensationCandidate(
ctx context.Context, convID int64, forced bool,
) ([]Summary, error) {
items, err := e.store.GetContextItems(ctx, convID)
if err != nil {
return nil, err
}
// Group by depth, find consecutive runs
tailStartIdx := len(items) - FreshTailCount
if tailStartIdx < 0 {
tailStartIdx = 0
}
minFanout := CondensedMinFanout
if forced {
minFanout = CondensedMinFanoutHard
}
// Track depth groups
depthGroups := make(map[int][]ContextItem)
for i := 0; i < tailStartIdx; i++ {
item := items[i]
if item.ItemType != "summary" {
continue
}
sum, err := e.store.GetSummary(ctx, item.SummaryID)
if err != nil {
continue
}
depthGroups[sum.Depth] = append(depthGroups[sum.Depth], item)
}
// Find shallowest depth with enough candidates
// Collect all depths and sort to handle non-consecutive depths
var depths []int
for depth := range depthGroups {
depths = append(depths, depth)
}
sort.Ints(depths)
for _, depth := range depths {
group := depthGroups[depth]
if len(group) >= minFanout {
// Load summaries
var result []Summary
for _, item := range group[:minFanout] {
sum, err := e.store.GetSummary(ctx, item.SummaryID)
if err != nil {
continue
}
result = append(result, *sum)
}
return result, nil
}
}
return nil, nil
}
// selectOldestChunkAtDepth scans context_items from oldest ordinal, collecting consecutive
// summaries at the given depth. Stops at non-summary items, different depth, fresh tail, or
// token overflow. Returns contiguous chunk of summaries.
func (e *CompactionEngine) selectOldestChunkAtDepth(
ctx context.Context, convID int64, targetDepth int,
) ([]Summary, error) {
items, err := e.store.GetContextItems(ctx, convID)
if err != nil {
return nil, err
}
tailStartIdx := len(items) - FreshTailCount
if tailStartIdx < 0 {
tailStartIdx = 0
}
var chunk []Summary
accumTokens := 0
for i := 0; i < tailStartIdx; i++ {
item := items[i]
if item.ItemType != "summary" {
// Non-summary breaks the chunk
break
}
sum, err := e.store.GetSummary(ctx, item.SummaryID)
if err != nil {
break
}
if sum.Depth != targetDepth {
// Different depth breaks the chunk
break
}
if accumTokens+sum.TokenCount > LeafChunkTokens {
// Token overflow stops collection
break
}
chunk = append(chunk, *sum)
accumTokens += sum.TokenCount
}
// Min tokens check: spec line 808
// chunk tokens must be >= max(CondensedTargetTokens, LeafChunkTokens × 0.1) = 2000
minTokens := CondensedTargetTokens // 2000
if accumTokens < minTokens {
return nil, nil
}
return chunk, nil
}
// generateLeafSummary calls the LLM to generate a leaf summary with 3-level escalation.
// Level 1: normal LLM prompt. Level 2: aggressive prompt. Level 3: deterministic truncation.
func (e *CompactionEngine) generateLeafSummary(
ctx context.Context,
messages []Message,
previousSummary string,
) (string, error) {
if e.complete == nil {
return truncateSummary(messages), nil
}
sourceText := formatMessagesForSummary(messages)
inputTokens := sumMessageTokens(messages)
targetTokens := minInt(LeafTargetTokens, int(float64(inputTokens)*0.35))
// Level 1: normal prompt
prompt := buildLeafSummaryPrompt(sourceText, previousSummary, targetTokens)
content, err := e.complete(ctx, prompt, CompleteOptions{
MaxTokens: LeafTargetTokens * 2,
Temperature: 0.3,
})
if err != nil {
return "", err
}
if content == "" {
// Retry with temperature=0
content, err = e.complete(ctx, prompt, CompleteOptions{
MaxTokens: LeafTargetTokens * 2,
Temperature: 0,
})
if err != nil {
return "", err
}
}
// Check if level 1 succeeded
if content != "" && tokenizer.EstimateMessageTokens(providers.Message{Content: content}) < inputTokens {
return content, nil
}
// Level 2: aggressive prompt
aggressiveTarget := minInt(640, int(float64(inputTokens)*0.20))
aggressivePrompt := buildAggressiveLeafSummaryPrompt(sourceText, previousSummary, aggressiveTarget)
content, err = e.complete(ctx, aggressivePrompt, CompleteOptions{
MaxTokens: aggressiveTarget * 2,
Temperature: 0.3,
})
if err != nil {
return "", err
}
if content == "" {
// Retry with temperature=0
content, err = e.complete(ctx, aggressivePrompt, CompleteOptions{
MaxTokens: aggressiveTarget * 2,
Temperature: 0,
})
if err != nil {
return "", err
}
}
if content != "" && tokenizer.EstimateMessageTokens(providers.Message{Content: content}) < inputTokens {
return content, nil
}
// Level 3: deterministic truncation
return truncateSummary(messages), nil
}
// generateCondensedSummary calls the LLM to generate a condensed summary with 3-level escalation.
func (e *CompactionEngine) generateCondensedSummary(ctx context.Context, summaries []Summary) (string, error) {
if e.complete == nil {
return truncateCondensedSummaries(summaries), nil
}
sourceText := formatSummariesForCondensation(summaries)
inputTokens := sumSummaryTokens(summaries)
targetTokens := minInt(CondensedTargetTokens, int(float64(inputTokens)*0.35))
// Level 1: normal prompt
prompt := buildCondensedSummaryPrompt(sourceText, targetTokens)
content, err := e.complete(ctx, prompt, CompleteOptions{
MaxTokens: CondensedTargetTokens * 2,
Temperature: 0.3,
})
if err != nil {
return "", err
}
if content == "" {
content, err = e.complete(ctx, prompt, CompleteOptions{
MaxTokens: CondensedTargetTokens * 2,
Temperature: 0,
})
if err != nil {
return "", err
}
}
if content != "" {
return content, nil
}
// Level 2: aggressive prompt
aggressiveTarget := minInt(640, int(float64(inputTokens)*0.20))
aggressivePrompt := buildCondensedSummaryPrompt(sourceText, aggressiveTarget)
content, err = e.complete(ctx, aggressivePrompt, CompleteOptions{
MaxTokens: aggressiveTarget * 2,
Temperature: 0.3,
})
if err != nil {
return "", err
}
if content != "" {
return content, nil
}
// Level 3: deterministic fallback
return truncateCondensedSummaries(summaries), nil
}
// runCondensedLoop runs condensed compaction in a loop until:
// a) context tokens <= threshold (success), OR
// b) No candidate found (nothing to condense), OR
// c) tokensAfter >= tokensBefore (no progress this iteration), OR
// d) tokensAfter >= previousTokens (no improvement over last iteration)
func (e *CompactionEngine) runCondensedLoop(ctx context.Context, convID int64) {
var prevTokens int
for {
select {
case <-ctx.Done():
return
default:
}
tokensBefore, err := e.store.GetContextTokenCount(ctx, convID)
if err != nil {
logger.ErrorCF("seahorse", "condensed: get tokens", map[string]any{"error": err.Error()})
return
}
condensedID, err := e.compactCondensed(ctx, convID)
if err != nil {
logger.ErrorCF("seahorse", "condensed: compact", map[string]any{"error": err.Error()})
return
}
if condensedID == nil {
// No candidate found
logger.DebugCF("seahorse", "condensed: no candidate", map[string]any{"conv_id": convID})
return
}
tokensAfter, _ := e.store.GetContextTokenCount(ctx, convID)
if tokensAfter >= tokensBefore {
// No progress this iteration
logger.DebugCF(
"seahorse",
"condensed: no progress",
map[string]any{"conv_id": convID, "tokens_before": tokensBefore, "tokens_after": tokensAfter},
)
return
}
if tokensAfter >= prevTokens && prevTokens > 0 {
// No improvement over last iteration
logger.DebugCF(
"seahorse",
"condensed: no improvement",
map[string]any{"conv_id": convID, "tokens": tokensAfter},
)
return
}
prevTokens = tokensAfter
}
}
// --- Helper functions ---
func formatMessagesForSummary(messages []Message) string {
var result string
for _, m := range messages {
ts := m.CreatedAt.Format("2006-01-02 15:04 MST")
content := m.Content
if content == "" && len(m.Parts) > 0 {
content = partsToReadableContent(m.Parts)
}
result += fmt.Sprintf("[%s]\n%s\n\n", ts, content)
}
return result
}
func formatSummariesForCondensation(summaries []Summary) string {
var result string
for _, s := range summaries {
earliest := ""
if s.EarliestAt != nil {
earliest = s.EarliestAt.Format("2006-01-02")
}
latest := ""
if s.LatestAt != nil {
latest = s.LatestAt.Format("2006-01-02")
}
result += fmt.Sprintf("[%s - %s]\n%s\n\n", earliest, latest, s.Content)
}
return result
}
func buildLeafSummaryPrompt(sourceText, previousSummary string, targetTokens int) string {
prev := "(none)"
if previousSummary != "" {
prev = previousSummary
}
return fmt.Sprintf(`You summarize a SEGMENT of a conversation for future model turns.
Treat this as incremental memory compaction input, not a full-conversation summary.
Normal summary policy:
- Preserve key decisions, rationale, constraints, and active tasks.
- Keep essential technical details needed to continue work safely.
- Remove obvious repetition and conversational filler.
Output requirements:
- Plain text only.
- No preamble, headings, or markdown formatting.
- Track file operations (created, modified, deleted, renamed) with file paths and current status.
- If no file operations appear, include exactly: "Files: none".
- End with exactly: "Expand for details about: <comma-separated list of what was dropped or compressed>".
- Target length: about %d tokens or less.
<previous_context>
%s
</previous_context>
<conversation_segment>
%s
</conversation_segment>`, targetTokens, prev, sourceText)
}
func buildCondensedSummaryPrompt(sourceText string, targetTokens int) string {
return fmt.Sprintf(`You condense multiple summaries into a single higher-level summary.
Preserve all important decisions, constraints, and outcomes.
Merge overlapping topics. Keep technical details intact.
Output requirements:
- Plain text only.
- No preamble, headings, or markdown formatting.
- End with exactly: "Expand for details about: <comma-separated list>".
- Target length: about %d tokens or less.
<summaries>
%s
</summaries>`, targetTokens, sourceText)
}
func buildAggressiveLeafSummaryPrompt(sourceText, previousSummary string, targetTokens int) string {
prev := "(none)"
if previousSummary != "" {
prev = previousSummary
}
return fmt.Sprintf(`You summarize a SEGMENT of a conversation for future model turns.
Aggressive summary policy:
- Keep only durable facts and current task state.
- Remove examples, repetition, and low-value narrative details.
- Preserve explicit TODOs, blockers, decisions, and constraints.
Output requirements:
- Plain text only.
- No preamble, headings, or markdown formatting.
- Track file operations (created, modified, deleted, renamed) with file paths and current status.
- If no file operations appear, include exactly: "Files: none".
- End with exactly: "Expand for details about: <comma-separated list of what was dropped or compressed>".
- Target length: about %d tokens or less.
<previous_context>
%s
</previous_context>
<conversation_segment>
%s
</conversation_segment>`, targetTokens, prev, sourceText)
}
func truncateSummary(messages []Message) string {
content := ""
for _, m := range messages {
c := m.Content
if c == "" && len(m.Parts) > 0 {
c = partsToReadableContent(m.Parts)
}
content += c + "\n"
}
if len(content) > 2048 {
content = content[:2048]
}
content += fmt.Sprintf("\n[Truncated from %d messages]", len(messages))
return content
}
func truncateCondensedSummaries(summaries []Summary) string {
content := ""
for _, s := range summaries {
content += s.Content + "\n"
}
if len(content) > 2048 {
content = content[:2048]
}
content += fmt.Sprintf("\n[Condensed from %d summaries]", len(summaries))
return content
}
func sumMessageTokens(messages []Message) int {
total := 0
for _, m := range messages {
total += m.TokenCount
}
return total
}
func sumSummaryTokens(summaries []Summary) int {
total := 0
for _, s := range summaries {
total += s.TokenCount
}
return total
}
func minInt(a, b int) int {
if a < b {
return a
}
return b
}
+974
View File
@@ -0,0 +1,974 @@
package seahorse
import (
"context"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
)
// --- Test Helpers ---
// waitForCondensed blocks until the async condensed goroutine for convID finishes.
// Returns false if timeout is reached.
func waitForCondensed(ce *CompactionEngine, convID int64, timeout time.Duration) bool {
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
if _, exists := ce.condensing.Load(convID); !exists {
return true
}
time.Sleep(50 * time.Millisecond)
}
return false
}
// --- Compaction Tests ---
func newTestCompactionEngine(t *testing.T) (*CompactionEngine, *Store, int64) {
t.Helper()
db := openTestDB(t)
if err := runSchema(db); err != nil {
t.Fatalf("migration: %v", err)
}
s := &Store{db: db}
ctx := context.Background()
conv, _ := s.GetOrCreateConversation(ctx, "test:compact")
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
ce := &CompactionEngine{
store: s,
config: Config{},
complete: mockCompleteFn,
shutdownCtx: shutdownCtx,
shutdownCancel: shutdownCancel,
}
convID := conv.ConversationID
// Ensure async goroutines are stopped before database is closed.
// Register cleanup here (after openTestDB) so it runs BEFORE openTestDB's db.Close().
t.Cleanup(func() {
shutdownCancel()
// Wait for async condensed goroutine to finish (poll condensing map)
deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
if _, exists := ce.condensing.Load(convID); !exists {
break
}
time.Sleep(50 * time.Millisecond)
}
})
return ce, s, conv.ConversationID
}
// newTestCompactionEngineWithStore creates a CompactionEngine with existing store.
// Note: Caller is responsible for calling shutdownCancel when test ends.
func newTestCompactionEngineWithStore(
s *Store, complete CompleteFn,
) (ce *CompactionEngine, shutdownCancel context.CancelFunc) {
shutdownCtx, cancel := context.WithCancel(context.Background())
return &CompactionEngine{
store: s,
config: Config{},
complete: complete,
shutdownCtx: shutdownCtx,
shutdownCancel: cancel,
}, cancel
}
// mockCompleteFn returns a simple summary for testing
var mockCompleteFn CompleteFn = func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) {
return "Mock summary of the conversation segment.", nil
}
func TestNeedsCompaction(t *testing.T) {
ce, s, convID := newTestCompactionEngine(t)
ctx := context.Background()
// Empty context — no compaction needed
needed, err := ce.NeedsCompaction(ctx, convID, 10000)
if err != nil {
t.Fatalf("NeedsCompaction: %v", err)
}
if needed {
t.Error("expected no compaction for empty context")
}
// Add messages to context, total tokens = 8000
for i := 0; i < 8; i++ {
m, _ := s.AddMessage(ctx, convID, "user", "test message content", 1000)
s.AppendContextMessage(ctx, convID, m.ID)
}
// Threshold = 0.75 × 10000 = 7500. We have 8000 tokens → needs compaction
needed, err = ce.NeedsCompaction(ctx, convID, 10000)
if err != nil {
t.Fatalf("NeedsCompaction: %v", err)
}
if !needed {
t.Error("expected compaction needed at 8000/10000 tokens (threshold 75%)")
}
// Below threshold: 5000 / 10000 → no compaction
s.UpsertContextItems(ctx, convID, nil) // clear
for i := 0; i < 5; i++ {
m, _ := s.AddMessage(ctx, convID, "user", "test", 1000)
s.AppendContextMessage(ctx, convID, m.ID)
}
needed, _ = ce.NeedsCompaction(ctx, convID, 10000)
if needed {
t.Error("expected no compaction at 5000/10000 tokens")
}
}
func TestCompactLeaf(t *testing.T) {
ce, s, convID := newTestCompactionEngine(t)
ctx := context.Background()
// Create enough messages to trigger leaf compaction:
// Need > FreshTailCount(32) evictable messages with >= LeafMinFanout(8) contiguous
for i := 0; i < 40; i++ {
m, _ := s.AddMessage(ctx, convID, "user", "message content for compaction test", 100)
s.AppendContextMessage(ctx, convID, m.ID)
}
// Compact
result, err := ce.Compact(ctx, convID, CompactInput{})
if err != nil {
t.Fatalf("Compact: %v", err)
}
if result == nil {
t.Fatal("expected non-nil result")
}
// Should have created at least one leaf summary
if result.LeafSummaries == 0 {
t.Error("expected at least 1 leaf summary")
}
// Context should now contain a summary item
items, _ := s.GetContextItems(ctx, convID)
foundSummary := false
for _, item := range items {
if item.ItemType == "summary" {
foundSummary = true
break
}
}
if !foundSummary {
t.Error("expected a summary in context_items after leaf compaction")
}
// Some messages should have been replaced
if len(result.SummariesCreated) == 0 {
t.Error("expected at least 1 summary created")
}
}
func TestCompactLeafNoCandidate(t *testing.T) {
ce, _, convID := newTestCompactionEngine(t)
ctx := context.Background()
// Too few messages to trigger leaf compaction
m, _ := ce.store.AddMessage(ctx, convID, "user", "short", 10)
ce.store.AppendContextMessage(ctx, convID, m.ID)
result, err := ce.Compact(ctx, convID, CompactInput{})
if err != nil {
t.Fatalf("Compact: %v", err)
}
if result == nil {
t.Fatal("expected non-nil result even with no candidate")
}
if result.LeafSummaries != 0 {
t.Errorf("LeafSummaries = %d, want 0 (too few messages)", result.LeafSummaries)
}
}
func TestCompactCondensed(t *testing.T) {
ce, s, convID := newTestCompactionEngine(t)
ctx := context.Background()
// Create enough leaf summaries and fresh messages to enable condensation
leafIDs := make([]string, CondensedMinFanout)
for i := 0; i < CondensedMinFanout; i++ {
now := time.Now().UTC()
summary, err := s.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindLeaf,
Depth: 0,
Content: "leaf summary content " + time.Now().String(),
TokenCount: 500,
EarliestAt: &now,
LatestAt: &now,
})
if err != nil {
t.Fatalf("CreateSummary %d: %v", i, err)
}
leafIDs[i] = summary.SummaryID
s.AppendContextSummary(ctx, convID, summary.SummaryID)
}
// Add enough fresh messages to have a fresh tail (>= FreshTailCount)
for i := 0; i < FreshTailCount; i++ {
m, _ := s.AddMessage(ctx, convID, "user", "fresh message", 10)
s.AppendContextMessage(ctx, convID, m.ID)
}
// Compact with force to trigger condensation
_, err := ce.Compact(ctx, convID, CompactInput{Force: true})
if err != nil {
t.Fatalf("Compact: %v", err)
}
// Wait for async condensed goroutine to complete
if !waitForCondensed(ce, convID, 2*time.Second) {
t.Fatal("timeout waiting for condensed compaction")
}
// Should have created a condensed summary in the DB
summaries, _ := s.GetSummariesByConversation(ctx, convID)
foundCondensed := false
for _, sum := range summaries {
if sum.Kind == SummaryKindCondensed {
foundCondensed = true
break
}
}
if !foundCondensed {
t.Error("expected at least 1 condensed summary")
}
}
func TestCompactCondensedDoesNotOrphanSummaryWhenCandidatesRemovedConcurrently(t *testing.T) {
// Reproduce orphan bug: candidates found by selectOldestChunkAtDepth are removed
// from context_items between candidate selection and ordinal range scan.
// Use a slow CompleteFn with barrier sync to control timing.
s := openTestStore(t)
ctx := context.Background()
conv, _ := s.GetOrCreateConversation(ctx, "test:orphan-race")
convID := conv.ConversationID
// Create leaf summaries with enough tokens for condensation
var leafIDs []string
for i := 0; i < CondensedMinFanout; i++ {
now := time.Now().UTC()
sum, err := s.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindLeaf,
Depth: 0,
Content: fmt.Sprintf("leaf summary %d", i),
TokenCount: 500,
EarliestAt: &now,
LatestAt: &now,
})
if err != nil {
t.Fatalf("CreateSummary: %v", err)
}
leafIDs = append(leafIDs, sum.SummaryID)
s.AppendContextSummary(ctx, convID, sum.SummaryID)
}
// Add fresh tail so leaf summaries are in evictable range
for i := 0; i < FreshTailCount+1; i++ {
m, _ := s.AddMessage(ctx, convID, "user", "fresh", 10)
s.AppendContextMessage(ctx, convID, m.ID)
}
// Barrier: CompleteFn waits until test removes context_items, then returns
var barrier1, barrier2 sync.WaitGroup
barrier1.Add(1) // CompleteFn signals when called
barrier2.Add(1) // test signals when context_items removed
slowComplete := func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) {
barrier1.Done() // signal: LLM called, candidates selected
barrier2.Wait() // wait: test removes context_items
return "Condensed summary.", nil
}
ce, cancel := newTestCompactionEngineWithStore(s, slowComplete)
t.Cleanup(func() {
cancel()
time.Sleep(100 * time.Millisecond)
})
// Run compactCondensed in background
type compactResult struct {
summaryID *string
err error
}
resultCh := make(chan compactResult, 1)
go func() {
sid, err := ce.compactCondensed(context.Background(), convID)
resultCh <- compactResult{summaryID: sid, err: err}
}()
// Wait for CompleteFn to be called (candidates selected)
barrier1.Wait()
// Remove leaf summaries from context_items (simulating concurrent replacement)
items, _ := s.GetContextItems(ctx, convID)
var preserved []ContextItem
for _, item := range items {
isLeaf := false
for _, lid := range leafIDs {
if item.SummaryID == lid {
isLeaf = true
break
}
}
if !isLeaf {
preserved = append(preserved, item)
}
}
s.UpsertContextItems(ctx, convID, preserved)
// Let CompleteFn return
barrier2.Done()
// Get result
res := <-resultCh
if res.err != nil {
t.Fatalf("compactCondensed: %v", res.err)
}
// With the bug: returns non-nil summaryID even though context_items has no matching ordinals
// The fix: should return nil when startOrd == -1
if res.summaryID != nil {
t.Errorf("compactCondensed returned summaryID=%s, want nil (orphan created)", *res.summaryID)
// Verify the orphan exists in DB
summary, _ := s.GetSummary(context.Background(), *res.summaryID)
if summary != nil && summary.Kind == SummaryKindCondensed {
// Check it's NOT in context_items (orphan)
items2, _ := s.GetContextItems(context.Background(), convID)
found := false
for _, item := range items2 {
if item.SummaryID == *res.summaryID {
found = true
break
}
}
if !found {
t.Error("condensed summary exists in DB but not in context_items — orphan confirmed")
}
}
}
}
func TestCompactUntilUnder(t *testing.T) {
ce, s, convID := newTestCompactionEngine(t)
ctx := context.Background()
// Create many leaf summaries to ensure we can condense
for i := 0; i < 8; i++ {
now := time.Now().UTC()
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindLeaf,
Depth: 0,
Content: "leaf summary for condensation test",
TokenCount: 500,
EarliestAt: &now,
LatestAt: &now,
})
s.AppendContextSummary(ctx, convID, summary.SummaryID)
}
// Force compact until under budget
result, err := ce.CompactUntilUnder(ctx, convID, 2000)
if err != nil {
t.Fatalf("CompactUntilUnder: %v", err)
}
if result == nil {
t.Fatal("expected non-nil result")
}
}
func TestSelectShallowestCondensationCandidate(t *testing.T) {
ce, s, convID := newTestCompactionEngine(t)
ctx := context.Background()
// Create enough leaf summaries + fresh messages for candidates
for i := 0; i < LeafMinFanout; i++ {
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindLeaf,
Depth: 0,
Content: "leaf",
TokenCount: 100,
})
s.AppendContextSummary(ctx, convID, summary.SummaryID)
}
// Add fresh tail messages so summaries are in evictable range
for i := 0; i < FreshTailCount+1; i++ {
m, _ := s.AddMessage(ctx, convID, "user", "fresh", 5)
s.AppendContextMessage(ctx, convID, m.ID)
}
candidates, err := ce.selectShallowestCondensationCandidate(ctx, convID, false)
if err != nil {
t.Fatalf("selectShallowestCondensationCandidate: %v", err)
}
// Should find leaf summaries at depth 0
if len(candidates) < CondensedMinFanout {
t.Errorf("candidates = %d, want >= %d", len(candidates), CondensedMinFanout)
}
}
func TestSelectShallowestCondensationCandidateEmpty(t *testing.T) {
ce, _, convID := newTestCompactionEngine(t)
ctx := context.Background()
candidates, err := ce.selectShallowestCondensationCandidate(ctx, convID, false)
if err != nil {
t.Fatalf("selectShallowestCondensationCandidate: %v", err)
}
if len(candidates) != 0 {
t.Errorf("candidates = %d, want 0 for empty context", len(candidates))
}
}
func TestCompactCondensedUsesSelectOldestChunk(t *testing.T) {
// Verify that compactCondensed prefers ordinal-ordered chunks via selectOldestChunkAtDepth
// rather than just grouping by depth without regard to order
ce, s, convID := newTestCompactionEngine(t)
ctx := context.Background()
// Create interleaved summaries at depth 0 with a message in between:
// sum1 (ordinal 100), msg (ordinal 200), sum2 (ordinal 300)
for i := 0; i < LeafMinFanout+2; i++ {
now := time.Now().UTC()
s.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindLeaf,
Depth: 0,
Content: fmt.Sprintf("leaf summary %d", i),
TokenCount: 100,
EarliestAt: &now,
LatestAt: &now,
})
}
// Insert a message between first two summaries to break contiguity
// for selectShallowestCondensationCandidate but would still find all 3
// but selectOldestChunkAtDepth should only find sum1 + sum2 (not sum3)
msg, _ := s.AddMessage(ctx, convID, "user", "interrupting message", 5)
s.AppendContextMessage(ctx, convID, msg.ID)
// Run compactCondensed
result, err := ce.compactCondensed(ctx, convID)
if err != nil {
t.Fatalf("compactCondensed: %v", err)
}
// The result should have merged the two summaries at the start
// (skipping the message in between), This proves ordinal-aware selection works.
_ = result // verify summary was created
if result != nil {
summaries, _ := s.GetSummariesByConversation(ctx, convID)
found := false
for _, sum := range summaries {
if sum.Kind == SummaryKindCondensed {
found = true
break
}
}
if !found {
t.Error("expected condensed summary to be created via ordinal-aware selection")
}
}
}
func TestCompactCondensedUsesOrdinalAwareSelection(t *testing.T) {
ce, s, convID := newTestCompactionEngine(t)
ctx := context.Background()
// Create leaf summaries at depth 0 (total tokens >= CondensedTargetTokens)
for i := 0; i < 5; i++ {
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindLeaf,
Depth: 0,
Content: fmt.Sprintf("leaf summary %d", i),
TokenCount: 500, // 5 × 500 = 2500 >= CondensedTargetTokens (2000)
})
s.AppendContextSummary(ctx, convID, summary.SummaryID)
}
// Add fresh tail
for i := 0; i < FreshTailCount+1; i++ {
m, _ := s.AddMessage(ctx, convID, "user", "fresh", 5)
s.AppendContextMessage(ctx, convID, m.ID)
}
chunk, err := ce.selectOldestChunkAtDepth(ctx, convID, 0)
if err != nil {
t.Fatalf("selectOldestChunkAtDepth: %v", err)
}
if len(chunk) < 2 {
t.Errorf("chunk length = %d, want >= 2 contiguous summaries", len(chunk))
}
for _, s := range chunk {
if s.Depth != 0 {
t.Errorf("got depth %d, want 0", s.Depth)
}
}
}
func TestSelectOldestChunkAtDepthBreaksOnMessage(t *testing.T) {
ce, s, convID := newTestCompactionEngine(t)
ctx := context.Background()
// Create 3 summaries, then a message, then 3 more summaries
for i := 0; i < 3; i++ {
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindLeaf,
Depth: 0,
Content: fmt.Sprintf("leaf %d", i),
TokenCount: 100,
})
s.AppendContextSummary(ctx, convID, summary.SummaryID)
}
msg, _ := s.AddMessage(ctx, convID, "user", "break", 10)
s.AppendContextMessage(ctx, convID, msg.ID)
for i := 0; i < 3; i++ {
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindLeaf,
Depth: 0,
Content: fmt.Sprintf("leaf-after %d", i),
TokenCount: 100,
})
s.AppendContextSummary(ctx, convID, summary.SummaryID)
}
for i := 0; i < FreshTailCount+1; i++ {
m, _ := s.AddMessage(ctx, convID, "user", "fresh", 5)
s.AppendContextMessage(ctx, convID, m.ID)
}
chunk, _ := ce.selectOldestChunkAtDepth(ctx, convID, 0)
if len(chunk) > 3 {
t.Errorf("chunk length = %d, want <= 3 (message breaks chain)", len(chunk))
}
}
func TestSelectOldestChunkAtDepthMinTokens(t *testing.T) {
ce, s, convID := newTestCompactionEngine(t)
ctx := context.Background()
// Create summaries with very low token counts (total < 2000)
for i := 0; i < 5; i++ {
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindLeaf,
Depth: 0,
Content: fmt.Sprintf("tiny summary %d", i),
TokenCount: 50, // very small
})
s.AppendContextSummary(ctx, convID, summary.SummaryID)
}
// Add fresh tail to protect from compaction
for i := 0; i < FreshTailCount+1; i++ {
m, _ := s.AddMessage(ctx, convID, "user", fmt.Sprintf("tail %d", i), 10)
s.AppendContextMessage(ctx, convID, m.ID)
}
// Should return nil because total tokens (250) < 2000 minimum
chunk, err := ce.selectOldestChunkAtDepth(ctx, convID, 0)
if err != nil {
t.Fatalf("selectOldestChunkAtDepth: %v", err)
}
if len(chunk) > 0 {
t.Errorf("expected empty chunk when tokens < 2000, got %d summaries", len(chunk))
}
}
func TestSelectOldestChunkAtDepthPassesMinTokens(t *testing.T) {
ce, s, convID := newTestCompactionEngine(t)
ctx := context.Background()
// Create summaries with enough tokens (total >= 2000)
for i := 0; i < 5; i++ {
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindLeaf,
Depth: 0,
Content: fmt.Sprintf(
"substantial summary with enough content to meet minimum token threshold for condensation candidate %d",
i,
),
TokenCount: 500, // 5 × 500 = 2500 >= 2000
})
s.AppendContextSummary(ctx, convID, summary.SummaryID)
}
// Add fresh tail
for i := 0; i < FreshTailCount+1; i++ {
m, _ := s.AddMessage(ctx, convID, "user", fmt.Sprintf("tail %d", i), 10)
s.AppendContextMessage(ctx, convID, m.ID)
}
// Should return chunk because total tokens (2500) >= 2000
chunk, err := ce.selectOldestChunkAtDepth(ctx, convID, 0)
if err != nil {
t.Fatalf("selectOldestChunkAtDepth: %v", err)
}
if len(chunk) == 0 {
t.Error("expected non-empty chunk when tokens >= 2000")
}
}
func TestGenerateLeafSummary(t *testing.T) {
ce, _, _ := newTestCompactionEngine(t)
ctx := context.Background()
msgs := []Message{
{Role: "user", Content: "hello world", TokenCount: 5},
{Role: "assistant", Content: "hi there", TokenCount: 5},
}
content, err := ce.generateLeafSummary(ctx, msgs, "")
if err != nil {
t.Fatalf("generateLeafSummary: %v", err)
}
if content == "" {
t.Error("expected non-empty summary content")
}
}
func TestGenerateLeafSummaryEscalationToAggressive(t *testing.T) {
// Level 1 returns summary that's too large (tokens >= input), should escalate to level 2
var calls []string
escalateComplete := func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) {
if contains(prompt, "Aggressive summary policy") {
calls = append(calls, "aggressive")
return "Short aggressive summary.", nil
}
calls = append(calls, "normal")
// Return a very long summary to trigger escalation
longContent := make([]byte, 5000)
for i := range longContent {
longContent[i] = 'x'
}
return string(longContent), nil
}
s := openTestStore(t)
ce, _ := newTestCompactionEngineWithStore(s, escalateComplete)
msgs := []Message{
{Role: "user", Content: "hello world", TokenCount: 10},
{Role: "assistant", Content: "response", TokenCount: 10},
}
content, err := ce.generateLeafSummary(context.Background(), msgs, "")
if err != nil {
t.Fatalf("generateLeafSummary: %v", err)
}
if content == "" {
t.Error("expected non-empty summary content")
}
// Should have called both normal and aggressive
foundNormal := false
foundAggressive := false
for _, c := range calls {
if c == "normal" {
foundNormal = true
}
if c == "aggressive" {
foundAggressive = true
}
}
if !foundNormal {
t.Error("expected normal LLM call")
}
if !foundAggressive {
t.Error("expected aggressive LLM call (level 2 escalation)")
}
}
func TestGenerateLeafSummaryEscalationToTruncation(t *testing.T) {
// Both normal and aggressive return empty, should escalate to level 3 truncation
emptyComplete := func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) {
return "", nil
}
s := openTestStore(t)
ce, _ := newTestCompactionEngineWithStore(s, emptyComplete)
msgs := []Message{
{Role: "user", Content: "hello world from test", TokenCount: 10},
{Role: "assistant", Content: "response text here", TokenCount: 10},
}
content, err := ce.generateLeafSummary(context.Background(), msgs, "")
if err != nil {
t.Fatalf("generateLeafSummary: %v", err)
}
// Level 3 truncation should have produced something
if content == "" {
t.Error("expected non-empty content from level 3 truncation fallback")
}
if !contains(content, "Truncated from") {
t.Errorf("expected truncation marker in content: %q", content)
}
}
func TestGenerateCondensedSummary(t *testing.T) {
ce, _, _ := newTestCompactionEngine(t)
ctx := context.Background()
summaries := []Summary{
{SummaryID: "sum_a", Content: "first summary", TokenCount: 100},
{SummaryID: "sum_b", Content: "second summary", TokenCount: 100},
}
content, err := ce.generateCondensedSummary(ctx, summaries)
if err != nil {
t.Fatalf("generateCondensedSummary: %v", err)
}
if content == "" {
t.Error("expected non-empty condensed summary content")
}
}
func TestGenerateCondensedSummaryEscalation(t *testing.T) {
// When LLM returns empty, should fall back to deterministic concatenation
emptyComplete := func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) {
return "", nil
}
s := openTestStore(t)
ce, _ := newTestCompactionEngineWithStore(s, emptyComplete)
summaries := []Summary{
{SummaryID: "sum_a", Content: "first summary text", TokenCount: 50},
{SummaryID: "sum_b", Content: "second summary text", TokenCount: 50},
}
content, err := ce.generateCondensedSummary(context.Background(), summaries)
if err != nil {
t.Fatalf("generateCondensedSummary: %v", err)
}
// Should fall back to concatenation
if content == "" {
t.Error("expected non-empty content from fallback")
}
}
// --- Async Condensed Compaction (Phase 2) ---
func TestCompactAsyncReturnsBeforeCondensed(t *testing.T) {
// Use a slow CompleteFn to verify Compact returns before condensed finishes
var callCount int32
slowComplete := func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) {
atomic.AddInt32(&callCount, 1)
time.Sleep(500 * time.Millisecond) // simulate slow LLM
return "Slow condensed summary.", nil
}
s := openTestStore(t)
ctx := context.Background()
conv, _ := s.GetOrCreateConversation(ctx, "test:async")
convID := conv.ConversationID
ce, cancel := newTestCompactionEngineWithStore(s, slowComplete)
t.Cleanup(func() {
cancel()
time.Sleep(100 * time.Millisecond)
})
// Create enough leaf summaries for condensation + fresh tail
for i := 0; i < CondensedMinFanout; i++ {
now := time.Now().UTC()
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindLeaf,
Depth: 0,
Content: "leaf for async test",
TokenCount: 500,
EarliestAt: &now,
LatestAt: &now,
})
s.AppendContextSummary(ctx, convID, summary.SummaryID)
}
for i := 0; i < FreshTailCount; i++ {
m, _ := s.AddMessage(ctx, convID, "user", "fresh", 10)
s.AppendContextMessage(ctx, convID, m.ID)
}
// Compact with force — should return quickly, condensed runs async
start := time.Now()
result, err := ce.Compact(ctx, convID, CompactInput{Force: true})
elapsed := time.Since(start)
if err != nil {
t.Fatalf("Compact: %v", err)
}
if result == nil {
t.Fatal("expected non-nil result")
}
// Should return well before the 500ms LLM call
if elapsed > 200*time.Millisecond {
t.Errorf("Compact took %v, should return before async condensed finishes", elapsed)
}
// Wait for async to complete
time.Sleep(800 * time.Millisecond)
// Verify condensed summary was created by background goroutine
summaries, _ := s.GetSummariesByConversation(ctx, convID)
foundCondensed := false
for _, sum := range summaries {
if sum.Kind == SummaryKindCondensed {
foundCondensed = true
break
}
}
if !foundCondensed {
t.Error("expected at least one condensed summary from async Phase 2")
}
}
func TestCompactAsyncDedup(t *testing.T) {
var callCount int32
slowComplete := func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) {
atomic.AddInt32(&callCount, 1)
time.Sleep(300 * time.Millisecond)
return "Slow condensed summary.", nil
}
s := openTestStore(t)
ctx := context.Background()
conv, _ := s.GetOrCreateConversation(ctx, "test:dedup")
convID := conv.ConversationID
ce, cancel := newTestCompactionEngineWithStore(s, slowComplete)
t.Cleanup(func() {
cancel()
waitForCondensed(ce, convID, 2*time.Second)
})
// Create conditions for condensed compaction
for i := 0; i < CondensedMinFanout; i++ {
now := time.Now().UTC()
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindLeaf,
Depth: 0,
Content: "leaf for dedup",
TokenCount: 500,
EarliestAt: &now,
LatestAt: &now,
})
s.AppendContextSummary(ctx, convID, summary.SummaryID)
}
for i := 0; i < FreshTailCount; i++ {
m, _ := s.AddMessage(ctx, convID, "user", "fresh", 10)
s.AppendContextMessage(ctx, convID, m.ID)
}
// Call Compact twice rapidly
ce.Compact(ctx, convID, CompactInput{Force: true})
ce.Compact(ctx, convID, CompactInput{Force: true})
// Wait for async to finish
time.Sleep(600 * time.Millisecond)
// LLM should only be called once for condensed (dedup)
// callCount may be 0 if no leaf was created (only condensed in goroutine)
// The key is that we don't get 2+ condensed calls
if atomic.LoadInt32(&callCount) > 1 {
t.Errorf("LLM called %d times, expected at most 1 (dedup)", callCount)
}
}
func TestCompactLeafForceBypassesFreshTail(t *testing.T) {
// Spec: compactLeaf with force=true should bypass FreshTailCount protection
// so CompactUntilUnder can compress messages inside the fresh tail
ce, s, convID := newTestCompactionEngine(t)
ctx := context.Background()
// Create exactly FreshTailCount+4 messages (36 total)
// Without force: all messages are in fresh tail → no candidate
// With force: should compact the oldest messages
total := FreshTailCount + 4
for i := 0; i < total; i++ {
m, _ := s.AddMessage(ctx, convID, "user", fmt.Sprintf("message %d for force test", i), 100)
s.AppendContextMessage(ctx, convID, m.ID)
}
// Without force: should return nil (all in fresh tail)
summaryID, err := ce.compactLeaf(ctx, convID)
if err != nil {
t.Fatalf("compactLeaf no-force: %v", err)
}
if summaryID != nil {
t.Error("expected nil without force (all messages in fresh tail)")
}
// With force: should compact despite fresh tail protection
summaryID, err = ce.compactLeaf(ctx, convID, true)
if err != nil {
t.Fatalf("compactLeaf force: %v", err)
}
if summaryID == nil {
t.Error("expected summary with force=true (bypasses fresh tail)")
}
}
func TestCompactLeafAccumulatesUpToLeafChunkTokens(t *testing.T) {
// Spec: compactLeaf should accumulate messages up to LeafChunkTokens before stopping
// It should NOT take the entire contiguous chunk regardless of token count
ce, s, convID := newTestCompactionEngine(t)
ctx := context.Background()
// Create messages totaling far more than LeafChunkTokens (20000)
// Each message is ~500 tokens, create 80 messages = 40000 tokens
for i := 0; i < 80; i++ {
m, _ := s.AddMessage(
ctx,
convID,
"user",
fmt.Sprintf(
"message %d with lots of content to make it big enough for token counting purposes and this should be a substantial message body that represents a meaningful conversation turn",
i,
),
500,
)
s.AppendContextMessage(ctx, convID, m.ID)
}
summaryID, err := ce.compactLeaf(ctx, convID)
if err != nil {
t.Fatalf("compactLeaf: %v", err)
}
if summaryID == nil {
t.Fatal("expected a summary to be created")
}
// The source messages that were compacted should total roughly LeafChunkTokens (20000),
// not the entire 40000 tokens worth of messages
summary, _ := s.GetSummary(ctx, *summaryID)
if summary == nil {
t.Fatal("summary not found")
}
// Source message tokens should be roughly <= LeafChunkTokens (20000)
// Spec says: "Stop when accumulated tokens >= LeafChunkTokens"
if summary.SourceMessageTokenCount > LeafChunkTokens {
t.Errorf("source tokens = %d, should be <= LeafChunkTokens (%d)",
summary.SourceMessageTokenCount, LeafChunkTokens)
}
}
+30
View File
@@ -0,0 +1,30 @@
package seahorse
// Short-term memory configuration constants — all are experience-based defaults.
const (
// OrdinalStep is the gap between ordinals in context_items.
// Insert at midpoint; resequence only when precision exhausted.
OrdinalStep = 100
// ContextThreshold is the compaction trigger for the context window.
ContextThreshold float64 = 0.75 // Compact at 75% of context window
FreshTailCount int = 32 // Recent messages protected from compaction
// LeafMinFanout is the fanout parameter.
LeafMinFanout int = 8 // Min messages per leaf summary
CondensedMinFanout int = 4 // Min summaries per condensed
CondensedMinFanoutHard int = 2 // Min for forced compaction
// LeafChunkTokens is the token target.
LeafChunkTokens int = 20000 // Max tokens per leaf chunk
LeafTargetTokens int = 1200 // Target tokens for leaf summaries
CondensedTargetTokens int = 2000 // Target tokens for condensed summaries
MaxExpandTokens int = 4000 // Token cap for expansion queries
// MaxCompactIterations caps CompactUntilUnder to prevent infinite loops.
// Each iteration reduces ~4x tokens via leaf (8:1) or condensed (4:1) compaction.
// With a 200k token context window and 75% threshold, ~20 iterations is enough
// for any realistic scenario. If exceeded, the issue is logged as a warning.
MaxCompactIterations int = 20
)
+568
View File
@@ -0,0 +1,568 @@
package seahorse
import (
"context"
"database/sql"
"fmt"
"os"
"path/filepath"
"regexp"
"strings"
"sync"
_ "modernc.org/sqlite"
"github.com/sipeed/picoclaw/pkg/logger"
)
// Config holds engine configuration.
type Config struct {
DBPath string `json:"dbPath"`
IgnoreSessionPatterns []string `json:"ignoreSessionPatterns,omitempty"`
StatelessSessionPatterns []string `json:"statelessSessionPatterns,omitempty"`
}
// CompleteFn is the LLM completion function type.
type CompleteFn func(ctx context.Context, prompt string, opts CompleteOptions) (string, error)
// CompleteOptions holds LLM completion parameters.
type CompleteOptions struct {
Model string
MaxTokens int
Temperature float64
}
// IngestResult is the result of message ingestion.
type IngestResult struct {
MessageCount int `json:"messageCount"`
TokenCount int `json:"tokenCount"`
}
// AssembleInput controls context assembly.
type AssembleInput struct {
Budget int `json:"budget"`
Query string `json:"query,omitempty"`
}
// AssembleResult contains assembled context.
type AssembleResult struct {
Messages []Message `json:"messages"`
Summary string `json:"summary"` // formatted XML summaries + system prompt addition
}
const numSessionShards = 256
// Engine is the main short-term memory engine.
type Engine struct {
store *Store
compaction *CompactionEngine
compactionMu sync.Mutex
assembler *Assembler
assemblerMu sync.Mutex
retrieval *RetrievalEngine
config Config
complete CompleteFn
ignorePatterns []*regexp.Regexp
statelessPatterns []*regexp.Regexp
sessionShards [numSessionShards]struct {
mu sync.Mutex
}
}
// CompactionEngine handles LLM-based summarization (defined in short_compaction.go).
type CompactionEngine struct {
store *Store
config Config
complete CompleteFn
condensing sync.Map // map[int64]struct{} — dedup for async condensed goroutines
shutdownCtx context.Context
shutdownCancel context.CancelFunc
}
// Assembler handles budget-aware context assembly (defined in short_assembler.go).
type Assembler struct {
store *Store
config Config
}
// RetrievalEngine handles search and expansion (defined in short_retrieval.go).
type RetrievalEngine struct {
store *Store
config Config
}
// Store returns the underlying store for direct access.
func (r *RetrievalEngine) Store() *Store {
return r.store
}
// NewEngine creates a new short-term memory engine.
func NewEngine(config Config, completeFn CompleteFn) (*Engine, error) {
dir := filepath.Dir(config.DBPath)
if dir != "" && dir != "." {
if err := os.MkdirAll(dir, 0o755); err != nil {
return nil, fmt.Errorf("create db directory: %w", err)
}
}
db, err := sql.Open("sqlite", config.DBPath)
if err != nil {
return nil, fmt.Errorf("open db: %w", err)
}
// Configure SQLite for concurrent access
if _, err := db.Exec("PRAGMA journal_mode = WAL;"); err != nil {
db.Close()
return nil, fmt.Errorf("enable WAL: %w", err)
}
if _, err := db.Exec("PRAGMA busy_timeout = 5000;"); err != nil {
db.Close()
return nil, fmt.Errorf("set busy_timeout: %w", err)
}
if _, err := db.Exec("PRAGMA synchronous = NORMAL;"); err != nil {
db.Close()
return nil, fmt.Errorf("set synchronous: %w", err)
}
if err := runSchema(db); err != nil {
db.Close()
return nil, fmt.Errorf("migrations: %w", err)
}
store := &Store{db: db}
// Prepend hardcoded ignore patterns (spec lines 1326-1328)
ignorePatterns := make([]string, 0, 1+len(config.IgnoreSessionPatterns))
ignorePatterns = append(ignorePatterns, "heartbeat")
ignorePatterns = append(ignorePatterns, config.IgnoreSessionPatterns...)
retrieval := &RetrievalEngine{store: store, config: config}
return &Engine{
store: store,
compaction: nil,
assembler: nil,
retrieval: retrieval,
config: config,
complete: completeFn,
ignorePatterns: compileSessionPatterns(ignorePatterns),
statelessPatterns: compileSessionPatterns(config.StatelessSessionPatterns),
}, nil
}
// compileSessionPattern converts a glob pattern to a compiled regex.
// Pattern rules:
// - * matches any sequence of non-colon characters ([^:]*)
// - ** matches any sequence of characters including colons (.*)
// - All other characters are treated literally
// - Pattern is anchored (^...$)
func compileSessionPattern(pattern string) *regexp.Regexp {
var b strings.Builder
b.WriteByte('^')
i := 0
for i < len(pattern) {
if i+1 < len(pattern) && pattern[i] == '*' && pattern[i+1] == '*' {
b.WriteString(".*")
i += 2
continue
}
if pattern[i] == '*' {
b.WriteString("[^:]*")
i++
continue
}
b.WriteString(regexp.QuoteMeta(string(pattern[i])))
i++
}
b.WriteByte('$')
return regexp.MustCompile(b.String())
}
// compileSessionPatterns compiles multiple glob patterns into regex patterns.
func compileSessionPatterns(patterns []string) []*regexp.Regexp {
result := make([]*regexp.Regexp, 0, len(patterns))
for _, p := range patterns {
if p == "" {
continue
}
result = append(result, compileSessionPattern(p))
}
return result
}
// shouldIgnoreSession returns true if the session key matches any ignore pattern.
func (e *Engine) shouldIgnoreSession(sessionKey string) bool {
for _, p := range e.ignorePatterns {
if p.MatchString(sessionKey) {
return true
}
}
return false
}
// isStatelessSession returns true if the session key matches any stateless pattern.
func (e *Engine) isStatelessSession(sessionKey string) bool {
for _, p := range e.statelessPatterns {
if p.MatchString(sessionKey) {
return true
}
}
return false
}
// fnv32 computes FNV-1a 32-bit hash for session key sharding.
func fnv32(key string) uint32 {
h := uint32(2166136261)
for _, c := range key {
h ^= uint32(c)
h *= 16777619
}
return h
}
// getSessionMutex returns the sharded mutex for a session key.
func (e *Engine) getSessionMutex(sessionKey string) *sync.Mutex {
h := fnv32(sessionKey)
shard := h % numSessionShards
return &e.sessionShards[shard].mu
}
// Ingest adds messages to a conversation identified by sessionKey.
func (e *Engine) Ingest(ctx context.Context, sessionKey string, messages []Message) (*IngestResult, error) {
if e.shouldIgnoreSession(sessionKey) {
return nil, nil
}
if e.isStatelessSession(sessionKey) {
return nil, nil
}
mu := e.getSessionMutex(sessionKey)
mu.Lock()
defer mu.Unlock()
conv, err := e.store.GetOrCreateConversation(ctx, sessionKey)
if err != nil {
return nil, fmt.Errorf("get conversation: %w", err)
}
var totalTokens int
var msgIDs []int64
for _, msg := range messages {
var added *Message
var err error
if len(msg.Parts) > 0 {
added, err = e.store.AddMessageWithParts(ctx, conv.ConversationID, msg.Role, msg.Parts, msg.TokenCount)
} else {
added, err = e.store.AddMessage(ctx, conv.ConversationID, msg.Role, msg.Content, msg.TokenCount)
}
if err != nil {
return nil, fmt.Errorf("add message: %w", err)
}
totalTokens += msg.TokenCount
msgIDs = append(msgIDs, added.ID)
}
// Append to context_items using actual inserted IDs
if err := e.store.AppendContextMessages(ctx, conv.ConversationID, msgIDs); err != nil {
return nil, fmt.Errorf("append context: %w", err)
}
logger.InfoCF("seahorse", "ingest", map[string]any{
"conv_id": conv.ConversationID,
"messages": len(messages),
"tokens": totalTokens,
})
return &IngestResult{
MessageCount: len(messages),
TokenCount: totalTokens,
}, nil
}
// Close releases resources.
func (e *Engine) Close() error {
// Signal compaction goroutines to stop
if e.compaction != nil {
e.compaction.Close()
}
if e.store != nil && e.store.db != nil {
return e.store.db.Close()
}
return nil
}
// GetRetrieval returns the retrieval engine for tool implementations.
func (e *Engine) GetRetrieval() *RetrievalEngine {
return e.retrieval
}
// Assemble builds budget-constrained context for a session.
func (e *Engine) Assemble(ctx context.Context, sessionKey string, input AssembleInput) (*AssembleResult, error) {
if e.shouldIgnoreSession(sessionKey) {
return nil, nil
}
conv, err := e.store.GetOrCreateConversation(ctx, sessionKey)
if err != nil {
return nil, fmt.Errorf("get conversation: %w", err)
}
e.initAssemblerOnce()
return e.assembler.Assemble(ctx, conv.ConversationID, input)
}
// Compact compresses conversation history for a session.
func (e *Engine) Compact(ctx context.Context, sessionKey string, input CompactInput) (*CompactResult, error) {
if e.shouldIgnoreSession(sessionKey) || e.isStatelessSession(sessionKey) {
return &CompactResult{}, nil
}
conv, err := e.store.GetOrCreateConversation(ctx, sessionKey)
if err != nil {
return nil, fmt.Errorf("get conversation: %w", err)
}
e.initCompactionOnce()
return e.compaction.Compact(ctx, conv.ConversationID, input)
}
// CompactUntilUnder aggressively compacts until context is under budget.
// Used for emergency compaction after LLM overflow (retry reason).
func (e *Engine) CompactUntilUnder(ctx context.Context, sessionKey string, budget int) (*CompactResult, error) {
if e.shouldIgnoreSession(sessionKey) || e.isStatelessSession(sessionKey) {
return &CompactResult{}, nil
}
conv, err := e.store.GetOrCreateConversation(ctx, sessionKey)
if err != nil {
return nil, fmt.Errorf("get conversation: %w", err)
}
e.initCompactionOnce()
return e.compaction.CompactUntilUnder(ctx, conv.ConversationID, budget)
}
// initCompactionOnce lazily initializes the compaction engine.
func (e *Engine) initCompactionOnce() {
if e.compaction == nil {
e.compactionMu.Lock()
defer e.compactionMu.Unlock()
if e.compaction == nil {
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
e.compaction = &CompactionEngine{
store: e.store,
config: e.config,
complete: e.complete,
shutdownCtx: shutdownCtx,
shutdownCancel: shutdownCancel,
}
}
}
}
// initAssemblerOnce lazily initializes the assembler.
func (e *Engine) initAssemblerOnce() {
if e.assembler == nil {
e.assemblerMu.Lock()
defer e.assemblerMu.Unlock()
if e.assembler == nil {
e.assembler = &Assembler{store: e.store, config: e.config}
}
}
}
// IngestMessages is an alias for Ingest.
func (e *Engine) IngestMessages(ctx context.Context, sessionKey string, messages []Message) (*IngestResult, error) {
return e.Ingest(ctx, sessionKey, messages)
}
// Bootstrap reconciles a session's messages with the database.
// Called once at startup for each known session.
// Bootstrap reconciles JSONL history with SQLite by ingesting only the delta.
// Simple approach: find longest matching prefix and append delta.
// If any mismatch is detected, clear and rebuild.
func (e *Engine) Bootstrap(ctx context.Context, sessionKey string, messages []Message) error {
if e.shouldIgnoreSession(sessionKey) {
return nil
}
if e.isStatelessSession(sessionKey) {
return nil
}
if len(messages) == 0 {
return nil
}
conv, err := e.store.GetOrCreateConversation(ctx, sessionKey)
if err != nil {
return fmt.Errorf("bootstrap: get conversation: %w", err)
}
// Get messages already in DB
dbMsgs, err := e.store.GetMessages(ctx, conv.ConversationID, len(messages), 0)
if err != nil {
return fmt.Errorf("bootstrap: get messages: %w", err)
}
// Fast path: DB has same count and exact match → no-op
if len(dbMsgs) == len(messages) {
matched := true
for i := 0; i < len(messages); i++ {
if !messageMatches(dbMsgs[i], messages[i]) {
matched = false
break
}
}
if matched {
return nil // DB is up to date
}
}
// Find longest matching prefix from the start
anchor := -1
compareLen := len(dbMsgs)
if compareLen > len(messages) {
compareLen = len(messages)
}
for i := 0; i < compareLen; i++ {
if messageMatches(dbMsgs[i], messages[i]) {
anchor = i
} else {
// Mismatch detected - log details and rebuild
logger.InfoCF("seahorse", "bootstrap: mismatch detected", map[string]any{
"conv_id": conv.ConversationID,
"index": i,
"db_role": dbMsgs[i].Role,
"db_content": truncate(dbMsgs[i].Content, 50),
"db_parts": len(dbMsgs[i].Parts),
"msg_role": messages[i].Role,
"msg_content": truncate(messages[i].Content, 50),
"msg_parts": len(messages[i].Parts),
})
break
}
}
// If we hit a mismatch before reaching the end of DB messages, delete delta and re-ingest
// Note: anchor can be -1 if first message didn't match (history completely changed)
if anchor >= 0 && anchor < len(dbMsgs)-1 && len(dbMsgs) > 0 {
anchorID := dbMsgs[anchor].ID
logger.InfoCF("seahorse", "bootstrap: history edit detected", map[string]any{
"conv_id": conv.ConversationID,
"db_count": len(dbMsgs),
"anchor": anchor,
"anchor_id": anchorID,
"msg_count": len(messages),
"delta_start": anchor + 1,
})
// Delete messages after anchor (also clears context_items)
if err := e.store.DeleteMessagesAfterID(ctx, conv.ConversationID, anchorID); err != nil {
return fmt.Errorf("bootstrap: delete messages: %w", err)
}
// Re-ingest from anchor+1 to end
delta := messages[anchor+1:]
if len(delta) > 0 {
_, err := e.Ingest(ctx, sessionKey, delta)
if err != nil {
return fmt.Errorf("bootstrap: re-ingest: %w", err)
}
}
return nil
}
// Normal case: append delta after anchor
if anchor >= 0 && anchor < len(messages)-1 {
delta := messages[anchor+1:]
if len(delta) > 0 {
_, err := e.Ingest(ctx, sessionKey, delta)
if err != nil {
return fmt.Errorf("bootstrap: ingest delta: %w", err)
}
}
} else if anchor == -1 && len(dbMsgs) > 0 {
// First message changed (history completely different) - rebuild from scratch
logger.InfoCF("seahorse", "bootstrap: history replaced, rebuilding", map[string]any{
"conv_id": conv.ConversationID,
"db_count": len(dbMsgs),
"msg_count": len(messages),
})
// Delete all existing messages
if err := e.store.DeleteMessagesAfterID(ctx, conv.ConversationID, 0); err != nil {
return fmt.Errorf("bootstrap: delete all messages: %w", err)
}
// Re-ingest everything
if len(messages) > 0 {
_, err := e.Ingest(ctx, sessionKey, messages)
if err != nil {
return fmt.Errorf("bootstrap: re-ingest all: %w", err)
}
}
} else if anchor == -1 && len(dbMsgs) == 0 {
// DB is empty, ingest everything
_, err := e.Ingest(ctx, sessionKey, messages)
if err != nil {
return fmt.Errorf("bootstrap: ingest all: %w", err)
}
}
return nil
}
// truncate shortens a string for logging.
func truncate(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
return s[:maxLen] + "..."
}
// messageMatches compares two messages using (role, content) or (role, parts).
// TokenCount is NOT compared because it may be re-estimated differently
// during bootstrap (e.g., via tokenizer.EstimateMessageTokens).
// For messages with Parts (tool_use, tool_result), compare Parts instead of Content
// since AddMessageWithParts stores empty Content in DB.
func messageMatches(a, b Message) bool {
if a.Role != b.Role {
return false
}
// If either message has Parts, compare Parts
if len(a.Parts) > 0 || len(b.Parts) > 0 {
return partsMatch(a.Parts, b.Parts)
}
// Simple text messages: compare Content
return a.Content == b.Content
}
// partsMatch compares two slices of MessagePart for equality.
func partsMatch(a, b []MessagePart) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i].Type != b[i].Type {
return false
}
switch a[i].Type {
case "text":
if a[i].Text != b[i].Text {
return false
}
case "tool_use":
if a[i].Name != b[i].Name || a[i].Arguments != b[i].Arguments || a[i].ToolCallID != b[i].ToolCallID {
return false
}
case "tool_result":
if a[i].ToolCallID != b[i].ToolCallID || a[i].Text != b[i].Text {
return false
}
case "media":
if a[i].MediaURI != b[i].MediaURI || a[i].MimeType != b[i].MimeType {
return false
}
}
}
return true
}
File diff suppressed because it is too large Load Diff
+212
View File
@@ -0,0 +1,212 @@
package seahorse
import (
"context"
"fmt"
"regexp"
"strconv"
"strings"
"time"
)
// ParseLastDuration parses a "last" duration string like "6h", "7d", "2w", "1m".
// Returns the duration and nil error, or zero and error if invalid.
func ParseLastDuration(s string) (time.Duration, error) {
if s == "" {
return 0, fmt.Errorf("empty duration")
}
re := regexp.MustCompile(`^(\d+)([hdwm])$`)
matches := re.FindStringSubmatch(s)
if matches == nil {
return 0, fmt.Errorf("invalid duration format: %q (use format like 6h, 7d, 2w, 1m)", s)
}
value, _ := strconv.Atoi(matches[1])
unit := matches[2]
switch unit {
case "h":
return time.Duration(value) * time.Hour, nil
case "d":
return time.Duration(value) * 24 * time.Hour, nil
case "w":
return time.Duration(value) * 7 * 24 * time.Hour, nil
case "m":
return time.Duration(value) * 30 * 24 * time.Hour, nil
default:
return 0, fmt.Errorf("unknown unit: %q", unit)
}
}
// GrepInput controls search across summaries and messages.
type GrepInput struct {
Pattern string `json:"pattern"`
Scope string `json:"scope,omitempty"` // "both" (default), "summary", or "message"
Role string `json:"role,omitempty"` // "user", "assistant", or "" (all)
AllConversations bool `json:"allConversations,omitempty"`
Since *time.Time `json:"since,omitempty"`
Before *time.Time `json:"before,omitempty"`
Last string `json:"last,omitempty"` // shortcut: "6h", "7d", "2w", "1m"
Limit int `json:"limit,omitempty"`
}
// GrepResult contains search results.
type GrepResult struct {
Success bool `json:"success"`
Summaries []GrepSummaryResult `json:"summaries"`
Messages []GrepMessageResult `json:"messages"`
TotalSummaries int `json:"totalSummaries"`
TotalMessages int `json:"totalMessages"`
Hint string `json:"hint,omitempty"`
}
// GrepSummaryResult is a summary match from grep.
type GrepSummaryResult struct {
ID string `json:"id"`
Content string `json:"content"`
Depth int `json:"depth"`
Kind SummaryKind `json:"kind"`
ConversationID int64 `json:"conversationId"`
// Rank is the bm25 relevance score (negative value, lower = better match).
// Examples: -5.0 = excellent match, -2.0 = good match, -0.5 = partial match.
Rank float64 `json:"rank,omitempty"`
}
// GrepMessageResult is a message match from grep.
type GrepMessageResult struct {
ID int64 `json:"id,string"`
Snippet string `json:"snippet"`
Role string `json:"role"`
ConversationID int64 `json:"conversationId"`
Rank float64 `json:"rank,omitempty"` // Relevance score (more negative = better match)
}
// ExpandMessagesResult contains expanded messages.
type ExpandMessagesResult struct {
Messages []Message `json:"messages"`
TokenCount int `json:"tokenCount"`
}
// Grep searches summaries and messages for matching content.
func (r *RetrievalEngine) Grep(ctx context.Context, input GrepInput) (*GrepResult, error) {
if input.Pattern == "" {
return nil, fmt.Errorf("grep: pattern is required")
}
limit := input.Limit
if limit == 0 {
limit = 20
}
// Handle Last parameter: convert to Since
since := input.Since
if input.Last != "" {
dur, err := ParseLastDuration(input.Last)
if err != nil {
return nil, fmt.Errorf("grep: invalid last: %w", err)
}
t := time.Now().UTC().Add(-dur)
since = &t
}
// Auto-detect mode: use LIKE if pattern contains %, otherwise full-text
mode := ""
if strings.Contains(input.Pattern, "%") {
mode = "like"
}
searchInput := SearchInput{
Pattern: input.Pattern,
Mode: mode,
Role: input.Role,
AllConversations: input.AllConversations,
Since: since,
Before: input.Before,
Limit: limit,
}
result := &GrepResult{
Success: true,
Summaries: make([]GrepSummaryResult, 0),
Messages: make([]GrepMessageResult, 0),
TotalSummaries: 0,
TotalMessages: 0,
}
// Determine scope
scope := input.Scope
if scope == "" {
scope = "both"
}
// Search summaries if requested
if scope == "both" || scope == "summary" {
sumResults, err := r.store.SearchSummaries(ctx, searchInput)
if err != nil {
return nil, fmt.Errorf("search summaries: %w", err)
}
for _, sr := range sumResults {
if sr.SummaryID != "" {
result.Summaries = append(result.Summaries, GrepSummaryResult{
ID: sr.SummaryID,
Content: sr.Content,
Depth: sr.Depth,
Kind: sr.Kind,
ConversationID: sr.ConversationID,
Rank: sr.Rank,
})
}
}
if len(sumResults) > 0 {
result.TotalSummaries = sumResults[0].TotalCount
}
}
// Search messages if requested
if scope == "both" || scope == "message" {
msgResults, err := r.store.SearchMessages(ctx, searchInput)
if err != nil {
return nil, fmt.Errorf("search messages: %w", err)
}
for _, sr := range msgResults {
if sr.MessageID > 0 {
result.Messages = append(result.Messages, GrepMessageResult{
ID: sr.MessageID,
Snippet: sr.Snippet,
Role: sr.Role,
ConversationID: sr.ConversationID,
Rank: sr.Rank,
})
}
}
if len(msgResults) > 0 {
result.TotalMessages = msgResults[0].TotalCount
}
}
// Add hint if no results
if len(result.Summaries) == 0 && len(result.Messages) == 0 {
result.Hint = "No matches. Try: %keyword% for fuzzy search, or all_conversations: true"
}
return result, nil
}
// ExpandMessages retrieves full message content by IDs.
func (r *RetrievalEngine) ExpandMessages(ctx context.Context, messageIDs []int64) (*ExpandMessagesResult, error) {
result := &ExpandMessagesResult{
Messages: make([]Message, 0, len(messageIDs)),
}
for _, msgID := range messageIDs {
msg, err := r.store.GetMessageByID(ctx, msgID)
if err != nil {
continue
}
result.Messages = append(result.Messages, *msg)
result.TokenCount += msg.TokenCount
}
return result, nil
}
+362
View File
@@ -0,0 +1,362 @@
package seahorse
import (
"context"
"fmt"
"testing"
"time"
)
// --- Retrieval Tests ---
func newTestRetrieval(t *testing.T) (*RetrievalEngine, *Store, int64) {
t.Helper()
s := openTestStore(t)
ctx := context.Background()
conv, _ := s.GetOrCreateConversation(ctx, "test:retrieval")
return &RetrievalEngine{store: s}, s, conv.ConversationID
}
func TestRetrievalGrepSummaries(t *testing.T) {
r, s, convID := newTestRetrieval(t)
ctx := context.Background()
s.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindLeaf,
Depth: 0,
Content: "数据库连接配置说明",
TokenCount: 50,
})
s.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindLeaf,
Depth: 0,
Content: "API endpoint documentation",
TokenCount: 50,
})
// FTS5 search (trigram, needs >= 3 chars)
results, err := r.Grep(ctx, GrepInput{
Pattern: "数据库连",
})
if err != nil {
t.Fatalf("Grep: %v", err)
}
if len(results.Summaries) == 0 {
t.Error("expected at least 1 FTS result")
}
// LIKE search with wildcard
results, err = r.Grep(ctx, GrepInput{
Pattern: "%endpoint%",
})
if err != nil {
t.Fatalf("Grep LIKE: %v", err)
}
if len(results.Summaries) == 0 {
t.Error("expected at least 1 LIKE result")
}
}
func TestRetrievalGrepMessages(t *testing.T) {
r, s, convID := newTestRetrieval(t)
ctx := context.Background()
s.AddMessage(ctx, convID, "user", "find this message about testing", 5)
s.AddMessage(ctx, convID, "user", "unrelated content here", 5)
results, err := r.Grep(ctx, GrepInput{
Pattern: "testing",
})
if err != nil {
t.Fatalf("Grep: %v", err)
}
if len(results.Messages) == 0 {
t.Error("expected at least 1 result for 'testing'")
}
}
func TestRetrievalExpandMessages(t *testing.T) {
r, s, convID := newTestRetrieval(t)
ctx := context.Background()
msg, _ := s.AddMessage(ctx, convID, "user", "expand this message", 10)
result, err := r.ExpandMessages(ctx, []int64{msg.ID})
if err != nil {
t.Fatalf("ExpandMessages: %v", err)
}
if len(result.Messages) != 1 {
t.Errorf("Messages = %d, want 1", len(result.Messages))
}
if result.Messages[0].Content != "expand this message" {
t.Errorf("Content = %q, want 'expand this message'", result.Messages[0].Content)
}
}
func TestRetrievalExpandMultipleMessages(t *testing.T) {
r, s, convID := newTestRetrieval(t)
ctx := context.Background()
msg1, _ := s.AddMessage(ctx, convID, "user", "first message", 10)
msg2, _ := s.AddMessage(ctx, convID, "assistant", "second message", 10)
msg3, _ := s.AddMessage(ctx, convID, "user", "third message", 10)
result, err := r.ExpandMessages(ctx, []int64{msg1.ID, msg2.ID, msg3.ID})
if err != nil {
t.Fatalf("ExpandMessages: %v", err)
}
if len(result.Messages) != 3 {
t.Errorf("Messages = %d, want 3", len(result.Messages))
}
if result.TokenCount != 30 {
t.Errorf("TokenCount = %d, want 30", result.TokenCount)
}
}
func TestRetrievalGrepWithTimeFilter(t *testing.T) {
r, s, convID := newTestRetrieval(t)
ctx := context.Background()
now := time.Now().UTC()
before := now.Add(-2 * time.Hour)
// Create messages at different times
s.AddMessage(ctx, convID, "user", "old message about auth", 5)
s.AddMessage(ctx, convID, "user", "recent message about auth", 5)
// Search with time filter
results, err := r.Grep(ctx, GrepInput{
Pattern: "auth",
Since: &before,
})
if err != nil {
t.Fatalf("Grep: %v", err)
}
_ = results // Just verify no error
}
func TestRetrievalGrepAllConversations(t *testing.T) {
r, s, _ := newTestRetrieval(t)
ctx := context.Background()
// Create another conversation
conv2, _ := s.GetOrCreateConversation(ctx, "test:retrieval2")
// Add messages to both
s.AddMessage(ctx, conv2.ConversationID, "user", "unique keyword xyz", 5)
// Search all conversations
results, err := r.Grep(ctx, GrepInput{
Pattern: "xyz",
AllConversations: true,
})
if err != nil {
t.Fatalf("Grep: %v", err)
}
if len(results.Messages) == 0 {
t.Error("expected to find message in other conversation")
}
}
// --- Last Duration Parsing Tests ---
func TestParseLastDuration(t *testing.T) {
tests := []struct {
input string
wantDur time.Duration
wantErr bool
}{
{"6h", 6 * time.Hour, false},
{"1d", 24 * time.Hour, false},
{"7d", 7 * 24 * time.Hour, false},
{"2w", 14 * 24 * time.Hour, false},
{"1m", 30 * 24 * time.Hour, false}, // month = 30 days
{"3m", 90 * 24 * time.Hour, false},
{"", 0, true},
{"invalid", 0, true},
{"5x", 0, true}, // unknown unit
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
got, err := ParseLastDuration(tt.input)
if tt.wantErr {
if err == nil {
t.Error("expected error, got nil")
}
} else {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != tt.wantDur {
t.Errorf("ParseLastDuration(%q) = %v, want %v", tt.input, got, tt.wantDur)
}
}
})
}
}
// --- Role Filter Tests ---
func TestRetrievalGrepRoleFilter(t *testing.T) {
r, s, convID := newTestRetrieval(t)
ctx := context.Background()
s.AddMessage(ctx, convID, "user", "user message about alpha", 5)
s.AddMessage(ctx, convID, "assistant", "assistant reply about alpha", 5)
s.AddMessage(ctx, convID, "user", "another user message", 5)
// Search all roles
allResults, err := r.Grep(ctx, GrepInput{
Pattern: "alpha",
})
if err != nil {
t.Fatalf("Grep: %v", err)
}
if len(allResults.Messages) != 2 {
t.Errorf("expected 2 messages, got %d", len(allResults.Messages))
}
// Search user only
userResults, err := r.Grep(ctx, GrepInput{
Pattern: "alpha",
Role: "user",
})
if err != nil {
t.Fatalf("Grep: %v", err)
}
if len(userResults.Messages) != 1 {
t.Errorf("expected 1 user message, got %d", len(userResults.Messages))
}
if userResults.Messages[0].Role != "user" {
t.Errorf("expected role=user, got %s", userResults.Messages[0].Role)
}
// Search assistant only
assistantResults, err := r.Grep(ctx, GrepInput{
Pattern: "alpha",
Role: "assistant",
})
if err != nil {
t.Fatalf("Grep: %v", err)
}
if len(assistantResults.Messages) != 1 {
t.Errorf("expected 1 assistant message, got %d", len(assistantResults.Messages))
}
}
// --- Last Parameter Tests ---
func TestRetrievalGrepWithLast(t *testing.T) {
r, s, convID := newTestRetrieval(t)
ctx := context.Background()
// Add messages (we can't control timestamps in SQLite easily,
// but we can verify the parameter is parsed correctly)
s.AddMessage(ctx, convID, "user", "recent message about testing", 5)
// Test that Last parameter is converted to Since
results, err := r.Grep(ctx, GrepInput{
Pattern: "testing",
Last: "1d", // last 1 day
})
if err != nil {
t.Fatalf("Grep: %v", err)
}
// Should still find the message since it's recent
if len(results.Messages) == 0 {
t.Error("expected to find recent message")
}
}
// TestRetrievalGrepRoleFilterWithSummaries tests that role filter works when
// searching both summaries and messages (summaries don't have role column).
func TestRetrievalGrepRoleFilterWithSummaries(t *testing.T) {
r, s, convID := newTestRetrieval(t)
ctx := context.Background()
// Create a summary (no role column)
s.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindLeaf,
Depth: 0,
Content: "summary about testing",
TokenCount: 50,
})
// Add messages with different roles
s.AddMessage(ctx, convID, "user", "user message about testing", 5)
s.AddMessage(ctx, convID, "assistant", "assistant reply about testing", 5)
// Search with role filter and scope=both (default), using LIKE mode (%)
// This should NOT error even though summaries don't have role column
bothResults, err := r.Grep(ctx, GrepInput{
Pattern: "%testing%", // LIKE mode to trigger the bug
Role: "user",
Scope: "both",
})
if err != nil {
t.Fatalf("Grep with role and scope=both: %v", err)
}
// Should only return user messages, not summaries or assistant messages
if len(bothResults.Messages) != 1 {
t.Errorf("expected 1 user message, got %d", len(bothResults.Messages))
}
if len(bothResults.Messages) > 0 && bothResults.Messages[0].Role != "user" {
t.Errorf("expected role=user, got %s", bothResults.Messages[0].Role)
}
// Summaries should be empty since they don't have roles to filter
// (or we could return all summaries - either is acceptable)
}
// TestRetrievalGrepTotalCounts tests that grep returns total counts.
func TestRetrievalGrepTotalCounts(t *testing.T) {
r, s, convID := newTestRetrieval(t)
ctx := context.Background()
// Create 3 summaries
for i := 0; i < 3; i++ {
s.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindLeaf,
Depth: 0,
Content: fmt.Sprintf("summary about testing %d", i),
TokenCount: 50,
})
}
// Add 5 messages
for i := 0; i < 5; i++ {
s.AddMessage(ctx, convID, "user", fmt.Sprintf("message about testing %d", i), 5)
}
// Search with limit smaller than total
results, err := r.Grep(ctx, GrepInput{
Pattern: "%testing%", // LIKE mode
Scope: "both",
Limit: 2,
})
if err != nil {
t.Fatalf("Grep: %v", err)
}
// Should return limited results
if len(results.Summaries) > 2 {
t.Errorf("expected at most 2 summaries, got %d", len(results.Summaries))
}
if len(results.Messages) > 2 {
t.Errorf("expected at most 2 messages, got %d", len(results.Messages))
}
// But total counts should reflect all matches
if results.TotalSummaries != 3 {
t.Errorf("expected TotalSummaries=3, got %d", results.TotalSummaries)
}
if results.TotalMessages != 5 {
t.Errorf("expected TotalMessages=5, got %d", results.TotalMessages)
}
}

Some files were not shown because too many files have changed in this diff Show More