From 1502636bf078b292d82aa05156c3ad217b1e3f4b Mon Sep 17 00:00:00 2001 From: afjcjsbx Date: Mon, 18 May 2026 21:11:21 +0200 Subject: [PATCH 1/3] fix(seahorse): enforce budget on fresh tail and rebuild paths --- pkg/agent/context_budget.go | 52 ++++++++++++++++++++++++ pkg/agent/context_budget_test.go | 61 ++++++++++++++++++++++++++++ pkg/agent/pipeline_llm.go | 44 ++++++++++++++++++-- pkg/agent/pipeline_setup.go | 36 ++++++++++++++-- pkg/seahorse/short_assembler.go | 27 +++++++----- pkg/seahorse/short_assembler_test.go | 23 +++++++---- 6 files changed, 216 insertions(+), 27 deletions(-) diff --git a/pkg/agent/context_budget.go b/pkg/agent/context_budget.go index 72f80382a..1d99bf190 100644 --- a/pkg/agent/context_budget.go +++ b/pkg/agent/context_budget.go @@ -115,3 +115,55 @@ func isOverContextBudget( return total > contextWindow } + +// trimHistoryToFitContextWindow rebuilds the prompt from progressively newer +// history slices until it fits within the context window. Oldest complete turns +// are dropped first so tool-call sequences remain intact. +func trimHistoryToFitContextWindow( + history []providers.Message, + build func([]providers.Message) []providers.Message, + contextWindow int, + toolDefs []providers.ToolDefinition, + maxTokens int, +) ([]providers.Message, []providers.Message, bool) { + messages := build(history) + if !isOverContextBudget(contextWindow, messages, toolDefs, maxTokens) { + return history, messages, true + } + + trimmedHistory := append([]providers.Message(nil), history...) + for len(trimmedHistory) > 0 { + dropUntil := nextHistoryTrimStart(trimmedHistory) + if dropUntil <= 0 || dropUntil >= len(trimmedHistory) { + trimmedHistory = nil + } else { + trimmedHistory = append([]providers.Message(nil), trimmedHistory[dropUntil:]...) + } + + messages = build(trimmedHistory) + if !isOverContextBudget(contextWindow, messages, toolDefs, maxTokens) { + return trimmedHistory, messages, true + } + } + + return nil, messages, false +} + +func nextHistoryTrimStart(history []providers.Message) int { + if len(history) == 0 { + return 0 + } + + turns := parseTurnBoundaries(history) + if len(turns) >= 2 { + return turns[1] + } + if len(turns) == 1 { + if turns[0] > 0 { + return turns[0] + } + return len(history) + } + + return len(history) +} diff --git a/pkg/agent/context_budget_test.go b/pkg/agent/context_budget_test.go index 9de1707ec..d7ca2a665 100644 --- a/pkg/agent/context_budget_test.go +++ b/pkg/agent/context_budget_test.go @@ -844,3 +844,64 @@ func TestIsOverContextBudget_RealisticSession(t *testing.T) { t.Error("realistic session should exceed 500 context window") } } + +func TestTrimHistoryToFitContextWindow_DropsOldestTurns(t *testing.T) { + history := []providers.Message{ + msgUser(strings.Repeat("u1 ", 120)), + msgAssistant(strings.Repeat("a1 ", 120)), + msgUser(strings.Repeat("u2 ", 120)), + msgAssistant(strings.Repeat("a2 ", 120)), + msgUser(strings.Repeat("u3 ", 120)), + msgAssistant(strings.Repeat("a3 ", 120)), + } + + build := func(history []providers.Message) []providers.Message { + return append([]providers.Message(nil), history...) + } + + trimmedHistory, messages, fit := trimHistoryToFitContextWindow( + history, + build, + 700, + nil, + 0, + ) + if !fit { + t.Fatal("expected trimmed history to fit context window") + } + if len(trimmedHistory) != 4 { + t.Fatalf("trimmed history len = %d, want 4", len(trimmedHistory)) + } + if trimmedHistory[0].Content != history[2].Content { + t.Fatalf("first kept message = %q, want second turn start", trimmedHistory[0].Content) + } + if isOverContextBudget(700, messages, nil, 0) { + t.Fatal("trimmed messages should be within budget") + } +} + +func TestTrimHistoryToFitContextWindow_ClearsSingleOversizedTurn(t *testing.T) { + history := []providers.Message{ + msgUser(strings.Repeat("oversized ", 200)), + msgAssistant(strings.Repeat("oversized ", 200)), + } + + trimmedHistory, messages, fit := trimHistoryToFitContextWindow( + history, + func(history []providers.Message) []providers.Message { + return append([]providers.Message(nil), history...) + }, + 200, + nil, + 0, + ) + if !fit { + t.Fatal("expected empty history rebuild to fit context window") + } + if len(trimmedHistory) != 0 { + t.Fatalf("trimmed history len = %d, want 0", len(trimmedHistory)) + } + if len(messages) != 0 { + t.Fatalf("messages len = %d, want 0", len(messages)) + } +} diff --git a/pkg/agent/pipeline_llm.go b/pkg/agent/pipeline_llm.go index 5de590129..7934b7815 100644 --- a/pkg/agent/pipeline_llm.go +++ b/pkg/agent/pipeline_llm.go @@ -369,14 +369,50 @@ func (p *Pipeline) CallLLM( contextualSkills = ts.agent.ContextBuilder.ResolveActiveSkillsForContext(ts.activeSkills) } ts.recordSkillContextSnapshot(skillContextTriggerContextRetryRebuild, contextualSkills) - rebuildPromptReq := promptBuildRequestForTurn(ts, exec.history, exec.summary, "", nil) - rebuildPromptReq.ActiveSkills = append([]string(nil), contextualSkills...) - exec.messages = ts.agent.ContextBuilder.BuildMessagesFromPrompt(rebuildPromptReq) - exec.callMessages = exec.messages + buildMessages := func(trimmedHistory []providers.Message) []providers.Message { + rebuildPromptReq := promptBuildRequestForTurn(ts, trimmedHistory, exec.summary, "", nil) + rebuildPromptReq.ActiveSkills = append([]string(nil), contextualSkills...) + return ts.agent.ContextBuilder.BuildMessagesFromPrompt(rebuildPromptReq) + } + originalHistoryCount := len(exec.history) + var fit bool + exec.history, exec.callMessages, fit = trimHistoryToFitContextWindow( + exec.history, + func(trimmedHistory []providers.Message) []providers.Message { + rebuilt := buildMessages(trimmedHistory) + if exec.gracefulTerminal { + return append(append([]providers.Message(nil), rebuilt...), ts.interruptHintMessage()) + } + return rebuilt + }, + ts.agent.ContextWindow, + exec.providerToolDefs, + ts.agent.MaxTokens, + ) + exec.messages = buildMessages(exec.history) if exec.gracefulTerminal { msgs := append([]providers.Message(nil), exec.messages...) exec.callMessages = append(msgs, ts.interruptHintMessage()) } + if dropped := originalHistoryCount - len(exec.history); dropped > 0 { + logger.WarnCF("agent", "Trimmed rebuilt history after context retry compaction", map[string]any{ + "session_key": ts.sessionKey, + "retry": retry, + "dropped_msgs": dropped, + "remaining_msgs": len(exec.history), + "context_window": ts.agent.ContextWindow, + "max_tokens": ts.agent.MaxTokens, + "still_overlimit": !fit, + }) + } else if !fit { + logger.WarnCF("agent", "Context still exceeds budget after retry compaction rebuild", map[string]any{ + "session_key": ts.sessionKey, + "retry": retry, + "history_msgs": len(exec.history), + "context_window": ts.agent.ContextWindow, + "max_tokens": ts.agent.MaxTokens, + }) + } continue } break diff --git a/pkg/agent/pipeline_setup.go b/pkg/agent/pipeline_setup.go index f6fed09de..cf03526b1 100644 --- a/pkg/agent/pipeline_setup.go +++ b/pkg/agent/pipeline_setup.go @@ -66,10 +66,38 @@ func (p *Pipeline) SetupTurn(ctx context.Context, ts *turnState) (*turnExecution history = resp.History summary = resp.Summary } - rebuildPromptReq := promptBuildRequestForTurn(ts, history, summary, ts.userMessage, ts.media) - rebuildPromptReq.ActiveSkills = append([]string(nil), contextualSkills...) - messages = ts.agent.ContextBuilder.BuildMessagesFromPrompt(rebuildPromptReq) - messages = resolveMediaRefs(messages, p.MediaStore, maxMediaSize) + originalHistoryCount := len(history) + var fit bool + history, messages, fit = trimHistoryToFitContextWindow( + history, + func(trimmedHistory []providers.Message) []providers.Message { + rebuildPromptReq := promptBuildRequestForTurn(ts, trimmedHistory, summary, ts.userMessage, ts.media) + rebuildPromptReq.ActiveSkills = append([]string(nil), contextualSkills...) + rebuilt := ts.agent.ContextBuilder.BuildMessagesFromPrompt(rebuildPromptReq) + return resolveMediaRefs(rebuilt, p.MediaStore, maxMediaSize) + }, + ts.agent.ContextWindow, + toolDefs, + ts.agent.MaxTokens, + ) + if dropped := originalHistoryCount - len(history); dropped > 0 { + logger.WarnCF("agent", "Trimmed rebuilt history after proactive compaction", map[string]any{ + "session_key": ts.sessionKey, + "dropped_msgs": dropped, + "remaining_msgs": len(history), + "context_window": ts.agent.ContextWindow, + "max_tokens": ts.agent.MaxTokens, + "still_overlimit": !fit, + }) + } else if !fit { + logger.WarnCF("agent", "Context still exceeds budget "+ + "after proactive compaction rebuild", map[string]any{ + "session_key": ts.sessionKey, + "history_msgs": len(history), + "context_window": ts.agent.ContextWindow, + "max_tokens": ts.agent.MaxTokens, + }) + } } } diff --git a/pkg/seahorse/short_assembler.go b/pkg/seahorse/short_assembler.go index f0fd323ba..0bfd66a69 100644 --- a/pkg/seahorse/short_assembler.go +++ b/pkg/seahorse/short_assembler.go @@ -68,19 +68,24 @@ func (a *Assembler) Assemble(ctx context.Context, convID int64, input AssembleIn freshTailTokens += r.tokenCount } - // Budget-aware selection of evictable items - remainingBudget := input.Budget - freshTailTokens - if remainingBudget < 0 { - // Fresh tail alone exceeds budget - we keep it anyway (design decision) - // Log for debugging retry/overflow issues - logger.InfoCF("seahorse", "assemble: fresh tail exceeds budget", map[string]any{ - "budget": input.Budget, - "fresh_tail_tokens": freshTailTokens, - "fresh_tail_count": len(freshTail), - "over_budget_by": freshTailTokens - input.Budget, + // If the protected tail alone exceeds budget, trim from the oldest end of + // the tail until the newest items fit within the requested budget. + if freshTailTokens > input.Budget { + originalTailCount := len(freshTail) + originalFreshTailTokens := freshTailTokens + for freshTailTokens > input.Budget && len(freshTail) > 0 { + freshTailTokens -= freshTail[0].tokenCount + freshTail = freshTail[1:] + } + logger.InfoCF("seahorse", "assemble: trimmed fresh tail to budget", map[string]any{ + "budget": input.Budget, + "fresh_tail_tokens": freshTailTokens, + "fresh_tail_count": len(freshTail), + "trimmed_fresh_items": originalTailCount - len(freshTail), + "original_fresh_tokens": originalFreshTailTokens, }) - remainingBudget = 0 } + remainingBudget := input.Budget - freshTailTokens var selected []resolvedItem evictableTokens := 0 diff --git a/pkg/seahorse/short_assembler_test.go b/pkg/seahorse/short_assembler_test.go index 88a05e64c..81918afa0 100644 --- a/pkg/seahorse/short_assembler_test.go +++ b/pkg/seahorse/short_assembler_test.go @@ -145,22 +145,29 @@ func TestAssemblerBudgetEvictsOldest(t *testing.T) { s.UpsertContextItems(ctx, convID, items) // Budget of 200 tokens with FreshTailCount=32 - // Fresh tail = last 32 messages (320 tokens, over budget, but always included) + // Fresh tail = last 32 messages (320 tokens, over budget) // Evictable = first 8 messages (80 tokens) - // Budget after tail: max(0, 200-320) = 0 → no evictable items included + // The oldest messages from the fresh tail should be dropped so only the + // newest 20 messages remain within the 200-token budget. a := &Assembler{store: s, config: Config{}} result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 200}) if err != nil { t.Fatalf("Assemble: %v", err) } - // Should only include the 32-item fresh tail - if len(result.Messages) != 32 { - t.Errorf("Messages = %d, want 32 (fresh tail)", len(result.Messages)) + if len(result.Messages) != 20 { + t.Errorf("Messages = %d, want 20", len(result.Messages)) } - // Should be the LAST 32 messages - if result.Messages[0].ID != msgs[8].ID { - t.Errorf("first message ID = %d, want %d (msgs[8])", result.Messages[0].ID, msgs[8].ID) + if result.Messages[0].ID != msgs[20].ID { + t.Errorf("first message ID = %d, want %d (msgs[20])", result.Messages[0].ID, msgs[20].ID) + } + + totalTokens := 0 + for _, msg := range result.Messages { + totalTokens += msg.TokenCount + } + if totalTokens > 200 { + t.Errorf("assembled tokens = %d, want <= 200", totalTokens) } } From fe7ded5c138d62fd76c25bee9110333d08d4c2b2 Mon Sep 17 00:00:00 2001 From: afjcjsbx Date: Tue, 19 May 2026 09:18:39 +0200 Subject: [PATCH 2/3] fix(agent): preserve active turn during context retry rebuild --- pkg/agent/pipeline_llm.go | 33 ++++++++--- pkg/agent/turn_state.go | 82 +++++++++++++++++++++++-- pkg/agent/turn_state_test.go | 112 +++++++++++++++++++++++++++++++++++ 3 files changed, 214 insertions(+), 13 deletions(-) create mode 100644 pkg/agent/turn_state_test.go diff --git a/pkg/agent/pipeline_llm.go b/pkg/agent/pipeline_llm.go index 7934b7815..9fbe9b740 100644 --- a/pkg/agent/pipeline_llm.go +++ b/pkg/agent/pipeline_llm.go @@ -369,15 +369,21 @@ func (p *Pipeline) CallLLM( contextualSkills = ts.agent.ContextBuilder.ResolveActiveSkillsForContext(ts.activeSkills) } ts.recordSkillContextSnapshot(skillContextTriggerContextRetryRebuild, contextualSkills) + stableHistory, protectedTurnTail := splitHistoryForActiveTurn( + exec.history, + ts.persistedMessagesSnapshot(), + ) buildMessages := func(trimmedHistory []providers.Message) []providers.Message { - rebuildPromptReq := promptBuildRequestForTurn(ts, trimmedHistory, exec.summary, "", nil) + fullHistory := append(append([]providers.Message(nil), trimmedHistory...), protectedTurnTail...) + rebuildPromptReq := promptBuildRequestForTurn(ts, fullHistory, exec.summary, "", nil) rebuildPromptReq.ActiveSkills = append([]string(nil), contextualSkills...) return ts.agent.ContextBuilder.BuildMessagesFromPrompt(rebuildPromptReq) } originalHistoryCount := len(exec.history) var fit bool - exec.history, exec.callMessages, fit = trimHistoryToFitContextWindow( - exec.history, + var trimmedStableHistory []providers.Message + trimmedStableHistory, exec.callMessages, fit = trimHistoryToFitContextWindow( + stableHistory, func(trimmedHistory []providers.Message) []providers.Message { rebuilt := buildMessages(trimmedHistory) if exec.gracefulTerminal { @@ -389,7 +395,8 @@ func (p *Pipeline) CallLLM( exec.providerToolDefs, ts.agent.MaxTokens, ) - exec.messages = buildMessages(exec.history) + exec.history = append(trimmedStableHistory, protectedTurnTail...) + exec.messages = buildMessages(trimmedStableHistory) if exec.gracefulTerminal { msgs := append([]providers.Message(nil), exec.messages...) exec.callMessages = append(msgs, ts.interruptHintMessage()) @@ -406,13 +413,21 @@ func (p *Pipeline) CallLLM( }) } else if !fit { logger.WarnCF("agent", "Context still exceeds budget after retry compaction rebuild", map[string]any{ - "session_key": ts.sessionKey, - "retry": retry, - "history_msgs": len(exec.history), - "context_window": ts.agent.ContextWindow, - "max_tokens": ts.agent.MaxTokens, + "session_key": ts.sessionKey, + "retry": retry, + "history_msgs": len(exec.history), + "protected_turn_msgs": len(protectedTurnTail), + "context_window": ts.agent.ContextWindow, + "max_tokens": ts.agent.MaxTokens, }) } + if !fit { + err = fmt.Errorf( + "context window still exceeded after retry compaction; refusing to drop active turn messages: %w", + err, + ) + break + } continue } break diff --git a/pkg/agent/turn_state.go b/pkg/agent/turn_state.go index ae058e49d..2df0c81c2 100644 --- a/pkg/agent/turn_state.go +++ b/pkg/agent/turn_state.go @@ -634,13 +634,17 @@ func (ts *turnState) recordPersistedMessage(msg providers.Message) { ts.persistedMessages = append(ts.persistedMessages, msg) } +func (ts *turnState) persistedMessagesSnapshot() []providers.Message { + ts.mu.RLock() + defer ts.mu.RUnlock() + return append([]providers.Message(nil), ts.persistedMessages...) +} + func (ts *turnState) refreshRestorePointFromSession(agent *AgentInstance) { history := agent.Sessions.GetHistory(ts.sessionKey) summary := agent.Sessions.GetSummary(ts.sessionKey) - ts.mu.RLock() - persisted := append([]providers.Message(nil), ts.persistedMessages...) - ts.mu.RUnlock() + persisted := ts.persistedMessagesSnapshot() if matched := matchingTurnMessageTail(history, persisted); matched > 0 { history = append([]providers.Message(nil), history[:len(history)-matched]...) @@ -680,13 +684,83 @@ func (ts *turnState) restoreSession(agent *AgentInstance) error { func matchingTurnMessageTail(history, persisted []providers.Message) int { maxMatch := min(len(history), len(persisted)) for size := maxMatch; size > 0; size-- { - if reflect.DeepEqual(history[len(history)-size:], persisted[len(persisted)-size:]) { + if messageSlicesEquivalent(history[len(history)-size:], persisted[len(persisted)-size:]) { return size } } return 0 } +func splitHistoryForActiveTurn( + history []providers.Message, + persisted []providers.Message, +) ([]providers.Message, []providers.Message) { + matched := matchingTurnMessageTail(history, persisted) + if matched <= 0 { + return append([]providers.Message(nil), history...), nil + } + + stable := append([]providers.Message(nil), history[:len(history)-matched]...) + protected := append([]providers.Message(nil), history[len(history)-matched:]...) + return stable, protected +} + +func messageSlicesEquivalent(a, b []providers.Message) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if !messagesEquivalent(a[i], b[i]) { + return false + } + } + return true +} + +func messagesEquivalent(a, b providers.Message) bool { + return reflect.DeepEqual(normalizeMessageForComparison(a), normalizeMessageForComparison(b)) +} + +func normalizeMessageForComparison(msg providers.Message) providers.Message { + msg.PromptLayer = "" + msg.PromptSlot = "" + msg.PromptSource = "" + + if len(msg.Media) == 0 { + msg.Media = nil + } + if len(msg.Attachments) == 0 { + msg.Attachments = nil + } + if len(msg.SystemParts) == 0 { + msg.SystemParts = nil + } else { + msg.SystemParts = append([]providers.ContentBlock(nil), msg.SystemParts...) + for i := range msg.SystemParts { + msg.SystemParts[i].PromptLayer = "" + msg.SystemParts[i].PromptSlot = "" + msg.SystemParts[i].PromptSource = "" + } + } + if len(msg.ToolCalls) == 0 { + msg.ToolCalls = nil + } else { + msg.ToolCalls = append([]providers.ToolCall(nil), msg.ToolCalls...) + for i := range msg.ToolCalls { + msg.ToolCalls[i].Name = "" + msg.ToolCalls[i].Arguments = nil + msg.ToolCalls[i].ThoughtSignature = "" + if msg.ToolCalls[i].Function != nil { + fn := *msg.ToolCalls[i].Function + fn.ThoughtSignature = "" + msg.ToolCalls[i].Function = &fn + } + } + } + + return msg +} + func (ts *turnState) interruptHintMessage() providers.Message { _, hint := ts.gracefulInterruptRequested() content := "Interrupt requested. Stop scheduling tools and provide a short final summary." diff --git a/pkg/agent/turn_state_test.go b/pkg/agent/turn_state_test.go new file mode 100644 index 000000000..8ffd40dd5 --- /dev/null +++ b/pkg/agent/turn_state_test.go @@ -0,0 +1,112 @@ +package agent + +import ( + "strings" + "testing" + + "github.com/sipeed/picoclaw/pkg/providers" +) + +func TestMatchingTurnMessageTail_IgnoresInternalRuntimeFields(t *testing.T) { + history := []providers.Message{ + {Role: "user", Content: "question"}, + { + Role: "assistant", + ToolCalls: []providers.ToolCall{ + { + ID: "call_1", + Type: "function", + Function: &providers.FunctionCall{ + Name: "read_file", + Arguments: `{"path":"/tmp/test"}`, + }, + }, + }, + }, + } + + persisted := []providers.Message{ + userPromptMessage("question", nil), + { + Role: "assistant", + ToolCalls: []providers.ToolCall{ + { + ID: "call_1", + Type: "function", + Name: "read_file", + Arguments: map[string]any{"path": "/tmp/test"}, + ThoughtSignature: "internal-signature", + Function: &providers.FunctionCall{ + Name: "read_file", + Arguments: `{"path":"/tmp/test"}`, + ThoughtSignature: "internal-signature", + }, + }, + }, + }, + } + + if got := matchingTurnMessageTail(history, persisted); got != 2 { + t.Fatalf("matchingTurnMessageTail() = %d, want 2", got) + } +} + +func TestSplitHistoryForActiveTurn_ProtectsPersistedTail(t *testing.T) { + history := []providers.Message{ + {Role: "user", Content: "old question"}, + {Role: "assistant", Content: "old answer"}, + {Role: "user", Content: "current question"}, + {Role: "tool", Content: "tool output", ToolCallID: "call_1"}, + } + + persisted := []providers.Message{ + userPromptMessage("current question", nil), + {Role: "tool", Content: "tool output", ToolCallID: "call_1"}, + } + + stable, protected := splitHistoryForActiveTurn(history, persisted) + if len(stable) != 2 { + t.Fatalf("stable history len = %d, want 2", len(stable)) + } + if len(protected) != 2 { + t.Fatalf("protected tail len = %d, want 2", len(protected)) + } + if protected[0].Content != "current question" { + t.Fatalf("protected[0].Content = %q, want current question", protected[0].Content) + } +} + +func TestTrimHistoryToFitContextWindow_WithProtectedTurnTailKeepsActiveTurn(t *testing.T) { + current := strings.Repeat("current turn ", 80) + history := []providers.Message{ + {Role: "user", Content: strings.Repeat("old turn ", 60)}, + {Role: "assistant", Content: strings.Repeat("old reply ", 60)}, + {Role: "user", Content: current}, + } + + stable, protected := splitHistoryForActiveTurn(history, []providers.Message{ + userPromptMessage(current, nil), + }) + trimmedStable, messages, fit := trimHistoryToFitContextWindow( + stable, + func(trimmedHistory []providers.Message) []providers.Message { + return append(append([]providers.Message(nil), trimmedHistory...), protected...) + }, + 120, + nil, + 0, + ) + + if fit { + t.Fatal("expected protected active turn alone to remain over budget") + } + if len(trimmedStable) != 0 { + t.Fatalf("trimmed stable history len = %d, want 0", len(trimmedStable)) + } + if len(messages) != 1 { + t.Fatalf("messages len = %d, want 1 protected active-turn message", len(messages)) + } + if messages[0].Content != current { + t.Fatalf("messages[0].Content = %q, want protected current turn", messages[0].Content) + } +} From f0dcba8c5a47bfa1ede3ad6dacbe25062d081a1f Mon Sep 17 00:00:00 2001 From: afjcjsbx Date: Wed, 20 May 2026 09:16:09 +0200 Subject: [PATCH 3/3] fix(seahorse): preserve active tool-call turn when trimming fresh tail --- pkg/agent/context_seahorse_test.go | 68 ++++++++++++++++++ pkg/seahorse/short_assembler.go | 100 ++++++++++++++++++++++++--- pkg/seahorse/short_assembler_test.go | 62 +++++++++++++++++ 3 files changed, 222 insertions(+), 8 deletions(-) diff --git a/pkg/agent/context_seahorse_test.go b/pkg/agent/context_seahorse_test.go index e405ef944..05c831835 100644 --- a/pkg/agent/context_seahorse_test.go +++ b/pkg/agent/context_seahorse_test.go @@ -280,6 +280,74 @@ func TestSeahorseToProviderMessagesWithToolCalls(t *testing.T) { } } +func TestSeahorseAssemblePreservesActiveToolTurnAcrossSanitization(t *testing.T) { + engine, err := seahorse.NewEngine(seahorse.Config{ + DBPath: t.TempDir() + "/seahorse.db", + }, nil) + if err != nil { + t.Fatalf("NewEngine: %v", err) + } + + ctx := context.Background() + sessionKey := "test:active-tool-turn" + _, err = engine.Ingest(ctx, sessionKey, []seahorse.Message{ + { + Role: "assistant", + Content: "older context", + TokenCount: 20, + }, + { + Role: "user", + Content: "inspect the file", + TokenCount: 5, + }, + { + Role: "assistant", + TokenCount: 5, + Parts: []seahorse.MessagePart{{ + Type: "tool_use", + Name: "read_file", + Arguments: `{"path":"/tmp/test.txt"}`, + ToolCallID: "tc_1", + }}, + }, + { + Role: "tool", + TokenCount: 200, + Parts: []seahorse.MessagePart{{ + Type: "tool_result", + ToolCallID: "tc_1", + Text: "very large tool output", + }}, + }, + { + Role: "assistant", + Content: "done", + TokenCount: 5, + }, + }) + if err != nil { + t.Fatalf("Ingest: %v", err) + } + + result, err := engine.Assemble(ctx, sessionKey, seahorse.AssembleInput{Budget: 210}) + if err != nil { + t.Fatalf("Assemble: %v", err) + } + + sanitized := sanitizeHistoryForProvider(seahorseToProviderMessages(result)) + if len(sanitized) != 4 { + t.Fatalf("sanitized history len = %d, want 4 protected-turn messages", len(sanitized)) + } + assertRoles(t, sanitized, "user", "assistant", "tool", "assistant") + if len(sanitized[1].ToolCalls) != 1 || sanitized[1].ToolCalls[0].ID != "tc_1" { + t.Fatalf("assistant tool calls = %+v, want preserved tool call tc_1", sanitized[1].ToolCalls) + } + if sanitized[2].ToolCallID != "tc_1" { + t.Fatalf("tool result id = %q, want tc_1", sanitized[2].ToolCallID) + } +} + func TestSeahorseToProviderMessagesToolResult(t *testing.T) { msg := seahorse.Message{ Role: "tool", diff --git a/pkg/seahorse/short_assembler.go b/pkg/seahorse/short_assembler.go index 0bfd66a69..5533512a1 100644 --- a/pkg/seahorse/short_assembler.go +++ b/pkg/seahorse/short_assembler.go @@ -68,24 +68,33 @@ func (a *Assembler) Assemble(ctx context.Context, convID int64, input AssembleIn freshTailTokens += r.tokenCount } - // If the protected tail alone exceeds budget, trim from the oldest end of - // the tail until the newest items fit within the requested budget. + // If the protected tail alone exceeds budget, trim from the oldest end at + // provider-safe boundaries. The rebuild path later sanitizes leading + // assistant(tool_calls)/tool messages, so splitting the active turn here can + // silently discard the very context we are trying to protect. if freshTailTokens > input.Budget { originalTailCount := len(freshTail) originalFreshTailTokens := freshTailTokens - for freshTailTokens > input.Budget && len(freshTail) > 0 { - freshTailTokens -= freshTail[0].tokenCount - freshTail = freshTail[1:] - } - logger.InfoCF("seahorse", "assemble: trimmed fresh tail to budget", map[string]any{ + var preservedActiveTurn bool + freshTail, freshTailTokens, preservedActiveTurn = trimFreshTailToSafeBudget(freshTail, input.Budget) + logFields := map[string]any{ "budget": input.Budget, "fresh_tail_tokens": freshTailTokens, "fresh_tail_count": len(freshTail), "trimmed_fresh_items": originalTailCount - len(freshTail), "original_fresh_tokens": originalFreshTailTokens, - }) + "preserved_active_turn": preservedActiveTurn, + } + if preservedActiveTurn { + logger.WarnCF("seahorse", "assemble: preserving active turn over budget", logFields) + } else { + logger.InfoCF("seahorse", "assemble: trimmed fresh tail to safe boundary", logFields) + } } remainingBudget := input.Budget - freshTailTokens + if remainingBudget < 0 { + remainingBudget = 0 + } var selected []resolvedItem evictableTokens := 0 @@ -189,6 +198,81 @@ func (a *Assembler) Assemble(ctx context.Context, convID int64, input AssembleIn }, nil } +func trimFreshTailToSafeBudget(tail []resolvedItem, budget int) ([]resolvedItem, int, bool) { + tailTokens := resolvedItemsTokenCount(tail) + if tailTokens <= budget { + return tail, tailTokens, false + } + + latestTurnStart := lastUserMessageIndex(tail) + if latestTurnStart >= 0 { + latestTurnTokens := resolvedItemsTokenCount(tail[latestTurnStart:]) + if latestTurnTokens > budget { + return tail[latestTurnStart:], latestTurnTokens, true + } + } + + start := 0 + for tailTokens > budget && start < len(tail) { + tailTokens -= tail[start].tokenCount + start++ + } + for start < len(tail) && !isProviderSafeHistoryStart(tail[start:]) { + tailTokens -= tail[start].tokenCount + start++ + } + + return tail[start:], tailTokens, false +} + +func resolvedItemsTokenCount(items []resolvedItem) int { + total := 0 + for _, item := range items { + total += item.tokenCount + } + return total +} + +func lastUserMessageIndex(items []resolvedItem) int { + for i := len(items) - 1; i >= 0; i-- { + if items[i].itemType != "message" || items[i].message == nil { + continue + } + if items[i].message.Role == "user" { + return i + } + } + return -1 +} + +func isProviderSafeHistoryStart(items []resolvedItem) bool { + for _, item := range items { + if item.itemType != "message" || item.message == nil { + continue + } + if item.message.Role == "tool" { + return false + } + if item.message.Role == "assistant" && messageHasToolUse(item.message) { + return false + } + return true + } + return true +} + +func messageHasToolUse(msg *Message) bool { + if msg == nil { + return false + } + for _, part := range msg.Parts { + if part.Type == "tool_use" { + return true + } + } + return false +} + // resolveItem loads the full message or summary for a context item. func (a *Assembler) resolveItem(ctx context.Context, item ContextItem) (resolvedItem, error) { if item.ItemType == "message" { diff --git a/pkg/seahorse/short_assembler_test.go b/pkg/seahorse/short_assembler_test.go index 81918afa0..6472bc368 100644 --- a/pkg/seahorse/short_assembler_test.go +++ b/pkg/seahorse/short_assembler_test.go @@ -171,6 +171,68 @@ func TestAssemblerBudgetEvictsOldest(t *testing.T) { } } +func TestAssemblerBudgetPreservesLatestToolTurnWhenItExceedsBudget(t *testing.T) { + s, convID := setupAssemblerStore(t) + ctx := context.Background() + + oldMsg, _ := s.AddMessage(ctx, convID, "assistant", "older context", 20) + userMsg, _ := s.AddMessage(ctx, convID, "user", "inspect the file", 5) + assistantToolMsg, _ := s.AddMessageWithParts(ctx, convID, "assistant", []MessagePart{ + { + Type: "tool_use", + Name: "read_file", + Arguments: `{"path":"/tmp/test.txt"}`, + ToolCallID: "tc_1", + }, + }, 5) + toolResultMsg, _ := s.AddMessageWithParts(ctx, convID, "tool", []MessagePart{ + { + Type: "tool_result", + ToolCallID: "tc_1", + Text: "very large tool output", + }, + }, 200) + finalAssistantMsg, _ := s.AddMessage(ctx, convID, "assistant", "done", 5) + + s.UpsertContextItems(ctx, convID, []ContextItem{ + {Ordinal: 100, ItemType: "message", MessageID: oldMsg.ID, TokenCount: 20}, + {Ordinal: 200, ItemType: "message", MessageID: userMsg.ID, TokenCount: 5}, + {Ordinal: 300, ItemType: "message", MessageID: assistantToolMsg.ID, TokenCount: 5}, + {Ordinal: 400, ItemType: "message", MessageID: toolResultMsg.ID, TokenCount: 200}, + {Ordinal: 500, ItemType: "message", MessageID: finalAssistantMsg.ID, TokenCount: 5}, + }) + + a := &Assembler{store: s, config: Config{}} + result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 210}) + if err != nil { + t.Fatalf("Assemble: %v", err) + } + + if len(result.Messages) != 4 { + t.Fatalf("Messages = %d, want 4 protected-turn messages", len(result.Messages)) + } + if result.Messages[0].ID != userMsg.ID { + t.Fatalf("first message ID = %d, want current user message %d", result.Messages[0].ID, userMsg.ID) + } + if result.Messages[1].ID != assistantToolMsg.ID { + t.Fatalf("second message ID = %d, want assistant tool-call %d", result.Messages[1].ID, assistantToolMsg.ID) + } + if result.Messages[2].ID != toolResultMsg.ID { + t.Fatalf("third message ID = %d, want tool result %d", result.Messages[2].ID, toolResultMsg.ID) + } + if result.Messages[3].ID != finalAssistantMsg.ID { + t.Fatalf("fourth message ID = %d, want final assistant %d", result.Messages[3].ID, finalAssistantMsg.ID) + } + + totalTokens := 0 + for _, msg := range result.Messages { + totalTokens += msg.TokenCount + } + if totalTokens <= 210 { + t.Fatalf("assembled tokens = %d, want protected turn to remain over budget", totalTokens) + } +} + func TestAssemblerBudgetFitsAll(t *testing.T) { s, convID := setupAssemblerStore(t) ctx := context.Background()