mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-05-25 16:00:35 +00:00
feat(seahorse): implement short-term memory engine (LCM) (#2285)
* feat(seahorse): implement short-term memory engine of seahorse Add pkg/seahorse/ module implementing a SQLite-backed DAG-based summary hierarchy for context management, ported from lossless-claw's LCM design: - types.go + short_constants.go: core types (Message, Summary, Conversation, ContextItem) and configuration constants (fanout, token targets, thresholds) - migration.go: idempotent DB schema with FTS5 trigram tokenizer for CJK - store.go: full SQLite CRUD (conversations, messages, summaries DAG, context_items with ordinal gap numbering, FTS5 search) - short_engine.go: Engine lifecycle (NewEngine, Ingest, Assemble, Compact), session pattern filtering (ignore/stateless glob→regex compilation), per-session mutex via sync.Map - short_assembler.go: budget-aware context assembly with fresh tail protection (32 messages), oldest-first eviction, summary XML formatting, RebuildContextItems - short_compaction.go: leaf compaction (messages→summary) and condensed compaction (summaries→higher-level summary), 3-level LLM escalation, CompactUntilUnder for emergency overflow - short_retrieval.go: lookupByID, FTS5/LIKE search, recursive expand with token cap - context_seahorse.go: agent.ContextManager adapter, registered as "seahorse", provider↔seahorse message type conversion (ToolCalls, tool_result) * fix(seahorse): correct 3 adapter bugs in context management - TokenCount: use full message (Content+ToolCalls+Media) instead of Content-only - Empty Content: rebuild Content from tool_result Parts when stored empty - Duplicate summaries: summaries only in Summary field, not in History messages - Grep: fix SearchResult.Snippet→Content for summaries - Schema: fix FTS5 SQL uses VIRTUAL TABLE not TEMP TABLE - TestFTS5SQLConstants: verify FTS5 SQL syntax correctness - Test: fix flaky TestCompactLeaf * fix(agent): ingest steering messages into seahorse SQLite Steering messages were only persisted to session JSONL but not ingested into seahorse SQLite, causing them to be missing from context assembly. Added `ts.ingestMessage(turnCtx, al, pm)` call in the steering message injection block alongside the existing JSONL persistence. Test: TestSeahorseSteeringMessageIngested verifies steering messages appear in seahorse SQLite DB after being processed. * fix(seahorse): address 3 blocking bugs from code review - Fix resequenceContextItemsTx scan error handling (store.go:850) Changed `return err` to `return scanErr` to properly propagate scan errors instead of returning nil (which silently corrupts data) - Fix sql.NullString for INTEGER column (store.go:847) Changed `mid` from sql.NullString to sql.NullInt64 since message_id is INTEGER in schema. Removed unnecessary strconv.ParseInt call. - Fix compactCondensed fallback deleting non-candidate items Added ReplaceContextItemsWithSummary method for per-item deletion when candidates are not contiguous in ordinal space. Optimized to use range deletion when candidates are consecutive. * fix(seahorse): pass Budget to Compact for correct condensed threshold Issue #4 from PR review: When Budget was not passed to seahorse.Compact, it defaulted to `tokensBefore * 0.75`, making `tokensBefore > budget` always true and causing condensed compaction to trigger unnecessarily. Changes: - context_seahorse.go: Forward Budget from CompactRequest to CompactInput - loop.go: Pass Budget (ContextWindow) in all 3 Compact calls - Add test verifying condensed is skipped when tokens < threshold - Fix lint issues in store.go and store_test.go * fix(seahorse): add mutex for assembler lazy initialization Issue #5 from PR review: The check-then-create pattern for e.assembler was a data race when multiple goroutines called Assemble() concurrently: if e.assembler == nil { e.assembler = &Assembler{...} } Changes: - Add assemblerMu sync.Mutex to Engine struct - Add initAssemblerOnce() using double-checked locking (same pattern as initCompactionOnce) - Add TestAssemblerLazyInitRace to verify thread-safety * fix(seahorse): handle non-consecutive depths in selectShallowestCondensationCandidate Issue #8 from PR review: the loop iterated depth 0, 1, 2... assuming consecutive keys, but break when key was missing caused deeper depths to never be checked. Fix: collect all existing depth keys, sort, then iterate in order. * fix(seahorse): wrap DeleteMessagesAfterID and appendContextItems in transactions - DeleteMessagesAfterID: wrap all DELETE operations in a transaction for atomicity, remove redundant manual FTS delete (handled by trigger) - appendContextItems: use transaction to fix read-then-write race condition - Add GetMaxOrdinalTx and resolveItemTokenCountTx for transaction-scoped queries - Remove unused resolveItemTokenCount function Fixes PR review issues 6 and 7. * fix(seahorse): derive readable content from Parts and cap CompactUntilUnder iterations - Derive readable content from MessageParts in AddMessageWithParts so FTS5 indexing and summary formatting can access tool call information - formatMessagesForSummary and truncateSummary now fall back to Parts when Content is empty, fixing blank summaries for Part-based messages - Add MaxCompactIterations (20) to prevent CompactUntilUnder infinite loops; exceeded iterations are logged as warnings
This commit is contained in:
@@ -67,3 +67,5 @@ web/backend/dist/*
|
|||||||
.claude/
|
.claude/
|
||||||
|
|
||||||
docker/data
|
docker/data
|
||||||
|
|
||||||
|
.omc/
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ linters:
|
|||||||
- exhaustruct
|
- exhaustruct
|
||||||
- funcorder
|
- funcorder
|
||||||
- gochecknoglobals
|
- gochecknoglobals
|
||||||
|
- gosmopolitan # Project legitimately uses CJK text in tests (FTS5, token counting)
|
||||||
- godot
|
- godot
|
||||||
- intrange
|
- intrange
|
||||||
- ireturn
|
- ireturn
|
||||||
|
|||||||
+11
-85
@@ -6,10 +6,8 @@
|
|||||||
package agent
|
package agent
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
|
||||||
"unicode/utf8"
|
|
||||||
|
|
||||||
"github.com/sipeed/picoclaw/pkg/providers"
|
"github.com/sipeed/picoclaw/pkg/providers"
|
||||||
|
"github.com/sipeed/picoclaw/pkg/tokenizer"
|
||||||
)
|
)
|
||||||
|
|
||||||
// parseTurnBoundaries returns the starting index of each Turn in the history.
|
// parseTurnBoundaries returns the starting index of each Turn in the history.
|
||||||
@@ -86,88 +84,16 @@ func findSafeBoundary(history []providers.Message, targetIndex int) int {
|
|||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// estimateMessageTokens estimates the token count for a single message,
|
// EstimateMessageTokens estimates the token count for a single message.
|
||||||
// including Content, ReasoningContent, ToolCalls arguments, ToolCallID
|
// Delegates to the shared tokenizer package for consistency across agent and seahorse.
|
||||||
// metadata, and Media items. Uses a heuristic of 2.5 characters per token.
|
func EstimateMessageTokens(msg providers.Message) int {
|
||||||
func estimateMessageTokens(msg providers.Message) int {
|
return tokenizer.EstimateMessageTokens(msg)
|
||||||
contentChars := utf8.RuneCountInString(msg.Content)
|
|
||||||
|
|
||||||
// SystemParts are structured system blocks used for cache-aware adapters.
|
|
||||||
// They carry the same content as Content, but in multiple blocks.
|
|
||||||
// We estimate them as an alternative representation, not additive.
|
|
||||||
systemPartsChars := 0
|
|
||||||
if len(msg.SystemParts) > 0 {
|
|
||||||
for _, part := range msg.SystemParts {
|
|
||||||
systemPartsChars += utf8.RuneCountInString(part.Text)
|
|
||||||
}
|
|
||||||
// Per-part overhead for JSON structure (type, text, cache_control).
|
|
||||||
const perPartOverhead = 20
|
|
||||||
systemPartsChars += len(msg.SystemParts) * perPartOverhead
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use the larger of the two representations to stay conservative.
|
|
||||||
chars := contentChars
|
|
||||||
if systemPartsChars > chars {
|
|
||||||
chars = systemPartsChars
|
|
||||||
}
|
|
||||||
|
|
||||||
chars += utf8.RuneCountInString(msg.ReasoningContent)
|
|
||||||
|
|
||||||
for _, tc := range msg.ToolCalls {
|
|
||||||
chars += len(tc.ID) + len(tc.Type)
|
|
||||||
if tc.Function != nil {
|
|
||||||
// Count function name + arguments (the wire format for most providers).
|
|
||||||
// tc.Name mirrors tc.Function.Name — count only once to avoid double-counting.
|
|
||||||
chars += len(tc.Function.Name) + len(tc.Function.Arguments)
|
|
||||||
} else {
|
|
||||||
// Fallback: some provider formats use top-level Name without Function.
|
|
||||||
chars += len(tc.Name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if msg.ToolCallID != "" {
|
|
||||||
chars += len(msg.ToolCallID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Per-message overhead for role label, JSON structure, separators.
|
|
||||||
const messageOverhead = 12
|
|
||||||
chars += messageOverhead
|
|
||||||
|
|
||||||
tokens := chars * 2 / 5
|
|
||||||
|
|
||||||
// Media items (images, files) are serialized by provider adapters into
|
|
||||||
// multipart or image_url payloads. Add a fixed per-item token estimate
|
|
||||||
// directly (not through the chars heuristic) since actual cost depends
|
|
||||||
// on resolution and provider-specific image tokenization.
|
|
||||||
const mediaTokensPerItem = 256
|
|
||||||
tokens += len(msg.Media) * mediaTokensPerItem
|
|
||||||
|
|
||||||
return tokens
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// estimateToolDefsTokens estimates the total token cost of tool definitions
|
// EstimateToolDefsTokens estimates the total token cost of tool definitions
|
||||||
// as they appear in the LLM request. Each tool's name, description, and
|
// as they appear in the LLM request. Delegates to the shared tokenizer package.
|
||||||
// JSON schema parameters contribute to the context window budget.
|
func EstimateToolDefsTokens(defs []providers.ToolDefinition) int {
|
||||||
func estimateToolDefsTokens(defs []providers.ToolDefinition) int {
|
return tokenizer.EstimateToolDefsTokens(defs)
|
||||||
if len(defs) == 0 {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
totalChars := 0
|
|
||||||
for _, d := range defs {
|
|
||||||
totalChars += len(d.Function.Name) + len(d.Function.Description)
|
|
||||||
|
|
||||||
if d.Function.Parameters != nil {
|
|
||||||
if paramJSON, err := json.Marshal(d.Function.Parameters); err == nil {
|
|
||||||
totalChars += len(paramJSON)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Per-tool overhead: type field, JSON structure, separators.
|
|
||||||
totalChars += 20
|
|
||||||
}
|
|
||||||
|
|
||||||
return totalChars * 2 / 5
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// isOverContextBudget checks whether the assembled messages plus tool definitions
|
// isOverContextBudget checks whether the assembled messages plus tool definitions
|
||||||
@@ -181,10 +107,10 @@ func isOverContextBudget(
|
|||||||
) bool {
|
) bool {
|
||||||
msgTokens := 0
|
msgTokens := 0
|
||||||
for _, m := range messages {
|
for _, m := range messages {
|
||||||
msgTokens += estimateMessageTokens(m)
|
msgTokens += EstimateMessageTokens(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
toolTokens := estimateToolDefsTokens(toolDefs)
|
toolTokens := EstimateToolDefsTokens(toolDefs)
|
||||||
total := msgTokens + toolTokens + maxTokens
|
total := msgTokens + toolTokens + maxTokens
|
||||||
|
|
||||||
return total > contextWindow
|
return total > contextWindow
|
||||||
|
|||||||
@@ -417,9 +417,9 @@ func TestEstimateMessageTokens(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got := estimateMessageTokens(tt.msg)
|
got := EstimateMessageTokens(tt.msg)
|
||||||
if got < tt.want {
|
if got < tt.want {
|
||||||
t.Errorf("estimateMessageTokens() = %d, want >= %d", got, tt.want)
|
t.Errorf("EstimateMessageTokens() = %d, want >= %d", got, tt.want)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -443,8 +443,8 @@ func TestEstimateMessageTokens_ToolCallsContribute(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
plainTokens := estimateMessageTokens(plain)
|
plainTokens := EstimateMessageTokens(plain)
|
||||||
withTCTokens := estimateMessageTokens(withTC)
|
withTCTokens := EstimateMessageTokens(withTC)
|
||||||
|
|
||||||
if withTCTokens <= plainTokens {
|
if withTCTokens <= plainTokens {
|
||||||
t.Errorf("message with ToolCalls (%d tokens) should exceed plain message (%d tokens)",
|
t.Errorf("message with ToolCalls (%d tokens) should exceed plain message (%d tokens)",
|
||||||
@@ -457,7 +457,7 @@ func TestEstimateMessageTokens_MultibyteContent(t *testing.T) {
|
|||||||
// but may map to different token counts. The heuristic should still produce
|
// but may map to different token counts. The heuristic should still produce
|
||||||
// reasonable estimates via RuneCountInString.
|
// reasonable estimates via RuneCountInString.
|
||||||
msg := msgUser("caf\u00e9 na\u00efve r\u00e9sum\u00e9 \u00fcber stra\u00dfe")
|
msg := msgUser("caf\u00e9 na\u00efve r\u00e9sum\u00e9 \u00fcber stra\u00dfe")
|
||||||
tokens := estimateMessageTokens(msg)
|
tokens := EstimateMessageTokens(msg)
|
||||||
if tokens <= 0 {
|
if tokens <= 0 {
|
||||||
t.Errorf("multibyte message should produce positive token count, got %d", tokens)
|
t.Errorf("multibyte message should produce positive token count, got %d", tokens)
|
||||||
}
|
}
|
||||||
@@ -481,7 +481,7 @@ func TestEstimateMessageTokens_LargeArguments(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
tokens := estimateMessageTokens(msg)
|
tokens := EstimateMessageTokens(msg)
|
||||||
// 5000+ chars → at least 2000 tokens with the 2.5 char/token heuristic
|
// 5000+ chars → at least 2000 tokens with the 2.5 char/token heuristic
|
||||||
if tokens < 2000 {
|
if tokens < 2000 {
|
||||||
t.Errorf("large tool call arguments should produce significant token count, got %d", tokens)
|
t.Errorf("large tool call arguments should produce significant token count, got %d", tokens)
|
||||||
@@ -496,8 +496,8 @@ func TestEstimateMessageTokens_ReasoningContent(t *testing.T) {
|
|||||||
ReasoningContent: strings.Repeat("thinking step ", 200),
|
ReasoningContent: strings.Repeat("thinking step ", 200),
|
||||||
}
|
}
|
||||||
|
|
||||||
plainTokens := estimateMessageTokens(plain)
|
plainTokens := EstimateMessageTokens(plain)
|
||||||
reasoningTokens := estimateMessageTokens(withReasoning)
|
reasoningTokens := EstimateMessageTokens(withReasoning)
|
||||||
|
|
||||||
if reasoningTokens <= plainTokens {
|
if reasoningTokens <= plainTokens {
|
||||||
t.Errorf("message with ReasoningContent (%d tokens) should exceed plain message (%d tokens)",
|
t.Errorf("message with ReasoningContent (%d tokens) should exceed plain message (%d tokens)",
|
||||||
@@ -513,8 +513,8 @@ func TestEstimateMessageTokens_MediaItems(t *testing.T) {
|
|||||||
Media: []string{"media://img1.png", "media://img2.png"},
|
Media: []string{"media://img1.png", "media://img2.png"},
|
||||||
}
|
}
|
||||||
|
|
||||||
plainTokens := estimateMessageTokens(plain)
|
plainTokens := EstimateMessageTokens(plain)
|
||||||
mediaTokens := estimateMessageTokens(withMedia)
|
mediaTokens := EstimateMessageTokens(withMedia)
|
||||||
|
|
||||||
if mediaTokens <= plainTokens {
|
if mediaTokens <= plainTokens {
|
||||||
t.Errorf("message with Media (%d tokens) should exceed plain message (%d tokens)",
|
t.Errorf("message with Media (%d tokens) should exceed plain message (%d tokens)",
|
||||||
@@ -540,8 +540,8 @@ func TestEstimateMessageTokens_SystemParts(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
plainTokens := estimateMessageTokens(plain)
|
plainTokens := EstimateMessageTokens(plain)
|
||||||
partsTokens := estimateMessageTokens(withParts)
|
partsTokens := EstimateMessageTokens(withParts)
|
||||||
|
|
||||||
if partsTokens <= plainTokens {
|
if partsTokens <= plainTokens {
|
||||||
t.Errorf("system message with SystemParts (%d) should exceed plain message (%d)",
|
t.Errorf("system message with SystemParts (%d) should exceed plain message (%d)",
|
||||||
@@ -549,7 +549,7 @@ func TestEstimateMessageTokens_SystemParts(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- estimateToolDefsTokens tests ---
|
// --- EstimateToolDefsTokens tests ---
|
||||||
|
|
||||||
func TestEstimateToolDefsTokens(t *testing.T) {
|
func TestEstimateToolDefsTokens(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
@@ -599,9 +599,9 @@ func TestEstimateToolDefsTokens(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got := estimateToolDefsTokens(tt.defs)
|
got := EstimateToolDefsTokens(tt.defs)
|
||||||
if got < tt.want {
|
if got < tt.want {
|
||||||
t.Errorf("estimateToolDefsTokens() = %d, want >= %d", got, tt.want)
|
t.Errorf("EstimateToolDefsTokens() = %d, want >= %d", got, tt.want)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -624,8 +624,8 @@ func TestEstimateToolDefsTokens_ScalesWithCount(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
one := estimateToolDefsTokens([]providers.ToolDefinition{makeTool("tool_a")})
|
one := EstimateToolDefsTokens([]providers.ToolDefinition{makeTool("tool_a")})
|
||||||
three := estimateToolDefsTokens([]providers.ToolDefinition{
|
three := EstimateToolDefsTokens([]providers.ToolDefinition{
|
||||||
makeTool("tool_a"), makeTool("tool_b"), makeTool("tool_c"),
|
makeTool("tool_a"), makeTool("tool_b"), makeTool("tool_c"),
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -770,7 +770,7 @@ func TestEstimateMessageTokens_WithReasoningAndMedia(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
tokens := estimateMessageTokens(msg)
|
tokens := EstimateMessageTokens(msg)
|
||||||
|
|
||||||
// ReasoningContent alone is ~1700 chars → ~680 tokens.
|
// ReasoningContent alone is ~1700 chars → ~680 tokens.
|
||||||
// Content + TC + overhead adds more. Should be well above 500.
|
// Content + TC + overhead adds more. Should be well above 500.
|
||||||
@@ -781,7 +781,7 @@ func TestEstimateMessageTokens_WithReasoningAndMedia(t *testing.T) {
|
|||||||
// Compare without reasoning to ensure it's counted.
|
// Compare without reasoning to ensure it's counted.
|
||||||
msgNoReasoning := msg
|
msgNoReasoning := msg
|
||||||
msgNoReasoning.ReasoningContent = ""
|
msgNoReasoning.ReasoningContent = ""
|
||||||
tokensNoReasoning := estimateMessageTokens(msgNoReasoning)
|
tokensNoReasoning := EstimateMessageTokens(msgNoReasoning)
|
||||||
|
|
||||||
if tokens <= tokensNoReasoning {
|
if tokens <= tokensNoReasoning {
|
||||||
t.Errorf("reasoning content should add tokens: with=%d, without=%d", tokens, tokensNoReasoning)
|
t.Errorf("reasoning content should add tokens: with=%d, without=%d", tokens, tokensNoReasoning)
|
||||||
|
|||||||
@@ -373,7 +373,7 @@ func (m *legacyContextManager) summarizeBatch(
|
|||||||
func (m *legacyContextManager) estimateTokens(messages []providers.Message) int {
|
func (m *legacyContextManager) estimateTokens(messages []providers.Message) int {
|
||||||
total := 0
|
total := 0
|
||||||
for _, msg := range messages {
|
for _, msg := range messages {
|
||||||
total += estimateMessageTokens(msg)
|
total += EstimateMessageTokens(msg)
|
||||||
}
|
}
|
||||||
return total
|
return total
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ type AssembleResponse struct {
|
|||||||
type CompactRequest struct {
|
type CompactRequest struct {
|
||||||
SessionKey string // session identifier
|
SessionKey string // session identifier
|
||||||
Reason ContextCompressReason // proactive_budget | llm_retry | summarize
|
Reason ContextCompressReason // proactive_budget | llm_retry | summarize
|
||||||
|
Budget int // context window budget (used for retry aggressive compaction)
|
||||||
}
|
}
|
||||||
|
|
||||||
// IngestRequest is the input to Ingest.
|
// IngestRequest is the input to Ingest.
|
||||||
|
|||||||
@@ -0,0 +1,267 @@
|
|||||||
|
package agent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/sipeed/picoclaw/pkg/logger"
|
||||||
|
"github.com/sipeed/picoclaw/pkg/providers"
|
||||||
|
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
|
||||||
|
"github.com/sipeed/picoclaw/pkg/seahorse"
|
||||||
|
"github.com/sipeed/picoclaw/pkg/session"
|
||||||
|
"github.com/sipeed/picoclaw/pkg/tokenizer"
|
||||||
|
)
|
||||||
|
|
||||||
|
// seahorseContextManager adapts seahorse.Engine to agent.ContextManager.
|
||||||
|
type seahorseContextManager struct {
|
||||||
|
engine *seahorse.Engine
|
||||||
|
sessions session.SessionStore // for startup bootstrap
|
||||||
|
}
|
||||||
|
|
||||||
|
// newSeahorseContextManager creates a seahorse-backed ContextManager.
|
||||||
|
func newSeahorseContextManager(_ json.RawMessage, al *AgentLoop) (ContextManager, error) {
|
||||||
|
if al == nil {
|
||||||
|
return nil, fmt.Errorf("seahorse: AgentLoop is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resolve workspace for DB path
|
||||||
|
// DB stores session data, so it goes in sessions/ directory
|
||||||
|
agent := al.registry.GetDefaultAgent()
|
||||||
|
dbPath := agent.Workspace + "/sessions/seahorse.db"
|
||||||
|
|
||||||
|
// Create CompleteFn from provider
|
||||||
|
completeFn := providerToCompleteFn(agent.Provider, agent.Model)
|
||||||
|
|
||||||
|
// Create engine
|
||||||
|
engine, err := seahorse.NewEngine(seahorse.Config{
|
||||||
|
DBPath: dbPath,
|
||||||
|
}, completeFn)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("seahorse: create engine: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr := &seahorseContextManager{
|
||||||
|
engine: engine,
|
||||||
|
sessions: agent.Sessions,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register seahorse tools with the agent's tool registry
|
||||||
|
retrieval := mgr.engine.GetRetrieval()
|
||||||
|
al.RegisterTool(seahorse.NewGrepTool(retrieval))
|
||||||
|
al.RegisterTool(seahorse.NewExpandTool(retrieval))
|
||||||
|
|
||||||
|
// Bootstrap all existing sessions at startup
|
||||||
|
if agent.Sessions != nil {
|
||||||
|
ctx := context.Background()
|
||||||
|
for _, sessionKey := range agent.Sessions.ListSessions() {
|
||||||
|
mgr.bootstrapSession(ctx, sessionKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return mgr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// providerToCompleteFn wraps providers.LLMProvider as a seahorse.CompleteFn.
|
||||||
|
func providerToCompleteFn(provider providers.LLMProvider, model string) seahorse.CompleteFn {
|
||||||
|
return func(ctx context.Context, prompt string, opts seahorse.CompleteOptions) (string, error) {
|
||||||
|
resp, err := provider.Chat(
|
||||||
|
ctx,
|
||||||
|
[]providers.Message{{Role: "user", Content: prompt}},
|
||||||
|
nil, // no tools for summarization
|
||||||
|
model,
|
||||||
|
map[string]any{
|
||||||
|
"max_tokens": opts.MaxTokens,
|
||||||
|
"temperature": opts.Temperature,
|
||||||
|
"prompt_cache_key": "seahorse",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return resp.Content, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assemble builds budget-aware context from seahorse SQLite.
|
||||||
|
func (m *seahorseContextManager) Assemble(ctx context.Context, req *AssembleRequest) (*AssembleResponse, error) {
|
||||||
|
if req == nil {
|
||||||
|
return nil, fmt.Errorf("seahorse assemble: nil request")
|
||||||
|
}
|
||||||
|
|
||||||
|
budget := req.Budget
|
||||||
|
if budget <= 0 {
|
||||||
|
budget = 100000
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reserve space for model response (spec lines 1400-1410)
|
||||||
|
effectiveBudget := budget - req.MaxTokens
|
||||||
|
if effectiveBudget <= 0 {
|
||||||
|
// MaxTokens >= budget is a configuration problem
|
||||||
|
// Use 50% as minimum to avoid guaranteed overflow
|
||||||
|
logger.WarnCF("agent", "MaxTokens >= budget, using 50% fallback",
|
||||||
|
map[string]any{"budget": budget, "max_tokens": req.MaxTokens})
|
||||||
|
effectiveBudget = budget / 2
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := m.engine.Assemble(ctx, req.SessionKey, seahorse.AssembleInput{
|
||||||
|
Budget: effectiveBudget,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("seahorse assemble: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
history := seahorseToProviderMessages(result)
|
||||||
|
|
||||||
|
// Summary is already formatted as XML with system prompt addition by assembler
|
||||||
|
return &AssembleResponse{
|
||||||
|
History: history,
|
||||||
|
Summary: result.Summary,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compact compresses conversation history via seahorse summarization.
|
||||||
|
func (m *seahorseContextManager) Compact(ctx context.Context, req *CompactRequest) error {
|
||||||
|
if req == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// For retry (LLM overflow), use aggressive CompactUntilUnder to guarantee
|
||||||
|
// context shrinks below budget (spec lines ~1410).
|
||||||
|
if req.Reason == ContextCompressReasonRetry && req.Budget > 0 {
|
||||||
|
_, err := m.engine.CompactUntilUnder(ctx, req.SessionKey, req.Budget)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := m.engine.Compact(ctx, req.SessionKey, seahorse.CompactInput{
|
||||||
|
Force: req.Reason == ContextCompressReasonRetry,
|
||||||
|
Budget: &req.Budget,
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ingest records a message into seahorse SQLite.
|
||||||
|
// All existing sessions are bootstrapped at startup, so this only ingests new messages.
|
||||||
|
func (m *seahorseContextManager) Ingest(ctx context.Context, req *IngestRequest) error {
|
||||||
|
if req == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := providerToSeahorseMessage(req.Message)
|
||||||
|
_, err := m.engine.Ingest(ctx, req.SessionKey, []seahorse.Message{msg})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// bootstrapSession reconciles JSONL session history into seahorse SQLite.
|
||||||
|
func (m *seahorseContextManager) bootstrapSession(ctx context.Context, sessionKey string) {
|
||||||
|
if m.sessions == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
history := m.sessions.GetHistory(sessionKey)
|
||||||
|
if len(history) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert provider messages to seahorse messages
|
||||||
|
msgs := make([]seahorse.Message, len(history))
|
||||||
|
for i, h := range history {
|
||||||
|
msgs[i] = providerToSeahorseMessage(h)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.engine.Bootstrap(ctx, sessionKey, msgs); err != nil {
|
||||||
|
logger.WarnCF("seahorse", "bootstrap", map[string]any{
|
||||||
|
"session": sessionKey,
|
||||||
|
"error": err.Error(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// providerToSeahorseMessage converts a providers.Message to a seahorse.Message.
|
||||||
|
func providerToSeahorseMessage(msg protocoltypes.Message) seahorse.Message {
|
||||||
|
result := seahorse.Message{
|
||||||
|
Role: msg.Role,
|
||||||
|
Content: msg.Content,
|
||||||
|
ReasoningContent: msg.ReasoningContent,
|
||||||
|
TokenCount: tokenizer.EstimateMessageTokens(msg),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert ToolCalls → MessageParts
|
||||||
|
for _, tc := range msg.ToolCalls {
|
||||||
|
part := seahorse.MessagePart{
|
||||||
|
Type: "tool_use",
|
||||||
|
Name: tc.Function.Name,
|
||||||
|
Arguments: tc.Function.Arguments,
|
||||||
|
ToolCallID: tc.ID,
|
||||||
|
}
|
||||||
|
result.Parts = append(result.Parts, part)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert tool result
|
||||||
|
if msg.ToolCallID != "" {
|
||||||
|
part := seahorse.MessagePart{
|
||||||
|
Type: "tool_result",
|
||||||
|
ToolCallID: msg.ToolCallID,
|
||||||
|
Text: msg.Content,
|
||||||
|
}
|
||||||
|
result.Parts = append(result.Parts, part)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert media attachments
|
||||||
|
for _, mediaURI := range msg.Media {
|
||||||
|
part := seahorse.MessagePart{
|
||||||
|
Type: "media",
|
||||||
|
MediaURI: mediaURI,
|
||||||
|
}
|
||||||
|
result.Parts = append(result.Parts, part)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// seahorseToProviderMessages converts a seahorse.AssembleResult to []providers.Message.
|
||||||
|
func seahorseToProviderMessages(result *seahorse.AssembleResult) []protocoltypes.Message {
|
||||||
|
messages := make([]protocoltypes.Message, 0, len(result.Messages))
|
||||||
|
|
||||||
|
// Convert assembled messages (which already include summary XML messages)
|
||||||
|
for _, msg := range result.Messages {
|
||||||
|
pm := protocoltypes.Message{
|
||||||
|
Role: msg.Role,
|
||||||
|
Content: msg.Content,
|
||||||
|
ReasoningContent: msg.ReasoningContent,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reconstruct ToolCalls from parts
|
||||||
|
for _, part := range msg.Parts {
|
||||||
|
if part.Type == "tool_use" {
|
||||||
|
pm.ToolCalls = append(pm.ToolCalls, protocoltypes.ToolCall{
|
||||||
|
ID: part.ToolCallID,
|
||||||
|
Type: "function", // Required by OpenAI-compatible APIs (GLM, etc.)
|
||||||
|
Function: &protocoltypes.FunctionCall{
|
||||||
|
Name: part.Name,
|
||||||
|
Arguments: part.Arguments,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if part.Type == "tool_result" {
|
||||||
|
pm.ToolCallID = part.ToolCallID
|
||||||
|
if pm.Content == "" && part.Text != "" {
|
||||||
|
pm.Content = part.Text
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if part.Type == "media" && part.MediaURI != "" {
|
||||||
|
pm.Media = append(pm.Media, part.MediaURI)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
messages = append(messages, pm)
|
||||||
|
}
|
||||||
|
|
||||||
|
return messages
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
if err := RegisterContextManager("seahorse", newSeahorseContextManager); err != nil {
|
||||||
|
panic(fmt.Sprintf("register seahorse context manager: %v", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
File diff suppressed because it is too large
Load Diff
+5
-1
@@ -1742,6 +1742,7 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
|
|||||||
if err := al.contextManager.Compact(turnCtx, &CompactRequest{
|
if err := al.contextManager.Compact(turnCtx, &CompactRequest{
|
||||||
SessionKey: ts.sessionKey,
|
SessionKey: ts.sessionKey,
|
||||||
Reason: ContextCompressReasonProactive,
|
Reason: ContextCompressReasonProactive,
|
||||||
|
Budget: ts.agent.ContextWindow,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
logger.WarnCF("agent", "Proactive compact failed", map[string]any{
|
logger.WarnCF("agent", "Proactive compact failed", map[string]any{
|
||||||
"session_key": ts.sessionKey,
|
"session_key": ts.sessionKey,
|
||||||
@@ -1857,6 +1858,7 @@ turnLoop:
|
|||||||
if !ts.opts.NoHistory {
|
if !ts.opts.NoHistory {
|
||||||
ts.agent.Sessions.AddFullMessage(ts.sessionKey, pm)
|
ts.agent.Sessions.AddFullMessage(ts.sessionKey, pm)
|
||||||
ts.recordPersistedMessage(pm)
|
ts.recordPersistedMessage(pm)
|
||||||
|
ts.ingestMessage(turnCtx, al, pm)
|
||||||
}
|
}
|
||||||
logger.InfoCF("agent", "Injected steering message into context",
|
logger.InfoCF("agent", "Injected steering message into context",
|
||||||
map[string]any{
|
map[string]any{
|
||||||
@@ -2128,6 +2130,7 @@ turnLoop:
|
|||||||
if compactErr := al.contextManager.Compact(turnCtx, &CompactRequest{
|
if compactErr := al.contextManager.Compact(turnCtx, &CompactRequest{
|
||||||
SessionKey: ts.sessionKey,
|
SessionKey: ts.sessionKey,
|
||||||
Reason: ContextCompressReasonRetry,
|
Reason: ContextCompressReasonRetry,
|
||||||
|
Budget: ts.agent.ContextWindow,
|
||||||
}); compactErr != nil {
|
}); compactErr != nil {
|
||||||
logger.WarnCF("agent", "Context overflow compact failed", map[string]any{
|
logger.WarnCF("agent", "Context overflow compact failed", map[string]any{
|
||||||
"session_key": ts.sessionKey,
|
"session_key": ts.sessionKey,
|
||||||
@@ -2773,7 +2776,7 @@ turnLoop:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if ts.opts.EnableSummary {
|
if ts.opts.EnableSummary {
|
||||||
al.contextManager.Compact(turnCtx, &CompactRequest{SessionKey: ts.sessionKey, Reason: ContextCompressReasonSummarize})
|
al.contextManager.Compact(turnCtx, &CompactRequest{SessionKey: ts.sessionKey, Reason: ContextCompressReasonSummarize, Budget: ts.agent.ContextWindow})
|
||||||
}
|
}
|
||||||
|
|
||||||
ts.setPhase(TurnPhaseCompleted)
|
ts.setPhase(TurnPhaseCompleted)
|
||||||
@@ -2849,6 +2852,7 @@ turnLoop:
|
|||||||
&CompactRequest{
|
&CompactRequest{
|
||||||
SessionKey: ts.sessionKey,
|
SessionKey: ts.sessionKey,
|
||||||
Reason: ContextCompressReasonSummarize,
|
Reason: ContextCompressReasonSummarize,
|
||||||
|
Budget: ts.agent.ContextWindow,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -604,6 +604,7 @@ type ephemeralSessionStoreIface interface {
|
|||||||
SetHistory(key string, history []providers.Message)
|
SetHistory(key string, history []providers.Message)
|
||||||
TruncateHistory(key string, keepLast int)
|
TruncateHistory(key string, keepLast int)
|
||||||
Save(key string) error
|
Save(key string) error
|
||||||
|
ListSessions() []string
|
||||||
Close() error
|
Close() error
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -663,8 +664,9 @@ func (e *ephemeralSessionStore) TruncateHistory(_ string, keepLast int) {
|
|||||||
e.history = e.history[len(e.history)-keepLast:]
|
e.history = e.history[len(e.history)-keepLast:]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *ephemeralSessionStore) Save(_ string) error { return nil }
|
func (e *ephemeralSessionStore) Save(_ string) error { return nil }
|
||||||
func (e *ephemeralSessionStore) Close() error { return nil }
|
func (e *ephemeralSessionStore) Close() error { return nil }
|
||||||
|
func (e *ephemeralSessionStore) ListSessions() []string { return nil }
|
||||||
|
|
||||||
func (e *ephemeralSessionStore) truncateLocked() {
|
func (e *ephemeralSessionStore) truncateLocked() {
|
||||||
if len(e.history) > maxEphemeralHistorySize {
|
if len(e.history) > maxEphemeralHistorySize {
|
||||||
|
|||||||
@@ -455,6 +455,33 @@ func (s *JSONLStore) rewriteJSONL(
|
|||||||
return fileutil.WriteFileAtomic(s.jsonlPath(sessionKey), buf.Bytes(), 0o644)
|
return fileutil.WriteFileAtomic(s.jsonlPath(sessionKey), buf.Bytes(), 0o644)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ListSessions returns all known session keys by reading .meta.json files.
|
||||||
|
func (s *JSONLStore) ListSessions() []string {
|
||||||
|
entries, err := os.ReadDir(s.dir)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var keys []string
|
||||||
|
for _, entry := range entries {
|
||||||
|
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".meta.json") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Read the meta file to get the original key
|
||||||
|
data, err := os.ReadFile(filepath.Join(s.dir, entry.Name()))
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var meta sessionMeta
|
||||||
|
if err := json.Unmarshal(data, &meta); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if meta.Key != "" {
|
||||||
|
keys = append(keys, meta.Key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return keys
|
||||||
|
}
|
||||||
|
|
||||||
func (s *JSONLStore) Close() error {
|
func (s *JSONLStore) Close() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -37,6 +37,9 @@ type Store interface {
|
|||||||
// data. Backends that do not accumulate dead data may return nil.
|
// data. Backends that do not accumulate dead data may return nil.
|
||||||
Compact(ctx context.Context, sessionKey string) error
|
Compact(ctx context.Context, sessionKey string) error
|
||||||
|
|
||||||
|
// ListSessions returns all known session keys.
|
||||||
|
ListSessions() []string
|
||||||
|
|
||||||
// Close releases any resources held by the store.
|
// Close releases any resources held by the store.
|
||||||
Close() error
|
Close() error
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,7 @@
|
|||||||
|
{
|
||||||
|
"tool_name": "Bash",
|
||||||
|
"tool_input_preview": "{\"command\":\"cd /home/yliu/repos/picoclaw && make lint 2>&1\",\"timeout\":120000}",
|
||||||
|
"error": "Exit code 2\npkg/agent/context_seahorse_test.go:1027:1: File is not properly formatted (gci)\n\t\t\tEarliestAt: &now,\n^\n1 issues:\n* gci: 1\nmake: *** [Makefile:264: lint] Error 1",
|
||||||
|
"timestamp": "2026-04-04T02:38:32.067Z",
|
||||||
|
"retry_count": 6
|
||||||
|
}
|
||||||
@@ -0,0 +1,58 @@
|
|||||||
|
package seahorse
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// CompactUntilUnder iteration cap
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
func TestCompactUntilUnderIterationCap(t *testing.T) {
|
||||||
|
// Setup: create a conversation with so many tokens that compaction
|
||||||
|
// will never reach the budget. The iteration cap prevents infinite loops.
|
||||||
|
//
|
||||||
|
// We use a mock CompleteFn that always returns the same content,
|
||||||
|
// and a budget of 0 which tokens can never reach.
|
||||||
|
// Without the cap, this would loop forever.
|
||||||
|
|
||||||
|
db := openTestDB(t)
|
||||||
|
if err := runSchema(db); err != nil {
|
||||||
|
t.Fatalf("migration: %v", err)
|
||||||
|
}
|
||||||
|
s := &Store{db: db}
|
||||||
|
|
||||||
|
conv, _ := s.GetOrCreateConversation(context.Background(), "agent:iter-cap")
|
||||||
|
convID := conv.ConversationID
|
||||||
|
|
||||||
|
// Add many messages to ensure there's plenty to compact
|
||||||
|
for i := 0; i < 40; i++ {
|
||||||
|
m, _ := s.AddMessage(context.Background(), convID, "user",
|
||||||
|
"this is a long message with lots of tokens to push context over budget", 100)
|
||||||
|
s.AppendContextMessage(context.Background(), convID, m.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// A completeFn that always succeeds but returns non-reducing content
|
||||||
|
mockComplete := func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) {
|
||||||
|
return "Summary that doesn't reduce tokens much.", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ce, cancel := newTestCompactionEngineWithStore(s, mockComplete)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Use budget=1 so tokens can never reach budget
|
||||||
|
// (each message is 100 tokens, so 40 messages = 4000 tokens, budget 1 is unreachable)
|
||||||
|
// The function should stop after maxCompactIterations, not loop forever
|
||||||
|
ce.config = Config{} // ensure defaults
|
||||||
|
|
||||||
|
result, err := ce.CompactUntilUnder(context.Background(), convID, 1)
|
||||||
|
if err != nil {
|
||||||
|
// Should not error — should stop gracefully
|
||||||
|
t.Fatalf("CompactUntilUnder with budget=0: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The function should have completed within reasonable time
|
||||||
|
// If it exceeded the cap, it would still return (not hang)
|
||||||
|
_ = result
|
||||||
|
}
|
||||||
@@ -0,0 +1,144 @@
|
|||||||
|
package seahorse
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// Bug 1: formatMessagesForSummary ignores Parts
|
||||||
|
// - formatMessagesForSummary only reads m.Content, empty for Part-based messages
|
||||||
|
// - truncateSummary has same issue
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
func TestFormatMessagesForSummaryIncludesParts(t *testing.T) {
|
||||||
|
ts := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||||
|
|
||||||
|
messages := []Message{
|
||||||
|
{ID: 1, Role: "user", Content: "hello world", CreatedAt: ts},
|
||||||
|
{
|
||||||
|
ID: 2,
|
||||||
|
Role: "assistant",
|
||||||
|
Content: "", // empty — real content is in Parts
|
||||||
|
Parts: []MessagePart{
|
||||||
|
{Type: "text", Text: "I will run a command"},
|
||||||
|
{Type: "tool_use", Name: "bash", Arguments: `{"command":"ls -la"}`, ToolCallID: "call_1"},
|
||||||
|
},
|
||||||
|
CreatedAt: ts.Add(time.Minute),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: 3,
|
||||||
|
Role: "tool",
|
||||||
|
Content: "", // empty — real content is in Parts
|
||||||
|
Parts: []MessagePart{
|
||||||
|
{Type: "tool_result", Text: "file1.txt\nfile2.txt", ToolCallID: "call_1"},
|
||||||
|
},
|
||||||
|
CreatedAt: ts.Add(2 * time.Minute),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := formatMessagesForSummary(messages)
|
||||||
|
|
||||||
|
// Must contain the plain text message
|
||||||
|
if !contains(result, "hello world") {
|
||||||
|
t.Error("formatMessagesForSummary: missing plain text content")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Must contain tool_use info (not blank)
|
||||||
|
if !contains(result, "bash") || !contains(result, "ls -la") {
|
||||||
|
t.Errorf("formatMessagesForSummary: tool_use info missing from Parts.\nGot:\n%s", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Must contain tool_result info (not blank)
|
||||||
|
if !contains(result, "file1.txt") {
|
||||||
|
t.Errorf("formatMessagesForSummary: tool_result text missing from Parts.\nGot:\n%s", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTruncateSummaryIncludesParts(t *testing.T) {
|
||||||
|
messages := []Message{
|
||||||
|
{ID: 1, Role: "user", Content: "run the tests", CreatedAt: time.Now()},
|
||||||
|
{
|
||||||
|
ID: 2,
|
||||||
|
Role: "assistant",
|
||||||
|
Content: "", // empty
|
||||||
|
Parts: []MessagePart{
|
||||||
|
{Type: "tool_use", Name: "bash", Arguments: `{"command":"go test ./..."}`, ToolCallID: "call_1"},
|
||||||
|
},
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: 3,
|
||||||
|
Role: "tool",
|
||||||
|
Content: "", // empty
|
||||||
|
Parts: []MessagePart{
|
||||||
|
{Type: "tool_result", Text: "PASS\nok 3.2s", ToolCallID: "call_1"},
|
||||||
|
},
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := truncateSummary(messages)
|
||||||
|
|
||||||
|
// Must contain plain text
|
||||||
|
if !contains(result, "run the tests") {
|
||||||
|
t.Error("truncateSummary: missing plain text content")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Must contain tool info from Parts (not blank)
|
||||||
|
if !contains(result, "bash") || !contains(result, "go test") {
|
||||||
|
t.Errorf("truncateSummary: tool_use info missing from Parts.\nGot:\n%s", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Must contain tool_result from Parts
|
||||||
|
if !contains(result, "PASS") {
|
||||||
|
t.Errorf("truncateSummary: tool_result text missing from Parts.\nGot:\n%s", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// Bug 2: SearchMessages cannot find Part-based messages
|
||||||
|
// - FTS5 indexes empty content, LIKE queries empty content
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
func TestSearchMessagesFindsPartBasedMessages(t *testing.T) {
|
||||||
|
s := openTestStore(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
conv, _ := s.GetOrCreateConversation(ctx, "agent:search-parts")
|
||||||
|
convID := conv.ConversationID
|
||||||
|
|
||||||
|
// Add a plain message (searchable)
|
||||||
|
s.AddMessage(ctx, convID, "user", "list the files please", 5)
|
||||||
|
|
||||||
|
// Add a Part-based message (tool_use) — currently NOT searchable
|
||||||
|
parts := []MessagePart{
|
||||||
|
{Type: "tool_use", Name: "bash", Arguments: `{"command":"grep -r TODO ."}`, ToolCallID: "call_1"},
|
||||||
|
}
|
||||||
|
s.AddMessageWithParts(ctx, convID, "assistant", parts, 10)
|
||||||
|
|
||||||
|
// Add a Part-based message (tool_result) — currently NOT searchable
|
||||||
|
resultParts := []MessagePart{
|
||||||
|
{Type: "tool_result", Text: "main.go:42: TODO fix this bug", ToolCallID: "call_1"},
|
||||||
|
}
|
||||||
|
s.AddMessageWithParts(ctx, convID, "tool", resultParts, 10)
|
||||||
|
|
||||||
|
// Search for "grep" — should find the tool_use message
|
||||||
|
results, err := s.SearchMessages(ctx, SearchInput{Pattern: "grep"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SearchMessages: %v", err)
|
||||||
|
}
|
||||||
|
if len(results) == 0 {
|
||||||
|
t.Error("SearchMessages: 'grep' not found — Part-based messages are invisible to search")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Search for "TODO fix" — should find the tool_result message
|
||||||
|
results2, err := s.SearchMessages(ctx, SearchInput{Pattern: "TODO fix"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SearchMessages: %v", err)
|
||||||
|
}
|
||||||
|
if len(results2) == 0 {
|
||||||
|
t.Error("SearchMessages: 'TODO fix' not found — tool_result messages are invisible to search")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,185 @@
|
|||||||
|
package seahorse
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/sipeed/picoclaw/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SQL statements for FTS5 tables with trigram tokenizer.
|
||||||
|
const (
|
||||||
|
sqlCreateSummariesFTS = `CREATE VIRTUAL TABLE IF NOT EXISTS summaries_fts USING fts5(
|
||||||
|
summary_id,
|
||||||
|
content,
|
||||||
|
tokenize="trigram"
|
||||||
|
)`
|
||||||
|
sqlCreateMessagesFTS = `CREATE VIRTUAL TABLE IF NOT EXISTS messages_fts USING fts5(
|
||||||
|
message_id,
|
||||||
|
content,
|
||||||
|
tokenize="trigram"
|
||||||
|
)`
|
||||||
|
sqlCheckFTS5Available = `CREATE VIRTUAL TABLE IF NOT EXISTS _fts5_check USING fts5(content)`
|
||||||
|
sqlCheckTrigramAvailable = `CREATE VIRTUAL TABLE IF NOT EXISTS _trigram_check USING fts5(content, tokenize="trigram")`
|
||||||
|
sqlDropFTS5Check = `DROP TABLE IF EXISTS _fts5_check`
|
||||||
|
sqlDropTrigramCheck = `DROP TABLE IF EXISTS _trigram_check`
|
||||||
|
)
|
||||||
|
|
||||||
|
// runSchema creates or upgrades the database schema.
|
||||||
|
// All schemas are idempotent (safe to run multiple times).
|
||||||
|
func runSchema(db *sql.DB) error {
|
||||||
|
// Check FTS5 support before creating tables
|
||||||
|
if err := checkFTS5Support(db); err != nil {
|
||||||
|
return fmt.Errorf("FTS5 check: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
stmts := []string{
|
||||||
|
`CREATE TABLE IF NOT EXISTS conversations (
|
||||||
|
conversation_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
session_key TEXT NOT NULL UNIQUE,
|
||||||
|
created_at TEXT NOT NULL DEFAULT (datetime('now')),
|
||||||
|
updated_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||||||
|
)`,
|
||||||
|
|
||||||
|
`CREATE TABLE IF NOT EXISTS messages (
|
||||||
|
message_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
conversation_id INTEGER NOT NULL REFERENCES conversations(conversation_id),
|
||||||
|
role TEXT NOT NULL,
|
||||||
|
content TEXT NOT NULL DEFAULT '',
|
||||||
|
token_count INTEGER NOT NULL DEFAULT 0,
|
||||||
|
created_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||||||
|
)`,
|
||||||
|
|
||||||
|
`CREATE TABLE IF NOT EXISTS message_parts (
|
||||||
|
part_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
message_id INTEGER NOT NULL REFERENCES messages(message_id),
|
||||||
|
type TEXT NOT NULL,
|
||||||
|
text TEXT,
|
||||||
|
name TEXT,
|
||||||
|
arguments TEXT,
|
||||||
|
tool_call_id TEXT,
|
||||||
|
media_uri TEXT,
|
||||||
|
mime_type TEXT,
|
||||||
|
ordinal INTEGER NOT NULL DEFAULT 0
|
||||||
|
)`,
|
||||||
|
|
||||||
|
`CREATE TABLE IF NOT EXISTS summaries (
|
||||||
|
summary_id TEXT PRIMARY KEY,
|
||||||
|
conversation_id INTEGER NOT NULL REFERENCES conversations(conversation_id),
|
||||||
|
kind TEXT NOT NULL,
|
||||||
|
depth INTEGER NOT NULL DEFAULT 0,
|
||||||
|
content TEXT NOT NULL,
|
||||||
|
token_count INTEGER NOT NULL DEFAULT 0,
|
||||||
|
earliest_at TEXT,
|
||||||
|
latest_at TEXT,
|
||||||
|
descendant_count INTEGER NOT NULL DEFAULT 0,
|
||||||
|
descendant_token_count INTEGER NOT NULL DEFAULT 0,
|
||||||
|
source_message_token_count INTEGER NOT NULL DEFAULT 0,
|
||||||
|
model TEXT,
|
||||||
|
created_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||||||
|
)`,
|
||||||
|
|
||||||
|
`CREATE TABLE IF NOT EXISTS summary_parents (
|
||||||
|
summary_id TEXT NOT NULL,
|
||||||
|
parent_summary_id TEXT NOT NULL,
|
||||||
|
PRIMARY KEY (summary_id, parent_summary_id)
|
||||||
|
)`,
|
||||||
|
|
||||||
|
`CREATE TABLE IF NOT EXISTS summary_messages (
|
||||||
|
summary_id TEXT NOT NULL,
|
||||||
|
message_id INTEGER NOT NULL,
|
||||||
|
ordinal INTEGER NOT NULL DEFAULT 0,
|
||||||
|
PRIMARY KEY (summary_id, message_id)
|
||||||
|
)`,
|
||||||
|
|
||||||
|
`CREATE TABLE IF NOT EXISTS context_items (
|
||||||
|
conversation_id INTEGER NOT NULL,
|
||||||
|
ordinal INTEGER NOT NULL,
|
||||||
|
item_type TEXT NOT NULL,
|
||||||
|
summary_id TEXT,
|
||||||
|
message_id INTEGER,
|
||||||
|
token_count INTEGER NOT NULL DEFAULT 0,
|
||||||
|
created_at TEXT NOT NULL DEFAULT (datetime('now')),
|
||||||
|
PRIMARY KEY (conversation_id, ordinal)
|
||||||
|
)`,
|
||||||
|
|
||||||
|
// FTS5 virtual table with trigram tokenizer for CJK support
|
||||||
|
sqlCreateSummariesFTS,
|
||||||
|
|
||||||
|
// FTS5 virtual table for message search with trigram tokenizer
|
||||||
|
sqlCreateMessagesFTS,
|
||||||
|
|
||||||
|
// Indexes for common query patterns
|
||||||
|
`CREATE INDEX IF NOT EXISTS idx_messages_conversation ON messages(conversation_id)`,
|
||||||
|
`CREATE INDEX IF NOT EXISTS idx_messages_created ON messages(conversation_id, created_at)`,
|
||||||
|
`CREATE INDEX IF NOT EXISTS idx_summaries_conversation ON summaries(conversation_id)`,
|
||||||
|
`CREATE INDEX IF NOT EXISTS idx_summaries_kind_depth ON summaries(conversation_id, kind, depth)`,
|
||||||
|
`CREATE INDEX IF NOT EXISTS idx_summary_parents_parent ON summary_parents(parent_summary_id)`,
|
||||||
|
`CREATE INDEX IF NOT EXISTS idx_summary_messages_message ON summary_messages(message_id)`,
|
||||||
|
`CREATE INDEX IF NOT EXISTS idx_context_items_conv ON context_items(conversation_id, ordinal)`,
|
||||||
|
|
||||||
|
// FTS5 triggers to keep summaries_fts in sync with summaries table
|
||||||
|
`CREATE TRIGGER IF NOT EXISTS summaries_ai AFTER INSERT ON summaries BEGIN
|
||||||
|
INSERT INTO summaries_fts (summary_id, content) VALUES (new.summary_id, new.content);
|
||||||
|
END`,
|
||||||
|
`CREATE TRIGGER IF NOT EXISTS summaries_ad AFTER DELETE ON summaries BEGIN
|
||||||
|
INSERT INTO summaries_fts (summaries_fts, summary_id, content) VALUES ('delete', old.summary_id, old.content);
|
||||||
|
END`,
|
||||||
|
`CREATE TRIGGER IF NOT EXISTS summaries_au AFTER UPDATE ON summaries BEGIN
|
||||||
|
INSERT INTO summaries_fts (summaries_fts, summary_id, content) VALUES ('delete', old.summary_id, old.content);
|
||||||
|
INSERT INTO summaries_fts (summary_id, content) VALUES (new.summary_id, new.content);
|
||||||
|
END`,
|
||||||
|
|
||||||
|
// FTS5 triggers to keep messages_fts in sync with messages table
|
||||||
|
`CREATE TRIGGER IF NOT EXISTS messages_ai AFTER INSERT ON messages BEGIN
|
||||||
|
INSERT INTO messages_fts (message_id, content) VALUES (new.message_id, new.content);
|
||||||
|
END`,
|
||||||
|
`CREATE TRIGGER IF NOT EXISTS messages_ad AFTER DELETE ON messages BEGIN
|
||||||
|
DELETE FROM messages_fts WHERE message_id = old.message_id;
|
||||||
|
END`,
|
||||||
|
`CREATE TRIGGER IF NOT EXISTS messages_au AFTER UPDATE ON messages BEGIN
|
||||||
|
DELETE FROM messages_fts WHERE message_id = old.message_id;
|
||||||
|
INSERT INTO messages_fts (message_id, content) VALUES (new.message_id, new.content);
|
||||||
|
END`,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, s := range stmts {
|
||||||
|
if _, err := db.Exec(s); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkFTS5Support verifies that SQLite has FTS5 with trigram tokenizer enabled.
|
||||||
|
// This is required for full-text search with CJK (Chinese, Japanese, Korean) support.
|
||||||
|
func checkFTS5Support(db *sql.DB) error {
|
||||||
|
// Check if FTS5 is compiled in
|
||||||
|
var fts5Enabled int
|
||||||
|
err := db.QueryRow(`SELECT sqlite_compileoption_used('ENABLE_FTS5')`).Scan(&fts5Enabled)
|
||||||
|
if err != nil {
|
||||||
|
// sqlite_compileoption_used might not exist in older SQLite
|
||||||
|
// Try a different approach: create a test FTS5 table
|
||||||
|
_, testErr := db.Exec(sqlCheckFTS5Available)
|
||||||
|
if testErr != nil {
|
||||||
|
return fmt.Errorf("SQLite FTS5 not available: %w (required for full-text search)", testErr)
|
||||||
|
}
|
||||||
|
db.Exec(sqlDropFTS5Check)
|
||||||
|
} else if fts5Enabled == 0 {
|
||||||
|
return fmt.Errorf("SQLite was compiled without FTS5 support (required for full-text search)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if trigram tokenizer is available by trying to create a test table
|
||||||
|
// Not all SQLite builds include the trigram tokenizer
|
||||||
|
_, err = db.Exec(sqlCheckTrigramAvailable)
|
||||||
|
if err != nil {
|
||||||
|
logger.WarnCF("seahorse", "SQLite trigram tokenizer not available, CJK search may be limited",
|
||||||
|
map[string]any{"error": err.Error()})
|
||||||
|
// Trigram is not strictly required, just better for CJK
|
||||||
|
// Don't return error, just log warning
|
||||||
|
} else {
|
||||||
|
db.Exec(sqlDropTrigramCheck)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,211 @@
|
|||||||
|
package seahorse
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
_ "modernc.org/sqlite"
|
||||||
|
)
|
||||||
|
|
||||||
|
func openTestDB(t *testing.T) *sql.DB {
|
||||||
|
t.Helper()
|
||||||
|
db, err := sql.Open("sqlite", ":memory:")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("open test db: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { db.Close() })
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunMigrations(t *testing.T) {
|
||||||
|
db := openTestDB(t)
|
||||||
|
|
||||||
|
if err := runSchema(db); err != nil {
|
||||||
|
t.Fatalf("runSchema: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify all tables exist
|
||||||
|
tables := []string{
|
||||||
|
"conversations",
|
||||||
|
"messages",
|
||||||
|
"message_parts",
|
||||||
|
"summaries",
|
||||||
|
"summary_parents",
|
||||||
|
"summary_messages",
|
||||||
|
"context_items",
|
||||||
|
}
|
||||||
|
for _, tbl := range tables {
|
||||||
|
var name string
|
||||||
|
err := db.QueryRow(
|
||||||
|
"SELECT name FROM sqlite_master WHERE type='table' AND name=?", tbl,
|
||||||
|
).Scan(&name)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("table %q not found: %v", tbl, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify FTS5 virtual table exists
|
||||||
|
var ftsName string
|
||||||
|
err := db.QueryRow(
|
||||||
|
"SELECT name FROM sqlite_master WHERE type='table' AND name='summaries_fts'",
|
||||||
|
).Scan(&ftsName)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("FTS5 table summaries_fts not found: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunMigrationsIdempotent(t *testing.T) {
|
||||||
|
db := openTestDB(t)
|
||||||
|
|
||||||
|
// Run migrations twice — should succeed both times
|
||||||
|
if err := runSchema(db); err != nil {
|
||||||
|
t.Fatalf("first migration: %v", err)
|
||||||
|
}
|
||||||
|
if err := runSchema(db); err != nil {
|
||||||
|
t.Fatalf("second migration (idempotent): %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify we can still insert data after double migration
|
||||||
|
res, err := db.Exec(
|
||||||
|
"INSERT INTO conversations (session_key, created_at, updated_at) VALUES (?, datetime('now'), datetime('now'))",
|
||||||
|
"test-session",
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("insert after double migration: %v", err)
|
||||||
|
}
|
||||||
|
id, _ := res.LastInsertId()
|
||||||
|
if id == 0 {
|
||||||
|
t.Error("expected non-zero conversation id")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMigrationConversationUnique(t *testing.T) {
|
||||||
|
db := openTestDB(t)
|
||||||
|
if err := runSchema(db); err != nil {
|
||||||
|
t.Fatalf("migration: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert first
|
||||||
|
_, err := db.Exec(
|
||||||
|
"INSERT INTO conversations (session_key, created_at, updated_at) VALUES (?, datetime('now'), datetime('now'))",
|
||||||
|
"unique-key",
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("first insert: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Duplicate should fail
|
||||||
|
_, err = db.Exec(
|
||||||
|
"INSERT INTO conversations (session_key, created_at, updated_at) VALUES (?, datetime('now'), datetime('now'))",
|
||||||
|
"unique-key",
|
||||||
|
)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected unique constraint violation for duplicate session_key")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMigrationSummaryFTSInsert(t *testing.T) {
|
||||||
|
db := openTestDB(t)
|
||||||
|
if err := runSchema(db); err != nil {
|
||||||
|
t.Fatalf("migration: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert a conversation first
|
||||||
|
_, err := db.Exec(
|
||||||
|
"INSERT INTO conversations (session_key, created_at, updated_at) VALUES (?, datetime('now'), datetime('now'))",
|
||||||
|
"fts-test",
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("insert conversation: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert a summary
|
||||||
|
_, err = db.Exec(
|
||||||
|
`INSERT INTO summaries (summary_id, conversation_id, kind, depth, content, token_count, created_at)
|
||||||
|
VALUES ('sum_test1', 1, 'leaf', 0, '你好世界 hello world', 10, datetime('now'))`)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("insert summary: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// FTS should find it — trigram tokenizer requires >= 3 chars
|
||||||
|
rows, err := db.Query(
|
||||||
|
"SELECT summary_id FROM summaries_fts WHERE summaries_fts MATCH ?",
|
||||||
|
"你好世",
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("FTS query: %v", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var found string
|
||||||
|
if rows.Next() {
|
||||||
|
if err := rows.Scan(&found); err != nil {
|
||||||
|
t.Fatalf("scan: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
t.Fatalf("rows.Err: %v", err)
|
||||||
|
}
|
||||||
|
if found != "sum_test1" {
|
||||||
|
t.Errorf("FTS: expected 'sum_test1', got %q", found)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMigrationSummaryParentsPK(t *testing.T) {
|
||||||
|
db := openTestDB(t)
|
||||||
|
if err := runSchema(db); err != nil {
|
||||||
|
t.Fatalf("migration: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert two summaries
|
||||||
|
for _, id := range []string{"sum_a", "sum_b"} {
|
||||||
|
_, err := db.Exec(
|
||||||
|
`INSERT INTO summaries (summary_id, conversation_id, kind, depth, content, token_count, created_at)
|
||||||
|
VALUES (?, 1, 'leaf', 0, 'content', 5, datetime('now'))`, id)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("insert summary %s: %v", id, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Link child to parent
|
||||||
|
_, err := db.Exec(
|
||||||
|
"INSERT INTO summary_parents (summary_id, parent_summary_id) VALUES ('sum_a', 'sum_b')")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("link: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Duplicate link should fail (composite PK)
|
||||||
|
_, err = db.Exec(
|
||||||
|
"INSERT INTO summary_parents (summary_id, parent_summary_id) VALUES ('sum_a', 'sum_b')")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected unique constraint violation for duplicate summary_parents link")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFTS5SQLConstants(t *testing.T) {
|
||||||
|
db := openTestDB(t)
|
||||||
|
|
||||||
|
// Verify FTS5 check SQL executes without error
|
||||||
|
_, err := db.Exec(sqlCheckFTS5Available)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("sqlCheckFTS5Available failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify trigram check SQL executes without error
|
||||||
|
_, err = db.Exec(sqlCheckTrigramAvailable)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("sqlCheckTrigramAvailable failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify summaries_fts SQL executes without error
|
||||||
|
_, err = db.Exec(sqlCreateSummariesFTS)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("sqlCreateSummariesFTS failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify messages_fts SQL executes without error
|
||||||
|
_, err = db.Exec(sqlCreateMessagesFTS)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("sqlCreateMessagesFTS failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,261 @@
|
|||||||
|
package seahorse
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sipeed/picoclaw/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// escapeXML escapes special characters for safe inclusion in XML content.
|
||||||
|
func escapeXML(s string) string {
|
||||||
|
s = strings.ReplaceAll(s, "&", "&")
|
||||||
|
s = strings.ReplaceAll(s, "<", "<")
|
||||||
|
s = strings.ReplaceAll(s, ">", ">")
|
||||||
|
s = strings.ReplaceAll(s, "\"", """)
|
||||||
|
s = strings.ReplaceAll(s, "'", "'")
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolvedItem is a context item resolved to its full content with token count.
|
||||||
|
type resolvedItem struct {
|
||||||
|
ordinal int
|
||||||
|
itemType string // "message" or "summary"
|
||||||
|
message *Message
|
||||||
|
summary *Summary
|
||||||
|
tokenCount int
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assemble builds budget-constrained context from summaries + messages.
|
||||||
|
//
|
||||||
|
// Algorithm:
|
||||||
|
// 1. Fetch context_items, resolve to full content
|
||||||
|
// 2. Split into evictable prefix + protected fresh tail
|
||||||
|
// 3. If evictable fits in remaining budget → include all
|
||||||
|
// 4. Else walk evictable from newest to oldest, keep while fits
|
||||||
|
func (a *Assembler) Assemble(ctx context.Context, convID int64, input AssembleInput) (*AssembleResult, error) {
|
||||||
|
items, err := a.store.GetContextItems(ctx, convID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get context items: %w", err)
|
||||||
|
}
|
||||||
|
if len(items) == 0 {
|
||||||
|
return &AssembleResult{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resolve all items
|
||||||
|
resolved := make([]resolvedItem, len(items))
|
||||||
|
for i, item := range items {
|
||||||
|
r, err := a.resolveItem(ctx, item)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
resolved[i] = r
|
||||||
|
}
|
||||||
|
|
||||||
|
// Split into evictable prefix and protected fresh tail
|
||||||
|
tailStart := len(resolved) - FreshTailCount
|
||||||
|
if tailStart < 0 {
|
||||||
|
tailStart = 0
|
||||||
|
}
|
||||||
|
evictable := resolved[:tailStart]
|
||||||
|
freshTail := resolved[tailStart:]
|
||||||
|
|
||||||
|
// Calculate fresh tail tokens
|
||||||
|
freshTailTokens := 0
|
||||||
|
for _, r := range freshTail {
|
||||||
|
freshTailTokens += r.tokenCount
|
||||||
|
}
|
||||||
|
|
||||||
|
// Budget-aware selection of evictable items
|
||||||
|
remainingBudget := input.Budget - freshTailTokens
|
||||||
|
if remainingBudget < 0 {
|
||||||
|
// Fresh tail alone exceeds budget - we keep it anyway (design decision)
|
||||||
|
// Log for debugging retry/overflow issues
|
||||||
|
logger.InfoCF("seahorse", "assemble: fresh tail exceeds budget", map[string]any{
|
||||||
|
"budget": input.Budget,
|
||||||
|
"fresh_tail_tokens": freshTailTokens,
|
||||||
|
"fresh_tail_count": len(freshTail),
|
||||||
|
"over_budget_by": freshTailTokens - input.Budget,
|
||||||
|
})
|
||||||
|
remainingBudget = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
var selected []resolvedItem
|
||||||
|
evictableTokens := 0
|
||||||
|
for _, r := range evictable {
|
||||||
|
evictableTokens += r.tokenCount
|
||||||
|
}
|
||||||
|
|
||||||
|
if evictableTokens <= remainingBudget {
|
||||||
|
// All evictable fit
|
||||||
|
selected = append(selected, evictable...)
|
||||||
|
} else {
|
||||||
|
// Walk from newest to oldest, keep while fits
|
||||||
|
var kept []resolvedItem
|
||||||
|
accum := 0
|
||||||
|
for i := len(evictable) - 1; i >= 0; i-- {
|
||||||
|
if accum+evictable[i].tokenCount <= remainingBudget {
|
||||||
|
kept = append(kept, evictable[i])
|
||||||
|
accum += evictable[i].tokenCount
|
||||||
|
} else {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Reverse to restore chronological order
|
||||||
|
for i, j := 0, len(kept)-1; i < j; i, j = i+1, j-1 {
|
||||||
|
kept[i], kept[j] = kept[j], kept[i]
|
||||||
|
}
|
||||||
|
selected = append(selected, kept...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Combine: selected evictable + fresh tail
|
||||||
|
final := append(selected, freshTail...)
|
||||||
|
|
||||||
|
// Build result
|
||||||
|
var messages []Message
|
||||||
|
var summaries []Summary
|
||||||
|
var sourceIDs []string
|
||||||
|
totalTokens := 0
|
||||||
|
maxDepth := 0
|
||||||
|
condensedCount := 0
|
||||||
|
|
||||||
|
for _, r := range final {
|
||||||
|
totalTokens += r.tokenCount
|
||||||
|
if r.itemType == "message" && r.message != nil {
|
||||||
|
messages = append(messages, *r.message)
|
||||||
|
sourceIDs = append(sourceIDs, fmt.Sprintf("msg:%d", r.message.ID))
|
||||||
|
} else if r.itemType == "summary" && r.summary != nil {
|
||||||
|
summaries = append(summaries, *r.summary)
|
||||||
|
if r.summary.Depth > maxDepth {
|
||||||
|
maxDepth = r.summary.Depth
|
||||||
|
}
|
||||||
|
if r.summary.Kind == SummaryKindCondensed {
|
||||||
|
condensedCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build depth-aware system prompt addition
|
||||||
|
systemPromptAddition := ""
|
||||||
|
if len(summaries) > 0 {
|
||||||
|
if maxDepth >= 2 || condensedCount >= 2 {
|
||||||
|
systemPromptAddition = "Your context has been heavily compressed through multi-level summarization.\n" +
|
||||||
|
"- Do NOT assert specific facts (commands, SHAs, paths, timestamps) from summaries without expanding.\n" +
|
||||||
|
"- When uncertain, use expand to recover original detail before making claims.\n" +
|
||||||
|
"- Tool escalation: grep \xe2\x86\x92 describe \xe2\x86\x92 expand"
|
||||||
|
} else {
|
||||||
|
systemPromptAddition = "Some earlier messages have been summarized. Use expand tools to recover details if needed."
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build Summary field: all XML summaries + system prompt addition
|
||||||
|
var summaryParts []string
|
||||||
|
for _, sum := range summaries {
|
||||||
|
if sum.Content == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Load parent IDs for XML formatting
|
||||||
|
parentSummaries, err := a.store.GetSummaryParents(ctx, sum.SummaryID)
|
||||||
|
if err != nil {
|
||||||
|
logger.WarnCF("seahorse", "assemble: get summary parents", map[string]any{
|
||||||
|
"summary_id": sum.SummaryID,
|
||||||
|
"error": err.Error(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
var parentIDs []string
|
||||||
|
for _, ps := range parentSummaries {
|
||||||
|
parentIDs = append(parentIDs, ps.SummaryID)
|
||||||
|
}
|
||||||
|
summaryParts = append(summaryParts, FormatSummaryXML(&sum, parentIDs))
|
||||||
|
}
|
||||||
|
summary := strings.Join(summaryParts, "\n\n")
|
||||||
|
if systemPromptAddition != "" {
|
||||||
|
if summary != "" {
|
||||||
|
summary += "\n\n"
|
||||||
|
}
|
||||||
|
summary += systemPromptAddition
|
||||||
|
}
|
||||||
|
|
||||||
|
return &AssembleResult{
|
||||||
|
Messages: messages,
|
||||||
|
Summary: summary,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveItem loads the full message or summary for a context item.
|
||||||
|
func (a *Assembler) resolveItem(ctx context.Context, item ContextItem) (resolvedItem, error) {
|
||||||
|
if item.ItemType == "message" {
|
||||||
|
msg, err := a.store.GetMessageByID(ctx, item.MessageID)
|
||||||
|
if err != nil {
|
||||||
|
return resolvedItem{}, err
|
||||||
|
}
|
||||||
|
tokens := item.TokenCount
|
||||||
|
if tokens == 0 {
|
||||||
|
tokens = msg.TokenCount
|
||||||
|
}
|
||||||
|
return resolvedItem{
|
||||||
|
ordinal: item.Ordinal,
|
||||||
|
itemType: "message",
|
||||||
|
message: msg,
|
||||||
|
tokenCount: tokens,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if item.ItemType == "summary" {
|
||||||
|
sum, err := a.store.GetSummary(ctx, item.SummaryID)
|
||||||
|
if err != nil {
|
||||||
|
return resolvedItem{}, err
|
||||||
|
}
|
||||||
|
tokens := item.TokenCount
|
||||||
|
if tokens == 0 {
|
||||||
|
tokens = sum.TokenCount
|
||||||
|
}
|
||||||
|
return resolvedItem{
|
||||||
|
ordinal: item.Ordinal,
|
||||||
|
itemType: "summary",
|
||||||
|
summary: sum,
|
||||||
|
tokenCount: tokens,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return resolvedItem{
|
||||||
|
ordinal: item.Ordinal,
|
||||||
|
itemType: item.ItemType,
|
||||||
|
tokenCount: item.TokenCount,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FormatSummaryXML formats a summary as XML for LLM context.
|
||||||
|
// This is exported so context managers can format summaries consistently.
|
||||||
|
func FormatSummaryXML(s *Summary, parentIDs []string) string {
|
||||||
|
// Build time attributes if available
|
||||||
|
var attrs string
|
||||||
|
if s.EarliestAt != nil {
|
||||||
|
attrs += fmt.Sprintf(` earliest_at="%s"`, s.EarliestAt.Format(time.RFC3339))
|
||||||
|
}
|
||||||
|
if s.LatestAt != nil {
|
||||||
|
attrs += fmt.Sprintf(` latest_at="%s"`, s.LatestAt.Format(time.RFC3339))
|
||||||
|
}
|
||||||
|
|
||||||
|
var parentsSection string
|
||||||
|
if s.Kind == SummaryKindCondensed && len(parentIDs) > 0 {
|
||||||
|
parents := "<parents>\n"
|
||||||
|
for _, pid := range parentIDs {
|
||||||
|
parents += fmt.Sprintf(" <summary_ref id=\"%s\" />\n", pid)
|
||||||
|
}
|
||||||
|
parents += " </parents>\n"
|
||||||
|
parentsSection = parents
|
||||||
|
}
|
||||||
|
return fmt.Sprintf(
|
||||||
|
"<summary id=\"%s\" kind=\"%s\" depth=\"%d\" descendant_count=\"%d\"%s>\n <content>\n %s\n </content>\n%s</summary>",
|
||||||
|
s.SummaryID,
|
||||||
|
string(s.Kind),
|
||||||
|
s.Depth,
|
||||||
|
s.DescendantCount,
|
||||||
|
attrs,
|
||||||
|
escapeXML(s.Content),
|
||||||
|
parentsSection,
|
||||||
|
)
|
||||||
|
}
|
||||||
@@ -0,0 +1,536 @@
|
|||||||
|
package seahorse
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// --- Assembler Tests ---
|
||||||
|
|
||||||
|
// helper: create a store with messages and summaries for assembly tests
|
||||||
|
func setupAssemblerStore(t *testing.T) (*Store, int64) {
|
||||||
|
t.Helper()
|
||||||
|
s := openTestStore(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
conv, err := s.GetOrCreateConversation(ctx, "test:assemble")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create conversation: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return s, conv.ConversationID
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAssemblerAssembleEmpty(t *testing.T) {
|
||||||
|
s, convID := setupAssemblerStore(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
a := &Assembler{store: s, config: Config{}}
|
||||||
|
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 1000})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Assemble: %v", err)
|
||||||
|
}
|
||||||
|
if len(result.Messages) != 0 {
|
||||||
|
t.Errorf("Messages = %d, want 0", len(result.Messages))
|
||||||
|
}
|
||||||
|
if result.Summary != "" {
|
||||||
|
t.Errorf("Summary = %q, want empty", result.Summary)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAssemblerAssembleMessagesOnly(t *testing.T) {
|
||||||
|
s, convID := setupAssemblerStore(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create messages
|
||||||
|
msg1, _ := s.AddMessage(ctx, convID, "user", "hello", 5)
|
||||||
|
msg2, _ := s.AddMessage(ctx, convID, "assistant", "world", 5)
|
||||||
|
|
||||||
|
// Create context items
|
||||||
|
s.UpsertContextItems(ctx, convID, []ContextItem{
|
||||||
|
{Ordinal: 100, ItemType: "message", MessageID: msg1.ID, TokenCount: 5},
|
||||||
|
{Ordinal: 200, ItemType: "message", MessageID: msg2.ID, TokenCount: 5},
|
||||||
|
})
|
||||||
|
|
||||||
|
a := &Assembler{store: s, config: Config{}}
|
||||||
|
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 100})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Assemble: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result.Messages) != 2 {
|
||||||
|
t.Fatalf("Messages = %d, want 2", len(result.Messages))
|
||||||
|
}
|
||||||
|
if result.Messages[0].Content != "hello" {
|
||||||
|
t.Errorf("Messages[0].Content = %q, want 'hello'", result.Messages[0].Content)
|
||||||
|
}
|
||||||
|
if result.Messages[1].Content != "world" {
|
||||||
|
t.Errorf("Messages[1].Content = %q, want 'world'", result.Messages[1].Content)
|
||||||
|
}
|
||||||
|
// No summaries, so Summary should be empty
|
||||||
|
if result.Summary != "" {
|
||||||
|
t.Errorf("Summary = %q, want empty", result.Summary)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAssemblerAssembleWithSummary(t *testing.T) {
|
||||||
|
s, convID := setupAssemblerStore(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create a summary
|
||||||
|
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||||
|
ConversationID: convID,
|
||||||
|
Kind: SummaryKindLeaf,
|
||||||
|
Depth: 0,
|
||||||
|
Content: "summary of early messages",
|
||||||
|
TokenCount: 50,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create recent messages
|
||||||
|
msg1, _ := s.AddMessage(ctx, convID, "user", "recent", 5)
|
||||||
|
msg2, _ := s.AddMessage(ctx, convID, "assistant", "reply", 5)
|
||||||
|
|
||||||
|
// Context: summary + recent messages
|
||||||
|
s.UpsertContextItems(ctx, convID, []ContextItem{
|
||||||
|
{Ordinal: 100, ItemType: "summary", SummaryID: summary.SummaryID, TokenCount: 50},
|
||||||
|
{Ordinal: 200, ItemType: "message", MessageID: msg1.ID, TokenCount: 5},
|
||||||
|
{Ordinal: 300, ItemType: "message", MessageID: msg2.ID, TokenCount: 5},
|
||||||
|
})
|
||||||
|
|
||||||
|
a := &Assembler{store: s, config: Config{}}
|
||||||
|
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 1000})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Assemble: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Messages = 2 raw messages (summaries are in Summary field, not Messages)
|
||||||
|
if len(result.Messages) != 2 {
|
||||||
|
t.Errorf("Messages = %d, want 2 (raw messages only)", len(result.Messages))
|
||||||
|
}
|
||||||
|
// Summary should contain XML with summary content
|
||||||
|
if result.Summary == "" {
|
||||||
|
t.Error("Summary should not be empty when summary exists")
|
||||||
|
}
|
||||||
|
if !strings.Contains(result.Summary, summary.Content) {
|
||||||
|
t.Errorf("Summary should contain summary content %q", summary.Content)
|
||||||
|
}
|
||||||
|
if !strings.Contains(result.Summary, "<summary") {
|
||||||
|
t.Error("Summary should contain <summary XML tag")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAssemblerBudgetEvictsOldest(t *testing.T) {
|
||||||
|
s, convID := setupAssemblerStore(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create 40 messages, each with 10 tokens = 400 total
|
||||||
|
msgs := make([]*Message, 40)
|
||||||
|
for i := 0; i < 40; i++ {
|
||||||
|
m, _ := s.AddMessage(ctx, convID, "user", "msg", 10)
|
||||||
|
msgs[i] = m
|
||||||
|
}
|
||||||
|
|
||||||
|
// Context items for all messages
|
||||||
|
items := make([]ContextItem, 40)
|
||||||
|
for i := 0; i < 40; i++ {
|
||||||
|
items[i] = ContextItem{
|
||||||
|
Ordinal: (i + 1) * 100,
|
||||||
|
ItemType: "message",
|
||||||
|
MessageID: msgs[i].ID,
|
||||||
|
TokenCount: 10,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.UpsertContextItems(ctx, convID, items)
|
||||||
|
|
||||||
|
// Budget of 200 tokens with FreshTailCount=32
|
||||||
|
// Fresh tail = last 32 messages (320 tokens, over budget, but always included)
|
||||||
|
// Evictable = first 8 messages (80 tokens)
|
||||||
|
// Budget after tail: max(0, 200-320) = 0 → no evictable items included
|
||||||
|
a := &Assembler{store: s, config: Config{}}
|
||||||
|
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 200})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Assemble: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should only include the 32-item fresh tail
|
||||||
|
if len(result.Messages) != 32 {
|
||||||
|
t.Errorf("Messages = %d, want 32 (fresh tail)", len(result.Messages))
|
||||||
|
}
|
||||||
|
// Should be the LAST 32 messages
|
||||||
|
if result.Messages[0].ID != msgs[8].ID {
|
||||||
|
t.Errorf("first message ID = %d, want %d (msgs[8])", result.Messages[0].ID, msgs[8].ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAssemblerBudgetFitsAll(t *testing.T) {
|
||||||
|
s, convID := setupAssemblerStore(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
msgs := make([]*Message, 5)
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
m, _ := s.AddMessage(ctx, convID, "user", "msg", 10)
|
||||||
|
msgs[i] = m
|
||||||
|
}
|
||||||
|
|
||||||
|
items := make([]ContextItem, 5)
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
items[i] = ContextItem{
|
||||||
|
Ordinal: (i + 1) * 100,
|
||||||
|
ItemType: "message",
|
||||||
|
MessageID: msgs[i].ID,
|
||||||
|
TokenCount: 10,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.UpsertContextItems(ctx, convID, items)
|
||||||
|
|
||||||
|
// Budget = 100, total = 50, FreshTailCount=32 → all items in tail
|
||||||
|
a := &Assembler{store: s, config: Config{}}
|
||||||
|
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 100})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Assemble: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result.Messages) != 5 {
|
||||||
|
t.Errorf("Messages = %d, want 5", len(result.Messages))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAssemblerSummaryXMLFormat(t *testing.T) {
|
||||||
|
s, convID := setupAssemblerStore(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||||
|
ConversationID: convID,
|
||||||
|
Kind: SummaryKindLeaf,
|
||||||
|
Depth: 0,
|
||||||
|
Content: "test summary content",
|
||||||
|
TokenCount: 20,
|
||||||
|
})
|
||||||
|
|
||||||
|
msg, _ := s.AddMessage(ctx, convID, "user", "hello", 5)
|
||||||
|
|
||||||
|
s.UpsertContextItems(ctx, convID, []ContextItem{
|
||||||
|
{Ordinal: 100, ItemType: "summary", SummaryID: summary.SummaryID, TokenCount: 20},
|
||||||
|
{Ordinal: 200, ItemType: "message", MessageID: msg.ID, TokenCount: 5},
|
||||||
|
})
|
||||||
|
|
||||||
|
a := &Assembler{store: s, config: Config{}}
|
||||||
|
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 1000})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Assemble: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Messages should only contain raw messages (no XML summary in Messages)
|
||||||
|
if len(result.Messages) != 1 {
|
||||||
|
t.Errorf("Messages = %d, want 1 (raw message only)", len(result.Messages))
|
||||||
|
}
|
||||||
|
// Summary should contain XML with summary content
|
||||||
|
if result.Summary == "" {
|
||||||
|
t.Fatal("Summary should not be empty")
|
||||||
|
}
|
||||||
|
if !contains(result.Summary, "<summary") {
|
||||||
|
t.Errorf("Summary missing <summary tag: %q", result.Summary)
|
||||||
|
}
|
||||||
|
if !contains(result.Summary, summary.SummaryID) {
|
||||||
|
t.Errorf("Summary missing summary ID: %q", result.Summary)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAssemblerSummaryXMLEscaping(t *testing.T) {
|
||||||
|
// Summary content with special XML characters should be properly escaped
|
||||||
|
s, convID := setupAssemblerStore(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create summary with content containing XML special characters
|
||||||
|
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||||
|
ConversationID: convID,
|
||||||
|
Kind: SummaryKindLeaf,
|
||||||
|
Depth: 0,
|
||||||
|
Content: `User said: "hello" & asked about <tags>`,
|
||||||
|
TokenCount: 20,
|
||||||
|
})
|
||||||
|
|
||||||
|
s.UpsertContextItems(ctx, convID, []ContextItem{
|
||||||
|
{Ordinal: 100, ItemType: "summary", SummaryID: summary.SummaryID, TokenCount: 20},
|
||||||
|
})
|
||||||
|
|
||||||
|
a := &Assembler{store: s, config: Config{}}
|
||||||
|
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 1000})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Assemble: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Summary field should contain XML with escaped special characters
|
||||||
|
if result.Summary == "" {
|
||||||
|
t.Fatal("Summary should not be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that special characters are escaped
|
||||||
|
if strings.Contains(result.Summary, "<tags>") {
|
||||||
|
t.Errorf("BUG: unescaped < in summary content: %q", result.Summary)
|
||||||
|
}
|
||||||
|
if strings.Contains(result.Summary, `"hello"`) {
|
||||||
|
t.Errorf("BUG: unescaped \" in summary content: %q", result.Summary)
|
||||||
|
}
|
||||||
|
// & should be escaped as &
|
||||||
|
if strings.Contains(result.Summary, " & ") {
|
||||||
|
t.Errorf("BUG: unescaped & in summary content: %q", result.Summary)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAssemblerSummaryXMLWithParents(t *testing.T) {
|
||||||
|
s, convID := setupAssemblerStore(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create a leaf and a condensed summary (condensed has parent)
|
||||||
|
leaf, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||||
|
ConversationID: convID,
|
||||||
|
Kind: SummaryKindLeaf,
|
||||||
|
Depth: 0,
|
||||||
|
Content: "leaf content",
|
||||||
|
TokenCount: 20,
|
||||||
|
})
|
||||||
|
condensed, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||||
|
ConversationID: convID,
|
||||||
|
Kind: SummaryKindCondensed,
|
||||||
|
Depth: 1,
|
||||||
|
Content: "condensed content",
|
||||||
|
TokenCount: 15,
|
||||||
|
ParentIDs: []string{leaf.SummaryID},
|
||||||
|
})
|
||||||
|
|
||||||
|
msg, _ := s.AddMessage(ctx, convID, "user", "fresh", 5)
|
||||||
|
|
||||||
|
s.UpsertContextItems(ctx, convID, []ContextItem{
|
||||||
|
{Ordinal: 100, ItemType: "summary", SummaryID: condensed.SummaryID, TokenCount: 15},
|
||||||
|
{Ordinal: 200, ItemType: "message", MessageID: msg.ID, TokenCount: 5},
|
||||||
|
})
|
||||||
|
|
||||||
|
a := &Assembler{store: s, config: Config{}}
|
||||||
|
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 1000})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Assemble: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Summary field should contain XML with parent information
|
||||||
|
if result.Summary == "" {
|
||||||
|
t.Fatal("Summary should not be empty")
|
||||||
|
}
|
||||||
|
xmlContent := result.Summary
|
||||||
|
|
||||||
|
// Should contain <parents> section with parent ID
|
||||||
|
if !contains(xmlContent, "<parents>") {
|
||||||
|
t.Errorf("condensed summary XML missing <parents> section: %q", xmlContent)
|
||||||
|
}
|
||||||
|
if !contains(xmlContent, leaf.SummaryID) {
|
||||||
|
t.Errorf("condensed summary XML missing parent ID %q: %q", leaf.SummaryID, xmlContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should contain kind="condensed"
|
||||||
|
if !contains(xmlContent, `kind="condensed"`) {
|
||||||
|
t.Errorf("condensed summary XML missing kind attribute: %q", xmlContent)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAssemblerSummaryXMLIncludesDescendantCount(t *testing.T) {
|
||||||
|
s, convID := setupAssemblerStore(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create a leaf summary with specific descendant count
|
||||||
|
leaf, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||||
|
ConversationID: convID,
|
||||||
|
Kind: SummaryKindLeaf,
|
||||||
|
Depth: 0,
|
||||||
|
Content: "leaf content",
|
||||||
|
TokenCount: 20,
|
||||||
|
DescendantCount: 8,
|
||||||
|
DescendantTokenCount: 1200,
|
||||||
|
})
|
||||||
|
|
||||||
|
msg, _ := s.AddMessage(ctx, convID, "user", "fresh", 5)
|
||||||
|
|
||||||
|
s.UpsertContextItems(ctx, convID, []ContextItem{
|
||||||
|
{Ordinal: 100, ItemType: "summary", SummaryID: leaf.SummaryID, TokenCount: 20},
|
||||||
|
{Ordinal: 200, ItemType: "message", MessageID: msg.ID, TokenCount: 5},
|
||||||
|
})
|
||||||
|
|
||||||
|
a := &Assembler{store: s, config: Config{}}
|
||||||
|
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 1000})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Assemble: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Summary == "" {
|
||||||
|
t.Fatal("Summary should not be empty")
|
||||||
|
}
|
||||||
|
xmlContent := result.Summary
|
||||||
|
|
||||||
|
// Should contain descendant_count="8"
|
||||||
|
if !contains(xmlContent, `descendant_count="8"`) {
|
||||||
|
t.Errorf("summary XML missing descendant_count attribute: %q", xmlContent)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAssemblerLeafSummaryNoParents(t *testing.T) {
|
||||||
|
s, convID := setupAssemblerStore(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Leaf summary has no parents
|
||||||
|
leaf, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||||
|
ConversationID: convID,
|
||||||
|
Kind: SummaryKindLeaf,
|
||||||
|
Depth: 0,
|
||||||
|
Content: "leaf content",
|
||||||
|
TokenCount: 20,
|
||||||
|
})
|
||||||
|
|
||||||
|
msg, _ := s.AddMessage(ctx, convID, "user", "fresh", 5)
|
||||||
|
|
||||||
|
s.UpsertContextItems(ctx, convID, []ContextItem{
|
||||||
|
{Ordinal: 100, ItemType: "summary", SummaryID: leaf.SummaryID, TokenCount: 20},
|
||||||
|
{Ordinal: 200, ItemType: "message", MessageID: msg.ID, TokenCount: 5},
|
||||||
|
})
|
||||||
|
|
||||||
|
a := &Assembler{store: s, config: Config{}}
|
||||||
|
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 1000})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Assemble: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Summary == "" {
|
||||||
|
t.Fatal("Summary should not be empty")
|
||||||
|
}
|
||||||
|
xmlContent := result.Summary
|
||||||
|
|
||||||
|
// Leaf summary should NOT have <parents> section
|
||||||
|
if contains(xmlContent, "<parents>") {
|
||||||
|
t.Errorf("leaf summary XML should not have <parents> section: %q", xmlContent)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAssemblerDepthAwarePrompt(t *testing.T) {
|
||||||
|
s, convID := setupAssemblerStore(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create a condensed summary (depth >= 2) to trigger full guidance
|
||||||
|
now := time.Now().UTC()
|
||||||
|
leaf, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||||
|
ConversationID: convID,
|
||||||
|
Kind: SummaryKindLeaf,
|
||||||
|
Depth: 0,
|
||||||
|
Content: "leaf summary",
|
||||||
|
TokenCount: 20,
|
||||||
|
EarliestAt: &now,
|
||||||
|
LatestAt: &now,
|
||||||
|
})
|
||||||
|
condensed, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||||
|
ConversationID: convID,
|
||||||
|
Kind: SummaryKindCondensed,
|
||||||
|
Depth: 2,
|
||||||
|
Content: "condensed summary",
|
||||||
|
TokenCount: 15,
|
||||||
|
ParentIDs: []string{leaf.SummaryID},
|
||||||
|
DescendantCount: 1,
|
||||||
|
DescendantTokenCount: 20,
|
||||||
|
})
|
||||||
|
|
||||||
|
msg, _ := s.AddMessage(ctx, convID, "user", "fresh", 5)
|
||||||
|
|
||||||
|
s.UpsertContextItems(ctx, convID, []ContextItem{
|
||||||
|
{Ordinal: 100, ItemType: "summary", SummaryID: condensed.SummaryID, TokenCount: 15},
|
||||||
|
{Ordinal: 200, ItemType: "message", MessageID: msg.ID, TokenCount: 5},
|
||||||
|
})
|
||||||
|
|
||||||
|
a := &Assembler{store: s, config: Config{}}
|
||||||
|
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 1000})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Assemble: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should have a depth-aware prompt in Summary field
|
||||||
|
if result.Summary == "" {
|
||||||
|
t.Error("expected non-empty Summary when depth >= 2")
|
||||||
|
}
|
||||||
|
// SystemPromptAddition is embedded in Summary field
|
||||||
|
if !strings.Contains(result.Summary, "multi-level summarization") {
|
||||||
|
t.Error("Summary should contain system prompt addition about multi-level summarization")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatSummaryXMLUsesSummaryRef(t *testing.T) {
|
||||||
|
// Spec: condensed summaries use <summary_ref id="parentId" /> not <parent>parentId</parent>
|
||||||
|
now := time.Now().UTC()
|
||||||
|
s := Summary{
|
||||||
|
SummaryID: "sum_condensed1",
|
||||||
|
Kind: SummaryKindCondensed,
|
||||||
|
Depth: 1,
|
||||||
|
Content: "condensed content",
|
||||||
|
TokenCount: 50,
|
||||||
|
DescendantCount: 2,
|
||||||
|
EarliestAt: &now,
|
||||||
|
LatestAt: &now,
|
||||||
|
}
|
||||||
|
parentIDs := []string{"sum_leaf1", "sum_leaf2"}
|
||||||
|
|
||||||
|
xml := FormatSummaryXML(&s, parentIDs)
|
||||||
|
|
||||||
|
// Must use <summary_ref id="..." /> per spec
|
||||||
|
if !contains(xml, `<summary_ref id="sum_leaf1" />`) {
|
||||||
|
t.Errorf("expected <summary_ref id=\"sum_leaf1\" />, got: %s", xml)
|
||||||
|
}
|
||||||
|
if !contains(xml, `<summary_ref id="sum_leaf2" />`) {
|
||||||
|
t.Errorf("expected <summary_ref id=\"sum_leaf2\" />, got: %s", xml)
|
||||||
|
}
|
||||||
|
// Must NOT use old <parent> tag
|
||||||
|
if contains(xml, "<parent>") {
|
||||||
|
t.Errorf("should not use <parent> tag, got: %s", xml)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatSummaryXMLIncludesTimestamps(t *testing.T) {
|
||||||
|
// Spec: summary XML includes earliest_at and latest_at attributes
|
||||||
|
earliest := time.Date(2026, 3, 15, 10, 0, 0, 0, time.UTC)
|
||||||
|
latest := time.Date(2026, 3, 15, 14, 30, 0, 0, time.UTC)
|
||||||
|
s := Summary{
|
||||||
|
SummaryID: "sum_leaf1",
|
||||||
|
Kind: SummaryKindLeaf,
|
||||||
|
Depth: 0,
|
||||||
|
Content: "leaf content",
|
||||||
|
TokenCount: 30,
|
||||||
|
DescendantCount: 0,
|
||||||
|
EarliestAt: &earliest,
|
||||||
|
LatestAt: &latest,
|
||||||
|
}
|
||||||
|
|
||||||
|
xml := FormatSummaryXML(&s, nil)
|
||||||
|
|
||||||
|
if !contains(xml, `earliest_at="2026-03-15T10:00:00Z"`) {
|
||||||
|
t.Errorf("missing earliest_at attribute, got: %s", xml)
|
||||||
|
}
|
||||||
|
if !contains(xml, `latest_at="2026-03-15T14:30:00Z"`) {
|
||||||
|
t.Errorf("missing latest_at attribute, got: %s", xml)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatSummaryXMLNoTimestampsWhenNil(t *testing.T) {
|
||||||
|
// When EarliestAt/LatestAt are nil, attributes should be omitted
|
||||||
|
s := Summary{
|
||||||
|
SummaryID: "sum_leaf1",
|
||||||
|
Kind: SummaryKindLeaf,
|
||||||
|
Depth: 0,
|
||||||
|
Content: "leaf content",
|
||||||
|
TokenCount: 30,
|
||||||
|
DescendantCount: 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
xml := FormatSummaryXML(&s, nil)
|
||||||
|
|
||||||
|
if contains(xml, "earliest_at=") {
|
||||||
|
t.Errorf("should not have earliest_at when nil, got: %s", xml)
|
||||||
|
}
|
||||||
|
if contains(xml, "latest_at=") {
|
||||||
|
t.Errorf("should not have latest_at when nil, got: %s", xml)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,336 @@
|
|||||||
|
package seahorse
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
_ "modernc.org/sqlite"
|
||||||
|
)
|
||||||
|
|
||||||
|
// newBenchStore creates a test store for benchmarks.
|
||||||
|
func newBenchStore(b *testing.B) (*Store, func()) {
|
||||||
|
b.Helper()
|
||||||
|
db, err := sql.Open("sqlite", ":memory:")
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("open test db: %v", err)
|
||||||
|
}
|
||||||
|
if err := runSchema(db); err != nil {
|
||||||
|
db.Close()
|
||||||
|
b.Fatalf("migration: %v", err)
|
||||||
|
}
|
||||||
|
return &Store{db: db}, func() { db.Close() }
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Ingest benchmarks ---
|
||||||
|
|
||||||
|
func BenchmarkIngest_SingleMessage(b *testing.B) {
|
||||||
|
s, cleanup := newBenchStore(b)
|
||||||
|
defer cleanup()
|
||||||
|
ctx := context.Background()
|
||||||
|
conv, _ := s.GetOrCreateConversation(ctx, "bench:ingest")
|
||||||
|
convID := conv.ConversationID
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, err := s.AddMessage(ctx, convID, "user", "Test message content", 15)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkIngest_BatchMessages(b *testing.B) {
|
||||||
|
s, cleanup := newBenchStore(b)
|
||||||
|
defer cleanup()
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
conv, _ := s.GetOrCreateConversation(ctx, fmt.Sprintf("bench:ingest-batch:%d", i))
|
||||||
|
convID := conv.ConversationID
|
||||||
|
|
||||||
|
for j := 0; j < 10; j++ {
|
||||||
|
added, err := s.AddMessage(ctx, convID, "user",
|
||||||
|
fmt.Sprintf("Message %d in batch", j), 10)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
s.AppendContextMessage(ctx, convID, added.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Assemble benchmarks ---
|
||||||
|
|
||||||
|
func BenchmarkAssemble_MessagesOnly(b *testing.B) {
|
||||||
|
s, cleanup := newBenchStore(b)
|
||||||
|
defer cleanup()
|
||||||
|
ctx := context.Background()
|
||||||
|
conv, _ := s.GetOrCreateConversation(ctx, "bench:assemble-msgs")
|
||||||
|
convID := conv.ConversationID
|
||||||
|
|
||||||
|
// Add 100 messages
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
m, _ := s.AddMessage(ctx, convID, "user",
|
||||||
|
fmt.Sprintf("Message content %d with some text", i), 10)
|
||||||
|
s.AppendContextMessage(ctx, convID, m.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
a := &Assembler{store: s}
|
||||||
|
input := AssembleInput{Budget: 50000}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, err := a.Assemble(ctx, convID, input)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkAssemble_WithSummaries(b *testing.B) {
|
||||||
|
s, cleanup := newBenchStore(b)
|
||||||
|
defer cleanup()
|
||||||
|
ctx := context.Background()
|
||||||
|
conv, _ := s.GetOrCreateConversation(ctx, "bench:assemble-sums")
|
||||||
|
convID := conv.ConversationID
|
||||||
|
|
||||||
|
now := time.Now().UTC()
|
||||||
|
|
||||||
|
// Add 10 leaf summaries
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
sum, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||||
|
ConversationID: convID,
|
||||||
|
Kind: SummaryKindLeaf,
|
||||||
|
Depth: 0,
|
||||||
|
Content: fmt.Sprintf("Leaf summary %d", i),
|
||||||
|
TokenCount: 500,
|
||||||
|
EarliestAt: &now,
|
||||||
|
LatestAt: &now,
|
||||||
|
})
|
||||||
|
s.AppendContextSummary(ctx, convID, sum.SummaryID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add 20 fresh messages
|
||||||
|
for i := 0; i < 20; i++ {
|
||||||
|
m, _ := s.AddMessage(ctx, convID, "user", fmt.Sprintf("Fresh message %d", i), 10)
|
||||||
|
s.AppendContextMessage(ctx, convID, m.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
a := &Assembler{store: s}
|
||||||
|
input := AssembleInput{Budget: 10000}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, err := a.Assemble(ctx, convID, input)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkAssemble_BudgetEviction(b *testing.B) {
|
||||||
|
s, cleanup := newBenchStore(b)
|
||||||
|
defer cleanup()
|
||||||
|
ctx := context.Background()
|
||||||
|
conv, _ := s.GetOrCreateConversation(ctx, "bench:assemble-evict")
|
||||||
|
convID := conv.ConversationID
|
||||||
|
|
||||||
|
now := time.Now().UTC()
|
||||||
|
|
||||||
|
// Add 50 leaf summaries (more than budget can hold)
|
||||||
|
for i := 0; i < 50; i++ {
|
||||||
|
sum, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||||
|
ConversationID: convID,
|
||||||
|
Kind: SummaryKindLeaf,
|
||||||
|
Depth: 0,
|
||||||
|
Content: fmt.Sprintf("Summary %d", i),
|
||||||
|
TokenCount: 300,
|
||||||
|
EarliestAt: &now,
|
||||||
|
LatestAt: &now,
|
||||||
|
})
|
||||||
|
s.AppendContextSummary(ctx, convID, sum.SummaryID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add fresh tail
|
||||||
|
for i := 0; i < FreshTailCount; i++ {
|
||||||
|
m, _ := s.AddMessage(ctx, convID, "user", "fresh", 10)
|
||||||
|
s.AppendContextMessage(ctx, convID, m.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
a := &Assembler{store: s}
|
||||||
|
input := AssembleInput{Budget: 5000} // Force eviction
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, err := a.Assemble(ctx, convID, input)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Search (FTS5) benchmarks ---
|
||||||
|
|
||||||
|
// benchSeedSummaries adds n summaries to a conversation for search benchmarks.
|
||||||
|
func benchSeedSummaries(b *testing.B, s *Store, convID int64, n int, contentTpl string) {
|
||||||
|
b.Helper()
|
||||||
|
now := time.Now().UTC()
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
sum, err := s.CreateSummary(context.Background(), CreateSummaryInput{
|
||||||
|
ConversationID: convID,
|
||||||
|
Kind: SummaryKindLeaf,
|
||||||
|
Depth: 0,
|
||||||
|
Content: fmt.Sprintf(contentTpl, i),
|
||||||
|
TokenCount: 200,
|
||||||
|
EarliestAt: &now,
|
||||||
|
LatestAt: &now,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("create summary: %v", err)
|
||||||
|
}
|
||||||
|
s.AppendContextSummary(context.Background(), convID, sum.SummaryID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkSearchSummaries_FTS5(b *testing.B) {
|
||||||
|
s, cleanup := newBenchStore(b)
|
||||||
|
defer cleanup()
|
||||||
|
ctx := context.Background()
|
||||||
|
conv, _ := s.GetOrCreateConversation(ctx, "bench:search-fts")
|
||||||
|
convID := conv.ConversationID
|
||||||
|
|
||||||
|
benchSeedSummaries(b, s, convID, 100, "Summary about database configuration and API endpoints %d")
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, err := s.SearchSummaries(ctx, SearchInput{
|
||||||
|
Pattern: "database",
|
||||||
|
Mode: "full_text",
|
||||||
|
ConversationID: convID,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkSearchSummaries_Like(b *testing.B) {
|
||||||
|
s, cleanup := newBenchStore(b)
|
||||||
|
defer cleanup()
|
||||||
|
ctx := context.Background()
|
||||||
|
conv, _ := s.GetOrCreateConversation(ctx, "bench:search-like")
|
||||||
|
convID := conv.ConversationID
|
||||||
|
|
||||||
|
benchSeedSummaries(b, s, convID, 100, "Summary about configuration %d")
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, err := s.SearchSummaries(ctx, SearchInput{
|
||||||
|
Pattern: "config",
|
||||||
|
Mode: "like",
|
||||||
|
ConversationID: convID,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkSearchMessages_FTS5(b *testing.B) {
|
||||||
|
s, cleanup := newBenchStore(b)
|
||||||
|
defer cleanup()
|
||||||
|
ctx := context.Background()
|
||||||
|
conv, _ := s.GetOrCreateConversation(ctx, "bench:search-msg-fts")
|
||||||
|
convID := conv.ConversationID
|
||||||
|
|
||||||
|
// Add 500 messages
|
||||||
|
for i := 0; i < 500; i++ {
|
||||||
|
m, _ := s.AddMessage(ctx, convID, "user",
|
||||||
|
fmt.Sprintf("User message about API and database integration %d", i), 20)
|
||||||
|
s.AppendContextMessage(ctx, convID, m.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, err := s.SearchMessages(ctx, SearchInput{
|
||||||
|
Pattern: "API database",
|
||||||
|
Mode: "full_text",
|
||||||
|
ConversationID: convID,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Bootstrap benchmarks ---
|
||||||
|
|
||||||
|
func BenchmarkBootstrap_Empty(b *testing.B) {
|
||||||
|
s, cleanup := newBenchStore(b)
|
||||||
|
defer cleanup()
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
conv, _ := s.GetOrCreateConversation(ctx, fmt.Sprintf("bench:bootstrap-empty:%d", i))
|
||||||
|
convID := conv.ConversationID
|
||||||
|
_ = convID // Bootstrap with empty history
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkBootstrap_100Messages(b *testing.B) {
|
||||||
|
s, cleanup := newBenchStore(b)
|
||||||
|
defer cleanup()
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Prepare 100 messages
|
||||||
|
msgs := make([]Message, 100)
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
msgs[i] = Message{
|
||||||
|
Role: "user",
|
||||||
|
Content: fmt.Sprintf("Bootstrap message %d", i),
|
||||||
|
TokenCount: 15,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
conv, _ := s.GetOrCreateConversation(ctx, fmt.Sprintf("bench:bootstrap-100:%d", i))
|
||||||
|
convID := conv.ConversationID
|
||||||
|
|
||||||
|
for _, m := range msgs {
|
||||||
|
added, _ := s.AddMessage(ctx, convID, m.Role, m.Content, m.TokenCount)
|
||||||
|
s.AppendContextMessage(ctx, convID, added.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkBootstrap_500Messages(b *testing.B) {
|
||||||
|
s, cleanup := newBenchStore(b)
|
||||||
|
defer cleanup()
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
msgs := make([]Message, 500)
|
||||||
|
for i := 0; i < 500; i++ {
|
||||||
|
msgs[i] = Message{
|
||||||
|
Role: "user",
|
||||||
|
Content: fmt.Sprintf("Bootstrap message %d", i),
|
||||||
|
TokenCount: 15,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
conv, _ := s.GetOrCreateConversation(ctx, fmt.Sprintf("bench:bootstrap-500:%d", i))
|
||||||
|
convID := conv.ConversationID
|
||||||
|
|
||||||
|
for _, m := range msgs {
|
||||||
|
added, _ := s.AddMessage(ctx, convID, m.Role, m.Content, m.TokenCount)
|
||||||
|
s.AppendContextMessage(ctx, convID, added.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,898 @@
|
|||||||
|
package seahorse
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"sort"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sipeed/picoclaw/pkg/logger"
|
||||||
|
"github.com/sipeed/picoclaw/pkg/providers"
|
||||||
|
"github.com/sipeed/picoclaw/pkg/tokenizer"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CompactInput controls compaction behavior.
|
||||||
|
type CompactInput struct {
|
||||||
|
Budget *int // Token budget override
|
||||||
|
Force bool // Force compaction even if below threshold
|
||||||
|
}
|
||||||
|
|
||||||
|
// CompactResult describes what was compacted.
|
||||||
|
type CompactResult struct {
|
||||||
|
SummariesCreated []string `json:"summariesCreated"`
|
||||||
|
TokensSaved int `json:"tokensSaved"`
|
||||||
|
LeafSummaries int `json:"leafSummaries"`
|
||||||
|
CondensedSummaries int `json:"condensedSummaries"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// NeedsCompaction returns true if context tokens >= ContextThreshold × contextWindow.
|
||||||
|
func (e *CompactionEngine) NeedsCompaction(ctx context.Context, convID int64, contextWindow int) (bool, error) {
|
||||||
|
tokens, err := e.store.GetContextTokenCount(ctx, convID)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("get token count: %w", err)
|
||||||
|
}
|
||||||
|
threshold := int(float64(contextWindow) * ContextThreshold)
|
||||||
|
return tokens >= threshold, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close cancels the shutdown context, stopping async goroutines.
|
||||||
|
func (e *CompactionEngine) Close() {
|
||||||
|
if e.shutdownCancel != nil {
|
||||||
|
e.shutdownCancel()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compact runs leaf compaction (sync) and optionally condensed compaction.
|
||||||
|
func (e *CompactionEngine) Compact(ctx context.Context, convID int64, input CompactInput) (*CompactResult, error) {
|
||||||
|
result := &CompactResult{}
|
||||||
|
|
||||||
|
// Phase 1: leaf compaction (synchronous, every turn)
|
||||||
|
summaryID, err := e.compactLeaf(ctx, convID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("compact leaf: %w", err)
|
||||||
|
}
|
||||||
|
if summaryID != nil {
|
||||||
|
result.SummariesCreated = append(result.SummariesCreated, *summaryID)
|
||||||
|
result.LeafSummaries++
|
||||||
|
logger.InfoCF("seahorse", "compact: leaf", map[string]any{
|
||||||
|
"conv_id": convID,
|
||||||
|
"summary_id": *summaryID,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Phase 2: condensed compaction if over threshold
|
||||||
|
tokensBefore, _ := e.store.GetContextTokenCount(ctx, convID)
|
||||||
|
var budget int
|
||||||
|
if input.Budget != nil {
|
||||||
|
budget = *input.Budget
|
||||||
|
if budget == 0 {
|
||||||
|
logger.ErrorCF("seahorse", "Compact: budget is 0, this should not happen", map[string]any{
|
||||||
|
"conv_id": convID,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
budget = int(float64(tokensBefore) * ContextThreshold)
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.Force || (tokensBefore > budget && budget > 0) {
|
||||||
|
// Launch async condensed compaction with dedup
|
||||||
|
if _, loaded := e.condensing.LoadOrStore(convID, struct{}{}); !loaded {
|
||||||
|
go func() {
|
||||||
|
defer e.condensing.Delete(convID)
|
||||||
|
e.runCondensedLoop(e.shutdownCtx, convID)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tokensAfter, _ := e.store.GetContextTokenCount(ctx, convID)
|
||||||
|
if tokensAfter < tokensBefore {
|
||||||
|
result.TokensSaved = tokensBefore - tokensAfter
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CompactUntilUnder aggressively compacts until context is under budget.
|
||||||
|
func (e *CompactionEngine) CompactUntilUnder(ctx context.Context, convID int64, budget int) (*CompactResult, error) {
|
||||||
|
result := &CompactResult{}
|
||||||
|
prevTokens := 0
|
||||||
|
logger.InfoCF("seahorse", "compact_until_under: start", map[string]any{"conv_id": convID, "budget": budget})
|
||||||
|
|
||||||
|
for iter := 0; iter < MaxCompactIterations; iter++ {
|
||||||
|
tokens, err := e.store.GetContextTokenCount(ctx, convID)
|
||||||
|
if err != nil {
|
||||||
|
return result, fmt.Errorf("get tokens: %w", err)
|
||||||
|
}
|
||||||
|
if tokens <= budget {
|
||||||
|
logger.InfoCF("seahorse", "compact_until_under: done", map[string]any{
|
||||||
|
"conv_id": convID,
|
||||||
|
"budget": budget,
|
||||||
|
"tokens": tokens,
|
||||||
|
"leaf": result.LeafSummaries,
|
||||||
|
"condensed": result.CondensedSummaries,
|
||||||
|
})
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try leaf first
|
||||||
|
summaryID, err := e.compactLeaf(ctx, convID, true)
|
||||||
|
if err != nil {
|
||||||
|
return result, err
|
||||||
|
}
|
||||||
|
if summaryID != nil {
|
||||||
|
result.SummariesCreated = append(result.SummariesCreated, *summaryID)
|
||||||
|
result.LeafSummaries++
|
||||||
|
logger.InfoCF("seahorse", "compact_until_under: leaf", map[string]any{
|
||||||
|
"conv_id": convID,
|
||||||
|
"summary_id": *summaryID,
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try condensed with forced fanout
|
||||||
|
condensedID, err := e.compactCondensed(ctx, convID)
|
||||||
|
if err != nil {
|
||||||
|
return result, err
|
||||||
|
}
|
||||||
|
if condensedID != nil {
|
||||||
|
result.SummariesCreated = append(result.SummariesCreated, *condensedID)
|
||||||
|
result.CondensedSummaries++
|
||||||
|
logger.InfoCF("seahorse", "compact_until_under: condensed", map[string]any{
|
||||||
|
"conv_id": convID,
|
||||||
|
"summary_id": *condensedID,
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// No progress
|
||||||
|
newTokens, _ := e.store.GetContextTokenCount(ctx, convID)
|
||||||
|
if newTokens >= prevTokens {
|
||||||
|
logger.WarnCF("seahorse", "compact_until_under: no progress", map[string]any{
|
||||||
|
"conv_id": convID,
|
||||||
|
"tokens": newTokens,
|
||||||
|
})
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
prevTokens = newTokens
|
||||||
|
}
|
||||||
|
|
||||||
|
// Safety cap exceeded — see MaxCompactIterations doc for rationale.
|
||||||
|
logger.WarnCF("seahorse", "compact_until_under: exceeded max iterations", map[string]any{
|
||||||
|
"conv_id": convID,
|
||||||
|
"budget": budget,
|
||||||
|
"iterations": MaxCompactIterations,
|
||||||
|
"tokens": prevTokens,
|
||||||
|
})
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// compactLeaf compresses the oldest contiguous message chunk into a leaf summary.
|
||||||
|
// When force is true, FreshTailCount protection is bypassed (used by CompactUntilUnder).
|
||||||
|
func (e *CompactionEngine) compactLeaf(ctx context.Context, convID int64, force ...bool) (*string, error) {
|
||||||
|
items, err := e.store.GetContextItems(ctx, convID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find oldest contiguous message chunk outside fresh tail
|
||||||
|
msgCount := 0
|
||||||
|
msgTokens := 0
|
||||||
|
for _, item := range items {
|
||||||
|
if item.ItemType == "message" {
|
||||||
|
msgCount++
|
||||||
|
msgTokens += item.TokenCount
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Trigger if either message count or token threshold is met
|
||||||
|
if msgCount < LeafMinFanout && msgTokens < LeafChunkTokens {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate fresh tail boundary (bypass when forced)
|
||||||
|
useForce := len(force) > 0 && force[0]
|
||||||
|
tailStartIdx := len(items) - FreshTailCount
|
||||||
|
if useForce {
|
||||||
|
tailStartIdx = len(items) // allow compacting everything
|
||||||
|
}
|
||||||
|
if tailStartIdx < 0 {
|
||||||
|
tailStartIdx = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find oldest contiguous message chunk, accumulating up to LeafChunkTokens
|
||||||
|
var chunk []ContextItem
|
||||||
|
chunkStart := -1
|
||||||
|
chunkEnd := -1
|
||||||
|
accumTokens := 0
|
||||||
|
for i := 0; i < tailStartIdx; i++ {
|
||||||
|
if items[i].ItemType == "message" {
|
||||||
|
if chunkStart == -1 {
|
||||||
|
chunkStart = i
|
||||||
|
}
|
||||||
|
chunkEnd = i
|
||||||
|
accumTokens += items[i].TokenCount
|
||||||
|
// Stop accumulating once we reach the token budget
|
||||||
|
if accumTokens >= LeafChunkTokens {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Non-message breaks the chunk
|
||||||
|
if chunkStart != -1 && (chunkEnd-chunkStart+1) >= LeafMinFanout {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
chunkStart = -1
|
||||||
|
chunkEnd = -1
|
||||||
|
accumTokens = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if chunkStart == -1 || (chunkEnd-chunkStart+1) < LeafMinFanout {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
chunk = items[chunkStart : chunkEnd+1]
|
||||||
|
|
||||||
|
// Collect messages for the chunk
|
||||||
|
var messages []Message
|
||||||
|
for _, item := range chunk {
|
||||||
|
msg, innerErr := e.store.GetMessageByID(ctx, item.MessageID)
|
||||||
|
if innerErr != nil {
|
||||||
|
return nil, innerErr
|
||||||
|
}
|
||||||
|
messages = append(messages, *msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get prior summaries for context
|
||||||
|
priorSummary := ""
|
||||||
|
priorCount := 0
|
||||||
|
for i := chunkStart - 1; i >= 0 && priorCount < 2; i-- {
|
||||||
|
if items[i].ItemType == "summary" {
|
||||||
|
sum, innerErr2 := e.store.GetSummary(ctx, items[i].SummaryID)
|
||||||
|
if innerErr2 == nil {
|
||||||
|
priorSummary = sum.Content + "\n" + priorSummary
|
||||||
|
priorCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate summary
|
||||||
|
content, err := e.generateLeafSummary(ctx, messages, priorSummary)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create summary in store
|
||||||
|
tokenCount := tokenizer.EstimateMessageTokens(providers.Message{Content: content})
|
||||||
|
|
||||||
|
var earliestAt, latestAt *time.Time
|
||||||
|
if len(messages) > 0 {
|
||||||
|
earliestAt = &messages[0].CreatedAt
|
||||||
|
latestAt = &messages[len(messages)-1].CreatedAt
|
||||||
|
}
|
||||||
|
|
||||||
|
summary, err := e.store.CreateSummary(ctx, CreateSummaryInput{
|
||||||
|
ConversationID: convID,
|
||||||
|
Kind: SummaryKindLeaf,
|
||||||
|
Depth: 0,
|
||||||
|
Content: content,
|
||||||
|
TokenCount: tokenCount,
|
||||||
|
EarliestAt: earliestAt,
|
||||||
|
LatestAt: latestAt,
|
||||||
|
SourceMessageTokens: sumMessageTokens(messages),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Link to source messages
|
||||||
|
msgIDs := make([]int64, len(messages))
|
||||||
|
for i, m := range messages {
|
||||||
|
msgIDs[i] = m.ID
|
||||||
|
}
|
||||||
|
if err := e.store.LinkSummaryToMessages(ctx, summary.SummaryID, msgIDs); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replace context range with summary
|
||||||
|
if err := e.store.ReplaceContextRangeWithSummary(
|
||||||
|
ctx, convID, chunk[0].Ordinal, chunk[len(chunk)-1].Ordinal, summary.SummaryID,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &summary.SummaryID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// compactCondensed compresses multiple summaries into one higher-level summary.
|
||||||
|
func (e *CompactionEngine) compactCondensed(ctx context.Context, convID int64) (*string, error) {
|
||||||
|
// Try ordinal-aware selection first (respects consecutive ordering)
|
||||||
|
var candidates []Summary
|
||||||
|
|
||||||
|
depths, err := e.store.GetDistinctDepthsInContext(ctx, convID, 0)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for _, depth := range depths {
|
||||||
|
var chunkAtDepth []Summary
|
||||||
|
var err2 error
|
||||||
|
chunkAtDepth, err2 = e.selectOldestChunkAtDepth(ctx, convID, depth)
|
||||||
|
if err2 != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if len(chunkAtDepth) > 0 {
|
||||||
|
candidates = chunkAtDepth
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to depth-grouping selection
|
||||||
|
if len(candidates) == 0 {
|
||||||
|
candidates, err = e.selectShallowestCondensationCandidate(ctx, convID, false)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(candidates) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate condensed summary
|
||||||
|
content, err := e.generateCondensedSummary(ctx, candidates)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge metadata
|
||||||
|
maxDepth := 0
|
||||||
|
descendantCount := 0
|
||||||
|
descendantTokenCount := 0
|
||||||
|
sourceMessageTokens := 0
|
||||||
|
var earliestAt, latestAt *time.Time
|
||||||
|
|
||||||
|
parentIDs := make([]string, len(candidates))
|
||||||
|
for i, c := range candidates {
|
||||||
|
parentIDs[i] = c.SummaryID
|
||||||
|
if c.Depth > maxDepth {
|
||||||
|
maxDepth = c.Depth
|
||||||
|
}
|
||||||
|
descendantCount += c.DescendantCount + 1
|
||||||
|
descendantTokenCount += c.TokenCount + c.DescendantTokenCount
|
||||||
|
sourceMessageTokens += c.SourceMessageTokenCount
|
||||||
|
if c.EarliestAt != nil {
|
||||||
|
if earliestAt == nil || c.EarliestAt.Before(*earliestAt) {
|
||||||
|
earliestAt = c.EarliestAt
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if c.LatestAt != nil {
|
||||||
|
if latestAt == nil || c.LatestAt.After(*latestAt) {
|
||||||
|
latestAt = c.LatestAt
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenCount := tokenizer.EstimateMessageTokens(providers.Message{Content: content})
|
||||||
|
|
||||||
|
summary, err := e.store.CreateSummary(ctx, CreateSummaryInput{
|
||||||
|
ConversationID: convID,
|
||||||
|
Kind: SummaryKindCondensed,
|
||||||
|
Depth: maxDepth + 1,
|
||||||
|
Content: content,
|
||||||
|
TokenCount: tokenCount,
|
||||||
|
EarliestAt: earliestAt,
|
||||||
|
LatestAt: latestAt,
|
||||||
|
DescendantCount: descendantCount,
|
||||||
|
DescendantTokenCount: descendantTokenCount,
|
||||||
|
SourceMessageTokens: sourceMessageTokens,
|
||||||
|
ParentIDs: parentIDs,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the ordinal range for the candidate summaries in context
|
||||||
|
items, err := e.store.GetContextItems(ctx, convID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
candidateSet := make(map[string]bool)
|
||||||
|
for _, c := range candidates {
|
||||||
|
candidateSet[c.SummaryID] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
startOrd := -1
|
||||||
|
endOrd := -1
|
||||||
|
hasNonCandidate := false
|
||||||
|
for _, item := range items {
|
||||||
|
if item.ItemType == "summary" && candidateSet[item.SummaryID] {
|
||||||
|
if startOrd == -1 {
|
||||||
|
startOrd, endOrd = item.Ordinal, item.Ordinal
|
||||||
|
} else {
|
||||||
|
// Check for non-candidate items between endOrd and current ordinal
|
||||||
|
for _, it := range items {
|
||||||
|
if it.Ordinal > endOrd && it.Ordinal <= item.Ordinal {
|
||||||
|
if it.ItemType != "summary" || !candidateSet[it.SummaryID] {
|
||||||
|
hasNonCandidate = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if hasNonCandidate {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if item.Ordinal < startOrd {
|
||||||
|
startOrd = item.Ordinal
|
||||||
|
}
|
||||||
|
if item.Ordinal > endOrd {
|
||||||
|
endOrd = item.Ordinal
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if startOrd == -1 || endOrd == -1 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collect candidate summary IDs
|
||||||
|
candidateIDs := make([]string, 0, len(candidates))
|
||||||
|
for _, c := range candidates {
|
||||||
|
candidateIDs = append(candidateIDs, c.SummaryID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if hasNonCandidate {
|
||||||
|
// Use safe per-item deletion to avoid deleting non-candidate items
|
||||||
|
if err := e.store.ReplaceContextItemsWithSummary(ctx, convID, candidateIDs, summary.SummaryID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Candidates are consecutive, use efficient range deletion
|
||||||
|
if err := e.store.ReplaceContextRangeWithSummary(ctx, convID, startOrd, endOrd, summary.SummaryID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &summary.SummaryID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// selectShallowestCondensationCandidate finds the shallowest consecutive summary group.
|
||||||
|
func (e *CompactionEngine) selectShallowestCondensationCandidate(
|
||||||
|
ctx context.Context, convID int64, forced bool,
|
||||||
|
) ([]Summary, error) {
|
||||||
|
items, err := e.store.GetContextItems(ctx, convID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Group by depth, find consecutive runs
|
||||||
|
tailStartIdx := len(items) - FreshTailCount
|
||||||
|
if tailStartIdx < 0 {
|
||||||
|
tailStartIdx = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
minFanout := CondensedMinFanout
|
||||||
|
if forced {
|
||||||
|
minFanout = CondensedMinFanoutHard
|
||||||
|
}
|
||||||
|
|
||||||
|
// Track depth groups
|
||||||
|
depthGroups := make(map[int][]ContextItem)
|
||||||
|
for i := 0; i < tailStartIdx; i++ {
|
||||||
|
item := items[i]
|
||||||
|
if item.ItemType != "summary" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
sum, err := e.store.GetSummary(ctx, item.SummaryID)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
depthGroups[sum.Depth] = append(depthGroups[sum.Depth], item)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find shallowest depth with enough candidates
|
||||||
|
// Collect all depths and sort to handle non-consecutive depths
|
||||||
|
var depths []int
|
||||||
|
for depth := range depthGroups {
|
||||||
|
depths = append(depths, depth)
|
||||||
|
}
|
||||||
|
sort.Ints(depths)
|
||||||
|
|
||||||
|
for _, depth := range depths {
|
||||||
|
group := depthGroups[depth]
|
||||||
|
if len(group) >= minFanout {
|
||||||
|
// Load summaries
|
||||||
|
var result []Summary
|
||||||
|
for _, item := range group[:minFanout] {
|
||||||
|
sum, err := e.store.GetSummary(ctx, item.SummaryID)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
result = append(result, *sum)
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// selectOldestChunkAtDepth scans context_items from oldest ordinal, collecting consecutive
|
||||||
|
// summaries at the given depth. Stops at non-summary items, different depth, fresh tail, or
|
||||||
|
// token overflow. Returns contiguous chunk of summaries.
|
||||||
|
func (e *CompactionEngine) selectOldestChunkAtDepth(
|
||||||
|
ctx context.Context, convID int64, targetDepth int,
|
||||||
|
) ([]Summary, error) {
|
||||||
|
items, err := e.store.GetContextItems(ctx, convID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
tailStartIdx := len(items) - FreshTailCount
|
||||||
|
if tailStartIdx < 0 {
|
||||||
|
tailStartIdx = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
var chunk []Summary
|
||||||
|
accumTokens := 0
|
||||||
|
|
||||||
|
for i := 0; i < tailStartIdx; i++ {
|
||||||
|
item := items[i]
|
||||||
|
if item.ItemType != "summary" {
|
||||||
|
// Non-summary breaks the chunk
|
||||||
|
break
|
||||||
|
}
|
||||||
|
sum, err := e.store.GetSummary(ctx, item.SummaryID)
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if sum.Depth != targetDepth {
|
||||||
|
// Different depth breaks the chunk
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if accumTokens+sum.TokenCount > LeafChunkTokens {
|
||||||
|
// Token overflow stops collection
|
||||||
|
break
|
||||||
|
}
|
||||||
|
chunk = append(chunk, *sum)
|
||||||
|
accumTokens += sum.TokenCount
|
||||||
|
}
|
||||||
|
|
||||||
|
// Min tokens check: spec line 808
|
||||||
|
// chunk tokens must be >= max(CondensedTargetTokens, LeafChunkTokens × 0.1) = 2000
|
||||||
|
minTokens := CondensedTargetTokens // 2000
|
||||||
|
if accumTokens < minTokens {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return chunk, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateLeafSummary calls the LLM to generate a leaf summary with 3-level escalation.
|
||||||
|
// Level 1: normal LLM prompt. Level 2: aggressive prompt. Level 3: deterministic truncation.
|
||||||
|
func (e *CompactionEngine) generateLeafSummary(
|
||||||
|
ctx context.Context,
|
||||||
|
messages []Message,
|
||||||
|
previousSummary string,
|
||||||
|
) (string, error) {
|
||||||
|
if e.complete == nil {
|
||||||
|
return truncateSummary(messages), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
sourceText := formatMessagesForSummary(messages)
|
||||||
|
inputTokens := sumMessageTokens(messages)
|
||||||
|
targetTokens := minInt(LeafTargetTokens, int(float64(inputTokens)*0.35))
|
||||||
|
|
||||||
|
// Level 1: normal prompt
|
||||||
|
prompt := buildLeafSummaryPrompt(sourceText, previousSummary, targetTokens)
|
||||||
|
content, err := e.complete(ctx, prompt, CompleteOptions{
|
||||||
|
MaxTokens: LeafTargetTokens * 2,
|
||||||
|
Temperature: 0.3,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if content == "" {
|
||||||
|
// Retry with temperature=0
|
||||||
|
content, err = e.complete(ctx, prompt, CompleteOptions{
|
||||||
|
MaxTokens: LeafTargetTokens * 2,
|
||||||
|
Temperature: 0,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if level 1 succeeded
|
||||||
|
if content != "" && tokenizer.EstimateMessageTokens(providers.Message{Content: content}) < inputTokens {
|
||||||
|
return content, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Level 2: aggressive prompt
|
||||||
|
aggressiveTarget := minInt(640, int(float64(inputTokens)*0.20))
|
||||||
|
aggressivePrompt := buildAggressiveLeafSummaryPrompt(sourceText, previousSummary, aggressiveTarget)
|
||||||
|
content, err = e.complete(ctx, aggressivePrompt, CompleteOptions{
|
||||||
|
MaxTokens: aggressiveTarget * 2,
|
||||||
|
Temperature: 0.3,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if content == "" {
|
||||||
|
// Retry with temperature=0
|
||||||
|
content, err = e.complete(ctx, aggressivePrompt, CompleteOptions{
|
||||||
|
MaxTokens: aggressiveTarget * 2,
|
||||||
|
Temperature: 0,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if content != "" && tokenizer.EstimateMessageTokens(providers.Message{Content: content}) < inputTokens {
|
||||||
|
return content, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Level 3: deterministic truncation
|
||||||
|
return truncateSummary(messages), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateCondensedSummary calls the LLM to generate a condensed summary with 3-level escalation.
|
||||||
|
func (e *CompactionEngine) generateCondensedSummary(ctx context.Context, summaries []Summary) (string, error) {
|
||||||
|
if e.complete == nil {
|
||||||
|
return truncateCondensedSummaries(summaries), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
sourceText := formatSummariesForCondensation(summaries)
|
||||||
|
inputTokens := sumSummaryTokens(summaries)
|
||||||
|
targetTokens := minInt(CondensedTargetTokens, int(float64(inputTokens)*0.35))
|
||||||
|
|
||||||
|
// Level 1: normal prompt
|
||||||
|
prompt := buildCondensedSummaryPrompt(sourceText, targetTokens)
|
||||||
|
content, err := e.complete(ctx, prompt, CompleteOptions{
|
||||||
|
MaxTokens: CondensedTargetTokens * 2,
|
||||||
|
Temperature: 0.3,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if content == "" {
|
||||||
|
content, err = e.complete(ctx, prompt, CompleteOptions{
|
||||||
|
MaxTokens: CondensedTargetTokens * 2,
|
||||||
|
Temperature: 0,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if content != "" {
|
||||||
|
return content, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Level 2: aggressive prompt
|
||||||
|
aggressiveTarget := minInt(640, int(float64(inputTokens)*0.20))
|
||||||
|
aggressivePrompt := buildCondensedSummaryPrompt(sourceText, aggressiveTarget)
|
||||||
|
content, err = e.complete(ctx, aggressivePrompt, CompleteOptions{
|
||||||
|
MaxTokens: aggressiveTarget * 2,
|
||||||
|
Temperature: 0.3,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if content != "" {
|
||||||
|
return content, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Level 3: deterministic fallback
|
||||||
|
return truncateCondensedSummaries(summaries), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// runCondensedLoop runs condensed compaction in a loop until:
|
||||||
|
// a) context tokens <= threshold (success), OR
|
||||||
|
// b) No candidate found (nothing to condense), OR
|
||||||
|
// c) tokensAfter >= tokensBefore (no progress this iteration), OR
|
||||||
|
// d) tokensAfter >= previousTokens (no improvement over last iteration)
|
||||||
|
func (e *CompactionEngine) runCondensedLoop(ctx context.Context, convID int64) {
|
||||||
|
var prevTokens int
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
tokensBefore, err := e.store.GetContextTokenCount(ctx, convID)
|
||||||
|
if err != nil {
|
||||||
|
logger.ErrorCF("seahorse", "condensed: get tokens", map[string]any{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
condensedID, err := e.compactCondensed(ctx, convID)
|
||||||
|
if err != nil {
|
||||||
|
logger.ErrorCF("seahorse", "condensed: compact", map[string]any{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if condensedID == nil {
|
||||||
|
// No candidate found
|
||||||
|
logger.DebugCF("seahorse", "condensed: no candidate", map[string]any{"conv_id": convID})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tokensAfter, _ := e.store.GetContextTokenCount(ctx, convID)
|
||||||
|
|
||||||
|
if tokensAfter >= tokensBefore {
|
||||||
|
// No progress this iteration
|
||||||
|
logger.DebugCF(
|
||||||
|
"seahorse",
|
||||||
|
"condensed: no progress",
|
||||||
|
map[string]any{"conv_id": convID, "tokens_before": tokensBefore, "tokens_after": tokensAfter},
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if tokensAfter >= prevTokens && prevTokens > 0 {
|
||||||
|
// No improvement over last iteration
|
||||||
|
logger.DebugCF(
|
||||||
|
"seahorse",
|
||||||
|
"condensed: no improvement",
|
||||||
|
map[string]any{"conv_id": convID, "tokens": tokensAfter},
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
prevTokens = tokensAfter
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Helper functions ---
|
||||||
|
|
||||||
|
func formatMessagesForSummary(messages []Message) string {
|
||||||
|
var result string
|
||||||
|
for _, m := range messages {
|
||||||
|
ts := m.CreatedAt.Format("2006-01-02 15:04 MST")
|
||||||
|
content := m.Content
|
||||||
|
if content == "" && len(m.Parts) > 0 {
|
||||||
|
content = partsToReadableContent(m.Parts)
|
||||||
|
}
|
||||||
|
result += fmt.Sprintf("[%s]\n%s\n\n", ts, content)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatSummariesForCondensation(summaries []Summary) string {
|
||||||
|
var result string
|
||||||
|
for _, s := range summaries {
|
||||||
|
earliest := ""
|
||||||
|
if s.EarliestAt != nil {
|
||||||
|
earliest = s.EarliestAt.Format("2006-01-02")
|
||||||
|
}
|
||||||
|
latest := ""
|
||||||
|
if s.LatestAt != nil {
|
||||||
|
latest = s.LatestAt.Format("2006-01-02")
|
||||||
|
}
|
||||||
|
result += fmt.Sprintf("[%s - %s]\n%s\n\n", earliest, latest, s.Content)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildLeafSummaryPrompt(sourceText, previousSummary string, targetTokens int) string {
|
||||||
|
prev := "(none)"
|
||||||
|
if previousSummary != "" {
|
||||||
|
prev = previousSummary
|
||||||
|
}
|
||||||
|
return fmt.Sprintf(`You summarize a SEGMENT of a conversation for future model turns.
|
||||||
|
Treat this as incremental memory compaction input, not a full-conversation summary.
|
||||||
|
|
||||||
|
Normal summary policy:
|
||||||
|
- Preserve key decisions, rationale, constraints, and active tasks.
|
||||||
|
- Keep essential technical details needed to continue work safely.
|
||||||
|
- Remove obvious repetition and conversational filler.
|
||||||
|
|
||||||
|
Output requirements:
|
||||||
|
- Plain text only.
|
||||||
|
- No preamble, headings, or markdown formatting.
|
||||||
|
- Track file operations (created, modified, deleted, renamed) with file paths and current status.
|
||||||
|
- If no file operations appear, include exactly: "Files: none".
|
||||||
|
- End with exactly: "Expand for details about: <comma-separated list of what was dropped or compressed>".
|
||||||
|
- Target length: about %d tokens or less.
|
||||||
|
|
||||||
|
<previous_context>
|
||||||
|
%s
|
||||||
|
</previous_context>
|
||||||
|
|
||||||
|
<conversation_segment>
|
||||||
|
%s
|
||||||
|
</conversation_segment>`, targetTokens, prev, sourceText)
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildCondensedSummaryPrompt(sourceText string, targetTokens int) string {
|
||||||
|
return fmt.Sprintf(`You condense multiple summaries into a single higher-level summary.
|
||||||
|
Preserve all important decisions, constraints, and outcomes.
|
||||||
|
Merge overlapping topics. Keep technical details intact.
|
||||||
|
|
||||||
|
Output requirements:
|
||||||
|
- Plain text only.
|
||||||
|
- No preamble, headings, or markdown formatting.
|
||||||
|
- End with exactly: "Expand for details about: <comma-separated list>".
|
||||||
|
- Target length: about %d tokens or less.
|
||||||
|
|
||||||
|
<summaries>
|
||||||
|
%s
|
||||||
|
</summaries>`, targetTokens, sourceText)
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildAggressiveLeafSummaryPrompt(sourceText, previousSummary string, targetTokens int) string {
|
||||||
|
prev := "(none)"
|
||||||
|
if previousSummary != "" {
|
||||||
|
prev = previousSummary
|
||||||
|
}
|
||||||
|
return fmt.Sprintf(`You summarize a SEGMENT of a conversation for future model turns.
|
||||||
|
Aggressive summary policy:
|
||||||
|
- Keep only durable facts and current task state.
|
||||||
|
- Remove examples, repetition, and low-value narrative details.
|
||||||
|
- Preserve explicit TODOs, blockers, decisions, and constraints.
|
||||||
|
|
||||||
|
Output requirements:
|
||||||
|
- Plain text only.
|
||||||
|
- No preamble, headings, or markdown formatting.
|
||||||
|
- Track file operations (created, modified, deleted, renamed) with file paths and current status.
|
||||||
|
- If no file operations appear, include exactly: "Files: none".
|
||||||
|
- End with exactly: "Expand for details about: <comma-separated list of what was dropped or compressed>".
|
||||||
|
- Target length: about %d tokens or less.
|
||||||
|
|
||||||
|
<previous_context>
|
||||||
|
%s
|
||||||
|
</previous_context>
|
||||||
|
|
||||||
|
<conversation_segment>
|
||||||
|
%s
|
||||||
|
</conversation_segment>`, targetTokens, prev, sourceText)
|
||||||
|
}
|
||||||
|
|
||||||
|
func truncateSummary(messages []Message) string {
|
||||||
|
content := ""
|
||||||
|
for _, m := range messages {
|
||||||
|
c := m.Content
|
||||||
|
if c == "" && len(m.Parts) > 0 {
|
||||||
|
c = partsToReadableContent(m.Parts)
|
||||||
|
}
|
||||||
|
content += c + "\n"
|
||||||
|
}
|
||||||
|
if len(content) > 2048 {
|
||||||
|
content = content[:2048]
|
||||||
|
}
|
||||||
|
content += fmt.Sprintf("\n[Truncated from %d messages]", len(messages))
|
||||||
|
return content
|
||||||
|
}
|
||||||
|
|
||||||
|
func truncateCondensedSummaries(summaries []Summary) string {
|
||||||
|
content := ""
|
||||||
|
for _, s := range summaries {
|
||||||
|
content += s.Content + "\n"
|
||||||
|
}
|
||||||
|
if len(content) > 2048 {
|
||||||
|
content = content[:2048]
|
||||||
|
}
|
||||||
|
content += fmt.Sprintf("\n[Condensed from %d summaries]", len(summaries))
|
||||||
|
return content
|
||||||
|
}
|
||||||
|
|
||||||
|
func sumMessageTokens(messages []Message) int {
|
||||||
|
total := 0
|
||||||
|
for _, m := range messages {
|
||||||
|
total += m.TokenCount
|
||||||
|
}
|
||||||
|
return total
|
||||||
|
}
|
||||||
|
|
||||||
|
func sumSummaryTokens(summaries []Summary) int {
|
||||||
|
total := 0
|
||||||
|
for _, s := range summaries {
|
||||||
|
total += s.TokenCount
|
||||||
|
}
|
||||||
|
return total
|
||||||
|
}
|
||||||
|
|
||||||
|
func minInt(a, b int) int {
|
||||||
|
if a < b {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
@@ -0,0 +1,974 @@
|
|||||||
|
package seahorse
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// --- Test Helpers ---
|
||||||
|
|
||||||
|
// waitForCondensed blocks until the async condensed goroutine for convID finishes.
|
||||||
|
// Returns false if timeout is reached.
|
||||||
|
func waitForCondensed(ce *CompactionEngine, convID int64, timeout time.Duration) bool {
|
||||||
|
deadline := time.Now().Add(timeout)
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
if _, exists := ce.condensing.Load(convID); !exists {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Compaction Tests ---
|
||||||
|
|
||||||
|
func newTestCompactionEngine(t *testing.T) (*CompactionEngine, *Store, int64) {
|
||||||
|
t.Helper()
|
||||||
|
db := openTestDB(t)
|
||||||
|
if err := runSchema(db); err != nil {
|
||||||
|
t.Fatalf("migration: %v", err)
|
||||||
|
}
|
||||||
|
s := &Store{db: db}
|
||||||
|
ctx := context.Background()
|
||||||
|
conv, _ := s.GetOrCreateConversation(ctx, "test:compact")
|
||||||
|
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
|
||||||
|
ce := &CompactionEngine{
|
||||||
|
store: s,
|
||||||
|
config: Config{},
|
||||||
|
complete: mockCompleteFn,
|
||||||
|
shutdownCtx: shutdownCtx,
|
||||||
|
shutdownCancel: shutdownCancel,
|
||||||
|
}
|
||||||
|
convID := conv.ConversationID
|
||||||
|
// Ensure async goroutines are stopped before database is closed.
|
||||||
|
// Register cleanup here (after openTestDB) so it runs BEFORE openTestDB's db.Close().
|
||||||
|
t.Cleanup(func() {
|
||||||
|
shutdownCancel()
|
||||||
|
// Wait for async condensed goroutine to finish (poll condensing map)
|
||||||
|
deadline := time.Now().Add(2 * time.Second)
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
if _, exists := ce.condensing.Load(convID); !exists {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return ce, s, conv.ConversationID
|
||||||
|
}
|
||||||
|
|
||||||
|
// newTestCompactionEngineWithStore creates a CompactionEngine with existing store.
|
||||||
|
// Note: Caller is responsible for calling shutdownCancel when test ends.
|
||||||
|
func newTestCompactionEngineWithStore(
|
||||||
|
s *Store, complete CompleteFn,
|
||||||
|
) (ce *CompactionEngine, shutdownCancel context.CancelFunc) {
|
||||||
|
shutdownCtx, cancel := context.WithCancel(context.Background())
|
||||||
|
return &CompactionEngine{
|
||||||
|
store: s,
|
||||||
|
config: Config{},
|
||||||
|
complete: complete,
|
||||||
|
shutdownCtx: shutdownCtx,
|
||||||
|
shutdownCancel: cancel,
|
||||||
|
}, cancel
|
||||||
|
}
|
||||||
|
|
||||||
|
// mockCompleteFn returns a simple summary for testing
|
||||||
|
var mockCompleteFn CompleteFn = func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) {
|
||||||
|
return "Mock summary of the conversation segment.", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNeedsCompaction(t *testing.T) {
|
||||||
|
ce, s, convID := newTestCompactionEngine(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Empty context — no compaction needed
|
||||||
|
needed, err := ce.NeedsCompaction(ctx, convID, 10000)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NeedsCompaction: %v", err)
|
||||||
|
}
|
||||||
|
if needed {
|
||||||
|
t.Error("expected no compaction for empty context")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add messages to context, total tokens = 8000
|
||||||
|
for i := 0; i < 8; i++ {
|
||||||
|
m, _ := s.AddMessage(ctx, convID, "user", "test message content", 1000)
|
||||||
|
s.AppendContextMessage(ctx, convID, m.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Threshold = 0.75 × 10000 = 7500. We have 8000 tokens → needs compaction
|
||||||
|
needed, err = ce.NeedsCompaction(ctx, convID, 10000)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NeedsCompaction: %v", err)
|
||||||
|
}
|
||||||
|
if !needed {
|
||||||
|
t.Error("expected compaction needed at 8000/10000 tokens (threshold 75%)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Below threshold: 5000 / 10000 → no compaction
|
||||||
|
s.UpsertContextItems(ctx, convID, nil) // clear
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
m, _ := s.AddMessage(ctx, convID, "user", "test", 1000)
|
||||||
|
s.AppendContextMessage(ctx, convID, m.ID)
|
||||||
|
}
|
||||||
|
needed, _ = ce.NeedsCompaction(ctx, convID, 10000)
|
||||||
|
if needed {
|
||||||
|
t.Error("expected no compaction at 5000/10000 tokens")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompactLeaf(t *testing.T) {
|
||||||
|
ce, s, convID := newTestCompactionEngine(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create enough messages to trigger leaf compaction:
|
||||||
|
// Need > FreshTailCount(32) evictable messages with >= LeafMinFanout(8) contiguous
|
||||||
|
for i := 0; i < 40; i++ {
|
||||||
|
m, _ := s.AddMessage(ctx, convID, "user", "message content for compaction test", 100)
|
||||||
|
s.AppendContextMessage(ctx, convID, m.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compact
|
||||||
|
result, err := ce.Compact(ctx, convID, CompactInput{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Compact: %v", err)
|
||||||
|
}
|
||||||
|
if result == nil {
|
||||||
|
t.Fatal("expected non-nil result")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should have created at least one leaf summary
|
||||||
|
if result.LeafSummaries == 0 {
|
||||||
|
t.Error("expected at least 1 leaf summary")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Context should now contain a summary item
|
||||||
|
items, _ := s.GetContextItems(ctx, convID)
|
||||||
|
foundSummary := false
|
||||||
|
for _, item := range items {
|
||||||
|
if item.ItemType == "summary" {
|
||||||
|
foundSummary = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !foundSummary {
|
||||||
|
t.Error("expected a summary in context_items after leaf compaction")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Some messages should have been replaced
|
||||||
|
if len(result.SummariesCreated) == 0 {
|
||||||
|
t.Error("expected at least 1 summary created")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompactLeafNoCandidate(t *testing.T) {
|
||||||
|
ce, _, convID := newTestCompactionEngine(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Too few messages to trigger leaf compaction
|
||||||
|
m, _ := ce.store.AddMessage(ctx, convID, "user", "short", 10)
|
||||||
|
ce.store.AppendContextMessage(ctx, convID, m.ID)
|
||||||
|
|
||||||
|
result, err := ce.Compact(ctx, convID, CompactInput{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Compact: %v", err)
|
||||||
|
}
|
||||||
|
if result == nil {
|
||||||
|
t.Fatal("expected non-nil result even with no candidate")
|
||||||
|
}
|
||||||
|
if result.LeafSummaries != 0 {
|
||||||
|
t.Errorf("LeafSummaries = %d, want 0 (too few messages)", result.LeafSummaries)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompactCondensed(t *testing.T) {
|
||||||
|
ce, s, convID := newTestCompactionEngine(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create enough leaf summaries and fresh messages to enable condensation
|
||||||
|
leafIDs := make([]string, CondensedMinFanout)
|
||||||
|
for i := 0; i < CondensedMinFanout; i++ {
|
||||||
|
now := time.Now().UTC()
|
||||||
|
summary, err := s.CreateSummary(ctx, CreateSummaryInput{
|
||||||
|
ConversationID: convID,
|
||||||
|
Kind: SummaryKindLeaf,
|
||||||
|
Depth: 0,
|
||||||
|
Content: "leaf summary content " + time.Now().String(),
|
||||||
|
TokenCount: 500,
|
||||||
|
EarliestAt: &now,
|
||||||
|
LatestAt: &now,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateSummary %d: %v", i, err)
|
||||||
|
}
|
||||||
|
leafIDs[i] = summary.SummaryID
|
||||||
|
s.AppendContextSummary(ctx, convID, summary.SummaryID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add enough fresh messages to have a fresh tail (>= FreshTailCount)
|
||||||
|
for i := 0; i < FreshTailCount; i++ {
|
||||||
|
m, _ := s.AddMessage(ctx, convID, "user", "fresh message", 10)
|
||||||
|
s.AppendContextMessage(ctx, convID, m.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compact with force to trigger condensation
|
||||||
|
_, err := ce.Compact(ctx, convID, CompactInput{Force: true})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Compact: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for async condensed goroutine to complete
|
||||||
|
if !waitForCondensed(ce, convID, 2*time.Second) {
|
||||||
|
t.Fatal("timeout waiting for condensed compaction")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should have created a condensed summary in the DB
|
||||||
|
summaries, _ := s.GetSummariesByConversation(ctx, convID)
|
||||||
|
foundCondensed := false
|
||||||
|
for _, sum := range summaries {
|
||||||
|
if sum.Kind == SummaryKindCondensed {
|
||||||
|
foundCondensed = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !foundCondensed {
|
||||||
|
t.Error("expected at least 1 condensed summary")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompactCondensedDoesNotOrphanSummaryWhenCandidatesRemovedConcurrently(t *testing.T) {
|
||||||
|
// Reproduce orphan bug: candidates found by selectOldestChunkAtDepth are removed
|
||||||
|
// from context_items between candidate selection and ordinal range scan.
|
||||||
|
// Use a slow CompleteFn with barrier sync to control timing.
|
||||||
|
s := openTestStore(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
conv, _ := s.GetOrCreateConversation(ctx, "test:orphan-race")
|
||||||
|
convID := conv.ConversationID
|
||||||
|
|
||||||
|
// Create leaf summaries with enough tokens for condensation
|
||||||
|
var leafIDs []string
|
||||||
|
for i := 0; i < CondensedMinFanout; i++ {
|
||||||
|
now := time.Now().UTC()
|
||||||
|
sum, err := s.CreateSummary(ctx, CreateSummaryInput{
|
||||||
|
ConversationID: convID,
|
||||||
|
Kind: SummaryKindLeaf,
|
||||||
|
Depth: 0,
|
||||||
|
Content: fmt.Sprintf("leaf summary %d", i),
|
||||||
|
TokenCount: 500,
|
||||||
|
EarliestAt: &now,
|
||||||
|
LatestAt: &now,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateSummary: %v", err)
|
||||||
|
}
|
||||||
|
leafIDs = append(leafIDs, sum.SummaryID)
|
||||||
|
s.AppendContextSummary(ctx, convID, sum.SummaryID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add fresh tail so leaf summaries are in evictable range
|
||||||
|
for i := 0; i < FreshTailCount+1; i++ {
|
||||||
|
m, _ := s.AddMessage(ctx, convID, "user", "fresh", 10)
|
||||||
|
s.AppendContextMessage(ctx, convID, m.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Barrier: CompleteFn waits until test removes context_items, then returns
|
||||||
|
var barrier1, barrier2 sync.WaitGroup
|
||||||
|
barrier1.Add(1) // CompleteFn signals when called
|
||||||
|
barrier2.Add(1) // test signals when context_items removed
|
||||||
|
|
||||||
|
slowComplete := func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) {
|
||||||
|
barrier1.Done() // signal: LLM called, candidates selected
|
||||||
|
barrier2.Wait() // wait: test removes context_items
|
||||||
|
return "Condensed summary.", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ce, cancel := newTestCompactionEngineWithStore(s, slowComplete)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
cancel()
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Run compactCondensed in background
|
||||||
|
type compactResult struct {
|
||||||
|
summaryID *string
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
resultCh := make(chan compactResult, 1)
|
||||||
|
go func() {
|
||||||
|
sid, err := ce.compactCondensed(context.Background(), convID)
|
||||||
|
resultCh <- compactResult{summaryID: sid, err: err}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait for CompleteFn to be called (candidates selected)
|
||||||
|
barrier1.Wait()
|
||||||
|
|
||||||
|
// Remove leaf summaries from context_items (simulating concurrent replacement)
|
||||||
|
items, _ := s.GetContextItems(ctx, convID)
|
||||||
|
var preserved []ContextItem
|
||||||
|
for _, item := range items {
|
||||||
|
isLeaf := false
|
||||||
|
for _, lid := range leafIDs {
|
||||||
|
if item.SummaryID == lid {
|
||||||
|
isLeaf = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !isLeaf {
|
||||||
|
preserved = append(preserved, item)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.UpsertContextItems(ctx, convID, preserved)
|
||||||
|
|
||||||
|
// Let CompleteFn return
|
||||||
|
barrier2.Done()
|
||||||
|
|
||||||
|
// Get result
|
||||||
|
res := <-resultCh
|
||||||
|
if res.err != nil {
|
||||||
|
t.Fatalf("compactCondensed: %v", res.err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// With the bug: returns non-nil summaryID even though context_items has no matching ordinals
|
||||||
|
// The fix: should return nil when startOrd == -1
|
||||||
|
if res.summaryID != nil {
|
||||||
|
t.Errorf("compactCondensed returned summaryID=%s, want nil (orphan created)", *res.summaryID)
|
||||||
|
|
||||||
|
// Verify the orphan exists in DB
|
||||||
|
summary, _ := s.GetSummary(context.Background(), *res.summaryID)
|
||||||
|
if summary != nil && summary.Kind == SummaryKindCondensed {
|
||||||
|
// Check it's NOT in context_items (orphan)
|
||||||
|
items2, _ := s.GetContextItems(context.Background(), convID)
|
||||||
|
found := false
|
||||||
|
for _, item := range items2 {
|
||||||
|
if item.SummaryID == *res.summaryID {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Error("condensed summary exists in DB but not in context_items — orphan confirmed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompactUntilUnder(t *testing.T) {
|
||||||
|
ce, s, convID := newTestCompactionEngine(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create many leaf summaries to ensure we can condense
|
||||||
|
for i := 0; i < 8; i++ {
|
||||||
|
now := time.Now().UTC()
|
||||||
|
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||||
|
ConversationID: convID,
|
||||||
|
Kind: SummaryKindLeaf,
|
||||||
|
Depth: 0,
|
||||||
|
Content: "leaf summary for condensation test",
|
||||||
|
TokenCount: 500,
|
||||||
|
EarliestAt: &now,
|
||||||
|
LatestAt: &now,
|
||||||
|
})
|
||||||
|
s.AppendContextSummary(ctx, convID, summary.SummaryID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Force compact until under budget
|
||||||
|
result, err := ce.CompactUntilUnder(ctx, convID, 2000)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CompactUntilUnder: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result == nil {
|
||||||
|
t.Fatal("expected non-nil result")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSelectShallowestCondensationCandidate(t *testing.T) {
|
||||||
|
ce, s, convID := newTestCompactionEngine(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create enough leaf summaries + fresh messages for candidates
|
||||||
|
for i := 0; i < LeafMinFanout; i++ {
|
||||||
|
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||||
|
ConversationID: convID,
|
||||||
|
Kind: SummaryKindLeaf,
|
||||||
|
Depth: 0,
|
||||||
|
Content: "leaf",
|
||||||
|
TokenCount: 100,
|
||||||
|
})
|
||||||
|
s.AppendContextSummary(ctx, convID, summary.SummaryID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add fresh tail messages so summaries are in evictable range
|
||||||
|
for i := 0; i < FreshTailCount+1; i++ {
|
||||||
|
m, _ := s.AddMessage(ctx, convID, "user", "fresh", 5)
|
||||||
|
s.AppendContextMessage(ctx, convID, m.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
candidates, err := ce.selectShallowestCondensationCandidate(ctx, convID, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("selectShallowestCondensationCandidate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should find leaf summaries at depth 0
|
||||||
|
if len(candidates) < CondensedMinFanout {
|
||||||
|
t.Errorf("candidates = %d, want >= %d", len(candidates), CondensedMinFanout)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSelectShallowestCondensationCandidateEmpty(t *testing.T) {
|
||||||
|
ce, _, convID := newTestCompactionEngine(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
candidates, err := ce.selectShallowestCondensationCandidate(ctx, convID, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("selectShallowestCondensationCandidate: %v", err)
|
||||||
|
}
|
||||||
|
if len(candidates) != 0 {
|
||||||
|
t.Errorf("candidates = %d, want 0 for empty context", len(candidates))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompactCondensedUsesSelectOldestChunk(t *testing.T) {
|
||||||
|
// Verify that compactCondensed prefers ordinal-ordered chunks via selectOldestChunkAtDepth
|
||||||
|
// rather than just grouping by depth without regard to order
|
||||||
|
ce, s, convID := newTestCompactionEngine(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create interleaved summaries at depth 0 with a message in between:
|
||||||
|
// sum1 (ordinal 100), msg (ordinal 200), sum2 (ordinal 300)
|
||||||
|
|
||||||
|
for i := 0; i < LeafMinFanout+2; i++ {
|
||||||
|
now := time.Now().UTC()
|
||||||
|
|
||||||
|
s.CreateSummary(ctx, CreateSummaryInput{
|
||||||
|
ConversationID: convID,
|
||||||
|
Kind: SummaryKindLeaf,
|
||||||
|
Depth: 0,
|
||||||
|
Content: fmt.Sprintf("leaf summary %d", i),
|
||||||
|
TokenCount: 100,
|
||||||
|
EarliestAt: &now,
|
||||||
|
LatestAt: &now,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert a message between first two summaries to break contiguity
|
||||||
|
// for selectShallowestCondensationCandidate but would still find all 3
|
||||||
|
// but selectOldestChunkAtDepth should only find sum1 + sum2 (not sum3)
|
||||||
|
|
||||||
|
msg, _ := s.AddMessage(ctx, convID, "user", "interrupting message", 5)
|
||||||
|
s.AppendContextMessage(ctx, convID, msg.ID)
|
||||||
|
|
||||||
|
// Run compactCondensed
|
||||||
|
result, err := ce.compactCondensed(ctx, convID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("compactCondensed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The result should have merged the two summaries at the start
|
||||||
|
// (skipping the message in between), This proves ordinal-aware selection works.
|
||||||
|
|
||||||
|
_ = result // verify summary was created
|
||||||
|
|
||||||
|
if result != nil {
|
||||||
|
summaries, _ := s.GetSummariesByConversation(ctx, convID)
|
||||||
|
found := false
|
||||||
|
for _, sum := range summaries {
|
||||||
|
if sum.Kind == SummaryKindCondensed {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Error("expected condensed summary to be created via ordinal-aware selection")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompactCondensedUsesOrdinalAwareSelection(t *testing.T) {
|
||||||
|
ce, s, convID := newTestCompactionEngine(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create leaf summaries at depth 0 (total tokens >= CondensedTargetTokens)
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||||
|
ConversationID: convID,
|
||||||
|
Kind: SummaryKindLeaf,
|
||||||
|
Depth: 0,
|
||||||
|
Content: fmt.Sprintf("leaf summary %d", i),
|
||||||
|
TokenCount: 500, // 5 × 500 = 2500 >= CondensedTargetTokens (2000)
|
||||||
|
})
|
||||||
|
s.AppendContextSummary(ctx, convID, summary.SummaryID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add fresh tail
|
||||||
|
for i := 0; i < FreshTailCount+1; i++ {
|
||||||
|
m, _ := s.AddMessage(ctx, convID, "user", "fresh", 5)
|
||||||
|
s.AppendContextMessage(ctx, convID, m.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
chunk, err := ce.selectOldestChunkAtDepth(ctx, convID, 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("selectOldestChunkAtDepth: %v", err)
|
||||||
|
}
|
||||||
|
if len(chunk) < 2 {
|
||||||
|
t.Errorf("chunk length = %d, want >= 2 contiguous summaries", len(chunk))
|
||||||
|
}
|
||||||
|
for _, s := range chunk {
|
||||||
|
if s.Depth != 0 {
|
||||||
|
t.Errorf("got depth %d, want 0", s.Depth)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSelectOldestChunkAtDepthBreaksOnMessage(t *testing.T) {
|
||||||
|
ce, s, convID := newTestCompactionEngine(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create 3 summaries, then a message, then 3 more summaries
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||||
|
ConversationID: convID,
|
||||||
|
Kind: SummaryKindLeaf,
|
||||||
|
Depth: 0,
|
||||||
|
Content: fmt.Sprintf("leaf %d", i),
|
||||||
|
TokenCount: 100,
|
||||||
|
})
|
||||||
|
s.AppendContextSummary(ctx, convID, summary.SummaryID)
|
||||||
|
}
|
||||||
|
msg, _ := s.AddMessage(ctx, convID, "user", "break", 10)
|
||||||
|
s.AppendContextMessage(ctx, convID, msg.ID)
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||||
|
ConversationID: convID,
|
||||||
|
Kind: SummaryKindLeaf,
|
||||||
|
Depth: 0,
|
||||||
|
Content: fmt.Sprintf("leaf-after %d", i),
|
||||||
|
TokenCount: 100,
|
||||||
|
})
|
||||||
|
s.AppendContextSummary(ctx, convID, summary.SummaryID)
|
||||||
|
}
|
||||||
|
for i := 0; i < FreshTailCount+1; i++ {
|
||||||
|
m, _ := s.AddMessage(ctx, convID, "user", "fresh", 5)
|
||||||
|
s.AppendContextMessage(ctx, convID, m.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
chunk, _ := ce.selectOldestChunkAtDepth(ctx, convID, 0)
|
||||||
|
if len(chunk) > 3 {
|
||||||
|
t.Errorf("chunk length = %d, want <= 3 (message breaks chain)", len(chunk))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSelectOldestChunkAtDepthMinTokens(t *testing.T) {
|
||||||
|
ce, s, convID := newTestCompactionEngine(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create summaries with very low token counts (total < 2000)
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||||
|
ConversationID: convID,
|
||||||
|
Kind: SummaryKindLeaf,
|
||||||
|
Depth: 0,
|
||||||
|
Content: fmt.Sprintf("tiny summary %d", i),
|
||||||
|
TokenCount: 50, // very small
|
||||||
|
})
|
||||||
|
s.AppendContextSummary(ctx, convID, summary.SummaryID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add fresh tail to protect from compaction
|
||||||
|
for i := 0; i < FreshTailCount+1; i++ {
|
||||||
|
m, _ := s.AddMessage(ctx, convID, "user", fmt.Sprintf("tail %d", i), 10)
|
||||||
|
s.AppendContextMessage(ctx, convID, m.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should return nil because total tokens (250) < 2000 minimum
|
||||||
|
chunk, err := ce.selectOldestChunkAtDepth(ctx, convID, 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("selectOldestChunkAtDepth: %v", err)
|
||||||
|
}
|
||||||
|
if len(chunk) > 0 {
|
||||||
|
t.Errorf("expected empty chunk when tokens < 2000, got %d summaries", len(chunk))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSelectOldestChunkAtDepthPassesMinTokens(t *testing.T) {
|
||||||
|
ce, s, convID := newTestCompactionEngine(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create summaries with enough tokens (total >= 2000)
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||||
|
ConversationID: convID,
|
||||||
|
Kind: SummaryKindLeaf,
|
||||||
|
Depth: 0,
|
||||||
|
Content: fmt.Sprintf(
|
||||||
|
"substantial summary with enough content to meet minimum token threshold for condensation candidate %d",
|
||||||
|
i,
|
||||||
|
),
|
||||||
|
TokenCount: 500, // 5 × 500 = 2500 >= 2000
|
||||||
|
})
|
||||||
|
s.AppendContextSummary(ctx, convID, summary.SummaryID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add fresh tail
|
||||||
|
for i := 0; i < FreshTailCount+1; i++ {
|
||||||
|
m, _ := s.AddMessage(ctx, convID, "user", fmt.Sprintf("tail %d", i), 10)
|
||||||
|
s.AppendContextMessage(ctx, convID, m.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should return chunk because total tokens (2500) >= 2000
|
||||||
|
chunk, err := ce.selectOldestChunkAtDepth(ctx, convID, 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("selectOldestChunkAtDepth: %v", err)
|
||||||
|
}
|
||||||
|
if len(chunk) == 0 {
|
||||||
|
t.Error("expected non-empty chunk when tokens >= 2000")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateLeafSummary(t *testing.T) {
|
||||||
|
ce, _, _ := newTestCompactionEngine(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
msgs := []Message{
|
||||||
|
{Role: "user", Content: "hello world", TokenCount: 5},
|
||||||
|
{Role: "assistant", Content: "hi there", TokenCount: 5},
|
||||||
|
}
|
||||||
|
|
||||||
|
content, err := ce.generateLeafSummary(ctx, msgs, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("generateLeafSummary: %v", err)
|
||||||
|
}
|
||||||
|
if content == "" {
|
||||||
|
t.Error("expected non-empty summary content")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateLeafSummaryEscalationToAggressive(t *testing.T) {
|
||||||
|
// Level 1 returns summary that's too large (tokens >= input), should escalate to level 2
|
||||||
|
var calls []string
|
||||||
|
escalateComplete := func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) {
|
||||||
|
if contains(prompt, "Aggressive summary policy") {
|
||||||
|
calls = append(calls, "aggressive")
|
||||||
|
return "Short aggressive summary.", nil
|
||||||
|
}
|
||||||
|
calls = append(calls, "normal")
|
||||||
|
// Return a very long summary to trigger escalation
|
||||||
|
longContent := make([]byte, 5000)
|
||||||
|
for i := range longContent {
|
||||||
|
longContent[i] = 'x'
|
||||||
|
}
|
||||||
|
return string(longContent), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
s := openTestStore(t)
|
||||||
|
ce, _ := newTestCompactionEngineWithStore(s, escalateComplete)
|
||||||
|
|
||||||
|
msgs := []Message{
|
||||||
|
{Role: "user", Content: "hello world", TokenCount: 10},
|
||||||
|
{Role: "assistant", Content: "response", TokenCount: 10},
|
||||||
|
}
|
||||||
|
|
||||||
|
content, err := ce.generateLeafSummary(context.Background(), msgs, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("generateLeafSummary: %v", err)
|
||||||
|
}
|
||||||
|
if content == "" {
|
||||||
|
t.Error("expected non-empty summary content")
|
||||||
|
}
|
||||||
|
// Should have called both normal and aggressive
|
||||||
|
foundNormal := false
|
||||||
|
foundAggressive := false
|
||||||
|
for _, c := range calls {
|
||||||
|
if c == "normal" {
|
||||||
|
foundNormal = true
|
||||||
|
}
|
||||||
|
if c == "aggressive" {
|
||||||
|
foundAggressive = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !foundNormal {
|
||||||
|
t.Error("expected normal LLM call")
|
||||||
|
}
|
||||||
|
if !foundAggressive {
|
||||||
|
t.Error("expected aggressive LLM call (level 2 escalation)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateLeafSummaryEscalationToTruncation(t *testing.T) {
|
||||||
|
// Both normal and aggressive return empty, should escalate to level 3 truncation
|
||||||
|
emptyComplete := func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
s := openTestStore(t)
|
||||||
|
ce, _ := newTestCompactionEngineWithStore(s, emptyComplete)
|
||||||
|
|
||||||
|
msgs := []Message{
|
||||||
|
{Role: "user", Content: "hello world from test", TokenCount: 10},
|
||||||
|
{Role: "assistant", Content: "response text here", TokenCount: 10},
|
||||||
|
}
|
||||||
|
|
||||||
|
content, err := ce.generateLeafSummary(context.Background(), msgs, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("generateLeafSummary: %v", err)
|
||||||
|
}
|
||||||
|
// Level 3 truncation should have produced something
|
||||||
|
if content == "" {
|
||||||
|
t.Error("expected non-empty content from level 3 truncation fallback")
|
||||||
|
}
|
||||||
|
if !contains(content, "Truncated from") {
|
||||||
|
t.Errorf("expected truncation marker in content: %q", content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateCondensedSummary(t *testing.T) {
|
||||||
|
ce, _, _ := newTestCompactionEngine(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
summaries := []Summary{
|
||||||
|
{SummaryID: "sum_a", Content: "first summary", TokenCount: 100},
|
||||||
|
{SummaryID: "sum_b", Content: "second summary", TokenCount: 100},
|
||||||
|
}
|
||||||
|
|
||||||
|
content, err := ce.generateCondensedSummary(ctx, summaries)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("generateCondensedSummary: %v", err)
|
||||||
|
}
|
||||||
|
if content == "" {
|
||||||
|
t.Error("expected non-empty condensed summary content")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateCondensedSummaryEscalation(t *testing.T) {
|
||||||
|
// When LLM returns empty, should fall back to deterministic concatenation
|
||||||
|
emptyComplete := func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
s := openTestStore(t)
|
||||||
|
ce, _ := newTestCompactionEngineWithStore(s, emptyComplete)
|
||||||
|
|
||||||
|
summaries := []Summary{
|
||||||
|
{SummaryID: "sum_a", Content: "first summary text", TokenCount: 50},
|
||||||
|
{SummaryID: "sum_b", Content: "second summary text", TokenCount: 50},
|
||||||
|
}
|
||||||
|
|
||||||
|
content, err := ce.generateCondensedSummary(context.Background(), summaries)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("generateCondensedSummary: %v", err)
|
||||||
|
}
|
||||||
|
// Should fall back to concatenation
|
||||||
|
if content == "" {
|
||||||
|
t.Error("expected non-empty content from fallback")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Async Condensed Compaction (Phase 2) ---
|
||||||
|
|
||||||
|
func TestCompactAsyncReturnsBeforeCondensed(t *testing.T) {
|
||||||
|
// Use a slow CompleteFn to verify Compact returns before condensed finishes
|
||||||
|
var callCount int32
|
||||||
|
slowComplete := func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) {
|
||||||
|
atomic.AddInt32(&callCount, 1)
|
||||||
|
time.Sleep(500 * time.Millisecond) // simulate slow LLM
|
||||||
|
return "Slow condensed summary.", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
s := openTestStore(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
conv, _ := s.GetOrCreateConversation(ctx, "test:async")
|
||||||
|
convID := conv.ConversationID
|
||||||
|
|
||||||
|
ce, cancel := newTestCompactionEngineWithStore(s, slowComplete)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
cancel()
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create enough leaf summaries for condensation + fresh tail
|
||||||
|
for i := 0; i < CondensedMinFanout; i++ {
|
||||||
|
now := time.Now().UTC()
|
||||||
|
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||||
|
ConversationID: convID,
|
||||||
|
Kind: SummaryKindLeaf,
|
||||||
|
Depth: 0,
|
||||||
|
Content: "leaf for async test",
|
||||||
|
TokenCount: 500,
|
||||||
|
EarliestAt: &now,
|
||||||
|
LatestAt: &now,
|
||||||
|
})
|
||||||
|
s.AppendContextSummary(ctx, convID, summary.SummaryID)
|
||||||
|
}
|
||||||
|
for i := 0; i < FreshTailCount; i++ {
|
||||||
|
m, _ := s.AddMessage(ctx, convID, "user", "fresh", 10)
|
||||||
|
s.AppendContextMessage(ctx, convID, m.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compact with force — should return quickly, condensed runs async
|
||||||
|
start := time.Now()
|
||||||
|
result, err := ce.Compact(ctx, convID, CompactInput{Force: true})
|
||||||
|
elapsed := time.Since(start)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Compact: %v", err)
|
||||||
|
}
|
||||||
|
if result == nil {
|
||||||
|
t.Fatal("expected non-nil result")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should return well before the 500ms LLM call
|
||||||
|
if elapsed > 200*time.Millisecond {
|
||||||
|
t.Errorf("Compact took %v, should return before async condensed finishes", elapsed)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for async to complete
|
||||||
|
time.Sleep(800 * time.Millisecond)
|
||||||
|
|
||||||
|
// Verify condensed summary was created by background goroutine
|
||||||
|
summaries, _ := s.GetSummariesByConversation(ctx, convID)
|
||||||
|
foundCondensed := false
|
||||||
|
for _, sum := range summaries {
|
||||||
|
if sum.Kind == SummaryKindCondensed {
|
||||||
|
foundCondensed = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !foundCondensed {
|
||||||
|
t.Error("expected at least one condensed summary from async Phase 2")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompactAsyncDedup(t *testing.T) {
|
||||||
|
var callCount int32
|
||||||
|
slowComplete := func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) {
|
||||||
|
atomic.AddInt32(&callCount, 1)
|
||||||
|
time.Sleep(300 * time.Millisecond)
|
||||||
|
return "Slow condensed summary.", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
s := openTestStore(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
conv, _ := s.GetOrCreateConversation(ctx, "test:dedup")
|
||||||
|
convID := conv.ConversationID
|
||||||
|
|
||||||
|
ce, cancel := newTestCompactionEngineWithStore(s, slowComplete)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
cancel()
|
||||||
|
waitForCondensed(ce, convID, 2*time.Second)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create conditions for condensed compaction
|
||||||
|
for i := 0; i < CondensedMinFanout; i++ {
|
||||||
|
now := time.Now().UTC()
|
||||||
|
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||||
|
ConversationID: convID,
|
||||||
|
Kind: SummaryKindLeaf,
|
||||||
|
Depth: 0,
|
||||||
|
Content: "leaf for dedup",
|
||||||
|
TokenCount: 500,
|
||||||
|
EarliestAt: &now,
|
||||||
|
LatestAt: &now,
|
||||||
|
})
|
||||||
|
s.AppendContextSummary(ctx, convID, summary.SummaryID)
|
||||||
|
}
|
||||||
|
for i := 0; i < FreshTailCount; i++ {
|
||||||
|
m, _ := s.AddMessage(ctx, convID, "user", "fresh", 10)
|
||||||
|
s.AppendContextMessage(ctx, convID, m.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call Compact twice rapidly
|
||||||
|
ce.Compact(ctx, convID, CompactInput{Force: true})
|
||||||
|
ce.Compact(ctx, convID, CompactInput{Force: true})
|
||||||
|
|
||||||
|
// Wait for async to finish
|
||||||
|
time.Sleep(600 * time.Millisecond)
|
||||||
|
|
||||||
|
// LLM should only be called once for condensed (dedup)
|
||||||
|
// callCount may be 0 if no leaf was created (only condensed in goroutine)
|
||||||
|
// The key is that we don't get 2+ condensed calls
|
||||||
|
if atomic.LoadInt32(&callCount) > 1 {
|
||||||
|
t.Errorf("LLM called %d times, expected at most 1 (dedup)", callCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompactLeafForceBypassesFreshTail(t *testing.T) {
|
||||||
|
// Spec: compactLeaf with force=true should bypass FreshTailCount protection
|
||||||
|
// so CompactUntilUnder can compress messages inside the fresh tail
|
||||||
|
ce, s, convID := newTestCompactionEngine(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create exactly FreshTailCount+4 messages (36 total)
|
||||||
|
// Without force: all messages are in fresh tail → no candidate
|
||||||
|
// With force: should compact the oldest messages
|
||||||
|
total := FreshTailCount + 4
|
||||||
|
for i := 0; i < total; i++ {
|
||||||
|
m, _ := s.AddMessage(ctx, convID, "user", fmt.Sprintf("message %d for force test", i), 100)
|
||||||
|
s.AppendContextMessage(ctx, convID, m.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Without force: should return nil (all in fresh tail)
|
||||||
|
summaryID, err := ce.compactLeaf(ctx, convID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("compactLeaf no-force: %v", err)
|
||||||
|
}
|
||||||
|
if summaryID != nil {
|
||||||
|
t.Error("expected nil without force (all messages in fresh tail)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// With force: should compact despite fresh tail protection
|
||||||
|
summaryID, err = ce.compactLeaf(ctx, convID, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("compactLeaf force: %v", err)
|
||||||
|
}
|
||||||
|
if summaryID == nil {
|
||||||
|
t.Error("expected summary with force=true (bypasses fresh tail)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompactLeafAccumulatesUpToLeafChunkTokens(t *testing.T) {
|
||||||
|
// Spec: compactLeaf should accumulate messages up to LeafChunkTokens before stopping
|
||||||
|
// It should NOT take the entire contiguous chunk regardless of token count
|
||||||
|
ce, s, convID := newTestCompactionEngine(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create messages totaling far more than LeafChunkTokens (20000)
|
||||||
|
// Each message is ~500 tokens, create 80 messages = 40000 tokens
|
||||||
|
for i := 0; i < 80; i++ {
|
||||||
|
m, _ := s.AddMessage(
|
||||||
|
ctx,
|
||||||
|
convID,
|
||||||
|
"user",
|
||||||
|
fmt.Sprintf(
|
||||||
|
"message %d with lots of content to make it big enough for token counting purposes and this should be a substantial message body that represents a meaningful conversation turn",
|
||||||
|
i,
|
||||||
|
),
|
||||||
|
500,
|
||||||
|
)
|
||||||
|
s.AppendContextMessage(ctx, convID, m.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
summaryID, err := ce.compactLeaf(ctx, convID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("compactLeaf: %v", err)
|
||||||
|
}
|
||||||
|
if summaryID == nil {
|
||||||
|
t.Fatal("expected a summary to be created")
|
||||||
|
}
|
||||||
|
|
||||||
|
// The source messages that were compacted should total roughly LeafChunkTokens (20000),
|
||||||
|
// not the entire 40000 tokens worth of messages
|
||||||
|
summary, _ := s.GetSummary(ctx, *summaryID)
|
||||||
|
if summary == nil {
|
||||||
|
t.Fatal("summary not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Source message tokens should be roughly <= LeafChunkTokens (20000)
|
||||||
|
// Spec says: "Stop when accumulated tokens >= LeafChunkTokens"
|
||||||
|
if summary.SourceMessageTokenCount > LeafChunkTokens {
|
||||||
|
t.Errorf("source tokens = %d, should be <= LeafChunkTokens (%d)",
|
||||||
|
summary.SourceMessageTokenCount, LeafChunkTokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,30 @@
|
|||||||
|
package seahorse
|
||||||
|
|
||||||
|
// Short-term memory configuration constants — all are experience-based defaults.
|
||||||
|
|
||||||
|
const (
|
||||||
|
// OrdinalStep is the gap between ordinals in context_items.
|
||||||
|
// Insert at midpoint; resequence only when precision exhausted.
|
||||||
|
OrdinalStep = 100
|
||||||
|
|
||||||
|
// ContextThreshold is the compaction trigger for the context window.
|
||||||
|
ContextThreshold float64 = 0.75 // Compact at 75% of context window
|
||||||
|
FreshTailCount int = 32 // Recent messages protected from compaction
|
||||||
|
|
||||||
|
// LeafMinFanout is the fanout parameter.
|
||||||
|
LeafMinFanout int = 8 // Min messages per leaf summary
|
||||||
|
CondensedMinFanout int = 4 // Min summaries per condensed
|
||||||
|
CondensedMinFanoutHard int = 2 // Min for forced compaction
|
||||||
|
|
||||||
|
// LeafChunkTokens is the token target.
|
||||||
|
LeafChunkTokens int = 20000 // Max tokens per leaf chunk
|
||||||
|
LeafTargetTokens int = 1200 // Target tokens for leaf summaries
|
||||||
|
CondensedTargetTokens int = 2000 // Target tokens for condensed summaries
|
||||||
|
MaxExpandTokens int = 4000 // Token cap for expansion queries
|
||||||
|
|
||||||
|
// MaxCompactIterations caps CompactUntilUnder to prevent infinite loops.
|
||||||
|
// Each iteration reduces ~4x tokens via leaf (8:1) or condensed (4:1) compaction.
|
||||||
|
// With a 200k token context window and 75% threshold, ~20 iterations is enough
|
||||||
|
// for any realistic scenario. If exceeded, the issue is logged as a warning.
|
||||||
|
MaxCompactIterations int = 20
|
||||||
|
)
|
||||||
@@ -0,0 +1,568 @@
|
|||||||
|
package seahorse
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
_ "modernc.org/sqlite"
|
||||||
|
|
||||||
|
"github.com/sipeed/picoclaw/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Config holds engine configuration.
|
||||||
|
type Config struct {
|
||||||
|
DBPath string `json:"dbPath"`
|
||||||
|
IgnoreSessionPatterns []string `json:"ignoreSessionPatterns,omitempty"`
|
||||||
|
StatelessSessionPatterns []string `json:"statelessSessionPatterns,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CompleteFn is the LLM completion function type.
|
||||||
|
type CompleteFn func(ctx context.Context, prompt string, opts CompleteOptions) (string, error)
|
||||||
|
|
||||||
|
// CompleteOptions holds LLM completion parameters.
|
||||||
|
type CompleteOptions struct {
|
||||||
|
Model string
|
||||||
|
MaxTokens int
|
||||||
|
Temperature float64
|
||||||
|
}
|
||||||
|
|
||||||
|
// IngestResult is the result of message ingestion.
|
||||||
|
type IngestResult struct {
|
||||||
|
MessageCount int `json:"messageCount"`
|
||||||
|
TokenCount int `json:"tokenCount"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// AssembleInput controls context assembly.
|
||||||
|
type AssembleInput struct {
|
||||||
|
Budget int `json:"budget"`
|
||||||
|
Query string `json:"query,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// AssembleResult contains assembled context.
|
||||||
|
type AssembleResult struct {
|
||||||
|
Messages []Message `json:"messages"`
|
||||||
|
Summary string `json:"summary"` // formatted XML summaries + system prompt addition
|
||||||
|
}
|
||||||
|
|
||||||
|
const numSessionShards = 256
|
||||||
|
|
||||||
|
// Engine is the main short-term memory engine.
|
||||||
|
type Engine struct {
|
||||||
|
store *Store
|
||||||
|
compaction *CompactionEngine
|
||||||
|
compactionMu sync.Mutex
|
||||||
|
assembler *Assembler
|
||||||
|
assemblerMu sync.Mutex
|
||||||
|
retrieval *RetrievalEngine
|
||||||
|
config Config
|
||||||
|
complete CompleteFn
|
||||||
|
ignorePatterns []*regexp.Regexp
|
||||||
|
statelessPatterns []*regexp.Regexp
|
||||||
|
sessionShards [numSessionShards]struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CompactionEngine handles LLM-based summarization (defined in short_compaction.go).
|
||||||
|
type CompactionEngine struct {
|
||||||
|
store *Store
|
||||||
|
config Config
|
||||||
|
complete CompleteFn
|
||||||
|
condensing sync.Map // map[int64]struct{} — dedup for async condensed goroutines
|
||||||
|
shutdownCtx context.Context
|
||||||
|
shutdownCancel context.CancelFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assembler handles budget-aware context assembly (defined in short_assembler.go).
|
||||||
|
type Assembler struct {
|
||||||
|
store *Store
|
||||||
|
config Config
|
||||||
|
}
|
||||||
|
|
||||||
|
// RetrievalEngine handles search and expansion (defined in short_retrieval.go).
|
||||||
|
type RetrievalEngine struct {
|
||||||
|
store *Store
|
||||||
|
config Config
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store returns the underlying store for direct access.
|
||||||
|
func (r *RetrievalEngine) Store() *Store {
|
||||||
|
return r.store
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewEngine creates a new short-term memory engine.
|
||||||
|
func NewEngine(config Config, completeFn CompleteFn) (*Engine, error) {
|
||||||
|
dir := filepath.Dir(config.DBPath)
|
||||||
|
if dir != "" && dir != "." {
|
||||||
|
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||||
|
return nil, fmt.Errorf("create db directory: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
db, err := sql.Open("sqlite", config.DBPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("open db: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configure SQLite for concurrent access
|
||||||
|
if _, err := db.Exec("PRAGMA journal_mode = WAL;"); err != nil {
|
||||||
|
db.Close()
|
||||||
|
return nil, fmt.Errorf("enable WAL: %w", err)
|
||||||
|
}
|
||||||
|
if _, err := db.Exec("PRAGMA busy_timeout = 5000;"); err != nil {
|
||||||
|
db.Close()
|
||||||
|
return nil, fmt.Errorf("set busy_timeout: %w", err)
|
||||||
|
}
|
||||||
|
if _, err := db.Exec("PRAGMA synchronous = NORMAL;"); err != nil {
|
||||||
|
db.Close()
|
||||||
|
return nil, fmt.Errorf("set synchronous: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := runSchema(db); err != nil {
|
||||||
|
db.Close()
|
||||||
|
return nil, fmt.Errorf("migrations: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
store := &Store{db: db}
|
||||||
|
|
||||||
|
// Prepend hardcoded ignore patterns (spec lines 1326-1328)
|
||||||
|
ignorePatterns := make([]string, 0, 1+len(config.IgnoreSessionPatterns))
|
||||||
|
ignorePatterns = append(ignorePatterns, "heartbeat")
|
||||||
|
ignorePatterns = append(ignorePatterns, config.IgnoreSessionPatterns...)
|
||||||
|
|
||||||
|
retrieval := &RetrievalEngine{store: store, config: config}
|
||||||
|
|
||||||
|
return &Engine{
|
||||||
|
store: store,
|
||||||
|
compaction: nil,
|
||||||
|
assembler: nil,
|
||||||
|
retrieval: retrieval,
|
||||||
|
config: config,
|
||||||
|
complete: completeFn,
|
||||||
|
ignorePatterns: compileSessionPatterns(ignorePatterns),
|
||||||
|
statelessPatterns: compileSessionPatterns(config.StatelessSessionPatterns),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// compileSessionPattern converts a glob pattern to a compiled regex.
|
||||||
|
// Pattern rules:
|
||||||
|
// - * matches any sequence of non-colon characters ([^:]*)
|
||||||
|
// - ** matches any sequence of characters including colons (.*)
|
||||||
|
// - All other characters are treated literally
|
||||||
|
// - Pattern is anchored (^...$)
|
||||||
|
func compileSessionPattern(pattern string) *regexp.Regexp {
|
||||||
|
var b strings.Builder
|
||||||
|
b.WriteByte('^')
|
||||||
|
|
||||||
|
i := 0
|
||||||
|
for i < len(pattern) {
|
||||||
|
if i+1 < len(pattern) && pattern[i] == '*' && pattern[i+1] == '*' {
|
||||||
|
b.WriteString(".*")
|
||||||
|
i += 2
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if pattern[i] == '*' {
|
||||||
|
b.WriteString("[^:]*")
|
||||||
|
i++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
b.WriteString(regexp.QuoteMeta(string(pattern[i])))
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
|
||||||
|
b.WriteByte('$')
|
||||||
|
return regexp.MustCompile(b.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// compileSessionPatterns compiles multiple glob patterns into regex patterns.
|
||||||
|
func compileSessionPatterns(patterns []string) []*regexp.Regexp {
|
||||||
|
result := make([]*regexp.Regexp, 0, len(patterns))
|
||||||
|
for _, p := range patterns {
|
||||||
|
if p == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
result = append(result, compileSessionPattern(p))
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// shouldIgnoreSession returns true if the session key matches any ignore pattern.
|
||||||
|
func (e *Engine) shouldIgnoreSession(sessionKey string) bool {
|
||||||
|
for _, p := range e.ignorePatterns {
|
||||||
|
if p.MatchString(sessionKey) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// isStatelessSession returns true if the session key matches any stateless pattern.
|
||||||
|
func (e *Engine) isStatelessSession(sessionKey string) bool {
|
||||||
|
for _, p := range e.statelessPatterns {
|
||||||
|
if p.MatchString(sessionKey) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// fnv32 computes FNV-1a 32-bit hash for session key sharding.
|
||||||
|
func fnv32(key string) uint32 {
|
||||||
|
h := uint32(2166136261)
|
||||||
|
for _, c := range key {
|
||||||
|
h ^= uint32(c)
|
||||||
|
h *= 16777619
|
||||||
|
}
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
// getSessionMutex returns the sharded mutex for a session key.
|
||||||
|
func (e *Engine) getSessionMutex(sessionKey string) *sync.Mutex {
|
||||||
|
h := fnv32(sessionKey)
|
||||||
|
shard := h % numSessionShards
|
||||||
|
return &e.sessionShards[shard].mu
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ingest adds messages to a conversation identified by sessionKey.
|
||||||
|
func (e *Engine) Ingest(ctx context.Context, sessionKey string, messages []Message) (*IngestResult, error) {
|
||||||
|
if e.shouldIgnoreSession(sessionKey) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
if e.isStatelessSession(sessionKey) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
mu := e.getSessionMutex(sessionKey)
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
|
||||||
|
conv, err := e.store.GetOrCreateConversation(ctx, sessionKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get conversation: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var totalTokens int
|
||||||
|
var msgIDs []int64
|
||||||
|
for _, msg := range messages {
|
||||||
|
var added *Message
|
||||||
|
var err error
|
||||||
|
if len(msg.Parts) > 0 {
|
||||||
|
added, err = e.store.AddMessageWithParts(ctx, conv.ConversationID, msg.Role, msg.Parts, msg.TokenCount)
|
||||||
|
} else {
|
||||||
|
added, err = e.store.AddMessage(ctx, conv.ConversationID, msg.Role, msg.Content, msg.TokenCount)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("add message: %w", err)
|
||||||
|
}
|
||||||
|
totalTokens += msg.TokenCount
|
||||||
|
msgIDs = append(msgIDs, added.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Append to context_items using actual inserted IDs
|
||||||
|
if err := e.store.AppendContextMessages(ctx, conv.ConversationID, msgIDs); err != nil {
|
||||||
|
return nil, fmt.Errorf("append context: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.InfoCF("seahorse", "ingest", map[string]any{
|
||||||
|
"conv_id": conv.ConversationID,
|
||||||
|
"messages": len(messages),
|
||||||
|
"tokens": totalTokens,
|
||||||
|
})
|
||||||
|
return &IngestResult{
|
||||||
|
MessageCount: len(messages),
|
||||||
|
TokenCount: totalTokens,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close releases resources.
|
||||||
|
func (e *Engine) Close() error {
|
||||||
|
// Signal compaction goroutines to stop
|
||||||
|
if e.compaction != nil {
|
||||||
|
e.compaction.Close()
|
||||||
|
}
|
||||||
|
if e.store != nil && e.store.db != nil {
|
||||||
|
return e.store.db.Close()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRetrieval returns the retrieval engine for tool implementations.
|
||||||
|
func (e *Engine) GetRetrieval() *RetrievalEngine {
|
||||||
|
return e.retrieval
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assemble builds budget-constrained context for a session.
|
||||||
|
func (e *Engine) Assemble(ctx context.Context, sessionKey string, input AssembleInput) (*AssembleResult, error) {
|
||||||
|
if e.shouldIgnoreSession(sessionKey) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
conv, err := e.store.GetOrCreateConversation(ctx, sessionKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get conversation: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
e.initAssemblerOnce()
|
||||||
|
return e.assembler.Assemble(ctx, conv.ConversationID, input)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compact compresses conversation history for a session.
|
||||||
|
func (e *Engine) Compact(ctx context.Context, sessionKey string, input CompactInput) (*CompactResult, error) {
|
||||||
|
if e.shouldIgnoreSession(sessionKey) || e.isStatelessSession(sessionKey) {
|
||||||
|
return &CompactResult{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
conv, err := e.store.GetOrCreateConversation(ctx, sessionKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get conversation: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
e.initCompactionOnce()
|
||||||
|
return e.compaction.Compact(ctx, conv.ConversationID, input)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CompactUntilUnder aggressively compacts until context is under budget.
|
||||||
|
// Used for emergency compaction after LLM overflow (retry reason).
|
||||||
|
func (e *Engine) CompactUntilUnder(ctx context.Context, sessionKey string, budget int) (*CompactResult, error) {
|
||||||
|
if e.shouldIgnoreSession(sessionKey) || e.isStatelessSession(sessionKey) {
|
||||||
|
return &CompactResult{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
conv, err := e.store.GetOrCreateConversation(ctx, sessionKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get conversation: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
e.initCompactionOnce()
|
||||||
|
return e.compaction.CompactUntilUnder(ctx, conv.ConversationID, budget)
|
||||||
|
}
|
||||||
|
|
||||||
|
// initCompactionOnce lazily initializes the compaction engine.
|
||||||
|
func (e *Engine) initCompactionOnce() {
|
||||||
|
if e.compaction == nil {
|
||||||
|
e.compactionMu.Lock()
|
||||||
|
defer e.compactionMu.Unlock()
|
||||||
|
if e.compaction == nil {
|
||||||
|
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
|
||||||
|
e.compaction = &CompactionEngine{
|
||||||
|
store: e.store,
|
||||||
|
config: e.config,
|
||||||
|
complete: e.complete,
|
||||||
|
shutdownCtx: shutdownCtx,
|
||||||
|
shutdownCancel: shutdownCancel,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// initAssemblerOnce lazily initializes the assembler.
|
||||||
|
func (e *Engine) initAssemblerOnce() {
|
||||||
|
if e.assembler == nil {
|
||||||
|
e.assemblerMu.Lock()
|
||||||
|
defer e.assemblerMu.Unlock()
|
||||||
|
if e.assembler == nil {
|
||||||
|
e.assembler = &Assembler{store: e.store, config: e.config}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IngestMessages is an alias for Ingest.
|
||||||
|
func (e *Engine) IngestMessages(ctx context.Context, sessionKey string, messages []Message) (*IngestResult, error) {
|
||||||
|
return e.Ingest(ctx, sessionKey, messages)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bootstrap reconciles a session's messages with the database.
|
||||||
|
// Called once at startup for each known session.
|
||||||
|
// Bootstrap reconciles JSONL history with SQLite by ingesting only the delta.
|
||||||
|
// Simple approach: find longest matching prefix and append delta.
|
||||||
|
// If any mismatch is detected, clear and rebuild.
|
||||||
|
func (e *Engine) Bootstrap(ctx context.Context, sessionKey string, messages []Message) error {
|
||||||
|
if e.shouldIgnoreSession(sessionKey) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if e.isStatelessSession(sessionKey) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if len(messages) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
conv, err := e.store.GetOrCreateConversation(ctx, sessionKey)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("bootstrap: get conversation: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get messages already in DB
|
||||||
|
dbMsgs, err := e.store.GetMessages(ctx, conv.ConversationID, len(messages), 0)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("bootstrap: get messages: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fast path: DB has same count and exact match → no-op
|
||||||
|
if len(dbMsgs) == len(messages) {
|
||||||
|
matched := true
|
||||||
|
for i := 0; i < len(messages); i++ {
|
||||||
|
if !messageMatches(dbMsgs[i], messages[i]) {
|
||||||
|
matched = false
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if matched {
|
||||||
|
return nil // DB is up to date
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find longest matching prefix from the start
|
||||||
|
anchor := -1
|
||||||
|
compareLen := len(dbMsgs)
|
||||||
|
if compareLen > len(messages) {
|
||||||
|
compareLen = len(messages)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < compareLen; i++ {
|
||||||
|
if messageMatches(dbMsgs[i], messages[i]) {
|
||||||
|
anchor = i
|
||||||
|
} else {
|
||||||
|
// Mismatch detected - log details and rebuild
|
||||||
|
logger.InfoCF("seahorse", "bootstrap: mismatch detected", map[string]any{
|
||||||
|
"conv_id": conv.ConversationID,
|
||||||
|
"index": i,
|
||||||
|
"db_role": dbMsgs[i].Role,
|
||||||
|
"db_content": truncate(dbMsgs[i].Content, 50),
|
||||||
|
"db_parts": len(dbMsgs[i].Parts),
|
||||||
|
"msg_role": messages[i].Role,
|
||||||
|
"msg_content": truncate(messages[i].Content, 50),
|
||||||
|
"msg_parts": len(messages[i].Parts),
|
||||||
|
})
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we hit a mismatch before reaching the end of DB messages, delete delta and re-ingest
|
||||||
|
// Note: anchor can be -1 if first message didn't match (history completely changed)
|
||||||
|
if anchor >= 0 && anchor < len(dbMsgs)-1 && len(dbMsgs) > 0 {
|
||||||
|
anchorID := dbMsgs[anchor].ID
|
||||||
|
logger.InfoCF("seahorse", "bootstrap: history edit detected", map[string]any{
|
||||||
|
"conv_id": conv.ConversationID,
|
||||||
|
"db_count": len(dbMsgs),
|
||||||
|
"anchor": anchor,
|
||||||
|
"anchor_id": anchorID,
|
||||||
|
"msg_count": len(messages),
|
||||||
|
"delta_start": anchor + 1,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Delete messages after anchor (also clears context_items)
|
||||||
|
if err := e.store.DeleteMessagesAfterID(ctx, conv.ConversationID, anchorID); err != nil {
|
||||||
|
return fmt.Errorf("bootstrap: delete messages: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Re-ingest from anchor+1 to end
|
||||||
|
delta := messages[anchor+1:]
|
||||||
|
if len(delta) > 0 {
|
||||||
|
_, err := e.Ingest(ctx, sessionKey, delta)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("bootstrap: re-ingest: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normal case: append delta after anchor
|
||||||
|
if anchor >= 0 && anchor < len(messages)-1 {
|
||||||
|
delta := messages[anchor+1:]
|
||||||
|
if len(delta) > 0 {
|
||||||
|
_, err := e.Ingest(ctx, sessionKey, delta)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("bootstrap: ingest delta: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if anchor == -1 && len(dbMsgs) > 0 {
|
||||||
|
// First message changed (history completely different) - rebuild from scratch
|
||||||
|
logger.InfoCF("seahorse", "bootstrap: history replaced, rebuilding", map[string]any{
|
||||||
|
"conv_id": conv.ConversationID,
|
||||||
|
"db_count": len(dbMsgs),
|
||||||
|
"msg_count": len(messages),
|
||||||
|
})
|
||||||
|
// Delete all existing messages
|
||||||
|
if err := e.store.DeleteMessagesAfterID(ctx, conv.ConversationID, 0); err != nil {
|
||||||
|
return fmt.Errorf("bootstrap: delete all messages: %w", err)
|
||||||
|
}
|
||||||
|
// Re-ingest everything
|
||||||
|
if len(messages) > 0 {
|
||||||
|
_, err := e.Ingest(ctx, sessionKey, messages)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("bootstrap: re-ingest all: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if anchor == -1 && len(dbMsgs) == 0 {
|
||||||
|
// DB is empty, ingest everything
|
||||||
|
_, err := e.Ingest(ctx, sessionKey, messages)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("bootstrap: ingest all: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// truncate shortens a string for logging.
|
||||||
|
func truncate(s string, maxLen int) string {
|
||||||
|
if len(s) <= maxLen {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return s[:maxLen] + "..."
|
||||||
|
}
|
||||||
|
|
||||||
|
// messageMatches compares two messages using (role, content) or (role, parts).
|
||||||
|
// TokenCount is NOT compared because it may be re-estimated differently
|
||||||
|
// during bootstrap (e.g., via tokenizer.EstimateMessageTokens).
|
||||||
|
// For messages with Parts (tool_use, tool_result), compare Parts instead of Content
|
||||||
|
// since AddMessageWithParts stores empty Content in DB.
|
||||||
|
func messageMatches(a, b Message) bool {
|
||||||
|
if a.Role != b.Role {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// If either message has Parts, compare Parts
|
||||||
|
if len(a.Parts) > 0 || len(b.Parts) > 0 {
|
||||||
|
return partsMatch(a.Parts, b.Parts)
|
||||||
|
}
|
||||||
|
// Simple text messages: compare Content
|
||||||
|
return a.Content == b.Content
|
||||||
|
}
|
||||||
|
|
||||||
|
// partsMatch compares two slices of MessagePart for equality.
|
||||||
|
func partsMatch(a, b []MessagePart) bool {
|
||||||
|
if len(a) != len(b) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for i := range a {
|
||||||
|
if a[i].Type != b[i].Type {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
switch a[i].Type {
|
||||||
|
case "text":
|
||||||
|
if a[i].Text != b[i].Text {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
case "tool_use":
|
||||||
|
if a[i].Name != b[i].Name || a[i].Arguments != b[i].Arguments || a[i].ToolCallID != b[i].ToolCallID {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
case "tool_result":
|
||||||
|
if a[i].ToolCallID != b[i].ToolCallID || a[i].Text != b[i].Text {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
case "media":
|
||||||
|
if a[i].MediaURI != b[i].MediaURI || a[i].MimeType != b[i].MimeType {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,212 @@
|
|||||||
|
package seahorse
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"regexp"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ParseLastDuration parses a "last" duration string like "6h", "7d", "2w", "1m".
|
||||||
|
// Returns the duration and nil error, or zero and error if invalid.
|
||||||
|
func ParseLastDuration(s string) (time.Duration, error) {
|
||||||
|
if s == "" {
|
||||||
|
return 0, fmt.Errorf("empty duration")
|
||||||
|
}
|
||||||
|
|
||||||
|
re := regexp.MustCompile(`^(\d+)([hdwm])$`)
|
||||||
|
matches := re.FindStringSubmatch(s)
|
||||||
|
if matches == nil {
|
||||||
|
return 0, fmt.Errorf("invalid duration format: %q (use format like 6h, 7d, 2w, 1m)", s)
|
||||||
|
}
|
||||||
|
|
||||||
|
value, _ := strconv.Atoi(matches[1])
|
||||||
|
unit := matches[2]
|
||||||
|
|
||||||
|
switch unit {
|
||||||
|
case "h":
|
||||||
|
return time.Duration(value) * time.Hour, nil
|
||||||
|
case "d":
|
||||||
|
return time.Duration(value) * 24 * time.Hour, nil
|
||||||
|
case "w":
|
||||||
|
return time.Duration(value) * 7 * 24 * time.Hour, nil
|
||||||
|
case "m":
|
||||||
|
return time.Duration(value) * 30 * 24 * time.Hour, nil
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("unknown unit: %q", unit)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GrepInput controls search across summaries and messages.
|
||||||
|
type GrepInput struct {
|
||||||
|
Pattern string `json:"pattern"`
|
||||||
|
Scope string `json:"scope,omitempty"` // "both" (default), "summary", or "message"
|
||||||
|
Role string `json:"role,omitempty"` // "user", "assistant", or "" (all)
|
||||||
|
AllConversations bool `json:"allConversations,omitempty"`
|
||||||
|
Since *time.Time `json:"since,omitempty"`
|
||||||
|
Before *time.Time `json:"before,omitempty"`
|
||||||
|
Last string `json:"last,omitempty"` // shortcut: "6h", "7d", "2w", "1m"
|
||||||
|
Limit int `json:"limit,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GrepResult contains search results.
|
||||||
|
type GrepResult struct {
|
||||||
|
Success bool `json:"success"`
|
||||||
|
Summaries []GrepSummaryResult `json:"summaries"`
|
||||||
|
Messages []GrepMessageResult `json:"messages"`
|
||||||
|
TotalSummaries int `json:"totalSummaries"`
|
||||||
|
TotalMessages int `json:"totalMessages"`
|
||||||
|
Hint string `json:"hint,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GrepSummaryResult is a summary match from grep.
|
||||||
|
type GrepSummaryResult struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
Depth int `json:"depth"`
|
||||||
|
Kind SummaryKind `json:"kind"`
|
||||||
|
ConversationID int64 `json:"conversationId"`
|
||||||
|
// Rank is the bm25 relevance score (negative value, closer to 0 = better match).
|
||||||
|
// Examples: -0.5 = excellent match, -2.0 = good match, -10.0 = partial match.
|
||||||
|
Rank float64 `json:"rank,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GrepMessageResult is a message match from grep.
|
||||||
|
type GrepMessageResult struct {
|
||||||
|
ID int64 `json:"id,string"`
|
||||||
|
Snippet string `json:"snippet"`
|
||||||
|
Role string `json:"role"`
|
||||||
|
ConversationID int64 `json:"conversationId"`
|
||||||
|
Rank float64 `json:"rank,omitempty"` // Relevance score (lower = better match)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExpandMessagesResult contains expanded messages.
|
||||||
|
type ExpandMessagesResult struct {
|
||||||
|
Messages []Message `json:"messages"`
|
||||||
|
TokenCount int `json:"tokenCount"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Grep searches summaries and messages for matching content.
|
||||||
|
func (r *RetrievalEngine) Grep(ctx context.Context, input GrepInput) (*GrepResult, error) {
|
||||||
|
if input.Pattern == "" {
|
||||||
|
return nil, fmt.Errorf("grep: pattern is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
limit := input.Limit
|
||||||
|
if limit == 0 {
|
||||||
|
limit = 20
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle Last parameter: convert to Since
|
||||||
|
since := input.Since
|
||||||
|
if input.Last != "" {
|
||||||
|
dur, err := ParseLastDuration(input.Last)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("grep: invalid last: %w", err)
|
||||||
|
}
|
||||||
|
t := time.Now().UTC().Add(-dur)
|
||||||
|
since = &t
|
||||||
|
}
|
||||||
|
|
||||||
|
// Auto-detect mode: use LIKE if pattern contains %, otherwise full-text
|
||||||
|
mode := ""
|
||||||
|
if strings.Contains(input.Pattern, "%") {
|
||||||
|
mode = "like"
|
||||||
|
}
|
||||||
|
|
||||||
|
searchInput := SearchInput{
|
||||||
|
Pattern: input.Pattern,
|
||||||
|
Mode: mode,
|
||||||
|
Role: input.Role,
|
||||||
|
AllConversations: input.AllConversations,
|
||||||
|
Since: since,
|
||||||
|
Before: input.Before,
|
||||||
|
Limit: limit,
|
||||||
|
}
|
||||||
|
|
||||||
|
result := &GrepResult{
|
||||||
|
Success: true,
|
||||||
|
Summaries: make([]GrepSummaryResult, 0),
|
||||||
|
Messages: make([]GrepMessageResult, 0),
|
||||||
|
TotalSummaries: 0,
|
||||||
|
TotalMessages: 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine scope
|
||||||
|
scope := input.Scope
|
||||||
|
if scope == "" {
|
||||||
|
scope = "both"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Search summaries if requested
|
||||||
|
if scope == "both" || scope == "summary" {
|
||||||
|
sumResults, err := r.store.SearchSummaries(ctx, searchInput)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("search summaries: %w", err)
|
||||||
|
}
|
||||||
|
for _, sr := range sumResults {
|
||||||
|
if sr.SummaryID != "" {
|
||||||
|
result.Summaries = append(result.Summaries, GrepSummaryResult{
|
||||||
|
ID: sr.SummaryID,
|
||||||
|
Content: sr.Content,
|
||||||
|
Depth: sr.Depth,
|
||||||
|
Kind: sr.Kind,
|
||||||
|
ConversationID: sr.ConversationID,
|
||||||
|
Rank: sr.Rank,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(sumResults) > 0 {
|
||||||
|
result.TotalSummaries = sumResults[0].TotalCount
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Search messages if requested
|
||||||
|
if scope == "both" || scope == "message" {
|
||||||
|
msgResults, err := r.store.SearchMessages(ctx, searchInput)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("search messages: %w", err)
|
||||||
|
}
|
||||||
|
for _, sr := range msgResults {
|
||||||
|
if sr.MessageID > 0 {
|
||||||
|
result.Messages = append(result.Messages, GrepMessageResult{
|
||||||
|
ID: sr.MessageID,
|
||||||
|
Snippet: sr.Snippet,
|
||||||
|
Role: sr.Role,
|
||||||
|
ConversationID: sr.ConversationID,
|
||||||
|
Rank: sr.Rank,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(msgResults) > 0 {
|
||||||
|
result.TotalMessages = msgResults[0].TotalCount
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add hint if no results
|
||||||
|
if len(result.Summaries) == 0 && len(result.Messages) == 0 {
|
||||||
|
result.Hint = "No matches. Try: %keyword% for fuzzy search, or all_conversations: true"
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExpandMessages retrieves full message content by IDs.
|
||||||
|
func (r *RetrievalEngine) ExpandMessages(ctx context.Context, messageIDs []int64) (*ExpandMessagesResult, error) {
|
||||||
|
result := &ExpandMessagesResult{
|
||||||
|
Messages: make([]Message, 0, len(messageIDs)),
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, msgID := range messageIDs {
|
||||||
|
msg, err := r.store.GetMessageByID(ctx, msgID)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
result.Messages = append(result.Messages, *msg)
|
||||||
|
result.TokenCount += msg.TokenCount
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,362 @@
|
|||||||
|
package seahorse
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// --- Retrieval Tests ---
|
||||||
|
|
||||||
|
func newTestRetrieval(t *testing.T) (*RetrievalEngine, *Store, int64) {
|
||||||
|
t.Helper()
|
||||||
|
s := openTestStore(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
conv, _ := s.GetOrCreateConversation(ctx, "test:retrieval")
|
||||||
|
return &RetrievalEngine{store: s}, s, conv.ConversationID
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRetrievalGrepSummaries(t *testing.T) {
|
||||||
|
r, s, convID := newTestRetrieval(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
s.CreateSummary(ctx, CreateSummaryInput{
|
||||||
|
ConversationID: convID,
|
||||||
|
Kind: SummaryKindLeaf,
|
||||||
|
Depth: 0,
|
||||||
|
Content: "数据库连接配置说明",
|
||||||
|
TokenCount: 50,
|
||||||
|
})
|
||||||
|
s.CreateSummary(ctx, CreateSummaryInput{
|
||||||
|
ConversationID: convID,
|
||||||
|
Kind: SummaryKindLeaf,
|
||||||
|
Depth: 0,
|
||||||
|
Content: "API endpoint documentation",
|
||||||
|
TokenCount: 50,
|
||||||
|
})
|
||||||
|
|
||||||
|
// FTS5 search (trigram, needs >= 3 chars)
|
||||||
|
results, err := r.Grep(ctx, GrepInput{
|
||||||
|
Pattern: "数据库连",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Grep: %v", err)
|
||||||
|
}
|
||||||
|
if len(results.Summaries) == 0 {
|
||||||
|
t.Error("expected at least 1 FTS result")
|
||||||
|
}
|
||||||
|
|
||||||
|
// LIKE search with wildcard
|
||||||
|
results, err = r.Grep(ctx, GrepInput{
|
||||||
|
Pattern: "%endpoint%",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Grep LIKE: %v", err)
|
||||||
|
}
|
||||||
|
if len(results.Summaries) == 0 {
|
||||||
|
t.Error("expected at least 1 LIKE result")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRetrievalGrepMessages(t *testing.T) {
|
||||||
|
r, s, convID := newTestRetrieval(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
s.AddMessage(ctx, convID, "user", "find this message about testing", 5)
|
||||||
|
s.AddMessage(ctx, convID, "user", "unrelated content here", 5)
|
||||||
|
|
||||||
|
results, err := r.Grep(ctx, GrepInput{
|
||||||
|
Pattern: "testing",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Grep: %v", err)
|
||||||
|
}
|
||||||
|
if len(results.Messages) == 0 {
|
||||||
|
t.Error("expected at least 1 result for 'testing'")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRetrievalExpandMessages(t *testing.T) {
|
||||||
|
r, s, convID := newTestRetrieval(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
msg, _ := s.AddMessage(ctx, convID, "user", "expand this message", 10)
|
||||||
|
|
||||||
|
result, err := r.ExpandMessages(ctx, []int64{msg.ID})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExpandMessages: %v", err)
|
||||||
|
}
|
||||||
|
if len(result.Messages) != 1 {
|
||||||
|
t.Errorf("Messages = %d, want 1", len(result.Messages))
|
||||||
|
}
|
||||||
|
if result.Messages[0].Content != "expand this message" {
|
||||||
|
t.Errorf("Content = %q, want 'expand this message'", result.Messages[0].Content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRetrievalExpandMultipleMessages(t *testing.T) {
|
||||||
|
r, s, convID := newTestRetrieval(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
msg1, _ := s.AddMessage(ctx, convID, "user", "first message", 10)
|
||||||
|
msg2, _ := s.AddMessage(ctx, convID, "assistant", "second message", 10)
|
||||||
|
msg3, _ := s.AddMessage(ctx, convID, "user", "third message", 10)
|
||||||
|
|
||||||
|
result, err := r.ExpandMessages(ctx, []int64{msg1.ID, msg2.ID, msg3.ID})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExpandMessages: %v", err)
|
||||||
|
}
|
||||||
|
if len(result.Messages) != 3 {
|
||||||
|
t.Errorf("Messages = %d, want 3", len(result.Messages))
|
||||||
|
}
|
||||||
|
if result.TokenCount != 30 {
|
||||||
|
t.Errorf("TokenCount = %d, want 30", result.TokenCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRetrievalGrepWithTimeFilter(t *testing.T) {
|
||||||
|
r, s, convID := newTestRetrieval(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
now := time.Now().UTC()
|
||||||
|
before := now.Add(-2 * time.Hour)
|
||||||
|
|
||||||
|
// Create messages at different times
|
||||||
|
s.AddMessage(ctx, convID, "user", "old message about auth", 5)
|
||||||
|
s.AddMessage(ctx, convID, "user", "recent message about auth", 5)
|
||||||
|
|
||||||
|
// Search with time filter
|
||||||
|
results, err := r.Grep(ctx, GrepInput{
|
||||||
|
Pattern: "auth",
|
||||||
|
Since: &before,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Grep: %v", err)
|
||||||
|
}
|
||||||
|
_ = results // Just verify no error
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRetrievalGrepAllConversations(t *testing.T) {
|
||||||
|
r, s, _ := newTestRetrieval(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create another conversation
|
||||||
|
conv2, _ := s.GetOrCreateConversation(ctx, "test:retrieval2")
|
||||||
|
|
||||||
|
// Add messages to both
|
||||||
|
s.AddMessage(ctx, conv2.ConversationID, "user", "unique keyword xyz", 5)
|
||||||
|
|
||||||
|
// Search all conversations
|
||||||
|
results, err := r.Grep(ctx, GrepInput{
|
||||||
|
Pattern: "xyz",
|
||||||
|
AllConversations: true,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Grep: %v", err)
|
||||||
|
}
|
||||||
|
if len(results.Messages) == 0 {
|
||||||
|
t.Error("expected to find message in other conversation")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Last Duration Parsing Tests ---
|
||||||
|
|
||||||
|
func TestParseLastDuration(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
wantDur time.Duration
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"6h", 6 * time.Hour, false},
|
||||||
|
{"1d", 24 * time.Hour, false},
|
||||||
|
{"7d", 7 * 24 * time.Hour, false},
|
||||||
|
{"2w", 14 * 24 * time.Hour, false},
|
||||||
|
{"1m", 30 * 24 * time.Hour, false}, // month = 30 days
|
||||||
|
{"3m", 90 * 24 * time.Hour, false},
|
||||||
|
{"", 0, true},
|
||||||
|
{"invalid", 0, true},
|
||||||
|
{"5x", 0, true}, // unknown unit
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.input, func(t *testing.T) {
|
||||||
|
got, err := ParseLastDuration(tt.input)
|
||||||
|
if tt.wantErr {
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error, got nil")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if got != tt.wantDur {
|
||||||
|
t.Errorf("ParseLastDuration(%q) = %v, want %v", tt.input, got, tt.wantDur)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Role Filter Tests ---
|
||||||
|
|
||||||
|
func TestRetrievalGrepRoleFilter(t *testing.T) {
|
||||||
|
r, s, convID := newTestRetrieval(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
s.AddMessage(ctx, convID, "user", "user message about alpha", 5)
|
||||||
|
s.AddMessage(ctx, convID, "assistant", "assistant reply about alpha", 5)
|
||||||
|
s.AddMessage(ctx, convID, "user", "another user message", 5)
|
||||||
|
|
||||||
|
// Search all roles
|
||||||
|
allResults, err := r.Grep(ctx, GrepInput{
|
||||||
|
Pattern: "alpha",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Grep: %v", err)
|
||||||
|
}
|
||||||
|
if len(allResults.Messages) != 2 {
|
||||||
|
t.Errorf("expected 2 messages, got %d", len(allResults.Messages))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Search user only
|
||||||
|
userResults, err := r.Grep(ctx, GrepInput{
|
||||||
|
Pattern: "alpha",
|
||||||
|
Role: "user",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Grep: %v", err)
|
||||||
|
}
|
||||||
|
if len(userResults.Messages) != 1 {
|
||||||
|
t.Errorf("expected 1 user message, got %d", len(userResults.Messages))
|
||||||
|
}
|
||||||
|
if userResults.Messages[0].Role != "user" {
|
||||||
|
t.Errorf("expected role=user, got %s", userResults.Messages[0].Role)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Search assistant only
|
||||||
|
assistantResults, err := r.Grep(ctx, GrepInput{
|
||||||
|
Pattern: "alpha",
|
||||||
|
Role: "assistant",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Grep: %v", err)
|
||||||
|
}
|
||||||
|
if len(assistantResults.Messages) != 1 {
|
||||||
|
t.Errorf("expected 1 assistant message, got %d", len(assistantResults.Messages))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Last Parameter Tests ---
|
||||||
|
|
||||||
|
func TestRetrievalGrepWithLast(t *testing.T) {
|
||||||
|
r, s, convID := newTestRetrieval(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Add messages (we can't control timestamps in SQLite easily,
|
||||||
|
// but we can verify the parameter is parsed correctly)
|
||||||
|
s.AddMessage(ctx, convID, "user", "recent message about testing", 5)
|
||||||
|
|
||||||
|
// Test that Last parameter is converted to Since
|
||||||
|
results, err := r.Grep(ctx, GrepInput{
|
||||||
|
Pattern: "testing",
|
||||||
|
Last: "1d", // last 1 day
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Grep: %v", err)
|
||||||
|
}
|
||||||
|
// Should still find the message since it's recent
|
||||||
|
if len(results.Messages) == 0 {
|
||||||
|
t.Error("expected to find recent message")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRetrievalGrepRoleFilterWithSummaries tests that role filter works when
|
||||||
|
// searching both summaries and messages (summaries don't have role column).
|
||||||
|
func TestRetrievalGrepRoleFilterWithSummaries(t *testing.T) {
|
||||||
|
r, s, convID := newTestRetrieval(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create a summary (no role column)
|
||||||
|
s.CreateSummary(ctx, CreateSummaryInput{
|
||||||
|
ConversationID: convID,
|
||||||
|
Kind: SummaryKindLeaf,
|
||||||
|
Depth: 0,
|
||||||
|
Content: "summary about testing",
|
||||||
|
TokenCount: 50,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Add messages with different roles
|
||||||
|
s.AddMessage(ctx, convID, "user", "user message about testing", 5)
|
||||||
|
s.AddMessage(ctx, convID, "assistant", "assistant reply about testing", 5)
|
||||||
|
|
||||||
|
// Search with role filter and scope=both (default), using LIKE mode (%)
|
||||||
|
// This should NOT error even though summaries don't have role column
|
||||||
|
bothResults, err := r.Grep(ctx, GrepInput{
|
||||||
|
Pattern: "%testing%", // LIKE mode to trigger the bug
|
||||||
|
Role: "user",
|
||||||
|
Scope: "both",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Grep with role and scope=both: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should only return user messages, not summaries or assistant messages
|
||||||
|
if len(bothResults.Messages) != 1 {
|
||||||
|
t.Errorf("expected 1 user message, got %d", len(bothResults.Messages))
|
||||||
|
}
|
||||||
|
if len(bothResults.Messages) > 0 && bothResults.Messages[0].Role != "user" {
|
||||||
|
t.Errorf("expected role=user, got %s", bothResults.Messages[0].Role)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Summaries should be empty since they don't have roles to filter
|
||||||
|
// (or we could return all summaries - either is acceptable)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRetrievalGrepTotalCounts tests that grep returns total counts.
|
||||||
|
func TestRetrievalGrepTotalCounts(t *testing.T) {
|
||||||
|
r, s, convID := newTestRetrieval(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create 3 summaries
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
s.CreateSummary(ctx, CreateSummaryInput{
|
||||||
|
ConversationID: convID,
|
||||||
|
Kind: SummaryKindLeaf,
|
||||||
|
Depth: 0,
|
||||||
|
Content: fmt.Sprintf("summary about testing %d", i),
|
||||||
|
TokenCount: 50,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add 5 messages
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
s.AddMessage(ctx, convID, "user", fmt.Sprintf("message about testing %d", i), 5)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Search with limit smaller than total
|
||||||
|
results, err := r.Grep(ctx, GrepInput{
|
||||||
|
Pattern: "%testing%", // LIKE mode
|
||||||
|
Scope: "both",
|
||||||
|
Limit: 2,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Grep: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should return limited results
|
||||||
|
if len(results.Summaries) > 2 {
|
||||||
|
t.Errorf("expected at most 2 summaries, got %d", len(results.Summaries))
|
||||||
|
}
|
||||||
|
if len(results.Messages) > 2 {
|
||||||
|
t.Errorf("expected at most 2 messages, got %d", len(results.Messages))
|
||||||
|
}
|
||||||
|
|
||||||
|
// But total counts should reflect all matches
|
||||||
|
if results.TotalSummaries != 3 {
|
||||||
|
t.Errorf("expected TotalSummaries=3, got %d", results.TotalSummaries)
|
||||||
|
}
|
||||||
|
if results.TotalMessages != 5 {
|
||||||
|
t.Errorf("expected TotalMessages=5, got %d", results.TotalMessages)
|
||||||
|
}
|
||||||
|
}
|
||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,129 @@
|
|||||||
|
package seahorse
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/sipeed/picoclaw/pkg/tools"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ExpandTool recovers full message content by ID.
|
||||||
|
type ExpandTool struct {
|
||||||
|
engine *RetrievalEngine
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewExpandTool(engine *RetrievalEngine) *ExpandTool {
|
||||||
|
return &ExpandTool{engine: engine}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *ExpandTool) Name() string {
|
||||||
|
return "short_expand"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *ExpandTool) Description() string {
|
||||||
|
return `Get full message content by ID.
|
||||||
|
|
||||||
|
Use when short_grep returns messages and you need complete content (not just snippet).
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- message_ids (required): Array of message ID strings (from short_grep results)
|
||||||
|
|
||||||
|
Returns message with:
|
||||||
|
- content: Full text content
|
||||||
|
- parts: Structured content
|
||||||
|
- text: Full text
|
||||||
|
- tool_use: name, arguments, toolCallId
|
||||||
|
- tool_result: toolCallId only (content omitted - re-run tool if needed)
|
||||||
|
- media: mediaUri (file path), mimeType
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- tool_result content is not returned (can be large). Re-run the tool if you need the result.
|
||||||
|
- Media files are stored on disk at mediaUri path, use bash to access.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
{"message_ids": ["10", "25"]}`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *ExpandTool) Parameters() map[string]any {
|
||||||
|
return map[string]any{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]any{
|
||||||
|
"message_ids": map[string]any{
|
||||||
|
"type": "array",
|
||||||
|
"items": map[string]any{"type": "string"},
|
||||||
|
"description": "Message IDs to expand (from short_grep results, e.g., [\"10\", \"25\"])",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": []string{"message_ids"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *ExpandTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
|
||||||
|
idsRaw, ok := args["message_ids"].([]any)
|
||||||
|
if !ok || len(idsRaw) == 0 {
|
||||||
|
return tools.ErrorResult(
|
||||||
|
"Missing required 'message_ids' argument. " +
|
||||||
|
"Example: {\"message_ids\": [\"10\", \"25\"]}")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse message IDs
|
||||||
|
messageIDs := make([]int64, 0, len(idsRaw))
|
||||||
|
for _, id := range idsRaw {
|
||||||
|
switch v := id.(type) {
|
||||||
|
case string:
|
||||||
|
var n int64
|
||||||
|
if _, err := fmt.Sscanf(v, "%d", &n); err != nil {
|
||||||
|
return tools.ErrorResult(fmt.Sprintf("Invalid message_id %q: %v", v, err))
|
||||||
|
}
|
||||||
|
messageIDs = append(messageIDs, n)
|
||||||
|
case float64:
|
||||||
|
messageIDs = append(messageIDs, int64(v))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := t.engine.ExpandMessages(ctx, messageIDs)
|
||||||
|
if err != nil {
|
||||||
|
return tools.ErrorResult("Expand failed: " + err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build response with filtered parts
|
||||||
|
messages := make([]map[string]any, 0, len(result.Messages))
|
||||||
|
for _, msg := range result.Messages {
|
||||||
|
parts := make([]map[string]any, 0, len(msg.Parts))
|
||||||
|
for _, p := range msg.Parts {
|
||||||
|
part := map[string]any{"type": p.Type}
|
||||||
|
switch p.Type {
|
||||||
|
case "text":
|
||||||
|
part["text"] = p.Text
|
||||||
|
case "tool_use":
|
||||||
|
part["name"] = p.Name
|
||||||
|
part["arguments"] = p.Arguments
|
||||||
|
part["toolCallId"] = p.ToolCallID
|
||||||
|
case "tool_result":
|
||||||
|
// Omit content - can be large, re-run tool if needed
|
||||||
|
part["toolCallId"] = p.ToolCallID
|
||||||
|
case "media":
|
||||||
|
part["mediaUri"] = p.MediaURI
|
||||||
|
part["mimeType"] = p.MimeType
|
||||||
|
}
|
||||||
|
parts = append(parts, part)
|
||||||
|
}
|
||||||
|
|
||||||
|
messages = append(messages, map[string]any{
|
||||||
|
"id": fmt.Sprintf("%d", msg.ID),
|
||||||
|
"role": msg.Role,
|
||||||
|
"content": msg.Content,
|
||||||
|
"parts": parts,
|
||||||
|
"conversationId": msg.ConversationID,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
output := map[string]any{
|
||||||
|
"success": true,
|
||||||
|
"tokenCount": result.TokenCount,
|
||||||
|
"messages": messages,
|
||||||
|
}
|
||||||
|
data, _ := json.Marshal(output)
|
||||||
|
return tools.NewToolResult(string(data))
|
||||||
|
}
|
||||||
@@ -0,0 +1,136 @@
|
|||||||
|
package seahorse
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestExpandToolByMessageIDs(t *testing.T) {
|
||||||
|
s := openTestStore(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
conv, _ := s.GetOrCreateConversation(ctx, "test:expand-tool")
|
||||||
|
|
||||||
|
msg1, _ := s.AddMessage(ctx, conv.ConversationID, "user", "first message", 10)
|
||||||
|
msg2, _ := s.AddMessage(ctx, conv.ConversationID, "assistant", "second message", 10)
|
||||||
|
|
||||||
|
re := &RetrievalEngine{store: s}
|
||||||
|
tool := NewExpandTool(re)
|
||||||
|
|
||||||
|
result := tool.Execute(ctx, map[string]any{
|
||||||
|
"message_ids": []any{fmt.Sprintf("%d", msg1.ID), fmt.Sprintf("%d", msg2.ID)},
|
||||||
|
})
|
||||||
|
|
||||||
|
if result.IsError {
|
||||||
|
t.Fatalf("Expand failed: %s", result.ForLLM)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse result
|
||||||
|
var output struct {
|
||||||
|
Success bool `json:"success"`
|
||||||
|
TokenCount int `json:"tokenCount"`
|
||||||
|
Messages []map[string]any `json:"messages"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal([]byte(result.ForLLM), &output); err != nil {
|
||||||
|
t.Fatalf("Parse result: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !output.Success {
|
||||||
|
t.Error("expected success=true")
|
||||||
|
}
|
||||||
|
if len(output.Messages) != 2 {
|
||||||
|
t.Errorf("Messages = %d, want 2", len(output.Messages))
|
||||||
|
}
|
||||||
|
if output.TokenCount != 20 {
|
||||||
|
t.Errorf("TokenCount = %d, want 20", output.TokenCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExpandToolMissingIDs(t *testing.T) {
|
||||||
|
s := openTestStore(t)
|
||||||
|
re := &RetrievalEngine{store: s}
|
||||||
|
tool := NewExpandTool(re)
|
||||||
|
|
||||||
|
result := tool.Execute(context.Background(), map[string]any{})
|
||||||
|
|
||||||
|
if !result.IsError {
|
||||||
|
t.Error("expected error for missing message_ids")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExpandToolWithParts(t *testing.T) {
|
||||||
|
s := openTestStore(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
conv, _ := s.GetOrCreateConversation(ctx, "test:expand-parts")
|
||||||
|
|
||||||
|
// Create message with parts
|
||||||
|
parts := []MessagePart{
|
||||||
|
{Type: "text", Text: "Hello"},
|
||||||
|
{Type: "tool_use", Name: "bash", Arguments: `{"command":"ls"}`, ToolCallID: "call_123"},
|
||||||
|
{Type: "tool_result", ToolCallID: "call_123", Text: "file1.txt\nfile2.txt"},
|
||||||
|
}
|
||||||
|
msg, _ := s.AddMessageWithParts(ctx, conv.ConversationID, "assistant", parts, 50)
|
||||||
|
|
||||||
|
re := &RetrievalEngine{store: s}
|
||||||
|
tool := NewExpandTool(re)
|
||||||
|
|
||||||
|
result := tool.Execute(ctx, map[string]any{
|
||||||
|
"message_ids": []any{fmt.Sprintf("%d", msg.ID)},
|
||||||
|
})
|
||||||
|
|
||||||
|
if result.IsError {
|
||||||
|
t.Fatalf("Expand failed: %s", result.ForLLM)
|
||||||
|
}
|
||||||
|
|
||||||
|
var output struct {
|
||||||
|
Messages []struct {
|
||||||
|
Parts []map[string]any `json:"parts"`
|
||||||
|
} `json:"messages"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal([]byte(result.ForLLM), &output); err != nil {
|
||||||
|
t.Fatalf("Parse result: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(output.Messages) != 1 {
|
||||||
|
t.Fatalf("Messages = %d, want 1", len(output.Messages))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify parts are filtered correctly
|
||||||
|
foundText := false
|
||||||
|
foundToolUse := false
|
||||||
|
foundToolResult := false
|
||||||
|
for _, p := range output.Messages[0].Parts {
|
||||||
|
switch p["type"].(string) {
|
||||||
|
case "text":
|
||||||
|
foundText = true
|
||||||
|
if p["text"] != "Hello" {
|
||||||
|
t.Errorf("text = %v, want Hello", p["text"])
|
||||||
|
}
|
||||||
|
case "tool_use":
|
||||||
|
foundToolUse = true
|
||||||
|
if p["name"] != "bash" {
|
||||||
|
t.Errorf("name = %v, want bash", p["name"])
|
||||||
|
}
|
||||||
|
case "tool_result":
|
||||||
|
foundToolResult = true
|
||||||
|
// tool_result should NOT have content
|
||||||
|
if _, hasContent := p["content"]; hasContent {
|
||||||
|
t.Error("tool_result should not have content field")
|
||||||
|
}
|
||||||
|
if p["toolCallId"] != "call_123" {
|
||||||
|
t.Errorf("toolCallId = %v, want call_123", p["toolCallId"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !foundText {
|
||||||
|
t.Error("missing text part")
|
||||||
|
}
|
||||||
|
if !foundToolUse {
|
||||||
|
t.Error("missing tool_use part")
|
||||||
|
}
|
||||||
|
if !foundToolResult {
|
||||||
|
t.Error("missing tool_result part")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,172 @@
|
|||||||
|
package seahorse
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sipeed/picoclaw/pkg/tools"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GrepTool searches summaries and messages for matching content.
|
||||||
|
type GrepTool struct {
|
||||||
|
engine *RetrievalEngine
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewGrepTool(engine *RetrievalEngine) *GrepTool {
|
||||||
|
return &GrepTool{engine: engine}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *GrepTool) Name() string {
|
||||||
|
return "short_grep"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *GrepTool) Description() string {
|
||||||
|
return `Search summaries and messages for matching content.
|
||||||
|
|
||||||
|
Pattern syntax:
|
||||||
|
- Words: "authentication" - matches content containing this word
|
||||||
|
- AND: "auth AND login" - matches content with both words
|
||||||
|
- OR: "auth OR signin" - matches content with either word
|
||||||
|
- NOT: "bug NOT fixed" - matches "bug" but excludes "fixed"
|
||||||
|
- Wildcard: "%auth%" - matches any text containing "auth" (e.g., "auth", "authentication")
|
||||||
|
|
||||||
|
Each summary has a "depth" field:
|
||||||
|
- depth 0: Created from messages, most detailed
|
||||||
|
- depth 1+: Created from other summaries, more compressed but covers longer time
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- pattern (required): Search pattern
|
||||||
|
- scope: "both" (default), "summary", or "message" - what to search
|
||||||
|
- role: "user", "assistant", or omit for all - filter by message role
|
||||||
|
- last: Time shortcut like "6h", "7d", "2w", "1m" (hours/days/weeks/months)
|
||||||
|
- all_conversations: Search all conversations (default: current only)
|
||||||
|
- since: ISO8601 timestamp, content after this time
|
||||||
|
- before: ISO8601 timestamp, content before this time
|
||||||
|
- limit: Max results (default: 20)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{
|
||||||
|
"success": true,
|
||||||
|
"summaries": [{"id": "sum_abc", "content": "...", "depth": 0, "kind": "leaf", "conversationId": 1, "rank": -0.5}],
|
||||||
|
"messages": [{"id": "10", "snippet": "...matched...", "role": "user", "conversationId": 1, "rank": -1.2}],
|
||||||
|
"totalSummaries": 5,
|
||||||
|
"totalMessages": 10,
|
||||||
|
"hint": "No matches. Try: %keyword% for fuzzy search"
|
||||||
|
}
|
||||||
|
|
||||||
|
Rank field (FTS5 mode only): bm25 relevance score, negative value where closer to 0 = better match.
|
||||||
|
Examples: -0.5=excellent, -2=good, -5=partial, -10=weak. LIKE mode (%pattern%) has no rank.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
{"pattern": "authentication"}
|
||||||
|
{"pattern": "bug AND login"}
|
||||||
|
{"pattern": "%snake%"}
|
||||||
|
{"pattern": "project", "scope": "summary"}
|
||||||
|
{"pattern": "error", "role": "assistant", "last": "7d"}
|
||||||
|
{"pattern": "error", "all_conversations": true}`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *GrepTool) Parameters() map[string]any {
|
||||||
|
return map[string]any{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]any{
|
||||||
|
"pattern": map[string]any{
|
||||||
|
"type": "string",
|
||||||
|
"description": "Search pattern. Supports: words, AND/OR/NOT operators, % wildcard",
|
||||||
|
},
|
||||||
|
"scope": map[string]any{
|
||||||
|
"type": "string",
|
||||||
|
"enum": []string{"both", "summary", "message"},
|
||||||
|
"description": "What to search: 'both' (default), 'summary', or 'message'",
|
||||||
|
},
|
||||||
|
"role": map[string]any{
|
||||||
|
"type": "string",
|
||||||
|
"enum": []string{"user", "assistant"},
|
||||||
|
"description": "Filter by message role (default: all roles)",
|
||||||
|
},
|
||||||
|
"last": map[string]any{
|
||||||
|
"type": "string",
|
||||||
|
"description": "Time shortcut: '6h' (6 hours), '7d' (7 days), '2w' (2 weeks), '1m' (1 month)",
|
||||||
|
},
|
||||||
|
"all_conversations": map[string]any{
|
||||||
|
"type": "boolean",
|
||||||
|
"description": "Search across all conversations (default: searches current conversation only)",
|
||||||
|
},
|
||||||
|
"since": map[string]any{
|
||||||
|
"type": "string",
|
||||||
|
"description": "ISO8601 timestamp, only return content after this time",
|
||||||
|
},
|
||||||
|
"before": map[string]any{
|
||||||
|
"type": "string",
|
||||||
|
"description": "ISO8601 timestamp, only return content before this time",
|
||||||
|
},
|
||||||
|
"limit": map[string]any{
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Maximum number of results (default: 20)",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": []string{"pattern"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *GrepTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
|
||||||
|
pattern, ok := args["pattern"].(string)
|
||||||
|
if !ok || pattern == "" {
|
||||||
|
return tools.ErrorResult("Missing required 'pattern' argument. Example: {\"pattern\": \"authentication\"}")
|
||||||
|
}
|
||||||
|
|
||||||
|
input := GrepInput{Pattern: pattern}
|
||||||
|
|
||||||
|
if scope, ok := args["scope"].(string); ok && scope != "" {
|
||||||
|
input.Scope = scope
|
||||||
|
}
|
||||||
|
if role, ok := args["role"].(string); ok && role != "" {
|
||||||
|
input.Role = role
|
||||||
|
}
|
||||||
|
if last, ok := args["last"].(string); ok && last != "" {
|
||||||
|
input.Last = last
|
||||||
|
}
|
||||||
|
if allConv, ok := args["all_conversations"].(bool); ok {
|
||||||
|
input.AllConversations = allConv
|
||||||
|
}
|
||||||
|
if limit, ok := args["limit"].(float64); ok {
|
||||||
|
input.Limit = int(limit)
|
||||||
|
}
|
||||||
|
if sinceStr, ok := args["since"].(string); ok && sinceStr != "" {
|
||||||
|
parsed, err := time.Parse(time.RFC3339, sinceStr)
|
||||||
|
if err != nil {
|
||||||
|
return tools.ErrorResult(fmt.Sprintf(
|
||||||
|
"Invalid 'since' timestamp. Use RFC3339 format like '2024-01-15T10:00:00Z'. Error: %v", err))
|
||||||
|
}
|
||||||
|
input.Since = &parsed
|
||||||
|
}
|
||||||
|
if beforeStr, ok := args["before"].(string); ok && beforeStr != "" {
|
||||||
|
parsed, err := time.Parse(time.RFC3339, beforeStr)
|
||||||
|
if err != nil {
|
||||||
|
return tools.ErrorResult(fmt.Sprintf("Invalid 'before' timestamp format: %v", err))
|
||||||
|
}
|
||||||
|
input.Before = &parsed
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := t.engine.Grep(ctx, input)
|
||||||
|
if err != nil {
|
||||||
|
return tools.ErrorResult("Grep failed: " + err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build response
|
||||||
|
output := map[string]any{
|
||||||
|
"success": result.Success,
|
||||||
|
"summaries": result.Summaries,
|
||||||
|
"messages": result.Messages,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add hint if provided
|
||||||
|
if result.Hint != "" {
|
||||||
|
output["hint"] = result.Hint
|
||||||
|
}
|
||||||
|
|
||||||
|
data, _ := json.Marshal(output)
|
||||||
|
return tools.NewToolResult(string(data))
|
||||||
|
}
|
||||||
@@ -0,0 +1,72 @@
|
|||||||
|
package seahorse
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGrepSearchSummaries(t *testing.T) {
|
||||||
|
s := openTestStore(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
conv, _ := s.GetOrCreateConversation(ctx, "test:grep-tool")
|
||||||
|
|
||||||
|
s.CreateSummary(ctx, CreateSummaryInput{
|
||||||
|
ConversationID: conv.ConversationID,
|
||||||
|
Kind: SummaryKindLeaf,
|
||||||
|
Depth: 0,
|
||||||
|
Content: "database connection pool configuration",
|
||||||
|
TokenCount: 50,
|
||||||
|
})
|
||||||
|
|
||||||
|
re := &RetrievalEngine{store: s}
|
||||||
|
results, err := re.Grep(ctx, GrepInput{
|
||||||
|
Pattern: "database",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Grep: %v", err)
|
||||||
|
}
|
||||||
|
if len(results.Summaries) == 0 {
|
||||||
|
t.Error("expected at least 1 summary result")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGrepSearchMessages(t *testing.T) {
|
||||||
|
s := openTestStore(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
conv, _ := s.GetOrCreateConversation(ctx, "test:grep-msg")
|
||||||
|
|
||||||
|
s.AddMessage(ctx, conv.ConversationID, "user", "find this message about testing", 5)
|
||||||
|
s.AddMessage(ctx, conv.ConversationID, "user", "unrelated content", 3)
|
||||||
|
|
||||||
|
re := &RetrievalEngine{store: s}
|
||||||
|
results, err := re.Grep(ctx, GrepInput{
|
||||||
|
Pattern: "testing",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Grep messages: %v", err)
|
||||||
|
}
|
||||||
|
if len(results.Messages) == 0 {
|
||||||
|
t.Error("expected at least 1 message result")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGrepMissingPattern(t *testing.T) {
|
||||||
|
s := openTestStore(t)
|
||||||
|
re := &RetrievalEngine{store: s}
|
||||||
|
_, err := re.Grep(context.Background(), GrepInput{})
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for missing pattern")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGrepToolSupportsAllConversations(t *testing.T) {
|
||||||
|
s := openTestStore(t)
|
||||||
|
tool := NewGrepTool(&RetrievalEngine{store: s})
|
||||||
|
params := tool.Parameters()
|
||||||
|
props := params["properties"].(map[string]any)
|
||||||
|
|
||||||
|
// GrepTool should accept all_conversations parameter
|
||||||
|
if _, ok := props["all_conversations"]; !ok {
|
||||||
|
t.Error("Parameters missing 'all_conversations' field")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,161 @@
|
|||||||
|
package seahorse
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sipeed/picoclaw/pkg/providers"
|
||||||
|
"github.com/sipeed/picoclaw/pkg/tokenizer"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SummaryKind distinguishes leaf summaries (from raw messages) vs condensed
|
||||||
|
// summaries (from other summaries).
|
||||||
|
type SummaryKind string
|
||||||
|
|
||||||
|
const (
|
||||||
|
SummaryKindLeaf SummaryKind = "leaf"
|
||||||
|
SummaryKindCondensed SummaryKind = "condensed"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Message represents a single chat message with role and content.
|
||||||
|
type Message struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
ConversationID int64 `json:"conversationId"`
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
ReasoningContent string `json:"reasoningContent,omitempty"`
|
||||||
|
TokenCount int `json:"tokenCount"`
|
||||||
|
CreatedAt time.Time `json:"createdAt"`
|
||||||
|
Parts []MessagePart `json:"parts,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// MessagePart holds structured content (tool calls, media, etc.)
|
||||||
|
type MessagePart struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
MessageID int64 `json:"messageId"`
|
||||||
|
Type string `json:"type"` // "text", "tool_use", "tool_result", "media"
|
||||||
|
Text string `json:"text"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Arguments string `json:"arguments"`
|
||||||
|
ToolCallID string `json:"toolCallId"`
|
||||||
|
MediaURI string `json:"mediaUri"`
|
||||||
|
MimeType string `json:"mimeType"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Summary represents a compressed representation of messages or other summaries.
|
||||||
|
type Summary struct {
|
||||||
|
SummaryID string `json:"summaryId"`
|
||||||
|
ConversationID int64 `json:"conversationId"`
|
||||||
|
Kind SummaryKind `json:"kind"`
|
||||||
|
Depth int `json:"depth"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
TokenCount int `json:"tokenCount"`
|
||||||
|
EarliestAt *time.Time `json:"earliestAt,omitempty"`
|
||||||
|
LatestAt *time.Time `json:"latestAt,omitempty"`
|
||||||
|
DescendantCount int `json:"descendantCount"`
|
||||||
|
DescendantTokenCount int `json:"descendantTokenCount"`
|
||||||
|
SourceMessageTokenCount int `json:"sourceMessageTokenCount"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
CreatedAt time.Time `json:"createdAt"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SummaryNode is a Summary with graph relationships for tree traversal.
|
||||||
|
type SummaryNode struct {
|
||||||
|
Summary
|
||||||
|
Children []string `json:"children"` // Child summary IDs
|
||||||
|
Expanded bool `json:"expanded"` // UI state for expansion
|
||||||
|
}
|
||||||
|
|
||||||
|
// Conversation represents a session's conversation with metadata.
|
||||||
|
type Conversation struct {
|
||||||
|
ConversationID int64 `json:"conversationId"`
|
||||||
|
SessionKey string `json:"sessionKey"`
|
||||||
|
CreatedAt time.Time `json:"createdAt"`
|
||||||
|
UpdatedAt time.Time `json:"updatedAt"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SessionStatus contains status information for a session.
|
||||||
|
type SessionStatus struct {
|
||||||
|
SessionKey string `json:"sessionKey"`
|
||||||
|
ConversationID int64 `json:"conversationId"`
|
||||||
|
Messages int `json:"messages"`
|
||||||
|
TotalTokens int `json:"totalTokens"`
|
||||||
|
Summaries int `json:"summaries"`
|
||||||
|
OldestAt time.Time `json:"oldestAt"`
|
||||||
|
NewestAt time.Time `json:"newestAt"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ContextItem represents one item in the assembled context window.
|
||||||
|
type ContextItem struct {
|
||||||
|
ConversationID int64 `json:"conversationId"`
|
||||||
|
Ordinal int `json:"ordinal"`
|
||||||
|
ItemType string `json:"itemType"` // "summary" or "message"
|
||||||
|
SummaryID string `json:"summaryId,omitempty"`
|
||||||
|
MessageID int64 `json:"messageId,omitempty"`
|
||||||
|
TokenCount int `json:"tokenCount"`
|
||||||
|
CreatedAt time.Time `json:"createdAt"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SummarySubtreeNode is a node in a summary DAG subtree.
|
||||||
|
type SummarySubtreeNode struct {
|
||||||
|
SummaryID string `json:"summaryId"`
|
||||||
|
DepthFromRoot int `json:"depthFromRoot"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SearchInput controls summary search.
|
||||||
|
type SearchInput struct {
|
||||||
|
Pattern string `json:"pattern"`
|
||||||
|
Mode string `json:"mode"` // "like" (LIKE search) or "full_text" (FTS5, default)
|
||||||
|
Scope string `json:"scope,omitempty"` // "messages", "summaries", "both"
|
||||||
|
Role string `json:"role,omitempty"` // "user", "assistant", or "" (all)
|
||||||
|
Since *time.Time `json:"since,omitempty"`
|
||||||
|
Before *time.Time `json:"before,omitempty"`
|
||||||
|
Limit int `json:"limit,omitempty"`
|
||||||
|
ConversationID int64 `json:"conversationId,omitempty"`
|
||||||
|
AllConversations bool `json:"allConversations,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SearchResult is a search match.
|
||||||
|
type SearchResult struct {
|
||||||
|
SummaryID string `json:"summaryId,omitempty"`
|
||||||
|
MessageID int64 `json:"messageId,omitempty"`
|
||||||
|
ConversationID int64 `json:"conversationId"`
|
||||||
|
Kind SummaryKind `json:"kind,omitempty"`
|
||||||
|
Depth int `json:"depth,omitempty"`
|
||||||
|
Role string `json:"role,omitempty"`
|
||||||
|
Content string `json:"content,omitempty"` // Full content for summaries
|
||||||
|
Snippet string `json:"snippet"`
|
||||||
|
CreatedAt time.Time `json:"createdAt"`
|
||||||
|
Rank float64 `json:"rank,omitempty"`
|
||||||
|
TotalCount int `json:"totalCount,omitempty"` // Total matching rows (from window function)
|
||||||
|
}
|
||||||
|
|
||||||
|
// EstimateMessageTokens estimates token count for a full message using the
|
||||||
|
// shared tokenizer package for consistency with agent.context_budget.
|
||||||
|
func EstimateMessageTokens(msg Message) int {
|
||||||
|
pm := providers.Message{
|
||||||
|
Role: msg.Role,
|
||||||
|
Content: msg.Content,
|
||||||
|
ReasoningContent: msg.ReasoningContent,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert MessageParts to ToolCalls / ToolCallID / Media
|
||||||
|
for _, part := range msg.Parts {
|
||||||
|
switch part.Type {
|
||||||
|
case "tool_use":
|
||||||
|
pm.ToolCalls = append(pm.ToolCalls, providers.ToolCall{
|
||||||
|
ID: part.ToolCallID,
|
||||||
|
Type: "function",
|
||||||
|
Function: &providers.FunctionCall{
|
||||||
|
Name: part.Name,
|
||||||
|
Arguments: part.Arguments,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
case "tool_result":
|
||||||
|
pm.ToolCallID = part.ToolCallID
|
||||||
|
case "media":
|
||||||
|
pm.Media = append(pm.Media, part.MediaURI)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokenizer.EstimateMessageTokens(pm)
|
||||||
|
}
|
||||||
@@ -0,0 +1,54 @@
|
|||||||
|
package seahorse
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSummaryKindValues(t *testing.T) {
|
||||||
|
if SummaryKindLeaf != "leaf" {
|
||||||
|
t.Errorf("expected SummaryKindLeaf = 'leaf', got %q", SummaryKindLeaf)
|
||||||
|
}
|
||||||
|
if SummaryKindCondensed != "condensed" {
|
||||||
|
t.Errorf("expected SummaryKindCondensed = 'condensed', got %q", SummaryKindCondensed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConstants(t *testing.T) {
|
||||||
|
// Ordinal gap step
|
||||||
|
if OrdinalStep != 100 {
|
||||||
|
t.Errorf("expected OrdinalStep = 100, got %d", OrdinalStep)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compaction triggers
|
||||||
|
if ContextThreshold != 0.75 {
|
||||||
|
t.Errorf("expected ContextThreshold = 0.75, got %f", ContextThreshold)
|
||||||
|
}
|
||||||
|
if FreshTailCount != 32 {
|
||||||
|
t.Errorf("expected FreshTailCount = 32, got %d", FreshTailCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fanout
|
||||||
|
if LeafMinFanout != 8 {
|
||||||
|
t.Errorf("expected LeafMinFanout = 8, got %d", LeafMinFanout)
|
||||||
|
}
|
||||||
|
if CondensedMinFanout != 4 {
|
||||||
|
t.Errorf("expected CondensedMinFanout = 4, got %d", CondensedMinFanout)
|
||||||
|
}
|
||||||
|
if CondensedMinFanoutHard != 2 {
|
||||||
|
t.Errorf("expected CondensedMinFanoutHard = 2, got %d", CondensedMinFanoutHard)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Token targets
|
||||||
|
if LeafChunkTokens != 20000 {
|
||||||
|
t.Errorf("expected LeafChunkTokens = 20000, got %d", LeafChunkTokens)
|
||||||
|
}
|
||||||
|
if LeafTargetTokens != 1200 {
|
||||||
|
t.Errorf("expected LeafTargetTokens = 1200, got %d", LeafTargetTokens)
|
||||||
|
}
|
||||||
|
if CondensedTargetTokens != 2000 {
|
||||||
|
t.Errorf("expected CondensedTargetTokens = 2000, got %d", CondensedTargetTokens)
|
||||||
|
}
|
||||||
|
if MaxExpandTokens != 4000 {
|
||||||
|
t.Errorf("expected MaxExpandTokens = 4000, got %d", MaxExpandTokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -79,3 +79,8 @@ func (b *JSONLBackend) Save(key string) error {
|
|||||||
func (b *JSONLBackend) Close() error {
|
func (b *JSONLBackend) Close() error {
|
||||||
return b.store.Close()
|
return b.store.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ListSessions returns all known session keys.
|
||||||
|
func (b *JSONLBackend) ListSessions() []string {
|
||||||
|
return b.store.ListSessions()
|
||||||
|
}
|
||||||
|
|||||||
@@ -145,6 +145,16 @@ func (sm *SessionManager) TruncateHistory(key string, keepLast int) {
|
|||||||
session.Updated = time.Now()
|
session.Updated = time.Now()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (sm *SessionManager) ListSessions() []string {
|
||||||
|
sm.mu.RLock()
|
||||||
|
defer sm.mu.RUnlock()
|
||||||
|
keys := make([]string, 0, len(sm.sessions))
|
||||||
|
for k := range sm.sessions {
|
||||||
|
keys = append(keys, k)
|
||||||
|
}
|
||||||
|
return keys
|
||||||
|
}
|
||||||
|
|
||||||
// sanitizeFilename converts a session key into a cross-platform safe filename.
|
// sanitizeFilename converts a session key into a cross-platform safe filename.
|
||||||
// Replaces ':' with '_' (session key separator) and '/' and '\' with '_' so
|
// Replaces ':' with '_' (session key separator) and '/' and '\' with '_' so
|
||||||
// composite IDs (e.g. Telegram forum "chatID/threadID") do not create
|
// composite IDs (e.g. Telegram forum "chatID/threadID") do not create
|
||||||
|
|||||||
@@ -27,6 +27,8 @@ type SessionStore interface {
|
|||||||
TruncateHistory(key string, keepLast int)
|
TruncateHistory(key string, keepLast int)
|
||||||
// Save persists any pending state to durable storage.
|
// Save persists any pending state to durable storage.
|
||||||
Save(key string) error
|
Save(key string) error
|
||||||
|
// ListSessions returns all known session keys.
|
||||||
|
ListSessions() []string
|
||||||
// Close releases resources held by the store.
|
// Close releases resources held by the store.
|
||||||
Close() error
|
Close() error
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,91 @@
|
|||||||
|
package tokenizer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"unicode/utf8"
|
||||||
|
|
||||||
|
"github.com/sipeed/picoclaw/pkg/providers"
|
||||||
|
)
|
||||||
|
|
||||||
|
// EstimateMessageTokens estimates the token count for a single message,
|
||||||
|
// including Content, ReasoningContent, ToolCalls arguments, ToolCallID
|
||||||
|
// metadata, and Media items. Uses a heuristic of 2.5 characters per token.
|
||||||
|
func EstimateMessageTokens(msg providers.Message) int {
|
||||||
|
contentChars := utf8.RuneCountInString(msg.Content)
|
||||||
|
|
||||||
|
// SystemParts are structured system blocks used for cache-aware adapters.
|
||||||
|
// They carry the same content as Content, but in multiple blocks.
|
||||||
|
// We estimate them as an alternative representation, not additive.
|
||||||
|
systemPartsChars := 0
|
||||||
|
if len(msg.SystemParts) > 0 {
|
||||||
|
for _, part := range msg.SystemParts {
|
||||||
|
systemPartsChars += utf8.RuneCountInString(part.Text)
|
||||||
|
}
|
||||||
|
// Per-part overhead for JSON structure (type, text, cache_control).
|
||||||
|
const perPartOverhead = 20
|
||||||
|
systemPartsChars += len(msg.SystemParts) * perPartOverhead
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use the larger of the two representations to stay conservative.
|
||||||
|
chars := contentChars
|
||||||
|
if systemPartsChars > chars {
|
||||||
|
chars = systemPartsChars
|
||||||
|
}
|
||||||
|
|
||||||
|
chars += utf8.RuneCountInString(msg.ReasoningContent)
|
||||||
|
|
||||||
|
for _, tc := range msg.ToolCalls {
|
||||||
|
chars += len(tc.ID) + len(tc.Type)
|
||||||
|
if tc.Function != nil {
|
||||||
|
// Count function name + arguments (the wire format for most providers).
|
||||||
|
// tc.Name mirrors tc.Function.Name — count only once to avoid double-counting.
|
||||||
|
chars += len(tc.Function.Name) + len(tc.Function.Arguments)
|
||||||
|
} else {
|
||||||
|
// Fallback: some provider formats use top-level Name without Function.
|
||||||
|
chars += len(tc.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if msg.ToolCallID != "" {
|
||||||
|
chars += len(msg.ToolCallID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Per-message overhead for role label, JSON structure, separators.
|
||||||
|
const messageOverhead = 12
|
||||||
|
chars += messageOverhead
|
||||||
|
|
||||||
|
tokens := chars * 2 / 5
|
||||||
|
|
||||||
|
// Media items (images, files) are serialized by provider adapters into
|
||||||
|
// multipart or image_url payloads. Add a fixed per-item token estimate
|
||||||
|
// directly (not through the chars heuristic) since actual cost depends
|
||||||
|
// on resolution and provider-specific image tokenization.
|
||||||
|
const mediaTokensPerItem = 256
|
||||||
|
tokens += len(msg.Media) * mediaTokensPerItem
|
||||||
|
|
||||||
|
return tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
// EstimateToolDefsTokens estimates the total token cost of tool definitions
|
||||||
|
// as they appear in the LLM request.
|
||||||
|
func EstimateToolDefsTokens(defs []providers.ToolDefinition) int {
|
||||||
|
if len(defs) == 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
totalChars := 0
|
||||||
|
for _, d := range defs {
|
||||||
|
totalChars += len(d.Function.Name) + len(d.Function.Description)
|
||||||
|
|
||||||
|
if d.Function.Parameters != nil {
|
||||||
|
if paramJSON, err := json.Marshal(d.Function.Parameters); err == nil {
|
||||||
|
totalChars += len(paramJSON)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Per-tool overhead: type field, JSON structure, separators.
|
||||||
|
totalChars += 20
|
||||||
|
}
|
||||||
|
|
||||||
|
return totalChars * 2 / 5
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user