diff --git a/pkg/agent/context_budget.go b/pkg/agent/context_budget.go index 72f80382a..1d99bf190 100644 --- a/pkg/agent/context_budget.go +++ b/pkg/agent/context_budget.go @@ -115,3 +115,55 @@ func isOverContextBudget( return total > contextWindow } + +// trimHistoryToFitContextWindow rebuilds the prompt from progressively newer +// history slices until it fits within the context window. Oldest complete turns +// are dropped first so tool-call sequences remain intact. +func trimHistoryToFitContextWindow( + history []providers.Message, + build func([]providers.Message) []providers.Message, + contextWindow int, + toolDefs []providers.ToolDefinition, + maxTokens int, +) ([]providers.Message, []providers.Message, bool) { + messages := build(history) + if !isOverContextBudget(contextWindow, messages, toolDefs, maxTokens) { + return history, messages, true + } + + trimmedHistory := append([]providers.Message(nil), history...) + for len(trimmedHistory) > 0 { + dropUntil := nextHistoryTrimStart(trimmedHistory) + if dropUntil <= 0 || dropUntil >= len(trimmedHistory) { + trimmedHistory = nil + } else { + trimmedHistory = append([]providers.Message(nil), trimmedHistory[dropUntil:]...) + } + + messages = build(trimmedHistory) + if !isOverContextBudget(contextWindow, messages, toolDefs, maxTokens) { + return trimmedHistory, messages, true + } + } + + return nil, messages, false +} + +func nextHistoryTrimStart(history []providers.Message) int { + if len(history) == 0 { + return 0 + } + + turns := parseTurnBoundaries(history) + if len(turns) >= 2 { + return turns[1] + } + if len(turns) == 1 { + if turns[0] > 0 { + return turns[0] + } + return len(history) + } + + return len(history) +} diff --git a/pkg/agent/context_budget_test.go b/pkg/agent/context_budget_test.go index 9de1707ec..d7ca2a665 100644 --- a/pkg/agent/context_budget_test.go +++ b/pkg/agent/context_budget_test.go @@ -844,3 +844,64 @@ func TestIsOverContextBudget_RealisticSession(t *testing.T) { t.Error("realistic session should exceed 500 context window") } } + +func TestTrimHistoryToFitContextWindow_DropsOldestTurns(t *testing.T) { + history := []providers.Message{ + msgUser(strings.Repeat("u1 ", 120)), + msgAssistant(strings.Repeat("a1 ", 120)), + msgUser(strings.Repeat("u2 ", 120)), + msgAssistant(strings.Repeat("a2 ", 120)), + msgUser(strings.Repeat("u3 ", 120)), + msgAssistant(strings.Repeat("a3 ", 120)), + } + + build := func(history []providers.Message) []providers.Message { + return append([]providers.Message(nil), history...) + } + + trimmedHistory, messages, fit := trimHistoryToFitContextWindow( + history, + build, + 700, + nil, + 0, + ) + if !fit { + t.Fatal("expected trimmed history to fit context window") + } + if len(trimmedHistory) != 4 { + t.Fatalf("trimmed history len = %d, want 4", len(trimmedHistory)) + } + if trimmedHistory[0].Content != history[2].Content { + t.Fatalf("first kept message = %q, want second turn start", trimmedHistory[0].Content) + } + if isOverContextBudget(700, messages, nil, 0) { + t.Fatal("trimmed messages should be within budget") + } +} + +func TestTrimHistoryToFitContextWindow_ClearsSingleOversizedTurn(t *testing.T) { + history := []providers.Message{ + msgUser(strings.Repeat("oversized ", 200)), + msgAssistant(strings.Repeat("oversized ", 200)), + } + + trimmedHistory, messages, fit := trimHistoryToFitContextWindow( + history, + func(history []providers.Message) []providers.Message { + return append([]providers.Message(nil), history...) + }, + 200, + nil, + 0, + ) + if !fit { + t.Fatal("expected empty history rebuild to fit context window") + } + if len(trimmedHistory) != 0 { + t.Fatalf("trimmed history len = %d, want 0", len(trimmedHistory)) + } + if len(messages) != 0 { + t.Fatalf("messages len = %d, want 0", len(messages)) + } +} diff --git a/pkg/agent/context_seahorse_test.go b/pkg/agent/context_seahorse_test.go index 2a9de3263..101f72ee2 100644 --- a/pkg/agent/context_seahorse_test.go +++ b/pkg/agent/context_seahorse_test.go @@ -288,6 +288,74 @@ func TestSeahorseToProviderMessagesWithToolCalls(t *testing.T) { } } +func TestSeahorseAssemblePreservesActiveToolTurnAcrossSanitization(t *testing.T) { + engine, err := seahorse.NewEngine(seahorse.Config{ + DBPath: t.TempDir() + "/seahorse.db", + }, nil) + if err != nil { + t.Fatalf("NewEngine: %v", err) + } + + ctx := context.Background() + sessionKey := "test:active-tool-turn" + _, err = engine.Ingest(ctx, sessionKey, []seahorse.Message{ + { + Role: "assistant", + Content: "older context", + TokenCount: 20, + }, + { + Role: "user", + Content: "inspect the file", + TokenCount: 5, + }, + { + Role: "assistant", + TokenCount: 5, + Parts: []seahorse.MessagePart{{ + Type: "tool_use", + Name: "read_file", + Arguments: `{"path":"/tmp/test.txt"}`, + ToolCallID: "tc_1", + }}, + }, + { + Role: "tool", + TokenCount: 200, + Parts: []seahorse.MessagePart{{ + Type: "tool_result", + ToolCallID: "tc_1", + Text: "very large tool output", + }}, + }, + { + Role: "assistant", + Content: "done", + TokenCount: 5, + }, + }) + if err != nil { + t.Fatalf("Ingest: %v", err) + } + + result, err := engine.Assemble(ctx, sessionKey, seahorse.AssembleInput{Budget: 210}) + if err != nil { + t.Fatalf("Assemble: %v", err) + } + + sanitized := sanitizeHistoryForProvider(seahorseToProviderMessages(result)) + if len(sanitized) != 4 { + t.Fatalf("sanitized history len = %d, want 4 protected-turn messages", len(sanitized)) + } + assertRoles(t, sanitized, "user", "assistant", "tool", "assistant") + if len(sanitized[1].ToolCalls) != 1 || sanitized[1].ToolCalls[0].ID != "tc_1" { + t.Fatalf("assistant tool calls = %+v, want preserved tool call tc_1", sanitized[1].ToolCalls) + } + if sanitized[2].ToolCallID != "tc_1" { + t.Fatalf("tool result id = %q, want tc_1", sanitized[2].ToolCallID) + } +} + func TestSeahorseToProviderMessagesToolResult(t *testing.T) { msg := seahorse.Message{ Role: "tool", diff --git a/pkg/agent/pipeline_llm.go b/pkg/agent/pipeline_llm.go index 0a31048a8..3dda03b3e 100644 --- a/pkg/agent/pipeline_llm.go +++ b/pkg/agent/pipeline_llm.go @@ -415,14 +415,65 @@ func (p *Pipeline) CallLLM( contextualSkills = ts.agent.ContextBuilder.ResolveActiveSkillsForContext(ts.activeSkills) } ts.recordSkillContextSnapshot(skillContextTriggerContextRetryRebuild, contextualSkills) - rebuildPromptReq := promptBuildRequestForTurn(ts, exec.history, exec.summary, "", nil, p.Cfg) - rebuildPromptReq.ActiveSkills = append([]string(nil), contextualSkills...) - exec.messages = ts.agent.ContextBuilder.BuildMessagesFromPrompt(rebuildPromptReq) - exec.callMessages = exec.messages + stableHistory, protectedTurnTail := splitHistoryForActiveTurn( + exec.history, + ts.persistedMessagesSnapshot(), + ) + buildMessages := func(trimmedHistory []providers.Message) []providers.Message { + fullHistory := append(append([]providers.Message(nil), trimmedHistory...), protectedTurnTail...) + rebuildPromptReq := promptBuildRequestForTurn(ts, fullHistory, exec.summary, "", nil, p.Cfg) + rebuildPromptReq.ActiveSkills = append([]string(nil), contextualSkills...) + return ts.agent.ContextBuilder.BuildMessagesFromPrompt(rebuildPromptReq) + } + originalHistoryCount := len(exec.history) + var fit bool + var trimmedStableHistory []providers.Message + trimmedStableHistory, exec.callMessages, fit = trimHistoryToFitContextWindow( + stableHistory, + func(trimmedHistory []providers.Message) []providers.Message { + rebuilt := buildMessages(trimmedHistory) + if exec.gracefulTerminal { + return append(append([]providers.Message(nil), rebuilt...), ts.interruptHintMessage()) + } + return rebuilt + }, + ts.agent.ContextWindow, + exec.providerToolDefs, + ts.agent.MaxTokens, + ) + exec.history = append(trimmedStableHistory, protectedTurnTail...) + exec.messages = buildMessages(trimmedStableHistory) if exec.gracefulTerminal { msgs := append([]providers.Message(nil), exec.messages...) exec.callMessages = append(msgs, ts.interruptHintMessage()) } + if dropped := originalHistoryCount - len(exec.history); dropped > 0 { + logger.WarnCF("agent", "Trimmed rebuilt history after context retry compaction", map[string]any{ + "session_key": ts.sessionKey, + "retry": retry, + "dropped_msgs": dropped, + "remaining_msgs": len(exec.history), + "context_window": ts.agent.ContextWindow, + "max_tokens": ts.agent.MaxTokens, + "still_overlimit": !fit, + }) + } else if !fit { + logger.WarnCF("agent", "Context still exceeds budget after retry compaction rebuild", map[string]any{ + "session_key": ts.sessionKey, + "retry": retry, + "history_msgs": len(exec.history), + "protected_turn_msgs": len(protectedTurnTail), + "context_window": ts.agent.ContextWindow, + "max_tokens": ts.agent.MaxTokens, + }) + } + if !fit { + err = fmt.Errorf( + "context window still exceeded after retry compaction; refusing to drop active turn messages: %w", + err, + ) + break + } continue } break diff --git a/pkg/agent/pipeline_setup.go b/pkg/agent/pipeline_setup.go index f764959e9..9d5033cef 100644 --- a/pkg/agent/pipeline_setup.go +++ b/pkg/agent/pipeline_setup.go @@ -66,10 +66,38 @@ func (p *Pipeline) SetupTurn(ctx context.Context, ts *turnState) (*turnExecution history = resp.History summary = resp.Summary } - rebuildPromptReq := promptBuildRequestForTurn(ts, history, summary, ts.userMessage, ts.media, cfg) - rebuildPromptReq.ActiveSkills = append([]string(nil), contextualSkills...) - messages = ts.agent.ContextBuilder.BuildMessagesFromPrompt(rebuildPromptReq) - messages = resolveMediaRefs(messages, p.MediaStore, maxMediaSize) + originalHistoryCount := len(history) + var fit bool + history, messages, fit = trimHistoryToFitContextWindow( + history, + func(trimmedHistory []providers.Message) []providers.Message { + rebuildPromptReq := promptBuildRequestForTurn(ts, trimmedHistory, summary, ts.userMessage, ts.media, cfg) + rebuildPromptReq.ActiveSkills = append([]string(nil), contextualSkills...) + rebuilt := ts.agent.ContextBuilder.BuildMessagesFromPrompt(rebuildPromptReq) + return resolveMediaRefs(rebuilt, p.MediaStore, maxMediaSize) + }, + ts.agent.ContextWindow, + toolDefs, + ts.agent.MaxTokens, + ) + if dropped := originalHistoryCount - len(history); dropped > 0 { + logger.WarnCF("agent", "Trimmed rebuilt history after proactive compaction", map[string]any{ + "session_key": ts.sessionKey, + "dropped_msgs": dropped, + "remaining_msgs": len(history), + "context_window": ts.agent.ContextWindow, + "max_tokens": ts.agent.MaxTokens, + "still_overlimit": !fit, + }) + } else if !fit { + logger.WarnCF("agent", "Context still exceeds budget "+ + "after proactive compaction rebuild", map[string]any{ + "session_key": ts.sessionKey, + "history_msgs": len(history), + "context_window": ts.agent.ContextWindow, + "max_tokens": ts.agent.MaxTokens, + }) + } } } diff --git a/pkg/agent/turn_state.go b/pkg/agent/turn_state.go index 4bc8f2ee4..5fc4dd40b 100644 --- a/pkg/agent/turn_state.go +++ b/pkg/agent/turn_state.go @@ -643,13 +643,17 @@ func (ts *turnState) recordPersistedMessage(msg providers.Message) { ts.persistedMessages = append(ts.persistedMessages, msg) } +func (ts *turnState) persistedMessagesSnapshot() []providers.Message { + ts.mu.RLock() + defer ts.mu.RUnlock() + return append([]providers.Message(nil), ts.persistedMessages...) +} + func (ts *turnState) refreshRestorePointFromSession(agent *AgentInstance) { history := agent.Sessions.GetHistory(ts.sessionKey) summary := agent.Sessions.GetSummary(ts.sessionKey) - ts.mu.RLock() - persisted := append([]providers.Message(nil), ts.persistedMessages...) - ts.mu.RUnlock() + persisted := ts.persistedMessagesSnapshot() if matched := matchingTurnMessageTail(history, persisted); matched > 0 { history = append([]providers.Message(nil), history[:len(history)-matched]...) @@ -686,29 +690,84 @@ func (ts *turnState) restoreSession(agent *AgentInstance) error { return agent.Sessions.Save(ts.sessionKey) } -// messagesContentEqual compares two message slices by content only, ignoring CreatedAt. -// JSON roundtrip loses the monotonic clock portion of time.Time, so direct -// reflect.DeepEqual would always differ on messages that roundtripped through -// the JSONL store. -func messagesContentEqual(a, b []providers.Message) bool { +func matchingTurnMessageTail(history, persisted []providers.Message) int { + maxMatch := min(len(history), len(persisted)) + for size := maxMatch; size > 0; size-- { + if messageSlicesEquivalent(history[len(history)-size:], persisted[len(persisted)-size:]) { + return size + } + } + return 0 +} + +func splitHistoryForActiveTurn( + history []providers.Message, + persisted []providers.Message, +) ([]providers.Message, []providers.Message) { + matched := matchingTurnMessageTail(history, persisted) + if matched <= 0 { + return append([]providers.Message(nil), history...), nil + } + + stable := append([]providers.Message(nil), history[:len(history)-matched]...) + protected := append([]providers.Message(nil), history[len(history)-matched:]...) + return stable, protected +} + +func messageSlicesEquivalent(a, b []providers.Message) bool { + if len(a) != len(b) { + return false + } for i := range a { - aCopy, bCopy := a[i], b[i] - aCopy.CreatedAt, bCopy.CreatedAt = nil, nil - if !reflect.DeepEqual(aCopy, bCopy) { + if !messagesEquivalent(a[i], b[i]) { return false } } return true } -func matchingTurnMessageTail(history, persisted []providers.Message) int { - maxMatch := min(len(history), len(persisted)) - for size := maxMatch; size > 0; size-- { - if messagesContentEqual(history[len(history)-size:], persisted[len(persisted)-size:]) { - return size +func messagesEquivalent(a, b providers.Message) bool { + return reflect.DeepEqual(normalizeMessageForComparison(a), normalizeMessageForComparison(b)) +} + +func normalizeMessageForComparison(msg providers.Message) providers.Message { + msg.PromptLayer = "" + msg.PromptSlot = "" + msg.PromptSource = "" + + if len(msg.Media) == 0 { + msg.Media = nil + } + if len(msg.Attachments) == 0 { + msg.Attachments = nil + } + if len(msg.SystemParts) == 0 { + msg.SystemParts = nil + } else { + msg.SystemParts = append([]providers.ContentBlock(nil), msg.SystemParts...) + for i := range msg.SystemParts { + msg.SystemParts[i].PromptLayer = "" + msg.SystemParts[i].PromptSlot = "" + msg.SystemParts[i].PromptSource = "" } } - return 0 + if len(msg.ToolCalls) == 0 { + msg.ToolCalls = nil + } else { + msg.ToolCalls = append([]providers.ToolCall(nil), msg.ToolCalls...) + for i := range msg.ToolCalls { + msg.ToolCalls[i].Name = "" + msg.ToolCalls[i].Arguments = nil + msg.ToolCalls[i].ThoughtSignature = "" + if msg.ToolCalls[i].Function != nil { + fn := *msg.ToolCalls[i].Function + fn.ThoughtSignature = "" + msg.ToolCalls[i].Function = &fn + } + } + } + + return msg } func (ts *turnState) interruptHintMessage() providers.Message { diff --git a/pkg/agent/turn_state_test.go b/pkg/agent/turn_state_test.go new file mode 100644 index 000000000..8ffd40dd5 --- /dev/null +++ b/pkg/agent/turn_state_test.go @@ -0,0 +1,112 @@ +package agent + +import ( + "strings" + "testing" + + "github.com/sipeed/picoclaw/pkg/providers" +) + +func TestMatchingTurnMessageTail_IgnoresInternalRuntimeFields(t *testing.T) { + history := []providers.Message{ + {Role: "user", Content: "question"}, + { + Role: "assistant", + ToolCalls: []providers.ToolCall{ + { + ID: "call_1", + Type: "function", + Function: &providers.FunctionCall{ + Name: "read_file", + Arguments: `{"path":"/tmp/test"}`, + }, + }, + }, + }, + } + + persisted := []providers.Message{ + userPromptMessage("question", nil), + { + Role: "assistant", + ToolCalls: []providers.ToolCall{ + { + ID: "call_1", + Type: "function", + Name: "read_file", + Arguments: map[string]any{"path": "/tmp/test"}, + ThoughtSignature: "internal-signature", + Function: &providers.FunctionCall{ + Name: "read_file", + Arguments: `{"path":"/tmp/test"}`, + ThoughtSignature: "internal-signature", + }, + }, + }, + }, + } + + if got := matchingTurnMessageTail(history, persisted); got != 2 { + t.Fatalf("matchingTurnMessageTail() = %d, want 2", got) + } +} + +func TestSplitHistoryForActiveTurn_ProtectsPersistedTail(t *testing.T) { + history := []providers.Message{ + {Role: "user", Content: "old question"}, + {Role: "assistant", Content: "old answer"}, + {Role: "user", Content: "current question"}, + {Role: "tool", Content: "tool output", ToolCallID: "call_1"}, + } + + persisted := []providers.Message{ + userPromptMessage("current question", nil), + {Role: "tool", Content: "tool output", ToolCallID: "call_1"}, + } + + stable, protected := splitHistoryForActiveTurn(history, persisted) + if len(stable) != 2 { + t.Fatalf("stable history len = %d, want 2", len(stable)) + } + if len(protected) != 2 { + t.Fatalf("protected tail len = %d, want 2", len(protected)) + } + if protected[0].Content != "current question" { + t.Fatalf("protected[0].Content = %q, want current question", protected[0].Content) + } +} + +func TestTrimHistoryToFitContextWindow_WithProtectedTurnTailKeepsActiveTurn(t *testing.T) { + current := strings.Repeat("current turn ", 80) + history := []providers.Message{ + {Role: "user", Content: strings.Repeat("old turn ", 60)}, + {Role: "assistant", Content: strings.Repeat("old reply ", 60)}, + {Role: "user", Content: current}, + } + + stable, protected := splitHistoryForActiveTurn(history, []providers.Message{ + userPromptMessage(current, nil), + }) + trimmedStable, messages, fit := trimHistoryToFitContextWindow( + stable, + func(trimmedHistory []providers.Message) []providers.Message { + return append(append([]providers.Message(nil), trimmedHistory...), protected...) + }, + 120, + nil, + 0, + ) + + if fit { + t.Fatal("expected protected active turn alone to remain over budget") + } + if len(trimmedStable) != 0 { + t.Fatalf("trimmed stable history len = %d, want 0", len(trimmedStable)) + } + if len(messages) != 1 { + t.Fatalf("messages len = %d, want 1 protected active-turn message", len(messages)) + } + if messages[0].Content != current { + t.Fatalf("messages[0].Content = %q, want protected current turn", messages[0].Content) + } +} diff --git a/pkg/seahorse/short_assembler.go b/pkg/seahorse/short_assembler.go index f0fd323ba..5533512a1 100644 --- a/pkg/seahorse/short_assembler.go +++ b/pkg/seahorse/short_assembler.go @@ -68,17 +68,31 @@ func (a *Assembler) Assemble(ctx context.Context, convID int64, input AssembleIn freshTailTokens += r.tokenCount } - // Budget-aware selection of evictable items + // If the protected tail alone exceeds budget, trim from the oldest end at + // provider-safe boundaries. The rebuild path later sanitizes leading + // assistant(tool_calls)/tool messages, so splitting the active turn here can + // silently discard the very context we are trying to protect. + if freshTailTokens > input.Budget { + originalTailCount := len(freshTail) + originalFreshTailTokens := freshTailTokens + var preservedActiveTurn bool + freshTail, freshTailTokens, preservedActiveTurn = trimFreshTailToSafeBudget(freshTail, input.Budget) + logFields := map[string]any{ + "budget": input.Budget, + "fresh_tail_tokens": freshTailTokens, + "fresh_tail_count": len(freshTail), + "trimmed_fresh_items": originalTailCount - len(freshTail), + "original_fresh_tokens": originalFreshTailTokens, + "preserved_active_turn": preservedActiveTurn, + } + if preservedActiveTurn { + logger.WarnCF("seahorse", "assemble: preserving active turn over budget", logFields) + } else { + logger.InfoCF("seahorse", "assemble: trimmed fresh tail to safe boundary", logFields) + } + } remainingBudget := input.Budget - freshTailTokens if remainingBudget < 0 { - // Fresh tail alone exceeds budget - we keep it anyway (design decision) - // Log for debugging retry/overflow issues - logger.InfoCF("seahorse", "assemble: fresh tail exceeds budget", map[string]any{ - "budget": input.Budget, - "fresh_tail_tokens": freshTailTokens, - "fresh_tail_count": len(freshTail), - "over_budget_by": freshTailTokens - input.Budget, - }) remainingBudget = 0 } @@ -184,6 +198,81 @@ func (a *Assembler) Assemble(ctx context.Context, convID int64, input AssembleIn }, nil } +func trimFreshTailToSafeBudget(tail []resolvedItem, budget int) ([]resolvedItem, int, bool) { + tailTokens := resolvedItemsTokenCount(tail) + if tailTokens <= budget { + return tail, tailTokens, false + } + + latestTurnStart := lastUserMessageIndex(tail) + if latestTurnStart >= 0 { + latestTurnTokens := resolvedItemsTokenCount(tail[latestTurnStart:]) + if latestTurnTokens > budget { + return tail[latestTurnStart:], latestTurnTokens, true + } + } + + start := 0 + for tailTokens > budget && start < len(tail) { + tailTokens -= tail[start].tokenCount + start++ + } + for start < len(tail) && !isProviderSafeHistoryStart(tail[start:]) { + tailTokens -= tail[start].tokenCount + start++ + } + + return tail[start:], tailTokens, false +} + +func resolvedItemsTokenCount(items []resolvedItem) int { + total := 0 + for _, item := range items { + total += item.tokenCount + } + return total +} + +func lastUserMessageIndex(items []resolvedItem) int { + for i := len(items) - 1; i >= 0; i-- { + if items[i].itemType != "message" || items[i].message == nil { + continue + } + if items[i].message.Role == "user" { + return i + } + } + return -1 +} + +func isProviderSafeHistoryStart(items []resolvedItem) bool { + for _, item := range items { + if item.itemType != "message" || item.message == nil { + continue + } + if item.message.Role == "tool" { + return false + } + if item.message.Role == "assistant" && messageHasToolUse(item.message) { + return false + } + return true + } + return true +} + +func messageHasToolUse(msg *Message) bool { + if msg == nil { + return false + } + for _, part := range msg.Parts { + if part.Type == "tool_use" { + return true + } + } + return false +} + // resolveItem loads the full message or summary for a context item. func (a *Assembler) resolveItem(ctx context.Context, item ContextItem) (resolvedItem, error) { if item.ItemType == "message" { diff --git a/pkg/seahorse/short_assembler_test.go b/pkg/seahorse/short_assembler_test.go index 88a05e64c..6472bc368 100644 --- a/pkg/seahorse/short_assembler_test.go +++ b/pkg/seahorse/short_assembler_test.go @@ -145,22 +145,91 @@ func TestAssemblerBudgetEvictsOldest(t *testing.T) { s.UpsertContextItems(ctx, convID, items) // Budget of 200 tokens with FreshTailCount=32 - // Fresh tail = last 32 messages (320 tokens, over budget, but always included) + // Fresh tail = last 32 messages (320 tokens, over budget) // Evictable = first 8 messages (80 tokens) - // Budget after tail: max(0, 200-320) = 0 → no evictable items included + // The oldest messages from the fresh tail should be dropped so only the + // newest 20 messages remain within the 200-token budget. a := &Assembler{store: s, config: Config{}} result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 200}) if err != nil { t.Fatalf("Assemble: %v", err) } - // Should only include the 32-item fresh tail - if len(result.Messages) != 32 { - t.Errorf("Messages = %d, want 32 (fresh tail)", len(result.Messages)) + if len(result.Messages) != 20 { + t.Errorf("Messages = %d, want 20", len(result.Messages)) } - // Should be the LAST 32 messages - if result.Messages[0].ID != msgs[8].ID { - t.Errorf("first message ID = %d, want %d (msgs[8])", result.Messages[0].ID, msgs[8].ID) + if result.Messages[0].ID != msgs[20].ID { + t.Errorf("first message ID = %d, want %d (msgs[20])", result.Messages[0].ID, msgs[20].ID) + } + + totalTokens := 0 + for _, msg := range result.Messages { + totalTokens += msg.TokenCount + } + if totalTokens > 200 { + t.Errorf("assembled tokens = %d, want <= 200", totalTokens) + } +} + +func TestAssemblerBudgetPreservesLatestToolTurnWhenItExceedsBudget(t *testing.T) { + s, convID := setupAssemblerStore(t) + ctx := context.Background() + + oldMsg, _ := s.AddMessage(ctx, convID, "assistant", "older context", 20) + userMsg, _ := s.AddMessage(ctx, convID, "user", "inspect the file", 5) + assistantToolMsg, _ := s.AddMessageWithParts(ctx, convID, "assistant", []MessagePart{ + { + Type: "tool_use", + Name: "read_file", + Arguments: `{"path":"/tmp/test.txt"}`, + ToolCallID: "tc_1", + }, + }, 5) + toolResultMsg, _ := s.AddMessageWithParts(ctx, convID, "tool", []MessagePart{ + { + Type: "tool_result", + ToolCallID: "tc_1", + Text: "very large tool output", + }, + }, 200) + finalAssistantMsg, _ := s.AddMessage(ctx, convID, "assistant", "done", 5) + + s.UpsertContextItems(ctx, convID, []ContextItem{ + {Ordinal: 100, ItemType: "message", MessageID: oldMsg.ID, TokenCount: 20}, + {Ordinal: 200, ItemType: "message", MessageID: userMsg.ID, TokenCount: 5}, + {Ordinal: 300, ItemType: "message", MessageID: assistantToolMsg.ID, TokenCount: 5}, + {Ordinal: 400, ItemType: "message", MessageID: toolResultMsg.ID, TokenCount: 200}, + {Ordinal: 500, ItemType: "message", MessageID: finalAssistantMsg.ID, TokenCount: 5}, + }) + + a := &Assembler{store: s, config: Config{}} + result, err := a.Assemble(ctx, convID, AssembleInput{Budget: 210}) + if err != nil { + t.Fatalf("Assemble: %v", err) + } + + if len(result.Messages) != 4 { + t.Fatalf("Messages = %d, want 4 protected-turn messages", len(result.Messages)) + } + if result.Messages[0].ID != userMsg.ID { + t.Fatalf("first message ID = %d, want current user message %d", result.Messages[0].ID, userMsg.ID) + } + if result.Messages[1].ID != assistantToolMsg.ID { + t.Fatalf("second message ID = %d, want assistant tool-call %d", result.Messages[1].ID, assistantToolMsg.ID) + } + if result.Messages[2].ID != toolResultMsg.ID { + t.Fatalf("third message ID = %d, want tool result %d", result.Messages[2].ID, toolResultMsg.ID) + } + if result.Messages[3].ID != finalAssistantMsg.ID { + t.Fatalf("fourth message ID = %d, want final assistant %d", result.Messages[3].ID, finalAssistantMsg.ID) + } + + totalTokens := 0 + for _, msg := range result.Messages { + totalTokens += msg.TokenCount + } + if totalTokens <= 210 { + t.Fatalf("assembled tokens = %d, want protected turn to remain over budget", totalTokens) } }