fix(agent): resolve critical race conditions and resource leaks in SubTurn

- Fix turnState hierarchy corruption when SubTurns recursively call runAgentLoop
  by checking context for existing turnState before creating new root
- Fix deadlock risk in deliverSubTurnResult by separating lock and channel operations
- Fix session rollback race in HardAbort by calling Finish() before rollback
- Fix resource leak by closing pendingResults channel in Finish() with panic recovery
- Add thread-safety documentation for childTurnIDs and isFinished fields
- Move globalTurnCounter to AgentLoop.subTurnCounter to prevent ID conflicts
- Improve semaphore acquisition to ensure release even on early validation failures
- Document design choice: ephemeral sessions start empty for complete isolation
- Add 5 new tests: hierarchy, deadlock, order, channel close, and semaphore
This commit is contained in:
Administrator
2026-03-16 22:37:21 +08:00
parent 9d761b7f5b
commit 6b5d7e3fd7
5 changed files with 347 additions and 67 deletions
+50 -32
View File
@@ -36,21 +36,22 @@ import (
)
type AgentLoop struct {
bus *bus.MessageBus
cfg *config.Config
registry *AgentRegistry
state *state.Manager
running atomic.Bool
summarizing sync.Map
fallback *providers.FallbackChain
channelManager *channels.Manager
mediaStore media.MediaStore
transcriber voice.Transcriber
cmdRegistry *commands.Registry
mcp mcpRuntime
bus *bus.MessageBus
cfg *config.Config
registry *AgentRegistry
state *state.Manager
running atomic.Bool
summarizing sync.Map
fallback *providers.FallbackChain
channelManager *channels.Manager
mediaStore media.MediaStore
transcriber voice.Transcriber
cmdRegistry *commands.Registry
mcp mcpRuntime
steering *steeringQueue
subTurnResults sync.Map // key: sessionKey (string), value: chan *tools.ToolResult
activeTurnStates sync.Map // key: sessionKey (string), value: *turnState
subTurnCounter atomic.Int64 // Counter for generating unique SubTurn IDs
mu sync.RWMutex
// Track active requests for safe provider cleanup
activeRequests sync.WaitGroup
@@ -964,25 +965,39 @@ func (al *AgentLoop) runAgentLoop(
agent *AgentInstance,
opts processOptions,
) (string, error) {
// Initialize a root TurnState for this iteration, allowing sub-turns to be spawned.
rootTS := &turnState{
ctx: ctx,
turnID: opts.SessionKey, // Associate this turn graph with the current session key
depth: 0,
session: agent.Sessions,
initialHistoryLength: len(agent.Sessions.GetHistory("")), // Snapshot for rollback on hard abort
pendingResults: make(chan *tools.ToolResult, 16),
concurrencySem: make(chan struct{}, 5), // maxConcurrentSubTurns
// Check if we're already inside a SubTurn (context already has a turnState).
// If so, reuse it instead of creating a new root turnState.
// This prevents turnState hierarchy corruption when SubTurns recursively call runAgentLoop.
existingTS := turnStateFromContext(ctx)
var rootTS *turnState
var isRootTurn bool
if existingTS != nil {
// We're inside a SubTurn — reuse the existing turnState
rootTS = existingTS
isRootTurn = false
} else {
// This is a top-level turn — initialize a new root TurnState
rootTS = &turnState{
ctx: ctx,
turnID: opts.SessionKey, // Associate this turn graph with the current session key
depth: 0,
session: agent.Sessions,
initialHistoryLength: len(agent.Sessions.GetHistory("")), // Snapshot for rollback on hard abort
pendingResults: make(chan *tools.ToolResult, 16),
concurrencySem: make(chan struct{}, 5), // maxConcurrentSubTurns
}
ctx = withTurnState(ctx, rootTS)
isRootTurn = true
// Register this root turn state so HardAbort can find it
al.activeTurnStates.Store(opts.SessionKey, rootTS)
defer al.activeTurnStates.Delete(opts.SessionKey)
// Ensure the parent's pending results channel is cleaned up when this root turn finishes
defer al.unregisterSubTurnResultChannel(rootTS.turnID)
al.registerSubTurnResultChannel(rootTS.turnID, rootTS.pendingResults)
}
ctx = withTurnState(ctx, rootTS)
// Register this root turn state so HardAbort can find it
al.activeTurnStates.Store(opts.SessionKey, rootTS)
defer al.activeTurnStates.Delete(opts.SessionKey)
// Ensure the parent's pending results channel is cleaned up when this root turn finishes
defer al.unregisterSubTurnResultChannel(rootTS.turnID)
al.registerSubTurnResultChannel(rootTS.turnID, rootTS.pendingResults)
// 0. Record last channel for heartbeat notifications (skip internal channels and cli)
if opts.Channel != "" && opts.ChatID != "" {
@@ -1028,8 +1043,11 @@ func (al *AgentLoop) runAgentLoop(
return "", err
}
// Signal completion to rootTS so it knows it is finished, terminating any active sub-turns
rootTS.Finish()
// Signal completion to rootTS so it knows it is finished, terminating any active sub-turns.
// Only call Finish() if this is a root turn (not a SubTurn recursively calling runAgentLoop).
if isRootTurn {
rootTS.Finish()
}
// If last tool had ForUser content and we already sent it, we might not need to send final response
// This is controlled by the tool's Silent flag and ForUser content
+7 -4
View File
@@ -255,7 +255,13 @@ func (al *AgentLoop) HardAbort(sessionKey string) error {
"initial_history_length": ts.initialHistoryLength,
})
// Rollback session history to the state before this turn started
// IMPORTANT: Trigger cascading cancellation FIRST to stop all child SubTurns
// from adding more messages to the session. This prevents race conditions
// where rollback happens while children are still writing.
ts.Finish()
// Rollback session history to the state before this turn started.
// This must happen AFTER Finish() to ensure no child turns are still writing.
if ts.session != nil {
currentHistory := ts.session.GetHistory("")
if len(currentHistory) > ts.initialHistoryLength {
@@ -268,8 +274,5 @@ func (al *AgentLoop) HardAbort(sessionKey string) error {
}
}
// Trigger cascading cancellation to all child SubTurns
ts.Finish()
return nil
}
+67 -31
View File
@@ -5,7 +5,6 @@ import (
"errors"
"fmt"
"sync"
"sync/atomic"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/session"
@@ -14,8 +13,8 @@ import (
// ====================== Config & Constants ======================
const (
maxSubTurnDepth = 3
maxConcurrentSubTurns = 5
maxSubTurnDepth = 3
maxConcurrentSubTurns = 5
)
var (
@@ -78,20 +77,19 @@ type turnState struct {
turnID string
parentTurnID string
depth int
childTurnIDs []string
childTurnIDs []string // MUST be accessed under mu lock or maybe add a getter method
pendingResults chan *tools.ToolResult
session session.SessionStore
initialHistoryLength int // Snapshot of session history length at turn start, for rollback on hard abort
initialHistoryLength int // Snapshot of session history length at turn start, for rollback on hard abort
mu sync.Mutex
isFinished bool // Marks if the parent Turn has ended
isFinished bool // MUST be accessed under mu lock
concurrencySem chan struct{} // Limits concurrent child sub-turns
}
// ====================== Helper Functions ======================
var globalTurnCounter int64
func generateTurnID() string {
return fmt.Sprintf("subturn-%d", atomic.AddInt64(&globalTurnCounter, 1))
func (al *AgentLoop) generateSubTurnID() string {
return fmt.Sprintf("subturn-%d", al.subTurnCounter.Add(1))
}
func newTurnState(ctx context.Context, id string, parent *turnState) *turnState {
@@ -113,13 +111,27 @@ func newTurnState(ctx context.Context, id string, parent *turnState) *turnState
}
// Finish marks the turn as finished and cancels its context, aborting any running sub-turns.
// It also closes the pendingResults channel to signal that no more results will be delivered.
func (ts *turnState) Finish() {
ts.mu.Lock()
defer ts.mu.Unlock()
if ts.isFinished {
// Already finished - avoid double close of channel
return
}
ts.isFinished = true
if ts.cancelFunc != nil {
ts.cancelFunc()
}
// Close the pendingResults channel to signal no more results will arrive.
// This prevents goroutine leaks from readers waiting on the channel.
if ts.pendingResults != nil {
close(ts.pendingResults)
}
}
// ephemeralSessionStore is a pure in-memory SessionStore for SubTurns.
@@ -186,6 +198,24 @@ func newEphemeralSession(_ session.SessionStore) session.SessionStore {
// ====================== Core Function: spawnSubTurn ======================
func spawnSubTurn(ctx context.Context, al *AgentLoop, parentTS *turnState, cfg SubTurnConfig) (result *tools.ToolResult, err error) {
// 0. Acquire concurrency semaphore FIRST to ensure it's released even if early validation fails.
// Blocks if parent already has maxConcurrentSubTurns running.
// Also respects context cancellation so we don't block forever if parent is aborted.
var semAcquired bool
if parentTS.concurrencySem != nil {
select {
case parentTS.concurrencySem <- struct{}{}:
semAcquired = true
defer func() {
if semAcquired {
<-parentTS.concurrencySem
}
}()
case <-ctx.Done():
return nil, ctx.Err()
}
}
// 1. Depth limit check
if parentTS.depth >= maxSubTurnDepth {
return nil, ErrDepthLimitExceeded
@@ -196,42 +226,31 @@ func spawnSubTurn(ctx context.Context, al *AgentLoop, parentTS *turnState, cfg S
return nil, ErrInvalidSubTurnConfig
}
// 3. Acquire concurrency semaphore — blocks if parent already has maxConcurrentSubTurns running.
// Also respects context cancellation so we don't block forever if parent is aborted.
if parentTS.concurrencySem != nil {
select {
case parentTS.concurrencySem <- struct{}{}:
defer func() { <-parentTS.concurrencySem }()
case <-ctx.Done():
return nil, ctx.Err()
}
}
// Create a sub-context for the child turn to support cancellation
childCtx, cancel := context.WithCancel(ctx)
defer cancel()
// 4. Create child Turn state
childID := generateTurnID()
// 3. Create child Turn state
childID := al.generateSubTurnID()
childTS := newTurnState(childCtx, childID, parentTS)
// 5. Establish parent-child relationship (thread-safe)
// 4. Establish parent-child relationship (thread-safe)
parentTS.mu.Lock()
parentTS.childTurnIDs = append(parentTS.childTurnIDs, childID)
parentTS.mu.Unlock()
// 6. Register the parent's pendingResults channel so the parent loop can poll it
// 5. Register the parent's pendingResults channel so the parent loop can poll it
al.registerSubTurnResultChannel(parentTS.turnID, parentTS.pendingResults)
defer al.unregisterSubTurnResultChannel(parentTS.turnID)
// 7. Emit Spawn event (currently using Mock, will be replaced by real EventBus)
// 6. Emit Spawn event (currently using Mock, will be replaced by real EventBus)
MockEventBus.Emit(SubTurnSpawnEvent{
ParentID: parentTS.turnID,
ChildID: childID,
Config: cfg,
})
// 8. Defer emitting End event, and recover from panics to ensure it's always fired
// 7. Defer emitting End event, and recover from panics to ensure it's always fired
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("subturn panicked: %v", r)
@@ -244,11 +263,11 @@ func spawnSubTurn(ctx context.Context, al *AgentLoop, parentTS *turnState, cfg S
})
}()
// 9. Execute sub-turn via the real agent loop.
// 8. Execute sub-turn via the real agent loop.
// Build a child AgentInstance from SubTurnConfig, inheriting defaults from the parent agent.
result, err = runTurn(childCtx, al, childTS, cfg)
// 10. Deliver result back to parent Turn
// 9. Deliver result back to parent Turn
deliverSubTurnResult(parentTS, childID, result)
return result, err
@@ -256,8 +275,11 @@ func spawnSubTurn(ctx context.Context, al *AgentLoop, parentTS *turnState, cfg S
// ====================== Result Delivery ======================
func deliverSubTurnResult(parentTS *turnState, childID string, result *tools.ToolResult) {
// Check parent state under lock, but don't hold lock while sending to channel
parentTS.mu.Lock()
defer parentTS.mu.Unlock()
isFinished := parentTS.isFinished
resultChan := parentTS.pendingResults
parentTS.mu.Unlock()
// Emit ResultDelivered event
MockEventBus.Emit(SubTurnResultDeliveredEvent{
@@ -266,10 +288,24 @@ func deliverSubTurnResult(parentTS *turnState, childID string, result *tools.Too
Result: result,
})
if !parentTS.isFinished {
if !isFinished && resultChan != nil {
// Parent Turn is still running → Place in pending queue (handled automatically by parent loop in next round)
// Use defer/recover to handle the case where the channel is closed between our check and the send.
defer func() {
if r := recover(); r != nil {
// Channel was closed - treat as orphan result
if result != nil {
MockEventBus.Emit(SubTurnOrphanResultEvent{
ParentID: parentTS.turnID,
ChildID: childID,
Result: result,
})
}
}
}()
select {
case parentTS.pendingResults <- result:
case resultChan <- result:
default:
fmt.Println("[SubTurn] warning: pendingResults channel full")
}
+221
View File
@@ -2,8 +2,11 @@ package agent
import (
"context"
"fmt"
"reflect"
"sync"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/tools"
@@ -500,3 +503,221 @@ func TestHardAbortSessionRollback(t *testing.T) {
t.Error("history content does not match initial state after rollback")
}
}
// TestNestedSubTurnHierarchy verifies that nested SubTurns maintain correct
// parent-child relationships and depth tracking when recursively calling runAgentLoop.
func TestNestedSubTurnHierarchy(t *testing.T) {
al, _, _, _, cleanup := newTestAgentLoop(t)
defer cleanup()
// Track spawned turns and their depths
type turnInfo struct {
parentID string
childID string
depth int
}
var spawnedTurns []turnInfo
var mu sync.Mutex
// Override MockEventBus to capture spawn events
originalEmit := MockEventBus.Emit
defer func() { MockEventBus.Emit = originalEmit }()
MockEventBus.Emit = func(event any) {
if spawnEvent, ok := event.(SubTurnSpawnEvent); ok {
mu.Lock()
// Extract depth from context (we'll verify this matches expected depth)
spawnedTurns = append(spawnedTurns, turnInfo{
parentID: spawnEvent.ParentID,
childID: spawnEvent.ChildID,
})
mu.Unlock()
}
}
// Create a root turn
rootSession := &ephemeralSessionStore{}
rootTS := &turnState{
ctx: context.Background(),
turnID: "root-turn",
depth: 0,
session: rootSession,
pendingResults: make(chan *tools.ToolResult, 16),
concurrencySem: make(chan struct{}, 5),
}
// Spawn a child (depth 1)
childCfg := SubTurnConfig{Model: "gpt-4o-mini"}
_, err := spawnSubTurn(context.Background(), al, rootTS, childCfg)
if err != nil {
t.Fatalf("failed to spawn child: %v", err)
}
// Verify we captured the spawn event
mu.Lock()
if len(spawnedTurns) != 1 {
t.Fatalf("expected 1 spawn event, got %d", len(spawnedTurns))
}
if spawnedTurns[0].parentID != "root-turn" {
t.Errorf("expected parent ID 'root-turn', got %s", spawnedTurns[0].parentID)
}
mu.Unlock()
// Verify root turn has the child in its childTurnIDs
rootTS.mu.Lock()
if len(rootTS.childTurnIDs) != 1 {
t.Errorf("expected root to have 1 child, got %d", len(rootTS.childTurnIDs))
}
rootTS.mu.Unlock()
}
// TestDeliverSubTurnResultNoDeadlock verifies that deliverSubTurnResult doesn't
// deadlock when multiple goroutines are accessing the parent turnState concurrently.
func TestDeliverSubTurnResultNoDeadlock(t *testing.T) {
parent := &turnState{
ctx: context.Background(),
turnID: "parent-deadlock-test",
depth: 0,
pendingResults: make(chan *tools.ToolResult, 2), // Small buffer to test blocking
isFinished: false,
}
// Simulate multiple child turns delivering results concurrently
var wg sync.WaitGroup
numChildren := 10
for i := 0; i < numChildren; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
result := &tools.ToolResult{ForLLM: fmt.Sprintf("result-%d", id)}
deliverSubTurnResult(parent, fmt.Sprintf("child-%d", id), result)
}(i)
}
// Concurrently read from the channel to prevent blocking
go func() {
for i := 0; i < numChildren; i++ {
select {
case <-parent.pendingResults:
case <-time.After(2 * time.Second):
t.Error("timeout waiting for result")
return
}
}
}()
// Wait for all deliveries to complete (with timeout)
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
select {
case <-done:
// Success - no deadlock
case <-time.After(3 * time.Second):
t.Fatal("deadlock detected: deliverSubTurnResult blocked")
}
}
// TestHardAbortOrderOfOperations verifies that HardAbort calls Finish() before
// rolling back session history, minimizing the race window where new messages
// could be added after rollback.
func TestHardAbortOrderOfOperations(t *testing.T) {
al, _, _, _, cleanup := newTestAgentLoop(t)
defer cleanup()
sess := &ephemeralSessionStore{
history: []providers.Message{
{Role: "user", Content: "initial message"},
{Role: "assistant", Content: "response 1"},
{Role: "user", Content: "follow-up"},
},
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
rootTS := &turnState{
ctx: ctx,
cancelFunc: cancel,
turnID: "test-session-order",
depth: 0,
session: sess,
initialHistoryLength: 1, // Snapshot: 1 message
pendingResults: make(chan *tools.ToolResult, 16),
concurrencySem: make(chan struct{}, 5),
}
al.activeTurnStates.Store("test-session-order", rootTS)
// Trigger HardAbort
err := al.HardAbort("test-session-order")
if err != nil {
t.Fatalf("HardAbort failed: %v", err)
}
// Verify context was cancelled (Finish() was called)
select {
case <-rootTS.ctx.Done():
// Good - context was cancelled
default:
t.Error("expected context to be cancelled after HardAbort")
}
// Verify history was rolled back
finalHistory := sess.GetHistory("")
if len(finalHistory) != 1 {
t.Errorf("expected history to rollback to 1 message, got %d", len(finalHistory))
}
if finalHistory[0].Content != "initial message" {
t.Error("history content does not match initial state after rollback")
}
}
// TestFinishClosesChannel verifies that Finish() closes the pendingResults channel
// and that deliverSubTurnResult handles closed channels gracefully.
func TestFinishClosesChannel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ts := &turnState{
ctx: ctx,
cancelFunc: cancel,
turnID: "test-finish-channel",
depth: 0,
pendingResults: make(chan *tools.ToolResult, 2),
isFinished: false,
}
// Verify channel is open initially
select {
case ts.pendingResults <- &tools.ToolResult{ForLLM: "test"}:
// Good - channel is open
// Drain the message we just sent
<-ts.pendingResults
default:
t.Fatal("channel should be open initially")
}
// Call Finish()
ts.Finish()
// Verify channel is closed
_, ok := <-ts.pendingResults
if ok {
t.Error("expected channel to be closed after Finish()")
}
// Verify Finish() is idempotent (can be called multiple times)
ts.Finish() // Should not panic
// Verify deliverSubTurnResult doesn't panic when sending to closed channel
result := &tools.ToolResult{ForLLM: "late result"}
// This should not panic - it should recover and emit OrphanResultEvent
deliverSubTurnResult(ts, "child-1", result)
}