Files
picoclaw/pkg/agent/turn.go
T
2026-03-20 17:46:31 +08:00

309 lines
6.8 KiB
Go

package agent
import (
"context"
"reflect"
"sync"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/providers"
)
type TurnPhase string
const (
TurnPhaseSetup TurnPhase = "setup"
TurnPhaseRunning TurnPhase = "running"
TurnPhaseTools TurnPhase = "tools"
TurnPhaseFinalizing TurnPhase = "finalizing"
TurnPhaseCompleted TurnPhase = "completed"
TurnPhaseAborted TurnPhase = "aborted"
)
type ActiveTurnInfo struct {
TurnID string
AgentID string
SessionKey string
Channel string
ChatID string
UserMessage string
Phase TurnPhase
Iteration int
StartedAt time.Time
}
type turnResult struct {
finalContent string
status TurnEndStatus
followUps []bus.InboundMessage
}
type turnState struct {
mu sync.RWMutex
agent *AgentInstance
opts processOptions
scope turnEventScope
turnID string
agentID string
sessionKey string
channel string
chatID string
userMessage string
media []string
phase TurnPhase
iteration int
startedAt time.Time
finalContent string
followUps []bus.InboundMessage
gracefulInterrupt bool
gracefulInterruptHint string
gracefulTerminalUsed bool
hardAbort bool
providerCancel context.CancelFunc
turnCancel context.CancelFunc
restorePointHistory []providers.Message
restorePointSummary string
persistedMessages []providers.Message
}
func newTurnState(agent *AgentInstance, opts processOptions, scope turnEventScope) *turnState {
return &turnState{
agent: agent,
opts: opts,
scope: scope,
turnID: scope.turnID,
agentID: agent.ID,
sessionKey: opts.SessionKey,
channel: opts.Channel,
chatID: opts.ChatID,
userMessage: opts.UserMessage,
media: append([]string(nil), opts.Media...),
phase: TurnPhaseSetup,
startedAt: time.Now(),
}
}
func (al *AgentLoop) registerActiveTurn(ts *turnState) {
al.activeTurnMu.Lock()
defer al.activeTurnMu.Unlock()
al.activeTurn = ts
}
func (al *AgentLoop) clearActiveTurn(ts *turnState) {
al.activeTurnMu.Lock()
defer al.activeTurnMu.Unlock()
if al.activeTurn == ts {
al.activeTurn = nil
}
}
func (al *AgentLoop) getActiveTurnState() *turnState {
al.activeTurnMu.RLock()
defer al.activeTurnMu.RUnlock()
return al.activeTurn
}
func (al *AgentLoop) GetActiveTurn() *ActiveTurnInfo {
ts := al.getActiveTurnState()
if ts == nil {
return nil
}
info := ts.snapshot()
return &info
}
func (ts *turnState) snapshot() ActiveTurnInfo {
ts.mu.RLock()
defer ts.mu.RUnlock()
return ActiveTurnInfo{
TurnID: ts.turnID,
AgentID: ts.agentID,
SessionKey: ts.sessionKey,
Channel: ts.channel,
ChatID: ts.chatID,
UserMessage: ts.userMessage,
Phase: ts.phase,
Iteration: ts.iteration,
StartedAt: ts.startedAt,
}
}
func (ts *turnState) setPhase(phase TurnPhase) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.phase = phase
}
func (ts *turnState) setIteration(iteration int) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.iteration = iteration
}
func (ts *turnState) currentIteration() int {
ts.mu.RLock()
defer ts.mu.RUnlock()
return ts.iteration
}
func (ts *turnState) setFinalContent(content string) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.finalContent = content
}
func (ts *turnState) finalContentLen() int {
ts.mu.RLock()
defer ts.mu.RUnlock()
return len(ts.finalContent)
}
func (ts *turnState) setTurnCancel(cancel context.CancelFunc) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.turnCancel = cancel
}
func (ts *turnState) setProviderCancel(cancel context.CancelFunc) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.providerCancel = cancel
}
func (ts *turnState) clearProviderCancel(_ context.CancelFunc) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.providerCancel = nil
}
func (ts *turnState) requestGracefulInterrupt(hint string) bool {
ts.mu.Lock()
defer ts.mu.Unlock()
if ts.hardAbort {
return false
}
ts.gracefulInterrupt = true
ts.gracefulInterruptHint = hint
return true
}
func (ts *turnState) gracefulInterruptRequested() (bool, string) {
ts.mu.RLock()
defer ts.mu.RUnlock()
return ts.gracefulInterrupt && !ts.gracefulTerminalUsed, ts.gracefulInterruptHint
}
func (ts *turnState) markGracefulTerminalUsed() {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.gracefulTerminalUsed = true
}
func (ts *turnState) requestHardAbort() bool {
ts.mu.Lock()
if ts.hardAbort {
ts.mu.Unlock()
return false
}
ts.hardAbort = true
turnCancel := ts.turnCancel
providerCancel := ts.providerCancel
ts.mu.Unlock()
if providerCancel != nil {
providerCancel()
}
if turnCancel != nil {
turnCancel()
}
return true
}
func (ts *turnState) hardAbortRequested() bool {
ts.mu.RLock()
defer ts.mu.RUnlock()
return ts.hardAbort
}
func (ts *turnState) eventMeta(source, tracePath string) EventMeta {
snap := ts.snapshot()
return EventMeta{
AgentID: snap.AgentID,
TurnID: snap.TurnID,
SessionKey: snap.SessionKey,
Iteration: snap.Iteration,
Source: source,
TracePath: tracePath,
}
}
func (ts *turnState) captureRestorePoint(history []providers.Message, summary string) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.restorePointHistory = append([]providers.Message(nil), history...)
ts.restorePointSummary = summary
}
func (ts *turnState) recordPersistedMessage(msg providers.Message) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.persistedMessages = append(ts.persistedMessages, msg)
}
func (ts *turnState) refreshRestorePointFromSession(agent *AgentInstance) {
history := agent.Sessions.GetHistory(ts.sessionKey)
summary := agent.Sessions.GetSummary(ts.sessionKey)
ts.mu.RLock()
persisted := append([]providers.Message(nil), ts.persistedMessages...)
ts.mu.RUnlock()
if matched := matchingTurnMessageTail(history, persisted); matched > 0 {
history = append([]providers.Message(nil), history[:len(history)-matched]...)
}
ts.captureRestorePoint(history, summary)
}
func (ts *turnState) restoreSession(agent *AgentInstance) error {
ts.mu.RLock()
history := append([]providers.Message(nil), ts.restorePointHistory...)
summary := ts.restorePointSummary
ts.mu.RUnlock()
agent.Sessions.SetHistory(ts.sessionKey, history)
agent.Sessions.SetSummary(ts.sessionKey, summary)
return agent.Sessions.Save(ts.sessionKey)
}
func matchingTurnMessageTail(history, persisted []providers.Message) int {
maxMatch := min(len(history), len(persisted))
for size := maxMatch; size > 0; size-- {
if reflect.DeepEqual(history[len(history)-size:], persisted[len(persisted)-size:]) {
return size
}
}
return 0
}
func (ts *turnState) interruptHintMessage() providers.Message {
_, hint := ts.gracefulInterruptRequested()
content := "Interrupt requested. Stop scheduling tools and provide a short final summary."
if hint != "" {
content += "\n\nInterrupt hint: " + hint
}
return providers.Message{
Role: "user",
Content: content,
}
}