mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-05-25 16:00:35 +00:00
fix(agent): preserve active turn during context retry rebuild
This commit is contained in:
@@ -369,15 +369,21 @@ func (p *Pipeline) CallLLM(
|
||||
contextualSkills = ts.agent.ContextBuilder.ResolveActiveSkillsForContext(ts.activeSkills)
|
||||
}
|
||||
ts.recordSkillContextSnapshot(skillContextTriggerContextRetryRebuild, contextualSkills)
|
||||
stableHistory, protectedTurnTail := splitHistoryForActiveTurn(
|
||||
exec.history,
|
||||
ts.persistedMessagesSnapshot(),
|
||||
)
|
||||
buildMessages := func(trimmedHistory []providers.Message) []providers.Message {
|
||||
rebuildPromptReq := promptBuildRequestForTurn(ts, trimmedHistory, exec.summary, "", nil)
|
||||
fullHistory := append(append([]providers.Message(nil), trimmedHistory...), protectedTurnTail...)
|
||||
rebuildPromptReq := promptBuildRequestForTurn(ts, fullHistory, exec.summary, "", nil)
|
||||
rebuildPromptReq.ActiveSkills = append([]string(nil), contextualSkills...)
|
||||
return ts.agent.ContextBuilder.BuildMessagesFromPrompt(rebuildPromptReq)
|
||||
}
|
||||
originalHistoryCount := len(exec.history)
|
||||
var fit bool
|
||||
exec.history, exec.callMessages, fit = trimHistoryToFitContextWindow(
|
||||
exec.history,
|
||||
var trimmedStableHistory []providers.Message
|
||||
trimmedStableHistory, exec.callMessages, fit = trimHistoryToFitContextWindow(
|
||||
stableHistory,
|
||||
func(trimmedHistory []providers.Message) []providers.Message {
|
||||
rebuilt := buildMessages(trimmedHistory)
|
||||
if exec.gracefulTerminal {
|
||||
@@ -389,7 +395,8 @@ func (p *Pipeline) CallLLM(
|
||||
exec.providerToolDefs,
|
||||
ts.agent.MaxTokens,
|
||||
)
|
||||
exec.messages = buildMessages(exec.history)
|
||||
exec.history = append(trimmedStableHistory, protectedTurnTail...)
|
||||
exec.messages = buildMessages(trimmedStableHistory)
|
||||
if exec.gracefulTerminal {
|
||||
msgs := append([]providers.Message(nil), exec.messages...)
|
||||
exec.callMessages = append(msgs, ts.interruptHintMessage())
|
||||
@@ -406,13 +413,21 @@ func (p *Pipeline) CallLLM(
|
||||
})
|
||||
} else if !fit {
|
||||
logger.WarnCF("agent", "Context still exceeds budget after retry compaction rebuild", map[string]any{
|
||||
"session_key": ts.sessionKey,
|
||||
"retry": retry,
|
||||
"history_msgs": len(exec.history),
|
||||
"context_window": ts.agent.ContextWindow,
|
||||
"max_tokens": ts.agent.MaxTokens,
|
||||
"session_key": ts.sessionKey,
|
||||
"retry": retry,
|
||||
"history_msgs": len(exec.history),
|
||||
"protected_turn_msgs": len(protectedTurnTail),
|
||||
"context_window": ts.agent.ContextWindow,
|
||||
"max_tokens": ts.agent.MaxTokens,
|
||||
})
|
||||
}
|
||||
if !fit {
|
||||
err = fmt.Errorf(
|
||||
"context window still exceeded after retry compaction; refusing to drop active turn messages: %w",
|
||||
err,
|
||||
)
|
||||
break
|
||||
}
|
||||
continue
|
||||
}
|
||||
break
|
||||
|
||||
+78
-4
@@ -634,13 +634,17 @@ func (ts *turnState) recordPersistedMessage(msg providers.Message) {
|
||||
ts.persistedMessages = append(ts.persistedMessages, msg)
|
||||
}
|
||||
|
||||
func (ts *turnState) persistedMessagesSnapshot() []providers.Message {
|
||||
ts.mu.RLock()
|
||||
defer ts.mu.RUnlock()
|
||||
return append([]providers.Message(nil), ts.persistedMessages...)
|
||||
}
|
||||
|
||||
func (ts *turnState) refreshRestorePointFromSession(agent *AgentInstance) {
|
||||
history := agent.Sessions.GetHistory(ts.sessionKey)
|
||||
summary := agent.Sessions.GetSummary(ts.sessionKey)
|
||||
|
||||
ts.mu.RLock()
|
||||
persisted := append([]providers.Message(nil), ts.persistedMessages...)
|
||||
ts.mu.RUnlock()
|
||||
persisted := ts.persistedMessagesSnapshot()
|
||||
|
||||
if matched := matchingTurnMessageTail(history, persisted); matched > 0 {
|
||||
history = append([]providers.Message(nil), history[:len(history)-matched]...)
|
||||
@@ -680,13 +684,83 @@ func (ts *turnState) restoreSession(agent *AgentInstance) error {
|
||||
func matchingTurnMessageTail(history, persisted []providers.Message) int {
|
||||
maxMatch := min(len(history), len(persisted))
|
||||
for size := maxMatch; size > 0; size-- {
|
||||
if reflect.DeepEqual(history[len(history)-size:], persisted[len(persisted)-size:]) {
|
||||
if messageSlicesEquivalent(history[len(history)-size:], persisted[len(persisted)-size:]) {
|
||||
return size
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func splitHistoryForActiveTurn(
|
||||
history []providers.Message,
|
||||
persisted []providers.Message,
|
||||
) ([]providers.Message, []providers.Message) {
|
||||
matched := matchingTurnMessageTail(history, persisted)
|
||||
if matched <= 0 {
|
||||
return append([]providers.Message(nil), history...), nil
|
||||
}
|
||||
|
||||
stable := append([]providers.Message(nil), history[:len(history)-matched]...)
|
||||
protected := append([]providers.Message(nil), history[len(history)-matched:]...)
|
||||
return stable, protected
|
||||
}
|
||||
|
||||
func messageSlicesEquivalent(a, b []providers.Message) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if !messagesEquivalent(a[i], b[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func messagesEquivalent(a, b providers.Message) bool {
|
||||
return reflect.DeepEqual(normalizeMessageForComparison(a), normalizeMessageForComparison(b))
|
||||
}
|
||||
|
||||
func normalizeMessageForComparison(msg providers.Message) providers.Message {
|
||||
msg.PromptLayer = ""
|
||||
msg.PromptSlot = ""
|
||||
msg.PromptSource = ""
|
||||
|
||||
if len(msg.Media) == 0 {
|
||||
msg.Media = nil
|
||||
}
|
||||
if len(msg.Attachments) == 0 {
|
||||
msg.Attachments = nil
|
||||
}
|
||||
if len(msg.SystemParts) == 0 {
|
||||
msg.SystemParts = nil
|
||||
} else {
|
||||
msg.SystemParts = append([]providers.ContentBlock(nil), msg.SystemParts...)
|
||||
for i := range msg.SystemParts {
|
||||
msg.SystemParts[i].PromptLayer = ""
|
||||
msg.SystemParts[i].PromptSlot = ""
|
||||
msg.SystemParts[i].PromptSource = ""
|
||||
}
|
||||
}
|
||||
if len(msg.ToolCalls) == 0 {
|
||||
msg.ToolCalls = nil
|
||||
} else {
|
||||
msg.ToolCalls = append([]providers.ToolCall(nil), msg.ToolCalls...)
|
||||
for i := range msg.ToolCalls {
|
||||
msg.ToolCalls[i].Name = ""
|
||||
msg.ToolCalls[i].Arguments = nil
|
||||
msg.ToolCalls[i].ThoughtSignature = ""
|
||||
if msg.ToolCalls[i].Function != nil {
|
||||
fn := *msg.ToolCalls[i].Function
|
||||
fn.ThoughtSignature = ""
|
||||
msg.ToolCalls[i].Function = &fn
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return msg
|
||||
}
|
||||
|
||||
func (ts *turnState) interruptHintMessage() providers.Message {
|
||||
_, hint := ts.gracefulInterruptRequested()
|
||||
content := "Interrupt requested. Stop scheduling tools and provide a short final summary."
|
||||
|
||||
@@ -0,0 +1,112 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
func TestMatchingTurnMessageTail_IgnoresInternalRuntimeFields(t *testing.T) {
|
||||
history := []providers.Message{
|
||||
{Role: "user", Content: "question"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []providers.ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Type: "function",
|
||||
Function: &providers.FunctionCall{
|
||||
Name: "read_file",
|
||||
Arguments: `{"path":"/tmp/test"}`,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
persisted := []providers.Message{
|
||||
userPromptMessage("question", nil),
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []providers.ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Type: "function",
|
||||
Name: "read_file",
|
||||
Arguments: map[string]any{"path": "/tmp/test"},
|
||||
ThoughtSignature: "internal-signature",
|
||||
Function: &providers.FunctionCall{
|
||||
Name: "read_file",
|
||||
Arguments: `{"path":"/tmp/test"}`,
|
||||
ThoughtSignature: "internal-signature",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if got := matchingTurnMessageTail(history, persisted); got != 2 {
|
||||
t.Fatalf("matchingTurnMessageTail() = %d, want 2", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitHistoryForActiveTurn_ProtectsPersistedTail(t *testing.T) {
|
||||
history := []providers.Message{
|
||||
{Role: "user", Content: "old question"},
|
||||
{Role: "assistant", Content: "old answer"},
|
||||
{Role: "user", Content: "current question"},
|
||||
{Role: "tool", Content: "tool output", ToolCallID: "call_1"},
|
||||
}
|
||||
|
||||
persisted := []providers.Message{
|
||||
userPromptMessage("current question", nil),
|
||||
{Role: "tool", Content: "tool output", ToolCallID: "call_1"},
|
||||
}
|
||||
|
||||
stable, protected := splitHistoryForActiveTurn(history, persisted)
|
||||
if len(stable) != 2 {
|
||||
t.Fatalf("stable history len = %d, want 2", len(stable))
|
||||
}
|
||||
if len(protected) != 2 {
|
||||
t.Fatalf("protected tail len = %d, want 2", len(protected))
|
||||
}
|
||||
if protected[0].Content != "current question" {
|
||||
t.Fatalf("protected[0].Content = %q, want current question", protected[0].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTrimHistoryToFitContextWindow_WithProtectedTurnTailKeepsActiveTurn(t *testing.T) {
|
||||
current := strings.Repeat("current turn ", 80)
|
||||
history := []providers.Message{
|
||||
{Role: "user", Content: strings.Repeat("old turn ", 60)},
|
||||
{Role: "assistant", Content: strings.Repeat("old reply ", 60)},
|
||||
{Role: "user", Content: current},
|
||||
}
|
||||
|
||||
stable, protected := splitHistoryForActiveTurn(history, []providers.Message{
|
||||
userPromptMessage(current, nil),
|
||||
})
|
||||
trimmedStable, messages, fit := trimHistoryToFitContextWindow(
|
||||
stable,
|
||||
func(trimmedHistory []providers.Message) []providers.Message {
|
||||
return append(append([]providers.Message(nil), trimmedHistory...), protected...)
|
||||
},
|
||||
120,
|
||||
nil,
|
||||
0,
|
||||
)
|
||||
|
||||
if fit {
|
||||
t.Fatal("expected protected active turn alone to remain over budget")
|
||||
}
|
||||
if len(trimmedStable) != 0 {
|
||||
t.Fatalf("trimmed stable history len = %d, want 0", len(trimmedStable))
|
||||
}
|
||||
if len(messages) != 1 {
|
||||
t.Fatalf("messages len = %d, want 1 protected active-turn message", len(messages))
|
||||
}
|
||||
if messages[0].Content != current {
|
||||
t.Fatalf("messages[0].Content = %q, want protected current turn", messages[0].Content)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user