diff --git a/.gitignore b/.gitignore index 61fe494ca..74245a906 100644 --- a/.gitignore +++ b/.gitignore @@ -56,3 +56,5 @@ dist/ !web/backend/dist/ web/backend/dist/* !web/backend/dist/.gitkeep + +.claude/ \ No newline at end of file diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 3324d56cc..b9fa1023a 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -36,21 +36,22 @@ import ( ) type AgentLoop struct { - bus *bus.MessageBus - cfg *config.Config - registry *AgentRegistry - state *state.Manager - running atomic.Bool - summarizing sync.Map - fallback *providers.FallbackChain - channelManager *channels.Manager - mediaStore media.MediaStore - transcriber voice.Transcriber - cmdRegistry *commands.Registry - mcp mcpRuntime + bus *bus.MessageBus + cfg *config.Config + registry *AgentRegistry + state *state.Manager + running atomic.Bool + summarizing sync.Map + fallback *providers.FallbackChain + channelManager *channels.Manager + mediaStore media.MediaStore + transcriber voice.Transcriber + cmdRegistry *commands.Registry + mcp mcpRuntime steering *steeringQueue subTurnResults sync.Map // key: sessionKey (string), value: chan *tools.ToolResult activeTurnStates sync.Map // key: sessionKey (string), value: *turnState + subTurnCounter atomic.Int64 // Counter for generating unique SubTurn IDs mu sync.RWMutex // Track active requests for safe provider cleanup activeRequests sync.WaitGroup @@ -964,25 +965,39 @@ func (al *AgentLoop) runAgentLoop( agent *AgentInstance, opts processOptions, ) (string, error) { - // Initialize a root TurnState for this iteration, allowing sub-turns to be spawned. - rootTS := &turnState{ - ctx: ctx, - turnID: opts.SessionKey, // Associate this turn graph with the current session key - depth: 0, - session: agent.Sessions, - initialHistoryLength: len(agent.Sessions.GetHistory("")), // Snapshot for rollback on hard abort - pendingResults: make(chan *tools.ToolResult, 16), - concurrencySem: make(chan struct{}, 5), // maxConcurrentSubTurns + // Check if we're already inside a SubTurn (context already has a turnState). + // If so, reuse it instead of creating a new root turnState. + // This prevents turnState hierarchy corruption when SubTurns recursively call runAgentLoop. + existingTS := turnStateFromContext(ctx) + var rootTS *turnState + var isRootTurn bool + + if existingTS != nil { + // We're inside a SubTurn — reuse the existing turnState + rootTS = existingTS + isRootTurn = false + } else { + // This is a top-level turn — initialize a new root TurnState + rootTS = &turnState{ + ctx: ctx, + turnID: opts.SessionKey, // Associate this turn graph with the current session key + depth: 0, + session: agent.Sessions, + initialHistoryLength: len(agent.Sessions.GetHistory("")), // Snapshot for rollback on hard abort + pendingResults: make(chan *tools.ToolResult, 16), + concurrencySem: make(chan struct{}, 5), // maxConcurrentSubTurns + } + ctx = withTurnState(ctx, rootTS) + isRootTurn = true + + // Register this root turn state so HardAbort can find it + al.activeTurnStates.Store(opts.SessionKey, rootTS) + defer al.activeTurnStates.Delete(opts.SessionKey) + + // Ensure the parent's pending results channel is cleaned up when this root turn finishes + defer al.unregisterSubTurnResultChannel(rootTS.turnID) + al.registerSubTurnResultChannel(rootTS.turnID, rootTS.pendingResults) } - ctx = withTurnState(ctx, rootTS) - - // Register this root turn state so HardAbort can find it - al.activeTurnStates.Store(opts.SessionKey, rootTS) - defer al.activeTurnStates.Delete(opts.SessionKey) - - // Ensure the parent's pending results channel is cleaned up when this root turn finishes - defer al.unregisterSubTurnResultChannel(rootTS.turnID) - al.registerSubTurnResultChannel(rootTS.turnID, rootTS.pendingResults) // 0. Record last channel for heartbeat notifications (skip internal channels and cli) if opts.Channel != "" && opts.ChatID != "" { @@ -1028,8 +1043,11 @@ func (al *AgentLoop) runAgentLoop( return "", err } - // Signal completion to rootTS so it knows it is finished, terminating any active sub-turns - rootTS.Finish() + // Signal completion to rootTS so it knows it is finished, terminating any active sub-turns. + // Only call Finish() if this is a root turn (not a SubTurn recursively calling runAgentLoop). + if isRootTurn { + rootTS.Finish() + } // If last tool had ForUser content and we already sent it, we might not need to send final response // This is controlled by the tool's Silent flag and ForUser content diff --git a/pkg/agent/steering.go b/pkg/agent/steering.go index e67a779a3..97461428d 100644 --- a/pkg/agent/steering.go +++ b/pkg/agent/steering.go @@ -255,7 +255,13 @@ func (al *AgentLoop) HardAbort(sessionKey string) error { "initial_history_length": ts.initialHistoryLength, }) - // Rollback session history to the state before this turn started + // IMPORTANT: Trigger cascading cancellation FIRST to stop all child SubTurns + // from adding more messages to the session. This prevents race conditions + // where rollback happens while children are still writing. + ts.Finish() + + // Rollback session history to the state before this turn started. + // This must happen AFTER Finish() to ensure no child turns are still writing. if ts.session != nil { currentHistory := ts.session.GetHistory("") if len(currentHistory) > ts.initialHistoryLength { @@ -268,8 +274,5 @@ func (al *AgentLoop) HardAbort(sessionKey string) error { } } - // Trigger cascading cancellation to all child SubTurns - ts.Finish() - return nil } diff --git a/pkg/agent/subturn.go b/pkg/agent/subturn.go index 0135dfc76..1d0239c4b 100644 --- a/pkg/agent/subturn.go +++ b/pkg/agent/subturn.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "sync" - "sync/atomic" "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/session" @@ -14,8 +13,8 @@ import ( // ====================== Config & Constants ====================== const ( - maxSubTurnDepth = 3 - maxConcurrentSubTurns = 5 + maxSubTurnDepth = 3 + maxConcurrentSubTurns = 5 ) var ( @@ -78,20 +77,19 @@ type turnState struct { turnID string parentTurnID string depth int - childTurnIDs []string + childTurnIDs []string // MUST be accessed under mu lock or maybe add a getter method pendingResults chan *tools.ToolResult session session.SessionStore - initialHistoryLength int // Snapshot of session history length at turn start, for rollback on hard abort + initialHistoryLength int // Snapshot of session history length at turn start, for rollback on hard abort mu sync.Mutex - isFinished bool // Marks if the parent Turn has ended + isFinished bool // MUST be accessed under mu lock concurrencySem chan struct{} // Limits concurrent child sub-turns } // ====================== Helper Functions ====================== -var globalTurnCounter int64 -func generateTurnID() string { - return fmt.Sprintf("subturn-%d", atomic.AddInt64(&globalTurnCounter, 1)) +func (al *AgentLoop) generateSubTurnID() string { + return fmt.Sprintf("subturn-%d", al.subTurnCounter.Add(1)) } func newTurnState(ctx context.Context, id string, parent *turnState) *turnState { @@ -113,13 +111,27 @@ func newTurnState(ctx context.Context, id string, parent *turnState) *turnState } // Finish marks the turn as finished and cancels its context, aborting any running sub-turns. +// It also closes the pendingResults channel to signal that no more results will be delivered. func (ts *turnState) Finish() { ts.mu.Lock() defer ts.mu.Unlock() + + if ts.isFinished { + // Already finished - avoid double close of channel + return + } + ts.isFinished = true + if ts.cancelFunc != nil { ts.cancelFunc() } + + // Close the pendingResults channel to signal no more results will arrive. + // This prevents goroutine leaks from readers waiting on the channel. + if ts.pendingResults != nil { + close(ts.pendingResults) + } } // ephemeralSessionStore is a pure in-memory SessionStore for SubTurns. @@ -186,6 +198,24 @@ func newEphemeralSession(_ session.SessionStore) session.SessionStore { // ====================== Core Function: spawnSubTurn ====================== func spawnSubTurn(ctx context.Context, al *AgentLoop, parentTS *turnState, cfg SubTurnConfig) (result *tools.ToolResult, err error) { + // 0. Acquire concurrency semaphore FIRST to ensure it's released even if early validation fails. + // Blocks if parent already has maxConcurrentSubTurns running. + // Also respects context cancellation so we don't block forever if parent is aborted. + var semAcquired bool + if parentTS.concurrencySem != nil { + select { + case parentTS.concurrencySem <- struct{}{}: + semAcquired = true + defer func() { + if semAcquired { + <-parentTS.concurrencySem + } + }() + case <-ctx.Done(): + return nil, ctx.Err() + } + } + // 1. Depth limit check if parentTS.depth >= maxSubTurnDepth { return nil, ErrDepthLimitExceeded @@ -196,42 +226,31 @@ func spawnSubTurn(ctx context.Context, al *AgentLoop, parentTS *turnState, cfg S return nil, ErrInvalidSubTurnConfig } - // 3. Acquire concurrency semaphore — blocks if parent already has maxConcurrentSubTurns running. - // Also respects context cancellation so we don't block forever if parent is aborted. - if parentTS.concurrencySem != nil { - select { - case parentTS.concurrencySem <- struct{}{}: - defer func() { <-parentTS.concurrencySem }() - case <-ctx.Done(): - return nil, ctx.Err() - } - } - // Create a sub-context for the child turn to support cancellation childCtx, cancel := context.WithCancel(ctx) defer cancel() - // 4. Create child Turn state - childID := generateTurnID() + // 3. Create child Turn state + childID := al.generateSubTurnID() childTS := newTurnState(childCtx, childID, parentTS) - // 5. Establish parent-child relationship (thread-safe) + // 4. Establish parent-child relationship (thread-safe) parentTS.mu.Lock() parentTS.childTurnIDs = append(parentTS.childTurnIDs, childID) parentTS.mu.Unlock() - // 6. Register the parent's pendingResults channel so the parent loop can poll it + // 5. Register the parent's pendingResults channel so the parent loop can poll it al.registerSubTurnResultChannel(parentTS.turnID, parentTS.pendingResults) defer al.unregisterSubTurnResultChannel(parentTS.turnID) - // 7. Emit Spawn event (currently using Mock, will be replaced by real EventBus) + // 6. Emit Spawn event (currently using Mock, will be replaced by real EventBus) MockEventBus.Emit(SubTurnSpawnEvent{ ParentID: parentTS.turnID, ChildID: childID, Config: cfg, }) - // 8. Defer emitting End event, and recover from panics to ensure it's always fired + // 7. Defer emitting End event, and recover from panics to ensure it's always fired defer func() { if r := recover(); r != nil { err = fmt.Errorf("subturn panicked: %v", r) @@ -244,11 +263,11 @@ func spawnSubTurn(ctx context.Context, al *AgentLoop, parentTS *turnState, cfg S }) }() - // 9. Execute sub-turn via the real agent loop. + // 8. Execute sub-turn via the real agent loop. // Build a child AgentInstance from SubTurnConfig, inheriting defaults from the parent agent. result, err = runTurn(childCtx, al, childTS, cfg) - // 10. Deliver result back to parent Turn + // 9. Deliver result back to parent Turn deliverSubTurnResult(parentTS, childID, result) return result, err @@ -256,8 +275,11 @@ func spawnSubTurn(ctx context.Context, al *AgentLoop, parentTS *turnState, cfg S // ====================== Result Delivery ====================== func deliverSubTurnResult(parentTS *turnState, childID string, result *tools.ToolResult) { + // Check parent state under lock, but don't hold lock while sending to channel parentTS.mu.Lock() - defer parentTS.mu.Unlock() + isFinished := parentTS.isFinished + resultChan := parentTS.pendingResults + parentTS.mu.Unlock() // Emit ResultDelivered event MockEventBus.Emit(SubTurnResultDeliveredEvent{ @@ -266,10 +288,24 @@ func deliverSubTurnResult(parentTS *turnState, childID string, result *tools.Too Result: result, }) - if !parentTS.isFinished { + if !isFinished && resultChan != nil { // Parent Turn is still running → Place in pending queue (handled automatically by parent loop in next round) + // Use defer/recover to handle the case where the channel is closed between our check and the send. + defer func() { + if r := recover(); r != nil { + // Channel was closed - treat as orphan result + if result != nil { + MockEventBus.Emit(SubTurnOrphanResultEvent{ + ParentID: parentTS.turnID, + ChildID: childID, + Result: result, + }) + } + } + }() + select { - case parentTS.pendingResults <- result: + case resultChan <- result: default: fmt.Println("[SubTurn] warning: pendingResults channel full") } diff --git a/pkg/agent/subturn_test.go b/pkg/agent/subturn_test.go index 5b99ebf9f..ac085c28a 100644 --- a/pkg/agent/subturn_test.go +++ b/pkg/agent/subturn_test.go @@ -2,8 +2,11 @@ package agent import ( "context" + "fmt" "reflect" + "sync" "testing" + "time" "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/tools" @@ -500,3 +503,221 @@ func TestHardAbortSessionRollback(t *testing.T) { t.Error("history content does not match initial state after rollback") } } + +// TestNestedSubTurnHierarchy verifies that nested SubTurns maintain correct +// parent-child relationships and depth tracking when recursively calling runAgentLoop. +func TestNestedSubTurnHierarchy(t *testing.T) { + al, _, _, _, cleanup := newTestAgentLoop(t) + defer cleanup() + + // Track spawned turns and their depths + type turnInfo struct { + parentID string + childID string + depth int + } + var spawnedTurns []turnInfo + var mu sync.Mutex + + // Override MockEventBus to capture spawn events + originalEmit := MockEventBus.Emit + defer func() { MockEventBus.Emit = originalEmit }() + + MockEventBus.Emit = func(event any) { + if spawnEvent, ok := event.(SubTurnSpawnEvent); ok { + mu.Lock() + // Extract depth from context (we'll verify this matches expected depth) + spawnedTurns = append(spawnedTurns, turnInfo{ + parentID: spawnEvent.ParentID, + childID: spawnEvent.ChildID, + }) + mu.Unlock() + } + } + + // Create a root turn + rootSession := &ephemeralSessionStore{} + rootTS := &turnState{ + ctx: context.Background(), + turnID: "root-turn", + depth: 0, + session: rootSession, + pendingResults: make(chan *tools.ToolResult, 16), + concurrencySem: make(chan struct{}, 5), + } + + // Spawn a child (depth 1) + childCfg := SubTurnConfig{Model: "gpt-4o-mini"} + _, err := spawnSubTurn(context.Background(), al, rootTS, childCfg) + if err != nil { + t.Fatalf("failed to spawn child: %v", err) + } + + // Verify we captured the spawn event + mu.Lock() + if len(spawnedTurns) != 1 { + t.Fatalf("expected 1 spawn event, got %d", len(spawnedTurns)) + } + if spawnedTurns[0].parentID != "root-turn" { + t.Errorf("expected parent ID 'root-turn', got %s", spawnedTurns[0].parentID) + } + mu.Unlock() + + // Verify root turn has the child in its childTurnIDs + rootTS.mu.Lock() + if len(rootTS.childTurnIDs) != 1 { + t.Errorf("expected root to have 1 child, got %d", len(rootTS.childTurnIDs)) + } + rootTS.mu.Unlock() +} + +// TestDeliverSubTurnResultNoDeadlock verifies that deliverSubTurnResult doesn't +// deadlock when multiple goroutines are accessing the parent turnState concurrently. +func TestDeliverSubTurnResultNoDeadlock(t *testing.T) { + parent := &turnState{ + ctx: context.Background(), + turnID: "parent-deadlock-test", + depth: 0, + pendingResults: make(chan *tools.ToolResult, 2), // Small buffer to test blocking + isFinished: false, + } + + // Simulate multiple child turns delivering results concurrently + var wg sync.WaitGroup + numChildren := 10 + + for i := 0; i < numChildren; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + result := &tools.ToolResult{ForLLM: fmt.Sprintf("result-%d", id)} + deliverSubTurnResult(parent, fmt.Sprintf("child-%d", id), result) + }(i) + } + + // Concurrently read from the channel to prevent blocking + go func() { + for i := 0; i < numChildren; i++ { + select { + case <-parent.pendingResults: + case <-time.After(2 * time.Second): + t.Error("timeout waiting for result") + return + } + } + }() + + // Wait for all deliveries to complete (with timeout) + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + // Success - no deadlock + case <-time.After(3 * time.Second): + t.Fatal("deadlock detected: deliverSubTurnResult blocked") + } +} + +// TestHardAbortOrderOfOperations verifies that HardAbort calls Finish() before +// rolling back session history, minimizing the race window where new messages +// could be added after rollback. +func TestHardAbortOrderOfOperations(t *testing.T) { + al, _, _, _, cleanup := newTestAgentLoop(t) + defer cleanup() + + sess := &ephemeralSessionStore{ + history: []providers.Message{ + {Role: "user", Content: "initial message"}, + {Role: "assistant", Content: "response 1"}, + {Role: "user", Content: "follow-up"}, + }, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + rootTS := &turnState{ + ctx: ctx, + cancelFunc: cancel, + turnID: "test-session-order", + depth: 0, + session: sess, + initialHistoryLength: 1, // Snapshot: 1 message + pendingResults: make(chan *tools.ToolResult, 16), + concurrencySem: make(chan struct{}, 5), + } + + al.activeTurnStates.Store("test-session-order", rootTS) + + // Trigger HardAbort + err := al.HardAbort("test-session-order") + if err != nil { + t.Fatalf("HardAbort failed: %v", err) + } + + // Verify context was cancelled (Finish() was called) + select { + case <-rootTS.ctx.Done(): + // Good - context was cancelled + default: + t.Error("expected context to be cancelled after HardAbort") + } + + // Verify history was rolled back + finalHistory := sess.GetHistory("") + if len(finalHistory) != 1 { + t.Errorf("expected history to rollback to 1 message, got %d", len(finalHistory)) + } + + if finalHistory[0].Content != "initial message" { + t.Error("history content does not match initial state after rollback") + } +} + +// TestFinishClosesChannel verifies that Finish() closes the pendingResults channel +// and that deliverSubTurnResult handles closed channels gracefully. +func TestFinishClosesChannel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ts := &turnState{ + ctx: ctx, + cancelFunc: cancel, + turnID: "test-finish-channel", + depth: 0, + pendingResults: make(chan *tools.ToolResult, 2), + isFinished: false, + } + + // Verify channel is open initially + select { + case ts.pendingResults <- &tools.ToolResult{ForLLM: "test"}: + // Good - channel is open + // Drain the message we just sent + <-ts.pendingResults + default: + t.Fatal("channel should be open initially") + } + + // Call Finish() + ts.Finish() + + // Verify channel is closed + _, ok := <-ts.pendingResults + if ok { + t.Error("expected channel to be closed after Finish()") + } + + // Verify Finish() is idempotent (can be called multiple times) + ts.Finish() // Should not panic + + // Verify deliverSubTurnResult doesn't panic when sending to closed channel + result := &tools.ToolResult{ForLLM: "late result"} + + // This should not panic - it should recover and emit OrphanResultEvent + deliverSubTurnResult(ts, "child-1", result) +}