mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Enhance hooks with respond action and comprehensive documentation (#2215)
* feat(hooks): add respond action for tool execution bypass Add a new HookActionRespond that allows hooks to return tool results directly, skipping actual tool execution. This enables plugin tool injection, caching, and mocking capabilities. - Add HookActionRespond constant and support in HookManager - Extend ToolCallHookRequest with HookResult field - Implement respond action handling in process hooks and agent loop - Add comprehensive tests for respond and deny_tool actions - Update documentation with hook actions table and examples * docs(hooks): add JSON-RPC protocol and plugin tool injection documentation Add comprehensive documentation for hook JSON-RPC protocol and plugin tool injection capabilities: - Add "Hook Actions" section to README.zh.md explaining respond action for tool execution bypass - Create hook-json-protocol.md/.zh.md detailing JSON-RPC 2.0 protocol for all hook methods - Create plugin-tool-injection.md/.zh.md with complete examples for external tool implementation - Document how hooks can inject tool definitions and return results via respond action - Include Python and Go examples for weather query plugin implementation * feat(agent): emit tool events and feedback for hook results Add ToolExecStart event emission and tool feedback for hook results to ensure consistent behavior between normal tool execution and hook bypass scenarios. This maintains parity in event tracking and user feedback when tools are executed via hooks. * style(agent): format whitespace in hook structs and constants Remove trailing whitespace and standardize spacing in JSON struct tags, constants, and test data for improved code consistency. * feat(hooks): add media support for plugin tool injection Extend the hook respond action to support media file handling: - Add `media` field for returning images and files from hooks - Add `response_handled` field to control turn completion behavior - When response_handled=true, media is automatically delivered to user - When response_handled=false, media is passed to LLM for vision requests This enables plugins to directly return generated images, downloaded files, and other media content either to users or for LLM analysis. * docs(hooks): document security implications of respond action Add security boundary documentation explaining that the respond action bypasses ApproveTool checks, allowing hooks to return results for any tool without approval. Include recommendations for secure hook implementation and code comments marking the security considerations. Changes: - Add "Security Boundaries" section to plugin-tool-injection docs - Document bypass of approval checks and associated risks - Provide security recommendations and example code - Add inline security comments in hooks.go and loop.go * refactor(agent): improve completeness of tool result cloning and hook processing Extend cloneToolResult to properly copy ArtifactTags and Messages fields, ensuring deep copies of all ToolResult data. Consolidate event emission and user message handling to match the normal tool execution flow. * fix(agent): align hook respond path with normal tool execution flow The hook respond code path was missing several critical behaviors that existed in normal tool execution: - Add logging for tool calls with arguments preview - Add is_tool_call metadata to user-facing messages - Handle attachment delivery failures by setting error state and notifying LLM - Set ResponseHandled=false when using bus for media delivery - Check for steering messages and graceful interrupts after tool execution, skipping remaining tools when appropriate - Poll for SubTurn results that arrived during tool execution This ensures consistent behavior between hook-responded tool calls and normally executed tool calls. * test(agent): add tests for hook respond media error handling Add comprehensive tests for the hook respond code path when media delivery fails. Tests cover error media channel scenarios and verify proper error state handling. Also document that AfterTool is not called when using respond action, as it provides the final answer directly (design decision).
This commit is contained in:
@@ -13,6 +13,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -90,7 +91,8 @@ type processHookAfterLLMResponse struct {
|
||||
|
||||
type processHookBeforeToolResponse struct {
|
||||
processHookDecisionResponse
|
||||
Call *ToolCallHookRequest `json:"call,omitempty"`
|
||||
Call *ToolCallHookRequest `json:"call,omitempty"`
|
||||
Result *tools.ToolResult `json:"result,omitempty"` // Result returned directly by hook (for respond action)
|
||||
}
|
||||
|
||||
type processHookAfterToolResponse struct {
|
||||
@@ -241,6 +243,10 @@ func (ph *ProcessHook) BeforeTool(
|
||||
if resp.Call == nil {
|
||||
resp.Call = call
|
||||
}
|
||||
// If hook returned a Result, carry it in ToolCallHookRequest
|
||||
if resp.Result != nil {
|
||||
resp.Call.HookResult = resp.Result
|
||||
}
|
||||
return resp.Call, HookDecision{Action: resp.Action, Reason: resp.Reason}, nil
|
||||
}
|
||||
|
||||
|
||||
+19
-5
@@ -25,6 +25,7 @@ type HookAction string
|
||||
const (
|
||||
HookActionContinue HookAction = "continue"
|
||||
HookActionModify HookAction = "modify"
|
||||
HookActionRespond HookAction = "respond" // Return result directly, skip tool execution. SECURITY: This bypasses ApproveTool checks, allowing hooks to return results for any tool (including sensitive ones like bash) without approval. Use with caution.
|
||||
HookActionDenyTool HookAction = "deny_tool"
|
||||
HookActionAbortTurn HookAction = "abort_turn"
|
||||
HookActionHardAbort HookAction = "hard_abort"
|
||||
@@ -127,11 +128,12 @@ func (r *LLMHookResponse) Clone() *LLMHookResponse {
|
||||
}
|
||||
|
||||
type ToolCallHookRequest struct {
|
||||
Meta EventMeta `json:"meta"`
|
||||
Tool string `json:"tool"`
|
||||
Arguments map[string]any `json:"arguments,omitempty"`
|
||||
Channel string `json:"channel,omitempty"`
|
||||
ChatID string `json:"chat_id,omitempty"`
|
||||
Meta EventMeta `json:"meta"`
|
||||
Tool string `json:"tool"`
|
||||
Arguments map[string]any `json:"arguments,omitempty"`
|
||||
Channel string `json:"channel,omitempty"`
|
||||
ChatID string `json:"chat_id,omitempty"`
|
||||
HookResult *tools.ToolResult `json:"hook_result,omitempty"` // Result returned directly by hook (for respond action). Media is supported - see Media handling section in docs.
|
||||
}
|
||||
|
||||
func (r *ToolCallHookRequest) Clone() *ToolCallHookRequest {
|
||||
@@ -140,6 +142,7 @@ func (r *ToolCallHookRequest) Clone() *ToolCallHookRequest {
|
||||
}
|
||||
cloned := *r
|
||||
cloned.Arguments = cloneStringAnyMap(r.Arguments)
|
||||
cloned.HookResult = cloneToolResult(r.HookResult)
|
||||
return &cloned
|
||||
}
|
||||
|
||||
@@ -382,6 +385,10 @@ func (hm *HookManager) BeforeTool(
|
||||
if next != nil {
|
||||
current = next
|
||||
}
|
||||
case HookActionRespond:
|
||||
// Hook returns result directly, skip tool execution
|
||||
// Carry HookResult in ToolCallHookRequest and return
|
||||
return next, decision
|
||||
case HookActionDenyTool, HookActionAbortTurn, HookActionHardAbort:
|
||||
return current, decision
|
||||
default:
|
||||
@@ -793,6 +800,13 @@ func cloneToolResult(result *tools.ToolResult) *tools.ToolResult {
|
||||
if len(result.Media) > 0 {
|
||||
cloned.Media = append([]string(nil), result.Media...)
|
||||
}
|
||||
if len(result.ArtifactTags) > 0 {
|
||||
cloned.ArtifactTags = append([]string(nil), result.ArtifactTags...)
|
||||
}
|
||||
if len(result.Messages) > 0 {
|
||||
cloned.Messages = make([]providers.Message, len(result.Messages))
|
||||
copy(cloned.Messages, result.Messages)
|
||||
}
|
||||
return &cloned
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
@@ -10,6 +11,7 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/routing"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
|
||||
@@ -343,3 +345,518 @@ func TestAgentLoop_Hooks_ToolApproverCanDeny(t *testing.T) {
|
||||
t.Fatalf("expected skipped reason %q, got %q", expected, payload.Reason)
|
||||
}
|
||||
}
|
||||
|
||||
// respondHook is a test hook for testing HookActionRespond functionality
|
||||
type respondHook struct {
|
||||
respondTools map[string]bool // tool names to respond to
|
||||
}
|
||||
|
||||
func (h *respondHook) BeforeTool(
|
||||
ctx context.Context,
|
||||
call *ToolCallHookRequest,
|
||||
) (*ToolCallHookRequest, HookDecision, error) {
|
||||
if h.respondTools[call.Tool] {
|
||||
next := call.Clone()
|
||||
next.HookResult = &tools.ToolResult{
|
||||
ForLLM: "hook-responded: " + call.Tool,
|
||||
ForUser: "",
|
||||
Silent: false,
|
||||
IsError: false,
|
||||
}
|
||||
return next, HookDecision{Action: HookActionRespond}, nil
|
||||
}
|
||||
return call, HookDecision{Action: HookActionContinue}, nil
|
||||
}
|
||||
|
||||
func (h *respondHook) AfterTool(
|
||||
ctx context.Context,
|
||||
result *ToolResultHookResponse,
|
||||
) (*ToolResultHookResponse, HookDecision, error) {
|
||||
// Should not be called since respond skips tool execution
|
||||
return result, HookDecision{Action: HookActionContinue}, nil
|
||||
}
|
||||
|
||||
func TestAgentLoop_Hooks_ToolRespondAction(t *testing.T) {
|
||||
provider := &toolHookProvider{}
|
||||
al, agent, cleanup := newHookTestLoop(t, provider)
|
||||
defer cleanup()
|
||||
|
||||
al.RegisterTool(&echoTextTool{})
|
||||
if err := al.MountHook(NamedHook("respond-hook", &respondHook{
|
||||
respondTools: map[string]bool{"echo_text": true},
|
||||
})); err != nil {
|
||||
t.Fatalf("MountHook failed: %v", err)
|
||||
}
|
||||
|
||||
sub := al.SubscribeEvents(16)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
|
||||
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
||||
SessionKey: "session-1",
|
||||
Channel: "cli",
|
||||
ChatID: "direct",
|
||||
UserMessage: "run tool",
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("runAgentLoop failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify response comes from hook, not tool
|
||||
expected := "hook-responded: echo_text"
|
||||
if resp != expected {
|
||||
t.Fatalf("expected %q, got %q", expected, resp)
|
||||
}
|
||||
|
||||
// Verify event stream has ToolExecEnd, not actual tool execution
|
||||
events := collectEventStream(sub.C)
|
||||
endEvt, ok := findEvent(events, EventKindToolExecEnd)
|
||||
if !ok {
|
||||
t.Fatal("expected tool exec end event")
|
||||
}
|
||||
payload, ok := endEvt.Payload.(ToolExecEndPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected ToolExecEndPayload, got %T", endEvt.Payload)
|
||||
}
|
||||
if payload.Tool != "echo_text" {
|
||||
t.Fatalf("expected tool echo_text, got %q", payload.Tool)
|
||||
}
|
||||
if payload.ForLLMLen != len(expected) {
|
||||
t.Fatalf("expected ForLLMLen %d, got %d", len(expected), payload.ForLLMLen)
|
||||
}
|
||||
}
|
||||
|
||||
// denyToolHook tests HookActionDenyTool functionality
|
||||
type denyToolHook struct {
|
||||
denyTools map[string]bool
|
||||
}
|
||||
|
||||
func (h *denyToolHook) BeforeTool(
|
||||
ctx context.Context,
|
||||
call *ToolCallHookRequest,
|
||||
) (*ToolCallHookRequest, HookDecision, error) {
|
||||
if h.denyTools[call.Tool] {
|
||||
return call, HookDecision{Action: HookActionDenyTool, Reason: "tool denied by hook"}, nil
|
||||
}
|
||||
return call, HookDecision{Action: HookActionContinue}, nil
|
||||
}
|
||||
|
||||
func (h *denyToolHook) AfterTool(
|
||||
ctx context.Context,
|
||||
result *ToolResultHookResponse,
|
||||
) (*ToolResultHookResponse, HookDecision, error) {
|
||||
return result, HookDecision{Action: HookActionContinue}, nil
|
||||
}
|
||||
|
||||
func TestAgentLoop_Hooks_ToolDenyAction(t *testing.T) {
|
||||
provider := &toolHookProvider{}
|
||||
al, agent, cleanup := newHookTestLoop(t, provider)
|
||||
defer cleanup()
|
||||
|
||||
al.RegisterTool(&echoTextTool{})
|
||||
if err := al.MountHook(NamedHook("deny-hook", &denyToolHook{
|
||||
denyTools: map[string]bool{"echo_text": true},
|
||||
})); err != nil {
|
||||
t.Fatalf("MountHook failed: %v", err)
|
||||
}
|
||||
|
||||
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
||||
SessionKey: "session-1",
|
||||
Channel: "cli",
|
||||
ChatID: "direct",
|
||||
UserMessage: "run tool",
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("runAgentLoop failed: %v", err)
|
||||
}
|
||||
|
||||
expected := "Tool execution denied by hook: tool denied by hook"
|
||||
if resp != expected {
|
||||
t.Fatalf("expected %q, got %q", expected, resp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHookManager_BeforeTool_RespondAction(t *testing.T) {
|
||||
hm := NewHookManager(nil)
|
||||
defer hm.Close()
|
||||
|
||||
hook := &respondHook{
|
||||
respondTools: map[string]bool{"test_tool": true},
|
||||
}
|
||||
if err := hm.Mount(NamedHook("respond-test", hook)); err != nil {
|
||||
t.Fatalf("mount hook: %v", err)
|
||||
}
|
||||
|
||||
req := &ToolCallHookRequest{
|
||||
Tool: "test_tool",
|
||||
Arguments: map[string]any{"arg": "value"},
|
||||
}
|
||||
result, decision := hm.BeforeTool(context.Background(), req)
|
||||
|
||||
if decision.Action != HookActionRespond {
|
||||
t.Fatalf("expected action %q, got %q", HookActionRespond, decision.Action)
|
||||
}
|
||||
|
||||
if result.HookResult == nil {
|
||||
t.Fatal("expected HookResult to be set")
|
||||
}
|
||||
if result.HookResult.ForLLM != "hook-responded: test_tool" {
|
||||
t.Fatalf("unexpected HookResult.ForLLM: %q", result.HookResult.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
type respondWithMediaHook struct {
|
||||
respondTools map[string]bool
|
||||
media []string
|
||||
responseHandled bool
|
||||
forLLM string
|
||||
sendMediaErr error
|
||||
}
|
||||
|
||||
func (h *respondWithMediaHook) BeforeTool(
|
||||
ctx context.Context,
|
||||
call *ToolCallHookRequest,
|
||||
) (*ToolCallHookRequest, HookDecision, error) {
|
||||
if h.respondTools[call.Tool] {
|
||||
next := call.Clone()
|
||||
next.HookResult = &tools.ToolResult{
|
||||
ForLLM: h.forLLM,
|
||||
ForUser: "media result",
|
||||
Media: h.media,
|
||||
ResponseHandled: h.responseHandled,
|
||||
Silent: false,
|
||||
IsError: false,
|
||||
}
|
||||
return next, HookDecision{Action: HookActionRespond}, nil
|
||||
}
|
||||
return call, HookDecision{Action: HookActionContinue}, nil
|
||||
}
|
||||
|
||||
func (h *respondWithMediaHook) AfterTool(
|
||||
ctx context.Context,
|
||||
result *ToolResultHookResponse,
|
||||
) (*ToolResultHookResponse, HookDecision, error) {
|
||||
return result, HookDecision{Action: HookActionContinue}, nil
|
||||
}
|
||||
|
||||
type errorMediaChannel struct {
|
||||
fakeChannel
|
||||
sendErr error
|
||||
}
|
||||
|
||||
func (f *errorMediaChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) ([]string, error) {
|
||||
return nil, f.sendErr
|
||||
}
|
||||
|
||||
func TestAgentLoop_HookRespond_MediaError(t *testing.T) {
|
||||
provider := &multiToolProvider{
|
||||
toolCalls: []providers.ToolCall{
|
||||
{ID: "call-1", Name: "media_tool", Arguments: map[string]any{}},
|
||||
},
|
||||
finalContent: "done",
|
||||
}
|
||||
al, agent, cleanup := newHookTestLoop(t, provider)
|
||||
defer cleanup()
|
||||
|
||||
hook := &respondWithMediaHook{
|
||||
respondTools: map[string]bool{"media_tool": true},
|
||||
media: []string{"media://test/image.png"},
|
||||
responseHandled: true,
|
||||
forLLM: "media sent successfully",
|
||||
}
|
||||
if err := al.MountHook(NamedHook("media-hook", hook)); err != nil {
|
||||
t.Fatalf("MountHook failed: %v", err)
|
||||
}
|
||||
|
||||
al.channelManager = newStartedTestChannelManager(t, al.bus, al.mediaStore, "discord", &errorMediaChannel{
|
||||
sendErr: errors.New("channel unavailable"),
|
||||
})
|
||||
|
||||
sub := al.SubscribeEvents(16)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
|
||||
_, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
||||
SessionKey: "session-media-err",
|
||||
Channel: "discord",
|
||||
ChatID: "chat1",
|
||||
UserMessage: "send media",
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("runAgentLoop failed: %v", err)
|
||||
}
|
||||
|
||||
events := collectEventStream(sub.C)
|
||||
endEvt, ok := findEvent(events, EventKindToolExecEnd)
|
||||
if !ok {
|
||||
t.Fatal("expected ToolExecEnd event")
|
||||
}
|
||||
payload, ok := endEvt.Payload.(ToolExecEndPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected ToolExecEndPayload, got %T", endEvt.Payload)
|
||||
}
|
||||
|
||||
if !payload.IsError {
|
||||
t.Fatal("expected IsError=true when SendMedia fails")
|
||||
}
|
||||
|
||||
if payload.ForLLMLen < 30 {
|
||||
t.Fatalf("expected ForLLM to contain error message, got ForLLMLen=%d", payload.ForLLMLen)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_HookRespond_BusFallback(t *testing.T) {
|
||||
provider := &multiToolProvider{
|
||||
toolCalls: []providers.ToolCall{
|
||||
{ID: "call-1", Name: "media_tool", Arguments: map[string]any{}},
|
||||
},
|
||||
finalContent: "done",
|
||||
}
|
||||
al, agent, cleanup := newHookTestLoop(t, provider)
|
||||
defer cleanup()
|
||||
|
||||
hook := &respondWithMediaHook{
|
||||
respondTools: map[string]bool{"media_tool": true},
|
||||
media: []string{"media://test/image.png"},
|
||||
responseHandled: true,
|
||||
forLLM: "media queued",
|
||||
}
|
||||
if err := al.MountHook(NamedHook("media-hook", hook)); err != nil {
|
||||
t.Fatalf("MountHook failed: %v", err)
|
||||
}
|
||||
|
||||
sub := al.SubscribeEvents(16)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
|
||||
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
||||
SessionKey: "session-bus-fallback",
|
||||
Channel: "cli",
|
||||
ChatID: "chat1",
|
||||
UserMessage: "send media",
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("runAgentLoop failed: %v", err)
|
||||
}
|
||||
|
||||
events := collectEventStream(sub.C)
|
||||
endEvt, ok := findEvent(events, EventKindToolExecEnd)
|
||||
if !ok {
|
||||
t.Fatal("expected ToolExecEnd event")
|
||||
}
|
||||
payload, ok := endEvt.Payload.(ToolExecEndPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected ToolExecEndPayload, got %T", endEvt.Payload)
|
||||
}
|
||||
|
||||
if payload.IsError {
|
||||
t.Fatal("expected IsError=false for bus fallback (media queued, not delivered)")
|
||||
}
|
||||
|
||||
if resp != "done" {
|
||||
t.Fatalf("expected response 'done', got %q", resp)
|
||||
}
|
||||
}
|
||||
|
||||
type multiToolProvider struct {
|
||||
mu sync.Mutex
|
||||
callCount int
|
||||
toolCalls []providers.ToolCall
|
||||
finalContent string
|
||||
}
|
||||
|
||||
func (p *multiToolProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
p.callCount++
|
||||
if p.callCount == 1 && len(p.toolCalls) > 0 {
|
||||
return &providers.LLMResponse{
|
||||
ToolCalls: p.toolCalls,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &providers.LLMResponse{
|
||||
Content: p.finalContent,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *multiToolProvider) GetDefaultModel() string {
|
||||
return "multi-tool-provider"
|
||||
}
|
||||
|
||||
func TestAgentLoop_HookRespond_InterruptSkipsRemaining(t *testing.T) {
|
||||
provider := &multiToolProvider{
|
||||
toolCalls: []providers.ToolCall{
|
||||
{ID: "call-1", Name: "tool_one", Arguments: map[string]any{}},
|
||||
{ID: "call-2", Name: "tool_two", Arguments: map[string]any{}},
|
||||
{ID: "call-3", Name: "tool_three", Arguments: map[string]any{}},
|
||||
},
|
||||
finalContent: "done",
|
||||
}
|
||||
al, _, cleanup := newHookTestLoop(t, provider)
|
||||
defer cleanup()
|
||||
|
||||
tool1ExecCh := make(chan struct{}, 1)
|
||||
al.RegisterTool(&slowTool{name: "tool_two", duration: 100 * time.Millisecond, execCh: tool1ExecCh})
|
||||
al.RegisterTool(&slowTool{name: "tool_three", duration: 100 * time.Millisecond})
|
||||
|
||||
hook := &respondHook{
|
||||
respondTools: map[string]bool{"tool_one": true},
|
||||
}
|
||||
if err := al.MountHook(NamedHook("respond-hook", hook)); err != nil {
|
||||
t.Fatalf("MountHook failed: %v", err)
|
||||
}
|
||||
|
||||
sub := al.SubscribeEvents(32)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
|
||||
sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID)
|
||||
|
||||
type result struct {
|
||||
resp string
|
||||
err error
|
||||
}
|
||||
resultCh := make(chan result, 1)
|
||||
go func() {
|
||||
resp, err := al.ProcessDirectWithChannel(
|
||||
context.Background(),
|
||||
"run tools",
|
||||
sessionKey,
|
||||
"cli",
|
||||
"chat1",
|
||||
)
|
||||
resultCh <- result{resp: resp, err: err}
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if err := al.InterruptGraceful("stop now"); err != nil {
|
||||
t.Fatalf("InterruptGraceful failed: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case r := <-resultCh:
|
||||
if r.err != nil {
|
||||
t.Fatalf("unexpected error: %v", r.err)
|
||||
}
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("timeout waiting for result")
|
||||
}
|
||||
|
||||
events := collectEventStream(sub.C)
|
||||
|
||||
skippedEvts := filterEvents(events, EventKindToolExecSkipped)
|
||||
if len(skippedEvts) < 1 {
|
||||
t.Fatal("expected at least one ToolExecSkipped event after interrupt")
|
||||
}
|
||||
|
||||
for _, evt := range skippedEvts {
|
||||
payload, ok := evt.Payload.(ToolExecSkippedPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected ToolExecSkippedPayload, got %T", evt.Payload)
|
||||
}
|
||||
if payload.Reason != "graceful interrupt requested" {
|
||||
t.Fatalf("expected skip reason 'graceful interrupt requested', got %q", payload.Reason)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_HookRespond_SteeringSkipsRemaining(t *testing.T) {
|
||||
provider := &multiToolProvider{
|
||||
toolCalls: []providers.ToolCall{
|
||||
{ID: "call-1", Name: "tool_one", Arguments: map[string]any{}},
|
||||
{ID: "call-2", Name: "tool_two", Arguments: map[string]any{}},
|
||||
{ID: "call-3", Name: "tool_three", Arguments: map[string]any{}},
|
||||
},
|
||||
finalContent: "done",
|
||||
}
|
||||
al, _, cleanup := newHookTestLoop(t, provider)
|
||||
defer cleanup()
|
||||
|
||||
al.RegisterTool(&slowTool{name: "tool_two", duration: 100 * time.Millisecond})
|
||||
al.RegisterTool(&slowTool{name: "tool_three", duration: 100 * time.Millisecond})
|
||||
|
||||
hook := &respondHook{
|
||||
respondTools: map[string]bool{"tool_one": true},
|
||||
}
|
||||
if err := al.MountHook(NamedHook("respond-hook", hook)); err != nil {
|
||||
t.Fatalf("MountHook failed: %v", err)
|
||||
}
|
||||
|
||||
sub := al.SubscribeEvents(32)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
|
||||
sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID)
|
||||
|
||||
type result struct {
|
||||
resp string
|
||||
err error
|
||||
}
|
||||
resultCh := make(chan result, 1)
|
||||
go func() {
|
||||
resp, err := al.ProcessDirectWithChannel(
|
||||
context.Background(),
|
||||
"run tools",
|
||||
sessionKey,
|
||||
"cli",
|
||||
"chat1",
|
||||
)
|
||||
resultCh <- result{resp: resp, err: err}
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
al.Steer(providers.Message{Role: "user", Content: "change direction"})
|
||||
|
||||
select {
|
||||
case r := <-resultCh:
|
||||
if r.err != nil {
|
||||
t.Fatalf("unexpected error: %v", r.err)
|
||||
}
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("timeout waiting for result")
|
||||
}
|
||||
|
||||
events := collectEventStream(sub.C)
|
||||
|
||||
skippedEvts := filterEvents(events, EventKindToolExecSkipped)
|
||||
if len(skippedEvts) < 1 {
|
||||
t.Fatal("expected at least one ToolExecSkipped event after steering")
|
||||
}
|
||||
|
||||
for _, evt := range skippedEvts {
|
||||
payload, ok := evt.Payload.(ToolExecSkippedPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected ToolExecSkippedPayload, got %T", evt.Payload)
|
||||
}
|
||||
if payload.Reason != "queued user steering message" {
|
||||
t.Fatalf("expected skip reason 'queued user steering message', got %q", payload.Reason)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func filterEvents(events []Event, kind EventKind) []Event {
|
||||
var result []Event
|
||||
for _, evt := range events {
|
||||
if evt.Kind == kind {
|
||||
result = append(result, evt)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -2352,6 +2352,236 @@ turnLoop:
|
||||
toolName = toolReq.Tool
|
||||
toolArgs = toolReq.Arguments
|
||||
}
|
||||
case HookActionRespond:
|
||||
// Hook returns result directly, skip tool execution.
|
||||
// SECURITY: This bypasses ApproveTool, allowing hooks to respond
|
||||
// for any tool name without approval. This is intentional for
|
||||
// plugin tools but means a before_tool hook can override even
|
||||
// sensitive tools like bash. Hook configuration should be
|
||||
// carefully reviewed to prevent unauthorized tool execution.
|
||||
if toolReq != nil && toolReq.HookResult != nil {
|
||||
hookResult := toolReq.HookResult
|
||||
|
||||
argsJSON, _ := json.Marshal(toolArgs)
|
||||
argsPreview := utils.Truncate(string(argsJSON), 200)
|
||||
logger.InfoCF("agent", fmt.Sprintf("Tool call (hook respond): %s(%s)", toolName, argsPreview),
|
||||
map[string]any{
|
||||
"agent_id": ts.agent.ID,
|
||||
"tool": toolName,
|
||||
"iteration": iteration,
|
||||
})
|
||||
|
||||
// Emit ToolExecStart event (same as normal tool execution)
|
||||
al.emitEvent(
|
||||
EventKindToolExecStart,
|
||||
ts.eventMeta("runTurn", "turn.tool.start"),
|
||||
ToolExecStartPayload{
|
||||
Tool: toolName,
|
||||
Arguments: cloneEventArguments(toolArgs),
|
||||
},
|
||||
)
|
||||
|
||||
// Send tool feedback to chat channel if enabled (same as normal tool execution)
|
||||
if al.cfg.Agents.Defaults.IsToolFeedbackEnabled() &&
|
||||
ts.channel != "" &&
|
||||
!ts.opts.SuppressToolFeedback {
|
||||
argsJSON, _ := json.Marshal(toolArgs)
|
||||
feedbackPreview := utils.Truncate(
|
||||
string(argsJSON),
|
||||
al.cfg.Agents.Defaults.GetToolFeedbackMaxArgsLength(),
|
||||
)
|
||||
feedbackMsg := fmt.Sprintf("\U0001f527 `%s`\n```\n%s\n```", toolName, feedbackPreview)
|
||||
fbCtx, fbCancel := context.WithTimeout(turnCtx, 3*time.Second)
|
||||
_ = al.bus.PublishOutbound(fbCtx, bus.OutboundMessage{
|
||||
Channel: ts.channel,
|
||||
ChatID: ts.chatID,
|
||||
Content: feedbackMsg,
|
||||
})
|
||||
fbCancel()
|
||||
}
|
||||
|
||||
toolDuration := time.Duration(0) // Hook execution time unknown
|
||||
|
||||
// Send ForUser content to user
|
||||
// For ResponseHandled results, send regardless of SendResponse setting,
|
||||
// same as normal tool execution path.
|
||||
shouldSendForUser := !hookResult.Silent && hookResult.ForUser != "" &&
|
||||
(ts.opts.SendResponse || hookResult.ResponseHandled)
|
||||
if shouldSendForUser {
|
||||
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
|
||||
Channel: ts.channel,
|
||||
ChatID: ts.chatID,
|
||||
Content: hookResult.ForUser,
|
||||
Metadata: map[string]string{
|
||||
"is_tool_call": "true",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Handle media from hook result (same as normal tool execution)
|
||||
if len(hookResult.Media) > 0 && hookResult.ResponseHandled {
|
||||
parts := make([]bus.MediaPart, 0, len(hookResult.Media))
|
||||
for _, ref := range hookResult.Media {
|
||||
part := bus.MediaPart{Ref: ref}
|
||||
if al.mediaStore != nil {
|
||||
if _, meta, err := al.mediaStore.ResolveWithMeta(ref); err == nil {
|
||||
part.Filename = meta.Filename
|
||||
part.ContentType = meta.ContentType
|
||||
part.Type = inferMediaType(meta.Filename, meta.ContentType)
|
||||
}
|
||||
}
|
||||
parts = append(parts, part)
|
||||
}
|
||||
outboundMedia := bus.OutboundMediaMessage{
|
||||
Channel: ts.channel,
|
||||
ChatID: ts.chatID,
|
||||
Parts: parts,
|
||||
}
|
||||
if al.channelManager != nil && ts.channel != "" && !constants.IsInternalChannel(ts.channel) {
|
||||
if err := al.channelManager.SendMedia(ctx, outboundMedia); err != nil {
|
||||
logger.WarnCF("agent", "Failed to deliver hook media",
|
||||
map[string]any{
|
||||
"agent_id": ts.agent.ID,
|
||||
"tool": toolName,
|
||||
"channel": ts.channel,
|
||||
"chat_id": ts.chatID,
|
||||
"error": err.Error(),
|
||||
})
|
||||
// Same as normal tool execution: notify LLM about delivery failure
|
||||
hookResult.IsError = true
|
||||
hookResult.ForLLM = fmt.Sprintf("failed to deliver attachment: %v", err)
|
||||
}
|
||||
} else if al.bus != nil {
|
||||
al.bus.PublishOutboundMedia(ctx, outboundMedia)
|
||||
// Same as normal tool execution: bus only queues, media not yet delivered
|
||||
hookResult.ResponseHandled = false
|
||||
}
|
||||
}
|
||||
|
||||
// Track response handling status (same as normal tool execution)
|
||||
if !hookResult.ResponseHandled {
|
||||
allResponsesHandled = false
|
||||
}
|
||||
|
||||
// Build tool message
|
||||
contentForLLM := hookResult.ContentForLLM()
|
||||
if al.cfg.Tools.IsFilterSensitiveDataEnabled() {
|
||||
contentForLLM = al.cfg.FilterSensitiveData(contentForLLM)
|
||||
}
|
||||
|
||||
toolResultMsg := providers.Message{
|
||||
Role: "tool",
|
||||
Content: contentForLLM,
|
||||
ToolCallID: tc.ID,
|
||||
}
|
||||
|
||||
// Handle media for LLM vision (same as normal tool execution)
|
||||
if len(hookResult.Media) > 0 && !hookResult.ResponseHandled {
|
||||
hookResult.ArtifactTags = buildArtifactTags(al.mediaStore, hookResult.Media)
|
||||
// Recalculate contentForLLM after adding ArtifactTags
|
||||
contentForLLM = hookResult.ContentForLLM()
|
||||
if al.cfg.Tools.IsFilterSensitiveDataEnabled() {
|
||||
contentForLLM = al.cfg.FilterSensitiveData(contentForLLM)
|
||||
}
|
||||
toolResultMsg.Content = contentForLLM
|
||||
toolResultMsg.Media = append(toolResultMsg.Media, hookResult.Media...)
|
||||
}
|
||||
|
||||
// Emit ToolExecEnd event (after filtering, same as normal tool execution)
|
||||
al.emitEvent(
|
||||
EventKindToolExecEnd,
|
||||
ts.eventMeta("runTurn", "turn.tool.end"),
|
||||
ToolExecEndPayload{
|
||||
Tool: toolName,
|
||||
Duration: toolDuration,
|
||||
ForLLMLen: len(contentForLLM),
|
||||
ForUserLen: len(hookResult.ForUser),
|
||||
IsError: hookResult.IsError,
|
||||
Async: hookResult.Async,
|
||||
},
|
||||
)
|
||||
|
||||
messages = append(messages, toolResultMsg)
|
||||
if !ts.opts.NoHistory {
|
||||
ts.agent.Sessions.AddFullMessage(ts.sessionKey, toolResultMsg)
|
||||
ts.recordPersistedMessage(toolResultMsg)
|
||||
ts.ingestMessage(turnCtx, al, toolResultMsg)
|
||||
}
|
||||
|
||||
// Same as normal tool execution: check for steering/interrupt/SubTurn after each tool
|
||||
if steerMsgs := al.dequeueSteeringMessagesForScope(ts.sessionKey); len(steerMsgs) > 0 {
|
||||
pendingMessages = append(pendingMessages, steerMsgs...)
|
||||
}
|
||||
|
||||
skipReason := ""
|
||||
skipMessage := ""
|
||||
if len(pendingMessages) > 0 {
|
||||
skipReason = "queued user steering message"
|
||||
skipMessage = "Skipped due to queued user message."
|
||||
} else if gracefulPending, _ := ts.gracefulInterruptRequested(); gracefulPending {
|
||||
skipReason = "graceful interrupt requested"
|
||||
skipMessage = "Skipped due to graceful interrupt."
|
||||
}
|
||||
|
||||
if skipReason != "" {
|
||||
remaining := len(normalizedToolCalls) - i - 1
|
||||
if remaining > 0 {
|
||||
logger.InfoCF("agent", "Turn checkpoint: skipping remaining tools after hook respond",
|
||||
map[string]any{
|
||||
"agent_id": ts.agent.ID,
|
||||
"completed": i + 1,
|
||||
"skipped": remaining,
|
||||
"reason": skipReason,
|
||||
})
|
||||
for j := i + 1; j < len(normalizedToolCalls); j++ {
|
||||
skippedTC := normalizedToolCalls[j]
|
||||
al.emitEvent(
|
||||
EventKindToolExecSkipped,
|
||||
ts.eventMeta("runTurn", "turn.tool.skipped"),
|
||||
ToolExecSkippedPayload{
|
||||
Tool: skippedTC.Name,
|
||||
Reason: skipReason,
|
||||
},
|
||||
)
|
||||
skippedMsg := providers.Message{
|
||||
Role: "tool",
|
||||
Content: skipMessage,
|
||||
ToolCallID: skippedTC.ID,
|
||||
}
|
||||
messages = append(messages, skippedMsg)
|
||||
if !ts.opts.NoHistory {
|
||||
ts.agent.Sessions.AddFullMessage(ts.sessionKey, skippedMsg)
|
||||
ts.recordPersistedMessage(skippedMsg)
|
||||
}
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
// Also poll for any SubTurn results that arrived during tool execution.
|
||||
if ts.pendingResults != nil {
|
||||
select {
|
||||
case result, ok := <-ts.pendingResults:
|
||||
if ok && result != nil && result.ForLLM != "" {
|
||||
content := al.cfg.FilterSensitiveData(result.ForLLM)
|
||||
msg := providers.Message{Role: "user", Content: fmt.Sprintf("[SubTurn Result] %s", content)}
|
||||
messages = append(messages, msg)
|
||||
ts.agent.Sessions.AddFullMessage(ts.sessionKey, msg)
|
||||
}
|
||||
default:
|
||||
// No results available
|
||||
}
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
// If no HookResult, fall back to continue with warning
|
||||
logger.WarnCF("agent", "Hook returned respond action but no HookResult provided",
|
||||
map[string]any{
|
||||
"agent_id": ts.agent.ID,
|
||||
"tool": toolName,
|
||||
"action": "respond",
|
||||
})
|
||||
case HookActionDenyTool:
|
||||
allResponsesHandled = false
|
||||
denyContent := hookDeniedToolContent("Tool execution denied by hook", decision.Reason)
|
||||
|
||||
Reference in New Issue
Block a user