Files
picoclaw/pkg/agent/subturn_test.go
T
Administrator e20ff43f8b fix(agent): resolve subturn deadlocks, panics and context retry state
This commit addresses several critical concurrency and state management bugs within the SubTurn execution and delivery logic.

1. Fix Goroutine Leak & Deadlock in deliverSubTurnResult:
   - Replaced non-blocking select with a safe blocking select that listens to `resultChan` and a new `<-parentTS.Finished()` channel.
   - This ensures results are not arbitrarily dropped when the channel is full (preventing orphaned valid results), while also guaranteeing the child goroutine safely unblocks and exits if the parent finishes execution early.

2. Prevent "Send on Closed Channel" Fatal Panics:
   - Removed `close(pendingResults)` and `drainPendingResults` from `turnState.Finish()`.
   - The pendingResults channel is now naturally garbage collected, completely eliminating the race condition panic when a child attempts delivery at the exact moment the parent finishes.
   - Added a `defer recover()` failsafe inside deliverSubTurnResult to gracefully emit Orphan events in extreme edge cases.

3. Fix Truncation Recovery Prompt Drop:
   - Fixed the runTurn truncation retry logic by introducing an explicit `promptAlreadyAdded` boolean.
   - Ensures that the dynamically generated `recoveryPrompt` is correctly injected into the LLM history sequence on subsequent iterations, adhering to API roles without duplicating arrays.

4. Test Suite Stabilization:
   - Fixed TestDeliverSubTurnResultNoDeadlock to accurately wait for deterministic deliveries instead of racing timeouts.
   - Replaced defunct closed-channel tests with TestFinishedChannelClosedState matching the new Finished() mechanism.
   - Fixed the Finish(true) parameter in TestGrandchildAbort_CascadingCancellation to correctly validate Context cascade behavior.
   - All tests now pass cleanly without hanging or emitting false positives.
2026-03-18 13:10:36 +08:00

2037 lines
57 KiB
Go

package agent
import (
"context"
"errors"
"fmt"
"reflect"
"sync"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/tools"
)
// ====================== Test Helper: Event Collector ======================
type eventCollector struct {
events []any
}
func (c *eventCollector) collect(e any) {
c.events = append(c.events, e)
}
func (c *eventCollector) hasEventOfType(typ any) bool {
targetType := reflect.TypeOf(typ)
for _, e := range c.events {
if reflect.TypeOf(e) == targetType {
return true
}
}
return false
}
func (c *eventCollector) countOfType(typ any) int {
targetType := reflect.TypeOf(typ)
count := 0
for _, e := range c.events {
if reflect.TypeOf(e) == targetType {
count++
}
}
return count
}
// ====================== Main Test Function ======================
func TestSpawnSubTurn(t *testing.T) {
tests := []struct {
name string
parentDepth int
config SubTurnConfig
wantErr error
wantSpawn bool
wantEnd bool
wantDepthFail bool
}{
{
name: "Basic success path - Single layer sub-turn",
parentDepth: 0,
config: SubTurnConfig{
Model: "gpt-4o-mini",
Tools: []tools.Tool{}, // At least one tool
},
wantErr: nil,
wantSpawn: true,
wantEnd: true,
},
{
name: "Nested 2 layers - Normal",
parentDepth: 1,
config: SubTurnConfig{
Model: "gpt-4o-mini",
Tools: []tools.Tool{},
},
wantErr: nil,
wantSpawn: true,
wantEnd: true,
},
{
name: "Depth limit triggered - 4th layer fails",
parentDepth: 3,
config: SubTurnConfig{
Model: "gpt-4o-mini",
Tools: []tools.Tool{},
},
wantErr: ErrDepthLimitExceeded,
wantSpawn: false,
wantEnd: false,
wantDepthFail: true,
},
{
name: "Invalid config - Empty Model",
parentDepth: 0,
config: SubTurnConfig{
Model: "",
Tools: []tools.Tool{},
},
wantErr: ErrInvalidSubTurnConfig,
wantSpawn: false,
wantEnd: false,
},
}
al, _, _, _, cleanup := newTestAgentLoop(t)
defer cleanup()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Prepare parent Turn
parent := &turnState{
ctx: context.Background(),
turnID: "parent-1",
depth: tt.parentDepth,
childTurnIDs: []string{},
pendingResults: make(chan *tools.ToolResult, 10),
session: &ephemeralSessionStore{},
}
// Replace mock with test collector
collector := &eventCollector{}
originalEmit := MockEventBus.Emit
MockEventBus.Emit = collector.collect
defer func() { MockEventBus.Emit = originalEmit }()
// Execute spawnSubTurn
result, err := spawnSubTurn(context.Background(), al, parent, tt.config)
// Assert errors
if tt.wantErr != nil {
if err == nil || err != tt.wantErr {
t.Errorf("expected error %v, got %v", tt.wantErr, err)
}
return
}
if err != nil {
t.Errorf("unexpected error: %v", err)
return
}
// Verify result
if result == nil {
t.Error("expected non-nil result")
}
// Verify event emission
if tt.wantSpawn {
if !collector.hasEventOfType(SubTurnSpawnEvent{}) {
t.Error("SubTurnSpawnEvent not emitted")
}
}
if tt.wantEnd {
if !collector.hasEventOfType(SubTurnEndEvent{}) {
t.Error("SubTurnEndEvent not emitted")
}
}
// Verify turn tree
if len(parent.childTurnIDs) == 0 && !tt.wantDepthFail {
t.Error("child Turn not added to parent.childTurnIDs")
}
// For synchronous calls (Async=false, the default), result is returned directly
// and should NOT be in pendingResults. The result was already verified above.
// Only async calls (Async=true) would place results in pendingResults.
})
}
}
// ====================== Extra Independent Test: Ephemeral Session Isolation ======================
func TestSpawnSubTurn_EphemeralSessionIsolation(t *testing.T) {
al, _, _, _, cleanup := newTestAgentLoop(t)
defer cleanup()
parentSession := &ephemeralSessionStore{}
parentSession.AddMessage("", "user", "parent msg")
parent := &turnState{
ctx: context.Background(),
turnID: "parent-1",
depth: 0,
pendingResults: make(chan *tools.ToolResult, 1),
session: parentSession,
}
cfg := SubTurnConfig{Model: "gpt-4o-mini", Tools: []tools.Tool{}}
// Record main session length before execution
originalLen := len(parent.session.GetHistory(""))
_, _ = spawnSubTurn(context.Background(), al, parent, cfg)
// After sub-turn ends, main session must remain unchanged
if len(parent.session.GetHistory("")) != originalLen {
t.Error("ephemeral session polluted the main session")
}
}
// ====================== Extra Independent Test: Result Delivery Path (Async) ======================
func TestSpawnSubTurn_ResultDelivery(t *testing.T) {
al, _, _, _, cleanup := newTestAgentLoop(t)
defer cleanup()
parent := &turnState{
ctx: context.Background(),
turnID: "parent-1",
depth: 0,
pendingResults: make(chan *tools.ToolResult, 1),
session: &ephemeralSessionStore{},
}
// Set Async=true to test async result delivery via pendingResults channel
cfg := SubTurnConfig{Model: "gpt-4o-mini", Tools: []tools.Tool{}, Async: true}
_, _ = spawnSubTurn(context.Background(), al, parent, cfg)
// Check if pendingResults received the result (only for async calls)
select {
case res := <-parent.pendingResults:
if res == nil {
t.Error("received nil result in pendingResults")
}
default:
t.Error("result did not enter pendingResults for async call")
}
}
// ====================== Extra Independent Test: Result Delivery Path (Sync) ======================
func TestSpawnSubTurn_ResultDeliverySync(t *testing.T) {
al, _, _, _, cleanup := newTestAgentLoop(t)
defer cleanup()
parent := &turnState{
ctx: context.Background(),
turnID: "parent-sync-1",
depth: 0,
pendingResults: make(chan *tools.ToolResult, 1),
session: &ephemeralSessionStore{},
}
// Sync call (Async=false, the default) - result should be returned directly
cfg := SubTurnConfig{Model: "gpt-4o-mini", Tools: []tools.Tool{}, Async: false}
result, err := spawnSubTurn(context.Background(), al, parent, cfg)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Result should be returned directly
if result == nil {
t.Error("expected non-nil result from sync call")
}
// pendingResults should NOT contain the result (no double delivery)
select {
case <-parent.pendingResults:
t.Error("sync call should not place result in pendingResults (double delivery)")
default:
// Expected - channel should be empty
}
}
// ====================== Extra Independent Test: Orphan Result Routing ======================
func TestSpawnSubTurn_OrphanResultRouting(t *testing.T) {
parentCtx, cancelParent := context.WithCancel(context.Background())
parent := &turnState{
ctx: parentCtx,
cancelFunc: cancelParent,
turnID: "parent-1",
depth: 0,
pendingResults: make(chan *tools.ToolResult, 1),
session: &ephemeralSessionStore{},
}
collector := &eventCollector{}
originalEmit := MockEventBus.Emit
MockEventBus.Emit = collector.collect
defer func() { MockEventBus.Emit = originalEmit }()
// Simulate parent finishing before child delivers result
parent.Finish(false)
// Call deliverSubTurnResult directly to simulate a delayed child
deliverSubTurnResult(parent, "delayed-child", &tools.ToolResult{ForLLM: "late result"})
// Verify Orphan event is emitted
if !collector.hasEventOfType(SubTurnOrphanResultEvent{}) {
t.Error("SubTurnOrphanResultEvent not emitted for finished parent")
}
// Verify history is NOT polluted
if len(parent.session.GetHistory("")) != 0 {
t.Error("Parent history was polluted by orphan result")
}
}
// ====================== Extra Independent Test: Result Channel Registration ======================
func TestSubTurnResultChannelRegistration(t *testing.T) {
al, _, _, _, cleanup := newTestAgentLoop(t)
defer cleanup()
parent := &turnState{
ctx: context.Background(),
turnID: "parent-reg-1",
depth: 0,
pendingResults: make(chan *tools.ToolResult, 4),
session: &ephemeralSessionStore{},
}
cfg := SubTurnConfig{Model: "gpt-4o-mini", Tools: []tools.Tool{}}
// Before spawn: channel should not be registered
if results := al.dequeuePendingSubTurnResults(parent.turnID); results != nil {
t.Error("expected no channel before spawnSubTurn")
}
_, _ = spawnSubTurn(context.Background(), al, parent, cfg)
// After spawn completes: channel should be unregistered (defer cleanup in spawnSubTurn)
if _, ok := al.subTurnResults.Load(parent.turnID); ok {
t.Error("channel should be unregistered after spawnSubTurn completes")
}
}
// ====================== Extra Independent Test: Dequeue Pending SubTurn Results ======================
func TestDequeuePendingSubTurnResults(t *testing.T) {
al, _, _, _, cleanup := newTestAgentLoop(t)
defer cleanup()
sessionKey := "test-session-dequeue"
ch := make(chan *tools.ToolResult, 4)
// Register channel manually
al.registerSubTurnResultChannel(sessionKey, ch)
defer al.unregisterSubTurnResultChannel(sessionKey)
// Empty channel returns nil
if results := al.dequeuePendingSubTurnResults(sessionKey); len(results) != 0 {
t.Errorf("expected empty results, got %d", len(results))
}
// Put 3 results in
ch <- &tools.ToolResult{ForLLM: "result-1"}
ch <- &tools.ToolResult{ForLLM: "result-2"}
ch <- &tools.ToolResult{ForLLM: "result-3"}
results := al.dequeuePendingSubTurnResults(sessionKey)
if len(results) != 3 {
t.Errorf("expected 3 results, got %d", len(results))
}
if results[0].ForLLM != "result-1" || results[2].ForLLM != "result-3" {
t.Error("results order or content mismatch")
}
// Channel should be drained now
if results := al.dequeuePendingSubTurnResults(sessionKey); len(results) != 0 {
t.Errorf("expected empty after drain, got %d", len(results))
}
// Unregistered session returns nil
al.unregisterSubTurnResultChannel(sessionKey)
if results := al.dequeuePendingSubTurnResults(sessionKey); results != nil {
t.Error("expected nil for unregistered session")
}
}
// ====================== Extra Independent Test: Concurrency Semaphore ======================
func TestSubTurnConcurrencySemaphore(t *testing.T) {
al, _, _, _, cleanup := newTestAgentLoop(t)
defer cleanup()
parent := &turnState{
ctx: context.Background(),
turnID: "parent-concurrency",
depth: 0,
pendingResults: make(chan *tools.ToolResult, 10),
session: &ephemeralSessionStore{},
concurrencySem: make(chan struct{}, 2), // Only allow 2 concurrent children
}
cfg := SubTurnConfig{Model: "gpt-4o-mini", Tools: []tools.Tool{}}
// Spawn 2 children — should succeed immediately
done := make(chan bool, 3)
for i := 0; i < 2; i++ {
go func() {
_, _ = spawnSubTurn(context.Background(), al, parent, cfg)
done <- true
}()
}
// Wait a bit to ensure the first 2 are running
// (In real scenario they'd be blocked in runTurn, but mockProvider returns immediately)
// So we just verify the semaphore doesn't block when under limit
<-done
<-done
// Verify semaphore is now full (2/2 slots used, but they already released)
// Since mockProvider returns immediately, semaphore is already released
// So we can't easily test blocking without a real long-running operation
// Instead, verify that semaphore exists and has correct capacity
if cap(parent.concurrencySem) != 2 {
t.Errorf("expected semaphore capacity 2, got %d", cap(parent.concurrencySem))
}
}
// ====================== Extra Independent Test: Hard Abort Cascading ======================
func TestHardAbortCascading(t *testing.T) {
al, _, _, _, cleanup := newTestAgentLoop(t)
defer cleanup()
sessionKey := "test-session-abort"
parentCtx, parentCancel := context.WithCancel(context.Background())
defer parentCancel()
rootTS := &turnState{
ctx: parentCtx,
turnID: sessionKey,
depth: 0,
session: &ephemeralSessionStore{},
pendingResults: make(chan *tools.ToolResult, 16),
concurrencySem: make(chan struct{}, 5),
}
// Register the root turn state
al.activeTurnStates.Store(sessionKey, rootTS)
defer al.activeTurnStates.Delete(sessionKey)
// Create a child turn state
childCtx, childCancel := context.WithCancel(rootTS.ctx)
defer childCancel()
childTS := &turnState{
ctx: childCtx,
cancelFunc: childCancel,
turnID: "child-1",
parentTurnID: sessionKey,
depth: 1,
session: &ephemeralSessionStore{},
pendingResults: make(chan *tools.ToolResult, 16),
concurrencySem: make(chan struct{}, 5),
}
// Attach cancelFunc to rootTS so Finish() can trigger it
rootTS.cancelFunc = parentCancel
// Verify contexts are not canceled yet
select {
case <-rootTS.ctx.Done():
t.Error("root context should not be canceled yet")
default:
}
select {
case <-childTS.ctx.Done():
t.Error("child context should not be canceled yet")
default:
}
// Trigger Hard Abort
err := al.HardAbort(sessionKey)
if err != nil {
t.Errorf("HardAbort failed: %v", err)
}
// Verify root context is canceled
select {
case <-rootTS.ctx.Done():
// Expected
default:
t.Error("root context should be canceled after HardAbort")
}
// Verify child context is also canceled (cascading)
select {
case <-childTS.ctx.Done():
// Expected
default:
t.Error("child context should be canceled after HardAbort (cascading)")
}
// Verify HardAbort on non-existent session returns error
err = al.HardAbort("non-existent-session")
if err == nil {
t.Error("expected error for non-existent session")
}
}
// TestHardAbortSessionRollback verifies that HardAbort rolls back session history
// to the state before the turn started, discarding all messages added during the turn.
func TestHardAbortSessionRollback(t *testing.T) {
al, _, _, _, cleanup := newTestAgentLoop(t)
defer cleanup()
// Create a session with initial history
sess := &ephemeralSessionStore{
history: []providers.Message{
{Role: "user", Content: "initial message 1"},
{Role: "assistant", Content: "initial response 1"},
},
}
// Create a root turnState with initialHistoryLength = 2
rootTS := &turnState{
ctx: context.Background(),
turnID: "test-session",
depth: 0,
session: sess,
initialHistoryLength: 2, // Snapshot: 2 messages
pendingResults: make(chan *tools.ToolResult, 16),
concurrencySem: make(chan struct{}, 5),
}
// Register the turn state
al.activeTurnStates.Store("test-session", rootTS)
// Simulate adding messages during the turn (e.g., user input + assistant response)
sess.AddMessage("", "user", "new user message")
sess.AddMessage("", "assistant", "new assistant response")
// Verify history grew to 4 messages
if len(sess.GetHistory("")) != 4 {
t.Fatalf("expected 4 messages before abort, got %d", len(sess.GetHistory("")))
}
// Trigger HardAbort
err := al.HardAbort("test-session")
if err != nil {
t.Fatalf("HardAbort failed: %v", err)
}
// Verify history rolled back to initial 2 messages
finalHistory := sess.GetHistory("")
if len(finalHistory) != 2 {
t.Errorf("expected history to rollback to 2 messages, got %d", len(finalHistory))
}
// Verify the content matches the initial state
if finalHistory[0].Content != "initial message 1" || finalHistory[1].Content != "initial response 1" {
t.Error("history content does not match initial state after rollback")
}
}
// TestNestedSubTurnHierarchy verifies that nested SubTurns maintain correct
// parent-child relationships and depth tracking when recursively calling runAgentLoop.
func TestNestedSubTurnHierarchy(t *testing.T) {
al, _, _, _, cleanup := newTestAgentLoop(t)
defer cleanup()
// Track spawned turns and their depths
type turnInfo struct {
parentID string
childID string
depth int
}
var spawnedTurns []turnInfo
var mu sync.Mutex
// Override MockEventBus to capture spawn events
originalEmit := MockEventBus.Emit
defer func() { MockEventBus.Emit = originalEmit }()
MockEventBus.Emit = func(event any) {
if spawnEvent, ok := event.(SubTurnSpawnEvent); ok {
mu.Lock()
// Extract depth from context (we'll verify this matches expected depth)
spawnedTurns = append(spawnedTurns, turnInfo{
parentID: spawnEvent.ParentID,
childID: spawnEvent.ChildID,
})
mu.Unlock()
}
}
// Create a root turn
rootSession := &ephemeralSessionStore{}
rootTS := &turnState{
ctx: context.Background(),
turnID: "root-turn",
depth: 0,
session: rootSession,
pendingResults: make(chan *tools.ToolResult, 16),
concurrencySem: make(chan struct{}, 5),
}
// Spawn a child (depth 1)
childCfg := SubTurnConfig{Model: "gpt-4o-mini"}
_, err := spawnSubTurn(context.Background(), al, rootTS, childCfg)
if err != nil {
t.Fatalf("failed to spawn child: %v", err)
}
// Verify we captured the spawn event
mu.Lock()
if len(spawnedTurns) != 1 {
t.Fatalf("expected 1 spawn event, got %d", len(spawnedTurns))
}
if spawnedTurns[0].parentID != "root-turn" {
t.Errorf("expected parent ID 'root-turn', got %s", spawnedTurns[0].parentID)
}
mu.Unlock()
// Verify root turn has the child in its childTurnIDs
rootTS.mu.Lock()
if len(rootTS.childTurnIDs) != 1 {
t.Errorf("expected root to have 1 child, got %d", len(rootTS.childTurnIDs))
}
rootTS.mu.Unlock()
}
// TestDeliverSubTurnResultNoDeadlock verifies that deliverSubTurnResult doesn't
// deadlock when multiple goroutines are accessing the parent turnState concurrently.
func TestDeliverSubTurnResultNoDeadlock(t *testing.T) {
parent := &turnState{
ctx: context.Background(),
turnID: "parent-deadlock-test",
depth: 0,
pendingResults: make(chan *tools.ToolResult, 2), // Small buffer to test blocking
isFinished: false,
}
// Simulate multiple child turns delivering results concurrently
var wg sync.WaitGroup
numChildren := 10
for i := 0; i < numChildren; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
result := &tools.ToolResult{ForLLM: fmt.Sprintf("result-%d", id)}
deliverSubTurnResult(parent, fmt.Sprintf("child-%d", id), result)
}(i)
}
// Concurrently read from the channel to prevent blocking
// and to actually retrieve the matched number of results
go func() {
for i := 0; i < numChildren; i++ {
select {
case <-parent.pendingResults:
case <-time.After(5 * time.Second):
t.Error("timeout waiting for result")
return
}
}
}()
// Wait for all deliveries to complete (with timeout)
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
select {
case <-done:
// Success - no deadlock
case <-time.After(3 * time.Second):
t.Fatal("deadlock detected: deliverSubTurnResult blocked")
}
}
// TestHardAbortOrderOfOperations verifies that HardAbort calls Finish() before
// rolling back session history, minimizing the race window where new messages
// could be added after rollback.
func TestHardAbortOrderOfOperations(t *testing.T) {
al, _, _, _, cleanup := newTestAgentLoop(t)
defer cleanup()
sess := &ephemeralSessionStore{
history: []providers.Message{
{Role: "user", Content: "initial message"},
{Role: "assistant", Content: "response 1"},
{Role: "user", Content: "follow-up"},
},
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
rootTS := &turnState{
ctx: ctx,
cancelFunc: cancel,
turnID: "test-session-order",
depth: 0,
session: sess,
initialHistoryLength: 1, // Snapshot: 1 message
pendingResults: make(chan *tools.ToolResult, 16),
concurrencySem: make(chan struct{}, 5),
}
al.activeTurnStates.Store("test-session-order", rootTS)
// Trigger HardAbort
err := al.HardAbort("test-session-order")
if err != nil {
t.Fatalf("HardAbort failed: %v", err)
}
// Verify context was cancelled (Finish() was called)
select {
case <-rootTS.ctx.Done():
// Good - context was cancelled
default:
t.Error("expected context to be cancelled after HardAbort")
}
// Verify history was rolled back
finalHistory := sess.GetHistory("")
if len(finalHistory) != 1 {
t.Errorf("expected history to rollback to 1 message, got %d", len(finalHistory))
}
if finalHistory[0].Content != "initial message" {
t.Error("history content does not match initial state after rollback")
}
}
// TestFinishedChannelClosedState verifies that Finish() closes the Finished() channel
// so that child turns can safely abort waiting.
func TestFinishedChannelClosedState(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ts := &turnState{
ctx: ctx,
cancelFunc: cancel,
turnID: "test-finished-channel",
depth: 0,
pendingResults: make(chan *tools.ToolResult, 2),
isFinished: false,
}
// Verify Finished channel is blocking initially
select {
case <-ts.Finished():
t.Fatal("finished channel should block initially")
default:
// Good
}
// Call Finish() with graceful finish
ts.Finish(false)
// Verify Finished channel is closed
select {
case _, ok := <-ts.Finished():
if ok {
t.Error("expected Finished() channel to be closed after Finish()")
}
default:
t.Fatal("expected <-ts.Finished() to not block")
}
// Verify Finish() is idempotent
ts.Finish(false) // Should not panic
// Verify deliverSubTurnResult correctly uses Finished() channel and treats as orphan
result := &tools.ToolResult{ForLLM: "late result"}
deliverSubTurnResult(ts, "child-1", result) // Will emit orphan due to <-ts.Finished() case
}
// TestFinalPollCapturesLateResults verifies that the final poll before Finish()
// captures results that arrive after the last iteration poll.
func TestFinalPollCapturesLateResults(t *testing.T) {
al, _, _, _, cleanup := newTestAgentLoop(t)
defer cleanup()
sessionKey := "test-session-final-poll"
ch := make(chan *tools.ToolResult, 4)
// Register the channel
al.registerSubTurnResultChannel(sessionKey, ch)
defer al.unregisterSubTurnResultChannel(sessionKey)
// Simulate results arriving after last iteration poll
ch <- &tools.ToolResult{ForLLM: "result 1"}
ch <- &tools.ToolResult{ForLLM: "result 2"}
// Dequeue should capture both results
results := al.dequeuePendingSubTurnResults(sessionKey)
if len(results) != 2 {
t.Errorf("expected 2 results, got %d", len(results))
}
// Verify channel is now empty
results = al.dequeuePendingSubTurnResults(sessionKey)
if len(results) != 0 {
t.Errorf("expected 0 results on second poll, got %d", len(results))
}
}
// TestSpawnSubTurn_PanicRecovery verifies that even if runTurn panics,
// the result is still delivered for async calls and SubTurnEndEvent is emitted.
func TestSpawnSubTurn_PanicRecovery(t *testing.T) {
// Create a panic provider
panicProvider := &panicMockProvider{}
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: t.TempDir(),
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
al := NewAgentLoop(cfg, bus.NewMessageBus(), panicProvider)
parent := &turnState{
ctx: context.Background(),
turnID: "parent-panic",
depth: 0,
pendingResults: make(chan *tools.ToolResult, 1),
session: &ephemeralSessionStore{},
}
collector := &eventCollector{}
originalEmit := MockEventBus.Emit
MockEventBus.Emit = collector.collect
defer func() { MockEventBus.Emit = originalEmit }()
// Test async call - result should still be delivered via channel
asyncCfg := SubTurnConfig{Model: "gpt-4o-mini", Tools: []tools.Tool{}, Async: true}
result, err := spawnSubTurn(context.Background(), al, parent, asyncCfg)
// Should return error from panic recovery
if err == nil {
t.Error("expected error from panic recovery")
}
// Result should be nil because panic occurred before runTurn could return
if result != nil {
t.Error("expected nil result after panic")
}
// SubTurnEndEvent should still be emitted
if !collector.hasEventOfType(SubTurnEndEvent{}) {
t.Error("SubTurnEndEvent not emitted after panic")
}
// For async call, result should still be delivered to channel (even if nil)
select {
case res := <-parent.pendingResults:
// Result was delivered (nil due to panic)
_ = res
default:
t.Error("async result should be delivered to channel even after panic")
}
}
// panicMockProvider is a mock provider that always panics
type panicMockProvider struct{}
func (m *panicMockProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
panic("intentional panic for testing")
}
func (m *panicMockProvider) GetDefaultModel() string {
return "panic-model"
}
// ====================== Public API Tests ======================
// simpleMockProviderAPI for testing public APIs
type simpleMockProviderAPI struct {
response string
}
func (m *simpleMockProviderAPI) Chat(
ctx context.Context,
messages []providers.Message,
toolDefs []providers.ToolDefinition,
model string,
options map[string]any,
) (*providers.LLMResponse, error) {
return &providers.LLMResponse{
Content: m.response,
}, nil
}
func (m *simpleMockProviderAPI) GetDefaultModel() string {
return "gpt-4o-mini"
}
// TestGetActiveTurn verifies that GetActiveTurn returns correct turn information
func TestGetActiveTurn(t *testing.T) {
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Model: "gpt-4o-mini",
Provider: "mock",
},
},
}
al := NewAgentLoop(cfg, nil, &simpleMockProviderAPI{response: "ok"})
// Create a root turn state
rootCtx := context.Background()
rootTS := &turnState{
ctx: rootCtx,
turnID: "root-turn",
parentTurnID: "",
depth: 0,
childTurnIDs: []string{},
session: newEphemeralSession(nil),
pendingResults: make(chan *tools.ToolResult, 16),
concurrencySem: make(chan struct{}, maxConcurrentSubTurns),
}
sessionKey := "test-session"
al.activeTurnStates.Store(sessionKey, rootTS)
defer al.activeTurnStates.Delete(sessionKey)
// Test: GetActiveTurn should return turn info
info := al.GetActiveTurn(sessionKey)
if info == nil {
t.Fatal("GetActiveTurn returned nil for active session")
}
if info.TurnID != "root-turn" {
t.Errorf("Expected TurnID 'root-turn', got %q", info.TurnID)
}
if info.Depth != 0 {
t.Errorf("Expected Depth 0, got %d", info.Depth)
}
if info.ParentTurnID != "" {
t.Errorf("Expected empty ParentTurnID, got %q", info.ParentTurnID)
}
if len(info.ChildTurnIDs) != 0 {
t.Errorf("Expected 0 child turns, got %d", len(info.ChildTurnIDs))
}
// Test: GetActiveTurn should return nil for non-existent session
nonExistentInfo := al.GetActiveTurn("non-existent-session")
if nonExistentInfo != nil {
t.Error("GetActiveTurn should return nil for non-existent session")
}
}
// TestGetActiveTurn_WithChildren verifies that child turn IDs are correctly reported
func TestGetActiveTurn_WithChildren(t *testing.T) {
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Model: "gpt-4o-mini",
Provider: "mock",
},
},
}
al := NewAgentLoop(cfg, nil, &simpleMockProviderAPI{response: "ok"})
rootCtx := context.Background()
rootTS := &turnState{
ctx: rootCtx,
turnID: "root-turn",
parentTurnID: "",
depth: 0,
childTurnIDs: []string{"child-1", "child-2"},
session: newEphemeralSession(nil),
pendingResults: make(chan *tools.ToolResult, 16),
concurrencySem: make(chan struct{}, maxConcurrentSubTurns),
}
sessionKey := "test-session-with-children"
al.activeTurnStates.Store(sessionKey, rootTS)
defer al.activeTurnStates.Delete(sessionKey)
info := al.GetActiveTurn(sessionKey)
if info == nil {
t.Fatal("GetActiveTurn returned nil")
}
if len(info.ChildTurnIDs) != 2 {
t.Fatalf("Expected 2 child turns, got %d", len(info.ChildTurnIDs))
}
if info.ChildTurnIDs[0] != "child-1" || info.ChildTurnIDs[1] != "child-2" {
t.Errorf("Child turn IDs mismatch: got %v", info.ChildTurnIDs)
}
}
// TestTurnStateInfo_ThreadSafety verifies that Info() is thread-safe
func TestTurnStateInfo_ThreadSafety(t *testing.T) {
rootCtx := context.Background()
ts := &turnState{
ctx: rootCtx,
turnID: "test-turn",
parentTurnID: "parent",
depth: 1,
childTurnIDs: []string{},
session: newEphemeralSession(nil),
pendingResults: make(chan *tools.ToolResult, 16),
concurrencySem: make(chan struct{}, maxConcurrentSubTurns),
}
// Concurrently read Info() and modify childTurnIDs
done := make(chan bool)
go func() {
for i := 0; i < 100; i++ {
ts.mu.Lock()
ts.childTurnIDs = append(ts.childTurnIDs, "child")
ts.mu.Unlock()
}
done <- true
}()
go func() {
for i := 0; i < 100; i++ {
info := ts.Info()
if info == nil {
t.Error("Info() returned nil")
}
}
done <- true
}()
<-done
<-done
}
// TestInjectFollowUp verifies that InjectFollowUp enqueues messages
func TestInjectFollowUp(t *testing.T) {
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Model: "gpt-4o-mini",
Provider: "mock",
},
},
}
al := NewAgentLoop(cfg, nil, &simpleMockProviderAPI{response: "ok"})
msg := providers.Message{
Role: "user",
Content: "Follow-up task",
}
err := al.InjectFollowUp(msg)
if err != nil {
t.Fatalf("InjectFollowUp failed: %v", err)
}
// Verify message was enqueued
if al.steering.len() != 1 {
t.Errorf("Expected 1 message in queue, got %d", al.steering.len())
}
}
// TestAPIAliases verifies that API aliases work correctly
func TestAPIAliases(t *testing.T) {
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Model: "gpt-4o-mini",
Provider: "mock",
},
},
}
al := NewAgentLoop(cfg, nil, &simpleMockProviderAPI{response: "ok"})
msg := providers.Message{
Role: "user",
Content: "Test message",
}
// Test InterruptGraceful (alias for Steer)
err := al.InterruptGraceful(msg)
if err != nil {
t.Errorf("InterruptGraceful failed: %v", err)
}
// Test InjectSteering (alias for Steer)
err = al.InjectSteering(msg)
if err != nil {
t.Errorf("InjectSteering failed: %v", err)
}
// Verify both messages were enqueued
if al.steering.len() != 2 {
t.Errorf("Expected 2 messages in queue, got %d", al.steering.len())
}
}
// TestInterruptHard_Alias verifies that InterruptHard is an alias for HardAbort
func TestInterruptHard_Alias(t *testing.T) {
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Model: "gpt-4o-mini",
Provider: "mock",
},
},
}
al := NewAgentLoop(cfg, nil, &simpleMockProviderAPI{response: "ok"})
rootCtx := context.Background()
rootTS := &turnState{
ctx: rootCtx,
turnID: "test-turn",
depth: 0,
session: newEphemeralSession(nil),
initialHistoryLength: 0,
pendingResults: make(chan *tools.ToolResult, 16),
concurrencySem: make(chan struct{}, maxConcurrentSubTurns),
}
sessionKey := "test-session-interrupt"
al.activeTurnStates.Store(sessionKey, rootTS)
// Test InterruptHard (alias for HardAbort)
err := al.InterruptHard(sessionKey)
if err != nil {
t.Errorf("InterruptHard failed: %v", err)
}
// Verify turn was finished
info := al.GetActiveTurn(sessionKey)
if info != nil && !info.IsFinished {
t.Error("Turn should be finished after InterruptHard")
}
}
// TestFinish_ConcurrentCalls verifies that calling Finish() concurrently from multiple
// goroutines is safe and doesn't cause panics or double-close errors.
func TestFinish_ConcurrentCalls(t *testing.T) {
ctx := context.Background()
parentTS := &turnState{
ctx: ctx,
turnID: "parent-concurrent-finish",
depth: 0,
pendingResults: make(chan *tools.ToolResult, 16),
concurrencySem: make(chan struct{}, maxConcurrentSubTurns),
}
parentTS.ctx, parentTS.cancelFunc = context.WithCancel(ctx)
// Launch multiple goroutines that all call Finish() concurrently
const numGoroutines = 10
var wg sync.WaitGroup
wg.Add(numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func() {
defer wg.Done()
// This should not panic, even when called concurrently
parentTS.Finish(false)
}()
}
wg.Wait()
// Verify the Finished() channel is closed
select {
case _, ok := <-parentTS.Finished():
if ok {
t.Error("Expected Finished() channel to be closed")
}
default:
t.Error("Expected Finished() channel to be closed and readable without blocking")
}
// Verify isFinished is set
parentTS.mu.Lock()
if !parentTS.isFinished {
t.Error("Expected isFinished to be true")
}
parentTS.mu.Unlock()
}
// TestDeliverSubTurnResult_RaceWithFinish verifies that deliverSubTurnResult handles
// the race condition where Finish() is called while results are being delivered.
func TestDeliverSubTurnResult_RaceWithFinish(t *testing.T) {
// Save original MockEventBus.Emit
originalEmit := MockEventBus.Emit
defer func() {
MockEventBus.Emit = originalEmit
}()
// Collect events
var mu sync.Mutex
var deliveredCount, orphanCount int
MockEventBus.Emit = func(e any) {
mu.Lock()
defer mu.Unlock()
switch e.(type) {
case SubTurnResultDeliveredEvent:
deliveredCount++
case SubTurnOrphanResultEvent:
orphanCount++
}
}
ctx := context.Background()
parentTS := &turnState{
ctx: ctx,
turnID: "parent-race-test",
depth: 0,
pendingResults: make(chan *tools.ToolResult, 16),
concurrencySem: make(chan struct{}, maxConcurrentSubTurns),
}
parentTS.ctx, parentTS.cancelFunc = context.WithCancel(ctx)
// Launch goroutines that deliver results while another goroutine calls Finish()
const numResults = 20
var wg sync.WaitGroup
wg.Add(numResults + 1)
// Goroutine that calls Finish() after a short delay
go func() {
defer wg.Done()
time.Sleep(5 * time.Millisecond)
parentTS.Finish(false)
}()
// Goroutines that deliver results
for i := 0; i < numResults; i++ {
go func(id int) {
defer wg.Done()
result := &tools.ToolResult{
ForLLM: fmt.Sprintf("result-%d", id),
}
// This should not panic, even if Finish() is called concurrently
deliverSubTurnResult(parentTS, fmt.Sprintf("child-%d", id), result)
}(i)
}
wg.Wait()
// Get final counts
mu.Lock()
finalDelivered := deliveredCount
finalOrphan := orphanCount
mu.Unlock()
t.Logf("Delivered: %d, Orphan: %d, Total: %d", finalDelivered, finalOrphan, finalDelivered+finalOrphan)
// With the new drainPendingResults behavior, the total events may be >= numResults
// because Finish() drains remaining results from the channel and emits them as orphans.
// So we expect:
// - Some results were delivered successfully (before Finish())
// - Some results became orphans (after Finish() or channel full)
// - Some results were in the channel when Finish() was called and got drained as orphans
// The total should be at least numResults (could be more due to drain)
if finalDelivered+finalOrphan < numResults {
t.Errorf("Expected at least %d total events, got %d delivered + %d orphan = %d",
numResults, finalDelivered, finalOrphan, finalDelivered+finalOrphan)
}
// Should have at least some orphan results (those that arrived after Finish() or were drained)
if finalOrphan == 0 {
t.Error("Expected at least some orphan results after Finish()")
}
}
// TestConcurrencySemaphore_Timeout verifies that spawning sub-turns times out
// when all concurrency slots are occupied for too long.
// Note: This test uses a shorter timeout by temporarily modifying the constant.
func TestConcurrencySemaphore_Timeout(t *testing.T) {
// This test would take 30 seconds with the default timeout.
// Instead, we'll test the mechanism by verifying the timeout context is created correctly.
// A full integration test with actual timeout would be too slow for unit tests.
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Provider: "mock",
},
},
}
msgBus := bus.NewMessageBus()
provider := &simpleMockProviderAPI{}
al := NewAgentLoop(cfg, msgBus, provider)
ctx := context.Background()
parentTS := &turnState{
ctx: ctx,
turnID: "parent-timeout-test",
depth: 0,
session: newEphemeralSession(nil),
pendingResults: make(chan *tools.ToolResult, 16),
concurrencySem: make(chan struct{}, maxConcurrentSubTurns),
}
parentTS.ctx, parentTS.cancelFunc = context.WithCancel(ctx)
defer parentTS.Finish(false)
// Fill all concurrency slots
for i := 0; i < maxConcurrentSubTurns; i++ {
parentTS.concurrencySem <- struct{}{}
}
// Create a context with a very short timeout for testing
testCtx, cancel := context.WithTimeout(ctx, 100*time.Millisecond)
defer cancel()
// Now try to spawn a sub-turn with the short timeout context
subTurnCfg := SubTurnConfig{
Model: "gpt-4o-mini",
Async: false,
}
start := time.Now()
_, err := spawnSubTurn(testCtx, al, parentTS, subTurnCfg)
elapsed := time.Since(start)
// Should get a timeout error (either from our timeout context or the internal one)
if err == nil {
t.Error("Expected timeout error, got nil")
}
// The error should be related to context cancellation or timeout
if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, ErrConcurrencyTimeout) {
t.Logf("Got error: %v (type: %T)", err, err)
// This is acceptable - the error might be wrapped
}
// Should timeout quickly (within a reasonable margin)
if elapsed > 2*time.Second {
t.Errorf("Timeout took too long: %v", elapsed)
}
t.Logf("Timeout occurred after %v with error: %v", elapsed, err)
// Clean up - drain the semaphore
for i := 0; i < maxConcurrentSubTurns; i++ {
<-parentTS.concurrencySem
}
}
// TestEphemeralSession_AutoTruncate verifies that ephemeral sessions automatically
// truncate their history to prevent memory accumulation.
func TestEphemeralSession_AutoTruncate(t *testing.T) {
store := newEphemeralSession(nil).(*ephemeralSessionStore)
// Add more messages than the limit
for i := 0; i < maxEphemeralHistorySize+20; i++ {
store.AddMessage("test", "user", fmt.Sprintf("message-%d", i))
}
// Verify history is truncated to the limit
history := store.GetHistory("test")
if len(history) != maxEphemeralHistorySize {
t.Errorf("Expected history length %d, got %d", maxEphemeralHistorySize, len(history))
}
// Verify we kept the most recent messages
lastMsg := history[len(history)-1]
expectedContent := fmt.Sprintf("message-%d", maxEphemeralHistorySize+20-1)
if lastMsg.Content != expectedContent {
t.Errorf("Expected last message to be %q, got %q", expectedContent, lastMsg.Content)
}
// Verify the oldest messages were discarded
firstMsg := history[0]
expectedFirstContent := fmt.Sprintf("message-%d", 20) // First 20 were discarded
if firstMsg.Content != expectedFirstContent {
t.Errorf("Expected first message to be %q, got %q", expectedFirstContent, firstMsg.Content)
}
}
// TestContextWrapping_SingleLayer verifies that we only create one context layer
// in spawnSubTurn, not multiple redundant layers.
func TestContextWrapping_SingleLayer(t *testing.T) {
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Provider: "mock",
},
},
}
msgBus := bus.NewMessageBus()
provider := &simpleMockProviderAPI{}
al := NewAgentLoop(cfg, msgBus, provider)
ctx := context.Background()
parentTS := &turnState{
ctx: ctx,
turnID: "parent-context-test",
depth: 0,
session: newEphemeralSession(nil),
pendingResults: make(chan *tools.ToolResult, 16),
concurrencySem: make(chan struct{}, maxConcurrentSubTurns),
}
parentTS.ctx, parentTS.cancelFunc = context.WithCancel(ctx)
defer parentTS.Finish(false)
// Spawn a sub-turn
subTurnCfg := SubTurnConfig{
Model: "gpt-4o-mini",
Async: false,
}
result, err := spawnSubTurn(ctx, al, parentTS, subTurnCfg)
if err != nil {
t.Fatalf("spawnSubTurn failed: %v", err)
}
if result == nil {
t.Error("Expected non-nil result")
}
// Verify the child turn was created with a cancel function
// (This is implicit - if the test passes without hanging, the context management is correct)
t.Log("Context wrapping test passed - no redundant layers detected")
}
// TestSyncSubTurn_NoChannelDelivery verifies that synchronous sub-turns
// do NOT deliver results to the pendingResults channel (only return directly).
func TestSyncSubTurn_NoChannelDelivery(t *testing.T) {
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Provider: "mock",
},
},
}
msgBus := bus.NewMessageBus()
provider := &simpleMockProviderAPI{}
al := NewAgentLoop(cfg, msgBus, provider)
ctx := context.Background()
parentTS := &turnState{
ctx: ctx,
turnID: "parent-sync-test",
depth: 0,
session: newEphemeralSession(nil),
pendingResults: make(chan *tools.ToolResult, 16),
concurrencySem: make(chan struct{}, maxConcurrentSubTurns),
}
parentTS.ctx, parentTS.cancelFunc = context.WithCancel(ctx)
defer parentTS.Finish(false)
// Spawn a SYNCHRONOUS sub-turn (Async=false)
subTurnCfg := SubTurnConfig{
Model: "gpt-4o-mini",
Async: false, // Synchronous - should NOT deliver to channel
}
result, err := spawnSubTurn(ctx, al, parentTS, subTurnCfg)
if err != nil {
t.Fatalf("spawnSubTurn failed: %v", err)
}
if result == nil {
t.Error("Expected non-nil result from synchronous sub-turn")
}
// Verify the pendingResults channel is EMPTY
// (synchronous sub-turns should not deliver to channel)
select {
case r := <-parentTS.pendingResults:
t.Errorf("Expected empty channel for sync sub-turn, but got result: %v", r)
default:
// Expected: channel is empty
t.Log("Verified: synchronous sub-turn did not deliver to channel")
}
// Verify channel length is 0
if len(parentTS.pendingResults) != 0 {
t.Errorf("Expected channel length 0, got %d", len(parentTS.pendingResults))
}
}
// TestAsyncSubTurn_ChannelDelivery verifies that asynchronous sub-turns
// DO deliver results to the pendingResults channel.
func TestAsyncSubTurn_ChannelDelivery(t *testing.T) {
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Provider: "mock",
},
},
}
msgBus := bus.NewMessageBus()
provider := &simpleMockProviderAPI{}
al := NewAgentLoop(cfg, msgBus, provider)
ctx := context.Background()
parentTS := &turnState{
ctx: ctx,
turnID: "parent-async-test",
depth: 0,
session: newEphemeralSession(nil),
pendingResults: make(chan *tools.ToolResult, 16),
concurrencySem: make(chan struct{}, maxConcurrentSubTurns),
}
parentTS.ctx, parentTS.cancelFunc = context.WithCancel(ctx)
defer parentTS.Finish(false)
// Spawn an ASYNCHRONOUS sub-turn (Async=true)
subTurnCfg := SubTurnConfig{
Model: "gpt-4o-mini",
Async: true, // Asynchronous - SHOULD deliver to channel
}
result, err := spawnSubTurn(ctx, al, parentTS, subTurnCfg)
if err != nil {
t.Fatalf("spawnSubTurn failed: %v", err)
}
if result == nil {
t.Error("Expected non-nil result from asynchronous sub-turn")
}
// Verify the pendingResults channel has the result
select {
case r := <-parentTS.pendingResults:
if r == nil {
t.Error("Expected non-nil result from channel")
}
t.Log("Verified: asynchronous sub-turn delivered to channel")
case <-time.After(100 * time.Millisecond):
t.Error("Expected result in channel for async sub-turn, but channel was empty")
}
}
// TestGrandchildAbort_CascadingCancellation verifies that when a grandparent turn
// is hard aborted, the cancellation cascades down to grandchild turns.
func TestGrandchildAbort_CascadingCancellation(t *testing.T) {
ctx := context.Background()
// Create grandparent turn (depth 0)
grandparentTS := &turnState{
ctx: ctx,
turnID: "grandparent",
depth: 0,
session: newEphemeralSession(nil),
pendingResults: make(chan *tools.ToolResult, 16),
concurrencySem: make(chan struct{}, maxConcurrentSubTurns),
}
grandparentTS.ctx, grandparentTS.cancelFunc = context.WithCancel(ctx)
// Create parent turn (depth 1) as child of grandparent
parentCtx, parentCancel := context.WithCancel(grandparentTS.ctx)
defer parentCancel()
parentTS := &turnState{
ctx: parentCtx,
turnID: "parent",
parentTurnID: "grandparent",
depth: 1,
session: newEphemeralSession(nil),
pendingResults: make(chan *tools.ToolResult, 16),
concurrencySem: make(chan struct{}, maxConcurrentSubTurns),
}
parentTS.cancelFunc = parentCancel
// Create grandchild turn (depth 2) as child of parent
childCtx, childCancel := context.WithCancel(parentTS.ctx)
defer childCancel()
childTS := &turnState{
ctx: childCtx,
turnID: "grandchild",
parentTurnID: "parent",
depth: 2,
session: newEphemeralSession(nil),
pendingResults: make(chan *tools.ToolResult, 16),
concurrencySem: make(chan struct{}, maxConcurrentSubTurns),
}
childTS.cancelFunc = childCancel
// Verify all contexts are active
select {
case <-grandparentTS.ctx.Done():
t.Error("Grandparent context should not be cancelled yet")
default:
}
select {
case <-parentTS.ctx.Done():
t.Error("Parent context should not be cancelled yet")
default:
}
select {
case <-childTS.ctx.Done():
t.Error("Child context should not be cancelled yet")
default:
}
// Hard abort the grandparent
grandparentTS.Finish(true)
// Wait a bit for cancellation to propagate
time.Sleep(10 * time.Millisecond)
// Verify cascading cancellation
select {
case <-grandparentTS.ctx.Done():
t.Log("Grandparent context cancelled (expected)")
default:
t.Error("Grandparent context should be cancelled")
}
select {
case <-parentTS.ctx.Done():
t.Log("Parent context cancelled via cascade (expected)")
default:
t.Error("Parent context should be cancelled via cascade")
}
select {
case <-childTS.ctx.Done():
t.Log("Grandchild context cancelled via cascade (expected)")
default:
t.Error("Grandchild context should be cancelled via cascade")
}
}
// TestSpawnDuringAbort_RaceCondition verifies behavior when trying to spawn
// a sub-turn while the parent is being aborted.
func TestSpawnDuringAbort_RaceCondition(t *testing.T) {
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Provider: "mock",
},
},
}
msgBus := bus.NewMessageBus()
provider := &simpleMockProviderAPI{}
al := NewAgentLoop(cfg, msgBus, provider)
ctx := context.Background()
parentTS := &turnState{
ctx: ctx,
turnID: "parent-abort-race",
depth: 0,
session: newEphemeralSession(nil),
pendingResults: make(chan *tools.ToolResult, 16),
concurrencySem: make(chan struct{}, maxConcurrentSubTurns),
}
parentTS.ctx, parentTS.cancelFunc = context.WithCancel(ctx)
var wg sync.WaitGroup
wg.Add(2)
var spawnErr error
// Goroutine 1: Try to spawn a sub-turn
go func() {
defer wg.Done()
subTurnCfg := SubTurnConfig{
Model: "gpt-4o-mini",
Async: false,
}
_, err := spawnSubTurn(parentTS.ctx, al, parentTS, subTurnCfg)
spawnErr = err
}()
// Goroutine 2: Abort the parent almost immediately
go func() {
defer wg.Done()
time.Sleep(1 * time.Millisecond)
parentTS.Finish(false)
}()
wg.Wait()
// The spawn should either succeed (if it started before abort)
// or fail with context cancelled error (if abort happened first)
if spawnErr != nil {
if errors.Is(spawnErr, context.Canceled) {
t.Logf("Spawn failed with expected context cancellation: %v", spawnErr)
} else {
t.Logf("Spawn failed with error: %v", spawnErr)
}
} else {
t.Log("Spawn succeeded before abort")
}
// The important thing is that it doesn't panic or deadlock
t.Log("Race condition handled gracefully - no panic or deadlock")
}
// ====================== Slow SubTurn Cancellation Test ======================
// slowMockProvider simulates a slow LLM call that takes a long time to complete.
// This is used to test the scenario where a parent turn finishes before the child SubTurn.
type slowMockProvider struct {
delay time.Duration
}
func (m *slowMockProvider) Chat(
ctx context.Context,
messages []providers.Message,
toolDefs []providers.ToolDefinition,
model string,
options map[string]any,
) (*providers.LLMResponse, error) {
select {
case <-time.After(m.delay):
// Completed normally after delay
return &providers.LLMResponse{
Content: "slow response completed",
}, nil
case <-ctx.Done():
// Context was cancelled while waiting
return nil, ctx.Err()
}
}
func (m *slowMockProvider) GetDefaultModel() string {
return "slow-model"
}
// TestAsyncSubTurn_ParentFinishesEarly simulates the scenario where:
// 1. Parent spawns an async SubTurn that takes a long time
// 2. Parent finishes quickly
// 3. SubTurn should be cancelled with context canceled error
func TestAsyncSubTurn_ParentFinishesEarly(t *testing.T) {
// Save original MockEventBus.Emit to capture events
originalEmit := MockEventBus.Emit
defer func() {
MockEventBus.Emit = originalEmit
}()
var mu sync.Mutex
var events []any
MockEventBus.Emit = func(e any) {
mu.Lock()
defer mu.Unlock()
events = append(events, e)
}
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Provider: "mock",
},
},
}
msgBus := bus.NewMessageBus()
provider := &slowMockProvider{delay: 5 * time.Second} // SubTurn takes 5 seconds
al := NewAgentLoop(cfg, msgBus, provider)
ctx := context.Background()
parentTS := &turnState{
ctx: ctx,
turnID: "parent-fast",
depth: 0,
session: newEphemeralSession(nil),
pendingResults: make(chan *tools.ToolResult, 16),
concurrencySem: make(chan struct{}, maxConcurrentSubTurns),
}
parentTS.ctx, parentTS.cancelFunc = context.WithCancel(ctx)
var subTurnErr error
var subTurnResult *tools.ToolResult
var wg sync.WaitGroup
// Spawn async SubTurn in a goroutine (it will be slow)
wg.Add(1)
go func() {
defer wg.Done()
subTurnCfg := SubTurnConfig{
Model: "slow-model",
Async: true, // Asynchronous SubTurn
}
subTurnResult, subTurnErr = spawnSubTurn(parentTS.ctx, al, parentTS, subTurnCfg)
}()
// Parent finishes quickly (after 100ms), while SubTurn is still running
time.Sleep(100 * time.Millisecond)
t.Log("Parent finishing early...")
parentTS.Finish(false)
// Wait for SubTurn to complete (or be cancelled)
wg.Wait()
// Check the result
t.Logf("SubTurn error: %v", subTurnErr)
t.Logf("SubTurn result: %v", subTurnResult)
if subTurnErr != nil {
if errors.Is(subTurnErr, context.Canceled) {
t.Log("✓ SubTurn was cancelled as expected (context canceled)")
} else {
t.Logf("SubTurn failed with other error: %v", subTurnErr)
}
} else {
t.Log("SubTurn completed before parent finished (unlikely but possible)")
}
// Log captured events
mu.Lock()
t.Logf("Captured %d events:", len(events))
for i, e := range events {
t.Logf(" Event %d: %T", i+1, e)
}
mu.Unlock()
}
// TestAsyncSubTurn_ParentWaitsForChild simulates the scenario where:
// 1. Parent spawns an async SubTurn that takes some time
// 2. Parent WAITS for SubTurn to complete before finishing
// 3. Both should complete successfully
func TestAsyncSubTurn_ParentWaitsForChild(t *testing.T) {
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Provider: "mock",
},
},
}
msgBus := bus.NewMessageBus()
provider := &slowMockProvider{delay: 200 * time.Millisecond} // SubTurn takes 200ms
al := NewAgentLoop(cfg, msgBus, provider)
ctx := context.Background()
parentTS := &turnState{
ctx: ctx,
turnID: "parent-wait",
depth: 0,
session: newEphemeralSession(nil),
pendingResults: make(chan *tools.ToolResult, 16),
concurrencySem: make(chan struct{}, maxConcurrentSubTurns),
}
parentTS.ctx, parentTS.cancelFunc = context.WithCancel(ctx)
var subTurnErr error
var subTurnResult *tools.ToolResult
var wg sync.WaitGroup
// Spawn async SubTurn in a goroutine
wg.Add(1)
go func() {
defer wg.Done()
subTurnCfg := SubTurnConfig{
Model: "slow-model",
Async: true,
}
subTurnResult, subTurnErr = spawnSubTurn(parentTS.ctx, al, parentTS, subTurnCfg)
}()
// Parent WAITS for SubTurn to complete
t.Log("Parent waiting for SubTurn...")
wg.Wait()
t.Log("SubTurn completed, parent now finishing")
// Now parent can finish safely
parentTS.Finish(false)
// Check the result
if subTurnErr != nil {
if errors.Is(subTurnErr, context.Canceled) {
t.Errorf("SubTurn should NOT have been cancelled: %v", subTurnErr)
} else {
t.Logf("SubTurn failed with error: %v", subTurnErr)
}
} else {
t.Log("✓ SubTurn completed successfully")
if subTurnResult != nil {
t.Logf("SubTurn result: %s", subTurnResult.ForLLM)
}
}
// Check channel delivery
select {
case r := <-parentTS.pendingResults:
if r != nil {
t.Logf("✓ Result delivered to channel: %s", r.ForLLM)
}
case <-time.After(100 * time.Millisecond):
t.Log("No result in channel (expected since we waited)")
}
}
// ====================== Graceful vs Hard Finish Tests ======================
// TestFinish_GracefulVsHard verifies the behavior difference between:
// - Finish(false): graceful finish, signals parentEnded but doesn't cancel children
// - Finish(true): hard abort, immediately cancels all children
func TestFinish_GracefulVsHard(t *testing.T) {
// Test 1: Graceful finish should set parentEnded but not cancel context
t.Run("Graceful_SetsParentEnded", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ts := &turnState{
ctx: ctx,
turnID: "graceful-test",
depth: 0,
pendingResults: make(chan *tools.ToolResult, 16),
}
ts.ctx, ts.cancelFunc = context.WithCancel(ctx)
// Finish gracefully
ts.Finish(false)
// Verify parentEnded is set
if !ts.parentEnded.Load() {
t.Error("parentEnded should be true after graceful finish")
}
// Verify context is NOT cancelled (for graceful finish, children continue)
// Note: In graceful mode, we don't call cancelFunc()
// But since we're using WithCancel on the same ctx, it might be cancelled
// Let's check that the context is still valid for a moment
time.Sleep(10 * time.Millisecond)
// Context might be cancelled by the deferred cancel() in test, which is fine
})
// Test 2: Hard abort should cancel context immediately
t.Run("Hard_CancelsContext", func(t *testing.T) {
ctx := context.Background()
ts := &turnState{
ctx: ctx,
turnID: "hard-test",
depth: 0,
pendingResults: make(chan *tools.ToolResult, 16),
}
ts.ctx, ts.cancelFunc = context.WithCancel(ctx)
// Finish with hard abort
ts.Finish(true)
// Verify context is cancelled
select {
case <-ts.ctx.Done():
t.Log("✓ Context cancelled after hard abort")
default:
t.Error("Context should be cancelled after hard abort")
}
})
// Test 3: IsParentEnded returns correct value
t.Run("IsParentEnded", func(t *testing.T) {
ctx := context.Background()
parentTS := &turnState{
ctx: ctx,
turnID: "parent-isended-test",
depth: 0,
pendingResults: make(chan *tools.ToolResult, 16),
}
parentTS.ctx, parentTS.cancelFunc = context.WithCancel(ctx)
childTS := &turnState{
ctx: ctx,
turnID: "child-isended-test",
depth: 1,
parentTurnState: parentTS,
pendingResults: make(chan *tools.ToolResult, 16),
}
// Before parent finishes
if childTS.IsParentEnded() {
t.Error("IsParentEnded should be false before parent finishes")
}
// Finish parent gracefully
parentTS.Finish(false)
// After parent finishes
if !childTS.IsParentEnded() {
t.Error("IsParentEnded should be true after parent finishes gracefully")
}
})
}
// TestSubTurn_IndependentContext verifies that SubTurns use independent contexts
// that don't get cancelled when the parent finishes gracefully.
func TestSubTurn_IndependentContext(t *testing.T) {
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Provider: "mock",
},
},
}
msgBus := bus.NewMessageBus()
provider := &slowMockProvider{delay: 500 * time.Millisecond}
al := NewAgentLoop(cfg, msgBus, provider)
ctx := context.Background()
parentTS := &turnState{
ctx: ctx,
turnID: "parent-independent",
depth: 0,
session: newEphemeralSession(nil),
pendingResults: make(chan *tools.ToolResult, 16),
concurrencySem: make(chan struct{}, maxConcurrentSubTurns),
}
parentTS.ctx, parentTS.cancelFunc = context.WithCancel(ctx)
var subTurnErr error
var wg sync.WaitGroup
// Spawn SubTurn with Critical=true (should continue after parent finishes)
wg.Add(1)
go func() {
defer wg.Done()
subTurnCfg := SubTurnConfig{
Model: "slow-model",
Async: true,
Critical: true, // Critical SubTurn should continue
}
_, subTurnErr = spawnSubTurn(parentTS.ctx, al, parentTS, subTurnCfg)
}()
// Let SubTurn start
time.Sleep(50 * time.Millisecond)
// Parent finishes gracefully (should NOT cancel SubTurn)
parentTS.Finish(false)
t.Log("Parent finished gracefully, SubTurn should continue")
// Wait for SubTurn to complete
wg.Wait()
// SubTurn should complete without context cancelled error
// (because it uses independent context now)
if subTurnErr != nil {
t.Logf("SubTurn error: %v", subTurnErr)
// The error might be context.DeadlineExceeded if timeout is too short
// but should NOT be context.Canceled from parent
if errors.Is(subTurnErr, context.Canceled) {
t.Error("SubTurn should not be cancelled by parent's graceful finish")
}
} else {
t.Log("✓ SubTurn completed successfully (independent context)")
}
}