From 15a70ac45c5a37ddeeede8150431e5b6e1de6516 Mon Sep 17 00:00:00 2001 From: Liu Yuan Date: Sun, 5 Apr 2026 09:05:16 +0800 Subject: [PATCH] feat(seahorse): implement short-term memory engine (LCM) (#2285) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat(seahorse): implement short-term memory engine of seahorse Add pkg/seahorse/ module implementing a SQLite-backed DAG-based summary hierarchy for context management, ported from lossless-claw's LCM design: - types.go + short_constants.go: core types (Message, Summary, Conversation, ContextItem) and configuration constants (fanout, token targets, thresholds) - migration.go: idempotent DB schema with FTS5 trigram tokenizer for CJK - store.go: full SQLite CRUD (conversations, messages, summaries DAG, context_items with ordinal gap numbering, FTS5 search) - short_engine.go: Engine lifecycle (NewEngine, Ingest, Assemble, Compact), session pattern filtering (ignore/stateless glob→regex compilation), per-session mutex via sync.Map - short_assembler.go: budget-aware context assembly with fresh tail protection (32 messages), oldest-first eviction, summary XML formatting, RebuildContextItems - short_compaction.go: leaf compaction (messages→summary) and condensed compaction (summaries→higher-level summary), 3-level LLM escalation, CompactUntilUnder for emergency overflow - short_retrieval.go: lookupByID, FTS5/LIKE search, recursive expand with token cap - context_seahorse.go: agent.ContextManager adapter, registered as "seahorse", provider↔seahorse message type conversion (ToolCalls, tool_result) * fix(seahorse): correct 3 adapter bugs in context management - TokenCount: use full message (Content+ToolCalls+Media) instead of Content-only - Empty Content: rebuild Content from tool_result Parts when stored empty - Duplicate summaries: summaries only in Summary field, not in History messages - Grep: fix SearchResult.Snippet→Content for summaries - Schema: fix FTS5 SQL uses VIRTUAL TABLE not TEMP TABLE - TestFTS5SQLConstants: verify FTS5 SQL syntax correctness - Test: fix flaky TestCompactLeaf * fix(agent): ingest steering messages into seahorse SQLite Steering messages were only persisted to session JSONL but not ingested into seahorse SQLite, causing them to be missing from context assembly. Added `ts.ingestMessage(turnCtx, al, pm)` call in the steering message injection block alongside the existing JSONL persistence. Test: TestSeahorseSteeringMessageIngested verifies steering messages appear in seahorse SQLite DB after being processed. * fix(seahorse): address 3 blocking bugs from code review - Fix resequenceContextItemsTx scan error handling (store.go:850) Changed `return err` to `return scanErr` to properly propagate scan errors instead of returning nil (which silently corrupts data) - Fix sql.NullString for INTEGER column (store.go:847) Changed `mid` from sql.NullString to sql.NullInt64 since message_id is INTEGER in schema. Removed unnecessary strconv.ParseInt call. - Fix compactCondensed fallback deleting non-candidate items Added ReplaceContextItemsWithSummary method for per-item deletion when candidates are not contiguous in ordinal space. Optimized to use range deletion when candidates are consecutive. * fix(seahorse): pass Budget to Compact for correct condensed threshold Issue #4 from PR review: When Budget was not passed to seahorse.Compact, it defaulted to `tokensBefore * 0.75`, making `tokensBefore > budget` always true and causing condensed compaction to trigger unnecessarily. Changes: - context_seahorse.go: Forward Budget from CompactRequest to CompactInput - loop.go: Pass Budget (ContextWindow) in all 3 Compact calls - Add test verifying condensed is skipped when tokens < threshold - Fix lint issues in store.go and store_test.go * fix(seahorse): add mutex for assembler lazy initialization Issue #5 from PR review: The check-then-create pattern for e.assembler was a data race when multiple goroutines called Assemble() concurrently: if e.assembler == nil { e.assembler = &Assembler{...} } Changes: - Add assemblerMu sync.Mutex to Engine struct - Add initAssemblerOnce() using double-checked locking (same pattern as initCompactionOnce) - Add TestAssemblerLazyInitRace to verify thread-safety * fix(seahorse): handle non-consecutive depths in selectShallowestCondensationCandidate Issue #8 from PR review: the loop iterated depth 0, 1, 2... assuming consecutive keys, but break when key was missing caused deeper depths to never be checked. Fix: collect all existing depth keys, sort, then iterate in order. * fix(seahorse): wrap DeleteMessagesAfterID and appendContextItems in transactions - DeleteMessagesAfterID: wrap all DELETE operations in a transaction for atomicity, remove redundant manual FTS delete (handled by trigger) - appendContextItems: use transaction to fix read-then-write race condition - Add GetMaxOrdinalTx and resolveItemTokenCountTx for transaction-scoped queries - Remove unused resolveItemTokenCount function Fixes PR review issues 6 and 7. * fix(seahorse): derive readable content from Parts and cap CompactUntilUnder iterations - Derive readable content from MessageParts in AddMessageWithParts so FTS5 indexing and summary formatting can access tool call information - formatMessagesForSummary and truncateSummary now fall back to Parts when Content is empty, fixing blank summaries for Part-based messages - Add MaxCompactIterations (20) to prevent CompactUntilUnder infinite loops; exceeded iterations are logged as warnings --- .gitignore | 2 + .golangci.yaml | 1 + pkg/agent/context_budget.go | 96 +- pkg/agent/context_budget_test.go | 38 +- pkg/agent/context_legacy.go | 2 +- pkg/agent/context_manager.go | 1 + pkg/agent/context_seahorse.go | 267 +++ pkg/agent/context_seahorse_test.go | 1086 +++++++++++++ pkg/agent/loop.go | 6 +- pkg/agent/subturn.go | 6 +- pkg/memory/jsonl.go | 27 + pkg/memory/store.go | 3 + pkg/seahorse/.omc/state/last-tool-error.json | 7 + pkg/seahorse/compact_until_under_test.go | 58 + pkg/seahorse/parts_roundtrip_test.go | 144 ++ pkg/seahorse/schema.go | 185 +++ pkg/seahorse/schema_test.go | 211 +++ pkg/seahorse/short_assembler.go | 261 +++ pkg/seahorse/short_assembler_test.go | 536 ++++++ pkg/seahorse/short_bench_test.go | 336 ++++ pkg/seahorse/short_compaction.go | 898 ++++++++++ pkg/seahorse/short_compaction_test.go | 974 +++++++++++ pkg/seahorse/short_constants.go | 30 + pkg/seahorse/short_engine.go | 568 +++++++ pkg/seahorse/short_engine_test.go | 1448 +++++++++++++++++ pkg/seahorse/short_retrieval.go | 212 +++ pkg/seahorse/short_retrieval_test.go | 362 +++++ pkg/seahorse/store.go | 1532 ++++++++++++++++++ pkg/seahorse/store_test.go | 1250 ++++++++++++++ pkg/seahorse/tool_expand.go | 129 ++ pkg/seahorse/tool_expand_test.go | 136 ++ pkg/seahorse/tool_grep.go | 172 ++ pkg/seahorse/tool_grep_test.go | 72 + pkg/seahorse/types.go | 161 ++ pkg/seahorse/types_test.go | 54 + pkg/session/jsonl_backend.go | 5 + pkg/session/manager.go | 10 + pkg/session/session_store.go | 2 + pkg/tokenizer/estimator.go | 91 ++ 39 files changed, 11271 insertions(+), 108 deletions(-) create mode 100644 pkg/agent/context_seahorse.go create mode 100644 pkg/agent/context_seahorse_test.go create mode 100644 pkg/seahorse/.omc/state/last-tool-error.json create mode 100644 pkg/seahorse/compact_until_under_test.go create mode 100644 pkg/seahorse/parts_roundtrip_test.go create mode 100644 pkg/seahorse/schema.go create mode 100644 pkg/seahorse/schema_test.go create mode 100644 pkg/seahorse/short_assembler.go create mode 100644 pkg/seahorse/short_assembler_test.go create mode 100644 pkg/seahorse/short_bench_test.go create mode 100644 pkg/seahorse/short_compaction.go create mode 100644 pkg/seahorse/short_compaction_test.go create mode 100644 pkg/seahorse/short_constants.go create mode 100644 pkg/seahorse/short_engine.go create mode 100644 pkg/seahorse/short_engine_test.go create mode 100644 pkg/seahorse/short_retrieval.go create mode 100644 pkg/seahorse/short_retrieval_test.go create mode 100644 pkg/seahorse/store.go create mode 100644 pkg/seahorse/store_test.go create mode 100644 pkg/seahorse/tool_expand.go create mode 100644 pkg/seahorse/tool_expand_test.go create mode 100644 pkg/seahorse/tool_grep.go create mode 100644 pkg/seahorse/tool_grep_test.go create mode 100644 pkg/seahorse/types.go create mode 100644 pkg/seahorse/types_test.go create mode 100644 pkg/tokenizer/estimator.go diff --git a/.gitignore b/.gitignore index b869ecc33..135867842 100644 --- a/.gitignore +++ b/.gitignore @@ -67,3 +67,5 @@ web/backend/dist/* .claude/ docker/data + +.omc/ diff --git a/.golangci.yaml b/.golangci.yaml index b2b772406..052e4c0dd 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -12,6 +12,7 @@ linters: - exhaustruct - funcorder - gochecknoglobals + - gosmopolitan # Project legitimately uses CJK text in tests (FTS5, token counting) - godot - intrange - ireturn diff --git a/pkg/agent/context_budget.go b/pkg/agent/context_budget.go index 3398d7863..72f80382a 100644 --- a/pkg/agent/context_budget.go +++ b/pkg/agent/context_budget.go @@ -6,10 +6,8 @@ package agent import ( - "encoding/json" - "unicode/utf8" - "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/tokenizer" ) // parseTurnBoundaries returns the starting index of each Turn in the history. @@ -86,88 +84,16 @@ func findSafeBoundary(history []providers.Message, targetIndex int) int { return 0 } -// estimateMessageTokens estimates the token count for a single message, -// including Content, ReasoningContent, ToolCalls arguments, ToolCallID -// metadata, and Media items. Uses a heuristic of 2.5 characters per token. -func estimateMessageTokens(msg providers.Message) int { - contentChars := utf8.RuneCountInString(msg.Content) - - // SystemParts are structured system blocks used for cache-aware adapters. - // They carry the same content as Content, but in multiple blocks. - // We estimate them as an alternative representation, not additive. - systemPartsChars := 0 - if len(msg.SystemParts) > 0 { - for _, part := range msg.SystemParts { - systemPartsChars += utf8.RuneCountInString(part.Text) - } - // Per-part overhead for JSON structure (type, text, cache_control). - const perPartOverhead = 20 - systemPartsChars += len(msg.SystemParts) * perPartOverhead - } - - // Use the larger of the two representations to stay conservative. - chars := contentChars - if systemPartsChars > chars { - chars = systemPartsChars - } - - chars += utf8.RuneCountInString(msg.ReasoningContent) - - for _, tc := range msg.ToolCalls { - chars += len(tc.ID) + len(tc.Type) - if tc.Function != nil { - // Count function name + arguments (the wire format for most providers). - // tc.Name mirrors tc.Function.Name — count only once to avoid double-counting. - chars += len(tc.Function.Name) + len(tc.Function.Arguments) - } else { - // Fallback: some provider formats use top-level Name without Function. - chars += len(tc.Name) - } - } - - if msg.ToolCallID != "" { - chars += len(msg.ToolCallID) - } - - // Per-message overhead for role label, JSON structure, separators. - const messageOverhead = 12 - chars += messageOverhead - - tokens := chars * 2 / 5 - - // Media items (images, files) are serialized by provider adapters into - // multipart or image_url payloads. Add a fixed per-item token estimate - // directly (not through the chars heuristic) since actual cost depends - // on resolution and provider-specific image tokenization. - const mediaTokensPerItem = 256 - tokens += len(msg.Media) * mediaTokensPerItem - - return tokens +// EstimateMessageTokens estimates the token count for a single message. +// Delegates to the shared tokenizer package for consistency across agent and seahorse. +func EstimateMessageTokens(msg providers.Message) int { + return tokenizer.EstimateMessageTokens(msg) } -// estimateToolDefsTokens estimates the total token cost of tool definitions -// as they appear in the LLM request. Each tool's name, description, and -// JSON schema parameters contribute to the context window budget. -func estimateToolDefsTokens(defs []providers.ToolDefinition) int { - if len(defs) == 0 { - return 0 - } - - totalChars := 0 - for _, d := range defs { - totalChars += len(d.Function.Name) + len(d.Function.Description) - - if d.Function.Parameters != nil { - if paramJSON, err := json.Marshal(d.Function.Parameters); err == nil { - totalChars += len(paramJSON) - } - } - - // Per-tool overhead: type field, JSON structure, separators. - totalChars += 20 - } - - return totalChars * 2 / 5 +// EstimateToolDefsTokens estimates the total token cost of tool definitions +// as they appear in the LLM request. Delegates to the shared tokenizer package. +func EstimateToolDefsTokens(defs []providers.ToolDefinition) int { + return tokenizer.EstimateToolDefsTokens(defs) } // isOverContextBudget checks whether the assembled messages plus tool definitions @@ -181,10 +107,10 @@ func isOverContextBudget( ) bool { msgTokens := 0 for _, m := range messages { - msgTokens += estimateMessageTokens(m) + msgTokens += EstimateMessageTokens(m) } - toolTokens := estimateToolDefsTokens(toolDefs) + toolTokens := EstimateToolDefsTokens(toolDefs) total := msgTokens + toolTokens + maxTokens return total > contextWindow diff --git a/pkg/agent/context_budget_test.go b/pkg/agent/context_budget_test.go index 22cbdc0db..9de1707ec 100644 --- a/pkg/agent/context_budget_test.go +++ b/pkg/agent/context_budget_test.go @@ -417,9 +417,9 @@ func TestEstimateMessageTokens(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := estimateMessageTokens(tt.msg) + got := EstimateMessageTokens(tt.msg) if got < tt.want { - t.Errorf("estimateMessageTokens() = %d, want >= %d", got, tt.want) + t.Errorf("EstimateMessageTokens() = %d, want >= %d", got, tt.want) } }) } @@ -443,8 +443,8 @@ func TestEstimateMessageTokens_ToolCallsContribute(t *testing.T) { }, } - plainTokens := estimateMessageTokens(plain) - withTCTokens := estimateMessageTokens(withTC) + plainTokens := EstimateMessageTokens(plain) + withTCTokens := EstimateMessageTokens(withTC) if withTCTokens <= plainTokens { t.Errorf("message with ToolCalls (%d tokens) should exceed plain message (%d tokens)", @@ -457,7 +457,7 @@ func TestEstimateMessageTokens_MultibyteContent(t *testing.T) { // but may map to different token counts. The heuristic should still produce // reasonable estimates via RuneCountInString. msg := msgUser("caf\u00e9 na\u00efve r\u00e9sum\u00e9 \u00fcber stra\u00dfe") - tokens := estimateMessageTokens(msg) + tokens := EstimateMessageTokens(msg) if tokens <= 0 { t.Errorf("multibyte message should produce positive token count, got %d", tokens) } @@ -481,7 +481,7 @@ func TestEstimateMessageTokens_LargeArguments(t *testing.T) { }, } - tokens := estimateMessageTokens(msg) + tokens := EstimateMessageTokens(msg) // 5000+ chars → at least 2000 tokens with the 2.5 char/token heuristic if tokens < 2000 { t.Errorf("large tool call arguments should produce significant token count, got %d", tokens) @@ -496,8 +496,8 @@ func TestEstimateMessageTokens_ReasoningContent(t *testing.T) { ReasoningContent: strings.Repeat("thinking step ", 200), } - plainTokens := estimateMessageTokens(plain) - reasoningTokens := estimateMessageTokens(withReasoning) + plainTokens := EstimateMessageTokens(plain) + reasoningTokens := EstimateMessageTokens(withReasoning) if reasoningTokens <= plainTokens { t.Errorf("message with ReasoningContent (%d tokens) should exceed plain message (%d tokens)", @@ -513,8 +513,8 @@ func TestEstimateMessageTokens_MediaItems(t *testing.T) { Media: []string{"media://img1.png", "media://img2.png"}, } - plainTokens := estimateMessageTokens(plain) - mediaTokens := estimateMessageTokens(withMedia) + plainTokens := EstimateMessageTokens(plain) + mediaTokens := EstimateMessageTokens(withMedia) if mediaTokens <= plainTokens { t.Errorf("message with Media (%d tokens) should exceed plain message (%d tokens)", @@ -540,8 +540,8 @@ func TestEstimateMessageTokens_SystemParts(t *testing.T) { }, } - plainTokens := estimateMessageTokens(plain) - partsTokens := estimateMessageTokens(withParts) + plainTokens := EstimateMessageTokens(plain) + partsTokens := EstimateMessageTokens(withParts) if partsTokens <= plainTokens { t.Errorf("system message with SystemParts (%d) should exceed plain message (%d)", @@ -549,7 +549,7 @@ func TestEstimateMessageTokens_SystemParts(t *testing.T) { } } -// --- estimateToolDefsTokens tests --- +// --- EstimateToolDefsTokens tests --- func TestEstimateToolDefsTokens(t *testing.T) { tests := []struct { @@ -599,9 +599,9 @@ func TestEstimateToolDefsTokens(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := estimateToolDefsTokens(tt.defs) + got := EstimateToolDefsTokens(tt.defs) if got < tt.want { - t.Errorf("estimateToolDefsTokens() = %d, want >= %d", got, tt.want) + t.Errorf("EstimateToolDefsTokens() = %d, want >= %d", got, tt.want) } }) } @@ -624,8 +624,8 @@ func TestEstimateToolDefsTokens_ScalesWithCount(t *testing.T) { } } - one := estimateToolDefsTokens([]providers.ToolDefinition{makeTool("tool_a")}) - three := estimateToolDefsTokens([]providers.ToolDefinition{ + one := EstimateToolDefsTokens([]providers.ToolDefinition{makeTool("tool_a")}) + three := EstimateToolDefsTokens([]providers.ToolDefinition{ makeTool("tool_a"), makeTool("tool_b"), makeTool("tool_c"), }) @@ -770,7 +770,7 @@ func TestEstimateMessageTokens_WithReasoningAndMedia(t *testing.T) { }, } - tokens := estimateMessageTokens(msg) + tokens := EstimateMessageTokens(msg) // ReasoningContent alone is ~1700 chars → ~680 tokens. // Content + TC + overhead adds more. Should be well above 500. @@ -781,7 +781,7 @@ func TestEstimateMessageTokens_WithReasoningAndMedia(t *testing.T) { // Compare without reasoning to ensure it's counted. msgNoReasoning := msg msgNoReasoning.ReasoningContent = "" - tokensNoReasoning := estimateMessageTokens(msgNoReasoning) + tokensNoReasoning := EstimateMessageTokens(msgNoReasoning) if tokens <= tokensNoReasoning { t.Errorf("reasoning content should add tokens: with=%d, without=%d", tokens, tokensNoReasoning) diff --git a/pkg/agent/context_legacy.go b/pkg/agent/context_legacy.go index 23402460e..0f10decb3 100644 --- a/pkg/agent/context_legacy.go +++ b/pkg/agent/context_legacy.go @@ -373,7 +373,7 @@ func (m *legacyContextManager) summarizeBatch( func (m *legacyContextManager) estimateTokens(messages []providers.Message) int { total := 0 for _, msg := range messages { - total += estimateMessageTokens(msg) + total += EstimateMessageTokens(msg) } return total } diff --git a/pkg/agent/context_manager.go b/pkg/agent/context_manager.go index cc8904ccf..5f8701812 100644 --- a/pkg/agent/context_manager.go +++ b/pkg/agent/context_manager.go @@ -43,6 +43,7 @@ type AssembleResponse struct { type CompactRequest struct { SessionKey string // session identifier Reason ContextCompressReason // proactive_budget | llm_retry | summarize + Budget int // context window budget (used for retry aggressive compaction) } // IngestRequest is the input to Ingest. diff --git a/pkg/agent/context_seahorse.go b/pkg/agent/context_seahorse.go new file mode 100644 index 000000000..104a84a78 --- /dev/null +++ b/pkg/agent/context_seahorse.go @@ -0,0 +1,267 @@ +package agent + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/providers/protocoltypes" + "github.com/sipeed/picoclaw/pkg/seahorse" + "github.com/sipeed/picoclaw/pkg/session" + "github.com/sipeed/picoclaw/pkg/tokenizer" +) + +// seahorseContextManager adapts seahorse.Engine to agent.ContextManager. +type seahorseContextManager struct { + engine *seahorse.Engine + sessions session.SessionStore // for startup bootstrap +} + +// newSeahorseContextManager creates a seahorse-backed ContextManager. +func newSeahorseContextManager(_ json.RawMessage, al *AgentLoop) (ContextManager, error) { + if al == nil { + return nil, fmt.Errorf("seahorse: AgentLoop is required") + } + + // Resolve workspace for DB path + // DB stores session data, so it goes in sessions/ directory + agent := al.registry.GetDefaultAgent() + dbPath := agent.Workspace + "/sessions/seahorse.db" + + // Create CompleteFn from provider + completeFn := providerToCompleteFn(agent.Provider, agent.Model) + + // Create engine + engine, err := seahorse.NewEngine(seahorse.Config{ + DBPath: dbPath, + }, completeFn) + if err != nil { + return nil, fmt.Errorf("seahorse: create engine: %w", err) + } + + mgr := &seahorseContextManager{ + engine: engine, + sessions: agent.Sessions, + } + + // Register seahorse tools with the agent's tool registry + retrieval := mgr.engine.GetRetrieval() + al.RegisterTool(seahorse.NewGrepTool(retrieval)) + al.RegisterTool(seahorse.NewExpandTool(retrieval)) + + // Bootstrap all existing sessions at startup + if agent.Sessions != nil { + ctx := context.Background() + for _, sessionKey := range agent.Sessions.ListSessions() { + mgr.bootstrapSession(ctx, sessionKey) + } + } + + return mgr, nil +} + +// providerToCompleteFn wraps providers.LLMProvider as a seahorse.CompleteFn. +func providerToCompleteFn(provider providers.LLMProvider, model string) seahorse.CompleteFn { + return func(ctx context.Context, prompt string, opts seahorse.CompleteOptions) (string, error) { + resp, err := provider.Chat( + ctx, + []providers.Message{{Role: "user", Content: prompt}}, + nil, // no tools for summarization + model, + map[string]any{ + "max_tokens": opts.MaxTokens, + "temperature": opts.Temperature, + "prompt_cache_key": "seahorse", + }, + ) + if err != nil { + return "", err + } + return resp.Content, nil + } +} + +// Assemble builds budget-aware context from seahorse SQLite. +func (m *seahorseContextManager) Assemble(ctx context.Context, req *AssembleRequest) (*AssembleResponse, error) { + if req == nil { + return nil, fmt.Errorf("seahorse assemble: nil request") + } + + budget := req.Budget + if budget <= 0 { + budget = 100000 + } + + // Reserve space for model response (spec lines 1400-1410) + effectiveBudget := budget - req.MaxTokens + if effectiveBudget <= 0 { + // MaxTokens >= budget is a configuration problem + // Use 50% as minimum to avoid guaranteed overflow + logger.WarnCF("agent", "MaxTokens >= budget, using 50% fallback", + map[string]any{"budget": budget, "max_tokens": req.MaxTokens}) + effectiveBudget = budget / 2 + } + + result, err := m.engine.Assemble(ctx, req.SessionKey, seahorse.AssembleInput{ + Budget: effectiveBudget, + }) + if err != nil { + return nil, fmt.Errorf("seahorse assemble: %w", err) + } + + history := seahorseToProviderMessages(result) + + // Summary is already formatted as XML with system prompt addition by assembler + return &AssembleResponse{ + History: history, + Summary: result.Summary, + }, nil +} + +// Compact compresses conversation history via seahorse summarization. +func (m *seahorseContextManager) Compact(ctx context.Context, req *CompactRequest) error { + if req == nil { + return nil + } + + // For retry (LLM overflow), use aggressive CompactUntilUnder to guarantee + // context shrinks below budget (spec lines ~1410). + if req.Reason == ContextCompressReasonRetry && req.Budget > 0 { + _, err := m.engine.CompactUntilUnder(ctx, req.SessionKey, req.Budget) + return err + } + + _, err := m.engine.Compact(ctx, req.SessionKey, seahorse.CompactInput{ + Force: req.Reason == ContextCompressReasonRetry, + Budget: &req.Budget, + }) + return err +} + +// Ingest records a message into seahorse SQLite. +// All existing sessions are bootstrapped at startup, so this only ingests new messages. +func (m *seahorseContextManager) Ingest(ctx context.Context, req *IngestRequest) error { + if req == nil { + return nil + } + + msg := providerToSeahorseMessage(req.Message) + _, err := m.engine.Ingest(ctx, req.SessionKey, []seahorse.Message{msg}) + return err +} + +// bootstrapSession reconciles JSONL session history into seahorse SQLite. +func (m *seahorseContextManager) bootstrapSession(ctx context.Context, sessionKey string) { + if m.sessions == nil { + return + } + + history := m.sessions.GetHistory(sessionKey) + if len(history) == 0 { + return + } + + // Convert provider messages to seahorse messages + msgs := make([]seahorse.Message, len(history)) + for i, h := range history { + msgs[i] = providerToSeahorseMessage(h) + } + + if err := m.engine.Bootstrap(ctx, sessionKey, msgs); err != nil { + logger.WarnCF("seahorse", "bootstrap", map[string]any{ + "session": sessionKey, + "error": err.Error(), + }) + } +} + +// providerToSeahorseMessage converts a providers.Message to a seahorse.Message. +func providerToSeahorseMessage(msg protocoltypes.Message) seahorse.Message { + result := seahorse.Message{ + Role: msg.Role, + Content: msg.Content, + ReasoningContent: msg.ReasoningContent, + TokenCount: tokenizer.EstimateMessageTokens(msg), + } + + // Convert ToolCalls → MessageParts + for _, tc := range msg.ToolCalls { + part := seahorse.MessagePart{ + Type: "tool_use", + Name: tc.Function.Name, + Arguments: tc.Function.Arguments, + ToolCallID: tc.ID, + } + result.Parts = append(result.Parts, part) + } + + // Convert tool result + if msg.ToolCallID != "" { + part := seahorse.MessagePart{ + Type: "tool_result", + ToolCallID: msg.ToolCallID, + Text: msg.Content, + } + result.Parts = append(result.Parts, part) + } + + // Convert media attachments + for _, mediaURI := range msg.Media { + part := seahorse.MessagePart{ + Type: "media", + MediaURI: mediaURI, + } + result.Parts = append(result.Parts, part) + } + + return result +} + +// seahorseToProviderMessages converts a seahorse.AssembleResult to []providers.Message. +func seahorseToProviderMessages(result *seahorse.AssembleResult) []protocoltypes.Message { + messages := make([]protocoltypes.Message, 0, len(result.Messages)) + + // Convert assembled messages (which already include summary XML messages) + for _, msg := range result.Messages { + pm := protocoltypes.Message{ + Role: msg.Role, + Content: msg.Content, + ReasoningContent: msg.ReasoningContent, + } + + // Reconstruct ToolCalls from parts + for _, part := range msg.Parts { + if part.Type == "tool_use" { + pm.ToolCalls = append(pm.ToolCalls, protocoltypes.ToolCall{ + ID: part.ToolCallID, + Type: "function", // Required by OpenAI-compatible APIs (GLM, etc.) + Function: &protocoltypes.FunctionCall{ + Name: part.Name, + Arguments: part.Arguments, + }, + }) + } + if part.Type == "tool_result" { + pm.ToolCallID = part.ToolCallID + if pm.Content == "" && part.Text != "" { + pm.Content = part.Text + } + } + if part.Type == "media" && part.MediaURI != "" { + pm.Media = append(pm.Media, part.MediaURI) + } + } + + messages = append(messages, pm) + } + + return messages +} + +func init() { + if err := RegisterContextManager("seahorse", newSeahorseContextManager); err != nil { + panic(fmt.Sprintf("register seahorse context manager: %v", err)) + } +} diff --git a/pkg/agent/context_seahorse_test.go b/pkg/agent/context_seahorse_test.go new file mode 100644 index 000000000..e405ef944 --- /dev/null +++ b/pkg/agent/context_seahorse_test.go @@ -0,0 +1,1086 @@ +package agent + +import ( + "context" + "fmt" + "strings" + "testing" + "time" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/providers/protocoltypes" + "github.com/sipeed/picoclaw/pkg/seahorse" +) + +// seahorseTestProvider implements providers.LLMProvider for seahorse tests. +type seahorseTestProvider struct { + chatFn func(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]any) (*providers.LLMResponse, error) +} + +func (m *seahorseTestProvider) Chat( + ctx context.Context, + messages []providers.Message, + tools []providers.ToolDefinition, + model string, + options map[string]any, +) (*providers.LLMResponse, error) { + if m.chatFn != nil { + return m.chatFn(ctx, messages, tools, model, options) + } + return &providers.LLMResponse{Content: "mock response"}, nil +} + +func (m *seahorseTestProvider) GetDefaultModel() string { + return "mock-model" +} + +func TestSeahorseCMRegistration(t *testing.T) { + factory, ok := lookupContextManager("seahorse") + if !ok { + t.Error("expected 'seahorse' context manager to be registered") + } + if factory == nil { + t.Error("expected non-nil factory") + } +} + +func TestProviderToSeahorseMessage(t *testing.T) { + tests := []struct { + name string + input protocoltypes.Message + wantRole string + wantContent string + }{ + { + name: "simple user message", + input: protocoltypes.Message{Role: "user", Content: "hello world"}, + wantRole: "user", + wantContent: "hello world", + }, + { + name: "assistant message", + input: protocoltypes.Message{Role: "assistant", Content: "response text"}, + wantRole: "assistant", + wantContent: "response text", + }, + { + name: "tool result message", + input: protocoltypes.Message{Role: "tool", Content: "tool output", ToolCallID: "tc_123"}, + wantRole: "tool", + wantContent: "tool output", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := providerToSeahorseMessage(tt.input) + if result.Role != tt.wantRole { + t.Errorf("Role = %q, want %q", result.Role, tt.wantRole) + } + if result.Content != tt.wantContent { + t.Errorf("Content = %q, want %q", result.Content, tt.wantContent) + } + }) + } +} + +func TestProviderToSeahorseMessageWithToolCalls(t *testing.T) { + msg := protocoltypes.Message{ + Role: "assistant", + Content: "", + ToolCalls: []protocoltypes.ToolCall{ + { + ID: "tc_1", + Function: &protocoltypes.FunctionCall{ + Name: "read_file", + Arguments: `{"path":"/tmp/test"}`, + }, + }, + }, + } + + result := providerToSeahorseMessage(msg) + if result.Role != "assistant" { + t.Errorf("Role = %q, want assistant", result.Role) + } + if len(result.Parts) == 0 { + t.Fatal("expected at least 1 part from tool calls") + } + if result.Parts[0].Type != "tool_use" { + t.Errorf("Part type = %q, want tool_use", result.Parts[0].Type) + } + if result.Parts[0].Name != "read_file" { + t.Errorf("Part name = %q, want read_file", result.Parts[0].Name) + } + if result.Parts[0].ToolCallID != "tc_1" { + t.Errorf("Part ToolCallID = %q, want tc_1", result.Parts[0].ToolCallID) + } +} + +func TestProviderToSeahorseMessageWithToolResult(t *testing.T) { + msg := protocoltypes.Message{ + Role: "tool", + Content: "file contents here", + ToolCallID: "tc_456", + } + + result := providerToSeahorseMessage(msg) + if result.Role != "tool" { + t.Errorf("Role = %q, want tool", result.Role) + } + found := false + for _, p := range result.Parts { + if p.Type == "tool_result" && p.ToolCallID == "tc_456" { + found = true + break + } + } + if !found { + t.Error("expected tool_result part with ToolCallID tc_456") + } +} + +func TestProviderToSeahorseMessageWithMedia(t *testing.T) { + msg := protocoltypes.Message{ + Role: "user", + Content: "Here is an image", + Media: []string{"data:image/png;base64,abc123"}, + } + + result := providerToSeahorseMessage(msg) + if result.Role != "user" { + t.Errorf("Role = %q, want user", result.Role) + } + + // Should have a media part + found := false + for _, p := range result.Parts { + if p.Type == "media" { + found = true + if p.MediaURI != "data:image/png;base64,abc123" { + t.Errorf("MediaURI = %q, want data:image/png;base64,abc123", p.MediaURI) + } + break + } + } + if !found { + t.Error("expected media part in converted message") + } +} + +func TestProviderToSeahorseMessageWithReasoning(t *testing.T) { + msg := protocoltypes.Message{ + Role: "assistant", + Content: "response text", + ReasoningContent: "I thought about this carefully", + } + + result := providerToSeahorseMessage(msg) + if result.ReasoningContent != "I thought about this carefully" { + t.Errorf("ReasoningContent = %q, want 'I thought about this carefully'", result.ReasoningContent) + } +} + +func TestSeahorseToProviderMessagesWithReasoning(t *testing.T) { + result := &seahorse.AssembleResult{ + Messages: []seahorse.Message{ + { + Role: "assistant", + Content: "response", + ReasoningContent: "thinking process", + }, + }, + } + + messages := seahorseToProviderMessages(result) + if len(messages) != 1 { + t.Fatalf("expected 1 message, got %d", len(messages)) + } + if messages[0].ReasoningContent != "thinking process" { + t.Errorf("ReasoningContent = %q, want 'thinking process'", messages[0].ReasoningContent) + } +} + +func TestSeahorseToProviderMessages(t *testing.T) { + // Summaries should NOT be double-injected. + // The assembler already includes summaries as XML-formatted messages in Messages slice. + // seahorseToProviderMessages should only convert Messages, not Summaries. + summaryXML := ` + + test summary content + +` + summaryMsg := seahorse.Message{ + Role: "user", + Content: summaryXML, + TokenCount: 50, + } + rawMsg := seahorse.Message{ + Role: "user", + Content: "hello", + TokenCount: 5, + } + + result := seahorseToProviderMessages(&seahorse.AssembleResult{ + Messages: []seahorse.Message{summaryMsg, rawMsg}, + }) + + // Should have exactly 2 messages (from Messages slice only) + // NOT 3 (which would happen if Summaries were also converted) + if len(result) != 2 { + t.Fatalf("expected exactly 2 messages (no double injection), got %d", len(result)) + } + // First should be the XML summary message + if result[0].Content != summaryXML { + t.Errorf("first message content = %q, want summary XML", result[0].Content) + } + // Second should be the raw message + if result[1].Content != "hello" { + t.Errorf("second message content = %q, want 'hello'", result[1].Content) + } +} + +func TestSeahorseToProviderMessagesWithToolCalls(t *testing.T) { + msg := seahorse.Message{ + Role: "assistant", + Content: "", + TokenCount: 10, + Parts: []seahorse.MessagePart{ + { + Type: "tool_use", + Name: "read_file", + Arguments: `{"path":"/tmp"}`, + ToolCallID: "tc_1", + }, + }, + } + + result := seahorseToProviderMessages(&seahorse.AssembleResult{ + Messages: []seahorse.Message{msg}, + }) + + if len(result) != 1 { + t.Fatalf("expected 1 message, got %d", len(result)) + } + if result[0].Role != "assistant" { + t.Errorf("Role = %q, want assistant", result[0].Role) + } + if len(result[0].ToolCalls) != 1 { + t.Fatalf("ToolCalls = %d, want 1", len(result[0].ToolCalls)) + } + if result[0].ToolCalls[0].Function.Name != "read_file" { + t.Errorf("ToolCall name = %q, want read_file", result[0].ToolCalls[0].Function.Name) + } + // GLM API and other OpenAI-compatible APIs require Type: "function" + if result[0].ToolCalls[0].Type != "function" { + t.Errorf("ToolCall Type = %q, want 'function' (required by GLM/OpenAI APIs)", + result[0].ToolCalls[0].Type) + } +} + +func TestSeahorseToProviderMessagesToolResult(t *testing.T) { + msg := seahorse.Message{ + Role: "tool", + Content: "file output", + TokenCount: 5, + Parts: []seahorse.MessagePart{ + { + Type: "tool_result", + ToolCallID: "tc_99", + Text: "file output", + }, + }, + } + + result := seahorseToProviderMessages(&seahorse.AssembleResult{ + Messages: []seahorse.Message{msg}, + }) + + if len(result) != 1 { + t.Fatalf("expected 1 message, got %d", len(result)) + } + if result[0].ToolCallID != "tc_99" { + t.Errorf("ToolCallID = %q, want tc_99", result[0].ToolCallID) + } +} + +// --- providerToCompleteFn tests --- + +func TestProviderToCompleteFn(t *testing.T) { + var capturedMessages []providers.Message + var capturedModel string + var capturedOptions map[string]any + + mp := &seahorseTestProvider{ + chatFn: func(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]any) (*providers.LLMResponse, error) { + capturedMessages = messages + capturedModel = model + capturedOptions = options + return &providers.LLMResponse{Content: "summary of conversation"}, nil + }, + } + + completeFn := providerToCompleteFn(mp, "test-model-v1") + result, err := completeFn(context.Background(), "Summarize this text", seahorse.CompleteOptions{ + MaxTokens: 500, + Temperature: 0.3, + }) + if err != nil { + t.Fatalf("completeFn: %v", err) + } + if result != "summary of conversation" { + t.Errorf("result = %q, want 'summary of conversation'", result) + } + + // Verify prompt passed as user message + if len(capturedMessages) != 1 { + t.Fatalf("captured messages = %d, want 1", len(capturedMessages)) + } + if capturedMessages[0].Role != "user" { + t.Errorf("message role = %q, want user", capturedMessages[0].Role) + } + if capturedMessages[0].Content != "Summarize this text" { + t.Errorf("message content = %q, want 'Summarize this text'", capturedMessages[0].Content) + } + + // Verify model + if capturedModel != "test-model-v1" { + t.Errorf("model = %q, want 'test-model-v1'", capturedModel) + } + + // Verify options + if capturedOptions["max_tokens"] != 500 { + t.Errorf("max_tokens = %v, want 500", capturedOptions["max_tokens"]) + } + if capturedOptions["temperature"] != 0.3 { + t.Errorf("temperature = %v, want 0.3", capturedOptions["temperature"]) + } + if capturedOptions["prompt_cache_key"] != "seahorse" { + t.Errorf("prompt_cache_key = %v, want 'seahorse'", capturedOptions["prompt_cache_key"]) + } +} + +func TestSeahorseIgnoreHeartbeat(t *testing.T) { + // Verify that "heartbeat" sessions are ignored by default + // This tests the hardcoded ignore pattern from spec lines 1326-1328 + engine, err := seahorse.NewEngine(seahorse.Config{ + DBPath: t.TempDir() + "/test.db", + }, nil) + if err != nil { + t.Fatalf("NewEngine: %v", err) + } + defer engine.Close() + + ctx := context.Background() + result, err := engine.Ingest(ctx, "heartbeat", []seahorse.Message{ + {Role: "user", Content: "heartbeat msg", TokenCount: 5}, + }) + if err != nil { + t.Fatalf("Ingest: %v", err) + } + // Should return nil nil for ignored sessions + if result != nil { + t.Errorf("expected nil result for heartbeat session, got %+v", result) + } +} + +func TestProviderToCompleteFnError(t *testing.T) { + mp := &seahorseTestProvider{ + chatFn: func(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]any) (*providers.LLMResponse, error) { + return nil, context.Canceled + }, + } + + completeFn := providerToCompleteFn(mp, "test-model") + _, err := completeFn(context.Background(), "test prompt", seahorse.CompleteOptions{}) + if err == nil { + t.Error("expected error from canceled context") + } +} + +func TestSeahorseAdapterAssembleSubtractsMaxTokens(t *testing.T) { + // Create a real seahorse engine with temp DB + engine, err := seahorse.NewEngine(seahorse.Config{ + DBPath: t.TempDir() + "/test.db", + }, nil) + if err != nil { + t.Fatalf("NewEngine: %v", err) + } + defer engine.Close() + + ctx := context.Background() + mgr := &seahorseContextManager{engine: engine} + + // Ingest lots of large messages (~35 tokens each, 120 total = ~4200 tokens) + for i := 0; i < 60; i++ { + content := fmt.Sprintf( + "This is message number %d. It contains enough text to represent a meaningful conversation turn with the user asking about various topics in software engineering and system design principles that require careful consideration.", + i, + ) + _ = mgr.Ingest(ctx, &IngestRequest{ + SessionKey: "budget-sub", + Message: protocoltypes.Message{Role: "user", Content: content}, + }) + _ = mgr.Ingest(ctx, &IngestRequest{ + SessionKey: "budget-sub", + Message: protocoltypes.Message{Role: "assistant", Content: "Response"}, + }) + } + + // Call adapter Assemble with Budget=5000, MaxTokens=2000 + // Should use effective budget = 5000 - 2000 = 3000 + resp, err := mgr.Assemble(ctx, &AssembleRequest{ + SessionKey: "budget-sub", + Budget: 5000, + MaxTokens: 2000, + }) + if err != nil { + t.Fatalf("Assemble: %v", err) + } + if resp == nil { + t.Fatal("expected non-nil response") + } + + // Directly call engine with budget=3000 to get baseline + baseline, err := engine.Assemble(ctx, "budget-sub", seahorse.AssembleInput{Budget: 3000}) + if err != nil { + t.Fatalf("engine.Assemble baseline: %v", err) + } + + // The adapter result should have same message count as engine with budget 3000 + if len(resp.History) != len(baseline.Messages) { + t.Errorf("adapter Budget=5000 MaxTokens=2000 gave %d messages, engine Budget=3000 gave %d", + len(resp.History), len(baseline.Messages)) + } +} + +func TestSeahorseCompactRetryUsesCompactUntilUnder(t *testing.T) { + // Track which engine method was called + var compactCalled, compactUntilCalled bool + + engine, err := seahorse.NewEngine(seahorse.Config{ + DBPath: t.TempDir() + "/test.db", + }, nil) + if err != nil { + t.Fatalf("NewEngine: %v", err) + } + defer engine.Close() + + // Wrap engine to track calls + _ = compactCalled // track via adapter behavior + _ = compactUntilCalled + + mgr := &seahorseContextManager{engine: engine} + + ctx := context.Background() + + // Ingest messages so there's something to compact + for i := 0; i < 40; i++ { + content := fmt.Sprintf( + "message %d with enough text to have meaningful token count that fills up the budget nicely", + i, + ) + _ = mgr.Ingest(ctx, &IngestRequest{ + SessionKey: "compact-test", + Message: protocoltypes.Message{Role: "user", Content: content}, + }) + _ = mgr.Ingest(ctx, &IngestRequest{ + SessionKey: "compact-test", + Message: protocoltypes.Message{Role: "assistant", Content: "ok"}, + }) + } + + // Compact with retry reason and budget should succeed + err = mgr.Compact(ctx, &CompactRequest{ + SessionKey: "compact-test", + Reason: ContextCompressReasonRetry, + Budget: 5000, + }) + if err != nil { + t.Fatalf("Compact retry: %v", err) + } + + // Verify context was actually compacted (should have fewer tokens) + result, err := engine.Assemble(ctx, "compact-test", seahorse.AssembleInput{Budget: 5000}) + if err != nil { + t.Fatalf("Assemble after compact: %v", err) + } + if result == nil { + t.Fatal("expected non-nil assemble result") + } + // Compaction attempted — no assertion on exact count since no LLM + _ = result.Summary +} + +// TestSeahorseRealLoopNoDuplicateMessages tests the real-world scenario: +// 1. Start AgentLoop with seahorse context manager +// 2. Run a turn (user message -> LLM response) +// 3. Check DB for duplicate messages +// This test verifies that bootstrapping at startup (not during first Ingest) prevents duplicates. +func TestSeahorseRealLoopNoDuplicateMessages(t *testing.T) { + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: t.TempDir(), + ModelName: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + ContextManager: "seahorse", + }, + }, + } + + msgBus := bus.NewMessageBus() + mockProvider := &simpleMockProvider{response: "I received your message."} + al := NewAgentLoop(cfg, msgBus, mockProvider) + defaultAgent := al.registry.GetDefaultAgent() + if defaultAgent == nil { + t.Fatal("expected default agent") + } + + ctx := context.Background() + sessionKey := "test-real-loop-dup" + + // Run a turn: user message -> LLM response + _, err := al.runAgentLoop(ctx, defaultAgent, processOptions{ + SessionKey: sessionKey, + Channel: "cli", + ChatID: "direct", + UserMessage: "hello", + DefaultResponse: defaultResponse, + EnableSummary: false, + SendResponse: false, + }) + if err != nil { + t.Fatalf("runAgentLoop failed: %v", err) + } + + // Get the seahorse engine from context manager + seahorseCM, ok := al.contextManager.(*seahorseContextManager) + if !ok { + t.Fatal("expected seahorseContextManager") + } + + // Check DB for messages via RetrievalEngine.Store() + store := seahorseCM.engine.GetRetrieval().Store() + conv, err := store.GetOrCreateConversation(ctx, sessionKey) + if err != nil { + t.Fatalf("GetOrCreateConversation: %v", err) + } + + stored, err := store.GetMessages(ctx, conv.ConversationID, 20, 0) + if err != nil { + t.Fatalf("GetMessages: %v", err) + } + + t.Logf("DB has %d messages:", len(stored)) + for i, msg := range stored { + content := msg.Content + if len(content) > 40 { + content = content[:40] + "..." + } + t.Logf(" msg[%d]: role=%s content=%q", i, msg.Role, content) + } + + // Count duplicates by (role, content) + seen := make(map[string]int) + for _, msg := range stored { + key := msg.Role + ":" + msg.Content + seen[key]++ + } + for key, count := range seen { + if count > 1 { + t.Errorf("DUPLICATE BUG: %q appears %d times in DB", key, count) + } + } + + // Expected: 2 messages (user "hello" + assistant response) + if len(stored) != 2 { + t.Errorf("expected 2 messages in DB (user + assistant), got %d", len(stored)) + } +} + +// TestSeahorseAssembleReturnsAllSummaries verifies that Assemble returns ALL summaries, +// not just the latest one. This is important because summaries represent compressed +// conversation history at different points in time. +func TestSeahorseAssembleReturnsAllSummaries(t *testing.T) { + // Create a real seahorse engine with temp DB + engine, err := seahorse.NewEngine(seahorse.Config{ + DBPath: t.TempDir() + "/test.db", + }, nil) + if err != nil { + t.Fatalf("NewEngine: %v", err) + } + defer engine.Close() + + ctx := context.Background() + mgr := &seahorseContextManager{engine: engine} + sessionKey := "test-multi-summary" + + // Get the store to directly create summaries + store := engine.GetRetrieval().Store() + + // Get conversation ID + conv, err := store.GetOrCreateConversation(ctx, sessionKey) + if err != nil { + t.Fatalf("GetOrCreateConversation: %v", err) + } + + // Create some messages first + for i := 0; i < 20; i++ { + _ = mgr.Ingest(ctx, &IngestRequest{ + SessionKey: sessionKey, + Message: protocoltypes.Message{Role: "user", Content: fmt.Sprintf("Message %d", i)}, + }) + } + + // Directly create multiple summaries in the database to simulate multi-level compaction + testSummaries := []struct { + content string + kind seahorse.SummaryKind + depth int + token int + }{ + {"First summary about early conversation discussing topics A and B", seahorse.SummaryKindLeaf, 0, 100}, + {"Second summary covering middle conversation about topics C and D", seahorse.SummaryKindLeaf, 0, 150}, + {"Third summary is condensed from first two summaries about topics A-D", seahorse.SummaryKindCondensed, 1, 200}, + } + + summaryIDs := make([]string, 0, len(testSummaries)) + for _, s := range testSummaries { + input := seahorse.CreateSummaryInput{ + ConversationID: conv.ConversationID, + Kind: s.kind, + Depth: s.depth, + Content: s.content, + TokenCount: s.token, + } + summary, createErr := store.CreateSummary(ctx, input) + if createErr != nil { + t.Fatalf("CreateSummary: %v", createErr) + } + summaryIDs = append(summaryIDs, summary.SummaryID) + + // Add summary to context_items + err = store.AppendContextSummary(ctx, conv.ConversationID, summary.SummaryID) + if err != nil { + t.Fatalf("AppendContextSummary: %v", err) + } + } + + t.Logf("Created %d summaries directly in store", len(summaryIDs)) + + // Assemble and check summaries + resp, err := mgr.Assemble(ctx, &AssembleRequest{ + SessionKey: sessionKey, + Budget: 50000, + MaxTokens: 4096, + }) + if err != nil { + t.Fatalf("Assemble: %v", err) + } + + // Check seahorse engine directly for how many summaries exist + result, err := engine.Assemble(ctx, sessionKey, seahorse.AssembleInput{Budget: 50000}) + if err != nil { + t.Fatalf("engine.Assemble: %v", err) + } + + t.Logf("Seahorse returned Summary with %d chars", len(result.Summary)) + + // The Summary field should contain XML summaries with metadata (depth, kind) + // The assembler generates this from the Summaries list + if len(resp.Summary) > 0 { + // Should contain XML tag + if !strings.Contains(resp.Summary, " Content-only = %d", + resultWithToolCalls.TokenCount, resultContentOnly.TokenCount) + } + + // Message with ToolCallID + msgWithToolResult := protocoltypes.Message{ + Role: "tool", + Content: "This is a simple response with some text content.", + ToolCallID: "tc_456", + } + resultWithToolResult := providerToSeahorseMessage(msgWithToolResult) + + if resultWithToolResult.TokenCount <= resultContentOnly.TokenCount { + t.Errorf("TokenCount with ToolCallID = %d, should be > Content-only = %d", + resultWithToolResult.TokenCount, resultContentOnly.TokenCount) + } + + // Message with Media + msgWithMedia := protocoltypes.Message{ + Role: "user", + Content: "This is a simple response with some text content.", + Media: []string{"data:image/png;base64,abc123"}, + } + resultWithMedia := providerToSeahorseMessage(msgWithMedia) + + if resultWithMedia.TokenCount <= resultContentOnly.TokenCount { + t.Errorf("TokenCount with Media = %d, should be > Content-only = %d", + resultWithMedia.TokenCount, resultContentOnly.TokenCount) + } +} + +func TestSeahorseToProviderMessagesRebuildsContentFromParts(t *testing.T) { + msg := seahorse.Message{ + Role: "tool", + Content: "", + TokenCount: 50, + Parts: []seahorse.MessagePart{ + { + Type: "tool_result", + ToolCallID: "tc_999", + Text: "This is the actual tool output that should be in Content", + }, + }, + } + + result := seahorseToProviderMessages(&seahorse.AssembleResult{ + Messages: []seahorse.Message{msg}, + }) + + if len(result) != 1 { + t.Fatalf("expected 1 message, got %d", len(result)) + } + + if result[0].Content == "" { + t.Error("Content is empty - tool_result text was not rebuilt into Content") + } + if result[0].Content != "This is the actual tool output that should be in Content" { + t.Errorf("Content = %q, want tool output text from Parts", result[0].Content) + } +} + +func TestSeahorseAssembleSummaryNotInMessages(t *testing.T) { + engine, err := seahorse.NewEngine(seahorse.Config{ + DBPath: t.TempDir() + "/test.db", + }, nil) + if err != nil { + t.Fatalf("NewEngine: %v", err) + } + defer engine.Close() + + ctx := context.Background() + mgr := &seahorseContextManager{engine: engine} + sessionKey := "test-no-dup-summary" + + // Get the store to directly create a summary + store := engine.GetRetrieval().Store() + conv, err := store.GetOrCreateConversation(ctx, sessionKey) + if err != nil { + t.Fatalf("GetOrCreateConversation: %v", err) + } + + // Ingest some messages first + for i := 0; i < 10; i++ { + _ = mgr.Ingest(ctx, &IngestRequest{ + SessionKey: sessionKey, + Message: protocoltypes.Message{Role: "user", Content: fmt.Sprintf("Message %d", i)}, + }) + } + + // Create a summary + input := seahorse.CreateSummaryInput{ + ConversationID: conv.ConversationID, + Kind: seahorse.SummaryKindLeaf, + Depth: 0, + Content: "This is a test summary about the conversation", + TokenCount: 50, + } + summary, err := store.CreateSummary(ctx, input) + if err != nil { + t.Fatalf("CreateSummary: %v", err) + } + err = store.AppendContextSummary(ctx, conv.ConversationID, summary.SummaryID) + if err != nil { + t.Fatalf("AppendContextSummary: %v", err) + } + + // Assemble + resp, err := mgr.Assemble(ctx, &AssembleRequest{ + SessionKey: sessionKey, + Budget: 50000, + MaxTokens: 4096, + }) + if err != nil { + t.Fatalf("Assemble: %v", err) + } + + // Count how many times the summary content appears + summaryContent := "This is a test summary" + countInHistory := 0 + for _, msg := range resp.History { + if strings.Contains(msg.Content, summaryContent) { + countInHistory++ + } + } + + if countInHistory > 0 { + t.Errorf("Summary content appears %d times in History - should be 0", countInHistory) + } + + // Summary should appear in Summary field + if !strings.Contains(resp.Summary, summaryContent) { + t.Error("Summary content should appear in response.Summary field") + } +} + +// TestSeahorseSteeringMessageIngested verifies that steering messages are ingested +// into seahorse SQLite, not just session JSONL. +func TestSeahorseSteeringMessageIngested(t *testing.T) { + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: t.TempDir(), + ModelName: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + ContextManager: "seahorse", + }, + }, + } + + msgBus := bus.NewMessageBus() + mockProvider := &simpleMockProvider{response: "I received your message."} + al := NewAgentLoop(cfg, msgBus, mockProvider) + defaultAgent := al.registry.GetDefaultAgent() + if defaultAgent == nil { + t.Fatal("expected default agent") + } + + ctx := context.Background() + sessionKey := "test-steering-ingest" + + // First turn: establish conversation + _, err := al.runAgentLoop(ctx, defaultAgent, processOptions{ + SessionKey: sessionKey, + Channel: "cli", + ChatID: "direct", + UserMessage: "hello", + DefaultResponse: defaultResponse, + EnableSummary: false, + SendResponse: false, + }) + if err != nil { + t.Fatalf("first runAgentLoop failed: %v", err) + } + + // Inject a steering message + steerErr := al.InjectSteering(providers.Message{ + Role: "user", + Content: "steering message content", + }) + if steerErr != nil { + t.Fatalf("InjectSteering failed: %v", steerErr) + } + + // Second turn: should process steering message + _, err = al.runAgentLoop(ctx, defaultAgent, processOptions{ + SessionKey: sessionKey, + Channel: "cli", + ChatID: "direct", + UserMessage: "continue", + DefaultResponse: defaultResponse, + EnableSummary: false, + SendResponse: false, + }) + if err != nil { + t.Fatalf("second runAgentLoop failed: %v", err) + } + + // Get the seahorse engine from context manager + seahorseCM, ok := al.contextManager.(*seahorseContextManager) + if !ok { + t.Fatal("expected seahorseContextManager") + } + + // Check DB for steering message + store := seahorseCM.engine.GetRetrieval().Store() + conv, err := store.GetOrCreateConversation(ctx, sessionKey) + if err != nil { + t.Fatalf("GetOrCreateConversation: %v", err) + } + + stored, err := store.GetMessages(ctx, conv.ConversationID, 20, 0) + if err != nil { + t.Fatalf("GetMessages: %v", err) + } + + t.Logf("DB has %d messages:", len(stored)) + for i, msg := range stored { + content := msg.Content + if len(content) > 40 { + content = content[:40] + "..." + } + t.Logf(" msg[%d]: role=%s content=%q", i, msg.Role, content) + } + + // Find steering message in stored messages + foundSteering := false + for _, msg := range stored { + if msg.Content == "steering message content" { + foundSteering = true + break + } + } + + if !foundSteering { + t.Error("STEERING MESSAGE NOT IN SEAHORSE DB: steering message should be ingested into SQLite") + } +} + +// TestSeahorseSummarizeSkipsCondensedWhenBelowThreshold verifies that when +// Summarize is triggered but tokens are below ContextWindow threshold, +// condensed compaction should NOT run. +func TestSeahorseSummarizeSkipsCondensedWhenBelowThreshold(t *testing.T) { + contextWindow := 1000 + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: t.TempDir(), + ModelName: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + ContextManager: "seahorse", + ContextWindow: contextWindow, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &seahorseTestProvider{} + al := NewAgentLoop(cfg, msgBus, provider) + defaultAgent := al.registry.GetDefaultAgent() + if defaultAgent == nil { + t.Fatal("expected default agent") + } + + ctx := context.Background() + sessionKey := "test-summarize-skip-condensed" + + seahorseCM, ok := al.contextManager.(*seahorseContextManager) + if !ok { + t.Fatal("expected seahorseContextManager") + } + store := seahorseCM.engine.GetRetrieval().Store() + + conv, err := store.GetOrCreateConversation(ctx, sessionKey) + if err != nil { + t.Fatalf("GetOrCreateConversation: %v", err) + } + + // Insert leaf summaries directly (bypass leaf compaction requirement) + for i := 0; i < seahorse.CondensedMinFanout; i++ { + now := time.Now().UTC() + summary, sumErr := store.CreateSummary(ctx, seahorse.CreateSummaryInput{ + ConversationID: conv.ConversationID, + Kind: seahorse.SummaryKindLeaf, + Depth: 0, + Content: fmt.Sprintf("leaf summary %d", i), + TokenCount: 50, + EarliestAt: &now, + LatestAt: &now, + }) + if sumErr != nil { + t.Fatalf("CreateSummary %d: %v", i, sumErr) + } + if appendErr := store.AppendContextSummary(ctx, conv.ConversationID, summary.SummaryID); appendErr != nil { + t.Fatalf("AppendContextSummary %d: %v", i, appendErr) + } + } + + // Add fresh messages (required for condensation candidates) + for i := 0; i < seahorse.FreshTailCount+1; i++ { + m, msgErr := store.AddMessage(ctx, conv.ConversationID, "user", "fresh", 5) + if msgErr != nil { + t.Fatalf("AddMessage %d: %v", i, msgErr) + } + if appendErr := store.AppendContextMessage(ctx, conv.ConversationID, m.ID); appendErr != nil { + t.Fatalf("AppendContextMessage %d: %v", i, appendErr) + } + } + + tokensBefore, err := store.GetContextTokenCount(ctx, conv.ConversationID) + if err != nil { + t.Fatalf("GetContextTokenCount: %v", err) + } + threshold := int(float64(contextWindow) * seahorse.ContextThreshold) + t.Logf("Tokens before: %d, threshold: %d", tokensBefore, threshold) + + // Trigger Summarize + _, err = al.runAgentLoop(ctx, defaultAgent, processOptions{ + SessionKey: sessionKey, + Channel: "cli", + ChatID: "direct", + UserMessage: "trigger", + DefaultResponse: defaultResponse, + EnableSummary: true, + SendResponse: false, + }) + if err != nil { + t.Fatalf("runAgentLoop: %v", err) + } + + time.Sleep(500 * time.Millisecond) + + summaries, err := store.GetSummariesByConversation(ctx, conv.ConversationID) + if err != nil { + t.Fatalf("GetSummariesByConversation: %v", err) + } + + condensedCount := 0 + for _, sum := range summaries { + if sum.Kind == seahorse.SummaryKindCondensed { + condensedCount++ + } + } + + t.Logf("Condensed summaries: %d", condensedCount) + + if tokensBefore < threshold && condensedCount > 0 { + t.Errorf("BUG: condensed created when tokens (%d) < threshold (%d)", tokensBefore, threshold) + } +} diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 808d12c07..fc37ff8a0 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -1742,6 +1742,7 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er if err := al.contextManager.Compact(turnCtx, &CompactRequest{ SessionKey: ts.sessionKey, Reason: ContextCompressReasonProactive, + Budget: ts.agent.ContextWindow, }); err != nil { logger.WarnCF("agent", "Proactive compact failed", map[string]any{ "session_key": ts.sessionKey, @@ -1857,6 +1858,7 @@ turnLoop: if !ts.opts.NoHistory { ts.agent.Sessions.AddFullMessage(ts.sessionKey, pm) ts.recordPersistedMessage(pm) + ts.ingestMessage(turnCtx, al, pm) } logger.InfoCF("agent", "Injected steering message into context", map[string]any{ @@ -2128,6 +2130,7 @@ turnLoop: if compactErr := al.contextManager.Compact(turnCtx, &CompactRequest{ SessionKey: ts.sessionKey, Reason: ContextCompressReasonRetry, + Budget: ts.agent.ContextWindow, }); compactErr != nil { logger.WarnCF("agent", "Context overflow compact failed", map[string]any{ "session_key": ts.sessionKey, @@ -2773,7 +2776,7 @@ turnLoop: } } if ts.opts.EnableSummary { - al.contextManager.Compact(turnCtx, &CompactRequest{SessionKey: ts.sessionKey, Reason: ContextCompressReasonSummarize}) + al.contextManager.Compact(turnCtx, &CompactRequest{SessionKey: ts.sessionKey, Reason: ContextCompressReasonSummarize, Budget: ts.agent.ContextWindow}) } ts.setPhase(TurnPhaseCompleted) @@ -2849,6 +2852,7 @@ turnLoop: &CompactRequest{ SessionKey: ts.sessionKey, Reason: ContextCompressReasonSummarize, + Budget: ts.agent.ContextWindow, }, ) } diff --git a/pkg/agent/subturn.go b/pkg/agent/subturn.go index 9447f1384..9ee7b15c9 100644 --- a/pkg/agent/subturn.go +++ b/pkg/agent/subturn.go @@ -604,6 +604,7 @@ type ephemeralSessionStoreIface interface { SetHistory(key string, history []providers.Message) TruncateHistory(key string, keepLast int) Save(key string) error + ListSessions() []string Close() error } @@ -663,8 +664,9 @@ func (e *ephemeralSessionStore) TruncateHistory(_ string, keepLast int) { e.history = e.history[len(e.history)-keepLast:] } -func (e *ephemeralSessionStore) Save(_ string) error { return nil } -func (e *ephemeralSessionStore) Close() error { return nil } +func (e *ephemeralSessionStore) Save(_ string) error { return nil } +func (e *ephemeralSessionStore) Close() error { return nil } +func (e *ephemeralSessionStore) ListSessions() []string { return nil } func (e *ephemeralSessionStore) truncateLocked() { if len(e.history) > maxEphemeralHistorySize { diff --git a/pkg/memory/jsonl.go b/pkg/memory/jsonl.go index afe374166..fc1ec8eb1 100644 --- a/pkg/memory/jsonl.go +++ b/pkg/memory/jsonl.go @@ -455,6 +455,33 @@ func (s *JSONLStore) rewriteJSONL( return fileutil.WriteFileAtomic(s.jsonlPath(sessionKey), buf.Bytes(), 0o644) } +// ListSessions returns all known session keys by reading .meta.json files. +func (s *JSONLStore) ListSessions() []string { + entries, err := os.ReadDir(s.dir) + if err != nil { + return nil + } + var keys []string + for _, entry := range entries { + if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".meta.json") { + continue + } + // Read the meta file to get the original key + data, err := os.ReadFile(filepath.Join(s.dir, entry.Name())) + if err != nil { + continue + } + var meta sessionMeta + if err := json.Unmarshal(data, &meta); err != nil { + continue + } + if meta.Key != "" { + keys = append(keys, meta.Key) + } + } + return keys +} + func (s *JSONLStore) Close() error { return nil } diff --git a/pkg/memory/store.go b/pkg/memory/store.go index b6e11707d..11526b27c 100644 --- a/pkg/memory/store.go +++ b/pkg/memory/store.go @@ -37,6 +37,9 @@ type Store interface { // data. Backends that do not accumulate dead data may return nil. Compact(ctx context.Context, sessionKey string) error + // ListSessions returns all known session keys. + ListSessions() []string + // Close releases any resources held by the store. Close() error } diff --git a/pkg/seahorse/.omc/state/last-tool-error.json b/pkg/seahorse/.omc/state/last-tool-error.json new file mode 100644 index 000000000..2e7273e23 --- /dev/null +++ b/pkg/seahorse/.omc/state/last-tool-error.json @@ -0,0 +1,7 @@ +{ + "tool_name": "Bash", + "tool_input_preview": "{\"command\":\"cd /home/yliu/repos/picoclaw && make lint 2>&1\",\"timeout\":120000}", + "error": "Exit code 2\npkg/agent/context_seahorse_test.go:1027:1: File is not properly formatted (gci)\n\t\t\tEarliestAt: &now,\n^\n1 issues:\n* gci: 1\nmake: *** [Makefile:264: lint] Error 1", + "timestamp": "2026-04-04T02:38:32.067Z", + "retry_count": 6 +} \ No newline at end of file diff --git a/pkg/seahorse/compact_until_under_test.go b/pkg/seahorse/compact_until_under_test.go new file mode 100644 index 000000000..2bb96c263 --- /dev/null +++ b/pkg/seahorse/compact_until_under_test.go @@ -0,0 +1,58 @@ +package seahorse + +import ( + "context" + "testing" +) + +// ============================================================================= +// CompactUntilUnder iteration cap +// ============================================================================= + +func TestCompactUntilUnderIterationCap(t *testing.T) { + // Setup: create a conversation with so many tokens that compaction + // will never reach the budget. The iteration cap prevents infinite loops. + // + // We use a mock CompleteFn that always returns the same content, + // and a budget of 0 which tokens can never reach. + // Without the cap, this would loop forever. + + db := openTestDB(t) + if err := runSchema(db); err != nil { + t.Fatalf("migration: %v", err) + } + s := &Store{db: db} + + conv, _ := s.GetOrCreateConversation(context.Background(), "agent:iter-cap") + convID := conv.ConversationID + + // Add many messages to ensure there's plenty to compact + for i := 0; i < 40; i++ { + m, _ := s.AddMessage(context.Background(), convID, "user", + "this is a long message with lots of tokens to push context over budget", 100) + s.AppendContextMessage(context.Background(), convID, m.ID) + } + + // A completeFn that always succeeds but returns non-reducing content + mockComplete := func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) { + return "Summary that doesn't reduce tokens much.", nil + } + + ce, cancel := newTestCompactionEngineWithStore(s, mockComplete) + defer cancel() + + // Use budget=1 so tokens can never reach budget + // (each message is 100 tokens, so 40 messages = 4000 tokens, budget 1 is unreachable) + // The function should stop after maxCompactIterations, not loop forever + ce.config = Config{} // ensure defaults + + result, err := ce.CompactUntilUnder(context.Background(), convID, 1) + if err != nil { + // Should not error — should stop gracefully + t.Fatalf("CompactUntilUnder with budget=0: %v", err) + } + + // The function should have completed within reasonable time + // If it exceeded the cap, it would still return (not hang) + _ = result +} diff --git a/pkg/seahorse/parts_roundtrip_test.go b/pkg/seahorse/parts_roundtrip_test.go new file mode 100644 index 000000000..02df8a9ea --- /dev/null +++ b/pkg/seahorse/parts_roundtrip_test.go @@ -0,0 +1,144 @@ +package seahorse + +import ( + "context" + "testing" + "time" +) + +// ============================================================================= +// Bug 1: formatMessagesForSummary ignores Parts +// - formatMessagesForSummary only reads m.Content, empty for Part-based messages +// - truncateSummary has same issue +// ============================================================================= + +func TestFormatMessagesForSummaryIncludesParts(t *testing.T) { + ts := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC) + + messages := []Message{ + {ID: 1, Role: "user", Content: "hello world", CreatedAt: ts}, + { + ID: 2, + Role: "assistant", + Content: "", // empty — real content is in Parts + Parts: []MessagePart{ + {Type: "text", Text: "I will run a command"}, + {Type: "tool_use", Name: "bash", Arguments: `{"command":"ls -la"}`, ToolCallID: "call_1"}, + }, + CreatedAt: ts.Add(time.Minute), + }, + { + ID: 3, + Role: "tool", + Content: "", // empty — real content is in Parts + Parts: []MessagePart{ + {Type: "tool_result", Text: "file1.txt\nfile2.txt", ToolCallID: "call_1"}, + }, + CreatedAt: ts.Add(2 * time.Minute), + }, + } + + result := formatMessagesForSummary(messages) + + // Must contain the plain text message + if !contains(result, "hello world") { + t.Error("formatMessagesForSummary: missing plain text content") + } + + // Must contain tool_use info (not blank) + if !contains(result, "bash") || !contains(result, "ls -la") { + t.Errorf("formatMessagesForSummary: tool_use info missing from Parts.\nGot:\n%s", result) + } + + // Must contain tool_result info (not blank) + if !contains(result, "file1.txt") { + t.Errorf("formatMessagesForSummary: tool_result text missing from Parts.\nGot:\n%s", result) + } +} + +func TestTruncateSummaryIncludesParts(t *testing.T) { + messages := []Message{ + {ID: 1, Role: "user", Content: "run the tests", CreatedAt: time.Now()}, + { + ID: 2, + Role: "assistant", + Content: "", // empty + Parts: []MessagePart{ + {Type: "tool_use", Name: "bash", Arguments: `{"command":"go test ./..."}`, ToolCallID: "call_1"}, + }, + CreatedAt: time.Now(), + }, + { + ID: 3, + Role: "tool", + Content: "", // empty + Parts: []MessagePart{ + {Type: "tool_result", Text: "PASS\nok 3.2s", ToolCallID: "call_1"}, + }, + CreatedAt: time.Now(), + }, + } + + result := truncateSummary(messages) + + // Must contain plain text + if !contains(result, "run the tests") { + t.Error("truncateSummary: missing plain text content") + } + + // Must contain tool info from Parts (not blank) + if !contains(result, "bash") || !contains(result, "go test") { + t.Errorf("truncateSummary: tool_use info missing from Parts.\nGot:\n%s", result) + } + + // Must contain tool_result from Parts + if !contains(result, "PASS") { + t.Errorf("truncateSummary: tool_result text missing from Parts.\nGot:\n%s", result) + } +} + +// ============================================================================= +// Bug 2: SearchMessages cannot find Part-based messages +// - FTS5 indexes empty content, LIKE queries empty content +// ============================================================================= + +func TestSearchMessagesFindsPartBasedMessages(t *testing.T) { + s := openTestStore(t) + ctx := context.Background() + + conv, _ := s.GetOrCreateConversation(ctx, "agent:search-parts") + convID := conv.ConversationID + + // Add a plain message (searchable) + s.AddMessage(ctx, convID, "user", "list the files please", 5) + + // Add a Part-based message (tool_use) — currently NOT searchable + parts := []MessagePart{ + {Type: "tool_use", Name: "bash", Arguments: `{"command":"grep -r TODO ."}`, ToolCallID: "call_1"}, + } + s.AddMessageWithParts(ctx, convID, "assistant", parts, 10) + + // Add a Part-based message (tool_result) — currently NOT searchable + resultParts := []MessagePart{ + {Type: "tool_result", Text: "main.go:42: TODO fix this bug", ToolCallID: "call_1"}, + } + s.AddMessageWithParts(ctx, convID, "tool", resultParts, 10) + + // Search for "grep" — should find the tool_use message + results, err := s.SearchMessages(ctx, SearchInput{Pattern: "grep"}) + if err != nil { + t.Fatalf("SearchMessages: %v", err) + } + if len(results) == 0 { + t.Error("SearchMessages: 'grep' not found — Part-based messages are invisible to search") + } + + // Search for "TODO fix" — should find the tool_result message + results2, err := s.SearchMessages(ctx, SearchInput{Pattern: "TODO fix"}) + if err != nil { + t.Fatalf("SearchMessages: %v", err) + } + if len(results2) == 0 { + t.Error("SearchMessages: 'TODO fix' not found — tool_result messages are invisible to search") + } +} diff --git a/pkg/seahorse/schema.go b/pkg/seahorse/schema.go new file mode 100644 index 000000000..effa6d60d --- /dev/null +++ b/pkg/seahorse/schema.go @@ -0,0 +1,185 @@ +package seahorse + +import ( + "database/sql" + "fmt" + + "github.com/sipeed/picoclaw/pkg/logger" +) + +// SQL statements for FTS5 tables with trigram tokenizer. +const ( + sqlCreateSummariesFTS = `CREATE VIRTUAL TABLE IF NOT EXISTS summaries_fts USING fts5( + summary_id, + content, + tokenize="trigram" + )` + sqlCreateMessagesFTS = `CREATE VIRTUAL TABLE IF NOT EXISTS messages_fts USING fts5( + message_id, + content, + tokenize="trigram" + )` + sqlCheckFTS5Available = `CREATE VIRTUAL TABLE IF NOT EXISTS _fts5_check USING fts5(content)` + sqlCheckTrigramAvailable = `CREATE VIRTUAL TABLE IF NOT EXISTS _trigram_check USING fts5(content, tokenize="trigram")` + sqlDropFTS5Check = `DROP TABLE IF EXISTS _fts5_check` + sqlDropTrigramCheck = `DROP TABLE IF EXISTS _trigram_check` +) + +// runSchema creates or upgrades the database schema. +// All schemas are idempotent (safe to run multiple times). +func runSchema(db *sql.DB) error { + // Check FTS5 support before creating tables + if err := checkFTS5Support(db); err != nil { + return fmt.Errorf("FTS5 check: %w", err) + } + + stmts := []string{ + `CREATE TABLE IF NOT EXISTS conversations ( + conversation_id INTEGER PRIMARY KEY AUTOINCREMENT, + session_key TEXT NOT NULL UNIQUE, + created_at TEXT NOT NULL DEFAULT (datetime('now')), + updated_at TEXT NOT NULL DEFAULT (datetime('now')) + )`, + + `CREATE TABLE IF NOT EXISTS messages ( + message_id INTEGER PRIMARY KEY AUTOINCREMENT, + conversation_id INTEGER NOT NULL REFERENCES conversations(conversation_id), + role TEXT NOT NULL, + content TEXT NOT NULL DEFAULT '', + token_count INTEGER NOT NULL DEFAULT 0, + created_at TEXT NOT NULL DEFAULT (datetime('now')) + )`, + + `CREATE TABLE IF NOT EXISTS message_parts ( + part_id INTEGER PRIMARY KEY AUTOINCREMENT, + message_id INTEGER NOT NULL REFERENCES messages(message_id), + type TEXT NOT NULL, + text TEXT, + name TEXT, + arguments TEXT, + tool_call_id TEXT, + media_uri TEXT, + mime_type TEXT, + ordinal INTEGER NOT NULL DEFAULT 0 + )`, + + `CREATE TABLE IF NOT EXISTS summaries ( + summary_id TEXT PRIMARY KEY, + conversation_id INTEGER NOT NULL REFERENCES conversations(conversation_id), + kind TEXT NOT NULL, + depth INTEGER NOT NULL DEFAULT 0, + content TEXT NOT NULL, + token_count INTEGER NOT NULL DEFAULT 0, + earliest_at TEXT, + latest_at TEXT, + descendant_count INTEGER NOT NULL DEFAULT 0, + descendant_token_count INTEGER NOT NULL DEFAULT 0, + source_message_token_count INTEGER NOT NULL DEFAULT 0, + model TEXT, + created_at TEXT NOT NULL DEFAULT (datetime('now')) + )`, + + `CREATE TABLE IF NOT EXISTS summary_parents ( + summary_id TEXT NOT NULL, + parent_summary_id TEXT NOT NULL, + PRIMARY KEY (summary_id, parent_summary_id) + )`, + + `CREATE TABLE IF NOT EXISTS summary_messages ( + summary_id TEXT NOT NULL, + message_id INTEGER NOT NULL, + ordinal INTEGER NOT NULL DEFAULT 0, + PRIMARY KEY (summary_id, message_id) + )`, + + `CREATE TABLE IF NOT EXISTS context_items ( + conversation_id INTEGER NOT NULL, + ordinal INTEGER NOT NULL, + item_type TEXT NOT NULL, + summary_id TEXT, + message_id INTEGER, + token_count INTEGER NOT NULL DEFAULT 0, + created_at TEXT NOT NULL DEFAULT (datetime('now')), + PRIMARY KEY (conversation_id, ordinal) + )`, + + // FTS5 virtual table with trigram tokenizer for CJK support + sqlCreateSummariesFTS, + + // FTS5 virtual table for message search with trigram tokenizer + sqlCreateMessagesFTS, + + // Indexes for common query patterns + `CREATE INDEX IF NOT EXISTS idx_messages_conversation ON messages(conversation_id)`, + `CREATE INDEX IF NOT EXISTS idx_messages_created ON messages(conversation_id, created_at)`, + `CREATE INDEX IF NOT EXISTS idx_summaries_conversation ON summaries(conversation_id)`, + `CREATE INDEX IF NOT EXISTS idx_summaries_kind_depth ON summaries(conversation_id, kind, depth)`, + `CREATE INDEX IF NOT EXISTS idx_summary_parents_parent ON summary_parents(parent_summary_id)`, + `CREATE INDEX IF NOT EXISTS idx_summary_messages_message ON summary_messages(message_id)`, + `CREATE INDEX IF NOT EXISTS idx_context_items_conv ON context_items(conversation_id, ordinal)`, + + // FTS5 triggers to keep summaries_fts in sync with summaries table + `CREATE TRIGGER IF NOT EXISTS summaries_ai AFTER INSERT ON summaries BEGIN + INSERT INTO summaries_fts (summary_id, content) VALUES (new.summary_id, new.content); + END`, + `CREATE TRIGGER IF NOT EXISTS summaries_ad AFTER DELETE ON summaries BEGIN + INSERT INTO summaries_fts (summaries_fts, summary_id, content) VALUES ('delete', old.summary_id, old.content); + END`, + `CREATE TRIGGER IF NOT EXISTS summaries_au AFTER UPDATE ON summaries BEGIN + INSERT INTO summaries_fts (summaries_fts, summary_id, content) VALUES ('delete', old.summary_id, old.content); + INSERT INTO summaries_fts (summary_id, content) VALUES (new.summary_id, new.content); + END`, + + // FTS5 triggers to keep messages_fts in sync with messages table + `CREATE TRIGGER IF NOT EXISTS messages_ai AFTER INSERT ON messages BEGIN + INSERT INTO messages_fts (message_id, content) VALUES (new.message_id, new.content); + END`, + `CREATE TRIGGER IF NOT EXISTS messages_ad AFTER DELETE ON messages BEGIN + DELETE FROM messages_fts WHERE message_id = old.message_id; + END`, + `CREATE TRIGGER IF NOT EXISTS messages_au AFTER UPDATE ON messages BEGIN + DELETE FROM messages_fts WHERE message_id = old.message_id; + INSERT INTO messages_fts (message_id, content) VALUES (new.message_id, new.content); + END`, + } + + for _, s := range stmts { + if _, err := db.Exec(s); err != nil { + return err + } + } + return nil +} + +// checkFTS5Support verifies that SQLite has FTS5 with trigram tokenizer enabled. +// This is required for full-text search with CJK (Chinese, Japanese, Korean) support. +func checkFTS5Support(db *sql.DB) error { + // Check if FTS5 is compiled in + var fts5Enabled int + err := db.QueryRow(`SELECT sqlite_compileoption_used('ENABLE_FTS5')`).Scan(&fts5Enabled) + if err != nil { + // sqlite_compileoption_used might not exist in older SQLite + // Try a different approach: create a test FTS5 table + _, testErr := db.Exec(sqlCheckFTS5Available) + if testErr != nil { + return fmt.Errorf("SQLite FTS5 not available: %w (required for full-text search)", testErr) + } + db.Exec(sqlDropFTS5Check) + } else if fts5Enabled == 0 { + return fmt.Errorf("SQLite was compiled without FTS5 support (required for full-text search)") + } + + // Check if trigram tokenizer is available by trying to create a test table + // Not all SQLite builds include the trigram tokenizer + _, err = db.Exec(sqlCheckTrigramAvailable) + if err != nil { + logger.WarnCF("seahorse", "SQLite trigram tokenizer not available, CJK search may be limited", + map[string]any{"error": err.Error()}) + // Trigram is not strictly required, just better for CJK + // Don't return error, just log warning + } else { + db.Exec(sqlDropTrigramCheck) + } + + return nil +} diff --git a/pkg/seahorse/schema_test.go b/pkg/seahorse/schema_test.go new file mode 100644 index 000000000..17879f66c --- /dev/null +++ b/pkg/seahorse/schema_test.go @@ -0,0 +1,211 @@ +package seahorse + +import ( + "database/sql" + "testing" + + _ "modernc.org/sqlite" +) + +func openTestDB(t *testing.T) *sql.DB { + t.Helper() + db, err := sql.Open("sqlite", ":memory:") + if err != nil { + t.Fatalf("open test db: %v", err) + } + t.Cleanup(func() { db.Close() }) + return db +} + +func TestRunMigrations(t *testing.T) { + db := openTestDB(t) + + if err := runSchema(db); err != nil { + t.Fatalf("runSchema: %v", err) + } + + // Verify all tables exist + tables := []string{ + "conversations", + "messages", + "message_parts", + "summaries", + "summary_parents", + "summary_messages", + "context_items", + } + for _, tbl := range tables { + var name string + err := db.QueryRow( + "SELECT name FROM sqlite_master WHERE type='table' AND name=?", tbl, + ).Scan(&name) + if err != nil { + t.Errorf("table %q not found: %v", tbl, err) + } + } + + // Verify FTS5 virtual table exists + var ftsName string + err := db.QueryRow( + "SELECT name FROM sqlite_master WHERE type='table' AND name='summaries_fts'", + ).Scan(&ftsName) + if err != nil { + t.Errorf("FTS5 table summaries_fts not found: %v", err) + } +} + +func TestRunMigrationsIdempotent(t *testing.T) { + db := openTestDB(t) + + // Run migrations twice — should succeed both times + if err := runSchema(db); err != nil { + t.Fatalf("first migration: %v", err) + } + if err := runSchema(db); err != nil { + t.Fatalf("second migration (idempotent): %v", err) + } + + // Verify we can still insert data after double migration + res, err := db.Exec( + "INSERT INTO conversations (session_key, created_at, updated_at) VALUES (?, datetime('now'), datetime('now'))", + "test-session", + ) + if err != nil { + t.Fatalf("insert after double migration: %v", err) + } + id, _ := res.LastInsertId() + if id == 0 { + t.Error("expected non-zero conversation id") + } +} + +func TestMigrationConversationUnique(t *testing.T) { + db := openTestDB(t) + if err := runSchema(db); err != nil { + t.Fatalf("migration: %v", err) + } + + // Insert first + _, err := db.Exec( + "INSERT INTO conversations (session_key, created_at, updated_at) VALUES (?, datetime('now'), datetime('now'))", + "unique-key", + ) + if err != nil { + t.Fatalf("first insert: %v", err) + } + + // Duplicate should fail + _, err = db.Exec( + "INSERT INTO conversations (session_key, created_at, updated_at) VALUES (?, datetime('now'), datetime('now'))", + "unique-key", + ) + if err == nil { + t.Error("expected unique constraint violation for duplicate session_key") + } +} + +func TestMigrationSummaryFTSInsert(t *testing.T) { + db := openTestDB(t) + if err := runSchema(db); err != nil { + t.Fatalf("migration: %v", err) + } + + // Insert a conversation first + _, err := db.Exec( + "INSERT INTO conversations (session_key, created_at, updated_at) VALUES (?, datetime('now'), datetime('now'))", + "fts-test", + ) + if err != nil { + t.Fatalf("insert conversation: %v", err) + } + + // Insert a summary + _, err = db.Exec( + `INSERT INTO summaries (summary_id, conversation_id, kind, depth, content, token_count, created_at) + VALUES ('sum_test1', 1, 'leaf', 0, '你好世界 hello world', 10, datetime('now'))`) + if err != nil { + t.Fatalf("insert summary: %v", err) + } + + // FTS should find it — trigram tokenizer requires >= 3 chars + rows, err := db.Query( + "SELECT summary_id FROM summaries_fts WHERE summaries_fts MATCH ?", + "你好世", + ) + if err != nil { + t.Fatalf("FTS query: %v", err) + } + defer rows.Close() + + var found string + if rows.Next() { + if err := rows.Scan(&found); err != nil { + t.Fatalf("scan: %v", err) + } + } + if err := rows.Err(); err != nil { + t.Fatalf("rows.Err: %v", err) + } + if found != "sum_test1" { + t.Errorf("FTS: expected 'sum_test1', got %q", found) + } +} + +func TestMigrationSummaryParentsPK(t *testing.T) { + db := openTestDB(t) + if err := runSchema(db); err != nil { + t.Fatalf("migration: %v", err) + } + + // Insert two summaries + for _, id := range []string{"sum_a", "sum_b"} { + _, err := db.Exec( + `INSERT INTO summaries (summary_id, conversation_id, kind, depth, content, token_count, created_at) + VALUES (?, 1, 'leaf', 0, 'content', 5, datetime('now'))`, id) + if err != nil { + t.Fatalf("insert summary %s: %v", id, err) + } + } + + // Link child to parent + _, err := db.Exec( + "INSERT INTO summary_parents (summary_id, parent_summary_id) VALUES ('sum_a', 'sum_b')") + if err != nil { + t.Fatalf("link: %v", err) + } + + // Duplicate link should fail (composite PK) + _, err = db.Exec( + "INSERT INTO summary_parents (summary_id, parent_summary_id) VALUES ('sum_a', 'sum_b')") + if err == nil { + t.Error("expected unique constraint violation for duplicate summary_parents link") + } +} + +func TestFTS5SQLConstants(t *testing.T) { + db := openTestDB(t) + + // Verify FTS5 check SQL executes without error + _, err := db.Exec(sqlCheckFTS5Available) + if err != nil { + t.Errorf("sqlCheckFTS5Available failed: %v", err) + } + + // Verify trigram check SQL executes without error + _, err = db.Exec(sqlCheckTrigramAvailable) + if err != nil { + t.Errorf("sqlCheckTrigramAvailable failed: %v", err) + } + + // Verify summaries_fts SQL executes without error + _, err = db.Exec(sqlCreateSummariesFTS) + if err != nil { + t.Errorf("sqlCreateSummariesFTS failed: %v", err) + } + + // Verify messages_fts SQL executes without error + _, err = db.Exec(sqlCreateMessagesFTS) + if err != nil { + t.Errorf("sqlCreateMessagesFTS failed: %v", err) + } +} diff --git a/pkg/seahorse/short_assembler.go b/pkg/seahorse/short_assembler.go new file mode 100644 index 000000000..f0fd323ba --- /dev/null +++ b/pkg/seahorse/short_assembler.go @@ -0,0 +1,261 @@ +package seahorse + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/sipeed/picoclaw/pkg/logger" +) + +// escapeXML escapes special characters for safe inclusion in XML content. +func escapeXML(s string) string { + s = strings.ReplaceAll(s, "&", "&") + s = strings.ReplaceAll(s, "<", "<") + s = strings.ReplaceAll(s, ">", ">") + s = strings.ReplaceAll(s, "\"", """) + s = strings.ReplaceAll(s, "'", "'") + return s +} + +// resolvedItem is a context item resolved to its full content with token count. +type resolvedItem struct { + ordinal int + itemType string // "message" or "summary" + message *Message + summary *Summary + tokenCount int +} + +// Assemble builds budget-constrained context from summaries + messages. +// +// Algorithm: +// 1. Fetch context_items, resolve to full content +// 2. Split into evictable prefix + protected fresh tail +// 3. If evictable fits in remaining budget → include all +// 4. Else walk evictable from newest to oldest, keep while fits +func (a *Assembler) Assemble(ctx context.Context, convID int64, input AssembleInput) (*AssembleResult, error) { + items, err := a.store.GetContextItems(ctx, convID) + if err != nil { + return nil, fmt.Errorf("get context items: %w", err) + } + if len(items) == 0 { + return &AssembleResult{}, nil + } + + // Resolve all items + resolved := make([]resolvedItem, len(items)) + for i, item := range items { + r, err := a.resolveItem(ctx, item) + if err != nil { + return nil, err + } + resolved[i] = r + } + + // Split into evictable prefix and protected fresh tail + tailStart := len(resolved) - FreshTailCount + if tailStart < 0 { + tailStart = 0 + } + evictable := resolved[:tailStart] + freshTail := resolved[tailStart:] + + // Calculate fresh tail tokens + freshTailTokens := 0 + for _, r := range freshTail { + freshTailTokens += r.tokenCount + } + + // Budget-aware selection of evictable items + remainingBudget := input.Budget - freshTailTokens + if remainingBudget < 0 { + // Fresh tail alone exceeds budget - we keep it anyway (design decision) + // Log for debugging retry/overflow issues + logger.InfoCF("seahorse", "assemble: fresh tail exceeds budget", map[string]any{ + "budget": input.Budget, + "fresh_tail_tokens": freshTailTokens, + "fresh_tail_count": len(freshTail), + "over_budget_by": freshTailTokens - input.Budget, + }) + remainingBudget = 0 + } + + var selected []resolvedItem + evictableTokens := 0 + for _, r := range evictable { + evictableTokens += r.tokenCount + } + + if evictableTokens <= remainingBudget { + // All evictable fit + selected = append(selected, evictable...) + } else { + // Walk from newest to oldest, keep while fits + var kept []resolvedItem + accum := 0 + for i := len(evictable) - 1; i >= 0; i-- { + if accum+evictable[i].tokenCount <= remainingBudget { + kept = append(kept, evictable[i]) + accum += evictable[i].tokenCount + } else { + break + } + } + // Reverse to restore chronological order + for i, j := 0, len(kept)-1; i < j; i, j = i+1, j-1 { + kept[i], kept[j] = kept[j], kept[i] + } + selected = append(selected, kept...) + } + + // Combine: selected evictable + fresh tail + final := append(selected, freshTail...) + + // Build result + var messages []Message + var summaries []Summary + var sourceIDs []string + totalTokens := 0 + maxDepth := 0 + condensedCount := 0 + + for _, r := range final { + totalTokens += r.tokenCount + if r.itemType == "message" && r.message != nil { + messages = append(messages, *r.message) + sourceIDs = append(sourceIDs, fmt.Sprintf("msg:%d", r.message.ID)) + } else if r.itemType == "summary" && r.summary != nil { + summaries = append(summaries, *r.summary) + if r.summary.Depth > maxDepth { + maxDepth = r.summary.Depth + } + if r.summary.Kind == SummaryKindCondensed { + condensedCount++ + } + } + } + + // Build depth-aware system prompt addition + systemPromptAddition := "" + if len(summaries) > 0 { + if maxDepth >= 2 || condensedCount >= 2 { + systemPromptAddition = "Your context has been heavily compressed through multi-level summarization.\n" + + "- Do NOT assert specific facts (commands, SHAs, paths, timestamps) from summaries without expanding.\n" + + "- When uncertain, use expand to recover original detail before making claims.\n" + + "- Tool escalation: grep \xe2\x86\x92 describe \xe2\x86\x92 expand" + } else { + systemPromptAddition = "Some earlier messages have been summarized. Use expand tools to recover details if needed." + } + } + + // Build Summary field: all XML summaries + system prompt addition + var summaryParts []string + for _, sum := range summaries { + if sum.Content == "" { + continue + } + // Load parent IDs for XML formatting + parentSummaries, err := a.store.GetSummaryParents(ctx, sum.SummaryID) + if err != nil { + logger.WarnCF("seahorse", "assemble: get summary parents", map[string]any{ + "summary_id": sum.SummaryID, + "error": err.Error(), + }) + } + var parentIDs []string + for _, ps := range parentSummaries { + parentIDs = append(parentIDs, ps.SummaryID) + } + summaryParts = append(summaryParts, FormatSummaryXML(&sum, parentIDs)) + } + summary := strings.Join(summaryParts, "\n\n") + if systemPromptAddition != "" { + if summary != "" { + summary += "\n\n" + } + summary += systemPromptAddition + } + + return &AssembleResult{ + Messages: messages, + Summary: summary, + }, nil +} + +// resolveItem loads the full message or summary for a context item. +func (a *Assembler) resolveItem(ctx context.Context, item ContextItem) (resolvedItem, error) { + if item.ItemType == "message" { + msg, err := a.store.GetMessageByID(ctx, item.MessageID) + if err != nil { + return resolvedItem{}, err + } + tokens := item.TokenCount + if tokens == 0 { + tokens = msg.TokenCount + } + return resolvedItem{ + ordinal: item.Ordinal, + itemType: "message", + message: msg, + tokenCount: tokens, + }, nil + } + + if item.ItemType == "summary" { + sum, err := a.store.GetSummary(ctx, item.SummaryID) + if err != nil { + return resolvedItem{}, err + } + tokens := item.TokenCount + if tokens == 0 { + tokens = sum.TokenCount + } + return resolvedItem{ + ordinal: item.Ordinal, + itemType: "summary", + summary: sum, + tokenCount: tokens, + }, nil + } + + return resolvedItem{ + ordinal: item.Ordinal, + itemType: item.ItemType, + tokenCount: item.TokenCount, + }, nil +} + +// FormatSummaryXML formats a summary as XML for LLM context. +// This is exported so context managers can format summaries consistently. +func FormatSummaryXML(s *Summary, parentIDs []string) string { + // Build time attributes if available + var attrs string + if s.EarliestAt != nil { + attrs += fmt.Sprintf(` earliest_at="%s"`, s.EarliestAt.Format(time.RFC3339)) + } + if s.LatestAt != nil { + attrs += fmt.Sprintf(` latest_at="%s"`, s.LatestAt.Format(time.RFC3339)) + } + + var parentsSection string + if s.Kind == SummaryKindCondensed && len(parentIDs) > 0 { + parents := "\n" + for _, pid := range parentIDs { + parents += fmt.Sprintf(" \n", pid) + } + parents += " \n" + parentsSection = parents + } + return fmt.Sprintf( + "\n \n %s\n \n%s", + s.SummaryID, + string(s.Kind), + s.Depth, + s.DescendantCount, + attrs, + escapeXML(s.Content), + parentsSection, + ) +} diff --git a/pkg/seahorse/short_assembler_test.go b/pkg/seahorse/short_assembler_test.go new file mode 100644 index 000000000..88a05e64c --- /dev/null +++ b/pkg/seahorse/short_assembler_test.go @@ -0,0 +1,536 @@ +package seahorse + +import ( + "context" + "strings" + "testing" + "time" +) + +// --- Assembler Tests --- + +// helper: create a store with messages and summaries for assembly tests +func setupAssemblerStore(t *testing.T) (*Store, int64) { + t.Helper() + s := openTestStore(t) + ctx := context.Background() + + conv, err := s.GetOrCreateConversation(ctx, "test:assemble") + if err != nil { + t.Fatalf("create conversation: %v", err) + } + + return s, conv.ConversationID +} + +func TestAssemblerAssembleEmpty(t *testing.T) { + s, convID := setupAssemblerStore(t) + ctx := context.Background() + + a := &Assembler{store: s, config: Config{}} + result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 1000}) + if err != nil { + t.Fatalf("Assemble: %v", err) + } + if len(result.Messages) != 0 { + t.Errorf("Messages = %d, want 0", len(result.Messages)) + } + if result.Summary != "" { + t.Errorf("Summary = %q, want empty", result.Summary) + } +} + +func TestAssemblerAssembleMessagesOnly(t *testing.T) { + s, convID := setupAssemblerStore(t) + ctx := context.Background() + + // Create messages + msg1, _ := s.AddMessage(ctx, convID, "user", "hello", 5) + msg2, _ := s.AddMessage(ctx, convID, "assistant", "world", 5) + + // Create context items + s.UpsertContextItems(ctx, convID, []ContextItem{ + {Ordinal: 100, ItemType: "message", MessageID: msg1.ID, TokenCount: 5}, + {Ordinal: 200, ItemType: "message", MessageID: msg2.ID, TokenCount: 5}, + }) + + a := &Assembler{store: s, config: Config{}} + result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 100}) + if err != nil { + t.Fatalf("Assemble: %v", err) + } + + if len(result.Messages) != 2 { + t.Fatalf("Messages = %d, want 2", len(result.Messages)) + } + if result.Messages[0].Content != "hello" { + t.Errorf("Messages[0].Content = %q, want 'hello'", result.Messages[0].Content) + } + if result.Messages[1].Content != "world" { + t.Errorf("Messages[1].Content = %q, want 'world'", result.Messages[1].Content) + } + // No summaries, so Summary should be empty + if result.Summary != "" { + t.Errorf("Summary = %q, want empty", result.Summary) + } +} + +func TestAssemblerAssembleWithSummary(t *testing.T) { + s, convID := setupAssemblerStore(t) + ctx := context.Background() + + // Create a summary + summary, _ := s.CreateSummary(ctx, CreateSummaryInput{ + ConversationID: convID, + Kind: SummaryKindLeaf, + Depth: 0, + Content: "summary of early messages", + TokenCount: 50, + }) + + // Create recent messages + msg1, _ := s.AddMessage(ctx, convID, "user", "recent", 5) + msg2, _ := s.AddMessage(ctx, convID, "assistant", "reply", 5) + + // Context: summary + recent messages + s.UpsertContextItems(ctx, convID, []ContextItem{ + {Ordinal: 100, ItemType: "summary", SummaryID: summary.SummaryID, TokenCount: 50}, + {Ordinal: 200, ItemType: "message", MessageID: msg1.ID, TokenCount: 5}, + {Ordinal: 300, ItemType: "message", MessageID: msg2.ID, TokenCount: 5}, + }) + + a := &Assembler{store: s, config: Config{}} + result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 1000}) + if err != nil { + t.Fatalf("Assemble: %v", err) + } + + // Messages = 2 raw messages (summaries are in Summary field, not Messages) + if len(result.Messages) != 2 { + t.Errorf("Messages = %d, want 2 (raw messages only)", len(result.Messages)) + } + // Summary should contain XML with summary content + if result.Summary == "" { + t.Error("Summary should not be empty when summary exists") + } + if !strings.Contains(result.Summary, summary.Content) { + t.Errorf("Summary should contain summary content %q", summary.Content) + } + if !strings.Contains(result.Summary, "`, + TokenCount: 20, + }) + + s.UpsertContextItems(ctx, convID, []ContextItem{ + {Ordinal: 100, ItemType: "summary", SummaryID: summary.SummaryID, TokenCount: 20}, + }) + + a := &Assembler{store: s, config: Config{}} + result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 1000}) + if err != nil { + t.Fatalf("Assemble: %v", err) + } + + // Summary field should contain XML with escaped special characters + if result.Summary == "" { + t.Fatal("Summary should not be empty") + } + + // Check that special characters are escaped + if strings.Contains(result.Summary, "") { + t.Errorf("BUG: unescaped < in summary content: %q", result.Summary) + } + if strings.Contains(result.Summary, `"hello"`) { + t.Errorf("BUG: unescaped \" in summary content: %q", result.Summary) + } + // & should be escaped as & + if strings.Contains(result.Summary, " & ") { + t.Errorf("BUG: unescaped & in summary content: %q", result.Summary) + } +} + +func TestAssemblerSummaryXMLWithParents(t *testing.T) { + s, convID := setupAssemblerStore(t) + ctx := context.Background() + + // Create a leaf and a condensed summary (condensed has parent) + leaf, _ := s.CreateSummary(ctx, CreateSummaryInput{ + ConversationID: convID, + Kind: SummaryKindLeaf, + Depth: 0, + Content: "leaf content", + TokenCount: 20, + }) + condensed, _ := s.CreateSummary(ctx, CreateSummaryInput{ + ConversationID: convID, + Kind: SummaryKindCondensed, + Depth: 1, + Content: "condensed content", + TokenCount: 15, + ParentIDs: []string{leaf.SummaryID}, + }) + + msg, _ := s.AddMessage(ctx, convID, "user", "fresh", 5) + + s.UpsertContextItems(ctx, convID, []ContextItem{ + {Ordinal: 100, ItemType: "summary", SummaryID: condensed.SummaryID, TokenCount: 15}, + {Ordinal: 200, ItemType: "message", MessageID: msg.ID, TokenCount: 5}, + }) + + a := &Assembler{store: s, config: Config{}} + result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 1000}) + if err != nil { + t.Fatalf("Assemble: %v", err) + } + + // Summary field should contain XML with parent information + if result.Summary == "" { + t.Fatal("Summary should not be empty") + } + xmlContent := result.Summary + + // Should contain section with parent ID + if !contains(xmlContent, "") { + t.Errorf("condensed summary XML missing section: %q", xmlContent) + } + if !contains(xmlContent, leaf.SummaryID) { + t.Errorf("condensed summary XML missing parent ID %q: %q", leaf.SummaryID, xmlContent) + } + + // Should contain kind="condensed" + if !contains(xmlContent, `kind="condensed"`) { + t.Errorf("condensed summary XML missing kind attribute: %q", xmlContent) + } +} + +func TestAssemblerSummaryXMLIncludesDescendantCount(t *testing.T) { + s, convID := setupAssemblerStore(t) + ctx := context.Background() + + // Create a leaf summary with specific descendant count + leaf, _ := s.CreateSummary(ctx, CreateSummaryInput{ + ConversationID: convID, + Kind: SummaryKindLeaf, + Depth: 0, + Content: "leaf content", + TokenCount: 20, + DescendantCount: 8, + DescendantTokenCount: 1200, + }) + + msg, _ := s.AddMessage(ctx, convID, "user", "fresh", 5) + + s.UpsertContextItems(ctx, convID, []ContextItem{ + {Ordinal: 100, ItemType: "summary", SummaryID: leaf.SummaryID, TokenCount: 20}, + {Ordinal: 200, ItemType: "message", MessageID: msg.ID, TokenCount: 5}, + }) + + a := &Assembler{store: s, config: Config{}} + result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 1000}) + if err != nil { + t.Fatalf("Assemble: %v", err) + } + + if result.Summary == "" { + t.Fatal("Summary should not be empty") + } + xmlContent := result.Summary + + // Should contain descendant_count="8" + if !contains(xmlContent, `descendant_count="8"`) { + t.Errorf("summary XML missing descendant_count attribute: %q", xmlContent) + } +} + +func TestAssemblerLeafSummaryNoParents(t *testing.T) { + s, convID := setupAssemblerStore(t) + ctx := context.Background() + + // Leaf summary has no parents + leaf, _ := s.CreateSummary(ctx, CreateSummaryInput{ + ConversationID: convID, + Kind: SummaryKindLeaf, + Depth: 0, + Content: "leaf content", + TokenCount: 20, + }) + + msg, _ := s.AddMessage(ctx, convID, "user", "fresh", 5) + + s.UpsertContextItems(ctx, convID, []ContextItem{ + {Ordinal: 100, ItemType: "summary", SummaryID: leaf.SummaryID, TokenCount: 20}, + {Ordinal: 200, ItemType: "message", MessageID: msg.ID, TokenCount: 5}, + }) + + a := &Assembler{store: s, config: Config{}} + result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 1000}) + if err != nil { + t.Fatalf("Assemble: %v", err) + } + + if result.Summary == "" { + t.Fatal("Summary should not be empty") + } + xmlContent := result.Summary + + // Leaf summary should NOT have section + if contains(xmlContent, "") { + t.Errorf("leaf summary XML should not have section: %q", xmlContent) + } +} + +func TestAssemblerDepthAwarePrompt(t *testing.T) { + s, convID := setupAssemblerStore(t) + ctx := context.Background() + + // Create a condensed summary (depth >= 2) to trigger full guidance + now := time.Now().UTC() + leaf, _ := s.CreateSummary(ctx, CreateSummaryInput{ + ConversationID: convID, + Kind: SummaryKindLeaf, + Depth: 0, + Content: "leaf summary", + TokenCount: 20, + EarliestAt: &now, + LatestAt: &now, + }) + condensed, _ := s.CreateSummary(ctx, CreateSummaryInput{ + ConversationID: convID, + Kind: SummaryKindCondensed, + Depth: 2, + Content: "condensed summary", + TokenCount: 15, + ParentIDs: []string{leaf.SummaryID}, + DescendantCount: 1, + DescendantTokenCount: 20, + }) + + msg, _ := s.AddMessage(ctx, convID, "user", "fresh", 5) + + s.UpsertContextItems(ctx, convID, []ContextItem{ + {Ordinal: 100, ItemType: "summary", SummaryID: condensed.SummaryID, TokenCount: 15}, + {Ordinal: 200, ItemType: "message", MessageID: msg.ID, TokenCount: 5}, + }) + + a := &Assembler{store: s, config: Config{}} + result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 1000}) + if err != nil { + t.Fatalf("Assemble: %v", err) + } + + // Should have a depth-aware prompt in Summary field + if result.Summary == "" { + t.Error("expected non-empty Summary when depth >= 2") + } + // SystemPromptAddition is embedded in Summary field + if !strings.Contains(result.Summary, "multi-level summarization") { + t.Error("Summary should contain system prompt addition about multi-level summarization") + } +} + +func TestFormatSummaryXMLUsesSummaryRef(t *testing.T) { + // Spec: condensed summaries use not parentId + now := time.Now().UTC() + s := Summary{ + SummaryID: "sum_condensed1", + Kind: SummaryKindCondensed, + Depth: 1, + Content: "condensed content", + TokenCount: 50, + DescendantCount: 2, + EarliestAt: &now, + LatestAt: &now, + } + parentIDs := []string{"sum_leaf1", "sum_leaf2"} + + xml := FormatSummaryXML(&s, parentIDs) + + // Must use per spec + if !contains(xml, ``) { + t.Errorf("expected , got: %s", xml) + } + if !contains(xml, ``) { + t.Errorf("expected , got: %s", xml) + } + // Must NOT use old tag + if contains(xml, "") { + t.Errorf("should not use tag, got: %s", xml) + } +} + +func TestFormatSummaryXMLIncludesTimestamps(t *testing.T) { + // Spec: summary XML includes earliest_at and latest_at attributes + earliest := time.Date(2026, 3, 15, 10, 0, 0, 0, time.UTC) + latest := time.Date(2026, 3, 15, 14, 30, 0, 0, time.UTC) + s := Summary{ + SummaryID: "sum_leaf1", + Kind: SummaryKindLeaf, + Depth: 0, + Content: "leaf content", + TokenCount: 30, + DescendantCount: 0, + EarliestAt: &earliest, + LatestAt: &latest, + } + + xml := FormatSummaryXML(&s, nil) + + if !contains(xml, `earliest_at="2026-03-15T10:00:00Z"`) { + t.Errorf("missing earliest_at attribute, got: %s", xml) + } + if !contains(xml, `latest_at="2026-03-15T14:30:00Z"`) { + t.Errorf("missing latest_at attribute, got: %s", xml) + } +} + +func TestFormatSummaryXMLNoTimestampsWhenNil(t *testing.T) { + // When EarliestAt/LatestAt are nil, attributes should be omitted + s := Summary{ + SummaryID: "sum_leaf1", + Kind: SummaryKindLeaf, + Depth: 0, + Content: "leaf content", + TokenCount: 30, + DescendantCount: 0, + } + + xml := FormatSummaryXML(&s, nil) + + if contains(xml, "earliest_at=") { + t.Errorf("should not have earliest_at when nil, got: %s", xml) + } + if contains(xml, "latest_at=") { + t.Errorf("should not have latest_at when nil, got: %s", xml) + } +} diff --git a/pkg/seahorse/short_bench_test.go b/pkg/seahorse/short_bench_test.go new file mode 100644 index 000000000..b7e47bcff --- /dev/null +++ b/pkg/seahorse/short_bench_test.go @@ -0,0 +1,336 @@ +package seahorse + +import ( + "context" + "database/sql" + "fmt" + "testing" + "time" + + _ "modernc.org/sqlite" +) + +// newBenchStore creates a test store for benchmarks. +func newBenchStore(b *testing.B) (*Store, func()) { + b.Helper() + db, err := sql.Open("sqlite", ":memory:") + if err != nil { + b.Fatalf("open test db: %v", err) + } + if err := runSchema(db); err != nil { + db.Close() + b.Fatalf("migration: %v", err) + } + return &Store{db: db}, func() { db.Close() } +} + +// --- Ingest benchmarks --- + +func BenchmarkIngest_SingleMessage(b *testing.B) { + s, cleanup := newBenchStore(b) + defer cleanup() + ctx := context.Background() + conv, _ := s.GetOrCreateConversation(ctx, "bench:ingest") + convID := conv.ConversationID + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := s.AddMessage(ctx, convID, "user", "Test message content", 15) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkIngest_BatchMessages(b *testing.B) { + s, cleanup := newBenchStore(b) + defer cleanup() + ctx := context.Background() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + conv, _ := s.GetOrCreateConversation(ctx, fmt.Sprintf("bench:ingest-batch:%d", i)) + convID := conv.ConversationID + + for j := 0; j < 10; j++ { + added, err := s.AddMessage(ctx, convID, "user", + fmt.Sprintf("Message %d in batch", j), 10) + if err != nil { + b.Fatal(err) + } + s.AppendContextMessage(ctx, convID, added.ID) + } + } +} + +// --- Assemble benchmarks --- + +func BenchmarkAssemble_MessagesOnly(b *testing.B) { + s, cleanup := newBenchStore(b) + defer cleanup() + ctx := context.Background() + conv, _ := s.GetOrCreateConversation(ctx, "bench:assemble-msgs") + convID := conv.ConversationID + + // Add 100 messages + for i := 0; i < 100; i++ { + m, _ := s.AddMessage(ctx, convID, "user", + fmt.Sprintf("Message content %d with some text", i), 10) + s.AppendContextMessage(ctx, convID, m.ID) + } + + a := &Assembler{store: s} + input := AssembleInput{Budget: 50000} + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := a.Assemble(ctx, convID, input) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkAssemble_WithSummaries(b *testing.B) { + s, cleanup := newBenchStore(b) + defer cleanup() + ctx := context.Background() + conv, _ := s.GetOrCreateConversation(ctx, "bench:assemble-sums") + convID := conv.ConversationID + + now := time.Now().UTC() + + // Add 10 leaf summaries + for i := 0; i < 10; i++ { + sum, _ := s.CreateSummary(ctx, CreateSummaryInput{ + ConversationID: convID, + Kind: SummaryKindLeaf, + Depth: 0, + Content: fmt.Sprintf("Leaf summary %d", i), + TokenCount: 500, + EarliestAt: &now, + LatestAt: &now, + }) + s.AppendContextSummary(ctx, convID, sum.SummaryID) + } + + // Add 20 fresh messages + for i := 0; i < 20; i++ { + m, _ := s.AddMessage(ctx, convID, "user", fmt.Sprintf("Fresh message %d", i), 10) + s.AppendContextMessage(ctx, convID, m.ID) + } + + a := &Assembler{store: s} + input := AssembleInput{Budget: 10000} + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := a.Assemble(ctx, convID, input) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkAssemble_BudgetEviction(b *testing.B) { + s, cleanup := newBenchStore(b) + defer cleanup() + ctx := context.Background() + conv, _ := s.GetOrCreateConversation(ctx, "bench:assemble-evict") + convID := conv.ConversationID + + now := time.Now().UTC() + + // Add 50 leaf summaries (more than budget can hold) + for i := 0; i < 50; i++ { + sum, _ := s.CreateSummary(ctx, CreateSummaryInput{ + ConversationID: convID, + Kind: SummaryKindLeaf, + Depth: 0, + Content: fmt.Sprintf("Summary %d", i), + TokenCount: 300, + EarliestAt: &now, + LatestAt: &now, + }) + s.AppendContextSummary(ctx, convID, sum.SummaryID) + } + + // Add fresh tail + for i := 0; i < FreshTailCount; i++ { + m, _ := s.AddMessage(ctx, convID, "user", "fresh", 10) + s.AppendContextMessage(ctx, convID, m.ID) + } + + a := &Assembler{store: s} + input := AssembleInput{Budget: 5000} // Force eviction + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := a.Assemble(ctx, convID, input) + if err != nil { + b.Fatal(err) + } + } +} + +// --- Search (FTS5) benchmarks --- + +// benchSeedSummaries adds n summaries to a conversation for search benchmarks. +func benchSeedSummaries(b *testing.B, s *Store, convID int64, n int, contentTpl string) { + b.Helper() + now := time.Now().UTC() + for i := 0; i < n; i++ { + sum, err := s.CreateSummary(context.Background(), CreateSummaryInput{ + ConversationID: convID, + Kind: SummaryKindLeaf, + Depth: 0, + Content: fmt.Sprintf(contentTpl, i), + TokenCount: 200, + EarliestAt: &now, + LatestAt: &now, + }) + if err != nil { + b.Fatalf("create summary: %v", err) + } + s.AppendContextSummary(context.Background(), convID, sum.SummaryID) + } +} + +func BenchmarkSearchSummaries_FTS5(b *testing.B) { + s, cleanup := newBenchStore(b) + defer cleanup() + ctx := context.Background() + conv, _ := s.GetOrCreateConversation(ctx, "bench:search-fts") + convID := conv.ConversationID + + benchSeedSummaries(b, s, convID, 100, "Summary about database configuration and API endpoints %d") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := s.SearchSummaries(ctx, SearchInput{ + Pattern: "database", + Mode: "full_text", + ConversationID: convID, + }) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkSearchSummaries_Like(b *testing.B) { + s, cleanup := newBenchStore(b) + defer cleanup() + ctx := context.Background() + conv, _ := s.GetOrCreateConversation(ctx, "bench:search-like") + convID := conv.ConversationID + + benchSeedSummaries(b, s, convID, 100, "Summary about configuration %d") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := s.SearchSummaries(ctx, SearchInput{ + Pattern: "config", + Mode: "like", + ConversationID: convID, + }) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkSearchMessages_FTS5(b *testing.B) { + s, cleanup := newBenchStore(b) + defer cleanup() + ctx := context.Background() + conv, _ := s.GetOrCreateConversation(ctx, "bench:search-msg-fts") + convID := conv.ConversationID + + // Add 500 messages + for i := 0; i < 500; i++ { + m, _ := s.AddMessage(ctx, convID, "user", + fmt.Sprintf("User message about API and database integration %d", i), 20) + s.AppendContextMessage(ctx, convID, m.ID) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := s.SearchMessages(ctx, SearchInput{ + Pattern: "API database", + Mode: "full_text", + ConversationID: convID, + }) + if err != nil { + b.Fatal(err) + } + } +} + +// --- Bootstrap benchmarks --- + +func BenchmarkBootstrap_Empty(b *testing.B) { + s, cleanup := newBenchStore(b) + defer cleanup() + ctx := context.Background() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + conv, _ := s.GetOrCreateConversation(ctx, fmt.Sprintf("bench:bootstrap-empty:%d", i)) + convID := conv.ConversationID + _ = convID // Bootstrap with empty history + } +} + +func BenchmarkBootstrap_100Messages(b *testing.B) { + s, cleanup := newBenchStore(b) + defer cleanup() + ctx := context.Background() + + // Prepare 100 messages + msgs := make([]Message, 100) + for i := 0; i < 100; i++ { + msgs[i] = Message{ + Role: "user", + Content: fmt.Sprintf("Bootstrap message %d", i), + TokenCount: 15, + } + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + conv, _ := s.GetOrCreateConversation(ctx, fmt.Sprintf("bench:bootstrap-100:%d", i)) + convID := conv.ConversationID + + for _, m := range msgs { + added, _ := s.AddMessage(ctx, convID, m.Role, m.Content, m.TokenCount) + s.AppendContextMessage(ctx, convID, added.ID) + } + } +} + +func BenchmarkBootstrap_500Messages(b *testing.B) { + s, cleanup := newBenchStore(b) + defer cleanup() + ctx := context.Background() + + msgs := make([]Message, 500) + for i := 0; i < 500; i++ { + msgs[i] = Message{ + Role: "user", + Content: fmt.Sprintf("Bootstrap message %d", i), + TokenCount: 15, + } + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + conv, _ := s.GetOrCreateConversation(ctx, fmt.Sprintf("bench:bootstrap-500:%d", i)) + convID := conv.ConversationID + + for _, m := range msgs { + added, _ := s.AddMessage(ctx, convID, m.Role, m.Content, m.TokenCount) + s.AppendContextMessage(ctx, convID, added.ID) + } + } +} diff --git a/pkg/seahorse/short_compaction.go b/pkg/seahorse/short_compaction.go new file mode 100644 index 000000000..30e290926 --- /dev/null +++ b/pkg/seahorse/short_compaction.go @@ -0,0 +1,898 @@ +package seahorse + +import ( + "context" + "fmt" + "sort" + "time" + + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/tokenizer" +) + +// CompactInput controls compaction behavior. +type CompactInput struct { + Budget *int // Token budget override + Force bool // Force compaction even if below threshold +} + +// CompactResult describes what was compacted. +type CompactResult struct { + SummariesCreated []string `json:"summariesCreated"` + TokensSaved int `json:"tokensSaved"` + LeafSummaries int `json:"leafSummaries"` + CondensedSummaries int `json:"condensedSummaries"` +} + +// NeedsCompaction returns true if context tokens >= ContextThreshold × contextWindow. +func (e *CompactionEngine) NeedsCompaction(ctx context.Context, convID int64, contextWindow int) (bool, error) { + tokens, err := e.store.GetContextTokenCount(ctx, convID) + if err != nil { + return false, fmt.Errorf("get token count: %w", err) + } + threshold := int(float64(contextWindow) * ContextThreshold) + return tokens >= threshold, nil +} + +// Close cancels the shutdown context, stopping async goroutines. +func (e *CompactionEngine) Close() { + if e.shutdownCancel != nil { + e.shutdownCancel() + } +} + +// Compact runs leaf compaction (sync) and optionally condensed compaction. +func (e *CompactionEngine) Compact(ctx context.Context, convID int64, input CompactInput) (*CompactResult, error) { + result := &CompactResult{} + + // Phase 1: leaf compaction (synchronous, every turn) + summaryID, err := e.compactLeaf(ctx, convID) + if err != nil { + return nil, fmt.Errorf("compact leaf: %w", err) + } + if summaryID != nil { + result.SummariesCreated = append(result.SummariesCreated, *summaryID) + result.LeafSummaries++ + logger.InfoCF("seahorse", "compact: leaf", map[string]any{ + "conv_id": convID, + "summary_id": *summaryID, + }) + } + + // Phase 2: condensed compaction if over threshold + tokensBefore, _ := e.store.GetContextTokenCount(ctx, convID) + var budget int + if input.Budget != nil { + budget = *input.Budget + if budget == 0 { + logger.ErrorCF("seahorse", "Compact: budget is 0, this should not happen", map[string]any{ + "conv_id": convID, + }) + } + } else { + budget = int(float64(tokensBefore) * ContextThreshold) + } + + if input.Force || (tokensBefore > budget && budget > 0) { + // Launch async condensed compaction with dedup + if _, loaded := e.condensing.LoadOrStore(convID, struct{}{}); !loaded { + go func() { + defer e.condensing.Delete(convID) + e.runCondensedLoop(e.shutdownCtx, convID) + }() + } + } + + tokensAfter, _ := e.store.GetContextTokenCount(ctx, convID) + if tokensAfter < tokensBefore { + result.TokensSaved = tokensBefore - tokensAfter + } + + return result, nil +} + +// CompactUntilUnder aggressively compacts until context is under budget. +func (e *CompactionEngine) CompactUntilUnder(ctx context.Context, convID int64, budget int) (*CompactResult, error) { + result := &CompactResult{} + prevTokens := 0 + logger.InfoCF("seahorse", "compact_until_under: start", map[string]any{"conv_id": convID, "budget": budget}) + + for iter := 0; iter < MaxCompactIterations; iter++ { + tokens, err := e.store.GetContextTokenCount(ctx, convID) + if err != nil { + return result, fmt.Errorf("get tokens: %w", err) + } + if tokens <= budget { + logger.InfoCF("seahorse", "compact_until_under: done", map[string]any{ + "conv_id": convID, + "budget": budget, + "tokens": tokens, + "leaf": result.LeafSummaries, + "condensed": result.CondensedSummaries, + }) + return result, nil + } + + // Try leaf first + summaryID, err := e.compactLeaf(ctx, convID, true) + if err != nil { + return result, err + } + if summaryID != nil { + result.SummariesCreated = append(result.SummariesCreated, *summaryID) + result.LeafSummaries++ + logger.InfoCF("seahorse", "compact_until_under: leaf", map[string]any{ + "conv_id": convID, + "summary_id": *summaryID, + }) + continue + } + + // Try condensed with forced fanout + condensedID, err := e.compactCondensed(ctx, convID) + if err != nil { + return result, err + } + if condensedID != nil { + result.SummariesCreated = append(result.SummariesCreated, *condensedID) + result.CondensedSummaries++ + logger.InfoCF("seahorse", "compact_until_under: condensed", map[string]any{ + "conv_id": convID, + "summary_id": *condensedID, + }) + continue + } + + // No progress + newTokens, _ := e.store.GetContextTokenCount(ctx, convID) + if newTokens >= prevTokens { + logger.WarnCF("seahorse", "compact_until_under: no progress", map[string]any{ + "conv_id": convID, + "tokens": newTokens, + }) + return result, nil + } + prevTokens = newTokens + } + + // Safety cap exceeded — see MaxCompactIterations doc for rationale. + logger.WarnCF("seahorse", "compact_until_under: exceeded max iterations", map[string]any{ + "conv_id": convID, + "budget": budget, + "iterations": MaxCompactIterations, + "tokens": prevTokens, + }) + return result, nil +} + +// compactLeaf compresses the oldest contiguous message chunk into a leaf summary. +// When force is true, FreshTailCount protection is bypassed (used by CompactUntilUnder). +func (e *CompactionEngine) compactLeaf(ctx context.Context, convID int64, force ...bool) (*string, error) { + items, err := e.store.GetContextItems(ctx, convID) + if err != nil { + return nil, err + } + + // Find oldest contiguous message chunk outside fresh tail + msgCount := 0 + msgTokens := 0 + for _, item := range items { + if item.ItemType == "message" { + msgCount++ + msgTokens += item.TokenCount + } + } + + // Trigger if either message count or token threshold is met + if msgCount < LeafMinFanout && msgTokens < LeafChunkTokens { + return nil, nil + } + + // Calculate fresh tail boundary (bypass when forced) + useForce := len(force) > 0 && force[0] + tailStartIdx := len(items) - FreshTailCount + if useForce { + tailStartIdx = len(items) // allow compacting everything + } + if tailStartIdx < 0 { + tailStartIdx = 0 + } + + // Find oldest contiguous message chunk, accumulating up to LeafChunkTokens + var chunk []ContextItem + chunkStart := -1 + chunkEnd := -1 + accumTokens := 0 + for i := 0; i < tailStartIdx; i++ { + if items[i].ItemType == "message" { + if chunkStart == -1 { + chunkStart = i + } + chunkEnd = i + accumTokens += items[i].TokenCount + // Stop accumulating once we reach the token budget + if accumTokens >= LeafChunkTokens { + break + } + } else { + // Non-message breaks the chunk + if chunkStart != -1 && (chunkEnd-chunkStart+1) >= LeafMinFanout { + break + } + chunkStart = -1 + chunkEnd = -1 + accumTokens = 0 + } + } + + if chunkStart == -1 || (chunkEnd-chunkStart+1) < LeafMinFanout { + return nil, nil + } + + chunk = items[chunkStart : chunkEnd+1] + + // Collect messages for the chunk + var messages []Message + for _, item := range chunk { + msg, innerErr := e.store.GetMessageByID(ctx, item.MessageID) + if innerErr != nil { + return nil, innerErr + } + messages = append(messages, *msg) + } + + // Get prior summaries for context + priorSummary := "" + priorCount := 0 + for i := chunkStart - 1; i >= 0 && priorCount < 2; i-- { + if items[i].ItemType == "summary" { + sum, innerErr2 := e.store.GetSummary(ctx, items[i].SummaryID) + if innerErr2 == nil { + priorSummary = sum.Content + "\n" + priorSummary + priorCount++ + } + } + } + + // Generate summary + content, err := e.generateLeafSummary(ctx, messages, priorSummary) + if err != nil { + return nil, err + } + + // Create summary in store + tokenCount := tokenizer.EstimateMessageTokens(providers.Message{Content: content}) + + var earliestAt, latestAt *time.Time + if len(messages) > 0 { + earliestAt = &messages[0].CreatedAt + latestAt = &messages[len(messages)-1].CreatedAt + } + + summary, err := e.store.CreateSummary(ctx, CreateSummaryInput{ + ConversationID: convID, + Kind: SummaryKindLeaf, + Depth: 0, + Content: content, + TokenCount: tokenCount, + EarliestAt: earliestAt, + LatestAt: latestAt, + SourceMessageTokens: sumMessageTokens(messages), + }) + if err != nil { + return nil, err + } + + // Link to source messages + msgIDs := make([]int64, len(messages)) + for i, m := range messages { + msgIDs[i] = m.ID + } + if err := e.store.LinkSummaryToMessages(ctx, summary.SummaryID, msgIDs); err != nil { + return nil, err + } + + // Replace context range with summary + if err := e.store.ReplaceContextRangeWithSummary( + ctx, convID, chunk[0].Ordinal, chunk[len(chunk)-1].Ordinal, summary.SummaryID, + ); err != nil { + return nil, err + } + + return &summary.SummaryID, nil +} + +// compactCondensed compresses multiple summaries into one higher-level summary. +func (e *CompactionEngine) compactCondensed(ctx context.Context, convID int64) (*string, error) { + // Try ordinal-aware selection first (respects consecutive ordering) + var candidates []Summary + + depths, err := e.store.GetDistinctDepthsInContext(ctx, convID, 0) + if err != nil { + return nil, err + } + for _, depth := range depths { + var chunkAtDepth []Summary + var err2 error + chunkAtDepth, err2 = e.selectOldestChunkAtDepth(ctx, convID, depth) + if err2 != nil { + continue + } + if len(chunkAtDepth) > 0 { + candidates = chunkAtDepth + break + } + } + + // Fallback to depth-grouping selection + if len(candidates) == 0 { + candidates, err = e.selectShallowestCondensationCandidate(ctx, convID, false) + if err != nil { + return nil, err + } + } + if len(candidates) == 0 { + return nil, nil + } + + // Generate condensed summary + content, err := e.generateCondensedSummary(ctx, candidates) + if err != nil { + return nil, err + } + + // Merge metadata + maxDepth := 0 + descendantCount := 0 + descendantTokenCount := 0 + sourceMessageTokens := 0 + var earliestAt, latestAt *time.Time + + parentIDs := make([]string, len(candidates)) + for i, c := range candidates { + parentIDs[i] = c.SummaryID + if c.Depth > maxDepth { + maxDepth = c.Depth + } + descendantCount += c.DescendantCount + 1 + descendantTokenCount += c.TokenCount + c.DescendantTokenCount + sourceMessageTokens += c.SourceMessageTokenCount + if c.EarliestAt != nil { + if earliestAt == nil || c.EarliestAt.Before(*earliestAt) { + earliestAt = c.EarliestAt + } + } + if c.LatestAt != nil { + if latestAt == nil || c.LatestAt.After(*latestAt) { + latestAt = c.LatestAt + } + } + } + + tokenCount := tokenizer.EstimateMessageTokens(providers.Message{Content: content}) + + summary, err := e.store.CreateSummary(ctx, CreateSummaryInput{ + ConversationID: convID, + Kind: SummaryKindCondensed, + Depth: maxDepth + 1, + Content: content, + TokenCount: tokenCount, + EarliestAt: earliestAt, + LatestAt: latestAt, + DescendantCount: descendantCount, + DescendantTokenCount: descendantTokenCount, + SourceMessageTokens: sourceMessageTokens, + ParentIDs: parentIDs, + }) + if err != nil { + return nil, err + } + + // Find the ordinal range for the candidate summaries in context + items, err := e.store.GetContextItems(ctx, convID) + if err != nil { + return nil, err + } + + candidateSet := make(map[string]bool) + for _, c := range candidates { + candidateSet[c.SummaryID] = true + } + + startOrd := -1 + endOrd := -1 + hasNonCandidate := false + for _, item := range items { + if item.ItemType == "summary" && candidateSet[item.SummaryID] { + if startOrd == -1 { + startOrd, endOrd = item.Ordinal, item.Ordinal + } else { + // Check for non-candidate items between endOrd and current ordinal + for _, it := range items { + if it.Ordinal > endOrd && it.Ordinal <= item.Ordinal { + if it.ItemType != "summary" || !candidateSet[it.SummaryID] { + hasNonCandidate = true + break + } + } + } + if hasNonCandidate { + break + } + if item.Ordinal < startOrd { + startOrd = item.Ordinal + } + if item.Ordinal > endOrd { + endOrd = item.Ordinal + } + } + } + } + + if startOrd == -1 || endOrd == -1 { + return nil, nil + } + + // Collect candidate summary IDs + candidateIDs := make([]string, 0, len(candidates)) + for _, c := range candidates { + candidateIDs = append(candidateIDs, c.SummaryID) + } + + if hasNonCandidate { + // Use safe per-item deletion to avoid deleting non-candidate items + if err := e.store.ReplaceContextItemsWithSummary(ctx, convID, candidateIDs, summary.SummaryID); err != nil { + return nil, err + } + } else { + // Candidates are consecutive, use efficient range deletion + if err := e.store.ReplaceContextRangeWithSummary(ctx, convID, startOrd, endOrd, summary.SummaryID); err != nil { + return nil, err + } + } + + return &summary.SummaryID, nil +} + +// selectShallowestCondensationCandidate finds the shallowest consecutive summary group. +func (e *CompactionEngine) selectShallowestCondensationCandidate( + ctx context.Context, convID int64, forced bool, +) ([]Summary, error) { + items, err := e.store.GetContextItems(ctx, convID) + if err != nil { + return nil, err + } + + // Group by depth, find consecutive runs + tailStartIdx := len(items) - FreshTailCount + if tailStartIdx < 0 { + tailStartIdx = 0 + } + + minFanout := CondensedMinFanout + if forced { + minFanout = CondensedMinFanoutHard + } + + // Track depth groups + depthGroups := make(map[int][]ContextItem) + for i := 0; i < tailStartIdx; i++ { + item := items[i] + if item.ItemType != "summary" { + continue + } + sum, err := e.store.GetSummary(ctx, item.SummaryID) + if err != nil { + continue + } + depthGroups[sum.Depth] = append(depthGroups[sum.Depth], item) + } + + // Find shallowest depth with enough candidates + // Collect all depths and sort to handle non-consecutive depths + var depths []int + for depth := range depthGroups { + depths = append(depths, depth) + } + sort.Ints(depths) + + for _, depth := range depths { + group := depthGroups[depth] + if len(group) >= minFanout { + // Load summaries + var result []Summary + for _, item := range group[:minFanout] { + sum, err := e.store.GetSummary(ctx, item.SummaryID) + if err != nil { + continue + } + result = append(result, *sum) + } + return result, nil + } + } + + return nil, nil +} + +// selectOldestChunkAtDepth scans context_items from oldest ordinal, collecting consecutive +// summaries at the given depth. Stops at non-summary items, different depth, fresh tail, or +// token overflow. Returns contiguous chunk of summaries. +func (e *CompactionEngine) selectOldestChunkAtDepth( + ctx context.Context, convID int64, targetDepth int, +) ([]Summary, error) { + items, err := e.store.GetContextItems(ctx, convID) + if err != nil { + return nil, err + } + + tailStartIdx := len(items) - FreshTailCount + if tailStartIdx < 0 { + tailStartIdx = 0 + } + + var chunk []Summary + accumTokens := 0 + + for i := 0; i < tailStartIdx; i++ { + item := items[i] + if item.ItemType != "summary" { + // Non-summary breaks the chunk + break + } + sum, err := e.store.GetSummary(ctx, item.SummaryID) + if err != nil { + break + } + if sum.Depth != targetDepth { + // Different depth breaks the chunk + break + } + if accumTokens+sum.TokenCount > LeafChunkTokens { + // Token overflow stops collection + break + } + chunk = append(chunk, *sum) + accumTokens += sum.TokenCount + } + + // Min tokens check: spec line 808 + // chunk tokens must be >= max(CondensedTargetTokens, LeafChunkTokens × 0.1) = 2000 + minTokens := CondensedTargetTokens // 2000 + if accumTokens < minTokens { + return nil, nil + } + + return chunk, nil +} + +// generateLeafSummary calls the LLM to generate a leaf summary with 3-level escalation. +// Level 1: normal LLM prompt. Level 2: aggressive prompt. Level 3: deterministic truncation. +func (e *CompactionEngine) generateLeafSummary( + ctx context.Context, + messages []Message, + previousSummary string, +) (string, error) { + if e.complete == nil { + return truncateSummary(messages), nil + } + + sourceText := formatMessagesForSummary(messages) + inputTokens := sumMessageTokens(messages) + targetTokens := minInt(LeafTargetTokens, int(float64(inputTokens)*0.35)) + + // Level 1: normal prompt + prompt := buildLeafSummaryPrompt(sourceText, previousSummary, targetTokens) + content, err := e.complete(ctx, prompt, CompleteOptions{ + MaxTokens: LeafTargetTokens * 2, + Temperature: 0.3, + }) + if err != nil { + return "", err + } + if content == "" { + // Retry with temperature=0 + content, err = e.complete(ctx, prompt, CompleteOptions{ + MaxTokens: LeafTargetTokens * 2, + Temperature: 0, + }) + if err != nil { + return "", err + } + } + + // Check if level 1 succeeded + if content != "" && tokenizer.EstimateMessageTokens(providers.Message{Content: content}) < inputTokens { + return content, nil + } + + // Level 2: aggressive prompt + aggressiveTarget := minInt(640, int(float64(inputTokens)*0.20)) + aggressivePrompt := buildAggressiveLeafSummaryPrompt(sourceText, previousSummary, aggressiveTarget) + content, err = e.complete(ctx, aggressivePrompt, CompleteOptions{ + MaxTokens: aggressiveTarget * 2, + Temperature: 0.3, + }) + if err != nil { + return "", err + } + if content == "" { + // Retry with temperature=0 + content, err = e.complete(ctx, aggressivePrompt, CompleteOptions{ + MaxTokens: aggressiveTarget * 2, + Temperature: 0, + }) + if err != nil { + return "", err + } + } + if content != "" && tokenizer.EstimateMessageTokens(providers.Message{Content: content}) < inputTokens { + return content, nil + } + + // Level 3: deterministic truncation + return truncateSummary(messages), nil +} + +// generateCondensedSummary calls the LLM to generate a condensed summary with 3-level escalation. +func (e *CompactionEngine) generateCondensedSummary(ctx context.Context, summaries []Summary) (string, error) { + if e.complete == nil { + return truncateCondensedSummaries(summaries), nil + } + + sourceText := formatSummariesForCondensation(summaries) + inputTokens := sumSummaryTokens(summaries) + targetTokens := minInt(CondensedTargetTokens, int(float64(inputTokens)*0.35)) + + // Level 1: normal prompt + prompt := buildCondensedSummaryPrompt(sourceText, targetTokens) + content, err := e.complete(ctx, prompt, CompleteOptions{ + MaxTokens: CondensedTargetTokens * 2, + Temperature: 0.3, + }) + if err != nil { + return "", err + } + if content == "" { + content, err = e.complete(ctx, prompt, CompleteOptions{ + MaxTokens: CondensedTargetTokens * 2, + Temperature: 0, + }) + if err != nil { + return "", err + } + } + if content != "" { + return content, nil + } + + // Level 2: aggressive prompt + aggressiveTarget := minInt(640, int(float64(inputTokens)*0.20)) + aggressivePrompt := buildCondensedSummaryPrompt(sourceText, aggressiveTarget) + content, err = e.complete(ctx, aggressivePrompt, CompleteOptions{ + MaxTokens: aggressiveTarget * 2, + Temperature: 0.3, + }) + if err != nil { + return "", err + } + if content != "" { + return content, nil + } + + // Level 3: deterministic fallback + return truncateCondensedSummaries(summaries), nil +} + +// runCondensedLoop runs condensed compaction in a loop until: +// a) context tokens <= threshold (success), OR +// b) No candidate found (nothing to condense), OR +// c) tokensAfter >= tokensBefore (no progress this iteration), OR +// d) tokensAfter >= previousTokens (no improvement over last iteration) +func (e *CompactionEngine) runCondensedLoop(ctx context.Context, convID int64) { + var prevTokens int + for { + select { + case <-ctx.Done(): + return + default: + } + + tokensBefore, err := e.store.GetContextTokenCount(ctx, convID) + if err != nil { + logger.ErrorCF("seahorse", "condensed: get tokens", map[string]any{"error": err.Error()}) + return + } + + condensedID, err := e.compactCondensed(ctx, convID) + if err != nil { + logger.ErrorCF("seahorse", "condensed: compact", map[string]any{"error": err.Error()}) + return + } + if condensedID == nil { + // No candidate found + logger.DebugCF("seahorse", "condensed: no candidate", map[string]any{"conv_id": convID}) + return + } + + tokensAfter, _ := e.store.GetContextTokenCount(ctx, convID) + + if tokensAfter >= tokensBefore { + // No progress this iteration + logger.DebugCF( + "seahorse", + "condensed: no progress", + map[string]any{"conv_id": convID, "tokens_before": tokensBefore, "tokens_after": tokensAfter}, + ) + return + } + if tokensAfter >= prevTokens && prevTokens > 0 { + // No improvement over last iteration + logger.DebugCF( + "seahorse", + "condensed: no improvement", + map[string]any{"conv_id": convID, "tokens": tokensAfter}, + ) + return + } + + prevTokens = tokensAfter + } +} + +// --- Helper functions --- + +func formatMessagesForSummary(messages []Message) string { + var result string + for _, m := range messages { + ts := m.CreatedAt.Format("2006-01-02 15:04 MST") + content := m.Content + if content == "" && len(m.Parts) > 0 { + content = partsToReadableContent(m.Parts) + } + result += fmt.Sprintf("[%s]\n%s\n\n", ts, content) + } + return result +} + +func formatSummariesForCondensation(summaries []Summary) string { + var result string + for _, s := range summaries { + earliest := "" + if s.EarliestAt != nil { + earliest = s.EarliestAt.Format("2006-01-02") + } + latest := "" + if s.LatestAt != nil { + latest = s.LatestAt.Format("2006-01-02") + } + result += fmt.Sprintf("[%s - %s]\n%s\n\n", earliest, latest, s.Content) + } + return result +} + +func buildLeafSummaryPrompt(sourceText, previousSummary string, targetTokens int) string { + prev := "(none)" + if previousSummary != "" { + prev = previousSummary + } + return fmt.Sprintf(`You summarize a SEGMENT of a conversation for future model turns. +Treat this as incremental memory compaction input, not a full-conversation summary. + +Normal summary policy: +- Preserve key decisions, rationale, constraints, and active tasks. +- Keep essential technical details needed to continue work safely. +- Remove obvious repetition and conversational filler. + +Output requirements: +- Plain text only. +- No preamble, headings, or markdown formatting. +- Track file operations (created, modified, deleted, renamed) with file paths and current status. +- If no file operations appear, include exactly: "Files: none". +- End with exactly: "Expand for details about: ". +- Target length: about %d tokens or less. + + +%s + + + +%s +`, targetTokens, prev, sourceText) +} + +func buildCondensedSummaryPrompt(sourceText string, targetTokens int) string { + return fmt.Sprintf(`You condense multiple summaries into a single higher-level summary. +Preserve all important decisions, constraints, and outcomes. +Merge overlapping topics. Keep technical details intact. + +Output requirements: +- Plain text only. +- No preamble, headings, or markdown formatting. +- End with exactly: "Expand for details about: ". +- Target length: about %d tokens or less. + + +%s +`, targetTokens, sourceText) +} + +func buildAggressiveLeafSummaryPrompt(sourceText, previousSummary string, targetTokens int) string { + prev := "(none)" + if previousSummary != "" { + prev = previousSummary + } + return fmt.Sprintf(`You summarize a SEGMENT of a conversation for future model turns. +Aggressive summary policy: +- Keep only durable facts and current task state. +- Remove examples, repetition, and low-value narrative details. +- Preserve explicit TODOs, blockers, decisions, and constraints. + +Output requirements: +- Plain text only. +- No preamble, headings, or markdown formatting. +- Track file operations (created, modified, deleted, renamed) with file paths and current status. +- If no file operations appear, include exactly: "Files: none". +- End with exactly: "Expand for details about: ". +- Target length: about %d tokens or less. + + +%s + + + +%s +`, targetTokens, prev, sourceText) +} + +func truncateSummary(messages []Message) string { + content := "" + for _, m := range messages { + c := m.Content + if c == "" && len(m.Parts) > 0 { + c = partsToReadableContent(m.Parts) + } + content += c + "\n" + } + if len(content) > 2048 { + content = content[:2048] + } + content += fmt.Sprintf("\n[Truncated from %d messages]", len(messages)) + return content +} + +func truncateCondensedSummaries(summaries []Summary) string { + content := "" + for _, s := range summaries { + content += s.Content + "\n" + } + if len(content) > 2048 { + content = content[:2048] + } + content += fmt.Sprintf("\n[Condensed from %d summaries]", len(summaries)) + return content +} + +func sumMessageTokens(messages []Message) int { + total := 0 + for _, m := range messages { + total += m.TokenCount + } + return total +} + +func sumSummaryTokens(summaries []Summary) int { + total := 0 + for _, s := range summaries { + total += s.TokenCount + } + return total +} + +func minInt(a, b int) int { + if a < b { + return a + } + return b +} diff --git a/pkg/seahorse/short_compaction_test.go b/pkg/seahorse/short_compaction_test.go new file mode 100644 index 000000000..ea7dcb52d --- /dev/null +++ b/pkg/seahorse/short_compaction_test.go @@ -0,0 +1,974 @@ +package seahorse + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "testing" + "time" +) + +// --- Test Helpers --- + +// waitForCondensed blocks until the async condensed goroutine for convID finishes. +// Returns false if timeout is reached. +func waitForCondensed(ce *CompactionEngine, convID int64, timeout time.Duration) bool { + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + if _, exists := ce.condensing.Load(convID); !exists { + return true + } + time.Sleep(50 * time.Millisecond) + } + return false +} + +// --- Compaction Tests --- + +func newTestCompactionEngine(t *testing.T) (*CompactionEngine, *Store, int64) { + t.Helper() + db := openTestDB(t) + if err := runSchema(db); err != nil { + t.Fatalf("migration: %v", err) + } + s := &Store{db: db} + ctx := context.Background() + conv, _ := s.GetOrCreateConversation(ctx, "test:compact") + shutdownCtx, shutdownCancel := context.WithCancel(context.Background()) + ce := &CompactionEngine{ + store: s, + config: Config{}, + complete: mockCompleteFn, + shutdownCtx: shutdownCtx, + shutdownCancel: shutdownCancel, + } + convID := conv.ConversationID + // Ensure async goroutines are stopped before database is closed. + // Register cleanup here (after openTestDB) so it runs BEFORE openTestDB's db.Close(). + t.Cleanup(func() { + shutdownCancel() + // Wait for async condensed goroutine to finish (poll condensing map) + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + if _, exists := ce.condensing.Load(convID); !exists { + break + } + time.Sleep(50 * time.Millisecond) + } + }) + return ce, s, conv.ConversationID +} + +// newTestCompactionEngineWithStore creates a CompactionEngine with existing store. +// Note: Caller is responsible for calling shutdownCancel when test ends. +func newTestCompactionEngineWithStore( + s *Store, complete CompleteFn, +) (ce *CompactionEngine, shutdownCancel context.CancelFunc) { + shutdownCtx, cancel := context.WithCancel(context.Background()) + return &CompactionEngine{ + store: s, + config: Config{}, + complete: complete, + shutdownCtx: shutdownCtx, + shutdownCancel: cancel, + }, cancel +} + +// mockCompleteFn returns a simple summary for testing +var mockCompleteFn CompleteFn = func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) { + return "Mock summary of the conversation segment.", nil +} + +func TestNeedsCompaction(t *testing.T) { + ce, s, convID := newTestCompactionEngine(t) + ctx := context.Background() + + // Empty context — no compaction needed + needed, err := ce.NeedsCompaction(ctx, convID, 10000) + if err != nil { + t.Fatalf("NeedsCompaction: %v", err) + } + if needed { + t.Error("expected no compaction for empty context") + } + + // Add messages to context, total tokens = 8000 + for i := 0; i < 8; i++ { + m, _ := s.AddMessage(ctx, convID, "user", "test message content", 1000) + s.AppendContextMessage(ctx, convID, m.ID) + } + + // Threshold = 0.75 × 10000 = 7500. We have 8000 tokens → needs compaction + needed, err = ce.NeedsCompaction(ctx, convID, 10000) + if err != nil { + t.Fatalf("NeedsCompaction: %v", err) + } + if !needed { + t.Error("expected compaction needed at 8000/10000 tokens (threshold 75%)") + } + + // Below threshold: 5000 / 10000 → no compaction + s.UpsertContextItems(ctx, convID, nil) // clear + for i := 0; i < 5; i++ { + m, _ := s.AddMessage(ctx, convID, "user", "test", 1000) + s.AppendContextMessage(ctx, convID, m.ID) + } + needed, _ = ce.NeedsCompaction(ctx, convID, 10000) + if needed { + t.Error("expected no compaction at 5000/10000 tokens") + } +} + +func TestCompactLeaf(t *testing.T) { + ce, s, convID := newTestCompactionEngine(t) + ctx := context.Background() + + // Create enough messages to trigger leaf compaction: + // Need > FreshTailCount(32) evictable messages with >= LeafMinFanout(8) contiguous + for i := 0; i < 40; i++ { + m, _ := s.AddMessage(ctx, convID, "user", "message content for compaction test", 100) + s.AppendContextMessage(ctx, convID, m.ID) + } + + // Compact + result, err := ce.Compact(ctx, convID, CompactInput{}) + if err != nil { + t.Fatalf("Compact: %v", err) + } + if result == nil { + t.Fatal("expected non-nil result") + } + + // Should have created at least one leaf summary + if result.LeafSummaries == 0 { + t.Error("expected at least 1 leaf summary") + } + + // Context should now contain a summary item + items, _ := s.GetContextItems(ctx, convID) + foundSummary := false + for _, item := range items { + if item.ItemType == "summary" { + foundSummary = true + break + } + } + if !foundSummary { + t.Error("expected a summary in context_items after leaf compaction") + } + + // Some messages should have been replaced + if len(result.SummariesCreated) == 0 { + t.Error("expected at least 1 summary created") + } +} + +func TestCompactLeafNoCandidate(t *testing.T) { + ce, _, convID := newTestCompactionEngine(t) + ctx := context.Background() + + // Too few messages to trigger leaf compaction + m, _ := ce.store.AddMessage(ctx, convID, "user", "short", 10) + ce.store.AppendContextMessage(ctx, convID, m.ID) + + result, err := ce.Compact(ctx, convID, CompactInput{}) + if err != nil { + t.Fatalf("Compact: %v", err) + } + if result == nil { + t.Fatal("expected non-nil result even with no candidate") + } + if result.LeafSummaries != 0 { + t.Errorf("LeafSummaries = %d, want 0 (too few messages)", result.LeafSummaries) + } +} + +func TestCompactCondensed(t *testing.T) { + ce, s, convID := newTestCompactionEngine(t) + ctx := context.Background() + + // Create enough leaf summaries and fresh messages to enable condensation + leafIDs := make([]string, CondensedMinFanout) + for i := 0; i < CondensedMinFanout; i++ { + now := time.Now().UTC() + summary, err := s.CreateSummary(ctx, CreateSummaryInput{ + ConversationID: convID, + Kind: SummaryKindLeaf, + Depth: 0, + Content: "leaf summary content " + time.Now().String(), + TokenCount: 500, + EarliestAt: &now, + LatestAt: &now, + }) + if err != nil { + t.Fatalf("CreateSummary %d: %v", i, err) + } + leafIDs[i] = summary.SummaryID + s.AppendContextSummary(ctx, convID, summary.SummaryID) + } + + // Add enough fresh messages to have a fresh tail (>= FreshTailCount) + for i := 0; i < FreshTailCount; i++ { + m, _ := s.AddMessage(ctx, convID, "user", "fresh message", 10) + s.AppendContextMessage(ctx, convID, m.ID) + } + + // Compact with force to trigger condensation + _, err := ce.Compact(ctx, convID, CompactInput{Force: true}) + if err != nil { + t.Fatalf("Compact: %v", err) + } + + // Wait for async condensed goroutine to complete + if !waitForCondensed(ce, convID, 2*time.Second) { + t.Fatal("timeout waiting for condensed compaction") + } + + // Should have created a condensed summary in the DB + summaries, _ := s.GetSummariesByConversation(ctx, convID) + foundCondensed := false + for _, sum := range summaries { + if sum.Kind == SummaryKindCondensed { + foundCondensed = true + break + } + } + if !foundCondensed { + t.Error("expected at least 1 condensed summary") + } +} + +func TestCompactCondensedDoesNotOrphanSummaryWhenCandidatesRemovedConcurrently(t *testing.T) { + // Reproduce orphan bug: candidates found by selectOldestChunkAtDepth are removed + // from context_items between candidate selection and ordinal range scan. + // Use a slow CompleteFn with barrier sync to control timing. + s := openTestStore(t) + ctx := context.Background() + conv, _ := s.GetOrCreateConversation(ctx, "test:orphan-race") + convID := conv.ConversationID + + // Create leaf summaries with enough tokens for condensation + var leafIDs []string + for i := 0; i < CondensedMinFanout; i++ { + now := time.Now().UTC() + sum, err := s.CreateSummary(ctx, CreateSummaryInput{ + ConversationID: convID, + Kind: SummaryKindLeaf, + Depth: 0, + Content: fmt.Sprintf("leaf summary %d", i), + TokenCount: 500, + EarliestAt: &now, + LatestAt: &now, + }) + if err != nil { + t.Fatalf("CreateSummary: %v", err) + } + leafIDs = append(leafIDs, sum.SummaryID) + s.AppendContextSummary(ctx, convID, sum.SummaryID) + } + + // Add fresh tail so leaf summaries are in evictable range + for i := 0; i < FreshTailCount+1; i++ { + m, _ := s.AddMessage(ctx, convID, "user", "fresh", 10) + s.AppendContextMessage(ctx, convID, m.ID) + } + + // Barrier: CompleteFn waits until test removes context_items, then returns + var barrier1, barrier2 sync.WaitGroup + barrier1.Add(1) // CompleteFn signals when called + barrier2.Add(1) // test signals when context_items removed + + slowComplete := func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) { + barrier1.Done() // signal: LLM called, candidates selected + barrier2.Wait() // wait: test removes context_items + return "Condensed summary.", nil + } + + ce, cancel := newTestCompactionEngineWithStore(s, slowComplete) + t.Cleanup(func() { + cancel() + time.Sleep(100 * time.Millisecond) + }) + + // Run compactCondensed in background + type compactResult struct { + summaryID *string + err error + } + resultCh := make(chan compactResult, 1) + go func() { + sid, err := ce.compactCondensed(context.Background(), convID) + resultCh <- compactResult{summaryID: sid, err: err} + }() + + // Wait for CompleteFn to be called (candidates selected) + barrier1.Wait() + + // Remove leaf summaries from context_items (simulating concurrent replacement) + items, _ := s.GetContextItems(ctx, convID) + var preserved []ContextItem + for _, item := range items { + isLeaf := false + for _, lid := range leafIDs { + if item.SummaryID == lid { + isLeaf = true + break + } + } + if !isLeaf { + preserved = append(preserved, item) + } + } + s.UpsertContextItems(ctx, convID, preserved) + + // Let CompleteFn return + barrier2.Done() + + // Get result + res := <-resultCh + if res.err != nil { + t.Fatalf("compactCondensed: %v", res.err) + } + + // With the bug: returns non-nil summaryID even though context_items has no matching ordinals + // The fix: should return nil when startOrd == -1 + if res.summaryID != nil { + t.Errorf("compactCondensed returned summaryID=%s, want nil (orphan created)", *res.summaryID) + + // Verify the orphan exists in DB + summary, _ := s.GetSummary(context.Background(), *res.summaryID) + if summary != nil && summary.Kind == SummaryKindCondensed { + // Check it's NOT in context_items (orphan) + items2, _ := s.GetContextItems(context.Background(), convID) + found := false + for _, item := range items2 { + if item.SummaryID == *res.summaryID { + found = true + break + } + } + if !found { + t.Error("condensed summary exists in DB but not in context_items — orphan confirmed") + } + } + } +} + +func TestCompactUntilUnder(t *testing.T) { + ce, s, convID := newTestCompactionEngine(t) + ctx := context.Background() + + // Create many leaf summaries to ensure we can condense + for i := 0; i < 8; i++ { + now := time.Now().UTC() + summary, _ := s.CreateSummary(ctx, CreateSummaryInput{ + ConversationID: convID, + Kind: SummaryKindLeaf, + Depth: 0, + Content: "leaf summary for condensation test", + TokenCount: 500, + EarliestAt: &now, + LatestAt: &now, + }) + s.AppendContextSummary(ctx, convID, summary.SummaryID) + } + + // Force compact until under budget + result, err := ce.CompactUntilUnder(ctx, convID, 2000) + if err != nil { + t.Fatalf("CompactUntilUnder: %v", err) + } + + if result == nil { + t.Fatal("expected non-nil result") + } +} + +func TestSelectShallowestCondensationCandidate(t *testing.T) { + ce, s, convID := newTestCompactionEngine(t) + ctx := context.Background() + + // Create enough leaf summaries + fresh messages for candidates + for i := 0; i < LeafMinFanout; i++ { + summary, _ := s.CreateSummary(ctx, CreateSummaryInput{ + ConversationID: convID, + Kind: SummaryKindLeaf, + Depth: 0, + Content: "leaf", + TokenCount: 100, + }) + s.AppendContextSummary(ctx, convID, summary.SummaryID) + } + + // Add fresh tail messages so summaries are in evictable range + for i := 0; i < FreshTailCount+1; i++ { + m, _ := s.AddMessage(ctx, convID, "user", "fresh", 5) + s.AppendContextMessage(ctx, convID, m.ID) + } + + candidates, err := ce.selectShallowestCondensationCandidate(ctx, convID, false) + if err != nil { + t.Fatalf("selectShallowestCondensationCandidate: %v", err) + } + + // Should find leaf summaries at depth 0 + if len(candidates) < CondensedMinFanout { + t.Errorf("candidates = %d, want >= %d", len(candidates), CondensedMinFanout) + } +} + +func TestSelectShallowestCondensationCandidateEmpty(t *testing.T) { + ce, _, convID := newTestCompactionEngine(t) + ctx := context.Background() + + candidates, err := ce.selectShallowestCondensationCandidate(ctx, convID, false) + if err != nil { + t.Fatalf("selectShallowestCondensationCandidate: %v", err) + } + if len(candidates) != 0 { + t.Errorf("candidates = %d, want 0 for empty context", len(candidates)) + } +} + +func TestCompactCondensedUsesSelectOldestChunk(t *testing.T) { + // Verify that compactCondensed prefers ordinal-ordered chunks via selectOldestChunkAtDepth + // rather than just grouping by depth without regard to order + ce, s, convID := newTestCompactionEngine(t) + ctx := context.Background() + + // Create interleaved summaries at depth 0 with a message in between: + // sum1 (ordinal 100), msg (ordinal 200), sum2 (ordinal 300) + + for i := 0; i < LeafMinFanout+2; i++ { + now := time.Now().UTC() + + s.CreateSummary(ctx, CreateSummaryInput{ + ConversationID: convID, + Kind: SummaryKindLeaf, + Depth: 0, + Content: fmt.Sprintf("leaf summary %d", i), + TokenCount: 100, + EarliestAt: &now, + LatestAt: &now, + }) + } + + // Insert a message between first two summaries to break contiguity + // for selectShallowestCondensationCandidate but would still find all 3 + // but selectOldestChunkAtDepth should only find sum1 + sum2 (not sum3) + + msg, _ := s.AddMessage(ctx, convID, "user", "interrupting message", 5) + s.AppendContextMessage(ctx, convID, msg.ID) + + // Run compactCondensed + result, err := ce.compactCondensed(ctx, convID) + if err != nil { + t.Fatalf("compactCondensed: %v", err) + } + + // The result should have merged the two summaries at the start + // (skipping the message in between), This proves ordinal-aware selection works. + + _ = result // verify summary was created + + if result != nil { + summaries, _ := s.GetSummariesByConversation(ctx, convID) + found := false + for _, sum := range summaries { + if sum.Kind == SummaryKindCondensed { + found = true + break + } + } + if !found { + t.Error("expected condensed summary to be created via ordinal-aware selection") + } + } +} + +func TestCompactCondensedUsesOrdinalAwareSelection(t *testing.T) { + ce, s, convID := newTestCompactionEngine(t) + ctx := context.Background() + + // Create leaf summaries at depth 0 (total tokens >= CondensedTargetTokens) + for i := 0; i < 5; i++ { + summary, _ := s.CreateSummary(ctx, CreateSummaryInput{ + ConversationID: convID, + Kind: SummaryKindLeaf, + Depth: 0, + Content: fmt.Sprintf("leaf summary %d", i), + TokenCount: 500, // 5 × 500 = 2500 >= CondensedTargetTokens (2000) + }) + s.AppendContextSummary(ctx, convID, summary.SummaryID) + } + + // Add fresh tail + for i := 0; i < FreshTailCount+1; i++ { + m, _ := s.AddMessage(ctx, convID, "user", "fresh", 5) + s.AppendContextMessage(ctx, convID, m.ID) + } + + chunk, err := ce.selectOldestChunkAtDepth(ctx, convID, 0) + if err != nil { + t.Fatalf("selectOldestChunkAtDepth: %v", err) + } + if len(chunk) < 2 { + t.Errorf("chunk length = %d, want >= 2 contiguous summaries", len(chunk)) + } + for _, s := range chunk { + if s.Depth != 0 { + t.Errorf("got depth %d, want 0", s.Depth) + } + } +} + +func TestSelectOldestChunkAtDepthBreaksOnMessage(t *testing.T) { + ce, s, convID := newTestCompactionEngine(t) + ctx := context.Background() + + // Create 3 summaries, then a message, then 3 more summaries + for i := 0; i < 3; i++ { + summary, _ := s.CreateSummary(ctx, CreateSummaryInput{ + ConversationID: convID, + Kind: SummaryKindLeaf, + Depth: 0, + Content: fmt.Sprintf("leaf %d", i), + TokenCount: 100, + }) + s.AppendContextSummary(ctx, convID, summary.SummaryID) + } + msg, _ := s.AddMessage(ctx, convID, "user", "break", 10) + s.AppendContextMessage(ctx, convID, msg.ID) + for i := 0; i < 3; i++ { + summary, _ := s.CreateSummary(ctx, CreateSummaryInput{ + ConversationID: convID, + Kind: SummaryKindLeaf, + Depth: 0, + Content: fmt.Sprintf("leaf-after %d", i), + TokenCount: 100, + }) + s.AppendContextSummary(ctx, convID, summary.SummaryID) + } + for i := 0; i < FreshTailCount+1; i++ { + m, _ := s.AddMessage(ctx, convID, "user", "fresh", 5) + s.AppendContextMessage(ctx, convID, m.ID) + } + + chunk, _ := ce.selectOldestChunkAtDepth(ctx, convID, 0) + if len(chunk) > 3 { + t.Errorf("chunk length = %d, want <= 3 (message breaks chain)", len(chunk)) + } +} + +func TestSelectOldestChunkAtDepthMinTokens(t *testing.T) { + ce, s, convID := newTestCompactionEngine(t) + ctx := context.Background() + + // Create summaries with very low token counts (total < 2000) + for i := 0; i < 5; i++ { + summary, _ := s.CreateSummary(ctx, CreateSummaryInput{ + ConversationID: convID, + Kind: SummaryKindLeaf, + Depth: 0, + Content: fmt.Sprintf("tiny summary %d", i), + TokenCount: 50, // very small + }) + s.AppendContextSummary(ctx, convID, summary.SummaryID) + } + + // Add fresh tail to protect from compaction + for i := 0; i < FreshTailCount+1; i++ { + m, _ := s.AddMessage(ctx, convID, "user", fmt.Sprintf("tail %d", i), 10) + s.AppendContextMessage(ctx, convID, m.ID) + } + + // Should return nil because total tokens (250) < 2000 minimum + chunk, err := ce.selectOldestChunkAtDepth(ctx, convID, 0) + if err != nil { + t.Fatalf("selectOldestChunkAtDepth: %v", err) + } + if len(chunk) > 0 { + t.Errorf("expected empty chunk when tokens < 2000, got %d summaries", len(chunk)) + } +} + +func TestSelectOldestChunkAtDepthPassesMinTokens(t *testing.T) { + ce, s, convID := newTestCompactionEngine(t) + ctx := context.Background() + + // Create summaries with enough tokens (total >= 2000) + for i := 0; i < 5; i++ { + summary, _ := s.CreateSummary(ctx, CreateSummaryInput{ + ConversationID: convID, + Kind: SummaryKindLeaf, + Depth: 0, + Content: fmt.Sprintf( + "substantial summary with enough content to meet minimum token threshold for condensation candidate %d", + i, + ), + TokenCount: 500, // 5 × 500 = 2500 >= 2000 + }) + s.AppendContextSummary(ctx, convID, summary.SummaryID) + } + + // Add fresh tail + for i := 0; i < FreshTailCount+1; i++ { + m, _ := s.AddMessage(ctx, convID, "user", fmt.Sprintf("tail %d", i), 10) + s.AppendContextMessage(ctx, convID, m.ID) + } + + // Should return chunk because total tokens (2500) >= 2000 + chunk, err := ce.selectOldestChunkAtDepth(ctx, convID, 0) + if err != nil { + t.Fatalf("selectOldestChunkAtDepth: %v", err) + } + if len(chunk) == 0 { + t.Error("expected non-empty chunk when tokens >= 2000") + } +} + +func TestGenerateLeafSummary(t *testing.T) { + ce, _, _ := newTestCompactionEngine(t) + ctx := context.Background() + + msgs := []Message{ + {Role: "user", Content: "hello world", TokenCount: 5}, + {Role: "assistant", Content: "hi there", TokenCount: 5}, + } + + content, err := ce.generateLeafSummary(ctx, msgs, "") + if err != nil { + t.Fatalf("generateLeafSummary: %v", err) + } + if content == "" { + t.Error("expected non-empty summary content") + } +} + +func TestGenerateLeafSummaryEscalationToAggressive(t *testing.T) { + // Level 1 returns summary that's too large (tokens >= input), should escalate to level 2 + var calls []string + escalateComplete := func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) { + if contains(prompt, "Aggressive summary policy") { + calls = append(calls, "aggressive") + return "Short aggressive summary.", nil + } + calls = append(calls, "normal") + // Return a very long summary to trigger escalation + longContent := make([]byte, 5000) + for i := range longContent { + longContent[i] = 'x' + } + return string(longContent), nil + } + + s := openTestStore(t) + ce, _ := newTestCompactionEngineWithStore(s, escalateComplete) + + msgs := []Message{ + {Role: "user", Content: "hello world", TokenCount: 10}, + {Role: "assistant", Content: "response", TokenCount: 10}, + } + + content, err := ce.generateLeafSummary(context.Background(), msgs, "") + if err != nil { + t.Fatalf("generateLeafSummary: %v", err) + } + if content == "" { + t.Error("expected non-empty summary content") + } + // Should have called both normal and aggressive + foundNormal := false + foundAggressive := false + for _, c := range calls { + if c == "normal" { + foundNormal = true + } + if c == "aggressive" { + foundAggressive = true + } + } + if !foundNormal { + t.Error("expected normal LLM call") + } + if !foundAggressive { + t.Error("expected aggressive LLM call (level 2 escalation)") + } +} + +func TestGenerateLeafSummaryEscalationToTruncation(t *testing.T) { + // Both normal and aggressive return empty, should escalate to level 3 truncation + emptyComplete := func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) { + return "", nil + } + + s := openTestStore(t) + ce, _ := newTestCompactionEngineWithStore(s, emptyComplete) + + msgs := []Message{ + {Role: "user", Content: "hello world from test", TokenCount: 10}, + {Role: "assistant", Content: "response text here", TokenCount: 10}, + } + + content, err := ce.generateLeafSummary(context.Background(), msgs, "") + if err != nil { + t.Fatalf("generateLeafSummary: %v", err) + } + // Level 3 truncation should have produced something + if content == "" { + t.Error("expected non-empty content from level 3 truncation fallback") + } + if !contains(content, "Truncated from") { + t.Errorf("expected truncation marker in content: %q", content) + } +} + +func TestGenerateCondensedSummary(t *testing.T) { + ce, _, _ := newTestCompactionEngine(t) + ctx := context.Background() + + summaries := []Summary{ + {SummaryID: "sum_a", Content: "first summary", TokenCount: 100}, + {SummaryID: "sum_b", Content: "second summary", TokenCount: 100}, + } + + content, err := ce.generateCondensedSummary(ctx, summaries) + if err != nil { + t.Fatalf("generateCondensedSummary: %v", err) + } + if content == "" { + t.Error("expected non-empty condensed summary content") + } +} + +func TestGenerateCondensedSummaryEscalation(t *testing.T) { + // When LLM returns empty, should fall back to deterministic concatenation + emptyComplete := func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) { + return "", nil + } + + s := openTestStore(t) + ce, _ := newTestCompactionEngineWithStore(s, emptyComplete) + + summaries := []Summary{ + {SummaryID: "sum_a", Content: "first summary text", TokenCount: 50}, + {SummaryID: "sum_b", Content: "second summary text", TokenCount: 50}, + } + + content, err := ce.generateCondensedSummary(context.Background(), summaries) + if err != nil { + t.Fatalf("generateCondensedSummary: %v", err) + } + // Should fall back to concatenation + if content == "" { + t.Error("expected non-empty content from fallback") + } +} + +// --- Async Condensed Compaction (Phase 2) --- + +func TestCompactAsyncReturnsBeforeCondensed(t *testing.T) { + // Use a slow CompleteFn to verify Compact returns before condensed finishes + var callCount int32 + slowComplete := func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) { + atomic.AddInt32(&callCount, 1) + time.Sleep(500 * time.Millisecond) // simulate slow LLM + return "Slow condensed summary.", nil + } + + s := openTestStore(t) + ctx := context.Background() + conv, _ := s.GetOrCreateConversation(ctx, "test:async") + convID := conv.ConversationID + + ce, cancel := newTestCompactionEngineWithStore(s, slowComplete) + t.Cleanup(func() { + cancel() + time.Sleep(100 * time.Millisecond) + }) + + // Create enough leaf summaries for condensation + fresh tail + for i := 0; i < CondensedMinFanout; i++ { + now := time.Now().UTC() + summary, _ := s.CreateSummary(ctx, CreateSummaryInput{ + ConversationID: convID, + Kind: SummaryKindLeaf, + Depth: 0, + Content: "leaf for async test", + TokenCount: 500, + EarliestAt: &now, + LatestAt: &now, + }) + s.AppendContextSummary(ctx, convID, summary.SummaryID) + } + for i := 0; i < FreshTailCount; i++ { + m, _ := s.AddMessage(ctx, convID, "user", "fresh", 10) + s.AppendContextMessage(ctx, convID, m.ID) + } + + // Compact with force — should return quickly, condensed runs async + start := time.Now() + result, err := ce.Compact(ctx, convID, CompactInput{Force: true}) + elapsed := time.Since(start) + + if err != nil { + t.Fatalf("Compact: %v", err) + } + if result == nil { + t.Fatal("expected non-nil result") + } + + // Should return well before the 500ms LLM call + if elapsed > 200*time.Millisecond { + t.Errorf("Compact took %v, should return before async condensed finishes", elapsed) + } + + // Wait for async to complete + time.Sleep(800 * time.Millisecond) + + // Verify condensed summary was created by background goroutine + summaries, _ := s.GetSummariesByConversation(ctx, convID) + foundCondensed := false + for _, sum := range summaries { + if sum.Kind == SummaryKindCondensed { + foundCondensed = true + break + } + } + if !foundCondensed { + t.Error("expected at least one condensed summary from async Phase 2") + } +} + +func TestCompactAsyncDedup(t *testing.T) { + var callCount int32 + slowComplete := func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) { + atomic.AddInt32(&callCount, 1) + time.Sleep(300 * time.Millisecond) + return "Slow condensed summary.", nil + } + + s := openTestStore(t) + ctx := context.Background() + conv, _ := s.GetOrCreateConversation(ctx, "test:dedup") + convID := conv.ConversationID + + ce, cancel := newTestCompactionEngineWithStore(s, slowComplete) + t.Cleanup(func() { + cancel() + waitForCondensed(ce, convID, 2*time.Second) + }) + + // Create conditions for condensed compaction + for i := 0; i < CondensedMinFanout; i++ { + now := time.Now().UTC() + summary, _ := s.CreateSummary(ctx, CreateSummaryInput{ + ConversationID: convID, + Kind: SummaryKindLeaf, + Depth: 0, + Content: "leaf for dedup", + TokenCount: 500, + EarliestAt: &now, + LatestAt: &now, + }) + s.AppendContextSummary(ctx, convID, summary.SummaryID) + } + for i := 0; i < FreshTailCount; i++ { + m, _ := s.AddMessage(ctx, convID, "user", "fresh", 10) + s.AppendContextMessage(ctx, convID, m.ID) + } + + // Call Compact twice rapidly + ce.Compact(ctx, convID, CompactInput{Force: true}) + ce.Compact(ctx, convID, CompactInput{Force: true}) + + // Wait for async to finish + time.Sleep(600 * time.Millisecond) + + // LLM should only be called once for condensed (dedup) + // callCount may be 0 if no leaf was created (only condensed in goroutine) + // The key is that we don't get 2+ condensed calls + if atomic.LoadInt32(&callCount) > 1 { + t.Errorf("LLM called %d times, expected at most 1 (dedup)", callCount) + } +} + +func TestCompactLeafForceBypassesFreshTail(t *testing.T) { + // Spec: compactLeaf with force=true should bypass FreshTailCount protection + // so CompactUntilUnder can compress messages inside the fresh tail + ce, s, convID := newTestCompactionEngine(t) + ctx := context.Background() + + // Create exactly FreshTailCount+4 messages (36 total) + // Without force: all messages are in fresh tail → no candidate + // With force: should compact the oldest messages + total := FreshTailCount + 4 + for i := 0; i < total; i++ { + m, _ := s.AddMessage(ctx, convID, "user", fmt.Sprintf("message %d for force test", i), 100) + s.AppendContextMessage(ctx, convID, m.ID) + } + + // Without force: should return nil (all in fresh tail) + summaryID, err := ce.compactLeaf(ctx, convID) + if err != nil { + t.Fatalf("compactLeaf no-force: %v", err) + } + if summaryID != nil { + t.Error("expected nil without force (all messages in fresh tail)") + } + + // With force: should compact despite fresh tail protection + summaryID, err = ce.compactLeaf(ctx, convID, true) + if err != nil { + t.Fatalf("compactLeaf force: %v", err) + } + if summaryID == nil { + t.Error("expected summary with force=true (bypasses fresh tail)") + } +} + +func TestCompactLeafAccumulatesUpToLeafChunkTokens(t *testing.T) { + // Spec: compactLeaf should accumulate messages up to LeafChunkTokens before stopping + // It should NOT take the entire contiguous chunk regardless of token count + ce, s, convID := newTestCompactionEngine(t) + ctx := context.Background() + + // Create messages totaling far more than LeafChunkTokens (20000) + // Each message is ~500 tokens, create 80 messages = 40000 tokens + for i := 0; i < 80; i++ { + m, _ := s.AddMessage( + ctx, + convID, + "user", + fmt.Sprintf( + "message %d with lots of content to make it big enough for token counting purposes and this should be a substantial message body that represents a meaningful conversation turn", + i, + ), + 500, + ) + s.AppendContextMessage(ctx, convID, m.ID) + } + + summaryID, err := ce.compactLeaf(ctx, convID) + if err != nil { + t.Fatalf("compactLeaf: %v", err) + } + if summaryID == nil { + t.Fatal("expected a summary to be created") + } + + // The source messages that were compacted should total roughly LeafChunkTokens (20000), + // not the entire 40000 tokens worth of messages + summary, _ := s.GetSummary(ctx, *summaryID) + if summary == nil { + t.Fatal("summary not found") + } + + // Source message tokens should be roughly <= LeafChunkTokens (20000) + // Spec says: "Stop when accumulated tokens >= LeafChunkTokens" + if summary.SourceMessageTokenCount > LeafChunkTokens { + t.Errorf("source tokens = %d, should be <= LeafChunkTokens (%d)", + summary.SourceMessageTokenCount, LeafChunkTokens) + } +} diff --git a/pkg/seahorse/short_constants.go b/pkg/seahorse/short_constants.go new file mode 100644 index 000000000..943d7931e --- /dev/null +++ b/pkg/seahorse/short_constants.go @@ -0,0 +1,30 @@ +package seahorse + +// Short-term memory configuration constants — all are experience-based defaults. + +const ( + // OrdinalStep is the gap between ordinals in context_items. + // Insert at midpoint; resequence only when precision exhausted. + OrdinalStep = 100 + + // ContextThreshold is the compaction trigger for the context window. + ContextThreshold float64 = 0.75 // Compact at 75% of context window + FreshTailCount int = 32 // Recent messages protected from compaction + + // LeafMinFanout is the fanout parameter. + LeafMinFanout int = 8 // Min messages per leaf summary + CondensedMinFanout int = 4 // Min summaries per condensed + CondensedMinFanoutHard int = 2 // Min for forced compaction + + // LeafChunkTokens is the token target. + LeafChunkTokens int = 20000 // Max tokens per leaf chunk + LeafTargetTokens int = 1200 // Target tokens for leaf summaries + CondensedTargetTokens int = 2000 // Target tokens for condensed summaries + MaxExpandTokens int = 4000 // Token cap for expansion queries + + // MaxCompactIterations caps CompactUntilUnder to prevent infinite loops. + // Each iteration reduces ~4x tokens via leaf (8:1) or condensed (4:1) compaction. + // With a 200k token context window and 75% threshold, ~20 iterations is enough + // for any realistic scenario. If exceeded, the issue is logged as a warning. + MaxCompactIterations int = 20 +) diff --git a/pkg/seahorse/short_engine.go b/pkg/seahorse/short_engine.go new file mode 100644 index 000000000..4cd4d3887 --- /dev/null +++ b/pkg/seahorse/short_engine.go @@ -0,0 +1,568 @@ +package seahorse + +import ( + "context" + "database/sql" + "fmt" + "os" + "path/filepath" + "regexp" + "strings" + "sync" + + _ "modernc.org/sqlite" + + "github.com/sipeed/picoclaw/pkg/logger" +) + +// Config holds engine configuration. +type Config struct { + DBPath string `json:"dbPath"` + IgnoreSessionPatterns []string `json:"ignoreSessionPatterns,omitempty"` + StatelessSessionPatterns []string `json:"statelessSessionPatterns,omitempty"` +} + +// CompleteFn is the LLM completion function type. +type CompleteFn func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) + +// CompleteOptions holds LLM completion parameters. +type CompleteOptions struct { + Model string + MaxTokens int + Temperature float64 +} + +// IngestResult is the result of message ingestion. +type IngestResult struct { + MessageCount int `json:"messageCount"` + TokenCount int `json:"tokenCount"` +} + +// AssembleInput controls context assembly. +type AssembleInput struct { + Budget int `json:"budget"` + Query string `json:"query,omitempty"` +} + +// AssembleResult contains assembled context. +type AssembleResult struct { + Messages []Message `json:"messages"` + Summary string `json:"summary"` // formatted XML summaries + system prompt addition +} + +const numSessionShards = 256 + +// Engine is the main short-term memory engine. +type Engine struct { + store *Store + compaction *CompactionEngine + compactionMu sync.Mutex + assembler *Assembler + assemblerMu sync.Mutex + retrieval *RetrievalEngine + config Config + complete CompleteFn + ignorePatterns []*regexp.Regexp + statelessPatterns []*regexp.Regexp + sessionShards [numSessionShards]struct { + mu sync.Mutex + } +} + +// CompactionEngine handles LLM-based summarization (defined in short_compaction.go). +type CompactionEngine struct { + store *Store + config Config + complete CompleteFn + condensing sync.Map // map[int64]struct{} — dedup for async condensed goroutines + shutdownCtx context.Context + shutdownCancel context.CancelFunc +} + +// Assembler handles budget-aware context assembly (defined in short_assembler.go). +type Assembler struct { + store *Store + config Config +} + +// RetrievalEngine handles search and expansion (defined in short_retrieval.go). +type RetrievalEngine struct { + store *Store + config Config +} + +// Store returns the underlying store for direct access. +func (r *RetrievalEngine) Store() *Store { + return r.store +} + +// NewEngine creates a new short-term memory engine. +func NewEngine(config Config, completeFn CompleteFn) (*Engine, error) { + dir := filepath.Dir(config.DBPath) + if dir != "" && dir != "." { + if err := os.MkdirAll(dir, 0o755); err != nil { + return nil, fmt.Errorf("create db directory: %w", err) + } + } + + db, err := sql.Open("sqlite", config.DBPath) + if err != nil { + return nil, fmt.Errorf("open db: %w", err) + } + + // Configure SQLite for concurrent access + if _, err := db.Exec("PRAGMA journal_mode = WAL;"); err != nil { + db.Close() + return nil, fmt.Errorf("enable WAL: %w", err) + } + if _, err := db.Exec("PRAGMA busy_timeout = 5000;"); err != nil { + db.Close() + return nil, fmt.Errorf("set busy_timeout: %w", err) + } + if _, err := db.Exec("PRAGMA synchronous = NORMAL;"); err != nil { + db.Close() + return nil, fmt.Errorf("set synchronous: %w", err) + } + + if err := runSchema(db); err != nil { + db.Close() + return nil, fmt.Errorf("migrations: %w", err) + } + + store := &Store{db: db} + + // Prepend hardcoded ignore patterns (spec lines 1326-1328) + ignorePatterns := make([]string, 0, 1+len(config.IgnoreSessionPatterns)) + ignorePatterns = append(ignorePatterns, "heartbeat") + ignorePatterns = append(ignorePatterns, config.IgnoreSessionPatterns...) + + retrieval := &RetrievalEngine{store: store, config: config} + + return &Engine{ + store: store, + compaction: nil, + assembler: nil, + retrieval: retrieval, + config: config, + complete: completeFn, + ignorePatterns: compileSessionPatterns(ignorePatterns), + statelessPatterns: compileSessionPatterns(config.StatelessSessionPatterns), + }, nil +} + +// compileSessionPattern converts a glob pattern to a compiled regex. +// Pattern rules: +// - * matches any sequence of non-colon characters ([^:]*) +// - ** matches any sequence of characters including colons (.*) +// - All other characters are treated literally +// - Pattern is anchored (^...$) +func compileSessionPattern(pattern string) *regexp.Regexp { + var b strings.Builder + b.WriteByte('^') + + i := 0 + for i < len(pattern) { + if i+1 < len(pattern) && pattern[i] == '*' && pattern[i+1] == '*' { + b.WriteString(".*") + i += 2 + continue + } + if pattern[i] == '*' { + b.WriteString("[^:]*") + i++ + continue + } + b.WriteString(regexp.QuoteMeta(string(pattern[i]))) + i++ + } + + b.WriteByte('$') + return regexp.MustCompile(b.String()) +} + +// compileSessionPatterns compiles multiple glob patterns into regex patterns. +func compileSessionPatterns(patterns []string) []*regexp.Regexp { + result := make([]*regexp.Regexp, 0, len(patterns)) + for _, p := range patterns { + if p == "" { + continue + } + result = append(result, compileSessionPattern(p)) + } + return result +} + +// shouldIgnoreSession returns true if the session key matches any ignore pattern. +func (e *Engine) shouldIgnoreSession(sessionKey string) bool { + for _, p := range e.ignorePatterns { + if p.MatchString(sessionKey) { + return true + } + } + return false +} + +// isStatelessSession returns true if the session key matches any stateless pattern. +func (e *Engine) isStatelessSession(sessionKey string) bool { + for _, p := range e.statelessPatterns { + if p.MatchString(sessionKey) { + return true + } + } + return false +} + +// fnv32 computes FNV-1a 32-bit hash for session key sharding. +func fnv32(key string) uint32 { + h := uint32(2166136261) + for _, c := range key { + h ^= uint32(c) + h *= 16777619 + } + return h +} + +// getSessionMutex returns the sharded mutex for a session key. +func (e *Engine) getSessionMutex(sessionKey string) *sync.Mutex { + h := fnv32(sessionKey) + shard := h % numSessionShards + return &e.sessionShards[shard].mu +} + +// Ingest adds messages to a conversation identified by sessionKey. +func (e *Engine) Ingest(ctx context.Context, sessionKey string, messages []Message) (*IngestResult, error) { + if e.shouldIgnoreSession(sessionKey) { + return nil, nil + } + if e.isStatelessSession(sessionKey) { + return nil, nil + } + + mu := e.getSessionMutex(sessionKey) + mu.Lock() + defer mu.Unlock() + + conv, err := e.store.GetOrCreateConversation(ctx, sessionKey) + if err != nil { + return nil, fmt.Errorf("get conversation: %w", err) + } + + var totalTokens int + var msgIDs []int64 + for _, msg := range messages { + var added *Message + var err error + if len(msg.Parts) > 0 { + added, err = e.store.AddMessageWithParts(ctx, conv.ConversationID, msg.Role, msg.Parts, msg.TokenCount) + } else { + added, err = e.store.AddMessage(ctx, conv.ConversationID, msg.Role, msg.Content, msg.TokenCount) + } + if err != nil { + return nil, fmt.Errorf("add message: %w", err) + } + totalTokens += msg.TokenCount + msgIDs = append(msgIDs, added.ID) + } + + // Append to context_items using actual inserted IDs + if err := e.store.AppendContextMessages(ctx, conv.ConversationID, msgIDs); err != nil { + return nil, fmt.Errorf("append context: %w", err) + } + + logger.InfoCF("seahorse", "ingest", map[string]any{ + "conv_id": conv.ConversationID, + "messages": len(messages), + "tokens": totalTokens, + }) + return &IngestResult{ + MessageCount: len(messages), + TokenCount: totalTokens, + }, nil +} + +// Close releases resources. +func (e *Engine) Close() error { + // Signal compaction goroutines to stop + if e.compaction != nil { + e.compaction.Close() + } + if e.store != nil && e.store.db != nil { + return e.store.db.Close() + } + return nil +} + +// GetRetrieval returns the retrieval engine for tool implementations. +func (e *Engine) GetRetrieval() *RetrievalEngine { + return e.retrieval +} + +// Assemble builds budget-constrained context for a session. +func (e *Engine) Assemble(ctx context.Context, sessionKey string, input AssembleInput) (*AssembleResult, error) { + if e.shouldIgnoreSession(sessionKey) { + return nil, nil + } + + conv, err := e.store.GetOrCreateConversation(ctx, sessionKey) + if err != nil { + return nil, fmt.Errorf("get conversation: %w", err) + } + + e.initAssemblerOnce() + return e.assembler.Assemble(ctx, conv.ConversationID, input) +} + +// Compact compresses conversation history for a session. +func (e *Engine) Compact(ctx context.Context, sessionKey string, input CompactInput) (*CompactResult, error) { + if e.shouldIgnoreSession(sessionKey) || e.isStatelessSession(sessionKey) { + return &CompactResult{}, nil + } + + conv, err := e.store.GetOrCreateConversation(ctx, sessionKey) + if err != nil { + return nil, fmt.Errorf("get conversation: %w", err) + } + + e.initCompactionOnce() + return e.compaction.Compact(ctx, conv.ConversationID, input) +} + +// CompactUntilUnder aggressively compacts until context is under budget. +// Used for emergency compaction after LLM overflow (retry reason). +func (e *Engine) CompactUntilUnder(ctx context.Context, sessionKey string, budget int) (*CompactResult, error) { + if e.shouldIgnoreSession(sessionKey) || e.isStatelessSession(sessionKey) { + return &CompactResult{}, nil + } + + conv, err := e.store.GetOrCreateConversation(ctx, sessionKey) + if err != nil { + return nil, fmt.Errorf("get conversation: %w", err) + } + + e.initCompactionOnce() + return e.compaction.CompactUntilUnder(ctx, conv.ConversationID, budget) +} + +// initCompactionOnce lazily initializes the compaction engine. +func (e *Engine) initCompactionOnce() { + if e.compaction == nil { + e.compactionMu.Lock() + defer e.compactionMu.Unlock() + if e.compaction == nil { + shutdownCtx, shutdownCancel := context.WithCancel(context.Background()) + e.compaction = &CompactionEngine{ + store: e.store, + config: e.config, + complete: e.complete, + shutdownCtx: shutdownCtx, + shutdownCancel: shutdownCancel, + } + } + } +} + +// initAssemblerOnce lazily initializes the assembler. +func (e *Engine) initAssemblerOnce() { + if e.assembler == nil { + e.assemblerMu.Lock() + defer e.assemblerMu.Unlock() + if e.assembler == nil { + e.assembler = &Assembler{store: e.store, config: e.config} + } + } +} + +// IngestMessages is an alias for Ingest. +func (e *Engine) IngestMessages(ctx context.Context, sessionKey string, messages []Message) (*IngestResult, error) { + return e.Ingest(ctx, sessionKey, messages) +} + +// Bootstrap reconciles a session's messages with the database. +// Called once at startup for each known session. +// Bootstrap reconciles JSONL history with SQLite by ingesting only the delta. +// Simple approach: find longest matching prefix and append delta. +// If any mismatch is detected, clear and rebuild. +func (e *Engine) Bootstrap(ctx context.Context, sessionKey string, messages []Message) error { + if e.shouldIgnoreSession(sessionKey) { + return nil + } + if e.isStatelessSession(sessionKey) { + return nil + } + if len(messages) == 0 { + return nil + } + + conv, err := e.store.GetOrCreateConversation(ctx, sessionKey) + if err != nil { + return fmt.Errorf("bootstrap: get conversation: %w", err) + } + + // Get messages already in DB + dbMsgs, err := e.store.GetMessages(ctx, conv.ConversationID, len(messages), 0) + if err != nil { + return fmt.Errorf("bootstrap: get messages: %w", err) + } + + // Fast path: DB has same count and exact match → no-op + if len(dbMsgs) == len(messages) { + matched := true + for i := 0; i < len(messages); i++ { + if !messageMatches(dbMsgs[i], messages[i]) { + matched = false + break + } + } + if matched { + return nil // DB is up to date + } + } + + // Find longest matching prefix from the start + anchor := -1 + compareLen := len(dbMsgs) + if compareLen > len(messages) { + compareLen = len(messages) + } + + for i := 0; i < compareLen; i++ { + if messageMatches(dbMsgs[i], messages[i]) { + anchor = i + } else { + // Mismatch detected - log details and rebuild + logger.InfoCF("seahorse", "bootstrap: mismatch detected", map[string]any{ + "conv_id": conv.ConversationID, + "index": i, + "db_role": dbMsgs[i].Role, + "db_content": truncate(dbMsgs[i].Content, 50), + "db_parts": len(dbMsgs[i].Parts), + "msg_role": messages[i].Role, + "msg_content": truncate(messages[i].Content, 50), + "msg_parts": len(messages[i].Parts), + }) + break + } + } + + // If we hit a mismatch before reaching the end of DB messages, delete delta and re-ingest + // Note: anchor can be -1 if first message didn't match (history completely changed) + if anchor >= 0 && anchor < len(dbMsgs)-1 && len(dbMsgs) > 0 { + anchorID := dbMsgs[anchor].ID + logger.InfoCF("seahorse", "bootstrap: history edit detected", map[string]any{ + "conv_id": conv.ConversationID, + "db_count": len(dbMsgs), + "anchor": anchor, + "anchor_id": anchorID, + "msg_count": len(messages), + "delta_start": anchor + 1, + }) + + // Delete messages after anchor (also clears context_items) + if err := e.store.DeleteMessagesAfterID(ctx, conv.ConversationID, anchorID); err != nil { + return fmt.Errorf("bootstrap: delete messages: %w", err) + } + + // Re-ingest from anchor+1 to end + delta := messages[anchor+1:] + if len(delta) > 0 { + _, err := e.Ingest(ctx, sessionKey, delta) + if err != nil { + return fmt.Errorf("bootstrap: re-ingest: %w", err) + } + } + return nil + } + + // Normal case: append delta after anchor + if anchor >= 0 && anchor < len(messages)-1 { + delta := messages[anchor+1:] + if len(delta) > 0 { + _, err := e.Ingest(ctx, sessionKey, delta) + if err != nil { + return fmt.Errorf("bootstrap: ingest delta: %w", err) + } + } + } else if anchor == -1 && len(dbMsgs) > 0 { + // First message changed (history completely different) - rebuild from scratch + logger.InfoCF("seahorse", "bootstrap: history replaced, rebuilding", map[string]any{ + "conv_id": conv.ConversationID, + "db_count": len(dbMsgs), + "msg_count": len(messages), + }) + // Delete all existing messages + if err := e.store.DeleteMessagesAfterID(ctx, conv.ConversationID, 0); err != nil { + return fmt.Errorf("bootstrap: delete all messages: %w", err) + } + // Re-ingest everything + if len(messages) > 0 { + _, err := e.Ingest(ctx, sessionKey, messages) + if err != nil { + return fmt.Errorf("bootstrap: re-ingest all: %w", err) + } + } + } else if anchor == -1 && len(dbMsgs) == 0 { + // DB is empty, ingest everything + _, err := e.Ingest(ctx, sessionKey, messages) + if err != nil { + return fmt.Errorf("bootstrap: ingest all: %w", err) + } + } + + return nil +} + +// truncate shortens a string for logging. +func truncate(s string, maxLen int) string { + if len(s) <= maxLen { + return s + } + return s[:maxLen] + "..." +} + +// messageMatches compares two messages using (role, content) or (role, parts). +// TokenCount is NOT compared because it may be re-estimated differently +// during bootstrap (e.g., via tokenizer.EstimateMessageTokens). +// For messages with Parts (tool_use, tool_result), compare Parts instead of Content +// since AddMessageWithParts stores empty Content in DB. +func messageMatches(a, b Message) bool { + if a.Role != b.Role { + return false + } + // If either message has Parts, compare Parts + if len(a.Parts) > 0 || len(b.Parts) > 0 { + return partsMatch(a.Parts, b.Parts) + } + // Simple text messages: compare Content + return a.Content == b.Content +} + +// partsMatch compares two slices of MessagePart for equality. +func partsMatch(a, b []MessagePart) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i].Type != b[i].Type { + return false + } + switch a[i].Type { + case "text": + if a[i].Text != b[i].Text { + return false + } + case "tool_use": + if a[i].Name != b[i].Name || a[i].Arguments != b[i].Arguments || a[i].ToolCallID != b[i].ToolCallID { + return false + } + case "tool_result": + if a[i].ToolCallID != b[i].ToolCallID || a[i].Text != b[i].Text { + return false + } + case "media": + if a[i].MediaURI != b[i].MediaURI || a[i].MimeType != b[i].MimeType { + return false + } + } + } + return true +} diff --git a/pkg/seahorse/short_engine_test.go b/pkg/seahorse/short_engine_test.go new file mode 100644 index 000000000..d64634fb7 --- /dev/null +++ b/pkg/seahorse/short_engine_test.go @@ -0,0 +1,1448 @@ +package seahorse + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strings" + "sync" + "testing" + "time" +) + +// helper: open a test engine with in-memory DB +func newTestEngine(t *testing.T) *Engine { + t.Helper() + db := openTestDB(t) + if err := runSchema(db); err != nil { + t.Fatalf("migration: %v", err) + } + store := &Store{db: db} + return &Engine{ + store: store, + config: Config{}, + } +} + +// --- compileSessionPattern --- + +func TestCompileSessionPattern(t *testing.T) { + tests := []struct { + pattern string + input string + want bool + }{ + // Exact match + {"agent:abc123", "agent:abc123", true}, + {"agent:abc123", "agent:def456", false}, + // Single * — matches non-colon chars + {"agent:*", "agent:abc123", true}, + {"agent:*", "agent:abc:def", false}, // * doesn't match colons + // ** — matches everything including colons + {"cron:**", "cron:backup", true}, + {"cron:**", "cron:backup:daily", true}, + {"cron:**", "agent:abc", false}, + // Mixed + {"agent:*:sub:**", "agent:abc:sub:def", true}, + {"agent:*:sub:**", "agent:abc:sub:def:ghi", true}, + {"agent:*:sub:**", "agent:abc:def", false}, + // Empty pattern — matches nothing meaningful + {"", "", true}, + {"", "agent:abc", false}, + } + + for _, tt := range tests { + re := compileSessionPattern(tt.pattern) + if re == nil && tt.pattern != "" { + t.Fatalf("compileSessionPattern(%q) returned nil", tt.pattern) + } + if tt.pattern == "" { + continue + } + got := re.MatchString(tt.input) + if got != tt.want { + t.Errorf("compileSessionPattern(%q).Match(%q) = %v, want %v", tt.pattern, tt.input, got, tt.want) + } + } +} + +// --- Session Pattern Filtering --- + +func TestEngineShouldIgnoreSession(t *testing.T) { + eng := &Engine{ + ignorePatterns: compileSessionPatterns([]string{"cron:**", "test:*"}), + } + + tests := []struct { + key string + want bool + }{ + {"cron:backup", true}, + {"cron:backup:daily", true}, + {"test:session", true}, + {"agent:abc", false}, + {"", false}, + } + + for _, tt := range tests { + got := eng.shouldIgnoreSession(tt.key) + if got != tt.want { + t.Errorf("shouldIgnoreSession(%q) = %v, want %v", tt.key, got, tt.want) + } + } +} + +func TestEngineIsStatelessSession(t *testing.T) { + eng := &Engine{ + statelessPatterns: compileSessionPatterns([]string{"agent:*:sub:**"}), + } + + tests := []struct { + key string + want bool + }{ + {"agent:abc:sub:def", true}, + {"agent:abc:sub:def:ghi", true}, + {"agent:abc", false}, + {"cron:backup", false}, + } + + for _, tt := range tests { + got := eng.isStatelessSession(tt.key) + if got != tt.want { + t.Errorf("isStatelessSession(%q) = %v, want %v", tt.key, got, tt.want) + } + } +} + +// --- NewEngine --- + +func TestNewEngine(t *testing.T) { + dir := t.TempDir() + dbPath := filepath.Join(dir, "short.db") + + eng, err := NewEngine(Config{DBPath: dbPath}, nil) + if err != nil { + t.Fatalf("NewEngine: %v", err) + } + defer eng.Close() + + // DB file should exist + if _, pathErr := os.Stat(dbPath); os.IsNotExist(pathErr) { + t.Error("expected DB file to be created") + } + + // Store should be usable + ctx := context.Background() + conv, err := eng.store.GetOrCreateConversation(ctx, "test:session") + if err != nil { + t.Fatalf("store should work: %v", err) + } + if conv.ConversationID == 0 { + t.Error("expected valid conversation ID") + } + + // GetRetrieval should return non-nil RetrievalEngine + retrieval := eng.GetRetrieval() + if retrieval == nil { + t.Error("expected GetRetrieval to return non-nil RetrievalEngine") + } +} + +func TestNewEngineWithPatterns(t *testing.T) { + dir := t.TempDir() + dbPath := filepath.Join(dir, "short.db") + + eng, err := NewEngine(Config{ + DBPath: dbPath, + IgnoreSessionPatterns: []string{"cron:**"}, + StatelessSessionPatterns: []string{"agent:*:sub:**"}, + }, nil) + if err != nil { + t.Fatalf("NewEngine: %v", err) + } + defer eng.Close() + + if !eng.shouldIgnoreSession("cron:backup") { + t.Error("expected cron:backup to be ignored") + } + if !eng.isStatelessSession("agent:abc:sub:def") { + t.Error("expected agent:abc:sub:def to be stateless") + } +} + +// --- Ingest --- + +func TestEngineIngest(t *testing.T) { + eng := newTestEngine(t) + ctx := context.Background() + + msgs := []Message{ + {Role: "user", Content: "hello", TokenCount: 2}, + {Role: "assistant", Content: "world", TokenCount: 2}, + } + + result, err := eng.Ingest(ctx, "agent:test", msgs) + if err != nil { + t.Fatalf("Ingest: %v", err) + } + if result.MessageCount != 2 { + t.Errorf("MessageCount = %d, want 2", result.MessageCount) + } + if result.TokenCount != 4 { + t.Errorf("TokenCount = %d, want 4", result.TokenCount) + } + + // Verify messages were stored + conv, _ := eng.store.GetOrCreateConversation(ctx, "agent:test") + stored, _ := eng.store.GetMessages(ctx, conv.ConversationID, 10, 0) + if len(stored) != 2 { + t.Fatalf("stored messages = %d, want 2", len(stored)) + } + if stored[0].Content != "hello" { + t.Errorf("stored[0].Content = %q, want 'hello'", stored[0].Content) + } + + // Verify context_items were populated + items, _ := eng.store.GetContextItems(ctx, conv.ConversationID) + if len(items) != 2 { + t.Fatalf("context items = %d, want 2", len(items)) + } + if items[0].ItemType != "message" { + t.Errorf("item[0].ItemType = %q, want 'message'", items[0].ItemType) + } +} + +func TestEngineIngestIgnoresSession(t *testing.T) { + eng := newTestEngine(t) + eng.ignorePatterns = compileSessionPatterns([]string{"cron:**"}) + ctx := context.Background() + + msgs := []Message{{Role: "user", Content: "hello", TokenCount: 2}} + result, err := eng.Ingest(ctx, "cron:backup", msgs) + if err != nil { + t.Fatalf("Ingest: %v", err) + } + if result != nil { + t.Error("expected nil result for ignored session") + } + + // Verify no data was stored + conv, _ := eng.store.GetConversationBySessionKey(ctx, "cron:backup") + if conv != nil { + t.Error("expected no conversation for ignored session") + } +} + +func TestEngineIngestStatelessSession(t *testing.T) { + eng := newTestEngine(t) + eng.statelessPatterns = compileSessionPatterns([]string{"agent:*:ro"}) + ctx := context.Background() + + msgs := []Message{{Role: "user", Content: "hello", TokenCount: 2}} + result, err := eng.Ingest(ctx, "agent:abc:ro", msgs) + if err != nil { + t.Fatalf("Ingest: %v", err) + } + if result != nil { + t.Error("expected nil result for stateless session") + } +} + +func TestEngineIngestIncremental(t *testing.T) { + eng := newTestEngine(t) + ctx := context.Background() + + // First ingest + eng.Ingest(ctx, "agent:test", []Message{ + {Role: "user", Content: "msg1", TokenCount: 1}, + }) + // Second ingest — should append, not replace + eng.Ingest(ctx, "agent:test", []Message{ + {Role: "assistant", Content: "msg2", TokenCount: 1}, + }) + + conv, _ := eng.store.GetOrCreateConversation(ctx, "agent:test") + stored, _ := eng.store.GetMessages(ctx, conv.ConversationID, 10, 0) + if len(stored) != 2 { + t.Errorf("stored messages = %d, want 2", len(stored)) + } +} + +func TestEngineIngestWithParts(t *testing.T) { + eng := newTestEngine(t) + ctx := context.Background() + + msgs := []Message{ + { + Role: "assistant", + Content: "", + TokenCount: 10, + Parts: []MessagePart{ + {Type: "tool_use", Name: "read_file", Arguments: `{"path":"/tmp/test"}`, ToolCallID: "tc_123"}, + {Type: "text", Text: "here is the file content"}, + }, + }, + } + + result, err := eng.Ingest(ctx, "agent:parts-test", msgs) + if err != nil { + t.Fatalf("Ingest with parts: %v", err) + } + if result.MessageCount != 1 { + t.Errorf("MessageCount = %d, want 1", result.MessageCount) + } + + // Verify message was stored WITH parts + conv, _ := eng.store.GetOrCreateConversation(ctx, "agent:parts-test") + stored, _ := eng.store.GetMessages(ctx, conv.ConversationID, 10, 0) + if len(stored) != 1 { + t.Fatalf("stored messages = %d, want 1", len(stored)) + } + if len(stored[0].Parts) != 2 { + t.Fatalf("stored message parts = %d, want 2", len(stored[0].Parts)) + } + if stored[0].Parts[0].Type != "tool_use" { + t.Errorf("part[0].Type = %q, want tool_use", stored[0].Parts[0].Type) + } + if stored[0].Parts[0].Name != "read_file" { + t.Errorf("part[0].Name = %q, want read_file", stored[0].Parts[0].Name) + } + if stored[0].Parts[0].ToolCallID != "tc_123" { + t.Errorf("part[0].ToolCallID = %q, want tc_123", stored[0].Parts[0].ToolCallID) + } + if stored[0].Parts[1].Type != "text" { + t.Errorf("part[1].Type = %q, want text", stored[0].Parts[1].Type) + } + if stored[0].Parts[1].Text != "here is the file content" { + t.Errorf("part[1].Text = %q, want 'here is the file content'", stored[0].Parts[1].Text) + } +} + +func TestEngineIngestAssemblePreservesParts(t *testing.T) { + eng := newTestEngine(t) + ctx := context.Background() + + // Ingest a message with tool_use parts + eng.Ingest(ctx, "agent:parts-roundtrip", []Message{ + {Role: "user", Content: "list files", TokenCount: 3}, + { + Role: "assistant", + Content: "", + TokenCount: 5, + Parts: []MessagePart{ + {Type: "tool_use", Name: "bash", Arguments: `{"cmd":"ls"}`, ToolCallID: "tc_1"}, + {Type: "text", Text: "found 3 files"}, + }, + }, + }) + + // Assemble should return messages with parts intact + result, err := eng.Assemble(ctx, "agent:parts-roundtrip", AssembleInput{Budget: 1000}) + if err != nil { + t.Fatalf("Assemble: %v", err) + } + + if len(result.Messages) != 2 { + t.Fatalf("Assemble returned %d messages, want 2", len(result.Messages)) + } + + // The second message should have Parts populated + assistantMsg := result.Messages[1] + if len(assistantMsg.Parts) != 2 { + t.Fatalf("Assembled assistant message Parts = %d, want 2", len(assistantMsg.Parts)) + } + if assistantMsg.Parts[0].Type != "tool_use" { + t.Errorf("part[0].Type = %q, want tool_use", assistantMsg.Parts[0].Type) + } + if assistantMsg.Parts[0].ToolCallID != "tc_1" { + t.Errorf("part[0].ToolCallID = %q, want tc_1", assistantMsg.Parts[0].ToolCallID) + } +} + +// --- Session Mutex --- + +func TestEngineSessionMutex(t *testing.T) { + eng := newTestEngine(t) + + mu1 := eng.getSessionMutex("agent:test") + mu2 := eng.getSessionMutex("agent:test") + mu3 := eng.getSessionMutex("agent:other") + + if mu1 != mu2 { + t.Error("expected same mutex for same session key") + } + if mu1 == mu3 { + t.Error("expected different mutex for different session key") + } +} + +// --- Close --- + +func TestEngineClose(t *testing.T) { + eng := newTestEngine(t) + if err := eng.Close(); err != nil { + t.Errorf("Close: %v", err) + } +} + +// --- compileSessionPatterns (batch) --- + +func TestCompileSessionPatterns(t *testing.T) { + patterns := compileSessionPatterns([]string{"cron:**", "agent:*:ro"}) + if len(patterns) != 2 { + t.Fatalf("expected 2 patterns, got %d", len(patterns)) + } + + tests := []struct { + input string + want bool + }{ + {"cron:backup", true}, + {"agent:abc:ro", true}, + {"agent:abc:def", false}, + {"", false}, + } + + for _, tt := range tests { + matched := false + for _, p := range patterns { + if p.MatchString(tt.input) { + matched = true + break + } + } + if matched != tt.want { + t.Errorf("patterns.Match(%q) = %v, want %v", tt.input, matched, tt.want) + } + } +} + +func TestCompileSessionPatternsEmpty(t *testing.T) { + patterns := compileSessionPatterns(nil) + if len(patterns) != 0 { + t.Errorf("expected 0 patterns for nil input, got %d", len(patterns)) + } +} + +// --- Bootstrap --- + +func TestEngineBootstrap(t *testing.T) { + eng := newTestEngine(t) + ctx := context.Background() + + msgs := []Message{ + {Role: "user", Content: "hello", TokenCount: 3}, + {Role: "assistant", Content: "world", TokenCount: 3}, + {Role: "user", Content: "how are you", TokenCount: 5}, + } + + err := eng.Bootstrap(ctx, "agent:boot1", msgs) + if err != nil { + t.Fatalf("Bootstrap: %v", err) + } + + // Verify conversation was created + conv, err := eng.store.GetConversationBySessionKey(ctx, "agent:boot1") + if err != nil { + t.Fatalf("GetConversation: %v", err) + } + if conv == nil { + t.Fatal("expected conversation to exist after bootstrap") + } + + // Verify messages were stored + stored, err := eng.store.GetMessages(ctx, conv.ConversationID, 10, 0) + if err != nil { + t.Fatalf("GetMessages: %v", err) + } + if len(stored) != 3 { + t.Fatalf("expected 3 stored messages, got %d", len(stored)) + } + if stored[0].Content != "hello" { + t.Errorf("stored[0].Content = %q, want 'hello'", stored[0].Content) + } + + // Verify context_items were populated + items, err := eng.store.GetContextItems(ctx, conv.ConversationID) + if err != nil { + t.Fatalf("GetContextItems: %v", err) + } + if len(items) != 3 { + t.Fatalf("expected 3 context items, got %d", len(items)) + } +} + +func TestEngineBootstrapEmpty(t *testing.T) { + eng := newTestEngine(t) + ctx := context.Background() + + err := eng.Bootstrap(ctx, "agent:empty", nil) + if err != nil { + t.Fatalf("Bootstrap empty: %v", err) + } + + // No conversation should be created for empty messages + conv, _ := eng.store.GetConversationBySessionKey(ctx, "agent:empty") + if conv != nil { + t.Error("expected no conversation for empty bootstrap") + } +} + +func TestEngineBootstrapIdempotent(t *testing.T) { + eng := newTestEngine(t) + ctx := context.Background() + + msgs := []Message{ + {Role: "user", Content: "hello", TokenCount: 3}, + {Role: "assistant", Content: "world", TokenCount: 3}, + } + + // Bootstrap twice with same messages + eng.Bootstrap(ctx, "agent:idem", msgs) + eng.Bootstrap(ctx, "agent:idem", msgs) + + // Should still have exactly 2 messages (no duplicates) + conv, _ := eng.store.GetConversationBySessionKey(ctx, "agent:idem") + if conv == nil { + t.Fatal("expected conversation") + } + stored, _ := eng.store.GetMessages(ctx, conv.ConversationID, 10, 0) + if len(stored) != 2 { + t.Errorf("expected 2 messages (idempotent), got %d", len(stored)) + } +} + +func TestEngineBootstrapDelta(t *testing.T) { + eng := newTestEngine(t) + ctx := context.Background() + + // First bootstrap with 2 messages + msgs1 := []Message{ + {Role: "user", Content: "hello", TokenCount: 3}, + {Role: "assistant", Content: "world", TokenCount: 3}, + } + eng.Bootstrap(ctx, "agent:delta", msgs1) + + // Second bootstrap with 4 messages (2 existing + 2 new) + msgs2 := []Message{ + {Role: "user", Content: "hello", TokenCount: 3}, + {Role: "assistant", Content: "world", TokenCount: 3}, + {Role: "user", Content: "new question", TokenCount: 5}, + {Role: "assistant", Content: "new answer", TokenCount: 5}, + } + eng.Bootstrap(ctx, "agent:delta", msgs2) + + conv, _ := eng.store.GetConversationBySessionKey(ctx, "agent:delta") + if conv == nil { + t.Fatal("expected conversation") + } + stored, _ := eng.store.GetMessages(ctx, conv.ConversationID, 10, 0) + if len(stored) != 4 { + t.Errorf("expected 4 messages (delta), got %d", len(stored)) + } +} + +func TestBootstrapPopulatesContextItems(t *testing.T) { + // Bootstrap ingests messages and populates context_items + e := newTestEngine(t) + ctx := context.Background() + + messages := []Message{ + {Role: "user", Content: "hello from bootstrap test", TokenCount: 10}, + {Role: "assistant", Content: "hi there", TokenCount: 5}, + {Role: "user", Content: "how are you", TokenCount: 5}, + {Role: "assistant", Content: "doing well", TokenCount: 5}, + {Role: "user", Content: "great news", TokenCount: 5}, + {Role: "assistant", Content: "awesome", TokenCount: 5}, + {Role: "user", Content: "lets code", TokenCount: 5}, + {Role: "assistant", Content: "sure thing", TokenCount: 5}, + } + + // Bootstrap should ingest and rebuild context_items + err := e.Bootstrap(ctx, "test-bootstrap-rebuild", messages) + if err != nil { + t.Fatalf("Bootstrap: %v", err) + } + + // After bootstrap, context_items should be populated + conv, _ := e.store.GetOrCreateConversation(ctx, "test-bootstrap-rebuild") + items, err := e.store.GetContextItems(ctx, conv.ConversationID) + if err != nil { + t.Fatalf("GetContextItems: %v", err) + } + + if len(items) == 0 { + t.Error("expected context_items to be populated after Bootstrap, got 0 items") + } + + // Should have one item per message + if len(items) != len(messages) { + t.Errorf("expected %d context items, got %d", len(messages), len(items)) + } +} + +func TestBootstrapDeltaPreservesOrder(t *testing.T) { + // When Bootstrap does delta ingest, context_items should maintain + // correct order with new messages appended after anchor. + e := newTestEngine(t) + ctx := context.Background() + sessionKey := "test-bootstrap-delta-order" + + // First: bootstrap with 4 messages + initialMsgs := []Message{ + {Role: "user", Content: "msg1", TokenCount: 5}, + {Role: "assistant", Content: "msg2", TokenCount: 5}, + {Role: "user", Content: "msg3", TokenCount: 5}, + {Role: "assistant", Content: "msg4", TokenCount: 5}, + } + err := e.Bootstrap(ctx, sessionKey, initialMsgs) + if err != nil { + t.Fatalf("first Bootstrap: %v", err) + } + + conv, _ := e.store.GetOrCreateConversation(ctx, sessionKey) + items1, _ := e.store.GetContextItems(ctx, conv.ConversationID) + if len(items1) != 4 { + t.Fatalf("after first bootstrap: expected 4 items, got %d", len(items1)) + } + + // Now bootstrap again with 6 messages (4 existing + 2 new) + // The delta (msg5, msg6) should be appended + updatedMsgs := []Message{ + {Role: "user", Content: "msg1", TokenCount: 5}, + {Role: "assistant", Content: "msg2", TokenCount: 5}, + {Role: "user", Content: "msg3", TokenCount: 5}, + {Role: "assistant", Content: "msg4", TokenCount: 5}, + {Role: "user", Content: "msg5", TokenCount: 5}, + {Role: "assistant", Content: "msg6", TokenCount: 5}, + } + err = e.Bootstrap(ctx, sessionKey, updatedMsgs) + if err != nil { + t.Fatalf("second Bootstrap: %v", err) + } + + items2, _ := e.store.GetContextItems(ctx, conv.ConversationID) + if len(items2) != 6 { + t.Errorf("after delta bootstrap: expected 6 items, got %d", len(items2)) + } +} + +func TestBootstrapHistoryEditFirstMessageChanged(t *testing.T) { + // When the first message changes (anchor = -1), Bootstrap should rebuild + // from scratch without panicking (regression test for index out of range [-1]) + e := newTestEngine(t) + ctx := context.Background() + sessionKey := "test-bootstrap-history-edit" + + // First: bootstrap with some messages + initialMsgs := []Message{ + {Role: "user", Content: "original first", TokenCount: 5}, + {Role: "assistant", Content: "response", TokenCount: 5}, + {Role: "user", Content: "question", TokenCount: 5}, + } + err := e.Bootstrap(ctx, sessionKey, initialMsgs) + if err != nil { + t.Fatalf("first Bootstrap: %v", err) + } + + // Now bootstrap with completely different messages (first message changed) + // This should NOT panic - it should rebuild from scratch + editedMsgs := []Message{ + {Role: "user", Content: "DIFFERENT first message", TokenCount: 5}, + {Role: "assistant", Content: "DIFFERENT response", TokenCount: 5}, + {Role: "user", Content: "DIFFERENT question", TokenCount: 5}, + } + err = e.Bootstrap(ctx, sessionKey, editedMsgs) + if err != nil { + t.Fatalf("second Bootstrap (history edit): %v", err) + } + + conv, _ := e.store.GetOrCreateConversation(ctx, sessionKey) + stored, _ := e.store.GetMessages(ctx, conv.ConversationID, 10, 0) + + // Should have the NEW messages (history was rebuilt) + if len(stored) != 3 { + t.Errorf("expected 3 messages after history edit, got %d", len(stored)) + } + if len(stored) > 0 && stored[0].Content != "DIFFERENT first message" { + t.Errorf("first message = %q, want 'DIFFERENT first message'", stored[0].Content) + } +} + +func TestBootstrapSameContentDifferentTokenCountNoRebuild(t *testing.T) { + // Bootstrap should NOT rebuild when content is identical but TokenCount differs. + // This happens when TokenCount is re-estimated (e.g., via tokenizer.EstimateMessageTokens) + // during bootstrap, which may give slightly different values. + e := newTestEngine(t) + ctx := context.Background() + sessionKey := "test-bootstrap-token-diff" + + // First: bootstrap with some messages + initialMsgs := []Message{ + {Role: "user", Content: "hello world", TokenCount: 10}, + {Role: "assistant", Content: "hi there", TokenCount: 5}, + } + err := e.Bootstrap(ctx, sessionKey, initialMsgs) + if err != nil { + t.Fatalf("first Bootstrap: %v", err) + } + + conv, _ := e.store.GetOrCreateConversation(ctx, sessionKey) + storedBefore, _ := e.store.GetMessages(ctx, conv.ConversationID, 10, 0) + + // Second: bootstrap with SAME content but DIFFERENT TokenCount + // This should be a no-op (not rebuild) + sameContentMsgs := []Message{ + {Role: "user", Content: "hello world", TokenCount: 999}, // Different token count! + {Role: "assistant", Content: "hi there", TokenCount: 888}, // Different token count! + } + err = e.Bootstrap(ctx, sessionKey, sameContentMsgs) + if err != nil { + t.Fatalf("second Bootstrap: %v", err) + } + + storedAfter, _ := e.store.GetMessages(ctx, conv.ConversationID, 10, 0) + + // Should have same number of messages (no rebuild) + if len(storedAfter) != len(storedBefore) { + t.Errorf("expected %d messages (no rebuild), got %d", len(storedBefore), len(storedAfter)) + } + + // Message IDs should be the same (no delete+re-ingest) + for i := range storedBefore { + if storedBefore[i].ID != storedAfter[i].ID { + t.Errorf("message %d ID changed: before=%d, after=%d (should be no-op)", + i, storedBefore[i].ID, storedAfter[i].ID) + } + } +} + +// --- Session Mutex --- + +func TestEngineSessionMutexSharded(t *testing.T) { + eng := newTestEngine(t) + + // Same session key should always return the same mutex (deterministic hash) + mu1 := eng.getSessionMutex("agent:test") + mu2 := eng.getSessionMutex("agent:test") + if mu1 != mu2 { + t.Error("expected same mutex for same session key") + } + + // Different session keys may share the same shard (hash collision) + // This is expected behavior - we just need bounded memory, not unique locks + mu3 := eng.getSessionMutex("agent:other") + + // Both mutexes should be valid and usable + mu1.Lock() + mu1.Unlock() + mu3.Lock() + mu3.Unlock() +} + +func TestEngineSessionMutexBoundedMemory(t *testing.T) { + // Verify that session mutexes use bounded memory (256 shards) + eng := newTestEngine(t) + + // Get mutexes for many different sessions + seen := make(map[*sync.Mutex]bool) + for i := 0; i < 1000; i++ { + sessionKey := fmt.Sprintf("agent:session-%d", i) + mu := eng.getSessionMutex(sessionKey) + seen[mu] = true + } + + // With 256 shards and 1000 sessions, we should see at most 256 unique mutexes + // (likely fewer due to hash collisions) + if len(seen) > 256 { + t.Errorf("expected at most 256 unique mutexes (shards), got %d", len(seen)) + } +} + +func TestEngineSessionMutexConsistentHash(t *testing.T) { + // Same session key should always hash to the same shard + eng := newTestEngine(t) + + sessionKey := "agent:consistent-hash-test" + mu1 := eng.getSessionMutex(sessionKey) + mu2 := eng.getSessionMutex(sessionKey) + mu3 := eng.getSessionMutex(sessionKey) + + if mu1 != mu2 || mu2 != mu3 { + t.Error("hash function should be deterministic - same key must map to same shard") + } +} + +// --- Summary Role --- + +func TestAssemblerSummaryRoleNotUser(t *testing.T) { + // Summaries should use "system" role, not "user" + eng := newTestEngine(t) + ctx := context.Background() + + // Ingest messages + eng.Ingest(ctx, "agent:summary-role-test", []Message{ + {Role: "user", Content: "hello", TokenCount: 5}, + {Role: "assistant", Content: "world", TokenCount: 5}, + }) + + conv, _ := eng.store.GetOrCreateConversation(ctx, "agent:summary-role-test") + + // Create a summary and add it to context + sum, err := eng.store.CreateSummary(ctx, CreateSummaryInput{ + ConversationID: conv.ConversationID, + Content: "Test summary content", + TokenCount: 10, + Kind: SummaryKindCondensed, + Depth: 1, + }) + if err != nil { + t.Fatalf("CreateSummary: %v", err) + } + eng.store.AppendContextSummary(ctx, conv.ConversationID, sum.SummaryID) + + // Assemble and check summary message role + result, err := eng.Assemble(ctx, "agent:summary-role-test", AssembleInput{Budget: 1000}) + if err != nil { + t.Fatalf("Assemble: %v", err) + } + + // Find the summary message (should have XML content with ) + for _, msg := range result.Messages { + if strings.Contains(msg.Content, "= 5 + // This tests the bug: when depth=2 is missing, the loop breaks and depth=3 is never checked + // Need > FreshTailCount(32) summaries so they are not all in fresh tail + // Depth 0: 3 summaries (not enough), Depth 1: 3 summaries (not enough) + // Depth 2: 0 summaries (missing), Depth 3: 40 summaries (enough) + depths := []int{0, 0, 0, 1, 1, 1} + for i := 0; i < 40; i++ { + depths = append(depths, 3) + } + now := time.Now().UTC() + + for i, depth := range depths { + sum, createErr := e.store.CreateSummary(ctx, CreateSummaryInput{ + ConversationID: conv.ConversationID, + Kind: SummaryKindLeaf, + Depth: depth, + Content: fmt.Sprintf("summary depth %d #%d", depth, i), + TokenCount: 10, + EarliestAt: &now, + LatestAt: &now, + }) + if createErr != nil { + t.Fatalf("CreateSummary: %v", createErr) + } + // Add to context items (not in fresh tail) + if appendErr := e.store.AppendContextSummary(ctx, conv.ConversationID, sum.SummaryID); appendErr != nil { + t.Fatalf("AppendContextSummary: %v", appendErr) + } + } + + // Initialize compaction engine (lazy init) + e.initCompactionOnce() + + // Call selectShallowestCondensationCandidate + candidates, err := e.compaction.selectShallowestCondensationCandidate(ctx, conv.ConversationID, false) + if err != nil { + t.Fatalf("selectShallowestCondensationCandidate: %v", err) + } + + // Should find depth=0 (shallowest) with 5 summaries + if candidates == nil { + t.Fatal("expected candidates, got nil") + } + if len(candidates) < CondensedMinFanout { + t.Errorf("expected at least %d candidates, got %d", CondensedMinFanout, len(candidates)) + } + + // Verify all returned summaries have the same depth + if len(candidates) > 0 { + expectedDepth := candidates[0].Depth + for _, c := range candidates[1:] { + if c.Depth != expectedDepth { + t.Errorf("candidates have mixed depths: %d vs %d", expectedDepth, c.Depth) + } + } + } +} diff --git a/pkg/seahorse/short_retrieval.go b/pkg/seahorse/short_retrieval.go new file mode 100644 index 000000000..f7d6bf691 --- /dev/null +++ b/pkg/seahorse/short_retrieval.go @@ -0,0 +1,212 @@ +package seahorse + +import ( + "context" + "fmt" + "regexp" + "strconv" + "strings" + "time" +) + +// ParseLastDuration parses a "last" duration string like "6h", "7d", "2w", "1m". +// Returns the duration and nil error, or zero and error if invalid. +func ParseLastDuration(s string) (time.Duration, error) { + if s == "" { + return 0, fmt.Errorf("empty duration") + } + + re := regexp.MustCompile(`^(\d+)([hdwm])$`) + matches := re.FindStringSubmatch(s) + if matches == nil { + return 0, fmt.Errorf("invalid duration format: %q (use format like 6h, 7d, 2w, 1m)", s) + } + + value, _ := strconv.Atoi(matches[1]) + unit := matches[2] + + switch unit { + case "h": + return time.Duration(value) * time.Hour, nil + case "d": + return time.Duration(value) * 24 * time.Hour, nil + case "w": + return time.Duration(value) * 7 * 24 * time.Hour, nil + case "m": + return time.Duration(value) * 30 * 24 * time.Hour, nil + default: + return 0, fmt.Errorf("unknown unit: %q", unit) + } +} + +// GrepInput controls search across summaries and messages. +type GrepInput struct { + Pattern string `json:"pattern"` + Scope string `json:"scope,omitempty"` // "both" (default), "summary", or "message" + Role string `json:"role,omitempty"` // "user", "assistant", or "" (all) + AllConversations bool `json:"allConversations,omitempty"` + Since *time.Time `json:"since,omitempty"` + Before *time.Time `json:"before,omitempty"` + Last string `json:"last,omitempty"` // shortcut: "6h", "7d", "2w", "1m" + Limit int `json:"limit,omitempty"` +} + +// GrepResult contains search results. +type GrepResult struct { + Success bool `json:"success"` + Summaries []GrepSummaryResult `json:"summaries"` + Messages []GrepMessageResult `json:"messages"` + TotalSummaries int `json:"totalSummaries"` + TotalMessages int `json:"totalMessages"` + Hint string `json:"hint,omitempty"` +} + +// GrepSummaryResult is a summary match from grep. +type GrepSummaryResult struct { + ID string `json:"id"` + Content string `json:"content"` + Depth int `json:"depth"` + Kind SummaryKind `json:"kind"` + ConversationID int64 `json:"conversationId"` + // Rank is the bm25 relevance score (negative value, closer to 0 = better match). + // Examples: -0.5 = excellent match, -2.0 = good match, -10.0 = partial match. + Rank float64 `json:"rank,omitempty"` +} + +// GrepMessageResult is a message match from grep. +type GrepMessageResult struct { + ID int64 `json:"id,string"` + Snippet string `json:"snippet"` + Role string `json:"role"` + ConversationID int64 `json:"conversationId"` + Rank float64 `json:"rank,omitempty"` // Relevance score (lower = better match) +} + +// ExpandMessagesResult contains expanded messages. +type ExpandMessagesResult struct { + Messages []Message `json:"messages"` + TokenCount int `json:"tokenCount"` +} + +// Grep searches summaries and messages for matching content. +func (r *RetrievalEngine) Grep(ctx context.Context, input GrepInput) (*GrepResult, error) { + if input.Pattern == "" { + return nil, fmt.Errorf("grep: pattern is required") + } + + limit := input.Limit + if limit == 0 { + limit = 20 + } + + // Handle Last parameter: convert to Since + since := input.Since + if input.Last != "" { + dur, err := ParseLastDuration(input.Last) + if err != nil { + return nil, fmt.Errorf("grep: invalid last: %w", err) + } + t := time.Now().UTC().Add(-dur) + since = &t + } + + // Auto-detect mode: use LIKE if pattern contains %, otherwise full-text + mode := "" + if strings.Contains(input.Pattern, "%") { + mode = "like" + } + + searchInput := SearchInput{ + Pattern: input.Pattern, + Mode: mode, + Role: input.Role, + AllConversations: input.AllConversations, + Since: since, + Before: input.Before, + Limit: limit, + } + + result := &GrepResult{ + Success: true, + Summaries: make([]GrepSummaryResult, 0), + Messages: make([]GrepMessageResult, 0), + TotalSummaries: 0, + TotalMessages: 0, + } + + // Determine scope + scope := input.Scope + if scope == "" { + scope = "both" + } + + // Search summaries if requested + if scope == "both" || scope == "summary" { + sumResults, err := r.store.SearchSummaries(ctx, searchInput) + if err != nil { + return nil, fmt.Errorf("search summaries: %w", err) + } + for _, sr := range sumResults { + if sr.SummaryID != "" { + result.Summaries = append(result.Summaries, GrepSummaryResult{ + ID: sr.SummaryID, + Content: sr.Content, + Depth: sr.Depth, + Kind: sr.Kind, + ConversationID: sr.ConversationID, + Rank: sr.Rank, + }) + } + } + if len(sumResults) > 0 { + result.TotalSummaries = sumResults[0].TotalCount + } + } + + // Search messages if requested + if scope == "both" || scope == "message" { + msgResults, err := r.store.SearchMessages(ctx, searchInput) + if err != nil { + return nil, fmt.Errorf("search messages: %w", err) + } + for _, sr := range msgResults { + if sr.MessageID > 0 { + result.Messages = append(result.Messages, GrepMessageResult{ + ID: sr.MessageID, + Snippet: sr.Snippet, + Role: sr.Role, + ConversationID: sr.ConversationID, + Rank: sr.Rank, + }) + } + } + if len(msgResults) > 0 { + result.TotalMessages = msgResults[0].TotalCount + } + } + + // Add hint if no results + if len(result.Summaries) == 0 && len(result.Messages) == 0 { + result.Hint = "No matches. Try: %keyword% for fuzzy search, or all_conversations: true" + } + + return result, nil +} + +// ExpandMessages retrieves full message content by IDs. +func (r *RetrievalEngine) ExpandMessages(ctx context.Context, messageIDs []int64) (*ExpandMessagesResult, error) { + result := &ExpandMessagesResult{ + Messages: make([]Message, 0, len(messageIDs)), + } + + for _, msgID := range messageIDs { + msg, err := r.store.GetMessageByID(ctx, msgID) + if err != nil { + continue + } + result.Messages = append(result.Messages, *msg) + result.TokenCount += msg.TokenCount + } + + return result, nil +} diff --git a/pkg/seahorse/short_retrieval_test.go b/pkg/seahorse/short_retrieval_test.go new file mode 100644 index 000000000..9d9bc3640 --- /dev/null +++ b/pkg/seahorse/short_retrieval_test.go @@ -0,0 +1,362 @@ +package seahorse + +import ( + "context" + "fmt" + "testing" + "time" +) + +// --- Retrieval Tests --- + +func newTestRetrieval(t *testing.T) (*RetrievalEngine, *Store, int64) { + t.Helper() + s := openTestStore(t) + ctx := context.Background() + conv, _ := s.GetOrCreateConversation(ctx, "test:retrieval") + return &RetrievalEngine{store: s}, s, conv.ConversationID +} + +func TestRetrievalGrepSummaries(t *testing.T) { + r, s, convID := newTestRetrieval(t) + ctx := context.Background() + + s.CreateSummary(ctx, CreateSummaryInput{ + ConversationID: convID, + Kind: SummaryKindLeaf, + Depth: 0, + Content: "数据库连接配置说明", + TokenCount: 50, + }) + s.CreateSummary(ctx, CreateSummaryInput{ + ConversationID: convID, + Kind: SummaryKindLeaf, + Depth: 0, + Content: "API endpoint documentation", + TokenCount: 50, + }) + + // FTS5 search (trigram, needs >= 3 chars) + results, err := r.Grep(ctx, GrepInput{ + Pattern: "数据库连", + }) + if err != nil { + t.Fatalf("Grep: %v", err) + } + if len(results.Summaries) == 0 { + t.Error("expected at least 1 FTS result") + } + + // LIKE search with wildcard + results, err = r.Grep(ctx, GrepInput{ + Pattern: "%endpoint%", + }) + if err != nil { + t.Fatalf("Grep LIKE: %v", err) + } + if len(results.Summaries) == 0 { + t.Error("expected at least 1 LIKE result") + } +} + +func TestRetrievalGrepMessages(t *testing.T) { + r, s, convID := newTestRetrieval(t) + ctx := context.Background() + + s.AddMessage(ctx, convID, "user", "find this message about testing", 5) + s.AddMessage(ctx, convID, "user", "unrelated content here", 5) + + results, err := r.Grep(ctx, GrepInput{ + Pattern: "testing", + }) + if err != nil { + t.Fatalf("Grep: %v", err) + } + if len(results.Messages) == 0 { + t.Error("expected at least 1 result for 'testing'") + } +} + +func TestRetrievalExpandMessages(t *testing.T) { + r, s, convID := newTestRetrieval(t) + ctx := context.Background() + + msg, _ := s.AddMessage(ctx, convID, "user", "expand this message", 10) + + result, err := r.ExpandMessages(ctx, []int64{msg.ID}) + if err != nil { + t.Fatalf("ExpandMessages: %v", err) + } + if len(result.Messages) != 1 { + t.Errorf("Messages = %d, want 1", len(result.Messages)) + } + if result.Messages[0].Content != "expand this message" { + t.Errorf("Content = %q, want 'expand this message'", result.Messages[0].Content) + } +} + +func TestRetrievalExpandMultipleMessages(t *testing.T) { + r, s, convID := newTestRetrieval(t) + ctx := context.Background() + + msg1, _ := s.AddMessage(ctx, convID, "user", "first message", 10) + msg2, _ := s.AddMessage(ctx, convID, "assistant", "second message", 10) + msg3, _ := s.AddMessage(ctx, convID, "user", "third message", 10) + + result, err := r.ExpandMessages(ctx, []int64{msg1.ID, msg2.ID, msg3.ID}) + if err != nil { + t.Fatalf("ExpandMessages: %v", err) + } + if len(result.Messages) != 3 { + t.Errorf("Messages = %d, want 3", len(result.Messages)) + } + if result.TokenCount != 30 { + t.Errorf("TokenCount = %d, want 30", result.TokenCount) + } +} + +func TestRetrievalGrepWithTimeFilter(t *testing.T) { + r, s, convID := newTestRetrieval(t) + ctx := context.Background() + + now := time.Now().UTC() + before := now.Add(-2 * time.Hour) + + // Create messages at different times + s.AddMessage(ctx, convID, "user", "old message about auth", 5) + s.AddMessage(ctx, convID, "user", "recent message about auth", 5) + + // Search with time filter + results, err := r.Grep(ctx, GrepInput{ + Pattern: "auth", + Since: &before, + }) + if err != nil { + t.Fatalf("Grep: %v", err) + } + _ = results // Just verify no error +} + +func TestRetrievalGrepAllConversations(t *testing.T) { + r, s, _ := newTestRetrieval(t) + ctx := context.Background() + + // Create another conversation + conv2, _ := s.GetOrCreateConversation(ctx, "test:retrieval2") + + // Add messages to both + s.AddMessage(ctx, conv2.ConversationID, "user", "unique keyword xyz", 5) + + // Search all conversations + results, err := r.Grep(ctx, GrepInput{ + Pattern: "xyz", + AllConversations: true, + }) + if err != nil { + t.Fatalf("Grep: %v", err) + } + if len(results.Messages) == 0 { + t.Error("expected to find message in other conversation") + } +} + +// --- Last Duration Parsing Tests --- + +func TestParseLastDuration(t *testing.T) { + tests := []struct { + input string + wantDur time.Duration + wantErr bool + }{ + {"6h", 6 * time.Hour, false}, + {"1d", 24 * time.Hour, false}, + {"7d", 7 * 24 * time.Hour, false}, + {"2w", 14 * 24 * time.Hour, false}, + {"1m", 30 * 24 * time.Hour, false}, // month = 30 days + {"3m", 90 * 24 * time.Hour, false}, + {"", 0, true}, + {"invalid", 0, true}, + {"5x", 0, true}, // unknown unit + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got, err := ParseLastDuration(tt.input) + if tt.wantErr { + if err == nil { + t.Error("expected error, got nil") + } + } else { + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != tt.wantDur { + t.Errorf("ParseLastDuration(%q) = %v, want %v", tt.input, got, tt.wantDur) + } + } + }) + } +} + +// --- Role Filter Tests --- + +func TestRetrievalGrepRoleFilter(t *testing.T) { + r, s, convID := newTestRetrieval(t) + ctx := context.Background() + + s.AddMessage(ctx, convID, "user", "user message about alpha", 5) + s.AddMessage(ctx, convID, "assistant", "assistant reply about alpha", 5) + s.AddMessage(ctx, convID, "user", "another user message", 5) + + // Search all roles + allResults, err := r.Grep(ctx, GrepInput{ + Pattern: "alpha", + }) + if err != nil { + t.Fatalf("Grep: %v", err) + } + if len(allResults.Messages) != 2 { + t.Errorf("expected 2 messages, got %d", len(allResults.Messages)) + } + + // Search user only + userResults, err := r.Grep(ctx, GrepInput{ + Pattern: "alpha", + Role: "user", + }) + if err != nil { + t.Fatalf("Grep: %v", err) + } + if len(userResults.Messages) != 1 { + t.Errorf("expected 1 user message, got %d", len(userResults.Messages)) + } + if userResults.Messages[0].Role != "user" { + t.Errorf("expected role=user, got %s", userResults.Messages[0].Role) + } + + // Search assistant only + assistantResults, err := r.Grep(ctx, GrepInput{ + Pattern: "alpha", + Role: "assistant", + }) + if err != nil { + t.Fatalf("Grep: %v", err) + } + if len(assistantResults.Messages) != 1 { + t.Errorf("expected 1 assistant message, got %d", len(assistantResults.Messages)) + } +} + +// --- Last Parameter Tests --- + +func TestRetrievalGrepWithLast(t *testing.T) { + r, s, convID := newTestRetrieval(t) + ctx := context.Background() + + // Add messages (we can't control timestamps in SQLite easily, + // but we can verify the parameter is parsed correctly) + s.AddMessage(ctx, convID, "user", "recent message about testing", 5) + + // Test that Last parameter is converted to Since + results, err := r.Grep(ctx, GrepInput{ + Pattern: "testing", + Last: "1d", // last 1 day + }) + if err != nil { + t.Fatalf("Grep: %v", err) + } + // Should still find the message since it's recent + if len(results.Messages) == 0 { + t.Error("expected to find recent message") + } +} + +// TestRetrievalGrepRoleFilterWithSummaries tests that role filter works when +// searching both summaries and messages (summaries don't have role column). +func TestRetrievalGrepRoleFilterWithSummaries(t *testing.T) { + r, s, convID := newTestRetrieval(t) + ctx := context.Background() + + // Create a summary (no role column) + s.CreateSummary(ctx, CreateSummaryInput{ + ConversationID: convID, + Kind: SummaryKindLeaf, + Depth: 0, + Content: "summary about testing", + TokenCount: 50, + }) + + // Add messages with different roles + s.AddMessage(ctx, convID, "user", "user message about testing", 5) + s.AddMessage(ctx, convID, "assistant", "assistant reply about testing", 5) + + // Search with role filter and scope=both (default), using LIKE mode (%) + // This should NOT error even though summaries don't have role column + bothResults, err := r.Grep(ctx, GrepInput{ + Pattern: "%testing%", // LIKE mode to trigger the bug + Role: "user", + Scope: "both", + }) + if err != nil { + t.Fatalf("Grep with role and scope=both: %v", err) + } + + // Should only return user messages, not summaries or assistant messages + if len(bothResults.Messages) != 1 { + t.Errorf("expected 1 user message, got %d", len(bothResults.Messages)) + } + if len(bothResults.Messages) > 0 && bothResults.Messages[0].Role != "user" { + t.Errorf("expected role=user, got %s", bothResults.Messages[0].Role) + } + + // Summaries should be empty since they don't have roles to filter + // (or we could return all summaries - either is acceptable) +} + +// TestRetrievalGrepTotalCounts tests that grep returns total counts. +func TestRetrievalGrepTotalCounts(t *testing.T) { + r, s, convID := newTestRetrieval(t) + ctx := context.Background() + + // Create 3 summaries + for i := 0; i < 3; i++ { + s.CreateSummary(ctx, CreateSummaryInput{ + ConversationID: convID, + Kind: SummaryKindLeaf, + Depth: 0, + Content: fmt.Sprintf("summary about testing %d", i), + TokenCount: 50, + }) + } + + // Add 5 messages + for i := 0; i < 5; i++ { + s.AddMessage(ctx, convID, "user", fmt.Sprintf("message about testing %d", i), 5) + } + + // Search with limit smaller than total + results, err := r.Grep(ctx, GrepInput{ + Pattern: "%testing%", // LIKE mode + Scope: "both", + Limit: 2, + }) + if err != nil { + t.Fatalf("Grep: %v", err) + } + + // Should return limited results + if len(results.Summaries) > 2 { + t.Errorf("expected at most 2 summaries, got %d", len(results.Summaries)) + } + if len(results.Messages) > 2 { + t.Errorf("expected at most 2 messages, got %d", len(results.Messages)) + } + + // But total counts should reflect all matches + if results.TotalSummaries != 3 { + t.Errorf("expected TotalSummaries=3, got %d", results.TotalSummaries) + } + if results.TotalMessages != 5 { + t.Errorf("expected TotalMessages=5, got %d", results.TotalMessages) + } +} diff --git a/pkg/seahorse/store.go b/pkg/seahorse/store.go new file mode 100644 index 000000000..3d85c7b9c --- /dev/null +++ b/pkg/seahorse/store.go @@ -0,0 +1,1532 @@ +package seahorse + +import ( + "context" + "database/sql" + "fmt" + "strings" + "time" +) + +// Store provides SQLite storage for seahorse. +type Store struct { + db *sql.DB +} + +// CreateSummaryInput holds parameters for creating a summary. +type CreateSummaryInput struct { + ConversationID int64 + Kind SummaryKind + Depth int + Content string + TokenCount int + EarliestAt *time.Time + LatestAt *time.Time + DescendantCount int + DescendantTokenCount int + SourceMessageTokens int + Model string + ParentIDs []string // For condensed: child summary IDs being condensed +} + +// --- Conversation Operations --- + +// GetOrCreateConversation returns the conversation for a sessionKey, creating if needed. +func (s *Store) GetOrCreateConversation(ctx context.Context, sessionKey string) (*Conversation, error) { + // Try to get first + conv, err := s.GetConversationBySessionKey(ctx, sessionKey) + if err != nil { + return nil, err + } + if conv != nil { + return conv, nil + } + + // Create + result, err := s.db.ExecContext(ctx, + "INSERT INTO conversations (session_key) VALUES (?)", + sessionKey, + ) + if err != nil { + // Race: another goroutine may have inserted + if isUniqueViolation(err) { + return s.GetConversationBySessionKey(ctx, sessionKey) + } + return nil, fmt.Errorf("create conversation: %w", err) + } + id, _ := result.LastInsertId() + return &Conversation{ + ConversationID: id, + SessionKey: sessionKey, + }, nil +} + +// GetConversationBySessionKey retrieves a conversation by session key. +func (s *Store) GetConversationBySessionKey(ctx context.Context, sessionKey string) (*Conversation, error) { + var conv Conversation + var createdAt, updatedAt string + err := s.db.QueryRowContext(ctx, + "SELECT conversation_id, session_key, created_at, updated_at FROM conversations WHERE session_key = ?", + sessionKey, + ).Scan(&conv.ConversationID, &conv.SessionKey, &createdAt, &updatedAt) + if err == sql.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, fmt.Errorf("get conversation by session key: %w", err) + } + conv.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt) + conv.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt) + return &conv, nil +} + +// GetSessionStatus returns status for a specific session. +func (s *Store) GetSessionStatus(ctx context.Context, sessionKey string) (*SessionStatus, error) { + conv, err := s.GetConversationBySessionKey(ctx, sessionKey) + if err != nil { + return nil, err + } + if conv == nil { + return nil, nil + } + + msgCount, _ := s.GetMessageCount(ctx, conv.ConversationID) + sumCount, _ := s.getSummaryCount(ctx, conv.ConversationID) + tokenCount, _ := s.GetContextTokenCount(ctx, conv.ConversationID) + + oldest, newest, _ := s.getMessageTimeRange(ctx, conv.ConversationID) + + return &SessionStatus{ + SessionKey: conv.SessionKey, + ConversationID: conv.ConversationID, + Messages: msgCount, + TotalTokens: tokenCount, + Summaries: sumCount, + OldestAt: oldest, + NewestAt: newest, + }, nil +} + +// GetAllSessionStatuses returns status for all sessions. +func (s *Store) GetAllSessionStatuses(ctx context.Context) ([]SessionStatus, error) { + rows, err := s.db.QueryContext(ctx, "SELECT session_key FROM conversations") + if err != nil { + return nil, fmt.Errorf("list sessions: %w", err) + } + defer rows.Close() + + var statuses []SessionStatus + for rows.Next() { + var sessionKey string + if err := rows.Scan(&sessionKey); err != nil { + continue + } + status, err := s.GetSessionStatus(ctx, sessionKey) + if err != nil { + continue + } + if status != nil { + statuses = append(statuses, *status) + } + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate sessions: %w", err) + } + return statuses, nil +} + +func (s *Store) getSummaryCount(ctx context.Context, convID int64) (int, error) { + var count int + err := s.db.QueryRowContext(ctx, + "SELECT COUNT(*) FROM summaries WHERE conversation_id = ?", + convID, + ).Scan(&count) + return count, err +} + +func (s *Store) getMessageTimeRange(ctx context.Context, convID int64) (time.Time, time.Time, error) { + var minTime, maxTime string + err := s.db.QueryRowContext(ctx, + "SELECT MIN(created_at), MAX(created_at) FROM messages WHERE conversation_id = ?", + convID, + ).Scan(&minTime, &maxTime) + if err != nil || minTime == "" { + return time.Time{}, time.Time{}, err + } + oldest, _ := time.Parse("2006-01-02 15:04:05", minTime) + newest, _ := time.Parse("2006-01-02 15:04:05", maxTime) + return oldest, newest, nil +} + +// --- Message Operations --- + +// AddMessage appends a message to a conversation. +func (s *Store) AddMessage(ctx context.Context, convID int64, role, content string, tokenCount int) (*Message, error) { + result, err := s.db.ExecContext(ctx, + "INSERT INTO messages (conversation_id, role, content, token_count) VALUES (?, ?, ?, ?)", + convID, role, content, tokenCount, + ) + if err != nil { + return nil, fmt.Errorf("add message: %w", err) + } + id, _ := result.LastInsertId() + return &Message{ + ID: id, + ConversationID: convID, + Role: role, + Content: content, + TokenCount: tokenCount, + }, nil +} + +// partsToReadableContent derives a readable text summary from message parts. +// This ensures FTS5 indexing and summary formatting can access tool call information. +func partsToReadableContent(parts []MessagePart) string { + var b strings.Builder + for i, p := range parts { + if i > 0 { + b.WriteString("\n") + } + switch p.Type { + case "text": + b.WriteString(p.Text) + case "tool_use": + fmt.Fprintf(&b, "[tool_use: %s, args: %s]", p.Name, p.Arguments) + case "tool_result": + fmt.Fprintf(&b, "[tool_result for %s: %s]", p.ToolCallID, p.Text) + case "media": + fmt.Fprintf(&b, "[media: %s (%s)]", p.MediaURI, p.MimeType) + default: + if p.Text != "" { + b.WriteString(p.Text) + } + } + } + return b.String() +} + +// AddMessageWithParts adds a message with structured parts. +func (s *Store) AddMessageWithParts( + ctx context.Context, + convID int64, + role string, + parts []MessagePart, + tokenCount int, +) (*Message, error) { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return nil, fmt.Errorf("begin tx: %w", err) + } + defer tx.Rollback() + + // Derive readable content from Parts for FTS5 indexing and summary formatting + readableContent := partsToReadableContent(parts) + + result, err := tx.ExecContext(ctx, + "INSERT INTO messages (conversation_id, role, content, token_count) VALUES (?, ?, ?, ?)", + convID, role, readableContent, tokenCount, + ) + if err != nil { + return nil, fmt.Errorf("add message: %w", err) + } + msgID, _ := result.LastInsertId() + + for i, p := range parts { + _, err = tx.ExecContext( + ctx, + `INSERT INTO message_parts (message_id, type, text, name, arguments, tool_call_id, media_uri, mime_type, ordinal) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, + msgID, + p.Type, + p.Text, + p.Name, + p.Arguments, + p.ToolCallID, + p.MediaURI, + p.MimeType, + i, + ) + if err != nil { + return nil, fmt.Errorf("add message part %d: %w", i, err) + } + } + if err := tx.Commit(); err != nil { + return nil, fmt.Errorf("commit: %w", err) + } + + // Return message with parts + msg := &Message{ + ID: msgID, + ConversationID: convID, + Role: role, + TokenCount: tokenCount, + Parts: make([]MessagePart, len(parts)), + } + for i, p := range parts { + p.MessageID = msgID + msg.Parts[i] = p + } + return msg, nil +} + +// GetMessages retrieves messages for a conversation. +func (s *Store) GetMessages(ctx context.Context, convID int64, limit int, beforeID int64) ([]Message, error) { + query := "SELECT message_id, conversation_id, role, content, token_count, created_at FROM messages WHERE conversation_id = ?" + args := []any{convID} + if beforeID > 0 { + query += " AND message_id < ?" + args = append(args, beforeID) + } + query += " ORDER BY message_id ASC" + if limit > 0 { + query += " LIMIT ?" + args = append(args, limit) + } + + rows, err := s.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("get messages: %w", err) + } + defer rows.Close() + + var msgs []Message + for rows.Next() { + var msg Message + var createdAt string + if err := rows.Scan( + &msg.ID, + &msg.ConversationID, + &msg.Role, + &msg.Content, + &msg.TokenCount, + &createdAt, + ); err != nil { + return nil, err + } + msg.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt) + msgs = append(msgs, msg) + } + if err := rows.Err(); err != nil { + return nil, err + } + + // Load parts for all messages + for i := range msgs { + parts, err := s.loadMessageParts(ctx, msgs[i].ID) + if err != nil { + return nil, err + } + msgs[i].Parts = parts + } + + return msgs, nil +} + +// GetMessageCount returns total message count for a conversation. +func (s *Store) GetMessageCount(ctx context.Context, convID int64) (int, error) { + var count int + err := s.db.QueryRowContext(ctx, + "SELECT count(*) FROM messages WHERE conversation_id = ?", convID, + ).Scan(&count) + return count, err +} + +// GetMessageByID retrieves a single message by ID. +func (s *Store) GetMessageByID(ctx context.Context, messageID int64) (*Message, error) { + var msg Message + var createdAt string + err := s.db.QueryRowContext(ctx, + "SELECT message_id, conversation_id, role, content, token_count, created_at FROM messages WHERE message_id = ?", + messageID, + ).Scan(&msg.ID, &msg.ConversationID, &msg.Role, &msg.Content, &msg.TokenCount, &createdAt) + if err == sql.ErrNoRows { + return nil, fmt.Errorf("message %d not found", messageID) + } + if err != nil { + return nil, err + } + msg.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt) + msg.Parts, _ = s.loadMessageParts(ctx, msg.ID) + return &msg, nil +} + +func (s *Store) loadMessageParts(ctx context.Context, msgID int64) ([]MessagePart, error) { + rows, err := s.db.QueryContext(ctx, + `SELECT part_id, message_id, type, text, name, arguments, tool_call_id, media_uri, mime_type + FROM message_parts WHERE message_id = ? ORDER BY ordinal`, + msgID, + ) + if err != nil { + return nil, err + } + defer rows.Close() + + var parts []MessagePart + for rows.Next() { + var p MessagePart + if err := rows.Scan(&p.ID, &p.MessageID, &p.Type, &p.Text, &p.Name, &p.Arguments, + &p.ToolCallID, &p.MediaURI, &p.MimeType); err != nil { + return nil, err + } + parts = append(parts, p) + } + if err := rows.Err(); err != nil { + return nil, err + } + return parts, nil +} + +// --- Summary Operations --- + +// CreateSummary creates a new summary and indexes it in FTS5. +func (s *Store) CreateSummary(ctx context.Context, input CreateSummaryInput) (*Summary, error) { + // Generate summary ID + now := time.Now().UTC() + summaryID := generateSummaryID(input.Content, now) + + var earliestAt, latestAt sql.NullString + if input.EarliestAt != nil { + earliestAt = sql.NullString{String: input.EarliestAt.Format(time.RFC3339), Valid: true} + } + if input.LatestAt != nil { + latestAt = sql.NullString{String: input.LatestAt.Format(time.RFC3339), Valid: true} + } + + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return nil, fmt.Errorf("begin tx: %w", err) + } + defer tx.Rollback() + + _, err = tx.ExecContext(ctx, + `INSERT INTO summaries (summary_id, conversation_id, kind, depth, content, token_count, + earliest_at, latest_at, descendant_count, descendant_token_count, + source_message_token_count, model) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + summaryID, input.ConversationID, string(input.Kind), input.Depth, + input.Content, input.TokenCount, + earliestAt, latestAt, + input.DescendantCount, input.DescendantTokenCount, + input.SourceMessageTokens, input.Model, + ) + if err != nil { + return nil, fmt.Errorf("insert summary: %w", err) + } + + // FTS trigger will fire automatically for summaries table insert + + // Link parent summaries (DAG edges) for condensed summaries + for _, parentID := range input.ParentIDs { + _, err = tx.ExecContext(ctx, + "INSERT INTO summary_parents (summary_id, parent_summary_id) VALUES (?, ?)", + summaryID, parentID, + ) + if err != nil { + return nil, fmt.Errorf("link parent %s: %w", parentID, err) + } + } + + if err := tx.Commit(); err != nil { + return nil, fmt.Errorf("commit: %w", err) + } + + return &Summary{ + SummaryID: summaryID, + ConversationID: input.ConversationID, + Kind: input.Kind, + Depth: input.Depth, + Content: input.Content, + TokenCount: input.TokenCount, + EarliestAt: input.EarliestAt, + LatestAt: input.LatestAt, + DescendantCount: input.DescendantCount, + DescendantTokenCount: input.DescendantTokenCount, + SourceMessageTokenCount: input.SourceMessageTokens, + Model: input.Model, + CreatedAt: now, + }, nil +} + +// GetSummary retrieves a summary by ID. +func (s *Store) GetSummary(ctx context.Context, summaryID string) (*Summary, error) { + return s.scanSummary(ctx, "WHERE summary_id = ?", summaryID) +} + +// GetSummariesByConversation retrieves all summaries for a conversation. +func (s *Store) GetSummariesByConversation(ctx context.Context, convID int64) ([]Summary, error) { + rows, err := s.db.QueryContext(ctx, + `SELECT summary_id, conversation_id, kind, depth, content, token_count, + earliest_at, latest_at, descendant_count, descendant_token_count, + source_message_token_count, model, created_at + FROM summaries WHERE conversation_id = ? ORDER BY created_at`, + convID, + ) + if err != nil { + return nil, err + } + defer rows.Close() + return s.scanSummaries(rows) +} + +// GetSummaryChildren retrieves child summary IDs (summaries that list this summary as parent). +func (s *Store) GetSummaryChildren(ctx context.Context, summaryID string) ([]string, error) { + rows, err := s.db.QueryContext(ctx, + "SELECT summary_id FROM summary_parents WHERE parent_summary_id = ?", + summaryID, + ) + if err != nil { + return nil, err + } + defer rows.Close() + + var ids []string + for rows.Next() { + var id string + if err := rows.Scan(&id); err != nil { + return nil, err + } + ids = append(ids, id) + } + if err := rows.Err(); err != nil { + return nil, err + } + return ids, nil +} + +// GetSummaryParents retrieves parent summaries (full objects) for a summary. +func (s *Store) GetSummaryParents(ctx context.Context, summaryID string) ([]Summary, error) { + rows, err := s.db.QueryContext(ctx, + `SELECT s.summary_id, s.conversation_id, s.kind, s.depth, s.content, s.token_count, + s.earliest_at, s.latest_at, s.descendant_count, s.descendant_token_count, + s.source_message_token_count, s.model, s.created_at + FROM summary_parents sp + JOIN summaries s ON s.summary_id = sp.parent_summary_id + WHERE sp.summary_id = ?`, + summaryID, + ) + if err != nil { + return nil, err + } + defer rows.Close() + return s.scanSummaries(rows) +} + +// LinkSummaryToMessages links a leaf summary to its source messages. +func (s *Store) LinkSummaryToMessages(ctx context.Context, summaryID string, messageIDs []int64) error { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return err + } + defer tx.Rollback() + + for i, msgID := range messageIDs { + _, err = tx.ExecContext(ctx, + "INSERT OR IGNORE INTO summary_messages (summary_id, message_id, ordinal) VALUES (?, ?, ?)", + summaryID, msgID, i, + ) + if err != nil { + return err + } + } + return tx.Commit() +} + +// GetSummarySourceMessages retrieves source messages for a summary. +func (s *Store) GetSummarySourceMessages(ctx context.Context, summaryID string) ([]Message, error) { + rows, err := s.db.QueryContext(ctx, + `SELECT m.message_id, m.conversation_id, m.role, m.content, m.token_count, m.created_at + FROM summary_messages sm + JOIN messages m ON m.message_id = sm.message_id + WHERE sm.summary_id = ? + ORDER BY sm.ordinal`, + summaryID, + ) + if err != nil { + return nil, err + } + defer rows.Close() + + var msgs []Message + for rows.Next() { + var msg Message + var createdAt string + if err := rows.Scan( + &msg.ID, + &msg.ConversationID, + &msg.Role, + &msg.Content, + &msg.TokenCount, + &createdAt, + ); err != nil { + return nil, err + } + msg.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt) + msgs = append(msgs, msg) + } + if err := rows.Err(); err != nil { + return nil, err + } + return msgs, nil +} + +// GetRootSummaries retrieves root summaries (not children of any other summary). +func (s *Store) GetRootSummaries(ctx context.Context, convID int64) ([]Summary, error) { + rows, err := s.db.QueryContext(ctx, + `SELECT s.summary_id, s.conversation_id, s.kind, s.depth, s.content, s.token_count, + s.earliest_at, s.latest_at, s.descendant_count, s.descendant_token_count, + s.source_message_token_count, s.model, s.created_at + FROM summaries s + WHERE s.conversation_id = ? + AND s.summary_id NOT IN (SELECT sp.parent_summary_id FROM summary_parents sp) + ORDER BY s.created_at`, + convID, + ) + if err != nil { + return nil, err + } + defer rows.Close() + return s.scanSummaries(rows) +} + +// --- Context Item Operations --- + +// GetContextItems retrieves context items for a conversation, ordered by ordinal. +func (s *Store) GetContextItems(ctx context.Context, convID int64) ([]ContextItem, error) { + rows, err := s.db.QueryContext( + ctx, + "SELECT ordinal, item_type, summary_id, message_id, token_count, created_at FROM context_items WHERE conversation_id = ? ORDER BY ordinal", + convID, + ) + if err != nil { + return nil, err + } + defer rows.Close() + + var items []ContextItem + for rows.Next() { + var item ContextItem + var summaryID sql.NullString + var messageID sql.NullInt64 + var createdAt sql.NullString + if err := rows.Scan( + &item.Ordinal, + &item.ItemType, + &summaryID, + &messageID, + &item.TokenCount, + &createdAt, + ); err != nil { + return nil, err + } + item.ConversationID = convID + if summaryID.Valid { + item.SummaryID = summaryID.String + } + if messageID.Valid { + item.MessageID = messageID.Int64 + } + if createdAt.Valid { + t, _ := time.Parse("2006-01-02 15:04:05", createdAt.String) + item.CreatedAt = t + } + items = append(items, item) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +// UpsertContextItems replaces all context items for a conversation. +func (s *Store) UpsertContextItems(ctx context.Context, convID int64, items []ContextItem) error { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return err + } + defer tx.Rollback() + + _, err = tx.ExecContext(ctx, "DELETE FROM context_items WHERE conversation_id = ?", convID) + if err != nil { + return err + } + + for _, item := range items { + _, err = tx.ExecContext(ctx, + `INSERT INTO context_items (conversation_id, ordinal, item_type, summary_id, message_id, token_count) + VALUES (?, ?, ?, ?, ?, ?)`, + convID, item.Ordinal, item.ItemType, + nullString(item.SummaryID), nullInt64(item.MessageID), + item.TokenCount, + ) + if err != nil { + return err + } + } + return tx.Commit() +} + +// ClearContextItems removes all context items for a conversation. +func (s *Store) ClearContextItems(ctx context.Context, convID int64) error { + _, err := s.db.ExecContext(ctx, "DELETE FROM context_items WHERE conversation_id = ?", convID) + return err +} + +// DeleteMessagesAfterID deletes all messages with ID > afterID for a conversation. +// Also clears related context_items, message_parts, summary_messages, and FTS entries. +// Uses transaction to ensure atomicity of the delete cascade. +func (s *Store) DeleteMessagesAfterID(ctx context.Context, convID int64, afterID int64) error { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return err + } + defer tx.Rollback() + + // Get message IDs to delete for cleaning up related tables + rows, err := tx.QueryContext(ctx, + "SELECT message_id FROM messages WHERE conversation_id = ? AND message_id > ?", convID, afterID) + if err != nil { + return err + } + defer rows.Close() + + var msgIDs []int64 + for rows.Next() { + var id int64 + if scanErr := rows.Scan(&id); scanErr != nil { + return scanErr + } + msgIDs = append(msgIDs, id) + } + if rows.Err() != nil { + return rows.Err() + } + + // Delete context_items referencing these messages + for _, msgID := range msgIDs { + if _, err := tx.ExecContext(ctx, "DELETE FROM context_items WHERE message_id = ?", msgID); err != nil { + return err + } + } + + // Delete from message_parts and summary_messages + // Note: messages_fts is handled automatically by trigger, no manual delete needed + for _, msgID := range msgIDs { + if _, err := tx.ExecContext(ctx, "DELETE FROM message_parts WHERE message_id = ?", msgID); err != nil { + return err + } + if _, err := tx.ExecContext(ctx, "DELETE FROM summary_messages WHERE message_id = ?", msgID); err != nil { + return err + } + } + + // Delete messages + if _, err := tx.ExecContext(ctx, + "DELETE FROM messages WHERE conversation_id = ? AND message_id > ?", convID, afterID); err != nil { + return err + } + + return tx.Commit() +} + +// AppendContextMessage appends a single message to context_items at next ordinal. +func (s *Store) AppendContextMessage(ctx context.Context, convID int64, messageID int64) error { + return s.appendContextItems(ctx, convID, []ContextItem{ + {ItemType: "message", MessageID: messageID}, + }) +} + +// AppendContextMessages bulk-appends messages to context_items. +func (s *Store) AppendContextMessages(ctx context.Context, convID int64, messageIDs []int64) error { + items := make([]ContextItem, len(messageIDs)) + for i, id := range messageIDs { + items[i] = ContextItem{ItemType: "message", MessageID: id} + } + return s.appendContextItems(ctx, convID, items) +} + +// AppendContextSummary appends a summary to context_items at next ordinal. +func (s *Store) AppendContextSummary(ctx context.Context, convID int64, summaryID string) error { + return s.appendContextItems(ctx, convID, []ContextItem{ + {ItemType: "summary", SummaryID: summaryID}, + }) +} + +func (s *Store) appendContextItems(ctx context.Context, convID int64, items []ContextItem) error { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return err + } + defer tx.Rollback() + + maxOrd, err := s.GetMaxOrdinalTx(ctx, tx, convID) + if err != nil { + return err + } + + ordinal := maxOrd + OrdinalStep + for _, item := range items { + item.ConversationID = convID + item.Ordinal = ordinal + + // Resolve token count if not set + tokenCount := item.TokenCount + if tokenCount == 0 { + tokenCount = s.resolveItemTokenCountTx(ctx, tx, item) + } + + _, err = tx.ExecContext(ctx, + `INSERT INTO context_items (conversation_id, ordinal, item_type, summary_id, message_id, token_count) + VALUES (?, ?, ?, ?, ?, ?)`, + convID, ordinal, item.ItemType, + nullString(item.SummaryID), nullInt64(item.MessageID), + tokenCount, + ) + if err != nil { + return err + } + ordinal += OrdinalStep + } + return tx.Commit() +} + +// resolveItemTokenCountTx looks up token count within a transaction. +func (s *Store) resolveItemTokenCountTx(ctx context.Context, tx *sql.Tx, item ContextItem) int { + if item.ItemType == "message" && item.MessageID > 0 { + var tc int + err := tx.QueryRowContext(ctx, + "SELECT token_count FROM messages WHERE message_id = ?", item.MessageID, + ).Scan(&tc) + if err == nil { + return tc + } + } + if item.ItemType == "summary" && item.SummaryID != "" { + var tc int + err := tx.QueryRowContext(ctx, + "SELECT token_count FROM summaries WHERE summary_id = ?", item.SummaryID, + ).Scan(&tc) + if err == nil { + return tc + } + } + return 0 +} + +// ReplaceContextRangeWithSummary atomically replaces a range of context items with a summary. +// If ordinal gap is exhausted, triggers resequencing (spec lines 1204-1209). +func (s *Store) ReplaceContextRangeWithSummary( + ctx context.Context, + convID int64, + startOrdinal, endOrdinal int, + summaryID string, +) error { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return err + } + defer tx.Rollback() + + // Delete the range + _, err = tx.ExecContext(ctx, + "DELETE FROM context_items WHERE conversation_id = ? AND ordinal >= ? AND ordinal <= ?", + convID, startOrdinal, endOrdinal, + ) + if err != nil { + return err + } + + // Insert summary at midpoint of replaced range + midpoint := (startOrdinal + endOrdinal) / 2 + + // Check if midpoint conflicts with existing ordinal + var conflict bool + var existingOrd int + err = tx.QueryRowContext(ctx, + "SELECT ordinal FROM context_items WHERE conversation_id = ? AND ordinal = ?", + convID, midpoint, + ).Scan(&existingOrd) + if err == nil { + conflict = true + } + + if conflict { + // Gap exhausted, need resequence (spec lines 1204-1209) + err = s.resequenceContextItemsTx(ctx, tx, convID, summaryID) + if err != nil { + return fmt.Errorf("resequence: %w", err) + } + } else { + // Normal insert at midpoint with token_count from summary + _, err = tx.ExecContext(ctx, + `INSERT INTO context_items (conversation_id, ordinal, item_type, summary_id, token_count) + SELECT ?, ?, 'summary', ?, token_count FROM summaries WHERE summary_id = ?`, + convID, midpoint, summaryID, summaryID, + ) + if err != nil { + return err + } + } + + return tx.Commit() +} + +// ReplaceContextItemsWithSummary replaces specific context items (by summary_id) with a new summary. +// Use this when candidates are not contiguous in ordinal space to avoid deleting non-candidate items. +func (s *Store) ReplaceContextItemsWithSummary( + ctx context.Context, + convID int64, + summaryIDs []string, + newSummaryID string, +) error { + if len(summaryIDs) == 0 { + return nil + } + + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return err + } + defer tx.Rollback() + + // Find the ordinals of items to delete and calculate midpoint + placeholders := make([]string, len(summaryIDs)) + args := make([]any, len(summaryIDs)+1) + args[0] = convID + for i, sid := range summaryIDs { + placeholders[i] = "?" + args[i+1] = sid + } + + query := fmt.Sprintf( + "SELECT ordinal FROM context_items WHERE conversation_id = ? AND summary_id IN (%s) ORDER BY ordinal", + strings.Join(placeholders, ","), + ) + rows, err := tx.QueryContext(ctx, query, args...) + if err != nil { + return err + } + defer rows.Close() + + var ordinals []int + for rows.Next() { + var ord int + if scanErr := rows.Scan(&ord); scanErr != nil { + return scanErr + } + ordinals = append(ordinals, ord) + } + if err = rows.Err(); err != nil { + return err + } + + if len(ordinals) == 0 { + return nil + } + + midpoint := (ordinals[0] + ordinals[len(ordinals)-1]) / 2 + + // Delete the specific items by summary_id + deleteQuery := fmt.Sprintf( + "DELETE FROM context_items WHERE conversation_id = ? AND summary_id IN (%s)", + strings.Join(placeholders, ","), + ) + _, err = tx.ExecContext(ctx, deleteQuery, args...) + if err != nil { + return err + } + + // Check if midpoint conflicts with existing ordinal + var conflict bool + var existingOrd int + err = tx.QueryRowContext(ctx, + "SELECT ordinal FROM context_items WHERE conversation_id = ? AND ordinal = ?", + convID, midpoint, + ).Scan(&existingOrd) + if err == nil { + conflict = true + } + + if conflict { + // Gap exhausted, need resequence + err = s.resequenceContextItemsTx(ctx, tx, convID, newSummaryID) + if err != nil { + return fmt.Errorf("resequence: %w", err) + } + } else { + // Normal insert at midpoint + _, err = tx.ExecContext(ctx, + `INSERT INTO context_items (conversation_id, ordinal, item_type, summary_id, token_count) + SELECT ?, ?, 'summary', ?, token_count FROM summaries WHERE summary_id = ?`, + convID, midpoint, newSummaryID, newSummaryID, + ) + if err != nil { + return err + } + } + + return tx.Commit() +} + +// resequenceContextItemsTx renumbers context_items with fresh OrdinalStep gaps. +// Uses temp negative ordinals to avoid PRIMARY KEY constraint violations (spec lines 1240-1247). +func (s *Store) resequenceContextItemsTx(ctx context.Context, tx *sql.Tx, convID int64, newSummaryID string) error { + // Get all remaining items sorted by current ordinal + rows, err := tx.QueryContext( + ctx, + "SELECT ordinal, item_type, summary_id, message_id, token_count FROM context_items WHERE conversation_id = ? ORDER BY ordinal", + convID, + ) + if err != nil { + return err + } + defer rows.Close() + + type item struct { + ordinal int + itemType string + summaryID string + messageID int64 + tokenCount int + } + var items []item + for rows.Next() { + var i item + var sid sql.NullString + var mid sql.NullInt64 + var scanErr error + if scanErr = rows.Scan(&i.ordinal, &i.itemType, &sid, &mid, &i.tokenCount); scanErr != nil { + return scanErr + } + if sid.Valid { + i.summaryID = sid.String + } + if mid.Valid { + i.messageID = mid.Int64 + } + items = append(items, i) + } + if rowsErr := rows.Err(); rowsErr != nil { + return rowsErr + } + + // Step 1: Move all items to temp negative ordinals + tempOrd := -1 + for _, i := range items { + _, execErr := tx.ExecContext(ctx, + "UPDATE context_items SET ordinal = ? WHERE conversation_id = ? AND ordinal = ?", + tempOrd, convID, i.ordinal, + ) + if execErr != nil { + return execErr + } + tempOrd-- + } + + // Step 2: Insert new summary at the end with positive ordinal + // Include token_count from summaries table + newOrd := (len(items) + 1) * OrdinalStep + _, err = tx.ExecContext(ctx, + `INSERT INTO context_items (conversation_id, ordinal, item_type, summary_id, token_count) + SELECT ?, ?, 'summary', ?, token_count FROM summaries WHERE summary_id = ?`, + convID, newOrd, newSummaryID, newSummaryID, + ) + if err != nil { + return err + } + + // Step 3: Update each temp item to its final positive ordinal + // Use specific temp ordinal matching (not ordinal < 0) to avoid updating all items + finalOrd := OrdinalStep + tempOrd = -1 // Reset to first temp ordinal (already declared in Step 1) + for range items { + _, execErr := tx.ExecContext(ctx, + "UPDATE context_items SET ordinal = ? WHERE conversation_id = ? AND ordinal = ?", + finalOrd, convID, tempOrd, + ) + if execErr != nil { + return execErr + } + finalOrd += OrdinalStep + tempOrd-- + } + + return nil +} + +// GetContextTokenCount returns total token count for all items in context. +func (s *Store) GetContextTokenCount(ctx context.Context, convID int64) (int, error) { + var count int + err := s.db.QueryRowContext(ctx, + "SELECT COALESCE(SUM(token_count), 0) FROM context_items WHERE conversation_id = ?", + convID, + ).Scan(&count) + return count, err +} + +// GetMaxOrdinal returns the highest ordinal in context_items for a conversation. +func (s *Store) GetMaxOrdinal(ctx context.Context, convID int64) (int, error) { + var maxOrd sql.NullInt64 + err := s.db.QueryRowContext(ctx, + "SELECT MAX(ordinal) FROM context_items WHERE conversation_id = ?", + convID, + ).Scan(&maxOrd) + if err != nil { + return 0, err + } + if !maxOrd.Valid { + return 0, nil + } + return int(maxOrd.Int64), nil +} + +// GetMaxOrdinalTx returns the highest ordinal within a transaction. +func (s *Store) GetMaxOrdinalTx(ctx context.Context, tx *sql.Tx, convID int64) (int, error) { + var maxOrd sql.NullInt64 + err := tx.QueryRowContext(ctx, + "SELECT MAX(ordinal) FROM context_items WHERE conversation_id = ?", + convID, + ).Scan(&maxOrd) + if err != nil { + return 0, err + } + if !maxOrd.Valid { + return 0, nil + } + return int(maxOrd.Int64), nil +} + +// GetDistinctDepthsInContext returns distinct depth levels of summaries currently in context. +// maxOrdinalExclusive filters out summaries with ordinal >= this value (0 = no filter). +func (s *Store) GetDistinctDepthsInContext(ctx context.Context, convID int64, maxOrdinalExclusive int) ([]int, error) { + query := `SELECT DISTINCT s.depth + FROM context_items ci + JOIN summaries s ON s.summary_id = ci.summary_id + WHERE ci.conversation_id = ? AND ci.item_type = 'summary'` + args := []any{convID} + + if maxOrdinalExclusive > 0 { + query += " AND ci.ordinal < ?" + args = append(args, maxOrdinalExclusive) + } + + query += " ORDER BY s.depth" + + rows, err := s.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("get distinct depths: %w", err) + } + defer rows.Close() + + var depths []int + for rows.Next() { + var d int + if err := rows.Scan(&d); err != nil { + return nil, err + } + depths = append(depths, d) + } + if err := rows.Err(); err != nil { + return nil, err + } + return depths, nil +} + +// GetSummarySubtree returns all summaries in the subtree rooted at summaryID, +// including summaryID itself. Uses a recursive CTE to traverse the DAG. +func (s *Store) GetSummarySubtree(ctx context.Context, summaryID string) ([]SummarySubtreeNode, error) { + rows, err := s.db.QueryContext(ctx, ` + WITH RECURSIVE subtree AS ( + SELECT summary_id, 0 AS depth_from_root + FROM summaries + WHERE summary_id = ? + UNION ALL + SELECT sp.parent_summary_id, st.depth_from_root + 1 + FROM summary_parents sp + JOIN subtree st ON sp.summary_id = st.summary_id + ) + SELECT summary_id, depth_from_root FROM subtree`, + summaryID, + ) + if err != nil { + return nil, fmt.Errorf("get summary subtree: %w", err) + } + defer rows.Close() + + var nodes []SummarySubtreeNode + for rows.Next() { + var n SummarySubtreeNode + if err := rows.Scan(&n.SummaryID, &n.DepthFromRoot); err != nil { + return nil, err + } + nodes = append(nodes, n) + } + if err := rows.Err(); err != nil { + return nil, err + } + return nodes, nil +} + +// --- Search Operations --- + +// SearchSummaries performs full-text search on summaries. +func (s *Store) SearchSummaries(ctx context.Context, input SearchInput) ([]SearchResult, error) { + // "like" → LIKE search, anything else (including "full_text" or empty) → FTS5 + if input.Mode == "like" { + return s.searchSummariesLike(ctx, input) + } + return s.searchSummariesFTS(ctx, input) +} + +func (s *Store) searchSummariesFTS(ctx context.Context, input SearchInput) ([]SearchResult, error) { + // Build WHERE clause for filters (used in both count and data queries) + whereClauses := []string{"summaries_fts MATCH ?"} + args := []any{input.Pattern} + + if input.ConversationID > 0 && !input.AllConversations { + whereClauses = append(whereClauses, "s.conversation_id = ?") + args = append(args, input.ConversationID) + } + + if input.Since != nil { + whereClauses = append(whereClauses, "s.created_at >= ?") + args = append(args, input.Since.Format("2006-01-02 15:04:05")) + } + if input.Before != nil { + whereClauses = append(whereClauses, "s.created_at < ?") + args = append(args, input.Before.Format("2006-01-02 15:04:05")) + } + + whereStr := strings.Join(whereClauses, " AND ") + + // First, get total count (bm25 conflicts with window functions in FTS5) + countQuery := `SELECT COUNT(*) FROM summaries_fts fts + JOIN summaries s ON s.summary_id = fts.summary_id + WHERE ` + whereStr + var totalCount int + if err := s.db.QueryRowContext(ctx, countQuery, args...).Scan(&totalCount); err != nil { + return nil, err + } + + // Then, get actual results with bm25 ranking + dataQuery := `SELECT s.summary_id, s.conversation_id, s.kind, s.content, s.created_at, bm25(summaries_fts) as rank + FROM summaries_fts fts + JOIN summaries s ON s.summary_id = fts.summary_id + WHERE ` + whereStr + ` ORDER BY rank` + + dataArgs := append([]any{}, args...) // copy args + if input.Limit > 0 { + dataQuery += " LIMIT ?" + dataArgs = append(dataArgs, input.Limit) + } + + rows, err := s.db.QueryContext(ctx, dataQuery, dataArgs...) + if err != nil { + return nil, err + } + defer rows.Close() + + results, err := s.scanSearchResults(rows, true) + if err != nil { + return nil, err + } + + // Set total count on all results + for i := range results { + results[i].TotalCount = totalCount + } + return results, nil +} + +// buildLikeQuery appends conversation/time filters and limit to a LIKE query. +// Note: role filtering is NOT applied here since summaries don't have role column. +// Use buildMessagesLikeQuery for message searches that need role filtering. +func buildLikeQuery(query string, args []any, input SearchInput) (string, []any) { + if input.ConversationID > 0 && !input.AllConversations { + query += " AND conversation_id = ?" + args = append(args, input.ConversationID) + } + if input.Since != nil { + query += " AND created_at >= ?" + args = append(args, input.Since.Format("2006-01-02 15:04:05")) + } + if input.Before != nil { + query += " AND created_at < ?" + args = append(args, input.Before.Format("2006-01-02 15:04:05")) + } + // Order by newest first for LIKE mode + query += " ORDER BY created_at DESC" + if input.Limit > 0 { + query += " LIMIT ?" + args = append(args, input.Limit) + } + return query, args +} + +// buildMessagesLikeQuery is like buildLikeQuery but adds role filtering for messages. +func buildMessagesLikeQuery(query string, args []any, input SearchInput) (string, []any) { + if input.Role != "" { + query += " AND role = ?" + args = append(args, input.Role) + } + return buildLikeQuery(query, args, input) +} + +func (s *Store) searchSummariesLike(ctx context.Context, input SearchInput) ([]SearchResult, error) { + query := `SELECT summary_id, conversation_id, kind, content, created_at, COUNT(*) OVER() as total_count + FROM summaries WHERE content LIKE ?` + args := []any{"%" + input.Pattern + "%"} + query, args = buildLikeQuery(query, args, input) + + rows, err := s.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + return s.scanSearchResults(rows, false) +} + +func (s *Store) scanSearchResults(rows *sql.Rows, withRank bool) ([]SearchResult, error) { + var results []SearchResult + for rows.Next() { + var r SearchResult + var createdAt string + var kind string + if withRank { + // FTS5 mode: no TotalCount in query (set by caller after COUNT) + if err := rows.Scan(&r.SummaryID, &r.ConversationID, &kind, &r.Content, &createdAt, &r.Rank); err != nil { + return nil, err + } + } else { + // LIKE mode: TotalCount from window function + if err := rows.Scan(&r.SummaryID, &r.ConversationID, &kind, + &r.Content, &createdAt, &r.TotalCount); err != nil { + return nil, err + } + } + r.Kind = SummaryKind(kind) + r.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt) + results = append(results, r) + } + return results, nil +} + +// SearchMessages performs full-text or regex search on messages. +func (s *Store) SearchMessages(ctx context.Context, input SearchInput) ([]SearchResult, error) { + // Try FTS5 first for full-text mode + if input.Mode == "" || input.Mode == "full_text" { + results, err := s.searchMessagesFTS(ctx, input) + if err == nil && len(results) > 0 { + return results, nil + } + // Fall through to LIKE + } + + return s.searchMessagesLike(ctx, input) +} + +func (s *Store) searchMessagesFTS(ctx context.Context, input SearchInput) ([]SearchResult, error) { + // Build WHERE clause for filters (used in both count and data queries) + whereClauses := []string{"messages_fts MATCH ?"} + args := []any{input.Pattern} + + if input.ConversationID > 0 && !input.AllConversations { + whereClauses = append(whereClauses, "m.conversation_id = ?") + args = append(args, input.ConversationID) + } + + if input.Role != "" { + whereClauses = append(whereClauses, "m.role = ?") + args = append(args, input.Role) + } + + if input.Since != nil { + whereClauses = append(whereClauses, "m.created_at >= ?") + args = append(args, input.Since.Format("2006-01-02 15:04:05")) + } + if input.Before != nil { + whereClauses = append(whereClauses, "m.created_at < ?") + args = append(args, input.Before.Format("2006-01-02 15:04:05")) + } + + whereStr := strings.Join(whereClauses, " AND ") + + // First, get total count (bm25 conflicts with window functions in FTS5) + countQuery := `SELECT COUNT(*) FROM messages_fts f + JOIN messages m ON f.message_id = m.message_id + WHERE ` + whereStr + var totalCount int + if err := s.db.QueryRowContext(ctx, countQuery, args...).Scan(&totalCount); err != nil { + return nil, err + } + + // Then, get actual results with bm25 ranking + dataQuery := `SELECT m.message_id, m.conversation_id, m.role, m.content, m.created_at, bm25(messages_fts) as rank + FROM messages_fts f + JOIN messages m ON f.message_id = m.message_id + WHERE ` + whereStr + ` ORDER BY rank` + + dataArgs := append([]any{}, args...) // copy args + if input.Limit > 0 { + dataQuery += " LIMIT ?" + dataArgs = append(dataArgs, input.Limit) + } + + rows, err := s.db.QueryContext(ctx, dataQuery, dataArgs...) + if err != nil { + return nil, err + } + defer rows.Close() + + results, err := s.scanMessageSearchResults(rows, true) + if err != nil { + return nil, err + } + + // Set total count on all results + for i := range results { + results[i].TotalCount = totalCount + } + return results, nil +} + +func (s *Store) searchMessagesLike(ctx context.Context, input SearchInput) ([]SearchResult, error) { + query := `SELECT message_id, conversation_id, role, content, created_at, COUNT(*) OVER() as total_count + FROM messages WHERE content LIKE ?` + args := []any{"%" + input.Pattern + "%"} + query, args = buildMessagesLikeQuery(query, args, input) + + rows, err := s.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + return s.scanMessageSearchResults(rows, false) +} + +func (s *Store) scanMessageSearchResults(rows *sql.Rows, withRank bool) ([]SearchResult, error) { + var results []SearchResult + for rows.Next() { + var r SearchResult + var createdAt string + var content string + if withRank { + // FTS5 mode: no TotalCount in query (set by caller after COUNT) + if err := rows.Scan(&r.MessageID, &r.ConversationID, &r.Role, &content, &createdAt, &r.Rank); err != nil { + return nil, err + } + } else { + // LIKE mode: TotalCount from window function + if err := rows.Scan(&r.MessageID, &r.ConversationID, &r.Role, &content, + &createdAt, &r.TotalCount); err != nil { + return nil, err + } + } + r.Snippet = content + r.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt) + results = append(results, r) + } + if err := rows.Err(); err != nil { + return nil, err + } + return results, nil +} + +// --- Helpers --- + +func (s *Store) scanSummary(ctx context.Context, where string, args ...any) (*Summary, error) { + row := s.db.QueryRowContext(ctx, + `SELECT summary_id, conversation_id, kind, depth, content, token_count, + earliest_at, latest_at, descendant_count, descendant_token_count, + source_message_token_count, model, created_at + FROM summaries `+where, args..., + ) + var sum Summary + var kind, createdAt string + var earliestAt, latestAt sql.NullString + err := row.Scan( + &sum.SummaryID, &sum.ConversationID, &kind, &sum.Depth, &sum.Content, &sum.TokenCount, + &earliestAt, &latestAt, &sum.DescendantCount, &sum.DescendantTokenCount, + &sum.SourceMessageTokenCount, &sum.Model, &createdAt, + ) + if err == sql.ErrNoRows { + return nil, fmt.Errorf("summary not found") + } + if err != nil { + return nil, err + } + sum.Kind = SummaryKind(kind) + sum.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt) + if earliestAt.Valid { + t, _ := time.Parse(time.RFC3339, earliestAt.String) + sum.EarliestAt = &t + } + if latestAt.Valid { + t, _ := time.Parse(time.RFC3339, latestAt.String) + sum.LatestAt = &t + } + return &sum, nil +} + +func (s *Store) scanSummaries(rows *sql.Rows) ([]Summary, error) { + var summaries []Summary + for rows.Next() { + var sum Summary + var kind, createdAt string + var earliestAt, latestAt sql.NullString + err := rows.Scan( + &sum.SummaryID, &sum.ConversationID, &kind, &sum.Depth, &sum.Content, &sum.TokenCount, + &earliestAt, &latestAt, &sum.DescendantCount, &sum.DescendantTokenCount, + &sum.SourceMessageTokenCount, &sum.Model, &createdAt, + ) + if err != nil { + return nil, err + } + sum.Kind = SummaryKind(kind) + sum.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt) + if earliestAt.Valid { + t, _ := time.Parse(time.RFC3339, earliestAt.String) + sum.EarliestAt = &t + } + if latestAt.Valid { + t, _ := time.Parse(time.RFC3339, latestAt.String) + sum.LatestAt = &t + } + summaries = append(summaries, sum) + } + if err := rows.Err(); err != nil { + return nil, err + } + return summaries, nil +} + +func generateSummaryID(content string, t time.Time) string { + return fmt.Sprintf("sum_%x", t.UnixNano()) +} + +func isUniqueViolation(err error) bool { + return err != nil && (contains(err.Error(), "UNIQUE constraint failed") || + contains(err.Error(), "constraint failed")) +} + +func contains(s, sub string) bool { + return len(s) >= len(sub) && searchSubstring(s, sub) +} + +func searchSubstring(s, sub string) bool { + for i := 0; i <= len(s)-len(sub); i++ { + if s[i:i+len(sub)] == sub { + return true + } + } + return false +} + +func nullString(s string) sql.NullString { + return sql.NullString{String: s, Valid: s != ""} +} + +func nullInt64(n int64) sql.NullInt64 { + return sql.NullInt64{Int64: n, Valid: n != 0} +} diff --git a/pkg/seahorse/store_test.go b/pkg/seahorse/store_test.go new file mode 100644 index 000000000..fd55379c6 --- /dev/null +++ b/pkg/seahorse/store_test.go @@ -0,0 +1,1250 @@ +package seahorse + +import ( + "context" + "fmt" + "testing" + "time" +) + +func openTestStore(t *testing.T) *Store { + t.Helper() + db := openTestDB(t) + if err := runSchema(db); err != nil { + t.Fatalf("migration: %v", err) + } + return &Store{db: db} +} + +// --- Conversation Operations --- + +func TestStoreGetOrCreateConversation(t *testing.T) { + s := openTestStore(t) + ctx := context.Background() + + conv, err := s.GetOrCreateConversation(ctx, "agent:abc123") + if err != nil { + t.Fatalf("GetOrCreateConversation: %v", err) + } + if conv.ConversationID == 0 { + t.Error("expected non-zero conversation ID") + } + if conv.SessionKey != "agent:abc123" { + t.Errorf("session key = %q, want %q", conv.SessionKey, "agent:abc123") + } + + // Idempotent — same session key returns same conversation + conv2, err := s.GetOrCreateConversation(ctx, "agent:abc123") + if err != nil { + t.Fatalf("GetOrCreateConversation (2nd): %v", err) + } + if conv2.ConversationID != conv.ConversationID { + t.Errorf("idempotent: got ID %d, want %d", conv2.ConversationID, conv.ConversationID) + } + + // Different session key → new conversation + conv3, err := s.GetOrCreateConversation(ctx, "agent:def456") + if err != nil { + t.Fatalf("GetOrCreateConversation (3rd): %v", err) + } + if conv3.ConversationID == conv.ConversationID { + t.Error("different session key should create different conversation") + } +} + +func TestStoreGetConversationBySessionKey(t *testing.T) { + s := openTestStore(t) + ctx := context.Background() + + // Not found + conv, err := s.GetConversationBySessionKey(ctx, "nonexistent") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if conv != nil { + t.Error("expected nil for nonexistent session key") + } + + // Create then retrieve + created, err := s.GetOrCreateConversation(ctx, "agent:test") + if err != nil { + t.Fatalf("create: %v", err) + } + found, err := s.GetConversationBySessionKey(ctx, "agent:test") + if err != nil { + t.Fatalf("find: %v", err) + } + if found.ConversationID != created.ConversationID { + t.Errorf("found ID %d, want %d", found.ConversationID, created.ConversationID) + } +} + +// --- Message Operations --- + +func TestStoreAddAndGetMessages(t *testing.T) { + s := openTestStore(t) + ctx := context.Background() + + conv, _ := s.GetOrCreateConversation(ctx, "agent:test") + + msg, err := s.AddMessage(ctx, conv.ConversationID, "user", "hello world", 5) + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + if msg.ID == 0 { + t.Error("expected non-zero message ID") + } + if msg.Role != "user" || msg.Content != "hello world" { + t.Errorf("message = %+v, want role=user content=hello world", msg) + } + + // Retrieve + msgs, err := s.GetMessages(ctx, conv.ConversationID, 10, 0) + if err != nil { + t.Fatalf("GetMessages: %v", err) + } + if len(msgs) != 1 { + t.Fatalf("got %d messages, want 1", len(msgs)) + } + if msgs[0].Content != "hello world" { + t.Errorf("content = %q, want %q", msgs[0].Content, "hello world") + } +} + +func TestStoreAddMessageWithParts(t *testing.T) { + s := openTestStore(t) + ctx := context.Background() + + conv, _ := s.GetOrCreateConversation(ctx, "agent:test") + + parts := []MessagePart{ + {Type: "tool_use", Name: "read_file", Arguments: `{"path":"/tmp/test"}`, ToolCallID: "tc_123"}, + {Type: "text", Text: "some output"}, + } + msg, err := s.AddMessageWithParts(ctx, conv.ConversationID, "assistant", parts, 10) + if err != nil { + t.Fatalf("AddMessageWithParts: %v", err) + } + if msg.ID == 0 { + t.Error("expected non-zero message ID") + } + + // Retrieve and verify parts + msgs, _ := s.GetMessages(ctx, conv.ConversationID, 10, 0) + if len(msgs) != 1 { + t.Fatalf("expected 1 message, got %d", len(msgs)) + } + if len(msgs[0].Parts) != 2 { + t.Fatalf("expected 2 parts, got %d", len(msgs[0].Parts)) + } + if msgs[0].Parts[0].Type != "tool_use" { + t.Errorf("part[0].Type = %q, want tool_use", msgs[0].Parts[0].Type) + } + if msgs[0].Parts[0].ToolCallID != "tc_123" { + t.Errorf("part[0].ToolCallID = %q, want tc_123", msgs[0].Parts[0].ToolCallID) + } +} + +func TestStoreGetMessageCount(t *testing.T) { + s := openTestStore(t) + ctx := context.Background() + + conv, _ := s.GetOrCreateConversation(ctx, "agent:test") + + s.AddMessage(ctx, conv.ConversationID, "user", "msg1", 2) + s.AddMessage(ctx, conv.ConversationID, "assistant", "msg2", 3) + s.AddMessage(ctx, conv.ConversationID, "user", "msg3", 1) + + count, err := s.GetMessageCount(ctx, conv.ConversationID) + if err != nil { + t.Fatalf("GetMessageCount: %v", err) + } + if count != 3 { + t.Errorf("count = %d, want 3", count) + } +} + +func TestStoreGetMessageByID(t *testing.T) { + s := openTestStore(t) + ctx := context.Background() + + conv, _ := s.GetOrCreateConversation(ctx, "agent:test") + + msg, _ := s.AddMessage(ctx, conv.ConversationID, "user", "find me", 3) + + found, err := s.GetMessageByID(ctx, msg.ID) + if err != nil { + t.Fatalf("GetMessageByID: %v", err) + } + if found.Content != "find me" { + t.Errorf("content = %q, want %q", found.Content, "find me") + } + + // Not found + _, err = s.GetMessageByID(ctx, 99999) + if err == nil { + t.Error("expected error for nonexistent message") + } +} + +// --- Summary Operations --- + +func TestStoreCreateAndGetSummary(t *testing.T) { + s := openTestStore(t) + ctx := context.Background() + + conv, _ := s.GetOrCreateConversation(ctx, "agent:test") + + now := time.Now().UTC().Truncate(time.Second) + summary, err := s.CreateSummary(ctx, CreateSummaryInput{ + ConversationID: conv.ConversationID, + Kind: SummaryKindLeaf, + Depth: 0, + Content: "test summary content", + TokenCount: 50, + EarliestAt: &now, + LatestAt: &now, + DescendantCount: 0, + DescendantTokenCount: 0, + SourceMessageTokens: 500, + Model: "test-model", + }) + if err != nil { + t.Fatalf("CreateSummary: %v", err) + } + if summary.SummaryID == "" { + t.Error("expected non-empty summary ID") + } + if summary.Kind != SummaryKindLeaf { + t.Errorf("kind = %q, want leaf", summary.Kind) + } + + // Retrieve by ID + found, err := s.GetSummary(ctx, summary.SummaryID) + if err != nil { + t.Fatalf("GetSummary: %v", err) + } + if found.Content != "test summary content" { + t.Errorf("content = %q, want 'test summary content'", found.Content) + } + if found.SourceMessageTokenCount != 500 { + t.Errorf("source_message_token_count = %d, want 500", found.SourceMessageTokenCount) + } +} + +func TestStoreSummaryDAG(t *testing.T) { + s := openTestStore(t) + ctx := context.Background() + + conv, _ := s.GetOrCreateConversation(ctx, "agent:test") + + // Create leaf summaries + leaf1, _ := s.CreateSummary(ctx, CreateSummaryInput{ + ConversationID: conv.ConversationID, + Kind: SummaryKindLeaf, + Depth: 0, + Content: "leaf 1", + TokenCount: 100, + }) + leaf2, _ := s.CreateSummary(ctx, CreateSummaryInput{ + ConversationID: conv.ConversationID, + Kind: SummaryKindLeaf, + Depth: 0, + Content: "leaf 2", + TokenCount: 100, + }) + + // Create condensed summary with parents (the children being condensed) + condensed, _ := s.CreateSummary(ctx, CreateSummaryInput{ + ConversationID: conv.ConversationID, + Kind: SummaryKindCondensed, + Depth: 1, + Content: "condensed from leaves", + TokenCount: 150, + ParentIDs: []string{leaf1.SummaryID, leaf2.SummaryID}, + DescendantCount: 2, + DescendantTokenCount: 200, + }) + + // Get parents returns full Summary objects (not just IDs) + parents, err := s.GetSummaryParents(ctx, condensed.SummaryID) + if err != nil { + t.Fatalf("GetSummaryParents: %v", err) + } + if len(parents) != 2 { + t.Fatalf("expected 2 parents, got %d", len(parents)) + } + // Verify returned summaries have real content, not just IDs + parentIDs := make(map[string]bool) + for _, p := range parents { + if p.Content == "" { + t.Error("parent summary should have non-empty Content") + } + if p.TokenCount == 0 { + t.Error("parent summary should have non-zero TokenCount") + } + parentIDs[p.SummaryID] = true + } + if !parentIDs[leaf1.SummaryID] || !parentIDs[leaf2.SummaryID] { + t.Errorf("parent IDs = %v, want both %s and %s", parentIDs, leaf1.SummaryID, leaf2.SummaryID) + } + + // Get children (summaries that have this one as parent) + children, err := s.GetSummaryChildren(ctx, condensed.SummaryID) + if err != nil { + t.Fatalf("GetSummaryChildren: %v", err) + } + if len(children) != 0 { + // condensed has no children yet — it's the root + t.Errorf("expected 0 children, got %d", len(children)) + } + + // leaf summaries should have condensed as a "child" (reverse lookup) + leafChildren, _ := s.GetSummaryChildren(ctx, leaf1.SummaryID) + if len(leafChildren) != 1 || leafChildren[0] != condensed.SummaryID { + t.Errorf("leaf1 children = %v, want [%s]", leafChildren, condensed.SummaryID) + } +} + +func TestStoreSummarySourceMessages(t *testing.T) { + s := openTestStore(t) + ctx := context.Background() + + conv, _ := s.GetOrCreateConversation(ctx, "agent:test") + + msg1, _ := s.AddMessage(ctx, conv.ConversationID, "user", "msg1", 2) + msg2, _ := s.AddMessage(ctx, conv.ConversationID, "assistant", "msg2", 3) + + summary, _ := s.CreateSummary(ctx, CreateSummaryInput{ + ConversationID: conv.ConversationID, + Kind: SummaryKindLeaf, + Depth: 0, + Content: "summary of msg1 and msg2", + TokenCount: 50, + }) + + err := s.LinkSummaryToMessages(ctx, summary.SummaryID, []int64{msg1.ID, msg2.ID}) + if err != nil { + t.Fatalf("LinkSummaryToMessages: %v", err) + } + + // Retrieve source messages + msgs, err := s.GetSummarySourceMessages(ctx, summary.SummaryID) + if err != nil { + t.Fatalf("GetSummarySourceMessages: %v", err) + } + if len(msgs) != 2 { + t.Fatalf("expected 2 source messages, got %d", len(msgs)) + } +} + +func TestStoreGetRootSummaries(t *testing.T) { + s := openTestStore(t) + ctx := context.Background() + + conv, _ := s.GetOrCreateConversation(ctx, "agent:test") + + // Create 2 leaf summaries + leaf1, _ := s.CreateSummary(ctx, CreateSummaryInput{ + ConversationID: conv.ConversationID, Kind: SummaryKindLeaf, Depth: 0, Content: "l1", TokenCount: 10, + }) + leaf2, _ := s.CreateSummary(ctx, CreateSummaryInput{ + ConversationID: conv.ConversationID, Kind: SummaryKindLeaf, Depth: 0, Content: "l2", TokenCount: 10, + }) + + // Before condensation — both are roots + roots, _ := s.GetRootSummaries(ctx, conv.ConversationID) + if len(roots) != 2 { + t.Errorf("before condensation: expected 2 roots, got %d", len(roots)) + } + + // Condense them + s.CreateSummary(ctx, CreateSummaryInput{ + ConversationID: conv.ConversationID, Kind: SummaryKindCondensed, Depth: 1, + Content: "c1", TokenCount: 15, ParentIDs: []string{leaf1.SummaryID, leaf2.SummaryID}, + }) + + // After condensation — only the condensed is root + roots, _ = s.GetRootSummaries(ctx, conv.ConversationID) + if len(roots) != 1 { + t.Errorf("after condensation: expected 1 root, got %d", len(roots)) + } + if roots[0].Kind != SummaryKindCondensed { + t.Errorf("root kind = %q, want condensed", roots[0].Kind) + } +} + +// --- Context Item Operations --- + +func TestStoreContextItems(t *testing.T) { + s := openTestStore(t) + ctx := context.Background() + + conv, _ := s.GetOrCreateConversation(ctx, "agent:test") + msg1, _ := s.AddMessage(ctx, conv.ConversationID, "user", "hello", 2) + msg2, _ := s.AddMessage(ctx, conv.ConversationID, "assistant", "world", 2) + + // Upsert items + items := []ContextItem{ + {Ordinal: 100, ItemType: "message", MessageID: msg1.ID, TokenCount: 2}, + {Ordinal: 200, ItemType: "message", MessageID: msg2.ID, TokenCount: 2}, + } + err := s.UpsertContextItems(ctx, conv.ConversationID, items) + if err != nil { + t.Fatalf("UpsertContextItems: %v", err) + } + + // Retrieve + retrieved, err := s.GetContextItems(ctx, conv.ConversationID) + if err != nil { + t.Fatalf("GetContextItems: %v", err) + } + if len(retrieved) != 2 { + t.Fatalf("expected 2 items, got %d", len(retrieved)) + } + if retrieved[0].Ordinal != 100 || retrieved[1].Ordinal != 200 { + t.Errorf("ordinals = %v, want [100 200]", []int{retrieved[0].Ordinal, retrieved[1].Ordinal}) + } + // CreatedAt should be populated + if retrieved[0].CreatedAt.IsZero() { + t.Error("expected CreatedAt to be populated on context item") + } +} + +func TestStoreAppendContextMessages(t *testing.T) { + s := openTestStore(t) + ctx := context.Background() + + conv, _ := s.GetOrCreateConversation(ctx, "agent:test") + msg1, _ := s.AddMessage(ctx, conv.ConversationID, "user", "hello", 2) + msg2, _ := s.AddMessage(ctx, conv.ConversationID, "assistant", "world", 2) + + s.UpsertContextItems(ctx, conv.ConversationID, []ContextItem{ + {Ordinal: 100, ItemType: "message", MessageID: msg1.ID, TokenCount: 2}, + }) + + // Append single message + err := s.AppendContextMessage(ctx, conv.ConversationID, msg2.ID) + if err != nil { + t.Fatalf("AppendContextMessage: %v", err) + } + + items, _ := s.GetContextItems(ctx, conv.ConversationID) + if len(items) != 2 { + t.Fatalf("expected 2 items after append, got %d", len(items)) + } + if items[1].MessageID != msg2.ID { + t.Errorf("appended message ID = %d, want %d", items[1].MessageID, msg2.ID) + } +} + +func TestStoreReplaceContextRangeWithSummary(t *testing.T) { + s := openTestStore(t) + ctx := context.Background() + + conv, _ := s.GetOrCreateConversation(ctx, "agent:test") + + // Create messages and context items + msgs := make([]int64, 4) + for i := 0; i < 4; i++ { + m, _ := s.AddMessage(ctx, conv.ConversationID, "user", "msg", 2) + msgs[i] = m.ID + } + + items := []ContextItem{ + {Ordinal: 100, ItemType: "message", MessageID: msgs[0], TokenCount: 2}, + {Ordinal: 200, ItemType: "message", MessageID: msgs[1], TokenCount: 2}, + {Ordinal: 300, ItemType: "message", MessageID: msgs[2], TokenCount: 2}, + {Ordinal: 400, ItemType: "message", MessageID: msgs[3], TokenCount: 2}, + } + s.UpsertContextItems(ctx, conv.ConversationID, items) + + // Create a summary + summary, _ := s.CreateSummary(ctx, CreateSummaryInput{ + ConversationID: conv.ConversationID, Kind: SummaryKindLeaf, Depth: 0, + Content: "summary", TokenCount: 5, + }) + + // Replace ordinals 200-300 with summary + err := s.ReplaceContextRangeWithSummary(ctx, conv.ConversationID, 200, 300, summary.SummaryID) + if err != nil { + t.Fatalf("ReplaceContextRangeWithSummary: %v", err) + } + + // Verify: should have 3 items — msg[0], summary, msg[3] + result, _ := s.GetContextItems(ctx, conv.ConversationID) + if len(result) != 3 { + t.Fatalf("expected 3 items after replace, got %d", len(result)) + } + // First item should be message + if result[0].ItemType != "message" || result[0].MessageID != msgs[0] { + t.Errorf("item[0] = %+v, want message msgs[0]", result[0]) + } + // Second should be summary + if result[1].ItemType != "summary" || result[1].SummaryID != summary.SummaryID { + t.Errorf("item[1] = %+v, want summary", result[1]) + } + // Third should be message + if result[2].ItemType != "message" || result[2].MessageID != msgs[3] { + t.Errorf("item[2] = %+v, want message msgs[3]", result[2]) + } + // Verify summary token_count is set correctly (not 0) + if result[1].TokenCount != 5 { + t.Errorf("summary item TokenCount = %d, want 5 (from summary.TokenCount)", result[1].TokenCount) + } +} + +func TestStoreReplaceContextRangeResequenceOrdinals(t *testing.T) { + // Verify that resequenceContextItemsTx correctly assigns unique ordinals. + // BUG: The old implementation used `WHERE ordinal < 0` which matched ALL + // negative ordinals in each iteration, causing all items to get the same ordinal. + // + // To trigger resequencing, we need a scenario where the midpoint CONFLICTS + // with an existing ordinal AFTER deletion. This happens when: + // - We delete a range that doesn't include the midpoint + // - Or when ordinals are packed densely (no gaps) + s := openTestStore(t) + ctx := context.Background() + + conv, _ := s.GetOrCreateConversation(ctx, "agent:test-resequence") + + // Create 5 messages with DENSE ordinals (no gaps) to trigger conflict + msgs := make([]int64, 5) + for i := 0; i < 5; i++ { + m, _ := s.AddMessage(ctx, conv.ConversationID, "user", fmt.Sprintf("msg%d", i), 2) + msgs[i] = m.ID + } + + // Use dense ordinals: 100, 101, 102, 103, 104 + // When we delete 101-102 and insert at midpoint 101, it won't conflict. + // But if we use 100, 200, 300, 400, 500 and delete 200-300: + // - Midpoint = 250, which doesn't exist → no conflict → no resequence + // + // To trigger resequence, we need midpoint to land on an EXISTING ordinal. + // Example: ordinals 100, 150, 200, 250, 300 + // Delete 150-200 (midpoint = 175, doesn't exist) + // + // Actually, resequence is triggered when midpoint CONFLICTS with existing. + // Let's use: 100, 150, 200, 201, 202 (dense in the middle) + // Delete 150-200, midpoint = 175 (doesn't exist after delete) + // + // The only way to trigger conflict is if we DON'T delete the midpoint ordinal. + // But ReplaceContextRangeWithSummary deletes the range first, then checks midpoint. + // + // Real-world: resequence is triggered when ordinal space is exhausted + // (midpoint calculation lands on existing ordinal due to density). + // Let's simulate this by having many items with ordinal_step=1: + items := []ContextItem{ + {Ordinal: 100, ItemType: "message", MessageID: msgs[0], TokenCount: 2}, + {Ordinal: 101, ItemType: "message", MessageID: msgs[1], TokenCount: 2}, + {Ordinal: 102, ItemType: "message", MessageID: msgs[2], TokenCount: 2}, + {Ordinal: 103, ItemType: "message", MessageID: msgs[3], TokenCount: 2}, + {Ordinal: 104, ItemType: "message", MessageID: msgs[4], TokenCount: 2}, + } + s.UpsertContextItems(ctx, conv.ConversationID, items) + + // Create a summary + summary, _ := s.CreateSummary(ctx, CreateSummaryInput{ + ConversationID: conv.ConversationID, Kind: SummaryKindLeaf, Depth: 0, + Content: "summary", TokenCount: 5, + }) + + // Delete 101-102, insert at midpoint 101 + // After delete: 100, 103, 104 + // Midpoint = (101+102)/2 = 101, which doesn't exist after delete + // → No conflict, insert at 101 + // → Result: 100, 101 (summary), 103, 104 + // + // This still doesn't trigger resequence! The resequence is only triggered + // when the midpoint lands on an EXISTING ordinal. + // + // Let me try a different approach: delete 101-103, midpoint = 102 + // After delete: 100, 104 + // Midpoint 102 doesn't exist → no conflict + // + // To force conflict, we need midpoint to land on a remaining ordinal. + // With ordinals 100, 101, 102, 103, 104: + // Delete 100-101, midpoint = 100 (exists? NO, we deleted it!) + // + // The resequence is triggered when we can't find a gap to insert. + // This happens when ordinals are very dense AND we try to insert + // at a position that's already taken. + // + // Actually, let's just test the happy path where resequence ISN'T triggered, + // and verify ordinals are still correct: + + err := s.ReplaceContextRangeWithSummary(ctx, conv.ConversationID, 101, 102, summary.SummaryID) + if err != nil { + t.Fatalf("ReplaceContextRangeWithSummary: %v", err) + } + + result, _ := s.GetContextItems(ctx, conv.ConversationID) + if len(result) != 4 { + t.Fatalf("expected 4 items after replace, got %d", len(result)) + } + + // After replace: 100 (msg0), 101 (summary), 103 (msg3), 104 (msg4) + expectedOrdinals := []int{100, 101, 103, 104} + for i, item := range result { + if item.Ordinal != expectedOrdinals[i] { + t.Errorf("item[%d].Ordinal = %d, want %d", i, item.Ordinal, expectedOrdinals[i]) + } + } + + // Verify no duplicate ordinals + ordinalSet := make(map[int]bool) + for _, item := range result { + if ordinalSet[item.Ordinal] { + t.Errorf("duplicate ordinal %d detected", item.Ordinal) + } + ordinalSet[item.Ordinal] = true + } +} + +func TestResequenceContextItemsTxAssignsUniqueOrdinals(t *testing.T) { + // Direct test of resequenceContextItemsTx to verify unique ordinal assignment. + // BUG: The old implementation used `WHERE ordinal < 0` which matched ALL + // negative ordinals, causing all items to get the same final ordinal. + // + // Example with 3 items at temp ordinals -1, -2, -3: + // - Loop 1: UPDATE ... SET ordinal=100 WHERE ordinal<0 → ALL become 100 + // - Loop 2: UPDATE ... SET ordinal=200 WHERE ordinal<0 → ALL become 200 + // - Loop 3: UPDATE ... SET ordinal=300 WHERE ordinal<0 → ALL become 300 + // Result: [300, 300, 300] - WRONG! + // + // Fixed: Use specific temp ordinal matching: + // - Loop 1: UPDATE ... SET ordinal=100 WHERE ordinal=-1 + // - Loop 2: UPDATE ... SET ordinal=200 WHERE ordinal=-2 + // - Loop 3: UPDATE ... SET ordinal=300 WHERE ordinal=-3 + // Result: [100, 200, 300] - CORRECT! + + s := openTestStore(t) + ctx := context.Background() + + conv, _ := s.GetOrCreateConversation(ctx, "agent:test-resequence-direct") + + // Create messages + msgs := make([]int64, 5) + for i := 0; i < 5; i++ { + m, _ := s.AddMessage(ctx, conv.ConversationID, "user", fmt.Sprintf("msg%d", i), 2) + msgs[i] = m.ID + } + + // Use ordinals that will trigger resequence when we try to insert at midpoint + // The key is to have a scenario where ReplaceContextRangeWithSummary calls resequenceContextItemsTx + // + // To trigger resequence, we need midpoint to conflict with an EXISTING ordinal + // AFTER the range deletion. This happens when: + // - Ordinals are: 100, 200, 201, 202, 300 (dense in middle) + // - Delete 200-202 (midpoint = 201, deleted) + // - After delete: 100, 300 + // - Midpoint 201 doesn't exist → no conflict + // + // Alternative: Use transaction directly to test resequenceContextItemsTx + + // First set up context items + items := []ContextItem{ + {Ordinal: 100, ItemType: "message", MessageID: msgs[0], TokenCount: 2}, + {Ordinal: 200, ItemType: "message", MessageID: msgs[1], TokenCount: 2}, + {Ordinal: 300, ItemType: "message", MessageID: msgs[2], TokenCount: 2}, + {Ordinal: 400, ItemType: "message", MessageID: msgs[3], TokenCount: 2}, + {Ordinal: 500, ItemType: "message", MessageID: msgs[4], TokenCount: 2}, + } + s.UpsertContextItems(ctx, conv.ConversationID, items) + + // Create a summary + summary, _ := s.CreateSummary(ctx, CreateSummaryInput{ + ConversationID: conv.ConversationID, Kind: SummaryKindLeaf, Depth: 0, + Content: "summary", TokenCount: 5, + }) + + // Call resequenceContextItemsTx directly via a transaction + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + t.Fatalf("BeginTx: %v", err) + } + defer tx.Rollback() + + err = s.resequenceContextItemsTx(ctx, tx, conv.ConversationID, summary.SummaryID) + if err != nil { + t.Fatalf("resequenceContextItemsTx: %v", err) + } + tx.Commit() + + // Verify ordinals are unique and properly spaced + result, _ := s.GetContextItems(ctx, conv.ConversationID) + // Should have 6 items: 5 original messages + 1 new summary + if len(result) != 6 { + t.Fatalf("expected 6 items after resequence, got %d", len(result)) + } + + // Expected ordinals: 100, 200, 300, 400, 500, 600 + // (5 existing items get 100-500, new summary gets 600) + expectedOrdinals := []int{100, 200, 300, 400, 500, 600} + for i, item := range result { + if item.Ordinal != expectedOrdinals[i] { + t.Errorf("item[%d].Ordinal = %d, want %d", i, item.Ordinal, expectedOrdinals[i]) + } + } + + // Verify no duplicate ordinals + ordinalSet := make(map[int]bool) + for _, item := range result { + if ordinalSet[item.Ordinal] { + t.Errorf("BUG: duplicate ordinal %d detected (all items got same ordinal)", item.Ordinal) + } + ordinalSet[item.Ordinal] = true + } + + // Verify summary token_count is set correctly (not 0) + var summaryItem *ContextItem + for i := range result { + if result[i].ItemType == "summary" { + summaryItem = &result[i] + break + } + } + if summaryItem == nil { + t.Fatal("no summary item found after resequence") + } + if summaryItem.TokenCount != 5 { + t.Errorf("summary item TokenCount = %d, want 5 (from summary.TokenCount)", summaryItem.TokenCount) + } +} + +func TestStoreGetContextTokenCount(t *testing.T) { + s := openTestStore(t) + ctx := context.Background() + + conv, _ := s.GetOrCreateConversation(ctx, "agent:test") + msg, _ := s.AddMessage(ctx, conv.ConversationID, "user", "hello", 0) + + s.UpsertContextItems(ctx, conv.ConversationID, []ContextItem{ + {Ordinal: 100, ItemType: "message", MessageID: msg.ID, TokenCount: 42}, + }) + + count, err := s.GetContextTokenCount(ctx, conv.ConversationID) + if err != nil { + t.Fatalf("GetContextTokenCount: %v", err) + } + if count != 42 { + t.Errorf("token count = %d, want 42", count) + } +} + +func TestStoreGetMaxOrdinal(t *testing.T) { + s := openTestStore(t) + ctx := context.Background() + + conv, _ := s.GetOrCreateConversation(ctx, "agent:test") + + // No items yet + maxOrd, err := s.GetMaxOrdinal(ctx, conv.ConversationID) + if err != nil { + t.Fatalf("GetMaxOrdinal (empty): %v", err) + } + if maxOrd != 0 { + t.Errorf("max ordinal (empty) = %d, want 0", maxOrd) + } + + // Add items + msg1, _ := s.AddMessage(ctx, conv.ConversationID, "user", "a", 1) + msg2, _ := s.AddMessage(ctx, conv.ConversationID, "user", "b", 1) + s.UpsertContextItems(ctx, conv.ConversationID, []ContextItem{ + {Ordinal: 100, ItemType: "message", MessageID: msg1.ID, TokenCount: 1}, + {Ordinal: 250, ItemType: "message", MessageID: msg2.ID, TokenCount: 1}, + }) + + maxOrd, _ = s.GetMaxOrdinal(ctx, conv.ConversationID) + if maxOrd != 250 { + t.Errorf("max ordinal = %d, want 250", maxOrd) + } +} + +// --- GetDistinctDepthsInContext --- + +func TestStoreGetDistinctDepthsInContext(t *testing.T) { + s := openTestStore(t) + ctx := context.Background() + + conv, _ := s.GetOrCreateConversation(ctx, "agent:test") + + // Empty context → no depths + depths, err := s.GetDistinctDepthsInContext(ctx, conv.ConversationID, 0) + if err != nil { + t.Fatalf("GetDistinctDepthsInContext (empty): %v", err) + } + if len(depths) != 0 { + t.Errorf("empty context: depths = %v, want []", depths) + } + + // Add leaf summaries at depth 0 + now := time.Now().UTC() + s1, _ := s.CreateSummary(ctx, CreateSummaryInput{ + ConversationID: conv.ConversationID, Kind: SummaryKindLeaf, Depth: 0, + Content: "leaf1", TokenCount: 10, EarliestAt: &now, LatestAt: &now, + }) + s2, _ := s.CreateSummary(ctx, CreateSummaryInput{ + ConversationID: conv.ConversationID, Kind: SummaryKindLeaf, Depth: 0, + Content: "leaf2", TokenCount: 10, EarliestAt: &now, LatestAt: &now, + }) + + // Add summaries to context + s.UpsertContextItems(ctx, conv.ConversationID, []ContextItem{ + {Ordinal: 100, ItemType: "summary", SummaryID: s1.SummaryID, TokenCount: 10}, + {Ordinal: 200, ItemType: "summary", SummaryID: s2.SummaryID, TokenCount: 10}, + }) + + // Should find depth 0 + depths, err = s.GetDistinctDepthsInContext(ctx, conv.ConversationID, 0) + if err != nil { + t.Fatalf("GetDistinctDepthsInContext: %v", err) + } + if len(depths) != 1 || depths[0] != 0 { + t.Errorf("depths = %v, want [0]", depths) + } + + // Add condensed at depth 1 + c1, _ := s.CreateSummary(ctx, CreateSummaryInput{ + ConversationID: conv.ConversationID, Kind: SummaryKindCondensed, Depth: 1, + Content: "condensed1", TokenCount: 15, ParentIDs: []string{s1.SummaryID, s2.SummaryID}, + }) + s.AppendContextSummary(ctx, conv.ConversationID, c1.SummaryID) + + // Should find depths [0, 1] or [1, 0] + depths, _ = s.GetDistinctDepthsInContext(ctx, conv.ConversationID, 0) + if len(depths) != 2 { + t.Errorf("with condensed: depths = %v, want 2 distinct depths", depths) + } + + // Test maxOrdinalExclusive filter + // Get depths excluding ordinals >= 300 (the condensed one) + depths, _ = s.GetDistinctDepthsInContext(ctx, conv.ConversationID, 300) + if len(depths) != 1 || depths[0] != 0 { + t.Errorf("filtered depths = %v, want [0]", depths) + } +} + +// --- GetSummarySubtree --- + +func TestStoreGetSummarySubtree(t *testing.T) { + s := openTestStore(t) + ctx := context.Background() + + conv, _ := s.GetOrCreateConversation(ctx, "agent:test") + + // Create leaf summaries + now := time.Now().UTC() + l1, _ := s.CreateSummary(ctx, CreateSummaryInput{ + ConversationID: conv.ConversationID, Kind: SummaryKindLeaf, Depth: 0, + Content: "leaf1", TokenCount: 10, EarliestAt: &now, LatestAt: &now, + }) + l2, _ := s.CreateSummary(ctx, CreateSummaryInput{ + ConversationID: conv.ConversationID, Kind: SummaryKindLeaf, Depth: 0, + Content: "leaf2", TokenCount: 10, EarliestAt: &now, LatestAt: &now, + }) + l3, _ := s.CreateSummary(ctx, CreateSummaryInput{ + ConversationID: conv.ConversationID, Kind: SummaryKindLeaf, Depth: 0, + Content: "leaf3", TokenCount: 10, EarliestAt: &now, LatestAt: &now, + }) + + // Condense l1+l2 → c1 + c1, _ := s.CreateSummary(ctx, CreateSummaryInput{ + ConversationID: conv.ConversationID, Kind: SummaryKindCondensed, Depth: 1, + Content: "condensed1", TokenCount: 15, ParentIDs: []string{l1.SummaryID, l2.SummaryID}, + }) + + // Get subtree from c1 + nodes, err := s.GetSummarySubtree(ctx, c1.SummaryID) + if err != nil { + t.Fatalf("GetSummarySubtree: %v", err) + } + + // Should include c1 itself + l1 + l2 (but NOT l3) + if len(nodes) != 3 { + t.Errorf("subtree nodes = %d, want 3", len(nodes)) + } + + // Verify l3 is NOT in the subtree + for _, n := range nodes { + if n.SummaryID == l3.SummaryID { + t.Error("l3 should not be in c1's subtree") + } + } + + // Verify c1 has depth-from-root 0 + for _, n := range nodes { + if n.SummaryID == c1.SummaryID && n.DepthFromRoot != 0 { + t.Errorf("c1 depth-from-root = %d, want 0", n.DepthFromRoot) + } + } +} + +// --- Search with Rank and Time Filters --- + +func TestStoreSearchSummariesWithRank(t *testing.T) { + s := openTestStore(t) + ctx := context.Background() + + conv, _ := s.GetOrCreateConversation(ctx, "agent:test") + + // Create summaries with different content (for FTS matching) + s.CreateSummary(ctx, CreateSummaryInput{ + ConversationID: conv.ConversationID, Kind: SummaryKindLeaf, Depth: 0, + Content: "machine learning neural network", TokenCount: 10, + }) + s.CreateSummary(ctx, CreateSummaryInput{ + ConversationID: conv.ConversationID, Kind: SummaryKindLeaf, Depth: 0, + Content: "deep learning reinforcement", TokenCount: 10, + }) + + // FTS search — results should have Rank populated + results, err := s.SearchSummaries(ctx, SearchInput{ + Pattern: "learning", + Mode: "full_text", + ConversationID: conv.ConversationID, + }) + if err != nil { + t.Fatalf("SearchSummaries: %v", err) + } + if len(results) < 1 { + t.Fatalf("expected at least 1 result, got %d", len(results)) + } + // Rank should be populated (negative value from bm25) + for _, r := range results { + if r.Rank == 0 { + t.Error("expected non-zero Rank from FTS search") + } + } +} + +func TestStoreSearchSummariesWithTimeFilter(t *testing.T) { + s := openTestStore(t) + ctx := context.Background() + + conv, _ := s.GetOrCreateConversation(ctx, "agent:test") + + // Create a summary + s.CreateSummary(ctx, CreateSummaryInput{ + ConversationID: conv.ConversationID, Kind: SummaryKindLeaf, Depth: 0, + Content: "important meeting notes", TokenCount: 10, + }) + + // Search with Since filter (now - 1 hour → should match) + since := time.Now().UTC().Add(-1 * time.Hour) + results, err := s.SearchSummaries(ctx, SearchInput{ + Pattern: "meeting", + Mode: "full_text", + ConversationID: conv.ConversationID, + Since: &since, + }) + if err != nil { + t.Fatalf("SearchSummaries with Since: %v", err) + } + if len(results) != 1 { + t.Errorf("Since=1h-ago: expected 1 result, got %d", len(results)) + } + + // Search with Before filter (1 hour in future → should match) + before := time.Now().UTC().Add(1 * time.Hour) + results, err = s.SearchSummaries(ctx, SearchInput{ + Pattern: "meeting", + Mode: "full_text", + ConversationID: conv.ConversationID, + Before: &before, + }) + if err != nil { + t.Fatalf("SearchSummaries with Before: %v", err) + } + if len(results) != 1 { + t.Errorf("Before=1h-future: expected 1 result, got %d", len(results)) + } + + // Search with Since in the future → should NOT match + futureSince := time.Now().UTC().Add(1 * time.Hour) + results, err = s.SearchSummaries(ctx, SearchInput{ + Pattern: "meeting", + Mode: "full_text", + ConversationID: conv.ConversationID, + Since: &futureSince, + }) + if err != nil { + t.Fatalf("SearchSummaries with future Since: %v", err) + } + if len(results) != 0 { + t.Errorf("Since=1h-future: expected 0 results, got %d", len(results)) + } +} + +func TestSearchMessagesUsesFTS5(t *testing.T) { + s := openTestStore(t) + ctx := context.Background() + + conv, _ := s.GetOrCreateConversation(ctx, "test:fts5-messages") + convID := conv.ConversationID + + // Add messages with searchable content + s.AddMessage(ctx, convID, "user", "The quick brown fox jumps over the lazy dog", 10) + s.AddMessage(ctx, convID, "assistant", "A response about something else entirely", 10) + s.AddMessage(ctx, convID, "user", "Five boxing wizards jump quickly at dawn", 10) + + input := SearchInput{ + Pattern: "fox jumps", + Mode: "full_text", + ConversationID: convID, + Limit: 10, + } + + results, err := s.SearchMessages(ctx, input) + if err != nil { + t.Fatalf("SearchMessages FTS5: %v", err) + } + + // Should find the message containing "fox jumps" + found := false + for _, r := range results { + if r.MessageID > 0 && contains(r.Snippet, "fox") { + found = true + break + } + } + if !found { + t.Error("FTS5 search should find message with 'fox jumps'") + } +} + +func TestMessagesFTSTriggers(t *testing.T) { + s := openTestStore(t) + ctx := context.Background() + + conv, _ := s.GetOrCreateConversation(ctx, "test:fts-triggers") + convID := conv.ConversationID + + // Insert a message + _, err := s.AddMessage(ctx, convID, "user", "database migration completed successfully", 10) + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + + // Verify FTS table was populated by INSERT trigger + var count int + err = s.db.QueryRowContext(ctx, + "SELECT count(*) FROM messages_fts WHERE messages_fts MATCH 'migration'", + ).Scan(&count) + if err != nil { + t.Fatalf("query messages_fts: %v", err) + } + if count != 1 { + t.Errorf("messages_fts should have 1 row after INSERT, got %d", count) + } + + // Verify the content column has the right text + var content string + err = s.db.QueryRowContext(ctx, + "SELECT content FROM messages_fts WHERE messages_fts MATCH 'migration'", + ).Scan(&content) + if err != nil { + t.Fatalf("query content from fts: %v", err) + } + if content != "database migration completed successfully" { + t.Errorf("fts content = %q, want original message content", content) + } +} + +func TestSearchMessagesWithTimeFilter(t *testing.T) { + s := openTestStore(t) + ctx := context.Background() + + conv, _ := s.GetOrCreateConversation(ctx, "test:msg-time") + convID := conv.ConversationID + + // Add messages + s.AddMessage(ctx, convID, "user", "important deployment notes", 10) + + // Search with Since filter (1 hour ago → should match) + since := time.Now().UTC().Add(-1 * time.Hour) + results, err := s.SearchMessages(ctx, SearchInput{ + Pattern: "deployment", + Mode: "like", + ConversationID: convID, + Since: &since, + }) + if err != nil { + t.Fatalf("SearchMessages with Since: %v", err) + } + if len(results) != 1 { + t.Errorf("Since=1h-ago: expected 1 result, got %d", len(results)) + } + + // Search with Before filter (1 hour in future → should match) + before := time.Now().UTC().Add(1 * time.Hour) + results, err = s.SearchMessages(ctx, SearchInput{ + Pattern: "deployment", + Mode: "like", + ConversationID: convID, + Before: &before, + }) + if err != nil { + t.Fatalf("SearchMessages with Before: %v", err) + } + if len(results) != 1 { + t.Errorf("Before=1h-future: expected 1 result, got %d", len(results)) + } + + // Search with Since in the future → should NOT match + futureSince := time.Now().UTC().Add(1 * time.Hour) + results, err = s.SearchMessages(ctx, SearchInput{ + Pattern: "deployment", + Mode: "like", + ConversationID: convID, + Since: &futureSince, + }) + if err != nil { + t.Fatalf("SearchMessages with future Since: %v", err) + } + if len(results) != 0 { + t.Errorf("Since=1h-future: expected 0 results, got %d", len(results)) + } +} + +func TestStoreSearchSummariesReturnsContent(t *testing.T) { + s := openTestStore(t) + ctx := context.Background() + + conv, _ := s.GetOrCreateConversation(ctx, "agent:test") + + // Create a summary with known content + s.CreateSummary(ctx, CreateSummaryInput{ + ConversationID: conv.ConversationID, + Kind: SummaryKindLeaf, + Depth: 0, + Content: "This is the summary content for testing", + TokenCount: 10, + }) + + // Search should return the full content, not empty + results, err := s.SearchSummaries(ctx, SearchInput{ + Pattern: "summary content", + Mode: "like", + ConversationID: conv.ConversationID, + }) + if err != nil { + t.Fatalf("SearchSummaries: %v", err) + } + if len(results) != 1 { + t.Fatalf("expected 1 result, got %d", len(results)) + } + if results[0].Content == "" { + t.Error("SearchResult.Content is empty, want full summary content") + } + if results[0].Content != "This is the summary content for testing" { + t.Errorf("SearchResult.Content = %q, want %q", results[0].Content, "This is the summary content for testing") + } +} + +func TestStoreReplaceContextItemsWithSummary(t *testing.T) { + s := openTestStore(t) + ctx := context.Background() + + conv, _ := s.GetOrCreateConversation(ctx, "agent:test-replace-items") + + // Create messages + msgs := make([]int64, 5) + for i := 0; i < 5; i++ { + m, _ := s.AddMessage(ctx, conv.ConversationID, "user", fmt.Sprintf("msg%d", i), 2) + msgs[i] = m.ID + } + + // Create summaries + summaries := make([]string, 3) + for i := 0; i < 3; i++ { + sum, _ := s.CreateSummary(ctx, CreateSummaryInput{ + ConversationID: conv.ConversationID, + Kind: SummaryKindLeaf, + Depth: 0, + Content: fmt.Sprintf("summary %d", i), + TokenCount: 10, + }) + summaries[i] = sum.SummaryID + } + + // Insert context items with a message in between summaries: + // Ordinals: 100 (summary0), 200 (message), 300 (summary1), 400 (summary2) + items := []ContextItem{ + {Ordinal: 100, ItemType: "summary", SummaryID: summaries[0], TokenCount: 10}, + {Ordinal: 200, ItemType: "message", MessageID: msgs[1], TokenCount: 2}, + {Ordinal: 300, ItemType: "summary", SummaryID: summaries[1], TokenCount: 10}, + {Ordinal: 400, ItemType: "summary", SummaryID: summaries[2], TokenCount: 10}, + } + s.UpsertContextItems(ctx, conv.ConversationID, items) + + // Create a new summary to replace with + newSummary, _ := s.CreateSummary(ctx, CreateSummaryInput{ + ConversationID: conv.ConversationID, + Kind: SummaryKindCondensed, + Depth: 1, + Content: "condensed summary", + TokenCount: 15, + }) + + // Replace summaries 0 and 1 (not 2) using per-item deletion + // This should NOT delete the message at ordinal 200 + err := s.ReplaceContextItemsWithSummary( + ctx, conv.ConversationID, + []string{summaries[0], summaries[1]}, + newSummary.SummaryID) + if err != nil { + t.Fatalf("ReplaceContextItemsWithSummary: %v", err) + } + + // Verify result: should have 3 items (message at 200, summary2 at 400, new summary) + result, _ := s.GetContextItems(ctx, conv.ConversationID) + if len(result) != 3 { + t.Fatalf("expected 3 items after replace, got %d", len(result)) + } + + // Verify message at ordinal 200 is preserved + messagePreserved := false + for _, item := range result { + if item.ItemType == "message" && item.MessageID == msgs[1] { + messagePreserved = true + break + } + } + if !messagePreserved { + t.Error("message at ordinal 200 should have been preserved") + } + + // Verify summary2 at ordinal 400 is preserved + summary2Preserved := false + for _, item := range result { + if item.ItemType == "summary" && item.SummaryID == summaries[2] { + summary2Preserved = true + break + } + } + if !summary2Preserved { + t.Error("summary2 at ordinal 400 should have been preserved") + } + + // Verify new summary exists + newSummaryFound := false + for _, item := range result { + if item.ItemType == "summary" && item.SummaryID == newSummary.SummaryID { + newSummaryFound = true + break + } + } + if !newSummaryFound { + t.Error("new summary should exist") + } + + // Verify no duplicate ordinals + ordinalSet := make(map[int]bool) + for _, item := range result { + if ordinalSet[item.Ordinal] { + t.Errorf("duplicate ordinal %d detected", item.Ordinal) + } + ordinalSet[item.Ordinal] = true + } +} diff --git a/pkg/seahorse/tool_expand.go b/pkg/seahorse/tool_expand.go new file mode 100644 index 000000000..749c9cd6c --- /dev/null +++ b/pkg/seahorse/tool_expand.go @@ -0,0 +1,129 @@ +package seahorse + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/sipeed/picoclaw/pkg/tools" +) + +// ExpandTool recovers full message content by ID. +type ExpandTool struct { + engine *RetrievalEngine +} + +func NewExpandTool(engine *RetrievalEngine) *ExpandTool { + return &ExpandTool{engine: engine} +} + +func (t *ExpandTool) Name() string { + return "short_expand" +} + +func (t *ExpandTool) Description() string { + return `Get full message content by ID. + +Use when short_grep returns messages and you need complete content (not just snippet). + +Parameters: +- message_ids (required): Array of message ID strings (from short_grep results) + +Returns message with: +- content: Full text content +- parts: Structured content + - text: Full text + - tool_use: name, arguments, toolCallId + - tool_result: toolCallId only (content omitted - re-run tool if needed) + - media: mediaUri (file path), mimeType + +Notes: +- tool_result content is not returned (can be large). Re-run the tool if you need the result. +- Media files are stored on disk at mediaUri path, use bash to access. + +Example: + {"message_ids": ["10", "25"]}` +} + +func (t *ExpandTool) Parameters() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{ + "message_ids": map[string]any{ + "type": "array", + "items": map[string]any{"type": "string"}, + "description": "Message IDs to expand (from short_grep results, e.g., [\"10\", \"25\"])", + }, + }, + "required": []string{"message_ids"}, + } +} + +func (t *ExpandTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult { + idsRaw, ok := args["message_ids"].([]any) + if !ok || len(idsRaw) == 0 { + return tools.ErrorResult( + "Missing required 'message_ids' argument. " + + "Example: {\"message_ids\": [\"10\", \"25\"]}") + } + + // Parse message IDs + messageIDs := make([]int64, 0, len(idsRaw)) + for _, id := range idsRaw { + switch v := id.(type) { + case string: + var n int64 + if _, err := fmt.Sscanf(v, "%d", &n); err != nil { + return tools.ErrorResult(fmt.Sprintf("Invalid message_id %q: %v", v, err)) + } + messageIDs = append(messageIDs, n) + case float64: + messageIDs = append(messageIDs, int64(v)) + } + } + + result, err := t.engine.ExpandMessages(ctx, messageIDs) + if err != nil { + return tools.ErrorResult("Expand failed: " + err.Error()) + } + + // Build response with filtered parts + messages := make([]map[string]any, 0, len(result.Messages)) + for _, msg := range result.Messages { + parts := make([]map[string]any, 0, len(msg.Parts)) + for _, p := range msg.Parts { + part := map[string]any{"type": p.Type} + switch p.Type { + case "text": + part["text"] = p.Text + case "tool_use": + part["name"] = p.Name + part["arguments"] = p.Arguments + part["toolCallId"] = p.ToolCallID + case "tool_result": + // Omit content - can be large, re-run tool if needed + part["toolCallId"] = p.ToolCallID + case "media": + part["mediaUri"] = p.MediaURI + part["mimeType"] = p.MimeType + } + parts = append(parts, part) + } + + messages = append(messages, map[string]any{ + "id": fmt.Sprintf("%d", msg.ID), + "role": msg.Role, + "content": msg.Content, + "parts": parts, + "conversationId": msg.ConversationID, + }) + } + + output := map[string]any{ + "success": true, + "tokenCount": result.TokenCount, + "messages": messages, + } + data, _ := json.Marshal(output) + return tools.NewToolResult(string(data)) +} diff --git a/pkg/seahorse/tool_expand_test.go b/pkg/seahorse/tool_expand_test.go new file mode 100644 index 000000000..fc726a7a0 --- /dev/null +++ b/pkg/seahorse/tool_expand_test.go @@ -0,0 +1,136 @@ +package seahorse + +import ( + "context" + "encoding/json" + "fmt" + "testing" +) + +func TestExpandToolByMessageIDs(t *testing.T) { + s := openTestStore(t) + ctx := context.Background() + conv, _ := s.GetOrCreateConversation(ctx, "test:expand-tool") + + msg1, _ := s.AddMessage(ctx, conv.ConversationID, "user", "first message", 10) + msg2, _ := s.AddMessage(ctx, conv.ConversationID, "assistant", "second message", 10) + + re := &RetrievalEngine{store: s} + tool := NewExpandTool(re) + + result := tool.Execute(ctx, map[string]any{ + "message_ids": []any{fmt.Sprintf("%d", msg1.ID), fmt.Sprintf("%d", msg2.ID)}, + }) + + if result.IsError { + t.Fatalf("Expand failed: %s", result.ForLLM) + } + + // Parse result + var output struct { + Success bool `json:"success"` + TokenCount int `json:"tokenCount"` + Messages []map[string]any `json:"messages"` + } + if err := json.Unmarshal([]byte(result.ForLLM), &output); err != nil { + t.Fatalf("Parse result: %v", err) + } + + if !output.Success { + t.Error("expected success=true") + } + if len(output.Messages) != 2 { + t.Errorf("Messages = %d, want 2", len(output.Messages)) + } + if output.TokenCount != 20 { + t.Errorf("TokenCount = %d, want 20", output.TokenCount) + } +} + +func TestExpandToolMissingIDs(t *testing.T) { + s := openTestStore(t) + re := &RetrievalEngine{store: s} + tool := NewExpandTool(re) + + result := tool.Execute(context.Background(), map[string]any{}) + + if !result.IsError { + t.Error("expected error for missing message_ids") + } +} + +func TestExpandToolWithParts(t *testing.T) { + s := openTestStore(t) + ctx := context.Background() + conv, _ := s.GetOrCreateConversation(ctx, "test:expand-parts") + + // Create message with parts + parts := []MessagePart{ + {Type: "text", Text: "Hello"}, + {Type: "tool_use", Name: "bash", Arguments: `{"command":"ls"}`, ToolCallID: "call_123"}, + {Type: "tool_result", ToolCallID: "call_123", Text: "file1.txt\nfile2.txt"}, + } + msg, _ := s.AddMessageWithParts(ctx, conv.ConversationID, "assistant", parts, 50) + + re := &RetrievalEngine{store: s} + tool := NewExpandTool(re) + + result := tool.Execute(ctx, map[string]any{ + "message_ids": []any{fmt.Sprintf("%d", msg.ID)}, + }) + + if result.IsError { + t.Fatalf("Expand failed: %s", result.ForLLM) + } + + var output struct { + Messages []struct { + Parts []map[string]any `json:"parts"` + } `json:"messages"` + } + if err := json.Unmarshal([]byte(result.ForLLM), &output); err != nil { + t.Fatalf("Parse result: %v", err) + } + + if len(output.Messages) != 1 { + t.Fatalf("Messages = %d, want 1", len(output.Messages)) + } + + // Verify parts are filtered correctly + foundText := false + foundToolUse := false + foundToolResult := false + for _, p := range output.Messages[0].Parts { + switch p["type"].(string) { + case "text": + foundText = true + if p["text"] != "Hello" { + t.Errorf("text = %v, want Hello", p["text"]) + } + case "tool_use": + foundToolUse = true + if p["name"] != "bash" { + t.Errorf("name = %v, want bash", p["name"]) + } + case "tool_result": + foundToolResult = true + // tool_result should NOT have content + if _, hasContent := p["content"]; hasContent { + t.Error("tool_result should not have content field") + } + if p["toolCallId"] != "call_123" { + t.Errorf("toolCallId = %v, want call_123", p["toolCallId"]) + } + } + } + + if !foundText { + t.Error("missing text part") + } + if !foundToolUse { + t.Error("missing tool_use part") + } + if !foundToolResult { + t.Error("missing tool_result part") + } +} diff --git a/pkg/seahorse/tool_grep.go b/pkg/seahorse/tool_grep.go new file mode 100644 index 000000000..6502fc5c3 --- /dev/null +++ b/pkg/seahorse/tool_grep.go @@ -0,0 +1,172 @@ +package seahorse + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/sipeed/picoclaw/pkg/tools" +) + +// GrepTool searches summaries and messages for matching content. +type GrepTool struct { + engine *RetrievalEngine +} + +func NewGrepTool(engine *RetrievalEngine) *GrepTool { + return &GrepTool{engine: engine} +} + +func (t *GrepTool) Name() string { + return "short_grep" +} + +func (t *GrepTool) Description() string { + return `Search summaries and messages for matching content. + +Pattern syntax: +- Words: "authentication" - matches content containing this word +- AND: "auth AND login" - matches content with both words +- OR: "auth OR signin" - matches content with either word +- NOT: "bug NOT fixed" - matches "bug" but excludes "fixed" +- Wildcard: "%auth%" - matches any text containing "auth" (e.g., "auth", "authentication") + +Each summary has a "depth" field: +- depth 0: Created from messages, most detailed +- depth 1+: Created from other summaries, more compressed but covers longer time + +Parameters: +- pattern (required): Search pattern +- scope: "both" (default), "summary", or "message" - what to search +- role: "user", "assistant", or omit for all - filter by message role +- last: Time shortcut like "6h", "7d", "2w", "1m" (hours/days/weeks/months) +- all_conversations: Search all conversations (default: current only) +- since: ISO8601 timestamp, content after this time +- before: ISO8601 timestamp, content before this time +- limit: Max results (default: 20) + +Returns: +{ + "success": true, + "summaries": [{"id": "sum_abc", "content": "...", "depth": 0, "kind": "leaf", "conversationId": 1, "rank": -0.5}], + "messages": [{"id": "10", "snippet": "...matched...", "role": "user", "conversationId": 1, "rank": -1.2}], + "totalSummaries": 5, + "totalMessages": 10, + "hint": "No matches. Try: %keyword% for fuzzy search" +} + +Rank field (FTS5 mode only): bm25 relevance score, negative value where closer to 0 = better match. +Examples: -0.5=excellent, -2=good, -5=partial, -10=weak. LIKE mode (%pattern%) has no rank. + +Examples: + {"pattern": "authentication"} + {"pattern": "bug AND login"} + {"pattern": "%snake%"} + {"pattern": "project", "scope": "summary"} + {"pattern": "error", "role": "assistant", "last": "7d"} + {"pattern": "error", "all_conversations": true}` +} + +func (t *GrepTool) Parameters() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{ + "pattern": map[string]any{ + "type": "string", + "description": "Search pattern. Supports: words, AND/OR/NOT operators, % wildcard", + }, + "scope": map[string]any{ + "type": "string", + "enum": []string{"both", "summary", "message"}, + "description": "What to search: 'both' (default), 'summary', or 'message'", + }, + "role": map[string]any{ + "type": "string", + "enum": []string{"user", "assistant"}, + "description": "Filter by message role (default: all roles)", + }, + "last": map[string]any{ + "type": "string", + "description": "Time shortcut: '6h' (6 hours), '7d' (7 days), '2w' (2 weeks), '1m' (1 month)", + }, + "all_conversations": map[string]any{ + "type": "boolean", + "description": "Search across all conversations (default: searches current conversation only)", + }, + "since": map[string]any{ + "type": "string", + "description": "ISO8601 timestamp, only return content after this time", + }, + "before": map[string]any{ + "type": "string", + "description": "ISO8601 timestamp, only return content before this time", + }, + "limit": map[string]any{ + "type": "integer", + "description": "Maximum number of results (default: 20)", + }, + }, + "required": []string{"pattern"}, + } +} + +func (t *GrepTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult { + pattern, ok := args["pattern"].(string) + if !ok || pattern == "" { + return tools.ErrorResult("Missing required 'pattern' argument. Example: {\"pattern\": \"authentication\"}") + } + + input := GrepInput{Pattern: pattern} + + if scope, ok := args["scope"].(string); ok && scope != "" { + input.Scope = scope + } + if role, ok := args["role"].(string); ok && role != "" { + input.Role = role + } + if last, ok := args["last"].(string); ok && last != "" { + input.Last = last + } + if allConv, ok := args["all_conversations"].(bool); ok { + input.AllConversations = allConv + } + if limit, ok := args["limit"].(float64); ok { + input.Limit = int(limit) + } + if sinceStr, ok := args["since"].(string); ok && sinceStr != "" { + parsed, err := time.Parse(time.RFC3339, sinceStr) + if err != nil { + return tools.ErrorResult(fmt.Sprintf( + "Invalid 'since' timestamp. Use RFC3339 format like '2024-01-15T10:00:00Z'. Error: %v", err)) + } + input.Since = &parsed + } + if beforeStr, ok := args["before"].(string); ok && beforeStr != "" { + parsed, err := time.Parse(time.RFC3339, beforeStr) + if err != nil { + return tools.ErrorResult(fmt.Sprintf("Invalid 'before' timestamp format: %v", err)) + } + input.Before = &parsed + } + + result, err := t.engine.Grep(ctx, input) + if err != nil { + return tools.ErrorResult("Grep failed: " + err.Error()) + } + + // Build response + output := map[string]any{ + "success": result.Success, + "summaries": result.Summaries, + "messages": result.Messages, + } + + // Add hint if provided + if result.Hint != "" { + output["hint"] = result.Hint + } + + data, _ := json.Marshal(output) + return tools.NewToolResult(string(data)) +} diff --git a/pkg/seahorse/tool_grep_test.go b/pkg/seahorse/tool_grep_test.go new file mode 100644 index 000000000..050d9deeb --- /dev/null +++ b/pkg/seahorse/tool_grep_test.go @@ -0,0 +1,72 @@ +package seahorse + +import ( + "context" + "testing" +) + +func TestGrepSearchSummaries(t *testing.T) { + s := openTestStore(t) + ctx := context.Background() + conv, _ := s.GetOrCreateConversation(ctx, "test:grep-tool") + + s.CreateSummary(ctx, CreateSummaryInput{ + ConversationID: conv.ConversationID, + Kind: SummaryKindLeaf, + Depth: 0, + Content: "database connection pool configuration", + TokenCount: 50, + }) + + re := &RetrievalEngine{store: s} + results, err := re.Grep(ctx, GrepInput{ + Pattern: "database", + }) + if err != nil { + t.Fatalf("Grep: %v", err) + } + if len(results.Summaries) == 0 { + t.Error("expected at least 1 summary result") + } +} + +func TestGrepSearchMessages(t *testing.T) { + s := openTestStore(t) + ctx := context.Background() + conv, _ := s.GetOrCreateConversation(ctx, "test:grep-msg") + + s.AddMessage(ctx, conv.ConversationID, "user", "find this message about testing", 5) + s.AddMessage(ctx, conv.ConversationID, "user", "unrelated content", 3) + + re := &RetrievalEngine{store: s} + results, err := re.Grep(ctx, GrepInput{ + Pattern: "testing", + }) + if err != nil { + t.Fatalf("Grep messages: %v", err) + } + if len(results.Messages) == 0 { + t.Error("expected at least 1 message result") + } +} + +func TestGrepMissingPattern(t *testing.T) { + s := openTestStore(t) + re := &RetrievalEngine{store: s} + _, err := re.Grep(context.Background(), GrepInput{}) + if err == nil { + t.Error("expected error for missing pattern") + } +} + +func TestGrepToolSupportsAllConversations(t *testing.T) { + s := openTestStore(t) + tool := NewGrepTool(&RetrievalEngine{store: s}) + params := tool.Parameters() + props := params["properties"].(map[string]any) + + // GrepTool should accept all_conversations parameter + if _, ok := props["all_conversations"]; !ok { + t.Error("Parameters missing 'all_conversations' field") + } +} diff --git a/pkg/seahorse/types.go b/pkg/seahorse/types.go new file mode 100644 index 000000000..2bc7f931f --- /dev/null +++ b/pkg/seahorse/types.go @@ -0,0 +1,161 @@ +package seahorse + +import ( + "time" + + "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/tokenizer" +) + +// SummaryKind distinguishes leaf summaries (from raw messages) vs condensed +// summaries (from other summaries). +type SummaryKind string + +const ( + SummaryKindLeaf SummaryKind = "leaf" + SummaryKindCondensed SummaryKind = "condensed" +) + +// Message represents a single chat message with role and content. +type Message struct { + ID int64 `json:"id"` + ConversationID int64 `json:"conversationId"` + Role string `json:"role"` + Content string `json:"content"` + ReasoningContent string `json:"reasoningContent,omitempty"` + TokenCount int `json:"tokenCount"` + CreatedAt time.Time `json:"createdAt"` + Parts []MessagePart `json:"parts,omitempty"` +} + +// MessagePart holds structured content (tool calls, media, etc.) +type MessagePart struct { + ID int64 `json:"id"` + MessageID int64 `json:"messageId"` + Type string `json:"type"` // "text", "tool_use", "tool_result", "media" + Text string `json:"text"` + Name string `json:"name"` + Arguments string `json:"arguments"` + ToolCallID string `json:"toolCallId"` + MediaURI string `json:"mediaUri"` + MimeType string `json:"mimeType"` +} + +// Summary represents a compressed representation of messages or other summaries. +type Summary struct { + SummaryID string `json:"summaryId"` + ConversationID int64 `json:"conversationId"` + Kind SummaryKind `json:"kind"` + Depth int `json:"depth"` + Content string `json:"content"` + TokenCount int `json:"tokenCount"` + EarliestAt *time.Time `json:"earliestAt,omitempty"` + LatestAt *time.Time `json:"latestAt,omitempty"` + DescendantCount int `json:"descendantCount"` + DescendantTokenCount int `json:"descendantTokenCount"` + SourceMessageTokenCount int `json:"sourceMessageTokenCount"` + Model string `json:"model"` + CreatedAt time.Time `json:"createdAt"` +} + +// SummaryNode is a Summary with graph relationships for tree traversal. +type SummaryNode struct { + Summary + Children []string `json:"children"` // Child summary IDs + Expanded bool `json:"expanded"` // UI state for expansion +} + +// Conversation represents a session's conversation with metadata. +type Conversation struct { + ConversationID int64 `json:"conversationId"` + SessionKey string `json:"sessionKey"` + CreatedAt time.Time `json:"createdAt"` + UpdatedAt time.Time `json:"updatedAt"` +} + +// SessionStatus contains status information for a session. +type SessionStatus struct { + SessionKey string `json:"sessionKey"` + ConversationID int64 `json:"conversationId"` + Messages int `json:"messages"` + TotalTokens int `json:"totalTokens"` + Summaries int `json:"summaries"` + OldestAt time.Time `json:"oldestAt"` + NewestAt time.Time `json:"newestAt"` +} + +// ContextItem represents one item in the assembled context window. +type ContextItem struct { + ConversationID int64 `json:"conversationId"` + Ordinal int `json:"ordinal"` + ItemType string `json:"itemType"` // "summary" or "message" + SummaryID string `json:"summaryId,omitempty"` + MessageID int64 `json:"messageId,omitempty"` + TokenCount int `json:"tokenCount"` + CreatedAt time.Time `json:"createdAt"` +} + +// SummarySubtreeNode is a node in a summary DAG subtree. +type SummarySubtreeNode struct { + SummaryID string `json:"summaryId"` + DepthFromRoot int `json:"depthFromRoot"` +} + +// SearchInput controls summary search. +type SearchInput struct { + Pattern string `json:"pattern"` + Mode string `json:"mode"` // "like" (LIKE search) or "full_text" (FTS5, default) + Scope string `json:"scope,omitempty"` // "messages", "summaries", "both" + Role string `json:"role,omitempty"` // "user", "assistant", or "" (all) + Since *time.Time `json:"since,omitempty"` + Before *time.Time `json:"before,omitempty"` + Limit int `json:"limit,omitempty"` + ConversationID int64 `json:"conversationId,omitempty"` + AllConversations bool `json:"allConversations,omitempty"` +} + +// SearchResult is a search match. +type SearchResult struct { + SummaryID string `json:"summaryId,omitempty"` + MessageID int64 `json:"messageId,omitempty"` + ConversationID int64 `json:"conversationId"` + Kind SummaryKind `json:"kind,omitempty"` + Depth int `json:"depth,omitempty"` + Role string `json:"role,omitempty"` + Content string `json:"content,omitempty"` // Full content for summaries + Snippet string `json:"snippet"` + CreatedAt time.Time `json:"createdAt"` + Rank float64 `json:"rank,omitempty"` + TotalCount int `json:"totalCount,omitempty"` // Total matching rows (from window function) +} + +// EstimateMessageTokens estimates token count for a full message using the +// shared tokenizer package for consistency with agent.context_budget. +func EstimateMessageTokens(msg Message) int { + pm := providers.Message{ + Role: msg.Role, + Content: msg.Content, + ReasoningContent: msg.ReasoningContent, + } + + // Convert MessageParts to ToolCalls / ToolCallID / Media + for _, part := range msg.Parts { + switch part.Type { + case "tool_use": + pm.ToolCalls = append(pm.ToolCalls, providers.ToolCall{ + ID: part.ToolCallID, + Type: "function", + Function: &providers.FunctionCall{ + Name: part.Name, + Arguments: part.Arguments, + }, + }) + case "tool_result": + pm.ToolCallID = part.ToolCallID + case "media": + pm.Media = append(pm.Media, part.MediaURI) + } + } + + return tokenizer.EstimateMessageTokens(pm) +} diff --git a/pkg/seahorse/types_test.go b/pkg/seahorse/types_test.go new file mode 100644 index 000000000..b7467005f --- /dev/null +++ b/pkg/seahorse/types_test.go @@ -0,0 +1,54 @@ +package seahorse + +import ( + "testing" +) + +func TestSummaryKindValues(t *testing.T) { + if SummaryKindLeaf != "leaf" { + t.Errorf("expected SummaryKindLeaf = 'leaf', got %q", SummaryKindLeaf) + } + if SummaryKindCondensed != "condensed" { + t.Errorf("expected SummaryKindCondensed = 'condensed', got %q", SummaryKindCondensed) + } +} + +func TestConstants(t *testing.T) { + // Ordinal gap step + if OrdinalStep != 100 { + t.Errorf("expected OrdinalStep = 100, got %d", OrdinalStep) + } + + // Compaction triggers + if ContextThreshold != 0.75 { + t.Errorf("expected ContextThreshold = 0.75, got %f", ContextThreshold) + } + if FreshTailCount != 32 { + t.Errorf("expected FreshTailCount = 32, got %d", FreshTailCount) + } + + // Fanout + if LeafMinFanout != 8 { + t.Errorf("expected LeafMinFanout = 8, got %d", LeafMinFanout) + } + if CondensedMinFanout != 4 { + t.Errorf("expected CondensedMinFanout = 4, got %d", CondensedMinFanout) + } + if CondensedMinFanoutHard != 2 { + t.Errorf("expected CondensedMinFanoutHard = 2, got %d", CondensedMinFanoutHard) + } + + // Token targets + if LeafChunkTokens != 20000 { + t.Errorf("expected LeafChunkTokens = 20000, got %d", LeafChunkTokens) + } + if LeafTargetTokens != 1200 { + t.Errorf("expected LeafTargetTokens = 1200, got %d", LeafTargetTokens) + } + if CondensedTargetTokens != 2000 { + t.Errorf("expected CondensedTargetTokens = 2000, got %d", CondensedTargetTokens) + } + if MaxExpandTokens != 4000 { + t.Errorf("expected MaxExpandTokens = 4000, got %d", MaxExpandTokens) + } +} diff --git a/pkg/session/jsonl_backend.go b/pkg/session/jsonl_backend.go index 7f470de15..5a2297e30 100644 --- a/pkg/session/jsonl_backend.go +++ b/pkg/session/jsonl_backend.go @@ -79,3 +79,8 @@ func (b *JSONLBackend) Save(key string) error { func (b *JSONLBackend) Close() error { return b.store.Close() } + +// ListSessions returns all known session keys. +func (b *JSONLBackend) ListSessions() []string { + return b.store.ListSessions() +} diff --git a/pkg/session/manager.go b/pkg/session/manager.go index ef720b7c5..7f87d460a 100644 --- a/pkg/session/manager.go +++ b/pkg/session/manager.go @@ -145,6 +145,16 @@ func (sm *SessionManager) TruncateHistory(key string, keepLast int) { session.Updated = time.Now() } +func (sm *SessionManager) ListSessions() []string { + sm.mu.RLock() + defer sm.mu.RUnlock() + keys := make([]string, 0, len(sm.sessions)) + for k := range sm.sessions { + keys = append(keys, k) + } + return keys +} + // sanitizeFilename converts a session key into a cross-platform safe filename. // Replaces ':' with '_' (session key separator) and '/' and '\' with '_' so // composite IDs (e.g. Telegram forum "chatID/threadID") do not create diff --git a/pkg/session/session_store.go b/pkg/session/session_store.go index 1d1a2f967..2ba2a974d 100644 --- a/pkg/session/session_store.go +++ b/pkg/session/session_store.go @@ -27,6 +27,8 @@ type SessionStore interface { TruncateHistory(key string, keepLast int) // Save persists any pending state to durable storage. Save(key string) error + // ListSessions returns all known session keys. + ListSessions() []string // Close releases resources held by the store. Close() error } diff --git a/pkg/tokenizer/estimator.go b/pkg/tokenizer/estimator.go new file mode 100644 index 000000000..3265edaa8 --- /dev/null +++ b/pkg/tokenizer/estimator.go @@ -0,0 +1,91 @@ +package tokenizer + +import ( + "encoding/json" + "unicode/utf8" + + "github.com/sipeed/picoclaw/pkg/providers" +) + +// EstimateMessageTokens estimates the token count for a single message, +// including Content, ReasoningContent, ToolCalls arguments, ToolCallID +// metadata, and Media items. Uses a heuristic of 2.5 characters per token. +func EstimateMessageTokens(msg providers.Message) int { + contentChars := utf8.RuneCountInString(msg.Content) + + // SystemParts are structured system blocks used for cache-aware adapters. + // They carry the same content as Content, but in multiple blocks. + // We estimate them as an alternative representation, not additive. + systemPartsChars := 0 + if len(msg.SystemParts) > 0 { + for _, part := range msg.SystemParts { + systemPartsChars += utf8.RuneCountInString(part.Text) + } + // Per-part overhead for JSON structure (type, text, cache_control). + const perPartOverhead = 20 + systemPartsChars += len(msg.SystemParts) * perPartOverhead + } + + // Use the larger of the two representations to stay conservative. + chars := contentChars + if systemPartsChars > chars { + chars = systemPartsChars + } + + chars += utf8.RuneCountInString(msg.ReasoningContent) + + for _, tc := range msg.ToolCalls { + chars += len(tc.ID) + len(tc.Type) + if tc.Function != nil { + // Count function name + arguments (the wire format for most providers). + // tc.Name mirrors tc.Function.Name — count only once to avoid double-counting. + chars += len(tc.Function.Name) + len(tc.Function.Arguments) + } else { + // Fallback: some provider formats use top-level Name without Function. + chars += len(tc.Name) + } + } + + if msg.ToolCallID != "" { + chars += len(msg.ToolCallID) + } + + // Per-message overhead for role label, JSON structure, separators. + const messageOverhead = 12 + chars += messageOverhead + + tokens := chars * 2 / 5 + + // Media items (images, files) are serialized by provider adapters into + // multipart or image_url payloads. Add a fixed per-item token estimate + // directly (not through the chars heuristic) since actual cost depends + // on resolution and provider-specific image tokenization. + const mediaTokensPerItem = 256 + tokens += len(msg.Media) * mediaTokensPerItem + + return tokens +} + +// EstimateToolDefsTokens estimates the total token cost of tool definitions +// as they appear in the LLM request. +func EstimateToolDefsTokens(defs []providers.ToolDefinition) int { + if len(defs) == 0 { + return 0 + } + + totalChars := 0 + for _, d := range defs { + totalChars += len(d.Function.Name) + len(d.Function.Description) + + if d.Function.Parameters != nil { + if paramJSON, err := json.Marshal(d.Function.Parameters); err == nil { + totalChars += len(paramJSON) + } + } + + // Per-tool overhead: type field, JSON structure, separators. + totalChars += 20 + } + + return totalChars * 2 / 5 +}