Files
picoclaw/pkg/agent/turn_state.go
T
lxowalle 2992eccbf0 feat: add request-scoped context policies (#2914)
* feat: add request-scoped context policies

Add named turn profiles under agents.defaults so callers can opt into
per-request context and tool policies without changing default chat behavior.

Profiles can disable history, system context, skill prompts, or tools, and can
limit skills/tools with allow lists. Wire profile selection through Pico message
payloads, agent turn execution, Web chat selection, and Web visual config.

Reject invalid turn profiles before saving config through Web APIs and document
the new request context policy behavior.

* fix: address turn profile review blockers

* feat: simplify request context policy config

* fix: suppress tool prompt when turn tools are disabled

* fix: enforce turn profile tool restrictions
2026-05-22 10:06:40 +08:00

822 lines
22 KiB
Go

// PicoClaw - Ultra-lightweight personal AI agent
package agent
import (
"context"
"reflect"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/session"
"github.com/sipeed/picoclaw/pkg/tools"
)
// =============================================================================
// TurnPhase - represents the current phase of a turn
// =============================================================================
type TurnPhase string
const (
TurnPhaseSetup TurnPhase = "setup"
TurnPhaseRunning TurnPhase = "running"
TurnPhaseTools TurnPhase = "tools"
TurnPhaseFinalizing TurnPhase = "finalizing"
TurnPhaseCompleted TurnPhase = "completed"
TurnPhaseAborted TurnPhase = "aborted"
)
// =============================================================================
// Control signals - returned from Pipeline methods to drive runTurn's coordinator loop
// =============================================================================
type Control int
const (
// ControlContinue tells the coordinator to jump back to the top of the turn loop
// (equivalent to the original "goto turnLoop").
ControlContinue Control = iota
// ControlBreak tells the coordinator to exit the turn loop and proceed to Finalize.
ControlBreak
// ControlToolLoop tells the coordinator to execute the tool loop.
ControlToolLoop
)
// ToolControl signals returned from ExecuteTools to drive tool loop iteration.
type ToolControl int
const (
// ToolControlContinue tells the tool loop to jump to the next iteration
// (pendingMessages arrived, SubTurn results, etc.).
ToolControlContinue ToolControl = iota
// ToolControlBreak tells the tool loop to exit and return to the coordinator.
ToolControlBreak
// ToolControlFinalize tells the coordinator that all tool responses were
// handled and the turn should finalize without another LLM call.
ToolControlFinalize
)
// LLMPhase indicates which phase the turn is executing in.
type LLMPhase int
const (
LLMPhaseSetup LLMPhase = iota
LLMPhasePreLLM
LLMPhaseLLMCall
LLMPhaseProcessing
LLMPhaseToolLoop
LLMPhaseTools
LLMPhaseFinalizing
LLMPhaseCompleted
LLMPhaseAborted
)
// =============================================================================
// turnResult - returned from runTurn
// =============================================================================
type turnResult struct {
finalContent string
modelName string
status TurnEndStatus
followUps []bus.InboundMessage
}
// =============================================================================
// ActiveTurnInfo - public info about an active turn
// =============================================================================
type ActiveTurnInfo struct {
TurnID string
AgentID string
SessionKey string
Channel string
ChatID string
UserMessage string
Phase TurnPhase
Iteration int
StartedAt time.Time
Depth int
ParentTurnID string
ChildTurnIDs []string
}
// =============================================================================
// turnExecution - mutable state that persists across turn loop iterations
// =============================================================================
type turnExecution struct {
// Core message state (accumulates throughout the turn)
messages []providers.Message // built from ContextBuilder, grows per-iteration
pendingMessages []providers.Message // steering/SubTurn messages awaiting injection
history []providers.Message // from ContextManager.Assemble
summary string
// Turn output
finalContent string
// Iteration tracking
iteration int
// Per-iteration state set by Pipeline.PreLLM
activeCandidates []providers.FallbackCandidate
activeModel string
activeModelConfig *config.ModelConfig
activeProvider providers.LLMProvider
usedLight bool
// LLM call per-iteration state
response *providers.LLMResponse
normalizedToolCalls []providers.ToolCall
allResponsesHandled bool
streamingPublisher *streamingChunkPublisher
streamingFallback bool
suppressReasoning bool
callMessages []providers.Message
providerToolDefs []providers.ToolDefinition
llmModel string
llmModelName string
llmOpts map[string]any
gracefulTerminal bool
useNativeSearch bool
// Phase tracking
phase LLMPhase
// Abort signaling for coordinator (set by Pipeline methods)
abortedByHardAbort bool // true when hard abort triggered during LLM/tools
abortedByHook bool // true when HookActionAbortTurn triggered
}
// newTurnExecution creates a turnExecution initialized from turnState and options.
func newTurnExecution(
agent *AgentInstance,
opts processOptions,
history []providers.Message,
summary string,
messages []providers.Message,
) *turnExecution {
return &turnExecution{
history: history,
summary: summary,
messages: messages,
pendingMessages: append([]providers.Message(nil), opts.InitialSteeringMessages...),
iteration: 0,
phase: LLMPhaseSetup,
}
}
// =============================================================================
// turnState - the full state for a turn, constructed once per turn
// =============================================================================
type turnState struct {
mu sync.RWMutex
agent *AgentInstance
opts processOptions
profile config.EffectiveTurnProfile
scope turnEventScope
turnID string
agentID string
sessionKey string
activeSkills []string
attemptedSkills []string
skillContextTrace []SkillContextSnapshot
toolKinds []string
toolExecutions []ToolExecutionRecord
turnCtx *TurnContext
channel string
chatID string
workspace string
userMessage string
media []string
phase TurnPhase
iteration int
startedAt time.Time
finalContent string
followUps []bus.InboundMessage
gracefulInterrupt bool
gracefulInterruptHint string
gracefulTerminalUsed bool
hardAbort bool
providerCancel context.CancelFunc
turnCancel context.CancelFunc
restorePointHistory []providers.Message
restorePointSummary string
persistedMessages []providers.Message
// SubTurn support (from HEAD)
depth int // SubTurn depth (0 for root turn)
parentTurnID string // Parent turn ID (empty for root turn)
childTurnIDs []string // Child turn IDs
pendingResults chan *tools.ToolResult // Channel for SubTurn results
concurrencySem chan struct{} // Semaphore for limiting concurrent SubTurns
isFinished atomic.Bool // Whether this turn has finished
session session.SessionStore // Session store reference
initialHistoryLength int // Snapshot of history length at turn start
// Additional SubTurn fields
ctx context.Context // Context for this turn
cancelFunc context.CancelFunc // Cancel function for this turn's context
critical bool // Whether this SubTurn should continue after parent ends
parentTurnState *turnState // Reference to parent turnState
parentEnded atomic.Bool // Whether parent has ended
closeOnce sync.Once // Ensures pendingResults channel is closed once
finishedChan chan struct{} // Closed when turn finishes
// Token budget tracking
tokenBudget *atomic.Int64 // Shared token budget counter
lastFinishReason string // Last LLM finish_reason
lastUsage *providers.UsageInfo // Last LLM usage info
// Back-reference to the owning AgentLoop (set for SubTurns only, used for hard abort cascade)
al *AgentLoop
}
// =============================================================================
// turnState constructors and active turn management
// =============================================================================
func newTurnState(agent *AgentInstance, opts processOptions, scope turnEventScope) *turnState {
ts := &turnState{
agent: agent,
opts: opts,
profile: opts.TurnProfile,
scope: scope,
turnID: scope.turnID,
agentID: agent.ID,
sessionKey: opts.Dispatch.SessionKey,
activeSkills: activeSkillNames(agent, opts),
turnCtx: cloneTurnContext(scope.context),
channel: opts.Dispatch.Channel(),
chatID: opts.Dispatch.ChatID(),
workspace: agent.Workspace,
userMessage: opts.Dispatch.UserMessage,
media: append([]string(nil), opts.Dispatch.Media...),
phase: TurnPhaseSetup,
startedAt: time.Now(),
}
// Bind session store and capture initial history length for rollback logic
if agent != nil && agent.Sessions != nil {
ts.session = agent.Sessions
history := agent.Sessions.GetHistory(opts.Dispatch.SessionKey)
ts.initialHistoryLength = len(history)
ts.restorePointHistory = append([]providers.Message(nil), history...)
ts.restorePointSummary = agent.Sessions.GetSummary(opts.Dispatch.SessionKey)
}
return ts
}
func (al *AgentLoop) registerActiveTurn(ts *turnState) {
al.activeTurnStates.Store(ts.sessionKey, ts)
}
func (al *AgentLoop) clearActiveTurn(ts *turnState) {
al.activeTurnStates.Delete(ts.sessionKey)
}
func (al *AgentLoop) getActiveTurnState(sessionKey string) *turnState {
if val, ok := al.activeTurnStates.Load(sessionKey); ok {
if ts, ok := val.(*turnState); ok {
return ts
}
// Unexpected non-*turnState value — treat as "no active turn" to avoid
// panics. This should not happen under normal operation.
}
return nil
}
// getAnyActiveTurnState returns any active turn state (for backward compatibility)
func (al *AgentLoop) getAnyActiveTurnState() *turnState {
var firstTS *turnState
al.activeTurnStates.Range(func(key, value any) bool {
if ts, ok := value.(*turnState); ok {
firstTS = ts
return false
}
return true
})
return firstTS
}
func (al *AgentLoop) GetActiveTurn() *ActiveTurnInfo {
// For backward compatibility, return the first active turn found
// In the new architecture, there can be multiple concurrent turns
var firstTS *turnState
al.activeTurnStates.Range(func(key, value any) bool {
if ts, ok := value.(*turnState); ok {
firstTS = ts
return false
}
return true
})
if firstTS == nil {
return nil
}
info := firstTS.snapshot()
return &info
}
func (al *AgentLoop) GetActiveTurnBySession(sessionKey string) *ActiveTurnInfo {
ts := al.getActiveTurnState(sessionKey)
if ts == nil {
return nil
}
info := ts.snapshot()
return &info
}
// =============================================================================
// turnState - getters and setters
// =============================================================================
func (ts *turnState) snapshot() ActiveTurnInfo {
ts.mu.RLock()
defer ts.mu.RUnlock()
return ActiveTurnInfo{
TurnID: ts.turnID,
AgentID: ts.agentID,
SessionKey: ts.sessionKey,
Channel: ts.channel,
ChatID: ts.chatID,
UserMessage: ts.userMessage,
Phase: ts.phase,
Iteration: ts.iteration,
StartedAt: ts.startedAt,
Depth: ts.depth,
ParentTurnID: ts.parentTurnID,
ChildTurnIDs: append([]string(nil), ts.childTurnIDs...),
}
}
func (ts *turnState) setPhase(phase TurnPhase) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.phase = phase
}
func (ts *turnState) setIteration(iteration int) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.iteration = iteration
}
func (ts *turnState) currentIteration() int {
ts.mu.RLock()
defer ts.mu.RUnlock()
return ts.iteration
}
func (ts *turnState) setFinalContent(content string) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.finalContent = content
}
func (ts *turnState) finalContentLen() int {
ts.mu.RLock()
defer ts.mu.RUnlock()
return len(ts.finalContent)
}
func (ts *turnState) finalContentSnapshot() string {
ts.mu.RLock()
defer ts.mu.RUnlock()
return ts.finalContent
}
func (ts *turnState) recordToolKind(tool string) {
tool = strings.TrimSpace(tool)
if tool == "" {
return
}
ts.mu.Lock()
defer ts.mu.Unlock()
for _, existing := range ts.toolKinds {
if existing == tool {
return
}
}
ts.toolKinds = append(ts.toolKinds, tool)
}
func (ts *turnState) toolKindsSnapshot() []string {
ts.mu.RLock()
defer ts.mu.RUnlock()
return append([]string(nil), ts.toolKinds...)
}
func (ts *turnState) recordToolExecution(tool string, success bool, errorSummary string, skillNames []string) {
tool = strings.TrimSpace(tool)
if tool == "" {
return
}
ts.recordToolKind(tool)
ts.mu.Lock()
defer ts.mu.Unlock()
ts.toolExecutions = append(ts.toolExecutions, ToolExecutionRecord{
Name: tool,
Success: success,
ErrorSummary: strings.TrimSpace(errorSummary),
SkillNames: append([]string(nil), skillNames...),
})
}
func (ts *turnState) toolExecutionsSnapshot() []ToolExecutionRecord {
ts.mu.RLock()
defer ts.mu.RUnlock()
if len(ts.toolExecutions) == 0 {
return nil
}
out := make([]ToolExecutionRecord, 0, len(ts.toolExecutions))
for _, exec := range ts.toolExecutions {
out = append(out, ToolExecutionRecord{
Name: exec.Name,
Success: exec.Success,
ErrorSummary: exec.ErrorSummary,
SkillNames: append([]string(nil), exec.SkillNames...),
})
}
return out
}
func (ts *turnState) recordAttemptedSkills(skillNames []string) {
if len(skillNames) == 0 {
return
}
ts.mu.Lock()
defer ts.mu.Unlock()
for _, skillName := range skillNames {
skillName = strings.TrimSpace(skillName)
if skillName == "" {
continue
}
seen := false
for _, existing := range ts.attemptedSkills {
if existing == skillName {
seen = true
break
}
}
if seen {
continue
}
ts.attemptedSkills = append(ts.attemptedSkills, skillName)
}
}
func (ts *turnState) attemptedSkillsSnapshot() []string {
ts.mu.RLock()
defer ts.mu.RUnlock()
return append([]string(nil), ts.attemptedSkills...)
}
func (ts *turnState) recordSkillContextSnapshot(trigger string, skillNames []string) {
if len(skillNames) == 0 {
return
}
filtered := make([]string, 0, len(skillNames))
for _, skillName := range skillNames {
skillName = strings.TrimSpace(skillName)
if skillName == "" {
continue
}
filtered = append(filtered, skillName)
}
if len(filtered) == 0 {
return
}
ts.recordAttemptedSkills(filtered)
ts.mu.Lock()
defer ts.mu.Unlock()
ts.skillContextTrace = append(ts.skillContextTrace, SkillContextSnapshot{
Sequence: len(ts.skillContextTrace) + 1,
Trigger: trigger,
SkillNames: append([]string(nil), filtered...),
})
}
func (ts *turnState) latestSkillContextSnapshot() []string {
ts.mu.RLock()
defer ts.mu.RUnlock()
if len(ts.skillContextTrace) == 0 {
return nil
}
return append([]string(nil), ts.skillContextTrace[len(ts.skillContextTrace)-1].SkillNames...)
}
func (ts *turnState) skillContextSnapshotsSnapshot() []SkillContextSnapshot {
ts.mu.RLock()
defer ts.mu.RUnlock()
if len(ts.skillContextTrace) == 0 {
return nil
}
snapshots := make([]SkillContextSnapshot, 0, len(ts.skillContextTrace))
for _, snapshot := range ts.skillContextTrace {
snapshots = append(snapshots, SkillContextSnapshot{
Sequence: snapshot.Sequence,
Trigger: snapshot.Trigger,
SkillNames: append([]string(nil), snapshot.SkillNames...),
})
}
return snapshots
}
func (ts *turnState) setTurnCancel(cancel context.CancelFunc) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.turnCancel = cancel
}
func (ts *turnState) setProviderCancel(cancel context.CancelFunc) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.providerCancel = cancel
}
func (ts *turnState) clearProviderCancel(_ context.CancelFunc) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.providerCancel = nil
}
func (ts *turnState) requestGracefulInterrupt(hint string) bool {
ts.mu.Lock()
defer ts.mu.Unlock()
if ts.hardAbort {
return false
}
ts.gracefulInterrupt = true
ts.gracefulInterruptHint = hint
return true
}
func (ts *turnState) gracefulInterruptRequested() (bool, string) {
ts.mu.RLock()
defer ts.mu.RUnlock()
return ts.gracefulInterrupt && !ts.gracefulTerminalUsed, ts.gracefulInterruptHint
}
func (ts *turnState) markGracefulTerminalUsed() {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.gracefulTerminalUsed = true
}
func (ts *turnState) requestHardAbort() bool {
ts.mu.Lock()
if ts.hardAbort {
ts.mu.Unlock()
return false
}
ts.hardAbort = true
turnCancel := ts.turnCancel
providerCancel := ts.providerCancel
ts.mu.Unlock()
if providerCancel != nil {
providerCancel()
}
if turnCancel != nil {
turnCancel()
}
return true
}
func (ts *turnState) hardAbortRequested() bool {
ts.mu.RLock()
defer ts.mu.RUnlock()
return ts.hardAbort
}
func (ts *turnState) eventMeta(source, tracePath string) HookMeta {
snap := ts.snapshot()
return HookMeta{
AgentID: snap.AgentID,
TurnID: snap.TurnID,
SessionKey: snap.SessionKey,
Iteration: snap.Iteration,
Source: source,
TracePath: tracePath,
turnContext: cloneTurnContext(ts.turnCtx),
}
}
func (ts *turnState) captureRestorePoint(history []providers.Message, summary string) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.restorePointHistory = append([]providers.Message(nil), history...)
ts.restorePointSummary = summary
}
func (ts *turnState) recordPersistedMessage(msg providers.Message) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.persistedMessages = append(ts.persistedMessages, msg)
}
func (ts *turnState) refreshRestorePointFromSession(agent *AgentInstance) {
history := agent.Sessions.GetHistory(ts.sessionKey)
summary := agent.Sessions.GetSummary(ts.sessionKey)
ts.mu.RLock()
persisted := append([]providers.Message(nil), ts.persistedMessages...)
ts.mu.RUnlock()
if matched := matchingTurnMessageTail(history, persisted); matched > 0 {
history = append([]providers.Message(nil), history[:len(history)-matched]...)
}
ts.captureRestorePoint(history, summary)
}
// ingestMessage calls the ContextManager's Ingest method for a persisted message.
// Errors are logged but never block the turn.
func (ts *turnState) ingestMessage(ctx context.Context, al *AgentLoop, msg providers.Message) {
if al.contextManager == nil {
return
}
if err := al.contextManager.Ingest(ctx, &IngestRequest{
SessionKey: ts.sessionKey,
Message: msg,
}); err != nil {
logger.WarnCF("agent", "Context manager ingest failed", map[string]any{
"session_key": ts.sessionKey,
"error": err.Error(),
})
}
}
func (ts *turnState) restoreSession(agent *AgentInstance) error {
ts.mu.RLock()
history := append([]providers.Message(nil), ts.restorePointHistory...)
summary := ts.restorePointSummary
ts.mu.RUnlock()
agent.Sessions.SetHistory(ts.sessionKey, history)
agent.Sessions.SetSummary(ts.sessionKey, summary)
return agent.Sessions.Save(ts.sessionKey)
}
func matchingTurnMessageTail(history, persisted []providers.Message) int {
maxMatch := min(len(history), len(persisted))
for size := maxMatch; size > 0; size-- {
if reflect.DeepEqual(history[len(history)-size:], persisted[len(persisted)-size:]) {
return size
}
}
return 0
}
func (ts *turnState) interruptHintMessage() providers.Message {
_, hint := ts.gracefulInterruptRequested()
content := "Interrupt requested. Stop scheduling tools and provide a short final summary."
if hint != "" {
content += "\n\nInterrupt hint: " + hint
}
return interruptPromptMessage(content)
}
// =============================================================================
// SubTurn-related methods
// =============================================================================
// Finish marks the turn as finished and closes the pendingResults channel
func (ts *turnState) Finish(isHardAbort bool) {
ts.isFinished.Store(true)
// Close pendingResults channel exactly once
ts.closeOnce.Do(func() {
if ts.pendingResults != nil {
close(ts.pendingResults)
}
ts.mu.Lock()
if ts.finishedChan == nil {
ts.finishedChan = make(chan struct{})
}
close(ts.finishedChan)
ts.mu.Unlock()
})
// Any graceful finish must signal direct children so nested SubTurns can
// observe parent completion and decide whether to stop or continue.
if !isHardAbort {
ts.parentEnded.Store(true)
}
// Cancel the turn context
if ts.cancelFunc != nil {
ts.cancelFunc()
}
// Hard abort cascades to all child turns
if isHardAbort && ts.al != nil {
ts.mu.RLock()
children := append([]string(nil), ts.childTurnIDs...)
ts.mu.RUnlock()
for _, childID := range children {
if val, ok := ts.al.activeTurnStates.Load(childID); ok {
if child, ok := val.(*turnState); ok {
child.Finish(true)
}
}
}
}
}
// Finished returns whether the turn has finished
func (ts *turnState) Finished() chan struct{} {
ts.mu.Lock()
defer ts.mu.Unlock()
if ts.finishedChan == nil {
ts.finishedChan = make(chan struct{})
}
return ts.finishedChan
}
// IsParentEnded checks if the parent turn has ended
func (ts *turnState) IsParentEnded() bool {
if ts.parentTurnState == nil {
return false
}
return ts.parentTurnState.parentEnded.Load()
}
// GetLastFinishReason returns the last LLM finish_reason
func (ts *turnState) GetLastFinishReason() string {
ts.mu.RLock()
defer ts.mu.RUnlock()
return ts.lastFinishReason
}
// SetLastFinishReason sets the last LLM finish_reason
func (ts *turnState) SetLastFinishReason(reason string) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.lastFinishReason = reason
}
// GetLastUsage returns the last LLM usage info
func (ts *turnState) GetLastUsage() *providers.UsageInfo {
ts.mu.RLock()
defer ts.mu.RUnlock()
return ts.lastUsage
}
// SetLastUsage sets the last LLM usage info
func (ts *turnState) SetLastUsage(usage *providers.UsageInfo) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.lastUsage = usage
}
// =============================================================================
// Context helper functions for turnState
// =============================================================================
type turnStateKeyType struct{}
var turnStateKey = turnStateKeyType{}
func withTurnState(ctx context.Context, ts *turnState) context.Context {
return context.WithValue(ctx, turnStateKey, ts)
}
func turnStateFromContext(ctx context.Context) *turnState {
ts, _ := ctx.Value(turnStateKey).(*turnState)
return ts
}
// TurnStateFromContext retrieves turnState from context (exported for tools)
func TurnStateFromContext(ctx context.Context) *turnState {
return turnStateFromContext(ctx)
}