mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
0e075f7300
Refactor agent loop execution around runTurn, add explicit turn state and interrupt semantics, and automatically continue queued steering that misses the current turn boundary.
310 lines
6.8 KiB
Go
310 lines
6.8 KiB
Go
package agent
|
|
|
|
import (
|
|
"context"
|
|
"reflect"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/sipeed/picoclaw/pkg/bus"
|
|
"github.com/sipeed/picoclaw/pkg/providers"
|
|
)
|
|
|
|
type TurnPhase string
|
|
|
|
const (
|
|
TurnPhaseSetup TurnPhase = "setup"
|
|
TurnPhaseRunning TurnPhase = "running"
|
|
TurnPhaseTools TurnPhase = "tools"
|
|
TurnPhaseFinalizing TurnPhase = "finalizing"
|
|
TurnPhaseCompleted TurnPhase = "completed"
|
|
TurnPhaseAborted TurnPhase = "aborted"
|
|
)
|
|
|
|
type ActiveTurnInfo struct {
|
|
TurnID string
|
|
AgentID string
|
|
SessionKey string
|
|
Channel string
|
|
ChatID string
|
|
UserMessage string
|
|
Phase TurnPhase
|
|
Iteration int
|
|
StartedAt time.Time
|
|
}
|
|
|
|
type turnResult struct {
|
|
finalContent string
|
|
status TurnEndStatus
|
|
followUps []bus.InboundMessage
|
|
}
|
|
|
|
type turnState struct {
|
|
mu sync.RWMutex
|
|
|
|
agent *AgentInstance
|
|
opts processOptions
|
|
scope turnEventScope
|
|
|
|
turnID string
|
|
agentID string
|
|
sessionKey string
|
|
|
|
channel string
|
|
chatID string
|
|
userMessage string
|
|
media []string
|
|
|
|
phase TurnPhase
|
|
iteration int
|
|
startedAt time.Time
|
|
finalContent string
|
|
|
|
pendingSteering []providers.Message
|
|
followUps []bus.InboundMessage
|
|
|
|
gracefulInterrupt bool
|
|
gracefulInterruptHint string
|
|
gracefulTerminalUsed bool
|
|
hardAbort bool
|
|
providerCancel context.CancelFunc
|
|
turnCancel context.CancelFunc
|
|
|
|
restorePointHistory []providers.Message
|
|
restorePointSummary string
|
|
persistedMessages []providers.Message
|
|
}
|
|
|
|
func newTurnState(agent *AgentInstance, opts processOptions, scope turnEventScope) *turnState {
|
|
return &turnState{
|
|
agent: agent,
|
|
opts: opts,
|
|
scope: scope,
|
|
turnID: scope.turnID,
|
|
agentID: agent.ID,
|
|
sessionKey: opts.SessionKey,
|
|
channel: opts.Channel,
|
|
chatID: opts.ChatID,
|
|
userMessage: opts.UserMessage,
|
|
media: append([]string(nil), opts.Media...),
|
|
phase: TurnPhaseSetup,
|
|
startedAt: time.Now(),
|
|
}
|
|
}
|
|
|
|
func (al *AgentLoop) registerActiveTurn(ts *turnState) {
|
|
al.activeTurnMu.Lock()
|
|
defer al.activeTurnMu.Unlock()
|
|
al.activeTurn = ts
|
|
}
|
|
|
|
func (al *AgentLoop) clearActiveTurn(ts *turnState) {
|
|
al.activeTurnMu.Lock()
|
|
defer al.activeTurnMu.Unlock()
|
|
if al.activeTurn == ts {
|
|
al.activeTurn = nil
|
|
}
|
|
}
|
|
|
|
func (al *AgentLoop) getActiveTurnState() *turnState {
|
|
al.activeTurnMu.RLock()
|
|
defer al.activeTurnMu.RUnlock()
|
|
return al.activeTurn
|
|
}
|
|
|
|
func (al *AgentLoop) GetActiveTurn() *ActiveTurnInfo {
|
|
ts := al.getActiveTurnState()
|
|
if ts == nil {
|
|
return nil
|
|
}
|
|
info := ts.snapshot()
|
|
return &info
|
|
}
|
|
|
|
func (ts *turnState) snapshot() ActiveTurnInfo {
|
|
ts.mu.RLock()
|
|
defer ts.mu.RUnlock()
|
|
|
|
return ActiveTurnInfo{
|
|
TurnID: ts.turnID,
|
|
AgentID: ts.agentID,
|
|
SessionKey: ts.sessionKey,
|
|
Channel: ts.channel,
|
|
ChatID: ts.chatID,
|
|
UserMessage: ts.userMessage,
|
|
Phase: ts.phase,
|
|
Iteration: ts.iteration,
|
|
StartedAt: ts.startedAt,
|
|
}
|
|
}
|
|
|
|
func (ts *turnState) setPhase(phase TurnPhase) {
|
|
ts.mu.Lock()
|
|
defer ts.mu.Unlock()
|
|
ts.phase = phase
|
|
}
|
|
|
|
func (ts *turnState) setIteration(iteration int) {
|
|
ts.mu.Lock()
|
|
defer ts.mu.Unlock()
|
|
ts.iteration = iteration
|
|
}
|
|
|
|
func (ts *turnState) currentIteration() int {
|
|
ts.mu.RLock()
|
|
defer ts.mu.RUnlock()
|
|
return ts.iteration
|
|
}
|
|
|
|
func (ts *turnState) setFinalContent(content string) {
|
|
ts.mu.Lock()
|
|
defer ts.mu.Unlock()
|
|
ts.finalContent = content
|
|
}
|
|
|
|
func (ts *turnState) finalContentLen() int {
|
|
ts.mu.RLock()
|
|
defer ts.mu.RUnlock()
|
|
return len(ts.finalContent)
|
|
}
|
|
|
|
func (ts *turnState) setTurnCancel(cancel context.CancelFunc) {
|
|
ts.mu.Lock()
|
|
defer ts.mu.Unlock()
|
|
ts.turnCancel = cancel
|
|
}
|
|
|
|
func (ts *turnState) setProviderCancel(cancel context.CancelFunc) {
|
|
ts.mu.Lock()
|
|
defer ts.mu.Unlock()
|
|
ts.providerCancel = cancel
|
|
}
|
|
|
|
func (ts *turnState) clearProviderCancel(_ context.CancelFunc) {
|
|
ts.mu.Lock()
|
|
defer ts.mu.Unlock()
|
|
ts.providerCancel = nil
|
|
}
|
|
|
|
func (ts *turnState) requestGracefulInterrupt(hint string) bool {
|
|
ts.mu.Lock()
|
|
defer ts.mu.Unlock()
|
|
if ts.hardAbort {
|
|
return false
|
|
}
|
|
ts.gracefulInterrupt = true
|
|
ts.gracefulInterruptHint = hint
|
|
return true
|
|
}
|
|
|
|
func (ts *turnState) gracefulInterruptRequested() (bool, string) {
|
|
ts.mu.RLock()
|
|
defer ts.mu.RUnlock()
|
|
return ts.gracefulInterrupt && !ts.gracefulTerminalUsed, ts.gracefulInterruptHint
|
|
}
|
|
|
|
func (ts *turnState) markGracefulTerminalUsed() {
|
|
ts.mu.Lock()
|
|
defer ts.mu.Unlock()
|
|
ts.gracefulTerminalUsed = true
|
|
}
|
|
|
|
func (ts *turnState) requestHardAbort() bool {
|
|
ts.mu.Lock()
|
|
if ts.hardAbort {
|
|
ts.mu.Unlock()
|
|
return false
|
|
}
|
|
ts.hardAbort = true
|
|
turnCancel := ts.turnCancel
|
|
providerCancel := ts.providerCancel
|
|
ts.mu.Unlock()
|
|
|
|
if providerCancel != nil {
|
|
providerCancel()
|
|
}
|
|
if turnCancel != nil {
|
|
turnCancel()
|
|
}
|
|
return true
|
|
}
|
|
|
|
func (ts *turnState) hardAbortRequested() bool {
|
|
ts.mu.RLock()
|
|
defer ts.mu.RUnlock()
|
|
return ts.hardAbort
|
|
}
|
|
|
|
func (ts *turnState) eventMeta(source, tracePath string) EventMeta {
|
|
snap := ts.snapshot()
|
|
return EventMeta{
|
|
AgentID: snap.AgentID,
|
|
TurnID: snap.TurnID,
|
|
SessionKey: snap.SessionKey,
|
|
Iteration: snap.Iteration,
|
|
Source: source,
|
|
TracePath: tracePath,
|
|
}
|
|
}
|
|
|
|
func (ts *turnState) captureRestorePoint(history []providers.Message, summary string) {
|
|
ts.mu.Lock()
|
|
defer ts.mu.Unlock()
|
|
ts.restorePointHistory = append([]providers.Message(nil), history...)
|
|
ts.restorePointSummary = summary
|
|
}
|
|
|
|
func (ts *turnState) recordPersistedMessage(msg providers.Message) {
|
|
ts.mu.Lock()
|
|
defer ts.mu.Unlock()
|
|
ts.persistedMessages = append(ts.persistedMessages, msg)
|
|
}
|
|
|
|
func (ts *turnState) refreshRestorePointFromSession(agent *AgentInstance) {
|
|
history := agent.Sessions.GetHistory(ts.sessionKey)
|
|
summary := agent.Sessions.GetSummary(ts.sessionKey)
|
|
|
|
ts.mu.RLock()
|
|
persisted := append([]providers.Message(nil), ts.persistedMessages...)
|
|
ts.mu.RUnlock()
|
|
|
|
if matched := matchingTurnMessageTail(history, persisted); matched > 0 {
|
|
history = append([]providers.Message(nil), history[:len(history)-matched]...)
|
|
}
|
|
|
|
ts.captureRestorePoint(history, summary)
|
|
}
|
|
|
|
func (ts *turnState) restoreSession(agent *AgentInstance) error {
|
|
ts.mu.RLock()
|
|
history := append([]providers.Message(nil), ts.restorePointHistory...)
|
|
summary := ts.restorePointSummary
|
|
ts.mu.RUnlock()
|
|
|
|
agent.Sessions.SetHistory(ts.sessionKey, history)
|
|
agent.Sessions.SetSummary(ts.sessionKey, summary)
|
|
return agent.Sessions.Save(ts.sessionKey)
|
|
}
|
|
|
|
func matchingTurnMessageTail(history, persisted []providers.Message) int {
|
|
maxMatch := min(len(history), len(persisted))
|
|
for size := maxMatch; size > 0; size-- {
|
|
if reflect.DeepEqual(history[len(history)-size:], persisted[len(persisted)-size:]) {
|
|
return size
|
|
}
|
|
}
|
|
return 0
|
|
}
|
|
|
|
func (ts *turnState) interruptHintMessage() providers.Message {
|
|
_, hint := ts.gracefulInterruptRequested()
|
|
content := "Interrupt requested. Stop scheduling tools and provide a short final summary."
|
|
if hint != "" {
|
|
content += "\n\nInterrupt hint: " + hint
|
|
}
|
|
return providers.Message{
|
|
Role: "user",
|
|
Content: content,
|
|
}
|
|
}
|