fix(agent): implement Critical flag, complete tools.SubTurnConfig, remove redundant subTurnResults

- Critical flag was declared but never acted on; non-critical SubTurns
  now break out of the iteration loop when IsParentEnded() returns true
- tools.SubTurnConfig was missing Critical/Timeout/MaxContextRunes,
  making those fields unreachable from the tools layer; added fields and
  wired them through AgentLoopSpawner.SpawnSubTurn
- Removed subTurnResults sync.Map from AgentLoop — it was a redundant
  alias for the same channel already stored in turnState.pendingResults;
  dequeuePendingSubTurnResults now reads directly via activeTurnStates
- Replace hardcoded concurrencySem size 5 with maxConcurrentSubTurns constant
- Update affected tests to match new dequeuePendingSubTurnResults API
This commit is contained in:
Administrator
2026-03-18 18:22:06 +08:00
parent 777230dcd1
commit 3611034795
6 changed files with 63 additions and 62 deletions
+10 -11
View File
@@ -49,7 +49,6 @@ type AgentLoop struct {
cmdRegistry *commands.Registry
mcp mcpRuntime
steering *steeringQueue
subTurnResults sync.Map // key: sessionKey (string), value: chan *tools.ToolResult
activeTurnStates sync.Map // key: sessionKey (string), value: *turnState
subTurnCounter atomic.Int64 // Counter for generating unique SubTurn IDs
mu sync.RWMutex
@@ -1001,7 +1000,7 @@ func (al *AgentLoop) runAgentLoop(
session: agent.Sessions,
initialHistoryLength: len(agent.Sessions.GetHistory("")), // Snapshot for rollback on hard abort
pendingResults: make(chan *tools.ToolResult, 16),
concurrencySem: make(chan struct{}, 5), // maxConcurrentSubTurns
concurrencySem: make(chan struct{}, maxConcurrentSubTurns), // maxConcurrentSubTurns
}
ctx = withTurnState(ctx, rootTS)
ctx = WithAgentLoop(ctx, al) // Inject AgentLoop for tool access
@@ -1010,10 +1009,6 @@ func (al *AgentLoop) runAgentLoop(
// Register this root turn state so HardAbort can find it
al.activeTurnStates.Store(opts.SessionKey, rootTS)
defer al.activeTurnStates.Delete(opts.SessionKey)
// Ensure the parent's pending results channel is cleaned up when this root turn finishes
defer al.unregisterSubTurnResultChannel(rootTS.turnID)
al.registerSubTurnResultChannel(rootTS.turnID, rootTS.pendingResults)
}
// 0. Record last channel for heartbeat notifications (skip internal channels and cli)
@@ -1220,15 +1215,19 @@ func (al *AgentLoop) runLLMIteration(
// This is only relevant for SubTurns (turnState with parentTurnState != nil).
// If parent ended and this SubTurn is not Critical, exit gracefully.
if ts := turnStateFromContext(ctx); ts != nil && ts.IsParentEnded() {
logger.InfoCF("agent", "Parent turn ended, SubTurn continues or exits", map[string]any{
if !ts.critical {
logger.InfoCF("agent", "Parent turn ended, non-critical SubTurn exiting gracefully", map[string]any{
"agent_id": agent.ID,
"iteration": iteration,
"turn_id": ts.turnID,
})
break
}
logger.InfoCF("agent", "Parent turn ended, critical SubTurn continues running", map[string]any{
"agent_id": agent.ID,
"iteration": iteration,
"turn_id": ts.turnID,
})
// For now, we continue running. The Critical flag check is handled
// at SubTurnConfig level in spawnSubTurn. Here we just log and continue.
// If this SubTurn should exit gracefully, it would have been cancelled
// by its own timeout or the caller would have handled it.
}
// Inject pending steering messages into the conversation context
+4 -16
View File
@@ -192,14 +192,13 @@ func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID s
// dequeuePendingSubTurnResults polls the SubTurn result channel for the given
// session and returns all available results without blocking.
// Returns nil if no channel is registered for this session.
// Returns nil if no active turn state exists for this session.
func (al *AgentLoop) dequeuePendingSubTurnResults(sessionKey string) []*tools.ToolResult {
chInterface, ok := al.subTurnResults.Load(sessionKey)
tsInterface, ok := al.activeTurnStates.Load(sessionKey)
if !ok {
return nil
}
ch, ok := chInterface.(chan *tools.ToolResult)
ts, ok := tsInterface.(*turnState)
if !ok {
return nil
}
@@ -207,7 +206,7 @@ func (al *AgentLoop) dequeuePendingSubTurnResults(sessionKey string) []*tools.To
var results []*tools.ToolResult
for {
select {
case result := <-ch:
case result := <-ts.pendingResults:
if result != nil {
results = append(results, result)
}
@@ -217,17 +216,6 @@ func (al *AgentLoop) dequeuePendingSubTurnResults(sessionKey string) []*tools.To
}
}
// registerSubTurnResultChannel registers a SubTurn result channel for the given session.
// This allows the parent loop to poll for results from child SubTurns.
func (al *AgentLoop) registerSubTurnResultChannel(sessionKey string, ch chan *tools.ToolResult) {
al.subTurnResults.Store(sessionKey, ch)
}
// unregisterSubTurnResultChannel removes the SubTurn result channel for the given session.
func (al *AgentLoop) unregisterSubTurnResultChannel(sessionKey string) {
al.subTurnResults.Delete(sessionKey)
}
// ====================== Hard Abort ======================
// HardAbort immediately cancels the running agent loop for the given session,
+9 -5
View File
@@ -186,11 +186,14 @@ func (s *AgentLoopSpawner) SpawnSubTurn(ctx context.Context, cfg tools.SubTurnCo
// Convert tools.SubTurnConfig to agent.SubTurnConfig
agentCfg := SubTurnConfig{
Model: cfg.Model,
Tools: cfg.Tools,
SystemPrompt: cfg.SystemPrompt,
MaxTokens: cfg.MaxTokens,
Async: cfg.Async,
Model: cfg.Model,
Tools: cfg.Tools,
SystemPrompt: cfg.SystemPrompt,
MaxTokens: cfg.MaxTokens,
Async: cfg.Async,
Critical: cfg.Critical,
Timeout: cfg.Timeout,
MaxContextRunes: cfg.MaxContextRunes,
}
return spawnSubTurn(ctx, s.al, parentTS, agentCfg)
@@ -277,6 +280,7 @@ func spawnSubTurn(ctx context.Context, al *AgentLoop, parentTS *turnState, cfg S
childTS := newTurnState(childCtx, childID, parentTS)
// Set the cancel function so Finish(true) can trigger hard cancellation
childTS.cancelFunc = cancel
childTS.critical = cfg.Critical
// IMPORTANT: Put childTS into childCtx so that code inside runTurn can retrieve it
childCtx = withTurnState(childCtx, childTS)
+32 -29
View File
@@ -315,11 +315,6 @@ func TestSubTurnResultChannelRegistration(t *testing.T) {
}
_, _ = spawnSubTurn(context.Background(), al, parent, cfg)
// After spawn completes: channel should be unregistered (defer cleanup in spawnSubTurn)
if _, ok := al.subTurnResults.Load(parent.turnID); ok {
t.Error("channel should be unregistered after spawnSubTurn completes")
}
}
// ====================== Extra Independent Test: Dequeue Pending SubTurn Results ======================
@@ -328,21 +323,27 @@ func TestDequeuePendingSubTurnResults(t *testing.T) {
defer cleanup()
sessionKey := "test-session-dequeue"
ch := make(chan *tools.ToolResult, 4)
// Register channel manually
al.registerSubTurnResultChannel(sessionKey, ch)
defer al.unregisterSubTurnResultChannel(sessionKey)
// Empty channel returns nil
// Empty (no turnState registered) returns nil
if results := al.dequeuePendingSubTurnResults(sessionKey); len(results) != 0 {
t.Errorf("expected empty results, got %d", len(results))
}
// Register a turnState so dequeuePendingSubTurnResults can find it
ts := &turnState{
ctx: context.Background(),
turnID: sessionKey,
depth: 0,
session: &ephemeralSessionStore{},
pendingResults: make(chan *tools.ToolResult, 4),
}
al.activeTurnStates.Store(sessionKey, ts)
defer al.activeTurnStates.Delete(sessionKey)
// Put 3 results in
ch <- &tools.ToolResult{ForLLM: "result-1"}
ch <- &tools.ToolResult{ForLLM: "result-2"}
ch <- &tools.ToolResult{ForLLM: "result-3"}
ts.pendingResults <- &tools.ToolResult{ForLLM: "result-1"}
ts.pendingResults <- &tools.ToolResult{ForLLM: "result-2"}
ts.pendingResults <- &tools.ToolResult{ForLLM: "result-3"}
results := al.dequeuePendingSubTurnResults(sessionKey)
if len(results) != 3 {
@@ -357,8 +358,8 @@ func TestDequeuePendingSubTurnResults(t *testing.T) {
t.Errorf("expected empty after drain, got %d", len(results))
}
// Unregistered session returns nil
al.unregisterSubTurnResultChannel(sessionKey)
// After removing from activeTurnStates, returns nil
al.activeTurnStates.Delete(sessionKey)
if results := al.dequeuePendingSubTurnResults(sessionKey); results != nil {
t.Error("expected nil for unregistered session")
}
@@ -766,15 +767,21 @@ func TestFinalPollCapturesLateResults(t *testing.T) {
defer cleanup()
sessionKey := "test-session-final-poll"
ch := make(chan *tools.ToolResult, 4)
// Register the channel
al.registerSubTurnResultChannel(sessionKey, ch)
defer al.unregisterSubTurnResultChannel(sessionKey)
// Register a turnState
ts := &turnState{
ctx: context.Background(),
turnID: sessionKey,
depth: 0,
session: &ephemeralSessionStore{},
pendingResults: make(chan *tools.ToolResult, 4),
}
al.activeTurnStates.Store(sessionKey, ts)
defer al.activeTurnStates.Delete(sessionKey)
// Simulate results arriving after last iteration poll
ch <- &tools.ToolResult{ForLLM: "result 1"}
ch <- &tools.ToolResult{ForLLM: "result 2"}
ts.pendingResults <- &tools.ToolResult{ForLLM: "result 1"}
ts.pendingResults <- &tools.ToolResult{ForLLM: "result 2"}
// Dequeue should capture both results
results := al.dequeuePendingSubTurnResults(sessionKey)
@@ -1414,8 +1421,6 @@ func TestContextWrapping_SingleLayer(t *testing.T) {
t.Log("Context wrapping test passed - no redundant layers detected")
}
// TestSyncSubTurn_NoChannelDelivery verifies that synchronous sub-turns
// do NOT deliver results to the pendingResults channel (only return directly).
func TestSyncSubTurn_NoChannelDelivery(t *testing.T) {
@@ -1526,8 +1531,6 @@ func TestAsyncSubTurn_ChannelDelivery(t *testing.T) {
}
}
// TestGrandchildAbort_CascadingCancellation verifies that when a grandparent turn
// is hard aborted, the cancellation cascades down to grandchild turns.
func TestGrandchildAbort_CascadingCancellation(t *testing.T) {
@@ -1949,9 +1952,9 @@ func TestFinish_GracefulVsHard(t *testing.T) {
parentTS.ctx, parentTS.cancelFunc = context.WithCancel(ctx)
childTS := &turnState{
ctx: ctx,
turnID: "child-isended-test",
depth: 1,
ctx: ctx,
turnID: "child-isended-test",
depth: 1,
parentTurnState: parentTS,
pendingResults: make(chan *tools.ToolResult, 16),
}
+4
View File
@@ -54,6 +54,10 @@ type turnState struct {
// to continue running (Critical=true) or exit gracefully (Critical=false).
parentEnded atomic.Bool
// critical indicates whether this SubTurn should continue running after
// the parent turn finishes gracefully. Set from SubTurnConfig.Critical.
critical bool
// parentTurnState holds a reference to the parent turnState.
// This allows child SubTurns to check if the parent has ended.
// Nil for root turns.
+4 -1
View File
@@ -22,7 +22,10 @@ type SubTurnConfig struct {
SystemPrompt string
MaxTokens int
Temperature float64
Async bool // true for async (spawn), false for sync (subagent)
Async bool // true for async (spawn), false for sync (subagent)
Critical bool // continue running after parent finishes gracefully
Timeout time.Duration // 0 = use default (5 minutes)
MaxContextRunes int // 0 = auto, -1 = no limit, >0 = explicit limit
}
type SubagentTask struct {