diff --git a/pkg/agent/subturn_test.go b/pkg/agent/subturn_test.go index a2d7120dd..e690fa544 100644 --- a/pkg/agent/subturn_test.go +++ b/pkg/agent/subturn_test.go @@ -1813,3 +1813,196 @@ func TestSpawnDuringAbort_RaceCondition(t *testing.T) { // The important thing is that it doesn't panic or deadlock t.Log("Race condition handled gracefully - no panic or deadlock") } + +// ====================== Slow SubTurn Cancellation Test ====================== + +// slowMockProvider simulates a slow LLM call that takes a long time to complete. +// This is used to test the scenario where a parent turn finishes before the child SubTurn. +type slowMockProvider struct { + delay time.Duration +} + +func (m *slowMockProvider) Chat( + ctx context.Context, + messages []providers.Message, + toolDefs []providers.ToolDefinition, + model string, + options map[string]any, +) (*providers.LLMResponse, error) { + select { + case <-time.After(m.delay): + // Completed normally after delay + return &providers.LLMResponse{ + Content: "slow response completed", + }, nil + case <-ctx.Done(): + // Context was cancelled while waiting + return nil, ctx.Err() + } +} + +func (m *slowMockProvider) GetDefaultModel() string { + return "slow-model" +} + +// TestAsyncSubTurn_ParentFinishesEarly simulates the scenario where: +// 1. Parent spawns an async SubTurn that takes a long time +// 2. Parent finishes quickly +// 3. SubTurn should be cancelled with context canceled error +func TestAsyncSubTurn_ParentFinishesEarly(t *testing.T) { + // Save original MockEventBus.Emit to capture events + originalEmit := MockEventBus.Emit + defer func() { + MockEventBus.Emit = originalEmit + }() + + var mu sync.Mutex + var events []any + MockEventBus.Emit = func(e any) { + mu.Lock() + defer mu.Unlock() + events = append(events, e) + } + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Provider: "mock", + }, + }, + } + msgBus := bus.NewMessageBus() + provider := &slowMockProvider{delay: 5 * time.Second} // SubTurn takes 5 seconds + al := NewAgentLoop(cfg, msgBus, provider) + + ctx := context.Background() + parentTS := &turnState{ + ctx: ctx, + turnID: "parent-fast", + depth: 0, + session: newEphemeralSession(nil), + pendingResults: make(chan *tools.ToolResult, 16), + concurrencySem: make(chan struct{}, maxConcurrentSubTurns), + } + parentTS.ctx, parentTS.cancelFunc = context.WithCancel(ctx) + + var subTurnErr error + var subTurnResult *tools.ToolResult + var wg sync.WaitGroup + + // Spawn async SubTurn in a goroutine (it will be slow) + wg.Add(1) + go func() { + defer wg.Done() + subTurnCfg := SubTurnConfig{ + Model: "slow-model", + Async: true, // Asynchronous SubTurn + } + subTurnResult, subTurnErr = spawnSubTurn(parentTS.ctx, al, parentTS, subTurnCfg) + }() + + // Parent finishes quickly (after 100ms), while SubTurn is still running + time.Sleep(100 * time.Millisecond) + t.Log("Parent finishing early...") + parentTS.Finish() + + // Wait for SubTurn to complete (or be cancelled) + wg.Wait() + + // Check the result + t.Logf("SubTurn error: %v", subTurnErr) + t.Logf("SubTurn result: %v", subTurnResult) + + if subTurnErr != nil { + if errors.Is(subTurnErr, context.Canceled) { + t.Log("✓ SubTurn was cancelled as expected (context canceled)") + } else { + t.Logf("SubTurn failed with other error: %v", subTurnErr) + } + } else { + t.Log("SubTurn completed before parent finished (unlikely but possible)") + } + + // Log captured events + mu.Lock() + t.Logf("Captured %d events:", len(events)) + for i, e := range events { + t.Logf(" Event %d: %T", i+1, e) + } + mu.Unlock() +} + +// TestAsyncSubTurn_ParentWaitsForChild simulates the scenario where: +// 1. Parent spawns an async SubTurn that takes some time +// 2. Parent WAITS for SubTurn to complete before finishing +// 3. Both should complete successfully +func TestAsyncSubTurn_ParentWaitsForChild(t *testing.T) { + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Provider: "mock", + }, + }, + } + msgBus := bus.NewMessageBus() + provider := &slowMockProvider{delay: 200 * time.Millisecond} // SubTurn takes 200ms + al := NewAgentLoop(cfg, msgBus, provider) + + ctx := context.Background() + parentTS := &turnState{ + ctx: ctx, + turnID: "parent-wait", + depth: 0, + session: newEphemeralSession(nil), + pendingResults: make(chan *tools.ToolResult, 16), + concurrencySem: make(chan struct{}, maxConcurrentSubTurns), + } + parentTS.ctx, parentTS.cancelFunc = context.WithCancel(ctx) + + var subTurnErr error + var subTurnResult *tools.ToolResult + var wg sync.WaitGroup + + // Spawn async SubTurn in a goroutine + wg.Add(1) + go func() { + defer wg.Done() + subTurnCfg := SubTurnConfig{ + Model: "slow-model", + Async: true, + } + subTurnResult, subTurnErr = spawnSubTurn(parentTS.ctx, al, parentTS, subTurnCfg) + }() + + // Parent WAITS for SubTurn to complete + t.Log("Parent waiting for SubTurn...") + wg.Wait() + t.Log("SubTurn completed, parent now finishing") + + // Now parent can finish safely + parentTS.Finish() + + // Check the result + if subTurnErr != nil { + if errors.Is(subTurnErr, context.Canceled) { + t.Errorf("SubTurn should NOT have been cancelled: %v", subTurnErr) + } else { + t.Logf("SubTurn failed with error: %v", subTurnErr) + } + } else { + t.Log("✓ SubTurn completed successfully") + if subTurnResult != nil { + t.Logf("SubTurn result: %s", subTurnResult.ForLLM) + } + } + + // Check channel delivery + select { + case r := <-parentTS.pendingResults: + if r != nil { + t.Logf("✓ Result delivered to channel: %s", r.ForLLM) + } + case <-time.After(100 * time.Millisecond): + t.Log("No result in channel (expected since we waited)") + } +}