mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
fix(agent): resolve critical race conditions and resource leaks in SubTurn
- Fix turnState hierarchy corruption when SubTurns recursively call runAgentLoop by checking context for existing turnState before creating new root - Fix deadlock risk in deliverSubTurnResult by separating lock and channel operations - Fix session rollback race in HardAbort by calling Finish() before rollback - Fix resource leak by closing pendingResults channel in Finish() with panic recovery - Add thread-safety documentation for childTurnIDs and isFinished fields - Move globalTurnCounter to AgentLoop.subTurnCounter to prevent ID conflicts - Improve semaphore acquisition to ensure release even on early validation failures - Document design choice: ephemeral sessions start empty for complete isolation - Add 5 new tests: hierarchy, deadlock, order, channel close, and semaphore
This commit is contained in:
+50
-32
@@ -36,21 +36,22 @@ import (
|
||||
)
|
||||
|
||||
type AgentLoop struct {
|
||||
bus *bus.MessageBus
|
||||
cfg *config.Config
|
||||
registry *AgentRegistry
|
||||
state *state.Manager
|
||||
running atomic.Bool
|
||||
summarizing sync.Map
|
||||
fallback *providers.FallbackChain
|
||||
channelManager *channels.Manager
|
||||
mediaStore media.MediaStore
|
||||
transcriber voice.Transcriber
|
||||
cmdRegistry *commands.Registry
|
||||
mcp mcpRuntime
|
||||
bus *bus.MessageBus
|
||||
cfg *config.Config
|
||||
registry *AgentRegistry
|
||||
state *state.Manager
|
||||
running atomic.Bool
|
||||
summarizing sync.Map
|
||||
fallback *providers.FallbackChain
|
||||
channelManager *channels.Manager
|
||||
mediaStore media.MediaStore
|
||||
transcriber voice.Transcriber
|
||||
cmdRegistry *commands.Registry
|
||||
mcp mcpRuntime
|
||||
steering *steeringQueue
|
||||
subTurnResults sync.Map // key: sessionKey (string), value: chan *tools.ToolResult
|
||||
activeTurnStates sync.Map // key: sessionKey (string), value: *turnState
|
||||
subTurnCounter atomic.Int64 // Counter for generating unique SubTurn IDs
|
||||
mu sync.RWMutex
|
||||
// Track active requests for safe provider cleanup
|
||||
activeRequests sync.WaitGroup
|
||||
@@ -964,25 +965,39 @@ func (al *AgentLoop) runAgentLoop(
|
||||
agent *AgentInstance,
|
||||
opts processOptions,
|
||||
) (string, error) {
|
||||
// Initialize a root TurnState for this iteration, allowing sub-turns to be spawned.
|
||||
rootTS := &turnState{
|
||||
ctx: ctx,
|
||||
turnID: opts.SessionKey, // Associate this turn graph with the current session key
|
||||
depth: 0,
|
||||
session: agent.Sessions,
|
||||
initialHistoryLength: len(agent.Sessions.GetHistory("")), // Snapshot for rollback on hard abort
|
||||
pendingResults: make(chan *tools.ToolResult, 16),
|
||||
concurrencySem: make(chan struct{}, 5), // maxConcurrentSubTurns
|
||||
// Check if we're already inside a SubTurn (context already has a turnState).
|
||||
// If so, reuse it instead of creating a new root turnState.
|
||||
// This prevents turnState hierarchy corruption when SubTurns recursively call runAgentLoop.
|
||||
existingTS := turnStateFromContext(ctx)
|
||||
var rootTS *turnState
|
||||
var isRootTurn bool
|
||||
|
||||
if existingTS != nil {
|
||||
// We're inside a SubTurn — reuse the existing turnState
|
||||
rootTS = existingTS
|
||||
isRootTurn = false
|
||||
} else {
|
||||
// This is a top-level turn — initialize a new root TurnState
|
||||
rootTS = &turnState{
|
||||
ctx: ctx,
|
||||
turnID: opts.SessionKey, // Associate this turn graph with the current session key
|
||||
depth: 0,
|
||||
session: agent.Sessions,
|
||||
initialHistoryLength: len(agent.Sessions.GetHistory("")), // Snapshot for rollback on hard abort
|
||||
pendingResults: make(chan *tools.ToolResult, 16),
|
||||
concurrencySem: make(chan struct{}, 5), // maxConcurrentSubTurns
|
||||
}
|
||||
ctx = withTurnState(ctx, rootTS)
|
||||
isRootTurn = true
|
||||
|
||||
// Register this root turn state so HardAbort can find it
|
||||
al.activeTurnStates.Store(opts.SessionKey, rootTS)
|
||||
defer al.activeTurnStates.Delete(opts.SessionKey)
|
||||
|
||||
// Ensure the parent's pending results channel is cleaned up when this root turn finishes
|
||||
defer al.unregisterSubTurnResultChannel(rootTS.turnID)
|
||||
al.registerSubTurnResultChannel(rootTS.turnID, rootTS.pendingResults)
|
||||
}
|
||||
ctx = withTurnState(ctx, rootTS)
|
||||
|
||||
// Register this root turn state so HardAbort can find it
|
||||
al.activeTurnStates.Store(opts.SessionKey, rootTS)
|
||||
defer al.activeTurnStates.Delete(opts.SessionKey)
|
||||
|
||||
// Ensure the parent's pending results channel is cleaned up when this root turn finishes
|
||||
defer al.unregisterSubTurnResultChannel(rootTS.turnID)
|
||||
al.registerSubTurnResultChannel(rootTS.turnID, rootTS.pendingResults)
|
||||
|
||||
// 0. Record last channel for heartbeat notifications (skip internal channels and cli)
|
||||
if opts.Channel != "" && opts.ChatID != "" {
|
||||
@@ -1028,8 +1043,11 @@ func (al *AgentLoop) runAgentLoop(
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Signal completion to rootTS so it knows it is finished, terminating any active sub-turns
|
||||
rootTS.Finish()
|
||||
// Signal completion to rootTS so it knows it is finished, terminating any active sub-turns.
|
||||
// Only call Finish() if this is a root turn (not a SubTurn recursively calling runAgentLoop).
|
||||
if isRootTurn {
|
||||
rootTS.Finish()
|
||||
}
|
||||
|
||||
// If last tool had ForUser content and we already sent it, we might not need to send final response
|
||||
// This is controlled by the tool's Silent flag and ForUser content
|
||||
|
||||
@@ -255,7 +255,13 @@ func (al *AgentLoop) HardAbort(sessionKey string) error {
|
||||
"initial_history_length": ts.initialHistoryLength,
|
||||
})
|
||||
|
||||
// Rollback session history to the state before this turn started
|
||||
// IMPORTANT: Trigger cascading cancellation FIRST to stop all child SubTurns
|
||||
// from adding more messages to the session. This prevents race conditions
|
||||
// where rollback happens while children are still writing.
|
||||
ts.Finish()
|
||||
|
||||
// Rollback session history to the state before this turn started.
|
||||
// This must happen AFTER Finish() to ensure no child turns are still writing.
|
||||
if ts.session != nil {
|
||||
currentHistory := ts.session.GetHistory("")
|
||||
if len(currentHistory) > ts.initialHistoryLength {
|
||||
@@ -268,8 +274,5 @@ func (al *AgentLoop) HardAbort(sessionKey string) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Trigger cascading cancellation to all child SubTurns
|
||||
ts.Finish()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
+67
-31
@@ -5,7 +5,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
@@ -14,8 +13,8 @@ import (
|
||||
|
||||
// ====================== Config & Constants ======================
|
||||
const (
|
||||
maxSubTurnDepth = 3
|
||||
maxConcurrentSubTurns = 5
|
||||
maxSubTurnDepth = 3
|
||||
maxConcurrentSubTurns = 5
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -78,20 +77,19 @@ type turnState struct {
|
||||
turnID string
|
||||
parentTurnID string
|
||||
depth int
|
||||
childTurnIDs []string
|
||||
childTurnIDs []string // MUST be accessed under mu lock or maybe add a getter method
|
||||
pendingResults chan *tools.ToolResult
|
||||
session session.SessionStore
|
||||
initialHistoryLength int // Snapshot of session history length at turn start, for rollback on hard abort
|
||||
initialHistoryLength int // Snapshot of session history length at turn start, for rollback on hard abort
|
||||
mu sync.Mutex
|
||||
isFinished bool // Marks if the parent Turn has ended
|
||||
isFinished bool // MUST be accessed under mu lock
|
||||
concurrencySem chan struct{} // Limits concurrent child sub-turns
|
||||
}
|
||||
|
||||
// ====================== Helper Functions ======================
|
||||
var globalTurnCounter int64
|
||||
|
||||
func generateTurnID() string {
|
||||
return fmt.Sprintf("subturn-%d", atomic.AddInt64(&globalTurnCounter, 1))
|
||||
func (al *AgentLoop) generateSubTurnID() string {
|
||||
return fmt.Sprintf("subturn-%d", al.subTurnCounter.Add(1))
|
||||
}
|
||||
|
||||
func newTurnState(ctx context.Context, id string, parent *turnState) *turnState {
|
||||
@@ -113,13 +111,27 @@ func newTurnState(ctx context.Context, id string, parent *turnState) *turnState
|
||||
}
|
||||
|
||||
// Finish marks the turn as finished and cancels its context, aborting any running sub-turns.
|
||||
// It also closes the pendingResults channel to signal that no more results will be delivered.
|
||||
func (ts *turnState) Finish() {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
|
||||
if ts.isFinished {
|
||||
// Already finished - avoid double close of channel
|
||||
return
|
||||
}
|
||||
|
||||
ts.isFinished = true
|
||||
|
||||
if ts.cancelFunc != nil {
|
||||
ts.cancelFunc()
|
||||
}
|
||||
|
||||
// Close the pendingResults channel to signal no more results will arrive.
|
||||
// This prevents goroutine leaks from readers waiting on the channel.
|
||||
if ts.pendingResults != nil {
|
||||
close(ts.pendingResults)
|
||||
}
|
||||
}
|
||||
|
||||
// ephemeralSessionStore is a pure in-memory SessionStore for SubTurns.
|
||||
@@ -186,6 +198,24 @@ func newEphemeralSession(_ session.SessionStore) session.SessionStore {
|
||||
|
||||
// ====================== Core Function: spawnSubTurn ======================
|
||||
func spawnSubTurn(ctx context.Context, al *AgentLoop, parentTS *turnState, cfg SubTurnConfig) (result *tools.ToolResult, err error) {
|
||||
// 0. Acquire concurrency semaphore FIRST to ensure it's released even if early validation fails.
|
||||
// Blocks if parent already has maxConcurrentSubTurns running.
|
||||
// Also respects context cancellation so we don't block forever if parent is aborted.
|
||||
var semAcquired bool
|
||||
if parentTS.concurrencySem != nil {
|
||||
select {
|
||||
case parentTS.concurrencySem <- struct{}{}:
|
||||
semAcquired = true
|
||||
defer func() {
|
||||
if semAcquired {
|
||||
<-parentTS.concurrencySem
|
||||
}
|
||||
}()
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// 1. Depth limit check
|
||||
if parentTS.depth >= maxSubTurnDepth {
|
||||
return nil, ErrDepthLimitExceeded
|
||||
@@ -196,42 +226,31 @@ func spawnSubTurn(ctx context.Context, al *AgentLoop, parentTS *turnState, cfg S
|
||||
return nil, ErrInvalidSubTurnConfig
|
||||
}
|
||||
|
||||
// 3. Acquire concurrency semaphore — blocks if parent already has maxConcurrentSubTurns running.
|
||||
// Also respects context cancellation so we don't block forever if parent is aborted.
|
||||
if parentTS.concurrencySem != nil {
|
||||
select {
|
||||
case parentTS.concurrencySem <- struct{}{}:
|
||||
defer func() { <-parentTS.concurrencySem }()
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// Create a sub-context for the child turn to support cancellation
|
||||
childCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
// 4. Create child Turn state
|
||||
childID := generateTurnID()
|
||||
// 3. Create child Turn state
|
||||
childID := al.generateSubTurnID()
|
||||
childTS := newTurnState(childCtx, childID, parentTS)
|
||||
|
||||
// 5. Establish parent-child relationship (thread-safe)
|
||||
// 4. Establish parent-child relationship (thread-safe)
|
||||
parentTS.mu.Lock()
|
||||
parentTS.childTurnIDs = append(parentTS.childTurnIDs, childID)
|
||||
parentTS.mu.Unlock()
|
||||
|
||||
// 6. Register the parent's pendingResults channel so the parent loop can poll it
|
||||
// 5. Register the parent's pendingResults channel so the parent loop can poll it
|
||||
al.registerSubTurnResultChannel(parentTS.turnID, parentTS.pendingResults)
|
||||
defer al.unregisterSubTurnResultChannel(parentTS.turnID)
|
||||
|
||||
// 7. Emit Spawn event (currently using Mock, will be replaced by real EventBus)
|
||||
// 6. Emit Spawn event (currently using Mock, will be replaced by real EventBus)
|
||||
MockEventBus.Emit(SubTurnSpawnEvent{
|
||||
ParentID: parentTS.turnID,
|
||||
ChildID: childID,
|
||||
Config: cfg,
|
||||
})
|
||||
|
||||
// 8. Defer emitting End event, and recover from panics to ensure it's always fired
|
||||
// 7. Defer emitting End event, and recover from panics to ensure it's always fired
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("subturn panicked: %v", r)
|
||||
@@ -244,11 +263,11 @@ func spawnSubTurn(ctx context.Context, al *AgentLoop, parentTS *turnState, cfg S
|
||||
})
|
||||
}()
|
||||
|
||||
// 9. Execute sub-turn via the real agent loop.
|
||||
// 8. Execute sub-turn via the real agent loop.
|
||||
// Build a child AgentInstance from SubTurnConfig, inheriting defaults from the parent agent.
|
||||
result, err = runTurn(childCtx, al, childTS, cfg)
|
||||
|
||||
// 10. Deliver result back to parent Turn
|
||||
// 9. Deliver result back to parent Turn
|
||||
deliverSubTurnResult(parentTS, childID, result)
|
||||
|
||||
return result, err
|
||||
@@ -256,8 +275,11 @@ func spawnSubTurn(ctx context.Context, al *AgentLoop, parentTS *turnState, cfg S
|
||||
|
||||
// ====================== Result Delivery ======================
|
||||
func deliverSubTurnResult(parentTS *turnState, childID string, result *tools.ToolResult) {
|
||||
// Check parent state under lock, but don't hold lock while sending to channel
|
||||
parentTS.mu.Lock()
|
||||
defer parentTS.mu.Unlock()
|
||||
isFinished := parentTS.isFinished
|
||||
resultChan := parentTS.pendingResults
|
||||
parentTS.mu.Unlock()
|
||||
|
||||
// Emit ResultDelivered event
|
||||
MockEventBus.Emit(SubTurnResultDeliveredEvent{
|
||||
@@ -266,10 +288,24 @@ func deliverSubTurnResult(parentTS *turnState, childID string, result *tools.Too
|
||||
Result: result,
|
||||
})
|
||||
|
||||
if !parentTS.isFinished {
|
||||
if !isFinished && resultChan != nil {
|
||||
// Parent Turn is still running → Place in pending queue (handled automatically by parent loop in next round)
|
||||
// Use defer/recover to handle the case where the channel is closed between our check and the send.
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
// Channel was closed - treat as orphan result
|
||||
if result != nil {
|
||||
MockEventBus.Emit(SubTurnOrphanResultEvent{
|
||||
ParentID: parentTS.turnID,
|
||||
ChildID: childID,
|
||||
Result: result,
|
||||
})
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case parentTS.pendingResults <- result:
|
||||
case resultChan <- result:
|
||||
default:
|
||||
fmt.Println("[SubTurn] warning: pendingResults channel full")
|
||||
}
|
||||
|
||||
@@ -2,8 +2,11 @@ package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
@@ -500,3 +503,221 @@ func TestHardAbortSessionRollback(t *testing.T) {
|
||||
t.Error("history content does not match initial state after rollback")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNestedSubTurnHierarchy verifies that nested SubTurns maintain correct
|
||||
// parent-child relationships and depth tracking when recursively calling runAgentLoop.
|
||||
func TestNestedSubTurnHierarchy(t *testing.T) {
|
||||
al, _, _, _, cleanup := newTestAgentLoop(t)
|
||||
defer cleanup()
|
||||
|
||||
// Track spawned turns and their depths
|
||||
type turnInfo struct {
|
||||
parentID string
|
||||
childID string
|
||||
depth int
|
||||
}
|
||||
var spawnedTurns []turnInfo
|
||||
var mu sync.Mutex
|
||||
|
||||
// Override MockEventBus to capture spawn events
|
||||
originalEmit := MockEventBus.Emit
|
||||
defer func() { MockEventBus.Emit = originalEmit }()
|
||||
|
||||
MockEventBus.Emit = func(event any) {
|
||||
if spawnEvent, ok := event.(SubTurnSpawnEvent); ok {
|
||||
mu.Lock()
|
||||
// Extract depth from context (we'll verify this matches expected depth)
|
||||
spawnedTurns = append(spawnedTurns, turnInfo{
|
||||
parentID: spawnEvent.ParentID,
|
||||
childID: spawnEvent.ChildID,
|
||||
})
|
||||
mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Create a root turn
|
||||
rootSession := &ephemeralSessionStore{}
|
||||
rootTS := &turnState{
|
||||
ctx: context.Background(),
|
||||
turnID: "root-turn",
|
||||
depth: 0,
|
||||
session: rootSession,
|
||||
pendingResults: make(chan *tools.ToolResult, 16),
|
||||
concurrencySem: make(chan struct{}, 5),
|
||||
}
|
||||
|
||||
// Spawn a child (depth 1)
|
||||
childCfg := SubTurnConfig{Model: "gpt-4o-mini"}
|
||||
_, err := spawnSubTurn(context.Background(), al, rootTS, childCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to spawn child: %v", err)
|
||||
}
|
||||
|
||||
// Verify we captured the spawn event
|
||||
mu.Lock()
|
||||
if len(spawnedTurns) != 1 {
|
||||
t.Fatalf("expected 1 spawn event, got %d", len(spawnedTurns))
|
||||
}
|
||||
if spawnedTurns[0].parentID != "root-turn" {
|
||||
t.Errorf("expected parent ID 'root-turn', got %s", spawnedTurns[0].parentID)
|
||||
}
|
||||
mu.Unlock()
|
||||
|
||||
// Verify root turn has the child in its childTurnIDs
|
||||
rootTS.mu.Lock()
|
||||
if len(rootTS.childTurnIDs) != 1 {
|
||||
t.Errorf("expected root to have 1 child, got %d", len(rootTS.childTurnIDs))
|
||||
}
|
||||
rootTS.mu.Unlock()
|
||||
}
|
||||
|
||||
// TestDeliverSubTurnResultNoDeadlock verifies that deliverSubTurnResult doesn't
|
||||
// deadlock when multiple goroutines are accessing the parent turnState concurrently.
|
||||
func TestDeliverSubTurnResultNoDeadlock(t *testing.T) {
|
||||
parent := &turnState{
|
||||
ctx: context.Background(),
|
||||
turnID: "parent-deadlock-test",
|
||||
depth: 0,
|
||||
pendingResults: make(chan *tools.ToolResult, 2), // Small buffer to test blocking
|
||||
isFinished: false,
|
||||
}
|
||||
|
||||
// Simulate multiple child turns delivering results concurrently
|
||||
var wg sync.WaitGroup
|
||||
numChildren := 10
|
||||
|
||||
for i := 0; i < numChildren; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
result := &tools.ToolResult{ForLLM: fmt.Sprintf("result-%d", id)}
|
||||
deliverSubTurnResult(parent, fmt.Sprintf("child-%d", id), result)
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Concurrently read from the channel to prevent blocking
|
||||
go func() {
|
||||
for i := 0; i < numChildren; i++ {
|
||||
select {
|
||||
case <-parent.pendingResults:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Error("timeout waiting for result")
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for all deliveries to complete (with timeout)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// Success - no deadlock
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("deadlock detected: deliverSubTurnResult blocked")
|
||||
}
|
||||
}
|
||||
|
||||
// TestHardAbortOrderOfOperations verifies that HardAbort calls Finish() before
|
||||
// rolling back session history, minimizing the race window where new messages
|
||||
// could be added after rollback.
|
||||
func TestHardAbortOrderOfOperations(t *testing.T) {
|
||||
al, _, _, _, cleanup := newTestAgentLoop(t)
|
||||
defer cleanup()
|
||||
|
||||
sess := &ephemeralSessionStore{
|
||||
history: []providers.Message{
|
||||
{Role: "user", Content: "initial message"},
|
||||
{Role: "assistant", Content: "response 1"},
|
||||
{Role: "user", Content: "follow-up"},
|
||||
},
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
rootTS := &turnState{
|
||||
ctx: ctx,
|
||||
cancelFunc: cancel,
|
||||
turnID: "test-session-order",
|
||||
depth: 0,
|
||||
session: sess,
|
||||
initialHistoryLength: 1, // Snapshot: 1 message
|
||||
pendingResults: make(chan *tools.ToolResult, 16),
|
||||
concurrencySem: make(chan struct{}, 5),
|
||||
}
|
||||
|
||||
al.activeTurnStates.Store("test-session-order", rootTS)
|
||||
|
||||
// Trigger HardAbort
|
||||
err := al.HardAbort("test-session-order")
|
||||
if err != nil {
|
||||
t.Fatalf("HardAbort failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify context was cancelled (Finish() was called)
|
||||
select {
|
||||
case <-rootTS.ctx.Done():
|
||||
// Good - context was cancelled
|
||||
default:
|
||||
t.Error("expected context to be cancelled after HardAbort")
|
||||
}
|
||||
|
||||
// Verify history was rolled back
|
||||
finalHistory := sess.GetHistory("")
|
||||
if len(finalHistory) != 1 {
|
||||
t.Errorf("expected history to rollback to 1 message, got %d", len(finalHistory))
|
||||
}
|
||||
|
||||
if finalHistory[0].Content != "initial message" {
|
||||
t.Error("history content does not match initial state after rollback")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFinishClosesChannel verifies that Finish() closes the pendingResults channel
|
||||
// and that deliverSubTurnResult handles closed channels gracefully.
|
||||
func TestFinishClosesChannel(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
ts := &turnState{
|
||||
ctx: ctx,
|
||||
cancelFunc: cancel,
|
||||
turnID: "test-finish-channel",
|
||||
depth: 0,
|
||||
pendingResults: make(chan *tools.ToolResult, 2),
|
||||
isFinished: false,
|
||||
}
|
||||
|
||||
// Verify channel is open initially
|
||||
select {
|
||||
case ts.pendingResults <- &tools.ToolResult{ForLLM: "test"}:
|
||||
// Good - channel is open
|
||||
// Drain the message we just sent
|
||||
<-ts.pendingResults
|
||||
default:
|
||||
t.Fatal("channel should be open initially")
|
||||
}
|
||||
|
||||
// Call Finish()
|
||||
ts.Finish()
|
||||
|
||||
// Verify channel is closed
|
||||
_, ok := <-ts.pendingResults
|
||||
if ok {
|
||||
t.Error("expected channel to be closed after Finish()")
|
||||
}
|
||||
|
||||
// Verify Finish() is idempotent (can be called multiple times)
|
||||
ts.Finish() // Should not panic
|
||||
|
||||
// Verify deliverSubTurnResult doesn't panic when sending to closed channel
|
||||
result := &tools.ToolResult{ForLLM: "late result"}
|
||||
|
||||
// This should not panic - it should recover and emit OrphanResultEvent
|
||||
deliverSubTurnResult(ts, "child-1", result)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user