feat(seahorse): implement short-term memory engine (LCM) (#2285)

* feat(seahorse): implement short-term memory engine of seahorse

Add pkg/seahorse/ module implementing a SQLite-backed DAG-based summary
hierarchy for context management, ported from lossless-claw's LCM design:

- types.go + short_constants.go: core types (Message, Summary, Conversation,
  ContextItem) and configuration constants (fanout, token targets, thresholds)
- migration.go: idempotent DB schema with FTS5 trigram tokenizer for CJK
- store.go: full SQLite CRUD (conversations, messages, summaries DAG,
  context_items with ordinal gap numbering, FTS5 search)
- short_engine.go: Engine lifecycle (NewEngine, Ingest, Assemble, Compact),
  session pattern filtering (ignore/stateless glob→regex compilation),
  per-session mutex via sync.Map
- short_assembler.go: budget-aware context assembly with fresh tail protection
  (32 messages), oldest-first eviction, summary XML formatting, RebuildContextItems
- short_compaction.go: leaf compaction (messages→summary) and condensed
  compaction (summaries→higher-level summary), 3-level LLM escalation,
  CompactUntilUnder for emergency overflow
- short_retrieval.go: lookupByID, FTS5/LIKE search, recursive expand with
  token cap
- context_seahorse.go: agent.ContextManager adapter, registered as "seahorse",
  provider↔seahorse message type conversion (ToolCalls, tool_result)

* fix(seahorse): correct 3 adapter bugs in context management

- TokenCount: use full message (Content+ToolCalls+Media) instead of Content-only
- Empty Content: rebuild Content from tool_result Parts when stored empty
- Duplicate summaries: summaries only in Summary field, not in History messages
- Grep: fix SearchResult.Snippet→Content for summaries
- Schema: fix FTS5 SQL uses VIRTUAL TABLE not TEMP TABLE
- TestFTS5SQLConstants: verify FTS5 SQL syntax correctness
- Test: fix flaky TestCompactLeaf

* fix(agent): ingest steering messages into seahorse SQLite

Steering messages were only persisted to session JSONL but not ingested
into seahorse SQLite, causing them to be missing from context assembly.

Added `ts.ingestMessage(turnCtx, al, pm)` call in the steering message
injection block alongside the existing JSONL persistence.

Test: TestSeahorseSteeringMessageIngested verifies steering messages
appear in seahorse SQLite DB after being processed.

* fix(seahorse): address 3 blocking bugs from code review

- Fix resequenceContextItemsTx scan error handling (store.go:850)
  Changed `return err` to `return scanErr` to properly propagate scan errors
  instead of returning nil (which silently corrupts data)

- Fix sql.NullString for INTEGER column (store.go:847)
  Changed `mid` from sql.NullString to sql.NullInt64 since message_id
  is INTEGER in schema. Removed unnecessary strconv.ParseInt call.

- Fix compactCondensed fallback deleting non-candidate items
  Added ReplaceContextItemsWithSummary method for per-item deletion
  when candidates are not contiguous in ordinal space.
  Optimized to use range deletion when candidates are consecutive.

* fix(seahorse): pass Budget to Compact for correct condensed threshold

Issue #4 from PR review: When Budget was not passed to seahorse.Compact,
it defaulted to `tokensBefore * 0.75`, making `tokensBefore > budget`
always true and causing condensed compaction to trigger unnecessarily.

Changes:
- context_seahorse.go: Forward Budget from CompactRequest to CompactInput
- loop.go: Pass Budget (ContextWindow) in all 3 Compact calls
- Add test verifying condensed is skipped when tokens < threshold
- Fix lint issues in store.go and store_test.go

* fix(seahorse): add mutex for assembler lazy initialization

Issue #5 from PR review: The check-then-create pattern for e.assembler
was a data race when multiple goroutines called Assemble() concurrently:
    if e.assembler == nil {
        e.assembler = &Assembler{...}
    }

Changes:
- Add assemblerMu sync.Mutex to Engine struct
- Add initAssemblerOnce() using double-checked locking (same pattern as initCompactionOnce)
- Add TestAssemblerLazyInitRace to verify thread-safety

* fix(seahorse): handle non-consecutive depths in selectShallowestCondensationCandidate

Issue #8 from PR review: the loop iterated depth 0, 1, 2... assuming
consecutive keys, but break when key was missing caused deeper depths
to never be checked.

Fix: collect all existing depth keys, sort, then iterate in order.

* fix(seahorse): wrap DeleteMessagesAfterID and appendContextItems in transactions

- DeleteMessagesAfterID: wrap all DELETE operations in a transaction for
  atomicity, remove redundant manual FTS delete (handled by trigger)
- appendContextItems: use transaction to fix read-then-write race condition
- Add GetMaxOrdinalTx and resolveItemTokenCountTx for transaction-scoped queries
- Remove unused resolveItemTokenCount function

Fixes PR review issues 6 and 7.

* fix(seahorse): derive readable content from Parts and cap CompactUntilUnder iterations

- Derive readable content from MessageParts in AddMessageWithParts so
  FTS5 indexing and summary formatting can access tool call information
- formatMessagesForSummary and truncateSummary now fall back to Parts
  when Content is empty, fixing blank summaries for Part-based messages
- Add MaxCompactIterations (20) to prevent CompactUntilUnder infinite
  loops; exceeded iterations are logged as warnings
This commit is contained in:
Liu Yuan
2026-04-05 09:05:16 +08:00
committed by GitHub
parent 71337b6f52
commit 15a70ac45c
39 changed files with 11271 additions and 108 deletions
+2
View File
@@ -67,3 +67,5 @@ web/backend/dist/*
.claude/ .claude/
docker/data docker/data
.omc/
+1
View File
@@ -12,6 +12,7 @@ linters:
- exhaustruct - exhaustruct
- funcorder - funcorder
- gochecknoglobals - gochecknoglobals
- gosmopolitan # Project legitimately uses CJK text in tests (FTS5, token counting)
- godot - godot
- intrange - intrange
- ireturn - ireturn
+11 -85
View File
@@ -6,10 +6,8 @@
package agent package agent
import ( import (
"encoding/json"
"unicode/utf8"
"github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/tokenizer"
) )
// parseTurnBoundaries returns the starting index of each Turn in the history. // parseTurnBoundaries returns the starting index of each Turn in the history.
@@ -86,88 +84,16 @@ func findSafeBoundary(history []providers.Message, targetIndex int) int {
return 0 return 0
} }
// estimateMessageTokens estimates the token count for a single message, // EstimateMessageTokens estimates the token count for a single message.
// including Content, ReasoningContent, ToolCalls arguments, ToolCallID // Delegates to the shared tokenizer package for consistency across agent and seahorse.
// metadata, and Media items. Uses a heuristic of 2.5 characters per token. func EstimateMessageTokens(msg providers.Message) int {
func estimateMessageTokens(msg providers.Message) int { return tokenizer.EstimateMessageTokens(msg)
contentChars := utf8.RuneCountInString(msg.Content)
// SystemParts are structured system blocks used for cache-aware adapters.
// They carry the same content as Content, but in multiple blocks.
// We estimate them as an alternative representation, not additive.
systemPartsChars := 0
if len(msg.SystemParts) > 0 {
for _, part := range msg.SystemParts {
systemPartsChars += utf8.RuneCountInString(part.Text)
}
// Per-part overhead for JSON structure (type, text, cache_control).
const perPartOverhead = 20
systemPartsChars += len(msg.SystemParts) * perPartOverhead
}
// Use the larger of the two representations to stay conservative.
chars := contentChars
if systemPartsChars > chars {
chars = systemPartsChars
}
chars += utf8.RuneCountInString(msg.ReasoningContent)
for _, tc := range msg.ToolCalls {
chars += len(tc.ID) + len(tc.Type)
if tc.Function != nil {
// Count function name + arguments (the wire format for most providers).
// tc.Name mirrors tc.Function.Name — count only once to avoid double-counting.
chars += len(tc.Function.Name) + len(tc.Function.Arguments)
} else {
// Fallback: some provider formats use top-level Name without Function.
chars += len(tc.Name)
}
}
if msg.ToolCallID != "" {
chars += len(msg.ToolCallID)
}
// Per-message overhead for role label, JSON structure, separators.
const messageOverhead = 12
chars += messageOverhead
tokens := chars * 2 / 5
// Media items (images, files) are serialized by provider adapters into
// multipart or image_url payloads. Add a fixed per-item token estimate
// directly (not through the chars heuristic) since actual cost depends
// on resolution and provider-specific image tokenization.
const mediaTokensPerItem = 256
tokens += len(msg.Media) * mediaTokensPerItem
return tokens
} }
// estimateToolDefsTokens estimates the total token cost of tool definitions // EstimateToolDefsTokens estimates the total token cost of tool definitions
// as they appear in the LLM request. Each tool's name, description, and // as they appear in the LLM request. Delegates to the shared tokenizer package.
// JSON schema parameters contribute to the context window budget. func EstimateToolDefsTokens(defs []providers.ToolDefinition) int {
func estimateToolDefsTokens(defs []providers.ToolDefinition) int { return tokenizer.EstimateToolDefsTokens(defs)
if len(defs) == 0 {
return 0
}
totalChars := 0
for _, d := range defs {
totalChars += len(d.Function.Name) + len(d.Function.Description)
if d.Function.Parameters != nil {
if paramJSON, err := json.Marshal(d.Function.Parameters); err == nil {
totalChars += len(paramJSON)
}
}
// Per-tool overhead: type field, JSON structure, separators.
totalChars += 20
}
return totalChars * 2 / 5
} }
// isOverContextBudget checks whether the assembled messages plus tool definitions // isOverContextBudget checks whether the assembled messages plus tool definitions
@@ -181,10 +107,10 @@ func isOverContextBudget(
) bool { ) bool {
msgTokens := 0 msgTokens := 0
for _, m := range messages { for _, m := range messages {
msgTokens += estimateMessageTokens(m) msgTokens += EstimateMessageTokens(m)
} }
toolTokens := estimateToolDefsTokens(toolDefs) toolTokens := EstimateToolDefsTokens(toolDefs)
total := msgTokens + toolTokens + maxTokens total := msgTokens + toolTokens + maxTokens
return total > contextWindow return total > contextWindow
+19 -19
View File
@@ -417,9 +417,9 @@ func TestEstimateMessageTokens(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got := estimateMessageTokens(tt.msg) got := EstimateMessageTokens(tt.msg)
if got < tt.want { if got < tt.want {
t.Errorf("estimateMessageTokens() = %d, want >= %d", got, tt.want) t.Errorf("EstimateMessageTokens() = %d, want >= %d", got, tt.want)
} }
}) })
} }
@@ -443,8 +443,8 @@ func TestEstimateMessageTokens_ToolCallsContribute(t *testing.T) {
}, },
} }
plainTokens := estimateMessageTokens(plain) plainTokens := EstimateMessageTokens(plain)
withTCTokens := estimateMessageTokens(withTC) withTCTokens := EstimateMessageTokens(withTC)
if withTCTokens <= plainTokens { if withTCTokens <= plainTokens {
t.Errorf("message with ToolCalls (%d tokens) should exceed plain message (%d tokens)", t.Errorf("message with ToolCalls (%d tokens) should exceed plain message (%d tokens)",
@@ -457,7 +457,7 @@ func TestEstimateMessageTokens_MultibyteContent(t *testing.T) {
// but may map to different token counts. The heuristic should still produce // but may map to different token counts. The heuristic should still produce
// reasonable estimates via RuneCountInString. // reasonable estimates via RuneCountInString.
msg := msgUser("caf\u00e9 na\u00efve r\u00e9sum\u00e9 \u00fcber stra\u00dfe") msg := msgUser("caf\u00e9 na\u00efve r\u00e9sum\u00e9 \u00fcber stra\u00dfe")
tokens := estimateMessageTokens(msg) tokens := EstimateMessageTokens(msg)
if tokens <= 0 { if tokens <= 0 {
t.Errorf("multibyte message should produce positive token count, got %d", tokens) t.Errorf("multibyte message should produce positive token count, got %d", tokens)
} }
@@ -481,7 +481,7 @@ func TestEstimateMessageTokens_LargeArguments(t *testing.T) {
}, },
} }
tokens := estimateMessageTokens(msg) tokens := EstimateMessageTokens(msg)
// 5000+ chars → at least 2000 tokens with the 2.5 char/token heuristic // 5000+ chars → at least 2000 tokens with the 2.5 char/token heuristic
if tokens < 2000 { if tokens < 2000 {
t.Errorf("large tool call arguments should produce significant token count, got %d", tokens) t.Errorf("large tool call arguments should produce significant token count, got %d", tokens)
@@ -496,8 +496,8 @@ func TestEstimateMessageTokens_ReasoningContent(t *testing.T) {
ReasoningContent: strings.Repeat("thinking step ", 200), ReasoningContent: strings.Repeat("thinking step ", 200),
} }
plainTokens := estimateMessageTokens(plain) plainTokens := EstimateMessageTokens(plain)
reasoningTokens := estimateMessageTokens(withReasoning) reasoningTokens := EstimateMessageTokens(withReasoning)
if reasoningTokens <= plainTokens { if reasoningTokens <= plainTokens {
t.Errorf("message with ReasoningContent (%d tokens) should exceed plain message (%d tokens)", t.Errorf("message with ReasoningContent (%d tokens) should exceed plain message (%d tokens)",
@@ -513,8 +513,8 @@ func TestEstimateMessageTokens_MediaItems(t *testing.T) {
Media: []string{"media://img1.png", "media://img2.png"}, Media: []string{"media://img1.png", "media://img2.png"},
} }
plainTokens := estimateMessageTokens(plain) plainTokens := EstimateMessageTokens(plain)
mediaTokens := estimateMessageTokens(withMedia) mediaTokens := EstimateMessageTokens(withMedia)
if mediaTokens <= plainTokens { if mediaTokens <= plainTokens {
t.Errorf("message with Media (%d tokens) should exceed plain message (%d tokens)", t.Errorf("message with Media (%d tokens) should exceed plain message (%d tokens)",
@@ -540,8 +540,8 @@ func TestEstimateMessageTokens_SystemParts(t *testing.T) {
}, },
} }
plainTokens := estimateMessageTokens(plain) plainTokens := EstimateMessageTokens(plain)
partsTokens := estimateMessageTokens(withParts) partsTokens := EstimateMessageTokens(withParts)
if partsTokens <= plainTokens { if partsTokens <= plainTokens {
t.Errorf("system message with SystemParts (%d) should exceed plain message (%d)", t.Errorf("system message with SystemParts (%d) should exceed plain message (%d)",
@@ -549,7 +549,7 @@ func TestEstimateMessageTokens_SystemParts(t *testing.T) {
} }
} }
// --- estimateToolDefsTokens tests --- // --- EstimateToolDefsTokens tests ---
func TestEstimateToolDefsTokens(t *testing.T) { func TestEstimateToolDefsTokens(t *testing.T) {
tests := []struct { tests := []struct {
@@ -599,9 +599,9 @@ func TestEstimateToolDefsTokens(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got := estimateToolDefsTokens(tt.defs) got := EstimateToolDefsTokens(tt.defs)
if got < tt.want { if got < tt.want {
t.Errorf("estimateToolDefsTokens() = %d, want >= %d", got, tt.want) t.Errorf("EstimateToolDefsTokens() = %d, want >= %d", got, tt.want)
} }
}) })
} }
@@ -624,8 +624,8 @@ func TestEstimateToolDefsTokens_ScalesWithCount(t *testing.T) {
} }
} }
one := estimateToolDefsTokens([]providers.ToolDefinition{makeTool("tool_a")}) one := EstimateToolDefsTokens([]providers.ToolDefinition{makeTool("tool_a")})
three := estimateToolDefsTokens([]providers.ToolDefinition{ three := EstimateToolDefsTokens([]providers.ToolDefinition{
makeTool("tool_a"), makeTool("tool_b"), makeTool("tool_c"), makeTool("tool_a"), makeTool("tool_b"), makeTool("tool_c"),
}) })
@@ -770,7 +770,7 @@ func TestEstimateMessageTokens_WithReasoningAndMedia(t *testing.T) {
}, },
} }
tokens := estimateMessageTokens(msg) tokens := EstimateMessageTokens(msg)
// ReasoningContent alone is ~1700 chars → ~680 tokens. // ReasoningContent alone is ~1700 chars → ~680 tokens.
// Content + TC + overhead adds more. Should be well above 500. // Content + TC + overhead adds more. Should be well above 500.
@@ -781,7 +781,7 @@ func TestEstimateMessageTokens_WithReasoningAndMedia(t *testing.T) {
// Compare without reasoning to ensure it's counted. // Compare without reasoning to ensure it's counted.
msgNoReasoning := msg msgNoReasoning := msg
msgNoReasoning.ReasoningContent = "" msgNoReasoning.ReasoningContent = ""
tokensNoReasoning := estimateMessageTokens(msgNoReasoning) tokensNoReasoning := EstimateMessageTokens(msgNoReasoning)
if tokens <= tokensNoReasoning { if tokens <= tokensNoReasoning {
t.Errorf("reasoning content should add tokens: with=%d, without=%d", tokens, tokensNoReasoning) t.Errorf("reasoning content should add tokens: with=%d, without=%d", tokens, tokensNoReasoning)
+1 -1
View File
@@ -373,7 +373,7 @@ func (m *legacyContextManager) summarizeBatch(
func (m *legacyContextManager) estimateTokens(messages []providers.Message) int { func (m *legacyContextManager) estimateTokens(messages []providers.Message) int {
total := 0 total := 0
for _, msg := range messages { for _, msg := range messages {
total += estimateMessageTokens(msg) total += EstimateMessageTokens(msg)
} }
return total return total
} }
+1
View File
@@ -43,6 +43,7 @@ type AssembleResponse struct {
type CompactRequest struct { type CompactRequest struct {
SessionKey string // session identifier SessionKey string // session identifier
Reason ContextCompressReason // proactive_budget | llm_retry | summarize Reason ContextCompressReason // proactive_budget | llm_retry | summarize
Budget int // context window budget (used for retry aggressive compaction)
} }
// IngestRequest is the input to Ingest. // IngestRequest is the input to Ingest.
+267
View File
@@ -0,0 +1,267 @@
package agent
import (
"context"
"encoding/json"
"fmt"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
"github.com/sipeed/picoclaw/pkg/seahorse"
"github.com/sipeed/picoclaw/pkg/session"
"github.com/sipeed/picoclaw/pkg/tokenizer"
)
// seahorseContextManager adapts seahorse.Engine to agent.ContextManager.
type seahorseContextManager struct {
engine *seahorse.Engine
sessions session.SessionStore // for startup bootstrap
}
// newSeahorseContextManager creates a seahorse-backed ContextManager.
func newSeahorseContextManager(_ json.RawMessage, al *AgentLoop) (ContextManager, error) {
if al == nil {
return nil, fmt.Errorf("seahorse: AgentLoop is required")
}
// Resolve workspace for DB path
// DB stores session data, so it goes in sessions/ directory
agent := al.registry.GetDefaultAgent()
dbPath := agent.Workspace + "/sessions/seahorse.db"
// Create CompleteFn from provider
completeFn := providerToCompleteFn(agent.Provider, agent.Model)
// Create engine
engine, err := seahorse.NewEngine(seahorse.Config{
DBPath: dbPath,
}, completeFn)
if err != nil {
return nil, fmt.Errorf("seahorse: create engine: %w", err)
}
mgr := &seahorseContextManager{
engine: engine,
sessions: agent.Sessions,
}
// Register seahorse tools with the agent's tool registry
retrieval := mgr.engine.GetRetrieval()
al.RegisterTool(seahorse.NewGrepTool(retrieval))
al.RegisterTool(seahorse.NewExpandTool(retrieval))
// Bootstrap all existing sessions at startup
if agent.Sessions != nil {
ctx := context.Background()
for _, sessionKey := range agent.Sessions.ListSessions() {
mgr.bootstrapSession(ctx, sessionKey)
}
}
return mgr, nil
}
// providerToCompleteFn wraps providers.LLMProvider as a seahorse.CompleteFn.
func providerToCompleteFn(provider providers.LLMProvider, model string) seahorse.CompleteFn {
return func(ctx context.Context, prompt string, opts seahorse.CompleteOptions) (string, error) {
resp, err := provider.Chat(
ctx,
[]providers.Message{{Role: "user", Content: prompt}},
nil, // no tools for summarization
model,
map[string]any{
"max_tokens": opts.MaxTokens,
"temperature": opts.Temperature,
"prompt_cache_key": "seahorse",
},
)
if err != nil {
return "", err
}
return resp.Content, nil
}
}
// Assemble builds budget-aware context from seahorse SQLite.
func (m *seahorseContextManager) Assemble(ctx context.Context, req *AssembleRequest) (*AssembleResponse, error) {
if req == nil {
return nil, fmt.Errorf("seahorse assemble: nil request")
}
budget := req.Budget
if budget <= 0 {
budget = 100000
}
// Reserve space for model response (spec lines 1400-1410)
effectiveBudget := budget - req.MaxTokens
if effectiveBudget <= 0 {
// MaxTokens >= budget is a configuration problem
// Use 50% as minimum to avoid guaranteed overflow
logger.WarnCF("agent", "MaxTokens >= budget, using 50% fallback",
map[string]any{"budget": budget, "max_tokens": req.MaxTokens})
effectiveBudget = budget / 2
}
result, err := m.engine.Assemble(ctx, req.SessionKey, seahorse.AssembleInput{
Budget: effectiveBudget,
})
if err != nil {
return nil, fmt.Errorf("seahorse assemble: %w", err)
}
history := seahorseToProviderMessages(result)
// Summary is already formatted as XML with system prompt addition by assembler
return &AssembleResponse{
History: history,
Summary: result.Summary,
}, nil
}
// Compact compresses conversation history via seahorse summarization.
func (m *seahorseContextManager) Compact(ctx context.Context, req *CompactRequest) error {
if req == nil {
return nil
}
// For retry (LLM overflow), use aggressive CompactUntilUnder to guarantee
// context shrinks below budget (spec lines ~1410).
if req.Reason == ContextCompressReasonRetry && req.Budget > 0 {
_, err := m.engine.CompactUntilUnder(ctx, req.SessionKey, req.Budget)
return err
}
_, err := m.engine.Compact(ctx, req.SessionKey, seahorse.CompactInput{
Force: req.Reason == ContextCompressReasonRetry,
Budget: &req.Budget,
})
return err
}
// Ingest records a message into seahorse SQLite.
// All existing sessions are bootstrapped at startup, so this only ingests new messages.
func (m *seahorseContextManager) Ingest(ctx context.Context, req *IngestRequest) error {
if req == nil {
return nil
}
msg := providerToSeahorseMessage(req.Message)
_, err := m.engine.Ingest(ctx, req.SessionKey, []seahorse.Message{msg})
return err
}
// bootstrapSession reconciles JSONL session history into seahorse SQLite.
func (m *seahorseContextManager) bootstrapSession(ctx context.Context, sessionKey string) {
if m.sessions == nil {
return
}
history := m.sessions.GetHistory(sessionKey)
if len(history) == 0 {
return
}
// Convert provider messages to seahorse messages
msgs := make([]seahorse.Message, len(history))
for i, h := range history {
msgs[i] = providerToSeahorseMessage(h)
}
if err := m.engine.Bootstrap(ctx, sessionKey, msgs); err != nil {
logger.WarnCF("seahorse", "bootstrap", map[string]any{
"session": sessionKey,
"error": err.Error(),
})
}
}
// providerToSeahorseMessage converts a providers.Message to a seahorse.Message.
func providerToSeahorseMessage(msg protocoltypes.Message) seahorse.Message {
result := seahorse.Message{
Role: msg.Role,
Content: msg.Content,
ReasoningContent: msg.ReasoningContent,
TokenCount: tokenizer.EstimateMessageTokens(msg),
}
// Convert ToolCalls → MessageParts
for _, tc := range msg.ToolCalls {
part := seahorse.MessagePart{
Type: "tool_use",
Name: tc.Function.Name,
Arguments: tc.Function.Arguments,
ToolCallID: tc.ID,
}
result.Parts = append(result.Parts, part)
}
// Convert tool result
if msg.ToolCallID != "" {
part := seahorse.MessagePart{
Type: "tool_result",
ToolCallID: msg.ToolCallID,
Text: msg.Content,
}
result.Parts = append(result.Parts, part)
}
// Convert media attachments
for _, mediaURI := range msg.Media {
part := seahorse.MessagePart{
Type: "media",
MediaURI: mediaURI,
}
result.Parts = append(result.Parts, part)
}
return result
}
// seahorseToProviderMessages converts a seahorse.AssembleResult to []providers.Message.
func seahorseToProviderMessages(result *seahorse.AssembleResult) []protocoltypes.Message {
messages := make([]protocoltypes.Message, 0, len(result.Messages))
// Convert assembled messages (which already include summary XML messages)
for _, msg := range result.Messages {
pm := protocoltypes.Message{
Role: msg.Role,
Content: msg.Content,
ReasoningContent: msg.ReasoningContent,
}
// Reconstruct ToolCalls from parts
for _, part := range msg.Parts {
if part.Type == "tool_use" {
pm.ToolCalls = append(pm.ToolCalls, protocoltypes.ToolCall{
ID: part.ToolCallID,
Type: "function", // Required by OpenAI-compatible APIs (GLM, etc.)
Function: &protocoltypes.FunctionCall{
Name: part.Name,
Arguments: part.Arguments,
},
})
}
if part.Type == "tool_result" {
pm.ToolCallID = part.ToolCallID
if pm.Content == "" && part.Text != "" {
pm.Content = part.Text
}
}
if part.Type == "media" && part.MediaURI != "" {
pm.Media = append(pm.Media, part.MediaURI)
}
}
messages = append(messages, pm)
}
return messages
}
func init() {
if err := RegisterContextManager("seahorse", newSeahorseContextManager); err != nil {
panic(fmt.Sprintf("register seahorse context manager: %v", err))
}
}
File diff suppressed because it is too large Load Diff
+5 -1
View File
@@ -1742,6 +1742,7 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
if err := al.contextManager.Compact(turnCtx, &CompactRequest{ if err := al.contextManager.Compact(turnCtx, &CompactRequest{
SessionKey: ts.sessionKey, SessionKey: ts.sessionKey,
Reason: ContextCompressReasonProactive, Reason: ContextCompressReasonProactive,
Budget: ts.agent.ContextWindow,
}); err != nil { }); err != nil {
logger.WarnCF("agent", "Proactive compact failed", map[string]any{ logger.WarnCF("agent", "Proactive compact failed", map[string]any{
"session_key": ts.sessionKey, "session_key": ts.sessionKey,
@@ -1857,6 +1858,7 @@ turnLoop:
if !ts.opts.NoHistory { if !ts.opts.NoHistory {
ts.agent.Sessions.AddFullMessage(ts.sessionKey, pm) ts.agent.Sessions.AddFullMessage(ts.sessionKey, pm)
ts.recordPersistedMessage(pm) ts.recordPersistedMessage(pm)
ts.ingestMessage(turnCtx, al, pm)
} }
logger.InfoCF("agent", "Injected steering message into context", logger.InfoCF("agent", "Injected steering message into context",
map[string]any{ map[string]any{
@@ -2128,6 +2130,7 @@ turnLoop:
if compactErr := al.contextManager.Compact(turnCtx, &CompactRequest{ if compactErr := al.contextManager.Compact(turnCtx, &CompactRequest{
SessionKey: ts.sessionKey, SessionKey: ts.sessionKey,
Reason: ContextCompressReasonRetry, Reason: ContextCompressReasonRetry,
Budget: ts.agent.ContextWindow,
}); compactErr != nil { }); compactErr != nil {
logger.WarnCF("agent", "Context overflow compact failed", map[string]any{ logger.WarnCF("agent", "Context overflow compact failed", map[string]any{
"session_key": ts.sessionKey, "session_key": ts.sessionKey,
@@ -2773,7 +2776,7 @@ turnLoop:
} }
} }
if ts.opts.EnableSummary { if ts.opts.EnableSummary {
al.contextManager.Compact(turnCtx, &CompactRequest{SessionKey: ts.sessionKey, Reason: ContextCompressReasonSummarize}) al.contextManager.Compact(turnCtx, &CompactRequest{SessionKey: ts.sessionKey, Reason: ContextCompressReasonSummarize, Budget: ts.agent.ContextWindow})
} }
ts.setPhase(TurnPhaseCompleted) ts.setPhase(TurnPhaseCompleted)
@@ -2849,6 +2852,7 @@ turnLoop:
&CompactRequest{ &CompactRequest{
SessionKey: ts.sessionKey, SessionKey: ts.sessionKey,
Reason: ContextCompressReasonSummarize, Reason: ContextCompressReasonSummarize,
Budget: ts.agent.ContextWindow,
}, },
) )
} }
+4 -2
View File
@@ -604,6 +604,7 @@ type ephemeralSessionStoreIface interface {
SetHistory(key string, history []providers.Message) SetHistory(key string, history []providers.Message)
TruncateHistory(key string, keepLast int) TruncateHistory(key string, keepLast int)
Save(key string) error Save(key string) error
ListSessions() []string
Close() error Close() error
} }
@@ -663,8 +664,9 @@ func (e *ephemeralSessionStore) TruncateHistory(_ string, keepLast int) {
e.history = e.history[len(e.history)-keepLast:] e.history = e.history[len(e.history)-keepLast:]
} }
func (e *ephemeralSessionStore) Save(_ string) error { return nil } func (e *ephemeralSessionStore) Save(_ string) error { return nil }
func (e *ephemeralSessionStore) Close() error { return nil } func (e *ephemeralSessionStore) Close() error { return nil }
func (e *ephemeralSessionStore) ListSessions() []string { return nil }
func (e *ephemeralSessionStore) truncateLocked() { func (e *ephemeralSessionStore) truncateLocked() {
if len(e.history) > maxEphemeralHistorySize { if len(e.history) > maxEphemeralHistorySize {
+27
View File
@@ -455,6 +455,33 @@ func (s *JSONLStore) rewriteJSONL(
return fileutil.WriteFileAtomic(s.jsonlPath(sessionKey), buf.Bytes(), 0o644) return fileutil.WriteFileAtomic(s.jsonlPath(sessionKey), buf.Bytes(), 0o644)
} }
// ListSessions returns all known session keys by reading .meta.json files.
func (s *JSONLStore) ListSessions() []string {
entries, err := os.ReadDir(s.dir)
if err != nil {
return nil
}
var keys []string
for _, entry := range entries {
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".meta.json") {
continue
}
// Read the meta file to get the original key
data, err := os.ReadFile(filepath.Join(s.dir, entry.Name()))
if err != nil {
continue
}
var meta sessionMeta
if err := json.Unmarshal(data, &meta); err != nil {
continue
}
if meta.Key != "" {
keys = append(keys, meta.Key)
}
}
return keys
}
func (s *JSONLStore) Close() error { func (s *JSONLStore) Close() error {
return nil return nil
} }
+3
View File
@@ -37,6 +37,9 @@ type Store interface {
// data. Backends that do not accumulate dead data may return nil. // data. Backends that do not accumulate dead data may return nil.
Compact(ctx context.Context, sessionKey string) error Compact(ctx context.Context, sessionKey string) error
// ListSessions returns all known session keys.
ListSessions() []string
// Close releases any resources held by the store. // Close releases any resources held by the store.
Close() error Close() error
} }
@@ -0,0 +1,7 @@
{
"tool_name": "Bash",
"tool_input_preview": "{\"command\":\"cd /home/yliu/repos/picoclaw && make lint 2>&1\",\"timeout\":120000}",
"error": "Exit code 2\npkg/agent/context_seahorse_test.go:1027:1: File is not properly formatted (gci)\n\t\t\tEarliestAt: &now,\n^\n1 issues:\n* gci: 1\nmake: *** [Makefile:264: lint] Error 1",
"timestamp": "2026-04-04T02:38:32.067Z",
"retry_count": 6
}
+58
View File
@@ -0,0 +1,58 @@
package seahorse
import (
"context"
"testing"
)
// =============================================================================
// CompactUntilUnder iteration cap
// =============================================================================
func TestCompactUntilUnderIterationCap(t *testing.T) {
// Setup: create a conversation with so many tokens that compaction
// will never reach the budget. The iteration cap prevents infinite loops.
//
// We use a mock CompleteFn that always returns the same content,
// and a budget of 0 which tokens can never reach.
// Without the cap, this would loop forever.
db := openTestDB(t)
if err := runSchema(db); err != nil {
t.Fatalf("migration: %v", err)
}
s := &Store{db: db}
conv, _ := s.GetOrCreateConversation(context.Background(), "agent:iter-cap")
convID := conv.ConversationID
// Add many messages to ensure there's plenty to compact
for i := 0; i < 40; i++ {
m, _ := s.AddMessage(context.Background(), convID, "user",
"this is a long message with lots of tokens to push context over budget", 100)
s.AppendContextMessage(context.Background(), convID, m.ID)
}
// A completeFn that always succeeds but returns non-reducing content
mockComplete := func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) {
return "Summary that doesn't reduce tokens much.", nil
}
ce, cancel := newTestCompactionEngineWithStore(s, mockComplete)
defer cancel()
// Use budget=1 so tokens can never reach budget
// (each message is 100 tokens, so 40 messages = 4000 tokens, budget 1 is unreachable)
// The function should stop after maxCompactIterations, not loop forever
ce.config = Config{} // ensure defaults
result, err := ce.CompactUntilUnder(context.Background(), convID, 1)
if err != nil {
// Should not error — should stop gracefully
t.Fatalf("CompactUntilUnder with budget=0: %v", err)
}
// The function should have completed within reasonable time
// If it exceeded the cap, it would still return (not hang)
_ = result
}
+144
View File
@@ -0,0 +1,144 @@
package seahorse
import (
"context"
"testing"
"time"
)
// =============================================================================
// Bug 1: formatMessagesForSummary ignores Parts
// - formatMessagesForSummary only reads m.Content, empty for Part-based messages
// - truncateSummary has same issue
// =============================================================================
func TestFormatMessagesForSummaryIncludesParts(t *testing.T) {
ts := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
messages := []Message{
{ID: 1, Role: "user", Content: "hello world", CreatedAt: ts},
{
ID: 2,
Role: "assistant",
Content: "", // empty — real content is in Parts
Parts: []MessagePart{
{Type: "text", Text: "I will run a command"},
{Type: "tool_use", Name: "bash", Arguments: `{"command":"ls -la"}`, ToolCallID: "call_1"},
},
CreatedAt: ts.Add(time.Minute),
},
{
ID: 3,
Role: "tool",
Content: "", // empty — real content is in Parts
Parts: []MessagePart{
{Type: "tool_result", Text: "file1.txt\nfile2.txt", ToolCallID: "call_1"},
},
CreatedAt: ts.Add(2 * time.Minute),
},
}
result := formatMessagesForSummary(messages)
// Must contain the plain text message
if !contains(result, "hello world") {
t.Error("formatMessagesForSummary: missing plain text content")
}
// Must contain tool_use info (not blank)
if !contains(result, "bash") || !contains(result, "ls -la") {
t.Errorf("formatMessagesForSummary: tool_use info missing from Parts.\nGot:\n%s", result)
}
// Must contain tool_result info (not blank)
if !contains(result, "file1.txt") {
t.Errorf("formatMessagesForSummary: tool_result text missing from Parts.\nGot:\n%s", result)
}
}
func TestTruncateSummaryIncludesParts(t *testing.T) {
messages := []Message{
{ID: 1, Role: "user", Content: "run the tests", CreatedAt: time.Now()},
{
ID: 2,
Role: "assistant",
Content: "", // empty
Parts: []MessagePart{
{Type: "tool_use", Name: "bash", Arguments: `{"command":"go test ./..."}`, ToolCallID: "call_1"},
},
CreatedAt: time.Now(),
},
{
ID: 3,
Role: "tool",
Content: "", // empty
Parts: []MessagePart{
{Type: "tool_result", Text: "PASS\nok 3.2s", ToolCallID: "call_1"},
},
CreatedAt: time.Now(),
},
}
result := truncateSummary(messages)
// Must contain plain text
if !contains(result, "run the tests") {
t.Error("truncateSummary: missing plain text content")
}
// Must contain tool info from Parts (not blank)
if !contains(result, "bash") || !contains(result, "go test") {
t.Errorf("truncateSummary: tool_use info missing from Parts.\nGot:\n%s", result)
}
// Must contain tool_result from Parts
if !contains(result, "PASS") {
t.Errorf("truncateSummary: tool_result text missing from Parts.\nGot:\n%s", result)
}
}
// =============================================================================
// Bug 2: SearchMessages cannot find Part-based messages
// - FTS5 indexes empty content, LIKE queries empty content
// =============================================================================
func TestSearchMessagesFindsPartBasedMessages(t *testing.T) {
s := openTestStore(t)
ctx := context.Background()
conv, _ := s.GetOrCreateConversation(ctx, "agent:search-parts")
convID := conv.ConversationID
// Add a plain message (searchable)
s.AddMessage(ctx, convID, "user", "list the files please", 5)
// Add a Part-based message (tool_use) — currently NOT searchable
parts := []MessagePart{
{Type: "tool_use", Name: "bash", Arguments: `{"command":"grep -r TODO ."}`, ToolCallID: "call_1"},
}
s.AddMessageWithParts(ctx, convID, "assistant", parts, 10)
// Add a Part-based message (tool_result) — currently NOT searchable
resultParts := []MessagePart{
{Type: "tool_result", Text: "main.go:42: TODO fix this bug", ToolCallID: "call_1"},
}
s.AddMessageWithParts(ctx, convID, "tool", resultParts, 10)
// Search for "grep" — should find the tool_use message
results, err := s.SearchMessages(ctx, SearchInput{Pattern: "grep"})
if err != nil {
t.Fatalf("SearchMessages: %v", err)
}
if len(results) == 0 {
t.Error("SearchMessages: 'grep' not found — Part-based messages are invisible to search")
}
// Search for "TODO fix" — should find the tool_result message
results2, err := s.SearchMessages(ctx, SearchInput{Pattern: "TODO fix"})
if err != nil {
t.Fatalf("SearchMessages: %v", err)
}
if len(results2) == 0 {
t.Error("SearchMessages: 'TODO fix' not found — tool_result messages are invisible to search")
}
}
+185
View File
@@ -0,0 +1,185 @@
package seahorse
import (
"database/sql"
"fmt"
"github.com/sipeed/picoclaw/pkg/logger"
)
// SQL statements for FTS5 tables with trigram tokenizer.
const (
sqlCreateSummariesFTS = `CREATE VIRTUAL TABLE IF NOT EXISTS summaries_fts USING fts5(
summary_id,
content,
tokenize="trigram"
)`
sqlCreateMessagesFTS = `CREATE VIRTUAL TABLE IF NOT EXISTS messages_fts USING fts5(
message_id,
content,
tokenize="trigram"
)`
sqlCheckFTS5Available = `CREATE VIRTUAL TABLE IF NOT EXISTS _fts5_check USING fts5(content)`
sqlCheckTrigramAvailable = `CREATE VIRTUAL TABLE IF NOT EXISTS _trigram_check USING fts5(content, tokenize="trigram")`
sqlDropFTS5Check = `DROP TABLE IF EXISTS _fts5_check`
sqlDropTrigramCheck = `DROP TABLE IF EXISTS _trigram_check`
)
// runSchema creates or upgrades the database schema.
// All schemas are idempotent (safe to run multiple times).
func runSchema(db *sql.DB) error {
// Check FTS5 support before creating tables
if err := checkFTS5Support(db); err != nil {
return fmt.Errorf("FTS5 check: %w", err)
}
stmts := []string{
`CREATE TABLE IF NOT EXISTS conversations (
conversation_id INTEGER PRIMARY KEY AUTOINCREMENT,
session_key TEXT NOT NULL UNIQUE,
created_at TEXT NOT NULL DEFAULT (datetime('now')),
updated_at TEXT NOT NULL DEFAULT (datetime('now'))
)`,
`CREATE TABLE IF NOT EXISTS messages (
message_id INTEGER PRIMARY KEY AUTOINCREMENT,
conversation_id INTEGER NOT NULL REFERENCES conversations(conversation_id),
role TEXT NOT NULL,
content TEXT NOT NULL DEFAULT '',
token_count INTEGER NOT NULL DEFAULT 0,
created_at TEXT NOT NULL DEFAULT (datetime('now'))
)`,
`CREATE TABLE IF NOT EXISTS message_parts (
part_id INTEGER PRIMARY KEY AUTOINCREMENT,
message_id INTEGER NOT NULL REFERENCES messages(message_id),
type TEXT NOT NULL,
text TEXT,
name TEXT,
arguments TEXT,
tool_call_id TEXT,
media_uri TEXT,
mime_type TEXT,
ordinal INTEGER NOT NULL DEFAULT 0
)`,
`CREATE TABLE IF NOT EXISTS summaries (
summary_id TEXT PRIMARY KEY,
conversation_id INTEGER NOT NULL REFERENCES conversations(conversation_id),
kind TEXT NOT NULL,
depth INTEGER NOT NULL DEFAULT 0,
content TEXT NOT NULL,
token_count INTEGER NOT NULL DEFAULT 0,
earliest_at TEXT,
latest_at TEXT,
descendant_count INTEGER NOT NULL DEFAULT 0,
descendant_token_count INTEGER NOT NULL DEFAULT 0,
source_message_token_count INTEGER NOT NULL DEFAULT 0,
model TEXT,
created_at TEXT NOT NULL DEFAULT (datetime('now'))
)`,
`CREATE TABLE IF NOT EXISTS summary_parents (
summary_id TEXT NOT NULL,
parent_summary_id TEXT NOT NULL,
PRIMARY KEY (summary_id, parent_summary_id)
)`,
`CREATE TABLE IF NOT EXISTS summary_messages (
summary_id TEXT NOT NULL,
message_id INTEGER NOT NULL,
ordinal INTEGER NOT NULL DEFAULT 0,
PRIMARY KEY (summary_id, message_id)
)`,
`CREATE TABLE IF NOT EXISTS context_items (
conversation_id INTEGER NOT NULL,
ordinal INTEGER NOT NULL,
item_type TEXT NOT NULL,
summary_id TEXT,
message_id INTEGER,
token_count INTEGER NOT NULL DEFAULT 0,
created_at TEXT NOT NULL DEFAULT (datetime('now')),
PRIMARY KEY (conversation_id, ordinal)
)`,
// FTS5 virtual table with trigram tokenizer for CJK support
sqlCreateSummariesFTS,
// FTS5 virtual table for message search with trigram tokenizer
sqlCreateMessagesFTS,
// Indexes for common query patterns
`CREATE INDEX IF NOT EXISTS idx_messages_conversation ON messages(conversation_id)`,
`CREATE INDEX IF NOT EXISTS idx_messages_created ON messages(conversation_id, created_at)`,
`CREATE INDEX IF NOT EXISTS idx_summaries_conversation ON summaries(conversation_id)`,
`CREATE INDEX IF NOT EXISTS idx_summaries_kind_depth ON summaries(conversation_id, kind, depth)`,
`CREATE INDEX IF NOT EXISTS idx_summary_parents_parent ON summary_parents(parent_summary_id)`,
`CREATE INDEX IF NOT EXISTS idx_summary_messages_message ON summary_messages(message_id)`,
`CREATE INDEX IF NOT EXISTS idx_context_items_conv ON context_items(conversation_id, ordinal)`,
// FTS5 triggers to keep summaries_fts in sync with summaries table
`CREATE TRIGGER IF NOT EXISTS summaries_ai AFTER INSERT ON summaries BEGIN
INSERT INTO summaries_fts (summary_id, content) VALUES (new.summary_id, new.content);
END`,
`CREATE TRIGGER IF NOT EXISTS summaries_ad AFTER DELETE ON summaries BEGIN
INSERT INTO summaries_fts (summaries_fts, summary_id, content) VALUES ('delete', old.summary_id, old.content);
END`,
`CREATE TRIGGER IF NOT EXISTS summaries_au AFTER UPDATE ON summaries BEGIN
INSERT INTO summaries_fts (summaries_fts, summary_id, content) VALUES ('delete', old.summary_id, old.content);
INSERT INTO summaries_fts (summary_id, content) VALUES (new.summary_id, new.content);
END`,
// FTS5 triggers to keep messages_fts in sync with messages table
`CREATE TRIGGER IF NOT EXISTS messages_ai AFTER INSERT ON messages BEGIN
INSERT INTO messages_fts (message_id, content) VALUES (new.message_id, new.content);
END`,
`CREATE TRIGGER IF NOT EXISTS messages_ad AFTER DELETE ON messages BEGIN
DELETE FROM messages_fts WHERE message_id = old.message_id;
END`,
`CREATE TRIGGER IF NOT EXISTS messages_au AFTER UPDATE ON messages BEGIN
DELETE FROM messages_fts WHERE message_id = old.message_id;
INSERT INTO messages_fts (message_id, content) VALUES (new.message_id, new.content);
END`,
}
for _, s := range stmts {
if _, err := db.Exec(s); err != nil {
return err
}
}
return nil
}
// checkFTS5Support verifies that SQLite has FTS5 with trigram tokenizer enabled.
// This is required for full-text search with CJK (Chinese, Japanese, Korean) support.
func checkFTS5Support(db *sql.DB) error {
// Check if FTS5 is compiled in
var fts5Enabled int
err := db.QueryRow(`SELECT sqlite_compileoption_used('ENABLE_FTS5')`).Scan(&fts5Enabled)
if err != nil {
// sqlite_compileoption_used might not exist in older SQLite
// Try a different approach: create a test FTS5 table
_, testErr := db.Exec(sqlCheckFTS5Available)
if testErr != nil {
return fmt.Errorf("SQLite FTS5 not available: %w (required for full-text search)", testErr)
}
db.Exec(sqlDropFTS5Check)
} else if fts5Enabled == 0 {
return fmt.Errorf("SQLite was compiled without FTS5 support (required for full-text search)")
}
// Check if trigram tokenizer is available by trying to create a test table
// Not all SQLite builds include the trigram tokenizer
_, err = db.Exec(sqlCheckTrigramAvailable)
if err != nil {
logger.WarnCF("seahorse", "SQLite trigram tokenizer not available, CJK search may be limited",
map[string]any{"error": err.Error()})
// Trigram is not strictly required, just better for CJK
// Don't return error, just log warning
} else {
db.Exec(sqlDropTrigramCheck)
}
return nil
}
+211
View File
@@ -0,0 +1,211 @@
package seahorse
import (
"database/sql"
"testing"
_ "modernc.org/sqlite"
)
func openTestDB(t *testing.T) *sql.DB {
t.Helper()
db, err := sql.Open("sqlite", ":memory:")
if err != nil {
t.Fatalf("open test db: %v", err)
}
t.Cleanup(func() { db.Close() })
return db
}
func TestRunMigrations(t *testing.T) {
db := openTestDB(t)
if err := runSchema(db); err != nil {
t.Fatalf("runSchema: %v", err)
}
// Verify all tables exist
tables := []string{
"conversations",
"messages",
"message_parts",
"summaries",
"summary_parents",
"summary_messages",
"context_items",
}
for _, tbl := range tables {
var name string
err := db.QueryRow(
"SELECT name FROM sqlite_master WHERE type='table' AND name=?", tbl,
).Scan(&name)
if err != nil {
t.Errorf("table %q not found: %v", tbl, err)
}
}
// Verify FTS5 virtual table exists
var ftsName string
err := db.QueryRow(
"SELECT name FROM sqlite_master WHERE type='table' AND name='summaries_fts'",
).Scan(&ftsName)
if err != nil {
t.Errorf("FTS5 table summaries_fts not found: %v", err)
}
}
func TestRunMigrationsIdempotent(t *testing.T) {
db := openTestDB(t)
// Run migrations twice — should succeed both times
if err := runSchema(db); err != nil {
t.Fatalf("first migration: %v", err)
}
if err := runSchema(db); err != nil {
t.Fatalf("second migration (idempotent): %v", err)
}
// Verify we can still insert data after double migration
res, err := db.Exec(
"INSERT INTO conversations (session_key, created_at, updated_at) VALUES (?, datetime('now'), datetime('now'))",
"test-session",
)
if err != nil {
t.Fatalf("insert after double migration: %v", err)
}
id, _ := res.LastInsertId()
if id == 0 {
t.Error("expected non-zero conversation id")
}
}
func TestMigrationConversationUnique(t *testing.T) {
db := openTestDB(t)
if err := runSchema(db); err != nil {
t.Fatalf("migration: %v", err)
}
// Insert first
_, err := db.Exec(
"INSERT INTO conversations (session_key, created_at, updated_at) VALUES (?, datetime('now'), datetime('now'))",
"unique-key",
)
if err != nil {
t.Fatalf("first insert: %v", err)
}
// Duplicate should fail
_, err = db.Exec(
"INSERT INTO conversations (session_key, created_at, updated_at) VALUES (?, datetime('now'), datetime('now'))",
"unique-key",
)
if err == nil {
t.Error("expected unique constraint violation for duplicate session_key")
}
}
func TestMigrationSummaryFTSInsert(t *testing.T) {
db := openTestDB(t)
if err := runSchema(db); err != nil {
t.Fatalf("migration: %v", err)
}
// Insert a conversation first
_, err := db.Exec(
"INSERT INTO conversations (session_key, created_at, updated_at) VALUES (?, datetime('now'), datetime('now'))",
"fts-test",
)
if err != nil {
t.Fatalf("insert conversation: %v", err)
}
// Insert a summary
_, err = db.Exec(
`INSERT INTO summaries (summary_id, conversation_id, kind, depth, content, token_count, created_at)
VALUES ('sum_test1', 1, 'leaf', 0, '你好世界 hello world', 10, datetime('now'))`)
if err != nil {
t.Fatalf("insert summary: %v", err)
}
// FTS should find it — trigram tokenizer requires >= 3 chars
rows, err := db.Query(
"SELECT summary_id FROM summaries_fts WHERE summaries_fts MATCH ?",
"你好世",
)
if err != nil {
t.Fatalf("FTS query: %v", err)
}
defer rows.Close()
var found string
if rows.Next() {
if err := rows.Scan(&found); err != nil {
t.Fatalf("scan: %v", err)
}
}
if err := rows.Err(); err != nil {
t.Fatalf("rows.Err: %v", err)
}
if found != "sum_test1" {
t.Errorf("FTS: expected 'sum_test1', got %q", found)
}
}
func TestMigrationSummaryParentsPK(t *testing.T) {
db := openTestDB(t)
if err := runSchema(db); err != nil {
t.Fatalf("migration: %v", err)
}
// Insert two summaries
for _, id := range []string{"sum_a", "sum_b"} {
_, err := db.Exec(
`INSERT INTO summaries (summary_id, conversation_id, kind, depth, content, token_count, created_at)
VALUES (?, 1, 'leaf', 0, 'content', 5, datetime('now'))`, id)
if err != nil {
t.Fatalf("insert summary %s: %v", id, err)
}
}
// Link child to parent
_, err := db.Exec(
"INSERT INTO summary_parents (summary_id, parent_summary_id) VALUES ('sum_a', 'sum_b')")
if err != nil {
t.Fatalf("link: %v", err)
}
// Duplicate link should fail (composite PK)
_, err = db.Exec(
"INSERT INTO summary_parents (summary_id, parent_summary_id) VALUES ('sum_a', 'sum_b')")
if err == nil {
t.Error("expected unique constraint violation for duplicate summary_parents link")
}
}
func TestFTS5SQLConstants(t *testing.T) {
db := openTestDB(t)
// Verify FTS5 check SQL executes without error
_, err := db.Exec(sqlCheckFTS5Available)
if err != nil {
t.Errorf("sqlCheckFTS5Available failed: %v", err)
}
// Verify trigram check SQL executes without error
_, err = db.Exec(sqlCheckTrigramAvailable)
if err != nil {
t.Errorf("sqlCheckTrigramAvailable failed: %v", err)
}
// Verify summaries_fts SQL executes without error
_, err = db.Exec(sqlCreateSummariesFTS)
if err != nil {
t.Errorf("sqlCreateSummariesFTS failed: %v", err)
}
// Verify messages_fts SQL executes without error
_, err = db.Exec(sqlCreateMessagesFTS)
if err != nil {
t.Errorf("sqlCreateMessagesFTS failed: %v", err)
}
}
+261
View File
@@ -0,0 +1,261 @@
package seahorse
import (
"context"
"fmt"
"strings"
"time"
"github.com/sipeed/picoclaw/pkg/logger"
)
// escapeXML escapes special characters for safe inclusion in XML content.
func escapeXML(s string) string {
s = strings.ReplaceAll(s, "&", "&amp;")
s = strings.ReplaceAll(s, "<", "&lt;")
s = strings.ReplaceAll(s, ">", "&gt;")
s = strings.ReplaceAll(s, "\"", "&quot;")
s = strings.ReplaceAll(s, "'", "&apos;")
return s
}
// resolvedItem is a context item resolved to its full content with token count.
type resolvedItem struct {
ordinal int
itemType string // "message" or "summary"
message *Message
summary *Summary
tokenCount int
}
// Assemble builds budget-constrained context from summaries + messages.
//
// Algorithm:
// 1. Fetch context_items, resolve to full content
// 2. Split into evictable prefix + protected fresh tail
// 3. If evictable fits in remaining budget → include all
// 4. Else walk evictable from newest to oldest, keep while fits
func (a *Assembler) Assemble(ctx context.Context, convID int64, input AssembleInput) (*AssembleResult, error) {
items, err := a.store.GetContextItems(ctx, convID)
if err != nil {
return nil, fmt.Errorf("get context items: %w", err)
}
if len(items) == 0 {
return &AssembleResult{}, nil
}
// Resolve all items
resolved := make([]resolvedItem, len(items))
for i, item := range items {
r, err := a.resolveItem(ctx, item)
if err != nil {
return nil, err
}
resolved[i] = r
}
// Split into evictable prefix and protected fresh tail
tailStart := len(resolved) - FreshTailCount
if tailStart < 0 {
tailStart = 0
}
evictable := resolved[:tailStart]
freshTail := resolved[tailStart:]
// Calculate fresh tail tokens
freshTailTokens := 0
for _, r := range freshTail {
freshTailTokens += r.tokenCount
}
// Budget-aware selection of evictable items
remainingBudget := input.Budget - freshTailTokens
if remainingBudget < 0 {
// Fresh tail alone exceeds budget - we keep it anyway (design decision)
// Log for debugging retry/overflow issues
logger.InfoCF("seahorse", "assemble: fresh tail exceeds budget", map[string]any{
"budget": input.Budget,
"fresh_tail_tokens": freshTailTokens,
"fresh_tail_count": len(freshTail),
"over_budget_by": freshTailTokens - input.Budget,
})
remainingBudget = 0
}
var selected []resolvedItem
evictableTokens := 0
for _, r := range evictable {
evictableTokens += r.tokenCount
}
if evictableTokens <= remainingBudget {
// All evictable fit
selected = append(selected, evictable...)
} else {
// Walk from newest to oldest, keep while fits
var kept []resolvedItem
accum := 0
for i := len(evictable) - 1; i >= 0; i-- {
if accum+evictable[i].tokenCount <= remainingBudget {
kept = append(kept, evictable[i])
accum += evictable[i].tokenCount
} else {
break
}
}
// Reverse to restore chronological order
for i, j := 0, len(kept)-1; i < j; i, j = i+1, j-1 {
kept[i], kept[j] = kept[j], kept[i]
}
selected = append(selected, kept...)
}
// Combine: selected evictable + fresh tail
final := append(selected, freshTail...)
// Build result
var messages []Message
var summaries []Summary
var sourceIDs []string
totalTokens := 0
maxDepth := 0
condensedCount := 0
for _, r := range final {
totalTokens += r.tokenCount
if r.itemType == "message" && r.message != nil {
messages = append(messages, *r.message)
sourceIDs = append(sourceIDs, fmt.Sprintf("msg:%d", r.message.ID))
} else if r.itemType == "summary" && r.summary != nil {
summaries = append(summaries, *r.summary)
if r.summary.Depth > maxDepth {
maxDepth = r.summary.Depth
}
if r.summary.Kind == SummaryKindCondensed {
condensedCount++
}
}
}
// Build depth-aware system prompt addition
systemPromptAddition := ""
if len(summaries) > 0 {
if maxDepth >= 2 || condensedCount >= 2 {
systemPromptAddition = "Your context has been heavily compressed through multi-level summarization.\n" +
"- Do NOT assert specific facts (commands, SHAs, paths, timestamps) from summaries without expanding.\n" +
"- When uncertain, use expand to recover original detail before making claims.\n" +
"- Tool escalation: grep \xe2\x86\x92 describe \xe2\x86\x92 expand"
} else {
systemPromptAddition = "Some earlier messages have been summarized. Use expand tools to recover details if needed."
}
}
// Build Summary field: all XML summaries + system prompt addition
var summaryParts []string
for _, sum := range summaries {
if sum.Content == "" {
continue
}
// Load parent IDs for XML formatting
parentSummaries, err := a.store.GetSummaryParents(ctx, sum.SummaryID)
if err != nil {
logger.WarnCF("seahorse", "assemble: get summary parents", map[string]any{
"summary_id": sum.SummaryID,
"error": err.Error(),
})
}
var parentIDs []string
for _, ps := range parentSummaries {
parentIDs = append(parentIDs, ps.SummaryID)
}
summaryParts = append(summaryParts, FormatSummaryXML(&sum, parentIDs))
}
summary := strings.Join(summaryParts, "\n\n")
if systemPromptAddition != "" {
if summary != "" {
summary += "\n\n"
}
summary += systemPromptAddition
}
return &AssembleResult{
Messages: messages,
Summary: summary,
}, nil
}
// resolveItem loads the full message or summary for a context item.
func (a *Assembler) resolveItem(ctx context.Context, item ContextItem) (resolvedItem, error) {
if item.ItemType == "message" {
msg, err := a.store.GetMessageByID(ctx, item.MessageID)
if err != nil {
return resolvedItem{}, err
}
tokens := item.TokenCount
if tokens == 0 {
tokens = msg.TokenCount
}
return resolvedItem{
ordinal: item.Ordinal,
itemType: "message",
message: msg,
tokenCount: tokens,
}, nil
}
if item.ItemType == "summary" {
sum, err := a.store.GetSummary(ctx, item.SummaryID)
if err != nil {
return resolvedItem{}, err
}
tokens := item.TokenCount
if tokens == 0 {
tokens = sum.TokenCount
}
return resolvedItem{
ordinal: item.Ordinal,
itemType: "summary",
summary: sum,
tokenCount: tokens,
}, nil
}
return resolvedItem{
ordinal: item.Ordinal,
itemType: item.ItemType,
tokenCount: item.TokenCount,
}, nil
}
// FormatSummaryXML formats a summary as XML for LLM context.
// This is exported so context managers can format summaries consistently.
func FormatSummaryXML(s *Summary, parentIDs []string) string {
// Build time attributes if available
var attrs string
if s.EarliestAt != nil {
attrs += fmt.Sprintf(` earliest_at="%s"`, s.EarliestAt.Format(time.RFC3339))
}
if s.LatestAt != nil {
attrs += fmt.Sprintf(` latest_at="%s"`, s.LatestAt.Format(time.RFC3339))
}
var parentsSection string
if s.Kind == SummaryKindCondensed && len(parentIDs) > 0 {
parents := "<parents>\n"
for _, pid := range parentIDs {
parents += fmt.Sprintf(" <summary_ref id=\"%s\" />\n", pid)
}
parents += " </parents>\n"
parentsSection = parents
}
return fmt.Sprintf(
"<summary id=\"%s\" kind=\"%s\" depth=\"%d\" descendant_count=\"%d\"%s>\n <content>\n %s\n </content>\n%s</summary>",
s.SummaryID,
string(s.Kind),
s.Depth,
s.DescendantCount,
attrs,
escapeXML(s.Content),
parentsSection,
)
}
+536
View File
@@ -0,0 +1,536 @@
package seahorse
import (
"context"
"strings"
"testing"
"time"
)
// --- Assembler Tests ---
// helper: create a store with messages and summaries for assembly tests
func setupAssemblerStore(t *testing.T) (*Store, int64) {
t.Helper()
s := openTestStore(t)
ctx := context.Background()
conv, err := s.GetOrCreateConversation(ctx, "test:assemble")
if err != nil {
t.Fatalf("create conversation: %v", err)
}
return s, conv.ConversationID
}
func TestAssemblerAssembleEmpty(t *testing.T) {
s, convID := setupAssemblerStore(t)
ctx := context.Background()
a := &Assembler{store: s, config: Config{}}
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 1000})
if err != nil {
t.Fatalf("Assemble: %v", err)
}
if len(result.Messages) != 0 {
t.Errorf("Messages = %d, want 0", len(result.Messages))
}
if result.Summary != "" {
t.Errorf("Summary = %q, want empty", result.Summary)
}
}
func TestAssemblerAssembleMessagesOnly(t *testing.T) {
s, convID := setupAssemblerStore(t)
ctx := context.Background()
// Create messages
msg1, _ := s.AddMessage(ctx, convID, "user", "hello", 5)
msg2, _ := s.AddMessage(ctx, convID, "assistant", "world", 5)
// Create context items
s.UpsertContextItems(ctx, convID, []ContextItem{
{Ordinal: 100, ItemType: "message", MessageID: msg1.ID, TokenCount: 5},
{Ordinal: 200, ItemType: "message", MessageID: msg2.ID, TokenCount: 5},
})
a := &Assembler{store: s, config: Config{}}
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 100})
if err != nil {
t.Fatalf("Assemble: %v", err)
}
if len(result.Messages) != 2 {
t.Fatalf("Messages = %d, want 2", len(result.Messages))
}
if result.Messages[0].Content != "hello" {
t.Errorf("Messages[0].Content = %q, want 'hello'", result.Messages[0].Content)
}
if result.Messages[1].Content != "world" {
t.Errorf("Messages[1].Content = %q, want 'world'", result.Messages[1].Content)
}
// No summaries, so Summary should be empty
if result.Summary != "" {
t.Errorf("Summary = %q, want empty", result.Summary)
}
}
func TestAssemblerAssembleWithSummary(t *testing.T) {
s, convID := setupAssemblerStore(t)
ctx := context.Background()
// Create a summary
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindLeaf,
Depth: 0,
Content: "summary of early messages",
TokenCount: 50,
})
// Create recent messages
msg1, _ := s.AddMessage(ctx, convID, "user", "recent", 5)
msg2, _ := s.AddMessage(ctx, convID, "assistant", "reply", 5)
// Context: summary + recent messages
s.UpsertContextItems(ctx, convID, []ContextItem{
{Ordinal: 100, ItemType: "summary", SummaryID: summary.SummaryID, TokenCount: 50},
{Ordinal: 200, ItemType: "message", MessageID: msg1.ID, TokenCount: 5},
{Ordinal: 300, ItemType: "message", MessageID: msg2.ID, TokenCount: 5},
})
a := &Assembler{store: s, config: Config{}}
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 1000})
if err != nil {
t.Fatalf("Assemble: %v", err)
}
// Messages = 2 raw messages (summaries are in Summary field, not Messages)
if len(result.Messages) != 2 {
t.Errorf("Messages = %d, want 2 (raw messages only)", len(result.Messages))
}
// Summary should contain XML with summary content
if result.Summary == "" {
t.Error("Summary should not be empty when summary exists")
}
if !strings.Contains(result.Summary, summary.Content) {
t.Errorf("Summary should contain summary content %q", summary.Content)
}
if !strings.Contains(result.Summary, "<summary") {
t.Error("Summary should contain <summary XML tag")
}
}
func TestAssemblerBudgetEvictsOldest(t *testing.T) {
s, convID := setupAssemblerStore(t)
ctx := context.Background()
// Create 40 messages, each with 10 tokens = 400 total
msgs := make([]*Message, 40)
for i := 0; i < 40; i++ {
m, _ := s.AddMessage(ctx, convID, "user", "msg", 10)
msgs[i] = m
}
// Context items for all messages
items := make([]ContextItem, 40)
for i := 0; i < 40; i++ {
items[i] = ContextItem{
Ordinal: (i + 1) * 100,
ItemType: "message",
MessageID: msgs[i].ID,
TokenCount: 10,
}
}
s.UpsertContextItems(ctx, convID, items)
// Budget of 200 tokens with FreshTailCount=32
// Fresh tail = last 32 messages (320 tokens, over budget, but always included)
// Evictable = first 8 messages (80 tokens)
// Budget after tail: max(0, 200-320) = 0 → no evictable items included
a := &Assembler{store: s, config: Config{}}
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 200})
if err != nil {
t.Fatalf("Assemble: %v", err)
}
// Should only include the 32-item fresh tail
if len(result.Messages) != 32 {
t.Errorf("Messages = %d, want 32 (fresh tail)", len(result.Messages))
}
// Should be the LAST 32 messages
if result.Messages[0].ID != msgs[8].ID {
t.Errorf("first message ID = %d, want %d (msgs[8])", result.Messages[0].ID, msgs[8].ID)
}
}
func TestAssemblerBudgetFitsAll(t *testing.T) {
s, convID := setupAssemblerStore(t)
ctx := context.Background()
msgs := make([]*Message, 5)
for i := 0; i < 5; i++ {
m, _ := s.AddMessage(ctx, convID, "user", "msg", 10)
msgs[i] = m
}
items := make([]ContextItem, 5)
for i := 0; i < 5; i++ {
items[i] = ContextItem{
Ordinal: (i + 1) * 100,
ItemType: "message",
MessageID: msgs[i].ID,
TokenCount: 10,
}
}
s.UpsertContextItems(ctx, convID, items)
// Budget = 100, total = 50, FreshTailCount=32 → all items in tail
a := &Assembler{store: s, config: Config{}}
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 100})
if err != nil {
t.Fatalf("Assemble: %v", err)
}
if len(result.Messages) != 5 {
t.Errorf("Messages = %d, want 5", len(result.Messages))
}
}
func TestAssemblerSummaryXMLFormat(t *testing.T) {
s, convID := setupAssemblerStore(t)
ctx := context.Background()
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindLeaf,
Depth: 0,
Content: "test summary content",
TokenCount: 20,
})
msg, _ := s.AddMessage(ctx, convID, "user", "hello", 5)
s.UpsertContextItems(ctx, convID, []ContextItem{
{Ordinal: 100, ItemType: "summary", SummaryID: summary.SummaryID, TokenCount: 20},
{Ordinal: 200, ItemType: "message", MessageID: msg.ID, TokenCount: 5},
})
a := &Assembler{store: s, config: Config{}}
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 1000})
if err != nil {
t.Fatalf("Assemble: %v", err)
}
// Messages should only contain raw messages (no XML summary in Messages)
if len(result.Messages) != 1 {
t.Errorf("Messages = %d, want 1 (raw message only)", len(result.Messages))
}
// Summary should contain XML with summary content
if result.Summary == "" {
t.Fatal("Summary should not be empty")
}
if !contains(result.Summary, "<summary") {
t.Errorf("Summary missing <summary tag: %q", result.Summary)
}
if !contains(result.Summary, summary.SummaryID) {
t.Errorf("Summary missing summary ID: %q", result.Summary)
}
}
func TestAssemblerSummaryXMLEscaping(t *testing.T) {
// Summary content with special XML characters should be properly escaped
s, convID := setupAssemblerStore(t)
ctx := context.Background()
// Create summary with content containing XML special characters
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindLeaf,
Depth: 0,
Content: `User said: "hello" & asked about <tags>`,
TokenCount: 20,
})
s.UpsertContextItems(ctx, convID, []ContextItem{
{Ordinal: 100, ItemType: "summary", SummaryID: summary.SummaryID, TokenCount: 20},
})
a := &Assembler{store: s, config: Config{}}
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 1000})
if err != nil {
t.Fatalf("Assemble: %v", err)
}
// Summary field should contain XML with escaped special characters
if result.Summary == "" {
t.Fatal("Summary should not be empty")
}
// Check that special characters are escaped
if strings.Contains(result.Summary, "<tags>") {
t.Errorf("BUG: unescaped < in summary content: %q", result.Summary)
}
if strings.Contains(result.Summary, `"hello"`) {
t.Errorf("BUG: unescaped \" in summary content: %q", result.Summary)
}
// & should be escaped as &amp;
if strings.Contains(result.Summary, " & ") {
t.Errorf("BUG: unescaped & in summary content: %q", result.Summary)
}
}
func TestAssemblerSummaryXMLWithParents(t *testing.T) {
s, convID := setupAssemblerStore(t)
ctx := context.Background()
// Create a leaf and a condensed summary (condensed has parent)
leaf, _ := s.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindLeaf,
Depth: 0,
Content: "leaf content",
TokenCount: 20,
})
condensed, _ := s.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindCondensed,
Depth: 1,
Content: "condensed content",
TokenCount: 15,
ParentIDs: []string{leaf.SummaryID},
})
msg, _ := s.AddMessage(ctx, convID, "user", "fresh", 5)
s.UpsertContextItems(ctx, convID, []ContextItem{
{Ordinal: 100, ItemType: "summary", SummaryID: condensed.SummaryID, TokenCount: 15},
{Ordinal: 200, ItemType: "message", MessageID: msg.ID, TokenCount: 5},
})
a := &Assembler{store: s, config: Config{}}
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 1000})
if err != nil {
t.Fatalf("Assemble: %v", err)
}
// Summary field should contain XML with parent information
if result.Summary == "" {
t.Fatal("Summary should not be empty")
}
xmlContent := result.Summary
// Should contain <parents> section with parent ID
if !contains(xmlContent, "<parents>") {
t.Errorf("condensed summary XML missing <parents> section: %q", xmlContent)
}
if !contains(xmlContent, leaf.SummaryID) {
t.Errorf("condensed summary XML missing parent ID %q: %q", leaf.SummaryID, xmlContent)
}
// Should contain kind="condensed"
if !contains(xmlContent, `kind="condensed"`) {
t.Errorf("condensed summary XML missing kind attribute: %q", xmlContent)
}
}
func TestAssemblerSummaryXMLIncludesDescendantCount(t *testing.T) {
s, convID := setupAssemblerStore(t)
ctx := context.Background()
// Create a leaf summary with specific descendant count
leaf, _ := s.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindLeaf,
Depth: 0,
Content: "leaf content",
TokenCount: 20,
DescendantCount: 8,
DescendantTokenCount: 1200,
})
msg, _ := s.AddMessage(ctx, convID, "user", "fresh", 5)
s.UpsertContextItems(ctx, convID, []ContextItem{
{Ordinal: 100, ItemType: "summary", SummaryID: leaf.SummaryID, TokenCount: 20},
{Ordinal: 200, ItemType: "message", MessageID: msg.ID, TokenCount: 5},
})
a := &Assembler{store: s, config: Config{}}
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 1000})
if err != nil {
t.Fatalf("Assemble: %v", err)
}
if result.Summary == "" {
t.Fatal("Summary should not be empty")
}
xmlContent := result.Summary
// Should contain descendant_count="8"
if !contains(xmlContent, `descendant_count="8"`) {
t.Errorf("summary XML missing descendant_count attribute: %q", xmlContent)
}
}
func TestAssemblerLeafSummaryNoParents(t *testing.T) {
s, convID := setupAssemblerStore(t)
ctx := context.Background()
// Leaf summary has no parents
leaf, _ := s.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindLeaf,
Depth: 0,
Content: "leaf content",
TokenCount: 20,
})
msg, _ := s.AddMessage(ctx, convID, "user", "fresh", 5)
s.UpsertContextItems(ctx, convID, []ContextItem{
{Ordinal: 100, ItemType: "summary", SummaryID: leaf.SummaryID, TokenCount: 20},
{Ordinal: 200, ItemType: "message", MessageID: msg.ID, TokenCount: 5},
})
a := &Assembler{store: s, config: Config{}}
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 1000})
if err != nil {
t.Fatalf("Assemble: %v", err)
}
if result.Summary == "" {
t.Fatal("Summary should not be empty")
}
xmlContent := result.Summary
// Leaf summary should NOT have <parents> section
if contains(xmlContent, "<parents>") {
t.Errorf("leaf summary XML should not have <parents> section: %q", xmlContent)
}
}
func TestAssemblerDepthAwarePrompt(t *testing.T) {
s, convID := setupAssemblerStore(t)
ctx := context.Background()
// Create a condensed summary (depth >= 2) to trigger full guidance
now := time.Now().UTC()
leaf, _ := s.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindLeaf,
Depth: 0,
Content: "leaf summary",
TokenCount: 20,
EarliestAt: &now,
LatestAt: &now,
})
condensed, _ := s.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindCondensed,
Depth: 2,
Content: "condensed summary",
TokenCount: 15,
ParentIDs: []string{leaf.SummaryID},
DescendantCount: 1,
DescendantTokenCount: 20,
})
msg, _ := s.AddMessage(ctx, convID, "user", "fresh", 5)
s.UpsertContextItems(ctx, convID, []ContextItem{
{Ordinal: 100, ItemType: "summary", SummaryID: condensed.SummaryID, TokenCount: 15},
{Ordinal: 200, ItemType: "message", MessageID: msg.ID, TokenCount: 5},
})
a := &Assembler{store: s, config: Config{}}
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 1000})
if err != nil {
t.Fatalf("Assemble: %v", err)
}
// Should have a depth-aware prompt in Summary field
if result.Summary == "" {
t.Error("expected non-empty Summary when depth >= 2")
}
// SystemPromptAddition is embedded in Summary field
if !strings.Contains(result.Summary, "multi-level summarization") {
t.Error("Summary should contain system prompt addition about multi-level summarization")
}
}
func TestFormatSummaryXMLUsesSummaryRef(t *testing.T) {
// Spec: condensed summaries use <summary_ref id="parentId" /> not <parent>parentId</parent>
now := time.Now().UTC()
s := Summary{
SummaryID: "sum_condensed1",
Kind: SummaryKindCondensed,
Depth: 1,
Content: "condensed content",
TokenCount: 50,
DescendantCount: 2,
EarliestAt: &now,
LatestAt: &now,
}
parentIDs := []string{"sum_leaf1", "sum_leaf2"}
xml := FormatSummaryXML(&s, parentIDs)
// Must use <summary_ref id="..." /> per spec
if !contains(xml, `<summary_ref id="sum_leaf1" />`) {
t.Errorf("expected <summary_ref id=\"sum_leaf1\" />, got: %s", xml)
}
if !contains(xml, `<summary_ref id="sum_leaf2" />`) {
t.Errorf("expected <summary_ref id=\"sum_leaf2\" />, got: %s", xml)
}
// Must NOT use old <parent> tag
if contains(xml, "<parent>") {
t.Errorf("should not use <parent> tag, got: %s", xml)
}
}
func TestFormatSummaryXMLIncludesTimestamps(t *testing.T) {
// Spec: summary XML includes earliest_at and latest_at attributes
earliest := time.Date(2026, 3, 15, 10, 0, 0, 0, time.UTC)
latest := time.Date(2026, 3, 15, 14, 30, 0, 0, time.UTC)
s := Summary{
SummaryID: "sum_leaf1",
Kind: SummaryKindLeaf,
Depth: 0,
Content: "leaf content",
TokenCount: 30,
DescendantCount: 0,
EarliestAt: &earliest,
LatestAt: &latest,
}
xml := FormatSummaryXML(&s, nil)
if !contains(xml, `earliest_at="2026-03-15T10:00:00Z"`) {
t.Errorf("missing earliest_at attribute, got: %s", xml)
}
if !contains(xml, `latest_at="2026-03-15T14:30:00Z"`) {
t.Errorf("missing latest_at attribute, got: %s", xml)
}
}
func TestFormatSummaryXMLNoTimestampsWhenNil(t *testing.T) {
// When EarliestAt/LatestAt are nil, attributes should be omitted
s := Summary{
SummaryID: "sum_leaf1",
Kind: SummaryKindLeaf,
Depth: 0,
Content: "leaf content",
TokenCount: 30,
DescendantCount: 0,
}
xml := FormatSummaryXML(&s, nil)
if contains(xml, "earliest_at=") {
t.Errorf("should not have earliest_at when nil, got: %s", xml)
}
if contains(xml, "latest_at=") {
t.Errorf("should not have latest_at when nil, got: %s", xml)
}
}
+336
View File
@@ -0,0 +1,336 @@
package seahorse
import (
"context"
"database/sql"
"fmt"
"testing"
"time"
_ "modernc.org/sqlite"
)
// newBenchStore creates a test store for benchmarks.
func newBenchStore(b *testing.B) (*Store, func()) {
b.Helper()
db, err := sql.Open("sqlite", ":memory:")
if err != nil {
b.Fatalf("open test db: %v", err)
}
if err := runSchema(db); err != nil {
db.Close()
b.Fatalf("migration: %v", err)
}
return &Store{db: db}, func() { db.Close() }
}
// --- Ingest benchmarks ---
func BenchmarkIngest_SingleMessage(b *testing.B) {
s, cleanup := newBenchStore(b)
defer cleanup()
ctx := context.Background()
conv, _ := s.GetOrCreateConversation(ctx, "bench:ingest")
convID := conv.ConversationID
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := s.AddMessage(ctx, convID, "user", "Test message content", 15)
if err != nil {
b.Fatal(err)
}
}
}
func BenchmarkIngest_BatchMessages(b *testing.B) {
s, cleanup := newBenchStore(b)
defer cleanup()
ctx := context.Background()
b.ResetTimer()
for i := 0; i < b.N; i++ {
conv, _ := s.GetOrCreateConversation(ctx, fmt.Sprintf("bench:ingest-batch:%d", i))
convID := conv.ConversationID
for j := 0; j < 10; j++ {
added, err := s.AddMessage(ctx, convID, "user",
fmt.Sprintf("Message %d in batch", j), 10)
if err != nil {
b.Fatal(err)
}
s.AppendContextMessage(ctx, convID, added.ID)
}
}
}
// --- Assemble benchmarks ---
func BenchmarkAssemble_MessagesOnly(b *testing.B) {
s, cleanup := newBenchStore(b)
defer cleanup()
ctx := context.Background()
conv, _ := s.GetOrCreateConversation(ctx, "bench:assemble-msgs")
convID := conv.ConversationID
// Add 100 messages
for i := 0; i < 100; i++ {
m, _ := s.AddMessage(ctx, convID, "user",
fmt.Sprintf("Message content %d with some text", i), 10)
s.AppendContextMessage(ctx, convID, m.ID)
}
a := &Assembler{store: s}
input := AssembleInput{Budget: 50000}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := a.Assemble(ctx, convID, input)
if err != nil {
b.Fatal(err)
}
}
}
func BenchmarkAssemble_WithSummaries(b *testing.B) {
s, cleanup := newBenchStore(b)
defer cleanup()
ctx := context.Background()
conv, _ := s.GetOrCreateConversation(ctx, "bench:assemble-sums")
convID := conv.ConversationID
now := time.Now().UTC()
// Add 10 leaf summaries
for i := 0; i < 10; i++ {
sum, _ := s.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindLeaf,
Depth: 0,
Content: fmt.Sprintf("Leaf summary %d", i),
TokenCount: 500,
EarliestAt: &now,
LatestAt: &now,
})
s.AppendContextSummary(ctx, convID, sum.SummaryID)
}
// Add 20 fresh messages
for i := 0; i < 20; i++ {
m, _ := s.AddMessage(ctx, convID, "user", fmt.Sprintf("Fresh message %d", i), 10)
s.AppendContextMessage(ctx, convID, m.ID)
}
a := &Assembler{store: s}
input := AssembleInput{Budget: 10000}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := a.Assemble(ctx, convID, input)
if err != nil {
b.Fatal(err)
}
}
}
func BenchmarkAssemble_BudgetEviction(b *testing.B) {
s, cleanup := newBenchStore(b)
defer cleanup()
ctx := context.Background()
conv, _ := s.GetOrCreateConversation(ctx, "bench:assemble-evict")
convID := conv.ConversationID
now := time.Now().UTC()
// Add 50 leaf summaries (more than budget can hold)
for i := 0; i < 50; i++ {
sum, _ := s.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindLeaf,
Depth: 0,
Content: fmt.Sprintf("Summary %d", i),
TokenCount: 300,
EarliestAt: &now,
LatestAt: &now,
})
s.AppendContextSummary(ctx, convID, sum.SummaryID)
}
// Add fresh tail
for i := 0; i < FreshTailCount; i++ {
m, _ := s.AddMessage(ctx, convID, "user", "fresh", 10)
s.AppendContextMessage(ctx, convID, m.ID)
}
a := &Assembler{store: s}
input := AssembleInput{Budget: 5000} // Force eviction
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := a.Assemble(ctx, convID, input)
if err != nil {
b.Fatal(err)
}
}
}
// --- Search (FTS5) benchmarks ---
// benchSeedSummaries adds n summaries to a conversation for search benchmarks.
func benchSeedSummaries(b *testing.B, s *Store, convID int64, n int, contentTpl string) {
b.Helper()
now := time.Now().UTC()
for i := 0; i < n; i++ {
sum, err := s.CreateSummary(context.Background(), CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindLeaf,
Depth: 0,
Content: fmt.Sprintf(contentTpl, i),
TokenCount: 200,
EarliestAt: &now,
LatestAt: &now,
})
if err != nil {
b.Fatalf("create summary: %v", err)
}
s.AppendContextSummary(context.Background(), convID, sum.SummaryID)
}
}
func BenchmarkSearchSummaries_FTS5(b *testing.B) {
s, cleanup := newBenchStore(b)
defer cleanup()
ctx := context.Background()
conv, _ := s.GetOrCreateConversation(ctx, "bench:search-fts")
convID := conv.ConversationID
benchSeedSummaries(b, s, convID, 100, "Summary about database configuration and API endpoints %d")
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := s.SearchSummaries(ctx, SearchInput{
Pattern: "database",
Mode: "full_text",
ConversationID: convID,
})
if err != nil {
b.Fatal(err)
}
}
}
func BenchmarkSearchSummaries_Like(b *testing.B) {
s, cleanup := newBenchStore(b)
defer cleanup()
ctx := context.Background()
conv, _ := s.GetOrCreateConversation(ctx, "bench:search-like")
convID := conv.ConversationID
benchSeedSummaries(b, s, convID, 100, "Summary about configuration %d")
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := s.SearchSummaries(ctx, SearchInput{
Pattern: "config",
Mode: "like",
ConversationID: convID,
})
if err != nil {
b.Fatal(err)
}
}
}
func BenchmarkSearchMessages_FTS5(b *testing.B) {
s, cleanup := newBenchStore(b)
defer cleanup()
ctx := context.Background()
conv, _ := s.GetOrCreateConversation(ctx, "bench:search-msg-fts")
convID := conv.ConversationID
// Add 500 messages
for i := 0; i < 500; i++ {
m, _ := s.AddMessage(ctx, convID, "user",
fmt.Sprintf("User message about API and database integration %d", i), 20)
s.AppendContextMessage(ctx, convID, m.ID)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := s.SearchMessages(ctx, SearchInput{
Pattern: "API database",
Mode: "full_text",
ConversationID: convID,
})
if err != nil {
b.Fatal(err)
}
}
}
// --- Bootstrap benchmarks ---
func BenchmarkBootstrap_Empty(b *testing.B) {
s, cleanup := newBenchStore(b)
defer cleanup()
ctx := context.Background()
b.ResetTimer()
for i := 0; i < b.N; i++ {
conv, _ := s.GetOrCreateConversation(ctx, fmt.Sprintf("bench:bootstrap-empty:%d", i))
convID := conv.ConversationID
_ = convID // Bootstrap with empty history
}
}
func BenchmarkBootstrap_100Messages(b *testing.B) {
s, cleanup := newBenchStore(b)
defer cleanup()
ctx := context.Background()
// Prepare 100 messages
msgs := make([]Message, 100)
for i := 0; i < 100; i++ {
msgs[i] = Message{
Role: "user",
Content: fmt.Sprintf("Bootstrap message %d", i),
TokenCount: 15,
}
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
conv, _ := s.GetOrCreateConversation(ctx, fmt.Sprintf("bench:bootstrap-100:%d", i))
convID := conv.ConversationID
for _, m := range msgs {
added, _ := s.AddMessage(ctx, convID, m.Role, m.Content, m.TokenCount)
s.AppendContextMessage(ctx, convID, added.ID)
}
}
}
func BenchmarkBootstrap_500Messages(b *testing.B) {
s, cleanup := newBenchStore(b)
defer cleanup()
ctx := context.Background()
msgs := make([]Message, 500)
for i := 0; i < 500; i++ {
msgs[i] = Message{
Role: "user",
Content: fmt.Sprintf("Bootstrap message %d", i),
TokenCount: 15,
}
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
conv, _ := s.GetOrCreateConversation(ctx, fmt.Sprintf("bench:bootstrap-500:%d", i))
convID := conv.ConversationID
for _, m := range msgs {
added, _ := s.AddMessage(ctx, convID, m.Role, m.Content, m.TokenCount)
s.AppendContextMessage(ctx, convID, added.ID)
}
}
}
+898
View File
@@ -0,0 +1,898 @@
package seahorse
import (
"context"
"fmt"
"sort"
"time"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/tokenizer"
)
// CompactInput controls compaction behavior.
type CompactInput struct {
Budget *int // Token budget override
Force bool // Force compaction even if below threshold
}
// CompactResult describes what was compacted.
type CompactResult struct {
SummariesCreated []string `json:"summariesCreated"`
TokensSaved int `json:"tokensSaved"`
LeafSummaries int `json:"leafSummaries"`
CondensedSummaries int `json:"condensedSummaries"`
}
// NeedsCompaction returns true if context tokens >= ContextThreshold × contextWindow.
func (e *CompactionEngine) NeedsCompaction(ctx context.Context, convID int64, contextWindow int) (bool, error) {
tokens, err := e.store.GetContextTokenCount(ctx, convID)
if err != nil {
return false, fmt.Errorf("get token count: %w", err)
}
threshold := int(float64(contextWindow) * ContextThreshold)
return tokens >= threshold, nil
}
// Close cancels the shutdown context, stopping async goroutines.
func (e *CompactionEngine) Close() {
if e.shutdownCancel != nil {
e.shutdownCancel()
}
}
// Compact runs leaf compaction (sync) and optionally condensed compaction.
func (e *CompactionEngine) Compact(ctx context.Context, convID int64, input CompactInput) (*CompactResult, error) {
result := &CompactResult{}
// Phase 1: leaf compaction (synchronous, every turn)
summaryID, err := e.compactLeaf(ctx, convID)
if err != nil {
return nil, fmt.Errorf("compact leaf: %w", err)
}
if summaryID != nil {
result.SummariesCreated = append(result.SummariesCreated, *summaryID)
result.LeafSummaries++
logger.InfoCF("seahorse", "compact: leaf", map[string]any{
"conv_id": convID,
"summary_id": *summaryID,
})
}
// Phase 2: condensed compaction if over threshold
tokensBefore, _ := e.store.GetContextTokenCount(ctx, convID)
var budget int
if input.Budget != nil {
budget = *input.Budget
if budget == 0 {
logger.ErrorCF("seahorse", "Compact: budget is 0, this should not happen", map[string]any{
"conv_id": convID,
})
}
} else {
budget = int(float64(tokensBefore) * ContextThreshold)
}
if input.Force || (tokensBefore > budget && budget > 0) {
// Launch async condensed compaction with dedup
if _, loaded := e.condensing.LoadOrStore(convID, struct{}{}); !loaded {
go func() {
defer e.condensing.Delete(convID)
e.runCondensedLoop(e.shutdownCtx, convID)
}()
}
}
tokensAfter, _ := e.store.GetContextTokenCount(ctx, convID)
if tokensAfter < tokensBefore {
result.TokensSaved = tokensBefore - tokensAfter
}
return result, nil
}
// CompactUntilUnder aggressively compacts until context is under budget.
func (e *CompactionEngine) CompactUntilUnder(ctx context.Context, convID int64, budget int) (*CompactResult, error) {
result := &CompactResult{}
prevTokens := 0
logger.InfoCF("seahorse", "compact_until_under: start", map[string]any{"conv_id": convID, "budget": budget})
for iter := 0; iter < MaxCompactIterations; iter++ {
tokens, err := e.store.GetContextTokenCount(ctx, convID)
if err != nil {
return result, fmt.Errorf("get tokens: %w", err)
}
if tokens <= budget {
logger.InfoCF("seahorse", "compact_until_under: done", map[string]any{
"conv_id": convID,
"budget": budget,
"tokens": tokens,
"leaf": result.LeafSummaries,
"condensed": result.CondensedSummaries,
})
return result, nil
}
// Try leaf first
summaryID, err := e.compactLeaf(ctx, convID, true)
if err != nil {
return result, err
}
if summaryID != nil {
result.SummariesCreated = append(result.SummariesCreated, *summaryID)
result.LeafSummaries++
logger.InfoCF("seahorse", "compact_until_under: leaf", map[string]any{
"conv_id": convID,
"summary_id": *summaryID,
})
continue
}
// Try condensed with forced fanout
condensedID, err := e.compactCondensed(ctx, convID)
if err != nil {
return result, err
}
if condensedID != nil {
result.SummariesCreated = append(result.SummariesCreated, *condensedID)
result.CondensedSummaries++
logger.InfoCF("seahorse", "compact_until_under: condensed", map[string]any{
"conv_id": convID,
"summary_id": *condensedID,
})
continue
}
// No progress
newTokens, _ := e.store.GetContextTokenCount(ctx, convID)
if newTokens >= prevTokens {
logger.WarnCF("seahorse", "compact_until_under: no progress", map[string]any{
"conv_id": convID,
"tokens": newTokens,
})
return result, nil
}
prevTokens = newTokens
}
// Safety cap exceeded — see MaxCompactIterations doc for rationale.
logger.WarnCF("seahorse", "compact_until_under: exceeded max iterations", map[string]any{
"conv_id": convID,
"budget": budget,
"iterations": MaxCompactIterations,
"tokens": prevTokens,
})
return result, nil
}
// compactLeaf compresses the oldest contiguous message chunk into a leaf summary.
// When force is true, FreshTailCount protection is bypassed (used by CompactUntilUnder).
func (e *CompactionEngine) compactLeaf(ctx context.Context, convID int64, force ...bool) (*string, error) {
items, err := e.store.GetContextItems(ctx, convID)
if err != nil {
return nil, err
}
// Find oldest contiguous message chunk outside fresh tail
msgCount := 0
msgTokens := 0
for _, item := range items {
if item.ItemType == "message" {
msgCount++
msgTokens += item.TokenCount
}
}
// Trigger if either message count or token threshold is met
if msgCount < LeafMinFanout && msgTokens < LeafChunkTokens {
return nil, nil
}
// Calculate fresh tail boundary (bypass when forced)
useForce := len(force) > 0 && force[0]
tailStartIdx := len(items) - FreshTailCount
if useForce {
tailStartIdx = len(items) // allow compacting everything
}
if tailStartIdx < 0 {
tailStartIdx = 0
}
// Find oldest contiguous message chunk, accumulating up to LeafChunkTokens
var chunk []ContextItem
chunkStart := -1
chunkEnd := -1
accumTokens := 0
for i := 0; i < tailStartIdx; i++ {
if items[i].ItemType == "message" {
if chunkStart == -1 {
chunkStart = i
}
chunkEnd = i
accumTokens += items[i].TokenCount
// Stop accumulating once we reach the token budget
if accumTokens >= LeafChunkTokens {
break
}
} else {
// Non-message breaks the chunk
if chunkStart != -1 && (chunkEnd-chunkStart+1) >= LeafMinFanout {
break
}
chunkStart = -1
chunkEnd = -1
accumTokens = 0
}
}
if chunkStart == -1 || (chunkEnd-chunkStart+1) < LeafMinFanout {
return nil, nil
}
chunk = items[chunkStart : chunkEnd+1]
// Collect messages for the chunk
var messages []Message
for _, item := range chunk {
msg, innerErr := e.store.GetMessageByID(ctx, item.MessageID)
if innerErr != nil {
return nil, innerErr
}
messages = append(messages, *msg)
}
// Get prior summaries for context
priorSummary := ""
priorCount := 0
for i := chunkStart - 1; i >= 0 && priorCount < 2; i-- {
if items[i].ItemType == "summary" {
sum, innerErr2 := e.store.GetSummary(ctx, items[i].SummaryID)
if innerErr2 == nil {
priorSummary = sum.Content + "\n" + priorSummary
priorCount++
}
}
}
// Generate summary
content, err := e.generateLeafSummary(ctx, messages, priorSummary)
if err != nil {
return nil, err
}
// Create summary in store
tokenCount := tokenizer.EstimateMessageTokens(providers.Message{Content: content})
var earliestAt, latestAt *time.Time
if len(messages) > 0 {
earliestAt = &messages[0].CreatedAt
latestAt = &messages[len(messages)-1].CreatedAt
}
summary, err := e.store.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindLeaf,
Depth: 0,
Content: content,
TokenCount: tokenCount,
EarliestAt: earliestAt,
LatestAt: latestAt,
SourceMessageTokens: sumMessageTokens(messages),
})
if err != nil {
return nil, err
}
// Link to source messages
msgIDs := make([]int64, len(messages))
for i, m := range messages {
msgIDs[i] = m.ID
}
if err := e.store.LinkSummaryToMessages(ctx, summary.SummaryID, msgIDs); err != nil {
return nil, err
}
// Replace context range with summary
if err := e.store.ReplaceContextRangeWithSummary(
ctx, convID, chunk[0].Ordinal, chunk[len(chunk)-1].Ordinal, summary.SummaryID,
); err != nil {
return nil, err
}
return &summary.SummaryID, nil
}
// compactCondensed compresses multiple summaries into one higher-level summary.
func (e *CompactionEngine) compactCondensed(ctx context.Context, convID int64) (*string, error) {
// Try ordinal-aware selection first (respects consecutive ordering)
var candidates []Summary
depths, err := e.store.GetDistinctDepthsInContext(ctx, convID, 0)
if err != nil {
return nil, err
}
for _, depth := range depths {
var chunkAtDepth []Summary
var err2 error
chunkAtDepth, err2 = e.selectOldestChunkAtDepth(ctx, convID, depth)
if err2 != nil {
continue
}
if len(chunkAtDepth) > 0 {
candidates = chunkAtDepth
break
}
}
// Fallback to depth-grouping selection
if len(candidates) == 0 {
candidates, err = e.selectShallowestCondensationCandidate(ctx, convID, false)
if err != nil {
return nil, err
}
}
if len(candidates) == 0 {
return nil, nil
}
// Generate condensed summary
content, err := e.generateCondensedSummary(ctx, candidates)
if err != nil {
return nil, err
}
// Merge metadata
maxDepth := 0
descendantCount := 0
descendantTokenCount := 0
sourceMessageTokens := 0
var earliestAt, latestAt *time.Time
parentIDs := make([]string, len(candidates))
for i, c := range candidates {
parentIDs[i] = c.SummaryID
if c.Depth > maxDepth {
maxDepth = c.Depth
}
descendantCount += c.DescendantCount + 1
descendantTokenCount += c.TokenCount + c.DescendantTokenCount
sourceMessageTokens += c.SourceMessageTokenCount
if c.EarliestAt != nil {
if earliestAt == nil || c.EarliestAt.Before(*earliestAt) {
earliestAt = c.EarliestAt
}
}
if c.LatestAt != nil {
if latestAt == nil || c.LatestAt.After(*latestAt) {
latestAt = c.LatestAt
}
}
}
tokenCount := tokenizer.EstimateMessageTokens(providers.Message{Content: content})
summary, err := e.store.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindCondensed,
Depth: maxDepth + 1,
Content: content,
TokenCount: tokenCount,
EarliestAt: earliestAt,
LatestAt: latestAt,
DescendantCount: descendantCount,
DescendantTokenCount: descendantTokenCount,
SourceMessageTokens: sourceMessageTokens,
ParentIDs: parentIDs,
})
if err != nil {
return nil, err
}
// Find the ordinal range for the candidate summaries in context
items, err := e.store.GetContextItems(ctx, convID)
if err != nil {
return nil, err
}
candidateSet := make(map[string]bool)
for _, c := range candidates {
candidateSet[c.SummaryID] = true
}
startOrd := -1
endOrd := -1
hasNonCandidate := false
for _, item := range items {
if item.ItemType == "summary" && candidateSet[item.SummaryID] {
if startOrd == -1 {
startOrd, endOrd = item.Ordinal, item.Ordinal
} else {
// Check for non-candidate items between endOrd and current ordinal
for _, it := range items {
if it.Ordinal > endOrd && it.Ordinal <= item.Ordinal {
if it.ItemType != "summary" || !candidateSet[it.SummaryID] {
hasNonCandidate = true
break
}
}
}
if hasNonCandidate {
break
}
if item.Ordinal < startOrd {
startOrd = item.Ordinal
}
if item.Ordinal > endOrd {
endOrd = item.Ordinal
}
}
}
}
if startOrd == -1 || endOrd == -1 {
return nil, nil
}
// Collect candidate summary IDs
candidateIDs := make([]string, 0, len(candidates))
for _, c := range candidates {
candidateIDs = append(candidateIDs, c.SummaryID)
}
if hasNonCandidate {
// Use safe per-item deletion to avoid deleting non-candidate items
if err := e.store.ReplaceContextItemsWithSummary(ctx, convID, candidateIDs, summary.SummaryID); err != nil {
return nil, err
}
} else {
// Candidates are consecutive, use efficient range deletion
if err := e.store.ReplaceContextRangeWithSummary(ctx, convID, startOrd, endOrd, summary.SummaryID); err != nil {
return nil, err
}
}
return &summary.SummaryID, nil
}
// selectShallowestCondensationCandidate finds the shallowest consecutive summary group.
func (e *CompactionEngine) selectShallowestCondensationCandidate(
ctx context.Context, convID int64, forced bool,
) ([]Summary, error) {
items, err := e.store.GetContextItems(ctx, convID)
if err != nil {
return nil, err
}
// Group by depth, find consecutive runs
tailStartIdx := len(items) - FreshTailCount
if tailStartIdx < 0 {
tailStartIdx = 0
}
minFanout := CondensedMinFanout
if forced {
minFanout = CondensedMinFanoutHard
}
// Track depth groups
depthGroups := make(map[int][]ContextItem)
for i := 0; i < tailStartIdx; i++ {
item := items[i]
if item.ItemType != "summary" {
continue
}
sum, err := e.store.GetSummary(ctx, item.SummaryID)
if err != nil {
continue
}
depthGroups[sum.Depth] = append(depthGroups[sum.Depth], item)
}
// Find shallowest depth with enough candidates
// Collect all depths and sort to handle non-consecutive depths
var depths []int
for depth := range depthGroups {
depths = append(depths, depth)
}
sort.Ints(depths)
for _, depth := range depths {
group := depthGroups[depth]
if len(group) >= minFanout {
// Load summaries
var result []Summary
for _, item := range group[:minFanout] {
sum, err := e.store.GetSummary(ctx, item.SummaryID)
if err != nil {
continue
}
result = append(result, *sum)
}
return result, nil
}
}
return nil, nil
}
// selectOldestChunkAtDepth scans context_items from oldest ordinal, collecting consecutive
// summaries at the given depth. Stops at non-summary items, different depth, fresh tail, or
// token overflow. Returns contiguous chunk of summaries.
func (e *CompactionEngine) selectOldestChunkAtDepth(
ctx context.Context, convID int64, targetDepth int,
) ([]Summary, error) {
items, err := e.store.GetContextItems(ctx, convID)
if err != nil {
return nil, err
}
tailStartIdx := len(items) - FreshTailCount
if tailStartIdx < 0 {
tailStartIdx = 0
}
var chunk []Summary
accumTokens := 0
for i := 0; i < tailStartIdx; i++ {
item := items[i]
if item.ItemType != "summary" {
// Non-summary breaks the chunk
break
}
sum, err := e.store.GetSummary(ctx, item.SummaryID)
if err != nil {
break
}
if sum.Depth != targetDepth {
// Different depth breaks the chunk
break
}
if accumTokens+sum.TokenCount > LeafChunkTokens {
// Token overflow stops collection
break
}
chunk = append(chunk, *sum)
accumTokens += sum.TokenCount
}
// Min tokens check: spec line 808
// chunk tokens must be >= max(CondensedTargetTokens, LeafChunkTokens × 0.1) = 2000
minTokens := CondensedTargetTokens // 2000
if accumTokens < minTokens {
return nil, nil
}
return chunk, nil
}
// generateLeafSummary calls the LLM to generate a leaf summary with 3-level escalation.
// Level 1: normal LLM prompt. Level 2: aggressive prompt. Level 3: deterministic truncation.
func (e *CompactionEngine) generateLeafSummary(
ctx context.Context,
messages []Message,
previousSummary string,
) (string, error) {
if e.complete == nil {
return truncateSummary(messages), nil
}
sourceText := formatMessagesForSummary(messages)
inputTokens := sumMessageTokens(messages)
targetTokens := minInt(LeafTargetTokens, int(float64(inputTokens)*0.35))
// Level 1: normal prompt
prompt := buildLeafSummaryPrompt(sourceText, previousSummary, targetTokens)
content, err := e.complete(ctx, prompt, CompleteOptions{
MaxTokens: LeafTargetTokens * 2,
Temperature: 0.3,
})
if err != nil {
return "", err
}
if content == "" {
// Retry with temperature=0
content, err = e.complete(ctx, prompt, CompleteOptions{
MaxTokens: LeafTargetTokens * 2,
Temperature: 0,
})
if err != nil {
return "", err
}
}
// Check if level 1 succeeded
if content != "" && tokenizer.EstimateMessageTokens(providers.Message{Content: content}) < inputTokens {
return content, nil
}
// Level 2: aggressive prompt
aggressiveTarget := minInt(640, int(float64(inputTokens)*0.20))
aggressivePrompt := buildAggressiveLeafSummaryPrompt(sourceText, previousSummary, aggressiveTarget)
content, err = e.complete(ctx, aggressivePrompt, CompleteOptions{
MaxTokens: aggressiveTarget * 2,
Temperature: 0.3,
})
if err != nil {
return "", err
}
if content == "" {
// Retry with temperature=0
content, err = e.complete(ctx, aggressivePrompt, CompleteOptions{
MaxTokens: aggressiveTarget * 2,
Temperature: 0,
})
if err != nil {
return "", err
}
}
if content != "" && tokenizer.EstimateMessageTokens(providers.Message{Content: content}) < inputTokens {
return content, nil
}
// Level 3: deterministic truncation
return truncateSummary(messages), nil
}
// generateCondensedSummary calls the LLM to generate a condensed summary with 3-level escalation.
func (e *CompactionEngine) generateCondensedSummary(ctx context.Context, summaries []Summary) (string, error) {
if e.complete == nil {
return truncateCondensedSummaries(summaries), nil
}
sourceText := formatSummariesForCondensation(summaries)
inputTokens := sumSummaryTokens(summaries)
targetTokens := minInt(CondensedTargetTokens, int(float64(inputTokens)*0.35))
// Level 1: normal prompt
prompt := buildCondensedSummaryPrompt(sourceText, targetTokens)
content, err := e.complete(ctx, prompt, CompleteOptions{
MaxTokens: CondensedTargetTokens * 2,
Temperature: 0.3,
})
if err != nil {
return "", err
}
if content == "" {
content, err = e.complete(ctx, prompt, CompleteOptions{
MaxTokens: CondensedTargetTokens * 2,
Temperature: 0,
})
if err != nil {
return "", err
}
}
if content != "" {
return content, nil
}
// Level 2: aggressive prompt
aggressiveTarget := minInt(640, int(float64(inputTokens)*0.20))
aggressivePrompt := buildCondensedSummaryPrompt(sourceText, aggressiveTarget)
content, err = e.complete(ctx, aggressivePrompt, CompleteOptions{
MaxTokens: aggressiveTarget * 2,
Temperature: 0.3,
})
if err != nil {
return "", err
}
if content != "" {
return content, nil
}
// Level 3: deterministic fallback
return truncateCondensedSummaries(summaries), nil
}
// runCondensedLoop runs condensed compaction in a loop until:
// a) context tokens <= threshold (success), OR
// b) No candidate found (nothing to condense), OR
// c) tokensAfter >= tokensBefore (no progress this iteration), OR
// d) tokensAfter >= previousTokens (no improvement over last iteration)
func (e *CompactionEngine) runCondensedLoop(ctx context.Context, convID int64) {
var prevTokens int
for {
select {
case <-ctx.Done():
return
default:
}
tokensBefore, err := e.store.GetContextTokenCount(ctx, convID)
if err != nil {
logger.ErrorCF("seahorse", "condensed: get tokens", map[string]any{"error": err.Error()})
return
}
condensedID, err := e.compactCondensed(ctx, convID)
if err != nil {
logger.ErrorCF("seahorse", "condensed: compact", map[string]any{"error": err.Error()})
return
}
if condensedID == nil {
// No candidate found
logger.DebugCF("seahorse", "condensed: no candidate", map[string]any{"conv_id": convID})
return
}
tokensAfter, _ := e.store.GetContextTokenCount(ctx, convID)
if tokensAfter >= tokensBefore {
// No progress this iteration
logger.DebugCF(
"seahorse",
"condensed: no progress",
map[string]any{"conv_id": convID, "tokens_before": tokensBefore, "tokens_after": tokensAfter},
)
return
}
if tokensAfter >= prevTokens && prevTokens > 0 {
// No improvement over last iteration
logger.DebugCF(
"seahorse",
"condensed: no improvement",
map[string]any{"conv_id": convID, "tokens": tokensAfter},
)
return
}
prevTokens = tokensAfter
}
}
// --- Helper functions ---
func formatMessagesForSummary(messages []Message) string {
var result string
for _, m := range messages {
ts := m.CreatedAt.Format("2006-01-02 15:04 MST")
content := m.Content
if content == "" && len(m.Parts) > 0 {
content = partsToReadableContent(m.Parts)
}
result += fmt.Sprintf("[%s]\n%s\n\n", ts, content)
}
return result
}
func formatSummariesForCondensation(summaries []Summary) string {
var result string
for _, s := range summaries {
earliest := ""
if s.EarliestAt != nil {
earliest = s.EarliestAt.Format("2006-01-02")
}
latest := ""
if s.LatestAt != nil {
latest = s.LatestAt.Format("2006-01-02")
}
result += fmt.Sprintf("[%s - %s]\n%s\n\n", earliest, latest, s.Content)
}
return result
}
func buildLeafSummaryPrompt(sourceText, previousSummary string, targetTokens int) string {
prev := "(none)"
if previousSummary != "" {
prev = previousSummary
}
return fmt.Sprintf(`You summarize a SEGMENT of a conversation for future model turns.
Treat this as incremental memory compaction input, not a full-conversation summary.
Normal summary policy:
- Preserve key decisions, rationale, constraints, and active tasks.
- Keep essential technical details needed to continue work safely.
- Remove obvious repetition and conversational filler.
Output requirements:
- Plain text only.
- No preamble, headings, or markdown formatting.
- Track file operations (created, modified, deleted, renamed) with file paths and current status.
- If no file operations appear, include exactly: "Files: none".
- End with exactly: "Expand for details about: <comma-separated list of what was dropped or compressed>".
- Target length: about %d tokens or less.
<previous_context>
%s
</previous_context>
<conversation_segment>
%s
</conversation_segment>`, targetTokens, prev, sourceText)
}
func buildCondensedSummaryPrompt(sourceText string, targetTokens int) string {
return fmt.Sprintf(`You condense multiple summaries into a single higher-level summary.
Preserve all important decisions, constraints, and outcomes.
Merge overlapping topics. Keep technical details intact.
Output requirements:
- Plain text only.
- No preamble, headings, or markdown formatting.
- End with exactly: "Expand for details about: <comma-separated list>".
- Target length: about %d tokens or less.
<summaries>
%s
</summaries>`, targetTokens, sourceText)
}
func buildAggressiveLeafSummaryPrompt(sourceText, previousSummary string, targetTokens int) string {
prev := "(none)"
if previousSummary != "" {
prev = previousSummary
}
return fmt.Sprintf(`You summarize a SEGMENT of a conversation for future model turns.
Aggressive summary policy:
- Keep only durable facts and current task state.
- Remove examples, repetition, and low-value narrative details.
- Preserve explicit TODOs, blockers, decisions, and constraints.
Output requirements:
- Plain text only.
- No preamble, headings, or markdown formatting.
- Track file operations (created, modified, deleted, renamed) with file paths and current status.
- If no file operations appear, include exactly: "Files: none".
- End with exactly: "Expand for details about: <comma-separated list of what was dropped or compressed>".
- Target length: about %d tokens or less.
<previous_context>
%s
</previous_context>
<conversation_segment>
%s
</conversation_segment>`, targetTokens, prev, sourceText)
}
func truncateSummary(messages []Message) string {
content := ""
for _, m := range messages {
c := m.Content
if c == "" && len(m.Parts) > 0 {
c = partsToReadableContent(m.Parts)
}
content += c + "\n"
}
if len(content) > 2048 {
content = content[:2048]
}
content += fmt.Sprintf("\n[Truncated from %d messages]", len(messages))
return content
}
func truncateCondensedSummaries(summaries []Summary) string {
content := ""
for _, s := range summaries {
content += s.Content + "\n"
}
if len(content) > 2048 {
content = content[:2048]
}
content += fmt.Sprintf("\n[Condensed from %d summaries]", len(summaries))
return content
}
func sumMessageTokens(messages []Message) int {
total := 0
for _, m := range messages {
total += m.TokenCount
}
return total
}
func sumSummaryTokens(summaries []Summary) int {
total := 0
for _, s := range summaries {
total += s.TokenCount
}
return total
}
func minInt(a, b int) int {
if a < b {
return a
}
return b
}
+974
View File
@@ -0,0 +1,974 @@
package seahorse
import (
"context"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
)
// --- Test Helpers ---
// waitForCondensed blocks until the async condensed goroutine for convID finishes.
// Returns false if timeout is reached.
func waitForCondensed(ce *CompactionEngine, convID int64, timeout time.Duration) bool {
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
if _, exists := ce.condensing.Load(convID); !exists {
return true
}
time.Sleep(50 * time.Millisecond)
}
return false
}
// --- Compaction Tests ---
func newTestCompactionEngine(t *testing.T) (*CompactionEngine, *Store, int64) {
t.Helper()
db := openTestDB(t)
if err := runSchema(db); err != nil {
t.Fatalf("migration: %v", err)
}
s := &Store{db: db}
ctx := context.Background()
conv, _ := s.GetOrCreateConversation(ctx, "test:compact")
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
ce := &CompactionEngine{
store: s,
config: Config{},
complete: mockCompleteFn,
shutdownCtx: shutdownCtx,
shutdownCancel: shutdownCancel,
}
convID := conv.ConversationID
// Ensure async goroutines are stopped before database is closed.
// Register cleanup here (after openTestDB) so it runs BEFORE openTestDB's db.Close().
t.Cleanup(func() {
shutdownCancel()
// Wait for async condensed goroutine to finish (poll condensing map)
deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
if _, exists := ce.condensing.Load(convID); !exists {
break
}
time.Sleep(50 * time.Millisecond)
}
})
return ce, s, conv.ConversationID
}
// newTestCompactionEngineWithStore creates a CompactionEngine with existing store.
// Note: Caller is responsible for calling shutdownCancel when test ends.
func newTestCompactionEngineWithStore(
s *Store, complete CompleteFn,
) (ce *CompactionEngine, shutdownCancel context.CancelFunc) {
shutdownCtx, cancel := context.WithCancel(context.Background())
return &CompactionEngine{
store: s,
config: Config{},
complete: complete,
shutdownCtx: shutdownCtx,
shutdownCancel: cancel,
}, cancel
}
// mockCompleteFn returns a simple summary for testing
var mockCompleteFn CompleteFn = func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) {
return "Mock summary of the conversation segment.", nil
}
func TestNeedsCompaction(t *testing.T) {
ce, s, convID := newTestCompactionEngine(t)
ctx := context.Background()
// Empty context — no compaction needed
needed, err := ce.NeedsCompaction(ctx, convID, 10000)
if err != nil {
t.Fatalf("NeedsCompaction: %v", err)
}
if needed {
t.Error("expected no compaction for empty context")
}
// Add messages to context, total tokens = 8000
for i := 0; i < 8; i++ {
m, _ := s.AddMessage(ctx, convID, "user", "test message content", 1000)
s.AppendContextMessage(ctx, convID, m.ID)
}
// Threshold = 0.75 × 10000 = 7500. We have 8000 tokens → needs compaction
needed, err = ce.NeedsCompaction(ctx, convID, 10000)
if err != nil {
t.Fatalf("NeedsCompaction: %v", err)
}
if !needed {
t.Error("expected compaction needed at 8000/10000 tokens (threshold 75%)")
}
// Below threshold: 5000 / 10000 → no compaction
s.UpsertContextItems(ctx, convID, nil) // clear
for i := 0; i < 5; i++ {
m, _ := s.AddMessage(ctx, convID, "user", "test", 1000)
s.AppendContextMessage(ctx, convID, m.ID)
}
needed, _ = ce.NeedsCompaction(ctx, convID, 10000)
if needed {
t.Error("expected no compaction at 5000/10000 tokens")
}
}
func TestCompactLeaf(t *testing.T) {
ce, s, convID := newTestCompactionEngine(t)
ctx := context.Background()
// Create enough messages to trigger leaf compaction:
// Need > FreshTailCount(32) evictable messages with >= LeafMinFanout(8) contiguous
for i := 0; i < 40; i++ {
m, _ := s.AddMessage(ctx, convID, "user", "message content for compaction test", 100)
s.AppendContextMessage(ctx, convID, m.ID)
}
// Compact
result, err := ce.Compact(ctx, convID, CompactInput{})
if err != nil {
t.Fatalf("Compact: %v", err)
}
if result == nil {
t.Fatal("expected non-nil result")
}
// Should have created at least one leaf summary
if result.LeafSummaries == 0 {
t.Error("expected at least 1 leaf summary")
}
// Context should now contain a summary item
items, _ := s.GetContextItems(ctx, convID)
foundSummary := false
for _, item := range items {
if item.ItemType == "summary" {
foundSummary = true
break
}
}
if !foundSummary {
t.Error("expected a summary in context_items after leaf compaction")
}
// Some messages should have been replaced
if len(result.SummariesCreated) == 0 {
t.Error("expected at least 1 summary created")
}
}
func TestCompactLeafNoCandidate(t *testing.T) {
ce, _, convID := newTestCompactionEngine(t)
ctx := context.Background()
// Too few messages to trigger leaf compaction
m, _ := ce.store.AddMessage(ctx, convID, "user", "short", 10)
ce.store.AppendContextMessage(ctx, convID, m.ID)
result, err := ce.Compact(ctx, convID, CompactInput{})
if err != nil {
t.Fatalf("Compact: %v", err)
}
if result == nil {
t.Fatal("expected non-nil result even with no candidate")
}
if result.LeafSummaries != 0 {
t.Errorf("LeafSummaries = %d, want 0 (too few messages)", result.LeafSummaries)
}
}
func TestCompactCondensed(t *testing.T) {
ce, s, convID := newTestCompactionEngine(t)
ctx := context.Background()
// Create enough leaf summaries and fresh messages to enable condensation
leafIDs := make([]string, CondensedMinFanout)
for i := 0; i < CondensedMinFanout; i++ {
now := time.Now().UTC()
summary, err := s.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindLeaf,
Depth: 0,
Content: "leaf summary content " + time.Now().String(),
TokenCount: 500,
EarliestAt: &now,
LatestAt: &now,
})
if err != nil {
t.Fatalf("CreateSummary %d: %v", i, err)
}
leafIDs[i] = summary.SummaryID
s.AppendContextSummary(ctx, convID, summary.SummaryID)
}
// Add enough fresh messages to have a fresh tail (>= FreshTailCount)
for i := 0; i < FreshTailCount; i++ {
m, _ := s.AddMessage(ctx, convID, "user", "fresh message", 10)
s.AppendContextMessage(ctx, convID, m.ID)
}
// Compact with force to trigger condensation
_, err := ce.Compact(ctx, convID, CompactInput{Force: true})
if err != nil {
t.Fatalf("Compact: %v", err)
}
// Wait for async condensed goroutine to complete
if !waitForCondensed(ce, convID, 2*time.Second) {
t.Fatal("timeout waiting for condensed compaction")
}
// Should have created a condensed summary in the DB
summaries, _ := s.GetSummariesByConversation(ctx, convID)
foundCondensed := false
for _, sum := range summaries {
if sum.Kind == SummaryKindCondensed {
foundCondensed = true
break
}
}
if !foundCondensed {
t.Error("expected at least 1 condensed summary")
}
}
func TestCompactCondensedDoesNotOrphanSummaryWhenCandidatesRemovedConcurrently(t *testing.T) {
// Reproduce orphan bug: candidates found by selectOldestChunkAtDepth are removed
// from context_items between candidate selection and ordinal range scan.
// Use a slow CompleteFn with barrier sync to control timing.
s := openTestStore(t)
ctx := context.Background()
conv, _ := s.GetOrCreateConversation(ctx, "test:orphan-race")
convID := conv.ConversationID
// Create leaf summaries with enough tokens for condensation
var leafIDs []string
for i := 0; i < CondensedMinFanout; i++ {
now := time.Now().UTC()
sum, err := s.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindLeaf,
Depth: 0,
Content: fmt.Sprintf("leaf summary %d", i),
TokenCount: 500,
EarliestAt: &now,
LatestAt: &now,
})
if err != nil {
t.Fatalf("CreateSummary: %v", err)
}
leafIDs = append(leafIDs, sum.SummaryID)
s.AppendContextSummary(ctx, convID, sum.SummaryID)
}
// Add fresh tail so leaf summaries are in evictable range
for i := 0; i < FreshTailCount+1; i++ {
m, _ := s.AddMessage(ctx, convID, "user", "fresh", 10)
s.AppendContextMessage(ctx, convID, m.ID)
}
// Barrier: CompleteFn waits until test removes context_items, then returns
var barrier1, barrier2 sync.WaitGroup
barrier1.Add(1) // CompleteFn signals when called
barrier2.Add(1) // test signals when context_items removed
slowComplete := func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) {
barrier1.Done() // signal: LLM called, candidates selected
barrier2.Wait() // wait: test removes context_items
return "Condensed summary.", nil
}
ce, cancel := newTestCompactionEngineWithStore(s, slowComplete)
t.Cleanup(func() {
cancel()
time.Sleep(100 * time.Millisecond)
})
// Run compactCondensed in background
type compactResult struct {
summaryID *string
err error
}
resultCh := make(chan compactResult, 1)
go func() {
sid, err := ce.compactCondensed(context.Background(), convID)
resultCh <- compactResult{summaryID: sid, err: err}
}()
// Wait for CompleteFn to be called (candidates selected)
barrier1.Wait()
// Remove leaf summaries from context_items (simulating concurrent replacement)
items, _ := s.GetContextItems(ctx, convID)
var preserved []ContextItem
for _, item := range items {
isLeaf := false
for _, lid := range leafIDs {
if item.SummaryID == lid {
isLeaf = true
break
}
}
if !isLeaf {
preserved = append(preserved, item)
}
}
s.UpsertContextItems(ctx, convID, preserved)
// Let CompleteFn return
barrier2.Done()
// Get result
res := <-resultCh
if res.err != nil {
t.Fatalf("compactCondensed: %v", res.err)
}
// With the bug: returns non-nil summaryID even though context_items has no matching ordinals
// The fix: should return nil when startOrd == -1
if res.summaryID != nil {
t.Errorf("compactCondensed returned summaryID=%s, want nil (orphan created)", *res.summaryID)
// Verify the orphan exists in DB
summary, _ := s.GetSummary(context.Background(), *res.summaryID)
if summary != nil && summary.Kind == SummaryKindCondensed {
// Check it's NOT in context_items (orphan)
items2, _ := s.GetContextItems(context.Background(), convID)
found := false
for _, item := range items2 {
if item.SummaryID == *res.summaryID {
found = true
break
}
}
if !found {
t.Error("condensed summary exists in DB but not in context_items — orphan confirmed")
}
}
}
}
func TestCompactUntilUnder(t *testing.T) {
ce, s, convID := newTestCompactionEngine(t)
ctx := context.Background()
// Create many leaf summaries to ensure we can condense
for i := 0; i < 8; i++ {
now := time.Now().UTC()
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindLeaf,
Depth: 0,
Content: "leaf summary for condensation test",
TokenCount: 500,
EarliestAt: &now,
LatestAt: &now,
})
s.AppendContextSummary(ctx, convID, summary.SummaryID)
}
// Force compact until under budget
result, err := ce.CompactUntilUnder(ctx, convID, 2000)
if err != nil {
t.Fatalf("CompactUntilUnder: %v", err)
}
if result == nil {
t.Fatal("expected non-nil result")
}
}
func TestSelectShallowestCondensationCandidate(t *testing.T) {
ce, s, convID := newTestCompactionEngine(t)
ctx := context.Background()
// Create enough leaf summaries + fresh messages for candidates
for i := 0; i < LeafMinFanout; i++ {
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindLeaf,
Depth: 0,
Content: "leaf",
TokenCount: 100,
})
s.AppendContextSummary(ctx, convID, summary.SummaryID)
}
// Add fresh tail messages so summaries are in evictable range
for i := 0; i < FreshTailCount+1; i++ {
m, _ := s.AddMessage(ctx, convID, "user", "fresh", 5)
s.AppendContextMessage(ctx, convID, m.ID)
}
candidates, err := ce.selectShallowestCondensationCandidate(ctx, convID, false)
if err != nil {
t.Fatalf("selectShallowestCondensationCandidate: %v", err)
}
// Should find leaf summaries at depth 0
if len(candidates) < CondensedMinFanout {
t.Errorf("candidates = %d, want >= %d", len(candidates), CondensedMinFanout)
}
}
func TestSelectShallowestCondensationCandidateEmpty(t *testing.T) {
ce, _, convID := newTestCompactionEngine(t)
ctx := context.Background()
candidates, err := ce.selectShallowestCondensationCandidate(ctx, convID, false)
if err != nil {
t.Fatalf("selectShallowestCondensationCandidate: %v", err)
}
if len(candidates) != 0 {
t.Errorf("candidates = %d, want 0 for empty context", len(candidates))
}
}
func TestCompactCondensedUsesSelectOldestChunk(t *testing.T) {
// Verify that compactCondensed prefers ordinal-ordered chunks via selectOldestChunkAtDepth
// rather than just grouping by depth without regard to order
ce, s, convID := newTestCompactionEngine(t)
ctx := context.Background()
// Create interleaved summaries at depth 0 with a message in between:
// sum1 (ordinal 100), msg (ordinal 200), sum2 (ordinal 300)
for i := 0; i < LeafMinFanout+2; i++ {
now := time.Now().UTC()
s.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindLeaf,
Depth: 0,
Content: fmt.Sprintf("leaf summary %d", i),
TokenCount: 100,
EarliestAt: &now,
LatestAt: &now,
})
}
// Insert a message between first two summaries to break contiguity
// for selectShallowestCondensationCandidate but would still find all 3
// but selectOldestChunkAtDepth should only find sum1 + sum2 (not sum3)
msg, _ := s.AddMessage(ctx, convID, "user", "interrupting message", 5)
s.AppendContextMessage(ctx, convID, msg.ID)
// Run compactCondensed
result, err := ce.compactCondensed(ctx, convID)
if err != nil {
t.Fatalf("compactCondensed: %v", err)
}
// The result should have merged the two summaries at the start
// (skipping the message in between), This proves ordinal-aware selection works.
_ = result // verify summary was created
if result != nil {
summaries, _ := s.GetSummariesByConversation(ctx, convID)
found := false
for _, sum := range summaries {
if sum.Kind == SummaryKindCondensed {
found = true
break
}
}
if !found {
t.Error("expected condensed summary to be created via ordinal-aware selection")
}
}
}
func TestCompactCondensedUsesOrdinalAwareSelection(t *testing.T) {
ce, s, convID := newTestCompactionEngine(t)
ctx := context.Background()
// Create leaf summaries at depth 0 (total tokens >= CondensedTargetTokens)
for i := 0; i < 5; i++ {
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindLeaf,
Depth: 0,
Content: fmt.Sprintf("leaf summary %d", i),
TokenCount: 500, // 5 × 500 = 2500 >= CondensedTargetTokens (2000)
})
s.AppendContextSummary(ctx, convID, summary.SummaryID)
}
// Add fresh tail
for i := 0; i < FreshTailCount+1; i++ {
m, _ := s.AddMessage(ctx, convID, "user", "fresh", 5)
s.AppendContextMessage(ctx, convID, m.ID)
}
chunk, err := ce.selectOldestChunkAtDepth(ctx, convID, 0)
if err != nil {
t.Fatalf("selectOldestChunkAtDepth: %v", err)
}
if len(chunk) < 2 {
t.Errorf("chunk length = %d, want >= 2 contiguous summaries", len(chunk))
}
for _, s := range chunk {
if s.Depth != 0 {
t.Errorf("got depth %d, want 0", s.Depth)
}
}
}
func TestSelectOldestChunkAtDepthBreaksOnMessage(t *testing.T) {
ce, s, convID := newTestCompactionEngine(t)
ctx := context.Background()
// Create 3 summaries, then a message, then 3 more summaries
for i := 0; i < 3; i++ {
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindLeaf,
Depth: 0,
Content: fmt.Sprintf("leaf %d", i),
TokenCount: 100,
})
s.AppendContextSummary(ctx, convID, summary.SummaryID)
}
msg, _ := s.AddMessage(ctx, convID, "user", "break", 10)
s.AppendContextMessage(ctx, convID, msg.ID)
for i := 0; i < 3; i++ {
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindLeaf,
Depth: 0,
Content: fmt.Sprintf("leaf-after %d", i),
TokenCount: 100,
})
s.AppendContextSummary(ctx, convID, summary.SummaryID)
}
for i := 0; i < FreshTailCount+1; i++ {
m, _ := s.AddMessage(ctx, convID, "user", "fresh", 5)
s.AppendContextMessage(ctx, convID, m.ID)
}
chunk, _ := ce.selectOldestChunkAtDepth(ctx, convID, 0)
if len(chunk) > 3 {
t.Errorf("chunk length = %d, want <= 3 (message breaks chain)", len(chunk))
}
}
func TestSelectOldestChunkAtDepthMinTokens(t *testing.T) {
ce, s, convID := newTestCompactionEngine(t)
ctx := context.Background()
// Create summaries with very low token counts (total < 2000)
for i := 0; i < 5; i++ {
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindLeaf,
Depth: 0,
Content: fmt.Sprintf("tiny summary %d", i),
TokenCount: 50, // very small
})
s.AppendContextSummary(ctx, convID, summary.SummaryID)
}
// Add fresh tail to protect from compaction
for i := 0; i < FreshTailCount+1; i++ {
m, _ := s.AddMessage(ctx, convID, "user", fmt.Sprintf("tail %d", i), 10)
s.AppendContextMessage(ctx, convID, m.ID)
}
// Should return nil because total tokens (250) < 2000 minimum
chunk, err := ce.selectOldestChunkAtDepth(ctx, convID, 0)
if err != nil {
t.Fatalf("selectOldestChunkAtDepth: %v", err)
}
if len(chunk) > 0 {
t.Errorf("expected empty chunk when tokens < 2000, got %d summaries", len(chunk))
}
}
func TestSelectOldestChunkAtDepthPassesMinTokens(t *testing.T) {
ce, s, convID := newTestCompactionEngine(t)
ctx := context.Background()
// Create summaries with enough tokens (total >= 2000)
for i := 0; i < 5; i++ {
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindLeaf,
Depth: 0,
Content: fmt.Sprintf(
"substantial summary with enough content to meet minimum token threshold for condensation candidate %d",
i,
),
TokenCount: 500, // 5 × 500 = 2500 >= 2000
})
s.AppendContextSummary(ctx, convID, summary.SummaryID)
}
// Add fresh tail
for i := 0; i < FreshTailCount+1; i++ {
m, _ := s.AddMessage(ctx, convID, "user", fmt.Sprintf("tail %d", i), 10)
s.AppendContextMessage(ctx, convID, m.ID)
}
// Should return chunk because total tokens (2500) >= 2000
chunk, err := ce.selectOldestChunkAtDepth(ctx, convID, 0)
if err != nil {
t.Fatalf("selectOldestChunkAtDepth: %v", err)
}
if len(chunk) == 0 {
t.Error("expected non-empty chunk when tokens >= 2000")
}
}
func TestGenerateLeafSummary(t *testing.T) {
ce, _, _ := newTestCompactionEngine(t)
ctx := context.Background()
msgs := []Message{
{Role: "user", Content: "hello world", TokenCount: 5},
{Role: "assistant", Content: "hi there", TokenCount: 5},
}
content, err := ce.generateLeafSummary(ctx, msgs, "")
if err != nil {
t.Fatalf("generateLeafSummary: %v", err)
}
if content == "" {
t.Error("expected non-empty summary content")
}
}
func TestGenerateLeafSummaryEscalationToAggressive(t *testing.T) {
// Level 1 returns summary that's too large (tokens >= input), should escalate to level 2
var calls []string
escalateComplete := func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) {
if contains(prompt, "Aggressive summary policy") {
calls = append(calls, "aggressive")
return "Short aggressive summary.", nil
}
calls = append(calls, "normal")
// Return a very long summary to trigger escalation
longContent := make([]byte, 5000)
for i := range longContent {
longContent[i] = 'x'
}
return string(longContent), nil
}
s := openTestStore(t)
ce, _ := newTestCompactionEngineWithStore(s, escalateComplete)
msgs := []Message{
{Role: "user", Content: "hello world", TokenCount: 10},
{Role: "assistant", Content: "response", TokenCount: 10},
}
content, err := ce.generateLeafSummary(context.Background(), msgs, "")
if err != nil {
t.Fatalf("generateLeafSummary: %v", err)
}
if content == "" {
t.Error("expected non-empty summary content")
}
// Should have called both normal and aggressive
foundNormal := false
foundAggressive := false
for _, c := range calls {
if c == "normal" {
foundNormal = true
}
if c == "aggressive" {
foundAggressive = true
}
}
if !foundNormal {
t.Error("expected normal LLM call")
}
if !foundAggressive {
t.Error("expected aggressive LLM call (level 2 escalation)")
}
}
func TestGenerateLeafSummaryEscalationToTruncation(t *testing.T) {
// Both normal and aggressive return empty, should escalate to level 3 truncation
emptyComplete := func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) {
return "", nil
}
s := openTestStore(t)
ce, _ := newTestCompactionEngineWithStore(s, emptyComplete)
msgs := []Message{
{Role: "user", Content: "hello world from test", TokenCount: 10},
{Role: "assistant", Content: "response text here", TokenCount: 10},
}
content, err := ce.generateLeafSummary(context.Background(), msgs, "")
if err != nil {
t.Fatalf("generateLeafSummary: %v", err)
}
// Level 3 truncation should have produced something
if content == "" {
t.Error("expected non-empty content from level 3 truncation fallback")
}
if !contains(content, "Truncated from") {
t.Errorf("expected truncation marker in content: %q", content)
}
}
func TestGenerateCondensedSummary(t *testing.T) {
ce, _, _ := newTestCompactionEngine(t)
ctx := context.Background()
summaries := []Summary{
{SummaryID: "sum_a", Content: "first summary", TokenCount: 100},
{SummaryID: "sum_b", Content: "second summary", TokenCount: 100},
}
content, err := ce.generateCondensedSummary(ctx, summaries)
if err != nil {
t.Fatalf("generateCondensedSummary: %v", err)
}
if content == "" {
t.Error("expected non-empty condensed summary content")
}
}
func TestGenerateCondensedSummaryEscalation(t *testing.T) {
// When LLM returns empty, should fall back to deterministic concatenation
emptyComplete := func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) {
return "", nil
}
s := openTestStore(t)
ce, _ := newTestCompactionEngineWithStore(s, emptyComplete)
summaries := []Summary{
{SummaryID: "sum_a", Content: "first summary text", TokenCount: 50},
{SummaryID: "sum_b", Content: "second summary text", TokenCount: 50},
}
content, err := ce.generateCondensedSummary(context.Background(), summaries)
if err != nil {
t.Fatalf("generateCondensedSummary: %v", err)
}
// Should fall back to concatenation
if content == "" {
t.Error("expected non-empty content from fallback")
}
}
// --- Async Condensed Compaction (Phase 2) ---
func TestCompactAsyncReturnsBeforeCondensed(t *testing.T) {
// Use a slow CompleteFn to verify Compact returns before condensed finishes
var callCount int32
slowComplete := func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) {
atomic.AddInt32(&callCount, 1)
time.Sleep(500 * time.Millisecond) // simulate slow LLM
return "Slow condensed summary.", nil
}
s := openTestStore(t)
ctx := context.Background()
conv, _ := s.GetOrCreateConversation(ctx, "test:async")
convID := conv.ConversationID
ce, cancel := newTestCompactionEngineWithStore(s, slowComplete)
t.Cleanup(func() {
cancel()
time.Sleep(100 * time.Millisecond)
})
// Create enough leaf summaries for condensation + fresh tail
for i := 0; i < CondensedMinFanout; i++ {
now := time.Now().UTC()
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindLeaf,
Depth: 0,
Content: "leaf for async test",
TokenCount: 500,
EarliestAt: &now,
LatestAt: &now,
})
s.AppendContextSummary(ctx, convID, summary.SummaryID)
}
for i := 0; i < FreshTailCount; i++ {
m, _ := s.AddMessage(ctx, convID, "user", "fresh", 10)
s.AppendContextMessage(ctx, convID, m.ID)
}
// Compact with force — should return quickly, condensed runs async
start := time.Now()
result, err := ce.Compact(ctx, convID, CompactInput{Force: true})
elapsed := time.Since(start)
if err != nil {
t.Fatalf("Compact: %v", err)
}
if result == nil {
t.Fatal("expected non-nil result")
}
// Should return well before the 500ms LLM call
if elapsed > 200*time.Millisecond {
t.Errorf("Compact took %v, should return before async condensed finishes", elapsed)
}
// Wait for async to complete
time.Sleep(800 * time.Millisecond)
// Verify condensed summary was created by background goroutine
summaries, _ := s.GetSummariesByConversation(ctx, convID)
foundCondensed := false
for _, sum := range summaries {
if sum.Kind == SummaryKindCondensed {
foundCondensed = true
break
}
}
if !foundCondensed {
t.Error("expected at least one condensed summary from async Phase 2")
}
}
func TestCompactAsyncDedup(t *testing.T) {
var callCount int32
slowComplete := func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) {
atomic.AddInt32(&callCount, 1)
time.Sleep(300 * time.Millisecond)
return "Slow condensed summary.", nil
}
s := openTestStore(t)
ctx := context.Background()
conv, _ := s.GetOrCreateConversation(ctx, "test:dedup")
convID := conv.ConversationID
ce, cancel := newTestCompactionEngineWithStore(s, slowComplete)
t.Cleanup(func() {
cancel()
waitForCondensed(ce, convID, 2*time.Second)
})
// Create conditions for condensed compaction
for i := 0; i < CondensedMinFanout; i++ {
now := time.Now().UTC()
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindLeaf,
Depth: 0,
Content: "leaf for dedup",
TokenCount: 500,
EarliestAt: &now,
LatestAt: &now,
})
s.AppendContextSummary(ctx, convID, summary.SummaryID)
}
for i := 0; i < FreshTailCount; i++ {
m, _ := s.AddMessage(ctx, convID, "user", "fresh", 10)
s.AppendContextMessage(ctx, convID, m.ID)
}
// Call Compact twice rapidly
ce.Compact(ctx, convID, CompactInput{Force: true})
ce.Compact(ctx, convID, CompactInput{Force: true})
// Wait for async to finish
time.Sleep(600 * time.Millisecond)
// LLM should only be called once for condensed (dedup)
// callCount may be 0 if no leaf was created (only condensed in goroutine)
// The key is that we don't get 2+ condensed calls
if atomic.LoadInt32(&callCount) > 1 {
t.Errorf("LLM called %d times, expected at most 1 (dedup)", callCount)
}
}
func TestCompactLeafForceBypassesFreshTail(t *testing.T) {
// Spec: compactLeaf with force=true should bypass FreshTailCount protection
// so CompactUntilUnder can compress messages inside the fresh tail
ce, s, convID := newTestCompactionEngine(t)
ctx := context.Background()
// Create exactly FreshTailCount+4 messages (36 total)
// Without force: all messages are in fresh tail → no candidate
// With force: should compact the oldest messages
total := FreshTailCount + 4
for i := 0; i < total; i++ {
m, _ := s.AddMessage(ctx, convID, "user", fmt.Sprintf("message %d for force test", i), 100)
s.AppendContextMessage(ctx, convID, m.ID)
}
// Without force: should return nil (all in fresh tail)
summaryID, err := ce.compactLeaf(ctx, convID)
if err != nil {
t.Fatalf("compactLeaf no-force: %v", err)
}
if summaryID != nil {
t.Error("expected nil without force (all messages in fresh tail)")
}
// With force: should compact despite fresh tail protection
summaryID, err = ce.compactLeaf(ctx, convID, true)
if err != nil {
t.Fatalf("compactLeaf force: %v", err)
}
if summaryID == nil {
t.Error("expected summary with force=true (bypasses fresh tail)")
}
}
func TestCompactLeafAccumulatesUpToLeafChunkTokens(t *testing.T) {
// Spec: compactLeaf should accumulate messages up to LeafChunkTokens before stopping
// It should NOT take the entire contiguous chunk regardless of token count
ce, s, convID := newTestCompactionEngine(t)
ctx := context.Background()
// Create messages totaling far more than LeafChunkTokens (20000)
// Each message is ~500 tokens, create 80 messages = 40000 tokens
for i := 0; i < 80; i++ {
m, _ := s.AddMessage(
ctx,
convID,
"user",
fmt.Sprintf(
"message %d with lots of content to make it big enough for token counting purposes and this should be a substantial message body that represents a meaningful conversation turn",
i,
),
500,
)
s.AppendContextMessage(ctx, convID, m.ID)
}
summaryID, err := ce.compactLeaf(ctx, convID)
if err != nil {
t.Fatalf("compactLeaf: %v", err)
}
if summaryID == nil {
t.Fatal("expected a summary to be created")
}
// The source messages that were compacted should total roughly LeafChunkTokens (20000),
// not the entire 40000 tokens worth of messages
summary, _ := s.GetSummary(ctx, *summaryID)
if summary == nil {
t.Fatal("summary not found")
}
// Source message tokens should be roughly <= LeafChunkTokens (20000)
// Spec says: "Stop when accumulated tokens >= LeafChunkTokens"
if summary.SourceMessageTokenCount > LeafChunkTokens {
t.Errorf("source tokens = %d, should be <= LeafChunkTokens (%d)",
summary.SourceMessageTokenCount, LeafChunkTokens)
}
}
+30
View File
@@ -0,0 +1,30 @@
package seahorse
// Short-term memory configuration constants — all are experience-based defaults.
const (
// OrdinalStep is the gap between ordinals in context_items.
// Insert at midpoint; resequence only when precision exhausted.
OrdinalStep = 100
// ContextThreshold is the compaction trigger for the context window.
ContextThreshold float64 = 0.75 // Compact at 75% of context window
FreshTailCount int = 32 // Recent messages protected from compaction
// LeafMinFanout is the fanout parameter.
LeafMinFanout int = 8 // Min messages per leaf summary
CondensedMinFanout int = 4 // Min summaries per condensed
CondensedMinFanoutHard int = 2 // Min for forced compaction
// LeafChunkTokens is the token target.
LeafChunkTokens int = 20000 // Max tokens per leaf chunk
LeafTargetTokens int = 1200 // Target tokens for leaf summaries
CondensedTargetTokens int = 2000 // Target tokens for condensed summaries
MaxExpandTokens int = 4000 // Token cap for expansion queries
// MaxCompactIterations caps CompactUntilUnder to prevent infinite loops.
// Each iteration reduces ~4x tokens via leaf (8:1) or condensed (4:1) compaction.
// With a 200k token context window and 75% threshold, ~20 iterations is enough
// for any realistic scenario. If exceeded, the issue is logged as a warning.
MaxCompactIterations int = 20
)
+568
View File
@@ -0,0 +1,568 @@
package seahorse
import (
"context"
"database/sql"
"fmt"
"os"
"path/filepath"
"regexp"
"strings"
"sync"
_ "modernc.org/sqlite"
"github.com/sipeed/picoclaw/pkg/logger"
)
// Config holds engine configuration.
type Config struct {
DBPath string `json:"dbPath"`
IgnoreSessionPatterns []string `json:"ignoreSessionPatterns,omitempty"`
StatelessSessionPatterns []string `json:"statelessSessionPatterns,omitempty"`
}
// CompleteFn is the LLM completion function type.
type CompleteFn func(ctx context.Context, prompt string, opts CompleteOptions) (string, error)
// CompleteOptions holds LLM completion parameters.
type CompleteOptions struct {
Model string
MaxTokens int
Temperature float64
}
// IngestResult is the result of message ingestion.
type IngestResult struct {
MessageCount int `json:"messageCount"`
TokenCount int `json:"tokenCount"`
}
// AssembleInput controls context assembly.
type AssembleInput struct {
Budget int `json:"budget"`
Query string `json:"query,omitempty"`
}
// AssembleResult contains assembled context.
type AssembleResult struct {
Messages []Message `json:"messages"`
Summary string `json:"summary"` // formatted XML summaries + system prompt addition
}
const numSessionShards = 256
// Engine is the main short-term memory engine.
type Engine struct {
store *Store
compaction *CompactionEngine
compactionMu sync.Mutex
assembler *Assembler
assemblerMu sync.Mutex
retrieval *RetrievalEngine
config Config
complete CompleteFn
ignorePatterns []*regexp.Regexp
statelessPatterns []*regexp.Regexp
sessionShards [numSessionShards]struct {
mu sync.Mutex
}
}
// CompactionEngine handles LLM-based summarization (defined in short_compaction.go).
type CompactionEngine struct {
store *Store
config Config
complete CompleteFn
condensing sync.Map // map[int64]struct{} — dedup for async condensed goroutines
shutdownCtx context.Context
shutdownCancel context.CancelFunc
}
// Assembler handles budget-aware context assembly (defined in short_assembler.go).
type Assembler struct {
store *Store
config Config
}
// RetrievalEngine handles search and expansion (defined in short_retrieval.go).
type RetrievalEngine struct {
store *Store
config Config
}
// Store returns the underlying store for direct access.
func (r *RetrievalEngine) Store() *Store {
return r.store
}
// NewEngine creates a new short-term memory engine.
func NewEngine(config Config, completeFn CompleteFn) (*Engine, error) {
dir := filepath.Dir(config.DBPath)
if dir != "" && dir != "." {
if err := os.MkdirAll(dir, 0o755); err != nil {
return nil, fmt.Errorf("create db directory: %w", err)
}
}
db, err := sql.Open("sqlite", config.DBPath)
if err != nil {
return nil, fmt.Errorf("open db: %w", err)
}
// Configure SQLite for concurrent access
if _, err := db.Exec("PRAGMA journal_mode = WAL;"); err != nil {
db.Close()
return nil, fmt.Errorf("enable WAL: %w", err)
}
if _, err := db.Exec("PRAGMA busy_timeout = 5000;"); err != nil {
db.Close()
return nil, fmt.Errorf("set busy_timeout: %w", err)
}
if _, err := db.Exec("PRAGMA synchronous = NORMAL;"); err != nil {
db.Close()
return nil, fmt.Errorf("set synchronous: %w", err)
}
if err := runSchema(db); err != nil {
db.Close()
return nil, fmt.Errorf("migrations: %w", err)
}
store := &Store{db: db}
// Prepend hardcoded ignore patterns (spec lines 1326-1328)
ignorePatterns := make([]string, 0, 1+len(config.IgnoreSessionPatterns))
ignorePatterns = append(ignorePatterns, "heartbeat")
ignorePatterns = append(ignorePatterns, config.IgnoreSessionPatterns...)
retrieval := &RetrievalEngine{store: store, config: config}
return &Engine{
store: store,
compaction: nil,
assembler: nil,
retrieval: retrieval,
config: config,
complete: completeFn,
ignorePatterns: compileSessionPatterns(ignorePatterns),
statelessPatterns: compileSessionPatterns(config.StatelessSessionPatterns),
}, nil
}
// compileSessionPattern converts a glob pattern to a compiled regex.
// Pattern rules:
// - * matches any sequence of non-colon characters ([^:]*)
// - ** matches any sequence of characters including colons (.*)
// - All other characters are treated literally
// - Pattern is anchored (^...$)
func compileSessionPattern(pattern string) *regexp.Regexp {
var b strings.Builder
b.WriteByte('^')
i := 0
for i < len(pattern) {
if i+1 < len(pattern) && pattern[i] == '*' && pattern[i+1] == '*' {
b.WriteString(".*")
i += 2
continue
}
if pattern[i] == '*' {
b.WriteString("[^:]*")
i++
continue
}
b.WriteString(regexp.QuoteMeta(string(pattern[i])))
i++
}
b.WriteByte('$')
return regexp.MustCompile(b.String())
}
// compileSessionPatterns compiles multiple glob patterns into regex patterns.
func compileSessionPatterns(patterns []string) []*regexp.Regexp {
result := make([]*regexp.Regexp, 0, len(patterns))
for _, p := range patterns {
if p == "" {
continue
}
result = append(result, compileSessionPattern(p))
}
return result
}
// shouldIgnoreSession returns true if the session key matches any ignore pattern.
func (e *Engine) shouldIgnoreSession(sessionKey string) bool {
for _, p := range e.ignorePatterns {
if p.MatchString(sessionKey) {
return true
}
}
return false
}
// isStatelessSession returns true if the session key matches any stateless pattern.
func (e *Engine) isStatelessSession(sessionKey string) bool {
for _, p := range e.statelessPatterns {
if p.MatchString(sessionKey) {
return true
}
}
return false
}
// fnv32 computes FNV-1a 32-bit hash for session key sharding.
func fnv32(key string) uint32 {
h := uint32(2166136261)
for _, c := range key {
h ^= uint32(c)
h *= 16777619
}
return h
}
// getSessionMutex returns the sharded mutex for a session key.
func (e *Engine) getSessionMutex(sessionKey string) *sync.Mutex {
h := fnv32(sessionKey)
shard := h % numSessionShards
return &e.sessionShards[shard].mu
}
// Ingest adds messages to a conversation identified by sessionKey.
func (e *Engine) Ingest(ctx context.Context, sessionKey string, messages []Message) (*IngestResult, error) {
if e.shouldIgnoreSession(sessionKey) {
return nil, nil
}
if e.isStatelessSession(sessionKey) {
return nil, nil
}
mu := e.getSessionMutex(sessionKey)
mu.Lock()
defer mu.Unlock()
conv, err := e.store.GetOrCreateConversation(ctx, sessionKey)
if err != nil {
return nil, fmt.Errorf("get conversation: %w", err)
}
var totalTokens int
var msgIDs []int64
for _, msg := range messages {
var added *Message
var err error
if len(msg.Parts) > 0 {
added, err = e.store.AddMessageWithParts(ctx, conv.ConversationID, msg.Role, msg.Parts, msg.TokenCount)
} else {
added, err = e.store.AddMessage(ctx, conv.ConversationID, msg.Role, msg.Content, msg.TokenCount)
}
if err != nil {
return nil, fmt.Errorf("add message: %w", err)
}
totalTokens += msg.TokenCount
msgIDs = append(msgIDs, added.ID)
}
// Append to context_items using actual inserted IDs
if err := e.store.AppendContextMessages(ctx, conv.ConversationID, msgIDs); err != nil {
return nil, fmt.Errorf("append context: %w", err)
}
logger.InfoCF("seahorse", "ingest", map[string]any{
"conv_id": conv.ConversationID,
"messages": len(messages),
"tokens": totalTokens,
})
return &IngestResult{
MessageCount: len(messages),
TokenCount: totalTokens,
}, nil
}
// Close releases resources.
func (e *Engine) Close() error {
// Signal compaction goroutines to stop
if e.compaction != nil {
e.compaction.Close()
}
if e.store != nil && e.store.db != nil {
return e.store.db.Close()
}
return nil
}
// GetRetrieval returns the retrieval engine for tool implementations.
func (e *Engine) GetRetrieval() *RetrievalEngine {
return e.retrieval
}
// Assemble builds budget-constrained context for a session.
func (e *Engine) Assemble(ctx context.Context, sessionKey string, input AssembleInput) (*AssembleResult, error) {
if e.shouldIgnoreSession(sessionKey) {
return nil, nil
}
conv, err := e.store.GetOrCreateConversation(ctx, sessionKey)
if err != nil {
return nil, fmt.Errorf("get conversation: %w", err)
}
e.initAssemblerOnce()
return e.assembler.Assemble(ctx, conv.ConversationID, input)
}
// Compact compresses conversation history for a session.
func (e *Engine) Compact(ctx context.Context, sessionKey string, input CompactInput) (*CompactResult, error) {
if e.shouldIgnoreSession(sessionKey) || e.isStatelessSession(sessionKey) {
return &CompactResult{}, nil
}
conv, err := e.store.GetOrCreateConversation(ctx, sessionKey)
if err != nil {
return nil, fmt.Errorf("get conversation: %w", err)
}
e.initCompactionOnce()
return e.compaction.Compact(ctx, conv.ConversationID, input)
}
// CompactUntilUnder aggressively compacts until context is under budget.
// Used for emergency compaction after LLM overflow (retry reason).
func (e *Engine) CompactUntilUnder(ctx context.Context, sessionKey string, budget int) (*CompactResult, error) {
if e.shouldIgnoreSession(sessionKey) || e.isStatelessSession(sessionKey) {
return &CompactResult{}, nil
}
conv, err := e.store.GetOrCreateConversation(ctx, sessionKey)
if err != nil {
return nil, fmt.Errorf("get conversation: %w", err)
}
e.initCompactionOnce()
return e.compaction.CompactUntilUnder(ctx, conv.ConversationID, budget)
}
// initCompactionOnce lazily initializes the compaction engine.
func (e *Engine) initCompactionOnce() {
if e.compaction == nil {
e.compactionMu.Lock()
defer e.compactionMu.Unlock()
if e.compaction == nil {
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
e.compaction = &CompactionEngine{
store: e.store,
config: e.config,
complete: e.complete,
shutdownCtx: shutdownCtx,
shutdownCancel: shutdownCancel,
}
}
}
}
// initAssemblerOnce lazily initializes the assembler.
func (e *Engine) initAssemblerOnce() {
if e.assembler == nil {
e.assemblerMu.Lock()
defer e.assemblerMu.Unlock()
if e.assembler == nil {
e.assembler = &Assembler{store: e.store, config: e.config}
}
}
}
// IngestMessages is an alias for Ingest.
func (e *Engine) IngestMessages(ctx context.Context, sessionKey string, messages []Message) (*IngestResult, error) {
return e.Ingest(ctx, sessionKey, messages)
}
// Bootstrap reconciles a session's messages with the database.
// Called once at startup for each known session.
// Bootstrap reconciles JSONL history with SQLite by ingesting only the delta.
// Simple approach: find longest matching prefix and append delta.
// If any mismatch is detected, clear and rebuild.
func (e *Engine) Bootstrap(ctx context.Context, sessionKey string, messages []Message) error {
if e.shouldIgnoreSession(sessionKey) {
return nil
}
if e.isStatelessSession(sessionKey) {
return nil
}
if len(messages) == 0 {
return nil
}
conv, err := e.store.GetOrCreateConversation(ctx, sessionKey)
if err != nil {
return fmt.Errorf("bootstrap: get conversation: %w", err)
}
// Get messages already in DB
dbMsgs, err := e.store.GetMessages(ctx, conv.ConversationID, len(messages), 0)
if err != nil {
return fmt.Errorf("bootstrap: get messages: %w", err)
}
// Fast path: DB has same count and exact match → no-op
if len(dbMsgs) == len(messages) {
matched := true
for i := 0; i < len(messages); i++ {
if !messageMatches(dbMsgs[i], messages[i]) {
matched = false
break
}
}
if matched {
return nil // DB is up to date
}
}
// Find longest matching prefix from the start
anchor := -1
compareLen := len(dbMsgs)
if compareLen > len(messages) {
compareLen = len(messages)
}
for i := 0; i < compareLen; i++ {
if messageMatches(dbMsgs[i], messages[i]) {
anchor = i
} else {
// Mismatch detected - log details and rebuild
logger.InfoCF("seahorse", "bootstrap: mismatch detected", map[string]any{
"conv_id": conv.ConversationID,
"index": i,
"db_role": dbMsgs[i].Role,
"db_content": truncate(dbMsgs[i].Content, 50),
"db_parts": len(dbMsgs[i].Parts),
"msg_role": messages[i].Role,
"msg_content": truncate(messages[i].Content, 50),
"msg_parts": len(messages[i].Parts),
})
break
}
}
// If we hit a mismatch before reaching the end of DB messages, delete delta and re-ingest
// Note: anchor can be -1 if first message didn't match (history completely changed)
if anchor >= 0 && anchor < len(dbMsgs)-1 && len(dbMsgs) > 0 {
anchorID := dbMsgs[anchor].ID
logger.InfoCF("seahorse", "bootstrap: history edit detected", map[string]any{
"conv_id": conv.ConversationID,
"db_count": len(dbMsgs),
"anchor": anchor,
"anchor_id": anchorID,
"msg_count": len(messages),
"delta_start": anchor + 1,
})
// Delete messages after anchor (also clears context_items)
if err := e.store.DeleteMessagesAfterID(ctx, conv.ConversationID, anchorID); err != nil {
return fmt.Errorf("bootstrap: delete messages: %w", err)
}
// Re-ingest from anchor+1 to end
delta := messages[anchor+1:]
if len(delta) > 0 {
_, err := e.Ingest(ctx, sessionKey, delta)
if err != nil {
return fmt.Errorf("bootstrap: re-ingest: %w", err)
}
}
return nil
}
// Normal case: append delta after anchor
if anchor >= 0 && anchor < len(messages)-1 {
delta := messages[anchor+1:]
if len(delta) > 0 {
_, err := e.Ingest(ctx, sessionKey, delta)
if err != nil {
return fmt.Errorf("bootstrap: ingest delta: %w", err)
}
}
} else if anchor == -1 && len(dbMsgs) > 0 {
// First message changed (history completely different) - rebuild from scratch
logger.InfoCF("seahorse", "bootstrap: history replaced, rebuilding", map[string]any{
"conv_id": conv.ConversationID,
"db_count": len(dbMsgs),
"msg_count": len(messages),
})
// Delete all existing messages
if err := e.store.DeleteMessagesAfterID(ctx, conv.ConversationID, 0); err != nil {
return fmt.Errorf("bootstrap: delete all messages: %w", err)
}
// Re-ingest everything
if len(messages) > 0 {
_, err := e.Ingest(ctx, sessionKey, messages)
if err != nil {
return fmt.Errorf("bootstrap: re-ingest all: %w", err)
}
}
} else if anchor == -1 && len(dbMsgs) == 0 {
// DB is empty, ingest everything
_, err := e.Ingest(ctx, sessionKey, messages)
if err != nil {
return fmt.Errorf("bootstrap: ingest all: %w", err)
}
}
return nil
}
// truncate shortens a string for logging.
func truncate(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
return s[:maxLen] + "..."
}
// messageMatches compares two messages using (role, content) or (role, parts).
// TokenCount is NOT compared because it may be re-estimated differently
// during bootstrap (e.g., via tokenizer.EstimateMessageTokens).
// For messages with Parts (tool_use, tool_result), compare Parts instead of Content
// since AddMessageWithParts stores empty Content in DB.
func messageMatches(a, b Message) bool {
if a.Role != b.Role {
return false
}
// If either message has Parts, compare Parts
if len(a.Parts) > 0 || len(b.Parts) > 0 {
return partsMatch(a.Parts, b.Parts)
}
// Simple text messages: compare Content
return a.Content == b.Content
}
// partsMatch compares two slices of MessagePart for equality.
func partsMatch(a, b []MessagePart) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i].Type != b[i].Type {
return false
}
switch a[i].Type {
case "text":
if a[i].Text != b[i].Text {
return false
}
case "tool_use":
if a[i].Name != b[i].Name || a[i].Arguments != b[i].Arguments || a[i].ToolCallID != b[i].ToolCallID {
return false
}
case "tool_result":
if a[i].ToolCallID != b[i].ToolCallID || a[i].Text != b[i].Text {
return false
}
case "media":
if a[i].MediaURI != b[i].MediaURI || a[i].MimeType != b[i].MimeType {
return false
}
}
}
return true
}
File diff suppressed because it is too large Load Diff
+212
View File
@@ -0,0 +1,212 @@
package seahorse
import (
"context"
"fmt"
"regexp"
"strconv"
"strings"
"time"
)
// ParseLastDuration parses a "last" duration string like "6h", "7d", "2w", "1m".
// Returns the duration and nil error, or zero and error if invalid.
func ParseLastDuration(s string) (time.Duration, error) {
if s == "" {
return 0, fmt.Errorf("empty duration")
}
re := regexp.MustCompile(`^(\d+)([hdwm])$`)
matches := re.FindStringSubmatch(s)
if matches == nil {
return 0, fmt.Errorf("invalid duration format: %q (use format like 6h, 7d, 2w, 1m)", s)
}
value, _ := strconv.Atoi(matches[1])
unit := matches[2]
switch unit {
case "h":
return time.Duration(value) * time.Hour, nil
case "d":
return time.Duration(value) * 24 * time.Hour, nil
case "w":
return time.Duration(value) * 7 * 24 * time.Hour, nil
case "m":
return time.Duration(value) * 30 * 24 * time.Hour, nil
default:
return 0, fmt.Errorf("unknown unit: %q", unit)
}
}
// GrepInput controls search across summaries and messages.
type GrepInput struct {
Pattern string `json:"pattern"`
Scope string `json:"scope,omitempty"` // "both" (default), "summary", or "message"
Role string `json:"role,omitempty"` // "user", "assistant", or "" (all)
AllConversations bool `json:"allConversations,omitempty"`
Since *time.Time `json:"since,omitempty"`
Before *time.Time `json:"before,omitempty"`
Last string `json:"last,omitempty"` // shortcut: "6h", "7d", "2w", "1m"
Limit int `json:"limit,omitempty"`
}
// GrepResult contains search results.
type GrepResult struct {
Success bool `json:"success"`
Summaries []GrepSummaryResult `json:"summaries"`
Messages []GrepMessageResult `json:"messages"`
TotalSummaries int `json:"totalSummaries"`
TotalMessages int `json:"totalMessages"`
Hint string `json:"hint,omitempty"`
}
// GrepSummaryResult is a summary match from grep.
type GrepSummaryResult struct {
ID string `json:"id"`
Content string `json:"content"`
Depth int `json:"depth"`
Kind SummaryKind `json:"kind"`
ConversationID int64 `json:"conversationId"`
// Rank is the bm25 relevance score (negative value, closer to 0 = better match).
// Examples: -0.5 = excellent match, -2.0 = good match, -10.0 = partial match.
Rank float64 `json:"rank,omitempty"`
}
// GrepMessageResult is a message match from grep.
type GrepMessageResult struct {
ID int64 `json:"id,string"`
Snippet string `json:"snippet"`
Role string `json:"role"`
ConversationID int64 `json:"conversationId"`
Rank float64 `json:"rank,omitempty"` // Relevance score (lower = better match)
}
// ExpandMessagesResult contains expanded messages.
type ExpandMessagesResult struct {
Messages []Message `json:"messages"`
TokenCount int `json:"tokenCount"`
}
// Grep searches summaries and messages for matching content.
func (r *RetrievalEngine) Grep(ctx context.Context, input GrepInput) (*GrepResult, error) {
if input.Pattern == "" {
return nil, fmt.Errorf("grep: pattern is required")
}
limit := input.Limit
if limit == 0 {
limit = 20
}
// Handle Last parameter: convert to Since
since := input.Since
if input.Last != "" {
dur, err := ParseLastDuration(input.Last)
if err != nil {
return nil, fmt.Errorf("grep: invalid last: %w", err)
}
t := time.Now().UTC().Add(-dur)
since = &t
}
// Auto-detect mode: use LIKE if pattern contains %, otherwise full-text
mode := ""
if strings.Contains(input.Pattern, "%") {
mode = "like"
}
searchInput := SearchInput{
Pattern: input.Pattern,
Mode: mode,
Role: input.Role,
AllConversations: input.AllConversations,
Since: since,
Before: input.Before,
Limit: limit,
}
result := &GrepResult{
Success: true,
Summaries: make([]GrepSummaryResult, 0),
Messages: make([]GrepMessageResult, 0),
TotalSummaries: 0,
TotalMessages: 0,
}
// Determine scope
scope := input.Scope
if scope == "" {
scope = "both"
}
// Search summaries if requested
if scope == "both" || scope == "summary" {
sumResults, err := r.store.SearchSummaries(ctx, searchInput)
if err != nil {
return nil, fmt.Errorf("search summaries: %w", err)
}
for _, sr := range sumResults {
if sr.SummaryID != "" {
result.Summaries = append(result.Summaries, GrepSummaryResult{
ID: sr.SummaryID,
Content: sr.Content,
Depth: sr.Depth,
Kind: sr.Kind,
ConversationID: sr.ConversationID,
Rank: sr.Rank,
})
}
}
if len(sumResults) > 0 {
result.TotalSummaries = sumResults[0].TotalCount
}
}
// Search messages if requested
if scope == "both" || scope == "message" {
msgResults, err := r.store.SearchMessages(ctx, searchInput)
if err != nil {
return nil, fmt.Errorf("search messages: %w", err)
}
for _, sr := range msgResults {
if sr.MessageID > 0 {
result.Messages = append(result.Messages, GrepMessageResult{
ID: sr.MessageID,
Snippet: sr.Snippet,
Role: sr.Role,
ConversationID: sr.ConversationID,
Rank: sr.Rank,
})
}
}
if len(msgResults) > 0 {
result.TotalMessages = msgResults[0].TotalCount
}
}
// Add hint if no results
if len(result.Summaries) == 0 && len(result.Messages) == 0 {
result.Hint = "No matches. Try: %keyword% for fuzzy search, or all_conversations: true"
}
return result, nil
}
// ExpandMessages retrieves full message content by IDs.
func (r *RetrievalEngine) ExpandMessages(ctx context.Context, messageIDs []int64) (*ExpandMessagesResult, error) {
result := &ExpandMessagesResult{
Messages: make([]Message, 0, len(messageIDs)),
}
for _, msgID := range messageIDs {
msg, err := r.store.GetMessageByID(ctx, msgID)
if err != nil {
continue
}
result.Messages = append(result.Messages, *msg)
result.TokenCount += msg.TokenCount
}
return result, nil
}
+362
View File
@@ -0,0 +1,362 @@
package seahorse
import (
"context"
"fmt"
"testing"
"time"
)
// --- Retrieval Tests ---
func newTestRetrieval(t *testing.T) (*RetrievalEngine, *Store, int64) {
t.Helper()
s := openTestStore(t)
ctx := context.Background()
conv, _ := s.GetOrCreateConversation(ctx, "test:retrieval")
return &RetrievalEngine{store: s}, s, conv.ConversationID
}
func TestRetrievalGrepSummaries(t *testing.T) {
r, s, convID := newTestRetrieval(t)
ctx := context.Background()
s.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindLeaf,
Depth: 0,
Content: "数据库连接配置说明",
TokenCount: 50,
})
s.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindLeaf,
Depth: 0,
Content: "API endpoint documentation",
TokenCount: 50,
})
// FTS5 search (trigram, needs >= 3 chars)
results, err := r.Grep(ctx, GrepInput{
Pattern: "数据库连",
})
if err != nil {
t.Fatalf("Grep: %v", err)
}
if len(results.Summaries) == 0 {
t.Error("expected at least 1 FTS result")
}
// LIKE search with wildcard
results, err = r.Grep(ctx, GrepInput{
Pattern: "%endpoint%",
})
if err != nil {
t.Fatalf("Grep LIKE: %v", err)
}
if len(results.Summaries) == 0 {
t.Error("expected at least 1 LIKE result")
}
}
func TestRetrievalGrepMessages(t *testing.T) {
r, s, convID := newTestRetrieval(t)
ctx := context.Background()
s.AddMessage(ctx, convID, "user", "find this message about testing", 5)
s.AddMessage(ctx, convID, "user", "unrelated content here", 5)
results, err := r.Grep(ctx, GrepInput{
Pattern: "testing",
})
if err != nil {
t.Fatalf("Grep: %v", err)
}
if len(results.Messages) == 0 {
t.Error("expected at least 1 result for 'testing'")
}
}
func TestRetrievalExpandMessages(t *testing.T) {
r, s, convID := newTestRetrieval(t)
ctx := context.Background()
msg, _ := s.AddMessage(ctx, convID, "user", "expand this message", 10)
result, err := r.ExpandMessages(ctx, []int64{msg.ID})
if err != nil {
t.Fatalf("ExpandMessages: %v", err)
}
if len(result.Messages) != 1 {
t.Errorf("Messages = %d, want 1", len(result.Messages))
}
if result.Messages[0].Content != "expand this message" {
t.Errorf("Content = %q, want 'expand this message'", result.Messages[0].Content)
}
}
func TestRetrievalExpandMultipleMessages(t *testing.T) {
r, s, convID := newTestRetrieval(t)
ctx := context.Background()
msg1, _ := s.AddMessage(ctx, convID, "user", "first message", 10)
msg2, _ := s.AddMessage(ctx, convID, "assistant", "second message", 10)
msg3, _ := s.AddMessage(ctx, convID, "user", "third message", 10)
result, err := r.ExpandMessages(ctx, []int64{msg1.ID, msg2.ID, msg3.ID})
if err != nil {
t.Fatalf("ExpandMessages: %v", err)
}
if len(result.Messages) != 3 {
t.Errorf("Messages = %d, want 3", len(result.Messages))
}
if result.TokenCount != 30 {
t.Errorf("TokenCount = %d, want 30", result.TokenCount)
}
}
func TestRetrievalGrepWithTimeFilter(t *testing.T) {
r, s, convID := newTestRetrieval(t)
ctx := context.Background()
now := time.Now().UTC()
before := now.Add(-2 * time.Hour)
// Create messages at different times
s.AddMessage(ctx, convID, "user", "old message about auth", 5)
s.AddMessage(ctx, convID, "user", "recent message about auth", 5)
// Search with time filter
results, err := r.Grep(ctx, GrepInput{
Pattern: "auth",
Since: &before,
})
if err != nil {
t.Fatalf("Grep: %v", err)
}
_ = results // Just verify no error
}
func TestRetrievalGrepAllConversations(t *testing.T) {
r, s, _ := newTestRetrieval(t)
ctx := context.Background()
// Create another conversation
conv2, _ := s.GetOrCreateConversation(ctx, "test:retrieval2")
// Add messages to both
s.AddMessage(ctx, conv2.ConversationID, "user", "unique keyword xyz", 5)
// Search all conversations
results, err := r.Grep(ctx, GrepInput{
Pattern: "xyz",
AllConversations: true,
})
if err != nil {
t.Fatalf("Grep: %v", err)
}
if len(results.Messages) == 0 {
t.Error("expected to find message in other conversation")
}
}
// --- Last Duration Parsing Tests ---
func TestParseLastDuration(t *testing.T) {
tests := []struct {
input string
wantDur time.Duration
wantErr bool
}{
{"6h", 6 * time.Hour, false},
{"1d", 24 * time.Hour, false},
{"7d", 7 * 24 * time.Hour, false},
{"2w", 14 * 24 * time.Hour, false},
{"1m", 30 * 24 * time.Hour, false}, // month = 30 days
{"3m", 90 * 24 * time.Hour, false},
{"", 0, true},
{"invalid", 0, true},
{"5x", 0, true}, // unknown unit
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
got, err := ParseLastDuration(tt.input)
if tt.wantErr {
if err == nil {
t.Error("expected error, got nil")
}
} else {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != tt.wantDur {
t.Errorf("ParseLastDuration(%q) = %v, want %v", tt.input, got, tt.wantDur)
}
}
})
}
}
// --- Role Filter Tests ---
func TestRetrievalGrepRoleFilter(t *testing.T) {
r, s, convID := newTestRetrieval(t)
ctx := context.Background()
s.AddMessage(ctx, convID, "user", "user message about alpha", 5)
s.AddMessage(ctx, convID, "assistant", "assistant reply about alpha", 5)
s.AddMessage(ctx, convID, "user", "another user message", 5)
// Search all roles
allResults, err := r.Grep(ctx, GrepInput{
Pattern: "alpha",
})
if err != nil {
t.Fatalf("Grep: %v", err)
}
if len(allResults.Messages) != 2 {
t.Errorf("expected 2 messages, got %d", len(allResults.Messages))
}
// Search user only
userResults, err := r.Grep(ctx, GrepInput{
Pattern: "alpha",
Role: "user",
})
if err != nil {
t.Fatalf("Grep: %v", err)
}
if len(userResults.Messages) != 1 {
t.Errorf("expected 1 user message, got %d", len(userResults.Messages))
}
if userResults.Messages[0].Role != "user" {
t.Errorf("expected role=user, got %s", userResults.Messages[0].Role)
}
// Search assistant only
assistantResults, err := r.Grep(ctx, GrepInput{
Pattern: "alpha",
Role: "assistant",
})
if err != nil {
t.Fatalf("Grep: %v", err)
}
if len(assistantResults.Messages) != 1 {
t.Errorf("expected 1 assistant message, got %d", len(assistantResults.Messages))
}
}
// --- Last Parameter Tests ---
func TestRetrievalGrepWithLast(t *testing.T) {
r, s, convID := newTestRetrieval(t)
ctx := context.Background()
// Add messages (we can't control timestamps in SQLite easily,
// but we can verify the parameter is parsed correctly)
s.AddMessage(ctx, convID, "user", "recent message about testing", 5)
// Test that Last parameter is converted to Since
results, err := r.Grep(ctx, GrepInput{
Pattern: "testing",
Last: "1d", // last 1 day
})
if err != nil {
t.Fatalf("Grep: %v", err)
}
// Should still find the message since it's recent
if len(results.Messages) == 0 {
t.Error("expected to find recent message")
}
}
// TestRetrievalGrepRoleFilterWithSummaries tests that role filter works when
// searching both summaries and messages (summaries don't have role column).
func TestRetrievalGrepRoleFilterWithSummaries(t *testing.T) {
r, s, convID := newTestRetrieval(t)
ctx := context.Background()
// Create a summary (no role column)
s.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindLeaf,
Depth: 0,
Content: "summary about testing",
TokenCount: 50,
})
// Add messages with different roles
s.AddMessage(ctx, convID, "user", "user message about testing", 5)
s.AddMessage(ctx, convID, "assistant", "assistant reply about testing", 5)
// Search with role filter and scope=both (default), using LIKE mode (%)
// This should NOT error even though summaries don't have role column
bothResults, err := r.Grep(ctx, GrepInput{
Pattern: "%testing%", // LIKE mode to trigger the bug
Role: "user",
Scope: "both",
})
if err != nil {
t.Fatalf("Grep with role and scope=both: %v", err)
}
// Should only return user messages, not summaries or assistant messages
if len(bothResults.Messages) != 1 {
t.Errorf("expected 1 user message, got %d", len(bothResults.Messages))
}
if len(bothResults.Messages) > 0 && bothResults.Messages[0].Role != "user" {
t.Errorf("expected role=user, got %s", bothResults.Messages[0].Role)
}
// Summaries should be empty since they don't have roles to filter
// (or we could return all summaries - either is acceptable)
}
// TestRetrievalGrepTotalCounts tests that grep returns total counts.
func TestRetrievalGrepTotalCounts(t *testing.T) {
r, s, convID := newTestRetrieval(t)
ctx := context.Background()
// Create 3 summaries
for i := 0; i < 3; i++ {
s.CreateSummary(ctx, CreateSummaryInput{
ConversationID: convID,
Kind: SummaryKindLeaf,
Depth: 0,
Content: fmt.Sprintf("summary about testing %d", i),
TokenCount: 50,
})
}
// Add 5 messages
for i := 0; i < 5; i++ {
s.AddMessage(ctx, convID, "user", fmt.Sprintf("message about testing %d", i), 5)
}
// Search with limit smaller than total
results, err := r.Grep(ctx, GrepInput{
Pattern: "%testing%", // LIKE mode
Scope: "both",
Limit: 2,
})
if err != nil {
t.Fatalf("Grep: %v", err)
}
// Should return limited results
if len(results.Summaries) > 2 {
t.Errorf("expected at most 2 summaries, got %d", len(results.Summaries))
}
if len(results.Messages) > 2 {
t.Errorf("expected at most 2 messages, got %d", len(results.Messages))
}
// But total counts should reflect all matches
if results.TotalSummaries != 3 {
t.Errorf("expected TotalSummaries=3, got %d", results.TotalSummaries)
}
if results.TotalMessages != 5 {
t.Errorf("expected TotalMessages=5, got %d", results.TotalMessages)
}
}
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
+129
View File
@@ -0,0 +1,129 @@
package seahorse
import (
"context"
"encoding/json"
"fmt"
"github.com/sipeed/picoclaw/pkg/tools"
)
// ExpandTool recovers full message content by ID.
type ExpandTool struct {
engine *RetrievalEngine
}
func NewExpandTool(engine *RetrievalEngine) *ExpandTool {
return &ExpandTool{engine: engine}
}
func (t *ExpandTool) Name() string {
return "short_expand"
}
func (t *ExpandTool) Description() string {
return `Get full message content by ID.
Use when short_grep returns messages and you need complete content (not just snippet).
Parameters:
- message_ids (required): Array of message ID strings (from short_grep results)
Returns message with:
- content: Full text content
- parts: Structured content
- text: Full text
- tool_use: name, arguments, toolCallId
- tool_result: toolCallId only (content omitted - re-run tool if needed)
- media: mediaUri (file path), mimeType
Notes:
- tool_result content is not returned (can be large). Re-run the tool if you need the result.
- Media files are stored on disk at mediaUri path, use bash to access.
Example:
{"message_ids": ["10", "25"]}`
}
func (t *ExpandTool) Parameters() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]any{
"message_ids": map[string]any{
"type": "array",
"items": map[string]any{"type": "string"},
"description": "Message IDs to expand (from short_grep results, e.g., [\"10\", \"25\"])",
},
},
"required": []string{"message_ids"},
}
}
func (t *ExpandTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
idsRaw, ok := args["message_ids"].([]any)
if !ok || len(idsRaw) == 0 {
return tools.ErrorResult(
"Missing required 'message_ids' argument. " +
"Example: {\"message_ids\": [\"10\", \"25\"]}")
}
// Parse message IDs
messageIDs := make([]int64, 0, len(idsRaw))
for _, id := range idsRaw {
switch v := id.(type) {
case string:
var n int64
if _, err := fmt.Sscanf(v, "%d", &n); err != nil {
return tools.ErrorResult(fmt.Sprintf("Invalid message_id %q: %v", v, err))
}
messageIDs = append(messageIDs, n)
case float64:
messageIDs = append(messageIDs, int64(v))
}
}
result, err := t.engine.ExpandMessages(ctx, messageIDs)
if err != nil {
return tools.ErrorResult("Expand failed: " + err.Error())
}
// Build response with filtered parts
messages := make([]map[string]any, 0, len(result.Messages))
for _, msg := range result.Messages {
parts := make([]map[string]any, 0, len(msg.Parts))
for _, p := range msg.Parts {
part := map[string]any{"type": p.Type}
switch p.Type {
case "text":
part["text"] = p.Text
case "tool_use":
part["name"] = p.Name
part["arguments"] = p.Arguments
part["toolCallId"] = p.ToolCallID
case "tool_result":
// Omit content - can be large, re-run tool if needed
part["toolCallId"] = p.ToolCallID
case "media":
part["mediaUri"] = p.MediaURI
part["mimeType"] = p.MimeType
}
parts = append(parts, part)
}
messages = append(messages, map[string]any{
"id": fmt.Sprintf("%d", msg.ID),
"role": msg.Role,
"content": msg.Content,
"parts": parts,
"conversationId": msg.ConversationID,
})
}
output := map[string]any{
"success": true,
"tokenCount": result.TokenCount,
"messages": messages,
}
data, _ := json.Marshal(output)
return tools.NewToolResult(string(data))
}
+136
View File
@@ -0,0 +1,136 @@
package seahorse
import (
"context"
"encoding/json"
"fmt"
"testing"
)
func TestExpandToolByMessageIDs(t *testing.T) {
s := openTestStore(t)
ctx := context.Background()
conv, _ := s.GetOrCreateConversation(ctx, "test:expand-tool")
msg1, _ := s.AddMessage(ctx, conv.ConversationID, "user", "first message", 10)
msg2, _ := s.AddMessage(ctx, conv.ConversationID, "assistant", "second message", 10)
re := &RetrievalEngine{store: s}
tool := NewExpandTool(re)
result := tool.Execute(ctx, map[string]any{
"message_ids": []any{fmt.Sprintf("%d", msg1.ID), fmt.Sprintf("%d", msg2.ID)},
})
if result.IsError {
t.Fatalf("Expand failed: %s", result.ForLLM)
}
// Parse result
var output struct {
Success bool `json:"success"`
TokenCount int `json:"tokenCount"`
Messages []map[string]any `json:"messages"`
}
if err := json.Unmarshal([]byte(result.ForLLM), &output); err != nil {
t.Fatalf("Parse result: %v", err)
}
if !output.Success {
t.Error("expected success=true")
}
if len(output.Messages) != 2 {
t.Errorf("Messages = %d, want 2", len(output.Messages))
}
if output.TokenCount != 20 {
t.Errorf("TokenCount = %d, want 20", output.TokenCount)
}
}
func TestExpandToolMissingIDs(t *testing.T) {
s := openTestStore(t)
re := &RetrievalEngine{store: s}
tool := NewExpandTool(re)
result := tool.Execute(context.Background(), map[string]any{})
if !result.IsError {
t.Error("expected error for missing message_ids")
}
}
func TestExpandToolWithParts(t *testing.T) {
s := openTestStore(t)
ctx := context.Background()
conv, _ := s.GetOrCreateConversation(ctx, "test:expand-parts")
// Create message with parts
parts := []MessagePart{
{Type: "text", Text: "Hello"},
{Type: "tool_use", Name: "bash", Arguments: `{"command":"ls"}`, ToolCallID: "call_123"},
{Type: "tool_result", ToolCallID: "call_123", Text: "file1.txt\nfile2.txt"},
}
msg, _ := s.AddMessageWithParts(ctx, conv.ConversationID, "assistant", parts, 50)
re := &RetrievalEngine{store: s}
tool := NewExpandTool(re)
result := tool.Execute(ctx, map[string]any{
"message_ids": []any{fmt.Sprintf("%d", msg.ID)},
})
if result.IsError {
t.Fatalf("Expand failed: %s", result.ForLLM)
}
var output struct {
Messages []struct {
Parts []map[string]any `json:"parts"`
} `json:"messages"`
}
if err := json.Unmarshal([]byte(result.ForLLM), &output); err != nil {
t.Fatalf("Parse result: %v", err)
}
if len(output.Messages) != 1 {
t.Fatalf("Messages = %d, want 1", len(output.Messages))
}
// Verify parts are filtered correctly
foundText := false
foundToolUse := false
foundToolResult := false
for _, p := range output.Messages[0].Parts {
switch p["type"].(string) {
case "text":
foundText = true
if p["text"] != "Hello" {
t.Errorf("text = %v, want Hello", p["text"])
}
case "tool_use":
foundToolUse = true
if p["name"] != "bash" {
t.Errorf("name = %v, want bash", p["name"])
}
case "tool_result":
foundToolResult = true
// tool_result should NOT have content
if _, hasContent := p["content"]; hasContent {
t.Error("tool_result should not have content field")
}
if p["toolCallId"] != "call_123" {
t.Errorf("toolCallId = %v, want call_123", p["toolCallId"])
}
}
}
if !foundText {
t.Error("missing text part")
}
if !foundToolUse {
t.Error("missing tool_use part")
}
if !foundToolResult {
t.Error("missing tool_result part")
}
}
+172
View File
@@ -0,0 +1,172 @@
package seahorse
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/sipeed/picoclaw/pkg/tools"
)
// GrepTool searches summaries and messages for matching content.
type GrepTool struct {
engine *RetrievalEngine
}
func NewGrepTool(engine *RetrievalEngine) *GrepTool {
return &GrepTool{engine: engine}
}
func (t *GrepTool) Name() string {
return "short_grep"
}
func (t *GrepTool) Description() string {
return `Search summaries and messages for matching content.
Pattern syntax:
- Words: "authentication" - matches content containing this word
- AND: "auth AND login" - matches content with both words
- OR: "auth OR signin" - matches content with either word
- NOT: "bug NOT fixed" - matches "bug" but excludes "fixed"
- Wildcard: "%auth%" - matches any text containing "auth" (e.g., "auth", "authentication")
Each summary has a "depth" field:
- depth 0: Created from messages, most detailed
- depth 1+: Created from other summaries, more compressed but covers longer time
Parameters:
- pattern (required): Search pattern
- scope: "both" (default), "summary", or "message" - what to search
- role: "user", "assistant", or omit for all - filter by message role
- last: Time shortcut like "6h", "7d", "2w", "1m" (hours/days/weeks/months)
- all_conversations: Search all conversations (default: current only)
- since: ISO8601 timestamp, content after this time
- before: ISO8601 timestamp, content before this time
- limit: Max results (default: 20)
Returns:
{
"success": true,
"summaries": [{"id": "sum_abc", "content": "...", "depth": 0, "kind": "leaf", "conversationId": 1, "rank": -0.5}],
"messages": [{"id": "10", "snippet": "...matched...", "role": "user", "conversationId": 1, "rank": -1.2}],
"totalSummaries": 5,
"totalMessages": 10,
"hint": "No matches. Try: %keyword% for fuzzy search"
}
Rank field (FTS5 mode only): bm25 relevance score, negative value where closer to 0 = better match.
Examples: -0.5=excellent, -2=good, -5=partial, -10=weak. LIKE mode (%pattern%) has no rank.
Examples:
{"pattern": "authentication"}
{"pattern": "bug AND login"}
{"pattern": "%snake%"}
{"pattern": "project", "scope": "summary"}
{"pattern": "error", "role": "assistant", "last": "7d"}
{"pattern": "error", "all_conversations": true}`
}
func (t *GrepTool) Parameters() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]any{
"pattern": map[string]any{
"type": "string",
"description": "Search pattern. Supports: words, AND/OR/NOT operators, % wildcard",
},
"scope": map[string]any{
"type": "string",
"enum": []string{"both", "summary", "message"},
"description": "What to search: 'both' (default), 'summary', or 'message'",
},
"role": map[string]any{
"type": "string",
"enum": []string{"user", "assistant"},
"description": "Filter by message role (default: all roles)",
},
"last": map[string]any{
"type": "string",
"description": "Time shortcut: '6h' (6 hours), '7d' (7 days), '2w' (2 weeks), '1m' (1 month)",
},
"all_conversations": map[string]any{
"type": "boolean",
"description": "Search across all conversations (default: searches current conversation only)",
},
"since": map[string]any{
"type": "string",
"description": "ISO8601 timestamp, only return content after this time",
},
"before": map[string]any{
"type": "string",
"description": "ISO8601 timestamp, only return content before this time",
},
"limit": map[string]any{
"type": "integer",
"description": "Maximum number of results (default: 20)",
},
},
"required": []string{"pattern"},
}
}
func (t *GrepTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
pattern, ok := args["pattern"].(string)
if !ok || pattern == "" {
return tools.ErrorResult("Missing required 'pattern' argument. Example: {\"pattern\": \"authentication\"}")
}
input := GrepInput{Pattern: pattern}
if scope, ok := args["scope"].(string); ok && scope != "" {
input.Scope = scope
}
if role, ok := args["role"].(string); ok && role != "" {
input.Role = role
}
if last, ok := args["last"].(string); ok && last != "" {
input.Last = last
}
if allConv, ok := args["all_conversations"].(bool); ok {
input.AllConversations = allConv
}
if limit, ok := args["limit"].(float64); ok {
input.Limit = int(limit)
}
if sinceStr, ok := args["since"].(string); ok && sinceStr != "" {
parsed, err := time.Parse(time.RFC3339, sinceStr)
if err != nil {
return tools.ErrorResult(fmt.Sprintf(
"Invalid 'since' timestamp. Use RFC3339 format like '2024-01-15T10:00:00Z'. Error: %v", err))
}
input.Since = &parsed
}
if beforeStr, ok := args["before"].(string); ok && beforeStr != "" {
parsed, err := time.Parse(time.RFC3339, beforeStr)
if err != nil {
return tools.ErrorResult(fmt.Sprintf("Invalid 'before' timestamp format: %v", err))
}
input.Before = &parsed
}
result, err := t.engine.Grep(ctx, input)
if err != nil {
return tools.ErrorResult("Grep failed: " + err.Error())
}
// Build response
output := map[string]any{
"success": result.Success,
"summaries": result.Summaries,
"messages": result.Messages,
}
// Add hint if provided
if result.Hint != "" {
output["hint"] = result.Hint
}
data, _ := json.Marshal(output)
return tools.NewToolResult(string(data))
}
+72
View File
@@ -0,0 +1,72 @@
package seahorse
import (
"context"
"testing"
)
func TestGrepSearchSummaries(t *testing.T) {
s := openTestStore(t)
ctx := context.Background()
conv, _ := s.GetOrCreateConversation(ctx, "test:grep-tool")
s.CreateSummary(ctx, CreateSummaryInput{
ConversationID: conv.ConversationID,
Kind: SummaryKindLeaf,
Depth: 0,
Content: "database connection pool configuration",
TokenCount: 50,
})
re := &RetrievalEngine{store: s}
results, err := re.Grep(ctx, GrepInput{
Pattern: "database",
})
if err != nil {
t.Fatalf("Grep: %v", err)
}
if len(results.Summaries) == 0 {
t.Error("expected at least 1 summary result")
}
}
func TestGrepSearchMessages(t *testing.T) {
s := openTestStore(t)
ctx := context.Background()
conv, _ := s.GetOrCreateConversation(ctx, "test:grep-msg")
s.AddMessage(ctx, conv.ConversationID, "user", "find this message about testing", 5)
s.AddMessage(ctx, conv.ConversationID, "user", "unrelated content", 3)
re := &RetrievalEngine{store: s}
results, err := re.Grep(ctx, GrepInput{
Pattern: "testing",
})
if err != nil {
t.Fatalf("Grep messages: %v", err)
}
if len(results.Messages) == 0 {
t.Error("expected at least 1 message result")
}
}
func TestGrepMissingPattern(t *testing.T) {
s := openTestStore(t)
re := &RetrievalEngine{store: s}
_, err := re.Grep(context.Background(), GrepInput{})
if err == nil {
t.Error("expected error for missing pattern")
}
}
func TestGrepToolSupportsAllConversations(t *testing.T) {
s := openTestStore(t)
tool := NewGrepTool(&RetrievalEngine{store: s})
params := tool.Parameters()
props := params["properties"].(map[string]any)
// GrepTool should accept all_conversations parameter
if _, ok := props["all_conversations"]; !ok {
t.Error("Parameters missing 'all_conversations' field")
}
}
+161
View File
@@ -0,0 +1,161 @@
package seahorse
import (
"time"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/tokenizer"
)
// SummaryKind distinguishes leaf summaries (from raw messages) vs condensed
// summaries (from other summaries).
type SummaryKind string
const (
SummaryKindLeaf SummaryKind = "leaf"
SummaryKindCondensed SummaryKind = "condensed"
)
// Message represents a single chat message with role and content.
type Message struct {
ID int64 `json:"id"`
ConversationID int64 `json:"conversationId"`
Role string `json:"role"`
Content string `json:"content"`
ReasoningContent string `json:"reasoningContent,omitempty"`
TokenCount int `json:"tokenCount"`
CreatedAt time.Time `json:"createdAt"`
Parts []MessagePart `json:"parts,omitempty"`
}
// MessagePart holds structured content (tool calls, media, etc.)
type MessagePart struct {
ID int64 `json:"id"`
MessageID int64 `json:"messageId"`
Type string `json:"type"` // "text", "tool_use", "tool_result", "media"
Text string `json:"text"`
Name string `json:"name"`
Arguments string `json:"arguments"`
ToolCallID string `json:"toolCallId"`
MediaURI string `json:"mediaUri"`
MimeType string `json:"mimeType"`
}
// Summary represents a compressed representation of messages or other summaries.
type Summary struct {
SummaryID string `json:"summaryId"`
ConversationID int64 `json:"conversationId"`
Kind SummaryKind `json:"kind"`
Depth int `json:"depth"`
Content string `json:"content"`
TokenCount int `json:"tokenCount"`
EarliestAt *time.Time `json:"earliestAt,omitempty"`
LatestAt *time.Time `json:"latestAt,omitempty"`
DescendantCount int `json:"descendantCount"`
DescendantTokenCount int `json:"descendantTokenCount"`
SourceMessageTokenCount int `json:"sourceMessageTokenCount"`
Model string `json:"model"`
CreatedAt time.Time `json:"createdAt"`
}
// SummaryNode is a Summary with graph relationships for tree traversal.
type SummaryNode struct {
Summary
Children []string `json:"children"` // Child summary IDs
Expanded bool `json:"expanded"` // UI state for expansion
}
// Conversation represents a session's conversation with metadata.
type Conversation struct {
ConversationID int64 `json:"conversationId"`
SessionKey string `json:"sessionKey"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
}
// SessionStatus contains status information for a session.
type SessionStatus struct {
SessionKey string `json:"sessionKey"`
ConversationID int64 `json:"conversationId"`
Messages int `json:"messages"`
TotalTokens int `json:"totalTokens"`
Summaries int `json:"summaries"`
OldestAt time.Time `json:"oldestAt"`
NewestAt time.Time `json:"newestAt"`
}
// ContextItem represents one item in the assembled context window.
type ContextItem struct {
ConversationID int64 `json:"conversationId"`
Ordinal int `json:"ordinal"`
ItemType string `json:"itemType"` // "summary" or "message"
SummaryID string `json:"summaryId,omitempty"`
MessageID int64 `json:"messageId,omitempty"`
TokenCount int `json:"tokenCount"`
CreatedAt time.Time `json:"createdAt"`
}
// SummarySubtreeNode is a node in a summary DAG subtree.
type SummarySubtreeNode struct {
SummaryID string `json:"summaryId"`
DepthFromRoot int `json:"depthFromRoot"`
}
// SearchInput controls summary search.
type SearchInput struct {
Pattern string `json:"pattern"`
Mode string `json:"mode"` // "like" (LIKE search) or "full_text" (FTS5, default)
Scope string `json:"scope,omitempty"` // "messages", "summaries", "both"
Role string `json:"role,omitempty"` // "user", "assistant", or "" (all)
Since *time.Time `json:"since,omitempty"`
Before *time.Time `json:"before,omitempty"`
Limit int `json:"limit,omitempty"`
ConversationID int64 `json:"conversationId,omitempty"`
AllConversations bool `json:"allConversations,omitempty"`
}
// SearchResult is a search match.
type SearchResult struct {
SummaryID string `json:"summaryId,omitempty"`
MessageID int64 `json:"messageId,omitempty"`
ConversationID int64 `json:"conversationId"`
Kind SummaryKind `json:"kind,omitempty"`
Depth int `json:"depth,omitempty"`
Role string `json:"role,omitempty"`
Content string `json:"content,omitempty"` // Full content for summaries
Snippet string `json:"snippet"`
CreatedAt time.Time `json:"createdAt"`
Rank float64 `json:"rank,omitempty"`
TotalCount int `json:"totalCount,omitempty"` // Total matching rows (from window function)
}
// EstimateMessageTokens estimates token count for a full message using the
// shared tokenizer package for consistency with agent.context_budget.
func EstimateMessageTokens(msg Message) int {
pm := providers.Message{
Role: msg.Role,
Content: msg.Content,
ReasoningContent: msg.ReasoningContent,
}
// Convert MessageParts to ToolCalls / ToolCallID / Media
for _, part := range msg.Parts {
switch part.Type {
case "tool_use":
pm.ToolCalls = append(pm.ToolCalls, providers.ToolCall{
ID: part.ToolCallID,
Type: "function",
Function: &providers.FunctionCall{
Name: part.Name,
Arguments: part.Arguments,
},
})
case "tool_result":
pm.ToolCallID = part.ToolCallID
case "media":
pm.Media = append(pm.Media, part.MediaURI)
}
}
return tokenizer.EstimateMessageTokens(pm)
}
+54
View File
@@ -0,0 +1,54 @@
package seahorse
import (
"testing"
)
func TestSummaryKindValues(t *testing.T) {
if SummaryKindLeaf != "leaf" {
t.Errorf("expected SummaryKindLeaf = 'leaf', got %q", SummaryKindLeaf)
}
if SummaryKindCondensed != "condensed" {
t.Errorf("expected SummaryKindCondensed = 'condensed', got %q", SummaryKindCondensed)
}
}
func TestConstants(t *testing.T) {
// Ordinal gap step
if OrdinalStep != 100 {
t.Errorf("expected OrdinalStep = 100, got %d", OrdinalStep)
}
// Compaction triggers
if ContextThreshold != 0.75 {
t.Errorf("expected ContextThreshold = 0.75, got %f", ContextThreshold)
}
if FreshTailCount != 32 {
t.Errorf("expected FreshTailCount = 32, got %d", FreshTailCount)
}
// Fanout
if LeafMinFanout != 8 {
t.Errorf("expected LeafMinFanout = 8, got %d", LeafMinFanout)
}
if CondensedMinFanout != 4 {
t.Errorf("expected CondensedMinFanout = 4, got %d", CondensedMinFanout)
}
if CondensedMinFanoutHard != 2 {
t.Errorf("expected CondensedMinFanoutHard = 2, got %d", CondensedMinFanoutHard)
}
// Token targets
if LeafChunkTokens != 20000 {
t.Errorf("expected LeafChunkTokens = 20000, got %d", LeafChunkTokens)
}
if LeafTargetTokens != 1200 {
t.Errorf("expected LeafTargetTokens = 1200, got %d", LeafTargetTokens)
}
if CondensedTargetTokens != 2000 {
t.Errorf("expected CondensedTargetTokens = 2000, got %d", CondensedTargetTokens)
}
if MaxExpandTokens != 4000 {
t.Errorf("expected MaxExpandTokens = 4000, got %d", MaxExpandTokens)
}
}
+5
View File
@@ -79,3 +79,8 @@ func (b *JSONLBackend) Save(key string) error {
func (b *JSONLBackend) Close() error { func (b *JSONLBackend) Close() error {
return b.store.Close() return b.store.Close()
} }
// ListSessions returns all known session keys.
func (b *JSONLBackend) ListSessions() []string {
return b.store.ListSessions()
}
+10
View File
@@ -145,6 +145,16 @@ func (sm *SessionManager) TruncateHistory(key string, keepLast int) {
session.Updated = time.Now() session.Updated = time.Now()
} }
func (sm *SessionManager) ListSessions() []string {
sm.mu.RLock()
defer sm.mu.RUnlock()
keys := make([]string, 0, len(sm.sessions))
for k := range sm.sessions {
keys = append(keys, k)
}
return keys
}
// sanitizeFilename converts a session key into a cross-platform safe filename. // sanitizeFilename converts a session key into a cross-platform safe filename.
// Replaces ':' with '_' (session key separator) and '/' and '\' with '_' so // Replaces ':' with '_' (session key separator) and '/' and '\' with '_' so
// composite IDs (e.g. Telegram forum "chatID/threadID") do not create // composite IDs (e.g. Telegram forum "chatID/threadID") do not create
+2
View File
@@ -27,6 +27,8 @@ type SessionStore interface {
TruncateHistory(key string, keepLast int) TruncateHistory(key string, keepLast int)
// Save persists any pending state to durable storage. // Save persists any pending state to durable storage.
Save(key string) error Save(key string) error
// ListSessions returns all known session keys.
ListSessions() []string
// Close releases resources held by the store. // Close releases resources held by the store.
Close() error Close() error
} }
+91
View File
@@ -0,0 +1,91 @@
package tokenizer
import (
"encoding/json"
"unicode/utf8"
"github.com/sipeed/picoclaw/pkg/providers"
)
// EstimateMessageTokens estimates the token count for a single message,
// including Content, ReasoningContent, ToolCalls arguments, ToolCallID
// metadata, and Media items. Uses a heuristic of 2.5 characters per token.
func EstimateMessageTokens(msg providers.Message) int {
contentChars := utf8.RuneCountInString(msg.Content)
// SystemParts are structured system blocks used for cache-aware adapters.
// They carry the same content as Content, but in multiple blocks.
// We estimate them as an alternative representation, not additive.
systemPartsChars := 0
if len(msg.SystemParts) > 0 {
for _, part := range msg.SystemParts {
systemPartsChars += utf8.RuneCountInString(part.Text)
}
// Per-part overhead for JSON structure (type, text, cache_control).
const perPartOverhead = 20
systemPartsChars += len(msg.SystemParts) * perPartOverhead
}
// Use the larger of the two representations to stay conservative.
chars := contentChars
if systemPartsChars > chars {
chars = systemPartsChars
}
chars += utf8.RuneCountInString(msg.ReasoningContent)
for _, tc := range msg.ToolCalls {
chars += len(tc.ID) + len(tc.Type)
if tc.Function != nil {
// Count function name + arguments (the wire format for most providers).
// tc.Name mirrors tc.Function.Name — count only once to avoid double-counting.
chars += len(tc.Function.Name) + len(tc.Function.Arguments)
} else {
// Fallback: some provider formats use top-level Name without Function.
chars += len(tc.Name)
}
}
if msg.ToolCallID != "" {
chars += len(msg.ToolCallID)
}
// Per-message overhead for role label, JSON structure, separators.
const messageOverhead = 12
chars += messageOverhead
tokens := chars * 2 / 5
// Media items (images, files) are serialized by provider adapters into
// multipart or image_url payloads. Add a fixed per-item token estimate
// directly (not through the chars heuristic) since actual cost depends
// on resolution and provider-specific image tokenization.
const mediaTokensPerItem = 256
tokens += len(msg.Media) * mediaTokensPerItem
return tokens
}
// EstimateToolDefsTokens estimates the total token cost of tool definitions
// as they appear in the LLM request.
func EstimateToolDefsTokens(defs []providers.ToolDefinition) int {
if len(defs) == 0 {
return 0
}
totalChars := 0
for _, d := range defs {
totalChars += len(d.Function.Name) + len(d.Function.Description)
if d.Function.Parameters != nil {
if paramJSON, err := json.Marshal(d.Function.Parameters); err == nil {
totalChars += len(paramJSON)
}
}
// Per-tool overhead: type field, JSON structure, separators.
totalChars += 20
}
return totalChars * 2 / 5
}