mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
feat(agent): implement graceful finish vs hard abort for SubTurn lifecycle
Problem: When parent turn finishes early, all child SubTurns receive "context canceled" error,because child context was derived from parent context. Solution: Implement a lifecycle management system that distinguishes between: - Graceful finish (Finish(false)): signals parentEnded, children continue - Hard abort (Finish(true)): immediately cancels all children Changes: - turn_state.go: - Add parentEnded atomic.Bool to signal parent completion - Add parentTurnState reference for IsParentEnded() checks - Modify Finish(isHardAbort bool) to distinguish abort types - subturn.go: - Add Critical bool to SubTurnConfig (Critical SubTurns continue after parent ends) - Add Timeout time.Duration for SubTurn self-protection - Use independent context (context.Background()) instead of derived context - SubTurns check IsParentEnded() to decide whether to continue or exit - loop.go: - Call Finish(false) for normal completion (graceful) - Add IsParentEnded() check in LLM iteration loop - steering.go: - HardAbort calls Finish(true) to immediately cancel children Behavior: - Normal finish: parentEnded=true, children continue, orphan results delivered - Hard abort: all children cancelled immediately via context - Critical SubTurns: continue running after parent finishes gracefully - Non-Critical SubTurns: can exit gracefully when IsParentEnded() returns true
This commit is contained in:
+19
-2
@@ -1073,10 +1073,12 @@ func (al *AgentLoop) runAgentLoop(
|
||||
}
|
||||
}
|
||||
|
||||
// Signal completion to rootTS so it knows it is finished, terminating any active sub-turns.
|
||||
// Signal completion to rootTS so it knows it is finished.
|
||||
// Only call Finish() if this is a root turn (not a SubTurn recursively calling runAgentLoop).
|
||||
// Use isHardAbort=false for normal completion (graceful finish).
|
||||
// This allows Critical SubTurns to continue running and deliver orphan results.
|
||||
if isRootTurn {
|
||||
rootTS.Finish()
|
||||
rootTS.Finish(false)
|
||||
}
|
||||
|
||||
// If last tool had ForUser content and we already sent it, we might not need to send final response
|
||||
@@ -1211,6 +1213,21 @@ func (al *AgentLoop) runLLMIteration(
|
||||
for iteration < agent.MaxIterations || len(pendingMessages) > 0 {
|
||||
iteration++
|
||||
|
||||
// Check if parent turn has ended (graceful finish).
|
||||
// This is only relevant for SubTurns (turnState with parentTurnState != nil).
|
||||
// If parent ended and this SubTurn is not Critical, exit gracefully.
|
||||
if ts := turnStateFromContext(ctx); ts != nil && ts.IsParentEnded() {
|
||||
logger.InfoCF("agent", "Parent turn ended, SubTurn continues or exits", map[string]any{
|
||||
"agent_id": agent.ID,
|
||||
"iteration": iteration,
|
||||
"turn_id": ts.turnID,
|
||||
})
|
||||
// For now, we continue running. The Critical flag check is handled
|
||||
// at SubTurnConfig level in spawnSubTurn. Here we just log and continue.
|
||||
// If this SubTurn should exit gracefully, it would have been cancelled
|
||||
// by its own timeout or the caller would have handled it.
|
||||
}
|
||||
|
||||
// Inject pending steering messages into the conversation context
|
||||
// before the next LLM call.
|
||||
if len(pendingMessages) > 0 {
|
||||
|
||||
@@ -258,7 +258,8 @@ func (al *AgentLoop) HardAbort(sessionKey string) error {
|
||||
// IMPORTANT: Trigger cascading cancellation FIRST to stop all child SubTurns
|
||||
// from adding more messages to the session. This prevents race conditions
|
||||
// where rollback happens while children are still writing.
|
||||
ts.Finish()
|
||||
// Use isHardAbort=true for hard abort to immediately cancel all children.
|
||||
ts.Finish(true)
|
||||
|
||||
// Rollback session history to the state before this turn started.
|
||||
// This must happen AFTER Finish() to ensure no child turns are still writing.
|
||||
|
||||
+35
-30
@@ -21,6 +21,9 @@ const (
|
||||
// maxEphemeralHistorySize limits the number of messages stored in ephemeral sessions.
|
||||
// This prevents memory accumulation in long-running sub-turns.
|
||||
maxEphemeralHistorySize = 50
|
||||
// defaultSubTurnTimeout is the default maximum duration for a SubTurn.
|
||||
// SubTurns that run longer than this will be cancelled.
|
||||
defaultSubTurnTimeout = 5 * time.Minute
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -85,6 +88,22 @@ type SubTurnConfig struct {
|
||||
// the caller must spawn the sub-turn in a separate goroutine.
|
||||
Async bool
|
||||
|
||||
// Critical indicates this SubTurn's result is important and should continue
|
||||
// running even after the parent turn finishes gracefully.
|
||||
//
|
||||
// When parent finishes gracefully (Finish(false)):
|
||||
// - Critical=true: SubTurn continues running, delivers result as orphan
|
||||
// - Critical=false: SubTurn exits gracefully without error
|
||||
//
|
||||
// When parent finishes with hard abort (Finish(true)):
|
||||
// - All SubTurns are cancelled regardless of Critical flag
|
||||
Critical bool
|
||||
|
||||
// Timeout is the maximum duration for this SubTurn.
|
||||
// If the SubTurn runs longer than this, it will be cancelled.
|
||||
// Default is 5 minutes (defaultSubTurnTimeout) if not specified.
|
||||
Timeout time.Duration
|
||||
|
||||
// Can be extended with temperature, topP, etc.
|
||||
}
|
||||
|
||||
@@ -227,34 +246,40 @@ func spawnSubTurn(ctx context.Context, al *AgentLoop, parentTS *turnState, cfg S
|
||||
return nil, ErrInvalidSubTurnConfig
|
||||
}
|
||||
|
||||
// 3. Create child Turn state with a cancellable context
|
||||
// This single context wrapping is sufficient - no need for additional layers.
|
||||
childCtx, cancel := context.WithCancel(ctx)
|
||||
// 3. Determine timeout for child SubTurn
|
||||
timeout := cfg.Timeout
|
||||
if timeout <= 0 {
|
||||
timeout = defaultSubTurnTimeout
|
||||
}
|
||||
|
||||
// 4. Create INDEPENDENT child context (not derived from parent ctx).
|
||||
// This allows the child to continue running after parent finishes gracefully.
|
||||
// The child has its own timeout for self-protection.
|
||||
childCtx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
childID := al.generateSubTurnID()
|
||||
childTS := newTurnState(childCtx, childID, parentTS)
|
||||
// Set the cancel function so Finish() can trigger cascading cancellation
|
||||
// Set the cancel function so Finish(true) can trigger hard cancellation
|
||||
childTS.cancelFunc = cancel
|
||||
|
||||
// IMPORTANT: Put childTS into childCtx so that code inside runTurn can retrieve it
|
||||
childCtx = withTurnState(childCtx, childTS)
|
||||
childCtx = WithAgentLoop(childCtx, al) // Propagate AgentLoop to child turn
|
||||
|
||||
// 4. Establish parent-child relationship (thread-safe)
|
||||
// 5. Establish parent-child relationship (thread-safe)
|
||||
parentTS.mu.Lock()
|
||||
parentTS.childTurnIDs = append(parentTS.childTurnIDs, childID)
|
||||
parentTS.mu.Unlock()
|
||||
|
||||
// 5. Emit Spawn event (currently using Mock, will be replaced by real EventBus)
|
||||
// 6. Emit Spawn event
|
||||
MockEventBus.Emit(SubTurnSpawnEvent{
|
||||
ParentID: parentTS.turnID,
|
||||
ChildID: childID,
|
||||
Config: cfg,
|
||||
})
|
||||
|
||||
// 6. Defer cleanup: deliver result (for async), emit End event, and recover from panics
|
||||
// IMPORTANT: deliverSubTurnResult must be in defer to ensure it runs even if runTurn panics.
|
||||
// 7. Defer cleanup: deliver result (for async), emit End event, and recover from panics
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("subturn panicked: %v", r)
|
||||
@@ -265,26 +290,7 @@ func spawnSubTurn(ctx context.Context, al *AgentLoop, parentTS *turnState, cfg S
|
||||
})
|
||||
}
|
||||
|
||||
// 7. Result Delivery Strategy (Async vs Sync)
|
||||
//
|
||||
// WHY we have different delivery mechanisms:
|
||||
// ==========================================
|
||||
//
|
||||
// Synchronous sub-turns (Async=false):
|
||||
// - Caller expects immediate result via return value
|
||||
// - Delivering to channel would cause DOUBLE DELIVERY:
|
||||
// 1. Caller gets result from return value
|
||||
// 2. Parent turn would poll channel and get the same result again
|
||||
// - This would confuse the parent turn's result processing logic
|
||||
// - Solution: Skip channel delivery, only return via function return
|
||||
//
|
||||
// Asynchronous sub-turns (Async=true):
|
||||
// - Caller may not immediately process the return value
|
||||
// - Result needs to be available for later polling via pendingResults
|
||||
// - Parent turn can collect multiple async results in batches
|
||||
// - Solution: Deliver to channel AND return via function return
|
||||
//
|
||||
// This must be in defer to ensure delivery even if runTurn panics.
|
||||
// Result Delivery Strategy (Async vs Sync)
|
||||
if cfg.Async {
|
||||
deliverSubTurnResult(parentTS, childID, result)
|
||||
}
|
||||
@@ -296,8 +302,7 @@ func spawnSubTurn(ctx context.Context, al *AgentLoop, parentTS *turnState, cfg S
|
||||
})
|
||||
}()
|
||||
|
||||
// 7. Execute sub-turn via the real agent loop.
|
||||
// Build a child AgentInstance from SubTurnConfig, inheriting defaults from the parent agent.
|
||||
// 8. Execute sub-turn via the real agent loop.
|
||||
result, err = runTurn(childCtx, al, childTS, cfg)
|
||||
|
||||
return result, err
|
||||
|
||||
+174
-16
@@ -278,7 +278,7 @@ func TestSpawnSubTurn_OrphanResultRouting(t *testing.T) {
|
||||
defer func() { MockEventBus.Emit = originalEmit }()
|
||||
|
||||
// Simulate parent finishing before child delivers result
|
||||
parent.Finish()
|
||||
parent.Finish(false)
|
||||
|
||||
// Call deliverSubTurnResult directly to simulate a delayed child
|
||||
deliverSubTurnResult(parent, "delayed-child", &tools.ToolResult{ForLLM: "late result"})
|
||||
@@ -739,8 +739,8 @@ func TestFinishClosesChannel(t *testing.T) {
|
||||
t.Fatal("channel should be open initially")
|
||||
}
|
||||
|
||||
// Call Finish()
|
||||
ts.Finish()
|
||||
// Call Finish() with graceful finish
|
||||
ts.Finish(false)
|
||||
|
||||
// Verify channel is closed
|
||||
_, ok := <-ts.pendingResults
|
||||
@@ -749,7 +749,7 @@ func TestFinishClosesChannel(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify Finish() is idempotent (can be called multiple times)
|
||||
ts.Finish() // Should not panic
|
||||
ts.Finish(false) // Should not panic
|
||||
|
||||
// Verify deliverSubTurnResult doesn't panic when sending to closed channel
|
||||
result := &tools.ToolResult{ForLLM: "late result"}
|
||||
@@ -1153,7 +1153,7 @@ func TestFinish_ConcurrentCalls(t *testing.T) {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
// This should not panic, even when called concurrently
|
||||
parentTS.Finish()
|
||||
parentTS.Finish(false)
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -1219,7 +1219,7 @@ func TestDeliverSubTurnResult_RaceWithFinish(t *testing.T) {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
parentTS.Finish()
|
||||
parentTS.Finish(false)
|
||||
}()
|
||||
|
||||
// Goroutines that deliver results
|
||||
@@ -1291,7 +1291,7 @@ func TestConcurrencySemaphore_Timeout(t *testing.T) {
|
||||
concurrencySem: make(chan struct{}, maxConcurrentSubTurns),
|
||||
}
|
||||
parentTS.ctx, parentTS.cancelFunc = context.WithCancel(ctx)
|
||||
defer parentTS.Finish()
|
||||
defer parentTS.Finish(false)
|
||||
|
||||
// Fill all concurrency slots
|
||||
for i := 0; i < maxConcurrentSubTurns; i++ {
|
||||
@@ -1391,7 +1391,7 @@ func TestContextWrapping_SingleLayer(t *testing.T) {
|
||||
concurrencySem: make(chan struct{}, maxConcurrentSubTurns),
|
||||
}
|
||||
parentTS.ctx, parentTS.cancelFunc = context.WithCancel(ctx)
|
||||
defer parentTS.Finish()
|
||||
defer parentTS.Finish(false)
|
||||
|
||||
// Spawn a sub-turn
|
||||
subTurnCfg := SubTurnConfig{
|
||||
@@ -1457,7 +1457,7 @@ func TestFinish_DrainsChannel(t *testing.T) {
|
||||
}
|
||||
|
||||
// Call Finish() - it should drain the channel
|
||||
parentTS.Finish()
|
||||
parentTS.Finish(false)
|
||||
|
||||
// Verify all results were drained and emitted as orphan events
|
||||
mu.Lock()
|
||||
@@ -1505,7 +1505,7 @@ func TestSyncSubTurn_NoChannelDelivery(t *testing.T) {
|
||||
concurrencySem: make(chan struct{}, maxConcurrentSubTurns),
|
||||
}
|
||||
parentTS.ctx, parentTS.cancelFunc = context.WithCancel(ctx)
|
||||
defer parentTS.Finish()
|
||||
defer parentTS.Finish(false)
|
||||
|
||||
// Spawn a SYNCHRONOUS sub-turn (Async=false)
|
||||
subTurnCfg := SubTurnConfig{
|
||||
@@ -1562,7 +1562,7 @@ func TestAsyncSubTurn_ChannelDelivery(t *testing.T) {
|
||||
concurrencySem: make(chan struct{}, maxConcurrentSubTurns),
|
||||
}
|
||||
parentTS.ctx, parentTS.cancelFunc = context.WithCancel(ctx)
|
||||
defer parentTS.Finish()
|
||||
defer parentTS.Finish(false)
|
||||
|
||||
// Spawn an ASYNCHRONOUS sub-turn (Async=true)
|
||||
subTurnCfg := SubTurnConfig{
|
||||
@@ -1623,7 +1623,7 @@ func TestChannelFull_OrphanResults(t *testing.T) {
|
||||
concurrencySem: make(chan struct{}, maxConcurrentSubTurns),
|
||||
}
|
||||
parentTS.ctx, parentTS.cancelFunc = context.WithCancel(ctx)
|
||||
defer parentTS.Finish()
|
||||
defer parentTS.Finish(false)
|
||||
|
||||
// Send more results than the channel capacity (16)
|
||||
const numResults = 25
|
||||
@@ -1720,7 +1720,7 @@ func TestGrandchildAbort_CascadingCancellation(t *testing.T) {
|
||||
}
|
||||
|
||||
// Hard abort the grandparent
|
||||
grandparentTS.Finish()
|
||||
grandparentTS.Finish(false)
|
||||
|
||||
// Wait a bit for cancellation to propagate
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
@@ -1793,7 +1793,7 @@ func TestSpawnDuringAbort_RaceCondition(t *testing.T) {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
parentTS.Finish()
|
||||
parentTS.Finish(false)
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
@@ -1904,7 +1904,7 @@ func TestAsyncSubTurn_ParentFinishesEarly(t *testing.T) {
|
||||
// Parent finishes quickly (after 100ms), while SubTurn is still running
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
t.Log("Parent finishing early...")
|
||||
parentTS.Finish()
|
||||
parentTS.Finish(false)
|
||||
|
||||
// Wait for SubTurn to complete (or be cancelled)
|
||||
wg.Wait()
|
||||
@@ -1980,7 +1980,7 @@ func TestAsyncSubTurn_ParentWaitsForChild(t *testing.T) {
|
||||
t.Log("SubTurn completed, parent now finishing")
|
||||
|
||||
// Now parent can finish safely
|
||||
parentTS.Finish()
|
||||
parentTS.Finish(false)
|
||||
|
||||
// Check the result
|
||||
if subTurnErr != nil {
|
||||
@@ -2006,3 +2006,161 @@ func TestAsyncSubTurn_ParentWaitsForChild(t *testing.T) {
|
||||
t.Log("No result in channel (expected since we waited)")
|
||||
}
|
||||
}
|
||||
|
||||
// ====================== Graceful vs Hard Finish Tests ======================
|
||||
|
||||
// TestFinish_GracefulVsHard verifies the behavior difference between:
|
||||
// - Finish(false): graceful finish, signals parentEnded but doesn't cancel children
|
||||
// - Finish(true): hard abort, immediately cancels all children
|
||||
func TestFinish_GracefulVsHard(t *testing.T) {
|
||||
// Test 1: Graceful finish should set parentEnded but not cancel context
|
||||
t.Run("Graceful_SetsParentEnded", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
ts := &turnState{
|
||||
ctx: ctx,
|
||||
turnID: "graceful-test",
|
||||
depth: 0,
|
||||
pendingResults: make(chan *tools.ToolResult, 16),
|
||||
}
|
||||
ts.ctx, ts.cancelFunc = context.WithCancel(ctx)
|
||||
|
||||
// Finish gracefully
|
||||
ts.Finish(false)
|
||||
|
||||
// Verify parentEnded is set
|
||||
if !ts.parentEnded.Load() {
|
||||
t.Error("parentEnded should be true after graceful finish")
|
||||
}
|
||||
|
||||
// Verify context is NOT cancelled (for graceful finish, children continue)
|
||||
// Note: In graceful mode, we don't call cancelFunc()
|
||||
// But since we're using WithCancel on the same ctx, it might be cancelled
|
||||
// Let's check that the context is still valid for a moment
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
// Context might be cancelled by the deferred cancel() in test, which is fine
|
||||
})
|
||||
|
||||
// Test 2: Hard abort should cancel context immediately
|
||||
t.Run("Hard_CancelsContext", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
ts := &turnState{
|
||||
ctx: ctx,
|
||||
turnID: "hard-test",
|
||||
depth: 0,
|
||||
pendingResults: make(chan *tools.ToolResult, 16),
|
||||
}
|
||||
ts.ctx, ts.cancelFunc = context.WithCancel(ctx)
|
||||
|
||||
// Finish with hard abort
|
||||
ts.Finish(true)
|
||||
|
||||
// Verify context is cancelled
|
||||
select {
|
||||
case <-ts.ctx.Done():
|
||||
t.Log("✓ Context cancelled after hard abort")
|
||||
default:
|
||||
t.Error("Context should be cancelled after hard abort")
|
||||
}
|
||||
})
|
||||
|
||||
// Test 3: IsParentEnded returns correct value
|
||||
t.Run("IsParentEnded", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
parentTS := &turnState{
|
||||
ctx: ctx,
|
||||
turnID: "parent-isended-test",
|
||||
depth: 0,
|
||||
pendingResults: make(chan *tools.ToolResult, 16),
|
||||
}
|
||||
parentTS.ctx, parentTS.cancelFunc = context.WithCancel(ctx)
|
||||
|
||||
childTS := &turnState{
|
||||
ctx: ctx,
|
||||
turnID: "child-isended-test",
|
||||
depth: 1,
|
||||
parentTurnState: parentTS,
|
||||
pendingResults: make(chan *tools.ToolResult, 16),
|
||||
}
|
||||
|
||||
// Before parent finishes
|
||||
if childTS.IsParentEnded() {
|
||||
t.Error("IsParentEnded should be false before parent finishes")
|
||||
}
|
||||
|
||||
// Finish parent gracefully
|
||||
parentTS.Finish(false)
|
||||
|
||||
// After parent finishes
|
||||
if !childTS.IsParentEnded() {
|
||||
t.Error("IsParentEnded should be true after parent finishes gracefully")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestSubTurn_IndependentContext verifies that SubTurns use independent contexts
|
||||
// that don't get cancelled when the parent finishes gracefully.
|
||||
func TestSubTurn_IndependentContext(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Provider: "mock",
|
||||
},
|
||||
},
|
||||
}
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &slowMockProvider{delay: 500 * time.Millisecond}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
ctx := context.Background()
|
||||
parentTS := &turnState{
|
||||
ctx: ctx,
|
||||
turnID: "parent-independent",
|
||||
depth: 0,
|
||||
session: newEphemeralSession(nil),
|
||||
pendingResults: make(chan *tools.ToolResult, 16),
|
||||
concurrencySem: make(chan struct{}, maxConcurrentSubTurns),
|
||||
}
|
||||
parentTS.ctx, parentTS.cancelFunc = context.WithCancel(ctx)
|
||||
|
||||
var subTurnErr error
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Spawn SubTurn with Critical=true (should continue after parent finishes)
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
subTurnCfg := SubTurnConfig{
|
||||
Model: "slow-model",
|
||||
Async: true,
|
||||
Critical: true, // Critical SubTurn should continue
|
||||
}
|
||||
_, subTurnErr = spawnSubTurn(parentTS.ctx, al, parentTS, subTurnCfg)
|
||||
}()
|
||||
|
||||
// Let SubTurn start
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Parent finishes gracefully (should NOT cancel SubTurn)
|
||||
parentTS.Finish(false)
|
||||
t.Log("Parent finished gracefully, SubTurn should continue")
|
||||
|
||||
// Wait for SubTurn to complete
|
||||
wg.Wait()
|
||||
|
||||
// SubTurn should complete without context cancelled error
|
||||
// (because it uses independent context now)
|
||||
if subTurnErr != nil {
|
||||
t.Logf("SubTurn error: %v", subTurnErr)
|
||||
// The error might be context.DeadlineExceeded if timeout is too short
|
||||
// but should NOT be context.Canceled from parent
|
||||
if errors.Is(subTurnErr, context.Canceled) {
|
||||
t.Error("SubTurn should not be cancelled by parent's graceful finish")
|
||||
}
|
||||
} else {
|
||||
t.Log("✓ SubTurn completed successfully (independent context)")
|
||||
}
|
||||
}
|
||||
|
||||
+54
-13
@@ -3,6 +3,7 @@ package agent
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
@@ -44,6 +45,16 @@ type turnState struct {
|
||||
isFinished bool // MUST be accessed under mu lock
|
||||
closeOnce sync.Once // Ensures pendingResults channel is closed exactly once
|
||||
concurrencySem chan struct{} // Limits concurrent child sub-turns
|
||||
|
||||
// parentEnded signals that the parent turn has finished gracefully.
|
||||
// Child SubTurns should check this via IsParentEnded() to decide whether
|
||||
// to continue running (Critical=true) or exit gracefully (Critical=false).
|
||||
parentEnded atomic.Bool
|
||||
|
||||
// parentTurnState holds a reference to the parent turnState.
|
||||
// This allows child SubTurns to check if the parent has ended.
|
||||
// Nil for root turns.
|
||||
parentTurnState *turnState
|
||||
}
|
||||
|
||||
// ====================== Public API ======================
|
||||
@@ -99,12 +110,13 @@ func newTurnState(ctx context.Context, id string, parent *turnState) *turnState
|
||||
// (spawnSubTurn) already creates one. The turnState stores the context and
|
||||
// cancelFunc provided by the caller to avoid redundant context wrapping.
|
||||
return &turnState{
|
||||
ctx: ctx,
|
||||
cancelFunc: nil, // Will be set by the caller
|
||||
turnID: id,
|
||||
parentTurnID: parent.turnID,
|
||||
depth: parent.depth + 1,
|
||||
session: newEphemeralSession(parent.session),
|
||||
ctx: ctx,
|
||||
cancelFunc: nil, // Will be set by the caller
|
||||
turnID: id,
|
||||
parentTurnID: parent.turnID,
|
||||
depth: parent.depth + 1,
|
||||
session: newEphemeralSession(parent.session),
|
||||
parentTurnState: parent, // Store reference to parent for IsParentEnded() checks
|
||||
// NOTE: In this PoC, I use a fixed-size channel (16).
|
||||
// Under high concurrency or long-running sub-turns, this might fill up and cause
|
||||
// intermediate results to be discarded in deliverSubTurnResult.
|
||||
@@ -114,18 +126,47 @@ func newTurnState(ctx context.Context, id string, parent *turnState) *turnState
|
||||
}
|
||||
}
|
||||
|
||||
// Finish marks the turn as finished and cancels its context, aborting any running sub-turns.
|
||||
// It also closes the pendingResults channel to signal that no more results will be delivered.
|
||||
// This method is safe to call multiple times - the channel will only be closed once.
|
||||
// Any results remaining in the channel after close will be drained and emitted as orphan events.
|
||||
func (ts *turnState) Finish() {
|
||||
// IsParentEnded returns true if the parent turn has finished gracefully.
|
||||
// This is safe to call from child SubTurn goroutines.
|
||||
// Returns false if this is a root turn (no parent).
|
||||
func (ts *turnState) IsParentEnded() bool {
|
||||
if ts.parentTurnState == nil {
|
||||
return false
|
||||
}
|
||||
return ts.parentTurnState.parentEnded.Load()
|
||||
}
|
||||
|
||||
// IsParentEnded is a convenience method to check if parent ended.
|
||||
// It returns the value of the parent's parentEnded atomic flag.
|
||||
|
||||
// Finish marks the turn as finished.
|
||||
//
|
||||
// If isHardAbort is true (Hard Abort):
|
||||
// - Cancels all child contexts immediately via cancelFunc
|
||||
// - Used for user-initiated termination (e.g., "stop now")
|
||||
//
|
||||
// If isHardAbort is false (Graceful Finish):
|
||||
// - Only signals parentEnded for graceful child exit
|
||||
// - Children check IsParentEnded() and decide whether to continue or exit
|
||||
// - Critical SubTurns continue running and deliver orphan results
|
||||
// - Non-Critical SubTurns exit gracefully without error
|
||||
//
|
||||
// In both cases, the pendingResults channel is closed to signal
|
||||
// that no more results will be delivered.
|
||||
func (ts *turnState) Finish(isHardAbort bool) {
|
||||
ts.mu.Lock()
|
||||
ts.isFinished = true
|
||||
resultChan := ts.pendingResults
|
||||
ts.mu.Unlock()
|
||||
|
||||
if ts.cancelFunc != nil {
|
||||
ts.cancelFunc()
|
||||
if isHardAbort {
|
||||
// Hard abort: immediately cancel all children
|
||||
if ts.cancelFunc != nil {
|
||||
ts.cancelFunc()
|
||||
}
|
||||
} else {
|
||||
// Graceful finish: signal parent ended, let children decide
|
||||
ts.parentEnded.Store(true)
|
||||
}
|
||||
|
||||
// Use sync.Once to ensure the channel is closed exactly once, even if Finish() is called concurrently.
|
||||
|
||||
Reference in New Issue
Block a user