feat(agent): add configurable hook mounting

This commit is contained in:
Hoshina
2026-03-21 19:46:16 +08:00
parent cf68c91eca
commit 337e43e5a5
11 changed files with 1634 additions and 36 deletions
+317
View File
@@ -0,0 +1,317 @@
package agent
import (
"context"
"fmt"
"sort"
"sync"
"time"
"github.com/sipeed/picoclaw/pkg/config"
)
type hookRuntime struct {
initOnce sync.Once
mu sync.Mutex
initErr error
mounted []string
}
func (r *hookRuntime) setInitErr(err error) {
r.mu.Lock()
r.initErr = err
r.mu.Unlock()
}
func (r *hookRuntime) getInitErr() error {
r.mu.Lock()
defer r.mu.Unlock()
return r.initErr
}
func (r *hookRuntime) setMounted(names []string) {
r.mu.Lock()
r.mounted = append([]string(nil), names...)
r.mu.Unlock()
}
func (r *hookRuntime) reset(al *AgentLoop) {
r.mu.Lock()
names := append([]string(nil), r.mounted...)
r.mounted = nil
r.initErr = nil
r.initOnce = sync.Once{}
r.mu.Unlock()
for _, name := range names {
al.UnmountHook(name)
}
}
// BuiltinHookFactory constructs an in-process hook from config.
type BuiltinHookFactory func(ctx context.Context, spec config.BuiltinHookConfig) (any, error)
var (
builtinHookRegistryMu sync.RWMutex
builtinHookRegistry = map[string]BuiltinHookFactory{}
)
// RegisterBuiltinHook registers a named in-process hook factory for config-driven mounting.
func RegisterBuiltinHook(name string, factory BuiltinHookFactory) error {
if name == "" {
return fmt.Errorf("builtin hook name is required")
}
if factory == nil {
return fmt.Errorf("builtin hook %q factory is nil", name)
}
builtinHookRegistryMu.Lock()
defer builtinHookRegistryMu.Unlock()
if _, exists := builtinHookRegistry[name]; exists {
return fmt.Errorf("builtin hook %q is already registered", name)
}
builtinHookRegistry[name] = factory
return nil
}
func unregisterBuiltinHook(name string) {
if name == "" {
return
}
builtinHookRegistryMu.Lock()
delete(builtinHookRegistry, name)
builtinHookRegistryMu.Unlock()
}
func lookupBuiltinHook(name string) (BuiltinHookFactory, bool) {
builtinHookRegistryMu.RLock()
defer builtinHookRegistryMu.RUnlock()
factory, ok := builtinHookRegistry[name]
return factory, ok
}
func configureHookManagerFromConfig(hm *HookManager, cfg *config.Config) {
if hm == nil || cfg == nil {
return
}
hm.ConfigureTimeouts(
hookTimeoutFromMS(cfg.Hooks.Defaults.ObserverTimeoutMS),
hookTimeoutFromMS(cfg.Hooks.Defaults.InterceptorTimeoutMS),
hookTimeoutFromMS(cfg.Hooks.Defaults.ApprovalTimeoutMS),
)
}
func hookTimeoutFromMS(ms int) time.Duration {
if ms <= 0 {
return 0
}
return time.Duration(ms) * time.Millisecond
}
func (al *AgentLoop) ensureHooksInitialized(ctx context.Context) error {
if al == nil || al.cfg == nil || al.hooks == nil {
return nil
}
al.hookRuntime.initOnce.Do(func() {
al.hookRuntime.setInitErr(al.loadConfiguredHooks(ctx))
})
return al.hookRuntime.getInitErr()
}
func (al *AgentLoop) loadConfiguredHooks(ctx context.Context) (err error) {
if al == nil || al.cfg == nil || !al.cfg.Hooks.Enabled {
return nil
}
mounted := make([]string, 0)
defer func() {
if err != nil {
for _, name := range mounted {
al.UnmountHook(name)
}
return
}
al.hookRuntime.setMounted(mounted)
}()
builtinNames := enabledBuiltinHookNames(al.cfg.Hooks.Builtins)
for _, name := range builtinNames {
spec := al.cfg.Hooks.Builtins[name]
factory, ok := lookupBuiltinHook(name)
if !ok {
return fmt.Errorf("builtin hook %q is not registered", name)
}
hook, factoryErr := factory(ctx, spec)
if factoryErr != nil {
return fmt.Errorf("build builtin hook %q: %w", name, factoryErr)
}
if err := al.MountHook(HookRegistration{
Name: name,
Priority: spec.Priority,
Source: HookSourceInProcess,
Hook: hook,
}); err != nil {
return fmt.Errorf("mount builtin hook %q: %w", name, err)
}
mounted = append(mounted, name)
}
processNames := enabledProcessHookNames(al.cfg.Hooks.Processes)
for _, name := range processNames {
spec := al.cfg.Hooks.Processes[name]
opts, buildErr := processHookOptionsFromConfig(spec)
if buildErr != nil {
return fmt.Errorf("configure process hook %q: %w", name, buildErr)
}
processHook, buildErr := NewProcessHook(ctx, name, opts)
if buildErr != nil {
return fmt.Errorf("start process hook %q: %w", name, buildErr)
}
if err := al.MountHook(HookRegistration{
Name: name,
Priority: spec.Priority,
Source: HookSourceProcess,
Hook: processHook,
}); err != nil {
_ = processHook.Close()
return fmt.Errorf("mount process hook %q: %w", name, err)
}
mounted = append(mounted, name)
}
return nil
}
func enabledBuiltinHookNames(specs map[string]config.BuiltinHookConfig) []string {
if len(specs) == 0 {
return nil
}
names := make([]string, 0, len(specs))
for name, spec := range specs {
if spec.Enabled {
names = append(names, name)
}
}
sort.Strings(names)
return names
}
func enabledProcessHookNames(specs map[string]config.ProcessHookConfig) []string {
if len(specs) == 0 {
return nil
}
names := make([]string, 0, len(specs))
for name, spec := range specs {
if spec.Enabled {
names = append(names, name)
}
}
sort.Strings(names)
return names
}
func processHookOptionsFromConfig(spec config.ProcessHookConfig) (ProcessHookOptions, error) {
transport := spec.Transport
if transport == "" {
transport = "stdio"
}
if transport != "stdio" {
return ProcessHookOptions{}, fmt.Errorf("unsupported transport %q", transport)
}
if len(spec.Command) == 0 {
return ProcessHookOptions{}, fmt.Errorf("command is required")
}
opts := ProcessHookOptions{
Command: append([]string(nil), spec.Command...),
Dir: spec.Dir,
Env: processHookEnvFromMap(spec.Env),
}
observeKinds, observeEnabled, err := processHookObserveKindsFromConfig(spec.Observe)
if err != nil {
return ProcessHookOptions{}, err
}
opts.Observe = observeEnabled
opts.ObserveKinds = observeKinds
for _, intercept := range spec.Intercept {
switch intercept {
case "before_llm", "after_llm":
opts.InterceptLLM = true
case "before_tool", "after_tool":
opts.InterceptTool = true
case "approve_tool":
opts.ApproveTool = true
case "":
continue
default:
return ProcessHookOptions{}, fmt.Errorf("unsupported intercept %q", intercept)
}
}
if !opts.Observe && !opts.InterceptLLM && !opts.InterceptTool && !opts.ApproveTool {
return ProcessHookOptions{}, fmt.Errorf("no hook modes enabled")
}
return opts, nil
}
func processHookEnvFromMap(envMap map[string]string) []string {
if len(envMap) == 0 {
return nil
}
keys := make([]string, 0, len(envMap))
for key := range envMap {
keys = append(keys, key)
}
sort.Strings(keys)
env := make([]string, 0, len(keys))
for _, key := range keys {
env = append(env, key+"="+envMap[key])
}
return env
}
func processHookObserveKindsFromConfig(observe []string) ([]string, bool, error) {
if len(observe) == 0 {
return nil, false, nil
}
validKinds := validHookEventKinds()
normalized := make([]string, 0, len(observe))
for _, kind := range observe {
switch kind {
case "", "*", "all":
return nil, true, nil
default:
if _, ok := validKinds[kind]; !ok {
return nil, false, fmt.Errorf("unsupported observe event %q", kind)
}
normalized = append(normalized, kind)
}
}
if len(normalized) == 0 {
return nil, false, nil
}
return normalized, true, nil
}
func validHookEventKinds() map[string]struct{} {
kinds := make(map[string]struct{}, int(eventKindCount))
for kind := EventKind(0); kind < eventKindCount; kind++ {
kinds[kind.String()] = struct{}{}
}
return kinds
}
+179
View File
@@ -0,0 +1,179 @@
package agent
import (
"context"
"encoding/json"
"path/filepath"
"testing"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
)
type builtinAutoHookConfig struct {
Model string `json:"model"`
Suffix string `json:"suffix"`
}
type builtinAutoHook struct {
model string
suffix string
}
func (h *builtinAutoHook) BeforeLLM(
ctx context.Context,
req *LLMHookRequest,
) (*LLMHookRequest, HookDecision, error) {
next := req.Clone()
next.Model = h.model
return next, HookDecision{Action: HookActionModify}, nil
}
func (h *builtinAutoHook) AfterLLM(
ctx context.Context,
resp *LLMHookResponse,
) (*LLMHookResponse, HookDecision, error) {
next := resp.Clone()
if next.Response != nil {
next.Response.Content += h.suffix
}
return next, HookDecision{Action: HookActionModify}, nil
}
func newConfiguredHookLoop(t *testing.T, provider *llmHookTestProvider, hooks config.HooksConfig) *AgentLoop {
t.Helper()
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: t.TempDir(),
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
Hooks: hooks,
}
return NewAgentLoop(cfg, bus.NewMessageBus(), provider)
}
func TestAgentLoop_ProcessDirectWithChannel_AutoMountsBuiltinHook(t *testing.T) {
const hookName = "test-auto-builtin-hook"
if err := RegisterBuiltinHook(hookName, func(
ctx context.Context,
spec config.BuiltinHookConfig,
) (any, error) {
var hookCfg builtinAutoHookConfig
if len(spec.Config) > 0 {
if err := json.Unmarshal(spec.Config, &hookCfg); err != nil {
return nil, err
}
}
return &builtinAutoHook{
model: hookCfg.Model,
suffix: hookCfg.Suffix,
}, nil
}); err != nil {
t.Fatalf("RegisterBuiltinHook failed: %v", err)
}
t.Cleanup(func() {
unregisterBuiltinHook(hookName)
})
rawCfg, err := json.Marshal(builtinAutoHookConfig{
Model: "builtin-model",
Suffix: "|builtin",
})
if err != nil {
t.Fatalf("json.Marshal failed: %v", err)
}
provider := &llmHookTestProvider{}
al := newConfiguredHookLoop(t, provider, config.HooksConfig{
Enabled: true,
Builtins: map[string]config.BuiltinHookConfig{
hookName: {
Enabled: true,
Config: rawCfg,
},
},
})
defer al.Close()
resp, err := al.ProcessDirectWithChannel(context.Background(), "hello", "session-1", "cli", "direct")
if err != nil {
t.Fatalf("ProcessDirectWithChannel failed: %v", err)
}
if resp != "provider content|builtin" {
t.Fatalf("expected builtin-hooked content, got %q", resp)
}
provider.mu.Lock()
lastModel := provider.lastModel
provider.mu.Unlock()
if lastModel != "builtin-model" {
t.Fatalf("expected builtin model, got %q", lastModel)
}
}
func TestAgentLoop_ProcessDirectWithChannel_AutoMountsProcessHook(t *testing.T) {
provider := &llmHookTestProvider{}
eventLog := filepath.Join(t.TempDir(), "events.log")
al := newConfiguredHookLoop(t, provider, config.HooksConfig{
Enabled: true,
Processes: map[string]config.ProcessHookConfig{
"ipc-auto": {
Enabled: true,
Command: processHookHelperCommand(),
Env: map[string]string{
"PICOCLAW_HOOK_HELPER": "1",
"PICOCLAW_HOOK_MODE": "rewrite",
"PICOCLAW_HOOK_EVENT_LOG": eventLog,
},
Observe: []string{"turn_end"},
Intercept: []string{"before_llm", "after_llm"},
},
},
})
defer al.Close()
resp, err := al.ProcessDirectWithChannel(context.Background(), "hello", "session-1", "cli", "direct")
if err != nil {
t.Fatalf("ProcessDirectWithChannel failed: %v", err)
}
if resp != "provider content|ipc" {
t.Fatalf("expected process-hooked content, got %q", resp)
}
provider.mu.Lock()
lastModel := provider.lastModel
provider.mu.Unlock()
if lastModel != "process-model" {
t.Fatalf("expected process model, got %q", lastModel)
}
waitForFileContains(t, eventLog, "turn_end")
}
func TestAgentLoop_ProcessDirectWithChannel_InvalidConfiguredHookFails(t *testing.T) {
provider := &llmHookTestProvider{}
al := newConfiguredHookLoop(t, provider, config.HooksConfig{
Enabled: true,
Processes: map[string]config.ProcessHookConfig{
"bad-hook": {
Enabled: true,
Command: processHookHelperCommand(),
Intercept: []string{"not_supported"},
},
},
})
defer al.Close()
_, err := al.ProcessDirectWithChannel(context.Background(), "hello", "session-1", "cli", "direct")
if err == nil {
t.Fatal("expected invalid configured hook error")
}
}
+511
View File
@@ -0,0 +1,511 @@
package agent
import (
"bufio"
"context"
"encoding/json"
"fmt"
"io"
"os"
"os/exec"
"sync"
"sync/atomic"
"time"
"github.com/sipeed/picoclaw/pkg/logger"
)
const (
processHookJSONRPCVersion = "2.0"
processHookReadBufferSize = 1024 * 1024
processHookCloseTimeout = 2 * time.Second
)
type ProcessHookOptions struct {
Command []string
Dir string
Env []string
Observe bool
ObserveKinds []string
InterceptLLM bool
InterceptTool bool
ApproveTool bool
}
type ProcessHook struct {
name string
opts ProcessHookOptions
cmd *exec.Cmd
stdin io.WriteCloser
observeKinds map[string]struct{}
writeMu sync.Mutex
pendingMu sync.Mutex
pending map[uint64]chan processHookRPCMessage
nextID atomic.Uint64
closed atomic.Bool
done chan struct{}
closeErr error
closeMu sync.Mutex
closeOnce sync.Once
}
type processHookRPCMessage struct {
JSONRPC string `json:"jsonrpc,omitempty"`
ID uint64 `json:"id,omitempty"`
Method string `json:"method,omitempty"`
Params json.RawMessage `json:"params,omitempty"`
Result json.RawMessage `json:"result,omitempty"`
Error *processHookRPCError `json:"error,omitempty"`
}
type processHookRPCError struct {
Code int `json:"code"`
Message string `json:"message"`
}
type processHookHelloParams struct {
Name string `json:"name"`
Version int `json:"version"`
Modes []string `json:"modes,omitempty"`
}
type processHookDecisionResponse struct {
Action HookAction `json:"action"`
Reason string `json:"reason,omitempty"`
}
type processHookBeforeLLMResponse struct {
processHookDecisionResponse
Request *LLMHookRequest `json:"request,omitempty"`
}
type processHookAfterLLMResponse struct {
processHookDecisionResponse
Response *LLMHookResponse `json:"response,omitempty"`
}
type processHookBeforeToolResponse struct {
processHookDecisionResponse
Call *ToolCallHookRequest `json:"call,omitempty"`
}
type processHookAfterToolResponse struct {
processHookDecisionResponse
Result *ToolResultHookResponse `json:"result,omitempty"`
}
func NewProcessHook(ctx context.Context, name string, opts ProcessHookOptions) (*ProcessHook, error) {
if len(opts.Command) == 0 {
return nil, fmt.Errorf("process hook command is required")
}
cmd := exec.Command(opts.Command[0], opts.Command[1:]...)
cmd.Dir = opts.Dir
if len(opts.Env) > 0 {
cmd.Env = append(os.Environ(), opts.Env...)
}
stdin, err := cmd.StdinPipe()
if err != nil {
return nil, fmt.Errorf("create process hook stdin: %w", err)
}
stdout, err := cmd.StdoutPipe()
if err != nil {
return nil, fmt.Errorf("create process hook stdout: %w", err)
}
stderr, err := cmd.StderrPipe()
if err != nil {
return nil, fmt.Errorf("create process hook stderr: %w", err)
}
if err := cmd.Start(); err != nil {
return nil, fmt.Errorf("start process hook: %w", err)
}
ph := &ProcessHook{
name: name,
opts: opts,
cmd: cmd,
stdin: stdin,
observeKinds: newProcessHookObserveKinds(opts.ObserveKinds),
pending: make(map[uint64]chan processHookRPCMessage),
done: make(chan struct{}),
}
go ph.readLoop(stdout)
go ph.readStderr(stderr)
go ph.waitLoop()
helloCtx := ctx
if helloCtx == nil {
var cancel context.CancelFunc
helloCtx, cancel = context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
}
if err := ph.hello(helloCtx); err != nil {
_ = ph.Close()
return nil, err
}
return ph, nil
}
func (ph *ProcessHook) Close() error {
if ph == nil {
return nil
}
ph.closeOnce.Do(func() {
ph.closed.Store(true)
if ph.stdin != nil {
_ = ph.stdin.Close()
}
select {
case <-ph.done:
case <-time.After(processHookCloseTimeout):
if ph.cmd != nil && ph.cmd.Process != nil {
_ = ph.cmd.Process.Kill()
}
<-ph.done
}
})
ph.closeMu.Lock()
defer ph.closeMu.Unlock()
return ph.closeErr
}
func (ph *ProcessHook) OnEvent(ctx context.Context, evt Event) error {
if ph == nil || !ph.opts.Observe {
return nil
}
if len(ph.observeKinds) > 0 {
if _, ok := ph.observeKinds[evt.Kind.String()]; !ok {
return nil
}
}
return ph.notify(ctx, "hook.event", evt)
}
func (ph *ProcessHook) BeforeLLM(
ctx context.Context,
req *LLMHookRequest,
) (*LLMHookRequest, HookDecision, error) {
if ph == nil || !ph.opts.InterceptLLM {
return req, HookDecision{Action: HookActionContinue}, nil
}
var resp processHookBeforeLLMResponse
if err := ph.call(ctx, "hook.before_llm", req, &resp); err != nil {
return nil, HookDecision{}, err
}
if resp.Request == nil {
resp.Request = req
}
return resp.Request, HookDecision{Action: resp.Action, Reason: resp.Reason}, nil
}
func (ph *ProcessHook) AfterLLM(
ctx context.Context,
resp *LLMHookResponse,
) (*LLMHookResponse, HookDecision, error) {
if ph == nil || !ph.opts.InterceptLLM {
return resp, HookDecision{Action: HookActionContinue}, nil
}
var result processHookAfterLLMResponse
if err := ph.call(ctx, "hook.after_llm", resp, &result); err != nil {
return nil, HookDecision{}, err
}
if result.Response == nil {
result.Response = resp
}
return result.Response, HookDecision{Action: result.Action, Reason: result.Reason}, nil
}
func (ph *ProcessHook) BeforeTool(
ctx context.Context,
call *ToolCallHookRequest,
) (*ToolCallHookRequest, HookDecision, error) {
if ph == nil || !ph.opts.InterceptTool {
return call, HookDecision{Action: HookActionContinue}, nil
}
var resp processHookBeforeToolResponse
if err := ph.call(ctx, "hook.before_tool", call, &resp); err != nil {
return nil, HookDecision{}, err
}
if resp.Call == nil {
resp.Call = call
}
return resp.Call, HookDecision{Action: resp.Action, Reason: resp.Reason}, nil
}
func (ph *ProcessHook) AfterTool(
ctx context.Context,
result *ToolResultHookResponse,
) (*ToolResultHookResponse, HookDecision, error) {
if ph == nil || !ph.opts.InterceptTool {
return result, HookDecision{Action: HookActionContinue}, nil
}
var resp processHookAfterToolResponse
if err := ph.call(ctx, "hook.after_tool", result, &resp); err != nil {
return nil, HookDecision{}, err
}
if resp.Result == nil {
resp.Result = result
}
return resp.Result, HookDecision{Action: resp.Action, Reason: resp.Reason}, nil
}
func (ph *ProcessHook) ApproveTool(ctx context.Context, req *ToolApprovalRequest) (ApprovalDecision, error) {
if ph == nil || !ph.opts.ApproveTool {
return ApprovalDecision{Approved: true}, nil
}
var resp ApprovalDecision
if err := ph.call(ctx, "hook.approve_tool", req, &resp); err != nil {
return ApprovalDecision{}, err
}
return resp, nil
}
func (ph *ProcessHook) hello(ctx context.Context) error {
modes := make([]string, 0, 4)
if ph.opts.Observe {
modes = append(modes, "observe")
}
if ph.opts.InterceptLLM {
modes = append(modes, "llm")
}
if ph.opts.InterceptTool {
modes = append(modes, "tool")
}
if ph.opts.ApproveTool {
modes = append(modes, "approve")
}
var result map[string]any
return ph.call(ctx, "hook.hello", processHookHelloParams{
Name: ph.name,
Version: 1,
Modes: modes,
}, &result)
}
func (ph *ProcessHook) notify(ctx context.Context, method string, params any) error {
msg := processHookRPCMessage{
JSONRPC: processHookJSONRPCVersion,
Method: method,
}
if params != nil {
body, err := json.Marshal(params)
if err != nil {
return err
}
msg.Params = body
}
return ph.send(ctx, msg)
}
func (ph *ProcessHook) call(ctx context.Context, method string, params any, out any) error {
if ph.closed.Load() {
return fmt.Errorf("process hook %q is closed", ph.name)
}
id := ph.nextID.Add(1)
respCh := make(chan processHookRPCMessage, 1)
ph.pendingMu.Lock()
ph.pending[id] = respCh
ph.pendingMu.Unlock()
msg := processHookRPCMessage{
JSONRPC: processHookJSONRPCVersion,
ID: id,
Method: method,
}
if params != nil {
body, err := json.Marshal(params)
if err != nil {
ph.removePending(id)
return err
}
msg.Params = body
}
if err := ph.send(ctx, msg); err != nil {
ph.removePending(id)
return err
}
select {
case resp, ok := <-respCh:
if !ok {
return fmt.Errorf("process hook %q closed while waiting for %s", ph.name, method)
}
if resp.Error != nil {
return fmt.Errorf("process hook %q %s failed: %s", ph.name, method, resp.Error.Message)
}
if out != nil && len(resp.Result) > 0 {
if err := json.Unmarshal(resp.Result, out); err != nil {
return fmt.Errorf("decode process hook %q %s result: %w", ph.name, method, err)
}
}
return nil
case <-ctx.Done():
ph.removePending(id)
return ctx.Err()
}
}
func (ph *ProcessHook) send(ctx context.Context, msg processHookRPCMessage) error {
body, err := json.Marshal(msg)
if err != nil {
return err
}
body = append(body, '\n')
ph.writeMu.Lock()
defer ph.writeMu.Unlock()
if ph.closed.Load() {
return fmt.Errorf("process hook %q is closed", ph.name)
}
done := make(chan error, 1)
go func() {
_, writeErr := ph.stdin.Write(body)
done <- writeErr
}()
select {
case err := <-done:
if err != nil {
return fmt.Errorf("write process hook %q message: %w", ph.name, err)
}
return nil
case <-ctx.Done():
return ctx.Err()
}
}
func (ph *ProcessHook) readLoop(stdout io.Reader) {
scanner := bufio.NewScanner(stdout)
scanner.Buffer(make([]byte, 0, 64*1024), processHookReadBufferSize)
for scanner.Scan() {
var msg processHookRPCMessage
if err := json.Unmarshal(scanner.Bytes(), &msg); err != nil {
logger.WarnCF("hooks", "Failed to decode process hook message", map[string]any{
"hook": ph.name,
"error": err.Error(),
})
continue
}
if msg.ID == 0 {
continue
}
ph.pendingMu.Lock()
respCh, ok := ph.pending[msg.ID]
if ok {
delete(ph.pending, msg.ID)
}
ph.pendingMu.Unlock()
if ok {
respCh <- msg
close(respCh)
}
}
}
func (ph *ProcessHook) readStderr(stderr io.Reader) {
scanner := bufio.NewScanner(stderr)
scanner.Buffer(make([]byte, 0, 16*1024), processHookReadBufferSize)
for scanner.Scan() {
logger.WarnCF("hooks", "Process hook stderr", map[string]any{
"hook": ph.name,
"stderr": scanner.Text(),
})
}
}
func (ph *ProcessHook) waitLoop() {
err := ph.cmd.Wait()
ph.closeMu.Lock()
ph.closeErr = err
ph.closeMu.Unlock()
ph.failPending(err)
close(ph.done)
}
func (ph *ProcessHook) failPending(err error) {
ph.pendingMu.Lock()
defer ph.pendingMu.Unlock()
msg := processHookRPCMessage{
Error: &processHookRPCError{
Code: -32000,
Message: "process exited",
},
}
if err != nil {
msg.Error.Message = err.Error()
}
for id, ch := range ph.pending {
delete(ph.pending, id)
ch <- msg
close(ch)
}
}
func (ph *ProcessHook) removePending(id uint64) {
ph.pendingMu.Lock()
defer ph.pendingMu.Unlock()
if ch, ok := ph.pending[id]; ok {
delete(ph.pending, id)
close(ch)
}
}
func (al *AgentLoop) MountProcessHook(ctx context.Context, name string, opts ProcessHookOptions) error {
if al == nil {
return fmt.Errorf("agent loop is nil")
}
processHook, err := NewProcessHook(ctx, name, opts)
if err != nil {
return err
}
if err := al.MountHook(HookRegistration{
Name: name,
Source: HookSourceProcess,
Hook: processHook,
}); err != nil {
_ = processHook.Close()
return err
}
return nil
}
func newProcessHookObserveKinds(kinds []string) map[string]struct{} {
if len(kinds) == 0 {
return nil
}
normalized := make(map[string]struct{}, len(kinds))
for _, kind := range kinds {
if kind == "" {
continue
}
normalized[kind] = struct{}{}
}
if len(normalized) == 0 {
return nil
}
return normalized
}
+339
View File
@@ -0,0 +1,339 @@
package agent
import (
"bufio"
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/providers"
)
func TestProcessHook_HelperProcess(t *testing.T) {
if os.Getenv("PICOCLAW_HOOK_HELPER") != "1" {
return
}
if err := runProcessHookHelper(); err != nil {
fmt.Fprintln(os.Stderr, err.Error())
os.Exit(1)
}
os.Exit(0)
}
func TestAgentLoop_MountProcessHook_LLMAndObserver(t *testing.T) {
provider := &llmHookTestProvider{}
al, agent, cleanup := newHookTestLoop(t, provider)
defer cleanup()
eventLog := filepath.Join(t.TempDir(), "events.log")
if err := al.MountProcessHook(context.Background(), "ipc-llm", ProcessHookOptions{
Command: processHookHelperCommand(),
Env: processHookHelperEnv("rewrite", eventLog),
Observe: true,
InterceptLLM: true,
}); err != nil {
t.Fatalf("MountProcessHook failed: %v", err)
}
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
SessionKey: "session-1",
Channel: "cli",
ChatID: "direct",
UserMessage: "hello",
DefaultResponse: defaultResponse,
EnableSummary: false,
SendResponse: false,
})
if err != nil {
t.Fatalf("runAgentLoop failed: %v", err)
}
if resp != "provider content|ipc" {
t.Fatalf("expected process-hooked llm content, got %q", resp)
}
provider.mu.Lock()
lastModel := provider.lastModel
provider.mu.Unlock()
if lastModel != "process-model" {
t.Fatalf("expected process model, got %q", lastModel)
}
waitForFileContains(t, eventLog, "turn_end")
}
func TestAgentLoop_MountProcessHook_ToolRewrite(t *testing.T) {
provider := &toolHookProvider{}
al, agent, cleanup := newHookTestLoop(t, provider)
defer cleanup()
al.RegisterTool(&echoTextTool{})
if err := al.MountProcessHook(context.Background(), "ipc-tool", ProcessHookOptions{
Command: processHookHelperCommand(),
Env: processHookHelperEnv("rewrite", ""),
InterceptTool: true,
}); err != nil {
t.Fatalf("MountProcessHook failed: %v", err)
}
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
SessionKey: "session-1",
Channel: "cli",
ChatID: "direct",
UserMessage: "run tool",
DefaultResponse: defaultResponse,
EnableSummary: false,
SendResponse: false,
})
if err != nil {
t.Fatalf("runAgentLoop failed: %v", err)
}
if resp != "ipc:ipc" {
t.Fatalf("expected rewritten process-hook tool result, got %q", resp)
}
}
type blockedToolProvider struct {
calls int
}
func (p *blockedToolProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
p.calls++
if p.calls == 1 {
return &providers.LLMResponse{
ToolCalls: []providers.ToolCall{
{
ID: "call-1",
Name: "blocked_tool",
Arguments: map[string]any{},
},
},
}, nil
}
return &providers.LLMResponse{
Content: messages[len(messages)-1].Content,
}, nil
}
func (p *blockedToolProvider) GetDefaultModel() string {
return "blocked-tool-provider"
}
func TestAgentLoop_MountProcessHook_ApprovalDeny(t *testing.T) {
provider := &blockedToolProvider{}
al, agent, cleanup := newHookTestLoop(t, provider)
defer cleanup()
if err := al.MountProcessHook(context.Background(), "ipc-approval", ProcessHookOptions{
Command: processHookHelperCommand(),
Env: processHookHelperEnv("deny", ""),
ApproveTool: true,
}); err != nil {
t.Fatalf("MountProcessHook failed: %v", err)
}
sub := al.SubscribeEvents(16)
defer al.UnsubscribeEvents(sub.ID)
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
SessionKey: "session-1",
Channel: "cli",
ChatID: "direct",
UserMessage: "run blocked tool",
DefaultResponse: defaultResponse,
EnableSummary: false,
SendResponse: false,
})
if err != nil {
t.Fatalf("runAgentLoop failed: %v", err)
}
expected := "Tool execution denied by approval hook: blocked by ipc hook"
if resp != expected {
t.Fatalf("expected %q, got %q", expected, resp)
}
events := collectEventStream(sub.C)
skippedEvt, ok := findEvent(events, EventKindToolExecSkipped)
if !ok {
t.Fatal("expected tool skipped event")
}
payload, ok := skippedEvt.Payload.(ToolExecSkippedPayload)
if !ok {
t.Fatalf("expected ToolExecSkippedPayload, got %T", skippedEvt.Payload)
}
if payload.Reason != expected {
t.Fatalf("expected reason %q, got %q", expected, payload.Reason)
}
}
func processHookHelperCommand() []string {
return []string{os.Args[0], "-test.run=TestProcessHook_HelperProcess", "--"}
}
func processHookHelperEnv(mode, eventLog string) []string {
env := []string{
"PICOCLAW_HOOK_HELPER=1",
"PICOCLAW_HOOK_MODE=" + mode,
}
if eventLog != "" {
env = append(env, "PICOCLAW_HOOK_EVENT_LOG="+eventLog)
}
return env
}
func waitForFileContains(t *testing.T, path, substring string) {
t.Helper()
deadline := time.Now().Add(3 * time.Second)
for time.Now().Before(deadline) {
data, err := os.ReadFile(path)
if err == nil && strings.Contains(string(data), substring) {
return
}
time.Sleep(20 * time.Millisecond)
}
data, _ := os.ReadFile(path)
t.Fatalf("timed out waiting for %q in %s; current content: %q", substring, path, string(data))
}
func runProcessHookHelper() error {
mode := os.Getenv("PICOCLAW_HOOK_MODE")
eventLog := os.Getenv("PICOCLAW_HOOK_EVENT_LOG")
scanner := bufio.NewScanner(os.Stdin)
scanner.Buffer(make([]byte, 0, 64*1024), processHookReadBufferSize)
encoder := json.NewEncoder(os.Stdout)
for scanner.Scan() {
var msg processHookRPCMessage
if err := json.Unmarshal(scanner.Bytes(), &msg); err != nil {
return err
}
if msg.ID == 0 {
if msg.Method == "hook.event" && eventLog != "" {
var evt map[string]any
if err := json.Unmarshal(msg.Params, &evt); err == nil {
if rawKind, ok := evt["Kind"].(float64); ok {
kind := EventKind(rawKind)
_ = os.WriteFile(eventLog, []byte(kind.String()+"\n"), 0o644)
}
}
}
continue
}
result, rpcErr := handleProcessHookRequest(mode, msg)
resp := processHookRPCMessage{
JSONRPC: processHookJSONRPCVersion,
ID: msg.ID,
}
if rpcErr != nil {
resp.Error = rpcErr
} else if result != nil {
body, err := json.Marshal(result)
if err != nil {
return err
}
resp.Result = body
} else {
resp.Result = []byte("{}")
}
if err := encoder.Encode(resp); err != nil {
return err
}
}
return scanner.Err()
}
func handleProcessHookRequest(mode string, msg processHookRPCMessage) (any, *processHookRPCError) {
switch msg.Method {
case "hook.hello":
return map[string]any{"ok": true}, nil
case "hook.before_llm":
if mode != "rewrite" {
return map[string]any{"action": HookActionContinue}, nil
}
var req map[string]any
_ = json.Unmarshal(msg.Params, &req)
req["model"] = "process-model"
return map[string]any{
"action": HookActionModify,
"request": req,
}, nil
case "hook.after_llm":
if mode != "rewrite" {
return map[string]any{"action": HookActionContinue}, nil
}
var resp map[string]any
_ = json.Unmarshal(msg.Params, &resp)
if rawResponse, ok := resp["response"].(map[string]any); ok {
if content, ok := rawResponse["content"].(string); ok {
rawResponse["content"] = content + "|ipc"
}
}
return map[string]any{
"action": HookActionModify,
"response": resp,
}, nil
case "hook.before_tool":
if mode != "rewrite" {
return map[string]any{"action": HookActionContinue}, nil
}
var call map[string]any
_ = json.Unmarshal(msg.Params, &call)
rawArgs, ok := call["arguments"].(map[string]any)
if !ok || rawArgs == nil {
rawArgs = map[string]any{}
}
rawArgs["text"] = "ipc"
call["arguments"] = rawArgs
return map[string]any{
"action": HookActionModify,
"call": call,
}, nil
case "hook.after_tool":
if mode != "rewrite" {
return map[string]any{"action": HookActionContinue}, nil
}
var result map[string]any
_ = json.Unmarshal(msg.Params, &result)
if rawResult, ok := result["result"].(map[string]any); ok {
if forLLM, ok := rawResult["for_llm"].(string); ok {
rawResult["for_llm"] = "ipc:" + forLLM
}
}
return map[string]any{
"action": HookActionModify,
"result": result,
}, nil
case "hook.approve_tool":
if mode == "deny" {
return ApprovalDecision{
Approved: false,
Reason: "blocked by ipc hook",
}, nil
}
return ApprovalDecision{Approved: true}, nil
default:
return nil, &processHookRPCError{
Code: -32601,
Message: "method not found",
}
}
}
+94 -36
View File
@@ -3,6 +3,7 @@ package agent
import (
"context"
"fmt"
"io"
"sort"
"sync"
"time"
@@ -30,8 +31,8 @@ const (
)
type HookDecision struct {
Action HookAction
Reason string
Action HookAction `json:"action"`
Reason string `json:"reason,omitempty"`
}
func (d HookDecision) normalizedAction() HookAction {
@@ -42,20 +43,29 @@ func (d HookDecision) normalizedAction() HookAction {
}
type ApprovalDecision struct {
Approved bool
Reason string
Approved bool `json:"approved"`
Reason string `json:"reason,omitempty"`
}
type HookSource uint8
const (
HookSourceInProcess HookSource = iota
HookSourceProcess
)
type HookRegistration struct {
Name string
Priority int
Source HookSource
Hook any
}
func NamedHook(name string, hook any) HookRegistration {
return HookRegistration{
Name: name,
Hook: hook,
Name: name,
Source: HookSourceInProcess,
Hook: hook,
}
}
@@ -78,14 +88,14 @@ type ToolApprover interface {
}
type LLMHookRequest struct {
Meta EventMeta
Model string
Messages []providers.Message
Tools []providers.ToolDefinition
Options map[string]any
Channel string
ChatID string
GracefulTerminal bool
Meta EventMeta `json:"meta"`
Model string `json:"model"`
Messages []providers.Message `json:"messages,omitempty"`
Tools []providers.ToolDefinition `json:"tools,omitempty"`
Options map[string]any `json:"options,omitempty"`
Channel string `json:"channel,omitempty"`
ChatID string `json:"chat_id,omitempty"`
GracefulTerminal bool `json:"graceful_terminal,omitempty"`
}
func (r *LLMHookRequest) Clone() *LLMHookRequest {
@@ -100,11 +110,11 @@ func (r *LLMHookRequest) Clone() *LLMHookRequest {
}
type LLMHookResponse struct {
Meta EventMeta
Model string
Response *providers.LLMResponse
Channel string
ChatID string
Meta EventMeta `json:"meta"`
Model string `json:"model"`
Response *providers.LLMResponse `json:"response,omitempty"`
Channel string `json:"channel,omitempty"`
ChatID string `json:"chat_id,omitempty"`
}
func (r *LLMHookResponse) Clone() *LLMHookResponse {
@@ -117,11 +127,11 @@ func (r *LLMHookResponse) Clone() *LLMHookResponse {
}
type ToolCallHookRequest struct {
Meta EventMeta
Tool string
Arguments map[string]any
Channel string
ChatID string
Meta EventMeta `json:"meta"`
Tool string `json:"tool"`
Arguments map[string]any `json:"arguments,omitempty"`
Channel string `json:"channel,omitempty"`
ChatID string `json:"chat_id,omitempty"`
}
func (r *ToolCallHookRequest) Clone() *ToolCallHookRequest {
@@ -134,11 +144,11 @@ func (r *ToolCallHookRequest) Clone() *ToolCallHookRequest {
}
type ToolApprovalRequest struct {
Meta EventMeta
Tool string
Arguments map[string]any
Channel string
ChatID string
Meta EventMeta `json:"meta"`
Tool string `json:"tool"`
Arguments map[string]any `json:"arguments,omitempty"`
Channel string `json:"channel,omitempty"`
ChatID string `json:"chat_id,omitempty"`
}
func (r *ToolApprovalRequest) Clone() *ToolApprovalRequest {
@@ -151,13 +161,13 @@ func (r *ToolApprovalRequest) Clone() *ToolApprovalRequest {
}
type ToolResultHookResponse struct {
Meta EventMeta
Tool string
Arguments map[string]any
Result *tools.ToolResult
Duration time.Duration
Channel string
ChatID string
Meta EventMeta `json:"meta"`
Tool string `json:"tool"`
Arguments map[string]any `json:"arguments,omitempty"`
Result *tools.ToolResult `json:"result,omitempty"`
Duration time.Duration `json:"duration"`
Channel string `json:"channel,omitempty"`
ChatID string `json:"chat_id,omitempty"`
}
func (r *ToolResultHookResponse) Clone() *ToolResultHookResponse {
@@ -215,9 +225,25 @@ func (hm *HookManager) Close() {
hm.eventBus.Unsubscribe(hm.sub.ID)
}
<-hm.done
hm.closeAllHooks()
})
}
func (hm *HookManager) ConfigureTimeouts(observer, interceptor, approval time.Duration) {
if hm == nil {
return
}
if observer > 0 {
hm.observerTimeout = observer
}
if interceptor > 0 {
hm.interceptorTimeout = interceptor
}
if approval > 0 {
hm.approvalTimeout = approval
}
}
func (hm *HookManager) Mount(reg HookRegistration) error {
if hm == nil {
return fmt.Errorf("hook manager is nil")
@@ -232,6 +258,9 @@ func (hm *HookManager) Mount(reg HookRegistration) error {
hm.mu.Lock()
defer hm.mu.Unlock()
if existing, ok := hm.hooks[reg.Name]; ok {
closeHookIfPossible(existing.Hook)
}
hm.hooks[reg.Name] = reg
hm.rebuildOrdered()
return nil
@@ -245,6 +274,9 @@ func (hm *HookManager) Unmount(name string) {
hm.mu.Lock()
defer hm.mu.Unlock()
if existing, ok := hm.hooks[name]; ok {
closeHookIfPossible(existing.Hook)
}
delete(hm.hooks, name)
hm.rebuildOrdered()
}
@@ -425,6 +457,9 @@ func (hm *HookManager) rebuildOrdered() {
hm.ordered = append(hm.ordered, reg)
}
sort.SliceStable(hm.ordered, func(i, j int) bool {
if hm.ordered[i].Source != hm.ordered[j].Source {
return hm.ordered[i].Source < hm.ordered[j].Source
}
if hm.ordered[i].Priority == hm.ordered[j].Priority {
return hm.ordered[i].Name < hm.ordered[j].Name
}
@@ -441,6 +476,17 @@ func (hm *HookManager) snapshotHooks() []HookRegistration {
return snapshot
}
func (hm *HookManager) closeAllHooks() {
hm.mu.Lock()
defer hm.mu.Unlock()
for name, reg := range hm.hooks {
closeHookIfPossible(reg.Hook)
delete(hm.hooks, name)
}
hm.ordered = nil
}
func (hm *HookManager) runObserver(name string, observer EventObserver, evt Event) {
ctx, cancel := context.WithTimeout(context.Background(), hm.observerTimeout)
defer cancel()
@@ -749,3 +795,15 @@ func cloneToolResult(result *tools.ToolResult) *tools.ToolResult {
}
return &cloned
}
func closeHookIfPossible(hook any) {
closer, ok := hook.(io.Closer)
if !ok {
return
}
if err := closer.Close(); err != nil {
logger.WarnCF("hooks", "Failed to close hook", map[string]any{
"error": err.Error(),
})
}
}
+33
View File
@@ -47,6 +47,39 @@ func newHookTestLoop(
}
}
func TestHookManager_SortsInProcessBeforeProcess(t *testing.T) {
hm := NewHookManager(nil)
defer hm.Close()
if err := hm.Mount(HookRegistration{
Name: "process",
Priority: -10,
Source: HookSourceProcess,
Hook: struct{}{},
}); err != nil {
t.Fatalf("mount process hook: %v", err)
}
if err := hm.Mount(HookRegistration{
Name: "in-process",
Priority: 100,
Source: HookSourceInProcess,
Hook: struct{}{},
}); err != nil {
t.Fatalf("mount in-process hook: %v", err)
}
ordered := hm.snapshotHooks()
if len(ordered) != 2 {
t.Fatalf("expected 2 hooks, got %d", len(ordered))
}
if ordered[0].Name != "in-process" {
t.Fatalf("expected in-process hook first, got %q", ordered[0].Name)
}
if ordered[1].Name != "process" {
t.Fatalf("expected process hook second, got %q", ordered[1].Name)
}
}
type llmHookTestProvider struct {
mu sync.Mutex
lastModel string
+18
View File
@@ -49,6 +49,7 @@ type AgentLoop struct {
transcriber voice.Transcriber
cmdRegistry *commands.Registry
mcp mcpRuntime
hookRuntime hookRuntime
steering *steeringQueue
mu sync.RWMutex
activeTurnMu sync.RWMutex
@@ -122,6 +123,7 @@ func NewAgentLoop(
steering: newSteeringQueue(parseSteeringMode(cfg.Agents.Defaults.SteeringMode)),
}
al.hooks = NewHookManager(eventBus)
configureHookManagerFromConfig(al.hooks, cfg)
return al
}
@@ -259,6 +261,9 @@ func registerSharedTools(
func (al *AgentLoop) Run(ctx context.Context) error {
al.running.Store(true)
if err := al.ensureHooksInitialized(ctx); err != nil {
return err
}
if err := al.ensureMCPInitialized(ctx); err != nil {
return err
}
@@ -773,6 +778,9 @@ func (al *AgentLoop) ReloadProviderAndConfig(
al.mu.Unlock()
al.hookRuntime.reset(al)
configureHookManagerFromConfig(al.hooks, cfg)
// Close old provider after releasing the lock
// This prevents blocking readers while closing
if oldProvider, ok := extractProvider(oldRegistry); ok {
@@ -987,6 +995,9 @@ func (al *AgentLoop) ProcessDirectWithChannel(
ctx context.Context,
content, sessionKey, channel, chatID string,
) (string, error) {
if err := al.ensureHooksInitialized(ctx); err != nil {
return "", err
}
if err := al.ensureMCPInitialized(ctx); err != nil {
return "", err
}
@@ -1008,6 +1019,13 @@ func (al *AgentLoop) ProcessHeartbeat(
ctx context.Context,
content, channel, chatID string,
) (string, error) {
if err := al.ensureHooksInitialized(ctx); err != nil {
return "", err
}
if err := al.ensureMCPInitialized(ctx); err != nil {
return "", err
}
agent := al.GetRegistry().GetDefaultAgent()
if agent == nil {
return "", fmt.Errorf("no default agent for heartbeat")
+6
View File
@@ -183,6 +183,12 @@ func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID s
if active := al.GetActiveTurn(); active != nil {
return "", fmt.Errorf("turn %s is still active", active.TurnID)
}
if err := al.ensureHooksInitialized(ctx); err != nil {
return "", err
}
if err := al.ensureMCPInitialized(ctx); err != nil {
return "", err
}
steeringMsgs := al.dequeueSteeringMessages()
if len(steeringMsgs) == 0 {
+31
View File
@@ -82,6 +82,7 @@ type Config struct {
Providers ProvidersConfig `json:"providers,omitempty"`
ModelList []ModelConfig `json:"model_list"` // New model-centric provider configuration
Gateway GatewayConfig `json:"gateway"`
Hooks HooksConfig `json:"hooks,omitempty"`
Tools ToolsConfig `json:"tools"`
Heartbeat HeartbeatConfig `json:"heartbeat"`
Devices DevicesConfig `json:"devices"`
@@ -90,6 +91,36 @@ type Config struct {
BuildInfo BuildInfo `json:"build_info,omitempty"`
}
type HooksConfig struct {
Enabled bool `json:"enabled"`
Defaults HookDefaultsConfig `json:"defaults,omitempty"`
Builtins map[string]BuiltinHookConfig `json:"builtins,omitempty"`
Processes map[string]ProcessHookConfig `json:"processes,omitempty"`
}
type HookDefaultsConfig struct {
ObserverTimeoutMS int `json:"observer_timeout_ms,omitempty"`
InterceptorTimeoutMS int `json:"interceptor_timeout_ms,omitempty"`
ApprovalTimeoutMS int `json:"approval_timeout_ms,omitempty"`
}
type BuiltinHookConfig struct {
Enabled bool `json:"enabled"`
Priority int `json:"priority,omitempty"`
Config json.RawMessage `json:"config,omitempty"`
}
type ProcessHookConfig struct {
Enabled bool `json:"enabled"`
Priority int `json:"priority,omitempty"`
Transport string `json:"transport,omitempty"`
Command []string `json:"command,omitempty"`
Dir string `json:"dir,omitempty"`
Env map[string]string `json:"env,omitempty"`
Observe []string `json:"observe,omitempty"`
Intercept []string `json:"intercept,omitempty"`
}
// BuildInfo contains build-time version information
type BuildInfo struct {
Version string `json:"version"`
+98
View File
@@ -391,6 +391,22 @@ func TestDefaultConfig_ExecAllowRemoteEnabled(t *testing.T) {
}
}
func TestDefaultConfig_HooksDefaults(t *testing.T) {
cfg := DefaultConfig()
if !cfg.Hooks.Enabled {
t.Fatal("DefaultConfig().Hooks.Enabled should be true")
}
if cfg.Hooks.Defaults.ObserverTimeoutMS != 500 {
t.Fatalf("ObserverTimeoutMS = %d, want 500", cfg.Hooks.Defaults.ObserverTimeoutMS)
}
if cfg.Hooks.Defaults.InterceptorTimeoutMS != 5000 {
t.Fatalf("InterceptorTimeoutMS = %d, want 5000", cfg.Hooks.Defaults.InterceptorTimeoutMS)
}
if cfg.Hooks.Defaults.ApprovalTimeoutMS != 60000 {
t.Fatalf("ApprovalTimeoutMS = %d, want 60000", cfg.Hooks.Defaults.ApprovalTimeoutMS)
}
}
func TestLoadConfig_OpenAIWebSearchDefaultsTrueWhenUnset(t *testing.T) {
dir := t.TempDir()
configPath := filepath.Join(dir, "config.json")
@@ -460,6 +476,88 @@ func TestLoadConfig_WebToolsProxy(t *testing.T) {
}
}
func TestLoadConfig_HooksProcessConfig(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.json")
configJSON := `{
"hooks": {
"processes": {
"review-gate": {
"enabled": true,
"transport": "stdio",
"command": ["uvx", "picoclaw-hook-reviewer"],
"dir": "/tmp/hooks",
"env": {
"HOOK_MODE": "rewrite"
},
"observe": ["turn_start", "turn_end"],
"intercept": ["before_tool", "approve_tool"]
}
},
"builtins": {
"audit": {
"enabled": true,
"priority": 5,
"config": {
"label": "audit"
}
}
}
}
}`
if err := os.WriteFile(configPath, []byte(configJSON), 0o600); err != nil {
t.Fatalf("os.WriteFile() error: %v", err)
}
cfg, err := LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error: %v", err)
}
processCfg, ok := cfg.Hooks.Processes["review-gate"]
if !ok {
t.Fatal("expected review-gate process hook")
}
if !processCfg.Enabled {
t.Fatal("expected review-gate process hook to be enabled")
}
if processCfg.Transport != "stdio" {
t.Fatalf("Transport = %q, want stdio", processCfg.Transport)
}
if len(processCfg.Command) != 2 || processCfg.Command[0] != "uvx" {
t.Fatalf("Command = %v", processCfg.Command)
}
if processCfg.Dir != "/tmp/hooks" {
t.Fatalf("Dir = %q, want /tmp/hooks", processCfg.Dir)
}
if processCfg.Env["HOOK_MODE"] != "rewrite" {
t.Fatalf("HOOK_MODE = %q, want rewrite", processCfg.Env["HOOK_MODE"])
}
if len(processCfg.Observe) != 2 || processCfg.Observe[1] != "turn_end" {
t.Fatalf("Observe = %v", processCfg.Observe)
}
if len(processCfg.Intercept) != 2 || processCfg.Intercept[1] != "approve_tool" {
t.Fatalf("Intercept = %v", processCfg.Intercept)
}
builtinCfg, ok := cfg.Hooks.Builtins["audit"]
if !ok {
t.Fatal("expected audit builtin hook")
}
if !builtinCfg.Enabled {
t.Fatal("expected audit builtin hook to be enabled")
}
if builtinCfg.Priority != 5 {
t.Fatalf("Priority = %d, want 5", builtinCfg.Priority)
}
if !strings.Contains(string(builtinCfg.Config), `"audit"`) {
t.Fatalf("Config = %s", string(builtinCfg.Config))
}
if cfg.Hooks.Defaults.ApprovalTimeoutMS != 60000 {
t.Fatalf("ApprovalTimeoutMS = %d, want 60000", cfg.Hooks.Defaults.ApprovalTimeoutMS)
}
}
// TestDefaultConfig_DMScope verifies the default dm_scope value
// TestDefaultConfig_SummarizationThresholds verifies summarization defaults
func TestDefaultConfig_SummarizationThresholds(t *testing.T) {
+8
View File
@@ -177,6 +177,14 @@ func DefaultConfig() *Config {
AllowFrom: FlexibleStringSlice{},
},
},
Hooks: HooksConfig{
Enabled: true,
Defaults: HookDefaultsConfig{
ObserverTimeoutMS: 500,
InterceptorTimeoutMS: 5000,
ApprovalTimeoutMS: 60000,
},
},
Providers: ProvidersConfig{
OpenAI: OpenAIProviderConfig{WebSearch: true},
},