mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge pull request #1827 from alexhoshina/refactor/agent-loop
Refactor/agent loop
This commit is contained in:
@@ -334,6 +334,9 @@ func TestAgentLoop_EmitsSteeringAndSkippedToolEvents(t *testing.T) {
|
||||
if interruptPayload.Role != "user" {
|
||||
t.Fatalf("expected interrupt role user, got %q", interruptPayload.Role)
|
||||
}
|
||||
if interruptPayload.Kind != InterruptKindSteering {
|
||||
t.Fatalf("expected steering interrupt kind, got %q", interruptPayload.Kind)
|
||||
}
|
||||
if interruptPayload.ContentLen != len("change course") {
|
||||
t.Fatalf("expected interrupt content len %d, got %d", len("change course"), interruptPayload.ContentLen)
|
||||
}
|
||||
|
||||
+13
-1
@@ -105,6 +105,8 @@ const (
|
||||
TurnEndStatusCompleted TurnEndStatus = "completed"
|
||||
// TurnEndStatusError indicates the turn ended because of an error.
|
||||
TurnEndStatusError TurnEndStatus = "error"
|
||||
// TurnEndStatusAborted indicates the turn was hard-aborted and rolled back.
|
||||
TurnEndStatusAborted TurnEndStatus = "aborted"
|
||||
)
|
||||
|
||||
// TurnStartPayload describes the start of a turn.
|
||||
@@ -215,11 +217,21 @@ type FollowUpQueuedPayload struct {
|
||||
ContentLen int
|
||||
}
|
||||
|
||||
// InterruptReceivedPayload describes a queued soft interrupt.
|
||||
type InterruptKind string
|
||||
|
||||
const (
|
||||
InterruptKindSteering InterruptKind = "steering"
|
||||
InterruptKindGraceful InterruptKind = "graceful"
|
||||
InterruptKindHard InterruptKind = "hard_abort"
|
||||
)
|
||||
|
||||
// InterruptReceivedPayload describes accepted turn-control input.
|
||||
type InterruptReceivedPayload struct {
|
||||
Kind InterruptKind
|
||||
Role string
|
||||
ContentLen int
|
||||
QueueDepth int
|
||||
HintLen int
|
||||
}
|
||||
|
||||
// SubTurnSpawnPayload describes the creation of a child turn.
|
||||
|
||||
+490
-328
File diff suppressed because it is too large
Load Diff
+62
-8
@@ -122,20 +122,23 @@ func (al *AgentLoop) Steer(msg providers.Message) error {
|
||||
"content_len": len(msg.Content),
|
||||
"queue_len": al.steering.len(),
|
||||
})
|
||||
agentID := ""
|
||||
if registry := al.GetRegistry(); registry != nil {
|
||||
|
||||
meta := EventMeta{
|
||||
Source: "Steer",
|
||||
TracePath: "turn.interrupt.received",
|
||||
}
|
||||
if ts := al.getActiveTurnState(); ts != nil {
|
||||
meta = ts.eventMeta("Steer", "turn.interrupt.received")
|
||||
} else if registry := al.GetRegistry(); registry != nil {
|
||||
if agent := registry.GetDefaultAgent(); agent != nil {
|
||||
agentID = agent.ID
|
||||
meta.AgentID = agent.ID
|
||||
}
|
||||
}
|
||||
al.emitEvent(
|
||||
EventKindInterruptReceived,
|
||||
EventMeta{
|
||||
AgentID: agentID,
|
||||
Source: "Steer",
|
||||
TracePath: "turn.interrupt.received",
|
||||
},
|
||||
meta,
|
||||
InterruptReceivedPayload{
|
||||
Kind: InterruptKindSteering,
|
||||
Role: msg.Role,
|
||||
ContentLen: len(msg.Content),
|
||||
QueueDepth: al.steering.len(),
|
||||
@@ -177,6 +180,10 @@ func (al *AgentLoop) dequeueSteeringMessages() []providers.Message {
|
||||
//
|
||||
// If no steering messages are pending, it returns an empty string.
|
||||
func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID string) (string, error) {
|
||||
if active := al.GetActiveTurn(); active != nil {
|
||||
return "", fmt.Errorf("turn %s is still active", active.TurnID)
|
||||
}
|
||||
|
||||
steeringMsgs := al.dequeueSteeringMessages()
|
||||
if len(steeringMsgs) == 0 {
|
||||
return "", nil
|
||||
@@ -187,6 +194,12 @@ func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID s
|
||||
return "", fmt.Errorf("no default agent available")
|
||||
}
|
||||
|
||||
if tool, ok := agent.Tools.Get("message"); ok {
|
||||
if resetter, ok := tool.(interface{ ResetSentInRound() }); ok {
|
||||
resetter.ResetSentInRound()
|
||||
}
|
||||
}
|
||||
|
||||
// Build a combined user message from the steering messages.
|
||||
var contents []string
|
||||
for _, msg := range steeringMsgs {
|
||||
@@ -205,3 +218,44 @@ func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID s
|
||||
SkipInitialSteeringPoll: true,
|
||||
})
|
||||
}
|
||||
|
||||
func (al *AgentLoop) InterruptGraceful(hint string) error {
|
||||
ts := al.getActiveTurnState()
|
||||
if ts == nil {
|
||||
return fmt.Errorf("no active turn")
|
||||
}
|
||||
if !ts.requestGracefulInterrupt(hint) {
|
||||
return fmt.Errorf("turn %s cannot accept graceful interrupt", ts.turnID)
|
||||
}
|
||||
|
||||
al.emitEvent(
|
||||
EventKindInterruptReceived,
|
||||
ts.eventMeta("InterruptGraceful", "turn.interrupt.received"),
|
||||
InterruptReceivedPayload{
|
||||
Kind: InterruptKindGraceful,
|
||||
HintLen: len(hint),
|
||||
},
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (al *AgentLoop) InterruptHard() error {
|
||||
ts := al.getActiveTurnState()
|
||||
if ts == nil {
|
||||
return fmt.Errorf("no active turn")
|
||||
}
|
||||
if !ts.requestHardAbort() {
|
||||
return fmt.Errorf("turn %s is already aborting", ts.turnID)
|
||||
}
|
||||
|
||||
al.emitEvent(
|
||||
EventKindInterruptReceived,
|
||||
ts.eventMeta("InterruptHard", "turn.interrupt.received"),
|
||||
InterruptReceivedPayload{
|
||||
Kind: InterruptKindHard,
|
||||
},
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -12,6 +13,7 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
|
||||
@@ -396,6 +398,103 @@ func (m *toolCallProvider) GetDefaultModel() string {
|
||||
return "tool-call-mock"
|
||||
}
|
||||
|
||||
type gracefulCaptureProvider struct {
|
||||
mu sync.Mutex
|
||||
calls int
|
||||
toolCalls []providers.ToolCall
|
||||
finalResp string
|
||||
terminalMessages []providers.Message
|
||||
terminalToolsCount int
|
||||
}
|
||||
|
||||
func (p *gracefulCaptureProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.calls++
|
||||
|
||||
if p.calls == 1 {
|
||||
return &providers.LLMResponse{
|
||||
ToolCalls: p.toolCalls,
|
||||
}, nil
|
||||
}
|
||||
|
||||
p.terminalMessages = append([]providers.Message(nil), messages...)
|
||||
p.terminalToolsCount = len(tools)
|
||||
return &providers.LLMResponse{
|
||||
Content: p.finalResp,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *gracefulCaptureProvider) GetDefaultModel() string {
|
||||
return "graceful-capture-mock"
|
||||
}
|
||||
|
||||
type lateSteeringProvider struct {
|
||||
mu sync.Mutex
|
||||
calls int
|
||||
firstCallStarted chan struct{}
|
||||
releaseFirstCall chan struct{}
|
||||
firstStartOnce sync.Once
|
||||
secondCallMessages []providers.Message
|
||||
}
|
||||
|
||||
func (p *lateSteeringProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
p.mu.Lock()
|
||||
p.calls++
|
||||
call := p.calls
|
||||
p.mu.Unlock()
|
||||
|
||||
if call == 1 {
|
||||
p.firstStartOnce.Do(func() { close(p.firstCallStarted) })
|
||||
<-p.releaseFirstCall
|
||||
return &providers.LLMResponse{Content: "first response"}, nil
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
p.secondCallMessages = append([]providers.Message(nil), messages...)
|
||||
p.mu.Unlock()
|
||||
return &providers.LLMResponse{Content: "continued response"}, nil
|
||||
}
|
||||
|
||||
func (p *lateSteeringProvider) GetDefaultModel() string {
|
||||
return "late-steering-mock"
|
||||
}
|
||||
|
||||
type interruptibleTool struct {
|
||||
name string
|
||||
started chan struct{}
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func (t *interruptibleTool) Name() string { return t.name }
|
||||
func (t *interruptibleTool) Description() string { return "interruptible tool for testing" }
|
||||
func (t *interruptibleTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *interruptibleTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
|
||||
if t.started != nil {
|
||||
t.once.Do(func() { close(t.started) })
|
||||
}
|
||||
<-ctx.Done()
|
||||
return tools.ErrorResult(ctx.Err().Error()).WithError(ctx.Err())
|
||||
}
|
||||
|
||||
func TestAgentLoop_Steering_SkipsRemainingTools(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
@@ -568,6 +667,427 @@ func TestAgentLoop_Steering_InitialPoll(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_Run_AutoContinuesLateSteeringMessage(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &lateSteeringProvider{
|
||||
firstCallStarted: make(chan struct{}),
|
||||
releaseFirstCall: make(chan struct{}),
|
||||
}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
runCtx, cancelRun := context.WithCancel(context.Background())
|
||||
defer cancelRun()
|
||||
|
||||
runErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
runErrCh <- al.Run(runCtx)
|
||||
}()
|
||||
|
||||
first := bus.InboundMessage{
|
||||
Channel: "test",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "first message",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
}
|
||||
late := bus.InboundMessage{
|
||||
Channel: "test",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "late append",
|
||||
Peer: bus.Peer{
|
||||
Kind: "direct",
|
||||
ID: "user1",
|
||||
},
|
||||
}
|
||||
|
||||
pubCtx, pubCancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer pubCancel()
|
||||
if err := msgBus.PublishInbound(pubCtx, first); err != nil {
|
||||
t.Fatalf("publish first inbound: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-provider.firstCallStarted:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for first provider call to start")
|
||||
}
|
||||
|
||||
if err := msgBus.PublishInbound(pubCtx, late); err != nil {
|
||||
t.Fatalf("publish late inbound: %v", err)
|
||||
}
|
||||
|
||||
close(provider.releaseFirstCall)
|
||||
|
||||
subCtx, subCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer subCancel()
|
||||
|
||||
out1, ok := msgBus.SubscribeOutbound(subCtx)
|
||||
if !ok {
|
||||
t.Fatal("expected first outbound response")
|
||||
}
|
||||
if out1.Content != "first response" {
|
||||
t.Fatalf("expected first response, got %q", out1.Content)
|
||||
}
|
||||
|
||||
out2, ok := msgBus.SubscribeOutbound(subCtx)
|
||||
if !ok {
|
||||
t.Fatal("expected continued outbound response")
|
||||
}
|
||||
if out2.Content != "continued response" {
|
||||
t.Fatalf("expected continued response, got %q", out2.Content)
|
||||
}
|
||||
|
||||
cancelRun()
|
||||
select {
|
||||
case err := <-runErrCh:
|
||||
if err != nil {
|
||||
t.Fatalf("Run returned error: %v", err)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for Run to stop")
|
||||
}
|
||||
|
||||
provider.mu.Lock()
|
||||
calls := provider.calls
|
||||
secondMessages := append([]providers.Message(nil), provider.secondCallMessages...)
|
||||
provider.mu.Unlock()
|
||||
|
||||
if calls != 2 {
|
||||
t.Fatalf("expected 2 provider calls, got %d", calls)
|
||||
}
|
||||
|
||||
foundLateMessage := false
|
||||
for _, msg := range secondMessages {
|
||||
if msg.Role == "user" && msg.Content == "late append" {
|
||||
foundLateMessage = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundLateMessage {
|
||||
t.Fatal("expected queued late message to be processed in an automatic follow-up turn")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_InterruptGraceful_UsesTerminalNoToolCall(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tool1ExecCh := make(chan struct{})
|
||||
tool1 := &slowTool{name: "tool_one", duration: 50 * time.Millisecond, execCh: tool1ExecCh}
|
||||
tool2 := &slowTool{name: "tool_two", duration: 50 * time.Millisecond}
|
||||
|
||||
provider := &gracefulCaptureProvider{
|
||||
toolCalls: []providers.ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Type: "function",
|
||||
Name: "tool_one",
|
||||
Function: &providers.FunctionCall{
|
||||
Name: "tool_one",
|
||||
Arguments: "{}",
|
||||
},
|
||||
Arguments: map[string]any{},
|
||||
},
|
||||
{
|
||||
ID: "call_2",
|
||||
Type: "function",
|
||||
Name: "tool_two",
|
||||
Function: &providers.FunctionCall{
|
||||
Name: "tool_two",
|
||||
Arguments: "{}",
|
||||
},
|
||||
Arguments: map[string]any{},
|
||||
},
|
||||
},
|
||||
finalResp: "graceful summary",
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
al.RegisterTool(tool1)
|
||||
al.RegisterTool(tool2)
|
||||
sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID)
|
||||
|
||||
sub := al.SubscribeEvents(32)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
|
||||
type result struct {
|
||||
resp string
|
||||
err error
|
||||
}
|
||||
resultCh := make(chan result, 1)
|
||||
go func() {
|
||||
resp, err := al.ProcessDirectWithChannel(
|
||||
context.Background(),
|
||||
"do something",
|
||||
sessionKey,
|
||||
"test",
|
||||
"chat1",
|
||||
)
|
||||
resultCh <- result{resp: resp, err: err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-tool1ExecCh:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for tool_one to start")
|
||||
}
|
||||
|
||||
active := al.GetActiveTurn()
|
||||
if active == nil {
|
||||
t.Fatal("expected active turn while tool is running")
|
||||
}
|
||||
if active.SessionKey != sessionKey {
|
||||
t.Fatalf("expected active session %q, got %q", sessionKey, active.SessionKey)
|
||||
}
|
||||
if active.Channel != "test" || active.ChatID != "chat1" {
|
||||
t.Fatalf("unexpected active turn target: %#v", active)
|
||||
}
|
||||
|
||||
if err := al.InterruptGraceful("wrap it up"); err != nil {
|
||||
t.Fatalf("InterruptGraceful failed: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case r := <-resultCh:
|
||||
if r.err != nil {
|
||||
t.Fatalf("unexpected error: %v", r.err)
|
||||
}
|
||||
if r.resp != "graceful summary" {
|
||||
t.Fatalf("expected graceful summary, got %q", r.resp)
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timeout waiting for graceful interrupt result")
|
||||
}
|
||||
|
||||
if active := al.GetActiveTurn(); active != nil {
|
||||
t.Fatalf("expected no active turn after completion, got %#v", active)
|
||||
}
|
||||
|
||||
provider.mu.Lock()
|
||||
terminalMessages := append([]providers.Message(nil), provider.terminalMessages...)
|
||||
terminalToolsCount := provider.terminalToolsCount
|
||||
calls := provider.calls
|
||||
provider.mu.Unlock()
|
||||
|
||||
if calls != 2 {
|
||||
t.Fatalf("expected 2 provider calls, got %d", calls)
|
||||
}
|
||||
if terminalToolsCount != 0 {
|
||||
t.Fatalf("expected graceful terminal call to disable tools, got %d tool defs", terminalToolsCount)
|
||||
}
|
||||
|
||||
foundHint := false
|
||||
foundSkipped := false
|
||||
expectedHint := "Interrupt requested. Stop scheduling tools and provide a short final summary.\n\n" +
|
||||
"Interrupt hint: wrap it up"
|
||||
for _, msg := range terminalMessages {
|
||||
if msg.Role == "user" && msg.Content == expectedHint {
|
||||
foundHint = true
|
||||
}
|
||||
if msg.Role == "tool" && msg.ToolCallID == "call_2" && msg.Content == "Skipped due to graceful interrupt." {
|
||||
foundSkipped = true
|
||||
}
|
||||
}
|
||||
if !foundHint {
|
||||
t.Fatal("expected graceful terminal call to include interrupt hint message")
|
||||
}
|
||||
if !foundSkipped {
|
||||
t.Fatal("expected remaining tool to be marked as skipped after graceful interrupt")
|
||||
}
|
||||
|
||||
events := collectEventStream(sub.C)
|
||||
interruptEvt, ok := findEvent(events, EventKindInterruptReceived)
|
||||
if !ok {
|
||||
t.Fatal("expected interrupt received event")
|
||||
}
|
||||
interruptPayload, ok := interruptEvt.Payload.(InterruptReceivedPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected InterruptReceivedPayload, got %T", interruptEvt.Payload)
|
||||
}
|
||||
if interruptPayload.Kind != InterruptKindGraceful {
|
||||
t.Fatalf("expected graceful interrupt payload, got %q", interruptPayload.Kind)
|
||||
}
|
||||
|
||||
turnEndEvt, ok := findEvent(events, EventKindTurnEnd)
|
||||
if !ok {
|
||||
t.Fatal("expected turn end event")
|
||||
}
|
||||
turnEndPayload, ok := turnEndEvt.Payload.(TurnEndPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected TurnEndPayload, got %T", turnEndEvt.Payload)
|
||||
}
|
||||
if turnEndPayload.Status != TurnEndStatusCompleted {
|
||||
t.Fatalf("expected completed turn after graceful interrupt, got %q", turnEndPayload.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_InterruptHard_RestoresSession(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &toolCallProvider{
|
||||
toolCalls: []providers.ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Type: "function",
|
||||
Name: "cancel_tool",
|
||||
Function: &providers.FunctionCall{
|
||||
Name: "cancel_tool",
|
||||
Arguments: "{}",
|
||||
},
|
||||
Arguments: map[string]any{},
|
||||
},
|
||||
},
|
||||
finalResp: "should not happen",
|
||||
}
|
||||
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
started := make(chan struct{})
|
||||
al.RegisterTool(&interruptibleTool{name: "cancel_tool", started: started})
|
||||
sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID)
|
||||
|
||||
defaultAgent := al.registry.GetDefaultAgent()
|
||||
if defaultAgent == nil {
|
||||
t.Fatal("expected default agent")
|
||||
}
|
||||
|
||||
originalHistory := []providers.Message{
|
||||
{Role: "user", Content: "before"},
|
||||
{Role: "assistant", Content: "after"},
|
||||
}
|
||||
defaultAgent.Sessions.SetHistory(sessionKey, originalHistory)
|
||||
|
||||
sub := al.SubscribeEvents(16)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
|
||||
type result struct {
|
||||
resp string
|
||||
err error
|
||||
}
|
||||
resultCh := make(chan result, 1)
|
||||
go func() {
|
||||
resp, err := al.ProcessDirectWithChannel(
|
||||
context.Background(),
|
||||
"do work",
|
||||
sessionKey,
|
||||
"test",
|
||||
"chat1",
|
||||
)
|
||||
resultCh <- result{resp: resp, err: err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-started:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for interruptible tool to start")
|
||||
}
|
||||
|
||||
if active := al.GetActiveTurn(); active == nil {
|
||||
t.Fatal("expected active turn before hard abort")
|
||||
}
|
||||
|
||||
if err := al.InterruptHard(); err != nil {
|
||||
t.Fatalf("InterruptHard failed: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case r := <-resultCh:
|
||||
if r.err != nil {
|
||||
t.Fatalf("unexpected error: %v", r.err)
|
||||
}
|
||||
if r.resp != "" {
|
||||
t.Fatalf("expected no final response after hard abort, got %q", r.resp)
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timeout waiting for hard abort result")
|
||||
}
|
||||
|
||||
if active := al.GetActiveTurn(); active != nil {
|
||||
t.Fatalf("expected no active turn after hard abort, got %#v", active)
|
||||
}
|
||||
|
||||
finalHistory := defaultAgent.Sessions.GetHistory(sessionKey)
|
||||
if !reflect.DeepEqual(finalHistory, originalHistory) {
|
||||
t.Fatalf("expected history rollback after hard abort, got %#v", finalHistory)
|
||||
}
|
||||
|
||||
events := collectEventStream(sub.C)
|
||||
interruptEvt, ok := findEvent(events, EventKindInterruptReceived)
|
||||
if !ok {
|
||||
t.Fatal("expected interrupt received event")
|
||||
}
|
||||
interruptPayload, ok := interruptEvt.Payload.(InterruptReceivedPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected InterruptReceivedPayload, got %T", interruptEvt.Payload)
|
||||
}
|
||||
if interruptPayload.Kind != InterruptKindHard {
|
||||
t.Fatalf("expected hard interrupt payload, got %q", interruptPayload.Kind)
|
||||
}
|
||||
|
||||
turnEndEvt, ok := findEvent(events, EventKindTurnEnd)
|
||||
if !ok {
|
||||
t.Fatal("expected turn end event")
|
||||
}
|
||||
turnEndPayload, ok := turnEndEvt.Payload.(TurnEndPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected TurnEndPayload, got %T", turnEndEvt.Payload)
|
||||
}
|
||||
if turnEndPayload.Status != TurnEndStatusAborted {
|
||||
t.Fatalf("expected aborted turn, got %q", turnEndPayload.Status)
|
||||
}
|
||||
}
|
||||
|
||||
// capturingMockProvider captures messages sent to Chat for inspection.
|
||||
type capturingMockProvider struct {
|
||||
response string
|
||||
|
||||
@@ -0,0 +1,308 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
type TurnPhase string
|
||||
|
||||
const (
|
||||
TurnPhaseSetup TurnPhase = "setup"
|
||||
TurnPhaseRunning TurnPhase = "running"
|
||||
TurnPhaseTools TurnPhase = "tools"
|
||||
TurnPhaseFinalizing TurnPhase = "finalizing"
|
||||
TurnPhaseCompleted TurnPhase = "completed"
|
||||
TurnPhaseAborted TurnPhase = "aborted"
|
||||
)
|
||||
|
||||
type ActiveTurnInfo struct {
|
||||
TurnID string
|
||||
AgentID string
|
||||
SessionKey string
|
||||
Channel string
|
||||
ChatID string
|
||||
UserMessage string
|
||||
Phase TurnPhase
|
||||
Iteration int
|
||||
StartedAt time.Time
|
||||
}
|
||||
|
||||
type turnResult struct {
|
||||
finalContent string
|
||||
status TurnEndStatus
|
||||
followUps []bus.InboundMessage
|
||||
}
|
||||
|
||||
type turnState struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
agent *AgentInstance
|
||||
opts processOptions
|
||||
scope turnEventScope
|
||||
|
||||
turnID string
|
||||
agentID string
|
||||
sessionKey string
|
||||
|
||||
channel string
|
||||
chatID string
|
||||
userMessage string
|
||||
media []string
|
||||
|
||||
phase TurnPhase
|
||||
iteration int
|
||||
startedAt time.Time
|
||||
finalContent string
|
||||
|
||||
followUps []bus.InboundMessage
|
||||
|
||||
gracefulInterrupt bool
|
||||
gracefulInterruptHint string
|
||||
gracefulTerminalUsed bool
|
||||
hardAbort bool
|
||||
providerCancel context.CancelFunc
|
||||
turnCancel context.CancelFunc
|
||||
|
||||
restorePointHistory []providers.Message
|
||||
restorePointSummary string
|
||||
persistedMessages []providers.Message
|
||||
}
|
||||
|
||||
func newTurnState(agent *AgentInstance, opts processOptions, scope turnEventScope) *turnState {
|
||||
return &turnState{
|
||||
agent: agent,
|
||||
opts: opts,
|
||||
scope: scope,
|
||||
turnID: scope.turnID,
|
||||
agentID: agent.ID,
|
||||
sessionKey: opts.SessionKey,
|
||||
channel: opts.Channel,
|
||||
chatID: opts.ChatID,
|
||||
userMessage: opts.UserMessage,
|
||||
media: append([]string(nil), opts.Media...),
|
||||
phase: TurnPhaseSetup,
|
||||
startedAt: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func (al *AgentLoop) registerActiveTurn(ts *turnState) {
|
||||
al.activeTurnMu.Lock()
|
||||
defer al.activeTurnMu.Unlock()
|
||||
al.activeTurn = ts
|
||||
}
|
||||
|
||||
func (al *AgentLoop) clearActiveTurn(ts *turnState) {
|
||||
al.activeTurnMu.Lock()
|
||||
defer al.activeTurnMu.Unlock()
|
||||
if al.activeTurn == ts {
|
||||
al.activeTurn = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (al *AgentLoop) getActiveTurnState() *turnState {
|
||||
al.activeTurnMu.RLock()
|
||||
defer al.activeTurnMu.RUnlock()
|
||||
return al.activeTurn
|
||||
}
|
||||
|
||||
func (al *AgentLoop) GetActiveTurn() *ActiveTurnInfo {
|
||||
ts := al.getActiveTurnState()
|
||||
if ts == nil {
|
||||
return nil
|
||||
}
|
||||
info := ts.snapshot()
|
||||
return &info
|
||||
}
|
||||
|
||||
func (ts *turnState) snapshot() ActiveTurnInfo {
|
||||
ts.mu.RLock()
|
||||
defer ts.mu.RUnlock()
|
||||
|
||||
return ActiveTurnInfo{
|
||||
TurnID: ts.turnID,
|
||||
AgentID: ts.agentID,
|
||||
SessionKey: ts.sessionKey,
|
||||
Channel: ts.channel,
|
||||
ChatID: ts.chatID,
|
||||
UserMessage: ts.userMessage,
|
||||
Phase: ts.phase,
|
||||
Iteration: ts.iteration,
|
||||
StartedAt: ts.startedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func (ts *turnState) setPhase(phase TurnPhase) {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
ts.phase = phase
|
||||
}
|
||||
|
||||
func (ts *turnState) setIteration(iteration int) {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
ts.iteration = iteration
|
||||
}
|
||||
|
||||
func (ts *turnState) currentIteration() int {
|
||||
ts.mu.RLock()
|
||||
defer ts.mu.RUnlock()
|
||||
return ts.iteration
|
||||
}
|
||||
|
||||
func (ts *turnState) setFinalContent(content string) {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
ts.finalContent = content
|
||||
}
|
||||
|
||||
func (ts *turnState) finalContentLen() int {
|
||||
ts.mu.RLock()
|
||||
defer ts.mu.RUnlock()
|
||||
return len(ts.finalContent)
|
||||
}
|
||||
|
||||
func (ts *turnState) setTurnCancel(cancel context.CancelFunc) {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
ts.turnCancel = cancel
|
||||
}
|
||||
|
||||
func (ts *turnState) setProviderCancel(cancel context.CancelFunc) {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
ts.providerCancel = cancel
|
||||
}
|
||||
|
||||
func (ts *turnState) clearProviderCancel(_ context.CancelFunc) {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
ts.providerCancel = nil
|
||||
}
|
||||
|
||||
func (ts *turnState) requestGracefulInterrupt(hint string) bool {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
if ts.hardAbort {
|
||||
return false
|
||||
}
|
||||
ts.gracefulInterrupt = true
|
||||
ts.gracefulInterruptHint = hint
|
||||
return true
|
||||
}
|
||||
|
||||
func (ts *turnState) gracefulInterruptRequested() (bool, string) {
|
||||
ts.mu.RLock()
|
||||
defer ts.mu.RUnlock()
|
||||
return ts.gracefulInterrupt && !ts.gracefulTerminalUsed, ts.gracefulInterruptHint
|
||||
}
|
||||
|
||||
func (ts *turnState) markGracefulTerminalUsed() {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
ts.gracefulTerminalUsed = true
|
||||
}
|
||||
|
||||
func (ts *turnState) requestHardAbort() bool {
|
||||
ts.mu.Lock()
|
||||
if ts.hardAbort {
|
||||
ts.mu.Unlock()
|
||||
return false
|
||||
}
|
||||
ts.hardAbort = true
|
||||
turnCancel := ts.turnCancel
|
||||
providerCancel := ts.providerCancel
|
||||
ts.mu.Unlock()
|
||||
|
||||
if providerCancel != nil {
|
||||
providerCancel()
|
||||
}
|
||||
if turnCancel != nil {
|
||||
turnCancel()
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (ts *turnState) hardAbortRequested() bool {
|
||||
ts.mu.RLock()
|
||||
defer ts.mu.RUnlock()
|
||||
return ts.hardAbort
|
||||
}
|
||||
|
||||
func (ts *turnState) eventMeta(source, tracePath string) EventMeta {
|
||||
snap := ts.snapshot()
|
||||
return EventMeta{
|
||||
AgentID: snap.AgentID,
|
||||
TurnID: snap.TurnID,
|
||||
SessionKey: snap.SessionKey,
|
||||
Iteration: snap.Iteration,
|
||||
Source: source,
|
||||
TracePath: tracePath,
|
||||
}
|
||||
}
|
||||
|
||||
func (ts *turnState) captureRestorePoint(history []providers.Message, summary string) {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
ts.restorePointHistory = append([]providers.Message(nil), history...)
|
||||
ts.restorePointSummary = summary
|
||||
}
|
||||
|
||||
func (ts *turnState) recordPersistedMessage(msg providers.Message) {
|
||||
ts.mu.Lock()
|
||||
defer ts.mu.Unlock()
|
||||
ts.persistedMessages = append(ts.persistedMessages, msg)
|
||||
}
|
||||
|
||||
func (ts *turnState) refreshRestorePointFromSession(agent *AgentInstance) {
|
||||
history := agent.Sessions.GetHistory(ts.sessionKey)
|
||||
summary := agent.Sessions.GetSummary(ts.sessionKey)
|
||||
|
||||
ts.mu.RLock()
|
||||
persisted := append([]providers.Message(nil), ts.persistedMessages...)
|
||||
ts.mu.RUnlock()
|
||||
|
||||
if matched := matchingTurnMessageTail(history, persisted); matched > 0 {
|
||||
history = append([]providers.Message(nil), history[:len(history)-matched]...)
|
||||
}
|
||||
|
||||
ts.captureRestorePoint(history, summary)
|
||||
}
|
||||
|
||||
func (ts *turnState) restoreSession(agent *AgentInstance) error {
|
||||
ts.mu.RLock()
|
||||
history := append([]providers.Message(nil), ts.restorePointHistory...)
|
||||
summary := ts.restorePointSummary
|
||||
ts.mu.RUnlock()
|
||||
|
||||
agent.Sessions.SetHistory(ts.sessionKey, history)
|
||||
agent.Sessions.SetSummary(ts.sessionKey, summary)
|
||||
return agent.Sessions.Save(ts.sessionKey)
|
||||
}
|
||||
|
||||
func matchingTurnMessageTail(history, persisted []providers.Message) int {
|
||||
maxMatch := min(len(history), len(persisted))
|
||||
for size := maxMatch; size > 0; size-- {
|
||||
if reflect.DeepEqual(history[len(history)-size:], persisted[len(persisted)-size:]) {
|
||||
return size
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (ts *turnState) interruptHintMessage() providers.Message {
|
||||
_, hint := ts.gracefulInterruptRequested()
|
||||
content := "Interrupt requested. Stop scheduling tools and provide a short final summary."
|
||||
if hint != "" {
|
||||
content += "\n\nInterrupt hint: " + hint
|
||||
}
|
||||
return providers.Message{
|
||||
Role: "user",
|
||||
Content: content,
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user