From cf68c91ecaa15c3518686e7cfa9c637fabfcbead Mon Sep 17 00:00:00 2001 From: Hoshina Date: Sat, 21 Mar 2026 19:15:10 +0800 Subject: [PATCH] feat(agent): add hook manager foundation --- docs/design/hook-system-design.zh.md | 476 +++++++++++++++++ pkg/agent/hooks.go | 751 +++++++++++++++++++++++++++ pkg/agent/hooks_test.go | 312 +++++++++++ pkg/agent/loop.go | 309 +++++++++-- 4 files changed, 1801 insertions(+), 47 deletions(-) create mode 100644 docs/design/hook-system-design.zh.md create mode 100644 pkg/agent/hooks.go create mode 100644 pkg/agent/hooks_test.go diff --git a/docs/design/hook-system-design.zh.md b/docs/design/hook-system-design.zh.md new file mode 100644 index 000000000..ab5566bec --- /dev/null +++ b/docs/design/hook-system-design.zh.md @@ -0,0 +1,476 @@ +# PicoClaw Hook 系统设计(基于 `refactor/agent`) + +## 背景 + +本设计围绕两个议题展开: + +- `#1316`:把 agent loop 重构为事件驱动、可中断、可追加、可观测 +- `#1796`:在 EventBus 稳定后,把 hooks 设计为 EventBus 的 consumer,而不是重新发明一套事件模型 + +当前分支已经完成了第一步里的“事件系统基础”,但还没有真正的 hook 挂载层。因此这里的目标不是重新设计 event,而是在已有实现上补出一层可扩展、可拦截、可外挂的 HookManager。 + +## 外部项目对比 + +### OpenClaw + +OpenClaw 的扩展能力分成三层: + +- Internal hooks:目录发现,运行在 Gateway 进程内 +- Plugin hooks:插件在运行时注册 hook,也在进程内 +- Webhooks:外部系统通过 HTTP 触发 Gateway 动作,属于进程外 + +值得借鉴的点: + +- 有“项目内挂载”和“项目外挂载”两种路径 +- hook 是配置驱动,可启停 +- 外部入口有明确的安全边界和映射层 + +不建议直接照搬的点: + +- OpenClaw 的 hooks / plugin hooks / webhooks 是三套路由,PicoClaw 当前体量下会偏重 +- HTTP webhook 更适合“事件进入系统”,不适合作为“可同步拦截 agent loop”的基础机制 + +### pi-mono + +pi-mono 的核心思路更接近当前分支: + +- 扩展统一为 extension API +- 事件分为观察型和可变更型 +- 某些阶段允许 `transform` / `block` / `replace` +- 扩展代码主要是进程内执行 +- RPC mode 把 UI 交互桥接到进程外客户端 + +值得借鉴的点: + +- 不把“观察”和“拦截”混成一个接口 +- 允许返回结构化动作,而不是只有回调 +- 进程外通信只暴露必要协议,不把整个内部对象图泄露出去 + +## 当前分支现状 + +### 已有能力 + +当前分支已经具备 hook 系统的地基: + +- `pkg/agent/events.go` 定义了稳定的 `EventKind`、`EventMeta` 和 payload +- `pkg/agent/eventbus.go` 提供了非阻塞 fan-out 的 `EventBus` +- `pkg/agent/loop.go` 中的 `runTurn()` 已在 turn、llm、tool、interrupt、follow-up、summary 等节点发射事件 +- `pkg/agent/steering.go` 已支持 steering、graceful interrupt、hard abort +- `pkg/agent/turn.go` 已维护 turn phase、恢复点、active turn、abort 状态 + +### 现有缺口 + +当前分支还缺四件事: + +- 没有 HookManager,只有 EventBus +- 没有 Before/After LLM、Before/After Tool 这种同步拦截点 +- 没有审批型 hook +- 子 agent 仍走 `pkg/tools/SubagentManager + RunToolLoop`,没有接入 `pkg/agent` 的 turn tree 和事件流 + +### 一个关键现实 + +`#1316` 文案里提到“只读并行、写入串行”的工具执行策略,但当前 `runTurn()` 实现已经先收敛成“顺序执行 + 每个工具后检查 steering / interrupt”。因此 hook 设计不应依赖未来的并行模型,而应该先兼容当前顺序执行,再为以后增加 `ReadOnlyIndicator` 留口子。 + +## 设计原则 + +- Hook 必须建立在 `pkg/agent` 的 EventBus 和 turn 上下文之上 +- EventBus 负责广播,HookManager 负责拦截,两者职责分离 +- 项目内挂载要简单,项目外挂载必须走 IPC +- 观察型 hook 不能阻塞 loop;拦截型 hook 必须有超时 +- 先覆盖主 turn,不把 sub-turn 一次做满 +- 不新增第二套用户事件命名系统,优先复用 `EventKind.String()` + +## 总体架构 + +分成三层: + +1. `EventBus` + 负责广播只读事件,现有实现直接复用 + +2. `HookManager` + 负责管理 hook、排序、超时、错误隔离,并在 `runTurn()` 的明确检查点执行同步拦截 + +3. `HookMount` + 负责两种挂载方式: + - 进程内 Go hook + - 进程外 IPC hook + +换句话说: + +- EventBus 是“发生了什么” +- HookManager 是“谁能介入” +- HookMount 是“这些 hook 从哪里来” + +## Hook 分类 + +不建议把所有 hook 都设计成 `OnEvent(evt)`。 + +建议拆成两类。 + +### 1. 观察型 + +只消费事件,不修改流程: + +```go +type EventObserver interface { + OnEvent(ctx context.Context, evt agent.Event) error +} +``` + +这类 hook 直接订阅 EventBus 即可。 + +适用场景: + +- 审计日志 +- 指标上报 +- 调试 trace +- 将事件转发给外部 UI / TUI / Web 面板 + +### 2. 拦截型 + +只在少数明确节点触发,允许返回动作: + +```go +type LLMInterceptor interface { + BeforeLLM(ctx context.Context, req *LLMRequest) HookDecision[*LLMRequest] + AfterLLM(ctx context.Context, resp *LLMResponse) HookDecision[*LLMResponse] +} + +type ToolInterceptor interface { + BeforeTool(ctx context.Context, call *ToolCall) HookDecision[*ToolCall] + AfterTool(ctx context.Context, result *ToolResultView) HookDecision[*ToolResultView] +} + +type ToolApprover interface { + ApproveTool(ctx context.Context, req *ToolApprovalRequest) ApprovalDecision +} +``` + +这里的 `HookDecision` 统一支持: + +- `continue` +- `modify` +- `deny_tool` +- `abort_turn` +- `hard_abort` + +## 对外暴露的最小 hook 面 + +V1 不需要把所有 EventKind 都变成可拦截点。 + +建议只开放这些同步 hook: + +- `before_llm` +- `after_llm` +- `before_tool` +- `after_tool` +- `approve_tool` + +其余节点继续作为只读事件暴露: + +- `turn_start` +- `turn_end` +- `llm_request` +- `llm_response` +- `tool_exec_start` +- `tool_exec_end` +- `tool_exec_skipped` +- `steering_injected` +- `follow_up_queued` +- `interrupt_received` +- `context_compress` +- `session_summarize` +- `error` + +`subturn_*` 在 V1 中保留名字,但不承诺一定触发,直到子 turn 迁移完成。 + +## 项目内挂载 + +内部挂载必须尽量低摩擦。 + +建议提供两种等价方式,底层都走 HookManager。 + +### 方式 A:代码显式挂载 + +```go +al.MountHook(hooks.Named("audit", &AuditHook{})) +``` + +适用于: + +- 仓内内建 hook +- 单元测试 +- feature flag 控制 + +### 方式 B:内建 registry + +```go +func init() { + hooks.RegisterBuiltin("audit", func() hooks.Hook { + return &AuditHook{} + }) +} +``` + +启动时根据配置启用: + +```json +{ + "hooks": { + "builtins": { + "audit": { "enabled": true } + } + } +} +``` + +这比 OpenClaw 的目录扫描更轻,也更贴合 Go 项目。 + +## 项目外挂载 + +这是本设计的硬要求。 + +建议 V1 采用: + +- `JSON-RPC over stdio` + +原因: + +- 跨平台最简单 +- 不依赖额外端口 +- 非常适合“由 PicoClaw 启动一个外部 hook 进程” +- 比 HTTP webhook 更适合同步拦截 + +### 外部 hook 进程模型 + +PicoClaw 启动外部进程,并在其 stdin/stdout 上跑协议。 + +配置示例: + +```json +{ + "hooks": { + "processes": { + "review-gate": { + "enabled": true, + "transport": "stdio", + "command": ["uvx", "picoclaw-hook-reviewer"], + "observe": ["turn_start", "turn_end", "tool_exec_end"], + "intercept": ["before_tool", "approve_tool"], + "timeout_ms": 5000 + } + } + } +} +``` + +### 协议边界 + +不要把内部 Go 结构体直接暴露给 IPC。 + +建议定义稳定的协议对象: + +- `HookHandshake` +- `HookEventNotification` +- `BeforeLLMRequest` +- `AfterLLMRequest` +- `BeforeToolRequest` +- `AfterToolRequest` +- `ApproveToolRequest` +- `HookDecision` + +其中: + +- 观察型事件用 notification,fire-and-forget +- 拦截型事件用 request/response,同步等待 + +### 为什么是 stdio,而不是直接用 HTTP webhook + +因为两者用途不同: + +- HTTP webhook 更适合“外部系统向 PicoClaw 投递事件” +- stdio/RPC 更适合“PicoClaw 在 turn 内同步询问外部 hook 是否改写 / 放行 / 拒绝” + +如果未来需要 OpenClaw 式 webhook,可以作为独立入口层,再把外部事件转成 inbound message 或 steering,而不是直接替代 hook IPC。 + +## Hook 执行顺序 + +建议统一排序规则: + +- 先内建 in-process hook +- 再外部 IPC hook +- 同组内按 `priority` 从小到大执行 + +原因: + +- 内建 hook 延迟更低,适合做基础规范化 +- 外部 hook 更适合做审批、审计、组织级策略 + +## 超时与错误策略 + +### 观察型 + +- 默认超时:`500ms` +- 超时或报错:记录日志,继续主流程 + +### 拦截型 + +- `before_llm` / `after_llm` / `before_tool` / `after_tool`:默认 `5s` +- `approve_tool`:默认 `60s` + +超时行为: + +- 普通拦截:`continue` +- 审批:`deny` + +这点应直接沿用 `#1316` 的安全倾向。 + +## 与当前分支的对接点 + +### 直接复用 + +- 事件定义:`pkg/agent/events.go` +- 事件广播:`pkg/agent/eventbus.go` +- 活跃 turn / interrupt / rollback:`pkg/agent/turn.go` +- 事件发射点:`pkg/agent/loop.go` + +### 需要新增 + +- `pkg/agent/hooks.go` + - Hook 接口 + - HookDecision / ApprovalDecision + - HookManager + +- `pkg/agent/hook_mount.go` + - 内建 hook 注册 + - 外部进程 hook 注册 + +- `pkg/agent/hook_ipc.go` + - stdio JSON-RPC bridge + +- `pkg/agent/hook_types.go` + - IPC 稳定载荷 + +### 需要改造 + +- `pkg/agent/loop.go` + - 在 LLM 和 tool 关键路径前后插入 HookManager 调用 + +- `pkg/tools/base.go` + - 可选新增 `ReadOnlyIndicator` + +- `pkg/tools/spawn.go` +- `pkg/tools/subagent.go` + - 先保留现状 + - 等 sub-turn 迁移后再接入 `subturn_*` hook + +## 一个更贴合当前分支的数据流 + +### 观察链路 + +```text +runTurn() -> emitEvent() -> EventBus -> observers +``` + +### 拦截链路 + +```text +runTurn() + -> HookManager.BeforeLLM() + -> Provider.Chat() + -> HookManager.AfterLLM() + -> HookManager.BeforeTool() + -> HookManager.ApproveTool() + -> tool.Execute() + -> HookManager.AfterTool() +``` + +也就是说: + +- observer 不改变现有 `emitEvent()` +- interceptor 直接插在 `runTurn()` 热路径 + +## 用户可见配置 + +建议新增: + +```json +{ + "hooks": { + "enabled": true, + "builtins": {}, + "processes": {}, + "defaults": { + "observer_timeout_ms": 500, + "interceptor_timeout_ms": 5000, + "approval_timeout_ms": 60000 + } + } +} +``` + +V1 不做复杂自动发现。 + +原因: + +- 当前分支重点是把地基打稳 +- 目录扫描、安装器、脚手架可以后置 +- 先让仓内和仓外都能挂上去,比“管理体验完整”更重要 + +## 推荐的 V1 范围 + +### 必做 + +- HookManager +- in-process 挂载 +- stdio IPC 挂载 +- observer hooks +- `before_tool` / `after_tool` / `approve_tool` +- `before_llm` / `after_llm` + +### 可后置 + +- hook CLI 管理命令 +- hook 自动发现 +- Unix socket / named pipe transport +- sub-turn hook 生命周期 +- read-only 并行分组 +- webhook 到 inbound message 的映射入口 + +## 分阶段落地 + +### Phase 1 + +- 引入 HookManager +- 支持 in-process observer + interceptor +- 先只接主 turn + +### Phase 2 + +- 引入 `stdio` 外部 hook 进程桥 +- 支持组织级审批 / 审计 / 参数改写 + +### Phase 3 + +- 把 `SubagentManager` 迁移到 `runTurn/sub-turn` +- 接通 `subturn_spawn` / `subturn_end` / `subturn_result_delivered` + +### Phase 4 + +- 视需求补 `ReadOnlyIndicator` +- 在主 turn 和 sub-turn 上统一只读并行策略 + +## 最终结论 + +最适合 PicoClaw 当前分支的方案,不是直接复制 OpenClaw 的 hooks,也不是完整照搬 pi-mono 的 extension system,而是: + +- 以现有 `EventBus` 为只读观察面 +- 以新增 `HookManager` 为同步拦截面 +- 项目内通过 Go 对象直接挂载 +- 项目外通过 `stdio JSON-RPC` 进程通信挂载 + +这样做有三个好处: + +- 和 `#1796` 一致,hooks 只是 EventBus 之上的消费层 +- 和当前 `refactor/agent` 实现一致,不需要推翻已有事件系统 +- 同时满足“仓内简单挂载”和“仓外进程通信挂载”两个硬需求 diff --git a/pkg/agent/hooks.go b/pkg/agent/hooks.go new file mode 100644 index 000000000..74af542fa --- /dev/null +++ b/pkg/agent/hooks.go @@ -0,0 +1,751 @@ +package agent + +import ( + "context" + "fmt" + "sort" + "sync" + "time" + + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/tools" +) + +const ( + defaultHookObserverTimeout = 500 * time.Millisecond + defaultHookInterceptorTimeout = 5 * time.Second + defaultHookApprovalTimeout = 60 * time.Second + hookObserverBufferSize = 64 +) + +type HookAction string + +const ( + HookActionContinue HookAction = "continue" + HookActionModify HookAction = "modify" + HookActionDenyTool HookAction = "deny_tool" + HookActionAbortTurn HookAction = "abort_turn" + HookActionHardAbort HookAction = "hard_abort" +) + +type HookDecision struct { + Action HookAction + Reason string +} + +func (d HookDecision) normalizedAction() HookAction { + if d.Action == "" { + return HookActionContinue + } + return d.Action +} + +type ApprovalDecision struct { + Approved bool + Reason string +} + +type HookRegistration struct { + Name string + Priority int + Hook any +} + +func NamedHook(name string, hook any) HookRegistration { + return HookRegistration{ + Name: name, + Hook: hook, + } +} + +type EventObserver interface { + OnEvent(ctx context.Context, evt Event) error +} + +type LLMInterceptor interface { + BeforeLLM(ctx context.Context, req *LLMHookRequest) (*LLMHookRequest, HookDecision, error) + AfterLLM(ctx context.Context, resp *LLMHookResponse) (*LLMHookResponse, HookDecision, error) +} + +type ToolInterceptor interface { + BeforeTool(ctx context.Context, call *ToolCallHookRequest) (*ToolCallHookRequest, HookDecision, error) + AfterTool(ctx context.Context, result *ToolResultHookResponse) (*ToolResultHookResponse, HookDecision, error) +} + +type ToolApprover interface { + ApproveTool(ctx context.Context, req *ToolApprovalRequest) (ApprovalDecision, error) +} + +type LLMHookRequest struct { + Meta EventMeta + Model string + Messages []providers.Message + Tools []providers.ToolDefinition + Options map[string]any + Channel string + ChatID string + GracefulTerminal bool +} + +func (r *LLMHookRequest) Clone() *LLMHookRequest { + if r == nil { + return nil + } + cloned := *r + cloned.Messages = cloneProviderMessages(r.Messages) + cloned.Tools = cloneToolDefinitions(r.Tools) + cloned.Options = cloneStringAnyMap(r.Options) + return &cloned +} + +type LLMHookResponse struct { + Meta EventMeta + Model string + Response *providers.LLMResponse + Channel string + ChatID string +} + +func (r *LLMHookResponse) Clone() *LLMHookResponse { + if r == nil { + return nil + } + cloned := *r + cloned.Response = cloneLLMResponse(r.Response) + return &cloned +} + +type ToolCallHookRequest struct { + Meta EventMeta + Tool string + Arguments map[string]any + Channel string + ChatID string +} + +func (r *ToolCallHookRequest) Clone() *ToolCallHookRequest { + if r == nil { + return nil + } + cloned := *r + cloned.Arguments = cloneStringAnyMap(r.Arguments) + return &cloned +} + +type ToolApprovalRequest struct { + Meta EventMeta + Tool string + Arguments map[string]any + Channel string + ChatID string +} + +func (r *ToolApprovalRequest) Clone() *ToolApprovalRequest { + if r == nil { + return nil + } + cloned := *r + cloned.Arguments = cloneStringAnyMap(r.Arguments) + return &cloned +} + +type ToolResultHookResponse struct { + Meta EventMeta + Tool string + Arguments map[string]any + Result *tools.ToolResult + Duration time.Duration + Channel string + ChatID string +} + +func (r *ToolResultHookResponse) Clone() *ToolResultHookResponse { + if r == nil { + return nil + } + cloned := *r + cloned.Arguments = cloneStringAnyMap(r.Arguments) + cloned.Result = cloneToolResult(r.Result) + return &cloned +} + +type HookManager struct { + eventBus *EventBus + observerTimeout time.Duration + interceptorTimeout time.Duration + approvalTimeout time.Duration + + mu sync.RWMutex + hooks map[string]HookRegistration + ordered []HookRegistration + + sub EventSubscription + done chan struct{} + closeOnce sync.Once +} + +func NewHookManager(eventBus *EventBus) *HookManager { + hm := &HookManager{ + eventBus: eventBus, + observerTimeout: defaultHookObserverTimeout, + interceptorTimeout: defaultHookInterceptorTimeout, + approvalTimeout: defaultHookApprovalTimeout, + hooks: make(map[string]HookRegistration), + done: make(chan struct{}), + } + + if eventBus == nil { + close(hm.done) + return hm + } + + hm.sub = eventBus.Subscribe(hookObserverBufferSize) + go hm.dispatchEvents() + return hm +} + +func (hm *HookManager) Close() { + if hm == nil { + return + } + + hm.closeOnce.Do(func() { + if hm.eventBus != nil { + hm.eventBus.Unsubscribe(hm.sub.ID) + } + <-hm.done + }) +} + +func (hm *HookManager) Mount(reg HookRegistration) error { + if hm == nil { + return fmt.Errorf("hook manager is nil") + } + if reg.Name == "" { + return fmt.Errorf("hook name is required") + } + if reg.Hook == nil { + return fmt.Errorf("hook %q is nil", reg.Name) + } + + hm.mu.Lock() + defer hm.mu.Unlock() + + hm.hooks[reg.Name] = reg + hm.rebuildOrdered() + return nil +} + +func (hm *HookManager) Unmount(name string) { + if hm == nil || name == "" { + return + } + + hm.mu.Lock() + defer hm.mu.Unlock() + + delete(hm.hooks, name) + hm.rebuildOrdered() +} + +func (hm *HookManager) dispatchEvents() { + defer close(hm.done) + + for evt := range hm.sub.C { + for _, reg := range hm.snapshotHooks() { + observer, ok := reg.Hook.(EventObserver) + if !ok { + continue + } + hm.runObserver(reg.Name, observer, evt) + } + } +} + +func (hm *HookManager) BeforeLLM(ctx context.Context, req *LLMHookRequest) (*LLMHookRequest, HookDecision) { + if hm == nil || req == nil { + return req, HookDecision{Action: HookActionContinue} + } + + current := req.Clone() + for _, reg := range hm.snapshotHooks() { + interceptor, ok := reg.Hook.(LLMInterceptor) + if !ok { + continue + } + + next, decision, ok := hm.callBeforeLLM(ctx, reg.Name, interceptor, current.Clone()) + if !ok { + continue + } + + switch decision.normalizedAction() { + case HookActionContinue, HookActionModify: + if next != nil { + current = next + } + case HookActionAbortTurn, HookActionHardAbort: + return current, decision + default: + hm.logUnsupportedAction(reg.Name, "before_llm", decision.Action) + } + } + return current, HookDecision{Action: HookActionContinue} +} + +func (hm *HookManager) AfterLLM(ctx context.Context, resp *LLMHookResponse) (*LLMHookResponse, HookDecision) { + if hm == nil || resp == nil { + return resp, HookDecision{Action: HookActionContinue} + } + + current := resp.Clone() + for _, reg := range hm.snapshotHooks() { + interceptor, ok := reg.Hook.(LLMInterceptor) + if !ok { + continue + } + + next, decision, ok := hm.callAfterLLM(ctx, reg.Name, interceptor, current.Clone()) + if !ok { + continue + } + + switch decision.normalizedAction() { + case HookActionContinue, HookActionModify: + if next != nil { + current = next + } + case HookActionAbortTurn, HookActionHardAbort: + return current, decision + default: + hm.logUnsupportedAction(reg.Name, "after_llm", decision.Action) + } + } + return current, HookDecision{Action: HookActionContinue} +} + +func (hm *HookManager) BeforeTool( + ctx context.Context, + call *ToolCallHookRequest, +) (*ToolCallHookRequest, HookDecision) { + if hm == nil || call == nil { + return call, HookDecision{Action: HookActionContinue} + } + + current := call.Clone() + for _, reg := range hm.snapshotHooks() { + interceptor, ok := reg.Hook.(ToolInterceptor) + if !ok { + continue + } + + next, decision, ok := hm.callBeforeTool(ctx, reg.Name, interceptor, current.Clone()) + if !ok { + continue + } + + switch decision.normalizedAction() { + case HookActionContinue, HookActionModify: + if next != nil { + current = next + } + case HookActionDenyTool, HookActionAbortTurn, HookActionHardAbort: + return current, decision + default: + hm.logUnsupportedAction(reg.Name, "before_tool", decision.Action) + } + } + return current, HookDecision{Action: HookActionContinue} +} + +func (hm *HookManager) AfterTool( + ctx context.Context, + result *ToolResultHookResponse, +) (*ToolResultHookResponse, HookDecision) { + if hm == nil || result == nil { + return result, HookDecision{Action: HookActionContinue} + } + + current := result.Clone() + for _, reg := range hm.snapshotHooks() { + interceptor, ok := reg.Hook.(ToolInterceptor) + if !ok { + continue + } + + next, decision, ok := hm.callAfterTool(ctx, reg.Name, interceptor, current.Clone()) + if !ok { + continue + } + + switch decision.normalizedAction() { + case HookActionContinue, HookActionModify: + if next != nil { + current = next + } + case HookActionAbortTurn, HookActionHardAbort: + return current, decision + default: + hm.logUnsupportedAction(reg.Name, "after_tool", decision.Action) + } + } + return current, HookDecision{Action: HookActionContinue} +} + +func (hm *HookManager) ApproveTool(ctx context.Context, req *ToolApprovalRequest) ApprovalDecision { + if hm == nil || req == nil { + return ApprovalDecision{Approved: true} + } + + for _, reg := range hm.snapshotHooks() { + approver, ok := reg.Hook.(ToolApprover) + if !ok { + continue + } + + decision, ok := hm.callApproveTool(ctx, reg.Name, approver, req.Clone()) + if !ok { + return ApprovalDecision{ + Approved: false, + Reason: fmt.Sprintf("tool approval hook %q failed", reg.Name), + } + } + if !decision.Approved { + return decision + } + } + + return ApprovalDecision{Approved: true} +} + +func (hm *HookManager) rebuildOrdered() { + hm.ordered = hm.ordered[:0] + for _, reg := range hm.hooks { + hm.ordered = append(hm.ordered, reg) + } + sort.SliceStable(hm.ordered, func(i, j int) bool { + if hm.ordered[i].Priority == hm.ordered[j].Priority { + return hm.ordered[i].Name < hm.ordered[j].Name + } + return hm.ordered[i].Priority < hm.ordered[j].Priority + }) +} + +func (hm *HookManager) snapshotHooks() []HookRegistration { + hm.mu.RLock() + defer hm.mu.RUnlock() + + snapshot := make([]HookRegistration, len(hm.ordered)) + copy(snapshot, hm.ordered) + return snapshot +} + +func (hm *HookManager) runObserver(name string, observer EventObserver, evt Event) { + ctx, cancel := context.WithTimeout(context.Background(), hm.observerTimeout) + defer cancel() + + done := make(chan error, 1) + go func() { + done <- observer.OnEvent(ctx, evt) + }() + + select { + case err := <-done: + if err != nil { + logger.WarnCF("hooks", "Event observer failed", map[string]any{ + "hook": name, + "event": evt.Kind.String(), + "error": err.Error(), + }) + } + case <-ctx.Done(): + logger.WarnCF("hooks", "Event observer timed out", map[string]any{ + "hook": name, + "event": evt.Kind.String(), + "timeout_ms": hm.observerTimeout.Milliseconds(), + }) + } +} + +func (hm *HookManager) callBeforeLLM( + parent context.Context, + name string, + interceptor LLMInterceptor, + req *LLMHookRequest, +) (*LLMHookRequest, HookDecision, bool) { + return runInterceptorHook( + parent, + hm.interceptorTimeout, + name, + "before_llm", + func(ctx context.Context) (*LLMHookRequest, HookDecision, error) { + return interceptor.BeforeLLM(ctx, req) + }, + ) +} + +func (hm *HookManager) callAfterLLM( + parent context.Context, + name string, + interceptor LLMInterceptor, + resp *LLMHookResponse, +) (*LLMHookResponse, HookDecision, bool) { + return runInterceptorHook( + parent, + hm.interceptorTimeout, + name, + "after_llm", + func(ctx context.Context) (*LLMHookResponse, HookDecision, error) { + return interceptor.AfterLLM(ctx, resp) + }, + ) +} + +func (hm *HookManager) callBeforeTool( + parent context.Context, + name string, + interceptor ToolInterceptor, + call *ToolCallHookRequest, +) (*ToolCallHookRequest, HookDecision, bool) { + return runInterceptorHook( + parent, + hm.interceptorTimeout, + name, + "before_tool", + func(ctx context.Context) (*ToolCallHookRequest, HookDecision, error) { + return interceptor.BeforeTool(ctx, call) + }, + ) +} + +func (hm *HookManager) callAfterTool( + parent context.Context, + name string, + interceptor ToolInterceptor, + resultView *ToolResultHookResponse, +) (*ToolResultHookResponse, HookDecision, bool) { + return runInterceptorHook( + parent, + hm.interceptorTimeout, + name, + "after_tool", + func(ctx context.Context) (*ToolResultHookResponse, HookDecision, error) { + return interceptor.AfterTool(ctx, resultView) + }, + ) +} + +func (hm *HookManager) callApproveTool( + parent context.Context, + name string, + approver ToolApprover, + req *ToolApprovalRequest, +) (ApprovalDecision, bool) { + return runApprovalHook( + parent, + hm.approvalTimeout, + name, + "approve_tool", + func(ctx context.Context) (ApprovalDecision, error) { + return approver.ApproveTool(ctx, req) + }, + ) +} + +func runInterceptorHook[T any]( + parent context.Context, + timeout time.Duration, + name string, + stage string, + fn func(ctx context.Context) (T, HookDecision, error), +) (T, HookDecision, bool) { + var zero T + + ctx, cancel := context.WithTimeout(parent, timeout) + defer cancel() + + type result struct { + value T + decision HookDecision + err error + } + done := make(chan result, 1) + go func() { + value, decision, err := fn(ctx) + done <- result{value: value, decision: decision, err: err} + }() + + select { + case res := <-done: + if res.err != nil { + logger.WarnCF("hooks", "Interceptor hook failed", map[string]any{ + "hook": name, + "stage": stage, + "error": res.err.Error(), + }) + return zero, HookDecision{}, false + } + return res.value, res.decision, true + case <-ctx.Done(): + logger.WarnCF("hooks", "Interceptor hook timed out", map[string]any{ + "hook": name, + "stage": stage, + "timeout_ms": timeout.Milliseconds(), + }) + return zero, HookDecision{}, false + } +} + +func runApprovalHook( + parent context.Context, + timeout time.Duration, + name string, + stage string, + fn func(ctx context.Context) (ApprovalDecision, error), +) (ApprovalDecision, bool) { + ctx, cancel := context.WithTimeout(parent, timeout) + defer cancel() + + type result struct { + decision ApprovalDecision + err error + } + done := make(chan result, 1) + go func() { + decision, err := fn(ctx) + done <- result{decision: decision, err: err} + }() + + select { + case res := <-done: + if res.err != nil { + logger.WarnCF("hooks", "Approval hook failed", map[string]any{ + "hook": name, + "stage": stage, + "error": res.err.Error(), + }) + return ApprovalDecision{}, false + } + return res.decision, true + case <-ctx.Done(): + logger.WarnCF("hooks", "Approval hook timed out", map[string]any{ + "hook": name, + "stage": stage, + "timeout_ms": timeout.Milliseconds(), + }) + return ApprovalDecision{ + Approved: false, + Reason: fmt.Sprintf("tool approval hook %q timed out", name), + }, true + } +} + +func (hm *HookManager) logUnsupportedAction(name, stage string, action HookAction) { + logger.WarnCF("hooks", "Hook returned unsupported action for stage", map[string]any{ + "hook": name, + "stage": stage, + "action": action, + }) +} + +func cloneProviderMessages(messages []providers.Message) []providers.Message { + if len(messages) == 0 { + return nil + } + + cloned := make([]providers.Message, len(messages)) + for i, msg := range messages { + cloned[i] = msg + if len(msg.Media) > 0 { + cloned[i].Media = append([]string(nil), msg.Media...) + } + if len(msg.SystemParts) > 0 { + cloned[i].SystemParts = append([]providers.ContentBlock(nil), msg.SystemParts...) + } + if len(msg.ToolCalls) > 0 { + cloned[i].ToolCalls = cloneProviderToolCalls(msg.ToolCalls) + } + } + return cloned +} + +func cloneProviderToolCalls(calls []providers.ToolCall) []providers.ToolCall { + if len(calls) == 0 { + return nil + } + + cloned := make([]providers.ToolCall, len(calls)) + for i, call := range calls { + cloned[i] = call + if call.Function != nil { + fn := *call.Function + cloned[i].Function = &fn + } + if call.Arguments != nil { + cloned[i].Arguments = cloneStringAnyMap(call.Arguments) + } + if call.ExtraContent != nil { + extra := *call.ExtraContent + if call.ExtraContent.Google != nil { + google := *call.ExtraContent.Google + extra.Google = &google + } + cloned[i].ExtraContent = &extra + } + } + return cloned +} + +func cloneToolDefinitions(defs []providers.ToolDefinition) []providers.ToolDefinition { + if len(defs) == 0 { + return nil + } + + cloned := make([]providers.ToolDefinition, len(defs)) + for i, def := range defs { + cloned[i] = def + cloned[i].Function.Parameters = cloneStringAnyMap(def.Function.Parameters) + } + return cloned +} + +func cloneLLMResponse(resp *providers.LLMResponse) *providers.LLMResponse { + if resp == nil { + return nil + } + cloned := *resp + cloned.ToolCalls = cloneProviderToolCalls(resp.ToolCalls) + if len(resp.ReasoningDetails) > 0 { + cloned.ReasoningDetails = append(cloned.ReasoningDetails[:0:0], resp.ReasoningDetails...) + } + if resp.Usage != nil { + usage := *resp.Usage + cloned.Usage = &usage + } + return &cloned +} + +func cloneStringAnyMap(src map[string]any) map[string]any { + if len(src) == 0 { + return nil + } + + cloned := make(map[string]any, len(src)) + for k, v := range src { + cloned[k] = v + } + return cloned +} + +func cloneToolResult(result *tools.ToolResult) *tools.ToolResult { + if result == nil { + return nil + } + + cloned := *result + if len(result.Media) > 0 { + cloned.Media = append([]string(nil), result.Media...) + } + return &cloned +} diff --git a/pkg/agent/hooks_test.go b/pkg/agent/hooks_test.go new file mode 100644 index 000000000..6607b5fe7 --- /dev/null +++ b/pkg/agent/hooks_test.go @@ -0,0 +1,312 @@ +package agent + +import ( + "context" + "os" + "sync" + "testing" + "time" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/tools" +) + +func newHookTestLoop( + t *testing.T, + provider providers.LLMProvider, +) (*AgentLoop, *AgentInstance, func()) { + t.Helper() + + tmpDir, err := os.MkdirTemp("", "agent-hooks-*") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + al := NewAgentLoop(cfg, bus.NewMessageBus(), provider) + agent := al.registry.GetDefaultAgent() + if agent == nil { + t.Fatal("expected default agent") + } + + return al, agent, func() { + al.Close() + _ = os.RemoveAll(tmpDir) + } +} + +type llmHookTestProvider struct { + mu sync.Mutex + lastModel string +} + +func (p *llmHookTestProvider) Chat( + ctx context.Context, + messages []providers.Message, + tools []providers.ToolDefinition, + model string, + opts map[string]any, +) (*providers.LLMResponse, error) { + p.mu.Lock() + p.lastModel = model + p.mu.Unlock() + + return &providers.LLMResponse{ + Content: "provider content", + }, nil +} + +func (p *llmHookTestProvider) GetDefaultModel() string { + return "llm-hook-provider" +} + +type llmObserverHook struct { + eventCh chan Event +} + +func (h *llmObserverHook) OnEvent(ctx context.Context, evt Event) error { + if evt.Kind == EventKindTurnEnd { + select { + case h.eventCh <- evt: + default: + } + } + return nil +} + +func (h *llmObserverHook) BeforeLLM( + ctx context.Context, + req *LLMHookRequest, +) (*LLMHookRequest, HookDecision, error) { + next := req.Clone() + next.Model = "hook-model" + return next, HookDecision{Action: HookActionModify}, nil +} + +func (h *llmObserverHook) AfterLLM( + ctx context.Context, + resp *LLMHookResponse, +) (*LLMHookResponse, HookDecision, error) { + next := resp.Clone() + next.Response.Content = "hooked content" + return next, HookDecision{Action: HookActionModify}, nil +} + +func TestAgentLoop_Hooks_ObserverAndLLMInterceptor(t *testing.T) { + provider := &llmHookTestProvider{} + al, agent, cleanup := newHookTestLoop(t, provider) + defer cleanup() + + hook := &llmObserverHook{eventCh: make(chan Event, 1)} + if err := al.MountHook(NamedHook("llm-observer", hook)); err != nil { + t.Fatalf("MountHook failed: %v", err) + } + + resp, err := al.runAgentLoop(context.Background(), agent, processOptions{ + SessionKey: "session-1", + Channel: "cli", + ChatID: "direct", + UserMessage: "hello", + DefaultResponse: defaultResponse, + EnableSummary: false, + SendResponse: false, + }) + if err != nil { + t.Fatalf("runAgentLoop failed: %v", err) + } + if resp != "hooked content" { + t.Fatalf("expected hooked content, got %q", resp) + } + + provider.mu.Lock() + lastModel := provider.lastModel + provider.mu.Unlock() + if lastModel != "hook-model" { + t.Fatalf("expected model hook-model, got %q", lastModel) + } + + select { + case evt := <-hook.eventCh: + if evt.Kind != EventKindTurnEnd { + t.Fatalf("expected turn end event, got %v", evt.Kind) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for hook observer event") + } +} + +type toolHookProvider struct { + mu sync.Mutex + calls int +} + +func (p *toolHookProvider) Chat( + ctx context.Context, + messages []providers.Message, + tools []providers.ToolDefinition, + model string, + opts map[string]any, +) (*providers.LLMResponse, error) { + p.mu.Lock() + defer p.mu.Unlock() + + p.calls++ + if p.calls == 1 { + return &providers.LLMResponse{ + ToolCalls: []providers.ToolCall{ + { + ID: "call-1", + Name: "echo_text", + Arguments: map[string]any{"text": "original"}, + }, + }, + }, nil + } + + last := messages[len(messages)-1] + return &providers.LLMResponse{ + Content: last.Content, + }, nil +} + +func (p *toolHookProvider) GetDefaultModel() string { + return "tool-hook-provider" +} + +type echoTextTool struct{} + +func (t *echoTextTool) Name() string { + return "echo_text" +} + +func (t *echoTextTool) Description() string { + return "echo a text argument" +} + +func (t *echoTextTool) Parameters() map[string]any { + return map[string]any{ + "type": "object", + "properties": map[string]any{ + "text": map[string]any{ + "type": "string", + }, + }, + } +} + +func (t *echoTextTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult { + text, _ := args["text"].(string) + return tools.SilentResult(text) +} + +type toolRewriteHook struct{} + +func (h *toolRewriteHook) BeforeTool( + ctx context.Context, + call *ToolCallHookRequest, +) (*ToolCallHookRequest, HookDecision, error) { + next := call.Clone() + next.Arguments["text"] = "modified" + return next, HookDecision{Action: HookActionModify}, nil +} + +func (h *toolRewriteHook) AfterTool( + ctx context.Context, + result *ToolResultHookResponse, +) (*ToolResultHookResponse, HookDecision, error) { + next := result.Clone() + next.Result.ForLLM = "after:" + next.Result.ForLLM + return next, HookDecision{Action: HookActionModify}, nil +} + +func TestAgentLoop_Hooks_ToolInterceptorCanRewrite(t *testing.T) { + provider := &toolHookProvider{} + al, agent, cleanup := newHookTestLoop(t, provider) + defer cleanup() + + al.RegisterTool(&echoTextTool{}) + if err := al.MountHook(NamedHook("tool-rewrite", &toolRewriteHook{})); err != nil { + t.Fatalf("MountHook failed: %v", err) + } + + resp, err := al.runAgentLoop(context.Background(), agent, processOptions{ + SessionKey: "session-1", + Channel: "cli", + ChatID: "direct", + UserMessage: "run tool", + DefaultResponse: defaultResponse, + EnableSummary: false, + SendResponse: false, + }) + if err != nil { + t.Fatalf("runAgentLoop failed: %v", err) + } + if resp != "after:modified" { + t.Fatalf("expected rewritten tool result, got %q", resp) + } +} + +type denyApprovalHook struct{} + +func (h *denyApprovalHook) ApproveTool(ctx context.Context, req *ToolApprovalRequest) (ApprovalDecision, error) { + return ApprovalDecision{ + Approved: false, + Reason: "blocked", + }, nil +} + +func TestAgentLoop_Hooks_ToolApproverCanDeny(t *testing.T) { + provider := &toolHookProvider{} + al, agent, cleanup := newHookTestLoop(t, provider) + defer cleanup() + + al.RegisterTool(&echoTextTool{}) + if err := al.MountHook(NamedHook("deny-approval", &denyApprovalHook{})); err != nil { + t.Fatalf("MountHook failed: %v", err) + } + + sub := al.SubscribeEvents(16) + defer al.UnsubscribeEvents(sub.ID) + + resp, err := al.runAgentLoop(context.Background(), agent, processOptions{ + SessionKey: "session-1", + Channel: "cli", + ChatID: "direct", + UserMessage: "run tool", + DefaultResponse: defaultResponse, + EnableSummary: false, + SendResponse: false, + }) + if err != nil { + t.Fatalf("runAgentLoop failed: %v", err) + } + expected := "Tool execution denied by approval hook: blocked" + if resp != expected { + t.Fatalf("expected %q, got %q", expected, resp) + } + + events := collectEventStream(sub.C) + skippedEvt, ok := findEvent(events, EventKindToolExecSkipped) + if !ok { + t.Fatal("expected tool skipped event") + } + payload, ok := skippedEvt.Payload.(ToolExecSkippedPayload) + if !ok { + t.Fatalf("expected ToolExecSkippedPayload, got %T", skippedEvt.Payload) + } + if payload.Reason != expected { + t.Fatalf("expected skipped reason %q, got %q", expected, payload.Reason) + } +} diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index f54482ae8..a85abcb60 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -40,6 +40,7 @@ type AgentLoop struct { registry *AgentRegistry state *state.Manager eventBus *EventBus + hooks *HookManager running atomic.Bool summarizing sync.Map fallback *providers.FallbackChain @@ -108,17 +109,19 @@ func NewAgentLoop( stateManager = state.NewManager(defaultAgent.Workspace) } + eventBus := NewEventBus() al := &AgentLoop{ bus: msgBus, cfg: cfg, registry: registry, state: stateManager, - eventBus: NewEventBus(), + eventBus: eventBus, summarizing: sync.Map{}, fallback: fallbackChain, cmdRegistry: commands.NewRegistry(commands.BuiltinDefinitions()), steering: newSteeringQueue(parseSteeringMode(cfg.Agents.Defaults.SteeringMode)), } + al.hooks = NewHookManager(eventBus) return al } @@ -460,11 +463,30 @@ func (al *AgentLoop) Close() { } al.GetRegistry().Close() + if al.hooks != nil { + al.hooks.Close() + } if al.eventBus != nil { al.eventBus.Close() } } +// MountHook registers an in-process hook on the agent loop. +func (al *AgentLoop) MountHook(reg HookRegistration) error { + if al == nil || al.hooks == nil { + return fmt.Errorf("hook manager is not initialized") + } + return al.hooks.Mount(reg) +} + +// UnmountHook removes a previously registered in-process hook. +func (al *AgentLoop) UnmountHook(name string) { + if al == nil || al.hooks == nil { + return + } + al.hooks.Unmount(name) +} + // SubscribeEvents registers a subscriber for agent-loop events. func (al *AgentLoop) SubscribeEvents(buffer int) EventSubscription { if al == nil || al.eventBus == nil { @@ -544,6 +566,31 @@ func cloneEventArguments(args map[string]any) map[string]any { return cloned } +func (al *AgentLoop) hookAbortError(ts *turnState, stage string, decision HookDecision) error { + reason := decision.Reason + if reason == "" { + reason = "hook requested turn abort" + } + + err := fmt.Errorf("hook aborted turn during %s: %s", stage, reason) + al.emitEvent( + EventKindError, + ts.eventMeta("hooks", "turn.error"), + ErrorPayload{ + Stage: "hook." + stage, + Message: err.Error(), + }, + ) + return err +} + +func hookDeniedToolContent(prefix, reason string) string { + if reason == "" { + return prefix + } + return prefix + ": " + reason +} + func (al *AgentLoop) logEvent(evt Event) { fields := map[string]any{ "event_kind": evt.Kind.String(), @@ -1418,36 +1465,6 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er ts.markGracefulTerminalUsed() } - al.emitEvent( - EventKindLLMRequest, - ts.eventMeta("runTurn", "turn.llm.request"), - LLMRequestPayload{ - Model: activeModel, - MessagesCount: len(callMessages), - ToolsCount: len(providerToolDefs), - MaxTokens: ts.agent.MaxTokens, - Temperature: ts.agent.Temperature, - }, - ) - - logger.DebugCF("agent", "LLM request", - map[string]any{ - "agent_id": ts.agent.ID, - "iteration": iteration, - "model": activeModel, - "messages_count": len(callMessages), - "tools_count": len(providerToolDefs), - "max_tokens": ts.agent.MaxTokens, - "temperature": ts.agent.Temperature, - "system_prompt_len": len(callMessages[0].Content), - }) - logger.DebugCF("agent", "Full LLM request", - map[string]any{ - "iteration": iteration, - "messages_json": formatMessagesForLog(callMessages), - "tools_json": formatToolsForLog(providerToolDefs), - }) - llmOpts := map[string]any{ "max_tokens": ts.agent.MaxTokens, "temperature": ts.agent.Temperature, @@ -1462,6 +1479,66 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er } } + llmModel := activeModel + if al.hooks != nil { + llmReq, decision := al.hooks.BeforeLLM(turnCtx, &LLMHookRequest{ + Meta: ts.eventMeta("runTurn", "turn.llm.request"), + Model: llmModel, + Messages: callMessages, + Tools: providerToolDefs, + Options: llmOpts, + Channel: ts.channel, + ChatID: ts.chatID, + GracefulTerminal: gracefulTerminal, + }) + switch decision.normalizedAction() { + case HookActionContinue, HookActionModify: + if llmReq != nil { + llmModel = llmReq.Model + callMessages = llmReq.Messages + providerToolDefs = llmReq.Tools + llmOpts = llmReq.Options + } + case HookActionAbortTurn: + turnStatus = TurnEndStatusError + return turnResult{}, al.hookAbortError(ts, "before_llm", decision) + case HookActionHardAbort: + _ = ts.requestHardAbort() + turnStatus = TurnEndStatusAborted + return al.abortTurn(ts) + } + } + + al.emitEvent( + EventKindLLMRequest, + ts.eventMeta("runTurn", "turn.llm.request"), + LLMRequestPayload{ + Model: llmModel, + MessagesCount: len(callMessages), + ToolsCount: len(providerToolDefs), + MaxTokens: ts.agent.MaxTokens, + Temperature: ts.agent.Temperature, + }, + ) + + logger.DebugCF("agent", "LLM request", + map[string]any{ + "agent_id": ts.agent.ID, + "iteration": iteration, + "model": llmModel, + "messages_count": len(callMessages), + "tools_count": len(providerToolDefs), + "max_tokens": ts.agent.MaxTokens, + "temperature": ts.agent.Temperature, + "system_prompt_len": len(callMessages[0].Content), + }) + logger.DebugCF("agent", "Full LLM request", + map[string]any{ + "iteration": iteration, + "messages_json": formatMessagesForLog(callMessages), + "tools_json": formatToolsForLog(providerToolDefs), + }) + callLLM := func(messagesForCall []providers.Message, toolDefsForCall []providers.ToolDefinition) (*providers.LLMResponse, error) { providerCtx, providerCancel := context.WithCancel(turnCtx) ts.setProviderCancel(providerCancel) @@ -1494,7 +1571,7 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er } return fbResult.Response, nil } - return ts.agent.Provider.Chat(providerCtx, messagesForCall, toolDefsForCall, activeModel, llmOpts) + return ts.agent.Provider.Chat(providerCtx, messagesForCall, toolDefsForCall, llmModel, llmOpts) } var response *providers.LLMResponse @@ -1626,12 +1703,35 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er map[string]any{ "agent_id": ts.agent.ID, "iteration": iteration, - "model": activeModel, + "model": llmModel, "error": err.Error(), }) return turnResult{}, fmt.Errorf("LLM call failed after retries: %w", err) } + if al.hooks != nil { + llmResp, decision := al.hooks.AfterLLM(turnCtx, &LLMHookResponse{ + Meta: ts.eventMeta("runTurn", "turn.llm.response"), + Model: llmModel, + Response: response, + Channel: ts.channel, + ChatID: ts.chatID, + }) + switch decision.normalizedAction() { + case HookActionContinue, HookActionModify: + if llmResp != nil && llmResp.Response != nil { + response = llmResp.Response + } + case HookActionAbortTurn: + turnStatus = TurnEndStatusError + return turnResult{}, al.hookAbortError(ts, "after_llm", decision) + case HookActionHardAbort: + _ = ts.requestHardAbort() + turnStatus = TurnEndStatusAborted + return al.abortTurn(ts) + } + } + go al.handleReasoning( turnCtx, response.Reasoning, @@ -1728,25 +1828,106 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er return al.abortTurn(ts) } - argsJSON, _ := json.Marshal(tc.Arguments) + toolName := tc.Name + toolArgs := cloneStringAnyMap(tc.Arguments) + + if al.hooks != nil { + toolReq, decision := al.hooks.BeforeTool(turnCtx, &ToolCallHookRequest{ + Meta: ts.eventMeta("runTurn", "turn.tool.before"), + Tool: toolName, + Arguments: toolArgs, + Channel: ts.channel, + ChatID: ts.chatID, + }) + switch decision.normalizedAction() { + case HookActionContinue, HookActionModify: + if toolReq != nil { + toolName = toolReq.Tool + toolArgs = toolReq.Arguments + } + case HookActionDenyTool: + denyContent := hookDeniedToolContent("Tool execution denied by hook", decision.Reason) + al.emitEvent( + EventKindToolExecSkipped, + ts.eventMeta("runTurn", "turn.tool.skipped"), + ToolExecSkippedPayload{ + Tool: toolName, + Reason: denyContent, + }, + ) + deniedMsg := providers.Message{ + Role: "tool", + Content: denyContent, + ToolCallID: tc.ID, + } + messages = append(messages, deniedMsg) + if !ts.opts.NoHistory { + ts.agent.Sessions.AddFullMessage(ts.sessionKey, deniedMsg) + ts.recordPersistedMessage(deniedMsg) + } + continue + case HookActionAbortTurn: + turnStatus = TurnEndStatusError + return turnResult{}, al.hookAbortError(ts, "before_tool", decision) + case HookActionHardAbort: + _ = ts.requestHardAbort() + turnStatus = TurnEndStatusAborted + return al.abortTurn(ts) + } + } + + if al.hooks != nil { + approval := al.hooks.ApproveTool(turnCtx, &ToolApprovalRequest{ + Meta: ts.eventMeta("runTurn", "turn.tool.approve"), + Tool: toolName, + Arguments: toolArgs, + Channel: ts.channel, + ChatID: ts.chatID, + }) + if !approval.Approved { + denyContent := hookDeniedToolContent("Tool execution denied by approval hook", approval.Reason) + al.emitEvent( + EventKindToolExecSkipped, + ts.eventMeta("runTurn", "turn.tool.skipped"), + ToolExecSkippedPayload{ + Tool: toolName, + Reason: denyContent, + }, + ) + deniedMsg := providers.Message{ + Role: "tool", + Content: denyContent, + ToolCallID: tc.ID, + } + messages = append(messages, deniedMsg) + if !ts.opts.NoHistory { + ts.agent.Sessions.AddFullMessage(ts.sessionKey, deniedMsg) + ts.recordPersistedMessage(deniedMsg) + } + continue + } + } + + argsJSON, _ := json.Marshal(toolArgs) argsPreview := utils.Truncate(string(argsJSON), 200) - logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview), + logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", toolName, argsPreview), map[string]any{ "agent_id": ts.agent.ID, - "tool": tc.Name, + "tool": toolName, "iteration": iteration, }) al.emitEvent( EventKindToolExecStart, ts.eventMeta("runTurn", "turn.tool.start"), ToolExecStartPayload{ - Tool: tc.Name, - Arguments: cloneEventArguments(tc.Arguments), + Tool: toolName, + Arguments: cloneEventArguments(toolArgs), }, ) - toolCall := tc + toolCallID := tc.ID toolIteration := iteration + asyncToolName := toolName asyncCallback := func(_ context.Context, result *tools.ToolResult) { if !result.Silent && result.ForUser != "" { outCtx, outCancel := context.WithTimeout(context.Background(), 5*time.Second) @@ -1768,7 +1949,7 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er logger.InfoCF("agent", "Async tool completed, publishing result", map[string]any{ - "tool": toolCall.Name, + "tool": asyncToolName, "content_len": len(content), "channel": ts.channel, }) @@ -1776,7 +1957,7 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er EventKindFollowUpQueued, ts.scope.meta(toolIteration, "runTurn", "turn.follow_up.queued"), FollowUpQueuedPayload{ - SourceTool: toolCall.Name, + SourceTool: asyncToolName, Channel: ts.channel, ChatID: ts.chatID, ContentLen: len(content), @@ -1787,7 +1968,7 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er defer pubCancel() _ = al.bus.PublishInbound(pubCtx, bus.InboundMessage{ Channel: "system", - SenderID: fmt.Sprintf("async:%s", toolCall.Name), + SenderID: fmt.Sprintf("async:%s", asyncToolName), ChatID: fmt.Sprintf("%s:%s", ts.channel, ts.chatID), Content: content, }) @@ -1796,8 +1977,8 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er toolStart := time.Now() toolResult := ts.agent.Tools.ExecuteWithContext( turnCtx, - toolCall.Name, - toolCall.Arguments, + toolName, + toolArgs, ts.channel, ts.chatID, asyncCallback, @@ -1809,6 +1990,40 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er return al.abortTurn(ts) } + if al.hooks != nil { + toolResp, decision := al.hooks.AfterTool(turnCtx, &ToolResultHookResponse{ + Meta: ts.eventMeta("runTurn", "turn.tool.after"), + Tool: toolName, + Arguments: toolArgs, + Result: toolResult, + Duration: toolDuration, + Channel: ts.channel, + ChatID: ts.chatID, + }) + switch decision.normalizedAction() { + case HookActionContinue, HookActionModify: + if toolResp != nil { + if toolResp.Tool != "" { + toolName = toolResp.Tool + } + if toolResp.Result != nil { + toolResult = toolResp.Result + } + } + case HookActionAbortTurn: + turnStatus = TurnEndStatusError + return turnResult{}, al.hookAbortError(ts, "after_tool", decision) + case HookActionHardAbort: + _ = ts.requestHardAbort() + turnStatus = TurnEndStatusAborted + return al.abortTurn(ts) + } + } + + if toolResult == nil { + toolResult = tools.ErrorResult("hook returned nil tool result") + } + if !toolResult.Silent && toolResult.ForUser != "" && ts.opts.SendResponse { al.bus.PublishOutbound(ctx, bus.OutboundMessage{ Channel: ts.channel, @@ -1817,7 +2032,7 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er }) logger.DebugCF("agent", "Sent tool result to user", map[string]any{ - "tool": toolCall.Name, + "tool": toolName, "content_len": len(toolResult.ForUser), }) } @@ -1850,13 +2065,13 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er toolResultMsg := providers.Message{ Role: "tool", Content: contentForLLM, - ToolCallID: toolCall.ID, + ToolCallID: toolCallID, } al.emitEvent( EventKindToolExecEnd, ts.eventMeta("runTurn", "turn.tool.end"), ToolExecEndPayload{ - Tool: toolCall.Name, + Tool: toolName, Duration: toolDuration, ForLLMLen: len(contentForLLM), ForUserLen: len(toolResult.ForUser),