From 9c82b0baa224d419cb63ba986bdbb27e3c115785 Mon Sep 17 00:00:00 2001
From: xiaoen <2768753269@qq.com>
Date: Fri, 13 Mar 2026 14:20:24 +0800
Subject: [PATCH 01/26] refactor(agent): context boundary detection, proactive
budget check, and safe compression
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Separate context_window from max_tokens — they serve different purposes
(input capacity vs output generation limit). The previous conflation caused
premature summarization or missed compression triggers.
Changes:
- Add context_window field to AgentDefaults config (default: 4x max_tokens)
- Extract boundary-safe truncation helpers (isSafeBoundary, findSafeBoundary)
into context_budget.go — pure functions with no AgentLoop dependency
- forceCompression: align split to safe boundary so tool-call sequences
(assistant+ToolCalls → tool results) are never torn apart
- summarizeSession: use findSafeBoundary instead of hardcoded keep-last-4
- estimateTokens: count ToolCalls arguments and ToolCallID metadata,
not just Content — fixes systematic undercounting in tool-heavy sessions
- Add proactive context budget check before LLM call in runAgentLoop,
preventing 400 context-length errors instead of reacting to them
- Add estimateToolDefsTokens for tool definition token cost
Closes #556, closes #665
Ref #1439
---
pkg/agent/context_budget.go | 133 ++++++++
pkg/agent/context_budget_test.go | 545 +++++++++++++++++++++++++++++++
pkg/agent/instance.go | 13 +-
pkg/agent/loop.go | 49 ++-
pkg/config/config.go | 1 +
5 files changed, 727 insertions(+), 14 deletions(-)
create mode 100644 pkg/agent/context_budget.go
create mode 100644 pkg/agent/context_budget_test.go
diff --git a/pkg/agent/context_budget.go b/pkg/agent/context_budget.go
new file mode 100644
index 000000000..2eec9c267
--- /dev/null
+++ b/pkg/agent/context_budget.go
@@ -0,0 +1,133 @@
+// PicoClaw - Ultra-lightweight personal AI agent
+// License: MIT
+//
+// Copyright (c) 2026 PicoClaw contributors
+
+package agent
+
+import (
+ "encoding/json"
+ "unicode/utf8"
+
+ "github.com/sipeed/picoclaw/pkg/providers"
+)
+
+// isSafeBoundary reports whether index is a valid position to split a message
+// history for truncation or compression. Splitting at index means:
+// - history[:index] is dropped or summarized
+// - history[index:] is kept
+//
+// A boundary is safe when the kept portion begins at a "user" message,
+// ensuring no tool-call sequence (assistant+ToolCalls → tool results)
+// is torn apart across the split.
+func isSafeBoundary(history []providers.Message, index int) bool {
+ if index <= 0 || index >= len(history) {
+ return true
+ }
+ return history[index].Role == "user"
+}
+
+// findSafeBoundary locates the nearest safe split point to targetIndex.
+// It scans backward first (preserving more context), then forward.
+// Returns targetIndex unchanged only when no safe boundary exists.
+func findSafeBoundary(history []providers.Message, targetIndex int) int {
+ if len(history) == 0 {
+ return 0
+ }
+ if targetIndex <= 0 {
+ return 0
+ }
+ if targetIndex >= len(history) {
+ return len(history)
+ }
+
+ if isSafeBoundary(history, targetIndex) {
+ return targetIndex
+ }
+
+ // Backward scan: prefer keeping more messages.
+ for i := targetIndex - 1; i > 0; i-- {
+ if isSafeBoundary(history, i) {
+ return i
+ }
+ }
+
+ // Forward scan: fall back to keeping fewer messages.
+ for i := targetIndex + 1; i < len(history); i++ {
+ if isSafeBoundary(history, i) {
+ return i
+ }
+ }
+
+ return targetIndex
+}
+
+// estimateMessageTokens estimates the token count for a single message,
+// including Content, ToolCalls arguments, and ToolCallID metadata.
+// Uses a heuristic of 2.5 characters per token.
+func estimateMessageTokens(msg providers.Message) int {
+ chars := utf8.RuneCountInString(msg.Content)
+
+ for _, tc := range msg.ToolCalls {
+ // Count tool call metadata: ID, type, function name
+ chars += len(tc.ID) + len(tc.Type) + len(tc.Name)
+ if tc.Function != nil {
+ chars += len(tc.Function.Name) + len(tc.Function.Arguments)
+ }
+ }
+
+ if msg.ToolCallID != "" {
+ chars += len(msg.ToolCallID)
+ }
+
+ // Per-message overhead for role label, JSON structure, separators.
+ const messageOverhead = 12
+ chars += messageOverhead
+
+ return chars * 2 / 5
+}
+
+// estimateToolDefsTokens estimates the total token cost of tool definitions
+// as they appear in the LLM request. Each tool's name, description, and
+// JSON schema parameters contribute to the context window budget.
+func estimateToolDefsTokens(defs []providers.ToolDefinition) int {
+ if len(defs) == 0 {
+ return 0
+ }
+
+ totalChars := 0
+ for _, d := range defs {
+ totalChars += len(d.Function.Name) + len(d.Function.Description)
+
+ if d.Function.Parameters != nil {
+ if paramJSON, err := json.Marshal(d.Function.Parameters); err == nil {
+ totalChars += len(paramJSON)
+ }
+ }
+
+ // Per-tool overhead: type field, JSON structure, separators.
+ totalChars += 20
+ }
+
+ return totalChars * 2 / 5
+}
+
+// isOverContextBudget checks whether the assembled messages plus tool definitions
+// and output reserve would exceed the model's context window. This enables
+// proactive compression before calling the LLM, rather than reacting to 400 errors.
+func isOverContextBudget(
+ contextWindow int,
+ messages []providers.Message,
+ toolDefs []providers.ToolDefinition,
+ maxTokens int,
+) bool {
+ msgTokens := 0
+ for _, m := range messages {
+ msgTokens += estimateMessageTokens(m)
+ }
+
+ toolTokens := estimateToolDefsTokens(toolDefs)
+ total := msgTokens + toolTokens + maxTokens
+
+ return total > contextWindow
+}
diff --git a/pkg/agent/context_budget_test.go b/pkg/agent/context_budget_test.go
new file mode 100644
index 000000000..c8a6b19c5
--- /dev/null
+++ b/pkg/agent/context_budget_test.go
@@ -0,0 +1,545 @@
+package agent
+
+import (
+ "fmt"
+ "strings"
+ "testing"
+
+ "github.com/sipeed/picoclaw/pkg/providers"
+)
+
+// msgUser creates a user message.
+func msgUser(content string) providers.Message {
+ return providers.Message{Role: "user", Content: content}
+}
+
+// msgAssistant creates a plain assistant message (no tool calls).
+func msgAssistant(content string) providers.Message {
+ return providers.Message{Role: "assistant", Content: content}
+}
+
+// msgAssistantTC creates an assistant message with tool calls.
+func msgAssistantTC(toolIDs ...string) providers.Message {
+ tcs := make([]providers.ToolCall, len(toolIDs))
+ for i, id := range toolIDs {
+ tcs[i] = providers.ToolCall{
+ ID: id,
+ Type: "function",
+ Name: "tool_" + id,
+ Function: &providers.FunctionCall{
+ Name: "tool_" + id,
+ Arguments: `{"key":"value"}`,
+ },
+ }
+ }
+ return providers.Message{Role: "assistant", ToolCalls: tcs}
+}
+
+// msgTool creates a tool result message.
+func msgTool(callID, content string) providers.Message {
+ return providers.Message{Role: "tool", ToolCallID: callID, Content: content}
+}
+
+func TestIsSafeBoundary(t *testing.T) {
+ tests := []struct {
+ name string
+ history []providers.Message
+ index int
+ want bool
+ }{
+ {
+ name: "empty history, index 0",
+ history: nil,
+ index: 0,
+ want: true,
+ },
+ {
+ name: "single user message, index 0",
+ history: []providers.Message{msgUser("hi")},
+ index: 0,
+ want: true,
+ },
+ {
+ name: "single user message, index 1 (end)",
+ history: []providers.Message{msgUser("hi")},
+ index: 1,
+ want: true,
+ },
+ {
+ name: "at user message",
+ history: []providers.Message{
+ msgAssistant("hello"),
+ msgUser("how are you"),
+ msgAssistant("fine"),
+ },
+ index: 1,
+ want: true,
+ },
+ {
+ name: "at assistant without tool calls",
+ history: []providers.Message{
+ msgUser("hello"),
+ msgAssistant("response"),
+ msgUser("follow up"),
+ },
+ index: 1,
+ want: false,
+ },
+ {
+ name: "at assistant with tool calls",
+ history: []providers.Message{
+ msgUser("search something"),
+ msgAssistantTC("tc1"),
+ msgTool("tc1", "result"),
+ msgAssistant("here is what I found"),
+ },
+ index: 1,
+ want: false,
+ },
+ {
+ name: "at tool result",
+ history: []providers.Message{
+ msgUser("do something"),
+ msgAssistantTC("tc1"),
+ msgTool("tc1", "done"),
+ msgAssistant("completed"),
+ },
+ index: 2,
+ want: false,
+ },
+ {
+ name: "negative index",
+ history: []providers.Message{
+ msgUser("hello"),
+ },
+ index: -1,
+ want: true,
+ },
+ {
+ name: "index beyond length",
+ history: []providers.Message{
+ msgUser("hello"),
+ },
+ index: 5,
+ want: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := isSafeBoundary(tt.history, tt.index)
+ if got != tt.want {
+ t.Errorf("isSafeBoundary(history, %d) = %v, want %v", tt.index, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestFindSafeBoundary(t *testing.T) {
+ tests := []struct {
+ name string
+ history []providers.Message
+ targetIndex int
+ want int
+ }{
+ {
+ name: "empty history",
+ history: nil,
+ targetIndex: 0,
+ want: 0,
+ },
+ {
+ name: "target at 0",
+ history: []providers.Message{msgUser("hi")},
+ targetIndex: 0,
+ want: 0,
+ },
+ {
+ name: "target beyond length",
+ history: []providers.Message{msgUser("hi")},
+ targetIndex: 5,
+ want: 1,
+ },
+ {
+ name: "target already at user message",
+ history: []providers.Message{
+ msgUser("q1"),
+ msgAssistant("a1"),
+ msgUser("q2"),
+ msgAssistant("a2"),
+ },
+ targetIndex: 2,
+ want: 2,
+ },
+ {
+ name: "target at assistant, scan backward finds user",
+ history: []providers.Message{
+ msgUser("q1"),
+ msgAssistant("a1"),
+ msgUser("q2"),
+ msgAssistant("a2"),
+ msgUser("q3"),
+ },
+ targetIndex: 3, // assistant "a2"
+ want: 2, // backward to user "q2"
+ },
+ {
+ name: "target inside tool sequence, scan backward finds user",
+ history: []providers.Message{
+ msgUser("q1"),
+ msgAssistant("a1"),
+ msgUser("q2"),
+ msgAssistantTC("tc1", "tc2"),
+ msgTool("tc1", "r1"),
+ msgTool("tc2", "r2"),
+ msgAssistant("summary"),
+ msgUser("q3"),
+ },
+ targetIndex: 4, // tool result "r1"
+ want: 2, // backward: 3=assistant+TC (not safe), 2=user → safe
+ },
+ {
+ name: "target inside tool sequence, backward finds user before chain",
+ history: []providers.Message{
+ msgUser("q1"),
+ msgAssistant("a1"),
+ msgUser("q2"),
+ msgAssistantTC("tc1", "tc2"),
+ msgTool("tc1", "r1"),
+ msgTool("tc2", "r2"),
+ msgAssistant("summary"),
+ msgUser("q3"),
+ },
+ targetIndex: 5, // tool result "r2"
+ want: 2, // backward: 4=tool, 3=assistant+TC, 2=user → safe
+ },
+ {
+ name: "no backward user, scan forward finds one",
+ history: []providers.Message{
+ msgAssistantTC("tc1"),
+ msgTool("tc1", "r1"),
+ msgAssistant("a1"),
+ msgUser("q1"),
+ },
+ targetIndex: 1, // tool result
+ want: 3, // forward to user "q1"
+ },
+ {
+ name: "multi-step tool chain preserves atomicity",
+ history: []providers.Message{
+ msgUser("q1"),
+ msgAssistant("a1"),
+ msgUser("q2"),
+ msgAssistantTC("tc1"),
+ msgTool("tc1", "r1"),
+ msgAssistantTC("tc2"),
+ msgTool("tc2", "r2"),
+ msgAssistant("final"),
+ msgUser("q3"),
+ msgAssistant("a3"),
+ },
+ targetIndex: 5, // second assistant+TC
+ want: 2, // backward: 4=tool, 3=assistant+TC, 2=user → safe
+ },
+ {
+ name: "all non-user messages returns target unchanged",
+ history: []providers.Message{
+ msgAssistant("a1"),
+ msgAssistant("a2"),
+ msgAssistant("a3"),
+ },
+ targetIndex: 1,
+ want: 1,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := findSafeBoundary(tt.history, tt.targetIndex)
+ if got != tt.want {
+ t.Errorf("findSafeBoundary(history, %d) = %d, want %d",
+ tt.targetIndex, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestFindSafeBoundary_BackwardScanSkipsToolSequence(t *testing.T) {
+ // A long tool-call chain: user → assistant+TC → tool → tool → ... → assistant → user
+ // Target is inside the chain; boundary should skip the entire chain backward.
+ history := []providers.Message{
+ msgUser("start"), // 0
+ msgAssistant("before chain"), // 1
+ msgUser("trigger"), // 2 ← expected safe boundary
+ msgAssistantTC("t1", "t2", "t3"), // 3
+ msgTool("t1", "r1"), // 4
+ msgTool("t2", "r2"), // 5
+ msgTool("t3", "r3"), // 6
+ msgAssistantTC("t4"), // 7
+ msgTool("t4", "r4"), // 8
+ msgAssistant("chain done"), // 9
+ msgUser("next"), // 10
+ }
+
+ // Target at index 6 (middle of tool results)
+ got := findSafeBoundary(history, 6)
+ if got != 2 {
+ t.Errorf("findSafeBoundary(history, 6) = %d, want 2 (user before chain)", got)
+ }
+}
+
+func TestEstimateMessageTokens(t *testing.T) {
+ tests := []struct {
+ name string
+ msg providers.Message
+ want int // minimum expected tokens (exact value depends on overhead)
+ }{
+ {
+ name: "plain user message",
+ msg: msgUser("Hello, world!"),
+ want: 1, // at least some tokens
+ },
+ {
+ name: "empty message still has overhead",
+ msg: providers.Message{Role: "user"},
+ want: 1, // message overhead alone
+ },
+ {
+ name: "assistant with tool calls",
+ msg: msgAssistantTC("tc_123"),
+ want: 1,
+ },
+ {
+ name: "tool result with ID",
+ msg: msgTool("call_abc", "Here is the search result with lots of content"),
+ want: 1,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := estimateMessageTokens(tt.msg)
+ if got < tt.want {
+ t.Errorf("estimateMessageTokens() = %d, want >= %d", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestEstimateMessageTokens_ToolCallsContribute(t *testing.T) {
+ plain := msgAssistant("thinking")
+ withTC := providers.Message{
+ Role: "assistant",
+ Content: "thinking",
+ ToolCalls: []providers.ToolCall{
+ {
+ ID: "call_1",
+ Type: "function",
+ Name: "web_search",
+ Function: &providers.FunctionCall{
+ Name: "web_search",
+ Arguments: `{"query":"picoclaw agent framework","max_results":5}`,
+ },
+ },
+ },
+ }
+
+ plainTokens := estimateMessageTokens(plain)
+ withTCTokens := estimateMessageTokens(withTC)
+
+ if withTCTokens <= plainTokens {
+ t.Errorf("message with ToolCalls (%d tokens) should exceed plain message (%d tokens)",
+ withTCTokens, plainTokens)
+ }
+}
+
+func TestEstimateMessageTokens_MultibyteContent(t *testing.T) {
+ // Multi-byte characters (e.g. emoji, accented letters) are single runes
+ // but may map to different token counts. The heuristic should still produce
+ // reasonable estimates via RuneCountInString.
+ msg := msgUser("caf\u00e9 na\u00efve r\u00e9sum\u00e9 \u00fcber stra\u00dfe")
+ tokens := estimateMessageTokens(msg)
+ if tokens <= 0 {
+ t.Errorf("multibyte message should produce positive token count, got %d", tokens)
+ }
+}
+
+func TestEstimateMessageTokens_LargeArguments(t *testing.T) {
+ // Simulate a tool call with large JSON arguments.
+ largeArgs := fmt.Sprintf(`{"content":"%s"}`, strings.Repeat("x", 5000))
+ msg := providers.Message{
+ Role: "assistant",
+ ToolCalls: []providers.ToolCall{
+ {
+ ID: "call_large",
+ Type: "function",
+ Name: "write_file",
+ Function: &providers.FunctionCall{
+ Name: "write_file",
+ Arguments: largeArgs,
+ },
+ },
+ },
+ }
+
+ tokens := estimateMessageTokens(msg)
+ // 5000+ chars → at least 2000 tokens with the 2.5 char/token heuristic
+ if tokens < 2000 {
+ t.Errorf("large tool call arguments should produce significant token count, got %d", tokens)
+ }
+}
+
+// --- estimateToolDefsTokens tests ---
+
+func TestEstimateToolDefsTokens(t *testing.T) {
+ tests := []struct {
+ name string
+ defs []providers.ToolDefinition
+ want int // minimum expected tokens
+ }{
+ {
+ name: "empty tool list",
+ defs: nil,
+ want: 0,
+ },
+ {
+ name: "single tool with params",
+ defs: []providers.ToolDefinition{
+ {
+ Type: "function",
+ Function: providers.ToolFunctionDefinition{
+ Name: "web_search",
+ Description: "Search the web for information",
+ Parameters: map[string]any{
+ "type": "object",
+ "properties": map[string]any{
+ "query": map[string]any{"type": "string"},
+ },
+ "required": []any{"query"},
+ },
+ },
+ },
+ },
+ want: 1,
+ },
+ {
+ name: "tool without params",
+ defs: []providers.ToolDefinition{
+ {
+ Type: "function",
+ Function: providers.ToolFunctionDefinition{
+ Name: "list_dir",
+ Description: "List directory contents",
+ },
+ },
+ },
+ want: 1,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := estimateToolDefsTokens(tt.defs)
+ if got < tt.want {
+ t.Errorf("estimateToolDefsTokens() = %d, want >= %d", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestEstimateToolDefsTokens_ScalesWithCount(t *testing.T) {
+ makeTool := func(name string) providers.ToolDefinition {
+ return providers.ToolDefinition{
+ Type: "function",
+ Function: providers.ToolFunctionDefinition{
+ Name: name,
+ Description: "A test tool that does something useful",
+ Parameters: map[string]any{
+ "type": "object",
+ "properties": map[string]any{
+ "input": map[string]any{"type": "string", "description": "Input value"},
+ },
+ },
+ },
+ }
+ }
+
+ one := estimateToolDefsTokens([]providers.ToolDefinition{makeTool("tool_a")})
+ three := estimateToolDefsTokens([]providers.ToolDefinition{
+ makeTool("tool_a"), makeTool("tool_b"), makeTool("tool_c"),
+ })
+
+ if three <= one {
+ t.Errorf("3 tools (%d tokens) should exceed 1 tool (%d tokens)", three, one)
+ }
+}
+
+// --- isOverContextBudget tests ---
+
+func TestIsOverContextBudget(t *testing.T) {
+ systemMsg := providers.Message{Role: "system", Content: strings.Repeat("x", 1000)}
+ userMsg := msgUser("hello")
+ smallHistory := []providers.Message{systemMsg, msgUser("q1"), msgAssistant("a1"), userMsg}
+
+ tools := []providers.ToolDefinition{
+ {
+ Type: "function",
+ Function: providers.ToolFunctionDefinition{
+ Name: "test_tool",
+ Description: "A test tool",
+ Parameters: map[string]any{"type": "object"},
+ },
+ },
+ }
+
+ tests := []struct {
+ name string
+ contextWindow int
+ messages []providers.Message
+ toolDefs []providers.ToolDefinition
+ maxTokens int
+ want bool
+ }{
+ {
+ name: "within budget",
+ contextWindow: 100000,
+ messages: smallHistory,
+ toolDefs: tools,
+ maxTokens: 4096,
+ want: false,
+ },
+ {
+ name: "over budget with small window",
+ contextWindow: 100, // very small window
+ messages: smallHistory,
+ toolDefs: tools,
+ maxTokens: 4096,
+ want: true,
+ },
+ {
+ name: "large max_tokens eats budget",
+ contextWindow: 2000,
+ messages: smallHistory,
+ toolDefs: tools,
+ maxTokens: 1800, // leaves almost no room
+ want: true,
+ },
+ {
+ name: "empty messages within budget",
+ contextWindow: 10000,
+ messages: nil,
+ toolDefs: nil,
+ maxTokens: 4096,
+ want: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := isOverContextBudget(tt.contextWindow, tt.messages, tt.toolDefs, tt.maxTokens)
+ if got != tt.want {
+ t.Errorf("isOverContextBudget() = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
diff --git a/pkg/agent/instance.go b/pkg/agent/instance.go
index 0c7baa1ee..c34f9b4a4 100644
--- a/pkg/agent/instance.go
+++ b/pkg/agent/instance.go
@@ -127,6 +127,17 @@ func NewAgentInstance(
maxTokens = 8192
}
+ contextWindow := defaults.ContextWindow
+ if contextWindow == 0 {
+ // Default heuristic: 4x the output token limit.
+ // Most models have context windows well above their output limits
+ // (e.g., GPT-4o 128k ctx / 16k out, Claude 200k ctx / 8k out).
+ // 4x is a conservative lower bound that avoids premature
+ // summarization while remaining safe — the reactive
+ // forceCompression handles any overshoot.
+ contextWindow = maxTokens * 4
+ }
+
temperature := 0.7
if defaults.Temperature != nil {
temperature = *defaults.Temperature
@@ -224,7 +235,7 @@ func NewAgentInstance(
MaxTokens: maxTokens,
Temperature: temperature,
ThinkingLevel: thinkingLevel,
- ContextWindow: maxTokens,
+ ContextWindow: contextWindow,
SummarizeMessageThreshold: summarizeMessageThreshold,
SummarizeTokenPercent: summarizeTokenPercent,
Provider: provider,
diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go
index 21516e7de..f20f2c938 100644
--- a/pkg/agent/loop.go
+++ b/pkg/agent/loop.go
@@ -17,7 +17,6 @@ import (
"sync"
"sync/atomic"
"time"
- "unicode/utf8"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
@@ -931,6 +930,24 @@ func (al *AgentLoop) runAgentLoop(
maxMediaSize := cfg.Agents.Defaults.GetMaxMediaSize()
messages = resolveMediaRefs(messages, al.mediaStore, maxMediaSize)
+ // 1.5. Proactive context budget check: compress before LLM call
+ // rather than waiting for a 400 context-length error.
+ if !opts.NoHistory {
+ toolDefs := agent.Tools.ToProviderDefs()
+ if isOverContextBudget(agent.ContextWindow, messages, toolDefs, agent.MaxTokens) {
+ logger.WarnCF("agent", "Proactive compression: context budget exceeded before LLM call",
+ map[string]any{"session_key": opts.SessionKey})
+ al.forceCompression(agent, opts.SessionKey)
+ newHistory := agent.Sessions.GetHistory(opts.SessionKey)
+ newSummary := agent.Sessions.GetSummary(opts.SessionKey)
+ messages = agent.ContextBuilder.BuildMessages(
+ newHistory, newSummary, opts.UserMessage,
+ opts.Media, opts.Channel, opts.ChatID,
+ )
+ messages = resolveMediaRefs(messages, al.mediaStore, maxMediaSize)
+ }
+ }
+
// 2. Save user message to session
agent.Sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage)
@@ -1539,7 +1556,8 @@ func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey, channel, c
}
// forceCompression aggressively reduces context when the limit is hit.
-// It drops the oldest 50% of messages (keeping system prompt and last user message).
+// It drops the oldest ~50% of messages (keeping system prompt and last user message),
+// aligning the split to a safe boundary so tool-call sequences stay intact.
func (al *AgentLoop) forceCompression(agent *AgentInstance, sessionKey string) {
history := agent.Sessions.GetHistory(sessionKey)
if len(history) <= 4 {
@@ -1554,8 +1572,8 @@ func (al *AgentLoop) forceCompression(agent *AgentInstance, sessionKey string) {
return
}
- // Helper to find the mid-point of the conversation
- mid := len(conversation) / 2
+ // Find a safe mid-point that does not split a tool-call sequence.
+ mid := findSafeBoundary(conversation, len(conversation)/2)
// New history structure:
// 1. System Prompt (with compression note appended)
@@ -1687,12 +1705,18 @@ func (al *AgentLoop) summarizeSession(agent *AgentInstance, sessionKey string) {
history := agent.Sessions.GetHistory(sessionKey)
summary := agent.Sessions.GetSummary(sessionKey)
- // Keep last 4 messages for continuity
+ // Keep last few messages for continuity, aligned to a safe boundary
+ // so that no tool-call sequence is split.
if len(history) <= 4 {
return
}
- toSummarize := history[:len(history)-4]
+ safeCut := findSafeBoundary(history, len(history)-4)
+ if safeCut <= 0 {
+ return
+ }
+ keepCount := len(history) - safeCut
+ toSummarize := history[:safeCut]
// Oversized Message Guard
maxMessageTokens := agent.ContextWindow / 2
@@ -1757,7 +1781,7 @@ func (al *AgentLoop) summarizeSession(agent *AgentInstance, sessionKey string) {
if finalSummary != "" {
agent.Sessions.SetSummary(sessionKey, finalSummary)
- agent.Sessions.TruncateHistory(sessionKey, 4)
+ agent.Sessions.TruncateHistory(sessionKey, keepCount)
agent.Sessions.Save(sessionKey)
}
}
@@ -1895,15 +1919,14 @@ func (al *AgentLoop) summarizeBatch(
}
// estimateTokens estimates the number of tokens in a message list.
-// Uses a safe heuristic of 2.5 characters per token to account for CJK and other
-// overheads better than the previous 3 chars/token.
+// Counts Content, ToolCalls arguments, and ToolCallID metadata so that
+// tool-heavy conversations are not systematically undercounted.
func (al *AgentLoop) estimateTokens(messages []providers.Message) int {
- totalChars := 0
+ total := 0
for _, m := range messages {
- totalChars += utf8.RuneCountInString(m.Content)
+ total += estimateMessageTokens(m)
}
- // 2.5 chars per token = totalChars * 2 / 5
- return totalChars * 2 / 5
+ return total
}
func (al *AgentLoop) handleCommand(
diff --git a/pkg/config/config.go b/pkg/config/config.go
index a8b8f337f..a3720b656 100644
--- a/pkg/config/config.go
+++ b/pkg/config/config.go
@@ -228,6 +228,7 @@ type AgentDefaults struct {
ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"`
ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"`
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
+ ContextWindow int `json:"context_window,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_CONTEXT_WINDOW"`
Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
SummarizeMessageThreshold int `json:"summarize_message_threshold" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_MESSAGE_THRESHOLD"`
From 9c65d78b07ca82b556dac227b57c76a58013527d Mon Sep 17 00:00:00 2001
From: xiaoen <2768753269@qq.com>
Date: Fri, 13 Mar 2026 15:13:04 +0800
Subject: [PATCH 02/26] fix(agent): forceCompression must not assume history[0]
is system prompt
Session history (GetHistory) contains only user/assistant/tool messages.
The system prompt is built dynamically by BuildMessages and is never
stored in session. The previous code incorrectly treated history[0] as
a system prompt, skipping the first user message and appending a
compression note to it.
Fix: operate on the full history slice, and record the compression
note in the session summary (which BuildMessages already injects into
the system prompt) rather than modifying any history message.
---
pkg/agent/loop.go | 55 ++++++++++++++++++++---------------------------
1 file changed, 23 insertions(+), 32 deletions(-)
diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go
index f20f2c938..14dc8c5ca 100644
--- a/pkg/agent/loop.go
+++ b/pkg/agent/loop.go
@@ -1556,56 +1556,47 @@ func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey, channel, c
}
// forceCompression aggressively reduces context when the limit is hit.
-// It drops the oldest ~50% of messages (keeping system prompt and last user message),
-// aligning the split to a safe boundary so tool-call sequences stay intact.
+// It drops the oldest ~50% of messages, aligning the split to a safe
+// boundary so tool-call sequences stay intact.
+//
+// Session history contains only user/assistant/tool messages — the system
+// prompt is built dynamically by BuildMessages and is NOT stored here.
+// The compression note is recorded in the session summary so that
+// BuildMessages can include it in the next system prompt.
func (al *AgentLoop) forceCompression(agent *AgentInstance, sessionKey string) {
history := agent.Sessions.GetHistory(sessionKey)
- if len(history) <= 4 {
- return
- }
-
- // Keep system prompt (usually [0]) and the very last message (user's trigger)
- // We want to drop the oldest half of the *conversation*
- // Assuming [0] is system, [1:] is conversation
- conversation := history[1 : len(history)-1]
- if len(conversation) == 0 {
+ if len(history) <= 2 {
return
}
// Find a safe mid-point that does not split a tool-call sequence.
- mid := findSafeBoundary(conversation, len(conversation)/2)
-
- // New history structure:
- // 1. System Prompt (with compression note appended)
- // 2. Second half of conversation
- // 3. Last message
+ mid := findSafeBoundary(history, len(history)/2)
+ if mid <= 0 {
+ return
+ }
droppedCount := mid
- keptConversation := conversation[mid:]
+ keptHistory := history[mid:]
- newHistory := make([]providers.Message, 0, 1+len(keptConversation)+1)
-
- // Append compression note to the original system prompt instead of adding a new system message
- // This avoids having two consecutive system messages which some APIs (like Zhipu) reject
+ // Record compression in the session summary so BuildMessages includes it
+ // in the system prompt. We do not modify history messages themselves.
+ existingSummary := agent.Sessions.GetSummary(sessionKey)
compressionNote := fmt.Sprintf(
- "\n\n[System Note: Emergency compression dropped %d oldest messages due to context limit]",
+ "[Emergency compression dropped %d oldest messages due to context limit]",
droppedCount,
)
- enhancedSystemPrompt := history[0]
- enhancedSystemPrompt.Content = enhancedSystemPrompt.Content + compressionNote
- newHistory = append(newHistory, enhancedSystemPrompt)
+ if existingSummary != "" {
+ compressionNote = existingSummary + "\n\n" + compressionNote
+ }
+ agent.Sessions.SetSummary(sessionKey, compressionNote)
- newHistory = append(newHistory, keptConversation...)
- newHistory = append(newHistory, history[len(history)-1]) // Last message
-
- // Update session
- agent.Sessions.SetHistory(sessionKey, newHistory)
+ agent.Sessions.SetHistory(sessionKey, keptHistory)
agent.Sessions.Save(sessionKey)
logger.WarnCF("agent", "Forced compression executed", map[string]any{
"session_key": sessionKey,
"dropped_msgs": droppedCount,
- "new_count": len(newHistory),
+ "new_count": len(keptHistory),
})
}
From d5fdd5ebd2644408d45a5525ead50b16938a5012 Mon Sep 17 00:00:00 2001
From: xiaoen <2768753269@qq.com>
Date: Fri, 13 Mar 2026 15:14:00 +0800
Subject: [PATCH 03/26] fix(agent): include ReasoningContent and Media in token
estimation
estimateMessageTokens now counts ReasoningContent (extended thinking /
chain-of-thought) which can be substantial and is persisted in session
history. Media items get a fixed per-item overhead (256 tokens) since
actual cost depends on provider-specific image tokenization.
---
pkg/agent/context_budget.go | 16 +++++++++++++--
pkg/agent/context_budget_test.go | 34 ++++++++++++++++++++++++++++++++
2 files changed, 48 insertions(+), 2 deletions(-)
diff --git a/pkg/agent/context_budget.go b/pkg/agent/context_budget.go
index 2eec9c267..71da5d8f7 100644
--- a/pkg/agent/context_budget.go
+++ b/pkg/agent/context_budget.go
@@ -63,11 +63,17 @@ func findSafeBoundary(history []providers.Message, targetIndex int) int {
}
// estimateMessageTokens estimates the token count for a single message,
-// including Content, ToolCalls arguments, and ToolCallID metadata.
-// Uses a heuristic of 2.5 characters per token.
+// including Content, ReasoningContent, ToolCalls arguments, ToolCallID
+// metadata, and Media items. Uses a heuristic of 2.5 characters per token.
func estimateMessageTokens(msg providers.Message) int {
chars := utf8.RuneCountInString(msg.Content)
+ // ReasoningContent (extended thinking / chain-of-thought) can be
+ // substantial and is stored in session history via AddFullMessage.
+ if msg.ReasoningContent != "" {
+ chars += utf8.RuneCountInString(msg.ReasoningContent)
+ }
+
for _, tc := range msg.ToolCalls {
// Count tool call metadata: ID, type, function name
chars += len(tc.ID) + len(tc.Type) + len(tc.Name)
@@ -80,6 +86,12 @@ func estimateMessageTokens(msg providers.Message) int {
chars += len(msg.ToolCallID)
}
+ // Media items (images, files) are serialized by provider adapters into
+ // multipart or image_url payloads. Use a fixed per-item estimate since
+ // actual token cost depends on resolution and provider tokenization.
+ const mediaTokensPerItem = 256
+ chars += len(msg.Media) * mediaTokensPerItem
+
// Per-message overhead for role label, JSON structure, separators.
const messageOverhead = 12
chars += messageOverhead
diff --git a/pkg/agent/context_budget_test.go b/pkg/agent/context_budget_test.go
index c8a6b19c5..03ace82e2 100644
--- a/pkg/agent/context_budget_test.go
+++ b/pkg/agent/context_budget_test.go
@@ -389,6 +389,40 @@ func TestEstimateMessageTokens_LargeArguments(t *testing.T) {
}
}
+func TestEstimateMessageTokens_ReasoningContent(t *testing.T) {
+ plain := msgAssistant("result")
+ withReasoning := providers.Message{
+ Role: "assistant",
+ Content: "result",
+ ReasoningContent: strings.Repeat("thinking step ", 200),
+ }
+
+ plainTokens := estimateMessageTokens(plain)
+ reasoningTokens := estimateMessageTokens(withReasoning)
+
+ if reasoningTokens <= plainTokens {
+ t.Errorf("message with ReasoningContent (%d tokens) should exceed plain message (%d tokens)",
+ reasoningTokens, plainTokens)
+ }
+}
+
+func TestEstimateMessageTokens_MediaItems(t *testing.T) {
+ plain := msgUser("describe this")
+ withMedia := providers.Message{
+ Role: "user",
+ Content: "describe this",
+ Media: []string{"media://img1.png", "media://img2.png"},
+ }
+
+ plainTokens := estimateMessageTokens(plain)
+ mediaTokens := estimateMessageTokens(withMedia)
+
+ if mediaTokens <= plainTokens {
+ t.Errorf("message with Media (%d tokens) should exceed plain message (%d tokens)",
+ mediaTokens, plainTokens)
+ }
+}
+
// --- estimateToolDefsTokens tests ---
func TestEstimateToolDefsTokens(t *testing.T) {
From e35906bb1447b60b4836587d824b488698e12b14 Mon Sep 17 00:00:00 2001
From: xiaoen <2768753269@qq.com>
Date: Fri, 13 Mar 2026 15:16:57 +0800
Subject: [PATCH 04/26] feat(config): expose context_window in example config
and web UI
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Add context_window to config.example.json, the web configuration page
(form model, input field, save handler), and i18n strings (en/zh).
The field is optional — leaving it empty falls back to the 4x max_tokens
heuristic.
---
config/config.example.json | 1 +
web/frontend/src/components/config/config-page.tsx | 4 ++++
.../src/components/config/config-sections.tsx | 14 ++++++++++++++
web/frontend/src/components/config/form-model.ts | 3 +++
web/frontend/src/i18n/locales/en.json | 2 ++
web/frontend/src/i18n/locales/zh.json | 2 ++
6 files changed, 26 insertions(+)
diff --git a/config/config.example.json b/config/config.example.json
index 094aa46df..20c10e60d 100644
--- a/config/config.example.json
+++ b/config/config.example.json
@@ -5,6 +5,7 @@
"restrict_to_workspace": true,
"model_name": "gpt-5.4",
"max_tokens": 8192,
+ "context_window": 131072,
"temperature": 0.7,
"max_tool_iterations": 20,
"summarize_message_threshold": 20,
diff --git a/web/frontend/src/components/config/config-page.tsx b/web/frontend/src/components/config/config-page.tsx
index cbce7d27e..dc6797749 100644
--- a/web/frontend/src/components/config/config-page.tsx
+++ b/web/frontend/src/components/config/config-page.tsx
@@ -144,6 +144,9 @@ export function ConfigPage() {
const maxTokens = parseIntField(form.maxTokens, "Max tokens", {
min: 1,
})
+ const contextWindow = form.contextWindow.trim()
+ ? parseIntField(form.contextWindow, "Context window", { min: 1 })
+ : undefined
const maxToolIterations = parseIntField(
form.maxToolIterations,
"Max tool iterations",
@@ -171,6 +174,7 @@ export function ConfigPage() {
workspace,
restrict_to_workspace: form.restrictToWorkspace,
max_tokens: maxTokens,
+ context_window: contextWindow,
max_tool_iterations: maxToolIterations,
summarize_message_threshold: summarizeMessageThreshold,
summarize_token_percent: summarizeTokenPercent,
diff --git a/web/frontend/src/components/config/config-sections.tsx b/web/frontend/src/components/config/config-sections.tsx
index dfbe22fc3..825d882b7 100644
--- a/web/frontend/src/components/config/config-sections.tsx
+++ b/web/frontend/src/components/config/config-sections.tsx
@@ -114,6 +114,20 @@ export function AgentDefaultsSection({
/>
+
+ onFieldChange("contextWindow", e.target.value)}
+ placeholder="131072"
+ />
+
+
Date: Fri, 13 Mar 2026 15:18:07 +0800
Subject: [PATCH 05/26] test(agent): add realistic session-shaped tests for
context budget
Add tests that reflect actual session data shape: history starts with
user messages (no system prompt), includes chained tool-call sequences,
reasoning content, and media items. Exercises the proactive budget check
path with BuildMessages-style assembled messages.
---
pkg/agent/context_budget_test.go | 140 +++++++++++++++++++++++++++++++
1 file changed, 140 insertions(+)
diff --git a/pkg/agent/context_budget_test.go b/pkg/agent/context_budget_test.go
index 03ace82e2..6b51a8cb7 100644
--- a/pkg/agent/context_budget_test.go
+++ b/pkg/agent/context_budget_test.go
@@ -577,3 +577,143 @@ func TestIsOverContextBudget(t *testing.T) {
})
}
}
+
+// --- Tests reflecting actual session data shape ---
+// Session history never contains system messages. The system prompt is
+// built dynamically by BuildMessages. These tests use realistic history
+// shapes: user/assistant/tool only, with tool chains and reasoning content.
+
+func TestFindSafeBoundary_SessionHistoryNoSystem(t *testing.T) {
+ // Real session history starts with a user message, not a system message.
+ history := []providers.Message{
+ msgUser("hello"), // 0
+ msgAssistant("hi there"), // 1
+ msgUser("search for X"), // 2
+ msgAssistantTC("tc1"), // 3
+ msgTool("tc1", "found X"), // 4
+ msgAssistant("here is X"), // 5
+ msgUser("thanks"), // 6
+ msgAssistant("you're welcome"), // 7
+ }
+
+ // Mid-point is 4 (tool result). Should snap backward to 2 (user).
+ got := findSafeBoundary(history, 4)
+ if got != 2 {
+ t.Errorf("findSafeBoundary(session_history, 4) = %d, want 2", got)
+ }
+}
+
+func TestFindSafeBoundary_SessionWithChainedTools(t *testing.T) {
+ // Session with chained tool calls (save then notify).
+ history := []providers.Message{
+ msgUser("save and notify"), // 0
+ msgAssistantTC("tc_save"), // 1
+ msgTool("tc_save", "saved"), // 2
+ msgAssistantTC("tc_notify"), // 3
+ msgTool("tc_notify", "notified"), // 4
+ msgAssistant("done"), // 5
+ msgUser("check status"), // 6
+ msgAssistant("all good"), // 7
+ }
+
+ // Target at 3 (inside chain). Should find user at 0, but backward
+ // scan stops at i>0, so forward scan finds user at 6.
+ // Actually: backward from 3: 2=tool (no), 1=assistantTC (no). Forward: 4=tool, 5=asst, 6=user ✓
+ got := findSafeBoundary(history, 3)
+ if got != 6 {
+ t.Errorf("findSafeBoundary(chained_tools, 3) = %d, want 6", got)
+ }
+}
+
+func TestEstimateMessageTokens_WithReasoningAndMedia(t *testing.T) {
+ // Message with all fields populated — mirrors what AddFullMessage stores.
+ msg := providers.Message{
+ Role: "assistant",
+ Content: "Here is the analysis.",
+ ReasoningContent: strings.Repeat("Let me think about this carefully. ", 50),
+ ToolCalls: []providers.ToolCall{
+ {
+ ID: "call_1",
+ Type: "function",
+ Name: "analyze",
+ Function: &providers.FunctionCall{
+ Name: "analyze",
+ Arguments: `{"data":"sample","depth":3}`,
+ },
+ },
+ },
+ }
+
+ tokens := estimateMessageTokens(msg)
+
+ // ReasoningContent alone is ~1700 chars → ~680 tokens.
+ // Content + TC + overhead adds more. Should be well above 500.
+ if tokens < 500 {
+ t.Errorf("message with reasoning+toolcalls should have significant tokens, got %d", tokens)
+ }
+
+ // Compare without reasoning to ensure it's counted.
+ msgNoReasoning := msg
+ msgNoReasoning.ReasoningContent = ""
+ tokensNoReasoning := estimateMessageTokens(msgNoReasoning)
+
+ if tokens <= tokensNoReasoning {
+ t.Errorf("reasoning content should add tokens: with=%d, without=%d", tokens, tokensNoReasoning)
+ }
+}
+
+func TestIsOverContextBudget_RealisticSession(t *testing.T) {
+ // Simulate what BuildMessages produces: system + session history + current user.
+ // System message is built by BuildMessages, not stored in session.
+ systemMsg := providers.Message{
+ Role: "system",
+ Content: strings.Repeat("system prompt content ", 100),
+ }
+ sessionHistory := []providers.Message{
+ msgUser("first question"),
+ msgAssistant("first answer"),
+ msgUser("use tool X"),
+ {
+ Role: "assistant",
+ Content: "I'll use tool X",
+ ToolCalls: []providers.ToolCall{
+ {
+ ID: "tc1", Type: "function", Name: "tool_x",
+ Function: &providers.FunctionCall{
+ Name: "tool_x",
+ Arguments: `{"query":"test","verbose":true}`,
+ },
+ },
+ },
+ },
+ {Role: "tool", Content: strings.Repeat("result data ", 200), ToolCallID: "tc1"},
+ msgAssistant("Here are the results from tool X."),
+ }
+ currentUser := msgUser("follow up question")
+
+ // Assemble as BuildMessages would.
+ messages := []providers.Message{systemMsg}
+ messages = append(messages, sessionHistory...)
+ messages = append(messages, currentUser)
+
+ tools := []providers.ToolDefinition{
+ {
+ Type: "function",
+ Function: providers.ToolFunctionDefinition{
+ Name: "tool_x",
+ Description: "A useful tool",
+ Parameters: map[string]any{"type": "object"},
+ },
+ },
+ }
+
+ // With a large context window, should be within budget.
+ if isOverContextBudget(131072, messages, tools, 32768) {
+ t.Error("realistic session should be within 131072 context window")
+ }
+
+ // With a tiny context window, should exceed budget.
+ if !isOverContextBudget(500, messages, tools, 32768) {
+ t.Error("realistic session should exceed 500 context window")
+ }
+}
From efd403242e8633dfbdf6b3a2c02840adfae338d1 Mon Sep 17 00:00:00 2001
From: xiaoen <2768753269@qq.com>
Date: Fri, 13 Mar 2026 15:50:51 +0800
Subject: [PATCH 06/26] fix(agent): preallocate messages slice in budget test
Fixes prealloc lint warning by using make() with capacity hint.
---
pkg/agent/context_budget_test.go | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/pkg/agent/context_budget_test.go b/pkg/agent/context_budget_test.go
index 6b51a8cb7..4073506cf 100644
--- a/pkg/agent/context_budget_test.go
+++ b/pkg/agent/context_budget_test.go
@@ -692,7 +692,8 @@ func TestIsOverContextBudget_RealisticSession(t *testing.T) {
currentUser := msgUser("follow up question")
// Assemble as BuildMessages would.
- messages := []providers.Message{systemMsg}
+ messages := make([]providers.Message, 0, 1+len(sessionHistory)+1)
+ messages = append(messages, systemMsg)
messages = append(messages, sessionHistory...)
messages = append(messages, currentUser)
From 639739cb8512e7b3610015265f30197dbe421096 Mon Sep 17 00:00:00 2001
From: xiaoen <2768753269@qq.com>
Date: Fri, 13 Mar 2026 15:54:50 +0800
Subject: [PATCH 07/26] refactor(agent): use Turn as the atomic unit for
compression cut-off
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Introduce parseTurnBoundaries() which identifies each Turn start index
in the session history. A Turn is a complete "user input → LLM iterations
→ final response" cycle (as defined in the agent refactor design #1316).
findSafeBoundary now uses Turn boundaries instead of raw role-scanning,
making the intent explicit: "find the nearest Turn boundary."
forceCompression drops the oldest half of Turns (not arbitrary messages),
which is simpler and more intuitive. The Turn-based approach naturally
prevents splitting tool-call sequences since each Turn is atomic.
---
pkg/agent/context_budget.go | 58 ++++++++++++++--------
pkg/agent/context_budget_test.go | 82 ++++++++++++++++++++++++++++++++
pkg/agent/loop.go | 20 ++++++--
3 files changed, 136 insertions(+), 24 deletions(-)
diff --git a/pkg/agent/context_budget.go b/pkg/agent/context_budget.go
index 71da5d8f7..05e27e18a 100644
--- a/pkg/agent/context_budget.go
+++ b/pkg/agent/context_budget.go
@@ -12,14 +12,26 @@ import (
"github.com/sipeed/picoclaw/pkg/providers"
)
-// isSafeBoundary reports whether index is a valid position to split a message
-// history for truncation or compression. Splitting at index means:
-// - history[:index] is dropped or summarized
-// - history[index:] is kept
+// parseTurnBoundaries returns the starting index of each Turn in the history.
+// A Turn is a complete "user input → LLM iterations → final response" cycle
+// (as defined in #1316). Each Turn begins at a user message and extends
+// through all subsequent assistant/tool messages until the next user message.
//
-// A boundary is safe when the kept portion begins at a "user" message,
-// ensuring no tool-call sequence (assistant+ToolCalls → tool results)
-// is torn apart across the split.
+// Cutting at a Turn boundary guarantees that no tool-call sequence
+// (assistant+ToolCalls → tool results) is split across the cut.
+func parseTurnBoundaries(history []providers.Message) []int {
+ var starts []int
+ for i, msg := range history {
+ if msg.Role == "user" {
+ starts = append(starts, i)
+ }
+ }
+ return starts
+}
+
+// isSafeBoundary reports whether index is a valid Turn boundary — i.e.,
+// a position where the kept portion (history[index:]) begins at a user
+// message, so no tool-call sequence is torn apart.
func isSafeBoundary(history []providers.Message, index int) bool {
if index <= 0 || index >= len(history) {
return true
@@ -27,9 +39,10 @@ func isSafeBoundary(history []providers.Message, index int) bool {
return history[index].Role == "user"
}
-// findSafeBoundary locates the nearest safe split point to targetIndex.
-// It scans backward first (preserving more context), then forward.
-// Returns targetIndex unchanged only when no safe boundary exists.
+// findSafeBoundary locates the nearest Turn boundary to targetIndex.
+// It prefers the boundary at or before targetIndex (preserving more recent
+// context). Falls back to the nearest boundary after targetIndex, and
+// returns targetIndex unchanged only when no Turn boundary exists at all.
func findSafeBoundary(history []providers.Message, targetIndex int) int {
if len(history) == 0 {
return 0
@@ -41,21 +54,28 @@ func findSafeBoundary(history []providers.Message, targetIndex int) int {
return len(history)
}
- if isSafeBoundary(history, targetIndex) {
+ turns := parseTurnBoundaries(history)
+ if len(turns) == 0 {
return targetIndex
}
- // Backward scan: prefer keeping more messages.
- for i := targetIndex - 1; i > 0; i-- {
- if isSafeBoundary(history, i) {
- return i
+ // Find the last Turn boundary at or before targetIndex.
+ // Prefer backward: keeps more recent messages.
+ backward := -1
+ for _, t := range turns {
+ if t <= targetIndex {
+ backward = t
}
}
+ if backward > 0 {
+ return backward
+ }
- // Forward scan: fall back to keeping fewer messages.
- for i := targetIndex + 1; i < len(history); i++ {
- if isSafeBoundary(history, i) {
- return i
+ // No valid Turn boundary before target (or only at index 0 which
+ // would keep everything). Use the first Turn after targetIndex.
+ for _, t := range turns {
+ if t > targetIndex {
+ return t
}
}
diff --git a/pkg/agent/context_budget_test.go b/pkg/agent/context_budget_test.go
index 4073506cf..15198d03b 100644
--- a/pkg/agent/context_budget_test.go
+++ b/pkg/agent/context_budget_test.go
@@ -40,6 +40,88 @@ func msgTool(callID, content string) providers.Message {
return providers.Message{Role: "tool", ToolCallID: callID, Content: content}
}
+func TestParseTurnBoundaries(t *testing.T) {
+ tests := []struct {
+ name string
+ history []providers.Message
+ want []int
+ }{
+ {
+ name: "empty history",
+ history: nil,
+ want: nil,
+ },
+ {
+ name: "simple exchange",
+ history: []providers.Message{
+ msgUser("q1"),
+ msgAssistant("a1"),
+ msgUser("q2"),
+ msgAssistant("a2"),
+ },
+ want: []int{0, 2},
+ },
+ {
+ name: "tool-call Turn",
+ history: []providers.Message{
+ msgUser("search"),
+ msgAssistantTC("tc1"),
+ msgTool("tc1", "result"),
+ msgAssistant("found it"),
+ msgUser("thanks"),
+ msgAssistant("welcome"),
+ },
+ want: []int{0, 4},
+ },
+ {
+ name: "chained tool calls in single Turn",
+ history: []providers.Message{
+ msgUser("save and notify"),
+ msgAssistantTC("tc_save"),
+ msgTool("tc_save", "saved"),
+ msgAssistantTC("tc_notify"),
+ msgTool("tc_notify", "notified"),
+ msgAssistant("done"),
+ },
+ want: []int{0},
+ },
+ {
+ name: "no user messages",
+ history: []providers.Message{
+ msgAssistant("a1"),
+ msgAssistant("a2"),
+ },
+ want: nil,
+ },
+ {
+ name: "leading non-user messages",
+ history: []providers.Message{
+ msgAssistantTC("tc1"),
+ msgTool("tc1", "r1"),
+ msgAssistant("greeting"),
+ msgUser("hello"),
+ msgAssistant("hi"),
+ },
+ want: []int{3},
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := parseTurnBoundaries(tt.history)
+ if len(got) != len(tt.want) {
+ t.Errorf("parseTurnBoundaries() = %v, want %v", got, tt.want)
+ return
+ }
+ for i := range got {
+ if got[i] != tt.want[i] {
+ t.Errorf("parseTurnBoundaries()[%d] = %d, want %d", i, got[i], tt.want[i])
+ }
+ }
+ })
+ }
+}
+
func TestIsSafeBoundary(t *testing.T) {
tests := []struct {
name string
diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go
index 14dc8c5ca..688d0ed1d 100644
--- a/pkg/agent/loop.go
+++ b/pkg/agent/loop.go
@@ -1556,8 +1556,8 @@ func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey, channel, c
}
// forceCompression aggressively reduces context when the limit is hit.
-// It drops the oldest ~50% of messages, aligning the split to a safe
-// boundary so tool-call sequences stay intact.
+// It drops the oldest ~50% of Turns (a Turn is a complete user→LLM→response
+// cycle, as defined in #1316), so tool-call sequences are never split.
//
// Session history contains only user/assistant/tool messages — the system
// prompt is built dynamically by BuildMessages and is NOT stored here.
@@ -1569,8 +1569,18 @@ func (al *AgentLoop) forceCompression(agent *AgentInstance, sessionKey string) {
return
}
- // Find a safe mid-point that does not split a tool-call sequence.
- mid := findSafeBoundary(history, len(history)/2)
+ // Split at a Turn boundary so no tool-call sequence is torn apart.
+ // parseTurnBoundaries gives us the start of each Turn; we drop the
+ // oldest half of Turns and keep the most recent ones.
+ turns := parseTurnBoundaries(history)
+ var mid int
+ if len(turns) >= 2 {
+ mid = turns[len(turns)/2]
+ } else {
+ // Fewer than 2 Turns — fall back to message-level midpoint
+ // aligned to the nearest Turn boundary.
+ mid = findSafeBoundary(history, len(history)/2)
+ }
if mid <= 0 {
return
}
@@ -1696,7 +1706,7 @@ func (al *AgentLoop) summarizeSession(agent *AgentInstance, sessionKey string) {
history := agent.Sessions.GetHistory(sessionKey)
summary := agent.Sessions.GetSummary(sessionKey)
- // Keep last few messages for continuity, aligned to a safe boundary
+ // Keep the most recent Turns for continuity, aligned to a Turn boundary
// so that no tool-call sequence is split.
if len(history) <= 4 {
return
From 8034ee7be13f891dd1e578390cad9bf09dbfa5e2 Mon Sep 17 00:00:00 2001
From: xiaoen <2768753269@qq.com>
Date: Fri, 13 Mar 2026 16:02:04 +0800
Subject: [PATCH 08/26] fix(agent): correct media token arithmetic and tool
call double-counting
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Two estimation bugs fixed:
1. Media tokens were added to the chars accumulator before the chars*2/5
conversion, resulting in 256*2/5=102 tokens per item instead of 256.
Fix: add media tokens directly to the final token count, bypassing
the character-based heuristic.
2. estimateMessageTokens counted both tc.Name and tc.Function.Name for
tool calls, but providers only send one (OpenAI-compat uses
function.name, Anthropic uses tc.Name). Fix: count tc.Function.Name
when Function is present, fall back to tc.Name only otherwise.
Also fix i18n hint text: "auto-detect" was misleading — the backend
uses a 4x max_tokens heuristic, not actual model detection.
---
pkg/agent/context_budget.go | 25 ++++++++++++++++---------
pkg/agent/context_budget_test.go | 7 +++++++
web/frontend/src/i18n/locales/en.json | 2 +-
web/frontend/src/i18n/locales/zh.json | 2 +-
4 files changed, 25 insertions(+), 11 deletions(-)
diff --git a/pkg/agent/context_budget.go b/pkg/agent/context_budget.go
index 05e27e18a..0b7f443e6 100644
--- a/pkg/agent/context_budget.go
+++ b/pkg/agent/context_budget.go
@@ -95,10 +95,14 @@ func estimateMessageTokens(msg providers.Message) int {
}
for _, tc := range msg.ToolCalls {
- // Count tool call metadata: ID, type, function name
- chars += len(tc.ID) + len(tc.Type) + len(tc.Name)
+ chars += len(tc.ID) + len(tc.Type)
if tc.Function != nil {
+ // Count function name + arguments (the wire format for most providers).
+ // tc.Name mirrors tc.Function.Name — count only once to avoid double-counting.
chars += len(tc.Function.Name) + len(tc.Function.Arguments)
+ } else {
+ // Fallback: some provider formats use top-level Name without Function.
+ chars += len(tc.Name)
}
}
@@ -106,17 +110,20 @@ func estimateMessageTokens(msg providers.Message) int {
chars += len(msg.ToolCallID)
}
- // Media items (images, files) are serialized by provider adapters into
- // multipart or image_url payloads. Use a fixed per-item estimate since
- // actual token cost depends on resolution and provider tokenization.
- const mediaTokensPerItem = 256
- chars += len(msg.Media) * mediaTokensPerItem
-
// Per-message overhead for role label, JSON structure, separators.
const messageOverhead = 12
chars += messageOverhead
- return chars * 2 / 5
+ tokens := chars * 2 / 5
+
+ // Media items (images, files) are serialized by provider adapters into
+ // multipart or image_url payloads. Add a fixed per-item token estimate
+ // directly (not through the chars heuristic) since actual cost depends
+ // on resolution and provider-specific image tokenization.
+ const mediaTokensPerItem = 256
+ tokens += len(msg.Media) * mediaTokensPerItem
+
+ return tokens
}
// estimateToolDefsTokens estimates the total token cost of tool definitions
diff --git a/pkg/agent/context_budget_test.go b/pkg/agent/context_budget_test.go
index 15198d03b..175e04885 100644
--- a/pkg/agent/context_budget_test.go
+++ b/pkg/agent/context_budget_test.go
@@ -503,6 +503,13 @@ func TestEstimateMessageTokens_MediaItems(t *testing.T) {
t.Errorf("message with Media (%d tokens) should exceed plain message (%d tokens)",
mediaTokens, plainTokens)
}
+
+ // Each media item should add exactly 256 tokens (not run through chars*2/5).
+ expectedDelta := 256 * 2
+ actualDelta := mediaTokens - plainTokens
+ if actualDelta != expectedDelta {
+ t.Errorf("2 media items should add %d tokens, got delta %d", expectedDelta, actualDelta)
+ }
}
// --- estimateToolDefsTokens tests ---
diff --git a/web/frontend/src/i18n/locales/en.json b/web/frontend/src/i18n/locales/en.json
index 116ee4441..09852e0c7 100644
--- a/web/frontend/src/i18n/locales/en.json
+++ b/web/frontend/src/i18n/locales/en.json
@@ -397,7 +397,7 @@
"max_tokens": "Max Tokens",
"max_tokens_hint": "Upper token limit per model response.",
"context_window": "Context Window",
- "context_window_hint": "Model input context capacity in tokens. Leave empty to auto-detect (default: 4x max tokens).",
+ "context_window_hint": "Model input context capacity in tokens. Leave empty to use the default (4x max tokens).",
"max_tool_iterations": "Max Tool Iterations",
"max_tool_iterations_hint": "Maximum tool-call loops in a single task.",
"summarize_threshold": "Summarize Message Threshold",
diff --git a/web/frontend/src/i18n/locales/zh.json b/web/frontend/src/i18n/locales/zh.json
index e68c46085..c92ea0032 100644
--- a/web/frontend/src/i18n/locales/zh.json
+++ b/web/frontend/src/i18n/locales/zh.json
@@ -397,7 +397,7 @@
"max_tokens": "最大 Token 数",
"max_tokens_hint": "单次模型响应允许的最大 Token 数。",
"context_window": "上下文窗口",
- "context_window_hint": "模型输入上下文容量(Token 数)。留空则自动推算(默认为最大 Token 数的 4 倍)。",
+ "context_window_hint": "模型输入上下文容量(Token 数)。留空使用默认值(最大 Token 数的 4 倍)。",
"max_tool_iterations": "最大工具迭代次数",
"max_tool_iterations_hint": "单个任务中允许的工具调用循环上限。",
"summarize_threshold": "触发摘要的消息阈值",
From edbdc3bcf106a60540348f01baa45d39a6627e00 Mon Sep 17 00:00:00 2001
From: xiaoen <2768753269@qq.com>
Date: Fri, 13 Mar 2026 16:25:27 +0800
Subject: [PATCH 09/26] fix(agent): findSafeBoundary returns 0 for single-Turn
history
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
When the entire history is a single Turn (one user message followed by
tool calls and responses, no subsequent user message), the only Turn
boundary is at index 0. Previously the fallback returned targetIndex,
which could land on a tool or assistant message — splitting the Turn.
Return 0 instead, so callers (forceCompression, summarizeSession) see
mid <= 0 and skip compression rather than cutting inside the Turn.
---
pkg/agent/context_budget.go | 6 +++++-
pkg/agent/context_budget_test.go | 17 +++++++++++++++++
2 files changed, 22 insertions(+), 1 deletion(-)
diff --git a/pkg/agent/context_budget.go b/pkg/agent/context_budget.go
index 0b7f443e6..c87695c7a 100644
--- a/pkg/agent/context_budget.go
+++ b/pkg/agent/context_budget.go
@@ -79,7 +79,11 @@ func findSafeBoundary(history []providers.Message, targetIndex int) int {
}
}
- return targetIndex
+ // No Turn boundary after targetIndex either. The only boundary is at
+ // index 0, meaning the entire history is a single Turn. Return 0 to
+ // signal that safe compression is not possible — callers check for
+ // mid <= 0 and skip compression in that case.
+ return 0
}
// estimateMessageTokens estimates the token count for a single message,
diff --git a/pkg/agent/context_budget_test.go b/pkg/agent/context_budget_test.go
index 175e04885..30b3fe6a2 100644
--- a/pkg/agent/context_budget_test.go
+++ b/pkg/agent/context_budget_test.go
@@ -346,6 +346,23 @@ func TestFindSafeBoundary(t *testing.T) {
}
}
+func TestFindSafeBoundary_SingleTurnReturnsZero(t *testing.T) {
+ // A single Turn with no subsequent user message. The only Turn boundary
+ // is at index 0; cutting anywhere else would split the Turn's tool
+ // sequence. findSafeBoundary must return 0 so callers skip compression.
+ history := []providers.Message{
+ msgUser("do everything"), // 0 ← only Turn boundary
+ msgAssistantTC("tc1"), // 1
+ msgTool("tc1", "result"), // 2
+ msgAssistant("all done"), // 3
+ }
+
+ got := findSafeBoundary(history, 2)
+ if got != 0 {
+ t.Errorf("findSafeBoundary(single_turn, 2) = %d, want 0 (cannot split single Turn)", got)
+ }
+}
+
func TestFindSafeBoundary_BackwardScanSkipsToolSequence(t *testing.T) {
// A long tool-call chain: user → assistant+TC → tool → tool → ... → assistant → user
// Target is inside the chain; boundary should skip the entire chain backward.
From 7c1a1c2c1a8554d29c11903103d231962ffdac4f Mon Sep 17 00:00:00 2001
From: xiaoen <2768753269@qq.com>
Date: Fri, 13 Mar 2026 16:30:26 +0800
Subject: [PATCH 10/26] style(agent): fix gci comment alignment in test
---
pkg/agent/context_budget_test.go | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/pkg/agent/context_budget_test.go b/pkg/agent/context_budget_test.go
index 30b3fe6a2..870f0fbe6 100644
--- a/pkg/agent/context_budget_test.go
+++ b/pkg/agent/context_budget_test.go
@@ -351,10 +351,10 @@ func TestFindSafeBoundary_SingleTurnReturnsZero(t *testing.T) {
// is at index 0; cutting anywhere else would split the Turn's tool
// sequence. findSafeBoundary must return 0 so callers skip compression.
history := []providers.Message{
- msgUser("do everything"), // 0 ← only Turn boundary
- msgAssistantTC("tc1"), // 1
- msgTool("tc1", "result"), // 2
- msgAssistant("all done"), // 3
+ msgUser("do everything"), // 0 ← only Turn boundary
+ msgAssistantTC("tc1"), // 1
+ msgTool("tc1", "result"), // 2
+ msgAssistant("all done"), // 3
}
got := findSafeBoundary(history, 2)
From b768dab822bee2affa417d7318e68b8e9eec31b3 Mon Sep 17 00:00:00 2001
From: xiaoen <2768753269@qq.com>
Date: Fri, 13 Mar 2026 17:04:34 +0800
Subject: [PATCH 11/26] test(agent): use realistic session data in context
retry test
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Session history only stores user/assistant/tool messages — the system
prompt is built dynamically by BuildMessages. Remove the incorrect
system message from TestAgentLoop_ContextExhaustionRetry test data
to match the real data model that forceCompression operates on.
---
pkg/agent/loop_test.go | 17 ++++++++---------
1 file changed, 8 insertions(+), 9 deletions(-)
diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go
index a6604e87f..b65c0e21c 100644
--- a/pkg/agent/loop_test.go
+++ b/pkg/agent/loop_test.go
@@ -719,11 +719,11 @@ func TestAgentLoop_ContextExhaustionRetry(t *testing.T) {
al := NewAgentLoop(cfg, msgBus, provider)
- // Inject some history to simulate a full context
+ // Inject some history to simulate a full context.
+ // Session history only stores user/assistant/tool messages — the system
+ // prompt is built dynamically by BuildMessages and is NOT stored here.
sessionKey := "test-session-context"
- // Create dummy history
history := []providers.Message{
- {Role: "system", Content: "System prompt"},
{Role: "user", Content: "Old message 1"},
{Role: "assistant", Content: "Old response 1"},
{Role: "user", Content: "Old message 2"},
@@ -761,12 +761,11 @@ func TestAgentLoop_ContextExhaustionRetry(t *testing.T) {
// Check final history length
finalHistory := defaultAgent.Sessions.GetHistory(sessionKey)
// We verify that the history has been modified (compressed)
- // Original length: 6
- // Expected behavior: compression drops ~50% of history (mid slice)
- // We can assert that the length is NOT what it would be without compression.
- // Without compression: 6 + 1 (new user msg) + 1 (assistant msg) = 8
- if len(finalHistory) >= 8 {
- t.Errorf("Expected history to be compressed (len < 8), got %d", len(finalHistory))
+ // Original length: 5
+ // Expected behavior: compression drops ~50% of Turns
+ // Without compression: 5 + 1 (new user msg) + 1 (assistant msg) = 7
+ if len(finalHistory) >= 7 {
+ t.Errorf("Expected history to be compressed (len < 7), got %d", len(finalHistory))
}
}
From 08259d7e9a1bf7675e52c0344f8570faad628d0d Mon Sep 17 00:00:00 2001
From: xiaoen <2768753269@qq.com>
Date: Sat, 14 Mar 2026 10:46:32 +0800
Subject: [PATCH 12/26] docs(agent-refactor): add context.md for Track 6
boundary clarification
Document the semantic boundaries of context management as called for
in the agent-refactor README (suggested document split, item 5):
- context window region definitions and history budget formula
- ContextWindow vs MaxTokens distinction
- session history contents (no system prompt stored)
- Turn as the atomic compression unit (#1316)
- three compression paths and their ordering
- token estimation approach and its limitations
- interface boundaries between budget functions and BuildMessages
Also documents known gaps: summarization trigger not using the full
budget formula, heuristic-only token estimation, and reactive retry
not preserving media references.
Ref #1439
---
docs/agent-refactor/context.md | 162 +++++++++++++++++++++++++++++++++
1 file changed, 162 insertions(+)
create mode 100644 docs/agent-refactor/context.md
diff --git a/docs/agent-refactor/context.md b/docs/agent-refactor/context.md
new file mode 100644
index 000000000..785fae2be
--- /dev/null
+++ b/docs/agent-refactor/context.md
@@ -0,0 +1,162 @@
+# Context
+
+## What this document covers
+
+This document makes explicit the boundaries of context management in the agent loop:
+
+- what fills the context window and how space is divided
+- what is stored in session history vs. built at request time
+- when and how context compression happens
+- how token budgets are estimated
+
+These are existing concepts. This document clarifies their boundaries rather than introducing new ones.
+
+---
+
+## Context window regions
+
+The context window is the model's total input capacity. Four regions fill it:
+
+| Region | Assembled by | Stored in session? |
+|---|---|---|
+| System prompt | `BuildMessages()` — static + dynamic parts | No |
+| Summary | `SetSummary()` stores it; `BuildMessages()` injects it | Separate from history |
+| Session history | User / assistant / tool messages | Yes |
+| Tool definitions | Provider adapter injects at call time | No |
+
+`MaxTokens` (the output generation limit) must also be reserved from the total budget.
+
+The available space for history is therefore:
+
+```
+history_budget = ContextWindow - system_prompt - summary - tool_definitions - MaxTokens
+```
+
+---
+
+## ContextWindow vs MaxTokens
+
+These serve different purposes:
+
+- **MaxTokens** — maximum tokens the LLM may generate in one response. Sent as the `max_tokens` request parameter.
+- **ContextWindow** — the model's total input context capacity.
+
+These were previously set to the same value, which caused the summarization threshold to fire either far too early (at the default 32K) or not at all (when a user raised `max_tokens`).
+
+Current default when not explicitly configured: `ContextWindow = MaxTokens * 4`.
+
+---
+
+## Session history
+
+Session history stores only conversation messages:
+
+- `user` — user input
+- `assistant` — LLM response (may include `ToolCalls`)
+- `tool` — tool execution results
+
+Session history does **not** contain:
+
+- System prompts — assembled at request time by `BuildMessages`
+- Summary content — stored separately via `SetSummary`, injected by `BuildMessages`
+
+This distinction matters: any code that operates on session history — compression, boundary detection, token estimation — must not assume a system message is present.
+
+---
+
+## Turn
+
+A **Turn** is one complete cycle:
+
+> user message -> LLM iterations (possibly including tool calls) -> final assistant response
+
+This definition comes from the agent loop design (#1316). In session history, Turn boundaries are identified by `user`-role messages.
+
+Turn is the atomic unit for compression. Cutting inside a Turn can orphan tool-call sequences — an assistant message with `ToolCalls` separated from its corresponding `tool` results. Compressing at Turn boundaries avoids this by construction.
+
+`parseTurnBoundaries(history)` returns the starting index of each Turn.
+`findSafeBoundary(history, targetIndex)` snaps a target cut point to the nearest Turn boundary.
+
+---
+
+## Compression paths
+
+Three compression paths exist, in order of preference:
+
+### 1. Async summarization
+
+`maybeSummarize` runs after each Turn completes.
+
+Triggers when message count exceeds a threshold, or when estimated history tokens exceed a percentage of `ContextWindow`. If triggered, a background goroutine calls the LLM to produce a summary of the oldest messages. The summary is stored via `SetSummary`; `BuildMessages` injects it into the system prompt on the next call.
+
+Cut point uses `findSafeBoundary` so no Turn is split.
+
+### 2. Proactive budget check
+
+`isOverContextBudget` runs before each LLM call.
+
+Uses the full budget formula: `message_tokens + tool_def_tokens + MaxTokens > ContextWindow`. If over budget, triggers `forceCompression` and rebuilds messages before calling the LLM.
+
+This prevents wasted (and billed) LLM calls that would otherwise fail with a context-window error.
+
+### 3. Emergency compression (reactive)
+
+`forceCompression` runs when the LLM returns a context-window error despite the proactive check.
+
+Drops the oldest ~50% of Turns. Stores a compression note in the session summary (not in history messages) so `BuildMessages` can include it in the next system prompt.
+
+This is the fallback for when the token estimate undershoots reality.
+
+---
+
+## Token estimation
+
+Estimation uses a heuristic of ~2.5 characters per token (`chars * 2 / 5`).
+
+`estimateMessageTokens` counts:
+
+- `Content` (rune count, for multibyte correctness)
+- `ReasoningContent` (extended thinking / chain-of-thought)
+- `ToolCalls` — ID, type, function name, arguments
+- `ToolCallID` (tool result metadata)
+- Per-message overhead (role label, JSON structure)
+- `Media` items — flat per-item token estimate, added directly to the final count (not through the character heuristic, since actual cost depends on resolution and provider-specific image tokenization)
+
+`estimateToolDefsTokens` counts tool definition overhead: name, description, JSON schema of parameters.
+
+These are deliberately heuristic. The proactive check handles the common case; the reactive path catches estimation errors.
+
+---
+
+## Interface boundaries
+
+Context budget functions (`parseTurnBoundaries`, `findSafeBoundary`, `estimateMessageTokens`, `isOverContextBudget`) are **pure functions**. They take `[]providers.Message` and integer parameters. They have no dependency on `AgentLoop` or any other runtime struct.
+
+`BuildMessages` is the sole assembler of the final message array sent to the LLM. Budget functions inform compression decisions but do not construct messages.
+
+`forceCompression` and `summarizeSession` mutate session state (history and summary). `BuildMessages` reads that state to construct context. The flow is:
+
+```
+budget check --> compression decision --> mutate session --> BuildMessages reads session --> LLM call
+```
+
+---
+
+## Known gaps
+
+These are recognized limitations in the current implementation, documented here for visibility:
+
+- **Summarization trigger does not use the full budget formula.** `maybeSummarize` compares estimated history tokens against a percentage of `ContextWindow`. It does not account for system prompt size, tool definition overhead, or `MaxTokens` reserve. The proactive check covers the critical path (preventing 400 errors), but the summarization trigger could be aligned with the same budget model for more accurate early compression.
+
+- **Token estimation is heuristic.** It does not account for provider-specific tokenization, exact system prompt size (assembled separately), or variable image token costs. The two-path design (proactive + reactive) is intended to tolerate this imprecision.
+
+- **Reactive retry does not preserve media.** When the reactive path rebuilds context after compression, it currently passes empty values for media references. This is a pre-existing issue in the main loop, not introduced by the budget system.
+
+---
+
+## What this document does not cover
+
+- How `AGENT.md` frontmatter configures context parameters — that is part of the Agent definition work
+- How the context builder assembles context in the new architecture — that is upcoming work
+- How compression events surface through the event system — that is part of the event model (#1316)
+- Subagent context isolation — that is a separate track
From c63c6449b4a3a9fbe15fb2a269eddddc8817084f Mon Sep 17 00:00:00 2001
From: xiaoen <2768753269@qq.com>
Date: Tue, 17 Mar 2026 10:23:16 +0800
Subject: [PATCH 13/26] fix(agent): forceCompression recovers from single
oversized Turn
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
When the entire session history is a single Turn (e.g. one user message
followed by a massive tool response), findSafeBoundary returns 0 and
forceCompression previously did nothing — leaving the agent stuck in
a context-exceeded retry loop.
Now falls back to keeping only the most recent user message when no
safe Turn boundary exists. This breaks Turn atomicity as a last resort
but guarantees the agent can recover.
Also updates docs/agent-refactor/context.md to document this behavior.
Ref #1490
---
docs/agent-refactor/context.md | 4 +++-
pkg/agent/loop.go | 22 +++++++++++++++++++---
2 files changed, 22 insertions(+), 4 deletions(-)
diff --git a/docs/agent-refactor/context.md b/docs/agent-refactor/context.md
index 785fae2be..2269d9258 100644
--- a/docs/agent-refactor/context.md
+++ b/docs/agent-refactor/context.md
@@ -103,7 +103,9 @@ This prevents wasted (and billed) LLM calls that would otherwise fail with a con
`forceCompression` runs when the LLM returns a context-window error despite the proactive check.
-Drops the oldest ~50% of Turns. Stores a compression note in the session summary (not in history messages) so `BuildMessages` can include it in the next system prompt.
+Drops the oldest ~50% of Turns. If the history is a single Turn with no safe split point (e.g. one user message followed by a massive tool response), falls back to keeping only the most recent user message — breaking Turn atomicity as a last resort to avoid a context-exceeded loop.
+
+Stores a compression note in the session summary (not in history messages) so `BuildMessages` can include it in the next system prompt.
This is the fallback for when the token estimate undershoots reality.
diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go
index 688d0ed1d..c583f5ca5 100644
--- a/pkg/agent/loop.go
+++ b/pkg/agent/loop.go
@@ -1559,6 +1559,10 @@ func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey, channel, c
// It drops the oldest ~50% of Turns (a Turn is a complete user→LLM→response
// cycle, as defined in #1316), so tool-call sequences are never split.
//
+// If the history is a single Turn with no safe split point, the function
+// falls back to keeping only the most recent user message. This breaks
+// Turn atomicity as a last resort to avoid a context-exceeded loop.
+//
// Session history contains only user/assistant/tool messages — the system
// prompt is built dynamically by BuildMessages and is NOT stored here.
// The compression note is recorded in the session summary so that
@@ -1581,12 +1585,24 @@ func (al *AgentLoop) forceCompression(agent *AgentInstance, sessionKey string) {
// aligned to the nearest Turn boundary.
mid = findSafeBoundary(history, len(history)/2)
}
+ var keptHistory []providers.Message
if mid <= 0 {
- return
+ // No safe Turn boundary — the entire history is a single Turn
+ // (e.g. one user message followed by a massive tool response).
+ // Keeping everything would leave the agent stuck in a context-
+ // exceeded loop, so fall back to keeping only the most recent
+ // user message. This breaks Turn atomicity as a last resort.
+ for i := len(history) - 1; i >= 0; i-- {
+ if history[i].Role == "user" {
+ keptHistory = []providers.Message{history[i]}
+ break
+ }
+ }
+ } else {
+ keptHistory = history[mid:]
}
- droppedCount := mid
- keptHistory := history[mid:]
+ droppedCount := len(history) - len(keptHistory)
// Record compression in the session summary so BuildMessages includes it
// in the system prompt. We do not modify history messages themselves.
From 899558bbfaf89414696070d240b6718628c93c52 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E7=BE=8E=E9=9B=BB=E7=90=83?=
Date: Wed, 18 Mar 2026 22:42:57 +0800
Subject: [PATCH 14/26] Feat/issue 1218 agent md context structure (#1705)
* feat(agent): add structured agent definition loader
Parse AGENT.md frontmatter into a runtime definition and pair it with SOUL.md while keeping a legacy AGENTS.md fallback for transition.
Refs #1218
* refactor(agent): build context from structured agent files
Use AGENT.md and SOUL.md as the structured bootstrap source, ignore IDENTITY.md for structured agents, remove USER.md from the new context flow, and update pkg/agent tests accordingly.
Refs #1218
* refactor(onboard): switch workspace templates to AGENT.md
Replace the legacy AGENTS.md, IDENTITY.md, and USER.md templates with a structured AGENT.md plus SOUL.md, and update the onboard template test to assert the new generated files.
Refs #1218
* docs(readme): update workspace layout for AGENT.md
Refresh the documented workspace tree across the README translations so onboarding now points to AGENT.md and SOUL.md instead of the retired AGENTS.md, IDENTITY.md, and USER.md files.
Refs #1218
* feat(agent): restore workspace USER.md context
* docs(readme): document workspace USER.md layout
* fix: sort agent definition imports for gci
---
README.fr.md | 5 +-
README.ja.md | 5 +-
README.md | 18 +-
README.pt-br.md | 5 +-
README.vi.md | 5 +-
README.zh.md | 18 +-
cmd/picoclaw/internal/onboard/helpers_test.go | 26 +-
pkg/agent/context.go | 43 ++-
pkg/agent/context_cache_test.go | 20 +-
pkg/agent/definition.go | 255 +++++++++++++++
pkg/agent/definition_test.go | 302 ++++++++++++++++++
workspace/AGENT.md | 45 +++
workspace/AGENTS.md | 12 -
workspace/IDENTITY.md | 53 ---
workspace/SOUL.md | 6 +-
workspace/USER.md | 4 +-
16 files changed, 690 insertions(+), 132 deletions(-)
create mode 100644 pkg/agent/definition.go
create mode 100644 pkg/agent/definition_test.go
create mode 100644 workspace/AGENT.md
delete mode 100644 workspace/AGENTS.md
delete mode 100644 workspace/IDENTITY.md
diff --git a/README.fr.md b/README.fr.md
index d5fe873bf..97dabe125 100644
--- a/README.fr.md
+++ b/README.fr.md
@@ -653,11 +653,10 @@ PicoClaw stocke les données dans votre workspace configuré (par défaut : `~/.
├── state/ # État persistant (dernier canal, etc.)
├── cron/ # Base de données des tâches planifiées
├── skills/ # Compétences personnalisées
-├── AGENTS.md # Guide de comportement de l'Agent
+├── AGENT.md # Définition structurée de l'agent et prompt système
├── HEARTBEAT.md # Invites de tâches périodiques (vérifiées toutes les 30 min)
-├── IDENTITY.md # Identité de l'Agent
├── SOUL.md # Âme de l'Agent
-└── USER.md # Préférences utilisateur
+└── ...
```
### 🔒 Bac à Sable de Sécurité
diff --git a/README.ja.md b/README.ja.md
index 7fff46d13..3f43e29ad 100644
--- a/README.ja.md
+++ b/README.ja.md
@@ -617,11 +617,10 @@ PicoClaw は設定されたワークスペース(デフォルト: `~/.picoclaw
├── state/ # 永続状態(最後のチャネルなど)
├── cron/ # スケジュールジョブデータベース
├── skills/ # カスタムスキル
-├── AGENTS.md # エージェントの行動ガイド
+├── AGENT.md # 構造化されたエージェント定義とシステムプロンプト
├── HEARTBEAT.md # 定期タスクプロンプト(30分ごとに確認)
-├── IDENTITY.md # エージェントのアイデンティティ
├── SOUL.md # エージェントのソウル
-└── USER.md # ユーザー設定
+└── ...
```
### 🔒 セキュリティサンドボックス
diff --git a/README.md b/README.md
index e64daf0e4..75ad7255a 100644
--- a/README.md
+++ b/README.md
@@ -784,15 +784,15 @@ PicoClaw stores data in your configured workspace (default: `~/.picoclaw/workspa
```
~/.picoclaw/workspace/
├── sessions/ # Conversation sessions and history
-├── memory/ # Long-term memory (MEMORY.md)
-├── state/ # Persistent state (last channel, etc.)
-├── cron/ # Scheduled jobs database
-├── skills/ # Custom skills
-├── AGENTS.md # Agent behavior guide
-├── HEARTBEAT.md # Periodic task prompts (checked every 30 min)
-├── IDENTITY.md # Agent identity
-├── SOUL.md # Agent soul
-└── USER.md # User preferences
+├── memory/ # Long-term memory (MEMORY.md)
+├── state/ # Persistent state (last channel, etc.)
+├── cron/ # Scheduled jobs database
+├── skills/ # Workspace-specific skills
+├── AGENT.md # Structured agent definition and system prompt
+├── SOUL.md # Agent soul
+├── USER.md # User profile and preferences for this workspace
+├── HEARTBEAT.md # Periodic task prompts (checked every 30 min)
+└── ...
```
### Skill Sources
diff --git a/README.pt-br.md b/README.pt-br.md
index 3fe24d7ea..fab8b8b0f 100644
--- a/README.pt-br.md
+++ b/README.pt-br.md
@@ -649,11 +649,10 @@ O PicoClaw armazena dados no workspace configurado (padrão: `~/.picoclaw/worksp
├── state/ # Estado persistente (ultimo canal, etc.)
├── cron/ # Banco de dados de tarefas agendadas
├── skills/ # Skills personalizadas
-├── AGENTS.md # Guia de comportamento do Agente
+├── AGENT.md # Definicao estruturada do agente e prompt do sistema
├── HEARTBEAT.md # Prompts de tarefas periodicas (verificado a cada 30 min)
-├── IDENTITY.md # Identidade do Agente
├── SOUL.md # Alma do Agente
-└── USER.md # Preferencias do usuario
+└── ...
```
### 🔒 Sandbox de Segurança
diff --git a/README.vi.md b/README.vi.md
index 3ee0209f6..337e3d68a 100644
--- a/README.vi.md
+++ b/README.vi.md
@@ -621,11 +621,10 @@ PicoClaw lưu trữ dữ liệu trong workspace đã cấu hình (mặc định:
├── state/ # Trạng thái lưu trữ (kênh cuối cùng, v.v.)
├── cron/ # Cơ sở dữ liệu tác vụ định kỳ
├── skills/ # Kỹ năng tùy chỉnh
-├── AGENTS.md # Hướng dẫn hành vi Agent
+├── AGENT.md # Định nghĩa agent có cấu trúc và system prompt
├── HEARTBEAT.md # Prompt tác vụ định kỳ (kiểm tra mỗi 30 phút)
-├── IDENTITY.md # Danh tính Agent
├── SOUL.md # Tâm hồn/Tính cách Agent
-└── USER.md # Tùy chọn người dùng
+└── ...
```
### 🔒 Hộp cát bảo mật (Security Sandbox)
diff --git a/README.zh.md b/README.zh.md
index 66d7c5f7c..aba133eef 100644
--- a/README.zh.md
+++ b/README.zh.md
@@ -365,15 +365,15 @@ PicoClaw 将数据存储在您配置的工作区中(默认:`~/.picoclaw/work
```
~/.picoclaw/workspace/
├── sessions/ # 对话会话和历史
-├── memory/ # 长期记忆 (MEMORY.md)
-├── state/ # 持久化状态 (最后一次频道等)
-├── cron/ # 定时任务数据库
-├── skills/ # 自定义技能
-├── AGENTS.md # Agent 行为指南
-├── HEARTBEAT.md # 周期性任务提示词 (每 30 分钟检查一次)
-├── IDENTITY.md # Agent 身份设定
-├── SOUL.md # Agent 灵魂/性格
-└── USER.md # 用户偏好
+├── memory/ # 长期记忆 (MEMORY.md)
+├── state/ # 持久化状态 (最后一次频道等)
+├── cron/ # 定时任务数据库
+├── skills/ # 工作区级技能
+├── AGENT.md # 结构化 Agent 定义与系统提示词
+├── SOUL.md # Agent 灵魂/性格
+├── USER.md # 当前工作区的用户资料与偏好
+├── HEARTBEAT.md # 周期性任务提示词 (每 30 分钟检查一次)
+└── ...
```
diff --git a/cmd/picoclaw/internal/onboard/helpers_test.go b/cmd/picoclaw/internal/onboard/helpers_test.go
index f3e0c92e0..23fc97c5a 100644
--- a/cmd/picoclaw/internal/onboard/helpers_test.go
+++ b/cmd/picoclaw/internal/onboard/helpers_test.go
@@ -6,20 +6,32 @@ import (
"testing"
)
-func TestCopyEmbeddedToTargetUsesAgentsMarkdown(t *testing.T) {
+func TestCopyEmbeddedToTargetUsesStructuredAgentFiles(t *testing.T) {
targetDir := t.TempDir()
if err := copyEmbeddedToTarget(targetDir); err != nil {
t.Fatalf("copyEmbeddedToTarget() error = %v", err)
}
- agentsPath := filepath.Join(targetDir, "AGENTS.md")
- if _, err := os.Stat(agentsPath); err != nil {
- t.Fatalf("expected %s to exist: %v", agentsPath, err)
+ agentPath := filepath.Join(targetDir, "AGENT.md")
+ if _, err := os.Stat(agentPath); err != nil {
+ t.Fatalf("expected %s to exist: %v", agentPath, err)
}
- legacyPath := filepath.Join(targetDir, "AGENT.md")
- if _, err := os.Stat(legacyPath); !os.IsNotExist(err) {
- t.Fatalf("expected legacy file %s to be absent, got err=%v", legacyPath, err)
+ soulPath := filepath.Join(targetDir, "SOUL.md")
+ if _, err := os.Stat(soulPath); err != nil {
+ t.Fatalf("expected %s to exist: %v", soulPath, err)
+ }
+
+ userPath := filepath.Join(targetDir, "USER.md")
+ if _, err := os.Stat(userPath); err != nil {
+ t.Fatalf("expected %s to exist: %v", userPath, err)
+ }
+
+ for _, legacyName := range []string{"AGENTS.md", "IDENTITY.md"} {
+ legacyPath := filepath.Join(targetDir, legacyName)
+ if _, err := os.Stat(legacyPath); !os.IsNotExist(err) {
+ t.Fatalf("expected legacy file %s to be absent, got err=%v", legacyPath, err)
+ }
}
}
diff --git a/pkg/agent/context.go b/pkg/agent/context.go
index 5a84c45e2..cb566f02b 100644
--- a/pkg/agent/context.go
+++ b/pkg/agent/context.go
@@ -222,13 +222,10 @@ func (cb *ContextBuilder) InvalidateCache() {
// invalidation (bootstrap files + memory). Skill roots are handled separately
// because they require both directory-level and recursive file-level checks.
func (cb *ContextBuilder) sourcePaths() []string {
- return []string{
- filepath.Join(cb.workspace, "AGENTS.md"),
- filepath.Join(cb.workspace, "SOUL.md"),
- filepath.Join(cb.workspace, "USER.md"),
- filepath.Join(cb.workspace, "IDENTITY.md"),
- filepath.Join(cb.workspace, "memory", "MEMORY.md"),
- }
+ agentDefinition := cb.LoadAgentDefinition()
+ paths := agentDefinition.trackedPaths(cb.workspace)
+ paths = append(paths, filepath.Join(cb.workspace, "memory", "MEMORY.md"))
+ return uniquePaths(paths)
}
// skillRoots returns all skill root directories that can affect
@@ -432,18 +429,32 @@ func skillFilesChangedSince(skillRoots []string, filesAtCache map[string]time.Ti
}
func (cb *ContextBuilder) LoadBootstrapFiles() string {
- bootstrapFiles := []string{
- "AGENTS.md",
- "SOUL.md",
- "USER.md",
- "IDENTITY.md",
+ var sb strings.Builder
+
+ agentDefinition := cb.LoadAgentDefinition()
+ if agentDefinition.Agent != nil {
+ label := string(agentDefinition.Source)
+ if label == "" {
+ label = relativeWorkspacePath(cb.workspace, agentDefinition.Agent.Path)
+ }
+ fmt.Fprintf(&sb, "## %s\n\n%s\n\n", label, agentDefinition.Agent.Body)
+ }
+ if agentDefinition.Soul != nil {
+ fmt.Fprintf(
+ &sb,
+ "## %s\n\n%s\n\n",
+ relativeWorkspacePath(cb.workspace, agentDefinition.Soul.Path),
+ agentDefinition.Soul.Content,
+ )
+ }
+ if agentDefinition.User != nil {
+ fmt.Fprintf(&sb, "## %s\n\n%s\n\n", "USER.md", agentDefinition.User.Content)
}
- var sb strings.Builder
- for _, filename := range bootstrapFiles {
- filePath := filepath.Join(cb.workspace, filename)
+ if agentDefinition.Source != AgentDefinitionSourceAgent {
+ filePath := filepath.Join(cb.workspace, "IDENTITY.md")
if data, err := os.ReadFile(filePath); err == nil {
- fmt.Fprintf(&sb, "## %s\n\n%s\n\n", filename, data)
+ fmt.Fprintf(&sb, "## %s\n\n%s\n\n", "IDENTITY.md", data)
}
}
diff --git a/pkg/agent/context_cache_test.go b/pkg/agent/context_cache_test.go
index 707510820..1f9423a3a 100644
--- a/pkg/agent/context_cache_test.go
+++ b/pkg/agent/context_cache_test.go
@@ -37,7 +37,7 @@ func setupWorkspace(t *testing.T, files map[string]string) string {
// Codex (only reads last system message as instructions).
func TestSingleSystemMessage(t *testing.T) {
tmpDir := setupWorkspace(t, map[string]string{
- "IDENTITY.md": "# Identity\nTest agent.",
+ "AGENT.md": "# Agent\nTest agent.",
})
defer os.RemoveAll(tmpDir)
@@ -140,10 +140,10 @@ func TestMtimeAutoInvalidation(t *testing.T) {
}{
{
name: "bootstrap file change",
- file: "IDENTITY.md",
- contentV1: "# Original Identity",
- contentV2: "# Updated Identity",
- checkField: "Updated Identity",
+ file: "AGENT.md",
+ contentV1: "# Original Agent",
+ contentV2: "# Updated Agent",
+ checkField: "Updated Agent",
},
{
name: "memory file change",
@@ -218,7 +218,7 @@ func TestMtimeAutoInvalidation(t *testing.T) {
// even when source files haven't changed (useful for tests and reload commands).
func TestExplicitInvalidateCache(t *testing.T) {
tmpDir := setupWorkspace(t, map[string]string{
- "IDENTITY.md": "# Test Identity",
+ "AGENT.md": "# Test Agent",
})
defer os.RemoveAll(tmpDir)
@@ -245,8 +245,8 @@ func TestExplicitInvalidateCache(t *testing.T) {
// when no files change (regression test for issue #607).
func TestCacheStability(t *testing.T) {
tmpDir := setupWorkspace(t, map[string]string{
- "IDENTITY.md": "# Identity\nContent",
- "SOUL.md": "# Soul\nContent",
+ "AGENT.md": "# Agent\nContent",
+ "SOUL.md": "# Soul\nContent",
})
defer os.RemoveAll(tmpDir)
@@ -545,7 +545,7 @@ description: delete-me-v1
// Run with: go test -race ./pkg/agent/ -run TestConcurrentBuildSystemPromptWithCache
func TestConcurrentBuildSystemPromptWithCache(t *testing.T) {
tmpDir := setupWorkspace(t, map[string]string{
- "IDENTITY.md": "# Identity\nConcurrency test agent.",
+ "AGENT.md": "# Agent\nConcurrency test agent.",
"SOUL.md": "# Soul\nBe helpful.",
"memory/MEMORY.md": "# Memory\nUser prefers Go.",
"skills/demo/SKILL.md": "---\nname: demo\ndescription: \"demo skill\"\n---\n# Demo",
@@ -652,7 +652,7 @@ func BenchmarkBuildMessagesWithCache(b *testing.B) {
os.MkdirAll(filepath.Join(tmpDir, "memory"), 0o755)
os.MkdirAll(filepath.Join(tmpDir, "skills"), 0o755)
- for _, name := range []string{"IDENTITY.md", "SOUL.md", "USER.md"} {
+ for _, name := range []string{"AGENT.md", "SOUL.md"} {
os.WriteFile(filepath.Join(tmpDir, name), []byte(strings.Repeat("Content.\n", 10)), 0o644)
}
diff --git a/pkg/agent/definition.go b/pkg/agent/definition.go
new file mode 100644
index 000000000..cf73d607c
--- /dev/null
+++ b/pkg/agent/definition.go
@@ -0,0 +1,255 @@
+package agent
+
+import (
+ "os"
+ "path/filepath"
+ "slices"
+ "strings"
+
+ "github.com/gomarkdown/markdown/parser"
+ "gopkg.in/yaml.v3"
+
+ "github.com/sipeed/picoclaw/pkg/logger"
+)
+
+// AgentDefinitionSource identifies which agent bootstrap file produced the definition.
+type AgentDefinitionSource string
+
+const (
+ // AgentDefinitionSourceAgent indicates the new AGENT.md format.
+ AgentDefinitionSourceAgent AgentDefinitionSource = "AGENT.md"
+ // AgentDefinitionSourceAgents indicates the legacy AGENTS.md format.
+ AgentDefinitionSourceAgents AgentDefinitionSource = "AGENTS.md"
+)
+
+// AgentFrontmatter holds machine-readable AGENT.md configuration.
+//
+// Known fields are exposed directly for convenience. Fields keeps the full
+// parsed frontmatter so future refactors can read additional keys without
+// changing the loader contract again.
+type AgentFrontmatter struct {
+ Name string `json:"name"`
+ Description string `json:"description"`
+ Tools []string `json:"tools,omitempty"`
+ Model string `json:"model,omitempty"`
+ MaxTurns *int `json:"maxTurns,omitempty"`
+ Skills []string `json:"skills,omitempty"`
+ MCPServers []string `json:"mcpServers,omitempty"`
+ Fields map[string]any `json:"fields,omitempty"`
+}
+
+// AgentPromptDefinition represents the parsed AGENT.md or AGENTS.md prompt file.
+type AgentPromptDefinition struct {
+ Path string `json:"path"`
+ Raw string `json:"raw"`
+ Body string `json:"body"`
+ RawFrontmatter string `json:"raw_frontmatter,omitempty"`
+ Frontmatter AgentFrontmatter `json:"frontmatter"`
+}
+
+// SoulDefinition represents the resolved SOUL.md file linked to the agent.
+type SoulDefinition struct {
+ Path string `json:"path"`
+ Content string `json:"content"`
+}
+
+// UserDefinition represents the resolved USER.md file linked to the workspace.
+type UserDefinition struct {
+ Path string `json:"path"`
+ Content string `json:"content"`
+}
+
+// AgentContextDefinition captures the workspace agent definition in a runtime-friendly shape.
+type AgentContextDefinition struct {
+ Source AgentDefinitionSource `json:"source,omitempty"`
+ Agent *AgentPromptDefinition `json:"agent,omitempty"`
+ Soul *SoulDefinition `json:"soul,omitempty"`
+ User *UserDefinition `json:"user,omitempty"`
+}
+
+// LoadAgentDefinition parses the workspace agent bootstrap files.
+//
+// It prefers the new AGENT.md format and its paired SOUL.md file. When the
+// structured files are absent, it falls back to the legacy AGENTS.md layout so
+// the current runtime can transition incrementally.
+func (cb *ContextBuilder) LoadAgentDefinition() AgentContextDefinition {
+ return loadAgentDefinition(cb.workspace)
+}
+
+func loadAgentDefinition(workspace string) AgentContextDefinition {
+ definition := AgentContextDefinition{}
+ definition.User = loadUserDefinition(workspace)
+ agentPath := filepath.Join(workspace, string(AgentDefinitionSourceAgent))
+ if content, err := os.ReadFile(agentPath); err == nil {
+ prompt := parseAgentPromptDefinition(agentPath, string(content))
+ definition.Source = AgentDefinitionSourceAgent
+ definition.Agent = &prompt
+ soulPath := filepath.Join(workspace, "SOUL.md")
+ if content, err := os.ReadFile(soulPath); err == nil {
+ definition.Soul = &SoulDefinition{
+ Path: soulPath,
+ Content: string(content),
+ }
+ }
+ return definition
+ }
+
+ legacyPath := filepath.Join(workspace, string(AgentDefinitionSourceAgents))
+ if content, err := os.ReadFile(legacyPath); err == nil {
+ definition.Source = AgentDefinitionSourceAgents
+ definition.Agent = &AgentPromptDefinition{
+ Path: legacyPath,
+ Raw: string(content),
+ Body: string(content),
+ }
+ }
+
+ defaultSoulPath := filepath.Join(workspace, "SOUL.md")
+ if definition.Source != "" || fileExists(defaultSoulPath) {
+ if content, err := os.ReadFile(defaultSoulPath); err == nil {
+ definition.Soul = &SoulDefinition{
+ Path: defaultSoulPath,
+ Content: string(content),
+ }
+ }
+ }
+
+ return definition
+}
+
+func (definition AgentContextDefinition) trackedPaths(workspace string) []string {
+ paths := []string{
+ filepath.Join(workspace, string(AgentDefinitionSourceAgent)),
+ filepath.Join(workspace, "SOUL.md"),
+ filepath.Join(workspace, "USER.md"),
+ }
+ if definition.Source != AgentDefinitionSourceAgent {
+ paths = append(paths,
+ filepath.Join(workspace, string(AgentDefinitionSourceAgents)),
+ filepath.Join(workspace, "IDENTITY.md"),
+ )
+ }
+ return uniquePaths(paths)
+}
+
+func loadUserDefinition(workspace string) *UserDefinition {
+ userPath := filepath.Join(workspace, "USER.md")
+ if content, err := os.ReadFile(userPath); err == nil {
+ return &UserDefinition{
+ Path: userPath,
+ Content: string(content),
+ }
+ }
+
+ return nil
+}
+
+func parseAgentPromptDefinition(path, content string) AgentPromptDefinition {
+ frontmatter, body := splitAgentFrontmatter(content)
+ return AgentPromptDefinition{
+ Path: path,
+ Raw: content,
+ Body: body,
+ RawFrontmatter: frontmatter,
+ Frontmatter: parseAgentFrontmatter(path, frontmatter),
+ }
+}
+
+func parseAgentFrontmatter(path, frontmatter string) AgentFrontmatter {
+ frontmatter = strings.TrimSpace(frontmatter)
+ if frontmatter == "" {
+ return AgentFrontmatter{}
+ }
+
+ rawFields := make(map[string]any)
+ if err := yaml.Unmarshal([]byte(frontmatter), &rawFields); err != nil {
+ logger.WarnCF("agent", "Failed to parse AGENT.md frontmatter", map[string]any{
+ "path": path,
+ "error": err.Error(),
+ })
+ return AgentFrontmatter{}
+ }
+
+ var typed struct {
+ Name string `yaml:"name"`
+ Description string `yaml:"description"`
+ Tools []string `yaml:"tools"`
+ Model string `yaml:"model"`
+ MaxTurns *int `yaml:"maxTurns"`
+ Skills []string `yaml:"skills"`
+ MCPServers []string `yaml:"mcpServers"`
+ }
+ if err := yaml.Unmarshal([]byte(frontmatter), &typed); err != nil {
+ logger.WarnCF("agent", "Failed to decode AGENT.md frontmatter fields", map[string]any{
+ "path": path,
+ "error": err.Error(),
+ })
+ return AgentFrontmatter{}
+ }
+
+ return AgentFrontmatter{
+ Name: strings.TrimSpace(typed.Name),
+ Description: strings.TrimSpace(typed.Description),
+ Tools: append([]string(nil), typed.Tools...),
+ Model: strings.TrimSpace(typed.Model),
+ MaxTurns: typed.MaxTurns,
+ Skills: append([]string(nil), typed.Skills...),
+ MCPServers: append([]string(nil), typed.MCPServers...),
+ Fields: rawFields,
+ }
+}
+
+func splitAgentFrontmatter(content string) (frontmatter, body string) {
+ normalized := string(parser.NormalizeNewlines([]byte(content)))
+ lines := strings.Split(normalized, "\n")
+ if len(lines) == 0 || lines[0] != "---" {
+ return "", content
+ }
+
+ end := -1
+ for i := 1; i < len(lines); i++ {
+ if lines[i] == "---" {
+ end = i
+ break
+ }
+ }
+ if end == -1 {
+ return "", content
+ }
+
+ frontmatter = strings.Join(lines[1:end], "\n")
+ body = strings.Join(lines[end+1:], "\n")
+ body = strings.TrimLeft(body, "\n")
+ return frontmatter, body
+}
+
+func relativeWorkspacePath(workspace, path string) string {
+ if strings.TrimSpace(path) == "" {
+ return ""
+ }
+ relativePath, err := filepath.Rel(workspace, path)
+ if err == nil && relativePath != "." && !strings.HasPrefix(relativePath, "..") {
+ return filepath.ToSlash(relativePath)
+ }
+ return filepath.Clean(path)
+}
+
+func uniquePaths(paths []string) []string {
+ result := make([]string, 0, len(paths))
+ for _, path := range paths {
+ if strings.TrimSpace(path) == "" {
+ continue
+ }
+ cleaned := filepath.Clean(path)
+ if slices.Contains(result, cleaned) {
+ continue
+ }
+ result = append(result, cleaned)
+ }
+ return result
+}
+
+func fileExists(path string) bool {
+ _, err := os.Stat(path)
+ return err == nil
+}
diff --git a/pkg/agent/definition_test.go b/pkg/agent/definition_test.go
new file mode 100644
index 000000000..5ee996967
--- /dev/null
+++ b/pkg/agent/definition_test.go
@@ -0,0 +1,302 @@
+package agent
+
+import (
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
+ "time"
+)
+
+func TestLoadAgentDefinitionParsesFrontmatterAndSoul(t *testing.T) {
+ tmpDir := setupWorkspace(t, map[string]string{
+ "AGENT.md": `---
+name: pico
+description: Structured agent
+model: claude-3-7-sonnet
+tools:
+ - shell
+ - search
+maxTurns: 8
+skills:
+ - review
+ - search-docs
+mcpServers:
+ - github
+metadata:
+ mode: strict
+---
+# Agent
+
+Act directly and use tools first.
+`,
+ "SOUL.md": "# Soul\nStay precise.",
+ })
+ defer cleanupWorkspace(t, tmpDir)
+
+ cb := NewContextBuilder(tmpDir)
+ definition := cb.LoadAgentDefinition()
+
+ if definition.Source != AgentDefinitionSourceAgent {
+ t.Fatalf("expected source %q, got %q", AgentDefinitionSourceAgent, definition.Source)
+ }
+ if definition.Agent == nil {
+ t.Fatal("expected AGENT.md definition to be loaded")
+ }
+ if definition.Agent.Body == "" || !strings.Contains(definition.Agent.Body, "Act directly") {
+ t.Fatalf("expected AGENT.md body to be preserved, got %q", definition.Agent.Body)
+ }
+ if definition.Agent.Frontmatter.Name != "pico" {
+ t.Fatalf("expected name to be parsed, got %q", definition.Agent.Frontmatter.Name)
+ }
+ if definition.Agent.Frontmatter.Model != "claude-3-7-sonnet" {
+ t.Fatalf("expected model to be parsed, got %q", definition.Agent.Frontmatter.Model)
+ }
+ if len(definition.Agent.Frontmatter.Tools) != 2 {
+ t.Fatalf("expected tools to be parsed, got %v", definition.Agent.Frontmatter.Tools)
+ }
+ if definition.Agent.Frontmatter.MaxTurns == nil || *definition.Agent.Frontmatter.MaxTurns != 8 {
+ t.Fatalf("expected maxTurns to be parsed, got %v", definition.Agent.Frontmatter.MaxTurns)
+ }
+ if len(definition.Agent.Frontmatter.Skills) != 2 {
+ t.Fatalf("expected skills to be parsed, got %v", definition.Agent.Frontmatter.Skills)
+ }
+ if len(definition.Agent.Frontmatter.MCPServers) != 1 || definition.Agent.Frontmatter.MCPServers[0] != "github" {
+ t.Fatalf("expected mcpServers to be parsed, got %v", definition.Agent.Frontmatter.MCPServers)
+ }
+ if definition.Agent.Frontmatter.Fields["metadata"] == nil {
+ t.Fatal("expected arbitrary frontmatter fields to remain available")
+ }
+
+ if definition.Soul == nil {
+ t.Fatal("expected SOUL.md to be loaded")
+ }
+ if !strings.Contains(definition.Soul.Content, "Stay precise") {
+ t.Fatalf("expected soul content to be loaded, got %q", definition.Soul.Content)
+ }
+ if definition.Soul.Path != filepath.Join(tmpDir, "SOUL.md") {
+ t.Fatalf("expected default SOUL.md path, got %q", definition.Soul.Path)
+ }
+}
+
+func TestLoadAgentDefinitionFallsBackToLegacyAgentsMarkdown(t *testing.T) {
+ tmpDir := setupWorkspace(t, map[string]string{
+ "AGENTS.md": "# Legacy Agent\nKeep compatibility.",
+ "SOUL.md": "# Soul\nLegacy soul.",
+ })
+ defer cleanupWorkspace(t, tmpDir)
+
+ cb := NewContextBuilder(tmpDir)
+ definition := cb.LoadAgentDefinition()
+
+ if definition.Source != AgentDefinitionSourceAgents {
+ t.Fatalf("expected source %q, got %q", AgentDefinitionSourceAgents, definition.Source)
+ }
+ if definition.Agent == nil {
+ t.Fatal("expected AGENTS.md to be loaded")
+ }
+ if definition.Agent.RawFrontmatter != "" {
+ t.Fatalf("legacy AGENTS.md should not have frontmatter, got %q", definition.Agent.RawFrontmatter)
+ }
+ if !strings.Contains(definition.Agent.Body, "Keep compatibility") {
+ t.Fatalf("expected legacy body to be preserved, got %q", definition.Agent.Body)
+ }
+ if definition.Soul == nil || !strings.Contains(definition.Soul.Content, "Legacy soul") {
+ t.Fatal("expected default SOUL.md to be loaded for legacy format")
+ }
+}
+
+func TestLoadAgentDefinitionLoadsWorkspaceUserMarkdown(t *testing.T) {
+ tmpDir := setupWorkspace(t, map[string]string{
+ "AGENT.md": "# Agent\nStructured agent.",
+ "USER.md": "# User\nWorkspace preferences.",
+ })
+ defer cleanupWorkspace(t, tmpDir)
+
+ cb := NewContextBuilder(tmpDir)
+ definition := cb.LoadAgentDefinition()
+
+ if definition.User == nil {
+ t.Fatal("expected USER.md to be loaded")
+ }
+ if definition.User.Path != filepath.Join(tmpDir, "USER.md") {
+ t.Fatalf("expected workspace USER.md path, got %q", definition.User.Path)
+ }
+ if !strings.Contains(definition.User.Content, "Workspace preferences") {
+ t.Fatalf("expected workspace USER.md content, got %q", definition.User.Content)
+ }
+}
+
+func TestLoadAgentDefinitionInvalidFrontmatterFallsBackToEmptyStructuredFields(t *testing.T) {
+ tmpDir := setupWorkspace(t, map[string]string{
+ "AGENT.md": `---
+name: pico
+tools:
+ - shell
+ broken
+---
+# Agent
+
+Keep going.
+`,
+ })
+ defer cleanupWorkspace(t, tmpDir)
+
+ cb := NewContextBuilder(tmpDir)
+ definition := cb.LoadAgentDefinition()
+
+ if definition.Agent == nil {
+ t.Fatal("expected AGENT.md definition to be loaded")
+ }
+ if !strings.Contains(definition.Agent.Body, "Keep going.") {
+ t.Fatalf("expected AGENT.md body to be preserved, got %q", definition.Agent.Body)
+ }
+ if definition.Agent.Frontmatter.Name != "" ||
+ definition.Agent.Frontmatter.Description != "" ||
+ definition.Agent.Frontmatter.Model != "" ||
+ definition.Agent.Frontmatter.MaxTurns != nil ||
+ len(definition.Agent.Frontmatter.Tools) != 0 ||
+ len(definition.Agent.Frontmatter.Skills) != 0 ||
+ len(definition.Agent.Frontmatter.MCPServers) != 0 ||
+ len(definition.Agent.Frontmatter.Fields) != 0 {
+ t.Fatalf("expected invalid frontmatter to decode as empty struct, got %+v", definition.Agent.Frontmatter)
+ }
+}
+
+func TestLoadBootstrapFilesUsesAgentBodyNotFrontmatter(t *testing.T) {
+ tmpDir := setupWorkspace(t, map[string]string{
+ "AGENT.md": `---
+name: pico
+model: codex-mini
+---
+# Agent
+
+Follow the body prompt.
+`,
+ "SOUL.md": "# Soul\nSpeak plainly.",
+ "IDENTITY.md": "# Identity\nWorkspace identity.",
+ })
+ defer cleanupWorkspace(t, tmpDir)
+
+ cb := NewContextBuilder(tmpDir)
+ bootstrap := cb.LoadBootstrapFiles()
+
+ if !strings.Contains(bootstrap, "Follow the body prompt") {
+ t.Fatalf("expected AGENT.md body in bootstrap, got %q", bootstrap)
+ }
+ if !strings.Contains(bootstrap, "Speak plainly") {
+ t.Fatalf("expected resolved soul content in bootstrap, got %q", bootstrap)
+ }
+ if strings.Contains(bootstrap, "name: pico") {
+ t.Fatalf("bootstrap should not expose raw frontmatter, got %q", bootstrap)
+ }
+ if strings.Contains(bootstrap, "model: codex-mini") {
+ t.Fatalf("bootstrap should not expose raw frontmatter, got %q", bootstrap)
+ }
+ if !strings.Contains(bootstrap, "SOUL.md") {
+ t.Fatalf("expected bootstrap to label SOUL.md, got %q", bootstrap)
+ }
+ if strings.Contains(bootstrap, "Workspace identity") {
+ t.Fatalf("structured bootstrap should ignore IDENTITY.md, got %q", bootstrap)
+ }
+}
+
+func TestLoadBootstrapFilesIncludesWorkspaceUserMarkdown(t *testing.T) {
+ tmpDir := setupWorkspace(t, map[string]string{
+ "AGENT.md": "# Agent\nFollow the new structure.",
+ "SOUL.md": "# Soul\nSpeak plainly.",
+ "USER.md": "# User\nShared profile.",
+ })
+ defer cleanupWorkspace(t, tmpDir)
+
+ cb := NewContextBuilder(tmpDir)
+ bootstrap := cb.LoadBootstrapFiles()
+
+ if !strings.Contains(bootstrap, "Shared profile") {
+ t.Fatalf("expected workspace USER.md in bootstrap, got %q", bootstrap)
+ }
+ if !strings.Contains(bootstrap, "## USER.md") {
+ t.Fatalf("expected USER.md heading in bootstrap, got %q", bootstrap)
+ }
+}
+
+func TestStructuredAgentIgnoresIdentityChanges(t *testing.T) {
+ tmpDir := setupWorkspace(t, map[string]string{
+ "AGENT.md": "# Agent\nFollow the new structure.",
+ "SOUL.md": "# Soul\nVersion one.",
+ "IDENTITY.md": "# Identity\nLegacy identity.",
+ })
+ defer cleanupWorkspace(t, tmpDir)
+
+ cb := NewContextBuilder(tmpDir)
+
+ promptV1 := cb.BuildSystemPromptWithCache()
+ if strings.Contains(promptV1, "Legacy identity") {
+ t.Fatalf("structured prompt should not include IDENTITY.md, got %q", promptV1)
+ }
+
+ identityPath := filepath.Join(tmpDir, "IDENTITY.md")
+ if err := os.WriteFile(identityPath, []byte("# Identity\nVersion two."), 0o644); err != nil {
+ t.Fatal(err)
+ }
+ future := time.Now().Add(2 * time.Second)
+ if err := os.Chtimes(identityPath, future, future); err != nil {
+ t.Fatal(err)
+ }
+
+ cb.systemPromptMutex.RLock()
+ changed := cb.sourceFilesChangedLocked()
+ cb.systemPromptMutex.RUnlock()
+ if changed {
+ t.Fatal("IDENTITY.md should not invalidate cache for structured agent definitions")
+ }
+
+ promptV2 := cb.BuildSystemPromptWithCache()
+ if promptV1 != promptV2 {
+ t.Fatal("structured prompt should remain stable after IDENTITY.md changes")
+ }
+}
+
+func TestStructuredAgentUserChangesInvalidateCache(t *testing.T) {
+ tmpDir := setupWorkspace(t, map[string]string{
+ "AGENT.md": "# Agent\nFollow the new structure.",
+ "SOUL.md": "# Soul\nVersion one.",
+ "USER.md": "# User\nInitial workspace preferences.",
+ })
+ defer cleanupWorkspace(t, tmpDir)
+
+ cb := NewContextBuilder(tmpDir)
+
+ promptV1 := cb.BuildSystemPromptWithCache()
+ if !strings.Contains(promptV1, "Initial workspace preferences") {
+ t.Fatalf("expected workspace USER.md in prompt, got %q", promptV1)
+ }
+
+ userPath := filepath.Join(tmpDir, "USER.md")
+ if err := os.WriteFile(userPath, []byte("# User\nUpdated workspace preferences."), 0o644); err != nil {
+ t.Fatal(err)
+ }
+ future := time.Now().Add(2 * time.Second)
+ if err := os.Chtimes(userPath, future, future); err != nil {
+ t.Fatal(err)
+ }
+
+ cb.systemPromptMutex.RLock()
+ changed := cb.sourceFilesChangedLocked()
+ cb.systemPromptMutex.RUnlock()
+ if !changed {
+ t.Fatal("workspace USER.md changes should invalidate cache")
+ }
+
+ promptV2 := cb.BuildSystemPromptWithCache()
+ if !strings.Contains(promptV2, "Updated workspace preferences") {
+ t.Fatalf("expected updated workspace USER.md in prompt, got %q", promptV2)
+ }
+}
+
+func cleanupWorkspace(t *testing.T, path string) {
+ t.Helper()
+ if err := os.RemoveAll(path); err != nil {
+ t.Fatalf("failed to clean up workspace %s: %v", path, err)
+ }
+}
diff --git a/workspace/AGENT.md b/workspace/AGENT.md
new file mode 100644
index 000000000..08f55a1b7
--- /dev/null
+++ b/workspace/AGENT.md
@@ -0,0 +1,45 @@
+---
+name: pico
+description: >
+ The default general-purpose assistant for everyday conversation, problem
+ solving, and workspace help.
+---
+
+You are Pico, the default assistant for this workspace.
+Your name is PicoClaw 🦞.
+## Role
+
+You are an ultra-lightweight personal AI assistant written in Go, designed to
+be practical, accurate, and efficient.
+
+## Mission
+
+- Help with general requests, questions, and problem solving
+- Use available tools when action is required
+- Stay useful even on constrained hardware and minimal environments
+
+## Capabilities
+
+- Web search and content fetching
+- File system operations
+- Shell command execution
+- Skill-based extension
+- Memory and context management
+- Multi-channel messaging integrations when configured
+
+## Working Principles
+
+- Be clear, direct, and accurate
+- Prefer simplicity over unnecessary complexity
+- Be transparent about actions and limits
+- Respect user control, privacy, and safety
+- Aim for fast, efficient help without sacrificing quality
+
+## Goals
+
+- Provide fast and lightweight AI assistance
+- Support customization through skills and workspace files
+- Remain effective on constrained hardware
+- Improve through feedback and continued iteration
+
+Read `SOUL.md` as part of your identity and communication style.
diff --git a/workspace/AGENTS.md b/workspace/AGENTS.md
deleted file mode 100644
index 5f5fa6480..000000000
--- a/workspace/AGENTS.md
+++ /dev/null
@@ -1,12 +0,0 @@
-# Agent Instructions
-
-You are a helpful AI assistant. Be concise, accurate, and friendly.
-
-## Guidelines
-
-- Always explain what you're doing before taking actions
-- Ask for clarification when request is ambiguous
-- Use tools to help accomplish tasks
-- Remember important information in your memory files
-- Be proactive and helpful
-- Learn from user feedback
\ No newline at end of file
diff --git a/workspace/IDENTITY.md b/workspace/IDENTITY.md
deleted file mode 100644
index 20e3e49fa..000000000
--- a/workspace/IDENTITY.md
+++ /dev/null
@@ -1,53 +0,0 @@
-# Identity
-
-## Name
-PicoClaw 🦞
-
-## Description
-Ultra-lightweight personal AI assistant written in Go, inspired by nanobot.
-
-## Purpose
-- Provide intelligent AI assistance with minimal resource usage
-- Support multiple LLM providers (OpenAI, Anthropic, Zhipu, etc.)
-- Enable easy customization through skills system
-- Run on minimal hardware ($10 boards, <10MB RAM)
-
-## Capabilities
-
-- Web search and content fetching
-- File system operations (read, write, edit)
-- Shell command execution
-- Multi-channel messaging (Telegram, WhatsApp, Feishu)
-- Skill-based extensibility
-- Memory and context management
-
-## Philosophy
-
-- Simplicity over complexity
-- Performance over features
-- User control and privacy
-- Transparent operation
-- Community-driven development
-
-## Goals
-
-- Provide a fast, lightweight AI assistant
-- Support offline-first operation where possible
-- Enable easy customization and extension
-- Maintain high quality responses
-- Run efficiently on constrained hardware
-
-## License
-MIT License - Free and open source
-
-## Repository
-https://github.com/sipeed/picoclaw
-
-## Contact
-Issues: https://github.com/sipeed/picoclaw/issues
-Discussions: https://github.com/sipeed/picoclaw/discussions
-
----
-
-"Every bit helps, every bit matters."
-- Picoclaw
\ No newline at end of file
diff --git a/workspace/SOUL.md b/workspace/SOUL.md
index 0be8834f5..8a6371ff9 100644
--- a/workspace/SOUL.md
+++ b/workspace/SOUL.md
@@ -1,6 +1,6 @@
# Soul
-I am picoclaw, a lightweight AI assistant powered by AI.
+I am PicoClaw: calm, helpful, and practical.
## Personality
@@ -8,10 +8,12 @@ I am picoclaw, a lightweight AI assistant powered by AI.
- Concise and to the point
- Curious and eager to learn
- Honest and transparent
+- Calm under uncertainty
## Values
- Accuracy over speed
- User privacy and safety
- Transparency in actions
-- Continuous improvement
\ No newline at end of file
+- Continuous improvement
+- Simplicity over unnecessary complexity
diff --git a/workspace/USER.md b/workspace/USER.md
index 91398a019..9a3419d87 100644
--- a/workspace/USER.md
+++ b/workspace/USER.md
@@ -1,6 +1,6 @@
# User
-Information about user goes here.
+Information about the user goes here.
## Preferences
@@ -18,4 +18,4 @@ Information about user goes here.
- What the user wants to learn from AI
- Preferred interaction style
-- Areas of interest
\ No newline at end of file
+- Areas of interest
From af61d0bca720340030fdc2afe2d858e57ff9a583 Mon Sep 17 00:00:00 2001
From: Hoshina
Date: Fri, 20 Mar 2026 14:53:22 +0800
Subject: [PATCH 15/26] feat(agent): add event bus foundation
---
pkg/agent/eventbus.go | 121 +++++++++++++++++++
pkg/agent/eventbus_test.go | 235 +++++++++++++++++++++++++++++++++++++
pkg/agent/events.go | 129 ++++++++++++++++++++
pkg/agent/loop.go | 166 +++++++++++++++++++++++++-
4 files changed, 650 insertions(+), 1 deletion(-)
create mode 100644 pkg/agent/eventbus.go
create mode 100644 pkg/agent/eventbus_test.go
create mode 100644 pkg/agent/events.go
diff --git a/pkg/agent/eventbus.go b/pkg/agent/eventbus.go
new file mode 100644
index 000000000..546d8436d
--- /dev/null
+++ b/pkg/agent/eventbus.go
@@ -0,0 +1,121 @@
+package agent
+
+import (
+ "sync"
+ "sync/atomic"
+ "time"
+)
+
+const defaultEventSubscriberBuffer = 16
+
+// EventSubscription identifies a subscriber channel returned by EventBus.Subscribe.
+type EventSubscription struct {
+ ID uint64
+ C <-chan Event
+}
+
+type eventSubscriber struct {
+ ch chan Event
+}
+
+// EventBus is a lightweight multi-subscriber broadcaster for agent-loop events.
+type EventBus struct {
+ mu sync.RWMutex
+ subs map[uint64]eventSubscriber
+ nextID uint64
+ closed bool
+ dropped [eventKindCount]atomic.Int64
+}
+
+// NewEventBus creates a new in-process event broadcaster.
+func NewEventBus() *EventBus {
+ return &EventBus{
+ subs: make(map[uint64]eventSubscriber),
+ }
+}
+
+// Subscribe registers a new subscriber with the requested channel buffer size.
+// A non-positive buffer uses the default size.
+func (b *EventBus) Subscribe(buffer int) EventSubscription {
+ if buffer <= 0 {
+ buffer = defaultEventSubscriberBuffer
+ }
+
+ b.mu.Lock()
+ defer b.mu.Unlock()
+
+ if b.closed {
+ ch := make(chan Event)
+ close(ch)
+ return EventSubscription{C: ch}
+ }
+
+ b.nextID++
+ id := b.nextID
+ ch := make(chan Event, buffer)
+ b.subs[id] = eventSubscriber{ch: ch}
+ return EventSubscription{ID: id, C: ch}
+}
+
+// Unsubscribe removes a subscriber and closes its channel.
+func (b *EventBus) Unsubscribe(id uint64) {
+ b.mu.Lock()
+ defer b.mu.Unlock()
+
+ sub, ok := b.subs[id]
+ if !ok {
+ return
+ }
+
+ delete(b.subs, id)
+ close(sub.ch)
+}
+
+// Emit broadcasts an event to all current subscribers without blocking.
+// When a subscriber channel is full, the event is dropped for that subscriber.
+func (b *EventBus) Emit(evt Event) {
+ if evt.Time.IsZero() {
+ evt.Time = time.Now()
+ }
+
+ b.mu.RLock()
+ defer b.mu.RUnlock()
+
+ if b.closed {
+ return
+ }
+
+ for _, sub := range b.subs {
+ select {
+ case sub.ch <- evt:
+ default:
+ if evt.Kind < eventKindCount {
+ b.dropped[evt.Kind].Add(1)
+ }
+ }
+ }
+}
+
+// Dropped returns the number of dropped events for a given kind.
+func (b *EventBus) Dropped(kind EventKind) int64 {
+ if kind >= eventKindCount {
+ return 0
+ }
+ return b.dropped[kind].Load()
+}
+
+// Close closes all subscriber channels and stops future broadcasts.
+func (b *EventBus) Close() {
+ b.mu.Lock()
+ defer b.mu.Unlock()
+
+ if b.closed {
+ return
+ }
+
+ b.closed = true
+ for id, sub := range b.subs {
+ close(sub.ch)
+ delete(b.subs, id)
+ }
+}
diff --git a/pkg/agent/eventbus_test.go b/pkg/agent/eventbus_test.go
new file mode 100644
index 000000000..d57fac094
--- /dev/null
+++ b/pkg/agent/eventbus_test.go
@@ -0,0 +1,235 @@
+package agent
+
+import (
+ "context"
+ "os"
+ "slices"
+ "testing"
+ "time"
+
+ "github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/config"
+ "github.com/sipeed/picoclaw/pkg/providers"
+ "github.com/sipeed/picoclaw/pkg/tools"
+)
+
+func TestEventBus_SubscribeEmitUnsubscribeClose(t *testing.T) {
+ eventBus := NewEventBus()
+ sub := eventBus.Subscribe(1)
+
+ eventBus.Emit(Event{
+ Kind: EventKindTurnStart,
+ Meta: EventMeta{TurnID: "turn-1"},
+ })
+
+ select {
+ case evt := <-sub.C:
+ if evt.Kind != EventKindTurnStart {
+ t.Fatalf("expected %v, got %v", EventKindTurnStart, evt.Kind)
+ }
+ if evt.Meta.TurnID != "turn-1" {
+ t.Fatalf("expected turn id turn-1, got %q", evt.Meta.TurnID)
+ }
+ case <-time.After(time.Second):
+ t.Fatal("timed out waiting for event")
+ }
+
+ eventBus.Unsubscribe(sub.ID)
+ if _, ok := <-sub.C; ok {
+ t.Fatal("expected subscriber channel to be closed after unsubscribe")
+ }
+
+ eventBus.Close()
+ closedSub := eventBus.Subscribe(1)
+ if _, ok := <-closedSub.C; ok {
+ t.Fatal("expected closed bus to return a closed subscriber channel")
+ }
+}
+
+func TestEventBus_DropsWhenSubscriberIsFull(t *testing.T) {
+ eventBus := NewEventBus()
+ sub := eventBus.Subscribe(1)
+ defer eventBus.Unsubscribe(sub.ID)
+
+ start := time.Now()
+ for i := 0; i < 1000; i++ {
+ eventBus.Emit(Event{Kind: EventKindLLMRequest})
+ }
+
+ if elapsed := time.Since(start); elapsed > 100*time.Millisecond {
+ t.Fatalf("Emit took too long with a blocked subscriber: %s", elapsed)
+ }
+
+ if got := eventBus.Dropped(EventKindLLMRequest); got != 999 {
+ t.Fatalf("expected 999 dropped events, got %d", got)
+ }
+}
+
+type scriptedToolProvider struct {
+ calls int
+}
+
+func (m *scriptedToolProvider) Chat(
+ ctx context.Context,
+ messages []providers.Message,
+ toolDefs []providers.ToolDefinition,
+ model string,
+ opts map[string]any,
+) (*providers.LLMResponse, error) {
+ m.calls++
+ if m.calls == 1 {
+ return &providers.LLMResponse{
+ ToolCalls: []providers.ToolCall{
+ {
+ ID: "call-1",
+ Name: "mock_custom",
+ Arguments: map[string]any{"task": "ping"},
+ },
+ },
+ }, nil
+ }
+
+ return &providers.LLMResponse{
+ Content: "done",
+ }, nil
+}
+
+func (m *scriptedToolProvider) GetDefaultModel() string {
+ return "scripted-tool-model"
+}
+
+func TestAgentLoop_EmitsMinimalTurnEvents(t *testing.T) {
+ tmpDir, err := os.MkdirTemp("", "agent-eventbus-*")
+ if err != nil {
+ t.Fatalf("failed to create temp dir: %v", err)
+ }
+ defer os.RemoveAll(tmpDir)
+
+ cfg := &config.Config{
+ Agents: config.AgentsConfig{
+ Defaults: config.AgentDefaults{
+ Workspace: tmpDir,
+ Model: "test-model",
+ MaxTokens: 4096,
+ MaxToolIterations: 10,
+ },
+ },
+ }
+
+ msgBus := bus.NewMessageBus()
+ provider := &scriptedToolProvider{}
+ al := NewAgentLoop(cfg, msgBus, provider)
+ al.RegisterTool(&mockCustomTool{})
+ defaultAgent := al.registry.GetDefaultAgent()
+ if defaultAgent == nil {
+ t.Fatal("expected default agent")
+ }
+
+ sub := al.SubscribeEvents(16)
+ defer al.UnsubscribeEvents(sub.ID)
+
+ response, err := al.runAgentLoop(context.Background(), defaultAgent, processOptions{
+ SessionKey: "session-1",
+ Channel: "cli",
+ ChatID: "direct",
+ UserMessage: "run tool",
+ DefaultResponse: defaultResponse,
+ EnableSummary: false,
+ SendResponse: false,
+ })
+ if err != nil {
+ t.Fatalf("runAgentLoop failed: %v", err)
+ }
+ if response != "done" {
+ t.Fatalf("expected final response 'done', got %q", response)
+ }
+
+ events := collectEventStream(sub.C)
+ if len(events) != 8 {
+ t.Fatalf("expected 8 events, got %d", len(events))
+ }
+
+ kinds := make([]EventKind, 0, len(events))
+ for _, evt := range events {
+ kinds = append(kinds, evt.Kind)
+ }
+
+ expectedKinds := []EventKind{
+ EventKindTurnStart,
+ EventKindLLMRequest,
+ EventKindLLMResponse,
+ EventKindToolExecStart,
+ EventKindToolExecEnd,
+ EventKindLLMRequest,
+ EventKindLLMResponse,
+ EventKindTurnEnd,
+ }
+ if !slices.Equal(kinds, expectedKinds) {
+ t.Fatalf("unexpected event sequence: got %v want %v", kinds, expectedKinds)
+ }
+
+ turnID := events[0].Meta.TurnID
+ for i, evt := range events {
+ if evt.Meta.TurnID != turnID {
+ t.Fatalf("event %d has mismatched turn id %q, want %q", i, evt.Meta.TurnID, turnID)
+ }
+ if evt.Meta.SessionKey != "session-1" {
+ t.Fatalf("event %d has session key %q, want session-1", i, evt.Meta.SessionKey)
+ }
+ }
+
+ startPayload, ok := events[0].Payload.(TurnStartPayload)
+ if !ok {
+ t.Fatalf("expected TurnStartPayload, got %T", events[0].Payload)
+ }
+ if startPayload.UserMessage != "run tool" {
+ t.Fatalf("expected user message 'run tool', got %q", startPayload.UserMessage)
+ }
+
+ toolStartPayload, ok := events[3].Payload.(ToolExecStartPayload)
+ if !ok {
+ t.Fatalf("expected ToolExecStartPayload, got %T", events[3].Payload)
+ }
+ if toolStartPayload.Tool != "mock_custom" {
+ t.Fatalf("expected tool name mock_custom, got %q", toolStartPayload.Tool)
+ }
+
+ toolEndPayload, ok := events[4].Payload.(ToolExecEndPayload)
+ if !ok {
+ t.Fatalf("expected ToolExecEndPayload, got %T", events[4].Payload)
+ }
+ if toolEndPayload.Tool != "mock_custom" {
+ t.Fatalf("expected tool end payload for mock_custom, got %q", toolEndPayload.Tool)
+ }
+ if toolEndPayload.IsError {
+ t.Fatal("expected mock_custom tool to succeed")
+ }
+
+ turnEndPayload, ok := events[len(events)-1].Payload.(TurnEndPayload)
+ if !ok {
+ t.Fatalf("expected TurnEndPayload, got %T", events[len(events)-1].Payload)
+ }
+ if turnEndPayload.Status != TurnEndStatusCompleted {
+ t.Fatalf("expected completed turn, got %q", turnEndPayload.Status)
+ }
+ if turnEndPayload.Iterations != 2 {
+ t.Fatalf("expected 2 iterations, got %d", turnEndPayload.Iterations)
+ }
+}
+
+func collectEventStream(ch <-chan Event) []Event {
+ var events []Event
+ for {
+ select {
+ case evt, ok := <-ch:
+ if !ok {
+ return events
+ }
+ events = append(events, evt)
+ default:
+ return events
+ }
+ }
+}
+
+var _ tools.Tool = (*mockCustomTool)(nil)
diff --git a/pkg/agent/events.go b/pkg/agent/events.go
new file mode 100644
index 000000000..92aec7436
--- /dev/null
+++ b/pkg/agent/events.go
@@ -0,0 +1,129 @@
+package agent
+
+import (
+ "fmt"
+ "time"
+)
+
+// EventKind identifies a structured agent-loop event.
+type EventKind uint8
+
+const (
+ // EventKindTurnStart is emitted when a turn begins processing.
+ EventKindTurnStart EventKind = iota
+ // EventKindTurnEnd is emitted when a turn finishes, successfully or with an error.
+ EventKindTurnEnd
+ // EventKindLLMRequest is emitted before a provider chat request is made.
+ EventKindLLMRequest
+ // EventKindLLMResponse is emitted after a provider chat response is received.
+ EventKindLLMResponse
+ // EventKindToolExecStart is emitted immediately before a tool executes.
+ EventKindToolExecStart
+ // EventKindToolExecEnd is emitted immediately after a tool finishes executing.
+ EventKindToolExecEnd
+ // EventKindError is emitted when a turn encounters an execution error.
+ EventKindError
+
+ eventKindCount
+)
+
+var eventKindNames = [...]string{
+ "turn_start",
+ "turn_end",
+ "llm_request",
+ "llm_response",
+ "tool_exec_start",
+ "tool_exec_end",
+ "error",
+}
+
+// String returns the stable string form of an EventKind.
+func (k EventKind) String() string {
+ if k >= eventKindCount {
+ return fmt.Sprintf("event_kind(%d)", k)
+ }
+ return eventKindNames[k]
+}
+
+// Event is the structured envelope broadcast by the agent EventBus.
+type Event struct {
+ Kind EventKind
+ Time time.Time
+ Meta EventMeta
+ Payload any
+}
+
+// EventMeta contains correlation fields shared by all agent-loop events.
+type EventMeta struct {
+ AgentID string
+ TurnID string
+ ParentTurnID string
+ SessionKey string
+ Iteration int
+ TracePath string
+ Source string
+}
+
+// TurnEndStatus describes the terminal state of a turn.
+type TurnEndStatus string
+
+const (
+ // TurnEndStatusCompleted indicates the turn finished normally.
+ TurnEndStatusCompleted TurnEndStatus = "completed"
+ // TurnEndStatusError indicates the turn ended because of an error.
+ TurnEndStatusError TurnEndStatus = "error"
+)
+
+// TurnStartPayload describes the start of a turn.
+type TurnStartPayload struct {
+ Channel string
+ ChatID string
+ UserMessage string
+ MediaCount int
+}
+
+// TurnEndPayload describes the completion of a turn.
+type TurnEndPayload struct {
+ Status TurnEndStatus
+ Iterations int
+ Duration time.Duration
+ FinalContentLen int
+}
+
+// LLMRequestPayload describes an outbound LLM request.
+type LLMRequestPayload struct {
+ Model string
+ MessagesCount int
+ ToolsCount int
+ MaxTokens int
+ Temperature float64
+}
+
+// LLMResponsePayload describes an inbound LLM response.
+type LLMResponsePayload struct {
+ ContentLen int
+ ToolCalls int
+ HasReasoning bool
+}
+
+// ToolExecStartPayload describes a tool execution request.
+type ToolExecStartPayload struct {
+ Tool string
+ Arguments map[string]any
+}
+
+// ToolExecEndPayload describes the outcome of a tool execution.
+type ToolExecEndPayload struct {
+ Tool string
+ Duration time.Duration
+ ForLLMLen int
+ ForUserLen int
+ IsError bool
+ Async bool
+}
+
+// ErrorPayload describes an execution error inside the agent loop.
+type ErrorPayload struct {
+ Stage string
+ Message string
+}
diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go
index c583f5ca5..2c9c86cf9 100644
--- a/pkg/agent/loop.go
+++ b/pkg/agent/loop.go
@@ -39,6 +39,7 @@ type AgentLoop struct {
cfg *config.Config
registry *AgentRegistry
state *state.Manager
+ eventBus *EventBus
running atomic.Bool
summarizing sync.Map
fallback *providers.FallbackChain
@@ -49,6 +50,7 @@ type AgentLoop struct {
mcp mcpRuntime
steering *steeringQueue
mu sync.RWMutex
+ turnSeq atomic.Uint64
// Track active requests for safe provider cleanup
activeRequests sync.WaitGroup
}
@@ -103,6 +105,7 @@ func NewAgentLoop(
cfg: cfg,
registry: registry,
state: stateManager,
+ eventBus: NewEventBus(),
summarizing: sync.Map{},
fallback: fallbackChain,
cmdRegistry: commands.NewRegistry(commands.BuiltinDefinitions()),
@@ -380,6 +383,84 @@ func (al *AgentLoop) Close() {
}
al.GetRegistry().Close()
+ if al.eventBus != nil {
+ al.eventBus.Close()
+ }
+}
+
+// SubscribeEvents registers a subscriber for agent-loop events.
+func (al *AgentLoop) SubscribeEvents(buffer int) EventSubscription {
+ if al == nil || al.eventBus == nil {
+ ch := make(chan Event)
+ close(ch)
+ return EventSubscription{C: ch}
+ }
+ return al.eventBus.Subscribe(buffer)
+}
+
+// UnsubscribeEvents removes a previously registered event subscriber.
+func (al *AgentLoop) UnsubscribeEvents(id uint64) {
+ if al == nil || al.eventBus == nil {
+ return
+ }
+ al.eventBus.Unsubscribe(id)
+}
+
+// EventDrops returns the number of dropped events for the given kind.
+func (al *AgentLoop) EventDrops(kind EventKind) int64 {
+ if al == nil || al.eventBus == nil {
+ return 0
+ }
+ return al.eventBus.Dropped(kind)
+}
+
+type turnEventScope struct {
+ agentID string
+ sessionKey string
+ turnID string
+}
+
+func (al *AgentLoop) newTurnEventScope(agentID, sessionKey string) turnEventScope {
+ seq := al.turnSeq.Add(1)
+ return turnEventScope{
+ agentID: agentID,
+ sessionKey: sessionKey,
+ turnID: fmt.Sprintf("%s-turn-%d", agentID, seq),
+ }
+}
+
+func (ts turnEventScope) meta(iteration int, source, tracePath string) EventMeta {
+ return EventMeta{
+ AgentID: ts.agentID,
+ TurnID: ts.turnID,
+ SessionKey: ts.sessionKey,
+ Iteration: iteration,
+ Source: source,
+ TracePath: tracePath,
+ }
+}
+
+func (al *AgentLoop) emitEvent(kind EventKind, meta EventMeta, payload any) {
+ if al == nil || al.eventBus == nil {
+ return
+ }
+ al.eventBus.Emit(Event{
+ Kind: kind,
+ Meta: meta,
+ Payload: payload,
+ })
+}
+
+func cloneEventArguments(args map[string]any) map[string]any {
+ if len(args) == 0 {
+ return nil
+ }
+
+ cloned := make(map[string]any, len(args))
+ for k, v := range args {
+ cloned[k] = v
+ }
+ return cloned
}
func (al *AgentLoop) RegisterTool(tool tools.Tool) {
@@ -895,6 +976,35 @@ func (al *AgentLoop) runAgentLoop(
agent *AgentInstance,
opts processOptions,
) (string, error) {
+ turnScope := al.newTurnEventScope(agent.ID, opts.SessionKey)
+ turnStartedAt := time.Now()
+ turnIterations := 0
+ turnFinalContentLen := 0
+ turnStatus := TurnEndStatusCompleted
+ defer func() {
+ al.emitEvent(
+ EventKindTurnEnd,
+ turnScope.meta(turnIterations, "runAgentLoop", "turn.end"),
+ TurnEndPayload{
+ Status: turnStatus,
+ Iterations: turnIterations,
+ Duration: time.Since(turnStartedAt),
+ FinalContentLen: turnFinalContentLen,
+ },
+ )
+ }()
+
+ al.emitEvent(
+ EventKindTurnStart,
+ turnScope.meta(0, "runAgentLoop", "turn.start"),
+ TurnStartPayload{
+ Channel: opts.Channel,
+ ChatID: opts.ChatID,
+ UserMessage: opts.UserMessage,
+ MediaCount: len(opts.Media),
+ },
+ )
+
// 0. Record last channel for heartbeat notifications (skip internal channels and cli)
if opts.Channel != "" && opts.ChatID != "" {
if !constants.IsInternalChannel(opts.Channel) {
@@ -952,8 +1062,10 @@ func (al *AgentLoop) runAgentLoop(
agent.Sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage)
// 3. Run LLM iteration loop
- finalContent, iteration, err := al.runLLMIteration(ctx, agent, messages, opts)
+ finalContent, iteration, err := al.runLLMIteration(ctx, agent, messages, opts, turnScope)
+ turnIterations = iteration
if err != nil {
+ turnStatus = TurnEndStatusError
return "", err
}
@@ -964,6 +1076,7 @@ func (al *AgentLoop) runAgentLoop(
if finalContent == "" {
finalContent = opts.DefaultResponse
}
+ turnFinalContentLen = len(finalContent)
// 5. Save final assistant message to session
agent.Sessions.AddMessage(opts.SessionKey, "assistant", finalContent)
@@ -1058,6 +1171,7 @@ func (al *AgentLoop) runLLMIteration(
agent *AgentInstance,
messages []providers.Message,
opts processOptions,
+ turnScope turnEventScope,
) (string, int, error) {
iteration := 0
var finalContent string
@@ -1106,6 +1220,17 @@ func (al *AgentLoop) runLLMIteration(
// Build tool definitions
providerToolDefs := agent.Tools.ToProviderDefs()
+ al.emitEvent(
+ EventKindLLMRequest,
+ turnScope.meta(iteration, "runLLMIteration", "turn.llm.request"),
+ LLMRequestPayload{
+ Model: activeModel,
+ MessagesCount: len(messages),
+ ToolsCount: len(providerToolDefs),
+ MaxTokens: agent.MaxTokens,
+ Temperature: agent.Temperature,
+ },
+ )
// Log LLM request details
logger.DebugCF("agent", "LLM request",
@@ -1246,6 +1371,14 @@ func (al *AgentLoop) runLLMIteration(
}
if err != nil {
+ al.emitEvent(
+ EventKindError,
+ turnScope.meta(iteration, "runLLMIteration", "turn.error"),
+ ErrorPayload{
+ Stage: "llm",
+ Message: err.Error(),
+ },
+ )
logger.ErrorCF("agent", "LLM call failed",
map[string]any{
"agent_id": agent.ID,
@@ -1262,6 +1395,15 @@ func (al *AgentLoop) runLLMIteration(
opts.Channel,
al.targetReasoningChannelID(opts.Channel),
)
+ al.emitEvent(
+ EventKindLLMResponse,
+ turnScope.meta(iteration, "runLLMIteration", "turn.llm.response"),
+ LLMResponsePayload{
+ ContentLen: len(response.Content),
+ ToolCalls: len(response.ToolCalls),
+ HasReasoning: response.Reasoning != "" || response.ReasoningContent != "",
+ },
+ )
logger.DebugCF("agent", "LLM response",
map[string]any{
@@ -1352,6 +1494,14 @@ func (al *AgentLoop) runLLMIteration(
"tool": tc.Name,
"iteration": iteration,
})
+ al.emitEvent(
+ EventKindToolExecStart,
+ turnScope.meta(iteration, "runLLMIteration", "turn.tool.start"),
+ ToolExecStartPayload{
+ Tool: tc.Name,
+ Arguments: cloneEventArguments(tc.Arguments),
+ },
+ )
// Create async callback for tools that implement AsyncExecutor.
asyncCallback := func(_ context.Context, result *tools.ToolResult) {
@@ -1390,6 +1540,7 @@ func (al *AgentLoop) runLLMIteration(
})
}
+ toolStart := time.Now()
toolResult := agent.Tools.ExecuteWithContext(
ctx,
tc.Name,
@@ -1398,6 +1549,7 @@ func (al *AgentLoop) runLLMIteration(
opts.ChatID,
asyncCallback,
)
+ toolDuration := time.Since(toolStart)
// Process tool result
if !toolResult.Silent && toolResult.ForUser != "" && opts.SendResponse {
@@ -1443,6 +1595,18 @@ func (al *AgentLoop) runLLMIteration(
Content: contentForLLM,
ToolCallID: tc.ID,
}
+ al.emitEvent(
+ EventKindToolExecEnd,
+ turnScope.meta(iteration, "runLLMIteration", "turn.tool.end"),
+ ToolExecEndPayload{
+ Tool: tc.Name,
+ Duration: toolDuration,
+ ForLLMLen: len(contentForLLM),
+ ForUserLen: len(toolResult.ForUser),
+ IsError: toolResult.IsError,
+ Async: toolResult.Async,
+ },
+ )
messages = append(messages, toolResultMsg)
agent.Sessions.AddFullMessage(opts.SessionKey, toolResultMsg)
From 50cc7100cee14247690bfb2690bf6fbea5be4e37 Mon Sep 17 00:00:00 2001
From: Hoshina
Date: Fri, 20 Mar 2026 15:06:43 +0800
Subject: [PATCH 16/26] feat(agent): make event logs show event kind clearly
---
pkg/agent/loop.go | 68 +++++++++++++++++++++++++++++++++++++++++++----
1 file changed, 63 insertions(+), 5 deletions(-)
diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go
index 2c9c86cf9..ac97104b1 100644
--- a/pkg/agent/loop.go
+++ b/pkg/agent/loop.go
@@ -441,14 +441,18 @@ func (ts turnEventScope) meta(iteration int, source, tracePath string) EventMeta
}
func (al *AgentLoop) emitEvent(kind EventKind, meta EventMeta, payload any) {
- if al == nil || al.eventBus == nil {
- return
- }
- al.eventBus.Emit(Event{
+ evt := Event{
Kind: kind,
Meta: meta,
Payload: payload,
- })
+ }
+
+ al.logEvent(evt)
+
+ if al == nil || al.eventBus == nil {
+ return
+ }
+ al.eventBus.Emit(evt)
}
func cloneEventArguments(args map[string]any) map[string]any {
@@ -463,6 +467,60 @@ func cloneEventArguments(args map[string]any) map[string]any {
return cloned
}
+func (al *AgentLoop) logEvent(evt Event) {
+ fields := map[string]any{
+ "event_kind": evt.Kind.String(),
+ "agent_id": evt.Meta.AgentID,
+ "turn_id": evt.Meta.TurnID,
+ "session_key": evt.Meta.SessionKey,
+ "iteration": evt.Meta.Iteration,
+ }
+
+ if evt.Meta.TracePath != "" {
+ fields["trace"] = evt.Meta.TracePath
+ }
+ if evt.Meta.Source != "" {
+ fields["source"] = evt.Meta.Source
+ }
+
+ switch payload := evt.Payload.(type) {
+ case TurnStartPayload:
+ fields["channel"] = payload.Channel
+ fields["chat_id"] = payload.ChatID
+ fields["user_len"] = len(payload.UserMessage)
+ fields["media_count"] = payload.MediaCount
+ case TurnEndPayload:
+ fields["status"] = payload.Status
+ fields["iterations_total"] = payload.Iterations
+ fields["duration_ms"] = payload.Duration.Milliseconds()
+ fields["final_len"] = payload.FinalContentLen
+ case LLMRequestPayload:
+ fields["model"] = payload.Model
+ fields["messages"] = payload.MessagesCount
+ fields["tools"] = payload.ToolsCount
+ fields["max_tokens"] = payload.MaxTokens
+ case LLMResponsePayload:
+ fields["content_len"] = payload.ContentLen
+ fields["tool_calls"] = payload.ToolCalls
+ fields["has_reasoning"] = payload.HasReasoning
+ case ToolExecStartPayload:
+ fields["tool"] = payload.Tool
+ fields["args_count"] = len(payload.Arguments)
+ case ToolExecEndPayload:
+ fields["tool"] = payload.Tool
+ fields["duration_ms"] = payload.Duration.Milliseconds()
+ fields["for_llm_len"] = payload.ForLLMLen
+ fields["for_user_len"] = payload.ForUserLen
+ fields["is_error"] = payload.IsError
+ fields["async"] = payload.Async
+ case ErrorPayload:
+ fields["stage"] = payload.Stage
+ fields["error"] = payload.Message
+ }
+
+ logger.InfoCF("eventbus", fmt.Sprintf("Agent event: %s", evt.Kind.String()), fields)
+}
+
func (al *AgentLoop) RegisterTool(tool tools.Tool) {
registry := al.GetRegistry()
for _, agentID := range registry.ListAgentIDs() {
From 57cde73b36cc27da4f7979b5526eabaad0f0bfed Mon Sep 17 00:00:00 2001
From: Hoshina
Date: Fri, 20 Mar 2026 15:29:52 +0800
Subject: [PATCH 17/26] feat(agent): expand event bus coverage
---
pkg/agent/eventbus_test.go | 444 +++++++++++++++++++++++++++++++++++++
pkg/agent/events.go | 119 ++++++++++
pkg/agent/loop.go | 150 ++++++++++++-
pkg/agent/steering.go | 19 ++
pkg/tools/spawn.go | 3 +
pkg/tools/subagent.go | 3 +
6 files changed, 730 insertions(+), 8 deletions(-)
diff --git a/pkg/agent/eventbus_test.go b/pkg/agent/eventbus_test.go
index d57fac094..dadbc2f94 100644
--- a/pkg/agent/eventbus_test.go
+++ b/pkg/agent/eventbus_test.go
@@ -217,6 +217,374 @@ func TestAgentLoop_EmitsMinimalTurnEvents(t *testing.T) {
}
}
+func TestAgentLoop_EmitsSteeringAndSkippedToolEvents(t *testing.T) {
+ tmpDir, err := os.MkdirTemp("", "agent-eventbus-steering-*")
+ if err != nil {
+ t.Fatalf("failed to create temp dir: %v", err)
+ }
+ defer os.RemoveAll(tmpDir)
+
+ cfg := &config.Config{
+ Agents: config.AgentsConfig{
+ Defaults: config.AgentDefaults{
+ Workspace: tmpDir,
+ Model: "test-model",
+ MaxTokens: 4096,
+ MaxToolIterations: 10,
+ },
+ },
+ }
+
+ tool1ExecCh := make(chan struct{})
+ tool1 := &slowTool{name: "tool_one", duration: 50 * time.Millisecond, execCh: tool1ExecCh}
+ tool2 := &slowTool{name: "tool_two", duration: 50 * time.Millisecond}
+
+ provider := &toolCallProvider{
+ toolCalls: []providers.ToolCall{
+ {
+ ID: "call_1",
+ Type: "function",
+ Name: "tool_one",
+ Function: &providers.FunctionCall{
+ Name: "tool_one",
+ Arguments: "{}",
+ },
+ Arguments: map[string]any{},
+ },
+ {
+ ID: "call_2",
+ Type: "function",
+ Name: "tool_two",
+ Function: &providers.FunctionCall{
+ Name: "tool_two",
+ Arguments: "{}",
+ },
+ Arguments: map[string]any{},
+ },
+ },
+ finalResp: "steered response",
+ }
+
+ msgBus := bus.NewMessageBus()
+ al := NewAgentLoop(cfg, msgBus, provider)
+ al.RegisterTool(tool1)
+ al.RegisterTool(tool2)
+
+ sub := al.SubscribeEvents(32)
+ defer al.UnsubscribeEvents(sub.ID)
+
+ resultCh := make(chan string, 1)
+ go func() {
+ resp, _ := al.ProcessDirectWithChannel(context.Background(), "do something", "test-session", "test", "chat1")
+ resultCh <- resp
+ }()
+
+ select {
+ case <-tool1ExecCh:
+ case <-time.After(2 * time.Second):
+ t.Fatal("timeout waiting for tool_one to start")
+ }
+
+ if err := al.Steer(providers.Message{Role: "user", Content: "change course"}); err != nil {
+ t.Fatalf("Steer failed: %v", err)
+ }
+
+ select {
+ case resp := <-resultCh:
+ if resp != "steered response" {
+ t.Fatalf("expected steered response, got %q", resp)
+ }
+ case <-time.After(5 * time.Second):
+ t.Fatal("timeout waiting for steered response")
+ }
+
+ events := collectEventStream(sub.C)
+ steeringEvt, ok := findEvent(events, EventKindSteeringInjected)
+ if !ok {
+ t.Fatal("expected steering injected event")
+ }
+ steeringPayload, ok := steeringEvt.Payload.(SteeringInjectedPayload)
+ if !ok {
+ t.Fatalf("expected SteeringInjectedPayload, got %T", steeringEvt.Payload)
+ }
+ if steeringPayload.Count != 1 {
+ t.Fatalf("expected 1 steering message, got %d", steeringPayload.Count)
+ }
+
+ skippedEvt, ok := findEvent(events, EventKindToolExecSkipped)
+ if !ok {
+ t.Fatal("expected skipped tool event")
+ }
+ skippedPayload, ok := skippedEvt.Payload.(ToolExecSkippedPayload)
+ if !ok {
+ t.Fatalf("expected ToolExecSkippedPayload, got %T", skippedEvt.Payload)
+ }
+ if skippedPayload.Tool != "tool_two" {
+ t.Fatalf("expected skipped tool_two, got %q", skippedPayload.Tool)
+ }
+
+ interruptEvt, ok := findEvent(events, EventKindInterruptReceived)
+ if !ok {
+ t.Fatal("expected interrupt received event")
+ }
+ interruptPayload, ok := interruptEvt.Payload.(InterruptReceivedPayload)
+ if !ok {
+ t.Fatalf("expected InterruptReceivedPayload, got %T", interruptEvt.Payload)
+ }
+ if interruptPayload.Role != "user" {
+ t.Fatalf("expected interrupt role user, got %q", interruptPayload.Role)
+ }
+ if interruptPayload.ContentLen != len("change course") {
+ t.Fatalf("expected interrupt content len %d, got %d", len("change course"), interruptPayload.ContentLen)
+ }
+}
+
+func TestAgentLoop_EmitsContextCompressEventOnRetry(t *testing.T) {
+ tmpDir, err := os.MkdirTemp("", "agent-eventbus-compress-*")
+ if err != nil {
+ t.Fatalf("failed to create temp dir: %v", err)
+ }
+ defer os.RemoveAll(tmpDir)
+
+ cfg := &config.Config{
+ Agents: config.AgentsConfig{
+ Defaults: config.AgentDefaults{
+ Workspace: tmpDir,
+ Model: "test-model",
+ MaxTokens: 4096,
+ MaxToolIterations: 10,
+ },
+ },
+ }
+
+ contextErr := errString("InvalidParameter: Total tokens of image and text exceed max message tokens")
+ provider := &failFirstMockProvider{
+ failures: 1,
+ failError: contextErr,
+ successResp: "Recovered from context error",
+ }
+ msgBus := bus.NewMessageBus()
+ al := NewAgentLoop(cfg, msgBus, provider)
+ defaultAgent := al.registry.GetDefaultAgent()
+ if defaultAgent == nil {
+ t.Fatal("expected default agent")
+ }
+
+ defaultAgent.Sessions.SetHistory("session-1", []providers.Message{
+ {Role: "user", Content: "Old message 1"},
+ {Role: "assistant", Content: "Old response 1"},
+ {Role: "user", Content: "Old message 2"},
+ {Role: "assistant", Content: "Old response 2"},
+ {Role: "user", Content: "Trigger message"},
+ })
+
+ sub := al.SubscribeEvents(16)
+ defer al.UnsubscribeEvents(sub.ID)
+
+ resp, err := al.runAgentLoop(context.Background(), defaultAgent, processOptions{
+ SessionKey: "session-1",
+ Channel: "cli",
+ ChatID: "direct",
+ UserMessage: "Trigger message",
+ DefaultResponse: defaultResponse,
+ EnableSummary: false,
+ SendResponse: false,
+ })
+ if err != nil {
+ t.Fatalf("runAgentLoop failed: %v", err)
+ }
+ if resp != "Recovered from context error" {
+ t.Fatalf("expected retry success, got %q", resp)
+ }
+
+ events := collectEventStream(sub.C)
+ retryEvt, ok := findEvent(events, EventKindLLMRetry)
+ if !ok {
+ t.Fatal("expected llm retry event")
+ }
+ retryPayload, ok := retryEvt.Payload.(LLMRetryPayload)
+ if !ok {
+ t.Fatalf("expected LLMRetryPayload, got %T", retryEvt.Payload)
+ }
+ if retryPayload.Reason != "context_limit" {
+ t.Fatalf("expected context_limit retry reason, got %q", retryPayload.Reason)
+ }
+ if retryPayload.Attempt != 1 {
+ t.Fatalf("expected retry attempt 1, got %d", retryPayload.Attempt)
+ }
+
+ compressEvt, ok := findEvent(events, EventKindContextCompress)
+ if !ok {
+ t.Fatal("expected context compress event")
+ }
+ payload, ok := compressEvt.Payload.(ContextCompressPayload)
+ if !ok {
+ t.Fatalf("expected ContextCompressPayload, got %T", compressEvt.Payload)
+ }
+ if payload.Reason != ContextCompressReasonRetry {
+ t.Fatalf("expected retry compress reason, got %q", payload.Reason)
+ }
+ if payload.DroppedMessages == 0 {
+ t.Fatal("expected dropped messages to be recorded")
+ }
+}
+
+func TestAgentLoop_EmitsSessionSummarizeEvent(t *testing.T) {
+ tmpDir, err := os.MkdirTemp("", "agent-eventbus-summary-*")
+ if err != nil {
+ t.Fatalf("failed to create temp dir: %v", err)
+ }
+ defer os.RemoveAll(tmpDir)
+
+ cfg := &config.Config{
+ Agents: config.AgentsConfig{
+ Defaults: config.AgentDefaults{
+ Workspace: tmpDir,
+ Model: "test-model",
+ MaxTokens: 4096,
+ MaxToolIterations: 10,
+ ContextWindow: 8000,
+ SummarizeMessageThreshold: 2,
+ SummarizeTokenPercent: 75,
+ },
+ },
+ }
+
+ msgBus := bus.NewMessageBus()
+ al := NewAgentLoop(cfg, msgBus, &simpleMockProvider{response: "summary text"})
+ defaultAgent := al.registry.GetDefaultAgent()
+ if defaultAgent == nil {
+ t.Fatal("expected default agent")
+ }
+
+ defaultAgent.Sessions.SetHistory("session-1", []providers.Message{
+ {Role: "user", Content: "Question one"},
+ {Role: "assistant", Content: "Answer one"},
+ {Role: "user", Content: "Question two"},
+ {Role: "assistant", Content: "Answer two"},
+ {Role: "user", Content: "Question three"},
+ {Role: "assistant", Content: "Answer three"},
+ })
+
+ sub := al.SubscribeEvents(16)
+ defer al.UnsubscribeEvents(sub.ID)
+
+ turnScope := al.newTurnEventScope(defaultAgent.ID, "session-1")
+ al.summarizeSession(defaultAgent, "session-1", turnScope)
+
+ events := collectEventStream(sub.C)
+ summaryEvt, ok := findEvent(events, EventKindSessionSummarize)
+ if !ok {
+ t.Fatal("expected session summarize event")
+ }
+ payload, ok := summaryEvt.Payload.(SessionSummarizePayload)
+ if !ok {
+ t.Fatalf("expected SessionSummarizePayload, got %T", summaryEvt.Payload)
+ }
+ if payload.SummaryLen == 0 {
+ t.Fatal("expected non-empty summary length")
+ }
+}
+
+func TestAgentLoop_EmitsFollowUpQueuedEvent(t *testing.T) {
+ tmpDir, err := os.MkdirTemp("", "agent-eventbus-followup-*")
+ if err != nil {
+ t.Fatalf("failed to create temp dir: %v", err)
+ }
+ defer os.RemoveAll(tmpDir)
+
+ cfg := &config.Config{
+ Agents: config.AgentsConfig{
+ Defaults: config.AgentDefaults{
+ Workspace: tmpDir,
+ Model: "test-model",
+ MaxTokens: 4096,
+ MaxToolIterations: 10,
+ },
+ },
+ }
+
+ provider := &toolCallProvider{
+ toolCalls: []providers.ToolCall{
+ {
+ ID: "call_async_1",
+ Type: "function",
+ Name: "async_followup",
+ Function: &providers.FunctionCall{
+ Name: "async_followup",
+ Arguments: "{}",
+ },
+ Arguments: map[string]any{},
+ },
+ },
+ finalResp: "async launched",
+ }
+
+ msgBus := bus.NewMessageBus()
+ al := NewAgentLoop(cfg, msgBus, provider)
+ doneCh := make(chan struct{})
+ al.RegisterTool(&asyncFollowUpTool{
+ name: "async_followup",
+ followUpText: "background result",
+ completionSig: doneCh,
+ })
+ defaultAgent := al.registry.GetDefaultAgent()
+ if defaultAgent == nil {
+ t.Fatal("expected default agent")
+ }
+
+ sub := al.SubscribeEvents(32)
+ defer al.UnsubscribeEvents(sub.ID)
+
+ resp, err := al.runAgentLoop(context.Background(), defaultAgent, processOptions{
+ SessionKey: "session-1",
+ Channel: "cli",
+ ChatID: "direct",
+ UserMessage: "run async tool",
+ DefaultResponse: defaultResponse,
+ EnableSummary: false,
+ SendResponse: false,
+ })
+ if err != nil {
+ t.Fatalf("runAgentLoop failed: %v", err)
+ }
+ if resp != "async launched" {
+ t.Fatalf("expected final response 'async launched', got %q", resp)
+ }
+
+ select {
+ case <-doneCh:
+ case <-time.After(2 * time.Second):
+ t.Fatal("timeout waiting for async tool completion")
+ }
+
+ followUpEvt := waitForEvent(t, sub.C, 2*time.Second, func(evt Event) bool {
+ return evt.Kind == EventKindFollowUpQueued
+ })
+ payload, ok := followUpEvt.Payload.(FollowUpQueuedPayload)
+ if !ok {
+ t.Fatalf("expected FollowUpQueuedPayload, got %T", followUpEvt.Payload)
+ }
+ if payload.SourceTool != "async_followup" {
+ t.Fatalf("expected source tool async_followup, got %q", payload.SourceTool)
+ }
+ if payload.Channel != "cli" {
+ t.Fatalf("expected channel cli, got %q", payload.Channel)
+ }
+ if payload.ChatID != "direct" {
+ t.Fatalf("expected chat id direct, got %q", payload.ChatID)
+ }
+ if payload.ContentLen != len("background result") {
+ t.Fatalf("expected content len %d, got %d", len("background result"), payload.ContentLen)
+ }
+ if followUpEvt.Meta.SessionKey != "session-1" {
+ t.Fatalf("expected session key session-1, got %q", followUpEvt.Meta.SessionKey)
+ }
+ if followUpEvt.Meta.TurnID == "" {
+ t.Fatal("expected follow-up event to include turn id")
+ }
+}
+
func collectEventStream(ch <-chan Event) []Event {
var events []Event
for {
@@ -232,4 +600,80 @@ func collectEventStream(ch <-chan Event) []Event {
}
}
+func waitForEvent(t *testing.T, ch <-chan Event, timeout time.Duration, match func(Event) bool) Event {
+ t.Helper()
+
+ timer := time.NewTimer(timeout)
+ defer timer.Stop()
+
+ for {
+ select {
+ case evt, ok := <-ch:
+ if !ok {
+ t.Fatal("event stream closed before expected event arrived")
+ }
+ if match(evt) {
+ return evt
+ }
+ case <-timer.C:
+ t.Fatal("timed out waiting for expected event")
+ }
+ }
+}
+
+func findEvent(events []Event, kind EventKind) (Event, bool) {
+ for _, evt := range events {
+ if evt.Kind == kind {
+ return evt, true
+ }
+ }
+ return Event{}, false
+}
+
+type errString string
+
+func (e errString) Error() string {
+ return string(e)
+}
+
+type asyncFollowUpTool struct {
+ name string
+ followUpText string
+ completionSig chan struct{}
+}
+
+func (t *asyncFollowUpTool) Name() string {
+ return t.name
+}
+
+func (t *asyncFollowUpTool) Description() string {
+ return "async follow-up tool for testing"
+}
+
+func (t *asyncFollowUpTool) Parameters() map[string]any {
+ return map[string]any{
+ "type": "object",
+ "properties": map[string]any{},
+ }
+}
+
+func (t *asyncFollowUpTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
+ return tools.AsyncResult("async follow-up scheduled")
+}
+
+func (t *asyncFollowUpTool) ExecuteAsync(
+ ctx context.Context,
+ args map[string]any,
+ cb tools.AsyncCallback,
+) *tools.ToolResult {
+ go func() {
+ cb(ctx, &tools.ToolResult{ForLLM: t.followUpText})
+ if t.completionSig != nil {
+ close(t.completionSig)
+ }
+ }()
+ return tools.AsyncResult("async follow-up scheduled")
+}
+
var _ tools.Tool = (*mockCustomTool)(nil)
+var _ tools.AsyncExecutor = (*asyncFollowUpTool)(nil)
diff --git a/pkg/agent/events.go b/pkg/agent/events.go
index 92aec7436..fae5033a3 100644
--- a/pkg/agent/events.go
+++ b/pkg/agent/events.go
@@ -15,12 +15,34 @@ const (
EventKindTurnEnd
// EventKindLLMRequest is emitted before a provider chat request is made.
EventKindLLMRequest
+ // EventKindLLMDelta is emitted when a streaming provider yields a partial delta.
+ EventKindLLMDelta
// EventKindLLMResponse is emitted after a provider chat response is received.
EventKindLLMResponse
+ // EventKindLLMRetry is emitted when an LLM request is retried.
+ EventKindLLMRetry
+ // EventKindContextCompress is emitted when session history is forcibly compressed.
+ EventKindContextCompress
+ // EventKindSessionSummarize is emitted when asynchronous summarization completes.
+ EventKindSessionSummarize
// EventKindToolExecStart is emitted immediately before a tool executes.
EventKindToolExecStart
// EventKindToolExecEnd is emitted immediately after a tool finishes executing.
EventKindToolExecEnd
+ // EventKindToolExecSkipped is emitted when a queued tool call is skipped.
+ EventKindToolExecSkipped
+ // EventKindSteeringInjected is emitted when queued steering is injected into context.
+ EventKindSteeringInjected
+ // EventKindFollowUpQueued is emitted when an async tool queues a follow-up system message.
+ EventKindFollowUpQueued
+ // EventKindInterruptReceived is emitted when a soft interrupt message is accepted.
+ EventKindInterruptReceived
+ // EventKindSubTurnSpawn is emitted when a sub-turn is spawned.
+ EventKindSubTurnSpawn
+ // EventKindSubTurnEnd is emitted when a sub-turn finishes.
+ EventKindSubTurnEnd
+ // EventKindSubTurnResultDelivered is emitted when a sub-turn result is delivered.
+ EventKindSubTurnResultDelivered
// EventKindError is emitted when a turn encounters an execution error.
EventKindError
@@ -31,9 +53,20 @@ var eventKindNames = [...]string{
"turn_start",
"turn_end",
"llm_request",
+ "llm_delta",
"llm_response",
+ "llm_retry",
+ "context_compress",
+ "session_summarize",
"tool_exec_start",
"tool_exec_end",
+ "tool_exec_skipped",
+ "steering_injected",
+ "follow_up_queued",
+ "interrupt_received",
+ "subturn_spawn",
+ "subturn_end",
+ "subturn_result_delivered",
"error",
}
@@ -106,6 +139,46 @@ type LLMResponsePayload struct {
HasReasoning bool
}
+// LLMDeltaPayload describes a streamed LLM delta.
+type LLMDeltaPayload struct {
+ ContentDeltaLen int
+ ReasoningDeltaLen int
+}
+
+// LLMRetryPayload describes a retry of an LLM request.
+type LLMRetryPayload struct {
+ Attempt int
+ MaxRetries int
+ Reason string
+ Error string
+ Backoff time.Duration
+}
+
+// ContextCompressReason identifies why emergency compression ran.
+type ContextCompressReason string
+
+const (
+ // ContextCompressReasonProactive indicates compression before the first LLM call.
+ ContextCompressReasonProactive ContextCompressReason = "proactive_budget"
+ // ContextCompressReasonRetry indicates compression during context-error retry handling.
+ ContextCompressReasonRetry ContextCompressReason = "llm_retry"
+)
+
+// ContextCompressPayload describes a forced history compression.
+type ContextCompressPayload struct {
+ Reason ContextCompressReason
+ DroppedMessages int
+ RemainingMessages int
+}
+
+// SessionSummarizePayload describes a completed async session summarization.
+type SessionSummarizePayload struct {
+ SummarizedMessages int
+ KeptMessages int
+ SummaryLen int
+ OmittedOversized bool
+}
+
// ToolExecStartPayload describes a tool execution request.
type ToolExecStartPayload struct {
Tool string
@@ -122,6 +195,52 @@ type ToolExecEndPayload struct {
Async bool
}
+// ToolExecSkippedPayload describes a skipped tool call.
+type ToolExecSkippedPayload struct {
+ Tool string
+ Reason string
+}
+
+// SteeringInjectedPayload describes steering messages appended before the next LLM call.
+type SteeringInjectedPayload struct {
+ Count int
+ TotalContentLen int
+}
+
+// FollowUpQueuedPayload describes an async follow-up queued back into the inbound bus.
+type FollowUpQueuedPayload struct {
+ SourceTool string
+ Channel string
+ ChatID string
+ ContentLen int
+}
+
+// InterruptReceivedPayload describes a queued soft interrupt.
+type InterruptReceivedPayload struct {
+ Role string
+ ContentLen int
+ QueueDepth int
+}
+
+// SubTurnSpawnPayload describes the creation of a child turn.
+type SubTurnSpawnPayload struct {
+ AgentID string
+ Label string
+}
+
+// SubTurnEndPayload describes the completion of a child turn.
+type SubTurnEndPayload struct {
+ AgentID string
+ Status string
+}
+
+// SubTurnResultDeliveredPayload describes delivery of a sub-turn result.
+type SubTurnResultDeliveredPayload struct {
+ TargetChannel string
+ TargetChatID string
+ ContentLen int
+}
+
// ErrorPayload describes an execution error inside the agent loop.
type ErrorPayload struct {
Stage string
diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go
index ac97104b1..877dbbd94 100644
--- a/pkg/agent/loop.go
+++ b/pkg/agent/loop.go
@@ -499,10 +499,28 @@ func (al *AgentLoop) logEvent(evt Event) {
fields["messages"] = payload.MessagesCount
fields["tools"] = payload.ToolsCount
fields["max_tokens"] = payload.MaxTokens
+ case LLMDeltaPayload:
+ fields["content_delta_len"] = payload.ContentDeltaLen
+ fields["reasoning_delta_len"] = payload.ReasoningDeltaLen
case LLMResponsePayload:
fields["content_len"] = payload.ContentLen
fields["tool_calls"] = payload.ToolCalls
fields["has_reasoning"] = payload.HasReasoning
+ case LLMRetryPayload:
+ fields["attempt"] = payload.Attempt
+ fields["max_retries"] = payload.MaxRetries
+ fields["reason"] = payload.Reason
+ fields["error"] = payload.Error
+ fields["backoff_ms"] = payload.Backoff.Milliseconds()
+ case ContextCompressPayload:
+ fields["reason"] = payload.Reason
+ fields["dropped_messages"] = payload.DroppedMessages
+ fields["remaining_messages"] = payload.RemainingMessages
+ case SessionSummarizePayload:
+ fields["summarized_messages"] = payload.SummarizedMessages
+ fields["kept_messages"] = payload.KeptMessages
+ fields["summary_len"] = payload.SummaryLen
+ fields["omitted_oversized"] = payload.OmittedOversized
case ToolExecStartPayload:
fields["tool"] = payload.Tool
fields["args_count"] = len(payload.Arguments)
@@ -513,6 +531,31 @@ func (al *AgentLoop) logEvent(evt Event) {
fields["for_user_len"] = payload.ForUserLen
fields["is_error"] = payload.IsError
fields["async"] = payload.Async
+ case ToolExecSkippedPayload:
+ fields["tool"] = payload.Tool
+ fields["reason"] = payload.Reason
+ case SteeringInjectedPayload:
+ fields["count"] = payload.Count
+ fields["total_content_len"] = payload.TotalContentLen
+ case FollowUpQueuedPayload:
+ fields["source_tool"] = payload.SourceTool
+ fields["channel"] = payload.Channel
+ fields["chat_id"] = payload.ChatID
+ fields["content_len"] = payload.ContentLen
+ case InterruptReceivedPayload:
+ fields["role"] = payload.Role
+ fields["content_len"] = payload.ContentLen
+ fields["queue_depth"] = payload.QueueDepth
+ case SubTurnSpawnPayload:
+ fields["child_agent_id"] = payload.AgentID
+ fields["label"] = payload.Label
+ case SubTurnEndPayload:
+ fields["child_agent_id"] = payload.AgentID
+ fields["status"] = payload.Status
+ case SubTurnResultDeliveredPayload:
+ fields["target_channel"] = payload.TargetChannel
+ fields["target_chat_id"] = payload.TargetChatID
+ fields["content_len"] = payload.ContentLen
case ErrorPayload:
fields["stage"] = payload.Stage
fields["error"] = payload.Message
@@ -1105,7 +1148,17 @@ func (al *AgentLoop) runAgentLoop(
if isOverContextBudget(agent.ContextWindow, messages, toolDefs, agent.MaxTokens) {
logger.WarnCF("agent", "Proactive compression: context budget exceeded before LLM call",
map[string]any{"session_key": opts.SessionKey})
- al.forceCompression(agent, opts.SessionKey)
+ if compression, ok := al.forceCompression(agent, opts.SessionKey); ok {
+ al.emitEvent(
+ EventKindContextCompress,
+ turnScope.meta(0, "runAgentLoop", "turn.context.compress"),
+ ContextCompressPayload{
+ Reason: ContextCompressReasonProactive,
+ DroppedMessages: compression.DroppedMessages,
+ RemainingMessages: compression.RemainingMessages,
+ },
+ )
+ }
newHistory := agent.Sessions.GetHistory(opts.SessionKey)
newSummary := agent.Sessions.GetSummary(opts.SessionKey)
messages = agent.ContextBuilder.BuildMessages(
@@ -1142,7 +1195,7 @@ func (al *AgentLoop) runAgentLoop(
// 6. Optional: summarization
if opts.EnableSummary {
- al.maybeSummarize(agent, opts.SessionKey, opts.Channel, opts.ChatID)
+ al.maybeSummarize(agent, opts.SessionKey, turnScope)
}
// 7. Optional: send response via bus
@@ -1256,9 +1309,11 @@ func (al *AgentLoop) runLLMIteration(
// Inject pending steering messages into the conversation context
// before the next LLM call.
if len(pendingMessages) > 0 {
+ totalContentLen := 0
for _, pm := range pendingMessages {
messages = append(messages, pm)
agent.Sessions.AddMessage(opts.SessionKey, pm.Role, pm.Content)
+ totalContentLen += len(pm.Content)
logger.InfoCF("agent", "Injected steering message into context",
map[string]any{
"agent_id": agent.ID,
@@ -1266,6 +1321,14 @@ func (al *AgentLoop) runLLMIteration(
"content_len": len(pm.Content),
})
}
+ al.emitEvent(
+ EventKindSteeringInjected,
+ turnScope.meta(iteration, "runLLMIteration", "turn.steering.injected"),
+ SteeringInjectedPayload{
+ Count: len(pendingMessages),
+ TotalContentLen: totalContentLen,
+ },
+ )
pendingMessages = nil
}
@@ -1334,6 +1397,8 @@ func (al *AgentLoop) runLLMIteration(
callLLM := func() (*providers.LLMResponse, error) {
al.activeRequests.Add(1)
defer al.activeRequests.Done()
+ // TODO(eventbus): emit EventKindLLMDelta when providers expose
+ // streaming callbacks instead of only the final Chat response.
if len(activeCandidates) > 1 && al.fallback != nil {
fbResult, fbErr := al.fallback.Execute(
@@ -1389,6 +1454,17 @@ func (al *AgentLoop) runLLMIteration(
if isTimeoutError && retry < maxRetries {
backoff := time.Duration(retry+1) * 5 * time.Second
+ al.emitEvent(
+ EventKindLLMRetry,
+ turnScope.meta(iteration, "runLLMIteration", "turn.llm.retry"),
+ LLMRetryPayload{
+ Attempt: retry + 1,
+ MaxRetries: maxRetries,
+ Reason: "timeout",
+ Error: err.Error(),
+ Backoff: backoff,
+ },
+ )
logger.WarnCF("agent", "Timeout error, retrying after backoff", map[string]any{
"error": err.Error(),
"retry": retry,
@@ -1399,6 +1475,16 @@ func (al *AgentLoop) runLLMIteration(
}
if isContextError && retry < maxRetries {
+ al.emitEvent(
+ EventKindLLMRetry,
+ turnScope.meta(iteration, "runLLMIteration", "turn.llm.retry"),
+ LLMRetryPayload{
+ Attempt: retry + 1,
+ MaxRetries: maxRetries,
+ Reason: "context_limit",
+ Error: err.Error(),
+ },
+ )
logger.WarnCF(
"agent",
"Context window error detected, attempting compression",
@@ -1416,7 +1502,17 @@ func (al *AgentLoop) runLLMIteration(
})
}
- al.forceCompression(agent, opts.SessionKey)
+ if compression, ok := al.forceCompression(agent, opts.SessionKey); ok {
+ al.emitEvent(
+ EventKindContextCompress,
+ turnScope.meta(iteration, "runLLMIteration", "turn.context.compress"),
+ ContextCompressPayload{
+ Reason: ContextCompressReasonRetry,
+ DroppedMessages: compression.DroppedMessages,
+ RemainingMessages: compression.RemainingMessages,
+ },
+ )
+ }
newHistory := agent.Sessions.GetHistory(opts.SessionKey)
newSummary := agent.Sessions.GetSummary(opts.SessionKey)
messages = agent.ContextBuilder.BuildMessages(
@@ -1587,6 +1683,16 @@ func (al *AgentLoop) runLLMIteration(
"content_len": len(content),
"channel": opts.Channel,
})
+ al.emitEvent(
+ EventKindFollowUpQueued,
+ turnScope.meta(iteration, "runLLMIteration", "turn.follow_up.queued"),
+ FollowUpQueuedPayload{
+ SourceTool: tc.Name,
+ Channel: opts.Channel,
+ ChatID: opts.ChatID,
+ ContentLen: len(content),
+ },
+ )
pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer pubCancel()
@@ -1686,6 +1792,14 @@ func (al *AgentLoop) runLLMIteration(
// Mark remaining tool calls as skipped
for j := i + 1; j < len(normalizedToolCalls); j++ {
skippedTC := normalizedToolCalls[j]
+ al.emitEvent(
+ EventKindToolExecSkipped,
+ turnScope.meta(iteration, "runLLMIteration", "turn.tool.skipped"),
+ ToolExecSkippedPayload{
+ Tool: skippedTC.Name,
+ Reason: "queued user steering message",
+ },
+ )
toolResultMsg := providers.Message{
Role: "tool",
Content: "Skipped due to queued user message.",
@@ -1760,7 +1874,7 @@ func (al *AgentLoop) selectCandidates(
}
// maybeSummarize triggers summarization if the session history exceeds thresholds.
-func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey, channel, chatID string) {
+func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey string, turnScope turnEventScope) {
newHistory := agent.Sessions.GetHistory(sessionKey)
tokenEstimate := al.estimateTokens(newHistory)
threshold := agent.ContextWindow * agent.SummarizeTokenPercent / 100
@@ -1771,12 +1885,17 @@ func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey, channel, c
go func() {
defer al.summarizing.Delete(summarizeKey)
logger.Debug("Memory threshold reached. Optimizing conversation history...")
- al.summarizeSession(agent, sessionKey)
+ al.summarizeSession(agent, sessionKey, turnScope)
}()
}
}
}
+type compressionResult struct {
+ DroppedMessages int
+ RemainingMessages int
+}
+
// forceCompression aggressively reduces context when the limit is hit.
// It drops the oldest ~50% of Turns (a Turn is a complete user→LLM→response
// cycle, as defined in #1316), so tool-call sequences are never split.
@@ -1789,10 +1908,10 @@ func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey, channel, c
// prompt is built dynamically by BuildMessages and is NOT stored here.
// The compression note is recorded in the session summary so that
// BuildMessages can include it in the next system prompt.
-func (al *AgentLoop) forceCompression(agent *AgentInstance, sessionKey string) {
+func (al *AgentLoop) forceCompression(agent *AgentInstance, sessionKey string) (compressionResult, bool) {
history := agent.Sessions.GetHistory(sessionKey)
if len(history) <= 2 {
- return
+ return compressionResult{}, false
}
// Split at a Turn boundary so no tool-call sequence is torn apart.
@@ -1846,6 +1965,11 @@ func (al *AgentLoop) forceCompression(agent *AgentInstance, sessionKey string) {
"dropped_msgs": droppedCount,
"new_count": len(keptHistory),
})
+
+ return compressionResult{
+ DroppedMessages: droppedCount,
+ RemainingMessages: len(keptHistory),
+ }, true
}
// GetStartupInfo returns information about loaded tools and skills for logging.
@@ -1937,7 +2061,7 @@ func formatToolsForLog(toolDefs []providers.ToolDefinition) string {
}
// summarizeSession summarizes the conversation history for a session.
-func (al *AgentLoop) summarizeSession(agent *AgentInstance, sessionKey string) {
+func (al *AgentLoop) summarizeSession(agent *AgentInstance, sessionKey string, turnScope turnEventScope) {
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
@@ -2022,6 +2146,16 @@ func (al *AgentLoop) summarizeSession(agent *AgentInstance, sessionKey string) {
agent.Sessions.SetSummary(sessionKey, finalSummary)
agent.Sessions.TruncateHistory(sessionKey, keepCount)
agent.Sessions.Save(sessionKey)
+ al.emitEvent(
+ EventKindSessionSummarize,
+ turnScope.meta(0, "summarizeSession", "turn.session.summarize"),
+ SessionSummarizePayload{
+ SummarizedMessages: len(validMessages),
+ KeptMessages: keepCount,
+ SummaryLen: len(finalSummary),
+ OmittedOversized: omitted,
+ },
+ )
}
}
diff --git a/pkg/agent/steering.go b/pkg/agent/steering.go
index 8c7c79c16..90d1cc091 100644
--- a/pkg/agent/steering.go
+++ b/pkg/agent/steering.go
@@ -122,6 +122,25 @@ func (al *AgentLoop) Steer(msg providers.Message) error {
"content_len": len(msg.Content),
"queue_len": al.steering.len(),
})
+ agentID := ""
+ if registry := al.GetRegistry(); registry != nil {
+ if agent := registry.GetDefaultAgent(); agent != nil {
+ agentID = agent.ID
+ }
+ }
+ al.emitEvent(
+ EventKindInterruptReceived,
+ EventMeta{
+ AgentID: agentID,
+ Source: "Steer",
+ TracePath: "turn.interrupt.received",
+ },
+ InterruptReceivedPayload{
+ Role: msg.Role,
+ ContentLen: len(msg.Content),
+ QueueDepth: al.steering.len(),
+ },
+ )
return nil
}
diff --git a/pkg/tools/spawn.go b/pkg/tools/spawn.go
index be40ffda2..34ccc80e4 100644
--- a/pkg/tools/spawn.go
+++ b/pkg/tools/spawn.go
@@ -96,6 +96,9 @@ func (t *SpawnTool) execute(ctx context.Context, args map[string]any, cb AsyncCa
}
// Pass callback to manager for async completion notification
+ // TODO(eventbus): when background subagents are migrated onto the
+ // agent package's runTurn/sub-turn tree, emit SubTurnSpawn here and move
+ // lifecycle events out of the legacy SubagentManager path.
result, err := t.manager.Spawn(ctx, task, label, agentID, channel, chatID, cb)
if err != nil {
return ErrorResult(fmt.Sprintf("failed to spawn subagent: %v", err))
diff --git a/pkg/tools/subagent.go b/pkg/tools/subagent.go
index e51cbaafa..9915c5900 100644
--- a/pkg/tools/subagent.go
+++ b/pkg/tools/subagent.go
@@ -111,6 +111,9 @@ func (sm *SubagentManager) Spawn(
func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask, callback AsyncCallback) {
task.Status = "running"
task.Created = time.Now().UnixMilli()
+ // TODO(eventbus): once subagents are modeled as child turns inside
+ // pkg/agent, emit SubTurnEnd and SubTurnResultDelivered from the parent
+ // AgentLoop instead of this legacy manager.
// Build system prompt for subagent
systemPrompt := `You are a subagent. Complete the given task independently and report the result.
From a65e0e95d618bc7437d80acb529a9568cce7b44c Mon Sep 17 00:00:00 2001
From: Hoshina
Date: Fri, 20 Mar 2026 15:45:27 +0800
Subject: [PATCH 18/26] fix: lint err
---
pkg/agent/eventbus_test.go | 12 +++++++-----
1 file changed, 7 insertions(+), 5 deletions(-)
diff --git a/pkg/agent/eventbus_test.go b/pkg/agent/eventbus_test.go
index dadbc2f94..13f2f2282 100644
--- a/pkg/agent/eventbus_test.go
+++ b/pkg/agent/eventbus_test.go
@@ -357,7 +357,7 @@ func TestAgentLoop_EmitsContextCompressEventOnRetry(t *testing.T) {
},
}
- contextErr := errString("InvalidParameter: Total tokens of image and text exceed max message tokens")
+ contextErr := stringError("InvalidParameter: Total tokens of image and text exceed max message tokens")
provider := &failFirstMockProvider{
failures: 1,
failError: contextErr,
@@ -630,9 +630,9 @@ func findEvent(events []Event, kind EventKind) (Event, bool) {
return Event{}, false
}
-type errString string
+type stringError string
-func (e errString) Error() string {
+func (e stringError) Error() string {
return string(e)
}
@@ -675,5 +675,7 @@ func (t *asyncFollowUpTool) ExecuteAsync(
return tools.AsyncResult("async follow-up scheduled")
}
-var _ tools.Tool = (*mockCustomTool)(nil)
-var _ tools.AsyncExecutor = (*asyncFollowUpTool)(nil)
+var (
+ _ tools.Tool = (*mockCustomTool)(nil)
+ _ tools.AsyncExecutor = (*asyncFollowUpTool)(nil)
+)
From 0e075f7300014e4d305c346f3555742e34cb8174 Mon Sep 17 00:00:00 2001
From: Hoshina
Date: Fri, 20 Mar 2026 17:28:12 +0800
Subject: [PATCH 19/26] feat(agent): centralize turn lifecycle and continue
queued steering
Refactor agent loop execution around runTurn, add explicit turn state and interrupt semantics, and automatically continue queued steering that misses the current turn boundary.
---
pkg/agent/eventbus_test.go | 3 +
pkg/agent/events.go | 14 +-
pkg/agent/loop.go | 818 ++++++++++++++++++++++---------------
pkg/agent/steering.go | 70 +++-
pkg/agent/steering_test.go | 518 +++++++++++++++++++++++
pkg/agent/turn.go | 309 ++++++++++++++
6 files changed, 1395 insertions(+), 337 deletions(-)
create mode 100644 pkg/agent/turn.go
diff --git a/pkg/agent/eventbus_test.go b/pkg/agent/eventbus_test.go
index 13f2f2282..9acc6ddd8 100644
--- a/pkg/agent/eventbus_test.go
+++ b/pkg/agent/eventbus_test.go
@@ -334,6 +334,9 @@ func TestAgentLoop_EmitsSteeringAndSkippedToolEvents(t *testing.T) {
if interruptPayload.Role != "user" {
t.Fatalf("expected interrupt role user, got %q", interruptPayload.Role)
}
+ if interruptPayload.Kind != InterruptKindSteering {
+ t.Fatalf("expected steering interrupt kind, got %q", interruptPayload.Kind)
+ }
if interruptPayload.ContentLen != len("change course") {
t.Fatalf("expected interrupt content len %d, got %d", len("change course"), interruptPayload.ContentLen)
}
diff --git a/pkg/agent/events.go b/pkg/agent/events.go
index fae5033a3..95e4c90d0 100644
--- a/pkg/agent/events.go
+++ b/pkg/agent/events.go
@@ -105,6 +105,8 @@ const (
TurnEndStatusCompleted TurnEndStatus = "completed"
// TurnEndStatusError indicates the turn ended because of an error.
TurnEndStatusError TurnEndStatus = "error"
+ // TurnEndStatusAborted indicates the turn was hard-aborted and rolled back.
+ TurnEndStatusAborted TurnEndStatus = "aborted"
)
// TurnStartPayload describes the start of a turn.
@@ -215,11 +217,21 @@ type FollowUpQueuedPayload struct {
ContentLen int
}
-// InterruptReceivedPayload describes a queued soft interrupt.
+type InterruptKind string
+
+const (
+ InterruptKindSteering InterruptKind = "steering"
+ InterruptKindGraceful InterruptKind = "graceful"
+ InterruptKindHard InterruptKind = "hard_abort"
+)
+
+// InterruptReceivedPayload describes accepted turn-control input.
type InterruptReceivedPayload struct {
+ Kind InterruptKind
Role string
ContentLen int
QueueDepth int
+ HintLen int
}
// SubTurnSpawnPayload describes the creation of a child turn.
diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go
index 877dbbd94..f54482ae8 100644
--- a/pkg/agent/loop.go
+++ b/pkg/agent/loop.go
@@ -50,6 +50,8 @@ type AgentLoop struct {
mcp mcpRuntime
steering *steeringQueue
mu sync.RWMutex
+ activeTurnMu sync.RWMutex
+ activeTurn *turnState
turnSeq atomic.Uint64
// Track active requests for safe provider cleanup
activeRequests sync.WaitGroup
@@ -69,6 +71,12 @@ type processOptions struct {
SkipInitialSteeringPoll bool // If true, skip the steering poll at loop start (used by Continue)
}
+type continuationTarget struct {
+ SessionKey string
+ Channel string
+ ChatID string
+}
+
const (
defaultResponse = "I've completed processing but have no response to give. Increase `max_tool_iterations` in config.json."
sessionKeyAgentPrefix = "agent:"
@@ -292,38 +300,46 @@ func (al *AgentLoop) Run(ctx context.Context) error {
}
if response != "" {
- // Check if the message tool already sent a response during this round.
- // If so, skip publishing to avoid duplicate messages to the user.
- // Use default agent's tools to check (message tool is shared).
- alreadySent := false
- defaultAgent := al.GetRegistry().GetDefaultAgent()
- if defaultAgent != nil {
- if tool, ok := defaultAgent.Tools.Get("message"); ok {
- if mt, ok := tool.(*tools.MessageTool); ok {
- alreadySent = mt.HasSentInRound()
- }
- }
+ al.publishResponseIfNeeded(ctx, msg.Channel, msg.ChatID, response)
+ }
+
+ target, targetErr := al.buildContinuationTarget(msg)
+ if targetErr != nil {
+ logger.WarnCF("agent", "Failed to build steering continuation target",
+ map[string]any{
+ "channel": msg.Channel,
+ "error": targetErr.Error(),
+ })
+ return
+ }
+ if target == nil {
+ return
+ }
+
+ for al.pendingSteeringCount() > 0 {
+ logger.InfoCF("agent", "Continuing queued steering after turn end",
+ map[string]any{
+ "channel": target.Channel,
+ "chat_id": target.ChatID,
+ "session_key": target.SessionKey,
+ "queue_depth": al.pendingSteeringCount(),
+ })
+
+ continued, continueErr := al.Continue(ctx, target.SessionKey, target.Channel, target.ChatID)
+ if continueErr != nil {
+ logger.WarnCF("agent", "Failed to continue queued steering",
+ map[string]any{
+ "channel": target.Channel,
+ "chat_id": target.ChatID,
+ "error": continueErr.Error(),
+ })
+ return
+ }
+ if continued == "" {
+ return
}
- if !alreadySent {
- al.bus.PublishOutbound(ctx, bus.OutboundMessage{
- Channel: msg.Channel,
- ChatID: msg.ChatID,
- Content: response,
- })
- logger.InfoCF("agent", "Published outbound response",
- map[string]any{
- "channel": msg.Channel,
- "chat_id": msg.ChatID,
- "content_len": len(response),
- })
- } else {
- logger.DebugCF(
- "agent",
- "Skipped outbound (message tool already sent)",
- map[string]any{"channel": msg.Channel},
- )
- }
+ al.publishResponseIfNeeded(ctx, target.Channel, target.ChatID, continued)
}
}()
}
@@ -369,6 +385,67 @@ func (al *AgentLoop) Stop() {
al.running.Store(false)
}
+func (al *AgentLoop) publishResponseIfNeeded(ctx context.Context, channel, chatID, response string) {
+ if response == "" {
+ return
+ }
+
+ alreadySent := false
+ defaultAgent := al.GetRegistry().GetDefaultAgent()
+ if defaultAgent != nil {
+ if tool, ok := defaultAgent.Tools.Get("message"); ok {
+ if mt, ok := tool.(*tools.MessageTool); ok {
+ alreadySent = mt.HasSentInRound()
+ }
+ }
+ }
+
+ if alreadySent {
+ logger.DebugCF(
+ "agent",
+ "Skipped outbound (message tool already sent)",
+ map[string]any{"channel": channel},
+ )
+ return
+ }
+
+ al.bus.PublishOutbound(ctx, bus.OutboundMessage{
+ Channel: channel,
+ ChatID: chatID,
+ Content: response,
+ })
+ logger.InfoCF("agent", "Published outbound response",
+ map[string]any{
+ "channel": channel,
+ "chat_id": chatID,
+ "content_len": len(response),
+ })
+}
+
+func (al *AgentLoop) pendingSteeringCount() int {
+ if al.steering == nil {
+ return 0
+ }
+ return al.steering.len()
+}
+
+func (al *AgentLoop) buildContinuationTarget(msg bus.InboundMessage) (*continuationTarget, error) {
+ if msg.Channel == "system" {
+ return nil, nil
+ }
+
+ route, _, err := al.resolveMessageRoute(msg)
+ if err != nil {
+ return nil, err
+ }
+
+ return &continuationTarget{
+ SessionKey: resolveScopeKey(route, msg.SessionKey),
+ Channel: msg.Channel,
+ ChatID: msg.ChatID,
+ }, nil
+}
+
// Close releases resources held by agent session stores. Call after Stop.
func (al *AgentLoop) Close() {
mcpManager := al.mcp.takeManager()
@@ -543,9 +620,11 @@ func (al *AgentLoop) logEvent(evt Event) {
fields["chat_id"] = payload.ChatID
fields["content_len"] = payload.ContentLen
case InterruptReceivedPayload:
+ fields["interrupt_kind"] = payload.Kind
fields["role"] = payload.Role
fields["content_len"] = payload.ContentLen
fields["queue_depth"] = payload.QueueDepth
+ fields["hint_len"] = payload.HintLen
case SubTurnSpawnPayload:
fields["child_agent_id"] = payload.AgentID
fields["label"] = payload.Label
@@ -1071,153 +1150,63 @@ func (al *AgentLoop) processSystemMessage(
})
}
-// runAgentLoop is the core message processing logic.
+// runAgentLoop remains the top-level shell that starts a turn and publishes
+// any post-turn work. runTurn owns the full turn lifecycle.
func (al *AgentLoop) runAgentLoop(
ctx context.Context,
agent *AgentInstance,
opts processOptions,
) (string, error) {
- turnScope := al.newTurnEventScope(agent.ID, opts.SessionKey)
- turnStartedAt := time.Now()
- turnIterations := 0
- turnFinalContentLen := 0
- turnStatus := TurnEndStatusCompleted
- defer func() {
- al.emitEvent(
- EventKindTurnEnd,
- turnScope.meta(turnIterations, "runAgentLoop", "turn.end"),
- TurnEndPayload{
- Status: turnStatus,
- Iterations: turnIterations,
- Duration: time.Since(turnStartedAt),
- FinalContentLen: turnFinalContentLen,
- },
- )
- }()
-
- al.emitEvent(
- EventKindTurnStart,
- turnScope.meta(0, "runAgentLoop", "turn.start"),
- TurnStartPayload{
- Channel: opts.Channel,
- ChatID: opts.ChatID,
- UserMessage: opts.UserMessage,
- MediaCount: len(opts.Media),
- },
- )
-
- // 0. Record last channel for heartbeat notifications (skip internal channels and cli)
- if opts.Channel != "" && opts.ChatID != "" {
- if !constants.IsInternalChannel(opts.Channel) {
- channelKey := fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID)
- if err := al.RecordLastChannel(channelKey); err != nil {
- logger.WarnCF(
- "agent",
- "Failed to record last channel",
- map[string]any{"error": err.Error()},
- )
- }
- }
- }
-
- // 1. Build messages (skip history for heartbeat)
- var history []providers.Message
- var summary string
- if !opts.NoHistory {
- history = agent.Sessions.GetHistory(opts.SessionKey)
- summary = agent.Sessions.GetSummary(opts.SessionKey)
- }
- messages := agent.ContextBuilder.BuildMessages(
- history,
- summary,
- opts.UserMessage,
- opts.Media,
- opts.Channel,
- opts.ChatID,
- )
-
- // Resolve media:// refs: images→base64 data URLs, non-images→local paths in content
- cfg := al.GetConfig()
- maxMediaSize := cfg.Agents.Defaults.GetMaxMediaSize()
- messages = resolveMediaRefs(messages, al.mediaStore, maxMediaSize)
-
- // 1.5. Proactive context budget check: compress before LLM call
- // rather than waiting for a 400 context-length error.
- if !opts.NoHistory {
- toolDefs := agent.Tools.ToProviderDefs()
- if isOverContextBudget(agent.ContextWindow, messages, toolDefs, agent.MaxTokens) {
- logger.WarnCF("agent", "Proactive compression: context budget exceeded before LLM call",
- map[string]any{"session_key": opts.SessionKey})
- if compression, ok := al.forceCompression(agent, opts.SessionKey); ok {
- al.emitEvent(
- EventKindContextCompress,
- turnScope.meta(0, "runAgentLoop", "turn.context.compress"),
- ContextCompressPayload{
- Reason: ContextCompressReasonProactive,
- DroppedMessages: compression.DroppedMessages,
- RemainingMessages: compression.RemainingMessages,
- },
- )
- }
- newHistory := agent.Sessions.GetHistory(opts.SessionKey)
- newSummary := agent.Sessions.GetSummary(opts.SessionKey)
- messages = agent.ContextBuilder.BuildMessages(
- newHistory, newSummary, opts.UserMessage,
- opts.Media, opts.Channel, opts.ChatID,
+ if opts.Channel != "" && opts.ChatID != "" && !constants.IsInternalChannel(opts.Channel) {
+ channelKey := fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID)
+ if err := al.RecordLastChannel(channelKey); err != nil {
+ logger.WarnCF(
+ "agent",
+ "Failed to record last channel",
+ map[string]any{"error": err.Error()},
)
- messages = resolveMediaRefs(messages, al.mediaStore, maxMediaSize)
}
}
- // 2. Save user message to session
- agent.Sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage)
-
- // 3. Run LLM iteration loop
- finalContent, iteration, err := al.runLLMIteration(ctx, agent, messages, opts, turnScope)
- turnIterations = iteration
+ ts := newTurnState(agent, opts, al.newTurnEventScope(agent.ID, opts.SessionKey))
+ result, err := al.runTurn(ctx, ts)
if err != nil {
- turnStatus = TurnEndStatusError
return "", err
}
-
- // If last tool had ForUser content and we already sent it, we might not need to send final response
- // This is controlled by the tool's Silent flag and ForUser content
-
- // 4. Handle empty response
- if finalContent == "" {
- finalContent = opts.DefaultResponse
- }
- turnFinalContentLen = len(finalContent)
-
- // 5. Save final assistant message to session
- agent.Sessions.AddMessage(opts.SessionKey, "assistant", finalContent)
- agent.Sessions.Save(opts.SessionKey)
-
- // 6. Optional: summarization
- if opts.EnableSummary {
- al.maybeSummarize(agent, opts.SessionKey, turnScope)
+ if result.status == TurnEndStatusAborted {
+ return "", nil
}
- // 7. Optional: send response via bus
- if opts.SendResponse {
+ for _, followUp := range result.followUps {
+ if pubErr := al.bus.PublishInbound(ctx, followUp); pubErr != nil {
+ logger.WarnCF("agent", "Failed to publish follow-up after turn",
+ map[string]any{
+ "turn_id": ts.turnID,
+ "error": pubErr.Error(),
+ })
+ }
+ }
+
+ if opts.SendResponse && result.finalContent != "" {
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
Channel: opts.Channel,
ChatID: opts.ChatID,
- Content: finalContent,
+ Content: result.finalContent,
})
}
- // 8. Log response
- responsePreview := utils.Truncate(finalContent, 120)
- logger.InfoCF("agent", fmt.Sprintf("Response: %s", responsePreview),
- map[string]any{
- "agent_id": agent.ID,
- "session_key": opts.SessionKey,
- "iterations": iteration,
- "final_length": len(finalContent),
- })
+ if result.finalContent != "" {
+ responsePreview := utils.Truncate(result.finalContent, 120)
+ logger.InfoCF("agent", fmt.Sprintf("Response: %s", responsePreview),
+ map[string]any{
+ "agent_id": agent.ID,
+ "session_key": opts.SessionKey,
+ "iterations": ts.currentIteration(),
+ "final_length": len(result.finalContent),
+ })
+ }
- return finalContent, nil
+ return result.finalContent, nil
}
func (al *AgentLoop) targetReasoningChannelID(channelName string) (chatID string) {
@@ -1276,54 +1265,135 @@ func (al *AgentLoop) handleReasoning(
}
}
-// runLLMIteration executes the LLM call loop with tool handling.
-func (al *AgentLoop) runLLMIteration(
- ctx context.Context,
- agent *AgentInstance,
- messages []providers.Message,
- opts processOptions,
- turnScope turnEventScope,
-) (string, int, error) {
- iteration := 0
- var finalContent string
- var pendingMessages []providers.Message
+func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, error) {
+ turnCtx, turnCancel := context.WithCancel(ctx)
+ defer turnCancel()
+ ts.setTurnCancel(turnCancel)
- // Poll for steering messages at loop start (in case the user typed while
- // the agent was setting up), unless the caller already provided initial
- // steering messages (e.g. Continue).
- if !opts.SkipInitialSteeringPoll {
- if msgs := al.dequeueSteeringMessages(); len(msgs) > 0 {
- pendingMessages = msgs
+ al.registerActiveTurn(ts)
+ defer al.clearActiveTurn(ts)
+
+ turnStatus := TurnEndStatusCompleted
+ defer func() {
+ al.emitEvent(
+ EventKindTurnEnd,
+ ts.eventMeta("runTurn", "turn.end"),
+ TurnEndPayload{
+ Status: turnStatus,
+ Iterations: ts.currentIteration(),
+ Duration: time.Since(ts.startedAt),
+ FinalContentLen: ts.finalContentLen(),
+ },
+ )
+ }()
+
+ al.emitEvent(
+ EventKindTurnStart,
+ ts.eventMeta("runTurn", "turn.start"),
+ TurnStartPayload{
+ Channel: ts.channel,
+ ChatID: ts.chatID,
+ UserMessage: ts.userMessage,
+ MediaCount: len(ts.media),
+ },
+ )
+
+ var history []providers.Message
+ var summary string
+ if !ts.opts.NoHistory {
+ history = ts.agent.Sessions.GetHistory(ts.sessionKey)
+ summary = ts.agent.Sessions.GetSummary(ts.sessionKey)
+ }
+ ts.captureRestorePoint(history, summary)
+
+ messages := ts.agent.ContextBuilder.BuildMessages(
+ history,
+ summary,
+ ts.userMessage,
+ ts.media,
+ ts.channel,
+ ts.chatID,
+ )
+
+ cfg := al.GetConfig()
+ maxMediaSize := cfg.Agents.Defaults.GetMaxMediaSize()
+ messages = resolveMediaRefs(messages, al.mediaStore, maxMediaSize)
+
+ if !ts.opts.NoHistory {
+ toolDefs := ts.agent.Tools.ToProviderDefs()
+ if isOverContextBudget(ts.agent.ContextWindow, messages, toolDefs, ts.agent.MaxTokens) {
+ logger.WarnCF("agent", "Proactive compression: context budget exceeded before LLM call",
+ map[string]any{"session_key": ts.sessionKey})
+ if compression, ok := al.forceCompression(ts.agent, ts.sessionKey); ok {
+ al.emitEvent(
+ EventKindContextCompress,
+ ts.eventMeta("runTurn", "turn.context.compress"),
+ ContextCompressPayload{
+ Reason: ContextCompressReasonProactive,
+ DroppedMessages: compression.DroppedMessages,
+ RemainingMessages: compression.RemainingMessages,
+ },
+ )
+ ts.refreshRestorePointFromSession(ts.agent)
+ }
+ newHistory := ts.agent.Sessions.GetHistory(ts.sessionKey)
+ newSummary := ts.agent.Sessions.GetSummary(ts.sessionKey)
+ messages = ts.agent.ContextBuilder.BuildMessages(
+ newHistory, newSummary, ts.userMessage,
+ ts.media, ts.channel, ts.chatID,
+ )
+ messages = resolveMediaRefs(messages, al.mediaStore, maxMediaSize)
}
}
- // Determine effective model tier for this conversation turn.
- // selectCandidates evaluates routing once and the decision is sticky for
- // all tool-follow-up iterations within the same turn so that a multi-step
- // tool chain doesn't switch models mid-way through.
- activeCandidates, activeModel := al.selectCandidates(agent, opts.UserMessage, messages)
+ if !ts.opts.NoHistory {
+ rootMsg := providers.Message{Role: "user", Content: ts.userMessage}
+ ts.agent.Sessions.AddMessage(ts.sessionKey, rootMsg.Role, rootMsg.Content)
+ ts.recordPersistedMessage(rootMsg)
+ }
- for iteration < agent.MaxIterations || len(pendingMessages) > 0 {
- iteration++
+ activeCandidates, activeModel := al.selectCandidates(ts.agent, ts.userMessage, messages)
+ var pendingMessages []providers.Message
+ var finalContent string
+
+ for ts.currentIteration() < ts.agent.MaxIterations || len(pendingMessages) > 0 || func() bool {
+ graceful, _ := ts.gracefulInterruptRequested()
+ return graceful
+ }() {
+ if ts.hardAbortRequested() {
+ turnStatus = TurnEndStatusAborted
+ return al.abortTurn(ts)
+ }
+
+ iteration := ts.currentIteration() + 1
+ ts.setIteration(iteration)
+ ts.setPhase(TurnPhaseRunning)
+
+ if iteration > 1 || !ts.opts.SkipInitialSteeringPoll {
+ if steerMsgs := al.dequeueSteeringMessages(); len(steerMsgs) > 0 {
+ pendingMessages = append(pendingMessages, steerMsgs...)
+ }
+ }
- // Inject pending steering messages into the conversation context
- // before the next LLM call.
if len(pendingMessages) > 0 {
totalContentLen := 0
for _, pm := range pendingMessages {
messages = append(messages, pm)
- agent.Sessions.AddMessage(opts.SessionKey, pm.Role, pm.Content)
totalContentLen += len(pm.Content)
+ if !ts.opts.NoHistory {
+ ts.agent.Sessions.AddMessage(ts.sessionKey, pm.Role, pm.Content)
+ ts.recordPersistedMessage(pm)
+ }
logger.InfoCF("agent", "Injected steering message into context",
map[string]any{
- "agent_id": agent.ID,
+ "agent_id": ts.agent.ID,
"iteration": iteration,
"content_len": len(pm.Content),
})
}
al.emitEvent(
EventKindSteeringInjected,
- turnScope.meta(iteration, "runLLMIteration", "turn.steering.injected"),
+ ts.eventMeta("runTurn", "turn.steering.injected"),
SteeringInjectedPayload{
Count: len(pendingMessages),
TotalContentLen: totalContentLen,
@@ -1334,78 +1404,81 @@ func (al *AgentLoop) runLLMIteration(
logger.DebugCF("agent", "LLM iteration",
map[string]any{
- "agent_id": agent.ID,
+ "agent_id": ts.agent.ID,
"iteration": iteration,
- "max": agent.MaxIterations,
+ "max": ts.agent.MaxIterations,
})
- // Build tool definitions
- providerToolDefs := agent.Tools.ToProviderDefs()
+ gracefulTerminal, _ := ts.gracefulInterruptRequested()
+ providerToolDefs := ts.agent.Tools.ToProviderDefs()
+ callMessages := messages
+ if gracefulTerminal {
+ callMessages = append(append([]providers.Message(nil), messages...), ts.interruptHintMessage())
+ providerToolDefs = nil
+ ts.markGracefulTerminalUsed()
+ }
+
al.emitEvent(
EventKindLLMRequest,
- turnScope.meta(iteration, "runLLMIteration", "turn.llm.request"),
+ ts.eventMeta("runTurn", "turn.llm.request"),
LLMRequestPayload{
Model: activeModel,
- MessagesCount: len(messages),
+ MessagesCount: len(callMessages),
ToolsCount: len(providerToolDefs),
- MaxTokens: agent.MaxTokens,
- Temperature: agent.Temperature,
+ MaxTokens: ts.agent.MaxTokens,
+ Temperature: ts.agent.Temperature,
},
)
- // Log LLM request details
logger.DebugCF("agent", "LLM request",
map[string]any{
- "agent_id": agent.ID,
+ "agent_id": ts.agent.ID,
"iteration": iteration,
"model": activeModel,
- "messages_count": len(messages),
+ "messages_count": len(callMessages),
"tools_count": len(providerToolDefs),
- "max_tokens": agent.MaxTokens,
- "temperature": agent.Temperature,
- "system_prompt_len": len(messages[0].Content),
+ "max_tokens": ts.agent.MaxTokens,
+ "temperature": ts.agent.Temperature,
+ "system_prompt_len": len(callMessages[0].Content),
})
-
- // Log full messages (detailed)
logger.DebugCF("agent", "Full LLM request",
map[string]any{
"iteration": iteration,
- "messages_json": formatMessagesForLog(messages),
+ "messages_json": formatMessagesForLog(callMessages),
"tools_json": formatToolsForLog(providerToolDefs),
})
- // Call LLM with fallback chain if multiple candidates are configured.
- var response *providers.LLMResponse
- var err error
-
llmOpts := map[string]any{
- "max_tokens": agent.MaxTokens,
- "temperature": agent.Temperature,
- "prompt_cache_key": agent.ID,
+ "max_tokens": ts.agent.MaxTokens,
+ "temperature": ts.agent.Temperature,
+ "prompt_cache_key": ts.agent.ID,
}
- // parseThinkingLevel guarantees ThinkingOff for empty/unknown values,
- // so checking != ThinkingOff is sufficient.
- if agent.ThinkingLevel != ThinkingOff {
- if tc, ok := agent.Provider.(providers.ThinkingCapable); ok && tc.SupportsThinking() {
- llmOpts["thinking_level"] = string(agent.ThinkingLevel)
+ if ts.agent.ThinkingLevel != ThinkingOff {
+ if tc, ok := ts.agent.Provider.(providers.ThinkingCapable); ok && tc.SupportsThinking() {
+ llmOpts["thinking_level"] = string(ts.agent.ThinkingLevel)
} else {
logger.WarnCF("agent", "thinking_level is set but current provider does not support it, ignoring",
- map[string]any{"agent_id": agent.ID, "thinking_level": string(agent.ThinkingLevel)})
+ map[string]any{"agent_id": ts.agent.ID, "thinking_level": string(ts.agent.ThinkingLevel)})
}
}
- callLLM := func() (*providers.LLMResponse, error) {
+ callLLM := func(messagesForCall []providers.Message, toolDefsForCall []providers.ToolDefinition) (*providers.LLMResponse, error) {
+ providerCtx, providerCancel := context.WithCancel(turnCtx)
+ ts.setProviderCancel(providerCancel)
+ defer func() {
+ providerCancel()
+ ts.clearProviderCancel(providerCancel)
+ }()
+
al.activeRequests.Add(1)
defer al.activeRequests.Done()
- // TODO(eventbus): emit EventKindLLMDelta when providers expose
- // streaming callbacks instead of only the final Chat response.
if len(activeCandidates) > 1 && al.fallback != nil {
fbResult, fbErr := al.fallback.Execute(
- ctx,
+ providerCtx,
activeCandidates,
func(ctx context.Context, provider, model string) (*providers.LLMResponse, error) {
- return agent.Provider.Chat(ctx, messages, providerToolDefs, model, llmOpts)
+ return ts.agent.Provider.Chat(ctx, messagesForCall, toolDefsForCall, model, llmOpts)
},
)
if fbErr != nil {
@@ -1416,32 +1489,34 @@ func (al *AgentLoop) runLLMIteration(
"agent",
fmt.Sprintf("Fallback: succeeded with %s/%s after %d attempts",
fbResult.Provider, fbResult.Model, len(fbResult.Attempts)+1),
- map[string]any{"agent_id": agent.ID, "iteration": iteration},
+ map[string]any{"agent_id": ts.agent.ID, "iteration": iteration},
)
}
return fbResult.Response, nil
}
- return agent.Provider.Chat(ctx, messages, providerToolDefs, activeModel, llmOpts)
+ return ts.agent.Provider.Chat(providerCtx, messagesForCall, toolDefsForCall, activeModel, llmOpts)
}
- // Retry loop for context/token errors
+ var response *providers.LLMResponse
+ var err error
maxRetries := 2
for retry := 0; retry <= maxRetries; retry++ {
- response, err = callLLM()
+ response, err = callLLM(callMessages, providerToolDefs)
if err == nil {
break
}
+ if ts.hardAbortRequested() && errors.Is(err, context.Canceled) {
+ turnStatus = TurnEndStatusAborted
+ return al.abortTurn(ts)
+ }
errMsg := strings.ToLower(err.Error())
-
- // Check if this is a network/HTTP timeout — not a context window error.
isTimeoutError := errors.Is(err, context.DeadlineExceeded) ||
strings.Contains(errMsg, "deadline exceeded") ||
strings.Contains(errMsg, "client.timeout") ||
strings.Contains(errMsg, "timed out") ||
strings.Contains(errMsg, "timeout exceeded")
- // Detect real context window / token limit errors, excluding network timeouts.
isContextError := !isTimeoutError && (strings.Contains(errMsg, "context_length_exceeded") ||
strings.Contains(errMsg, "context window") ||
strings.Contains(errMsg, "maximum context length") ||
@@ -1456,7 +1531,7 @@ func (al *AgentLoop) runLLMIteration(
backoff := time.Duration(retry+1) * 5 * time.Second
al.emitEvent(
EventKindLLMRetry,
- turnScope.meta(iteration, "runLLMIteration", "turn.llm.retry"),
+ ts.eventMeta("runTurn", "turn.llm.retry"),
LLMRetryPayload{
Attempt: retry + 1,
MaxRetries: maxRetries,
@@ -1470,14 +1545,21 @@ func (al *AgentLoop) runLLMIteration(
"retry": retry,
"backoff": backoff.String(),
})
- time.Sleep(backoff)
+ if sleepErr := sleepWithContext(turnCtx, backoff); sleepErr != nil {
+ if ts.hardAbortRequested() {
+ turnStatus = TurnEndStatusAborted
+ return al.abortTurn(ts)
+ }
+ err = sleepErr
+ break
+ }
continue
}
- if isContextError && retry < maxRetries {
+ if isContextError && retry < maxRetries && !ts.opts.NoHistory {
al.emitEvent(
EventKindLLMRetry,
- turnScope.meta(iteration, "runLLMIteration", "turn.llm.retry"),
+ ts.eventMeta("runTurn", "turn.llm.retry"),
LLMRetryPayload{
Attempt: retry + 1,
MaxRetries: maxRetries,
@@ -1494,40 +1576,47 @@ func (al *AgentLoop) runLLMIteration(
},
)
- if retry == 0 && !constants.IsInternalChannel(opts.Channel) {
+ if retry == 0 && !constants.IsInternalChannel(ts.channel) {
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
- Channel: opts.Channel,
- ChatID: opts.ChatID,
+ Channel: ts.channel,
+ ChatID: ts.chatID,
Content: "Context window exceeded. Compressing history and retrying...",
})
}
- if compression, ok := al.forceCompression(agent, opts.SessionKey); ok {
+ if compression, ok := al.forceCompression(ts.agent, ts.sessionKey); ok {
al.emitEvent(
EventKindContextCompress,
- turnScope.meta(iteration, "runLLMIteration", "turn.context.compress"),
+ ts.eventMeta("runTurn", "turn.context.compress"),
ContextCompressPayload{
Reason: ContextCompressReasonRetry,
DroppedMessages: compression.DroppedMessages,
RemainingMessages: compression.RemainingMessages,
},
)
+ ts.refreshRestorePointFromSession(ts.agent)
}
- newHistory := agent.Sessions.GetHistory(opts.SessionKey)
- newSummary := agent.Sessions.GetSummary(opts.SessionKey)
- messages = agent.ContextBuilder.BuildMessages(
+
+ newHistory := ts.agent.Sessions.GetHistory(ts.sessionKey)
+ newSummary := ts.agent.Sessions.GetSummary(ts.sessionKey)
+ messages = ts.agent.ContextBuilder.BuildMessages(
newHistory, newSummary, "",
- nil, opts.Channel, opts.ChatID,
+ nil, ts.channel, ts.chatID,
)
+ callMessages = messages
+ if gracefulTerminal {
+ callMessages = append(append([]providers.Message(nil), messages...), ts.interruptHintMessage())
+ }
continue
}
break
}
if err != nil {
+ turnStatus = TurnEndStatusError
al.emitEvent(
EventKindError,
- turnScope.meta(iteration, "runLLMIteration", "turn.error"),
+ ts.eventMeta("runTurn", "turn.error"),
ErrorPayload{
Stage: "llm",
Message: err.Error(),
@@ -1535,23 +1624,23 @@ func (al *AgentLoop) runLLMIteration(
)
logger.ErrorCF("agent", "LLM call failed",
map[string]any{
- "agent_id": agent.ID,
+ "agent_id": ts.agent.ID,
"iteration": iteration,
"model": activeModel,
"error": err.Error(),
})
- return "", iteration, fmt.Errorf("LLM call failed after retries: %w", err)
+ return turnResult{}, fmt.Errorf("LLM call failed after retries: %w", err)
}
go al.handleReasoning(
- ctx,
+ turnCtx,
response.Reasoning,
- opts.Channel,
- al.targetReasoningChannelID(opts.Channel),
+ ts.channel,
+ al.targetReasoningChannelID(ts.channel),
)
al.emitEvent(
EventKindLLMResponse,
- turnScope.meta(iteration, "runLLMIteration", "turn.llm.response"),
+ ts.eventMeta("runTurn", "turn.llm.response"),
LLMResponsePayload{
ContentLen: len(response.Content),
ToolCalls: len(response.ToolCalls),
@@ -1561,23 +1650,23 @@ func (al *AgentLoop) runLLMIteration(
logger.DebugCF("agent", "LLM response",
map[string]any{
- "agent_id": agent.ID,
+ "agent_id": ts.agent.ID,
"iteration": iteration,
"content_chars": len(response.Content),
"tool_calls": len(response.ToolCalls),
"reasoning": response.Reasoning,
- "target_channel": al.targetReasoningChannelID(opts.Channel),
- "channel": opts.Channel,
+ "target_channel": al.targetReasoningChannelID(ts.channel),
+ "channel": ts.channel,
})
- // Check if no tool calls - then check reasoning content if any
- if len(response.ToolCalls) == 0 {
+
+ if len(response.ToolCalls) == 0 || gracefulTerminal {
finalContent = response.Content
if finalContent == "" && response.ReasoningContent != "" {
finalContent = response.ReasoningContent
}
logger.InfoCF("agent", "LLM response without tool calls (direct answer)",
map[string]any{
- "agent_id": agent.ID,
+ "agent_id": ts.agent.ID,
"iteration": iteration,
"content_chars": len(finalContent),
})
@@ -1589,20 +1678,18 @@ func (al *AgentLoop) runLLMIteration(
normalizedToolCalls = append(normalizedToolCalls, providers.NormalizeToolCall(tc))
}
- // Log tool calls
toolNames := make([]string, 0, len(normalizedToolCalls))
for _, tc := range normalizedToolCalls {
toolNames = append(toolNames, tc.Name)
}
logger.InfoCF("agent", "LLM requested tool calls",
map[string]any{
- "agent_id": agent.ID,
+ "agent_id": ts.agent.ID,
"tools": toolNames,
"count": len(normalizedToolCalls),
"iteration": iteration,
})
- // Build assistant message with tool calls
assistantMsg := providers.Message{
Role: "assistant",
Content: response.Content,
@@ -1610,13 +1697,11 @@ func (al *AgentLoop) runLLMIteration(
}
for _, tc := range normalizedToolCalls {
argumentsJSON, _ := json.Marshal(tc.Arguments)
- // Copy ExtraContent to ensure thought_signature is persisted for Gemini 3
extraContent := tc.ExtraContent
thoughtSignature := ""
if tc.Function != nil {
thoughtSignature = tc.Function.ThoughtSignature
}
-
assistantMsg.ToolCalls = append(assistantMsg.ToolCalls, providers.ToolCall{
ID: tc.ID,
Type: "function",
@@ -1631,40 +1716,44 @@ func (al *AgentLoop) runLLMIteration(
})
}
messages = append(messages, assistantMsg)
+ if !ts.opts.NoHistory {
+ ts.agent.Sessions.AddFullMessage(ts.sessionKey, assistantMsg)
+ ts.recordPersistedMessage(assistantMsg)
+ }
- // Save assistant message with tool calls to session
- agent.Sessions.AddFullMessage(opts.SessionKey, assistantMsg)
-
- // Execute tool calls sequentially. After each tool completes, check
- // for steering messages. If any are found, skip remaining tools.
- var steeringAfterTools []providers.Message
-
+ ts.setPhase(TurnPhaseTools)
for i, tc := range normalizedToolCalls {
+ if ts.hardAbortRequested() {
+ turnStatus = TurnEndStatusAborted
+ return al.abortTurn(ts)
+ }
+
argsJSON, _ := json.Marshal(tc.Arguments)
argsPreview := utils.Truncate(string(argsJSON), 200)
logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
map[string]any{
- "agent_id": agent.ID,
+ "agent_id": ts.agent.ID,
"tool": tc.Name,
"iteration": iteration,
})
al.emitEvent(
EventKindToolExecStart,
- turnScope.meta(iteration, "runLLMIteration", "turn.tool.start"),
+ ts.eventMeta("runTurn", "turn.tool.start"),
ToolExecStartPayload{
Tool: tc.Name,
Arguments: cloneEventArguments(tc.Arguments),
},
)
- // Create async callback for tools that implement AsyncExecutor.
+ toolCall := tc
+ toolIteration := iteration
asyncCallback := func(_ context.Context, result *tools.ToolResult) {
if !result.Silent && result.ForUser != "" {
outCtx, outCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer outCancel()
_ = al.bus.PublishOutbound(outCtx, bus.OutboundMessage{
- Channel: opts.Channel,
- ChatID: opts.ChatID,
+ Channel: ts.channel,
+ ChatID: ts.chatID,
Content: result.ForUser,
})
}
@@ -1679,17 +1768,17 @@ func (al *AgentLoop) runLLMIteration(
logger.InfoCF("agent", "Async tool completed, publishing result",
map[string]any{
- "tool": tc.Name,
+ "tool": toolCall.Name,
"content_len": len(content),
- "channel": opts.Channel,
+ "channel": ts.channel,
})
al.emitEvent(
EventKindFollowUpQueued,
- turnScope.meta(iteration, "runLLMIteration", "turn.follow_up.queued"),
+ ts.scope.meta(toolIteration, "runTurn", "turn.follow_up.queued"),
FollowUpQueuedPayload{
- SourceTool: tc.Name,
- Channel: opts.Channel,
- ChatID: opts.ChatID,
+ SourceTool: toolCall.Name,
+ Channel: ts.channel,
+ ChatID: ts.chatID,
ContentLen: len(content),
},
)
@@ -1698,33 +1787,37 @@ func (al *AgentLoop) runLLMIteration(
defer pubCancel()
_ = al.bus.PublishInbound(pubCtx, bus.InboundMessage{
Channel: "system",
- SenderID: fmt.Sprintf("async:%s", tc.Name),
- ChatID: fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID),
+ SenderID: fmt.Sprintf("async:%s", toolCall.Name),
+ ChatID: fmt.Sprintf("%s:%s", ts.channel, ts.chatID),
Content: content,
})
}
toolStart := time.Now()
- toolResult := agent.Tools.ExecuteWithContext(
- ctx,
- tc.Name,
- tc.Arguments,
- opts.Channel,
- opts.ChatID,
+ toolResult := ts.agent.Tools.ExecuteWithContext(
+ turnCtx,
+ toolCall.Name,
+ toolCall.Arguments,
+ ts.channel,
+ ts.chatID,
asyncCallback,
)
toolDuration := time.Since(toolStart)
- // Process tool result
- if !toolResult.Silent && toolResult.ForUser != "" && opts.SendResponse {
+ if ts.hardAbortRequested() {
+ turnStatus = TurnEndStatusAborted
+ return al.abortTurn(ts)
+ }
+
+ if !toolResult.Silent && toolResult.ForUser != "" && ts.opts.SendResponse {
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
- Channel: opts.Channel,
- ChatID: opts.ChatID,
+ Channel: ts.channel,
+ ChatID: ts.chatID,
Content: toolResult.ForUser,
})
logger.DebugCF("agent", "Sent tool result to user",
map[string]any{
- "tool": tc.Name,
+ "tool": toolCall.Name,
"content_len": len(toolResult.ForUser),
})
}
@@ -1743,8 +1836,8 @@ func (al *AgentLoop) runLLMIteration(
parts = append(parts, part)
}
al.bus.PublishOutboundMedia(ctx, bus.OutboundMediaMessage{
- Channel: opts.Channel,
- ChatID: opts.ChatID,
+ Channel: ts.channel,
+ ChatID: ts.chatID,
Parts: parts,
})
}
@@ -1757,13 +1850,13 @@ func (al *AgentLoop) runLLMIteration(
toolResultMsg := providers.Message{
Role: "tool",
Content: contentForLLM,
- ToolCallID: tc.ID,
+ ToolCallID: toolCall.ID,
}
al.emitEvent(
EventKindToolExecEnd,
- turnScope.meta(iteration, "runLLMIteration", "turn.tool.end"),
+ ts.eventMeta("runTurn", "turn.tool.end"),
ToolExecEndPayload{
- Tool: tc.Name,
+ Tool: toolCall.Name,
Duration: toolDuration,
ForLLMLen: len(contentForLLM),
ForUserLen: len(toolResult.ForUser),
@@ -1772,67 +1865,136 @@ func (al *AgentLoop) runLLMIteration(
},
)
messages = append(messages, toolResultMsg)
- agent.Sessions.AddFullMessage(opts.SessionKey, toolResultMsg)
+ if !ts.opts.NoHistory {
+ ts.agent.Sessions.AddFullMessage(ts.sessionKey, toolResultMsg)
+ ts.recordPersistedMessage(toolResultMsg)
+ }
- // After EVERY tool (including the first and last), check for
- // steering messages. If found and there are remaining tools,
- // skip them all.
if steerMsgs := al.dequeueSteeringMessages(); len(steerMsgs) > 0 {
+ pendingMessages = append(pendingMessages, steerMsgs...)
+ }
+
+ skipReason := ""
+ skipMessage := ""
+ if len(pendingMessages) > 0 {
+ skipReason = "queued user steering message"
+ skipMessage = "Skipped due to queued user message."
+ } else if gracefulPending, _ := ts.gracefulInterruptRequested(); gracefulPending {
+ skipReason = "graceful interrupt requested"
+ skipMessage = "Skipped due to graceful interrupt."
+ }
+
+ if skipReason != "" {
remaining := len(normalizedToolCalls) - i - 1
if remaining > 0 {
- logger.InfoCF("agent", "Steering interrupt: skipping remaining tools",
+ logger.InfoCF("agent", "Turn checkpoint: skipping remaining tools",
map[string]any{
- "agent_id": agent.ID,
- "completed": i + 1,
- "skipped": remaining,
- "total_tools": len(normalizedToolCalls),
- "steering_count": len(steerMsgs),
+ "agent_id": ts.agent.ID,
+ "completed": i + 1,
+ "skipped": remaining,
+ "reason": skipReason,
})
-
- // Mark remaining tool calls as skipped
for j := i + 1; j < len(normalizedToolCalls); j++ {
skippedTC := normalizedToolCalls[j]
al.emitEvent(
EventKindToolExecSkipped,
- turnScope.meta(iteration, "runLLMIteration", "turn.tool.skipped"),
+ ts.eventMeta("runTurn", "turn.tool.skipped"),
ToolExecSkippedPayload{
Tool: skippedTC.Name,
- Reason: "queued user steering message",
+ Reason: skipReason,
},
)
- toolResultMsg := providers.Message{
+ skippedMsg := providers.Message{
Role: "tool",
- Content: "Skipped due to queued user message.",
+ Content: skipMessage,
ToolCallID: skippedTC.ID,
}
- messages = append(messages, toolResultMsg)
- agent.Sessions.AddFullMessage(opts.SessionKey, toolResultMsg)
+ messages = append(messages, skippedMsg)
+ if !ts.opts.NoHistory {
+ ts.agent.Sessions.AddFullMessage(ts.sessionKey, skippedMsg)
+ ts.recordPersistedMessage(skippedMsg)
+ }
}
}
- steeringAfterTools = steerMsgs
break
}
}
- // If steering messages were captured during tool execution, they
- // become pendingMessages for the next iteration of the inner loop.
- if len(steeringAfterTools) > 0 {
- pendingMessages = steeringAfterTools
- }
-
- // Tick down TTL of discovered tools after processing tool results.
- // Only reached when tool calls were made (the loop continues);
- // the break on no-tool-call responses skips this.
- // NOTE: This is safe because processMessage is sequential per agent.
- // If per-agent concurrency is added, TTL consistency between
- // ToProviderDefs and Get must be re-evaluated.
- agent.Tools.TickTTL()
+ ts.agent.Tools.TickTTL()
logger.DebugCF("agent", "TTL tick after tool execution", map[string]any{
- "agent_id": agent.ID, "iteration": iteration,
+ "agent_id": ts.agent.ID, "iteration": iteration,
})
}
- return finalContent, iteration, nil
+ if ts.hardAbortRequested() {
+ turnStatus = TurnEndStatusAborted
+ return al.abortTurn(ts)
+ }
+
+ if finalContent == "" {
+ finalContent = ts.opts.DefaultResponse
+ }
+
+ ts.setPhase(TurnPhaseFinalizing)
+ ts.setFinalContent(finalContent)
+ if !ts.opts.NoHistory {
+ finalMsg := providers.Message{Role: "assistant", Content: finalContent}
+ ts.agent.Sessions.AddMessage(ts.sessionKey, finalMsg.Role, finalMsg.Content)
+ ts.recordPersistedMessage(finalMsg)
+ if err := ts.agent.Sessions.Save(ts.sessionKey); err != nil {
+ turnStatus = TurnEndStatusError
+ al.emitEvent(
+ EventKindError,
+ ts.eventMeta("runTurn", "turn.error"),
+ ErrorPayload{
+ Stage: "session_save",
+ Message: err.Error(),
+ },
+ )
+ return turnResult{}, err
+ }
+ }
+
+ if ts.opts.EnableSummary {
+ al.maybeSummarize(ts.agent, ts.sessionKey, ts.scope)
+ }
+
+ ts.setPhase(TurnPhaseCompleted)
+ return turnResult{
+ finalContent: finalContent,
+ status: turnStatus,
+ followUps: append([]bus.InboundMessage(nil), ts.followUps...),
+ }, nil
+}
+
+func (al *AgentLoop) abortTurn(ts *turnState) (turnResult, error) {
+ ts.setPhase(TurnPhaseAborted)
+ if !ts.opts.NoHistory {
+ if err := ts.restoreSession(ts.agent); err != nil {
+ al.emitEvent(
+ EventKindError,
+ ts.eventMeta("abortTurn", "turn.error"),
+ ErrorPayload{
+ Stage: "session_restore",
+ Message: err.Error(),
+ },
+ )
+ return turnResult{}, err
+ }
+ }
+ return turnResult{status: TurnEndStatusAborted}, nil
+}
+
+func sleepWithContext(ctx context.Context, d time.Duration) error {
+ timer := time.NewTimer(d)
+ defer timer.Stop()
+
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case <-timer.C:
+ return nil
+ }
}
// selectCandidates returns the model candidates and resolved model name to use
diff --git a/pkg/agent/steering.go b/pkg/agent/steering.go
index 90d1cc091..77c2e0c17 100644
--- a/pkg/agent/steering.go
+++ b/pkg/agent/steering.go
@@ -122,20 +122,23 @@ func (al *AgentLoop) Steer(msg providers.Message) error {
"content_len": len(msg.Content),
"queue_len": al.steering.len(),
})
- agentID := ""
- if registry := al.GetRegistry(); registry != nil {
+
+ meta := EventMeta{
+ Source: "Steer",
+ TracePath: "turn.interrupt.received",
+ }
+ if ts := al.getActiveTurnState(); ts != nil {
+ meta = ts.eventMeta("Steer", "turn.interrupt.received")
+ } else if registry := al.GetRegistry(); registry != nil {
if agent := registry.GetDefaultAgent(); agent != nil {
- agentID = agent.ID
+ meta.AgentID = agent.ID
}
}
al.emitEvent(
EventKindInterruptReceived,
- EventMeta{
- AgentID: agentID,
- Source: "Steer",
- TracePath: "turn.interrupt.received",
- },
+ meta,
InterruptReceivedPayload{
+ Kind: InterruptKindSteering,
Role: msg.Role,
ContentLen: len(msg.Content),
QueueDepth: al.steering.len(),
@@ -177,6 +180,10 @@ func (al *AgentLoop) dequeueSteeringMessages() []providers.Message {
//
// If no steering messages are pending, it returns an empty string.
func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID string) (string, error) {
+ if active := al.GetActiveTurn(); active != nil {
+ return "", fmt.Errorf("turn %s is still active", active.TurnID)
+ }
+
steeringMsgs := al.dequeueSteeringMessages()
if len(steeringMsgs) == 0 {
return "", nil
@@ -187,6 +194,12 @@ func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID s
return "", fmt.Errorf("no default agent available")
}
+ if tool, ok := agent.Tools.Get("message"); ok {
+ if resetter, ok := tool.(interface{ ResetSentInRound() }); ok {
+ resetter.ResetSentInRound()
+ }
+ }
+
// Build a combined user message from the steering messages.
var contents []string
for _, msg := range steeringMsgs {
@@ -205,3 +218,44 @@ func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID s
SkipInitialSteeringPoll: true,
})
}
+
+func (al *AgentLoop) InterruptGraceful(hint string) error {
+ ts := al.getActiveTurnState()
+ if ts == nil {
+ return fmt.Errorf("no active turn")
+ }
+ if !ts.requestGracefulInterrupt(hint) {
+ return fmt.Errorf("turn %s cannot accept graceful interrupt", ts.turnID)
+ }
+
+ al.emitEvent(
+ EventKindInterruptReceived,
+ ts.eventMeta("InterruptGraceful", "turn.interrupt.received"),
+ InterruptReceivedPayload{
+ Kind: InterruptKindGraceful,
+ HintLen: len(hint),
+ },
+ )
+
+ return nil
+}
+
+func (al *AgentLoop) InterruptHard() error {
+ ts := al.getActiveTurnState()
+ if ts == nil {
+ return fmt.Errorf("no active turn")
+ }
+ if !ts.requestHardAbort() {
+ return fmt.Errorf("turn %s is already aborting", ts.turnID)
+ }
+
+ al.emitEvent(
+ EventKindInterruptReceived,
+ ts.eventMeta("InterruptHard", "turn.interrupt.received"),
+ InterruptReceivedPayload{
+ Kind: InterruptKindHard,
+ },
+ )
+
+ return nil
+}
diff --git a/pkg/agent/steering_test.go b/pkg/agent/steering_test.go
index e8cdb2344..f8c046ea9 100644
--- a/pkg/agent/steering_test.go
+++ b/pkg/agent/steering_test.go
@@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"os"
+ "reflect"
"sync"
"testing"
"time"
@@ -12,6 +13,7 @@ import (
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/providers"
+ "github.com/sipeed/picoclaw/pkg/routing"
"github.com/sipeed/picoclaw/pkg/tools"
)
@@ -396,6 +398,103 @@ func (m *toolCallProvider) GetDefaultModel() string {
return "tool-call-mock"
}
+type gracefulCaptureProvider struct {
+ mu sync.Mutex
+ calls int
+ toolCalls []providers.ToolCall
+ finalResp string
+ terminalMessages []providers.Message
+ terminalToolsCount int
+}
+
+func (p *gracefulCaptureProvider) Chat(
+ ctx context.Context,
+ messages []providers.Message,
+ tools []providers.ToolDefinition,
+ model string,
+ opts map[string]any,
+) (*providers.LLMResponse, error) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ p.calls++
+
+ if p.calls == 1 {
+ return &providers.LLMResponse{
+ ToolCalls: p.toolCalls,
+ }, nil
+ }
+
+ p.terminalMessages = append([]providers.Message(nil), messages...)
+ p.terminalToolsCount = len(tools)
+ return &providers.LLMResponse{
+ Content: p.finalResp,
+ }, nil
+}
+
+func (p *gracefulCaptureProvider) GetDefaultModel() string {
+ return "graceful-capture-mock"
+}
+
+type lateSteeringProvider struct {
+ mu sync.Mutex
+ calls int
+ firstCallStarted chan struct{}
+ releaseFirstCall chan struct{}
+ firstStartOnce sync.Once
+ secondCallMessages []providers.Message
+}
+
+func (p *lateSteeringProvider) Chat(
+ ctx context.Context,
+ messages []providers.Message,
+ tools []providers.ToolDefinition,
+ model string,
+ opts map[string]any,
+) (*providers.LLMResponse, error) {
+ p.mu.Lock()
+ p.calls++
+ call := p.calls
+ p.mu.Unlock()
+
+ if call == 1 {
+ p.firstStartOnce.Do(func() { close(p.firstCallStarted) })
+ <-p.releaseFirstCall
+ return &providers.LLMResponse{Content: "first response"}, nil
+ }
+
+ p.mu.Lock()
+ p.secondCallMessages = append([]providers.Message(nil), messages...)
+ p.mu.Unlock()
+ return &providers.LLMResponse{Content: "continued response"}, nil
+}
+
+func (p *lateSteeringProvider) GetDefaultModel() string {
+ return "late-steering-mock"
+}
+
+type interruptibleTool struct {
+ name string
+ started chan struct{}
+ once sync.Once
+}
+
+func (t *interruptibleTool) Name() string { return t.name }
+func (t *interruptibleTool) Description() string { return "interruptible tool for testing" }
+func (t *interruptibleTool) Parameters() map[string]any {
+ return map[string]any{
+ "type": "object",
+ "properties": map[string]any{},
+ }
+}
+
+func (t *interruptibleTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
+ if t.started != nil {
+ t.once.Do(func() { close(t.started) })
+ }
+ <-ctx.Done()
+ return tools.ErrorResult(ctx.Err().Error()).WithError(ctx.Err())
+}
+
func TestAgentLoop_Steering_SkipsRemainingTools(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
@@ -568,6 +667,425 @@ func TestAgentLoop_Steering_InitialPoll(t *testing.T) {
}
}
+func TestAgentLoop_Run_AutoContinuesLateSteeringMessage(t *testing.T) {
+ tmpDir, err := os.MkdirTemp("", "agent-test-*")
+ if err != nil {
+ t.Fatalf("Failed to create temp dir: %v", err)
+ }
+ defer os.RemoveAll(tmpDir)
+
+ cfg := &config.Config{
+ Agents: config.AgentsConfig{
+ Defaults: config.AgentDefaults{
+ Workspace: tmpDir,
+ Model: "test-model",
+ MaxTokens: 4096,
+ MaxToolIterations: 10,
+ },
+ },
+ }
+
+ msgBus := bus.NewMessageBus()
+ provider := &lateSteeringProvider{
+ firstCallStarted: make(chan struct{}),
+ releaseFirstCall: make(chan struct{}),
+ }
+ al := NewAgentLoop(cfg, msgBus, provider)
+
+ runCtx, cancelRun := context.WithCancel(context.Background())
+ defer cancelRun()
+
+ runErrCh := make(chan error, 1)
+ go func() {
+ runErrCh <- al.Run(runCtx)
+ }()
+
+ first := bus.InboundMessage{
+ Channel: "test",
+ SenderID: "user1",
+ ChatID: "chat1",
+ Content: "first message",
+ Peer: bus.Peer{
+ Kind: "direct",
+ ID: "user1",
+ },
+ }
+ late := bus.InboundMessage{
+ Channel: "test",
+ SenderID: "user1",
+ ChatID: "chat1",
+ Content: "late append",
+ Peer: bus.Peer{
+ Kind: "direct",
+ ID: "user1",
+ },
+ }
+
+ pubCtx, pubCancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer pubCancel()
+ if err := msgBus.PublishInbound(pubCtx, first); err != nil {
+ t.Fatalf("publish first inbound: %v", err)
+ }
+
+ select {
+ case <-provider.firstCallStarted:
+ case <-time.After(2 * time.Second):
+ t.Fatal("timeout waiting for first provider call to start")
+ }
+
+ if err := msgBus.PublishInbound(pubCtx, late); err != nil {
+ t.Fatalf("publish late inbound: %v", err)
+ }
+
+ close(provider.releaseFirstCall)
+
+ subCtx, subCancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer subCancel()
+
+ out1, ok := msgBus.SubscribeOutbound(subCtx)
+ if !ok {
+ t.Fatal("expected first outbound response")
+ }
+ if out1.Content != "first response" {
+ t.Fatalf("expected first response, got %q", out1.Content)
+ }
+
+ out2, ok := msgBus.SubscribeOutbound(subCtx)
+ if !ok {
+ t.Fatal("expected continued outbound response")
+ }
+ if out2.Content != "continued response" {
+ t.Fatalf("expected continued response, got %q", out2.Content)
+ }
+
+ cancelRun()
+ select {
+ case err := <-runErrCh:
+ if err != nil {
+ t.Fatalf("Run returned error: %v", err)
+ }
+ case <-time.After(2 * time.Second):
+ t.Fatal("timeout waiting for Run to stop")
+ }
+
+ provider.mu.Lock()
+ calls := provider.calls
+ secondMessages := append([]providers.Message(nil), provider.secondCallMessages...)
+ provider.mu.Unlock()
+
+ if calls != 2 {
+ t.Fatalf("expected 2 provider calls, got %d", calls)
+ }
+
+ foundLateMessage := false
+ for _, msg := range secondMessages {
+ if msg.Role == "user" && msg.Content == "late append" {
+ foundLateMessage = true
+ break
+ }
+ }
+ if !foundLateMessage {
+ t.Fatal("expected queued late message to be processed in an automatic follow-up turn")
+ }
+}
+
+func TestAgentLoop_InterruptGraceful_UsesTerminalNoToolCall(t *testing.T) {
+ tmpDir, err := os.MkdirTemp("", "agent-test-*")
+ if err != nil {
+ t.Fatalf("Failed to create temp dir: %v", err)
+ }
+ defer os.RemoveAll(tmpDir)
+
+ cfg := &config.Config{
+ Agents: config.AgentsConfig{
+ Defaults: config.AgentDefaults{
+ Workspace: tmpDir,
+ Model: "test-model",
+ MaxTokens: 4096,
+ MaxToolIterations: 10,
+ },
+ },
+ }
+
+ tool1ExecCh := make(chan struct{})
+ tool1 := &slowTool{name: "tool_one", duration: 50 * time.Millisecond, execCh: tool1ExecCh}
+ tool2 := &slowTool{name: "tool_two", duration: 50 * time.Millisecond}
+
+ provider := &gracefulCaptureProvider{
+ toolCalls: []providers.ToolCall{
+ {
+ ID: "call_1",
+ Type: "function",
+ Name: "tool_one",
+ Function: &providers.FunctionCall{
+ Name: "tool_one",
+ Arguments: "{}",
+ },
+ Arguments: map[string]any{},
+ },
+ {
+ ID: "call_2",
+ Type: "function",
+ Name: "tool_two",
+ Function: &providers.FunctionCall{
+ Name: "tool_two",
+ Arguments: "{}",
+ },
+ Arguments: map[string]any{},
+ },
+ },
+ finalResp: "graceful summary",
+ }
+
+ msgBus := bus.NewMessageBus()
+ al := NewAgentLoop(cfg, msgBus, provider)
+ al.RegisterTool(tool1)
+ al.RegisterTool(tool2)
+ sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID)
+
+ sub := al.SubscribeEvents(32)
+ defer al.UnsubscribeEvents(sub.ID)
+
+ type result struct {
+ resp string
+ err error
+ }
+ resultCh := make(chan result, 1)
+ go func() {
+ resp, err := al.ProcessDirectWithChannel(
+ context.Background(),
+ "do something",
+ sessionKey,
+ "test",
+ "chat1",
+ )
+ resultCh <- result{resp: resp, err: err}
+ }()
+
+ select {
+ case <-tool1ExecCh:
+ case <-time.After(2 * time.Second):
+ t.Fatal("timeout waiting for tool_one to start")
+ }
+
+ active := al.GetActiveTurn()
+ if active == nil {
+ t.Fatal("expected active turn while tool is running")
+ }
+ if active.SessionKey != sessionKey {
+ t.Fatalf("expected active session %q, got %q", sessionKey, active.SessionKey)
+ }
+ if active.Channel != "test" || active.ChatID != "chat1" {
+ t.Fatalf("unexpected active turn target: %#v", active)
+ }
+
+ if err := al.InterruptGraceful("wrap it up"); err != nil {
+ t.Fatalf("InterruptGraceful failed: %v", err)
+ }
+
+ select {
+ case r := <-resultCh:
+ if r.err != nil {
+ t.Fatalf("unexpected error: %v", r.err)
+ }
+ if r.resp != "graceful summary" {
+ t.Fatalf("expected graceful summary, got %q", r.resp)
+ }
+ case <-time.After(5 * time.Second):
+ t.Fatal("timeout waiting for graceful interrupt result")
+ }
+
+ if active := al.GetActiveTurn(); active != nil {
+ t.Fatalf("expected no active turn after completion, got %#v", active)
+ }
+
+ provider.mu.Lock()
+ terminalMessages := append([]providers.Message(nil), provider.terminalMessages...)
+ terminalToolsCount := provider.terminalToolsCount
+ calls := provider.calls
+ provider.mu.Unlock()
+
+ if calls != 2 {
+ t.Fatalf("expected 2 provider calls, got %d", calls)
+ }
+ if terminalToolsCount != 0 {
+ t.Fatalf("expected graceful terminal call to disable tools, got %d tool defs", terminalToolsCount)
+ }
+
+ foundHint := false
+ foundSkipped := false
+ for _, msg := range terminalMessages {
+ if msg.Role == "user" && msg.Content == "Interrupt requested. Stop scheduling tools and provide a short final summary.\n\nInterrupt hint: wrap it up" {
+ foundHint = true
+ }
+ if msg.Role == "tool" && msg.ToolCallID == "call_2" && msg.Content == "Skipped due to graceful interrupt." {
+ foundSkipped = true
+ }
+ }
+ if !foundHint {
+ t.Fatal("expected graceful terminal call to include interrupt hint message")
+ }
+ if !foundSkipped {
+ t.Fatal("expected remaining tool to be marked as skipped after graceful interrupt")
+ }
+
+ events := collectEventStream(sub.C)
+ interruptEvt, ok := findEvent(events, EventKindInterruptReceived)
+ if !ok {
+ t.Fatal("expected interrupt received event")
+ }
+ interruptPayload, ok := interruptEvt.Payload.(InterruptReceivedPayload)
+ if !ok {
+ t.Fatalf("expected InterruptReceivedPayload, got %T", interruptEvt.Payload)
+ }
+ if interruptPayload.Kind != InterruptKindGraceful {
+ t.Fatalf("expected graceful interrupt payload, got %q", interruptPayload.Kind)
+ }
+
+ turnEndEvt, ok := findEvent(events, EventKindTurnEnd)
+ if !ok {
+ t.Fatal("expected turn end event")
+ }
+ turnEndPayload, ok := turnEndEvt.Payload.(TurnEndPayload)
+ if !ok {
+ t.Fatalf("expected TurnEndPayload, got %T", turnEndEvt.Payload)
+ }
+ if turnEndPayload.Status != TurnEndStatusCompleted {
+ t.Fatalf("expected completed turn after graceful interrupt, got %q", turnEndPayload.Status)
+ }
+}
+
+func TestAgentLoop_InterruptHard_RestoresSession(t *testing.T) {
+ tmpDir, err := os.MkdirTemp("", "agent-test-*")
+ if err != nil {
+ t.Fatalf("Failed to create temp dir: %v", err)
+ }
+ defer os.RemoveAll(tmpDir)
+
+ cfg := &config.Config{
+ Agents: config.AgentsConfig{
+ Defaults: config.AgentDefaults{
+ Workspace: tmpDir,
+ Model: "test-model",
+ MaxTokens: 4096,
+ MaxToolIterations: 10,
+ },
+ },
+ }
+
+ msgBus := bus.NewMessageBus()
+ provider := &toolCallProvider{
+ toolCalls: []providers.ToolCall{
+ {
+ ID: "call_1",
+ Type: "function",
+ Name: "cancel_tool",
+ Function: &providers.FunctionCall{
+ Name: "cancel_tool",
+ Arguments: "{}",
+ },
+ Arguments: map[string]any{},
+ },
+ },
+ finalResp: "should not happen",
+ }
+
+ al := NewAgentLoop(cfg, msgBus, provider)
+ started := make(chan struct{})
+ al.RegisterTool(&interruptibleTool{name: "cancel_tool", started: started})
+ sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID)
+
+ defaultAgent := al.registry.GetDefaultAgent()
+ if defaultAgent == nil {
+ t.Fatal("expected default agent")
+ }
+
+ originalHistory := []providers.Message{
+ {Role: "user", Content: "before"},
+ {Role: "assistant", Content: "after"},
+ }
+ defaultAgent.Sessions.SetHistory(sessionKey, originalHistory)
+
+ sub := al.SubscribeEvents(16)
+ defer al.UnsubscribeEvents(sub.ID)
+
+ type result struct {
+ resp string
+ err error
+ }
+ resultCh := make(chan result, 1)
+ go func() {
+ resp, err := al.ProcessDirectWithChannel(
+ context.Background(),
+ "do work",
+ sessionKey,
+ "test",
+ "chat1",
+ )
+ resultCh <- result{resp: resp, err: err}
+ }()
+
+ select {
+ case <-started:
+ case <-time.After(2 * time.Second):
+ t.Fatal("timeout waiting for interruptible tool to start")
+ }
+
+ if active := al.GetActiveTurn(); active == nil {
+ t.Fatal("expected active turn before hard abort")
+ }
+
+ if err := al.InterruptHard(); err != nil {
+ t.Fatalf("InterruptHard failed: %v", err)
+ }
+
+ select {
+ case r := <-resultCh:
+ if r.err != nil {
+ t.Fatalf("unexpected error: %v", r.err)
+ }
+ if r.resp != "" {
+ t.Fatalf("expected no final response after hard abort, got %q", r.resp)
+ }
+ case <-time.After(5 * time.Second):
+ t.Fatal("timeout waiting for hard abort result")
+ }
+
+ if active := al.GetActiveTurn(); active != nil {
+ t.Fatalf("expected no active turn after hard abort, got %#v", active)
+ }
+
+ finalHistory := defaultAgent.Sessions.GetHistory(sessionKey)
+ if !reflect.DeepEqual(finalHistory, originalHistory) {
+ t.Fatalf("expected history rollback after hard abort, got %#v", finalHistory)
+ }
+
+ events := collectEventStream(sub.C)
+ interruptEvt, ok := findEvent(events, EventKindInterruptReceived)
+ if !ok {
+ t.Fatal("expected interrupt received event")
+ }
+ interruptPayload, ok := interruptEvt.Payload.(InterruptReceivedPayload)
+ if !ok {
+ t.Fatalf("expected InterruptReceivedPayload, got %T", interruptEvt.Payload)
+ }
+ if interruptPayload.Kind != InterruptKindHard {
+ t.Fatalf("expected hard interrupt payload, got %q", interruptPayload.Kind)
+ }
+
+ turnEndEvt, ok := findEvent(events, EventKindTurnEnd)
+ if !ok {
+ t.Fatal("expected turn end event")
+ }
+ turnEndPayload, ok := turnEndEvt.Payload.(TurnEndPayload)
+ if !ok {
+ t.Fatalf("expected TurnEndPayload, got %T", turnEndEvt.Payload)
+ }
+ if turnEndPayload.Status != TurnEndStatusAborted {
+ t.Fatalf("expected aborted turn, got %q", turnEndPayload.Status)
+ }
+}
+
// capturingMockProvider captures messages sent to Chat for inspection.
type capturingMockProvider struct {
response string
diff --git a/pkg/agent/turn.go b/pkg/agent/turn.go
new file mode 100644
index 000000000..c44a4f80e
--- /dev/null
+++ b/pkg/agent/turn.go
@@ -0,0 +1,309 @@
+package agent
+
+import (
+ "context"
+ "reflect"
+ "sync"
+ "time"
+
+ "github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/providers"
+)
+
+type TurnPhase string
+
+const (
+ TurnPhaseSetup TurnPhase = "setup"
+ TurnPhaseRunning TurnPhase = "running"
+ TurnPhaseTools TurnPhase = "tools"
+ TurnPhaseFinalizing TurnPhase = "finalizing"
+ TurnPhaseCompleted TurnPhase = "completed"
+ TurnPhaseAborted TurnPhase = "aborted"
+)
+
+type ActiveTurnInfo struct {
+ TurnID string
+ AgentID string
+ SessionKey string
+ Channel string
+ ChatID string
+ UserMessage string
+ Phase TurnPhase
+ Iteration int
+ StartedAt time.Time
+}
+
+type turnResult struct {
+ finalContent string
+ status TurnEndStatus
+ followUps []bus.InboundMessage
+}
+
+type turnState struct {
+ mu sync.RWMutex
+
+ agent *AgentInstance
+ opts processOptions
+ scope turnEventScope
+
+ turnID string
+ agentID string
+ sessionKey string
+
+ channel string
+ chatID string
+ userMessage string
+ media []string
+
+ phase TurnPhase
+ iteration int
+ startedAt time.Time
+ finalContent string
+
+ pendingSteering []providers.Message
+ followUps []bus.InboundMessage
+
+ gracefulInterrupt bool
+ gracefulInterruptHint string
+ gracefulTerminalUsed bool
+ hardAbort bool
+ providerCancel context.CancelFunc
+ turnCancel context.CancelFunc
+
+ restorePointHistory []providers.Message
+ restorePointSummary string
+ persistedMessages []providers.Message
+}
+
+func newTurnState(agent *AgentInstance, opts processOptions, scope turnEventScope) *turnState {
+ return &turnState{
+ agent: agent,
+ opts: opts,
+ scope: scope,
+ turnID: scope.turnID,
+ agentID: agent.ID,
+ sessionKey: opts.SessionKey,
+ channel: opts.Channel,
+ chatID: opts.ChatID,
+ userMessage: opts.UserMessage,
+ media: append([]string(nil), opts.Media...),
+ phase: TurnPhaseSetup,
+ startedAt: time.Now(),
+ }
+}
+
+func (al *AgentLoop) registerActiveTurn(ts *turnState) {
+ al.activeTurnMu.Lock()
+ defer al.activeTurnMu.Unlock()
+ al.activeTurn = ts
+}
+
+func (al *AgentLoop) clearActiveTurn(ts *turnState) {
+ al.activeTurnMu.Lock()
+ defer al.activeTurnMu.Unlock()
+ if al.activeTurn == ts {
+ al.activeTurn = nil
+ }
+}
+
+func (al *AgentLoop) getActiveTurnState() *turnState {
+ al.activeTurnMu.RLock()
+ defer al.activeTurnMu.RUnlock()
+ return al.activeTurn
+}
+
+func (al *AgentLoop) GetActiveTurn() *ActiveTurnInfo {
+ ts := al.getActiveTurnState()
+ if ts == nil {
+ return nil
+ }
+ info := ts.snapshot()
+ return &info
+}
+
+func (ts *turnState) snapshot() ActiveTurnInfo {
+ ts.mu.RLock()
+ defer ts.mu.RUnlock()
+
+ return ActiveTurnInfo{
+ TurnID: ts.turnID,
+ AgentID: ts.agentID,
+ SessionKey: ts.sessionKey,
+ Channel: ts.channel,
+ ChatID: ts.chatID,
+ UserMessage: ts.userMessage,
+ Phase: ts.phase,
+ Iteration: ts.iteration,
+ StartedAt: ts.startedAt,
+ }
+}
+
+func (ts *turnState) setPhase(phase TurnPhase) {
+ ts.mu.Lock()
+ defer ts.mu.Unlock()
+ ts.phase = phase
+}
+
+func (ts *turnState) setIteration(iteration int) {
+ ts.mu.Lock()
+ defer ts.mu.Unlock()
+ ts.iteration = iteration
+}
+
+func (ts *turnState) currentIteration() int {
+ ts.mu.RLock()
+ defer ts.mu.RUnlock()
+ return ts.iteration
+}
+
+func (ts *turnState) setFinalContent(content string) {
+ ts.mu.Lock()
+ defer ts.mu.Unlock()
+ ts.finalContent = content
+}
+
+func (ts *turnState) finalContentLen() int {
+ ts.mu.RLock()
+ defer ts.mu.RUnlock()
+ return len(ts.finalContent)
+}
+
+func (ts *turnState) setTurnCancel(cancel context.CancelFunc) {
+ ts.mu.Lock()
+ defer ts.mu.Unlock()
+ ts.turnCancel = cancel
+}
+
+func (ts *turnState) setProviderCancel(cancel context.CancelFunc) {
+ ts.mu.Lock()
+ defer ts.mu.Unlock()
+ ts.providerCancel = cancel
+}
+
+func (ts *turnState) clearProviderCancel(_ context.CancelFunc) {
+ ts.mu.Lock()
+ defer ts.mu.Unlock()
+ ts.providerCancel = nil
+}
+
+func (ts *turnState) requestGracefulInterrupt(hint string) bool {
+ ts.mu.Lock()
+ defer ts.mu.Unlock()
+ if ts.hardAbort {
+ return false
+ }
+ ts.gracefulInterrupt = true
+ ts.gracefulInterruptHint = hint
+ return true
+}
+
+func (ts *turnState) gracefulInterruptRequested() (bool, string) {
+ ts.mu.RLock()
+ defer ts.mu.RUnlock()
+ return ts.gracefulInterrupt && !ts.gracefulTerminalUsed, ts.gracefulInterruptHint
+}
+
+func (ts *turnState) markGracefulTerminalUsed() {
+ ts.mu.Lock()
+ defer ts.mu.Unlock()
+ ts.gracefulTerminalUsed = true
+}
+
+func (ts *turnState) requestHardAbort() bool {
+ ts.mu.Lock()
+ if ts.hardAbort {
+ ts.mu.Unlock()
+ return false
+ }
+ ts.hardAbort = true
+ turnCancel := ts.turnCancel
+ providerCancel := ts.providerCancel
+ ts.mu.Unlock()
+
+ if providerCancel != nil {
+ providerCancel()
+ }
+ if turnCancel != nil {
+ turnCancel()
+ }
+ return true
+}
+
+func (ts *turnState) hardAbortRequested() bool {
+ ts.mu.RLock()
+ defer ts.mu.RUnlock()
+ return ts.hardAbort
+}
+
+func (ts *turnState) eventMeta(source, tracePath string) EventMeta {
+ snap := ts.snapshot()
+ return EventMeta{
+ AgentID: snap.AgentID,
+ TurnID: snap.TurnID,
+ SessionKey: snap.SessionKey,
+ Iteration: snap.Iteration,
+ Source: source,
+ TracePath: tracePath,
+ }
+}
+
+func (ts *turnState) captureRestorePoint(history []providers.Message, summary string) {
+ ts.mu.Lock()
+ defer ts.mu.Unlock()
+ ts.restorePointHistory = append([]providers.Message(nil), history...)
+ ts.restorePointSummary = summary
+}
+
+func (ts *turnState) recordPersistedMessage(msg providers.Message) {
+ ts.mu.Lock()
+ defer ts.mu.Unlock()
+ ts.persistedMessages = append(ts.persistedMessages, msg)
+}
+
+func (ts *turnState) refreshRestorePointFromSession(agent *AgentInstance) {
+ history := agent.Sessions.GetHistory(ts.sessionKey)
+ summary := agent.Sessions.GetSummary(ts.sessionKey)
+
+ ts.mu.RLock()
+ persisted := append([]providers.Message(nil), ts.persistedMessages...)
+ ts.mu.RUnlock()
+
+ if matched := matchingTurnMessageTail(history, persisted); matched > 0 {
+ history = append([]providers.Message(nil), history[:len(history)-matched]...)
+ }
+
+ ts.captureRestorePoint(history, summary)
+}
+
+func (ts *turnState) restoreSession(agent *AgentInstance) error {
+ ts.mu.RLock()
+ history := append([]providers.Message(nil), ts.restorePointHistory...)
+ summary := ts.restorePointSummary
+ ts.mu.RUnlock()
+
+ agent.Sessions.SetHistory(ts.sessionKey, history)
+ agent.Sessions.SetSummary(ts.sessionKey, summary)
+ return agent.Sessions.Save(ts.sessionKey)
+}
+
+func matchingTurnMessageTail(history, persisted []providers.Message) int {
+ maxMatch := min(len(history), len(persisted))
+ for size := maxMatch; size > 0; size-- {
+ if reflect.DeepEqual(history[len(history)-size:], persisted[len(persisted)-size:]) {
+ return size
+ }
+ }
+ return 0
+}
+
+func (ts *turnState) interruptHintMessage() providers.Message {
+ _, hint := ts.gracefulInterruptRequested()
+ content := "Interrupt requested. Stop scheduling tools and provide a short final summary."
+ if hint != "" {
+ content += "\n\nInterrupt hint: " + hint
+ }
+ return providers.Message{
+ Role: "user",
+ Content: content,
+ }
+}
From 2b3c95b1f19357c289419b06eba7528926200823 Mon Sep 17 00:00:00 2001
From: Hoshina
Date: Fri, 20 Mar 2026 17:46:31 +0800
Subject: [PATCH 20/26] fix: lint err
---
pkg/agent/steering_test.go | 4 +++-
pkg/agent/turn.go | 3 +--
2 files changed, 4 insertions(+), 3 deletions(-)
diff --git a/pkg/agent/steering_test.go b/pkg/agent/steering_test.go
index f8c046ea9..bb5d42c73 100644
--- a/pkg/agent/steering_test.go
+++ b/pkg/agent/steering_test.go
@@ -914,8 +914,10 @@ func TestAgentLoop_InterruptGraceful_UsesTerminalNoToolCall(t *testing.T) {
foundHint := false
foundSkipped := false
+ expectedHint := "Interrupt requested. Stop scheduling tools and provide a short final summary.\n\n" +
+ "Interrupt hint: wrap it up"
for _, msg := range terminalMessages {
- if msg.Role == "user" && msg.Content == "Interrupt requested. Stop scheduling tools and provide a short final summary.\n\nInterrupt hint: wrap it up" {
+ if msg.Role == "user" && msg.Content == expectedHint {
foundHint = true
}
if msg.Role == "tool" && msg.ToolCallID == "call_2" && msg.Content == "Skipped due to graceful interrupt." {
diff --git a/pkg/agent/turn.go b/pkg/agent/turn.go
index c44a4f80e..358dae2b4 100644
--- a/pkg/agent/turn.go
+++ b/pkg/agent/turn.go
@@ -60,8 +60,7 @@ type turnState struct {
startedAt time.Time
finalContent string
- pendingSteering []providers.Message
- followUps []bus.InboundMessage
+ followUps []bus.InboundMessage
gracefulInterrupt bool
gracefulInterruptHint string
From 1c6586681d9d5f1b6dc3708edd91fa55ca70554f Mon Sep 17 00:00:00 2001
From: afjcjsbx
Date: Fri, 20 Mar 2026 19:44:00 +0100
Subject: [PATCH 21/26] fix(agent) scope steering
---
docs/steering.md | 35 +++-
pkg/agent/loop.go | 133 +++++++++++----
pkg/agent/steering.go | 223 +++++++++++++++++++-----
pkg/agent/steering_test.go | 340 ++++++++++++++++++++++++++++++++++++-
4 files changed, 645 insertions(+), 86 deletions(-)
diff --git a/docs/steering.md b/docs/steering.md
index ad08f8425..63294ac5f 100644
--- a/docs/steering.md
+++ b/docs/steering.md
@@ -21,6 +21,18 @@ Agent Loop ▼
└─ new LLM turn with steering message
```
+## Scoped queues
+
+Steering is now isolated per resolved session scope, not stored in a single
+global queue.
+
+- The active turn writes and reads from its own scope key (usually the routed session key such as `agent::...`)
+- `Steer()` still works outside an active turn through a legacy fallback queue
+- `Continue()` first dequeues messages for the requested session scope, then falls back to the legacy queue for backwards compatibility
+
+This prevents a message arriving from another chat, DM peer, or routed agent
+session from being injected into the wrong conversation.
+
## Configuration
In `config.json`, under `agents.defaults`:
@@ -86,12 +98,18 @@ if response == "" {
`Continue` internally uses `SkipInitialSteeringPoll: true` to avoid double-dequeuing the same messages (since it already extracted them and passes them directly as input).
+`Continue` also resolves the target agent from the provided session key, so
+agent-scoped sessions continue on the correct agent instead of always using
+the default one.
+
## Polling points in the loop
-Steering is checked at **two points** in the agent cycle:
+Steering is checked at the following points in the agent cycle:
1. **At loop start** — before the first LLM call, to catch messages enqueued during setup
2. **After every tool completes** — including the first and the last. If steering is found and there are remaining tools, they are all skipped immediately
+3. **After a direct LLM response** — if a new steering message arrived while the model was generating a non-tool response, the loop continues instead of returning a stale answer
+4. **Right before the turn is finalized** — if steering arrived at the very end of the turn, the agent immediately starts a continuation turn instead of leaving the message orphaned in the queue
## Why remaining tools are skipped
@@ -156,11 +174,26 @@ When the agent loop (`Run()`) starts processing a message, it spawns a backgroun
- Users on any channel (Telegram, Discord, etc.) don't need to do anything special — their messages are automatically captured as steering when the agent is busy
- Audio messages are transcribed before being steered, so the agent receives text. If transcription fails, the original (non-transcribed) message is steered as-is
+- Only messages that resolve to the **same steering scope** as the active turn are redirected. Messages for other chats/sessions are requeued onto the inbound bus so they can be processed normally
+- `system` inbound messages are not treated as steering input
- When `processMessage` finishes, the drain goroutine is canceled and normal message consumption resumes
+## Steering with media
+
+Steering messages can include `Media` refs, just like normal inbound user
+messages.
+
+- The original `media://` refs are preserved in session history via `AddFullMessage`
+- Before the next provider call, steering messages go through the normal media resolution pipeline
+- Image refs are converted to data URLs for multimodal providers; non-image refs are resolved the same way as standard inbound media
+
+This applies both to in-turn steering and to idle-session continuation through
+`Continue()`.
+
## Notes
- Steering **does not interrupt** a tool that is currently executing. It waits for the current tool to finish, then checks the queue.
- With `one-at-a-time` mode, if multiple messages are enqueued rapidly, they will be processed one per iteration. This gives the model the opportunity to react to each message individually.
- With `all` mode, all pending messages are combined into a single injection. Useful when you want the agent to receive all the context at once.
- The steering queue has a maximum capacity of 10 messages (`MaxQueueSize`). `Steer()` returns an error when the queue is full. In the bus drain path, the error is logged as a warning and the message is effectively dropped.
+- Manual `Steer()` calls made outside an active turn still go to the legacy fallback queue, so older integrations keep working.
diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go
index f54482ae8..27bafe977 100644
--- a/pkg/agent/loop.go
+++ b/pkg/agent/loop.go
@@ -64,11 +64,12 @@ type processOptions struct {
ChatID string // Target chat ID for tool execution
UserMessage string // User message content (may include prefix)
Media []string // media:// refs from inbound message
- DefaultResponse string // Response when LLM returns empty
- EnableSummary bool // Whether to trigger summarization
- SendResponse bool // Whether to send response via bus
- NoHistory bool // If true, don't load session history (for heartbeat)
- SkipInitialSteeringPoll bool // If true, skip the steering poll at loop start (used by Continue)
+ InitialSteeringMessages []providers.Message
+ DefaultResponse string // Response when LLM returns empty
+ EnableSummary bool // Whether to trigger summarization
+ SendResponse bool // Whether to send response via bus
+ NoHistory bool // If true, don't load session history (for heartbeat)
+ SkipInitialSteeringPoll bool // If true, skip the steering poll at loop start (used by Continue)
}
type continuationTarget struct {
@@ -271,11 +272,14 @@ func (al *AgentLoop) Run(ctx context.Context) error {
}
// Start a goroutine that drains the bus while processMessage is
- // running. Any inbound messages that arrive during processing are
- // redirected into the steering queue so the agent loop can pick
- // them up between tool calls.
- drainCtx, drainCancel := context.WithCancel(ctx)
- go al.drainBusToSteering(drainCtx)
+ // running. Only messages that resolve to the active turn scope are
+ // redirected into steering; other inbound messages are requeued.
+ drainCancel := func() {}
+ if activeScope, activeAgentID, ok := al.resolveSteeringTarget(msg); ok {
+ drainCtx, cancel := context.WithCancel(ctx)
+ drainCancel = cancel
+ go al.drainBusToSteering(drainCtx, activeScope, activeAgentID)
+ }
// Process message
func() {
@@ -316,13 +320,13 @@ func (al *AgentLoop) Run(ctx context.Context) error {
return
}
- for al.pendingSteeringCount() > 0 {
+ for al.pendingSteeringCountForScope(target.SessionKey) > 0 {
logger.InfoCF("agent", "Continuing queued steering after turn end",
map[string]any{
"channel": target.Channel,
"chat_id": target.ChatID,
"session_key": target.SessionKey,
- "queue_depth": al.pendingSteeringCount(),
+ "queue_depth": al.pendingSteeringCountForScope(target.SessionKey),
})
continued, continueErr := al.Continue(ctx, target.SessionKey, target.Channel, target.ChatID)
@@ -349,15 +353,27 @@ func (al *AgentLoop) Run(ctx context.Context) error {
}
// drainBusToSteering continuously consumes inbound messages and redirects
-// them into the steering queue. It runs in a goroutine while processMessage
-// is active and stops when drainCtx is canceled (i.e., processMessage returns).
-func (al *AgentLoop) drainBusToSteering(ctx context.Context) {
+// messages from the active scope into the steering queue. Messages from other
+// scopes are requeued so they can be processed normally after the active turn.
+func (al *AgentLoop) drainBusToSteering(ctx context.Context, activeScope, activeAgentID string) {
for {
msg, ok := al.bus.ConsumeInbound(ctx)
if !ok {
return
}
+ msgScope, _, scopeOK := al.resolveSteeringTarget(msg)
+ if !scopeOK || msgScope != activeScope {
+ if err := al.requeueInboundMessage(msg); err != nil {
+ logger.WarnCF("agent", "Failed to requeue non-steering inbound message", map[string]any{
+ "error": err.Error(),
+ "channel": msg.Channel,
+ "sender_id": msg.SenderID,
+ })
+ }
+ return
+ }
+
// Transcribe audio if needed before steering, so the agent sees text.
msg, _ = al.transcribeAudioInMessage(ctx, msg)
@@ -366,11 +382,13 @@ func (al *AgentLoop) drainBusToSteering(ctx context.Context) {
"channel": msg.Channel,
"sender_id": msg.SenderID,
"content_len": len(msg.Content),
+ "scope": activeScope,
})
- if err := al.Steer(providers.Message{
+ if err := al.enqueueSteeringMessage(activeScope, activeAgentID, providers.Message{
Role: "user",
Content: msg.Content,
+ Media: append([]string(nil), msg.Media...),
}); err != nil {
logger.WarnCF("agent", "Failed to steer message, will be lost",
map[string]any{
@@ -1085,6 +1103,25 @@ func resolveScopeKey(route routing.ResolvedRoute, msgSessionKey string) string {
return route.SessionKey
}
+func (al *AgentLoop) resolveSteeringTarget(msg bus.InboundMessage) (string, string, bool) {
+ if msg.Channel == "system" {
+ return "", "", false
+ }
+
+ route, agent, err := al.resolveMessageRoute(msg)
+ if err != nil || agent == nil {
+ return "", "", false
+ }
+
+ return resolveScopeKey(route, msg.SessionKey), agent.ID, true
+}
+
+func (al *AgentLoop) requeueInboundMessage(msg bus.InboundMessage) error {
+ pubCtx, cancel := context.WithTimeout(context.Background(), time.Second)
+ defer cancel()
+ return al.bus.PublishInbound(pubCtx, msg)
+}
+
func (al *AgentLoop) processSystemMessage(
ctx context.Context,
msg bus.InboundMessage,
@@ -1346,16 +1383,25 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
}
}
- if !ts.opts.NoHistory {
- rootMsg := providers.Message{Role: "user", Content: ts.userMessage}
- ts.agent.Sessions.AddMessage(ts.sessionKey, rootMsg.Role, rootMsg.Content)
+ if !ts.opts.NoHistory && (strings.TrimSpace(ts.userMessage) != "" || len(ts.media) > 0) {
+ rootMsg := providers.Message{
+ Role: "user",
+ Content: ts.userMessage,
+ Media: append([]string(nil), ts.media...),
+ }
+ if len(rootMsg.Media) > 0 {
+ ts.agent.Sessions.AddFullMessage(ts.sessionKey, rootMsg)
+ } else {
+ ts.agent.Sessions.AddMessage(ts.sessionKey, rootMsg.Role, rootMsg.Content)
+ }
ts.recordPersistedMessage(rootMsg)
}
activeCandidates, activeModel := al.selectCandidates(ts.agent, ts.userMessage, messages)
- var pendingMessages []providers.Message
+ pendingMessages := append([]providers.Message(nil), ts.opts.InitialSteeringMessages...)
var finalContent string
+turnLoop:
for ts.currentIteration() < ts.agent.MaxIterations || len(pendingMessages) > 0 || func() bool {
graceful, _ := ts.gracefulInterruptRequested()
return graceful
@@ -1369,19 +1415,24 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
ts.setIteration(iteration)
ts.setPhase(TurnPhaseRunning)
- if iteration > 1 || !ts.opts.SkipInitialSteeringPoll {
- if steerMsgs := al.dequeueSteeringMessages(); len(steerMsgs) > 0 {
+ if iteration > 1 {
+ if steerMsgs := al.dequeueSteeringMessagesForScope(ts.sessionKey); len(steerMsgs) > 0 {
+ pendingMessages = append(pendingMessages, steerMsgs...)
+ }
+ } else if !ts.opts.SkipInitialSteeringPoll {
+ if steerMsgs := al.dequeueSteeringMessagesForScopeWithFallback(ts.sessionKey); len(steerMsgs) > 0 {
pendingMessages = append(pendingMessages, steerMsgs...)
}
}
if len(pendingMessages) > 0 {
+ resolvedPending := resolveMediaRefs(pendingMessages, al.mediaStore, maxMediaSize)
totalContentLen := 0
- for _, pm := range pendingMessages {
- messages = append(messages, pm)
+ for i, pm := range pendingMessages {
+ messages = append(messages, resolvedPending[i])
totalContentLen += len(pm.Content)
if !ts.opts.NoHistory {
- ts.agent.Sessions.AddMessage(ts.sessionKey, pm.Role, pm.Content)
+ ts.agent.Sessions.AddFullMessage(ts.sessionKey, pm)
ts.recordPersistedMessage(pm)
}
logger.InfoCF("agent", "Injected steering message into context",
@@ -1389,6 +1440,7 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
"agent_id": ts.agent.ID,
"iteration": iteration,
"content_len": len(pm.Content),
+ "media_count": len(pm.Media),
})
}
al.emitEvent(
@@ -1660,10 +1712,21 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
})
if len(response.ToolCalls) == 0 || gracefulTerminal {
- finalContent = response.Content
- if finalContent == "" && response.ReasoningContent != "" {
- finalContent = response.ReasoningContent
+ responseContent := response.Content
+ if responseContent == "" && response.ReasoningContent != "" {
+ responseContent = response.ReasoningContent
}
+ if steerMsgs := al.dequeueSteeringMessagesForScope(ts.sessionKey); len(steerMsgs) > 0 {
+ logger.InfoCF("agent", "Steering arrived after direct LLM response; continuing turn",
+ map[string]any{
+ "agent_id": ts.agent.ID,
+ "iteration": iteration,
+ "steering_count": len(steerMsgs),
+ })
+ pendingMessages = append(pendingMessages, steerMsgs...)
+ continue
+ }
+ finalContent = responseContent
logger.InfoCF("agent", "LLM response without tool calls (direct answer)",
map[string]any{
"agent_id": ts.agent.ID,
@@ -1870,7 +1933,7 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
ts.recordPersistedMessage(toolResultMsg)
}
- if steerMsgs := al.dequeueSteeringMessages(); len(steerMsgs) > 0 {
+ if steerMsgs := al.dequeueSteeringMessagesForScope(ts.sessionKey); len(steerMsgs) > 0 {
pendingMessages = append(pendingMessages, steerMsgs...)
}
@@ -1926,6 +1989,18 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
})
}
+ if steerMsgs := al.dequeueSteeringMessagesForScope(ts.sessionKey); len(steerMsgs) > 0 {
+ logger.InfoCF("agent", "Steering arrived after turn completion; continuing turn before finalizing",
+ map[string]any{
+ "agent_id": ts.agent.ID,
+ "steering_count": len(steerMsgs),
+ "session_key": ts.sessionKey,
+ })
+ pendingMessages = append(pendingMessages, steerMsgs...)
+ finalContent = ""
+ goto turnLoop
+ }
+
if ts.hardAbortRequested() {
turnStatus = TurnEndStatusAborted
return al.abortTurn(ts)
diff --git a/pkg/agent/steering.go b/pkg/agent/steering.go
index 77c2e0c17..eb8afa1dd 100644
--- a/pkg/agent/steering.go
+++ b/pkg/agent/steering.go
@@ -8,6 +8,7 @@ import (
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
+ "github.com/sipeed/picoclaw/pkg/routing"
)
// SteeringMode controls how queued steering messages are dequeued.
@@ -20,6 +21,9 @@ const (
SteeringAll SteeringMode = "all"
// MaxQueueSize number of possible messages in the Steering Queue
MaxQueueSize = 10
+ // manualSteeringScope is the legacy fallback queue used when no active
+ // turn/session scope is available.
+ manualSteeringScope = "__manual__"
)
// parseSteeringMode normalizes a config string into a SteeringMode.
@@ -35,56 +39,117 @@ func parseSteeringMode(s string) SteeringMode {
// steeringQueue is a thread-safe queue of user messages that can be injected
// into a running agent loop to interrupt it between tool calls.
type steeringQueue struct {
- mu sync.Mutex
- queue []providers.Message
- mode SteeringMode
+ mu sync.Mutex
+ queues map[string][]providers.Message
+ mode SteeringMode
}
func newSteeringQueue(mode SteeringMode) *steeringQueue {
return &steeringQueue{
- mode: mode,
+ queues: make(map[string][]providers.Message),
+ mode: mode,
}
}
-// push enqueues a steering message.
+func normalizeSteeringScope(scope string) string {
+ scope = strings.TrimSpace(scope)
+ if scope == "" {
+ return manualSteeringScope
+ }
+ return scope
+}
+
+// push enqueues a steering message in the legacy fallback scope.
func (sq *steeringQueue) push(msg providers.Message) error {
+ return sq.pushScope(manualSteeringScope, msg)
+}
+
+// pushScope enqueues a steering message for the provided scope.
+func (sq *steeringQueue) pushScope(scope string, msg providers.Message) error {
sq.mu.Lock()
defer sq.mu.Unlock()
- if len(sq.queue) >= MaxQueueSize {
+
+ scope = normalizeSteeringScope(scope)
+ queue := sq.queues[scope]
+ if len(queue) >= MaxQueueSize {
return fmt.Errorf("steering queue is full")
}
- sq.queue = append(sq.queue, msg)
+ sq.queues[scope] = append(queue, msg)
return nil
}
-// dequeue removes and returns pending steering messages according to the
-// configured mode. Returns nil when the queue is empty.
+// dequeue removes and returns pending steering messages from the legacy
+// fallback scope according to the configured mode.
func (sq *steeringQueue) dequeue() []providers.Message {
+ return sq.dequeueScope(manualSteeringScope)
+}
+
+// dequeueScope removes and returns pending steering messages for the provided
+// scope according to the configured mode.
+func (sq *steeringQueue) dequeueScope(scope string) []providers.Message {
sq.mu.Lock()
defer sq.mu.Unlock()
- if len(sq.queue) == 0 {
+ return sq.dequeueLocked(normalizeSteeringScope(scope))
+}
+
+// dequeueScopeWithFallback drains the scoped queue first and falls back to the
+// legacy manual scope for backwards compatibility.
+func (sq *steeringQueue) dequeueScopeWithFallback(scope string) []providers.Message {
+ sq.mu.Lock()
+ defer sq.mu.Unlock()
+
+ scope = strings.TrimSpace(scope)
+ if scope != "" {
+ if msgs := sq.dequeueLocked(scope); len(msgs) > 0 {
+ return msgs
+ }
+ }
+
+ return sq.dequeueLocked(manualSteeringScope)
+}
+
+func (sq *steeringQueue) dequeueLocked(scope string) []providers.Message {
+ queue := sq.queues[scope]
+ if len(queue) == 0 {
return nil
}
switch sq.mode {
case SteeringAll:
- msgs := sq.queue
- sq.queue = nil
+ msgs := append([]providers.Message(nil), queue...)
+ delete(sq.queues, scope)
return msgs
- default: // one-at-a-time
- msg := sq.queue[0]
- sq.queue[0] = providers.Message{} // Clear reference for GC
- sq.queue = sq.queue[1:]
+ default:
+ msg := queue[0]
+ queue[0] = providers.Message{} // Clear reference for GC
+ queue = queue[1:]
+ if len(queue) == 0 {
+ delete(sq.queues, scope)
+ } else {
+ sq.queues[scope] = queue
+ }
return []providers.Message{msg}
}
}
-// len returns the number of queued messages.
+// len returns the number of queued messages across all scopes.
func (sq *steeringQueue) len() int {
sq.mu.Lock()
defer sq.mu.Unlock()
- return len(sq.queue)
+
+ total := 0
+ for _, queue := range sq.queues {
+ total += len(queue)
+ }
+ return total
+}
+
+// lenScope returns the number of queued messages for a specific scope.
+func (sq *steeringQueue) lenScope(scope string) int {
+ sq.mu.Lock()
+ defer sq.mu.Unlock()
+ return len(sq.queues[normalizeSteeringScope(scope)])
}
// setMode updates the steering mode.
@@ -101,26 +166,40 @@ func (sq *steeringQueue) getMode() SteeringMode {
return sq.mode
}
-// --- AgentLoop steering API ---
-
// Steer enqueues a user message to be injected into the currently running
// agent loop. The message will be picked up after the current tool finishes
// executing, causing any remaining tool calls in the batch to be skipped.
func (al *AgentLoop) Steer(msg providers.Message) error {
+ scope := ""
+ agentID := ""
+ if ts := al.getActiveTurnState(); ts != nil {
+ scope = ts.sessionKey
+ agentID = ts.agentID
+ }
+ return al.enqueueSteeringMessage(scope, agentID, msg)
+}
+
+func (al *AgentLoop) enqueueSteeringMessage(scope, agentID string, msg providers.Message) error {
if al.steering == nil {
return fmt.Errorf("steering queue is not initialized")
}
- if err := al.steering.push(msg); err != nil {
+
+ if err := al.steering.pushScope(scope, msg); err != nil {
logger.WarnCF("agent", "Failed to enqueue steering message", map[string]any{
"error": err.Error(),
"role": msg.Role,
+ "scope": normalizeSteeringScope(scope),
})
return err
}
+
+ queueDepth := al.steering.lenScope(scope)
logger.DebugCF("agent", "Steering message enqueued", map[string]any{
"role": msg.Role,
"content_len": len(msg.Content),
- "queue_len": al.steering.len(),
+ "media_count": len(msg.Media),
+ "queue_len": queueDepth,
+ "scope": normalizeSteeringScope(scope),
})
meta := EventMeta{
@@ -129,11 +208,23 @@ func (al *AgentLoop) Steer(msg providers.Message) error {
}
if ts := al.getActiveTurnState(); ts != nil {
meta = ts.eventMeta("Steer", "turn.interrupt.received")
- } else if registry := al.GetRegistry(); registry != nil {
- if agent := registry.GetDefaultAgent(); agent != nil {
- meta.AgentID = agent.ID
+ } else {
+ if strings.TrimSpace(agentID) != "" {
+ meta.AgentID = agentID
+ }
+ normalizedScope := normalizeSteeringScope(scope)
+ if normalizedScope != manualSteeringScope {
+ meta.SessionKey = normalizedScope
+ }
+ if meta.AgentID == "" {
+ if registry := al.GetRegistry(); registry != nil {
+ if agent := registry.GetDefaultAgent(); agent != nil {
+ meta.AgentID = agent.ID
+ }
+ }
}
}
+
al.emitEvent(
EventKindInterruptReceived,
meta,
@@ -141,7 +232,7 @@ func (al *AgentLoop) Steer(msg providers.Message) error {
Kind: InterruptKindSteering,
Role: msg.Role,
ContentLen: len(msg.Content),
- QueueDepth: al.steering.len(),
+ QueueDepth: queueDepth,
},
)
@@ -165,7 +256,7 @@ func (al *AgentLoop) SetSteeringMode(mode SteeringMode) {
}
// dequeueSteeringMessages is the internal method called by the agent loop
-// to poll for steering messages. Returns nil when no messages are pending.
+// to poll for steering messages in the legacy fallback scope.
func (al *AgentLoop) dequeueSteeringMessages() []providers.Message {
if al.steering == nil {
return nil
@@ -173,6 +264,60 @@ func (al *AgentLoop) dequeueSteeringMessages() []providers.Message {
return al.steering.dequeue()
}
+func (al *AgentLoop) dequeueSteeringMessagesForScope(scope string) []providers.Message {
+ if al.steering == nil {
+ return nil
+ }
+ return al.steering.dequeueScope(scope)
+}
+
+func (al *AgentLoop) dequeueSteeringMessagesForScopeWithFallback(scope string) []providers.Message {
+ if al.steering == nil {
+ return nil
+ }
+ return al.steering.dequeueScopeWithFallback(scope)
+}
+
+func (al *AgentLoop) pendingSteeringCountForScope(scope string) int {
+ if al.steering == nil {
+ return 0
+ }
+ return al.steering.lenScope(scope)
+}
+
+func (al *AgentLoop) continueWithSteeringMessages(
+ ctx context.Context,
+ agent *AgentInstance,
+ sessionKey, channel, chatID string,
+ steeringMsgs []providers.Message,
+) (string, error) {
+ return al.runAgentLoop(ctx, agent, processOptions{
+ SessionKey: sessionKey,
+ Channel: channel,
+ ChatID: chatID,
+ DefaultResponse: defaultResponse,
+ EnableSummary: true,
+ SendResponse: false,
+ InitialSteeringMessages: steeringMsgs,
+ SkipInitialSteeringPoll: true,
+ })
+}
+
+func (al *AgentLoop) agentForSession(sessionKey string) *AgentInstance {
+ registry := al.GetRegistry()
+ if registry == nil {
+ return nil
+ }
+
+ if parsed := routing.ParseAgentSessionKey(sessionKey); parsed != nil {
+ if agent, ok := registry.GetAgent(parsed.AgentID); ok {
+ return agent
+ }
+ }
+
+ return registry.GetDefaultAgent()
+}
+
// Continue resumes an idle agent by dequeuing any pending steering messages
// and running them through the agent loop. This is used when the agent's last
// message was from the assistant (i.e., it has stopped processing) and the
@@ -184,14 +329,14 @@ func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID s
return "", fmt.Errorf("turn %s is still active", active.TurnID)
}
- steeringMsgs := al.dequeueSteeringMessages()
+ steeringMsgs := al.dequeueSteeringMessagesForScopeWithFallback(sessionKey)
if len(steeringMsgs) == 0 {
return "", nil
}
- agent := al.GetRegistry().GetDefaultAgent()
+ agent := al.agentForSession(sessionKey)
if agent == nil {
- return "", fmt.Errorf("no default agent available")
+ return "", fmt.Errorf("no agent available for session %q", sessionKey)
}
if tool, ok := agent.Tools.Get("message"); ok {
@@ -200,23 +345,7 @@ func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID s
}
}
- // Build a combined user message from the steering messages.
- var contents []string
- for _, msg := range steeringMsgs {
- contents = append(contents, msg.Content)
- }
- combinedContent := strings.Join(contents, "\n")
-
- return al.runAgentLoop(ctx, agent, processOptions{
- SessionKey: sessionKey,
- Channel: channel,
- ChatID: chatID,
- UserMessage: combinedContent,
- DefaultResponse: defaultResponse,
- EnableSummary: true,
- SendResponse: false,
- SkipInitialSteeringPoll: true,
- })
+ return al.continueWithSteeringMessages(ctx, agent, sessionKey, channel, chatID, steeringMsgs)
}
func (al *AgentLoop) InterruptGraceful(hint string) error {
diff --git a/pkg/agent/steering_test.go b/pkg/agent/steering_test.go
index bb5d42c73..4c14dc6ef 100644
--- a/pkg/agent/steering_test.go
+++ b/pkg/agent/steering_test.go
@@ -5,13 +5,16 @@ import (
"encoding/json"
"fmt"
"os"
+ "path/filepath"
"reflect"
+ "strings"
"sync"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
+ "github.com/sipeed/picoclaw/pkg/media"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/routing"
"github.com/sipeed/picoclaw/pkg/tools"
@@ -337,6 +340,96 @@ func TestAgentLoop_Continue_WithMessages(t *testing.T) {
}
}
+func TestDrainBusToSteering_RequeuesDifferentScopeMessage(t *testing.T) {
+ tmpDir, err := os.MkdirTemp("", "agent-test-*")
+ if err != nil {
+ t.Fatalf("Failed to create temp dir: %v", err)
+ }
+ defer os.RemoveAll(tmpDir)
+
+ cfg := &config.Config{
+ Agents: config.AgentsConfig{
+ Defaults: config.AgentDefaults{
+ Workspace: tmpDir,
+ Model: "test-model",
+ MaxTokens: 4096,
+ MaxToolIterations: 10,
+ },
+ },
+ Session: config.SessionConfig{
+ DMScope: "per-peer",
+ },
+ }
+
+ msgBus := bus.NewMessageBus()
+ al := NewAgentLoop(cfg, msgBus, &mockProvider{})
+
+ activeMsg := bus.InboundMessage{
+ Channel: "telegram",
+ SenderID: "user1",
+ ChatID: "chat1",
+ Content: "active turn",
+ Peer: bus.Peer{
+ Kind: "direct",
+ ID: "user1",
+ },
+ }
+ activeScope, activeAgentID, ok := al.resolveSteeringTarget(activeMsg)
+ if !ok {
+ t.Fatal("expected active message to resolve to a steering scope")
+ }
+
+ otherMsg := bus.InboundMessage{
+ Channel: "telegram",
+ SenderID: "user2",
+ ChatID: "chat2",
+ Content: "other session",
+ Peer: bus.Peer{
+ Kind: "direct",
+ ID: "user2",
+ },
+ }
+ otherScope, _, ok := al.resolveSteeringTarget(otherMsg)
+ if !ok {
+ t.Fatal("expected other message to resolve to a steering scope")
+ }
+ if otherScope == activeScope {
+ t.Fatalf("expected different steering scopes, got same scope %q", activeScope)
+ }
+
+ if err := msgBus.PublishInbound(context.Background(), otherMsg); err != nil {
+ t.Fatalf("PublishInbound failed: %v", err)
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+ defer cancel()
+
+ done := make(chan struct{})
+ go func() {
+ al.drainBusToSteering(ctx, activeScope, activeAgentID)
+ close(done)
+ }()
+
+ select {
+ case <-done:
+ case <-time.After(2 * time.Second):
+ t.Fatal("timeout waiting for drainBusToSteering to stop")
+ }
+
+ if msgs := al.dequeueSteeringMessagesForScope(activeScope); len(msgs) != 0 {
+ t.Fatalf("expected no steering messages for active scope, got %v", msgs)
+ }
+
+ requeued, ok := msgBus.ConsumeInbound(context.Background())
+ if !ok {
+ t.Fatal("expected message to be requeued on the inbound bus")
+ }
+ if requeued.Channel != otherMsg.Channel || requeued.ChatID != otherMsg.ChatID ||
+ requeued.SenderID != otherMsg.SenderID || requeued.Content != otherMsg.Content {
+ t.Fatalf("requeued message mismatch: got %+v want %+v", requeued, otherMsg)
+ }
+}
+
// slowTool simulates a tool that takes some time to execute.
type slowTool struct {
name string
@@ -472,6 +565,52 @@ func (p *lateSteeringProvider) GetDefaultModel() string {
return "late-steering-mock"
}
+type blockingDirectProvider struct {
+ mu sync.Mutex
+ calls int
+ firstStarted chan struct{}
+ releaseFirst chan struct{}
+ firstResp string
+ finalResp string
+}
+
+func (p *blockingDirectProvider) Chat(
+ ctx context.Context,
+ messages []providers.Message,
+ tools []providers.ToolDefinition,
+ model string,
+ opts map[string]any,
+) (*providers.LLMResponse, error) {
+ p.mu.Lock()
+ p.calls++
+ call := p.calls
+ firstStarted := p.firstStarted
+ releaseFirst := p.releaseFirst
+ firstResp := p.firstResp
+ finalResp := p.finalResp
+ if call == 1 && p.firstStarted != nil {
+ close(p.firstStarted)
+ p.firstStarted = nil
+ }
+ p.mu.Unlock()
+
+ if call == 1 {
+ select {
+ case <-releaseFirst:
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ }
+ return &providers.LLMResponse{Content: firstResp}, nil
+ }
+
+ _ = firstStarted
+ return &providers.LLMResponse{Content: finalResp}, nil
+}
+
+func (p *blockingDirectProvider) GetDefaultModel() string {
+ return "blocking-direct-mock"
+}
+
type interruptibleTool struct {
name string
started chan struct{}
@@ -744,18 +883,16 @@ func TestAgentLoop_Run_AutoContinuesLateSteeringMessage(t *testing.T) {
out1, ok := msgBus.SubscribeOutbound(subCtx)
if !ok {
- t.Fatal("expected first outbound response")
+ t.Fatal("expected outbound response")
}
- if out1.Content != "first response" {
- t.Fatalf("expected first response, got %q", out1.Content)
+ if out1.Content != "continued response" {
+ t.Fatalf("expected continued response, got %q", out1.Content)
}
- out2, ok := msgBus.SubscribeOutbound(subCtx)
- if !ok {
- t.Fatal("expected continued outbound response")
- }
- if out2.Content != "continued response" {
- t.Fatalf("expected continued response, got %q", out2.Content)
+ noExtraCtx, cancelNoExtra := context.WithTimeout(context.Background(), 200*time.Millisecond)
+ defer cancelNoExtra()
+ if out2, ok := msgBus.SubscribeOutbound(noExtraCtx); ok {
+ t.Fatalf("expected stale direct response to be suppressed, got extra outbound %q", out2.Content)
}
cancelRun()
@@ -789,6 +926,191 @@ func TestAgentLoop_Run_AutoContinuesLateSteeringMessage(t *testing.T) {
}
}
+func TestAgentLoop_Steering_DirectResponseContinuesWithQueuedMessage(t *testing.T) {
+ tmpDir, err := os.MkdirTemp("", "agent-test-*")
+ if err != nil {
+ t.Fatalf("Failed to create temp dir: %v", err)
+ }
+ defer os.RemoveAll(tmpDir)
+
+ cfg := &config.Config{
+ Agents: config.AgentsConfig{
+ Defaults: config.AgentDefaults{
+ Workspace: tmpDir,
+ Model: "test-model",
+ MaxTokens: 4096,
+ MaxToolIterations: 10,
+ },
+ },
+ }
+
+ sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID)
+ provider := &blockingDirectProvider{
+ firstStarted: make(chan struct{}),
+ releaseFirst: make(chan struct{}),
+ firstResp: "stale direct response",
+ finalResp: "fresh response after steering",
+ }
+
+ msgBus := bus.NewMessageBus()
+ al := NewAgentLoop(cfg, msgBus, provider)
+
+ resultCh := make(chan struct {
+ resp string
+ err error
+ }, 1)
+ go func() {
+ resp, err := al.ProcessDirectWithChannel(
+ context.Background(),
+ "initial request",
+ sessionKey,
+ "test",
+ "chat1",
+ )
+ resultCh <- struct {
+ resp string
+ err error
+ }{resp: resp, err: err}
+ }()
+
+ select {
+ case <-provider.firstStarted:
+ case <-time.After(2 * time.Second):
+ t.Fatal("timeout waiting for first LLM call to start")
+ }
+
+ if err := al.Steer(providers.Message{Role: "user", Content: "follow-up instruction"}); err != nil {
+ t.Fatalf("Steer failed: %v", err)
+ }
+ close(provider.releaseFirst)
+
+ select {
+ case result := <-resultCh:
+ if result.err != nil {
+ t.Fatalf("unexpected error: %v", result.err)
+ }
+ if result.resp != "fresh response after steering" {
+ t.Fatalf("expected refreshed response, got %q", result.resp)
+ }
+ case <-time.After(5 * time.Second):
+ t.Fatal("timeout waiting for ProcessDirectWithChannel")
+ }
+
+ provider.mu.Lock()
+ calls := provider.calls
+ provider.mu.Unlock()
+ if calls != 2 {
+ t.Fatalf("expected 2 provider calls, got %d", calls)
+ }
+
+ if msgs := al.dequeueSteeringMessagesForScope(sessionKey); len(msgs) != 0 {
+ t.Fatalf("expected steering queue to be empty after continuation, got %v", msgs)
+ }
+}
+
+func TestAgentLoop_Continue_PreservesSteeringMedia(t *testing.T) {
+ tmpDir, err := os.MkdirTemp("", "agent-test-*")
+ if err != nil {
+ t.Fatalf("Failed to create temp dir: %v", err)
+ }
+ defer os.RemoveAll(tmpDir)
+
+ cfg := &config.Config{
+ Agents: config.AgentsConfig{
+ Defaults: config.AgentDefaults{
+ Workspace: tmpDir,
+ Model: "test-model",
+ MaxTokens: 4096,
+ MaxToolIterations: 10,
+ },
+ },
+ }
+
+ store := media.NewFileMediaStore()
+ pngPath := filepath.Join(tmpDir, "steer.png")
+ pngHeader := []byte{
+ 0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A,
+ 0x00, 0x00, 0x00, 0x0D,
+ 0x49, 0x48, 0x44, 0x52,
+ 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x02,
+ 0x00, 0x00, 0x00,
+ 0x90, 0x77, 0x53, 0xDE,
+ }
+ if err := os.WriteFile(pngPath, pngHeader, 0o644); err != nil {
+ t.Fatalf("WriteFile failed: %v", err)
+ }
+ ref, err := store.Store(pngPath, media.MediaMeta{Filename: "steer.png", ContentType: "image/png"}, "test")
+ if err != nil {
+ t.Fatalf("Store failed: %v", err)
+ }
+
+ var capturedMessages []providers.Message
+ var capMu sync.Mutex
+ provider := &capturingMockProvider{
+ response: "ack",
+ captureFn: func(msgs []providers.Message) {
+ capMu.Lock()
+ defer capMu.Unlock()
+ capturedMessages = append([]providers.Message(nil), msgs...)
+ },
+ }
+
+ sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID)
+ msgBus := bus.NewMessageBus()
+ al := NewAgentLoop(cfg, msgBus, provider)
+ al.SetMediaStore(store)
+
+ if err := al.Steer(providers.Message{
+ Role: "user",
+ Content: "describe this image",
+ Media: []string{ref},
+ }); err != nil {
+ t.Fatalf("Steer failed: %v", err)
+ }
+
+ resp, err := al.Continue(context.Background(), sessionKey, "test", "chat1")
+ if err != nil {
+ t.Fatalf("Continue failed: %v", err)
+ }
+ if resp != "ack" {
+ t.Fatalf("expected ack, got %q", resp)
+ }
+
+ capMu.Lock()
+ msgs := append([]providers.Message(nil), capturedMessages...)
+ capMu.Unlock()
+
+ foundResolvedMedia := false
+ for _, msg := range msgs {
+ if msg.Role != "user" || msg.Content != "describe this image" || len(msg.Media) != 1 {
+ continue
+ }
+ if strings.HasPrefix(msg.Media[0], "data:image/png;base64,") {
+ foundResolvedMedia = true
+ break
+ }
+ }
+ if !foundResolvedMedia {
+ t.Fatal("expected continue path to inject steering media into the provider request")
+ }
+
+ defaultAgent := al.registry.GetDefaultAgent()
+ if defaultAgent == nil {
+ t.Fatal("expected default agent")
+ }
+ history := defaultAgent.Sessions.GetHistory(sessionKey)
+ foundOriginalRef := false
+ for _, msg := range history {
+ if msg.Role == "user" && len(msg.Media) == 1 && msg.Media[0] == ref {
+ foundOriginalRef = true
+ break
+ }
+ }
+ if !foundOriginalRef {
+ t.Fatal("expected original steering media ref to be preserved in session history")
+ }
+}
+
func TestAgentLoop_InterruptGraceful_UsesTerminalNoToolCall(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
From 827449aff35a3f517f6a4c80e58442a1f4c2af69 Mon Sep 17 00:00:00 2001
From: afjcjsbx
Date: Fri, 20 Mar 2026 20:12:55 +0100
Subject: [PATCH 22/26] fix lint
---
pkg/agent/loop.go | 7 -------
pkg/agent/steering_test.go | 4 ++--
2 files changed, 2 insertions(+), 9 deletions(-)
diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go
index 27bafe977..01e7ce4c4 100644
--- a/pkg/agent/loop.go
+++ b/pkg/agent/loop.go
@@ -440,13 +440,6 @@ func (al *AgentLoop) publishResponseIfNeeded(ctx context.Context, channel, chatI
})
}
-func (al *AgentLoop) pendingSteeringCount() int {
- if al.steering == nil {
- return 0
- }
- return al.steering.len()
-}
-
func (al *AgentLoop) buildContinuationTarget(msg bus.InboundMessage) (*continuationTarget, error) {
if msg.Channel == "system" {
return nil, nil
diff --git a/pkg/agent/steering_test.go b/pkg/agent/steering_test.go
index 4c14dc6ef..cf2e86904 100644
--- a/pkg/agent/steering_test.go
+++ b/pkg/agent/steering_test.go
@@ -1036,7 +1036,7 @@ func TestAgentLoop_Continue_PreservesSteeringMedia(t *testing.T) {
0x00, 0x00, 0x00,
0x90, 0x77, 0x53, 0xDE,
}
- if err := os.WriteFile(pngPath, pngHeader, 0o644); err != nil {
+ if err = os.WriteFile(pngPath, pngHeader, 0o644); err != nil {
t.Fatalf("WriteFile failed: %v", err)
}
ref, err := store.Store(pngPath, media.MediaMeta{Filename: "steer.png", ContentType: "image/png"}, "test")
@@ -1060,7 +1060,7 @@ func TestAgentLoop_Continue_PreservesSteeringMedia(t *testing.T) {
al := NewAgentLoop(cfg, msgBus, provider)
al.SetMediaStore(store)
- if err := al.Steer(providers.Message{
+ if err = al.Steer(providers.Message{
Role: "user",
Content: "describe this image",
Media: []string{ref},
From 9e344594a2045faae5ce416f7af7f4879dbbf69f Mon Sep 17 00:00:00 2001
From: afjcjsbx
Date: Fri, 20 Mar 2026 21:07:07 +0100
Subject: [PATCH 23/26] fix logic
---
pkg/agent/loop.go | 53 +++++++++++++++++++++++++++++++++++++++++------
1 file changed, 47 insertions(+), 6 deletions(-)
diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go
index 01e7ce4c4..a3a23fb3d 100644
--- a/pkg/agent/loop.go
+++ b/pkg/agent/loop.go
@@ -296,16 +296,21 @@ func (al *AgentLoop) Run(ctx context.Context) error {
// }
// }()
- defer drainCancel()
+ drainCanceled := false
+ cancelDrain := func() {
+ if drainCanceled {
+ return
+ }
+ drainCancel()
+ drainCanceled = true
+ }
+ defer cancelDrain()
response, err := al.processMessage(ctx, msg)
if err != nil {
response = fmt.Sprintf("Error processing message: %v", err)
}
-
- if response != "" {
- al.publishResponseIfNeeded(ctx, msg.Channel, msg.ChatID, response)
- }
+ finalResponse := response
target, targetErr := al.buildContinuationTarget(msg)
if targetErr != nil {
@@ -317,6 +322,10 @@ func (al *AgentLoop) Run(ctx context.Context) error {
return
}
if target == nil {
+ cancelDrain()
+ if finalResponse != "" {
+ al.publishResponseIfNeeded(ctx, msg.Channel, msg.ChatID, finalResponse)
+ }
return
}
@@ -343,7 +352,39 @@ func (al *AgentLoop) Run(ctx context.Context) error {
return
}
- al.publishResponseIfNeeded(ctx, target.Channel, target.ChatID, continued)
+ finalResponse = continued
+ }
+
+ cancelDrain()
+
+ for al.pendingSteeringCountForScope(target.SessionKey) > 0 {
+ logger.InfoCF("agent", "Draining steering queued during turn shutdown",
+ map[string]any{
+ "channel": target.Channel,
+ "chat_id": target.ChatID,
+ "session_key": target.SessionKey,
+ "queue_depth": al.pendingSteeringCountForScope(target.SessionKey),
+ })
+
+ continued, continueErr := al.Continue(ctx, target.SessionKey, target.Channel, target.ChatID)
+ if continueErr != nil {
+ logger.WarnCF("agent", "Failed to continue queued steering after shutdown drain",
+ map[string]any{
+ "channel": target.Channel,
+ "chat_id": target.ChatID,
+ "error": continueErr.Error(),
+ })
+ return
+ }
+ if continued == "" {
+ break
+ }
+
+ finalResponse = continued
+ }
+
+ if finalResponse != "" {
+ al.publishResponseIfNeeded(ctx, target.Channel, target.ChatID, finalResponse)
}
}()
}
From cf68c91ecaa15c3518686e7cfa9c637fabfcbead Mon Sep 17 00:00:00 2001
From: Hoshina
Date: Sat, 21 Mar 2026 19:15:10 +0800
Subject: [PATCH 24/26] feat(agent): add hook manager foundation
---
docs/design/hook-system-design.zh.md | 476 +++++++++++++++++
pkg/agent/hooks.go | 751 +++++++++++++++++++++++++++
pkg/agent/hooks_test.go | 312 +++++++++++
pkg/agent/loop.go | 309 +++++++++--
4 files changed, 1801 insertions(+), 47 deletions(-)
create mode 100644 docs/design/hook-system-design.zh.md
create mode 100644 pkg/agent/hooks.go
create mode 100644 pkg/agent/hooks_test.go
diff --git a/docs/design/hook-system-design.zh.md b/docs/design/hook-system-design.zh.md
new file mode 100644
index 000000000..ab5566bec
--- /dev/null
+++ b/docs/design/hook-system-design.zh.md
@@ -0,0 +1,476 @@
+# PicoClaw Hook 系统设计(基于 `refactor/agent`)
+
+## 背景
+
+本设计围绕两个议题展开:
+
+- `#1316`:把 agent loop 重构为事件驱动、可中断、可追加、可观测
+- `#1796`:在 EventBus 稳定后,把 hooks 设计为 EventBus 的 consumer,而不是重新发明一套事件模型
+
+当前分支已经完成了第一步里的“事件系统基础”,但还没有真正的 hook 挂载层。因此这里的目标不是重新设计 event,而是在已有实现上补出一层可扩展、可拦截、可外挂的 HookManager。
+
+## 外部项目对比
+
+### OpenClaw
+
+OpenClaw 的扩展能力分成三层:
+
+- Internal hooks:目录发现,运行在 Gateway 进程内
+- Plugin hooks:插件在运行时注册 hook,也在进程内
+- Webhooks:外部系统通过 HTTP 触发 Gateway 动作,属于进程外
+
+值得借鉴的点:
+
+- 有“项目内挂载”和“项目外挂载”两种路径
+- hook 是配置驱动,可启停
+- 外部入口有明确的安全边界和映射层
+
+不建议直接照搬的点:
+
+- OpenClaw 的 hooks / plugin hooks / webhooks 是三套路由,PicoClaw 当前体量下会偏重
+- HTTP webhook 更适合“事件进入系统”,不适合作为“可同步拦截 agent loop”的基础机制
+
+### pi-mono
+
+pi-mono 的核心思路更接近当前分支:
+
+- 扩展统一为 extension API
+- 事件分为观察型和可变更型
+- 某些阶段允许 `transform` / `block` / `replace`
+- 扩展代码主要是进程内执行
+- RPC mode 把 UI 交互桥接到进程外客户端
+
+值得借鉴的点:
+
+- 不把“观察”和“拦截”混成一个接口
+- 允许返回结构化动作,而不是只有回调
+- 进程外通信只暴露必要协议,不把整个内部对象图泄露出去
+
+## 当前分支现状
+
+### 已有能力
+
+当前分支已经具备 hook 系统的地基:
+
+- `pkg/agent/events.go` 定义了稳定的 `EventKind`、`EventMeta` 和 payload
+- `pkg/agent/eventbus.go` 提供了非阻塞 fan-out 的 `EventBus`
+- `pkg/agent/loop.go` 中的 `runTurn()` 已在 turn、llm、tool、interrupt、follow-up、summary 等节点发射事件
+- `pkg/agent/steering.go` 已支持 steering、graceful interrupt、hard abort
+- `pkg/agent/turn.go` 已维护 turn phase、恢复点、active turn、abort 状态
+
+### 现有缺口
+
+当前分支还缺四件事:
+
+- 没有 HookManager,只有 EventBus
+- 没有 Before/After LLM、Before/After Tool 这种同步拦截点
+- 没有审批型 hook
+- 子 agent 仍走 `pkg/tools/SubagentManager + RunToolLoop`,没有接入 `pkg/agent` 的 turn tree 和事件流
+
+### 一个关键现实
+
+`#1316` 文案里提到“只读并行、写入串行”的工具执行策略,但当前 `runTurn()` 实现已经先收敛成“顺序执行 + 每个工具后检查 steering / interrupt”。因此 hook 设计不应依赖未来的并行模型,而应该先兼容当前顺序执行,再为以后增加 `ReadOnlyIndicator` 留口子。
+
+## 设计原则
+
+- Hook 必须建立在 `pkg/agent` 的 EventBus 和 turn 上下文之上
+- EventBus 负责广播,HookManager 负责拦截,两者职责分离
+- 项目内挂载要简单,项目外挂载必须走 IPC
+- 观察型 hook 不能阻塞 loop;拦截型 hook 必须有超时
+- 先覆盖主 turn,不把 sub-turn 一次做满
+- 不新增第二套用户事件命名系统,优先复用 `EventKind.String()`
+
+## 总体架构
+
+分成三层:
+
+1. `EventBus`
+ 负责广播只读事件,现有实现直接复用
+
+2. `HookManager`
+ 负责管理 hook、排序、超时、错误隔离,并在 `runTurn()` 的明确检查点执行同步拦截
+
+3. `HookMount`
+ 负责两种挂载方式:
+ - 进程内 Go hook
+ - 进程外 IPC hook
+
+换句话说:
+
+- EventBus 是“发生了什么”
+- HookManager 是“谁能介入”
+- HookMount 是“这些 hook 从哪里来”
+
+## Hook 分类
+
+不建议把所有 hook 都设计成 `OnEvent(evt)`。
+
+建议拆成两类。
+
+### 1. 观察型
+
+只消费事件,不修改流程:
+
+```go
+type EventObserver interface {
+ OnEvent(ctx context.Context, evt agent.Event) error
+}
+```
+
+这类 hook 直接订阅 EventBus 即可。
+
+适用场景:
+
+- 审计日志
+- 指标上报
+- 调试 trace
+- 将事件转发给外部 UI / TUI / Web 面板
+
+### 2. 拦截型
+
+只在少数明确节点触发,允许返回动作:
+
+```go
+type LLMInterceptor interface {
+ BeforeLLM(ctx context.Context, req *LLMRequest) HookDecision[*LLMRequest]
+ AfterLLM(ctx context.Context, resp *LLMResponse) HookDecision[*LLMResponse]
+}
+
+type ToolInterceptor interface {
+ BeforeTool(ctx context.Context, call *ToolCall) HookDecision[*ToolCall]
+ AfterTool(ctx context.Context, result *ToolResultView) HookDecision[*ToolResultView]
+}
+
+type ToolApprover interface {
+ ApproveTool(ctx context.Context, req *ToolApprovalRequest) ApprovalDecision
+}
+```
+
+这里的 `HookDecision` 统一支持:
+
+- `continue`
+- `modify`
+- `deny_tool`
+- `abort_turn`
+- `hard_abort`
+
+## 对外暴露的最小 hook 面
+
+V1 不需要把所有 EventKind 都变成可拦截点。
+
+建议只开放这些同步 hook:
+
+- `before_llm`
+- `after_llm`
+- `before_tool`
+- `after_tool`
+- `approve_tool`
+
+其余节点继续作为只读事件暴露:
+
+- `turn_start`
+- `turn_end`
+- `llm_request`
+- `llm_response`
+- `tool_exec_start`
+- `tool_exec_end`
+- `tool_exec_skipped`
+- `steering_injected`
+- `follow_up_queued`
+- `interrupt_received`
+- `context_compress`
+- `session_summarize`
+- `error`
+
+`subturn_*` 在 V1 中保留名字,但不承诺一定触发,直到子 turn 迁移完成。
+
+## 项目内挂载
+
+内部挂载必须尽量低摩擦。
+
+建议提供两种等价方式,底层都走 HookManager。
+
+### 方式 A:代码显式挂载
+
+```go
+al.MountHook(hooks.Named("audit", &AuditHook{}))
+```
+
+适用于:
+
+- 仓内内建 hook
+- 单元测试
+- feature flag 控制
+
+### 方式 B:内建 registry
+
+```go
+func init() {
+ hooks.RegisterBuiltin("audit", func() hooks.Hook {
+ return &AuditHook{}
+ })
+}
+```
+
+启动时根据配置启用:
+
+```json
+{
+ "hooks": {
+ "builtins": {
+ "audit": { "enabled": true }
+ }
+ }
+}
+```
+
+这比 OpenClaw 的目录扫描更轻,也更贴合 Go 项目。
+
+## 项目外挂载
+
+这是本设计的硬要求。
+
+建议 V1 采用:
+
+- `JSON-RPC over stdio`
+
+原因:
+
+- 跨平台最简单
+- 不依赖额外端口
+- 非常适合“由 PicoClaw 启动一个外部 hook 进程”
+- 比 HTTP webhook 更适合同步拦截
+
+### 外部 hook 进程模型
+
+PicoClaw 启动外部进程,并在其 stdin/stdout 上跑协议。
+
+配置示例:
+
+```json
+{
+ "hooks": {
+ "processes": {
+ "review-gate": {
+ "enabled": true,
+ "transport": "stdio",
+ "command": ["uvx", "picoclaw-hook-reviewer"],
+ "observe": ["turn_start", "turn_end", "tool_exec_end"],
+ "intercept": ["before_tool", "approve_tool"],
+ "timeout_ms": 5000
+ }
+ }
+ }
+}
+```
+
+### 协议边界
+
+不要把内部 Go 结构体直接暴露给 IPC。
+
+建议定义稳定的协议对象:
+
+- `HookHandshake`
+- `HookEventNotification`
+- `BeforeLLMRequest`
+- `AfterLLMRequest`
+- `BeforeToolRequest`
+- `AfterToolRequest`
+- `ApproveToolRequest`
+- `HookDecision`
+
+其中:
+
+- 观察型事件用 notification,fire-and-forget
+- 拦截型事件用 request/response,同步等待
+
+### 为什么是 stdio,而不是直接用 HTTP webhook
+
+因为两者用途不同:
+
+- HTTP webhook 更适合“外部系统向 PicoClaw 投递事件”
+- stdio/RPC 更适合“PicoClaw 在 turn 内同步询问外部 hook 是否改写 / 放行 / 拒绝”
+
+如果未来需要 OpenClaw 式 webhook,可以作为独立入口层,再把外部事件转成 inbound message 或 steering,而不是直接替代 hook IPC。
+
+## Hook 执行顺序
+
+建议统一排序规则:
+
+- 先内建 in-process hook
+- 再外部 IPC hook
+- 同组内按 `priority` 从小到大执行
+
+原因:
+
+- 内建 hook 延迟更低,适合做基础规范化
+- 外部 hook 更适合做审批、审计、组织级策略
+
+## 超时与错误策略
+
+### 观察型
+
+- 默认超时:`500ms`
+- 超时或报错:记录日志,继续主流程
+
+### 拦截型
+
+- `before_llm` / `after_llm` / `before_tool` / `after_tool`:默认 `5s`
+- `approve_tool`:默认 `60s`
+
+超时行为:
+
+- 普通拦截:`continue`
+- 审批:`deny`
+
+这点应直接沿用 `#1316` 的安全倾向。
+
+## 与当前分支的对接点
+
+### 直接复用
+
+- 事件定义:`pkg/agent/events.go`
+- 事件广播:`pkg/agent/eventbus.go`
+- 活跃 turn / interrupt / rollback:`pkg/agent/turn.go`
+- 事件发射点:`pkg/agent/loop.go`
+
+### 需要新增
+
+- `pkg/agent/hooks.go`
+ - Hook 接口
+ - HookDecision / ApprovalDecision
+ - HookManager
+
+- `pkg/agent/hook_mount.go`
+ - 内建 hook 注册
+ - 外部进程 hook 注册
+
+- `pkg/agent/hook_ipc.go`
+ - stdio JSON-RPC bridge
+
+- `pkg/agent/hook_types.go`
+ - IPC 稳定载荷
+
+### 需要改造
+
+- `pkg/agent/loop.go`
+ - 在 LLM 和 tool 关键路径前后插入 HookManager 调用
+
+- `pkg/tools/base.go`
+ - 可选新增 `ReadOnlyIndicator`
+
+- `pkg/tools/spawn.go`
+- `pkg/tools/subagent.go`
+ - 先保留现状
+ - 等 sub-turn 迁移后再接入 `subturn_*` hook
+
+## 一个更贴合当前分支的数据流
+
+### 观察链路
+
+```text
+runTurn() -> emitEvent() -> EventBus -> observers
+```
+
+### 拦截链路
+
+```text
+runTurn()
+ -> HookManager.BeforeLLM()
+ -> Provider.Chat()
+ -> HookManager.AfterLLM()
+ -> HookManager.BeforeTool()
+ -> HookManager.ApproveTool()
+ -> tool.Execute()
+ -> HookManager.AfterTool()
+```
+
+也就是说:
+
+- observer 不改变现有 `emitEvent()`
+- interceptor 直接插在 `runTurn()` 热路径
+
+## 用户可见配置
+
+建议新增:
+
+```json
+{
+ "hooks": {
+ "enabled": true,
+ "builtins": {},
+ "processes": {},
+ "defaults": {
+ "observer_timeout_ms": 500,
+ "interceptor_timeout_ms": 5000,
+ "approval_timeout_ms": 60000
+ }
+ }
+}
+```
+
+V1 不做复杂自动发现。
+
+原因:
+
+- 当前分支重点是把地基打稳
+- 目录扫描、安装器、脚手架可以后置
+- 先让仓内和仓外都能挂上去,比“管理体验完整”更重要
+
+## 推荐的 V1 范围
+
+### 必做
+
+- HookManager
+- in-process 挂载
+- stdio IPC 挂载
+- observer hooks
+- `before_tool` / `after_tool` / `approve_tool`
+- `before_llm` / `after_llm`
+
+### 可后置
+
+- hook CLI 管理命令
+- hook 自动发现
+- Unix socket / named pipe transport
+- sub-turn hook 生命周期
+- read-only 并行分组
+- webhook 到 inbound message 的映射入口
+
+## 分阶段落地
+
+### Phase 1
+
+- 引入 HookManager
+- 支持 in-process observer + interceptor
+- 先只接主 turn
+
+### Phase 2
+
+- 引入 `stdio` 外部 hook 进程桥
+- 支持组织级审批 / 审计 / 参数改写
+
+### Phase 3
+
+- 把 `SubagentManager` 迁移到 `runTurn/sub-turn`
+- 接通 `subturn_spawn` / `subturn_end` / `subturn_result_delivered`
+
+### Phase 4
+
+- 视需求补 `ReadOnlyIndicator`
+- 在主 turn 和 sub-turn 上统一只读并行策略
+
+## 最终结论
+
+最适合 PicoClaw 当前分支的方案,不是直接复制 OpenClaw 的 hooks,也不是完整照搬 pi-mono 的 extension system,而是:
+
+- 以现有 `EventBus` 为只读观察面
+- 以新增 `HookManager` 为同步拦截面
+- 项目内通过 Go 对象直接挂载
+- 项目外通过 `stdio JSON-RPC` 进程通信挂载
+
+这样做有三个好处:
+
+- 和 `#1796` 一致,hooks 只是 EventBus 之上的消费层
+- 和当前 `refactor/agent` 实现一致,不需要推翻已有事件系统
+- 同时满足“仓内简单挂载”和“仓外进程通信挂载”两个硬需求
diff --git a/pkg/agent/hooks.go b/pkg/agent/hooks.go
new file mode 100644
index 000000000..74af542fa
--- /dev/null
+++ b/pkg/agent/hooks.go
@@ -0,0 +1,751 @@
+package agent
+
+import (
+ "context"
+ "fmt"
+ "sort"
+ "sync"
+ "time"
+
+ "github.com/sipeed/picoclaw/pkg/logger"
+ "github.com/sipeed/picoclaw/pkg/providers"
+ "github.com/sipeed/picoclaw/pkg/tools"
+)
+
+const (
+ defaultHookObserverTimeout = 500 * time.Millisecond
+ defaultHookInterceptorTimeout = 5 * time.Second
+ defaultHookApprovalTimeout = 60 * time.Second
+ hookObserverBufferSize = 64
+)
+
+type HookAction string
+
+const (
+ HookActionContinue HookAction = "continue"
+ HookActionModify HookAction = "modify"
+ HookActionDenyTool HookAction = "deny_tool"
+ HookActionAbortTurn HookAction = "abort_turn"
+ HookActionHardAbort HookAction = "hard_abort"
+)
+
+type HookDecision struct {
+ Action HookAction
+ Reason string
+}
+
+func (d HookDecision) normalizedAction() HookAction {
+ if d.Action == "" {
+ return HookActionContinue
+ }
+ return d.Action
+}
+
+type ApprovalDecision struct {
+ Approved bool
+ Reason string
+}
+
+type HookRegistration struct {
+ Name string
+ Priority int
+ Hook any
+}
+
+func NamedHook(name string, hook any) HookRegistration {
+ return HookRegistration{
+ Name: name,
+ Hook: hook,
+ }
+}
+
+type EventObserver interface {
+ OnEvent(ctx context.Context, evt Event) error
+}
+
+type LLMInterceptor interface {
+ BeforeLLM(ctx context.Context, req *LLMHookRequest) (*LLMHookRequest, HookDecision, error)
+ AfterLLM(ctx context.Context, resp *LLMHookResponse) (*LLMHookResponse, HookDecision, error)
+}
+
+type ToolInterceptor interface {
+ BeforeTool(ctx context.Context, call *ToolCallHookRequest) (*ToolCallHookRequest, HookDecision, error)
+ AfterTool(ctx context.Context, result *ToolResultHookResponse) (*ToolResultHookResponse, HookDecision, error)
+}
+
+type ToolApprover interface {
+ ApproveTool(ctx context.Context, req *ToolApprovalRequest) (ApprovalDecision, error)
+}
+
+type LLMHookRequest struct {
+ Meta EventMeta
+ Model string
+ Messages []providers.Message
+ Tools []providers.ToolDefinition
+ Options map[string]any
+ Channel string
+ ChatID string
+ GracefulTerminal bool
+}
+
+func (r *LLMHookRequest) Clone() *LLMHookRequest {
+ if r == nil {
+ return nil
+ }
+ cloned := *r
+ cloned.Messages = cloneProviderMessages(r.Messages)
+ cloned.Tools = cloneToolDefinitions(r.Tools)
+ cloned.Options = cloneStringAnyMap(r.Options)
+ return &cloned
+}
+
+type LLMHookResponse struct {
+ Meta EventMeta
+ Model string
+ Response *providers.LLMResponse
+ Channel string
+ ChatID string
+}
+
+func (r *LLMHookResponse) Clone() *LLMHookResponse {
+ if r == nil {
+ return nil
+ }
+ cloned := *r
+ cloned.Response = cloneLLMResponse(r.Response)
+ return &cloned
+}
+
+type ToolCallHookRequest struct {
+ Meta EventMeta
+ Tool string
+ Arguments map[string]any
+ Channel string
+ ChatID string
+}
+
+func (r *ToolCallHookRequest) Clone() *ToolCallHookRequest {
+ if r == nil {
+ return nil
+ }
+ cloned := *r
+ cloned.Arguments = cloneStringAnyMap(r.Arguments)
+ return &cloned
+}
+
+type ToolApprovalRequest struct {
+ Meta EventMeta
+ Tool string
+ Arguments map[string]any
+ Channel string
+ ChatID string
+}
+
+func (r *ToolApprovalRequest) Clone() *ToolApprovalRequest {
+ if r == nil {
+ return nil
+ }
+ cloned := *r
+ cloned.Arguments = cloneStringAnyMap(r.Arguments)
+ return &cloned
+}
+
+type ToolResultHookResponse struct {
+ Meta EventMeta
+ Tool string
+ Arguments map[string]any
+ Result *tools.ToolResult
+ Duration time.Duration
+ Channel string
+ ChatID string
+}
+
+func (r *ToolResultHookResponse) Clone() *ToolResultHookResponse {
+ if r == nil {
+ return nil
+ }
+ cloned := *r
+ cloned.Arguments = cloneStringAnyMap(r.Arguments)
+ cloned.Result = cloneToolResult(r.Result)
+ return &cloned
+}
+
+type HookManager struct {
+ eventBus *EventBus
+ observerTimeout time.Duration
+ interceptorTimeout time.Duration
+ approvalTimeout time.Duration
+
+ mu sync.RWMutex
+ hooks map[string]HookRegistration
+ ordered []HookRegistration
+
+ sub EventSubscription
+ done chan struct{}
+ closeOnce sync.Once
+}
+
+func NewHookManager(eventBus *EventBus) *HookManager {
+ hm := &HookManager{
+ eventBus: eventBus,
+ observerTimeout: defaultHookObserverTimeout,
+ interceptorTimeout: defaultHookInterceptorTimeout,
+ approvalTimeout: defaultHookApprovalTimeout,
+ hooks: make(map[string]HookRegistration),
+ done: make(chan struct{}),
+ }
+
+ if eventBus == nil {
+ close(hm.done)
+ return hm
+ }
+
+ hm.sub = eventBus.Subscribe(hookObserverBufferSize)
+ go hm.dispatchEvents()
+ return hm
+}
+
+func (hm *HookManager) Close() {
+ if hm == nil {
+ return
+ }
+
+ hm.closeOnce.Do(func() {
+ if hm.eventBus != nil {
+ hm.eventBus.Unsubscribe(hm.sub.ID)
+ }
+ <-hm.done
+ })
+}
+
+func (hm *HookManager) Mount(reg HookRegistration) error {
+ if hm == nil {
+ return fmt.Errorf("hook manager is nil")
+ }
+ if reg.Name == "" {
+ return fmt.Errorf("hook name is required")
+ }
+ if reg.Hook == nil {
+ return fmt.Errorf("hook %q is nil", reg.Name)
+ }
+
+ hm.mu.Lock()
+ defer hm.mu.Unlock()
+
+ hm.hooks[reg.Name] = reg
+ hm.rebuildOrdered()
+ return nil
+}
+
+func (hm *HookManager) Unmount(name string) {
+ if hm == nil || name == "" {
+ return
+ }
+
+ hm.mu.Lock()
+ defer hm.mu.Unlock()
+
+ delete(hm.hooks, name)
+ hm.rebuildOrdered()
+}
+
+func (hm *HookManager) dispatchEvents() {
+ defer close(hm.done)
+
+ for evt := range hm.sub.C {
+ for _, reg := range hm.snapshotHooks() {
+ observer, ok := reg.Hook.(EventObserver)
+ if !ok {
+ continue
+ }
+ hm.runObserver(reg.Name, observer, evt)
+ }
+ }
+}
+
+func (hm *HookManager) BeforeLLM(ctx context.Context, req *LLMHookRequest) (*LLMHookRequest, HookDecision) {
+ if hm == nil || req == nil {
+ return req, HookDecision{Action: HookActionContinue}
+ }
+
+ current := req.Clone()
+ for _, reg := range hm.snapshotHooks() {
+ interceptor, ok := reg.Hook.(LLMInterceptor)
+ if !ok {
+ continue
+ }
+
+ next, decision, ok := hm.callBeforeLLM(ctx, reg.Name, interceptor, current.Clone())
+ if !ok {
+ continue
+ }
+
+ switch decision.normalizedAction() {
+ case HookActionContinue, HookActionModify:
+ if next != nil {
+ current = next
+ }
+ case HookActionAbortTurn, HookActionHardAbort:
+ return current, decision
+ default:
+ hm.logUnsupportedAction(reg.Name, "before_llm", decision.Action)
+ }
+ }
+ return current, HookDecision{Action: HookActionContinue}
+}
+
+func (hm *HookManager) AfterLLM(ctx context.Context, resp *LLMHookResponse) (*LLMHookResponse, HookDecision) {
+ if hm == nil || resp == nil {
+ return resp, HookDecision{Action: HookActionContinue}
+ }
+
+ current := resp.Clone()
+ for _, reg := range hm.snapshotHooks() {
+ interceptor, ok := reg.Hook.(LLMInterceptor)
+ if !ok {
+ continue
+ }
+
+ next, decision, ok := hm.callAfterLLM(ctx, reg.Name, interceptor, current.Clone())
+ if !ok {
+ continue
+ }
+
+ switch decision.normalizedAction() {
+ case HookActionContinue, HookActionModify:
+ if next != nil {
+ current = next
+ }
+ case HookActionAbortTurn, HookActionHardAbort:
+ return current, decision
+ default:
+ hm.logUnsupportedAction(reg.Name, "after_llm", decision.Action)
+ }
+ }
+ return current, HookDecision{Action: HookActionContinue}
+}
+
+func (hm *HookManager) BeforeTool(
+ ctx context.Context,
+ call *ToolCallHookRequest,
+) (*ToolCallHookRequest, HookDecision) {
+ if hm == nil || call == nil {
+ return call, HookDecision{Action: HookActionContinue}
+ }
+
+ current := call.Clone()
+ for _, reg := range hm.snapshotHooks() {
+ interceptor, ok := reg.Hook.(ToolInterceptor)
+ if !ok {
+ continue
+ }
+
+ next, decision, ok := hm.callBeforeTool(ctx, reg.Name, interceptor, current.Clone())
+ if !ok {
+ continue
+ }
+
+ switch decision.normalizedAction() {
+ case HookActionContinue, HookActionModify:
+ if next != nil {
+ current = next
+ }
+ case HookActionDenyTool, HookActionAbortTurn, HookActionHardAbort:
+ return current, decision
+ default:
+ hm.logUnsupportedAction(reg.Name, "before_tool", decision.Action)
+ }
+ }
+ return current, HookDecision{Action: HookActionContinue}
+}
+
+func (hm *HookManager) AfterTool(
+ ctx context.Context,
+ result *ToolResultHookResponse,
+) (*ToolResultHookResponse, HookDecision) {
+ if hm == nil || result == nil {
+ return result, HookDecision{Action: HookActionContinue}
+ }
+
+ current := result.Clone()
+ for _, reg := range hm.snapshotHooks() {
+ interceptor, ok := reg.Hook.(ToolInterceptor)
+ if !ok {
+ continue
+ }
+
+ next, decision, ok := hm.callAfterTool(ctx, reg.Name, interceptor, current.Clone())
+ if !ok {
+ continue
+ }
+
+ switch decision.normalizedAction() {
+ case HookActionContinue, HookActionModify:
+ if next != nil {
+ current = next
+ }
+ case HookActionAbortTurn, HookActionHardAbort:
+ return current, decision
+ default:
+ hm.logUnsupportedAction(reg.Name, "after_tool", decision.Action)
+ }
+ }
+ return current, HookDecision{Action: HookActionContinue}
+}
+
+func (hm *HookManager) ApproveTool(ctx context.Context, req *ToolApprovalRequest) ApprovalDecision {
+ if hm == nil || req == nil {
+ return ApprovalDecision{Approved: true}
+ }
+
+ for _, reg := range hm.snapshotHooks() {
+ approver, ok := reg.Hook.(ToolApprover)
+ if !ok {
+ continue
+ }
+
+ decision, ok := hm.callApproveTool(ctx, reg.Name, approver, req.Clone())
+ if !ok {
+ return ApprovalDecision{
+ Approved: false,
+ Reason: fmt.Sprintf("tool approval hook %q failed", reg.Name),
+ }
+ }
+ if !decision.Approved {
+ return decision
+ }
+ }
+
+ return ApprovalDecision{Approved: true}
+}
+
+func (hm *HookManager) rebuildOrdered() {
+ hm.ordered = hm.ordered[:0]
+ for _, reg := range hm.hooks {
+ hm.ordered = append(hm.ordered, reg)
+ }
+ sort.SliceStable(hm.ordered, func(i, j int) bool {
+ if hm.ordered[i].Priority == hm.ordered[j].Priority {
+ return hm.ordered[i].Name < hm.ordered[j].Name
+ }
+ return hm.ordered[i].Priority < hm.ordered[j].Priority
+ })
+}
+
+func (hm *HookManager) snapshotHooks() []HookRegistration {
+ hm.mu.RLock()
+ defer hm.mu.RUnlock()
+
+ snapshot := make([]HookRegistration, len(hm.ordered))
+ copy(snapshot, hm.ordered)
+ return snapshot
+}
+
+func (hm *HookManager) runObserver(name string, observer EventObserver, evt Event) {
+ ctx, cancel := context.WithTimeout(context.Background(), hm.observerTimeout)
+ defer cancel()
+
+ done := make(chan error, 1)
+ go func() {
+ done <- observer.OnEvent(ctx, evt)
+ }()
+
+ select {
+ case err := <-done:
+ if err != nil {
+ logger.WarnCF("hooks", "Event observer failed", map[string]any{
+ "hook": name,
+ "event": evt.Kind.String(),
+ "error": err.Error(),
+ })
+ }
+ case <-ctx.Done():
+ logger.WarnCF("hooks", "Event observer timed out", map[string]any{
+ "hook": name,
+ "event": evt.Kind.String(),
+ "timeout_ms": hm.observerTimeout.Milliseconds(),
+ })
+ }
+}
+
+func (hm *HookManager) callBeforeLLM(
+ parent context.Context,
+ name string,
+ interceptor LLMInterceptor,
+ req *LLMHookRequest,
+) (*LLMHookRequest, HookDecision, bool) {
+ return runInterceptorHook(
+ parent,
+ hm.interceptorTimeout,
+ name,
+ "before_llm",
+ func(ctx context.Context) (*LLMHookRequest, HookDecision, error) {
+ return interceptor.BeforeLLM(ctx, req)
+ },
+ )
+}
+
+func (hm *HookManager) callAfterLLM(
+ parent context.Context,
+ name string,
+ interceptor LLMInterceptor,
+ resp *LLMHookResponse,
+) (*LLMHookResponse, HookDecision, bool) {
+ return runInterceptorHook(
+ parent,
+ hm.interceptorTimeout,
+ name,
+ "after_llm",
+ func(ctx context.Context) (*LLMHookResponse, HookDecision, error) {
+ return interceptor.AfterLLM(ctx, resp)
+ },
+ )
+}
+
+func (hm *HookManager) callBeforeTool(
+ parent context.Context,
+ name string,
+ interceptor ToolInterceptor,
+ call *ToolCallHookRequest,
+) (*ToolCallHookRequest, HookDecision, bool) {
+ return runInterceptorHook(
+ parent,
+ hm.interceptorTimeout,
+ name,
+ "before_tool",
+ func(ctx context.Context) (*ToolCallHookRequest, HookDecision, error) {
+ return interceptor.BeforeTool(ctx, call)
+ },
+ )
+}
+
+func (hm *HookManager) callAfterTool(
+ parent context.Context,
+ name string,
+ interceptor ToolInterceptor,
+ resultView *ToolResultHookResponse,
+) (*ToolResultHookResponse, HookDecision, bool) {
+ return runInterceptorHook(
+ parent,
+ hm.interceptorTimeout,
+ name,
+ "after_tool",
+ func(ctx context.Context) (*ToolResultHookResponse, HookDecision, error) {
+ return interceptor.AfterTool(ctx, resultView)
+ },
+ )
+}
+
+func (hm *HookManager) callApproveTool(
+ parent context.Context,
+ name string,
+ approver ToolApprover,
+ req *ToolApprovalRequest,
+) (ApprovalDecision, bool) {
+ return runApprovalHook(
+ parent,
+ hm.approvalTimeout,
+ name,
+ "approve_tool",
+ func(ctx context.Context) (ApprovalDecision, error) {
+ return approver.ApproveTool(ctx, req)
+ },
+ )
+}
+
+func runInterceptorHook[T any](
+ parent context.Context,
+ timeout time.Duration,
+ name string,
+ stage string,
+ fn func(ctx context.Context) (T, HookDecision, error),
+) (T, HookDecision, bool) {
+ var zero T
+
+ ctx, cancel := context.WithTimeout(parent, timeout)
+ defer cancel()
+
+ type result struct {
+ value T
+ decision HookDecision
+ err error
+ }
+ done := make(chan result, 1)
+ go func() {
+ value, decision, err := fn(ctx)
+ done <- result{value: value, decision: decision, err: err}
+ }()
+
+ select {
+ case res := <-done:
+ if res.err != nil {
+ logger.WarnCF("hooks", "Interceptor hook failed", map[string]any{
+ "hook": name,
+ "stage": stage,
+ "error": res.err.Error(),
+ })
+ return zero, HookDecision{}, false
+ }
+ return res.value, res.decision, true
+ case <-ctx.Done():
+ logger.WarnCF("hooks", "Interceptor hook timed out", map[string]any{
+ "hook": name,
+ "stage": stage,
+ "timeout_ms": timeout.Milliseconds(),
+ })
+ return zero, HookDecision{}, false
+ }
+}
+
+func runApprovalHook(
+ parent context.Context,
+ timeout time.Duration,
+ name string,
+ stage string,
+ fn func(ctx context.Context) (ApprovalDecision, error),
+) (ApprovalDecision, bool) {
+ ctx, cancel := context.WithTimeout(parent, timeout)
+ defer cancel()
+
+ type result struct {
+ decision ApprovalDecision
+ err error
+ }
+ done := make(chan result, 1)
+ go func() {
+ decision, err := fn(ctx)
+ done <- result{decision: decision, err: err}
+ }()
+
+ select {
+ case res := <-done:
+ if res.err != nil {
+ logger.WarnCF("hooks", "Approval hook failed", map[string]any{
+ "hook": name,
+ "stage": stage,
+ "error": res.err.Error(),
+ })
+ return ApprovalDecision{}, false
+ }
+ return res.decision, true
+ case <-ctx.Done():
+ logger.WarnCF("hooks", "Approval hook timed out", map[string]any{
+ "hook": name,
+ "stage": stage,
+ "timeout_ms": timeout.Milliseconds(),
+ })
+ return ApprovalDecision{
+ Approved: false,
+ Reason: fmt.Sprintf("tool approval hook %q timed out", name),
+ }, true
+ }
+}
+
+func (hm *HookManager) logUnsupportedAction(name, stage string, action HookAction) {
+ logger.WarnCF("hooks", "Hook returned unsupported action for stage", map[string]any{
+ "hook": name,
+ "stage": stage,
+ "action": action,
+ })
+}
+
+func cloneProviderMessages(messages []providers.Message) []providers.Message {
+ if len(messages) == 0 {
+ return nil
+ }
+
+ cloned := make([]providers.Message, len(messages))
+ for i, msg := range messages {
+ cloned[i] = msg
+ if len(msg.Media) > 0 {
+ cloned[i].Media = append([]string(nil), msg.Media...)
+ }
+ if len(msg.SystemParts) > 0 {
+ cloned[i].SystemParts = append([]providers.ContentBlock(nil), msg.SystemParts...)
+ }
+ if len(msg.ToolCalls) > 0 {
+ cloned[i].ToolCalls = cloneProviderToolCalls(msg.ToolCalls)
+ }
+ }
+ return cloned
+}
+
+func cloneProviderToolCalls(calls []providers.ToolCall) []providers.ToolCall {
+ if len(calls) == 0 {
+ return nil
+ }
+
+ cloned := make([]providers.ToolCall, len(calls))
+ for i, call := range calls {
+ cloned[i] = call
+ if call.Function != nil {
+ fn := *call.Function
+ cloned[i].Function = &fn
+ }
+ if call.Arguments != nil {
+ cloned[i].Arguments = cloneStringAnyMap(call.Arguments)
+ }
+ if call.ExtraContent != nil {
+ extra := *call.ExtraContent
+ if call.ExtraContent.Google != nil {
+ google := *call.ExtraContent.Google
+ extra.Google = &google
+ }
+ cloned[i].ExtraContent = &extra
+ }
+ }
+ return cloned
+}
+
+func cloneToolDefinitions(defs []providers.ToolDefinition) []providers.ToolDefinition {
+ if len(defs) == 0 {
+ return nil
+ }
+
+ cloned := make([]providers.ToolDefinition, len(defs))
+ for i, def := range defs {
+ cloned[i] = def
+ cloned[i].Function.Parameters = cloneStringAnyMap(def.Function.Parameters)
+ }
+ return cloned
+}
+
+func cloneLLMResponse(resp *providers.LLMResponse) *providers.LLMResponse {
+ if resp == nil {
+ return nil
+ }
+ cloned := *resp
+ cloned.ToolCalls = cloneProviderToolCalls(resp.ToolCalls)
+ if len(resp.ReasoningDetails) > 0 {
+ cloned.ReasoningDetails = append(cloned.ReasoningDetails[:0:0], resp.ReasoningDetails...)
+ }
+ if resp.Usage != nil {
+ usage := *resp.Usage
+ cloned.Usage = &usage
+ }
+ return &cloned
+}
+
+func cloneStringAnyMap(src map[string]any) map[string]any {
+ if len(src) == 0 {
+ return nil
+ }
+
+ cloned := make(map[string]any, len(src))
+ for k, v := range src {
+ cloned[k] = v
+ }
+ return cloned
+}
+
+func cloneToolResult(result *tools.ToolResult) *tools.ToolResult {
+ if result == nil {
+ return nil
+ }
+
+ cloned := *result
+ if len(result.Media) > 0 {
+ cloned.Media = append([]string(nil), result.Media...)
+ }
+ return &cloned
+}
diff --git a/pkg/agent/hooks_test.go b/pkg/agent/hooks_test.go
new file mode 100644
index 000000000..6607b5fe7
--- /dev/null
+++ b/pkg/agent/hooks_test.go
@@ -0,0 +1,312 @@
+package agent
+
+import (
+ "context"
+ "os"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/config"
+ "github.com/sipeed/picoclaw/pkg/providers"
+ "github.com/sipeed/picoclaw/pkg/tools"
+)
+
+func newHookTestLoop(
+ t *testing.T,
+ provider providers.LLMProvider,
+) (*AgentLoop, *AgentInstance, func()) {
+ t.Helper()
+
+ tmpDir, err := os.MkdirTemp("", "agent-hooks-*")
+ if err != nil {
+ t.Fatalf("failed to create temp dir: %v", err)
+ }
+
+ cfg := &config.Config{
+ Agents: config.AgentsConfig{
+ Defaults: config.AgentDefaults{
+ Workspace: tmpDir,
+ Model: "test-model",
+ MaxTokens: 4096,
+ MaxToolIterations: 10,
+ },
+ },
+ }
+
+ al := NewAgentLoop(cfg, bus.NewMessageBus(), provider)
+ agent := al.registry.GetDefaultAgent()
+ if agent == nil {
+ t.Fatal("expected default agent")
+ }
+
+ return al, agent, func() {
+ al.Close()
+ _ = os.RemoveAll(tmpDir)
+ }
+}
+
+type llmHookTestProvider struct {
+ mu sync.Mutex
+ lastModel string
+}
+
+func (p *llmHookTestProvider) Chat(
+ ctx context.Context,
+ messages []providers.Message,
+ tools []providers.ToolDefinition,
+ model string,
+ opts map[string]any,
+) (*providers.LLMResponse, error) {
+ p.mu.Lock()
+ p.lastModel = model
+ p.mu.Unlock()
+
+ return &providers.LLMResponse{
+ Content: "provider content",
+ }, nil
+}
+
+func (p *llmHookTestProvider) GetDefaultModel() string {
+ return "llm-hook-provider"
+}
+
+type llmObserverHook struct {
+ eventCh chan Event
+}
+
+func (h *llmObserverHook) OnEvent(ctx context.Context, evt Event) error {
+ if evt.Kind == EventKindTurnEnd {
+ select {
+ case h.eventCh <- evt:
+ default:
+ }
+ }
+ return nil
+}
+
+func (h *llmObserverHook) BeforeLLM(
+ ctx context.Context,
+ req *LLMHookRequest,
+) (*LLMHookRequest, HookDecision, error) {
+ next := req.Clone()
+ next.Model = "hook-model"
+ return next, HookDecision{Action: HookActionModify}, nil
+}
+
+func (h *llmObserverHook) AfterLLM(
+ ctx context.Context,
+ resp *LLMHookResponse,
+) (*LLMHookResponse, HookDecision, error) {
+ next := resp.Clone()
+ next.Response.Content = "hooked content"
+ return next, HookDecision{Action: HookActionModify}, nil
+}
+
+func TestAgentLoop_Hooks_ObserverAndLLMInterceptor(t *testing.T) {
+ provider := &llmHookTestProvider{}
+ al, agent, cleanup := newHookTestLoop(t, provider)
+ defer cleanup()
+
+ hook := &llmObserverHook{eventCh: make(chan Event, 1)}
+ if err := al.MountHook(NamedHook("llm-observer", hook)); err != nil {
+ t.Fatalf("MountHook failed: %v", err)
+ }
+
+ resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
+ SessionKey: "session-1",
+ Channel: "cli",
+ ChatID: "direct",
+ UserMessage: "hello",
+ DefaultResponse: defaultResponse,
+ EnableSummary: false,
+ SendResponse: false,
+ })
+ if err != nil {
+ t.Fatalf("runAgentLoop failed: %v", err)
+ }
+ if resp != "hooked content" {
+ t.Fatalf("expected hooked content, got %q", resp)
+ }
+
+ provider.mu.Lock()
+ lastModel := provider.lastModel
+ provider.mu.Unlock()
+ if lastModel != "hook-model" {
+ t.Fatalf("expected model hook-model, got %q", lastModel)
+ }
+
+ select {
+ case evt := <-hook.eventCh:
+ if evt.Kind != EventKindTurnEnd {
+ t.Fatalf("expected turn end event, got %v", evt.Kind)
+ }
+ case <-time.After(2 * time.Second):
+ t.Fatal("timed out waiting for hook observer event")
+ }
+}
+
+type toolHookProvider struct {
+ mu sync.Mutex
+ calls int
+}
+
+func (p *toolHookProvider) Chat(
+ ctx context.Context,
+ messages []providers.Message,
+ tools []providers.ToolDefinition,
+ model string,
+ opts map[string]any,
+) (*providers.LLMResponse, error) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ p.calls++
+ if p.calls == 1 {
+ return &providers.LLMResponse{
+ ToolCalls: []providers.ToolCall{
+ {
+ ID: "call-1",
+ Name: "echo_text",
+ Arguments: map[string]any{"text": "original"},
+ },
+ },
+ }, nil
+ }
+
+ last := messages[len(messages)-1]
+ return &providers.LLMResponse{
+ Content: last.Content,
+ }, nil
+}
+
+func (p *toolHookProvider) GetDefaultModel() string {
+ return "tool-hook-provider"
+}
+
+type echoTextTool struct{}
+
+func (t *echoTextTool) Name() string {
+ return "echo_text"
+}
+
+func (t *echoTextTool) Description() string {
+ return "echo a text argument"
+}
+
+func (t *echoTextTool) Parameters() map[string]any {
+ return map[string]any{
+ "type": "object",
+ "properties": map[string]any{
+ "text": map[string]any{
+ "type": "string",
+ },
+ },
+ }
+}
+
+func (t *echoTextTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
+ text, _ := args["text"].(string)
+ return tools.SilentResult(text)
+}
+
+type toolRewriteHook struct{}
+
+func (h *toolRewriteHook) BeforeTool(
+ ctx context.Context,
+ call *ToolCallHookRequest,
+) (*ToolCallHookRequest, HookDecision, error) {
+ next := call.Clone()
+ next.Arguments["text"] = "modified"
+ return next, HookDecision{Action: HookActionModify}, nil
+}
+
+func (h *toolRewriteHook) AfterTool(
+ ctx context.Context,
+ result *ToolResultHookResponse,
+) (*ToolResultHookResponse, HookDecision, error) {
+ next := result.Clone()
+ next.Result.ForLLM = "after:" + next.Result.ForLLM
+ return next, HookDecision{Action: HookActionModify}, nil
+}
+
+func TestAgentLoop_Hooks_ToolInterceptorCanRewrite(t *testing.T) {
+ provider := &toolHookProvider{}
+ al, agent, cleanup := newHookTestLoop(t, provider)
+ defer cleanup()
+
+ al.RegisterTool(&echoTextTool{})
+ if err := al.MountHook(NamedHook("tool-rewrite", &toolRewriteHook{})); err != nil {
+ t.Fatalf("MountHook failed: %v", err)
+ }
+
+ resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
+ SessionKey: "session-1",
+ Channel: "cli",
+ ChatID: "direct",
+ UserMessage: "run tool",
+ DefaultResponse: defaultResponse,
+ EnableSummary: false,
+ SendResponse: false,
+ })
+ if err != nil {
+ t.Fatalf("runAgentLoop failed: %v", err)
+ }
+ if resp != "after:modified" {
+ t.Fatalf("expected rewritten tool result, got %q", resp)
+ }
+}
+
+type denyApprovalHook struct{}
+
+func (h *denyApprovalHook) ApproveTool(ctx context.Context, req *ToolApprovalRequest) (ApprovalDecision, error) {
+ return ApprovalDecision{
+ Approved: false,
+ Reason: "blocked",
+ }, nil
+}
+
+func TestAgentLoop_Hooks_ToolApproverCanDeny(t *testing.T) {
+ provider := &toolHookProvider{}
+ al, agent, cleanup := newHookTestLoop(t, provider)
+ defer cleanup()
+
+ al.RegisterTool(&echoTextTool{})
+ if err := al.MountHook(NamedHook("deny-approval", &denyApprovalHook{})); err != nil {
+ t.Fatalf("MountHook failed: %v", err)
+ }
+
+ sub := al.SubscribeEvents(16)
+ defer al.UnsubscribeEvents(sub.ID)
+
+ resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
+ SessionKey: "session-1",
+ Channel: "cli",
+ ChatID: "direct",
+ UserMessage: "run tool",
+ DefaultResponse: defaultResponse,
+ EnableSummary: false,
+ SendResponse: false,
+ })
+ if err != nil {
+ t.Fatalf("runAgentLoop failed: %v", err)
+ }
+ expected := "Tool execution denied by approval hook: blocked"
+ if resp != expected {
+ t.Fatalf("expected %q, got %q", expected, resp)
+ }
+
+ events := collectEventStream(sub.C)
+ skippedEvt, ok := findEvent(events, EventKindToolExecSkipped)
+ if !ok {
+ t.Fatal("expected tool skipped event")
+ }
+ payload, ok := skippedEvt.Payload.(ToolExecSkippedPayload)
+ if !ok {
+ t.Fatalf("expected ToolExecSkippedPayload, got %T", skippedEvt.Payload)
+ }
+ if payload.Reason != expected {
+ t.Fatalf("expected skipped reason %q, got %q", expected, payload.Reason)
+ }
+}
diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go
index f54482ae8..a85abcb60 100644
--- a/pkg/agent/loop.go
+++ b/pkg/agent/loop.go
@@ -40,6 +40,7 @@ type AgentLoop struct {
registry *AgentRegistry
state *state.Manager
eventBus *EventBus
+ hooks *HookManager
running atomic.Bool
summarizing sync.Map
fallback *providers.FallbackChain
@@ -108,17 +109,19 @@ func NewAgentLoop(
stateManager = state.NewManager(defaultAgent.Workspace)
}
+ eventBus := NewEventBus()
al := &AgentLoop{
bus: msgBus,
cfg: cfg,
registry: registry,
state: stateManager,
- eventBus: NewEventBus(),
+ eventBus: eventBus,
summarizing: sync.Map{},
fallback: fallbackChain,
cmdRegistry: commands.NewRegistry(commands.BuiltinDefinitions()),
steering: newSteeringQueue(parseSteeringMode(cfg.Agents.Defaults.SteeringMode)),
}
+ al.hooks = NewHookManager(eventBus)
return al
}
@@ -460,11 +463,30 @@ func (al *AgentLoop) Close() {
}
al.GetRegistry().Close()
+ if al.hooks != nil {
+ al.hooks.Close()
+ }
if al.eventBus != nil {
al.eventBus.Close()
}
}
+// MountHook registers an in-process hook on the agent loop.
+func (al *AgentLoop) MountHook(reg HookRegistration) error {
+ if al == nil || al.hooks == nil {
+ return fmt.Errorf("hook manager is not initialized")
+ }
+ return al.hooks.Mount(reg)
+}
+
+// UnmountHook removes a previously registered in-process hook.
+func (al *AgentLoop) UnmountHook(name string) {
+ if al == nil || al.hooks == nil {
+ return
+ }
+ al.hooks.Unmount(name)
+}
+
// SubscribeEvents registers a subscriber for agent-loop events.
func (al *AgentLoop) SubscribeEvents(buffer int) EventSubscription {
if al == nil || al.eventBus == nil {
@@ -544,6 +566,31 @@ func cloneEventArguments(args map[string]any) map[string]any {
return cloned
}
+func (al *AgentLoop) hookAbortError(ts *turnState, stage string, decision HookDecision) error {
+ reason := decision.Reason
+ if reason == "" {
+ reason = "hook requested turn abort"
+ }
+
+ err := fmt.Errorf("hook aborted turn during %s: %s", stage, reason)
+ al.emitEvent(
+ EventKindError,
+ ts.eventMeta("hooks", "turn.error"),
+ ErrorPayload{
+ Stage: "hook." + stage,
+ Message: err.Error(),
+ },
+ )
+ return err
+}
+
+func hookDeniedToolContent(prefix, reason string) string {
+ if reason == "" {
+ return prefix
+ }
+ return prefix + ": " + reason
+}
+
func (al *AgentLoop) logEvent(evt Event) {
fields := map[string]any{
"event_kind": evt.Kind.String(),
@@ -1418,36 +1465,6 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
ts.markGracefulTerminalUsed()
}
- al.emitEvent(
- EventKindLLMRequest,
- ts.eventMeta("runTurn", "turn.llm.request"),
- LLMRequestPayload{
- Model: activeModel,
- MessagesCount: len(callMessages),
- ToolsCount: len(providerToolDefs),
- MaxTokens: ts.agent.MaxTokens,
- Temperature: ts.agent.Temperature,
- },
- )
-
- logger.DebugCF("agent", "LLM request",
- map[string]any{
- "agent_id": ts.agent.ID,
- "iteration": iteration,
- "model": activeModel,
- "messages_count": len(callMessages),
- "tools_count": len(providerToolDefs),
- "max_tokens": ts.agent.MaxTokens,
- "temperature": ts.agent.Temperature,
- "system_prompt_len": len(callMessages[0].Content),
- })
- logger.DebugCF("agent", "Full LLM request",
- map[string]any{
- "iteration": iteration,
- "messages_json": formatMessagesForLog(callMessages),
- "tools_json": formatToolsForLog(providerToolDefs),
- })
-
llmOpts := map[string]any{
"max_tokens": ts.agent.MaxTokens,
"temperature": ts.agent.Temperature,
@@ -1462,6 +1479,66 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
}
}
+ llmModel := activeModel
+ if al.hooks != nil {
+ llmReq, decision := al.hooks.BeforeLLM(turnCtx, &LLMHookRequest{
+ Meta: ts.eventMeta("runTurn", "turn.llm.request"),
+ Model: llmModel,
+ Messages: callMessages,
+ Tools: providerToolDefs,
+ Options: llmOpts,
+ Channel: ts.channel,
+ ChatID: ts.chatID,
+ GracefulTerminal: gracefulTerminal,
+ })
+ switch decision.normalizedAction() {
+ case HookActionContinue, HookActionModify:
+ if llmReq != nil {
+ llmModel = llmReq.Model
+ callMessages = llmReq.Messages
+ providerToolDefs = llmReq.Tools
+ llmOpts = llmReq.Options
+ }
+ case HookActionAbortTurn:
+ turnStatus = TurnEndStatusError
+ return turnResult{}, al.hookAbortError(ts, "before_llm", decision)
+ case HookActionHardAbort:
+ _ = ts.requestHardAbort()
+ turnStatus = TurnEndStatusAborted
+ return al.abortTurn(ts)
+ }
+ }
+
+ al.emitEvent(
+ EventKindLLMRequest,
+ ts.eventMeta("runTurn", "turn.llm.request"),
+ LLMRequestPayload{
+ Model: llmModel,
+ MessagesCount: len(callMessages),
+ ToolsCount: len(providerToolDefs),
+ MaxTokens: ts.agent.MaxTokens,
+ Temperature: ts.agent.Temperature,
+ },
+ )
+
+ logger.DebugCF("agent", "LLM request",
+ map[string]any{
+ "agent_id": ts.agent.ID,
+ "iteration": iteration,
+ "model": llmModel,
+ "messages_count": len(callMessages),
+ "tools_count": len(providerToolDefs),
+ "max_tokens": ts.agent.MaxTokens,
+ "temperature": ts.agent.Temperature,
+ "system_prompt_len": len(callMessages[0].Content),
+ })
+ logger.DebugCF("agent", "Full LLM request",
+ map[string]any{
+ "iteration": iteration,
+ "messages_json": formatMessagesForLog(callMessages),
+ "tools_json": formatToolsForLog(providerToolDefs),
+ })
+
callLLM := func(messagesForCall []providers.Message, toolDefsForCall []providers.ToolDefinition) (*providers.LLMResponse, error) {
providerCtx, providerCancel := context.WithCancel(turnCtx)
ts.setProviderCancel(providerCancel)
@@ -1494,7 +1571,7 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
}
return fbResult.Response, nil
}
- return ts.agent.Provider.Chat(providerCtx, messagesForCall, toolDefsForCall, activeModel, llmOpts)
+ return ts.agent.Provider.Chat(providerCtx, messagesForCall, toolDefsForCall, llmModel, llmOpts)
}
var response *providers.LLMResponse
@@ -1626,12 +1703,35 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
map[string]any{
"agent_id": ts.agent.ID,
"iteration": iteration,
- "model": activeModel,
+ "model": llmModel,
"error": err.Error(),
})
return turnResult{}, fmt.Errorf("LLM call failed after retries: %w", err)
}
+ if al.hooks != nil {
+ llmResp, decision := al.hooks.AfterLLM(turnCtx, &LLMHookResponse{
+ Meta: ts.eventMeta("runTurn", "turn.llm.response"),
+ Model: llmModel,
+ Response: response,
+ Channel: ts.channel,
+ ChatID: ts.chatID,
+ })
+ switch decision.normalizedAction() {
+ case HookActionContinue, HookActionModify:
+ if llmResp != nil && llmResp.Response != nil {
+ response = llmResp.Response
+ }
+ case HookActionAbortTurn:
+ turnStatus = TurnEndStatusError
+ return turnResult{}, al.hookAbortError(ts, "after_llm", decision)
+ case HookActionHardAbort:
+ _ = ts.requestHardAbort()
+ turnStatus = TurnEndStatusAborted
+ return al.abortTurn(ts)
+ }
+ }
+
go al.handleReasoning(
turnCtx,
response.Reasoning,
@@ -1728,25 +1828,106 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
return al.abortTurn(ts)
}
- argsJSON, _ := json.Marshal(tc.Arguments)
+ toolName := tc.Name
+ toolArgs := cloneStringAnyMap(tc.Arguments)
+
+ if al.hooks != nil {
+ toolReq, decision := al.hooks.BeforeTool(turnCtx, &ToolCallHookRequest{
+ Meta: ts.eventMeta("runTurn", "turn.tool.before"),
+ Tool: toolName,
+ Arguments: toolArgs,
+ Channel: ts.channel,
+ ChatID: ts.chatID,
+ })
+ switch decision.normalizedAction() {
+ case HookActionContinue, HookActionModify:
+ if toolReq != nil {
+ toolName = toolReq.Tool
+ toolArgs = toolReq.Arguments
+ }
+ case HookActionDenyTool:
+ denyContent := hookDeniedToolContent("Tool execution denied by hook", decision.Reason)
+ al.emitEvent(
+ EventKindToolExecSkipped,
+ ts.eventMeta("runTurn", "turn.tool.skipped"),
+ ToolExecSkippedPayload{
+ Tool: toolName,
+ Reason: denyContent,
+ },
+ )
+ deniedMsg := providers.Message{
+ Role: "tool",
+ Content: denyContent,
+ ToolCallID: tc.ID,
+ }
+ messages = append(messages, deniedMsg)
+ if !ts.opts.NoHistory {
+ ts.agent.Sessions.AddFullMessage(ts.sessionKey, deniedMsg)
+ ts.recordPersistedMessage(deniedMsg)
+ }
+ continue
+ case HookActionAbortTurn:
+ turnStatus = TurnEndStatusError
+ return turnResult{}, al.hookAbortError(ts, "before_tool", decision)
+ case HookActionHardAbort:
+ _ = ts.requestHardAbort()
+ turnStatus = TurnEndStatusAborted
+ return al.abortTurn(ts)
+ }
+ }
+
+ if al.hooks != nil {
+ approval := al.hooks.ApproveTool(turnCtx, &ToolApprovalRequest{
+ Meta: ts.eventMeta("runTurn", "turn.tool.approve"),
+ Tool: toolName,
+ Arguments: toolArgs,
+ Channel: ts.channel,
+ ChatID: ts.chatID,
+ })
+ if !approval.Approved {
+ denyContent := hookDeniedToolContent("Tool execution denied by approval hook", approval.Reason)
+ al.emitEvent(
+ EventKindToolExecSkipped,
+ ts.eventMeta("runTurn", "turn.tool.skipped"),
+ ToolExecSkippedPayload{
+ Tool: toolName,
+ Reason: denyContent,
+ },
+ )
+ deniedMsg := providers.Message{
+ Role: "tool",
+ Content: denyContent,
+ ToolCallID: tc.ID,
+ }
+ messages = append(messages, deniedMsg)
+ if !ts.opts.NoHistory {
+ ts.agent.Sessions.AddFullMessage(ts.sessionKey, deniedMsg)
+ ts.recordPersistedMessage(deniedMsg)
+ }
+ continue
+ }
+ }
+
+ argsJSON, _ := json.Marshal(toolArgs)
argsPreview := utils.Truncate(string(argsJSON), 200)
- logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
+ logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", toolName, argsPreview),
map[string]any{
"agent_id": ts.agent.ID,
- "tool": tc.Name,
+ "tool": toolName,
"iteration": iteration,
})
al.emitEvent(
EventKindToolExecStart,
ts.eventMeta("runTurn", "turn.tool.start"),
ToolExecStartPayload{
- Tool: tc.Name,
- Arguments: cloneEventArguments(tc.Arguments),
+ Tool: toolName,
+ Arguments: cloneEventArguments(toolArgs),
},
)
- toolCall := tc
+ toolCallID := tc.ID
toolIteration := iteration
+ asyncToolName := toolName
asyncCallback := func(_ context.Context, result *tools.ToolResult) {
if !result.Silent && result.ForUser != "" {
outCtx, outCancel := context.WithTimeout(context.Background(), 5*time.Second)
@@ -1768,7 +1949,7 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
logger.InfoCF("agent", "Async tool completed, publishing result",
map[string]any{
- "tool": toolCall.Name,
+ "tool": asyncToolName,
"content_len": len(content),
"channel": ts.channel,
})
@@ -1776,7 +1957,7 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
EventKindFollowUpQueued,
ts.scope.meta(toolIteration, "runTurn", "turn.follow_up.queued"),
FollowUpQueuedPayload{
- SourceTool: toolCall.Name,
+ SourceTool: asyncToolName,
Channel: ts.channel,
ChatID: ts.chatID,
ContentLen: len(content),
@@ -1787,7 +1968,7 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
defer pubCancel()
_ = al.bus.PublishInbound(pubCtx, bus.InboundMessage{
Channel: "system",
- SenderID: fmt.Sprintf("async:%s", toolCall.Name),
+ SenderID: fmt.Sprintf("async:%s", asyncToolName),
ChatID: fmt.Sprintf("%s:%s", ts.channel, ts.chatID),
Content: content,
})
@@ -1796,8 +1977,8 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
toolStart := time.Now()
toolResult := ts.agent.Tools.ExecuteWithContext(
turnCtx,
- toolCall.Name,
- toolCall.Arguments,
+ toolName,
+ toolArgs,
ts.channel,
ts.chatID,
asyncCallback,
@@ -1809,6 +1990,40 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
return al.abortTurn(ts)
}
+ if al.hooks != nil {
+ toolResp, decision := al.hooks.AfterTool(turnCtx, &ToolResultHookResponse{
+ Meta: ts.eventMeta("runTurn", "turn.tool.after"),
+ Tool: toolName,
+ Arguments: toolArgs,
+ Result: toolResult,
+ Duration: toolDuration,
+ Channel: ts.channel,
+ ChatID: ts.chatID,
+ })
+ switch decision.normalizedAction() {
+ case HookActionContinue, HookActionModify:
+ if toolResp != nil {
+ if toolResp.Tool != "" {
+ toolName = toolResp.Tool
+ }
+ if toolResp.Result != nil {
+ toolResult = toolResp.Result
+ }
+ }
+ case HookActionAbortTurn:
+ turnStatus = TurnEndStatusError
+ return turnResult{}, al.hookAbortError(ts, "after_tool", decision)
+ case HookActionHardAbort:
+ _ = ts.requestHardAbort()
+ turnStatus = TurnEndStatusAborted
+ return al.abortTurn(ts)
+ }
+ }
+
+ if toolResult == nil {
+ toolResult = tools.ErrorResult("hook returned nil tool result")
+ }
+
if !toolResult.Silent && toolResult.ForUser != "" && ts.opts.SendResponse {
al.bus.PublishOutbound(ctx, bus.OutboundMessage{
Channel: ts.channel,
@@ -1817,7 +2032,7 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
})
logger.DebugCF("agent", "Sent tool result to user",
map[string]any{
- "tool": toolCall.Name,
+ "tool": toolName,
"content_len": len(toolResult.ForUser),
})
}
@@ -1850,13 +2065,13 @@ func (al *AgentLoop) runTurn(ctx context.Context, ts *turnState) (turnResult, er
toolResultMsg := providers.Message{
Role: "tool",
Content: contentForLLM,
- ToolCallID: toolCall.ID,
+ ToolCallID: toolCallID,
}
al.emitEvent(
EventKindToolExecEnd,
ts.eventMeta("runTurn", "turn.tool.end"),
ToolExecEndPayload{
- Tool: toolCall.Name,
+ Tool: toolName,
Duration: toolDuration,
ForLLMLen: len(contentForLLM),
ForUserLen: len(toolResult.ForUser),
From 337e43e5a5a2f0a12598a3ac982419bacdde0b15 Mon Sep 17 00:00:00 2001
From: Hoshina
Date: Sat, 21 Mar 2026 19:46:16 +0800
Subject: [PATCH 25/26] feat(agent): add configurable hook mounting
---
pkg/agent/hook_mount.go | 317 ++++++++++++++++++++
pkg/agent/hook_mount_test.go | 179 ++++++++++++
pkg/agent/hook_process.go | 511 +++++++++++++++++++++++++++++++++
pkg/agent/hook_process_test.go | 339 ++++++++++++++++++++++
pkg/agent/hooks.go | 130 ++++++---
pkg/agent/hooks_test.go | 33 +++
pkg/agent/loop.go | 18 ++
pkg/agent/steering.go | 6 +
pkg/config/config.go | 31 ++
pkg/config/config_test.go | 98 +++++++
pkg/config/defaults.go | 8 +
11 files changed, 1634 insertions(+), 36 deletions(-)
create mode 100644 pkg/agent/hook_mount.go
create mode 100644 pkg/agent/hook_mount_test.go
create mode 100644 pkg/agent/hook_process.go
create mode 100644 pkg/agent/hook_process_test.go
diff --git a/pkg/agent/hook_mount.go b/pkg/agent/hook_mount.go
new file mode 100644
index 000000000..c92145f1f
--- /dev/null
+++ b/pkg/agent/hook_mount.go
@@ -0,0 +1,317 @@
+package agent
+
+import (
+ "context"
+ "fmt"
+ "sort"
+ "sync"
+ "time"
+
+ "github.com/sipeed/picoclaw/pkg/config"
+)
+
+type hookRuntime struct {
+ initOnce sync.Once
+ mu sync.Mutex
+ initErr error
+ mounted []string
+}
+
+func (r *hookRuntime) setInitErr(err error) {
+ r.mu.Lock()
+ r.initErr = err
+ r.mu.Unlock()
+}
+
+func (r *hookRuntime) getInitErr() error {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ return r.initErr
+}
+
+func (r *hookRuntime) setMounted(names []string) {
+ r.mu.Lock()
+ r.mounted = append([]string(nil), names...)
+ r.mu.Unlock()
+}
+
+func (r *hookRuntime) reset(al *AgentLoop) {
+ r.mu.Lock()
+ names := append([]string(nil), r.mounted...)
+ r.mounted = nil
+ r.initErr = nil
+ r.initOnce = sync.Once{}
+ r.mu.Unlock()
+
+ for _, name := range names {
+ al.UnmountHook(name)
+ }
+}
+
+// BuiltinHookFactory constructs an in-process hook from config.
+type BuiltinHookFactory func(ctx context.Context, spec config.BuiltinHookConfig) (any, error)
+
+var (
+ builtinHookRegistryMu sync.RWMutex
+ builtinHookRegistry = map[string]BuiltinHookFactory{}
+)
+
+// RegisterBuiltinHook registers a named in-process hook factory for config-driven mounting.
+func RegisterBuiltinHook(name string, factory BuiltinHookFactory) error {
+ if name == "" {
+ return fmt.Errorf("builtin hook name is required")
+ }
+ if factory == nil {
+ return fmt.Errorf("builtin hook %q factory is nil", name)
+ }
+
+ builtinHookRegistryMu.Lock()
+ defer builtinHookRegistryMu.Unlock()
+
+ if _, exists := builtinHookRegistry[name]; exists {
+ return fmt.Errorf("builtin hook %q is already registered", name)
+ }
+ builtinHookRegistry[name] = factory
+ return nil
+}
+
+func unregisterBuiltinHook(name string) {
+ if name == "" {
+ return
+ }
+ builtinHookRegistryMu.Lock()
+ delete(builtinHookRegistry, name)
+ builtinHookRegistryMu.Unlock()
+}
+
+func lookupBuiltinHook(name string) (BuiltinHookFactory, bool) {
+ builtinHookRegistryMu.RLock()
+ defer builtinHookRegistryMu.RUnlock()
+
+ factory, ok := builtinHookRegistry[name]
+ return factory, ok
+}
+
+func configureHookManagerFromConfig(hm *HookManager, cfg *config.Config) {
+ if hm == nil || cfg == nil {
+ return
+ }
+ hm.ConfigureTimeouts(
+ hookTimeoutFromMS(cfg.Hooks.Defaults.ObserverTimeoutMS),
+ hookTimeoutFromMS(cfg.Hooks.Defaults.InterceptorTimeoutMS),
+ hookTimeoutFromMS(cfg.Hooks.Defaults.ApprovalTimeoutMS),
+ )
+}
+
+func hookTimeoutFromMS(ms int) time.Duration {
+ if ms <= 0 {
+ return 0
+ }
+ return time.Duration(ms) * time.Millisecond
+}
+
+func (al *AgentLoop) ensureHooksInitialized(ctx context.Context) error {
+ if al == nil || al.cfg == nil || al.hooks == nil {
+ return nil
+ }
+
+ al.hookRuntime.initOnce.Do(func() {
+ al.hookRuntime.setInitErr(al.loadConfiguredHooks(ctx))
+ })
+
+ return al.hookRuntime.getInitErr()
+}
+
+func (al *AgentLoop) loadConfiguredHooks(ctx context.Context) (err error) {
+ if al == nil || al.cfg == nil || !al.cfg.Hooks.Enabled {
+ return nil
+ }
+
+ mounted := make([]string, 0)
+ defer func() {
+ if err != nil {
+ for _, name := range mounted {
+ al.UnmountHook(name)
+ }
+ return
+ }
+ al.hookRuntime.setMounted(mounted)
+ }()
+
+ builtinNames := enabledBuiltinHookNames(al.cfg.Hooks.Builtins)
+ for _, name := range builtinNames {
+ spec := al.cfg.Hooks.Builtins[name]
+ factory, ok := lookupBuiltinHook(name)
+ if !ok {
+ return fmt.Errorf("builtin hook %q is not registered", name)
+ }
+
+ hook, factoryErr := factory(ctx, spec)
+ if factoryErr != nil {
+ return fmt.Errorf("build builtin hook %q: %w", name, factoryErr)
+ }
+ if err := al.MountHook(HookRegistration{
+ Name: name,
+ Priority: spec.Priority,
+ Source: HookSourceInProcess,
+ Hook: hook,
+ }); err != nil {
+ return fmt.Errorf("mount builtin hook %q: %w", name, err)
+ }
+ mounted = append(mounted, name)
+ }
+
+ processNames := enabledProcessHookNames(al.cfg.Hooks.Processes)
+ for _, name := range processNames {
+ spec := al.cfg.Hooks.Processes[name]
+ opts, buildErr := processHookOptionsFromConfig(spec)
+ if buildErr != nil {
+ return fmt.Errorf("configure process hook %q: %w", name, buildErr)
+ }
+
+ processHook, buildErr := NewProcessHook(ctx, name, opts)
+ if buildErr != nil {
+ return fmt.Errorf("start process hook %q: %w", name, buildErr)
+ }
+ if err := al.MountHook(HookRegistration{
+ Name: name,
+ Priority: spec.Priority,
+ Source: HookSourceProcess,
+ Hook: processHook,
+ }); err != nil {
+ _ = processHook.Close()
+ return fmt.Errorf("mount process hook %q: %w", name, err)
+ }
+ mounted = append(mounted, name)
+ }
+
+ return nil
+}
+
+func enabledBuiltinHookNames(specs map[string]config.BuiltinHookConfig) []string {
+ if len(specs) == 0 {
+ return nil
+ }
+
+ names := make([]string, 0, len(specs))
+ for name, spec := range specs {
+ if spec.Enabled {
+ names = append(names, name)
+ }
+ }
+ sort.Strings(names)
+ return names
+}
+
+func enabledProcessHookNames(specs map[string]config.ProcessHookConfig) []string {
+ if len(specs) == 0 {
+ return nil
+ }
+
+ names := make([]string, 0, len(specs))
+ for name, spec := range specs {
+ if spec.Enabled {
+ names = append(names, name)
+ }
+ }
+ sort.Strings(names)
+ return names
+}
+
+func processHookOptionsFromConfig(spec config.ProcessHookConfig) (ProcessHookOptions, error) {
+ transport := spec.Transport
+ if transport == "" {
+ transport = "stdio"
+ }
+ if transport != "stdio" {
+ return ProcessHookOptions{}, fmt.Errorf("unsupported transport %q", transport)
+ }
+ if len(spec.Command) == 0 {
+ return ProcessHookOptions{}, fmt.Errorf("command is required")
+ }
+
+ opts := ProcessHookOptions{
+ Command: append([]string(nil), spec.Command...),
+ Dir: spec.Dir,
+ Env: processHookEnvFromMap(spec.Env),
+ }
+
+ observeKinds, observeEnabled, err := processHookObserveKindsFromConfig(spec.Observe)
+ if err != nil {
+ return ProcessHookOptions{}, err
+ }
+ opts.Observe = observeEnabled
+ opts.ObserveKinds = observeKinds
+
+ for _, intercept := range spec.Intercept {
+ switch intercept {
+ case "before_llm", "after_llm":
+ opts.InterceptLLM = true
+ case "before_tool", "after_tool":
+ opts.InterceptTool = true
+ case "approve_tool":
+ opts.ApproveTool = true
+ case "":
+ continue
+ default:
+ return ProcessHookOptions{}, fmt.Errorf("unsupported intercept %q", intercept)
+ }
+ }
+
+ if !opts.Observe && !opts.InterceptLLM && !opts.InterceptTool && !opts.ApproveTool {
+ return ProcessHookOptions{}, fmt.Errorf("no hook modes enabled")
+ }
+
+ return opts, nil
+}
+
+func processHookEnvFromMap(envMap map[string]string) []string {
+ if len(envMap) == 0 {
+ return nil
+ }
+
+ keys := make([]string, 0, len(envMap))
+ for key := range envMap {
+ keys = append(keys, key)
+ }
+ sort.Strings(keys)
+
+ env := make([]string, 0, len(keys))
+ for _, key := range keys {
+ env = append(env, key+"="+envMap[key])
+ }
+ return env
+}
+
+func processHookObserveKindsFromConfig(observe []string) ([]string, bool, error) {
+ if len(observe) == 0 {
+ return nil, false, nil
+ }
+
+ validKinds := validHookEventKinds()
+ normalized := make([]string, 0, len(observe))
+ for _, kind := range observe {
+ switch kind {
+ case "", "*", "all":
+ return nil, true, nil
+ default:
+ if _, ok := validKinds[kind]; !ok {
+ return nil, false, fmt.Errorf("unsupported observe event %q", kind)
+ }
+ normalized = append(normalized, kind)
+ }
+ }
+
+ if len(normalized) == 0 {
+ return nil, false, nil
+ }
+ return normalized, true, nil
+}
+
+func validHookEventKinds() map[string]struct{} {
+ kinds := make(map[string]struct{}, int(eventKindCount))
+ for kind := EventKind(0); kind < eventKindCount; kind++ {
+ kinds[kind.String()] = struct{}{}
+ }
+ return kinds
+}
diff --git a/pkg/agent/hook_mount_test.go b/pkg/agent/hook_mount_test.go
new file mode 100644
index 000000000..a9d8f27c5
--- /dev/null
+++ b/pkg/agent/hook_mount_test.go
@@ -0,0 +1,179 @@
+package agent
+
+import (
+ "context"
+ "encoding/json"
+ "path/filepath"
+ "testing"
+
+ "github.com/sipeed/picoclaw/pkg/bus"
+ "github.com/sipeed/picoclaw/pkg/config"
+)
+
+type builtinAutoHookConfig struct {
+ Model string `json:"model"`
+ Suffix string `json:"suffix"`
+}
+
+type builtinAutoHook struct {
+ model string
+ suffix string
+}
+
+func (h *builtinAutoHook) BeforeLLM(
+ ctx context.Context,
+ req *LLMHookRequest,
+) (*LLMHookRequest, HookDecision, error) {
+ next := req.Clone()
+ next.Model = h.model
+ return next, HookDecision{Action: HookActionModify}, nil
+}
+
+func (h *builtinAutoHook) AfterLLM(
+ ctx context.Context,
+ resp *LLMHookResponse,
+) (*LLMHookResponse, HookDecision, error) {
+ next := resp.Clone()
+ if next.Response != nil {
+ next.Response.Content += h.suffix
+ }
+ return next, HookDecision{Action: HookActionModify}, nil
+}
+
+func newConfiguredHookLoop(t *testing.T, provider *llmHookTestProvider, hooks config.HooksConfig) *AgentLoop {
+ t.Helper()
+
+ cfg := &config.Config{
+ Agents: config.AgentsConfig{
+ Defaults: config.AgentDefaults{
+ Workspace: t.TempDir(),
+ Model: "test-model",
+ MaxTokens: 4096,
+ MaxToolIterations: 10,
+ },
+ },
+ Hooks: hooks,
+ }
+
+ return NewAgentLoop(cfg, bus.NewMessageBus(), provider)
+}
+
+func TestAgentLoop_ProcessDirectWithChannel_AutoMountsBuiltinHook(t *testing.T) {
+ const hookName = "test-auto-builtin-hook"
+
+ if err := RegisterBuiltinHook(hookName, func(
+ ctx context.Context,
+ spec config.BuiltinHookConfig,
+ ) (any, error) {
+ var hookCfg builtinAutoHookConfig
+ if len(spec.Config) > 0 {
+ if err := json.Unmarshal(spec.Config, &hookCfg); err != nil {
+ return nil, err
+ }
+ }
+ return &builtinAutoHook{
+ model: hookCfg.Model,
+ suffix: hookCfg.Suffix,
+ }, nil
+ }); err != nil {
+ t.Fatalf("RegisterBuiltinHook failed: %v", err)
+ }
+ t.Cleanup(func() {
+ unregisterBuiltinHook(hookName)
+ })
+
+ rawCfg, err := json.Marshal(builtinAutoHookConfig{
+ Model: "builtin-model",
+ Suffix: "|builtin",
+ })
+ if err != nil {
+ t.Fatalf("json.Marshal failed: %v", err)
+ }
+
+ provider := &llmHookTestProvider{}
+ al := newConfiguredHookLoop(t, provider, config.HooksConfig{
+ Enabled: true,
+ Builtins: map[string]config.BuiltinHookConfig{
+ hookName: {
+ Enabled: true,
+ Config: rawCfg,
+ },
+ },
+ })
+ defer al.Close()
+
+ resp, err := al.ProcessDirectWithChannel(context.Background(), "hello", "session-1", "cli", "direct")
+ if err != nil {
+ t.Fatalf("ProcessDirectWithChannel failed: %v", err)
+ }
+ if resp != "provider content|builtin" {
+ t.Fatalf("expected builtin-hooked content, got %q", resp)
+ }
+
+ provider.mu.Lock()
+ lastModel := provider.lastModel
+ provider.mu.Unlock()
+ if lastModel != "builtin-model" {
+ t.Fatalf("expected builtin model, got %q", lastModel)
+ }
+}
+
+func TestAgentLoop_ProcessDirectWithChannel_AutoMountsProcessHook(t *testing.T) {
+ provider := &llmHookTestProvider{}
+ eventLog := filepath.Join(t.TempDir(), "events.log")
+
+ al := newConfiguredHookLoop(t, provider, config.HooksConfig{
+ Enabled: true,
+ Processes: map[string]config.ProcessHookConfig{
+ "ipc-auto": {
+ Enabled: true,
+ Command: processHookHelperCommand(),
+ Env: map[string]string{
+ "PICOCLAW_HOOK_HELPER": "1",
+ "PICOCLAW_HOOK_MODE": "rewrite",
+ "PICOCLAW_HOOK_EVENT_LOG": eventLog,
+ },
+ Observe: []string{"turn_end"},
+ Intercept: []string{"before_llm", "after_llm"},
+ },
+ },
+ })
+ defer al.Close()
+
+ resp, err := al.ProcessDirectWithChannel(context.Background(), "hello", "session-1", "cli", "direct")
+ if err != nil {
+ t.Fatalf("ProcessDirectWithChannel failed: %v", err)
+ }
+ if resp != "provider content|ipc" {
+ t.Fatalf("expected process-hooked content, got %q", resp)
+ }
+
+ provider.mu.Lock()
+ lastModel := provider.lastModel
+ provider.mu.Unlock()
+ if lastModel != "process-model" {
+ t.Fatalf("expected process model, got %q", lastModel)
+ }
+
+ waitForFileContains(t, eventLog, "turn_end")
+}
+
+func TestAgentLoop_ProcessDirectWithChannel_InvalidConfiguredHookFails(t *testing.T) {
+ provider := &llmHookTestProvider{}
+ al := newConfiguredHookLoop(t, provider, config.HooksConfig{
+ Enabled: true,
+ Processes: map[string]config.ProcessHookConfig{
+ "bad-hook": {
+ Enabled: true,
+ Command: processHookHelperCommand(),
+ Intercept: []string{"not_supported"},
+ },
+ },
+ })
+ defer al.Close()
+
+ _, err := al.ProcessDirectWithChannel(context.Background(), "hello", "session-1", "cli", "direct")
+ if err == nil {
+ t.Fatal("expected invalid configured hook error")
+ }
+}
diff --git a/pkg/agent/hook_process.go b/pkg/agent/hook_process.go
new file mode 100644
index 000000000..e5632913d
--- /dev/null
+++ b/pkg/agent/hook_process.go
@@ -0,0 +1,511 @@
+package agent
+
+import (
+ "bufio"
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "os"
+ "os/exec"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/sipeed/picoclaw/pkg/logger"
+)
+
+const (
+ processHookJSONRPCVersion = "2.0"
+ processHookReadBufferSize = 1024 * 1024
+ processHookCloseTimeout = 2 * time.Second
+)
+
+type ProcessHookOptions struct {
+ Command []string
+ Dir string
+ Env []string
+ Observe bool
+ ObserveKinds []string
+ InterceptLLM bool
+ InterceptTool bool
+ ApproveTool bool
+}
+
+type ProcessHook struct {
+ name string
+ opts ProcessHookOptions
+
+ cmd *exec.Cmd
+ stdin io.WriteCloser
+ observeKinds map[string]struct{}
+
+ writeMu sync.Mutex
+
+ pendingMu sync.Mutex
+ pending map[uint64]chan processHookRPCMessage
+ nextID atomic.Uint64
+
+ closed atomic.Bool
+ done chan struct{}
+ closeErr error
+ closeMu sync.Mutex
+ closeOnce sync.Once
+}
+
+type processHookRPCMessage struct {
+ JSONRPC string `json:"jsonrpc,omitempty"`
+ ID uint64 `json:"id,omitempty"`
+ Method string `json:"method,omitempty"`
+ Params json.RawMessage `json:"params,omitempty"`
+ Result json.RawMessage `json:"result,omitempty"`
+ Error *processHookRPCError `json:"error,omitempty"`
+}
+
+type processHookRPCError struct {
+ Code int `json:"code"`
+ Message string `json:"message"`
+}
+
+type processHookHelloParams struct {
+ Name string `json:"name"`
+ Version int `json:"version"`
+ Modes []string `json:"modes,omitempty"`
+}
+
+type processHookDecisionResponse struct {
+ Action HookAction `json:"action"`
+ Reason string `json:"reason,omitempty"`
+}
+
+type processHookBeforeLLMResponse struct {
+ processHookDecisionResponse
+ Request *LLMHookRequest `json:"request,omitempty"`
+}
+
+type processHookAfterLLMResponse struct {
+ processHookDecisionResponse
+ Response *LLMHookResponse `json:"response,omitempty"`
+}
+
+type processHookBeforeToolResponse struct {
+ processHookDecisionResponse
+ Call *ToolCallHookRequest `json:"call,omitempty"`
+}
+
+type processHookAfterToolResponse struct {
+ processHookDecisionResponse
+ Result *ToolResultHookResponse `json:"result,omitempty"`
+}
+
+func NewProcessHook(ctx context.Context, name string, opts ProcessHookOptions) (*ProcessHook, error) {
+ if len(opts.Command) == 0 {
+ return nil, fmt.Errorf("process hook command is required")
+ }
+
+ cmd := exec.Command(opts.Command[0], opts.Command[1:]...)
+ cmd.Dir = opts.Dir
+ if len(opts.Env) > 0 {
+ cmd.Env = append(os.Environ(), opts.Env...)
+ }
+ stdin, err := cmd.StdinPipe()
+ if err != nil {
+ return nil, fmt.Errorf("create process hook stdin: %w", err)
+ }
+ stdout, err := cmd.StdoutPipe()
+ if err != nil {
+ return nil, fmt.Errorf("create process hook stdout: %w", err)
+ }
+ stderr, err := cmd.StderrPipe()
+ if err != nil {
+ return nil, fmt.Errorf("create process hook stderr: %w", err)
+ }
+ if err := cmd.Start(); err != nil {
+ return nil, fmt.Errorf("start process hook: %w", err)
+ }
+
+ ph := &ProcessHook{
+ name: name,
+ opts: opts,
+ cmd: cmd,
+ stdin: stdin,
+ observeKinds: newProcessHookObserveKinds(opts.ObserveKinds),
+ pending: make(map[uint64]chan processHookRPCMessage),
+ done: make(chan struct{}),
+ }
+
+ go ph.readLoop(stdout)
+ go ph.readStderr(stderr)
+ go ph.waitLoop()
+
+ helloCtx := ctx
+ if helloCtx == nil {
+ var cancel context.CancelFunc
+ helloCtx, cancel = context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+ }
+ if err := ph.hello(helloCtx); err != nil {
+ _ = ph.Close()
+ return nil, err
+ }
+
+ return ph, nil
+}
+
+func (ph *ProcessHook) Close() error {
+ if ph == nil {
+ return nil
+ }
+
+ ph.closeOnce.Do(func() {
+ ph.closed.Store(true)
+ if ph.stdin != nil {
+ _ = ph.stdin.Close()
+ }
+
+ select {
+ case <-ph.done:
+ case <-time.After(processHookCloseTimeout):
+ if ph.cmd != nil && ph.cmd.Process != nil {
+ _ = ph.cmd.Process.Kill()
+ }
+ <-ph.done
+ }
+ })
+
+ ph.closeMu.Lock()
+ defer ph.closeMu.Unlock()
+ return ph.closeErr
+}
+
+func (ph *ProcessHook) OnEvent(ctx context.Context, evt Event) error {
+ if ph == nil || !ph.opts.Observe {
+ return nil
+ }
+ if len(ph.observeKinds) > 0 {
+ if _, ok := ph.observeKinds[evt.Kind.String()]; !ok {
+ return nil
+ }
+ }
+ return ph.notify(ctx, "hook.event", evt)
+}
+
+func (ph *ProcessHook) BeforeLLM(
+ ctx context.Context,
+ req *LLMHookRequest,
+) (*LLMHookRequest, HookDecision, error) {
+ if ph == nil || !ph.opts.InterceptLLM {
+ return req, HookDecision{Action: HookActionContinue}, nil
+ }
+
+ var resp processHookBeforeLLMResponse
+ if err := ph.call(ctx, "hook.before_llm", req, &resp); err != nil {
+ return nil, HookDecision{}, err
+ }
+ if resp.Request == nil {
+ resp.Request = req
+ }
+ return resp.Request, HookDecision{Action: resp.Action, Reason: resp.Reason}, nil
+}
+
+func (ph *ProcessHook) AfterLLM(
+ ctx context.Context,
+ resp *LLMHookResponse,
+) (*LLMHookResponse, HookDecision, error) {
+ if ph == nil || !ph.opts.InterceptLLM {
+ return resp, HookDecision{Action: HookActionContinue}, nil
+ }
+
+ var result processHookAfterLLMResponse
+ if err := ph.call(ctx, "hook.after_llm", resp, &result); err != nil {
+ return nil, HookDecision{}, err
+ }
+ if result.Response == nil {
+ result.Response = resp
+ }
+ return result.Response, HookDecision{Action: result.Action, Reason: result.Reason}, nil
+}
+
+func (ph *ProcessHook) BeforeTool(
+ ctx context.Context,
+ call *ToolCallHookRequest,
+) (*ToolCallHookRequest, HookDecision, error) {
+ if ph == nil || !ph.opts.InterceptTool {
+ return call, HookDecision{Action: HookActionContinue}, nil
+ }
+
+ var resp processHookBeforeToolResponse
+ if err := ph.call(ctx, "hook.before_tool", call, &resp); err != nil {
+ return nil, HookDecision{}, err
+ }
+ if resp.Call == nil {
+ resp.Call = call
+ }
+ return resp.Call, HookDecision{Action: resp.Action, Reason: resp.Reason}, nil
+}
+
+func (ph *ProcessHook) AfterTool(
+ ctx context.Context,
+ result *ToolResultHookResponse,
+) (*ToolResultHookResponse, HookDecision, error) {
+ if ph == nil || !ph.opts.InterceptTool {
+ return result, HookDecision{Action: HookActionContinue}, nil
+ }
+
+ var resp processHookAfterToolResponse
+ if err := ph.call(ctx, "hook.after_tool", result, &resp); err != nil {
+ return nil, HookDecision{}, err
+ }
+ if resp.Result == nil {
+ resp.Result = result
+ }
+ return resp.Result, HookDecision{Action: resp.Action, Reason: resp.Reason}, nil
+}
+
+func (ph *ProcessHook) ApproveTool(ctx context.Context, req *ToolApprovalRequest) (ApprovalDecision, error) {
+ if ph == nil || !ph.opts.ApproveTool {
+ return ApprovalDecision{Approved: true}, nil
+ }
+
+ var resp ApprovalDecision
+ if err := ph.call(ctx, "hook.approve_tool", req, &resp); err != nil {
+ return ApprovalDecision{}, err
+ }
+ return resp, nil
+}
+
+func (ph *ProcessHook) hello(ctx context.Context) error {
+ modes := make([]string, 0, 4)
+ if ph.opts.Observe {
+ modes = append(modes, "observe")
+ }
+ if ph.opts.InterceptLLM {
+ modes = append(modes, "llm")
+ }
+ if ph.opts.InterceptTool {
+ modes = append(modes, "tool")
+ }
+ if ph.opts.ApproveTool {
+ modes = append(modes, "approve")
+ }
+
+ var result map[string]any
+ return ph.call(ctx, "hook.hello", processHookHelloParams{
+ Name: ph.name,
+ Version: 1,
+ Modes: modes,
+ }, &result)
+}
+
+func (ph *ProcessHook) notify(ctx context.Context, method string, params any) error {
+ msg := processHookRPCMessage{
+ JSONRPC: processHookJSONRPCVersion,
+ Method: method,
+ }
+ if params != nil {
+ body, err := json.Marshal(params)
+ if err != nil {
+ return err
+ }
+ msg.Params = body
+ }
+ return ph.send(ctx, msg)
+}
+
+func (ph *ProcessHook) call(ctx context.Context, method string, params any, out any) error {
+ if ph.closed.Load() {
+ return fmt.Errorf("process hook %q is closed", ph.name)
+ }
+
+ id := ph.nextID.Add(1)
+ respCh := make(chan processHookRPCMessage, 1)
+ ph.pendingMu.Lock()
+ ph.pending[id] = respCh
+ ph.pendingMu.Unlock()
+
+ msg := processHookRPCMessage{
+ JSONRPC: processHookJSONRPCVersion,
+ ID: id,
+ Method: method,
+ }
+ if params != nil {
+ body, err := json.Marshal(params)
+ if err != nil {
+ ph.removePending(id)
+ return err
+ }
+ msg.Params = body
+ }
+
+ if err := ph.send(ctx, msg); err != nil {
+ ph.removePending(id)
+ return err
+ }
+
+ select {
+ case resp, ok := <-respCh:
+ if !ok {
+ return fmt.Errorf("process hook %q closed while waiting for %s", ph.name, method)
+ }
+ if resp.Error != nil {
+ return fmt.Errorf("process hook %q %s failed: %s", ph.name, method, resp.Error.Message)
+ }
+ if out != nil && len(resp.Result) > 0 {
+ if err := json.Unmarshal(resp.Result, out); err != nil {
+ return fmt.Errorf("decode process hook %q %s result: %w", ph.name, method, err)
+ }
+ }
+ return nil
+ case <-ctx.Done():
+ ph.removePending(id)
+ return ctx.Err()
+ }
+}
+
+func (ph *ProcessHook) send(ctx context.Context, msg processHookRPCMessage) error {
+ body, err := json.Marshal(msg)
+ if err != nil {
+ return err
+ }
+ body = append(body, '\n')
+
+ ph.writeMu.Lock()
+ defer ph.writeMu.Unlock()
+
+ if ph.closed.Load() {
+ return fmt.Errorf("process hook %q is closed", ph.name)
+ }
+
+ done := make(chan error, 1)
+ go func() {
+ _, writeErr := ph.stdin.Write(body)
+ done <- writeErr
+ }()
+
+ select {
+ case err := <-done:
+ if err != nil {
+ return fmt.Errorf("write process hook %q message: %w", ph.name, err)
+ }
+ return nil
+ case <-ctx.Done():
+ return ctx.Err()
+ }
+}
+
+func (ph *ProcessHook) readLoop(stdout io.Reader) {
+ scanner := bufio.NewScanner(stdout)
+ scanner.Buffer(make([]byte, 0, 64*1024), processHookReadBufferSize)
+
+ for scanner.Scan() {
+ var msg processHookRPCMessage
+ if err := json.Unmarshal(scanner.Bytes(), &msg); err != nil {
+ logger.WarnCF("hooks", "Failed to decode process hook message", map[string]any{
+ "hook": ph.name,
+ "error": err.Error(),
+ })
+ continue
+ }
+ if msg.ID == 0 {
+ continue
+ }
+ ph.pendingMu.Lock()
+ respCh, ok := ph.pending[msg.ID]
+ if ok {
+ delete(ph.pending, msg.ID)
+ }
+ ph.pendingMu.Unlock()
+ if ok {
+ respCh <- msg
+ close(respCh)
+ }
+ }
+}
+
+func (ph *ProcessHook) readStderr(stderr io.Reader) {
+ scanner := bufio.NewScanner(stderr)
+ scanner.Buffer(make([]byte, 0, 16*1024), processHookReadBufferSize)
+ for scanner.Scan() {
+ logger.WarnCF("hooks", "Process hook stderr", map[string]any{
+ "hook": ph.name,
+ "stderr": scanner.Text(),
+ })
+ }
+}
+
+func (ph *ProcessHook) waitLoop() {
+ err := ph.cmd.Wait()
+ ph.closeMu.Lock()
+ ph.closeErr = err
+ ph.closeMu.Unlock()
+ ph.failPending(err)
+ close(ph.done)
+}
+
+func (ph *ProcessHook) failPending(err error) {
+ ph.pendingMu.Lock()
+ defer ph.pendingMu.Unlock()
+
+ msg := processHookRPCMessage{
+ Error: &processHookRPCError{
+ Code: -32000,
+ Message: "process exited",
+ },
+ }
+ if err != nil {
+ msg.Error.Message = err.Error()
+ }
+
+ for id, ch := range ph.pending {
+ delete(ph.pending, id)
+ ch <- msg
+ close(ch)
+ }
+}
+
+func (ph *ProcessHook) removePending(id uint64) {
+ ph.pendingMu.Lock()
+ defer ph.pendingMu.Unlock()
+
+ if ch, ok := ph.pending[id]; ok {
+ delete(ph.pending, id)
+ close(ch)
+ }
+}
+
+func (al *AgentLoop) MountProcessHook(ctx context.Context, name string, opts ProcessHookOptions) error {
+ if al == nil {
+ return fmt.Errorf("agent loop is nil")
+ }
+ processHook, err := NewProcessHook(ctx, name, opts)
+ if err != nil {
+ return err
+ }
+ if err := al.MountHook(HookRegistration{
+ Name: name,
+ Source: HookSourceProcess,
+ Hook: processHook,
+ }); err != nil {
+ _ = processHook.Close()
+ return err
+ }
+ return nil
+}
+
+func newProcessHookObserveKinds(kinds []string) map[string]struct{} {
+ if len(kinds) == 0 {
+ return nil
+ }
+
+ normalized := make(map[string]struct{}, len(kinds))
+ for _, kind := range kinds {
+ if kind == "" {
+ continue
+ }
+ normalized[kind] = struct{}{}
+ }
+ if len(normalized) == 0 {
+ return nil
+ }
+ return normalized
+}
diff --git a/pkg/agent/hook_process_test.go b/pkg/agent/hook_process_test.go
new file mode 100644
index 000000000..50f89811f
--- /dev/null
+++ b/pkg/agent/hook_process_test.go
@@ -0,0 +1,339 @@
+package agent
+
+import (
+ "bufio"
+ "context"
+ "encoding/json"
+ "fmt"
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/sipeed/picoclaw/pkg/providers"
+)
+
+func TestProcessHook_HelperProcess(t *testing.T) {
+ if os.Getenv("PICOCLAW_HOOK_HELPER") != "1" {
+ return
+ }
+ if err := runProcessHookHelper(); err != nil {
+ fmt.Fprintln(os.Stderr, err.Error())
+ os.Exit(1)
+ }
+ os.Exit(0)
+}
+
+func TestAgentLoop_MountProcessHook_LLMAndObserver(t *testing.T) {
+ provider := &llmHookTestProvider{}
+ al, agent, cleanup := newHookTestLoop(t, provider)
+ defer cleanup()
+
+ eventLog := filepath.Join(t.TempDir(), "events.log")
+ if err := al.MountProcessHook(context.Background(), "ipc-llm", ProcessHookOptions{
+ Command: processHookHelperCommand(),
+ Env: processHookHelperEnv("rewrite", eventLog),
+ Observe: true,
+ InterceptLLM: true,
+ }); err != nil {
+ t.Fatalf("MountProcessHook failed: %v", err)
+ }
+
+ resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
+ SessionKey: "session-1",
+ Channel: "cli",
+ ChatID: "direct",
+ UserMessage: "hello",
+ DefaultResponse: defaultResponse,
+ EnableSummary: false,
+ SendResponse: false,
+ })
+ if err != nil {
+ t.Fatalf("runAgentLoop failed: %v", err)
+ }
+ if resp != "provider content|ipc" {
+ t.Fatalf("expected process-hooked llm content, got %q", resp)
+ }
+
+ provider.mu.Lock()
+ lastModel := provider.lastModel
+ provider.mu.Unlock()
+ if lastModel != "process-model" {
+ t.Fatalf("expected process model, got %q", lastModel)
+ }
+
+ waitForFileContains(t, eventLog, "turn_end")
+}
+
+func TestAgentLoop_MountProcessHook_ToolRewrite(t *testing.T) {
+ provider := &toolHookProvider{}
+ al, agent, cleanup := newHookTestLoop(t, provider)
+ defer cleanup()
+
+ al.RegisterTool(&echoTextTool{})
+ if err := al.MountProcessHook(context.Background(), "ipc-tool", ProcessHookOptions{
+ Command: processHookHelperCommand(),
+ Env: processHookHelperEnv("rewrite", ""),
+ InterceptTool: true,
+ }); err != nil {
+ t.Fatalf("MountProcessHook failed: %v", err)
+ }
+
+ resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
+ SessionKey: "session-1",
+ Channel: "cli",
+ ChatID: "direct",
+ UserMessage: "run tool",
+ DefaultResponse: defaultResponse,
+ EnableSummary: false,
+ SendResponse: false,
+ })
+ if err != nil {
+ t.Fatalf("runAgentLoop failed: %v", err)
+ }
+ if resp != "ipc:ipc" {
+ t.Fatalf("expected rewritten process-hook tool result, got %q", resp)
+ }
+}
+
+type blockedToolProvider struct {
+ calls int
+}
+
+func (p *blockedToolProvider) Chat(
+ ctx context.Context,
+ messages []providers.Message,
+ tools []providers.ToolDefinition,
+ model string,
+ opts map[string]any,
+) (*providers.LLMResponse, error) {
+ p.calls++
+ if p.calls == 1 {
+ return &providers.LLMResponse{
+ ToolCalls: []providers.ToolCall{
+ {
+ ID: "call-1",
+ Name: "blocked_tool",
+ Arguments: map[string]any{},
+ },
+ },
+ }, nil
+ }
+
+ return &providers.LLMResponse{
+ Content: messages[len(messages)-1].Content,
+ }, nil
+}
+
+func (p *blockedToolProvider) GetDefaultModel() string {
+ return "blocked-tool-provider"
+}
+
+func TestAgentLoop_MountProcessHook_ApprovalDeny(t *testing.T) {
+ provider := &blockedToolProvider{}
+ al, agent, cleanup := newHookTestLoop(t, provider)
+ defer cleanup()
+
+ if err := al.MountProcessHook(context.Background(), "ipc-approval", ProcessHookOptions{
+ Command: processHookHelperCommand(),
+ Env: processHookHelperEnv("deny", ""),
+ ApproveTool: true,
+ }); err != nil {
+ t.Fatalf("MountProcessHook failed: %v", err)
+ }
+
+ sub := al.SubscribeEvents(16)
+ defer al.UnsubscribeEvents(sub.ID)
+
+ resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
+ SessionKey: "session-1",
+ Channel: "cli",
+ ChatID: "direct",
+ UserMessage: "run blocked tool",
+ DefaultResponse: defaultResponse,
+ EnableSummary: false,
+ SendResponse: false,
+ })
+ if err != nil {
+ t.Fatalf("runAgentLoop failed: %v", err)
+ }
+
+ expected := "Tool execution denied by approval hook: blocked by ipc hook"
+ if resp != expected {
+ t.Fatalf("expected %q, got %q", expected, resp)
+ }
+
+ events := collectEventStream(sub.C)
+ skippedEvt, ok := findEvent(events, EventKindToolExecSkipped)
+ if !ok {
+ t.Fatal("expected tool skipped event")
+ }
+ payload, ok := skippedEvt.Payload.(ToolExecSkippedPayload)
+ if !ok {
+ t.Fatalf("expected ToolExecSkippedPayload, got %T", skippedEvt.Payload)
+ }
+ if payload.Reason != expected {
+ t.Fatalf("expected reason %q, got %q", expected, payload.Reason)
+ }
+}
+
+func processHookHelperCommand() []string {
+ return []string{os.Args[0], "-test.run=TestProcessHook_HelperProcess", "--"}
+}
+
+func processHookHelperEnv(mode, eventLog string) []string {
+ env := []string{
+ "PICOCLAW_HOOK_HELPER=1",
+ "PICOCLAW_HOOK_MODE=" + mode,
+ }
+ if eventLog != "" {
+ env = append(env, "PICOCLAW_HOOK_EVENT_LOG="+eventLog)
+ }
+ return env
+}
+
+func waitForFileContains(t *testing.T, path, substring string) {
+ t.Helper()
+
+ deadline := time.Now().Add(3 * time.Second)
+ for time.Now().Before(deadline) {
+ data, err := os.ReadFile(path)
+ if err == nil && strings.Contains(string(data), substring) {
+ return
+ }
+ time.Sleep(20 * time.Millisecond)
+ }
+
+ data, _ := os.ReadFile(path)
+ t.Fatalf("timed out waiting for %q in %s; current content: %q", substring, path, string(data))
+}
+
+func runProcessHookHelper() error {
+ mode := os.Getenv("PICOCLAW_HOOK_MODE")
+ eventLog := os.Getenv("PICOCLAW_HOOK_EVENT_LOG")
+
+ scanner := bufio.NewScanner(os.Stdin)
+ scanner.Buffer(make([]byte, 0, 64*1024), processHookReadBufferSize)
+ encoder := json.NewEncoder(os.Stdout)
+
+ for scanner.Scan() {
+ var msg processHookRPCMessage
+ if err := json.Unmarshal(scanner.Bytes(), &msg); err != nil {
+ return err
+ }
+
+ if msg.ID == 0 {
+ if msg.Method == "hook.event" && eventLog != "" {
+ var evt map[string]any
+ if err := json.Unmarshal(msg.Params, &evt); err == nil {
+ if rawKind, ok := evt["Kind"].(float64); ok {
+ kind := EventKind(rawKind)
+ _ = os.WriteFile(eventLog, []byte(kind.String()+"\n"), 0o644)
+ }
+ }
+ }
+ continue
+ }
+
+ result, rpcErr := handleProcessHookRequest(mode, msg)
+ resp := processHookRPCMessage{
+ JSONRPC: processHookJSONRPCVersion,
+ ID: msg.ID,
+ }
+ if rpcErr != nil {
+ resp.Error = rpcErr
+ } else if result != nil {
+ body, err := json.Marshal(result)
+ if err != nil {
+ return err
+ }
+ resp.Result = body
+ } else {
+ resp.Result = []byte("{}")
+ }
+
+ if err := encoder.Encode(resp); err != nil {
+ return err
+ }
+ }
+
+ return scanner.Err()
+}
+
+func handleProcessHookRequest(mode string, msg processHookRPCMessage) (any, *processHookRPCError) {
+ switch msg.Method {
+ case "hook.hello":
+ return map[string]any{"ok": true}, nil
+ case "hook.before_llm":
+ if mode != "rewrite" {
+ return map[string]any{"action": HookActionContinue}, nil
+ }
+ var req map[string]any
+ _ = json.Unmarshal(msg.Params, &req)
+ req["model"] = "process-model"
+ return map[string]any{
+ "action": HookActionModify,
+ "request": req,
+ }, nil
+ case "hook.after_llm":
+ if mode != "rewrite" {
+ return map[string]any{"action": HookActionContinue}, nil
+ }
+ var resp map[string]any
+ _ = json.Unmarshal(msg.Params, &resp)
+ if rawResponse, ok := resp["response"].(map[string]any); ok {
+ if content, ok := rawResponse["content"].(string); ok {
+ rawResponse["content"] = content + "|ipc"
+ }
+ }
+ return map[string]any{
+ "action": HookActionModify,
+ "response": resp,
+ }, nil
+ case "hook.before_tool":
+ if mode != "rewrite" {
+ return map[string]any{"action": HookActionContinue}, nil
+ }
+ var call map[string]any
+ _ = json.Unmarshal(msg.Params, &call)
+ rawArgs, ok := call["arguments"].(map[string]any)
+ if !ok || rawArgs == nil {
+ rawArgs = map[string]any{}
+ }
+ rawArgs["text"] = "ipc"
+ call["arguments"] = rawArgs
+ return map[string]any{
+ "action": HookActionModify,
+ "call": call,
+ }, nil
+ case "hook.after_tool":
+ if mode != "rewrite" {
+ return map[string]any{"action": HookActionContinue}, nil
+ }
+ var result map[string]any
+ _ = json.Unmarshal(msg.Params, &result)
+ if rawResult, ok := result["result"].(map[string]any); ok {
+ if forLLM, ok := rawResult["for_llm"].(string); ok {
+ rawResult["for_llm"] = "ipc:" + forLLM
+ }
+ }
+ return map[string]any{
+ "action": HookActionModify,
+ "result": result,
+ }, nil
+ case "hook.approve_tool":
+ if mode == "deny" {
+ return ApprovalDecision{
+ Approved: false,
+ Reason: "blocked by ipc hook",
+ }, nil
+ }
+ return ApprovalDecision{Approved: true}, nil
+ default:
+ return nil, &processHookRPCError{
+ Code: -32601,
+ Message: "method not found",
+ }
+ }
+}
diff --git a/pkg/agent/hooks.go b/pkg/agent/hooks.go
index 74af542fa..c1ef58ffd 100644
--- a/pkg/agent/hooks.go
+++ b/pkg/agent/hooks.go
@@ -3,6 +3,7 @@ package agent
import (
"context"
"fmt"
+ "io"
"sort"
"sync"
"time"
@@ -30,8 +31,8 @@ const (
)
type HookDecision struct {
- Action HookAction
- Reason string
+ Action HookAction `json:"action"`
+ Reason string `json:"reason,omitempty"`
}
func (d HookDecision) normalizedAction() HookAction {
@@ -42,20 +43,29 @@ func (d HookDecision) normalizedAction() HookAction {
}
type ApprovalDecision struct {
- Approved bool
- Reason string
+ Approved bool `json:"approved"`
+ Reason string `json:"reason,omitempty"`
}
+type HookSource uint8
+
+const (
+ HookSourceInProcess HookSource = iota
+ HookSourceProcess
+)
+
type HookRegistration struct {
Name string
Priority int
+ Source HookSource
Hook any
}
func NamedHook(name string, hook any) HookRegistration {
return HookRegistration{
- Name: name,
- Hook: hook,
+ Name: name,
+ Source: HookSourceInProcess,
+ Hook: hook,
}
}
@@ -78,14 +88,14 @@ type ToolApprover interface {
}
type LLMHookRequest struct {
- Meta EventMeta
- Model string
- Messages []providers.Message
- Tools []providers.ToolDefinition
- Options map[string]any
- Channel string
- ChatID string
- GracefulTerminal bool
+ Meta EventMeta `json:"meta"`
+ Model string `json:"model"`
+ Messages []providers.Message `json:"messages,omitempty"`
+ Tools []providers.ToolDefinition `json:"tools,omitempty"`
+ Options map[string]any `json:"options,omitempty"`
+ Channel string `json:"channel,omitempty"`
+ ChatID string `json:"chat_id,omitempty"`
+ GracefulTerminal bool `json:"graceful_terminal,omitempty"`
}
func (r *LLMHookRequest) Clone() *LLMHookRequest {
@@ -100,11 +110,11 @@ func (r *LLMHookRequest) Clone() *LLMHookRequest {
}
type LLMHookResponse struct {
- Meta EventMeta
- Model string
- Response *providers.LLMResponse
- Channel string
- ChatID string
+ Meta EventMeta `json:"meta"`
+ Model string `json:"model"`
+ Response *providers.LLMResponse `json:"response,omitempty"`
+ Channel string `json:"channel,omitempty"`
+ ChatID string `json:"chat_id,omitempty"`
}
func (r *LLMHookResponse) Clone() *LLMHookResponse {
@@ -117,11 +127,11 @@ func (r *LLMHookResponse) Clone() *LLMHookResponse {
}
type ToolCallHookRequest struct {
- Meta EventMeta
- Tool string
- Arguments map[string]any
- Channel string
- ChatID string
+ Meta EventMeta `json:"meta"`
+ Tool string `json:"tool"`
+ Arguments map[string]any `json:"arguments,omitempty"`
+ Channel string `json:"channel,omitempty"`
+ ChatID string `json:"chat_id,omitempty"`
}
func (r *ToolCallHookRequest) Clone() *ToolCallHookRequest {
@@ -134,11 +144,11 @@ func (r *ToolCallHookRequest) Clone() *ToolCallHookRequest {
}
type ToolApprovalRequest struct {
- Meta EventMeta
- Tool string
- Arguments map[string]any
- Channel string
- ChatID string
+ Meta EventMeta `json:"meta"`
+ Tool string `json:"tool"`
+ Arguments map[string]any `json:"arguments,omitempty"`
+ Channel string `json:"channel,omitempty"`
+ ChatID string `json:"chat_id,omitempty"`
}
func (r *ToolApprovalRequest) Clone() *ToolApprovalRequest {
@@ -151,13 +161,13 @@ func (r *ToolApprovalRequest) Clone() *ToolApprovalRequest {
}
type ToolResultHookResponse struct {
- Meta EventMeta
- Tool string
- Arguments map[string]any
- Result *tools.ToolResult
- Duration time.Duration
- Channel string
- ChatID string
+ Meta EventMeta `json:"meta"`
+ Tool string `json:"tool"`
+ Arguments map[string]any `json:"arguments,omitempty"`
+ Result *tools.ToolResult `json:"result,omitempty"`
+ Duration time.Duration `json:"duration"`
+ Channel string `json:"channel,omitempty"`
+ ChatID string `json:"chat_id,omitempty"`
}
func (r *ToolResultHookResponse) Clone() *ToolResultHookResponse {
@@ -215,9 +225,25 @@ func (hm *HookManager) Close() {
hm.eventBus.Unsubscribe(hm.sub.ID)
}
<-hm.done
+ hm.closeAllHooks()
})
}
+func (hm *HookManager) ConfigureTimeouts(observer, interceptor, approval time.Duration) {
+ if hm == nil {
+ return
+ }
+ if observer > 0 {
+ hm.observerTimeout = observer
+ }
+ if interceptor > 0 {
+ hm.interceptorTimeout = interceptor
+ }
+ if approval > 0 {
+ hm.approvalTimeout = approval
+ }
+}
+
func (hm *HookManager) Mount(reg HookRegistration) error {
if hm == nil {
return fmt.Errorf("hook manager is nil")
@@ -232,6 +258,9 @@ func (hm *HookManager) Mount(reg HookRegistration) error {
hm.mu.Lock()
defer hm.mu.Unlock()
+ if existing, ok := hm.hooks[reg.Name]; ok {
+ closeHookIfPossible(existing.Hook)
+ }
hm.hooks[reg.Name] = reg
hm.rebuildOrdered()
return nil
@@ -245,6 +274,9 @@ func (hm *HookManager) Unmount(name string) {
hm.mu.Lock()
defer hm.mu.Unlock()
+ if existing, ok := hm.hooks[name]; ok {
+ closeHookIfPossible(existing.Hook)
+ }
delete(hm.hooks, name)
hm.rebuildOrdered()
}
@@ -425,6 +457,9 @@ func (hm *HookManager) rebuildOrdered() {
hm.ordered = append(hm.ordered, reg)
}
sort.SliceStable(hm.ordered, func(i, j int) bool {
+ if hm.ordered[i].Source != hm.ordered[j].Source {
+ return hm.ordered[i].Source < hm.ordered[j].Source
+ }
if hm.ordered[i].Priority == hm.ordered[j].Priority {
return hm.ordered[i].Name < hm.ordered[j].Name
}
@@ -441,6 +476,17 @@ func (hm *HookManager) snapshotHooks() []HookRegistration {
return snapshot
}
+func (hm *HookManager) closeAllHooks() {
+ hm.mu.Lock()
+ defer hm.mu.Unlock()
+
+ for name, reg := range hm.hooks {
+ closeHookIfPossible(reg.Hook)
+ delete(hm.hooks, name)
+ }
+ hm.ordered = nil
+}
+
func (hm *HookManager) runObserver(name string, observer EventObserver, evt Event) {
ctx, cancel := context.WithTimeout(context.Background(), hm.observerTimeout)
defer cancel()
@@ -749,3 +795,15 @@ func cloneToolResult(result *tools.ToolResult) *tools.ToolResult {
}
return &cloned
}
+
+func closeHookIfPossible(hook any) {
+ closer, ok := hook.(io.Closer)
+ if !ok {
+ return
+ }
+ if err := closer.Close(); err != nil {
+ logger.WarnCF("hooks", "Failed to close hook", map[string]any{
+ "error": err.Error(),
+ })
+ }
+}
diff --git a/pkg/agent/hooks_test.go b/pkg/agent/hooks_test.go
index 6607b5fe7..e6471e9cc 100644
--- a/pkg/agent/hooks_test.go
+++ b/pkg/agent/hooks_test.go
@@ -47,6 +47,39 @@ func newHookTestLoop(
}
}
+func TestHookManager_SortsInProcessBeforeProcess(t *testing.T) {
+ hm := NewHookManager(nil)
+ defer hm.Close()
+
+ if err := hm.Mount(HookRegistration{
+ Name: "process",
+ Priority: -10,
+ Source: HookSourceProcess,
+ Hook: struct{}{},
+ }); err != nil {
+ t.Fatalf("mount process hook: %v", err)
+ }
+ if err := hm.Mount(HookRegistration{
+ Name: "in-process",
+ Priority: 100,
+ Source: HookSourceInProcess,
+ Hook: struct{}{},
+ }); err != nil {
+ t.Fatalf("mount in-process hook: %v", err)
+ }
+
+ ordered := hm.snapshotHooks()
+ if len(ordered) != 2 {
+ t.Fatalf("expected 2 hooks, got %d", len(ordered))
+ }
+ if ordered[0].Name != "in-process" {
+ t.Fatalf("expected in-process hook first, got %q", ordered[0].Name)
+ }
+ if ordered[1].Name != "process" {
+ t.Fatalf("expected process hook second, got %q", ordered[1].Name)
+ }
+}
+
type llmHookTestProvider struct {
mu sync.Mutex
lastModel string
diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go
index a85abcb60..41dfdff5f 100644
--- a/pkg/agent/loop.go
+++ b/pkg/agent/loop.go
@@ -49,6 +49,7 @@ type AgentLoop struct {
transcriber voice.Transcriber
cmdRegistry *commands.Registry
mcp mcpRuntime
+ hookRuntime hookRuntime
steering *steeringQueue
mu sync.RWMutex
activeTurnMu sync.RWMutex
@@ -122,6 +123,7 @@ func NewAgentLoop(
steering: newSteeringQueue(parseSteeringMode(cfg.Agents.Defaults.SteeringMode)),
}
al.hooks = NewHookManager(eventBus)
+ configureHookManagerFromConfig(al.hooks, cfg)
return al
}
@@ -259,6 +261,9 @@ func registerSharedTools(
func (al *AgentLoop) Run(ctx context.Context) error {
al.running.Store(true)
+ if err := al.ensureHooksInitialized(ctx); err != nil {
+ return err
+ }
if err := al.ensureMCPInitialized(ctx); err != nil {
return err
}
@@ -773,6 +778,9 @@ func (al *AgentLoop) ReloadProviderAndConfig(
al.mu.Unlock()
+ al.hookRuntime.reset(al)
+ configureHookManagerFromConfig(al.hooks, cfg)
+
// Close old provider after releasing the lock
// This prevents blocking readers while closing
if oldProvider, ok := extractProvider(oldRegistry); ok {
@@ -987,6 +995,9 @@ func (al *AgentLoop) ProcessDirectWithChannel(
ctx context.Context,
content, sessionKey, channel, chatID string,
) (string, error) {
+ if err := al.ensureHooksInitialized(ctx); err != nil {
+ return "", err
+ }
if err := al.ensureMCPInitialized(ctx); err != nil {
return "", err
}
@@ -1008,6 +1019,13 @@ func (al *AgentLoop) ProcessHeartbeat(
ctx context.Context,
content, channel, chatID string,
) (string, error) {
+ if err := al.ensureHooksInitialized(ctx); err != nil {
+ return "", err
+ }
+ if err := al.ensureMCPInitialized(ctx); err != nil {
+ return "", err
+ }
+
agent := al.GetRegistry().GetDefaultAgent()
if agent == nil {
return "", fmt.Errorf("no default agent for heartbeat")
diff --git a/pkg/agent/steering.go b/pkg/agent/steering.go
index 77c2e0c17..55ee45ad1 100644
--- a/pkg/agent/steering.go
+++ b/pkg/agent/steering.go
@@ -183,6 +183,12 @@ func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID s
if active := al.GetActiveTurn(); active != nil {
return "", fmt.Errorf("turn %s is still active", active.TurnID)
}
+ if err := al.ensureHooksInitialized(ctx); err != nil {
+ return "", err
+ }
+ if err := al.ensureMCPInitialized(ctx); err != nil {
+ return "", err
+ }
steeringMsgs := al.dequeueSteeringMessages()
if len(steeringMsgs) == 0 {
diff --git a/pkg/config/config.go b/pkg/config/config.go
index a3720b656..a7c44c825 100644
--- a/pkg/config/config.go
+++ b/pkg/config/config.go
@@ -82,6 +82,7 @@ type Config struct {
Providers ProvidersConfig `json:"providers,omitempty"`
ModelList []ModelConfig `json:"model_list"` // New model-centric provider configuration
Gateway GatewayConfig `json:"gateway"`
+ Hooks HooksConfig `json:"hooks,omitempty"`
Tools ToolsConfig `json:"tools"`
Heartbeat HeartbeatConfig `json:"heartbeat"`
Devices DevicesConfig `json:"devices"`
@@ -90,6 +91,36 @@ type Config struct {
BuildInfo BuildInfo `json:"build_info,omitempty"`
}
+type HooksConfig struct {
+ Enabled bool `json:"enabled"`
+ Defaults HookDefaultsConfig `json:"defaults,omitempty"`
+ Builtins map[string]BuiltinHookConfig `json:"builtins,omitempty"`
+ Processes map[string]ProcessHookConfig `json:"processes,omitempty"`
+}
+
+type HookDefaultsConfig struct {
+ ObserverTimeoutMS int `json:"observer_timeout_ms,omitempty"`
+ InterceptorTimeoutMS int `json:"interceptor_timeout_ms,omitempty"`
+ ApprovalTimeoutMS int `json:"approval_timeout_ms,omitempty"`
+}
+
+type BuiltinHookConfig struct {
+ Enabled bool `json:"enabled"`
+ Priority int `json:"priority,omitempty"`
+ Config json.RawMessage `json:"config,omitempty"`
+}
+
+type ProcessHookConfig struct {
+ Enabled bool `json:"enabled"`
+ Priority int `json:"priority,omitempty"`
+ Transport string `json:"transport,omitempty"`
+ Command []string `json:"command,omitempty"`
+ Dir string `json:"dir,omitempty"`
+ Env map[string]string `json:"env,omitempty"`
+ Observe []string `json:"observe,omitempty"`
+ Intercept []string `json:"intercept,omitempty"`
+}
+
// BuildInfo contains build-time version information
type BuildInfo struct {
Version string `json:"version"`
diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go
index c5bdbf3c3..caab8a152 100644
--- a/pkg/config/config_test.go
+++ b/pkg/config/config_test.go
@@ -391,6 +391,22 @@ func TestDefaultConfig_ExecAllowRemoteEnabled(t *testing.T) {
}
}
+func TestDefaultConfig_HooksDefaults(t *testing.T) {
+ cfg := DefaultConfig()
+ if !cfg.Hooks.Enabled {
+ t.Fatal("DefaultConfig().Hooks.Enabled should be true")
+ }
+ if cfg.Hooks.Defaults.ObserverTimeoutMS != 500 {
+ t.Fatalf("ObserverTimeoutMS = %d, want 500", cfg.Hooks.Defaults.ObserverTimeoutMS)
+ }
+ if cfg.Hooks.Defaults.InterceptorTimeoutMS != 5000 {
+ t.Fatalf("InterceptorTimeoutMS = %d, want 5000", cfg.Hooks.Defaults.InterceptorTimeoutMS)
+ }
+ if cfg.Hooks.Defaults.ApprovalTimeoutMS != 60000 {
+ t.Fatalf("ApprovalTimeoutMS = %d, want 60000", cfg.Hooks.Defaults.ApprovalTimeoutMS)
+ }
+}
+
func TestLoadConfig_OpenAIWebSearchDefaultsTrueWhenUnset(t *testing.T) {
dir := t.TempDir()
configPath := filepath.Join(dir, "config.json")
@@ -460,6 +476,88 @@ func TestLoadConfig_WebToolsProxy(t *testing.T) {
}
}
+func TestLoadConfig_HooksProcessConfig(t *testing.T) {
+ tmpDir := t.TempDir()
+ configPath := filepath.Join(tmpDir, "config.json")
+ configJSON := `{
+ "hooks": {
+ "processes": {
+ "review-gate": {
+ "enabled": true,
+ "transport": "stdio",
+ "command": ["uvx", "picoclaw-hook-reviewer"],
+ "dir": "/tmp/hooks",
+ "env": {
+ "HOOK_MODE": "rewrite"
+ },
+ "observe": ["turn_start", "turn_end"],
+ "intercept": ["before_tool", "approve_tool"]
+ }
+ },
+ "builtins": {
+ "audit": {
+ "enabled": true,
+ "priority": 5,
+ "config": {
+ "label": "audit"
+ }
+ }
+ }
+ }
+}`
+ if err := os.WriteFile(configPath, []byte(configJSON), 0o600); err != nil {
+ t.Fatalf("os.WriteFile() error: %v", err)
+ }
+
+ cfg, err := LoadConfig(configPath)
+ if err != nil {
+ t.Fatalf("LoadConfig() error: %v", err)
+ }
+
+ processCfg, ok := cfg.Hooks.Processes["review-gate"]
+ if !ok {
+ t.Fatal("expected review-gate process hook")
+ }
+ if !processCfg.Enabled {
+ t.Fatal("expected review-gate process hook to be enabled")
+ }
+ if processCfg.Transport != "stdio" {
+ t.Fatalf("Transport = %q, want stdio", processCfg.Transport)
+ }
+ if len(processCfg.Command) != 2 || processCfg.Command[0] != "uvx" {
+ t.Fatalf("Command = %v", processCfg.Command)
+ }
+ if processCfg.Dir != "/tmp/hooks" {
+ t.Fatalf("Dir = %q, want /tmp/hooks", processCfg.Dir)
+ }
+ if processCfg.Env["HOOK_MODE"] != "rewrite" {
+ t.Fatalf("HOOK_MODE = %q, want rewrite", processCfg.Env["HOOK_MODE"])
+ }
+ if len(processCfg.Observe) != 2 || processCfg.Observe[1] != "turn_end" {
+ t.Fatalf("Observe = %v", processCfg.Observe)
+ }
+ if len(processCfg.Intercept) != 2 || processCfg.Intercept[1] != "approve_tool" {
+ t.Fatalf("Intercept = %v", processCfg.Intercept)
+ }
+
+ builtinCfg, ok := cfg.Hooks.Builtins["audit"]
+ if !ok {
+ t.Fatal("expected audit builtin hook")
+ }
+ if !builtinCfg.Enabled {
+ t.Fatal("expected audit builtin hook to be enabled")
+ }
+ if builtinCfg.Priority != 5 {
+ t.Fatalf("Priority = %d, want 5", builtinCfg.Priority)
+ }
+ if !strings.Contains(string(builtinCfg.Config), `"audit"`) {
+ t.Fatalf("Config = %s", string(builtinCfg.Config))
+ }
+ if cfg.Hooks.Defaults.ApprovalTimeoutMS != 60000 {
+ t.Fatalf("ApprovalTimeoutMS = %d, want 60000", cfg.Hooks.Defaults.ApprovalTimeoutMS)
+ }
+}
+
// TestDefaultConfig_DMScope verifies the default dm_scope value
// TestDefaultConfig_SummarizationThresholds verifies summarization defaults
func TestDefaultConfig_SummarizationThresholds(t *testing.T) {
diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go
index 5e6b89a4c..bfb54fb97 100644
--- a/pkg/config/defaults.go
+++ b/pkg/config/defaults.go
@@ -177,6 +177,14 @@ func DefaultConfig() *Config {
AllowFrom: FlexibleStringSlice{},
},
},
+ Hooks: HooksConfig{
+ Enabled: true,
+ Defaults: HookDefaultsConfig{
+ ObserverTimeoutMS: 500,
+ InterceptorTimeoutMS: 5000,
+ ApprovalTimeoutMS: 60000,
+ },
+ },
Providers: ProvidersConfig{
OpenAI: OpenAIProviderConfig{WebSearch: true},
},
From 9978c9550bc03f70e17dbbac5256263cc7fd1fed Mon Sep 17 00:00:00 2001
From: Hoshina
Date: Sat, 21 Mar 2026 23:18:29 +0800
Subject: [PATCH 26/26] docs(hooks): inline and translate hook examples
---
config/config.example.json | 8 +
docs/hooks/README.md | 679 +++++++++++++++++++++++++++++++++++++
docs/hooks/README.zh.md | 679 +++++++++++++++++++++++++++++++++++++
3 files changed, 1366 insertions(+)
create mode 100644 docs/hooks/README.md
create mode 100644 docs/hooks/README.zh.md
diff --git a/config/config.example.json b/config/config.example.json
index 20c10e60d..3c149c744 100644
--- a/config/config.example.json
+++ b/config/config.example.json
@@ -511,6 +511,14 @@
"voice": {
"echo_transcription": false
},
+ "hooks": {
+ "enabled": true,
+ "defaults": {
+ "observer_timeout_ms": 500,
+ "interceptor_timeout_ms": 5000,
+ "approval_timeout_ms": 60000
+ }
+ },
"gateway": {
"host": "127.0.0.1",
"port": 18790
diff --git a/docs/hooks/README.md b/docs/hooks/README.md
new file mode 100644
index 000000000..ec3bbc46a
--- /dev/null
+++ b/docs/hooks/README.md
@@ -0,0 +1,679 @@
+# Hook System Guide
+
+This document describes the hook system that is implemented in the current repository, not the older design draft.
+
+The current implementation supports two mounting modes:
+
+1. In-process hooks
+2. Out-of-process process hooks (`JSON-RPC over stdio`)
+
+The repository no longer ships standalone example source files. The Go and Python examples below are embedded directly in this document. If you want to use them, copy them into your own local files first.
+
+## Supported Hook Types
+
+| Type | Interface | Stage | Can modify data |
+| --- | --- | --- | --- |
+| Observer | `EventObserver` | EventBus broadcast | No |
+| LLM interceptor | `LLMInterceptor` | `before_llm` / `after_llm` | Yes |
+| Tool interceptor | `ToolInterceptor` | `before_tool` / `after_tool` | Yes |
+| Tool approver | `ToolApprover` | `approve_tool` | No, returns allow/deny |
+
+The currently exposed synchronous hook points are:
+
+- `before_llm`
+- `after_llm`
+- `before_tool`
+- `after_tool`
+- `approve_tool`
+
+Everything else is exposed as read-only events.
+
+## Execution Order
+
+`HookManager` sorts hooks like this:
+
+1. In-process hooks first
+2. Process hooks second
+3. Lower `priority` first within the same source
+4. Name order as the final tie-breaker
+
+## Timeouts
+
+Global defaults live under `hooks.defaults`:
+
+- `observer_timeout_ms`
+- `interceptor_timeout_ms`
+- `approval_timeout_ms`
+
+Note: the current implementation does not support per-process-hook `timeout_ms`. Timeouts are global defaults.
+
+## Quick Start
+
+If your first goal is simply to prove that the hook flow works and observe real requests, the easiest path is the Python process-hook example below:
+
+1. Enable `hooks.enabled`
+2. Save the Python example from this document to a local file, for example `/tmp/review_gate.py`
+3. Set `PICOCLAW_HOOK_LOG_FILE`
+4. Restart the gateway
+5. Watch the log file with `tail -f`
+
+Example:
+
+```json
+{
+ "hooks": {
+ "enabled": true,
+ "processes": {
+ "py_review_gate": {
+ "enabled": true,
+ "priority": 100,
+ "transport": "stdio",
+ "command": [
+ "python3",
+ "/tmp/review_gate.py"
+ ],
+ "observe": [
+ "tool_exec_start",
+ "tool_exec_end",
+ "tool_exec_skipped"
+ ],
+ "intercept": [
+ "before_tool",
+ "approve_tool"
+ ],
+ "env": {
+ "PICOCLAW_HOOK_LOG_FILE": "/tmp/picoclaw-hook-review-gate.log"
+ }
+ }
+ }
+ }
+}
+```
+
+Watch it with:
+
+```bash
+tail -f /tmp/picoclaw-hook-review-gate.log
+```
+
+If you are developing PicoClaw itself rather than only validating the protocol, continue with the Go in-process example as well.
+
+## What The Two Examples Are For
+
+- Go in-process example
+ Best for validating the host-side hook chain and understanding `MountHook()` plus the synchronous stages
+- Python process example
+ Best for understanding the `JSON-RPC over stdio` protocol and verifying the message flow between PicoClaw and an external process
+
+Both examples are intentionally safe: they only log, never rewrite, and never deny.
+
+## Go In-Process Example
+
+The following is a minimal logging hook for in-process use. It implements:
+
+1. `EventObserver`
+2. `LLMInterceptor`
+3. `ToolInterceptor`
+4. `ToolApprover`
+
+It only records activity. It does not rewrite requests or reject tools.
+
+You can save it as your own Go file, for example `pkg/myhooks/example_logger.go`:
+
+```go
+package myhooks
+
+import (
+ "context"
+ "encoding/json"
+ "os"
+ "path/filepath"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/sipeed/picoclaw/pkg/agent"
+ "github.com/sipeed/picoclaw/pkg/logger"
+)
+
+type ExampleLoggerHookOptions struct {
+ LogFile string `json:"log_file,omitempty"`
+ LogEvents bool `json:"log_events,omitempty"`
+}
+
+type ExampleLoggerHook struct {
+ logFile string
+ logEvents bool
+ mu sync.Mutex
+}
+
+func NewExampleLoggerHook(opts ExampleLoggerHookOptions) *ExampleLoggerHook {
+ return &ExampleLoggerHook{
+ logFile: strings.TrimSpace(opts.LogFile),
+ logEvents: opts.LogEvents,
+ }
+}
+
+func (h *ExampleLoggerHook) OnEvent(ctx context.Context, evt agent.Event) error {
+ _ = ctx
+ if h == nil || !h.logEvents {
+ return nil
+ }
+ h.record("event", evt.Meta, map[string]any{
+ "event": evt.Kind.String(),
+ "payload": evt.Payload,
+ }, nil)
+ return nil
+}
+
+func (h *ExampleLoggerHook) BeforeLLM(
+ ctx context.Context,
+ req *agent.LLMHookRequest,
+) (*agent.LLMHookRequest, agent.HookDecision, error) {
+ _ = ctx
+ h.record("before_llm", req.Meta, req, agent.HookDecision{Action: agent.HookActionContinue})
+ return req, agent.HookDecision{Action: agent.HookActionContinue}, nil
+}
+
+func (h *ExampleLoggerHook) AfterLLM(
+ ctx context.Context,
+ resp *agent.LLMHookResponse,
+) (*agent.LLMHookResponse, agent.HookDecision, error) {
+ _ = ctx
+ h.record("after_llm", resp.Meta, resp, agent.HookDecision{Action: agent.HookActionContinue})
+ return resp, agent.HookDecision{Action: agent.HookActionContinue}, nil
+}
+
+func (h *ExampleLoggerHook) BeforeTool(
+ ctx context.Context,
+ call *agent.ToolCallHookRequest,
+) (*agent.ToolCallHookRequest, agent.HookDecision, error) {
+ _ = ctx
+ h.record("before_tool", call.Meta, call, agent.HookDecision{Action: agent.HookActionContinue})
+ return call, agent.HookDecision{Action: agent.HookActionContinue}, nil
+}
+
+func (h *ExampleLoggerHook) AfterTool(
+ ctx context.Context,
+ result *agent.ToolResultHookResponse,
+) (*agent.ToolResultHookResponse, agent.HookDecision, error) {
+ _ = ctx
+ h.record("after_tool", result.Meta, result, agent.HookDecision{Action: agent.HookActionContinue})
+ return result, agent.HookDecision{Action: agent.HookActionContinue}, nil
+}
+
+func (h *ExampleLoggerHook) ApproveTool(
+ ctx context.Context,
+ req *agent.ToolApprovalRequest,
+) (agent.ApprovalDecision, error) {
+ _ = ctx
+ decision := agent.ApprovalDecision{Approved: true}
+ h.record("approve_tool", req.Meta, req, decision)
+ return decision, nil
+}
+
+func (h *ExampleLoggerHook) record(stage string, meta agent.EventMeta, payload any, decision any) {
+ logger.InfoCF("hooks", "Example hook observed", map[string]any{
+ "stage": stage,
+ })
+ if h == nil || h.logFile == "" {
+ return
+ }
+
+ entry := map[string]any{
+ "ts": time.Now().UTC(),
+ "stage": stage,
+ "meta": meta,
+ "payload": payload,
+ "decision": decision,
+ }
+
+ body, err := json.Marshal(entry)
+ if err != nil {
+ logger.WarnCF("hooks", "Example hook log encode failed", map[string]any{
+ "stage": stage,
+ "error": err.Error(),
+ })
+ return
+ }
+
+ h.mu.Lock()
+ defer h.mu.Unlock()
+
+ if dir := filepath.Dir(h.logFile); dir != "" && dir != "." {
+ if err := os.MkdirAll(dir, 0o755); err != nil {
+ logger.WarnCF("hooks", "Example hook log mkdir failed", map[string]any{
+ "stage": stage,
+ "path": h.logFile,
+ "error": err.Error(),
+ })
+ return
+ }
+ }
+
+ file, err := os.OpenFile(h.logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644)
+ if err != nil {
+ logger.WarnCF("hooks", "Example hook log open failed", map[string]any{
+ "stage": stage,
+ "path": h.logFile,
+ "error": err.Error(),
+ })
+ return
+ }
+ defer func() { _ = file.Close() }()
+
+ if _, err := file.Write(append(body, '\n')); err != nil {
+ logger.WarnCF("hooks", "Example hook log write failed", map[string]any{
+ "stage": stage,
+ "path": h.logFile,
+ "error": err.Error(),
+ })
+ }
+}
+```
+
+### Mounting It In Code
+
+If code mounting is enough, call this after `AgentLoop` is initialized:
+
+```go
+hook := myhooks.NewExampleLoggerHook(myhooks.ExampleLoggerHookOptions{
+ LogFile: "/tmp/picoclaw-hook-example-logger.log",
+ LogEvents: true,
+})
+
+if err := al.MountHook(agent.NamedHook("example-logger", hook)); err != nil {
+ panic(err)
+}
+```
+
+### If You Also Want Config Mounting
+
+The hook system supports builtin hooks, but that requires you to compile the factory into your binary. In practice, that means you need registration code like this alongside the hook definition above:
+
+```go
+package myhooks
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+
+ "github.com/sipeed/picoclaw/pkg/agent"
+ "github.com/sipeed/picoclaw/pkg/config"
+)
+
+func init() {
+ if err := agent.RegisterBuiltinHook("example_logger", func(
+ ctx context.Context,
+ spec config.BuiltinHookConfig,
+ ) (any, error) {
+ _ = ctx
+
+ var opts ExampleLoggerHookOptions
+ if len(spec.Config) > 0 {
+ if err := json.Unmarshal(spec.Config, &opts); err != nil {
+ return nil, fmt.Errorf("decode example_logger config: %w", err)
+ }
+ }
+ return NewExampleLoggerHook(opts), nil
+ }); err != nil {
+ panic(err)
+ }
+}
+```
+
+Only after you register that builtin will the following config work:
+
+```json
+{
+ "hooks": {
+ "enabled": true,
+ "builtins": {
+ "example_logger": {
+ "enabled": true,
+ "priority": 10,
+ "config": {
+ "log_file": "/tmp/picoclaw-hook-example-logger.log",
+ "log_events": true
+ }
+ }
+ }
+ }
+}
+```
+
+### How To Observe It
+
+- If `log_file` is set, each hook call is appended as JSON Lines
+- If `log_file` is not set, the hook still writes summaries to the gateway log
+- Requests that only hit the LLM path usually show `before_llm` and `after_llm`
+- Requests that trigger tools usually also show `before_tool`, `approve_tool`, and `after_tool`
+- If `log_events=true`, you will also see `event`
+
+Typical log lines:
+
+```json
+{"ts":"2026-03-21T14:10:00Z","stage":"before_tool","meta":{"session_key":"session-1"},"payload":{"tool":"echo_text","arguments":{"text":"hello"}},"decision":{"action":"continue"}}
+{"ts":"2026-03-21T14:10:00Z","stage":"approve_tool","meta":{"session_key":"session-1"},"payload":{"tool":"echo_text","arguments":{"text":"hello"}},"decision":{"approved":true}}
+```
+
+If you only see `before_llm` and `after_llm`, that usually means the request did not trigger any tool call, not that the hook failed to mount.
+
+## Python Process-Hook Example
+
+The following script is a minimal process-hook example. It uses only the Python standard library and supports:
+
+1. `hook.hello`
+2. `hook.event`
+3. `hook.before_tool`
+4. `hook.approve_tool`
+
+It only records activity. It does not rewrite or deny anything.
+
+Save it to any local path, for example `/tmp/review_gate.py`:
+
+```python
+#!/usr/bin/env python3
+from __future__ import annotations
+
+import json
+import os
+import signal
+import sys
+from datetime import datetime, timezone
+from typing import Any
+
+LOG_EVENTS = os.getenv("PICOCLAW_HOOK_LOG_EVENTS", "1").lower() not in {"0", "false", "no"}
+LOG_FILE = os.getenv("PICOCLAW_HOOK_LOG_FILE", "").strip()
+
+
+def append_log(entry: dict[str, Any]) -> None:
+ if not LOG_FILE:
+ return
+
+ payload = {
+ "ts": datetime.now(timezone.utc).isoformat(),
+ **entry,
+ }
+ try:
+ log_dir = os.path.dirname(LOG_FILE)
+ if log_dir:
+ os.makedirs(log_dir, exist_ok=True)
+ with open(LOG_FILE, "a", encoding="utf-8") as handle:
+ handle.write(json.dumps(payload, ensure_ascii=True) + "\n")
+ except OSError as exc:
+ log_stderr(f"failed to write hook log file {LOG_FILE}: {exc}")
+
+
+def send_response(message_id: int, result: Any | None = None, error: str | None = None) -> None:
+ payload: dict[str, Any] = {
+ "jsonrpc": "2.0",
+ "id": message_id,
+ }
+ if error is not None:
+ payload["error"] = {"code": -32000, "message": error}
+ else:
+ payload["result"] = result if result is not None else {}
+
+ append_log({
+ "direction": "out",
+ "id": message_id,
+ "response": payload.get("result"),
+ "error": payload.get("error"),
+ })
+
+ try:
+ sys.stdout.write(json.dumps(payload, ensure_ascii=True) + "\n")
+ sys.stdout.flush()
+ except BrokenPipeError:
+ raise SystemExit(0) from None
+
+
+def log_stderr(message: str) -> None:
+ try:
+ sys.stderr.write(message + "\n")
+ sys.stderr.flush()
+ except BrokenPipeError:
+ raise SystemExit(0) from None
+
+
+def handle_shutdown_signal(signum: int, _frame: Any) -> None:
+ raise KeyboardInterrupt(f"received signal {signum}")
+
+
+def handle_before_tool(params: dict[str, Any]) -> dict[str, Any]:
+ _ = params
+ return {"action": "continue"}
+
+
+def handle_approve_tool(params: dict[str, Any]) -> dict[str, Any]:
+ _ = params
+ return {"approved": True}
+
+
+def handle_request(method: str, params: dict[str, Any]) -> dict[str, Any]:
+ if method == "hook.hello":
+ return {"ok": True, "name": "python-review-gate"}
+ if method == "hook.before_tool":
+ return handle_before_tool(params)
+ if method == "hook.approve_tool":
+ return handle_approve_tool(params)
+ if method == "hook.before_llm":
+ return {"action": "continue"}
+ if method == "hook.after_llm":
+ return {"action": "continue"}
+ if method == "hook.after_tool":
+ return {"action": "continue"}
+ raise KeyError(f"method not found: {method}")
+
+
+def main() -> int:
+ try:
+ for raw_line in sys.stdin:
+ line = raw_line.strip()
+ if not line:
+ continue
+
+ try:
+ message = json.loads(line)
+ except json.JSONDecodeError as exc:
+ log_stderr(f"failed to decode request: {exc}")
+ append_log({
+ "direction": "in",
+ "decode_error": str(exc),
+ "raw": line,
+ })
+ continue
+
+ method = message.get("method")
+ message_id = message.get("id", 0)
+ params = message.get("params") or {}
+ if not isinstance(params, dict):
+ params = {}
+
+ append_log({
+ "direction": "in",
+ "id": message_id,
+ "method": method,
+ "params": params,
+ "notification": not bool(message_id),
+ })
+
+ if not message_id:
+ if method == "hook.event" and LOG_EVENTS:
+ log_stderr(f"observed event: {params.get('Kind')}")
+ continue
+
+ try:
+ result = handle_request(str(method or ""), params)
+ except KeyError as exc:
+ send_response(int(message_id), error=str(exc))
+ continue
+ except Exception as exc:
+ send_response(int(message_id), error=f"unexpected error: {exc}")
+ continue
+
+ send_response(int(message_id), result=result)
+ except KeyboardInterrupt:
+ return 0
+
+ return 0
+
+
+if __name__ == "__main__":
+ signal.signal(signal.SIGINT, handle_shutdown_signal)
+ signal.signal(signal.SIGTERM, handle_shutdown_signal)
+ raise SystemExit(main())
+```
+
+### Configuration
+
+```json
+{
+ "hooks": {
+ "enabled": true,
+ "processes": {
+ "py_review_gate": {
+ "enabled": true,
+ "priority": 100,
+ "transport": "stdio",
+ "command": [
+ "python3",
+ "/abs/path/to/review_gate.py"
+ ],
+ "observe": [
+ "tool_exec_start",
+ "tool_exec_end",
+ "tool_exec_skipped"
+ ],
+ "intercept": [
+ "before_tool",
+ "approve_tool"
+ ],
+ "env": {
+ "PICOCLAW_HOOK_LOG_FILE": "/tmp/picoclaw-hook-review-gate.log"
+ }
+ }
+ }
+ }
+}
+```
+
+### Environment Variables
+
+- `PICOCLAW_HOOK_LOG_EVENTS`
+ Whether to write `hook.event` summaries to `stderr`, enabled by default
+- `PICOCLAW_HOOK_LOG_FILE`
+ Path to an external log file. When set, the script appends inbound hook requests, notifications, and outbound responses as JSON Lines
+
+Note: `PICOCLAW_HOOK_LOG_FILE` has no default. If you do not set it, the script does not write any file logs.
+
+### How To Confirm It Received Hooks
+
+Watch two places:
+
+- Gateway logs
+ Useful for confirming that the host successfully started the process and for seeing event summaries written to `stderr`
+- `PICOCLAW_HOOK_LOG_FILE`
+ Useful for seeing the exact requests the script received and the exact responses it returned
+
+Typical interpretation:
+
+- Only `hook.hello`
+ The process started and completed the handshake, but no business hook request has arrived yet
+- `hook.event`
+ The `observe` configuration is working
+- `hook.before_tool`
+ The `intercept: ["before_tool", ...]` configuration is working
+- `hook.approve_tool`
+ The approval hook path is working
+
+Because this example never rewrites or denies, the expected responses look like:
+
+```json
+{"direction":"out","id":7,"response":{"action":"continue"},"error":null}
+{"direction":"out","id":8,"response":{"approved":true},"error":null}
+```
+
+A complete sample:
+
+```json
+{"ts":"2026-03-21T14:12:00+00:00","direction":"in","id":1,"method":"hook.hello","params":{"name":"py_review_gate","version":1,"modes":["observe","tool","approve"]},"notification":false}
+{"ts":"2026-03-21T14:12:00+00:00","direction":"out","id":1,"response":{"ok":true,"name":"python-review-gate"},"error":null}
+{"ts":"2026-03-21T14:12:05+00:00","direction":"in","id":0,"method":"hook.event","params":{"Kind":"tool_exec_start"},"notification":true}
+{"ts":"2026-03-21T14:12:05+00:00","direction":"in","id":7,"method":"hook.before_tool","params":{"tool":"echo_text","arguments":{"text":"hello"}},"notification":false}
+{"ts":"2026-03-21T14:12:05+00:00","direction":"out","id":7,"response":{"action":"continue"},"error":null}
+```
+
+Additional notes:
+
+- Timestamps are UTC
+- `notification=true` means it was a notification such as `hook.event`, which does not expect a response
+- `id` increases within a single hook process; if the process restarts, the counter starts over
+
+## Process-Hook Protocol
+
+Current process hooks use `JSON-RPC over stdio`:
+
+- PicoClaw starts the external process
+- Requests and responses are exchanged as one JSON message per line
+- `hook.event` is a notification and does not need a response
+- `hook.before_llm`, `hook.after_llm`, `hook.before_tool`, `hook.after_tool`, and `hook.approve_tool` are request/response calls
+
+The host does not currently accept new RPCs initiated by the process hook. In practice, that means an external hook can only respond to PicoClaw calls; it cannot call back into the host to send channel messages.
+
+## Configuration Fields
+
+### `hooks.builtins.`
+
+- `enabled`
+- `priority`
+- `config`
+
+### `hooks.processes.`
+
+- `enabled`
+- `priority`
+- `transport`
+ Currently only `stdio` is supported
+- `command`
+- `dir`
+- `env`
+- `observe`
+- `intercept`
+
+## Troubleshooting
+
+If a hook looks like it is not firing, check these in order:
+
+1. `hooks.enabled`
+2. Whether the target builtin or process hook is `enabled`
+3. Whether the process-hook `command` path is correct
+4. Whether you are watching the correct log file
+5. Whether the current request actually reached the stage you care about
+6. Whether `observe` or `intercept` contains the hook point you want
+
+A practical minimal troubleshooting pair is:
+
+- Use the Python process-hook example from this document to validate the external protocol
+- Use the Go in-process example from this document to validate the host-side chain
+
+If the Python side shows `hook.hello` but no business hook requests, the protocol is usually fine; the current request simply did not trigger the stage you expected.
+
+## Scope And Limits
+
+The current hook system is best suited for:
+
+- LLM request rewriting
+- Tool argument normalization
+- Pre-execution tool approval
+- Auditing and observability
+
+It is not yet well suited for:
+
+- External hooks actively sending channel messages
+- Suspending a turn and waiting for human approval replies
+- Full inbound/outbound message interception across the whole platform
+
+If you want a real human approval workflow, use hooks as the approval entry point and keep the state machine plus channel interaction in a separate `ApprovalManager`.
diff --git a/docs/hooks/README.zh.md b/docs/hooks/README.zh.md
new file mode 100644
index 000000000..46c7c9392
--- /dev/null
+++ b/docs/hooks/README.zh.md
@@ -0,0 +1,679 @@
+# Hook 系统使用说明
+
+这份文档对应当前仓库里已经实现的 hook 系统,而不是设计草案。
+
+当前实现支持两类挂载方式:
+
+1. 进程内 hook
+2. 进程外 process hook(`JSON-RPC over stdio`)
+
+当前仓库不再内置示例代码文件。下面的 Go / Python 示例都直接写在本文档里;如果你要使用它们,需要先复制到你自己的文件路径。
+
+## 支持的 hook 类型
+
+| 类型 | 接口 | 作用阶段 | 能否改写 |
+| --- | --- | --- | --- |
+| 观察型 | `EventObserver` | EventBus 广播事件时 | 否 |
+| LLM 拦截型 | `LLMInterceptor` | `before_llm` / `after_llm` | 是 |
+| Tool 拦截型 | `ToolInterceptor` | `before_tool` / `after_tool` | 是 |
+| Tool 审批型 | `ToolApprover` | `approve_tool` | 否,返回批准/拒绝 |
+
+当前公开的同步点位只有:
+
+- `before_llm`
+- `after_llm`
+- `before_tool`
+- `after_tool`
+- `approve_tool`
+
+其余 lifecycle 通过事件形式只读暴露。
+
+## 执行顺序
+
+HookManager 的排序规则是:
+
+1. 先执行进程内 hook
+2. 再执行 process hook
+3. 同一来源内按 `priority` 从小到大
+4. 若 `priority` 相同,再按名字排序
+
+## 超时
+
+当前配置在 `hooks.defaults` 中统一设置:
+
+- `observer_timeout_ms`
+- `interceptor_timeout_ms`
+- `approval_timeout_ms`
+
+注意:当前实现还没有单个 process hook 自己的 `timeout_ms` 字段,超时配置是全局默认值。
+
+## 快速开始
+
+如果你的目标只是先把当前 hook 流程跑通并观察到实际请求,最省事的是先用下面的 Python process hook 示例:
+
+1. 打开 `hooks.enabled`
+2. 把下面文档里的 Python 示例保存到本地文件,例如 `/tmp/review_gate.py`
+3. 给它配置 `PICOCLAW_HOOK_LOG_FILE`
+4. 重启 gateway
+5. 用 `tail -f` 观察日志文件
+
+例如:
+
+```json
+{
+ "hooks": {
+ "enabled": true,
+ "processes": {
+ "py_review_gate": {
+ "enabled": true,
+ "priority": 100,
+ "transport": "stdio",
+ "command": [
+ "python3",
+ "/tmp/review_gate.py"
+ ],
+ "observe": [
+ "tool_exec_start",
+ "tool_exec_end",
+ "tool_exec_skipped"
+ ],
+ "intercept": [
+ "before_tool",
+ "approve_tool"
+ ],
+ "env": {
+ "PICOCLAW_HOOK_LOG_FILE": "/tmp/picoclaw-hook-review-gate.log"
+ }
+ }
+ }
+ }
+}
+```
+
+观察方式:
+
+```bash
+tail -f /tmp/picoclaw-hook-review-gate.log
+```
+
+如果你是在开发 PicoClaw 本体,而不是只想验证协议,那么再看后面的 Go in-process 示例。
+
+## 两个示例的定位
+
+- Go in-process 示例
+ 适合验证宿主内的 hook 链路、理解 `MountHook()` 和各个同步点位
+- Python process 示例
+ 适合理解 `JSON-RPC over stdio` 协议、确认宿主和外部进程之间的消息来回是否正常
+
+这两个示例都刻意保持为“只记录、不改写、不拒绝”的安全模式。它们的目的不是提供策略能力,而是帮你观察当前 hook 系统。
+
+## Go 进程内示例
+
+下面这段代码是一个最小的“记录型” in-process hook。它实现了:
+
+1. `EventObserver`
+2. `LLMInterceptor`
+3. `ToolInterceptor`
+4. `ToolApprover`
+
+它只记录,不改写请求,也不拒绝工具。
+
+你可以把它保存成你自己的 Go 文件,例如 `pkg/myhooks/example_logger.go`:
+
+```go
+package myhooks
+
+import (
+ "context"
+ "encoding/json"
+ "os"
+ "path/filepath"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/sipeed/picoclaw/pkg/agent"
+ "github.com/sipeed/picoclaw/pkg/logger"
+)
+
+type ExampleLoggerHookOptions struct {
+ LogFile string `json:"log_file,omitempty"`
+ LogEvents bool `json:"log_events,omitempty"`
+}
+
+type ExampleLoggerHook struct {
+ logFile string
+ logEvents bool
+ mu sync.Mutex
+}
+
+func NewExampleLoggerHook(opts ExampleLoggerHookOptions) *ExampleLoggerHook {
+ return &ExampleLoggerHook{
+ logFile: strings.TrimSpace(opts.LogFile),
+ logEvents: opts.LogEvents,
+ }
+}
+
+func (h *ExampleLoggerHook) OnEvent(ctx context.Context, evt agent.Event) error {
+ _ = ctx
+ if h == nil || !h.logEvents {
+ return nil
+ }
+ h.record("event", evt.Meta, map[string]any{
+ "event": evt.Kind.String(),
+ "payload": evt.Payload,
+ }, nil)
+ return nil
+}
+
+func (h *ExampleLoggerHook) BeforeLLM(
+ ctx context.Context,
+ req *agent.LLMHookRequest,
+) (*agent.LLMHookRequest, agent.HookDecision, error) {
+ _ = ctx
+ h.record("before_llm", req.Meta, req, agent.HookDecision{Action: agent.HookActionContinue})
+ return req, agent.HookDecision{Action: agent.HookActionContinue}, nil
+}
+
+func (h *ExampleLoggerHook) AfterLLM(
+ ctx context.Context,
+ resp *agent.LLMHookResponse,
+) (*agent.LLMHookResponse, agent.HookDecision, error) {
+ _ = ctx
+ h.record("after_llm", resp.Meta, resp, agent.HookDecision{Action: agent.HookActionContinue})
+ return resp, agent.HookDecision{Action: agent.HookActionContinue}, nil
+}
+
+func (h *ExampleLoggerHook) BeforeTool(
+ ctx context.Context,
+ call *agent.ToolCallHookRequest,
+) (*agent.ToolCallHookRequest, agent.HookDecision, error) {
+ _ = ctx
+ h.record("before_tool", call.Meta, call, agent.HookDecision{Action: agent.HookActionContinue})
+ return call, agent.HookDecision{Action: agent.HookActionContinue}, nil
+}
+
+func (h *ExampleLoggerHook) AfterTool(
+ ctx context.Context,
+ result *agent.ToolResultHookResponse,
+) (*agent.ToolResultHookResponse, agent.HookDecision, error) {
+ _ = ctx
+ h.record("after_tool", result.Meta, result, agent.HookDecision{Action: agent.HookActionContinue})
+ return result, agent.HookDecision{Action: agent.HookActionContinue}, nil
+}
+
+func (h *ExampleLoggerHook) ApproveTool(
+ ctx context.Context,
+ req *agent.ToolApprovalRequest,
+) (agent.ApprovalDecision, error) {
+ _ = ctx
+ decision := agent.ApprovalDecision{Approved: true}
+ h.record("approve_tool", req.Meta, req, decision)
+ return decision, nil
+}
+
+func (h *ExampleLoggerHook) record(stage string, meta agent.EventMeta, payload any, decision any) {
+ logger.InfoCF("hooks", "Example hook observed", map[string]any{
+ "stage": stage,
+ })
+ if h == nil || h.logFile == "" {
+ return
+ }
+
+ entry := map[string]any{
+ "ts": time.Now().UTC(),
+ "stage": stage,
+ "meta": meta,
+ "payload": payload,
+ "decision": decision,
+ }
+
+ body, err := json.Marshal(entry)
+ if err != nil {
+ logger.WarnCF("hooks", "Example hook log encode failed", map[string]any{
+ "stage": stage,
+ "error": err.Error(),
+ })
+ return
+ }
+
+ h.mu.Lock()
+ defer h.mu.Unlock()
+
+ if dir := filepath.Dir(h.logFile); dir != "" && dir != "." {
+ if err := os.MkdirAll(dir, 0o755); err != nil {
+ logger.WarnCF("hooks", "Example hook log mkdir failed", map[string]any{
+ "stage": stage,
+ "path": h.logFile,
+ "error": err.Error(),
+ })
+ return
+ }
+ }
+
+ file, err := os.OpenFile(h.logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o644)
+ if err != nil {
+ logger.WarnCF("hooks", "Example hook log open failed", map[string]any{
+ "stage": stage,
+ "path": h.logFile,
+ "error": err.Error(),
+ })
+ return
+ }
+ defer func() { _ = file.Close() }()
+
+ if _, err := file.Write(append(body, '\n')); err != nil {
+ logger.WarnCF("hooks", "Example hook log write failed", map[string]any{
+ "stage": stage,
+ "path": h.logFile,
+ "error": err.Error(),
+ })
+ }
+}
+```
+
+### 如何挂载
+
+如果你只需要代码挂载,直接在 `AgentLoop` 初始化后调用:
+
+```go
+hook := myhooks.NewExampleLoggerHook(myhooks.ExampleLoggerHookOptions{
+ LogFile: "/tmp/picoclaw-hook-example-logger.log",
+ LogEvents: true,
+})
+
+if err := al.MountHook(agent.NamedHook("example-logger", hook)); err != nil {
+ panic(err)
+}
+```
+
+### 如果你还想用配置挂载
+
+当前 hook 系统支持 builtin hook,但这要求你自己把 factory 编进二进制。也就是说,下面这段注册代码需要和上面的 hook 定义一起放进你的工程里:
+
+```go
+package myhooks
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+
+ "github.com/sipeed/picoclaw/pkg/agent"
+ "github.com/sipeed/picoclaw/pkg/config"
+)
+
+func init() {
+ if err := agent.RegisterBuiltinHook("example_logger", func(
+ ctx context.Context,
+ spec config.BuiltinHookConfig,
+ ) (any, error) {
+ _ = ctx
+
+ var opts ExampleLoggerHookOptions
+ if len(spec.Config) > 0 {
+ if err := json.Unmarshal(spec.Config, &opts); err != nil {
+ return nil, fmt.Errorf("decode example_logger config: %w", err)
+ }
+ }
+ return NewExampleLoggerHook(opts), nil
+ }); err != nil {
+ panic(err)
+ }
+}
+```
+
+只有在你自己注册了 builtin 之后,下面的配置才会生效:
+
+```json
+{
+ "hooks": {
+ "enabled": true,
+ "builtins": {
+ "example_logger": {
+ "enabled": true,
+ "priority": 10,
+ "config": {
+ "log_file": "/tmp/picoclaw-hook-example-logger.log",
+ "log_events": true
+ }
+ }
+ }
+ }
+}
+```
+
+### 如何观察它是否生效
+
+- 如果设置了 `log_file`,它会把每次 hook 调用按 JSON Lines 写入文件
+- 如果没有设置 `log_file`,它仍然会把摘要写到 gateway 日志
+- 普通只走 LLM 的请求,通常会看到 `before_llm` 和 `after_llm`
+- 触发工具调用的请求,通常还会看到 `before_tool`、`approve_tool`、`after_tool`
+- 如果 `log_events=true`,还会额外看到 `event`
+
+典型日志:
+
+```json
+{"ts":"2026-03-21T14:10:00Z","stage":"before_tool","meta":{"session_key":"session-1"},"payload":{"tool":"echo_text","arguments":{"text":"hello"}},"decision":{"action":"continue"}}
+{"ts":"2026-03-21T14:10:00Z","stage":"approve_tool","meta":{"session_key":"session-1"},"payload":{"tool":"echo_text","arguments":{"text":"hello"}},"decision":{"approved":true}}
+```
+
+如果你只看到了 `before_llm` / `after_llm`,没有看到 tool 相关阶段,通常不是 hook 没挂上,而是这次请求本身没有触发工具调用。
+
+## Python process hook 示例
+
+下面这段脚本是一个最小的 `process hook` 示例。它只使用 Python 标准库,支持:
+
+1. `hook.hello`
+2. `hook.event`
+3. `hook.before_tool`
+4. `hook.approve_tool`
+
+它默认只记录,不改写,也不拒绝。
+
+你可以把它保存到任意本地路径,例如 `/tmp/review_gate.py`:
+
+```python
+#!/usr/bin/env python3
+from __future__ import annotations
+
+import json
+import os
+import signal
+import sys
+from datetime import datetime, timezone
+from typing import Any
+
+LOG_EVENTS = os.getenv("PICOCLAW_HOOK_LOG_EVENTS", "1").lower() not in {"0", "false", "no"}
+LOG_FILE = os.getenv("PICOCLAW_HOOK_LOG_FILE", "").strip()
+
+
+def append_log(entry: dict[str, Any]) -> None:
+ if not LOG_FILE:
+ return
+
+ payload = {
+ "ts": datetime.now(timezone.utc).isoformat(),
+ **entry,
+ }
+ try:
+ log_dir = os.path.dirname(LOG_FILE)
+ if log_dir:
+ os.makedirs(log_dir, exist_ok=True)
+ with open(LOG_FILE, "a", encoding="utf-8") as handle:
+ handle.write(json.dumps(payload, ensure_ascii=True) + "\n")
+ except OSError as exc:
+ log_stderr(f"failed to write hook log file {LOG_FILE}: {exc}")
+
+
+def send_response(message_id: int, result: Any | None = None, error: str | None = None) -> None:
+ payload: dict[str, Any] = {
+ "jsonrpc": "2.0",
+ "id": message_id,
+ }
+ if error is not None:
+ payload["error"] = {"code": -32000, "message": error}
+ else:
+ payload["result"] = result if result is not None else {}
+
+ append_log({
+ "direction": "out",
+ "id": message_id,
+ "response": payload.get("result"),
+ "error": payload.get("error"),
+ })
+
+ try:
+ sys.stdout.write(json.dumps(payload, ensure_ascii=True) + "\n")
+ sys.stdout.flush()
+ except BrokenPipeError:
+ raise SystemExit(0) from None
+
+
+def log_stderr(message: str) -> None:
+ try:
+ sys.stderr.write(message + "\n")
+ sys.stderr.flush()
+ except BrokenPipeError:
+ raise SystemExit(0) from None
+
+
+def handle_shutdown_signal(signum: int, _frame: Any) -> None:
+ raise KeyboardInterrupt(f"received signal {signum}")
+
+
+def handle_before_tool(params: dict[str, Any]) -> dict[str, Any]:
+ _ = params
+ return {"action": "continue"}
+
+
+def handle_approve_tool(params: dict[str, Any]) -> dict[str, Any]:
+ _ = params
+ return {"approved": True}
+
+
+def handle_request(method: str, params: dict[str, Any]) -> dict[str, Any]:
+ if method == "hook.hello":
+ return {"ok": True, "name": "python-review-gate"}
+ if method == "hook.before_tool":
+ return handle_before_tool(params)
+ if method == "hook.approve_tool":
+ return handle_approve_tool(params)
+ if method == "hook.before_llm":
+ return {"action": "continue"}
+ if method == "hook.after_llm":
+ return {"action": "continue"}
+ if method == "hook.after_tool":
+ return {"action": "continue"}
+ raise KeyError(f"method not found: {method}")
+
+
+def main() -> int:
+ try:
+ for raw_line in sys.stdin:
+ line = raw_line.strip()
+ if not line:
+ continue
+
+ try:
+ message = json.loads(line)
+ except json.JSONDecodeError as exc:
+ log_stderr(f"failed to decode request: {exc}")
+ append_log({
+ "direction": "in",
+ "decode_error": str(exc),
+ "raw": line,
+ })
+ continue
+
+ method = message.get("method")
+ message_id = message.get("id", 0)
+ params = message.get("params") or {}
+ if not isinstance(params, dict):
+ params = {}
+
+ append_log({
+ "direction": "in",
+ "id": message_id,
+ "method": method,
+ "params": params,
+ "notification": not bool(message_id),
+ })
+
+ if not message_id:
+ if method == "hook.event" and LOG_EVENTS:
+ log_stderr(f"observed event: {params.get('Kind')}")
+ continue
+
+ try:
+ result = handle_request(str(method or ""), params)
+ except KeyError as exc:
+ send_response(int(message_id), error=str(exc))
+ continue
+ except Exception as exc:
+ send_response(int(message_id), error=f"unexpected error: {exc}")
+ continue
+
+ send_response(int(message_id), result=result)
+ except KeyboardInterrupt:
+ return 0
+
+ return 0
+
+
+if __name__ == "__main__":
+ signal.signal(signal.SIGINT, handle_shutdown_signal)
+ signal.signal(signal.SIGTERM, handle_shutdown_signal)
+ raise SystemExit(main())
+```
+
+### 如何配置
+
+```json
+{
+ "hooks": {
+ "enabled": true,
+ "processes": {
+ "py_review_gate": {
+ "enabled": true,
+ "priority": 100,
+ "transport": "stdio",
+ "command": [
+ "python3",
+ "/abs/path/to/review_gate.py"
+ ],
+ "observe": [
+ "tool_exec_start",
+ "tool_exec_end",
+ "tool_exec_skipped"
+ ],
+ "intercept": [
+ "before_tool",
+ "approve_tool"
+ ],
+ "env": {
+ "PICOCLAW_HOOK_LOG_FILE": "/tmp/picoclaw-hook-review-gate.log"
+ }
+ }
+ }
+ }
+}
+```
+
+### 环境变量
+
+- `PICOCLAW_HOOK_LOG_EVENTS`
+ 是否把 `hook.event` 写到 `stderr`,默认开启
+- `PICOCLAW_HOOK_LOG_FILE`
+ 外部日志文件路径。设置后,脚本会把收到的 hook 请求、notification 和返回结果按 JSON Lines 追加到该文件
+
+注意:`PICOCLAW_HOOK_LOG_FILE` 没有默认值。不设置时,脚本不会自动落盘日志。
+
+### 如何确认它收到了 hook
+
+推荐同时看两个地方:
+
+- gateway 日志
+ 用来观察宿主是否成功启动了外部进程,以及脚本写到 `stderr` 的事件摘要
+- `PICOCLAW_HOOK_LOG_FILE`
+ 用来观察脚本实际收到了什么请求、返回了什么响应
+
+典型判断方式:
+
+- 只看到 `hook.hello`
+ 说明进程启动并完成握手了,但还没有新的业务 hook 请求真正打进来
+- 看到 `hook.event`
+ 说明 `observe` 配置生效了
+- 看到 `hook.before_tool`
+ 说明 `intercept: ["before_tool", ...]` 生效了
+- 看到 `hook.approve_tool`
+ 说明审批 hook 生效了
+
+这份示例脚本不会改写任何参数,也不会拒绝工具,所以你应该看到的典型返回是:
+
+```json
+{"direction":"out","id":7,"response":{"action":"continue"},"error":null}
+{"direction":"out","id":8,"response":{"approved":true},"error":null}
+```
+
+一组完整样例:
+
+```json
+{"ts":"2026-03-21T14:12:00+00:00","direction":"in","id":1,"method":"hook.hello","params":{"name":"py_review_gate","version":1,"modes":["observe","tool","approve"]},"notification":false}
+{"ts":"2026-03-21T14:12:00+00:00","direction":"out","id":1,"response":{"ok":true,"name":"python-review-gate"},"error":null}
+{"ts":"2026-03-21T14:12:05+00:00","direction":"in","id":0,"method":"hook.event","params":{"Kind":"tool_exec_start"},"notification":true}
+{"ts":"2026-03-21T14:12:05+00:00","direction":"in","id":7,"method":"hook.before_tool","params":{"tool":"echo_text","arguments":{"text":"hello"}},"notification":false}
+{"ts":"2026-03-21T14:12:05+00:00","direction":"out","id":7,"response":{"action":"continue"},"error":null}
+```
+
+补充说明:
+
+- 时间戳是 UTC,不是本地时区
+- `notification=true` 表示这是 `hook.event` 这类不需要响应的通知
+- `id` 会随着当前进程内的请求递增;如果 hook 进程重启,计数会重新开始
+
+## Process Hook 协议约定
+
+当前 process hook 使用 `JSON-RPC over stdio`:
+
+- PicoClaw 启动外部进程
+- 请求和响应都按“一行一个 JSON 消息”传输
+- `hook.event` 是 notification,不需要响应
+- `hook.before_llm` / `hook.after_llm` / `hook.before_tool` / `hook.after_tool` / `hook.approve_tool` 是 request/response
+
+当前宿主不会接受 process hook 主动发起的新 RPC。也就是说,外部 hook 现在只能“响应 PicoClaw 的调用”,不能反向调用宿主去发送 channel 消息。
+
+## 配置字段
+
+### `hooks.builtins.`
+
+- `enabled`
+- `priority`
+- `config`
+
+### `hooks.processes.`
+
+- `enabled`
+- `priority`
+- `transport`
+ 当前只支持 `stdio`
+- `command`
+- `dir`
+- `env`
+- `observe`
+- `intercept`
+
+## 排查建议
+
+当你觉得“hook 没触发”时,优先按这个顺序排查:
+
+1. `hooks.enabled` 是否为 `true`
+2. 对应的 builtin/process hook 是否 `enabled`
+3. process hook 的 `command` 路径是否正确
+4. 你看的是否是正确的日志文件
+5. 当前请求是否真的走到了对应阶段
+6. `observe` / `intercept` 是否包含了你想看的点位
+
+一个很实用的最小排查组合是:
+
+- 先用文档里的 Python process 示例确认外部协议没问题
+- 再用文档里的 Go in-process 示例确认宿主内的 hook 链路没问题
+
+如果前者有 `hook.hello` 但没有业务请求,通常不是协议挂了,而是当前这次请求没有真正触发对应的 hook 点位。
+
+## 适用边界
+
+当前 hook 系统最适合做这些事:
+
+- LLM 请求改写
+- 工具参数规范化
+- 工具执行前审批
+- 审计和观测
+
+当前还不适合直接承载这些需求:
+
+- 外部 hook 主动发 channel 消息
+- 挂起 turn 并等待人工审批回复
+- inbound/outbound 全链路消息拦截
+
+如果你要做人审流转,推荐把 hook 作为审批入口,把审批状态机和 channel 交互放到独立的 `ApprovalManager`。