feat(agent): add hook manager foundation

This commit is contained in:
Hoshina
2026-03-21 19:15:10 +08:00
parent 73a683fd16
commit cf68c91eca
4 changed files with 1801 additions and 47 deletions
+751
View File
@@ -0,0 +1,751 @@
package agent
import (
"context"
"fmt"
"sort"
"sync"
"time"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/tools"
)
const (
defaultHookObserverTimeout = 500 * time.Millisecond
defaultHookInterceptorTimeout = 5 * time.Second
defaultHookApprovalTimeout = 60 * time.Second
hookObserverBufferSize = 64
)
type HookAction string
const (
HookActionContinue HookAction = "continue"
HookActionModify HookAction = "modify"
HookActionDenyTool HookAction = "deny_tool"
HookActionAbortTurn HookAction = "abort_turn"
HookActionHardAbort HookAction = "hard_abort"
)
type HookDecision struct {
Action HookAction
Reason string
}
func (d HookDecision) normalizedAction() HookAction {
if d.Action == "" {
return HookActionContinue
}
return d.Action
}
type ApprovalDecision struct {
Approved bool
Reason string
}
type HookRegistration struct {
Name string
Priority int
Hook any
}
func NamedHook(name string, hook any) HookRegistration {
return HookRegistration{
Name: name,
Hook: hook,
}
}
type EventObserver interface {
OnEvent(ctx context.Context, evt Event) error
}
type LLMInterceptor interface {
BeforeLLM(ctx context.Context, req *LLMHookRequest) (*LLMHookRequest, HookDecision, error)
AfterLLM(ctx context.Context, resp *LLMHookResponse) (*LLMHookResponse, HookDecision, error)
}
type ToolInterceptor interface {
BeforeTool(ctx context.Context, call *ToolCallHookRequest) (*ToolCallHookRequest, HookDecision, error)
AfterTool(ctx context.Context, result *ToolResultHookResponse) (*ToolResultHookResponse, HookDecision, error)
}
type ToolApprover interface {
ApproveTool(ctx context.Context, req *ToolApprovalRequest) (ApprovalDecision, error)
}
type LLMHookRequest struct {
Meta EventMeta
Model string
Messages []providers.Message
Tools []providers.ToolDefinition
Options map[string]any
Channel string
ChatID string
GracefulTerminal bool
}
func (r *LLMHookRequest) Clone() *LLMHookRequest {
if r == nil {
return nil
}
cloned := *r
cloned.Messages = cloneProviderMessages(r.Messages)
cloned.Tools = cloneToolDefinitions(r.Tools)
cloned.Options = cloneStringAnyMap(r.Options)
return &cloned
}
type LLMHookResponse struct {
Meta EventMeta
Model string
Response *providers.LLMResponse
Channel string
ChatID string
}
func (r *LLMHookResponse) Clone() *LLMHookResponse {
if r == nil {
return nil
}
cloned := *r
cloned.Response = cloneLLMResponse(r.Response)
return &cloned
}
type ToolCallHookRequest struct {
Meta EventMeta
Tool string
Arguments map[string]any
Channel string
ChatID string
}
func (r *ToolCallHookRequest) Clone() *ToolCallHookRequest {
if r == nil {
return nil
}
cloned := *r
cloned.Arguments = cloneStringAnyMap(r.Arguments)
return &cloned
}
type ToolApprovalRequest struct {
Meta EventMeta
Tool string
Arguments map[string]any
Channel string
ChatID string
}
func (r *ToolApprovalRequest) Clone() *ToolApprovalRequest {
if r == nil {
return nil
}
cloned := *r
cloned.Arguments = cloneStringAnyMap(r.Arguments)
return &cloned
}
type ToolResultHookResponse struct {
Meta EventMeta
Tool string
Arguments map[string]any
Result *tools.ToolResult
Duration time.Duration
Channel string
ChatID string
}
func (r *ToolResultHookResponse) Clone() *ToolResultHookResponse {
if r == nil {
return nil
}
cloned := *r
cloned.Arguments = cloneStringAnyMap(r.Arguments)
cloned.Result = cloneToolResult(r.Result)
return &cloned
}
type HookManager struct {
eventBus *EventBus
observerTimeout time.Duration
interceptorTimeout time.Duration
approvalTimeout time.Duration
mu sync.RWMutex
hooks map[string]HookRegistration
ordered []HookRegistration
sub EventSubscription
done chan struct{}
closeOnce sync.Once
}
func NewHookManager(eventBus *EventBus) *HookManager {
hm := &HookManager{
eventBus: eventBus,
observerTimeout: defaultHookObserverTimeout,
interceptorTimeout: defaultHookInterceptorTimeout,
approvalTimeout: defaultHookApprovalTimeout,
hooks: make(map[string]HookRegistration),
done: make(chan struct{}),
}
if eventBus == nil {
close(hm.done)
return hm
}
hm.sub = eventBus.Subscribe(hookObserverBufferSize)
go hm.dispatchEvents()
return hm
}
func (hm *HookManager) Close() {
if hm == nil {
return
}
hm.closeOnce.Do(func() {
if hm.eventBus != nil {
hm.eventBus.Unsubscribe(hm.sub.ID)
}
<-hm.done
})
}
func (hm *HookManager) Mount(reg HookRegistration) error {
if hm == nil {
return fmt.Errorf("hook manager is nil")
}
if reg.Name == "" {
return fmt.Errorf("hook name is required")
}
if reg.Hook == nil {
return fmt.Errorf("hook %q is nil", reg.Name)
}
hm.mu.Lock()
defer hm.mu.Unlock()
hm.hooks[reg.Name] = reg
hm.rebuildOrdered()
return nil
}
func (hm *HookManager) Unmount(name string) {
if hm == nil || name == "" {
return
}
hm.mu.Lock()
defer hm.mu.Unlock()
delete(hm.hooks, name)
hm.rebuildOrdered()
}
func (hm *HookManager) dispatchEvents() {
defer close(hm.done)
for evt := range hm.sub.C {
for _, reg := range hm.snapshotHooks() {
observer, ok := reg.Hook.(EventObserver)
if !ok {
continue
}
hm.runObserver(reg.Name, observer, evt)
}
}
}
func (hm *HookManager) BeforeLLM(ctx context.Context, req *LLMHookRequest) (*LLMHookRequest, HookDecision) {
if hm == nil || req == nil {
return req, HookDecision{Action: HookActionContinue}
}
current := req.Clone()
for _, reg := range hm.snapshotHooks() {
interceptor, ok := reg.Hook.(LLMInterceptor)
if !ok {
continue
}
next, decision, ok := hm.callBeforeLLM(ctx, reg.Name, interceptor, current.Clone())
if !ok {
continue
}
switch decision.normalizedAction() {
case HookActionContinue, HookActionModify:
if next != nil {
current = next
}
case HookActionAbortTurn, HookActionHardAbort:
return current, decision
default:
hm.logUnsupportedAction(reg.Name, "before_llm", decision.Action)
}
}
return current, HookDecision{Action: HookActionContinue}
}
func (hm *HookManager) AfterLLM(ctx context.Context, resp *LLMHookResponse) (*LLMHookResponse, HookDecision) {
if hm == nil || resp == nil {
return resp, HookDecision{Action: HookActionContinue}
}
current := resp.Clone()
for _, reg := range hm.snapshotHooks() {
interceptor, ok := reg.Hook.(LLMInterceptor)
if !ok {
continue
}
next, decision, ok := hm.callAfterLLM(ctx, reg.Name, interceptor, current.Clone())
if !ok {
continue
}
switch decision.normalizedAction() {
case HookActionContinue, HookActionModify:
if next != nil {
current = next
}
case HookActionAbortTurn, HookActionHardAbort:
return current, decision
default:
hm.logUnsupportedAction(reg.Name, "after_llm", decision.Action)
}
}
return current, HookDecision{Action: HookActionContinue}
}
func (hm *HookManager) BeforeTool(
ctx context.Context,
call *ToolCallHookRequest,
) (*ToolCallHookRequest, HookDecision) {
if hm == nil || call == nil {
return call, HookDecision{Action: HookActionContinue}
}
current := call.Clone()
for _, reg := range hm.snapshotHooks() {
interceptor, ok := reg.Hook.(ToolInterceptor)
if !ok {
continue
}
next, decision, ok := hm.callBeforeTool(ctx, reg.Name, interceptor, current.Clone())
if !ok {
continue
}
switch decision.normalizedAction() {
case HookActionContinue, HookActionModify:
if next != nil {
current = next
}
case HookActionDenyTool, HookActionAbortTurn, HookActionHardAbort:
return current, decision
default:
hm.logUnsupportedAction(reg.Name, "before_tool", decision.Action)
}
}
return current, HookDecision{Action: HookActionContinue}
}
func (hm *HookManager) AfterTool(
ctx context.Context,
result *ToolResultHookResponse,
) (*ToolResultHookResponse, HookDecision) {
if hm == nil || result == nil {
return result, HookDecision{Action: HookActionContinue}
}
current := result.Clone()
for _, reg := range hm.snapshotHooks() {
interceptor, ok := reg.Hook.(ToolInterceptor)
if !ok {
continue
}
next, decision, ok := hm.callAfterTool(ctx, reg.Name, interceptor, current.Clone())
if !ok {
continue
}
switch decision.normalizedAction() {
case HookActionContinue, HookActionModify:
if next != nil {
current = next
}
case HookActionAbortTurn, HookActionHardAbort:
return current, decision
default:
hm.logUnsupportedAction(reg.Name, "after_tool", decision.Action)
}
}
return current, HookDecision{Action: HookActionContinue}
}
func (hm *HookManager) ApproveTool(ctx context.Context, req *ToolApprovalRequest) ApprovalDecision {
if hm == nil || req == nil {
return ApprovalDecision{Approved: true}
}
for _, reg := range hm.snapshotHooks() {
approver, ok := reg.Hook.(ToolApprover)
if !ok {
continue
}
decision, ok := hm.callApproveTool(ctx, reg.Name, approver, req.Clone())
if !ok {
return ApprovalDecision{
Approved: false,
Reason: fmt.Sprintf("tool approval hook %q failed", reg.Name),
}
}
if !decision.Approved {
return decision
}
}
return ApprovalDecision{Approved: true}
}
func (hm *HookManager) rebuildOrdered() {
hm.ordered = hm.ordered[:0]
for _, reg := range hm.hooks {
hm.ordered = append(hm.ordered, reg)
}
sort.SliceStable(hm.ordered, func(i, j int) bool {
if hm.ordered[i].Priority == hm.ordered[j].Priority {
return hm.ordered[i].Name < hm.ordered[j].Name
}
return hm.ordered[i].Priority < hm.ordered[j].Priority
})
}
func (hm *HookManager) snapshotHooks() []HookRegistration {
hm.mu.RLock()
defer hm.mu.RUnlock()
snapshot := make([]HookRegistration, len(hm.ordered))
copy(snapshot, hm.ordered)
return snapshot
}
func (hm *HookManager) runObserver(name string, observer EventObserver, evt Event) {
ctx, cancel := context.WithTimeout(context.Background(), hm.observerTimeout)
defer cancel()
done := make(chan error, 1)
go func() {
done <- observer.OnEvent(ctx, evt)
}()
select {
case err := <-done:
if err != nil {
logger.WarnCF("hooks", "Event observer failed", map[string]any{
"hook": name,
"event": evt.Kind.String(),
"error": err.Error(),
})
}
case <-ctx.Done():
logger.WarnCF("hooks", "Event observer timed out", map[string]any{
"hook": name,
"event": evt.Kind.String(),
"timeout_ms": hm.observerTimeout.Milliseconds(),
})
}
}
func (hm *HookManager) callBeforeLLM(
parent context.Context,
name string,
interceptor LLMInterceptor,
req *LLMHookRequest,
) (*LLMHookRequest, HookDecision, bool) {
return runInterceptorHook(
parent,
hm.interceptorTimeout,
name,
"before_llm",
func(ctx context.Context) (*LLMHookRequest, HookDecision, error) {
return interceptor.BeforeLLM(ctx, req)
},
)
}
func (hm *HookManager) callAfterLLM(
parent context.Context,
name string,
interceptor LLMInterceptor,
resp *LLMHookResponse,
) (*LLMHookResponse, HookDecision, bool) {
return runInterceptorHook(
parent,
hm.interceptorTimeout,
name,
"after_llm",
func(ctx context.Context) (*LLMHookResponse, HookDecision, error) {
return interceptor.AfterLLM(ctx, resp)
},
)
}
func (hm *HookManager) callBeforeTool(
parent context.Context,
name string,
interceptor ToolInterceptor,
call *ToolCallHookRequest,
) (*ToolCallHookRequest, HookDecision, bool) {
return runInterceptorHook(
parent,
hm.interceptorTimeout,
name,
"before_tool",
func(ctx context.Context) (*ToolCallHookRequest, HookDecision, error) {
return interceptor.BeforeTool(ctx, call)
},
)
}
func (hm *HookManager) callAfterTool(
parent context.Context,
name string,
interceptor ToolInterceptor,
resultView *ToolResultHookResponse,
) (*ToolResultHookResponse, HookDecision, bool) {
return runInterceptorHook(
parent,
hm.interceptorTimeout,
name,
"after_tool",
func(ctx context.Context) (*ToolResultHookResponse, HookDecision, error) {
return interceptor.AfterTool(ctx, resultView)
},
)
}
func (hm *HookManager) callApproveTool(
parent context.Context,
name string,
approver ToolApprover,
req *ToolApprovalRequest,
) (ApprovalDecision, bool) {
return runApprovalHook(
parent,
hm.approvalTimeout,
name,
"approve_tool",
func(ctx context.Context) (ApprovalDecision, error) {
return approver.ApproveTool(ctx, req)
},
)
}
func runInterceptorHook[T any](
parent context.Context,
timeout time.Duration,
name string,
stage string,
fn func(ctx context.Context) (T, HookDecision, error),
) (T, HookDecision, bool) {
var zero T
ctx, cancel := context.WithTimeout(parent, timeout)
defer cancel()
type result struct {
value T
decision HookDecision
err error
}
done := make(chan result, 1)
go func() {
value, decision, err := fn(ctx)
done <- result{value: value, decision: decision, err: err}
}()
select {
case res := <-done:
if res.err != nil {
logger.WarnCF("hooks", "Interceptor hook failed", map[string]any{
"hook": name,
"stage": stage,
"error": res.err.Error(),
})
return zero, HookDecision{}, false
}
return res.value, res.decision, true
case <-ctx.Done():
logger.WarnCF("hooks", "Interceptor hook timed out", map[string]any{
"hook": name,
"stage": stage,
"timeout_ms": timeout.Milliseconds(),
})
return zero, HookDecision{}, false
}
}
func runApprovalHook(
parent context.Context,
timeout time.Duration,
name string,
stage string,
fn func(ctx context.Context) (ApprovalDecision, error),
) (ApprovalDecision, bool) {
ctx, cancel := context.WithTimeout(parent, timeout)
defer cancel()
type result struct {
decision ApprovalDecision
err error
}
done := make(chan result, 1)
go func() {
decision, err := fn(ctx)
done <- result{decision: decision, err: err}
}()
select {
case res := <-done:
if res.err != nil {
logger.WarnCF("hooks", "Approval hook failed", map[string]any{
"hook": name,
"stage": stage,
"error": res.err.Error(),
})
return ApprovalDecision{}, false
}
return res.decision, true
case <-ctx.Done():
logger.WarnCF("hooks", "Approval hook timed out", map[string]any{
"hook": name,
"stage": stage,
"timeout_ms": timeout.Milliseconds(),
})
return ApprovalDecision{
Approved: false,
Reason: fmt.Sprintf("tool approval hook %q timed out", name),
}, true
}
}
func (hm *HookManager) logUnsupportedAction(name, stage string, action HookAction) {
logger.WarnCF("hooks", "Hook returned unsupported action for stage", map[string]any{
"hook": name,
"stage": stage,
"action": action,
})
}
func cloneProviderMessages(messages []providers.Message) []providers.Message {
if len(messages) == 0 {
return nil
}
cloned := make([]providers.Message, len(messages))
for i, msg := range messages {
cloned[i] = msg
if len(msg.Media) > 0 {
cloned[i].Media = append([]string(nil), msg.Media...)
}
if len(msg.SystemParts) > 0 {
cloned[i].SystemParts = append([]providers.ContentBlock(nil), msg.SystemParts...)
}
if len(msg.ToolCalls) > 0 {
cloned[i].ToolCalls = cloneProviderToolCalls(msg.ToolCalls)
}
}
return cloned
}
func cloneProviderToolCalls(calls []providers.ToolCall) []providers.ToolCall {
if len(calls) == 0 {
return nil
}
cloned := make([]providers.ToolCall, len(calls))
for i, call := range calls {
cloned[i] = call
if call.Function != nil {
fn := *call.Function
cloned[i].Function = &fn
}
if call.Arguments != nil {
cloned[i].Arguments = cloneStringAnyMap(call.Arguments)
}
if call.ExtraContent != nil {
extra := *call.ExtraContent
if call.ExtraContent.Google != nil {
google := *call.ExtraContent.Google
extra.Google = &google
}
cloned[i].ExtraContent = &extra
}
}
return cloned
}
func cloneToolDefinitions(defs []providers.ToolDefinition) []providers.ToolDefinition {
if len(defs) == 0 {
return nil
}
cloned := make([]providers.ToolDefinition, len(defs))
for i, def := range defs {
cloned[i] = def
cloned[i].Function.Parameters = cloneStringAnyMap(def.Function.Parameters)
}
return cloned
}
func cloneLLMResponse(resp *providers.LLMResponse) *providers.LLMResponse {
if resp == nil {
return nil
}
cloned := *resp
cloned.ToolCalls = cloneProviderToolCalls(resp.ToolCalls)
if len(resp.ReasoningDetails) > 0 {
cloned.ReasoningDetails = append(cloned.ReasoningDetails[:0:0], resp.ReasoningDetails...)
}
if resp.Usage != nil {
usage := *resp.Usage
cloned.Usage = &usage
}
return &cloned
}
func cloneStringAnyMap(src map[string]any) map[string]any {
if len(src) == 0 {
return nil
}
cloned := make(map[string]any, len(src))
for k, v := range src {
cloned[k] = v
}
return cloned
}
func cloneToolResult(result *tools.ToolResult) *tools.ToolResult {
if result == nil {
return nil
}
cloned := *result
if len(result.Media) > 0 {
cloned.Media = append([]string(nil), result.Media...)
}
return &cloned
}
+312
View File
@@ -0,0 +1,312 @@
package agent
import (
"context"
"os"
"sync"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/tools"
)
func newHookTestLoop(
t *testing.T,
provider providers.LLMProvider,
) (*AgentLoop, *AgentInstance, func()) {
t.Helper()
tmpDir, err := os.MkdirTemp("", "agent-hooks-*")
if err != nil {
t.Fatalf("failed to create temp dir: %v", err)
}
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
al := NewAgentLoop(cfg, bus.NewMessageBus(), provider)
agent := al.registry.GetDefaultAgent()
if agent == nil {
t.Fatal("expected default agent")
}
return al, agent, func() {
al.Close()
_ = os.RemoveAll(tmpDir)
}
}
type llmHookTestProvider struct {
mu sync.Mutex
lastModel string
}
func (p *llmHookTestProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
p.mu.Lock()
p.lastModel = model
p.mu.Unlock()
return &providers.LLMResponse{
Content: "provider content",
}, nil
}
func (p *llmHookTestProvider) GetDefaultModel() string {
return "llm-hook-provider"
}
type llmObserverHook struct {
eventCh chan Event
}
func (h *llmObserverHook) OnEvent(ctx context.Context, evt Event) error {
if evt.Kind == EventKindTurnEnd {
select {
case h.eventCh <- evt:
default:
}
}
return nil
}
func (h *llmObserverHook) BeforeLLM(
ctx context.Context,
req *LLMHookRequest,
) (*LLMHookRequest, HookDecision, error) {
next := req.Clone()
next.Model = "hook-model"
return next, HookDecision{Action: HookActionModify}, nil
}
func (h *llmObserverHook) AfterLLM(
ctx context.Context,
resp *LLMHookResponse,
) (*LLMHookResponse, HookDecision, error) {
next := resp.Clone()
next.Response.Content = "hooked content"
return next, HookDecision{Action: HookActionModify}, nil
}
func TestAgentLoop_Hooks_ObserverAndLLMInterceptor(t *testing.T) {
provider := &llmHookTestProvider{}
al, agent, cleanup := newHookTestLoop(t, provider)
defer cleanup()
hook := &llmObserverHook{eventCh: make(chan Event, 1)}
if err := al.MountHook(NamedHook("llm-observer", hook)); err != nil {
t.Fatalf("MountHook failed: %v", err)
}
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
SessionKey: "session-1",
Channel: "cli",
ChatID: "direct",
UserMessage: "hello",
DefaultResponse: defaultResponse,
EnableSummary: false,
SendResponse: false,
})
if err != nil {
t.Fatalf("runAgentLoop failed: %v", err)
}
if resp != "hooked content" {
t.Fatalf("expected hooked content, got %q", resp)
}
provider.mu.Lock()
lastModel := provider.lastModel
provider.mu.Unlock()
if lastModel != "hook-model" {
t.Fatalf("expected model hook-model, got %q", lastModel)
}
select {
case evt := <-hook.eventCh:
if evt.Kind != EventKindTurnEnd {
t.Fatalf("expected turn end event, got %v", evt.Kind)
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for hook observer event")
}
}
type toolHookProvider struct {
mu sync.Mutex
calls int
}
func (p *toolHookProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
p.mu.Lock()
defer p.mu.Unlock()
p.calls++
if p.calls == 1 {
return &providers.LLMResponse{
ToolCalls: []providers.ToolCall{
{
ID: "call-1",
Name: "echo_text",
Arguments: map[string]any{"text": "original"},
},
},
}, nil
}
last := messages[len(messages)-1]
return &providers.LLMResponse{
Content: last.Content,
}, nil
}
func (p *toolHookProvider) GetDefaultModel() string {
return "tool-hook-provider"
}
type echoTextTool struct{}
func (t *echoTextTool) Name() string {
return "echo_text"
}
func (t *echoTextTool) Description() string {
return "echo a text argument"
}
func (t *echoTextTool) Parameters() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]any{
"text": map[string]any{
"type": "string",
},
},
}
}
func (t *echoTextTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
text, _ := args["text"].(string)
return tools.SilentResult(text)
}
type toolRewriteHook struct{}
func (h *toolRewriteHook) BeforeTool(
ctx context.Context,
call *ToolCallHookRequest,
) (*ToolCallHookRequest, HookDecision, error) {
next := call.Clone()
next.Arguments["text"] = "modified"
return next, HookDecision{Action: HookActionModify}, nil
}
func (h *toolRewriteHook) AfterTool(
ctx context.Context,
result *ToolResultHookResponse,
) (*ToolResultHookResponse, HookDecision, error) {
next := result.Clone()
next.Result.ForLLM = "after:" + next.Result.ForLLM
return next, HookDecision{Action: HookActionModify}, nil
}
func TestAgentLoop_Hooks_ToolInterceptorCanRewrite(t *testing.T) {
provider := &toolHookProvider{}
al, agent, cleanup := newHookTestLoop(t, provider)
defer cleanup()
al.RegisterTool(&echoTextTool{})
if err := al.MountHook(NamedHook("tool-rewrite", &toolRewriteHook{})); err != nil {
t.Fatalf("MountHook failed: %v", err)
}
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
SessionKey: "session-1",
Channel: "cli",
ChatID: "direct",
UserMessage: "run tool",
DefaultResponse: defaultResponse,
EnableSummary: false,
SendResponse: false,
})
if err != nil {
t.Fatalf("runAgentLoop failed: %v", err)
}
if resp != "after:modified" {
t.Fatalf("expected rewritten tool result, got %q", resp)
}
}
type denyApprovalHook struct{}
func (h *denyApprovalHook) ApproveTool(ctx context.Context, req *ToolApprovalRequest) (ApprovalDecision, error) {
return ApprovalDecision{
Approved: false,
Reason: "blocked",
}, nil
}
func TestAgentLoop_Hooks_ToolApproverCanDeny(t *testing.T) {
provider := &toolHookProvider{}
al, agent, cleanup := newHookTestLoop(t, provider)
defer cleanup()
al.RegisterTool(&echoTextTool{})
if err := al.MountHook(NamedHook("deny-approval", &denyApprovalHook{})); err != nil {
t.Fatalf("MountHook failed: %v", err)
}
sub := al.SubscribeEvents(16)
defer al.UnsubscribeEvents(sub.ID)
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
SessionKey: "session-1",
Channel: "cli",
ChatID: "direct",
UserMessage: "run tool",
DefaultResponse: defaultResponse,
EnableSummary: false,
SendResponse: false,
})
if err != nil {
t.Fatalf("runAgentLoop failed: %v", err)
}
expected := "Tool execution denied by approval hook: blocked"
if resp != expected {
t.Fatalf("expected %q, got %q", expected, resp)
}
events := collectEventStream(sub.C)
skippedEvt, ok := findEvent(events, EventKindToolExecSkipped)
if !ok {
t.Fatal("expected tool skipped event")
}
payload, ok := skippedEvt.Payload.(ToolExecSkippedPayload)
if !ok {
t.Fatalf("expected ToolExecSkippedPayload, got %T", skippedEvt.Payload)
}
if payload.Reason != expected {
t.Fatalf("expected skipped reason %q, got %q", expected, payload.Reason)
}
}
+262 -47
View File
@@ -40,6 +40,7 @@ type AgentLoop struct {
registry *AgentRegistry
state *state.Manager
eventBus *EventBus
hooks *HookManager
running atomic.Bool
summarizing sync.Map
fallback *providers.FallbackChain
@@ -108,17 +109,19 @@ func NewAgentLoop(
stateManager = state.NewManager(defaultAgent.Workspace)
}
eventBus := NewEventBus()
al := &AgentLoop{
bus: msgBus,
cfg: cfg,
registry: registry,
state: stateManager,
eventBus: NewEventBus(),
eventBus: eventBus,
summarizing: sync.Map{},
fallback: fallbackChain,
cmdRegistry: commands.NewRegistry(commands.BuiltinDefinitions()),
steering: newSteeringQueue(parseSteeringMode(cfg.Agents.Defaults.SteeringMode)),
}
al.hooks = NewHookManager(eventBus)
return al
}
@@ -460,11 +463,30 @@ func (al *AgentLoop) Close() {
}
al.GetRegistry().Close()
if al.hooks != nil {
al.hooks.Close()
}
if al.eventBus != nil {
al.eventBus.Close()
}
}
// MountHook registers an in-process hook on the agent loop.
func (al *AgentLoop) MountHook(reg HookRegistration) error {
if al == nil || al.hooks == nil {
return fmt.Errorf("hook manager is not initialized")
}
return al.hooks.Mount(reg)
}
// UnmountHook removes a previously registered in-process hook.
func (al *AgentLoop) UnmountHook(name string) {
if al == nil || al.hooks == nil {
return
}
al.hooks.Unmount(name)
}
// SubscribeEvents registers a subscriber for agent-loop events.
func (al *AgentLoop) SubscribeEvents(buffer int) EventSubscription {
if al == nil || al.eventBus == nil {
@@ -544,6 +566,31 @@ func cloneEventArguments(args map[string]any) map[string]any {
return cloned
}
func (al *AgentLoop) hookAbortError(ts *turnState, stage string, decision HookDecision) error {
reason := decision.Reason
if reason == "" {
reason = "hook requested turn abort"
}
err := fmt.Errorf("hook aborted turn during %s: %s", stage, reason)
al.emitEvent(
EventKindError,
ts.eventMeta("hooks", "turn.error"),
ErrorPayload{
Stage: "hook." + stage,
Message: err.Error(),
},
)
return err
}
func hookDeniedToolContent(prefix, reason string) string {
if reason == "" {
return prefix
}
return prefix + ": " + reason
}
func (al *AgentLoop) logEvent(evt Event) {
fields := map[string]any{
"event_kind": evt.Kind.String(),
@@ -1418,36 +1465,6 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
ts.markGracefulTerminalUsed()
}
al.emitEvent(
EventKindLLMRequest,
ts.eventMeta("runTurn", "turn.llm.request"),
LLMRequestPayload{
Model: activeModel,
MessagesCount: len(callMessages),
ToolsCount: len(providerToolDefs),
MaxTokens: ts.agent.MaxTokens,
Temperature: ts.agent.Temperature,
},
)
logger.DebugCF("agent", "LLM request",
map[string]any{
"agent_id": ts.agent.ID,
"iteration": iteration,
"model": activeModel,
"messages_count": len(callMessages),
"tools_count": len(providerToolDefs),
"max_tokens": ts.agent.MaxTokens,
"temperature": ts.agent.Temperature,
"system_prompt_len": len(callMessages[0].Content),
})
logger.DebugCF("agent", "Full LLM request",
map[string]any{
"iteration": iteration,
"messages_json": formatMessagesForLog(callMessages),
"tools_json": formatToolsForLog(providerToolDefs),
})
llmOpts := map[string]any{
"max_tokens": ts.agent.MaxTokens,
"temperature": ts.agent.Temperature,
@@ -1462,6 +1479,66 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
}
}
llmModel := activeModel
if al.hooks != nil {
llmReq, decision := al.hooks.BeforeLLM(turnCtx, &LLMHookRequest{
Meta: ts.eventMeta("runTurn", "turn.llm.request"),
Model: llmModel,
Messages: callMessages,
Tools: providerToolDefs,
Options: llmOpts,
Channel: ts.channel,
ChatID: ts.chatID,
GracefulTerminal: gracefulTerminal,
})
switch decision.normalizedAction() {
case HookActionContinue, HookActionModify:
if llmReq != nil {
llmModel = llmReq.Model
callMessages = llmReq.Messages
providerToolDefs = llmReq.Tools
llmOpts = llmReq.Options
}
case HookActionAbortTurn:
turnStatus = TurnEndStatusError
return turnResult{}, al.hookAbortError(ts, "before_llm", decision)
case HookActionHardAbort:
_ = ts.requestHardAbort()
turnStatus = TurnEndStatusAborted
return al.abortTurn(ts)
}
}
al.emitEvent(
EventKindLLMRequest,
ts.eventMeta("runTurn", "turn.llm.request"),
LLMRequestPayload{
Model: llmModel,
MessagesCount: len(callMessages),
ToolsCount: len(providerToolDefs),
MaxTokens: ts.agent.MaxTokens,
Temperature: ts.agent.Temperature,
},
)
logger.DebugCF("agent", "LLM request",
map[string]any{
"agent_id": ts.agent.ID,
"iteration": iteration,
"model": llmModel,
"messages_count": len(callMessages),
"tools_count": len(providerToolDefs),
"max_tokens": ts.agent.MaxTokens,
"temperature": ts.agent.Temperature,
"system_prompt_len": len(callMessages[0].Content),
})
logger.DebugCF("agent", "Full LLM request",
map[string]any{
"iteration": iteration,
"messages_json": formatMessagesForLog(callMessages),
"tools_json": formatToolsForLog(providerToolDefs),
})
callLLM := func(messagesForCall []providers.Message, toolDefsForCall []providers.ToolDefinition) (*providers.LLMResponse, error) {
providerCtx, providerCancel := context.WithCancel(turnCtx)
ts.setProviderCancel(providerCancel)
@@ -1494,7 +1571,7 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
}
return fbResult.Response, nil
}
return ts.agent.Provider.Chat(providerCtx, messagesForCall, toolDefsForCall, activeModel, llmOpts)
return ts.agent.Provider.Chat(providerCtx, messagesForCall, toolDefsForCall, llmModel, llmOpts)
}
var response *providers.LLMResponse
@@ -1626,12 +1703,35 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
map[string]any{
"agent_id": ts.agent.ID,
"iteration": iteration,
"model": activeModel,
"model": llmModel,
"error": err.Error(),
})
return turnResult{}, fmt.Errorf("LLM call failed after retries: %w", err)
}
if al.hooks != nil {
llmResp, decision := al.hooks.AfterLLM(turnCtx, &LLMHookResponse{
Meta: ts.eventMeta("runTurn", "turn.llm.response"),
Model: llmModel,
Response: response,
Channel: ts.channel,
ChatID: ts.chatID,
})
switch decision.normalizedAction() {
case HookActionContinue, HookActionModify:
if llmResp != nil && llmResp.Response != nil {
response = llmResp.Response
}
case HookActionAbortTurn:
turnStatus = TurnEndStatusError
return turnResult{}, al.hookAbortError(ts, "after_llm", decision)
case HookActionHardAbort:
_ = ts.requestHardAbort()
turnStatus = TurnEndStatusAborted
return al.abortTurn(ts)
}
}
go al.handleReasoning(
turnCtx,
response.Reasoning,
@@ -1728,25 +1828,106 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
return al.abortTurn(ts)
}
argsJSON, _ := json.Marshal(tc.Arguments)
toolName := tc.Name
toolArgs := cloneStringAnyMap(tc.Arguments)
if al.hooks != nil {
toolReq, decision := al.hooks.BeforeTool(turnCtx, &ToolCallHookRequest{
Meta: ts.eventMeta("runTurn", "turn.tool.before"),
Tool: toolName,
Arguments: toolArgs,
Channel: ts.channel,
ChatID: ts.chatID,
})
switch decision.normalizedAction() {
case HookActionContinue, HookActionModify:
if toolReq != nil {
toolName = toolReq.Tool
toolArgs = toolReq.Arguments
}
case HookActionDenyTool:
denyContent := hookDeniedToolContent("Tool execution denied by hook", decision.Reason)
al.emitEvent(
EventKindToolExecSkipped,
ts.eventMeta("runTurn", "turn.tool.skipped"),
ToolExecSkippedPayload{
Tool: toolName,
Reason: denyContent,
},
)
deniedMsg := providers.Message{
Role: "tool",
Content: denyContent,
ToolCallID: tc.ID,
}
messages = append(messages, deniedMsg)
if !ts.opts.NoHistory {
ts.agent.Sessions.AddFullMessage(ts.sessionKey, deniedMsg)
ts.recordPersistedMessage(deniedMsg)
}
continue
case HookActionAbortTurn:
turnStatus = TurnEndStatusError
return turnResult{}, al.hookAbortError(ts, "before_tool", decision)
case HookActionHardAbort:
_ = ts.requestHardAbort()
turnStatus = TurnEndStatusAborted
return al.abortTurn(ts)
}
}
if al.hooks != nil {
approval := al.hooks.ApproveTool(turnCtx, &ToolApprovalRequest{
Meta: ts.eventMeta("runTurn", "turn.tool.approve"),
Tool: toolName,
Arguments: toolArgs,
Channel: ts.channel,
ChatID: ts.chatID,
})
if !approval.Approved {
denyContent := hookDeniedToolContent("Tool execution denied by approval hook", approval.Reason)
al.emitEvent(
EventKindToolExecSkipped,
ts.eventMeta("runTurn", "turn.tool.skipped"),
ToolExecSkippedPayload{
Tool: toolName,
Reason: denyContent,
},
)
deniedMsg := providers.Message{
Role: "tool",
Content: denyContent,
ToolCallID: tc.ID,
}
messages = append(messages, deniedMsg)
if !ts.opts.NoHistory {
ts.agent.Sessions.AddFullMessage(ts.sessionKey, deniedMsg)
ts.recordPersistedMessage(deniedMsg)
}
continue
}
}
argsJSON, _ := json.Marshal(toolArgs)
argsPreview := utils.Truncate(string(argsJSON), 200)
logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", toolName, argsPreview),
map[string]any{
"agent_id": ts.agent.ID,
"tool": tc.Name,
"tool": toolName,
"iteration": iteration,
})
al.emitEvent(
EventKindToolExecStart,
ts.eventMeta("runTurn", "turn.tool.start"),
ToolExecStartPayload{
Tool: tc.Name,
Arguments: cloneEventArguments(tc.Arguments),
Tool: toolName,
Arguments: cloneEventArguments(toolArgs),
},
)
toolCall := tc
toolCallID := tc.ID
toolIteration := iteration
asyncToolName := toolName
asyncCallback := func(_ context.Context, result *tools.ToolResult) {
if !result.Silent && result.ForUser != "" {
outCtx, outCancel := context.WithTimeout(context.Background(), 5*time.Second)
@@ -1768,7 +1949,7 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
logger.InfoCF("agent", "Async tool completed, publishing result",
map[string]any{
"tool": toolCall.Name,
"tool": asyncToolName,
"content_len": len(content),
"channel": ts.channel,
})
@@ -1776,7 +1957,7 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
EventKindFollowUpQueued,
ts.scope.meta(toolIteration, "runTurn", "turn.follow_up.queued"),
FollowUpQueuedPayload{
SourceTool: toolCall.Name,
SourceTool: asyncToolName,
Channel: ts.channel,
ChatID: ts.chatID,
ContentLen: len(content),
@@ -1787,7 +1968,7 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
defer pubCancel()
_ = al.bus.PublishInbound(pubCtx, bus.InboundMessage{
Channel: "system",
SenderID: fmt.Sprintf("async:%s", toolCall.Name),
SenderID: fmt.Sprintf("async:%s", asyncToolName),
ChatID: fmt.Sprintf("%s:%s", ts.channel, ts.chatID),
Content: content,
})
@@ -1796,8 +1977,8 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
toolStart := time.Now()
toolResult := ts.agent.Tools.ExecuteWithContext(
turnCtx,
toolCall.Name,
toolCall.Arguments,
toolName,
toolArgs,
ts.channel,
ts.chatID,
asyncCallback,
@@ -1809,6 +1990,40 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
return al.abortTurn(ts)
}
if al.hooks != nil {
toolResp, decision := al.hooks.AfterTool(turnCtx, &ToolResultHookResponse{
Meta: ts.eventMeta("runTurn", "turn.tool.after"),
Tool: toolName,
Arguments: toolArgs,
Result: toolResult,
Duration: toolDuration,
Channel: ts.channel,
ChatID: ts.chatID,
})
switch decision.normalizedAction() {
case HookActionContinue, HookActionModify:
if toolResp != nil {
if toolResp.Tool != "" {
toolName = toolResp.Tool
}
if toolResp.Result != nil {
toolResult = toolResp.Result
}
}
case HookActionAbortTurn:
turnStatus = TurnEndStatusError
return turnResult{}, al.hookAbortError(ts, "after_tool", decision)
case HookActionHardAbort:
_ = ts.requestHardAbort()
turnStatus = TurnEndStatusAborted
return al.abortTurn(ts)
}
}
if toolResult == nil {
toolResult = tools.ErrorResult("hook returned nil tool result")
}
if !toolResult.Silent && toolResult.ForUser != "" && ts.opts.SendResponse {
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
Channel: ts.channel,
@@ -1817,7 +2032,7 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
})
logger.DebugCF("agent", "Sent tool result to user",
map[string]any{
"tool": toolCall.Name,
"tool": toolName,
"content_len": len(toolResult.ForUser),
})
}
@@ -1850,13 +2065,13 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
toolResultMsg := providers.Message{
Role: "tool",
Content: contentForLLM,
ToolCallID: toolCall.ID,
ToolCallID: toolCallID,
}
al.emitEvent(
EventKindToolExecEnd,
ts.eventMeta("runTurn", "turn.tool.end"),
ToolExecEndPayload{
Tool: toolCall.Name,
Tool: toolName,
Duration: toolDuration,
ForLLMLen: len(contentForLLM),
ForUserLen: len(toolResult.ForUser),