package agent import ( "context" "errors" "fmt" "reflect" "sync" "testing" "time" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/tools" ) // ====================== Test Helper: Event Collector ====================== type eventCollector struct { events []any } func (c *eventCollector) collect(e any) { c.events = append(c.events, e) } func (c *eventCollector) hasEventOfType(typ any) bool { targetType := reflect.TypeOf(typ) for _, e := range c.events { if reflect.TypeOf(e) == targetType { return true } } return false } func (c *eventCollector) countOfType(typ any) int { targetType := reflect.TypeOf(typ) count := 0 for _, e := range c.events { if reflect.TypeOf(e) == targetType { count++ } } return count } // ====================== Main Test Function ====================== func TestSpawnSubTurn(t *testing.T) { tests := []struct { name string parentDepth int config SubTurnConfig wantErr error wantSpawn bool wantEnd bool wantDepthFail bool }{ { name: "Basic success path - Single layer sub-turn", parentDepth: 0, config: SubTurnConfig{ Model: "gpt-4o-mini", Tools: []tools.Tool{}, // At least one tool }, wantErr: nil, wantSpawn: true, wantEnd: true, }, { name: "Nested 2 layers - Normal", parentDepth: 1, config: SubTurnConfig{ Model: "gpt-4o-mini", Tools: []tools.Tool{}, }, wantErr: nil, wantSpawn: true, wantEnd: true, }, { name: "Depth limit triggered - 4th layer fails", parentDepth: 3, config: SubTurnConfig{ Model: "gpt-4o-mini", Tools: []tools.Tool{}, }, wantErr: ErrDepthLimitExceeded, wantSpawn: false, wantEnd: false, wantDepthFail: true, }, { name: "Invalid config - Empty Model", parentDepth: 0, config: SubTurnConfig{ Model: "", Tools: []tools.Tool{}, }, wantErr: ErrInvalidSubTurnConfig, wantSpawn: false, wantEnd: false, }, } al, _, _, _, cleanup := newTestAgentLoop(t) defer cleanup() for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // Prepare parent Turn parent := &turnState{ ctx: context.Background(), turnID: "parent-1", depth: tt.parentDepth, childTurnIDs: []string{}, pendingResults: make(chan *tools.ToolResult, 10), session: &ephemeralSessionStore{}, } // Replace mock with test collector collector := &eventCollector{} originalEmit := MockEventBus.Emit MockEventBus.Emit = collector.collect defer func() { MockEventBus.Emit = originalEmit }() // Execute spawnSubTurn result, err := spawnSubTurn(context.Background(), al, parent, tt.config) // Assert errors if tt.wantErr != nil { if err == nil || err != tt.wantErr { t.Errorf("expected error %v, got %v", tt.wantErr, err) } return } if err != nil { t.Errorf("unexpected error: %v", err) return } // Verify result if result == nil { t.Error("expected non-nil result") } // Verify event emission if tt.wantSpawn { if !collector.hasEventOfType(SubTurnSpawnEvent{}) { t.Error("SubTurnSpawnEvent not emitted") } } if tt.wantEnd { if !collector.hasEventOfType(SubTurnEndEvent{}) { t.Error("SubTurnEndEvent not emitted") } } // Verify turn tree if len(parent.childTurnIDs) == 0 && !tt.wantDepthFail { t.Error("child Turn not added to parent.childTurnIDs") } // For synchronous calls (Async=false, the default), result is returned directly // and should NOT be in pendingResults. The result was already verified above. // Only async calls (Async=true) would place results in pendingResults. }) } } // ====================== Extra Independent Test: Ephemeral Session Isolation ====================== func TestSpawnSubTurn_EphemeralSessionIsolation(t *testing.T) { al, _, _, _, cleanup := newTestAgentLoop(t) defer cleanup() parentSession := &ephemeralSessionStore{} parentSession.AddMessage("", "user", "parent msg") parent := &turnState{ ctx: context.Background(), turnID: "parent-1", depth: 0, pendingResults: make(chan *tools.ToolResult, 1), session: parentSession, } cfg := SubTurnConfig{Model: "gpt-4o-mini", Tools: []tools.Tool{}} // Record main session length before execution originalLen := len(parent.session.GetHistory("")) _, _ = spawnSubTurn(context.Background(), al, parent, cfg) // After sub-turn ends, main session must remain unchanged if len(parent.session.GetHistory("")) != originalLen { t.Error("ephemeral session polluted the main session") } } // ====================== Extra Independent Test: Result Delivery Path (Async) ====================== func TestSpawnSubTurn_ResultDelivery(t *testing.T) { al, _, _, _, cleanup := newTestAgentLoop(t) defer cleanup() parent := &turnState{ ctx: context.Background(), turnID: "parent-1", depth: 0, pendingResults: make(chan *tools.ToolResult, 1), session: &ephemeralSessionStore{}, } // Set Async=true to test async result delivery via pendingResults channel cfg := SubTurnConfig{Model: "gpt-4o-mini", Tools: []tools.Tool{}, Async: true} _, _ = spawnSubTurn(context.Background(), al, parent, cfg) // Check if pendingResults received the result (only for async calls) select { case res := <-parent.pendingResults: if res == nil { t.Error("received nil result in pendingResults") } default: t.Error("result did not enter pendingResults for async call") } } // ====================== Extra Independent Test: Result Delivery Path (Sync) ====================== func TestSpawnSubTurn_ResultDeliverySync(t *testing.T) { al, _, _, _, cleanup := newTestAgentLoop(t) defer cleanup() parent := &turnState{ ctx: context.Background(), turnID: "parent-sync-1", depth: 0, pendingResults: make(chan *tools.ToolResult, 1), session: &ephemeralSessionStore{}, } // Sync call (Async=false, the default) - result should be returned directly cfg := SubTurnConfig{Model: "gpt-4o-mini", Tools: []tools.Tool{}, Async: false} result, err := spawnSubTurn(context.Background(), al, parent, cfg) if err != nil { t.Fatalf("unexpected error: %v", err) } // Result should be returned directly if result == nil { t.Error("expected non-nil result from sync call") } // pendingResults should NOT contain the result (no double delivery) select { case <-parent.pendingResults: t.Error("sync call should not place result in pendingResults (double delivery)") default: // Expected - channel should be empty } } // ====================== Extra Independent Test: Orphan Result Routing ====================== func TestSpawnSubTurn_OrphanResultRouting(t *testing.T) { parentCtx, cancelParent := context.WithCancel(context.Background()) parent := &turnState{ ctx: parentCtx, cancelFunc: cancelParent, turnID: "parent-1", depth: 0, pendingResults: make(chan *tools.ToolResult, 1), session: &ephemeralSessionStore{}, } collector := &eventCollector{} originalEmit := MockEventBus.Emit MockEventBus.Emit = collector.collect defer func() { MockEventBus.Emit = originalEmit }() // Simulate parent finishing before child delivers result parent.Finish(false) // Call deliverSubTurnResult directly to simulate a delayed child deliverSubTurnResult(parent, "delayed-child", &tools.ToolResult{ForLLM: "late result"}) // Verify Orphan event is emitted if !collector.hasEventOfType(SubTurnOrphanResultEvent{}) { t.Error("SubTurnOrphanResultEvent not emitted for finished parent") } // Verify history is NOT polluted if len(parent.session.GetHistory("")) != 0 { t.Error("Parent history was polluted by orphan result") } } // ====================== Extra Independent Test: Result Channel Registration ====================== func TestSubTurnResultChannelRegistration(t *testing.T) { al, _, _, _, cleanup := newTestAgentLoop(t) defer cleanup() parent := &turnState{ ctx: context.Background(), turnID: "parent-reg-1", depth: 0, pendingResults: make(chan *tools.ToolResult, 4), session: &ephemeralSessionStore{}, } cfg := SubTurnConfig{Model: "gpt-4o-mini", Tools: []tools.Tool{}} // Before spawn: channel should not be registered if results := al.dequeuePendingSubTurnResults(parent.turnID); results != nil { t.Error("expected no channel before spawnSubTurn") } _, _ = spawnSubTurn(context.Background(), al, parent, cfg) } // ====================== Extra Independent Test: Dequeue Pending SubTurn Results ====================== func TestDequeuePendingSubTurnResults(t *testing.T) { al, _, _, _, cleanup := newTestAgentLoop(t) defer cleanup() sessionKey := "test-session-dequeue" // Empty (no turnState registered) returns nil if results := al.dequeuePendingSubTurnResults(sessionKey); len(results) != 0 { t.Errorf("expected empty results, got %d", len(results)) } // Register a turnState so dequeuePendingSubTurnResults can find it ts := &turnState{ ctx: context.Background(), turnID: sessionKey, depth: 0, session: &ephemeralSessionStore{}, pendingResults: make(chan *tools.ToolResult, 4), } al.activeTurnStates.Store(sessionKey, ts) defer al.activeTurnStates.Delete(sessionKey) // Put 3 results in ts.pendingResults <- &tools.ToolResult{ForLLM: "result-1"} ts.pendingResults <- &tools.ToolResult{ForLLM: "result-2"} ts.pendingResults <- &tools.ToolResult{ForLLM: "result-3"} results := al.dequeuePendingSubTurnResults(sessionKey) if len(results) != 3 { t.Errorf("expected 3 results, got %d", len(results)) } if results[0].ForLLM != "result-1" || results[2].ForLLM != "result-3" { t.Error("results order or content mismatch") } // Channel should be drained now if results := al.dequeuePendingSubTurnResults(sessionKey); len(results) != 0 { t.Errorf("expected empty after drain, got %d", len(results)) } // After removing from activeTurnStates, returns nil al.activeTurnStates.Delete(sessionKey) if results := al.dequeuePendingSubTurnResults(sessionKey); results != nil { t.Error("expected nil for unregistered session") } } // ====================== Extra Independent Test: Concurrency Semaphore ====================== func TestSubTurnConcurrencySemaphore(t *testing.T) { al, _, _, _, cleanup := newTestAgentLoop(t) defer cleanup() parent := &turnState{ ctx: context.Background(), turnID: "parent-concurrency", depth: 0, pendingResults: make(chan *tools.ToolResult, 10), session: &ephemeralSessionStore{}, concurrencySem: make(chan struct{}, 2), // Only allow 2 concurrent children } cfg := SubTurnConfig{Model: "gpt-4o-mini", Tools: []tools.Tool{}} // Spawn 2 children — should succeed immediately done := make(chan bool, 3) for i := 0; i < 2; i++ { go func() { _, _ = spawnSubTurn(context.Background(), al, parent, cfg) done <- true }() } // Wait a bit to ensure the first 2 are running // (In real scenario they'd be blocked in runTurn, but mockProvider returns immediately) // So we just verify the semaphore doesn't block when under limit <-done <-done // Verify semaphore is now full (2/2 slots used, but they already released) // Since mockProvider returns immediately, semaphore is already released // So we can't easily test blocking without a real long-running operation // Instead, verify that semaphore exists and has correct capacity if cap(parent.concurrencySem) != 2 { t.Errorf("expected semaphore capacity 2, got %d", cap(parent.concurrencySem)) } } // ====================== Extra Independent Test: Hard Abort Cascading ====================== func TestHardAbortCascading(t *testing.T) { al, _, _, _, cleanup := newTestAgentLoop(t) defer cleanup() sessionKey := "test-session-abort" parentCtx, parentCancel := context.WithCancel(context.Background()) defer parentCancel() rootTS := &turnState{ ctx: parentCtx, turnID: sessionKey, depth: 0, session: &ephemeralSessionStore{}, pendingResults: make(chan *tools.ToolResult, 16), concurrencySem: make(chan struct{}, 5), } // Register the root turn state al.activeTurnStates.Store(sessionKey, rootTS) defer al.activeTurnStates.Delete(sessionKey) // Create a child turn state childCtx, childCancel := context.WithCancel(rootTS.ctx) defer childCancel() childTS := &turnState{ ctx: childCtx, cancelFunc: childCancel, turnID: "child-1", parentTurnID: sessionKey, depth: 1, session: &ephemeralSessionStore{}, pendingResults: make(chan *tools.ToolResult, 16), concurrencySem: make(chan struct{}, 5), } // Attach cancelFunc to rootTS so Finish() can trigger it rootTS.cancelFunc = parentCancel // Verify contexts are not canceled yet select { case <-rootTS.ctx.Done(): t.Error("root context should not be canceled yet") default: } select { case <-childTS.ctx.Done(): t.Error("child context should not be canceled yet") default: } // Trigger Hard Abort err := al.HardAbort(sessionKey) if err != nil { t.Errorf("HardAbort failed: %v", err) } // Verify root context is canceled select { case <-rootTS.ctx.Done(): // Expected default: t.Error("root context should be canceled after HardAbort") } // Verify child context is also canceled (cascading) select { case <-childTS.ctx.Done(): // Expected default: t.Error("child context should be canceled after HardAbort (cascading)") } // Verify HardAbort on non-existent session returns error err = al.HardAbort("non-existent-session") if err == nil { t.Error("expected error for non-existent session") } } // TestHardAbortSessionRollback verifies that HardAbort rolls back session history // to the state before the turn started, discarding all messages added during the turn. func TestHardAbortSessionRollback(t *testing.T) { al, _, _, _, cleanup := newTestAgentLoop(t) defer cleanup() // Create a session with initial history sess := &ephemeralSessionStore{ history: []providers.Message{ {Role: "user", Content: "initial message 1"}, {Role: "assistant", Content: "initial response 1"}, }, } // Create a root turnState with initialHistoryLength = 2 rootTS := &turnState{ ctx: context.Background(), turnID: "test-session", depth: 0, session: sess, initialHistoryLength: 2, // Snapshot: 2 messages pendingResults: make(chan *tools.ToolResult, 16), concurrencySem: make(chan struct{}, 5), } // Register the turn state al.activeTurnStates.Store("test-session", rootTS) // Simulate adding messages during the turn (e.g., user input + assistant response) sess.AddMessage("", "user", "new user message") sess.AddMessage("", "assistant", "new assistant response") // Verify history grew to 4 messages if len(sess.GetHistory("")) != 4 { t.Fatalf("expected 4 messages before abort, got %d", len(sess.GetHistory(""))) } // Trigger HardAbort err := al.HardAbort("test-session") if err != nil { t.Fatalf("HardAbort failed: %v", err) } // Verify history rolled back to initial 2 messages finalHistory := sess.GetHistory("") if len(finalHistory) != 2 { t.Errorf("expected history to rollback to 2 messages, got %d", len(finalHistory)) } // Verify the content matches the initial state if finalHistory[0].Content != "initial message 1" || finalHistory[1].Content != "initial response 1" { t.Error("history content does not match initial state after rollback") } } // TestNestedSubTurnHierarchy verifies that nested SubTurns maintain correct // parent-child relationships and depth tracking when recursively calling runAgentLoop. func TestNestedSubTurnHierarchy(t *testing.T) { al, _, _, _, cleanup := newTestAgentLoop(t) defer cleanup() // Track spawned turns and their depths type turnInfo struct { parentID string childID string depth int } var spawnedTurns []turnInfo var mu sync.Mutex // Override MockEventBus to capture spawn events originalEmit := MockEventBus.Emit defer func() { MockEventBus.Emit = originalEmit }() MockEventBus.Emit = func(event any) { if spawnEvent, ok := event.(SubTurnSpawnEvent); ok { mu.Lock() // Extract depth from context (we'll verify this matches expected depth) spawnedTurns = append(spawnedTurns, turnInfo{ parentID: spawnEvent.ParentID, childID: spawnEvent.ChildID, }) mu.Unlock() } } // Create a root turn rootSession := &ephemeralSessionStore{} rootTS := &turnState{ ctx: context.Background(), turnID: "root-turn", depth: 0, session: rootSession, pendingResults: make(chan *tools.ToolResult, 16), concurrencySem: make(chan struct{}, 5), } // Spawn a child (depth 1) childCfg := SubTurnConfig{Model: "gpt-4o-mini"} _, err := spawnSubTurn(context.Background(), al, rootTS, childCfg) if err != nil { t.Fatalf("failed to spawn child: %v", err) } // Verify we captured the spawn event mu.Lock() if len(spawnedTurns) != 1 { t.Fatalf("expected 1 spawn event, got %d", len(spawnedTurns)) } if spawnedTurns[0].parentID != "root-turn" { t.Errorf("expected parent ID 'root-turn', got %s", spawnedTurns[0].parentID) } mu.Unlock() // Verify root turn has the child in its childTurnIDs rootTS.mu.Lock() if len(rootTS.childTurnIDs) != 1 { t.Errorf("expected root to have 1 child, got %d", len(rootTS.childTurnIDs)) } rootTS.mu.Unlock() } // TestDeliverSubTurnResultNoDeadlock verifies that deliverSubTurnResult doesn't // deadlock when multiple goroutines are accessing the parent turnState concurrently. func TestDeliverSubTurnResultNoDeadlock(t *testing.T) { parent := &turnState{ ctx: context.Background(), turnID: "parent-deadlock-test", depth: 0, pendingResults: make(chan *tools.ToolResult, 2), // Small buffer to test blocking isFinished: false, } // Simulate multiple child turns delivering results concurrently var wg sync.WaitGroup numChildren := 10 for i := 0; i < numChildren; i++ { wg.Add(1) go func(id int) { defer wg.Done() result := &tools.ToolResult{ForLLM: fmt.Sprintf("result-%d", id)} deliverSubTurnResult(parent, fmt.Sprintf("child-%d", id), result) }(i) } // Concurrently read from the channel to prevent blocking // and to actually retrieve the matched number of results go func() { for i := 0; i < numChildren; i++ { select { case <-parent.pendingResults: case <-time.After(5 * time.Second): t.Error("timeout waiting for result") return } } }() // Wait for all deliveries to complete (with timeout) done := make(chan struct{}) go func() { wg.Wait() close(done) }() select { case <-done: // Success - no deadlock case <-time.After(3 * time.Second): t.Fatal("deadlock detected: deliverSubTurnResult blocked") } } // TestHardAbortOrderOfOperations verifies that HardAbort calls Finish() before // rolling back session history, minimizing the race window where new messages // could be added after rollback. func TestHardAbortOrderOfOperations(t *testing.T) { al, _, _, _, cleanup := newTestAgentLoop(t) defer cleanup() sess := &ephemeralSessionStore{ history: []providers.Message{ {Role: "user", Content: "initial message"}, {Role: "assistant", Content: "response 1"}, {Role: "user", Content: "follow-up"}, }, } ctx, cancel := context.WithCancel(context.Background()) defer cancel() rootTS := &turnState{ ctx: ctx, cancelFunc: cancel, turnID: "test-session-order", depth: 0, session: sess, initialHistoryLength: 1, // Snapshot: 1 message pendingResults: make(chan *tools.ToolResult, 16), concurrencySem: make(chan struct{}, 5), } al.activeTurnStates.Store("test-session-order", rootTS) // Trigger HardAbort err := al.HardAbort("test-session-order") if err != nil { t.Fatalf("HardAbort failed: %v", err) } // Verify context was cancelled (Finish() was called) select { case <-rootTS.ctx.Done(): // Good - context was cancelled default: t.Error("expected context to be cancelled after HardAbort") } // Verify history was rolled back finalHistory := sess.GetHistory("") if len(finalHistory) != 1 { t.Errorf("expected history to rollback to 1 message, got %d", len(finalHistory)) } if finalHistory[0].Content != "initial message" { t.Error("history content does not match initial state after rollback") } } // TestFinishedChannelClosedState verifies that Finish() closes the Finished() channel // so that child turns can safely abort waiting. func TestFinishedChannelClosedState(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() ts := &turnState{ ctx: ctx, cancelFunc: cancel, turnID: "test-finished-channel", depth: 0, pendingResults: make(chan *tools.ToolResult, 2), isFinished: false, } // Verify Finished channel is blocking initially select { case <-ts.Finished(): t.Fatal("finished channel should block initially") default: // Good } // Call Finish() with graceful finish ts.Finish(false) // Verify Finished channel is closed select { case _, ok := <-ts.Finished(): if ok { t.Error("expected Finished() channel to be closed after Finish()") } default: t.Fatal("expected <-ts.Finished() to not block") } // Verify Finish() is idempotent ts.Finish(false) // Should not panic // Verify deliverSubTurnResult correctly uses Finished() channel and treats as orphan result := &tools.ToolResult{ForLLM: "late result"} deliverSubTurnResult(ts, "child-1", result) // Will emit orphan due to <-ts.Finished() case } // TestFinalPollCapturesLateResults verifies that the final poll before Finish() // captures results that arrive after the last iteration poll. func TestFinalPollCapturesLateResults(t *testing.T) { al, _, _, _, cleanup := newTestAgentLoop(t) defer cleanup() sessionKey := "test-session-final-poll" // Register a turnState ts := &turnState{ ctx: context.Background(), turnID: sessionKey, depth: 0, session: &ephemeralSessionStore{}, pendingResults: make(chan *tools.ToolResult, 4), } al.activeTurnStates.Store(sessionKey, ts) defer al.activeTurnStates.Delete(sessionKey) // Simulate results arriving after last iteration poll ts.pendingResults <- &tools.ToolResult{ForLLM: "result 1"} ts.pendingResults <- &tools.ToolResult{ForLLM: "result 2"} // Dequeue should capture both results results := al.dequeuePendingSubTurnResults(sessionKey) if len(results) != 2 { t.Errorf("expected 2 results, got %d", len(results)) } // Verify channel is now empty results = al.dequeuePendingSubTurnResults(sessionKey) if len(results) != 0 { t.Errorf("expected 0 results on second poll, got %d", len(results)) } } // TestSpawnSubTurn_PanicRecovery verifies that even if runTurn panics, // the result is still delivered for async calls and SubTurnEndEvent is emitted. func TestSpawnSubTurn_PanicRecovery(t *testing.T) { // Create a panic provider panicProvider := &panicMockProvider{} cfg := &config.Config{ Agents: config.AgentsConfig{ Defaults: config.AgentDefaults{ Workspace: t.TempDir(), Model: "test-model", MaxTokens: 4096, MaxToolIterations: 10, }, }, } al := NewAgentLoop(cfg, bus.NewMessageBus(), panicProvider) parent := &turnState{ ctx: context.Background(), turnID: "parent-panic", depth: 0, pendingResults: make(chan *tools.ToolResult, 1), session: &ephemeralSessionStore{}, } collector := &eventCollector{} originalEmit := MockEventBus.Emit MockEventBus.Emit = collector.collect defer func() { MockEventBus.Emit = originalEmit }() // Test async call - result should still be delivered via channel asyncCfg := SubTurnConfig{Model: "gpt-4o-mini", Tools: []tools.Tool{}, Async: true} result, err := spawnSubTurn(context.Background(), al, parent, asyncCfg) // Should return error from panic recovery if err == nil { t.Error("expected error from panic recovery") } // Result should be nil because panic occurred before runTurn could return if result != nil { t.Error("expected nil result after panic") } // SubTurnEndEvent should still be emitted if !collector.hasEventOfType(SubTurnEndEvent{}) { t.Error("SubTurnEndEvent not emitted after panic") } // For async call, result should still be delivered to channel (even if nil) select { case res := <-parent.pendingResults: // Result was delivered (nil due to panic) _ = res default: t.Error("async result should be delivered to channel even after panic") } } // panicMockProvider is a mock provider that always panics type panicMockProvider struct{} func (m *panicMockProvider) Chat( ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, opts map[string]any, ) (*providers.LLMResponse, error) { panic("intentional panic for testing") } func (m *panicMockProvider) GetDefaultModel() string { return "panic-model" } // ====================== Public API Tests ====================== // simpleMockProviderAPI for testing public APIs type simpleMockProviderAPI struct { response string } func (m *simpleMockProviderAPI) Chat( ctx context.Context, messages []providers.Message, toolDefs []providers.ToolDefinition, model string, options map[string]any, ) (*providers.LLMResponse, error) { return &providers.LLMResponse{ Content: m.response, }, nil } func (m *simpleMockProviderAPI) GetDefaultModel() string { return "gpt-4o-mini" } // TestGetActiveTurn verifies that GetActiveTurn returns correct turn information func TestGetActiveTurn(t *testing.T) { cfg := &config.Config{ Agents: config.AgentsConfig{ Defaults: config.AgentDefaults{ Model: "gpt-4o-mini", Provider: "mock", }, }, } al := NewAgentLoop(cfg, nil, &simpleMockProviderAPI{response: "ok"}) // Create a root turn state rootCtx := context.Background() rootTS := &turnState{ ctx: rootCtx, turnID: "root-turn", parentTurnID: "", depth: 0, childTurnIDs: []string{}, session: newEphemeralSession(nil), pendingResults: make(chan *tools.ToolResult, 16), concurrencySem: make(chan struct{}, maxConcurrentSubTurns), } sessionKey := "test-session" al.activeTurnStates.Store(sessionKey, rootTS) defer al.activeTurnStates.Delete(sessionKey) // Test: GetActiveTurn should return turn info info := al.GetActiveTurn(sessionKey) if info == nil { t.Fatal("GetActiveTurn returned nil for active session") } if info.TurnID != "root-turn" { t.Errorf("Expected TurnID 'root-turn', got %q", info.TurnID) } if info.Depth != 0 { t.Errorf("Expected Depth 0, got %d", info.Depth) } if info.ParentTurnID != "" { t.Errorf("Expected empty ParentTurnID, got %q", info.ParentTurnID) } if len(info.ChildTurnIDs) != 0 { t.Errorf("Expected 0 child turns, got %d", len(info.ChildTurnIDs)) } // Test: GetActiveTurn should return nil for non-existent session nonExistentInfo := al.GetActiveTurn("non-existent-session") if nonExistentInfo != nil { t.Error("GetActiveTurn should return nil for non-existent session") } } // TestGetActiveTurn_WithChildren verifies that child turn IDs are correctly reported func TestGetActiveTurn_WithChildren(t *testing.T) { cfg := &config.Config{ Agents: config.AgentsConfig{ Defaults: config.AgentDefaults{ Model: "gpt-4o-mini", Provider: "mock", }, }, } al := NewAgentLoop(cfg, nil, &simpleMockProviderAPI{response: "ok"}) rootCtx := context.Background() rootTS := &turnState{ ctx: rootCtx, turnID: "root-turn", parentTurnID: "", depth: 0, childTurnIDs: []string{"child-1", "child-2"}, session: newEphemeralSession(nil), pendingResults: make(chan *tools.ToolResult, 16), concurrencySem: make(chan struct{}, maxConcurrentSubTurns), } sessionKey := "test-session-with-children" al.activeTurnStates.Store(sessionKey, rootTS) defer al.activeTurnStates.Delete(sessionKey) info := al.GetActiveTurn(sessionKey) if info == nil { t.Fatal("GetActiveTurn returned nil") } if len(info.ChildTurnIDs) != 2 { t.Fatalf("Expected 2 child turns, got %d", len(info.ChildTurnIDs)) } if info.ChildTurnIDs[0] != "child-1" || info.ChildTurnIDs[1] != "child-2" { t.Errorf("Child turn IDs mismatch: got %v", info.ChildTurnIDs) } } // TestTurnStateInfo_ThreadSafety verifies that Info() is thread-safe func TestTurnStateInfo_ThreadSafety(t *testing.T) { rootCtx := context.Background() ts := &turnState{ ctx: rootCtx, turnID: "test-turn", parentTurnID: "parent", depth: 1, childTurnIDs: []string{}, session: newEphemeralSession(nil), pendingResults: make(chan *tools.ToolResult, 16), concurrencySem: make(chan struct{}, maxConcurrentSubTurns), } // Concurrently read Info() and modify childTurnIDs done := make(chan bool) go func() { for i := 0; i < 100; i++ { ts.mu.Lock() ts.childTurnIDs = append(ts.childTurnIDs, "child") ts.mu.Unlock() } done <- true }() go func() { for i := 0; i < 100; i++ { info := ts.Info() if info == nil { t.Error("Info() returned nil") } } done <- true }() <-done <-done } // TestInjectFollowUp verifies that InjectFollowUp enqueues messages func TestInjectFollowUp(t *testing.T) { cfg := &config.Config{ Agents: config.AgentsConfig{ Defaults: config.AgentDefaults{ Model: "gpt-4o-mini", Provider: "mock", }, }, } al := NewAgentLoop(cfg, nil, &simpleMockProviderAPI{response: "ok"}) msg := providers.Message{ Role: "user", Content: "Follow-up task", } err := al.InjectFollowUp(msg) if err != nil { t.Fatalf("InjectFollowUp failed: %v", err) } // Verify message was enqueued if al.steering.len() != 1 { t.Errorf("Expected 1 message in queue, got %d", al.steering.len()) } } // TestAPIAliases verifies that API aliases work correctly func TestAPIAliases(t *testing.T) { cfg := &config.Config{ Agents: config.AgentsConfig{ Defaults: config.AgentDefaults{ Model: "gpt-4o-mini", Provider: "mock", }, }, } al := NewAgentLoop(cfg, nil, &simpleMockProviderAPI{response: "ok"}) msg := providers.Message{ Role: "user", Content: "Test message", } // Test InterruptGraceful (alias for Steer) err := al.InterruptGraceful(msg) if err != nil { t.Errorf("InterruptGraceful failed: %v", err) } // Test InjectSteering (alias for Steer) err = al.InjectSteering(msg) if err != nil { t.Errorf("InjectSteering failed: %v", err) } // Verify both messages were enqueued if al.steering.len() != 2 { t.Errorf("Expected 2 messages in queue, got %d", al.steering.len()) } } // TestInterruptHard_Alias verifies that InterruptHard is an alias for HardAbort func TestInterruptHard_Alias(t *testing.T) { cfg := &config.Config{ Agents: config.AgentsConfig{ Defaults: config.AgentDefaults{ Model: "gpt-4o-mini", Provider: "mock", }, }, } al := NewAgentLoop(cfg, nil, &simpleMockProviderAPI{response: "ok"}) rootCtx := context.Background() rootTS := &turnState{ ctx: rootCtx, turnID: "test-turn", depth: 0, session: newEphemeralSession(nil), initialHistoryLength: 0, pendingResults: make(chan *tools.ToolResult, 16), concurrencySem: make(chan struct{}, maxConcurrentSubTurns), } sessionKey := "test-session-interrupt" al.activeTurnStates.Store(sessionKey, rootTS) // Test InterruptHard (alias for HardAbort) err := al.InterruptHard(sessionKey) if err != nil { t.Errorf("InterruptHard failed: %v", err) } // Verify turn was finished info := al.GetActiveTurn(sessionKey) if info != nil && !info.IsFinished { t.Error("Turn should be finished after InterruptHard") } } // TestFinish_ConcurrentCalls verifies that calling Finish() concurrently from multiple // goroutines is safe and doesn't cause panics or double-close errors. func TestFinish_ConcurrentCalls(t *testing.T) { ctx := context.Background() parentTS := &turnState{ ctx: ctx, turnID: "parent-concurrent-finish", depth: 0, pendingResults: make(chan *tools.ToolResult, 16), concurrencySem: make(chan struct{}, maxConcurrentSubTurns), } parentTS.ctx, parentTS.cancelFunc = context.WithCancel(ctx) // Launch multiple goroutines that all call Finish() concurrently const numGoroutines = 10 var wg sync.WaitGroup wg.Add(numGoroutines) for i := 0; i < numGoroutines; i++ { go func() { defer wg.Done() // This should not panic, even when called concurrently parentTS.Finish(false) }() } wg.Wait() // Verify the Finished() channel is closed select { case _, ok := <-parentTS.Finished(): if ok { t.Error("Expected Finished() channel to be closed") } default: t.Error("Expected Finished() channel to be closed and readable without blocking") } // Verify isFinished is set parentTS.mu.Lock() if !parentTS.isFinished { t.Error("Expected isFinished to be true") } parentTS.mu.Unlock() } // TestDeliverSubTurnResult_RaceWithFinish verifies that deliverSubTurnResult handles // the race condition where Finish() is called while results are being delivered. func TestDeliverSubTurnResult_RaceWithFinish(t *testing.T) { // Save original MockEventBus.Emit originalEmit := MockEventBus.Emit defer func() { MockEventBus.Emit = originalEmit }() // Collect events var mu sync.Mutex var deliveredCount, orphanCount int MockEventBus.Emit = func(e any) { mu.Lock() defer mu.Unlock() switch e.(type) { case SubTurnResultDeliveredEvent: deliveredCount++ case SubTurnOrphanResultEvent: orphanCount++ } } ctx := context.Background() parentTS := &turnState{ ctx: ctx, turnID: "parent-race-test", depth: 0, pendingResults: make(chan *tools.ToolResult, 16), concurrencySem: make(chan struct{}, maxConcurrentSubTurns), } parentTS.ctx, parentTS.cancelFunc = context.WithCancel(ctx) // Launch goroutines that deliver results while another goroutine calls Finish() const numResults = 20 var wg sync.WaitGroup wg.Add(numResults + 1) // Goroutine that calls Finish() after a short delay go func() { defer wg.Done() time.Sleep(5 * time.Millisecond) parentTS.Finish(false) }() // Goroutines that deliver results for i := 0; i < numResults; i++ { go func(id int) { defer wg.Done() result := &tools.ToolResult{ ForLLM: fmt.Sprintf("result-%d", id), } // This should not panic, even if Finish() is called concurrently deliverSubTurnResult(parentTS, fmt.Sprintf("child-%d", id), result) }(i) } wg.Wait() // Get final counts mu.Lock() finalDelivered := deliveredCount finalOrphan := orphanCount mu.Unlock() t.Logf("Delivered: %d, Orphan: %d, Total: %d", finalDelivered, finalOrphan, finalDelivered+finalOrphan) // With the new drainPendingResults behavior, the total events may be >= numResults // because Finish() drains remaining results from the channel and emits them as orphans. // So we expect: // - Some results were delivered successfully (before Finish()) // - Some results became orphans (after Finish() or channel full) // - Some results were in the channel when Finish() was called and got drained as orphans // The total should be at least numResults (could be more due to drain) if finalDelivered+finalOrphan < numResults { t.Errorf("Expected at least %d total events, got %d delivered + %d orphan = %d", numResults, finalDelivered, finalOrphan, finalDelivered+finalOrphan) } // Should have at least some orphan results (those that arrived after Finish() or were drained) if finalOrphan == 0 { t.Error("Expected at least some orphan results after Finish()") } } // TestConcurrencySemaphore_Timeout verifies that spawning sub-turns times out // when all concurrency slots are occupied for too long. // Note: This test uses a shorter timeout by temporarily modifying the constant. func TestConcurrencySemaphore_Timeout(t *testing.T) { // This test would take 30 seconds with the default timeout. // Instead, we'll test the mechanism by verifying the timeout context is created correctly. // A full integration test with actual timeout would be too slow for unit tests. cfg := &config.Config{ Agents: config.AgentsConfig{ Defaults: config.AgentDefaults{ Provider: "mock", }, }, } msgBus := bus.NewMessageBus() provider := &simpleMockProviderAPI{} al := NewAgentLoop(cfg, msgBus, provider) ctx := context.Background() parentTS := &turnState{ ctx: ctx, turnID: "parent-timeout-test", depth: 0, session: newEphemeralSession(nil), pendingResults: make(chan *tools.ToolResult, 16), concurrencySem: make(chan struct{}, maxConcurrentSubTurns), } parentTS.ctx, parentTS.cancelFunc = context.WithCancel(ctx) defer parentTS.Finish(false) // Fill all concurrency slots for i := 0; i < maxConcurrentSubTurns; i++ { parentTS.concurrencySem <- struct{}{} } // Create a context with a very short timeout for testing testCtx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) defer cancel() // Now try to spawn a sub-turn with the short timeout context subTurnCfg := SubTurnConfig{ Model: "gpt-4o-mini", Async: false, } start := time.Now() _, err := spawnSubTurn(testCtx, al, parentTS, subTurnCfg) elapsed := time.Since(start) // Should get a timeout error (either from our timeout context or the internal one) if err == nil { t.Error("Expected timeout error, got nil") } // The error should be related to context cancellation or timeout if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, ErrConcurrencyTimeout) { t.Logf("Got error: %v (type: %T)", err, err) // This is acceptable - the error might be wrapped } // Should timeout quickly (within a reasonable margin) if elapsed > 2*time.Second { t.Errorf("Timeout took too long: %v", elapsed) } t.Logf("Timeout occurred after %v with error: %v", elapsed, err) // Clean up - drain the semaphore for i := 0; i < maxConcurrentSubTurns; i++ { <-parentTS.concurrencySem } } // TestEphemeralSession_AutoTruncate verifies that ephemeral sessions automatically // truncate their history to prevent memory accumulation. func TestEphemeralSession_AutoTruncate(t *testing.T) { store := newEphemeralSession(nil).(*ephemeralSessionStore) // Add more messages than the limit for i := 0; i < maxEphemeralHistorySize+20; i++ { store.AddMessage("test", "user", fmt.Sprintf("message-%d", i)) } // Verify history is truncated to the limit history := store.GetHistory("test") if len(history) != maxEphemeralHistorySize { t.Errorf("Expected history length %d, got %d", maxEphemeralHistorySize, len(history)) } // Verify we kept the most recent messages lastMsg := history[len(history)-1] expectedContent := fmt.Sprintf("message-%d", maxEphemeralHistorySize+20-1) if lastMsg.Content != expectedContent { t.Errorf("Expected last message to be %q, got %q", expectedContent, lastMsg.Content) } // Verify the oldest messages were discarded firstMsg := history[0] expectedFirstContent := fmt.Sprintf("message-%d", 20) // First 20 were discarded if firstMsg.Content != expectedFirstContent { t.Errorf("Expected first message to be %q, got %q", expectedFirstContent, firstMsg.Content) } } // TestContextWrapping_SingleLayer verifies that we only create one context layer // in spawnSubTurn, not multiple redundant layers. func TestContextWrapping_SingleLayer(t *testing.T) { cfg := &config.Config{ Agents: config.AgentsConfig{ Defaults: config.AgentDefaults{ Provider: "mock", }, }, } msgBus := bus.NewMessageBus() provider := &simpleMockProviderAPI{} al := NewAgentLoop(cfg, msgBus, provider) ctx := context.Background() parentTS := &turnState{ ctx: ctx, turnID: "parent-context-test", depth: 0, session: newEphemeralSession(nil), pendingResults: make(chan *tools.ToolResult, 16), concurrencySem: make(chan struct{}, maxConcurrentSubTurns), } parentTS.ctx, parentTS.cancelFunc = context.WithCancel(ctx) defer parentTS.Finish(false) // Spawn a sub-turn subTurnCfg := SubTurnConfig{ Model: "gpt-4o-mini", Async: false, } result, err := spawnSubTurn(ctx, al, parentTS, subTurnCfg) if err != nil { t.Fatalf("spawnSubTurn failed: %v", err) } if result == nil { t.Error("Expected non-nil result") } // Verify the child turn was created with a cancel function // (This is implicit - if the test passes without hanging, the context management is correct) t.Log("Context wrapping test passed - no redundant layers detected") } // TestSyncSubTurn_NoChannelDelivery verifies that synchronous sub-turns // do NOT deliver results to the pendingResults channel (only return directly). func TestSyncSubTurn_NoChannelDelivery(t *testing.T) { cfg := &config.Config{ Agents: config.AgentsConfig{ Defaults: config.AgentDefaults{ Provider: "mock", }, }, } msgBus := bus.NewMessageBus() provider := &simpleMockProviderAPI{} al := NewAgentLoop(cfg, msgBus, provider) ctx := context.Background() parentTS := &turnState{ ctx: ctx, turnID: "parent-sync-test", depth: 0, session: newEphemeralSession(nil), pendingResults: make(chan *tools.ToolResult, 16), concurrencySem: make(chan struct{}, maxConcurrentSubTurns), } parentTS.ctx, parentTS.cancelFunc = context.WithCancel(ctx) defer parentTS.Finish(false) // Spawn a SYNCHRONOUS sub-turn (Async=false) subTurnCfg := SubTurnConfig{ Model: "gpt-4o-mini", Async: false, // Synchronous - should NOT deliver to channel } result, err := spawnSubTurn(ctx, al, parentTS, subTurnCfg) if err != nil { t.Fatalf("spawnSubTurn failed: %v", err) } if result == nil { t.Error("Expected non-nil result from synchronous sub-turn") } // Verify the pendingResults channel is EMPTY // (synchronous sub-turns should not deliver to channel) select { case r := <-parentTS.pendingResults: t.Errorf("Expected empty channel for sync sub-turn, but got result: %v", r) default: // Expected: channel is empty t.Log("Verified: synchronous sub-turn did not deliver to channel") } // Verify channel length is 0 if len(parentTS.pendingResults) != 0 { t.Errorf("Expected channel length 0, got %d", len(parentTS.pendingResults)) } } // TestAsyncSubTurn_ChannelDelivery verifies that asynchronous sub-turns // DO deliver results to the pendingResults channel. func TestAsyncSubTurn_ChannelDelivery(t *testing.T) { cfg := &config.Config{ Agents: config.AgentsConfig{ Defaults: config.AgentDefaults{ Provider: "mock", }, }, } msgBus := bus.NewMessageBus() provider := &simpleMockProviderAPI{} al := NewAgentLoop(cfg, msgBus, provider) ctx := context.Background() parentTS := &turnState{ ctx: ctx, turnID: "parent-async-test", depth: 0, session: newEphemeralSession(nil), pendingResults: make(chan *tools.ToolResult, 16), concurrencySem: make(chan struct{}, maxConcurrentSubTurns), } parentTS.ctx, parentTS.cancelFunc = context.WithCancel(ctx) defer parentTS.Finish(false) // Spawn an ASYNCHRONOUS sub-turn (Async=true) subTurnCfg := SubTurnConfig{ Model: "gpt-4o-mini", Async: true, // Asynchronous - SHOULD deliver to channel } result, err := spawnSubTurn(ctx, al, parentTS, subTurnCfg) if err != nil { t.Fatalf("spawnSubTurn failed: %v", err) } if result == nil { t.Error("Expected non-nil result from asynchronous sub-turn") } // Verify the pendingResults channel has the result select { case r := <-parentTS.pendingResults: if r == nil { t.Error("Expected non-nil result from channel") } t.Log("Verified: asynchronous sub-turn delivered to channel") case <-time.After(100 * time.Millisecond): t.Error("Expected result in channel for async sub-turn, but channel was empty") } } // TestGrandchildAbort_CascadingCancellation verifies that when a grandparent turn // is hard aborted, the cancellation cascades down to grandchild turns. func TestGrandchildAbort_CascadingCancellation(t *testing.T) { ctx := context.Background() // Create grandparent turn (depth 0) grandparentTS := &turnState{ ctx: ctx, turnID: "grandparent", depth: 0, session: newEphemeralSession(nil), pendingResults: make(chan *tools.ToolResult, 16), concurrencySem: make(chan struct{}, maxConcurrentSubTurns), } grandparentTS.ctx, grandparentTS.cancelFunc = context.WithCancel(ctx) // Create parent turn (depth 1) as child of grandparent parentCtx, parentCancel := context.WithCancel(grandparentTS.ctx) defer parentCancel() parentTS := &turnState{ ctx: parentCtx, turnID: "parent", parentTurnID: "grandparent", depth: 1, session: newEphemeralSession(nil), pendingResults: make(chan *tools.ToolResult, 16), concurrencySem: make(chan struct{}, maxConcurrentSubTurns), } parentTS.cancelFunc = parentCancel // Create grandchild turn (depth 2) as child of parent childCtx, childCancel := context.WithCancel(parentTS.ctx) defer childCancel() childTS := &turnState{ ctx: childCtx, turnID: "grandchild", parentTurnID: "parent", depth: 2, session: newEphemeralSession(nil), pendingResults: make(chan *tools.ToolResult, 16), concurrencySem: make(chan struct{}, maxConcurrentSubTurns), } childTS.cancelFunc = childCancel // Verify all contexts are active select { case <-grandparentTS.ctx.Done(): t.Error("Grandparent context should not be cancelled yet") default: } select { case <-parentTS.ctx.Done(): t.Error("Parent context should not be cancelled yet") default: } select { case <-childTS.ctx.Done(): t.Error("Child context should not be cancelled yet") default: } // Hard abort the grandparent grandparentTS.Finish(true) // Wait a bit for cancellation to propagate time.Sleep(10 * time.Millisecond) // Verify cascading cancellation select { case <-grandparentTS.ctx.Done(): t.Log("Grandparent context cancelled (expected)") default: t.Error("Grandparent context should be cancelled") } select { case <-parentTS.ctx.Done(): t.Log("Parent context cancelled via cascade (expected)") default: t.Error("Parent context should be cancelled via cascade") } select { case <-childTS.ctx.Done(): t.Log("Grandchild context cancelled via cascade (expected)") default: t.Error("Grandchild context should be cancelled via cascade") } } // TestSpawnDuringAbort_RaceCondition verifies behavior when trying to spawn // a sub-turn while the parent is being aborted. func TestSpawnDuringAbort_RaceCondition(t *testing.T) { cfg := &config.Config{ Agents: config.AgentsConfig{ Defaults: config.AgentDefaults{ Provider: "mock", }, }, } msgBus := bus.NewMessageBus() provider := &simpleMockProviderAPI{} al := NewAgentLoop(cfg, msgBus, provider) ctx := context.Background() parentTS := &turnState{ ctx: ctx, turnID: "parent-abort-race", depth: 0, session: newEphemeralSession(nil), pendingResults: make(chan *tools.ToolResult, 16), concurrencySem: make(chan struct{}, maxConcurrentSubTurns), } parentTS.ctx, parentTS.cancelFunc = context.WithCancel(ctx) var wg sync.WaitGroup wg.Add(2) var spawnErr error // Goroutine 1: Try to spawn a sub-turn go func() { defer wg.Done() subTurnCfg := SubTurnConfig{ Model: "gpt-4o-mini", Async: false, } _, err := spawnSubTurn(parentTS.ctx, al, parentTS, subTurnCfg) spawnErr = err }() // Goroutine 2: Abort the parent almost immediately go func() { defer wg.Done() time.Sleep(1 * time.Millisecond) parentTS.Finish(false) }() wg.Wait() // The spawn should either succeed (if it started before abort) // or fail with context cancelled error (if abort happened first) if spawnErr != nil { if errors.Is(spawnErr, context.Canceled) { t.Logf("Spawn failed with expected context cancellation: %v", spawnErr) } else { t.Logf("Spawn failed with error: %v", spawnErr) } } else { t.Log("Spawn succeeded before abort") } // The important thing is that it doesn't panic or deadlock t.Log("Race condition handled gracefully - no panic or deadlock") } // ====================== Slow SubTurn Cancellation Test ====================== // slowMockProvider simulates a slow LLM call that takes a long time to complete. // This is used to test the scenario where a parent turn finishes before the child SubTurn. type slowMockProvider struct { delay time.Duration } func (m *slowMockProvider) Chat( ctx context.Context, messages []providers.Message, toolDefs []providers.ToolDefinition, model string, options map[string]any, ) (*providers.LLMResponse, error) { select { case <-time.After(m.delay): // Completed normally after delay return &providers.LLMResponse{ Content: "slow response completed", }, nil case <-ctx.Done(): // Context was cancelled while waiting return nil, ctx.Err() } } func (m *slowMockProvider) GetDefaultModel() string { return "slow-model" } // TestAsyncSubTurn_ParentFinishesEarly simulates the scenario where: // 1. Parent spawns an async SubTurn that takes a long time // 2. Parent finishes quickly // 3. SubTurn should be cancelled with context canceled error func TestAsyncSubTurn_ParentFinishesEarly(t *testing.T) { // Save original MockEventBus.Emit to capture events originalEmit := MockEventBus.Emit defer func() { MockEventBus.Emit = originalEmit }() var mu sync.Mutex var events []any MockEventBus.Emit = func(e any) { mu.Lock() defer mu.Unlock() events = append(events, e) } cfg := &config.Config{ Agents: config.AgentsConfig{ Defaults: config.AgentDefaults{ Provider: "mock", }, }, } msgBus := bus.NewMessageBus() provider := &slowMockProvider{delay: 5 * time.Second} // SubTurn takes 5 seconds al := NewAgentLoop(cfg, msgBus, provider) ctx := context.Background() parentTS := &turnState{ ctx: ctx, turnID: "parent-fast", depth: 0, session: newEphemeralSession(nil), pendingResults: make(chan *tools.ToolResult, 16), concurrencySem: make(chan struct{}, maxConcurrentSubTurns), } parentTS.ctx, parentTS.cancelFunc = context.WithCancel(ctx) var subTurnErr error var subTurnResult *tools.ToolResult var wg sync.WaitGroup // Spawn async SubTurn in a goroutine (it will be slow) wg.Add(1) go func() { defer wg.Done() subTurnCfg := SubTurnConfig{ Model: "slow-model", Async: true, // Asynchronous SubTurn } subTurnResult, subTurnErr = spawnSubTurn(parentTS.ctx, al, parentTS, subTurnCfg) }() // Parent finishes quickly (after 100ms), while SubTurn is still running time.Sleep(100 * time.Millisecond) t.Log("Parent finishing early...") parentTS.Finish(false) // Wait for SubTurn to complete (or be cancelled) wg.Wait() // Check the result t.Logf("SubTurn error: %v", subTurnErr) t.Logf("SubTurn result: %v", subTurnResult) if subTurnErr != nil { if errors.Is(subTurnErr, context.Canceled) { t.Log("✓ SubTurn was cancelled as expected (context canceled)") } else { t.Logf("SubTurn failed with other error: %v", subTurnErr) } } else { t.Log("SubTurn completed before parent finished (unlikely but possible)") } // Log captured events mu.Lock() t.Logf("Captured %d events:", len(events)) for i, e := range events { t.Logf(" Event %d: %T", i+1, e) } mu.Unlock() } // TestAsyncSubTurn_ParentWaitsForChild simulates the scenario where: // 1. Parent spawns an async SubTurn that takes some time // 2. Parent WAITS for SubTurn to complete before finishing // 3. Both should complete successfully func TestAsyncSubTurn_ParentWaitsForChild(t *testing.T) { cfg := &config.Config{ Agents: config.AgentsConfig{ Defaults: config.AgentDefaults{ Provider: "mock", }, }, } msgBus := bus.NewMessageBus() provider := &slowMockProvider{delay: 200 * time.Millisecond} // SubTurn takes 200ms al := NewAgentLoop(cfg, msgBus, provider) ctx := context.Background() parentTS := &turnState{ ctx: ctx, turnID: "parent-wait", depth: 0, session: newEphemeralSession(nil), pendingResults: make(chan *tools.ToolResult, 16), concurrencySem: make(chan struct{}, maxConcurrentSubTurns), } parentTS.ctx, parentTS.cancelFunc = context.WithCancel(ctx) var subTurnErr error var subTurnResult *tools.ToolResult var wg sync.WaitGroup // Spawn async SubTurn in a goroutine wg.Add(1) go func() { defer wg.Done() subTurnCfg := SubTurnConfig{ Model: "slow-model", Async: true, } subTurnResult, subTurnErr = spawnSubTurn(parentTS.ctx, al, parentTS, subTurnCfg) }() // Parent WAITS for SubTurn to complete t.Log("Parent waiting for SubTurn...") wg.Wait() t.Log("SubTurn completed, parent now finishing") // Now parent can finish safely parentTS.Finish(false) // Check the result if subTurnErr != nil { if errors.Is(subTurnErr, context.Canceled) { t.Errorf("SubTurn should NOT have been cancelled: %v", subTurnErr) } else { t.Logf("SubTurn failed with error: %v", subTurnErr) } } else { t.Log("✓ SubTurn completed successfully") if subTurnResult != nil { t.Logf("SubTurn result: %s", subTurnResult.ForLLM) } } // Check channel delivery select { case r := <-parentTS.pendingResults: if r != nil { t.Logf("✓ Result delivered to channel: %s", r.ForLLM) } case <-time.After(100 * time.Millisecond): t.Log("No result in channel (expected since we waited)") } } // ====================== Graceful vs Hard Finish Tests ====================== // TestFinish_GracefulVsHard verifies the behavior difference between: // - Finish(false): graceful finish, signals parentEnded but doesn't cancel children // - Finish(true): hard abort, immediately cancels all children func TestFinish_GracefulVsHard(t *testing.T) { // Test 1: Graceful finish should set parentEnded but not cancel context t.Run("Graceful_SetsParentEnded", func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() ts := &turnState{ ctx: ctx, turnID: "graceful-test", depth: 0, pendingResults: make(chan *tools.ToolResult, 16), } ts.ctx, ts.cancelFunc = context.WithCancel(ctx) // Finish gracefully ts.Finish(false) // Verify parentEnded is set if !ts.parentEnded.Load() { t.Error("parentEnded should be true after graceful finish") } // Verify context is NOT cancelled (for graceful finish, children continue) // Note: In graceful mode, we don't call cancelFunc() // But since we're using WithCancel on the same ctx, it might be cancelled // Let's check that the context is still valid for a moment time.Sleep(10 * time.Millisecond) // Context might be cancelled by the deferred cancel() in test, which is fine }) // Test 2: Hard abort should cancel context immediately t.Run("Hard_CancelsContext", func(t *testing.T) { ctx := context.Background() ts := &turnState{ ctx: ctx, turnID: "hard-test", depth: 0, pendingResults: make(chan *tools.ToolResult, 16), } ts.ctx, ts.cancelFunc = context.WithCancel(ctx) // Finish with hard abort ts.Finish(true) // Verify context is cancelled select { case <-ts.ctx.Done(): t.Log("✓ Context cancelled after hard abort") default: t.Error("Context should be cancelled after hard abort") } }) // Test 3: IsParentEnded returns correct value t.Run("IsParentEnded", func(t *testing.T) { ctx := context.Background() parentTS := &turnState{ ctx: ctx, turnID: "parent-isended-test", depth: 0, pendingResults: make(chan *tools.ToolResult, 16), } parentTS.ctx, parentTS.cancelFunc = context.WithCancel(ctx) childTS := &turnState{ ctx: ctx, turnID: "child-isended-test", depth: 1, parentTurnState: parentTS, pendingResults: make(chan *tools.ToolResult, 16), } // Before parent finishes if childTS.IsParentEnded() { t.Error("IsParentEnded should be false before parent finishes") } // Finish parent gracefully parentTS.Finish(false) // After parent finishes if !childTS.IsParentEnded() { t.Error("IsParentEnded should be true after parent finishes gracefully") } }) } // TestSubTurn_IndependentContext verifies that SubTurns use independent contexts // that don't get cancelled when the parent finishes gracefully. func TestSubTurn_IndependentContext(t *testing.T) { cfg := &config.Config{ Agents: config.AgentsConfig{ Defaults: config.AgentDefaults{ Provider: "mock", }, }, } msgBus := bus.NewMessageBus() provider := &slowMockProvider{delay: 500 * time.Millisecond} al := NewAgentLoop(cfg, msgBus, provider) ctx := context.Background() parentTS := &turnState{ ctx: ctx, turnID: "parent-independent", depth: 0, session: newEphemeralSession(nil), pendingResults: make(chan *tools.ToolResult, 16), concurrencySem: make(chan struct{}, maxConcurrentSubTurns), } parentTS.ctx, parentTS.cancelFunc = context.WithCancel(ctx) var subTurnErr error var wg sync.WaitGroup // Spawn SubTurn with Critical=true (should continue after parent finishes) wg.Add(1) go func() { defer wg.Done() subTurnCfg := SubTurnConfig{ Model: "slow-model", Async: true, Critical: true, // Critical SubTurn should continue } _, subTurnErr = spawnSubTurn(parentTS.ctx, al, parentTS, subTurnCfg) }() // Let SubTurn start time.Sleep(50 * time.Millisecond) // Parent finishes gracefully (should NOT cancel SubTurn) parentTS.Finish(false) t.Log("Parent finished gracefully, SubTurn should continue") // Wait for SubTurn to complete wg.Wait() // SubTurn should complete without context cancelled error // (because it uses independent context now) if subTurnErr != nil { t.Logf("SubTurn error: %v", subTurnErr) // The error might be context.DeadlineExceeded if timeout is too short // but should NOT be context.Canceled from parent if errors.Is(subTurnErr, context.Canceled) { t.Error("SubTurn should not be cancelled by parent's graceful finish") } } else { t.Log("✓ SubTurn completed successfully (independent context)") } }