mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
329e68e017
* refactor(agent): introduce interfaces for MessageBus and ChannelManager Phase 2 of loop.go refactor — dependency inversion using adapter pattern. - Add interfaces.MessageBus and interfaces.ChannelManager interfaces - Create adapters/messagebus.go wrapping *bus.MessageBus - Create adapters/channelmanager.go wrapping *channels.Manager - Update AgentLoop to use interfaces instead of concrete types - Update registerSharedTools to accept interfaces.MessageBus Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * refactor(agent): restructure pipeline and rename loop files Pipeline refactoring: - Split pipeline.go (1400 lines) into focused files: - pipeline_setup.go (~115 lines): SetupTurn method - pipeline_llm.go (~519 lines): CallLLM method - pipeline_execute.go (~693 lines): ExecuteTools method - pipeline_finalize.go (~78 lines): Finalize method - Pipeline struct and NewPipeline remain in pipeline.go (~39 lines) Agent file renaming: - Rename loop_*.go to agent_*.go for consistent naming: - loop.go -> agent.go, loop_message.go -> agent_message.go, etc. - Merge turn.go + turn_exec.go into turn_state.go - Rename loop_turn.go -> turn_coord.go Documentation: - Update docs/pipeline-restructuring-plan.md - Add docs/agent-rename-plan.md Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com> * fix(agent): code format fixed * refactor(agent): code test file added/renamed * docs(agent): update agent refactor docs * fix(agent): fix agent hardAbortX --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1017 lines
27 KiB
Go
1017 lines
27 KiB
Go
package agent
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"os"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/sipeed/picoclaw/pkg/bus"
|
|
"github.com/sipeed/picoclaw/pkg/config"
|
|
"github.com/sipeed/picoclaw/pkg/providers"
|
|
"github.com/sipeed/picoclaw/pkg/routing"
|
|
"github.com/sipeed/picoclaw/pkg/session"
|
|
"github.com/sipeed/picoclaw/pkg/tools"
|
|
)
|
|
|
|
func newHookTestLoop(
|
|
t *testing.T,
|
|
provider providers.LLMProvider,
|
|
) (*AgentLoop, *AgentInstance, func()) {
|
|
t.Helper()
|
|
|
|
tmpDir, err := os.MkdirTemp("", "agent-hooks-*")
|
|
if err != nil {
|
|
t.Fatalf("failed to create temp dir: %v", err)
|
|
}
|
|
|
|
cfg := &config.Config{
|
|
Agents: config.AgentsConfig{
|
|
Defaults: config.AgentDefaults{
|
|
Workspace: tmpDir,
|
|
ModelName: "test-model",
|
|
MaxTokens: 4096,
|
|
MaxToolIterations: 10,
|
|
},
|
|
},
|
|
}
|
|
|
|
al := NewAgentLoop(cfg, bus.NewMessageBus(), provider)
|
|
agent := al.registry.GetDefaultAgent()
|
|
if agent == nil {
|
|
t.Fatal("expected default agent")
|
|
}
|
|
|
|
return al, agent, func() {
|
|
al.Close()
|
|
_ = os.RemoveAll(tmpDir)
|
|
}
|
|
}
|
|
|
|
func TestHookManager_SortsInProcessBeforeProcess(t *testing.T) {
|
|
hm := NewHookManager(nil)
|
|
defer hm.Close()
|
|
|
|
if err := hm.Mount(HookRegistration{
|
|
Name: "process",
|
|
Priority: -10,
|
|
Source: HookSourceProcess,
|
|
Hook: struct{}{},
|
|
}); err != nil {
|
|
t.Fatalf("mount process hook: %v", err)
|
|
}
|
|
if err := hm.Mount(HookRegistration{
|
|
Name: "in-process",
|
|
Priority: 100,
|
|
Source: HookSourceInProcess,
|
|
Hook: struct{}{},
|
|
}); err != nil {
|
|
t.Fatalf("mount in-process hook: %v", err)
|
|
}
|
|
|
|
ordered := hm.snapshotHooks()
|
|
if len(ordered) != 2 {
|
|
t.Fatalf("expected 2 hooks, got %d", len(ordered))
|
|
}
|
|
if ordered[0].Name != "in-process" {
|
|
t.Fatalf("expected in-process hook first, got %q", ordered[0].Name)
|
|
}
|
|
if ordered[1].Name != "process" {
|
|
t.Fatalf("expected process hook second, got %q", ordered[1].Name)
|
|
}
|
|
}
|
|
|
|
type llmHookTestProvider struct {
|
|
mu sync.Mutex
|
|
lastModel string
|
|
}
|
|
|
|
func (p *llmHookTestProvider) Chat(
|
|
ctx context.Context,
|
|
messages []providers.Message,
|
|
tools []providers.ToolDefinition,
|
|
model string,
|
|
opts map[string]any,
|
|
) (*providers.LLMResponse, error) {
|
|
p.mu.Lock()
|
|
p.lastModel = model
|
|
p.mu.Unlock()
|
|
|
|
return &providers.LLMResponse{
|
|
Content: "provider content",
|
|
}, nil
|
|
}
|
|
|
|
func (p *llmHookTestProvider) GetDefaultModel() string {
|
|
return "llm-hook-provider"
|
|
}
|
|
|
|
type llmObserverHook struct {
|
|
eventCh chan Event
|
|
lastInbound *bus.InboundContext
|
|
lastRoute *routing.ResolvedRoute
|
|
lastScope *session.SessionScope
|
|
}
|
|
|
|
func (h *llmObserverHook) OnEvent(ctx context.Context, evt Event) error {
|
|
if evt.Kind == EventKindTurnEnd {
|
|
select {
|
|
case h.eventCh <- evt:
|
|
default:
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (h *llmObserverHook) BeforeLLM(
|
|
ctx context.Context,
|
|
req *LLMHookRequest,
|
|
) (*LLMHookRequest, HookDecision, error) {
|
|
if req.Context != nil {
|
|
h.lastInbound = cloneInboundContext(req.Context.Inbound)
|
|
h.lastRoute = cloneResolvedRoute(req.Context.Route)
|
|
h.lastScope = session.CloneScope(req.Context.Scope)
|
|
}
|
|
next := req.Clone()
|
|
next.Model = "hook-model"
|
|
return next, HookDecision{Action: HookActionModify}, nil
|
|
}
|
|
|
|
func (h *llmObserverHook) AfterLLM(
|
|
ctx context.Context,
|
|
resp *LLMHookResponse,
|
|
) (*LLMHookResponse, HookDecision, error) {
|
|
next := resp.Clone()
|
|
next.Response.Content = "hooked content"
|
|
return next, HookDecision{Action: HookActionModify}, nil
|
|
}
|
|
|
|
func TestAgentLoop_Hooks_ObserverAndLLMInterceptor(t *testing.T) {
|
|
provider := &llmHookTestProvider{}
|
|
al, agent, cleanup := newHookTestLoop(t, provider)
|
|
defer cleanup()
|
|
|
|
hook := &llmObserverHook{eventCh: make(chan Event, 1)}
|
|
if err := al.MountHook(NamedHook("llm-observer", hook)); err != nil {
|
|
t.Fatalf("MountHook failed: %v", err)
|
|
}
|
|
|
|
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
|
SessionKey: "session-1",
|
|
Channel: "cli",
|
|
ChatID: "direct",
|
|
UserMessage: "hello",
|
|
DefaultResponse: defaultResponse,
|
|
EnableSummary: false,
|
|
SendResponse: false,
|
|
InboundContext: &bus.InboundContext{
|
|
Channel: "cli",
|
|
ChatID: "direct",
|
|
ChatType: "direct",
|
|
SenderID: "hook-user",
|
|
},
|
|
RouteResult: &routing.ResolvedRoute{
|
|
AgentID: "main",
|
|
Channel: "cli",
|
|
AccountID: routing.DefaultAccountID,
|
|
SessionPolicy: routing.SessionPolicy{
|
|
Dimensions: []string{"sender"},
|
|
},
|
|
MatchedBy: "default",
|
|
},
|
|
SessionScope: &session.SessionScope{
|
|
Version: session.ScopeVersionV1,
|
|
AgentID: "main",
|
|
Channel: "cli",
|
|
Account: routing.DefaultAccountID,
|
|
Dimensions: []string{"sender"},
|
|
Values: map[string]string{
|
|
"sender": "hook-user",
|
|
},
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("runAgentLoop failed: %v", err)
|
|
}
|
|
if resp != "hooked content" {
|
|
t.Fatalf("expected hooked content, got %q", resp)
|
|
}
|
|
|
|
provider.mu.Lock()
|
|
lastModel := provider.lastModel
|
|
provider.mu.Unlock()
|
|
if lastModel != "hook-model" {
|
|
t.Fatalf("expected model hook-model, got %q", lastModel)
|
|
}
|
|
if hook.lastInbound == nil {
|
|
t.Fatal("expected hook to receive inbound context")
|
|
}
|
|
if hook.lastInbound.Channel != "cli" || hook.lastInbound.SenderID != "hook-user" {
|
|
t.Fatalf("hook inbound context = %+v", hook.lastInbound)
|
|
}
|
|
if hook.lastInbound != nil && hook.lastInbound.ChatID != "direct" {
|
|
t.Fatalf("hook inbound chat ID = %q, want direct", hook.lastInbound.ChatID)
|
|
}
|
|
|
|
select {
|
|
case evt := <-hook.eventCh:
|
|
if evt.Kind != EventKindTurnEnd {
|
|
t.Fatalf("expected turn end event, got %v", evt.Kind)
|
|
}
|
|
if evt.Context == nil || evt.Context.Inbound == nil {
|
|
t.Fatal("expected observer event to carry inbound context")
|
|
}
|
|
if evt.Context.Route == nil || evt.Context.Route.AgentID != "main" {
|
|
t.Fatalf("expected observer event to carry route context, got %+v", evt.Context.Route)
|
|
}
|
|
if evt.Context.Scope == nil || evt.Context.Scope.Values["sender"] != "hook-user" {
|
|
t.Fatalf("expected observer event to carry session scope, got %+v", evt.Context.Scope)
|
|
}
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("timed out waiting for hook observer event")
|
|
}
|
|
}
|
|
|
|
func TestAgentLoop_BtwCommand_UsesLLMHooks(t *testing.T) {
|
|
provider := &llmHookTestProvider{}
|
|
al, agent, cleanup := newHookTestLoop(t, provider)
|
|
defer cleanup()
|
|
useTestSideQuestionProvider(al, provider)
|
|
|
|
hook := &llmObserverHook{eventCh: make(chan Event, 1)}
|
|
if err := al.MountHook(NamedHook("llm-observer", hook)); err != nil {
|
|
t.Fatalf("MountHook failed: %v", err)
|
|
}
|
|
|
|
response, handled := al.handleCommand(context.Background(), bus.InboundMessage{
|
|
Context: bus.InboundContext{
|
|
Channel: "cli",
|
|
ChatID: "direct",
|
|
ChatType: "direct",
|
|
SenderID: "hook-user",
|
|
},
|
|
Content: "/btw hello",
|
|
}, agent, &processOptions{
|
|
Dispatch: DispatchRequest{
|
|
SessionKey: "session-1",
|
|
InboundContext: &bus.InboundContext{
|
|
Channel: "cli",
|
|
ChatID: "direct",
|
|
ChatType: "direct",
|
|
SenderID: "hook-user",
|
|
},
|
|
RouteResult: &routing.ResolvedRoute{
|
|
AgentID: "main",
|
|
Channel: "cli",
|
|
AccountID: routing.DefaultAccountID,
|
|
SessionPolicy: routing.SessionPolicy{
|
|
Dimensions: []string{"sender"},
|
|
},
|
|
MatchedBy: "default",
|
|
},
|
|
SessionScope: &session.SessionScope{
|
|
Version: session.ScopeVersionV1,
|
|
AgentID: "main",
|
|
Channel: "cli",
|
|
Account: routing.DefaultAccountID,
|
|
Dimensions: []string{"sender"},
|
|
Values: map[string]string{
|
|
"sender": "hook-user",
|
|
},
|
|
},
|
|
UserMessage: "/btw hello",
|
|
},
|
|
SessionKey: "session-1",
|
|
Channel: "cli",
|
|
ChatID: "direct",
|
|
SenderID: "hook-user",
|
|
SenderDisplayName: "Hook User",
|
|
})
|
|
if !handled {
|
|
t.Fatal("expected /btw command to be handled")
|
|
}
|
|
if response != "hooked content" {
|
|
t.Fatalf("expected hooked content, got %q", response)
|
|
}
|
|
|
|
provider.mu.Lock()
|
|
lastModel := provider.lastModel
|
|
provider.mu.Unlock()
|
|
if lastModel != "hook-model" {
|
|
t.Fatalf("expected model hook-model, got %q", lastModel)
|
|
}
|
|
if hook.lastInbound == nil {
|
|
t.Fatal("expected hook to receive inbound context")
|
|
}
|
|
if hook.lastInbound.Channel != "cli" || hook.lastInbound.SenderID != "hook-user" {
|
|
t.Fatalf("hook inbound context = %+v", hook.lastInbound)
|
|
}
|
|
if hook.lastInbound.ChatID != "direct" {
|
|
t.Fatalf("hook inbound chat ID = %q, want direct", hook.lastInbound.ChatID)
|
|
}
|
|
if hook.lastRoute == nil || hook.lastRoute.AgentID != "main" {
|
|
t.Fatalf("expected hook route context for /btw, got %+v", hook.lastRoute)
|
|
}
|
|
if hook.lastScope == nil || hook.lastScope.Values["sender"] != "hook-user" {
|
|
t.Fatalf("expected hook session scope for /btw, got %+v", hook.lastScope)
|
|
}
|
|
}
|
|
|
|
type toolHookProvider struct {
|
|
mu sync.Mutex
|
|
calls int
|
|
}
|
|
|
|
func (p *toolHookProvider) Chat(
|
|
ctx context.Context,
|
|
messages []providers.Message,
|
|
tools []providers.ToolDefinition,
|
|
model string,
|
|
opts map[string]any,
|
|
) (*providers.LLMResponse, error) {
|
|
p.mu.Lock()
|
|
defer p.mu.Unlock()
|
|
|
|
p.calls++
|
|
if p.calls == 1 {
|
|
return &providers.LLMResponse{
|
|
ToolCalls: []providers.ToolCall{
|
|
{
|
|
ID: "call-1",
|
|
Name: "echo_text",
|
|
Arguments: map[string]any{"text": "original"},
|
|
},
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
last := messages[len(messages)-1]
|
|
return &providers.LLMResponse{
|
|
Content: last.Content,
|
|
}, nil
|
|
}
|
|
|
|
func (p *toolHookProvider) GetDefaultModel() string {
|
|
return "tool-hook-provider"
|
|
}
|
|
|
|
type echoTextTool struct{}
|
|
|
|
func (t *echoTextTool) Name() string {
|
|
return "echo_text"
|
|
}
|
|
|
|
func (t *echoTextTool) Description() string {
|
|
return "echo a text argument"
|
|
}
|
|
|
|
func (t *echoTextTool) Parameters() map[string]any {
|
|
return map[string]any{
|
|
"type": "object",
|
|
"properties": map[string]any{
|
|
"text": map[string]any{
|
|
"type": "string",
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
func (t *echoTextTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
|
|
text, _ := args["text"].(string)
|
|
return tools.SilentResult(text)
|
|
}
|
|
|
|
type toolRewriteHook struct{}
|
|
|
|
func (h *toolRewriteHook) BeforeTool(
|
|
ctx context.Context,
|
|
call *ToolCallHookRequest,
|
|
) (*ToolCallHookRequest, HookDecision, error) {
|
|
next := call.Clone()
|
|
next.Arguments["text"] = "modified"
|
|
return next, HookDecision{Action: HookActionModify}, nil
|
|
}
|
|
|
|
func (h *toolRewriteHook) AfterTool(
|
|
ctx context.Context,
|
|
result *ToolResultHookResponse,
|
|
) (*ToolResultHookResponse, HookDecision, error) {
|
|
next := result.Clone()
|
|
next.Result.ForLLM = "after:" + next.Result.ForLLM
|
|
return next, HookDecision{Action: HookActionModify}, nil
|
|
}
|
|
|
|
func TestAgentLoop_Hooks_ToolInterceptorCanRewrite(t *testing.T) {
|
|
provider := &toolHookProvider{}
|
|
al, agent, cleanup := newHookTestLoop(t, provider)
|
|
defer cleanup()
|
|
|
|
al.RegisterTool(&echoTextTool{})
|
|
if err := al.MountHook(NamedHook("tool-rewrite", &toolRewriteHook{})); err != nil {
|
|
t.Fatalf("MountHook failed: %v", err)
|
|
}
|
|
|
|
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
|
SessionKey: "session-1",
|
|
Channel: "cli",
|
|
ChatID: "direct",
|
|
UserMessage: "run tool",
|
|
DefaultResponse: defaultResponse,
|
|
EnableSummary: false,
|
|
SendResponse: false,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("runAgentLoop failed: %v", err)
|
|
}
|
|
if resp != "after:modified" {
|
|
t.Fatalf("expected rewritten tool result, got %q", resp)
|
|
}
|
|
}
|
|
|
|
type denyApprovalHook struct{}
|
|
|
|
func (h *denyApprovalHook) ApproveTool(ctx context.Context, req *ToolApprovalRequest) (ApprovalDecision, error) {
|
|
return ApprovalDecision{
|
|
Approved: false,
|
|
Reason: "blocked",
|
|
}, nil
|
|
}
|
|
|
|
func TestAgentLoop_Hooks_ToolApproverCanDeny(t *testing.T) {
|
|
provider := &toolHookProvider{}
|
|
al, agent, cleanup := newHookTestLoop(t, provider)
|
|
defer cleanup()
|
|
|
|
al.RegisterTool(&echoTextTool{})
|
|
if err := al.MountHook(NamedHook("deny-approval", &denyApprovalHook{})); err != nil {
|
|
t.Fatalf("MountHook failed: %v", err)
|
|
}
|
|
|
|
sub := al.SubscribeEvents(16)
|
|
defer al.UnsubscribeEvents(sub.ID)
|
|
|
|
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
|
SessionKey: "session-1",
|
|
Channel: "cli",
|
|
ChatID: "direct",
|
|
UserMessage: "run tool",
|
|
DefaultResponse: defaultResponse,
|
|
EnableSummary: false,
|
|
SendResponse: false,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("runAgentLoop failed: %v", err)
|
|
}
|
|
expected := "Tool execution denied by approval hook: blocked"
|
|
if resp != expected {
|
|
t.Fatalf("expected %q, got %q", expected, resp)
|
|
}
|
|
|
|
events := collectEventStream(sub.C)
|
|
skippedEvt, ok := findEvent(events, EventKindToolExecSkipped)
|
|
if !ok {
|
|
t.Fatal("expected tool skipped event")
|
|
}
|
|
payload, ok := skippedEvt.Payload.(ToolExecSkippedPayload)
|
|
if !ok {
|
|
t.Fatalf("expected ToolExecSkippedPayload, got %T", skippedEvt.Payload)
|
|
}
|
|
if payload.Reason != expected {
|
|
t.Fatalf("expected skipped reason %q, got %q", expected, payload.Reason)
|
|
}
|
|
}
|
|
|
|
// respondHook is a test hook for testing HookActionRespond functionality
|
|
type respondHook struct {
|
|
respondTools map[string]bool // tool names to respond to
|
|
}
|
|
|
|
func (h *respondHook) BeforeTool(
|
|
ctx context.Context,
|
|
call *ToolCallHookRequest,
|
|
) (*ToolCallHookRequest, HookDecision, error) {
|
|
if h.respondTools[call.Tool] {
|
|
next := call.Clone()
|
|
next.HookResult = &tools.ToolResult{
|
|
ForLLM: "hook-responded: " + call.Tool,
|
|
ForUser: "",
|
|
Silent: false,
|
|
IsError: false,
|
|
}
|
|
return next, HookDecision{Action: HookActionRespond}, nil
|
|
}
|
|
return call, HookDecision{Action: HookActionContinue}, nil
|
|
}
|
|
|
|
func (h *respondHook) AfterTool(
|
|
ctx context.Context,
|
|
result *ToolResultHookResponse,
|
|
) (*ToolResultHookResponse, HookDecision, error) {
|
|
// Should not be called since respond skips tool execution
|
|
return result, HookDecision{Action: HookActionContinue}, nil
|
|
}
|
|
|
|
func TestAgentLoop_Hooks_ToolRespondAction(t *testing.T) {
|
|
provider := &toolHookProvider{}
|
|
al, agent, cleanup := newHookTestLoop(t, provider)
|
|
defer cleanup()
|
|
|
|
al.RegisterTool(&echoTextTool{})
|
|
if err := al.MountHook(NamedHook("respond-hook", &respondHook{
|
|
respondTools: map[string]bool{"echo_text": true},
|
|
})); err != nil {
|
|
t.Fatalf("MountHook failed: %v", err)
|
|
}
|
|
|
|
sub := al.SubscribeEvents(16)
|
|
defer al.UnsubscribeEvents(sub.ID)
|
|
|
|
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
|
SessionKey: "session-1",
|
|
Channel: "cli",
|
|
ChatID: "direct",
|
|
UserMessage: "run tool",
|
|
DefaultResponse: defaultResponse,
|
|
EnableSummary: false,
|
|
SendResponse: false,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("runAgentLoop failed: %v", err)
|
|
}
|
|
|
|
// Verify response comes from hook, not tool
|
|
expected := "hook-responded: echo_text"
|
|
if resp != expected {
|
|
t.Fatalf("expected %q, got %q", expected, resp)
|
|
}
|
|
|
|
// Verify event stream has ToolExecEnd, not actual tool execution
|
|
events := collectEventStream(sub.C)
|
|
endEvt, ok := findEvent(events, EventKindToolExecEnd)
|
|
if !ok {
|
|
t.Fatal("expected tool exec end event")
|
|
}
|
|
payload, ok := endEvt.Payload.(ToolExecEndPayload)
|
|
if !ok {
|
|
t.Fatalf("expected ToolExecEndPayload, got %T", endEvt.Payload)
|
|
}
|
|
if payload.Tool != "echo_text" {
|
|
t.Fatalf("expected tool echo_text, got %q", payload.Tool)
|
|
}
|
|
if payload.ForLLMLen != len(expected) {
|
|
t.Fatalf("expected ForLLMLen %d, got %d", len(expected), payload.ForLLMLen)
|
|
}
|
|
}
|
|
|
|
// denyToolHook tests HookActionDenyTool functionality
|
|
type denyToolHook struct {
|
|
denyTools map[string]bool
|
|
}
|
|
|
|
func (h *denyToolHook) BeforeTool(
|
|
ctx context.Context,
|
|
call *ToolCallHookRequest,
|
|
) (*ToolCallHookRequest, HookDecision, error) {
|
|
if h.denyTools[call.Tool] {
|
|
return call, HookDecision{Action: HookActionDenyTool, Reason: "tool denied by hook"}, nil
|
|
}
|
|
return call, HookDecision{Action: HookActionContinue}, nil
|
|
}
|
|
|
|
func (h *denyToolHook) AfterTool(
|
|
ctx context.Context,
|
|
result *ToolResultHookResponse,
|
|
) (*ToolResultHookResponse, HookDecision, error) {
|
|
return result, HookDecision{Action: HookActionContinue}, nil
|
|
}
|
|
|
|
func TestAgentLoop_Hooks_ToolDenyAction(t *testing.T) {
|
|
provider := &toolHookProvider{}
|
|
al, agent, cleanup := newHookTestLoop(t, provider)
|
|
defer cleanup()
|
|
|
|
al.RegisterTool(&echoTextTool{})
|
|
if err := al.MountHook(NamedHook("deny-hook", &denyToolHook{
|
|
denyTools: map[string]bool{"echo_text": true},
|
|
})); err != nil {
|
|
t.Fatalf("MountHook failed: %v", err)
|
|
}
|
|
|
|
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
|
SessionKey: "session-1",
|
|
Channel: "cli",
|
|
ChatID: "direct",
|
|
UserMessage: "run tool",
|
|
DefaultResponse: defaultResponse,
|
|
EnableSummary: false,
|
|
SendResponse: false,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("runAgentLoop failed: %v", err)
|
|
}
|
|
|
|
expected := "Tool execution denied by hook: tool denied by hook"
|
|
if resp != expected {
|
|
t.Fatalf("expected %q, got %q", expected, resp)
|
|
}
|
|
}
|
|
|
|
func TestHookManager_BeforeTool_RespondAction(t *testing.T) {
|
|
hm := NewHookManager(nil)
|
|
defer hm.Close()
|
|
|
|
hook := &respondHook{
|
|
respondTools: map[string]bool{"test_tool": true},
|
|
}
|
|
if err := hm.Mount(NamedHook("respond-test", hook)); err != nil {
|
|
t.Fatalf("mount hook: %v", err)
|
|
}
|
|
|
|
req := &ToolCallHookRequest{
|
|
Tool: "test_tool",
|
|
Arguments: map[string]any{"arg": "value"},
|
|
}
|
|
result, decision := hm.BeforeTool(context.Background(), req)
|
|
|
|
if decision.Action != HookActionRespond {
|
|
t.Fatalf("expected action %q, got %q", HookActionRespond, decision.Action)
|
|
}
|
|
|
|
if result.HookResult == nil {
|
|
t.Fatal("expected HookResult to be set")
|
|
}
|
|
if result.HookResult.ForLLM != "hook-responded: test_tool" {
|
|
t.Fatalf("unexpected HookResult.ForLLM: %q", result.HookResult.ForLLM)
|
|
}
|
|
}
|
|
|
|
type respondWithMediaHook struct {
|
|
respondTools map[string]bool
|
|
media []string
|
|
responseHandled bool
|
|
forLLM string
|
|
}
|
|
|
|
func (h *respondWithMediaHook) BeforeTool(
|
|
ctx context.Context,
|
|
call *ToolCallHookRequest,
|
|
) (*ToolCallHookRequest, HookDecision, error) {
|
|
if h.respondTools[call.Tool] {
|
|
next := call.Clone()
|
|
next.HookResult = &tools.ToolResult{
|
|
ForLLM: h.forLLM,
|
|
ForUser: "media result",
|
|
Media: h.media,
|
|
ResponseHandled: h.responseHandled,
|
|
Silent: false,
|
|
IsError: false,
|
|
}
|
|
return next, HookDecision{Action: HookActionRespond}, nil
|
|
}
|
|
return call, HookDecision{Action: HookActionContinue}, nil
|
|
}
|
|
|
|
func (h *respondWithMediaHook) AfterTool(
|
|
ctx context.Context,
|
|
result *ToolResultHookResponse,
|
|
) (*ToolResultHookResponse, HookDecision, error) {
|
|
return result, HookDecision{Action: HookActionContinue}, nil
|
|
}
|
|
|
|
type errorMediaChannel struct {
|
|
fakeChannel
|
|
sendErr error
|
|
}
|
|
|
|
func (f *errorMediaChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) ([]string, error) {
|
|
return nil, f.sendErr
|
|
}
|
|
|
|
func TestAgentLoop_HookRespond_MediaError(t *testing.T) {
|
|
provider := &multiToolProvider{
|
|
toolCalls: []providers.ToolCall{
|
|
{ID: "call-1", Name: "media_tool", Arguments: map[string]any{}},
|
|
},
|
|
finalContent: "done",
|
|
}
|
|
al, agent, cleanup := newHookTestLoop(t, provider)
|
|
defer cleanup()
|
|
|
|
hook := &respondWithMediaHook{
|
|
respondTools: map[string]bool{"media_tool": true},
|
|
media: []string{"media://test/image.png"},
|
|
responseHandled: true,
|
|
forLLM: "media sent successfully",
|
|
}
|
|
if err := al.MountHook(NamedHook("media-hook", hook)); err != nil {
|
|
t.Fatalf("MountHook failed: %v", err)
|
|
}
|
|
|
|
al.channelManager = newStartedTestChannelManager(t,
|
|
al.bus.(*bus.MessageBus), al.mediaStore, "discord", &errorMediaChannel{
|
|
sendErr: errors.New("channel unavailable"),
|
|
})
|
|
|
|
sub := al.SubscribeEvents(16)
|
|
defer al.UnsubscribeEvents(sub.ID)
|
|
|
|
_, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
|
SessionKey: "session-media-err",
|
|
Channel: "discord",
|
|
ChatID: "chat1",
|
|
UserMessage: "send media",
|
|
DefaultResponse: defaultResponse,
|
|
EnableSummary: false,
|
|
SendResponse: false,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("runAgentLoop failed: %v", err)
|
|
}
|
|
|
|
events := collectEventStream(sub.C)
|
|
endEvt, ok := findEvent(events, EventKindToolExecEnd)
|
|
if !ok {
|
|
t.Fatal("expected ToolExecEnd event")
|
|
}
|
|
payload, ok := endEvt.Payload.(ToolExecEndPayload)
|
|
if !ok {
|
|
t.Fatalf("expected ToolExecEndPayload, got %T", endEvt.Payload)
|
|
}
|
|
|
|
if !payload.IsError {
|
|
t.Fatal("expected IsError=true when SendMedia fails")
|
|
}
|
|
|
|
if payload.ForLLMLen < 30 {
|
|
t.Fatalf("expected ForLLM to contain error message, got ForLLMLen=%d", payload.ForLLMLen)
|
|
}
|
|
}
|
|
|
|
func TestAgentLoop_HookRespond_BusFallback(t *testing.T) {
|
|
provider := &multiToolProvider{
|
|
toolCalls: []providers.ToolCall{
|
|
{ID: "call-1", Name: "media_tool", Arguments: map[string]any{}},
|
|
},
|
|
finalContent: "done",
|
|
}
|
|
al, agent, cleanup := newHookTestLoop(t, provider)
|
|
defer cleanup()
|
|
|
|
hook := &respondWithMediaHook{
|
|
respondTools: map[string]bool{"media_tool": true},
|
|
media: []string{"media://test/image.png"},
|
|
responseHandled: true,
|
|
forLLM: "media queued",
|
|
}
|
|
if err := al.MountHook(NamedHook("media-hook", hook)); err != nil {
|
|
t.Fatalf("MountHook failed: %v", err)
|
|
}
|
|
|
|
sub := al.SubscribeEvents(16)
|
|
defer al.UnsubscribeEvents(sub.ID)
|
|
|
|
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
|
SessionKey: "session-bus-fallback",
|
|
Channel: "cli",
|
|
ChatID: "chat1",
|
|
UserMessage: "send media",
|
|
DefaultResponse: defaultResponse,
|
|
EnableSummary: false,
|
|
SendResponse: false,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("runAgentLoop failed: %v", err)
|
|
}
|
|
|
|
events := collectEventStream(sub.C)
|
|
endEvt, ok := findEvent(events, EventKindToolExecEnd)
|
|
if !ok {
|
|
t.Fatal("expected ToolExecEnd event")
|
|
}
|
|
payload, ok := endEvt.Payload.(ToolExecEndPayload)
|
|
if !ok {
|
|
t.Fatalf("expected ToolExecEndPayload, got %T", endEvt.Payload)
|
|
}
|
|
|
|
if payload.IsError {
|
|
t.Fatal("expected IsError=false for bus fallback (media queued, not delivered)")
|
|
}
|
|
|
|
if resp != "done" {
|
|
t.Fatalf("expected response 'done', got %q", resp)
|
|
}
|
|
}
|
|
|
|
type multiToolProvider struct {
|
|
mu sync.Mutex
|
|
callCount int
|
|
toolCalls []providers.ToolCall
|
|
finalContent string
|
|
}
|
|
|
|
func (p *multiToolProvider) Chat(
|
|
ctx context.Context,
|
|
messages []providers.Message,
|
|
tools []providers.ToolDefinition,
|
|
model string,
|
|
opts map[string]any,
|
|
) (*providers.LLMResponse, error) {
|
|
p.mu.Lock()
|
|
defer p.mu.Unlock()
|
|
|
|
p.callCount++
|
|
if p.callCount == 1 && len(p.toolCalls) > 0 {
|
|
return &providers.LLMResponse{
|
|
ToolCalls: p.toolCalls,
|
|
}, nil
|
|
}
|
|
|
|
return &providers.LLMResponse{
|
|
Content: p.finalContent,
|
|
}, nil
|
|
}
|
|
|
|
func (p *multiToolProvider) GetDefaultModel() string {
|
|
return "multi-tool-provider"
|
|
}
|
|
|
|
func TestAgentLoop_HookRespond_InterruptSkipsRemaining(t *testing.T) {
|
|
provider := &multiToolProvider{
|
|
toolCalls: []providers.ToolCall{
|
|
{ID: "call-1", Name: "tool_one", Arguments: map[string]any{}},
|
|
{ID: "call-2", Name: "tool_two", Arguments: map[string]any{}},
|
|
{ID: "call-3", Name: "tool_three", Arguments: map[string]any{}},
|
|
},
|
|
finalContent: "done",
|
|
}
|
|
al, _, cleanup := newHookTestLoop(t, provider)
|
|
defer cleanup()
|
|
|
|
tool1ExecCh := make(chan struct{}, 1)
|
|
al.RegisterTool(&slowTool{name: "tool_two", duration: 100 * time.Millisecond, execCh: tool1ExecCh})
|
|
al.RegisterTool(&slowTool{name: "tool_three", duration: 100 * time.Millisecond})
|
|
|
|
hook := &respondHook{
|
|
respondTools: map[string]bool{"tool_one": true},
|
|
}
|
|
if err := al.MountHook(NamedHook("respond-hook", hook)); err != nil {
|
|
t.Fatalf("MountHook failed: %v", err)
|
|
}
|
|
|
|
sub := al.SubscribeEvents(32)
|
|
defer al.UnsubscribeEvents(sub.ID)
|
|
|
|
sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID)
|
|
|
|
type result struct {
|
|
resp string
|
|
err error
|
|
}
|
|
resultCh := make(chan result, 1)
|
|
go func() {
|
|
resp, err := al.ProcessDirectWithChannel(
|
|
context.Background(),
|
|
"run tools",
|
|
sessionKey,
|
|
"cli",
|
|
"chat1",
|
|
)
|
|
resultCh <- result{resp: resp, err: err}
|
|
}()
|
|
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
if err := al.InterruptGraceful("stop now"); err != nil {
|
|
t.Fatalf("InterruptGraceful failed: %v", err)
|
|
}
|
|
|
|
select {
|
|
case r := <-resultCh:
|
|
if r.err != nil {
|
|
t.Fatalf("unexpected error: %v", r.err)
|
|
}
|
|
case <-time.After(3 * time.Second):
|
|
t.Fatal("timeout waiting for result")
|
|
}
|
|
|
|
events := collectEventStream(sub.C)
|
|
|
|
skippedEvts := filterEvents(events, EventKindToolExecSkipped)
|
|
if len(skippedEvts) < 1 {
|
|
t.Fatal("expected at least one ToolExecSkipped event after interrupt")
|
|
}
|
|
|
|
for _, evt := range skippedEvts {
|
|
payload, ok := evt.Payload.(ToolExecSkippedPayload)
|
|
if !ok {
|
|
t.Fatalf("expected ToolExecSkippedPayload, got %T", evt.Payload)
|
|
}
|
|
if payload.Reason != "graceful interrupt requested" {
|
|
t.Fatalf("expected skip reason 'graceful interrupt requested', got %q", payload.Reason)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestAgentLoop_HookRespond_SteeringSkipsRemaining(t *testing.T) {
|
|
provider := &multiToolProvider{
|
|
toolCalls: []providers.ToolCall{
|
|
{ID: "call-1", Name: "tool_one", Arguments: map[string]any{}},
|
|
{ID: "call-2", Name: "tool_two", Arguments: map[string]any{}},
|
|
{ID: "call-3", Name: "tool_three", Arguments: map[string]any{}},
|
|
},
|
|
finalContent: "done",
|
|
}
|
|
al, _, cleanup := newHookTestLoop(t, provider)
|
|
defer cleanup()
|
|
|
|
al.RegisterTool(&slowTool{name: "tool_two", duration: 100 * time.Millisecond})
|
|
al.RegisterTool(&slowTool{name: "tool_three", duration: 100 * time.Millisecond})
|
|
|
|
hook := &respondHook{
|
|
respondTools: map[string]bool{"tool_one": true},
|
|
}
|
|
if err := al.MountHook(NamedHook("respond-hook", hook)); err != nil {
|
|
t.Fatalf("MountHook failed: %v", err)
|
|
}
|
|
|
|
sub := al.SubscribeEvents(32)
|
|
defer al.UnsubscribeEvents(sub.ID)
|
|
|
|
sessionKey := session.BuildMainSessionKey(routing.DefaultAgentID)
|
|
|
|
type result struct {
|
|
resp string
|
|
err error
|
|
}
|
|
resultCh := make(chan result, 1)
|
|
go func() {
|
|
resp, err := al.ProcessDirectWithChannel(
|
|
context.Background(),
|
|
"run tools",
|
|
sessionKey,
|
|
"cli",
|
|
"chat1",
|
|
)
|
|
resultCh <- result{resp: resp, err: err}
|
|
}()
|
|
|
|
collectedEvents := make([]Event, 0, 8)
|
|
steered := false
|
|
deadline := time.After(3 * time.Second)
|
|
for !steered {
|
|
select {
|
|
case evt := <-sub.C:
|
|
collectedEvents = append(collectedEvents, evt)
|
|
if evt.Kind != EventKindToolExecEnd {
|
|
continue
|
|
}
|
|
payload, ok := evt.Payload.(ToolExecEndPayload)
|
|
if !ok || payload.Tool != "tool_one" {
|
|
continue
|
|
}
|
|
al.Steer(providers.Message{Role: "user", Content: "change direction"})
|
|
steered = true
|
|
case <-deadline:
|
|
t.Fatal("timeout waiting for tool_one to finish before steering")
|
|
}
|
|
}
|
|
|
|
select {
|
|
case r := <-resultCh:
|
|
if r.err != nil {
|
|
t.Fatalf("unexpected error: %v", r.err)
|
|
}
|
|
case <-time.After(3 * time.Second):
|
|
t.Fatal("timeout waiting for result")
|
|
}
|
|
|
|
events := append(collectedEvents, collectEventStream(sub.C)...)
|
|
|
|
skippedEvts := filterEvents(events, EventKindToolExecSkipped)
|
|
if len(skippedEvts) < 1 {
|
|
t.Fatal("expected at least one ToolExecSkipped event after steering")
|
|
}
|
|
|
|
for _, evt := range skippedEvts {
|
|
payload, ok := evt.Payload.(ToolExecSkippedPayload)
|
|
if !ok {
|
|
t.Fatalf("expected ToolExecSkippedPayload, got %T", evt.Payload)
|
|
}
|
|
if payload.Reason != "queued user steering message" {
|
|
t.Fatalf("expected skip reason 'queued user steering message', got %q", payload.Reason)
|
|
}
|
|
}
|
|
}
|
|
|
|
func filterEvents(events []Event, kind EventKind) []Event {
|
|
var result []Event
|
|
for _, evt := range events {
|
|
if evt.Kind == kind {
|
|
result = append(result, evt)
|
|
}
|
|
}
|
|
return result
|
|
}
|