mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
b7db059544
* feat(chat,seahorse): persist and display model_name across history * test(seahorse): fix lint regressions in repair coverage * fix(pico): preserve model_name in live updates * fix(pico): preserve model_name through live stream wrappers
1038 lines
28 KiB
Go
1038 lines
28 KiB
Go
package agent
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/sipeed/picoclaw/pkg/bus"
|
|
"github.com/sipeed/picoclaw/pkg/config"
|
|
"github.com/sipeed/picoclaw/pkg/providers"
|
|
"github.com/sipeed/picoclaw/pkg/routing"
|
|
"github.com/sipeed/picoclaw/pkg/session"
|
|
)
|
|
|
|
// =============================================================================
|
|
// Mock Providers for turn_coord Tests
|
|
// =============================================================================
|
|
|
|
// simpleConvProvider returns a simple text response without tools
|
|
type simpleConvProvider struct{}
|
|
|
|
func (p *simpleConvProvider) Chat(
|
|
ctx context.Context,
|
|
messages []providers.Message,
|
|
tools []providers.ToolDefinition,
|
|
model string,
|
|
opts map[string]any,
|
|
) (*providers.LLMResponse, error) {
|
|
return &providers.LLMResponse{
|
|
Content: "Hello! How can I help you today?",
|
|
FinishReason: "stop",
|
|
}, nil
|
|
}
|
|
|
|
func (p *simpleConvProvider) GetDefaultModel() string {
|
|
return "simple-model"
|
|
}
|
|
|
|
type sequenceProvider struct {
|
|
responses []*providers.LLMResponse
|
|
errors []error
|
|
callCount int
|
|
mu sync.Mutex
|
|
}
|
|
|
|
func (p *sequenceProvider) Chat(
|
|
ctx context.Context,
|
|
messages []providers.Message,
|
|
tools []providers.ToolDefinition,
|
|
model string,
|
|
opts map[string]any,
|
|
) (*providers.LLMResponse, error) {
|
|
p.mu.Lock()
|
|
defer p.mu.Unlock()
|
|
|
|
idx := p.callCount
|
|
p.callCount++
|
|
|
|
if idx < len(p.errors) && p.errors[idx] != nil {
|
|
return nil, p.errors[idx]
|
|
}
|
|
if idx < len(p.responses) && p.responses[idx] != nil {
|
|
return p.responses[idx], nil
|
|
}
|
|
return &providers.LLMResponse{Content: "ok", FinishReason: "stop"}, nil
|
|
}
|
|
|
|
func (p *sequenceProvider) GetDefaultModel() string {
|
|
return "sequence-model"
|
|
}
|
|
|
|
type nativeSearchCaptureProvider struct {
|
|
lastOpts map[string]any
|
|
}
|
|
|
|
func (p *nativeSearchCaptureProvider) Chat(
|
|
ctx context.Context,
|
|
messages []providers.Message,
|
|
tools []providers.ToolDefinition,
|
|
model string,
|
|
opts map[string]any,
|
|
) (*providers.LLMResponse, error) {
|
|
p.lastOpts = make(map[string]any, len(opts))
|
|
for k, v := range opts {
|
|
p.lastOpts[k] = v
|
|
}
|
|
return &providers.LLMResponse{
|
|
Content: "Using native search",
|
|
FinishReason: "stop",
|
|
}, nil
|
|
}
|
|
|
|
func (p *nativeSearchCaptureProvider) GetDefaultModel() string {
|
|
return "native-search-model"
|
|
}
|
|
|
|
func (p *nativeSearchCaptureProvider) SupportsNativeSearch() bool {
|
|
return true
|
|
}
|
|
|
|
// toolCallRespProvider returns a tool call response
|
|
type toolCallRespProvider struct {
|
|
toolName string
|
|
toolArgs map[string]any
|
|
response string
|
|
callCount int
|
|
mu sync.Mutex
|
|
}
|
|
|
|
func (p *toolCallRespProvider) Chat(
|
|
ctx context.Context,
|
|
messages []providers.Message,
|
|
tools []providers.ToolDefinition,
|
|
model string,
|
|
opts map[string]any,
|
|
) (*providers.LLMResponse, error) {
|
|
p.mu.Lock()
|
|
p.callCount++
|
|
count := p.callCount
|
|
p.mu.Unlock()
|
|
|
|
// First call returns a tool call, subsequent calls return final response
|
|
if count == 1 {
|
|
return &providers.LLMResponse{
|
|
Content: "Let me search for that information.",
|
|
ToolCalls: []providers.ToolCall{
|
|
{
|
|
ID: "call_1",
|
|
Name: p.toolName,
|
|
Arguments: p.toolArgs,
|
|
},
|
|
},
|
|
FinishReason: "tool_calls",
|
|
}, nil
|
|
}
|
|
return &providers.LLMResponse{
|
|
Content: p.response,
|
|
FinishReason: "stop",
|
|
}, nil
|
|
}
|
|
|
|
func (p *toolCallRespProvider) GetDefaultModel() string {
|
|
return "tool-model"
|
|
}
|
|
|
|
// errorProvider simulates various error conditions
|
|
type errorProvider struct {
|
|
errType string
|
|
callCount int
|
|
mu sync.Mutex
|
|
}
|
|
|
|
func (p *errorProvider) Chat(
|
|
ctx context.Context,
|
|
messages []providers.Message,
|
|
tools []providers.ToolDefinition,
|
|
model string,
|
|
opts map[string]any,
|
|
) (*providers.LLMResponse, error) {
|
|
p.mu.Lock()
|
|
p.callCount++
|
|
p.mu.Unlock()
|
|
|
|
switch p.errType {
|
|
case "timeout":
|
|
return nil, context.DeadlineExceeded
|
|
case "context_length":
|
|
return nil, errors.New("context_length_exceeded")
|
|
case "vision":
|
|
return nil, errors.New("vision_unsupported")
|
|
case "connection_reset":
|
|
return nil, errors.New("connection reset by peer")
|
|
case "broken_pipe":
|
|
return nil, errors.New("broken pipe")
|
|
case "read_tcp":
|
|
return nil, errors.New("read tcp 127.0.0.1:8080: connection reset")
|
|
case "eof":
|
|
return nil, errors.New("EOF")
|
|
case "connection_refused":
|
|
return nil, errors.New("connection refused")
|
|
default:
|
|
return nil, errors.New("unknown error")
|
|
}
|
|
}
|
|
|
|
func (p *errorProvider) GetDefaultModel() string {
|
|
return "error-model"
|
|
}
|
|
|
|
// =============================================================================
|
|
// Test Helper Functions
|
|
// =============================================================================
|
|
|
|
func newTurnCoordTestLoop(t *testing.T, provider providers.LLMProvider) (*AgentLoop, *AgentInstance, func()) {
|
|
t.Helper()
|
|
tmpDir := t.TempDir()
|
|
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
ModelName: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
},
|
|
},
|
|
}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
agent := al.registry.GetDefaultAgent()
|
|
if agent == nil {
|
|
t.Fatal("expected default agent")
|
|
}
|
|
|
|
return al, agent, func() {
|
|
al.Close()
|
|
}
|
|
}
|
|
|
|
func makeTestProcessOpts(sessionKey string) processOptions {
|
|
return processOptions{
|
|
SessionKey: sessionKey,
|
|
Channel: "cli",
|
|
ChatID: "test-chat",
|
|
UserMessage: "test message",
|
|
DefaultResponse: "I couldn't process your request.",
|
|
EnableSummary: false,
|
|
SendResponse: false,
|
|
NoHistory: false,
|
|
}
|
|
}
|
|
|
|
type saveFailingSessionStore struct {
|
|
session.SessionStore
|
|
err error
|
|
}
|
|
|
|
func (s *saveFailingSessionStore) Save(_ string) error {
|
|
return s.err
|
|
}
|
|
|
|
// =============================================================================
|
|
// Pipeline Method Tests: SetupTurn
|
|
// =============================================================================
|
|
|
|
func TestPipeline_SetupTurn_BasicInitialization(t *testing.T) {
|
|
al, agent, cleanup := newTurnCoordTestLoop(t, &simpleConvProvider{})
|
|
defer cleanup()
|
|
|
|
pipeline := NewPipeline(al)
|
|
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
|
|
turnID: "turn-1",
|
|
context: newTurnContext(nil, nil, nil),
|
|
})
|
|
|
|
exec, err := pipeline.SetupTurn(context.Background(), ts)
|
|
if err != nil {
|
|
t.Fatalf("SetupTurn failed: %v", err)
|
|
}
|
|
if exec == nil {
|
|
t.Fatal("expected non-nil turnExecution")
|
|
}
|
|
if len(exec.messages) == 0 {
|
|
t.Error("expected messages to be populated")
|
|
}
|
|
if exec.iteration != 0 {
|
|
t.Errorf("expected iteration 0, got %d", exec.iteration)
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// Pipeline Method Tests: CallLLM
|
|
// =============================================================================
|
|
|
|
func TestPipeline_CallLLM_SimpleResponse(t *testing.T) {
|
|
al, agent, cleanup := newTurnCoordTestLoop(t, &simpleConvProvider{})
|
|
defer cleanup()
|
|
|
|
pipeline := NewPipeline(al)
|
|
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
|
|
turnID: "turn-1",
|
|
context: newTurnContext(nil, nil, nil),
|
|
})
|
|
|
|
exec, err := pipeline.SetupTurn(context.Background(), ts)
|
|
if err != nil {
|
|
t.Fatalf("SetupTurn failed: %v", err)
|
|
}
|
|
|
|
ctrl, err := pipeline.CallLLM(context.Background(), context.Background(), ts, exec, 1)
|
|
if err != nil {
|
|
t.Fatalf("CallLLM failed: %v", err)
|
|
}
|
|
if ctrl != ControlBreak {
|
|
t.Errorf("expected ControlBreak, got %v", ctrl)
|
|
}
|
|
if exec.response == nil {
|
|
t.Fatal("expected non-nil response")
|
|
}
|
|
if exec.response.Content == "" {
|
|
t.Error("expected non-empty content")
|
|
}
|
|
}
|
|
|
|
func TestPipeline_SetupTurn_ModelNameDoesNotUseFallbackAliasBeforeFallback(t *testing.T) {
|
|
al, agent, cleanup := newTurnCoordTestLoop(t, &simpleConvProvider{})
|
|
defer cleanup()
|
|
|
|
agent.Model = "primary-model"
|
|
agent.Candidates = []providers.FallbackCandidate{
|
|
{Provider: "openai", Model: "gpt-5.4"},
|
|
{Provider: "anthropic", Model: "claude-sonnet", IdentityKey: "model_name:fallback-model"},
|
|
}
|
|
|
|
pipeline := NewPipeline(al)
|
|
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
|
|
turnID: "turn-1",
|
|
context: newTurnContext(nil, nil, nil),
|
|
})
|
|
|
|
exec, err := pipeline.SetupTurn(context.Background(), ts)
|
|
if err != nil {
|
|
t.Fatalf("SetupTurn failed: %v", err)
|
|
}
|
|
if exec.llmModelName != "primary-model" {
|
|
t.Fatalf("exec.llmModelName = %q, want %q", exec.llmModelName, "primary-model")
|
|
}
|
|
}
|
|
|
|
func TestPipeline_CallLLM_UsesSuccessfulFallbackIdentityAlias(t *testing.T) {
|
|
provider := &sequenceProvider{
|
|
errors: []error{
|
|
errors.New("status: 429 - rate limit exceeded"),
|
|
nil,
|
|
},
|
|
responses: []*providers.LLMResponse{
|
|
nil,
|
|
{Content: "fallback answer", FinishReason: "stop"},
|
|
},
|
|
}
|
|
al, agent, cleanup := newTurnCoordTestLoop(t, provider)
|
|
defer cleanup()
|
|
|
|
agent.Model = "primary-model"
|
|
agent.Candidates = []providers.FallbackCandidate{
|
|
{Provider: "openai", Model: "gpt-5.4", IdentityKey: "model_name:primary"},
|
|
{Provider: "openai", Model: "gpt-5.4", IdentityKey: "model_name:secondary"},
|
|
}
|
|
al.fallback = providers.NewFallbackChain(providers.NewCooldownTracker(), nil)
|
|
|
|
pipeline := NewPipeline(al)
|
|
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
|
|
turnID: "turn-1",
|
|
context: newTurnContext(nil, nil, nil),
|
|
})
|
|
|
|
exec, err := pipeline.SetupTurn(context.Background(), ts)
|
|
if err != nil {
|
|
t.Fatalf("SetupTurn failed: %v", err)
|
|
}
|
|
|
|
ctrl, err := pipeline.CallLLM(context.Background(), context.Background(), ts, exec, 1)
|
|
if err != nil {
|
|
t.Fatalf("CallLLM failed: %v", err)
|
|
}
|
|
if ctrl != ControlBreak {
|
|
t.Fatalf("expected ControlBreak, got %v", ctrl)
|
|
}
|
|
if exec.llmModelName != "secondary" {
|
|
t.Fatalf("exec.llmModelName = %q, want %q", exec.llmModelName, "secondary")
|
|
}
|
|
}
|
|
|
|
func TestPipeline_CallLLM_UsesSuccessfulFallbackDisplayNameWithoutAlias(t *testing.T) {
|
|
provider := &sequenceProvider{
|
|
errors: []error{
|
|
errors.New("status: 429 - rate limit exceeded"),
|
|
nil,
|
|
},
|
|
responses: []*providers.LLMResponse{
|
|
nil,
|
|
{Content: "fallback answer", FinishReason: "stop"},
|
|
},
|
|
}
|
|
al, agent, cleanup := newTurnCoordTestLoop(t, provider)
|
|
defer cleanup()
|
|
|
|
agent.Model = "primary-model"
|
|
agent.Candidates = []providers.FallbackCandidate{
|
|
{Provider: "openai", Model: "gpt-5.4", IdentityKey: "model_name:primary", DisplayName: "primary-model"},
|
|
{Provider: "anthropic", Model: "claude-sonnet", DisplayName: "anthropic/claude-sonnet"},
|
|
}
|
|
al.fallback = providers.NewFallbackChain(providers.NewCooldownTracker(), nil)
|
|
|
|
pipeline := NewPipeline(al)
|
|
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
|
|
turnID: "turn-1",
|
|
context: newTurnContext(nil, nil, nil),
|
|
})
|
|
|
|
exec, err := pipeline.SetupTurn(context.Background(), ts)
|
|
if err != nil {
|
|
t.Fatalf("SetupTurn failed: %v", err)
|
|
}
|
|
|
|
ctrl, err := pipeline.CallLLM(context.Background(), context.Background(), ts, exec, 1)
|
|
if err != nil {
|
|
t.Fatalf("CallLLM failed: %v", err)
|
|
}
|
|
if ctrl != ControlBreak {
|
|
t.Fatalf("expected ControlBreak, got %v", ctrl)
|
|
}
|
|
if exec.llmModelName != "anthropic/claude-sonnet" {
|
|
t.Fatalf("exec.llmModelName = %q, want %q", exec.llmModelName, "anthropic/claude-sonnet")
|
|
}
|
|
}
|
|
|
|
func TestPipeline_SetupTurn_UsesLightCandidateDisplayName(t *testing.T) {
|
|
al, agent, cleanup := newTurnCoordTestLoop(t, &simpleConvProvider{})
|
|
defer cleanup()
|
|
|
|
agent.Model = "primary-model"
|
|
agent.Candidates = []providers.FallbackCandidate{
|
|
{Provider: "openai", Model: "gpt-5.4", IdentityKey: "model_name:primary", DisplayName: "primary-model"},
|
|
}
|
|
agent.LightCandidates = []providers.FallbackCandidate{
|
|
{Provider: "openai", Model: "gpt-5.4-mini", IdentityKey: "model_name:light-model", DisplayName: "light-model"},
|
|
}
|
|
agent.Router = routing.New(routing.RouterConfig{LightModel: "light-model", Threshold: 1})
|
|
|
|
pipeline := NewPipeline(al)
|
|
opts := makeTestProcessOpts("test-session")
|
|
opts.UserMessage = ""
|
|
ts := newTurnState(agent, opts, turnEventScope{
|
|
turnID: "turn-1",
|
|
context: newTurnContext(nil, nil, nil),
|
|
})
|
|
|
|
exec, err := pipeline.SetupTurn(context.Background(), ts)
|
|
if err != nil {
|
|
t.Fatalf("SetupTurn failed: %v", err)
|
|
}
|
|
if !exec.usedLight {
|
|
t.Fatal("expected light routing to be used")
|
|
}
|
|
if exec.llmModelName != "light-model" {
|
|
t.Fatalf("exec.llmModelName = %q, want %q", exec.llmModelName, "light-model")
|
|
}
|
|
}
|
|
|
|
func TestRunTurn_FinalizeSaveErrorEmitsErrorTurnEnd(t *testing.T) {
|
|
al, agent, cleanup := newTurnCoordTestLoop(t, &simpleConvProvider{})
|
|
defer cleanup()
|
|
|
|
saveErr := errors.New("session save failed")
|
|
agent.Sessions = &saveFailingSessionStore{
|
|
SessionStore: session.NewSessionManager(""),
|
|
err: saveErr,
|
|
}
|
|
|
|
sub := al.SubscribeEvents(8)
|
|
defer al.UnsubscribeEvents(sub.ID)
|
|
|
|
if _, err := al.ProcessDirect(context.Background(), "hello", "session-save-fail"); err == nil {
|
|
t.Fatal("expected ProcessDirect to fail")
|
|
}
|
|
|
|
deadline := time.After(2 * time.Second)
|
|
for {
|
|
select {
|
|
case evt := <-sub.C:
|
|
if evt.Kind != EventKindTurnEnd {
|
|
continue
|
|
}
|
|
payload, ok := evt.Payload.(TurnEndPayload)
|
|
if !ok {
|
|
t.Fatalf("TurnEnd payload type = %T", evt.Payload)
|
|
}
|
|
if payload.Status != TurnEndStatusError {
|
|
t.Fatalf("TurnEnd status = %q, want %q", payload.Status, TurnEndStatusError)
|
|
}
|
|
return
|
|
case <-deadline:
|
|
t.Fatal("timed out waiting for turn_end event")
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestPipeline_CallLLM_WithToolCall(t *testing.T) {
|
|
provider := &toolCallRespProvider{
|
|
toolName: "web_search",
|
|
toolArgs: map[string]any{"query": "test"},
|
|
response: "Found information about test.",
|
|
}
|
|
al, agent, cleanup := newTurnCoordTestLoop(t, provider)
|
|
defer cleanup()
|
|
|
|
pipeline := NewPipeline(al)
|
|
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
|
|
turnID: "turn-1",
|
|
context: newTurnContext(nil, nil, nil),
|
|
})
|
|
|
|
exec, err := pipeline.SetupTurn(context.Background(), ts)
|
|
if err != nil {
|
|
t.Fatalf("SetupTurn failed: %v", err)
|
|
}
|
|
|
|
ctrl, err := pipeline.CallLLM(context.Background(), context.Background(), ts, exec, 1)
|
|
if err != nil {
|
|
t.Fatalf("CallLLM failed: %v", err)
|
|
}
|
|
if ctrl != ControlToolLoop {
|
|
t.Errorf("expected ControlToolLoop, got %v", ctrl)
|
|
}
|
|
if len(exec.normalizedToolCalls) == 0 {
|
|
t.Fatal("expected tool calls")
|
|
}
|
|
if exec.normalizedToolCalls[0].Name != "web_search" {
|
|
t.Errorf("expected tool name 'web_search', got %q", exec.normalizedToolCalls[0].Name)
|
|
}
|
|
}
|
|
|
|
func TestPipeline_CallLLM_UsesNativeSearchWithoutClientWebSearchTool(t *testing.T) {
|
|
provider := &nativeSearchCaptureProvider{}
|
|
al, agent, cleanup := newTurnCoordTestLoop(t, provider)
|
|
defer cleanup()
|
|
|
|
if _, ok := agent.Tools.Get("web_search"); ok {
|
|
t.Fatal("expected no client-side web_search tool to be registered")
|
|
}
|
|
|
|
al.cfg.Tools.Web.Enabled = true
|
|
al.cfg.Tools.Web.PreferNative = true
|
|
|
|
pipeline := NewPipeline(al)
|
|
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
|
|
turnID: "turn-1",
|
|
context: newTurnContext(nil, nil, nil),
|
|
})
|
|
|
|
exec, err := pipeline.SetupTurn(context.Background(), ts)
|
|
if err != nil {
|
|
t.Fatalf("SetupTurn failed: %v", err)
|
|
}
|
|
|
|
ctrl, err := pipeline.CallLLM(context.Background(), context.Background(), ts, exec, 1)
|
|
if err != nil {
|
|
t.Fatalf("CallLLM failed: %v", err)
|
|
}
|
|
if ctrl != ControlBreak {
|
|
t.Fatalf("expected ControlBreak, got %v", ctrl)
|
|
}
|
|
if got, _ := provider.lastOpts["native_search"].(bool); !got {
|
|
t.Fatalf("expected native_search=true, got %#v", provider.lastOpts["native_search"])
|
|
}
|
|
}
|
|
|
|
func TestPipeline_CallLLM_TimeoutRetry(t *testing.T) {
|
|
errorPrv := &errorProvider{errType: "timeout"}
|
|
al, agent, cleanup := newTurnCoordTestLoop(t, errorPrv)
|
|
defer cleanup()
|
|
|
|
pipeline := NewPipeline(al)
|
|
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
|
|
turnID: "turn-1",
|
|
context: newTurnContext(nil, nil, nil),
|
|
})
|
|
|
|
exec, err := pipeline.SetupTurn(context.Background(), ts)
|
|
if err != nil {
|
|
t.Fatalf("SetupTurn failed: %v", err)
|
|
}
|
|
|
|
// Should retry and eventually fail after max retries
|
|
_, err = pipeline.CallLLM(context.Background(), context.Background(), ts, exec, 1)
|
|
if err == nil {
|
|
t.Error("expected error after retries")
|
|
}
|
|
}
|
|
|
|
func TestPipeline_CallLLM_ContextLengthError(t *testing.T) {
|
|
errorPrv := &errorProvider{errType: "context_length"}
|
|
al, agent, cleanup := newTurnCoordTestLoop(t, errorPrv)
|
|
defer cleanup()
|
|
|
|
pipeline := NewPipeline(al)
|
|
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
|
|
turnID: "turn-1",
|
|
context: newTurnContext(nil, nil, nil),
|
|
})
|
|
|
|
exec, err := pipeline.SetupTurn(context.Background(), ts)
|
|
if err != nil {
|
|
t.Fatalf("SetupTurn failed: %v", err)
|
|
}
|
|
|
|
// Should trigger context compression and retry
|
|
_, err = pipeline.CallLLM(context.Background(), context.Background(), ts, exec, 1)
|
|
// May succeed after compression or fail - either is acceptable
|
|
t.Logf("CallLLM result after context error: err=%v", err)
|
|
}
|
|
|
|
func TestPipeline_CallLLM_NetworkErrorRetry(t *testing.T) {
|
|
testCases := []struct {
|
|
name string
|
|
errType string
|
|
}{
|
|
{"connection_reset", "connection_reset"},
|
|
{"broken_pipe", "broken_pipe"},
|
|
{"read_tcp", "read_tcp"},
|
|
{"eof", "eof"},
|
|
{"connection_refused", "connection_refused"},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
errorPrv := &errorProvider{errType: tc.errType}
|
|
al, agent, cleanup := newTurnCoordTestLoop(t, errorPrv)
|
|
defer cleanup()
|
|
|
|
pipeline := NewPipeline(al)
|
|
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
|
|
turnID: "turn-1",
|
|
context: newTurnContext(nil, nil, nil),
|
|
})
|
|
|
|
exec, err := pipeline.SetupTurn(context.Background(), ts)
|
|
if err != nil {
|
|
t.Fatalf("SetupTurn failed: %v", err)
|
|
}
|
|
|
|
_, err = pipeline.CallLLM(context.Background(), context.Background(), ts, exec, 1)
|
|
if err == nil {
|
|
t.Error("expected error after network error retries")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestPipeline_CallLLM_RetryConfigRespected(t *testing.T) {
|
|
tmpDir := t.TempDir()
|
|
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
ModelName: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
MaxLLMRetries: 3,
|
|
LLMRetryBackoffSecs: 1,
|
|
},
|
|
},
|
|
}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
provider := &errorProvider{errType: "connection_reset"}
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
defer al.Close()
|
|
agent := al.registry.GetDefaultAgent()
|
|
if agent == nil {
|
|
t.Fatal("expected default agent")
|
|
}
|
|
|
|
pipeline := NewPipeline(al)
|
|
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
|
|
turnID: "turn-1",
|
|
context: newTurnContext(nil, nil, nil),
|
|
})
|
|
|
|
exec, err := pipeline.SetupTurn(context.Background(), ts)
|
|
if err != nil {
|
|
t.Fatalf("SetupTurn failed: %v", err)
|
|
}
|
|
|
|
start := time.Now()
|
|
_, err = pipeline.CallLLM(context.Background(), context.Background(), ts, exec, 1)
|
|
elapsed := time.Since(start)
|
|
|
|
if err == nil {
|
|
t.Error("expected error after retries")
|
|
}
|
|
|
|
expectedMinTime := 3 * time.Second
|
|
if elapsed < expectedMinTime {
|
|
t.Errorf("expected at least %v of backoff, got %v", expectedMinTime, elapsed)
|
|
}
|
|
}
|
|
|
|
func TestPipeline_CallLLM_RetryCountLimit(t *testing.T) {
|
|
tmpDir := t.TempDir()
|
|
|
|
counterPrv := &countingErrorProvider{errType: "connection_reset", targetCalls: 5}
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
ModelName: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
MaxLLMRetries: 2,
|
|
LLMRetryBackoffSecs: 0,
|
|
},
|
|
},
|
|
}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
al := NewAgentLoop(cfg, msgBus, counterPrv)
|
|
defer al.Close()
|
|
agent := al.registry.GetDefaultAgent()
|
|
if agent == nil {
|
|
t.Fatal("expected default agent")
|
|
}
|
|
|
|
pipeline := NewPipeline(al)
|
|
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
|
|
turnID: "turn-1",
|
|
context: newTurnContext(nil, nil, nil),
|
|
})
|
|
|
|
exec, err := pipeline.SetupTurn(context.Background(), ts)
|
|
if err != nil {
|
|
t.Fatalf("SetupTurn failed: %v", err)
|
|
}
|
|
|
|
_, err = pipeline.CallLLM(context.Background(), context.Background(), ts, exec, 1)
|
|
if err == nil {
|
|
t.Error("expected error after retries")
|
|
}
|
|
|
|
if counterPrv.callCount != 3 {
|
|
t.Errorf("expected exactly 3 calls (1 initial + 2 retries), got %d", counterPrv.callCount)
|
|
}
|
|
}
|
|
|
|
type countingErrorProvider struct {
|
|
errType string
|
|
targetCalls int
|
|
callCount int
|
|
mu sync.Mutex
|
|
}
|
|
|
|
func (p *countingErrorProvider) Chat(
|
|
ctx context.Context,
|
|
messages []providers.Message,
|
|
tools []providers.ToolDefinition,
|
|
model string,
|
|
opts map[string]any,
|
|
) (*providers.LLMResponse, error) {
|
|
p.mu.Lock()
|
|
p.callCount++
|
|
p.mu.Unlock()
|
|
return nil, errors.New("connection reset by peer")
|
|
}
|
|
|
|
func (p *countingErrorProvider) GetDefaultModel() string {
|
|
return "counting-error-model"
|
|
}
|
|
|
|
// =============================================================================
|
|
// Pipeline Method Tests: ExecuteTools
|
|
// =============================================================================
|
|
|
|
func TestPipeline_ExecuteTools_NoTools(t *testing.T) {
|
|
// Provider returns no tool calls, so ExecuteTools should not be called
|
|
// This test verifies the ControlBreak path from CallLLM
|
|
provider := &simpleConvProvider{}
|
|
al, agent, cleanup := newTurnCoordTestLoop(t, provider)
|
|
defer cleanup()
|
|
|
|
pipeline := NewPipeline(al)
|
|
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
|
|
turnID: "turn-1",
|
|
context: newTurnContext(nil, nil, nil),
|
|
})
|
|
|
|
exec, err := pipeline.SetupTurn(context.Background(), ts)
|
|
if err != nil {
|
|
t.Fatalf("SetupTurn failed: %v", err)
|
|
}
|
|
|
|
// First CallLLM returns ControlBreak (no tools)
|
|
ctrl, err := pipeline.CallLLM(context.Background(), context.Background(), ts, exec, 1)
|
|
if err != nil {
|
|
t.Fatalf("CallLLM failed: %v", err)
|
|
}
|
|
|
|
if ctrl != ControlBreak {
|
|
t.Fatalf("expected ControlBreak, got %v", ctrl)
|
|
}
|
|
// No tools to execute, Finalize should be called directly
|
|
}
|
|
|
|
// =============================================================================
|
|
// runTurn Integration Tests
|
|
// =============================================================================
|
|
|
|
func TestRunTurn_SimpleConversation(t *testing.T) {
|
|
provider := &simpleConvProvider{}
|
|
al, agent, cleanup := newTurnCoordTestLoop(t, provider)
|
|
defer cleanup()
|
|
|
|
pipeline := NewPipeline(al)
|
|
opts := makeTestProcessOpts("test-session-simple")
|
|
|
|
ts := newTurnState(agent, opts, turnEventScope{
|
|
turnID: "turn-simple",
|
|
context: newTurnContext(nil, nil, nil),
|
|
})
|
|
|
|
result, err := al.runTurn(context.Background(), ts, pipeline)
|
|
if err != nil {
|
|
t.Fatalf("runTurn failed: %v", err)
|
|
}
|
|
if result.status != TurnEndStatusCompleted {
|
|
t.Errorf("expected status Completed, got %v", result.status)
|
|
}
|
|
if result.finalContent == "" {
|
|
t.Error("expected non-empty finalContent")
|
|
}
|
|
}
|
|
|
|
func TestRunTurn_MaxIterations(t *testing.T) {
|
|
// Provider always returns tool calls, should hit max iterations
|
|
provider := &toolCallRespProvider{
|
|
toolName: "search",
|
|
toolArgs: map[string]any{"q": "x"},
|
|
response: "done",
|
|
}
|
|
al, agent, cleanup := newTurnCoordTestLoop(t, provider)
|
|
defer cleanup()
|
|
|
|
// Override max iterations to 2
|
|
agent.MaxIterations = 2
|
|
|
|
pipeline := NewPipeline(al)
|
|
opts := makeTestProcessOpts("test-session-maxiter")
|
|
|
|
ts := newTurnState(agent, opts, turnEventScope{
|
|
turnID: "turn-maxiter",
|
|
context: newTurnContext(nil, nil, nil),
|
|
})
|
|
|
|
result, err := al.runTurn(context.Background(), ts, pipeline)
|
|
if err != nil {
|
|
t.Fatalf("runTurn failed: %v", err)
|
|
}
|
|
// Should complete due to max iterations
|
|
if result.status != TurnEndStatusCompleted {
|
|
t.Errorf("expected status Completed, got %v", result.status)
|
|
}
|
|
}
|
|
|
|
func TestRunTurn_HardAbort(t *testing.T) {
|
|
// Provider simulates a slow response, but we'll abort mid-turn
|
|
slowProvider := &slowMockProvider{delay: 10 * time.Second}
|
|
al, agent, cleanup := newTurnCoordTestLoop(t, slowProvider)
|
|
defer cleanup()
|
|
|
|
pipeline := NewPipeline(al)
|
|
opts := makeTestProcessOpts("test-session-abort")
|
|
|
|
ts := newTurnState(agent, opts, turnEventScope{
|
|
turnID: "turn-abort",
|
|
context: newTurnContext(nil, nil, nil),
|
|
})
|
|
|
|
// Run in goroutine with abort after short delay
|
|
done := make(chan struct{})
|
|
|
|
go func() {
|
|
al.runTurn(context.Background(), ts, pipeline)
|
|
close(done)
|
|
}()
|
|
|
|
// Give it a moment to start
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
// Request hard abort
|
|
ts.requestHardAbort()
|
|
|
|
// Wait for runTurn to complete
|
|
select {
|
|
case <-done:
|
|
case <-time.After(3 * time.Second):
|
|
t.Fatal("runTurn did not complete after abort")
|
|
}
|
|
}
|
|
|
|
func TestRunTurn_SteeringMessageInjection(t *testing.T) {
|
|
provider := &simpleConvProvider{}
|
|
al, agent, cleanup := newTurnCoordTestLoop(t, provider)
|
|
defer cleanup()
|
|
|
|
pipeline := NewPipeline(al)
|
|
opts := makeTestProcessOpts("test-session-steering")
|
|
|
|
ts := newTurnState(agent, opts, turnEventScope{
|
|
turnID: "turn-steering",
|
|
context: newTurnContext(nil, nil, nil),
|
|
})
|
|
|
|
// Enqueue steering message before runTurn
|
|
steeringMsg := providers.Message{
|
|
Role: "user",
|
|
Content: "Steering message",
|
|
}
|
|
al.Steer(steeringMsg)
|
|
|
|
result, err := al.runTurn(context.Background(), ts, pipeline)
|
|
if err != nil {
|
|
t.Fatalf("runTurn failed: %v", err)
|
|
}
|
|
if result.status != TurnEndStatusCompleted {
|
|
t.Errorf("expected status Completed, got %v", result.status)
|
|
}
|
|
// Steering message should have been injected
|
|
}
|
|
|
|
func TestRunTurn_GracefulInterrupt(t *testing.T) {
|
|
provider := &toolCallRespProvider{
|
|
toolName: "search",
|
|
toolArgs: map[string]any{"q": "test"},
|
|
response: "Final response after interrupt",
|
|
}
|
|
al, agent, cleanup := newTurnCoordTestLoop(t, provider)
|
|
defer cleanup()
|
|
|
|
pipeline := NewPipeline(al)
|
|
opts := makeTestProcessOpts("test-session-graceful")
|
|
|
|
ts := newTurnState(agent, opts, turnEventScope{
|
|
turnID: "turn-graceful",
|
|
context: newTurnContext(nil, nil, nil),
|
|
})
|
|
|
|
// Run in goroutine with graceful interrupt after first iteration
|
|
done := make(chan struct{})
|
|
var result turnResult
|
|
|
|
go func() {
|
|
result, _ = al.runTurn(context.Background(), ts, pipeline)
|
|
close(done)
|
|
}()
|
|
|
|
// Give it a moment to start first iteration
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
// Request graceful interrupt
|
|
ts.requestGracefulInterrupt("Please stop")
|
|
|
|
select {
|
|
case <-done:
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatal("runTurn did not complete after graceful interrupt")
|
|
}
|
|
|
|
// Should complete gracefully
|
|
if result.status != TurnEndStatusCompleted {
|
|
t.Errorf("expected status Completed, got %v", result.status)
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// turnState Tests
|
|
// =============================================================================
|
|
|
|
func TestTurnState_GracefulInterruptRequested(t *testing.T) {
|
|
ts := &turnState{
|
|
gracefulInterrupt: false,
|
|
gracefulInterruptHint: "",
|
|
}
|
|
|
|
// Initially should not be requested
|
|
requested, _ := ts.gracefulInterruptRequested()
|
|
if requested {
|
|
t.Error("expected no interrupt initially")
|
|
}
|
|
|
|
// Request interrupt
|
|
ts.requestGracefulInterrupt("test hint")
|
|
|
|
requested, hint := ts.gracefulInterruptRequested()
|
|
if !requested {
|
|
t.Error("expected interrupt to be requested")
|
|
}
|
|
if hint != "test hint" {
|
|
t.Errorf("expected hint 'test hint', got %q", hint)
|
|
}
|
|
}
|
|
|
|
func TestTurnState_HardAbortRequested(t *testing.T) {
|
|
ts := &turnState{
|
|
hardAbort: false,
|
|
}
|
|
|
|
if ts.hardAbortRequested() {
|
|
t.Error("expected no hard abort initially")
|
|
}
|
|
|
|
ts.requestHardAbort()
|
|
|
|
if !ts.hardAbortRequested() {
|
|
t.Error("expected hard abort to be requested")
|
|
}
|
|
}
|
|
|
|
func TestTurnState_SkillContextSnapshotsTrackLatestSuccessfulPath(t *testing.T) {
|
|
ts := &turnState{}
|
|
|
|
ts.recordSkillContextSnapshot(skillContextTriggerInitialBuild, []string{"skill-a"})
|
|
ts.recordSkillContextSnapshot(skillContextTriggerContextRetryRebuild, []string{"skill-b", "skill-c"})
|
|
|
|
if got := ts.attemptedSkillsSnapshot(); len(got) != 3 || got[0] != "skill-a" || got[1] != "skill-b" ||
|
|
got[2] != "skill-c" {
|
|
t.Fatalf("attemptedSkillsSnapshot = %v, want [skill-a skill-b skill-c]", got)
|
|
}
|
|
|
|
if got := ts.latestSkillContextSnapshot(); len(got) != 2 || got[0] != "skill-b" || got[1] != "skill-c" {
|
|
t.Fatalf("latestSkillContextSnapshot = %v, want [skill-b skill-c]", got)
|
|
}
|
|
|
|
snapshots := ts.skillContextSnapshotsSnapshot()
|
|
if len(snapshots) != 2 {
|
|
t.Fatalf("len(skillContextSnapshotsSnapshot()) = %d, want 2", len(snapshots))
|
|
}
|
|
if snapshots[0].Sequence != 1 || snapshots[0].Trigger != skillContextTriggerInitialBuild {
|
|
t.Fatalf("snapshots[0] = %+v, want sequence=1 trigger=%q", snapshots[0], skillContextTriggerInitialBuild)
|
|
}
|
|
if snapshots[1].Sequence != 2 || snapshots[1].Trigger != skillContextTriggerContextRetryRebuild {
|
|
t.Fatalf("snapshots[1] = %+v, want sequence=2 trigger=%q", snapshots[1], skillContextTriggerContextRetryRebuild)
|
|
}
|
|
}
|