From fe7ded5c138d62fd76c25bee9110333d08d4c2b2 Mon Sep 17 00:00:00 2001 From: afjcjsbx Date: Tue, 19 May 2026 09:18:39 +0200 Subject: [PATCH] fix(agent): preserve active turn during context retry rebuild --- pkg/agent/pipeline_llm.go | 33 ++++++++--- pkg/agent/turn_state.go | 82 +++++++++++++++++++++++-- pkg/agent/turn_state_test.go | 112 +++++++++++++++++++++++++++++++++++ 3 files changed, 214 insertions(+), 13 deletions(-) create mode 100644 pkg/agent/turn_state_test.go diff --git a/pkg/agent/pipeline_llm.go b/pkg/agent/pipeline_llm.go index 7934b7815..9fbe9b740 100644 --- a/pkg/agent/pipeline_llm.go +++ b/pkg/agent/pipeline_llm.go @@ -369,15 +369,21 @@ func (p *Pipeline) CallLLM( contextualSkills = ts.agent.ContextBuilder.ResolveActiveSkillsForContext(ts.activeSkills) } ts.recordSkillContextSnapshot(skillContextTriggerContextRetryRebuild, contextualSkills) + stableHistory, protectedTurnTail := splitHistoryForActiveTurn( + exec.history, + ts.persistedMessagesSnapshot(), + ) buildMessages := func(trimmedHistory []providers.Message) []providers.Message { - rebuildPromptReq := promptBuildRequestForTurn(ts, trimmedHistory, exec.summary, "", nil) + fullHistory := append(append([]providers.Message(nil), trimmedHistory...), protectedTurnTail...) + rebuildPromptReq := promptBuildRequestForTurn(ts, fullHistory, exec.summary, "", nil) rebuildPromptReq.ActiveSkills = append([]string(nil), contextualSkills...) return ts.agent.ContextBuilder.BuildMessagesFromPrompt(rebuildPromptReq) } originalHistoryCount := len(exec.history) var fit bool - exec.history, exec.callMessages, fit = trimHistoryToFitContextWindow( - exec.history, + var trimmedStableHistory []providers.Message + trimmedStableHistory, exec.callMessages, fit = trimHistoryToFitContextWindow( + stableHistory, func(trimmedHistory []providers.Message) []providers.Message { rebuilt := buildMessages(trimmedHistory) if exec.gracefulTerminal { @@ -389,7 +395,8 @@ func (p *Pipeline) CallLLM( exec.providerToolDefs, ts.agent.MaxTokens, ) - exec.messages = buildMessages(exec.history) + exec.history = append(trimmedStableHistory, protectedTurnTail...) + exec.messages = buildMessages(trimmedStableHistory) if exec.gracefulTerminal { msgs := append([]providers.Message(nil), exec.messages...) exec.callMessages = append(msgs, ts.interruptHintMessage()) @@ -406,13 +413,21 @@ func (p *Pipeline) CallLLM( }) } else if !fit { logger.WarnCF("agent", "Context still exceeds budget after retry compaction rebuild", map[string]any{ - "session_key": ts.sessionKey, - "retry": retry, - "history_msgs": len(exec.history), - "context_window": ts.agent.ContextWindow, - "max_tokens": ts.agent.MaxTokens, + "session_key": ts.sessionKey, + "retry": retry, + "history_msgs": len(exec.history), + "protected_turn_msgs": len(protectedTurnTail), + "context_window": ts.agent.ContextWindow, + "max_tokens": ts.agent.MaxTokens, }) } + if !fit { + err = fmt.Errorf( + "context window still exceeded after retry compaction; refusing to drop active turn messages: %w", + err, + ) + break + } continue } break diff --git a/pkg/agent/turn_state.go b/pkg/agent/turn_state.go index ae058e49d..2df0c81c2 100644 --- a/pkg/agent/turn_state.go +++ b/pkg/agent/turn_state.go @@ -634,13 +634,17 @@ func (ts *turnState) recordPersistedMessage(msg providers.Message) { ts.persistedMessages = append(ts.persistedMessages, msg) } +func (ts *turnState) persistedMessagesSnapshot() []providers.Message { + ts.mu.RLock() + defer ts.mu.RUnlock() + return append([]providers.Message(nil), ts.persistedMessages...) +} + func (ts *turnState) refreshRestorePointFromSession(agent *AgentInstance) { history := agent.Sessions.GetHistory(ts.sessionKey) summary := agent.Sessions.GetSummary(ts.sessionKey) - ts.mu.RLock() - persisted := append([]providers.Message(nil), ts.persistedMessages...) - ts.mu.RUnlock() + persisted := ts.persistedMessagesSnapshot() if matched := matchingTurnMessageTail(history, persisted); matched > 0 { history = append([]providers.Message(nil), history[:len(history)-matched]...) @@ -680,13 +684,83 @@ func (ts *turnState) restoreSession(agent *AgentInstance) error { func matchingTurnMessageTail(history, persisted []providers.Message) int { maxMatch := min(len(history), len(persisted)) for size := maxMatch; size > 0; size-- { - if reflect.DeepEqual(history[len(history)-size:], persisted[len(persisted)-size:]) { + if messageSlicesEquivalent(history[len(history)-size:], persisted[len(persisted)-size:]) { return size } } return 0 } +func splitHistoryForActiveTurn( + history []providers.Message, + persisted []providers.Message, +) ([]providers.Message, []providers.Message) { + matched := matchingTurnMessageTail(history, persisted) + if matched <= 0 { + return append([]providers.Message(nil), history...), nil + } + + stable := append([]providers.Message(nil), history[:len(history)-matched]...) + protected := append([]providers.Message(nil), history[len(history)-matched:]...) + return stable, protected +} + +func messageSlicesEquivalent(a, b []providers.Message) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if !messagesEquivalent(a[i], b[i]) { + return false + } + } + return true +} + +func messagesEquivalent(a, b providers.Message) bool { + return reflect.DeepEqual(normalizeMessageForComparison(a), normalizeMessageForComparison(b)) +} + +func normalizeMessageForComparison(msg providers.Message) providers.Message { + msg.PromptLayer = "" + msg.PromptSlot = "" + msg.PromptSource = "" + + if len(msg.Media) == 0 { + msg.Media = nil + } + if len(msg.Attachments) == 0 { + msg.Attachments = nil + } + if len(msg.SystemParts) == 0 { + msg.SystemParts = nil + } else { + msg.SystemParts = append([]providers.ContentBlock(nil), msg.SystemParts...) + for i := range msg.SystemParts { + msg.SystemParts[i].PromptLayer = "" + msg.SystemParts[i].PromptSlot = "" + msg.SystemParts[i].PromptSource = "" + } + } + if len(msg.ToolCalls) == 0 { + msg.ToolCalls = nil + } else { + msg.ToolCalls = append([]providers.ToolCall(nil), msg.ToolCalls...) + for i := range msg.ToolCalls { + msg.ToolCalls[i].Name = "" + msg.ToolCalls[i].Arguments = nil + msg.ToolCalls[i].ThoughtSignature = "" + if msg.ToolCalls[i].Function != nil { + fn := *msg.ToolCalls[i].Function + fn.ThoughtSignature = "" + msg.ToolCalls[i].Function = &fn + } + } + } + + return msg +} + func (ts *turnState) interruptHintMessage() providers.Message { _, hint := ts.gracefulInterruptRequested() content := "Interrupt requested. Stop scheduling tools and provide a short final summary." diff --git a/pkg/agent/turn_state_test.go b/pkg/agent/turn_state_test.go new file mode 100644 index 000000000..8ffd40dd5 --- /dev/null +++ b/pkg/agent/turn_state_test.go @@ -0,0 +1,112 @@ +package agent + +import ( + "strings" + "testing" + + "github.com/sipeed/picoclaw/pkg/providers" +) + +func TestMatchingTurnMessageTail_IgnoresInternalRuntimeFields(t *testing.T) { + history := []providers.Message{ + {Role: "user", Content: "question"}, + { + Role: "assistant", + ToolCalls: []providers.ToolCall{ + { + ID: "call_1", + Type: "function", + Function: &providers.FunctionCall{ + Name: "read_file", + Arguments: `{"path":"/tmp/test"}`, + }, + }, + }, + }, + } + + persisted := []providers.Message{ + userPromptMessage("question", nil), + { + Role: "assistant", + ToolCalls: []providers.ToolCall{ + { + ID: "call_1", + Type: "function", + Name: "read_file", + Arguments: map[string]any{"path": "/tmp/test"}, + ThoughtSignature: "internal-signature", + Function: &providers.FunctionCall{ + Name: "read_file", + Arguments: `{"path":"/tmp/test"}`, + ThoughtSignature: "internal-signature", + }, + }, + }, + }, + } + + if got := matchingTurnMessageTail(history, persisted); got != 2 { + t.Fatalf("matchingTurnMessageTail() = %d, want 2", got) + } +} + +func TestSplitHistoryForActiveTurn_ProtectsPersistedTail(t *testing.T) { + history := []providers.Message{ + {Role: "user", Content: "old question"}, + {Role: "assistant", Content: "old answer"}, + {Role: "user", Content: "current question"}, + {Role: "tool", Content: "tool output", ToolCallID: "call_1"}, + } + + persisted := []providers.Message{ + userPromptMessage("current question", nil), + {Role: "tool", Content: "tool output", ToolCallID: "call_1"}, + } + + stable, protected := splitHistoryForActiveTurn(history, persisted) + if len(stable) != 2 { + t.Fatalf("stable history len = %d, want 2", len(stable)) + } + if len(protected) != 2 { + t.Fatalf("protected tail len = %d, want 2", len(protected)) + } + if protected[0].Content != "current question" { + t.Fatalf("protected[0].Content = %q, want current question", protected[0].Content) + } +} + +func TestTrimHistoryToFitContextWindow_WithProtectedTurnTailKeepsActiveTurn(t *testing.T) { + current := strings.Repeat("current turn ", 80) + history := []providers.Message{ + {Role: "user", Content: strings.Repeat("old turn ", 60)}, + {Role: "assistant", Content: strings.Repeat("old reply ", 60)}, + {Role: "user", Content: current}, + } + + stable, protected := splitHistoryForActiveTurn(history, []providers.Message{ + userPromptMessage(current, nil), + }) + trimmedStable, messages, fit := trimHistoryToFitContextWindow( + stable, + func(trimmedHistory []providers.Message) []providers.Message { + return append(append([]providers.Message(nil), trimmedHistory...), protected...) + }, + 120, + nil, + 0, + ) + + if fit { + t.Fatal("expected protected active turn alone to remain over budget") + } + if len(trimmedStable) != 0 { + t.Fatalf("trimmed stable history len = %d, want 0", len(trimmedStable)) + } + if len(messages) != 1 { + t.Fatalf("messages len = %d, want 1 protected active-turn message", len(messages)) + } + if messages[0].Content != current { + t.Fatalf("messages[0].Content = %q, want protected current turn", messages[0].Content) + } +}