mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
6e8a81bfbf
Use RuntimeEventObserver for the normal in-process hook observer path and make the process-hook helper assert hook.runtime_event notifications. Validation: go test ./pkg/agent; make lint
1597 lines
43 KiB
Go
1597 lines
43 KiB
Go
package agent
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"os"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/sipeed/picoclaw/pkg/bus"
|
|
"github.com/sipeed/picoclaw/pkg/config"
|
|
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
|
"github.com/sipeed/picoclaw/pkg/providers"
|
|
"github.com/sipeed/picoclaw/pkg/routing"
|
|
"github.com/sipeed/picoclaw/pkg/session"
|
|
"github.com/sipeed/picoclaw/pkg/tools"
|
|
)
|
|
|
|
func newHookTestLoop(
|
|
t *testing.T,
|
|
provider providers.LLMProvider,
|
|
) (*AgentLoop, *AgentInstance, func()) {
|
|
t.Helper()
|
|
|
|
tmpDir, err := os.MkdirTemp("", "agent-hooks-*")
|
|
if err != nil {
|
|
t.Fatalf("failed to create temp dir: %v", err)
|
|
}
|
|
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
ModelName: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
},
|
|
},
|
|
}
|
|
|
|
al := NewAgentLoop(cfg, bus.NewMessageBus(), provider)
|
|
agent := al.registry.GetDefaultAgent()
|
|
if agent == nil {
|
|
t.Fatal("expected default agent")
|
|
}
|
|
|
|
return al, agent, func() {
|
|
al.Close()
|
|
_ = os.RemoveAll(tmpDir)
|
|
}
|
|
}
|
|
|
|
func TestHookManager_SortsInProcessBeforeProcess(t *testing.T) {
|
|
hm := NewHookManager(nil)
|
|
defer hm.Close()
|
|
|
|
if err := hm.Mount(HookRegistration{
|
|
Name: "process",
|
|
Priority: -10,
|
|
Source: HookSourceProcess,
|
|
Hook: struct{}{},
|
|
}); err != nil {
|
|
t.Fatalf("mount process hook: %v", err)
|
|
}
|
|
if err := hm.Mount(HookRegistration{
|
|
Name: "in-process",
|
|
Priority: 100,
|
|
Source: HookSourceInProcess,
|
|
Hook: struct{}{},
|
|
}); err != nil {
|
|
t.Fatalf("mount in-process hook: %v", err)
|
|
}
|
|
|
|
ordered := hm.snapshotHooks()
|
|
if len(ordered) != 2 {
|
|
t.Fatalf("expected 2 hooks, got %d", len(ordered))
|
|
}
|
|
if ordered[0].Name != "in-process" {
|
|
t.Fatalf("expected in-process hook first, got %q", ordered[0].Name)
|
|
}
|
|
if ordered[1].Name != "process" {
|
|
t.Fatalf("expected process hook second, got %q", ordered[1].Name)
|
|
}
|
|
}
|
|
|
|
type llmHookTestProvider struct {
|
|
mu sync.Mutex
|
|
lastModel string
|
|
}
|
|
|
|
func (p *llmHookTestProvider) Chat(
|
|
ctx context.Context,
|
|
messages []providers.Message,
|
|
tools []providers.ToolDefinition,
|
|
model string,
|
|
opts map[string]any,
|
|
) (*providers.LLMResponse, error) {
|
|
p.mu.Lock()
|
|
p.lastModel = model
|
|
p.mu.Unlock()
|
|
|
|
return &providers.LLMResponse{
|
|
Content: "provider content",
|
|
}, nil
|
|
}
|
|
|
|
func (p *llmHookTestProvider) GetDefaultModel() string {
|
|
return "llm-hook-provider"
|
|
}
|
|
|
|
type llmObserverHook struct {
|
|
eventCh chan runtimeevents.Event
|
|
lastInbound *bus.InboundContext
|
|
lastRoute *routing.ResolvedRoute
|
|
lastScope *session.SessionScope
|
|
}
|
|
|
|
func (h *llmObserverHook) OnRuntimeEvent(ctx context.Context, evt runtimeevents.Event) error {
|
|
if evt.Kind == runtimeevents.KindAgentTurnEnd {
|
|
select {
|
|
case h.eventCh <- evt:
|
|
default:
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (h *llmObserverHook) BeforeLLM(
|
|
ctx context.Context,
|
|
req *LLMHookRequest,
|
|
) (*LLMHookRequest, HookDecision, error) {
|
|
if req.Context != nil {
|
|
h.lastInbound = cloneInboundContext(req.Context.Inbound)
|
|
h.lastRoute = cloneResolvedRoute(req.Context.Route)
|
|
h.lastScope = session.CloneScope(req.Context.Scope)
|
|
}
|
|
next := req.Clone()
|
|
next.Model = "hook-model"
|
|
return next, HookDecision{Action: HookActionModify}, nil
|
|
}
|
|
|
|
func (h *llmObserverHook) AfterLLM(
|
|
ctx context.Context,
|
|
resp *LLMHookResponse,
|
|
) (*LLMHookResponse, HookDecision, error) {
|
|
next := resp.Clone()
|
|
next.Response.Content = "hooked content"
|
|
return next, HookDecision{Action: HookActionModify}, nil
|
|
}
|
|
|
|
type dualRuntimeObserverHook struct {
|
|
legacyCh chan Event
|
|
runtimeCh chan runtimeevents.Event
|
|
}
|
|
|
|
func (h *dualRuntimeObserverHook) OnEvent(ctx context.Context, evt Event) error {
|
|
if evt.Kind == EventKindTurnEnd {
|
|
select {
|
|
case h.legacyCh <- evt:
|
|
default:
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (h *dualRuntimeObserverHook) OnRuntimeEvent(ctx context.Context, evt runtimeevents.Event) error {
|
|
if evt.Kind == runtimeevents.KindAgentTurnEnd {
|
|
select {
|
|
case h.runtimeCh <- evt:
|
|
default:
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
type llmSystemRewriteHook struct{}
|
|
|
|
func (h *llmSystemRewriteHook) BeforeLLM(
|
|
ctx context.Context,
|
|
req *LLMHookRequest,
|
|
) (*LLMHookRequest, HookDecision, error) {
|
|
next := req.Clone()
|
|
next.Model = "changed-model"
|
|
next.Messages[0].Content = "rewritten system"
|
|
return next, HookDecision{Action: HookActionModify}, nil
|
|
}
|
|
|
|
func (h *llmSystemRewriteHook) AfterLLM(
|
|
ctx context.Context,
|
|
resp *LLMHookResponse,
|
|
) (*LLMHookResponse, HookDecision, error) {
|
|
return resp.Clone(), HookDecision{Action: HookActionContinue}, nil
|
|
}
|
|
|
|
type llmUserAppendHook struct{}
|
|
|
|
func (h *llmUserAppendHook) BeforeLLM(
|
|
ctx context.Context,
|
|
req *LLMHookRequest,
|
|
) (*LLMHookRequest, HookDecision, error) {
|
|
next := req.Clone()
|
|
next.Messages = append(next.Messages, providers.Message{Role: "user", Content: "extra user context"})
|
|
return next, HookDecision{Action: HookActionModify}, nil
|
|
}
|
|
|
|
func (h *llmUserAppendHook) AfterLLM(
|
|
ctx context.Context,
|
|
resp *LLMHookResponse,
|
|
) (*LLMHookResponse, HookDecision, error) {
|
|
return resp.Clone(), HookDecision{Action: HookActionContinue}, nil
|
|
}
|
|
|
|
type llmJSONRoundTripUserAppendHook struct{}
|
|
|
|
type jsonRoundTripLLMHookRequest struct {
|
|
Model string `json:"model"`
|
|
Messages []providers.Message `json:"messages,omitempty"`
|
|
Tools []providers.ToolDefinition `json:"tools,omitempty"`
|
|
}
|
|
|
|
func (h *llmJSONRoundTripUserAppendHook) BeforeLLM(
|
|
ctx context.Context,
|
|
req *LLMHookRequest,
|
|
) (*LLMHookRequest, HookDecision, error) {
|
|
payload := jsonRoundTripLLMHookRequest{
|
|
Model: req.Model,
|
|
Messages: req.Messages,
|
|
Tools: req.Tools,
|
|
}
|
|
data, err := json.Marshal(payload)
|
|
if err != nil {
|
|
return nil, HookDecision{}, err
|
|
}
|
|
var decoded jsonRoundTripLLMHookRequest
|
|
if err := json.Unmarshal(data, &decoded); err != nil {
|
|
return nil, HookDecision{}, err
|
|
}
|
|
next := req.Clone()
|
|
next.Model = decoded.Model
|
|
next.Messages = decoded.Messages
|
|
next.Tools = decoded.Tools
|
|
next.Messages = append(next.Messages, providers.Message{Role: "user", Content: "json extra user context"})
|
|
return next, HookDecision{Action: HookActionModify}, nil
|
|
}
|
|
|
|
func (h *llmJSONRoundTripUserAppendHook) AfterLLM(
|
|
ctx context.Context,
|
|
resp *LLMHookResponse,
|
|
) (*LLMHookResponse, HookDecision, error) {
|
|
return resp.Clone(), HookDecision{Action: HookActionContinue}, nil
|
|
}
|
|
|
|
type llmToolRewriteHook struct{}
|
|
|
|
func (h *llmToolRewriteHook) BeforeLLM(
|
|
ctx context.Context,
|
|
req *LLMHookRequest,
|
|
) (*LLMHookRequest, HookDecision, error) {
|
|
next := req.Clone()
|
|
next.Model = "changed-model"
|
|
next.Tools[0].Function.Description = "rewritten tool"
|
|
next.Tools = append(next.Tools, providers.ToolDefinition{
|
|
Type: "function",
|
|
Function: providers.ToolFunctionDefinition{
|
|
Name: "hook_tool",
|
|
Description: "hook tool",
|
|
Parameters: map[string]any{"type": "object"},
|
|
},
|
|
PromptLayer: string(PromptLayerCapability),
|
|
PromptSlot: string(PromptSlotTooling),
|
|
PromptSource: "hook:test",
|
|
})
|
|
return next, HookDecision{Action: HookActionModify}, nil
|
|
}
|
|
|
|
func (h *llmToolRewriteHook) AfterLLM(
|
|
ctx context.Context,
|
|
resp *LLMHookResponse,
|
|
) (*LLMHookResponse, HookDecision, error) {
|
|
return resp.Clone(), HookDecision{Action: HookActionContinue}, nil
|
|
}
|
|
|
|
func TestHookManager_BeforeLLMControlsSystemPromptMutation(t *testing.T) {
|
|
hm := NewHookManager(nil)
|
|
if err := hm.Mount(NamedHook("rewrite-system", &llmSystemRewriteHook{})); err != nil {
|
|
t.Fatalf("Mount() error = %v", err)
|
|
}
|
|
|
|
req := &LLMHookRequest{
|
|
Model: "original-model",
|
|
Messages: []providers.Message{
|
|
{
|
|
Role: "system",
|
|
Content: "original system",
|
|
SystemParts: []providers.ContentBlock{
|
|
{Type: "text", Text: "original system"},
|
|
},
|
|
},
|
|
{Role: "user", Content: "hello"},
|
|
},
|
|
}
|
|
|
|
got, decision := hm.BeforeLLM(context.Background(), req)
|
|
if decision.normalizedAction() != HookActionContinue {
|
|
t.Fatalf("decision = %v, want continue", decision)
|
|
}
|
|
if got.Model != "changed-model" {
|
|
t.Fatalf("model = %q, want changed-model", got.Model)
|
|
}
|
|
if got.Messages[0].Content != "original system" {
|
|
t.Fatalf("system content = %q, want original system", got.Messages[0].Content)
|
|
}
|
|
if got.Messages[1].Content != "hello" {
|
|
t.Fatalf("user content = %q, want hello", got.Messages[1].Content)
|
|
}
|
|
}
|
|
|
|
func TestHookManager_BeforeLLMAllowsNonSystemMessageMutation(t *testing.T) {
|
|
hm := NewHookManager(nil)
|
|
if err := hm.Mount(NamedHook("append-user", &llmUserAppendHook{})); err != nil {
|
|
t.Fatalf("Mount() error = %v", err)
|
|
}
|
|
|
|
req := &LLMHookRequest{
|
|
Model: "model",
|
|
Messages: []providers.Message{
|
|
{Role: "system", Content: "system"},
|
|
{Role: "user", Content: "hello"},
|
|
},
|
|
}
|
|
|
|
got, _ := hm.BeforeLLM(context.Background(), req)
|
|
if len(got.Messages) != 3 {
|
|
t.Fatalf("messages len = %d, want 3", len(got.Messages))
|
|
}
|
|
if got.Messages[2].Role != "user" || got.Messages[2].Content != "extra user context" {
|
|
t.Fatalf("appended message = %#v, want extra user context", got.Messages[2])
|
|
}
|
|
}
|
|
|
|
func TestHookManager_BeforeLLMAllowsJSONRoundTripNonSystemMessageMutation(t *testing.T) {
|
|
hm := NewHookManager(nil)
|
|
if err := hm.Mount(NamedHook("json-append-user", &llmJSONRoundTripUserAppendHook{})); err != nil {
|
|
t.Fatalf("Mount() error = %v", err)
|
|
}
|
|
|
|
req := &LLMHookRequest{
|
|
Model: "model",
|
|
Messages: []providers.Message{
|
|
{
|
|
Role: "system",
|
|
Content: "system",
|
|
PromptLayer: string(PromptLayerKernel),
|
|
PromptSlot: string(PromptSlotIdentity),
|
|
PromptSource: string(PromptSourceKernel),
|
|
SystemParts: []providers.ContentBlock{
|
|
{
|
|
Type: "text",
|
|
Text: "system",
|
|
CacheControl: &providers.CacheControl{Type: "ephemeral"},
|
|
PromptLayer: string(PromptLayerKernel),
|
|
PromptSlot: string(PromptSlotIdentity),
|
|
PromptSource: string(PromptSourceKernel),
|
|
},
|
|
},
|
|
},
|
|
{Role: "user", Content: "hello"},
|
|
},
|
|
Tools: []providers.ToolDefinition{
|
|
{
|
|
Type: "function",
|
|
Function: providers.ToolFunctionDefinition{
|
|
Name: "mcp_github_create_issue",
|
|
Description: "create issue",
|
|
Parameters: map[string]any{"type": "object"},
|
|
},
|
|
PromptLayer: string(PromptLayerCapability),
|
|
PromptSlot: string(PromptSlotMCP),
|
|
PromptSource: "mcp:github",
|
|
},
|
|
},
|
|
}
|
|
|
|
got, _ := hm.BeforeLLM(context.Background(), req)
|
|
if len(got.Messages) != 3 {
|
|
t.Fatalf("messages len = %d, want 3", len(got.Messages))
|
|
}
|
|
if got.Messages[2].Role != "user" || got.Messages[2].Content != "json extra user context" {
|
|
t.Fatalf("appended message = %#v, want json extra user context", got.Messages[2])
|
|
}
|
|
}
|
|
|
|
func TestHookManager_BeforeLLMControlsToolDefinitionMutation(t *testing.T) {
|
|
hm := NewHookManager(nil)
|
|
if err := hm.Mount(NamedHook("rewrite-tool", &llmToolRewriteHook{})); err != nil {
|
|
t.Fatalf("Mount() error = %v", err)
|
|
}
|
|
|
|
req := &LLMHookRequest{
|
|
Model: "original-model",
|
|
Messages: []providers.Message{
|
|
{Role: "system", Content: "system"},
|
|
{Role: "user", Content: "hello"},
|
|
},
|
|
Tools: []providers.ToolDefinition{
|
|
{
|
|
Type: "function",
|
|
Function: providers.ToolFunctionDefinition{
|
|
Name: "mcp_github_create_issue",
|
|
Description: "create issue",
|
|
Parameters: map[string]any{"type": "object"},
|
|
},
|
|
PromptLayer: string(PromptLayerCapability),
|
|
PromptSlot: string(PromptSlotMCP),
|
|
PromptSource: "mcp:github",
|
|
},
|
|
},
|
|
}
|
|
|
|
got, decision := hm.BeforeLLM(context.Background(), req)
|
|
if decision.normalizedAction() != HookActionContinue {
|
|
t.Fatalf("decision = %v, want continue", decision)
|
|
}
|
|
if got.Model != "changed-model" {
|
|
t.Fatalf("model = %q, want changed-model", got.Model)
|
|
}
|
|
if len(got.Tools) != 1 {
|
|
t.Fatalf("tools len = %d, want original 1", len(got.Tools))
|
|
}
|
|
if got.Tools[0].Function.Description != "create issue" {
|
|
t.Fatalf("tool description = %q, want original", got.Tools[0].Function.Description)
|
|
}
|
|
if got.Tools[0].PromptSource != "mcp:github" || got.Tools[0].PromptSlot != string(PromptSlotMCP) {
|
|
t.Fatalf("tool prompt metadata = %#v, want original mcp metadata", got.Tools[0])
|
|
}
|
|
}
|
|
|
|
func TestAgentLoop_Hooks_ObserverAndLLMInterceptor(t *testing.T) {
|
|
provider := &llmHookTestProvider{}
|
|
al, agent, cleanup := newHookTestLoop(t, provider)
|
|
defer cleanup()
|
|
|
|
hook := &llmObserverHook{eventCh: make(chan runtimeevents.Event, 1)}
|
|
if err := al.MountHook(NamedHook("llm-observer", hook)); err != nil {
|
|
t.Fatalf("MountHook failed: %v", err)
|
|
}
|
|
|
|
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
|
SessionKey: "session-1",
|
|
Channel: "cli",
|
|
ChatID: "direct",
|
|
UserMessage: "hello",
|
|
DefaultResponse: defaultResponse,
|
|
EnableSummary: false,
|
|
SendResponse: false,
|
|
InboundContext: &bus.InboundContext{
|
|
Channel: "cli",
|
|
ChatID: "direct",
|
|
ChatType: "direct",
|
|
SenderID: "hook-user",
|
|
},
|
|
RouteResult: &routing.ResolvedRoute{
|
|
AgentID: "main",
|
|
Channel: "cli",
|
|
AccountID: routing.DefaultAccountID,
|
|
SessionPolicy: routing.SessionPolicy{
|
|
Dimensions: []string{"sender"},
|
|
},
|
|
MatchedBy: "default",
|
|
},
|
|
SessionScope: &session.SessionScope{
|
|
Version: session.ScopeVersionV1,
|
|
AgentID: "main",
|
|
Channel: "cli",
|
|
Account: routing.DefaultAccountID,
|
|
Dimensions: []string{"sender"},
|
|
Values: map[string]string{
|
|
"sender": "hook-user",
|
|
},
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("runAgentLoop failed: %v", err)
|
|
}
|
|
if resp != "hooked content" {
|
|
t.Fatalf("expected hooked content, got %q", resp)
|
|
}
|
|
|
|
provider.mu.Lock()
|
|
lastModel := provider.lastModel
|
|
provider.mu.Unlock()
|
|
if lastModel != "hook-model" {
|
|
t.Fatalf("expected model hook-model, got %q", lastModel)
|
|
}
|
|
if hook.lastInbound == nil {
|
|
t.Fatal("expected hook to receive inbound context")
|
|
}
|
|
if hook.lastInbound.Channel != "cli" || hook.lastInbound.SenderID != "hook-user" {
|
|
t.Fatalf("hook inbound context = %+v", hook.lastInbound)
|
|
}
|
|
if hook.lastInbound != nil && hook.lastInbound.ChatID != "direct" {
|
|
t.Fatalf("hook inbound chat ID = %q, want direct", hook.lastInbound.ChatID)
|
|
}
|
|
|
|
select {
|
|
case evt := <-hook.eventCh:
|
|
if evt.Kind != runtimeevents.KindAgentTurnEnd {
|
|
t.Fatalf("expected turn end event, got %v", evt.Kind)
|
|
}
|
|
if evt.Scope.AgentID != "main" ||
|
|
evt.Scope.SessionKey != "session-1" ||
|
|
evt.Scope.Channel != "cli" ||
|
|
evt.Scope.ChatID != "direct" ||
|
|
evt.Scope.SenderID != "hook-user" {
|
|
t.Fatalf("runtime observer scope = %+v", evt.Scope)
|
|
}
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("timed out waiting for hook observer event")
|
|
}
|
|
}
|
|
|
|
func TestAgentLoop_Hooks_RuntimeObserverPreferredOverLegacyObserver(t *testing.T) {
|
|
provider := &llmHookTestProvider{}
|
|
al, agent, cleanup := newHookTestLoop(t, provider)
|
|
defer cleanup()
|
|
|
|
hook := &dualRuntimeObserverHook{
|
|
legacyCh: make(chan Event, 1),
|
|
runtimeCh: make(chan runtimeevents.Event, 1),
|
|
}
|
|
if err := al.MountHook(NamedHook("runtime-observer", hook)); err != nil {
|
|
t.Fatalf("MountHook failed: %v", err)
|
|
}
|
|
|
|
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
|
SessionKey: "session-1",
|
|
Channel: "cli",
|
|
ChatID: "direct",
|
|
UserMessage: "hello",
|
|
DefaultResponse: defaultResponse,
|
|
EnableSummary: false,
|
|
SendResponse: false,
|
|
InboundContext: &bus.InboundContext{
|
|
Channel: "cli",
|
|
Account: "default",
|
|
ChatID: "direct",
|
|
ChatType: "direct",
|
|
SenderID: "hook-user",
|
|
MessageID: "msg-1",
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("runAgentLoop failed: %v", err)
|
|
}
|
|
if resp != "provider content" {
|
|
t.Fatalf("expected provider content, got %q", resp)
|
|
}
|
|
|
|
select {
|
|
case evt := <-hook.runtimeCh:
|
|
if evt.Kind != runtimeevents.KindAgentTurnEnd {
|
|
t.Fatalf("runtime observer kind = %q", evt.Kind)
|
|
}
|
|
if evt.Scope.SessionKey != "session-1" ||
|
|
evt.Scope.Channel != "cli" ||
|
|
evt.Scope.ChatID != "direct" ||
|
|
evt.Scope.MessageID != "msg-1" {
|
|
t.Fatalf("runtime observer scope = %+v", evt.Scope)
|
|
}
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("timed out waiting for runtime observer event")
|
|
}
|
|
|
|
select {
|
|
case evt := <-hook.legacyCh:
|
|
t.Fatalf("legacy observer unexpectedly received %v", evt.Kind)
|
|
case <-time.After(100 * time.Millisecond):
|
|
}
|
|
}
|
|
|
|
func TestAgentLoop_BtwCommand_UsesLLMHooks(t *testing.T) {
|
|
provider := &llmHookTestProvider{}
|
|
al, agent, cleanup := newHookTestLoop(t, provider)
|
|
defer cleanup()
|
|
useTestSideQuestionProvider(al, provider)
|
|
|
|
hook := &llmObserverHook{eventCh: make(chan runtimeevents.Event, 1)}
|
|
if err := al.MountHook(NamedHook("llm-observer", hook)); err != nil {
|
|
t.Fatalf("MountHook failed: %v", err)
|
|
}
|
|
|
|
response, handled := al.handleCommand(context.Background(), bus.InboundMessage{
|
|
Context: bus.InboundContext{
|
|
Channel: "cli",
|
|
ChatID: "direct",
|
|
ChatType: "direct",
|
|
SenderID: "hook-user",
|
|
},
|
|
Content: "/btw hello",
|
|
}, agent, &processOptions{
|
|
Dispatch: DispatchRequest{
|
|
SessionKey: "session-1",
|
|
InboundContext: &bus.InboundContext{
|
|
Channel: "cli",
|
|
ChatID: "direct",
|
|
ChatType: "direct",
|
|
SenderID: "hook-user",
|
|
},
|
|
RouteResult: &routing.ResolvedRoute{
|
|
AgentID: "main",
|
|
Channel: "cli",
|
|
AccountID: routing.DefaultAccountID,
|
|
SessionPolicy: routing.SessionPolicy{
|
|
Dimensions: []string{"sender"},
|
|
},
|
|
MatchedBy: "default",
|
|
},
|
|
SessionScope: &session.SessionScope{
|
|
Version: session.ScopeVersionV1,
|
|
AgentID: "main",
|
|
Channel: "cli",
|
|
Account: routing.DefaultAccountID,
|
|
Dimensions: []string{"sender"},
|
|
Values: map[string]string{
|
|
"sender": "hook-user",
|
|
},
|
|
},
|
|
UserMessage: "/btw hello",
|
|
},
|
|
SessionKey: "session-1",
|
|
Channel: "cli",
|
|
ChatID: "direct",
|
|
SenderID: "hook-user",
|
|
SenderDisplayName: "Hook User",
|
|
})
|
|
if !handled {
|
|
t.Fatal("expected /btw command to be handled")
|
|
}
|
|
if response != "hooked content" {
|
|
t.Fatalf("expected hooked content, got %q", response)
|
|
}
|
|
|
|
provider.mu.Lock()
|
|
lastModel := provider.lastModel
|
|
provider.mu.Unlock()
|
|
if lastModel != "hook-model" {
|
|
t.Fatalf("expected model hook-model, got %q", lastModel)
|
|
}
|
|
if hook.lastInbound == nil {
|
|
t.Fatal("expected hook to receive inbound context")
|
|
}
|
|
if hook.lastInbound.Channel != "cli" || hook.lastInbound.SenderID != "hook-user" {
|
|
t.Fatalf("hook inbound context = %+v", hook.lastInbound)
|
|
}
|
|
if hook.lastInbound.ChatID != "direct" {
|
|
t.Fatalf("hook inbound chat ID = %q, want direct", hook.lastInbound.ChatID)
|
|
}
|
|
if hook.lastRoute == nil || hook.lastRoute.AgentID != "main" {
|
|
t.Fatalf("expected hook route context for /btw, got %+v", hook.lastRoute)
|
|
}
|
|
if hook.lastScope == nil || hook.lastScope.Values["sender"] != "hook-user" {
|
|
t.Fatalf("expected hook session scope for /btw, got %+v", hook.lastScope)
|
|
}
|
|
}
|
|
|
|
type toolHookProvider struct {
|
|
mu sync.Mutex
|
|
calls int
|
|
}
|
|
|
|
func (p *toolHookProvider) Chat(
|
|
ctx context.Context,
|
|
messages []providers.Message,
|
|
tools []providers.ToolDefinition,
|
|
model string,
|
|
opts map[string]any,
|
|
) (*providers.LLMResponse, error) {
|
|
p.mu.Lock()
|
|
defer p.mu.Unlock()
|
|
|
|
p.calls++
|
|
if p.calls == 1 {
|
|
return &providers.LLMResponse{
|
|
ToolCalls: []providers.ToolCall{
|
|
{
|
|
ID: "call-1",
|
|
Name: "echo_text",
|
|
Arguments: map[string]any{"text": "original"},
|
|
},
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
last := messages[len(messages)-1]
|
|
return &providers.LLMResponse{
|
|
Content: last.Content,
|
|
}, nil
|
|
}
|
|
|
|
func (p *toolHookProvider) GetDefaultModel() string {
|
|
return "tool-hook-provider"
|
|
}
|
|
|
|
type echoTextTool struct{}
|
|
|
|
func (t *echoTextTool) Name() string {
|
|
return "echo_text"
|
|
}
|
|
|
|
func (t *echoTextTool) Description() string {
|
|
return "echo a text argument"
|
|
}
|
|
|
|
func (t *echoTextTool) Parameters() map[string]any {
|
|
return map[string]any{
|
|
"type": "object",
|
|
"properties": map[string]any{
|
|
"text": map[string]any{
|
|
"type": "string",
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
func (t *echoTextTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
|
|
text, _ := args["text"].(string)
|
|
return tools.SilentResult(text)
|
|
}
|
|
|
|
type toolRewriteHook struct{}
|
|
|
|
func (h *toolRewriteHook) BeforeTool(
|
|
ctx context.Context,
|
|
call *ToolCallHookRequest,
|
|
) (*ToolCallHookRequest, HookDecision, error) {
|
|
next := call.Clone()
|
|
next.Arguments["text"] = "modified"
|
|
return next, HookDecision{Action: HookActionModify}, nil
|
|
}
|
|
|
|
func (h *toolRewriteHook) AfterTool(
|
|
ctx context.Context,
|
|
result *ToolResultHookResponse,
|
|
) (*ToolResultHookResponse, HookDecision, error) {
|
|
next := result.Clone()
|
|
next.Result.ForLLM = "after:" + next.Result.ForLLM
|
|
return next, HookDecision{Action: HookActionModify}, nil
|
|
}
|
|
|
|
type toolRenameHook struct{}
|
|
|
|
func (h *toolRenameHook) BeforeTool(
|
|
ctx context.Context,
|
|
call *ToolCallHookRequest,
|
|
) (*ToolCallHookRequest, HookDecision, error) {
|
|
next := call.Clone()
|
|
next.Tool = "echo_text_rewritten"
|
|
return next, HookDecision{Action: HookActionModify}, nil
|
|
}
|
|
|
|
func (h *toolRenameHook) AfterTool(
|
|
ctx context.Context,
|
|
result *ToolResultHookResponse,
|
|
) (*ToolResultHookResponse, HookDecision, error) {
|
|
return result.Clone(), HookDecision{Action: HookActionContinue}, nil
|
|
}
|
|
|
|
func TestAgentLoop_Hooks_ToolInterceptorCanRewrite(t *testing.T) {
|
|
provider := &toolHookProvider{}
|
|
al, agent, cleanup := newHookTestLoop(t, provider)
|
|
defer cleanup()
|
|
|
|
al.RegisterTool(&echoTextTool{})
|
|
if err := al.MountHook(NamedHook("tool-rewrite", &toolRewriteHook{})); err != nil {
|
|
t.Fatalf("MountHook failed: %v", err)
|
|
}
|
|
|
|
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
|
SessionKey: "session-1",
|
|
Channel: "cli",
|
|
ChatID: "direct",
|
|
UserMessage: "run tool",
|
|
DefaultResponse: defaultResponse,
|
|
EnableSummary: false,
|
|
SendResponse: false,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("runAgentLoop failed: %v", err)
|
|
}
|
|
if resp != "after:modified" {
|
|
t.Fatalf("expected rewritten tool result, got %q", resp)
|
|
}
|
|
}
|
|
|
|
type echoTextRewrittenTool struct{}
|
|
|
|
func (t *echoTextRewrittenTool) Name() string {
|
|
return "echo_text_rewritten"
|
|
}
|
|
|
|
func (t *echoTextRewrittenTool) Description() string {
|
|
return "echo a rewritten text argument"
|
|
}
|
|
|
|
func (t *echoTextRewrittenTool) Parameters() map[string]any {
|
|
return map[string]any{
|
|
"type": "object",
|
|
"properties": map[string]any{
|
|
"text": map[string]any{
|
|
"type": "string",
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
func (t *echoTextRewrittenTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
|
|
text, _ := args["text"].(string)
|
|
return tools.SilentResult("rewritten:" + text)
|
|
}
|
|
|
|
func TestAgentLoop_Hooks_ToolFeedbackUsesRewrittenToolName(t *testing.T) {
|
|
provider := &toolHookProvider{}
|
|
al, agent, cleanup := newHookTestLoop(t, provider)
|
|
defer cleanup()
|
|
|
|
al.cfg.Agents.Defaults.ToolFeedback.Enabled = true
|
|
al.RegisterTool(&echoTextTool{})
|
|
al.RegisterTool(&echoTextRewrittenTool{})
|
|
if err := al.MountHook(NamedHook("tool-rename", &toolRenameHook{})); err != nil {
|
|
t.Fatalf("MountHook failed: %v", err)
|
|
}
|
|
|
|
_, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
|
SessionKey: "session-1",
|
|
Channel: "cli",
|
|
ChatID: "direct",
|
|
UserMessage: "run tool",
|
|
DefaultResponse: defaultResponse,
|
|
EnableSummary: false,
|
|
SendResponse: false,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("runAgentLoop failed: %v", err)
|
|
}
|
|
|
|
msgBus, ok := al.bus.(*bus.MessageBus)
|
|
if !ok {
|
|
t.Fatalf("expected concrete MessageBus, got %T", al.bus)
|
|
}
|
|
|
|
select {
|
|
case outbound := <-msgBus.OutboundChan():
|
|
if !strings.Contains(outbound.Content, "`echo_text_rewritten`") {
|
|
t.Fatalf("tool feedback content = %q, want rewritten tool name", outbound.Content)
|
|
}
|
|
if strings.Contains(outbound.Content, "`echo_text`") {
|
|
t.Fatalf("tool feedback content = %q, want no original tool name", outbound.Content)
|
|
}
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("expected outbound tool feedback")
|
|
}
|
|
}
|
|
|
|
type denyApprovalHook struct{}
|
|
|
|
func (h *denyApprovalHook) ApproveTool(ctx context.Context, req *ToolApprovalRequest) (ApprovalDecision, error) {
|
|
return ApprovalDecision{
|
|
Approved: false,
|
|
Reason: "blocked",
|
|
}, nil
|
|
}
|
|
|
|
func TestAgentLoop_Hooks_ToolApproverCanDeny(t *testing.T) {
|
|
provider := &toolHookProvider{}
|
|
al, agent, cleanup := newHookTestLoop(t, provider)
|
|
defer cleanup()
|
|
|
|
al.RegisterTool(&echoTextTool{})
|
|
if err := al.MountHook(NamedHook("deny-approval", &denyApprovalHook{})); err != nil {
|
|
t.Fatalf("MountHook failed: %v", err)
|
|
}
|
|
|
|
runtimeCh, closeRuntimeEvents := subscribeRuntimeEventsForTest(
|
|
t,
|
|
al,
|
|
16,
|
|
runtimeevents.KindAgentToolExecSkipped,
|
|
)
|
|
defer closeRuntimeEvents()
|
|
|
|
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
|
SessionKey: "session-1",
|
|
Channel: "cli",
|
|
ChatID: "direct",
|
|
UserMessage: "run tool",
|
|
DefaultResponse: defaultResponse,
|
|
EnableSummary: false,
|
|
SendResponse: false,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("runAgentLoop failed: %v", err)
|
|
}
|
|
expected := "Tool execution denied by approval hook: blocked"
|
|
if resp != expected {
|
|
t.Fatalf("expected %q, got %q", expected, resp)
|
|
}
|
|
|
|
events := collectRuntimeEventStream(runtimeCh)
|
|
skippedEvt, ok := findRuntimeEvent(events, runtimeevents.KindAgentToolExecSkipped)
|
|
if !ok {
|
|
t.Fatal("expected tool skipped event")
|
|
}
|
|
payload, ok := skippedEvt.Payload.(ToolExecSkippedPayload)
|
|
if !ok {
|
|
t.Fatalf("expected ToolExecSkippedPayload, got %T", skippedEvt.Payload)
|
|
}
|
|
if payload.Reason != expected {
|
|
t.Fatalf("expected skipped reason %q, got %q", expected, payload.Reason)
|
|
}
|
|
}
|
|
|
|
// respondHook is a test hook for testing HookActionRespond functionality
|
|
type respondHook struct {
|
|
respondTools map[string]bool // tool names to respond to
|
|
}
|
|
|
|
func (h *respondHook) BeforeTool(
|
|
ctx context.Context,
|
|
call *ToolCallHookRequest,
|
|
) (*ToolCallHookRequest, HookDecision, error) {
|
|
if h.respondTools[call.Tool] {
|
|
next := call.Clone()
|
|
next.HookResult = &tools.ToolResult{
|
|
ForLLM: "hook-responded: " + call.Tool,
|
|
ForUser: "",
|
|
Silent: false,
|
|
IsError: false,
|
|
}
|
|
return next, HookDecision{Action: HookActionRespond}, nil
|
|
}
|
|
return call, HookDecision{Action: HookActionContinue}, nil
|
|
}
|
|
|
|
func (h *respondHook) AfterTool(
|
|
ctx context.Context,
|
|
result *ToolResultHookResponse,
|
|
) (*ToolResultHookResponse, HookDecision, error) {
|
|
// Should not be called since respond skips tool execution
|
|
return result, HookDecision{Action: HookActionContinue}, nil
|
|
}
|
|
|
|
func TestAgentLoop_Hooks_ToolRespondAction(t *testing.T) {
|
|
provider := &toolHookProvider{}
|
|
al, agent, cleanup := newHookTestLoop(t, provider)
|
|
defer cleanup()
|
|
|
|
al.RegisterTool(&echoTextTool{})
|
|
if err := al.MountHook(NamedHook("respond-hook", &respondHook{
|
|
respondTools: map[string]bool{"echo_text": true},
|
|
})); err != nil {
|
|
t.Fatalf("MountHook failed: %v", err)
|
|
}
|
|
|
|
runtimeCh, closeRuntimeEvents := subscribeRuntimeEventsForTest(
|
|
t,
|
|
al,
|
|
16,
|
|
runtimeevents.KindAgentToolExecEnd,
|
|
)
|
|
defer closeRuntimeEvents()
|
|
|
|
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
|
SessionKey: "session-1",
|
|
Channel: "cli",
|
|
ChatID: "direct",
|
|
UserMessage: "run tool",
|
|
DefaultResponse: defaultResponse,
|
|
EnableSummary: false,
|
|
SendResponse: false,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("runAgentLoop failed: %v", err)
|
|
}
|
|
|
|
// Verify response comes from hook, not tool
|
|
expected := "hook-responded: echo_text"
|
|
if resp != expected {
|
|
t.Fatalf("expected %q, got %q", expected, resp)
|
|
}
|
|
|
|
// Verify event stream has ToolExecEnd, not actual tool execution
|
|
events := collectRuntimeEventStream(runtimeCh)
|
|
endEvt, ok := findRuntimeEvent(events, runtimeevents.KindAgentToolExecEnd)
|
|
if !ok {
|
|
t.Fatal("expected tool exec end event")
|
|
}
|
|
payload, ok := endEvt.Payload.(ToolExecEndPayload)
|
|
if !ok {
|
|
t.Fatalf("expected ToolExecEndPayload, got %T", endEvt.Payload)
|
|
}
|
|
if payload.Tool != "echo_text" {
|
|
t.Fatalf("expected tool echo_text, got %q", payload.Tool)
|
|
}
|
|
if payload.ForLLMLen != len(expected) {
|
|
t.Fatalf("expected ForLLMLen %d, got %d", len(expected), payload.ForLLMLen)
|
|
}
|
|
}
|
|
|
|
// denyToolHook tests HookActionDenyTool functionality
|
|
type denyToolHook struct {
|
|
denyTools map[string]bool
|
|
}
|
|
|
|
func (h *denyToolHook) BeforeTool(
|
|
ctx context.Context,
|
|
call *ToolCallHookRequest,
|
|
) (*ToolCallHookRequest, HookDecision, error) {
|
|
if h.denyTools[call.Tool] {
|
|
return call, HookDecision{Action: HookActionDenyTool, Reason: "tool denied by hook"}, nil
|
|
}
|
|
return call, HookDecision{Action: HookActionContinue}, nil
|
|
}
|
|
|
|
func (h *denyToolHook) AfterTool(
|
|
ctx context.Context,
|
|
result *ToolResultHookResponse,
|
|
) (*ToolResultHookResponse, HookDecision, error) {
|
|
return result, HookDecision{Action: HookActionContinue}, nil
|
|
}
|
|
|
|
func TestAgentLoop_Hooks_ToolDenyAction(t *testing.T) {
|
|
provider := &toolHookProvider{}
|
|
al, agent, cleanup := newHookTestLoop(t, provider)
|
|
defer cleanup()
|
|
|
|
al.RegisterTool(&echoTextTool{})
|
|
if err := al.MountHook(NamedHook("deny-hook", &denyToolHook{
|
|
denyTools: map[string]bool{"echo_text": true},
|
|
})); err != nil {
|
|
t.Fatalf("MountHook failed: %v", err)
|
|
}
|
|
|
|
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
|
SessionKey: "session-1",
|
|
Channel: "cli",
|
|
ChatID: "direct",
|
|
UserMessage: "run tool",
|
|
DefaultResponse: defaultResponse,
|
|
EnableSummary: false,
|
|
SendResponse: false,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("runAgentLoop failed: %v", err)
|
|
}
|
|
|
|
expected := "Tool execution denied by hook: tool denied by hook"
|
|
if resp != expected {
|
|
t.Fatalf("expected %q, got %q", expected, resp)
|
|
}
|
|
}
|
|
|
|
func TestHookManager_BeforeTool_RespondAction(t *testing.T) {
|
|
hm := NewHookManager(nil)
|
|
defer hm.Close()
|
|
|
|
hook := &respondHook{
|
|
respondTools: map[string]bool{"test_tool": true},
|
|
}
|
|
if err := hm.Mount(NamedHook("respond-test", hook)); err != nil {
|
|
t.Fatalf("mount hook: %v", err)
|
|
}
|
|
|
|
req := &ToolCallHookRequest{
|
|
Tool: "test_tool",
|
|
Arguments: map[string]any{"arg": "value"},
|
|
}
|
|
result, decision := hm.BeforeTool(context.Background(), req)
|
|
|
|
if decision.Action != HookActionRespond {
|
|
t.Fatalf("expected action %q, got %q", HookActionRespond, decision.Action)
|
|
}
|
|
|
|
if result.HookResult == nil {
|
|
t.Fatal("expected HookResult to be set")
|
|
}
|
|
if result.HookResult.ForLLM != "hook-responded: test_tool" {
|
|
t.Fatalf("unexpected HookResult.ForLLM: %q", result.HookResult.ForLLM)
|
|
}
|
|
}
|
|
|
|
type respondWithMediaHook struct {
|
|
respondTools map[string]bool
|
|
media []string
|
|
responseHandled bool
|
|
forLLM string
|
|
}
|
|
|
|
func (h *respondWithMediaHook) BeforeTool(
|
|
ctx context.Context,
|
|
call *ToolCallHookRequest,
|
|
) (*ToolCallHookRequest, HookDecision, error) {
|
|
if h.respondTools[call.Tool] {
|
|
next := call.Clone()
|
|
next.HookResult = &tools.ToolResult{
|
|
ForLLM: h.forLLM,
|
|
ForUser: "media result",
|
|
Media: h.media,
|
|
ResponseHandled: h.responseHandled,
|
|
Silent: false,
|
|
IsError: false,
|
|
}
|
|
return next, HookDecision{Action: HookActionRespond}, nil
|
|
}
|
|
return call, HookDecision{Action: HookActionContinue}, nil
|
|
}
|
|
|
|
func (h *respondWithMediaHook) AfterTool(
|
|
ctx context.Context,
|
|
result *ToolResultHookResponse,
|
|
) (*ToolResultHookResponse, HookDecision, error) {
|
|
return result, HookDecision{Action: HookActionContinue}, nil
|
|
}
|
|
|
|
type errorMediaChannel struct {
|
|
fakeChannel
|
|
sendErr error
|
|
}
|
|
|
|
func (f *errorMediaChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) ([]string, error) {
|
|
return nil, f.sendErr
|
|
}
|
|
|
|
func TestAgentLoop_HookRespond_MediaError(t *testing.T) {
|
|
provider := &multiToolProvider{
|
|
toolCalls: []providers.ToolCall{
|
|
{ID: "call-1", Name: "media_tool", Arguments: map[string]any{}},
|
|
},
|
|
finalContent: "done",
|
|
}
|
|
al, agent, cleanup := newHookTestLoop(t, provider)
|
|
defer cleanup()
|
|
|
|
hook := &respondWithMediaHook{
|
|
respondTools: map[string]bool{"media_tool": true},
|
|
media: []string{"media://test/image.png"},
|
|
responseHandled: true,
|
|
forLLM: "media sent successfully",
|
|
}
|
|
if err := al.MountHook(NamedHook("media-hook", hook)); err != nil {
|
|
t.Fatalf("MountHook failed: %v", err)
|
|
}
|
|
|
|
al.channelManager = newStartedTestChannelManager(t,
|
|
al.bus.(*bus.MessageBus), al.mediaStore, "discord", &errorMediaChannel{
|
|
sendErr: errors.New("channel unavailable"),
|
|
})
|
|
|
|
runtimeCh, closeRuntimeEvents := subscribeRuntimeEventsForTest(
|
|
t,
|
|
al,
|
|
16,
|
|
runtimeevents.KindAgentToolExecEnd,
|
|
)
|
|
defer closeRuntimeEvents()
|
|
|
|
_, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
|
SessionKey: "session-media-err",
|
|
Channel: "discord",
|
|
ChatID: "chat1",
|
|
UserMessage: "send media",
|
|
DefaultResponse: defaultResponse,
|
|
EnableSummary: false,
|
|
SendResponse: false,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("runAgentLoop failed: %v", err)
|
|
}
|
|
|
|
events := collectRuntimeEventStream(runtimeCh)
|
|
endEvt, ok := findRuntimeEvent(events, runtimeevents.KindAgentToolExecEnd)
|
|
if !ok {
|
|
t.Fatal("expected ToolExecEnd event")
|
|
}
|
|
payload, ok := endEvt.Payload.(ToolExecEndPayload)
|
|
if !ok {
|
|
t.Fatalf("expected ToolExecEndPayload, got %T", endEvt.Payload)
|
|
}
|
|
|
|
if !payload.IsError {
|
|
t.Fatal("expected IsError=true when SendMedia fails")
|
|
}
|
|
|
|
if payload.ForLLMLen < 30 {
|
|
t.Fatalf("expected ForLLM to contain error message, got ForLLMLen=%d", payload.ForLLMLen)
|
|
}
|
|
}
|
|
|
|
func TestAgentLoop_HookRespond_BusFallback(t *testing.T) {
|
|
provider := &multiToolProvider{
|
|
toolCalls: []providers.ToolCall{
|
|
{ID: "call-1", Name: "media_tool", Arguments: map[string]any{}},
|
|
},
|
|
finalContent: "done",
|
|
}
|
|
al, agent, cleanup := newHookTestLoop(t, provider)
|
|
defer cleanup()
|
|
|
|
hook := &respondWithMediaHook{
|
|
respondTools: map[string]bool{"media_tool": true},
|
|
media: []string{"media://test/image.png"},
|
|
responseHandled: true,
|
|
forLLM: "media queued",
|
|
}
|
|
if err := al.MountHook(NamedHook("media-hook", hook)); err != nil {
|
|
t.Fatalf("MountHook failed: %v", err)
|
|
}
|
|
|
|
runtimeCh, closeRuntimeEvents := subscribeRuntimeEventsForTest(
|
|
t,
|
|
al,
|
|
16,
|
|
runtimeevents.KindAgentToolExecEnd,
|
|
)
|
|
defer closeRuntimeEvents()
|
|
|
|
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
|
SessionKey: "session-bus-fallback",
|
|
Channel: "cli",
|
|
ChatID: "chat1",
|
|
UserMessage: "send media",
|
|
DefaultResponse: defaultResponse,
|
|
EnableSummary: false,
|
|
SendResponse: false,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("runAgentLoop failed: %v", err)
|
|
}
|
|
|
|
events := collectRuntimeEventStream(runtimeCh)
|
|
endEvt, ok := findRuntimeEvent(events, runtimeevents.KindAgentToolExecEnd)
|
|
if !ok {
|
|
t.Fatal("expected ToolExecEnd event")
|
|
}
|
|
payload, ok := endEvt.Payload.(ToolExecEndPayload)
|
|
if !ok {
|
|
t.Fatalf("expected ToolExecEndPayload, got %T", endEvt.Payload)
|
|
}
|
|
|
|
if payload.IsError {
|
|
t.Fatal("expected IsError=false for bus fallback (media queued, not delivered)")
|
|
}
|
|
|
|
if resp != "done" {
|
|
t.Fatalf("expected response 'done', got %q", resp)
|
|
}
|
|
}
|
|
|
|
func TestAgentLoop_HookRespond_ResponseHandledMediaPreservesOutboundContext(t *testing.T) {
|
|
provider := &multiToolProvider{
|
|
toolCalls: []providers.ToolCall{
|
|
{ID: "call-1", Name: "media_tool", Arguments: map[string]any{}},
|
|
},
|
|
finalContent: "done",
|
|
}
|
|
al, agent, cleanup := newHookTestLoop(t, provider)
|
|
defer cleanup()
|
|
|
|
hook := &respondWithMediaHook{
|
|
respondTools: map[string]bool{"media_tool": true},
|
|
media: []string{"media://test/image.png"},
|
|
responseHandled: true,
|
|
forLLM: "media sent successfully",
|
|
}
|
|
if err := al.MountHook(NamedHook("media-hook", hook)); err != nil {
|
|
t.Fatalf("MountHook failed: %v", err)
|
|
}
|
|
|
|
telegramChannel := &fakeMediaChannel{fakeChannel: fakeChannel{id: "rid-telegram"}}
|
|
al.channelManager = newStartedTestChannelManager(t,
|
|
al.bus.(*bus.MessageBus), al.mediaStore, "telegram", telegramChannel)
|
|
|
|
_, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
|
Dispatch: DispatchRequest{
|
|
SessionKey: "session-topic-media",
|
|
SessionScope: &session.SessionScope{
|
|
Version: session.ScopeVersionV1,
|
|
AgentID: agent.ID,
|
|
Channel: "telegram",
|
|
Dimensions: []string{"chat"},
|
|
Values: map[string]string{
|
|
"chat": "forum:-100123/42",
|
|
},
|
|
},
|
|
InboundContext: &bus.InboundContext{
|
|
Channel: "telegram",
|
|
ChatID: "-100123",
|
|
TopicID: "42",
|
|
ChatType: "group",
|
|
SenderID: "user1",
|
|
},
|
|
UserMessage: "send media",
|
|
},
|
|
DefaultResponse: defaultResponse,
|
|
EnableSummary: false,
|
|
SendResponse: false,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("runAgentLoop failed: %v", err)
|
|
}
|
|
|
|
if len(telegramChannel.sentMedia) != 1 {
|
|
t.Fatalf("expected exactly 1 sent media message, got %d", len(telegramChannel.sentMedia))
|
|
}
|
|
sent := telegramChannel.sentMedia[0]
|
|
if sent.Context.Channel != "telegram" || sent.Context.ChatID != "-100123" || sent.Context.TopicID != "42" {
|
|
t.Fatalf("unexpected media context: %+v", sent.Context)
|
|
}
|
|
if sent.AgentID != agent.ID {
|
|
t.Fatalf("sent media agent_id = %q, want %q", sent.AgentID, agent.ID)
|
|
}
|
|
if sent.SessionKey != "session-topic-media" {
|
|
t.Fatalf("sent media session_key = %q, want session-topic-media", sent.SessionKey)
|
|
}
|
|
if sent.Scope == nil || sent.Scope.Values["chat"] != "forum:-100123/42" {
|
|
t.Fatalf("unexpected sent media scope: %+v", sent.Scope)
|
|
}
|
|
}
|
|
|
|
type multiToolProvider struct {
|
|
mu sync.Mutex
|
|
callCount int
|
|
toolCalls []providers.ToolCall
|
|
finalContent string
|
|
}
|
|
|
|
func (p *multiToolProvider) Chat(
|
|
ctx context.Context,
|
|
messages []providers.Message,
|
|
tools []providers.ToolDefinition,
|
|
model string,
|
|
opts map[string]any,
|
|
) (*providers.LLMResponse, error) {
|
|
p.mu.Lock()
|
|
defer p.mu.Unlock()
|
|
|
|
p.callCount++
|
|
if p.callCount == 1 && len(p.toolCalls) > 0 {
|
|
return &providers.LLMResponse{
|
|
ToolCalls: p.toolCalls,
|
|
}, nil
|
|
}
|
|
|
|
return &providers.LLMResponse{
|
|
Content: p.finalContent,
|
|
}, nil
|
|
}
|
|
|
|
func (p *multiToolProvider) GetDefaultModel() string {
|
|
return "multi-tool-provider"
|
|
}
|
|
|
|
func TestAgentLoop_HookRespond_InterruptSkipsRemaining(t *testing.T) {
|
|
provider := &multiToolProvider{
|
|
toolCalls: []providers.ToolCall{
|
|
{ID: "call-1", Name: "tool_one", Arguments: map[string]any{}},
|
|
{ID: "call-2", Name: "tool_two", Arguments: map[string]any{}},
|
|
{ID: "call-3", Name: "tool_three", Arguments: map[string]any{}},
|
|
},
|
|
finalContent: "done",
|
|
}
|
|
al, _, cleanup := newHookTestLoop(t, provider)
|
|
defer cleanup()
|
|
|
|
tool1ExecCh := make(chan struct{}, 1)
|
|
al.RegisterTool(&slowTool{name: "tool_two", duration: 100 * time.Millisecond, execCh: tool1ExecCh})
|
|
al.RegisterTool(&slowTool{name: "tool_three", duration: 100 * time.Millisecond})
|
|
|
|
hook := &respondHook{
|
|
respondTools: map[string]bool{"tool_one": true},
|
|
}
|
|
if err := al.MountHook(NamedHook("respond-hook", hook)); err != nil {
|
|
t.Fatalf("MountHook failed: %v", err)
|
|
}
|
|
|
|
runtimeCh, closeRuntimeEvents := subscribeRuntimeEventsForTest(
|
|
t,
|
|
al,
|
|
32,
|
|
runtimeevents.KindAgentToolExecSkipped,
|
|
)
|
|
defer closeRuntimeEvents()
|
|
|
|
sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID)
|
|
|
|
type result struct {
|
|
resp string
|
|
err error
|
|
}
|
|
resultCh := make(chan result, 1)
|
|
go func() {
|
|
resp, err := al.ProcessDirectWithChannel(
|
|
context.Background(),
|
|
"run tools",
|
|
sessionKey,
|
|
"cli",
|
|
"chat1",
|
|
)
|
|
resultCh <- result{resp: resp, err: err}
|
|
}()
|
|
|
|
select {
|
|
case <-tool1ExecCh:
|
|
case <-time.After(3 * time.Second):
|
|
t.Fatal("timeout waiting for tool execution to start")
|
|
}
|
|
|
|
if err := al.InterruptGraceful("stop now"); err != nil {
|
|
t.Fatalf("InterruptGraceful failed: %v", err)
|
|
}
|
|
|
|
select {
|
|
case r := <-resultCh:
|
|
if r.err != nil {
|
|
t.Fatalf("unexpected error: %v", r.err)
|
|
}
|
|
case <-time.After(3 * time.Second):
|
|
t.Fatal("timeout waiting for result")
|
|
}
|
|
|
|
events := collectRuntimeEventStream(runtimeCh)
|
|
|
|
skippedEvts := filterRuntimeEvents(events, runtimeevents.KindAgentToolExecSkipped)
|
|
if len(skippedEvts) < 1 {
|
|
t.Fatal("expected at least one ToolExecSkipped event after interrupt")
|
|
}
|
|
|
|
for _, evt := range skippedEvts {
|
|
payload, ok := evt.Payload.(ToolExecSkippedPayload)
|
|
if !ok {
|
|
t.Fatalf("expected ToolExecSkippedPayload, got %T", evt.Payload)
|
|
}
|
|
if payload.Reason != "graceful interrupt requested" {
|
|
t.Fatalf("expected skip reason 'graceful interrupt requested', got %q", payload.Reason)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestAgentLoop_HookRespond_SteeringSkipsRemaining(t *testing.T) {
|
|
provider := &multiToolProvider{
|
|
toolCalls: []providers.ToolCall{
|
|
{ID: "call-1", Name: "tool_one", Arguments: map[string]any{}},
|
|
{ID: "call-2", Name: "tool_two", Arguments: map[string]any{}},
|
|
{ID: "call-3", Name: "tool_three", Arguments: map[string]any{}},
|
|
},
|
|
finalContent: "done",
|
|
}
|
|
al, _, cleanup := newHookTestLoop(t, provider)
|
|
defer cleanup()
|
|
|
|
al.RegisterTool(&slowTool{name: "tool_two", duration: 100 * time.Millisecond})
|
|
al.RegisterTool(&slowTool{name: "tool_three", duration: 100 * time.Millisecond})
|
|
|
|
hook := &respondHook{
|
|
respondTools: map[string]bool{"tool_one": true},
|
|
}
|
|
if err := al.MountHook(NamedHook("respond-hook", hook)); err != nil {
|
|
t.Fatalf("MountHook failed: %v", err)
|
|
}
|
|
|
|
runtimeCh, closeRuntimeEvents := subscribeRuntimeEventsForTest(
|
|
t,
|
|
al,
|
|
32,
|
|
runtimeevents.KindAgentToolExecEnd,
|
|
runtimeevents.KindAgentToolExecSkipped,
|
|
)
|
|
defer closeRuntimeEvents()
|
|
|
|
sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID)
|
|
|
|
type result struct {
|
|
resp string
|
|
err error
|
|
}
|
|
resultCh := make(chan result, 1)
|
|
go func() {
|
|
resp, err := al.ProcessDirectWithChannel(
|
|
context.Background(),
|
|
"run tools",
|
|
sessionKey,
|
|
"cli",
|
|
"chat1",
|
|
)
|
|
resultCh <- result{resp: resp, err: err}
|
|
}()
|
|
|
|
collectedEvents := make([]runtimeevents.Event, 0, 8)
|
|
steered := false
|
|
deadline := time.After(3 * time.Second)
|
|
for !steered {
|
|
select {
|
|
case evt := <-runtimeCh:
|
|
collectedEvents = append(collectedEvents, evt)
|
|
if evt.Kind != runtimeevents.KindAgentToolExecEnd {
|
|
continue
|
|
}
|
|
payload, ok := evt.Payload.(ToolExecEndPayload)
|
|
if !ok || payload.Tool != "tool_one" {
|
|
continue
|
|
}
|
|
al.Steer(providers.Message{Role: "user", Content: "change direction"})
|
|
steered = true
|
|
case <-deadline:
|
|
t.Fatal("timeout waiting for tool_one to finish before steering")
|
|
}
|
|
}
|
|
|
|
select {
|
|
case r := <-resultCh:
|
|
if r.err != nil {
|
|
t.Fatalf("unexpected error: %v", r.err)
|
|
}
|
|
case <-time.After(3 * time.Second):
|
|
t.Fatal("timeout waiting for result")
|
|
}
|
|
|
|
events := append(collectedEvents, collectRuntimeEventStream(runtimeCh)...)
|
|
|
|
skippedEvts := filterRuntimeEvents(events, runtimeevents.KindAgentToolExecSkipped)
|
|
if len(skippedEvts) < 1 {
|
|
t.Fatal("expected at least one ToolExecSkipped event after steering")
|
|
}
|
|
|
|
for _, evt := range skippedEvts {
|
|
payload, ok := evt.Payload.(ToolExecSkippedPayload)
|
|
if !ok {
|
|
t.Fatalf("expected ToolExecSkippedPayload, got %T", evt.Payload)
|
|
}
|
|
if payload.Reason != "queued user steering message" {
|
|
t.Fatalf("expected skip reason 'queued user steering message', got %q", payload.Reason)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestCloneStringAnyMap_EmptyMapReturnsNonNil(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
input map[string]any
|
|
wantNil bool
|
|
wantLen int
|
|
}{
|
|
{
|
|
name: "nil input returns empty map",
|
|
input: nil,
|
|
wantNil: false,
|
|
wantLen: 0,
|
|
},
|
|
{
|
|
name: "empty map returns empty map",
|
|
input: map[string]any{},
|
|
wantNil: false,
|
|
wantLen: 0,
|
|
},
|
|
{
|
|
name: "populated map is cloned",
|
|
input: map[string]any{"key": "value"},
|
|
wantNil: false,
|
|
wantLen: 1,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result := cloneStringAnyMap(tt.input)
|
|
if result == nil {
|
|
t.Fatal("cloneStringAnyMap returned nil — MCP tool calls " +
|
|
"with no arguments would send null instead of {}")
|
|
}
|
|
if len(result) != tt.wantLen {
|
|
t.Fatalf("expected len %d, got %d", tt.wantLen, len(result))
|
|
}
|
|
})
|
|
}
|
|
|
|
t.Run("clone does not share underlying map", func(t *testing.T) {
|
|
src := map[string]any{"a": 1}
|
|
cloned := cloneStringAnyMap(src)
|
|
cloned["b"] = 2
|
|
if _, ok := src["b"]; ok {
|
|
t.Fatal("modifying clone should not affect source")
|
|
}
|
|
})
|
|
}
|