mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
fix(agent): implement Critical flag, complete tools.SubTurnConfig, remove redundant subTurnResults
- Critical flag was declared but never acted on; non-critical SubTurns now break out of the iteration loop when IsParentEnded() returns true - tools.SubTurnConfig was missing Critical/Timeout/MaxContextRunes, making those fields unreachable from the tools layer; added fields and wired them through AgentLoopSpawner.SpawnSubTurn - Removed subTurnResults sync.Map from AgentLoop — it was a redundant alias for the same channel already stored in turnState.pendingResults; dequeuePendingSubTurnResults now reads directly via activeTurnStates - Replace hardcoded concurrencySem size 5 with maxConcurrentSubTurns constant - Update affected tests to match new dequeuePendingSubTurnResults API
This commit is contained in:
+10
-11
@@ -49,7 +49,6 @@ type AgentLoop struct {
|
||||
cmdRegistry *commands.Registry
|
||||
mcp mcpRuntime
|
||||
steering *steeringQueue
|
||||
subTurnResults sync.Map // key: sessionKey (string), value: chan *tools.ToolResult
|
||||
activeTurnStates sync.Map // key: sessionKey (string), value: *turnState
|
||||
subTurnCounter atomic.Int64 // Counter for generating unique SubTurn IDs
|
||||
mu sync.RWMutex
|
||||
@@ -1001,7 +1000,7 @@ func (al *AgentLoop) runAgentLoop(
|
||||
session: agent.Sessions,
|
||||
initialHistoryLength: len(agent.Sessions.GetHistory("")), // Snapshot for rollback on hard abort
|
||||
pendingResults: make(chan *tools.ToolResult, 16),
|
||||
concurrencySem: make(chan struct{}, 5), // maxConcurrentSubTurns
|
||||
concurrencySem: make(chan struct{}, maxConcurrentSubTurns), // maxConcurrentSubTurns
|
||||
}
|
||||
ctx = withTurnState(ctx, rootTS)
|
||||
ctx = WithAgentLoop(ctx, al) // Inject AgentLoop for tool access
|
||||
@@ -1010,10 +1009,6 @@ func (al *AgentLoop) runAgentLoop(
|
||||
// Register this root turn state so HardAbort can find it
|
||||
al.activeTurnStates.Store(opts.SessionKey, rootTS)
|
||||
defer al.activeTurnStates.Delete(opts.SessionKey)
|
||||
|
||||
// Ensure the parent's pending results channel is cleaned up when this root turn finishes
|
||||
defer al.unregisterSubTurnResultChannel(rootTS.turnID)
|
||||
al.registerSubTurnResultChannel(rootTS.turnID, rootTS.pendingResults)
|
||||
}
|
||||
|
||||
// 0. Record last channel for heartbeat notifications (skip internal channels and cli)
|
||||
@@ -1220,15 +1215,19 @@ func (al *AgentLoop) runLLMIteration(
|
||||
// This is only relevant for SubTurns (turnState with parentTurnState != nil).
|
||||
// If parent ended and this SubTurn is not Critical, exit gracefully.
|
||||
if ts := turnStateFromContext(ctx); ts != nil && ts.IsParentEnded() {
|
||||
logger.InfoCF("agent", "Parent turn ended, SubTurn continues or exits", map[string]any{
|
||||
if !ts.critical {
|
||||
logger.InfoCF("agent", "Parent turn ended, non-critical SubTurn exiting gracefully", map[string]any{
|
||||
"agent_id": agent.ID,
|
||||
"iteration": iteration,
|
||||
"turn_id": ts.turnID,
|
||||
})
|
||||
break
|
||||
}
|
||||
logger.InfoCF("agent", "Parent turn ended, critical SubTurn continues running", map[string]any{
|
||||
"agent_id": agent.ID,
|
||||
"iteration": iteration,
|
||||
"turn_id": ts.turnID,
|
||||
})
|
||||
// For now, we continue running. The Critical flag check is handled
|
||||
// at SubTurnConfig level in spawnSubTurn. Here we just log and continue.
|
||||
// If this SubTurn should exit gracefully, it would have been cancelled
|
||||
// by its own timeout or the caller would have handled it.
|
||||
}
|
||||
|
||||
// Inject pending steering messages into the conversation context
|
||||
|
||||
+4
-16
@@ -192,14 +192,13 @@ func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID s
|
||||
|
||||
// dequeuePendingSubTurnResults polls the SubTurn result channel for the given
|
||||
// session and returns all available results without blocking.
|
||||
// Returns nil if no channel is registered for this session.
|
||||
// Returns nil if no active turn state exists for this session.
|
||||
func (al *AgentLoop) dequeuePendingSubTurnResults(sessionKey string) []*tools.ToolResult {
|
||||
chInterface, ok := al.subTurnResults.Load(sessionKey)
|
||||
tsInterface, ok := al.activeTurnStates.Load(sessionKey)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
ch, ok := chInterface.(chan *tools.ToolResult)
|
||||
ts, ok := tsInterface.(*turnState)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
@@ -207,7 +206,7 @@ func (al *AgentLoop) dequeuePendingSubTurnResults(sessionKey string) []*tools.To
|
||||
var results []*tools.ToolResult
|
||||
for {
|
||||
select {
|
||||
case result := <-ch:
|
||||
case result := <-ts.pendingResults:
|
||||
if result != nil {
|
||||
results = append(results, result)
|
||||
}
|
||||
@@ -217,17 +216,6 @@ func (al *AgentLoop) dequeuePendingSubTurnResults(sessionKey string) []*tools.To
|
||||
}
|
||||
}
|
||||
|
||||
// registerSubTurnResultChannel registers a SubTurn result channel for the given session.
|
||||
// This allows the parent loop to poll for results from child SubTurns.
|
||||
func (al *AgentLoop) registerSubTurnResultChannel(sessionKey string, ch chan *tools.ToolResult) {
|
||||
al.subTurnResults.Store(sessionKey, ch)
|
||||
}
|
||||
|
||||
// unregisterSubTurnResultChannel removes the SubTurn result channel for the given session.
|
||||
func (al *AgentLoop) unregisterSubTurnResultChannel(sessionKey string) {
|
||||
al.subTurnResults.Delete(sessionKey)
|
||||
}
|
||||
|
||||
// ====================== Hard Abort ======================
|
||||
|
||||
// HardAbort immediately cancels the running agent loop for the given session,
|
||||
|
||||
@@ -186,11 +186,14 @@ func (s *AgentLoopSpawner) SpawnSubTurn(ctx context.Context, cfg tools.SubTurnCo
|
||||
|
||||
// Convert tools.SubTurnConfig to agent.SubTurnConfig
|
||||
agentCfg := SubTurnConfig{
|
||||
Model: cfg.Model,
|
||||
Tools: cfg.Tools,
|
||||
SystemPrompt: cfg.SystemPrompt,
|
||||
MaxTokens: cfg.MaxTokens,
|
||||
Async: cfg.Async,
|
||||
Model: cfg.Model,
|
||||
Tools: cfg.Tools,
|
||||
SystemPrompt: cfg.SystemPrompt,
|
||||
MaxTokens: cfg.MaxTokens,
|
||||
Async: cfg.Async,
|
||||
Critical: cfg.Critical,
|
||||
Timeout: cfg.Timeout,
|
||||
MaxContextRunes: cfg.MaxContextRunes,
|
||||
}
|
||||
|
||||
return spawnSubTurn(ctx, s.al, parentTS, agentCfg)
|
||||
@@ -277,6 +280,7 @@ func spawnSubTurn(ctx context.Context, al *AgentLoop, parentTS *turnState, cfg S
|
||||
childTS := newTurnState(childCtx, childID, parentTS)
|
||||
// Set the cancel function so Finish(true) can trigger hard cancellation
|
||||
childTS.cancelFunc = cancel
|
||||
childTS.critical = cfg.Critical
|
||||
|
||||
// IMPORTANT: Put childTS into childCtx so that code inside runTurn can retrieve it
|
||||
childCtx = withTurnState(childCtx, childTS)
|
||||
|
||||
+32
-29
@@ -315,11 +315,6 @@ func TestSubTurnResultChannelRegistration(t *testing.T) {
|
||||
}
|
||||
|
||||
_, _ = spawnSubTurn(context.Background(), al, parent, cfg)
|
||||
|
||||
// After spawn completes: channel should be unregistered (defer cleanup in spawnSubTurn)
|
||||
if _, ok := al.subTurnResults.Load(parent.turnID); ok {
|
||||
t.Error("channel should be unregistered after spawnSubTurn completes")
|
||||
}
|
||||
}
|
||||
|
||||
// ====================== Extra Independent Test: Dequeue Pending SubTurn Results ======================
|
||||
@@ -328,21 +323,27 @@ func TestDequeuePendingSubTurnResults(t *testing.T) {
|
||||
defer cleanup()
|
||||
|
||||
sessionKey := "test-session-dequeue"
|
||||
ch := make(chan *tools.ToolResult, 4)
|
||||
|
||||
// Register channel manually
|
||||
al.registerSubTurnResultChannel(sessionKey, ch)
|
||||
defer al.unregisterSubTurnResultChannel(sessionKey)
|
||||
|
||||
// Empty channel returns nil
|
||||
// Empty (no turnState registered) returns nil
|
||||
if results := al.dequeuePendingSubTurnResults(sessionKey); len(results) != 0 {
|
||||
t.Errorf("expected empty results, got %d", len(results))
|
||||
}
|
||||
|
||||
// Register a turnState so dequeuePendingSubTurnResults can find it
|
||||
ts := &turnState{
|
||||
ctx: context.Background(),
|
||||
turnID: sessionKey,
|
||||
depth: 0,
|
||||
session: &ephemeralSessionStore{},
|
||||
pendingResults: make(chan *tools.ToolResult, 4),
|
||||
}
|
||||
al.activeTurnStates.Store(sessionKey, ts)
|
||||
defer al.activeTurnStates.Delete(sessionKey)
|
||||
|
||||
// Put 3 results in
|
||||
ch <- &tools.ToolResult{ForLLM: "result-1"}
|
||||
ch <- &tools.ToolResult{ForLLM: "result-2"}
|
||||
ch <- &tools.ToolResult{ForLLM: "result-3"}
|
||||
ts.pendingResults <- &tools.ToolResult{ForLLM: "result-1"}
|
||||
ts.pendingResults <- &tools.ToolResult{ForLLM: "result-2"}
|
||||
ts.pendingResults <- &tools.ToolResult{ForLLM: "result-3"}
|
||||
|
||||
results := al.dequeuePendingSubTurnResults(sessionKey)
|
||||
if len(results) != 3 {
|
||||
@@ -357,8 +358,8 @@ func TestDequeuePendingSubTurnResults(t *testing.T) {
|
||||
t.Errorf("expected empty after drain, got %d", len(results))
|
||||
}
|
||||
|
||||
// Unregistered session returns nil
|
||||
al.unregisterSubTurnResultChannel(sessionKey)
|
||||
// After removing from activeTurnStates, returns nil
|
||||
al.activeTurnStates.Delete(sessionKey)
|
||||
if results := al.dequeuePendingSubTurnResults(sessionKey); results != nil {
|
||||
t.Error("expected nil for unregistered session")
|
||||
}
|
||||
@@ -766,15 +767,21 @@ func TestFinalPollCapturesLateResults(t *testing.T) {
|
||||
defer cleanup()
|
||||
|
||||
sessionKey := "test-session-final-poll"
|
||||
ch := make(chan *tools.ToolResult, 4)
|
||||
|
||||
// Register the channel
|
||||
al.registerSubTurnResultChannel(sessionKey, ch)
|
||||
defer al.unregisterSubTurnResultChannel(sessionKey)
|
||||
// Register a turnState
|
||||
ts := &turnState{
|
||||
ctx: context.Background(),
|
||||
turnID: sessionKey,
|
||||
depth: 0,
|
||||
session: &ephemeralSessionStore{},
|
||||
pendingResults: make(chan *tools.ToolResult, 4),
|
||||
}
|
||||
al.activeTurnStates.Store(sessionKey, ts)
|
||||
defer al.activeTurnStates.Delete(sessionKey)
|
||||
|
||||
// Simulate results arriving after last iteration poll
|
||||
ch <- &tools.ToolResult{ForLLM: "result 1"}
|
||||
ch <- &tools.ToolResult{ForLLM: "result 2"}
|
||||
ts.pendingResults <- &tools.ToolResult{ForLLM: "result 1"}
|
||||
ts.pendingResults <- &tools.ToolResult{ForLLM: "result 2"}
|
||||
|
||||
// Dequeue should capture both results
|
||||
results := al.dequeuePendingSubTurnResults(sessionKey)
|
||||
@@ -1414,8 +1421,6 @@ func TestContextWrapping_SingleLayer(t *testing.T) {
|
||||
t.Log("Context wrapping test passed - no redundant layers detected")
|
||||
}
|
||||
|
||||
|
||||
|
||||
// TestSyncSubTurn_NoChannelDelivery verifies that synchronous sub-turns
|
||||
// do NOT deliver results to the pendingResults channel (only return directly).
|
||||
func TestSyncSubTurn_NoChannelDelivery(t *testing.T) {
|
||||
@@ -1526,8 +1531,6 @@ func TestAsyncSubTurn_ChannelDelivery(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
// TestGrandchildAbort_CascadingCancellation verifies that when a grandparent turn
|
||||
// is hard aborted, the cancellation cascades down to grandchild turns.
|
||||
func TestGrandchildAbort_CascadingCancellation(t *testing.T) {
|
||||
@@ -1949,9 +1952,9 @@ func TestFinish_GracefulVsHard(t *testing.T) {
|
||||
parentTS.ctx, parentTS.cancelFunc = context.WithCancel(ctx)
|
||||
|
||||
childTS := &turnState{
|
||||
ctx: ctx,
|
||||
turnID: "child-isended-test",
|
||||
depth: 1,
|
||||
ctx: ctx,
|
||||
turnID: "child-isended-test",
|
||||
depth: 1,
|
||||
parentTurnState: parentTS,
|
||||
pendingResults: make(chan *tools.ToolResult, 16),
|
||||
}
|
||||
|
||||
@@ -54,6 +54,10 @@ type turnState struct {
|
||||
// to continue running (Critical=true) or exit gracefully (Critical=false).
|
||||
parentEnded atomic.Bool
|
||||
|
||||
// critical indicates whether this SubTurn should continue running after
|
||||
// the parent turn finishes gracefully. Set from SubTurnConfig.Critical.
|
||||
critical bool
|
||||
|
||||
// parentTurnState holds a reference to the parent turnState.
|
||||
// This allows child SubTurns to check if the parent has ended.
|
||||
// Nil for root turns.
|
||||
|
||||
@@ -22,7 +22,10 @@ type SubTurnConfig struct {
|
||||
SystemPrompt string
|
||||
MaxTokens int
|
||||
Temperature float64
|
||||
Async bool // true for async (spawn), false for sync (subagent)
|
||||
Async bool // true for async (spawn), false for sync (subagent)
|
||||
Critical bool // continue running after parent finishes gracefully
|
||||
Timeout time.Duration // 0 = use default (5 minutes)
|
||||
MaxContextRunes int // 0 = auto, -1 = no limit, >0 = explicit limit
|
||||
}
|
||||
|
||||
type SubagentTask struct {
|
||||
|
||||
Reference in New Issue
Block a user