mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
cac4f21746
- centralize web search provider readiness and resolution logic - fall back when the configured provider is unavailable or invalid - allow native-search-capable models to use built-in search without the client tool - simplify the tools page and add direct access to web search settings - add backend, agent, and integration tests for the new selection behavior
616 lines
16 KiB
Go
616 lines
16 KiB
Go
package agent
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/sipeed/picoclaw/pkg/bus"
|
|
"github.com/sipeed/picoclaw/pkg/config"
|
|
"github.com/sipeed/picoclaw/pkg/providers"
|
|
)
|
|
|
|
// =============================================================================
|
|
// Mock Providers for turn_coord Tests
|
|
// =============================================================================
|
|
|
|
// simpleConvProvider returns a simple text response without tools
|
|
type simpleConvProvider struct{}
|
|
|
|
func (p *simpleConvProvider) Chat(
|
|
ctx context.Context,
|
|
messages []providers.Message,
|
|
tools []providers.ToolDefinition,
|
|
model string,
|
|
opts map[string]any,
|
|
) (*providers.LLMResponse, error) {
|
|
return &providers.LLMResponse{
|
|
Content: "Hello! How can I help you today?",
|
|
FinishReason: "stop",
|
|
}, nil
|
|
}
|
|
|
|
func (p *simpleConvProvider) GetDefaultModel() string {
|
|
return "simple-model"
|
|
}
|
|
|
|
type nativeSearchCaptureProvider struct {
|
|
lastOpts map[string]any
|
|
}
|
|
|
|
func (p *nativeSearchCaptureProvider) Chat(
|
|
ctx context.Context,
|
|
messages []providers.Message,
|
|
tools []providers.ToolDefinition,
|
|
model string,
|
|
opts map[string]any,
|
|
) (*providers.LLMResponse, error) {
|
|
p.lastOpts = make(map[string]any, len(opts))
|
|
for k, v := range opts {
|
|
p.lastOpts[k] = v
|
|
}
|
|
return &providers.LLMResponse{
|
|
Content: "Using native search",
|
|
FinishReason: "stop",
|
|
}, nil
|
|
}
|
|
|
|
func (p *nativeSearchCaptureProvider) GetDefaultModel() string {
|
|
return "native-search-model"
|
|
}
|
|
|
|
func (p *nativeSearchCaptureProvider) SupportsNativeSearch() bool {
|
|
return true
|
|
}
|
|
|
|
// toolCallRespProvider returns a tool call response
|
|
type toolCallRespProvider struct {
|
|
toolName string
|
|
toolArgs map[string]any
|
|
response string
|
|
callCount int
|
|
mu sync.Mutex
|
|
}
|
|
|
|
func (p *toolCallRespProvider) Chat(
|
|
ctx context.Context,
|
|
messages []providers.Message,
|
|
tools []providers.ToolDefinition,
|
|
model string,
|
|
opts map[string]any,
|
|
) (*providers.LLMResponse, error) {
|
|
p.mu.Lock()
|
|
p.callCount++
|
|
count := p.callCount
|
|
p.mu.Unlock()
|
|
|
|
// First call returns a tool call, subsequent calls return final response
|
|
if count == 1 {
|
|
return &providers.LLMResponse{
|
|
Content: "Let me search for that information.",
|
|
ToolCalls: []providers.ToolCall{
|
|
{
|
|
ID: "call_1",
|
|
Name: p.toolName,
|
|
Arguments: p.toolArgs,
|
|
},
|
|
},
|
|
FinishReason: "tool_calls",
|
|
}, nil
|
|
}
|
|
return &providers.LLMResponse{
|
|
Content: p.response,
|
|
FinishReason: "stop",
|
|
}, nil
|
|
}
|
|
|
|
func (p *toolCallRespProvider) GetDefaultModel() string {
|
|
return "tool-model"
|
|
}
|
|
|
|
// errorProvider simulates various error conditions
|
|
type errorProvider struct {
|
|
errType string
|
|
callCount int
|
|
mu sync.Mutex
|
|
}
|
|
|
|
func (p *errorProvider) Chat(
|
|
ctx context.Context,
|
|
messages []providers.Message,
|
|
tools []providers.ToolDefinition,
|
|
model string,
|
|
opts map[string]any,
|
|
) (*providers.LLMResponse, error) {
|
|
p.mu.Lock()
|
|
p.callCount++
|
|
p.mu.Unlock()
|
|
|
|
switch p.errType {
|
|
case "timeout":
|
|
return nil, context.DeadlineExceeded
|
|
case "context_length":
|
|
return nil, errors.New("context_length_exceeded")
|
|
case "vision":
|
|
return nil, errors.New("vision_unsupported")
|
|
default:
|
|
return nil, errors.New("unknown error")
|
|
}
|
|
}
|
|
|
|
func (p *errorProvider) GetDefaultModel() string {
|
|
return "error-model"
|
|
}
|
|
|
|
// =============================================================================
|
|
// Test Helper Functions
|
|
// =============================================================================
|
|
|
|
func newTurnCoordTestLoop(t *testing.T, provider providers.LLMProvider) (*AgentLoop, *AgentInstance, func()) {
|
|
t.Helper()
|
|
tmpDir := t.TempDir()
|
|
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
ModelName: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
},
|
|
},
|
|
}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
agent := al.registry.GetDefaultAgent()
|
|
if agent == nil {
|
|
t.Fatal("expected default agent")
|
|
}
|
|
|
|
return al, agent, func() {
|
|
al.Close()
|
|
}
|
|
}
|
|
|
|
func makeTestProcessOpts(sessionKey string) processOptions {
|
|
return processOptions{
|
|
SessionKey: sessionKey,
|
|
Channel: "cli",
|
|
ChatID: "test-chat",
|
|
UserMessage: "test message",
|
|
DefaultResponse: "I couldn't process your request.",
|
|
EnableSummary: false,
|
|
SendResponse: false,
|
|
NoHistory: false,
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// Pipeline Method Tests: SetupTurn
|
|
// =============================================================================
|
|
|
|
func TestPipeline_SetupTurn_BasicInitialization(t *testing.T) {
|
|
al, agent, cleanup := newTurnCoordTestLoop(t, &simpleConvProvider{})
|
|
defer cleanup()
|
|
|
|
pipeline := NewPipeline(al)
|
|
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
|
|
turnID: "turn-1",
|
|
context: newTurnContext(nil, nil, nil),
|
|
})
|
|
|
|
exec, err := pipeline.SetupTurn(context.Background(), ts)
|
|
if err != nil {
|
|
t.Fatalf("SetupTurn failed: %v", err)
|
|
}
|
|
if exec == nil {
|
|
t.Fatal("expected non-nil turnExecution")
|
|
}
|
|
if len(exec.messages) == 0 {
|
|
t.Error("expected messages to be populated")
|
|
}
|
|
if exec.iteration != 0 {
|
|
t.Errorf("expected iteration 0, got %d", exec.iteration)
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// Pipeline Method Tests: CallLLM
|
|
// =============================================================================
|
|
|
|
func TestPipeline_CallLLM_SimpleResponse(t *testing.T) {
|
|
al, agent, cleanup := newTurnCoordTestLoop(t, &simpleConvProvider{})
|
|
defer cleanup()
|
|
|
|
pipeline := NewPipeline(al)
|
|
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
|
|
turnID: "turn-1",
|
|
context: newTurnContext(nil, nil, nil),
|
|
})
|
|
|
|
exec, err := pipeline.SetupTurn(context.Background(), ts)
|
|
if err != nil {
|
|
t.Fatalf("SetupTurn failed: %v", err)
|
|
}
|
|
|
|
ctrl, err := pipeline.CallLLM(context.Background(), context.Background(), ts, exec, 1)
|
|
if err != nil {
|
|
t.Fatalf("CallLLM failed: %v", err)
|
|
}
|
|
if ctrl != ControlBreak {
|
|
t.Errorf("expected ControlBreak, got %v", ctrl)
|
|
}
|
|
if exec.response == nil {
|
|
t.Fatal("expected non-nil response")
|
|
}
|
|
if exec.response.Content == "" {
|
|
t.Error("expected non-empty content")
|
|
}
|
|
}
|
|
|
|
func TestPipeline_CallLLM_WithToolCall(t *testing.T) {
|
|
provider := &toolCallRespProvider{
|
|
toolName: "web_search",
|
|
toolArgs: map[string]any{"query": "test"},
|
|
response: "Found information about test.",
|
|
}
|
|
al, agent, cleanup := newTurnCoordTestLoop(t, provider)
|
|
defer cleanup()
|
|
|
|
pipeline := NewPipeline(al)
|
|
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
|
|
turnID: "turn-1",
|
|
context: newTurnContext(nil, nil, nil),
|
|
})
|
|
|
|
exec, err := pipeline.SetupTurn(context.Background(), ts)
|
|
if err != nil {
|
|
t.Fatalf("SetupTurn failed: %v", err)
|
|
}
|
|
|
|
ctrl, err := pipeline.CallLLM(context.Background(), context.Background(), ts, exec, 1)
|
|
if err != nil {
|
|
t.Fatalf("CallLLM failed: %v", err)
|
|
}
|
|
if ctrl != ControlToolLoop {
|
|
t.Errorf("expected ControlToolLoop, got %v", ctrl)
|
|
}
|
|
if len(exec.normalizedToolCalls) == 0 {
|
|
t.Fatal("expected tool calls")
|
|
}
|
|
if exec.normalizedToolCalls[0].Name != "web_search" {
|
|
t.Errorf("expected tool name 'web_search', got %q", exec.normalizedToolCalls[0].Name)
|
|
}
|
|
}
|
|
|
|
func TestPipeline_CallLLM_UsesNativeSearchWithoutClientWebSearchTool(t *testing.T) {
|
|
provider := &nativeSearchCaptureProvider{}
|
|
al, agent, cleanup := newTurnCoordTestLoop(t, provider)
|
|
defer cleanup()
|
|
|
|
if _, ok := agent.Tools.Get("web_search"); ok {
|
|
t.Fatal("expected no client-side web_search tool to be registered")
|
|
}
|
|
|
|
al.cfg.Tools.Web.Enabled = true
|
|
al.cfg.Tools.Web.PreferNative = true
|
|
|
|
pipeline := NewPipeline(al)
|
|
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
|
|
turnID: "turn-1",
|
|
context: newTurnContext(nil, nil, nil),
|
|
})
|
|
|
|
exec, err := pipeline.SetupTurn(context.Background(), ts)
|
|
if err != nil {
|
|
t.Fatalf("SetupTurn failed: %v", err)
|
|
}
|
|
|
|
ctrl, err := pipeline.CallLLM(context.Background(), context.Background(), ts, exec, 1)
|
|
if err != nil {
|
|
t.Fatalf("CallLLM failed: %v", err)
|
|
}
|
|
if ctrl != ControlBreak {
|
|
t.Fatalf("expected ControlBreak, got %v", ctrl)
|
|
}
|
|
if got, _ := provider.lastOpts["native_search"].(bool); !got {
|
|
t.Fatalf("expected native_search=true, got %#v", provider.lastOpts["native_search"])
|
|
}
|
|
}
|
|
|
|
func TestPipeline_CallLLM_TimeoutRetry(t *testing.T) {
|
|
errorPrv := &errorProvider{errType: "timeout"}
|
|
al, agent, cleanup := newTurnCoordTestLoop(t, errorPrv)
|
|
defer cleanup()
|
|
|
|
pipeline := NewPipeline(al)
|
|
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
|
|
turnID: "turn-1",
|
|
context: newTurnContext(nil, nil, nil),
|
|
})
|
|
|
|
exec, err := pipeline.SetupTurn(context.Background(), ts)
|
|
if err != nil {
|
|
t.Fatalf("SetupTurn failed: %v", err)
|
|
}
|
|
|
|
// Should retry and eventually fail after max retries
|
|
_, err = pipeline.CallLLM(context.Background(), context.Background(), ts, exec, 1)
|
|
if err == nil {
|
|
t.Error("expected error after retries")
|
|
}
|
|
}
|
|
|
|
func TestPipeline_CallLLM_ContextLengthError(t *testing.T) {
|
|
errorPrv := &errorProvider{errType: "context_length"}
|
|
al, agent, cleanup := newTurnCoordTestLoop(t, errorPrv)
|
|
defer cleanup()
|
|
|
|
pipeline := NewPipeline(al)
|
|
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
|
|
turnID: "turn-1",
|
|
context: newTurnContext(nil, nil, nil),
|
|
})
|
|
|
|
exec, err := pipeline.SetupTurn(context.Background(), ts)
|
|
if err != nil {
|
|
t.Fatalf("SetupTurn failed: %v", err)
|
|
}
|
|
|
|
// Should trigger context compression and retry
|
|
_, err = pipeline.CallLLM(context.Background(), context.Background(), ts, exec, 1)
|
|
// May succeed after compression or fail - either is acceptable
|
|
t.Logf("CallLLM result after context error: err=%v", err)
|
|
}
|
|
|
|
// =============================================================================
|
|
// Pipeline Method Tests: ExecuteTools
|
|
// =============================================================================
|
|
|
|
func TestPipeline_ExecuteTools_NoTools(t *testing.T) {
|
|
// Provider returns no tool calls, so ExecuteTools should not be called
|
|
// This test verifies the ControlBreak path from CallLLM
|
|
provider := &simpleConvProvider{}
|
|
al, agent, cleanup := newTurnCoordTestLoop(t, provider)
|
|
defer cleanup()
|
|
|
|
pipeline := NewPipeline(al)
|
|
ts := newTurnState(agent, makeTestProcessOpts("test-session"), turnEventScope{
|
|
turnID: "turn-1",
|
|
context: newTurnContext(nil, nil, nil),
|
|
})
|
|
|
|
exec, err := pipeline.SetupTurn(context.Background(), ts)
|
|
if err != nil {
|
|
t.Fatalf("SetupTurn failed: %v", err)
|
|
}
|
|
|
|
// First CallLLM returns ControlBreak (no tools)
|
|
ctrl, err := pipeline.CallLLM(context.Background(), context.Background(), ts, exec, 1)
|
|
if err != nil {
|
|
t.Fatalf("CallLLM failed: %v", err)
|
|
}
|
|
|
|
if ctrl != ControlBreak {
|
|
t.Fatalf("expected ControlBreak, got %v", ctrl)
|
|
}
|
|
// No tools to execute, Finalize should be called directly
|
|
}
|
|
|
|
// =============================================================================
|
|
// runTurn Integration Tests
|
|
// =============================================================================
|
|
|
|
func TestRunTurn_SimpleConversation(t *testing.T) {
|
|
provider := &simpleConvProvider{}
|
|
al, agent, cleanup := newTurnCoordTestLoop(t, provider)
|
|
defer cleanup()
|
|
|
|
pipeline := NewPipeline(al)
|
|
opts := makeTestProcessOpts("test-session-simple")
|
|
|
|
ts := newTurnState(agent, opts, turnEventScope{
|
|
turnID: "turn-simple",
|
|
context: newTurnContext(nil, nil, nil),
|
|
})
|
|
|
|
result, err := al.runTurn(context.Background(), ts, pipeline)
|
|
if err != nil {
|
|
t.Fatalf("runTurn failed: %v", err)
|
|
}
|
|
if result.status != TurnEndStatusCompleted {
|
|
t.Errorf("expected status Completed, got %v", result.status)
|
|
}
|
|
if result.finalContent == "" {
|
|
t.Error("expected non-empty finalContent")
|
|
}
|
|
}
|
|
|
|
func TestRunTurn_MaxIterations(t *testing.T) {
|
|
// Provider always returns tool calls, should hit max iterations
|
|
provider := &toolCallRespProvider{
|
|
toolName: "search",
|
|
toolArgs: map[string]any{"q": "x"},
|
|
response: "done",
|
|
}
|
|
al, agent, cleanup := newTurnCoordTestLoop(t, provider)
|
|
defer cleanup()
|
|
|
|
// Override max iterations to 2
|
|
agent.MaxIterations = 2
|
|
|
|
pipeline := NewPipeline(al)
|
|
opts := makeTestProcessOpts("test-session-maxiter")
|
|
|
|
ts := newTurnState(agent, opts, turnEventScope{
|
|
turnID: "turn-maxiter",
|
|
context: newTurnContext(nil, nil, nil),
|
|
})
|
|
|
|
result, err := al.runTurn(context.Background(), ts, pipeline)
|
|
if err != nil {
|
|
t.Fatalf("runTurn failed: %v", err)
|
|
}
|
|
// Should complete due to max iterations
|
|
if result.status != TurnEndStatusCompleted {
|
|
t.Errorf("expected status Completed, got %v", result.status)
|
|
}
|
|
}
|
|
|
|
func TestRunTurn_HardAbort(t *testing.T) {
|
|
// Provider simulates a slow response, but we'll abort mid-turn
|
|
slowProvider := &slowMockProvider{delay: 10 * time.Second}
|
|
al, agent, cleanup := newTurnCoordTestLoop(t, slowProvider)
|
|
defer cleanup()
|
|
|
|
pipeline := NewPipeline(al)
|
|
opts := makeTestProcessOpts("test-session-abort")
|
|
|
|
ts := newTurnState(agent, opts, turnEventScope{
|
|
turnID: "turn-abort",
|
|
context: newTurnContext(nil, nil, nil),
|
|
})
|
|
|
|
// Run in goroutine with abort after short delay
|
|
done := make(chan struct{})
|
|
|
|
go func() {
|
|
al.runTurn(context.Background(), ts, pipeline)
|
|
close(done)
|
|
}()
|
|
|
|
// Give it a moment to start
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
// Request hard abort
|
|
ts.requestHardAbort()
|
|
|
|
// Wait for runTurn to complete
|
|
select {
|
|
case <-done:
|
|
case <-time.After(3 * time.Second):
|
|
t.Fatal("runTurn did not complete after abort")
|
|
}
|
|
}
|
|
|
|
func TestRunTurn_SteeringMessageInjection(t *testing.T) {
|
|
provider := &simpleConvProvider{}
|
|
al, agent, cleanup := newTurnCoordTestLoop(t, provider)
|
|
defer cleanup()
|
|
|
|
pipeline := NewPipeline(al)
|
|
opts := makeTestProcessOpts("test-session-steering")
|
|
|
|
ts := newTurnState(agent, opts, turnEventScope{
|
|
turnID: "turn-steering",
|
|
context: newTurnContext(nil, nil, nil),
|
|
})
|
|
|
|
// Enqueue steering message before runTurn
|
|
steeringMsg := providers.Message{
|
|
Role: "user",
|
|
Content: "Steering message",
|
|
}
|
|
al.Steer(steeringMsg)
|
|
|
|
result, err := al.runTurn(context.Background(), ts, pipeline)
|
|
if err != nil {
|
|
t.Fatalf("runTurn failed: %v", err)
|
|
}
|
|
if result.status != TurnEndStatusCompleted {
|
|
t.Errorf("expected status Completed, got %v", result.status)
|
|
}
|
|
// Steering message should have been injected
|
|
}
|
|
|
|
func TestRunTurn_GracefulInterrupt(t *testing.T) {
|
|
provider := &toolCallRespProvider{
|
|
toolName: "search",
|
|
toolArgs: map[string]any{"q": "test"},
|
|
response: "Final response after interrupt",
|
|
}
|
|
al, agent, cleanup := newTurnCoordTestLoop(t, provider)
|
|
defer cleanup()
|
|
|
|
pipeline := NewPipeline(al)
|
|
opts := makeTestProcessOpts("test-session-graceful")
|
|
|
|
ts := newTurnState(agent, opts, turnEventScope{
|
|
turnID: "turn-graceful",
|
|
context: newTurnContext(nil, nil, nil),
|
|
})
|
|
|
|
// Run in goroutine with graceful interrupt after first iteration
|
|
done := make(chan struct{})
|
|
var result turnResult
|
|
|
|
go func() {
|
|
result, _ = al.runTurn(context.Background(), ts, pipeline)
|
|
close(done)
|
|
}()
|
|
|
|
// Give it a moment to start first iteration
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
// Request graceful interrupt
|
|
ts.requestGracefulInterrupt("Please stop")
|
|
|
|
select {
|
|
case <-done:
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatal("runTurn did not complete after graceful interrupt")
|
|
}
|
|
|
|
// Should complete gracefully
|
|
if result.status != TurnEndStatusCompleted {
|
|
t.Errorf("expected status Completed, got %v", result.status)
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// turnState Tests
|
|
// =============================================================================
|
|
|
|
func TestTurnState_GracefulInterruptRequested(t *testing.T) {
|
|
ts := &turnState{
|
|
gracefulInterrupt: false,
|
|
gracefulInterruptHint: "",
|
|
}
|
|
|
|
// Initially should not be requested
|
|
requested, _ := ts.gracefulInterruptRequested()
|
|
if requested {
|
|
t.Error("expected no interrupt initially")
|
|
}
|
|
|
|
// Request interrupt
|
|
ts.requestGracefulInterrupt("test hint")
|
|
|
|
requested, hint := ts.gracefulInterruptRequested()
|
|
if !requested {
|
|
t.Error("expected interrupt to be requested")
|
|
}
|
|
if hint != "test hint" {
|
|
t.Errorf("expected hint 'test hint', got %q", hint)
|
|
}
|
|
}
|
|
|
|
func TestTurnState_HardAbortRequested(t *testing.T) {
|
|
ts := &turnState{
|
|
hardAbort: false,
|
|
}
|
|
|
|
if ts.hardAbortRequested() {
|
|
t.Error("expected no hard abort initially")
|
|
}
|
|
|
|
ts.requestHardAbort()
|
|
|
|
if !ts.hardAbortRequested() {
|
|
t.Error("expected hard abort to be requested")
|
|
}
|
|
}
|