fix(agent): prevent duplicate history during subturn context recoveries

Problem:
During subturn context limit or truncation recoveries, the recovery loops repeatedly
called `runAgentLoop` with the same or modified `UserMessage`. Because `runAgentLoop`
unconditionally adds the `UserMessage` to the session history, this resulted in:
1. Duplicate User Messages polluting the history upon `context_length_exceeded` retries.
2. The possibility of injecting empty User Messages if `opts.UserMessage` was artificially blanked out to work around the duplication.
3. Messy or duplicate entries during `finish_reason="truncated"` recovery injections.

Solution:
- Introduce `SkipAddUserMessage` boolean to `processOptions` to explicitly control whether the agent loop should write the user prompt to history.
- Add an explicit `opts.UserMessage != ""` check in `runAgentLoop` to prevent polluting history with empty message content.
- In `subturn.go`'s recovery loop, set `SkipAddUserMessage: contextRetryCount > 0` to skip writing the user message on context
This commit is contained in:
Administrator
2026-03-18 12:18:32 +08:00
parent f8defe3ae1
commit c7ea018a73
6 changed files with 834 additions and 14 deletions
+11 -3
View File
@@ -49,8 +49,8 @@ type AgentLoop struct {
cmdRegistry *commands.Registry
mcp mcpRuntime
steering *steeringQueue
subTurnResults sync.Map // key: sessionKey (string), value: chan *tools.ToolResult
activeTurnStates sync.Map // key: sessionKey (string), value: *turnState
subTurnResults sync.Map // key: sessionKey (string), value: chan *tools.ToolResult
activeTurnStates sync.Map // key: sessionKey (string), value: *turnState
subTurnCounter atomic.Int64 // Counter for generating unique SubTurn IDs
mu sync.RWMutex
// Track active requests for safe provider cleanup
@@ -69,6 +69,7 @@ type processOptions struct {
SendResponse bool // Whether to send response via bus
NoHistory bool // If true, don't load session history (for heartbeat)
SkipInitialSteeringPoll bool // If true, skip the steering poll at loop start (used by Continue)
SkipAddUserMessage bool // If true, skip adding UserMessage to session history
}
const (
@@ -1051,7 +1052,9 @@ func (al *AgentLoop) runAgentLoop(
messages = resolveMediaRefs(messages, al.mediaStore, maxMediaSize)
// 2. Save user message to session
agent.Sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage)
if !opts.SkipAddUserMessage && opts.UserMessage != "" {
agent.Sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage)
}
// 3. Run LLM iteration loop
finalContent, iteration, err := al.runLLMIteration(ctx, agent, messages, opts)
@@ -1403,6 +1406,11 @@ func (al *AgentLoop) runLLMIteration(
return "", iteration, fmt.Errorf("LLM call failed after retries: %w", err)
}
// Save finishReason to turnState for SubTurn truncation detection
if ts := turnStateFromContext(ctx); ts != nil {
ts.SetLastFinishReason(response.FinishReason)
}
go al.handleReasoning(
ctx,
response.Reasoning,
+171 -10
View File
@@ -4,11 +4,13 @@ import (
"context"
"errors"
"fmt"
"strings"
"time"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/tools"
"github.com/sipeed/picoclaw/pkg/utils"
)
// ====================== Config & Constants ======================
@@ -104,6 +106,19 @@ type SubTurnConfig struct {
// Default is 5 minutes (defaultSubTurnTimeout) if not specified.
Timeout time.Duration
// MaxContextRunes limits the context size (in runes) passed to the SubTurn.
// This prevents context window overflow by truncating message history before LLM calls.
//
// Values:
// 0 = Auto-calculate based on model's ContextWindow * 0.75 (default, recommended)
// -1 = No limit (disable soft truncation, rely only on hard context errors)
// >0 = Use specified rune limit
//
// The soft limit acts as a first line of defense before hitting the provider's
// hard context window limit. When exceeded, older messages are intelligently
// truncated while preserving system messages and recent context.
MaxContextRunes int
// Can be extended with temperature, topP, etc.
}
@@ -377,6 +392,25 @@ func deliverSubTurnResult(parentTS *turnState, childID string, result *tools.Too
// runTurn builds a temporary AgentInstance from SubTurnConfig and delegates to
// the real agent loop. The child's ephemeral session is used for history so it
// never pollutes the parent session.
//
// This function implements multiple layers of context protection and error recovery:
//
// 1. Soft Context Limit (MaxContextRunes):
// - Proactively truncates message history before LLM calls
// - Default: 75% of model's context window
// - Preserves system messages and recent context
// - First line of defense against context overflow
//
// 2. Hard Context Error Recovery:
// - Detects context_length_exceeded errors from provider
// - Triggers force compression and retries (up to 2 times)
// - Second line of defense when soft limit is insufficient
//
// 3. Truncation Recovery:
// - Detects when LLM response is truncated (finish_reason="truncated")
// - Injects recovery prompt asking for shorter response
// - Retries up to 2 times
// - Handles cases where max_tokens is hit
func runTurn(ctx context.Context, al *AgentLoop, ts *turnState, cfg SubTurnConfig) (*tools.ToolResult, error) {
// Derive candidates from the requested model using the parent loop's provider.
defaultProvider := al.GetConfig().Agents.Defaults.Provider
@@ -420,17 +454,144 @@ func runTurn(ctx context.Context, al *AgentLoop, ts *turnState, cfg SubTurnConfi
childAgent.MaxTokens = parentAgent.MaxTokens
}
finalContent, err := al.runAgentLoop(ctx, childAgent, processOptions{
SessionKey: ts.turnID,
UserMessage: cfg.SystemPrompt,
DefaultResponse: "",
EnableSummary: false,
SendResponse: false,
})
if err != nil {
return nil, err
// Resolve MaxContextRunes configuration
maxContextRunes := utils.ResolveMaxContextRunes(cfg.MaxContextRunes, childAgent.ContextWindow)
logger.DebugCF("subturn", "Context limit resolved",
map[string]any{
"turn_id": ts.turnID,
"context_window": childAgent.ContextWindow,
"max_context_runes": maxContextRunes,
"configured_value": cfg.MaxContextRunes,
})
// Retry loop for truncation and context errors
const (
maxTruncationRetries = 2
maxContextRetries = 2
)
truncationRetryCount := 0
contextRetryCount := 0
currentPrompt := cfg.SystemPrompt
for {
// Soft context limit: check and truncate before LLM call
if maxContextRunes > 0 {
messages := childAgent.Sessions.GetHistory(ts.turnID)
currentRunes := utils.MeasureContextRunes(messages)
if currentRunes > maxContextRunes {
logger.WarnCF("subturn", "Context exceeds soft limit, truncating",
map[string]any{
"turn_id": ts.turnID,
"current_runes": currentRunes,
"max_runes": maxContextRunes,
"overflow": currentRunes - maxContextRunes,
})
truncatedMessages := utils.TruncateContextSmart(messages, maxContextRunes)
childAgent.Sessions.SetHistory(ts.turnID, truncatedMessages)
// Log truncation result
newRunes := utils.MeasureContextRunes(truncatedMessages)
logger.InfoCF("subturn", "Context truncated successfully",
map[string]any{
"turn_id": ts.turnID,
"before_runes": currentRunes,
"after_runes": newRunes,
"saved_runes": currentRunes - newRunes,
})
}
}
// Call the agent loop
finalContent, err := al.runAgentLoop(ctx, childAgent, processOptions{
SessionKey: ts.turnID,
UserMessage: currentPrompt,
DefaultResponse: "",
EnableSummary: false,
SendResponse: false,
SkipAddUserMessage: contextRetryCount > 0,
})
// 1. Handle context length errors
if err != nil && isContextLengthError(err) {
if contextRetryCount >= maxContextRetries {
logger.ErrorCF("subturn", "Context limit exceeded after max retries",
map[string]any{
"turn_id": ts.turnID,
"retries": contextRetryCount,
"max_retries": maxContextRetries,
})
return nil, fmt.Errorf("context limit exceeded after %d retries: %w", maxContextRetries, err)
}
logger.WarnCF("subturn", "Context length exceeded, compressing and retrying",
map[string]any{
"turn_id": ts.turnID,
"retry": contextRetryCount + 1,
})
// Trigger force compression
al.forceCompression(childAgent, ts.turnID)
contextRetryCount++
continue // Retry with compressed history
}
if err != nil {
return nil, err // Other errors, return immediately
}
// 2. Check for truncation (retrieve finishReason from turnState)
finishReason := ts.GetLastFinishReason()
if finishReason == "truncated" && truncationRetryCount < maxTruncationRetries {
logger.WarnCF("subturn", "Response truncated, injecting recovery message",
map[string]any{
"turn_id": ts.turnID,
"retry": truncationRetryCount + 1,
})
// IMPORTANT: Do NOT manually add messages to history here.
// runAgentLoop has already saved both the assistant message (finalContent)
// and will save the next user message (currentPrompt) on the next iteration.
// Manually adding them would cause duplicates.
// Inject recovery prompt - it will be added by runAgentLoop on next iteration
recoveryPrompt := "Your previous response was truncated due to length. Please provide a shorter, complete response that finishes your thought."
currentPrompt = recoveryPrompt
truncationRetryCount++
continue // Retry with recovery prompt
}
// 3. Success - return result
return &tools.ToolResult{ForLLM: finalContent}, nil
}
return &tools.ToolResult{ForLLM: finalContent}, nil
}
// isContextLengthError checks if the error is due to context length exceeded.
// It excludes timeout errors to avoid false positives.
func isContextLengthError(err error) bool {
if err == nil {
return false
}
errMsg := strings.ToLower(err.Error())
// Exclude timeout errors
if strings.Contains(errMsg, "timeout") || strings.Contains(errMsg, "deadline exceeded") {
return false
}
// Detect context error patterns
return strings.Contains(errMsg, "context_length_exceeded") ||
strings.Contains(errMsg, "maximum context length") ||
strings.Contains(errMsg, "context window") ||
strings.Contains(errMsg, "too many tokens") ||
strings.Contains(errMsg, "token limit") ||
strings.Contains(errMsg, "prompt is too long")
}
// ====================== Other Types ======================
+19
View File
@@ -55,6 +55,11 @@ type turnState struct {
// This allows child SubTurns to check if the parent has ended.
// Nil for root turns.
parentTurnState *turnState
// lastFinishReason stores the finish_reason from the last LLM call.
// Used by SubTurn to detect truncation and retry.
// MUST be accessed under mu lock.
lastFinishReason string
}
// ====================== Public API ======================
@@ -136,6 +141,20 @@ func (ts *turnState) IsParentEnded() bool {
return ts.parentTurnState.parentEnded.Load()
}
// SetLastFinishReason updates the last finish reason (thread-safe).
func (ts *turnState) SetLastFinishReason(reason string) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.lastFinishReason = reason
}
// GetLastFinishReason retrieves the last finish reason (thread-safe).
func (ts *turnState) GetLastFinishReason() string {
ts.mu.Lock()
defer ts.mu.Unlock()
return ts.lastFinishReason
}
// IsParentEnded is a convenience method to check if parent ended.
// It returns the value of the parent's parentEnded atomic flag.
+10 -1
View File
@@ -214,11 +214,20 @@ func ParseResponse(body io.Reader) (*LLMResponse, error) {
Reasoning: choice.Message.Reasoning,
ReasoningDetails: choice.Message.ReasoningDetails,
ToolCalls: toolCalls,
FinishReason: choice.FinishReason,
FinishReason: normalizeFinishReason(choice.FinishReason),
Usage: apiResponse.Usage,
}, nil
}
// normalizeFinishReason normalizes finish_reason values across providers.
// Converts "length" to "truncated" for consistent handling.
func normalizeFinishReason(reason string) string {
if reason == "length" {
return "truncated"
}
return reason
}
// DecodeToolCallArguments decodes a tool call's arguments from raw JSON.
func DecodeToolCallArguments(raw json.RawMessage, name string) map[string]any {
arguments := make(map[string]any)
+173
View File
@@ -0,0 +1,173 @@
// PicoClaw - Ultra-lightweight personal AI agent
// Inspired by and based on nanobot: https://github.com/HKUDS/nanobot
// License: MIT
//
// Copyright (c) 2026 PicoClaw contributors
package utils
import (
"encoding/json"
"fmt"
"unicode/utf8"
"github.com/sipeed/picoclaw/pkg/providers"
)
// CalculateDefaultMaxContextRunes computes a default context limit based on the model's context window.
// Strategy: Use 75% of the context window and convert to rune estimate.
//
// Token-to-rune conversion ratios (conservative estimates):
// - English: ~4 chars per token
// - Chinese: ~1.5-2 chars per token
// - Mixed: ~3 chars per token (used here for safety)
func CalculateDefaultMaxContextRunes(contextWindow int) int {
if contextWindow <= 0 {
// Conservative fallback when context window is unknown
return 8000 // ~2000 tokens
}
// Use 75% of context window to leave headroom
targetTokens := int(float64(contextWindow) * 0.75)
// Convert tokens to runes using conservative ratio
const avgCharsPerToken = 3
return targetTokens * avgCharsPerToken
}
// ResolveMaxContextRunes determines the final MaxContextRunes value to use.
// Priority: explicit config > auto-calculate > conservative default
func ResolveMaxContextRunes(configValue, contextWindow int) int {
switch {
case configValue > 0:
// Explicitly configured, use as-is
return configValue
case configValue == -1:
// Explicitly disabled
return -1
default:
// 0 or unset: auto-calculate
return CalculateDefaultMaxContextRunes(contextWindow)
}
}
// MeasureContextRunes calculates the total rune count of a message list.
// Includes content, reasoning content, and estimates for tool calls.
func MeasureContextRunes(messages []providers.Message) int {
totalRunes := 0
for _, msg := range messages {
totalRunes += utf8.RuneCountInString(msg.Content)
totalRunes += utf8.RuneCountInString(msg.ReasoningContent)
// Tool calls: serialize to JSON and count
if len(msg.ToolCalls) > 0 {
for _, tc := range msg.ToolCalls {
totalRunes += utf8.RuneCountInString(tc.Name)
// Arguments: serialize and count
if argsJSON, err := json.Marshal(tc.Arguments); err == nil {
totalRunes += utf8.RuneCountInString(string(argsJSON))
} else {
// Fallback estimate if serialization fails
totalRunes += 100
}
}
}
// ToolCallID
totalRunes += utf8.RuneCountInString(msg.ToolCallID)
}
return totalRunes
}
// TruncateContextSmart intelligently truncates message history to fit within maxRunes.
//
// Strategy:
// 1. Always preserve system messages (they define the agent's behavior)
// 2. Keep the most recent messages (they contain current context)
// 3. Drop older middle messages when necessary
// 4. Insert a truncation notice to inform the LLM
//
// Returns the truncated message list.
func TruncateContextSmart(messages []providers.Message, maxRunes int) []providers.Message {
if len(messages) == 0 {
return messages
}
// Separate system messages from others
var systemMsgs []providers.Message
var otherMsgs []providers.Message
for _, msg := range messages {
if msg.Role == "system" {
systemMsgs = append(systemMsgs, msg)
} else {
otherMsgs = append(otherMsgs, msg)
}
}
// Calculate system message size
systemRunes := 0
for _, msg := range systemMsgs {
systemRunes += utf8.RuneCountInString(msg.Content)
systemRunes += utf8.RuneCountInString(msg.ReasoningContent)
}
// Reserve space for truncation notice (estimate ~80 runes)
const truncationNoticeEstimate = 80
// Allocate remaining space for other messages
remainingRunes := maxRunes - systemRunes - truncationNoticeEstimate
if remainingRunes <= 0 {
// System messages already exceed limit - return only system messages
return systemMsgs
}
// Collect recent messages in reverse order until we hit the limit
var keptMsgs []providers.Message
currentRunes := 0
for i := len(otherMsgs) - 1; i >= 0; i-- {
msg := otherMsgs[i]
msgRunes := utf8.RuneCountInString(msg.Content) +
utf8.RuneCountInString(msg.ReasoningContent)
// Estimate tool call size
if len(msg.ToolCalls) > 0 {
for _, tc := range msg.ToolCalls {
msgRunes += utf8.RuneCountInString(tc.Name)
if argsJSON, err := json.Marshal(tc.Arguments); err == nil {
msgRunes += utf8.RuneCountInString(string(argsJSON))
} else {
msgRunes += 100
}
}
}
msgRunes += utf8.RuneCountInString(msg.ToolCallID)
if currentRunes+msgRunes > remainingRunes {
// Would exceed limit, stop collecting
break
}
// Prepend to maintain chronological order
keptMsgs = append([]providers.Message{msg}, keptMsgs...)
currentRunes += msgRunes
}
// If we dropped messages, add a truncation notice
result := systemMsgs
if len(keptMsgs) < len(otherMsgs) {
droppedCount := len(otherMsgs) - len(keptMsgs)
truncationNotice := providers.Message{
Role: "system",
Content: fmt.Sprintf(
"[Context truncated: %d earlier messages omitted to stay within context limits]",
droppedCount,
),
}
result = append(result, truncationNotice)
}
result = append(result, keptMsgs...)
return result
}
+450
View File
@@ -0,0 +1,450 @@
// PicoClaw - Ultra-lightweight personal AI agent
// License: MIT
//
// Copyright (c) 2026 PicoClaw contributors
package utils
import (
"testing"
"github.com/sipeed/picoclaw/pkg/providers"
)
func TestCalculateDefaultMaxContextRunes(t *testing.T) {
tests := []struct {
name string
contextWindow int
want int
}{
{
name: "zero context window uses fallback",
contextWindow: 0,
want: 8000,
},
{
name: "negative context window uses fallback",
contextWindow: -1,
want: 8000,
},
{
name: "small context window (4k tokens)",
contextWindow: 4000,
want: 9000, // 4000 * 0.75 * 3 = 9000
},
{
name: "medium context window (128k tokens)",
contextWindow: 128000,
want: 288000, // 128000 * 0.75 * 3 = 288000
},
{
name: "large context window (1M tokens)",
contextWindow: 1000000,
want: 2250000, // 1000000 * 0.75 * 3 = 2250000
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := CalculateDefaultMaxContextRunes(tt.contextWindow)
if got != tt.want {
t.Errorf("CalculateDefaultMaxContextRunes(%d) = %d, want %d",
tt.contextWindow, got, tt.want)
}
})
}
}
func TestResolveMaxContextRunes(t *testing.T) {
tests := []struct {
name string
configValue int
contextWindow int
want int
}{
{
name: "explicit positive value",
configValue: 12000,
contextWindow: 4000,
want: 12000,
},
{
name: "explicit disable (-1)",
configValue: -1,
contextWindow: 4000,
want: -1,
},
{
name: "zero uses auto-calculate",
configValue: 0,
contextWindow: 4000,
want: 9000, // 4000 * 0.75 * 3
},
{
name: "unset (0) with unknown context window",
configValue: 0,
contextWindow: 0,
want: 8000, // fallback
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ResolveMaxContextRunes(tt.configValue, tt.contextWindow)
if got != tt.want {
t.Errorf("ResolveMaxContextRunes(%d, %d) = %d, want %d",
tt.configValue, tt.contextWindow, got, tt.want)
}
})
}
}
func TestMeasureContextRunes(t *testing.T) {
tests := []struct {
name string
messages []providers.Message
want int
}{
{
name: "empty messages",
messages: []providers.Message{},
want: 0,
},
{
name: "single simple message",
messages: []providers.Message{
{Role: "user", Content: "Hello"},
},
want: 5, // "Hello" = 5 runes
},
{
name: "message with reasoning",
messages: []providers.Message{
{
Role: "assistant",
Content: "Answer",
ReasoningContent: "Thinking",
},
},
want: 14, // "Answer" (6) + "Thinking" (8) = 14
},
{
name: "message with tool call",
messages: []providers.Message{
{
Role: "assistant",
Content: "Using tool",
ToolCalls: []providers.ToolCall{
{
Name: "test_tool",
Arguments: map[string]any{"key": "value"},
},
},
},
},
want: 10 + 9 + 15, // "Using tool" + "test_tool" + {"key":"value"}
},
{
name: "multiple messages",
messages: []providers.Message{
{Role: "system", Content: "You are helpful"},
{Role: "user", Content: "Hi"},
{Role: "assistant", Content: "Hello!"},
},
want: 15 + 2 + 6, // 15 + 2 + 6 = 23
},
{
name: "unicode characters",
messages: []providers.Message{
{Role: "user", Content: "你好世界"}, // 4 Chinese characters
},
want: 4,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := MeasureContextRunes(tt.messages)
if got != tt.want {
t.Errorf("MeasureContextRunes() = %d, want %d", got, tt.want)
}
})
}
}
func TestTruncateContextSmart(t *testing.T) {
tests := []struct {
name string
messages []providers.Message
maxRunes int
wantLen int
wantHas []string // Content strings that should be present
wantNot []string // Content strings that should be absent
}{
{
name: "empty messages",
messages: []providers.Message{},
maxRunes: 100,
wantLen: 0,
},
{
name: "no truncation needed",
messages: []providers.Message{
{Role: "system", Content: "System"},
{Role: "user", Content: "Hello"},
},
maxRunes: 100,
wantLen: 2,
wantHas: []string{"System", "Hello"},
},
{
name: "truncate when limit is tight",
messages: []providers.Message{
{Role: "system", Content: "System"},
{Role: "user", Content: "Message 1 with some content here"},
{Role: "assistant", Content: "Response 1 with some content here"},
{Role: "user", Content: "Message 2 with some content here"},
{Role: "assistant", Content: "Response 2 with some content here"},
{Role: "user", Content: "Latest"},
},
maxRunes: 120, // Tight limit to force truncation
wantLen: -1, // Don't check exact length, just verify truncation occurred
wantHas: []string{"System", "Latest"},
wantNot: []string{"Message 1", "Response 1"},
},
{
name: "system messages exceed limit",
messages: []providers.Message{
{Role: "system", Content: "Very long system message"},
{Role: "user", Content: "User message"},
},
maxRunes: 10, // Less than system message
wantLen: 1, // Only system message
wantHas: []string{"Very long system message"},
wantNot: []string{"User message"},
},
{
name: "preserve multiple system messages",
messages: []providers.Message{
{Role: "system", Content: "Sys1"},
{Role: "system", Content: "Sys2"},
{Role: "user", Content: "Old"},
{Role: "user", Content: "New"},
},
maxRunes: 200, // Generous limit
wantLen: 4, // Both system + truncation notice + new
wantHas: []string{"Sys1", "Sys2", "New"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := TruncateContextSmart(tt.messages, tt.maxRunes)
if tt.wantLen >= 0 && len(got) != tt.wantLen {
t.Errorf("TruncateContextSmart() returned %d messages, want %d",
len(got), tt.wantLen)
}
// Check for expected content
allContent := ""
for _, msg := range got {
allContent += msg.Content + " "
}
for _, want := range tt.wantHas {
found := false
for _, msg := range got {
if msg.Content == want || containsSubstring(msg.Content, want) {
found = true
break
}
}
if !found {
t.Errorf("Expected content %q not found in truncated messages", want)
}
}
for _, notWant := range tt.wantNot {
for _, msg := range got {
if containsSubstring(msg.Content, notWant) {
t.Errorf("Unexpected content %q found in truncated messages", notWant)
}
}
}
})
}
}
func containsSubstring(s, substr string) bool {
return len(s) >= len(substr) && findSubstring(s, substr)
}
func findSubstring(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}
// TestSubTurnConfigMaxContextRunes verifies that MaxContextRunes configuration
// is properly integrated into the SubTurn execution flow.
func TestSubTurnConfigMaxContextRunes(t *testing.T) {
tests := []struct {
name string
maxContextRunes int
contextWindow int
wantResolved int
}{
{
name: "default (0) auto-calculates from context window",
maxContextRunes: 0,
contextWindow: 4000,
wantResolved: 9000, // 4000 * 0.75 * 3
},
{
name: "explicit value is used",
maxContextRunes: 12000,
contextWindow: 4000,
wantResolved: 12000,
},
{
name: "disabled (-1) returns -1",
maxContextRunes: -1,
contextWindow: 4000,
wantResolved: -1,
},
{
name: "fallback when context window unknown",
maxContextRunes: 0,
contextWindow: 0,
wantResolved: 8000, // conservative fallback
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ResolveMaxContextRunes(tt.maxContextRunes, tt.contextWindow)
if got != tt.wantResolved {
t.Errorf("utils.ResolveMaxContextRunes(%d, %d) = %d, want %d",
tt.maxContextRunes, tt.contextWindow, got, tt.wantResolved)
}
})
}
}
// TestContextTruncationFlow verifies the complete context truncation flow:
// 1. Messages accumulate beyond soft limit
// 2. Truncation is triggered
// 3. System messages are preserved
// 4. Recent messages are kept
func TestContextTruncationFlow(t *testing.T) {
// Build a message history that exceeds the limit
messages := []providers.Message{
{Role: "system", Content: "You are a helpful assistant"}, // ~27 runes
{Role: "user", Content: "First question"}, // ~14 runes
{Role: "assistant", Content: "First answer"}, // ~12 runes
{Role: "user", Content: "Second question"}, // ~15 runes
{Role: "assistant", Content: "Second answer"}, // ~13 runes
{Role: "user", Content: "Third question"}, // ~14 runes
{Role: "assistant", Content: "Third answer"}, // ~12 runes
{Role: "user", Content: "Latest question"}, // ~15 runes
}
// Total: ~122 runes
totalRunes := MeasureContextRunes(messages)
if totalRunes < 100 {
t.Errorf("Expected total runes > 100, got %d", totalRunes)
}
// Set limit to 150 runes - should force truncation of old messages
// but preserve system + truncation notice + recent messages
maxRunes := 150
truncated := TruncateContextSmart(messages, maxRunes)
// Verify truncation occurred
if len(truncated) >= len(messages) {
t.Errorf("Expected truncation, but got %d messages (original: %d)",
len(truncated), len(messages))
}
// Verify system message is preserved
foundSystem := false
for _, msg := range truncated {
if msg.Role == "system" && msg.Content == "You are a helpful assistant" {
foundSystem = true
break
}
}
if !foundSystem {
t.Error("System message was not preserved after truncation")
}
// Verify latest message is preserved
foundLatest := false
for _, msg := range truncated {
if msg.Content == "Latest question" {
foundLatest = true
break
}
}
if !foundLatest {
t.Error("Latest message was not preserved after truncation")
}
// Verify truncation notice is present
foundNotice := false
for _, msg := range truncated {
if msg.Role == "system" && containsSubstring(msg.Content, "truncated") {
foundNotice = true
break
}
}
if !foundNotice {
t.Error("Truncation notice was not added")
}
// Verify result is within limit (with some tolerance for estimation)
resultRunes := MeasureContextRunes(truncated)
if resultRunes > maxRunes+20 { // Allow 20 rune tolerance
t.Errorf("Truncated context (%d runes) significantly exceeds limit (%d runes)",
resultRunes, maxRunes)
}
}
// TestContextTruncationPreservesToolCalls verifies that tool calls are
// properly handled during context truncation.
func TestContextTruncationPreservesToolCalls(t *testing.T) {
messages := []providers.Message{
{Role: "system", Content: "System"},
{Role: "user", Content: "Old message that should be dropped"},
{
Role: "assistant",
Content: "Recent tool use",
ToolCalls: []providers.ToolCall{
{
Name: "important_tool",
Arguments: map[string]any{"key": "value"},
},
},
},
}
// Set a generous limit that should keep the tool call message
maxRunes := 200
truncated := TruncateContextSmart(messages, maxRunes)
// Verify tool call message is preserved
foundToolCall := false
for _, msg := range truncated {
if len(msg.ToolCalls) > 0 && msg.ToolCalls[0].Name == "important_tool" {
foundToolCall = true
break
}
}
if !foundToolCall {
t.Error("Tool call message was not preserved during truncation")
}
}