feat(agent): implement graceful finish vs hard abort for SubTurn lifecycle

Problem:
When parent turn finishes early, all child SubTurns receive "context canceled"
error,because child context was derived from parent context.

Solution:
Implement a lifecycle management system that distinguishes between:
- Graceful finish (Finish(false)): signals parentEnded, children continue
- Hard abort (Finish(true)): immediately cancels all children

Changes:
- turn_state.go:
  - Add parentEnded atomic.Bool to signal parent completion
  - Add parentTurnState reference for IsParentEnded() checks
  - Modify Finish(isHardAbort bool) to distinguish abort types

- subturn.go:
  - Add Critical bool to SubTurnConfig (Critical SubTurns continue after parent ends)
  - Add Timeout time.Duration for SubTurn self-protection
  - Use independent context (context.Background()) instead of derived context
  - SubTurns check IsParentEnded() to decide whether to continue or exit

- loop.go:
  - Call Finish(false) for normal completion (graceful)
  - Add IsParentEnded() check in LLM iteration loop

- steering.go:
  - HardAbort calls Finish(true) to immediately cancel children

Behavior:
- Normal finish: parentEnded=true, children continue, orphan results delivered
- Hard abort: all children cancelled immediately via context
- Critical SubTurns: continue running after parent finishes gracefully
- Non-Critical SubTurns: can exit gracefully when IsParentEnded() returns true
This commit is contained in:
Administrator
2026-03-17 23:06:16 +08:00
parent e05d2620e1
commit f8defe3ae1
5 changed files with 284 additions and 62 deletions
+19 -2
View File
@@ -1073,10 +1073,12 @@ func (al *AgentLoop) runAgentLoop(
}
}
// Signal completion to rootTS so it knows it is finished, terminating any active sub-turns.
// Signal completion to rootTS so it knows it is finished.
// Only call Finish() if this is a root turn (not a SubTurn recursively calling runAgentLoop).
// Use isHardAbort=false for normal completion (graceful finish).
// This allows Critical SubTurns to continue running and deliver orphan results.
if isRootTurn {
rootTS.Finish()
rootTS.Finish(false)
}
// If last tool had ForUser content and we already sent it, we might not need to send final response
@@ -1211,6 +1213,21 @@ func (al *AgentLoop) runLLMIteration(
for iteration < agent.MaxIterations || len(pendingMessages) > 0 {
iteration++
// Check if parent turn has ended (graceful finish).
// This is only relevant for SubTurns (turnState with parentTurnState != nil).
// If parent ended and this SubTurn is not Critical, exit gracefully.
if ts := turnStateFromContext(ctx); ts != nil && ts.IsParentEnded() {
logger.InfoCF("agent", "Parent turn ended, SubTurn continues or exits", map[string]any{
"agent_id": agent.ID,
"iteration": iteration,
"turn_id": ts.turnID,
})
// For now, we continue running. The Critical flag check is handled
// at SubTurnConfig level in spawnSubTurn. Here we just log and continue.
// If this SubTurn should exit gracefully, it would have been cancelled
// by its own timeout or the caller would have handled it.
}
// Inject pending steering messages into the conversation context
// before the next LLM call.
if len(pendingMessages) > 0 {
+2 -1
View File
@@ -258,7 +258,8 @@ func (al *AgentLoop) HardAbort(sessionKey string) error {
// IMPORTANT: Trigger cascading cancellation FIRST to stop all child SubTurns
// from adding more messages to the session. This prevents race conditions
// where rollback happens while children are still writing.
ts.Finish()
// Use isHardAbort=true for hard abort to immediately cancel all children.
ts.Finish(true)
// Rollback session history to the state before this turn started.
// This must happen AFTER Finish() to ensure no child turns are still writing.
+35 -30
View File
@@ -21,6 +21,9 @@ const (
// maxEphemeralHistorySize limits the number of messages stored in ephemeral sessions.
// This prevents memory accumulation in long-running sub-turns.
maxEphemeralHistorySize = 50
// defaultSubTurnTimeout is the default maximum duration for a SubTurn.
// SubTurns that run longer than this will be cancelled.
defaultSubTurnTimeout = 5 * time.Minute
)
var (
@@ -85,6 +88,22 @@ type SubTurnConfig struct {
// the caller must spawn the sub-turn in a separate goroutine.
Async bool
// Critical indicates this SubTurn's result is important and should continue
// running even after the parent turn finishes gracefully.
//
// When parent finishes gracefully (Finish(false)):
// - Critical=true: SubTurn continues running, delivers result as orphan
// - Critical=false: SubTurn exits gracefully without error
//
// When parent finishes with hard abort (Finish(true)):
// - All SubTurns are cancelled regardless of Critical flag
Critical bool
// Timeout is the maximum duration for this SubTurn.
// If the SubTurn runs longer than this, it will be cancelled.
// Default is 5 minutes (defaultSubTurnTimeout) if not specified.
Timeout time.Duration
// Can be extended with temperature, topP, etc.
}
@@ -227,34 +246,40 @@ func spawnSubTurn(ctx context.Context, al *AgentLoop, parentTS *turnState, cfg S
return nil, ErrInvalidSubTurnConfig
}
// 3. Create child Turn state with a cancellable context
// This single context wrapping is sufficient - no need for additional layers.
childCtx, cancel := context.WithCancel(ctx)
// 3. Determine timeout for child SubTurn
timeout := cfg.Timeout
if timeout <= 0 {
timeout = defaultSubTurnTimeout
}
// 4. Create INDEPENDENT child context (not derived from parent ctx).
// This allows the child to continue running after parent finishes gracefully.
// The child has its own timeout for self-protection.
childCtx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
childID := al.generateSubTurnID()
childTS := newTurnState(childCtx, childID, parentTS)
// Set the cancel function so Finish() can trigger cascading cancellation
// Set the cancel function so Finish(true) can trigger hard cancellation
childTS.cancelFunc = cancel
// IMPORTANT: Put childTS into childCtx so that code inside runTurn can retrieve it
childCtx = withTurnState(childCtx, childTS)
childCtx = WithAgentLoop(childCtx, al) // Propagate AgentLoop to child turn
// 4. Establish parent-child relationship (thread-safe)
// 5. Establish parent-child relationship (thread-safe)
parentTS.mu.Lock()
parentTS.childTurnIDs = append(parentTS.childTurnIDs, childID)
parentTS.mu.Unlock()
// 5. Emit Spawn event (currently using Mock, will be replaced by real EventBus)
// 6. Emit Spawn event
MockEventBus.Emit(SubTurnSpawnEvent{
ParentID: parentTS.turnID,
ChildID: childID,
Config: cfg,
})
// 6. Defer cleanup: deliver result (for async), emit End event, and recover from panics
// IMPORTANT: deliverSubTurnResult must be in defer to ensure it runs even if runTurn panics.
// 7. Defer cleanup: deliver result (for async), emit End event, and recover from panics
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("subturn panicked: %v", r)
@@ -265,26 +290,7 @@ func spawnSubTurn(ctx context.Context, al *AgentLoop, parentTS *turnState, cfg S
})
}
// 7. Result Delivery Strategy (Async vs Sync)
//
// WHY we have different delivery mechanisms:
// ==========================================
//
// Synchronous sub-turns (Async=false):
// - Caller expects immediate result via return value
// - Delivering to channel would cause DOUBLE DELIVERY:
// 1. Caller gets result from return value
// 2. Parent turn would poll channel and get the same result again
// - This would confuse the parent turn's result processing logic
// - Solution: Skip channel delivery, only return via function return
//
// Asynchronous sub-turns (Async=true):
// - Caller may not immediately process the return value
// - Result needs to be available for later polling via pendingResults
// - Parent turn can collect multiple async results in batches
// - Solution: Deliver to channel AND return via function return
//
// This must be in defer to ensure delivery even if runTurn panics.
// Result Delivery Strategy (Async vs Sync)
if cfg.Async {
deliverSubTurnResult(parentTS, childID, result)
}
@@ -296,8 +302,7 @@ func spawnSubTurn(ctx context.Context, al *AgentLoop, parentTS *turnState, cfg S
})
}()
// 7. Execute sub-turn via the real agent loop.
// Build a child AgentInstance from SubTurnConfig, inheriting defaults from the parent agent.
// 8. Execute sub-turn via the real agent loop.
result, err = runTurn(childCtx, al, childTS, cfg)
return result, err
+174 -16
View File
@@ -278,7 +278,7 @@ func TestSpawnSubTurn_OrphanResultRouting(t *testing.T) {
defer func() { MockEventBus.Emit = originalEmit }()
// Simulate parent finishing before child delivers result
parent.Finish()
parent.Finish(false)
// Call deliverSubTurnResult directly to simulate a delayed child
deliverSubTurnResult(parent, "delayed-child", &tools.ToolResult{ForLLM: "late result"})
@@ -739,8 +739,8 @@ func TestFinishClosesChannel(t *testing.T) {
t.Fatal("channel should be open initially")
}
// Call Finish()
ts.Finish()
// Call Finish() with graceful finish
ts.Finish(false)
// Verify channel is closed
_, ok := <-ts.pendingResults
@@ -749,7 +749,7 @@ func TestFinishClosesChannel(t *testing.T) {
}
// Verify Finish() is idempotent (can be called multiple times)
ts.Finish() // Should not panic
ts.Finish(false) // Should not panic
// Verify deliverSubTurnResult doesn't panic when sending to closed channel
result := &tools.ToolResult{ForLLM: "late result"}
@@ -1153,7 +1153,7 @@ func TestFinish_ConcurrentCalls(t *testing.T) {
go func() {
defer wg.Done()
// This should not panic, even when called concurrently
parentTS.Finish()
parentTS.Finish(false)
}()
}
@@ -1219,7 +1219,7 @@ func TestDeliverSubTurnResult_RaceWithFinish(t *testing.T) {
go func() {
defer wg.Done()
time.Sleep(5 * time.Millisecond)
parentTS.Finish()
parentTS.Finish(false)
}()
// Goroutines that deliver results
@@ -1291,7 +1291,7 @@ func TestConcurrencySemaphore_Timeout(t *testing.T) {
concurrencySem: make(chan struct{}, maxConcurrentSubTurns),
}
parentTS.ctx, parentTS.cancelFunc = context.WithCancel(ctx)
defer parentTS.Finish()
defer parentTS.Finish(false)
// Fill all concurrency slots
for i := 0; i < maxConcurrentSubTurns; i++ {
@@ -1391,7 +1391,7 @@ func TestContextWrapping_SingleLayer(t *testing.T) {
concurrencySem: make(chan struct{}, maxConcurrentSubTurns),
}
parentTS.ctx, parentTS.cancelFunc = context.WithCancel(ctx)
defer parentTS.Finish()
defer parentTS.Finish(false)
// Spawn a sub-turn
subTurnCfg := SubTurnConfig{
@@ -1457,7 +1457,7 @@ func TestFinish_DrainsChannel(t *testing.T) {
}
// Call Finish() - it should drain the channel
parentTS.Finish()
parentTS.Finish(false)
// Verify all results were drained and emitted as orphan events
mu.Lock()
@@ -1505,7 +1505,7 @@ func TestSyncSubTurn_NoChannelDelivery(t *testing.T) {
concurrencySem: make(chan struct{}, maxConcurrentSubTurns),
}
parentTS.ctx, parentTS.cancelFunc = context.WithCancel(ctx)
defer parentTS.Finish()
defer parentTS.Finish(false)
// Spawn a SYNCHRONOUS sub-turn (Async=false)
subTurnCfg := SubTurnConfig{
@@ -1562,7 +1562,7 @@ func TestAsyncSubTurn_ChannelDelivery(t *testing.T) {
concurrencySem: make(chan struct{}, maxConcurrentSubTurns),
}
parentTS.ctx, parentTS.cancelFunc = context.WithCancel(ctx)
defer parentTS.Finish()
defer parentTS.Finish(false)
// Spawn an ASYNCHRONOUS sub-turn (Async=true)
subTurnCfg := SubTurnConfig{
@@ -1623,7 +1623,7 @@ func TestChannelFull_OrphanResults(t *testing.T) {
concurrencySem: make(chan struct{}, maxConcurrentSubTurns),
}
parentTS.ctx, parentTS.cancelFunc = context.WithCancel(ctx)
defer parentTS.Finish()
defer parentTS.Finish(false)
// Send more results than the channel capacity (16)
const numResults = 25
@@ -1720,7 +1720,7 @@ func TestGrandchildAbort_CascadingCancellation(t *testing.T) {
}
// Hard abort the grandparent
grandparentTS.Finish()
grandparentTS.Finish(false)
// Wait a bit for cancellation to propagate
time.Sleep(10 * time.Millisecond)
@@ -1793,7 +1793,7 @@ func TestSpawnDuringAbort_RaceCondition(t *testing.T) {
go func() {
defer wg.Done()
time.Sleep(1 * time.Millisecond)
parentTS.Finish()
parentTS.Finish(false)
}()
wg.Wait()
@@ -1904,7 +1904,7 @@ func TestAsyncSubTurn_ParentFinishesEarly(t *testing.T) {
// Parent finishes quickly (after 100ms), while SubTurn is still running
time.Sleep(100 * time.Millisecond)
t.Log("Parent finishing early...")
parentTS.Finish()
parentTS.Finish(false)
// Wait for SubTurn to complete (or be cancelled)
wg.Wait()
@@ -1980,7 +1980,7 @@ func TestAsyncSubTurn_ParentWaitsForChild(t *testing.T) {
t.Log("SubTurn completed, parent now finishing")
// Now parent can finish safely
parentTS.Finish()
parentTS.Finish(false)
// Check the result
if subTurnErr != nil {
@@ -2006,3 +2006,161 @@ func TestAsyncSubTurn_ParentWaitsForChild(t *testing.T) {
t.Log("No result in channel (expected since we waited)")
}
}
// ====================== Graceful vs Hard Finish Tests ======================
// TestFinish_GracefulVsHard verifies the behavior difference between:
// - Finish(false): graceful finish, signals parentEnded but doesn't cancel children
// - Finish(true): hard abort, immediately cancels all children
func TestFinish_GracefulVsHard(t *testing.T) {
// Test 1: Graceful finish should set parentEnded but not cancel context
t.Run("Graceful_SetsParentEnded", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ts := &turnState{
ctx: ctx,
turnID: "graceful-test",
depth: 0,
pendingResults: make(chan *tools.ToolResult, 16),
}
ts.ctx, ts.cancelFunc = context.WithCancel(ctx)
// Finish gracefully
ts.Finish(false)
// Verify parentEnded is set
if !ts.parentEnded.Load() {
t.Error("parentEnded should be true after graceful finish")
}
// Verify context is NOT cancelled (for graceful finish, children continue)
// Note: In graceful mode, we don't call cancelFunc()
// But since we're using WithCancel on the same ctx, it might be cancelled
// Let's check that the context is still valid for a moment
time.Sleep(10 * time.Millisecond)
// Context might be cancelled by the deferred cancel() in test, which is fine
})
// Test 2: Hard abort should cancel context immediately
t.Run("Hard_CancelsContext", func(t *testing.T) {
ctx := context.Background()
ts := &turnState{
ctx: ctx,
turnID: "hard-test",
depth: 0,
pendingResults: make(chan *tools.ToolResult, 16),
}
ts.ctx, ts.cancelFunc = context.WithCancel(ctx)
// Finish with hard abort
ts.Finish(true)
// Verify context is cancelled
select {
case <-ts.ctx.Done():
t.Log("✓ Context cancelled after hard abort")
default:
t.Error("Context should be cancelled after hard abort")
}
})
// Test 3: IsParentEnded returns correct value
t.Run("IsParentEnded", func(t *testing.T) {
ctx := context.Background()
parentTS := &turnState{
ctx: ctx,
turnID: "parent-isended-test",
depth: 0,
pendingResults: make(chan *tools.ToolResult, 16),
}
parentTS.ctx, parentTS.cancelFunc = context.WithCancel(ctx)
childTS := &turnState{
ctx: ctx,
turnID: "child-isended-test",
depth: 1,
parentTurnState: parentTS,
pendingResults: make(chan *tools.ToolResult, 16),
}
// Before parent finishes
if childTS.IsParentEnded() {
t.Error("IsParentEnded should be false before parent finishes")
}
// Finish parent gracefully
parentTS.Finish(false)
// After parent finishes
if !childTS.IsParentEnded() {
t.Error("IsParentEnded should be true after parent finishes gracefully")
}
})
}
// TestSubTurn_IndependentContext verifies that SubTurns use independent contexts
// that don't get cancelled when the parent finishes gracefully.
func TestSubTurn_IndependentContext(t *testing.T) {
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Provider: "mock",
},
},
}
msgBus := bus.NewMessageBus()
provider := &slowMockProvider{delay: 500 * time.Millisecond}
al := NewAgentLoop(cfg, msgBus, provider)
ctx := context.Background()
parentTS := &turnState{
ctx: ctx,
turnID: "parent-independent",
depth: 0,
session: newEphemeralSession(nil),
pendingResults: make(chan *tools.ToolResult, 16),
concurrencySem: make(chan struct{}, maxConcurrentSubTurns),
}
parentTS.ctx, parentTS.cancelFunc = context.WithCancel(ctx)
var subTurnErr error
var wg sync.WaitGroup
// Spawn SubTurn with Critical=true (should continue after parent finishes)
wg.Add(1)
go func() {
defer wg.Done()
subTurnCfg := SubTurnConfig{
Model: "slow-model",
Async: true,
Critical: true, // Critical SubTurn should continue
}
_, subTurnErr = spawnSubTurn(parentTS.ctx, al, parentTS, subTurnCfg)
}()
// Let SubTurn start
time.Sleep(50 * time.Millisecond)
// Parent finishes gracefully (should NOT cancel SubTurn)
parentTS.Finish(false)
t.Log("Parent finished gracefully, SubTurn should continue")
// Wait for SubTurn to complete
wg.Wait()
// SubTurn should complete without context cancelled error
// (because it uses independent context now)
if subTurnErr != nil {
t.Logf("SubTurn error: %v", subTurnErr)
// The error might be context.DeadlineExceeded if timeout is too short
// but should NOT be context.Canceled from parent
if errors.Is(subTurnErr, context.Canceled) {
t.Error("SubTurn should not be cancelled by parent's graceful finish")
}
} else {
t.Log("✓ SubTurn completed successfully (independent context)")
}
}
+54 -13
View File
@@ -3,6 +3,7 @@ package agent
import (
"context"
"sync"
"sync/atomic"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/session"
@@ -44,6 +45,16 @@ type turnState struct {
isFinished bool // MUST be accessed under mu lock
closeOnce sync.Once // Ensures pendingResults channel is closed exactly once
concurrencySem chan struct{} // Limits concurrent child sub-turns
// parentEnded signals that the parent turn has finished gracefully.
// Child SubTurns should check this via IsParentEnded() to decide whether
// to continue running (Critical=true) or exit gracefully (Critical=false).
parentEnded atomic.Bool
// parentTurnState holds a reference to the parent turnState.
// This allows child SubTurns to check if the parent has ended.
// Nil for root turns.
parentTurnState *turnState
}
// ====================== Public API ======================
@@ -99,12 +110,13 @@ func newTurnState(ctx context.Context, id string, parent *turnState) *turnState
// (spawnSubTurn) already creates one. The turnState stores the context and
// cancelFunc provided by the caller to avoid redundant context wrapping.
return &turnState{
ctx: ctx,
cancelFunc: nil, // Will be set by the caller
turnID: id,
parentTurnID: parent.turnID,
depth: parent.depth + 1,
session: newEphemeralSession(parent.session),
ctx: ctx,
cancelFunc: nil, // Will be set by the caller
turnID: id,
parentTurnID: parent.turnID,
depth: parent.depth + 1,
session: newEphemeralSession(parent.session),
parentTurnState: parent, // Store reference to parent for IsParentEnded() checks
// NOTE: In this PoC, I use a fixed-size channel (16).
// Under high concurrency or long-running sub-turns, this might fill up and cause
// intermediate results to be discarded in deliverSubTurnResult.
@@ -114,18 +126,47 @@ func newTurnState(ctx context.Context, id string, parent *turnState) *turnState
}
}
// Finish marks the turn as finished and cancels its context, aborting any running sub-turns.
// It also closes the pendingResults channel to signal that no more results will be delivered.
// This method is safe to call multiple times - the channel will only be closed once.
// Any results remaining in the channel after close will be drained and emitted as orphan events.
func (ts *turnState) Finish() {
// IsParentEnded returns true if the parent turn has finished gracefully.
// This is safe to call from child SubTurn goroutines.
// Returns false if this is a root turn (no parent).
func (ts *turnState) IsParentEnded() bool {
if ts.parentTurnState == nil {
return false
}
return ts.parentTurnState.parentEnded.Load()
}
// IsParentEnded is a convenience method to check if parent ended.
// It returns the value of the parent's parentEnded atomic flag.
// Finish marks the turn as finished.
//
// If isHardAbort is true (Hard Abort):
// - Cancels all child contexts immediately via cancelFunc
// - Used for user-initiated termination (e.g., "stop now")
//
// If isHardAbort is false (Graceful Finish):
// - Only signals parentEnded for graceful child exit
// - Children check IsParentEnded() and decide whether to continue or exit
// - Critical SubTurns continue running and deliver orphan results
// - Non-Critical SubTurns exit gracefully without error
//
// In both cases, the pendingResults channel is closed to signal
// that no more results will be delivered.
func (ts *turnState) Finish(isHardAbort bool) {
ts.mu.Lock()
ts.isFinished = true
resultChan := ts.pendingResults
ts.mu.Unlock()
if ts.cancelFunc != nil {
ts.cancelFunc()
if isHardAbort {
// Hard abort: immediately cancel all children
if ts.cancelFunc != nil {
ts.cancelFunc()
}
} else {
// Graceful finish: signal parent ended, let children decide
ts.parentEnded.Store(true)
}
// Use sync.Once to ensure the channel is closed exactly once, even if Finish() is called concurrently.