mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
feat(agent): centralize turn lifecycle and continue queued steering
Refactor agent loop execution around runTurn, add explicit turn state and interrupt semantics, and automatically continue queued steering that misses the current turn boundary.
This commit is contained in:
@@ -334,6 +334,9 @@ func TestAgentLoop_EmitsSteeringAndSkippedToolEvents(t *testing.T) {
|
||||
if interruptPayload.Role != "user" {
|
||||
t.Fatalf("expected interrupt role user, got %q", interruptPayload.Role)
|
||||
}
|
||||
if interruptPayload.Kind != InterruptKindSteering {
|
||||
t.Fatalf("expected steering interrupt kind, got %q", interruptPayload.Kind)
|
||||
}
|
||||
if interruptPayload.ContentLen != len("change course") {
|
||||
t.Fatalf("expected interrupt content len %d, got %d", len("change course"), interruptPayload.ContentLen)
|
||||
}
|
||||
|
||||
+13
-1
@@ -105,6 +105,8 @@ const (
|
||||
TurnEndStatusCompleted TurnEndStatus = "completed"
|
||||
// TurnEndStatusError indicates the turn ended because of an error.
|
||||
TurnEndStatusError TurnEndStatus = "error"
|
||||
// TurnEndStatusAborted indicates the turn was hard-aborted and rolled back.
|
||||
TurnEndStatusAborted TurnEndStatus = "aborted"
|
||||
)
|
||||
|
||||
// TurnStartPayload describes the start of a turn.
|
||||
@@ -215,11 +217,21 @@ type FollowUpQueuedPayload struct {
|
||||
ContentLen int
|
||||
}
|
||||
|
||||
// InterruptReceivedPayload describes a queued soft interrupt.
|
||||
type InterruptKind string
|
||||
|
||||
const (
|
||||
InterruptKindSteering InterruptKind = "steering"
|
||||
InterruptKindGraceful InterruptKind = "graceful"
|
||||
InterruptKindHard InterruptKind = "hard_abort"
|
||||
)
|
||||
|
||||
// InterruptReceivedPayload describes accepted turn-control input.
|
||||
type InterruptReceivedPayload struct {
|
||||
Kind InterruptKind
|
||||
Role string
|
||||
ContentLen int
|
||||
QueueDepth int
|
||||
HintLen int
|
||||
}
|
||||
|
||||
// SubTurnSpawnPayload describes the creation of a child turn.
|
||||
|
||||
+490
-328
File diff suppressed because it is too large
Load Diff
+62
-8
@@ -122,20 +122,23 @@ func (al *AgentLoop) Steer(msg providers.Message) error {
|
||||
"content_len": len(msg.Content),
|
||||
"queue_len": al.steering.len(),
|
||||
})
|
||||
agentID := ""
|
||||
if registry := al.GetRegistry(); registry != nil {
|
||||
|
||||
meta := EventMeta{
|
||||
Source: "Steer",
|
||||
TracePath: "turn.interrupt.received",
|
||||
}
|
||||
if ts := al.getActiveTurnState(); ts != nil {
|
||||
meta = ts.eventMeta("Steer", "turn.interrupt.received")
|
||||
} else if registry := al.GetRegistry(); registry != nil {
|
||||
if agent := registry.GetDefaultAgent(); agent != nil {
|
||||
agentID = agent.ID
|
||||
meta.AgentID = agent.ID
|
||||
}
|
||||
}
|
||||
al.emitEvent(
|
||||
EventKindInterruptReceived,
|
||||
EventMeta{
|
||||
AgentID: agentID,
|
||||
Source: "Steer",
|
||||
TracePath: "turn.interrupt.received",
|
||||
},
|
||||
meta,
|
||||
InterruptReceivedPayload{
|
||||
Kind: InterruptKindSteering,
|
||||
Role: msg.Role,
|
||||
ContentLen: len(msg.Content),
|
||||
QueueDepth: al.steering.len(),
|
||||
@@ -177,6 +180,10 @@ func (al *AgentLoop) dequeueSteeringMessages() []providers.Message {
|
||||
//
|
||||
// If no steering messages are pending, it returns an empty string.
|
||||
func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID string) (string, error) {
|
||||
if active := al.GetActiveTurn(); active != nil {
|
||||
return "", fmt.Errorf("turn %s is still active", active.TurnID)
|
||||
}
|
||||
|
||||
steeringMsgs := al.dequeueSteeringMessages()
|
||||
if len(steeringMsgs) == 0 {
|
||||
return "", nil
|
||||
@@ -187,6 +194,12 @@ func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID s
|
||||
return "", fmt.Errorf("no default agent available")
|
||||
}
|
||||
|
||||
if tool, ok := agent.Tools.Get("message"); ok {
|
||||
if resetter, ok := tool.(interface{ ResetSentInRound() }); ok {
|
||||
resetter.ResetSentInRound()
|
||||
}
|
||||
}
|
||||
|
||||
// Build a combined user message from the steering messages.
|
||||
var contents []string
|
||||
for _, msg := range steeringMsgs {
|
||||
@@ -205,3 +218,44 @@ func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID s
|
||||
SkipInitialSteeringPoll: true,
|
||||
})
|
||||
}
|
||||
|
||||
func (al *AgentLoop) InterruptGraceful(hint string) error {
|
||||
ts := al.getActiveTurnState()
|
||||
if ts == nil {
|
||||
return fmt.Errorf("no active turn")
|
||||
}
|
||||
if !ts.requestGracefulInterrupt(hint) {
|
||||
return fmt.Errorf("turn %s cannot accept graceful interrupt", ts.turnID)
|
||||
}
|
||||
|
||||
al.emitEvent(
|
||||
EventKindInterruptReceived,
|
||||
ts.eventMeta("InterruptGraceful", "turn.interrupt.received"),
|
||||
InterruptReceivedPayload{
|
||||
Kind: InterruptKindGraceful,
|
||||
HintLen: len(hint),
|
||||
},
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (al *AgentLoop) InterruptHard() error {
|
||||
ts := al.getActiveTurnState()
|
||||
if ts == nil {
|
||||
return fmt.Errorf("no active turn")
|
||||
}
|
||||
if !ts.requestHardAbort() {
|
||||
return fmt.Errorf("turn %s is already aborting", ts.turnID)
|
||||
}
|
||||
|
||||
al.emitEvent(
|
||||
EventKindInterruptReceived,
|
||||
ts.eventMeta("InterruptHard", "turn.interrupt.received"),
|
||||
InterruptReceivedPayload{
|
||||
Kind: InterruptKindHard,
|
||||
},
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -12,6 +13,7 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
|
||||
@@ -396,6 +398,103 @@ func (m *toolCallProvider) GetDefaultModel() string {
|
||||
return "tool-call-mock"
|
||||
}
|
||||
|
||||
type gracefulCaptureProvider struct {
|
||||
mu sync.Mutex
|
||||
calls int
|
||||
toolCalls []providers.ToolCall
|
||||
finalResp string
|
||||
terminalMessages []providers.Message
|
||||
terminalToolsCount int
|
||||
}
|
||||
|
||||
func (p *gracefulCaptureProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.calls++
|
||||
|
||||
if p.calls == 1 {
|
||||
return &providers.LLMResponse{
|
||||
ToolCalls: p.toolCalls,
|
||||
}, nil
|
||||
}
|
||||
|
||||
p.terminalMessages = append([]providers.Message(nil), messages...)
|
||||
p.terminalToolsCount = len(tools)
|
||||
return &providers.LLMResponse{
|
||||
Content: p.finalResp,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *gracefulCaptureProvider) GetDefaultModel() string {
|
||||
return "graceful-capture-mock"
|
||||
}
|
||||
|
||||
type lateSteeringProvider struct {
|
||||
mu sync.Mutex
|
||||
calls int
|
||||
firstCallStarted chan struct{}
|
||||
releaseFirstCall chan struct{}
|
||||
firstStartOnce sync.Once
|
||||
secondCallMessages []providers.Message
|
||||
}
|
||||
|
||||
func (p *lateSteeringProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
p.mu.Lock()
|
||||
p.calls++
|
||||
call := p.calls
|
||||
p.mu.Unlock()
|
||||
|
||||
if call == 1 {
|
||||
p.firstStartOnce.Do(func() { close(p.firstCallStarted) })
|
||||
<-p.releaseFirstCall
|
||||
return &providers.LLMResponse{Content: "first response"}, nil
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
p.secondCallMessages = append([]providers.Message(nil), messages...)
|
||||
p.mu.Unlock()
|
||||
return &providers.LLMResponse{Content: "continued response"}, nil
|
||||
}
|
||||
|
||||
func (p *lateSteeringProvider) GetDefaultModel() string {
|
||||
return "late-steering-mock"
|
||||
}
|
||||
|
||||
type interruptibleTool struct {
|
||||
name string
|
||||
started chan struct{}
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func (t *interruptibleTool) Name() string { return t.name }
|
||||
func (t *interruptibleTool) Description() string { return "interruptible tool for testing" }
|
||||
func (t *interruptibleTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *interruptibleTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
|
||||
if t.started != nil {
|
||||
t.once.Do(func() { close(t.started) })
|
||||
}
|
||||
<-ctx.Done()
|
||||
return tools.ErrorResult(ctx.Err().Error()).WithError(ctx.Err())
|
||||
}
|
||||
|
||||
func TestAgentLoop_Steering_SkipsRemainingTools(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
@@ -568,6 +667,425 @@ func TestAgentLoop_Steering_InitialPoll(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_Run_AutoContinuesLateSteeringMessage(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &lateSteeringProvider{
|
||||
firstCallStarted: make(chan struct{}),
|
||||
releaseFirstCall: make(chan struct{}),
|
||||
}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
runCtx, cancelRun := context.WithCancel(context.Background())
|
||||
defer cancelRun()
|
||||
|
||||
runErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
runErrCh <- al.Run(runCtx)
|
||||
}()
|
||||
|
||||
first := bus.InboundMessage{
|
||||
Channel: "test",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "first message",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
}
|
||||
late := bus.InboundMessage{
|
||||
Channel: "test",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "late append",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
}
|
||||
|
||||
pubCtx, pubCancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer pubCancel()
|
||||
if err := msgBus.PublishInbound(pubCtx, first); err != nil {
|
||||
t.Fatalf("publish first inbound: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-provider.firstCallStarted:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for first provider call to start")
|
||||
}
|
||||
|
||||
if err := msgBus.PublishInbound(pubCtx, late); err != nil {
|
||||
t.Fatalf("publish late inbound: %v", err)
|
||||
}
|
||||
|
||||
close(provider.releaseFirstCall)
|
||||
|
||||
subCtx, subCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer subCancel()
|
||||
|
||||
out1, ok := msgBus.SubscribeOutbound(subCtx)
|
||||
if !ok {
|
||||
t.Fatal("expected first outbound response")
|
||||
}
|
||||
if out1.Content != "first response" {
|
||||
t.Fatalf("expected first response, got %q", out1.Content)
|
||||
}
|
||||
|
||||
out2, ok := msgBus.SubscribeOutbound(subCtx)
|
||||
if !ok {
|
||||
t.Fatal("expected continued outbound response")
|
||||
}
|
||||
if out2.Content != "continued response" {
|
||||
t.Fatalf("expected continued response, got %q", out2.Content)
|
||||
}
|
||||
|
||||
cancelRun()
|
||||
select {
|
||||
case err := <-runErrCh:
|
||||
if err != nil {
|
||||
t.Fatalf("Run returned error: %v", err)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for Run to stop")
|
||||
}
|
||||
|
||||
provider.mu.Lock()
|
||||
calls := provider.calls
|
||||
secondMessages := append([]providers.Message(nil), provider.secondCallMessages...)
|
||||
provider.mu.Unlock()
|
||||
|
||||
if calls != 2 {
|
||||
t.Fatalf("expected 2 provider calls, got %d", calls)
|
||||
}
|
||||
|
||||
foundLateMessage := false
|
||||
for _, msg := range secondMessages {
|
||||
if msg.Role == "user" && msg.Content == "late append" {
|
||||
foundLateMessage = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundLateMessage {
|
||||
t.Fatal("expected queued late message to be processed in an automatic follow-up turn")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_InterruptGraceful_UsesTerminalNoToolCall(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tool1ExecCh := make(chan struct{})
|
||||
tool1 := &slowTool{name: "tool_one", duration: 50 * time.Millisecond, execCh: tool1ExecCh}
|
||||
tool2 := &slowTool{name: "tool_two", duration: 50 * time.Millisecond}
|
||||
|
||||
provider := &gracefulCaptureProvider{
|
||||
toolCalls: []providers.ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Type: "function",
|
||||
Name: "tool_one",
|
||||
Function: &providers.FunctionCall{
|
||||
Name: "tool_one",
|
||||
Arguments: "{}",
|
||||
},
|
||||
Arguments: map[string]any{},
|
||||
},
|
||||
{
|
||||
ID: "call_2",
|
||||
Type: "function",
|
||||
Name: "tool_two",
|
||||
Function: &providers.FunctionCall{
|
||||
Name: "tool_two",
|
||||
Arguments: "{}",
|
||||
},
|
||||
Arguments: map[string]any{},
|
||||
},
|
||||
},
|
||||
finalResp: "graceful summary",
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
al.RegisterTool(tool1)
|
||||
al.RegisterTool(tool2)
|
||||
sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID)
|
||||
|
||||
sub := al.SubscribeEvents(32)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
|
||||
type result struct {
|
||||
resp string
|
||||
err error
|
||||
}
|
||||
resultCh := make(chan result, 1)
|
||||
go func() {
|
||||
resp, err := al.ProcessDirectWithChannel(
|
||||
context.Background(),
|
||||
"do something",
|
||||
sessionKey,
|
||||
"test",
|
||||
"chat1",
|
||||
)
|
||||
resultCh <- result{resp: resp, err: err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-tool1ExecCh:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for tool_one to start")
|
||||
}
|
||||
|
||||
active := al.GetActiveTurn()
|
||||
if active == nil {
|
||||
t.Fatal("expected active turn while tool is running")
|
||||
}
|
||||
if active.SessionKey != sessionKey {
|
||||
t.Fatalf("expected active session %q, got %q", sessionKey, active.SessionKey)
|
||||
}
|
||||
if active.Channel != "test" || active.ChatID != "chat1" {
|
||||
t.Fatalf("unexpected active turn target: %#v", active)
|
||||
}
|
||||
|
||||
if err := al.InterruptGraceful("wrap it up"); err != nil {
|
||||
t.Fatalf("InterruptGraceful failed: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case r := <-resultCh:
|
||||
if r.err != nil {
|
||||
t.Fatalf("unexpected error: %v", r.err)
|
||||
}
|
||||
if r.resp != "graceful summary" {
|
||||
t.Fatalf("expected graceful summary, got %q", r.resp)
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timeout waiting for graceful interrupt result")
|
||||
}
|
||||
|
||||
if active := al.GetActiveTurn(); active != nil {
|
||||
t.Fatalf("expected no active turn after completion, got %#v", active)
|
||||
}
|
||||
|
||||
provider.mu.Lock()
|
||||
terminalMessages := append([]providers.Message(nil), provider.terminalMessages...)
|
||||
terminalToolsCount := provider.terminalToolsCount
|
||||
calls := provider.calls
|
||||
provider.mu.Unlock()
|
||||
|
||||
if calls != 2 {
|
||||
t.Fatalf("expected 2 provider calls, got %d", calls)
|
||||
}
|
||||
if terminalToolsCount != 0 {
|
||||
t.Fatalf("expected graceful terminal call to disable tools, got %d tool defs", terminalToolsCount)
|
||||
}
|
||||
|
||||
foundHint := false
|
||||
foundSkipped := false
|
||||
for _, msg := range terminalMessages {
|
||||
if msg.Role == "user" && msg.Content == "Interrupt requested. Stop scheduling tools and provide a short final summary.\n\nInterrupt hint: wrap it up" {
|
||||
foundHint = true
|
||||
}
|
||||
if msg.Role == "tool" && msg.ToolCallID == "call_2" && msg.Content == "Skipped due to graceful interrupt." {
|
||||
foundSkipped = true
|
||||
}
|
||||
}
|
||||
if !foundHint {
|
||||
t.Fatal("expected graceful terminal call to include interrupt hint message")
|
||||
}
|
||||
if !foundSkipped {
|
||||
t.Fatal("expected remaining tool to be marked as skipped after graceful interrupt")
|
||||
}
|
||||
|
||||
events := collectEventStream(sub.C)
|
||||
interruptEvt, ok := findEvent(events, EventKindInterruptReceived)
|
||||
if !ok {
|
||||
t.Fatal("expected interrupt received event")
|
||||
}
|
||||
interruptPayload, ok := interruptEvt.Payload.(InterruptReceivedPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected InterruptReceivedPayload, got %T", interruptEvt.Payload)
|
||||
}
|
||||
if interruptPayload.Kind != InterruptKindGraceful {
|
||||
t.Fatalf("expected graceful interrupt payload, got %q", interruptPayload.Kind)
|
||||
}
|
||||
|
||||
turnEndEvt, ok := findEvent(events, EventKindTurnEnd)
|
||||
if !ok {
|
||||
t.Fatal("expected turn end event")
|
||||
}
|
||||
turnEndPayload, ok := turnEndEvt.Payload.(TurnEndPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected TurnEndPayload, got %T", turnEndEvt.Payload)
|
||||
}
|
||||
if turnEndPayload.Status != TurnEndStatusCompleted {
|
||||
t.Fatalf("expected completed turn after graceful interrupt, got %q", turnEndPayload.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_InterruptHard_RestoresSession(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &toolCallProvider{
|
||||
toolCalls: []providers.ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Type: "function",
|
||||
Name: "cancel_tool",
|
||||
Function: &providers.FunctionCall{
|
||||
Name: "cancel_tool",
|
||||
Arguments: "{}",
|
||||
},
|
||||
Arguments: map[string]any{},
|
||||
},
|
||||
},
|
||||
finalResp: "should not happen",
|
||||
}
|
||||
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
started := make(chan struct{})
|
||||
al.RegisterTool(&interruptibleTool{name: "cancel_tool", started: started})
|
||||
sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID)
|
||||
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
if defaultAgent == nil {
|
||||
t.Fatal("expected default agent")
|
||||
}
|
||||
|
||||
originalHistory := []providers.Message{
|
||||
{Role: "user", Content: "before"},
|
||||
{Role: "assistant", Content: "after"},
|
||||
}
|
||||
defaultAgent.Sessions.SetHistory(sessionKey, originalHistory)
|
||||
|
||||
sub := al.SubscribeEvents(16)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
|
||||
type result struct {
|
||||
resp string
|
||||
err error
|
||||
}
|
||||
resultCh := make(chan result, 1)
|
||||
go func() {
|
||||
resp, err := al.ProcessDirectWithChannel(
|
||||
context.Background(),
|
||||
"do work",
|
||||
sessionKey,
|
||||
"test",
|
||||
"chat1",
|
||||
)
|
||||
resultCh <- result{resp: resp, err: err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-started:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for interruptible tool to start")
|
||||
}
|
||||
|
||||
if active := al.GetActiveTurn(); active == nil {
|
||||
t.Fatal("expected active turn before hard abort")
|
||||
}
|
||||
|
||||
if err := al.InterruptHard(); err != nil {
|
||||
t.Fatalf("InterruptHard failed: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case r := <-resultCh:
|
||||
if r.err != nil {
|
||||
t.Fatalf("unexpected error: %v", r.err)
|
||||
}
|
||||
if r.resp != "" {
|
||||
t.Fatalf("expected no final response after hard abort, got %q", r.resp)
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timeout waiting for hard abort result")
|
||||
}
|
||||
|
||||
if active := al.GetActiveTurn(); active != nil {
|
||||
t.Fatalf("expected no active turn after hard abort, got %#v", active)
|
||||
}
|
||||
|
||||
finalHistory := defaultAgent.Sessions.GetHistory(sessionKey)
|
||||
if !reflect.DeepEqual(finalHistory, originalHistory) {
|
||||
t.Fatalf("expected history rollback after hard abort, got %#v", finalHistory)
|
||||
}
|
||||
|
||||
events := collectEventStream(sub.C)
|
||||
interruptEvt, ok := findEvent(events, EventKindInterruptReceived)
|
||||
if !ok {
|
||||
t.Fatal("expected interrupt received event")
|
||||
}
|
||||
interruptPayload, ok := interruptEvt.Payload.(InterruptReceivedPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected InterruptReceivedPayload, got %T", interruptEvt.Payload)
|
||||
}
|
||||
if interruptPayload.Kind != InterruptKindHard {
|
||||
t.Fatalf("expected hard interrupt payload, got %q", interruptPayload.Kind)
|
||||
}
|
||||
|
||||
turnEndEvt, ok := findEvent(events, EventKindTurnEnd)
|
||||
if !ok {
|
||||
t.Fatal("expected turn end event")
|
||||
}
|
||||
turnEndPayload, ok := turnEndEvt.Payload.(TurnEndPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected TurnEndPayload, got %T", turnEndEvt.Payload)
|
||||
}
|
||||
if turnEndPayload.Status != TurnEndStatusAborted {
|
||||
t.Fatalf("expected aborted turn, got %q", turnEndPayload.Status)
|
||||
}
|
||||
}
|
||||
|
||||
// capturingMockProvider captures messages sent to Chat for inspection.
|
||||
type capturingMockProvider struct {
|
||||
response string
|
||||
|
||||
@@ -0,0 +1,309 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
type TurnPhase string
|
||||
|
||||
const (
|
||||
TurnPhaseSetup TurnPhase = "setup"
|
||||
TurnPhaseRunning TurnPhase = "running"
|
||||
TurnPhaseTools TurnPhase = "tools"
|
||||
TurnPhaseFinalizing TurnPhase = "finalizing"
|
||||
TurnPhaseCompleted TurnPhase = "completed"
|
||||
TurnPhaseAborted TurnPhase = "aborted"
|
||||
)
|
||||
|
||||
type ActiveTurnInfo struct {
|
||||
TurnID string
|
||||
AgentID string
|
||||
SessionKey string
|
||||
Channel string
|
||||
ChatID string
|
||||
UserMessage string
|
||||
Phase TurnPhase
|
||||
Iteration int
|
||||
StartedAt time.Time
|
||||
}
|
||||
|
||||
type turnResult struct {
|
||||
finalContent string
|
||||
status TurnEndStatus
|
||||
followUps []bus.InboundMessage
|
||||
}
|
||||
|
||||
type turnState struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
agent *AgentInstance
|
||||
opts processOptions
|
||||
scope turnEventScope
|
||||
|
||||
turnID string
|
||||
agentID string
|
||||
sessionKey string
|
||||
|
||||
channel string
|
||||
chatID string
|
||||
userMessage string
|
||||
media []string
|
||||
|
||||
phase TurnPhase
|
||||
iteration int
|
||||
startedAt time.Time
|
||||
finalContent string
|
||||
|
||||
pendingSteering []providers.Message
|
||||
followUps []bus.InboundMessage
|
||||
|
||||
gracefulInterrupt bool
|
||||
gracefulInterruptHint string
|
||||
gracefulTerminalUsed bool
|
||||
hardAbort bool
|
||||
providerCancel context.CancelFunc
|
||||
turnCancel context.CancelFunc
|
||||
|
||||
restorePointHistory []providers.Message
|
||||
restorePointSummary string
|
||||
persistedMessages []providers.Message
|
||||
}
|
||||
|
||||
func newTurnState(agent *AgentInstance, opts processOptions, scope turnEventScope) *turnState {
|
||||
return &turnState{
|
||||
agent: agent,
|
||||
opts: opts,
|
||||
scope: scope,
|
||||
turnID: scope.turnID,
|
||||
agentID: agent.ID,
|
||||
sessionKey: opts.SessionKey,
|
||||
channel: opts.Channel,
|
||||
chatID: opts.ChatID,
|
||||
userMessage: opts.UserMessage,
|
||||
media: append([]string(nil), opts.Media...),
|
||||
phase: TurnPhaseSetup,
|
||||
startedAt: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func (al *AgentLoop) registerActiveTurn(ts *turnState) {
|
||||
al.activeTurnMu.Lock()
|
||||
defer al.activeTurnMu.Unlock()
|
||||
al.activeTurn = ts
|
||||
}
|
||||
|
||||
func (al *AgentLoop) clearActiveTurn(ts *turnState) {
|
||||
al.activeTurnMu.Lock()
|
||||
defer al.activeTurnMu.Unlock()
|
||||
if al.activeTurn == ts {
|
||||
al.activeTurn = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (al *AgentLoop) getActiveTurnState() *turnState {
|
||||
al.activeTurnMu.RLock()
|
||||
defer al.activeTurnMu.RUnlock()
|
||||
return al.activeTurn
|
||||
}
|
||||
|
||||
func (al *AgentLoop) GetActiveTurn() *ActiveTurnInfo {
|
||||
ts := al.getActiveTurnState()
|
||||
if ts == nil {
|
||||
return nil
|
||||
}
|
||||
info := ts.snapshot()
|
||||
return &info
|
||||
}
|
||||
|
||||
func (ts *turnState) snapshot() ActiveTurnInfo {
|
||||
ts.mu.RLock()
|
||||
defer ts.mu.RUnlock()
|
||||
|
||||
return ActiveTurnInfo{
|
||||
TurnID: ts.turnID,
|
||||
AgentID: ts.agentID,
|
||||
SessionKey: ts.sessionKey,
|
||||
Channel: ts.channel,
|
||||
ChatID: ts.chatID,
|
||||
UserMessage: ts.userMessage,
|
||||
Phase: ts.phase,
|
||||
Iteration: ts.iteration,
|
||||
StartedAt: ts.startedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func (ts *turnState) setPhase(phase TurnPhase) {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
ts.phase = phase
|
||||
}
|
||||
|
||||
func (ts *turnState) setIteration(iteration int) {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
ts.iteration = iteration
|
||||
}
|
||||
|
||||
func (ts *turnState) currentIteration() int {
|
||||
ts.mu.RLock()
|
||||
defer ts.mu.RUnlock()
|
||||
return ts.iteration
|
||||
}
|
||||
|
||||
func (ts *turnState) setFinalContent(content string) {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
ts.finalContent = content
|
||||
}
|
||||
|
||||
func (ts *turnState) finalContentLen() int {
|
||||
ts.mu.RLock()
|
||||
defer ts.mu.RUnlock()
|
||||
return len(ts.finalContent)
|
||||
}
|
||||
|
||||
func (ts *turnState) setTurnCancel(cancel context.CancelFunc) {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
ts.turnCancel = cancel
|
||||
}
|
||||
|
||||
func (ts *turnState) setProviderCancel(cancel context.CancelFunc) {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
ts.providerCancel = cancel
|
||||
}
|
||||
|
||||
func (ts *turnState) clearProviderCancel(_ context.CancelFunc) {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
ts.providerCancel = nil
|
||||
}
|
||||
|
||||
func (ts *turnState) requestGracefulInterrupt(hint string) bool {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
if ts.hardAbort {
|
||||
return false
|
||||
}
|
||||
ts.gracefulInterrupt = true
|
||||
ts.gracefulInterruptHint = hint
|
||||
return true
|
||||
}
|
||||
|
||||
func (ts *turnState) gracefulInterruptRequested() (bool, string) {
|
||||
ts.mu.RLock()
|
||||
defer ts.mu.RUnlock()
|
||||
return ts.gracefulInterrupt && !ts.gracefulTerminalUsed, ts.gracefulInterruptHint
|
||||
}
|
||||
|
||||
func (ts *turnState) markGracefulTerminalUsed() {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
ts.gracefulTerminalUsed = true
|
||||
}
|
||||
|
||||
func (ts *turnState) requestHardAbort() bool {
|
||||
ts.mu.Lock()
|
||||
if ts.hardAbort {
|
||||
ts.mu.Unlock()
|
||||
return false
|
||||
}
|
||||
ts.hardAbort = true
|
||||
turnCancel := ts.turnCancel
|
||||
providerCancel := ts.providerCancel
|
||||
ts.mu.Unlock()
|
||||
|
||||
if providerCancel != nil {
|
||||
providerCancel()
|
||||
}
|
||||
if turnCancel != nil {
|
||||
turnCancel()
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (ts *turnState) hardAbortRequested() bool {
|
||||
ts.mu.RLock()
|
||||
defer ts.mu.RUnlock()
|
||||
return ts.hardAbort
|
||||
}
|
||||
|
||||
func (ts *turnState) eventMeta(source, tracePath string) EventMeta {
|
||||
snap := ts.snapshot()
|
||||
return EventMeta{
|
||||
AgentID: snap.AgentID,
|
||||
TurnID: snap.TurnID,
|
||||
SessionKey: snap.SessionKey,
|
||||
Iteration: snap.Iteration,
|
||||
Source: source,
|
||||
TracePath: tracePath,
|
||||
}
|
||||
}
|
||||
|
||||
func (ts *turnState) captureRestorePoint(history []providers.Message, summary string) {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
ts.restorePointHistory = append([]providers.Message(nil), history...)
|
||||
ts.restorePointSummary = summary
|
||||
}
|
||||
|
||||
func (ts *turnState) recordPersistedMessage(msg providers.Message) {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
ts.persistedMessages = append(ts.persistedMessages, msg)
|
||||
}
|
||||
|
||||
func (ts *turnState) refreshRestorePointFromSession(agent *AgentInstance) {
|
||||
history := agent.Sessions.GetHistory(ts.sessionKey)
|
||||
summary := agent.Sessions.GetSummary(ts.sessionKey)
|
||||
|
||||
ts.mu.RLock()
|
||||
persisted := append([]providers.Message(nil), ts.persistedMessages...)
|
||||
ts.mu.RUnlock()
|
||||
|
||||
if matched := matchingTurnMessageTail(history, persisted); matched > 0 {
|
||||
history = append([]providers.Message(nil), history[:len(history)-matched]...)
|
||||
}
|
||||
|
||||
ts.captureRestorePoint(history, summary)
|
||||
}
|
||||
|
||||
func (ts *turnState) restoreSession(agent *AgentInstance) error {
|
||||
ts.mu.RLock()
|
||||
history := append([]providers.Message(nil), ts.restorePointHistory...)
|
||||
summary := ts.restorePointSummary
|
||||
ts.mu.RUnlock()
|
||||
|
||||
agent.Sessions.SetHistory(ts.sessionKey, history)
|
||||
agent.Sessions.SetSummary(ts.sessionKey, summary)
|
||||
return agent.Sessions.Save(ts.sessionKey)
|
||||
}
|
||||
|
||||
func matchingTurnMessageTail(history, persisted []providers.Message) int {
|
||||
maxMatch := min(len(history), len(persisted))
|
||||
for size := maxMatch; size > 0; size-- {
|
||||
if reflect.DeepEqual(history[len(history)-size:], persisted[len(persisted)-size:]) {
|
||||
return size
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (ts *turnState) interruptHintMessage() providers.Message {
|
||||
_, hint := ts.gracefulInterruptRequested()
|
||||
content := "Interrupt requested. Stop scheduling tools and provide a short final summary."
|
||||
if hint != "" {
|
||||
content += "\n\nInterrupt hint: " + hint
|
||||
}
|
||||
return providers.Message{
|
||||
Role: "user",
|
||||
Content: content,
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user