mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
feat(agent): add hook manager foundation
This commit is contained in:
@@ -0,0 +1,476 @@
|
||||
# PicoClaw Hook 系统设计(基于 `refactor/agent`)
|
||||
|
||||
## 背景
|
||||
|
||||
本设计围绕两个议题展开:
|
||||
|
||||
- `#1316`:把 agent loop 重构为事件驱动、可中断、可追加、可观测
|
||||
- `#1796`:在 EventBus 稳定后,把 hooks 设计为 EventBus 的 consumer,而不是重新发明一套事件模型
|
||||
|
||||
当前分支已经完成了第一步里的“事件系统基础”,但还没有真正的 hook 挂载层。因此这里的目标不是重新设计 event,而是在已有实现上补出一层可扩展、可拦截、可外挂的 HookManager。
|
||||
|
||||
## 外部项目对比
|
||||
|
||||
### OpenClaw
|
||||
|
||||
OpenClaw 的扩展能力分成三层:
|
||||
|
||||
- Internal hooks:目录发现,运行在 Gateway 进程内
|
||||
- Plugin hooks:插件在运行时注册 hook,也在进程内
|
||||
- Webhooks:外部系统通过 HTTP 触发 Gateway 动作,属于进程外
|
||||
|
||||
值得借鉴的点:
|
||||
|
||||
- 有“项目内挂载”和“项目外挂载”两种路径
|
||||
- hook 是配置驱动,可启停
|
||||
- 外部入口有明确的安全边界和映射层
|
||||
|
||||
不建议直接照搬的点:
|
||||
|
||||
- OpenClaw 的 hooks / plugin hooks / webhooks 是三套路由,PicoClaw 当前体量下会偏重
|
||||
- HTTP webhook 更适合“事件进入系统”,不适合作为“可同步拦截 agent loop”的基础机制
|
||||
|
||||
### pi-mono
|
||||
|
||||
pi-mono 的核心思路更接近当前分支:
|
||||
|
||||
- 扩展统一为 extension API
|
||||
- 事件分为观察型和可变更型
|
||||
- 某些阶段允许 `transform` / `block` / `replace`
|
||||
- 扩展代码主要是进程内执行
|
||||
- RPC mode 把 UI 交互桥接到进程外客户端
|
||||
|
||||
值得借鉴的点:
|
||||
|
||||
- 不把“观察”和“拦截”混成一个接口
|
||||
- 允许返回结构化动作,而不是只有回调
|
||||
- 进程外通信只暴露必要协议,不把整个内部对象图泄露出去
|
||||
|
||||
## 当前分支现状
|
||||
|
||||
### 已有能力
|
||||
|
||||
当前分支已经具备 hook 系统的地基:
|
||||
|
||||
- `pkg/agent/events.go` 定义了稳定的 `EventKind`、`EventMeta` 和 payload
|
||||
- `pkg/agent/eventbus.go` 提供了非阻塞 fan-out 的 `EventBus`
|
||||
- `pkg/agent/loop.go` 中的 `runTurn()` 已在 turn、llm、tool、interrupt、follow-up、summary 等节点发射事件
|
||||
- `pkg/agent/steering.go` 已支持 steering、graceful interrupt、hard abort
|
||||
- `pkg/agent/turn.go` 已维护 turn phase、恢复点、active turn、abort 状态
|
||||
|
||||
### 现有缺口
|
||||
|
||||
当前分支还缺四件事:
|
||||
|
||||
- 没有 HookManager,只有 EventBus
|
||||
- 没有 Before/After LLM、Before/After Tool 这种同步拦截点
|
||||
- 没有审批型 hook
|
||||
- 子 agent 仍走 `pkg/tools/SubagentManager + RunToolLoop`,没有接入 `pkg/agent` 的 turn tree 和事件流
|
||||
|
||||
### 一个关键现实
|
||||
|
||||
`#1316` 文案里提到“只读并行、写入串行”的工具执行策略,但当前 `runTurn()` 实现已经先收敛成“顺序执行 + 每个工具后检查 steering / interrupt”。因此 hook 设计不应依赖未来的并行模型,而应该先兼容当前顺序执行,再为以后增加 `ReadOnlyIndicator` 留口子。
|
||||
|
||||
## 设计原则
|
||||
|
||||
- Hook 必须建立在 `pkg/agent` 的 EventBus 和 turn 上下文之上
|
||||
- EventBus 负责广播,HookManager 负责拦截,两者职责分离
|
||||
- 项目内挂载要简单,项目外挂载必须走 IPC
|
||||
- 观察型 hook 不能阻塞 loop;拦截型 hook 必须有超时
|
||||
- 先覆盖主 turn,不把 sub-turn 一次做满
|
||||
- 不新增第二套用户事件命名系统,优先复用 `EventKind.String()`
|
||||
|
||||
## 总体架构
|
||||
|
||||
分成三层:
|
||||
|
||||
1. `EventBus`
|
||||
负责广播只读事件,现有实现直接复用
|
||||
|
||||
2. `HookManager`
|
||||
负责管理 hook、排序、超时、错误隔离,并在 `runTurn()` 的明确检查点执行同步拦截
|
||||
|
||||
3. `HookMount`
|
||||
负责两种挂载方式:
|
||||
- 进程内 Go hook
|
||||
- 进程外 IPC hook
|
||||
|
||||
换句话说:
|
||||
|
||||
- EventBus 是“发生了什么”
|
||||
- HookManager 是“谁能介入”
|
||||
- HookMount 是“这些 hook 从哪里来”
|
||||
|
||||
## Hook 分类
|
||||
|
||||
不建议把所有 hook 都设计成 `OnEvent(evt)`。
|
||||
|
||||
建议拆成两类。
|
||||
|
||||
### 1. 观察型
|
||||
|
||||
只消费事件,不修改流程:
|
||||
|
||||
```go
|
||||
type EventObserver interface {
|
||||
OnEvent(ctx context.Context, evt agent.Event) error
|
||||
}
|
||||
```
|
||||
|
||||
这类 hook 直接订阅 EventBus 即可。
|
||||
|
||||
适用场景:
|
||||
|
||||
- 审计日志
|
||||
- 指标上报
|
||||
- 调试 trace
|
||||
- 将事件转发给外部 UI / TUI / Web 面板
|
||||
|
||||
### 2. 拦截型
|
||||
|
||||
只在少数明确节点触发,允许返回动作:
|
||||
|
||||
```go
|
||||
type LLMInterceptor interface {
|
||||
BeforeLLM(ctx context.Context, req *LLMRequest) HookDecision[*LLMRequest]
|
||||
AfterLLM(ctx context.Context, resp *LLMResponse) HookDecision[*LLMResponse]
|
||||
}
|
||||
|
||||
type ToolInterceptor interface {
|
||||
BeforeTool(ctx context.Context, call *ToolCall) HookDecision[*ToolCall]
|
||||
AfterTool(ctx context.Context, result *ToolResultView) HookDecision[*ToolResultView]
|
||||
}
|
||||
|
||||
type ToolApprover interface {
|
||||
ApproveTool(ctx context.Context, req *ToolApprovalRequest) ApprovalDecision
|
||||
}
|
||||
```
|
||||
|
||||
这里的 `HookDecision` 统一支持:
|
||||
|
||||
- `continue`
|
||||
- `modify`
|
||||
- `deny_tool`
|
||||
- `abort_turn`
|
||||
- `hard_abort`
|
||||
|
||||
## 对外暴露的最小 hook 面
|
||||
|
||||
V1 不需要把所有 EventKind 都变成可拦截点。
|
||||
|
||||
建议只开放这些同步 hook:
|
||||
|
||||
- `before_llm`
|
||||
- `after_llm`
|
||||
- `before_tool`
|
||||
- `after_tool`
|
||||
- `approve_tool`
|
||||
|
||||
其余节点继续作为只读事件暴露:
|
||||
|
||||
- `turn_start`
|
||||
- `turn_end`
|
||||
- `llm_request`
|
||||
- `llm_response`
|
||||
- `tool_exec_start`
|
||||
- `tool_exec_end`
|
||||
- `tool_exec_skipped`
|
||||
- `steering_injected`
|
||||
- `follow_up_queued`
|
||||
- `interrupt_received`
|
||||
- `context_compress`
|
||||
- `session_summarize`
|
||||
- `error`
|
||||
|
||||
`subturn_*` 在 V1 中保留名字,但不承诺一定触发,直到子 turn 迁移完成。
|
||||
|
||||
## 项目内挂载
|
||||
|
||||
内部挂载必须尽量低摩擦。
|
||||
|
||||
建议提供两种等价方式,底层都走 HookManager。
|
||||
|
||||
### 方式 A:代码显式挂载
|
||||
|
||||
```go
|
||||
al.MountHook(hooks.Named("audit", &AuditHook{}))
|
||||
```
|
||||
|
||||
适用于:
|
||||
|
||||
- 仓内内建 hook
|
||||
- 单元测试
|
||||
- feature flag 控制
|
||||
|
||||
### 方式 B:内建 registry
|
||||
|
||||
```go
|
||||
func init() {
|
||||
hooks.RegisterBuiltin("audit", func() hooks.Hook {
|
||||
return &AuditHook{}
|
||||
})
|
||||
}
|
||||
```
|
||||
|
||||
启动时根据配置启用:
|
||||
|
||||
```json
|
||||
{
|
||||
"hooks": {
|
||||
"builtins": {
|
||||
"audit": { "enabled": true }
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
这比 OpenClaw 的目录扫描更轻,也更贴合 Go 项目。
|
||||
|
||||
## 项目外挂载
|
||||
|
||||
这是本设计的硬要求。
|
||||
|
||||
建议 V1 采用:
|
||||
|
||||
- `JSON-RPC over stdio`
|
||||
|
||||
原因:
|
||||
|
||||
- 跨平台最简单
|
||||
- 不依赖额外端口
|
||||
- 非常适合“由 PicoClaw 启动一个外部 hook 进程”
|
||||
- 比 HTTP webhook 更适合同步拦截
|
||||
|
||||
### 外部 hook 进程模型
|
||||
|
||||
PicoClaw 启动外部进程,并在其 stdin/stdout 上跑协议。
|
||||
|
||||
配置示例:
|
||||
|
||||
```json
|
||||
{
|
||||
"hooks": {
|
||||
"processes": {
|
||||
"review-gate": {
|
||||
"enabled": true,
|
||||
"transport": "stdio",
|
||||
"command": ["uvx", "picoclaw-hook-reviewer"],
|
||||
"observe": ["turn_start", "turn_end", "tool_exec_end"],
|
||||
"intercept": ["before_tool", "approve_tool"],
|
||||
"timeout_ms": 5000
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 协议边界
|
||||
|
||||
不要把内部 Go 结构体直接暴露给 IPC。
|
||||
|
||||
建议定义稳定的协议对象:
|
||||
|
||||
- `HookHandshake`
|
||||
- `HookEventNotification`
|
||||
- `BeforeLLMRequest`
|
||||
- `AfterLLMRequest`
|
||||
- `BeforeToolRequest`
|
||||
- `AfterToolRequest`
|
||||
- `ApproveToolRequest`
|
||||
- `HookDecision`
|
||||
|
||||
其中:
|
||||
|
||||
- 观察型事件用 notification,fire-and-forget
|
||||
- 拦截型事件用 request/response,同步等待
|
||||
|
||||
### 为什么是 stdio,而不是直接用 HTTP webhook
|
||||
|
||||
因为两者用途不同:
|
||||
|
||||
- HTTP webhook 更适合“外部系统向 PicoClaw 投递事件”
|
||||
- stdio/RPC 更适合“PicoClaw 在 turn 内同步询问外部 hook 是否改写 / 放行 / 拒绝”
|
||||
|
||||
如果未来需要 OpenClaw 式 webhook,可以作为独立入口层,再把外部事件转成 inbound message 或 steering,而不是直接替代 hook IPC。
|
||||
|
||||
## Hook 执行顺序
|
||||
|
||||
建议统一排序规则:
|
||||
|
||||
- 先内建 in-process hook
|
||||
- 再外部 IPC hook
|
||||
- 同组内按 `priority` 从小到大执行
|
||||
|
||||
原因:
|
||||
|
||||
- 内建 hook 延迟更低,适合做基础规范化
|
||||
- 外部 hook 更适合做审批、审计、组织级策略
|
||||
|
||||
## 超时与错误策略
|
||||
|
||||
### 观察型
|
||||
|
||||
- 默认超时:`500ms`
|
||||
- 超时或报错:记录日志,继续主流程
|
||||
|
||||
### 拦截型
|
||||
|
||||
- `before_llm` / `after_llm` / `before_tool` / `after_tool`:默认 `5s`
|
||||
- `approve_tool`:默认 `60s`
|
||||
|
||||
超时行为:
|
||||
|
||||
- 普通拦截:`continue`
|
||||
- 审批:`deny`
|
||||
|
||||
这点应直接沿用 `#1316` 的安全倾向。
|
||||
|
||||
## 与当前分支的对接点
|
||||
|
||||
### 直接复用
|
||||
|
||||
- 事件定义:`pkg/agent/events.go`
|
||||
- 事件广播:`pkg/agent/eventbus.go`
|
||||
- 活跃 turn / interrupt / rollback:`pkg/agent/turn.go`
|
||||
- 事件发射点:`pkg/agent/loop.go`
|
||||
|
||||
### 需要新增
|
||||
|
||||
- `pkg/agent/hooks.go`
|
||||
- Hook 接口
|
||||
- HookDecision / ApprovalDecision
|
||||
- HookManager
|
||||
|
||||
- `pkg/agent/hook_mount.go`
|
||||
- 内建 hook 注册
|
||||
- 外部进程 hook 注册
|
||||
|
||||
- `pkg/agent/hook_ipc.go`
|
||||
- stdio JSON-RPC bridge
|
||||
|
||||
- `pkg/agent/hook_types.go`
|
||||
- IPC 稳定载荷
|
||||
|
||||
### 需要改造
|
||||
|
||||
- `pkg/agent/loop.go`
|
||||
- 在 LLM 和 tool 关键路径前后插入 HookManager 调用
|
||||
|
||||
- `pkg/tools/base.go`
|
||||
- 可选新增 `ReadOnlyIndicator`
|
||||
|
||||
- `pkg/tools/spawn.go`
|
||||
- `pkg/tools/subagent.go`
|
||||
- 先保留现状
|
||||
- 等 sub-turn 迁移后再接入 `subturn_*` hook
|
||||
|
||||
## 一个更贴合当前分支的数据流
|
||||
|
||||
### 观察链路
|
||||
|
||||
```text
|
||||
runTurn() -> emitEvent() -> EventBus -> observers
|
||||
```
|
||||
|
||||
### 拦截链路
|
||||
|
||||
```text
|
||||
runTurn()
|
||||
-> HookManager.BeforeLLM()
|
||||
-> Provider.Chat()
|
||||
-> HookManager.AfterLLM()
|
||||
-> HookManager.BeforeTool()
|
||||
-> HookManager.ApproveTool()
|
||||
-> tool.Execute()
|
||||
-> HookManager.AfterTool()
|
||||
```
|
||||
|
||||
也就是说:
|
||||
|
||||
- observer 不改变现有 `emitEvent()`
|
||||
- interceptor 直接插在 `runTurn()` 热路径
|
||||
|
||||
## 用户可见配置
|
||||
|
||||
建议新增:
|
||||
|
||||
```json
|
||||
{
|
||||
"hooks": {
|
||||
"enabled": true,
|
||||
"builtins": {},
|
||||
"processes": {},
|
||||
"defaults": {
|
||||
"observer_timeout_ms": 500,
|
||||
"interceptor_timeout_ms": 5000,
|
||||
"approval_timeout_ms": 60000
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
V1 不做复杂自动发现。
|
||||
|
||||
原因:
|
||||
|
||||
- 当前分支重点是把地基打稳
|
||||
- 目录扫描、安装器、脚手架可以后置
|
||||
- 先让仓内和仓外都能挂上去,比“管理体验完整”更重要
|
||||
|
||||
## 推荐的 V1 范围
|
||||
|
||||
### 必做
|
||||
|
||||
- HookManager
|
||||
- in-process 挂载
|
||||
- stdio IPC 挂载
|
||||
- observer hooks
|
||||
- `before_tool` / `after_tool` / `approve_tool`
|
||||
- `before_llm` / `after_llm`
|
||||
|
||||
### 可后置
|
||||
|
||||
- hook CLI 管理命令
|
||||
- hook 自动发现
|
||||
- Unix socket / named pipe transport
|
||||
- sub-turn hook 生命周期
|
||||
- read-only 并行分组
|
||||
- webhook 到 inbound message 的映射入口
|
||||
|
||||
## 分阶段落地
|
||||
|
||||
### Phase 1
|
||||
|
||||
- 引入 HookManager
|
||||
- 支持 in-process observer + interceptor
|
||||
- 先只接主 turn
|
||||
|
||||
### Phase 2
|
||||
|
||||
- 引入 `stdio` 外部 hook 进程桥
|
||||
- 支持组织级审批 / 审计 / 参数改写
|
||||
|
||||
### Phase 3
|
||||
|
||||
- 把 `SubagentManager` 迁移到 `runTurn/sub-turn`
|
||||
- 接通 `subturn_spawn` / `subturn_end` / `subturn_result_delivered`
|
||||
|
||||
### Phase 4
|
||||
|
||||
- 视需求补 `ReadOnlyIndicator`
|
||||
- 在主 turn 和 sub-turn 上统一只读并行策略
|
||||
|
||||
## 最终结论
|
||||
|
||||
最适合 PicoClaw 当前分支的方案,不是直接复制 OpenClaw 的 hooks,也不是完整照搬 pi-mono 的 extension system,而是:
|
||||
|
||||
- 以现有 `EventBus` 为只读观察面
|
||||
- 以新增 `HookManager` 为同步拦截面
|
||||
- 项目内通过 Go 对象直接挂载
|
||||
- 项目外通过 `stdio JSON-RPC` 进程通信挂载
|
||||
|
||||
这样做有三个好处:
|
||||
|
||||
- 和 `#1796` 一致,hooks 只是 EventBus 之上的消费层
|
||||
- 和当前 `refactor/agent` 实现一致,不需要推翻已有事件系统
|
||||
- 同时满足“仓内简单挂载”和“仓外进程通信挂载”两个硬需求
|
||||
@@ -0,0 +1,751 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultHookObserverTimeout = 500 * time.Millisecond
|
||||
defaultHookInterceptorTimeout = 5 * time.Second
|
||||
defaultHookApprovalTimeout = 60 * time.Second
|
||||
hookObserverBufferSize = 64
|
||||
)
|
||||
|
||||
type HookAction string
|
||||
|
||||
const (
|
||||
HookActionContinue HookAction = "continue"
|
||||
HookActionModify HookAction = "modify"
|
||||
HookActionDenyTool HookAction = "deny_tool"
|
||||
HookActionAbortTurn HookAction = "abort_turn"
|
||||
HookActionHardAbort HookAction = "hard_abort"
|
||||
)
|
||||
|
||||
type HookDecision struct {
|
||||
Action HookAction
|
||||
Reason string
|
||||
}
|
||||
|
||||
func (d HookDecision) normalizedAction() HookAction {
|
||||
if d.Action == "" {
|
||||
return HookActionContinue
|
||||
}
|
||||
return d.Action
|
||||
}
|
||||
|
||||
type ApprovalDecision struct {
|
||||
Approved bool
|
||||
Reason string
|
||||
}
|
||||
|
||||
type HookRegistration struct {
|
||||
Name string
|
||||
Priority int
|
||||
Hook any
|
||||
}
|
||||
|
||||
func NamedHook(name string, hook any) HookRegistration {
|
||||
return HookRegistration{
|
||||
Name: name,
|
||||
Hook: hook,
|
||||
}
|
||||
}
|
||||
|
||||
type EventObserver interface {
|
||||
OnEvent(ctx context.Context, evt Event) error
|
||||
}
|
||||
|
||||
type LLMInterceptor interface {
|
||||
BeforeLLM(ctx context.Context, req *LLMHookRequest) (*LLMHookRequest, HookDecision, error)
|
||||
AfterLLM(ctx context.Context, resp *LLMHookResponse) (*LLMHookResponse, HookDecision, error)
|
||||
}
|
||||
|
||||
type ToolInterceptor interface {
|
||||
BeforeTool(ctx context.Context, call *ToolCallHookRequest) (*ToolCallHookRequest, HookDecision, error)
|
||||
AfterTool(ctx context.Context, result *ToolResultHookResponse) (*ToolResultHookResponse, HookDecision, error)
|
||||
}
|
||||
|
||||
type ToolApprover interface {
|
||||
ApproveTool(ctx context.Context, req *ToolApprovalRequest) (ApprovalDecision, error)
|
||||
}
|
||||
|
||||
type LLMHookRequest struct {
|
||||
Meta EventMeta
|
||||
Model string
|
||||
Messages []providers.Message
|
||||
Tools []providers.ToolDefinition
|
||||
Options map[string]any
|
||||
Channel string
|
||||
ChatID string
|
||||
GracefulTerminal bool
|
||||
}
|
||||
|
||||
func (r *LLMHookRequest) Clone() *LLMHookRequest {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *r
|
||||
cloned.Messages = cloneProviderMessages(r.Messages)
|
||||
cloned.Tools = cloneToolDefinitions(r.Tools)
|
||||
cloned.Options = cloneStringAnyMap(r.Options)
|
||||
return &cloned
|
||||
}
|
||||
|
||||
type LLMHookResponse struct {
|
||||
Meta EventMeta
|
||||
Model string
|
||||
Response *providers.LLMResponse
|
||||
Channel string
|
||||
ChatID string
|
||||
}
|
||||
|
||||
func (r *LLMHookResponse) Clone() *LLMHookResponse {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *r
|
||||
cloned.Response = cloneLLMResponse(r.Response)
|
||||
return &cloned
|
||||
}
|
||||
|
||||
type ToolCallHookRequest struct {
|
||||
Meta EventMeta
|
||||
Tool string
|
||||
Arguments map[string]any
|
||||
Channel string
|
||||
ChatID string
|
||||
}
|
||||
|
||||
func (r *ToolCallHookRequest) Clone() *ToolCallHookRequest {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *r
|
||||
cloned.Arguments = cloneStringAnyMap(r.Arguments)
|
||||
return &cloned
|
||||
}
|
||||
|
||||
type ToolApprovalRequest struct {
|
||||
Meta EventMeta
|
||||
Tool string
|
||||
Arguments map[string]any
|
||||
Channel string
|
||||
ChatID string
|
||||
}
|
||||
|
||||
func (r *ToolApprovalRequest) Clone() *ToolApprovalRequest {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *r
|
||||
cloned.Arguments = cloneStringAnyMap(r.Arguments)
|
||||
return &cloned
|
||||
}
|
||||
|
||||
type ToolResultHookResponse struct {
|
||||
Meta EventMeta
|
||||
Tool string
|
||||
Arguments map[string]any
|
||||
Result *tools.ToolResult
|
||||
Duration time.Duration
|
||||
Channel string
|
||||
ChatID string
|
||||
}
|
||||
|
||||
func (r *ToolResultHookResponse) Clone() *ToolResultHookResponse {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *r
|
||||
cloned.Arguments = cloneStringAnyMap(r.Arguments)
|
||||
cloned.Result = cloneToolResult(r.Result)
|
||||
return &cloned
|
||||
}
|
||||
|
||||
type HookManager struct {
|
||||
eventBus *EventBus
|
||||
observerTimeout time.Duration
|
||||
interceptorTimeout time.Duration
|
||||
approvalTimeout time.Duration
|
||||
|
||||
mu sync.RWMutex
|
||||
hooks map[string]HookRegistration
|
||||
ordered []HookRegistration
|
||||
|
||||
sub EventSubscription
|
||||
done chan struct{}
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
func NewHookManager(eventBus *EventBus) *HookManager {
|
||||
hm := &HookManager{
|
||||
eventBus: eventBus,
|
||||
observerTimeout: defaultHookObserverTimeout,
|
||||
interceptorTimeout: defaultHookInterceptorTimeout,
|
||||
approvalTimeout: defaultHookApprovalTimeout,
|
||||
hooks: make(map[string]HookRegistration),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
if eventBus == nil {
|
||||
close(hm.done)
|
||||
return hm
|
||||
}
|
||||
|
||||
hm.sub = eventBus.Subscribe(hookObserverBufferSize)
|
||||
go hm.dispatchEvents()
|
||||
return hm
|
||||
}
|
||||
|
||||
func (hm *HookManager) Close() {
|
||||
if hm == nil {
|
||||
return
|
||||
}
|
||||
|
||||
hm.closeOnce.Do(func() {
|
||||
if hm.eventBus != nil {
|
||||
hm.eventBus.Unsubscribe(hm.sub.ID)
|
||||
}
|
||||
<-hm.done
|
||||
})
|
||||
}
|
||||
|
||||
func (hm *HookManager) Mount(reg HookRegistration) error {
|
||||
if hm == nil {
|
||||
return fmt.Errorf("hook manager is nil")
|
||||
}
|
||||
if reg.Name == "" {
|
||||
return fmt.Errorf("hook name is required")
|
||||
}
|
||||
if reg.Hook == nil {
|
||||
return fmt.Errorf("hook %q is nil", reg.Name)
|
||||
}
|
||||
|
||||
hm.mu.Lock()
|
||||
defer hm.mu.Unlock()
|
||||
|
||||
hm.hooks[reg.Name] = reg
|
||||
hm.rebuildOrdered()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hm *HookManager) Unmount(name string) {
|
||||
if hm == nil || name == "" {
|
||||
return
|
||||
}
|
||||
|
||||
hm.mu.Lock()
|
||||
defer hm.mu.Unlock()
|
||||
|
||||
delete(hm.hooks, name)
|
||||
hm.rebuildOrdered()
|
||||
}
|
||||
|
||||
func (hm *HookManager) dispatchEvents() {
|
||||
defer close(hm.done)
|
||||
|
||||
for evt := range hm.sub.C {
|
||||
for _, reg := range hm.snapshotHooks() {
|
||||
observer, ok := reg.Hook.(EventObserver)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
hm.runObserver(reg.Name, observer, evt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (hm *HookManager) BeforeLLM(ctx context.Context, req *LLMHookRequest) (*LLMHookRequest, HookDecision) {
|
||||
if hm == nil || req == nil {
|
||||
return req, HookDecision{Action: HookActionContinue}
|
||||
}
|
||||
|
||||
current := req.Clone()
|
||||
for _, reg := range hm.snapshotHooks() {
|
||||
interceptor, ok := reg.Hook.(LLMInterceptor)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
next, decision, ok := hm.callBeforeLLM(ctx, reg.Name, interceptor, current.Clone())
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
switch decision.normalizedAction() {
|
||||
case HookActionContinue, HookActionModify:
|
||||
if next != nil {
|
||||
current = next
|
||||
}
|
||||
case HookActionAbortTurn, HookActionHardAbort:
|
||||
return current, decision
|
||||
default:
|
||||
hm.logUnsupportedAction(reg.Name, "before_llm", decision.Action)
|
||||
}
|
||||
}
|
||||
return current, HookDecision{Action: HookActionContinue}
|
||||
}
|
||||
|
||||
func (hm *HookManager) AfterLLM(ctx context.Context, resp *LLMHookResponse) (*LLMHookResponse, HookDecision) {
|
||||
if hm == nil || resp == nil {
|
||||
return resp, HookDecision{Action: HookActionContinue}
|
||||
}
|
||||
|
||||
current := resp.Clone()
|
||||
for _, reg := range hm.snapshotHooks() {
|
||||
interceptor, ok := reg.Hook.(LLMInterceptor)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
next, decision, ok := hm.callAfterLLM(ctx, reg.Name, interceptor, current.Clone())
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
switch decision.normalizedAction() {
|
||||
case HookActionContinue, HookActionModify:
|
||||
if next != nil {
|
||||
current = next
|
||||
}
|
||||
case HookActionAbortTurn, HookActionHardAbort:
|
||||
return current, decision
|
||||
default:
|
||||
hm.logUnsupportedAction(reg.Name, "after_llm", decision.Action)
|
||||
}
|
||||
}
|
||||
return current, HookDecision{Action: HookActionContinue}
|
||||
}
|
||||
|
||||
func (hm *HookManager) BeforeTool(
|
||||
ctx context.Context,
|
||||
call *ToolCallHookRequest,
|
||||
) (*ToolCallHookRequest, HookDecision) {
|
||||
if hm == nil || call == nil {
|
||||
return call, HookDecision{Action: HookActionContinue}
|
||||
}
|
||||
|
||||
current := call.Clone()
|
||||
for _, reg := range hm.snapshotHooks() {
|
||||
interceptor, ok := reg.Hook.(ToolInterceptor)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
next, decision, ok := hm.callBeforeTool(ctx, reg.Name, interceptor, current.Clone())
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
switch decision.normalizedAction() {
|
||||
case HookActionContinue, HookActionModify:
|
||||
if next != nil {
|
||||
current = next
|
||||
}
|
||||
case HookActionDenyTool, HookActionAbortTurn, HookActionHardAbort:
|
||||
return current, decision
|
||||
default:
|
||||
hm.logUnsupportedAction(reg.Name, "before_tool", decision.Action)
|
||||
}
|
||||
}
|
||||
return current, HookDecision{Action: HookActionContinue}
|
||||
}
|
||||
|
||||
func (hm *HookManager) AfterTool(
|
||||
ctx context.Context,
|
||||
result *ToolResultHookResponse,
|
||||
) (*ToolResultHookResponse, HookDecision) {
|
||||
if hm == nil || result == nil {
|
||||
return result, HookDecision{Action: HookActionContinue}
|
||||
}
|
||||
|
||||
current := result.Clone()
|
||||
for _, reg := range hm.snapshotHooks() {
|
||||
interceptor, ok := reg.Hook.(ToolInterceptor)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
next, decision, ok := hm.callAfterTool(ctx, reg.Name, interceptor, current.Clone())
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
switch decision.normalizedAction() {
|
||||
case HookActionContinue, HookActionModify:
|
||||
if next != nil {
|
||||
current = next
|
||||
}
|
||||
case HookActionAbortTurn, HookActionHardAbort:
|
||||
return current, decision
|
||||
default:
|
||||
hm.logUnsupportedAction(reg.Name, "after_tool", decision.Action)
|
||||
}
|
||||
}
|
||||
return current, HookDecision{Action: HookActionContinue}
|
||||
}
|
||||
|
||||
func (hm *HookManager) ApproveTool(ctx context.Context, req *ToolApprovalRequest) ApprovalDecision {
|
||||
if hm == nil || req == nil {
|
||||
return ApprovalDecision{Approved: true}
|
||||
}
|
||||
|
||||
for _, reg := range hm.snapshotHooks() {
|
||||
approver, ok := reg.Hook.(ToolApprover)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
decision, ok := hm.callApproveTool(ctx, reg.Name, approver, req.Clone())
|
||||
if !ok {
|
||||
return ApprovalDecision{
|
||||
Approved: false,
|
||||
Reason: fmt.Sprintf("tool approval hook %q failed", reg.Name),
|
||||
}
|
||||
}
|
||||
if !decision.Approved {
|
||||
return decision
|
||||
}
|
||||
}
|
||||
|
||||
return ApprovalDecision{Approved: true}
|
||||
}
|
||||
|
||||
func (hm *HookManager) rebuildOrdered() {
|
||||
hm.ordered = hm.ordered[:0]
|
||||
for _, reg := range hm.hooks {
|
||||
hm.ordered = append(hm.ordered, reg)
|
||||
}
|
||||
sort.SliceStable(hm.ordered, func(i, j int) bool {
|
||||
if hm.ordered[i].Priority == hm.ordered[j].Priority {
|
||||
return hm.ordered[i].Name < hm.ordered[j].Name
|
||||
}
|
||||
return hm.ordered[i].Priority < hm.ordered[j].Priority
|
||||
})
|
||||
}
|
||||
|
||||
func (hm *HookManager) snapshotHooks() []HookRegistration {
|
||||
hm.mu.RLock()
|
||||
defer hm.mu.RUnlock()
|
||||
|
||||
snapshot := make([]HookRegistration, len(hm.ordered))
|
||||
copy(snapshot, hm.ordered)
|
||||
return snapshot
|
||||
}
|
||||
|
||||
func (hm *HookManager) runObserver(name string, observer EventObserver, evt Event) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hm.observerTimeout)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- observer.OnEvent(ctx, evt)
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
logger.WarnCF("hooks", "Event observer failed", map[string]any{
|
||||
"hook": name,
|
||||
"event": evt.Kind.String(),
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
case <-ctx.Done():
|
||||
logger.WarnCF("hooks", "Event observer timed out", map[string]any{
|
||||
"hook": name,
|
||||
"event": evt.Kind.String(),
|
||||
"timeout_ms": hm.observerTimeout.Milliseconds(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (hm *HookManager) callBeforeLLM(
|
||||
parent context.Context,
|
||||
name string,
|
||||
interceptor LLMInterceptor,
|
||||
req *LLMHookRequest,
|
||||
) (*LLMHookRequest, HookDecision, bool) {
|
||||
return runInterceptorHook(
|
||||
parent,
|
||||
hm.interceptorTimeout,
|
||||
name,
|
||||
"before_llm",
|
||||
func(ctx context.Context) (*LLMHookRequest, HookDecision, error) {
|
||||
return interceptor.BeforeLLM(ctx, req)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func (hm *HookManager) callAfterLLM(
|
||||
parent context.Context,
|
||||
name string,
|
||||
interceptor LLMInterceptor,
|
||||
resp *LLMHookResponse,
|
||||
) (*LLMHookResponse, HookDecision, bool) {
|
||||
return runInterceptorHook(
|
||||
parent,
|
||||
hm.interceptorTimeout,
|
||||
name,
|
||||
"after_llm",
|
||||
func(ctx context.Context) (*LLMHookResponse, HookDecision, error) {
|
||||
return interceptor.AfterLLM(ctx, resp)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func (hm *HookManager) callBeforeTool(
|
||||
parent context.Context,
|
||||
name string,
|
||||
interceptor ToolInterceptor,
|
||||
call *ToolCallHookRequest,
|
||||
) (*ToolCallHookRequest, HookDecision, bool) {
|
||||
return runInterceptorHook(
|
||||
parent,
|
||||
hm.interceptorTimeout,
|
||||
name,
|
||||
"before_tool",
|
||||
func(ctx context.Context) (*ToolCallHookRequest, HookDecision, error) {
|
||||
return interceptor.BeforeTool(ctx, call)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func (hm *HookManager) callAfterTool(
|
||||
parent context.Context,
|
||||
name string,
|
||||
interceptor ToolInterceptor,
|
||||
resultView *ToolResultHookResponse,
|
||||
) (*ToolResultHookResponse, HookDecision, bool) {
|
||||
return runInterceptorHook(
|
||||
parent,
|
||||
hm.interceptorTimeout,
|
||||
name,
|
||||
"after_tool",
|
||||
func(ctx context.Context) (*ToolResultHookResponse, HookDecision, error) {
|
||||
return interceptor.AfterTool(ctx, resultView)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func (hm *HookManager) callApproveTool(
|
||||
parent context.Context,
|
||||
name string,
|
||||
approver ToolApprover,
|
||||
req *ToolApprovalRequest,
|
||||
) (ApprovalDecision, bool) {
|
||||
return runApprovalHook(
|
||||
parent,
|
||||
hm.approvalTimeout,
|
||||
name,
|
||||
"approve_tool",
|
||||
func(ctx context.Context) (ApprovalDecision, error) {
|
||||
return approver.ApproveTool(ctx, req)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func runInterceptorHook[T any](
|
||||
parent context.Context,
|
||||
timeout time.Duration,
|
||||
name string,
|
||||
stage string,
|
||||
fn func(ctx context.Context) (T, HookDecision, error),
|
||||
) (T, HookDecision, bool) {
|
||||
var zero T
|
||||
|
||||
ctx, cancel := context.WithTimeout(parent, timeout)
|
||||
defer cancel()
|
||||
|
||||
type result struct {
|
||||
value T
|
||||
decision HookDecision
|
||||
err error
|
||||
}
|
||||
done := make(chan result, 1)
|
||||
go func() {
|
||||
value, decision, err := fn(ctx)
|
||||
done <- result{value: value, decision: decision, err: err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case res := <-done:
|
||||
if res.err != nil {
|
||||
logger.WarnCF("hooks", "Interceptor hook failed", map[string]any{
|
||||
"hook": name,
|
||||
"stage": stage,
|
||||
"error": res.err.Error(),
|
||||
})
|
||||
return zero, HookDecision{}, false
|
||||
}
|
||||
return res.value, res.decision, true
|
||||
case <-ctx.Done():
|
||||
logger.WarnCF("hooks", "Interceptor hook timed out", map[string]any{
|
||||
"hook": name,
|
||||
"stage": stage,
|
||||
"timeout_ms": timeout.Milliseconds(),
|
||||
})
|
||||
return zero, HookDecision{}, false
|
||||
}
|
||||
}
|
||||
|
||||
func runApprovalHook(
|
||||
parent context.Context,
|
||||
timeout time.Duration,
|
||||
name string,
|
||||
stage string,
|
||||
fn func(ctx context.Context) (ApprovalDecision, error),
|
||||
) (ApprovalDecision, bool) {
|
||||
ctx, cancel := context.WithTimeout(parent, timeout)
|
||||
defer cancel()
|
||||
|
||||
type result struct {
|
||||
decision ApprovalDecision
|
||||
err error
|
||||
}
|
||||
done := make(chan result, 1)
|
||||
go func() {
|
||||
decision, err := fn(ctx)
|
||||
done <- result{decision: decision, err: err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case res := <-done:
|
||||
if res.err != nil {
|
||||
logger.WarnCF("hooks", "Approval hook failed", map[string]any{
|
||||
"hook": name,
|
||||
"stage": stage,
|
||||
"error": res.err.Error(),
|
||||
})
|
||||
return ApprovalDecision{}, false
|
||||
}
|
||||
return res.decision, true
|
||||
case <-ctx.Done():
|
||||
logger.WarnCF("hooks", "Approval hook timed out", map[string]any{
|
||||
"hook": name,
|
||||
"stage": stage,
|
||||
"timeout_ms": timeout.Milliseconds(),
|
||||
})
|
||||
return ApprovalDecision{
|
||||
Approved: false,
|
||||
Reason: fmt.Sprintf("tool approval hook %q timed out", name),
|
||||
}, true
|
||||
}
|
||||
}
|
||||
|
||||
func (hm *HookManager) logUnsupportedAction(name, stage string, action HookAction) {
|
||||
logger.WarnCF("hooks", "Hook returned unsupported action for stage", map[string]any{
|
||||
"hook": name,
|
||||
"stage": stage,
|
||||
"action": action,
|
||||
})
|
||||
}
|
||||
|
||||
func cloneProviderMessages(messages []providers.Message) []providers.Message {
|
||||
if len(messages) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
cloned := make([]providers.Message, len(messages))
|
||||
for i, msg := range messages {
|
||||
cloned[i] = msg
|
||||
if len(msg.Media) > 0 {
|
||||
cloned[i].Media = append([]string(nil), msg.Media...)
|
||||
}
|
||||
if len(msg.SystemParts) > 0 {
|
||||
cloned[i].SystemParts = append([]providers.ContentBlock(nil), msg.SystemParts...)
|
||||
}
|
||||
if len(msg.ToolCalls) > 0 {
|
||||
cloned[i].ToolCalls = cloneProviderToolCalls(msg.ToolCalls)
|
||||
}
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func cloneProviderToolCalls(calls []providers.ToolCall) []providers.ToolCall {
|
||||
if len(calls) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
cloned := make([]providers.ToolCall, len(calls))
|
||||
for i, call := range calls {
|
||||
cloned[i] = call
|
||||
if call.Function != nil {
|
||||
fn := *call.Function
|
||||
cloned[i].Function = &fn
|
||||
}
|
||||
if call.Arguments != nil {
|
||||
cloned[i].Arguments = cloneStringAnyMap(call.Arguments)
|
||||
}
|
||||
if call.ExtraContent != nil {
|
||||
extra := *call.ExtraContent
|
||||
if call.ExtraContent.Google != nil {
|
||||
google := *call.ExtraContent.Google
|
||||
extra.Google = &google
|
||||
}
|
||||
cloned[i].ExtraContent = &extra
|
||||
}
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func cloneToolDefinitions(defs []providers.ToolDefinition) []providers.ToolDefinition {
|
||||
if len(defs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
cloned := make([]providers.ToolDefinition, len(defs))
|
||||
for i, def := range defs {
|
||||
cloned[i] = def
|
||||
cloned[i].Function.Parameters = cloneStringAnyMap(def.Function.Parameters)
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func cloneLLMResponse(resp *providers.LLMResponse) *providers.LLMResponse {
|
||||
if resp == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *resp
|
||||
cloned.ToolCalls = cloneProviderToolCalls(resp.ToolCalls)
|
||||
if len(resp.ReasoningDetails) > 0 {
|
||||
cloned.ReasoningDetails = append(cloned.ReasoningDetails[:0:0], resp.ReasoningDetails...)
|
||||
}
|
||||
if resp.Usage != nil {
|
||||
usage := *resp.Usage
|
||||
cloned.Usage = &usage
|
||||
}
|
||||
return &cloned
|
||||
}
|
||||
|
||||
func cloneStringAnyMap(src map[string]any) map[string]any {
|
||||
if len(src) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
cloned := make(map[string]any, len(src))
|
||||
for k, v := range src {
|
||||
cloned[k] = v
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func cloneToolResult(result *tools.ToolResult) *tools.ToolResult {
|
||||
if result == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
cloned := *result
|
||||
if len(result.Media) > 0 {
|
||||
cloned.Media = append([]string(nil), result.Media...)
|
||||
}
|
||||
return &cloned
|
||||
}
|
||||
@@ -0,0 +1,312 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
|
||||
func newHookTestLoop(
|
||||
t *testing.T,
|
||||
provider providers.LLMProvider,
|
||||
) (*AgentLoop, *AgentInstance, func()) {
|
||||
t.Helper()
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", "agent-hooks-*")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp dir: %v", err)
|
||||
}
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
al := NewAgentLoop(cfg, bus.NewMessageBus(), provider)
|
||||
agent := al.registry.GetDefaultAgent()
|
||||
if agent == nil {
|
||||
t.Fatal("expected default agent")
|
||||
}
|
||||
|
||||
return al, agent, func() {
|
||||
al.Close()
|
||||
_ = os.RemoveAll(tmpDir)
|
||||
}
|
||||
}
|
||||
|
||||
type llmHookTestProvider struct {
|
||||
mu sync.Mutex
|
||||
lastModel string
|
||||
}
|
||||
|
||||
func (p *llmHookTestProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
p.mu.Lock()
|
||||
p.lastModel = model
|
||||
p.mu.Unlock()
|
||||
|
||||
return &providers.LLMResponse{
|
||||
Content: "provider content",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *llmHookTestProvider) GetDefaultModel() string {
|
||||
return "llm-hook-provider"
|
||||
}
|
||||
|
||||
type llmObserverHook struct {
|
||||
eventCh chan Event
|
||||
}
|
||||
|
||||
func (h *llmObserverHook) OnEvent(ctx context.Context, evt Event) error {
|
||||
if evt.Kind == EventKindTurnEnd {
|
||||
select {
|
||||
case h.eventCh <- evt:
|
||||
default:
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *llmObserverHook) BeforeLLM(
|
||||
ctx context.Context,
|
||||
req *LLMHookRequest,
|
||||
) (*LLMHookRequest, HookDecision, error) {
|
||||
next := req.Clone()
|
||||
next.Model = "hook-model"
|
||||
return next, HookDecision{Action: HookActionModify}, nil
|
||||
}
|
||||
|
||||
func (h *llmObserverHook) AfterLLM(
|
||||
ctx context.Context,
|
||||
resp *LLMHookResponse,
|
||||
) (*LLMHookResponse, HookDecision, error) {
|
||||
next := resp.Clone()
|
||||
next.Response.Content = "hooked content"
|
||||
return next, HookDecision{Action: HookActionModify}, nil
|
||||
}
|
||||
|
||||
func TestAgentLoop_Hooks_ObserverAndLLMInterceptor(t *testing.T) {
|
||||
provider := &llmHookTestProvider{}
|
||||
al, agent, cleanup := newHookTestLoop(t, provider)
|
||||
defer cleanup()
|
||||
|
||||
hook := &llmObserverHook{eventCh: make(chan Event, 1)}
|
||||
if err := al.MountHook(NamedHook("llm-observer", hook)); err != nil {
|
||||
t.Fatalf("MountHook failed: %v", err)
|
||||
}
|
||||
|
||||
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
||||
SessionKey: "session-1",
|
||||
Channel: "cli",
|
||||
ChatID: "direct",
|
||||
UserMessage: "hello",
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("runAgentLoop failed: %v", err)
|
||||
}
|
||||
if resp != "hooked content" {
|
||||
t.Fatalf("expected hooked content, got %q", resp)
|
||||
}
|
||||
|
||||
provider.mu.Lock()
|
||||
lastModel := provider.lastModel
|
||||
provider.mu.Unlock()
|
||||
if lastModel != "hook-model" {
|
||||
t.Fatalf("expected model hook-model, got %q", lastModel)
|
||||
}
|
||||
|
||||
select {
|
||||
case evt := <-hook.eventCh:
|
||||
if evt.Kind != EventKindTurnEnd {
|
||||
t.Fatalf("expected turn end event, got %v", evt.Kind)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for hook observer event")
|
||||
}
|
||||
}
|
||||
|
||||
type toolHookProvider struct {
|
||||
mu sync.Mutex
|
||||
calls int
|
||||
}
|
||||
|
||||
func (p *toolHookProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
p.calls++
|
||||
if p.calls == 1 {
|
||||
return &providers.LLMResponse{
|
||||
ToolCalls: []providers.ToolCall{
|
||||
{
|
||||
ID: "call-1",
|
||||
Name: "echo_text",
|
||||
Arguments: map[string]any{"text": "original"},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
last := messages[len(messages)-1]
|
||||
return &providers.LLMResponse{
|
||||
Content: last.Content,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *toolHookProvider) GetDefaultModel() string {
|
||||
return "tool-hook-provider"
|
||||
}
|
||||
|
||||
type echoTextTool struct{}
|
||||
|
||||
func (t *echoTextTool) Name() string {
|
||||
return "echo_text"
|
||||
}
|
||||
|
||||
func (t *echoTextTool) Description() string {
|
||||
return "echo a text argument"
|
||||
}
|
||||
|
||||
func (t *echoTextTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"text": map[string]any{
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *echoTextTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
|
||||
text, _ := args["text"].(string)
|
||||
return tools.SilentResult(text)
|
||||
}
|
||||
|
||||
type toolRewriteHook struct{}
|
||||
|
||||
func (h *toolRewriteHook) BeforeTool(
|
||||
ctx context.Context,
|
||||
call *ToolCallHookRequest,
|
||||
) (*ToolCallHookRequest, HookDecision, error) {
|
||||
next := call.Clone()
|
||||
next.Arguments["text"] = "modified"
|
||||
return next, HookDecision{Action: HookActionModify}, nil
|
||||
}
|
||||
|
||||
func (h *toolRewriteHook) AfterTool(
|
||||
ctx context.Context,
|
||||
result *ToolResultHookResponse,
|
||||
) (*ToolResultHookResponse, HookDecision, error) {
|
||||
next := result.Clone()
|
||||
next.Result.ForLLM = "after:" + next.Result.ForLLM
|
||||
return next, HookDecision{Action: HookActionModify}, nil
|
||||
}
|
||||
|
||||
func TestAgentLoop_Hooks_ToolInterceptorCanRewrite(t *testing.T) {
|
||||
provider := &toolHookProvider{}
|
||||
al, agent, cleanup := newHookTestLoop(t, provider)
|
||||
defer cleanup()
|
||||
|
||||
al.RegisterTool(&echoTextTool{})
|
||||
if err := al.MountHook(NamedHook("tool-rewrite", &toolRewriteHook{})); err != nil {
|
||||
t.Fatalf("MountHook failed: %v", err)
|
||||
}
|
||||
|
||||
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
||||
SessionKey: "session-1",
|
||||
Channel: "cli",
|
||||
ChatID: "direct",
|
||||
UserMessage: "run tool",
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("runAgentLoop failed: %v", err)
|
||||
}
|
||||
if resp != "after:modified" {
|
||||
t.Fatalf("expected rewritten tool result, got %q", resp)
|
||||
}
|
||||
}
|
||||
|
||||
type denyApprovalHook struct{}
|
||||
|
||||
func (h *denyApprovalHook) ApproveTool(ctx context.Context, req *ToolApprovalRequest) (ApprovalDecision, error) {
|
||||
return ApprovalDecision{
|
||||
Approved: false,
|
||||
Reason: "blocked",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func TestAgentLoop_Hooks_ToolApproverCanDeny(t *testing.T) {
|
||||
provider := &toolHookProvider{}
|
||||
al, agent, cleanup := newHookTestLoop(t, provider)
|
||||
defer cleanup()
|
||||
|
||||
al.RegisterTool(&echoTextTool{})
|
||||
if err := al.MountHook(NamedHook("deny-approval", &denyApprovalHook{})); err != nil {
|
||||
t.Fatalf("MountHook failed: %v", err)
|
||||
}
|
||||
|
||||
sub := al.SubscribeEvents(16)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
|
||||
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
||||
SessionKey: "session-1",
|
||||
Channel: "cli",
|
||||
ChatID: "direct",
|
||||
UserMessage: "run tool",
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("runAgentLoop failed: %v", err)
|
||||
}
|
||||
expected := "Tool execution denied by approval hook: blocked"
|
||||
if resp != expected {
|
||||
t.Fatalf("expected %q, got %q", expected, resp)
|
||||
}
|
||||
|
||||
events := collectEventStream(sub.C)
|
||||
skippedEvt, ok := findEvent(events, EventKindToolExecSkipped)
|
||||
if !ok {
|
||||
t.Fatal("expected tool skipped event")
|
||||
}
|
||||
payload, ok := skippedEvt.Payload.(ToolExecSkippedPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected ToolExecSkippedPayload, got %T", skippedEvt.Payload)
|
||||
}
|
||||
if payload.Reason != expected {
|
||||
t.Fatalf("expected skipped reason %q, got %q", expected, payload.Reason)
|
||||
}
|
||||
}
|
||||
+262
-47
@@ -40,6 +40,7 @@ type AgentLoop struct {
|
||||
registry *AgentRegistry
|
||||
state *state.Manager
|
||||
eventBus *EventBus
|
||||
hooks *HookManager
|
||||
running atomic.Bool
|
||||
summarizing sync.Map
|
||||
fallback *providers.FallbackChain
|
||||
@@ -108,17 +109,19 @@ func NewAgentLoop(
|
||||
stateManager = state.NewManager(defaultAgent.Workspace)
|
||||
}
|
||||
|
||||
eventBus := NewEventBus()
|
||||
al := &AgentLoop{
|
||||
bus: msgBus,
|
||||
cfg: cfg,
|
||||
registry: registry,
|
||||
state: stateManager,
|
||||
eventBus: NewEventBus(),
|
||||
eventBus: eventBus,
|
||||
summarizing: sync.Map{},
|
||||
fallback: fallbackChain,
|
||||
cmdRegistry: commands.NewRegistry(commands.BuiltinDefinitions()),
|
||||
steering: newSteeringQueue(parseSteeringMode(cfg.Agents.Defaults.SteeringMode)),
|
||||
}
|
||||
al.hooks = NewHookManager(eventBus)
|
||||
|
||||
return al
|
||||
}
|
||||
@@ -460,11 +463,30 @@ func (al *AgentLoop) Close() {
|
||||
}
|
||||
|
||||
al.GetRegistry().Close()
|
||||
if al.hooks != nil {
|
||||
al.hooks.Close()
|
||||
}
|
||||
if al.eventBus != nil {
|
||||
al.eventBus.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// MountHook registers an in-process hook on the agent loop.
|
||||
func (al *AgentLoop) MountHook(reg HookRegistration) error {
|
||||
if al == nil || al.hooks == nil {
|
||||
return fmt.Errorf("hook manager is not initialized")
|
||||
}
|
||||
return al.hooks.Mount(reg)
|
||||
}
|
||||
|
||||
// UnmountHook removes a previously registered in-process hook.
|
||||
func (al *AgentLoop) UnmountHook(name string) {
|
||||
if al == nil || al.hooks == nil {
|
||||
return
|
||||
}
|
||||
al.hooks.Unmount(name)
|
||||
}
|
||||
|
||||
// SubscribeEvents registers a subscriber for agent-loop events.
|
||||
func (al *AgentLoop) SubscribeEvents(buffer int) EventSubscription {
|
||||
if al == nil || al.eventBus == nil {
|
||||
@@ -544,6 +566,31 @@ func cloneEventArguments(args map[string]any) map[string]any {
|
||||
return cloned
|
||||
}
|
||||
|
||||
func (al *AgentLoop) hookAbortError(ts *turnState, stage string, decision HookDecision) error {
|
||||
reason := decision.Reason
|
||||
if reason == "" {
|
||||
reason = "hook requested turn abort"
|
||||
}
|
||||
|
||||
err := fmt.Errorf("hook aborted turn during %s: %s", stage, reason)
|
||||
al.emitEvent(
|
||||
EventKindError,
|
||||
ts.eventMeta("hooks", "turn.error"),
|
||||
ErrorPayload{
|
||||
Stage: "hook." + stage,
|
||||
Message: err.Error(),
|
||||
},
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func hookDeniedToolContent(prefix, reason string) string {
|
||||
if reason == "" {
|
||||
return prefix
|
||||
}
|
||||
return prefix + ": " + reason
|
||||
}
|
||||
|
||||
func (al *AgentLoop) logEvent(evt Event) {
|
||||
fields := map[string]any{
|
||||
"event_kind": evt.Kind.String(),
|
||||
@@ -1418,36 +1465,6 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
|
||||
ts.markGracefulTerminalUsed()
|
||||
}
|
||||
|
||||
al.emitEvent(
|
||||
EventKindLLMRequest,
|
||||
ts.eventMeta("runTurn", "turn.llm.request"),
|
||||
LLMRequestPayload{
|
||||
Model: activeModel,
|
||||
MessagesCount: len(callMessages),
|
||||
ToolsCount: len(providerToolDefs),
|
||||
MaxTokens: ts.agent.MaxTokens,
|
||||
Temperature: ts.agent.Temperature,
|
||||
},
|
||||
)
|
||||
|
||||
logger.DebugCF("agent", "LLM request",
|
||||
map[string]any{
|
||||
"agent_id": ts.agent.ID,
|
||||
"iteration": iteration,
|
||||
"model": activeModel,
|
||||
"messages_count": len(callMessages),
|
||||
"tools_count": len(providerToolDefs),
|
||||
"max_tokens": ts.agent.MaxTokens,
|
||||
"temperature": ts.agent.Temperature,
|
||||
"system_prompt_len": len(callMessages[0].Content),
|
||||
})
|
||||
logger.DebugCF("agent", "Full LLM request",
|
||||
map[string]any{
|
||||
"iteration": iteration,
|
||||
"messages_json": formatMessagesForLog(callMessages),
|
||||
"tools_json": formatToolsForLog(providerToolDefs),
|
||||
})
|
||||
|
||||
llmOpts := map[string]any{
|
||||
"max_tokens": ts.agent.MaxTokens,
|
||||
"temperature": ts.agent.Temperature,
|
||||
@@ -1462,6 +1479,66 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
|
||||
}
|
||||
}
|
||||
|
||||
llmModel := activeModel
|
||||
if al.hooks != nil {
|
||||
llmReq, decision := al.hooks.BeforeLLM(turnCtx, &LLMHookRequest{
|
||||
Meta: ts.eventMeta("runTurn", "turn.llm.request"),
|
||||
Model: llmModel,
|
||||
Messages: callMessages,
|
||||
Tools: providerToolDefs,
|
||||
Options: llmOpts,
|
||||
Channel: ts.channel,
|
||||
ChatID: ts.chatID,
|
||||
GracefulTerminal: gracefulTerminal,
|
||||
})
|
||||
switch decision.normalizedAction() {
|
||||
case HookActionContinue, HookActionModify:
|
||||
if llmReq != nil {
|
||||
llmModel = llmReq.Model
|
||||
callMessages = llmReq.Messages
|
||||
providerToolDefs = llmReq.Tools
|
||||
llmOpts = llmReq.Options
|
||||
}
|
||||
case HookActionAbortTurn:
|
||||
turnStatus = TurnEndStatusError
|
||||
return turnResult{}, al.hookAbortError(ts, "before_llm", decision)
|
||||
case HookActionHardAbort:
|
||||
_ = ts.requestHardAbort()
|
||||
turnStatus = TurnEndStatusAborted
|
||||
return al.abortTurn(ts)
|
||||
}
|
||||
}
|
||||
|
||||
al.emitEvent(
|
||||
EventKindLLMRequest,
|
||||
ts.eventMeta("runTurn", "turn.llm.request"),
|
||||
LLMRequestPayload{
|
||||
Model: llmModel,
|
||||
MessagesCount: len(callMessages),
|
||||
ToolsCount: len(providerToolDefs),
|
||||
MaxTokens: ts.agent.MaxTokens,
|
||||
Temperature: ts.agent.Temperature,
|
||||
},
|
||||
)
|
||||
|
||||
logger.DebugCF("agent", "LLM request",
|
||||
map[string]any{
|
||||
"agent_id": ts.agent.ID,
|
||||
"iteration": iteration,
|
||||
"model": llmModel,
|
||||
"messages_count": len(callMessages),
|
||||
"tools_count": len(providerToolDefs),
|
||||
"max_tokens": ts.agent.MaxTokens,
|
||||
"temperature": ts.agent.Temperature,
|
||||
"system_prompt_len": len(callMessages[0].Content),
|
||||
})
|
||||
logger.DebugCF("agent", "Full LLM request",
|
||||
map[string]any{
|
||||
"iteration": iteration,
|
||||
"messages_json": formatMessagesForLog(callMessages),
|
||||
"tools_json": formatToolsForLog(providerToolDefs),
|
||||
})
|
||||
|
||||
callLLM := func(messagesForCall []providers.Message, toolDefsForCall []providers.ToolDefinition) (*providers.LLMResponse, error) {
|
||||
providerCtx, providerCancel := context.WithCancel(turnCtx)
|
||||
ts.setProviderCancel(providerCancel)
|
||||
@@ -1494,7 +1571,7 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
|
||||
}
|
||||
return fbResult.Response, nil
|
||||
}
|
||||
return ts.agent.Provider.Chat(providerCtx, messagesForCall, toolDefsForCall, activeModel, llmOpts)
|
||||
return ts.agent.Provider.Chat(providerCtx, messagesForCall, toolDefsForCall, llmModel, llmOpts)
|
||||
}
|
||||
|
||||
var response *providers.LLMResponse
|
||||
@@ -1626,12 +1703,35 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
|
||||
map[string]any{
|
||||
"agent_id": ts.agent.ID,
|
||||
"iteration": iteration,
|
||||
"model": activeModel,
|
||||
"model": llmModel,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return turnResult{}, fmt.Errorf("LLM call failed after retries: %w", err)
|
||||
}
|
||||
|
||||
if al.hooks != nil {
|
||||
llmResp, decision := al.hooks.AfterLLM(turnCtx, &LLMHookResponse{
|
||||
Meta: ts.eventMeta("runTurn", "turn.llm.response"),
|
||||
Model: llmModel,
|
||||
Response: response,
|
||||
Channel: ts.channel,
|
||||
ChatID: ts.chatID,
|
||||
})
|
||||
switch decision.normalizedAction() {
|
||||
case HookActionContinue, HookActionModify:
|
||||
if llmResp != nil && llmResp.Response != nil {
|
||||
response = llmResp.Response
|
||||
}
|
||||
case HookActionAbortTurn:
|
||||
turnStatus = TurnEndStatusError
|
||||
return turnResult{}, al.hookAbortError(ts, "after_llm", decision)
|
||||
case HookActionHardAbort:
|
||||
_ = ts.requestHardAbort()
|
||||
turnStatus = TurnEndStatusAborted
|
||||
return al.abortTurn(ts)
|
||||
}
|
||||
}
|
||||
|
||||
go al.handleReasoning(
|
||||
turnCtx,
|
||||
response.Reasoning,
|
||||
@@ -1728,25 +1828,106 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
|
||||
return al.abortTurn(ts)
|
||||
}
|
||||
|
||||
argsJSON, _ := json.Marshal(tc.Arguments)
|
||||
toolName := tc.Name
|
||||
toolArgs := cloneStringAnyMap(tc.Arguments)
|
||||
|
||||
if al.hooks != nil {
|
||||
toolReq, decision := al.hooks.BeforeTool(turnCtx, &ToolCallHookRequest{
|
||||
Meta: ts.eventMeta("runTurn", "turn.tool.before"),
|
||||
Tool: toolName,
|
||||
Arguments: toolArgs,
|
||||
Channel: ts.channel,
|
||||
ChatID: ts.chatID,
|
||||
})
|
||||
switch decision.normalizedAction() {
|
||||
case HookActionContinue, HookActionModify:
|
||||
if toolReq != nil {
|
||||
toolName = toolReq.Tool
|
||||
toolArgs = toolReq.Arguments
|
||||
}
|
||||
case HookActionDenyTool:
|
||||
denyContent := hookDeniedToolContent("Tool execution denied by hook", decision.Reason)
|
||||
al.emitEvent(
|
||||
EventKindToolExecSkipped,
|
||||
ts.eventMeta("runTurn", "turn.tool.skipped"),
|
||||
ToolExecSkippedPayload{
|
||||
Tool: toolName,
|
||||
Reason: denyContent,
|
||||
},
|
||||
)
|
||||
deniedMsg := providers.Message{
|
||||
Role: "tool",
|
||||
Content: denyContent,
|
||||
ToolCallID: tc.ID,
|
||||
}
|
||||
messages = append(messages, deniedMsg)
|
||||
if !ts.opts.NoHistory {
|
||||
ts.agent.Sessions.AddFullMessage(ts.sessionKey, deniedMsg)
|
||||
ts.recordPersistedMessage(deniedMsg)
|
||||
}
|
||||
continue
|
||||
case HookActionAbortTurn:
|
||||
turnStatus = TurnEndStatusError
|
||||
return turnResult{}, al.hookAbortError(ts, "before_tool", decision)
|
||||
case HookActionHardAbort:
|
||||
_ = ts.requestHardAbort()
|
||||
turnStatus = TurnEndStatusAborted
|
||||
return al.abortTurn(ts)
|
||||
}
|
||||
}
|
||||
|
||||
if al.hooks != nil {
|
||||
approval := al.hooks.ApproveTool(turnCtx, &ToolApprovalRequest{
|
||||
Meta: ts.eventMeta("runTurn", "turn.tool.approve"),
|
||||
Tool: toolName,
|
||||
Arguments: toolArgs,
|
||||
Channel: ts.channel,
|
||||
ChatID: ts.chatID,
|
||||
})
|
||||
if !approval.Approved {
|
||||
denyContent := hookDeniedToolContent("Tool execution denied by approval hook", approval.Reason)
|
||||
al.emitEvent(
|
||||
EventKindToolExecSkipped,
|
||||
ts.eventMeta("runTurn", "turn.tool.skipped"),
|
||||
ToolExecSkippedPayload{
|
||||
Tool: toolName,
|
||||
Reason: denyContent,
|
||||
},
|
||||
)
|
||||
deniedMsg := providers.Message{
|
||||
Role: "tool",
|
||||
Content: denyContent,
|
||||
ToolCallID: tc.ID,
|
||||
}
|
||||
messages = append(messages, deniedMsg)
|
||||
if !ts.opts.NoHistory {
|
||||
ts.agent.Sessions.AddFullMessage(ts.sessionKey, deniedMsg)
|
||||
ts.recordPersistedMessage(deniedMsg)
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
argsJSON, _ := json.Marshal(toolArgs)
|
||||
argsPreview := utils.Truncate(string(argsJSON), 200)
|
||||
logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
|
||||
logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", toolName, argsPreview),
|
||||
map[string]any{
|
||||
"agent_id": ts.agent.ID,
|
||||
"tool": tc.Name,
|
||||
"tool": toolName,
|
||||
"iteration": iteration,
|
||||
})
|
||||
al.emitEvent(
|
||||
EventKindToolExecStart,
|
||||
ts.eventMeta("runTurn", "turn.tool.start"),
|
||||
ToolExecStartPayload{
|
||||
Tool: tc.Name,
|
||||
Arguments: cloneEventArguments(tc.Arguments),
|
||||
Tool: toolName,
|
||||
Arguments: cloneEventArguments(toolArgs),
|
||||
},
|
||||
)
|
||||
|
||||
toolCall := tc
|
||||
toolCallID := tc.ID
|
||||
toolIteration := iteration
|
||||
asyncToolName := toolName
|
||||
asyncCallback := func(_ context.Context, result *tools.ToolResult) {
|
||||
if !result.Silent && result.ForUser != "" {
|
||||
outCtx, outCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
@@ -1768,7 +1949,7 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
|
||||
|
||||
logger.InfoCF("agent", "Async tool completed, publishing result",
|
||||
map[string]any{
|
||||
"tool": toolCall.Name,
|
||||
"tool": asyncToolName,
|
||||
"content_len": len(content),
|
||||
"channel": ts.channel,
|
||||
})
|
||||
@@ -1776,7 +1957,7 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
|
||||
EventKindFollowUpQueued,
|
||||
ts.scope.meta(toolIteration, "runTurn", "turn.follow_up.queued"),
|
||||
FollowUpQueuedPayload{
|
||||
SourceTool: toolCall.Name,
|
||||
SourceTool: asyncToolName,
|
||||
Channel: ts.channel,
|
||||
ChatID: ts.chatID,
|
||||
ContentLen: len(content),
|
||||
@@ -1787,7 +1968,7 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
|
||||
defer pubCancel()
|
||||
_ = al.bus.PublishInbound(pubCtx, bus.InboundMessage{
|
||||
Channel: "system",
|
||||
SenderID: fmt.Sprintf("async:%s", toolCall.Name),
|
||||
SenderID: fmt.Sprintf("async:%s", asyncToolName),
|
||||
ChatID: fmt.Sprintf("%s:%s", ts.channel, ts.chatID),
|
||||
Content: content,
|
||||
})
|
||||
@@ -1796,8 +1977,8 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
|
||||
toolStart := time.Now()
|
||||
toolResult := ts.agent.Tools.ExecuteWithContext(
|
||||
turnCtx,
|
||||
toolCall.Name,
|
||||
toolCall.Arguments,
|
||||
toolName,
|
||||
toolArgs,
|
||||
ts.channel,
|
||||
ts.chatID,
|
||||
asyncCallback,
|
||||
@@ -1809,6 +1990,40 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
|
||||
return al.abortTurn(ts)
|
||||
}
|
||||
|
||||
if al.hooks != nil {
|
||||
toolResp, decision := al.hooks.AfterTool(turnCtx, &ToolResultHookResponse{
|
||||
Meta: ts.eventMeta("runTurn", "turn.tool.after"),
|
||||
Tool: toolName,
|
||||
Arguments: toolArgs,
|
||||
Result: toolResult,
|
||||
Duration: toolDuration,
|
||||
Channel: ts.channel,
|
||||
ChatID: ts.chatID,
|
||||
})
|
||||
switch decision.normalizedAction() {
|
||||
case HookActionContinue, HookActionModify:
|
||||
if toolResp != nil {
|
||||
if toolResp.Tool != "" {
|
||||
toolName = toolResp.Tool
|
||||
}
|
||||
if toolResp.Result != nil {
|
||||
toolResult = toolResp.Result
|
||||
}
|
||||
}
|
||||
case HookActionAbortTurn:
|
||||
turnStatus = TurnEndStatusError
|
||||
return turnResult{}, al.hookAbortError(ts, "after_tool", decision)
|
||||
case HookActionHardAbort:
|
||||
_ = ts.requestHardAbort()
|
||||
turnStatus = TurnEndStatusAborted
|
||||
return al.abortTurn(ts)
|
||||
}
|
||||
}
|
||||
|
||||
if toolResult == nil {
|
||||
toolResult = tools.ErrorResult("hook returned nil tool result")
|
||||
}
|
||||
|
||||
if !toolResult.Silent && toolResult.ForUser != "" && ts.opts.SendResponse {
|
||||
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
|
||||
Channel: ts.channel,
|
||||
@@ -1817,7 +2032,7 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
|
||||
})
|
||||
logger.DebugCF("agent", "Sent tool result to user",
|
||||
map[string]any{
|
||||
"tool": toolCall.Name,
|
||||
"tool": toolName,
|
||||
"content_len": len(toolResult.ForUser),
|
||||
})
|
||||
}
|
||||
@@ -1850,13 +2065,13 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
|
||||
toolResultMsg := providers.Message{
|
||||
Role: "tool",
|
||||
Content: contentForLLM,
|
||||
ToolCallID: toolCall.ID,
|
||||
ToolCallID: toolCallID,
|
||||
}
|
||||
al.emitEvent(
|
||||
EventKindToolExecEnd,
|
||||
ts.eventMeta("runTurn", "turn.tool.end"),
|
||||
ToolExecEndPayload{
|
||||
Tool: toolCall.Name,
|
||||
Tool: toolName,
|
||||
Duration: toolDuration,
|
||||
ForLLMLen: len(contentForLLM),
|
||||
ForUserLen: len(toolResult.ForUser),
|
||||
|
||||
Reference in New Issue
Block a user