mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-05-25 16:00:35 +00:00
887 lines
24 KiB
Go
887 lines
24 KiB
Go
// PicoClaw - Ultra-lightweight personal AI agent
|
|
|
|
package agent
|
|
|
|
import (
|
|
"context"
|
|
"reflect"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/sipeed/picoclaw/pkg/bus"
|
|
"github.com/sipeed/picoclaw/pkg/logger"
|
|
"github.com/sipeed/picoclaw/pkg/providers"
|
|
"github.com/sipeed/picoclaw/pkg/session"
|
|
"github.com/sipeed/picoclaw/pkg/tools"
|
|
)
|
|
|
|
// =============================================================================
|
|
// TurnPhase - represents the current phase of a turn
|
|
// =============================================================================
|
|
|
|
type TurnPhase string
|
|
|
|
const (
|
|
TurnPhaseSetup TurnPhase = "setup"
|
|
TurnPhaseRunning TurnPhase = "running"
|
|
TurnPhaseTools TurnPhase = "tools"
|
|
TurnPhaseFinalizing TurnPhase = "finalizing"
|
|
TurnPhaseCompleted TurnPhase = "completed"
|
|
TurnPhaseAborted TurnPhase = "aborted"
|
|
)
|
|
|
|
// =============================================================================
|
|
// Control signals - returned from Pipeline methods to drive runTurn's coordinator loop
|
|
// =============================================================================
|
|
|
|
type Control int
|
|
|
|
const (
|
|
// ControlContinue tells the coordinator to jump back to the top of the turn loop
|
|
// (equivalent to the original "goto turnLoop").
|
|
ControlContinue Control = iota
|
|
// ControlBreak tells the coordinator to exit the turn loop and proceed to Finalize.
|
|
ControlBreak
|
|
// ControlToolLoop tells the coordinator to execute the tool loop.
|
|
ControlToolLoop
|
|
)
|
|
|
|
// ToolControl signals returned from ExecuteTools to drive tool loop iteration.
|
|
type ToolControl int
|
|
|
|
const (
|
|
// ToolControlContinue tells the tool loop to jump to the next iteration
|
|
// (pendingMessages arrived, SubTurn results, etc.).
|
|
ToolControlContinue ToolControl = iota
|
|
// ToolControlBreak tells the tool loop to exit and return to the coordinator.
|
|
ToolControlBreak
|
|
// ToolControlFinalize tells the coordinator that all tool responses were
|
|
// handled and the turn should finalize without another LLM call.
|
|
ToolControlFinalize
|
|
)
|
|
|
|
// LLMPhase indicates which phase the turn is executing in.
|
|
type LLMPhase int
|
|
|
|
const (
|
|
LLMPhaseSetup LLMPhase = iota
|
|
LLMPhasePreLLM
|
|
LLMPhaseLLMCall
|
|
LLMPhaseProcessing
|
|
LLMPhaseToolLoop
|
|
LLMPhaseTools
|
|
LLMPhaseFinalizing
|
|
LLMPhaseCompleted
|
|
LLMPhaseAborted
|
|
)
|
|
|
|
// =============================================================================
|
|
// turnResult - returned from runTurn
|
|
// =============================================================================
|
|
|
|
type turnResult struct {
|
|
finalContent string
|
|
status TurnEndStatus
|
|
followUps []bus.InboundMessage
|
|
}
|
|
|
|
// =============================================================================
|
|
// ActiveTurnInfo - public info about an active turn
|
|
// =============================================================================
|
|
|
|
type ActiveTurnInfo struct {
|
|
TurnID string
|
|
AgentID string
|
|
SessionKey string
|
|
Channel string
|
|
ChatID string
|
|
UserMessage string
|
|
Phase TurnPhase
|
|
Iteration int
|
|
StartedAt time.Time
|
|
Depth int
|
|
ParentTurnID string
|
|
ChildTurnIDs []string
|
|
}
|
|
|
|
// =============================================================================
|
|
// turnExecution - mutable state that persists across turn loop iterations
|
|
// =============================================================================
|
|
|
|
type turnExecution struct {
|
|
// Core message state (accumulates throughout the turn)
|
|
messages []providers.Message // built from ContextBuilder, grows per-iteration
|
|
pendingMessages []providers.Message // steering/SubTurn messages awaiting injection
|
|
history []providers.Message // from ContextManager.Assemble
|
|
summary string
|
|
|
|
// Turn output
|
|
finalContent string
|
|
|
|
// Iteration tracking
|
|
iteration int
|
|
|
|
// Per-iteration state set by Pipeline.PreLLM
|
|
activeCandidates []providers.FallbackCandidate
|
|
activeModel string
|
|
activeProvider providers.LLMProvider
|
|
usedLight bool
|
|
|
|
// LLM call per-iteration state
|
|
response *providers.LLMResponse
|
|
normalizedToolCalls []providers.ToolCall
|
|
allResponsesHandled bool
|
|
callMessages []providers.Message
|
|
providerToolDefs []providers.ToolDefinition
|
|
llmModel string
|
|
llmOpts map[string]any
|
|
gracefulTerminal bool
|
|
useNativeSearch bool
|
|
|
|
// Phase tracking
|
|
phase LLMPhase
|
|
|
|
// Abort signaling for coordinator (set by Pipeline methods)
|
|
abortedByHardAbort bool // true when hard abort triggered during LLM/tools
|
|
abortedByHook bool // true when HookActionAbortTurn triggered
|
|
}
|
|
|
|
// newTurnExecution creates a turnExecution initialized from turnState and options.
|
|
func newTurnExecution(
|
|
agent *AgentInstance,
|
|
opts processOptions,
|
|
history []providers.Message,
|
|
summary string,
|
|
messages []providers.Message,
|
|
) *turnExecution {
|
|
return &turnExecution{
|
|
history: history,
|
|
summary: summary,
|
|
messages: messages,
|
|
pendingMessages: append([]providers.Message(nil), opts.InitialSteeringMessages...),
|
|
iteration: 0,
|
|
phase: LLMPhaseSetup,
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// turnState - the full state for a turn, constructed once per turn
|
|
// =============================================================================
|
|
|
|
type turnState struct {
|
|
mu sync.RWMutex
|
|
|
|
agent *AgentInstance
|
|
opts processOptions
|
|
scope turnEventScope
|
|
|
|
turnID string
|
|
agentID string
|
|
sessionKey string
|
|
activeSkills []string
|
|
attemptedSkills []string
|
|
skillContextTrace []SkillContextSnapshot
|
|
toolKinds []string
|
|
toolExecutions []ToolExecutionRecord
|
|
turnCtx *TurnContext
|
|
|
|
channel string
|
|
chatID string
|
|
workspace string
|
|
userMessage string
|
|
media []string
|
|
|
|
phase TurnPhase
|
|
iteration int
|
|
startedAt time.Time
|
|
finalContent string
|
|
|
|
followUps []bus.InboundMessage
|
|
|
|
gracefulInterrupt bool
|
|
gracefulInterruptHint string
|
|
gracefulTerminalUsed bool
|
|
hardAbort bool
|
|
providerCancel context.CancelFunc
|
|
turnCancel context.CancelFunc
|
|
|
|
restorePointHistory []providers.Message
|
|
restorePointSummary string
|
|
persistedMessages []providers.Message
|
|
|
|
// SubTurn support (from HEAD)
|
|
depth int // SubTurn depth (0 for root turn)
|
|
parentTurnID string // Parent turn ID (empty for root turn)
|
|
childTurnIDs []string // Child turn IDs
|
|
pendingResults chan *tools.ToolResult // Channel for SubTurn results
|
|
concurrencySem chan struct{} // Semaphore for limiting concurrent SubTurns
|
|
isFinished atomic.Bool // Whether this turn has finished
|
|
session session.SessionStore // Session store reference
|
|
initialHistoryLength int // Snapshot of history length at turn start
|
|
|
|
// Additional SubTurn fields
|
|
ctx context.Context // Context for this turn
|
|
cancelFunc context.CancelFunc // Cancel function for this turn's context
|
|
critical bool // Whether this SubTurn should continue after parent ends
|
|
parentTurnState *turnState // Reference to parent turnState
|
|
parentEnded atomic.Bool // Whether parent has ended
|
|
closeOnce sync.Once // Ensures pendingResults channel is closed once
|
|
finishedChan chan struct{} // Closed when turn finishes
|
|
|
|
// Token budget tracking
|
|
tokenBudget *atomic.Int64 // Shared token budget counter
|
|
lastFinishReason string // Last LLM finish_reason
|
|
lastUsage *providers.UsageInfo // Last LLM usage info
|
|
|
|
// Back-reference to the owning AgentLoop (set for SubTurns only, used for hard abort cascade)
|
|
al *AgentLoop
|
|
}
|
|
|
|
// =============================================================================
|
|
// turnState constructors and active turn management
|
|
// =============================================================================
|
|
|
|
func newTurnState(agent *AgentInstance, opts processOptions, scope turnEventScope) *turnState {
|
|
ts := &turnState{
|
|
agent: agent,
|
|
opts: opts,
|
|
scope: scope,
|
|
turnID: scope.turnID,
|
|
agentID: agent.ID,
|
|
sessionKey: opts.Dispatch.SessionKey,
|
|
activeSkills: activeSkillNames(agent, opts),
|
|
turnCtx: cloneTurnContext(scope.context),
|
|
channel: opts.Dispatch.Channel(),
|
|
chatID: opts.Dispatch.ChatID(),
|
|
workspace: agent.Workspace,
|
|
userMessage: opts.Dispatch.UserMessage,
|
|
media: append([]string(nil), opts.Dispatch.Media...),
|
|
phase: TurnPhaseSetup,
|
|
startedAt: time.Now(),
|
|
}
|
|
|
|
// Bind session store and capture initial history length for rollback logic
|
|
if agent != nil && agent.Sessions != nil {
|
|
ts.session = agent.Sessions
|
|
history := agent.Sessions.GetHistory(opts.Dispatch.SessionKey)
|
|
ts.initialHistoryLength = len(history)
|
|
ts.restorePointHistory = append([]providers.Message(nil), history...)
|
|
ts.restorePointSummary = agent.Sessions.GetSummary(opts.Dispatch.SessionKey)
|
|
}
|
|
|
|
return ts
|
|
}
|
|
|
|
func (al *AgentLoop) registerActiveTurn(ts *turnState) {
|
|
al.activeTurnStates.Store(ts.sessionKey, ts)
|
|
}
|
|
|
|
func (al *AgentLoop) clearActiveTurn(ts *turnState) {
|
|
al.activeTurnStates.Delete(ts.sessionKey)
|
|
}
|
|
|
|
func (al *AgentLoop) getActiveTurnState(sessionKey string) *turnState {
|
|
if val, ok := al.activeTurnStates.Load(sessionKey); ok {
|
|
if ts, ok := val.(*turnState); ok {
|
|
return ts
|
|
}
|
|
// Unexpected non-*turnState value — treat as "no active turn" to avoid
|
|
// panics. This should not happen under normal operation.
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// getAnyActiveTurnState returns any active turn state (for backward compatibility)
|
|
func (al *AgentLoop) getAnyActiveTurnState() *turnState {
|
|
var firstTS *turnState
|
|
al.activeTurnStates.Range(func(key, value any) bool {
|
|
if ts, ok := value.(*turnState); ok {
|
|
firstTS = ts
|
|
return false
|
|
}
|
|
return true
|
|
})
|
|
return firstTS
|
|
}
|
|
|
|
func (al *AgentLoop) GetActiveTurn() *ActiveTurnInfo {
|
|
// For backward compatibility, return the first active turn found
|
|
// In the new architecture, there can be multiple concurrent turns
|
|
var firstTS *turnState
|
|
al.activeTurnStates.Range(func(key, value any) bool {
|
|
if ts, ok := value.(*turnState); ok {
|
|
firstTS = ts
|
|
return false
|
|
}
|
|
return true
|
|
})
|
|
if firstTS == nil {
|
|
return nil
|
|
}
|
|
info := firstTS.snapshot()
|
|
return &info
|
|
}
|
|
|
|
func (al *AgentLoop) GetActiveTurnBySession(sessionKey string) *ActiveTurnInfo {
|
|
ts := al.getActiveTurnState(sessionKey)
|
|
if ts == nil {
|
|
return nil
|
|
}
|
|
info := ts.snapshot()
|
|
return &info
|
|
}
|
|
|
|
// =============================================================================
|
|
// turnState - getters and setters
|
|
// =============================================================================
|
|
|
|
func (ts *turnState) snapshot() ActiveTurnInfo {
|
|
ts.mu.RLock()
|
|
defer ts.mu.RUnlock()
|
|
|
|
return ActiveTurnInfo{
|
|
TurnID: ts.turnID,
|
|
AgentID: ts.agentID,
|
|
SessionKey: ts.sessionKey,
|
|
Channel: ts.channel,
|
|
ChatID: ts.chatID,
|
|
UserMessage: ts.userMessage,
|
|
Phase: ts.phase,
|
|
Iteration: ts.iteration,
|
|
StartedAt: ts.startedAt,
|
|
Depth: ts.depth,
|
|
ParentTurnID: ts.parentTurnID,
|
|
ChildTurnIDs: append([]string(nil), ts.childTurnIDs...),
|
|
}
|
|
}
|
|
|
|
func (ts *turnState) setPhase(phase TurnPhase) {
|
|
ts.mu.Lock()
|
|
defer ts.mu.Unlock()
|
|
ts.phase = phase
|
|
}
|
|
|
|
func (ts *turnState) setIteration(iteration int) {
|
|
ts.mu.Lock()
|
|
defer ts.mu.Unlock()
|
|
ts.iteration = iteration
|
|
}
|
|
|
|
func (ts *turnState) currentIteration() int {
|
|
ts.mu.RLock()
|
|
defer ts.mu.RUnlock()
|
|
return ts.iteration
|
|
}
|
|
|
|
func (ts *turnState) setFinalContent(content string) {
|
|
ts.mu.Lock()
|
|
defer ts.mu.Unlock()
|
|
ts.finalContent = content
|
|
}
|
|
|
|
func (ts *turnState) finalContentLen() int {
|
|
ts.mu.RLock()
|
|
defer ts.mu.RUnlock()
|
|
return len(ts.finalContent)
|
|
}
|
|
|
|
func (ts *turnState) finalContentSnapshot() string {
|
|
ts.mu.RLock()
|
|
defer ts.mu.RUnlock()
|
|
return ts.finalContent
|
|
}
|
|
|
|
func (ts *turnState) recordToolKind(tool string) {
|
|
tool = strings.TrimSpace(tool)
|
|
if tool == "" {
|
|
return
|
|
}
|
|
|
|
ts.mu.Lock()
|
|
defer ts.mu.Unlock()
|
|
|
|
for _, existing := range ts.toolKinds {
|
|
if existing == tool {
|
|
return
|
|
}
|
|
}
|
|
ts.toolKinds = append(ts.toolKinds, tool)
|
|
}
|
|
|
|
func (ts *turnState) toolKindsSnapshot() []string {
|
|
ts.mu.RLock()
|
|
defer ts.mu.RUnlock()
|
|
return append([]string(nil), ts.toolKinds...)
|
|
}
|
|
|
|
func (ts *turnState) recordToolExecution(tool string, success bool, errorSummary string, skillNames []string) {
|
|
tool = strings.TrimSpace(tool)
|
|
if tool == "" {
|
|
return
|
|
}
|
|
|
|
ts.recordToolKind(tool)
|
|
|
|
ts.mu.Lock()
|
|
defer ts.mu.Unlock()
|
|
ts.toolExecutions = append(ts.toolExecutions, ToolExecutionRecord{
|
|
Name: tool,
|
|
Success: success,
|
|
ErrorSummary: strings.TrimSpace(errorSummary),
|
|
SkillNames: append([]string(nil), skillNames...),
|
|
})
|
|
}
|
|
|
|
func (ts *turnState) toolExecutionsSnapshot() []ToolExecutionRecord {
|
|
ts.mu.RLock()
|
|
defer ts.mu.RUnlock()
|
|
if len(ts.toolExecutions) == 0 {
|
|
return nil
|
|
}
|
|
|
|
out := make([]ToolExecutionRecord, 0, len(ts.toolExecutions))
|
|
for _, exec := range ts.toolExecutions {
|
|
out = append(out, ToolExecutionRecord{
|
|
Name: exec.Name,
|
|
Success: exec.Success,
|
|
ErrorSummary: exec.ErrorSummary,
|
|
SkillNames: append([]string(nil), exec.SkillNames...),
|
|
})
|
|
}
|
|
return out
|
|
}
|
|
|
|
func (ts *turnState) recordAttemptedSkills(skillNames []string) {
|
|
if len(skillNames) == 0 {
|
|
return
|
|
}
|
|
|
|
ts.mu.Lock()
|
|
defer ts.mu.Unlock()
|
|
|
|
for _, skillName := range skillNames {
|
|
skillName = strings.TrimSpace(skillName)
|
|
if skillName == "" {
|
|
continue
|
|
}
|
|
seen := false
|
|
for _, existing := range ts.attemptedSkills {
|
|
if existing == skillName {
|
|
seen = true
|
|
break
|
|
}
|
|
}
|
|
if seen {
|
|
continue
|
|
}
|
|
ts.attemptedSkills = append(ts.attemptedSkills, skillName)
|
|
}
|
|
}
|
|
|
|
func (ts *turnState) attemptedSkillsSnapshot() []string {
|
|
ts.mu.RLock()
|
|
defer ts.mu.RUnlock()
|
|
return append([]string(nil), ts.attemptedSkills...)
|
|
}
|
|
|
|
func (ts *turnState) recordSkillContextSnapshot(trigger string, skillNames []string) {
|
|
if len(skillNames) == 0 {
|
|
return
|
|
}
|
|
|
|
filtered := make([]string, 0, len(skillNames))
|
|
for _, skillName := range skillNames {
|
|
skillName = strings.TrimSpace(skillName)
|
|
if skillName == "" {
|
|
continue
|
|
}
|
|
filtered = append(filtered, skillName)
|
|
}
|
|
if len(filtered) == 0 {
|
|
return
|
|
}
|
|
|
|
ts.recordAttemptedSkills(filtered)
|
|
|
|
ts.mu.Lock()
|
|
defer ts.mu.Unlock()
|
|
ts.skillContextTrace = append(ts.skillContextTrace, SkillContextSnapshot{
|
|
Sequence: len(ts.skillContextTrace) + 1,
|
|
Trigger: trigger,
|
|
SkillNames: append([]string(nil), filtered...),
|
|
})
|
|
}
|
|
|
|
func (ts *turnState) latestSkillContextSnapshot() []string {
|
|
ts.mu.RLock()
|
|
defer ts.mu.RUnlock()
|
|
if len(ts.skillContextTrace) == 0 {
|
|
return nil
|
|
}
|
|
return append([]string(nil), ts.skillContextTrace[len(ts.skillContextTrace)-1].SkillNames...)
|
|
}
|
|
|
|
func (ts *turnState) skillContextSnapshotsSnapshot() []SkillContextSnapshot {
|
|
ts.mu.RLock()
|
|
defer ts.mu.RUnlock()
|
|
if len(ts.skillContextTrace) == 0 {
|
|
return nil
|
|
}
|
|
|
|
snapshots := make([]SkillContextSnapshot, 0, len(ts.skillContextTrace))
|
|
for _, snapshot := range ts.skillContextTrace {
|
|
snapshots = append(snapshots, SkillContextSnapshot{
|
|
Sequence: snapshot.Sequence,
|
|
Trigger: snapshot.Trigger,
|
|
SkillNames: append([]string(nil), snapshot.SkillNames...),
|
|
})
|
|
}
|
|
return snapshots
|
|
}
|
|
|
|
func (ts *turnState) setTurnCancel(cancel context.CancelFunc) {
|
|
ts.mu.Lock()
|
|
defer ts.mu.Unlock()
|
|
ts.turnCancel = cancel
|
|
}
|
|
|
|
func (ts *turnState) setProviderCancel(cancel context.CancelFunc) {
|
|
ts.mu.Lock()
|
|
defer ts.mu.Unlock()
|
|
ts.providerCancel = cancel
|
|
}
|
|
|
|
func (ts *turnState) clearProviderCancel(_ context.CancelFunc) {
|
|
ts.mu.Lock()
|
|
defer ts.mu.Unlock()
|
|
ts.providerCancel = nil
|
|
}
|
|
|
|
func (ts *turnState) requestGracefulInterrupt(hint string) bool {
|
|
ts.mu.Lock()
|
|
defer ts.mu.Unlock()
|
|
if ts.hardAbort {
|
|
return false
|
|
}
|
|
ts.gracefulInterrupt = true
|
|
ts.gracefulInterruptHint = hint
|
|
return true
|
|
}
|
|
|
|
func (ts *turnState) gracefulInterruptRequested() (bool, string) {
|
|
ts.mu.RLock()
|
|
defer ts.mu.RUnlock()
|
|
return ts.gracefulInterrupt && !ts.gracefulTerminalUsed, ts.gracefulInterruptHint
|
|
}
|
|
|
|
func (ts *turnState) markGracefulTerminalUsed() {
|
|
ts.mu.Lock()
|
|
defer ts.mu.Unlock()
|
|
ts.gracefulTerminalUsed = true
|
|
}
|
|
|
|
func (ts *turnState) requestHardAbort() bool {
|
|
ts.mu.Lock()
|
|
if ts.hardAbort {
|
|
ts.mu.Unlock()
|
|
return false
|
|
}
|
|
ts.hardAbort = true
|
|
turnCancel := ts.turnCancel
|
|
providerCancel := ts.providerCancel
|
|
ts.mu.Unlock()
|
|
|
|
if providerCancel != nil {
|
|
providerCancel()
|
|
}
|
|
if turnCancel != nil {
|
|
turnCancel()
|
|
}
|
|
return true
|
|
}
|
|
|
|
func (ts *turnState) hardAbortRequested() bool {
|
|
ts.mu.RLock()
|
|
defer ts.mu.RUnlock()
|
|
return ts.hardAbort
|
|
}
|
|
|
|
func (ts *turnState) eventMeta(source, tracePath string) HookMeta {
|
|
snap := ts.snapshot()
|
|
return HookMeta{
|
|
AgentID: snap.AgentID,
|
|
TurnID: snap.TurnID,
|
|
SessionKey: snap.SessionKey,
|
|
Iteration: snap.Iteration,
|
|
Source: source,
|
|
TracePath: tracePath,
|
|
turnContext: cloneTurnContext(ts.turnCtx),
|
|
}
|
|
}
|
|
|
|
func (ts *turnState) captureRestorePoint(history []providers.Message, summary string) {
|
|
ts.mu.Lock()
|
|
defer ts.mu.Unlock()
|
|
ts.restorePointHistory = append([]providers.Message(nil), history...)
|
|
ts.restorePointSummary = summary
|
|
}
|
|
|
|
func (ts *turnState) recordPersistedMessage(msg providers.Message) {
|
|
ts.mu.Lock()
|
|
defer ts.mu.Unlock()
|
|
ts.persistedMessages = append(ts.persistedMessages, msg)
|
|
}
|
|
|
|
func (ts *turnState) persistedMessagesSnapshot() []providers.Message {
|
|
ts.mu.RLock()
|
|
defer ts.mu.RUnlock()
|
|
return append([]providers.Message(nil), ts.persistedMessages...)
|
|
}
|
|
|
|
func (ts *turnState) refreshRestorePointFromSession(agent *AgentInstance) {
|
|
history := agent.Sessions.GetHistory(ts.sessionKey)
|
|
summary := agent.Sessions.GetSummary(ts.sessionKey)
|
|
|
|
persisted := ts.persistedMessagesSnapshot()
|
|
|
|
if matched := matchingTurnMessageTail(history, persisted); matched > 0 {
|
|
history = append([]providers.Message(nil), history[:len(history)-matched]...)
|
|
}
|
|
|
|
ts.captureRestorePoint(history, summary)
|
|
}
|
|
|
|
// ingestMessage calls the ContextManager's Ingest method for a persisted message.
|
|
// Errors are logged but never block the turn.
|
|
func (ts *turnState) ingestMessage(ctx context.Context, al *AgentLoop, msg providers.Message) {
|
|
if al.contextManager == nil {
|
|
return
|
|
}
|
|
if err := al.contextManager.Ingest(ctx, &IngestRequest{
|
|
SessionKey: ts.sessionKey,
|
|
Message: msg,
|
|
}); err != nil {
|
|
logger.WarnCF("agent", "Context manager ingest failed", map[string]any{
|
|
"session_key": ts.sessionKey,
|
|
"error": err.Error(),
|
|
})
|
|
}
|
|
}
|
|
|
|
func (ts *turnState) restoreSession(agent *AgentInstance) error {
|
|
ts.mu.RLock()
|
|
history := append([]providers.Message(nil), ts.restorePointHistory...)
|
|
summary := ts.restorePointSummary
|
|
ts.mu.RUnlock()
|
|
|
|
agent.Sessions.SetHistory(ts.sessionKey, history)
|
|
agent.Sessions.SetSummary(ts.sessionKey, summary)
|
|
return agent.Sessions.Save(ts.sessionKey)
|
|
}
|
|
|
|
func matchingTurnMessageTail(history, persisted []providers.Message) int {
|
|
maxMatch := min(len(history), len(persisted))
|
|
for size := maxMatch; size > 0; size-- {
|
|
if messageSlicesEquivalent(history[len(history)-size:], persisted[len(persisted)-size:]) {
|
|
return size
|
|
}
|
|
}
|
|
return 0
|
|
}
|
|
|
|
func splitHistoryForActiveTurn(
|
|
history []providers.Message,
|
|
persisted []providers.Message,
|
|
) ([]providers.Message, []providers.Message) {
|
|
matched := matchingTurnMessageTail(history, persisted)
|
|
if matched <= 0 {
|
|
return append([]providers.Message(nil), history...), nil
|
|
}
|
|
|
|
stable := append([]providers.Message(nil), history[:len(history)-matched]...)
|
|
protected := append([]providers.Message(nil), history[len(history)-matched:]...)
|
|
return stable, protected
|
|
}
|
|
|
|
func messageSlicesEquivalent(a, b []providers.Message) bool {
|
|
if len(a) != len(b) {
|
|
return false
|
|
}
|
|
for i := range a {
|
|
if !messagesEquivalent(a[i], b[i]) {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
func messagesEquivalent(a, b providers.Message) bool {
|
|
return reflect.DeepEqual(normalizeMessageForComparison(a), normalizeMessageForComparison(b))
|
|
}
|
|
|
|
func normalizeMessageForComparison(msg providers.Message) providers.Message {
|
|
msg.PromptLayer = ""
|
|
msg.PromptSlot = ""
|
|
msg.PromptSource = ""
|
|
|
|
if len(msg.Media) == 0 {
|
|
msg.Media = nil
|
|
}
|
|
if len(msg.Attachments) == 0 {
|
|
msg.Attachments = nil
|
|
}
|
|
if len(msg.SystemParts) == 0 {
|
|
msg.SystemParts = nil
|
|
} else {
|
|
msg.SystemParts = append([]providers.ContentBlock(nil), msg.SystemParts...)
|
|
for i := range msg.SystemParts {
|
|
msg.SystemParts[i].PromptLayer = ""
|
|
msg.SystemParts[i].PromptSlot = ""
|
|
msg.SystemParts[i].PromptSource = ""
|
|
}
|
|
}
|
|
if len(msg.ToolCalls) == 0 {
|
|
msg.ToolCalls = nil
|
|
} else {
|
|
msg.ToolCalls = append([]providers.ToolCall(nil), msg.ToolCalls...)
|
|
for i := range msg.ToolCalls {
|
|
msg.ToolCalls[i].Name = ""
|
|
msg.ToolCalls[i].Arguments = nil
|
|
msg.ToolCalls[i].ThoughtSignature = ""
|
|
if msg.ToolCalls[i].Function != nil {
|
|
fn := *msg.ToolCalls[i].Function
|
|
fn.ThoughtSignature = ""
|
|
msg.ToolCalls[i].Function = &fn
|
|
}
|
|
}
|
|
}
|
|
|
|
return msg
|
|
}
|
|
|
|
func (ts *turnState) interruptHintMessage() providers.Message {
|
|
_, hint := ts.gracefulInterruptRequested()
|
|
content := "Interrupt requested. Stop scheduling tools and provide a short final summary."
|
|
if hint != "" {
|
|
content += "\n\nInterrupt hint: " + hint
|
|
}
|
|
return interruptPromptMessage(content)
|
|
}
|
|
|
|
// =============================================================================
|
|
// SubTurn-related methods
|
|
// =============================================================================
|
|
|
|
// Finish marks the turn as finished and closes the pendingResults channel
|
|
func (ts *turnState) Finish(isHardAbort bool) {
|
|
ts.isFinished.Store(true)
|
|
|
|
// Close pendingResults channel exactly once
|
|
ts.closeOnce.Do(func() {
|
|
if ts.pendingResults != nil {
|
|
close(ts.pendingResults)
|
|
}
|
|
ts.mu.Lock()
|
|
if ts.finishedChan == nil {
|
|
ts.finishedChan = make(chan struct{})
|
|
}
|
|
close(ts.finishedChan)
|
|
ts.mu.Unlock()
|
|
})
|
|
|
|
// Any graceful finish must signal direct children so nested SubTurns can
|
|
// observe parent completion and decide whether to stop or continue.
|
|
if !isHardAbort {
|
|
ts.parentEnded.Store(true)
|
|
}
|
|
|
|
// Cancel the turn context
|
|
if ts.cancelFunc != nil {
|
|
ts.cancelFunc()
|
|
}
|
|
|
|
// Hard abort cascades to all child turns
|
|
if isHardAbort && ts.al != nil {
|
|
ts.mu.RLock()
|
|
children := append([]string(nil), ts.childTurnIDs...)
|
|
ts.mu.RUnlock()
|
|
for _, childID := range children {
|
|
if val, ok := ts.al.activeTurnStates.Load(childID); ok {
|
|
if child, ok := val.(*turnState); ok {
|
|
child.Finish(true)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Finished returns whether the turn has finished
|
|
func (ts *turnState) Finished() chan struct{} {
|
|
ts.mu.Lock()
|
|
defer ts.mu.Unlock()
|
|
if ts.finishedChan == nil {
|
|
ts.finishedChan = make(chan struct{})
|
|
}
|
|
return ts.finishedChan
|
|
}
|
|
|
|
// IsParentEnded checks if the parent turn has ended
|
|
func (ts *turnState) IsParentEnded() bool {
|
|
if ts.parentTurnState == nil {
|
|
return false
|
|
}
|
|
return ts.parentTurnState.parentEnded.Load()
|
|
}
|
|
|
|
// GetLastFinishReason returns the last LLM finish_reason
|
|
func (ts *turnState) GetLastFinishReason() string {
|
|
ts.mu.RLock()
|
|
defer ts.mu.RUnlock()
|
|
return ts.lastFinishReason
|
|
}
|
|
|
|
// SetLastFinishReason sets the last LLM finish_reason
|
|
func (ts *turnState) SetLastFinishReason(reason string) {
|
|
ts.mu.Lock()
|
|
defer ts.mu.Unlock()
|
|
ts.lastFinishReason = reason
|
|
}
|
|
|
|
// GetLastUsage returns the last LLM usage info
|
|
func (ts *turnState) GetLastUsage() *providers.UsageInfo {
|
|
ts.mu.RLock()
|
|
defer ts.mu.RUnlock()
|
|
return ts.lastUsage
|
|
}
|
|
|
|
// SetLastUsage sets the last LLM usage info
|
|
func (ts *turnState) SetLastUsage(usage *providers.UsageInfo) {
|
|
ts.mu.Lock()
|
|
defer ts.mu.Unlock()
|
|
ts.lastUsage = usage
|
|
}
|
|
|
|
// =============================================================================
|
|
// Context helper functions for turnState
|
|
// =============================================================================
|
|
|
|
type turnStateKeyType struct{}
|
|
|
|
var turnStateKey = turnStateKeyType{}
|
|
|
|
func withTurnState(ctx context.Context, ts *turnState) context.Context {
|
|
return context.WithValue(ctx, turnStateKey, ts)
|
|
}
|
|
|
|
func turnStateFromContext(ctx context.Context) *turnState {
|
|
ts, _ := ctx.Value(turnStateKey).(*turnState)
|
|
return ts
|
|
}
|
|
|
|
// TurnStateFromContext retrieves turnState from context (exported for tools)
|
|
func TurnStateFromContext(ctx context.Context) *turnState {
|
|
return turnStateFromContext(ctx)
|
|
}
|