mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
b954e6b8dc
Drop the old agent EventBus, SubscribeEvents/EventDrops public surface, legacy hook observer dispatch, and hook.event process notification path. Agent observations now flow through pkg/events runtime events. Validation: go test ./pkg/agent; make lint
931 lines
23 KiB
Go
931 lines
23 KiB
Go
package agent
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"reflect"
|
|
"sort"
|
|
"sync"
|
|
"time"
|
|
|
|
runtimeevents "github.com/sipeed/picoclaw/pkg/events"
|
|
"github.com/sipeed/picoclaw/pkg/logger"
|
|
"github.com/sipeed/picoclaw/pkg/providers"
|
|
"github.com/sipeed/picoclaw/pkg/tools"
|
|
)
|
|
|
|
const (
|
|
defaultHookObserverTimeout = 500 * time.Millisecond
|
|
defaultHookInterceptorTimeout = 5 * time.Second
|
|
defaultHookApprovalTimeout = 60 * time.Second
|
|
hookObserverBufferSize = 64
|
|
)
|
|
|
|
type HookAction string
|
|
|
|
const (
|
|
HookActionContinue HookAction = "continue"
|
|
HookActionModify HookAction = "modify"
|
|
HookActionRespond HookAction = "respond" // Return result directly, skip tool execution. SECURITY: This bypasses ApproveTool checks, allowing hooks to return results for any tool (including sensitive ones like bash) without approval. Use with caution.
|
|
HookActionDenyTool HookAction = "deny_tool"
|
|
HookActionAbortTurn HookAction = "abort_turn"
|
|
HookActionHardAbort HookAction = "hard_abort"
|
|
)
|
|
|
|
type HookDecision struct {
|
|
Action HookAction `json:"action"`
|
|
Reason string `json:"reason,omitempty"`
|
|
}
|
|
|
|
func (d HookDecision) normalizedAction() HookAction {
|
|
if d.Action == "" {
|
|
return HookActionContinue
|
|
}
|
|
return d.Action
|
|
}
|
|
|
|
type ApprovalDecision struct {
|
|
Approved bool `json:"approved"`
|
|
Reason string `json:"reason,omitempty"`
|
|
}
|
|
|
|
type HookSource uint8
|
|
|
|
const (
|
|
HookSourceInProcess HookSource = iota
|
|
HookSourceProcess
|
|
)
|
|
|
|
type HookRegistration struct {
|
|
Name string
|
|
Priority int
|
|
Source HookSource
|
|
Hook any
|
|
}
|
|
|
|
func NamedHook(name string, hook any) HookRegistration {
|
|
return HookRegistration{
|
|
Name: name,
|
|
Source: HookSourceInProcess,
|
|
Hook: hook,
|
|
}
|
|
}
|
|
|
|
type RuntimeEventObserver interface {
|
|
OnRuntimeEvent(ctx context.Context, evt runtimeevents.Event) error
|
|
}
|
|
|
|
type LLMInterceptor interface {
|
|
BeforeLLM(ctx context.Context, req *LLMHookRequest) (*LLMHookRequest, HookDecision, error)
|
|
AfterLLM(ctx context.Context, resp *LLMHookResponse) (*LLMHookResponse, HookDecision, error)
|
|
}
|
|
|
|
type ToolInterceptor interface {
|
|
BeforeTool(ctx context.Context, call *ToolCallHookRequest) (*ToolCallHookRequest, HookDecision, error)
|
|
AfterTool(ctx context.Context, result *ToolResultHookResponse) (*ToolResultHookResponse, HookDecision, error)
|
|
}
|
|
|
|
type ToolApprover interface {
|
|
ApproveTool(ctx context.Context, req *ToolApprovalRequest) (ApprovalDecision, error)
|
|
}
|
|
|
|
type LLMHookRequest struct {
|
|
Meta EventMeta `json:"meta"`
|
|
Context *TurnContext `json:"context,omitempty"`
|
|
Model string `json:"model"`
|
|
Messages []providers.Message `json:"messages,omitempty"`
|
|
Tools []providers.ToolDefinition `json:"tools,omitempty"`
|
|
Options map[string]any `json:"options,omitempty"`
|
|
GracefulTerminal bool `json:"graceful_terminal,omitempty"`
|
|
}
|
|
|
|
func (r *LLMHookRequest) Clone() *LLMHookRequest {
|
|
if r == nil {
|
|
return nil
|
|
}
|
|
cloned := *r
|
|
cloned.Meta = cloneEventMeta(r.Meta)
|
|
cloned.Context = cloneTurnContext(r.Context)
|
|
cloned.Messages = cloneProviderMessages(r.Messages)
|
|
cloned.Tools = cloneToolDefinitions(r.Tools)
|
|
cloned.Options = cloneStringAnyMap(r.Options)
|
|
return &cloned
|
|
}
|
|
|
|
type LLMHookResponse struct {
|
|
Meta EventMeta `json:"meta"`
|
|
Context *TurnContext `json:"context,omitempty"`
|
|
Model string `json:"model"`
|
|
Response *providers.LLMResponse `json:"response,omitempty"`
|
|
}
|
|
|
|
func (r *LLMHookResponse) Clone() *LLMHookResponse {
|
|
if r == nil {
|
|
return nil
|
|
}
|
|
cloned := *r
|
|
cloned.Meta = cloneEventMeta(r.Meta)
|
|
cloned.Context = cloneTurnContext(r.Context)
|
|
cloned.Response = cloneLLMResponse(r.Response)
|
|
return &cloned
|
|
}
|
|
|
|
type ToolCallHookRequest struct {
|
|
Meta EventMeta `json:"meta"`
|
|
Context *TurnContext `json:"context,omitempty"`
|
|
Tool string `json:"tool"`
|
|
Arguments map[string]any `json:"arguments,omitempty"`
|
|
Channel string `json:"channel,omitempty"`
|
|
ChatID string `json:"chat_id,omitempty"`
|
|
HookResult *tools.ToolResult `json:"hook_result,omitempty"` // Result returned directly by hook (for respond action). Media is supported - see Media handling section in docs.
|
|
}
|
|
|
|
func (r *ToolCallHookRequest) Clone() *ToolCallHookRequest {
|
|
if r == nil {
|
|
return nil
|
|
}
|
|
cloned := *r
|
|
cloned.Meta = cloneEventMeta(r.Meta)
|
|
cloned.Context = cloneTurnContext(r.Context)
|
|
cloned.Arguments = cloneStringAnyMap(r.Arguments)
|
|
cloned.HookResult = cloneToolResult(r.HookResult)
|
|
return &cloned
|
|
}
|
|
|
|
type ToolApprovalRequest struct {
|
|
Meta EventMeta `json:"meta"`
|
|
Context *TurnContext `json:"context,omitempty"`
|
|
Tool string `json:"tool"`
|
|
Arguments map[string]any `json:"arguments,omitempty"`
|
|
}
|
|
|
|
func (r *ToolApprovalRequest) Clone() *ToolApprovalRequest {
|
|
if r == nil {
|
|
return nil
|
|
}
|
|
cloned := *r
|
|
cloned.Meta = cloneEventMeta(r.Meta)
|
|
cloned.Context = cloneTurnContext(r.Context)
|
|
cloned.Arguments = cloneStringAnyMap(r.Arguments)
|
|
return &cloned
|
|
}
|
|
|
|
type ToolResultHookResponse struct {
|
|
Meta EventMeta `json:"meta"`
|
|
Context *TurnContext `json:"context,omitempty"`
|
|
Tool string `json:"tool"`
|
|
Arguments map[string]any `json:"arguments,omitempty"`
|
|
Result *tools.ToolResult `json:"result,omitempty"`
|
|
Duration time.Duration `json:"duration"`
|
|
}
|
|
|
|
func (r *ToolResultHookResponse) Clone() *ToolResultHookResponse {
|
|
if r == nil {
|
|
return nil
|
|
}
|
|
cloned := *r
|
|
cloned.Meta = cloneEventMeta(r.Meta)
|
|
cloned.Context = cloneTurnContext(r.Context)
|
|
cloned.Arguments = cloneStringAnyMap(r.Arguments)
|
|
cloned.Result = cloneToolResult(r.Result)
|
|
return &cloned
|
|
}
|
|
|
|
type HookManager struct {
|
|
runtimeEvents runtimeevents.EventChannel
|
|
observerTimeout time.Duration
|
|
interceptorTimeout time.Duration
|
|
approvalTimeout time.Duration
|
|
|
|
mu sync.RWMutex
|
|
hooks map[string]HookRegistration
|
|
ordered []HookRegistration
|
|
|
|
runtimeSub runtimeevents.Subscription
|
|
runtimeDone chan struct{}
|
|
closeOnce sync.Once
|
|
}
|
|
|
|
func NewHookManager(runtimeEvents runtimeevents.EventChannel) *HookManager {
|
|
hm := &HookManager{
|
|
runtimeEvents: runtimeEvents,
|
|
observerTimeout: defaultHookObserverTimeout,
|
|
interceptorTimeout: defaultHookInterceptorTimeout,
|
|
approvalTimeout: defaultHookApprovalTimeout,
|
|
hooks: make(map[string]HookRegistration),
|
|
runtimeDone: make(chan struct{}),
|
|
}
|
|
|
|
if runtimeEvents != nil {
|
|
sub, ch, err := runtimeEvents.SubscribeChan(context.Background(), runtimeevents.SubscribeOptions{
|
|
Name: "hook-manager-observer",
|
|
Buffer: hookObserverBufferSize,
|
|
})
|
|
if err != nil {
|
|
logger.WarnCF("hooks", "Failed to subscribe runtime events for hooks", map[string]any{
|
|
"error": err.Error(),
|
|
})
|
|
close(hm.runtimeDone)
|
|
} else {
|
|
hm.runtimeSub = sub
|
|
go hm.dispatchRuntimeEvents(ch)
|
|
}
|
|
} else {
|
|
close(hm.runtimeDone)
|
|
}
|
|
|
|
return hm
|
|
}
|
|
|
|
func (hm *HookManager) Close() {
|
|
if hm == nil {
|
|
return
|
|
}
|
|
|
|
hm.closeOnce.Do(func() {
|
|
if hm.runtimeSub != nil {
|
|
if err := hm.runtimeSub.Close(); err != nil {
|
|
logger.WarnCF("hooks", "Failed to close runtime event hook subscription", map[string]any{
|
|
"error": err.Error(),
|
|
})
|
|
}
|
|
}
|
|
<-hm.runtimeDone
|
|
hm.closeAllHooks()
|
|
})
|
|
}
|
|
|
|
func (hm *HookManager) ConfigureTimeouts(observer, interceptor, approval time.Duration) {
|
|
if hm == nil {
|
|
return
|
|
}
|
|
if observer > 0 {
|
|
hm.observerTimeout = observer
|
|
}
|
|
if interceptor > 0 {
|
|
hm.interceptorTimeout = interceptor
|
|
}
|
|
if approval > 0 {
|
|
hm.approvalTimeout = approval
|
|
}
|
|
}
|
|
|
|
func (hm *HookManager) Mount(reg HookRegistration) error {
|
|
if hm == nil {
|
|
return fmt.Errorf("hook manager is nil")
|
|
}
|
|
if reg.Name == "" {
|
|
return fmt.Errorf("hook name is required")
|
|
}
|
|
if reg.Hook == nil {
|
|
return fmt.Errorf("hook %q is nil", reg.Name)
|
|
}
|
|
|
|
hm.mu.Lock()
|
|
defer hm.mu.Unlock()
|
|
|
|
if existing, ok := hm.hooks[reg.Name]; ok {
|
|
closeHookIfPossible(existing.Hook)
|
|
}
|
|
hm.hooks[reg.Name] = reg
|
|
hm.rebuildOrdered()
|
|
return nil
|
|
}
|
|
|
|
func (hm *HookManager) Unmount(name string) {
|
|
if hm == nil || name == "" {
|
|
return
|
|
}
|
|
|
|
hm.mu.Lock()
|
|
defer hm.mu.Unlock()
|
|
|
|
if existing, ok := hm.hooks[name]; ok {
|
|
closeHookIfPossible(existing.Hook)
|
|
}
|
|
delete(hm.hooks, name)
|
|
hm.rebuildOrdered()
|
|
}
|
|
|
|
func (hm *HookManager) dispatchRuntimeEvents(ch <-chan runtimeevents.Event) {
|
|
defer close(hm.runtimeDone)
|
|
|
|
for evt := range ch {
|
|
for _, reg := range hm.snapshotHooks() {
|
|
observer, ok := reg.Hook.(RuntimeEventObserver)
|
|
if !ok {
|
|
continue
|
|
}
|
|
hm.runRuntimeObserver(reg.Name, observer, evt)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (hm *HookManager) BeforeLLM(ctx context.Context, req *LLMHookRequest) (*LLMHookRequest, HookDecision) {
|
|
if hm == nil || req == nil {
|
|
return req, HookDecision{Action: HookActionContinue}
|
|
}
|
|
|
|
current := req.Clone()
|
|
for _, reg := range hm.snapshotHooks() {
|
|
interceptor, ok := reg.Hook.(LLMInterceptor)
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
next, decision, ok := hm.callBeforeLLM(ctx, reg.Name, interceptor, current.Clone())
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
switch decision.normalizedAction() {
|
|
case HookActionContinue, HookActionModify:
|
|
if next != nil {
|
|
next = hm.applyBeforeLLMControls(reg.Name, current, next)
|
|
current = next
|
|
}
|
|
case HookActionAbortTurn, HookActionHardAbort:
|
|
return current, decision
|
|
default:
|
|
hm.logUnsupportedAction(reg.Name, "before_llm", decision.Action)
|
|
}
|
|
}
|
|
return current, HookDecision{Action: HookActionContinue}
|
|
}
|
|
|
|
func (hm *HookManager) AfterLLM(ctx context.Context, resp *LLMHookResponse) (*LLMHookResponse, HookDecision) {
|
|
if hm == nil || resp == nil {
|
|
return resp, HookDecision{Action: HookActionContinue}
|
|
}
|
|
|
|
current := resp.Clone()
|
|
for _, reg := range hm.snapshotHooks() {
|
|
interceptor, ok := reg.Hook.(LLMInterceptor)
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
next, decision, ok := hm.callAfterLLM(ctx, reg.Name, interceptor, current.Clone())
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
switch decision.normalizedAction() {
|
|
case HookActionContinue, HookActionModify:
|
|
if next != nil {
|
|
current = next
|
|
}
|
|
case HookActionAbortTurn, HookActionHardAbort:
|
|
return current, decision
|
|
default:
|
|
hm.logUnsupportedAction(reg.Name, "after_llm", decision.Action)
|
|
}
|
|
}
|
|
return current, HookDecision{Action: HookActionContinue}
|
|
}
|
|
|
|
func (hm *HookManager) applyBeforeLLMControls(
|
|
hookName string,
|
|
current *LLMHookRequest,
|
|
next *LLMHookRequest,
|
|
) *LLMHookRequest {
|
|
if next == nil || current == nil {
|
|
return next
|
|
}
|
|
if !llmHookSystemMessagesUnchanged(current.Messages, next.Messages) {
|
|
logger.WarnCF("hooks", "Hook attempted to modify system prompt; preserving original messages", map[string]any{
|
|
"hook": hookName,
|
|
})
|
|
next.Messages = cloneProviderMessages(current.Messages)
|
|
}
|
|
if !llmHookToolDefinitionsUnchanged(current.Tools, next.Tools) {
|
|
logger.WarnCF("hooks", "Hook attempted to modify tool definitions; preserving original tools", map[string]any{
|
|
"hook": hookName,
|
|
})
|
|
next.Tools = cloneToolDefinitions(current.Tools)
|
|
}
|
|
return next
|
|
}
|
|
|
|
func llmHookSystemMessagesUnchanged(before, after []providers.Message) bool {
|
|
beforeSystem := systemMessageFingerprints(before)
|
|
afterSystem := systemMessageFingerprints(after)
|
|
return reflect.DeepEqual(beforeSystem, afterSystem)
|
|
}
|
|
|
|
type systemMessageFingerprint struct {
|
|
Index int
|
|
Message providers.Message
|
|
}
|
|
|
|
func systemMessageFingerprints(messages []providers.Message) []systemMessageFingerprint {
|
|
var fingerprints []systemMessageFingerprint
|
|
for i, msg := range messages {
|
|
if msg.Role != "system" {
|
|
continue
|
|
}
|
|
msg = providerVisibleMessage(msg)
|
|
fingerprints = append(fingerprints, systemMessageFingerprint{
|
|
Index: i,
|
|
Message: cloneProviderMessages([]providers.Message{msg})[0],
|
|
})
|
|
}
|
|
return fingerprints
|
|
}
|
|
|
|
func llmHookToolDefinitionsUnchanged(before, after []providers.ToolDefinition) bool {
|
|
return reflect.DeepEqual(providerVisibleToolDefinitions(before), providerVisibleToolDefinitions(after))
|
|
}
|
|
|
|
func providerVisibleMessage(msg providers.Message) providers.Message {
|
|
msg.PromptLayer = ""
|
|
msg.PromptSlot = ""
|
|
msg.PromptSource = ""
|
|
if len(msg.SystemParts) > 0 {
|
|
msg.SystemParts = append([]providers.ContentBlock(nil), msg.SystemParts...)
|
|
for i := range msg.SystemParts {
|
|
msg.SystemParts[i].PromptLayer = ""
|
|
msg.SystemParts[i].PromptSlot = ""
|
|
msg.SystemParts[i].PromptSource = ""
|
|
}
|
|
}
|
|
return msg
|
|
}
|
|
|
|
func providerVisibleToolDefinitions(defs []providers.ToolDefinition) []providers.ToolDefinition {
|
|
cloned := cloneToolDefinitions(defs)
|
|
for i := range cloned {
|
|
cloned[i].PromptLayer = ""
|
|
cloned[i].PromptSlot = ""
|
|
cloned[i].PromptSource = ""
|
|
}
|
|
return cloned
|
|
}
|
|
|
|
func (hm *HookManager) BeforeTool(
|
|
ctx context.Context,
|
|
call *ToolCallHookRequest,
|
|
) (*ToolCallHookRequest, HookDecision) {
|
|
if hm == nil || call == nil {
|
|
return call, HookDecision{Action: HookActionContinue}
|
|
}
|
|
|
|
current := call.Clone()
|
|
for _, reg := range hm.snapshotHooks() {
|
|
interceptor, ok := reg.Hook.(ToolInterceptor)
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
next, decision, ok := hm.callBeforeTool(ctx, reg.Name, interceptor, current.Clone())
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
switch decision.normalizedAction() {
|
|
case HookActionContinue, HookActionModify:
|
|
if next != nil {
|
|
current = next
|
|
}
|
|
case HookActionRespond:
|
|
// Hook returns result directly, skip tool execution
|
|
// Carry HookResult in ToolCallHookRequest and return
|
|
return next, decision
|
|
case HookActionDenyTool, HookActionAbortTurn, HookActionHardAbort:
|
|
return current, decision
|
|
default:
|
|
hm.logUnsupportedAction(reg.Name, "before_tool", decision.Action)
|
|
}
|
|
}
|
|
return current, HookDecision{Action: HookActionContinue}
|
|
}
|
|
|
|
func (hm *HookManager) AfterTool(
|
|
ctx context.Context,
|
|
result *ToolResultHookResponse,
|
|
) (*ToolResultHookResponse, HookDecision) {
|
|
if hm == nil || result == nil {
|
|
return result, HookDecision{Action: HookActionContinue}
|
|
}
|
|
|
|
current := result.Clone()
|
|
for _, reg := range hm.snapshotHooks() {
|
|
interceptor, ok := reg.Hook.(ToolInterceptor)
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
next, decision, ok := hm.callAfterTool(ctx, reg.Name, interceptor, current.Clone())
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
switch decision.normalizedAction() {
|
|
case HookActionContinue, HookActionModify:
|
|
if next != nil {
|
|
current = next
|
|
}
|
|
case HookActionAbortTurn, HookActionHardAbort:
|
|
return current, decision
|
|
default:
|
|
hm.logUnsupportedAction(reg.Name, "after_tool", decision.Action)
|
|
}
|
|
}
|
|
return current, HookDecision{Action: HookActionContinue}
|
|
}
|
|
|
|
func (hm *HookManager) ApproveTool(ctx context.Context, req *ToolApprovalRequest) ApprovalDecision {
|
|
if hm == nil || req == nil {
|
|
return ApprovalDecision{Approved: true}
|
|
}
|
|
|
|
for _, reg := range hm.snapshotHooks() {
|
|
approver, ok := reg.Hook.(ToolApprover)
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
decision, ok := hm.callApproveTool(ctx, reg.Name, approver, req.Clone())
|
|
if !ok {
|
|
return ApprovalDecision{
|
|
Approved: false,
|
|
Reason: fmt.Sprintf("tool approval hook %q failed", reg.Name),
|
|
}
|
|
}
|
|
if !decision.Approved {
|
|
return decision
|
|
}
|
|
}
|
|
|
|
return ApprovalDecision{Approved: true}
|
|
}
|
|
|
|
func (hm *HookManager) rebuildOrdered() {
|
|
hm.ordered = hm.ordered[:0]
|
|
for _, reg := range hm.hooks {
|
|
hm.ordered = append(hm.ordered, reg)
|
|
}
|
|
sort.SliceStable(hm.ordered, func(i, j int) bool {
|
|
if hm.ordered[i].Source != hm.ordered[j].Source {
|
|
return hm.ordered[i].Source < hm.ordered[j].Source
|
|
}
|
|
if hm.ordered[i].Priority == hm.ordered[j].Priority {
|
|
return hm.ordered[i].Name < hm.ordered[j].Name
|
|
}
|
|
return hm.ordered[i].Priority < hm.ordered[j].Priority
|
|
})
|
|
}
|
|
|
|
func (hm *HookManager) snapshotHooks() []HookRegistration {
|
|
hm.mu.RLock()
|
|
defer hm.mu.RUnlock()
|
|
|
|
snapshot := make([]HookRegistration, len(hm.ordered))
|
|
copy(snapshot, hm.ordered)
|
|
return snapshot
|
|
}
|
|
|
|
func (hm *HookManager) closeAllHooks() {
|
|
hm.mu.Lock()
|
|
defer hm.mu.Unlock()
|
|
|
|
for name, reg := range hm.hooks {
|
|
closeHookIfPossible(reg.Hook)
|
|
delete(hm.hooks, name)
|
|
}
|
|
hm.ordered = nil
|
|
}
|
|
|
|
func (hm *HookManager) runRuntimeObserver(
|
|
name string,
|
|
observer RuntimeEventObserver,
|
|
evt runtimeevents.Event,
|
|
) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), hm.observerTimeout)
|
|
defer cancel()
|
|
|
|
done := make(chan error, 1)
|
|
go func() {
|
|
done <- observer.OnRuntimeEvent(ctx, evt)
|
|
}()
|
|
|
|
select {
|
|
case err := <-done:
|
|
if err != nil {
|
|
logger.WarnCF("hooks", "Runtime event observer failed", map[string]any{
|
|
"hook": name,
|
|
"event": evt.Kind.String(),
|
|
"error": err.Error(),
|
|
})
|
|
}
|
|
case <-ctx.Done():
|
|
logger.WarnCF("hooks", "Runtime event observer timed out", map[string]any{
|
|
"hook": name,
|
|
"event": evt.Kind.String(),
|
|
"timeout_ms": hm.observerTimeout.Milliseconds(),
|
|
})
|
|
}
|
|
}
|
|
|
|
func (hm *HookManager) callBeforeLLM(
|
|
parent context.Context,
|
|
name string,
|
|
interceptor LLMInterceptor,
|
|
req *LLMHookRequest,
|
|
) (*LLMHookRequest, HookDecision, bool) {
|
|
return runInterceptorHook(
|
|
parent,
|
|
hm.interceptorTimeout,
|
|
name,
|
|
"before_llm",
|
|
func(ctx context.Context) (*LLMHookRequest, HookDecision, error) {
|
|
return interceptor.BeforeLLM(ctx, req)
|
|
},
|
|
)
|
|
}
|
|
|
|
func (hm *HookManager) callAfterLLM(
|
|
parent context.Context,
|
|
name string,
|
|
interceptor LLMInterceptor,
|
|
resp *LLMHookResponse,
|
|
) (*LLMHookResponse, HookDecision, bool) {
|
|
return runInterceptorHook(
|
|
parent,
|
|
hm.interceptorTimeout,
|
|
name,
|
|
"after_llm",
|
|
func(ctx context.Context) (*LLMHookResponse, HookDecision, error) {
|
|
return interceptor.AfterLLM(ctx, resp)
|
|
},
|
|
)
|
|
}
|
|
|
|
func (hm *HookManager) callBeforeTool(
|
|
parent context.Context,
|
|
name string,
|
|
interceptor ToolInterceptor,
|
|
call *ToolCallHookRequest,
|
|
) (*ToolCallHookRequest, HookDecision, bool) {
|
|
return runInterceptorHook(
|
|
parent,
|
|
hm.interceptorTimeout,
|
|
name,
|
|
"before_tool",
|
|
func(ctx context.Context) (*ToolCallHookRequest, HookDecision, error) {
|
|
return interceptor.BeforeTool(ctx, call)
|
|
},
|
|
)
|
|
}
|
|
|
|
func (hm *HookManager) callAfterTool(
|
|
parent context.Context,
|
|
name string,
|
|
interceptor ToolInterceptor,
|
|
resultView *ToolResultHookResponse,
|
|
) (*ToolResultHookResponse, HookDecision, bool) {
|
|
return runInterceptorHook(
|
|
parent,
|
|
hm.interceptorTimeout,
|
|
name,
|
|
"after_tool",
|
|
func(ctx context.Context) (*ToolResultHookResponse, HookDecision, error) {
|
|
return interceptor.AfterTool(ctx, resultView)
|
|
},
|
|
)
|
|
}
|
|
|
|
func (hm *HookManager) callApproveTool(
|
|
parent context.Context,
|
|
name string,
|
|
approver ToolApprover,
|
|
req *ToolApprovalRequest,
|
|
) (ApprovalDecision, bool) {
|
|
return runApprovalHook(
|
|
parent,
|
|
hm.approvalTimeout,
|
|
name,
|
|
"approve_tool",
|
|
func(ctx context.Context) (ApprovalDecision, error) {
|
|
return approver.ApproveTool(ctx, req)
|
|
},
|
|
)
|
|
}
|
|
|
|
func runInterceptorHook[T any](
|
|
parent context.Context,
|
|
timeout time.Duration,
|
|
name string,
|
|
stage string,
|
|
fn func(ctx context.Context) (T, HookDecision, error),
|
|
) (T, HookDecision, bool) {
|
|
var zero T
|
|
|
|
ctx, cancel := context.WithTimeout(parent, timeout)
|
|
defer cancel()
|
|
|
|
type result struct {
|
|
value T
|
|
decision HookDecision
|
|
err error
|
|
}
|
|
done := make(chan result, 1)
|
|
go func() {
|
|
value, decision, err := fn(ctx)
|
|
done <- result{value: value, decision: decision, err: err}
|
|
}()
|
|
|
|
select {
|
|
case res := <-done:
|
|
if res.err != nil {
|
|
logger.WarnCF("hooks", "Interceptor hook failed", map[string]any{
|
|
"hook": name,
|
|
"stage": stage,
|
|
"error": res.err.Error(),
|
|
})
|
|
return zero, HookDecision{}, false
|
|
}
|
|
return res.value, res.decision, true
|
|
case <-ctx.Done():
|
|
logger.WarnCF("hooks", "Interceptor hook timed out", map[string]any{
|
|
"hook": name,
|
|
"stage": stage,
|
|
"timeout_ms": timeout.Milliseconds(),
|
|
})
|
|
return zero, HookDecision{}, false
|
|
}
|
|
}
|
|
|
|
func runApprovalHook(
|
|
parent context.Context,
|
|
timeout time.Duration,
|
|
name string,
|
|
stage string,
|
|
fn func(ctx context.Context) (ApprovalDecision, error),
|
|
) (ApprovalDecision, bool) {
|
|
ctx, cancel := context.WithTimeout(parent, timeout)
|
|
defer cancel()
|
|
|
|
type result struct {
|
|
decision ApprovalDecision
|
|
err error
|
|
}
|
|
done := make(chan result, 1)
|
|
go func() {
|
|
decision, err := fn(ctx)
|
|
done <- result{decision: decision, err: err}
|
|
}()
|
|
|
|
select {
|
|
case res := <-done:
|
|
if res.err != nil {
|
|
logger.WarnCF("hooks", "Approval hook failed", map[string]any{
|
|
"hook": name,
|
|
"stage": stage,
|
|
"error": res.err.Error(),
|
|
})
|
|
return ApprovalDecision{}, false
|
|
}
|
|
return res.decision, true
|
|
case <-ctx.Done():
|
|
logger.WarnCF("hooks", "Approval hook timed out", map[string]any{
|
|
"hook": name,
|
|
"stage": stage,
|
|
"timeout_ms": timeout.Milliseconds(),
|
|
})
|
|
return ApprovalDecision{
|
|
Approved: false,
|
|
Reason: fmt.Sprintf("tool approval hook %q timed out", name),
|
|
}, true
|
|
}
|
|
}
|
|
|
|
func (hm *HookManager) logUnsupportedAction(name, stage string, action HookAction) {
|
|
logger.WarnCF("hooks", "Hook returned unsupported action for stage", map[string]any{
|
|
"hook": name,
|
|
"stage": stage,
|
|
"action": action,
|
|
})
|
|
}
|
|
|
|
func cloneProviderMessages(messages []providers.Message) []providers.Message {
|
|
if len(messages) == 0 {
|
|
return nil
|
|
}
|
|
|
|
cloned := make([]providers.Message, len(messages))
|
|
for i, msg := range messages {
|
|
cloned[i] = msg
|
|
if len(msg.Media) > 0 {
|
|
cloned[i].Media = append([]string(nil), msg.Media...)
|
|
}
|
|
if len(msg.SystemParts) > 0 {
|
|
cloned[i].SystemParts = append([]providers.ContentBlock(nil), msg.SystemParts...)
|
|
}
|
|
if len(msg.ToolCalls) > 0 {
|
|
cloned[i].ToolCalls = cloneProviderToolCalls(msg.ToolCalls)
|
|
}
|
|
}
|
|
return cloned
|
|
}
|
|
|
|
func cloneProviderToolCalls(calls []providers.ToolCall) []providers.ToolCall {
|
|
if len(calls) == 0 {
|
|
return nil
|
|
}
|
|
|
|
cloned := make([]providers.ToolCall, len(calls))
|
|
for i, call := range calls {
|
|
cloned[i] = call
|
|
if call.Function != nil {
|
|
fn := *call.Function
|
|
cloned[i].Function = &fn
|
|
}
|
|
if call.Arguments != nil {
|
|
cloned[i].Arguments = cloneStringAnyMap(call.Arguments)
|
|
}
|
|
if call.ExtraContent != nil {
|
|
extra := *call.ExtraContent
|
|
if call.ExtraContent.Google != nil {
|
|
google := *call.ExtraContent.Google
|
|
extra.Google = &google
|
|
}
|
|
cloned[i].ExtraContent = &extra
|
|
}
|
|
}
|
|
return cloned
|
|
}
|
|
|
|
func cloneToolDefinitions(defs []providers.ToolDefinition) []providers.ToolDefinition {
|
|
if len(defs) == 0 {
|
|
return nil
|
|
}
|
|
|
|
cloned := make([]providers.ToolDefinition, len(defs))
|
|
for i, def := range defs {
|
|
cloned[i] = def
|
|
cloned[i].Function.Parameters = cloneStringAnyMap(def.Function.Parameters)
|
|
}
|
|
return cloned
|
|
}
|
|
|
|
func cloneLLMResponse(resp *providers.LLMResponse) *providers.LLMResponse {
|
|
if resp == nil {
|
|
return nil
|
|
}
|
|
cloned := *resp
|
|
cloned.ToolCalls = cloneProviderToolCalls(resp.ToolCalls)
|
|
if len(resp.ReasoningDetails) > 0 {
|
|
cloned.ReasoningDetails = append(cloned.ReasoningDetails[:0:0], resp.ReasoningDetails...)
|
|
}
|
|
if resp.Usage != nil {
|
|
usage := *resp.Usage
|
|
cloned.Usage = &usage
|
|
}
|
|
return &cloned
|
|
}
|
|
|
|
func cloneStringAnyMap(src map[string]any) map[string]any {
|
|
if len(src) == 0 {
|
|
return map[string]any{}
|
|
}
|
|
|
|
cloned := make(map[string]any, len(src))
|
|
for k, v := range src {
|
|
cloned[k] = v
|
|
}
|
|
return cloned
|
|
}
|
|
|
|
func cloneToolResult(result *tools.ToolResult) *tools.ToolResult {
|
|
if result == nil {
|
|
return nil
|
|
}
|
|
|
|
cloned := *result
|
|
if len(result.Media) > 0 {
|
|
cloned.Media = append([]string(nil), result.Media...)
|
|
}
|
|
if len(result.ArtifactTags) > 0 {
|
|
cloned.ArtifactTags = append([]string(nil), result.ArtifactTags...)
|
|
}
|
|
if len(result.Messages) > 0 {
|
|
cloned.Messages = make([]providers.Message, len(result.Messages))
|
|
copy(cloned.Messages, result.Messages)
|
|
}
|
|
return &cloned
|
|
}
|
|
|
|
func closeHookIfPossible(hook any) {
|
|
closer, ok := hook.(io.Closer)
|
|
if !ok {
|
|
return
|
|
}
|
|
if err := closer.Close(); err != nil {
|
|
logger.WarnCF("hooks", "Failed to close hook", map[string]any{
|
|
"error": err.Error(),
|
|
})
|
|
}
|
|
}
|