mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
0e075f7300
Refactor agent loop execution around runTurn, add explicit turn state and interrupt semantics, and automatically continue queued steering that misses the current turn boundary.
685 lines
18 KiB
Go
685 lines
18 KiB
Go
package agent
|
|
|
|
import (
|
|
"context"
|
|
"os"
|
|
"slices"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/sipeed/picoclaw/pkg/bus"
|
|
"github.com/sipeed/picoclaw/pkg/config"
|
|
"github.com/sipeed/picoclaw/pkg/providers"
|
|
"github.com/sipeed/picoclaw/pkg/tools"
|
|
)
|
|
|
|
func TestEventBus_SubscribeEmitUnsubscribeClose(t *testing.T) {
|
|
eventBus := NewEventBus()
|
|
sub := eventBus.Subscribe(1)
|
|
|
|
eventBus.Emit(Event{
|
|
Kind: EventKindTurnStart,
|
|
Meta: EventMeta{TurnID: "turn-1"},
|
|
})
|
|
|
|
select {
|
|
case evt := <-sub.C:
|
|
if evt.Kind != EventKindTurnStart {
|
|
t.Fatalf("expected %v, got %v", EventKindTurnStart, evt.Kind)
|
|
}
|
|
if evt.Meta.TurnID != "turn-1" {
|
|
t.Fatalf("expected turn id turn-1, got %q", evt.Meta.TurnID)
|
|
}
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timed out waiting for event")
|
|
}
|
|
|
|
eventBus.Unsubscribe(sub.ID)
|
|
if _, ok := <-sub.C; ok {
|
|
t.Fatal("expected subscriber channel to be closed after unsubscribe")
|
|
}
|
|
|
|
eventBus.Close()
|
|
closedSub := eventBus.Subscribe(1)
|
|
if _, ok := <-closedSub.C; ok {
|
|
t.Fatal("expected closed bus to return a closed subscriber channel")
|
|
}
|
|
}
|
|
|
|
func TestEventBus_DropsWhenSubscriberIsFull(t *testing.T) {
|
|
eventBus := NewEventBus()
|
|
sub := eventBus.Subscribe(1)
|
|
defer eventBus.Unsubscribe(sub.ID)
|
|
|
|
start := time.Now()
|
|
for i := 0; i < 1000; i++ {
|
|
eventBus.Emit(Event{Kind: EventKindLLMRequest})
|
|
}
|
|
|
|
if elapsed := time.Since(start); elapsed > 100*time.Millisecond {
|
|
t.Fatalf("Emit took too long with a blocked subscriber: %s", elapsed)
|
|
}
|
|
|
|
if got := eventBus.Dropped(EventKindLLMRequest); got != 999 {
|
|
t.Fatalf("expected 999 dropped events, got %d", got)
|
|
}
|
|
}
|
|
|
|
type scriptedToolProvider struct {
|
|
calls int
|
|
}
|
|
|
|
func (m *scriptedToolProvider) Chat(
|
|
ctx context.Context,
|
|
messages []providers.Message,
|
|
toolDefs []providers.ToolDefinition,
|
|
model string,
|
|
opts map[string]any,
|
|
) (*providers.LLMResponse, error) {
|
|
m.calls++
|
|
if m.calls == 1 {
|
|
return &providers.LLMResponse{
|
|
ToolCalls: []providers.ToolCall{
|
|
{
|
|
ID: "call-1",
|
|
Name: "mock_custom",
|
|
Arguments: map[string]any{"task": "ping"},
|
|
},
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
return &providers.LLMResponse{
|
|
Content: "done",
|
|
}, nil
|
|
}
|
|
|
|
func (m *scriptedToolProvider) GetDefaultModel() string {
|
|
return "scripted-tool-model"
|
|
}
|
|
|
|
func TestAgentLoop_EmitsMinimalTurnEvents(t *testing.T) {
|
|
tmpDir, err := os.MkdirTemp("", "agent-eventbus-*")
|
|
if err != nil {
|
|
t.Fatalf("failed to create temp dir: %v", err)
|
|
}
|
|
defer os.RemoveAll(tmpDir)
|
|
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
Model: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
},
|
|
},
|
|
}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
provider := &scriptedToolProvider{}
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
al.RegisterTool(&mockCustomTool{})
|
|
defaultAgent := al.registry.GetDefaultAgent()
|
|
if defaultAgent == nil {
|
|
t.Fatal("expected default agent")
|
|
}
|
|
|
|
sub := al.SubscribeEvents(16)
|
|
defer al.UnsubscribeEvents(sub.ID)
|
|
|
|
response, err := al.runAgentLoop(context.Background(), defaultAgent, processOptions{
|
|
SessionKey: "session-1",
|
|
Channel: "cli",
|
|
ChatID: "direct",
|
|
UserMessage: "run tool",
|
|
DefaultResponse: defaultResponse,
|
|
EnableSummary: false,
|
|
SendResponse: false,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("runAgentLoop failed: %v", err)
|
|
}
|
|
if response != "done" {
|
|
t.Fatalf("expected final response 'done', got %q", response)
|
|
}
|
|
|
|
events := collectEventStream(sub.C)
|
|
if len(events) != 8 {
|
|
t.Fatalf("expected 8 events, got %d", len(events))
|
|
}
|
|
|
|
kinds := make([]EventKind, 0, len(events))
|
|
for _, evt := range events {
|
|
kinds = append(kinds, evt.Kind)
|
|
}
|
|
|
|
expectedKinds := []EventKind{
|
|
EventKindTurnStart,
|
|
EventKindLLMRequest,
|
|
EventKindLLMResponse,
|
|
EventKindToolExecStart,
|
|
EventKindToolExecEnd,
|
|
EventKindLLMRequest,
|
|
EventKindLLMResponse,
|
|
EventKindTurnEnd,
|
|
}
|
|
if !slices.Equal(kinds, expectedKinds) {
|
|
t.Fatalf("unexpected event sequence: got %v want %v", kinds, expectedKinds)
|
|
}
|
|
|
|
turnID := events[0].Meta.TurnID
|
|
for i, evt := range events {
|
|
if evt.Meta.TurnID != turnID {
|
|
t.Fatalf("event %d has mismatched turn id %q, want %q", i, evt.Meta.TurnID, turnID)
|
|
}
|
|
if evt.Meta.SessionKey != "session-1" {
|
|
t.Fatalf("event %d has session key %q, want session-1", i, evt.Meta.SessionKey)
|
|
}
|
|
}
|
|
|
|
startPayload, ok := events[0].Payload.(TurnStartPayload)
|
|
if !ok {
|
|
t.Fatalf("expected TurnStartPayload, got %T", events[0].Payload)
|
|
}
|
|
if startPayload.UserMessage != "run tool" {
|
|
t.Fatalf("expected user message 'run tool', got %q", startPayload.UserMessage)
|
|
}
|
|
|
|
toolStartPayload, ok := events[3].Payload.(ToolExecStartPayload)
|
|
if !ok {
|
|
t.Fatalf("expected ToolExecStartPayload, got %T", events[3].Payload)
|
|
}
|
|
if toolStartPayload.Tool != "mock_custom" {
|
|
t.Fatalf("expected tool name mock_custom, got %q", toolStartPayload.Tool)
|
|
}
|
|
|
|
toolEndPayload, ok := events[4].Payload.(ToolExecEndPayload)
|
|
if !ok {
|
|
t.Fatalf("expected ToolExecEndPayload, got %T", events[4].Payload)
|
|
}
|
|
if toolEndPayload.Tool != "mock_custom" {
|
|
t.Fatalf("expected tool end payload for mock_custom, got %q", toolEndPayload.Tool)
|
|
}
|
|
if toolEndPayload.IsError {
|
|
t.Fatal("expected mock_custom tool to succeed")
|
|
}
|
|
|
|
turnEndPayload, ok := events[len(events)-1].Payload.(TurnEndPayload)
|
|
if !ok {
|
|
t.Fatalf("expected TurnEndPayload, got %T", events[len(events)-1].Payload)
|
|
}
|
|
if turnEndPayload.Status != TurnEndStatusCompleted {
|
|
t.Fatalf("expected completed turn, got %q", turnEndPayload.Status)
|
|
}
|
|
if turnEndPayload.Iterations != 2 {
|
|
t.Fatalf("expected 2 iterations, got %d", turnEndPayload.Iterations)
|
|
}
|
|
}
|
|
|
|
func TestAgentLoop_EmitsSteeringAndSkippedToolEvents(t *testing.T) {
|
|
tmpDir, err := os.MkdirTemp("", "agent-eventbus-steering-*")
|
|
if err != nil {
|
|
t.Fatalf("failed to create temp dir: %v", err)
|
|
}
|
|
defer os.RemoveAll(tmpDir)
|
|
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
Model: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
},
|
|
},
|
|
}
|
|
|
|
tool1ExecCh := make(chan struct{})
|
|
tool1 := &slowTool{name: "tool_one", duration: 50 * time.Millisecond, execCh: tool1ExecCh}
|
|
tool2 := &slowTool{name: "tool_two", duration: 50 * time.Millisecond}
|
|
|
|
provider := &toolCallProvider{
|
|
toolCalls: []providers.ToolCall{
|
|
{
|
|
ID: "call_1",
|
|
Type: "function",
|
|
Name: "tool_one",
|
|
Function: &providers.FunctionCall{
|
|
Name: "tool_one",
|
|
Arguments: "{}",
|
|
},
|
|
Arguments: map[string]any{},
|
|
},
|
|
{
|
|
ID: "call_2",
|
|
Type: "function",
|
|
Name: "tool_two",
|
|
Function: &providers.FunctionCall{
|
|
Name: "tool_two",
|
|
Arguments: "{}",
|
|
},
|
|
Arguments: map[string]any{},
|
|
},
|
|
},
|
|
finalResp: "steered response",
|
|
}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
al.RegisterTool(tool1)
|
|
al.RegisterTool(tool2)
|
|
|
|
sub := al.SubscribeEvents(32)
|
|
defer al.UnsubscribeEvents(sub.ID)
|
|
|
|
resultCh := make(chan string, 1)
|
|
go func() {
|
|
resp, _ := al.ProcessDirectWithChannel(context.Background(), "do something", "test-session", "test", "chat1")
|
|
resultCh <- resp
|
|
}()
|
|
|
|
select {
|
|
case <-tool1ExecCh:
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("timeout waiting for tool_one to start")
|
|
}
|
|
|
|
if err := al.Steer(providers.Message{Role: "user", Content: "change course"}); err != nil {
|
|
t.Fatalf("Steer failed: %v", err)
|
|
}
|
|
|
|
select {
|
|
case resp := <-resultCh:
|
|
if resp != "steered response" {
|
|
t.Fatalf("expected steered response, got %q", resp)
|
|
}
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatal("timeout waiting for steered response")
|
|
}
|
|
|
|
events := collectEventStream(sub.C)
|
|
steeringEvt, ok := findEvent(events, EventKindSteeringInjected)
|
|
if !ok {
|
|
t.Fatal("expected steering injected event")
|
|
}
|
|
steeringPayload, ok := steeringEvt.Payload.(SteeringInjectedPayload)
|
|
if !ok {
|
|
t.Fatalf("expected SteeringInjectedPayload, got %T", steeringEvt.Payload)
|
|
}
|
|
if steeringPayload.Count != 1 {
|
|
t.Fatalf("expected 1 steering message, got %d", steeringPayload.Count)
|
|
}
|
|
|
|
skippedEvt, ok := findEvent(events, EventKindToolExecSkipped)
|
|
if !ok {
|
|
t.Fatal("expected skipped tool event")
|
|
}
|
|
skippedPayload, ok := skippedEvt.Payload.(ToolExecSkippedPayload)
|
|
if !ok {
|
|
t.Fatalf("expected ToolExecSkippedPayload, got %T", skippedEvt.Payload)
|
|
}
|
|
if skippedPayload.Tool != "tool_two" {
|
|
t.Fatalf("expected skipped tool_two, got %q", skippedPayload.Tool)
|
|
}
|
|
|
|
interruptEvt, ok := findEvent(events, EventKindInterruptReceived)
|
|
if !ok {
|
|
t.Fatal("expected interrupt received event")
|
|
}
|
|
interruptPayload, ok := interruptEvt.Payload.(InterruptReceivedPayload)
|
|
if !ok {
|
|
t.Fatalf("expected InterruptReceivedPayload, got %T", interruptEvt.Payload)
|
|
}
|
|
if interruptPayload.Role != "user" {
|
|
t.Fatalf("expected interrupt role user, got %q", interruptPayload.Role)
|
|
}
|
|
if interruptPayload.Kind != InterruptKindSteering {
|
|
t.Fatalf("expected steering interrupt kind, got %q", interruptPayload.Kind)
|
|
}
|
|
if interruptPayload.ContentLen != len("change course") {
|
|
t.Fatalf("expected interrupt content len %d, got %d", len("change course"), interruptPayload.ContentLen)
|
|
}
|
|
}
|
|
|
|
func TestAgentLoop_EmitsContextCompressEventOnRetry(t *testing.T) {
|
|
tmpDir, err := os.MkdirTemp("", "agent-eventbus-compress-*")
|
|
if err != nil {
|
|
t.Fatalf("failed to create temp dir: %v", err)
|
|
}
|
|
defer os.RemoveAll(tmpDir)
|
|
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
Model: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
},
|
|
},
|
|
}
|
|
|
|
contextErr := stringError("InvalidParameter: Total tokens of image and text exceed max message tokens")
|
|
provider := &failFirstMockProvider{
|
|
failures: 1,
|
|
failError: contextErr,
|
|
successResp: "Recovered from context error",
|
|
}
|
|
msgBus := bus.NewMessageBus()
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
defaultAgent := al.registry.GetDefaultAgent()
|
|
if defaultAgent == nil {
|
|
t.Fatal("expected default agent")
|
|
}
|
|
|
|
defaultAgent.Sessions.SetHistory("session-1", []providers.Message{
|
|
{Role: "user", Content: "Old message 1"},
|
|
{Role: "assistant", Content: "Old response 1"},
|
|
{Role: "user", Content: "Old message 2"},
|
|
{Role: "assistant", Content: "Old response 2"},
|
|
{Role: "user", Content: "Trigger message"},
|
|
})
|
|
|
|
sub := al.SubscribeEvents(16)
|
|
defer al.UnsubscribeEvents(sub.ID)
|
|
|
|
resp, err := al.runAgentLoop(context.Background(), defaultAgent, processOptions{
|
|
SessionKey: "session-1",
|
|
Channel: "cli",
|
|
ChatID: "direct",
|
|
UserMessage: "Trigger message",
|
|
DefaultResponse: defaultResponse,
|
|
EnableSummary: false,
|
|
SendResponse: false,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("runAgentLoop failed: %v", err)
|
|
}
|
|
if resp != "Recovered from context error" {
|
|
t.Fatalf("expected retry success, got %q", resp)
|
|
}
|
|
|
|
events := collectEventStream(sub.C)
|
|
retryEvt, ok := findEvent(events, EventKindLLMRetry)
|
|
if !ok {
|
|
t.Fatal("expected llm retry event")
|
|
}
|
|
retryPayload, ok := retryEvt.Payload.(LLMRetryPayload)
|
|
if !ok {
|
|
t.Fatalf("expected LLMRetryPayload, got %T", retryEvt.Payload)
|
|
}
|
|
if retryPayload.Reason != "context_limit" {
|
|
t.Fatalf("expected context_limit retry reason, got %q", retryPayload.Reason)
|
|
}
|
|
if retryPayload.Attempt != 1 {
|
|
t.Fatalf("expected retry attempt 1, got %d", retryPayload.Attempt)
|
|
}
|
|
|
|
compressEvt, ok := findEvent(events, EventKindContextCompress)
|
|
if !ok {
|
|
t.Fatal("expected context compress event")
|
|
}
|
|
payload, ok := compressEvt.Payload.(ContextCompressPayload)
|
|
if !ok {
|
|
t.Fatalf("expected ContextCompressPayload, got %T", compressEvt.Payload)
|
|
}
|
|
if payload.Reason != ContextCompressReasonRetry {
|
|
t.Fatalf("expected retry compress reason, got %q", payload.Reason)
|
|
}
|
|
if payload.DroppedMessages == 0 {
|
|
t.Fatal("expected dropped messages to be recorded")
|
|
}
|
|
}
|
|
|
|
func TestAgentLoop_EmitsSessionSummarizeEvent(t *testing.T) {
|
|
tmpDir, err := os.MkdirTemp("", "agent-eventbus-summary-*")
|
|
if err != nil {
|
|
t.Fatalf("failed to create temp dir: %v", err)
|
|
}
|
|
defer os.RemoveAll(tmpDir)
|
|
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
Model: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
ContextWindow: 8000,
|
|
SummarizeMessageThreshold: 2,
|
|
SummarizeTokenPercent: 75,
|
|
},
|
|
},
|
|
}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
al := NewAgentLoop(cfg, msgBus, &simpleMockProvider{response: "summary text"})
|
|
defaultAgent := al.registry.GetDefaultAgent()
|
|
if defaultAgent == nil {
|
|
t.Fatal("expected default agent")
|
|
}
|
|
|
|
defaultAgent.Sessions.SetHistory("session-1", []providers.Message{
|
|
{Role: "user", Content: "Question one"},
|
|
{Role: "assistant", Content: "Answer one"},
|
|
{Role: "user", Content: "Question two"},
|
|
{Role: "assistant", Content: "Answer two"},
|
|
{Role: "user", Content: "Question three"},
|
|
{Role: "assistant", Content: "Answer three"},
|
|
})
|
|
|
|
sub := al.SubscribeEvents(16)
|
|
defer al.UnsubscribeEvents(sub.ID)
|
|
|
|
turnScope := al.newTurnEventScope(defaultAgent.ID, "session-1")
|
|
al.summarizeSession(defaultAgent, "session-1", turnScope)
|
|
|
|
events := collectEventStream(sub.C)
|
|
summaryEvt, ok := findEvent(events, EventKindSessionSummarize)
|
|
if !ok {
|
|
t.Fatal("expected session summarize event")
|
|
}
|
|
payload, ok := summaryEvt.Payload.(SessionSummarizePayload)
|
|
if !ok {
|
|
t.Fatalf("expected SessionSummarizePayload, got %T", summaryEvt.Payload)
|
|
}
|
|
if payload.SummaryLen == 0 {
|
|
t.Fatal("expected non-empty summary length")
|
|
}
|
|
}
|
|
|
|
func TestAgentLoop_EmitsFollowUpQueuedEvent(t *testing.T) {
|
|
tmpDir, err := os.MkdirTemp("", "agent-eventbus-followup-*")
|
|
if err != nil {
|
|
t.Fatalf("failed to create temp dir: %v", err)
|
|
}
|
|
defer os.RemoveAll(tmpDir)
|
|
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
Model: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
},
|
|
},
|
|
}
|
|
|
|
provider := &toolCallProvider{
|
|
toolCalls: []providers.ToolCall{
|
|
{
|
|
ID: "call_async_1",
|
|
Type: "function",
|
|
Name: "async_followup",
|
|
Function: &providers.FunctionCall{
|
|
Name: "async_followup",
|
|
Arguments: "{}",
|
|
},
|
|
Arguments: map[string]any{},
|
|
},
|
|
},
|
|
finalResp: "async launched",
|
|
}
|
|
|
|
msgBus := bus.NewMessageBus()
|
|
al := NewAgentLoop(cfg, msgBus, provider)
|
|
doneCh := make(chan struct{})
|
|
al.RegisterTool(&asyncFollowUpTool{
|
|
name: "async_followup",
|
|
followUpText: "background result",
|
|
completionSig: doneCh,
|
|
})
|
|
defaultAgent := al.registry.GetDefaultAgent()
|
|
if defaultAgent == nil {
|
|
t.Fatal("expected default agent")
|
|
}
|
|
|
|
sub := al.SubscribeEvents(32)
|
|
defer al.UnsubscribeEvents(sub.ID)
|
|
|
|
resp, err := al.runAgentLoop(context.Background(), defaultAgent, processOptions{
|
|
SessionKey: "session-1",
|
|
Channel: "cli",
|
|
ChatID: "direct",
|
|
UserMessage: "run async tool",
|
|
DefaultResponse: defaultResponse,
|
|
EnableSummary: false,
|
|
SendResponse: false,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("runAgentLoop failed: %v", err)
|
|
}
|
|
if resp != "async launched" {
|
|
t.Fatalf("expected final response 'async launched', got %q", resp)
|
|
}
|
|
|
|
select {
|
|
case <-doneCh:
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("timeout waiting for async tool completion")
|
|
}
|
|
|
|
followUpEvt := waitForEvent(t, sub.C, 2*time.Second, func(evt Event) bool {
|
|
return evt.Kind == EventKindFollowUpQueued
|
|
})
|
|
payload, ok := followUpEvt.Payload.(FollowUpQueuedPayload)
|
|
if !ok {
|
|
t.Fatalf("expected FollowUpQueuedPayload, got %T", followUpEvt.Payload)
|
|
}
|
|
if payload.SourceTool != "async_followup" {
|
|
t.Fatalf("expected source tool async_followup, got %q", payload.SourceTool)
|
|
}
|
|
if payload.Channel != "cli" {
|
|
t.Fatalf("expected channel cli, got %q", payload.Channel)
|
|
}
|
|
if payload.ChatID != "direct" {
|
|
t.Fatalf("expected chat id direct, got %q", payload.ChatID)
|
|
}
|
|
if payload.ContentLen != len("background result") {
|
|
t.Fatalf("expected content len %d, got %d", len("background result"), payload.ContentLen)
|
|
}
|
|
if followUpEvt.Meta.SessionKey != "session-1" {
|
|
t.Fatalf("expected session key session-1, got %q", followUpEvt.Meta.SessionKey)
|
|
}
|
|
if followUpEvt.Meta.TurnID == "" {
|
|
t.Fatal("expected follow-up event to include turn id")
|
|
}
|
|
}
|
|
|
|
func collectEventStream(ch <-chan Event) []Event {
|
|
var events []Event
|
|
for {
|
|
select {
|
|
case evt, ok := <-ch:
|
|
if !ok {
|
|
return events
|
|
}
|
|
events = append(events, evt)
|
|
default:
|
|
return events
|
|
}
|
|
}
|
|
}
|
|
|
|
func waitForEvent(t *testing.T, ch <-chan Event, timeout time.Duration, match func(Event) bool) Event {
|
|
t.Helper()
|
|
|
|
timer := time.NewTimer(timeout)
|
|
defer timer.Stop()
|
|
|
|
for {
|
|
select {
|
|
case evt, ok := <-ch:
|
|
if !ok {
|
|
t.Fatal("event stream closed before expected event arrived")
|
|
}
|
|
if match(evt) {
|
|
return evt
|
|
}
|
|
case <-timer.C:
|
|
t.Fatal("timed out waiting for expected event")
|
|
}
|
|
}
|
|
}
|
|
|
|
func findEvent(events []Event, kind EventKind) (Event, bool) {
|
|
for _, evt := range events {
|
|
if evt.Kind == kind {
|
|
return evt, true
|
|
}
|
|
}
|
|
return Event{}, false
|
|
}
|
|
|
|
type stringError string
|
|
|
|
func (e stringError) Error() string {
|
|
return string(e)
|
|
}
|
|
|
|
type asyncFollowUpTool struct {
|
|
name string
|
|
followUpText string
|
|
completionSig chan struct{}
|
|
}
|
|
|
|
func (t *asyncFollowUpTool) Name() string {
|
|
return t.name
|
|
}
|
|
|
|
func (t *asyncFollowUpTool) Description() string {
|
|
return "async follow-up tool for testing"
|
|
}
|
|
|
|
func (t *asyncFollowUpTool) Parameters() map[string]any {
|
|
return map[string]any{
|
|
"type": "object",
|
|
"properties": map[string]any{},
|
|
}
|
|
}
|
|
|
|
func (t *asyncFollowUpTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
|
|
return tools.AsyncResult("async follow-up scheduled")
|
|
}
|
|
|
|
func (t *asyncFollowUpTool) ExecuteAsync(
|
|
ctx context.Context,
|
|
args map[string]any,
|
|
cb tools.AsyncCallback,
|
|
) *tools.ToolResult {
|
|
go func() {
|
|
cb(ctx, &tools.ToolResult{ForLLM: t.followUpText})
|
|
if t.completionSig != nil {
|
|
close(t.completionSig)
|
|
}
|
|
}()
|
|
return tools.AsyncResult("async follow-up scheduled")
|
|
}
|
|
|
|
var (
|
|
_ tools.Tool = (*mockCustomTool)(nil)
|
|
_ tools.AsyncExecutor = (*asyncFollowUpTool)(nil)
|
|
)
|