mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge branch 'fix/seahorse-fresh-tail-budget'
# Conflicts: # pkg/agent/pipeline_llm.go # pkg/agent/pipeline_setup.go # pkg/agent/turn_state.go
This commit is contained in:
@@ -115,3 +115,55 @@ func isOverContextBudget(
|
||||
|
||||
return total > contextWindow
|
||||
}
|
||||
|
||||
// trimHistoryToFitContextWindow rebuilds the prompt from progressively newer
|
||||
// history slices until it fits within the context window. Oldest complete turns
|
||||
// are dropped first so tool-call sequences remain intact.
|
||||
func trimHistoryToFitContextWindow(
|
||||
history []providers.Message,
|
||||
build func([]providers.Message) []providers.Message,
|
||||
contextWindow int,
|
||||
toolDefs []providers.ToolDefinition,
|
||||
maxTokens int,
|
||||
) ([]providers.Message, []providers.Message, bool) {
|
||||
messages := build(history)
|
||||
if !isOverContextBudget(contextWindow, messages, toolDefs, maxTokens) {
|
||||
return history, messages, true
|
||||
}
|
||||
|
||||
trimmedHistory := append([]providers.Message(nil), history...)
|
||||
for len(trimmedHistory) > 0 {
|
||||
dropUntil := nextHistoryTrimStart(trimmedHistory)
|
||||
if dropUntil <= 0 || dropUntil >= len(trimmedHistory) {
|
||||
trimmedHistory = nil
|
||||
} else {
|
||||
trimmedHistory = append([]providers.Message(nil), trimmedHistory[dropUntil:]...)
|
||||
}
|
||||
|
||||
messages = build(trimmedHistory)
|
||||
if !isOverContextBudget(contextWindow, messages, toolDefs, maxTokens) {
|
||||
return trimmedHistory, messages, true
|
||||
}
|
||||
}
|
||||
|
||||
return nil, messages, false
|
||||
}
|
||||
|
||||
func nextHistoryTrimStart(history []providers.Message) int {
|
||||
if len(history) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
turns := parseTurnBoundaries(history)
|
||||
if len(turns) >= 2 {
|
||||
return turns[1]
|
||||
}
|
||||
if len(turns) == 1 {
|
||||
if turns[0] > 0 {
|
||||
return turns[0]
|
||||
}
|
||||
return len(history)
|
||||
}
|
||||
|
||||
return len(history)
|
||||
}
|
||||
|
||||
@@ -844,3 +844,64 @@ func TestIsOverContextBudget_RealisticSession(t *testing.T) {
|
||||
t.Error("realistic session should exceed 500 context window")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTrimHistoryToFitContextWindow_DropsOldestTurns(t *testing.T) {
|
||||
history := []providers.Message{
|
||||
msgUser(strings.Repeat("u1 ", 120)),
|
||||
msgAssistant(strings.Repeat("a1 ", 120)),
|
||||
msgUser(strings.Repeat("u2 ", 120)),
|
||||
msgAssistant(strings.Repeat("a2 ", 120)),
|
||||
msgUser(strings.Repeat("u3 ", 120)),
|
||||
msgAssistant(strings.Repeat("a3 ", 120)),
|
||||
}
|
||||
|
||||
build := func(history []providers.Message) []providers.Message {
|
||||
return append([]providers.Message(nil), history...)
|
||||
}
|
||||
|
||||
trimmedHistory, messages, fit := trimHistoryToFitContextWindow(
|
||||
history,
|
||||
build,
|
||||
700,
|
||||
nil,
|
||||
0,
|
||||
)
|
||||
if !fit {
|
||||
t.Fatal("expected trimmed history to fit context window")
|
||||
}
|
||||
if len(trimmedHistory) != 4 {
|
||||
t.Fatalf("trimmed history len = %d, want 4", len(trimmedHistory))
|
||||
}
|
||||
if trimmedHistory[0].Content != history[2].Content {
|
||||
t.Fatalf("first kept message = %q, want second turn start", trimmedHistory[0].Content)
|
||||
}
|
||||
if isOverContextBudget(700, messages, nil, 0) {
|
||||
t.Fatal("trimmed messages should be within budget")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTrimHistoryToFitContextWindow_ClearsSingleOversizedTurn(t *testing.T) {
|
||||
history := []providers.Message{
|
||||
msgUser(strings.Repeat("oversized ", 200)),
|
||||
msgAssistant(strings.Repeat("oversized ", 200)),
|
||||
}
|
||||
|
||||
trimmedHistory, messages, fit := trimHistoryToFitContextWindow(
|
||||
history,
|
||||
func(history []providers.Message) []providers.Message {
|
||||
return append([]providers.Message(nil), history...)
|
||||
},
|
||||
200,
|
||||
nil,
|
||||
0,
|
||||
)
|
||||
if !fit {
|
||||
t.Fatal("expected empty history rebuild to fit context window")
|
||||
}
|
||||
if len(trimmedHistory) != 0 {
|
||||
t.Fatalf("trimmed history len = %d, want 0", len(trimmedHistory))
|
||||
}
|
||||
if len(messages) != 0 {
|
||||
t.Fatalf("messages len = %d, want 0", len(messages))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -288,6 +288,74 @@ func TestSeahorseToProviderMessagesWithToolCalls(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSeahorseAssemblePreservesActiveToolTurnAcrossSanitization(t *testing.T) {
|
||||
engine, err := seahorse.NewEngine(seahorse.Config{
|
||||
DBPath: t.TempDir() + "/seahorse.db",
|
||||
}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewEngine: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
sessionKey := "test:active-tool-turn"
|
||||
_, err = engine.Ingest(ctx, sessionKey, []seahorse.Message{
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "older context",
|
||||
TokenCount: 20,
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Content: "inspect the file",
|
||||
TokenCount: 5,
|
||||
},
|
||||
{
|
||||
Role: "assistant",
|
||||
TokenCount: 5,
|
||||
Parts: []seahorse.MessagePart{{
|
||||
Type: "tool_use",
|
||||
Name: "read_file",
|
||||
Arguments: `{"path":"/tmp/test.txt"}`,
|
||||
ToolCallID: "tc_1",
|
||||
}},
|
||||
},
|
||||
{
|
||||
Role: "tool",
|
||||
TokenCount: 200,
|
||||
Parts: []seahorse.MessagePart{{
|
||||
Type: "tool_result",
|
||||
ToolCallID: "tc_1",
|
||||
Text: "very large tool output",
|
||||
}},
|
||||
},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "done",
|
||||
TokenCount: 5,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Ingest: %v", err)
|
||||
}
|
||||
|
||||
result, err := engine.Assemble(ctx, sessionKey, seahorse.AssembleInput{Budget: 210})
|
||||
if err != nil {
|
||||
t.Fatalf("Assemble: %v", err)
|
||||
}
|
||||
|
||||
sanitized := sanitizeHistoryForProvider(seahorseToProviderMessages(result))
|
||||
if len(sanitized) != 4 {
|
||||
t.Fatalf("sanitized history len = %d, want 4 protected-turn messages", len(sanitized))
|
||||
}
|
||||
assertRoles(t, sanitized, "user", "assistant", "tool", "assistant")
|
||||
if len(sanitized[1].ToolCalls) != 1 || sanitized[1].ToolCalls[0].ID != "tc_1" {
|
||||
t.Fatalf("assistant tool calls = %+v, want preserved tool call tc_1", sanitized[1].ToolCalls)
|
||||
}
|
||||
if sanitized[2].ToolCallID != "tc_1" {
|
||||
t.Fatalf("tool result id = %q, want tc_1", sanitized[2].ToolCallID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSeahorseToProviderMessagesToolResult(t *testing.T) {
|
||||
msg := seahorse.Message{
|
||||
Role: "tool",
|
||||
|
||||
@@ -415,14 +415,65 @@ func (p *Pipeline) CallLLM(
|
||||
contextualSkills = ts.agent.ContextBuilder.ResolveActiveSkillsForContext(ts.activeSkills)
|
||||
}
|
||||
ts.recordSkillContextSnapshot(skillContextTriggerContextRetryRebuild, contextualSkills)
|
||||
rebuildPromptReq := promptBuildRequestForTurn(ts, exec.history, exec.summary, "", nil, p.Cfg)
|
||||
rebuildPromptReq.ActiveSkills = append([]string(nil), contextualSkills...)
|
||||
exec.messages = ts.agent.ContextBuilder.BuildMessagesFromPrompt(rebuildPromptReq)
|
||||
exec.callMessages = exec.messages
|
||||
stableHistory, protectedTurnTail := splitHistoryForActiveTurn(
|
||||
exec.history,
|
||||
ts.persistedMessagesSnapshot(),
|
||||
)
|
||||
buildMessages := func(trimmedHistory []providers.Message) []providers.Message {
|
||||
fullHistory := append(append([]providers.Message(nil), trimmedHistory...), protectedTurnTail...)
|
||||
rebuildPromptReq := promptBuildRequestForTurn(ts, fullHistory, exec.summary, "", nil, p.Cfg)
|
||||
rebuildPromptReq.ActiveSkills = append([]string(nil), contextualSkills...)
|
||||
return ts.agent.ContextBuilder.BuildMessagesFromPrompt(rebuildPromptReq)
|
||||
}
|
||||
originalHistoryCount := len(exec.history)
|
||||
var fit bool
|
||||
var trimmedStableHistory []providers.Message
|
||||
trimmedStableHistory, exec.callMessages, fit = trimHistoryToFitContextWindow(
|
||||
stableHistory,
|
||||
func(trimmedHistory []providers.Message) []providers.Message {
|
||||
rebuilt := buildMessages(trimmedHistory)
|
||||
if exec.gracefulTerminal {
|
||||
return append(append([]providers.Message(nil), rebuilt...), ts.interruptHintMessage())
|
||||
}
|
||||
return rebuilt
|
||||
},
|
||||
ts.agent.ContextWindow,
|
||||
exec.providerToolDefs,
|
||||
ts.agent.MaxTokens,
|
||||
)
|
||||
exec.history = append(trimmedStableHistory, protectedTurnTail...)
|
||||
exec.messages = buildMessages(trimmedStableHistory)
|
||||
if exec.gracefulTerminal {
|
||||
msgs := append([]providers.Message(nil), exec.messages...)
|
||||
exec.callMessages = append(msgs, ts.interruptHintMessage())
|
||||
}
|
||||
if dropped := originalHistoryCount - len(exec.history); dropped > 0 {
|
||||
logger.WarnCF("agent", "Trimmed rebuilt history after context retry compaction", map[string]any{
|
||||
"session_key": ts.sessionKey,
|
||||
"retry": retry,
|
||||
"dropped_msgs": dropped,
|
||||
"remaining_msgs": len(exec.history),
|
||||
"context_window": ts.agent.ContextWindow,
|
||||
"max_tokens": ts.agent.MaxTokens,
|
||||
"still_overlimit": !fit,
|
||||
})
|
||||
} else if !fit {
|
||||
logger.WarnCF("agent", "Context still exceeds budget after retry compaction rebuild", map[string]any{
|
||||
"session_key": ts.sessionKey,
|
||||
"retry": retry,
|
||||
"history_msgs": len(exec.history),
|
||||
"protected_turn_msgs": len(protectedTurnTail),
|
||||
"context_window": ts.agent.ContextWindow,
|
||||
"max_tokens": ts.agent.MaxTokens,
|
||||
})
|
||||
}
|
||||
if !fit {
|
||||
err = fmt.Errorf(
|
||||
"context window still exceeded after retry compaction; refusing to drop active turn messages: %w",
|
||||
err,
|
||||
)
|
||||
break
|
||||
}
|
||||
continue
|
||||
}
|
||||
break
|
||||
|
||||
@@ -66,10 +66,38 @@ func (p *Pipeline) SetupTurn(ctx context.Context, ts *turnState) (*turnExecution
|
||||
history = resp.History
|
||||
summary = resp.Summary
|
||||
}
|
||||
rebuildPromptReq := promptBuildRequestForTurn(ts, history, summary, ts.userMessage, ts.media, cfg)
|
||||
rebuildPromptReq.ActiveSkills = append([]string(nil), contextualSkills...)
|
||||
messages = ts.agent.ContextBuilder.BuildMessagesFromPrompt(rebuildPromptReq)
|
||||
messages = resolveMediaRefs(messages, p.MediaStore, maxMediaSize)
|
||||
originalHistoryCount := len(history)
|
||||
var fit bool
|
||||
history, messages, fit = trimHistoryToFitContextWindow(
|
||||
history,
|
||||
func(trimmedHistory []providers.Message) []providers.Message {
|
||||
rebuildPromptReq := promptBuildRequestForTurn(ts, trimmedHistory, summary, ts.userMessage, ts.media, cfg)
|
||||
rebuildPromptReq.ActiveSkills = append([]string(nil), contextualSkills...)
|
||||
rebuilt := ts.agent.ContextBuilder.BuildMessagesFromPrompt(rebuildPromptReq)
|
||||
return resolveMediaRefs(rebuilt, p.MediaStore, maxMediaSize)
|
||||
},
|
||||
ts.agent.ContextWindow,
|
||||
toolDefs,
|
||||
ts.agent.MaxTokens,
|
||||
)
|
||||
if dropped := originalHistoryCount - len(history); dropped > 0 {
|
||||
logger.WarnCF("agent", "Trimmed rebuilt history after proactive compaction", map[string]any{
|
||||
"session_key": ts.sessionKey,
|
||||
"dropped_msgs": dropped,
|
||||
"remaining_msgs": len(history),
|
||||
"context_window": ts.agent.ContextWindow,
|
||||
"max_tokens": ts.agent.MaxTokens,
|
||||
"still_overlimit": !fit,
|
||||
})
|
||||
} else if !fit {
|
||||
logger.WarnCF("agent", "Context still exceeds budget "+
|
||||
"after proactive compaction rebuild", map[string]any{
|
||||
"session_key": ts.sessionKey,
|
||||
"history_msgs": len(history),
|
||||
"context_window": ts.agent.ContextWindow,
|
||||
"max_tokens": ts.agent.MaxTokens,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+76
-17
@@ -643,13 +643,17 @@ func (ts *turnState) recordPersistedMessage(msg providers.Message) {
|
||||
ts.persistedMessages = append(ts.persistedMessages, msg)
|
||||
}
|
||||
|
||||
func (ts *turnState) persistedMessagesSnapshot() []providers.Message {
|
||||
ts.mu.RLock()
|
||||
defer ts.mu.RUnlock()
|
||||
return append([]providers.Message(nil), ts.persistedMessages...)
|
||||
}
|
||||
|
||||
func (ts *turnState) refreshRestorePointFromSession(agent *AgentInstance) {
|
||||
history := agent.Sessions.GetHistory(ts.sessionKey)
|
||||
summary := agent.Sessions.GetSummary(ts.sessionKey)
|
||||
|
||||
ts.mu.RLock()
|
||||
persisted := append([]providers.Message(nil), ts.persistedMessages...)
|
||||
ts.mu.RUnlock()
|
||||
persisted := ts.persistedMessagesSnapshot()
|
||||
|
||||
if matched := matchingTurnMessageTail(history, persisted); matched > 0 {
|
||||
history = append([]providers.Message(nil), history[:len(history)-matched]...)
|
||||
@@ -686,29 +690,84 @@ func (ts *turnState) restoreSession(agent *AgentInstance) error {
|
||||
return agent.Sessions.Save(ts.sessionKey)
|
||||
}
|
||||
|
||||
// messagesContentEqual compares two message slices by content only, ignoring CreatedAt.
|
||||
// JSON roundtrip loses the monotonic clock portion of time.Time, so direct
|
||||
// reflect.DeepEqual would always differ on messages that roundtripped through
|
||||
// the JSONL store.
|
||||
func messagesContentEqual(a, b []providers.Message) bool {
|
||||
func matchingTurnMessageTail(history, persisted []providers.Message) int {
|
||||
maxMatch := min(len(history), len(persisted))
|
||||
for size := maxMatch; size > 0; size-- {
|
||||
if messageSlicesEquivalent(history[len(history)-size:], persisted[len(persisted)-size:]) {
|
||||
return size
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func splitHistoryForActiveTurn(
|
||||
history []providers.Message,
|
||||
persisted []providers.Message,
|
||||
) ([]providers.Message, []providers.Message) {
|
||||
matched := matchingTurnMessageTail(history, persisted)
|
||||
if matched <= 0 {
|
||||
return append([]providers.Message(nil), history...), nil
|
||||
}
|
||||
|
||||
stable := append([]providers.Message(nil), history[:len(history)-matched]...)
|
||||
protected := append([]providers.Message(nil), history[len(history)-matched:]...)
|
||||
return stable, protected
|
||||
}
|
||||
|
||||
func messageSlicesEquivalent(a, b []providers.Message) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
aCopy, bCopy := a[i], b[i]
|
||||
aCopy.CreatedAt, bCopy.CreatedAt = nil, nil
|
||||
if !reflect.DeepEqual(aCopy, bCopy) {
|
||||
if !messagesEquivalent(a[i], b[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func matchingTurnMessageTail(history, persisted []providers.Message) int {
|
||||
maxMatch := min(len(history), len(persisted))
|
||||
for size := maxMatch; size > 0; size-- {
|
||||
if messagesContentEqual(history[len(history)-size:], persisted[len(persisted)-size:]) {
|
||||
return size
|
||||
func messagesEquivalent(a, b providers.Message) bool {
|
||||
return reflect.DeepEqual(normalizeMessageForComparison(a), normalizeMessageForComparison(b))
|
||||
}
|
||||
|
||||
func normalizeMessageForComparison(msg providers.Message) providers.Message {
|
||||
msg.PromptLayer = ""
|
||||
msg.PromptSlot = ""
|
||||
msg.PromptSource = ""
|
||||
|
||||
if len(msg.Media) == 0 {
|
||||
msg.Media = nil
|
||||
}
|
||||
if len(msg.Attachments) == 0 {
|
||||
msg.Attachments = nil
|
||||
}
|
||||
if len(msg.SystemParts) == 0 {
|
||||
msg.SystemParts = nil
|
||||
} else {
|
||||
msg.SystemParts = append([]providers.ContentBlock(nil), msg.SystemParts...)
|
||||
for i := range msg.SystemParts {
|
||||
msg.SystemParts[i].PromptLayer = ""
|
||||
msg.SystemParts[i].PromptSlot = ""
|
||||
msg.SystemParts[i].PromptSource = ""
|
||||
}
|
||||
}
|
||||
return 0
|
||||
if len(msg.ToolCalls) == 0 {
|
||||
msg.ToolCalls = nil
|
||||
} else {
|
||||
msg.ToolCalls = append([]providers.ToolCall(nil), msg.ToolCalls...)
|
||||
for i := range msg.ToolCalls {
|
||||
msg.ToolCalls[i].Name = ""
|
||||
msg.ToolCalls[i].Arguments = nil
|
||||
msg.ToolCalls[i].ThoughtSignature = ""
|
||||
if msg.ToolCalls[i].Function != nil {
|
||||
fn := *msg.ToolCalls[i].Function
|
||||
fn.ThoughtSignature = ""
|
||||
msg.ToolCalls[i].Function = &fn
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return msg
|
||||
}
|
||||
|
||||
func (ts *turnState) interruptHintMessage() providers.Message {
|
||||
|
||||
@@ -0,0 +1,112 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
func TestMatchingTurnMessageTail_IgnoresInternalRuntimeFields(t *testing.T) {
|
||||
history := []providers.Message{
|
||||
{Role: "user", Content: "question"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []providers.ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Type: "function",
|
||||
Function: &providers.FunctionCall{
|
||||
Name: "read_file",
|
||||
Arguments: `{"path":"/tmp/test"}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
persisted := []providers.Message{
|
||||
userPromptMessage("question", nil),
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []providers.ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Type: "function",
|
||||
Name: "read_file",
|
||||
Arguments: map[string]any{"path": "/tmp/test"},
|
||||
ThoughtSignature: "internal-signature",
|
||||
Function: &providers.FunctionCall{
|
||||
Name: "read_file",
|
||||
Arguments: `{"path":"/tmp/test"}`,
|
||||
ThoughtSignature: "internal-signature",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if got := matchingTurnMessageTail(history, persisted); got != 2 {
|
||||
t.Fatalf("matchingTurnMessageTail() = %d, want 2", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitHistoryForActiveTurn_ProtectsPersistedTail(t *testing.T) {
|
||||
history := []providers.Message{
|
||||
{Role: "user", Content: "old question"},
|
||||
{Role: "assistant", Content: "old answer"},
|
||||
{Role: "user", Content: "current question"},
|
||||
{Role: "tool", Content: "tool output", ToolCallID: "call_1"},
|
||||
}
|
||||
|
||||
persisted := []providers.Message{
|
||||
userPromptMessage("current question", nil),
|
||||
{Role: "tool", Content: "tool output", ToolCallID: "call_1"},
|
||||
}
|
||||
|
||||
stable, protected := splitHistoryForActiveTurn(history, persisted)
|
||||
if len(stable) != 2 {
|
||||
t.Fatalf("stable history len = %d, want 2", len(stable))
|
||||
}
|
||||
if len(protected) != 2 {
|
||||
t.Fatalf("protected tail len = %d, want 2", len(protected))
|
||||
}
|
||||
if protected[0].Content != "current question" {
|
||||
t.Fatalf("protected[0].Content = %q, want current question", protected[0].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTrimHistoryToFitContextWindow_WithProtectedTurnTailKeepsActiveTurn(t *testing.T) {
|
||||
current := strings.Repeat("current turn ", 80)
|
||||
history := []providers.Message{
|
||||
{Role: "user", Content: strings.Repeat("old turn ", 60)},
|
||||
{Role: "assistant", Content: strings.Repeat("old reply ", 60)},
|
||||
{Role: "user", Content: current},
|
||||
}
|
||||
|
||||
stable, protected := splitHistoryForActiveTurn(history, []providers.Message{
|
||||
userPromptMessage(current, nil),
|
||||
})
|
||||
trimmedStable, messages, fit := trimHistoryToFitContextWindow(
|
||||
stable,
|
||||
func(trimmedHistory []providers.Message) []providers.Message {
|
||||
return append(append([]providers.Message(nil), trimmedHistory...), protected...)
|
||||
},
|
||||
120,
|
||||
nil,
|
||||
0,
|
||||
)
|
||||
|
||||
if fit {
|
||||
t.Fatal("expected protected active turn alone to remain over budget")
|
||||
}
|
||||
if len(trimmedStable) != 0 {
|
||||
t.Fatalf("trimmed stable history len = %d, want 0", len(trimmedStable))
|
||||
}
|
||||
if len(messages) != 1 {
|
||||
t.Fatalf("messages len = %d, want 1 protected active-turn message", len(messages))
|
||||
}
|
||||
if messages[0].Content != current {
|
||||
t.Fatalf("messages[0].Content = %q, want protected current turn", messages[0].Content)
|
||||
}
|
||||
}
|
||||
@@ -68,17 +68,31 @@ func (a *Assembler) Assemble(ctx context.Context, convID int64, input AssembleIn
|
||||
freshTailTokens += r.tokenCount
|
||||
}
|
||||
|
||||
// Budget-aware selection of evictable items
|
||||
// If the protected tail alone exceeds budget, trim from the oldest end at
|
||||
// provider-safe boundaries. The rebuild path later sanitizes leading
|
||||
// assistant(tool_calls)/tool messages, so splitting the active turn here can
|
||||
// silently discard the very context we are trying to protect.
|
||||
if freshTailTokens > input.Budget {
|
||||
originalTailCount := len(freshTail)
|
||||
originalFreshTailTokens := freshTailTokens
|
||||
var preservedActiveTurn bool
|
||||
freshTail, freshTailTokens, preservedActiveTurn = trimFreshTailToSafeBudget(freshTail, input.Budget)
|
||||
logFields := map[string]any{
|
||||
"budget": input.Budget,
|
||||
"fresh_tail_tokens": freshTailTokens,
|
||||
"fresh_tail_count": len(freshTail),
|
||||
"trimmed_fresh_items": originalTailCount - len(freshTail),
|
||||
"original_fresh_tokens": originalFreshTailTokens,
|
||||
"preserved_active_turn": preservedActiveTurn,
|
||||
}
|
||||
if preservedActiveTurn {
|
||||
logger.WarnCF("seahorse", "assemble: preserving active turn over budget", logFields)
|
||||
} else {
|
||||
logger.InfoCF("seahorse", "assemble: trimmed fresh tail to safe boundary", logFields)
|
||||
}
|
||||
}
|
||||
remainingBudget := input.Budget - freshTailTokens
|
||||
if remainingBudget < 0 {
|
||||
// Fresh tail alone exceeds budget - we keep it anyway (design decision)
|
||||
// Log for debugging retry/overflow issues
|
||||
logger.InfoCF("seahorse", "assemble: fresh tail exceeds budget", map[string]any{
|
||||
"budget": input.Budget,
|
||||
"fresh_tail_tokens": freshTailTokens,
|
||||
"fresh_tail_count": len(freshTail),
|
||||
"over_budget_by": freshTailTokens - input.Budget,
|
||||
})
|
||||
remainingBudget = 0
|
||||
}
|
||||
|
||||
@@ -184,6 +198,81 @@ func (a *Assembler) Assemble(ctx context.Context, convID int64, input AssembleIn
|
||||
}, nil
|
||||
}
|
||||
|
||||
func trimFreshTailToSafeBudget(tail []resolvedItem, budget int) ([]resolvedItem, int, bool) {
|
||||
tailTokens := resolvedItemsTokenCount(tail)
|
||||
if tailTokens <= budget {
|
||||
return tail, tailTokens, false
|
||||
}
|
||||
|
||||
latestTurnStart := lastUserMessageIndex(tail)
|
||||
if latestTurnStart >= 0 {
|
||||
latestTurnTokens := resolvedItemsTokenCount(tail[latestTurnStart:])
|
||||
if latestTurnTokens > budget {
|
||||
return tail[latestTurnStart:], latestTurnTokens, true
|
||||
}
|
||||
}
|
||||
|
||||
start := 0
|
||||
for tailTokens > budget && start < len(tail) {
|
||||
tailTokens -= tail[start].tokenCount
|
||||
start++
|
||||
}
|
||||
for start < len(tail) && !isProviderSafeHistoryStart(tail[start:]) {
|
||||
tailTokens -= tail[start].tokenCount
|
||||
start++
|
||||
}
|
||||
|
||||
return tail[start:], tailTokens, false
|
||||
}
|
||||
|
||||
func resolvedItemsTokenCount(items []resolvedItem) int {
|
||||
total := 0
|
||||
for _, item := range items {
|
||||
total += item.tokenCount
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
func lastUserMessageIndex(items []resolvedItem) int {
|
||||
for i := len(items) - 1; i >= 0; i-- {
|
||||
if items[i].itemType != "message" || items[i].message == nil {
|
||||
continue
|
||||
}
|
||||
if items[i].message.Role == "user" {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func isProviderSafeHistoryStart(items []resolvedItem) bool {
|
||||
for _, item := range items {
|
||||
if item.itemType != "message" || item.message == nil {
|
||||
continue
|
||||
}
|
||||
if item.message.Role == "tool" {
|
||||
return false
|
||||
}
|
||||
if item.message.Role == "assistant" && messageHasToolUse(item.message) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func messageHasToolUse(msg *Message) bool {
|
||||
if msg == nil {
|
||||
return false
|
||||
}
|
||||
for _, part := range msg.Parts {
|
||||
if part.Type == "tool_use" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// resolveItem loads the full message or summary for a context item.
|
||||
func (a *Assembler) resolveItem(ctx context.Context, item ContextItem) (resolvedItem, error) {
|
||||
if item.ItemType == "message" {
|
||||
|
||||
@@ -145,22 +145,91 @@ func TestAssemblerBudgetEvictsOldest(t *testing.T) {
|
||||
s.UpsertContextItems(ctx, convID, items)
|
||||
|
||||
// Budget of 200 tokens with FreshTailCount=32
|
||||
// Fresh tail = last 32 messages (320 tokens, over budget, but always included)
|
||||
// Fresh tail = last 32 messages (320 tokens, over budget)
|
||||
// Evictable = first 8 messages (80 tokens)
|
||||
// Budget after tail: max(0, 200-320) = 0 → no evictable items included
|
||||
// The oldest messages from the fresh tail should be dropped so only the
|
||||
// newest 20 messages remain within the 200-token budget.
|
||||
a := &Assembler{store: s, config: Config{}}
|
||||
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 200})
|
||||
if err != nil {
|
||||
t.Fatalf("Assemble: %v", err)
|
||||
}
|
||||
|
||||
// Should only include the 32-item fresh tail
|
||||
if len(result.Messages) != 32 {
|
||||
t.Errorf("Messages = %d, want 32 (fresh tail)", len(result.Messages))
|
||||
if len(result.Messages) != 20 {
|
||||
t.Errorf("Messages = %d, want 20", len(result.Messages))
|
||||
}
|
||||
// Should be the LAST 32 messages
|
||||
if result.Messages[0].ID != msgs[8].ID {
|
||||
t.Errorf("first message ID = %d, want %d (msgs[8])", result.Messages[0].ID, msgs[8].ID)
|
||||
if result.Messages[0].ID != msgs[20].ID {
|
||||
t.Errorf("first message ID = %d, want %d (msgs[20])", result.Messages[0].ID, msgs[20].ID)
|
||||
}
|
||||
|
||||
totalTokens := 0
|
||||
for _, msg := range result.Messages {
|
||||
totalTokens += msg.TokenCount
|
||||
}
|
||||
if totalTokens > 200 {
|
||||
t.Errorf("assembled tokens = %d, want <= 200", totalTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAssemblerBudgetPreservesLatestToolTurnWhenItExceedsBudget(t *testing.T) {
|
||||
s, convID := setupAssemblerStore(t)
|
||||
ctx := context.Background()
|
||||
|
||||
oldMsg, _ := s.AddMessage(ctx, convID, "assistant", "older context", 20)
|
||||
userMsg, _ := s.AddMessage(ctx, convID, "user", "inspect the file", 5)
|
||||
assistantToolMsg, _ := s.AddMessageWithParts(ctx, convID, "assistant", []MessagePart{
|
||||
{
|
||||
Type: "tool_use",
|
||||
Name: "read_file",
|
||||
Arguments: `{"path":"/tmp/test.txt"}`,
|
||||
ToolCallID: "tc_1",
|
||||
},
|
||||
}, 5)
|
||||
toolResultMsg, _ := s.AddMessageWithParts(ctx, convID, "tool", []MessagePart{
|
||||
{
|
||||
Type: "tool_result",
|
||||
ToolCallID: "tc_1",
|
||||
Text: "very large tool output",
|
||||
},
|
||||
}, 200)
|
||||
finalAssistantMsg, _ := s.AddMessage(ctx, convID, "assistant", "done", 5)
|
||||
|
||||
s.UpsertContextItems(ctx, convID, []ContextItem{
|
||||
{Ordinal: 100, ItemType: "message", MessageID: oldMsg.ID, TokenCount: 20},
|
||||
{Ordinal: 200, ItemType: "message", MessageID: userMsg.ID, TokenCount: 5},
|
||||
{Ordinal: 300, ItemType: "message", MessageID: assistantToolMsg.ID, TokenCount: 5},
|
||||
{Ordinal: 400, ItemType: "message", MessageID: toolResultMsg.ID, TokenCount: 200},
|
||||
{Ordinal: 500, ItemType: "message", MessageID: finalAssistantMsg.ID, TokenCount: 5},
|
||||
})
|
||||
|
||||
a := &Assembler{store: s, config: Config{}}
|
||||
result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 210})
|
||||
if err != nil {
|
||||
t.Fatalf("Assemble: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Messages) != 4 {
|
||||
t.Fatalf("Messages = %d, want 4 protected-turn messages", len(result.Messages))
|
||||
}
|
||||
if result.Messages[0].ID != userMsg.ID {
|
||||
t.Fatalf("first message ID = %d, want current user message %d", result.Messages[0].ID, userMsg.ID)
|
||||
}
|
||||
if result.Messages[1].ID != assistantToolMsg.ID {
|
||||
t.Fatalf("second message ID = %d, want assistant tool-call %d", result.Messages[1].ID, assistantToolMsg.ID)
|
||||
}
|
||||
if result.Messages[2].ID != toolResultMsg.ID {
|
||||
t.Fatalf("third message ID = %d, want tool result %d", result.Messages[2].ID, toolResultMsg.ID)
|
||||
}
|
||||
if result.Messages[3].ID != finalAssistantMsg.ID {
|
||||
t.Fatalf("fourth message ID = %d, want final assistant %d", result.Messages[3].ID, finalAssistantMsg.ID)
|
||||
}
|
||||
|
||||
totalTokens := 0
|
||||
for _, msg := range result.Messages {
|
||||
totalTokens += msg.TokenCount
|
||||
}
|
||||
if totalTokens <= 210 {
|
||||
t.Fatalf("assembled tokens = %d, want protected turn to remain over budget", totalTokens)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user