diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 510e247e3..dd4c81373 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -48,9 +48,10 @@ type AgentLoop struct { transcriber voice.Transcriber cmdRegistry *commands.Registry mcp mcpRuntime - steering *steeringQueue - subTurnResults sync.Map - mu sync.RWMutex + steering *steeringQueue + subTurnResults sync.Map // key: sessionKey (string), value: chan *tools.ToolResult + activeTurnStates sync.Map // key: sessionKey (string), value: *turnState + mu sync.RWMutex // Track active requests for safe provider cleanup activeRequests sync.WaitGroup } @@ -253,6 +254,7 @@ func registerSharedTools( depth: 0, session: newEphemeralSession(nil), pendingResults: make(chan *tools.ToolResult, 16), + concurrencySem: make(chan struct{}, 5), } } @@ -969,9 +971,14 @@ func (al *AgentLoop) runAgentLoop( depth: 0, session: agent.Sessions, pendingResults: make(chan *tools.ToolResult, 16), + concurrencySem: make(chan struct{}, 5), // maxConcurrentSubTurns } ctx = withTurnState(ctx, rootTS) + // Register this root turn state so HardAbort can find it + al.activeTurnStates.Store(opts.SessionKey, rootTS) + defer al.activeTurnStates.Delete(opts.SessionKey) + // Ensure the parent's pending results channel is cleaned up when this root turn finishes defer al.unregisterSubTurnResultChannel(rootTS.turnID) al.registerSubTurnResultChannel(rootTS.turnID, rootTS.pendingResults) diff --git a/pkg/agent/steering.go b/pkg/agent/steering.go index c09b97581..840a73723 100644 --- a/pkg/agent/steering.go +++ b/pkg/agent/steering.go @@ -227,3 +227,35 @@ func (al *AgentLoop) registerSubTurnResultChannel(sessionKey string, ch chan *to func (al *AgentLoop) unregisterSubTurnResultChannel(sessionKey string) { al.subTurnResults.Delete(sessionKey) } + +// ====================== Hard Abort ====================== + +// HardAbort immediately cancels the running agent loop for the given session, +// cascading the cancellation to all child SubTurns. This is a destructive operation +// that terminates execution without waiting for graceful cleanup. +// +// Use this when the user explicitly requests immediate termination (e.g., "stop now", "abort"). +// For graceful interruption that allows the agent to finish the current tool and summarize, +// use Steer() instead. +func (al *AgentLoop) HardAbort(sessionKey string) error { + tsInterface, ok := al.activeTurnStates.Load(sessionKey) + if !ok { + return fmt.Errorf("no active turn state found for session %s", sessionKey) + } + + ts, ok := tsInterface.(*turnState) + if !ok { + return fmt.Errorf("invalid turn state type for session %s", sessionKey) + } + + logger.InfoCF("agent", "Hard abort triggered", map[string]any{ + "session_key": sessionKey, + "turn_id": ts.turnID, + "depth": ts.depth, + }) + + // Trigger cascading cancellation to all child SubTurns + ts.Finish() + + return nil +} diff --git a/pkg/agent/subturn.go b/pkg/agent/subturn.go index 89b254c69..691353e90 100644 --- a/pkg/agent/subturn.go +++ b/pkg/agent/subturn.go @@ -13,11 +13,15 @@ import ( ) // ====================== Config & Constants ====================== -const maxSubTurnDepth = 3 +const ( + maxSubTurnDepth = 3 + maxConcurrentSubTurns = 5 +) var ( - ErrDepthLimitExceeded = errors.New("sub-turn depth limit exceeded") - ErrInvalidSubTurnConfig = errors.New("invalid sub-turn config") + ErrDepthLimitExceeded = errors.New("sub-turn depth limit exceeded") + ErrInvalidSubTurnConfig = errors.New("invalid sub-turn config") + ErrConcurrencyLimitExceeded = errors.New("sub-turn concurrency limit exceeded") ) // ====================== SubTurn Config ====================== @@ -79,6 +83,7 @@ type turnState struct { session session.SessionStore mu sync.Mutex isFinished bool // Marks if the parent Turn has ended + concurrencySem chan struct{} // Limits concurrent child sub-turns } // ====================== Helper Functions ====================== @@ -102,6 +107,7 @@ func newTurnState(ctx context.Context, id string, parent *turnState) *turnState // intermediate results to be discarded in deliverSubTurnResult. // For production, consider an unbounded queue or a blocking strategy with backpressure. pendingResults: make(chan *tools.ToolResult, 16), + concurrencySem: make(chan struct{}, maxConcurrentSubTurns), } } @@ -189,31 +195,42 @@ func spawnSubTurn(ctx context.Context, al *AgentLoop, parentTS *turnState, cfg S return nil, ErrInvalidSubTurnConfig } + // 3. Acquire concurrency semaphore — blocks if parent already has maxConcurrentSubTurns running. + // Also respects context cancellation so we don't block forever if parent is aborted. + if parentTS.concurrencySem != nil { + select { + case parentTS.concurrencySem <- struct{}{}: + defer func() { <-parentTS.concurrencySem }() + case <-ctx.Done(): + return nil, ctx.Err() + } + } + // Create a sub-context for the child turn to support cancellation childCtx, cancel := context.WithCancel(ctx) defer cancel() - // 3. Create child Turn state + // 4. Create child Turn state childID := generateTurnID() childTS := newTurnState(childCtx, childID, parentTS) - // 4. Establish parent-child relationship (thread-safe) + // 5. Establish parent-child relationship (thread-safe) parentTS.mu.Lock() parentTS.childTurnIDs = append(parentTS.childTurnIDs, childID) parentTS.mu.Unlock() - // 5. Register the parent's pendingResults channel so the parent loop can poll it + // 6. Register the parent's pendingResults channel so the parent loop can poll it al.registerSubTurnResultChannel(parentTS.turnID, parentTS.pendingResults) defer al.unregisterSubTurnResultChannel(parentTS.turnID) - // 6. Emit Spawn event (currently using Mock, will be replaced by real EventBus) + // 7. Emit Spawn event (currently using Mock, will be replaced by real EventBus) MockEventBus.Emit(SubTurnSpawnEvent{ ParentID: parentTS.turnID, ChildID: childID, Config: cfg, }) - // 7. Defer emitting End event, and recover from panics to ensure it's always fired + // 8. Defer emitting End event, and recover from panics to ensure it's always fired defer func() { if r := recover(); r != nil { err = fmt.Errorf("subturn panicked: %v", r) @@ -226,11 +243,11 @@ func spawnSubTurn(ctx context.Context, al *AgentLoop, parentTS *turnState, cfg S }) }() - // 8. Execute sub-turn via the real agent loop. + // 9. Execute sub-turn via the real agent loop. // Build a child AgentInstance from SubTurnConfig, inheriting defaults from the parent agent. result, err = runTurn(childCtx, al, childTS, cfg) - // 9. Deliver result back to parent Turn + // 10. Deliver result back to parent Turn deliverSubTurnResult(parentTS, childID, result) return result, err diff --git a/pkg/agent/subturn_test.go b/pkg/agent/subturn_test.go index b7012e63d..1b609318d 100644 --- a/pkg/agent/subturn_test.go +++ b/pkg/agent/subturn_test.go @@ -323,3 +323,124 @@ func TestDequeuePendingSubTurnResults(t *testing.T) { t.Error("expected nil for unregistered session") } } + +// ====================== Extra Independent Test: Concurrency Semaphore ====================== +func TestSubTurnConcurrencySemaphore(t *testing.T) { + al, _, _, _, cleanup := newTestAgentLoop(t) + defer cleanup() + + parent := &turnState{ + ctx: context.Background(), + turnID: "parent-concurrency", + depth: 0, + pendingResults: make(chan *tools.ToolResult, 10), + session: &ephemeralSessionStore{}, + concurrencySem: make(chan struct{}, 2), // Only allow 2 concurrent children + } + + cfg := SubTurnConfig{Model: "gpt-4o-mini", Tools: []tools.Tool{}} + + // Spawn 2 children — should succeed immediately + done := make(chan bool, 3) + for i := 0; i < 2; i++ { + go func() { + _, _ = spawnSubTurn(context.Background(), al, parent, cfg) + done <- true + }() + } + + // Wait a bit to ensure the first 2 are running + // (In real scenario they'd be blocked in runTurn, but mockProvider returns immediately) + // So we just verify the semaphore doesn't block when under limit + <-done + <-done + + // Verify semaphore is now full (2/2 slots used, but they already released) + // Since mockProvider returns immediately, semaphore is already released + // So we can't easily test blocking without a real long-running operation + + // Instead, verify that semaphore exists and has correct capacity + if cap(parent.concurrencySem) != 2 { + t.Errorf("expected semaphore capacity 2, got %d", cap(parent.concurrencySem)) + } +} + +// ====================== Extra Independent Test: Hard Abort Cascading ====================== +func TestHardAbortCascading(t *testing.T) { + al, _, _, _, cleanup := newTestAgentLoop(t) + defer cleanup() + + sessionKey := "test-session-abort" + parentCtx, parentCancel := context.WithCancel(context.Background()) + defer parentCancel() + + rootTS := &turnState{ + ctx: parentCtx, + turnID: sessionKey, + depth: 0, + session: &ephemeralSessionStore{}, + pendingResults: make(chan *tools.ToolResult, 16), + concurrencySem: make(chan struct{}, 5), + } + + // Register the root turn state + al.activeTurnStates.Store(sessionKey, rootTS) + defer al.activeTurnStates.Delete(sessionKey) + + // Create a child turn state + childCtx, childCancel := context.WithCancel(rootTS.ctx) + defer childCancel() + childTS := &turnState{ + ctx: childCtx, + cancelFunc: childCancel, + turnID: "child-1", + parentTurnID: sessionKey, + depth: 1, + session: &ephemeralSessionStore{}, + pendingResults: make(chan *tools.ToolResult, 16), + concurrencySem: make(chan struct{}, 5), + } + + // Attach cancelFunc to rootTS so Finish() can trigger it + rootTS.cancelFunc = parentCancel + + // Verify contexts are not canceled yet + select { + case <-rootTS.ctx.Done(): + t.Error("root context should not be canceled yet") + default: + } + select { + case <-childTS.ctx.Done(): + t.Error("child context should not be canceled yet") + default: + } + + // Trigger Hard Abort + err := al.HardAbort(sessionKey) + if err != nil { + t.Errorf("HardAbort failed: %v", err) + } + + // Verify root context is canceled + select { + case <-rootTS.ctx.Done(): + // Expected + default: + t.Error("root context should be canceled after HardAbort") + } + + // Verify child context is also canceled (cascading) + select { + case <-childTS.ctx.Done(): + // Expected + default: + t.Error("child context should be canceled after HardAbort (cascading)") + } + + // Verify HardAbort on non-existent session returns error + err = al.HardAbort("non-existent-session") + if err == nil { + t.Error("expected error for non-existent session") + } +}