feat(agent): add concurrency semaphore and hard abort for SubTurn

- Add maxConcurrentSubTurns constant (5) and concurrencySem channel to turnState
- Acquire/release semaphore in spawnSubTurn to limit concurrent child turns per parent
- Add activeTurnStates sync.Map to AgentLoop for tracking root turn states by session
- Implement HardAbort(sessionKey) method to trigger cascading cancellation via turnState.Finish()
- Register/unregister root turnState in runAgentLoop for hard abort lookup
- Add TestSubTurnConcurrencySemaphore to verify semaphore capacity enforcement
- Add TestHardAbortCascading to verify context cancellation propagates to child turns
This commit is contained in:
Administrator
2026-03-16 21:03:58 +08:00
parent ceeae15d8a
commit 1236dd9e6d
4 changed files with 190 additions and 13 deletions
+10 -3
View File
@@ -48,9 +48,10 @@ type AgentLoop struct {
transcriber voice.Transcriber
cmdRegistry *commands.Registry
mcp mcpRuntime
steering *steeringQueue
subTurnResults sync.Map
mu sync.RWMutex
steering *steeringQueue
subTurnResults sync.Map // key: sessionKey (string), value: chan *tools.ToolResult
activeTurnStates sync.Map // key: sessionKey (string), value: *turnState
mu sync.RWMutex
// Track active requests for safe provider cleanup
activeRequests sync.WaitGroup
}
@@ -253,6 +254,7 @@ func registerSharedTools(
depth: 0,
session: newEphemeralSession(nil),
pendingResults: make(chan *tools.ToolResult, 16),
concurrencySem: make(chan struct{}, 5),
}
}
@@ -969,9 +971,14 @@ func (al *AgentLoop) runAgentLoop(
depth: 0,
session: agent.Sessions,
pendingResults: make(chan *tools.ToolResult, 16),
concurrencySem: make(chan struct{}, 5), // maxConcurrentSubTurns
}
ctx = withTurnState(ctx, rootTS)
// Register this root turn state so HardAbort can find it
al.activeTurnStates.Store(opts.SessionKey, rootTS)
defer al.activeTurnStates.Delete(opts.SessionKey)
// Ensure the parent's pending results channel is cleaned up when this root turn finishes
defer al.unregisterSubTurnResultChannel(rootTS.turnID)
al.registerSubTurnResultChannel(rootTS.turnID, rootTS.pendingResults)