diff --git a/pkg/agent/hook_mount.go b/pkg/agent/hook_mount.go new file mode 100644 index 000000000..c92145f1f --- /dev/null +++ b/pkg/agent/hook_mount.go @@ -0,0 +1,317 @@ +package agent + +import ( + "context" + "fmt" + "sort" + "sync" + "time" + + "github.com/sipeed/picoclaw/pkg/config" +) + +type hookRuntime struct { + initOnce sync.Once + mu sync.Mutex + initErr error + mounted []string +} + +func (r *hookRuntime) setInitErr(err error) { + r.mu.Lock() + r.initErr = err + r.mu.Unlock() +} + +func (r *hookRuntime) getInitErr() error { + r.mu.Lock() + defer r.mu.Unlock() + return r.initErr +} + +func (r *hookRuntime) setMounted(names []string) { + r.mu.Lock() + r.mounted = append([]string(nil), names...) + r.mu.Unlock() +} + +func (r *hookRuntime) reset(al *AgentLoop) { + r.mu.Lock() + names := append([]string(nil), r.mounted...) + r.mounted = nil + r.initErr = nil + r.initOnce = sync.Once{} + r.mu.Unlock() + + for _, name := range names { + al.UnmountHook(name) + } +} + +// BuiltinHookFactory constructs an in-process hook from config. +type BuiltinHookFactory func(ctx context.Context, spec config.BuiltinHookConfig) (any, error) + +var ( + builtinHookRegistryMu sync.RWMutex + builtinHookRegistry = map[string]BuiltinHookFactory{} +) + +// RegisterBuiltinHook registers a named in-process hook factory for config-driven mounting. +func RegisterBuiltinHook(name string, factory BuiltinHookFactory) error { + if name == "" { + return fmt.Errorf("builtin hook name is required") + } + if factory == nil { + return fmt.Errorf("builtin hook %q factory is nil", name) + } + + builtinHookRegistryMu.Lock() + defer builtinHookRegistryMu.Unlock() + + if _, exists := builtinHookRegistry[name]; exists { + return fmt.Errorf("builtin hook %q is already registered", name) + } + builtinHookRegistry[name] = factory + return nil +} + +func unregisterBuiltinHook(name string) { + if name == "" { + return + } + builtinHookRegistryMu.Lock() + delete(builtinHookRegistry, name) + builtinHookRegistryMu.Unlock() +} + +func lookupBuiltinHook(name string) (BuiltinHookFactory, bool) { + builtinHookRegistryMu.RLock() + defer builtinHookRegistryMu.RUnlock() + + factory, ok := builtinHookRegistry[name] + return factory, ok +} + +func configureHookManagerFromConfig(hm *HookManager, cfg *config.Config) { + if hm == nil || cfg == nil { + return + } + hm.ConfigureTimeouts( + hookTimeoutFromMS(cfg.Hooks.Defaults.ObserverTimeoutMS), + hookTimeoutFromMS(cfg.Hooks.Defaults.InterceptorTimeoutMS), + hookTimeoutFromMS(cfg.Hooks.Defaults.ApprovalTimeoutMS), + ) +} + +func hookTimeoutFromMS(ms int) time.Duration { + if ms <= 0 { + return 0 + } + return time.Duration(ms) * time.Millisecond +} + +func (al *AgentLoop) ensureHooksInitialized(ctx context.Context) error { + if al == nil || al.cfg == nil || al.hooks == nil { + return nil + } + + al.hookRuntime.initOnce.Do(func() { + al.hookRuntime.setInitErr(al.loadConfiguredHooks(ctx)) + }) + + return al.hookRuntime.getInitErr() +} + +func (al *AgentLoop) loadConfiguredHooks(ctx context.Context) (err error) { + if al == nil || al.cfg == nil || !al.cfg.Hooks.Enabled { + return nil + } + + mounted := make([]string, 0) + defer func() { + if err != nil { + for _, name := range mounted { + al.UnmountHook(name) + } + return + } + al.hookRuntime.setMounted(mounted) + }() + + builtinNames := enabledBuiltinHookNames(al.cfg.Hooks.Builtins) + for _, name := range builtinNames { + spec := al.cfg.Hooks.Builtins[name] + factory, ok := lookupBuiltinHook(name) + if !ok { + return fmt.Errorf("builtin hook %q is not registered", name) + } + + hook, factoryErr := factory(ctx, spec) + if factoryErr != nil { + return fmt.Errorf("build builtin hook %q: %w", name, factoryErr) + } + if err := al.MountHook(HookRegistration{ + Name: name, + Priority: spec.Priority, + Source: HookSourceInProcess, + Hook: hook, + }); err != nil { + return fmt.Errorf("mount builtin hook %q: %w", name, err) + } + mounted = append(mounted, name) + } + + processNames := enabledProcessHookNames(al.cfg.Hooks.Processes) + for _, name := range processNames { + spec := al.cfg.Hooks.Processes[name] + opts, buildErr := processHookOptionsFromConfig(spec) + if buildErr != nil { + return fmt.Errorf("configure process hook %q: %w", name, buildErr) + } + + processHook, buildErr := NewProcessHook(ctx, name, opts) + if buildErr != nil { + return fmt.Errorf("start process hook %q: %w", name, buildErr) + } + if err := al.MountHook(HookRegistration{ + Name: name, + Priority: spec.Priority, + Source: HookSourceProcess, + Hook: processHook, + }); err != nil { + _ = processHook.Close() + return fmt.Errorf("mount process hook %q: %w", name, err) + } + mounted = append(mounted, name) + } + + return nil +} + +func enabledBuiltinHookNames(specs map[string]config.BuiltinHookConfig) []string { + if len(specs) == 0 { + return nil + } + + names := make([]string, 0, len(specs)) + for name, spec := range specs { + if spec.Enabled { + names = append(names, name) + } + } + sort.Strings(names) + return names +} + +func enabledProcessHookNames(specs map[string]config.ProcessHookConfig) []string { + if len(specs) == 0 { + return nil + } + + names := make([]string, 0, len(specs)) + for name, spec := range specs { + if spec.Enabled { + names = append(names, name) + } + } + sort.Strings(names) + return names +} + +func processHookOptionsFromConfig(spec config.ProcessHookConfig) (ProcessHookOptions, error) { + transport := spec.Transport + if transport == "" { + transport = "stdio" + } + if transport != "stdio" { + return ProcessHookOptions{}, fmt.Errorf("unsupported transport %q", transport) + } + if len(spec.Command) == 0 { + return ProcessHookOptions{}, fmt.Errorf("command is required") + } + + opts := ProcessHookOptions{ + Command: append([]string(nil), spec.Command...), + Dir: spec.Dir, + Env: processHookEnvFromMap(spec.Env), + } + + observeKinds, observeEnabled, err := processHookObserveKindsFromConfig(spec.Observe) + if err != nil { + return ProcessHookOptions{}, err + } + opts.Observe = observeEnabled + opts.ObserveKinds = observeKinds + + for _, intercept := range spec.Intercept { + switch intercept { + case "before_llm", "after_llm": + opts.InterceptLLM = true + case "before_tool", "after_tool": + opts.InterceptTool = true + case "approve_tool": + opts.ApproveTool = true + case "": + continue + default: + return ProcessHookOptions{}, fmt.Errorf("unsupported intercept %q", intercept) + } + } + + if !opts.Observe && !opts.InterceptLLM && !opts.InterceptTool && !opts.ApproveTool { + return ProcessHookOptions{}, fmt.Errorf("no hook modes enabled") + } + + return opts, nil +} + +func processHookEnvFromMap(envMap map[string]string) []string { + if len(envMap) == 0 { + return nil + } + + keys := make([]string, 0, len(envMap)) + for key := range envMap { + keys = append(keys, key) + } + sort.Strings(keys) + + env := make([]string, 0, len(keys)) + for _, key := range keys { + env = append(env, key+"="+envMap[key]) + } + return env +} + +func processHookObserveKindsFromConfig(observe []string) ([]string, bool, error) { + if len(observe) == 0 { + return nil, false, nil + } + + validKinds := validHookEventKinds() + normalized := make([]string, 0, len(observe)) + for _, kind := range observe { + switch kind { + case "", "*", "all": + return nil, true, nil + default: + if _, ok := validKinds[kind]; !ok { + return nil, false, fmt.Errorf("unsupported observe event %q", kind) + } + normalized = append(normalized, kind) + } + } + + if len(normalized) == 0 { + return nil, false, nil + } + return normalized, true, nil +} + +func validHookEventKinds() map[string]struct{} { + kinds := make(map[string]struct{}, int(eventKindCount)) + for kind := EventKind(0); kind < eventKindCount; kind++ { + kinds[kind.String()] = struct{}{} + } + return kinds +} diff --git a/pkg/agent/hook_mount_test.go b/pkg/agent/hook_mount_test.go new file mode 100644 index 000000000..a9d8f27c5 --- /dev/null +++ b/pkg/agent/hook_mount_test.go @@ -0,0 +1,179 @@ +package agent + +import ( + "context" + "encoding/json" + "path/filepath" + "testing" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" +) + +type builtinAutoHookConfig struct { + Model string `json:"model"` + Suffix string `json:"suffix"` +} + +type builtinAutoHook struct { + model string + suffix string +} + +func (h *builtinAutoHook) BeforeLLM( + ctx context.Context, + req *LLMHookRequest, +) (*LLMHookRequest, HookDecision, error) { + next := req.Clone() + next.Model = h.model + return next, HookDecision{Action: HookActionModify}, nil +} + +func (h *builtinAutoHook) AfterLLM( + ctx context.Context, + resp *LLMHookResponse, +) (*LLMHookResponse, HookDecision, error) { + next := resp.Clone() + if next.Response != nil { + next.Response.Content += h.suffix + } + return next, HookDecision{Action: HookActionModify}, nil +} + +func newConfiguredHookLoop(t *testing.T, provider *llmHookTestProvider, hooks config.HooksConfig) *AgentLoop { + t.Helper() + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: t.TempDir(), + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + Hooks: hooks, + } + + return NewAgentLoop(cfg, bus.NewMessageBus(), provider) +} + +func TestAgentLoop_ProcessDirectWithChannel_AutoMountsBuiltinHook(t *testing.T) { + const hookName = "test-auto-builtin-hook" + + if err := RegisterBuiltinHook(hookName, func( + ctx context.Context, + spec config.BuiltinHookConfig, + ) (any, error) { + var hookCfg builtinAutoHookConfig + if len(spec.Config) > 0 { + if err := json.Unmarshal(spec.Config, &hookCfg); err != nil { + return nil, err + } + } + return &builtinAutoHook{ + model: hookCfg.Model, + suffix: hookCfg.Suffix, + }, nil + }); err != nil { + t.Fatalf("RegisterBuiltinHook failed: %v", err) + } + t.Cleanup(func() { + unregisterBuiltinHook(hookName) + }) + + rawCfg, err := json.Marshal(builtinAutoHookConfig{ + Model: "builtin-model", + Suffix: "|builtin", + }) + if err != nil { + t.Fatalf("json.Marshal failed: %v", err) + } + + provider := &llmHookTestProvider{} + al := newConfiguredHookLoop(t, provider, config.HooksConfig{ + Enabled: true, + Builtins: map[string]config.BuiltinHookConfig{ + hookName: { + Enabled: true, + Config: rawCfg, + }, + }, + }) + defer al.Close() + + resp, err := al.ProcessDirectWithChannel(context.Background(), "hello", "session-1", "cli", "direct") + if err != nil { + t.Fatalf("ProcessDirectWithChannel failed: %v", err) + } + if resp != "provider content|builtin" { + t.Fatalf("expected builtin-hooked content, got %q", resp) + } + + provider.mu.Lock() + lastModel := provider.lastModel + provider.mu.Unlock() + if lastModel != "builtin-model" { + t.Fatalf("expected builtin model, got %q", lastModel) + } +} + +func TestAgentLoop_ProcessDirectWithChannel_AutoMountsProcessHook(t *testing.T) { + provider := &llmHookTestProvider{} + eventLog := filepath.Join(t.TempDir(), "events.log") + + al := newConfiguredHookLoop(t, provider, config.HooksConfig{ + Enabled: true, + Processes: map[string]config.ProcessHookConfig{ + "ipc-auto": { + Enabled: true, + Command: processHookHelperCommand(), + Env: map[string]string{ + "PICOCLAW_HOOK_HELPER": "1", + "PICOCLAW_HOOK_MODE": "rewrite", + "PICOCLAW_HOOK_EVENT_LOG": eventLog, + }, + Observe: []string{"turn_end"}, + Intercept: []string{"before_llm", "after_llm"}, + }, + }, + }) + defer al.Close() + + resp, err := al.ProcessDirectWithChannel(context.Background(), "hello", "session-1", "cli", "direct") + if err != nil { + t.Fatalf("ProcessDirectWithChannel failed: %v", err) + } + if resp != "provider content|ipc" { + t.Fatalf("expected process-hooked content, got %q", resp) + } + + provider.mu.Lock() + lastModel := provider.lastModel + provider.mu.Unlock() + if lastModel != "process-model" { + t.Fatalf("expected process model, got %q", lastModel) + } + + waitForFileContains(t, eventLog, "turn_end") +} + +func TestAgentLoop_ProcessDirectWithChannel_InvalidConfiguredHookFails(t *testing.T) { + provider := &llmHookTestProvider{} + al := newConfiguredHookLoop(t, provider, config.HooksConfig{ + Enabled: true, + Processes: map[string]config.ProcessHookConfig{ + "bad-hook": { + Enabled: true, + Command: processHookHelperCommand(), + Intercept: []string{"not_supported"}, + }, + }, + }) + defer al.Close() + + _, err := al.ProcessDirectWithChannel(context.Background(), "hello", "session-1", "cli", "direct") + if err == nil { + t.Fatal("expected invalid configured hook error") + } +} diff --git a/pkg/agent/hook_process.go b/pkg/agent/hook_process.go new file mode 100644 index 000000000..e5632913d --- /dev/null +++ b/pkg/agent/hook_process.go @@ -0,0 +1,511 @@ +package agent + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "io" + "os" + "os/exec" + "sync" + "sync/atomic" + "time" + + "github.com/sipeed/picoclaw/pkg/logger" +) + +const ( + processHookJSONRPCVersion = "2.0" + processHookReadBufferSize = 1024 * 1024 + processHookCloseTimeout = 2 * time.Second +) + +type ProcessHookOptions struct { + Command []string + Dir string + Env []string + Observe bool + ObserveKinds []string + InterceptLLM bool + InterceptTool bool + ApproveTool bool +} + +type ProcessHook struct { + name string + opts ProcessHookOptions + + cmd *exec.Cmd + stdin io.WriteCloser + observeKinds map[string]struct{} + + writeMu sync.Mutex + + pendingMu sync.Mutex + pending map[uint64]chan processHookRPCMessage + nextID atomic.Uint64 + + closed atomic.Bool + done chan struct{} + closeErr error + closeMu sync.Mutex + closeOnce sync.Once +} + +type processHookRPCMessage struct { + JSONRPC string `json:"jsonrpc,omitempty"` + ID uint64 `json:"id,omitempty"` + Method string `json:"method,omitempty"` + Params json.RawMessage `json:"params,omitempty"` + Result json.RawMessage `json:"result,omitempty"` + Error *processHookRPCError `json:"error,omitempty"` +} + +type processHookRPCError struct { + Code int `json:"code"` + Message string `json:"message"` +} + +type processHookHelloParams struct { + Name string `json:"name"` + Version int `json:"version"` + Modes []string `json:"modes,omitempty"` +} + +type processHookDecisionResponse struct { + Action HookAction `json:"action"` + Reason string `json:"reason,omitempty"` +} + +type processHookBeforeLLMResponse struct { + processHookDecisionResponse + Request *LLMHookRequest `json:"request,omitempty"` +} + +type processHookAfterLLMResponse struct { + processHookDecisionResponse + Response *LLMHookResponse `json:"response,omitempty"` +} + +type processHookBeforeToolResponse struct { + processHookDecisionResponse + Call *ToolCallHookRequest `json:"call,omitempty"` +} + +type processHookAfterToolResponse struct { + processHookDecisionResponse + Result *ToolResultHookResponse `json:"result,omitempty"` +} + +func NewProcessHook(ctx context.Context, name string, opts ProcessHookOptions) (*ProcessHook, error) { + if len(opts.Command) == 0 { + return nil, fmt.Errorf("process hook command is required") + } + + cmd := exec.Command(opts.Command[0], opts.Command[1:]...) + cmd.Dir = opts.Dir + if len(opts.Env) > 0 { + cmd.Env = append(os.Environ(), opts.Env...) + } + stdin, err := cmd.StdinPipe() + if err != nil { + return nil, fmt.Errorf("create process hook stdin: %w", err) + } + stdout, err := cmd.StdoutPipe() + if err != nil { + return nil, fmt.Errorf("create process hook stdout: %w", err) + } + stderr, err := cmd.StderrPipe() + if err != nil { + return nil, fmt.Errorf("create process hook stderr: %w", err) + } + if err := cmd.Start(); err != nil { + return nil, fmt.Errorf("start process hook: %w", err) + } + + ph := &ProcessHook{ + name: name, + opts: opts, + cmd: cmd, + stdin: stdin, + observeKinds: newProcessHookObserveKinds(opts.ObserveKinds), + pending: make(map[uint64]chan processHookRPCMessage), + done: make(chan struct{}), + } + + go ph.readLoop(stdout) + go ph.readStderr(stderr) + go ph.waitLoop() + + helloCtx := ctx + if helloCtx == nil { + var cancel context.CancelFunc + helloCtx, cancel = context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + } + if err := ph.hello(helloCtx); err != nil { + _ = ph.Close() + return nil, err + } + + return ph, nil +} + +func (ph *ProcessHook) Close() error { + if ph == nil { + return nil + } + + ph.closeOnce.Do(func() { + ph.closed.Store(true) + if ph.stdin != nil { + _ = ph.stdin.Close() + } + + select { + case <-ph.done: + case <-time.After(processHookCloseTimeout): + if ph.cmd != nil && ph.cmd.Process != nil { + _ = ph.cmd.Process.Kill() + } + <-ph.done + } + }) + + ph.closeMu.Lock() + defer ph.closeMu.Unlock() + return ph.closeErr +} + +func (ph *ProcessHook) OnEvent(ctx context.Context, evt Event) error { + if ph == nil || !ph.opts.Observe { + return nil + } + if len(ph.observeKinds) > 0 { + if _, ok := ph.observeKinds[evt.Kind.String()]; !ok { + return nil + } + } + return ph.notify(ctx, "hook.event", evt) +} + +func (ph *ProcessHook) BeforeLLM( + ctx context.Context, + req *LLMHookRequest, +) (*LLMHookRequest, HookDecision, error) { + if ph == nil || !ph.opts.InterceptLLM { + return req, HookDecision{Action: HookActionContinue}, nil + } + + var resp processHookBeforeLLMResponse + if err := ph.call(ctx, "hook.before_llm", req, &resp); err != nil { + return nil, HookDecision{}, err + } + if resp.Request == nil { + resp.Request = req + } + return resp.Request, HookDecision{Action: resp.Action, Reason: resp.Reason}, nil +} + +func (ph *ProcessHook) AfterLLM( + ctx context.Context, + resp *LLMHookResponse, +) (*LLMHookResponse, HookDecision, error) { + if ph == nil || !ph.opts.InterceptLLM { + return resp, HookDecision{Action: HookActionContinue}, nil + } + + var result processHookAfterLLMResponse + if err := ph.call(ctx, "hook.after_llm", resp, &result); err != nil { + return nil, HookDecision{}, err + } + if result.Response == nil { + result.Response = resp + } + return result.Response, HookDecision{Action: result.Action, Reason: result.Reason}, nil +} + +func (ph *ProcessHook) BeforeTool( + ctx context.Context, + call *ToolCallHookRequest, +) (*ToolCallHookRequest, HookDecision, error) { + if ph == nil || !ph.opts.InterceptTool { + return call, HookDecision{Action: HookActionContinue}, nil + } + + var resp processHookBeforeToolResponse + if err := ph.call(ctx, "hook.before_tool", call, &resp); err != nil { + return nil, HookDecision{}, err + } + if resp.Call == nil { + resp.Call = call + } + return resp.Call, HookDecision{Action: resp.Action, Reason: resp.Reason}, nil +} + +func (ph *ProcessHook) AfterTool( + ctx context.Context, + result *ToolResultHookResponse, +) (*ToolResultHookResponse, HookDecision, error) { + if ph == nil || !ph.opts.InterceptTool { + return result, HookDecision{Action: HookActionContinue}, nil + } + + var resp processHookAfterToolResponse + if err := ph.call(ctx, "hook.after_tool", result, &resp); err != nil { + return nil, HookDecision{}, err + } + if resp.Result == nil { + resp.Result = result + } + return resp.Result, HookDecision{Action: resp.Action, Reason: resp.Reason}, nil +} + +func (ph *ProcessHook) ApproveTool(ctx context.Context, req *ToolApprovalRequest) (ApprovalDecision, error) { + if ph == nil || !ph.opts.ApproveTool { + return ApprovalDecision{Approved: true}, nil + } + + var resp ApprovalDecision + if err := ph.call(ctx, "hook.approve_tool", req, &resp); err != nil { + return ApprovalDecision{}, err + } + return resp, nil +} + +func (ph *ProcessHook) hello(ctx context.Context) error { + modes := make([]string, 0, 4) + if ph.opts.Observe { + modes = append(modes, "observe") + } + if ph.opts.InterceptLLM { + modes = append(modes, "llm") + } + if ph.opts.InterceptTool { + modes = append(modes, "tool") + } + if ph.opts.ApproveTool { + modes = append(modes, "approve") + } + + var result map[string]any + return ph.call(ctx, "hook.hello", processHookHelloParams{ + Name: ph.name, + Version: 1, + Modes: modes, + }, &result) +} + +func (ph *ProcessHook) notify(ctx context.Context, method string, params any) error { + msg := processHookRPCMessage{ + JSONRPC: processHookJSONRPCVersion, + Method: method, + } + if params != nil { + body, err := json.Marshal(params) + if err != nil { + return err + } + msg.Params = body + } + return ph.send(ctx, msg) +} + +func (ph *ProcessHook) call(ctx context.Context, method string, params any, out any) error { + if ph.closed.Load() { + return fmt.Errorf("process hook %q is closed", ph.name) + } + + id := ph.nextID.Add(1) + respCh := make(chan processHookRPCMessage, 1) + ph.pendingMu.Lock() + ph.pending[id] = respCh + ph.pendingMu.Unlock() + + msg := processHookRPCMessage{ + JSONRPC: processHookJSONRPCVersion, + ID: id, + Method: method, + } + if params != nil { + body, err := json.Marshal(params) + if err != nil { + ph.removePending(id) + return err + } + msg.Params = body + } + + if err := ph.send(ctx, msg); err != nil { + ph.removePending(id) + return err + } + + select { + case resp, ok := <-respCh: + if !ok { + return fmt.Errorf("process hook %q closed while waiting for %s", ph.name, method) + } + if resp.Error != nil { + return fmt.Errorf("process hook %q %s failed: %s", ph.name, method, resp.Error.Message) + } + if out != nil && len(resp.Result) > 0 { + if err := json.Unmarshal(resp.Result, out); err != nil { + return fmt.Errorf("decode process hook %q %s result: %w", ph.name, method, err) + } + } + return nil + case <-ctx.Done(): + ph.removePending(id) + return ctx.Err() + } +} + +func (ph *ProcessHook) send(ctx context.Context, msg processHookRPCMessage) error { + body, err := json.Marshal(msg) + if err != nil { + return err + } + body = append(body, '\n') + + ph.writeMu.Lock() + defer ph.writeMu.Unlock() + + if ph.closed.Load() { + return fmt.Errorf("process hook %q is closed", ph.name) + } + + done := make(chan error, 1) + go func() { + _, writeErr := ph.stdin.Write(body) + done <- writeErr + }() + + select { + case err := <-done: + if err != nil { + return fmt.Errorf("write process hook %q message: %w", ph.name, err) + } + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +func (ph *ProcessHook) readLoop(stdout io.Reader) { + scanner := bufio.NewScanner(stdout) + scanner.Buffer(make([]byte, 0, 64*1024), processHookReadBufferSize) + + for scanner.Scan() { + var msg processHookRPCMessage + if err := json.Unmarshal(scanner.Bytes(), &msg); err != nil { + logger.WarnCF("hooks", "Failed to decode process hook message", map[string]any{ + "hook": ph.name, + "error": err.Error(), + }) + continue + } + if msg.ID == 0 { + continue + } + ph.pendingMu.Lock() + respCh, ok := ph.pending[msg.ID] + if ok { + delete(ph.pending, msg.ID) + } + ph.pendingMu.Unlock() + if ok { + respCh <- msg + close(respCh) + } + } +} + +func (ph *ProcessHook) readStderr(stderr io.Reader) { + scanner := bufio.NewScanner(stderr) + scanner.Buffer(make([]byte, 0, 16*1024), processHookReadBufferSize) + for scanner.Scan() { + logger.WarnCF("hooks", "Process hook stderr", map[string]any{ + "hook": ph.name, + "stderr": scanner.Text(), + }) + } +} + +func (ph *ProcessHook) waitLoop() { + err := ph.cmd.Wait() + ph.closeMu.Lock() + ph.closeErr = err + ph.closeMu.Unlock() + ph.failPending(err) + close(ph.done) +} + +func (ph *ProcessHook) failPending(err error) { + ph.pendingMu.Lock() + defer ph.pendingMu.Unlock() + + msg := processHookRPCMessage{ + Error: &processHookRPCError{ + Code: -32000, + Message: "process exited", + }, + } + if err != nil { + msg.Error.Message = err.Error() + } + + for id, ch := range ph.pending { + delete(ph.pending, id) + ch <- msg + close(ch) + } +} + +func (ph *ProcessHook) removePending(id uint64) { + ph.pendingMu.Lock() + defer ph.pendingMu.Unlock() + + if ch, ok := ph.pending[id]; ok { + delete(ph.pending, id) + close(ch) + } +} + +func (al *AgentLoop) MountProcessHook(ctx context.Context, name string, opts ProcessHookOptions) error { + if al == nil { + return fmt.Errorf("agent loop is nil") + } + processHook, err := NewProcessHook(ctx, name, opts) + if err != nil { + return err + } + if err := al.MountHook(HookRegistration{ + Name: name, + Source: HookSourceProcess, + Hook: processHook, + }); err != nil { + _ = processHook.Close() + return err + } + return nil +} + +func newProcessHookObserveKinds(kinds []string) map[string]struct{} { + if len(kinds) == 0 { + return nil + } + + normalized := make(map[string]struct{}, len(kinds)) + for _, kind := range kinds { + if kind == "" { + continue + } + normalized[kind] = struct{}{} + } + if len(normalized) == 0 { + return nil + } + return normalized +} diff --git a/pkg/agent/hook_process_test.go b/pkg/agent/hook_process_test.go new file mode 100644 index 000000000..50f89811f --- /dev/null +++ b/pkg/agent/hook_process_test.go @@ -0,0 +1,339 @@ +package agent + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/sipeed/picoclaw/pkg/providers" +) + +func TestProcessHook_HelperProcess(t *testing.T) { + if os.Getenv("PICOCLAW_HOOK_HELPER") != "1" { + return + } + if err := runProcessHookHelper(); err != nil { + fmt.Fprintln(os.Stderr, err.Error()) + os.Exit(1) + } + os.Exit(0) +} + +func TestAgentLoop_MountProcessHook_LLMAndObserver(t *testing.T) { + provider := &llmHookTestProvider{} + al, agent, cleanup := newHookTestLoop(t, provider) + defer cleanup() + + eventLog := filepath.Join(t.TempDir(), "events.log") + if err := al.MountProcessHook(context.Background(), "ipc-llm", ProcessHookOptions{ + Command: processHookHelperCommand(), + Env: processHookHelperEnv("rewrite", eventLog), + Observe: true, + InterceptLLM: true, + }); err != nil { + t.Fatalf("MountProcessHook failed: %v", err) + } + + resp, err := al.runAgentLoop(context.Background(), agent, processOptions{ + SessionKey: "session-1", + Channel: "cli", + ChatID: "direct", + UserMessage: "hello", + DefaultResponse: defaultResponse, + EnableSummary: false, + SendResponse: false, + }) + if err != nil { + t.Fatalf("runAgentLoop failed: %v", err) + } + if resp != "provider content|ipc" { + t.Fatalf("expected process-hooked llm content, got %q", resp) + } + + provider.mu.Lock() + lastModel := provider.lastModel + provider.mu.Unlock() + if lastModel != "process-model" { + t.Fatalf("expected process model, got %q", lastModel) + } + + waitForFileContains(t, eventLog, "turn_end") +} + +func TestAgentLoop_MountProcessHook_ToolRewrite(t *testing.T) { + provider := &toolHookProvider{} + al, agent, cleanup := newHookTestLoop(t, provider) + defer cleanup() + + al.RegisterTool(&echoTextTool{}) + if err := al.MountProcessHook(context.Background(), "ipc-tool", ProcessHookOptions{ + Command: processHookHelperCommand(), + Env: processHookHelperEnv("rewrite", ""), + InterceptTool: true, + }); err != nil { + t.Fatalf("MountProcessHook failed: %v", err) + } + + resp, err := al.runAgentLoop(context.Background(), agent, processOptions{ + SessionKey: "session-1", + Channel: "cli", + ChatID: "direct", + UserMessage: "run tool", + DefaultResponse: defaultResponse, + EnableSummary: false, + SendResponse: false, + }) + if err != nil { + t.Fatalf("runAgentLoop failed: %v", err) + } + if resp != "ipc:ipc" { + t.Fatalf("expected rewritten process-hook tool result, got %q", resp) + } +} + +type blockedToolProvider struct { + calls int +} + +func (p *blockedToolProvider) Chat( + ctx context.Context, + messages []providers.Message, + tools []providers.ToolDefinition, + model string, + opts map[string]any, +) (*providers.LLMResponse, error) { + p.calls++ + if p.calls == 1 { + return &providers.LLMResponse{ + ToolCalls: []providers.ToolCall{ + { + ID: "call-1", + Name: "blocked_tool", + Arguments: map[string]any{}, + }, + }, + }, nil + } + + return &providers.LLMResponse{ + Content: messages[len(messages)-1].Content, + }, nil +} + +func (p *blockedToolProvider) GetDefaultModel() string { + return "blocked-tool-provider" +} + +func TestAgentLoop_MountProcessHook_ApprovalDeny(t *testing.T) { + provider := &blockedToolProvider{} + al, agent, cleanup := newHookTestLoop(t, provider) + defer cleanup() + + if err := al.MountProcessHook(context.Background(), "ipc-approval", ProcessHookOptions{ + Command: processHookHelperCommand(), + Env: processHookHelperEnv("deny", ""), + ApproveTool: true, + }); err != nil { + t.Fatalf("MountProcessHook failed: %v", err) + } + + sub := al.SubscribeEvents(16) + defer al.UnsubscribeEvents(sub.ID) + + resp, err := al.runAgentLoop(context.Background(), agent, processOptions{ + SessionKey: "session-1", + Channel: "cli", + ChatID: "direct", + UserMessage: "run blocked tool", + DefaultResponse: defaultResponse, + EnableSummary: false, + SendResponse: false, + }) + if err != nil { + t.Fatalf("runAgentLoop failed: %v", err) + } + + expected := "Tool execution denied by approval hook: blocked by ipc hook" + if resp != expected { + t.Fatalf("expected %q, got %q", expected, resp) + } + + events := collectEventStream(sub.C) + skippedEvt, ok := findEvent(events, EventKindToolExecSkipped) + if !ok { + t.Fatal("expected tool skipped event") + } + payload, ok := skippedEvt.Payload.(ToolExecSkippedPayload) + if !ok { + t.Fatalf("expected ToolExecSkippedPayload, got %T", skippedEvt.Payload) + } + if payload.Reason != expected { + t.Fatalf("expected reason %q, got %q", expected, payload.Reason) + } +} + +func processHookHelperCommand() []string { + return []string{os.Args[0], "-test.run=TestProcessHook_HelperProcess", "--"} +} + +func processHookHelperEnv(mode, eventLog string) []string { + env := []string{ + "PICOCLAW_HOOK_HELPER=1", + "PICOCLAW_HOOK_MODE=" + mode, + } + if eventLog != "" { + env = append(env, "PICOCLAW_HOOK_EVENT_LOG="+eventLog) + } + return env +} + +func waitForFileContains(t *testing.T, path, substring string) { + t.Helper() + + deadline := time.Now().Add(3 * time.Second) + for time.Now().Before(deadline) { + data, err := os.ReadFile(path) + if err == nil && strings.Contains(string(data), substring) { + return + } + time.Sleep(20 * time.Millisecond) + } + + data, _ := os.ReadFile(path) + t.Fatalf("timed out waiting for %q in %s; current content: %q", substring, path, string(data)) +} + +func runProcessHookHelper() error { + mode := os.Getenv("PICOCLAW_HOOK_MODE") + eventLog := os.Getenv("PICOCLAW_HOOK_EVENT_LOG") + + scanner := bufio.NewScanner(os.Stdin) + scanner.Buffer(make([]byte, 0, 64*1024), processHookReadBufferSize) + encoder := json.NewEncoder(os.Stdout) + + for scanner.Scan() { + var msg processHookRPCMessage + if err := json.Unmarshal(scanner.Bytes(), &msg); err != nil { + return err + } + + if msg.ID == 0 { + if msg.Method == "hook.event" && eventLog != "" { + var evt map[string]any + if err := json.Unmarshal(msg.Params, &evt); err == nil { + if rawKind, ok := evt["Kind"].(float64); ok { + kind := EventKind(rawKind) + _ = os.WriteFile(eventLog, []byte(kind.String()+"\n"), 0o644) + } + } + } + continue + } + + result, rpcErr := handleProcessHookRequest(mode, msg) + resp := processHookRPCMessage{ + JSONRPC: processHookJSONRPCVersion, + ID: msg.ID, + } + if rpcErr != nil { + resp.Error = rpcErr + } else if result != nil { + body, err := json.Marshal(result) + if err != nil { + return err + } + resp.Result = body + } else { + resp.Result = []byte("{}") + } + + if err := encoder.Encode(resp); err != nil { + return err + } + } + + return scanner.Err() +} + +func handleProcessHookRequest(mode string, msg processHookRPCMessage) (any, *processHookRPCError) { + switch msg.Method { + case "hook.hello": + return map[string]any{"ok": true}, nil + case "hook.before_llm": + if mode != "rewrite" { + return map[string]any{"action": HookActionContinue}, nil + } + var req map[string]any + _ = json.Unmarshal(msg.Params, &req) + req["model"] = "process-model" + return map[string]any{ + "action": HookActionModify, + "request": req, + }, nil + case "hook.after_llm": + if mode != "rewrite" { + return map[string]any{"action": HookActionContinue}, nil + } + var resp map[string]any + _ = json.Unmarshal(msg.Params, &resp) + if rawResponse, ok := resp["response"].(map[string]any); ok { + if content, ok := rawResponse["content"].(string); ok { + rawResponse["content"] = content + "|ipc" + } + } + return map[string]any{ + "action": HookActionModify, + "response": resp, + }, nil + case "hook.before_tool": + if mode != "rewrite" { + return map[string]any{"action": HookActionContinue}, nil + } + var call map[string]any + _ = json.Unmarshal(msg.Params, &call) + rawArgs, ok := call["arguments"].(map[string]any) + if !ok || rawArgs == nil { + rawArgs = map[string]any{} + } + rawArgs["text"] = "ipc" + call["arguments"] = rawArgs + return map[string]any{ + "action": HookActionModify, + "call": call, + }, nil + case "hook.after_tool": + if mode != "rewrite" { + return map[string]any{"action": HookActionContinue}, nil + } + var result map[string]any + _ = json.Unmarshal(msg.Params, &result) + if rawResult, ok := result["result"].(map[string]any); ok { + if forLLM, ok := rawResult["for_llm"].(string); ok { + rawResult["for_llm"] = "ipc:" + forLLM + } + } + return map[string]any{ + "action": HookActionModify, + "result": result, + }, nil + case "hook.approve_tool": + if mode == "deny" { + return ApprovalDecision{ + Approved: false, + Reason: "blocked by ipc hook", + }, nil + } + return ApprovalDecision{Approved: true}, nil + default: + return nil, &processHookRPCError{ + Code: -32601, + Message: "method not found", + } + } +} diff --git a/pkg/agent/hooks.go b/pkg/agent/hooks.go index 74af542fa..c1ef58ffd 100644 --- a/pkg/agent/hooks.go +++ b/pkg/agent/hooks.go @@ -3,6 +3,7 @@ package agent import ( "context" "fmt" + "io" "sort" "sync" "time" @@ -30,8 +31,8 @@ const ( ) type HookDecision struct { - Action HookAction - Reason string + Action HookAction `json:"action"` + Reason string `json:"reason,omitempty"` } func (d HookDecision) normalizedAction() HookAction { @@ -42,20 +43,29 @@ func (d HookDecision) normalizedAction() HookAction { } type ApprovalDecision struct { - Approved bool - Reason string + Approved bool `json:"approved"` + Reason string `json:"reason,omitempty"` } +type HookSource uint8 + +const ( + HookSourceInProcess HookSource = iota + HookSourceProcess +) + type HookRegistration struct { Name string Priority int + Source HookSource Hook any } func NamedHook(name string, hook any) HookRegistration { return HookRegistration{ - Name: name, - Hook: hook, + Name: name, + Source: HookSourceInProcess, + Hook: hook, } } @@ -78,14 +88,14 @@ type ToolApprover interface { } type LLMHookRequest struct { - Meta EventMeta - Model string - Messages []providers.Message - Tools []providers.ToolDefinition - Options map[string]any - Channel string - ChatID string - GracefulTerminal bool + Meta EventMeta `json:"meta"` + Model string `json:"model"` + Messages []providers.Message `json:"messages,omitempty"` + Tools []providers.ToolDefinition `json:"tools,omitempty"` + Options map[string]any `json:"options,omitempty"` + Channel string `json:"channel,omitempty"` + ChatID string `json:"chat_id,omitempty"` + GracefulTerminal bool `json:"graceful_terminal,omitempty"` } func (r *LLMHookRequest) Clone() *LLMHookRequest { @@ -100,11 +110,11 @@ func (r *LLMHookRequest) Clone() *LLMHookRequest { } type LLMHookResponse struct { - Meta EventMeta - Model string - Response *providers.LLMResponse - Channel string - ChatID string + Meta EventMeta `json:"meta"` + Model string `json:"model"` + Response *providers.LLMResponse `json:"response,omitempty"` + Channel string `json:"channel,omitempty"` + ChatID string `json:"chat_id,omitempty"` } func (r *LLMHookResponse) Clone() *LLMHookResponse { @@ -117,11 +127,11 @@ func (r *LLMHookResponse) Clone() *LLMHookResponse { } type ToolCallHookRequest struct { - Meta EventMeta - Tool string - Arguments map[string]any - Channel string - ChatID string + Meta EventMeta `json:"meta"` + Tool string `json:"tool"` + Arguments map[string]any `json:"arguments,omitempty"` + Channel string `json:"channel,omitempty"` + ChatID string `json:"chat_id,omitempty"` } func (r *ToolCallHookRequest) Clone() *ToolCallHookRequest { @@ -134,11 +144,11 @@ func (r *ToolCallHookRequest) Clone() *ToolCallHookRequest { } type ToolApprovalRequest struct { - Meta EventMeta - Tool string - Arguments map[string]any - Channel string - ChatID string + Meta EventMeta `json:"meta"` + Tool string `json:"tool"` + Arguments map[string]any `json:"arguments,omitempty"` + Channel string `json:"channel,omitempty"` + ChatID string `json:"chat_id,omitempty"` } func (r *ToolApprovalRequest) Clone() *ToolApprovalRequest { @@ -151,13 +161,13 @@ func (r *ToolApprovalRequest) Clone() *ToolApprovalRequest { } type ToolResultHookResponse struct { - Meta EventMeta - Tool string - Arguments map[string]any - Result *tools.ToolResult - Duration time.Duration - Channel string - ChatID string + Meta EventMeta `json:"meta"` + Tool string `json:"tool"` + Arguments map[string]any `json:"arguments,omitempty"` + Result *tools.ToolResult `json:"result,omitempty"` + Duration time.Duration `json:"duration"` + Channel string `json:"channel,omitempty"` + ChatID string `json:"chat_id,omitempty"` } func (r *ToolResultHookResponse) Clone() *ToolResultHookResponse { @@ -215,9 +225,25 @@ func (hm *HookManager) Close() { hm.eventBus.Unsubscribe(hm.sub.ID) } <-hm.done + hm.closeAllHooks() }) } +func (hm *HookManager) ConfigureTimeouts(observer, interceptor, approval time.Duration) { + if hm == nil { + return + } + if observer > 0 { + hm.observerTimeout = observer + } + if interceptor > 0 { + hm.interceptorTimeout = interceptor + } + if approval > 0 { + hm.approvalTimeout = approval + } +} + func (hm *HookManager) Mount(reg HookRegistration) error { if hm == nil { return fmt.Errorf("hook manager is nil") @@ -232,6 +258,9 @@ func (hm *HookManager) Mount(reg HookRegistration) error { hm.mu.Lock() defer hm.mu.Unlock() + if existing, ok := hm.hooks[reg.Name]; ok { + closeHookIfPossible(existing.Hook) + } hm.hooks[reg.Name] = reg hm.rebuildOrdered() return nil @@ -245,6 +274,9 @@ func (hm *HookManager) Unmount(name string) { hm.mu.Lock() defer hm.mu.Unlock() + if existing, ok := hm.hooks[name]; ok { + closeHookIfPossible(existing.Hook) + } delete(hm.hooks, name) hm.rebuildOrdered() } @@ -425,6 +457,9 @@ func (hm *HookManager) rebuildOrdered() { hm.ordered = append(hm.ordered, reg) } sort.SliceStable(hm.ordered, func(i, j int) bool { + if hm.ordered[i].Source != hm.ordered[j].Source { + return hm.ordered[i].Source < hm.ordered[j].Source + } if hm.ordered[i].Priority == hm.ordered[j].Priority { return hm.ordered[i].Name < hm.ordered[j].Name } @@ -441,6 +476,17 @@ func (hm *HookManager) snapshotHooks() []HookRegistration { return snapshot } +func (hm *HookManager) closeAllHooks() { + hm.mu.Lock() + defer hm.mu.Unlock() + + for name, reg := range hm.hooks { + closeHookIfPossible(reg.Hook) + delete(hm.hooks, name) + } + hm.ordered = nil +} + func (hm *HookManager) runObserver(name string, observer EventObserver, evt Event) { ctx, cancel := context.WithTimeout(context.Background(), hm.observerTimeout) defer cancel() @@ -749,3 +795,15 @@ func cloneToolResult(result *tools.ToolResult) *tools.ToolResult { } return &cloned } + +func closeHookIfPossible(hook any) { + closer, ok := hook.(io.Closer) + if !ok { + return + } + if err := closer.Close(); err != nil { + logger.WarnCF("hooks", "Failed to close hook", map[string]any{ + "error": err.Error(), + }) + } +} diff --git a/pkg/agent/hooks_test.go b/pkg/agent/hooks_test.go index 6607b5fe7..e6471e9cc 100644 --- a/pkg/agent/hooks_test.go +++ b/pkg/agent/hooks_test.go @@ -47,6 +47,39 @@ func newHookTestLoop( } } +func TestHookManager_SortsInProcessBeforeProcess(t *testing.T) { + hm := NewHookManager(nil) + defer hm.Close() + + if err := hm.Mount(HookRegistration{ + Name: "process", + Priority: -10, + Source: HookSourceProcess, + Hook: struct{}{}, + }); err != nil { + t.Fatalf("mount process hook: %v", err) + } + if err := hm.Mount(HookRegistration{ + Name: "in-process", + Priority: 100, + Source: HookSourceInProcess, + Hook: struct{}{}, + }); err != nil { + t.Fatalf("mount in-process hook: %v", err) + } + + ordered := hm.snapshotHooks() + if len(ordered) != 2 { + t.Fatalf("expected 2 hooks, got %d", len(ordered)) + } + if ordered[0].Name != "in-process" { + t.Fatalf("expected in-process hook first, got %q", ordered[0].Name) + } + if ordered[1].Name != "process" { + t.Fatalf("expected process hook second, got %q", ordered[1].Name) + } +} + type llmHookTestProvider struct { mu sync.Mutex lastModel string diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index a85abcb60..41dfdff5f 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -49,6 +49,7 @@ type AgentLoop struct { transcriber voice.Transcriber cmdRegistry *commands.Registry mcp mcpRuntime + hookRuntime hookRuntime steering *steeringQueue mu sync.RWMutex activeTurnMu sync.RWMutex @@ -122,6 +123,7 @@ func NewAgentLoop( steering: newSteeringQueue(parseSteeringMode(cfg.Agents.Defaults.SteeringMode)), } al.hooks = NewHookManager(eventBus) + configureHookManagerFromConfig(al.hooks, cfg) return al } @@ -259,6 +261,9 @@ func registerSharedTools( func (al *AgentLoop) Run(ctx context.Context) error { al.running.Store(true) + if err := al.ensureHooksInitialized(ctx); err != nil { + return err + } if err := al.ensureMCPInitialized(ctx); err != nil { return err } @@ -773,6 +778,9 @@ func (al *AgentLoop) ReloadProviderAndConfig( al.mu.Unlock() + al.hookRuntime.reset(al) + configureHookManagerFromConfig(al.hooks, cfg) + // Close old provider after releasing the lock // This prevents blocking readers while closing if oldProvider, ok := extractProvider(oldRegistry); ok { @@ -987,6 +995,9 @@ func (al *AgentLoop) ProcessDirectWithChannel( ctx context.Context, content, sessionKey, channel, chatID string, ) (string, error) { + if err := al.ensureHooksInitialized(ctx); err != nil { + return "", err + } if err := al.ensureMCPInitialized(ctx); err != nil { return "", err } @@ -1008,6 +1019,13 @@ func (al *AgentLoop) ProcessHeartbeat( ctx context.Context, content, channel, chatID string, ) (string, error) { + if err := al.ensureHooksInitialized(ctx); err != nil { + return "", err + } + if err := al.ensureMCPInitialized(ctx); err != nil { + return "", err + } + agent := al.GetRegistry().GetDefaultAgent() if agent == nil { return "", fmt.Errorf("no default agent for heartbeat") diff --git a/pkg/agent/steering.go b/pkg/agent/steering.go index 77c2e0c17..55ee45ad1 100644 --- a/pkg/agent/steering.go +++ b/pkg/agent/steering.go @@ -183,6 +183,12 @@ func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID s if active := al.GetActiveTurn(); active != nil { return "", fmt.Errorf("turn %s is still active", active.TurnID) } + if err := al.ensureHooksInitialized(ctx); err != nil { + return "", err + } + if err := al.ensureMCPInitialized(ctx); err != nil { + return "", err + } steeringMsgs := al.dequeueSteeringMessages() if len(steeringMsgs) == 0 { diff --git a/pkg/config/config.go b/pkg/config/config.go index a3720b656..a7c44c825 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -82,6 +82,7 @@ type Config struct { Providers ProvidersConfig `json:"providers,omitempty"` ModelList []ModelConfig `json:"model_list"` // New model-centric provider configuration Gateway GatewayConfig `json:"gateway"` + Hooks HooksConfig `json:"hooks,omitempty"` Tools ToolsConfig `json:"tools"` Heartbeat HeartbeatConfig `json:"heartbeat"` Devices DevicesConfig `json:"devices"` @@ -90,6 +91,36 @@ type Config struct { BuildInfo BuildInfo `json:"build_info,omitempty"` } +type HooksConfig struct { + Enabled bool `json:"enabled"` + Defaults HookDefaultsConfig `json:"defaults,omitempty"` + Builtins map[string]BuiltinHookConfig `json:"builtins,omitempty"` + Processes map[string]ProcessHookConfig `json:"processes,omitempty"` +} + +type HookDefaultsConfig struct { + ObserverTimeoutMS int `json:"observer_timeout_ms,omitempty"` + InterceptorTimeoutMS int `json:"interceptor_timeout_ms,omitempty"` + ApprovalTimeoutMS int `json:"approval_timeout_ms,omitempty"` +} + +type BuiltinHookConfig struct { + Enabled bool `json:"enabled"` + Priority int `json:"priority,omitempty"` + Config json.RawMessage `json:"config,omitempty"` +} + +type ProcessHookConfig struct { + Enabled bool `json:"enabled"` + Priority int `json:"priority,omitempty"` + Transport string `json:"transport,omitempty"` + Command []string `json:"command,omitempty"` + Dir string `json:"dir,omitempty"` + Env map[string]string `json:"env,omitempty"` + Observe []string `json:"observe,omitempty"` + Intercept []string `json:"intercept,omitempty"` +} + // BuildInfo contains build-time version information type BuildInfo struct { Version string `json:"version"` diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index c5bdbf3c3..caab8a152 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -391,6 +391,22 @@ func TestDefaultConfig_ExecAllowRemoteEnabled(t *testing.T) { } } +func TestDefaultConfig_HooksDefaults(t *testing.T) { + cfg := DefaultConfig() + if !cfg.Hooks.Enabled { + t.Fatal("DefaultConfig().Hooks.Enabled should be true") + } + if cfg.Hooks.Defaults.ObserverTimeoutMS != 500 { + t.Fatalf("ObserverTimeoutMS = %d, want 500", cfg.Hooks.Defaults.ObserverTimeoutMS) + } + if cfg.Hooks.Defaults.InterceptorTimeoutMS != 5000 { + t.Fatalf("InterceptorTimeoutMS = %d, want 5000", cfg.Hooks.Defaults.InterceptorTimeoutMS) + } + if cfg.Hooks.Defaults.ApprovalTimeoutMS != 60000 { + t.Fatalf("ApprovalTimeoutMS = %d, want 60000", cfg.Hooks.Defaults.ApprovalTimeoutMS) + } +} + func TestLoadConfig_OpenAIWebSearchDefaultsTrueWhenUnset(t *testing.T) { dir := t.TempDir() configPath := filepath.Join(dir, "config.json") @@ -460,6 +476,88 @@ func TestLoadConfig_WebToolsProxy(t *testing.T) { } } +func TestLoadConfig_HooksProcessConfig(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.json") + configJSON := `{ + "hooks": { + "processes": { + "review-gate": { + "enabled": true, + "transport": "stdio", + "command": ["uvx", "picoclaw-hook-reviewer"], + "dir": "/tmp/hooks", + "env": { + "HOOK_MODE": "rewrite" + }, + "observe": ["turn_start", "turn_end"], + "intercept": ["before_tool", "approve_tool"] + } + }, + "builtins": { + "audit": { + "enabled": true, + "priority": 5, + "config": { + "label": "audit" + } + } + } + } +}` + if err := os.WriteFile(configPath, []byte(configJSON), 0o600); err != nil { + t.Fatalf("os.WriteFile() error: %v", err) + } + + cfg, err := LoadConfig(configPath) + if err != nil { + t.Fatalf("LoadConfig() error: %v", err) + } + + processCfg, ok := cfg.Hooks.Processes["review-gate"] + if !ok { + t.Fatal("expected review-gate process hook") + } + if !processCfg.Enabled { + t.Fatal("expected review-gate process hook to be enabled") + } + if processCfg.Transport != "stdio" { + t.Fatalf("Transport = %q, want stdio", processCfg.Transport) + } + if len(processCfg.Command) != 2 || processCfg.Command[0] != "uvx" { + t.Fatalf("Command = %v", processCfg.Command) + } + if processCfg.Dir != "/tmp/hooks" { + t.Fatalf("Dir = %q, want /tmp/hooks", processCfg.Dir) + } + if processCfg.Env["HOOK_MODE"] != "rewrite" { + t.Fatalf("HOOK_MODE = %q, want rewrite", processCfg.Env["HOOK_MODE"]) + } + if len(processCfg.Observe) != 2 || processCfg.Observe[1] != "turn_end" { + t.Fatalf("Observe = %v", processCfg.Observe) + } + if len(processCfg.Intercept) != 2 || processCfg.Intercept[1] != "approve_tool" { + t.Fatalf("Intercept = %v", processCfg.Intercept) + } + + builtinCfg, ok := cfg.Hooks.Builtins["audit"] + if !ok { + t.Fatal("expected audit builtin hook") + } + if !builtinCfg.Enabled { + t.Fatal("expected audit builtin hook to be enabled") + } + if builtinCfg.Priority != 5 { + t.Fatalf("Priority = %d, want 5", builtinCfg.Priority) + } + if !strings.Contains(string(builtinCfg.Config), `"audit"`) { + t.Fatalf("Config = %s", string(builtinCfg.Config)) + } + if cfg.Hooks.Defaults.ApprovalTimeoutMS != 60000 { + t.Fatalf("ApprovalTimeoutMS = %d, want 60000", cfg.Hooks.Defaults.ApprovalTimeoutMS) + } +} + // TestDefaultConfig_DMScope verifies the default dm_scope value // TestDefaultConfig_SummarizationThresholds verifies summarization defaults func TestDefaultConfig_SummarizationThresholds(t *testing.T) { diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go index 5e6b89a4c..bfb54fb97 100644 --- a/pkg/config/defaults.go +++ b/pkg/config/defaults.go @@ -177,6 +177,14 @@ func DefaultConfig() *Config { AllowFrom: FlexibleStringSlice{}, }, }, + Hooks: HooksConfig{ + Enabled: true, + Defaults: HookDefaultsConfig{ + ObserverTimeoutMS: 500, + InterceptorTimeoutMS: 5000, + ApprovalTimeoutMS: 60000, + }, + }, Providers: ProvidersConfig{ OpenAI: OpenAIProviderConfig{WebSearch: true}, },