mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
0e075f7300
Refactor agent loop execution around runTurn, add explicit turn state and interrupt semantics, and automatically continue queued steering that misses the current turn boundary.
1263 lines
31 KiB
Go
1263 lines
31 KiB
Go
package agent
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"os"
|
|
"reflect"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/sipeed/picoclaw/pkg/bus"
|
|
"github.com/sipeed/picoclaw/pkg/config"
|
|
"github.com/sipeed/picoclaw/pkg/providers"
|
|
"github.com/sipeed/picoclaw/pkg/routing"
|
|
"github.com/sipeed/picoclaw/pkg/tools"
|
|
)
|
|
|
|
// --- steeringQueue unit tests ---
|
|
|
|
func TestSteeringQueue_PushDequeue_OneAtATime(t *testing.T) {
|
|
sq := newSteeringQueue(SteeringOneAtATime)
|
|
|
|
sq.push(providers.Message{Role: "user", Content: "msg1"})
|
|
sq.push(providers.Message{Role: "user", Content: "msg2"})
|
|
sq.push(providers.Message{Role: "user", Content: "msg3"})
|
|
|
|
if sq.len() != 3 {
|
|
t.Fatalf("expected 3 messages, got %d", sq.len())
|
|
}
|
|
|
|
msgs := sq.dequeue()
|
|
if len(msgs) != 1 {
|
|
t.Fatalf("expected 1 message in one-at-a-time mode, got %d", len(msgs))
|
|
}
|
|
if msgs[0].Content != "msg1" {
|
|
t.Fatalf("expected 'msg1', got %q", msgs[0].Content)
|
|
}
|
|
if sq.len() != 2 {
|
|
t.Fatalf("expected 2 remaining, got %d", sq.len())
|
|
}
|
|
|
|
msgs = sq.dequeue()
|
|
if len(msgs) != 1 || msgs[0].Content != "msg2" {
|
|
t.Fatalf("expected 'msg2', got %v", msgs)
|
|
}
|
|
|
|
msgs = sq.dequeue()
|
|
if len(msgs) != 1 || msgs[0].Content != "msg3" {
|
|
t.Fatalf("expected 'msg3', got %v", msgs)
|
|
}
|
|
|
|
msgs = sq.dequeue()
|
|
if msgs != nil {
|
|
t.Fatalf("expected nil from empty queue, got %v", msgs)
|
|
}
|
|
}
|
|
|
|
func TestSteeringQueue_PushDequeue_All(t *testing.T) {
|
|
sq := newSteeringQueue(SteeringAll)
|
|
|
|
sq.push(providers.Message{Role: "user", Content: "msg1"})
|
|
sq.push(providers.Message{Role: "user", Content: "msg2"})
|
|
sq.push(providers.Message{Role: "user", Content: "msg3"})
|
|
|
|
msgs := sq.dequeue()
|
|
if len(msgs) != 3 {
|
|
t.Fatalf("expected 3 messages in all mode, got %d", len(msgs))
|
|
}
|
|
if msgs[0].Content != "msg1" || msgs[1].Content != "msg2" || msgs[2].Content != "msg3" {
|
|
t.Fatalf("unexpected messages: %v", msgs)
|
|
}
|
|
|
|
if sq.len() != 0 {
|
|
t.Fatalf("expected 0 remaining, got %d", sq.len())
|
|
}
|
|
|
|
msgs = sq.dequeue()
|
|
if msgs != nil {
|
|
t.Fatalf("expected nil from empty queue, got %v", msgs)
|
|
}
|
|
}
|
|
|
|
func TestSteeringQueue_EmptyDequeue(t *testing.T) {
|
|
sq := newSteeringQueue(SteeringOneAtATime)
|
|
if msgs := sq.dequeue(); msgs != nil {
|
|
t.Fatalf("expected nil, got %v", msgs)
|
|
}
|
|
}
|
|
|
|
func TestSteeringQueue_SetMode(t *testing.T) {
|
|
sq := newSteeringQueue(SteeringOneAtATime)
|
|
if sq.getMode() != SteeringOneAtATime {
|
|
t.Fatalf("expected one-at-a-time, got %v", sq.getMode())
|
|
}
|
|
|
|
sq.setMode(SteeringAll)
|
|
if sq.getMode() != SteeringAll {
|
|
t.Fatalf("expected all, got %v", sq.getMode())
|
|
}
|
|
|
|
// Push two messages and verify all-mode drains them
|
|
sq.push(providers.Message{Role: "user", Content: "a"})
|
|
sq.push(providers.Message{Role: "user", Content: "b"})
|
|
|
|
msgs := sq.dequeue()
|
|
if len(msgs) != 2 {
|
|
t.Fatalf("expected 2 messages after mode switch, got %d", len(msgs))
|
|
}
|
|
}
|
|
|
|
func TestSteeringQueue_ConcurrentAccess(t *testing.T) {
|
|
sq := newSteeringQueue(SteeringOneAtATime)
|
|
|
|
var wg sync.WaitGroup
|
|
const n = MaxQueueSize
|
|
|
|
// Push from multiple goroutines
|
|
for i := 0; i < n; i++ {
|
|
wg.Add(1)
|
|
go func(i int) {
|
|
defer wg.Done()
|
|
sq.push(providers.Message{Role: "user", Content: fmt.Sprintf("msg%d", i)})
|
|
}(i)
|
|
}
|
|
wg.Wait()
|
|
|
|
if sq.len() != n {
|
|
t.Fatalf("expected %d messages, got %d", n, sq.len())
|
|
}
|
|
|
|
// Drain from multiple goroutines
|
|
var drained int
|
|
var mu sync.Mutex
|
|
for i := 0; i < n; i++ {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
if msgs := sq.dequeue(); len(msgs) > 0 {
|
|
mu.Lock()
|
|
drained += len(msgs)
|
|
mu.Unlock()
|
|
}
|
|
}()
|
|
}
|
|
wg.Wait()
|
|
|
|
if drained != n {
|
|
t.Fatalf("expected to drain %d messages, got %d", n, drained)
|
|
}
|
|
}
|
|
|
|
func TestSteeringQueue_Overflow(t *testing.T) {
|
|
sq := newSteeringQueue(SteeringOneAtATime)
|
|
|
|
// Fill the queue up to its maximum capacity
|
|
for i := 0; i < MaxQueueSize; i++ {
|
|
err := sq.push(providers.Message{Role: "user", Content: fmt.Sprintf("msg%d", i)})
|
|
if err != nil {
|
|
t.Fatalf("unexpected error pushing message %d: %v", i, err)
|
|
}
|
|
}
|
|
|
|
// Sanity check: ensure the queue is actually full
|
|
if sq.len() != MaxQueueSize {
|
|
t.Fatalf("expected queue length %d, got %d", MaxQueueSize, sq.len())
|
|
}
|
|
|
|
// Attempt to push one more message, which MUST fail
|
|
err := sq.push(providers.Message{Role: "user", Content: "overflow_msg"})
|
|
|
|
// Assert the error happened and is the exact one we expect
|
|
if err == nil {
|
|
t.Fatal("expected an error when pushing to a full queue, but got nil")
|
|
}
|
|
|
|
expectedErr := "steering queue is full"
|
|
if err.Error() != expectedErr {
|
|
t.Errorf("expected error message %q, got %q", expectedErr, err.Error())
|
|
}
|
|
}
|
|
|
|
func TestParseSteeringMode(t *testing.T) {
|
|
tests := []struct {
|
|
input string
|
|
expected SteeringMode
|
|
}{
|
|
{"", SteeringOneAtATime},
|
|
{"one-at-a-time", SteeringOneAtATime},
|
|
{"all", SteeringAll},
|
|
{"unknown", SteeringOneAtATime},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.input, func(t *testing.T) {
|
|
if got := parseSteeringMode(tt.input); got != tt.expected {
|
|
t.Fatalf("parseSteeringMode(%q) = %v, want %v", tt.input, got, tt.expected)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// --- AgentLoop steering integration tests ---
|
|
|
|
func TestAgentLoop_Steer_Enqueues(t *testing.T) {
|
|
al, cfg, msgBus, provider, cleanup := newTestAgentLoop(t)
|
|
defer cleanup()
|
|
|
|
if cfg == nil {
|
|
t.Fatal("expected config to be initialized")
|
|
}
|
|
if msgBus == nil {
|
|
t.Fatal("expected message bus to be initialized")
|
|
}
|
|
if provider == nil {
|
|
t.Fatal("expected provider to be initialized")
|
|
}
|
|
|
|
al.Steer(providers.Message{Role: "user", Content: "interrupt me"})
|
|
|
|
if al.steering.len() != 1 {
|
|
t.Fatalf("expected 1 steering message, got %d", al.steering.len())
|
|
}
|
|
|
|
msgs := al.dequeueSteeringMessages()
|
|
if len(msgs) != 1 || msgs[0].Content != "interrupt me" {
|
|
t.Fatalf("unexpected dequeued message: %v", msgs)
|
|
}
|
|
}
|
|
|
|
func TestAgentLoop_SteeringMode_GetSet(t *testing.T) {
|
|
al, cfg, msgBus, provider, cleanup := newTestAgentLoop(t)
|
|
defer cleanup()
|
|
|
|
if cfg == nil {
|
|
t.Fatal("expected config to be initialized")
|
|
}
|
|
if msgBus == nil {
|
|
t.Fatal("expected message bus to be initialized")
|
|
}
|
|
if provider == nil {
|
|
t.Fatal("expected provider to be initialized")
|
|
}
|
|
|
|
if al.SteeringMode() != SteeringOneAtATime {
|
|
t.Fatalf("expected default mode one-at-a-time, got %v", al.SteeringMode())
|
|
}
|
|
|
|
al.SetSteeringMode(SteeringAll)
|
|
if al.SteeringMode() != SteeringAll {
|
|
t.Fatalf("expected all mode, got %v", al.SteeringMode())
|
|
}
|
|
}
|
|
|
|
func TestAgentLoop_SteeringMode_ConfiguredFromConfig(t *testing.T) {
|
|
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create temp dir: %v", err)
|
|
}
|
|
defer os.RemoveAll(tmpDir)
|
|
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
Model: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
SteeringMode: "all",
|
|
},
|
|
},
|
|
}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
provider := &mockProvider{}
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
|
|
if al.SteeringMode() != SteeringAll {
|
|
t.Fatalf("expected 'all' mode from config, got %v", al.SteeringMode())
|
|
}
|
|
}
|
|
|
|
func TestAgentLoop_Continue_NoMessages(t *testing.T) {
|
|
al, cfg, msgBus, provider, cleanup := newTestAgentLoop(t)
|
|
defer cleanup()
|
|
|
|
if cfg == nil {
|
|
t.Fatal("expected config to be initialized")
|
|
}
|
|
if msgBus == nil {
|
|
t.Fatal("expected message bus to be initialized")
|
|
}
|
|
if provider == nil {
|
|
t.Fatal("expected provider to be initialized")
|
|
}
|
|
|
|
resp, err := al.Continue(context.Background(), "test-session", "test", "chat1")
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if resp != "" {
|
|
t.Fatalf("expected empty response for no steering messages, got %q", resp)
|
|
}
|
|
}
|
|
|
|
func TestAgentLoop_Continue_WithMessages(t *testing.T) {
|
|
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create temp dir: %v", err)
|
|
}
|
|
defer os.RemoveAll(tmpDir)
|
|
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
Model: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
},
|
|
},
|
|
}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
provider := &simpleMockProvider{response: "continued response"}
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
|
|
al.Steer(providers.Message{Role: "user", Content: "new direction"})
|
|
|
|
resp, err := al.Continue(context.Background(), "test-session", "test", "chat1")
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if resp != "continued response" {
|
|
t.Fatalf("expected 'continued response', got %q", resp)
|
|
}
|
|
}
|
|
|
|
// slowTool simulates a tool that takes some time to execute.
|
|
type slowTool struct {
|
|
name string
|
|
duration time.Duration
|
|
execCh chan struct{} // closed when Execute starts
|
|
}
|
|
|
|
func (t *slowTool) Name() string { return t.name }
|
|
func (t *slowTool) Description() string { return "slow tool for testing" }
|
|
func (t *slowTool) Parameters() map[string]any {
|
|
return map[string]any{
|
|
"type": "object",
|
|
"properties": map[string]any{},
|
|
}
|
|
}
|
|
|
|
func (t *slowTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
|
|
if t.execCh != nil {
|
|
close(t.execCh)
|
|
}
|
|
time.Sleep(t.duration)
|
|
return tools.SilentResult(fmt.Sprintf("executed %s", t.name))
|
|
}
|
|
|
|
// toolCallProvider returns an LLM response with tool calls on the first call,
|
|
// then a direct response on subsequent calls.
|
|
type toolCallProvider struct {
|
|
mu sync.Mutex
|
|
calls int
|
|
toolCalls []providers.ToolCall
|
|
finalResp string
|
|
}
|
|
|
|
func (m *toolCallProvider) Chat(
|
|
ctx context.Context,
|
|
messages []providers.Message,
|
|
tools []providers.ToolDefinition,
|
|
model string,
|
|
opts map[string]any,
|
|
) (*providers.LLMResponse, error) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
m.calls++
|
|
|
|
if m.calls == 1 && len(m.toolCalls) > 0 {
|
|
return &providers.LLMResponse{
|
|
Content: "",
|
|
ToolCalls: m.toolCalls,
|
|
}, nil
|
|
}
|
|
|
|
return &providers.LLMResponse{
|
|
Content: m.finalResp,
|
|
ToolCalls: []providers.ToolCall{},
|
|
}, nil
|
|
}
|
|
|
|
func (m *toolCallProvider) GetDefaultModel() string {
|
|
return "tool-call-mock"
|
|
}
|
|
|
|
type gracefulCaptureProvider struct {
|
|
mu sync.Mutex
|
|
calls int
|
|
toolCalls []providers.ToolCall
|
|
finalResp string
|
|
terminalMessages []providers.Message
|
|
terminalToolsCount int
|
|
}
|
|
|
|
func (p *gracefulCaptureProvider) Chat(
|
|
ctx context.Context,
|
|
messages []providers.Message,
|
|
tools []providers.ToolDefinition,
|
|
model string,
|
|
opts map[string]any,
|
|
) (*providers.LLMResponse, error) {
|
|
p.mu.Lock()
|
|
defer p.mu.Unlock()
|
|
p.calls++
|
|
|
|
if p.calls == 1 {
|
|
return &providers.LLMResponse{
|
|
ToolCalls: p.toolCalls,
|
|
}, nil
|
|
}
|
|
|
|
p.terminalMessages = append([]providers.Message(nil), messages...)
|
|
p.terminalToolsCount = len(tools)
|
|
return &providers.LLMResponse{
|
|
Content: p.finalResp,
|
|
}, nil
|
|
}
|
|
|
|
func (p *gracefulCaptureProvider) GetDefaultModel() string {
|
|
return "graceful-capture-mock"
|
|
}
|
|
|
|
type lateSteeringProvider struct {
|
|
mu sync.Mutex
|
|
calls int
|
|
firstCallStarted chan struct{}
|
|
releaseFirstCall chan struct{}
|
|
firstStartOnce sync.Once
|
|
secondCallMessages []providers.Message
|
|
}
|
|
|
|
func (p *lateSteeringProvider) Chat(
|
|
ctx context.Context,
|
|
messages []providers.Message,
|
|
tools []providers.ToolDefinition,
|
|
model string,
|
|
opts map[string]any,
|
|
) (*providers.LLMResponse, error) {
|
|
p.mu.Lock()
|
|
p.calls++
|
|
call := p.calls
|
|
p.mu.Unlock()
|
|
|
|
if call == 1 {
|
|
p.firstStartOnce.Do(func() { close(p.firstCallStarted) })
|
|
<-p.releaseFirstCall
|
|
return &providers.LLMResponse{Content: "first response"}, nil
|
|
}
|
|
|
|
p.mu.Lock()
|
|
p.secondCallMessages = append([]providers.Message(nil), messages...)
|
|
p.mu.Unlock()
|
|
return &providers.LLMResponse{Content: "continued response"}, nil
|
|
}
|
|
|
|
func (p *lateSteeringProvider) GetDefaultModel() string {
|
|
return "late-steering-mock"
|
|
}
|
|
|
|
type interruptibleTool struct {
|
|
name string
|
|
started chan struct{}
|
|
once sync.Once
|
|
}
|
|
|
|
func (t *interruptibleTool) Name() string { return t.name }
|
|
func (t *interruptibleTool) Description() string { return "interruptible tool for testing" }
|
|
func (t *interruptibleTool) Parameters() map[string]any {
|
|
return map[string]any{
|
|
"type": "object",
|
|
"properties": map[string]any{},
|
|
}
|
|
}
|
|
|
|
func (t *interruptibleTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
|
|
if t.started != nil {
|
|
t.once.Do(func() { close(t.started) })
|
|
}
|
|
<-ctx.Done()
|
|
return tools.ErrorResult(ctx.Err().Error()).WithError(ctx.Err())
|
|
}
|
|
|
|
func TestAgentLoop_Steering_SkipsRemainingTools(t *testing.T) {
|
|
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create temp dir: %v", err)
|
|
}
|
|
defer os.RemoveAll(tmpDir)
|
|
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
Model: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
},
|
|
},
|
|
}
|
|
|
|
tool1ExecCh := make(chan struct{})
|
|
tool1 := &slowTool{name: "tool_one", duration: 50 * time.Millisecond, execCh: tool1ExecCh}
|
|
tool2 := &slowTool{name: "tool_two", duration: 50 * time.Millisecond}
|
|
|
|
provider := &toolCallProvider{
|
|
toolCalls: []providers.ToolCall{
|
|
{
|
|
ID: "call_1",
|
|
Type: "function",
|
|
Name: "tool_one",
|
|
Function: &providers.FunctionCall{
|
|
Name: "tool_one",
|
|
Arguments: "{}",
|
|
},
|
|
Arguments: map[string]any{},
|
|
},
|
|
{
|
|
ID: "call_2",
|
|
Type: "function",
|
|
Name: "tool_two",
|
|
Function: &providers.FunctionCall{
|
|
Name: "tool_two",
|
|
Arguments: "{}",
|
|
},
|
|
Arguments: map[string]any{},
|
|
},
|
|
},
|
|
finalResp: "steered response",
|
|
}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
al.RegisterTool(tool1)
|
|
al.RegisterTool(tool2)
|
|
|
|
// Start processing in a goroutine
|
|
type result struct {
|
|
resp string
|
|
err error
|
|
}
|
|
resultCh := make(chan result, 1)
|
|
|
|
go func() {
|
|
resp, err := al.ProcessDirectWithChannel(
|
|
context.Background(),
|
|
"do something",
|
|
"test-session",
|
|
"test",
|
|
"chat1",
|
|
)
|
|
resultCh <- result{resp, err}
|
|
}()
|
|
|
|
// Wait for tool_one to start executing, then enqueue a steering message
|
|
select {
|
|
case <-tool1ExecCh:
|
|
// tool_one has started executing
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("timeout waiting for tool_one to start")
|
|
}
|
|
|
|
al.Steer(providers.Message{Role: "user", Content: "change course"})
|
|
|
|
// Get the result
|
|
select {
|
|
case r := <-resultCh:
|
|
if r.err != nil {
|
|
t.Fatalf("unexpected error: %v", r.err)
|
|
}
|
|
if r.resp != "steered response" {
|
|
t.Fatalf("expected 'steered response', got %q", r.resp)
|
|
}
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatal("timeout waiting for agent loop to complete")
|
|
}
|
|
|
|
// The provider should have been called twice:
|
|
// 1. first call returned tool calls
|
|
// 2. second call (after steering) returned the final response
|
|
provider.mu.Lock()
|
|
calls := provider.calls
|
|
provider.mu.Unlock()
|
|
if calls != 2 {
|
|
t.Fatalf("expected 2 provider calls, got %d", calls)
|
|
}
|
|
}
|
|
|
|
func TestAgentLoop_Steering_InitialPoll(t *testing.T) {
|
|
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create temp dir: %v", err)
|
|
}
|
|
defer os.RemoveAll(tmpDir)
|
|
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
Model: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
},
|
|
},
|
|
}
|
|
|
|
// Provider that captures messages it receives
|
|
var capturedMessages []providers.Message
|
|
var capMu sync.Mutex
|
|
provider := &capturingMockProvider{
|
|
response: "ack",
|
|
captureFn: func(msgs []providers.Message) {
|
|
capMu.Lock()
|
|
capturedMessages = make([]providers.Message, len(msgs))
|
|
copy(capturedMessages, msgs)
|
|
capMu.Unlock()
|
|
},
|
|
}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
|
|
// Enqueue a steering message before processing starts
|
|
al.Steer(providers.Message{Role: "user", Content: "pre-enqueued steering"})
|
|
|
|
// Process a normal message - the initial steering poll should inject the steering message
|
|
_, err = al.ProcessDirectWithChannel(
|
|
context.Background(),
|
|
"initial message",
|
|
"test-session",
|
|
"test",
|
|
"chat1",
|
|
)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
|
|
// The steering message should have been injected into the conversation
|
|
capMu.Lock()
|
|
msgs := capturedMessages
|
|
capMu.Unlock()
|
|
|
|
// Look for the steering message in the captured messages
|
|
found := false
|
|
for _, m := range msgs {
|
|
if m.Content == "pre-enqueued steering" {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
t.Fatal("expected steering message to be injected into conversation context")
|
|
}
|
|
}
|
|
|
|
func TestAgentLoop_Run_AutoContinuesLateSteeringMessage(t *testing.T) {
|
|
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create temp dir: %v", err)
|
|
}
|
|
defer os.RemoveAll(tmpDir)
|
|
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
Model: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
},
|
|
},
|
|
}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
provider := &lateSteeringProvider{
|
|
firstCallStarted: make(chan struct{}),
|
|
releaseFirstCall: make(chan struct{}),
|
|
}
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
|
|
runCtx, cancelRun := context.WithCancel(context.Background())
|
|
defer cancelRun()
|
|
|
|
runErrCh := make(chan error, 1)
|
|
go func() {
|
|
runErrCh <- al.Run(runCtx)
|
|
}()
|
|
|
|
first := bus.InboundMessage{
|
|
Channel: "test",
|
|
SenderID: "user1",
|
|
ChatID: "chat1",
|
|
Content: "first message",
|
|
Peer: bus.Peer{
|
|
Kind: "direct",
|
|
ID: "user1",
|
|
},
|
|
}
|
|
late := bus.InboundMessage{
|
|
Channel: "test",
|
|
SenderID: "user1",
|
|
ChatID: "chat1",
|
|
Content: "late append",
|
|
Peer: bus.Peer{
|
|
Kind: "direct",
|
|
ID: "user1",
|
|
},
|
|
}
|
|
|
|
pubCtx, pubCancel := context.WithTimeout(context.Background(), 2*time.Second)
|
|
defer pubCancel()
|
|
if err := msgBus.PublishInbound(pubCtx, first); err != nil {
|
|
t.Fatalf("publish first inbound: %v", err)
|
|
}
|
|
|
|
select {
|
|
case <-provider.firstCallStarted:
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("timeout waiting for first provider call to start")
|
|
}
|
|
|
|
if err := msgBus.PublishInbound(pubCtx, late); err != nil {
|
|
t.Fatalf("publish late inbound: %v", err)
|
|
}
|
|
|
|
close(provider.releaseFirstCall)
|
|
|
|
subCtx, subCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer subCancel()
|
|
|
|
out1, ok := msgBus.SubscribeOutbound(subCtx)
|
|
if !ok {
|
|
t.Fatal("expected first outbound response")
|
|
}
|
|
if out1.Content != "first response" {
|
|
t.Fatalf("expected first response, got %q", out1.Content)
|
|
}
|
|
|
|
out2, ok := msgBus.SubscribeOutbound(subCtx)
|
|
if !ok {
|
|
t.Fatal("expected continued outbound response")
|
|
}
|
|
if out2.Content != "continued response" {
|
|
t.Fatalf("expected continued response, got %q", out2.Content)
|
|
}
|
|
|
|
cancelRun()
|
|
select {
|
|
case err := <-runErrCh:
|
|
if err != nil {
|
|
t.Fatalf("Run returned error: %v", err)
|
|
}
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("timeout waiting for Run to stop")
|
|
}
|
|
|
|
provider.mu.Lock()
|
|
calls := provider.calls
|
|
secondMessages := append([]providers.Message(nil), provider.secondCallMessages...)
|
|
provider.mu.Unlock()
|
|
|
|
if calls != 2 {
|
|
t.Fatalf("expected 2 provider calls, got %d", calls)
|
|
}
|
|
|
|
foundLateMessage := false
|
|
for _, msg := range secondMessages {
|
|
if msg.Role == "user" && msg.Content == "late append" {
|
|
foundLateMessage = true
|
|
break
|
|
}
|
|
}
|
|
if !foundLateMessage {
|
|
t.Fatal("expected queued late message to be processed in an automatic follow-up turn")
|
|
}
|
|
}
|
|
|
|
func TestAgentLoop_InterruptGraceful_UsesTerminalNoToolCall(t *testing.T) {
|
|
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create temp dir: %v", err)
|
|
}
|
|
defer os.RemoveAll(tmpDir)
|
|
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
Model: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
},
|
|
},
|
|
}
|
|
|
|
tool1ExecCh := make(chan struct{})
|
|
tool1 := &slowTool{name: "tool_one", duration: 50 * time.Millisecond, execCh: tool1ExecCh}
|
|
tool2 := &slowTool{name: "tool_two", duration: 50 * time.Millisecond}
|
|
|
|
provider := &gracefulCaptureProvider{
|
|
toolCalls: []providers.ToolCall{
|
|
{
|
|
ID: "call_1",
|
|
Type: "function",
|
|
Name: "tool_one",
|
|
Function: &providers.FunctionCall{
|
|
Name: "tool_one",
|
|
Arguments: "{}",
|
|
},
|
|
Arguments: map[string]any{},
|
|
},
|
|
{
|
|
ID: "call_2",
|
|
Type: "function",
|
|
Name: "tool_two",
|
|
Function: &providers.FunctionCall{
|
|
Name: "tool_two",
|
|
Arguments: "{}",
|
|
},
|
|
Arguments: map[string]any{},
|
|
},
|
|
},
|
|
finalResp: "graceful summary",
|
|
}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
al.RegisterTool(tool1)
|
|
al.RegisterTool(tool2)
|
|
sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID)
|
|
|
|
sub := al.SubscribeEvents(32)
|
|
defer al.UnsubscribeEvents(sub.ID)
|
|
|
|
type result struct {
|
|
resp string
|
|
err error
|
|
}
|
|
resultCh := make(chan result, 1)
|
|
go func() {
|
|
resp, err := al.ProcessDirectWithChannel(
|
|
context.Background(),
|
|
"do something",
|
|
sessionKey,
|
|
"test",
|
|
"chat1",
|
|
)
|
|
resultCh <- result{resp: resp, err: err}
|
|
}()
|
|
|
|
select {
|
|
case <-tool1ExecCh:
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("timeout waiting for tool_one to start")
|
|
}
|
|
|
|
active := al.GetActiveTurn()
|
|
if active == nil {
|
|
t.Fatal("expected active turn while tool is running")
|
|
}
|
|
if active.SessionKey != sessionKey {
|
|
t.Fatalf("expected active session %q, got %q", sessionKey, active.SessionKey)
|
|
}
|
|
if active.Channel != "test" || active.ChatID != "chat1" {
|
|
t.Fatalf("unexpected active turn target: %#v", active)
|
|
}
|
|
|
|
if err := al.InterruptGraceful("wrap it up"); err != nil {
|
|
t.Fatalf("InterruptGraceful failed: %v", err)
|
|
}
|
|
|
|
select {
|
|
case r := <-resultCh:
|
|
if r.err != nil {
|
|
t.Fatalf("unexpected error: %v", r.err)
|
|
}
|
|
if r.resp != "graceful summary" {
|
|
t.Fatalf("expected graceful summary, got %q", r.resp)
|
|
}
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatal("timeout waiting for graceful interrupt result")
|
|
}
|
|
|
|
if active := al.GetActiveTurn(); active != nil {
|
|
t.Fatalf("expected no active turn after completion, got %#v", active)
|
|
}
|
|
|
|
provider.mu.Lock()
|
|
terminalMessages := append([]providers.Message(nil), provider.terminalMessages...)
|
|
terminalToolsCount := provider.terminalToolsCount
|
|
calls := provider.calls
|
|
provider.mu.Unlock()
|
|
|
|
if calls != 2 {
|
|
t.Fatalf("expected 2 provider calls, got %d", calls)
|
|
}
|
|
if terminalToolsCount != 0 {
|
|
t.Fatalf("expected graceful terminal call to disable tools, got %d tool defs", terminalToolsCount)
|
|
}
|
|
|
|
foundHint := false
|
|
foundSkipped := false
|
|
for _, msg := range terminalMessages {
|
|
if msg.Role == "user" && msg.Content == "Interrupt requested. Stop scheduling tools and provide a short final summary.\n\nInterrupt hint: wrap it up" {
|
|
foundHint = true
|
|
}
|
|
if msg.Role == "tool" && msg.ToolCallID == "call_2" && msg.Content == "Skipped due to graceful interrupt." {
|
|
foundSkipped = true
|
|
}
|
|
}
|
|
if !foundHint {
|
|
t.Fatal("expected graceful terminal call to include interrupt hint message")
|
|
}
|
|
if !foundSkipped {
|
|
t.Fatal("expected remaining tool to be marked as skipped after graceful interrupt")
|
|
}
|
|
|
|
events := collectEventStream(sub.C)
|
|
interruptEvt, ok := findEvent(events, EventKindInterruptReceived)
|
|
if !ok {
|
|
t.Fatal("expected interrupt received event")
|
|
}
|
|
interruptPayload, ok := interruptEvt.Payload.(InterruptReceivedPayload)
|
|
if !ok {
|
|
t.Fatalf("expected InterruptReceivedPayload, got %T", interruptEvt.Payload)
|
|
}
|
|
if interruptPayload.Kind != InterruptKindGraceful {
|
|
t.Fatalf("expected graceful interrupt payload, got %q", interruptPayload.Kind)
|
|
}
|
|
|
|
turnEndEvt, ok := findEvent(events, EventKindTurnEnd)
|
|
if !ok {
|
|
t.Fatal("expected turn end event")
|
|
}
|
|
turnEndPayload, ok := turnEndEvt.Payload.(TurnEndPayload)
|
|
if !ok {
|
|
t.Fatalf("expected TurnEndPayload, got %T", turnEndEvt.Payload)
|
|
}
|
|
if turnEndPayload.Status != TurnEndStatusCompleted {
|
|
t.Fatalf("expected completed turn after graceful interrupt, got %q", turnEndPayload.Status)
|
|
}
|
|
}
|
|
|
|
func TestAgentLoop_InterruptHard_RestoresSession(t *testing.T) {
|
|
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create temp dir: %v", err)
|
|
}
|
|
defer os.RemoveAll(tmpDir)
|
|
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
Model: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
},
|
|
},
|
|
}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
provider := &toolCallProvider{
|
|
toolCalls: []providers.ToolCall{
|
|
{
|
|
ID: "call_1",
|
|
Type: "function",
|
|
Name: "cancel_tool",
|
|
Function: &providers.FunctionCall{
|
|
Name: "cancel_tool",
|
|
Arguments: "{}",
|
|
},
|
|
Arguments: map[string]any{},
|
|
},
|
|
},
|
|
finalResp: "should not happen",
|
|
}
|
|
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
started := make(chan struct{})
|
|
al.RegisterTool(&interruptibleTool{name: "cancel_tool", started: started})
|
|
sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID)
|
|
|
|
defaultAgent := al.registry.GetDefaultAgent()
|
|
if defaultAgent == nil {
|
|
t.Fatal("expected default agent")
|
|
}
|
|
|
|
originalHistory := []providers.Message{
|
|
{Role: "user", Content: "before"},
|
|
{Role: "assistant", Content: "after"},
|
|
}
|
|
defaultAgent.Sessions.SetHistory(sessionKey, originalHistory)
|
|
|
|
sub := al.SubscribeEvents(16)
|
|
defer al.UnsubscribeEvents(sub.ID)
|
|
|
|
type result struct {
|
|
resp string
|
|
err error
|
|
}
|
|
resultCh := make(chan result, 1)
|
|
go func() {
|
|
resp, err := al.ProcessDirectWithChannel(
|
|
context.Background(),
|
|
"do work",
|
|
sessionKey,
|
|
"test",
|
|
"chat1",
|
|
)
|
|
resultCh <- result{resp: resp, err: err}
|
|
}()
|
|
|
|
select {
|
|
case <-started:
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("timeout waiting for interruptible tool to start")
|
|
}
|
|
|
|
if active := al.GetActiveTurn(); active == nil {
|
|
t.Fatal("expected active turn before hard abort")
|
|
}
|
|
|
|
if err := al.InterruptHard(); err != nil {
|
|
t.Fatalf("InterruptHard failed: %v", err)
|
|
}
|
|
|
|
select {
|
|
case r := <-resultCh:
|
|
if r.err != nil {
|
|
t.Fatalf("unexpected error: %v", r.err)
|
|
}
|
|
if r.resp != "" {
|
|
t.Fatalf("expected no final response after hard abort, got %q", r.resp)
|
|
}
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatal("timeout waiting for hard abort result")
|
|
}
|
|
|
|
if active := al.GetActiveTurn(); active != nil {
|
|
t.Fatalf("expected no active turn after hard abort, got %#v", active)
|
|
}
|
|
|
|
finalHistory := defaultAgent.Sessions.GetHistory(sessionKey)
|
|
if !reflect.DeepEqual(finalHistory, originalHistory) {
|
|
t.Fatalf("expected history rollback after hard abort, got %#v", finalHistory)
|
|
}
|
|
|
|
events := collectEventStream(sub.C)
|
|
interruptEvt, ok := findEvent(events, EventKindInterruptReceived)
|
|
if !ok {
|
|
t.Fatal("expected interrupt received event")
|
|
}
|
|
interruptPayload, ok := interruptEvt.Payload.(InterruptReceivedPayload)
|
|
if !ok {
|
|
t.Fatalf("expected InterruptReceivedPayload, got %T", interruptEvt.Payload)
|
|
}
|
|
if interruptPayload.Kind != InterruptKindHard {
|
|
t.Fatalf("expected hard interrupt payload, got %q", interruptPayload.Kind)
|
|
}
|
|
|
|
turnEndEvt, ok := findEvent(events, EventKindTurnEnd)
|
|
if !ok {
|
|
t.Fatal("expected turn end event")
|
|
}
|
|
turnEndPayload, ok := turnEndEvt.Payload.(TurnEndPayload)
|
|
if !ok {
|
|
t.Fatalf("expected TurnEndPayload, got %T", turnEndEvt.Payload)
|
|
}
|
|
if turnEndPayload.Status != TurnEndStatusAborted {
|
|
t.Fatalf("expected aborted turn, got %q", turnEndPayload.Status)
|
|
}
|
|
}
|
|
|
|
// capturingMockProvider captures messages sent to Chat for inspection.
|
|
type capturingMockProvider struct {
|
|
response string
|
|
calls int
|
|
captureFn func([]providers.Message)
|
|
}
|
|
|
|
func (m *capturingMockProvider) Chat(
|
|
ctx context.Context,
|
|
messages []providers.Message,
|
|
tools []providers.ToolDefinition,
|
|
model string,
|
|
opts map[string]any,
|
|
) (*providers.LLMResponse, error) {
|
|
m.calls++
|
|
if m.captureFn != nil {
|
|
m.captureFn(messages)
|
|
}
|
|
return &providers.LLMResponse{
|
|
Content: m.response,
|
|
ToolCalls: []providers.ToolCall{},
|
|
}, nil
|
|
}
|
|
|
|
func (m *capturingMockProvider) GetDefaultModel() string {
|
|
return "capturing-mock"
|
|
}
|
|
|
|
func TestAgentLoop_Steering_SkippedToolsHaveErrorResults(t *testing.T) {
|
|
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create temp dir: %v", err)
|
|
}
|
|
defer os.RemoveAll(tmpDir)
|
|
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
Model: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
},
|
|
},
|
|
}
|
|
|
|
execCh := make(chan struct{})
|
|
tool1 := &slowTool{name: "slow_tool", duration: 50 * time.Millisecond, execCh: execCh}
|
|
tool2 := &slowTool{name: "skipped_tool", duration: 50 * time.Millisecond}
|
|
|
|
// Provider that captures messages on the second call (after tools)
|
|
var secondCallMessages []providers.Message
|
|
var capMu sync.Mutex
|
|
callCount := 0
|
|
|
|
provider := &toolCallProvider{
|
|
toolCalls: []providers.ToolCall{
|
|
{
|
|
ID: "call_1",
|
|
Type: "function",
|
|
Name: "slow_tool",
|
|
Function: &providers.FunctionCall{
|
|
Name: "slow_tool",
|
|
Arguments: "{}",
|
|
},
|
|
Arguments: map[string]any{},
|
|
},
|
|
{
|
|
ID: "call_2",
|
|
Type: "function",
|
|
Name: "skipped_tool",
|
|
Function: &providers.FunctionCall{
|
|
Name: "skipped_tool",
|
|
Arguments: "{}",
|
|
},
|
|
Arguments: map[string]any{},
|
|
},
|
|
},
|
|
finalResp: "done",
|
|
}
|
|
|
|
// Wrap provider to capture messages on second call
|
|
wrappedProvider := &wrappingProvider{
|
|
inner: provider,
|
|
onChat: func(msgs []providers.Message) {
|
|
capMu.Lock()
|
|
callCount++
|
|
if callCount >= 2 {
|
|
secondCallMessages = make([]providers.Message, len(msgs))
|
|
copy(secondCallMessages, msgs)
|
|
}
|
|
capMu.Unlock()
|
|
},
|
|
}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
al := NewAgentLoop(cfg, msgBus, wrappedProvider)
|
|
al.RegisterTool(tool1)
|
|
al.RegisterTool(tool2)
|
|
|
|
resultCh := make(chan string, 1)
|
|
go func() {
|
|
resp, _ := al.ProcessDirectWithChannel(
|
|
context.Background(), "go", "test-session", "test", "chat1",
|
|
)
|
|
resultCh <- resp
|
|
}()
|
|
|
|
<-execCh
|
|
al.Steer(providers.Message{Role: "user", Content: "interrupt!"})
|
|
|
|
select {
|
|
case <-resultCh:
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
|
|
// Check that the skipped tool result message is in the conversation
|
|
capMu.Lock()
|
|
msgs := secondCallMessages
|
|
capMu.Unlock()
|
|
|
|
foundSkipped := false
|
|
for _, m := range msgs {
|
|
if m.Role == "tool" && m.ToolCallID == "call_2" && m.Content == "Skipped due to queued user message." {
|
|
foundSkipped = true
|
|
break
|
|
}
|
|
}
|
|
if !foundSkipped {
|
|
// Log what we actually got
|
|
for i, m := range msgs {
|
|
t.Logf("msg[%d]: role=%s toolCallID=%s content=%s", i, m.Role, m.ToolCallID, truncate(m.Content, 80))
|
|
}
|
|
t.Fatal("expected skipped tool result for call_2")
|
|
}
|
|
}
|
|
|
|
func truncate(s string, n int) string {
|
|
if len(s) <= n {
|
|
return s
|
|
}
|
|
return s[:n] + "..."
|
|
}
|
|
|
|
// wrappingProvider wraps another provider to hook into Chat calls.
|
|
type wrappingProvider struct {
|
|
inner providers.LLMProvider
|
|
onChat func([]providers.Message)
|
|
}
|
|
|
|
func (w *wrappingProvider) Chat(
|
|
ctx context.Context,
|
|
messages []providers.Message,
|
|
tools []providers.ToolDefinition,
|
|
model string,
|
|
opts map[string]any,
|
|
) (*providers.LLMResponse, error) {
|
|
if w.onChat != nil {
|
|
w.onChat(messages)
|
|
}
|
|
return w.inner.Chat(ctx, messages, tools, model, opts)
|
|
}
|
|
|
|
func (w *wrappingProvider) GetDefaultModel() string {
|
|
return w.inner.GetDefaultModel()
|
|
}
|
|
|
|
// Ensure NormalizeToolCall handles our test tool calls.
|
|
func init() {
|
|
// This is a no-op init; we just need the tool call tests to work
|
|
// with the proper argument serialization.
|
|
_ = json.Marshal
|
|
}
|