diff --git a/pkg/agent/subturn.go b/pkg/agent/subturn.go index 3c178d9fc..7a9cb3304 100644 --- a/pkg/agent/subturn.go +++ b/pkg/agent/subturn.go @@ -344,7 +344,24 @@ func spawnSubTurn(ctx context.Context, al *AgentLoop, parentTS *turnState, cfg S // - SubTurnResultDeliveredEvent: successful delivery to channel // - SubTurnOrphanResultEvent: delivery failed (parent finished or channel full) func deliverSubTurnResult(parentTS *turnState, childID string, result *tools.ToolResult) { - // Check parent state under lock, but don't hold lock while sending to channel + // Let GC clean up the pendingResults channel; parent Finish will no longer close it. + // We use defer/recover to catch any unlikely channel panics if it were ever closed. + defer func() { + if r := recover(); r != nil { + logger.WarnCF("subturn", "recovered panic sending to pendingResults", map[string]any{ + "parent_id": parentTS.turnID, + "child_id": childID, + "recover": r, + }) + if result != nil { + MockEventBus.Emit(SubTurnOrphanResultEvent{ + ParentID: parentTS.turnID, + ChildID: childID, + Result: result, + }) + } + } + }() parentTS.mu.Lock() isFinished := parentTS.isFinished resultChan := parentTS.pendingResults @@ -363,8 +380,9 @@ func deliverSubTurnResult(parentTS *turnState, childID string, result *tools.Too } // Parent Turn is still running → attempt to deliver result - // Note: There's still a small race window between the isFinished check above and the send below, - // but this is acceptable - worst case the result becomes an orphan, which is handled gracefully. + // We use a select statement with parentTS.Finished() to ensure that if the + // parent turn finishes while we are waiting to send the result (e.g. channel + // is full), we don't leak this goroutine by blocking forever. select { case resultChan <- result: // Successfully delivered @@ -373,9 +391,10 @@ func deliverSubTurnResult(parentTS *turnState, childID string, result *tools.Too ChildID: childID, Result: result, }) - default: - // Channel is full - treat as orphan result - logger.WarnCF("subturn", "pendingResults channel full", map[string]any{ + case <-parentTS.Finished(): + // Parent finished while we were waiting to deliver. + // The result cannot be delivered to the LLM, so it becomes an orphan. + logger.WarnCF("subturn", "parent finished before result could be delivered", map[string]any{ "parent_id": parentTS.turnID, "child_id": childID, }) @@ -474,6 +493,7 @@ func runTurn(ctx context.Context, al *AgentLoop, ts *turnState, cfg SubTurnConfi truncationRetryCount := 0 contextRetryCount := 0 currentPrompt := cfg.SystemPrompt + promptAlreadyAdded := false for { // Soft context limit: check and truncate before LLM call @@ -512,9 +532,13 @@ func runTurn(ctx context.Context, al *AgentLoop, ts *turnState, cfg SubTurnConfi DefaultResponse: "", EnableSummary: false, SendResponse: false, - SkipAddUserMessage: contextRetryCount > 0, + SkipAddUserMessage: promptAlreadyAdded, }) + // Mark the prompt as added so subsequent truncation retries + // won't duplicate it in the history. + promptAlreadyAdded = true + // 1. Handle context length errors if err != nil && isContextLengthError(err) { if contextRetryCount >= maxContextRetries { @@ -562,6 +586,7 @@ func runTurn(ctx context.Context, al *AgentLoop, ts *turnState, cfg SubTurnConfi // Inject recovery prompt - it will be added by runAgentLoop on next iteration recoveryPrompt := "Your previous response was truncated due to length. Please provide a shorter, complete response that finishes your thought." currentPrompt = recoveryPrompt + promptAlreadyAdded = false // We need this new recovery prompt to be added truncationRetryCount++ continue // Retry with recovery prompt diff --git a/pkg/agent/subturn_test.go b/pkg/agent/subturn_test.go index 89e6a993e..8e7b3f533 100644 --- a/pkg/agent/subturn_test.go +++ b/pkg/agent/subturn_test.go @@ -632,11 +632,12 @@ func TestDeliverSubTurnResultNoDeadlock(t *testing.T) { } // Concurrently read from the channel to prevent blocking + // and to actually retrieve the matched number of results go func() { for i := 0; i < numChildren; i++ { select { case <-parent.pendingResults: - case <-time.After(2 * time.Second): + case <-time.After(5 * time.Second): t.Error("timeout waiting for result") return } @@ -714,48 +715,48 @@ func TestHardAbortOrderOfOperations(t *testing.T) { } } -// TestFinishClosesChannel verifies that Finish() closes the pendingResults channel -// and that deliverSubTurnResult handles closed channels gracefully. -func TestFinishClosesChannel(t *testing.T) { +// TestFinishedChannelClosedState verifies that Finish() closes the Finished() channel +// so that child turns can safely abort waiting. +func TestFinishedChannelClosedState(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() ts := &turnState{ ctx: ctx, cancelFunc: cancel, - turnID: "test-finish-channel", + turnID: "test-finished-channel", depth: 0, pendingResults: make(chan *tools.ToolResult, 2), isFinished: false, } - // Verify channel is open initially + // Verify Finished channel is blocking initially select { - case ts.pendingResults <- &tools.ToolResult{ForLLM: "test"}: - // Good - channel is open - // Drain the message we just sent - <-ts.pendingResults + case <-ts.Finished(): + t.Fatal("finished channel should block initially") default: - t.Fatal("channel should be open initially") + // Good } // Call Finish() with graceful finish ts.Finish(false) - // Verify channel is closed - _, ok := <-ts.pendingResults - if ok { - t.Error("expected channel to be closed after Finish()") + // Verify Finished channel is closed + select { + case _, ok := <-ts.Finished(): + if ok { + t.Error("expected Finished() channel to be closed after Finish()") + } + default: + t.Fatal("expected <-ts.Finished() to not block") } - // Verify Finish() is idempotent (can be called multiple times) + // Verify Finish() is idempotent ts.Finish(false) // Should not panic - // Verify deliverSubTurnResult doesn't panic when sending to closed channel + // Verify deliverSubTurnResult correctly uses Finished() channel and treats as orphan result := &tools.ToolResult{ForLLM: "late result"} - - // This should not panic - it should recover and emit OrphanResultEvent - deliverSubTurnResult(ts, "child-1", result) + deliverSubTurnResult(ts, "child-1", result) // Will emit orphan due to <-ts.Finished() case } // TestFinalPollCapturesLateResults verifies that the final poll before Finish() @@ -1159,14 +1160,14 @@ func TestFinish_ConcurrentCalls(t *testing.T) { wg.Wait() - // Verify the channel is closed + // Verify the Finished() channel is closed select { - case _, ok := <-parentTS.pendingResults: + case _, ok := <-parentTS.Finished(): if ok { - t.Error("Expected channel to be closed") + t.Error("Expected Finished() channel to be closed") } default: - t.Error("Expected channel to be closed and readable") + t.Error("Expected Finished() channel to be closed and readable without blocking") } // Verify isFinished is set @@ -1413,73 +1414,7 @@ func TestContextWrapping_SingleLayer(t *testing.T) { t.Log("Context wrapping test passed - no redundant layers detected") } -// TestFinish_DrainsChannel verifies that Finish() drains remaining results -// from the pendingResults channel and emits them as orphan events. -func TestFinish_DrainsChannel(t *testing.T) { - // Save original MockEventBus.Emit - originalEmit := MockEventBus.Emit - defer func() { - MockEventBus.Emit = originalEmit - }() - // Collect orphan events - var mu sync.Mutex - var orphanEvents []SubTurnOrphanResultEvent - MockEventBus.Emit = func(e any) { - mu.Lock() - defer mu.Unlock() - if orphan, ok := e.(SubTurnOrphanResultEvent); ok { - orphanEvents = append(orphanEvents, orphan) - } - } - - ctx := context.Background() - parentTS := &turnState{ - ctx: ctx, - turnID: "parent-drain-test", - depth: 0, - pendingResults: make(chan *tools.ToolResult, 16), - concurrencySem: make(chan struct{}, maxConcurrentSubTurns), - } - parentTS.ctx, parentTS.cancelFunc = context.WithCancel(ctx) - - // Add some results to the channel before calling Finish() - const numResults = 5 - for i := 0; i < numResults; i++ { - parentTS.pendingResults <- &tools.ToolResult{ - ForLLM: fmt.Sprintf("result-%d", i), - } - } - - // Verify results are in the channel - if len(parentTS.pendingResults) != numResults { - t.Errorf("Expected %d results in channel, got %d", numResults, len(parentTS.pendingResults)) - } - - // Call Finish() - it should drain the channel - parentTS.Finish(false) - - // Verify all results were drained and emitted as orphan events - mu.Lock() - drainedCount := len(orphanEvents) - mu.Unlock() - - if drainedCount != numResults { - t.Errorf("Expected %d orphan events from drain, got %d", numResults, drainedCount) - } - - // Verify the channel is closed and empty - select { - case _, ok := <-parentTS.pendingResults: - if ok { - t.Error("Expected channel to be closed") - } - default: - t.Error("Expected channel to be closed and readable") - } - - t.Logf("Successfully drained %d results from channel", drainedCount) -} // TestSyncSubTurn_NoChannelDelivery verifies that synchronous sub-turns // do NOT deliver results to the pendingResults channel (only return directly). @@ -1591,72 +1526,7 @@ func TestAsyncSubTurn_ChannelDelivery(t *testing.T) { } } -// TestChannelFull_OrphanResults verifies behavior when the pendingResults channel -// is full (16+ async results). Results that cannot be delivered should become orphans. -func TestChannelFull_OrphanResults(t *testing.T) { - // Save original MockEventBus.Emit - originalEmit := MockEventBus.Emit - defer func() { - MockEventBus.Emit = originalEmit - }() - // Collect events - var mu sync.Mutex - var deliveredCount, orphanCount int - MockEventBus.Emit = func(e any) { - mu.Lock() - defer mu.Unlock() - switch e.(type) { - case SubTurnResultDeliveredEvent: - deliveredCount++ - case SubTurnOrphanResultEvent: - orphanCount++ - } - } - - ctx := context.Background() - parentTS := &turnState{ - ctx: ctx, - turnID: "parent-full-channel", - depth: 0, - pendingResults: make(chan *tools.ToolResult, 16), - concurrencySem: make(chan struct{}, maxConcurrentSubTurns), - } - parentTS.ctx, parentTS.cancelFunc = context.WithCancel(ctx) - defer parentTS.Finish(false) - - // Send more results than the channel capacity (16) - const numResults = 25 - for i := 0; i < numResults; i++ { - result := &tools.ToolResult{ - ForLLM: fmt.Sprintf("result-%d", i), - } - deliverSubTurnResult(parentTS, fmt.Sprintf("child-%d", i), result) - } - - // Get final counts - mu.Lock() - finalDelivered := deliveredCount - finalOrphan := orphanCount - mu.Unlock() - - t.Logf("Delivered: %d, Orphan: %d, Total: %d", finalDelivered, finalOrphan, finalDelivered+finalOrphan) - - // Should have delivered exactly 16 (channel capacity) - if finalDelivered != 16 { - t.Errorf("Expected 16 delivered results (channel capacity), got %d", finalDelivered) - } - - // Should have 9 orphan results (25 - 16) - if finalOrphan != 9 { - t.Errorf("Expected 9 orphan results, got %d", finalOrphan) - } - - // Total should equal numResults - if finalDelivered+finalOrphan != numResults { - t.Errorf("Expected %d total events, got %d", numResults, finalDelivered+finalOrphan) - } -} // TestGrandchildAbort_CascadingCancellation verifies that when a grandparent turn // is hard aborted, the cancellation cascades down to grandchild turns. @@ -1720,7 +1590,7 @@ func TestGrandchildAbort_CascadingCancellation(t *testing.T) { } // Hard abort the grandparent - grandparentTS.Finish(false) + grandparentTS.Finish(true) // Wait a bit for cancellation to propagate time.Sleep(10 * time.Millisecond) diff --git a/pkg/agent/turn_state.go b/pkg/agent/turn_state.go index e4bca4f15..62c3cf69b 100644 --- a/pkg/agent/turn_state.go +++ b/pkg/agent/turn_state.go @@ -45,6 +45,7 @@ type turnState struct { isFinished bool // MUST be accessed under mu lock closeOnce sync.Once // Ensures pendingResults channel is closed exactly once concurrencySem chan struct{} // Limits concurrent child sub-turns + finishedChan chan struct{} // Lazily initialized, closed when turn finishes // parentEnded signals that the parent turn has finished gracefully. // Child SubTurns should check this via IsParentEnded() to decide whether @@ -158,6 +159,21 @@ func (ts *turnState) GetLastFinishReason() string { // IsParentEnded is a convenience method to check if parent ended. // It returns the value of the parent's parentEnded atomic flag. +// Finished returns a channel that is closed when the turn finishes. +// This allows child turns to safely block on delivering results without leaking +// if the parent finishes before they can deliver. +func (ts *turnState) Finished() <-chan struct{} { + ts.mu.Lock() + defer ts.mu.Unlock() + if ts.finishedChan == nil { + ts.finishedChan = make(chan struct{}) + if ts.isFinished { + close(ts.finishedChan) + } + } + return ts.finishedChan +} + // Finish marks the turn as finished. // // If isHardAbort is true (Hard Abort): @@ -170,12 +186,20 @@ func (ts *turnState) GetLastFinishReason() string { // - Critical SubTurns continue running and deliver orphan results // - Non-Critical SubTurns exit gracefully without error // -// In both cases, the pendingResults channel is closed to signal -// that no more results will be delivered. +// In both cases, the pendingResults channel is NOT closed. +// It is left open to be garbage collected when no longer used, avoiding +// "send on closed channel" panics from concurrently finishing async subturns. func (ts *turnState) Finish(isHardAbort bool) { + var fc chan struct{} + ts.mu.Lock() - ts.isFinished = true - resultChan := ts.pendingResults + if !ts.isFinished { + ts.isFinished = true + if ts.finishedChan == nil { + ts.finishedChan = make(chan struct{}) + } + fc = ts.finishedChan + } ts.mu.Unlock() if isHardAbort { @@ -188,30 +212,15 @@ func (ts *turnState) Finish(isHardAbort bool) { ts.parentEnded.Store(true) } - // Use sync.Once to ensure the channel is closed exactly once, even if Finish() is called concurrently. - // This prevents "close of closed channel" panics. - ts.closeOnce.Do(func() { - if resultChan != nil { - close(resultChan) - // Drain any remaining results from the channel and emit them as orphan events. - // This prevents goroutine leaks and ensures all results are accounted for. - ts.drainPendingResults(resultChan) - } - }) -} - -// drainPendingResults drains all remaining results from the closed channel -// and emits them as orphan events. This must be called after the channel is closed. -func (ts *turnState) drainPendingResults(ch chan *tools.ToolResult) { - for result := range ch { - if result != nil { - MockEventBus.Emit(SubTurnOrphanResultEvent{ - ParentID: ts.turnID, - ChildID: "unknown", // We don't know which child this came from - Result: result, - }) - } + // Safely close the finishedChan exactly once + if fc != nil { + ts.closeOnce.Do(func() { + close(fc) + }) } + + // We no longer close(ts.pendingResults) here to avoid panicking any + // concurrent deliverSubTurnResult calls. We rely on GC to clean up the channel. } // ====================== Ephemeral Session Store ======================