mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
feat(agent): add session state rollback on hard abort
- Add initialHistoryLength field to turnState to snapshot session state at turn start - Save initial history length in runAgentLoop when creating root turnState - Implement session rollback in HardAbort via SetHistory, truncating to initial length - Add TestHardAbortSessionRollback to verify history rollback after abort - Import providers package in subturn_test.go for Message type This ensures that when a user triggers hard abort, all messages added during the aborted turn are discarded, restoring the session to its pre-turn state.
This commit is contained in:
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"permissions": {
|
||||
"allow": [
|
||||
"Bash(go test:*)"
|
||||
]
|
||||
}
|
||||
}
|
||||
+7
-6
@@ -966,12 +966,13 @@ func (al *AgentLoop) runAgentLoop(
|
||||
) (string, error) {
|
||||
// Initialize a root TurnState for this iteration, allowing sub-turns to be spawned.
|
||||
rootTS := &turnState{
|
||||
ctx: ctx,
|
||||
turnID: opts.SessionKey, // Associate this turn graph with the current session key
|
||||
depth: 0,
|
||||
session: agent.Sessions,
|
||||
pendingResults: make(chan *tools.ToolResult, 16),
|
||||
concurrencySem: make(chan struct{}, 5), // maxConcurrentSubTurns
|
||||
ctx: ctx,
|
||||
turnID: opts.SessionKey, // Associate this turn graph with the current session key
|
||||
depth: 0,
|
||||
session: agent.Sessions,
|
||||
initialHistoryLength: len(agent.Sessions.GetHistory("")), // Snapshot for rollback on hard abort
|
||||
pendingResults: make(chan *tools.ToolResult, 16),
|
||||
concurrencySem: make(chan struct{}, 5), // maxConcurrentSubTurns
|
||||
}
|
||||
ctx = withTurnState(ctx, rootTS)
|
||||
|
||||
|
||||
+17
-3
@@ -249,11 +249,25 @@ func (al *AgentLoop) HardAbort(sessionKey string) error {
|
||||
}
|
||||
|
||||
logger.InfoCF("agent", "Hard abort triggered", map[string]any{
|
||||
"session_key": sessionKey,
|
||||
"turn_id": ts.turnID,
|
||||
"depth": ts.depth,
|
||||
"session_key": sessionKey,
|
||||
"turn_id": ts.turnID,
|
||||
"depth": ts.depth,
|
||||
"initial_history_length": ts.initialHistoryLength,
|
||||
})
|
||||
|
||||
// Rollback session history to the state before this turn started
|
||||
if ts.session != nil {
|
||||
currentHistory := ts.session.GetHistory("")
|
||||
if len(currentHistory) > ts.initialHistoryLength {
|
||||
logger.InfoCF("agent", "Rolling back session history", map[string]any{
|
||||
"from": len(currentHistory),
|
||||
"to": ts.initialHistoryLength,
|
||||
})
|
||||
// SetHistory with the truncated slice to rollback
|
||||
ts.session.SetHistory("", currentHistory[:ts.initialHistoryLength])
|
||||
}
|
||||
}
|
||||
|
||||
// Trigger cascading cancellation to all child SubTurns
|
||||
ts.Finish()
|
||||
|
||||
|
||||
+12
-11
@@ -73,17 +73,18 @@ func turnStateFromContext(ctx context.Context) *turnState {
|
||||
}
|
||||
|
||||
type turnState struct {
|
||||
ctx context.Context
|
||||
cancelFunc context.CancelFunc // Used to cancel all children when this turn finishes
|
||||
turnID string
|
||||
parentTurnID string
|
||||
depth int
|
||||
childTurnIDs []string
|
||||
pendingResults chan *tools.ToolResult
|
||||
session session.SessionStore
|
||||
mu sync.Mutex
|
||||
isFinished bool // Marks if the parent Turn has ended
|
||||
concurrencySem chan struct{} // Limits concurrent child sub-turns
|
||||
ctx context.Context
|
||||
cancelFunc context.CancelFunc // Used to cancel all children when this turn finishes
|
||||
turnID string
|
||||
parentTurnID string
|
||||
depth int
|
||||
childTurnIDs []string
|
||||
pendingResults chan *tools.ToolResult
|
||||
session session.SessionStore
|
||||
initialHistoryLength int // Snapshot of session history length at turn start, for rollback on hard abort
|
||||
mu sync.Mutex
|
||||
isFinished bool // Marks if the parent Turn has ended
|
||||
concurrencySem chan struct{} // Limits concurrent child sub-turns
|
||||
}
|
||||
|
||||
// ====================== Helper Functions ======================
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
|
||||
@@ -444,3 +445,58 @@ func TestHardAbortCascading(t *testing.T) {
|
||||
t.Error("expected error for non-existent session")
|
||||
}
|
||||
}
|
||||
|
||||
// TestHardAbortSessionRollback verifies that HardAbort rolls back session history
|
||||
// to the state before the turn started, discarding all messages added during the turn.
|
||||
func TestHardAbortSessionRollback(t *testing.T) {
|
||||
al, _, _, _, cleanup := newTestAgentLoop(t)
|
||||
defer cleanup()
|
||||
|
||||
// Create a session with initial history
|
||||
sess := &ephemeralSessionStore{
|
||||
history: []providers.Message{
|
||||
{Role: "user", Content: "initial message 1"},
|
||||
{Role: "assistant", Content: "initial response 1"},
|
||||
},
|
||||
}
|
||||
|
||||
// Create a root turnState with initialHistoryLength = 2
|
||||
rootTS := &turnState{
|
||||
ctx: context.Background(),
|
||||
turnID: "test-session",
|
||||
depth: 0,
|
||||
session: sess,
|
||||
initialHistoryLength: 2, // Snapshot: 2 messages
|
||||
pendingResults: make(chan *tools.ToolResult, 16),
|
||||
concurrencySem: make(chan struct{}, 5),
|
||||
}
|
||||
|
||||
// Register the turn state
|
||||
al.activeTurnStates.Store("test-session", rootTS)
|
||||
|
||||
// Simulate adding messages during the turn (e.g., user input + assistant response)
|
||||
sess.AddMessage("", "user", "new user message")
|
||||
sess.AddMessage("", "assistant", "new assistant response")
|
||||
|
||||
// Verify history grew to 4 messages
|
||||
if len(sess.GetHistory("")) != 4 {
|
||||
t.Fatalf("expected 4 messages before abort, got %d", len(sess.GetHistory("")))
|
||||
}
|
||||
|
||||
// Trigger HardAbort
|
||||
err := al.HardAbort("test-session")
|
||||
if err != nil {
|
||||
t.Fatalf("HardAbort failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify history rolled back to initial 2 messages
|
||||
finalHistory := sess.GetHistory("")
|
||||
if len(finalHistory) != 2 {
|
||||
t.Errorf("expected history to rollback to 2 messages, got %d", len(finalHistory))
|
||||
}
|
||||
|
||||
// Verify the content matches the initial state
|
||||
if finalHistory[0].Content != "initial message 1" || finalHistory[1].Content != "initial response 1" {
|
||||
t.Error("history content does not match initial state after rollback")
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user