mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge pull request #1863 from alexhoshina/feat/hook-manager
Feat/hook manager
This commit is contained in:
@@ -0,0 +1,317 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
type hookRuntime struct {
|
||||
initOnce sync.Once
|
||||
mu sync.Mutex
|
||||
initErr error
|
||||
mounted []string
|
||||
}
|
||||
|
||||
func (r *hookRuntime) setInitErr(err error) {
|
||||
r.mu.Lock()
|
||||
r.initErr = err
|
||||
r.mu.Unlock()
|
||||
}
|
||||
|
||||
func (r *hookRuntime) getInitErr() error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
return r.initErr
|
||||
}
|
||||
|
||||
func (r *hookRuntime) setMounted(names []string) {
|
||||
r.mu.Lock()
|
||||
r.mounted = append([]string(nil), names...)
|
||||
r.mu.Unlock()
|
||||
}
|
||||
|
||||
func (r *hookRuntime) reset(al *AgentLoop) {
|
||||
r.mu.Lock()
|
||||
names := append([]string(nil), r.mounted...)
|
||||
r.mounted = nil
|
||||
r.initErr = nil
|
||||
r.initOnce = sync.Once{}
|
||||
r.mu.Unlock()
|
||||
|
||||
for _, name := range names {
|
||||
al.UnmountHook(name)
|
||||
}
|
||||
}
|
||||
|
||||
// BuiltinHookFactory constructs an in-process hook from config.
|
||||
type BuiltinHookFactory func(ctx context.Context, spec config.BuiltinHookConfig) (any, error)
|
||||
|
||||
var (
|
||||
builtinHookRegistryMu sync.RWMutex
|
||||
builtinHookRegistry = map[string]BuiltinHookFactory{}
|
||||
)
|
||||
|
||||
// RegisterBuiltinHook registers a named in-process hook factory for config-driven mounting.
|
||||
func RegisterBuiltinHook(name string, factory BuiltinHookFactory) error {
|
||||
if name == "" {
|
||||
return fmt.Errorf("builtin hook name is required")
|
||||
}
|
||||
if factory == nil {
|
||||
return fmt.Errorf("builtin hook %q factory is nil", name)
|
||||
}
|
||||
|
||||
builtinHookRegistryMu.Lock()
|
||||
defer builtinHookRegistryMu.Unlock()
|
||||
|
||||
if _, exists := builtinHookRegistry[name]; exists {
|
||||
return fmt.Errorf("builtin hook %q is already registered", name)
|
||||
}
|
||||
builtinHookRegistry[name] = factory
|
||||
return nil
|
||||
}
|
||||
|
||||
func unregisterBuiltinHook(name string) {
|
||||
if name == "" {
|
||||
return
|
||||
}
|
||||
builtinHookRegistryMu.Lock()
|
||||
delete(builtinHookRegistry, name)
|
||||
builtinHookRegistryMu.Unlock()
|
||||
}
|
||||
|
||||
func lookupBuiltinHook(name string) (BuiltinHookFactory, bool) {
|
||||
builtinHookRegistryMu.RLock()
|
||||
defer builtinHookRegistryMu.RUnlock()
|
||||
|
||||
factory, ok := builtinHookRegistry[name]
|
||||
return factory, ok
|
||||
}
|
||||
|
||||
func configureHookManagerFromConfig(hm *HookManager, cfg *config.Config) {
|
||||
if hm == nil || cfg == nil {
|
||||
return
|
||||
}
|
||||
hm.ConfigureTimeouts(
|
||||
hookTimeoutFromMS(cfg.Hooks.Defaults.ObserverTimeoutMS),
|
||||
hookTimeoutFromMS(cfg.Hooks.Defaults.InterceptorTimeoutMS),
|
||||
hookTimeoutFromMS(cfg.Hooks.Defaults.ApprovalTimeoutMS),
|
||||
)
|
||||
}
|
||||
|
||||
func hookTimeoutFromMS(ms int) time.Duration {
|
||||
if ms <= 0 {
|
||||
return 0
|
||||
}
|
||||
return time.Duration(ms) * time.Millisecond
|
||||
}
|
||||
|
||||
func (al *AgentLoop) ensureHooksInitialized(ctx context.Context) error {
|
||||
if al == nil || al.cfg == nil || al.hooks == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
al.hookRuntime.initOnce.Do(func() {
|
||||
al.hookRuntime.setInitErr(al.loadConfiguredHooks(ctx))
|
||||
})
|
||||
|
||||
return al.hookRuntime.getInitErr()
|
||||
}
|
||||
|
||||
func (al *AgentLoop) loadConfiguredHooks(ctx context.Context) (err error) {
|
||||
if al == nil || al.cfg == nil || !al.cfg.Hooks.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
mounted := make([]string, 0)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
for _, name := range mounted {
|
||||
al.UnmountHook(name)
|
||||
}
|
||||
return
|
||||
}
|
||||
al.hookRuntime.setMounted(mounted)
|
||||
}()
|
||||
|
||||
builtinNames := enabledBuiltinHookNames(al.cfg.Hooks.Builtins)
|
||||
for _, name := range builtinNames {
|
||||
spec := al.cfg.Hooks.Builtins[name]
|
||||
factory, ok := lookupBuiltinHook(name)
|
||||
if !ok {
|
||||
return fmt.Errorf("builtin hook %q is not registered", name)
|
||||
}
|
||||
|
||||
hook, factoryErr := factory(ctx, spec)
|
||||
if factoryErr != nil {
|
||||
return fmt.Errorf("build builtin hook %q: %w", name, factoryErr)
|
||||
}
|
||||
if err := al.MountHook(HookRegistration{
|
||||
Name: name,
|
||||
Priority: spec.Priority,
|
||||
Source: HookSourceInProcess,
|
||||
Hook: hook,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("mount builtin hook %q: %w", name, err)
|
||||
}
|
||||
mounted = append(mounted, name)
|
||||
}
|
||||
|
||||
processNames := enabledProcessHookNames(al.cfg.Hooks.Processes)
|
||||
for _, name := range processNames {
|
||||
spec := al.cfg.Hooks.Processes[name]
|
||||
opts, buildErr := processHookOptionsFromConfig(spec)
|
||||
if buildErr != nil {
|
||||
return fmt.Errorf("configure process hook %q: %w", name, buildErr)
|
||||
}
|
||||
|
||||
processHook, buildErr := NewProcessHook(ctx, name, opts)
|
||||
if buildErr != nil {
|
||||
return fmt.Errorf("start process hook %q: %w", name, buildErr)
|
||||
}
|
||||
if err := al.MountHook(HookRegistration{
|
||||
Name: name,
|
||||
Priority: spec.Priority,
|
||||
Source: HookSourceProcess,
|
||||
Hook: processHook,
|
||||
}); err != nil {
|
||||
_ = processHook.Close()
|
||||
return fmt.Errorf("mount process hook %q: %w", name, err)
|
||||
}
|
||||
mounted = append(mounted, name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func enabledBuiltinHookNames(specs map[string]config.BuiltinHookConfig) []string {
|
||||
if len(specs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
names := make([]string, 0, len(specs))
|
||||
for name, spec := range specs {
|
||||
if spec.Enabled {
|
||||
names = append(names, name)
|
||||
}
|
||||
}
|
||||
sort.Strings(names)
|
||||
return names
|
||||
}
|
||||
|
||||
func enabledProcessHookNames(specs map[string]config.ProcessHookConfig) []string {
|
||||
if len(specs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
names := make([]string, 0, len(specs))
|
||||
for name, spec := range specs {
|
||||
if spec.Enabled {
|
||||
names = append(names, name)
|
||||
}
|
||||
}
|
||||
sort.Strings(names)
|
||||
return names
|
||||
}
|
||||
|
||||
func processHookOptionsFromConfig(spec config.ProcessHookConfig) (ProcessHookOptions, error) {
|
||||
transport := spec.Transport
|
||||
if transport == "" {
|
||||
transport = "stdio"
|
||||
}
|
||||
if transport != "stdio" {
|
||||
return ProcessHookOptions{}, fmt.Errorf("unsupported transport %q", transport)
|
||||
}
|
||||
if len(spec.Command) == 0 {
|
||||
return ProcessHookOptions{}, fmt.Errorf("command is required")
|
||||
}
|
||||
|
||||
opts := ProcessHookOptions{
|
||||
Command: append([]string(nil), spec.Command...),
|
||||
Dir: spec.Dir,
|
||||
Env: processHookEnvFromMap(spec.Env),
|
||||
}
|
||||
|
||||
observeKinds, observeEnabled, err := processHookObserveKindsFromConfig(spec.Observe)
|
||||
if err != nil {
|
||||
return ProcessHookOptions{}, err
|
||||
}
|
||||
opts.Observe = observeEnabled
|
||||
opts.ObserveKinds = observeKinds
|
||||
|
||||
for _, intercept := range spec.Intercept {
|
||||
switch intercept {
|
||||
case "before_llm", "after_llm":
|
||||
opts.InterceptLLM = true
|
||||
case "before_tool", "after_tool":
|
||||
opts.InterceptTool = true
|
||||
case "approve_tool":
|
||||
opts.ApproveTool = true
|
||||
case "":
|
||||
continue
|
||||
default:
|
||||
return ProcessHookOptions{}, fmt.Errorf("unsupported intercept %q", intercept)
|
||||
}
|
||||
}
|
||||
|
||||
if !opts.Observe && !opts.InterceptLLM && !opts.InterceptTool && !opts.ApproveTool {
|
||||
return ProcessHookOptions{}, fmt.Errorf("no hook modes enabled")
|
||||
}
|
||||
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
func processHookEnvFromMap(envMap map[string]string) []string {
|
||||
if len(envMap) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
keys := make([]string, 0, len(envMap))
|
||||
for key := range envMap {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
env := make([]string, 0, len(keys))
|
||||
for _, key := range keys {
|
||||
env = append(env, key+"="+envMap[key])
|
||||
}
|
||||
return env
|
||||
}
|
||||
|
||||
func processHookObserveKindsFromConfig(observe []string) ([]string, bool, error) {
|
||||
if len(observe) == 0 {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
validKinds := validHookEventKinds()
|
||||
normalized := make([]string, 0, len(observe))
|
||||
for _, kind := range observe {
|
||||
switch kind {
|
||||
case "", "*", "all":
|
||||
return nil, true, nil
|
||||
default:
|
||||
if _, ok := validKinds[kind]; !ok {
|
||||
return nil, false, fmt.Errorf("unsupported observe event %q", kind)
|
||||
}
|
||||
normalized = append(normalized, kind)
|
||||
}
|
||||
}
|
||||
|
||||
if len(normalized) == 0 {
|
||||
return nil, false, nil
|
||||
}
|
||||
return normalized, true, nil
|
||||
}
|
||||
|
||||
func validHookEventKinds() map[string]struct{} {
|
||||
kinds := make(map[string]struct{}, int(eventKindCount))
|
||||
for kind := EventKind(0); kind < eventKindCount; kind++ {
|
||||
kinds[kind.String()] = struct{}{}
|
||||
}
|
||||
return kinds
|
||||
}
|
||||
@@ -0,0 +1,179 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
type builtinAutoHookConfig struct {
|
||||
Model string `json:"model"`
|
||||
Suffix string `json:"suffix"`
|
||||
}
|
||||
|
||||
type builtinAutoHook struct {
|
||||
model string
|
||||
suffix string
|
||||
}
|
||||
|
||||
func (h *builtinAutoHook) BeforeLLM(
|
||||
ctx context.Context,
|
||||
req *LLMHookRequest,
|
||||
) (*LLMHookRequest, HookDecision, error) {
|
||||
next := req.Clone()
|
||||
next.Model = h.model
|
||||
return next, HookDecision{Action: HookActionModify}, nil
|
||||
}
|
||||
|
||||
func (h *builtinAutoHook) AfterLLM(
|
||||
ctx context.Context,
|
||||
resp *LLMHookResponse,
|
||||
) (*LLMHookResponse, HookDecision, error) {
|
||||
next := resp.Clone()
|
||||
if next.Response != nil {
|
||||
next.Response.Content += h.suffix
|
||||
}
|
||||
return next, HookDecision{Action: HookActionModify}, nil
|
||||
}
|
||||
|
||||
func newConfiguredHookLoop(t *testing.T, provider *llmHookTestProvider, hooks config.HooksConfig) *AgentLoop {
|
||||
t.Helper()
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: t.TempDir(),
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
Hooks: hooks,
|
||||
}
|
||||
|
||||
return NewAgentLoop(cfg, bus.NewMessageBus(), provider)
|
||||
}
|
||||
|
||||
func TestAgentLoop_ProcessDirectWithChannel_AutoMountsBuiltinHook(t *testing.T) {
|
||||
const hookName = "test-auto-builtin-hook"
|
||||
|
||||
if err := RegisterBuiltinHook(hookName, func(
|
||||
ctx context.Context,
|
||||
spec config.BuiltinHookConfig,
|
||||
) (any, error) {
|
||||
var hookCfg builtinAutoHookConfig
|
||||
if len(spec.Config) > 0 {
|
||||
if err := json.Unmarshal(spec.Config, &hookCfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return &builtinAutoHook{
|
||||
model: hookCfg.Model,
|
||||
suffix: hookCfg.Suffix,
|
||||
}, nil
|
||||
}); err != nil {
|
||||
t.Fatalf("RegisterBuiltinHook failed: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
unregisterBuiltinHook(hookName)
|
||||
})
|
||||
|
||||
rawCfg, err := json.Marshal(builtinAutoHookConfig{
|
||||
Model: "builtin-model",
|
||||
Suffix: "|builtin",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("json.Marshal failed: %v", err)
|
||||
}
|
||||
|
||||
provider := &llmHookTestProvider{}
|
||||
al := newConfiguredHookLoop(t, provider, config.HooksConfig{
|
||||
Enabled: true,
|
||||
Builtins: map[string]config.BuiltinHookConfig{
|
||||
hookName: {
|
||||
Enabled: true,
|
||||
Config: rawCfg,
|
||||
},
|
||||
},
|
||||
})
|
||||
defer al.Close()
|
||||
|
||||
resp, err := al.ProcessDirectWithChannel(context.Background(), "hello", "session-1", "cli", "direct")
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessDirectWithChannel failed: %v", err)
|
||||
}
|
||||
if resp != "provider content|builtin" {
|
||||
t.Fatalf("expected builtin-hooked content, got %q", resp)
|
||||
}
|
||||
|
||||
provider.mu.Lock()
|
||||
lastModel := provider.lastModel
|
||||
provider.mu.Unlock()
|
||||
if lastModel != "builtin-model" {
|
||||
t.Fatalf("expected builtin model, got %q", lastModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentLoop_ProcessDirectWithChannel_AutoMountsProcessHook(t *testing.T) {
|
||||
provider := &llmHookTestProvider{}
|
||||
eventLog := filepath.Join(t.TempDir(), "events.log")
|
||||
|
||||
al := newConfiguredHookLoop(t, provider, config.HooksConfig{
|
||||
Enabled: true,
|
||||
Processes: map[string]config.ProcessHookConfig{
|
||||
"ipc-auto": {
|
||||
Enabled: true,
|
||||
Command: processHookHelperCommand(),
|
||||
Env: map[string]string{
|
||||
"PICOCLAW_HOOK_HELPER": "1",
|
||||
"PICOCLAW_HOOK_MODE": "rewrite",
|
||||
"PICOCLAW_HOOK_EVENT_LOG": eventLog,
|
||||
},
|
||||
Observe: []string{"turn_end"},
|
||||
Intercept: []string{"before_llm", "after_llm"},
|
||||
},
|
||||
},
|
||||
})
|
||||
defer al.Close()
|
||||
|
||||
resp, err := al.ProcessDirectWithChannel(context.Background(), "hello", "session-1", "cli", "direct")
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessDirectWithChannel failed: %v", err)
|
||||
}
|
||||
if resp != "provider content|ipc" {
|
||||
t.Fatalf("expected process-hooked content, got %q", resp)
|
||||
}
|
||||
|
||||
provider.mu.Lock()
|
||||
lastModel := provider.lastModel
|
||||
provider.mu.Unlock()
|
||||
if lastModel != "process-model" {
|
||||
t.Fatalf("expected process model, got %q", lastModel)
|
||||
}
|
||||
|
||||
waitForFileContains(t, eventLog, "turn_end")
|
||||
}
|
||||
|
||||
func TestAgentLoop_ProcessDirectWithChannel_InvalidConfiguredHookFails(t *testing.T) {
|
||||
provider := &llmHookTestProvider{}
|
||||
al := newConfiguredHookLoop(t, provider, config.HooksConfig{
|
||||
Enabled: true,
|
||||
Processes: map[string]config.ProcessHookConfig{
|
||||
"bad-hook": {
|
||||
Enabled: true,
|
||||
Command: processHookHelperCommand(),
|
||||
Intercept: []string{"not_supported"},
|
||||
},
|
||||
},
|
||||
})
|
||||
defer al.Close()
|
||||
|
||||
_, err := al.ProcessDirectWithChannel(context.Background(), "hello", "session-1", "cli", "direct")
|
||||
if err == nil {
|
||||
t.Fatal("expected invalid configured hook error")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,511 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
processHookJSONRPCVersion = "2.0"
|
||||
processHookReadBufferSize = 1024 * 1024
|
||||
processHookCloseTimeout = 2 * time.Second
|
||||
)
|
||||
|
||||
type ProcessHookOptions struct {
|
||||
Command []string
|
||||
Dir string
|
||||
Env []string
|
||||
Observe bool
|
||||
ObserveKinds []string
|
||||
InterceptLLM bool
|
||||
InterceptTool bool
|
||||
ApproveTool bool
|
||||
}
|
||||
|
||||
type ProcessHook struct {
|
||||
name string
|
||||
opts ProcessHookOptions
|
||||
|
||||
cmd *exec.Cmd
|
||||
stdin io.WriteCloser
|
||||
observeKinds map[string]struct{}
|
||||
|
||||
writeMu sync.Mutex
|
||||
|
||||
pendingMu sync.Mutex
|
||||
pending map[uint64]chan processHookRPCMessage
|
||||
nextID atomic.Uint64
|
||||
|
||||
closed atomic.Bool
|
||||
done chan struct{}
|
||||
closeErr error
|
||||
closeMu sync.Mutex
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
type processHookRPCMessage struct {
|
||||
JSONRPC string `json:"jsonrpc,omitempty"`
|
||||
ID uint64 `json:"id,omitempty"`
|
||||
Method string `json:"method,omitempty"`
|
||||
Params json.RawMessage `json:"params,omitempty"`
|
||||
Result json.RawMessage `json:"result,omitempty"`
|
||||
Error *processHookRPCError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type processHookRPCError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type processHookHelloParams struct {
|
||||
Name string `json:"name"`
|
||||
Version int `json:"version"`
|
||||
Modes []string `json:"modes,omitempty"`
|
||||
}
|
||||
|
||||
type processHookDecisionResponse struct {
|
||||
Action HookAction `json:"action"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
}
|
||||
|
||||
type processHookBeforeLLMResponse struct {
|
||||
processHookDecisionResponse
|
||||
Request *LLMHookRequest `json:"request,omitempty"`
|
||||
}
|
||||
|
||||
type processHookAfterLLMResponse struct {
|
||||
processHookDecisionResponse
|
||||
Response *LLMHookResponse `json:"response,omitempty"`
|
||||
}
|
||||
|
||||
type processHookBeforeToolResponse struct {
|
||||
processHookDecisionResponse
|
||||
Call *ToolCallHookRequest `json:"call,omitempty"`
|
||||
}
|
||||
|
||||
type processHookAfterToolResponse struct {
|
||||
processHookDecisionResponse
|
||||
Result *ToolResultHookResponse `json:"result,omitempty"`
|
||||
}
|
||||
|
||||
func NewProcessHook(ctx context.Context, name string, opts ProcessHookOptions) (*ProcessHook, error) {
|
||||
if len(opts.Command) == 0 {
|
||||
return nil, fmt.Errorf("process hook command is required")
|
||||
}
|
||||
|
||||
cmd := exec.Command(opts.Command[0], opts.Command[1:]...)
|
||||
cmd.Dir = opts.Dir
|
||||
if len(opts.Env) > 0 {
|
||||
cmd.Env = append(os.Environ(), opts.Env...)
|
||||
}
|
||||
stdin, err := cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create process hook stdin: %w", err)
|
||||
}
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create process hook stdout: %w", err)
|
||||
}
|
||||
stderr, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create process hook stderr: %w", err)
|
||||
}
|
||||
if err := cmd.Start(); err != nil {
|
||||
return nil, fmt.Errorf("start process hook: %w", err)
|
||||
}
|
||||
|
||||
ph := &ProcessHook{
|
||||
name: name,
|
||||
opts: opts,
|
||||
cmd: cmd,
|
||||
stdin: stdin,
|
||||
observeKinds: newProcessHookObserveKinds(opts.ObserveKinds),
|
||||
pending: make(map[uint64]chan processHookRPCMessage),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
go ph.readLoop(stdout)
|
||||
go ph.readStderr(stderr)
|
||||
go ph.waitLoop()
|
||||
|
||||
helloCtx := ctx
|
||||
if helloCtx == nil {
|
||||
var cancel context.CancelFunc
|
||||
helloCtx, cancel = context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
}
|
||||
if err := ph.hello(helloCtx); err != nil {
|
||||
_ = ph.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ph, nil
|
||||
}
|
||||
|
||||
func (ph *ProcessHook) Close() error {
|
||||
if ph == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
ph.closeOnce.Do(func() {
|
||||
ph.closed.Store(true)
|
||||
if ph.stdin != nil {
|
||||
_ = ph.stdin.Close()
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ph.done:
|
||||
case <-time.After(processHookCloseTimeout):
|
||||
if ph.cmd != nil && ph.cmd.Process != nil {
|
||||
_ = ph.cmd.Process.Kill()
|
||||
}
|
||||
<-ph.done
|
||||
}
|
||||
})
|
||||
|
||||
ph.closeMu.Lock()
|
||||
defer ph.closeMu.Unlock()
|
||||
return ph.closeErr
|
||||
}
|
||||
|
||||
func (ph *ProcessHook) OnEvent(ctx context.Context, evt Event) error {
|
||||
if ph == nil || !ph.opts.Observe {
|
||||
return nil
|
||||
}
|
||||
if len(ph.observeKinds) > 0 {
|
||||
if _, ok := ph.observeKinds[evt.Kind.String()]; !ok {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return ph.notify(ctx, "hook.event", evt)
|
||||
}
|
||||
|
||||
func (ph *ProcessHook) BeforeLLM(
|
||||
ctx context.Context,
|
||||
req *LLMHookRequest,
|
||||
) (*LLMHookRequest, HookDecision, error) {
|
||||
if ph == nil || !ph.opts.InterceptLLM {
|
||||
return req, HookDecision{Action: HookActionContinue}, nil
|
||||
}
|
||||
|
||||
var resp processHookBeforeLLMResponse
|
||||
if err := ph.call(ctx, "hook.before_llm", req, &resp); err != nil {
|
||||
return nil, HookDecision{}, err
|
||||
}
|
||||
if resp.Request == nil {
|
||||
resp.Request = req
|
||||
}
|
||||
return resp.Request, HookDecision{Action: resp.Action, Reason: resp.Reason}, nil
|
||||
}
|
||||
|
||||
func (ph *ProcessHook) AfterLLM(
|
||||
ctx context.Context,
|
||||
resp *LLMHookResponse,
|
||||
) (*LLMHookResponse, HookDecision, error) {
|
||||
if ph == nil || !ph.opts.InterceptLLM {
|
||||
return resp, HookDecision{Action: HookActionContinue}, nil
|
||||
}
|
||||
|
||||
var result processHookAfterLLMResponse
|
||||
if err := ph.call(ctx, "hook.after_llm", resp, &result); err != nil {
|
||||
return nil, HookDecision{}, err
|
||||
}
|
||||
if result.Response == nil {
|
||||
result.Response = resp
|
||||
}
|
||||
return result.Response, HookDecision{Action: result.Action, Reason: result.Reason}, nil
|
||||
}
|
||||
|
||||
func (ph *ProcessHook) BeforeTool(
|
||||
ctx context.Context,
|
||||
call *ToolCallHookRequest,
|
||||
) (*ToolCallHookRequest, HookDecision, error) {
|
||||
if ph == nil || !ph.opts.InterceptTool {
|
||||
return call, HookDecision{Action: HookActionContinue}, nil
|
||||
}
|
||||
|
||||
var resp processHookBeforeToolResponse
|
||||
if err := ph.call(ctx, "hook.before_tool", call, &resp); err != nil {
|
||||
return nil, HookDecision{}, err
|
||||
}
|
||||
if resp.Call == nil {
|
||||
resp.Call = call
|
||||
}
|
||||
return resp.Call, HookDecision{Action: resp.Action, Reason: resp.Reason}, nil
|
||||
}
|
||||
|
||||
func (ph *ProcessHook) AfterTool(
|
||||
ctx context.Context,
|
||||
result *ToolResultHookResponse,
|
||||
) (*ToolResultHookResponse, HookDecision, error) {
|
||||
if ph == nil || !ph.opts.InterceptTool {
|
||||
return result, HookDecision{Action: HookActionContinue}, nil
|
||||
}
|
||||
|
||||
var resp processHookAfterToolResponse
|
||||
if err := ph.call(ctx, "hook.after_tool", result, &resp); err != nil {
|
||||
return nil, HookDecision{}, err
|
||||
}
|
||||
if resp.Result == nil {
|
||||
resp.Result = result
|
||||
}
|
||||
return resp.Result, HookDecision{Action: resp.Action, Reason: resp.Reason}, nil
|
||||
}
|
||||
|
||||
func (ph *ProcessHook) ApproveTool(ctx context.Context, req *ToolApprovalRequest) (ApprovalDecision, error) {
|
||||
if ph == nil || !ph.opts.ApproveTool {
|
||||
return ApprovalDecision{Approved: true}, nil
|
||||
}
|
||||
|
||||
var resp ApprovalDecision
|
||||
if err := ph.call(ctx, "hook.approve_tool", req, &resp); err != nil {
|
||||
return ApprovalDecision{}, err
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (ph *ProcessHook) hello(ctx context.Context) error {
|
||||
modes := make([]string, 0, 4)
|
||||
if ph.opts.Observe {
|
||||
modes = append(modes, "observe")
|
||||
}
|
||||
if ph.opts.InterceptLLM {
|
||||
modes = append(modes, "llm")
|
||||
}
|
||||
if ph.opts.InterceptTool {
|
||||
modes = append(modes, "tool")
|
||||
}
|
||||
if ph.opts.ApproveTool {
|
||||
modes = append(modes, "approve")
|
||||
}
|
||||
|
||||
var result map[string]any
|
||||
return ph.call(ctx, "hook.hello", processHookHelloParams{
|
||||
Name: ph.name,
|
||||
Version: 1,
|
||||
Modes: modes,
|
||||
}, &result)
|
||||
}
|
||||
|
||||
func (ph *ProcessHook) notify(ctx context.Context, method string, params any) error {
|
||||
msg := processHookRPCMessage{
|
||||
JSONRPC: processHookJSONRPCVersion,
|
||||
Method: method,
|
||||
}
|
||||
if params != nil {
|
||||
body, err := json.Marshal(params)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
msg.Params = body
|
||||
}
|
||||
return ph.send(ctx, msg)
|
||||
}
|
||||
|
||||
func (ph *ProcessHook) call(ctx context.Context, method string, params any, out any) error {
|
||||
if ph.closed.Load() {
|
||||
return fmt.Errorf("process hook %q is closed", ph.name)
|
||||
}
|
||||
|
||||
id := ph.nextID.Add(1)
|
||||
respCh := make(chan processHookRPCMessage, 1)
|
||||
ph.pendingMu.Lock()
|
||||
ph.pending[id] = respCh
|
||||
ph.pendingMu.Unlock()
|
||||
|
||||
msg := processHookRPCMessage{
|
||||
JSONRPC: processHookJSONRPCVersion,
|
||||
ID: id,
|
||||
Method: method,
|
||||
}
|
||||
if params != nil {
|
||||
body, err := json.Marshal(params)
|
||||
if err != nil {
|
||||
ph.removePending(id)
|
||||
return err
|
||||
}
|
||||
msg.Params = body
|
||||
}
|
||||
|
||||
if err := ph.send(ctx, msg); err != nil {
|
||||
ph.removePending(id)
|
||||
return err
|
||||
}
|
||||
|
||||
select {
|
||||
case resp, ok := <-respCh:
|
||||
if !ok {
|
||||
return fmt.Errorf("process hook %q closed while waiting for %s", ph.name, method)
|
||||
}
|
||||
if resp.Error != nil {
|
||||
return fmt.Errorf("process hook %q %s failed: %s", ph.name, method, resp.Error.Message)
|
||||
}
|
||||
if out != nil && len(resp.Result) > 0 {
|
||||
if err := json.Unmarshal(resp.Result, out); err != nil {
|
||||
return fmt.Errorf("decode process hook %q %s result: %w", ph.name, method, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
ph.removePending(id)
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (ph *ProcessHook) send(ctx context.Context, msg processHookRPCMessage) error {
|
||||
body, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
body = append(body, '\n')
|
||||
|
||||
ph.writeMu.Lock()
|
||||
defer ph.writeMu.Unlock()
|
||||
|
||||
if ph.closed.Load() {
|
||||
return fmt.Errorf("process hook %q is closed", ph.name)
|
||||
}
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
_, writeErr := ph.stdin.Write(body)
|
||||
done <- writeErr
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
return fmt.Errorf("write process hook %q message: %w", ph.name, err)
|
||||
}
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (ph *ProcessHook) readLoop(stdout io.Reader) {
|
||||
scanner := bufio.NewScanner(stdout)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), processHookReadBufferSize)
|
||||
|
||||
for scanner.Scan() {
|
||||
var msg processHookRPCMessage
|
||||
if err := json.Unmarshal(scanner.Bytes(), &msg); err != nil {
|
||||
logger.WarnCF("hooks", "Failed to decode process hook message", map[string]any{
|
||||
"hook": ph.name,
|
||||
"error": err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
if msg.ID == 0 {
|
||||
continue
|
||||
}
|
||||
ph.pendingMu.Lock()
|
||||
respCh, ok := ph.pending[msg.ID]
|
||||
if ok {
|
||||
delete(ph.pending, msg.ID)
|
||||
}
|
||||
ph.pendingMu.Unlock()
|
||||
if ok {
|
||||
respCh <- msg
|
||||
close(respCh)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ph *ProcessHook) readStderr(stderr io.Reader) {
|
||||
scanner := bufio.NewScanner(stderr)
|
||||
scanner.Buffer(make([]byte, 0, 16*1024), processHookReadBufferSize)
|
||||
for scanner.Scan() {
|
||||
logger.WarnCF("hooks", "Process hook stderr", map[string]any{
|
||||
"hook": ph.name,
|
||||
"stderr": scanner.Text(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (ph *ProcessHook) waitLoop() {
|
||||
err := ph.cmd.Wait()
|
||||
ph.closeMu.Lock()
|
||||
ph.closeErr = err
|
||||
ph.closeMu.Unlock()
|
||||
ph.failPending(err)
|
||||
close(ph.done)
|
||||
}
|
||||
|
||||
func (ph *ProcessHook) failPending(err error) {
|
||||
ph.pendingMu.Lock()
|
||||
defer ph.pendingMu.Unlock()
|
||||
|
||||
msg := processHookRPCMessage{
|
||||
Error: &processHookRPCError{
|
||||
Code: -32000,
|
||||
Message: "process exited",
|
||||
},
|
||||
}
|
||||
if err != nil {
|
||||
msg.Error.Message = err.Error()
|
||||
}
|
||||
|
||||
for id, ch := range ph.pending {
|
||||
delete(ph.pending, id)
|
||||
ch <- msg
|
||||
close(ch)
|
||||
}
|
||||
}
|
||||
|
||||
func (ph *ProcessHook) removePending(id uint64) {
|
||||
ph.pendingMu.Lock()
|
||||
defer ph.pendingMu.Unlock()
|
||||
|
||||
if ch, ok := ph.pending[id]; ok {
|
||||
delete(ph.pending, id)
|
||||
close(ch)
|
||||
}
|
||||
}
|
||||
|
||||
func (al *AgentLoop) MountProcessHook(ctx context.Context, name string, opts ProcessHookOptions) error {
|
||||
if al == nil {
|
||||
return fmt.Errorf("agent loop is nil")
|
||||
}
|
||||
processHook, err := NewProcessHook(ctx, name, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := al.MountHook(HookRegistration{
|
||||
Name: name,
|
||||
Source: HookSourceProcess,
|
||||
Hook: processHook,
|
||||
}); err != nil {
|
||||
_ = processHook.Close()
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func newProcessHookObserveKinds(kinds []string) map[string]struct{} {
|
||||
if len(kinds) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
normalized := make(map[string]struct{}, len(kinds))
|
||||
for _, kind := range kinds {
|
||||
if kind == "" {
|
||||
continue
|
||||
}
|
||||
normalized[kind] = struct{}{}
|
||||
}
|
||||
if len(normalized) == 0 {
|
||||
return nil
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
@@ -0,0 +1,339 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
func TestProcessHook_HelperProcess(t *testing.T) {
|
||||
if os.Getenv("PICOCLAW_HOOK_HELPER") != "1" {
|
||||
return
|
||||
}
|
||||
if err := runProcessHookHelper(); err != nil {
|
||||
fmt.Fprintln(os.Stderr, err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
func TestAgentLoop_MountProcessHook_LLMAndObserver(t *testing.T) {
|
||||
provider := &llmHookTestProvider{}
|
||||
al, agent, cleanup := newHookTestLoop(t, provider)
|
||||
defer cleanup()
|
||||
|
||||
eventLog := filepath.Join(t.TempDir(), "events.log")
|
||||
if err := al.MountProcessHook(context.Background(), "ipc-llm", ProcessHookOptions{
|
||||
Command: processHookHelperCommand(),
|
||||
Env: processHookHelperEnv("rewrite", eventLog),
|
||||
Observe: true,
|
||||
InterceptLLM: true,
|
||||
}); err != nil {
|
||||
t.Fatalf("MountProcessHook failed: %v", err)
|
||||
}
|
||||
|
||||
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
||||
SessionKey: "session-1",
|
||||
Channel: "cli",
|
||||
ChatID: "direct",
|
||||
UserMessage: "hello",
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("runAgentLoop failed: %v", err)
|
||||
}
|
||||
if resp != "provider content|ipc" {
|
||||
t.Fatalf("expected process-hooked llm content, got %q", resp)
|
||||
}
|
||||
|
||||
provider.mu.Lock()
|
||||
lastModel := provider.lastModel
|
||||
provider.mu.Unlock()
|
||||
if lastModel != "process-model" {
|
||||
t.Fatalf("expected process model, got %q", lastModel)
|
||||
}
|
||||
|
||||
waitForFileContains(t, eventLog, "turn_end")
|
||||
}
|
||||
|
||||
func TestAgentLoop_MountProcessHook_ToolRewrite(t *testing.T) {
|
||||
provider := &toolHookProvider{}
|
||||
al, agent, cleanup := newHookTestLoop(t, provider)
|
||||
defer cleanup()
|
||||
|
||||
al.RegisterTool(&echoTextTool{})
|
||||
if err := al.MountProcessHook(context.Background(), "ipc-tool", ProcessHookOptions{
|
||||
Command: processHookHelperCommand(),
|
||||
Env: processHookHelperEnv("rewrite", ""),
|
||||
InterceptTool: true,
|
||||
}); err != nil {
|
||||
t.Fatalf("MountProcessHook failed: %v", err)
|
||||
}
|
||||
|
||||
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
||||
SessionKey: "session-1",
|
||||
Channel: "cli",
|
||||
ChatID: "direct",
|
||||
UserMessage: "run tool",
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("runAgentLoop failed: %v", err)
|
||||
}
|
||||
if resp != "ipc:ipc" {
|
||||
t.Fatalf("expected rewritten process-hook tool result, got %q", resp)
|
||||
}
|
||||
}
|
||||
|
||||
type blockedToolProvider struct {
|
||||
calls int
|
||||
}
|
||||
|
||||
func (p *blockedToolProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
p.calls++
|
||||
if p.calls == 1 {
|
||||
return &providers.LLMResponse{
|
||||
ToolCalls: []providers.ToolCall{
|
||||
{
|
||||
ID: "call-1",
|
||||
Name: "blocked_tool",
|
||||
Arguments: map[string]any{},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &providers.LLMResponse{
|
||||
Content: messages[len(messages)-1].Content,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *blockedToolProvider) GetDefaultModel() string {
|
||||
return "blocked-tool-provider"
|
||||
}
|
||||
|
||||
func TestAgentLoop_MountProcessHook_ApprovalDeny(t *testing.T) {
|
||||
provider := &blockedToolProvider{}
|
||||
al, agent, cleanup := newHookTestLoop(t, provider)
|
||||
defer cleanup()
|
||||
|
||||
if err := al.MountProcessHook(context.Background(), "ipc-approval", ProcessHookOptions{
|
||||
Command: processHookHelperCommand(),
|
||||
Env: processHookHelperEnv("deny", ""),
|
||||
ApproveTool: true,
|
||||
}); err != nil {
|
||||
t.Fatalf("MountProcessHook failed: %v", err)
|
||||
}
|
||||
|
||||
sub := al.SubscribeEvents(16)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
|
||||
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
||||
SessionKey: "session-1",
|
||||
Channel: "cli",
|
||||
ChatID: "direct",
|
||||
UserMessage: "run blocked tool",
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("runAgentLoop failed: %v", err)
|
||||
}
|
||||
|
||||
expected := "Tool execution denied by approval hook: blocked by ipc hook"
|
||||
if resp != expected {
|
||||
t.Fatalf("expected %q, got %q", expected, resp)
|
||||
}
|
||||
|
||||
events := collectEventStream(sub.C)
|
||||
skippedEvt, ok := findEvent(events, EventKindToolExecSkipped)
|
||||
if !ok {
|
||||
t.Fatal("expected tool skipped event")
|
||||
}
|
||||
payload, ok := skippedEvt.Payload.(ToolExecSkippedPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected ToolExecSkippedPayload, got %T", skippedEvt.Payload)
|
||||
}
|
||||
if payload.Reason != expected {
|
||||
t.Fatalf("expected reason %q, got %q", expected, payload.Reason)
|
||||
}
|
||||
}
|
||||
|
||||
func processHookHelperCommand() []string {
|
||||
return []string{os.Args[0], "-test.run=TestProcessHook_HelperProcess", "--"}
|
||||
}
|
||||
|
||||
func processHookHelperEnv(mode, eventLog string) []string {
|
||||
env := []string{
|
||||
"PICOCLAW_HOOK_HELPER=1",
|
||||
"PICOCLAW_HOOK_MODE=" + mode,
|
||||
}
|
||||
if eventLog != "" {
|
||||
env = append(env, "PICOCLAW_HOOK_EVENT_LOG="+eventLog)
|
||||
}
|
||||
return env
|
||||
}
|
||||
|
||||
func waitForFileContains(t *testing.T, path, substring string) {
|
||||
t.Helper()
|
||||
|
||||
deadline := time.Now().Add(3 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err == nil && strings.Contains(string(data), substring) {
|
||||
return
|
||||
}
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
|
||||
data, _ := os.ReadFile(path)
|
||||
t.Fatalf("timed out waiting for %q in %s; current content: %q", substring, path, string(data))
|
||||
}
|
||||
|
||||
func runProcessHookHelper() error {
|
||||
mode := os.Getenv("PICOCLAW_HOOK_MODE")
|
||||
eventLog := os.Getenv("PICOCLAW_HOOK_EVENT_LOG")
|
||||
|
||||
scanner := bufio.NewScanner(os.Stdin)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), processHookReadBufferSize)
|
||||
encoder := json.NewEncoder(os.Stdout)
|
||||
|
||||
for scanner.Scan() {
|
||||
var msg processHookRPCMessage
|
||||
if err := json.Unmarshal(scanner.Bytes(), &msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if msg.ID == 0 {
|
||||
if msg.Method == "hook.event" && eventLog != "" {
|
||||
var evt map[string]any
|
||||
if err := json.Unmarshal(msg.Params, &evt); err == nil {
|
||||
if rawKind, ok := evt["Kind"].(float64); ok {
|
||||
kind := EventKind(rawKind)
|
||||
_ = os.WriteFile(eventLog, []byte(kind.String()+"\n"), 0o644)
|
||||
}
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
result, rpcErr := handleProcessHookRequest(mode, msg)
|
||||
resp := processHookRPCMessage{
|
||||
JSONRPC: processHookJSONRPCVersion,
|
||||
ID: msg.ID,
|
||||
}
|
||||
if rpcErr != nil {
|
||||
resp.Error = rpcErr
|
||||
} else if result != nil {
|
||||
body, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp.Result = body
|
||||
} else {
|
||||
resp.Result = []byte("{}")
|
||||
}
|
||||
|
||||
if err := encoder.Encode(resp); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
func handleProcessHookRequest(mode string, msg processHookRPCMessage) (any, *processHookRPCError) {
|
||||
switch msg.Method {
|
||||
case "hook.hello":
|
||||
return map[string]any{"ok": true}, nil
|
||||
case "hook.before_llm":
|
||||
if mode != "rewrite" {
|
||||
return map[string]any{"action": HookActionContinue}, nil
|
||||
}
|
||||
var req map[string]any
|
||||
_ = json.Unmarshal(msg.Params, &req)
|
||||
req["model"] = "process-model"
|
||||
return map[string]any{
|
||||
"action": HookActionModify,
|
||||
"request": req,
|
||||
}, nil
|
||||
case "hook.after_llm":
|
||||
if mode != "rewrite" {
|
||||
return map[string]any{"action": HookActionContinue}, nil
|
||||
}
|
||||
var resp map[string]any
|
||||
_ = json.Unmarshal(msg.Params, &resp)
|
||||
if rawResponse, ok := resp["response"].(map[string]any); ok {
|
||||
if content, ok := rawResponse["content"].(string); ok {
|
||||
rawResponse["content"] = content + "|ipc"
|
||||
}
|
||||
}
|
||||
return map[string]any{
|
||||
"action": HookActionModify,
|
||||
"response": resp,
|
||||
}, nil
|
||||
case "hook.before_tool":
|
||||
if mode != "rewrite" {
|
||||
return map[string]any{"action": HookActionContinue}, nil
|
||||
}
|
||||
var call map[string]any
|
||||
_ = json.Unmarshal(msg.Params, &call)
|
||||
rawArgs, ok := call["arguments"].(map[string]any)
|
||||
if !ok || rawArgs == nil {
|
||||
rawArgs = map[string]any{}
|
||||
}
|
||||
rawArgs["text"] = "ipc"
|
||||
call["arguments"] = rawArgs
|
||||
return map[string]any{
|
||||
"action": HookActionModify,
|
||||
"call": call,
|
||||
}, nil
|
||||
case "hook.after_tool":
|
||||
if mode != "rewrite" {
|
||||
return map[string]any{"action": HookActionContinue}, nil
|
||||
}
|
||||
var result map[string]any
|
||||
_ = json.Unmarshal(msg.Params, &result)
|
||||
if rawResult, ok := result["result"].(map[string]any); ok {
|
||||
if forLLM, ok := rawResult["for_llm"].(string); ok {
|
||||
rawResult["for_llm"] = "ipc:" + forLLM
|
||||
}
|
||||
}
|
||||
return map[string]any{
|
||||
"action": HookActionModify,
|
||||
"result": result,
|
||||
}, nil
|
||||
case "hook.approve_tool":
|
||||
if mode == "deny" {
|
||||
return ApprovalDecision{
|
||||
Approved: false,
|
||||
Reason: "blocked by ipc hook",
|
||||
}, nil
|
||||
}
|
||||
return ApprovalDecision{Approved: true}, nil
|
||||
default:
|
||||
return nil, &processHookRPCError{
|
||||
Code: -32601,
|
||||
Message: "method not found",
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,809 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultHookObserverTimeout = 500 * time.Millisecond
|
||||
defaultHookInterceptorTimeout = 5 * time.Second
|
||||
defaultHookApprovalTimeout = 60 * time.Second
|
||||
hookObserverBufferSize = 64
|
||||
)
|
||||
|
||||
type HookAction string
|
||||
|
||||
const (
|
||||
HookActionContinue HookAction = "continue"
|
||||
HookActionModify HookAction = "modify"
|
||||
HookActionDenyTool HookAction = "deny_tool"
|
||||
HookActionAbortTurn HookAction = "abort_turn"
|
||||
HookActionHardAbort HookAction = "hard_abort"
|
||||
)
|
||||
|
||||
type HookDecision struct {
|
||||
Action HookAction `json:"action"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
}
|
||||
|
||||
func (d HookDecision) normalizedAction() HookAction {
|
||||
if d.Action == "" {
|
||||
return HookActionContinue
|
||||
}
|
||||
return d.Action
|
||||
}
|
||||
|
||||
type ApprovalDecision struct {
|
||||
Approved bool `json:"approved"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
}
|
||||
|
||||
type HookSource uint8
|
||||
|
||||
const (
|
||||
HookSourceInProcess HookSource = iota
|
||||
HookSourceProcess
|
||||
)
|
||||
|
||||
type HookRegistration struct {
|
||||
Name string
|
||||
Priority int
|
||||
Source HookSource
|
||||
Hook any
|
||||
}
|
||||
|
||||
func NamedHook(name string, hook any) HookRegistration {
|
||||
return HookRegistration{
|
||||
Name: name,
|
||||
Source: HookSourceInProcess,
|
||||
Hook: hook,
|
||||
}
|
||||
}
|
||||
|
||||
type EventObserver interface {
|
||||
OnEvent(ctx context.Context, evt Event) error
|
||||
}
|
||||
|
||||
type LLMInterceptor interface {
|
||||
BeforeLLM(ctx context.Context, req *LLMHookRequest) (*LLMHookRequest, HookDecision, error)
|
||||
AfterLLM(ctx context.Context, resp *LLMHookResponse) (*LLMHookResponse, HookDecision, error)
|
||||
}
|
||||
|
||||
type ToolInterceptor interface {
|
||||
BeforeTool(ctx context.Context, call *ToolCallHookRequest) (*ToolCallHookRequest, HookDecision, error)
|
||||
AfterTool(ctx context.Context, result *ToolResultHookResponse) (*ToolResultHookResponse, HookDecision, error)
|
||||
}
|
||||
|
||||
type ToolApprover interface {
|
||||
ApproveTool(ctx context.Context, req *ToolApprovalRequest) (ApprovalDecision, error)
|
||||
}
|
||||
|
||||
type LLMHookRequest struct {
|
||||
Meta EventMeta `json:"meta"`
|
||||
Model string `json:"model"`
|
||||
Messages []providers.Message `json:"messages,omitempty"`
|
||||
Tools []providers.ToolDefinition `json:"tools,omitempty"`
|
||||
Options map[string]any `json:"options,omitempty"`
|
||||
Channel string `json:"channel,omitempty"`
|
||||
ChatID string `json:"chat_id,omitempty"`
|
||||
GracefulTerminal bool `json:"graceful_terminal,omitempty"`
|
||||
}
|
||||
|
||||
func (r *LLMHookRequest) Clone() *LLMHookRequest {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *r
|
||||
cloned.Messages = cloneProviderMessages(r.Messages)
|
||||
cloned.Tools = cloneToolDefinitions(r.Tools)
|
||||
cloned.Options = cloneStringAnyMap(r.Options)
|
||||
return &cloned
|
||||
}
|
||||
|
||||
type LLMHookResponse struct {
|
||||
Meta EventMeta `json:"meta"`
|
||||
Model string `json:"model"`
|
||||
Response *providers.LLMResponse `json:"response,omitempty"`
|
||||
Channel string `json:"channel,omitempty"`
|
||||
ChatID string `json:"chat_id,omitempty"`
|
||||
}
|
||||
|
||||
func (r *LLMHookResponse) Clone() *LLMHookResponse {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *r
|
||||
cloned.Response = cloneLLMResponse(r.Response)
|
||||
return &cloned
|
||||
}
|
||||
|
||||
type ToolCallHookRequest struct {
|
||||
Meta EventMeta `json:"meta"`
|
||||
Tool string `json:"tool"`
|
||||
Arguments map[string]any `json:"arguments,omitempty"`
|
||||
Channel string `json:"channel,omitempty"`
|
||||
ChatID string `json:"chat_id,omitempty"`
|
||||
}
|
||||
|
||||
func (r *ToolCallHookRequest) Clone() *ToolCallHookRequest {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *r
|
||||
cloned.Arguments = cloneStringAnyMap(r.Arguments)
|
||||
return &cloned
|
||||
}
|
||||
|
||||
type ToolApprovalRequest struct {
|
||||
Meta EventMeta `json:"meta"`
|
||||
Tool string `json:"tool"`
|
||||
Arguments map[string]any `json:"arguments,omitempty"`
|
||||
Channel string `json:"channel,omitempty"`
|
||||
ChatID string `json:"chat_id,omitempty"`
|
||||
}
|
||||
|
||||
func (r *ToolApprovalRequest) Clone() *ToolApprovalRequest {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *r
|
||||
cloned.Arguments = cloneStringAnyMap(r.Arguments)
|
||||
return &cloned
|
||||
}
|
||||
|
||||
type ToolResultHookResponse struct {
|
||||
Meta EventMeta `json:"meta"`
|
||||
Tool string `json:"tool"`
|
||||
Arguments map[string]any `json:"arguments,omitempty"`
|
||||
Result *tools.ToolResult `json:"result,omitempty"`
|
||||
Duration time.Duration `json:"duration"`
|
||||
Channel string `json:"channel,omitempty"`
|
||||
ChatID string `json:"chat_id,omitempty"`
|
||||
}
|
||||
|
||||
func (r *ToolResultHookResponse) Clone() *ToolResultHookResponse {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *r
|
||||
cloned.Arguments = cloneStringAnyMap(r.Arguments)
|
||||
cloned.Result = cloneToolResult(r.Result)
|
||||
return &cloned
|
||||
}
|
||||
|
||||
type HookManager struct {
|
||||
eventBus *EventBus
|
||||
observerTimeout time.Duration
|
||||
interceptorTimeout time.Duration
|
||||
approvalTimeout time.Duration
|
||||
|
||||
mu sync.RWMutex
|
||||
hooks map[string]HookRegistration
|
||||
ordered []HookRegistration
|
||||
|
||||
sub EventSubscription
|
||||
done chan struct{}
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
func NewHookManager(eventBus *EventBus) *HookManager {
|
||||
hm := &HookManager{
|
||||
eventBus: eventBus,
|
||||
observerTimeout: defaultHookObserverTimeout,
|
||||
interceptorTimeout: defaultHookInterceptorTimeout,
|
||||
approvalTimeout: defaultHookApprovalTimeout,
|
||||
hooks: make(map[string]HookRegistration),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
if eventBus == nil {
|
||||
close(hm.done)
|
||||
return hm
|
||||
}
|
||||
|
||||
hm.sub = eventBus.Subscribe(hookObserverBufferSize)
|
||||
go hm.dispatchEvents()
|
||||
return hm
|
||||
}
|
||||
|
||||
func (hm *HookManager) Close() {
|
||||
if hm == nil {
|
||||
return
|
||||
}
|
||||
|
||||
hm.closeOnce.Do(func() {
|
||||
if hm.eventBus != nil {
|
||||
hm.eventBus.Unsubscribe(hm.sub.ID)
|
||||
}
|
||||
<-hm.done
|
||||
hm.closeAllHooks()
|
||||
})
|
||||
}
|
||||
|
||||
func (hm *HookManager) ConfigureTimeouts(observer, interceptor, approval time.Duration) {
|
||||
if hm == nil {
|
||||
return
|
||||
}
|
||||
if observer > 0 {
|
||||
hm.observerTimeout = observer
|
||||
}
|
||||
if interceptor > 0 {
|
||||
hm.interceptorTimeout = interceptor
|
||||
}
|
||||
if approval > 0 {
|
||||
hm.approvalTimeout = approval
|
||||
}
|
||||
}
|
||||
|
||||
func (hm *HookManager) Mount(reg HookRegistration) error {
|
||||
if hm == nil {
|
||||
return fmt.Errorf("hook manager is nil")
|
||||
}
|
||||
if reg.Name == "" {
|
||||
return fmt.Errorf("hook name is required")
|
||||
}
|
||||
if reg.Hook == nil {
|
||||
return fmt.Errorf("hook %q is nil", reg.Name)
|
||||
}
|
||||
|
||||
hm.mu.Lock()
|
||||
defer hm.mu.Unlock()
|
||||
|
||||
if existing, ok := hm.hooks[reg.Name]; ok {
|
||||
closeHookIfPossible(existing.Hook)
|
||||
}
|
||||
hm.hooks[reg.Name] = reg
|
||||
hm.rebuildOrdered()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hm *HookManager) Unmount(name string) {
|
||||
if hm == nil || name == "" {
|
||||
return
|
||||
}
|
||||
|
||||
hm.mu.Lock()
|
||||
defer hm.mu.Unlock()
|
||||
|
||||
if existing, ok := hm.hooks[name]; ok {
|
||||
closeHookIfPossible(existing.Hook)
|
||||
}
|
||||
delete(hm.hooks, name)
|
||||
hm.rebuildOrdered()
|
||||
}
|
||||
|
||||
func (hm *HookManager) dispatchEvents() {
|
||||
defer close(hm.done)
|
||||
|
||||
for evt := range hm.sub.C {
|
||||
for _, reg := range hm.snapshotHooks() {
|
||||
observer, ok := reg.Hook.(EventObserver)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
hm.runObserver(reg.Name, observer, evt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (hm *HookManager) BeforeLLM(ctx context.Context, req *LLMHookRequest) (*LLMHookRequest, HookDecision) {
|
||||
if hm == nil || req == nil {
|
||||
return req, HookDecision{Action: HookActionContinue}
|
||||
}
|
||||
|
||||
current := req.Clone()
|
||||
for _, reg := range hm.snapshotHooks() {
|
||||
interceptor, ok := reg.Hook.(LLMInterceptor)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
next, decision, ok := hm.callBeforeLLM(ctx, reg.Name, interceptor, current.Clone())
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
switch decision.normalizedAction() {
|
||||
case HookActionContinue, HookActionModify:
|
||||
if next != nil {
|
||||
current = next
|
||||
}
|
||||
case HookActionAbortTurn, HookActionHardAbort:
|
||||
return current, decision
|
||||
default:
|
||||
hm.logUnsupportedAction(reg.Name, "before_llm", decision.Action)
|
||||
}
|
||||
}
|
||||
return current, HookDecision{Action: HookActionContinue}
|
||||
}
|
||||
|
||||
func (hm *HookManager) AfterLLM(ctx context.Context, resp *LLMHookResponse) (*LLMHookResponse, HookDecision) {
|
||||
if hm == nil || resp == nil {
|
||||
return resp, HookDecision{Action: HookActionContinue}
|
||||
}
|
||||
|
||||
current := resp.Clone()
|
||||
for _, reg := range hm.snapshotHooks() {
|
||||
interceptor, ok := reg.Hook.(LLMInterceptor)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
next, decision, ok := hm.callAfterLLM(ctx, reg.Name, interceptor, current.Clone())
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
switch decision.normalizedAction() {
|
||||
case HookActionContinue, HookActionModify:
|
||||
if next != nil {
|
||||
current = next
|
||||
}
|
||||
case HookActionAbortTurn, HookActionHardAbort:
|
||||
return current, decision
|
||||
default:
|
||||
hm.logUnsupportedAction(reg.Name, "after_llm", decision.Action)
|
||||
}
|
||||
}
|
||||
return current, HookDecision{Action: HookActionContinue}
|
||||
}
|
||||
|
||||
func (hm *HookManager) BeforeTool(
|
||||
ctx context.Context,
|
||||
call *ToolCallHookRequest,
|
||||
) (*ToolCallHookRequest, HookDecision) {
|
||||
if hm == nil || call == nil {
|
||||
return call, HookDecision{Action: HookActionContinue}
|
||||
}
|
||||
|
||||
current := call.Clone()
|
||||
for _, reg := range hm.snapshotHooks() {
|
||||
interceptor, ok := reg.Hook.(ToolInterceptor)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
next, decision, ok := hm.callBeforeTool(ctx, reg.Name, interceptor, current.Clone())
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
switch decision.normalizedAction() {
|
||||
case HookActionContinue, HookActionModify:
|
||||
if next != nil {
|
||||
current = next
|
||||
}
|
||||
case HookActionDenyTool, HookActionAbortTurn, HookActionHardAbort:
|
||||
return current, decision
|
||||
default:
|
||||
hm.logUnsupportedAction(reg.Name, "before_tool", decision.Action)
|
||||
}
|
||||
}
|
||||
return current, HookDecision{Action: HookActionContinue}
|
||||
}
|
||||
|
||||
func (hm *HookManager) AfterTool(
|
||||
ctx context.Context,
|
||||
result *ToolResultHookResponse,
|
||||
) (*ToolResultHookResponse, HookDecision) {
|
||||
if hm == nil || result == nil {
|
||||
return result, HookDecision{Action: HookActionContinue}
|
||||
}
|
||||
|
||||
current := result.Clone()
|
||||
for _, reg := range hm.snapshotHooks() {
|
||||
interceptor, ok := reg.Hook.(ToolInterceptor)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
next, decision, ok := hm.callAfterTool(ctx, reg.Name, interceptor, current.Clone())
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
switch decision.normalizedAction() {
|
||||
case HookActionContinue, HookActionModify:
|
||||
if next != nil {
|
||||
current = next
|
||||
}
|
||||
case HookActionAbortTurn, HookActionHardAbort:
|
||||
return current, decision
|
||||
default:
|
||||
hm.logUnsupportedAction(reg.Name, "after_tool", decision.Action)
|
||||
}
|
||||
}
|
||||
return current, HookDecision{Action: HookActionContinue}
|
||||
}
|
||||
|
||||
func (hm *HookManager) ApproveTool(ctx context.Context, req *ToolApprovalRequest) ApprovalDecision {
|
||||
if hm == nil || req == nil {
|
||||
return ApprovalDecision{Approved: true}
|
||||
}
|
||||
|
||||
for _, reg := range hm.snapshotHooks() {
|
||||
approver, ok := reg.Hook.(ToolApprover)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
decision, ok := hm.callApproveTool(ctx, reg.Name, approver, req.Clone())
|
||||
if !ok {
|
||||
return ApprovalDecision{
|
||||
Approved: false,
|
||||
Reason: fmt.Sprintf("tool approval hook %q failed", reg.Name),
|
||||
}
|
||||
}
|
||||
if !decision.Approved {
|
||||
return decision
|
||||
}
|
||||
}
|
||||
|
||||
return ApprovalDecision{Approved: true}
|
||||
}
|
||||
|
||||
func (hm *HookManager) rebuildOrdered() {
|
||||
hm.ordered = hm.ordered[:0]
|
||||
for _, reg := range hm.hooks {
|
||||
hm.ordered = append(hm.ordered, reg)
|
||||
}
|
||||
sort.SliceStable(hm.ordered, func(i, j int) bool {
|
||||
if hm.ordered[i].Source != hm.ordered[j].Source {
|
||||
return hm.ordered[i].Source < hm.ordered[j].Source
|
||||
}
|
||||
if hm.ordered[i].Priority == hm.ordered[j].Priority {
|
||||
return hm.ordered[i].Name < hm.ordered[j].Name
|
||||
}
|
||||
return hm.ordered[i].Priority < hm.ordered[j].Priority
|
||||
})
|
||||
}
|
||||
|
||||
func (hm *HookManager) snapshotHooks() []HookRegistration {
|
||||
hm.mu.RLock()
|
||||
defer hm.mu.RUnlock()
|
||||
|
||||
snapshot := make([]HookRegistration, len(hm.ordered))
|
||||
copy(snapshot, hm.ordered)
|
||||
return snapshot
|
||||
}
|
||||
|
||||
func (hm *HookManager) closeAllHooks() {
|
||||
hm.mu.Lock()
|
||||
defer hm.mu.Unlock()
|
||||
|
||||
for name, reg := range hm.hooks {
|
||||
closeHookIfPossible(reg.Hook)
|
||||
delete(hm.hooks, name)
|
||||
}
|
||||
hm.ordered = nil
|
||||
}
|
||||
|
||||
func (hm *HookManager) runObserver(name string, observer EventObserver, evt Event) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hm.observerTimeout)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- observer.OnEvent(ctx, evt)
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
logger.WarnCF("hooks", "Event observer failed", map[string]any{
|
||||
"hook": name,
|
||||
"event": evt.Kind.String(),
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
case <-ctx.Done():
|
||||
logger.WarnCF("hooks", "Event observer timed out", map[string]any{
|
||||
"hook": name,
|
||||
"event": evt.Kind.String(),
|
||||
"timeout_ms": hm.observerTimeout.Milliseconds(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (hm *HookManager) callBeforeLLM(
|
||||
parent context.Context,
|
||||
name string,
|
||||
interceptor LLMInterceptor,
|
||||
req *LLMHookRequest,
|
||||
) (*LLMHookRequest, HookDecision, bool) {
|
||||
return runInterceptorHook(
|
||||
parent,
|
||||
hm.interceptorTimeout,
|
||||
name,
|
||||
"before_llm",
|
||||
func(ctx context.Context) (*LLMHookRequest, HookDecision, error) {
|
||||
return interceptor.BeforeLLM(ctx, req)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func (hm *HookManager) callAfterLLM(
|
||||
parent context.Context,
|
||||
name string,
|
||||
interceptor LLMInterceptor,
|
||||
resp *LLMHookResponse,
|
||||
) (*LLMHookResponse, HookDecision, bool) {
|
||||
return runInterceptorHook(
|
||||
parent,
|
||||
hm.interceptorTimeout,
|
||||
name,
|
||||
"after_llm",
|
||||
func(ctx context.Context) (*LLMHookResponse, HookDecision, error) {
|
||||
return interceptor.AfterLLM(ctx, resp)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func (hm *HookManager) callBeforeTool(
|
||||
parent context.Context,
|
||||
name string,
|
||||
interceptor ToolInterceptor,
|
||||
call *ToolCallHookRequest,
|
||||
) (*ToolCallHookRequest, HookDecision, bool) {
|
||||
return runInterceptorHook(
|
||||
parent,
|
||||
hm.interceptorTimeout,
|
||||
name,
|
||||
"before_tool",
|
||||
func(ctx context.Context) (*ToolCallHookRequest, HookDecision, error) {
|
||||
return interceptor.BeforeTool(ctx, call)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func (hm *HookManager) callAfterTool(
|
||||
parent context.Context,
|
||||
name string,
|
||||
interceptor ToolInterceptor,
|
||||
resultView *ToolResultHookResponse,
|
||||
) (*ToolResultHookResponse, HookDecision, bool) {
|
||||
return runInterceptorHook(
|
||||
parent,
|
||||
hm.interceptorTimeout,
|
||||
name,
|
||||
"after_tool",
|
||||
func(ctx context.Context) (*ToolResultHookResponse, HookDecision, error) {
|
||||
return interceptor.AfterTool(ctx, resultView)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func (hm *HookManager) callApproveTool(
|
||||
parent context.Context,
|
||||
name string,
|
||||
approver ToolApprover,
|
||||
req *ToolApprovalRequest,
|
||||
) (ApprovalDecision, bool) {
|
||||
return runApprovalHook(
|
||||
parent,
|
||||
hm.approvalTimeout,
|
||||
name,
|
||||
"approve_tool",
|
||||
func(ctx context.Context) (ApprovalDecision, error) {
|
||||
return approver.ApproveTool(ctx, req)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func runInterceptorHook[T any](
|
||||
parent context.Context,
|
||||
timeout time.Duration,
|
||||
name string,
|
||||
stage string,
|
||||
fn func(ctx context.Context) (T, HookDecision, error),
|
||||
) (T, HookDecision, bool) {
|
||||
var zero T
|
||||
|
||||
ctx, cancel := context.WithTimeout(parent, timeout)
|
||||
defer cancel()
|
||||
|
||||
type result struct {
|
||||
value T
|
||||
decision HookDecision
|
||||
err error
|
||||
}
|
||||
done := make(chan result, 1)
|
||||
go func() {
|
||||
value, decision, err := fn(ctx)
|
||||
done <- result{value: value, decision: decision, err: err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case res := <-done:
|
||||
if res.err != nil {
|
||||
logger.WarnCF("hooks", "Interceptor hook failed", map[string]any{
|
||||
"hook": name,
|
||||
"stage": stage,
|
||||
"error": res.err.Error(),
|
||||
})
|
||||
return zero, HookDecision{}, false
|
||||
}
|
||||
return res.value, res.decision, true
|
||||
case <-ctx.Done():
|
||||
logger.WarnCF("hooks", "Interceptor hook timed out", map[string]any{
|
||||
"hook": name,
|
||||
"stage": stage,
|
||||
"timeout_ms": timeout.Milliseconds(),
|
||||
})
|
||||
return zero, HookDecision{}, false
|
||||
}
|
||||
}
|
||||
|
||||
func runApprovalHook(
|
||||
parent context.Context,
|
||||
timeout time.Duration,
|
||||
name string,
|
||||
stage string,
|
||||
fn func(ctx context.Context) (ApprovalDecision, error),
|
||||
) (ApprovalDecision, bool) {
|
||||
ctx, cancel := context.WithTimeout(parent, timeout)
|
||||
defer cancel()
|
||||
|
||||
type result struct {
|
||||
decision ApprovalDecision
|
||||
err error
|
||||
}
|
||||
done := make(chan result, 1)
|
||||
go func() {
|
||||
decision, err := fn(ctx)
|
||||
done <- result{decision: decision, err: err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case res := <-done:
|
||||
if res.err != nil {
|
||||
logger.WarnCF("hooks", "Approval hook failed", map[string]any{
|
||||
"hook": name,
|
||||
"stage": stage,
|
||||
"error": res.err.Error(),
|
||||
})
|
||||
return ApprovalDecision{}, false
|
||||
}
|
||||
return res.decision, true
|
||||
case <-ctx.Done():
|
||||
logger.WarnCF("hooks", "Approval hook timed out", map[string]any{
|
||||
"hook": name,
|
||||
"stage": stage,
|
||||
"timeout_ms": timeout.Milliseconds(),
|
||||
})
|
||||
return ApprovalDecision{
|
||||
Approved: false,
|
||||
Reason: fmt.Sprintf("tool approval hook %q timed out", name),
|
||||
}, true
|
||||
}
|
||||
}
|
||||
|
||||
func (hm *HookManager) logUnsupportedAction(name, stage string, action HookAction) {
|
||||
logger.WarnCF("hooks", "Hook returned unsupported action for stage", map[string]any{
|
||||
"hook": name,
|
||||
"stage": stage,
|
||||
"action": action,
|
||||
})
|
||||
}
|
||||
|
||||
func cloneProviderMessages(messages []providers.Message) []providers.Message {
|
||||
if len(messages) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
cloned := make([]providers.Message, len(messages))
|
||||
for i, msg := range messages {
|
||||
cloned[i] = msg
|
||||
if len(msg.Media) > 0 {
|
||||
cloned[i].Media = append([]string(nil), msg.Media...)
|
||||
}
|
||||
if len(msg.SystemParts) > 0 {
|
||||
cloned[i].SystemParts = append([]providers.ContentBlock(nil), msg.SystemParts...)
|
||||
}
|
||||
if len(msg.ToolCalls) > 0 {
|
||||
cloned[i].ToolCalls = cloneProviderToolCalls(msg.ToolCalls)
|
||||
}
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func cloneProviderToolCalls(calls []providers.ToolCall) []providers.ToolCall {
|
||||
if len(calls) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
cloned := make([]providers.ToolCall, len(calls))
|
||||
for i, call := range calls {
|
||||
cloned[i] = call
|
||||
if call.Function != nil {
|
||||
fn := *call.Function
|
||||
cloned[i].Function = &fn
|
||||
}
|
||||
if call.Arguments != nil {
|
||||
cloned[i].Arguments = cloneStringAnyMap(call.Arguments)
|
||||
}
|
||||
if call.ExtraContent != nil {
|
||||
extra := *call.ExtraContent
|
||||
if call.ExtraContent.Google != nil {
|
||||
google := *call.ExtraContent.Google
|
||||
extra.Google = &google
|
||||
}
|
||||
cloned[i].ExtraContent = &extra
|
||||
}
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func cloneToolDefinitions(defs []providers.ToolDefinition) []providers.ToolDefinition {
|
||||
if len(defs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
cloned := make([]providers.ToolDefinition, len(defs))
|
||||
for i, def := range defs {
|
||||
cloned[i] = def
|
||||
cloned[i].Function.Parameters = cloneStringAnyMap(def.Function.Parameters)
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func cloneLLMResponse(resp *providers.LLMResponse) *providers.LLMResponse {
|
||||
if resp == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *resp
|
||||
cloned.ToolCalls = cloneProviderToolCalls(resp.ToolCalls)
|
||||
if len(resp.ReasoningDetails) > 0 {
|
||||
cloned.ReasoningDetails = append(cloned.ReasoningDetails[:0:0], resp.ReasoningDetails...)
|
||||
}
|
||||
if resp.Usage != nil {
|
||||
usage := *resp.Usage
|
||||
cloned.Usage = &usage
|
||||
}
|
||||
return &cloned
|
||||
}
|
||||
|
||||
func cloneStringAnyMap(src map[string]any) map[string]any {
|
||||
if len(src) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
cloned := make(map[string]any, len(src))
|
||||
for k, v := range src {
|
||||
cloned[k] = v
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func cloneToolResult(result *tools.ToolResult) *tools.ToolResult {
|
||||
if result == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
cloned := *result
|
||||
if len(result.Media) > 0 {
|
||||
cloned.Media = append([]string(nil), result.Media...)
|
||||
}
|
||||
return &cloned
|
||||
}
|
||||
|
||||
func closeHookIfPossible(hook any) {
|
||||
closer, ok := hook.(io.Closer)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if err := closer.Close(); err != nil {
|
||||
logger.WarnCF("hooks", "Failed to close hook", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,345 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
|
||||
func newHookTestLoop(
|
||||
t *testing.T,
|
||||
provider providers.LLMProvider,
|
||||
) (*AgentLoop, *AgentInstance, func()) {
|
||||
t.Helper()
|
||||
|
||||
tmpDir, err := os.MkdirTemp("", "agent-hooks-*")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp dir: %v", err)
|
||||
}
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
al := NewAgentLoop(cfg, bus.NewMessageBus(), provider)
|
||||
agent := al.registry.GetDefaultAgent()
|
||||
if agent == nil {
|
||||
t.Fatal("expected default agent")
|
||||
}
|
||||
|
||||
return al, agent, func() {
|
||||
al.Close()
|
||||
_ = os.RemoveAll(tmpDir)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHookManager_SortsInProcessBeforeProcess(t *testing.T) {
|
||||
hm := NewHookManager(nil)
|
||||
defer hm.Close()
|
||||
|
||||
if err := hm.Mount(HookRegistration{
|
||||
Name: "process",
|
||||
Priority: -10,
|
||||
Source: HookSourceProcess,
|
||||
Hook: struct{}{},
|
||||
}); err != nil {
|
||||
t.Fatalf("mount process hook: %v", err)
|
||||
}
|
||||
if err := hm.Mount(HookRegistration{
|
||||
Name: "in-process",
|
||||
Priority: 100,
|
||||
Source: HookSourceInProcess,
|
||||
Hook: struct{}{},
|
||||
}); err != nil {
|
||||
t.Fatalf("mount in-process hook: %v", err)
|
||||
}
|
||||
|
||||
ordered := hm.snapshotHooks()
|
||||
if len(ordered) != 2 {
|
||||
t.Fatalf("expected 2 hooks, got %d", len(ordered))
|
||||
}
|
||||
if ordered[0].Name != "in-process" {
|
||||
t.Fatalf("expected in-process hook first, got %q", ordered[0].Name)
|
||||
}
|
||||
if ordered[1].Name != "process" {
|
||||
t.Fatalf("expected process hook second, got %q", ordered[1].Name)
|
||||
}
|
||||
}
|
||||
|
||||
type llmHookTestProvider struct {
|
||||
mu sync.Mutex
|
||||
lastModel string
|
||||
}
|
||||
|
||||
func (p *llmHookTestProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
p.mu.Lock()
|
||||
p.lastModel = model
|
||||
p.mu.Unlock()
|
||||
|
||||
return &providers.LLMResponse{
|
||||
Content: "provider content",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *llmHookTestProvider) GetDefaultModel() string {
|
||||
return "llm-hook-provider"
|
||||
}
|
||||
|
||||
type llmObserverHook struct {
|
||||
eventCh chan Event
|
||||
}
|
||||
|
||||
func (h *llmObserverHook) OnEvent(ctx context.Context, evt Event) error {
|
||||
if evt.Kind == EventKindTurnEnd {
|
||||
select {
|
||||
case h.eventCh <- evt:
|
||||
default:
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *llmObserverHook) BeforeLLM(
|
||||
ctx context.Context,
|
||||
req *LLMHookRequest,
|
||||
) (*LLMHookRequest, HookDecision, error) {
|
||||
next := req.Clone()
|
||||
next.Model = "hook-model"
|
||||
return next, HookDecision{Action: HookActionModify}, nil
|
||||
}
|
||||
|
||||
func (h *llmObserverHook) AfterLLM(
|
||||
ctx context.Context,
|
||||
resp *LLMHookResponse,
|
||||
) (*LLMHookResponse, HookDecision, error) {
|
||||
next := resp.Clone()
|
||||
next.Response.Content = "hooked content"
|
||||
return next, HookDecision{Action: HookActionModify}, nil
|
||||
}
|
||||
|
||||
func TestAgentLoop_Hooks_ObserverAndLLMInterceptor(t *testing.T) {
|
||||
provider := &llmHookTestProvider{}
|
||||
al, agent, cleanup := newHookTestLoop(t, provider)
|
||||
defer cleanup()
|
||||
|
||||
hook := &llmObserverHook{eventCh: make(chan Event, 1)}
|
||||
if err := al.MountHook(NamedHook("llm-observer", hook)); err != nil {
|
||||
t.Fatalf("MountHook failed: %v", err)
|
||||
}
|
||||
|
||||
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
||||
SessionKey: "session-1",
|
||||
Channel: "cli",
|
||||
ChatID: "direct",
|
||||
UserMessage: "hello",
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("runAgentLoop failed: %v", err)
|
||||
}
|
||||
if resp != "hooked content" {
|
||||
t.Fatalf("expected hooked content, got %q", resp)
|
||||
}
|
||||
|
||||
provider.mu.Lock()
|
||||
lastModel := provider.lastModel
|
||||
provider.mu.Unlock()
|
||||
if lastModel != "hook-model" {
|
||||
t.Fatalf("expected model hook-model, got %q", lastModel)
|
||||
}
|
||||
|
||||
select {
|
||||
case evt := <-hook.eventCh:
|
||||
if evt.Kind != EventKindTurnEnd {
|
||||
t.Fatalf("expected turn end event, got %v", evt.Kind)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for hook observer event")
|
||||
}
|
||||
}
|
||||
|
||||
type toolHookProvider struct {
|
||||
mu sync.Mutex
|
||||
calls int
|
||||
}
|
||||
|
||||
func (p *toolHookProvider) Chat(
|
||||
ctx context.Context,
|
||||
messages []providers.Message,
|
||||
tools []providers.ToolDefinition,
|
||||
model string,
|
||||
opts map[string]any,
|
||||
) (*providers.LLMResponse, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
p.calls++
|
||||
if p.calls == 1 {
|
||||
return &providers.LLMResponse{
|
||||
ToolCalls: []providers.ToolCall{
|
||||
{
|
||||
ID: "call-1",
|
||||
Name: "echo_text",
|
||||
Arguments: map[string]any{"text": "original"},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
last := messages[len(messages)-1]
|
||||
return &providers.LLMResponse{
|
||||
Content: last.Content,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *toolHookProvider) GetDefaultModel() string {
|
||||
return "tool-hook-provider"
|
||||
}
|
||||
|
||||
type echoTextTool struct{}
|
||||
|
||||
func (t *echoTextTool) Name() string {
|
||||
return "echo_text"
|
||||
}
|
||||
|
||||
func (t *echoTextTool) Description() string {
|
||||
return "echo a text argument"
|
||||
}
|
||||
|
||||
func (t *echoTextTool) Parameters() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"text": map[string]any{
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *echoTextTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
|
||||
text, _ := args["text"].(string)
|
||||
return tools.SilentResult(text)
|
||||
}
|
||||
|
||||
type toolRewriteHook struct{}
|
||||
|
||||
func (h *toolRewriteHook) BeforeTool(
|
||||
ctx context.Context,
|
||||
call *ToolCallHookRequest,
|
||||
) (*ToolCallHookRequest, HookDecision, error) {
|
||||
next := call.Clone()
|
||||
next.Arguments["text"] = "modified"
|
||||
return next, HookDecision{Action: HookActionModify}, nil
|
||||
}
|
||||
|
||||
func (h *toolRewriteHook) AfterTool(
|
||||
ctx context.Context,
|
||||
result *ToolResultHookResponse,
|
||||
) (*ToolResultHookResponse, HookDecision, error) {
|
||||
next := result.Clone()
|
||||
next.Result.ForLLM = "after:" + next.Result.ForLLM
|
||||
return next, HookDecision{Action: HookActionModify}, nil
|
||||
}
|
||||
|
||||
func TestAgentLoop_Hooks_ToolInterceptorCanRewrite(t *testing.T) {
|
||||
provider := &toolHookProvider{}
|
||||
al, agent, cleanup := newHookTestLoop(t, provider)
|
||||
defer cleanup()
|
||||
|
||||
al.RegisterTool(&echoTextTool{})
|
||||
if err := al.MountHook(NamedHook("tool-rewrite", &toolRewriteHook{})); err != nil {
|
||||
t.Fatalf("MountHook failed: %v", err)
|
||||
}
|
||||
|
||||
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
||||
SessionKey: "session-1",
|
||||
Channel: "cli",
|
||||
ChatID: "direct",
|
||||
UserMessage: "run tool",
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("runAgentLoop failed: %v", err)
|
||||
}
|
||||
if resp != "after:modified" {
|
||||
t.Fatalf("expected rewritten tool result, got %q", resp)
|
||||
}
|
||||
}
|
||||
|
||||
type denyApprovalHook struct{}
|
||||
|
||||
func (h *denyApprovalHook) ApproveTool(ctx context.Context, req *ToolApprovalRequest) (ApprovalDecision, error) {
|
||||
return ApprovalDecision{
|
||||
Approved: false,
|
||||
Reason: "blocked",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func TestAgentLoop_Hooks_ToolApproverCanDeny(t *testing.T) {
|
||||
provider := &toolHookProvider{}
|
||||
al, agent, cleanup := newHookTestLoop(t, provider)
|
||||
defer cleanup()
|
||||
|
||||
al.RegisterTool(&echoTextTool{})
|
||||
if err := al.MountHook(NamedHook("deny-approval", &denyApprovalHook{})); err != nil {
|
||||
t.Fatalf("MountHook failed: %v", err)
|
||||
}
|
||||
|
||||
sub := al.SubscribeEvents(16)
|
||||
defer al.UnsubscribeEvents(sub.ID)
|
||||
|
||||
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
|
||||
SessionKey: "session-1",
|
||||
Channel: "cli",
|
||||
ChatID: "direct",
|
||||
UserMessage: "run tool",
|
||||
DefaultResponse: defaultResponse,
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("runAgentLoop failed: %v", err)
|
||||
}
|
||||
expected := "Tool execution denied by approval hook: blocked"
|
||||
if resp != expected {
|
||||
t.Fatalf("expected %q, got %q", expected, resp)
|
||||
}
|
||||
|
||||
events := collectEventStream(sub.C)
|
||||
skippedEvt, ok := findEvent(events, EventKindToolExecSkipped)
|
||||
if !ok {
|
||||
t.Fatal("expected tool skipped event")
|
||||
}
|
||||
payload, ok := skippedEvt.Payload.(ToolExecSkippedPayload)
|
||||
if !ok {
|
||||
t.Fatalf("expected ToolExecSkippedPayload, got %T", skippedEvt.Payload)
|
||||
}
|
||||
if payload.Reason != expected {
|
||||
t.Fatalf("expected skipped reason %q, got %q", expected, payload.Reason)
|
||||
}
|
||||
}
|
||||
+280
-47
@@ -40,6 +40,7 @@ type AgentLoop struct {
|
||||
registry *AgentRegistry
|
||||
state *state.Manager
|
||||
eventBus *EventBus
|
||||
hooks *HookManager
|
||||
running atomic.Bool
|
||||
summarizing sync.Map
|
||||
fallback *providers.FallbackChain
|
||||
@@ -48,6 +49,7 @@ type AgentLoop struct {
|
||||
transcriber voice.Transcriber
|
||||
cmdRegistry *commands.Registry
|
||||
mcp mcpRuntime
|
||||
hookRuntime hookRuntime
|
||||
steering *steeringQueue
|
||||
mu sync.RWMutex
|
||||
activeTurnMu sync.RWMutex
|
||||
@@ -109,17 +111,20 @@ func NewAgentLoop(
|
||||
stateManager = state.NewManager(defaultAgent.Workspace)
|
||||
}
|
||||
|
||||
eventBus := NewEventBus()
|
||||
al := &AgentLoop{
|
||||
bus: msgBus,
|
||||
cfg: cfg,
|
||||
registry: registry,
|
||||
state: stateManager,
|
||||
eventBus: NewEventBus(),
|
||||
eventBus: eventBus,
|
||||
summarizing: sync.Map{},
|
||||
fallback: fallbackChain,
|
||||
cmdRegistry: commands.NewRegistry(commands.BuiltinDefinitions()),
|
||||
steering: newSteeringQueue(parseSteeringMode(cfg.Agents.Defaults.SteeringMode)),
|
||||
}
|
||||
al.hooks = NewHookManager(eventBus)
|
||||
configureHookManagerFromConfig(al.hooks, cfg)
|
||||
|
||||
return al
|
||||
}
|
||||
@@ -257,6 +262,9 @@ func registerSharedTools(
|
||||
func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
al.running.Store(true)
|
||||
|
||||
if err := al.ensureHooksInitialized(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := al.ensureMCPInitialized(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -512,11 +520,30 @@ func (al *AgentLoop) Close() {
|
||||
}
|
||||
|
||||
al.GetRegistry().Close()
|
||||
if al.hooks != nil {
|
||||
al.hooks.Close()
|
||||
}
|
||||
if al.eventBus != nil {
|
||||
al.eventBus.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// MountHook registers an in-process hook on the agent loop.
|
||||
func (al *AgentLoop) MountHook(reg HookRegistration) error {
|
||||
if al == nil || al.hooks == nil {
|
||||
return fmt.Errorf("hook manager is not initialized")
|
||||
}
|
||||
return al.hooks.Mount(reg)
|
||||
}
|
||||
|
||||
// UnmountHook removes a previously registered in-process hook.
|
||||
func (al *AgentLoop) UnmountHook(name string) {
|
||||
if al == nil || al.hooks == nil {
|
||||
return
|
||||
}
|
||||
al.hooks.Unmount(name)
|
||||
}
|
||||
|
||||
// SubscribeEvents registers a subscriber for agent-loop events.
|
||||
func (al *AgentLoop) SubscribeEvents(buffer int) EventSubscription {
|
||||
if al == nil || al.eventBus == nil {
|
||||
@@ -596,6 +623,31 @@ func cloneEventArguments(args map[string]any) map[string]any {
|
||||
return cloned
|
||||
}
|
||||
|
||||
func (al *AgentLoop) hookAbortError(ts *turnState, stage string, decision HookDecision) error {
|
||||
reason := decision.Reason
|
||||
if reason == "" {
|
||||
reason = "hook requested turn abort"
|
||||
}
|
||||
|
||||
err := fmt.Errorf("hook aborted turn during %s: %s", stage, reason)
|
||||
al.emitEvent(
|
||||
EventKindError,
|
||||
ts.eventMeta("hooks", "turn.error"),
|
||||
ErrorPayload{
|
||||
Stage: "hook." + stage,
|
||||
Message: err.Error(),
|
||||
},
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func hookDeniedToolContent(prefix, reason string) string {
|
||||
if reason == "" {
|
||||
return prefix
|
||||
}
|
||||
return prefix + ": " + reason
|
||||
}
|
||||
|
||||
func (al *AgentLoop) logEvent(evt Event) {
|
||||
fields := map[string]any{
|
||||
"event_kind": evt.Kind.String(),
|
||||
@@ -778,6 +830,9 @@ func (al *AgentLoop) ReloadProviderAndConfig(
|
||||
|
||||
al.mu.Unlock()
|
||||
|
||||
al.hookRuntime.reset(al)
|
||||
configureHookManagerFromConfig(al.hooks, cfg)
|
||||
|
||||
// Close old provider after releasing the lock
|
||||
// This prevents blocking readers while closing
|
||||
if oldProvider, ok := extractProvider(oldRegistry); ok {
|
||||
@@ -992,6 +1047,9 @@ func (al *AgentLoop) ProcessDirectWithChannel(
|
||||
ctx context.Context,
|
||||
content, sessionKey, channel, chatID string,
|
||||
) (string, error) {
|
||||
if err := al.ensureHooksInitialized(ctx); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := al.ensureMCPInitialized(ctx); err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -1013,6 +1071,13 @@ func (al *AgentLoop) ProcessHeartbeat(
|
||||
ctx context.Context,
|
||||
content, channel, chatID string,
|
||||
) (string, error) {
|
||||
if err := al.ensureHooksInitialized(ctx); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := al.ensureMCPInitialized(ctx); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
agent := al.GetRegistry().GetDefaultAgent()
|
||||
if agent == nil {
|
||||
return "", fmt.Errorf("no default agent for heartbeat")
|
||||
@@ -1504,36 +1569,6 @@ turnLoop:
|
||||
ts.markGracefulTerminalUsed()
|
||||
}
|
||||
|
||||
al.emitEvent(
|
||||
EventKindLLMRequest,
|
||||
ts.eventMeta("runTurn", "turn.llm.request"),
|
||||
LLMRequestPayload{
|
||||
Model: activeModel,
|
||||
MessagesCount: len(callMessages),
|
||||
ToolsCount: len(providerToolDefs),
|
||||
MaxTokens: ts.agent.MaxTokens,
|
||||
Temperature: ts.agent.Temperature,
|
||||
},
|
||||
)
|
||||
|
||||
logger.DebugCF("agent", "LLM request",
|
||||
map[string]any{
|
||||
"agent_id": ts.agent.ID,
|
||||
"iteration": iteration,
|
||||
"model": activeModel,
|
||||
"messages_count": len(callMessages),
|
||||
"tools_count": len(providerToolDefs),
|
||||
"max_tokens": ts.agent.MaxTokens,
|
||||
"temperature": ts.agent.Temperature,
|
||||
"system_prompt_len": len(callMessages[0].Content),
|
||||
})
|
||||
logger.DebugCF("agent", "Full LLM request",
|
||||
map[string]any{
|
||||
"iteration": iteration,
|
||||
"messages_json": formatMessagesForLog(callMessages),
|
||||
"tools_json": formatToolsForLog(providerToolDefs),
|
||||
})
|
||||
|
||||
llmOpts := map[string]any{
|
||||
"max_tokens": ts.agent.MaxTokens,
|
||||
"temperature": ts.agent.Temperature,
|
||||
@@ -1548,6 +1583,66 @@ turnLoop:
|
||||
}
|
||||
}
|
||||
|
||||
llmModel := activeModel
|
||||
if al.hooks != nil {
|
||||
llmReq, decision := al.hooks.BeforeLLM(turnCtx, &LLMHookRequest{
|
||||
Meta: ts.eventMeta("runTurn", "turn.llm.request"),
|
||||
Model: llmModel,
|
||||
Messages: callMessages,
|
||||
Tools: providerToolDefs,
|
||||
Options: llmOpts,
|
||||
Channel: ts.channel,
|
||||
ChatID: ts.chatID,
|
||||
GracefulTerminal: gracefulTerminal,
|
||||
})
|
||||
switch decision.normalizedAction() {
|
||||
case HookActionContinue, HookActionModify:
|
||||
if llmReq != nil {
|
||||
llmModel = llmReq.Model
|
||||
callMessages = llmReq.Messages
|
||||
providerToolDefs = llmReq.Tools
|
||||
llmOpts = llmReq.Options
|
||||
}
|
||||
case HookActionAbortTurn:
|
||||
turnStatus = TurnEndStatusError
|
||||
return turnResult{}, al.hookAbortError(ts, "before_llm", decision)
|
||||
case HookActionHardAbort:
|
||||
_ = ts.requestHardAbort()
|
||||
turnStatus = TurnEndStatusAborted
|
||||
return al.abortTurn(ts)
|
||||
}
|
||||
}
|
||||
|
||||
al.emitEvent(
|
||||
EventKindLLMRequest,
|
||||
ts.eventMeta("runTurn", "turn.llm.request"),
|
||||
LLMRequestPayload{
|
||||
Model: llmModel,
|
||||
MessagesCount: len(callMessages),
|
||||
ToolsCount: len(providerToolDefs),
|
||||
MaxTokens: ts.agent.MaxTokens,
|
||||
Temperature: ts.agent.Temperature,
|
||||
},
|
||||
)
|
||||
|
||||
logger.DebugCF("agent", "LLM request",
|
||||
map[string]any{
|
||||
"agent_id": ts.agent.ID,
|
||||
"iteration": iteration,
|
||||
"model": llmModel,
|
||||
"messages_count": len(callMessages),
|
||||
"tools_count": len(providerToolDefs),
|
||||
"max_tokens": ts.agent.MaxTokens,
|
||||
"temperature": ts.agent.Temperature,
|
||||
"system_prompt_len": len(callMessages[0].Content),
|
||||
})
|
||||
logger.DebugCF("agent", "Full LLM request",
|
||||
map[string]any{
|
||||
"iteration": iteration,
|
||||
"messages_json": formatMessagesForLog(callMessages),
|
||||
"tools_json": formatToolsForLog(providerToolDefs),
|
||||
})
|
||||
|
||||
callLLM := func(messagesForCall []providers.Message, toolDefsForCall []providers.ToolDefinition) (*providers.LLMResponse, error) {
|
||||
providerCtx, providerCancel := context.WithCancel(turnCtx)
|
||||
ts.setProviderCancel(providerCancel)
|
||||
@@ -1580,7 +1675,7 @@ turnLoop:
|
||||
}
|
||||
return fbResult.Response, nil
|
||||
}
|
||||
return ts.agent.Provider.Chat(providerCtx, messagesForCall, toolDefsForCall, activeModel, llmOpts)
|
||||
return ts.agent.Provider.Chat(providerCtx, messagesForCall, toolDefsForCall, llmModel, llmOpts)
|
||||
}
|
||||
|
||||
var response *providers.LLMResponse
|
||||
@@ -1712,12 +1807,35 @@ turnLoop:
|
||||
map[string]any{
|
||||
"agent_id": ts.agent.ID,
|
||||
"iteration": iteration,
|
||||
"model": activeModel,
|
||||
"model": llmModel,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return turnResult{}, fmt.Errorf("LLM call failed after retries: %w", err)
|
||||
}
|
||||
|
||||
if al.hooks != nil {
|
||||
llmResp, decision := al.hooks.AfterLLM(turnCtx, &LLMHookResponse{
|
||||
Meta: ts.eventMeta("runTurn", "turn.llm.response"),
|
||||
Model: llmModel,
|
||||
Response: response,
|
||||
Channel: ts.channel,
|
||||
ChatID: ts.chatID,
|
||||
})
|
||||
switch decision.normalizedAction() {
|
||||
case HookActionContinue, HookActionModify:
|
||||
if llmResp != nil && llmResp.Response != nil {
|
||||
response = llmResp.Response
|
||||
}
|
||||
case HookActionAbortTurn:
|
||||
turnStatus = TurnEndStatusError
|
||||
return turnResult{}, al.hookAbortError(ts, "after_llm", decision)
|
||||
case HookActionHardAbort:
|
||||
_ = ts.requestHardAbort()
|
||||
turnStatus = TurnEndStatusAborted
|
||||
return al.abortTurn(ts)
|
||||
}
|
||||
}
|
||||
|
||||
go al.handleReasoning(
|
||||
turnCtx,
|
||||
response.Reasoning,
|
||||
@@ -1825,25 +1943,106 @@ turnLoop:
|
||||
return al.abortTurn(ts)
|
||||
}
|
||||
|
||||
argsJSON, _ := json.Marshal(tc.Arguments)
|
||||
toolName := tc.Name
|
||||
toolArgs := cloneStringAnyMap(tc.Arguments)
|
||||
|
||||
if al.hooks != nil {
|
||||
toolReq, decision := al.hooks.BeforeTool(turnCtx, &ToolCallHookRequest{
|
||||
Meta: ts.eventMeta("runTurn", "turn.tool.before"),
|
||||
Tool: toolName,
|
||||
Arguments: toolArgs,
|
||||
Channel: ts.channel,
|
||||
ChatID: ts.chatID,
|
||||
})
|
||||
switch decision.normalizedAction() {
|
||||
case HookActionContinue, HookActionModify:
|
||||
if toolReq != nil {
|
||||
toolName = toolReq.Tool
|
||||
toolArgs = toolReq.Arguments
|
||||
}
|
||||
case HookActionDenyTool:
|
||||
denyContent := hookDeniedToolContent("Tool execution denied by hook", decision.Reason)
|
||||
al.emitEvent(
|
||||
EventKindToolExecSkipped,
|
||||
ts.eventMeta("runTurn", "turn.tool.skipped"),
|
||||
ToolExecSkippedPayload{
|
||||
Tool: toolName,
|
||||
Reason: denyContent,
|
||||
},
|
||||
)
|
||||
deniedMsg := providers.Message{
|
||||
Role: "tool",
|
||||
Content: denyContent,
|
||||
ToolCallID: tc.ID,
|
||||
}
|
||||
messages = append(messages, deniedMsg)
|
||||
if !ts.opts.NoHistory {
|
||||
ts.agent.Sessions.AddFullMessage(ts.sessionKey, deniedMsg)
|
||||
ts.recordPersistedMessage(deniedMsg)
|
||||
}
|
||||
continue
|
||||
case HookActionAbortTurn:
|
||||
turnStatus = TurnEndStatusError
|
||||
return turnResult{}, al.hookAbortError(ts, "before_tool", decision)
|
||||
case HookActionHardAbort:
|
||||
_ = ts.requestHardAbort()
|
||||
turnStatus = TurnEndStatusAborted
|
||||
return al.abortTurn(ts)
|
||||
}
|
||||
}
|
||||
|
||||
if al.hooks != nil {
|
||||
approval := al.hooks.ApproveTool(turnCtx, &ToolApprovalRequest{
|
||||
Meta: ts.eventMeta("runTurn", "turn.tool.approve"),
|
||||
Tool: toolName,
|
||||
Arguments: toolArgs,
|
||||
Channel: ts.channel,
|
||||
ChatID: ts.chatID,
|
||||
})
|
||||
if !approval.Approved {
|
||||
denyContent := hookDeniedToolContent("Tool execution denied by approval hook", approval.Reason)
|
||||
al.emitEvent(
|
||||
EventKindToolExecSkipped,
|
||||
ts.eventMeta("runTurn", "turn.tool.skipped"),
|
||||
ToolExecSkippedPayload{
|
||||
Tool: toolName,
|
||||
Reason: denyContent,
|
||||
},
|
||||
)
|
||||
deniedMsg := providers.Message{
|
||||
Role: "tool",
|
||||
Content: denyContent,
|
||||
ToolCallID: tc.ID,
|
||||
}
|
||||
messages = append(messages, deniedMsg)
|
||||
if !ts.opts.NoHistory {
|
||||
ts.agent.Sessions.AddFullMessage(ts.sessionKey, deniedMsg)
|
||||
ts.recordPersistedMessage(deniedMsg)
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
argsJSON, _ := json.Marshal(toolArgs)
|
||||
argsPreview := utils.Truncate(string(argsJSON), 200)
|
||||
logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
|
||||
logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", toolName, argsPreview),
|
||||
map[string]any{
|
||||
"agent_id": ts.agent.ID,
|
||||
"tool": tc.Name,
|
||||
"tool": toolName,
|
||||
"iteration": iteration,
|
||||
})
|
||||
al.emitEvent(
|
||||
EventKindToolExecStart,
|
||||
ts.eventMeta("runTurn", "turn.tool.start"),
|
||||
ToolExecStartPayload{
|
||||
Tool: tc.Name,
|
||||
Arguments: cloneEventArguments(tc.Arguments),
|
||||
Tool: toolName,
|
||||
Arguments: cloneEventArguments(toolArgs),
|
||||
},
|
||||
)
|
||||
|
||||
toolCall := tc
|
||||
toolCallID := tc.ID
|
||||
toolIteration := iteration
|
||||
asyncToolName := toolName
|
||||
asyncCallback := func(_ context.Context, result *tools.ToolResult) {
|
||||
if !result.Silent && result.ForUser != "" {
|
||||
outCtx, outCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
@@ -1865,7 +2064,7 @@ turnLoop:
|
||||
|
||||
logger.InfoCF("agent", "Async tool completed, publishing result",
|
||||
map[string]any{
|
||||
"tool": toolCall.Name,
|
||||
"tool": asyncToolName,
|
||||
"content_len": len(content),
|
||||
"channel": ts.channel,
|
||||
})
|
||||
@@ -1873,7 +2072,7 @@ turnLoop:
|
||||
EventKindFollowUpQueued,
|
||||
ts.scope.meta(toolIteration, "runTurn", "turn.follow_up.queued"),
|
||||
FollowUpQueuedPayload{
|
||||
SourceTool: toolCall.Name,
|
||||
SourceTool: asyncToolName,
|
||||
Channel: ts.channel,
|
||||
ChatID: ts.chatID,
|
||||
ContentLen: len(content),
|
||||
@@ -1884,7 +2083,7 @@ turnLoop:
|
||||
defer pubCancel()
|
||||
_ = al.bus.PublishInbound(pubCtx, bus.InboundMessage{
|
||||
Channel: "system",
|
||||
SenderID: fmt.Sprintf("async:%s", toolCall.Name),
|
||||
SenderID: fmt.Sprintf("async:%s", asyncToolName),
|
||||
ChatID: fmt.Sprintf("%s:%s", ts.channel, ts.chatID),
|
||||
Content: content,
|
||||
})
|
||||
@@ -1893,8 +2092,8 @@ turnLoop:
|
||||
toolStart := time.Now()
|
||||
toolResult := ts.agent.Tools.ExecuteWithContext(
|
||||
turnCtx,
|
||||
toolCall.Name,
|
||||
toolCall.Arguments,
|
||||
toolName,
|
||||
toolArgs,
|
||||
ts.channel,
|
||||
ts.chatID,
|
||||
asyncCallback,
|
||||
@@ -1906,6 +2105,40 @@ turnLoop:
|
||||
return al.abortTurn(ts)
|
||||
}
|
||||
|
||||
if al.hooks != nil {
|
||||
toolResp, decision := al.hooks.AfterTool(turnCtx, &ToolResultHookResponse{
|
||||
Meta: ts.eventMeta("runTurn", "turn.tool.after"),
|
||||
Tool: toolName,
|
||||
Arguments: toolArgs,
|
||||
Result: toolResult,
|
||||
Duration: toolDuration,
|
||||
Channel: ts.channel,
|
||||
ChatID: ts.chatID,
|
||||
})
|
||||
switch decision.normalizedAction() {
|
||||
case HookActionContinue, HookActionModify:
|
||||
if toolResp != nil {
|
||||
if toolResp.Tool != "" {
|
||||
toolName = toolResp.Tool
|
||||
}
|
||||
if toolResp.Result != nil {
|
||||
toolResult = toolResp.Result
|
||||
}
|
||||
}
|
||||
case HookActionAbortTurn:
|
||||
turnStatus = TurnEndStatusError
|
||||
return turnResult{}, al.hookAbortError(ts, "after_tool", decision)
|
||||
case HookActionHardAbort:
|
||||
_ = ts.requestHardAbort()
|
||||
turnStatus = TurnEndStatusAborted
|
||||
return al.abortTurn(ts)
|
||||
}
|
||||
}
|
||||
|
||||
if toolResult == nil {
|
||||
toolResult = tools.ErrorResult("hook returned nil tool result")
|
||||
}
|
||||
|
||||
if !toolResult.Silent && toolResult.ForUser != "" && ts.opts.SendResponse {
|
||||
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
|
||||
Channel: ts.channel,
|
||||
@@ -1914,7 +2147,7 @@ turnLoop:
|
||||
})
|
||||
logger.DebugCF("agent", "Sent tool result to user",
|
||||
map[string]any{
|
||||
"tool": toolCall.Name,
|
||||
"tool": toolName,
|
||||
"content_len": len(toolResult.ForUser),
|
||||
})
|
||||
}
|
||||
@@ -1947,13 +2180,13 @@ turnLoop:
|
||||
toolResultMsg := providers.Message{
|
||||
Role: "tool",
|
||||
Content: contentForLLM,
|
||||
ToolCallID: toolCall.ID,
|
||||
ToolCallID: toolCallID,
|
||||
}
|
||||
al.emitEvent(
|
||||
EventKindToolExecEnd,
|
||||
ts.eventMeta("runTurn", "turn.tool.end"),
|
||||
ToolExecEndPayload{
|
||||
Tool: toolCall.Name,
|
||||
Tool: toolName,
|
||||
Duration: toolDuration,
|
||||
ForLLMLen: len(contentForLLM),
|
||||
ForUserLen: len(toolResult.ForUser),
|
||||
|
||||
@@ -328,6 +328,12 @@ func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID s
|
||||
if active := al.GetActiveTurn(); active != nil {
|
||||
return "", fmt.Errorf("turn %s is still active", active.TurnID)
|
||||
}
|
||||
if err := al.ensureHooksInitialized(ctx); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := al.ensureMCPInitialized(ctx); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
steeringMsgs := al.dequeueSteeringMessagesForScopeWithFallback(sessionKey)
|
||||
if len(steeringMsgs) == 0 {
|
||||
|
||||
Reference in New Issue
Block a user