Files
picoclaw/pkg/agent/turn_coord_test.go
T
sky5454 329e68e017 refactor(agent): Agent Looper refactor phase2, restructure pipeline and rename loop files to agent (#2585)
* refactor(agent): introduce interfaces for MessageBus and ChannelManager

Phase 2 of loop.go refactor — dependency inversion using adapter pattern.

- Add interfaces.MessageBus and interfaces.ChannelManager interfaces
- Create adapters/messagebus.go wrapping *bus.MessageBus
- Create adapters/channelmanager.go wrapping *channels.Manager
- Update AgentLoop to use interfaces instead of concrete types
- Update registerSharedTools to accept interfaces.MessageBus

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* refactor(agent): restructure pipeline and rename loop files

Pipeline refactoring:
- Split pipeline.go (1400 lines) into focused files:
  - pipeline_setup.go (~115 lines): SetupTurn method
  - pipeline_llm.go (~519 lines): CallLLM method
  - pipeline_execute.go (~693 lines): ExecuteTools method
  - pipeline_finalize.go (~78 lines): Finalize method
- Pipeline struct and NewPipeline remain in pipeline.go (~39 lines)

Agent file renaming:
- Rename loop_*.go to agent_*.go for consistent naming:
  - loop.go -> agent.go, loop_message.go -> agent_message.go, etc.
- Merge turn.go + turn_exec.go into turn_state.go
- Rename loop_turn.go -> turn_coord.go

Documentation:
- Update docs/pipeline-restructuring-plan.md
- Add docs/agent-rename-plan.md

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* fix(agent): code format  fixed

* refactor(agent): code test file added/renamed

* docs(agent): update agent refactor docs

* fix(agent): fix agent hardAbortX

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-21 10:55:50 +08:00

552 lines
14 KiB
Go

package agent
import (
"context"
"errors"
"sync"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/providers"
)
// =============================================================================
// Mock Providers for turn_coord Tests
// =============================================================================
// simpleConvProvider returns a simple text response without tools
type simpleConvProvider struct{}
func (p *simpleConvProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
return &providers.LLMResponse{
Content: "Hello! How can I help you today?",
FinishReason: "stop",
}, nil
}
func (p *simpleConvProvider) GetDefaultModel() string {
return "simple-model"
}
// toolCallRespProvider returns a tool call response
type toolCallRespProvider struct {
toolName string
toolArgs map[string]any
response string
callCount int
mu sync.Mutex
}
func (p *toolCallRespProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
p.mu.Lock()
p.callCount++
count := p.callCount
p.mu.Unlock()
// First call returns a tool call, subsequent calls return final response
if count == 1 {
return &providers.LLMResponse{
Content: "Let me search for that information.",
ToolCalls: []providers.ToolCall{
{
ID: "call_1",
Name: p.toolName,
Arguments: p.toolArgs,
},
},
FinishReason: "tool_calls",
}, nil
}
return &providers.LLMResponse{
Content: p.response,
FinishReason: "stop",
}, nil
}
func (p *toolCallRespProvider) GetDefaultModel() string {
return "tool-model"
}
// errorProvider simulates various error conditions
type errorProvider struct {
errType string
callCount int
mu sync.Mutex
}
func (p *errorProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
p.mu.Lock()
p.callCount++
p.mu.Unlock()
switch p.errType {
case "timeout":
return nil, context.DeadlineExceeded
case "context_length":
return nil, errors.New("context_length_exceeded")
case "vision":
return nil, errors.New("vision_unsupported")
default:
return nil, errors.New("unknown error")
}
}
func (p *errorProvider) GetDefaultModel() string {
return "error-model"
}
// =============================================================================
// Test Helper Functions
// =============================================================================
func newTurnCoordTestLoop(t *testing.T, provider providers.LLMProvider) (*AgentLoop, *AgentInstance, func()) {
t.Helper()
tmpDir := t.TempDir()
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
ModelName: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
al := NewAgentLoop(cfg, msgBus, provider)
agent := al.registry.GetDefaultAgent()
if agent == nil {
t.Fatal("expected default agent")
}
return al, agent, func() {
al.Close()
}
}
func makeTestProcessOpts(sessionKey string) processOptions {
return processOptions{
SessionKey: sessionKey,
Channel: "cli",
ChatID: "test-chat",
UserMessage: "test message",
DefaultResponse: "I couldn't process your request.",
EnableSummary: false,
SendResponse: false,
NoHistory: false,
}
}
// =============================================================================
// Pipeline Method Tests: SetupTurn
// =============================================================================
func TestPipeline_SetupTurn_BasicInitialization(t *testing.T) {
al, agent, cleanup := newTurnCoordTestLoop(t, &simpleConvProvider{})
defer cleanup()
pipeline := NewPipeline(al)
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
turnID: "turn-1",
context: newTurnContext(nil, nil, nil),
})
exec, err := pipeline.SetupTurn(context.Background(), ts)
if err != nil {
t.Fatalf("SetupTurn failed: %v", err)
}
if exec == nil {
t.Fatal("expected non-nil turnExecution")
}
if len(exec.messages) == 0 {
t.Error("expected messages to be populated")
}
if exec.iteration != 0 {
t.Errorf("expected iteration 0, got %d", exec.iteration)
}
}
// =============================================================================
// Pipeline Method Tests: CallLLM
// =============================================================================
func TestPipeline_CallLLM_SimpleResponse(t *testing.T) {
al, agent, cleanup := newTurnCoordTestLoop(t, &simpleConvProvider{})
defer cleanup()
pipeline := NewPipeline(al)
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
turnID: "turn-1",
context: newTurnContext(nil, nil, nil),
})
exec, err := pipeline.SetupTurn(context.Background(), ts)
if err != nil {
t.Fatalf("SetupTurn failed: %v", err)
}
ctrl, err := pipeline.CallLLM(context.Background(), context.Background(), ts, exec, 1)
if err != nil {
t.Fatalf("CallLLM failed: %v", err)
}
if ctrl != ControlBreak {
t.Errorf("expected ControlBreak, got %v", ctrl)
}
if exec.response == nil {
t.Fatal("expected non-nil response")
}
if exec.response.Content == "" {
t.Error("expected non-empty content")
}
}
func TestPipeline_CallLLM_WithToolCall(t *testing.T) {
provider := &toolCallRespProvider{
toolName: "web_search",
toolArgs: map[string]any{"query": "test"},
response: "Found information about test.",
}
al, agent, cleanup := newTurnCoordTestLoop(t, provider)
defer cleanup()
pipeline := NewPipeline(al)
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
turnID: "turn-1",
context: newTurnContext(nil, nil, nil),
})
exec, err := pipeline.SetupTurn(context.Background(), ts)
if err != nil {
t.Fatalf("SetupTurn failed: %v", err)
}
ctrl, err := pipeline.CallLLM(context.Background(), context.Background(), ts, exec, 1)
if err != nil {
t.Fatalf("CallLLM failed: %v", err)
}
if ctrl != ControlToolLoop {
t.Errorf("expected ControlToolLoop, got %v", ctrl)
}
if len(exec.normalizedToolCalls) == 0 {
t.Fatal("expected tool calls")
}
if exec.normalizedToolCalls[0].Name != "web_search" {
t.Errorf("expected tool name 'web_search', got %q", exec.normalizedToolCalls[0].Name)
}
}
func TestPipeline_CallLLM_TimeoutRetry(t *testing.T) {
errorPrv := &errorProvider{errType: "timeout"}
al, agent, cleanup := newTurnCoordTestLoop(t, errorPrv)
defer cleanup()
pipeline := NewPipeline(al)
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
turnID: "turn-1",
context: newTurnContext(nil, nil, nil),
})
exec, err := pipeline.SetupTurn(context.Background(), ts)
if err != nil {
t.Fatalf("SetupTurn failed: %v", err)
}
// Should retry and eventually fail after max retries
_, err = pipeline.CallLLM(context.Background(), context.Background(), ts, exec, 1)
if err == nil {
t.Error("expected error after retries")
}
}
func TestPipeline_CallLLM_ContextLengthError(t *testing.T) {
errorPrv := &errorProvider{errType: "context_length"}
al, agent, cleanup := newTurnCoordTestLoop(t, errorPrv)
defer cleanup()
pipeline := NewPipeline(al)
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
turnID: "turn-1",
context: newTurnContext(nil, nil, nil),
})
exec, err := pipeline.SetupTurn(context.Background(), ts)
if err != nil {
t.Fatalf("SetupTurn failed: %v", err)
}
// Should trigger context compression and retry
_, err = pipeline.CallLLM(context.Background(), context.Background(), ts, exec, 1)
// May succeed after compression or fail - either is acceptable
t.Logf("CallLLM result after context error: err=%v", err)
}
// =============================================================================
// Pipeline Method Tests: ExecuteTools
// =============================================================================
func TestPipeline_ExecuteTools_NoTools(t *testing.T) {
// Provider returns no tool calls, so ExecuteTools should not be called
// This test verifies the ControlBreak path from CallLLM
provider := &simpleConvProvider{}
al, agent, cleanup := newTurnCoordTestLoop(t, provider)
defer cleanup()
pipeline := NewPipeline(al)
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
turnID: "turn-1",
context: newTurnContext(nil, nil, nil),
})
exec, err := pipeline.SetupTurn(context.Background(), ts)
if err != nil {
t.Fatalf("SetupTurn failed: %v", err)
}
// First CallLLM returns ControlBreak (no tools)
ctrl, err := pipeline.CallLLM(context.Background(), context.Background(), ts, exec, 1)
if err != nil {
t.Fatalf("CallLLM failed: %v", err)
}
if ctrl != ControlBreak {
t.Fatalf("expected ControlBreak, got %v", ctrl)
}
// No tools to execute, Finalize should be called directly
}
// =============================================================================
// runTurn Integration Tests
// =============================================================================
func TestRunTurn_SimpleConversation(t *testing.T) {
provider := &simpleConvProvider{}
al, agent, cleanup := newTurnCoordTestLoop(t, provider)
defer cleanup()
pipeline := NewPipeline(al)
opts := makeTestProcessOpts("test-session-simple")
ts := newTurnState(agent, opts, turnEventScope{
turnID: "turn-simple",
context: newTurnContext(nil, nil, nil),
})
result, err := al.runTurn(context.Background(), ts, pipeline)
if err != nil {
t.Fatalf("runTurn failed: %v", err)
}
if result.status != TurnEndStatusCompleted {
t.Errorf("expected status Completed, got %v", result.status)
}
if result.finalContent == "" {
t.Error("expected non-empty finalContent")
}
}
func TestRunTurn_MaxIterations(t *testing.T) {
// Provider always returns tool calls, should hit max iterations
provider := &toolCallRespProvider{
toolName: "search",
toolArgs: map[string]any{"q": "x"},
response: "done",
}
al, agent, cleanup := newTurnCoordTestLoop(t, provider)
defer cleanup()
// Override max iterations to 2
agent.MaxIterations = 2
pipeline := NewPipeline(al)
opts := makeTestProcessOpts("test-session-maxiter")
ts := newTurnState(agent, opts, turnEventScope{
turnID: "turn-maxiter",
context: newTurnContext(nil, nil, nil),
})
result, err := al.runTurn(context.Background(), ts, pipeline)
if err != nil {
t.Fatalf("runTurn failed: %v", err)
}
// Should complete due to max iterations
if result.status != TurnEndStatusCompleted {
t.Errorf("expected status Completed, got %v", result.status)
}
}
func TestRunTurn_HardAbort(t *testing.T) {
// Provider simulates a slow response, but we'll abort mid-turn
slowProvider := &slowMockProvider{delay: 10 * time.Second}
al, agent, cleanup := newTurnCoordTestLoop(t, slowProvider)
defer cleanup()
pipeline := NewPipeline(al)
opts := makeTestProcessOpts("test-session-abort")
ts := newTurnState(agent, opts, turnEventScope{
turnID: "turn-abort",
context: newTurnContext(nil, nil, nil),
})
// Run in goroutine with abort after short delay
done := make(chan struct{})
go func() {
al.runTurn(context.Background(), ts, pipeline)
close(done)
}()
// Give it a moment to start
time.Sleep(50 * time.Millisecond)
// Request hard abort
ts.requestHardAbort()
// Wait for runTurn to complete
select {
case <-done:
case <-time.After(3 * time.Second):
t.Fatal("runTurn did not complete after abort")
}
}
func TestRunTurn_SteeringMessageInjection(t *testing.T) {
provider := &simpleConvProvider{}
al, agent, cleanup := newTurnCoordTestLoop(t, provider)
defer cleanup()
pipeline := NewPipeline(al)
opts := makeTestProcessOpts("test-session-steering")
ts := newTurnState(agent, opts, turnEventScope{
turnID: "turn-steering",
context: newTurnContext(nil, nil, nil),
})
// Enqueue steering message before runTurn
steeringMsg := providers.Message{
Role: "user",
Content: "Steering message",
}
al.Steer(steeringMsg)
result, err := al.runTurn(context.Background(), ts, pipeline)
if err != nil {
t.Fatalf("runTurn failed: %v", err)
}
if result.status != TurnEndStatusCompleted {
t.Errorf("expected status Completed, got %v", result.status)
}
// Steering message should have been injected
}
func TestRunTurn_GracefulInterrupt(t *testing.T) {
provider := &toolCallRespProvider{
toolName: "search",
toolArgs: map[string]any{"q": "test"},
response: "Final response after interrupt",
}
al, agent, cleanup := newTurnCoordTestLoop(t, provider)
defer cleanup()
pipeline := NewPipeline(al)
opts := makeTestProcessOpts("test-session-graceful")
ts := newTurnState(agent, opts, turnEventScope{
turnID: "turn-graceful",
context: newTurnContext(nil, nil, nil),
})
// Run in goroutine with graceful interrupt after first iteration
done := make(chan struct{})
var result turnResult
go func() {
result, _ = al.runTurn(context.Background(), ts, pipeline)
close(done)
}()
// Give it a moment to start first iteration
time.Sleep(50 * time.Millisecond)
// Request graceful interrupt
ts.requestGracefulInterrupt("Please stop")
select {
case <-done:
case <-time.After(5 * time.Second):
t.Fatal("runTurn did not complete after graceful interrupt")
}
// Should complete gracefully
if result.status != TurnEndStatusCompleted {
t.Errorf("expected status Completed, got %v", result.status)
}
}
// =============================================================================
// turnState Tests
// =============================================================================
func TestTurnState_GracefulInterruptRequested(t *testing.T) {
ts := &turnState{
gracefulInterrupt: false,
gracefulInterruptHint: "",
}
// Initially should not be requested
requested, _ := ts.gracefulInterruptRequested()
if requested {
t.Error("expected no interrupt initially")
}
// Request interrupt
ts.requestGracefulInterrupt("test hint")
requested, hint := ts.gracefulInterruptRequested()
if !requested {
t.Error("expected interrupt to be requested")
}
if hint != "test hint" {
t.Errorf("expected hint 'test hint', got %q", hint)
}
}
func TestTurnState_HardAbortRequested(t *testing.T) {
ts := &turnState{
hardAbort: false,
}
if ts.hardAbortRequested() {
t.Error("expected no hard abort initially")
}
ts.requestHardAbort()
if !ts.hardAbortRequested() {
t.Error("expected hard abort to be requested")
}
}