mirror of
https://github.com/sipeed/picoclaw.git
synced 2026-06-12 18:08:54 +00:00
Merge pull request #1863 from alexhoshina/feat/hook-manager
Feat/hook manager
This commit is contained in:
+280
-47
@@ -40,6 +40,7 @@ type AgentLoop struct {
|
||||
registry *AgentRegistry
|
||||
state *state.Manager
|
||||
eventBus *EventBus
|
||||
hooks *HookManager
|
||||
running atomic.Bool
|
||||
summarizing sync.Map
|
||||
fallback *providers.FallbackChain
|
||||
@@ -48,6 +49,7 @@ type AgentLoop struct {
|
||||
transcriber voice.Transcriber
|
||||
cmdRegistry *commands.Registry
|
||||
mcp mcpRuntime
|
||||
hookRuntime hookRuntime
|
||||
steering *steeringQueue
|
||||
mu sync.RWMutex
|
||||
activeTurnMu sync.RWMutex
|
||||
@@ -109,17 +111,20 @@ func NewAgentLoop(
|
||||
stateManager = state.NewManager(defaultAgent.Workspace)
|
||||
}
|
||||
|
||||
eventBus := NewEventBus()
|
||||
al := &AgentLoop{
|
||||
bus: msgBus,
|
||||
cfg: cfg,
|
||||
registry: registry,
|
||||
state: stateManager,
|
||||
eventBus: NewEventBus(),
|
||||
eventBus: eventBus,
|
||||
summarizing: sync.Map{},
|
||||
fallback: fallbackChain,
|
||||
cmdRegistry: commands.NewRegistry(commands.BuiltinDefinitions()),
|
||||
steering: newSteeringQueue(parseSteeringMode(cfg.Agents.Defaults.SteeringMode)),
|
||||
}
|
||||
al.hooks = NewHookManager(eventBus)
|
||||
configureHookManagerFromConfig(al.hooks, cfg)
|
||||
|
||||
return al
|
||||
}
|
||||
@@ -257,6 +262,9 @@ func registerSharedTools(
|
||||
func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
al.running.Store(true)
|
||||
|
||||
if err := al.ensureHooksInitialized(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := al.ensureMCPInitialized(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -512,11 +520,30 @@ func (al *AgentLoop) Close() {
|
||||
}
|
||||
|
||||
al.GetRegistry().Close()
|
||||
if al.hooks != nil {
|
||||
al.hooks.Close()
|
||||
}
|
||||
if al.eventBus != nil {
|
||||
al.eventBus.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// MountHook registers an in-process hook on the agent loop.
|
||||
func (al *AgentLoop) MountHook(reg HookRegistration) error {
|
||||
if al == nil || al.hooks == nil {
|
||||
return fmt.Errorf("hook manager is not initialized")
|
||||
}
|
||||
return al.hooks.Mount(reg)
|
||||
}
|
||||
|
||||
// UnmountHook removes a previously registered in-process hook.
|
||||
func (al *AgentLoop) UnmountHook(name string) {
|
||||
if al == nil || al.hooks == nil {
|
||||
return
|
||||
}
|
||||
al.hooks.Unmount(name)
|
||||
}
|
||||
|
||||
// SubscribeEvents registers a subscriber for agent-loop events.
|
||||
func (al *AgentLoop) SubscribeEvents(buffer int) EventSubscription {
|
||||
if al == nil || al.eventBus == nil {
|
||||
@@ -596,6 +623,31 @@ func cloneEventArguments(args map[string]any) map[string]any {
|
||||
return cloned
|
||||
}
|
||||
|
||||
func (al *AgentLoop) hookAbortError(ts *turnState, stage string, decision HookDecision) error {
|
||||
reason := decision.Reason
|
||||
if reason == "" {
|
||||
reason = "hook requested turn abort"
|
||||
}
|
||||
|
||||
err := fmt.Errorf("hook aborted turn during %s: %s", stage, reason)
|
||||
al.emitEvent(
|
||||
EventKindError,
|
||||
ts.eventMeta("hooks", "turn.error"),
|
||||
ErrorPayload{
|
||||
Stage: "hook." + stage,
|
||||
Message: err.Error(),
|
||||
},
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func hookDeniedToolContent(prefix, reason string) string {
|
||||
if reason == "" {
|
||||
return prefix
|
||||
}
|
||||
return prefix + ": " + reason
|
||||
}
|
||||
|
||||
func (al *AgentLoop) logEvent(evt Event) {
|
||||
fields := map[string]any{
|
||||
"event_kind": evt.Kind.String(),
|
||||
@@ -778,6 +830,9 @@ func (al *AgentLoop) ReloadProviderAndConfig(
|
||||
|
||||
al.mu.Unlock()
|
||||
|
||||
al.hookRuntime.reset(al)
|
||||
configureHookManagerFromConfig(al.hooks, cfg)
|
||||
|
||||
// Close old provider after releasing the lock
|
||||
// This prevents blocking readers while closing
|
||||
if oldProvider, ok := extractProvider(oldRegistry); ok {
|
||||
@@ -992,6 +1047,9 @@ func (al *AgentLoop) ProcessDirectWithChannel(
|
||||
ctx context.Context,
|
||||
content, sessionKey, channel, chatID string,
|
||||
) (string, error) {
|
||||
if err := al.ensureHooksInitialized(ctx); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := al.ensureMCPInitialized(ctx); err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -1013,6 +1071,13 @@ func (al *AgentLoop) ProcessHeartbeat(
|
||||
ctx context.Context,
|
||||
content, channel, chatID string,
|
||||
) (string, error) {
|
||||
if err := al.ensureHooksInitialized(ctx); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := al.ensureMCPInitialized(ctx); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
agent := al.GetRegistry().GetDefaultAgent()
|
||||
if agent == nil {
|
||||
return "", fmt.Errorf("no default agent for heartbeat")
|
||||
@@ -1504,36 +1569,6 @@ turnLoop:
|
||||
ts.markGracefulTerminalUsed()
|
||||
}
|
||||
|
||||
al.emitEvent(
|
||||
EventKindLLMRequest,
|
||||
ts.eventMeta("runTurn", "turn.llm.request"),
|
||||
LLMRequestPayload{
|
||||
Model: activeModel,
|
||||
MessagesCount: len(callMessages),
|
||||
ToolsCount: len(providerToolDefs),
|
||||
MaxTokens: ts.agent.MaxTokens,
|
||||
Temperature: ts.agent.Temperature,
|
||||
},
|
||||
)
|
||||
|
||||
logger.DebugCF("agent", "LLM request",
|
||||
map[string]any{
|
||||
"agent_id": ts.agent.ID,
|
||||
"iteration": iteration,
|
||||
"model": activeModel,
|
||||
"messages_count": len(callMessages),
|
||||
"tools_count": len(providerToolDefs),
|
||||
"max_tokens": ts.agent.MaxTokens,
|
||||
"temperature": ts.agent.Temperature,
|
||||
"system_prompt_len": len(callMessages[0].Content),
|
||||
})
|
||||
logger.DebugCF("agent", "Full LLM request",
|
||||
map[string]any{
|
||||
"iteration": iteration,
|
||||
"messages_json": formatMessagesForLog(callMessages),
|
||||
"tools_json": formatToolsForLog(providerToolDefs),
|
||||
})
|
||||
|
||||
llmOpts := map[string]any{
|
||||
"max_tokens": ts.agent.MaxTokens,
|
||||
"temperature": ts.agent.Temperature,
|
||||
@@ -1548,6 +1583,66 @@ turnLoop:
|
||||
}
|
||||
}
|
||||
|
||||
llmModel := activeModel
|
||||
if al.hooks != nil {
|
||||
llmReq, decision := al.hooks.BeforeLLM(turnCtx, &LLMHookRequest{
|
||||
Meta: ts.eventMeta("runTurn", "turn.llm.request"),
|
||||
Model: llmModel,
|
||||
Messages: callMessages,
|
||||
Tools: providerToolDefs,
|
||||
Options: llmOpts,
|
||||
Channel: ts.channel,
|
||||
ChatID: ts.chatID,
|
||||
GracefulTerminal: gracefulTerminal,
|
||||
})
|
||||
switch decision.normalizedAction() {
|
||||
case HookActionContinue, HookActionModify:
|
||||
if llmReq != nil {
|
||||
llmModel = llmReq.Model
|
||||
callMessages = llmReq.Messages
|
||||
providerToolDefs = llmReq.Tools
|
||||
llmOpts = llmReq.Options
|
||||
}
|
||||
case HookActionAbortTurn:
|
||||
turnStatus = TurnEndStatusError
|
||||
return turnResult{}, al.hookAbortError(ts, "before_llm", decision)
|
||||
case HookActionHardAbort:
|
||||
_ = ts.requestHardAbort()
|
||||
turnStatus = TurnEndStatusAborted
|
||||
return al.abortTurn(ts)
|
||||
}
|
||||
}
|
||||
|
||||
al.emitEvent(
|
||||
EventKindLLMRequest,
|
||||
ts.eventMeta("runTurn", "turn.llm.request"),
|
||||
LLMRequestPayload{
|
||||
Model: llmModel,
|
||||
MessagesCount: len(callMessages),
|
||||
ToolsCount: len(providerToolDefs),
|
||||
MaxTokens: ts.agent.MaxTokens,
|
||||
Temperature: ts.agent.Temperature,
|
||||
},
|
||||
)
|
||||
|
||||
logger.DebugCF("agent", "LLM request",
|
||||
map[string]any{
|
||||
"agent_id": ts.agent.ID,
|
||||
"iteration": iteration,
|
||||
"model": llmModel,
|
||||
"messages_count": len(callMessages),
|
||||
"tools_count": len(providerToolDefs),
|
||||
"max_tokens": ts.agent.MaxTokens,
|
||||
"temperature": ts.agent.Temperature,
|
||||
"system_prompt_len": len(callMessages[0].Content),
|
||||
})
|
||||
logger.DebugCF("agent", "Full LLM request",
|
||||
map[string]any{
|
||||
"iteration": iteration,
|
||||
"messages_json": formatMessagesForLog(callMessages),
|
||||
"tools_json": formatToolsForLog(providerToolDefs),
|
||||
})
|
||||
|
||||
callLLM := func(messagesForCall []providers.Message, toolDefsForCall []providers.ToolDefinition) (*providers.LLMResponse, error) {
|
||||
providerCtx, providerCancel := context.WithCancel(turnCtx)
|
||||
ts.setProviderCancel(providerCancel)
|
||||
@@ -1580,7 +1675,7 @@ turnLoop:
|
||||
}
|
||||
return fbResult.Response, nil
|
||||
}
|
||||
return ts.agent.Provider.Chat(providerCtx, messagesForCall, toolDefsForCall, activeModel, llmOpts)
|
||||
return ts.agent.Provider.Chat(providerCtx, messagesForCall, toolDefsForCall, llmModel, llmOpts)
|
||||
}
|
||||
|
||||
var response *providers.LLMResponse
|
||||
@@ -1712,12 +1807,35 @@ turnLoop:
|
||||
map[string]any{
|
||||
"agent_id": ts.agent.ID,
|
||||
"iteration": iteration,
|
||||
"model": activeModel,
|
||||
"model": llmModel,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return turnResult{}, fmt.Errorf("LLM call failed after retries: %w", err)
|
||||
}
|
||||
|
||||
if al.hooks != nil {
|
||||
llmResp, decision := al.hooks.AfterLLM(turnCtx, &LLMHookResponse{
|
||||
Meta: ts.eventMeta("runTurn", "turn.llm.response"),
|
||||
Model: llmModel,
|
||||
Response: response,
|
||||
Channel: ts.channel,
|
||||
ChatID: ts.chatID,
|
||||
})
|
||||
switch decision.normalizedAction() {
|
||||
case HookActionContinue, HookActionModify:
|
||||
if llmResp != nil && llmResp.Response != nil {
|
||||
response = llmResp.Response
|
||||
}
|
||||
case HookActionAbortTurn:
|
||||
turnStatus = TurnEndStatusError
|
||||
return turnResult{}, al.hookAbortError(ts, "after_llm", decision)
|
||||
case HookActionHardAbort:
|
||||
_ = ts.requestHardAbort()
|
||||
turnStatus = TurnEndStatusAborted
|
||||
return al.abortTurn(ts)
|
||||
}
|
||||
}
|
||||
|
||||
go al.handleReasoning(
|
||||
turnCtx,
|
||||
response.Reasoning,
|
||||
@@ -1825,25 +1943,106 @@ turnLoop:
|
||||
return al.abortTurn(ts)
|
||||
}
|
||||
|
||||
argsJSON, _ := json.Marshal(tc.Arguments)
|
||||
toolName := tc.Name
|
||||
toolArgs := cloneStringAnyMap(tc.Arguments)
|
||||
|
||||
if al.hooks != nil {
|
||||
toolReq, decision := al.hooks.BeforeTool(turnCtx, &ToolCallHookRequest{
|
||||
Meta: ts.eventMeta("runTurn", "turn.tool.before"),
|
||||
Tool: toolName,
|
||||
Arguments: toolArgs,
|
||||
Channel: ts.channel,
|
||||
ChatID: ts.chatID,
|
||||
})
|
||||
switch decision.normalizedAction() {
|
||||
case HookActionContinue, HookActionModify:
|
||||
if toolReq != nil {
|
||||
toolName = toolReq.Tool
|
||||
toolArgs = toolReq.Arguments
|
||||
}
|
||||
case HookActionDenyTool:
|
||||
denyContent := hookDeniedToolContent("Tool execution denied by hook", decision.Reason)
|
||||
al.emitEvent(
|
||||
EventKindToolExecSkipped,
|
||||
ts.eventMeta("runTurn", "turn.tool.skipped"),
|
||||
ToolExecSkippedPayload{
|
||||
Tool: toolName,
|
||||
Reason: denyContent,
|
||||
},
|
||||
)
|
||||
deniedMsg := providers.Message{
|
||||
Role: "tool",
|
||||
Content: denyContent,
|
||||
ToolCallID: tc.ID,
|
||||
}
|
||||
messages = append(messages, deniedMsg)
|
||||
if !ts.opts.NoHistory {
|
||||
ts.agent.Sessions.AddFullMessage(ts.sessionKey, deniedMsg)
|
||||
ts.recordPersistedMessage(deniedMsg)
|
||||
}
|
||||
continue
|
||||
case HookActionAbortTurn:
|
||||
turnStatus = TurnEndStatusError
|
||||
return turnResult{}, al.hookAbortError(ts, "before_tool", decision)
|
||||
case HookActionHardAbort:
|
||||
_ = ts.requestHardAbort()
|
||||
turnStatus = TurnEndStatusAborted
|
||||
return al.abortTurn(ts)
|
||||
}
|
||||
}
|
||||
|
||||
if al.hooks != nil {
|
||||
approval := al.hooks.ApproveTool(turnCtx, &ToolApprovalRequest{
|
||||
Meta: ts.eventMeta("runTurn", "turn.tool.approve"),
|
||||
Tool: toolName,
|
||||
Arguments: toolArgs,
|
||||
Channel: ts.channel,
|
||||
ChatID: ts.chatID,
|
||||
})
|
||||
if !approval.Approved {
|
||||
denyContent := hookDeniedToolContent("Tool execution denied by approval hook", approval.Reason)
|
||||
al.emitEvent(
|
||||
EventKindToolExecSkipped,
|
||||
ts.eventMeta("runTurn", "turn.tool.skipped"),
|
||||
ToolExecSkippedPayload{
|
||||
Tool: toolName,
|
||||
Reason: denyContent,
|
||||
},
|
||||
)
|
||||
deniedMsg := providers.Message{
|
||||
Role: "tool",
|
||||
Content: denyContent,
|
||||
ToolCallID: tc.ID,
|
||||
}
|
||||
messages = append(messages, deniedMsg)
|
||||
if !ts.opts.NoHistory {
|
||||
ts.agent.Sessions.AddFullMessage(ts.sessionKey, deniedMsg)
|
||||
ts.recordPersistedMessage(deniedMsg)
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
argsJSON, _ := json.Marshal(toolArgs)
|
||||
argsPreview := utils.Truncate(string(argsJSON), 200)
|
||||
logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
|
||||
logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", toolName, argsPreview),
|
||||
map[string]any{
|
||||
"agent_id": ts.agent.ID,
|
||||
"tool": tc.Name,
|
||||
"tool": toolName,
|
||||
"iteration": iteration,
|
||||
})
|
||||
al.emitEvent(
|
||||
EventKindToolExecStart,
|
||||
ts.eventMeta("runTurn", "turn.tool.start"),
|
||||
ToolExecStartPayload{
|
||||
Tool: tc.Name,
|
||||
Arguments: cloneEventArguments(tc.Arguments),
|
||||
Tool: toolName,
|
||||
Arguments: cloneEventArguments(toolArgs),
|
||||
},
|
||||
)
|
||||
|
||||
toolCall := tc
|
||||
toolCallID := tc.ID
|
||||
toolIteration := iteration
|
||||
asyncToolName := toolName
|
||||
asyncCallback := func(_ context.Context, result *tools.ToolResult) {
|
||||
if !result.Silent && result.ForUser != "" {
|
||||
outCtx, outCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
@@ -1865,7 +2064,7 @@ turnLoop:
|
||||
|
||||
logger.InfoCF("agent", "Async tool completed, publishing result",
|
||||
map[string]any{
|
||||
"tool": toolCall.Name,
|
||||
"tool": asyncToolName,
|
||||
"content_len": len(content),
|
||||
"channel": ts.channel,
|
||||
})
|
||||
@@ -1873,7 +2072,7 @@ turnLoop:
|
||||
EventKindFollowUpQueued,
|
||||
ts.scope.meta(toolIteration, "runTurn", "turn.follow_up.queued"),
|
||||
FollowUpQueuedPayload{
|
||||
SourceTool: toolCall.Name,
|
||||
SourceTool: asyncToolName,
|
||||
Channel: ts.channel,
|
||||
ChatID: ts.chatID,
|
||||
ContentLen: len(content),
|
||||
@@ -1884,7 +2083,7 @@ turnLoop:
|
||||
defer pubCancel()
|
||||
_ = al.bus.PublishInbound(pubCtx, bus.InboundMessage{
|
||||
Channel: "system",
|
||||
SenderID: fmt.Sprintf("async:%s", toolCall.Name),
|
||||
SenderID: fmt.Sprintf("async:%s", asyncToolName),
|
||||
ChatID: fmt.Sprintf("%s:%s", ts.channel, ts.chatID),
|
||||
Content: content,
|
||||
})
|
||||
@@ -1893,8 +2092,8 @@ turnLoop:
|
||||
toolStart := time.Now()
|
||||
toolResult := ts.agent.Tools.ExecuteWithContext(
|
||||
turnCtx,
|
||||
toolCall.Name,
|
||||
toolCall.Arguments,
|
||||
toolName,
|
||||
toolArgs,
|
||||
ts.channel,
|
||||
ts.chatID,
|
||||
asyncCallback,
|
||||
@@ -1906,6 +2105,40 @@ turnLoop:
|
||||
return al.abortTurn(ts)
|
||||
}
|
||||
|
||||
if al.hooks != nil {
|
||||
toolResp, decision := al.hooks.AfterTool(turnCtx, &ToolResultHookResponse{
|
||||
Meta: ts.eventMeta("runTurn", "turn.tool.after"),
|
||||
Tool: toolName,
|
||||
Arguments: toolArgs,
|
||||
Result: toolResult,
|
||||
Duration: toolDuration,
|
||||
Channel: ts.channel,
|
||||
ChatID: ts.chatID,
|
||||
})
|
||||
switch decision.normalizedAction() {
|
||||
case HookActionContinue, HookActionModify:
|
||||
if toolResp != nil {
|
||||
if toolResp.Tool != "" {
|
||||
toolName = toolResp.Tool
|
||||
}
|
||||
if toolResp.Result != nil {
|
||||
toolResult = toolResp.Result
|
||||
}
|
||||
}
|
||||
case HookActionAbortTurn:
|
||||
turnStatus = TurnEndStatusError
|
||||
return turnResult{}, al.hookAbortError(ts, "after_tool", decision)
|
||||
case HookActionHardAbort:
|
||||
_ = ts.requestHardAbort()
|
||||
turnStatus = TurnEndStatusAborted
|
||||
return al.abortTurn(ts)
|
||||
}
|
||||
}
|
||||
|
||||
if toolResult == nil {
|
||||
toolResult = tools.ErrorResult("hook returned nil tool result")
|
||||
}
|
||||
|
||||
if !toolResult.Silent && toolResult.ForUser != "" && ts.opts.SendResponse {
|
||||
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
|
||||
Channel: ts.channel,
|
||||
@@ -1914,7 +2147,7 @@ turnLoop:
|
||||
})
|
||||
logger.DebugCF("agent", "Sent tool result to user",
|
||||
map[string]any{
|
||||
"tool": toolCall.Name,
|
||||
"tool": toolName,
|
||||
"content_len": len(toolResult.ForUser),
|
||||
})
|
||||
}
|
||||
@@ -1947,13 +2180,13 @@ turnLoop:
|
||||
toolResultMsg := providers.Message{
|
||||
Role: "tool",
|
||||
Content: contentForLLM,
|
||||
ToolCallID: toolCall.ID,
|
||||
ToolCallID: toolCallID,
|
||||
}
|
||||
al.emitEvent(
|
||||
EventKindToolExecEnd,
|
||||
ts.eventMeta("runTurn", "turn.tool.end"),
|
||||
ToolExecEndPayload{
|
||||
Tool: toolCall.Name,
|
||||
Tool: toolName,
|
||||
Duration: toolDuration,
|
||||
ForLLMLen: len(contentForLLM),
|
||||
ForUserLen: len(toolResult.ForUser),
|
||||
|
||||
Reference in New Issue
Block a user