mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
fix(agent): prevent duplicate history during subturn context recoveries
Problem: During subturn context limit or truncation recoveries, the recovery loops repeatedly called `runAgentLoop` with the same or modified `UserMessage`. Because `runAgentLoop` unconditionally adds the `UserMessage` to the session history, this resulted in: 1. Duplicate User Messages polluting the history upon `context_length_exceeded` retries. 2. The possibility of injecting empty User Messages if `opts.UserMessage` was artificially blanked out to work around the duplication. 3. Messy or duplicate entries during `finish_reason="truncated"` recovery injections. Solution: - Introduce `SkipAddUserMessage` boolean to `processOptions` to explicitly control whether the agent loop should write the user prompt to history. - Add an explicit `opts.UserMessage != ""` check in `runAgentLoop` to prevent polluting history with empty message content. - In `subturn.go`'s recovery loop, set `SkipAddUserMessage: contextRetryCount > 0` to skip writing the user message on context
This commit is contained in:
+11
-3
@@ -49,8 +49,8 @@ type AgentLoop struct {
|
||||
cmdRegistry *commands.Registry
|
||||
mcp mcpRuntime
|
||||
steering *steeringQueue
|
||||
subTurnResults sync.Map // key: sessionKey (string), value: chan *tools.ToolResult
|
||||
activeTurnStates sync.Map // key: sessionKey (string), value: *turnState
|
||||
subTurnResults sync.Map // key: sessionKey (string), value: chan *tools.ToolResult
|
||||
activeTurnStates sync.Map // key: sessionKey (string), value: *turnState
|
||||
subTurnCounter atomic.Int64 // Counter for generating unique SubTurn IDs
|
||||
mu sync.RWMutex
|
||||
// Track active requests for safe provider cleanup
|
||||
@@ -69,6 +69,7 @@ type processOptions struct {
|
||||
SendResponse bool // Whether to send response via bus
|
||||
NoHistory bool // If true, don't load session history (for heartbeat)
|
||||
SkipInitialSteeringPoll bool // If true, skip the steering poll at loop start (used by Continue)
|
||||
SkipAddUserMessage bool // If true, skip adding UserMessage to session history
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -1051,7 +1052,9 @@ func (al *AgentLoop) runAgentLoop(
|
||||
messages = resolveMediaRefs(messages, al.mediaStore, maxMediaSize)
|
||||
|
||||
// 2. Save user message to session
|
||||
agent.Sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage)
|
||||
if !opts.SkipAddUserMessage && opts.UserMessage != "" {
|
||||
agent.Sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage)
|
||||
}
|
||||
|
||||
// 3. Run LLM iteration loop
|
||||
finalContent, iteration, err := al.runLLMIteration(ctx, agent, messages, opts)
|
||||
@@ -1403,6 +1406,11 @@ func (al *AgentLoop) runLLMIteration(
|
||||
return "", iteration, fmt.Errorf("LLM call failed after retries: %w", err)
|
||||
}
|
||||
|
||||
// Save finishReason to turnState for SubTurn truncation detection
|
||||
if ts := turnStateFromContext(ctx); ts != nil {
|
||||
ts.SetLastFinishReason(response.FinishReason)
|
||||
}
|
||||
|
||||
go al.handleReasoning(
|
||||
ctx,
|
||||
response.Reasoning,
|
||||
|
||||
+171
-10
@@ -4,11 +4,13 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
// ====================== Config & Constants ======================
|
||||
@@ -104,6 +106,19 @@ type SubTurnConfig struct {
|
||||
// Default is 5 minutes (defaultSubTurnTimeout) if not specified.
|
||||
Timeout time.Duration
|
||||
|
||||
// MaxContextRunes limits the context size (in runes) passed to the SubTurn.
|
||||
// This prevents context window overflow by truncating message history before LLM calls.
|
||||
//
|
||||
// Values:
|
||||
// 0 = Auto-calculate based on model's ContextWindow * 0.75 (default, recommended)
|
||||
// -1 = No limit (disable soft truncation, rely only on hard context errors)
|
||||
// >0 = Use specified rune limit
|
||||
//
|
||||
// The soft limit acts as a first line of defense before hitting the provider's
|
||||
// hard context window limit. When exceeded, older messages are intelligently
|
||||
// truncated while preserving system messages and recent context.
|
||||
MaxContextRunes int
|
||||
|
||||
// Can be extended with temperature, topP, etc.
|
||||
}
|
||||
|
||||
@@ -377,6 +392,25 @@ func deliverSubTurnResult(parentTS *turnState, childID string, result *tools.Too
|
||||
// runTurn builds a temporary AgentInstance from SubTurnConfig and delegates to
|
||||
// the real agent loop. The child's ephemeral session is used for history so it
|
||||
// never pollutes the parent session.
|
||||
//
|
||||
// This function implements multiple layers of context protection and error recovery:
|
||||
//
|
||||
// 1. Soft Context Limit (MaxContextRunes):
|
||||
// - Proactively truncates message history before LLM calls
|
||||
// - Default: 75% of model's context window
|
||||
// - Preserves system messages and recent context
|
||||
// - First line of defense against context overflow
|
||||
//
|
||||
// 2. Hard Context Error Recovery:
|
||||
// - Detects context_length_exceeded errors from provider
|
||||
// - Triggers force compression and retries (up to 2 times)
|
||||
// - Second line of defense when soft limit is insufficient
|
||||
//
|
||||
// 3. Truncation Recovery:
|
||||
// - Detects when LLM response is truncated (finish_reason="truncated")
|
||||
// - Injects recovery prompt asking for shorter response
|
||||
// - Retries up to 2 times
|
||||
// - Handles cases where max_tokens is hit
|
||||
func runTurn(ctx context.Context, al *AgentLoop, ts *turnState, cfg SubTurnConfig) (*tools.ToolResult, error) {
|
||||
// Derive candidates from the requested model using the parent loop's provider.
|
||||
defaultProvider := al.GetConfig().Agents.Defaults.Provider
|
||||
@@ -420,17 +454,144 @@ func runTurn(ctx context.Context, al *AgentLoop, ts *turnState, cfg SubTurnConfi
|
||||
childAgent.MaxTokens = parentAgent.MaxTokens
|
||||
}
|
||||
|
||||
finalContent, err := al.runAgentLoop(ctx, childAgent, processOptions{
|
||||
SessionKey: ts.turnID,
|
||||
UserMessage: cfg.SystemPrompt,
|
||||
DefaultResponse: "",
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// Resolve MaxContextRunes configuration
|
||||
maxContextRunes := utils.ResolveMaxContextRunes(cfg.MaxContextRunes, childAgent.ContextWindow)
|
||||
|
||||
logger.DebugCF("subturn", "Context limit resolved",
|
||||
map[string]any{
|
||||
"turn_id": ts.turnID,
|
||||
"context_window": childAgent.ContextWindow,
|
||||
"max_context_runes": maxContextRunes,
|
||||
"configured_value": cfg.MaxContextRunes,
|
||||
})
|
||||
|
||||
// Retry loop for truncation and context errors
|
||||
const (
|
||||
maxTruncationRetries = 2
|
||||
maxContextRetries = 2
|
||||
)
|
||||
|
||||
truncationRetryCount := 0
|
||||
contextRetryCount := 0
|
||||
currentPrompt := cfg.SystemPrompt
|
||||
|
||||
for {
|
||||
// Soft context limit: check and truncate before LLM call
|
||||
if maxContextRunes > 0 {
|
||||
messages := childAgent.Sessions.GetHistory(ts.turnID)
|
||||
currentRunes := utils.MeasureContextRunes(messages)
|
||||
|
||||
if currentRunes > maxContextRunes {
|
||||
logger.WarnCF("subturn", "Context exceeds soft limit, truncating",
|
||||
map[string]any{
|
||||
"turn_id": ts.turnID,
|
||||
"current_runes": currentRunes,
|
||||
"max_runes": maxContextRunes,
|
||||
"overflow": currentRunes - maxContextRunes,
|
||||
})
|
||||
|
||||
truncatedMessages := utils.TruncateContextSmart(messages, maxContextRunes)
|
||||
childAgent.Sessions.SetHistory(ts.turnID, truncatedMessages)
|
||||
|
||||
// Log truncation result
|
||||
newRunes := utils.MeasureContextRunes(truncatedMessages)
|
||||
logger.InfoCF("subturn", "Context truncated successfully",
|
||||
map[string]any{
|
||||
"turn_id": ts.turnID,
|
||||
"before_runes": currentRunes,
|
||||
"after_runes": newRunes,
|
||||
"saved_runes": currentRunes - newRunes,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Call the agent loop
|
||||
finalContent, err := al.runAgentLoop(ctx, childAgent, processOptions{
|
||||
SessionKey: ts.turnID,
|
||||
UserMessage: currentPrompt,
|
||||
DefaultResponse: "",
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
SkipAddUserMessage: contextRetryCount > 0,
|
||||
})
|
||||
|
||||
// 1. Handle context length errors
|
||||
if err != nil && isContextLengthError(err) {
|
||||
if contextRetryCount >= maxContextRetries {
|
||||
logger.ErrorCF("subturn", "Context limit exceeded after max retries",
|
||||
map[string]any{
|
||||
"turn_id": ts.turnID,
|
||||
"retries": contextRetryCount,
|
||||
"max_retries": maxContextRetries,
|
||||
})
|
||||
return nil, fmt.Errorf("context limit exceeded after %d retries: %w", maxContextRetries, err)
|
||||
}
|
||||
|
||||
logger.WarnCF("subturn", "Context length exceeded, compressing and retrying",
|
||||
map[string]any{
|
||||
"turn_id": ts.turnID,
|
||||
"retry": contextRetryCount + 1,
|
||||
})
|
||||
|
||||
// Trigger force compression
|
||||
al.forceCompression(childAgent, ts.turnID)
|
||||
|
||||
contextRetryCount++
|
||||
continue // Retry with compressed history
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err // Other errors, return immediately
|
||||
}
|
||||
|
||||
// 2. Check for truncation (retrieve finishReason from turnState)
|
||||
finishReason := ts.GetLastFinishReason()
|
||||
|
||||
if finishReason == "truncated" && truncationRetryCount < maxTruncationRetries {
|
||||
logger.WarnCF("subturn", "Response truncated, injecting recovery message",
|
||||
map[string]any{
|
||||
"turn_id": ts.turnID,
|
||||
"retry": truncationRetryCount + 1,
|
||||
})
|
||||
|
||||
// IMPORTANT: Do NOT manually add messages to history here.
|
||||
// runAgentLoop has already saved both the assistant message (finalContent)
|
||||
// and will save the next user message (currentPrompt) on the next iteration.
|
||||
// Manually adding them would cause duplicates.
|
||||
|
||||
// Inject recovery prompt - it will be added by runAgentLoop on next iteration
|
||||
recoveryPrompt := "Your previous response was truncated due to length. Please provide a shorter, complete response that finishes your thought."
|
||||
currentPrompt = recoveryPrompt
|
||||
|
||||
truncationRetryCount++
|
||||
continue // Retry with recovery prompt
|
||||
}
|
||||
|
||||
// 3. Success - return result
|
||||
return &tools.ToolResult{ForLLM: finalContent}, nil
|
||||
}
|
||||
return &tools.ToolResult{ForLLM: finalContent}, nil
|
||||
}
|
||||
|
||||
// isContextLengthError checks if the error is due to context length exceeded.
|
||||
// It excludes timeout errors to avoid false positives.
|
||||
func isContextLengthError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
errMsg := strings.ToLower(err.Error())
|
||||
|
||||
// Exclude timeout errors
|
||||
if strings.Contains(errMsg, "timeout") || strings.Contains(errMsg, "deadline exceeded") {
|
||||
return false
|
||||
}
|
||||
|
||||
// Detect context error patterns
|
||||
return strings.Contains(errMsg, "context_length_exceeded") ||
|
||||
strings.Contains(errMsg, "maximum context length") ||
|
||||
strings.Contains(errMsg, "context window") ||
|
||||
strings.Contains(errMsg, "too many tokens") ||
|
||||
strings.Contains(errMsg, "token limit") ||
|
||||
strings.Contains(errMsg, "prompt is too long")
|
||||
}
|
||||
|
||||
// ====================== Other Types ======================
|
||||
|
||||
@@ -55,6 +55,11 @@ type turnState struct {
|
||||
// This allows child SubTurns to check if the parent has ended.
|
||||
// Nil for root turns.
|
||||
parentTurnState *turnState
|
||||
|
||||
// lastFinishReason stores the finish_reason from the last LLM call.
|
||||
// Used by SubTurn to detect truncation and retry.
|
||||
// MUST be accessed under mu lock.
|
||||
lastFinishReason string
|
||||
}
|
||||
|
||||
// ====================== Public API ======================
|
||||
@@ -136,6 +141,20 @@ func (ts *turnState) IsParentEnded() bool {
|
||||
return ts.parentTurnState.parentEnded.Load()
|
||||
}
|
||||
|
||||
// SetLastFinishReason updates the last finish reason (thread-safe).
|
||||
func (ts *turnState) SetLastFinishReason(reason string) {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
ts.lastFinishReason = reason
|
||||
}
|
||||
|
||||
// GetLastFinishReason retrieves the last finish reason (thread-safe).
|
||||
func (ts *turnState) GetLastFinishReason() string {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
return ts.lastFinishReason
|
||||
}
|
||||
|
||||
// IsParentEnded is a convenience method to check if parent ended.
|
||||
// It returns the value of the parent's parentEnded atomic flag.
|
||||
|
||||
|
||||
@@ -214,11 +214,20 @@ func ParseResponse(body io.Reader) (*LLMResponse, error) {
|
||||
Reasoning: choice.Message.Reasoning,
|
||||
ReasoningDetails: choice.Message.ReasoningDetails,
|
||||
ToolCalls: toolCalls,
|
||||
FinishReason: choice.FinishReason,
|
||||
FinishReason: normalizeFinishReason(choice.FinishReason),
|
||||
Usage: apiResponse.Usage,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// normalizeFinishReason normalizes finish_reason values across providers.
|
||||
// Converts "length" to "truncated" for consistent handling.
|
||||
func normalizeFinishReason(reason string) string {
|
||||
if reason == "length" {
|
||||
return "truncated"
|
||||
}
|
||||
return reason
|
||||
}
|
||||
|
||||
// DecodeToolCallArguments decodes a tool call's arguments from raw JSON.
|
||||
func DecodeToolCallArguments(raw json.RawMessage, name string) map[string]any {
|
||||
arguments := make(map[string]any)
|
||||
|
||||
@@ -0,0 +1,173 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// Inspired by and based on nanobot: https://github.com/HKUDS/nanobot
|
||||
// License: MIT
|
||||
//
|
||||
// Copyright (c) 2026 PicoClaw contributors
|
||||
|
||||
package utils
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
// CalculateDefaultMaxContextRunes computes a default context limit based on the model's context window.
|
||||
// Strategy: Use 75% of the context window and convert to rune estimate.
|
||||
//
|
||||
// Token-to-rune conversion ratios (conservative estimates):
|
||||
// - English: ~4 chars per token
|
||||
// - Chinese: ~1.5-2 chars per token
|
||||
// - Mixed: ~3 chars per token (used here for safety)
|
||||
func CalculateDefaultMaxContextRunes(contextWindow int) int {
|
||||
if contextWindow <= 0 {
|
||||
// Conservative fallback when context window is unknown
|
||||
return 8000 // ~2000 tokens
|
||||
}
|
||||
|
||||
// Use 75% of context window to leave headroom
|
||||
targetTokens := int(float64(contextWindow) * 0.75)
|
||||
|
||||
// Convert tokens to runes using conservative ratio
|
||||
const avgCharsPerToken = 3
|
||||
return targetTokens * avgCharsPerToken
|
||||
}
|
||||
|
||||
// ResolveMaxContextRunes determines the final MaxContextRunes value to use.
|
||||
// Priority: explicit config > auto-calculate > conservative default
|
||||
func ResolveMaxContextRunes(configValue, contextWindow int) int {
|
||||
switch {
|
||||
case configValue > 0:
|
||||
// Explicitly configured, use as-is
|
||||
return configValue
|
||||
case configValue == -1:
|
||||
// Explicitly disabled
|
||||
return -1
|
||||
default:
|
||||
// 0 or unset: auto-calculate
|
||||
return CalculateDefaultMaxContextRunes(contextWindow)
|
||||
}
|
||||
}
|
||||
|
||||
// MeasureContextRunes calculates the total rune count of a message list.
|
||||
// Includes content, reasoning content, and estimates for tool calls.
|
||||
func MeasureContextRunes(messages []providers.Message) int {
|
||||
totalRunes := 0
|
||||
for _, msg := range messages {
|
||||
totalRunes += utf8.RuneCountInString(msg.Content)
|
||||
totalRunes += utf8.RuneCountInString(msg.ReasoningContent)
|
||||
|
||||
// Tool calls: serialize to JSON and count
|
||||
if len(msg.ToolCalls) > 0 {
|
||||
for _, tc := range msg.ToolCalls {
|
||||
totalRunes += utf8.RuneCountInString(tc.Name)
|
||||
// Arguments: serialize and count
|
||||
if argsJSON, err := json.Marshal(tc.Arguments); err == nil {
|
||||
totalRunes += utf8.RuneCountInString(string(argsJSON))
|
||||
} else {
|
||||
// Fallback estimate if serialization fails
|
||||
totalRunes += 100
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ToolCallID
|
||||
totalRunes += utf8.RuneCountInString(msg.ToolCallID)
|
||||
}
|
||||
return totalRunes
|
||||
}
|
||||
|
||||
// TruncateContextSmart intelligently truncates message history to fit within maxRunes.
|
||||
//
|
||||
// Strategy:
|
||||
// 1. Always preserve system messages (they define the agent's behavior)
|
||||
// 2. Keep the most recent messages (they contain current context)
|
||||
// 3. Drop older middle messages when necessary
|
||||
// 4. Insert a truncation notice to inform the LLM
|
||||
//
|
||||
// Returns the truncated message list.
|
||||
func TruncateContextSmart(messages []providers.Message, maxRunes int) []providers.Message {
|
||||
if len(messages) == 0 {
|
||||
return messages
|
||||
}
|
||||
|
||||
// Separate system messages from others
|
||||
var systemMsgs []providers.Message
|
||||
var otherMsgs []providers.Message
|
||||
|
||||
for _, msg := range messages {
|
||||
if msg.Role == "system" {
|
||||
systemMsgs = append(systemMsgs, msg)
|
||||
} else {
|
||||
otherMsgs = append(otherMsgs, msg)
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate system message size
|
||||
systemRunes := 0
|
||||
for _, msg := range systemMsgs {
|
||||
systemRunes += utf8.RuneCountInString(msg.Content)
|
||||
systemRunes += utf8.RuneCountInString(msg.ReasoningContent)
|
||||
}
|
||||
|
||||
// Reserve space for truncation notice (estimate ~80 runes)
|
||||
const truncationNoticeEstimate = 80
|
||||
|
||||
// Allocate remaining space for other messages
|
||||
remainingRunes := maxRunes - systemRunes - truncationNoticeEstimate
|
||||
if remainingRunes <= 0 {
|
||||
// System messages already exceed limit - return only system messages
|
||||
return systemMsgs
|
||||
}
|
||||
|
||||
// Collect recent messages in reverse order until we hit the limit
|
||||
var keptMsgs []providers.Message
|
||||
currentRunes := 0
|
||||
|
||||
for i := len(otherMsgs) - 1; i >= 0; i-- {
|
||||
msg := otherMsgs[i]
|
||||
msgRunes := utf8.RuneCountInString(msg.Content) +
|
||||
utf8.RuneCountInString(msg.ReasoningContent)
|
||||
|
||||
// Estimate tool call size
|
||||
if len(msg.ToolCalls) > 0 {
|
||||
for _, tc := range msg.ToolCalls {
|
||||
msgRunes += utf8.RuneCountInString(tc.Name)
|
||||
if argsJSON, err := json.Marshal(tc.Arguments); err == nil {
|
||||
msgRunes += utf8.RuneCountInString(string(argsJSON))
|
||||
} else {
|
||||
msgRunes += 100
|
||||
}
|
||||
}
|
||||
}
|
||||
msgRunes += utf8.RuneCountInString(msg.ToolCallID)
|
||||
|
||||
if currentRunes+msgRunes > remainingRunes {
|
||||
// Would exceed limit, stop collecting
|
||||
break
|
||||
}
|
||||
|
||||
// Prepend to maintain chronological order
|
||||
keptMsgs = append([]providers.Message{msg}, keptMsgs...)
|
||||
currentRunes += msgRunes
|
||||
}
|
||||
|
||||
// If we dropped messages, add a truncation notice
|
||||
result := systemMsgs
|
||||
if len(keptMsgs) < len(otherMsgs) {
|
||||
droppedCount := len(otherMsgs) - len(keptMsgs)
|
||||
truncationNotice := providers.Message{
|
||||
Role: "system",
|
||||
Content: fmt.Sprintf(
|
||||
"[Context truncated: %d earlier messages omitted to stay within context limits]",
|
||||
droppedCount,
|
||||
),
|
||||
}
|
||||
result = append(result, truncationNotice)
|
||||
}
|
||||
|
||||
result = append(result, keptMsgs...)
|
||||
return result
|
||||
}
|
||||
@@ -0,0 +1,450 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// License: MIT
|
||||
//
|
||||
// Copyright (c) 2026 PicoClaw contributors
|
||||
|
||||
package utils
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
func TestCalculateDefaultMaxContextRunes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
contextWindow int
|
||||
want int
|
||||
}{
|
||||
{
|
||||
name: "zero context window uses fallback",
|
||||
contextWindow: 0,
|
||||
want: 8000,
|
||||
},
|
||||
{
|
||||
name: "negative context window uses fallback",
|
||||
contextWindow: -1,
|
||||
want: 8000,
|
||||
},
|
||||
{
|
||||
name: "small context window (4k tokens)",
|
||||
contextWindow: 4000,
|
||||
want: 9000, // 4000 * 0.75 * 3 = 9000
|
||||
},
|
||||
{
|
||||
name: "medium context window (128k tokens)",
|
||||
contextWindow: 128000,
|
||||
want: 288000, // 128000 * 0.75 * 3 = 288000
|
||||
},
|
||||
{
|
||||
name: "large context window (1M tokens)",
|
||||
contextWindow: 1000000,
|
||||
want: 2250000, // 1000000 * 0.75 * 3 = 2250000
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := CalculateDefaultMaxContextRunes(tt.contextWindow)
|
||||
if got != tt.want {
|
||||
t.Errorf("CalculateDefaultMaxContextRunes(%d) = %d, want %d",
|
||||
tt.contextWindow, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveMaxContextRunes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
configValue int
|
||||
contextWindow int
|
||||
want int
|
||||
}{
|
||||
{
|
||||
name: "explicit positive value",
|
||||
configValue: 12000,
|
||||
contextWindow: 4000,
|
||||
want: 12000,
|
||||
},
|
||||
{
|
||||
name: "explicit disable (-1)",
|
||||
configValue: -1,
|
||||
contextWindow: 4000,
|
||||
want: -1,
|
||||
},
|
||||
{
|
||||
name: "zero uses auto-calculate",
|
||||
configValue: 0,
|
||||
contextWindow: 4000,
|
||||
want: 9000, // 4000 * 0.75 * 3
|
||||
},
|
||||
{
|
||||
name: "unset (0) with unknown context window",
|
||||
configValue: 0,
|
||||
contextWindow: 0,
|
||||
want: 8000, // fallback
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := ResolveMaxContextRunes(tt.configValue, tt.contextWindow)
|
||||
if got != tt.want {
|
||||
t.Errorf("ResolveMaxContextRunes(%d, %d) = %d, want %d",
|
||||
tt.configValue, tt.contextWindow, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMeasureContextRunes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
messages []providers.Message
|
||||
want int
|
||||
}{
|
||||
{
|
||||
name: "empty messages",
|
||||
messages: []providers.Message{},
|
||||
want: 0,
|
||||
},
|
||||
{
|
||||
name: "single simple message",
|
||||
messages: []providers.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
want: 5, // "Hello" = 5 runes
|
||||
},
|
||||
{
|
||||
name: "message with reasoning",
|
||||
messages: []providers.Message{
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "Answer",
|
||||
ReasoningContent: "Thinking",
|
||||
},
|
||||
},
|
||||
want: 14, // "Answer" (6) + "Thinking" (8) = 14
|
||||
},
|
||||
{
|
||||
name: "message with tool call",
|
||||
messages: []providers.Message{
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "Using tool",
|
||||
ToolCalls: []providers.ToolCall{
|
||||
{
|
||||
Name: "test_tool",
|
||||
Arguments: map[string]any{"key": "value"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: 10 + 9 + 15, // "Using tool" + "test_tool" + {"key":"value"}
|
||||
},
|
||||
{
|
||||
name: "multiple messages",
|
||||
messages: []providers.Message{
|
||||
{Role: "system", Content: "You are helpful"},
|
||||
{Role: "user", Content: "Hi"},
|
||||
{Role: "assistant", Content: "Hello!"},
|
||||
},
|
||||
want: 15 + 2 + 6, // 15 + 2 + 6 = 23
|
||||
},
|
||||
{
|
||||
name: "unicode characters",
|
||||
messages: []providers.Message{
|
||||
{Role: "user", Content: "你好世界"}, // 4 Chinese characters
|
||||
},
|
||||
want: 4,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := MeasureContextRunes(tt.messages)
|
||||
if got != tt.want {
|
||||
t.Errorf("MeasureContextRunes() = %d, want %d", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncateContextSmart(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
messages []providers.Message
|
||||
maxRunes int
|
||||
wantLen int
|
||||
wantHas []string // Content strings that should be present
|
||||
wantNot []string // Content strings that should be absent
|
||||
}{
|
||||
{
|
||||
name: "empty messages",
|
||||
messages: []providers.Message{},
|
||||
maxRunes: 100,
|
||||
wantLen: 0,
|
||||
},
|
||||
{
|
||||
name: "no truncation needed",
|
||||
messages: []providers.Message{
|
||||
{Role: "system", Content: "System"},
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
maxRunes: 100,
|
||||
wantLen: 2,
|
||||
wantHas: []string{"System", "Hello"},
|
||||
},
|
||||
{
|
||||
name: "truncate when limit is tight",
|
||||
messages: []providers.Message{
|
||||
{Role: "system", Content: "System"},
|
||||
{Role: "user", Content: "Message 1 with some content here"},
|
||||
{Role: "assistant", Content: "Response 1 with some content here"},
|
||||
{Role: "user", Content: "Message 2 with some content here"},
|
||||
{Role: "assistant", Content: "Response 2 with some content here"},
|
||||
{Role: "user", Content: "Latest"},
|
||||
},
|
||||
maxRunes: 120, // Tight limit to force truncation
|
||||
wantLen: -1, // Don't check exact length, just verify truncation occurred
|
||||
wantHas: []string{"System", "Latest"},
|
||||
wantNot: []string{"Message 1", "Response 1"},
|
||||
},
|
||||
{
|
||||
name: "system messages exceed limit",
|
||||
messages: []providers.Message{
|
||||
{Role: "system", Content: "Very long system message"},
|
||||
{Role: "user", Content: "User message"},
|
||||
},
|
||||
maxRunes: 10, // Less than system message
|
||||
wantLen: 1, // Only system message
|
||||
wantHas: []string{"Very long system message"},
|
||||
wantNot: []string{"User message"},
|
||||
},
|
||||
{
|
||||
name: "preserve multiple system messages",
|
||||
messages: []providers.Message{
|
||||
{Role: "system", Content: "Sys1"},
|
||||
{Role: "system", Content: "Sys2"},
|
||||
{Role: "user", Content: "Old"},
|
||||
{Role: "user", Content: "New"},
|
||||
},
|
||||
maxRunes: 200, // Generous limit
|
||||
wantLen: 4, // Both system + truncation notice + new
|
||||
wantHas: []string{"Sys1", "Sys2", "New"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := TruncateContextSmart(tt.messages, tt.maxRunes)
|
||||
|
||||
if tt.wantLen >= 0 && len(got) != tt.wantLen {
|
||||
t.Errorf("TruncateContextSmart() returned %d messages, want %d",
|
||||
len(got), tt.wantLen)
|
||||
}
|
||||
|
||||
// Check for expected content
|
||||
allContent := ""
|
||||
for _, msg := range got {
|
||||
allContent += msg.Content + " "
|
||||
}
|
||||
|
||||
for _, want := range tt.wantHas {
|
||||
found := false
|
||||
for _, msg := range got {
|
||||
if msg.Content == want || containsSubstring(msg.Content, want) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("Expected content %q not found in truncated messages", want)
|
||||
}
|
||||
}
|
||||
|
||||
for _, notWant := range tt.wantNot {
|
||||
for _, msg := range got {
|
||||
if containsSubstring(msg.Content, notWant) {
|
||||
t.Errorf("Unexpected content %q found in truncated messages", notWant)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func containsSubstring(s, substr string) bool {
|
||||
return len(s) >= len(substr) && findSubstring(s, substr)
|
||||
}
|
||||
|
||||
func findSubstring(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// TestSubTurnConfigMaxContextRunes verifies that MaxContextRunes configuration
|
||||
// is properly integrated into the SubTurn execution flow.
|
||||
func TestSubTurnConfigMaxContextRunes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
maxContextRunes int
|
||||
contextWindow int
|
||||
wantResolved int
|
||||
}{
|
||||
{
|
||||
name: "default (0) auto-calculates from context window",
|
||||
maxContextRunes: 0,
|
||||
contextWindow: 4000,
|
||||
wantResolved: 9000, // 4000 * 0.75 * 3
|
||||
},
|
||||
{
|
||||
name: "explicit value is used",
|
||||
maxContextRunes: 12000,
|
||||
contextWindow: 4000,
|
||||
wantResolved: 12000,
|
||||
},
|
||||
{
|
||||
name: "disabled (-1) returns -1",
|
||||
maxContextRunes: -1,
|
||||
contextWindow: 4000,
|
||||
wantResolved: -1,
|
||||
},
|
||||
{
|
||||
name: "fallback when context window unknown",
|
||||
maxContextRunes: 0,
|
||||
contextWindow: 0,
|
||||
wantResolved: 8000, // conservative fallback
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := ResolveMaxContextRunes(tt.maxContextRunes, tt.contextWindow)
|
||||
if got != tt.wantResolved {
|
||||
t.Errorf("utils.ResolveMaxContextRunes(%d, %d) = %d, want %d",
|
||||
tt.maxContextRunes, tt.contextWindow, got, tt.wantResolved)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestContextTruncationFlow verifies the complete context truncation flow:
|
||||
// 1. Messages accumulate beyond soft limit
|
||||
// 2. Truncation is triggered
|
||||
// 3. System messages are preserved
|
||||
// 4. Recent messages are kept
|
||||
func TestContextTruncationFlow(t *testing.T) {
|
||||
// Build a message history that exceeds the limit
|
||||
messages := []providers.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant"}, // ~27 runes
|
||||
{Role: "user", Content: "First question"}, // ~14 runes
|
||||
{Role: "assistant", Content: "First answer"}, // ~12 runes
|
||||
{Role: "user", Content: "Second question"}, // ~15 runes
|
||||
{Role: "assistant", Content: "Second answer"}, // ~13 runes
|
||||
{Role: "user", Content: "Third question"}, // ~14 runes
|
||||
{Role: "assistant", Content: "Third answer"}, // ~12 runes
|
||||
{Role: "user", Content: "Latest question"}, // ~15 runes
|
||||
}
|
||||
|
||||
// Total: ~122 runes
|
||||
totalRunes := MeasureContextRunes(messages)
|
||||
if totalRunes < 100 {
|
||||
t.Errorf("Expected total runes > 100, got %d", totalRunes)
|
||||
}
|
||||
|
||||
// Set limit to 150 runes - should force truncation of old messages
|
||||
// but preserve system + truncation notice + recent messages
|
||||
maxRunes := 150
|
||||
truncated := TruncateContextSmart(messages, maxRunes)
|
||||
|
||||
// Verify truncation occurred
|
||||
if len(truncated) >= len(messages) {
|
||||
t.Errorf("Expected truncation, but got %d messages (original: %d)",
|
||||
len(truncated), len(messages))
|
||||
}
|
||||
|
||||
// Verify system message is preserved
|
||||
foundSystem := false
|
||||
for _, msg := range truncated {
|
||||
if msg.Role == "system" && msg.Content == "You are a helpful assistant" {
|
||||
foundSystem = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundSystem {
|
||||
t.Error("System message was not preserved after truncation")
|
||||
}
|
||||
|
||||
// Verify latest message is preserved
|
||||
foundLatest := false
|
||||
for _, msg := range truncated {
|
||||
if msg.Content == "Latest question" {
|
||||
foundLatest = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundLatest {
|
||||
t.Error("Latest message was not preserved after truncation")
|
||||
}
|
||||
|
||||
// Verify truncation notice is present
|
||||
foundNotice := false
|
||||
for _, msg := range truncated {
|
||||
if msg.Role == "system" && containsSubstring(msg.Content, "truncated") {
|
||||
foundNotice = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundNotice {
|
||||
t.Error("Truncation notice was not added")
|
||||
}
|
||||
|
||||
// Verify result is within limit (with some tolerance for estimation)
|
||||
resultRunes := MeasureContextRunes(truncated)
|
||||
if resultRunes > maxRunes+20 { // Allow 20 rune tolerance
|
||||
t.Errorf("Truncated context (%d runes) significantly exceeds limit (%d runes)",
|
||||
resultRunes, maxRunes)
|
||||
}
|
||||
}
|
||||
|
||||
// TestContextTruncationPreservesToolCalls verifies that tool calls are
|
||||
// properly handled during context truncation.
|
||||
func TestContextTruncationPreservesToolCalls(t *testing.T) {
|
||||
messages := []providers.Message{
|
||||
{Role: "system", Content: "System"},
|
||||
{Role: "user", Content: "Old message that should be dropped"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "Recent tool use",
|
||||
ToolCalls: []providers.ToolCall{
|
||||
{
|
||||
Name: "important_tool",
|
||||
Arguments: map[string]any{"key": "value"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Set a generous limit that should keep the tool call message
|
||||
maxRunes := 200
|
||||
truncated := TruncateContextSmart(messages, maxRunes)
|
||||
|
||||
// Verify tool call message is preserved
|
||||
foundToolCall := false
|
||||
for _, msg := range truncated {
|
||||
if len(msg.ToolCalls) > 0 && msg.ToolCalls[0].Name == "important_tool" {
|
||||
foundToolCall = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundToolCall {
|
||||
t.Error("Tool call message was not preserved during truncation")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user