mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
351 lines
9.4 KiB
Go
351 lines
9.4 KiB
Go
package seahorse
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/sipeed/picoclaw/pkg/logger"
|
|
)
|
|
|
|
// escapeXML escapes special characters for safe inclusion in XML content.
|
|
func escapeXML(s string) string {
|
|
s = strings.ReplaceAll(s, "&", "&")
|
|
s = strings.ReplaceAll(s, "<", "<")
|
|
s = strings.ReplaceAll(s, ">", ">")
|
|
s = strings.ReplaceAll(s, "\"", """)
|
|
s = strings.ReplaceAll(s, "'", "'")
|
|
return s
|
|
}
|
|
|
|
// resolvedItem is a context item resolved to its full content with token count.
|
|
type resolvedItem struct {
|
|
ordinal int
|
|
itemType string // "message" or "summary"
|
|
message *Message
|
|
summary *Summary
|
|
tokenCount int
|
|
}
|
|
|
|
// Assemble builds budget-constrained context from summaries + messages.
|
|
//
|
|
// Algorithm:
|
|
// 1. Fetch context_items, resolve to full content
|
|
// 2. Split into evictable prefix + protected fresh tail
|
|
// 3. If evictable fits in remaining budget → include all
|
|
// 4. Else walk evictable from newest to oldest, keep while fits
|
|
func (a *Assembler) Assemble(ctx context.Context, convID int64, input AssembleInput) (*AssembleResult, error) {
|
|
items, err := a.store.GetContextItems(ctx, convID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("get context items: %w", err)
|
|
}
|
|
if len(items) == 0 {
|
|
return &AssembleResult{}, nil
|
|
}
|
|
|
|
// Resolve all items
|
|
resolved := make([]resolvedItem, len(items))
|
|
for i, item := range items {
|
|
r, err := a.resolveItem(ctx, item)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
resolved[i] = r
|
|
}
|
|
|
|
// Split into evictable prefix and protected fresh tail
|
|
tailStart := len(resolved) - FreshTailCount
|
|
if tailStart < 0 {
|
|
tailStart = 0
|
|
}
|
|
evictable := resolved[:tailStart]
|
|
freshTail := resolved[tailStart:]
|
|
|
|
// Calculate fresh tail tokens
|
|
freshTailTokens := 0
|
|
for _, r := range freshTail {
|
|
freshTailTokens += r.tokenCount
|
|
}
|
|
|
|
// If the protected tail alone exceeds budget, trim from the oldest end at
|
|
// provider-safe boundaries. The rebuild path later sanitizes leading
|
|
// assistant(tool_calls)/tool messages, so splitting the active turn here can
|
|
// silently discard the very context we are trying to protect.
|
|
if freshTailTokens > input.Budget {
|
|
originalTailCount := len(freshTail)
|
|
originalFreshTailTokens := freshTailTokens
|
|
var preservedActiveTurn bool
|
|
freshTail, freshTailTokens, preservedActiveTurn = trimFreshTailToSafeBudget(freshTail, input.Budget)
|
|
logFields := map[string]any{
|
|
"budget": input.Budget,
|
|
"fresh_tail_tokens": freshTailTokens,
|
|
"fresh_tail_count": len(freshTail),
|
|
"trimmed_fresh_items": originalTailCount - len(freshTail),
|
|
"original_fresh_tokens": originalFreshTailTokens,
|
|
"preserved_active_turn": preservedActiveTurn,
|
|
}
|
|
if preservedActiveTurn {
|
|
logger.WarnCF("seahorse", "assemble: preserving active turn over budget", logFields)
|
|
} else {
|
|
logger.InfoCF("seahorse", "assemble: trimmed fresh tail to safe boundary", logFields)
|
|
}
|
|
}
|
|
remainingBudget := input.Budget - freshTailTokens
|
|
if remainingBudget < 0 {
|
|
remainingBudget = 0
|
|
}
|
|
|
|
var selected []resolvedItem
|
|
evictableTokens := 0
|
|
for _, r := range evictable {
|
|
evictableTokens += r.tokenCount
|
|
}
|
|
|
|
if evictableTokens <= remainingBudget {
|
|
// All evictable fit
|
|
selected = append(selected, evictable...)
|
|
} else {
|
|
// Walk from newest to oldest, keep while fits
|
|
var kept []resolvedItem
|
|
accum := 0
|
|
for i := len(evictable) - 1; i >= 0; i-- {
|
|
if accum+evictable[i].tokenCount <= remainingBudget {
|
|
kept = append(kept, evictable[i])
|
|
accum += evictable[i].tokenCount
|
|
} else {
|
|
break
|
|
}
|
|
}
|
|
// Reverse to restore chronological order
|
|
for i, j := 0, len(kept)-1; i < j; i, j = i+1, j-1 {
|
|
kept[i], kept[j] = kept[j], kept[i]
|
|
}
|
|
selected = append(selected, kept...)
|
|
}
|
|
|
|
// Combine: selected evictable + fresh tail
|
|
final := append(selected, freshTail...)
|
|
|
|
// Build result
|
|
var messages []Message
|
|
var summaries []Summary
|
|
var sourceIDs []string
|
|
totalTokens := 0
|
|
maxDepth := 0
|
|
condensedCount := 0
|
|
|
|
for _, r := range final {
|
|
totalTokens += r.tokenCount
|
|
if r.itemType == "message" && r.message != nil {
|
|
messages = append(messages, *r.message)
|
|
sourceIDs = append(sourceIDs, fmt.Sprintf("msg:%d", r.message.ID))
|
|
} else if r.itemType == "summary" && r.summary != nil {
|
|
summaries = append(summaries, *r.summary)
|
|
if r.summary.Depth > maxDepth {
|
|
maxDepth = r.summary.Depth
|
|
}
|
|
if r.summary.Kind == SummaryKindCondensed {
|
|
condensedCount++
|
|
}
|
|
}
|
|
}
|
|
|
|
// Build depth-aware system prompt addition
|
|
systemPromptAddition := ""
|
|
if len(summaries) > 0 {
|
|
if maxDepth >= 2 || condensedCount >= 2 {
|
|
systemPromptAddition = "Your context has been heavily compressed through multi-level summarization.\n" +
|
|
"- Do NOT assert specific facts (commands, SHAs, paths, timestamps) from summaries without expanding.\n" +
|
|
"- When uncertain, use expand to recover original detail before making claims.\n" +
|
|
"- Tool escalation: grep \xe2\x86\x92 describe \xe2\x86\x92 expand"
|
|
} else {
|
|
systemPromptAddition = "Some earlier messages have been summarized. Use expand tools to recover details if needed."
|
|
}
|
|
}
|
|
|
|
// Build Summary field: all XML summaries + system prompt addition
|
|
var summaryParts []string
|
|
for _, sum := range summaries {
|
|
if sum.Content == "" {
|
|
continue
|
|
}
|
|
// Load parent IDs for XML formatting
|
|
parentSummaries, err := a.store.GetSummaryParents(ctx, sum.SummaryID)
|
|
if err != nil {
|
|
logger.WarnCF("seahorse", "assemble: get summary parents", map[string]any{
|
|
"summary_id": sum.SummaryID,
|
|
"error": err.Error(),
|
|
})
|
|
}
|
|
var parentIDs []string
|
|
for _, ps := range parentSummaries {
|
|
parentIDs = append(parentIDs, ps.SummaryID)
|
|
}
|
|
summaryParts = append(summaryParts, FormatSummaryXML(&sum, parentIDs))
|
|
}
|
|
summary := strings.Join(summaryParts, "\n\n")
|
|
if systemPromptAddition != "" {
|
|
if summary != "" {
|
|
summary += "\n\n"
|
|
}
|
|
summary += systemPromptAddition
|
|
}
|
|
|
|
return &AssembleResult{
|
|
Messages: messages,
|
|
Summary: summary,
|
|
}, nil
|
|
}
|
|
|
|
func trimFreshTailToSafeBudget(tail []resolvedItem, budget int) ([]resolvedItem, int, bool) {
|
|
tailTokens := resolvedItemsTokenCount(tail)
|
|
if tailTokens <= budget {
|
|
return tail, tailTokens, false
|
|
}
|
|
|
|
latestTurnStart := lastUserMessageIndex(tail)
|
|
if latestTurnStart >= 0 {
|
|
latestTurnTokens := resolvedItemsTokenCount(tail[latestTurnStart:])
|
|
if latestTurnTokens > budget {
|
|
return tail[latestTurnStart:], latestTurnTokens, true
|
|
}
|
|
}
|
|
|
|
start := 0
|
|
for tailTokens > budget && start < len(tail) {
|
|
tailTokens -= tail[start].tokenCount
|
|
start++
|
|
}
|
|
for start < len(tail) && !isProviderSafeHistoryStart(tail[start:]) {
|
|
tailTokens -= tail[start].tokenCount
|
|
start++
|
|
}
|
|
|
|
return tail[start:], tailTokens, false
|
|
}
|
|
|
|
func resolvedItemsTokenCount(items []resolvedItem) int {
|
|
total := 0
|
|
for _, item := range items {
|
|
total += item.tokenCount
|
|
}
|
|
return total
|
|
}
|
|
|
|
func lastUserMessageIndex(items []resolvedItem) int {
|
|
for i := len(items) - 1; i >= 0; i-- {
|
|
if items[i].itemType != "message" || items[i].message == nil {
|
|
continue
|
|
}
|
|
if items[i].message.Role == "user" {
|
|
return i
|
|
}
|
|
}
|
|
return -1
|
|
}
|
|
|
|
func isProviderSafeHistoryStart(items []resolvedItem) bool {
|
|
for _, item := range items {
|
|
if item.itemType != "message" || item.message == nil {
|
|
continue
|
|
}
|
|
if item.message.Role == "tool" {
|
|
return false
|
|
}
|
|
if item.message.Role == "assistant" && messageHasToolUse(item.message) {
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
return true
|
|
}
|
|
|
|
func messageHasToolUse(msg *Message) bool {
|
|
if msg == nil {
|
|
return false
|
|
}
|
|
for _, part := range msg.Parts {
|
|
if part.Type == "tool_use" {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// resolveItem loads the full message or summary for a context item.
|
|
func (a *Assembler) resolveItem(ctx context.Context, item ContextItem) (resolvedItem, error) {
|
|
if item.ItemType == "message" {
|
|
msg, err := a.store.GetMessageByID(ctx, item.MessageID)
|
|
if err != nil {
|
|
return resolvedItem{}, err
|
|
}
|
|
tokens := item.TokenCount
|
|
if tokens == 0 {
|
|
tokens = msg.TokenCount
|
|
}
|
|
return resolvedItem{
|
|
ordinal: item.Ordinal,
|
|
itemType: "message",
|
|
message: msg,
|
|
tokenCount: tokens,
|
|
}, nil
|
|
}
|
|
|
|
if item.ItemType == "summary" {
|
|
sum, err := a.store.GetSummary(ctx, item.SummaryID)
|
|
if err != nil {
|
|
return resolvedItem{}, err
|
|
}
|
|
tokens := item.TokenCount
|
|
if tokens == 0 {
|
|
tokens = sum.TokenCount
|
|
}
|
|
return resolvedItem{
|
|
ordinal: item.Ordinal,
|
|
itemType: "summary",
|
|
summary: sum,
|
|
tokenCount: tokens,
|
|
}, nil
|
|
}
|
|
|
|
return resolvedItem{
|
|
ordinal: item.Ordinal,
|
|
itemType: item.ItemType,
|
|
tokenCount: item.TokenCount,
|
|
}, nil
|
|
}
|
|
|
|
// FormatSummaryXML formats a summary as XML for LLM context.
|
|
// This is exported so context managers can format summaries consistently.
|
|
func FormatSummaryXML(s *Summary, parentIDs []string) string {
|
|
// Build time attributes if available
|
|
var attrs string
|
|
if s.EarliestAt != nil {
|
|
attrs += fmt.Sprintf(` earliest_at="%s"`, s.EarliestAt.Format(time.RFC3339))
|
|
}
|
|
if s.LatestAt != nil {
|
|
attrs += fmt.Sprintf(` latest_at="%s"`, s.LatestAt.Format(time.RFC3339))
|
|
}
|
|
|
|
var parentsSection string
|
|
if s.Kind == SummaryKindCondensed && len(parentIDs) > 0 {
|
|
parents := "<parents>\n"
|
|
for _, pid := range parentIDs {
|
|
parents += fmt.Sprintf(" <summary_ref id=\"%s\" />\n", pid)
|
|
}
|
|
parents += " </parents>\n"
|
|
parentsSection = parents
|
|
}
|
|
return fmt.Sprintf(
|
|
"<summary id=\"%s\" kind=\"%s\" depth=\"%d\" descendant_count=\"%d\"%s>\n <content>\n %s\n </content>\n%s</summary>",
|
|
s.SummaryID,
|
|
string(s.Kind),
|
|
s.Depth,
|
|
s.DescendantCount,
|
|
attrs,
|
|
escapeXML(s.Content),
|
|
parentsSection,
|
|
)
|
|
}
|