Files
picoclaw/pkg/agent/steering.go
T
Administrator 6b5d7e3fd7 fix(agent): resolve critical race conditions and resource leaks in SubTurn
- Fix turnState hierarchy corruption when SubTurns recursively call runAgentLoop
  by checking context for existing turnState before creating new root
- Fix deadlock risk in deliverSubTurnResult by separating lock and channel operations
- Fix session rollback race in HardAbort by calling Finish() before rollback
- Fix resource leak by closing pendingResults channel in Finish() with panic recovery
- Add thread-safety documentation for childTurnIDs and isFinished fields
- Move globalTurnCounter to AgentLoop.subTurnCounter to prevent ID conflicts
- Improve semaphore acquisition to ensure release even on early validation failures
- Document design choice: ephemeral sessions start empty for complete isolation
- Add 5 new tests: hierarchy, deadlock, order, channel close, and semaphore
2026-03-16 22:37:21 +08:00

279 lines
8.0 KiB
Go

package agent
import (
"context"
"fmt"
"strings"
"sync"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/tools"
)
// SteeringMode controls how queued steering messages are dequeued.
type SteeringMode string
const (
// SteeringOneAtATime dequeues only the first queued message per poll.
SteeringOneAtATime SteeringMode = "one-at-a-time"
// SteeringAll drains the entire queue in a single poll.
SteeringAll SteeringMode = "all"
// MaxQueueSize number of possible messages in the Steering Queue
MaxQueueSize = 10
)
// parseSteeringMode normalizes a config string into a SteeringMode.
func parseSteeringMode(s string) SteeringMode {
switch s {
case "all":
return SteeringAll
default:
return SteeringOneAtATime
}
}
// steeringQueue is a thread-safe queue of user messages that can be injected
// into a running agent loop to interrupt it between tool calls.
type steeringQueue struct {
mu sync.Mutex
queue []providers.Message
mode SteeringMode
}
func newSteeringQueue(mode SteeringMode) *steeringQueue {
return &steeringQueue{
mode: mode,
}
}
// push enqueues a steering message.
func (sq *steeringQueue) push(msg providers.Message) error {
sq.mu.Lock()
defer sq.mu.Unlock()
if len(sq.queue) >= MaxQueueSize {
return fmt.Errorf("steering queue is full")
}
sq.queue = append(sq.queue, msg)
return nil
}
// dequeue removes and returns pending steering messages according to the
// configured mode. Returns nil when the queue is empty.
func (sq *steeringQueue) dequeue() []providers.Message {
sq.mu.Lock()
defer sq.mu.Unlock()
if len(sq.queue) == 0 {
return nil
}
switch sq.mode {
case SteeringAll:
msgs := sq.queue
sq.queue = nil
return msgs
default: // one-at-a-time
msg := sq.queue[0]
sq.queue[0] = providers.Message{} // Clear reference for GC
sq.queue = sq.queue[1:]
return []providers.Message{msg}
}
}
// len returns the number of queued messages.
func (sq *steeringQueue) len() int {
sq.mu.Lock()
defer sq.mu.Unlock()
return len(sq.queue)
}
// setMode updates the steering mode.
func (sq *steeringQueue) setMode(mode SteeringMode) {
sq.mu.Lock()
defer sq.mu.Unlock()
sq.mode = mode
}
// getMode returns the current steering mode.
func (sq *steeringQueue) getMode() SteeringMode {
sq.mu.Lock()
defer sq.mu.Unlock()
return sq.mode
}
// --- AgentLoop steering API ---
// Steer enqueues a user message to be injected into the currently running
// agent loop. The message will be picked up after the current tool finishes
// executing, causing any remaining tool calls in the batch to be skipped.
func (al *AgentLoop) Steer(msg providers.Message) error {
if al.steering == nil {
return fmt.Errorf("steering queue is not initialized")
}
if err := al.steering.push(msg); err != nil {
logger.WarnCF("agent", "Failed to enqueue steering message", map[string]any{
"error": err.Error(),
"role": msg.Role,
})
return err
}
logger.DebugCF("agent", "Steering message enqueued", map[string]any{
"role": msg.Role,
"content_len": len(msg.Content),
"queue_len": al.steering.len(),
})
return nil
}
// SteeringMode returns the current steering mode.
func (al *AgentLoop) SteeringMode() SteeringMode {
if al.steering == nil {
return SteeringOneAtATime
}
return al.steering.getMode()
}
// SetSteeringMode updates the steering mode.
func (al *AgentLoop) SetSteeringMode(mode SteeringMode) {
if al.steering == nil {
return
}
al.steering.setMode(mode)
}
// dequeueSteeringMessages is the internal method called by the agent loop
// to poll for steering messages. Returns nil when no messages are pending.
func (al *AgentLoop) dequeueSteeringMessages() []providers.Message {
if al.steering == nil {
return nil
}
return al.steering.dequeue()
}
// Continue resumes an idle agent by dequeuing any pending steering messages
// and running them through the agent loop. This is used when the agent's last
// message was from the assistant (i.e., it has stopped processing) and the
// user has since enqueued steering messages.
//
// If no steering messages are pending, it returns an empty string.
func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID string) (string, error) {
steeringMsgs := al.dequeueSteeringMessages()
if len(steeringMsgs) == 0 {
return "", nil
}
agent := al.GetRegistry().GetDefaultAgent()
if agent == nil {
return "", fmt.Errorf("no default agent available")
}
// Build a combined user message from the steering messages.
var contents []string
for _, msg := range steeringMsgs {
contents = append(contents, msg.Content)
}
combinedContent := strings.Join(contents, "\n")
return al.runAgentLoop(ctx, agent, processOptions{
SessionKey: sessionKey,
Channel: channel,
ChatID: chatID,
UserMessage: combinedContent,
DefaultResponse: defaultResponse,
EnableSummary: true,
SendResponse: false,
SkipInitialSteeringPoll: true,
})
}
// ====================== SubTurn Result Polling ======================
// dequeuePendingSubTurnResults polls the SubTurn result channel for the given
// session and returns all available results without blocking.
// Returns nil if no channel is registered for this session.
func (al *AgentLoop) dequeuePendingSubTurnResults(sessionKey string) []*tools.ToolResult {
chInterface, ok := al.subTurnResults.Load(sessionKey)
if !ok {
return nil
}
ch, ok := chInterface.(chan *tools.ToolResult)
if !ok {
return nil
}
var results []*tools.ToolResult
for {
select {
case result := <-ch:
if result != nil {
results = append(results, result)
}
default:
return results
}
}
}
// registerSubTurnResultChannel registers a SubTurn result channel for the given session.
// This allows the parent loop to poll for results from child SubTurns.
func (al *AgentLoop) registerSubTurnResultChannel(sessionKey string, ch chan *tools.ToolResult) {
al.subTurnResults.Store(sessionKey, ch)
}
// unregisterSubTurnResultChannel removes the SubTurn result channel for the given session.
func (al *AgentLoop) unregisterSubTurnResultChannel(sessionKey string) {
al.subTurnResults.Delete(sessionKey)
}
// ====================== Hard Abort ======================
// HardAbort immediately cancels the running agent loop for the given session,
// cascading the cancellation to all child SubTurns. This is a destructive operation
// that terminates execution without waiting for graceful cleanup.
//
// Use this when the user explicitly requests immediate termination (e.g., "stop now", "abort").
// For graceful interruption that allows the agent to finish the current tool and summarize,
// use Steer() instead.
func (al *AgentLoop) HardAbort(sessionKey string) error {
tsInterface, ok := al.activeTurnStates.Load(sessionKey)
if !ok {
return fmt.Errorf("no active turn state found for session %s", sessionKey)
}
ts, ok := tsInterface.(*turnState)
if !ok {
return fmt.Errorf("invalid turn state type for session %s", sessionKey)
}
logger.InfoCF("agent", "Hard abort triggered", map[string]any{
"session_key": sessionKey,
"turn_id": ts.turnID,
"depth": ts.depth,
"initial_history_length": ts.initialHistoryLength,
})
// IMPORTANT: Trigger cascading cancellation FIRST to stop all child SubTurns
// from adding more messages to the session. This prevents race conditions
// where rollback happens while children are still writing.
ts.Finish()
// Rollback session history to the state before this turn started.
// This must happen AFTER Finish() to ensure no child turns are still writing.
if ts.session != nil {
currentHistory := ts.session.GetHistory("")
if len(currentHistory) > ts.initialHistoryLength {
logger.InfoCF("agent", "Rolling back session history", map[string]any{
"from": len(currentHistory),
"to": ts.initialHistoryLength,
})
// SetHistory with the truncated slice to rollback
ts.session.SetHistory("", currentHistory[:ts.initialHistoryLength])
}
}
return nil
}