From c7ea018a73dae733017ab71a0389c86c6e17725b Mon Sep 17 00:00:00 2001 From: Administrator <1280842908@qq.com> Date: Wed, 18 Mar 2026 12:18:32 +0800 Subject: [PATCH] fix(agent): prevent duplicate history during subturn context recoveries Problem: During subturn context limit or truncation recoveries, the recovery loops repeatedly called `runAgentLoop` with the same or modified `UserMessage`. Because `runAgentLoop` unconditionally adds the `UserMessage` to the session history, this resulted in: 1. Duplicate User Messages polluting the history upon `context_length_exceeded` retries. 2. The possibility of injecting empty User Messages if `opts.UserMessage` was artificially blanked out to work around the duplication. 3. Messy or duplicate entries during `finish_reason="truncated"` recovery injections. Solution: - Introduce `SkipAddUserMessage` boolean to `processOptions` to explicitly control whether the agent loop should write the user prompt to history. - Add an explicit `opts.UserMessage != ""` check in `runAgentLoop` to prevent polluting history with empty message content. - In `subturn.go`'s recovery loop, set `SkipAddUserMessage: contextRetryCount > 0` to skip writing the user message on context --- pkg/agent/loop.go | 14 +- pkg/agent/subturn.go | 181 ++++++++++++- pkg/agent/turn_state.go | 19 ++ pkg/providers/common/common.go | 11 +- pkg/utils/context.go | 173 +++++++++++++ pkg/utils/context_test.go | 450 +++++++++++++++++++++++++++++++++ 6 files changed, 834 insertions(+), 14 deletions(-) create mode 100644 pkg/utils/context.go create mode 100644 pkg/utils/context_test.go diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index b4a7774c3..d9f9e6371 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -49,8 +49,8 @@ type AgentLoop struct { cmdRegistry *commands.Registry mcp mcpRuntime steering *steeringQueue - subTurnResults sync.Map // key: sessionKey (string), value: chan *tools.ToolResult - activeTurnStates sync.Map // key: sessionKey (string), value: *turnState + subTurnResults sync.Map // key: sessionKey (string), value: chan *tools.ToolResult + activeTurnStates sync.Map // key: sessionKey (string), value: *turnState subTurnCounter atomic.Int64 // Counter for generating unique SubTurn IDs mu sync.RWMutex // Track active requests for safe provider cleanup @@ -69,6 +69,7 @@ type processOptions struct { SendResponse bool // Whether to send response via bus NoHistory bool // If true, don't load session history (for heartbeat) SkipInitialSteeringPoll bool // If true, skip the steering poll at loop start (used by Continue) + SkipAddUserMessage bool // If true, skip adding UserMessage to session history } const ( @@ -1051,7 +1052,9 @@ func (al *AgentLoop) runAgentLoop( messages = resolveMediaRefs(messages, al.mediaStore, maxMediaSize) // 2. Save user message to session - agent.Sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage) + if !opts.SkipAddUserMessage && opts.UserMessage != "" { + agent.Sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage) + } // 3. Run LLM iteration loop finalContent, iteration, err := al.runLLMIteration(ctx, agent, messages, opts) @@ -1403,6 +1406,11 @@ func (al *AgentLoop) runLLMIteration( return "", iteration, fmt.Errorf("LLM call failed after retries: %w", err) } + // Save finishReason to turnState for SubTurn truncation detection + if ts := turnStateFromContext(ctx); ts != nil { + ts.SetLastFinishReason(response.FinishReason) + } + go al.handleReasoning( ctx, response.Reasoning, diff --git a/pkg/agent/subturn.go b/pkg/agent/subturn.go index 4dfed42a0..3c178d9fc 100644 --- a/pkg/agent/subturn.go +++ b/pkg/agent/subturn.go @@ -4,11 +4,13 @@ import ( "context" "errors" "fmt" + "strings" "time" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/tools" + "github.com/sipeed/picoclaw/pkg/utils" ) // ====================== Config & Constants ====================== @@ -104,6 +106,19 @@ type SubTurnConfig struct { // Default is 5 minutes (defaultSubTurnTimeout) if not specified. Timeout time.Duration + // MaxContextRunes limits the context size (in runes) passed to the SubTurn. + // This prevents context window overflow by truncating message history before LLM calls. + // + // Values: + // 0 = Auto-calculate based on model's ContextWindow * 0.75 (default, recommended) + // -1 = No limit (disable soft truncation, rely only on hard context errors) + // >0 = Use specified rune limit + // + // The soft limit acts as a first line of defense before hitting the provider's + // hard context window limit. When exceeded, older messages are intelligently + // truncated while preserving system messages and recent context. + MaxContextRunes int + // Can be extended with temperature, topP, etc. } @@ -377,6 +392,25 @@ func deliverSubTurnResult(parentTS *turnState, childID string, result *tools.Too // runTurn builds a temporary AgentInstance from SubTurnConfig and delegates to // the real agent loop. The child's ephemeral session is used for history so it // never pollutes the parent session. +// +// This function implements multiple layers of context protection and error recovery: +// +// 1. Soft Context Limit (MaxContextRunes): +// - Proactively truncates message history before LLM calls +// - Default: 75% of model's context window +// - Preserves system messages and recent context +// - First line of defense against context overflow +// +// 2. Hard Context Error Recovery: +// - Detects context_length_exceeded errors from provider +// - Triggers force compression and retries (up to 2 times) +// - Second line of defense when soft limit is insufficient +// +// 3. Truncation Recovery: +// - Detects when LLM response is truncated (finish_reason="truncated") +// - Injects recovery prompt asking for shorter response +// - Retries up to 2 times +// - Handles cases where max_tokens is hit func runTurn(ctx context.Context, al *AgentLoop, ts *turnState, cfg SubTurnConfig) (*tools.ToolResult, error) { // Derive candidates from the requested model using the parent loop's provider. defaultProvider := al.GetConfig().Agents.Defaults.Provider @@ -420,17 +454,144 @@ func runTurn(ctx context.Context, al *AgentLoop, ts *turnState, cfg SubTurnConfi childAgent.MaxTokens = parentAgent.MaxTokens } - finalContent, err := al.runAgentLoop(ctx, childAgent, processOptions{ - SessionKey: ts.turnID, - UserMessage: cfg.SystemPrompt, - DefaultResponse: "", - EnableSummary: false, - SendResponse: false, - }) - if err != nil { - return nil, err + // Resolve MaxContextRunes configuration + maxContextRunes := utils.ResolveMaxContextRunes(cfg.MaxContextRunes, childAgent.ContextWindow) + + logger.DebugCF("subturn", "Context limit resolved", + map[string]any{ + "turn_id": ts.turnID, + "context_window": childAgent.ContextWindow, + "max_context_runes": maxContextRunes, + "configured_value": cfg.MaxContextRunes, + }) + + // Retry loop for truncation and context errors + const ( + maxTruncationRetries = 2 + maxContextRetries = 2 + ) + + truncationRetryCount := 0 + contextRetryCount := 0 + currentPrompt := cfg.SystemPrompt + + for { + // Soft context limit: check and truncate before LLM call + if maxContextRunes > 0 { + messages := childAgent.Sessions.GetHistory(ts.turnID) + currentRunes := utils.MeasureContextRunes(messages) + + if currentRunes > maxContextRunes { + logger.WarnCF("subturn", "Context exceeds soft limit, truncating", + map[string]any{ + "turn_id": ts.turnID, + "current_runes": currentRunes, + "max_runes": maxContextRunes, + "overflow": currentRunes - maxContextRunes, + }) + + truncatedMessages := utils.TruncateContextSmart(messages, maxContextRunes) + childAgent.Sessions.SetHistory(ts.turnID, truncatedMessages) + + // Log truncation result + newRunes := utils.MeasureContextRunes(truncatedMessages) + logger.InfoCF("subturn", "Context truncated successfully", + map[string]any{ + "turn_id": ts.turnID, + "before_runes": currentRunes, + "after_runes": newRunes, + "saved_runes": currentRunes - newRunes, + }) + } + } + + // Call the agent loop + finalContent, err := al.runAgentLoop(ctx, childAgent, processOptions{ + SessionKey: ts.turnID, + UserMessage: currentPrompt, + DefaultResponse: "", + EnableSummary: false, + SendResponse: false, + SkipAddUserMessage: contextRetryCount > 0, + }) + + // 1. Handle context length errors + if err != nil && isContextLengthError(err) { + if contextRetryCount >= maxContextRetries { + logger.ErrorCF("subturn", "Context limit exceeded after max retries", + map[string]any{ + "turn_id": ts.turnID, + "retries": contextRetryCount, + "max_retries": maxContextRetries, + }) + return nil, fmt.Errorf("context limit exceeded after %d retries: %w", maxContextRetries, err) + } + + logger.WarnCF("subturn", "Context length exceeded, compressing and retrying", + map[string]any{ + "turn_id": ts.turnID, + "retry": contextRetryCount + 1, + }) + + // Trigger force compression + al.forceCompression(childAgent, ts.turnID) + + contextRetryCount++ + continue // Retry with compressed history + } + + if err != nil { + return nil, err // Other errors, return immediately + } + + // 2. Check for truncation (retrieve finishReason from turnState) + finishReason := ts.GetLastFinishReason() + + if finishReason == "truncated" && truncationRetryCount < maxTruncationRetries { + logger.WarnCF("subturn", "Response truncated, injecting recovery message", + map[string]any{ + "turn_id": ts.turnID, + "retry": truncationRetryCount + 1, + }) + + // IMPORTANT: Do NOT manually add messages to history here. + // runAgentLoop has already saved both the assistant message (finalContent) + // and will save the next user message (currentPrompt) on the next iteration. + // Manually adding them would cause duplicates. + + // Inject recovery prompt - it will be added by runAgentLoop on next iteration + recoveryPrompt := "Your previous response was truncated due to length. Please provide a shorter, complete response that finishes your thought." + currentPrompt = recoveryPrompt + + truncationRetryCount++ + continue // Retry with recovery prompt + } + + // 3. Success - return result + return &tools.ToolResult{ForLLM: finalContent}, nil } - return &tools.ToolResult{ForLLM: finalContent}, nil +} + +// isContextLengthError checks if the error is due to context length exceeded. +// It excludes timeout errors to avoid false positives. +func isContextLengthError(err error) bool { + if err == nil { + return false + } + errMsg := strings.ToLower(err.Error()) + + // Exclude timeout errors + if strings.Contains(errMsg, "timeout") || strings.Contains(errMsg, "deadline exceeded") { + return false + } + + // Detect context error patterns + return strings.Contains(errMsg, "context_length_exceeded") || + strings.Contains(errMsg, "maximum context length") || + strings.Contains(errMsg, "context window") || + strings.Contains(errMsg, "too many tokens") || + strings.Contains(errMsg, "token limit") || + strings.Contains(errMsg, "prompt is too long") } // ====================== Other Types ====================== diff --git a/pkg/agent/turn_state.go b/pkg/agent/turn_state.go index 2ca078017..e4bca4f15 100644 --- a/pkg/agent/turn_state.go +++ b/pkg/agent/turn_state.go @@ -55,6 +55,11 @@ type turnState struct { // This allows child SubTurns to check if the parent has ended. // Nil for root turns. parentTurnState *turnState + + // lastFinishReason stores the finish_reason from the last LLM call. + // Used by SubTurn to detect truncation and retry. + // MUST be accessed under mu lock. + lastFinishReason string } // ====================== Public API ====================== @@ -136,6 +141,20 @@ func (ts *turnState) IsParentEnded() bool { return ts.parentTurnState.parentEnded.Load() } +// SetLastFinishReason updates the last finish reason (thread-safe). +func (ts *turnState) SetLastFinishReason(reason string) { + ts.mu.Lock() + defer ts.mu.Unlock() + ts.lastFinishReason = reason +} + +// GetLastFinishReason retrieves the last finish reason (thread-safe). +func (ts *turnState) GetLastFinishReason() string { + ts.mu.Lock() + defer ts.mu.Unlock() + return ts.lastFinishReason +} + // IsParentEnded is a convenience method to check if parent ended. // It returns the value of the parent's parentEnded atomic flag. diff --git a/pkg/providers/common/common.go b/pkg/providers/common/common.go index 23680a1bf..9dfd7dc1d 100644 --- a/pkg/providers/common/common.go +++ b/pkg/providers/common/common.go @@ -214,11 +214,20 @@ func ParseResponse(body io.Reader) (*LLMResponse, error) { Reasoning: choice.Message.Reasoning, ReasoningDetails: choice.Message.ReasoningDetails, ToolCalls: toolCalls, - FinishReason: choice.FinishReason, + FinishReason: normalizeFinishReason(choice.FinishReason), Usage: apiResponse.Usage, }, nil } +// normalizeFinishReason normalizes finish_reason values across providers. +// Converts "length" to "truncated" for consistent handling. +func normalizeFinishReason(reason string) string { + if reason == "length" { + return "truncated" + } + return reason +} + // DecodeToolCallArguments decodes a tool call's arguments from raw JSON. func DecodeToolCallArguments(raw json.RawMessage, name string) map[string]any { arguments := make(map[string]any) diff --git a/pkg/utils/context.go b/pkg/utils/context.go new file mode 100644 index 000000000..115841dc4 --- /dev/null +++ b/pkg/utils/context.go @@ -0,0 +1,173 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// Inspired by and based on nanobot: https://github.com/HKUDS/nanobot +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package utils + +import ( + "encoding/json" + "fmt" + "unicode/utf8" + + "github.com/sipeed/picoclaw/pkg/providers" +) + +// CalculateDefaultMaxContextRunes computes a default context limit based on the model's context window. +// Strategy: Use 75% of the context window and convert to rune estimate. +// +// Token-to-rune conversion ratios (conservative estimates): +// - English: ~4 chars per token +// - Chinese: ~1.5-2 chars per token +// - Mixed: ~3 chars per token (used here for safety) +func CalculateDefaultMaxContextRunes(contextWindow int) int { + if contextWindow <= 0 { + // Conservative fallback when context window is unknown + return 8000 // ~2000 tokens + } + + // Use 75% of context window to leave headroom + targetTokens := int(float64(contextWindow) * 0.75) + + // Convert tokens to runes using conservative ratio + const avgCharsPerToken = 3 + return targetTokens * avgCharsPerToken +} + +// ResolveMaxContextRunes determines the final MaxContextRunes value to use. +// Priority: explicit config > auto-calculate > conservative default +func ResolveMaxContextRunes(configValue, contextWindow int) int { + switch { + case configValue > 0: + // Explicitly configured, use as-is + return configValue + case configValue == -1: + // Explicitly disabled + return -1 + default: + // 0 or unset: auto-calculate + return CalculateDefaultMaxContextRunes(contextWindow) + } +} + +// MeasureContextRunes calculates the total rune count of a message list. +// Includes content, reasoning content, and estimates for tool calls. +func MeasureContextRunes(messages []providers.Message) int { + totalRunes := 0 + for _, msg := range messages { + totalRunes += utf8.RuneCountInString(msg.Content) + totalRunes += utf8.RuneCountInString(msg.ReasoningContent) + + // Tool calls: serialize to JSON and count + if len(msg.ToolCalls) > 0 { + for _, tc := range msg.ToolCalls { + totalRunes += utf8.RuneCountInString(tc.Name) + // Arguments: serialize and count + if argsJSON, err := json.Marshal(tc.Arguments); err == nil { + totalRunes += utf8.RuneCountInString(string(argsJSON)) + } else { + // Fallback estimate if serialization fails + totalRunes += 100 + } + } + } + + // ToolCallID + totalRunes += utf8.RuneCountInString(msg.ToolCallID) + } + return totalRunes +} + +// TruncateContextSmart intelligently truncates message history to fit within maxRunes. +// +// Strategy: +// 1. Always preserve system messages (they define the agent's behavior) +// 2. Keep the most recent messages (they contain current context) +// 3. Drop older middle messages when necessary +// 4. Insert a truncation notice to inform the LLM +// +// Returns the truncated message list. +func TruncateContextSmart(messages []providers.Message, maxRunes int) []providers.Message { + if len(messages) == 0 { + return messages + } + + // Separate system messages from others + var systemMsgs []providers.Message + var otherMsgs []providers.Message + + for _, msg := range messages { + if msg.Role == "system" { + systemMsgs = append(systemMsgs, msg) + } else { + otherMsgs = append(otherMsgs, msg) + } + } + + // Calculate system message size + systemRunes := 0 + for _, msg := range systemMsgs { + systemRunes += utf8.RuneCountInString(msg.Content) + systemRunes += utf8.RuneCountInString(msg.ReasoningContent) + } + + // Reserve space for truncation notice (estimate ~80 runes) + const truncationNoticeEstimate = 80 + + // Allocate remaining space for other messages + remainingRunes := maxRunes - systemRunes - truncationNoticeEstimate + if remainingRunes <= 0 { + // System messages already exceed limit - return only system messages + return systemMsgs + } + + // Collect recent messages in reverse order until we hit the limit + var keptMsgs []providers.Message + currentRunes := 0 + + for i := len(otherMsgs) - 1; i >= 0; i-- { + msg := otherMsgs[i] + msgRunes := utf8.RuneCountInString(msg.Content) + + utf8.RuneCountInString(msg.ReasoningContent) + + // Estimate tool call size + if len(msg.ToolCalls) > 0 { + for _, tc := range msg.ToolCalls { + msgRunes += utf8.RuneCountInString(tc.Name) + if argsJSON, err := json.Marshal(tc.Arguments); err == nil { + msgRunes += utf8.RuneCountInString(string(argsJSON)) + } else { + msgRunes += 100 + } + } + } + msgRunes += utf8.RuneCountInString(msg.ToolCallID) + + if currentRunes+msgRunes > remainingRunes { + // Would exceed limit, stop collecting + break + } + + // Prepend to maintain chronological order + keptMsgs = append([]providers.Message{msg}, keptMsgs...) + currentRunes += msgRunes + } + + // If we dropped messages, add a truncation notice + result := systemMsgs + if len(keptMsgs) < len(otherMsgs) { + droppedCount := len(otherMsgs) - len(keptMsgs) + truncationNotice := providers.Message{ + Role: "system", + Content: fmt.Sprintf( + "[Context truncated: %d earlier messages omitted to stay within context limits]", + droppedCount, + ), + } + result = append(result, truncationNotice) + } + + result = append(result, keptMsgs...) + return result +} diff --git a/pkg/utils/context_test.go b/pkg/utils/context_test.go new file mode 100644 index 000000000..1b8e26e2f --- /dev/null +++ b/pkg/utils/context_test.go @@ -0,0 +1,450 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package utils + +import ( + "testing" + + "github.com/sipeed/picoclaw/pkg/providers" +) + +func TestCalculateDefaultMaxContextRunes(t *testing.T) { + tests := []struct { + name string + contextWindow int + want int + }{ + { + name: "zero context window uses fallback", + contextWindow: 0, + want: 8000, + }, + { + name: "negative context window uses fallback", + contextWindow: -1, + want: 8000, + }, + { + name: "small context window (4k tokens)", + contextWindow: 4000, + want: 9000, // 4000 * 0.75 * 3 = 9000 + }, + { + name: "medium context window (128k tokens)", + contextWindow: 128000, + want: 288000, // 128000 * 0.75 * 3 = 288000 + }, + { + name: "large context window (1M tokens)", + contextWindow: 1000000, + want: 2250000, // 1000000 * 0.75 * 3 = 2250000 + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := CalculateDefaultMaxContextRunes(tt.contextWindow) + if got != tt.want { + t.Errorf("CalculateDefaultMaxContextRunes(%d) = %d, want %d", + tt.contextWindow, got, tt.want) + } + }) + } +} + +func TestResolveMaxContextRunes(t *testing.T) { + tests := []struct { + name string + configValue int + contextWindow int + want int + }{ + { + name: "explicit positive value", + configValue: 12000, + contextWindow: 4000, + want: 12000, + }, + { + name: "explicit disable (-1)", + configValue: -1, + contextWindow: 4000, + want: -1, + }, + { + name: "zero uses auto-calculate", + configValue: 0, + contextWindow: 4000, + want: 9000, // 4000 * 0.75 * 3 + }, + { + name: "unset (0) with unknown context window", + configValue: 0, + contextWindow: 0, + want: 8000, // fallback + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ResolveMaxContextRunes(tt.configValue, tt.contextWindow) + if got != tt.want { + t.Errorf("ResolveMaxContextRunes(%d, %d) = %d, want %d", + tt.configValue, tt.contextWindow, got, tt.want) + } + }) + } +} + +func TestMeasureContextRunes(t *testing.T) { + tests := []struct { + name string + messages []providers.Message + want int + }{ + { + name: "empty messages", + messages: []providers.Message{}, + want: 0, + }, + { + name: "single simple message", + messages: []providers.Message{ + {Role: "user", Content: "Hello"}, + }, + want: 5, // "Hello" = 5 runes + }, + { + name: "message with reasoning", + messages: []providers.Message{ + { + Role: "assistant", + Content: "Answer", + ReasoningContent: "Thinking", + }, + }, + want: 14, // "Answer" (6) + "Thinking" (8) = 14 + }, + { + name: "message with tool call", + messages: []providers.Message{ + { + Role: "assistant", + Content: "Using tool", + ToolCalls: []providers.ToolCall{ + { + Name: "test_tool", + Arguments: map[string]any{"key": "value"}, + }, + }, + }, + }, + want: 10 + 9 + 15, // "Using tool" + "test_tool" + {"key":"value"} + }, + { + name: "multiple messages", + messages: []providers.Message{ + {Role: "system", Content: "You are helpful"}, + {Role: "user", Content: "Hi"}, + {Role: "assistant", Content: "Hello!"}, + }, + want: 15 + 2 + 6, // 15 + 2 + 6 = 23 + }, + { + name: "unicode characters", + messages: []providers.Message{ + {Role: "user", Content: "你好世界"}, // 4 Chinese characters + }, + want: 4, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := MeasureContextRunes(tt.messages) + if got != tt.want { + t.Errorf("MeasureContextRunes() = %d, want %d", got, tt.want) + } + }) + } +} + +func TestTruncateContextSmart(t *testing.T) { + tests := []struct { + name string + messages []providers.Message + maxRunes int + wantLen int + wantHas []string // Content strings that should be present + wantNot []string // Content strings that should be absent + }{ + { + name: "empty messages", + messages: []providers.Message{}, + maxRunes: 100, + wantLen: 0, + }, + { + name: "no truncation needed", + messages: []providers.Message{ + {Role: "system", Content: "System"}, + {Role: "user", Content: "Hello"}, + }, + maxRunes: 100, + wantLen: 2, + wantHas: []string{"System", "Hello"}, + }, + { + name: "truncate when limit is tight", + messages: []providers.Message{ + {Role: "system", Content: "System"}, + {Role: "user", Content: "Message 1 with some content here"}, + {Role: "assistant", Content: "Response 1 with some content here"}, + {Role: "user", Content: "Message 2 with some content here"}, + {Role: "assistant", Content: "Response 2 with some content here"}, + {Role: "user", Content: "Latest"}, + }, + maxRunes: 120, // Tight limit to force truncation + wantLen: -1, // Don't check exact length, just verify truncation occurred + wantHas: []string{"System", "Latest"}, + wantNot: []string{"Message 1", "Response 1"}, + }, + { + name: "system messages exceed limit", + messages: []providers.Message{ + {Role: "system", Content: "Very long system message"}, + {Role: "user", Content: "User message"}, + }, + maxRunes: 10, // Less than system message + wantLen: 1, // Only system message + wantHas: []string{"Very long system message"}, + wantNot: []string{"User message"}, + }, + { + name: "preserve multiple system messages", + messages: []providers.Message{ + {Role: "system", Content: "Sys1"}, + {Role: "system", Content: "Sys2"}, + {Role: "user", Content: "Old"}, + {Role: "user", Content: "New"}, + }, + maxRunes: 200, // Generous limit + wantLen: 4, // Both system + truncation notice + new + wantHas: []string{"Sys1", "Sys2", "New"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := TruncateContextSmart(tt.messages, tt.maxRunes) + + if tt.wantLen >= 0 && len(got) != tt.wantLen { + t.Errorf("TruncateContextSmart() returned %d messages, want %d", + len(got), tt.wantLen) + } + + // Check for expected content + allContent := "" + for _, msg := range got { + allContent += msg.Content + " " + } + + for _, want := range tt.wantHas { + found := false + for _, msg := range got { + if msg.Content == want || containsSubstring(msg.Content, want) { + found = true + break + } + } + if !found { + t.Errorf("Expected content %q not found in truncated messages", want) + } + } + + for _, notWant := range tt.wantNot { + for _, msg := range got { + if containsSubstring(msg.Content, notWant) { + t.Errorf("Unexpected content %q found in truncated messages", notWant) + } + } + } + }) + } +} + +func containsSubstring(s, substr string) bool { + return len(s) >= len(substr) && findSubstring(s, substr) +} + +func findSubstring(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} + +// TestSubTurnConfigMaxContextRunes verifies that MaxContextRunes configuration +// is properly integrated into the SubTurn execution flow. +func TestSubTurnConfigMaxContextRunes(t *testing.T) { + tests := []struct { + name string + maxContextRunes int + contextWindow int + wantResolved int + }{ + { + name: "default (0) auto-calculates from context window", + maxContextRunes: 0, + contextWindow: 4000, + wantResolved: 9000, // 4000 * 0.75 * 3 + }, + { + name: "explicit value is used", + maxContextRunes: 12000, + contextWindow: 4000, + wantResolved: 12000, + }, + { + name: "disabled (-1) returns -1", + maxContextRunes: -1, + contextWindow: 4000, + wantResolved: -1, + }, + { + name: "fallback when context window unknown", + maxContextRunes: 0, + contextWindow: 0, + wantResolved: 8000, // conservative fallback + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ResolveMaxContextRunes(tt.maxContextRunes, tt.contextWindow) + if got != tt.wantResolved { + t.Errorf("utils.ResolveMaxContextRunes(%d, %d) = %d, want %d", + tt.maxContextRunes, tt.contextWindow, got, tt.wantResolved) + } + }) + } +} + +// TestContextTruncationFlow verifies the complete context truncation flow: +// 1. Messages accumulate beyond soft limit +// 2. Truncation is triggered +// 3. System messages are preserved +// 4. Recent messages are kept +func TestContextTruncationFlow(t *testing.T) { + // Build a message history that exceeds the limit + messages := []providers.Message{ + {Role: "system", Content: "You are a helpful assistant"}, // ~27 runes + {Role: "user", Content: "First question"}, // ~14 runes + {Role: "assistant", Content: "First answer"}, // ~12 runes + {Role: "user", Content: "Second question"}, // ~15 runes + {Role: "assistant", Content: "Second answer"}, // ~13 runes + {Role: "user", Content: "Third question"}, // ~14 runes + {Role: "assistant", Content: "Third answer"}, // ~12 runes + {Role: "user", Content: "Latest question"}, // ~15 runes + } + + // Total: ~122 runes + totalRunes := MeasureContextRunes(messages) + if totalRunes < 100 { + t.Errorf("Expected total runes > 100, got %d", totalRunes) + } + + // Set limit to 150 runes - should force truncation of old messages + // but preserve system + truncation notice + recent messages + maxRunes := 150 + truncated := TruncateContextSmart(messages, maxRunes) + + // Verify truncation occurred + if len(truncated) >= len(messages) { + t.Errorf("Expected truncation, but got %d messages (original: %d)", + len(truncated), len(messages)) + } + + // Verify system message is preserved + foundSystem := false + for _, msg := range truncated { + if msg.Role == "system" && msg.Content == "You are a helpful assistant" { + foundSystem = true + break + } + } + if !foundSystem { + t.Error("System message was not preserved after truncation") + } + + // Verify latest message is preserved + foundLatest := false + for _, msg := range truncated { + if msg.Content == "Latest question" { + foundLatest = true + break + } + } + if !foundLatest { + t.Error("Latest message was not preserved after truncation") + } + + // Verify truncation notice is present + foundNotice := false + for _, msg := range truncated { + if msg.Role == "system" && containsSubstring(msg.Content, "truncated") { + foundNotice = true + break + } + } + if !foundNotice { + t.Error("Truncation notice was not added") + } + + // Verify result is within limit (with some tolerance for estimation) + resultRunes := MeasureContextRunes(truncated) + if resultRunes > maxRunes+20 { // Allow 20 rune tolerance + t.Errorf("Truncated context (%d runes) significantly exceeds limit (%d runes)", + resultRunes, maxRunes) + } +} + +// TestContextTruncationPreservesToolCalls verifies that tool calls are +// properly handled during context truncation. +func TestContextTruncationPreservesToolCalls(t *testing.T) { + messages := []providers.Message{ + {Role: "system", Content: "System"}, + {Role: "user", Content: "Old message that should be dropped"}, + { + Role: "assistant", + Content: "Recent tool use", + ToolCalls: []providers.ToolCall{ + { + Name: "important_tool", + Arguments: map[string]any{"key": "value"}, + }, + }, + }, + } + + // Set a generous limit that should keep the tool call message + maxRunes := 200 + truncated := TruncateContextSmart(messages, maxRunes) + + // Verify tool call message is preserved + foundToolCall := false + for _, msg := range truncated { + if len(msg.ToolCalls) > 0 && msg.ToolCalls[0].Name == "important_tool" { + foundToolCall = true + break + } + } + if !foundToolCall { + t.Error("Tool call message was not preserved during truncation") + } +}