feat(agent): centralize turn lifecycle and continue queued steering

Refactor agent loop execution around runTurn, add explicit turn state and interrupt semantics, and automatically continue queued steering that misses the current turn boundary.
This commit is contained in:
Hoshina
2026-03-20 17:28:12 +08:00
parent a65e0e95d6
commit 0e075f7300
6 changed files with 1395 additions and 337 deletions
+3
View File
@@ -334,6 +334,9 @@ func TestAgentLoop_EmitsSteeringAndSkippedToolEvents(t *testing.T) {
if interruptPayload.Role != "user" {
t.Fatalf("expected interrupt role user, got %q", interruptPayload.Role)
}
if interruptPayload.Kind != InterruptKindSteering {
t.Fatalf("expected steering interrupt kind, got %q", interruptPayload.Kind)
}
if interruptPayload.ContentLen != len("change course") {
t.Fatalf("expected interrupt content len %d, got %d", len("change course"), interruptPayload.ContentLen)
}
+13 -1
View File
@@ -105,6 +105,8 @@ const (
TurnEndStatusCompleted TurnEndStatus = "completed"
// TurnEndStatusError indicates the turn ended because of an error.
TurnEndStatusError TurnEndStatus = "error"
// TurnEndStatusAborted indicates the turn was hard-aborted and rolled back.
TurnEndStatusAborted TurnEndStatus = "aborted"
)
// TurnStartPayload describes the start of a turn.
@@ -215,11 +217,21 @@ type FollowUpQueuedPayload struct {
ContentLen int
}
// InterruptReceivedPayload describes a queued soft interrupt.
type InterruptKind string
const (
InterruptKindSteering InterruptKind = "steering"
InterruptKindGraceful InterruptKind = "graceful"
InterruptKindHard InterruptKind = "hard_abort"
)
// InterruptReceivedPayload describes accepted turn-control input.
type InterruptReceivedPayload struct {
Kind InterruptKind
Role string
ContentLen int
QueueDepth int
HintLen int
}
// SubTurnSpawnPayload describes the creation of a child turn.
+490 -328
View File
File diff suppressed because it is too large Load Diff
+62 -8
View File
@@ -122,20 +122,23 @@ func (al *AgentLoop) Steer(msg providers.Message) error {
"content_len": len(msg.Content),
"queue_len": al.steering.len(),
})
agentID := ""
if registry := al.GetRegistry(); registry != nil {
meta := EventMeta{
Source: "Steer",
TracePath: "turn.interrupt.received",
}
if ts := al.getActiveTurnState(); ts != nil {
meta = ts.eventMeta("Steer", "turn.interrupt.received")
} else if registry := al.GetRegistry(); registry != nil {
if agent := registry.GetDefaultAgent(); agent != nil {
agentID = agent.ID
meta.AgentID = agent.ID
}
}
al.emitEvent(
EventKindInterruptReceived,
EventMeta{
AgentID: agentID,
Source: "Steer",
TracePath: "turn.interrupt.received",
},
meta,
InterruptReceivedPayload{
Kind: InterruptKindSteering,
Role: msg.Role,
ContentLen: len(msg.Content),
QueueDepth: al.steering.len(),
@@ -177,6 +180,10 @@ func (al *AgentLoop) dequeueSteeringMessages() []providers.Message {
//
// If no steering messages are pending, it returns an empty string.
func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID string) (string, error) {
if active := al.GetActiveTurn(); active != nil {
return "", fmt.Errorf("turn %s is still active", active.TurnID)
}
steeringMsgs := al.dequeueSteeringMessages()
if len(steeringMsgs) == 0 {
return "", nil
@@ -187,6 +194,12 @@ func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID s
return "", fmt.Errorf("no default agent available")
}
if tool, ok := agent.Tools.Get("message"); ok {
if resetter, ok := tool.(interface{ ResetSentInRound() }); ok {
resetter.ResetSentInRound()
}
}
// Build a combined user message from the steering messages.
var contents []string
for _, msg := range steeringMsgs {
@@ -205,3 +218,44 @@ func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID s
SkipInitialSteeringPoll: true,
})
}
func (al *AgentLoop) InterruptGraceful(hint string) error {
ts := al.getActiveTurnState()
if ts == nil {
return fmt.Errorf("no active turn")
}
if !ts.requestGracefulInterrupt(hint) {
return fmt.Errorf("turn %s cannot accept graceful interrupt", ts.turnID)
}
al.emitEvent(
EventKindInterruptReceived,
ts.eventMeta("InterruptGraceful", "turn.interrupt.received"),
InterruptReceivedPayload{
Kind: InterruptKindGraceful,
HintLen: len(hint),
},
)
return nil
}
func (al *AgentLoop) InterruptHard() error {
ts := al.getActiveTurnState()
if ts == nil {
return fmt.Errorf("no active turn")
}
if !ts.requestHardAbort() {
return fmt.Errorf("turn %s is already aborting", ts.turnID)
}
al.emitEvent(
EventKindInterruptReceived,
ts.eventMeta("InterruptHard", "turn.interrupt.received"),
InterruptReceivedPayload{
Kind: InterruptKindHard,
},
)
return nil
}
+518
View File
@@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"os"
"reflect"
"sync"
"testing"
"time"
@@ -12,6 +13,7 @@ import (
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/routing"
"github.com/sipeed/picoclaw/pkg/tools"
)
@@ -396,6 +398,103 @@ func (m *toolCallProvider) GetDefaultModel() string {
return "tool-call-mock"
}
type gracefulCaptureProvider struct {
mu sync.Mutex
calls int
toolCalls []providers.ToolCall
finalResp string
terminalMessages []providers.Message
terminalToolsCount int
}
func (p *gracefulCaptureProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
p.mu.Lock()
defer p.mu.Unlock()
p.calls++
if p.calls == 1 {
return &providers.LLMResponse{
ToolCalls: p.toolCalls,
}, nil
}
p.terminalMessages = append([]providers.Message(nil), messages...)
p.terminalToolsCount = len(tools)
return &providers.LLMResponse{
Content: p.finalResp,
}, nil
}
func (p *gracefulCaptureProvider) GetDefaultModel() string {
return "graceful-capture-mock"
}
type lateSteeringProvider struct {
mu sync.Mutex
calls int
firstCallStarted chan struct{}
releaseFirstCall chan struct{}
firstStartOnce sync.Once
secondCallMessages []providers.Message
}
func (p *lateSteeringProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
p.mu.Lock()
p.calls++
call := p.calls
p.mu.Unlock()
if call == 1 {
p.firstStartOnce.Do(func() { close(p.firstCallStarted) })
<-p.releaseFirstCall
return &providers.LLMResponse{Content: "first response"}, nil
}
p.mu.Lock()
p.secondCallMessages = append([]providers.Message(nil), messages...)
p.mu.Unlock()
return &providers.LLMResponse{Content: "continued response"}, nil
}
func (p *lateSteeringProvider) GetDefaultModel() string {
return "late-steering-mock"
}
type interruptibleTool struct {
name string
started chan struct{}
once sync.Once
}
func (t *interruptibleTool) Name() string { return t.name }
func (t *interruptibleTool) Description() string { return "interruptible tool for testing" }
func (t *interruptibleTool) Parameters() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]any{},
}
}
func (t *interruptibleTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
if t.started != nil {
t.once.Do(func() { close(t.started) })
}
<-ctx.Done()
return tools.ErrorResult(ctx.Err().Error()).WithError(ctx.Err())
}
func TestAgentLoop_Steering_SkipsRemainingTools(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
@@ -568,6 +667,425 @@ func TestAgentLoop_Steering_InitialPoll(t *testing.T) {
}
}
func TestAgentLoop_Run_AutoContinuesLateSteeringMessage(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &lateSteeringProvider{
firstCallStarted: make(chan struct{}),
releaseFirstCall: make(chan struct{}),
}
al := NewAgentLoop(cfg, msgBus, provider)
runCtx, cancelRun := context.WithCancel(context.Background())
defer cancelRun()
runErrCh := make(chan error, 1)
go func() {
runErrCh <- al.Run(runCtx)
}()
first := bus.InboundMessage{
Channel: "test",
SenderID: "user1",
ChatID: "chat1",
Content: "first message",
Peer: bus.Peer{
Kind: "direct",
ID: "user1",
},
}
late := bus.InboundMessage{
Channel: "test",
SenderID: "user1",
ChatID: "chat1",
Content: "late append",
Peer: bus.Peer{
Kind: "direct",
ID: "user1",
},
}
pubCtx, pubCancel := context.WithTimeout(context.Background(), 2*time.Second)
defer pubCancel()
if err := msgBus.PublishInbound(pubCtx, first); err != nil {
t.Fatalf("publish first inbound: %v", err)
}
select {
case <-provider.firstCallStarted:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for first provider call to start")
}
if err := msgBus.PublishInbound(pubCtx, late); err != nil {
t.Fatalf("publish late inbound: %v", err)
}
close(provider.releaseFirstCall)
subCtx, subCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer subCancel()
out1, ok := msgBus.SubscribeOutbound(subCtx)
if !ok {
t.Fatal("expected first outbound response")
}
if out1.Content != "first response" {
t.Fatalf("expected first response, got %q", out1.Content)
}
out2, ok := msgBus.SubscribeOutbound(subCtx)
if !ok {
t.Fatal("expected continued outbound response")
}
if out2.Content != "continued response" {
t.Fatalf("expected continued response, got %q", out2.Content)
}
cancelRun()
select {
case err := <-runErrCh:
if err != nil {
t.Fatalf("Run returned error: %v", err)
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for Run to stop")
}
provider.mu.Lock()
calls := provider.calls
secondMessages := append([]providers.Message(nil), provider.secondCallMessages...)
provider.mu.Unlock()
if calls != 2 {
t.Fatalf("expected 2 provider calls, got %d", calls)
}
foundLateMessage := false
for _, msg := range secondMessages {
if msg.Role == "user" && msg.Content == "late append" {
foundLateMessage = true
break
}
}
if !foundLateMessage {
t.Fatal("expected queued late message to be processed in an automatic follow-up turn")
}
}
func TestAgentLoop_InterruptGraceful_UsesTerminalNoToolCall(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
tool1ExecCh := make(chan struct{})
tool1 := &slowTool{name: "tool_one", duration: 50 * time.Millisecond, execCh: tool1ExecCh}
tool2 := &slowTool{name: "tool_two", duration: 50 * time.Millisecond}
provider := &gracefulCaptureProvider{
toolCalls: []providers.ToolCall{
{
ID: "call_1",
Type: "function",
Name: "tool_one",
Function: &providers.FunctionCall{
Name: "tool_one",
Arguments: "{}",
},
Arguments: map[string]any{},
},
{
ID: "call_2",
Type: "function",
Name: "tool_two",
Function: &providers.FunctionCall{
Name: "tool_two",
Arguments: "{}",
},
Arguments: map[string]any{},
},
},
finalResp: "graceful summary",
}
msgBus := bus.NewMessageBus()
al := NewAgentLoop(cfg, msgBus, provider)
al.RegisterTool(tool1)
al.RegisterTool(tool2)
sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID)
sub := al.SubscribeEvents(32)
defer al.UnsubscribeEvents(sub.ID)
type result struct {
resp string
err error
}
resultCh := make(chan result, 1)
go func() {
resp, err := al.ProcessDirectWithChannel(
context.Background(),
"do something",
sessionKey,
"test",
"chat1",
)
resultCh <- result{resp: resp, err: err}
}()
select {
case <-tool1ExecCh:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for tool_one to start")
}
active := al.GetActiveTurn()
if active == nil {
t.Fatal("expected active turn while tool is running")
}
if active.SessionKey != sessionKey {
t.Fatalf("expected active session %q, got %q", sessionKey, active.SessionKey)
}
if active.Channel != "test" || active.ChatID != "chat1" {
t.Fatalf("unexpected active turn target: %#v", active)
}
if err := al.InterruptGraceful("wrap it up"); err != nil {
t.Fatalf("InterruptGraceful failed: %v", err)
}
select {
case r := <-resultCh:
if r.err != nil {
t.Fatalf("unexpected error: %v", r.err)
}
if r.resp != "graceful summary" {
t.Fatalf("expected graceful summary, got %q", r.resp)
}
case <-time.After(5 * time.Second):
t.Fatal("timeout waiting for graceful interrupt result")
}
if active := al.GetActiveTurn(); active != nil {
t.Fatalf("expected no active turn after completion, got %#v", active)
}
provider.mu.Lock()
terminalMessages := append([]providers.Message(nil), provider.terminalMessages...)
terminalToolsCount := provider.terminalToolsCount
calls := provider.calls
provider.mu.Unlock()
if calls != 2 {
t.Fatalf("expected 2 provider calls, got %d", calls)
}
if terminalToolsCount != 0 {
t.Fatalf("expected graceful terminal call to disable tools, got %d tool defs", terminalToolsCount)
}
foundHint := false
foundSkipped := false
for _, msg := range terminalMessages {
if msg.Role == "user" && msg.Content == "Interrupt requested. Stop scheduling tools and provide a short final summary.\n\nInterrupt hint: wrap it up" {
foundHint = true
}
if msg.Role == "tool" && msg.ToolCallID == "call_2" && msg.Content == "Skipped due to graceful interrupt." {
foundSkipped = true
}
}
if !foundHint {
t.Fatal("expected graceful terminal call to include interrupt hint message")
}
if !foundSkipped {
t.Fatal("expected remaining tool to be marked as skipped after graceful interrupt")
}
events := collectEventStream(sub.C)
interruptEvt, ok := findEvent(events, EventKindInterruptReceived)
if !ok {
t.Fatal("expected interrupt received event")
}
interruptPayload, ok := interruptEvt.Payload.(InterruptReceivedPayload)
if !ok {
t.Fatalf("expected InterruptReceivedPayload, got %T", interruptEvt.Payload)
}
if interruptPayload.Kind != InterruptKindGraceful {
t.Fatalf("expected graceful interrupt payload, got %q", interruptPayload.Kind)
}
turnEndEvt, ok := findEvent(events, EventKindTurnEnd)
if !ok {
t.Fatal("expected turn end event")
}
turnEndPayload, ok := turnEndEvt.Payload.(TurnEndPayload)
if !ok {
t.Fatalf("expected TurnEndPayload, got %T", turnEndEvt.Payload)
}
if turnEndPayload.Status != TurnEndStatusCompleted {
t.Fatalf("expected completed turn after graceful interrupt, got %q", turnEndPayload.Status)
}
}
func TestAgentLoop_InterruptHard_RestoresSession(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &toolCallProvider{
toolCalls: []providers.ToolCall{
{
ID: "call_1",
Type: "function",
Name: "cancel_tool",
Function: &providers.FunctionCall{
Name: "cancel_tool",
Arguments: "{}",
},
Arguments: map[string]any{},
},
},
finalResp: "should not happen",
}
al := NewAgentLoop(cfg, msgBus, provider)
started := make(chan struct{})
al.RegisterTool(&interruptibleTool{name: "cancel_tool", started: started})
sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID)
defaultAgent := al.registry.GetDefaultAgent()
if defaultAgent == nil {
t.Fatal("expected default agent")
}
originalHistory := []providers.Message{
{Role: "user", Content: "before"},
{Role: "assistant", Content: "after"},
}
defaultAgent.Sessions.SetHistory(sessionKey, originalHistory)
sub := al.SubscribeEvents(16)
defer al.UnsubscribeEvents(sub.ID)
type result struct {
resp string
err error
}
resultCh := make(chan result, 1)
go func() {
resp, err := al.ProcessDirectWithChannel(
context.Background(),
"do work",
sessionKey,
"test",
"chat1",
)
resultCh <- result{resp: resp, err: err}
}()
select {
case <-started:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for interruptible tool to start")
}
if active := al.GetActiveTurn(); active == nil {
t.Fatal("expected active turn before hard abort")
}
if err := al.InterruptHard(); err != nil {
t.Fatalf("InterruptHard failed: %v", err)
}
select {
case r := <-resultCh:
if r.err != nil {
t.Fatalf("unexpected error: %v", r.err)
}
if r.resp != "" {
t.Fatalf("expected no final response after hard abort, got %q", r.resp)
}
case <-time.After(5 * time.Second):
t.Fatal("timeout waiting for hard abort result")
}
if active := al.GetActiveTurn(); active != nil {
t.Fatalf("expected no active turn after hard abort, got %#v", active)
}
finalHistory := defaultAgent.Sessions.GetHistory(sessionKey)
if !reflect.DeepEqual(finalHistory, originalHistory) {
t.Fatalf("expected history rollback after hard abort, got %#v", finalHistory)
}
events := collectEventStream(sub.C)
interruptEvt, ok := findEvent(events, EventKindInterruptReceived)
if !ok {
t.Fatal("expected interrupt received event")
}
interruptPayload, ok := interruptEvt.Payload.(InterruptReceivedPayload)
if !ok {
t.Fatalf("expected InterruptReceivedPayload, got %T", interruptEvt.Payload)
}
if interruptPayload.Kind != InterruptKindHard {
t.Fatalf("expected hard interrupt payload, got %q", interruptPayload.Kind)
}
turnEndEvt, ok := findEvent(events, EventKindTurnEnd)
if !ok {
t.Fatal("expected turn end event")
}
turnEndPayload, ok := turnEndEvt.Payload.(TurnEndPayload)
if !ok {
t.Fatalf("expected TurnEndPayload, got %T", turnEndEvt.Payload)
}
if turnEndPayload.Status != TurnEndStatusAborted {
t.Fatalf("expected aborted turn, got %q", turnEndPayload.Status)
}
}
// capturingMockProvider captures messages sent to Chat for inspection.
type capturingMockProvider struct {
response string
+309
View File
@@ -0,0 +1,309 @@
package agent
import (
"context"
"reflect"
"sync"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/providers"
)
type TurnPhase string
const (
TurnPhaseSetup TurnPhase = "setup"
TurnPhaseRunning TurnPhase = "running"
TurnPhaseTools TurnPhase = "tools"
TurnPhaseFinalizing TurnPhase = "finalizing"
TurnPhaseCompleted TurnPhase = "completed"
TurnPhaseAborted TurnPhase = "aborted"
)
type ActiveTurnInfo struct {
TurnID string
AgentID string
SessionKey string
Channel string
ChatID string
UserMessage string
Phase TurnPhase
Iteration int
StartedAt time.Time
}
type turnResult struct {
finalContent string
status TurnEndStatus
followUps []bus.InboundMessage
}
type turnState struct {
mu sync.RWMutex
agent *AgentInstance
opts processOptions
scope turnEventScope
turnID string
agentID string
sessionKey string
channel string
chatID string
userMessage string
media []string
phase TurnPhase
iteration int
startedAt time.Time
finalContent string
pendingSteering []providers.Message
followUps []bus.InboundMessage
gracefulInterrupt bool
gracefulInterruptHint string
gracefulTerminalUsed bool
hardAbort bool
providerCancel context.CancelFunc
turnCancel context.CancelFunc
restorePointHistory []providers.Message
restorePointSummary string
persistedMessages []providers.Message
}
func newTurnState(agent *AgentInstance, opts processOptions, scope turnEventScope) *turnState {
return &turnState{
agent: agent,
opts: opts,
scope: scope,
turnID: scope.turnID,
agentID: agent.ID,
sessionKey: opts.SessionKey,
channel: opts.Channel,
chatID: opts.ChatID,
userMessage: opts.UserMessage,
media: append([]string(nil), opts.Media...),
phase: TurnPhaseSetup,
startedAt: time.Now(),
}
}
func (al *AgentLoop) registerActiveTurn(ts *turnState) {
al.activeTurnMu.Lock()
defer al.activeTurnMu.Unlock()
al.activeTurn = ts
}
func (al *AgentLoop) clearActiveTurn(ts *turnState) {
al.activeTurnMu.Lock()
defer al.activeTurnMu.Unlock()
if al.activeTurn == ts {
al.activeTurn = nil
}
}
func (al *AgentLoop) getActiveTurnState() *turnState {
al.activeTurnMu.RLock()
defer al.activeTurnMu.RUnlock()
return al.activeTurn
}
func (al *AgentLoop) GetActiveTurn() *ActiveTurnInfo {
ts := al.getActiveTurnState()
if ts == nil {
return nil
}
info := ts.snapshot()
return &info
}
func (ts *turnState) snapshot() ActiveTurnInfo {
ts.mu.RLock()
defer ts.mu.RUnlock()
return ActiveTurnInfo{
TurnID: ts.turnID,
AgentID: ts.agentID,
SessionKey: ts.sessionKey,
Channel: ts.channel,
ChatID: ts.chatID,
UserMessage: ts.userMessage,
Phase: ts.phase,
Iteration: ts.iteration,
StartedAt: ts.startedAt,
}
}
func (ts *turnState) setPhase(phase TurnPhase) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.phase = phase
}
func (ts *turnState) setIteration(iteration int) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.iteration = iteration
}
func (ts *turnState) currentIteration() int {
ts.mu.RLock()
defer ts.mu.RUnlock()
return ts.iteration
}
func (ts *turnState) setFinalContent(content string) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.finalContent = content
}
func (ts *turnState) finalContentLen() int {
ts.mu.RLock()
defer ts.mu.RUnlock()
return len(ts.finalContent)
}
func (ts *turnState) setTurnCancel(cancel context.CancelFunc) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.turnCancel = cancel
}
func (ts *turnState) setProviderCancel(cancel context.CancelFunc) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.providerCancel = cancel
}
func (ts *turnState) clearProviderCancel(_ context.CancelFunc) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.providerCancel = nil
}
func (ts *turnState) requestGracefulInterrupt(hint string) bool {
ts.mu.Lock()
defer ts.mu.Unlock()
if ts.hardAbort {
return false
}
ts.gracefulInterrupt = true
ts.gracefulInterruptHint = hint
return true
}
func (ts *turnState) gracefulInterruptRequested() (bool, string) {
ts.mu.RLock()
defer ts.mu.RUnlock()
return ts.gracefulInterrupt && !ts.gracefulTerminalUsed, ts.gracefulInterruptHint
}
func (ts *turnState) markGracefulTerminalUsed() {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.gracefulTerminalUsed = true
}
func (ts *turnState) requestHardAbort() bool {
ts.mu.Lock()
if ts.hardAbort {
ts.mu.Unlock()
return false
}
ts.hardAbort = true
turnCancel := ts.turnCancel
providerCancel := ts.providerCancel
ts.mu.Unlock()
if providerCancel != nil {
providerCancel()
}
if turnCancel != nil {
turnCancel()
}
return true
}
func (ts *turnState) hardAbortRequested() bool {
ts.mu.RLock()
defer ts.mu.RUnlock()
return ts.hardAbort
}
func (ts *turnState) eventMeta(source, tracePath string) EventMeta {
snap := ts.snapshot()
return EventMeta{
AgentID: snap.AgentID,
TurnID: snap.TurnID,
SessionKey: snap.SessionKey,
Iteration: snap.Iteration,
Source: source,
TracePath: tracePath,
}
}
func (ts *turnState) captureRestorePoint(history []providers.Message, summary string) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.restorePointHistory = append([]providers.Message(nil), history...)
ts.restorePointSummary = summary
}
func (ts *turnState) recordPersistedMessage(msg providers.Message) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.persistedMessages = append(ts.persistedMessages, msg)
}
func (ts *turnState) refreshRestorePointFromSession(agent *AgentInstance) {
history := agent.Sessions.GetHistory(ts.sessionKey)
summary := agent.Sessions.GetSummary(ts.sessionKey)
ts.mu.RLock()
persisted := append([]providers.Message(nil), ts.persistedMessages...)
ts.mu.RUnlock()
if matched := matchingTurnMessageTail(history, persisted); matched > 0 {
history = append([]providers.Message(nil), history[:len(history)-matched]...)
}
ts.captureRestorePoint(history, summary)
}
func (ts *turnState) restoreSession(agent *AgentInstance) error {
ts.mu.RLock()
history := append([]providers.Message(nil), ts.restorePointHistory...)
summary := ts.restorePointSummary
ts.mu.RUnlock()
agent.Sessions.SetHistory(ts.sessionKey, history)
agent.Sessions.SetSummary(ts.sessionKey, summary)
return agent.Sessions.Save(ts.sessionKey)
}
func matchingTurnMessageTail(history, persisted []providers.Message) int {
maxMatch := min(len(history), len(persisted))
for size := maxMatch; size > 0; size-- {
if reflect.DeepEqual(history[len(history)-size:], persisted[len(persisted)-size:]) {
return size
}
}
return 0
}
func (ts *turnState) interruptHintMessage() providers.Message {
_, hint := ts.gracefulInterruptRequested()
content := "Interrupt requested. Stop scheduling tools and provide a short final summary."
if hint != "" {
content += "\n\nInterrupt hint: " + hint
}
return providers.Message{
Role: "user",
Content: content,
}
}