diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 02253b753..04e726b84 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -49,7 +49,6 @@ type AgentLoop struct { cmdRegistry *commands.Registry mcp mcpRuntime steering *steeringQueue - subTurnResults sync.Map // key: sessionKey (string), value: chan *tools.ToolResult activeTurnStates sync.Map // key: sessionKey (string), value: *turnState subTurnCounter atomic.Int64 // Counter for generating unique SubTurn IDs mu sync.RWMutex @@ -1001,7 +1000,7 @@ func (al *AgentLoop) runAgentLoop( session: agent.Sessions, initialHistoryLength: len(agent.Sessions.GetHistory("")), // Snapshot for rollback on hard abort pendingResults: make(chan *tools.ToolResult, 16), - concurrencySem: make(chan struct{}, 5), // maxConcurrentSubTurns + concurrencySem: make(chan struct{}, maxConcurrentSubTurns), // maxConcurrentSubTurns } ctx = withTurnState(ctx, rootTS) ctx = WithAgentLoop(ctx, al) // Inject AgentLoop for tool access @@ -1010,10 +1009,6 @@ func (al *AgentLoop) runAgentLoop( // Register this root turn state so HardAbort can find it al.activeTurnStates.Store(opts.SessionKey, rootTS) defer al.activeTurnStates.Delete(opts.SessionKey) - - // Ensure the parent's pending results channel is cleaned up when this root turn finishes - defer al.unregisterSubTurnResultChannel(rootTS.turnID) - al.registerSubTurnResultChannel(rootTS.turnID, rootTS.pendingResults) } // 0. Record last channel for heartbeat notifications (skip internal channels and cli) @@ -1220,15 +1215,19 @@ func (al *AgentLoop) runLLMIteration( // This is only relevant for SubTurns (turnState with parentTurnState != nil). // If parent ended and this SubTurn is not Critical, exit gracefully. if ts := turnStateFromContext(ctx); ts != nil && ts.IsParentEnded() { - logger.InfoCF("agent", "Parent turn ended, SubTurn continues or exits", map[string]any{ + if !ts.critical { + logger.InfoCF("agent", "Parent turn ended, non-critical SubTurn exiting gracefully", map[string]any{ + "agent_id": agent.ID, + "iteration": iteration, + "turn_id": ts.turnID, + }) + break + } + logger.InfoCF("agent", "Parent turn ended, critical SubTurn continues running", map[string]any{ "agent_id": agent.ID, "iteration": iteration, "turn_id": ts.turnID, }) - // For now, we continue running. The Critical flag check is handled - // at SubTurnConfig level in spawnSubTurn. Here we just log and continue. - // If this SubTurn should exit gracefully, it would have been cancelled - // by its own timeout or the caller would have handled it. } // Inject pending steering messages into the conversation context diff --git a/pkg/agent/steering.go b/pkg/agent/steering.go index 401db7cc7..0cbde2c2e 100644 --- a/pkg/agent/steering.go +++ b/pkg/agent/steering.go @@ -192,14 +192,13 @@ func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID s // dequeuePendingSubTurnResults polls the SubTurn result channel for the given // session and returns all available results without blocking. -// Returns nil if no channel is registered for this session. +// Returns nil if no active turn state exists for this session. func (al *AgentLoop) dequeuePendingSubTurnResults(sessionKey string) []*tools.ToolResult { - chInterface, ok := al.subTurnResults.Load(sessionKey) + tsInterface, ok := al.activeTurnStates.Load(sessionKey) if !ok { return nil } - - ch, ok := chInterface.(chan *tools.ToolResult) + ts, ok := tsInterface.(*turnState) if !ok { return nil } @@ -207,7 +206,7 @@ func (al *AgentLoop) dequeuePendingSubTurnResults(sessionKey string) []*tools.To var results []*tools.ToolResult for { select { - case result := <-ch: + case result := <-ts.pendingResults: if result != nil { results = append(results, result) } @@ -217,17 +216,6 @@ func (al *AgentLoop) dequeuePendingSubTurnResults(sessionKey string) []*tools.To } } -// registerSubTurnResultChannel registers a SubTurn result channel for the given session. -// This allows the parent loop to poll for results from child SubTurns. -func (al *AgentLoop) registerSubTurnResultChannel(sessionKey string, ch chan *tools.ToolResult) { - al.subTurnResults.Store(sessionKey, ch) -} - -// unregisterSubTurnResultChannel removes the SubTurn result channel for the given session. -func (al *AgentLoop) unregisterSubTurnResultChannel(sessionKey string) { - al.subTurnResults.Delete(sessionKey) -} - // ====================== Hard Abort ====================== // HardAbort immediately cancels the running agent loop for the given session, diff --git a/pkg/agent/subturn.go b/pkg/agent/subturn.go index b3fe71518..b981da399 100644 --- a/pkg/agent/subturn.go +++ b/pkg/agent/subturn.go @@ -186,11 +186,14 @@ func (s *AgentLoopSpawner) SpawnSubTurn(ctx context.Context, cfg tools.SubTurnCo // Convert tools.SubTurnConfig to agent.SubTurnConfig agentCfg := SubTurnConfig{ - Model: cfg.Model, - Tools: cfg.Tools, - SystemPrompt: cfg.SystemPrompt, - MaxTokens: cfg.MaxTokens, - Async: cfg.Async, + Model: cfg.Model, + Tools: cfg.Tools, + SystemPrompt: cfg.SystemPrompt, + MaxTokens: cfg.MaxTokens, + Async: cfg.Async, + Critical: cfg.Critical, + Timeout: cfg.Timeout, + MaxContextRunes: cfg.MaxContextRunes, } return spawnSubTurn(ctx, s.al, parentTS, agentCfg) @@ -277,6 +280,7 @@ func spawnSubTurn(ctx context.Context, al *AgentLoop, parentTS *turnState, cfg S childTS := newTurnState(childCtx, childID, parentTS) // Set the cancel function so Finish(true) can trigger hard cancellation childTS.cancelFunc = cancel + childTS.critical = cfg.Critical // IMPORTANT: Put childTS into childCtx so that code inside runTurn can retrieve it childCtx = withTurnState(childCtx, childTS) diff --git a/pkg/agent/subturn_test.go b/pkg/agent/subturn_test.go index 8e7b3f533..883958231 100644 --- a/pkg/agent/subturn_test.go +++ b/pkg/agent/subturn_test.go @@ -315,11 +315,6 @@ func TestSubTurnResultChannelRegistration(t *testing.T) { } _, _ = spawnSubTurn(context.Background(), al, parent, cfg) - - // After spawn completes: channel should be unregistered (defer cleanup in spawnSubTurn) - if _, ok := al.subTurnResults.Load(parent.turnID); ok { - t.Error("channel should be unregistered after spawnSubTurn completes") - } } // ====================== Extra Independent Test: Dequeue Pending SubTurn Results ====================== @@ -328,21 +323,27 @@ func TestDequeuePendingSubTurnResults(t *testing.T) { defer cleanup() sessionKey := "test-session-dequeue" - ch := make(chan *tools.ToolResult, 4) - // Register channel manually - al.registerSubTurnResultChannel(sessionKey, ch) - defer al.unregisterSubTurnResultChannel(sessionKey) - - // Empty channel returns nil + // Empty (no turnState registered) returns nil if results := al.dequeuePendingSubTurnResults(sessionKey); len(results) != 0 { t.Errorf("expected empty results, got %d", len(results)) } + // Register a turnState so dequeuePendingSubTurnResults can find it + ts := &turnState{ + ctx: context.Background(), + turnID: sessionKey, + depth: 0, + session: &ephemeralSessionStore{}, + pendingResults: make(chan *tools.ToolResult, 4), + } + al.activeTurnStates.Store(sessionKey, ts) + defer al.activeTurnStates.Delete(sessionKey) + // Put 3 results in - ch <- &tools.ToolResult{ForLLM: "result-1"} - ch <- &tools.ToolResult{ForLLM: "result-2"} - ch <- &tools.ToolResult{ForLLM: "result-3"} + ts.pendingResults <- &tools.ToolResult{ForLLM: "result-1"} + ts.pendingResults <- &tools.ToolResult{ForLLM: "result-2"} + ts.pendingResults <- &tools.ToolResult{ForLLM: "result-3"} results := al.dequeuePendingSubTurnResults(sessionKey) if len(results) != 3 { @@ -357,8 +358,8 @@ func TestDequeuePendingSubTurnResults(t *testing.T) { t.Errorf("expected empty after drain, got %d", len(results)) } - // Unregistered session returns nil - al.unregisterSubTurnResultChannel(sessionKey) + // After removing from activeTurnStates, returns nil + al.activeTurnStates.Delete(sessionKey) if results := al.dequeuePendingSubTurnResults(sessionKey); results != nil { t.Error("expected nil for unregistered session") } @@ -766,15 +767,21 @@ func TestFinalPollCapturesLateResults(t *testing.T) { defer cleanup() sessionKey := "test-session-final-poll" - ch := make(chan *tools.ToolResult, 4) - // Register the channel - al.registerSubTurnResultChannel(sessionKey, ch) - defer al.unregisterSubTurnResultChannel(sessionKey) + // Register a turnState + ts := &turnState{ + ctx: context.Background(), + turnID: sessionKey, + depth: 0, + session: &ephemeralSessionStore{}, + pendingResults: make(chan *tools.ToolResult, 4), + } + al.activeTurnStates.Store(sessionKey, ts) + defer al.activeTurnStates.Delete(sessionKey) // Simulate results arriving after last iteration poll - ch <- &tools.ToolResult{ForLLM: "result 1"} - ch <- &tools.ToolResult{ForLLM: "result 2"} + ts.pendingResults <- &tools.ToolResult{ForLLM: "result 1"} + ts.pendingResults <- &tools.ToolResult{ForLLM: "result 2"} // Dequeue should capture both results results := al.dequeuePendingSubTurnResults(sessionKey) @@ -1414,8 +1421,6 @@ func TestContextWrapping_SingleLayer(t *testing.T) { t.Log("Context wrapping test passed - no redundant layers detected") } - - // TestSyncSubTurn_NoChannelDelivery verifies that synchronous sub-turns // do NOT deliver results to the pendingResults channel (only return directly). func TestSyncSubTurn_NoChannelDelivery(t *testing.T) { @@ -1526,8 +1531,6 @@ func TestAsyncSubTurn_ChannelDelivery(t *testing.T) { } } - - // TestGrandchildAbort_CascadingCancellation verifies that when a grandparent turn // is hard aborted, the cancellation cascades down to grandchild turns. func TestGrandchildAbort_CascadingCancellation(t *testing.T) { @@ -1949,9 +1952,9 @@ func TestFinish_GracefulVsHard(t *testing.T) { parentTS.ctx, parentTS.cancelFunc = context.WithCancel(ctx) childTS := &turnState{ - ctx: ctx, - turnID: "child-isended-test", - depth: 1, + ctx: ctx, + turnID: "child-isended-test", + depth: 1, parentTurnState: parentTS, pendingResults: make(chan *tools.ToolResult, 16), } diff --git a/pkg/agent/turn_state.go b/pkg/agent/turn_state.go index ff2bf0d68..d5c98ff7f 100644 --- a/pkg/agent/turn_state.go +++ b/pkg/agent/turn_state.go @@ -54,6 +54,10 @@ type turnState struct { // to continue running (Critical=true) or exit gracefully (Critical=false). parentEnded atomic.Bool + // critical indicates whether this SubTurn should continue running after + // the parent turn finishes gracefully. Set from SubTurnConfig.Critical. + critical bool + // parentTurnState holds a reference to the parent turnState. // This allows child SubTurns to check if the parent has ended. // Nil for root turns. diff --git a/pkg/tools/subagent.go b/pkg/tools/subagent.go index 288c5065e..d41cf9a6d 100644 --- a/pkg/tools/subagent.go +++ b/pkg/tools/subagent.go @@ -22,7 +22,10 @@ type SubTurnConfig struct { SystemPrompt string MaxTokens int Temperature float64 - Async bool // true for async (spawn), false for sync (subagent) + Async bool // true for async (spawn), false for sync (subagent) + Critical bool // continue running after parent finishes gracefully + Timeout time.Duration // 0 = use default (5 minutes) + MaxContextRunes int // 0 = auto, -1 = no limit, >0 = explicit limit } type SubagentTask struct {