Merge pull request #1827 from alexhoshina/refactor/agent-loop

Refactor/agent loop
This commit is contained in:
daming大铭
2026-03-20 20:56:53 +08:00
committed by GitHub
6 changed files with 1396 additions and 337 deletions
+3
View File
@@ -334,6 +334,9 @@ func TestAgentLoop_EmitsSteeringAndSkippedToolEvents(t *testing.T) {
if interruptPayload.Role != "user" {
t.Fatalf("expected interrupt role user, got %q", interruptPayload.Role)
}
if interruptPayload.Kind != InterruptKindSteering {
t.Fatalf("expected steering interrupt kind, got %q", interruptPayload.Kind)
}
if interruptPayload.ContentLen != len("change course") {
t.Fatalf("expected interrupt content len %d, got %d", len("change course"), interruptPayload.ContentLen)
}
+13 -1
View File
@@ -105,6 +105,8 @@ const (
TurnEndStatusCompleted TurnEndStatus = "completed"
// TurnEndStatusError indicates the turn ended because of an error.
TurnEndStatusError TurnEndStatus = "error"
// TurnEndStatusAborted indicates the turn was hard-aborted and rolled back.
TurnEndStatusAborted TurnEndStatus = "aborted"
)
// TurnStartPayload describes the start of a turn.
@@ -215,11 +217,21 @@ type FollowUpQueuedPayload struct {
ContentLen int
}
// InterruptReceivedPayload describes a queued soft interrupt.
type InterruptKind string
const (
InterruptKindSteering InterruptKind = "steering"
InterruptKindGraceful InterruptKind = "graceful"
InterruptKindHard InterruptKind = "hard_abort"
)
// InterruptReceivedPayload describes accepted turn-control input.
type InterruptReceivedPayload struct {
Kind InterruptKind
Role string
ContentLen int
QueueDepth int
HintLen int
}
// SubTurnSpawnPayload describes the creation of a child turn.
+490 -328
View File
File diff suppressed because it is too large Load Diff
+62 -8
View File
@@ -122,20 +122,23 @@ func (al *AgentLoop) Steer(msg providers.Message) error {
"content_len": len(msg.Content),
"queue_len": al.steering.len(),
})
agentID := ""
if registry := al.GetRegistry(); registry != nil {
meta := EventMeta{
Source: "Steer",
TracePath: "turn.interrupt.received",
}
if ts := al.getActiveTurnState(); ts != nil {
meta = ts.eventMeta("Steer", "turn.interrupt.received")
} else if registry := al.GetRegistry(); registry != nil {
if agent := registry.GetDefaultAgent(); agent != nil {
agentID = agent.ID
meta.AgentID = agent.ID
}
}
al.emitEvent(
EventKindInterruptReceived,
EventMeta{
AgentID: agentID,
Source: "Steer",
TracePath: "turn.interrupt.received",
},
meta,
InterruptReceivedPayload{
Kind: InterruptKindSteering,
Role: msg.Role,
ContentLen: len(msg.Content),
QueueDepth: al.steering.len(),
@@ -177,6 +180,10 @@ func (al *AgentLoop) dequeueSteeringMessages() []providers.Message {
//
// If no steering messages are pending, it returns an empty string.
func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID string) (string, error) {
if active := al.GetActiveTurn(); active != nil {
return "", fmt.Errorf("turn %s is still active", active.TurnID)
}
steeringMsgs := al.dequeueSteeringMessages()
if len(steeringMsgs) == 0 {
return "", nil
@@ -187,6 +194,12 @@ func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID s
return "", fmt.Errorf("no default agent available")
}
if tool, ok := agent.Tools.Get("message"); ok {
if resetter, ok := tool.(interface{ ResetSentInRound() }); ok {
resetter.ResetSentInRound()
}
}
// Build a combined user message from the steering messages.
var contents []string
for _, msg := range steeringMsgs {
@@ -205,3 +218,44 @@ func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID s
SkipInitialSteeringPoll: true,
})
}
func (al *AgentLoop) InterruptGraceful(hint string) error {
ts := al.getActiveTurnState()
if ts == nil {
return fmt.Errorf("no active turn")
}
if !ts.requestGracefulInterrupt(hint) {
return fmt.Errorf("turn %s cannot accept graceful interrupt", ts.turnID)
}
al.emitEvent(
EventKindInterruptReceived,
ts.eventMeta("InterruptGraceful", "turn.interrupt.received"),
InterruptReceivedPayload{
Kind: InterruptKindGraceful,
HintLen: len(hint),
},
)
return nil
}
func (al *AgentLoop) InterruptHard() error {
ts := al.getActiveTurnState()
if ts == nil {
return fmt.Errorf("no active turn")
}
if !ts.requestHardAbort() {
return fmt.Errorf("turn %s is already aborting", ts.turnID)
}
al.emitEvent(
EventKindInterruptReceived,
ts.eventMeta("InterruptHard", "turn.interrupt.received"),
InterruptReceivedPayload{
Kind: InterruptKindHard,
},
)
return nil
}
+520
View File
@@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"os"
"reflect"
"sync"
"testing"
"time"
@@ -12,6 +13,7 @@ import (
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/routing"
"github.com/sipeed/picoclaw/pkg/tools"
)
@@ -396,6 +398,103 @@ func (m *toolCallProvider) GetDefaultModel() string {
return "tool-call-mock"
}
type gracefulCaptureProvider struct {
mu sync.Mutex
calls int
toolCalls []providers.ToolCall
finalResp string
terminalMessages []providers.Message
terminalToolsCount int
}
func (p *gracefulCaptureProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
p.mu.Lock()
defer p.mu.Unlock()
p.calls++
if p.calls == 1 {
return &providers.LLMResponse{
ToolCalls: p.toolCalls,
}, nil
}
p.terminalMessages = append([]providers.Message(nil), messages...)
p.terminalToolsCount = len(tools)
return &providers.LLMResponse{
Content: p.finalResp,
}, nil
}
func (p *gracefulCaptureProvider) GetDefaultModel() string {
return "graceful-capture-mock"
}
type lateSteeringProvider struct {
mu sync.Mutex
calls int
firstCallStarted chan struct{}
releaseFirstCall chan struct{}
firstStartOnce sync.Once
secondCallMessages []providers.Message
}
func (p *lateSteeringProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
p.mu.Lock()
p.calls++
call := p.calls
p.mu.Unlock()
if call == 1 {
p.firstStartOnce.Do(func() { close(p.firstCallStarted) })
<-p.releaseFirstCall
return &providers.LLMResponse{Content: "first response"}, nil
}
p.mu.Lock()
p.secondCallMessages = append([]providers.Message(nil), messages...)
p.mu.Unlock()
return &providers.LLMResponse{Content: "continued response"}, nil
}
func (p *lateSteeringProvider) GetDefaultModel() string {
return "late-steering-mock"
}
type interruptibleTool struct {
name string
started chan struct{}
once sync.Once
}
func (t *interruptibleTool) Name() string { return t.name }
func (t *interruptibleTool) Description() string { return "interruptible tool for testing" }
func (t *interruptibleTool) Parameters() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]any{},
}
}
func (t *interruptibleTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
if t.started != nil {
t.once.Do(func() { close(t.started) })
}
<-ctx.Done()
return tools.ErrorResult(ctx.Err().Error()).WithError(ctx.Err())
}
func TestAgentLoop_Steering_SkipsRemainingTools(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
@@ -568,6 +667,427 @@ func TestAgentLoop_Steering_InitialPoll(t *testing.T) {
}
}
func TestAgentLoop_Run_AutoContinuesLateSteeringMessage(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &lateSteeringProvider{
firstCallStarted: make(chan struct{}),
releaseFirstCall: make(chan struct{}),
}
al := NewAgentLoop(cfg, msgBus, provider)
runCtx, cancelRun := context.WithCancel(context.Background())
defer cancelRun()
runErrCh := make(chan error, 1)
go func() {
runErrCh <- al.Run(runCtx)
}()
first := bus.InboundMessage{
Channel: "test",
SenderID: "user1",
ChatID: "chat1",
Content: "first message",
Peer: bus.Peer{
Kind: "direct",
ID: "user1",
},
}
late := bus.InboundMessage{
Channel: "test",
SenderID: "user1",
ChatID: "chat1",
Content: "late append",
Peer: bus.Peer{
Kind: "direct",
ID: "user1",
},
}
pubCtx, pubCancel := context.WithTimeout(context.Background(), 2*time.Second)
defer pubCancel()
if err := msgBus.PublishInbound(pubCtx, first); err != nil {
t.Fatalf("publish first inbound: %v", err)
}
select {
case <-provider.firstCallStarted:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for first provider call to start")
}
if err := msgBus.PublishInbound(pubCtx, late); err != nil {
t.Fatalf("publish late inbound: %v", err)
}
close(provider.releaseFirstCall)
subCtx, subCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer subCancel()
out1, ok := msgBus.SubscribeOutbound(subCtx)
if !ok {
t.Fatal("expected first outbound response")
}
if out1.Content != "first response" {
t.Fatalf("expected first response, got %q", out1.Content)
}
out2, ok := msgBus.SubscribeOutbound(subCtx)
if !ok {
t.Fatal("expected continued outbound response")
}
if out2.Content != "continued response" {
t.Fatalf("expected continued response, got %q", out2.Content)
}
cancelRun()
select {
case err := <-runErrCh:
if err != nil {
t.Fatalf("Run returned error: %v", err)
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for Run to stop")
}
provider.mu.Lock()
calls := provider.calls
secondMessages := append([]providers.Message(nil), provider.secondCallMessages...)
provider.mu.Unlock()
if calls != 2 {
t.Fatalf("expected 2 provider calls, got %d", calls)
}
foundLateMessage := false
for _, msg := range secondMessages {
if msg.Role == "user" && msg.Content == "late append" {
foundLateMessage = true
break
}
}
if !foundLateMessage {
t.Fatal("expected queued late message to be processed in an automatic follow-up turn")
}
}
func TestAgentLoop_InterruptGraceful_UsesTerminalNoToolCall(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
tool1ExecCh := make(chan struct{})
tool1 := &slowTool{name: "tool_one", duration: 50 * time.Millisecond, execCh: tool1ExecCh}
tool2 := &slowTool{name: "tool_two", duration: 50 * time.Millisecond}
provider := &gracefulCaptureProvider{
toolCalls: []providers.ToolCall{
{
ID: "call_1",
Type: "function",
Name: "tool_one",
Function: &providers.FunctionCall{
Name: "tool_one",
Arguments: "{}",
},
Arguments: map[string]any{},
},
{
ID: "call_2",
Type: "function",
Name: "tool_two",
Function: &providers.FunctionCall{
Name: "tool_two",
Arguments: "{}",
},
Arguments: map[string]any{},
},
},
finalResp: "graceful summary",
}
msgBus := bus.NewMessageBus()
al := NewAgentLoop(cfg, msgBus, provider)
al.RegisterTool(tool1)
al.RegisterTool(tool2)
sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID)
sub := al.SubscribeEvents(32)
defer al.UnsubscribeEvents(sub.ID)
type result struct {
resp string
err error
}
resultCh := make(chan result, 1)
go func() {
resp, err := al.ProcessDirectWithChannel(
context.Background(),
"do something",
sessionKey,
"test",
"chat1",
)
resultCh <- result{resp: resp, err: err}
}()
select {
case <-tool1ExecCh:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for tool_one to start")
}
active := al.GetActiveTurn()
if active == nil {
t.Fatal("expected active turn while tool is running")
}
if active.SessionKey != sessionKey {
t.Fatalf("expected active session %q, got %q", sessionKey, active.SessionKey)
}
if active.Channel != "test" || active.ChatID != "chat1" {
t.Fatalf("unexpected active turn target: %#v", active)
}
if err := al.InterruptGraceful("wrap it up"); err != nil {
t.Fatalf("InterruptGraceful failed: %v", err)
}
select {
case r := <-resultCh:
if r.err != nil {
t.Fatalf("unexpected error: %v", r.err)
}
if r.resp != "graceful summary" {
t.Fatalf("expected graceful summary, got %q", r.resp)
}
case <-time.After(5 * time.Second):
t.Fatal("timeout waiting for graceful interrupt result")
}
if active := al.GetActiveTurn(); active != nil {
t.Fatalf("expected no active turn after completion, got %#v", active)
}
provider.mu.Lock()
terminalMessages := append([]providers.Message(nil), provider.terminalMessages...)
terminalToolsCount := provider.terminalToolsCount
calls := provider.calls
provider.mu.Unlock()
if calls != 2 {
t.Fatalf("expected 2 provider calls, got %d", calls)
}
if terminalToolsCount != 0 {
t.Fatalf("expected graceful terminal call to disable tools, got %d tool defs", terminalToolsCount)
}
foundHint := false
foundSkipped := false
expectedHint := "Interrupt requested. Stop scheduling tools and provide a short final summary.\n\n" +
"Interrupt hint: wrap it up"
for _, msg := range terminalMessages {
if msg.Role == "user" && msg.Content == expectedHint {
foundHint = true
}
if msg.Role == "tool" && msg.ToolCallID == "call_2" && msg.Content == "Skipped due to graceful interrupt." {
foundSkipped = true
}
}
if !foundHint {
t.Fatal("expected graceful terminal call to include interrupt hint message")
}
if !foundSkipped {
t.Fatal("expected remaining tool to be marked as skipped after graceful interrupt")
}
events := collectEventStream(sub.C)
interruptEvt, ok := findEvent(events, EventKindInterruptReceived)
if !ok {
t.Fatal("expected interrupt received event")
}
interruptPayload, ok := interruptEvt.Payload.(InterruptReceivedPayload)
if !ok {
t.Fatalf("expected InterruptReceivedPayload, got %T", interruptEvt.Payload)
}
if interruptPayload.Kind != InterruptKindGraceful {
t.Fatalf("expected graceful interrupt payload, got %q", interruptPayload.Kind)
}
turnEndEvt, ok := findEvent(events, EventKindTurnEnd)
if !ok {
t.Fatal("expected turn end event")
}
turnEndPayload, ok := turnEndEvt.Payload.(TurnEndPayload)
if !ok {
t.Fatalf("expected TurnEndPayload, got %T", turnEndEvt.Payload)
}
if turnEndPayload.Status != TurnEndStatusCompleted {
t.Fatalf("expected completed turn after graceful interrupt, got %q", turnEndPayload.Status)
}
}
func TestAgentLoop_InterruptHard_RestoresSession(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &toolCallProvider{
toolCalls: []providers.ToolCall{
{
ID: "call_1",
Type: "function",
Name: "cancel_tool",
Function: &providers.FunctionCall{
Name: "cancel_tool",
Arguments: "{}",
},
Arguments: map[string]any{},
},
},
finalResp: "should not happen",
}
al := NewAgentLoop(cfg, msgBus, provider)
started := make(chan struct{})
al.RegisterTool(&interruptibleTool{name: "cancel_tool", started: started})
sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID)
defaultAgent := al.registry.GetDefaultAgent()
if defaultAgent == nil {
t.Fatal("expected default agent")
}
originalHistory := []providers.Message{
{Role: "user", Content: "before"},
{Role: "assistant", Content: "after"},
}
defaultAgent.Sessions.SetHistory(sessionKey, originalHistory)
sub := al.SubscribeEvents(16)
defer al.UnsubscribeEvents(sub.ID)
type result struct {
resp string
err error
}
resultCh := make(chan result, 1)
go func() {
resp, err := al.ProcessDirectWithChannel(
context.Background(),
"do work",
sessionKey,
"test",
"chat1",
)
resultCh <- result{resp: resp, err: err}
}()
select {
case <-started:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for interruptible tool to start")
}
if active := al.GetActiveTurn(); active == nil {
t.Fatal("expected active turn before hard abort")
}
if err := al.InterruptHard(); err != nil {
t.Fatalf("InterruptHard failed: %v", err)
}
select {
case r := <-resultCh:
if r.err != nil {
t.Fatalf("unexpected error: %v", r.err)
}
if r.resp != "" {
t.Fatalf("expected no final response after hard abort, got %q", r.resp)
}
case <-time.After(5 * time.Second):
t.Fatal("timeout waiting for hard abort result")
}
if active := al.GetActiveTurn(); active != nil {
t.Fatalf("expected no active turn after hard abort, got %#v", active)
}
finalHistory := defaultAgent.Sessions.GetHistory(sessionKey)
if !reflect.DeepEqual(finalHistory, originalHistory) {
t.Fatalf("expected history rollback after hard abort, got %#v", finalHistory)
}
events := collectEventStream(sub.C)
interruptEvt, ok := findEvent(events, EventKindInterruptReceived)
if !ok {
t.Fatal("expected interrupt received event")
}
interruptPayload, ok := interruptEvt.Payload.(InterruptReceivedPayload)
if !ok {
t.Fatalf("expected InterruptReceivedPayload, got %T", interruptEvt.Payload)
}
if interruptPayload.Kind != InterruptKindHard {
t.Fatalf("expected hard interrupt payload, got %q", interruptPayload.Kind)
}
turnEndEvt, ok := findEvent(events, EventKindTurnEnd)
if !ok {
t.Fatal("expected turn end event")
}
turnEndPayload, ok := turnEndEvt.Payload.(TurnEndPayload)
if !ok {
t.Fatalf("expected TurnEndPayload, got %T", turnEndEvt.Payload)
}
if turnEndPayload.Status != TurnEndStatusAborted {
t.Fatalf("expected aborted turn, got %q", turnEndPayload.Status)
}
}
// capturingMockProvider captures messages sent to Chat for inspection.
type capturingMockProvider struct {
response string
+308
View File
@@ -0,0 +1,308 @@
package agent
import (
"context"
"reflect"
"sync"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/providers"
)
type TurnPhase string
const (
TurnPhaseSetup TurnPhase = "setup"
TurnPhaseRunning TurnPhase = "running"
TurnPhaseTools TurnPhase = "tools"
TurnPhaseFinalizing TurnPhase = "finalizing"
TurnPhaseCompleted TurnPhase = "completed"
TurnPhaseAborted TurnPhase = "aborted"
)
type ActiveTurnInfo struct {
TurnID string
AgentID string
SessionKey string
Channel string
ChatID string
UserMessage string
Phase TurnPhase
Iteration int
StartedAt time.Time
}
type turnResult struct {
finalContent string
status TurnEndStatus
followUps []bus.InboundMessage
}
type turnState struct {
mu sync.RWMutex
agent *AgentInstance
opts processOptions
scope turnEventScope
turnID string
agentID string
sessionKey string
channel string
chatID string
userMessage string
media []string
phase TurnPhase
iteration int
startedAt time.Time
finalContent string
followUps []bus.InboundMessage
gracefulInterrupt bool
gracefulInterruptHint string
gracefulTerminalUsed bool
hardAbort bool
providerCancel context.CancelFunc
turnCancel context.CancelFunc
restorePointHistory []providers.Message
restorePointSummary string
persistedMessages []providers.Message
}
func newTurnState(agent *AgentInstance, opts processOptions, scope turnEventScope) *turnState {
return &turnState{
agent: agent,
opts: opts,
scope: scope,
turnID: scope.turnID,
agentID: agent.ID,
sessionKey: opts.SessionKey,
channel: opts.Channel,
chatID: opts.ChatID,
userMessage: opts.UserMessage,
media: append([]string(nil), opts.Media...),
phase: TurnPhaseSetup,
startedAt: time.Now(),
}
}
func (al *AgentLoop) registerActiveTurn(ts *turnState) {
al.activeTurnMu.Lock()
defer al.activeTurnMu.Unlock()
al.activeTurn = ts
}
func (al *AgentLoop) clearActiveTurn(ts *turnState) {
al.activeTurnMu.Lock()
defer al.activeTurnMu.Unlock()
if al.activeTurn == ts {
al.activeTurn = nil
}
}
func (al *AgentLoop) getActiveTurnState() *turnState {
al.activeTurnMu.RLock()
defer al.activeTurnMu.RUnlock()
return al.activeTurn
}
func (al *AgentLoop) GetActiveTurn() *ActiveTurnInfo {
ts := al.getActiveTurnState()
if ts == nil {
return nil
}
info := ts.snapshot()
return &info
}
func (ts *turnState) snapshot() ActiveTurnInfo {
ts.mu.RLock()
defer ts.mu.RUnlock()
return ActiveTurnInfo{
TurnID: ts.turnID,
AgentID: ts.agentID,
SessionKey: ts.sessionKey,
Channel: ts.channel,
ChatID: ts.chatID,
UserMessage: ts.userMessage,
Phase: ts.phase,
Iteration: ts.iteration,
StartedAt: ts.startedAt,
}
}
func (ts *turnState) setPhase(phase TurnPhase) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.phase = phase
}
func (ts *turnState) setIteration(iteration int) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.iteration = iteration
}
func (ts *turnState) currentIteration() int {
ts.mu.RLock()
defer ts.mu.RUnlock()
return ts.iteration
}
func (ts *turnState) setFinalContent(content string) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.finalContent = content
}
func (ts *turnState) finalContentLen() int {
ts.mu.RLock()
defer ts.mu.RUnlock()
return len(ts.finalContent)
}
func (ts *turnState) setTurnCancel(cancel context.CancelFunc) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.turnCancel = cancel
}
func (ts *turnState) setProviderCancel(cancel context.CancelFunc) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.providerCancel = cancel
}
func (ts *turnState) clearProviderCancel(_ context.CancelFunc) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.providerCancel = nil
}
func (ts *turnState) requestGracefulInterrupt(hint string) bool {
ts.mu.Lock()
defer ts.mu.Unlock()
if ts.hardAbort {
return false
}
ts.gracefulInterrupt = true
ts.gracefulInterruptHint = hint
return true
}
func (ts *turnState) gracefulInterruptRequested() (bool, string) {
ts.mu.RLock()
defer ts.mu.RUnlock()
return ts.gracefulInterrupt && !ts.gracefulTerminalUsed, ts.gracefulInterruptHint
}
func (ts *turnState) markGracefulTerminalUsed() {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.gracefulTerminalUsed = true
}
func (ts *turnState) requestHardAbort() bool {
ts.mu.Lock()
if ts.hardAbort {
ts.mu.Unlock()
return false
}
ts.hardAbort = true
turnCancel := ts.turnCancel
providerCancel := ts.providerCancel
ts.mu.Unlock()
if providerCancel != nil {
providerCancel()
}
if turnCancel != nil {
turnCancel()
}
return true
}
func (ts *turnState) hardAbortRequested() bool {
ts.mu.RLock()
defer ts.mu.RUnlock()
return ts.hardAbort
}
func (ts *turnState) eventMeta(source, tracePath string) EventMeta {
snap := ts.snapshot()
return EventMeta{
AgentID: snap.AgentID,
TurnID: snap.TurnID,
SessionKey: snap.SessionKey,
Iteration: snap.Iteration,
Source: source,
TracePath: tracePath,
}
}
func (ts *turnState) captureRestorePoint(history []providers.Message, summary string) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.restorePointHistory = append([]providers.Message(nil), history...)
ts.restorePointSummary = summary
}
func (ts *turnState) recordPersistedMessage(msg providers.Message) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.persistedMessages = append(ts.persistedMessages, msg)
}
func (ts *turnState) refreshRestorePointFromSession(agent *AgentInstance) {
history := agent.Sessions.GetHistory(ts.sessionKey)
summary := agent.Sessions.GetSummary(ts.sessionKey)
ts.mu.RLock()
persisted := append([]providers.Message(nil), ts.persistedMessages...)
ts.mu.RUnlock()
if matched := matchingTurnMessageTail(history, persisted); matched > 0 {
history = append([]providers.Message(nil), history[:len(history)-matched]...)
}
ts.captureRestorePoint(history, summary)
}
func (ts *turnState) restoreSession(agent *AgentInstance) error {
ts.mu.RLock()
history := append([]providers.Message(nil), ts.restorePointHistory...)
summary := ts.restorePointSummary
ts.mu.RUnlock()
agent.Sessions.SetHistory(ts.sessionKey, history)
agent.Sessions.SetSummary(ts.sessionKey, summary)
return agent.Sessions.Save(ts.sessionKey)
}
func matchingTurnMessageTail(history, persisted []providers.Message) int {
maxMatch := min(len(history), len(persisted))
for size := maxMatch; size > 0; size-- {
if reflect.DeepEqual(history[len(history)-size:], persisted[len(persisted)-size:]) {
return size
}
}
return 0
}
func (ts *turnState) interruptHintMessage() providers.Message {
_, hint := ts.gracefulInterruptRequested()
content := "Interrupt requested. Stop scheduling tools and provide a short final summary."
if hint != "" {
content += "\n\nInterrupt hint: " + hint
}
return providers.Message{
Role: "user",
Content: content,
}
}