Files
picoclaw/pkg/agent/turn_coord_test.go
T
lxowalle 2992eccbf0 feat: add request-scoped context policies (#2914)
* feat: add request-scoped context policies

Add named turn profiles under agents.defaults so callers can opt into
per-request context and tool policies without changing default chat behavior.

Profiles can disable history, system context, skill prompts, or tools, and can
limit skills/tools with allow lists. Wire profile selection through Pico message
payloads, agent turn execution, Web chat selection, and Web visual config.

Reject invalid turn profiles before saving config through Web APIs and document
the new request context policy behavior.

* fix: address turn profile review blockers

* feat: simplify request context policy config

* fix: suppress tool prompt when turn tools are disabled

* fix: enforce turn profile tool restrictions
2026-05-22 10:06:40 +08:00

1042 lines
28 KiB
Go

package agent
import (
"context"
"errors"
"sync"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/routing"
"github.com/sipeed/picoclaw/pkg/session"
)
// =============================================================================
// Mock Providers for turn_coord Tests
// =============================================================================
// simpleConvProvider returns a simple text response without tools
type simpleConvProvider struct{}
func (p *simpleConvProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
return &providers.LLMResponse{
Content: "Hello! How can I help you today?",
FinishReason: "stop",
}, nil
}
func (p *simpleConvProvider) GetDefaultModel() string {
return "simple-model"
}
type sequenceProvider struct {
responses []*providers.LLMResponse
errors []error
callCount int
mu sync.Mutex
}
func (p *sequenceProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
p.mu.Lock()
defer p.mu.Unlock()
idx := p.callCount
p.callCount++
if idx < len(p.errors) && p.errors[idx] != nil {
return nil, p.errors[idx]
}
if idx < len(p.responses) && p.responses[idx] != nil {
return p.responses[idx], nil
}
return &providers.LLMResponse{Content: "ok", FinishReason: "stop"}, nil
}
func (p *sequenceProvider) GetDefaultModel() string {
return "sequence-model"
}
type nativeSearchCaptureProvider struct {
lastOpts map[string]any
messages []providers.Message
tools []providers.ToolDefinition
}
func (p *nativeSearchCaptureProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
p.messages = append([]providers.Message(nil), messages...)
p.tools = append([]providers.ToolDefinition(nil), tools...)
p.lastOpts = make(map[string]any, len(opts))
for k, v := range opts {
p.lastOpts[k] = v
}
return &providers.LLMResponse{
Content: "Using native search",
FinishReason: "stop",
}, nil
}
func (p *nativeSearchCaptureProvider) GetDefaultModel() string {
return "native-search-model"
}
func (p *nativeSearchCaptureProvider) SupportsNativeSearch() bool {
return true
}
// toolCallRespProvider returns a tool call response
type toolCallRespProvider struct {
toolName string
toolArgs map[string]any
response string
callCount int
mu sync.Mutex
}
func (p *toolCallRespProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
p.mu.Lock()
p.callCount++
count := p.callCount
p.mu.Unlock()
// First call returns a tool call, subsequent calls return final response
if count == 1 {
return &providers.LLMResponse{
Content: "Let me search for that information.",
ToolCalls: []providers.ToolCall{
{
ID: "call_1",
Name: p.toolName,
Arguments: p.toolArgs,
},
},
FinishReason: "tool_calls",
}, nil
}
return &providers.LLMResponse{
Content: p.response,
FinishReason: "stop",
}, nil
}
func (p *toolCallRespProvider) GetDefaultModel() string {
return "tool-model"
}
// errorProvider simulates various error conditions
type errorProvider struct {
errType string
callCount int
mu sync.Mutex
}
func (p *errorProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
p.mu.Lock()
p.callCount++
p.mu.Unlock()
switch p.errType {
case "timeout":
return nil, context.DeadlineExceeded
case "context_length":
return nil, errors.New("context_length_exceeded")
case "vision":
return nil, errors.New("vision_unsupported")
case "connection_reset":
return nil, errors.New("connection reset by peer")
case "broken_pipe":
return nil, errors.New("broken pipe")
case "read_tcp":
return nil, errors.New("read tcp 127.0.0.1:8080: connection reset")
case "eof":
return nil, errors.New("EOF")
case "connection_refused":
return nil, errors.New("connection refused")
default:
return nil, errors.New("unknown error")
}
}
func (p *errorProvider) GetDefaultModel() string {
return "error-model"
}
// =============================================================================
// Test Helper Functions
// =============================================================================
func newTurnCoordTestLoop(t *testing.T, provider providers.LLMProvider) (*AgentLoop, *AgentInstance, func()) {
t.Helper()
tmpDir := t.TempDir()
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
ModelName: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
al := NewAgentLoop(cfg, msgBus, provider)
agent := al.registry.GetDefaultAgent()
if agent == nil {
t.Fatal("expected default agent")
}
return al, agent, func() {
al.Close()
}
}
func makeTestProcessOpts(sessionKey string) processOptions {
return processOptions{
SessionKey: sessionKey,
Channel: "cli",
ChatID: "test-chat",
UserMessage: "test message",
DefaultResponse: "I couldn't process your request.",
EnableSummary: false,
SendResponse: false,
NoHistory: false,
}
}
type saveFailingSessionStore struct {
session.SessionStore
err error
}
func (s *saveFailingSessionStore) Save(_ string) error {
return s.err
}
// =============================================================================
// Pipeline Method Tests: SetupTurn
// =============================================================================
func TestPipeline_SetupTurn_BasicInitialization(t *testing.T) {
al, agent, cleanup := newTurnCoordTestLoop(t, &simpleConvProvider{})
defer cleanup()
pipeline := NewPipeline(al)
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
turnID: "turn-1",
context: newTurnContext(nil, nil, nil),
})
exec, err := pipeline.SetupTurn(context.Background(), ts)
if err != nil {
t.Fatalf("SetupTurn failed: %v", err)
}
if exec == nil {
t.Fatal("expected non-nil turnExecution")
}
if len(exec.messages) == 0 {
t.Error("expected messages to be populated")
}
if exec.iteration != 0 {
t.Errorf("expected iteration 0, got %d", exec.iteration)
}
}
// =============================================================================
// Pipeline Method Tests: CallLLM
// =============================================================================
func TestPipeline_CallLLM_SimpleResponse(t *testing.T) {
al, agent, cleanup := newTurnCoordTestLoop(t, &simpleConvProvider{})
defer cleanup()
pipeline := NewPipeline(al)
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
turnID: "turn-1",
context: newTurnContext(nil, nil, nil),
})
exec, err := pipeline.SetupTurn(context.Background(), ts)
if err != nil {
t.Fatalf("SetupTurn failed: %v", err)
}
ctrl, err := pipeline.CallLLM(context.Background(), context.Background(), ts, exec, 1)
if err != nil {
t.Fatalf("CallLLM failed: %v", err)
}
if ctrl != ControlBreak {
t.Errorf("expected ControlBreak, got %v", ctrl)
}
if exec.response == nil {
t.Fatal("expected non-nil response")
}
if exec.response.Content == "" {
t.Error("expected non-empty content")
}
}
func TestPipeline_SetupTurn_ModelNameDoesNotUseFallbackAliasBeforeFallback(t *testing.T) {
al, agent, cleanup := newTurnCoordTestLoop(t, &simpleConvProvider{})
defer cleanup()
agent.Model = "primary-model"
agent.Candidates = []providers.FallbackCandidate{
{Provider: "openai", Model: "gpt-5.4"},
{Provider: "anthropic", Model: "claude-sonnet", IdentityKey: "model_name:fallback-model"},
}
pipeline := NewPipeline(al)
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
turnID: "turn-1",
context: newTurnContext(nil, nil, nil),
})
exec, err := pipeline.SetupTurn(context.Background(), ts)
if err != nil {
t.Fatalf("SetupTurn failed: %v", err)
}
if exec.llmModelName != "primary-model" {
t.Fatalf("exec.llmModelName = %q, want %q", exec.llmModelName, "primary-model")
}
}
func TestPipeline_CallLLM_UsesSuccessfulFallbackIdentityAlias(t *testing.T) {
provider := &sequenceProvider{
errors: []error{
errors.New("status: 429 - rate limit exceeded"),
nil,
},
responses: []*providers.LLMResponse{
nil,
{Content: "fallback answer", FinishReason: "stop"},
},
}
al, agent, cleanup := newTurnCoordTestLoop(t, provider)
defer cleanup()
agent.Model = "primary-model"
agent.Candidates = []providers.FallbackCandidate{
{Provider: "openai", Model: "gpt-5.4", IdentityKey: "model_name:primary"},
{Provider: "openai", Model: "gpt-5.4", IdentityKey: "model_name:secondary"},
}
al.fallback = providers.NewFallbackChain(providers.NewCooldownTracker(), nil)
pipeline := NewPipeline(al)
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
turnID: "turn-1",
context: newTurnContext(nil, nil, nil),
})
exec, err := pipeline.SetupTurn(context.Background(), ts)
if err != nil {
t.Fatalf("SetupTurn failed: %v", err)
}
ctrl, err := pipeline.CallLLM(context.Background(), context.Background(), ts, exec, 1)
if err != nil {
t.Fatalf("CallLLM failed: %v", err)
}
if ctrl != ControlBreak {
t.Fatalf("expected ControlBreak, got %v", ctrl)
}
if exec.llmModelName != "secondary" {
t.Fatalf("exec.llmModelName = %q, want %q", exec.llmModelName, "secondary")
}
}
func TestPipeline_CallLLM_UsesSuccessfulFallbackDisplayNameWithoutAlias(t *testing.T) {
provider := &sequenceProvider{
errors: []error{
errors.New("status: 429 - rate limit exceeded"),
nil,
},
responses: []*providers.LLMResponse{
nil,
{Content: "fallback answer", FinishReason: "stop"},
},
}
al, agent, cleanup := newTurnCoordTestLoop(t, provider)
defer cleanup()
agent.Model = "primary-model"
agent.Candidates = []providers.FallbackCandidate{
{Provider: "openai", Model: "gpt-5.4", IdentityKey: "model_name:primary", DisplayName: "primary-model"},
{Provider: "anthropic", Model: "claude-sonnet", DisplayName: "anthropic/claude-sonnet"},
}
al.fallback = providers.NewFallbackChain(providers.NewCooldownTracker(), nil)
pipeline := NewPipeline(al)
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
turnID: "turn-1",
context: newTurnContext(nil, nil, nil),
})
exec, err := pipeline.SetupTurn(context.Background(), ts)
if err != nil {
t.Fatalf("SetupTurn failed: %v", err)
}
ctrl, err := pipeline.CallLLM(context.Background(), context.Background(), ts, exec, 1)
if err != nil {
t.Fatalf("CallLLM failed: %v", err)
}
if ctrl != ControlBreak {
t.Fatalf("expected ControlBreak, got %v", ctrl)
}
if exec.llmModelName != "anthropic/claude-sonnet" {
t.Fatalf("exec.llmModelName = %q, want %q", exec.llmModelName, "anthropic/claude-sonnet")
}
}
func TestPipeline_SetupTurn_UsesLightCandidateDisplayName(t *testing.T) {
al, agent, cleanup := newTurnCoordTestLoop(t, &simpleConvProvider{})
defer cleanup()
agent.Model = "primary-model"
agent.Candidates = []providers.FallbackCandidate{
{Provider: "openai", Model: "gpt-5.4", IdentityKey: "model_name:primary", DisplayName: "primary-model"},
}
agent.LightCandidates = []providers.FallbackCandidate{
{Provider: "openai", Model: "gpt-5.4-mini", IdentityKey: "model_name:light-model", DisplayName: "light-model"},
}
agent.Router = routing.New(routing.RouterConfig{LightModel: "light-model", Threshold: 1})
pipeline := NewPipeline(al)
opts := makeTestProcessOpts("test-session")
opts.UserMessage = ""
ts := newTurnState(agent, opts, turnEventScope{
turnID: "turn-1",
context: newTurnContext(nil, nil, nil),
})
exec, err := pipeline.SetupTurn(context.Background(), ts)
if err != nil {
t.Fatalf("SetupTurn failed: %v", err)
}
if !exec.usedLight {
t.Fatal("expected light routing to be used")
}
if exec.llmModelName != "light-model" {
t.Fatalf("exec.llmModelName = %q, want %q", exec.llmModelName, "light-model")
}
}
func TestRunTurn_FinalizeSaveErrorEmitsErrorTurnEnd(t *testing.T) {
al, agent, cleanup := newTurnCoordTestLoop(t, &simpleConvProvider{})
defer cleanup()
saveErr := errors.New("session save failed")
agent.Sessions = &saveFailingSessionStore{
SessionStore: session.NewSessionManager(""),
err: saveErr,
}
sub := al.SubscribeEvents(8)
defer al.UnsubscribeEvents(sub.ID)
if _, err := al.ProcessDirect(context.Background(), "hello", "session-save-fail"); err == nil {
t.Fatal("expected ProcessDirect to fail")
}
deadline := time.After(2 * time.Second)
for {
select {
case evt := <-sub.C:
if evt.Kind != EventKindTurnEnd {
continue
}
payload, ok := evt.Payload.(TurnEndPayload)
if !ok {
t.Fatalf("TurnEnd payload type = %T", evt.Payload)
}
if payload.Status != TurnEndStatusError {
t.Fatalf("TurnEnd status = %q, want %q", payload.Status, TurnEndStatusError)
}
return
case <-deadline:
t.Fatal("timed out waiting for turn_end event")
}
}
}
func TestPipeline_CallLLM_WithToolCall(t *testing.T) {
provider := &toolCallRespProvider{
toolName: "web_search",
toolArgs: map[string]any{"query": "test"},
response: "Found information about test.",
}
al, agent, cleanup := newTurnCoordTestLoop(t, provider)
defer cleanup()
pipeline := NewPipeline(al)
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
turnID: "turn-1",
context: newTurnContext(nil, nil, nil),
})
exec, err := pipeline.SetupTurn(context.Background(), ts)
if err != nil {
t.Fatalf("SetupTurn failed: %v", err)
}
ctrl, err := pipeline.CallLLM(context.Background(), context.Background(), ts, exec, 1)
if err != nil {
t.Fatalf("CallLLM failed: %v", err)
}
if ctrl != ControlToolLoop {
t.Errorf("expected ControlToolLoop, got %v", ctrl)
}
if len(exec.normalizedToolCalls) == 0 {
t.Fatal("expected tool calls")
}
if exec.normalizedToolCalls[0].Name != "web_search" {
t.Errorf("expected tool name 'web_search', got %q", exec.normalizedToolCalls[0].Name)
}
}
func TestPipeline_CallLLM_UsesNativeSearchWithoutClientWebSearchTool(t *testing.T) {
provider := &nativeSearchCaptureProvider{}
al, agent, cleanup := newTurnCoordTestLoop(t, provider)
defer cleanup()
if _, ok := agent.Tools.Get("web_search"); ok {
t.Fatal("expected no client-side web_search tool to be registered")
}
al.cfg.Tools.Web.Enabled = true
al.cfg.Tools.Web.PreferNative = true
pipeline := NewPipeline(al)
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
turnID: "turn-1",
context: newTurnContext(nil, nil, nil),
})
exec, err := pipeline.SetupTurn(context.Background(), ts)
if err != nil {
t.Fatalf("SetupTurn failed: %v", err)
}
ctrl, err := pipeline.CallLLM(context.Background(), context.Background(), ts, exec, 1)
if err != nil {
t.Fatalf("CallLLM failed: %v", err)
}
if ctrl != ControlBreak {
t.Fatalf("expected ControlBreak, got %v", ctrl)
}
if got, _ := provider.lastOpts["native_search"].(bool); !got {
t.Fatalf("expected native_search=true, got %#v", provider.lastOpts["native_search"])
}
}
func TestPipeline_CallLLM_TimeoutRetry(t *testing.T) {
errorPrv := &errorProvider{errType: "timeout"}
al, agent, cleanup := newTurnCoordTestLoop(t, errorPrv)
defer cleanup()
pipeline := NewPipeline(al)
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
turnID: "turn-1",
context: newTurnContext(nil, nil, nil),
})
exec, err := pipeline.SetupTurn(context.Background(), ts)
if err != nil {
t.Fatalf("SetupTurn failed: %v", err)
}
// Should retry and eventually fail after max retries
_, err = pipeline.CallLLM(context.Background(), context.Background(), ts, exec, 1)
if err == nil {
t.Error("expected error after retries")
}
}
func TestPipeline_CallLLM_ContextLengthError(t *testing.T) {
errorPrv := &errorProvider{errType: "context_length"}
al, agent, cleanup := newTurnCoordTestLoop(t, errorPrv)
defer cleanup()
pipeline := NewPipeline(al)
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
turnID: "turn-1",
context: newTurnContext(nil, nil, nil),
})
exec, err := pipeline.SetupTurn(context.Background(), ts)
if err != nil {
t.Fatalf("SetupTurn failed: %v", err)
}
// Should trigger context compression and retry
_, err = pipeline.CallLLM(context.Background(), context.Background(), ts, exec, 1)
// May succeed after compression or fail - either is acceptable
t.Logf("CallLLM result after context error: err=%v", err)
}
func TestPipeline_CallLLM_NetworkErrorRetry(t *testing.T) {
testCases := []struct {
name string
errType string
}{
{"connection_reset", "connection_reset"},
{"broken_pipe", "broken_pipe"},
{"read_tcp", "read_tcp"},
{"eof", "eof"},
{"connection_refused", "connection_refused"},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
errorPrv := &errorProvider{errType: tc.errType}
al, agent, cleanup := newTurnCoordTestLoop(t, errorPrv)
defer cleanup()
pipeline := NewPipeline(al)
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
turnID: "turn-1",
context: newTurnContext(nil, nil, nil),
})
exec, err := pipeline.SetupTurn(context.Background(), ts)
if err != nil {
t.Fatalf("SetupTurn failed: %v", err)
}
_, err = pipeline.CallLLM(context.Background(), context.Background(), ts, exec, 1)
if err == nil {
t.Error("expected error after network error retries")
}
})
}
}
func TestPipeline_CallLLM_RetryConfigRespected(t *testing.T) {
tmpDir := t.TempDir()
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
ModelName: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
MaxLLMRetries: 3,
LLMRetryBackoffSecs: 1,
},
},
}
msgBus := bus.NewMessageBus()
provider := &errorProvider{errType: "connection_reset"}
al := NewAgentLoop(cfg, msgBus, provider)
defer al.Close()
agent := al.registry.GetDefaultAgent()
if agent == nil {
t.Fatal("expected default agent")
}
pipeline := NewPipeline(al)
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
turnID: "turn-1",
context: newTurnContext(nil, nil, nil),
})
exec, err := pipeline.SetupTurn(context.Background(), ts)
if err != nil {
t.Fatalf("SetupTurn failed: %v", err)
}
start := time.Now()
_, err = pipeline.CallLLM(context.Background(), context.Background(), ts, exec, 1)
elapsed := time.Since(start)
if err == nil {
t.Error("expected error after retries")
}
expectedMinTime := 3 * time.Second
if elapsed < expectedMinTime {
t.Errorf("expected at least %v of backoff, got %v", expectedMinTime, elapsed)
}
}
func TestPipeline_CallLLM_RetryCountLimit(t *testing.T) {
tmpDir := t.TempDir()
counterPrv := &countingErrorProvider{errType: "connection_reset", targetCalls: 5}
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
ModelName: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
MaxLLMRetries: 2,
LLMRetryBackoffSecs: 0,
},
},
}
msgBus := bus.NewMessageBus()
al := NewAgentLoop(cfg, msgBus, counterPrv)
defer al.Close()
agent := al.registry.GetDefaultAgent()
if agent == nil {
t.Fatal("expected default agent")
}
pipeline := NewPipeline(al)
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
turnID: "turn-1",
context: newTurnContext(nil, nil, nil),
})
exec, err := pipeline.SetupTurn(context.Background(), ts)
if err != nil {
t.Fatalf("SetupTurn failed: %v", err)
}
_, err = pipeline.CallLLM(context.Background(), context.Background(), ts, exec, 1)
if err == nil {
t.Error("expected error after retries")
}
if counterPrv.callCount != 3 {
t.Errorf("expected exactly 3 calls (1 initial + 2 retries), got %d", counterPrv.callCount)
}
}
type countingErrorProvider struct {
errType string
targetCalls int
callCount int
mu sync.Mutex
}
func (p *countingErrorProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
p.mu.Lock()
p.callCount++
p.mu.Unlock()
return nil, errors.New("connection reset by peer")
}
func (p *countingErrorProvider) GetDefaultModel() string {
return "counting-error-model"
}
// =============================================================================
// Pipeline Method Tests: ExecuteTools
// =============================================================================
func TestPipeline_ExecuteTools_NoTools(t *testing.T) {
// Provider returns no tool calls, so ExecuteTools should not be called
// This test verifies the ControlBreak path from CallLLM
provider := &simpleConvProvider{}
al, agent, cleanup := newTurnCoordTestLoop(t, provider)
defer cleanup()
pipeline := NewPipeline(al)
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
turnID: "turn-1",
context: newTurnContext(nil, nil, nil),
})
exec, err := pipeline.SetupTurn(context.Background(), ts)
if err != nil {
t.Fatalf("SetupTurn failed: %v", err)
}
// First CallLLM returns ControlBreak (no tools)
ctrl, err := pipeline.CallLLM(context.Background(), context.Background(), ts, exec, 1)
if err != nil {
t.Fatalf("CallLLM failed: %v", err)
}
if ctrl != ControlBreak {
t.Fatalf("expected ControlBreak, got %v", ctrl)
}
// No tools to execute, Finalize should be called directly
}
// =============================================================================
// runTurn Integration Tests
// =============================================================================
func TestRunTurn_SimpleConversation(t *testing.T) {
provider := &simpleConvProvider{}
al, agent, cleanup := newTurnCoordTestLoop(t, provider)
defer cleanup()
pipeline := NewPipeline(al)
opts := makeTestProcessOpts("test-session-simple")
ts := newTurnState(agent, opts, turnEventScope{
turnID: "turn-simple",
context: newTurnContext(nil, nil, nil),
})
result, err := al.runTurn(context.Background(), ts, pipeline)
if err != nil {
t.Fatalf("runTurn failed: %v", err)
}
if result.status != TurnEndStatusCompleted {
t.Errorf("expected status Completed, got %v", result.status)
}
if result.finalContent == "" {
t.Error("expected non-empty finalContent")
}
}
func TestRunTurn_MaxIterations(t *testing.T) {
// Provider always returns tool calls, should hit max iterations
provider := &toolCallRespProvider{
toolName: "search",
toolArgs: map[string]any{"q": "x"},
response: "done",
}
al, agent, cleanup := newTurnCoordTestLoop(t, provider)
defer cleanup()
// Override max iterations to 2
agent.MaxIterations = 2
pipeline := NewPipeline(al)
opts := makeTestProcessOpts("test-session-maxiter")
ts := newTurnState(agent, opts, turnEventScope{
turnID: "turn-maxiter",
context: newTurnContext(nil, nil, nil),
})
result, err := al.runTurn(context.Background(), ts, pipeline)
if err != nil {
t.Fatalf("runTurn failed: %v", err)
}
// Should complete due to max iterations
if result.status != TurnEndStatusCompleted {
t.Errorf("expected status Completed, got %v", result.status)
}
}
func TestRunTurn_HardAbort(t *testing.T) {
// Provider simulates a slow response, but we'll abort mid-turn
slowProvider := &slowMockProvider{delay: 10 * time.Second}
al, agent, cleanup := newTurnCoordTestLoop(t, slowProvider)
defer cleanup()
pipeline := NewPipeline(al)
opts := makeTestProcessOpts("test-session-abort")
ts := newTurnState(agent, opts, turnEventScope{
turnID: "turn-abort",
context: newTurnContext(nil, nil, nil),
})
// Run in goroutine with abort after short delay
done := make(chan struct{})
go func() {
al.runTurn(context.Background(), ts, pipeline)
close(done)
}()
// Give it a moment to start
time.Sleep(50 * time.Millisecond)
// Request hard abort
ts.requestHardAbort()
// Wait for runTurn to complete
select {
case <-done:
case <-time.After(3 * time.Second):
t.Fatal("runTurn did not complete after abort")
}
}
func TestRunTurn_SteeringMessageInjection(t *testing.T) {
provider := &simpleConvProvider{}
al, agent, cleanup := newTurnCoordTestLoop(t, provider)
defer cleanup()
pipeline := NewPipeline(al)
opts := makeTestProcessOpts("test-session-steering")
ts := newTurnState(agent, opts, turnEventScope{
turnID: "turn-steering",
context: newTurnContext(nil, nil, nil),
})
// Enqueue steering message before runTurn
steeringMsg := providers.Message{
Role: "user",
Content: "Steering message",
}
al.Steer(steeringMsg)
result, err := al.runTurn(context.Background(), ts, pipeline)
if err != nil {
t.Fatalf("runTurn failed: %v", err)
}
if result.status != TurnEndStatusCompleted {
t.Errorf("expected status Completed, got %v", result.status)
}
// Steering message should have been injected
}
func TestRunTurn_GracefulInterrupt(t *testing.T) {
provider := &toolCallRespProvider{
toolName: "search",
toolArgs: map[string]any{"q": "test"},
response: "Final response after interrupt",
}
al, agent, cleanup := newTurnCoordTestLoop(t, provider)
defer cleanup()
pipeline := NewPipeline(al)
opts := makeTestProcessOpts("test-session-graceful")
ts := newTurnState(agent, opts, turnEventScope{
turnID: "turn-graceful",
context: newTurnContext(nil, nil, nil),
})
// Run in goroutine with graceful interrupt after first iteration
done := make(chan struct{})
var result turnResult
go func() {
result, _ = al.runTurn(context.Background(), ts, pipeline)
close(done)
}()
// Give it a moment to start first iteration
time.Sleep(50 * time.Millisecond)
// Request graceful interrupt
ts.requestGracefulInterrupt("Please stop")
select {
case <-done:
case <-time.After(5 * time.Second):
t.Fatal("runTurn did not complete after graceful interrupt")
}
// Should complete gracefully
if result.status != TurnEndStatusCompleted {
t.Errorf("expected status Completed, got %v", result.status)
}
}
// =============================================================================
// turnState Tests
// =============================================================================
func TestTurnState_GracefulInterruptRequested(t *testing.T) {
ts := &turnState{
gracefulInterrupt: false,
gracefulInterruptHint: "",
}
// Initially should not be requested
requested, _ := ts.gracefulInterruptRequested()
if requested {
t.Error("expected no interrupt initially")
}
// Request interrupt
ts.requestGracefulInterrupt("test hint")
requested, hint := ts.gracefulInterruptRequested()
if !requested {
t.Error("expected interrupt to be requested")
}
if hint != "test hint" {
t.Errorf("expected hint 'test hint', got %q", hint)
}
}
func TestTurnState_HardAbortRequested(t *testing.T) {
ts := &turnState{
hardAbort: false,
}
if ts.hardAbortRequested() {
t.Error("expected no hard abort initially")
}
ts.requestHardAbort()
if !ts.hardAbortRequested() {
t.Error("expected hard abort to be requested")
}
}
func TestTurnState_SkillContextSnapshotsTrackLatestSuccessfulPath(t *testing.T) {
ts := &turnState{}
ts.recordSkillContextSnapshot(skillContextTriggerInitialBuild, []string{"skill-a"})
ts.recordSkillContextSnapshot(skillContextTriggerContextRetryRebuild, []string{"skill-b", "skill-c"})
if got := ts.attemptedSkillsSnapshot(); len(got) != 3 || got[0] != "skill-a" || got[1] != "skill-b" ||
got[2] != "skill-c" {
t.Fatalf("attemptedSkillsSnapshot = %v, want [skill-a skill-b skill-c]", got)
}
if got := ts.latestSkillContextSnapshot(); len(got) != 2 || got[0] != "skill-b" || got[1] != "skill-c" {
t.Fatalf("latestSkillContextSnapshot = %v, want [skill-b skill-c]", got)
}
snapshots := ts.skillContextSnapshotsSnapshot()
if len(snapshots) != 2 {
t.Fatalf("len(skillContextSnapshotsSnapshot()) = %d, want 2", len(snapshots))
}
if snapshots[0].Sequence != 1 || snapshots[0].Trigger != skillContextTriggerInitialBuild {
t.Fatalf("snapshots[0] = %+v, want sequence=1 trigger=%q", snapshots[0], skillContextTriggerInitialBuild)
}
if snapshots[1].Sequence != 2 || snapshots[1].Trigger != skillContextTriggerContextRetryRebuild {
t.Fatalf("snapshots[1] = %+v, want sequence=2 trigger=%q", snapshots[1], skillContextTriggerContextRetryRebuild)
}
}