From acd436acfe66dc153443d77abd00673940229ad7 Mon Sep 17 00:00:00 2001 From: Administrator <1280842908@qq.com> Date: Mon, 16 Mar 2026 21:49:58 +0800 Subject: [PATCH] feat(agent): add session state rollback on hard abort - Add initialHistoryLength field to turnState to snapshot session state at turn start - Save initial history length in runAgentLoop when creating root turnState - Implement session rollback in HardAbort via SetHistory, truncating to initial length - Add TestHardAbortSessionRollback to verify history rollback after abort - Import providers package in subturn_test.go for Message type This ensures that when a user triggers hard abort, all messages added during the aborted turn are discarded, restoring the session to its pre-turn state. --- .claude/settings.json | 7 +++++ pkg/agent/loop.go | 13 ++++----- pkg/agent/steering.go | 20 +++++++++++--- pkg/agent/subturn.go | 23 ++++++++-------- pkg/agent/subturn_test.go | 56 +++++++++++++++++++++++++++++++++++++++ 5 files changed, 99 insertions(+), 20 deletions(-) create mode 100644 .claude/settings.json diff --git a/.claude/settings.json b/.claude/settings.json new file mode 100644 index 000000000..2df2bfb5b --- /dev/null +++ b/.claude/settings.json @@ -0,0 +1,7 @@ +{ + "permissions": { + "allow": [ + "Bash(go test:*)" + ] + } +} diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index dd4c81373..3324d56cc 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -966,12 +966,13 @@ func (al *AgentLoop) runAgentLoop( ) (string, error) { // Initialize a root TurnState for this iteration, allowing sub-turns to be spawned. rootTS := &turnState{ - ctx: ctx, - turnID: opts.SessionKey, // Associate this turn graph with the current session key - depth: 0, - session: agent.Sessions, - pendingResults: make(chan *tools.ToolResult, 16), - concurrencySem: make(chan struct{}, 5), // maxConcurrentSubTurns + ctx: ctx, + turnID: opts.SessionKey, // Associate this turn graph with the current session key + depth: 0, + session: agent.Sessions, + initialHistoryLength: len(agent.Sessions.GetHistory("")), // Snapshot for rollback on hard abort + pendingResults: make(chan *tools.ToolResult, 16), + concurrencySem: make(chan struct{}, 5), // maxConcurrentSubTurns } ctx = withTurnState(ctx, rootTS) diff --git a/pkg/agent/steering.go b/pkg/agent/steering.go index 840a73723..e67a779a3 100644 --- a/pkg/agent/steering.go +++ b/pkg/agent/steering.go @@ -249,11 +249,25 @@ func (al *AgentLoop) HardAbort(sessionKey string) error { } logger.InfoCF("agent", "Hard abort triggered", map[string]any{ - "session_key": sessionKey, - "turn_id": ts.turnID, - "depth": ts.depth, + "session_key": sessionKey, + "turn_id": ts.turnID, + "depth": ts.depth, + "initial_history_length": ts.initialHistoryLength, }) + // Rollback session history to the state before this turn started + if ts.session != nil { + currentHistory := ts.session.GetHistory("") + if len(currentHistory) > ts.initialHistoryLength { + logger.InfoCF("agent", "Rolling back session history", map[string]any{ + "from": len(currentHistory), + "to": ts.initialHistoryLength, + }) + // SetHistory with the truncated slice to rollback + ts.session.SetHistory("", currentHistory[:ts.initialHistoryLength]) + } + } + // Trigger cascading cancellation to all child SubTurns ts.Finish() diff --git a/pkg/agent/subturn.go b/pkg/agent/subturn.go index 691353e90..0135dfc76 100644 --- a/pkg/agent/subturn.go +++ b/pkg/agent/subturn.go @@ -73,17 +73,18 @@ func turnStateFromContext(ctx context.Context) *turnState { } type turnState struct { - ctx context.Context - cancelFunc context.CancelFunc // Used to cancel all children when this turn finishes - turnID string - parentTurnID string - depth int - childTurnIDs []string - pendingResults chan *tools.ToolResult - session session.SessionStore - mu sync.Mutex - isFinished bool // Marks if the parent Turn has ended - concurrencySem chan struct{} // Limits concurrent child sub-turns + ctx context.Context + cancelFunc context.CancelFunc // Used to cancel all children when this turn finishes + turnID string + parentTurnID string + depth int + childTurnIDs []string + pendingResults chan *tools.ToolResult + session session.SessionStore + initialHistoryLength int // Snapshot of session history length at turn start, for rollback on hard abort + mu sync.Mutex + isFinished bool // Marks if the parent Turn has ended + concurrencySem chan struct{} // Limits concurrent child sub-turns } // ====================== Helper Functions ====================== diff --git a/pkg/agent/subturn_test.go b/pkg/agent/subturn_test.go index 1b609318d..5b99ebf9f 100644 --- a/pkg/agent/subturn_test.go +++ b/pkg/agent/subturn_test.go @@ -5,6 +5,7 @@ import ( "reflect" "testing" + "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/tools" ) @@ -444,3 +445,58 @@ func TestHardAbortCascading(t *testing.T) { t.Error("expected error for non-existent session") } } + +// TestHardAbortSessionRollback verifies that HardAbort rolls back session history +// to the state before the turn started, discarding all messages added during the turn. +func TestHardAbortSessionRollback(t *testing.T) { + al, _, _, _, cleanup := newTestAgentLoop(t) + defer cleanup() + + // Create a session with initial history + sess := &ephemeralSessionStore{ + history: []providers.Message{ + {Role: "user", Content: "initial message 1"}, + {Role: "assistant", Content: "initial response 1"}, + }, + } + + // Create a root turnState with initialHistoryLength = 2 + rootTS := &turnState{ + ctx: context.Background(), + turnID: "test-session", + depth: 0, + session: sess, + initialHistoryLength: 2, // Snapshot: 2 messages + pendingResults: make(chan *tools.ToolResult, 16), + concurrencySem: make(chan struct{}, 5), + } + + // Register the turn state + al.activeTurnStates.Store("test-session", rootTS) + + // Simulate adding messages during the turn (e.g., user input + assistant response) + sess.AddMessage("", "user", "new user message") + sess.AddMessage("", "assistant", "new assistant response") + + // Verify history grew to 4 messages + if len(sess.GetHistory("")) != 4 { + t.Fatalf("expected 4 messages before abort, got %d", len(sess.GetHistory(""))) + } + + // Trigger HardAbort + err := al.HardAbort("test-session") + if err != nil { + t.Fatalf("HardAbort failed: %v", err) + } + + // Verify history rolled back to initial 2 messages + finalHistory := sess.GetHistory("") + if len(finalHistory) != 2 { + t.Errorf("expected history to rollback to 2 messages, got %d", len(finalHistory)) + } + + // Verify the content matches the initial state + if finalHistory[0].Content != "initial message 1" || finalHistory[1].Content != "initial response 1" { + t.Error("history content does not match initial state after rollback") + } +}