mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-05-25 16:00:35 +00:00
feat(seahorse): implement short-term memory engine (LCM) (#2285)
* feat(seahorse): implement short-term memory engine of seahorse Add pkg/seahorse/ module implementing a SQLite-backed DAG-based summary hierarchy for context management, ported from lossless-claw's LCM design: - types.go + short_constants.go: core types (Message, Summary, Conversation, ContextItem) and configuration constants (fanout, token targets, thresholds) - migration.go: idempotent DB schema with FTS5 trigram tokenizer for CJK - store.go: full SQLite CRUD (conversations, messages, summaries DAG, context_items with ordinal gap numbering, FTS5 search) - short_engine.go: Engine lifecycle (NewEngine, Ingest, Assemble, Compact), session pattern filtering (ignore/stateless glob→regex compilation), per-session mutex via sync.Map - short_assembler.go: budget-aware context assembly with fresh tail protection (32 messages), oldest-first eviction, summary XML formatting, RebuildContextItems - short_compaction.go: leaf compaction (messages→summary) and condensed compaction (summaries→higher-level summary), 3-level LLM escalation, CompactUntilUnder for emergency overflow - short_retrieval.go: lookupByID, FTS5/LIKE search, recursive expand with token cap - context_seahorse.go: agent.ContextManager adapter, registered as "seahorse", provider↔seahorse message type conversion (ToolCalls, tool_result) * fix(seahorse): correct 3 adapter bugs in context management - TokenCount: use full message (Content+ToolCalls+Media) instead of Content-only - Empty Content: rebuild Content from tool_result Parts when stored empty - Duplicate summaries: summaries only in Summary field, not in History messages - Grep: fix SearchResult.Snippet→Content for summaries - Schema: fix FTS5 SQL uses VIRTUAL TABLE not TEMP TABLE - TestFTS5SQLConstants: verify FTS5 SQL syntax correctness - Test: fix flaky TestCompactLeaf * fix(agent): ingest steering messages into seahorse SQLite Steering messages were only persisted to session JSONL but not ingested into seahorse SQLite, causing them to be missing from context assembly. Added `ts.ingestMessage(turnCtx, al, pm)` call in the steering message injection block alongside the existing JSONL persistence. Test: TestSeahorseSteeringMessageIngested verifies steering messages appear in seahorse SQLite DB after being processed. * fix(seahorse): address 3 blocking bugs from code review - Fix resequenceContextItemsTx scan error handling (store.go:850) Changed `return err` to `return scanErr` to properly propagate scan errors instead of returning nil (which silently corrupts data) - Fix sql.NullString for INTEGER column (store.go:847) Changed `mid` from sql.NullString to sql.NullInt64 since message_id is INTEGER in schema. Removed unnecessary strconv.ParseInt call. - Fix compactCondensed fallback deleting non-candidate items Added ReplaceContextItemsWithSummary method for per-item deletion when candidates are not contiguous in ordinal space. Optimized to use range deletion when candidates are consecutive. * fix(seahorse): pass Budget to Compact for correct condensed threshold Issue #4 from PR review: When Budget was not passed to seahorse.Compact, it defaulted to `tokensBefore * 0.75`, making `tokensBefore > budget` always true and causing condensed compaction to trigger unnecessarily. Changes: - context_seahorse.go: Forward Budget from CompactRequest to CompactInput - loop.go: Pass Budget (ContextWindow) in all 3 Compact calls - Add test verifying condensed is skipped when tokens < threshold - Fix lint issues in store.go and store_test.go * fix(seahorse): add mutex for assembler lazy initialization Issue #5 from PR review: The check-then-create pattern for e.assembler was a data race when multiple goroutines called Assemble() concurrently: if e.assembler == nil { e.assembler = &Assembler{...} } Changes: - Add assemblerMu sync.Mutex to Engine struct - Add initAssemblerOnce() using double-checked locking (same pattern as initCompactionOnce) - Add TestAssemblerLazyInitRace to verify thread-safety * fix(seahorse): handle non-consecutive depths in selectShallowestCondensationCandidate Issue #8 from PR review: the loop iterated depth 0, 1, 2... assuming consecutive keys, but break when key was missing caused deeper depths to never be checked. Fix: collect all existing depth keys, sort, then iterate in order. * fix(seahorse): wrap DeleteMessagesAfterID and appendContextItems in transactions - DeleteMessagesAfterID: wrap all DELETE operations in a transaction for atomicity, remove redundant manual FTS delete (handled by trigger) - appendContextItems: use transaction to fix read-then-write race condition - Add GetMaxOrdinalTx and resolveItemTokenCountTx for transaction-scoped queries - Remove unused resolveItemTokenCount function Fixes PR review issues 6 and 7. * fix(seahorse): derive readable content from Parts and cap CompactUntilUnder iterations - Derive readable content from MessageParts in AddMessageWithParts so FTS5 indexing and summary formatting can access tool call information - formatMessagesForSummary and truncateSummary now fall back to Parts when Content is empty, fixing blank summaries for Part-based messages - Add MaxCompactIterations (20) to prevent CompactUntilUnder infinite loops; exceeded iterations are logged as warnings
This commit is contained in:
@@ -67,3 +67,5 @@ web/backend/dist/*
|
||||
.claude/
|
||||
|
||||
docker/data
|
||||
|
||||
.omc/
|
||||
|
||||
@@ -12,6 +12,7 @@ linters:
|
||||
- exhaustruct
|
||||
- funcorder
|
||||
- gochecknoglobals
|
||||
- gosmopolitan # Project legitimately uses CJK text in tests (FTS5, token counting)
|
||||
- godot
|
||||
- intrange
|
||||
- ireturn
|
||||
|
||||
+11
-85
@@ -6,10 +6,8 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/tokenizer"
|
||||
)
|
||||
|
||||
// parseTurnBoundaries returns the starting index of each Turn in the history.
|
||||
@@ -86,88 +84,16 @@ func findSafeBoundary(history []providers.Message, targetIndex int) int {
|
||||
return 0
|
||||
}
|
||||
|
||||
// estimateMessageTokens estimates the token count for a single message,
|
||||
// including Content, ReasoningContent, ToolCalls arguments, ToolCallID
|
||||
// metadata, and Media items. Uses a heuristic of 2.5 characters per token.
|
||||
func estimateMessageTokens(msg providers.Message) int {
|
||||
contentChars := utf8.RuneCountInString(msg.Content)
|
||||
|
||||
// SystemParts are structured system blocks used for cache-aware adapters.
|
||||
// They carry the same content as Content, but in multiple blocks.
|
||||
// We estimate them as an alternative representation, not additive.
|
||||
systemPartsChars := 0
|
||||
if len(msg.SystemParts) > 0 {
|
||||
for _, part := range msg.SystemParts {
|
||||
systemPartsChars += utf8.RuneCountInString(part.Text)
|
||||
}
|
||||
// Per-part overhead for JSON structure (type, text, cache_control).
|
||||
const perPartOverhead = 20
|
||||
systemPartsChars += len(msg.SystemParts) * perPartOverhead
|
||||
}
|
||||
|
||||
// Use the larger of the two representations to stay conservative.
|
||||
chars := contentChars
|
||||
if systemPartsChars > chars {
|
||||
chars = systemPartsChars
|
||||
}
|
||||
|
||||
chars += utf8.RuneCountInString(msg.ReasoningContent)
|
||||
|
||||
for _, tc := range msg.ToolCalls {
|
||||
chars += len(tc.ID) + len(tc.Type)
|
||||
if tc.Function != nil {
|
||||
// Count function name + arguments (the wire format for most providers).
|
||||
// tc.Name mirrors tc.Function.Name — count only once to avoid double-counting.
|
||||
chars += len(tc.Function.Name) + len(tc.Function.Arguments)
|
||||
} else {
|
||||
// Fallback: some provider formats use top-level Name without Function.
|
||||
chars += len(tc.Name)
|
||||
}
|
||||
}
|
||||
|
||||
if msg.ToolCallID != "" {
|
||||
chars += len(msg.ToolCallID)
|
||||
}
|
||||
|
||||
// Per-message overhead for role label, JSON structure, separators.
|
||||
const messageOverhead = 12
|
||||
chars += messageOverhead
|
||||
|
||||
tokens := chars * 2 / 5
|
||||
|
||||
// Media items (images, files) are serialized by provider adapters into
|
||||
// multipart or image_url payloads. Add a fixed per-item token estimate
|
||||
// directly (not through the chars heuristic) since actual cost depends
|
||||
// on resolution and provider-specific image tokenization.
|
||||
const mediaTokensPerItem = 256
|
||||
tokens += len(msg.Media) * mediaTokensPerItem
|
||||
|
||||
return tokens
|
||||
// EstimateMessageTokens estimates the token count for a single message.
|
||||
// Delegates to the shared tokenizer package for consistency across agent and seahorse.
|
||||
func EstimateMessageTokens(msg providers.Message) int {
|
||||
return tokenizer.EstimateMessageTokens(msg)
|
||||
}
|
||||
|
||||
// estimateToolDefsTokens estimates the total token cost of tool definitions
|
||||
// as they appear in the LLM request. Each tool's name, description, and
|
||||
// JSON schema parameters contribute to the context window budget.
|
||||
func estimateToolDefsTokens(defs []providers.ToolDefinition) int {
|
||||
if len(defs) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
totalChars := 0
|
||||
for _, d := range defs {
|
||||
totalChars += len(d.Function.Name) + len(d.Function.Description)
|
||||
|
||||
if d.Function.Parameters != nil {
|
||||
if paramJSON, err := json.Marshal(d.Function.Parameters); err == nil {
|
||||
totalChars += len(paramJSON)
|
||||
}
|
||||
}
|
||||
|
||||
// Per-tool overhead: type field, JSON structure, separators.
|
||||
totalChars += 20
|
||||
}
|
||||
|
||||
return totalChars * 2 / 5
|
||||
// EstimateToolDefsTokens estimates the total token cost of tool definitions
|
||||
// as they appear in the LLM request. Delegates to the shared tokenizer package.
|
||||
func EstimateToolDefsTokens(defs []providers.ToolDefinition) int {
|
||||
return tokenizer.EstimateToolDefsTokens(defs)
|
||||
}
|
||||
|
||||
// isOverContextBudget checks whether the assembled messages plus tool definitions
|
||||
@@ -181,10 +107,10 @@ func isOverContextBudget(
|
||||
) bool {
|
||||
msgTokens := 0
|
||||
for _, m := range messages {
|
||||
msgTokens += estimateMessageTokens(m)
|
||||
msgTokens += EstimateMessageTokens(m)
|
||||
}
|
||||
|
||||
toolTokens := estimateToolDefsTokens(toolDefs)
|
||||
toolTokens := EstimateToolDefsTokens(toolDefs)
|
||||
total := msgTokens + toolTokens + maxTokens
|
||||
|
||||
return total > contextWindow
|
||||
|
||||
@@ -417,9 +417,9 @@ func TestEstimateMessageTokens(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := estimateMessageTokens(tt.msg)
|
||||
got := EstimateMessageTokens(tt.msg)
|
||||
if got < tt.want {
|
||||
t.Errorf("estimateMessageTokens() = %d, want >= %d", got, tt.want)
|
||||
t.Errorf("EstimateMessageTokens() = %d, want >= %d", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -443,8 +443,8 @@ func TestEstimateMessageTokens_ToolCallsContribute(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
plainTokens := estimateMessageTokens(plain)
|
||||
withTCTokens := estimateMessageTokens(withTC)
|
||||
plainTokens := EstimateMessageTokens(plain)
|
||||
withTCTokens := EstimateMessageTokens(withTC)
|
||||
|
||||
if withTCTokens <= plainTokens {
|
||||
t.Errorf("message with ToolCalls (%d tokens) should exceed plain message (%d tokens)",
|
||||
@@ -457,7 +457,7 @@ func TestEstimateMessageTokens_MultibyteContent(t *testing.T) {
|
||||
// but may map to different token counts. The heuristic should still produce
|
||||
// reasonable estimates via RuneCountInString.
|
||||
msg := msgUser("caf\u00e9 na\u00efve r\u00e9sum\u00e9 \u00fcber stra\u00dfe")
|
||||
tokens := estimateMessageTokens(msg)
|
||||
tokens := EstimateMessageTokens(msg)
|
||||
if tokens <= 0 {
|
||||
t.Errorf("multibyte message should produce positive token count, got %d", tokens)
|
||||
}
|
||||
@@ -481,7 +481,7 @@ func TestEstimateMessageTokens_LargeArguments(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
tokens := estimateMessageTokens(msg)
|
||||
tokens := EstimateMessageTokens(msg)
|
||||
// 5000+ chars → at least 2000 tokens with the 2.5 char/token heuristic
|
||||
if tokens < 2000 {
|
||||
t.Errorf("large tool call arguments should produce significant token count, got %d", tokens)
|
||||
@@ -496,8 +496,8 @@ func TestEstimateMessageTokens_ReasoningContent(t *testing.T) {
|
||||
ReasoningContent: strings.Repeat("thinking step ", 200),
|
||||
}
|
||||
|
||||
plainTokens := estimateMessageTokens(plain)
|
||||
reasoningTokens := estimateMessageTokens(withReasoning)
|
||||
plainTokens := EstimateMessageTokens(plain)
|
||||
reasoningTokens := EstimateMessageTokens(withReasoning)
|
||||
|
||||
if reasoningTokens <= plainTokens {
|
||||
t.Errorf("message with ReasoningContent (%d tokens) should exceed plain message (%d tokens)",
|
||||
@@ -513,8 +513,8 @@ func TestEstimateMessageTokens_MediaItems(t *testing.T) {
|
||||
Media: []string{"media://img1.png", "media://img2.png"},
|
||||
}
|
||||
|
||||
plainTokens := estimateMessageTokens(plain)
|
||||
mediaTokens := estimateMessageTokens(withMedia)
|
||||
plainTokens := EstimateMessageTokens(plain)
|
||||
mediaTokens := EstimateMessageTokens(withMedia)
|
||||
|
||||
if mediaTokens <= plainTokens {
|
||||
t.Errorf("message with Media (%d tokens) should exceed plain message (%d tokens)",
|
||||
@@ -540,8 +540,8 @@ func TestEstimateMessageTokens_SystemParts(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
plainTokens := estimateMessageTokens(plain)
|
||||
partsTokens := estimateMessageTokens(withParts)
|
||||
plainTokens := EstimateMessageTokens(plain)
|
||||
partsTokens := EstimateMessageTokens(withParts)
|
||||
|
||||
if partsTokens <= plainTokens {
|
||||
t.Errorf("system message with SystemParts (%d) should exceed plain message (%d)",
|
||||
@@ -549,7 +549,7 @@ func TestEstimateMessageTokens_SystemParts(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// --- estimateToolDefsTokens tests ---
|
||||
// --- EstimateToolDefsTokens tests ---
|
||||
|
||||
func TestEstimateToolDefsTokens(t *testing.T) {
|
||||
tests := []struct {
|
||||
@@ -599,9 +599,9 @@ func TestEstimateToolDefsTokens(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := estimateToolDefsTokens(tt.defs)
|
||||
got := EstimateToolDefsTokens(tt.defs)
|
||||
if got < tt.want {
|
||||
t.Errorf("estimateToolDefsTokens() = %d, want >= %d", got, tt.want)
|
||||
t.Errorf("EstimateToolDefsTokens() = %d, want >= %d", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -624,8 +624,8 @@ func TestEstimateToolDefsTokens_ScalesWithCount(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
one := estimateToolDefsTokens([]providers.ToolDefinition{makeTool("tool_a")})
|
||||
three := estimateToolDefsTokens([]providers.ToolDefinition{
|
||||
one := EstimateToolDefsTokens([]providers.ToolDefinition{makeTool("tool_a")})
|
||||
three := EstimateToolDefsTokens([]providers.ToolDefinition{
|
||||
makeTool("tool_a"), makeTool("tool_b"), makeTool("tool_c"),
|
||||
})
|
||||
|
||||
@@ -770,7 +770,7 @@ func TestEstimateMessageTokens_WithReasoningAndMedia(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
tokens := estimateMessageTokens(msg)
|
||||
tokens := EstimateMessageTokens(msg)
|
||||
|
||||
// ReasoningContent alone is ~1700 chars → ~680 tokens.
|
||||
// Content + TC + overhead adds more. Should be well above 500.
|
||||
@@ -781,7 +781,7 @@ func TestEstimateMessageTokens_WithReasoningAndMedia(t *testing.T) {
|
||||
// Compare without reasoning to ensure it's counted.
|
||||
msgNoReasoning := msg
|
||||
msgNoReasoning.ReasoningContent = ""
|
||||
tokensNoReasoning := estimateMessageTokens(msgNoReasoning)
|
||||
tokensNoReasoning := EstimateMessageTokens(msgNoReasoning)
|
||||
|
||||
if tokens <= tokensNoReasoning {
|
||||
t.Errorf("reasoning content should add tokens: with=%d, without=%d", tokens, tokensNoReasoning)
|
||||
|
||||
@@ -373,7 +373,7 @@ func (m *legacyContextManager) summarizeBatch(
|
||||
func (m *legacyContextManager) estimateTokens(messages []providers.Message) int {
|
||||
total := 0
|
||||
for _, msg := range messages {
|
||||
total += estimateMessageTokens(msg)
|
||||
total += EstimateMessageTokens(msg)
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
@@ -43,6 +43,7 @@ type AssembleResponse struct {
|
||||
type CompactRequest struct {
|
||||
SessionKey string // session identifier
|
||||
Reason ContextCompressReason // proactive_budget | llm_retry | summarize
|
||||
Budget int // context window budget (used for retry aggressive compaction)
|
||||
}
|
||||
|
||||
// IngestRequest is the input to Ingest.
|
||||
|
||||
@@ -0,0 +1,267 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/providers/protocoltypes"
|
||||
"github.com/sipeed/picoclaw/pkg/seahorse"
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
"github.com/sipeed/picoclaw/pkg/tokenizer"
|
||||
)
|
||||
|
||||
// seahorseContextManager adapts seahorse.Engine to agent.ContextManager.
|
||||
type seahorseContextManager struct {
|
||||
engine *seahorse.Engine
|
||||
sessions session.SessionStore // for startup bootstrap
|
||||
}
|
||||
|
||||
// newSeahorseContextManager creates a seahorse-backed ContextManager.
|
||||
func newSeahorseContextManager(_ json.RawMessage, al *AgentLoop) (ContextManager, error) {
|
||||
if al == nil {
|
||||
return nil, fmt.Errorf("seahorse: AgentLoop is required")
|
||||
}
|
||||
|
||||
// Resolve workspace for DB path
|
||||
// DB stores session data, so it goes in sessions/ directory
|
||||
agent := al.registry.GetDefaultAgent()
|
||||
dbPath := agent.Workspace + "/sessions/seahorse.db"
|
||||
|
||||
// Create CompleteFn from provider
|
||||
completeFn := providerToCompleteFn(agent.Provider, agent.Model)
|
||||
|
||||
// Create engine
|
||||
engine, err := seahorse.NewEngine(seahorse.Config{
|
||||
DBPath: dbPath,
|
||||
}, completeFn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("seahorse: create engine: %w", err)
|
||||
}
|
||||
|
||||
mgr := &seahorseContextManager{
|
||||
engine: engine,
|
||||
sessions: agent.Sessions,
|
||||
}
|
||||
|
||||
// Register seahorse tools with the agent's tool registry
|
||||
retrieval := mgr.engine.GetRetrieval()
|
||||
al.RegisterTool(seahorse.NewGrepTool(retrieval))
|
||||
al.RegisterTool(seahorse.NewExpandTool(retrieval))
|
||||
|
||||
// Bootstrap all existing sessions at startup
|
||||
if agent.Sessions != nil {
|
||||
ctx := context.Background()
|
||||
for _, sessionKey := range agent.Sessions.ListSessions() {
|
||||
mgr.bootstrapSession(ctx, sessionKey)
|
||||
}
|
||||
}
|
||||
|
||||
return mgr, nil
|
||||
}
|
||||
|
||||
// providerToCompleteFn wraps providers.LLMProvider as a seahorse.CompleteFn.
|
||||
func providerToCompleteFn(provider providers.LLMProvider, model string) seahorse.CompleteFn {
|
||||
return func(ctx context.Context, prompt string, opts seahorse.CompleteOptions) (string, error) {
|
||||
resp, err := provider.Chat(
|
||||
ctx,
|
||||
[]providers.Message{{Role: "user", Content: prompt}},
|
||||
nil, // no tools for summarization
|
||||
model,
|
||||
map[string]any{
|
||||
"max_tokens": opts.MaxTokens,
|
||||
"temperature": opts.Temperature,
|
||||
"prompt_cache_key": "seahorse",
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return resp.Content, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Assemble builds budget-aware context from seahorse SQLite.
|
||||
func (m *seahorseContextManager) Assemble(ctx context.Context, req *AssembleRequest) (*AssembleResponse, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("seahorse assemble: nil request")
|
||||
}
|
||||
|
||||
budget := req.Budget
|
||||
if budget <= 0 {
|
||||
budget = 100000
|
||||
}
|
||||
|
||||
// Reserve space for model response (spec lines 1400-1410)
|
||||
effectiveBudget := budget - req.MaxTokens
|
||||
if effectiveBudget <= 0 {
|
||||
// MaxTokens >= budget is a configuration problem
|
||||
// Use 50% as minimum to avoid guaranteed overflow
|
||||
logger.WarnCF("agent", "MaxTokens >= budget, using 50% fallback",
|
||||
map[string]any{"budget": budget, "max_tokens": req.MaxTokens})
|
||||
effectiveBudget = budget / 2
|
||||
}
|
||||
|
||||
result, err := m.engine.Assemble(ctx, req.SessionKey, seahorse.AssembleInput{
|
||||
Budget: effectiveBudget,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("seahorse assemble: %w", err)
|
||||
}
|
||||
|
||||
history := seahorseToProviderMessages(result)
|
||||
|
||||
// Summary is already formatted as XML with system prompt addition by assembler
|
||||
return &AssembleResponse{
|
||||
History: history,
|
||||
Summary: result.Summary,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Compact compresses conversation history via seahorse summarization.
|
||||
func (m *seahorseContextManager) Compact(ctx context.Context, req *CompactRequest) error {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// For retry (LLM overflow), use aggressive CompactUntilUnder to guarantee
|
||||
// context shrinks below budget (spec lines ~1410).
|
||||
if req.Reason == ContextCompressReasonRetry && req.Budget > 0 {
|
||||
_, err := m.engine.CompactUntilUnder(ctx, req.SessionKey, req.Budget)
|
||||
return err
|
||||
}
|
||||
|
||||
_, err := m.engine.Compact(ctx, req.SessionKey, seahorse.CompactInput{
|
||||
Force: req.Reason == ContextCompressReasonRetry,
|
||||
Budget: &req.Budget,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// Ingest records a message into seahorse SQLite.
|
||||
// All existing sessions are bootstrapped at startup, so this only ingests new messages.
|
||||
func (m *seahorseContextManager) Ingest(ctx context.Context, req *IngestRequest) error {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
msg := providerToSeahorseMessage(req.Message)
|
||||
_, err := m.engine.Ingest(ctx, req.SessionKey, []seahorse.Message{msg})
|
||||
return err
|
||||
}
|
||||
|
||||
// bootstrapSession reconciles JSONL session history into seahorse SQLite.
|
||||
func (m *seahorseContextManager) bootstrapSession(ctx context.Context, sessionKey string) {
|
||||
if m.sessions == nil {
|
||||
return
|
||||
}
|
||||
|
||||
history := m.sessions.GetHistory(sessionKey)
|
||||
if len(history) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Convert provider messages to seahorse messages
|
||||
msgs := make([]seahorse.Message, len(history))
|
||||
for i, h := range history {
|
||||
msgs[i] = providerToSeahorseMessage(h)
|
||||
}
|
||||
|
||||
if err := m.engine.Bootstrap(ctx, sessionKey, msgs); err != nil {
|
||||
logger.WarnCF("seahorse", "bootstrap", map[string]any{
|
||||
"session": sessionKey,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// providerToSeahorseMessage converts a providers.Message to a seahorse.Message.
|
||||
func providerToSeahorseMessage(msg protocoltypes.Message) seahorse.Message {
|
||||
result := seahorse.Message{
|
||||
Role: msg.Role,
|
||||
Content: msg.Content,
|
||||
ReasoningContent: msg.ReasoningContent,
|
||||
TokenCount: tokenizer.EstimateMessageTokens(msg),
|
||||
}
|
||||
|
||||
// Convert ToolCalls → MessageParts
|
||||
for _, tc := range msg.ToolCalls {
|
||||
part := seahorse.MessagePart{
|
||||
Type: "tool_use",
|
||||
Name: tc.Function.Name,
|
||||
Arguments: tc.Function.Arguments,
|
||||
ToolCallID: tc.ID,
|
||||
}
|
||||
result.Parts = append(result.Parts, part)
|
||||
}
|
||||
|
||||
// Convert tool result
|
||||
if msg.ToolCallID != "" {
|
||||
part := seahorse.MessagePart{
|
||||
Type: "tool_result",
|
||||
ToolCallID: msg.ToolCallID,
|
||||
Text: msg.Content,
|
||||
}
|
||||
result.Parts = append(result.Parts, part)
|
||||
}
|
||||
|
||||
// Convert media attachments
|
||||
for _, mediaURI := range msg.Media {
|
||||
part := seahorse.MessagePart{
|
||||
Type: "media",
|
||||
MediaURI: mediaURI,
|
||||
}
|
||||
result.Parts = append(result.Parts, part)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// seahorseToProviderMessages converts a seahorse.AssembleResult to []providers.Message.
|
||||
func seahorseToProviderMessages(result *seahorse.AssembleResult) []protocoltypes.Message {
|
||||
messages := make([]protocoltypes.Message, 0, len(result.Messages))
|
||||
|
||||
// Convert assembled messages (which already include summary XML messages)
|
||||
for _, msg := range result.Messages {
|
||||
pm := protocoltypes.Message{
|
||||
Role: msg.Role,
|
||||
Content: msg.Content,
|
||||
ReasoningContent: msg.ReasoningContent,
|
||||
}
|
||||
|
||||
// Reconstruct ToolCalls from parts
|
||||
for _, part := range msg.Parts {
|
||||
if part.Type == "tool_use" {
|
||||
pm.ToolCalls = append(pm.ToolCalls, protocoltypes.ToolCall{
|
||||
ID: part.ToolCallID,
|
||||
Type: "function", // Required by OpenAI-compatible APIs (GLM, etc.)
|
||||
Function: &protocoltypes.FunctionCall{
|
||||
Name: part.Name,
|
||||
Arguments: part.Arguments,
|
||||
},
|
||||
})
|
||||
}
|
||||
if part.Type == "tool_result" {
|
||||
pm.ToolCallID = part.ToolCallID
|
||||
if pm.Content == "" && part.Text != "" {
|
||||
pm.Content = part.Text
|
||||
}
|
||||
}
|
||||
if part.Type == "media" && part.MediaURI != "" {
|
||||
pm.Media = append(pm.Media, part.MediaURI)
|
||||
}
|
||||
}
|
||||
|
||||
messages = append(messages, pm)
|
||||
}
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
func init() {
|
||||
if err := RegisterContextManager("seahorse", newSeahorseContextManager); err != nil {
|
||||
panic(fmt.Sprintf("register seahorse context manager: %v", err))
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
+5
-1
@@ -1742,6 +1742,7 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
|
||||
if err := al.contextManager.Compact(turnCtx, &CompactRequest{
|
||||
SessionKey: ts.sessionKey,
|
||||
Reason: ContextCompressReasonProactive,
|
||||
Budget: ts.agent.ContextWindow,
|
||||
}); err != nil {
|
||||
logger.WarnCF("agent", "Proactive compact failed", map[string]any{
|
||||
"session_key": ts.sessionKey,
|
||||
@@ -1857,6 +1858,7 @@ turnLoop:
|
||||
if !ts.opts.NoHistory {
|
||||
ts.agent.Sessions.AddFullMessage(ts.sessionKey, pm)
|
||||
ts.recordPersistedMessage(pm)
|
||||
ts.ingestMessage(turnCtx, al, pm)
|
||||
}
|
||||
logger.InfoCF("agent", "Injected steering message into context",
|
||||
map[string]any{
|
||||
@@ -2128,6 +2130,7 @@ turnLoop:
|
||||
if compactErr := al.contextManager.Compact(turnCtx, &CompactRequest{
|
||||
SessionKey: ts.sessionKey,
|
||||
Reason: ContextCompressReasonRetry,
|
||||
Budget: ts.agent.ContextWindow,
|
||||
}); compactErr != nil {
|
||||
logger.WarnCF("agent", "Context overflow compact failed", map[string]any{
|
||||
"session_key": ts.sessionKey,
|
||||
@@ -2773,7 +2776,7 @@ turnLoop:
|
||||
}
|
||||
}
|
||||
if ts.opts.EnableSummary {
|
||||
al.contextManager.Compact(turnCtx, &CompactRequest{SessionKey: ts.sessionKey, Reason: ContextCompressReasonSummarize})
|
||||
al.contextManager.Compact(turnCtx, &CompactRequest{SessionKey: ts.sessionKey, Reason: ContextCompressReasonSummarize, Budget: ts.agent.ContextWindow})
|
||||
}
|
||||
|
||||
ts.setPhase(TurnPhaseCompleted)
|
||||
@@ -2849,6 +2852,7 @@ turnLoop:
|
||||
&CompactRequest{
|
||||
SessionKey: ts.sessionKey,
|
||||
Reason: ContextCompressReasonSummarize,
|
||||
Budget: ts.agent.ContextWindow,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@@ -604,6 +604,7 @@ type ephemeralSessionStoreIface interface {
|
||||
SetHistory(key string, history []providers.Message)
|
||||
TruncateHistory(key string, keepLast int)
|
||||
Save(key string) error
|
||||
ListSessions() []string
|
||||
Close() error
|
||||
}
|
||||
|
||||
@@ -663,8 +664,9 @@ func (e *ephemeralSessionStore) TruncateHistory(_ string, keepLast int) {
|
||||
e.history = e.history[len(e.history)-keepLast:]
|
||||
}
|
||||
|
||||
func (e *ephemeralSessionStore) Save(_ string) error { return nil }
|
||||
func (e *ephemeralSessionStore) Close() error { return nil }
|
||||
func (e *ephemeralSessionStore) Save(_ string) error { return nil }
|
||||
func (e *ephemeralSessionStore) Close() error { return nil }
|
||||
func (e *ephemeralSessionStore) ListSessions() []string { return nil }
|
||||
|
||||
func (e *ephemeralSessionStore) truncateLocked() {
|
||||
if len(e.history) > maxEphemeralHistorySize {
|
||||
|
||||
@@ -455,6 +455,33 @@ func (s *JSONLStore) rewriteJSONL(
|
||||
return fileutil.WriteFileAtomic(s.jsonlPath(sessionKey), buf.Bytes(), 0o644)
|
||||
}
|
||||
|
||||
// ListSessions returns all known session keys by reading .meta.json files.
|
||||
func (s *JSONLStore) ListSessions() []string {
|
||||
entries, err := os.ReadDir(s.dir)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
var keys []string
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".meta.json") {
|
||||
continue
|
||||
}
|
||||
// Read the meta file to get the original key
|
||||
data, err := os.ReadFile(filepath.Join(s.dir, entry.Name()))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
var meta sessionMeta
|
||||
if err := json.Unmarshal(data, &meta); err != nil {
|
||||
continue
|
||||
}
|
||||
if meta.Key != "" {
|
||||
keys = append(keys, meta.Key)
|
||||
}
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
func (s *JSONLStore) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -37,6 +37,9 @@ type Store interface {
|
||||
// data. Backends that do not accumulate dead data may return nil.
|
||||
Compact(ctx context.Context, sessionKey string) error
|
||||
|
||||
// ListSessions returns all known session keys.
|
||||
ListSessions() []string
|
||||
|
||||
// Close releases any resources held by the store.
|
||||
Close() error
|
||||
}
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"tool_name": "Bash",
|
||||
"tool_input_preview": "{\"command\":\"cd /home/yliu/repos/picoclaw && make lint 2>&1\",\"timeout\":120000}",
|
||||
"error": "Exit code 2\npkg/agent/context_seahorse_test.go:1027:1: File is not properly formatted (gci)\n\t\t\tEarliestAt: &now,\n^\n1 issues:\n* gci: 1\nmake: *** [Makefile:264: lint] Error 1",
|
||||
"timestamp": "2026-04-04T02:38:32.067Z",
|
||||
"retry_count": 6
|
||||
}
|
||||
@@ -0,0 +1,58 @@
|
||||
package seahorse
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// =============================================================================
|
||||
// CompactUntilUnder iteration cap
|
||||
// =============================================================================
|
||||
|
||||
func TestCompactUntilUnderIterationCap(t *testing.T) {
|
||||
// Setup: create a conversation with so many tokens that compaction
|
||||
// will never reach the budget. The iteration cap prevents infinite loops.
|
||||
//
|
||||
// We use a mock CompleteFn that always returns the same content,
|
||||
// and a budget of 0 which tokens can never reach.
|
||||
// Without the cap, this would loop forever.
|
||||
|
||||
db := openTestDB(t)
|
||||
if err := runSchema(db); err != nil {
|
||||
t.Fatalf("migration: %v", err)
|
||||
}
|
||||
s := &Store{db: db}
|
||||
|
||||
conv, _ := s.GetOrCreateConversation(context.Background(), "agent:iter-cap")
|
||||
convID := conv.ConversationID
|
||||
|
||||
// Add many messages to ensure there's plenty to compact
|
||||
for i := 0; i < 40; i++ {
|
||||
m, _ := s.AddMessage(context.Background(), convID, "user",
|
||||
"this is a long message with lots of tokens to push context over budget", 100)
|
||||
s.AppendContextMessage(context.Background(), convID, m.ID)
|
||||
}
|
||||
|
||||
// A completeFn that always succeeds but returns non-reducing content
|
||||
mockComplete := func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) {
|
||||
return "Summary that doesn't reduce tokens much.", nil
|
||||
}
|
||||
|
||||
ce, cancel := newTestCompactionEngineWithStore(s, mockComplete)
|
||||
defer cancel()
|
||||
|
||||
// Use budget=1 so tokens can never reach budget
|
||||
// (each message is 100 tokens, so 40 messages = 4000 tokens, budget 1 is unreachable)
|
||||
// The function should stop after maxCompactIterations, not loop forever
|
||||
ce.config = Config{} // ensure defaults
|
||||
|
||||
result, err := ce.CompactUntilUnder(context.Background(), convID, 1)
|
||||
if err != nil {
|
||||
// Should not error — should stop gracefully
|
||||
t.Fatalf("CompactUntilUnder with budget=0: %v", err)
|
||||
}
|
||||
|
||||
// The function should have completed within reasonable time
|
||||
// If it exceeded the cap, it would still return (not hang)
|
||||
_ = result
|
||||
}
|
||||
@@ -0,0 +1,144 @@
|
||||
package seahorse
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// =============================================================================
|
||||
// Bug 1: formatMessagesForSummary ignores Parts
|
||||
// - formatMessagesForSummary only reads m.Content, empty for Part-based messages
|
||||
// - truncateSummary has same issue
|
||||
// =============================================================================
|
||||
|
||||
func TestFormatMessagesForSummaryIncludesParts(t *testing.T) {
|
||||
ts := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
messages := []Message{
|
||||
{ID: 1, Role: "user", Content: "hello world", CreatedAt: ts},
|
||||
{
|
||||
ID: 2,
|
||||
Role: "assistant",
|
||||
Content: "", // empty — real content is in Parts
|
||||
Parts: []MessagePart{
|
||||
{Type: "text", Text: "I will run a command"},
|
||||
{Type: "tool_use", Name: "bash", Arguments: `{"command":"ls -la"}`, ToolCallID: "call_1"},
|
||||
},
|
||||
CreatedAt: ts.Add(time.Minute),
|
||||
},
|
||||
{
|
||||
ID: 3,
|
||||
Role: "tool",
|
||||
Content: "", // empty — real content is in Parts
|
||||
Parts: []MessagePart{
|
||||
{Type: "tool_result", Text: "file1.txt\nfile2.txt", ToolCallID: "call_1"},
|
||||
},
|
||||
CreatedAt: ts.Add(2 * time.Minute),
|
||||
},
|
||||
}
|
||||
|
||||
result := formatMessagesForSummary(messages)
|
||||
|
||||
// Must contain the plain text message
|
||||
if !contains(result, "hello world") {
|
||||
t.Error("formatMessagesForSummary: missing plain text content")
|
||||
}
|
||||
|
||||
// Must contain tool_use info (not blank)
|
||||
if !contains(result, "bash") || !contains(result, "ls -la") {
|
||||
t.Errorf("formatMessagesForSummary: tool_use info missing from Parts.\nGot:\n%s", result)
|
||||
}
|
||||
|
||||
// Must contain tool_result info (not blank)
|
||||
if !contains(result, "file1.txt") {
|
||||
t.Errorf("formatMessagesForSummary: tool_result text missing from Parts.\nGot:\n%s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncateSummaryIncludesParts(t *testing.T) {
|
||||
messages := []Message{
|
||||
{ID: 1, Role: "user", Content: "run the tests", CreatedAt: time.Now()},
|
||||
{
|
||||
ID: 2,
|
||||
Role: "assistant",
|
||||
Content: "", // empty
|
||||
Parts: []MessagePart{
|
||||
{Type: "tool_use", Name: "bash", Arguments: `{"command":"go test ./..."}`, ToolCallID: "call_1"},
|
||||
},
|
||||
CreatedAt: time.Now(),
|
||||
},
|
||||
{
|
||||
ID: 3,
|
||||
Role: "tool",
|
||||
Content: "", // empty
|
||||
Parts: []MessagePart{
|
||||
{Type: "tool_result", Text: "PASS\nok 3.2s", ToolCallID: "call_1"},
|
||||
},
|
||||
CreatedAt: time.Now(),
|
||||
},
|
||||
}
|
||||
|
||||
result := truncateSummary(messages)
|
||||
|
||||
// Must contain plain text
|
||||
if !contains(result, "run the tests") {
|
||||
t.Error("truncateSummary: missing plain text content")
|
||||
}
|
||||
|
||||
// Must contain tool info from Parts (not blank)
|
||||
if !contains(result, "bash") || !contains(result, "go test") {
|
||||
t.Errorf("truncateSummary: tool_use info missing from Parts.\nGot:\n%s", result)
|
||||
}
|
||||
|
||||
// Must contain tool_result from Parts
|
||||
if !contains(result, "PASS") {
|
||||
t.Errorf("truncateSummary: tool_result text missing from Parts.\nGot:\n%s", result)
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Bug 2: SearchMessages cannot find Part-based messages
|
||||
// - FTS5 indexes empty content, LIKE queries empty content
|
||||
// =============================================================================
|
||||
|
||||
func TestSearchMessagesFindsPartBasedMessages(t *testing.T) {
|
||||
s := openTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
conv, _ := s.GetOrCreateConversation(ctx, "agent:search-parts")
|
||||
convID := conv.ConversationID
|
||||
|
||||
// Add a plain message (searchable)
|
||||
s.AddMessage(ctx, convID, "user", "list the files please", 5)
|
||||
|
||||
// Add a Part-based message (tool_use) — currently NOT searchable
|
||||
parts := []MessagePart{
|
||||
{Type: "tool_use", Name: "bash", Arguments: `{"command":"grep -r TODO ."}`, ToolCallID: "call_1"},
|
||||
}
|
||||
s.AddMessageWithParts(ctx, convID, "assistant", parts, 10)
|
||||
|
||||
// Add a Part-based message (tool_result) — currently NOT searchable
|
||||
resultParts := []MessagePart{
|
||||
{Type: "tool_result", Text: "main.go:42: TODO fix this bug", ToolCallID: "call_1"},
|
||||
}
|
||||
s.AddMessageWithParts(ctx, convID, "tool", resultParts, 10)
|
||||
|
||||
// Search for "grep" — should find the tool_use message
|
||||
results, err := s.SearchMessages(ctx, SearchInput{Pattern: "grep"})
|
||||
if err != nil {
|
||||
t.Fatalf("SearchMessages: %v", err)
|
||||
}
|
||||
if len(results) == 0 {
|
||||
t.Error("SearchMessages: 'grep' not found — Part-based messages are invisible to search")
|
||||
}
|
||||
|
||||
// Search for "TODO fix" — should find the tool_result message
|
||||
results2, err := s.SearchMessages(ctx, SearchInput{Pattern: "TODO fix"})
|
||||
if err != nil {
|
||||
t.Fatalf("SearchMessages: %v", err)
|
||||
}
|
||||
if len(results2) == 0 {
|
||||
t.Error("SearchMessages: 'TODO fix' not found — tool_result messages are invisible to search")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,185 @@
|
||||
package seahorse
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
// SQL statements for FTS5 tables with trigram tokenizer.
|
||||
const (
|
||||
sqlCreateSummariesFTS = `CREATE VIRTUAL TABLE IF NOT EXISTS summaries_fts USING fts5(
|
||||
summary_id,
|
||||
content,
|
||||
tokenize="trigram"
|
||||
)`
|
||||
sqlCreateMessagesFTS = `CREATE VIRTUAL TABLE IF NOT EXISTS messages_fts USING fts5(
|
||||
message_id,
|
||||
content,
|
||||
tokenize="trigram"
|
||||
)`
|
||||
sqlCheckFTS5Available = `CREATE VIRTUAL TABLE IF NOT EXISTS _fts5_check USING fts5(content)`
|
||||
sqlCheckTrigramAvailable = `CREATE VIRTUAL TABLE IF NOT EXISTS _trigram_check USING fts5(content, tokenize="trigram")`
|
||||
sqlDropFTS5Check = `DROP TABLE IF EXISTS _fts5_check`
|
||||
sqlDropTrigramCheck = `DROP TABLE IF EXISTS _trigram_check`
|
||||
)
|
||||
|
||||
// runSchema creates or upgrades the database schema.
|
||||
// All schemas are idempotent (safe to run multiple times).
|
||||
func runSchema(db *sql.DB) error {
|
||||
// Check FTS5 support before creating tables
|
||||
if err := checkFTS5Support(db); err != nil {
|
||||
return fmt.Errorf("FTS5 check: %w", err)
|
||||
}
|
||||
|
||||
stmts := []string{
|
||||
`CREATE TABLE IF NOT EXISTS conversations (
|
||||
conversation_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
session_key TEXT NOT NULL UNIQUE,
|
||||
created_at TEXT NOT NULL DEFAULT (datetime('now')),
|
||||
updated_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||||
)`,
|
||||
|
||||
`CREATE TABLE IF NOT EXISTS messages (
|
||||
message_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
conversation_id INTEGER NOT NULL REFERENCES conversations(conversation_id),
|
||||
role TEXT NOT NULL,
|
||||
content TEXT NOT NULL DEFAULT '',
|
||||
token_count INTEGER NOT NULL DEFAULT 0,
|
||||
created_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||||
)`,
|
||||
|
||||
`CREATE TABLE IF NOT EXISTS message_parts (
|
||||
part_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
message_id INTEGER NOT NULL REFERENCES messages(message_id),
|
||||
type TEXT NOT NULL,
|
||||
text TEXT,
|
||||
name TEXT,
|
||||
arguments TEXT,
|
||||
tool_call_id TEXT,
|
||||
media_uri TEXT,
|
||||
mime_type TEXT,
|
||||
ordinal INTEGER NOT NULL DEFAULT 0
|
||||
)`,
|
||||
|
||||
`CREATE TABLE IF NOT EXISTS summaries (
|
||||
summary_id TEXT PRIMARY KEY,
|
||||
conversation_id INTEGER NOT NULL REFERENCES conversations(conversation_id),
|
||||
kind TEXT NOT NULL,
|
||||
depth INTEGER NOT NULL DEFAULT 0,
|
||||
content TEXT NOT NULL,
|
||||
token_count INTEGER NOT NULL DEFAULT 0,
|
||||
earliest_at TEXT,
|
||||
latest_at TEXT,
|
||||
descendant_count INTEGER NOT NULL DEFAULT 0,
|
||||
descendant_token_count INTEGER NOT NULL DEFAULT 0,
|
||||
source_message_token_count INTEGER NOT NULL DEFAULT 0,
|
||||
model TEXT,
|
||||
created_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||||
)`,
|
||||
|
||||
`CREATE TABLE IF NOT EXISTS summary_parents (
|
||||
summary_id TEXT NOT NULL,
|
||||
parent_summary_id TEXT NOT NULL,
|
||||
PRIMARY KEY (summary_id, parent_summary_id)
|
||||
)`,
|
||||
|
||||
`CREATE TABLE IF NOT EXISTS summary_messages (
|
||||
summary_id TEXT NOT NULL,
|
||||
message_id INTEGER NOT NULL,
|
||||
ordinal INTEGER NOT NULL DEFAULT 0,
|
||||
PRIMARY KEY (summary_id, message_id)
|
||||
)`,
|
||||
|
||||
`CREATE TABLE IF NOT EXISTS context_items (
|
||||
conversation_id INTEGER NOT NULL,
|
||||
ordinal INTEGER NOT NULL,
|
||||
item_type TEXT NOT NULL,
|
||||
summary_id TEXT,
|
||||
message_id INTEGER,
|
||||
token_count INTEGER NOT NULL DEFAULT 0,
|
||||
created_at TEXT NOT NULL DEFAULT (datetime('now')),
|
||||
PRIMARY KEY (conversation_id, ordinal)
|
||||
)`,
|
||||
|
||||
// FTS5 virtual table with trigram tokenizer for CJK support
|
||||
sqlCreateSummariesFTS,
|
||||
|
||||
// FTS5 virtual table for message search with trigram tokenizer
|
||||
sqlCreateMessagesFTS,
|
||||
|
||||
// Indexes for common query patterns
|
||||
`CREATE INDEX IF NOT EXISTS idx_messages_conversation ON messages(conversation_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_messages_created ON messages(conversation_id, created_at)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_summaries_conversation ON summaries(conversation_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_summaries_kind_depth ON summaries(conversation_id, kind, depth)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_summary_parents_parent ON summary_parents(parent_summary_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_summary_messages_message ON summary_messages(message_id)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_context_items_conv ON context_items(conversation_id, ordinal)`,
|
||||
|
||||
// FTS5 triggers to keep summaries_fts in sync with summaries table
|
||||
`CREATE TRIGGER IF NOT EXISTS summaries_ai AFTER INSERT ON summaries BEGIN
|
||||
INSERT INTO summaries_fts (summary_id, content) VALUES (new.summary_id, new.content);
|
||||
END`,
|
||||
`CREATE TRIGGER IF NOT EXISTS summaries_ad AFTER DELETE ON summaries BEGIN
|
||||
INSERT INTO summaries_fts (summaries_fts, summary_id, content) VALUES ('delete', old.summary_id, old.content);
|
||||
END`,
|
||||
`CREATE TRIGGER IF NOT EXISTS summaries_au AFTER UPDATE ON summaries BEGIN
|
||||
INSERT INTO summaries_fts (summaries_fts, summary_id, content) VALUES ('delete', old.summary_id, old.content);
|
||||
INSERT INTO summaries_fts (summary_id, content) VALUES (new.summary_id, new.content);
|
||||
END`,
|
||||
|
||||
// FTS5 triggers to keep messages_fts in sync with messages table
|
||||
`CREATE TRIGGER IF NOT EXISTS messages_ai AFTER INSERT ON messages BEGIN
|
||||
INSERT INTO messages_fts (message_id, content) VALUES (new.message_id, new.content);
|
||||
END`,
|
||||
`CREATE TRIGGER IF NOT EXISTS messages_ad AFTER DELETE ON messages BEGIN
|
||||
DELETE FROM messages_fts WHERE message_id = old.message_id;
|
||||
END`,
|
||||
`CREATE TRIGGER IF NOT EXISTS messages_au AFTER UPDATE ON messages BEGIN
|
||||
DELETE FROM messages_fts WHERE message_id = old.message_id;
|
||||
INSERT INTO messages_fts (message_id, content) VALUES (new.message_id, new.content);
|
||||
END`,
|
||||
}
|
||||
|
||||
for _, s := range stmts {
|
||||
if _, err := db.Exec(s); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkFTS5Support verifies that SQLite has FTS5 with trigram tokenizer enabled.
|
||||
// This is required for full-text search with CJK (Chinese, Japanese, Korean) support.
|
||||
func checkFTS5Support(db *sql.DB) error {
|
||||
// Check if FTS5 is compiled in
|
||||
var fts5Enabled int
|
||||
err := db.QueryRow(`SELECT sqlite_compileoption_used('ENABLE_FTS5')`).Scan(&fts5Enabled)
|
||||
if err != nil {
|
||||
// sqlite_compileoption_used might not exist in older SQLite
|
||||
// Try a different approach: create a test FTS5 table
|
||||
_, testErr := db.Exec(sqlCheckFTS5Available)
|
||||
if testErr != nil {
|
||||
return fmt.Errorf("SQLite FTS5 not available: %w (required for full-text search)", testErr)
|
||||
}
|
||||
db.Exec(sqlDropFTS5Check)
|
||||
} else if fts5Enabled == 0 {
|
||||
return fmt.Errorf("SQLite was compiled without FTS5 support (required for full-text search)")
|
||||
}
|
||||
|
||||
// Check if trigram tokenizer is available by trying to create a test table
|
||||
// Not all SQLite builds include the trigram tokenizer
|
||||
_, err = db.Exec(sqlCheckTrigramAvailable)
|
||||
if err != nil {
|
||||
logger.WarnCF("seahorse", "SQLite trigram tokenizer not available, CJK search may be limited",
|
||||
map[string]any{"error": err.Error()})
|
||||
// Trigram is not strictly required, just better for CJK
|
||||
// Don't return error, just log warning
|
||||
} else {
|
||||
db.Exec(sqlDropTrigramCheck)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,211 @@
|
||||
package seahorse
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
func openTestDB(t *testing.T) *sql.DB {
|
||||
t.Helper()
|
||||
db, err := sql.Open("sqlite", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("open test db: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { db.Close() })
|
||||
return db
|
||||
}
|
||||
|
||||
func TestRunMigrations(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
|
||||
if err := runSchema(db); err != nil {
|
||||
t.Fatalf("runSchema: %v", err)
|
||||
}
|
||||
|
||||
// Verify all tables exist
|
||||
tables := []string{
|
||||
"conversations",
|
||||
"messages",
|
||||
"message_parts",
|
||||
"summaries",
|
||||
"summary_parents",
|
||||
"summary_messages",
|
||||
"context_items",
|
||||
}
|
||||
for _, tbl := range tables {
|
||||
var name string
|
||||
err := db.QueryRow(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name=?", tbl,
|
||||
).Scan(&name)
|
||||
if err != nil {
|
||||
t.Errorf("table %q not found: %v", tbl, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify FTS5 virtual table exists
|
||||
var ftsName string
|
||||
err := db.QueryRow(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='summaries_fts'",
|
||||
).Scan(&ftsName)
|
||||
if err != nil {
|
||||
t.Errorf("FTS5 table summaries_fts not found: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunMigrationsIdempotent(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
|
||||
// Run migrations twice — should succeed both times
|
||||
if err := runSchema(db); err != nil {
|
||||
t.Fatalf("first migration: %v", err)
|
||||
}
|
||||
if err := runSchema(db); err != nil {
|
||||
t.Fatalf("second migration (idempotent): %v", err)
|
||||
}
|
||||
|
||||
// Verify we can still insert data after double migration
|
||||
res, err := db.Exec(
|
||||
"INSERT INTO conversations (session_key, created_at, updated_at) VALUES (?, datetime('now'), datetime('now'))",
|
||||
"test-session",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("insert after double migration: %v", err)
|
||||
}
|
||||
id, _ := res.LastInsertId()
|
||||
if id == 0 {
|
||||
t.Error("expected non-zero conversation id")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrationConversationUnique(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
if err := runSchema(db); err != nil {
|
||||
t.Fatalf("migration: %v", err)
|
||||
}
|
||||
|
||||
// Insert first
|
||||
_, err := db.Exec(
|
||||
"INSERT INTO conversations (session_key, created_at, updated_at) VALUES (?, datetime('now'), datetime('now'))",
|
||||
"unique-key",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("first insert: %v", err)
|
||||
}
|
||||
|
||||
// Duplicate should fail
|
||||
_, err = db.Exec(
|
||||
"INSERT INTO conversations (session_key, created_at, updated_at) VALUES (?, datetime('now'), datetime('now'))",
|
||||
"unique-key",
|
||||
)
|
||||
if err == nil {
|
||||
t.Error("expected unique constraint violation for duplicate session_key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrationSummaryFTSInsert(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
if err := runSchema(db); err != nil {
|
||||
t.Fatalf("migration: %v", err)
|
||||
}
|
||||
|
||||
// Insert a conversation first
|
||||
_, err := db.Exec(
|
||||
"INSERT INTO conversations (session_key, created_at, updated_at) VALUES (?, datetime('now'), datetime('now'))",
|
||||
"fts-test",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("insert conversation: %v", err)
|
||||
}
|
||||
|
||||
// Insert a summary
|
||||
_, err = db.Exec(
|
||||
`INSERT INTO summaries (summary_id, conversation_id, kind, depth, content, token_count, created_at)
|
||||
VALUES ('sum_test1', 1, 'leaf', 0, '你好世界 hello world', 10, datetime('now'))`)
|
||||
if err != nil {
|
||||
t.Fatalf("insert summary: %v", err)
|
||||
}
|
||||
|
||||
// FTS should find it — trigram tokenizer requires >= 3 chars
|
||||
rows, err := db.Query(
|
||||
"SELECT summary_id FROM summaries_fts WHERE summaries_fts MATCH ?",
|
||||
"你好世",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("FTS query: %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var found string
|
||||
if rows.Next() {
|
||||
if err := rows.Scan(&found); err != nil {
|
||||
t.Fatalf("scan: %v", err)
|
||||
}
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
t.Fatalf("rows.Err: %v", err)
|
||||
}
|
||||
if found != "sum_test1" {
|
||||
t.Errorf("FTS: expected 'sum_test1', got %q", found)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrationSummaryParentsPK(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
if err := runSchema(db); err != nil {
|
||||
t.Fatalf("migration: %v", err)
|
||||
}
|
||||
|
||||
// Insert two summaries
|
||||
for _, id := range []string{"sum_a", "sum_b"} {
|
||||
_, err := db.Exec(
|
||||
`INSERT INTO summaries (summary_id, conversation_id, kind, depth, content, token_count, created_at)
|
||||
VALUES (?, 1, 'leaf', 0, 'content', 5, datetime('now'))`, id)
|
||||
if err != nil {
|
||||
t.Fatalf("insert summary %s: %v", id, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Link child to parent
|
||||
_, err := db.Exec(
|
||||
"INSERT INTO summary_parents (summary_id, parent_summary_id) VALUES ('sum_a', 'sum_b')")
|
||||
if err != nil {
|
||||
t.Fatalf("link: %v", err)
|
||||
}
|
||||
|
||||
// Duplicate link should fail (composite PK)
|
||||
_, err = db.Exec(
|
||||
"INSERT INTO summary_parents (summary_id, parent_summary_id) VALUES ('sum_a', 'sum_b')")
|
||||
if err == nil {
|
||||
t.Error("expected unique constraint violation for duplicate summary_parents link")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFTS5SQLConstants(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
|
||||
// Verify FTS5 check SQL executes without error
|
||||
_, err := db.Exec(sqlCheckFTS5Available)
|
||||
if err != nil {
|
||||
t.Errorf("sqlCheckFTS5Available failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify trigram check SQL executes without error
|
||||
_, err = db.Exec(sqlCheckTrigramAvailable)
|
||||
if err != nil {
|
||||
t.Errorf("sqlCheckTrigramAvailable failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify summaries_fts SQL executes without error
|
||||
_, err = db.Exec(sqlCreateSummariesFTS)
|
||||
if err != nil {
|
||||
t.Errorf("sqlCreateSummariesFTS failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify messages_fts SQL executes without error
|
||||
_, err = db.Exec(sqlCreateMessagesFTS)
|
||||
if err != nil {
|
||||
t.Errorf("sqlCreateMessagesFTS failed: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,261 @@
|
||||
package seahorse
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
// escapeXML escapes special characters for safe inclusion in XML content.
|
||||
func escapeXML(s string) string {
|
||||
s = strings.ReplaceAll(s, "&", "&")
|
||||
s = strings.ReplaceAll(s, "<", "<")
|
||||
s = strings.ReplaceAll(s, ">", ">")
|
||||
s = strings.ReplaceAll(s, "\"", """)
|
||||
s = strings.ReplaceAll(s, "'", "'")
|
||||
return s
|
||||
}
|
||||
|
||||
// resolvedItem is a context item resolved to its full content with token count.
|
||||
type resolvedItem struct {
|
||||
ordinal int
|
||||
itemType string // "message" or "summary"
|
||||
message *Message
|
||||
summary *Summary
|
||||
tokenCount int
|
||||
}
|
||||
|
||||
// Assemble builds budget-constrained context from summaries + messages.
|
||||
//
|
||||
// Algorithm:
|
||||
// 1. Fetch context_items, resolve to full content
|
||||
// 2. Split into evictable prefix + protected fresh tail
|
||||
// 3. If evictable fits in remaining budget → include all
|
||||
// 4. Else walk evictable from newest to oldest, keep while fits
|
||||
func (a *Assembler) Assemble(ctx context.Context, convID int64, input AssembleInput) (*AssembleResult, error) {
|
||||
items, err := a.store.GetContextItems(ctx, convID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get context items: %w", err)
|
||||
}
|
||||
if len(items) == 0 {
|
||||
return &AssembleResult{}, nil
|
||||
}
|
||||
|
||||
// Resolve all items
|
||||
resolved := make([]resolvedItem, len(items))
|
||||
for i, item := range items {
|
||||
r, err := a.resolveItem(ctx, item)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resolved[i] = r
|
||||
}
|
||||
|
||||
// Split into evictable prefix and protected fresh tail
|
||||
tailStart := len(resolved) - FreshTailCount
|
||||
if tailStart < 0 {
|
||||
tailStart = 0
|
||||
}
|
||||
evictable := resolved[:tailStart]
|
||||
freshTail := resolved[tailStart:]
|
||||
|
||||
// Calculate fresh tail tokens
|
||||
freshTailTokens := 0
|
||||
for _, r := range freshTail {
|
||||
freshTailTokens += r.tokenCount
|
||||
}
|
||||
|
||||
// Budget-aware selection of evictable items
|
||||
remainingBudget := input.Budget - freshTailTokens
|
||||
if remainingBudget < 0 {
|
||||
// Fresh tail alone exceeds budget - we keep it anyway (design decision)
|
||||
// Log for debugging retry/overflow issues
|
||||
logger.InfoCF("seahorse", "assemble: fresh tail exceeds budget", map[string]any{
|
||||
"budget": input.Budget,
|
||||
"fresh_tail_tokens": freshTailTokens,
|
||||
"fresh_tail_count": len(freshTail),
|
||||
"over_budget_by": freshTailTokens - input.Budget,
|
||||
})
|
||||
remainingBudget = 0
|
||||
}
|
||||
|
||||
var selected []resolvedItem
|
||||
evictableTokens := 0
|
||||
for _, r := range evictable {
|
||||
evictableTokens += r.tokenCount
|
||||
}
|
||||
|
||||
if evictableTokens <= remainingBudget {
|
||||
// All evictable fit
|
||||
selected = append(selected, evictable...)
|
||||
} else {
|
||||
// Walk from newest to oldest, keep while fits
|
||||
var kept []resolvedItem
|
||||
accum := 0
|
||||
for i := len(evictable) - 1; i >= 0; i-- {
|
||||
if accum+evictable[i].tokenCount <= remainingBudget {
|
||||
kept = append(kept, evictable[i])
|
||||
accum += evictable[i].tokenCount
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
// Reverse to restore chronological order
|
||||
for i, j := 0, len(kept)-1; i < j; i, j = i+1, j-1 {
|
||||
kept[i], kept[j] = kept[j], kept[i]
|
||||
}
|
||||
selected = append(selected, kept...)
|
||||
}
|
||||
|
||||
// Combine: selected evictable + fresh tail
|
||||
final := append(selected, freshTail...)
|
||||
|
||||
// Build result
|
||||
var messages []Message
|
||||
var summaries []Summary
|
||||
var sourceIDs []string
|
||||
totalTokens := 0
|
||||
maxDepth := 0
|
||||
condensedCount := 0
|
||||
|
||||
for _, r := range final {
|
||||
totalTokens += r.tokenCount
|
||||
if r.itemType == "message" && r.message != nil {
|
||||
messages = append(messages, *r.message)
|
||||
sourceIDs = append(sourceIDs, fmt.Sprintf("msg:%d", r.message.ID))
|
||||
} else if r.itemType == "summary" && r.summary != nil {
|
||||
summaries = append(summaries, *r.summary)
|
||||
if r.summary.Depth > maxDepth {
|
||||
maxDepth = r.summary.Depth
|
||||
}
|
||||
if r.summary.Kind == SummaryKindCondensed {
|
||||
condensedCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build depth-aware system prompt addition
|
||||
systemPromptAddition := ""
|
||||
if len(summaries) > 0 {
|
||||
if maxDepth >= 2 || condensedCount >= 2 {
|
||||
systemPromptAddition = "Your context has been heavily compressed through multi-level summarization.\n" +
|
||||
"- Do NOT assert specific facts (commands, SHAs, paths, timestamps) from summaries without expanding.\n" +
|
||||
"- When uncertain, use expand to recover original detail before making claims.\n" +
|
||||
"- Tool escalation: grep \xe2\x86\x92 describe \xe2\x86\x92 expand"
|
||||
} else {
|
||||
systemPromptAddition = "Some earlier messages have been summarized. Use expand tools to recover details if needed."
|
||||
}
|
||||
}
|
||||
|
||||
// Build Summary field: all XML summaries + system prompt addition
|
||||
var summaryParts []string
|
||||
for _, sum := range summaries {
|
||||
if sum.Content == "" {
|
||||
continue
|
||||
}
|
||||
// Load parent IDs for XML formatting
|
||||
parentSummaries, err := a.store.GetSummaryParents(ctx, sum.SummaryID)
|
||||
if err != nil {
|
||||
logger.WarnCF("seahorse", "assemble: get summary parents", map[string]any{
|
||||
"summary_id": sum.SummaryID,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
var parentIDs []string
|
||||
for _, ps := range parentSummaries {
|
||||
parentIDs = append(parentIDs, ps.SummaryID)
|
||||
}
|
||||
summaryParts = append(summaryParts, FormatSummaryXML(&sum, parentIDs))
|
||||
}
|
||||
summary := strings.Join(summaryParts, "\n\n")
|
||||
if systemPromptAddition != "" {
|
||||
if summary != "" {
|
||||
summary += "\n\n"
|
||||
}
|
||||
summary += systemPromptAddition
|
||||
}
|
||||
|
||||
return &AssembleResult{
|
||||
Messages: messages,
|
||||
Summary: summary,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// resolveItem loads the full message or summary for a context item.
|
||||
func (a *Assembler) resolveItem(ctx context.Context, item ContextItem) (resolvedItem, error) {
|
||||
if item.ItemType == "message" {
|
||||
msg, err := a.store.GetMessageByID(ctx, item.MessageID)
|
||||
if err != nil {
|
||||
return resolvedItem{}, err
|
||||
}
|
||||
tokens := item.TokenCount
|
||||
if tokens == 0 {
|
||||
tokens = msg.TokenCount
|
||||
}
|
||||
return resolvedItem{
|
||||
ordinal: item.Ordinal,
|
||||
itemType: "message",
|
||||
message: msg,
|
||||
tokenCount: tokens,
|
||||
}, nil
|
||||
}
|
||||
|
||||
if item.ItemType == "summary" {
|
||||
sum, err := a.store.GetSummary(ctx, item.SummaryID)
|
||||
if err != nil {
|
||||
return resolvedItem{}, err
|
||||
}
|
||||
tokens := item.TokenCount
|
||||
if tokens == 0 {
|
||||
tokens = sum.TokenCount
|
||||
}
|
||||
return resolvedItem{
|
||||
ordinal: item.Ordinal,
|
||||
itemType: "summary",
|
||||
summary: sum,
|
||||
tokenCount: tokens,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return resolvedItem{
|
||||
ordinal: item.Ordinal,
|
||||
itemType: item.ItemType,
|
||||
tokenCount: item.TokenCount,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// FormatSummaryXML formats a summary as XML for LLM context.
|
||||
// This is exported so context managers can format summaries consistently.
|
||||
func FormatSummaryXML(s *Summary, parentIDs []string) string {
|
||||
// Build time attributes if available
|
||||
var attrs string
|
||||
if s.EarliestAt != nil {
|
||||
attrs += fmt.Sprintf(` earliest_at="%s"`, s.EarliestAt.Format(time.RFC3339))
|
||||
}
|
||||
if s.LatestAt != nil {
|
||||
attrs += fmt.Sprintf(` latest_at="%s"`, s.LatestAt.Format(time.RFC3339))
|
||||
}
|
||||
|
||||
var parentsSection string
|
||||
if s.Kind == SummaryKindCondensed && len(parentIDs) > 0 {
|
||||
parents := "<parents>\n"
|
||||
for _, pid := range parentIDs {
|
||||
parents += fmt.Sprintf(" <summary_ref id=\"%s\" />\n", pid)
|
||||
}
|
||||
parents += " </parents>\n"
|
||||
parentsSection = parents
|
||||
}
|
||||
return fmt.Sprintf(
|
||||
"<summary id=\"%s\" kind=\"%s\" depth=\"%d\" descendant_count=\"%d\"%s>\n <content>\n %s\n </content>\n%s</summary>",
|
||||
s.SummaryID,
|
||||
string(s.Kind),
|
||||
s.Depth,
|
||||
s.DescendantCount,
|
||||
attrs,
|
||||
escapeXML(s.Content),
|
||||
parentsSection,
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,536 @@
|
||||
package seahorse
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// --- Assembler Tests ---
|
||||
|
||||
// helper: create a store with messages and summaries for assembly tests
|
||||
func setupAssemblerStore(t *testing.T) (*Store, int64) {
|
||||
t.Helper()
|
||||
s := openTestStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
conv, err := s.GetOrCreateConversation(ctx, "test:assemble")
|
||||
if err != nil {
|
||||
t.Fatalf("create conversation: %v", err)
|
||||
}
|
||||
|
||||
return s, conv.ConversationID
|
||||
}
|
||||
|
||||
func TestAssemblerAssembleEmpty(t *testing.T) {
|
||||
s, convID := setupAssemblerStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
a := &Assembler{store: s, config: Config{}}
|
||||
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 1000})
|
||||
if err != nil {
|
||||
t.Fatalf("Assemble: %v", err)
|
||||
}
|
||||
if len(result.Messages) != 0 {
|
||||
t.Errorf("Messages = %d, want 0", len(result.Messages))
|
||||
}
|
||||
if result.Summary != "" {
|
||||
t.Errorf("Summary = %q, want empty", result.Summary)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAssemblerAssembleMessagesOnly(t *testing.T) {
|
||||
s, convID := setupAssemblerStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create messages
|
||||
msg1, _ := s.AddMessage(ctx, convID, "user", "hello", 5)
|
||||
msg2, _ := s.AddMessage(ctx, convID, "assistant", "world", 5)
|
||||
|
||||
// Create context items
|
||||
s.UpsertContextItems(ctx, convID, []ContextItem{
|
||||
{Ordinal: 100, ItemType: "message", MessageID: msg1.ID, TokenCount: 5},
|
||||
{Ordinal: 200, ItemType: "message", MessageID: msg2.ID, TokenCount: 5},
|
||||
})
|
||||
|
||||
a := &Assembler{store: s, config: Config{}}
|
||||
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 100})
|
||||
if err != nil {
|
||||
t.Fatalf("Assemble: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 2 {
|
||||
t.Fatalf("Messages = %d, want 2", len(result.Messages))
|
||||
}
|
||||
if result.Messages[0].Content != "hello" {
|
||||
t.Errorf("Messages[0].Content = %q, want 'hello'", result.Messages[0].Content)
|
||||
}
|
||||
if result.Messages[1].Content != "world" {
|
||||
t.Errorf("Messages[1].Content = %q, want 'world'", result.Messages[1].Content)
|
||||
}
|
||||
// No summaries, so Summary should be empty
|
||||
if result.Summary != "" {
|
||||
t.Errorf("Summary = %q, want empty", result.Summary)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAssemblerAssembleWithSummary(t *testing.T) {
|
||||
s, convID := setupAssemblerStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a summary
|
||||
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: "summary of early messages",
|
||||
TokenCount: 50,
|
||||
})
|
||||
|
||||
// Create recent messages
|
||||
msg1, _ := s.AddMessage(ctx, convID, "user", "recent", 5)
|
||||
msg2, _ := s.AddMessage(ctx, convID, "assistant", "reply", 5)
|
||||
|
||||
// Context: summary + recent messages
|
||||
s.UpsertContextItems(ctx, convID, []ContextItem{
|
||||
{Ordinal: 100, ItemType: "summary", SummaryID: summary.SummaryID, TokenCount: 50},
|
||||
{Ordinal: 200, ItemType: "message", MessageID: msg1.ID, TokenCount: 5},
|
||||
{Ordinal: 300, ItemType: "message", MessageID: msg2.ID, TokenCount: 5},
|
||||
})
|
||||
|
||||
a := &Assembler{store: s, config: Config{}}
|
||||
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 1000})
|
||||
if err != nil {
|
||||
t.Fatalf("Assemble: %v", err)
|
||||
}
|
||||
|
||||
// Messages = 2 raw messages (summaries are in Summary field, not Messages)
|
||||
if len(result.Messages) != 2 {
|
||||
t.Errorf("Messages = %d, want 2 (raw messages only)", len(result.Messages))
|
||||
}
|
||||
// Summary should contain XML with summary content
|
||||
if result.Summary == "" {
|
||||
t.Error("Summary should not be empty when summary exists")
|
||||
}
|
||||
if !strings.Contains(result.Summary, summary.Content) {
|
||||
t.Errorf("Summary should contain summary content %q", summary.Content)
|
||||
}
|
||||
if !strings.Contains(result.Summary, "<summary") {
|
||||
t.Error("Summary should contain <summary XML tag")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAssemblerBudgetEvictsOldest(t *testing.T) {
|
||||
s, convID := setupAssemblerStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create 40 messages, each with 10 tokens = 400 total
|
||||
msgs := make([]*Message, 40)
|
||||
for i := 0; i < 40; i++ {
|
||||
m, _ := s.AddMessage(ctx, convID, "user", "msg", 10)
|
||||
msgs[i] = m
|
||||
}
|
||||
|
||||
// Context items for all messages
|
||||
items := make([]ContextItem, 40)
|
||||
for i := 0; i < 40; i++ {
|
||||
items[i] = ContextItem{
|
||||
Ordinal: (i + 1) * 100,
|
||||
ItemType: "message",
|
||||
MessageID: msgs[i].ID,
|
||||
TokenCount: 10,
|
||||
}
|
||||
}
|
||||
s.UpsertContextItems(ctx, convID, items)
|
||||
|
||||
// Budget of 200 tokens with FreshTailCount=32
|
||||
// Fresh tail = last 32 messages (320 tokens, over budget, but always included)
|
||||
// Evictable = first 8 messages (80 tokens)
|
||||
// Budget after tail: max(0, 200-320) = 0 → no evictable items included
|
||||
a := &Assembler{store: s, config: Config{}}
|
||||
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 200})
|
||||
if err != nil {
|
||||
t.Fatalf("Assemble: %v", err)
|
||||
}
|
||||
|
||||
// Should only include the 32-item fresh tail
|
||||
if len(result.Messages) != 32 {
|
||||
t.Errorf("Messages = %d, want 32 (fresh tail)", len(result.Messages))
|
||||
}
|
||||
// Should be the LAST 32 messages
|
||||
if result.Messages[0].ID != msgs[8].ID {
|
||||
t.Errorf("first message ID = %d, want %d (msgs[8])", result.Messages[0].ID, msgs[8].ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAssemblerBudgetFitsAll(t *testing.T) {
|
||||
s, convID := setupAssemblerStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
msgs := make([]*Message, 5)
|
||||
for i := 0; i < 5; i++ {
|
||||
m, _ := s.AddMessage(ctx, convID, "user", "msg", 10)
|
||||
msgs[i] = m
|
||||
}
|
||||
|
||||
items := make([]ContextItem, 5)
|
||||
for i := 0; i < 5; i++ {
|
||||
items[i] = ContextItem{
|
||||
Ordinal: (i + 1) * 100,
|
||||
ItemType: "message",
|
||||
MessageID: msgs[i].ID,
|
||||
TokenCount: 10,
|
||||
}
|
||||
}
|
||||
s.UpsertContextItems(ctx, convID, items)
|
||||
|
||||
// Budget = 100, total = 50, FreshTailCount=32 → all items in tail
|
||||
a := &Assembler{store: s, config: Config{}}
|
||||
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 100})
|
||||
if err != nil {
|
||||
t.Fatalf("Assemble: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 5 {
|
||||
t.Errorf("Messages = %d, want 5", len(result.Messages))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAssemblerSummaryXMLFormat(t *testing.T) {
|
||||
s, convID := setupAssemblerStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: "test summary content",
|
||||
TokenCount: 20,
|
||||
})
|
||||
|
||||
msg, _ := s.AddMessage(ctx, convID, "user", "hello", 5)
|
||||
|
||||
s.UpsertContextItems(ctx, convID, []ContextItem{
|
||||
{Ordinal: 100, ItemType: "summary", SummaryID: summary.SummaryID, TokenCount: 20},
|
||||
{Ordinal: 200, ItemType: "message", MessageID: msg.ID, TokenCount: 5},
|
||||
})
|
||||
|
||||
a := &Assembler{store: s, config: Config{}}
|
||||
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 1000})
|
||||
if err != nil {
|
||||
t.Fatalf("Assemble: %v", err)
|
||||
}
|
||||
|
||||
// Messages should only contain raw messages (no XML summary in Messages)
|
||||
if len(result.Messages) != 1 {
|
||||
t.Errorf("Messages = %d, want 1 (raw message only)", len(result.Messages))
|
||||
}
|
||||
// Summary should contain XML with summary content
|
||||
if result.Summary == "" {
|
||||
t.Fatal("Summary should not be empty")
|
||||
}
|
||||
if !contains(result.Summary, "<summary") {
|
||||
t.Errorf("Summary missing <summary tag: %q", result.Summary)
|
||||
}
|
||||
if !contains(result.Summary, summary.SummaryID) {
|
||||
t.Errorf("Summary missing summary ID: %q", result.Summary)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAssemblerSummaryXMLEscaping(t *testing.T) {
|
||||
// Summary content with special XML characters should be properly escaped
|
||||
s, convID := setupAssemblerStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create summary with content containing XML special characters
|
||||
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: `User said: "hello" & asked about <tags>`,
|
||||
TokenCount: 20,
|
||||
})
|
||||
|
||||
s.UpsertContextItems(ctx, convID, []ContextItem{
|
||||
{Ordinal: 100, ItemType: "summary", SummaryID: summary.SummaryID, TokenCount: 20},
|
||||
})
|
||||
|
||||
a := &Assembler{store: s, config: Config{}}
|
||||
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 1000})
|
||||
if err != nil {
|
||||
t.Fatalf("Assemble: %v", err)
|
||||
}
|
||||
|
||||
// Summary field should contain XML with escaped special characters
|
||||
if result.Summary == "" {
|
||||
t.Fatal("Summary should not be empty")
|
||||
}
|
||||
|
||||
// Check that special characters are escaped
|
||||
if strings.Contains(result.Summary, "<tags>") {
|
||||
t.Errorf("BUG: unescaped < in summary content: %q", result.Summary)
|
||||
}
|
||||
if strings.Contains(result.Summary, `"hello"`) {
|
||||
t.Errorf("BUG: unescaped \" in summary content: %q", result.Summary)
|
||||
}
|
||||
// & should be escaped as &
|
||||
if strings.Contains(result.Summary, " & ") {
|
||||
t.Errorf("BUG: unescaped & in summary content: %q", result.Summary)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAssemblerSummaryXMLWithParents(t *testing.T) {
|
||||
s, convID := setupAssemblerStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a leaf and a condensed summary (condensed has parent)
|
||||
leaf, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: "leaf content",
|
||||
TokenCount: 20,
|
||||
})
|
||||
condensed, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindCondensed,
|
||||
Depth: 1,
|
||||
Content: "condensed content",
|
||||
TokenCount: 15,
|
||||
ParentIDs: []string{leaf.SummaryID},
|
||||
})
|
||||
|
||||
msg, _ := s.AddMessage(ctx, convID, "user", "fresh", 5)
|
||||
|
||||
s.UpsertContextItems(ctx, convID, []ContextItem{
|
||||
{Ordinal: 100, ItemType: "summary", SummaryID: condensed.SummaryID, TokenCount: 15},
|
||||
{Ordinal: 200, ItemType: "message", MessageID: msg.ID, TokenCount: 5},
|
||||
})
|
||||
|
||||
a := &Assembler{store: s, config: Config{}}
|
||||
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 1000})
|
||||
if err != nil {
|
||||
t.Fatalf("Assemble: %v", err)
|
||||
}
|
||||
|
||||
// Summary field should contain XML with parent information
|
||||
if result.Summary == "" {
|
||||
t.Fatal("Summary should not be empty")
|
||||
}
|
||||
xmlContent := result.Summary
|
||||
|
||||
// Should contain <parents> section with parent ID
|
||||
if !contains(xmlContent, "<parents>") {
|
||||
t.Errorf("condensed summary XML missing <parents> section: %q", xmlContent)
|
||||
}
|
||||
if !contains(xmlContent, leaf.SummaryID) {
|
||||
t.Errorf("condensed summary XML missing parent ID %q: %q", leaf.SummaryID, xmlContent)
|
||||
}
|
||||
|
||||
// Should contain kind="condensed"
|
||||
if !contains(xmlContent, `kind="condensed"`) {
|
||||
t.Errorf("condensed summary XML missing kind attribute: %q", xmlContent)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAssemblerSummaryXMLIncludesDescendantCount(t *testing.T) {
|
||||
s, convID := setupAssemblerStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a leaf summary with specific descendant count
|
||||
leaf, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: "leaf content",
|
||||
TokenCount: 20,
|
||||
DescendantCount: 8,
|
||||
DescendantTokenCount: 1200,
|
||||
})
|
||||
|
||||
msg, _ := s.AddMessage(ctx, convID, "user", "fresh", 5)
|
||||
|
||||
s.UpsertContextItems(ctx, convID, []ContextItem{
|
||||
{Ordinal: 100, ItemType: "summary", SummaryID: leaf.SummaryID, TokenCount: 20},
|
||||
{Ordinal: 200, ItemType: "message", MessageID: msg.ID, TokenCount: 5},
|
||||
})
|
||||
|
||||
a := &Assembler{store: s, config: Config{}}
|
||||
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 1000})
|
||||
if err != nil {
|
||||
t.Fatalf("Assemble: %v", err)
|
||||
}
|
||||
|
||||
if result.Summary == "" {
|
||||
t.Fatal("Summary should not be empty")
|
||||
}
|
||||
xmlContent := result.Summary
|
||||
|
||||
// Should contain descendant_count="8"
|
||||
if !contains(xmlContent, `descendant_count="8"`) {
|
||||
t.Errorf("summary XML missing descendant_count attribute: %q", xmlContent)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAssemblerLeafSummaryNoParents(t *testing.T) {
|
||||
s, convID := setupAssemblerStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Leaf summary has no parents
|
||||
leaf, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: "leaf content",
|
||||
TokenCount: 20,
|
||||
})
|
||||
|
||||
msg, _ := s.AddMessage(ctx, convID, "user", "fresh", 5)
|
||||
|
||||
s.UpsertContextItems(ctx, convID, []ContextItem{
|
||||
{Ordinal: 100, ItemType: "summary", SummaryID: leaf.SummaryID, TokenCount: 20},
|
||||
{Ordinal: 200, ItemType: "message", MessageID: msg.ID, TokenCount: 5},
|
||||
})
|
||||
|
||||
a := &Assembler{store: s, config: Config{}}
|
||||
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 1000})
|
||||
if err != nil {
|
||||
t.Fatalf("Assemble: %v", err)
|
||||
}
|
||||
|
||||
if result.Summary == "" {
|
||||
t.Fatal("Summary should not be empty")
|
||||
}
|
||||
xmlContent := result.Summary
|
||||
|
||||
// Leaf summary should NOT have <parents> section
|
||||
if contains(xmlContent, "<parents>") {
|
||||
t.Errorf("leaf summary XML should not have <parents> section: %q", xmlContent)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAssemblerDepthAwarePrompt(t *testing.T) {
|
||||
s, convID := setupAssemblerStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a condensed summary (depth >= 2) to trigger full guidance
|
||||
now := time.Now().UTC()
|
||||
leaf, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: "leaf summary",
|
||||
TokenCount: 20,
|
||||
EarliestAt: &now,
|
||||
LatestAt: &now,
|
||||
})
|
||||
condensed, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindCondensed,
|
||||
Depth: 2,
|
||||
Content: "condensed summary",
|
||||
TokenCount: 15,
|
||||
ParentIDs: []string{leaf.SummaryID},
|
||||
DescendantCount: 1,
|
||||
DescendantTokenCount: 20,
|
||||
})
|
||||
|
||||
msg, _ := s.AddMessage(ctx, convID, "user", "fresh", 5)
|
||||
|
||||
s.UpsertContextItems(ctx, convID, []ContextItem{
|
||||
{Ordinal: 100, ItemType: "summary", SummaryID: condensed.SummaryID, TokenCount: 15},
|
||||
{Ordinal: 200, ItemType: "message", MessageID: msg.ID, TokenCount: 5},
|
||||
})
|
||||
|
||||
a := &Assembler{store: s, config: Config{}}
|
||||
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 1000})
|
||||
if err != nil {
|
||||
t.Fatalf("Assemble: %v", err)
|
||||
}
|
||||
|
||||
// Should have a depth-aware prompt in Summary field
|
||||
if result.Summary == "" {
|
||||
t.Error("expected non-empty Summary when depth >= 2")
|
||||
}
|
||||
// SystemPromptAddition is embedded in Summary field
|
||||
if !strings.Contains(result.Summary, "multi-level summarization") {
|
||||
t.Error("Summary should contain system prompt addition about multi-level summarization")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatSummaryXMLUsesSummaryRef(t *testing.T) {
|
||||
// Spec: condensed summaries use <summary_ref id="parentId" /> not <parent>parentId</parent>
|
||||
now := time.Now().UTC()
|
||||
s := Summary{
|
||||
SummaryID: "sum_condensed1",
|
||||
Kind: SummaryKindCondensed,
|
||||
Depth: 1,
|
||||
Content: "condensed content",
|
||||
TokenCount: 50,
|
||||
DescendantCount: 2,
|
||||
EarliestAt: &now,
|
||||
LatestAt: &now,
|
||||
}
|
||||
parentIDs := []string{"sum_leaf1", "sum_leaf2"}
|
||||
|
||||
xml := FormatSummaryXML(&s, parentIDs)
|
||||
|
||||
// Must use <summary_ref id="..." /> per spec
|
||||
if !contains(xml, `<summary_ref id="sum_leaf1" />`) {
|
||||
t.Errorf("expected <summary_ref id=\"sum_leaf1\" />, got: %s", xml)
|
||||
}
|
||||
if !contains(xml, `<summary_ref id="sum_leaf2" />`) {
|
||||
t.Errorf("expected <summary_ref id=\"sum_leaf2\" />, got: %s", xml)
|
||||
}
|
||||
// Must NOT use old <parent> tag
|
||||
if contains(xml, "<parent>") {
|
||||
t.Errorf("should not use <parent> tag, got: %s", xml)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatSummaryXMLIncludesTimestamps(t *testing.T) {
|
||||
// Spec: summary XML includes earliest_at and latest_at attributes
|
||||
earliest := time.Date(2026, 3, 15, 10, 0, 0, 0, time.UTC)
|
||||
latest := time.Date(2026, 3, 15, 14, 30, 0, 0, time.UTC)
|
||||
s := Summary{
|
||||
SummaryID: "sum_leaf1",
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: "leaf content",
|
||||
TokenCount: 30,
|
||||
DescendantCount: 0,
|
||||
EarliestAt: &earliest,
|
||||
LatestAt: &latest,
|
||||
}
|
||||
|
||||
xml := FormatSummaryXML(&s, nil)
|
||||
|
||||
if !contains(xml, `earliest_at="2026-03-15T10:00:00Z"`) {
|
||||
t.Errorf("missing earliest_at attribute, got: %s", xml)
|
||||
}
|
||||
if !contains(xml, `latest_at="2026-03-15T14:30:00Z"`) {
|
||||
t.Errorf("missing latest_at attribute, got: %s", xml)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatSummaryXMLNoTimestampsWhenNil(t *testing.T) {
|
||||
// When EarliestAt/LatestAt are nil, attributes should be omitted
|
||||
s := Summary{
|
||||
SummaryID: "sum_leaf1",
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: "leaf content",
|
||||
TokenCount: 30,
|
||||
DescendantCount: 0,
|
||||
}
|
||||
|
||||
xml := FormatSummaryXML(&s, nil)
|
||||
|
||||
if contains(xml, "earliest_at=") {
|
||||
t.Errorf("should not have earliest_at when nil, got: %s", xml)
|
||||
}
|
||||
if contains(xml, "latest_at=") {
|
||||
t.Errorf("should not have latest_at when nil, got: %s", xml)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,336 @@
|
||||
package seahorse
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
// newBenchStore creates a test store for benchmarks.
|
||||
func newBenchStore(b *testing.B) (*Store, func()) {
|
||||
b.Helper()
|
||||
db, err := sql.Open("sqlite", ":memory:")
|
||||
if err != nil {
|
||||
b.Fatalf("open test db: %v", err)
|
||||
}
|
||||
if err := runSchema(db); err != nil {
|
||||
db.Close()
|
||||
b.Fatalf("migration: %v", err)
|
||||
}
|
||||
return &Store{db: db}, func() { db.Close() }
|
||||
}
|
||||
|
||||
// --- Ingest benchmarks ---
|
||||
|
||||
func BenchmarkIngest_SingleMessage(b *testing.B) {
|
||||
s, cleanup := newBenchStore(b)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
conv, _ := s.GetOrCreateConversation(ctx, "bench:ingest")
|
||||
convID := conv.ConversationID
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := s.AddMessage(ctx, convID, "user", "Test message content", 15)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkIngest_BatchMessages(b *testing.B) {
|
||||
s, cleanup := newBenchStore(b)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
conv, _ := s.GetOrCreateConversation(ctx, fmt.Sprintf("bench:ingest-batch:%d", i))
|
||||
convID := conv.ConversationID
|
||||
|
||||
for j := 0; j < 10; j++ {
|
||||
added, err := s.AddMessage(ctx, convID, "user",
|
||||
fmt.Sprintf("Message %d in batch", j), 10)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
s.AppendContextMessage(ctx, convID, added.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- Assemble benchmarks ---
|
||||
|
||||
func BenchmarkAssemble_MessagesOnly(b *testing.B) {
|
||||
s, cleanup := newBenchStore(b)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
conv, _ := s.GetOrCreateConversation(ctx, "bench:assemble-msgs")
|
||||
convID := conv.ConversationID
|
||||
|
||||
// Add 100 messages
|
||||
for i := 0; i < 100; i++ {
|
||||
m, _ := s.AddMessage(ctx, convID, "user",
|
||||
fmt.Sprintf("Message content %d with some text", i), 10)
|
||||
s.AppendContextMessage(ctx, convID, m.ID)
|
||||
}
|
||||
|
||||
a := &Assembler{store: s}
|
||||
input := AssembleInput{Budget: 50000}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := a.Assemble(ctx, convID, input)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkAssemble_WithSummaries(b *testing.B) {
|
||||
s, cleanup := newBenchStore(b)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
conv, _ := s.GetOrCreateConversation(ctx, "bench:assemble-sums")
|
||||
convID := conv.ConversationID
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
// Add 10 leaf summaries
|
||||
for i := 0; i < 10; i++ {
|
||||
sum, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: fmt.Sprintf("Leaf summary %d", i),
|
||||
TokenCount: 500,
|
||||
EarliestAt: &now,
|
||||
LatestAt: &now,
|
||||
})
|
||||
s.AppendContextSummary(ctx, convID, sum.SummaryID)
|
||||
}
|
||||
|
||||
// Add 20 fresh messages
|
||||
for i := 0; i < 20; i++ {
|
||||
m, _ := s.AddMessage(ctx, convID, "user", fmt.Sprintf("Fresh message %d", i), 10)
|
||||
s.AppendContextMessage(ctx, convID, m.ID)
|
||||
}
|
||||
|
||||
a := &Assembler{store: s}
|
||||
input := AssembleInput{Budget: 10000}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := a.Assemble(ctx, convID, input)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkAssemble_BudgetEviction(b *testing.B) {
|
||||
s, cleanup := newBenchStore(b)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
conv, _ := s.GetOrCreateConversation(ctx, "bench:assemble-evict")
|
||||
convID := conv.ConversationID
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
// Add 50 leaf summaries (more than budget can hold)
|
||||
for i := 0; i < 50; i++ {
|
||||
sum, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: fmt.Sprintf("Summary %d", i),
|
||||
TokenCount: 300,
|
||||
EarliestAt: &now,
|
||||
LatestAt: &now,
|
||||
})
|
||||
s.AppendContextSummary(ctx, convID, sum.SummaryID)
|
||||
}
|
||||
|
||||
// Add fresh tail
|
||||
for i := 0; i < FreshTailCount; i++ {
|
||||
m, _ := s.AddMessage(ctx, convID, "user", "fresh", 10)
|
||||
s.AppendContextMessage(ctx, convID, m.ID)
|
||||
}
|
||||
|
||||
a := &Assembler{store: s}
|
||||
input := AssembleInput{Budget: 5000} // Force eviction
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := a.Assemble(ctx, convID, input)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- Search (FTS5) benchmarks ---
|
||||
|
||||
// benchSeedSummaries adds n summaries to a conversation for search benchmarks.
|
||||
func benchSeedSummaries(b *testing.B, s *Store, convID int64, n int, contentTpl string) {
|
||||
b.Helper()
|
||||
now := time.Now().UTC()
|
||||
for i := 0; i < n; i++ {
|
||||
sum, err := s.CreateSummary(context.Background(), CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: fmt.Sprintf(contentTpl, i),
|
||||
TokenCount: 200,
|
||||
EarliestAt: &now,
|
||||
LatestAt: &now,
|
||||
})
|
||||
if err != nil {
|
||||
b.Fatalf("create summary: %v", err)
|
||||
}
|
||||
s.AppendContextSummary(context.Background(), convID, sum.SummaryID)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSearchSummaries_FTS5(b *testing.B) {
|
||||
s, cleanup := newBenchStore(b)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
conv, _ := s.GetOrCreateConversation(ctx, "bench:search-fts")
|
||||
convID := conv.ConversationID
|
||||
|
||||
benchSeedSummaries(b, s, convID, 100, "Summary about database configuration and API endpoints %d")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := s.SearchSummaries(ctx, SearchInput{
|
||||
Pattern: "database",
|
||||
Mode: "full_text",
|
||||
ConversationID: convID,
|
||||
})
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSearchSummaries_Like(b *testing.B) {
|
||||
s, cleanup := newBenchStore(b)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
conv, _ := s.GetOrCreateConversation(ctx, "bench:search-like")
|
||||
convID := conv.ConversationID
|
||||
|
||||
benchSeedSummaries(b, s, convID, 100, "Summary about configuration %d")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := s.SearchSummaries(ctx, SearchInput{
|
||||
Pattern: "config",
|
||||
Mode: "like",
|
||||
ConversationID: convID,
|
||||
})
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSearchMessages_FTS5(b *testing.B) {
|
||||
s, cleanup := newBenchStore(b)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
conv, _ := s.GetOrCreateConversation(ctx, "bench:search-msg-fts")
|
||||
convID := conv.ConversationID
|
||||
|
||||
// Add 500 messages
|
||||
for i := 0; i < 500; i++ {
|
||||
m, _ := s.AddMessage(ctx, convID, "user",
|
||||
fmt.Sprintf("User message about API and database integration %d", i), 20)
|
||||
s.AppendContextMessage(ctx, convID, m.ID)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := s.SearchMessages(ctx, SearchInput{
|
||||
Pattern: "API database",
|
||||
Mode: "full_text",
|
||||
ConversationID: convID,
|
||||
})
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- Bootstrap benchmarks ---
|
||||
|
||||
func BenchmarkBootstrap_Empty(b *testing.B) {
|
||||
s, cleanup := newBenchStore(b)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
conv, _ := s.GetOrCreateConversation(ctx, fmt.Sprintf("bench:bootstrap-empty:%d", i))
|
||||
convID := conv.ConversationID
|
||||
_ = convID // Bootstrap with empty history
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkBootstrap_100Messages(b *testing.B) {
|
||||
s, cleanup := newBenchStore(b)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
|
||||
// Prepare 100 messages
|
||||
msgs := make([]Message, 100)
|
||||
for i := 0; i < 100; i++ {
|
||||
msgs[i] = Message{
|
||||
Role: "user",
|
||||
Content: fmt.Sprintf("Bootstrap message %d", i),
|
||||
TokenCount: 15,
|
||||
}
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
conv, _ := s.GetOrCreateConversation(ctx, fmt.Sprintf("bench:bootstrap-100:%d", i))
|
||||
convID := conv.ConversationID
|
||||
|
||||
for _, m := range msgs {
|
||||
added, _ := s.AddMessage(ctx, convID, m.Role, m.Content, m.TokenCount)
|
||||
s.AppendContextMessage(ctx, convID, added.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkBootstrap_500Messages(b *testing.B) {
|
||||
s, cleanup := newBenchStore(b)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
|
||||
msgs := make([]Message, 500)
|
||||
for i := 0; i < 500; i++ {
|
||||
msgs[i] = Message{
|
||||
Role: "user",
|
||||
Content: fmt.Sprintf("Bootstrap message %d", i),
|
||||
TokenCount: 15,
|
||||
}
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
conv, _ := s.GetOrCreateConversation(ctx, fmt.Sprintf("bench:bootstrap-500:%d", i))
|
||||
convID := conv.ConversationID
|
||||
|
||||
for _, m := range msgs {
|
||||
added, _ := s.AddMessage(ctx, convID, m.Role, m.Content, m.TokenCount)
|
||||
s.AppendContextMessage(ctx, convID, added.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,898 @@
|
||||
package seahorse
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/tokenizer"
|
||||
)
|
||||
|
||||
// CompactInput controls compaction behavior.
|
||||
type CompactInput struct {
|
||||
Budget *int // Token budget override
|
||||
Force bool // Force compaction even if below threshold
|
||||
}
|
||||
|
||||
// CompactResult describes what was compacted.
|
||||
type CompactResult struct {
|
||||
SummariesCreated []string `json:"summariesCreated"`
|
||||
TokensSaved int `json:"tokensSaved"`
|
||||
LeafSummaries int `json:"leafSummaries"`
|
||||
CondensedSummaries int `json:"condensedSummaries"`
|
||||
}
|
||||
|
||||
// NeedsCompaction returns true if context tokens >= ContextThreshold × contextWindow.
|
||||
func (e *CompactionEngine) NeedsCompaction(ctx context.Context, convID int64, contextWindow int) (bool, error) {
|
||||
tokens, err := e.store.GetContextTokenCount(ctx, convID)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("get token count: %w", err)
|
||||
}
|
||||
threshold := int(float64(contextWindow) * ContextThreshold)
|
||||
return tokens >= threshold, nil
|
||||
}
|
||||
|
||||
// Close cancels the shutdown context, stopping async goroutines.
|
||||
func (e *CompactionEngine) Close() {
|
||||
if e.shutdownCancel != nil {
|
||||
e.shutdownCancel()
|
||||
}
|
||||
}
|
||||
|
||||
// Compact runs leaf compaction (sync) and optionally condensed compaction.
|
||||
func (e *CompactionEngine) Compact(ctx context.Context, convID int64, input CompactInput) (*CompactResult, error) {
|
||||
result := &CompactResult{}
|
||||
|
||||
// Phase 1: leaf compaction (synchronous, every turn)
|
||||
summaryID, err := e.compactLeaf(ctx, convID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("compact leaf: %w", err)
|
||||
}
|
||||
if summaryID != nil {
|
||||
result.SummariesCreated = append(result.SummariesCreated, *summaryID)
|
||||
result.LeafSummaries++
|
||||
logger.InfoCF("seahorse", "compact: leaf", map[string]any{
|
||||
"conv_id": convID,
|
||||
"summary_id": *summaryID,
|
||||
})
|
||||
}
|
||||
|
||||
// Phase 2: condensed compaction if over threshold
|
||||
tokensBefore, _ := e.store.GetContextTokenCount(ctx, convID)
|
||||
var budget int
|
||||
if input.Budget != nil {
|
||||
budget = *input.Budget
|
||||
if budget == 0 {
|
||||
logger.ErrorCF("seahorse", "Compact: budget is 0, this should not happen", map[string]any{
|
||||
"conv_id": convID,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
budget = int(float64(tokensBefore) * ContextThreshold)
|
||||
}
|
||||
|
||||
if input.Force || (tokensBefore > budget && budget > 0) {
|
||||
// Launch async condensed compaction with dedup
|
||||
if _, loaded := e.condensing.LoadOrStore(convID, struct{}{}); !loaded {
|
||||
go func() {
|
||||
defer e.condensing.Delete(convID)
|
||||
e.runCondensedLoop(e.shutdownCtx, convID)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
tokensAfter, _ := e.store.GetContextTokenCount(ctx, convID)
|
||||
if tokensAfter < tokensBefore {
|
||||
result.TokensSaved = tokensBefore - tokensAfter
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// CompactUntilUnder aggressively compacts until context is under budget.
|
||||
func (e *CompactionEngine) CompactUntilUnder(ctx context.Context, convID int64, budget int) (*CompactResult, error) {
|
||||
result := &CompactResult{}
|
||||
prevTokens := 0
|
||||
logger.InfoCF("seahorse", "compact_until_under: start", map[string]any{"conv_id": convID, "budget": budget})
|
||||
|
||||
for iter := 0; iter < MaxCompactIterations; iter++ {
|
||||
tokens, err := e.store.GetContextTokenCount(ctx, convID)
|
||||
if err != nil {
|
||||
return result, fmt.Errorf("get tokens: %w", err)
|
||||
}
|
||||
if tokens <= budget {
|
||||
logger.InfoCF("seahorse", "compact_until_under: done", map[string]any{
|
||||
"conv_id": convID,
|
||||
"budget": budget,
|
||||
"tokens": tokens,
|
||||
"leaf": result.LeafSummaries,
|
||||
"condensed": result.CondensedSummaries,
|
||||
})
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Try leaf first
|
||||
summaryID, err := e.compactLeaf(ctx, convID, true)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
if summaryID != nil {
|
||||
result.SummariesCreated = append(result.SummariesCreated, *summaryID)
|
||||
result.LeafSummaries++
|
||||
logger.InfoCF("seahorse", "compact_until_under: leaf", map[string]any{
|
||||
"conv_id": convID,
|
||||
"summary_id": *summaryID,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// Try condensed with forced fanout
|
||||
condensedID, err := e.compactCondensed(ctx, convID)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
if condensedID != nil {
|
||||
result.SummariesCreated = append(result.SummariesCreated, *condensedID)
|
||||
result.CondensedSummaries++
|
||||
logger.InfoCF("seahorse", "compact_until_under: condensed", map[string]any{
|
||||
"conv_id": convID,
|
||||
"summary_id": *condensedID,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// No progress
|
||||
newTokens, _ := e.store.GetContextTokenCount(ctx, convID)
|
||||
if newTokens >= prevTokens {
|
||||
logger.WarnCF("seahorse", "compact_until_under: no progress", map[string]any{
|
||||
"conv_id": convID,
|
||||
"tokens": newTokens,
|
||||
})
|
||||
return result, nil
|
||||
}
|
||||
prevTokens = newTokens
|
||||
}
|
||||
|
||||
// Safety cap exceeded — see MaxCompactIterations doc for rationale.
|
||||
logger.WarnCF("seahorse", "compact_until_under: exceeded max iterations", map[string]any{
|
||||
"conv_id": convID,
|
||||
"budget": budget,
|
||||
"iterations": MaxCompactIterations,
|
||||
"tokens": prevTokens,
|
||||
})
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// compactLeaf compresses the oldest contiguous message chunk into a leaf summary.
|
||||
// When force is true, FreshTailCount protection is bypassed (used by CompactUntilUnder).
|
||||
func (e *CompactionEngine) compactLeaf(ctx context.Context, convID int64, force ...bool) (*string, error) {
|
||||
items, err := e.store.GetContextItems(ctx, convID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Find oldest contiguous message chunk outside fresh tail
|
||||
msgCount := 0
|
||||
msgTokens := 0
|
||||
for _, item := range items {
|
||||
if item.ItemType == "message" {
|
||||
msgCount++
|
||||
msgTokens += item.TokenCount
|
||||
}
|
||||
}
|
||||
|
||||
// Trigger if either message count or token threshold is met
|
||||
if msgCount < LeafMinFanout && msgTokens < LeafChunkTokens {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Calculate fresh tail boundary (bypass when forced)
|
||||
useForce := len(force) > 0 && force[0]
|
||||
tailStartIdx := len(items) - FreshTailCount
|
||||
if useForce {
|
||||
tailStartIdx = len(items) // allow compacting everything
|
||||
}
|
||||
if tailStartIdx < 0 {
|
||||
tailStartIdx = 0
|
||||
}
|
||||
|
||||
// Find oldest contiguous message chunk, accumulating up to LeafChunkTokens
|
||||
var chunk []ContextItem
|
||||
chunkStart := -1
|
||||
chunkEnd := -1
|
||||
accumTokens := 0
|
||||
for i := 0; i < tailStartIdx; i++ {
|
||||
if items[i].ItemType == "message" {
|
||||
if chunkStart == -1 {
|
||||
chunkStart = i
|
||||
}
|
||||
chunkEnd = i
|
||||
accumTokens += items[i].TokenCount
|
||||
// Stop accumulating once we reach the token budget
|
||||
if accumTokens >= LeafChunkTokens {
|
||||
break
|
||||
}
|
||||
} else {
|
||||
// Non-message breaks the chunk
|
||||
if chunkStart != -1 && (chunkEnd-chunkStart+1) >= LeafMinFanout {
|
||||
break
|
||||
}
|
||||
chunkStart = -1
|
||||
chunkEnd = -1
|
||||
accumTokens = 0
|
||||
}
|
||||
}
|
||||
|
||||
if chunkStart == -1 || (chunkEnd-chunkStart+1) < LeafMinFanout {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
chunk = items[chunkStart : chunkEnd+1]
|
||||
|
||||
// Collect messages for the chunk
|
||||
var messages []Message
|
||||
for _, item := range chunk {
|
||||
msg, innerErr := e.store.GetMessageByID(ctx, item.MessageID)
|
||||
if innerErr != nil {
|
||||
return nil, innerErr
|
||||
}
|
||||
messages = append(messages, *msg)
|
||||
}
|
||||
|
||||
// Get prior summaries for context
|
||||
priorSummary := ""
|
||||
priorCount := 0
|
||||
for i := chunkStart - 1; i >= 0 && priorCount < 2; i-- {
|
||||
if items[i].ItemType == "summary" {
|
||||
sum, innerErr2 := e.store.GetSummary(ctx, items[i].SummaryID)
|
||||
if innerErr2 == nil {
|
||||
priorSummary = sum.Content + "\n" + priorSummary
|
||||
priorCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Generate summary
|
||||
content, err := e.generateLeafSummary(ctx, messages, priorSummary)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create summary in store
|
||||
tokenCount := tokenizer.EstimateMessageTokens(providers.Message{Content: content})
|
||||
|
||||
var earliestAt, latestAt *time.Time
|
||||
if len(messages) > 0 {
|
||||
earliestAt = &messages[0].CreatedAt
|
||||
latestAt = &messages[len(messages)-1].CreatedAt
|
||||
}
|
||||
|
||||
summary, err := e.store.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: content,
|
||||
TokenCount: tokenCount,
|
||||
EarliestAt: earliestAt,
|
||||
LatestAt: latestAt,
|
||||
SourceMessageTokens: sumMessageTokens(messages),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Link to source messages
|
||||
msgIDs := make([]int64, len(messages))
|
||||
for i, m := range messages {
|
||||
msgIDs[i] = m.ID
|
||||
}
|
||||
if err := e.store.LinkSummaryToMessages(ctx, summary.SummaryID, msgIDs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Replace context range with summary
|
||||
if err := e.store.ReplaceContextRangeWithSummary(
|
||||
ctx, convID, chunk[0].Ordinal, chunk[len(chunk)-1].Ordinal, summary.SummaryID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &summary.SummaryID, nil
|
||||
}
|
||||
|
||||
// compactCondensed compresses multiple summaries into one higher-level summary.
|
||||
func (e *CompactionEngine) compactCondensed(ctx context.Context, convID int64) (*string, error) {
|
||||
// Try ordinal-aware selection first (respects consecutive ordering)
|
||||
var candidates []Summary
|
||||
|
||||
depths, err := e.store.GetDistinctDepthsInContext(ctx, convID, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, depth := range depths {
|
||||
var chunkAtDepth []Summary
|
||||
var err2 error
|
||||
chunkAtDepth, err2 = e.selectOldestChunkAtDepth(ctx, convID, depth)
|
||||
if err2 != nil {
|
||||
continue
|
||||
}
|
||||
if len(chunkAtDepth) > 0 {
|
||||
candidates = chunkAtDepth
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to depth-grouping selection
|
||||
if len(candidates) == 0 {
|
||||
candidates, err = e.selectShallowestCondensationCandidate(ctx, convID, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if len(candidates) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Generate condensed summary
|
||||
content, err := e.generateCondensedSummary(ctx, candidates)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Merge metadata
|
||||
maxDepth := 0
|
||||
descendantCount := 0
|
||||
descendantTokenCount := 0
|
||||
sourceMessageTokens := 0
|
||||
var earliestAt, latestAt *time.Time
|
||||
|
||||
parentIDs := make([]string, len(candidates))
|
||||
for i, c := range candidates {
|
||||
parentIDs[i] = c.SummaryID
|
||||
if c.Depth > maxDepth {
|
||||
maxDepth = c.Depth
|
||||
}
|
||||
descendantCount += c.DescendantCount + 1
|
||||
descendantTokenCount += c.TokenCount + c.DescendantTokenCount
|
||||
sourceMessageTokens += c.SourceMessageTokenCount
|
||||
if c.EarliestAt != nil {
|
||||
if earliestAt == nil || c.EarliestAt.Before(*earliestAt) {
|
||||
earliestAt = c.EarliestAt
|
||||
}
|
||||
}
|
||||
if c.LatestAt != nil {
|
||||
if latestAt == nil || c.LatestAt.After(*latestAt) {
|
||||
latestAt = c.LatestAt
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tokenCount := tokenizer.EstimateMessageTokens(providers.Message{Content: content})
|
||||
|
||||
summary, err := e.store.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindCondensed,
|
||||
Depth: maxDepth + 1,
|
||||
Content: content,
|
||||
TokenCount: tokenCount,
|
||||
EarliestAt: earliestAt,
|
||||
LatestAt: latestAt,
|
||||
DescendantCount: descendantCount,
|
||||
DescendantTokenCount: descendantTokenCount,
|
||||
SourceMessageTokens: sourceMessageTokens,
|
||||
ParentIDs: parentIDs,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Find the ordinal range for the candidate summaries in context
|
||||
items, err := e.store.GetContextItems(ctx, convID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
candidateSet := make(map[string]bool)
|
||||
for _, c := range candidates {
|
||||
candidateSet[c.SummaryID] = true
|
||||
}
|
||||
|
||||
startOrd := -1
|
||||
endOrd := -1
|
||||
hasNonCandidate := false
|
||||
for _, item := range items {
|
||||
if item.ItemType == "summary" && candidateSet[item.SummaryID] {
|
||||
if startOrd == -1 {
|
||||
startOrd, endOrd = item.Ordinal, item.Ordinal
|
||||
} else {
|
||||
// Check for non-candidate items between endOrd and current ordinal
|
||||
for _, it := range items {
|
||||
if it.Ordinal > endOrd && it.Ordinal <= item.Ordinal {
|
||||
if it.ItemType != "summary" || !candidateSet[it.SummaryID] {
|
||||
hasNonCandidate = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if hasNonCandidate {
|
||||
break
|
||||
}
|
||||
if item.Ordinal < startOrd {
|
||||
startOrd = item.Ordinal
|
||||
}
|
||||
if item.Ordinal > endOrd {
|
||||
endOrd = item.Ordinal
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if startOrd == -1 || endOrd == -1 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Collect candidate summary IDs
|
||||
candidateIDs := make([]string, 0, len(candidates))
|
||||
for _, c := range candidates {
|
||||
candidateIDs = append(candidateIDs, c.SummaryID)
|
||||
}
|
||||
|
||||
if hasNonCandidate {
|
||||
// Use safe per-item deletion to avoid deleting non-candidate items
|
||||
if err := e.store.ReplaceContextItemsWithSummary(ctx, convID, candidateIDs, summary.SummaryID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
// Candidates are consecutive, use efficient range deletion
|
||||
if err := e.store.ReplaceContextRangeWithSummary(ctx, convID, startOrd, endOrd, summary.SummaryID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &summary.SummaryID, nil
|
||||
}
|
||||
|
||||
// selectShallowestCondensationCandidate finds the shallowest consecutive summary group.
|
||||
func (e *CompactionEngine) selectShallowestCondensationCandidate(
|
||||
ctx context.Context, convID int64, forced bool,
|
||||
) ([]Summary, error) {
|
||||
items, err := e.store.GetContextItems(ctx, convID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Group by depth, find consecutive runs
|
||||
tailStartIdx := len(items) - FreshTailCount
|
||||
if tailStartIdx < 0 {
|
||||
tailStartIdx = 0
|
||||
}
|
||||
|
||||
minFanout := CondensedMinFanout
|
||||
if forced {
|
||||
minFanout = CondensedMinFanoutHard
|
||||
}
|
||||
|
||||
// Track depth groups
|
||||
depthGroups := make(map[int][]ContextItem)
|
||||
for i := 0; i < tailStartIdx; i++ {
|
||||
item := items[i]
|
||||
if item.ItemType != "summary" {
|
||||
continue
|
||||
}
|
||||
sum, err := e.store.GetSummary(ctx, item.SummaryID)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
depthGroups[sum.Depth] = append(depthGroups[sum.Depth], item)
|
||||
}
|
||||
|
||||
// Find shallowest depth with enough candidates
|
||||
// Collect all depths and sort to handle non-consecutive depths
|
||||
var depths []int
|
||||
for depth := range depthGroups {
|
||||
depths = append(depths, depth)
|
||||
}
|
||||
sort.Ints(depths)
|
||||
|
||||
for _, depth := range depths {
|
||||
group := depthGroups[depth]
|
||||
if len(group) >= minFanout {
|
||||
// Load summaries
|
||||
var result []Summary
|
||||
for _, item := range group[:minFanout] {
|
||||
sum, err := e.store.GetSummary(ctx, item.SummaryID)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
result = append(result, *sum)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// selectOldestChunkAtDepth scans context_items from oldest ordinal, collecting consecutive
|
||||
// summaries at the given depth. Stops at non-summary items, different depth, fresh tail, or
|
||||
// token overflow. Returns contiguous chunk of summaries.
|
||||
func (e *CompactionEngine) selectOldestChunkAtDepth(
|
||||
ctx context.Context, convID int64, targetDepth int,
|
||||
) ([]Summary, error) {
|
||||
items, err := e.store.GetContextItems(ctx, convID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tailStartIdx := len(items) - FreshTailCount
|
||||
if tailStartIdx < 0 {
|
||||
tailStartIdx = 0
|
||||
}
|
||||
|
||||
var chunk []Summary
|
||||
accumTokens := 0
|
||||
|
||||
for i := 0; i < tailStartIdx; i++ {
|
||||
item := items[i]
|
||||
if item.ItemType != "summary" {
|
||||
// Non-summary breaks the chunk
|
||||
break
|
||||
}
|
||||
sum, err := e.store.GetSummary(ctx, item.SummaryID)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
if sum.Depth != targetDepth {
|
||||
// Different depth breaks the chunk
|
||||
break
|
||||
}
|
||||
if accumTokens+sum.TokenCount > LeafChunkTokens {
|
||||
// Token overflow stops collection
|
||||
break
|
||||
}
|
||||
chunk = append(chunk, *sum)
|
||||
accumTokens += sum.TokenCount
|
||||
}
|
||||
|
||||
// Min tokens check: spec line 808
|
||||
// chunk tokens must be >= max(CondensedTargetTokens, LeafChunkTokens × 0.1) = 2000
|
||||
minTokens := CondensedTargetTokens // 2000
|
||||
if accumTokens < minTokens {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return chunk, nil
|
||||
}
|
||||
|
||||
// generateLeafSummary calls the LLM to generate a leaf summary with 3-level escalation.
|
||||
// Level 1: normal LLM prompt. Level 2: aggressive prompt. Level 3: deterministic truncation.
|
||||
func (e *CompactionEngine) generateLeafSummary(
|
||||
ctx context.Context,
|
||||
messages []Message,
|
||||
previousSummary string,
|
||||
) (string, error) {
|
||||
if e.complete == nil {
|
||||
return truncateSummary(messages), nil
|
||||
}
|
||||
|
||||
sourceText := formatMessagesForSummary(messages)
|
||||
inputTokens := sumMessageTokens(messages)
|
||||
targetTokens := minInt(LeafTargetTokens, int(float64(inputTokens)*0.35))
|
||||
|
||||
// Level 1: normal prompt
|
||||
prompt := buildLeafSummaryPrompt(sourceText, previousSummary, targetTokens)
|
||||
content, err := e.complete(ctx, prompt, CompleteOptions{
|
||||
MaxTokens: LeafTargetTokens * 2,
|
||||
Temperature: 0.3,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if content == "" {
|
||||
// Retry with temperature=0
|
||||
content, err = e.complete(ctx, prompt, CompleteOptions{
|
||||
MaxTokens: LeafTargetTokens * 2,
|
||||
Temperature: 0,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
// Check if level 1 succeeded
|
||||
if content != "" && tokenizer.EstimateMessageTokens(providers.Message{Content: content}) < inputTokens {
|
||||
return content, nil
|
||||
}
|
||||
|
||||
// Level 2: aggressive prompt
|
||||
aggressiveTarget := minInt(640, int(float64(inputTokens)*0.20))
|
||||
aggressivePrompt := buildAggressiveLeafSummaryPrompt(sourceText, previousSummary, aggressiveTarget)
|
||||
content, err = e.complete(ctx, aggressivePrompt, CompleteOptions{
|
||||
MaxTokens: aggressiveTarget * 2,
|
||||
Temperature: 0.3,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if content == "" {
|
||||
// Retry with temperature=0
|
||||
content, err = e.complete(ctx, aggressivePrompt, CompleteOptions{
|
||||
MaxTokens: aggressiveTarget * 2,
|
||||
Temperature: 0,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
if content != "" && tokenizer.EstimateMessageTokens(providers.Message{Content: content}) < inputTokens {
|
||||
return content, nil
|
||||
}
|
||||
|
||||
// Level 3: deterministic truncation
|
||||
return truncateSummary(messages), nil
|
||||
}
|
||||
|
||||
// generateCondensedSummary calls the LLM to generate a condensed summary with 3-level escalation.
|
||||
func (e *CompactionEngine) generateCondensedSummary(ctx context.Context, summaries []Summary) (string, error) {
|
||||
if e.complete == nil {
|
||||
return truncateCondensedSummaries(summaries), nil
|
||||
}
|
||||
|
||||
sourceText := formatSummariesForCondensation(summaries)
|
||||
inputTokens := sumSummaryTokens(summaries)
|
||||
targetTokens := minInt(CondensedTargetTokens, int(float64(inputTokens)*0.35))
|
||||
|
||||
// Level 1: normal prompt
|
||||
prompt := buildCondensedSummaryPrompt(sourceText, targetTokens)
|
||||
content, err := e.complete(ctx, prompt, CompleteOptions{
|
||||
MaxTokens: CondensedTargetTokens * 2,
|
||||
Temperature: 0.3,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if content == "" {
|
||||
content, err = e.complete(ctx, prompt, CompleteOptions{
|
||||
MaxTokens: CondensedTargetTokens * 2,
|
||||
Temperature: 0,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
if content != "" {
|
||||
return content, nil
|
||||
}
|
||||
|
||||
// Level 2: aggressive prompt
|
||||
aggressiveTarget := minInt(640, int(float64(inputTokens)*0.20))
|
||||
aggressivePrompt := buildCondensedSummaryPrompt(sourceText, aggressiveTarget)
|
||||
content, err = e.complete(ctx, aggressivePrompt, CompleteOptions{
|
||||
MaxTokens: aggressiveTarget * 2,
|
||||
Temperature: 0.3,
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if content != "" {
|
||||
return content, nil
|
||||
}
|
||||
|
||||
// Level 3: deterministic fallback
|
||||
return truncateCondensedSummaries(summaries), nil
|
||||
}
|
||||
|
||||
// runCondensedLoop runs condensed compaction in a loop until:
|
||||
// a) context tokens <= threshold (success), OR
|
||||
// b) No candidate found (nothing to condense), OR
|
||||
// c) tokensAfter >= tokensBefore (no progress this iteration), OR
|
||||
// d) tokensAfter >= previousTokens (no improvement over last iteration)
|
||||
func (e *CompactionEngine) runCondensedLoop(ctx context.Context, convID int64) {
|
||||
var prevTokens int
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
tokensBefore, err := e.store.GetContextTokenCount(ctx, convID)
|
||||
if err != nil {
|
||||
logger.ErrorCF("seahorse", "condensed: get tokens", map[string]any{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
condensedID, err := e.compactCondensed(ctx, convID)
|
||||
if err != nil {
|
||||
logger.ErrorCF("seahorse", "condensed: compact", map[string]any{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if condensedID == nil {
|
||||
// No candidate found
|
||||
logger.DebugCF("seahorse", "condensed: no candidate", map[string]any{"conv_id": convID})
|
||||
return
|
||||
}
|
||||
|
||||
tokensAfter, _ := e.store.GetContextTokenCount(ctx, convID)
|
||||
|
||||
if tokensAfter >= tokensBefore {
|
||||
// No progress this iteration
|
||||
logger.DebugCF(
|
||||
"seahorse",
|
||||
"condensed: no progress",
|
||||
map[string]any{"conv_id": convID, "tokens_before": tokensBefore, "tokens_after": tokensAfter},
|
||||
)
|
||||
return
|
||||
}
|
||||
if tokensAfter >= prevTokens && prevTokens > 0 {
|
||||
// No improvement over last iteration
|
||||
logger.DebugCF(
|
||||
"seahorse",
|
||||
"condensed: no improvement",
|
||||
map[string]any{"conv_id": convID, "tokens": tokensAfter},
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
prevTokens = tokensAfter
|
||||
}
|
||||
}
|
||||
|
||||
// --- Helper functions ---
|
||||
|
||||
func formatMessagesForSummary(messages []Message) string {
|
||||
var result string
|
||||
for _, m := range messages {
|
||||
ts := m.CreatedAt.Format("2006-01-02 15:04 MST")
|
||||
content := m.Content
|
||||
if content == "" && len(m.Parts) > 0 {
|
||||
content = partsToReadableContent(m.Parts)
|
||||
}
|
||||
result += fmt.Sprintf("[%s]\n%s\n\n", ts, content)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func formatSummariesForCondensation(summaries []Summary) string {
|
||||
var result string
|
||||
for _, s := range summaries {
|
||||
earliest := ""
|
||||
if s.EarliestAt != nil {
|
||||
earliest = s.EarliestAt.Format("2006-01-02")
|
||||
}
|
||||
latest := ""
|
||||
if s.LatestAt != nil {
|
||||
latest = s.LatestAt.Format("2006-01-02")
|
||||
}
|
||||
result += fmt.Sprintf("[%s - %s]\n%s\n\n", earliest, latest, s.Content)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func buildLeafSummaryPrompt(sourceText, previousSummary string, targetTokens int) string {
|
||||
prev := "(none)"
|
||||
if previousSummary != "" {
|
||||
prev = previousSummary
|
||||
}
|
||||
return fmt.Sprintf(`You summarize a SEGMENT of a conversation for future model turns.
|
||||
Treat this as incremental memory compaction input, not a full-conversation summary.
|
||||
|
||||
Normal summary policy:
|
||||
- Preserve key decisions, rationale, constraints, and active tasks.
|
||||
- Keep essential technical details needed to continue work safely.
|
||||
- Remove obvious repetition and conversational filler.
|
||||
|
||||
Output requirements:
|
||||
- Plain text only.
|
||||
- No preamble, headings, or markdown formatting.
|
||||
- Track file operations (created, modified, deleted, renamed) with file paths and current status.
|
||||
- If no file operations appear, include exactly: "Files: none".
|
||||
- End with exactly: "Expand for details about: <comma-separated list of what was dropped or compressed>".
|
||||
- Target length: about %d tokens or less.
|
||||
|
||||
<previous_context>
|
||||
%s
|
||||
</previous_context>
|
||||
|
||||
<conversation_segment>
|
||||
%s
|
||||
</conversation_segment>`, targetTokens, prev, sourceText)
|
||||
}
|
||||
|
||||
func buildCondensedSummaryPrompt(sourceText string, targetTokens int) string {
|
||||
return fmt.Sprintf(`You condense multiple summaries into a single higher-level summary.
|
||||
Preserve all important decisions, constraints, and outcomes.
|
||||
Merge overlapping topics. Keep technical details intact.
|
||||
|
||||
Output requirements:
|
||||
- Plain text only.
|
||||
- No preamble, headings, or markdown formatting.
|
||||
- End with exactly: "Expand for details about: <comma-separated list>".
|
||||
- Target length: about %d tokens or less.
|
||||
|
||||
<summaries>
|
||||
%s
|
||||
</summaries>`, targetTokens, sourceText)
|
||||
}
|
||||
|
||||
func buildAggressiveLeafSummaryPrompt(sourceText, previousSummary string, targetTokens int) string {
|
||||
prev := "(none)"
|
||||
if previousSummary != "" {
|
||||
prev = previousSummary
|
||||
}
|
||||
return fmt.Sprintf(`You summarize a SEGMENT of a conversation for future model turns.
|
||||
Aggressive summary policy:
|
||||
- Keep only durable facts and current task state.
|
||||
- Remove examples, repetition, and low-value narrative details.
|
||||
- Preserve explicit TODOs, blockers, decisions, and constraints.
|
||||
|
||||
Output requirements:
|
||||
- Plain text only.
|
||||
- No preamble, headings, or markdown formatting.
|
||||
- Track file operations (created, modified, deleted, renamed) with file paths and current status.
|
||||
- If no file operations appear, include exactly: "Files: none".
|
||||
- End with exactly: "Expand for details about: <comma-separated list of what was dropped or compressed>".
|
||||
- Target length: about %d tokens or less.
|
||||
|
||||
<previous_context>
|
||||
%s
|
||||
</previous_context>
|
||||
|
||||
<conversation_segment>
|
||||
%s
|
||||
</conversation_segment>`, targetTokens, prev, sourceText)
|
||||
}
|
||||
|
||||
func truncateSummary(messages []Message) string {
|
||||
content := ""
|
||||
for _, m := range messages {
|
||||
c := m.Content
|
||||
if c == "" && len(m.Parts) > 0 {
|
||||
c = partsToReadableContent(m.Parts)
|
||||
}
|
||||
content += c + "\n"
|
||||
}
|
||||
if len(content) > 2048 {
|
||||
content = content[:2048]
|
||||
}
|
||||
content += fmt.Sprintf("\n[Truncated from %d messages]", len(messages))
|
||||
return content
|
||||
}
|
||||
|
||||
func truncateCondensedSummaries(summaries []Summary) string {
|
||||
content := ""
|
||||
for _, s := range summaries {
|
||||
content += s.Content + "\n"
|
||||
}
|
||||
if len(content) > 2048 {
|
||||
content = content[:2048]
|
||||
}
|
||||
content += fmt.Sprintf("\n[Condensed from %d summaries]", len(summaries))
|
||||
return content
|
||||
}
|
||||
|
||||
func sumMessageTokens(messages []Message) int {
|
||||
total := 0
|
||||
for _, m := range messages {
|
||||
total += m.TokenCount
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
func sumSummaryTokens(summaries []Summary) int {
|
||||
total := 0
|
||||
for _, s := range summaries {
|
||||
total += s.TokenCount
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
func minInt(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
@@ -0,0 +1,974 @@
|
||||
package seahorse
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// --- Test Helpers ---
|
||||
|
||||
// waitForCondensed blocks until the async condensed goroutine for convID finishes.
|
||||
// Returns false if timeout is reached.
|
||||
func waitForCondensed(ce *CompactionEngine, convID int64, timeout time.Duration) bool {
|
||||
deadline := time.Now().Add(timeout)
|
||||
for time.Now().Before(deadline) {
|
||||
if _, exists := ce.condensing.Load(convID); !exists {
|
||||
return true
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// --- Compaction Tests ---
|
||||
|
||||
func newTestCompactionEngine(t *testing.T) (*CompactionEngine, *Store, int64) {
|
||||
t.Helper()
|
||||
db := openTestDB(t)
|
||||
if err := runSchema(db); err != nil {
|
||||
t.Fatalf("migration: %v", err)
|
||||
}
|
||||
s := &Store{db: db}
|
||||
ctx := context.Background()
|
||||
conv, _ := s.GetOrCreateConversation(ctx, "test:compact")
|
||||
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
|
||||
ce := &CompactionEngine{
|
||||
store: s,
|
||||
config: Config{},
|
||||
complete: mockCompleteFn,
|
||||
shutdownCtx: shutdownCtx,
|
||||
shutdownCancel: shutdownCancel,
|
||||
}
|
||||
convID := conv.ConversationID
|
||||
// Ensure async goroutines are stopped before database is closed.
|
||||
// Register cleanup here (after openTestDB) so it runs BEFORE openTestDB's db.Close().
|
||||
t.Cleanup(func() {
|
||||
shutdownCancel()
|
||||
// Wait for async condensed goroutine to finish (poll condensing map)
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
if _, exists := ce.condensing.Load(convID); !exists {
|
||||
break
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
})
|
||||
return ce, s, conv.ConversationID
|
||||
}
|
||||
|
||||
// newTestCompactionEngineWithStore creates a CompactionEngine with existing store.
|
||||
// Note: Caller is responsible for calling shutdownCancel when test ends.
|
||||
func newTestCompactionEngineWithStore(
|
||||
s *Store, complete CompleteFn,
|
||||
) (ce *CompactionEngine, shutdownCancel context.CancelFunc) {
|
||||
shutdownCtx, cancel := context.WithCancel(context.Background())
|
||||
return &CompactionEngine{
|
||||
store: s,
|
||||
config: Config{},
|
||||
complete: complete,
|
||||
shutdownCtx: shutdownCtx,
|
||||
shutdownCancel: cancel,
|
||||
}, cancel
|
||||
}
|
||||
|
||||
// mockCompleteFn returns a simple summary for testing
|
||||
var mockCompleteFn CompleteFn = func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) {
|
||||
return "Mock summary of the conversation segment.", nil
|
||||
}
|
||||
|
||||
func TestNeedsCompaction(t *testing.T) {
|
||||
ce, s, convID := newTestCompactionEngine(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Empty context — no compaction needed
|
||||
needed, err := ce.NeedsCompaction(ctx, convID, 10000)
|
||||
if err != nil {
|
||||
t.Fatalf("NeedsCompaction: %v", err)
|
||||
}
|
||||
if needed {
|
||||
t.Error("expected no compaction for empty context")
|
||||
}
|
||||
|
||||
// Add messages to context, total tokens = 8000
|
||||
for i := 0; i < 8; i++ {
|
||||
m, _ := s.AddMessage(ctx, convID, "user", "test message content", 1000)
|
||||
s.AppendContextMessage(ctx, convID, m.ID)
|
||||
}
|
||||
|
||||
// Threshold = 0.75 × 10000 = 7500. We have 8000 tokens → needs compaction
|
||||
needed, err = ce.NeedsCompaction(ctx, convID, 10000)
|
||||
if err != nil {
|
||||
t.Fatalf("NeedsCompaction: %v", err)
|
||||
}
|
||||
if !needed {
|
||||
t.Error("expected compaction needed at 8000/10000 tokens (threshold 75%)")
|
||||
}
|
||||
|
||||
// Below threshold: 5000 / 10000 → no compaction
|
||||
s.UpsertContextItems(ctx, convID, nil) // clear
|
||||
for i := 0; i < 5; i++ {
|
||||
m, _ := s.AddMessage(ctx, convID, "user", "test", 1000)
|
||||
s.AppendContextMessage(ctx, convID, m.ID)
|
||||
}
|
||||
needed, _ = ce.NeedsCompaction(ctx, convID, 10000)
|
||||
if needed {
|
||||
t.Error("expected no compaction at 5000/10000 tokens")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompactLeaf(t *testing.T) {
|
||||
ce, s, convID := newTestCompactionEngine(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create enough messages to trigger leaf compaction:
|
||||
// Need > FreshTailCount(32) evictable messages with >= LeafMinFanout(8) contiguous
|
||||
for i := 0; i < 40; i++ {
|
||||
m, _ := s.AddMessage(ctx, convID, "user", "message content for compaction test", 100)
|
||||
s.AppendContextMessage(ctx, convID, m.ID)
|
||||
}
|
||||
|
||||
// Compact
|
||||
result, err := ce.Compact(ctx, convID, CompactInput{})
|
||||
if err != nil {
|
||||
t.Fatalf("Compact: %v", err)
|
||||
}
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
}
|
||||
|
||||
// Should have created at least one leaf summary
|
||||
if result.LeafSummaries == 0 {
|
||||
t.Error("expected at least 1 leaf summary")
|
||||
}
|
||||
|
||||
// Context should now contain a summary item
|
||||
items, _ := s.GetContextItems(ctx, convID)
|
||||
foundSummary := false
|
||||
for _, item := range items {
|
||||
if item.ItemType == "summary" {
|
||||
foundSummary = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundSummary {
|
||||
t.Error("expected a summary in context_items after leaf compaction")
|
||||
}
|
||||
|
||||
// Some messages should have been replaced
|
||||
if len(result.SummariesCreated) == 0 {
|
||||
t.Error("expected at least 1 summary created")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompactLeafNoCandidate(t *testing.T) {
|
||||
ce, _, convID := newTestCompactionEngine(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Too few messages to trigger leaf compaction
|
||||
m, _ := ce.store.AddMessage(ctx, convID, "user", "short", 10)
|
||||
ce.store.AppendContextMessage(ctx, convID, m.ID)
|
||||
|
||||
result, err := ce.Compact(ctx, convID, CompactInput{})
|
||||
if err != nil {
|
||||
t.Fatalf("Compact: %v", err)
|
||||
}
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result even with no candidate")
|
||||
}
|
||||
if result.LeafSummaries != 0 {
|
||||
t.Errorf("LeafSummaries = %d, want 0 (too few messages)", result.LeafSummaries)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompactCondensed(t *testing.T) {
|
||||
ce, s, convID := newTestCompactionEngine(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create enough leaf summaries and fresh messages to enable condensation
|
||||
leafIDs := make([]string, CondensedMinFanout)
|
||||
for i := 0; i < CondensedMinFanout; i++ {
|
||||
now := time.Now().UTC()
|
||||
summary, err := s.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: "leaf summary content " + time.Now().String(),
|
||||
TokenCount: 500,
|
||||
EarliestAt: &now,
|
||||
LatestAt: &now,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSummary %d: %v", i, err)
|
||||
}
|
||||
leafIDs[i] = summary.SummaryID
|
||||
s.AppendContextSummary(ctx, convID, summary.SummaryID)
|
||||
}
|
||||
|
||||
// Add enough fresh messages to have a fresh tail (>= FreshTailCount)
|
||||
for i := 0; i < FreshTailCount; i++ {
|
||||
m, _ := s.AddMessage(ctx, convID, "user", "fresh message", 10)
|
||||
s.AppendContextMessage(ctx, convID, m.ID)
|
||||
}
|
||||
|
||||
// Compact with force to trigger condensation
|
||||
_, err := ce.Compact(ctx, convID, CompactInput{Force: true})
|
||||
if err != nil {
|
||||
t.Fatalf("Compact: %v", err)
|
||||
}
|
||||
|
||||
// Wait for async condensed goroutine to complete
|
||||
if !waitForCondensed(ce, convID, 2*time.Second) {
|
||||
t.Fatal("timeout waiting for condensed compaction")
|
||||
}
|
||||
|
||||
// Should have created a condensed summary in the DB
|
||||
summaries, _ := s.GetSummariesByConversation(ctx, convID)
|
||||
foundCondensed := false
|
||||
for _, sum := range summaries {
|
||||
if sum.Kind == SummaryKindCondensed {
|
||||
foundCondensed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundCondensed {
|
||||
t.Error("expected at least 1 condensed summary")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompactCondensedDoesNotOrphanSummaryWhenCandidatesRemovedConcurrently(t *testing.T) {
|
||||
// Reproduce orphan bug: candidates found by selectOldestChunkAtDepth are removed
|
||||
// from context_items between candidate selection and ordinal range scan.
|
||||
// Use a slow CompleteFn with barrier sync to control timing.
|
||||
s := openTestStore(t)
|
||||
ctx := context.Background()
|
||||
conv, _ := s.GetOrCreateConversation(ctx, "test:orphan-race")
|
||||
convID := conv.ConversationID
|
||||
|
||||
// Create leaf summaries with enough tokens for condensation
|
||||
var leafIDs []string
|
||||
for i := 0; i < CondensedMinFanout; i++ {
|
||||
now := time.Now().UTC()
|
||||
sum, err := s.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: fmt.Sprintf("leaf summary %d", i),
|
||||
TokenCount: 500,
|
||||
EarliestAt: &now,
|
||||
LatestAt: &now,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSummary: %v", err)
|
||||
}
|
||||
leafIDs = append(leafIDs, sum.SummaryID)
|
||||
s.AppendContextSummary(ctx, convID, sum.SummaryID)
|
||||
}
|
||||
|
||||
// Add fresh tail so leaf summaries are in evictable range
|
||||
for i := 0; i < FreshTailCount+1; i++ {
|
||||
m, _ := s.AddMessage(ctx, convID, "user", "fresh", 10)
|
||||
s.AppendContextMessage(ctx, convID, m.ID)
|
||||
}
|
||||
|
||||
// Barrier: CompleteFn waits until test removes context_items, then returns
|
||||
var barrier1, barrier2 sync.WaitGroup
|
||||
barrier1.Add(1) // CompleteFn signals when called
|
||||
barrier2.Add(1) // test signals when context_items removed
|
||||
|
||||
slowComplete := func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) {
|
||||
barrier1.Done() // signal: LLM called, candidates selected
|
||||
barrier2.Wait() // wait: test removes context_items
|
||||
return "Condensed summary.", nil
|
||||
}
|
||||
|
||||
ce, cancel := newTestCompactionEngineWithStore(s, slowComplete)
|
||||
t.Cleanup(func() {
|
||||
cancel()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
})
|
||||
|
||||
// Run compactCondensed in background
|
||||
type compactResult struct {
|
||||
summaryID *string
|
||||
err error
|
||||
}
|
||||
resultCh := make(chan compactResult, 1)
|
||||
go func() {
|
||||
sid, err := ce.compactCondensed(context.Background(), convID)
|
||||
resultCh <- compactResult{summaryID: sid, err: err}
|
||||
}()
|
||||
|
||||
// Wait for CompleteFn to be called (candidates selected)
|
||||
barrier1.Wait()
|
||||
|
||||
// Remove leaf summaries from context_items (simulating concurrent replacement)
|
||||
items, _ := s.GetContextItems(ctx, convID)
|
||||
var preserved []ContextItem
|
||||
for _, item := range items {
|
||||
isLeaf := false
|
||||
for _, lid := range leafIDs {
|
||||
if item.SummaryID == lid {
|
||||
isLeaf = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !isLeaf {
|
||||
preserved = append(preserved, item)
|
||||
}
|
||||
}
|
||||
s.UpsertContextItems(ctx, convID, preserved)
|
||||
|
||||
// Let CompleteFn return
|
||||
barrier2.Done()
|
||||
|
||||
// Get result
|
||||
res := <-resultCh
|
||||
if res.err != nil {
|
||||
t.Fatalf("compactCondensed: %v", res.err)
|
||||
}
|
||||
|
||||
// With the bug: returns non-nil summaryID even though context_items has no matching ordinals
|
||||
// The fix: should return nil when startOrd == -1
|
||||
if res.summaryID != nil {
|
||||
t.Errorf("compactCondensed returned summaryID=%s, want nil (orphan created)", *res.summaryID)
|
||||
|
||||
// Verify the orphan exists in DB
|
||||
summary, _ := s.GetSummary(context.Background(), *res.summaryID)
|
||||
if summary != nil && summary.Kind == SummaryKindCondensed {
|
||||
// Check it's NOT in context_items (orphan)
|
||||
items2, _ := s.GetContextItems(context.Background(), convID)
|
||||
found := false
|
||||
for _, item := range items2 {
|
||||
if item.SummaryID == *res.summaryID {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("condensed summary exists in DB but not in context_items — orphan confirmed")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompactUntilUnder(t *testing.T) {
|
||||
ce, s, convID := newTestCompactionEngine(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create many leaf summaries to ensure we can condense
|
||||
for i := 0; i < 8; i++ {
|
||||
now := time.Now().UTC()
|
||||
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: "leaf summary for condensation test",
|
||||
TokenCount: 500,
|
||||
EarliestAt: &now,
|
||||
LatestAt: &now,
|
||||
})
|
||||
s.AppendContextSummary(ctx, convID, summary.SummaryID)
|
||||
}
|
||||
|
||||
// Force compact until under budget
|
||||
result, err := ce.CompactUntilUnder(ctx, convID, 2000)
|
||||
if err != nil {
|
||||
t.Fatalf("CompactUntilUnder: %v", err)
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectShallowestCondensationCandidate(t *testing.T) {
|
||||
ce, s, convID := newTestCompactionEngine(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create enough leaf summaries + fresh messages for candidates
|
||||
for i := 0; i < LeafMinFanout; i++ {
|
||||
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: "leaf",
|
||||
TokenCount: 100,
|
||||
})
|
||||
s.AppendContextSummary(ctx, convID, summary.SummaryID)
|
||||
}
|
||||
|
||||
// Add fresh tail messages so summaries are in evictable range
|
||||
for i := 0; i < FreshTailCount+1; i++ {
|
||||
m, _ := s.AddMessage(ctx, convID, "user", "fresh", 5)
|
||||
s.AppendContextMessage(ctx, convID, m.ID)
|
||||
}
|
||||
|
||||
candidates, err := ce.selectShallowestCondensationCandidate(ctx, convID, false)
|
||||
if err != nil {
|
||||
t.Fatalf("selectShallowestCondensationCandidate: %v", err)
|
||||
}
|
||||
|
||||
// Should find leaf summaries at depth 0
|
||||
if len(candidates) < CondensedMinFanout {
|
||||
t.Errorf("candidates = %d, want >= %d", len(candidates), CondensedMinFanout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectShallowestCondensationCandidateEmpty(t *testing.T) {
|
||||
ce, _, convID := newTestCompactionEngine(t)
|
||||
ctx := context.Background()
|
||||
|
||||
candidates, err := ce.selectShallowestCondensationCandidate(ctx, convID, false)
|
||||
if err != nil {
|
||||
t.Fatalf("selectShallowestCondensationCandidate: %v", err)
|
||||
}
|
||||
if len(candidates) != 0 {
|
||||
t.Errorf("candidates = %d, want 0 for empty context", len(candidates))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompactCondensedUsesSelectOldestChunk(t *testing.T) {
|
||||
// Verify that compactCondensed prefers ordinal-ordered chunks via selectOldestChunkAtDepth
|
||||
// rather than just grouping by depth without regard to order
|
||||
ce, s, convID := newTestCompactionEngine(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create interleaved summaries at depth 0 with a message in between:
|
||||
// sum1 (ordinal 100), msg (ordinal 200), sum2 (ordinal 300)
|
||||
|
||||
for i := 0; i < LeafMinFanout+2; i++ {
|
||||
now := time.Now().UTC()
|
||||
|
||||
s.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: fmt.Sprintf("leaf summary %d", i),
|
||||
TokenCount: 100,
|
||||
EarliestAt: &now,
|
||||
LatestAt: &now,
|
||||
})
|
||||
}
|
||||
|
||||
// Insert a message between first two summaries to break contiguity
|
||||
// for selectShallowestCondensationCandidate but would still find all 3
|
||||
// but selectOldestChunkAtDepth should only find sum1 + sum2 (not sum3)
|
||||
|
||||
msg, _ := s.AddMessage(ctx, convID, "user", "interrupting message", 5)
|
||||
s.AppendContextMessage(ctx, convID, msg.ID)
|
||||
|
||||
// Run compactCondensed
|
||||
result, err := ce.compactCondensed(ctx, convID)
|
||||
if err != nil {
|
||||
t.Fatalf("compactCondensed: %v", err)
|
||||
}
|
||||
|
||||
// The result should have merged the two summaries at the start
|
||||
// (skipping the message in between), This proves ordinal-aware selection works.
|
||||
|
||||
_ = result // verify summary was created
|
||||
|
||||
if result != nil {
|
||||
summaries, _ := s.GetSummariesByConversation(ctx, convID)
|
||||
found := false
|
||||
for _, sum := range summaries {
|
||||
if sum.Kind == SummaryKindCondensed {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("expected condensed summary to be created via ordinal-aware selection")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompactCondensedUsesOrdinalAwareSelection(t *testing.T) {
|
||||
ce, s, convID := newTestCompactionEngine(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create leaf summaries at depth 0 (total tokens >= CondensedTargetTokens)
|
||||
for i := 0; i < 5; i++ {
|
||||
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: fmt.Sprintf("leaf summary %d", i),
|
||||
TokenCount: 500, // 5 × 500 = 2500 >= CondensedTargetTokens (2000)
|
||||
})
|
||||
s.AppendContextSummary(ctx, convID, summary.SummaryID)
|
||||
}
|
||||
|
||||
// Add fresh tail
|
||||
for i := 0; i < FreshTailCount+1; i++ {
|
||||
m, _ := s.AddMessage(ctx, convID, "user", "fresh", 5)
|
||||
s.AppendContextMessage(ctx, convID, m.ID)
|
||||
}
|
||||
|
||||
chunk, err := ce.selectOldestChunkAtDepth(ctx, convID, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("selectOldestChunkAtDepth: %v", err)
|
||||
}
|
||||
if len(chunk) < 2 {
|
||||
t.Errorf("chunk length = %d, want >= 2 contiguous summaries", len(chunk))
|
||||
}
|
||||
for _, s := range chunk {
|
||||
if s.Depth != 0 {
|
||||
t.Errorf("got depth %d, want 0", s.Depth)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectOldestChunkAtDepthBreaksOnMessage(t *testing.T) {
|
||||
ce, s, convID := newTestCompactionEngine(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create 3 summaries, then a message, then 3 more summaries
|
||||
for i := 0; i < 3; i++ {
|
||||
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: fmt.Sprintf("leaf %d", i),
|
||||
TokenCount: 100,
|
||||
})
|
||||
s.AppendContextSummary(ctx, convID, summary.SummaryID)
|
||||
}
|
||||
msg, _ := s.AddMessage(ctx, convID, "user", "break", 10)
|
||||
s.AppendContextMessage(ctx, convID, msg.ID)
|
||||
for i := 0; i < 3; i++ {
|
||||
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: fmt.Sprintf("leaf-after %d", i),
|
||||
TokenCount: 100,
|
||||
})
|
||||
s.AppendContextSummary(ctx, convID, summary.SummaryID)
|
||||
}
|
||||
for i := 0; i < FreshTailCount+1; i++ {
|
||||
m, _ := s.AddMessage(ctx, convID, "user", "fresh", 5)
|
||||
s.AppendContextMessage(ctx, convID, m.ID)
|
||||
}
|
||||
|
||||
chunk, _ := ce.selectOldestChunkAtDepth(ctx, convID, 0)
|
||||
if len(chunk) > 3 {
|
||||
t.Errorf("chunk length = %d, want <= 3 (message breaks chain)", len(chunk))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectOldestChunkAtDepthMinTokens(t *testing.T) {
|
||||
ce, s, convID := newTestCompactionEngine(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create summaries with very low token counts (total < 2000)
|
||||
for i := 0; i < 5; i++ {
|
||||
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: fmt.Sprintf("tiny summary %d", i),
|
||||
TokenCount: 50, // very small
|
||||
})
|
||||
s.AppendContextSummary(ctx, convID, summary.SummaryID)
|
||||
}
|
||||
|
||||
// Add fresh tail to protect from compaction
|
||||
for i := 0; i < FreshTailCount+1; i++ {
|
||||
m, _ := s.AddMessage(ctx, convID, "user", fmt.Sprintf("tail %d", i), 10)
|
||||
s.AppendContextMessage(ctx, convID, m.ID)
|
||||
}
|
||||
|
||||
// Should return nil because total tokens (250) < 2000 minimum
|
||||
chunk, err := ce.selectOldestChunkAtDepth(ctx, convID, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("selectOldestChunkAtDepth: %v", err)
|
||||
}
|
||||
if len(chunk) > 0 {
|
||||
t.Errorf("expected empty chunk when tokens < 2000, got %d summaries", len(chunk))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectOldestChunkAtDepthPassesMinTokens(t *testing.T) {
|
||||
ce, s, convID := newTestCompactionEngine(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create summaries with enough tokens (total >= 2000)
|
||||
for i := 0; i < 5; i++ {
|
||||
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: fmt.Sprintf(
|
||||
"substantial summary with enough content to meet minimum token threshold for condensation candidate %d",
|
||||
i,
|
||||
),
|
||||
TokenCount: 500, // 5 × 500 = 2500 >= 2000
|
||||
})
|
||||
s.AppendContextSummary(ctx, convID, summary.SummaryID)
|
||||
}
|
||||
|
||||
// Add fresh tail
|
||||
for i := 0; i < FreshTailCount+1; i++ {
|
||||
m, _ := s.AddMessage(ctx, convID, "user", fmt.Sprintf("tail %d", i), 10)
|
||||
s.AppendContextMessage(ctx, convID, m.ID)
|
||||
}
|
||||
|
||||
// Should return chunk because total tokens (2500) >= 2000
|
||||
chunk, err := ce.selectOldestChunkAtDepth(ctx, convID, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("selectOldestChunkAtDepth: %v", err)
|
||||
}
|
||||
if len(chunk) == 0 {
|
||||
t.Error("expected non-empty chunk when tokens >= 2000")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateLeafSummary(t *testing.T) {
|
||||
ce, _, _ := newTestCompactionEngine(t)
|
||||
ctx := context.Background()
|
||||
|
||||
msgs := []Message{
|
||||
{Role: "user", Content: "hello world", TokenCount: 5},
|
||||
{Role: "assistant", Content: "hi there", TokenCount: 5},
|
||||
}
|
||||
|
||||
content, err := ce.generateLeafSummary(ctx, msgs, "")
|
||||
if err != nil {
|
||||
t.Fatalf("generateLeafSummary: %v", err)
|
||||
}
|
||||
if content == "" {
|
||||
t.Error("expected non-empty summary content")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateLeafSummaryEscalationToAggressive(t *testing.T) {
|
||||
// Level 1 returns summary that's too large (tokens >= input), should escalate to level 2
|
||||
var calls []string
|
||||
escalateComplete := func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) {
|
||||
if contains(prompt, "Aggressive summary policy") {
|
||||
calls = append(calls, "aggressive")
|
||||
return "Short aggressive summary.", nil
|
||||
}
|
||||
calls = append(calls, "normal")
|
||||
// Return a very long summary to trigger escalation
|
||||
longContent := make([]byte, 5000)
|
||||
for i := range longContent {
|
||||
longContent[i] = 'x'
|
||||
}
|
||||
return string(longContent), nil
|
||||
}
|
||||
|
||||
s := openTestStore(t)
|
||||
ce, _ := newTestCompactionEngineWithStore(s, escalateComplete)
|
||||
|
||||
msgs := []Message{
|
||||
{Role: "user", Content: "hello world", TokenCount: 10},
|
||||
{Role: "assistant", Content: "response", TokenCount: 10},
|
||||
}
|
||||
|
||||
content, err := ce.generateLeafSummary(context.Background(), msgs, "")
|
||||
if err != nil {
|
||||
t.Fatalf("generateLeafSummary: %v", err)
|
||||
}
|
||||
if content == "" {
|
||||
t.Error("expected non-empty summary content")
|
||||
}
|
||||
// Should have called both normal and aggressive
|
||||
foundNormal := false
|
||||
foundAggressive := false
|
||||
for _, c := range calls {
|
||||
if c == "normal" {
|
||||
foundNormal = true
|
||||
}
|
||||
if c == "aggressive" {
|
||||
foundAggressive = true
|
||||
}
|
||||
}
|
||||
if !foundNormal {
|
||||
t.Error("expected normal LLM call")
|
||||
}
|
||||
if !foundAggressive {
|
||||
t.Error("expected aggressive LLM call (level 2 escalation)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateLeafSummaryEscalationToTruncation(t *testing.T) {
|
||||
// Both normal and aggressive return empty, should escalate to level 3 truncation
|
||||
emptyComplete := func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
s := openTestStore(t)
|
||||
ce, _ := newTestCompactionEngineWithStore(s, emptyComplete)
|
||||
|
||||
msgs := []Message{
|
||||
{Role: "user", Content: "hello world from test", TokenCount: 10},
|
||||
{Role: "assistant", Content: "response text here", TokenCount: 10},
|
||||
}
|
||||
|
||||
content, err := ce.generateLeafSummary(context.Background(), msgs, "")
|
||||
if err != nil {
|
||||
t.Fatalf("generateLeafSummary: %v", err)
|
||||
}
|
||||
// Level 3 truncation should have produced something
|
||||
if content == "" {
|
||||
t.Error("expected non-empty content from level 3 truncation fallback")
|
||||
}
|
||||
if !contains(content, "Truncated from") {
|
||||
t.Errorf("expected truncation marker in content: %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateCondensedSummary(t *testing.T) {
|
||||
ce, _, _ := newTestCompactionEngine(t)
|
||||
ctx := context.Background()
|
||||
|
||||
summaries := []Summary{
|
||||
{SummaryID: "sum_a", Content: "first summary", TokenCount: 100},
|
||||
{SummaryID: "sum_b", Content: "second summary", TokenCount: 100},
|
||||
}
|
||||
|
||||
content, err := ce.generateCondensedSummary(ctx, summaries)
|
||||
if err != nil {
|
||||
t.Fatalf("generateCondensedSummary: %v", err)
|
||||
}
|
||||
if content == "" {
|
||||
t.Error("expected non-empty condensed summary content")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateCondensedSummaryEscalation(t *testing.T) {
|
||||
// When LLM returns empty, should fall back to deterministic concatenation
|
||||
emptyComplete := func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
s := openTestStore(t)
|
||||
ce, _ := newTestCompactionEngineWithStore(s, emptyComplete)
|
||||
|
||||
summaries := []Summary{
|
||||
{SummaryID: "sum_a", Content: "first summary text", TokenCount: 50},
|
||||
{SummaryID: "sum_b", Content: "second summary text", TokenCount: 50},
|
||||
}
|
||||
|
||||
content, err := ce.generateCondensedSummary(context.Background(), summaries)
|
||||
if err != nil {
|
||||
t.Fatalf("generateCondensedSummary: %v", err)
|
||||
}
|
||||
// Should fall back to concatenation
|
||||
if content == "" {
|
||||
t.Error("expected non-empty content from fallback")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Async Condensed Compaction (Phase 2) ---
|
||||
|
||||
func TestCompactAsyncReturnsBeforeCondensed(t *testing.T) {
|
||||
// Use a slow CompleteFn to verify Compact returns before condensed finishes
|
||||
var callCount int32
|
||||
slowComplete := func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) {
|
||||
atomic.AddInt32(&callCount, 1)
|
||||
time.Sleep(500 * time.Millisecond) // simulate slow LLM
|
||||
return "Slow condensed summary.", nil
|
||||
}
|
||||
|
||||
s := openTestStore(t)
|
||||
ctx := context.Background()
|
||||
conv, _ := s.GetOrCreateConversation(ctx, "test:async")
|
||||
convID := conv.ConversationID
|
||||
|
||||
ce, cancel := newTestCompactionEngineWithStore(s, slowComplete)
|
||||
t.Cleanup(func() {
|
||||
cancel()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
})
|
||||
|
||||
// Create enough leaf summaries for condensation + fresh tail
|
||||
for i := 0; i < CondensedMinFanout; i++ {
|
||||
now := time.Now().UTC()
|
||||
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: "leaf for async test",
|
||||
TokenCount: 500,
|
||||
EarliestAt: &now,
|
||||
LatestAt: &now,
|
||||
})
|
||||
s.AppendContextSummary(ctx, convID, summary.SummaryID)
|
||||
}
|
||||
for i := 0; i < FreshTailCount; i++ {
|
||||
m, _ := s.AddMessage(ctx, convID, "user", "fresh", 10)
|
||||
s.AppendContextMessage(ctx, convID, m.ID)
|
||||
}
|
||||
|
||||
// Compact with force — should return quickly, condensed runs async
|
||||
start := time.Now()
|
||||
result, err := ce.Compact(ctx, convID, CompactInput{Force: true})
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Compact: %v", err)
|
||||
}
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
}
|
||||
|
||||
// Should return well before the 500ms LLM call
|
||||
if elapsed > 200*time.Millisecond {
|
||||
t.Errorf("Compact took %v, should return before async condensed finishes", elapsed)
|
||||
}
|
||||
|
||||
// Wait for async to complete
|
||||
time.Sleep(800 * time.Millisecond)
|
||||
|
||||
// Verify condensed summary was created by background goroutine
|
||||
summaries, _ := s.GetSummariesByConversation(ctx, convID)
|
||||
foundCondensed := false
|
||||
for _, sum := range summaries {
|
||||
if sum.Kind == SummaryKindCondensed {
|
||||
foundCondensed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundCondensed {
|
||||
t.Error("expected at least one condensed summary from async Phase 2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompactAsyncDedup(t *testing.T) {
|
||||
var callCount int32
|
||||
slowComplete := func(ctx context.Context, prompt string, opts CompleteOptions) (string, error) {
|
||||
atomic.AddInt32(&callCount, 1)
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
return "Slow condensed summary.", nil
|
||||
}
|
||||
|
||||
s := openTestStore(t)
|
||||
ctx := context.Background()
|
||||
conv, _ := s.GetOrCreateConversation(ctx, "test:dedup")
|
||||
convID := conv.ConversationID
|
||||
|
||||
ce, cancel := newTestCompactionEngineWithStore(s, slowComplete)
|
||||
t.Cleanup(func() {
|
||||
cancel()
|
||||
waitForCondensed(ce, convID, 2*time.Second)
|
||||
})
|
||||
|
||||
// Create conditions for condensed compaction
|
||||
for i := 0; i < CondensedMinFanout; i++ {
|
||||
now := time.Now().UTC()
|
||||
summary, _ := s.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: "leaf for dedup",
|
||||
TokenCount: 500,
|
||||
EarliestAt: &now,
|
||||
LatestAt: &now,
|
||||
})
|
||||
s.AppendContextSummary(ctx, convID, summary.SummaryID)
|
||||
}
|
||||
for i := 0; i < FreshTailCount; i++ {
|
||||
m, _ := s.AddMessage(ctx, convID, "user", "fresh", 10)
|
||||
s.AppendContextMessage(ctx, convID, m.ID)
|
||||
}
|
||||
|
||||
// Call Compact twice rapidly
|
||||
ce.Compact(ctx, convID, CompactInput{Force: true})
|
||||
ce.Compact(ctx, convID, CompactInput{Force: true})
|
||||
|
||||
// Wait for async to finish
|
||||
time.Sleep(600 * time.Millisecond)
|
||||
|
||||
// LLM should only be called once for condensed (dedup)
|
||||
// callCount may be 0 if no leaf was created (only condensed in goroutine)
|
||||
// The key is that we don't get 2+ condensed calls
|
||||
if atomic.LoadInt32(&callCount) > 1 {
|
||||
t.Errorf("LLM called %d times, expected at most 1 (dedup)", callCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompactLeafForceBypassesFreshTail(t *testing.T) {
|
||||
// Spec: compactLeaf with force=true should bypass FreshTailCount protection
|
||||
// so CompactUntilUnder can compress messages inside the fresh tail
|
||||
ce, s, convID := newTestCompactionEngine(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create exactly FreshTailCount+4 messages (36 total)
|
||||
// Without force: all messages are in fresh tail → no candidate
|
||||
// With force: should compact the oldest messages
|
||||
total := FreshTailCount + 4
|
||||
for i := 0; i < total; i++ {
|
||||
m, _ := s.AddMessage(ctx, convID, "user", fmt.Sprintf("message %d for force test", i), 100)
|
||||
s.AppendContextMessage(ctx, convID, m.ID)
|
||||
}
|
||||
|
||||
// Without force: should return nil (all in fresh tail)
|
||||
summaryID, err := ce.compactLeaf(ctx, convID)
|
||||
if err != nil {
|
||||
t.Fatalf("compactLeaf no-force: %v", err)
|
||||
}
|
||||
if summaryID != nil {
|
||||
t.Error("expected nil without force (all messages in fresh tail)")
|
||||
}
|
||||
|
||||
// With force: should compact despite fresh tail protection
|
||||
summaryID, err = ce.compactLeaf(ctx, convID, true)
|
||||
if err != nil {
|
||||
t.Fatalf("compactLeaf force: %v", err)
|
||||
}
|
||||
if summaryID == nil {
|
||||
t.Error("expected summary with force=true (bypasses fresh tail)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompactLeafAccumulatesUpToLeafChunkTokens(t *testing.T) {
|
||||
// Spec: compactLeaf should accumulate messages up to LeafChunkTokens before stopping
|
||||
// It should NOT take the entire contiguous chunk regardless of token count
|
||||
ce, s, convID := newTestCompactionEngine(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create messages totaling far more than LeafChunkTokens (20000)
|
||||
// Each message is ~500 tokens, create 80 messages = 40000 tokens
|
||||
for i := 0; i < 80; i++ {
|
||||
m, _ := s.AddMessage(
|
||||
ctx,
|
||||
convID,
|
||||
"user",
|
||||
fmt.Sprintf(
|
||||
"message %d with lots of content to make it big enough for token counting purposes and this should be a substantial message body that represents a meaningful conversation turn",
|
||||
i,
|
||||
),
|
||||
500,
|
||||
)
|
||||
s.AppendContextMessage(ctx, convID, m.ID)
|
||||
}
|
||||
|
||||
summaryID, err := ce.compactLeaf(ctx, convID)
|
||||
if err != nil {
|
||||
t.Fatalf("compactLeaf: %v", err)
|
||||
}
|
||||
if summaryID == nil {
|
||||
t.Fatal("expected a summary to be created")
|
||||
}
|
||||
|
||||
// The source messages that were compacted should total roughly LeafChunkTokens (20000),
|
||||
// not the entire 40000 tokens worth of messages
|
||||
summary, _ := s.GetSummary(ctx, *summaryID)
|
||||
if summary == nil {
|
||||
t.Fatal("summary not found")
|
||||
}
|
||||
|
||||
// Source message tokens should be roughly <= LeafChunkTokens (20000)
|
||||
// Spec says: "Stop when accumulated tokens >= LeafChunkTokens"
|
||||
if summary.SourceMessageTokenCount > LeafChunkTokens {
|
||||
t.Errorf("source tokens = %d, should be <= LeafChunkTokens (%d)",
|
||||
summary.SourceMessageTokenCount, LeafChunkTokens)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
package seahorse
|
||||
|
||||
// Short-term memory configuration constants — all are experience-based defaults.
|
||||
|
||||
const (
|
||||
// OrdinalStep is the gap between ordinals in context_items.
|
||||
// Insert at midpoint; resequence only when precision exhausted.
|
||||
OrdinalStep = 100
|
||||
|
||||
// ContextThreshold is the compaction trigger for the context window.
|
||||
ContextThreshold float64 = 0.75 // Compact at 75% of context window
|
||||
FreshTailCount int = 32 // Recent messages protected from compaction
|
||||
|
||||
// LeafMinFanout is the fanout parameter.
|
||||
LeafMinFanout int = 8 // Min messages per leaf summary
|
||||
CondensedMinFanout int = 4 // Min summaries per condensed
|
||||
CondensedMinFanoutHard int = 2 // Min for forced compaction
|
||||
|
||||
// LeafChunkTokens is the token target.
|
||||
LeafChunkTokens int = 20000 // Max tokens per leaf chunk
|
||||
LeafTargetTokens int = 1200 // Target tokens for leaf summaries
|
||||
CondensedTargetTokens int = 2000 // Target tokens for condensed summaries
|
||||
MaxExpandTokens int = 4000 // Token cap for expansion queries
|
||||
|
||||
// MaxCompactIterations caps CompactUntilUnder to prevent infinite loops.
|
||||
// Each iteration reduces ~4x tokens via leaf (8:1) or condensed (4:1) compaction.
|
||||
// With a 200k token context window and 75% threshold, ~20 iterations is enough
|
||||
// for any realistic scenario. If exceeded, the issue is logged as a warning.
|
||||
MaxCompactIterations int = 20
|
||||
)
|
||||
@@ -0,0 +1,568 @@
|
||||
package seahorse
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
// Config holds engine configuration.
|
||||
type Config struct {
|
||||
DBPath string `json:"dbPath"`
|
||||
IgnoreSessionPatterns []string `json:"ignoreSessionPatterns,omitempty"`
|
||||
StatelessSessionPatterns []string `json:"statelessSessionPatterns,omitempty"`
|
||||
}
|
||||
|
||||
// CompleteFn is the LLM completion function type.
|
||||
type CompleteFn func(ctx context.Context, prompt string, opts CompleteOptions) (string, error)
|
||||
|
||||
// CompleteOptions holds LLM completion parameters.
|
||||
type CompleteOptions struct {
|
||||
Model string
|
||||
MaxTokens int
|
||||
Temperature float64
|
||||
}
|
||||
|
||||
// IngestResult is the result of message ingestion.
|
||||
type IngestResult struct {
|
||||
MessageCount int `json:"messageCount"`
|
||||
TokenCount int `json:"tokenCount"`
|
||||
}
|
||||
|
||||
// AssembleInput controls context assembly.
|
||||
type AssembleInput struct {
|
||||
Budget int `json:"budget"`
|
||||
Query string `json:"query,omitempty"`
|
||||
}
|
||||
|
||||
// AssembleResult contains assembled context.
|
||||
type AssembleResult struct {
|
||||
Messages []Message `json:"messages"`
|
||||
Summary string `json:"summary"` // formatted XML summaries + system prompt addition
|
||||
}
|
||||
|
||||
const numSessionShards = 256
|
||||
|
||||
// Engine is the main short-term memory engine.
|
||||
type Engine struct {
|
||||
store *Store
|
||||
compaction *CompactionEngine
|
||||
compactionMu sync.Mutex
|
||||
assembler *Assembler
|
||||
assemblerMu sync.Mutex
|
||||
retrieval *RetrievalEngine
|
||||
config Config
|
||||
complete CompleteFn
|
||||
ignorePatterns []*regexp.Regexp
|
||||
statelessPatterns []*regexp.Regexp
|
||||
sessionShards [numSessionShards]struct {
|
||||
mu sync.Mutex
|
||||
}
|
||||
}
|
||||
|
||||
// CompactionEngine handles LLM-based summarization (defined in short_compaction.go).
|
||||
type CompactionEngine struct {
|
||||
store *Store
|
||||
config Config
|
||||
complete CompleteFn
|
||||
condensing sync.Map // map[int64]struct{} — dedup for async condensed goroutines
|
||||
shutdownCtx context.Context
|
||||
shutdownCancel context.CancelFunc
|
||||
}
|
||||
|
||||
// Assembler handles budget-aware context assembly (defined in short_assembler.go).
|
||||
type Assembler struct {
|
||||
store *Store
|
||||
config Config
|
||||
}
|
||||
|
||||
// RetrievalEngine handles search and expansion (defined in short_retrieval.go).
|
||||
type RetrievalEngine struct {
|
||||
store *Store
|
||||
config Config
|
||||
}
|
||||
|
||||
// Store returns the underlying store for direct access.
|
||||
func (r *RetrievalEngine) Store() *Store {
|
||||
return r.store
|
||||
}
|
||||
|
||||
// NewEngine creates a new short-term memory engine.
|
||||
func NewEngine(config Config, completeFn CompleteFn) (*Engine, error) {
|
||||
dir := filepath.Dir(config.DBPath)
|
||||
if dir != "" && dir != "." {
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return nil, fmt.Errorf("create db directory: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
db, err := sql.Open("sqlite", config.DBPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open db: %w", err)
|
||||
}
|
||||
|
||||
// Configure SQLite for concurrent access
|
||||
if _, err := db.Exec("PRAGMA journal_mode = WAL;"); err != nil {
|
||||
db.Close()
|
||||
return nil, fmt.Errorf("enable WAL: %w", err)
|
||||
}
|
||||
if _, err := db.Exec("PRAGMA busy_timeout = 5000;"); err != nil {
|
||||
db.Close()
|
||||
return nil, fmt.Errorf("set busy_timeout: %w", err)
|
||||
}
|
||||
if _, err := db.Exec("PRAGMA synchronous = NORMAL;"); err != nil {
|
||||
db.Close()
|
||||
return nil, fmt.Errorf("set synchronous: %w", err)
|
||||
}
|
||||
|
||||
if err := runSchema(db); err != nil {
|
||||
db.Close()
|
||||
return nil, fmt.Errorf("migrations: %w", err)
|
||||
}
|
||||
|
||||
store := &Store{db: db}
|
||||
|
||||
// Prepend hardcoded ignore patterns (spec lines 1326-1328)
|
||||
ignorePatterns := make([]string, 0, 1+len(config.IgnoreSessionPatterns))
|
||||
ignorePatterns = append(ignorePatterns, "heartbeat")
|
||||
ignorePatterns = append(ignorePatterns, config.IgnoreSessionPatterns...)
|
||||
|
||||
retrieval := &RetrievalEngine{store: store, config: config}
|
||||
|
||||
return &Engine{
|
||||
store: store,
|
||||
compaction: nil,
|
||||
assembler: nil,
|
||||
retrieval: retrieval,
|
||||
config: config,
|
||||
complete: completeFn,
|
||||
ignorePatterns: compileSessionPatterns(ignorePatterns),
|
||||
statelessPatterns: compileSessionPatterns(config.StatelessSessionPatterns),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// compileSessionPattern converts a glob pattern to a compiled regex.
|
||||
// Pattern rules:
|
||||
// - * matches any sequence of non-colon characters ([^:]*)
|
||||
// - ** matches any sequence of characters including colons (.*)
|
||||
// - All other characters are treated literally
|
||||
// - Pattern is anchored (^...$)
|
||||
func compileSessionPattern(pattern string) *regexp.Regexp {
|
||||
var b strings.Builder
|
||||
b.WriteByte('^')
|
||||
|
||||
i := 0
|
||||
for i < len(pattern) {
|
||||
if i+1 < len(pattern) && pattern[i] == '*' && pattern[i+1] == '*' {
|
||||
b.WriteString(".*")
|
||||
i += 2
|
||||
continue
|
||||
}
|
||||
if pattern[i] == '*' {
|
||||
b.WriteString("[^:]*")
|
||||
i++
|
||||
continue
|
||||
}
|
||||
b.WriteString(regexp.QuoteMeta(string(pattern[i])))
|
||||
i++
|
||||
}
|
||||
|
||||
b.WriteByte('$')
|
||||
return regexp.MustCompile(b.String())
|
||||
}
|
||||
|
||||
// compileSessionPatterns compiles multiple glob patterns into regex patterns.
|
||||
func compileSessionPatterns(patterns []string) []*regexp.Regexp {
|
||||
result := make([]*regexp.Regexp, 0, len(patterns))
|
||||
for _, p := range patterns {
|
||||
if p == "" {
|
||||
continue
|
||||
}
|
||||
result = append(result, compileSessionPattern(p))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// shouldIgnoreSession returns true if the session key matches any ignore pattern.
|
||||
func (e *Engine) shouldIgnoreSession(sessionKey string) bool {
|
||||
for _, p := range e.ignorePatterns {
|
||||
if p.MatchString(sessionKey) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isStatelessSession returns true if the session key matches any stateless pattern.
|
||||
func (e *Engine) isStatelessSession(sessionKey string) bool {
|
||||
for _, p := range e.statelessPatterns {
|
||||
if p.MatchString(sessionKey) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// fnv32 computes FNV-1a 32-bit hash for session key sharding.
|
||||
func fnv32(key string) uint32 {
|
||||
h := uint32(2166136261)
|
||||
for _, c := range key {
|
||||
h ^= uint32(c)
|
||||
h *= 16777619
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
// getSessionMutex returns the sharded mutex for a session key.
|
||||
func (e *Engine) getSessionMutex(sessionKey string) *sync.Mutex {
|
||||
h := fnv32(sessionKey)
|
||||
shard := h % numSessionShards
|
||||
return &e.sessionShards[shard].mu
|
||||
}
|
||||
|
||||
// Ingest adds messages to a conversation identified by sessionKey.
|
||||
func (e *Engine) Ingest(ctx context.Context, sessionKey string, messages []Message) (*IngestResult, error) {
|
||||
if e.shouldIgnoreSession(sessionKey) {
|
||||
return nil, nil
|
||||
}
|
||||
if e.isStatelessSession(sessionKey) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
mu := e.getSessionMutex(sessionKey)
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
conv, err := e.store.GetOrCreateConversation(ctx, sessionKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get conversation: %w", err)
|
||||
}
|
||||
|
||||
var totalTokens int
|
||||
var msgIDs []int64
|
||||
for _, msg := range messages {
|
||||
var added *Message
|
||||
var err error
|
||||
if len(msg.Parts) > 0 {
|
||||
added, err = e.store.AddMessageWithParts(ctx, conv.ConversationID, msg.Role, msg.Parts, msg.TokenCount)
|
||||
} else {
|
||||
added, err = e.store.AddMessage(ctx, conv.ConversationID, msg.Role, msg.Content, msg.TokenCount)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("add message: %w", err)
|
||||
}
|
||||
totalTokens += msg.TokenCount
|
||||
msgIDs = append(msgIDs, added.ID)
|
||||
}
|
||||
|
||||
// Append to context_items using actual inserted IDs
|
||||
if err := e.store.AppendContextMessages(ctx, conv.ConversationID, msgIDs); err != nil {
|
||||
return nil, fmt.Errorf("append context: %w", err)
|
||||
}
|
||||
|
||||
logger.InfoCF("seahorse", "ingest", map[string]any{
|
||||
"conv_id": conv.ConversationID,
|
||||
"messages": len(messages),
|
||||
"tokens": totalTokens,
|
||||
})
|
||||
return &IngestResult{
|
||||
MessageCount: len(messages),
|
||||
TokenCount: totalTokens,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close releases resources.
|
||||
func (e *Engine) Close() error {
|
||||
// Signal compaction goroutines to stop
|
||||
if e.compaction != nil {
|
||||
e.compaction.Close()
|
||||
}
|
||||
if e.store != nil && e.store.db != nil {
|
||||
return e.store.db.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetRetrieval returns the retrieval engine for tool implementations.
|
||||
func (e *Engine) GetRetrieval() *RetrievalEngine {
|
||||
return e.retrieval
|
||||
}
|
||||
|
||||
// Assemble builds budget-constrained context for a session.
|
||||
func (e *Engine) Assemble(ctx context.Context, sessionKey string, input AssembleInput) (*AssembleResult, error) {
|
||||
if e.shouldIgnoreSession(sessionKey) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
conv, err := e.store.GetOrCreateConversation(ctx, sessionKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get conversation: %w", err)
|
||||
}
|
||||
|
||||
e.initAssemblerOnce()
|
||||
return e.assembler.Assemble(ctx, conv.ConversationID, input)
|
||||
}
|
||||
|
||||
// Compact compresses conversation history for a session.
|
||||
func (e *Engine) Compact(ctx context.Context, sessionKey string, input CompactInput) (*CompactResult, error) {
|
||||
if e.shouldIgnoreSession(sessionKey) || e.isStatelessSession(sessionKey) {
|
||||
return &CompactResult{}, nil
|
||||
}
|
||||
|
||||
conv, err := e.store.GetOrCreateConversation(ctx, sessionKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get conversation: %w", err)
|
||||
}
|
||||
|
||||
e.initCompactionOnce()
|
||||
return e.compaction.Compact(ctx, conv.ConversationID, input)
|
||||
}
|
||||
|
||||
// CompactUntilUnder aggressively compacts until context is under budget.
|
||||
// Used for emergency compaction after LLM overflow (retry reason).
|
||||
func (e *Engine) CompactUntilUnder(ctx context.Context, sessionKey string, budget int) (*CompactResult, error) {
|
||||
if e.shouldIgnoreSession(sessionKey) || e.isStatelessSession(sessionKey) {
|
||||
return &CompactResult{}, nil
|
||||
}
|
||||
|
||||
conv, err := e.store.GetOrCreateConversation(ctx, sessionKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get conversation: %w", err)
|
||||
}
|
||||
|
||||
e.initCompactionOnce()
|
||||
return e.compaction.CompactUntilUnder(ctx, conv.ConversationID, budget)
|
||||
}
|
||||
|
||||
// initCompactionOnce lazily initializes the compaction engine.
|
||||
func (e *Engine) initCompactionOnce() {
|
||||
if e.compaction == nil {
|
||||
e.compactionMu.Lock()
|
||||
defer e.compactionMu.Unlock()
|
||||
if e.compaction == nil {
|
||||
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
|
||||
e.compaction = &CompactionEngine{
|
||||
store: e.store,
|
||||
config: e.config,
|
||||
complete: e.complete,
|
||||
shutdownCtx: shutdownCtx,
|
||||
shutdownCancel: shutdownCancel,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// initAssemblerOnce lazily initializes the assembler.
|
||||
func (e *Engine) initAssemblerOnce() {
|
||||
if e.assembler == nil {
|
||||
e.assemblerMu.Lock()
|
||||
defer e.assemblerMu.Unlock()
|
||||
if e.assembler == nil {
|
||||
e.assembler = &Assembler{store: e.store, config: e.config}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// IngestMessages is an alias for Ingest.
|
||||
func (e *Engine) IngestMessages(ctx context.Context, sessionKey string, messages []Message) (*IngestResult, error) {
|
||||
return e.Ingest(ctx, sessionKey, messages)
|
||||
}
|
||||
|
||||
// Bootstrap reconciles a session's messages with the database.
|
||||
// Called once at startup for each known session.
|
||||
// Bootstrap reconciles JSONL history with SQLite by ingesting only the delta.
|
||||
// Simple approach: find longest matching prefix and append delta.
|
||||
// If any mismatch is detected, clear and rebuild.
|
||||
func (e *Engine) Bootstrap(ctx context.Context, sessionKey string, messages []Message) error {
|
||||
if e.shouldIgnoreSession(sessionKey) {
|
||||
return nil
|
||||
}
|
||||
if e.isStatelessSession(sessionKey) {
|
||||
return nil
|
||||
}
|
||||
if len(messages) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
conv, err := e.store.GetOrCreateConversation(ctx, sessionKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("bootstrap: get conversation: %w", err)
|
||||
}
|
||||
|
||||
// Get messages already in DB
|
||||
dbMsgs, err := e.store.GetMessages(ctx, conv.ConversationID, len(messages), 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("bootstrap: get messages: %w", err)
|
||||
}
|
||||
|
||||
// Fast path: DB has same count and exact match → no-op
|
||||
if len(dbMsgs) == len(messages) {
|
||||
matched := true
|
||||
for i := 0; i < len(messages); i++ {
|
||||
if !messageMatches(dbMsgs[i], messages[i]) {
|
||||
matched = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if matched {
|
||||
return nil // DB is up to date
|
||||
}
|
||||
}
|
||||
|
||||
// Find longest matching prefix from the start
|
||||
anchor := -1
|
||||
compareLen := len(dbMsgs)
|
||||
if compareLen > len(messages) {
|
||||
compareLen = len(messages)
|
||||
}
|
||||
|
||||
for i := 0; i < compareLen; i++ {
|
||||
if messageMatches(dbMsgs[i], messages[i]) {
|
||||
anchor = i
|
||||
} else {
|
||||
// Mismatch detected - log details and rebuild
|
||||
logger.InfoCF("seahorse", "bootstrap: mismatch detected", map[string]any{
|
||||
"conv_id": conv.ConversationID,
|
||||
"index": i,
|
||||
"db_role": dbMsgs[i].Role,
|
||||
"db_content": truncate(dbMsgs[i].Content, 50),
|
||||
"db_parts": len(dbMsgs[i].Parts),
|
||||
"msg_role": messages[i].Role,
|
||||
"msg_content": truncate(messages[i].Content, 50),
|
||||
"msg_parts": len(messages[i].Parts),
|
||||
})
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// If we hit a mismatch before reaching the end of DB messages, delete delta and re-ingest
|
||||
// Note: anchor can be -1 if first message didn't match (history completely changed)
|
||||
if anchor >= 0 && anchor < len(dbMsgs)-1 && len(dbMsgs) > 0 {
|
||||
anchorID := dbMsgs[anchor].ID
|
||||
logger.InfoCF("seahorse", "bootstrap: history edit detected", map[string]any{
|
||||
"conv_id": conv.ConversationID,
|
||||
"db_count": len(dbMsgs),
|
||||
"anchor": anchor,
|
||||
"anchor_id": anchorID,
|
||||
"msg_count": len(messages),
|
||||
"delta_start": anchor + 1,
|
||||
})
|
||||
|
||||
// Delete messages after anchor (also clears context_items)
|
||||
if err := e.store.DeleteMessagesAfterID(ctx, conv.ConversationID, anchorID); err != nil {
|
||||
return fmt.Errorf("bootstrap: delete messages: %w", err)
|
||||
}
|
||||
|
||||
// Re-ingest from anchor+1 to end
|
||||
delta := messages[anchor+1:]
|
||||
if len(delta) > 0 {
|
||||
_, err := e.Ingest(ctx, sessionKey, delta)
|
||||
if err != nil {
|
||||
return fmt.Errorf("bootstrap: re-ingest: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Normal case: append delta after anchor
|
||||
if anchor >= 0 && anchor < len(messages)-1 {
|
||||
delta := messages[anchor+1:]
|
||||
if len(delta) > 0 {
|
||||
_, err := e.Ingest(ctx, sessionKey, delta)
|
||||
if err != nil {
|
||||
return fmt.Errorf("bootstrap: ingest delta: %w", err)
|
||||
}
|
||||
}
|
||||
} else if anchor == -1 && len(dbMsgs) > 0 {
|
||||
// First message changed (history completely different) - rebuild from scratch
|
||||
logger.InfoCF("seahorse", "bootstrap: history replaced, rebuilding", map[string]any{
|
||||
"conv_id": conv.ConversationID,
|
||||
"db_count": len(dbMsgs),
|
||||
"msg_count": len(messages),
|
||||
})
|
||||
// Delete all existing messages
|
||||
if err := e.store.DeleteMessagesAfterID(ctx, conv.ConversationID, 0); err != nil {
|
||||
return fmt.Errorf("bootstrap: delete all messages: %w", err)
|
||||
}
|
||||
// Re-ingest everything
|
||||
if len(messages) > 0 {
|
||||
_, err := e.Ingest(ctx, sessionKey, messages)
|
||||
if err != nil {
|
||||
return fmt.Errorf("bootstrap: re-ingest all: %w", err)
|
||||
}
|
||||
}
|
||||
} else if anchor == -1 && len(dbMsgs) == 0 {
|
||||
// DB is empty, ingest everything
|
||||
_, err := e.Ingest(ctx, sessionKey, messages)
|
||||
if err != nil {
|
||||
return fmt.Errorf("bootstrap: ingest all: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// truncate shortens a string for logging.
|
||||
func truncate(s string, maxLen int) string {
|
||||
if len(s) <= maxLen {
|
||||
return s
|
||||
}
|
||||
return s[:maxLen] + "..."
|
||||
}
|
||||
|
||||
// messageMatches compares two messages using (role, content) or (role, parts).
|
||||
// TokenCount is NOT compared because it may be re-estimated differently
|
||||
// during bootstrap (e.g., via tokenizer.EstimateMessageTokens).
|
||||
// For messages with Parts (tool_use, tool_result), compare Parts instead of Content
|
||||
// since AddMessageWithParts stores empty Content in DB.
|
||||
func messageMatches(a, b Message) bool {
|
||||
if a.Role != b.Role {
|
||||
return false
|
||||
}
|
||||
// If either message has Parts, compare Parts
|
||||
if len(a.Parts) > 0 || len(b.Parts) > 0 {
|
||||
return partsMatch(a.Parts, b.Parts)
|
||||
}
|
||||
// Simple text messages: compare Content
|
||||
return a.Content == b.Content
|
||||
}
|
||||
|
||||
// partsMatch compares two slices of MessagePart for equality.
|
||||
func partsMatch(a, b []MessagePart) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i].Type != b[i].Type {
|
||||
return false
|
||||
}
|
||||
switch a[i].Type {
|
||||
case "text":
|
||||
if a[i].Text != b[i].Text {
|
||||
return false
|
||||
}
|
||||
case "tool_use":
|
||||
if a[i].Name != b[i].Name || a[i].Arguments != b[i].Arguments || a[i].ToolCallID != b[i].ToolCallID {
|
||||
return false
|
||||
}
|
||||
case "tool_result":
|
||||
if a[i].ToolCallID != b[i].ToolCallID || a[i].Text != b[i].Text {
|
||||
return false
|
||||
}
|
||||
case "media":
|
||||
if a[i].MediaURI != b[i].MediaURI || a[i].MimeType != b[i].MimeType {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,212 @@
|
||||
package seahorse
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ParseLastDuration parses a "last" duration string like "6h", "7d", "2w", "1m".
|
||||
// Returns the duration and nil error, or zero and error if invalid.
|
||||
func ParseLastDuration(s string) (time.Duration, error) {
|
||||
if s == "" {
|
||||
return 0, fmt.Errorf("empty duration")
|
||||
}
|
||||
|
||||
re := regexp.MustCompile(`^(\d+)([hdwm])$`)
|
||||
matches := re.FindStringSubmatch(s)
|
||||
if matches == nil {
|
||||
return 0, fmt.Errorf("invalid duration format: %q (use format like 6h, 7d, 2w, 1m)", s)
|
||||
}
|
||||
|
||||
value, _ := strconv.Atoi(matches[1])
|
||||
unit := matches[2]
|
||||
|
||||
switch unit {
|
||||
case "h":
|
||||
return time.Duration(value) * time.Hour, nil
|
||||
case "d":
|
||||
return time.Duration(value) * 24 * time.Hour, nil
|
||||
case "w":
|
||||
return time.Duration(value) * 7 * 24 * time.Hour, nil
|
||||
case "m":
|
||||
return time.Duration(value) * 30 * 24 * time.Hour, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unknown unit: %q", unit)
|
||||
}
|
||||
}
|
||||
|
||||
// GrepInput controls search across summaries and messages.
|
||||
type GrepInput struct {
|
||||
Pattern string `json:"pattern"`
|
||||
Scope string `json:"scope,omitempty"` // "both" (default), "summary", or "message"
|
||||
Role string `json:"role,omitempty"` // "user", "assistant", or "" (all)
|
||||
AllConversations bool `json:"allConversations,omitempty"`
|
||||
Since *time.Time `json:"since,omitempty"`
|
||||
Before *time.Time `json:"before,omitempty"`
|
||||
Last string `json:"last,omitempty"` // shortcut: "6h", "7d", "2w", "1m"
|
||||
Limit int `json:"limit,omitempty"`
|
||||
}
|
||||
|
||||
// GrepResult contains search results.
|
||||
type GrepResult struct {
|
||||
Success bool `json:"success"`
|
||||
Summaries []GrepSummaryResult `json:"summaries"`
|
||||
Messages []GrepMessageResult `json:"messages"`
|
||||
TotalSummaries int `json:"totalSummaries"`
|
||||
TotalMessages int `json:"totalMessages"`
|
||||
Hint string `json:"hint,omitempty"`
|
||||
}
|
||||
|
||||
// GrepSummaryResult is a summary match from grep.
|
||||
type GrepSummaryResult struct {
|
||||
ID string `json:"id"`
|
||||
Content string `json:"content"`
|
||||
Depth int `json:"depth"`
|
||||
Kind SummaryKind `json:"kind"`
|
||||
ConversationID int64 `json:"conversationId"`
|
||||
// Rank is the bm25 relevance score (negative value, closer to 0 = better match).
|
||||
// Examples: -0.5 = excellent match, -2.0 = good match, -10.0 = partial match.
|
||||
Rank float64 `json:"rank,omitempty"`
|
||||
}
|
||||
|
||||
// GrepMessageResult is a message match from grep.
|
||||
type GrepMessageResult struct {
|
||||
ID int64 `json:"id,string"`
|
||||
Snippet string `json:"snippet"`
|
||||
Role string `json:"role"`
|
||||
ConversationID int64 `json:"conversationId"`
|
||||
Rank float64 `json:"rank,omitempty"` // Relevance score (lower = better match)
|
||||
}
|
||||
|
||||
// ExpandMessagesResult contains expanded messages.
|
||||
type ExpandMessagesResult struct {
|
||||
Messages []Message `json:"messages"`
|
||||
TokenCount int `json:"tokenCount"`
|
||||
}
|
||||
|
||||
// Grep searches summaries and messages for matching content.
|
||||
func (r *RetrievalEngine) Grep(ctx context.Context, input GrepInput) (*GrepResult, error) {
|
||||
if input.Pattern == "" {
|
||||
return nil, fmt.Errorf("grep: pattern is required")
|
||||
}
|
||||
|
||||
limit := input.Limit
|
||||
if limit == 0 {
|
||||
limit = 20
|
||||
}
|
||||
|
||||
// Handle Last parameter: convert to Since
|
||||
since := input.Since
|
||||
if input.Last != "" {
|
||||
dur, err := ParseLastDuration(input.Last)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("grep: invalid last: %w", err)
|
||||
}
|
||||
t := time.Now().UTC().Add(-dur)
|
||||
since = &t
|
||||
}
|
||||
|
||||
// Auto-detect mode: use LIKE if pattern contains %, otherwise full-text
|
||||
mode := ""
|
||||
if strings.Contains(input.Pattern, "%") {
|
||||
mode = "like"
|
||||
}
|
||||
|
||||
searchInput := SearchInput{
|
||||
Pattern: input.Pattern,
|
||||
Mode: mode,
|
||||
Role: input.Role,
|
||||
AllConversations: input.AllConversations,
|
||||
Since: since,
|
||||
Before: input.Before,
|
||||
Limit: limit,
|
||||
}
|
||||
|
||||
result := &GrepResult{
|
||||
Success: true,
|
||||
Summaries: make([]GrepSummaryResult, 0),
|
||||
Messages: make([]GrepMessageResult, 0),
|
||||
TotalSummaries: 0,
|
||||
TotalMessages: 0,
|
||||
}
|
||||
|
||||
// Determine scope
|
||||
scope := input.Scope
|
||||
if scope == "" {
|
||||
scope = "both"
|
||||
}
|
||||
|
||||
// Search summaries if requested
|
||||
if scope == "both" || scope == "summary" {
|
||||
sumResults, err := r.store.SearchSummaries(ctx, searchInput)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("search summaries: %w", err)
|
||||
}
|
||||
for _, sr := range sumResults {
|
||||
if sr.SummaryID != "" {
|
||||
result.Summaries = append(result.Summaries, GrepSummaryResult{
|
||||
ID: sr.SummaryID,
|
||||
Content: sr.Content,
|
||||
Depth: sr.Depth,
|
||||
Kind: sr.Kind,
|
||||
ConversationID: sr.ConversationID,
|
||||
Rank: sr.Rank,
|
||||
})
|
||||
}
|
||||
}
|
||||
if len(sumResults) > 0 {
|
||||
result.TotalSummaries = sumResults[0].TotalCount
|
||||
}
|
||||
}
|
||||
|
||||
// Search messages if requested
|
||||
if scope == "both" || scope == "message" {
|
||||
msgResults, err := r.store.SearchMessages(ctx, searchInput)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("search messages: %w", err)
|
||||
}
|
||||
for _, sr := range msgResults {
|
||||
if sr.MessageID > 0 {
|
||||
result.Messages = append(result.Messages, GrepMessageResult{
|
||||
ID: sr.MessageID,
|
||||
Snippet: sr.Snippet,
|
||||
Role: sr.Role,
|
||||
ConversationID: sr.ConversationID,
|
||||
Rank: sr.Rank,
|
||||
})
|
||||
}
|
||||
}
|
||||
if len(msgResults) > 0 {
|
||||
result.TotalMessages = msgResults[0].TotalCount
|
||||
}
|
||||
}
|
||||
|
||||
// Add hint if no results
|
||||
if len(result.Summaries) == 0 && len(result.Messages) == 0 {
|
||||
result.Hint = "No matches. Try: %keyword% for fuzzy search, or all_conversations: true"
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ExpandMessages retrieves full message content by IDs.
|
||||
func (r *RetrievalEngine) ExpandMessages(ctx context.Context, messageIDs []int64) (*ExpandMessagesResult, error) {
|
||||
result := &ExpandMessagesResult{
|
||||
Messages: make([]Message, 0, len(messageIDs)),
|
||||
}
|
||||
|
||||
for _, msgID := range messageIDs {
|
||||
msg, err := r.store.GetMessageByID(ctx, msgID)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
result.Messages = append(result.Messages, *msg)
|
||||
result.TokenCount += msg.TokenCount
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
@@ -0,0 +1,362 @@
|
||||
package seahorse
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// --- Retrieval Tests ---
|
||||
|
||||
func newTestRetrieval(t *testing.T) (*RetrievalEngine, *Store, int64) {
|
||||
t.Helper()
|
||||
s := openTestStore(t)
|
||||
ctx := context.Background()
|
||||
conv, _ := s.GetOrCreateConversation(ctx, "test:retrieval")
|
||||
return &RetrievalEngine{store: s}, s, conv.ConversationID
|
||||
}
|
||||
|
||||
func TestRetrievalGrepSummaries(t *testing.T) {
|
||||
r, s, convID := newTestRetrieval(t)
|
||||
ctx := context.Background()
|
||||
|
||||
s.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: "数据库连接配置说明",
|
||||
TokenCount: 50,
|
||||
})
|
||||
s.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: "API endpoint documentation",
|
||||
TokenCount: 50,
|
||||
})
|
||||
|
||||
// FTS5 search (trigram, needs >= 3 chars)
|
||||
results, err := r.Grep(ctx, GrepInput{
|
||||
Pattern: "数据库连",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Grep: %v", err)
|
||||
}
|
||||
if len(results.Summaries) == 0 {
|
||||
t.Error("expected at least 1 FTS result")
|
||||
}
|
||||
|
||||
// LIKE search with wildcard
|
||||
results, err = r.Grep(ctx, GrepInput{
|
||||
Pattern: "%endpoint%",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Grep LIKE: %v", err)
|
||||
}
|
||||
if len(results.Summaries) == 0 {
|
||||
t.Error("expected at least 1 LIKE result")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetrievalGrepMessages(t *testing.T) {
|
||||
r, s, convID := newTestRetrieval(t)
|
||||
ctx := context.Background()
|
||||
|
||||
s.AddMessage(ctx, convID, "user", "find this message about testing", 5)
|
||||
s.AddMessage(ctx, convID, "user", "unrelated content here", 5)
|
||||
|
||||
results, err := r.Grep(ctx, GrepInput{
|
||||
Pattern: "testing",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Grep: %v", err)
|
||||
}
|
||||
if len(results.Messages) == 0 {
|
||||
t.Error("expected at least 1 result for 'testing'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetrievalExpandMessages(t *testing.T) {
|
||||
r, s, convID := newTestRetrieval(t)
|
||||
ctx := context.Background()
|
||||
|
||||
msg, _ := s.AddMessage(ctx, convID, "user", "expand this message", 10)
|
||||
|
||||
result, err := r.ExpandMessages(ctx, []int64{msg.ID})
|
||||
if err != nil {
|
||||
t.Fatalf("ExpandMessages: %v", err)
|
||||
}
|
||||
if len(result.Messages) != 1 {
|
||||
t.Errorf("Messages = %d, want 1", len(result.Messages))
|
||||
}
|
||||
if result.Messages[0].Content != "expand this message" {
|
||||
t.Errorf("Content = %q, want 'expand this message'", result.Messages[0].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetrievalExpandMultipleMessages(t *testing.T) {
|
||||
r, s, convID := newTestRetrieval(t)
|
||||
ctx := context.Background()
|
||||
|
||||
msg1, _ := s.AddMessage(ctx, convID, "user", "first message", 10)
|
||||
msg2, _ := s.AddMessage(ctx, convID, "assistant", "second message", 10)
|
||||
msg3, _ := s.AddMessage(ctx, convID, "user", "third message", 10)
|
||||
|
||||
result, err := r.ExpandMessages(ctx, []int64{msg1.ID, msg2.ID, msg3.ID})
|
||||
if err != nil {
|
||||
t.Fatalf("ExpandMessages: %v", err)
|
||||
}
|
||||
if len(result.Messages) != 3 {
|
||||
t.Errorf("Messages = %d, want 3", len(result.Messages))
|
||||
}
|
||||
if result.TokenCount != 30 {
|
||||
t.Errorf("TokenCount = %d, want 30", result.TokenCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetrievalGrepWithTimeFilter(t *testing.T) {
|
||||
r, s, convID := newTestRetrieval(t)
|
||||
ctx := context.Background()
|
||||
|
||||
now := time.Now().UTC()
|
||||
before := now.Add(-2 * time.Hour)
|
||||
|
||||
// Create messages at different times
|
||||
s.AddMessage(ctx, convID, "user", "old message about auth", 5)
|
||||
s.AddMessage(ctx, convID, "user", "recent message about auth", 5)
|
||||
|
||||
// Search with time filter
|
||||
results, err := r.Grep(ctx, GrepInput{
|
||||
Pattern: "auth",
|
||||
Since: &before,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Grep: %v", err)
|
||||
}
|
||||
_ = results // Just verify no error
|
||||
}
|
||||
|
||||
func TestRetrievalGrepAllConversations(t *testing.T) {
|
||||
r, s, _ := newTestRetrieval(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create another conversation
|
||||
conv2, _ := s.GetOrCreateConversation(ctx, "test:retrieval2")
|
||||
|
||||
// Add messages to both
|
||||
s.AddMessage(ctx, conv2.ConversationID, "user", "unique keyword xyz", 5)
|
||||
|
||||
// Search all conversations
|
||||
results, err := r.Grep(ctx, GrepInput{
|
||||
Pattern: "xyz",
|
||||
AllConversations: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Grep: %v", err)
|
||||
}
|
||||
if len(results.Messages) == 0 {
|
||||
t.Error("expected to find message in other conversation")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Last Duration Parsing Tests ---
|
||||
|
||||
func TestParseLastDuration(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
wantDur time.Duration
|
||||
wantErr bool
|
||||
}{
|
||||
{"6h", 6 * time.Hour, false},
|
||||
{"1d", 24 * time.Hour, false},
|
||||
{"7d", 7 * 24 * time.Hour, false},
|
||||
{"2w", 14 * 24 * time.Hour, false},
|
||||
{"1m", 30 * 24 * time.Hour, false}, // month = 30 days
|
||||
{"3m", 90 * 24 * time.Hour, false},
|
||||
{"", 0, true},
|
||||
{"invalid", 0, true},
|
||||
{"5x", 0, true}, // unknown unit
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
got, err := ParseLastDuration(tt.input)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != tt.wantDur {
|
||||
t.Errorf("ParseLastDuration(%q) = %v, want %v", tt.input, got, tt.wantDur)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- Role Filter Tests ---
|
||||
|
||||
func TestRetrievalGrepRoleFilter(t *testing.T) {
|
||||
r, s, convID := newTestRetrieval(t)
|
||||
ctx := context.Background()
|
||||
|
||||
s.AddMessage(ctx, convID, "user", "user message about alpha", 5)
|
||||
s.AddMessage(ctx, convID, "assistant", "assistant reply about alpha", 5)
|
||||
s.AddMessage(ctx, convID, "user", "another user message", 5)
|
||||
|
||||
// Search all roles
|
||||
allResults, err := r.Grep(ctx, GrepInput{
|
||||
Pattern: "alpha",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Grep: %v", err)
|
||||
}
|
||||
if len(allResults.Messages) != 2 {
|
||||
t.Errorf("expected 2 messages, got %d", len(allResults.Messages))
|
||||
}
|
||||
|
||||
// Search user only
|
||||
userResults, err := r.Grep(ctx, GrepInput{
|
||||
Pattern: "alpha",
|
||||
Role: "user",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Grep: %v", err)
|
||||
}
|
||||
if len(userResults.Messages) != 1 {
|
||||
t.Errorf("expected 1 user message, got %d", len(userResults.Messages))
|
||||
}
|
||||
if userResults.Messages[0].Role != "user" {
|
||||
t.Errorf("expected role=user, got %s", userResults.Messages[0].Role)
|
||||
}
|
||||
|
||||
// Search assistant only
|
||||
assistantResults, err := r.Grep(ctx, GrepInput{
|
||||
Pattern: "alpha",
|
||||
Role: "assistant",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Grep: %v", err)
|
||||
}
|
||||
if len(assistantResults.Messages) != 1 {
|
||||
t.Errorf("expected 1 assistant message, got %d", len(assistantResults.Messages))
|
||||
}
|
||||
}
|
||||
|
||||
// --- Last Parameter Tests ---
|
||||
|
||||
func TestRetrievalGrepWithLast(t *testing.T) {
|
||||
r, s, convID := newTestRetrieval(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Add messages (we can't control timestamps in SQLite easily,
|
||||
// but we can verify the parameter is parsed correctly)
|
||||
s.AddMessage(ctx, convID, "user", "recent message about testing", 5)
|
||||
|
||||
// Test that Last parameter is converted to Since
|
||||
results, err := r.Grep(ctx, GrepInput{
|
||||
Pattern: "testing",
|
||||
Last: "1d", // last 1 day
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Grep: %v", err)
|
||||
}
|
||||
// Should still find the message since it's recent
|
||||
if len(results.Messages) == 0 {
|
||||
t.Error("expected to find recent message")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRetrievalGrepRoleFilterWithSummaries tests that role filter works when
|
||||
// searching both summaries and messages (summaries don't have role column).
|
||||
func TestRetrievalGrepRoleFilterWithSummaries(t *testing.T) {
|
||||
r, s, convID := newTestRetrieval(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a summary (no role column)
|
||||
s.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: "summary about testing",
|
||||
TokenCount: 50,
|
||||
})
|
||||
|
||||
// Add messages with different roles
|
||||
s.AddMessage(ctx, convID, "user", "user message about testing", 5)
|
||||
s.AddMessage(ctx, convID, "assistant", "assistant reply about testing", 5)
|
||||
|
||||
// Search with role filter and scope=both (default), using LIKE mode (%)
|
||||
// This should NOT error even though summaries don't have role column
|
||||
bothResults, err := r.Grep(ctx, GrepInput{
|
||||
Pattern: "%testing%", // LIKE mode to trigger the bug
|
||||
Role: "user",
|
||||
Scope: "both",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Grep with role and scope=both: %v", err)
|
||||
}
|
||||
|
||||
// Should only return user messages, not summaries or assistant messages
|
||||
if len(bothResults.Messages) != 1 {
|
||||
t.Errorf("expected 1 user message, got %d", len(bothResults.Messages))
|
||||
}
|
||||
if len(bothResults.Messages) > 0 && bothResults.Messages[0].Role != "user" {
|
||||
t.Errorf("expected role=user, got %s", bothResults.Messages[0].Role)
|
||||
}
|
||||
|
||||
// Summaries should be empty since they don't have roles to filter
|
||||
// (or we could return all summaries - either is acceptable)
|
||||
}
|
||||
|
||||
// TestRetrievalGrepTotalCounts tests that grep returns total counts.
|
||||
func TestRetrievalGrepTotalCounts(t *testing.T) {
|
||||
r, s, convID := newTestRetrieval(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create 3 summaries
|
||||
for i := 0; i < 3; i++ {
|
||||
s.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: convID,
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: fmt.Sprintf("summary about testing %d", i),
|
||||
TokenCount: 50,
|
||||
})
|
||||
}
|
||||
|
||||
// Add 5 messages
|
||||
for i := 0; i < 5; i++ {
|
||||
s.AddMessage(ctx, convID, "user", fmt.Sprintf("message about testing %d", i), 5)
|
||||
}
|
||||
|
||||
// Search with limit smaller than total
|
||||
results, err := r.Grep(ctx, GrepInput{
|
||||
Pattern: "%testing%", // LIKE mode
|
||||
Scope: "both",
|
||||
Limit: 2,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Grep: %v", err)
|
||||
}
|
||||
|
||||
// Should return limited results
|
||||
if len(results.Summaries) > 2 {
|
||||
t.Errorf("expected at most 2 summaries, got %d", len(results.Summaries))
|
||||
}
|
||||
if len(results.Messages) > 2 {
|
||||
t.Errorf("expected at most 2 messages, got %d", len(results.Messages))
|
||||
}
|
||||
|
||||
// But total counts should reflect all matches
|
||||
if results.TotalSummaries != 3 {
|
||||
t.Errorf("expected TotalSummaries=3, got %d", results.TotalSummaries)
|
||||
}
|
||||
if results.TotalMessages != 5 {
|
||||
t.Errorf("expected TotalMessages=5, got %d", results.TotalMessages)
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,129 @@
|
||||
package seahorse
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
|
||||
// ExpandTool recovers full message content by ID.
|
||||
type ExpandTool struct {
|
||||
engine *RetrievalEngine
|
||||
}
|
||||
|
||||
func NewExpandTool(engine *RetrievalEngine) *ExpandTool {
|
||||
return &ExpandTool{engine: engine}
|
||||
}
|
||||
|
||||
func (t *ExpandTool) Name() string {
|
||||
return "short_expand"
|
||||
}
|
||||
|
||||
func (t *ExpandTool) Description() string {
|
||||
return `Get full message content by ID.
|
||||
|
||||
Use when short_grep returns messages and you need complete content (not just snippet).
|
||||
|
||||
Parameters:
|
||||
- message_ids (required): Array of message ID strings (from short_grep results)
|
||||
|
||||
Returns message with:
|
||||
- content: Full text content
|
||||
- parts: Structured content
|
||||
- text: Full text
|
||||
- tool_use: name, arguments, toolCallId
|
||||
- tool_result: toolCallId only (content omitted - re-run tool if needed)
|
||||
- media: mediaUri (file path), mimeType
|
||||
|
||||
Notes:
|
||||
- tool_result content is not returned (can be large). Re-run the tool if you need the result.
|
||||
- Media files are stored on disk at mediaUri path, use bash to access.
|
||||
|
||||
Example:
|
||||
{"message_ids": ["10", "25"]}`
|
||||
}
|
||||
|
||||
func (t *ExpandTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"message_ids": map[string]any{
|
||||
"type": "array",
|
||||
"items": map[string]any{"type": "string"},
|
||||
"description": "Message IDs to expand (from short_grep results, e.g., [\"10\", \"25\"])",
|
||||
},
|
||||
},
|
||||
"required": []string{"message_ids"},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ExpandTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
|
||||
idsRaw, ok := args["message_ids"].([]any)
|
||||
if !ok || len(idsRaw) == 0 {
|
||||
return tools.ErrorResult(
|
||||
"Missing required 'message_ids' argument. " +
|
||||
"Example: {\"message_ids\": [\"10\", \"25\"]}")
|
||||
}
|
||||
|
||||
// Parse message IDs
|
||||
messageIDs := make([]int64, 0, len(idsRaw))
|
||||
for _, id := range idsRaw {
|
||||
switch v := id.(type) {
|
||||
case string:
|
||||
var n int64
|
||||
if _, err := fmt.Sscanf(v, "%d", &n); err != nil {
|
||||
return tools.ErrorResult(fmt.Sprintf("Invalid message_id %q: %v", v, err))
|
||||
}
|
||||
messageIDs = append(messageIDs, n)
|
||||
case float64:
|
||||
messageIDs = append(messageIDs, int64(v))
|
||||
}
|
||||
}
|
||||
|
||||
result, err := t.engine.ExpandMessages(ctx, messageIDs)
|
||||
if err != nil {
|
||||
return tools.ErrorResult("Expand failed: " + err.Error())
|
||||
}
|
||||
|
||||
// Build response with filtered parts
|
||||
messages := make([]map[string]any, 0, len(result.Messages))
|
||||
for _, msg := range result.Messages {
|
||||
parts := make([]map[string]any, 0, len(msg.Parts))
|
||||
for _, p := range msg.Parts {
|
||||
part := map[string]any{"type": p.Type}
|
||||
switch p.Type {
|
||||
case "text":
|
||||
part["text"] = p.Text
|
||||
case "tool_use":
|
||||
part["name"] = p.Name
|
||||
part["arguments"] = p.Arguments
|
||||
part["toolCallId"] = p.ToolCallID
|
||||
case "tool_result":
|
||||
// Omit content - can be large, re-run tool if needed
|
||||
part["toolCallId"] = p.ToolCallID
|
||||
case "media":
|
||||
part["mediaUri"] = p.MediaURI
|
||||
part["mimeType"] = p.MimeType
|
||||
}
|
||||
parts = append(parts, part)
|
||||
}
|
||||
|
||||
messages = append(messages, map[string]any{
|
||||
"id": fmt.Sprintf("%d", msg.ID),
|
||||
"role": msg.Role,
|
||||
"content": msg.Content,
|
||||
"parts": parts,
|
||||
"conversationId": msg.ConversationID,
|
||||
})
|
||||
}
|
||||
|
||||
output := map[string]any{
|
||||
"success": true,
|
||||
"tokenCount": result.TokenCount,
|
||||
"messages": messages,
|
||||
}
|
||||
data, _ := json.Marshal(output)
|
||||
return tools.NewToolResult(string(data))
|
||||
}
|
||||
@@ -0,0 +1,136 @@
|
||||
package seahorse
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestExpandToolByMessageIDs(t *testing.T) {
|
||||
s := openTestStore(t)
|
||||
ctx := context.Background()
|
||||
conv, _ := s.GetOrCreateConversation(ctx, "test:expand-tool")
|
||||
|
||||
msg1, _ := s.AddMessage(ctx, conv.ConversationID, "user", "first message", 10)
|
||||
msg2, _ := s.AddMessage(ctx, conv.ConversationID, "assistant", "second message", 10)
|
||||
|
||||
re := &RetrievalEngine{store: s}
|
||||
tool := NewExpandTool(re)
|
||||
|
||||
result := tool.Execute(ctx, map[string]any{
|
||||
"message_ids": []any{fmt.Sprintf("%d", msg1.ID), fmt.Sprintf("%d", msg2.ID)},
|
||||
})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf("Expand failed: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// Parse result
|
||||
var output struct {
|
||||
Success bool `json:"success"`
|
||||
TokenCount int `json:"tokenCount"`
|
||||
Messages []map[string]any `json:"messages"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(result.ForLLM), &output); err != nil {
|
||||
t.Fatalf("Parse result: %v", err)
|
||||
}
|
||||
|
||||
if !output.Success {
|
||||
t.Error("expected success=true")
|
||||
}
|
||||
if len(output.Messages) != 2 {
|
||||
t.Errorf("Messages = %d, want 2", len(output.Messages))
|
||||
}
|
||||
if output.TokenCount != 20 {
|
||||
t.Errorf("TokenCount = %d, want 20", output.TokenCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandToolMissingIDs(t *testing.T) {
|
||||
s := openTestStore(t)
|
||||
re := &RetrievalEngine{store: s}
|
||||
tool := NewExpandTool(re)
|
||||
|
||||
result := tool.Execute(context.Background(), map[string]any{})
|
||||
|
||||
if !result.IsError {
|
||||
t.Error("expected error for missing message_ids")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandToolWithParts(t *testing.T) {
|
||||
s := openTestStore(t)
|
||||
ctx := context.Background()
|
||||
conv, _ := s.GetOrCreateConversation(ctx, "test:expand-parts")
|
||||
|
||||
// Create message with parts
|
||||
parts := []MessagePart{
|
||||
{Type: "text", Text: "Hello"},
|
||||
{Type: "tool_use", Name: "bash", Arguments: `{"command":"ls"}`, ToolCallID: "call_123"},
|
||||
{Type: "tool_result", ToolCallID: "call_123", Text: "file1.txt\nfile2.txt"},
|
||||
}
|
||||
msg, _ := s.AddMessageWithParts(ctx, conv.ConversationID, "assistant", parts, 50)
|
||||
|
||||
re := &RetrievalEngine{store: s}
|
||||
tool := NewExpandTool(re)
|
||||
|
||||
result := tool.Execute(ctx, map[string]any{
|
||||
"message_ids": []any{fmt.Sprintf("%d", msg.ID)},
|
||||
})
|
||||
|
||||
if result.IsError {
|
||||
t.Fatalf("Expand failed: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
var output struct {
|
||||
Messages []struct {
|
||||
Parts []map[string]any `json:"parts"`
|
||||
} `json:"messages"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(result.ForLLM), &output); err != nil {
|
||||
t.Fatalf("Parse result: %v", err)
|
||||
}
|
||||
|
||||
if len(output.Messages) != 1 {
|
||||
t.Fatalf("Messages = %d, want 1", len(output.Messages))
|
||||
}
|
||||
|
||||
// Verify parts are filtered correctly
|
||||
foundText := false
|
||||
foundToolUse := false
|
||||
foundToolResult := false
|
||||
for _, p := range output.Messages[0].Parts {
|
||||
switch p["type"].(string) {
|
||||
case "text":
|
||||
foundText = true
|
||||
if p["text"] != "Hello" {
|
||||
t.Errorf("text = %v, want Hello", p["text"])
|
||||
}
|
||||
case "tool_use":
|
||||
foundToolUse = true
|
||||
if p["name"] != "bash" {
|
||||
t.Errorf("name = %v, want bash", p["name"])
|
||||
}
|
||||
case "tool_result":
|
||||
foundToolResult = true
|
||||
// tool_result should NOT have content
|
||||
if _, hasContent := p["content"]; hasContent {
|
||||
t.Error("tool_result should not have content field")
|
||||
}
|
||||
if p["toolCallId"] != "call_123" {
|
||||
t.Errorf("toolCallId = %v, want call_123", p["toolCallId"])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !foundText {
|
||||
t.Error("missing text part")
|
||||
}
|
||||
if !foundToolUse {
|
||||
t.Error("missing tool_use part")
|
||||
}
|
||||
if !foundToolResult {
|
||||
t.Error("missing tool_result part")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,172 @@
|
||||
package seahorse
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
|
||||
// GrepTool searches summaries and messages for matching content.
|
||||
type GrepTool struct {
|
||||
engine *RetrievalEngine
|
||||
}
|
||||
|
||||
func NewGrepTool(engine *RetrievalEngine) *GrepTool {
|
||||
return &GrepTool{engine: engine}
|
||||
}
|
||||
|
||||
func (t *GrepTool) Name() string {
|
||||
return "short_grep"
|
||||
}
|
||||
|
||||
func (t *GrepTool) Description() string {
|
||||
return `Search summaries and messages for matching content.
|
||||
|
||||
Pattern syntax:
|
||||
- Words: "authentication" - matches content containing this word
|
||||
- AND: "auth AND login" - matches content with both words
|
||||
- OR: "auth OR signin" - matches content with either word
|
||||
- NOT: "bug NOT fixed" - matches "bug" but excludes "fixed"
|
||||
- Wildcard: "%auth%" - matches any text containing "auth" (e.g., "auth", "authentication")
|
||||
|
||||
Each summary has a "depth" field:
|
||||
- depth 0: Created from messages, most detailed
|
||||
- depth 1+: Created from other summaries, more compressed but covers longer time
|
||||
|
||||
Parameters:
|
||||
- pattern (required): Search pattern
|
||||
- scope: "both" (default), "summary", or "message" - what to search
|
||||
- role: "user", "assistant", or omit for all - filter by message role
|
||||
- last: Time shortcut like "6h", "7d", "2w", "1m" (hours/days/weeks/months)
|
||||
- all_conversations: Search all conversations (default: current only)
|
||||
- since: ISO8601 timestamp, content after this time
|
||||
- before: ISO8601 timestamp, content before this time
|
||||
- limit: Max results (default: 20)
|
||||
|
||||
Returns:
|
||||
{
|
||||
"success": true,
|
||||
"summaries": [{"id": "sum_abc", "content": "...", "depth": 0, "kind": "leaf", "conversationId": 1, "rank": -0.5}],
|
||||
"messages": [{"id": "10", "snippet": "...matched...", "role": "user", "conversationId": 1, "rank": -1.2}],
|
||||
"totalSummaries": 5,
|
||||
"totalMessages": 10,
|
||||
"hint": "No matches. Try: %keyword% for fuzzy search"
|
||||
}
|
||||
|
||||
Rank field (FTS5 mode only): bm25 relevance score, negative value where closer to 0 = better match.
|
||||
Examples: -0.5=excellent, -2=good, -5=partial, -10=weak. LIKE mode (%pattern%) has no rank.
|
||||
|
||||
Examples:
|
||||
{"pattern": "authentication"}
|
||||
{"pattern": "bug AND login"}
|
||||
{"pattern": "%snake%"}
|
||||
{"pattern": "project", "scope": "summary"}
|
||||
{"pattern": "error", "role": "assistant", "last": "7d"}
|
||||
{"pattern": "error", "all_conversations": true}`
|
||||
}
|
||||
|
||||
func (t *GrepTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"pattern": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Search pattern. Supports: words, AND/OR/NOT operators, % wildcard",
|
||||
},
|
||||
"scope": map[string]any{
|
||||
"type": "string",
|
||||
"enum": []string{"both", "summary", "message"},
|
||||
"description": "What to search: 'both' (default), 'summary', or 'message'",
|
||||
},
|
||||
"role": map[string]any{
|
||||
"type": "string",
|
||||
"enum": []string{"user", "assistant"},
|
||||
"description": "Filter by message role (default: all roles)",
|
||||
},
|
||||
"last": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Time shortcut: '6h' (6 hours), '7d' (7 days), '2w' (2 weeks), '1m' (1 month)",
|
||||
},
|
||||
"all_conversations": map[string]any{
|
||||
"type": "boolean",
|
||||
"description": "Search across all conversations (default: searches current conversation only)",
|
||||
},
|
||||
"since": map[string]any{
|
||||
"type": "string",
|
||||
"description": "ISO8601 timestamp, only return content after this time",
|
||||
},
|
||||
"before": map[string]any{
|
||||
"type": "string",
|
||||
"description": "ISO8601 timestamp, only return content before this time",
|
||||
},
|
||||
"limit": map[string]any{
|
||||
"type": "integer",
|
||||
"description": "Maximum number of results (default: 20)",
|
||||
},
|
||||
},
|
||||
"required": []string{"pattern"},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *GrepTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
|
||||
pattern, ok := args["pattern"].(string)
|
||||
if !ok || pattern == "" {
|
||||
return tools.ErrorResult("Missing required 'pattern' argument. Example: {\"pattern\": \"authentication\"}")
|
||||
}
|
||||
|
||||
input := GrepInput{Pattern: pattern}
|
||||
|
||||
if scope, ok := args["scope"].(string); ok && scope != "" {
|
||||
input.Scope = scope
|
||||
}
|
||||
if role, ok := args["role"].(string); ok && role != "" {
|
||||
input.Role = role
|
||||
}
|
||||
if last, ok := args["last"].(string); ok && last != "" {
|
||||
input.Last = last
|
||||
}
|
||||
if allConv, ok := args["all_conversations"].(bool); ok {
|
||||
input.AllConversations = allConv
|
||||
}
|
||||
if limit, ok := args["limit"].(float64); ok {
|
||||
input.Limit = int(limit)
|
||||
}
|
||||
if sinceStr, ok := args["since"].(string); ok && sinceStr != "" {
|
||||
parsed, err := time.Parse(time.RFC3339, sinceStr)
|
||||
if err != nil {
|
||||
return tools.ErrorResult(fmt.Sprintf(
|
||||
"Invalid 'since' timestamp. Use RFC3339 format like '2024-01-15T10:00:00Z'. Error: %v", err))
|
||||
}
|
||||
input.Since = &parsed
|
||||
}
|
||||
if beforeStr, ok := args["before"].(string); ok && beforeStr != "" {
|
||||
parsed, err := time.Parse(time.RFC3339, beforeStr)
|
||||
if err != nil {
|
||||
return tools.ErrorResult(fmt.Sprintf("Invalid 'before' timestamp format: %v", err))
|
||||
}
|
||||
input.Before = &parsed
|
||||
}
|
||||
|
||||
result, err := t.engine.Grep(ctx, input)
|
||||
if err != nil {
|
||||
return tools.ErrorResult("Grep failed: " + err.Error())
|
||||
}
|
||||
|
||||
// Build response
|
||||
output := map[string]any{
|
||||
"success": result.Success,
|
||||
"summaries": result.Summaries,
|
||||
"messages": result.Messages,
|
||||
}
|
||||
|
||||
// Add hint if provided
|
||||
if result.Hint != "" {
|
||||
output["hint"] = result.Hint
|
||||
}
|
||||
|
||||
data, _ := json.Marshal(output)
|
||||
return tools.NewToolResult(string(data))
|
||||
}
|
||||
@@ -0,0 +1,72 @@
|
||||
package seahorse
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGrepSearchSummaries(t *testing.T) {
|
||||
s := openTestStore(t)
|
||||
ctx := context.Background()
|
||||
conv, _ := s.GetOrCreateConversation(ctx, "test:grep-tool")
|
||||
|
||||
s.CreateSummary(ctx, CreateSummaryInput{
|
||||
ConversationID: conv.ConversationID,
|
||||
Kind: SummaryKindLeaf,
|
||||
Depth: 0,
|
||||
Content: "database connection pool configuration",
|
||||
TokenCount: 50,
|
||||
})
|
||||
|
||||
re := &RetrievalEngine{store: s}
|
||||
results, err := re.Grep(ctx, GrepInput{
|
||||
Pattern: "database",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Grep: %v", err)
|
||||
}
|
||||
if len(results.Summaries) == 0 {
|
||||
t.Error("expected at least 1 summary result")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGrepSearchMessages(t *testing.T) {
|
||||
s := openTestStore(t)
|
||||
ctx := context.Background()
|
||||
conv, _ := s.GetOrCreateConversation(ctx, "test:grep-msg")
|
||||
|
||||
s.AddMessage(ctx, conv.ConversationID, "user", "find this message about testing", 5)
|
||||
s.AddMessage(ctx, conv.ConversationID, "user", "unrelated content", 3)
|
||||
|
||||
re := &RetrievalEngine{store: s}
|
||||
results, err := re.Grep(ctx, GrepInput{
|
||||
Pattern: "testing",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Grep messages: %v", err)
|
||||
}
|
||||
if len(results.Messages) == 0 {
|
||||
t.Error("expected at least 1 message result")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGrepMissingPattern(t *testing.T) {
|
||||
s := openTestStore(t)
|
||||
re := &RetrievalEngine{store: s}
|
||||
_, err := re.Grep(context.Background(), GrepInput{})
|
||||
if err == nil {
|
||||
t.Error("expected error for missing pattern")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGrepToolSupportsAllConversations(t *testing.T) {
|
||||
s := openTestStore(t)
|
||||
tool := NewGrepTool(&RetrievalEngine{store: s})
|
||||
params := tool.Parameters()
|
||||
props := params["properties"].(map[string]any)
|
||||
|
||||
// GrepTool should accept all_conversations parameter
|
||||
if _, ok := props["all_conversations"]; !ok {
|
||||
t.Error("Parameters missing 'all_conversations' field")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,161 @@
|
||||
package seahorse
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/tokenizer"
|
||||
)
|
||||
|
||||
// SummaryKind distinguishes leaf summaries (from raw messages) vs condensed
|
||||
// summaries (from other summaries).
|
||||
type SummaryKind string
|
||||
|
||||
const (
|
||||
SummaryKindLeaf SummaryKind = "leaf"
|
||||
SummaryKindCondensed SummaryKind = "condensed"
|
||||
)
|
||||
|
||||
// Message represents a single chat message with role and content.
|
||||
type Message struct {
|
||||
ID int64 `json:"id"`
|
||||
ConversationID int64 `json:"conversationId"`
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
ReasoningContent string `json:"reasoningContent,omitempty"`
|
||||
TokenCount int `json:"tokenCount"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
Parts []MessagePart `json:"parts,omitempty"`
|
||||
}
|
||||
|
||||
// MessagePart holds structured content (tool calls, media, etc.)
|
||||
type MessagePart struct {
|
||||
ID int64 `json:"id"`
|
||||
MessageID int64 `json:"messageId"`
|
||||
Type string `json:"type"` // "text", "tool_use", "tool_result", "media"
|
||||
Text string `json:"text"`
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
ToolCallID string `json:"toolCallId"`
|
||||
MediaURI string `json:"mediaUri"`
|
||||
MimeType string `json:"mimeType"`
|
||||
}
|
||||
|
||||
// Summary represents a compressed representation of messages or other summaries.
|
||||
type Summary struct {
|
||||
SummaryID string `json:"summaryId"`
|
||||
ConversationID int64 `json:"conversationId"`
|
||||
Kind SummaryKind `json:"kind"`
|
||||
Depth int `json:"depth"`
|
||||
Content string `json:"content"`
|
||||
TokenCount int `json:"tokenCount"`
|
||||
EarliestAt *time.Time `json:"earliestAt,omitempty"`
|
||||
LatestAt *time.Time `json:"latestAt,omitempty"`
|
||||
DescendantCount int `json:"descendantCount"`
|
||||
DescendantTokenCount int `json:"descendantTokenCount"`
|
||||
SourceMessageTokenCount int `json:"sourceMessageTokenCount"`
|
||||
Model string `json:"model"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
}
|
||||
|
||||
// SummaryNode is a Summary with graph relationships for tree traversal.
|
||||
type SummaryNode struct {
|
||||
Summary
|
||||
Children []string `json:"children"` // Child summary IDs
|
||||
Expanded bool `json:"expanded"` // UI state for expansion
|
||||
}
|
||||
|
||||
// Conversation represents a session's conversation with metadata.
|
||||
type Conversation struct {
|
||||
ConversationID int64 `json:"conversationId"`
|
||||
SessionKey string `json:"sessionKey"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
}
|
||||
|
||||
// SessionStatus contains status information for a session.
|
||||
type SessionStatus struct {
|
||||
SessionKey string `json:"sessionKey"`
|
||||
ConversationID int64 `json:"conversationId"`
|
||||
Messages int `json:"messages"`
|
||||
TotalTokens int `json:"totalTokens"`
|
||||
Summaries int `json:"summaries"`
|
||||
OldestAt time.Time `json:"oldestAt"`
|
||||
NewestAt time.Time `json:"newestAt"`
|
||||
}
|
||||
|
||||
// ContextItem represents one item in the assembled context window.
|
||||
type ContextItem struct {
|
||||
ConversationID int64 `json:"conversationId"`
|
||||
Ordinal int `json:"ordinal"`
|
||||
ItemType string `json:"itemType"` // "summary" or "message"
|
||||
SummaryID string `json:"summaryId,omitempty"`
|
||||
MessageID int64 `json:"messageId,omitempty"`
|
||||
TokenCount int `json:"tokenCount"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
}
|
||||
|
||||
// SummarySubtreeNode is a node in a summary DAG subtree.
|
||||
type SummarySubtreeNode struct {
|
||||
SummaryID string `json:"summaryId"`
|
||||
DepthFromRoot int `json:"depthFromRoot"`
|
||||
}
|
||||
|
||||
// SearchInput controls summary search.
|
||||
type SearchInput struct {
|
||||
Pattern string `json:"pattern"`
|
||||
Mode string `json:"mode"` // "like" (LIKE search) or "full_text" (FTS5, default)
|
||||
Scope string `json:"scope,omitempty"` // "messages", "summaries", "both"
|
||||
Role string `json:"role,omitempty"` // "user", "assistant", or "" (all)
|
||||
Since *time.Time `json:"since,omitempty"`
|
||||
Before *time.Time `json:"before,omitempty"`
|
||||
Limit int `json:"limit,omitempty"`
|
||||
ConversationID int64 `json:"conversationId,omitempty"`
|
||||
AllConversations bool `json:"allConversations,omitempty"`
|
||||
}
|
||||
|
||||
// SearchResult is a search match.
|
||||
type SearchResult struct {
|
||||
SummaryID string `json:"summaryId,omitempty"`
|
||||
MessageID int64 `json:"messageId,omitempty"`
|
||||
ConversationID int64 `json:"conversationId"`
|
||||
Kind SummaryKind `json:"kind,omitempty"`
|
||||
Depth int `json:"depth,omitempty"`
|
||||
Role string `json:"role,omitempty"`
|
||||
Content string `json:"content,omitempty"` // Full content for summaries
|
||||
Snippet string `json:"snippet"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
Rank float64 `json:"rank,omitempty"`
|
||||
TotalCount int `json:"totalCount,omitempty"` // Total matching rows (from window function)
|
||||
}
|
||||
|
||||
// EstimateMessageTokens estimates token count for a full message using the
|
||||
// shared tokenizer package for consistency with agent.context_budget.
|
||||
func EstimateMessageTokens(msg Message) int {
|
||||
pm := providers.Message{
|
||||
Role: msg.Role,
|
||||
Content: msg.Content,
|
||||
ReasoningContent: msg.ReasoningContent,
|
||||
}
|
||||
|
||||
// Convert MessageParts to ToolCalls / ToolCallID / Media
|
||||
for _, part := range msg.Parts {
|
||||
switch part.Type {
|
||||
case "tool_use":
|
||||
pm.ToolCalls = append(pm.ToolCalls, providers.ToolCall{
|
||||
ID: part.ToolCallID,
|
||||
Type: "function",
|
||||
Function: &providers.FunctionCall{
|
||||
Name: part.Name,
|
||||
Arguments: part.Arguments,
|
||||
},
|
||||
})
|
||||
case "tool_result":
|
||||
pm.ToolCallID = part.ToolCallID
|
||||
case "media":
|
||||
pm.Media = append(pm.Media, part.MediaURI)
|
||||
}
|
||||
}
|
||||
|
||||
return tokenizer.EstimateMessageTokens(pm)
|
||||
}
|
||||
@@ -0,0 +1,54 @@
|
||||
package seahorse
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSummaryKindValues(t *testing.T) {
|
||||
if SummaryKindLeaf != "leaf" {
|
||||
t.Errorf("expected SummaryKindLeaf = 'leaf', got %q", SummaryKindLeaf)
|
||||
}
|
||||
if SummaryKindCondensed != "condensed" {
|
||||
t.Errorf("expected SummaryKindCondensed = 'condensed', got %q", SummaryKindCondensed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConstants(t *testing.T) {
|
||||
// Ordinal gap step
|
||||
if OrdinalStep != 100 {
|
||||
t.Errorf("expected OrdinalStep = 100, got %d", OrdinalStep)
|
||||
}
|
||||
|
||||
// Compaction triggers
|
||||
if ContextThreshold != 0.75 {
|
||||
t.Errorf("expected ContextThreshold = 0.75, got %f", ContextThreshold)
|
||||
}
|
||||
if FreshTailCount != 32 {
|
||||
t.Errorf("expected FreshTailCount = 32, got %d", FreshTailCount)
|
||||
}
|
||||
|
||||
// Fanout
|
||||
if LeafMinFanout != 8 {
|
||||
t.Errorf("expected LeafMinFanout = 8, got %d", LeafMinFanout)
|
||||
}
|
||||
if CondensedMinFanout != 4 {
|
||||
t.Errorf("expected CondensedMinFanout = 4, got %d", CondensedMinFanout)
|
||||
}
|
||||
if CondensedMinFanoutHard != 2 {
|
||||
t.Errorf("expected CondensedMinFanoutHard = 2, got %d", CondensedMinFanoutHard)
|
||||
}
|
||||
|
||||
// Token targets
|
||||
if LeafChunkTokens != 20000 {
|
||||
t.Errorf("expected LeafChunkTokens = 20000, got %d", LeafChunkTokens)
|
||||
}
|
||||
if LeafTargetTokens != 1200 {
|
||||
t.Errorf("expected LeafTargetTokens = 1200, got %d", LeafTargetTokens)
|
||||
}
|
||||
if CondensedTargetTokens != 2000 {
|
||||
t.Errorf("expected CondensedTargetTokens = 2000, got %d", CondensedTargetTokens)
|
||||
}
|
||||
if MaxExpandTokens != 4000 {
|
||||
t.Errorf("expected MaxExpandTokens = 4000, got %d", MaxExpandTokens)
|
||||
}
|
||||
}
|
||||
@@ -79,3 +79,8 @@ func (b *JSONLBackend) Save(key string) error {
|
||||
func (b *JSONLBackend) Close() error {
|
||||
return b.store.Close()
|
||||
}
|
||||
|
||||
// ListSessions returns all known session keys.
|
||||
func (b *JSONLBackend) ListSessions() []string {
|
||||
return b.store.ListSessions()
|
||||
}
|
||||
|
||||
@@ -145,6 +145,16 @@ func (sm *SessionManager) TruncateHistory(key string, keepLast int) {
|
||||
session.Updated = time.Now()
|
||||
}
|
||||
|
||||
func (sm *SessionManager) ListSessions() []string {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
keys := make([]string, 0, len(sm.sessions))
|
||||
for k := range sm.sessions {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
// sanitizeFilename converts a session key into a cross-platform safe filename.
|
||||
// Replaces ':' with '_' (session key separator) and '/' and '\' with '_' so
|
||||
// composite IDs (e.g. Telegram forum "chatID/threadID") do not create
|
||||
|
||||
@@ -27,6 +27,8 @@ type SessionStore interface {
|
||||
TruncateHistory(key string, keepLast int)
|
||||
// Save persists any pending state to durable storage.
|
||||
Save(key string) error
|
||||
// ListSessions returns all known session keys.
|
||||
ListSessions() []string
|
||||
// Close releases resources held by the store.
|
||||
Close() error
|
||||
}
|
||||
|
||||
@@ -0,0 +1,91 @@
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
// EstimateMessageTokens estimates the token count for a single message,
|
||||
// including Content, ReasoningContent, ToolCalls arguments, ToolCallID
|
||||
// metadata, and Media items. Uses a heuristic of 2.5 characters per token.
|
||||
func EstimateMessageTokens(msg providers.Message) int {
|
||||
contentChars := utf8.RuneCountInString(msg.Content)
|
||||
|
||||
// SystemParts are structured system blocks used for cache-aware adapters.
|
||||
// They carry the same content as Content, but in multiple blocks.
|
||||
// We estimate them as an alternative representation, not additive.
|
||||
systemPartsChars := 0
|
||||
if len(msg.SystemParts) > 0 {
|
||||
for _, part := range msg.SystemParts {
|
||||
systemPartsChars += utf8.RuneCountInString(part.Text)
|
||||
}
|
||||
// Per-part overhead for JSON structure (type, text, cache_control).
|
||||
const perPartOverhead = 20
|
||||
systemPartsChars += len(msg.SystemParts) * perPartOverhead
|
||||
}
|
||||
|
||||
// Use the larger of the two representations to stay conservative.
|
||||
chars := contentChars
|
||||
if systemPartsChars > chars {
|
||||
chars = systemPartsChars
|
||||
}
|
||||
|
||||
chars += utf8.RuneCountInString(msg.ReasoningContent)
|
||||
|
||||
for _, tc := range msg.ToolCalls {
|
||||
chars += len(tc.ID) + len(tc.Type)
|
||||
if tc.Function != nil {
|
||||
// Count function name + arguments (the wire format for most providers).
|
||||
// tc.Name mirrors tc.Function.Name — count only once to avoid double-counting.
|
||||
chars += len(tc.Function.Name) + len(tc.Function.Arguments)
|
||||
} else {
|
||||
// Fallback: some provider formats use top-level Name without Function.
|
||||
chars += len(tc.Name)
|
||||
}
|
||||
}
|
||||
|
||||
if msg.ToolCallID != "" {
|
||||
chars += len(msg.ToolCallID)
|
||||
}
|
||||
|
||||
// Per-message overhead for role label, JSON structure, separators.
|
||||
const messageOverhead = 12
|
||||
chars += messageOverhead
|
||||
|
||||
tokens := chars * 2 / 5
|
||||
|
||||
// Media items (images, files) are serialized by provider adapters into
|
||||
// multipart or image_url payloads. Add a fixed per-item token estimate
|
||||
// directly (not through the chars heuristic) since actual cost depends
|
||||
// on resolution and provider-specific image tokenization.
|
||||
const mediaTokensPerItem = 256
|
||||
tokens += len(msg.Media) * mediaTokensPerItem
|
||||
|
||||
return tokens
|
||||
}
|
||||
|
||||
// EstimateToolDefsTokens estimates the total token cost of tool definitions
|
||||
// as they appear in the LLM request.
|
||||
func EstimateToolDefsTokens(defs []providers.ToolDefinition) int {
|
||||
if len(defs) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
totalChars := 0
|
||||
for _, d := range defs {
|
||||
totalChars += len(d.Function.Name) + len(d.Function.Description)
|
||||
|
||||
if d.Function.Parameters != nil {
|
||||
if paramJSON, err := json.Marshal(d.Function.Parameters); err == nil {
|
||||
totalChars += len(paramJSON)
|
||||
}
|
||||
}
|
||||
|
||||
// Per-tool overhead: type field, JSON structure, separators.
|
||||
totalChars += 20
|
||||
}
|
||||
|
||||
return totalChars * 2 / 5
|
||||
}
|
||||
Reference in New Issue
Block a user