fix(agent): preserve active turn during context retry rebuild

This commit is contained in:
afjcjsbx
2026-05-19 09:18:39 +02:00
parent 1502636bf0
commit fe7ded5c13
3 changed files with 214 additions and 13 deletions
+24 -9
View File
@@ -369,15 +369,21 @@ func (p *Pipeline) CallLLM(
contextualSkills = ts.agent.ContextBuilder.ResolveActiveSkillsForContext(ts.activeSkills)
}
ts.recordSkillContextSnapshot(skillContextTriggerContextRetryRebuild, contextualSkills)
stableHistory, protectedTurnTail := splitHistoryForActiveTurn(
exec.history,
ts.persistedMessagesSnapshot(),
)
buildMessages := func(trimmedHistory []providers.Message) []providers.Message {
rebuildPromptReq := promptBuildRequestForTurn(ts, trimmedHistory, exec.summary, "", nil)
fullHistory := append(append([]providers.Message(nil), trimmedHistory...), protectedTurnTail...)
rebuildPromptReq := promptBuildRequestForTurn(ts, fullHistory, exec.summary, "", nil)
rebuildPromptReq.ActiveSkills = append([]string(nil), contextualSkills...)
return ts.agent.ContextBuilder.BuildMessagesFromPrompt(rebuildPromptReq)
}
originalHistoryCount := len(exec.history)
var fit bool
exec.history, exec.callMessages, fit = trimHistoryToFitContextWindow(
exec.history,
var trimmedStableHistory []providers.Message
trimmedStableHistory, exec.callMessages, fit = trimHistoryToFitContextWindow(
stableHistory,
func(trimmedHistory []providers.Message) []providers.Message {
rebuilt := buildMessages(trimmedHistory)
if exec.gracefulTerminal {
@@ -389,7 +395,8 @@ func (p *Pipeline) CallLLM(
exec.providerToolDefs,
ts.agent.MaxTokens,
)
exec.messages = buildMessages(exec.history)
exec.history = append(trimmedStableHistory, protectedTurnTail...)
exec.messages = buildMessages(trimmedStableHistory)
if exec.gracefulTerminal {
msgs := append([]providers.Message(nil), exec.messages...)
exec.callMessages = append(msgs, ts.interruptHintMessage())
@@ -406,13 +413,21 @@ func (p *Pipeline) CallLLM(
})
} else if !fit {
logger.WarnCF("agent", "Context still exceeds budget after retry compaction rebuild", map[string]any{
"session_key": ts.sessionKey,
"retry": retry,
"history_msgs": len(exec.history),
"context_window": ts.agent.ContextWindow,
"max_tokens": ts.agent.MaxTokens,
"session_key": ts.sessionKey,
"retry": retry,
"history_msgs": len(exec.history),
"protected_turn_msgs": len(protectedTurnTail),
"context_window": ts.agent.ContextWindow,
"max_tokens": ts.agent.MaxTokens,
})
}
if !fit {
err = fmt.Errorf(
"context window still exceeded after retry compaction; refusing to drop active turn messages: %w",
err,
)
break
}
continue
}
break
+78 -4
View File
@@ -634,13 +634,17 @@ func (ts *turnState) recordPersistedMessage(msg providers.Message) {
ts.persistedMessages = append(ts.persistedMessages, msg)
}
func (ts *turnState) persistedMessagesSnapshot() []providers.Message {
ts.mu.RLock()
defer ts.mu.RUnlock()
return append([]providers.Message(nil), ts.persistedMessages...)
}
func (ts *turnState) refreshRestorePointFromSession(agent *AgentInstance) {
history := agent.Sessions.GetHistory(ts.sessionKey)
summary := agent.Sessions.GetSummary(ts.sessionKey)
ts.mu.RLock()
persisted := append([]providers.Message(nil), ts.persistedMessages...)
ts.mu.RUnlock()
persisted := ts.persistedMessagesSnapshot()
if matched := matchingTurnMessageTail(history, persisted); matched > 0 {
history = append([]providers.Message(nil), history[:len(history)-matched]...)
@@ -680,13 +684,83 @@ func (ts *turnState) restoreSession(agent *AgentInstance) error {
func matchingTurnMessageTail(history, persisted []providers.Message) int {
maxMatch := min(len(history), len(persisted))
for size := maxMatch; size > 0; size-- {
if reflect.DeepEqual(history[len(history)-size:], persisted[len(persisted)-size:]) {
if messageSlicesEquivalent(history[len(history)-size:], persisted[len(persisted)-size:]) {
return size
}
}
return 0
}
func splitHistoryForActiveTurn(
history []providers.Message,
persisted []providers.Message,
) ([]providers.Message, []providers.Message) {
matched := matchingTurnMessageTail(history, persisted)
if matched <= 0 {
return append([]providers.Message(nil), history...), nil
}
stable := append([]providers.Message(nil), history[:len(history)-matched]...)
protected := append([]providers.Message(nil), history[len(history)-matched:]...)
return stable, protected
}
func messageSlicesEquivalent(a, b []providers.Message) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if !messagesEquivalent(a[i], b[i]) {
return false
}
}
return true
}
func messagesEquivalent(a, b providers.Message) bool {
return reflect.DeepEqual(normalizeMessageForComparison(a), normalizeMessageForComparison(b))
}
func normalizeMessageForComparison(msg providers.Message) providers.Message {
msg.PromptLayer = ""
msg.PromptSlot = ""
msg.PromptSource = ""
if len(msg.Media) == 0 {
msg.Media = nil
}
if len(msg.Attachments) == 0 {
msg.Attachments = nil
}
if len(msg.SystemParts) == 0 {
msg.SystemParts = nil
} else {
msg.SystemParts = append([]providers.ContentBlock(nil), msg.SystemParts...)
for i := range msg.SystemParts {
msg.SystemParts[i].PromptLayer = ""
msg.SystemParts[i].PromptSlot = ""
msg.SystemParts[i].PromptSource = ""
}
}
if len(msg.ToolCalls) == 0 {
msg.ToolCalls = nil
} else {
msg.ToolCalls = append([]providers.ToolCall(nil), msg.ToolCalls...)
for i := range msg.ToolCalls {
msg.ToolCalls[i].Name = ""
msg.ToolCalls[i].Arguments = nil
msg.ToolCalls[i].ThoughtSignature = ""
if msg.ToolCalls[i].Function != nil {
fn := *msg.ToolCalls[i].Function
fn.ThoughtSignature = ""
msg.ToolCalls[i].Function = &fn
}
}
}
return msg
}
func (ts *turnState) interruptHintMessage() providers.Message {
_, hint := ts.gracefulInterruptRequested()
content := "Interrupt requested. Stop scheduling tools and provide a short final summary."
+112
View File
@@ -0,0 +1,112 @@
package agent
import (
"strings"
"testing"
"github.com/sipeed/picoclaw/pkg/providers"
)
func TestMatchingTurnMessageTail_IgnoresInternalRuntimeFields(t *testing.T) {
history := []providers.Message{
{Role: "user", Content: "question"},
{
Role: "assistant",
ToolCalls: []providers.ToolCall{
{
ID: "call_1",
Type: "function",
Function: &providers.FunctionCall{
Name: "read_file",
Arguments: `{"path":"/tmp/test"}`,
},
},
},
},
}
persisted := []providers.Message{
userPromptMessage("question", nil),
{
Role: "assistant",
ToolCalls: []providers.ToolCall{
{
ID: "call_1",
Type: "function",
Name: "read_file",
Arguments: map[string]any{"path": "/tmp/test"},
ThoughtSignature: "internal-signature",
Function: &providers.FunctionCall{
Name: "read_file",
Arguments: `{"path":"/tmp/test"}`,
ThoughtSignature: "internal-signature",
},
},
},
},
}
if got := matchingTurnMessageTail(history, persisted); got != 2 {
t.Fatalf("matchingTurnMessageTail() = %d, want 2", got)
}
}
func TestSplitHistoryForActiveTurn_ProtectsPersistedTail(t *testing.T) {
history := []providers.Message{
{Role: "user", Content: "old question"},
{Role: "assistant", Content: "old answer"},
{Role: "user", Content: "current question"},
{Role: "tool", Content: "tool output", ToolCallID: "call_1"},
}
persisted := []providers.Message{
userPromptMessage("current question", nil),
{Role: "tool", Content: "tool output", ToolCallID: "call_1"},
}
stable, protected := splitHistoryForActiveTurn(history, persisted)
if len(stable) != 2 {
t.Fatalf("stable history len = %d, want 2", len(stable))
}
if len(protected) != 2 {
t.Fatalf("protected tail len = %d, want 2", len(protected))
}
if protected[0].Content != "current question" {
t.Fatalf("protected[0].Content = %q, want current question", protected[0].Content)
}
}
func TestTrimHistoryToFitContextWindow_WithProtectedTurnTailKeepsActiveTurn(t *testing.T) {
current := strings.Repeat("current turn ", 80)
history := []providers.Message{
{Role: "user", Content: strings.Repeat("old turn ", 60)},
{Role: "assistant", Content: strings.Repeat("old reply ", 60)},
{Role: "user", Content: current},
}
stable, protected := splitHistoryForActiveTurn(history, []providers.Message{
userPromptMessage(current, nil),
})
trimmedStable, messages, fit := trimHistoryToFitContextWindow(
stable,
func(trimmedHistory []providers.Message) []providers.Message {
return append(append([]providers.Message(nil), trimmedHistory...), protected...)
},
120,
nil,
0,
)
if fit {
t.Fatal("expected protected active turn alone to remain over budget")
}
if len(trimmedStable) != 0 {
t.Fatalf("trimmed stable history len = %d, want 0", len(trimmedStable))
}
if len(messages) != 1 {
t.Fatalf("messages len = %d, want 1 protected active-turn message", len(messages))
}
if messages[0].Content != current {
t.Fatalf("messages[0].Content = %q, want protected current turn", messages[0].Content)
}
}