feat(agent): add session state rollback on hard abort

- Add initialHistoryLength field to turnState to snapshot session state at turn start
- Save initial history length in runAgentLoop when creating root turnState
- Implement session rollback in HardAbort via SetHistory, truncating to initial length
- Add TestHardAbortSessionRollback to verify history rollback after abort
- Import providers package in subturn_test.go for Message type

This ensures that when a user triggers hard abort, all messages added during
the aborted turn are discarded, restoring the session to its pre-turn state.
This commit is contained in:
Administrator
2026-03-16 21:49:58 +08:00
parent 1236dd9e6d
commit acd436acfe
5 changed files with 99 additions and 20 deletions
+7
View File
@@ -0,0 +1,7 @@
{
"permissions": {
"allow": [
"Bash(go test:*)"
]
}
}
+7 -6
View File
@@ -966,12 +966,13 @@ func (al *AgentLoop) runAgentLoop(
) (string, error) {
// Initialize a root TurnState for this iteration, allowing sub-turns to be spawned.
rootTS := &turnState{
ctx: ctx,
turnID: opts.SessionKey, // Associate this turn graph with the current session key
depth: 0,
session: agent.Sessions,
pendingResults: make(chan *tools.ToolResult, 16),
concurrencySem: make(chan struct{}, 5), // maxConcurrentSubTurns
ctx: ctx,
turnID: opts.SessionKey, // Associate this turn graph with the current session key
depth: 0,
session: agent.Sessions,
initialHistoryLength: len(agent.Sessions.GetHistory("")), // Snapshot for rollback on hard abort
pendingResults: make(chan *tools.ToolResult, 16),
concurrencySem: make(chan struct{}, 5), // maxConcurrentSubTurns
}
ctx = withTurnState(ctx, rootTS)
+17 -3
View File
@@ -249,11 +249,25 @@ func (al *AgentLoop) HardAbort(sessionKey string) error {
}
logger.InfoCF("agent", "Hard abort triggered", map[string]any{
"session_key": sessionKey,
"turn_id": ts.turnID,
"depth": ts.depth,
"session_key": sessionKey,
"turn_id": ts.turnID,
"depth": ts.depth,
"initial_history_length": ts.initialHistoryLength,
})
// Rollback session history to the state before this turn started
if ts.session != nil {
currentHistory := ts.session.GetHistory("")
if len(currentHistory) > ts.initialHistoryLength {
logger.InfoCF("agent", "Rolling back session history", map[string]any{
"from": len(currentHistory),
"to": ts.initialHistoryLength,
})
// SetHistory with the truncated slice to rollback
ts.session.SetHistory("", currentHistory[:ts.initialHistoryLength])
}
}
// Trigger cascading cancellation to all child SubTurns
ts.Finish()
+12 -11
View File
@@ -73,17 +73,18 @@ func turnStateFromContext(ctx context.Context) *turnState {
}
type turnState struct {
ctx context.Context
cancelFunc context.CancelFunc // Used to cancel all children when this turn finishes
turnID string
parentTurnID string
depth int
childTurnIDs []string
pendingResults chan *tools.ToolResult
session session.SessionStore
mu sync.Mutex
isFinished bool // Marks if the parent Turn has ended
concurrencySem chan struct{} // Limits concurrent child sub-turns
ctx context.Context
cancelFunc context.CancelFunc // Used to cancel all children when this turn finishes
turnID string
parentTurnID string
depth int
childTurnIDs []string
pendingResults chan *tools.ToolResult
session session.SessionStore
initialHistoryLength int // Snapshot of session history length at turn start, for rollback on hard abort
mu sync.Mutex
isFinished bool // Marks if the parent Turn has ended
concurrencySem chan struct{} // Limits concurrent child sub-turns
}
// ====================== Helper Functions ======================
+56
View File
@@ -5,6 +5,7 @@ import (
"reflect"
"testing"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/tools"
)
@@ -444,3 +445,58 @@ func TestHardAbortCascading(t *testing.T) {
t.Error("expected error for non-existent session")
}
}
// TestHardAbortSessionRollback verifies that HardAbort rolls back session history
// to the state before the turn started, discarding all messages added during the turn.
func TestHardAbortSessionRollback(t *testing.T) {
al, _, _, _, cleanup := newTestAgentLoop(t)
defer cleanup()
// Create a session with initial history
sess := &ephemeralSessionStore{
history: []providers.Message{
{Role: "user", Content: "initial message 1"},
{Role: "assistant", Content: "initial response 1"},
},
}
// Create a root turnState with initialHistoryLength = 2
rootTS := &turnState{
ctx: context.Background(),
turnID: "test-session",
depth: 0,
session: sess,
initialHistoryLength: 2, // Snapshot: 2 messages
pendingResults: make(chan *tools.ToolResult, 16),
concurrencySem: make(chan struct{}, 5),
}
// Register the turn state
al.activeTurnStates.Store("test-session", rootTS)
// Simulate adding messages during the turn (e.g., user input + assistant response)
sess.AddMessage("", "user", "new user message")
sess.AddMessage("", "assistant", "new assistant response")
// Verify history grew to 4 messages
if len(sess.GetHistory("")) != 4 {
t.Fatalf("expected 4 messages before abort, got %d", len(sess.GetHistory("")))
}
// Trigger HardAbort
err := al.HardAbort("test-session")
if err != nil {
t.Fatalf("HardAbort failed: %v", err)
}
// Verify history rolled back to initial 2 messages
finalHistory := sess.GetHistory("")
if len(finalHistory) != 2 {
t.Errorf("expected history to rollback to 2 messages, got %d", len(finalHistory))
}
// Verify the content matches the initial state
if finalHistory[0].Content != "initial message 1" || finalHistory[1].Content != "initial response 1" {
t.Error("history content does not match initial state after rollback")
}
}