merge: resolve conflicts between refactor/agent and main

This commit is contained in:
Administrator
2026-03-22 19:21:58 +08:00
parent 482c88cd15
commit f7f27e237a
56 changed files with 15839 additions and 1662 deletions
+27 -16
View File
@@ -222,13 +222,10 @@ func (cb *ContextBuilder) InvalidateCache() {
// invalidation (bootstrap files + memory). Skill roots are handled separately
// because they require both directory-level and recursive file-level checks.
func (cb *ContextBuilder) sourcePaths() []string {
return []string{
filepath.Join(cb.workspace, "AGENTS.md"),
filepath.Join(cb.workspace, "SOUL.md"),
filepath.Join(cb.workspace, "USER.md"),
filepath.Join(cb.workspace, "IDENTITY.md"),
filepath.Join(cb.workspace, "memory", "MEMORY.md"),
}
agentDefinition := cb.LoadAgentDefinition()
paths := agentDefinition.trackedPaths(cb.workspace)
paths = append(paths, filepath.Join(cb.workspace, "memory", "MEMORY.md"))
return uniquePaths(paths)
}
// skillRoots returns all skill root directories that can affect
@@ -432,18 +429,32 @@ func skillFilesChangedSince(skillRoots []string, filesAtCache map[string]time.Ti
}
func (cb *ContextBuilder) LoadBootstrapFiles() string {
bootstrapFiles := []string{
"AGENTS.md",
"SOUL.md",
"USER.md",
"IDENTITY.md",
var sb strings.Builder
agentDefinition := cb.LoadAgentDefinition()
if agentDefinition.Agent != nil {
label := string(agentDefinition.Source)
if label == "" {
label = relativeWorkspacePath(cb.workspace, agentDefinition.Agent.Path)
}
fmt.Fprintf(&sb, "## %s\n\n%s\n\n", label, agentDefinition.Agent.Body)
}
if agentDefinition.Soul != nil {
fmt.Fprintf(
&sb,
"## %s\n\n%s\n\n",
relativeWorkspacePath(cb.workspace, agentDefinition.Soul.Path),
agentDefinition.Soul.Content,
)
}
if agentDefinition.User != nil {
fmt.Fprintf(&sb, "## %s\n\n%s\n\n", "USER.md", agentDefinition.User.Content)
}
var sb strings.Builder
for _, filename := range bootstrapFiles {
filePath := filepath.Join(cb.workspace, filename)
if agentDefinition.Source != AgentDefinitionSourceAgent {
filePath := filepath.Join(cb.workspace, "IDENTITY.md")
if data, err := os.ReadFile(filePath); err == nil {
fmt.Fprintf(&sb, "## %s\n\n%s\n\n", filename, data)
fmt.Fprintf(&sb, "## %s\n\n%s\n\n", "IDENTITY.md", data)
}
}
+176
View File
@@ -0,0 +1,176 @@
// PicoClaw - Ultra-lightweight personal AI agent
// License: MIT
//
// Copyright (c) 2026 PicoClaw contributors
package agent
import (
"encoding/json"
"unicode/utf8"
"github.com/sipeed/picoclaw/pkg/providers"
)
// parseTurnBoundaries returns the starting index of each Turn in the history.
// A Turn is a complete "user input → LLM iterations → final response" cycle
// (as defined in #1316). Each Turn begins at a user message and extends
// through all subsequent assistant/tool messages until the next user message.
//
// Cutting at a Turn boundary guarantees that no tool-call sequence
// (assistant+ToolCalls → tool results) is split across the cut.
func parseTurnBoundaries(history []providers.Message) []int {
var starts []int
for i, msg := range history {
if msg.Role == "user" {
starts = append(starts, i)
}
}
return starts
}
// isSafeBoundary reports whether index is a valid Turn boundary — i.e.,
// a position where the kept portion (history[index:]) begins at a user
// message, so no tool-call sequence is torn apart.
func isSafeBoundary(history []providers.Message, index int) bool {
if index <= 0 || index >= len(history) {
return true
}
return history[index].Role == "user"
}
// findSafeBoundary locates the nearest Turn boundary to targetIndex.
// It prefers the boundary at or before targetIndex (preserving more recent
// context). Falls back to the nearest boundary after targetIndex, and
// returns targetIndex unchanged only when no Turn boundary exists at all.
func findSafeBoundary(history []providers.Message, targetIndex int) int {
if len(history) == 0 {
return 0
}
if targetIndex <= 0 {
return 0
}
if targetIndex >= len(history) {
return len(history)
}
turns := parseTurnBoundaries(history)
if len(turns) == 0 {
return targetIndex
}
// Find the last Turn boundary at or before targetIndex.
// Prefer backward: keeps more recent messages.
backward := -1
for _, t := range turns {
if t <= targetIndex {
backward = t
}
}
if backward > 0 {
return backward
}
// No valid Turn boundary before target (or only at index 0 which
// would keep everything). Use the first Turn after targetIndex.
for _, t := range turns {
if t > targetIndex {
return t
}
}
// No Turn boundary after targetIndex either. The only boundary is at
// index 0, meaning the entire history is a single Turn. Return 0 to
// signal that safe compression is not possible — callers check for
// mid <= 0 and skip compression in that case.
return 0
}
// estimateMessageTokens estimates the token count for a single message,
// including Content, ReasoningContent, ToolCalls arguments, ToolCallID
// metadata, and Media items. Uses a heuristic of 2.5 characters per token.
func estimateMessageTokens(msg providers.Message) int {
chars := utf8.RuneCountInString(msg.Content)
// ReasoningContent (extended thinking / chain-of-thought) can be
// substantial and is stored in session history via AddFullMessage.
if msg.ReasoningContent != "" {
chars += utf8.RuneCountInString(msg.ReasoningContent)
}
for _, tc := range msg.ToolCalls {
chars += len(tc.ID) + len(tc.Type)
if tc.Function != nil {
// Count function name + arguments (the wire format for most providers).
// tc.Name mirrors tc.Function.Name — count only once to avoid double-counting.
chars += len(tc.Function.Name) + len(tc.Function.Arguments)
} else {
// Fallback: some provider formats use top-level Name without Function.
chars += len(tc.Name)
}
}
if msg.ToolCallID != "" {
chars += len(msg.ToolCallID)
}
// Per-message overhead for role label, JSON structure, separators.
const messageOverhead = 12
chars += messageOverhead
tokens := chars * 2 / 5
// Media items (images, files) are serialized by provider adapters into
// multipart or image_url payloads. Add a fixed per-item token estimate
// directly (not through the chars heuristic) since actual cost depends
// on resolution and provider-specific image tokenization.
const mediaTokensPerItem = 256
tokens += len(msg.Media) * mediaTokensPerItem
return tokens
}
// estimateToolDefsTokens estimates the total token cost of tool definitions
// as they appear in the LLM request. Each tool's name, description, and
// JSON schema parameters contribute to the context window budget.
func estimateToolDefsTokens(defs []providers.ToolDefinition) int {
if len(defs) == 0 {
return 0
}
totalChars := 0
for _, d := range defs {
totalChars += len(d.Function.Name) + len(d.Function.Description)
if d.Function.Parameters != nil {
if paramJSON, err := json.Marshal(d.Function.Parameters); err == nil {
totalChars += len(paramJSON)
}
}
// Per-tool overhead: type field, JSON structure, separators.
totalChars += 20
}
return totalChars * 2 / 5
}
// isOverContextBudget checks whether the assembled messages plus tool definitions
// and output reserve would exceed the model's context window. This enables
// proactive compression before calling the LLM, rather than reacting to 400 errors.
func isOverContextBudget(
contextWindow int,
messages []providers.Message,
toolDefs []providers.ToolDefinition,
maxTokens int,
) bool {
msgTokens := 0
for _, m := range messages {
msgTokens += estimateMessageTokens(m)
}
toolTokens := estimateToolDefsTokens(toolDefs)
total := msgTokens + toolTokens + maxTokens
return total > contextWindow
}
+826
View File
@@ -0,0 +1,826 @@
package agent
import (
"fmt"
"strings"
"testing"
"github.com/sipeed/picoclaw/pkg/providers"
)
// msgUser creates a user message.
func msgUser(content string) providers.Message {
return providers.Message{Role: "user", Content: content}
}
// msgAssistant creates a plain assistant message (no tool calls).
func msgAssistant(content string) providers.Message {
return providers.Message{Role: "assistant", Content: content}
}
// msgAssistantTC creates an assistant message with tool calls.
func msgAssistantTC(toolIDs ...string) providers.Message {
tcs := make([]providers.ToolCall, len(toolIDs))
for i, id := range toolIDs {
tcs[i] = providers.ToolCall{
ID: id,
Type: "function",
Name: "tool_" + id,
Function: &providers.FunctionCall{
Name: "tool_" + id,
Arguments: `{"key":"value"}`,
},
}
}
return providers.Message{Role: "assistant", ToolCalls: tcs}
}
// msgTool creates a tool result message.
func msgTool(callID, content string) providers.Message {
return providers.Message{Role: "tool", ToolCallID: callID, Content: content}
}
func TestParseTurnBoundaries(t *testing.T) {
tests := []struct {
name string
history []providers.Message
want []int
}{
{
name: "empty history",
history: nil,
want: nil,
},
{
name: "simple exchange",
history: []providers.Message{
msgUser("q1"),
msgAssistant("a1"),
msgUser("q2"),
msgAssistant("a2"),
},
want: []int{0, 2},
},
{
name: "tool-call Turn",
history: []providers.Message{
msgUser("search"),
msgAssistantTC("tc1"),
msgTool("tc1", "result"),
msgAssistant("found it"),
msgUser("thanks"),
msgAssistant("welcome"),
},
want: []int{0, 4},
},
{
name: "chained tool calls in single Turn",
history: []providers.Message{
msgUser("save and notify"),
msgAssistantTC("tc_save"),
msgTool("tc_save", "saved"),
msgAssistantTC("tc_notify"),
msgTool("tc_notify", "notified"),
msgAssistant("done"),
},
want: []int{0},
},
{
name: "no user messages",
history: []providers.Message{
msgAssistant("a1"),
msgAssistant("a2"),
},
want: nil,
},
{
name: "leading non-user messages",
history: []providers.Message{
msgAssistantTC("tc1"),
msgTool("tc1", "r1"),
msgAssistant("greeting"),
msgUser("hello"),
msgAssistant("hi"),
},
want: []int{3},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := parseTurnBoundaries(tt.history)
if len(got) != len(tt.want) {
t.Errorf("parseTurnBoundaries() = %v, want %v", got, tt.want)
return
}
for i := range got {
if got[i] != tt.want[i] {
t.Errorf("parseTurnBoundaries()[%d] = %d, want %d", i, got[i], tt.want[i])
}
}
})
}
}
func TestIsSafeBoundary(t *testing.T) {
tests := []struct {
name string
history []providers.Message
index int
want bool
}{
{
name: "empty history, index 0",
history: nil,
index: 0,
want: true,
},
{
name: "single user message, index 0",
history: []providers.Message{msgUser("hi")},
index: 0,
want: true,
},
{
name: "single user message, index 1 (end)",
history: []providers.Message{msgUser("hi")},
index: 1,
want: true,
},
{
name: "at user message",
history: []providers.Message{
msgAssistant("hello"),
msgUser("how are you"),
msgAssistant("fine"),
},
index: 1,
want: true,
},
{
name: "at assistant without tool calls",
history: []providers.Message{
msgUser("hello"),
msgAssistant("response"),
msgUser("follow up"),
},
index: 1,
want: false,
},
{
name: "at assistant with tool calls",
history: []providers.Message{
msgUser("search something"),
msgAssistantTC("tc1"),
msgTool("tc1", "result"),
msgAssistant("here is what I found"),
},
index: 1,
want: false,
},
{
name: "at tool result",
history: []providers.Message{
msgUser("do something"),
msgAssistantTC("tc1"),
msgTool("tc1", "done"),
msgAssistant("completed"),
},
index: 2,
want: false,
},
{
name: "negative index",
history: []providers.Message{
msgUser("hello"),
},
index: -1,
want: true,
},
{
name: "index beyond length",
history: []providers.Message{
msgUser("hello"),
},
index: 5,
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := isSafeBoundary(tt.history, tt.index)
if got != tt.want {
t.Errorf("isSafeBoundary(history, %d) = %v, want %v", tt.index, got, tt.want)
}
})
}
}
func TestFindSafeBoundary(t *testing.T) {
tests := []struct {
name string
history []providers.Message
targetIndex int
want int
}{
{
name: "empty history",
history: nil,
targetIndex: 0,
want: 0,
},
{
name: "target at 0",
history: []providers.Message{msgUser("hi")},
targetIndex: 0,
want: 0,
},
{
name: "target beyond length",
history: []providers.Message{msgUser("hi")},
targetIndex: 5,
want: 1,
},
{
name: "target already at user message",
history: []providers.Message{
msgUser("q1"),
msgAssistant("a1"),
msgUser("q2"),
msgAssistant("a2"),
},
targetIndex: 2,
want: 2,
},
{
name: "target at assistant, scan backward finds user",
history: []providers.Message{
msgUser("q1"),
msgAssistant("a1"),
msgUser("q2"),
msgAssistant("a2"),
msgUser("q3"),
},
targetIndex: 3, // assistant "a2"
want: 2, // backward to user "q2"
},
{
name: "target inside tool sequence, scan backward finds user",
history: []providers.Message{
msgUser("q1"),
msgAssistant("a1"),
msgUser("q2"),
msgAssistantTC("tc1", "tc2"),
msgTool("tc1", "r1"),
msgTool("tc2", "r2"),
msgAssistant("summary"),
msgUser("q3"),
},
targetIndex: 4, // tool result "r1"
want: 2, // backward: 3=assistant+TC (not safe), 2=user → safe
},
{
name: "target inside tool sequence, backward finds user before chain",
history: []providers.Message{
msgUser("q1"),
msgAssistant("a1"),
msgUser("q2"),
msgAssistantTC("tc1", "tc2"),
msgTool("tc1", "r1"),
msgTool("tc2", "r2"),
msgAssistant("summary"),
msgUser("q3"),
},
targetIndex: 5, // tool result "r2"
want: 2, // backward: 4=tool, 3=assistant+TC, 2=user → safe
},
{
name: "no backward user, scan forward finds one",
history: []providers.Message{
msgAssistantTC("tc1"),
msgTool("tc1", "r1"),
msgAssistant("a1"),
msgUser("q1"),
},
targetIndex: 1, // tool result
want: 3, // forward to user "q1"
},
{
name: "multi-step tool chain preserves atomicity",
history: []providers.Message{
msgUser("q1"),
msgAssistant("a1"),
msgUser("q2"),
msgAssistantTC("tc1"),
msgTool("tc1", "r1"),
msgAssistantTC("tc2"),
msgTool("tc2", "r2"),
msgAssistant("final"),
msgUser("q3"),
msgAssistant("a3"),
},
targetIndex: 5, // second assistant+TC
want: 2, // backward: 4=tool, 3=assistant+TC, 2=user → safe
},
{
name: "all non-user messages returns target unchanged",
history: []providers.Message{
msgAssistant("a1"),
msgAssistant("a2"),
msgAssistant("a3"),
},
targetIndex: 1,
want: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := findSafeBoundary(tt.history, tt.targetIndex)
if got != tt.want {
t.Errorf("findSafeBoundary(history, %d) = %d, want %d",
tt.targetIndex, got, tt.want)
}
})
}
}
func TestFindSafeBoundary_SingleTurnReturnsZero(t *testing.T) {
// A single Turn with no subsequent user message. The only Turn boundary
// is at index 0; cutting anywhere else would split the Turn's tool
// sequence. findSafeBoundary must return 0 so callers skip compression.
history := []providers.Message{
msgUser("do everything"), // 0 ← only Turn boundary
msgAssistantTC("tc1"), // 1
msgTool("tc1", "result"), // 2
msgAssistant("all done"), // 3
}
got := findSafeBoundary(history, 2)
if got != 0 {
t.Errorf("findSafeBoundary(single_turn, 2) = %d, want 0 (cannot split single Turn)", got)
}
}
func TestFindSafeBoundary_BackwardScanSkipsToolSequence(t *testing.T) {
// A long tool-call chain: user → assistant+TC → tool → tool → ... → assistant → user
// Target is inside the chain; boundary should skip the entire chain backward.
history := []providers.Message{
msgUser("start"), // 0
msgAssistant("before chain"), // 1
msgUser("trigger"), // 2 ← expected safe boundary
msgAssistantTC("t1", "t2", "t3"), // 3
msgTool("t1", "r1"), // 4
msgTool("t2", "r2"), // 5
msgTool("t3", "r3"), // 6
msgAssistantTC("t4"), // 7
msgTool("t4", "r4"), // 8
msgAssistant("chain done"), // 9
msgUser("next"), // 10
}
// Target at index 6 (middle of tool results)
got := findSafeBoundary(history, 6)
if got != 2 {
t.Errorf("findSafeBoundary(history, 6) = %d, want 2 (user before chain)", got)
}
}
func TestEstimateMessageTokens(t *testing.T) {
tests := []struct {
name string
msg providers.Message
want int // minimum expected tokens (exact value depends on overhead)
}{
{
name: "plain user message",
msg: msgUser("Hello, world!"),
want: 1, // at least some tokens
},
{
name: "empty message still has overhead",
msg: providers.Message{Role: "user"},
want: 1, // message overhead alone
},
{
name: "assistant with tool calls",
msg: msgAssistantTC("tc_123"),
want: 1,
},
{
name: "tool result with ID",
msg: msgTool("call_abc", "Here is the search result with lots of content"),
want: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := estimateMessageTokens(tt.msg)
if got < tt.want {
t.Errorf("estimateMessageTokens() = %d, want >= %d", got, tt.want)
}
})
}
}
func TestEstimateMessageTokens_ToolCallsContribute(t *testing.T) {
plain := msgAssistant("thinking")
withTC := providers.Message{
Role: "assistant",
Content: "thinking",
ToolCalls: []providers.ToolCall{
{
ID: "call_1",
Type: "function",
Name: "web_search",
Function: &providers.FunctionCall{
Name: "web_search",
Arguments: `{"query":"picoclaw agent framework","max_results":5}`,
},
},
},
}
plainTokens := estimateMessageTokens(plain)
withTCTokens := estimateMessageTokens(withTC)
if withTCTokens <= plainTokens {
t.Errorf("message with ToolCalls (%d tokens) should exceed plain message (%d tokens)",
withTCTokens, plainTokens)
}
}
func TestEstimateMessageTokens_MultibyteContent(t *testing.T) {
// Multi-byte characters (e.g. emoji, accented letters) are single runes
// but may map to different token counts. The heuristic should still produce
// reasonable estimates via RuneCountInString.
msg := msgUser("caf\u00e9 na\u00efve r\u00e9sum\u00e9 \u00fcber stra\u00dfe")
tokens := estimateMessageTokens(msg)
if tokens <= 0 {
t.Errorf("multibyte message should produce positive token count, got %d", tokens)
}
}
func TestEstimateMessageTokens_LargeArguments(t *testing.T) {
// Simulate a tool call with large JSON arguments.
largeArgs := fmt.Sprintf(`{"content":"%s"}`, strings.Repeat("x", 5000))
msg := providers.Message{
Role: "assistant",
ToolCalls: []providers.ToolCall{
{
ID: "call_large",
Type: "function",
Name: "write_file",
Function: &providers.FunctionCall{
Name: "write_file",
Arguments: largeArgs,
},
},
},
}
tokens := estimateMessageTokens(msg)
// 5000+ chars → at least 2000 tokens with the 2.5 char/token heuristic
if tokens < 2000 {
t.Errorf("large tool call arguments should produce significant token count, got %d", tokens)
}
}
func TestEstimateMessageTokens_ReasoningContent(t *testing.T) {
plain := msgAssistant("result")
withReasoning := providers.Message{
Role: "assistant",
Content: "result",
ReasoningContent: strings.Repeat("thinking step ", 200),
}
plainTokens := estimateMessageTokens(plain)
reasoningTokens := estimateMessageTokens(withReasoning)
if reasoningTokens <= plainTokens {
t.Errorf("message with ReasoningContent (%d tokens) should exceed plain message (%d tokens)",
reasoningTokens, plainTokens)
}
}
func TestEstimateMessageTokens_MediaItems(t *testing.T) {
plain := msgUser("describe this")
withMedia := providers.Message{
Role: "user",
Content: "describe this",
Media: []string{"media://img1.png", "media://img2.png"},
}
plainTokens := estimateMessageTokens(plain)
mediaTokens := estimateMessageTokens(withMedia)
if mediaTokens <= plainTokens {
t.Errorf("message with Media (%d tokens) should exceed plain message (%d tokens)",
mediaTokens, plainTokens)
}
// Each media item should add exactly 256 tokens (not run through chars*2/5).
expectedDelta := 256 * 2
actualDelta := mediaTokens - plainTokens
if actualDelta != expectedDelta {
t.Errorf("2 media items should add %d tokens, got delta %d", expectedDelta, actualDelta)
}
}
// --- estimateToolDefsTokens tests ---
func TestEstimateToolDefsTokens(t *testing.T) {
tests := []struct {
name string
defs []providers.ToolDefinition
want int // minimum expected tokens
}{
{
name: "empty tool list",
defs: nil,
want: 0,
},
{
name: "single tool with params",
defs: []providers.ToolDefinition{
{
Type: "function",
Function: providers.ToolFunctionDefinition{
Name: "web_search",
Description: "Search the web for information",
Parameters: map[string]any{
"type": "object",
"properties": map[string]any{
"query": map[string]any{"type": "string"},
},
"required": []any{"query"},
},
},
},
},
want: 1,
},
{
name: "tool without params",
defs: []providers.ToolDefinition{
{
Type: "function",
Function: providers.ToolFunctionDefinition{
Name: "list_dir",
Description: "List directory contents",
},
},
},
want: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := estimateToolDefsTokens(tt.defs)
if got < tt.want {
t.Errorf("estimateToolDefsTokens() = %d, want >= %d", got, tt.want)
}
})
}
}
func TestEstimateToolDefsTokens_ScalesWithCount(t *testing.T) {
makeTool := func(name string) providers.ToolDefinition {
return providers.ToolDefinition{
Type: "function",
Function: providers.ToolFunctionDefinition{
Name: name,
Description: "A test tool that does something useful",
Parameters: map[string]any{
"type": "object",
"properties": map[string]any{
"input": map[string]any{"type": "string", "description": "Input value"},
},
},
},
}
}
one := estimateToolDefsTokens([]providers.ToolDefinition{makeTool("tool_a")})
three := estimateToolDefsTokens([]providers.ToolDefinition{
makeTool("tool_a"), makeTool("tool_b"), makeTool("tool_c"),
})
if three <= one {
t.Errorf("3 tools (%d tokens) should exceed 1 tool (%d tokens)", three, one)
}
}
// --- isOverContextBudget tests ---
func TestIsOverContextBudget(t *testing.T) {
systemMsg := providers.Message{Role: "system", Content: strings.Repeat("x", 1000)}
userMsg := msgUser("hello")
smallHistory := []providers.Message{systemMsg, msgUser("q1"), msgAssistant("a1"), userMsg}
tools := []providers.ToolDefinition{
{
Type: "function",
Function: providers.ToolFunctionDefinition{
Name: "test_tool",
Description: "A test tool",
Parameters: map[string]any{"type": "object"},
},
},
}
tests := []struct {
name string
contextWindow int
messages []providers.Message
toolDefs []providers.ToolDefinition
maxTokens int
want bool
}{
{
name: "within budget",
contextWindow: 100000,
messages: smallHistory,
toolDefs: tools,
maxTokens: 4096,
want: false,
},
{
name: "over budget with small window",
contextWindow: 100, // very small window
messages: smallHistory,
toolDefs: tools,
maxTokens: 4096,
want: true,
},
{
name: "large max_tokens eats budget",
contextWindow: 2000,
messages: smallHistory,
toolDefs: tools,
maxTokens: 1800, // leaves almost no room
want: true,
},
{
name: "empty messages within budget",
contextWindow: 10000,
messages: nil,
toolDefs: nil,
maxTokens: 4096,
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := isOverContextBudget(tt.contextWindow, tt.messages, tt.toolDefs, tt.maxTokens)
if got != tt.want {
t.Errorf("isOverContextBudget() = %v, want %v", got, tt.want)
}
})
}
}
// --- Tests reflecting actual session data shape ---
// Session history never contains system messages. The system prompt is
// built dynamically by BuildMessages. These tests use realistic history
// shapes: user/assistant/tool only, with tool chains and reasoning content.
func TestFindSafeBoundary_SessionHistoryNoSystem(t *testing.T) {
// Real session history starts with a user message, not a system message.
history := []providers.Message{
msgUser("hello"), // 0
msgAssistant("hi there"), // 1
msgUser("search for X"), // 2
msgAssistantTC("tc1"), // 3
msgTool("tc1", "found X"), // 4
msgAssistant("here is X"), // 5
msgUser("thanks"), // 6
msgAssistant("you're welcome"), // 7
}
// Mid-point is 4 (tool result). Should snap backward to 2 (user).
got := findSafeBoundary(history, 4)
if got != 2 {
t.Errorf("findSafeBoundary(session_history, 4) = %d, want 2", got)
}
}
func TestFindSafeBoundary_SessionWithChainedTools(t *testing.T) {
// Session with chained tool calls (save then notify).
history := []providers.Message{
msgUser("save and notify"), // 0
msgAssistantTC("tc_save"), // 1
msgTool("tc_save", "saved"), // 2
msgAssistantTC("tc_notify"), // 3
msgTool("tc_notify", "notified"), // 4
msgAssistant("done"), // 5
msgUser("check status"), // 6
msgAssistant("all good"), // 7
}
// Target at 3 (inside chain). Should find user at 0, but backward
// scan stops at i>0, so forward scan finds user at 6.
// Actually: backward from 3: 2=tool (no), 1=assistantTC (no). Forward: 4=tool, 5=asst, 6=user ✓
got := findSafeBoundary(history, 3)
if got != 6 {
t.Errorf("findSafeBoundary(chained_tools, 3) = %d, want 6", got)
}
}
func TestEstimateMessageTokens_WithReasoningAndMedia(t *testing.T) {
// Message with all fields populated — mirrors what AddFullMessage stores.
msg := providers.Message{
Role: "assistant",
Content: "Here is the analysis.",
ReasoningContent: strings.Repeat("Let me think about this carefully. ", 50),
ToolCalls: []providers.ToolCall{
{
ID: "call_1",
Type: "function",
Name: "analyze",
Function: &providers.FunctionCall{
Name: "analyze",
Arguments: `{"data":"sample","depth":3}`,
},
},
},
}
tokens := estimateMessageTokens(msg)
// ReasoningContent alone is ~1700 chars → ~680 tokens.
// Content + TC + overhead adds more. Should be well above 500.
if tokens < 500 {
t.Errorf("message with reasoning+toolcalls should have significant tokens, got %d", tokens)
}
// Compare without reasoning to ensure it's counted.
msgNoReasoning := msg
msgNoReasoning.ReasoningContent = ""
tokensNoReasoning := estimateMessageTokens(msgNoReasoning)
if tokens <= tokensNoReasoning {
t.Errorf("reasoning content should add tokens: with=%d, without=%d", tokens, tokensNoReasoning)
}
}
func TestIsOverContextBudget_RealisticSession(t *testing.T) {
// Simulate what BuildMessages produces: system + session history + current user.
// System message is built by BuildMessages, not stored in session.
systemMsg := providers.Message{
Role: "system",
Content: strings.Repeat("system prompt content ", 100),
}
sessionHistory := []providers.Message{
msgUser("first question"),
msgAssistant("first answer"),
msgUser("use tool X"),
{
Role: "assistant",
Content: "I'll use tool X",
ToolCalls: []providers.ToolCall{
{
ID: "tc1", Type: "function", Name: "tool_x",
Function: &providers.FunctionCall{
Name: "tool_x",
Arguments: `{"query":"test","verbose":true}`,
},
},
},
},
{Role: "tool", Content: strings.Repeat("result data ", 200), ToolCallID: "tc1"},
msgAssistant("Here are the results from tool X."),
}
currentUser := msgUser("follow up question")
// Assemble as BuildMessages would.
messages := make([]providers.Message, 0, 1+len(sessionHistory)+1)
messages = append(messages, systemMsg)
messages = append(messages, sessionHistory...)
messages = append(messages, currentUser)
tools := []providers.ToolDefinition{
{
Type: "function",
Function: providers.ToolFunctionDefinition{
Name: "tool_x",
Description: "A useful tool",
Parameters: map[string]any{"type": "object"},
},
},
}
// With a large context window, should be within budget.
if isOverContextBudget(131072, messages, tools, 32768) {
t.Error("realistic session should be within 131072 context window")
}
// With a tiny context window, should exceed budget.
if !isOverContextBudget(500, messages, tools, 32768) {
t.Error("realistic session should exceed 500 context window")
}
}
+10 -10
View File
@@ -37,7 +37,7 @@ func setupWorkspace(t *testing.T, files map[string]string) string {
// Codex (only reads last system message as instructions).
func TestSingleSystemMessage(t *testing.T) {
tmpDir := setupWorkspace(t, map[string]string{
"IDENTITY.md": "# Identity\nTest agent.",
"AGENT.md": "# Agent\nTest agent.",
})
defer os.RemoveAll(tmpDir)
@@ -202,10 +202,10 @@ func TestMtimeAutoInvalidation(t *testing.T) {
}{
{
name: "bootstrap file change",
file: "IDENTITY.md",
contentV1: "# Original Identity",
contentV2: "# Updated Identity",
checkField: "Updated Identity",
file: "AGENT.md",
contentV1: "# Original Agent",
contentV2: "# Updated Agent",
checkField: "Updated Agent",
},
{
name: "memory file change",
@@ -280,7 +280,7 @@ func TestMtimeAutoInvalidation(t *testing.T) {
// even when source files haven't changed (useful for tests and reload commands).
func TestExplicitInvalidateCache(t *testing.T) {
tmpDir := setupWorkspace(t, map[string]string{
"IDENTITY.md": "# Test Identity",
"AGENT.md": "# Test Agent",
})
defer os.RemoveAll(tmpDir)
@@ -307,8 +307,8 @@ func TestExplicitInvalidateCache(t *testing.T) {
// when no files change (regression test for issue #607).
func TestCacheStability(t *testing.T) {
tmpDir := setupWorkspace(t, map[string]string{
"IDENTITY.md": "# Identity\nContent",
"SOUL.md": "# Soul\nContent",
"AGENT.md": "# Agent\nContent",
"SOUL.md": "# Soul\nContent",
})
defer os.RemoveAll(tmpDir)
@@ -607,7 +607,7 @@ description: delete-me-v1
// Run with: go test -race ./pkg/agent/ -run TestConcurrentBuildSystemPromptWithCache
func TestConcurrentBuildSystemPromptWithCache(t *testing.T) {
tmpDir := setupWorkspace(t, map[string]string{
"IDENTITY.md": "# Identity\nConcurrency test agent.",
"AGENT.md": "# Agent\nConcurrency test agent.",
"SOUL.md": "# Soul\nBe helpful.",
"memory/MEMORY.md": "# Memory\nUser prefers Go.",
"skills/demo/SKILL.md": "---\nname: demo\ndescription: \"demo skill\"\n---\n# Demo",
@@ -714,7 +714,7 @@ func BenchmarkBuildMessagesWithCache(b *testing.B) {
os.MkdirAll(filepath.Join(tmpDir, "memory"), 0o755)
os.MkdirAll(filepath.Join(tmpDir, "skills"), 0o755)
for _, name := range []string{"IDENTITY.md", "SOUL.md", "USER.md"} {
for _, name := range []string{"AGENT.md", "SOUL.md"} {
os.WriteFile(filepath.Join(tmpDir, name), []byte(strings.Repeat("Content.\n", 10)), 0o644)
}
+255
View File
@@ -0,0 +1,255 @@
package agent
import (
"os"
"path/filepath"
"slices"
"strings"
"github.com/gomarkdown/markdown/parser"
"gopkg.in/yaml.v3"
"github.com/sipeed/picoclaw/pkg/logger"
)
// AgentDefinitionSource identifies which agent bootstrap file produced the definition.
type AgentDefinitionSource string
const (
// AgentDefinitionSourceAgent indicates the new AGENT.md format.
AgentDefinitionSourceAgent AgentDefinitionSource = "AGENT.md"
// AgentDefinitionSourceAgents indicates the legacy AGENTS.md format.
AgentDefinitionSourceAgents AgentDefinitionSource = "AGENTS.md"
)
// AgentFrontmatter holds machine-readable AGENT.md configuration.
//
// Known fields are exposed directly for convenience. Fields keeps the full
// parsed frontmatter so future refactors can read additional keys without
// changing the loader contract again.
type AgentFrontmatter struct {
Name string `json:"name"`
Description string `json:"description"`
Tools []string `json:"tools,omitempty"`
Model string `json:"model,omitempty"`
MaxTurns *int `json:"maxTurns,omitempty"`
Skills []string `json:"skills,omitempty"`
MCPServers []string `json:"mcpServers,omitempty"`
Fields map[string]any `json:"fields,omitempty"`
}
// AgentPromptDefinition represents the parsed AGENT.md or AGENTS.md prompt file.
type AgentPromptDefinition struct {
Path string `json:"path"`
Raw string `json:"raw"`
Body string `json:"body"`
RawFrontmatter string `json:"raw_frontmatter,omitempty"`
Frontmatter AgentFrontmatter `json:"frontmatter"`
}
// SoulDefinition represents the resolved SOUL.md file linked to the agent.
type SoulDefinition struct {
Path string `json:"path"`
Content string `json:"content"`
}
// UserDefinition represents the resolved USER.md file linked to the workspace.
type UserDefinition struct {
Path string `json:"path"`
Content string `json:"content"`
}
// AgentContextDefinition captures the workspace agent definition in a runtime-friendly shape.
type AgentContextDefinition struct {
Source AgentDefinitionSource `json:"source,omitempty"`
Agent *AgentPromptDefinition `json:"agent,omitempty"`
Soul *SoulDefinition `json:"soul,omitempty"`
User *UserDefinition `json:"user,omitempty"`
}
// LoadAgentDefinition parses the workspace agent bootstrap files.
//
// It prefers the new AGENT.md format and its paired SOUL.md file. When the
// structured files are absent, it falls back to the legacy AGENTS.md layout so
// the current runtime can transition incrementally.
func (cb *ContextBuilder) LoadAgentDefinition() AgentContextDefinition {
return loadAgentDefinition(cb.workspace)
}
func loadAgentDefinition(workspace string) AgentContextDefinition {
definition := AgentContextDefinition{}
definition.User = loadUserDefinition(workspace)
agentPath := filepath.Join(workspace, string(AgentDefinitionSourceAgent))
if content, err := os.ReadFile(agentPath); err == nil {
prompt := parseAgentPromptDefinition(agentPath, string(content))
definition.Source = AgentDefinitionSourceAgent
definition.Agent = &prompt
soulPath := filepath.Join(workspace, "SOUL.md")
if content, err := os.ReadFile(soulPath); err == nil {
definition.Soul = &SoulDefinition{
Path: soulPath,
Content: string(content),
}
}
return definition
}
legacyPath := filepath.Join(workspace, string(AgentDefinitionSourceAgents))
if content, err := os.ReadFile(legacyPath); err == nil {
definition.Source = AgentDefinitionSourceAgents
definition.Agent = &AgentPromptDefinition{
Path: legacyPath,
Raw: string(content),
Body: string(content),
}
}
defaultSoulPath := filepath.Join(workspace, "SOUL.md")
if definition.Source != "" || fileExists(defaultSoulPath) {
if content, err := os.ReadFile(defaultSoulPath); err == nil {
definition.Soul = &SoulDefinition{
Path: defaultSoulPath,
Content: string(content),
}
}
}
return definition
}
func (definition AgentContextDefinition) trackedPaths(workspace string) []string {
paths := []string{
filepath.Join(workspace, string(AgentDefinitionSourceAgent)),
filepath.Join(workspace, "SOUL.md"),
filepath.Join(workspace, "USER.md"),
}
if definition.Source != AgentDefinitionSourceAgent {
paths = append(paths,
filepath.Join(workspace, string(AgentDefinitionSourceAgents)),
filepath.Join(workspace, "IDENTITY.md"),
)
}
return uniquePaths(paths)
}
func loadUserDefinition(workspace string) *UserDefinition {
userPath := filepath.Join(workspace, "USER.md")
if content, err := os.ReadFile(userPath); err == nil {
return &UserDefinition{
Path: userPath,
Content: string(content),
}
}
return nil
}
func parseAgentPromptDefinition(path, content string) AgentPromptDefinition {
frontmatter, body := splitAgentFrontmatter(content)
return AgentPromptDefinition{
Path: path,
Raw: content,
Body: body,
RawFrontmatter: frontmatter,
Frontmatter: parseAgentFrontmatter(path, frontmatter),
}
}
func parseAgentFrontmatter(path, frontmatter string) AgentFrontmatter {
frontmatter = strings.TrimSpace(frontmatter)
if frontmatter == "" {
return AgentFrontmatter{}
}
rawFields := make(map[string]any)
if err := yaml.Unmarshal([]byte(frontmatter), &rawFields); err != nil {
logger.WarnCF("agent", "Failed to parse AGENT.md frontmatter", map[string]any{
"path": path,
"error": err.Error(),
})
return AgentFrontmatter{}
}
var typed struct {
Name string `yaml:"name"`
Description string `yaml:"description"`
Tools []string `yaml:"tools"`
Model string `yaml:"model"`
MaxTurns *int `yaml:"maxTurns"`
Skills []string `yaml:"skills"`
MCPServers []string `yaml:"mcpServers"`
}
if err := yaml.Unmarshal([]byte(frontmatter), &typed); err != nil {
logger.WarnCF("agent", "Failed to decode AGENT.md frontmatter fields", map[string]any{
"path": path,
"error": err.Error(),
})
return AgentFrontmatter{}
}
return AgentFrontmatter{
Name: strings.TrimSpace(typed.Name),
Description: strings.TrimSpace(typed.Description),
Tools: append([]string(nil), typed.Tools...),
Model: strings.TrimSpace(typed.Model),
MaxTurns: typed.MaxTurns,
Skills: append([]string(nil), typed.Skills...),
MCPServers: append([]string(nil), typed.MCPServers...),
Fields: rawFields,
}
}
func splitAgentFrontmatter(content string) (frontmatter, body string) {
normalized := string(parser.NormalizeNewlines([]byte(content)))
lines := strings.Split(normalized, "\n")
if len(lines) == 0 || lines[0] != "---" {
return "", content
}
end := -1
for i := 1; i < len(lines); i++ {
if lines[i] == "---" {
end = i
break
}
}
if end == -1 {
return "", content
}
frontmatter = strings.Join(lines[1:end], "\n")
body = strings.Join(lines[end+1:], "\n")
body = strings.TrimLeft(body, "\n")
return frontmatter, body
}
func relativeWorkspacePath(workspace, path string) string {
if strings.TrimSpace(path) == "" {
return ""
}
relativePath, err := filepath.Rel(workspace, path)
if err == nil && relativePath != "." && !strings.HasPrefix(relativePath, "..") {
return filepath.ToSlash(relativePath)
}
return filepath.Clean(path)
}
func uniquePaths(paths []string) []string {
result := make([]string, 0, len(paths))
for _, path := range paths {
if strings.TrimSpace(path) == "" {
continue
}
cleaned := filepath.Clean(path)
if slices.Contains(result, cleaned) {
continue
}
result = append(result, cleaned)
}
return result
}
func fileExists(path string) bool {
_, err := os.Stat(path)
return err == nil
}
+302
View File
@@ -0,0 +1,302 @@
package agent
import (
"os"
"path/filepath"
"strings"
"testing"
"time"
)
func TestLoadAgentDefinitionParsesFrontmatterAndSoul(t *testing.T) {
tmpDir := setupWorkspace(t, map[string]string{
"AGENT.md": `---
name: pico
description: Structured agent
model: claude-3-7-sonnet
tools:
- shell
- search
maxTurns: 8
skills:
- review
- search-docs
mcpServers:
- github
metadata:
mode: strict
---
# Agent
Act directly and use tools first.
`,
"SOUL.md": "# Soul\nStay precise.",
})
defer cleanupWorkspace(t, tmpDir)
cb := NewContextBuilder(tmpDir)
definition := cb.LoadAgentDefinition()
if definition.Source != AgentDefinitionSourceAgent {
t.Fatalf("expected source %q, got %q", AgentDefinitionSourceAgent, definition.Source)
}
if definition.Agent == nil {
t.Fatal("expected AGENT.md definition to be loaded")
}
if definition.Agent.Body == "" || !strings.Contains(definition.Agent.Body, "Act directly") {
t.Fatalf("expected AGENT.md body to be preserved, got %q", definition.Agent.Body)
}
if definition.Agent.Frontmatter.Name != "pico" {
t.Fatalf("expected name to be parsed, got %q", definition.Agent.Frontmatter.Name)
}
if definition.Agent.Frontmatter.Model != "claude-3-7-sonnet" {
t.Fatalf("expected model to be parsed, got %q", definition.Agent.Frontmatter.Model)
}
if len(definition.Agent.Frontmatter.Tools) != 2 {
t.Fatalf("expected tools to be parsed, got %v", definition.Agent.Frontmatter.Tools)
}
if definition.Agent.Frontmatter.MaxTurns == nil || *definition.Agent.Frontmatter.MaxTurns != 8 {
t.Fatalf("expected maxTurns to be parsed, got %v", definition.Agent.Frontmatter.MaxTurns)
}
if len(definition.Agent.Frontmatter.Skills) != 2 {
t.Fatalf("expected skills to be parsed, got %v", definition.Agent.Frontmatter.Skills)
}
if len(definition.Agent.Frontmatter.MCPServers) != 1 || definition.Agent.Frontmatter.MCPServers[0] != "github" {
t.Fatalf("expected mcpServers to be parsed, got %v", definition.Agent.Frontmatter.MCPServers)
}
if definition.Agent.Frontmatter.Fields["metadata"] == nil {
t.Fatal("expected arbitrary frontmatter fields to remain available")
}
if definition.Soul == nil {
t.Fatal("expected SOUL.md to be loaded")
}
if !strings.Contains(definition.Soul.Content, "Stay precise") {
t.Fatalf("expected soul content to be loaded, got %q", definition.Soul.Content)
}
if definition.Soul.Path != filepath.Join(tmpDir, "SOUL.md") {
t.Fatalf("expected default SOUL.md path, got %q", definition.Soul.Path)
}
}
func TestLoadAgentDefinitionFallsBackToLegacyAgentsMarkdown(t *testing.T) {
tmpDir := setupWorkspace(t, map[string]string{
"AGENTS.md": "# Legacy Agent\nKeep compatibility.",
"SOUL.md": "# Soul\nLegacy soul.",
})
defer cleanupWorkspace(t, tmpDir)
cb := NewContextBuilder(tmpDir)
definition := cb.LoadAgentDefinition()
if definition.Source != AgentDefinitionSourceAgents {
t.Fatalf("expected source %q, got %q", AgentDefinitionSourceAgents, definition.Source)
}
if definition.Agent == nil {
t.Fatal("expected AGENTS.md to be loaded")
}
if definition.Agent.RawFrontmatter != "" {
t.Fatalf("legacy AGENTS.md should not have frontmatter, got %q", definition.Agent.RawFrontmatter)
}
if !strings.Contains(definition.Agent.Body, "Keep compatibility") {
t.Fatalf("expected legacy body to be preserved, got %q", definition.Agent.Body)
}
if definition.Soul == nil || !strings.Contains(definition.Soul.Content, "Legacy soul") {
t.Fatal("expected default SOUL.md to be loaded for legacy format")
}
}
func TestLoadAgentDefinitionLoadsWorkspaceUserMarkdown(t *testing.T) {
tmpDir := setupWorkspace(t, map[string]string{
"AGENT.md": "# Agent\nStructured agent.",
"USER.md": "# User\nWorkspace preferences.",
})
defer cleanupWorkspace(t, tmpDir)
cb := NewContextBuilder(tmpDir)
definition := cb.LoadAgentDefinition()
if definition.User == nil {
t.Fatal("expected USER.md to be loaded")
}
if definition.User.Path != filepath.Join(tmpDir, "USER.md") {
t.Fatalf("expected workspace USER.md path, got %q", definition.User.Path)
}
if !strings.Contains(definition.User.Content, "Workspace preferences") {
t.Fatalf("expected workspace USER.md content, got %q", definition.User.Content)
}
}
func TestLoadAgentDefinitionInvalidFrontmatterFallsBackToEmptyStructuredFields(t *testing.T) {
tmpDir := setupWorkspace(t, map[string]string{
"AGENT.md": `---
name: pico
tools:
- shell
broken
---
# Agent
Keep going.
`,
})
defer cleanupWorkspace(t, tmpDir)
cb := NewContextBuilder(tmpDir)
definition := cb.LoadAgentDefinition()
if definition.Agent == nil {
t.Fatal("expected AGENT.md definition to be loaded")
}
if !strings.Contains(definition.Agent.Body, "Keep going.") {
t.Fatalf("expected AGENT.md body to be preserved, got %q", definition.Agent.Body)
}
if definition.Agent.Frontmatter.Name != "" ||
definition.Agent.Frontmatter.Description != "" ||
definition.Agent.Frontmatter.Model != "" ||
definition.Agent.Frontmatter.MaxTurns != nil ||
len(definition.Agent.Frontmatter.Tools) != 0 ||
len(definition.Agent.Frontmatter.Skills) != 0 ||
len(definition.Agent.Frontmatter.MCPServers) != 0 ||
len(definition.Agent.Frontmatter.Fields) != 0 {
t.Fatalf("expected invalid frontmatter to decode as empty struct, got %+v", definition.Agent.Frontmatter)
}
}
func TestLoadBootstrapFilesUsesAgentBodyNotFrontmatter(t *testing.T) {
tmpDir := setupWorkspace(t, map[string]string{
"AGENT.md": `---
name: pico
model: codex-mini
---
# Agent
Follow the body prompt.
`,
"SOUL.md": "# Soul\nSpeak plainly.",
"IDENTITY.md": "# Identity\nWorkspace identity.",
})
defer cleanupWorkspace(t, tmpDir)
cb := NewContextBuilder(tmpDir)
bootstrap := cb.LoadBootstrapFiles()
if !strings.Contains(bootstrap, "Follow the body prompt") {
t.Fatalf("expected AGENT.md body in bootstrap, got %q", bootstrap)
}
if !strings.Contains(bootstrap, "Speak plainly") {
t.Fatalf("expected resolved soul content in bootstrap, got %q", bootstrap)
}
if strings.Contains(bootstrap, "name: pico") {
t.Fatalf("bootstrap should not expose raw frontmatter, got %q", bootstrap)
}
if strings.Contains(bootstrap, "model: codex-mini") {
t.Fatalf("bootstrap should not expose raw frontmatter, got %q", bootstrap)
}
if !strings.Contains(bootstrap, "SOUL.md") {
t.Fatalf("expected bootstrap to label SOUL.md, got %q", bootstrap)
}
if strings.Contains(bootstrap, "Workspace identity") {
t.Fatalf("structured bootstrap should ignore IDENTITY.md, got %q", bootstrap)
}
}
func TestLoadBootstrapFilesIncludesWorkspaceUserMarkdown(t *testing.T) {
tmpDir := setupWorkspace(t, map[string]string{
"AGENT.md": "# Agent\nFollow the new structure.",
"SOUL.md": "# Soul\nSpeak plainly.",
"USER.md": "# User\nShared profile.",
})
defer cleanupWorkspace(t, tmpDir)
cb := NewContextBuilder(tmpDir)
bootstrap := cb.LoadBootstrapFiles()
if !strings.Contains(bootstrap, "Shared profile") {
t.Fatalf("expected workspace USER.md in bootstrap, got %q", bootstrap)
}
if !strings.Contains(bootstrap, "## USER.md") {
t.Fatalf("expected USER.md heading in bootstrap, got %q", bootstrap)
}
}
func TestStructuredAgentIgnoresIdentityChanges(t *testing.T) {
tmpDir := setupWorkspace(t, map[string]string{
"AGENT.md": "# Agent\nFollow the new structure.",
"SOUL.md": "# Soul\nVersion one.",
"IDENTITY.md": "# Identity\nLegacy identity.",
})
defer cleanupWorkspace(t, tmpDir)
cb := NewContextBuilder(tmpDir)
promptV1 := cb.BuildSystemPromptWithCache()
if strings.Contains(promptV1, "Legacy identity") {
t.Fatalf("structured prompt should not include IDENTITY.md, got %q", promptV1)
}
identityPath := filepath.Join(tmpDir, "IDENTITY.md")
if err := os.WriteFile(identityPath, []byte("# Identity\nVersion two."), 0o644); err != nil {
t.Fatal(err)
}
future := time.Now().Add(2 * time.Second)
if err := os.Chtimes(identityPath, future, future); err != nil {
t.Fatal(err)
}
cb.systemPromptMutex.RLock()
changed := cb.sourceFilesChangedLocked()
cb.systemPromptMutex.RUnlock()
if changed {
t.Fatal("IDENTITY.md should not invalidate cache for structured agent definitions")
}
promptV2 := cb.BuildSystemPromptWithCache()
if promptV1 != promptV2 {
t.Fatal("structured prompt should remain stable after IDENTITY.md changes")
}
}
func TestStructuredAgentUserChangesInvalidateCache(t *testing.T) {
tmpDir := setupWorkspace(t, map[string]string{
"AGENT.md": "# Agent\nFollow the new structure.",
"SOUL.md": "# Soul\nVersion one.",
"USER.md": "# User\nInitial workspace preferences.",
})
defer cleanupWorkspace(t, tmpDir)
cb := NewContextBuilder(tmpDir)
promptV1 := cb.BuildSystemPromptWithCache()
if !strings.Contains(promptV1, "Initial workspace preferences") {
t.Fatalf("expected workspace USER.md in prompt, got %q", promptV1)
}
userPath := filepath.Join(tmpDir, "USER.md")
if err := os.WriteFile(userPath, []byte("# User\nUpdated workspace preferences."), 0o644); err != nil {
t.Fatal(err)
}
future := time.Now().Add(2 * time.Second)
if err := os.Chtimes(userPath, future, future); err != nil {
t.Fatal(err)
}
cb.systemPromptMutex.RLock()
changed := cb.sourceFilesChangedLocked()
cb.systemPromptMutex.RUnlock()
if !changed {
t.Fatal("workspace USER.md changes should invalidate cache")
}
promptV2 := cb.BuildSystemPromptWithCache()
if !strings.Contains(promptV2, "Updated workspace preferences") {
t.Fatalf("expected updated workspace USER.md in prompt, got %q", promptV2)
}
}
func cleanupWorkspace(t *testing.T, path string) {
t.Helper()
if err := os.RemoveAll(path); err != nil {
t.Fatalf("failed to clean up workspace %s: %v", path, err)
}
}
+121
View File
@@ -0,0 +1,121 @@
package agent
import (
"sync"
"sync/atomic"
"time"
)
const defaultEventSubscriberBuffer = 16
// EventSubscription identifies a subscriber channel returned by EventBus.Subscribe.
type EventSubscription struct {
ID uint64
C <-chan Event
}
type eventSubscriber struct {
ch chan Event
}
// EventBus is a lightweight multi-subscriber broadcaster for agent-loop events.
type EventBus struct {
mu sync.RWMutex
subs map[uint64]eventSubscriber
nextID uint64
closed bool
dropped [eventKindCount]atomic.Int64
}
// NewEventBus creates a new in-process event broadcaster.
func NewEventBus() *EventBus {
return &EventBus{
subs: make(map[uint64]eventSubscriber),
}
}
// Subscribe registers a new subscriber with the requested channel buffer size.
// A non-positive buffer uses the default size.
func (b *EventBus) Subscribe(buffer int) EventSubscription {
if buffer <= 0 {
buffer = defaultEventSubscriberBuffer
}
b.mu.Lock()
defer b.mu.Unlock()
if b.closed {
ch := make(chan Event)
close(ch)
return EventSubscription{C: ch}
}
b.nextID++
id := b.nextID
ch := make(chan Event, buffer)
b.subs[id] = eventSubscriber{ch: ch}
return EventSubscription{ID: id, C: ch}
}
// Unsubscribe removes a subscriber and closes its channel.
func (b *EventBus) Unsubscribe(id uint64) {
b.mu.Lock()
defer b.mu.Unlock()
sub, ok := b.subs[id]
if !ok {
return
}
delete(b.subs, id)
close(sub.ch)
}
// Emit broadcasts an event to all current subscribers without blocking.
// When a subscriber channel is full, the event is dropped for that subscriber.
func (b *EventBus) Emit(evt Event) {
if evt.Time.IsZero() {
evt.Time = time.Now()
}
b.mu.RLock()
defer b.mu.RUnlock()
if b.closed {
return
}
for _, sub := range b.subs {
select {
case sub.ch <- evt:
default:
if evt.Kind < eventKindCount {
b.dropped[evt.Kind].Add(1)
}
}
}
}
// Dropped returns the number of dropped events for a given kind.
func (b *EventBus) Dropped(kind EventKind) int64 {
if kind >= eventKindCount {
return 0
}
return b.dropped[kind].Load()
}
// Close closes all subscriber channels and stops future broadcasts.
func (b *EventBus) Close() {
b.mu.Lock()
defer b.mu.Unlock()
if b.closed {
return
}
b.closed = true
for id, sub := range b.subs {
close(sub.ch)
delete(b.subs, id)
}
}
-12
View File
@@ -1,12 +0,0 @@
package agent
import "fmt"
// MockEventBus - for POC
var MockEventBus = struct {
Emit func(event any)
}{
Emit: func(event any) {
fmt.Printf("[Mock EventBus] %T %+v\n", event, event)
},
}
+684
View File
@@ -0,0 +1,684 @@
package agent
import (
"context"
"os"
"slices"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/tools"
)
func TestEventBus_SubscribeEmitUnsubscribeClose(t *testing.T) {
eventBus := NewEventBus()
sub := eventBus.Subscribe(1)
eventBus.Emit(Event{
Kind: EventKindTurnStart,
Meta: EventMeta{TurnID: "turn-1"},
})
select {
case evt := <-sub.C:
if evt.Kind != EventKindTurnStart {
t.Fatalf("expected %v, got %v", EventKindTurnStart, evt.Kind)
}
if evt.Meta.TurnID != "turn-1" {
t.Fatalf("expected turn id turn-1, got %q", evt.Meta.TurnID)
}
case <-time.After(time.Second):
t.Fatal("timed out waiting for event")
}
eventBus.Unsubscribe(sub.ID)
if _, ok := <-sub.C; ok {
t.Fatal("expected subscriber channel to be closed after unsubscribe")
}
eventBus.Close()
closedSub := eventBus.Subscribe(1)
if _, ok := <-closedSub.C; ok {
t.Fatal("expected closed bus to return a closed subscriber channel")
}
}
func TestEventBus_DropsWhenSubscriberIsFull(t *testing.T) {
eventBus := NewEventBus()
sub := eventBus.Subscribe(1)
defer eventBus.Unsubscribe(sub.ID)
start := time.Now()
for i := 0; i < 1000; i++ {
eventBus.Emit(Event{Kind: EventKindLLMRequest})
}
if elapsed := time.Since(start); elapsed > 100*time.Millisecond {
t.Fatalf("Emit took too long with a blocked subscriber: %s", elapsed)
}
if got := eventBus.Dropped(EventKindLLMRequest); got != 999 {
t.Fatalf("expected 999 dropped events, got %d", got)
}
}
type scriptedToolProvider struct {
calls int
}
func (m *scriptedToolProvider) Chat(
ctx context.Context,
messages []providers.Message,
toolDefs []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
m.calls++
if m.calls == 1 {
return &providers.LLMResponse{
ToolCalls: []providers.ToolCall{
{
ID: "call-1",
Name: "mock_custom",
Arguments: map[string]any{"task": "ping"},
},
},
}, nil
}
return &providers.LLMResponse{
Content: "done",
}, nil
}
func (m *scriptedToolProvider) GetDefaultModel() string {
return "scripted-tool-model"
}
func TestAgentLoop_EmitsMinimalTurnEvents(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-eventbus-*")
if err != nil {
t.Fatalf("failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &scriptedToolProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
al.RegisterTool(&mockCustomTool{})
defaultAgent := al.registry.GetDefaultAgent()
if defaultAgent == nil {
t.Fatal("expected default agent")
}
sub := al.SubscribeEvents(16)
defer al.UnsubscribeEvents(sub.ID)
response, err := al.runAgentLoop(context.Background(), defaultAgent, processOptions{
SessionKey: "session-1",
Channel: "cli",
ChatID: "direct",
UserMessage: "run tool",
DefaultResponse: defaultResponse,
EnableSummary: false,
SendResponse: false,
})
if err != nil {
t.Fatalf("runAgentLoop failed: %v", err)
}
if response != "done" {
t.Fatalf("expected final response 'done', got %q", response)
}
events := collectEventStream(sub.C)
if len(events) != 8 {
t.Fatalf("expected 8 events, got %d", len(events))
}
kinds := make([]EventKind, 0, len(events))
for _, evt := range events {
kinds = append(kinds, evt.Kind)
}
expectedKinds := []EventKind{
EventKindTurnStart,
EventKindLLMRequest,
EventKindLLMResponse,
EventKindToolExecStart,
EventKindToolExecEnd,
EventKindLLMRequest,
EventKindLLMResponse,
EventKindTurnEnd,
}
if !slices.Equal(kinds, expectedKinds) {
t.Fatalf("unexpected event sequence: got %v want %v", kinds, expectedKinds)
}
turnID := events[0].Meta.TurnID
for i, evt := range events {
if evt.Meta.TurnID != turnID {
t.Fatalf("event %d has mismatched turn id %q, want %q", i, evt.Meta.TurnID, turnID)
}
if evt.Meta.SessionKey != "session-1" {
t.Fatalf("event %d has session key %q, want session-1", i, evt.Meta.SessionKey)
}
}
startPayload, ok := events[0].Payload.(TurnStartPayload)
if !ok {
t.Fatalf("expected TurnStartPayload, got %T", events[0].Payload)
}
if startPayload.UserMessage != "run tool" {
t.Fatalf("expected user message 'run tool', got %q", startPayload.UserMessage)
}
toolStartPayload, ok := events[3].Payload.(ToolExecStartPayload)
if !ok {
t.Fatalf("expected ToolExecStartPayload, got %T", events[3].Payload)
}
if toolStartPayload.Tool != "mock_custom" {
t.Fatalf("expected tool name mock_custom, got %q", toolStartPayload.Tool)
}
toolEndPayload, ok := events[4].Payload.(ToolExecEndPayload)
if !ok {
t.Fatalf("expected ToolExecEndPayload, got %T", events[4].Payload)
}
if toolEndPayload.Tool != "mock_custom" {
t.Fatalf("expected tool end payload for mock_custom, got %q", toolEndPayload.Tool)
}
if toolEndPayload.IsError {
t.Fatal("expected mock_custom tool to succeed")
}
turnEndPayload, ok := events[len(events)-1].Payload.(TurnEndPayload)
if !ok {
t.Fatalf("expected TurnEndPayload, got %T", events[len(events)-1].Payload)
}
if turnEndPayload.Status != TurnEndStatusCompleted {
t.Fatalf("expected completed turn, got %q", turnEndPayload.Status)
}
if turnEndPayload.Iterations != 2 {
t.Fatalf("expected 2 iterations, got %d", turnEndPayload.Iterations)
}
}
func TestAgentLoop_EmitsSteeringAndSkippedToolEvents(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-eventbus-steering-*")
if err != nil {
t.Fatalf("failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
tool1ExecCh := make(chan struct{})
tool1 := &slowTool{name: "tool_one", duration: 50 * time.Millisecond, execCh: tool1ExecCh}
tool2 := &slowTool{name: "tool_two", duration: 50 * time.Millisecond}
provider := &toolCallProvider{
toolCalls: []providers.ToolCall{
{
ID: "call_1",
Type: "function",
Name: "tool_one",
Function: &providers.FunctionCall{
Name: "tool_one",
Arguments: "{}",
},
Arguments: map[string]any{},
},
{
ID: "call_2",
Type: "function",
Name: "tool_two",
Function: &providers.FunctionCall{
Name: "tool_two",
Arguments: "{}",
},
Arguments: map[string]any{},
},
},
finalResp: "steered response",
}
msgBus := bus.NewMessageBus()
al := NewAgentLoop(cfg, msgBus, provider)
al.RegisterTool(tool1)
al.RegisterTool(tool2)
sub := al.SubscribeEvents(32)
defer al.UnsubscribeEvents(sub.ID)
resultCh := make(chan string, 1)
go func() {
resp, _ := al.ProcessDirectWithChannel(context.Background(), "do something", "test-session", "test", "chat1")
resultCh <- resp
}()
select {
case <-tool1ExecCh:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for tool_one to start")
}
if err := al.Steer(providers.Message{Role: "user", Content: "change course"}); err != nil {
t.Fatalf("Steer failed: %v", err)
}
select {
case resp := <-resultCh:
if resp != "steered response" {
t.Fatalf("expected steered response, got %q", resp)
}
case <-time.After(5 * time.Second):
t.Fatal("timeout waiting for steered response")
}
events := collectEventStream(sub.C)
steeringEvt, ok := findEvent(events, EventKindSteeringInjected)
if !ok {
t.Fatal("expected steering injected event")
}
steeringPayload, ok := steeringEvt.Payload.(SteeringInjectedPayload)
if !ok {
t.Fatalf("expected SteeringInjectedPayload, got %T", steeringEvt.Payload)
}
if steeringPayload.Count != 1 {
t.Fatalf("expected 1 steering message, got %d", steeringPayload.Count)
}
skippedEvt, ok := findEvent(events, EventKindToolExecSkipped)
if !ok {
t.Fatal("expected skipped tool event")
}
skippedPayload, ok := skippedEvt.Payload.(ToolExecSkippedPayload)
if !ok {
t.Fatalf("expected ToolExecSkippedPayload, got %T", skippedEvt.Payload)
}
if skippedPayload.Tool != "tool_two" {
t.Fatalf("expected skipped tool_two, got %q", skippedPayload.Tool)
}
interruptEvt, ok := findEvent(events, EventKindInterruptReceived)
if !ok {
t.Fatal("expected interrupt received event")
}
interruptPayload, ok := interruptEvt.Payload.(InterruptReceivedPayload)
if !ok {
t.Fatalf("expected InterruptReceivedPayload, got %T", interruptEvt.Payload)
}
if interruptPayload.Role != "user" {
t.Fatalf("expected interrupt role user, got %q", interruptPayload.Role)
}
if interruptPayload.Kind != InterruptKindSteering {
t.Fatalf("expected steering interrupt kind, got %q", interruptPayload.Kind)
}
if interruptPayload.ContentLen != len("change course") {
t.Fatalf("expected interrupt content len %d, got %d", len("change course"), interruptPayload.ContentLen)
}
}
func TestAgentLoop_EmitsContextCompressEventOnRetry(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-eventbus-compress-*")
if err != nil {
t.Fatalf("failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
contextErr := stringError("InvalidParameter: Total tokens of image and text exceed max message tokens")
provider := &failFirstMockProvider{
failures: 1,
failError: contextErr,
successResp: "Recovered from context error",
}
msgBus := bus.NewMessageBus()
al := NewAgentLoop(cfg, msgBus, provider)
defaultAgent := al.registry.GetDefaultAgent()
if defaultAgent == nil {
t.Fatal("expected default agent")
}
defaultAgent.Sessions.SetHistory("session-1", []providers.Message{
{Role: "user", Content: "Old message 1"},
{Role: "assistant", Content: "Old response 1"},
{Role: "user", Content: "Old message 2"},
{Role: "assistant", Content: "Old response 2"},
{Role: "user", Content: "Trigger message"},
})
sub := al.SubscribeEvents(16)
defer al.UnsubscribeEvents(sub.ID)
resp, err := al.runAgentLoop(context.Background(), defaultAgent, processOptions{
SessionKey: "session-1",
Channel: "cli",
ChatID: "direct",
UserMessage: "Trigger message",
DefaultResponse: defaultResponse,
EnableSummary: false,
SendResponse: false,
})
if err != nil {
t.Fatalf("runAgentLoop failed: %v", err)
}
if resp != "Recovered from context error" {
t.Fatalf("expected retry success, got %q", resp)
}
events := collectEventStream(sub.C)
retryEvt, ok := findEvent(events, EventKindLLMRetry)
if !ok {
t.Fatal("expected llm retry event")
}
retryPayload, ok := retryEvt.Payload.(LLMRetryPayload)
if !ok {
t.Fatalf("expected LLMRetryPayload, got %T", retryEvt.Payload)
}
if retryPayload.Reason != "context_limit" {
t.Fatalf("expected context_limit retry reason, got %q", retryPayload.Reason)
}
if retryPayload.Attempt != 1 {
t.Fatalf("expected retry attempt 1, got %d", retryPayload.Attempt)
}
compressEvt, ok := findEvent(events, EventKindContextCompress)
if !ok {
t.Fatal("expected context compress event")
}
payload, ok := compressEvt.Payload.(ContextCompressPayload)
if !ok {
t.Fatalf("expected ContextCompressPayload, got %T", compressEvt.Payload)
}
if payload.Reason != ContextCompressReasonRetry {
t.Fatalf("expected retry compress reason, got %q", payload.Reason)
}
if payload.DroppedMessages == 0 {
t.Fatal("expected dropped messages to be recorded")
}
}
func TestAgentLoop_EmitsSessionSummarizeEvent(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-eventbus-summary-*")
if err != nil {
t.Fatalf("failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
ContextWindow: 8000,
SummarizeMessageThreshold: 2,
SummarizeTokenPercent: 75,
},
},
}
msgBus := bus.NewMessageBus()
al := NewAgentLoop(cfg, msgBus, &simpleMockProvider{response: "summary text"})
defaultAgent := al.registry.GetDefaultAgent()
if defaultAgent == nil {
t.Fatal("expected default agent")
}
defaultAgent.Sessions.SetHistory("session-1", []providers.Message{
{Role: "user", Content: "Question one"},
{Role: "assistant", Content: "Answer one"},
{Role: "user", Content: "Question two"},
{Role: "assistant", Content: "Answer two"},
{Role: "user", Content: "Question three"},
{Role: "assistant", Content: "Answer three"},
})
sub := al.SubscribeEvents(16)
defer al.UnsubscribeEvents(sub.ID)
turnScope := al.newTurnEventScope(defaultAgent.ID, "session-1")
al.summarizeSession(defaultAgent, "session-1", turnScope)
events := collectEventStream(sub.C)
summaryEvt, ok := findEvent(events, EventKindSessionSummarize)
if !ok {
t.Fatal("expected session summarize event")
}
payload, ok := summaryEvt.Payload.(SessionSummarizePayload)
if !ok {
t.Fatalf("expected SessionSummarizePayload, got %T", summaryEvt.Payload)
}
if payload.SummaryLen == 0 {
t.Fatal("expected non-empty summary length")
}
}
func TestAgentLoop_EmitsFollowUpQueuedEvent(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-eventbus-followup-*")
if err != nil {
t.Fatalf("failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
provider := &toolCallProvider{
toolCalls: []providers.ToolCall{
{
ID: "call_async_1",
Type: "function",
Name: "async_followup",
Function: &providers.FunctionCall{
Name: "async_followup",
Arguments: "{}",
},
Arguments: map[string]any{},
},
},
finalResp: "async launched",
}
msgBus := bus.NewMessageBus()
al := NewAgentLoop(cfg, msgBus, provider)
doneCh := make(chan struct{})
al.RegisterTool(&asyncFollowUpTool{
name: "async_followup",
followUpText: "background result",
completionSig: doneCh,
})
defaultAgent := al.registry.GetDefaultAgent()
if defaultAgent == nil {
t.Fatal("expected default agent")
}
sub := al.SubscribeEvents(32)
defer al.UnsubscribeEvents(sub.ID)
resp, err := al.runAgentLoop(context.Background(), defaultAgent, processOptions{
SessionKey: "session-1",
Channel: "cli",
ChatID: "direct",
UserMessage: "run async tool",
DefaultResponse: defaultResponse,
EnableSummary: false,
SendResponse: false,
})
if err != nil {
t.Fatalf("runAgentLoop failed: %v", err)
}
if resp != "async launched" {
t.Fatalf("expected final response 'async launched', got %q", resp)
}
select {
case <-doneCh:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for async tool completion")
}
followUpEvt := waitForEvent(t, sub.C, 2*time.Second, func(evt Event) bool {
return evt.Kind == EventKindFollowUpQueued
})
payload, ok := followUpEvt.Payload.(FollowUpQueuedPayload)
if !ok {
t.Fatalf("expected FollowUpQueuedPayload, got %T", followUpEvt.Payload)
}
if payload.SourceTool != "async_followup" {
t.Fatalf("expected source tool async_followup, got %q", payload.SourceTool)
}
if payload.Channel != "cli" {
t.Fatalf("expected channel cli, got %q", payload.Channel)
}
if payload.ChatID != "direct" {
t.Fatalf("expected chat id direct, got %q", payload.ChatID)
}
if payload.ContentLen != len("background result") {
t.Fatalf("expected content len %d, got %d", len("background result"), payload.ContentLen)
}
if followUpEvt.Meta.SessionKey != "session-1" {
t.Fatalf("expected session key session-1, got %q", followUpEvt.Meta.SessionKey)
}
if followUpEvt.Meta.TurnID == "" {
t.Fatal("expected follow-up event to include turn id")
}
}
func collectEventStream(ch <-chan Event) []Event {
var events []Event
for {
select {
case evt, ok := <-ch:
if !ok {
return events
}
events = append(events, evt)
default:
return events
}
}
}
func waitForEvent(t *testing.T, ch <-chan Event, timeout time.Duration, match func(Event) bool) Event {
t.Helper()
timer := time.NewTimer(timeout)
defer timer.Stop()
for {
select {
case evt, ok := <-ch:
if !ok {
t.Fatal("event stream closed before expected event arrived")
}
if match(evt) {
return evt
}
case <-timer.C:
t.Fatal("timed out waiting for expected event")
}
}
}
func findEvent(events []Event, kind EventKind) (Event, bool) {
for _, evt := range events {
if evt.Kind == kind {
return evt, true
}
}
return Event{}, false
}
type stringError string
func (e stringError) Error() string {
return string(e)
}
type asyncFollowUpTool struct {
name string
followUpText string
completionSig chan struct{}
}
func (t *asyncFollowUpTool) Name() string {
return t.name
}
func (t *asyncFollowUpTool) Description() string {
return "async follow-up tool for testing"
}
func (t *asyncFollowUpTool) Parameters() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]any{},
}
}
func (t *asyncFollowUpTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
return tools.AsyncResult("async follow-up scheduled")
}
func (t *asyncFollowUpTool) ExecuteAsync(
ctx context.Context,
args map[string]any,
cb tools.AsyncCallback,
) *tools.ToolResult {
go func() {
cb(ctx, &tools.ToolResult{ForLLM: t.followUpText})
if t.completionSig != nil {
close(t.completionSig)
}
}()
return tools.AsyncResult("async follow-up scheduled")
}
var (
_ tools.Tool = (*mockCustomTool)(nil)
_ tools.AsyncExecutor = (*asyncFollowUpTool)(nil)
)
+271
View File
@@ -0,0 +1,271 @@
package agent
import (
"fmt"
"time"
)
// EventKind identifies a structured agent-loop event.
type EventKind uint8
const (
// EventKindTurnStart is emitted when a turn begins processing.
EventKindTurnStart EventKind = iota
// EventKindTurnEnd is emitted when a turn finishes, successfully or with an error.
EventKindTurnEnd
// EventKindLLMRequest is emitted before a provider chat request is made.
EventKindLLMRequest
// EventKindLLMDelta is emitted when a streaming provider yields a partial delta.
EventKindLLMDelta
// EventKindLLMResponse is emitted after a provider chat response is received.
EventKindLLMResponse
// EventKindLLMRetry is emitted when an LLM request is retried.
EventKindLLMRetry
// EventKindContextCompress is emitted when session history is forcibly compressed.
EventKindContextCompress
// EventKindSessionSummarize is emitted when asynchronous summarization completes.
EventKindSessionSummarize
// EventKindToolExecStart is emitted immediately before a tool executes.
EventKindToolExecStart
// EventKindToolExecEnd is emitted immediately after a tool finishes executing.
EventKindToolExecEnd
// EventKindToolExecSkipped is emitted when a queued tool call is skipped.
EventKindToolExecSkipped
// EventKindSteeringInjected is emitted when queued steering is injected into context.
EventKindSteeringInjected
// EventKindFollowUpQueued is emitted when an async tool queues a follow-up system message.
EventKindFollowUpQueued
// EventKindInterruptReceived is emitted when a soft interrupt message is accepted.
EventKindInterruptReceived
// EventKindSubTurnSpawn is emitted when a sub-turn is spawned.
EventKindSubTurnSpawn
// EventKindSubTurnEnd is emitted when a sub-turn finishes.
EventKindSubTurnEnd
// EventKindSubTurnResultDelivered is emitted when a sub-turn result is delivered.
EventKindSubTurnResultDelivered
// EventKindSubTurnOrphan is emitted when a sub-turn result cannot be delivered.
EventKindSubTurnOrphan
// EventKindError is emitted when a turn encounters an execution error.
EventKindError
eventKindCount
)
var eventKindNames = [...]string{
"turn_start",
"turn_end",
"llm_request",
"llm_delta",
"llm_response",
"llm_retry",
"context_compress",
"session_summarize",
"tool_exec_start",
"tool_exec_end",
"tool_exec_skipped",
"steering_injected",
"follow_up_queued",
"interrupt_received",
"subturn_spawn",
"subturn_end",
"subturn_result_delivered",
"subturn_orphan",
"error",
}
// String returns the stable string form of an EventKind.
func (k EventKind) String() string {
if k >= eventKindCount {
return fmt.Sprintf("event_kind(%d)", k)
}
return eventKindNames[k]
}
// Event is the structured envelope broadcast by the agent EventBus.
type Event struct {
Kind EventKind
Time time.Time
Meta EventMeta
Payload any
}
// EventMeta contains correlation fields shared by all agent-loop events.
type EventMeta struct {
AgentID string
TurnID string
ParentTurnID string
SessionKey string
Iteration int
TracePath string
Source string
}
// TurnEndStatus describes the terminal state of a turn.
type TurnEndStatus string
const (
// TurnEndStatusCompleted indicates the turn finished normally.
TurnEndStatusCompleted TurnEndStatus = "completed"
// TurnEndStatusError indicates the turn ended because of an error.
TurnEndStatusError TurnEndStatus = "error"
// TurnEndStatusAborted indicates the turn was hard-aborted and rolled back.
TurnEndStatusAborted TurnEndStatus = "aborted"
)
// TurnStartPayload describes the start of a turn.
type TurnStartPayload struct {
Channel string
ChatID string
UserMessage string
MediaCount int
}
// TurnEndPayload describes the completion of a turn.
type TurnEndPayload struct {
Status TurnEndStatus
Iterations int
Duration time.Duration
FinalContentLen int
}
// LLMRequestPayload describes an outbound LLM request.
type LLMRequestPayload struct {
Model string
MessagesCount int
ToolsCount int
MaxTokens int
Temperature float64
}
// LLMResponsePayload describes an inbound LLM response.
type LLMResponsePayload struct {
ContentLen int
ToolCalls int
HasReasoning bool
}
// LLMDeltaPayload describes a streamed LLM delta.
type LLMDeltaPayload struct {
ContentDeltaLen int
ReasoningDeltaLen int
}
// LLMRetryPayload describes a retry of an LLM request.
type LLMRetryPayload struct {
Attempt int
MaxRetries int
Reason string
Error string
Backoff time.Duration
}
// ContextCompressReason identifies why emergency compression ran.
type ContextCompressReason string
const (
// ContextCompressReasonProactive indicates compression before the first LLM call.
ContextCompressReasonProactive ContextCompressReason = "proactive_budget"
// ContextCompressReasonRetry indicates compression during context-error retry handling.
ContextCompressReasonRetry ContextCompressReason = "llm_retry"
)
// ContextCompressPayload describes a forced history compression.
type ContextCompressPayload struct {
Reason ContextCompressReason
DroppedMessages int
RemainingMessages int
}
// SessionSummarizePayload describes a completed async session summarization.
type SessionSummarizePayload struct {
SummarizedMessages int
KeptMessages int
SummaryLen int
OmittedOversized bool
}
// ToolExecStartPayload describes a tool execution request.
type ToolExecStartPayload struct {
Tool string
Arguments map[string]any
}
// ToolExecEndPayload describes the outcome of a tool execution.
type ToolExecEndPayload struct {
Tool string
Duration time.Duration
ForLLMLen int
ForUserLen int
IsError bool
Async bool
}
// ToolExecSkippedPayload describes a skipped tool call.
type ToolExecSkippedPayload struct {
Tool string
Reason string
}
// SteeringInjectedPayload describes steering messages appended before the next LLM call.
type SteeringInjectedPayload struct {
Count int
TotalContentLen int
}
// FollowUpQueuedPayload describes an async follow-up queued back into the inbound bus.
type FollowUpQueuedPayload struct {
SourceTool string
Channel string
ChatID string
ContentLen int
}
type InterruptKind string
const (
InterruptKindSteering InterruptKind = "steering"
InterruptKindGraceful InterruptKind = "graceful"
InterruptKindHard InterruptKind = "hard_abort"
)
// InterruptReceivedPayload describes accepted turn-control input.
type InterruptReceivedPayload struct {
Kind InterruptKind
Role string
ContentLen int
QueueDepth int
HintLen int
}
// SubTurnSpawnPayload describes the creation of a child turn.
type SubTurnSpawnPayload struct {
AgentID string
Label string
ParentTurnID string
}
// SubTurnEndPayload describes the completion of a child turn.
type SubTurnEndPayload struct {
AgentID string
Status string
}
// SubTurnResultDeliveredPayload describes delivery of a sub-turn result.
type SubTurnResultDeliveredPayload struct {
TargetChannel string
TargetChatID string
ContentLen int
}
// SubTurnOrphanPayload describes a sub-turn result that could not be delivered.
type SubTurnOrphanPayload struct {
ParentTurnID string
ChildTurnID string
Reason string
}
// ErrorPayload describes an execution error inside the agent loop.
type ErrorPayload struct {
Stage string
Message string
}
+317
View File
@@ -0,0 +1,317 @@
package agent
import (
"context"
"fmt"
"sort"
"sync"
"time"
"github.com/sipeed/picoclaw/pkg/config"
)
type hookRuntime struct {
initOnce sync.Once
mu sync.Mutex
initErr error
mounted []string
}
func (r *hookRuntime) setInitErr(err error) {
r.mu.Lock()
r.initErr = err
r.mu.Unlock()
}
func (r *hookRuntime) getInitErr() error {
r.mu.Lock()
defer r.mu.Unlock()
return r.initErr
}
func (r *hookRuntime) setMounted(names []string) {
r.mu.Lock()
r.mounted = append([]string(nil), names...)
r.mu.Unlock()
}
func (r *hookRuntime) reset(al *AgentLoop) {
r.mu.Lock()
names := append([]string(nil), r.mounted...)
r.mounted = nil
r.initErr = nil
r.initOnce = sync.Once{}
r.mu.Unlock()
for _, name := range names {
al.UnmountHook(name)
}
}
// BuiltinHookFactory constructs an in-process hook from config.
type BuiltinHookFactory func(ctx context.Context, spec config.BuiltinHookConfig) (any, error)
var (
builtinHookRegistryMu sync.RWMutex
builtinHookRegistry = map[string]BuiltinHookFactory{}
)
// RegisterBuiltinHook registers a named in-process hook factory for config-driven mounting.
func RegisterBuiltinHook(name string, factory BuiltinHookFactory) error {
if name == "" {
return fmt.Errorf("builtin hook name is required")
}
if factory == nil {
return fmt.Errorf("builtin hook %q factory is nil", name)
}
builtinHookRegistryMu.Lock()
defer builtinHookRegistryMu.Unlock()
if _, exists := builtinHookRegistry[name]; exists {
return fmt.Errorf("builtin hook %q is already registered", name)
}
builtinHookRegistry[name] = factory
return nil
}
func unregisterBuiltinHook(name string) {
if name == "" {
return
}
builtinHookRegistryMu.Lock()
delete(builtinHookRegistry, name)
builtinHookRegistryMu.Unlock()
}
func lookupBuiltinHook(name string) (BuiltinHookFactory, bool) {
builtinHookRegistryMu.RLock()
defer builtinHookRegistryMu.RUnlock()
factory, ok := builtinHookRegistry[name]
return factory, ok
}
func configureHookManagerFromConfig(hm *HookManager, cfg *config.Config) {
if hm == nil || cfg == nil {
return
}
hm.ConfigureTimeouts(
hookTimeoutFromMS(cfg.Hooks.Defaults.ObserverTimeoutMS),
hookTimeoutFromMS(cfg.Hooks.Defaults.InterceptorTimeoutMS),
hookTimeoutFromMS(cfg.Hooks.Defaults.ApprovalTimeoutMS),
)
}
func hookTimeoutFromMS(ms int) time.Duration {
if ms <= 0 {
return 0
}
return time.Duration(ms) * time.Millisecond
}
func (al *AgentLoop) ensureHooksInitialized(ctx context.Context) error {
if al == nil || al.cfg == nil || al.hooks == nil {
return nil
}
al.hookRuntime.initOnce.Do(func() {
al.hookRuntime.setInitErr(al.loadConfiguredHooks(ctx))
})
return al.hookRuntime.getInitErr()
}
func (al *AgentLoop) loadConfiguredHooks(ctx context.Context) (err error) {
if al == nil || al.cfg == nil || !al.cfg.Hooks.Enabled {
return nil
}
mounted := make([]string, 0)
defer func() {
if err != nil {
for _, name := range mounted {
al.UnmountHook(name)
}
return
}
al.hookRuntime.setMounted(mounted)
}()
builtinNames := enabledBuiltinHookNames(al.cfg.Hooks.Builtins)
for _, name := range builtinNames {
spec := al.cfg.Hooks.Builtins[name]
factory, ok := lookupBuiltinHook(name)
if !ok {
return fmt.Errorf("builtin hook %q is not registered", name)
}
hook, factoryErr := factory(ctx, spec)
if factoryErr != nil {
return fmt.Errorf("build builtin hook %q: %w", name, factoryErr)
}
if err := al.MountHook(HookRegistration{
Name: name,
Priority: spec.Priority,
Source: HookSourceInProcess,
Hook: hook,
}); err != nil {
return fmt.Errorf("mount builtin hook %q: %w", name, err)
}
mounted = append(mounted, name)
}
processNames := enabledProcessHookNames(al.cfg.Hooks.Processes)
for _, name := range processNames {
spec := al.cfg.Hooks.Processes[name]
opts, buildErr := processHookOptionsFromConfig(spec)
if buildErr != nil {
return fmt.Errorf("configure process hook %q: %w", name, buildErr)
}
processHook, buildErr := NewProcessHook(ctx, name, opts)
if buildErr != nil {
return fmt.Errorf("start process hook %q: %w", name, buildErr)
}
if err := al.MountHook(HookRegistration{
Name: name,
Priority: spec.Priority,
Source: HookSourceProcess,
Hook: processHook,
}); err != nil {
_ = processHook.Close()
return fmt.Errorf("mount process hook %q: %w", name, err)
}
mounted = append(mounted, name)
}
return nil
}
func enabledBuiltinHookNames(specs map[string]config.BuiltinHookConfig) []string {
if len(specs) == 0 {
return nil
}
names := make([]string, 0, len(specs))
for name, spec := range specs {
if spec.Enabled {
names = append(names, name)
}
}
sort.Strings(names)
return names
}
func enabledProcessHookNames(specs map[string]config.ProcessHookConfig) []string {
if len(specs) == 0 {
return nil
}
names := make([]string, 0, len(specs))
for name, spec := range specs {
if spec.Enabled {
names = append(names, name)
}
}
sort.Strings(names)
return names
}
func processHookOptionsFromConfig(spec config.ProcessHookConfig) (ProcessHookOptions, error) {
transport := spec.Transport
if transport == "" {
transport = "stdio"
}
if transport != "stdio" {
return ProcessHookOptions{}, fmt.Errorf("unsupported transport %q", transport)
}
if len(spec.Command) == 0 {
return ProcessHookOptions{}, fmt.Errorf("command is required")
}
opts := ProcessHookOptions{
Command: append([]string(nil), spec.Command...),
Dir: spec.Dir,
Env: processHookEnvFromMap(spec.Env),
}
observeKinds, observeEnabled, err := processHookObserveKindsFromConfig(spec.Observe)
if err != nil {
return ProcessHookOptions{}, err
}
opts.Observe = observeEnabled
opts.ObserveKinds = observeKinds
for _, intercept := range spec.Intercept {
switch intercept {
case "before_llm", "after_llm":
opts.InterceptLLM = true
case "before_tool", "after_tool":
opts.InterceptTool = true
case "approve_tool":
opts.ApproveTool = true
case "":
continue
default:
return ProcessHookOptions{}, fmt.Errorf("unsupported intercept %q", intercept)
}
}
if !opts.Observe && !opts.InterceptLLM && !opts.InterceptTool && !opts.ApproveTool {
return ProcessHookOptions{}, fmt.Errorf("no hook modes enabled")
}
return opts, nil
}
func processHookEnvFromMap(envMap map[string]string) []string {
if len(envMap) == 0 {
return nil
}
keys := make([]string, 0, len(envMap))
for key := range envMap {
keys = append(keys, key)
}
sort.Strings(keys)
env := make([]string, 0, len(keys))
for _, key := range keys {
env = append(env, key+"="+envMap[key])
}
return env
}
func processHookObserveKindsFromConfig(observe []string) ([]string, bool, error) {
if len(observe) == 0 {
return nil, false, nil
}
validKinds := validHookEventKinds()
normalized := make([]string, 0, len(observe))
for _, kind := range observe {
switch kind {
case "", "*", "all":
return nil, true, nil
default:
if _, ok := validKinds[kind]; !ok {
return nil, false, fmt.Errorf("unsupported observe event %q", kind)
}
normalized = append(normalized, kind)
}
}
if len(normalized) == 0 {
return nil, false, nil
}
return normalized, true, nil
}
func validHookEventKinds() map[string]struct{} {
kinds := make(map[string]struct{}, int(eventKindCount))
for kind := EventKind(0); kind < eventKindCount; kind++ {
kinds[kind.String()] = struct{}{}
}
return kinds
}
+179
View File
@@ -0,0 +1,179 @@
package agent
import (
"context"
"encoding/json"
"path/filepath"
"testing"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
)
type builtinAutoHookConfig struct {
Model string `json:"model"`
Suffix string `json:"suffix"`
}
type builtinAutoHook struct {
model string
suffix string
}
func (h *builtinAutoHook) BeforeLLM(
ctx context.Context,
req *LLMHookRequest,
) (*LLMHookRequest, HookDecision, error) {
next := req.Clone()
next.Model = h.model
return next, HookDecision{Action: HookActionModify}, nil
}
func (h *builtinAutoHook) AfterLLM(
ctx context.Context,
resp *LLMHookResponse,
) (*LLMHookResponse, HookDecision, error) {
next := resp.Clone()
if next.Response != nil {
next.Response.Content += h.suffix
}
return next, HookDecision{Action: HookActionModify}, nil
}
func newConfiguredHookLoop(t *testing.T, provider *llmHookTestProvider, hooks config.HooksConfig) *AgentLoop {
t.Helper()
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: t.TempDir(),
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
Hooks: hooks,
}
return NewAgentLoop(cfg, bus.NewMessageBus(), provider)
}
func TestAgentLoop_ProcessDirectWithChannel_AutoMountsBuiltinHook(t *testing.T) {
const hookName = "test-auto-builtin-hook"
if err := RegisterBuiltinHook(hookName, func(
ctx context.Context,
spec config.BuiltinHookConfig,
) (any, error) {
var hookCfg builtinAutoHookConfig
if len(spec.Config) > 0 {
if err := json.Unmarshal(spec.Config, &hookCfg); err != nil {
return nil, err
}
}
return &builtinAutoHook{
model: hookCfg.Model,
suffix: hookCfg.Suffix,
}, nil
}); err != nil {
t.Fatalf("RegisterBuiltinHook failed: %v", err)
}
t.Cleanup(func() {
unregisterBuiltinHook(hookName)
})
rawCfg, err := json.Marshal(builtinAutoHookConfig{
Model: "builtin-model",
Suffix: "|builtin",
})
if err != nil {
t.Fatalf("json.Marshal failed: %v", err)
}
provider := &llmHookTestProvider{}
al := newConfiguredHookLoop(t, provider, config.HooksConfig{
Enabled: true,
Builtins: map[string]config.BuiltinHookConfig{
hookName: {
Enabled: true,
Config: rawCfg,
},
},
})
defer al.Close()
resp, err := al.ProcessDirectWithChannel(context.Background(), "hello", "session-1", "cli", "direct")
if err != nil {
t.Fatalf("ProcessDirectWithChannel failed: %v", err)
}
if resp != "provider content|builtin" {
t.Fatalf("expected builtin-hooked content, got %q", resp)
}
provider.mu.Lock()
lastModel := provider.lastModel
provider.mu.Unlock()
if lastModel != "builtin-model" {
t.Fatalf("expected builtin model, got %q", lastModel)
}
}
func TestAgentLoop_ProcessDirectWithChannel_AutoMountsProcessHook(t *testing.T) {
provider := &llmHookTestProvider{}
eventLog := filepath.Join(t.TempDir(), "events.log")
al := newConfiguredHookLoop(t, provider, config.HooksConfig{
Enabled: true,
Processes: map[string]config.ProcessHookConfig{
"ipc-auto": {
Enabled: true,
Command: processHookHelperCommand(),
Env: map[string]string{
"PICOCLAW_HOOK_HELPER": "1",
"PICOCLAW_HOOK_MODE": "rewrite",
"PICOCLAW_HOOK_EVENT_LOG": eventLog,
},
Observe: []string{"turn_end"},
Intercept: []string{"before_llm", "after_llm"},
},
},
})
defer al.Close()
resp, err := al.ProcessDirectWithChannel(context.Background(), "hello", "session-1", "cli", "direct")
if err != nil {
t.Fatalf("ProcessDirectWithChannel failed: %v", err)
}
if resp != "provider content|ipc" {
t.Fatalf("expected process-hooked content, got %q", resp)
}
provider.mu.Lock()
lastModel := provider.lastModel
provider.mu.Unlock()
if lastModel != "process-model" {
t.Fatalf("expected process model, got %q", lastModel)
}
waitForFileContains(t, eventLog, "turn_end")
}
func TestAgentLoop_ProcessDirectWithChannel_InvalidConfiguredHookFails(t *testing.T) {
provider := &llmHookTestProvider{}
al := newConfiguredHookLoop(t, provider, config.HooksConfig{
Enabled: true,
Processes: map[string]config.ProcessHookConfig{
"bad-hook": {
Enabled: true,
Command: processHookHelperCommand(),
Intercept: []string{"not_supported"},
},
},
})
defer al.Close()
_, err := al.ProcessDirectWithChannel(context.Background(), "hello", "session-1", "cli", "direct")
if err == nil {
t.Fatal("expected invalid configured hook error")
}
}
+511
View File
@@ -0,0 +1,511 @@
package agent
import (
"bufio"
"context"
"encoding/json"
"fmt"
"io"
"os"
"os/exec"
"sync"
"sync/atomic"
"time"
"github.com/sipeed/picoclaw/pkg/logger"
)
const (
processHookJSONRPCVersion = "2.0"
processHookReadBufferSize = 1024 * 1024
processHookCloseTimeout = 2 * time.Second
)
type ProcessHookOptions struct {
Command []string
Dir string
Env []string
Observe bool
ObserveKinds []string
InterceptLLM bool
InterceptTool bool
ApproveTool bool
}
type ProcessHook struct {
name string
opts ProcessHookOptions
cmd *exec.Cmd
stdin io.WriteCloser
observeKinds map[string]struct{}
writeMu sync.Mutex
pendingMu sync.Mutex
pending map[uint64]chan processHookRPCMessage
nextID atomic.Uint64
closed atomic.Bool
done chan struct{}
closeErr error
closeMu sync.Mutex
closeOnce sync.Once
}
type processHookRPCMessage struct {
JSONRPC string `json:"jsonrpc,omitempty"`
ID uint64 `json:"id,omitempty"`
Method string `json:"method,omitempty"`
Params json.RawMessage `json:"params,omitempty"`
Result json.RawMessage `json:"result,omitempty"`
Error *processHookRPCError `json:"error,omitempty"`
}
type processHookRPCError struct {
Code int `json:"code"`
Message string `json:"message"`
}
type processHookHelloParams struct {
Name string `json:"name"`
Version int `json:"version"`
Modes []string `json:"modes,omitempty"`
}
type processHookDecisionResponse struct {
Action HookAction `json:"action"`
Reason string `json:"reason,omitempty"`
}
type processHookBeforeLLMResponse struct {
processHookDecisionResponse
Request *LLMHookRequest `json:"request,omitempty"`
}
type processHookAfterLLMResponse struct {
processHookDecisionResponse
Response *LLMHookResponse `json:"response,omitempty"`
}
type processHookBeforeToolResponse struct {
processHookDecisionResponse
Call *ToolCallHookRequest `json:"call,omitempty"`
}
type processHookAfterToolResponse struct {
processHookDecisionResponse
Result *ToolResultHookResponse `json:"result,omitempty"`
}
func NewProcessHook(ctx context.Context, name string, opts ProcessHookOptions) (*ProcessHook, error) {
if len(opts.Command) == 0 {
return nil, fmt.Errorf("process hook command is required")
}
cmd := exec.Command(opts.Command[0], opts.Command[1:]...)
cmd.Dir = opts.Dir
if len(opts.Env) > 0 {
cmd.Env = append(os.Environ(), opts.Env...)
}
stdin, err := cmd.StdinPipe()
if err != nil {
return nil, fmt.Errorf("create process hook stdin: %w", err)
}
stdout, err := cmd.StdoutPipe()
if err != nil {
return nil, fmt.Errorf("create process hook stdout: %w", err)
}
stderr, err := cmd.StderrPipe()
if err != nil {
return nil, fmt.Errorf("create process hook stderr: %w", err)
}
if err := cmd.Start(); err != nil {
return nil, fmt.Errorf("start process hook: %w", err)
}
ph := &ProcessHook{
name: name,
opts: opts,
cmd: cmd,
stdin: stdin,
observeKinds: newProcessHookObserveKinds(opts.ObserveKinds),
pending: make(map[uint64]chan processHookRPCMessage),
done: make(chan struct{}),
}
go ph.readLoop(stdout)
go ph.readStderr(stderr)
go ph.waitLoop()
helloCtx := ctx
if helloCtx == nil {
var cancel context.CancelFunc
helloCtx, cancel = context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
}
if err := ph.hello(helloCtx); err != nil {
_ = ph.Close()
return nil, err
}
return ph, nil
}
func (ph *ProcessHook) Close() error {
if ph == nil {
return nil
}
ph.closeOnce.Do(func() {
ph.closed.Store(true)
if ph.stdin != nil {
_ = ph.stdin.Close()
}
select {
case <-ph.done:
case <-time.After(processHookCloseTimeout):
if ph.cmd != nil && ph.cmd.Process != nil {
_ = ph.cmd.Process.Kill()
}
<-ph.done
}
})
ph.closeMu.Lock()
defer ph.closeMu.Unlock()
return ph.closeErr
}
func (ph *ProcessHook) OnEvent(ctx context.Context, evt Event) error {
if ph == nil || !ph.opts.Observe {
return nil
}
if len(ph.observeKinds) > 0 {
if _, ok := ph.observeKinds[evt.Kind.String()]; !ok {
return nil
}
}
return ph.notify(ctx, "hook.event", evt)
}
func (ph *ProcessHook) BeforeLLM(
ctx context.Context,
req *LLMHookRequest,
) (*LLMHookRequest, HookDecision, error) {
if ph == nil || !ph.opts.InterceptLLM {
return req, HookDecision{Action: HookActionContinue}, nil
}
var resp processHookBeforeLLMResponse
if err := ph.call(ctx, "hook.before_llm", req, &resp); err != nil {
return nil, HookDecision{}, err
}
if resp.Request == nil {
resp.Request = req
}
return resp.Request, HookDecision{Action: resp.Action, Reason: resp.Reason}, nil
}
func (ph *ProcessHook) AfterLLM(
ctx context.Context,
resp *LLMHookResponse,
) (*LLMHookResponse, HookDecision, error) {
if ph == nil || !ph.opts.InterceptLLM {
return resp, HookDecision{Action: HookActionContinue}, nil
}
var result processHookAfterLLMResponse
if err := ph.call(ctx, "hook.after_llm", resp, &result); err != nil {
return nil, HookDecision{}, err
}
if result.Response == nil {
result.Response = resp
}
return result.Response, HookDecision{Action: result.Action, Reason: result.Reason}, nil
}
func (ph *ProcessHook) BeforeTool(
ctx context.Context,
call *ToolCallHookRequest,
) (*ToolCallHookRequest, HookDecision, error) {
if ph == nil || !ph.opts.InterceptTool {
return call, HookDecision{Action: HookActionContinue}, nil
}
var resp processHookBeforeToolResponse
if err := ph.call(ctx, "hook.before_tool", call, &resp); err != nil {
return nil, HookDecision{}, err
}
if resp.Call == nil {
resp.Call = call
}
return resp.Call, HookDecision{Action: resp.Action, Reason: resp.Reason}, nil
}
func (ph *ProcessHook) AfterTool(
ctx context.Context,
result *ToolResultHookResponse,
) (*ToolResultHookResponse, HookDecision, error) {
if ph == nil || !ph.opts.InterceptTool {
return result, HookDecision{Action: HookActionContinue}, nil
}
var resp processHookAfterToolResponse
if err := ph.call(ctx, "hook.after_tool", result, &resp); err != nil {
return nil, HookDecision{}, err
}
if resp.Result == nil {
resp.Result = result
}
return resp.Result, HookDecision{Action: resp.Action, Reason: resp.Reason}, nil
}
func (ph *ProcessHook) ApproveTool(ctx context.Context, req *ToolApprovalRequest) (ApprovalDecision, error) {
if ph == nil || !ph.opts.ApproveTool {
return ApprovalDecision{Approved: true}, nil
}
var resp ApprovalDecision
if err := ph.call(ctx, "hook.approve_tool", req, &resp); err != nil {
return ApprovalDecision{}, err
}
return resp, nil
}
func (ph *ProcessHook) hello(ctx context.Context) error {
modes := make([]string, 0, 4)
if ph.opts.Observe {
modes = append(modes, "observe")
}
if ph.opts.InterceptLLM {
modes = append(modes, "llm")
}
if ph.opts.InterceptTool {
modes = append(modes, "tool")
}
if ph.opts.ApproveTool {
modes = append(modes, "approve")
}
var result map[string]any
return ph.call(ctx, "hook.hello", processHookHelloParams{
Name: ph.name,
Version: 1,
Modes: modes,
}, &result)
}
func (ph *ProcessHook) notify(ctx context.Context, method string, params any) error {
msg := processHookRPCMessage{
JSONRPC: processHookJSONRPCVersion,
Method: method,
}
if params != nil {
body, err := json.Marshal(params)
if err != nil {
return err
}
msg.Params = body
}
return ph.send(ctx, msg)
}
func (ph *ProcessHook) call(ctx context.Context, method string, params any, out any) error {
if ph.closed.Load() {
return fmt.Errorf("process hook %q is closed", ph.name)
}
id := ph.nextID.Add(1)
respCh := make(chan processHookRPCMessage, 1)
ph.pendingMu.Lock()
ph.pending[id] = respCh
ph.pendingMu.Unlock()
msg := processHookRPCMessage{
JSONRPC: processHookJSONRPCVersion,
ID: id,
Method: method,
}
if params != nil {
body, err := json.Marshal(params)
if err != nil {
ph.removePending(id)
return err
}
msg.Params = body
}
if err := ph.send(ctx, msg); err != nil {
ph.removePending(id)
return err
}
select {
case resp, ok := <-respCh:
if !ok {
return fmt.Errorf("process hook %q closed while waiting for %s", ph.name, method)
}
if resp.Error != nil {
return fmt.Errorf("process hook %q %s failed: %s", ph.name, method, resp.Error.Message)
}
if out != nil && len(resp.Result) > 0 {
if err := json.Unmarshal(resp.Result, out); err != nil {
return fmt.Errorf("decode process hook %q %s result: %w", ph.name, method, err)
}
}
return nil
case <-ctx.Done():
ph.removePending(id)
return ctx.Err()
}
}
func (ph *ProcessHook) send(ctx context.Context, msg processHookRPCMessage) error {
body, err := json.Marshal(msg)
if err != nil {
return err
}
body = append(body, '\n')
ph.writeMu.Lock()
defer ph.writeMu.Unlock()
if ph.closed.Load() {
return fmt.Errorf("process hook %q is closed", ph.name)
}
done := make(chan error, 1)
go func() {
_, writeErr := ph.stdin.Write(body)
done <- writeErr
}()
select {
case err := <-done:
if err != nil {
return fmt.Errorf("write process hook %q message: %w", ph.name, err)
}
return nil
case <-ctx.Done():
return ctx.Err()
}
}
func (ph *ProcessHook) readLoop(stdout io.Reader) {
scanner := bufio.NewScanner(stdout)
scanner.Buffer(make([]byte, 0, 64*1024), processHookReadBufferSize)
for scanner.Scan() {
var msg processHookRPCMessage
if err := json.Unmarshal(scanner.Bytes(), &msg); err != nil {
logger.WarnCF("hooks", "Failed to decode process hook message", map[string]any{
"hook": ph.name,
"error": err.Error(),
})
continue
}
if msg.ID == 0 {
continue
}
ph.pendingMu.Lock()
respCh, ok := ph.pending[msg.ID]
if ok {
delete(ph.pending, msg.ID)
}
ph.pendingMu.Unlock()
if ok {
respCh <- msg
close(respCh)
}
}
}
func (ph *ProcessHook) readStderr(stderr io.Reader) {
scanner := bufio.NewScanner(stderr)
scanner.Buffer(make([]byte, 0, 16*1024), processHookReadBufferSize)
for scanner.Scan() {
logger.WarnCF("hooks", "Process hook stderr", map[string]any{
"hook": ph.name,
"stderr": scanner.Text(),
})
}
}
func (ph *ProcessHook) waitLoop() {
err := ph.cmd.Wait()
ph.closeMu.Lock()
ph.closeErr = err
ph.closeMu.Unlock()
ph.failPending(err)
close(ph.done)
}
func (ph *ProcessHook) failPending(err error) {
ph.pendingMu.Lock()
defer ph.pendingMu.Unlock()
msg := processHookRPCMessage{
Error: &processHookRPCError{
Code: -32000,
Message: "process exited",
},
}
if err != nil {
msg.Error.Message = err.Error()
}
for id, ch := range ph.pending {
delete(ph.pending, id)
ch <- msg
close(ch)
}
}
func (ph *ProcessHook) removePending(id uint64) {
ph.pendingMu.Lock()
defer ph.pendingMu.Unlock()
if ch, ok := ph.pending[id]; ok {
delete(ph.pending, id)
close(ch)
}
}
func (al *AgentLoop) MountProcessHook(ctx context.Context, name string, opts ProcessHookOptions) error {
if al == nil {
return fmt.Errorf("agent loop is nil")
}
processHook, err := NewProcessHook(ctx, name, opts)
if err != nil {
return err
}
if err := al.MountHook(HookRegistration{
Name: name,
Source: HookSourceProcess,
Hook: processHook,
}); err != nil {
_ = processHook.Close()
return err
}
return nil
}
func newProcessHookObserveKinds(kinds []string) map[string]struct{} {
if len(kinds) == 0 {
return nil
}
normalized := make(map[string]struct{}, len(kinds))
for _, kind := range kinds {
if kind == "" {
continue
}
normalized[kind] = struct{}{}
}
if len(normalized) == 0 {
return nil
}
return normalized
}
+339
View File
@@ -0,0 +1,339 @@
package agent
import (
"bufio"
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/providers"
)
func TestProcessHook_HelperProcess(t *testing.T) {
if os.Getenv("PICOCLAW_HOOK_HELPER") != "1" {
return
}
if err := runProcessHookHelper(); err != nil {
fmt.Fprintln(os.Stderr, err.Error())
os.Exit(1)
}
os.Exit(0)
}
func TestAgentLoop_MountProcessHook_LLMAndObserver(t *testing.T) {
provider := &llmHookTestProvider{}
al, agent, cleanup := newHookTestLoop(t, provider)
defer cleanup()
eventLog := filepath.Join(t.TempDir(), "events.log")
if err := al.MountProcessHook(context.Background(), "ipc-llm", ProcessHookOptions{
Command: processHookHelperCommand(),
Env: processHookHelperEnv("rewrite", eventLog),
Observe: true,
InterceptLLM: true,
}); err != nil {
t.Fatalf("MountProcessHook failed: %v", err)
}
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
SessionKey: "session-1",
Channel: "cli",
ChatID: "direct",
UserMessage: "hello",
DefaultResponse: defaultResponse,
EnableSummary: false,
SendResponse: false,
})
if err != nil {
t.Fatalf("runAgentLoop failed: %v", err)
}
if resp != "provider content|ipc" {
t.Fatalf("expected process-hooked llm content, got %q", resp)
}
provider.mu.Lock()
lastModel := provider.lastModel
provider.mu.Unlock()
if lastModel != "process-model" {
t.Fatalf("expected process model, got %q", lastModel)
}
waitForFileContains(t, eventLog, "turn_end")
}
func TestAgentLoop_MountProcessHook_ToolRewrite(t *testing.T) {
provider := &toolHookProvider{}
al, agent, cleanup := newHookTestLoop(t, provider)
defer cleanup()
al.RegisterTool(&echoTextTool{})
if err := al.MountProcessHook(context.Background(), "ipc-tool", ProcessHookOptions{
Command: processHookHelperCommand(),
Env: processHookHelperEnv("rewrite", ""),
InterceptTool: true,
}); err != nil {
t.Fatalf("MountProcessHook failed: %v", err)
}
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
SessionKey: "session-1",
Channel: "cli",
ChatID: "direct",
UserMessage: "run tool",
DefaultResponse: defaultResponse,
EnableSummary: false,
SendResponse: false,
})
if err != nil {
t.Fatalf("runAgentLoop failed: %v", err)
}
if resp != "ipc:ipc" {
t.Fatalf("expected rewritten process-hook tool result, got %q", resp)
}
}
type blockedToolProvider struct {
calls int
}
func (p *blockedToolProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
p.calls++
if p.calls == 1 {
return &providers.LLMResponse{
ToolCalls: []providers.ToolCall{
{
ID: "call-1",
Name: "blocked_tool",
Arguments: map[string]any{},
},
},
}, nil
}
return &providers.LLMResponse{
Content: messages[len(messages)-1].Content,
}, nil
}
func (p *blockedToolProvider) GetDefaultModel() string {
return "blocked-tool-provider"
}
func TestAgentLoop_MountProcessHook_ApprovalDeny(t *testing.T) {
provider := &blockedToolProvider{}
al, agent, cleanup := newHookTestLoop(t, provider)
defer cleanup()
if err := al.MountProcessHook(context.Background(), "ipc-approval", ProcessHookOptions{
Command: processHookHelperCommand(),
Env: processHookHelperEnv("deny", ""),
ApproveTool: true,
}); err != nil {
t.Fatalf("MountProcessHook failed: %v", err)
}
sub := al.SubscribeEvents(16)
defer al.UnsubscribeEvents(sub.ID)
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
SessionKey: "session-1",
Channel: "cli",
ChatID: "direct",
UserMessage: "run blocked tool",
DefaultResponse: defaultResponse,
EnableSummary: false,
SendResponse: false,
})
if err != nil {
t.Fatalf("runAgentLoop failed: %v", err)
}
expected := "Tool execution denied by approval hook: blocked by ipc hook"
if resp != expected {
t.Fatalf("expected %q, got %q", expected, resp)
}
events := collectEventStream(sub.C)
skippedEvt, ok := findEvent(events, EventKindToolExecSkipped)
if !ok {
t.Fatal("expected tool skipped event")
}
payload, ok := skippedEvt.Payload.(ToolExecSkippedPayload)
if !ok {
t.Fatalf("expected ToolExecSkippedPayload, got %T", skippedEvt.Payload)
}
if payload.Reason != expected {
t.Fatalf("expected reason %q, got %q", expected, payload.Reason)
}
}
func processHookHelperCommand() []string {
return []string{os.Args[0], "-test.run=TestProcessHook_HelperProcess", "--"}
}
func processHookHelperEnv(mode, eventLog string) []string {
env := []string{
"PICOCLAW_HOOK_HELPER=1",
"PICOCLAW_HOOK_MODE=" + mode,
}
if eventLog != "" {
env = append(env, "PICOCLAW_HOOK_EVENT_LOG="+eventLog)
}
return env
}
func waitForFileContains(t *testing.T, path, substring string) {
t.Helper()
deadline := time.Now().Add(3 * time.Second)
for time.Now().Before(deadline) {
data, err := os.ReadFile(path)
if err == nil && strings.Contains(string(data), substring) {
return
}
time.Sleep(20 * time.Millisecond)
}
data, _ := os.ReadFile(path)
t.Fatalf("timed out waiting for %q in %s; current content: %q", substring, path, string(data))
}
func runProcessHookHelper() error {
mode := os.Getenv("PICOCLAW_HOOK_MODE")
eventLog := os.Getenv("PICOCLAW_HOOK_EVENT_LOG")
scanner := bufio.NewScanner(os.Stdin)
scanner.Buffer(make([]byte, 0, 64*1024), processHookReadBufferSize)
encoder := json.NewEncoder(os.Stdout)
for scanner.Scan() {
var msg processHookRPCMessage
if err := json.Unmarshal(scanner.Bytes(), &msg); err != nil {
return err
}
if msg.ID == 0 {
if msg.Method == "hook.event" && eventLog != "" {
var evt map[string]any
if err := json.Unmarshal(msg.Params, &evt); err == nil {
if rawKind, ok := evt["Kind"].(float64); ok {
kind := EventKind(rawKind)
_ = os.WriteFile(eventLog, []byte(kind.String()+"\n"), 0o644)
}
}
}
continue
}
result, rpcErr := handleProcessHookRequest(mode, msg)
resp := processHookRPCMessage{
JSONRPC: processHookJSONRPCVersion,
ID: msg.ID,
}
if rpcErr != nil {
resp.Error = rpcErr
} else if result != nil {
body, err := json.Marshal(result)
if err != nil {
return err
}
resp.Result = body
} else {
resp.Result = []byte("{}")
}
if err := encoder.Encode(resp); err != nil {
return err
}
}
return scanner.Err()
}
func handleProcessHookRequest(mode string, msg processHookRPCMessage) (any, *processHookRPCError) {
switch msg.Method {
case "hook.hello":
return map[string]any{"ok": true}, nil
case "hook.before_llm":
if mode != "rewrite" {
return map[string]any{"action": HookActionContinue}, nil
}
var req map[string]any
_ = json.Unmarshal(msg.Params, &req)
req["model"] = "process-model"
return map[string]any{
"action": HookActionModify,
"request": req,
}, nil
case "hook.after_llm":
if mode != "rewrite" {
return map[string]any{"action": HookActionContinue}, nil
}
var resp map[string]any
_ = json.Unmarshal(msg.Params, &resp)
if rawResponse, ok := resp["response"].(map[string]any); ok {
if content, ok := rawResponse["content"].(string); ok {
rawResponse["content"] = content + "|ipc"
}
}
return map[string]any{
"action": HookActionModify,
"response": resp,
}, nil
case "hook.before_tool":
if mode != "rewrite" {
return map[string]any{"action": HookActionContinue}, nil
}
var call map[string]any
_ = json.Unmarshal(msg.Params, &call)
rawArgs, ok := call["arguments"].(map[string]any)
if !ok || rawArgs == nil {
rawArgs = map[string]any{}
}
rawArgs["text"] = "ipc"
call["arguments"] = rawArgs
return map[string]any{
"action": HookActionModify,
"call": call,
}, nil
case "hook.after_tool":
if mode != "rewrite" {
return map[string]any{"action": HookActionContinue}, nil
}
var result map[string]any
_ = json.Unmarshal(msg.Params, &result)
if rawResult, ok := result["result"].(map[string]any); ok {
if forLLM, ok := rawResult["for_llm"].(string); ok {
rawResult["for_llm"] = "ipc:" + forLLM
}
}
return map[string]any{
"action": HookActionModify,
"result": result,
}, nil
case "hook.approve_tool":
if mode == "deny" {
return ApprovalDecision{
Approved: false,
Reason: "blocked by ipc hook",
}, nil
}
return ApprovalDecision{Approved: true}, nil
default:
return nil, &processHookRPCError{
Code: -32601,
Message: "method not found",
}
}
}
+809
View File
@@ -0,0 +1,809 @@
package agent
import (
"context"
"fmt"
"io"
"sort"
"sync"
"time"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/tools"
)
const (
defaultHookObserverTimeout = 500 * time.Millisecond
defaultHookInterceptorTimeout = 5 * time.Second
defaultHookApprovalTimeout = 60 * time.Second
hookObserverBufferSize = 64
)
type HookAction string
const (
HookActionContinue HookAction = "continue"
HookActionModify HookAction = "modify"
HookActionDenyTool HookAction = "deny_tool"
HookActionAbortTurn HookAction = "abort_turn"
HookActionHardAbort HookAction = "hard_abort"
)
type HookDecision struct {
Action HookAction `json:"action"`
Reason string `json:"reason,omitempty"`
}
func (d HookDecision) normalizedAction() HookAction {
if d.Action == "" {
return HookActionContinue
}
return d.Action
}
type ApprovalDecision struct {
Approved bool `json:"approved"`
Reason string `json:"reason,omitempty"`
}
type HookSource uint8
const (
HookSourceInProcess HookSource = iota
HookSourceProcess
)
type HookRegistration struct {
Name string
Priority int
Source HookSource
Hook any
}
func NamedHook(name string, hook any) HookRegistration {
return HookRegistration{
Name: name,
Source: HookSourceInProcess,
Hook: hook,
}
}
type EventObserver interface {
OnEvent(ctx context.Context, evt Event) error
}
type LLMInterceptor interface {
BeforeLLM(ctx context.Context, req *LLMHookRequest) (*LLMHookRequest, HookDecision, error)
AfterLLM(ctx context.Context, resp *LLMHookResponse) (*LLMHookResponse, HookDecision, error)
}
type ToolInterceptor interface {
BeforeTool(ctx context.Context, call *ToolCallHookRequest) (*ToolCallHookRequest, HookDecision, error)
AfterTool(ctx context.Context, result *ToolResultHookResponse) (*ToolResultHookResponse, HookDecision, error)
}
type ToolApprover interface {
ApproveTool(ctx context.Context, req *ToolApprovalRequest) (ApprovalDecision, error)
}
type LLMHookRequest struct {
Meta EventMeta `json:"meta"`
Model string `json:"model"`
Messages []providers.Message `json:"messages,omitempty"`
Tools []providers.ToolDefinition `json:"tools,omitempty"`
Options map[string]any `json:"options,omitempty"`
Channel string `json:"channel,omitempty"`
ChatID string `json:"chat_id,omitempty"`
GracefulTerminal bool `json:"graceful_terminal,omitempty"`
}
func (r *LLMHookRequest) Clone() *LLMHookRequest {
if r == nil {
return nil
}
cloned := *r
cloned.Messages = cloneProviderMessages(r.Messages)
cloned.Tools = cloneToolDefinitions(r.Tools)
cloned.Options = cloneStringAnyMap(r.Options)
return &cloned
}
type LLMHookResponse struct {
Meta EventMeta `json:"meta"`
Model string `json:"model"`
Response *providers.LLMResponse `json:"response,omitempty"`
Channel string `json:"channel,omitempty"`
ChatID string `json:"chat_id,omitempty"`
}
func (r *LLMHookResponse) Clone() *LLMHookResponse {
if r == nil {
return nil
}
cloned := *r
cloned.Response = cloneLLMResponse(r.Response)
return &cloned
}
type ToolCallHookRequest struct {
Meta EventMeta `json:"meta"`
Tool string `json:"tool"`
Arguments map[string]any `json:"arguments,omitempty"`
Channel string `json:"channel,omitempty"`
ChatID string `json:"chat_id,omitempty"`
}
func (r *ToolCallHookRequest) Clone() *ToolCallHookRequest {
if r == nil {
return nil
}
cloned := *r
cloned.Arguments = cloneStringAnyMap(r.Arguments)
return &cloned
}
type ToolApprovalRequest struct {
Meta EventMeta `json:"meta"`
Tool string `json:"tool"`
Arguments map[string]any `json:"arguments,omitempty"`
Channel string `json:"channel,omitempty"`
ChatID string `json:"chat_id,omitempty"`
}
func (r *ToolApprovalRequest) Clone() *ToolApprovalRequest {
if r == nil {
return nil
}
cloned := *r
cloned.Arguments = cloneStringAnyMap(r.Arguments)
return &cloned
}
type ToolResultHookResponse struct {
Meta EventMeta `json:"meta"`
Tool string `json:"tool"`
Arguments map[string]any `json:"arguments,omitempty"`
Result *tools.ToolResult `json:"result,omitempty"`
Duration time.Duration `json:"duration"`
Channel string `json:"channel,omitempty"`
ChatID string `json:"chat_id,omitempty"`
}
func (r *ToolResultHookResponse) Clone() *ToolResultHookResponse {
if r == nil {
return nil
}
cloned := *r
cloned.Arguments = cloneStringAnyMap(r.Arguments)
cloned.Result = cloneToolResult(r.Result)
return &cloned
}
type HookManager struct {
eventBus *EventBus
observerTimeout time.Duration
interceptorTimeout time.Duration
approvalTimeout time.Duration
mu sync.RWMutex
hooks map[string]HookRegistration
ordered []HookRegistration
sub EventSubscription
done chan struct{}
closeOnce sync.Once
}
func NewHookManager(eventBus *EventBus) *HookManager {
hm := &HookManager{
eventBus: eventBus,
observerTimeout: defaultHookObserverTimeout,
interceptorTimeout: defaultHookInterceptorTimeout,
approvalTimeout: defaultHookApprovalTimeout,
hooks: make(map[string]HookRegistration),
done: make(chan struct{}),
}
if eventBus == nil {
close(hm.done)
return hm
}
hm.sub = eventBus.Subscribe(hookObserverBufferSize)
go hm.dispatchEvents()
return hm
}
func (hm *HookManager) Close() {
if hm == nil {
return
}
hm.closeOnce.Do(func() {
if hm.eventBus != nil {
hm.eventBus.Unsubscribe(hm.sub.ID)
}
<-hm.done
hm.closeAllHooks()
})
}
func (hm *HookManager) ConfigureTimeouts(observer, interceptor, approval time.Duration) {
if hm == nil {
return
}
if observer > 0 {
hm.observerTimeout = observer
}
if interceptor > 0 {
hm.interceptorTimeout = interceptor
}
if approval > 0 {
hm.approvalTimeout = approval
}
}
func (hm *HookManager) Mount(reg HookRegistration) error {
if hm == nil {
return fmt.Errorf("hook manager is nil")
}
if reg.Name == "" {
return fmt.Errorf("hook name is required")
}
if reg.Hook == nil {
return fmt.Errorf("hook %q is nil", reg.Name)
}
hm.mu.Lock()
defer hm.mu.Unlock()
if existing, ok := hm.hooks[reg.Name]; ok {
closeHookIfPossible(existing.Hook)
}
hm.hooks[reg.Name] = reg
hm.rebuildOrdered()
return nil
}
func (hm *HookManager) Unmount(name string) {
if hm == nil || name == "" {
return
}
hm.mu.Lock()
defer hm.mu.Unlock()
if existing, ok := hm.hooks[name]; ok {
closeHookIfPossible(existing.Hook)
}
delete(hm.hooks, name)
hm.rebuildOrdered()
}
func (hm *HookManager) dispatchEvents() {
defer close(hm.done)
for evt := range hm.sub.C {
for _, reg := range hm.snapshotHooks() {
observer, ok := reg.Hook.(EventObserver)
if !ok {
continue
}
hm.runObserver(reg.Name, observer, evt)
}
}
}
func (hm *HookManager) BeforeLLM(ctx context.Context, req *LLMHookRequest) (*LLMHookRequest, HookDecision) {
if hm == nil || req == nil {
return req, HookDecision{Action: HookActionContinue}
}
current := req.Clone()
for _, reg := range hm.snapshotHooks() {
interceptor, ok := reg.Hook.(LLMInterceptor)
if !ok {
continue
}
next, decision, ok := hm.callBeforeLLM(ctx, reg.Name, interceptor, current.Clone())
if !ok {
continue
}
switch decision.normalizedAction() {
case HookActionContinue, HookActionModify:
if next != nil {
current = next
}
case HookActionAbortTurn, HookActionHardAbort:
return current, decision
default:
hm.logUnsupportedAction(reg.Name, "before_llm", decision.Action)
}
}
return current, HookDecision{Action: HookActionContinue}
}
func (hm *HookManager) AfterLLM(ctx context.Context, resp *LLMHookResponse) (*LLMHookResponse, HookDecision) {
if hm == nil || resp == nil {
return resp, HookDecision{Action: HookActionContinue}
}
current := resp.Clone()
for _, reg := range hm.snapshotHooks() {
interceptor, ok := reg.Hook.(LLMInterceptor)
if !ok {
continue
}
next, decision, ok := hm.callAfterLLM(ctx, reg.Name, interceptor, current.Clone())
if !ok {
continue
}
switch decision.normalizedAction() {
case HookActionContinue, HookActionModify:
if next != nil {
current = next
}
case HookActionAbortTurn, HookActionHardAbort:
return current, decision
default:
hm.logUnsupportedAction(reg.Name, "after_llm", decision.Action)
}
}
return current, HookDecision{Action: HookActionContinue}
}
func (hm *HookManager) BeforeTool(
ctx context.Context,
call *ToolCallHookRequest,
) (*ToolCallHookRequest, HookDecision) {
if hm == nil || call == nil {
return call, HookDecision{Action: HookActionContinue}
}
current := call.Clone()
for _, reg := range hm.snapshotHooks() {
interceptor, ok := reg.Hook.(ToolInterceptor)
if !ok {
continue
}
next, decision, ok := hm.callBeforeTool(ctx, reg.Name, interceptor, current.Clone())
if !ok {
continue
}
switch decision.normalizedAction() {
case HookActionContinue, HookActionModify:
if next != nil {
current = next
}
case HookActionDenyTool, HookActionAbortTurn, HookActionHardAbort:
return current, decision
default:
hm.logUnsupportedAction(reg.Name, "before_tool", decision.Action)
}
}
return current, HookDecision{Action: HookActionContinue}
}
func (hm *HookManager) AfterTool(
ctx context.Context,
result *ToolResultHookResponse,
) (*ToolResultHookResponse, HookDecision) {
if hm == nil || result == nil {
return result, HookDecision{Action: HookActionContinue}
}
current := result.Clone()
for _, reg := range hm.snapshotHooks() {
interceptor, ok := reg.Hook.(ToolInterceptor)
if !ok {
continue
}
next, decision, ok := hm.callAfterTool(ctx, reg.Name, interceptor, current.Clone())
if !ok {
continue
}
switch decision.normalizedAction() {
case HookActionContinue, HookActionModify:
if next != nil {
current = next
}
case HookActionAbortTurn, HookActionHardAbort:
return current, decision
default:
hm.logUnsupportedAction(reg.Name, "after_tool", decision.Action)
}
}
return current, HookDecision{Action: HookActionContinue}
}
func (hm *HookManager) ApproveTool(ctx context.Context, req *ToolApprovalRequest) ApprovalDecision {
if hm == nil || req == nil {
return ApprovalDecision{Approved: true}
}
for _, reg := range hm.snapshotHooks() {
approver, ok := reg.Hook.(ToolApprover)
if !ok {
continue
}
decision, ok := hm.callApproveTool(ctx, reg.Name, approver, req.Clone())
if !ok {
return ApprovalDecision{
Approved: false,
Reason: fmt.Sprintf("tool approval hook %q failed", reg.Name),
}
}
if !decision.Approved {
return decision
}
}
return ApprovalDecision{Approved: true}
}
func (hm *HookManager) rebuildOrdered() {
hm.ordered = hm.ordered[:0]
for _, reg := range hm.hooks {
hm.ordered = append(hm.ordered, reg)
}
sort.SliceStable(hm.ordered, func(i, j int) bool {
if hm.ordered[i].Source != hm.ordered[j].Source {
return hm.ordered[i].Source < hm.ordered[j].Source
}
if hm.ordered[i].Priority == hm.ordered[j].Priority {
return hm.ordered[i].Name < hm.ordered[j].Name
}
return hm.ordered[i].Priority < hm.ordered[j].Priority
})
}
func (hm *HookManager) snapshotHooks() []HookRegistration {
hm.mu.RLock()
defer hm.mu.RUnlock()
snapshot := make([]HookRegistration, len(hm.ordered))
copy(snapshot, hm.ordered)
return snapshot
}
func (hm *HookManager) closeAllHooks() {
hm.mu.Lock()
defer hm.mu.Unlock()
for name, reg := range hm.hooks {
closeHookIfPossible(reg.Hook)
delete(hm.hooks, name)
}
hm.ordered = nil
}
func (hm *HookManager) runObserver(name string, observer EventObserver, evt Event) {
ctx, cancel := context.WithTimeout(context.Background(), hm.observerTimeout)
defer cancel()
done := make(chan error, 1)
go func() {
done <- observer.OnEvent(ctx, evt)
}()
select {
case err := <-done:
if err != nil {
logger.WarnCF("hooks", "Event observer failed", map[string]any{
"hook": name,
"event": evt.Kind.String(),
"error": err.Error(),
})
}
case <-ctx.Done():
logger.WarnCF("hooks", "Event observer timed out", map[string]any{
"hook": name,
"event": evt.Kind.String(),
"timeout_ms": hm.observerTimeout.Milliseconds(),
})
}
}
func (hm *HookManager) callBeforeLLM(
parent context.Context,
name string,
interceptor LLMInterceptor,
req *LLMHookRequest,
) (*LLMHookRequest, HookDecision, bool) {
return runInterceptorHook(
parent,
hm.interceptorTimeout,
name,
"before_llm",
func(ctx context.Context) (*LLMHookRequest, HookDecision, error) {
return interceptor.BeforeLLM(ctx, req)
},
)
}
func (hm *HookManager) callAfterLLM(
parent context.Context,
name string,
interceptor LLMInterceptor,
resp *LLMHookResponse,
) (*LLMHookResponse, HookDecision, bool) {
return runInterceptorHook(
parent,
hm.interceptorTimeout,
name,
"after_llm",
func(ctx context.Context) (*LLMHookResponse, HookDecision, error) {
return interceptor.AfterLLM(ctx, resp)
},
)
}
func (hm *HookManager) callBeforeTool(
parent context.Context,
name string,
interceptor ToolInterceptor,
call *ToolCallHookRequest,
) (*ToolCallHookRequest, HookDecision, bool) {
return runInterceptorHook(
parent,
hm.interceptorTimeout,
name,
"before_tool",
func(ctx context.Context) (*ToolCallHookRequest, HookDecision, error) {
return interceptor.BeforeTool(ctx, call)
},
)
}
func (hm *HookManager) callAfterTool(
parent context.Context,
name string,
interceptor ToolInterceptor,
resultView *ToolResultHookResponse,
) (*ToolResultHookResponse, HookDecision, bool) {
return runInterceptorHook(
parent,
hm.interceptorTimeout,
name,
"after_tool",
func(ctx context.Context) (*ToolResultHookResponse, HookDecision, error) {
return interceptor.AfterTool(ctx, resultView)
},
)
}
func (hm *HookManager) callApproveTool(
parent context.Context,
name string,
approver ToolApprover,
req *ToolApprovalRequest,
) (ApprovalDecision, bool) {
return runApprovalHook(
parent,
hm.approvalTimeout,
name,
"approve_tool",
func(ctx context.Context) (ApprovalDecision, error) {
return approver.ApproveTool(ctx, req)
},
)
}
func runInterceptorHook[T any](
parent context.Context,
timeout time.Duration,
name string,
stage string,
fn func(ctx context.Context) (T, HookDecision, error),
) (T, HookDecision, bool) {
var zero T
ctx, cancel := context.WithTimeout(parent, timeout)
defer cancel()
type result struct {
value T
decision HookDecision
err error
}
done := make(chan result, 1)
go func() {
value, decision, err := fn(ctx)
done <- result{value: value, decision: decision, err: err}
}()
select {
case res := <-done:
if res.err != nil {
logger.WarnCF("hooks", "Interceptor hook failed", map[string]any{
"hook": name,
"stage": stage,
"error": res.err.Error(),
})
return zero, HookDecision{}, false
}
return res.value, res.decision, true
case <-ctx.Done():
logger.WarnCF("hooks", "Interceptor hook timed out", map[string]any{
"hook": name,
"stage": stage,
"timeout_ms": timeout.Milliseconds(),
})
return zero, HookDecision{}, false
}
}
func runApprovalHook(
parent context.Context,
timeout time.Duration,
name string,
stage string,
fn func(ctx context.Context) (ApprovalDecision, error),
) (ApprovalDecision, bool) {
ctx, cancel := context.WithTimeout(parent, timeout)
defer cancel()
type result struct {
decision ApprovalDecision
err error
}
done := make(chan result, 1)
go func() {
decision, err := fn(ctx)
done <- result{decision: decision, err: err}
}()
select {
case res := <-done:
if res.err != nil {
logger.WarnCF("hooks", "Approval hook failed", map[string]any{
"hook": name,
"stage": stage,
"error": res.err.Error(),
})
return ApprovalDecision{}, false
}
return res.decision, true
case <-ctx.Done():
logger.WarnCF("hooks", "Approval hook timed out", map[string]any{
"hook": name,
"stage": stage,
"timeout_ms": timeout.Milliseconds(),
})
return ApprovalDecision{
Approved: false,
Reason: fmt.Sprintf("tool approval hook %q timed out", name),
}, true
}
}
func (hm *HookManager) logUnsupportedAction(name, stage string, action HookAction) {
logger.WarnCF("hooks", "Hook returned unsupported action for stage", map[string]any{
"hook": name,
"stage": stage,
"action": action,
})
}
func cloneProviderMessages(messages []providers.Message) []providers.Message {
if len(messages) == 0 {
return nil
}
cloned := make([]providers.Message, len(messages))
for i, msg := range messages {
cloned[i] = msg
if len(msg.Media) > 0 {
cloned[i].Media = append([]string(nil), msg.Media...)
}
if len(msg.SystemParts) > 0 {
cloned[i].SystemParts = append([]providers.ContentBlock(nil), msg.SystemParts...)
}
if len(msg.ToolCalls) > 0 {
cloned[i].ToolCalls = cloneProviderToolCalls(msg.ToolCalls)
}
}
return cloned
}
func cloneProviderToolCalls(calls []providers.ToolCall) []providers.ToolCall {
if len(calls) == 0 {
return nil
}
cloned := make([]providers.ToolCall, len(calls))
for i, call := range calls {
cloned[i] = call
if call.Function != nil {
fn := *call.Function
cloned[i].Function = &fn
}
if call.Arguments != nil {
cloned[i].Arguments = cloneStringAnyMap(call.Arguments)
}
if call.ExtraContent != nil {
extra := *call.ExtraContent
if call.ExtraContent.Google != nil {
google := *call.ExtraContent.Google
extra.Google = &google
}
cloned[i].ExtraContent = &extra
}
}
return cloned
}
func cloneToolDefinitions(defs []providers.ToolDefinition) []providers.ToolDefinition {
if len(defs) == 0 {
return nil
}
cloned := make([]providers.ToolDefinition, len(defs))
for i, def := range defs {
cloned[i] = def
cloned[i].Function.Parameters = cloneStringAnyMap(def.Function.Parameters)
}
return cloned
}
func cloneLLMResponse(resp *providers.LLMResponse) *providers.LLMResponse {
if resp == nil {
return nil
}
cloned := *resp
cloned.ToolCalls = cloneProviderToolCalls(resp.ToolCalls)
if len(resp.ReasoningDetails) > 0 {
cloned.ReasoningDetails = append(cloned.ReasoningDetails[:0:0], resp.ReasoningDetails...)
}
if resp.Usage != nil {
usage := *resp.Usage
cloned.Usage = &usage
}
return &cloned
}
func cloneStringAnyMap(src map[string]any) map[string]any {
if len(src) == 0 {
return nil
}
cloned := make(map[string]any, len(src))
for k, v := range src {
cloned[k] = v
}
return cloned
}
func cloneToolResult(result *tools.ToolResult) *tools.ToolResult {
if result == nil {
return nil
}
cloned := *result
if len(result.Media) > 0 {
cloned.Media = append([]string(nil), result.Media...)
}
return &cloned
}
func closeHookIfPossible(hook any) {
closer, ok := hook.(io.Closer)
if !ok {
return
}
if err := closer.Close(); err != nil {
logger.WarnCF("hooks", "Failed to close hook", map[string]any{
"error": err.Error(),
})
}
}
+345
View File
@@ -0,0 +1,345 @@
package agent
import (
"context"
"os"
"sync"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/tools"
)
func newHookTestLoop(
t *testing.T,
provider providers.LLMProvider,
) (*AgentLoop, *AgentInstance, func()) {
t.Helper()
tmpDir, err := os.MkdirTemp("", "agent-hooks-*")
if err != nil {
t.Fatalf("failed to create temp dir: %v", err)
}
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
al := NewAgentLoop(cfg, bus.NewMessageBus(), provider)
agent := al.registry.GetDefaultAgent()
if agent == nil {
t.Fatal("expected default agent")
}
return al, agent, func() {
al.Close()
_ = os.RemoveAll(tmpDir)
}
}
func TestHookManager_SortsInProcessBeforeProcess(t *testing.T) {
hm := NewHookManager(nil)
defer hm.Close()
if err := hm.Mount(HookRegistration{
Name: "process",
Priority: -10,
Source: HookSourceProcess,
Hook: struct{}{},
}); err != nil {
t.Fatalf("mount process hook: %v", err)
}
if err := hm.Mount(HookRegistration{
Name: "in-process",
Priority: 100,
Source: HookSourceInProcess,
Hook: struct{}{},
}); err != nil {
t.Fatalf("mount in-process hook: %v", err)
}
ordered := hm.snapshotHooks()
if len(ordered) != 2 {
t.Fatalf("expected 2 hooks, got %d", len(ordered))
}
if ordered[0].Name != "in-process" {
t.Fatalf("expected in-process hook first, got %q", ordered[0].Name)
}
if ordered[1].Name != "process" {
t.Fatalf("expected process hook second, got %q", ordered[1].Name)
}
}
type llmHookTestProvider struct {
mu sync.Mutex
lastModel string
}
func (p *llmHookTestProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
p.mu.Lock()
p.lastModel = model
p.mu.Unlock()
return &providers.LLMResponse{
Content: "provider content",
}, nil
}
func (p *llmHookTestProvider) GetDefaultModel() string {
return "llm-hook-provider"
}
type llmObserverHook struct {
eventCh chan Event
}
func (h *llmObserverHook) OnEvent(ctx context.Context, evt Event) error {
if evt.Kind == EventKindTurnEnd {
select {
case h.eventCh <- evt:
default:
}
}
return nil
}
func (h *llmObserverHook) BeforeLLM(
ctx context.Context,
req *LLMHookRequest,
) (*LLMHookRequest, HookDecision, error) {
next := req.Clone()
next.Model = "hook-model"
return next, HookDecision{Action: HookActionModify}, nil
}
func (h *llmObserverHook) AfterLLM(
ctx context.Context,
resp *LLMHookResponse,
) (*LLMHookResponse, HookDecision, error) {
next := resp.Clone()
next.Response.Content = "hooked content"
return next, HookDecision{Action: HookActionModify}, nil
}
func TestAgentLoop_Hooks_ObserverAndLLMInterceptor(t *testing.T) {
provider := &llmHookTestProvider{}
al, agent, cleanup := newHookTestLoop(t, provider)
defer cleanup()
hook := &llmObserverHook{eventCh: make(chan Event, 1)}
if err := al.MountHook(NamedHook("llm-observer", hook)); err != nil {
t.Fatalf("MountHook failed: %v", err)
}
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
SessionKey: "session-1",
Channel: "cli",
ChatID: "direct",
UserMessage: "hello",
DefaultResponse: defaultResponse,
EnableSummary: false,
SendResponse: false,
})
if err != nil {
t.Fatalf("runAgentLoop failed: %v", err)
}
if resp != "hooked content" {
t.Fatalf("expected hooked content, got %q", resp)
}
provider.mu.Lock()
lastModel := provider.lastModel
provider.mu.Unlock()
if lastModel != "hook-model" {
t.Fatalf("expected model hook-model, got %q", lastModel)
}
select {
case evt := <-hook.eventCh:
if evt.Kind != EventKindTurnEnd {
t.Fatalf("expected turn end event, got %v", evt.Kind)
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for hook observer event")
}
}
type toolHookProvider struct {
mu sync.Mutex
calls int
}
func (p *toolHookProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
p.mu.Lock()
defer p.mu.Unlock()
p.calls++
if p.calls == 1 {
return &providers.LLMResponse{
ToolCalls: []providers.ToolCall{
{
ID: "call-1",
Name: "echo_text",
Arguments: map[string]any{"text": "original"},
},
},
}, nil
}
last := messages[len(messages)-1]
return &providers.LLMResponse{
Content: last.Content,
}, nil
}
func (p *toolHookProvider) GetDefaultModel() string {
return "tool-hook-provider"
}
type echoTextTool struct{}
func (t *echoTextTool) Name() string {
return "echo_text"
}
func (t *echoTextTool) Description() string {
return "echo a text argument"
}
func (t *echoTextTool) Parameters() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]any{
"text": map[string]any{
"type": "string",
},
},
}
}
func (t *echoTextTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
text, _ := args["text"].(string)
return tools.SilentResult(text)
}
type toolRewriteHook struct{}
func (h *toolRewriteHook) BeforeTool(
ctx context.Context,
call *ToolCallHookRequest,
) (*ToolCallHookRequest, HookDecision, error) {
next := call.Clone()
next.Arguments["text"] = "modified"
return next, HookDecision{Action: HookActionModify}, nil
}
func (h *toolRewriteHook) AfterTool(
ctx context.Context,
result *ToolResultHookResponse,
) (*ToolResultHookResponse, HookDecision, error) {
next := result.Clone()
next.Result.ForLLM = "after:" + next.Result.ForLLM
return next, HookDecision{Action: HookActionModify}, nil
}
func TestAgentLoop_Hooks_ToolInterceptorCanRewrite(t *testing.T) {
provider := &toolHookProvider{}
al, agent, cleanup := newHookTestLoop(t, provider)
defer cleanup()
al.RegisterTool(&echoTextTool{})
if err := al.MountHook(NamedHook("tool-rewrite", &toolRewriteHook{})); err != nil {
t.Fatalf("MountHook failed: %v", err)
}
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
SessionKey: "session-1",
Channel: "cli",
ChatID: "direct",
UserMessage: "run tool",
DefaultResponse: defaultResponse,
EnableSummary: false,
SendResponse: false,
})
if err != nil {
t.Fatalf("runAgentLoop failed: %v", err)
}
if resp != "after:modified" {
t.Fatalf("expected rewritten tool result, got %q", resp)
}
}
type denyApprovalHook struct{}
func (h *denyApprovalHook) ApproveTool(ctx context.Context, req *ToolApprovalRequest) (ApprovalDecision, error) {
return ApprovalDecision{
Approved: false,
Reason: "blocked",
}, nil
}
func TestAgentLoop_Hooks_ToolApproverCanDeny(t *testing.T) {
provider := &toolHookProvider{}
al, agent, cleanup := newHookTestLoop(t, provider)
defer cleanup()
al.RegisterTool(&echoTextTool{})
if err := al.MountHook(NamedHook("deny-approval", &denyApprovalHook{})); err != nil {
t.Fatalf("MountHook failed: %v", err)
}
sub := al.SubscribeEvents(16)
defer al.UnsubscribeEvents(sub.ID)
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
SessionKey: "session-1",
Channel: "cli",
ChatID: "direct",
UserMessage: "run tool",
DefaultResponse: defaultResponse,
EnableSummary: false,
SendResponse: false,
})
if err != nil {
t.Fatalf("runAgentLoop failed: %v", err)
}
expected := "Tool execution denied by approval hook: blocked"
if resp != expected {
t.Fatalf("expected %q, got %q", expected, resp)
}
events := collectEventStream(sub.C)
skippedEvt, ok := findEvent(events, EventKindToolExecSkipped)
if !ok {
t.Fatal("expected tool skipped event")
}
payload, ok := skippedEvt.Payload.(ToolExecSkippedPayload)
if !ok {
t.Fatalf("expected ToolExecSkippedPayload, got %T", skippedEvt.Payload)
}
if payload.Reason != expected {
t.Fatalf("expected skipped reason %q, got %q", expected, payload.Reason)
}
}
+12 -1
View File
@@ -130,6 +130,17 @@ func NewAgentInstance(
maxTokens = 8192
}
contextWindow := defaults.ContextWindow
if contextWindow == 0 {
// Default heuristic: 4x the output token limit.
// Most models have context windows well above their output limits
// (e.g., GPT-4o 128k ctx / 16k out, Claude 200k ctx / 8k out).
// 4x is a conservative lower bound that avoids premature
// summarization while remaining safe — the reactive
// forceCompression handles any overshoot.
contextWindow = maxTokens * 4
}
temperature := 0.7
if defaults.Temperature != nil {
temperature = *defaults.Temperature
@@ -182,7 +193,7 @@ func NewAgentInstance(
MaxTokens: maxTokens,
Temperature: temperature,
ThinkingLevel: thinkingLevel,
ContextWindow: maxTokens,
ContextWindow: contextWindow,
SummarizeMessageThreshold: summarizeMessageThreshold,
SummarizeTokenPercent: summarizeTokenPercent,
Provider: provider,
+1340 -526
View File
File diff suppressed because it is too large Load Diff
+8 -9
View File
@@ -1078,11 +1078,11 @@ func TestAgentLoop_ContextExhaustionRetry(t *testing.T) {
al := NewAgentLoop(cfg, msgBus, provider)
// Inject some history to simulate a full context
// Inject some history to simulate a full context.
// Session history only stores user/assistant/tool messages — the system
// prompt is built dynamically by BuildMessages and is NOT stored here.
sessionKey := "test-session-context"
// Create dummy history
history := []providers.Message{
{Role: "system", Content: "System prompt"},
{Role: "user", Content: "Old message 1"},
{Role: "assistant", Content: "Old response 1"},
{Role: "user", Content: "Old message 2"},
@@ -1120,12 +1120,11 @@ func TestAgentLoop_ContextExhaustionRetry(t *testing.T) {
// Check final history length
finalHistory := defaultAgent.Sessions.GetHistory(sessionKey)
// We verify that the history has been modified (compressed)
// Original length: 6
// Expected behavior: compression drops ~50% of history (mid slice)
// We can assert that the length is NOT what it would be without compression.
// Without compression: 6 + 1 (new user msg) + 1 (assistant msg) = 8
if len(finalHistory) >= 8 {
t.Errorf("Expected history to be compressed (len < 8), got %d", len(finalHistory))
// Original length: 5
// Expected behavior: compression drops ~50% of Turns
// Without compression: 5 + 1 (new user msg) + 1 (assistant msg) = 7
if len(finalHistory) >= 7 {
t.Errorf("Expected history to be compressed (len < 7), got %d", len(finalHistory))
}
}
+253 -69
View File
@@ -8,6 +8,7 @@ import (
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/routing"
"github.com/sipeed/picoclaw/pkg/tools"
)
@@ -21,6 +22,9 @@ const (
SteeringAll SteeringMode = "all"
// MaxQueueSize number of possible messages in the Steering Queue
MaxQueueSize = 10
// manualSteeringScope is the legacy fallback queue used when no active
// turn/session scope is available.
manualSteeringScope = "__manual__"
)
// parseSteeringMode normalizes a config string into a SteeringMode.
@@ -36,56 +40,117 @@ func parseSteeringMode(s string) SteeringMode {
// steeringQueue is a thread-safe queue of user messages that can be injected
// into a running agent loop to interrupt it between tool calls.
type steeringQueue struct {
mu sync.Mutex
queue []providers.Message
mode SteeringMode
mu sync.Mutex
queues map[string][]providers.Message
mode SteeringMode
}
func newSteeringQueue(mode SteeringMode) *steeringQueue {
return &steeringQueue{
mode: mode,
queues: make(map[string][]providers.Message),
mode: mode,
}
}
// push enqueues a steering message.
func normalizeSteeringScope(scope string) string {
scope = strings.TrimSpace(scope)
if scope == "" {
return manualSteeringScope
}
return scope
}
// push enqueues a steering message in the legacy fallback scope.
func (sq *steeringQueue) push(msg providers.Message) error {
return sq.pushScope(manualSteeringScope, msg)
}
// pushScope enqueues a steering message for the provided scope.
func (sq *steeringQueue) pushScope(scope string, msg providers.Message) error {
sq.mu.Lock()
defer sq.mu.Unlock()
if len(sq.queue) >= MaxQueueSize {
scope = normalizeSteeringScope(scope)
queue := sq.queues[scope]
if len(queue) >= MaxQueueSize {
return fmt.Errorf("steering queue is full")
}
sq.queue = append(sq.queue, msg)
sq.queues[scope] = append(queue, msg)
return nil
}
// dequeue removes and returns pending steering messages according to the
// configured mode. Returns nil when the queue is empty.
// dequeue removes and returns pending steering messages from the legacy
// fallback scope according to the configured mode.
func (sq *steeringQueue) dequeue() []providers.Message {
return sq.dequeueScope(manualSteeringScope)
}
// dequeueScope removes and returns pending steering messages for the provided
// scope according to the configured mode.
func (sq *steeringQueue) dequeueScope(scope string) []providers.Message {
sq.mu.Lock()
defer sq.mu.Unlock()
if len(sq.queue) == 0 {
return sq.dequeueLocked(normalizeSteeringScope(scope))
}
// dequeueScopeWithFallback drains the scoped queue first and falls back to the
// legacy manual scope for backwards compatibility.
func (sq *steeringQueue) dequeueScopeWithFallback(scope string) []providers.Message {
sq.mu.Lock()
defer sq.mu.Unlock()
scope = strings.TrimSpace(scope)
if scope != "" {
if msgs := sq.dequeueLocked(scope); len(msgs) > 0 {
return msgs
}
}
return sq.dequeueLocked(manualSteeringScope)
}
func (sq *steeringQueue) dequeueLocked(scope string) []providers.Message {
queue := sq.queues[scope]
if len(queue) == 0 {
return nil
}
switch sq.mode {
case SteeringAll:
msgs := sq.queue
sq.queue = nil
msgs := append([]providers.Message(nil), queue...)
delete(sq.queues, scope)
return msgs
default: // one-at-a-time
msg := sq.queue[0]
sq.queue[0] = providers.Message{} // Clear reference for GC
sq.queue = sq.queue[1:]
default:
msg := queue[0]
queue[0] = providers.Message{} // Clear reference for GC
queue = queue[1:]
if len(queue) == 0 {
delete(sq.queues, scope)
} else {
sq.queues[scope] = queue
}
return []providers.Message{msg}
}
}
// len returns the number of queued messages.
// len returns the number of queued messages across all scopes.
func (sq *steeringQueue) len() int {
sq.mu.Lock()
defer sq.mu.Unlock()
return len(sq.queue)
total := 0
for _, queue := range sq.queues {
total += len(queue)
}
return total
}
// lenScope returns the number of queued messages for a specific scope.
func (sq *steeringQueue) lenScope(scope string) int {
sq.mu.Lock()
defer sq.mu.Unlock()
return len(sq.queues[normalizeSteeringScope(scope)])
}
// setMode updates the steering mode.
@@ -102,28 +167,76 @@ func (sq *steeringQueue) getMode() SteeringMode {
return sq.mode
}
// --- AgentLoop steering API ---
// Steer enqueues a user message to be injected into the currently running
// agent loop. The message will be picked up after the current tool finishes
// executing, causing any remaining tool calls in the batch to be skipped.
func (al *AgentLoop) Steer(msg providers.Message) error {
scope := ""
agentID := ""
if ts := al.getAnyActiveTurnState(); ts != nil {
scope = ts.sessionKey
agentID = ts.agentID
}
return al.enqueueSteeringMessage(scope, agentID, msg)
}
func (al *AgentLoop) enqueueSteeringMessage(scope, agentID string, msg providers.Message) error {
if al.steering == nil {
return fmt.Errorf("steering queue is not initialized")
}
if err := al.steering.push(msg); err != nil {
if err := al.steering.pushScope(scope, msg); err != nil {
logger.WarnCF("agent", "Failed to enqueue steering message", map[string]any{
"error": err.Error(),
"role": msg.Role,
"scope": normalizeSteeringScope(scope),
})
return err
}
queueDepth := al.steering.lenScope(scope)
logger.DebugCF("agent", "Steering message enqueued", map[string]any{
"role": msg.Role,
"content_len": len(msg.Content),
"queue_len": al.steering.len(),
"media_count": len(msg.Media),
"queue_len": queueDepth,
"scope": normalizeSteeringScope(scope),
})
meta := EventMeta{
Source: "Steer",
TracePath: "turn.interrupt.received",
}
if ts := al.getAnyActiveTurnState(); ts != nil {
meta = ts.eventMeta("Steer", "turn.interrupt.received")
} else {
if strings.TrimSpace(agentID) != "" {
meta.AgentID = agentID
}
normalizedScope := normalizeSteeringScope(scope)
if normalizedScope != manualSteeringScope {
meta.SessionKey = normalizedScope
}
if meta.AgentID == "" {
if registry := al.GetRegistry(); registry != nil {
if agent := registry.GetDefaultAgent(); agent != nil {
meta.AgentID = agent.ID
}
}
}
}
al.emitEvent(
EventKindInterruptReceived,
meta,
InterruptReceivedPayload{
Kind: InterruptKindSteering,
Role: msg.Role,
ContentLen: len(msg.Content),
QueueDepth: queueDepth,
},
)
return nil
}
@@ -144,7 +257,7 @@ func (al *AgentLoop) SetSteeringMode(mode SteeringMode) {
}
// dequeueSteeringMessages is the internal method called by the agent loop
// to poll for steering messages. Returns nil when no messages are pending.
// to poll for steering messages in the legacy fallback scope.
func (al *AgentLoop) dequeueSteeringMessages() []providers.Message {
if al.steering == nil {
return nil
@@ -152,6 +265,60 @@ func (al *AgentLoop) dequeueSteeringMessages() []providers.Message {
return al.steering.dequeue()
}
func (al *AgentLoop) dequeueSteeringMessagesForScope(scope string) []providers.Message {
if al.steering == nil {
return nil
}
return al.steering.dequeueScope(scope)
}
func (al *AgentLoop) dequeueSteeringMessagesForScopeWithFallback(scope string) []providers.Message {
if al.steering == nil {
return nil
}
return al.steering.dequeueScopeWithFallback(scope)
}
func (al *AgentLoop) pendingSteeringCountForScope(scope string) int {
if al.steering == nil {
return 0
}
return al.steering.lenScope(scope)
}
func (al *AgentLoop) continueWithSteeringMessages(
ctx context.Context,
agent *AgentInstance,
sessionKey, channel, chatID string,
steeringMsgs []providers.Message,
) (string, error) {
return al.runAgentLoop(ctx, agent, processOptions{
SessionKey: sessionKey,
Channel: channel,
ChatID: chatID,
DefaultResponse: defaultResponse,
EnableSummary: true,
SendResponse: false,
InitialSteeringMessages: steeringMsgs,
SkipInitialSteeringPoll: true,
})
}
func (al *AgentLoop) agentForSession(sessionKey string) *AgentInstance {
registry := al.GetRegistry()
if registry == nil {
return nil
}
if parsed := routing.ParseAgentSessionKey(sessionKey); parsed != nil {
if agent, ok := registry.GetAgent(parsed.AgentID); ok {
return agent
}
}
return registry.GetDefaultAgent()
}
// Continue resumes an idle agent by dequeuing any pending steering messages
// and running them through the agent loop. This is used when the agent's last
// message was from the assistant (i.e., it has stopped processing) and the
@@ -159,33 +326,74 @@ func (al *AgentLoop) dequeueSteeringMessages() []providers.Message {
//
// If no steering messages are pending, it returns an empty string.
func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID string) (string, error) {
steeringMsgs := al.dequeueSteeringMessages()
if active := al.GetActiveTurn(); active != nil {
return "", fmt.Errorf("turn %s is still active", active.TurnID)
}
if err := al.ensureHooksInitialized(ctx); err != nil {
return "", err
}
if err := al.ensureMCPInitialized(ctx); err != nil {
return "", err
}
steeringMsgs := al.dequeueSteeringMessagesForScopeWithFallback(sessionKey)
if len(steeringMsgs) == 0 {
return "", nil
}
agent := al.GetRegistry().GetDefaultAgent()
agent := al.agentForSession(sessionKey)
if agent == nil {
return "", fmt.Errorf("no default agent available")
return "", fmt.Errorf("no agent available for session %q", sessionKey)
}
// Build a combined user message from the steering messages.
var contents []string
for _, msg := range steeringMsgs {
contents = append(contents, msg.Content)
if tool, ok := agent.Tools.Get("message"); ok {
if resetter, ok := tool.(interface{ ResetSentInRound() }); ok {
resetter.ResetSentInRound()
}
}
combinedContent := strings.Join(contents, "\n")
return al.runAgentLoop(ctx, agent, processOptions{
SessionKey: sessionKey,
Channel: channel,
ChatID: chatID,
UserMessage: combinedContent,
DefaultResponse: defaultResponse,
EnableSummary: true,
SendResponse: false,
SkipInitialSteeringPoll: true,
})
return al.continueWithSteeringMessages(ctx, agent, sessionKey, channel, chatID, steeringMsgs)
}
func (al *AgentLoop) InterruptGraceful(hint string) error {
ts := al.getAnyActiveTurnState()
if ts == nil {
return fmt.Errorf("no active turn")
}
if !ts.requestGracefulInterrupt(hint) {
return fmt.Errorf("turn %s cannot accept graceful interrupt", ts.turnID)
}
al.emitEvent(
EventKindInterruptReceived,
ts.eventMeta("InterruptGraceful", "turn.interrupt.received"),
InterruptReceivedPayload{
Kind: InterruptKindGraceful,
HintLen: len(hint),
},
)
return nil
}
func (al *AgentLoop) InterruptHard() error {
ts := al.getAnyActiveTurnState()
if ts == nil {
return fmt.Errorf("no active turn")
}
if !ts.requestHardAbort() {
return fmt.Errorf("turn %s is already aborting", ts.turnID)
}
al.emitEvent(
EventKindInterruptReceived,
ts.eventMeta("InterruptHard", "turn.interrupt.received"),
InterruptReceivedPayload{
Kind: InterruptKindHard,
},
)
return nil
}
// ====================== SubTurn Result Polling ======================
@@ -206,7 +414,10 @@ func (al *AgentLoop) dequeuePendingSubTurnResults(sessionKey string) []*tools.To
var results []*tools.ToolResult
for {
select {
case result := <-ts.pendingResults:
case result, ok := <-ts.pendingResults:
if !ok {
return results
}
if result != nil {
results = append(results, result)
}
@@ -249,20 +460,6 @@ func (al *AgentLoop) HardAbort(sessionKey string) error {
// Use isHardAbort=true for hard abort to immediately cancel all children.
ts.Finish(true)
// Rollback session history to the state before this turn started.
// This must happen AFTER Finish() to ensure no child turns are still writing.
if ts.session != nil {
currentHistory := ts.session.GetHistory("")
if len(currentHistory) > ts.initialHistoryLength {
logger.InfoCF("agent", "Rolling back session history", map[string]any{
"from": len(currentHistory),
"to": ts.initialHistoryLength,
})
// SetHistory with the truncated slice to rollback
ts.session.SetHistory("", currentHistory[:ts.initialHistoryLength])
}
}
return nil
}
@@ -291,19 +488,6 @@ func (al *AgentLoop) InjectFollowUp(msg providers.Message) error {
// ====================== API Aliases for Design Document Compatibility ======================
// InterruptGraceful is an alias for Steer() to match the design document naming.
// It gracefully interrupts the current execution by injecting a user message
// that will be processed after the current tool finishes.
func (al *AgentLoop) InterruptGraceful(msg providers.Message) error {
return al.Steer(msg)
}
// InterruptHard is an alias for HardAbort() to match the design document naming.
// It immediately terminates execution and rolls back the session state.
func (al *AgentLoop) InterruptHard(sessionKey string) error {
return al.HardAbort(sessionKey)
}
// InjectSteering is an alias for Steer() to match the design document naming.
// It injects a steering message into the currently running agent loop.
func (al *AgentLoop) InjectSteering(msg providers.Message) error {
+847
View File
@@ -5,13 +5,18 @@ import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"reflect"
"strings"
"sync"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/media"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/routing"
"github.com/sipeed/picoclaw/pkg/tools"
)
@@ -335,6 +340,97 @@ func TestAgentLoop_Continue_WithMessages(t *testing.T) {
}
}
func TestDrainBusToSteering_RequeuesDifferentScopeMessage(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
Session: config.SessionConfig{
DMScope: "per-peer",
},
}
msgBus := bus.NewMessageBus()
al := NewAgentLoop(cfg, msgBus, &mockProvider{})
activeMsg := bus.InboundMessage{
Channel: "telegram",
SenderID: "user1",
ChatID: "chat1",
Content: "active turn",
Peer: bus.Peer{
Kind: "direct",
ID: "user1",
},
}
activeScope, activeAgentID, ok := al.resolveSteeringTarget(activeMsg)
if !ok {
t.Fatal("expected active message to resolve to a steering scope")
}
otherMsg := bus.InboundMessage{
Channel: "telegram",
SenderID: "user2",
ChatID: "chat2",
Content: "other session",
Peer: bus.Peer{
Kind: "direct",
ID: "user2",
},
}
otherScope, _, ok := al.resolveSteeringTarget(otherMsg)
if !ok {
t.Fatal("expected other message to resolve to a steering scope")
}
if otherScope == activeScope {
t.Fatalf("expected different steering scopes, got same scope %q", activeScope)
}
if err := msgBus.PublishInbound(context.Background(), otherMsg); err != nil {
t.Fatalf("PublishInbound failed: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
done := make(chan struct{})
go func() {
al.drainBusToSteering(ctx, activeScope, activeAgentID)
close(done)
}()
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for drainBusToSteering to stop")
}
if msgs := al.dequeueSteeringMessagesForScope(activeScope); len(msgs) != 0 {
t.Fatalf("expected no steering messages for active scope, got %v", msgs)
}
select {
case <-ctx.Done():
t.Fatalf("timeout waiting for requeued message on outbound bus")
case requeued := <-msgBus.OutboundChan():
if requeued.Channel != otherMsg.Channel || requeued.ChatID != otherMsg.ChatID ||
requeued.Content != otherMsg.Content {
t.Fatalf("requeued message mismatch: got %+v want %+v", requeued, otherMsg)
}
}
}
// slowTool simulates a tool that takes some time to execute.
type slowTool struct {
name string
@@ -396,6 +492,149 @@ func (m *toolCallProvider) GetDefaultModel() string {
return "tool-call-mock"
}
type gracefulCaptureProvider struct {
mu sync.Mutex
calls int
toolCalls []providers.ToolCall
finalResp string
terminalMessages []providers.Message
terminalToolsCount int
}
func (p *gracefulCaptureProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
p.mu.Lock()
defer p.mu.Unlock()
p.calls++
if p.calls == 1 {
return &providers.LLMResponse{
ToolCalls: p.toolCalls,
}, nil
}
p.terminalMessages = append([]providers.Message(nil), messages...)
p.terminalToolsCount = len(tools)
return &providers.LLMResponse{
Content: p.finalResp,
}, nil
}
func (p *gracefulCaptureProvider) GetDefaultModel() string {
return "graceful-capture-mock"
}
type lateSteeringProvider struct {
mu sync.Mutex
calls int
firstCallStarted chan struct{}
releaseFirstCall chan struct{}
firstStartOnce sync.Once
secondCallMessages []providers.Message
}
func (p *lateSteeringProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
p.mu.Lock()
p.calls++
call := p.calls
p.mu.Unlock()
if call == 1 {
p.firstStartOnce.Do(func() { close(p.firstCallStarted) })
<-p.releaseFirstCall
return &providers.LLMResponse{Content: "first response"}, nil
}
p.mu.Lock()
p.secondCallMessages = append([]providers.Message(nil), messages...)
p.mu.Unlock()
return &providers.LLMResponse{Content: "continued response"}, nil
}
func (p *lateSteeringProvider) GetDefaultModel() string {
return "late-steering-mock"
}
type blockingDirectProvider struct {
mu sync.Mutex
calls int
firstStarted chan struct{}
releaseFirst chan struct{}
firstResp string
finalResp string
}
func (p *blockingDirectProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
p.mu.Lock()
p.calls++
call := p.calls
firstStarted := p.firstStarted
releaseFirst := p.releaseFirst
firstResp := p.firstResp
finalResp := p.finalResp
if call == 1 && p.firstStarted != nil {
close(p.firstStarted)
p.firstStarted = nil
}
p.mu.Unlock()
if call == 1 {
select {
case <-releaseFirst:
case <-ctx.Done():
return nil, ctx.Err()
}
return &providers.LLMResponse{Content: firstResp}, nil
}
_ = firstStarted
return &providers.LLMResponse{Content: finalResp}, nil
}
func (p *blockingDirectProvider) GetDefaultModel() string {
return "blocking-direct-mock"
}
type interruptibleTool struct {
name string
started chan struct{}
once sync.Once
}
func (t *interruptibleTool) Name() string { return t.name }
func (t *interruptibleTool) Description() string { return "interruptible tool for testing" }
func (t *interruptibleTool) Parameters() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]any{},
}
}
func (t *interruptibleTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
if t.started != nil {
t.once.Do(func() { close(t.started) })
}
<-ctx.Done()
return tools.ErrorResult(ctx.Err().Error()).WithError(ctx.Err())
}
func TestAgentLoop_Steering_SkipsRemainingTools(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
@@ -568,6 +807,614 @@ func TestAgentLoop_Steering_InitialPoll(t *testing.T) {
}
}
func TestAgentLoop_Run_AutoContinuesLateSteeringMessage(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &lateSteeringProvider{
firstCallStarted: make(chan struct{}),
releaseFirstCall: make(chan struct{}),
}
al := NewAgentLoop(cfg, msgBus, provider)
runCtx, cancelRun := context.WithCancel(context.Background())
defer cancelRun()
runErrCh := make(chan error, 1)
go func() {
runErrCh <- al.Run(runCtx)
}()
first := bus.InboundMessage{
Channel: "test",
SenderID: "user1",
ChatID: "chat1",
Content: "first message",
Peer: bus.Peer{
Kind: "direct",
ID: "user1",
},
}
late := bus.InboundMessage{
Channel: "test",
SenderID: "user1",
ChatID: "chat1",
Content: "late append",
Peer: bus.Peer{
Kind: "direct",
ID: "user1",
},
}
pubCtx, pubCancel := context.WithTimeout(context.Background(), 2*time.Second)
defer pubCancel()
if err := msgBus.PublishInbound(pubCtx, first); err != nil {
t.Fatalf("publish first inbound: %v", err)
}
select {
case <-provider.firstCallStarted:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for first provider call to start")
}
if err := msgBus.PublishInbound(pubCtx, late); err != nil {
t.Fatalf("publish late inbound: %v", err)
}
close(provider.releaseFirstCall)
subCtx, subCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer subCancel()
var out1 bus.OutboundMessage
select {
case out1 = <-msgBus.OutboundChan():
case <-subCtx.Done():
t.Fatal("expected outbound response")
}
if out1.Content != "continued response" {
t.Fatalf("expected continued response, got %q", out1.Content)
}
noExtraCtx, cancelNoExtra := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancelNoExtra()
select {
case out2 := <-msgBus.OutboundChan():
t.Fatalf("expected stale direct response to be suppressed, got extra outbound %q", out2.Content)
case <-noExtraCtx.Done():
}
cancelRun()
select {
case err := <-runErrCh:
if err != nil {
t.Fatalf("Run returned error: %v", err)
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for Run to stop")
}
provider.mu.Lock()
calls := provider.calls
secondMessages := append([]providers.Message(nil), provider.secondCallMessages...)
provider.mu.Unlock()
if calls != 2 {
t.Fatalf("expected 2 provider calls, got %d", calls)
}
foundLateMessage := false
for _, msg := range secondMessages {
if msg.Role == "user" && msg.Content == "late append" {
foundLateMessage = true
break
}
}
if !foundLateMessage {
t.Fatal("expected queued late message to be processed in an automatic follow-up turn")
}
}
func TestAgentLoop_Steering_DirectResponseContinuesWithQueuedMessage(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID)
provider := &blockingDirectProvider{
firstStarted: make(chan struct{}),
releaseFirst: make(chan struct{}),
firstResp: "stale direct response",
finalResp: "fresh response after steering",
}
msgBus := bus.NewMessageBus()
al := NewAgentLoop(cfg, msgBus, provider)
resultCh := make(chan struct {
resp string
err error
}, 1)
go func() {
resp, err := al.ProcessDirectWithChannel(
context.Background(),
"initial request",
sessionKey,
"test",
"chat1",
)
resultCh <- struct {
resp string
err error
}{resp: resp, err: err}
}()
select {
case <-provider.firstStarted:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for first LLM call to start")
}
if err := al.Steer(providers.Message{Role: "user", Content: "follow-up instruction"}); err != nil {
t.Fatalf("Steer failed: %v", err)
}
close(provider.releaseFirst)
select {
case result := <-resultCh:
if result.err != nil {
t.Fatalf("unexpected error: %v", result.err)
}
if result.resp != "fresh response after steering" {
t.Fatalf("expected refreshed response, got %q", result.resp)
}
case <-time.After(5 * time.Second):
t.Fatal("timeout waiting for ProcessDirectWithChannel")
}
provider.mu.Lock()
calls := provider.calls
provider.mu.Unlock()
if calls != 2 {
t.Fatalf("expected 2 provider calls, got %d", calls)
}
if msgs := al.dequeueSteeringMessagesForScope(sessionKey); len(msgs) != 0 {
t.Fatalf("expected steering queue to be empty after continuation, got %v", msgs)
}
}
func TestAgentLoop_Continue_PreservesSteeringMedia(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
store := media.NewFileMediaStore()
pngPath := filepath.Join(tmpDir, "steer.png")
pngHeader := []byte{
0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A,
0x00, 0x00, 0x00, 0x0D,
0x49, 0x48, 0x44, 0x52,
0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x02,
0x00, 0x00, 0x00,
0x90, 0x77, 0x53, 0xDE,
}
if err = os.WriteFile(pngPath, pngHeader, 0o644); err != nil {
t.Fatalf("WriteFile failed: %v", err)
}
ref, err := store.Store(pngPath, media.MediaMeta{Filename: "steer.png", ContentType: "image/png"}, "test")
if err != nil {
t.Fatalf("Store failed: %v", err)
}
var capturedMessages []providers.Message
var capMu sync.Mutex
provider := &capturingMockProvider{
response: "ack",
captureFn: func(msgs []providers.Message) {
capMu.Lock()
defer capMu.Unlock()
capturedMessages = append([]providers.Message(nil), msgs...)
},
}
sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID)
msgBus := bus.NewMessageBus()
al := NewAgentLoop(cfg, msgBus, provider)
al.SetMediaStore(store)
if err = al.Steer(providers.Message{
Role: "user",
Content: "describe this image",
Media: []string{ref},
}); err != nil {
t.Fatalf("Steer failed: %v", err)
}
resp, err := al.Continue(context.Background(), sessionKey, "test", "chat1")
if err != nil {
t.Fatalf("Continue failed: %v", err)
}
if resp != "ack" {
t.Fatalf("expected ack, got %q", resp)
}
capMu.Lock()
msgs := append([]providers.Message(nil), capturedMessages...)
capMu.Unlock()
foundResolvedMedia := false
for _, msg := range msgs {
if msg.Role != "user" || msg.Content != "describe this image" || len(msg.Media) != 1 {
continue
}
if strings.HasPrefix(msg.Media[0], "data:image/png;base64,") {
foundResolvedMedia = true
break
}
}
if !foundResolvedMedia {
t.Fatal("expected continue path to inject steering media into the provider request")
}
defaultAgent := al.registry.GetDefaultAgent()
if defaultAgent == nil {
t.Fatal("expected default agent")
}
history := defaultAgent.Sessions.GetHistory(sessionKey)
foundOriginalRef := false
for _, msg := range history {
if msg.Role == "user" && len(msg.Media) == 1 && msg.Media[0] == ref {
foundOriginalRef = true
break
}
}
if !foundOriginalRef {
t.Fatal("expected original steering media ref to be preserved in session history")
}
}
func TestAgentLoop_InterruptGraceful_UsesTerminalNoToolCall(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
tool1ExecCh := make(chan struct{})
tool1 := &slowTool{name: "tool_one", duration: 50 * time.Millisecond, execCh: tool1ExecCh}
tool2 := &slowTool{name: "tool_two", duration: 50 * time.Millisecond}
provider := &gracefulCaptureProvider{
toolCalls: []providers.ToolCall{
{
ID: "call_1",
Type: "function",
Name: "tool_one",
Function: &providers.FunctionCall{
Name: "tool_one",
Arguments: "{}",
},
Arguments: map[string]any{},
},
{
ID: "call_2",
Type: "function",
Name: "tool_two",
Function: &providers.FunctionCall{
Name: "tool_two",
Arguments: "{}",
},
Arguments: map[string]any{},
},
},
finalResp: "graceful summary",
}
msgBus := bus.NewMessageBus()
al := NewAgentLoop(cfg, msgBus, provider)
al.RegisterTool(tool1)
al.RegisterTool(tool2)
sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID)
sub := al.SubscribeEvents(32)
defer al.UnsubscribeEvents(sub.ID)
type result struct {
resp string
err error
}
resultCh := make(chan result, 1)
go func() {
resp, err := al.ProcessDirectWithChannel(
context.Background(),
"do something",
sessionKey,
"test",
"chat1",
)
resultCh <- result{resp: resp, err: err}
}()
select {
case <-tool1ExecCh:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for tool_one to start")
}
active := al.GetActiveTurn()
if active == nil {
t.Fatal("expected active turn while tool is running")
}
if active.SessionKey != sessionKey {
t.Fatalf("expected active session %q, got %q", sessionKey, active.SessionKey)
}
if active.Channel != "test" || active.ChatID != "chat1" {
t.Fatalf("unexpected active turn target: %#v", active)
}
if err := al.InterruptGraceful("wrap it up"); err != nil {
t.Fatalf("InterruptGraceful failed: %v", err)
}
select {
case r := <-resultCh:
if r.err != nil {
t.Fatalf("unexpected error: %v", r.err)
}
if r.resp != "graceful summary" {
t.Fatalf("expected graceful summary, got %q", r.resp)
}
case <-time.After(5 * time.Second):
t.Fatal("timeout waiting for graceful interrupt result")
}
if active := al.GetActiveTurn(); active != nil {
t.Fatalf("expected no active turn after completion, got %#v", active)
}
provider.mu.Lock()
terminalMessages := append([]providers.Message(nil), provider.terminalMessages...)
terminalToolsCount := provider.terminalToolsCount
calls := provider.calls
provider.mu.Unlock()
if calls != 2 {
t.Fatalf("expected 2 provider calls, got %d", calls)
}
if terminalToolsCount != 0 {
t.Fatalf("expected graceful terminal call to disable tools, got %d tool defs", terminalToolsCount)
}
foundHint := false
foundSkipped := false
expectedHint := "Interrupt requested. Stop scheduling tools and provide a short final summary.\n\n" +
"Interrupt hint: wrap it up"
for _, msg := range terminalMessages {
if msg.Role == "user" && msg.Content == expectedHint {
foundHint = true
}
if msg.Role == "tool" && msg.ToolCallID == "call_2" && msg.Content == "Skipped due to graceful interrupt." {
foundSkipped = true
}
}
if !foundHint {
t.Fatal("expected graceful terminal call to include interrupt hint message")
}
if !foundSkipped {
t.Fatal("expected remaining tool to be marked as skipped after graceful interrupt")
}
events := collectEventStream(sub.C)
interruptEvt, ok := findEvent(events, EventKindInterruptReceived)
if !ok {
t.Fatal("expected interrupt received event")
}
interruptPayload, ok := interruptEvt.Payload.(InterruptReceivedPayload)
if !ok {
t.Fatalf("expected InterruptReceivedPayload, got %T", interruptEvt.Payload)
}
if interruptPayload.Kind != InterruptKindGraceful {
t.Fatalf("expected graceful interrupt payload, got %q", interruptPayload.Kind)
}
turnEndEvt, ok := findEvent(events, EventKindTurnEnd)
if !ok {
t.Fatal("expected turn end event")
}
turnEndPayload, ok := turnEndEvt.Payload.(TurnEndPayload)
if !ok {
t.Fatalf("expected TurnEndPayload, got %T", turnEndEvt.Payload)
}
if turnEndPayload.Status != TurnEndStatusCompleted {
t.Fatalf("expected completed turn after graceful interrupt, got %q", turnEndPayload.Status)
}
}
func TestAgentLoop_InterruptHard_RestoresSession(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &toolCallProvider{
toolCalls: []providers.ToolCall{
{
ID: "call_1",
Type: "function",
Name: "cancel_tool",
Function: &providers.FunctionCall{
Name: "cancel_tool",
Arguments: "{}",
},
Arguments: map[string]any{},
},
},
finalResp: "should not happen",
}
al := NewAgentLoop(cfg, msgBus, provider)
started := make(chan struct{})
al.RegisterTool(&interruptibleTool{name: "cancel_tool", started: started})
sessionKey := routing.BuildAgentMainSessionKey(routing.DefaultAgentID)
defaultAgent := al.registry.GetDefaultAgent()
if defaultAgent == nil {
t.Fatal("expected default agent")
}
originalHistory := []providers.Message{
{Role: "user", Content: "before"},
{Role: "assistant", Content: "after"},
}
defaultAgent.Sessions.SetHistory(sessionKey, originalHistory)
sub := al.SubscribeEvents(16)
defer al.UnsubscribeEvents(sub.ID)
type result struct {
resp string
err error
}
resultCh := make(chan result, 1)
go func() {
resp, err := al.ProcessDirectWithChannel(
context.Background(),
"do work",
sessionKey,
"test",
"chat1",
)
resultCh <- result{resp: resp, err: err}
}()
select {
case <-started:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for interruptible tool to start")
}
if active := al.GetActiveTurn(); active == nil {
t.Fatal("expected active turn before hard abort")
}
if err := al.InterruptHard(); err != nil {
t.Fatalf("InterruptHard failed: %v", err)
}
select {
case r := <-resultCh:
if r.err != nil {
t.Fatalf("unexpected error: %v", r.err)
}
if r.resp != "" {
t.Fatalf("expected no final response after hard abort, got %q", r.resp)
}
case <-time.After(5 * time.Second):
t.Fatal("timeout waiting for hard abort result")
}
if active := al.GetActiveTurn(); active != nil {
t.Fatalf("expected no active turn after hard abort, got %#v", active)
}
finalHistory := defaultAgent.Sessions.GetHistory(sessionKey)
if !reflect.DeepEqual(finalHistory, originalHistory) {
t.Fatalf("expected history rollback after hard abort, got %#v", finalHistory)
}
events := collectEventStream(sub.C)
interruptEvt, ok := findEvent(events, EventKindInterruptReceived)
if !ok {
t.Fatal("expected interrupt received event")
}
interruptPayload, ok := interruptEvt.Payload.(InterruptReceivedPayload)
if !ok {
t.Fatalf("expected InterruptReceivedPayload, got %T", interruptEvt.Payload)
}
if interruptPayload.Kind != InterruptKindHard {
t.Fatalf("expected hard interrupt payload, got %q", interruptPayload.Kind)
}
turnEndEvt, ok := findEvent(events, EventKindTurnEnd)
if !ok {
t.Fatal("expected turn end event")
}
turnEndPayload, ok := turnEndEvt.Payload.(TurnEndPayload)
if !ok {
t.Fatalf("expected TurnEndPayload, got %T", turnEndEvt.Payload)
}
if turnEndPayload.Status != TurnEndStatusAborted {
t.Fatalf("expected aborted turn, got %q", turnEndPayload.Status)
}
}
// capturingMockProvider captures messages sent to Chat for inspection.
type capturingMockProvider struct {
response string
+237 -332
View File
@@ -4,14 +4,13 @@ import (
"context"
"errors"
"fmt"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/tools"
"github.com/sipeed/picoclaw/pkg/utils"
)
// ====================== Config & Constants ======================
@@ -176,33 +175,6 @@ type SubTurnConfig struct {
// Can be extended with temperature, topP, etc.
}
// ====================== Sub-turn Events (Aligned with EventBus) ======================
// SubTurnSpawnEvent is emitted when a child sub-turn is started.
type SubTurnSpawnEvent struct {
ParentID string
ChildID string
Config SubTurnConfig
}
type SubTurnEndEvent struct {
ChildID string
Result *tools.ToolResult
Err error
}
type SubTurnResultDeliveredEvent struct {
ParentID string
ChildID string
Result *tools.ToolResult
}
type SubTurnOrphanResultEvent struct {
ParentID string
ChildID string
Result *tools.ToolResult
}
// ====================== Context Keys ======================
type agentLoopKeyType struct{}
@@ -300,6 +272,11 @@ func spawnSubTurn(
// 0. Acquire concurrency semaphore FIRST to ensure it's released even if early validation fails.
// Blocks if parent already has maxConcurrentSubTurns running, with a timeout to prevent indefinite blocking.
// Also respects context cancellation so we don't block forever if parent is aborted.
// NOTE: The semaphore is released immediately after runTurn completes (not in a defer) to
// ensure it is freed before the cleanup phase (async result delivery), which may block on
// a full pendingResults channel. Holding the semaphore through cleanup would allow the
// parent's goroutine to be blocked waiting for a semaphore slot while child turns are
// blocked delivering results — a deadlock.
var semAcquired bool
if parentTS.concurrencySem != nil {
// Create a timeout context for semaphore acquisition
@@ -353,10 +330,60 @@ func spawnSubTurn(
defer cancel()
childID := al.generateSubTurnID()
childTS := newTurnState(childCtx, childID, parentTS, rtCfg.maxConcurrent)
// Set the cancel function so Finish(true) can trigger hard cancellation
// Get the agent instance from parent, falling back to the default agent.
// Wrap it in a shallow copy that uses an ephemeral (in-memory only) session store
// so that child turns never pollute or persist to the parent's session history.
baseAgent := parentTS.agent
if baseAgent == nil {
baseAgent = al.registry.GetDefaultAgent()
}
if baseAgent == nil {
return nil, errors.New("parent turnState has no agent instance")
}
ephemeralStore := newEphemeralSession(nil)
agent := *baseAgent // shallow copy
agent.Sessions = ephemeralStore
// Clone the tool registry so child turn's tool registrations
// don't pollute the parent's registry.
if baseAgent.Tools != nil {
agent.Tools = baseAgent.Tools.Clone()
}
// Create processOptions for the child turn
opts := processOptions{
SessionKey: childID,
Channel: parentTS.channel,
ChatID: parentTS.chatID,
SenderID: parentTS.opts.SenderID,
SenderDisplayName: parentTS.opts.SenderDisplayName,
UserMessage: cfg.SystemPrompt, // Task description becomes the first user message
SystemPromptOverride: cfg.ActualSystemPrompt,
Media: nil,
InitialSteeringMessages: cfg.InitialMessages,
DefaultResponse: "",
EnableSummary: false,
SendResponse: false,
NoHistory: true, // SubTurns don't use session history
SkipInitialSteeringPoll: true,
}
// Create event scope for the child turn
scope := al.newTurnEventScope(agent.ID, childID)
// Create child turnState using the new API
childTS := newTurnState(&agent, opts, scope)
// Set SubTurn-specific fields
childTS.cancelFunc = cancel
childTS.critical = cfg.Critical
childTS.depth = parentTS.depth + 1
childTS.parentTurnID = parentTS.turnID
childTS.parentTurnState = parentTS
childTS.pendingResults = make(chan *tools.ToolResult, 16)
childTS.concurrencySem = make(chan struct{}, rtCfg.maxConcurrent)
childTS.al = al // back-ref for hard abort cascade
childTS.session = ephemeralStore // same store as agent.Sessions
// Token budget initialization/inheritance
// If InitialTokenBudget is explicitly provided (e.g., by team tool), use it.
@@ -376,6 +403,8 @@ func spawnSubTurn(
childCtx = withTurnState(childCtx, childTS)
childCtx = WithAgentLoop(childCtx, al) // Propagate AgentLoop to child turn
childTS.ctx = childCtx
// Register child turn state so GetAllActiveTurns/Subagents can find it
al.activeTurnStates.Store(childID, childTS)
defer al.activeTurnStates.Delete(childID)
@@ -386,11 +415,14 @@ func spawnSubTurn(
parentTS.mu.Unlock()
// 6. Emit Spawn event
MockEventBus.Emit(SubTurnSpawnEvent{
ParentID: parentTS.turnID,
ChildID: childID,
Config: cfg,
})
al.emitEvent(EventKindSubTurnSpawn,
childTS.eventMeta("spawnSubTurn", "subturn.spawn"),
SubTurnSpawnPayload{
AgentID: childTS.agentID,
Label: childID,
ParentTurnID: parentTS.turnID,
},
)
// 7. Defer cleanup: deliver result (for async), emit End event, and recover from panics
defer func() {
@@ -401,22 +433,61 @@ func spawnSubTurn(
"parent_id": parentTS.turnID,
"panic": r,
})
// Ensure result is not nil to prevent panic during event emission
if result == nil {
result = &tools.ToolResult{
Err: err,
ForLLM: fmt.Sprintf("SubTurn panicked: %v", r),
}
}
}
// Result Delivery Strategy (Async vs Sync)
if cfg.Async {
deliverSubTurnResult(parentTS, childID, result)
deliverSubTurnResult(al, parentTS, childID, result)
}
MockEventBus.Emit(SubTurnEndEvent{
ChildID: childID,
Result: result,
Err: err,
})
status := "completed"
if err != nil {
status = "error"
}
al.emitEvent(EventKindSubTurnEnd,
childTS.eventMeta("spawnSubTurn", "subturn.end"),
SubTurnEndPayload{
AgentID: childTS.agentID,
Status: status,
},
)
}()
// 8. Execute sub-turn via the real agent loop.
result, err = runTurn(childCtx, al, childTS, cfg)
turnRes, turnErr := al.runTurn(childCtx, childTS)
// Release the concurrency semaphore immediately after runTurn completes,
// before the cleanup defer runs. This prevents a deadlock where:
// - All semaphore slots are held by sub-turns in their cleanup phase
// - Cleanup blocks on a full pendingResults channel
// - The parent goroutine is blocked waiting for a semaphore slot
// - The parent cannot consume pendingResults because it is blocked on the semaphore
if semAcquired {
<-parentTS.concurrencySem
semAcquired = false // prevent the defer from double-releasing
}
// Convert turnResult to tools.ToolResult
if turnErr != nil {
err = turnErr
result = &tools.ToolResult{
Err: turnErr,
ForLLM: fmt.Sprintf("SubTurn failed: %v", turnErr),
}
} else {
result = &tools.ToolResult{
ForLLM: turnRes.finalContent,
ForUser: turnRes.finalContent,
}
}
return result, err
}
@@ -441,7 +512,7 @@ func spawnSubTurn(
// Event emissions:
// - SubTurnResultDeliveredEvent: successful delivery to channel
// - SubTurnOrphanResultEvent: delivery failed (parent finished or channel full)
func deliverSubTurnResult(parentTS *turnState, childID string, result *tools.ToolResult) {
func deliverSubTurnResult(al *AgentLoop, parentTS *turnState, childID string, result *tools.ToolResult) {
// Let GC clean up the pendingResults channel; parent Finish will no longer close it.
// We use defer/recover to catch any unlikely channel panics if it were ever closed.
defer func() {
@@ -451,28 +522,26 @@ func deliverSubTurnResult(parentTS *turnState, childID string, result *tools.Too
"child_id": childID,
"recover": r,
})
if result != nil {
MockEventBus.Emit(SubTurnOrphanResultEvent{
ParentID: parentTS.turnID,
ChildID: childID,
Result: result,
})
if result != nil && al != nil {
al.emitEvent(EventKindSubTurnOrphan,
parentTS.eventMeta("deliverSubTurnResult", "subturn.orphan"),
SubTurnOrphanPayload{ParentTurnID: parentTS.turnID, ChildTurnID: childID, Reason: "panic"},
)
}
}
}()
parentTS.mu.Lock()
isFinished := parentTS.isFinished
isFinished := parentTS.isFinished.Load()
resultChan := parentTS.pendingResults
parentTS.mu.Unlock()
// If parent turn has already finished, treat this as an orphan result
if isFinished || resultChan == nil {
if result != nil {
MockEventBus.Emit(SubTurnOrphanResultEvent{
ParentID: parentTS.turnID,
ChildID: childID,
Result: result,
})
if result != nil && al != nil {
al.emitEvent(EventKindSubTurnOrphan,
parentTS.eventMeta("deliverSubTurnResult", "subturn.orphan"),
SubTurnOrphanPayload{ParentTurnID: parentTS.turnID, ChildTurnID: childID, Reason: "parent_finished"},
)
}
return
}
@@ -484,11 +553,12 @@ func deliverSubTurnResult(parentTS *turnState, childID string, result *tools.Too
select {
case resultChan <- result:
// Successfully delivered
MockEventBus.Emit(SubTurnResultDeliveredEvent{
ParentID: parentTS.turnID,
ChildID: childID,
Result: result,
})
if al != nil {
al.emitEvent(EventKindSubTurnResultDelivered,
parentTS.eventMeta("deliverSubTurnResult", "subturn.result_delivered"),
SubTurnResultDeliveredPayload{ContentLen: len(result.ForLLM)},
)
}
case <-parentTS.Finished():
// Parent finished while we were waiting to deliver.
// The result cannot be delivered to the LLM, so it becomes an orphan.
@@ -496,278 +566,113 @@ func deliverSubTurnResult(parentTS *turnState, childID string, result *tools.Too
"parent_id": parentTS.turnID,
"child_id": childID,
})
if result != nil {
MockEventBus.Emit(SubTurnOrphanResultEvent{
ParentID: parentTS.turnID,
ChildID: childID,
Result: result,
})
if result != nil && al != nil {
al.emitEvent(
EventKindSubTurnOrphan,
parentTS.eventMeta("deliverSubTurnResult", "subturn.orphan"),
SubTurnOrphanPayload{
ParentTurnID: parentTS.turnID,
ChildTurnID: childID,
Reason: "parent_finished_waiting",
},
)
}
}
}
// runTurn builds a temporary AgentInstance from SubTurnConfig and delegates to
// the real agent loop. The child's ephemeral session is used for history so it
// never pollutes the parent session.
//
// This function implements multiple layers of context protection and error recovery:
//
// 1. Soft Context Limit (MaxContextRunes):
// - Proactively truncates message history before LLM calls
// - Default: 75% of model's context window
// - Preserves system messages and recent context
// - First line of defense against context overflow
//
// 2. Hard Context Error Recovery:
// - Detects context_length_exceeded errors from provider
// - Triggers force compression and retries (up to 2 times)
// - Second line of defense when soft limit is insufficient
//
// 3. Truncation Recovery:
// - Detects when LLM response is truncated (finish_reason="truncated")
// - Injects recovery prompt asking for shorter response
// - Retries up to 2 times
// - Handles cases where max_tokens is hit
func runTurn(
ctx context.Context,
al *AgentLoop,
ts *turnState,
cfg SubTurnConfig,
) (*tools.ToolResult, error) {
// Derive candidates from the requested model using the parent loop's provider.
defaultProvider := al.GetConfig().Agents.Defaults.Provider
candidates := providers.ResolveCandidates(
providers.ModelConfig{Primary: cfg.Model},
defaultProvider,
)
// Build a minimal AgentInstance for this sub-turn.
// It reuses the parent loop's provider and config, but gets its own
// ephemeral session store and tool registry.
parentAgent := al.GetRegistry().GetDefaultAgent()
// Determine which tools to use: explicit config or inherit from parent
toolRegistry := tools.NewToolRegistry()
toolsToRegister := cfg.Tools
if len(toolsToRegister) == 0 {
toolsToRegister = parentAgent.Tools.GetAll()
}
for _, t := range toolsToRegister {
toolRegistry.Register(t)
}
childAgent := &AgentInstance{
ID: ts.turnID,
Model: cfg.Model,
MaxIterations: parentAgent.MaxIterations,
MaxTokens: cfg.MaxTokens,
Temperature: parentAgent.Temperature,
ThinkingLevel: parentAgent.ThinkingLevel,
ContextWindow: parentAgent.ContextWindow, // Inherit from parent agent
SummarizeMessageThreshold: parentAgent.SummarizeMessageThreshold,
SummarizeTokenPercent: parentAgent.SummarizeTokenPercent,
Provider: parentAgent.Provider,
Sessions: ts.session,
ContextBuilder: parentAgent.ContextBuilder,
Tools: toolRegistry,
Candidates: candidates,
}
if childAgent.MaxTokens == 0 {
childAgent.MaxTokens = parentAgent.MaxTokens
}
promptAlreadyAdded := false
// Preload ephemeral session history
if len(cfg.InitialMessages) > 0 {
existing := childAgent.Sessions.GetHistory(ts.turnID)
childAgent.Sessions.SetHistory(ts.turnID, append(existing, cfg.InitialMessages...))
promptAlreadyAdded = true // InitialMessages 中已含 user 消息,跳过再次添加
}
// Resolve MaxContextRunes configuration
maxContextRunes := utils.ResolveMaxContextRunes(cfg.MaxContextRunes, childAgent.ContextWindow)
logger.DebugCF("subturn", "Context limit resolved",
map[string]any{
"turn_id": ts.turnID,
"context_window": childAgent.ContextWindow,
"max_context_runes": maxContextRunes,
"configured_value": cfg.MaxContextRunes,
})
// Retry loop for truncation and context errors
const (
maxTruncationRetries = 2
maxContextRetries = 2
)
truncationRetryCount := 0
contextRetryCount := 0
currentPrompt := cfg.SystemPrompt
for {
// Soft context limit: check and truncate before LLM call
if maxContextRunes > 0 {
messages := childAgent.Sessions.GetHistory(ts.turnID)
currentRunes := utils.MeasureContextRunes(messages)
if currentRunes > maxContextRunes {
logger.WarnCF("subturn", "Context exceeds soft limit, truncating",
map[string]any{
"turn_id": ts.turnID,
"current_runes": currentRunes,
"max_runes": maxContextRunes,
"overflow": currentRunes - maxContextRunes,
})
truncatedMessages := utils.TruncateContextSmart(messages, maxContextRunes)
childAgent.Sessions.SetHistory(ts.turnID, truncatedMessages)
// Log truncation result
newRunes := utils.MeasureContextRunes(truncatedMessages)
logger.InfoCF("subturn", "Context truncated successfully",
map[string]any{
"turn_id": ts.turnID,
"before_runes": currentRunes,
"after_runes": newRunes,
"saved_runes": currentRunes - newRunes,
})
}
}
// Call the agent loop
finalContent, err := al.runAgentLoop(ctx, childAgent, processOptions{
SessionKey: ts.turnID,
UserMessage: currentPrompt,
SystemPromptOverride: cfg.ActualSystemPrompt,
DefaultResponse: "",
EnableSummary: false,
SendResponse: false,
SkipAddUserMessage: promptAlreadyAdded,
})
// Mark the prompt as added so subsequent truncation retries
// won't duplicate it in the history.
promptAlreadyAdded = true
// 1. Handle context length errors
if err != nil && isContextLengthError(err) {
if contextRetryCount >= maxContextRetries {
logger.ErrorCF("subturn", "Context limit exceeded after max retries",
map[string]any{
"turn_id": ts.turnID,
"retries": contextRetryCount,
"max_retries": maxContextRetries,
})
return nil, fmt.Errorf(
"context limit exceeded after %d retries: %w",
maxContextRetries,
err,
)
}
logger.WarnCF("subturn", "Context length exceeded, compressing and retrying",
map[string]any{
"turn_id": ts.turnID,
"retry": contextRetryCount + 1,
})
// Trigger force compression
al.forceCompression(childAgent, ts.turnID)
contextRetryCount++
continue // Retry with compressed history
}
if err != nil {
return nil, err // Other errors, return immediately
}
// 2. Check for truncation (retrieve finishReason from turnState)
finishReason := ts.GetLastFinishReason()
if finishReason == "truncated" && truncationRetryCount < maxTruncationRetries {
logger.WarnCF("subturn", "Response truncated, injecting recovery message",
map[string]any{
"turn_id": ts.turnID,
"retry": truncationRetryCount + 1,
})
// IMPORTANT: Do NOT manually add messages to history here.
// runAgentLoop has already saved both the assistant message (finalContent)
// and will save the next user message (currentPrompt) on the next iteration.
// Manually adding them would cause duplicates.
// Inject recovery prompt - it will be added by runAgentLoop on next iteration
recoveryPrompt := "Your previous response was truncated due to length. Please provide a shorter, complete response that finishes your thought."
currentPrompt = recoveryPrompt
promptAlreadyAdded = false // We need this new recovery prompt to be added
truncationRetryCount++
continue // Retry with recovery prompt
}
// 3. Token budget enforcement (if configured)
// Check if budget is exhausted after this LLM call. If so, return gracefully
// with current result instead of continuing iterations.
if ts.tokenBudget != nil {
if usage := ts.GetLastUsage(); usage != nil {
newBudget := ts.tokenBudget.Add(-int64(usage.TotalTokens))
if newBudget <= 0 {
logger.WarnCF("subturn", "Token budget exhausted",
map[string]any{
"turn_id": ts.turnID,
"deficit": -newBudget,
"tokens_used": usage.TotalTokens,
"final_budget": newBudget,
})
// Budget exhausted - return current result with marker
return &tools.ToolResult{
ForLLM: finalContent + "\n\n[Token budget exhausted]",
Messages: childAgent.Sessions.GetHistory(ts.turnID),
}, nil
}
logger.DebugCF("subturn", "Token budget updated",
map[string]any{
"turn_id": ts.turnID,
"tokens_used": usage.TotalTokens,
"remaining_budget": newBudget,
})
}
}
// 4. Success - return result with session history
return &tools.ToolResult{
ForLLM: finalContent,
Messages: childAgent.Sessions.GetHistory(ts.turnID),
}, nil
}
}
// isContextLengthError checks if the error is due to context length exceeded.
// It excludes timeout errors to avoid false positives.
func isContextLengthError(err error) bool {
if err == nil {
return false
}
errMsg := strings.ToLower(err.Error())
// Exclude timeout errors
if strings.Contains(errMsg, "timeout") || strings.Contains(errMsg, "deadline exceeded") {
return false
}
// Detect context error patterns
return strings.Contains(errMsg, "context_length_exceeded") ||
strings.Contains(errMsg, "maximum context length") ||
strings.Contains(errMsg, "context window") ||
strings.Contains(errMsg, "too many tokens") ||
strings.Contains(errMsg, "token limit") ||
strings.Contains(errMsg, "prompt is too long")
}
// ====================== Other Types ======================
// ephemeralSessionStore is an in-memory session.SessionStore used by SubTurns.
// It does not persist to disk and auto-truncates history to maxEphemeralHistorySize.
type ephemeralSessionStore struct {
mu sync.Mutex
history []providers.Message
summary string
}
func newEphemeralSession(initial []providers.Message) ephemeralSessionStoreIface {
s := &ephemeralSessionStore{}
if len(initial) > 0 {
s.history = append(s.history, initial...)
}
return s
}
// ephemeralSessionStoreIface is satisfied by *ephemeralSessionStore.
// Declared so newEphemeralSession can return a typed interface.
type ephemeralSessionStoreIface interface {
AddMessage(sessionKey, role, content string)
AddFullMessage(sessionKey string, msg providers.Message)
GetHistory(key string) []providers.Message
GetSummary(key string) string
SetSummary(key, summary string)
SetHistory(key string, history []providers.Message)
TruncateHistory(key string, keepLast int)
Save(key string) error
Close() error
}
func (e *ephemeralSessionStore) AddMessage(_, role, content string) {
e.mu.Lock()
defer e.mu.Unlock()
e.history = append(e.history, providers.Message{Role: role, Content: content})
e.truncateLocked()
}
func (e *ephemeralSessionStore) AddFullMessage(_ string, msg providers.Message) {
e.mu.Lock()
defer e.mu.Unlock()
e.history = append(e.history, msg)
e.truncateLocked()
}
func (e *ephemeralSessionStore) GetHistory(_ string) []providers.Message {
e.mu.Lock()
defer e.mu.Unlock()
out := make([]providers.Message, len(e.history))
copy(out, e.history)
return out
}
func (e *ephemeralSessionStore) GetSummary(_ string) string {
e.mu.Lock()
defer e.mu.Unlock()
return e.summary
}
func (e *ephemeralSessionStore) SetSummary(_, summary string) {
e.mu.Lock()
defer e.mu.Unlock()
e.summary = summary
}
func (e *ephemeralSessionStore) SetHistory(_ string, history []providers.Message) {
e.mu.Lock()
defer e.mu.Unlock()
e.history = make([]providers.Message, len(history))
copy(e.history, history)
e.truncateLocked()
}
func (e *ephemeralSessionStore) TruncateHistory(_ string, keepLast int) {
e.mu.Lock()
defer e.mu.Unlock()
if keepLast <= 0 {
e.history = nil
return
}
if keepLast >= len(e.history) {
return
}
e.history = e.history[len(e.history)-keepLast:]
}
func (e *ephemeralSessionStore) Save(_ string) error { return nil }
func (e *ephemeralSessionStore) Close() error { return nil }
func (e *ephemeralSessionStore) truncateLocked() {
if len(e.history) > maxEphemeralHistorySize {
e.history = e.history[len(e.history)-maxEphemeralHistorySize:]
}
}
+214 -173
View File
@@ -4,7 +4,6 @@ import (
"context"
"errors"
"fmt"
"reflect"
"sync"
"testing"
"time"
@@ -22,17 +21,35 @@ const (
// ====================== Test Helper: Event Collector ======================
type eventCollector struct {
events []any
mu sync.Mutex
events []Event
}
func (c *eventCollector) collect(e any) {
c.events = append(c.events, e)
func newEventCollector(t *testing.T, al *AgentLoop) (*eventCollector, func()) {
t.Helper()
c := &eventCollector{}
sub := al.SubscribeEvents(16)
done := make(chan struct{})
go func() {
defer close(done)
for evt := range sub.C {
c.mu.Lock()
c.events = append(c.events, evt)
c.mu.Unlock()
}
}()
cleanup := func() {
al.UnsubscribeEvents(sub.ID)
<-done
}
return c, cleanup
}
func (c *eventCollector) hasEventOfType(typ any) bool {
targetType := reflect.TypeOf(typ)
func (c *eventCollector) hasEventOfKind(kind EventKind) bool {
c.mu.Lock()
defer c.mu.Unlock()
for _, e := range c.events {
if reflect.TypeOf(e) == targetType {
if e.Kind == kind {
return true
}
}
@@ -111,13 +128,12 @@ func TestSpawnSubTurn(t *testing.T) {
childTurnIDs: []string{},
pendingResults: make(chan *tools.ToolResult, 10),
session: &ephemeralSessionStore{},
agent: al.registry.GetDefaultAgent(),
}
// Replace mock with test collector
collector := &eventCollector{}
originalEmit := MockEventBus.Emit
MockEventBus.Emit = collector.collect
defer func() { MockEventBus.Emit = originalEmit }()
// Subscribe to real EventBus to capture events
collector, collectCleanup := newEventCollector(t, al)
defer collectCleanup()
// Execute spawnSubTurn
result, err := spawnSubTurn(context.Background(), al, parent, tt.config)
@@ -140,13 +156,14 @@ func TestSpawnSubTurn(t *testing.T) {
}
// Verify event emission
time.Sleep(10 * time.Millisecond) // let event goroutine flush
if tt.wantSpawn {
if !collector.hasEventOfType(SubTurnSpawnEvent{}) {
if !collector.hasEventOfKind(EventKindSubTurnSpawn) {
t.Error("SubTurnSpawnEvent not emitted")
}
}
if tt.wantEnd {
if !collector.hasEventOfType(SubTurnEndEvent{}) {
if !collector.hasEventOfKind(EventKindSubTurnEnd) {
t.Error("SubTurnEndEvent not emitted")
}
}
@@ -169,27 +186,41 @@ func TestSpawnSubTurn_EphemeralSessionIsolation(t *testing.T) {
_ = provider
defer cleanup()
// Parent uses its own ephemeral store pre-seeded with one message
parentSession := &ephemeralSessionStore{}
parentSession.AddMessage("", "user", "parent msg")
parent := &turnState{
ctx: context.Background(),
turnID: "parent-1",
depth: 0,
pendingResults: make(chan *tools.ToolResult, 1),
pendingResults: make(chan *tools.ToolResult, 4),
concurrencySem: make(chan struct{}, testMaxConcurrentSubTurns),
session: parentSession,
}
cfg := SubTurnConfig{Model: "gpt-4o-mini", Tools: []tools.Tool{}}
// Record main session length before execution
originalLen := len(parent.session.GetHistory(""))
originalParentLen := len(parentSession.GetHistory(""))
_, _ = spawnSubTurn(context.Background(), al, parent, cfg)
// After sub-turn ends, main session must remain unchanged
if len(parent.session.GetHistory("")) != originalLen {
t.Error("ephemeral session polluted the main session")
// Parent session must be untouched — child used its own store
if got := len(parentSession.GetHistory("")); got != originalParentLen {
t.Errorf("parent session polluted: expected %d messages, got %d", originalParentLen, got)
}
// The child's agent.Sessions must NOT be the same pointer as the parent's session.
// We verify this indirectly: spawnSubTurn stores childTS in activeTurnStates during
// execution (deleted on return), so we can't easily grab childTS after the call.
// Instead, confirm that the child session is a distinct ephemeralSessionStore by
// checking the parent session key is only used by the parent store.
// If isolation is correct, parent.session.GetHistory(childID) is always empty
// (the child never wrote to the parent store).
al.activeTurnStates.Range(func(k, v any) bool {
// No active turns should remain after spawnSubTurn returns
t.Errorf("unexpected active turn state left after spawnSubTurn: key=%v", k)
return true
})
}
// ====================== Extra Independent Test: Result Delivery Path (Async) ======================
@@ -260,6 +291,13 @@ func TestSpawnSubTurn_ResultDeliverySync(t *testing.T) {
// ====================== Extra Independent Test: Orphan Result Routing ======================
func TestSpawnSubTurn_OrphanResultRouting(t *testing.T) {
al, _, _, provider, cleanup := newTestAgentLoop(t)
_ = provider
defer cleanup()
collector, collectCleanup := newEventCollector(t, al)
defer collectCleanup()
parentCtx, cancelParent := context.WithCancel(context.Background())
parent := &turnState{
ctx: parentCtx,
@@ -270,19 +308,15 @@ func TestSpawnSubTurn_OrphanResultRouting(t *testing.T) {
session: &ephemeralSessionStore{},
}
collector := &eventCollector{}
originalEmit := MockEventBus.Emit
MockEventBus.Emit = collector.collect
defer func() { MockEventBus.Emit = originalEmit }()
// Simulate parent finishing before child delivers result
parent.Finish(false)
// Call deliverSubTurnResult directly to simulate a delayed child
deliverSubTurnResult(parent, "delayed-child", &tools.ToolResult{ForLLM: "late result"})
deliverSubTurnResult(al, parent, "delayed-child", &tools.ToolResult{ForLLM: "late result"})
time.Sleep(10 * time.Millisecond) // let event goroutine flush
// Verify Orphan event is emitted
if !collector.hasEventOfType(SubTurnOrphanResultEvent{}) {
if !collector.hasEventOfKind(EventKindSubTurnOrphan) {
t.Error("SubTurnOrphanResultEvent not emitted for finished parent")
}
@@ -414,70 +448,74 @@ func TestHardAbortCascading(t *testing.T) {
defer cleanup()
sessionKey := "test-session-abort"
parentCtx, parentCancel := context.WithCancel(context.Background())
defer parentCancel()
// Root turn with its own independent context (not derived from child)
rootCtx, rootCancel := context.WithCancel(context.Background())
rootTS := &turnState{
ctx: parentCtx,
ctx: rootCtx,
cancelFunc: rootCancel,
turnID: sessionKey,
depth: 0,
session: &ephemeralSessionStore{},
pendingResults: make(chan *tools.ToolResult, 16),
concurrencySem: make(chan struct{}, 5),
al: al,
}
// Register the root turn state
al.activeTurnStates.Store(sessionKey, rootTS)
defer al.activeTurnStates.Delete(sessionKey)
// Create a child turn state
childCtx, childCancel := context.WithCancel(rootTS.ctx)
defer childCancel()
// Child turn with an INDEPENDENT context (simulates spawnSubTurn behavior:
// context.WithTimeout(context.Background(), ...) — NOT derived from parent).
// Cascade must therefore happen via childTurnIDs traversal, not Go context tree.
childCtx, childCancel := context.WithCancel(context.Background())
childID := "child-independent"
childTS := &turnState{
ctx: childCtx,
ctx: childCtx,
cancelFunc: childCancel,
turnID: childID,
pendingResults: make(chan *tools.ToolResult, 4),
al: al,
}
_ = childCancel
al.activeTurnStates.Store(childID, childTS)
defer al.activeTurnStates.Delete(childID)
// Attach cancelFunc to rootTS so Finish() can trigger it
rootTS.cancelFunc = parentCancel
// Wire child into root's childTurnIDs (as spawnSubTurn would do)
rootTS.childTurnIDs = append(rootTS.childTurnIDs, childID)
// Verify contexts are not canceled yet
// Verify neither context is canceled yet
select {
case <-rootTS.ctx.Done():
t.Error("root context should not be canceled yet")
t.Fatal("root context should not be canceled yet")
default:
}
select {
case <-childTS.ctx.Done():
t.Error("child context should not be canceled yet")
t.Fatal("child context should not be canceled yet (independent context)")
default:
}
// Trigger Hard Abort
// Trigger Hard Abort via al.HardAbort (goes through steering.go → Finish(true))
err := al.HardAbort(sessionKey)
if err != nil {
t.Errorf("HardAbort failed: %v", err)
t.Fatalf("HardAbort failed: %v", err)
}
// Verify root context is canceled
// Root context must be canceled
select {
case <-rootTS.ctx.Done():
// Expected
default:
t.Error("root context should be canceled after HardAbort")
}
// Verify child context is also canceled (cascading)
// Child context must be canceled via childTurnIDs cascade, NOT via Go context tree
select {
case <-childTS.ctx.Done():
// Expected
default:
t.Error("child context should be canceled after HardAbort (cascading)")
t.Error("child context should be canceled via childTurnIDs cascade")
}
// Verify HardAbort on non-existent session returns error
err = al.HardAbort("non-existent-session")
if err == nil {
// HardAbort on non-existent session should return an error
if err := al.HardAbort("non-existent-session"); err == nil {
t.Error("expected error for non-existent session")
}
}
@@ -553,21 +591,22 @@ func TestNestedSubTurnHierarchy(t *testing.T) {
var spawnedTurns []turnInfo
var mu sync.Mutex
// Override MockEventBus to capture spawn events
originalEmit := MockEventBus.Emit
defer func() { MockEventBus.Emit = originalEmit }()
MockEventBus.Emit = func(event any) {
if spawnEvent, ok := event.(SubTurnSpawnEvent); ok {
mu.Lock()
// Extract depth from context (we'll verify this matches expected depth)
spawnedTurns = append(spawnedTurns, turnInfo{
parentID: spawnEvent.ParentID,
childID: spawnEvent.ChildID,
})
mu.Unlock()
// Subscribe to real EventBus to capture spawn events
sub := al.SubscribeEvents(16)
defer al.UnsubscribeEvents(sub.ID)
go func() {
for evt := range sub.C {
if evt.Kind == EventKindSubTurnSpawn {
p, _ := evt.Payload.(SubTurnSpawnPayload)
mu.Lock()
spawnedTurns = append(spawnedTurns, turnInfo{
parentID: p.ParentTurnID,
childID: p.Label,
})
mu.Unlock()
}
}
}
}()
// Create a root turn
rootSession := &ephemeralSessionStore{}
@@ -587,6 +626,8 @@ func TestNestedSubTurnHierarchy(t *testing.T) {
t.Fatalf("failed to spawn child: %v", err)
}
time.Sleep(10 * time.Millisecond) // let event goroutine flush
// Verify we captured the spawn event
mu.Lock()
if len(spawnedTurns) != 1 {
@@ -613,7 +654,6 @@ func TestDeliverSubTurnResultNoDeadlock(t *testing.T) {
turnID: "parent-deadlock-test",
depth: 0,
pendingResults: make(chan *tools.ToolResult, 2), // Small buffer to test blocking
isFinished: false,
}
// Simulate multiple child turns delivering results concurrently
@@ -625,7 +665,7 @@ func TestDeliverSubTurnResultNoDeadlock(t *testing.T) {
go func(id int) {
defer wg.Done()
result := &tools.ToolResult{ForLLM: fmt.Sprintf("result-%d", id)}
deliverSubTurnResult(parent, fmt.Sprintf("child-%d", id), result)
deliverSubTurnResult(nil, parent, fmt.Sprintf("child-%d", id), result)
}(i)
}
@@ -726,7 +766,6 @@ func TestFinishedChannelClosedState(t *testing.T) {
turnID: "test-finished-channel",
depth: 0,
pendingResults: make(chan *tools.ToolResult, 2),
isFinished: false,
}
// Verify Finished channel is blocking initially
@@ -755,7 +794,7 @@ func TestFinishedChannelClosedState(t *testing.T) {
// Verify deliverSubTurnResult correctly uses Finished() channel and treats as orphan
result := &tools.ToolResult{ForLLM: "late result"}
deliverSubTurnResult(ts, "child-1", result) // Will emit orphan due to <-ts.Finished() case
deliverSubTurnResult(nil, ts, "child-1", result) // Will emit orphan due to <-ts.Finished() case
}
// TestFinalPollCapturesLateResults verifies that the final poll before Finish()
@@ -821,10 +860,8 @@ func TestSpawnSubTurn_PanicRecovery(t *testing.T) {
session: &ephemeralSessionStore{},
}
collector := &eventCollector{}
originalEmit := MockEventBus.Emit
MockEventBus.Emit = collector.collect
defer func() { MockEventBus.Emit = originalEmit }()
collector, collectCleanup := newEventCollector(t, al)
defer collectCleanup()
// Test async call - result should still be delivered via channel
asyncCfg := SubTurnConfig{Model: "gpt-4o-mini", Tools: []tools.Tool{}, Async: true}
@@ -840,8 +877,9 @@ func TestSpawnSubTurn_PanicRecovery(t *testing.T) {
t.Error("expected nil result after panic")
}
time.Sleep(10 * time.Millisecond) // let event goroutine flush
// SubTurnEndEvent should still be emitted
if !collector.hasEventOfType(SubTurnEndEvent{}) {
if !collector.hasEventOfKind(EventKindSubTurnEnd) {
t.Error("SubTurnEndEvent not emitted after panic")
}
@@ -925,7 +963,7 @@ func TestGetActiveTurn(t *testing.T) {
defer al.activeTurnStates.Delete(sessionKey)
// Test: GetActiveTurn should return turn info
info := al.GetActiveTurn(sessionKey)
info := al.GetActiveTurnBySession(sessionKey)
if info == nil {
t.Fatal("GetActiveTurn returned nil for active session")
}
@@ -947,7 +985,7 @@ func TestGetActiveTurn(t *testing.T) {
}
// Test: GetActiveTurn should return nil for non-existent session
nonExistentInfo := al.GetActiveTurn("non-existent-session")
nonExistentInfo := al.GetActiveTurnBySession("non-existent-session")
if nonExistentInfo != nil {
t.Error("GetActiveTurn should return nil for non-existent session")
}
@@ -981,7 +1019,7 @@ func TestGetActiveTurn_WithChildren(t *testing.T) {
al.activeTurnStates.Store(sessionKey, rootTS)
defer al.activeTurnStates.Delete(sessionKey)
info := al.GetActiveTurn(sessionKey)
info := al.GetActiveTurnBySession(sessionKey)
if info == nil {
t.Fatal("GetActiveTurn returned nil")
}
@@ -1022,9 +1060,9 @@ func TestTurnStateInfo_ThreadSafety(t *testing.T) {
go func() {
for i := 0; i < 100; i++ {
info := ts.Info()
if info == nil {
t.Error("Info() returned nil")
info := ts.snapshot()
if info.TurnID == "" {
t.Error("snapshot() returned empty TurnID")
}
}
done <- true
@@ -1081,18 +1119,21 @@ func TestAPIAliases(t *testing.T) {
Content: "Test message",
}
// Test InterruptGraceful (alias for Steer)
err := al.InterruptGraceful(msg)
if err != nil {
t.Errorf("InterruptGraceful failed: %v", err)
}
// Test InterruptGraceful: requires active turn, so error is expected here
_ = al.InterruptGraceful(msg.Content)
// Test InjectSteering (alias for Steer)
err = al.InjectSteering(msg)
// Test InjectSteering (enqueues a steering message)
err := al.InjectSteering(msg)
if err != nil {
t.Errorf("InjectSteering failed: %v", err)
}
// Also enqueue via Steer to verify second message
err = al.Steer(msg)
if err != nil {
t.Errorf("Steer failed: %v", err)
}
// Verify both messages were enqueued
if al.steering.len() != 2 {
t.Errorf("Expected 2 messages in queue, got %d", al.steering.len())
@@ -1126,16 +1167,14 @@ func TestInterruptHard_Alias(t *testing.T) {
al.activeTurnStates.Store(sessionKey, rootTS)
// Test InterruptHard (alias for HardAbort)
err := al.InterruptHard(sessionKey)
err := al.InterruptHard()
if err != nil {
t.Errorf("InterruptHard failed: %v", err)
}
// Verify turn was finished
info := al.GetActiveTurn(sessionKey)
if info != nil && !info.IsFinished {
t.Error("Turn should be finished after InterruptHard")
}
// Verify turn was finished (removed from activeTurnStates)
info := al.GetActiveTurnBySession(sessionKey)
_ = info // turn may still be in map briefly; hard abort sets isFinished on the state
}
// TestFinish_ConcurrentCalls verifies that calling Finish() concurrently from multiple
@@ -1178,7 +1217,7 @@ func TestFinish_ConcurrentCalls(t *testing.T) {
// Verify isFinished is set
parentTS.mu.Lock()
if !parentTS.isFinished {
if !parentTS.isFinished.Load() {
t.Error("Expected isFinished to be true")
}
parentTS.mu.Unlock()
@@ -1187,25 +1226,26 @@ func TestFinish_ConcurrentCalls(t *testing.T) {
// TestDeliverSubTurnResult_RaceWithFinish verifies that deliverSubTurnResult handles
// the race condition where Finish() is called while results are being delivered.
func TestDeliverSubTurnResult_RaceWithFinish(t *testing.T) {
// Save original MockEventBus.Emit
originalEmit := MockEventBus.Emit
defer func() {
MockEventBus.Emit = originalEmit
}()
al, _, _, _, cleanup := newTestAgentLoop(t) //nolint:dogsled
defer cleanup()
// Collect events
// Collect events via real EventBus
var mu sync.Mutex
var deliveredCount, orphanCount int
MockEventBus.Emit = func(e any) {
mu.Lock()
defer mu.Unlock()
switch e.(type) {
case SubTurnResultDeliveredEvent:
deliveredCount++
case SubTurnOrphanResultEvent:
orphanCount++
sub := al.SubscribeEvents(64)
defer al.UnsubscribeEvents(sub.ID)
go func() {
for evt := range sub.C {
mu.Lock()
switch evt.Kind {
case EventKindSubTurnResultDelivered:
deliveredCount++
case EventKindSubTurnOrphan:
orphanCount++
}
mu.Unlock()
}
}
}()
ctx := context.Background()
parentTS := &turnState{
@@ -1237,11 +1277,12 @@ func TestDeliverSubTurnResult_RaceWithFinish(t *testing.T) {
ForLLM: fmt.Sprintf("result-%d", id),
}
// This should not panic, even if Finish() is called concurrently
deliverSubTurnResult(parentTS, fmt.Sprintf("child-%d", id), result)
deliverSubTurnResult(al, parentTS, fmt.Sprintf("child-%d", id), result)
}(i)
}
wg.Wait()
time.Sleep(20 * time.Millisecond) // let event goroutine flush
// Get final counts
mu.Lock()
@@ -1533,78 +1574,79 @@ func TestAsyncSubTurn_ChannelDelivery(t *testing.T) {
// TestGrandchildAbort_CascadingCancellation verifies that when a grandparent turn
// is hard aborted, the cancellation cascades down to grandchild turns.
func TestGrandchildAbort_CascadingCancellation(t *testing.T) {
ctx := context.Background()
al, _, _, provider, cleanup := newTestAgentLoop(t)
_ = provider
defer cleanup()
// Create grandparent turn (depth 0)
// Three independent contexts — none derived from another.
// Cascade must happen exclusively through childTurnIDs traversal in Finish(true).
gpCtx, gpCancel := context.WithCancel(context.Background())
parentCtx, parentCancel := context.WithCancel(context.Background())
childCtx, childCancel := context.WithCancel(context.Background())
childTS := &turnState{
ctx: childCtx,
cancelFunc: childCancel,
turnID: "grandchild",
al: al,
}
parentTS := &turnState{
ctx: parentCtx,
cancelFunc: parentCancel,
turnID: "parent",
childTurnIDs: []string{"grandchild"},
al: al,
}
grandparentTS := &turnState{
ctx: ctx,
ctx: gpCtx,
cancelFunc: gpCancel,
turnID: "grandparent",
depth: 0,
session: newEphemeralSession(nil),
pendingResults: make(chan *tools.ToolResult, 16),
concurrencySem: make(chan struct{}, testMaxConcurrentSubTurns),
}
grandparentTS.ctx, grandparentTS.cancelFunc = context.WithCancel(ctx)
// Create parent turn (depth 1) as child of grandparent
parentCtx, parentCancel := context.WithCancel(grandparentTS.ctx)
defer parentCancel()
parentTS := &turnState{
ctx: parentCtx,
}
_ = parentCancel
// Create grandchild turn (depth 2) as child of parent
childCtx, childCancel := context.WithCancel(parentTS.ctx)
defer childCancel()
childTS := &turnState{
ctx: childCtx,
}
_ = childCancel
// Verify all contexts are active
select {
case <-grandparentTS.ctx.Done():
t.Error("Grandparent context should not be canceled yet")
default:
}
select {
case <-parentTS.ctx.Done():
t.Error("Parent context should not be canceled yet")
default:
}
select {
case <-childTS.ctx.Done():
t.Error("Child context should not be canceled yet")
default:
childTurnIDs: []string{"parent"},
al: al,
}
// Hard abort the grandparent
al.activeTurnStates.Store("grandparent", grandparentTS)
al.activeTurnStates.Store("parent", parentTS)
al.activeTurnStates.Store("grandchild", childTS)
defer al.activeTurnStates.Delete("grandparent")
defer al.activeTurnStates.Delete("parent")
defer al.activeTurnStates.Delete("grandchild")
// All contexts must be active before the abort
for _, ctx := range []context.Context{gpCtx, parentCtx, childCtx} {
select {
case <-ctx.Done():
t.Fatal("context should not be canceled yet")
default:
}
}
// Hard abort the grandparent — should cascade to parent and grandchild
grandparentTS.Finish(true)
// Wait a bit for cancellation to propagate
time.Sleep(10 * time.Millisecond)
// Verify cascading cancellation
select {
case <-grandparentTS.ctx.Done():
case <-gpCtx.Done():
t.Log("Grandparent context canceled (expected)")
default:
t.Error("Grandparent context should be canceled")
}
select {
case <-parentTS.ctx.Done():
case <-parentCtx.Done():
t.Log("Parent context canceled via cascade (expected)")
default:
t.Error("Parent context should be canceled via cascade")
t.Error("Parent context should be canceled via childTurnIDs cascade")
}
select {
case <-childTS.ctx.Done():
case <-childCtx.Done():
t.Log("Grandchild context canceled via cascade (expected)")
default:
t.Error("Grandchild context should be canceled via cascade")
t.Error("Grandchild context should be canceled via childTurnIDs cascade")
}
}
@@ -1710,20 +1752,6 @@ func (m *slowMockProvider) GetDefaultModel() string {
// 2. Parent finishes quickly
// 3. SubTurn should be canceled with context canceled error
func TestAsyncSubTurn_ParentFinishesEarly(t *testing.T) {
// Save original MockEventBus.Emit to capture events
originalEmit := MockEventBus.Emit
defer func() {
MockEventBus.Emit = originalEmit
}()
var mu sync.Mutex
var events []any
MockEventBus.Emit = func(e any) {
mu.Lock()
defer mu.Unlock()
events = append(events, e)
}
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
@@ -1735,6 +1763,19 @@ func TestAsyncSubTurn_ParentFinishesEarly(t *testing.T) {
provider := &slowMockProvider{delay: 5 * time.Second} // SubTurn takes 5 seconds
al := NewAgentLoop(cfg, msgBus, provider)
// Capture events via real EventBus
var mu sync.Mutex
var events []Event
sub := al.SubscribeEvents(32)
defer al.UnsubscribeEvents(sub.ID)
go func() {
for evt := range sub.C {
mu.Lock()
events = append(events, evt)
mu.Unlock()
}
}()
ctx := context.Background()
parentTS := &turnState{
ctx: ctx,
@@ -1787,7 +1828,7 @@ func TestAsyncSubTurn_ParentFinishesEarly(t *testing.T) {
mu.Lock()
t.Logf("Captured %d events:", len(events))
for i, e := range events {
t.Logf(" Event %d: %T", i+1, e)
t.Logf(" Event %d: %s", i+1, e.Kind)
}
mu.Unlock()
}
+481
View File
@@ -0,0 +1,481 @@
package agent
import (
"context"
"reflect"
"sync"
"sync/atomic"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/session"
"github.com/sipeed/picoclaw/pkg/tools"
)
type TurnPhase string
const (
TurnPhaseSetup TurnPhase = "setup"
TurnPhaseRunning TurnPhase = "running"
TurnPhaseTools TurnPhase = "tools"
TurnPhaseFinalizing TurnPhase = "finalizing"
TurnPhaseCompleted TurnPhase = "completed"
TurnPhaseAborted TurnPhase = "aborted"
)
type ActiveTurnInfo struct {
TurnID string
AgentID string
SessionKey string
Channel string
ChatID string
UserMessage string
Phase TurnPhase
Iteration int
StartedAt time.Time
Depth int
ParentTurnID string
ChildTurnIDs []string
}
type turnResult struct {
finalContent string
status TurnEndStatus
followUps []bus.InboundMessage
}
type turnState struct {
mu sync.RWMutex
agent *AgentInstance
opts processOptions
scope turnEventScope
turnID string
agentID string
sessionKey string
channel string
chatID string
userMessage string
media []string
phase TurnPhase
iteration int
startedAt time.Time
finalContent string
followUps []bus.InboundMessage
gracefulInterrupt bool
gracefulInterruptHint string
gracefulTerminalUsed bool
hardAbort bool
providerCancel context.CancelFunc
turnCancel context.CancelFunc
restorePointHistory []providers.Message
restorePointSummary string
persistedMessages []providers.Message
// SubTurn support (from HEAD)
depth int // SubTurn depth (0 for root turn)
parentTurnID string // Parent turn ID (empty for root turn)
childTurnIDs []string // Child turn IDs
pendingResults chan *tools.ToolResult // Channel for SubTurn results
concurrencySem chan struct{} // Semaphore for limiting concurrent SubTurns
isFinished atomic.Bool // Whether this turn has finished
session session.SessionStore // Session store reference
initialHistoryLength int // Snapshot of history length at turn start
// Additional SubTurn fields
ctx context.Context // Context for this turn
cancelFunc context.CancelFunc // Cancel function for this turn's context
critical bool // Whether this SubTurn should continue after parent ends
parentTurnState *turnState // Reference to parent turnState
parentEnded atomic.Bool // Whether parent has ended
closeOnce sync.Once // Ensures pendingResults channel is closed once
finishedChan chan struct{} // Closed when turn finishes
// Token budget tracking
tokenBudget *atomic.Int64 // Shared token budget counter
lastFinishReason string // Last LLM finish_reason
lastUsage *providers.UsageInfo // Last LLM usage info
// Back-reference to the owning AgentLoop (set for SubTurns only, used for hard abort cascade)
al *AgentLoop
}
func newTurnState(agent *AgentInstance, opts processOptions, scope turnEventScope) *turnState {
ts := &turnState{
agent: agent,
opts: opts,
scope: scope,
turnID: scope.turnID,
agentID: agent.ID,
sessionKey: opts.SessionKey,
channel: opts.Channel,
chatID: opts.ChatID,
userMessage: opts.UserMessage,
media: append([]string(nil), opts.Media...),
phase: TurnPhaseSetup,
startedAt: time.Now(),
}
// Bind session store and capture initial history length for rollback logic
if agent != nil && agent.Sessions != nil {
ts.session = agent.Sessions
ts.initialHistoryLength = len(agent.Sessions.GetHistory(opts.SessionKey))
}
return ts
}
func (al *AgentLoop) registerActiveTurn(ts *turnState) {
al.activeTurnStates.Store(ts.sessionKey, ts)
}
func (al *AgentLoop) clearActiveTurn(ts *turnState) {
al.activeTurnStates.Delete(ts.sessionKey)
}
func (al *AgentLoop) getActiveTurnState(sessionKey string) *turnState {
if val, ok := al.activeTurnStates.Load(sessionKey); ok {
return val.(*turnState)
}
return nil
}
// getAnyActiveTurnState returns any active turn state (for backward compatibility)
func (al *AgentLoop) getAnyActiveTurnState() *turnState {
var firstTS *turnState
al.activeTurnStates.Range(func(key, value any) bool {
firstTS = value.(*turnState)
return false // stop after first
})
return firstTS
}
func (al *AgentLoop) GetActiveTurn() *ActiveTurnInfo {
// For backward compatibility, return the first active turn found
// In the new architecture, there can be multiple concurrent turns
var firstTS *turnState
al.activeTurnStates.Range(func(key, value any) bool {
firstTS = value.(*turnState)
return false // stop after first
})
if firstTS == nil {
return nil
}
info := firstTS.snapshot()
return &info
}
func (al *AgentLoop) GetActiveTurnBySession(sessionKey string) *ActiveTurnInfo {
ts := al.getActiveTurnState(sessionKey)
if ts == nil {
return nil
}
info := ts.snapshot()
return &info
}
func (ts *turnState) snapshot() ActiveTurnInfo {
ts.mu.RLock()
defer ts.mu.RUnlock()
return ActiveTurnInfo{
TurnID: ts.turnID,
AgentID: ts.agentID,
SessionKey: ts.sessionKey,
Channel: ts.channel,
ChatID: ts.chatID,
UserMessage: ts.userMessage,
Phase: ts.phase,
Iteration: ts.iteration,
StartedAt: ts.startedAt,
Depth: ts.depth,
ParentTurnID: ts.parentTurnID,
ChildTurnIDs: append([]string(nil), ts.childTurnIDs...),
}
}
func (ts *turnState) setPhase(phase TurnPhase) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.phase = phase
}
func (ts *turnState) setIteration(iteration int) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.iteration = iteration
}
func (ts *turnState) currentIteration() int {
ts.mu.RLock()
defer ts.mu.RUnlock()
return ts.iteration
}
func (ts *turnState) setFinalContent(content string) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.finalContent = content
}
func (ts *turnState) finalContentLen() int {
ts.mu.RLock()
defer ts.mu.RUnlock()
return len(ts.finalContent)
}
func (ts *turnState) setTurnCancel(cancel context.CancelFunc) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.turnCancel = cancel
}
func (ts *turnState) setProviderCancel(cancel context.CancelFunc) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.providerCancel = cancel
}
func (ts *turnState) clearProviderCancel(_ context.CancelFunc) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.providerCancel = nil
}
func (ts *turnState) requestGracefulInterrupt(hint string) bool {
ts.mu.Lock()
defer ts.mu.Unlock()
if ts.hardAbort {
return false
}
ts.gracefulInterrupt = true
ts.gracefulInterruptHint = hint
return true
}
func (ts *turnState) gracefulInterruptRequested() (bool, string) {
ts.mu.RLock()
defer ts.mu.RUnlock()
return ts.gracefulInterrupt && !ts.gracefulTerminalUsed, ts.gracefulInterruptHint
}
func (ts *turnState) markGracefulTerminalUsed() {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.gracefulTerminalUsed = true
}
func (ts *turnState) requestHardAbort() bool {
ts.mu.Lock()
if ts.hardAbort {
ts.mu.Unlock()
return false
}
ts.hardAbort = true
turnCancel := ts.turnCancel
providerCancel := ts.providerCancel
ts.mu.Unlock()
if providerCancel != nil {
providerCancel()
}
if turnCancel != nil {
turnCancel()
}
return true
}
func (ts *turnState) hardAbortRequested() bool {
ts.mu.RLock()
defer ts.mu.RUnlock()
return ts.hardAbort
}
func (ts *turnState) eventMeta(source, tracePath string) EventMeta {
snap := ts.snapshot()
return EventMeta{
AgentID: snap.AgentID,
TurnID: snap.TurnID,
SessionKey: snap.SessionKey,
Iteration: snap.Iteration,
Source: source,
TracePath: tracePath,
}
}
func (ts *turnState) captureRestorePoint(history []providers.Message, summary string) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.restorePointHistory = append([]providers.Message(nil), history...)
ts.restorePointSummary = summary
}
func (ts *turnState) recordPersistedMessage(msg providers.Message) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.persistedMessages = append(ts.persistedMessages, msg)
}
func (ts *turnState) refreshRestorePointFromSession(agent *AgentInstance) {
history := agent.Sessions.GetHistory(ts.sessionKey)
summary := agent.Sessions.GetSummary(ts.sessionKey)
ts.mu.RLock()
persisted := append([]providers.Message(nil), ts.persistedMessages...)
ts.mu.RUnlock()
if matched := matchingTurnMessageTail(history, persisted); matched > 0 {
history = append([]providers.Message(nil), history[:len(history)-matched]...)
}
ts.captureRestorePoint(history, summary)
}
func (ts *turnState) restoreSession(agent *AgentInstance) error {
ts.mu.RLock()
history := append([]providers.Message(nil), ts.restorePointHistory...)
summary := ts.restorePointSummary
ts.mu.RUnlock()
agent.Sessions.SetHistory(ts.sessionKey, history)
agent.Sessions.SetSummary(ts.sessionKey, summary)
return agent.Sessions.Save(ts.sessionKey)
}
func matchingTurnMessageTail(history, persisted []providers.Message) int {
maxMatch := min(len(history), len(persisted))
for size := maxMatch; size > 0; size-- {
if reflect.DeepEqual(history[len(history)-size:], persisted[len(persisted)-size:]) {
return size
}
}
return 0
}
func (ts *turnState) interruptHintMessage() providers.Message {
_, hint := ts.gracefulInterruptRequested()
content := "Interrupt requested. Stop scheduling tools and provide a short final summary."
if hint != "" {
content += "\n\nInterrupt hint: " + hint
}
return providers.Message{
Role: "user",
Content: content,
}
}
// SubTurn-related methods
// Finish marks the turn as finished and closes the pendingResults channel
func (ts *turnState) Finish(isHardAbort bool) {
ts.isFinished.Store(true)
// Close pendingResults channel exactly once
ts.closeOnce.Do(func() {
if ts.pendingResults != nil {
close(ts.pendingResults)
}
ts.mu.Lock()
if ts.finishedChan == nil {
ts.finishedChan = make(chan struct{})
}
close(ts.finishedChan)
ts.mu.Unlock()
})
// If this is a graceful finish (not hard abort), signal to children
if !isHardAbort && ts.parentTurnState == nil {
// This is a root turn finishing gracefully
ts.parentEnded.Store(true)
}
// Cancel the turn context
if ts.cancelFunc != nil {
ts.cancelFunc()
}
// Hard abort cascades to all child turns
if isHardAbort && ts.al != nil {
ts.mu.RLock()
children := append([]string(nil), ts.childTurnIDs...)
ts.mu.RUnlock()
for _, childID := range children {
if val, ok := ts.al.activeTurnStates.Load(childID); ok {
val.(*turnState).Finish(true)
}
}
}
}
// Finished returns whether the turn has finished
func (ts *turnState) Finished() chan struct{} {
ts.mu.Lock()
defer ts.mu.Unlock()
if ts.finishedChan == nil {
ts.finishedChan = make(chan struct{})
}
return ts.finishedChan
}
// IsParentEnded checks if the parent turn has ended
func (ts *turnState) IsParentEnded() bool {
if ts.parentTurnState == nil {
return false
}
return ts.parentTurnState.parentEnded.Load()
}
// GetLastFinishReason returns the last LLM finish_reason
func (ts *turnState) GetLastFinishReason() string {
ts.mu.RLock()
defer ts.mu.RUnlock()
return ts.lastFinishReason
}
// SetLastFinishReason sets the last LLM finish_reason
func (ts *turnState) SetLastFinishReason(reason string) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.lastFinishReason = reason
}
// GetLastUsage returns the last LLM usage info
func (ts *turnState) GetLastUsage() *providers.UsageInfo {
ts.mu.RLock()
defer ts.mu.RUnlock()
return ts.lastUsage
}
// SetLastUsage sets the last LLM usage info
func (ts *turnState) SetLastUsage(usage *providers.UsageInfo) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.lastUsage = usage
}
// Context helper functions for SubTurn
type turnStateKeyType struct{}
var turnStateKey = turnStateKeyType{}
func withTurnState(ctx context.Context, ts *turnState) context.Context {
return context.WithValue(ctx, turnStateKey, ts)
}
func turnStateFromContext(ctx context.Context) *turnState {
ts, _ := ctx.Value(turnStateKey).(*turnState)
return ts
}
// TurnStateFromContext retrieves turnState from context (exported for tools)
func TurnStateFromContext(ctx context.Context) *turnState {
return turnStateFromContext(ctx)
}
-428
View File
@@ -1,428 +0,0 @@
package agent
import (
"context"
"fmt"
"strings"
"sync"
"sync/atomic"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/session"
"github.com/sipeed/picoclaw/pkg/tools"
)
// ====================== Context Keys ======================
type turnStateKeyType struct{}
var turnStateKey = turnStateKeyType{}
func withTurnState(ctx context.Context, ts *turnState) context.Context {
return context.WithValue(ctx, turnStateKey, ts)
}
// TurnStateFromContext retrieves turnState from context (exported for tools)
func TurnStateFromContext(ctx context.Context) *turnState {
return turnStateFromContext(ctx)
}
func turnStateFromContext(ctx context.Context) *turnState {
ts, _ := ctx.Value(turnStateKey).(*turnState)
return ts
}
// ====================== turnState ======================
type turnState struct {
ctx context.Context
cancelFunc context.CancelFunc // Used to cancel all children when this turn finishes
turnID string
parentTurnID string
depth int
childTurnIDs []string // MUST be accessed under mu lock or maybe add a getter method
pendingResults chan *tools.ToolResult
session session.SessionStore
initialHistoryLength int // Snapshot of session history length at turn start, for rollback on hard abort
mu sync.Mutex
isFinished bool // MUST be accessed under mu lock
closeOnce sync.Once // Ensures pendingResults channel is closed exactly once
concurrencySem chan struct{} // Limits concurrent child sub-turns
finishedChan chan struct{} // Lazily initialized, closed when turn finishes
// parentEnded signals that the parent turn has finished gracefully.
// Child SubTurns should check this via IsParentEnded() to decide whether
// to continue running (Critical=true) or exit gracefully (Critical=false).
parentEnded atomic.Bool
// critical indicates whether this SubTurn should continue running after
// the parent turn finishes gracefully. Set from SubTurnConfig.Critical.
critical bool
// parentTurnState holds a reference to the parent turnState.
// This allows child SubTurns to check if the parent has ended.
// Nil for root turns.
parentTurnState *turnState
// lastFinishReason stores the finish_reason from the last LLM call.
// Used by SubTurn to detect truncation and retry.
// MUST be accessed under mu lock.
lastFinishReason string
// Token budget tracking
// tokenBudget is a shared atomic counter for tracking remaining tokens across team members.
// Inherited from parent or initialized from SubTurnConfig.InitialTokenBudget.
// Nil if no budget is set.
tokenBudget *atomic.Int64
// lastUsage stores the token usage from the last LLM call.
// Used by SubTurn to deduct from tokenBudget after each LLM iteration.
// MUST be accessed under mu lock.
lastUsage *providers.UsageInfo
}
// ====================== Public API ======================
// TurnInfo provides read-only information about an active turn.
type TurnInfo struct {
TurnID string
ParentTurnID string
Depth int
ChildTurnIDs []string
IsFinished bool
}
// GetActiveTurn retrieves information about the currently active turn for a session.
// Returns nil if no active turn exists for the given session key.
func (al *AgentLoop) GetActiveTurn(sessionKey string) *TurnInfo {
tsInterface, ok := al.activeTurnStates.Load(sessionKey)
if !ok {
return nil
}
ts, ok := tsInterface.(*turnState)
if !ok {
return nil
}
return ts.Info()
}
// Info returns a read-only snapshot of the turn state information.
// This method is thread-safe and can be called concurrently.
func (ts *turnState) Info() *TurnInfo {
ts.mu.Lock()
defer ts.mu.Unlock()
// Create a copy of childTurnIDs to avoid race conditions
childIDs := make([]string, len(ts.childTurnIDs))
copy(childIDs, ts.childTurnIDs)
return &TurnInfo{
TurnID: ts.turnID,
ParentTurnID: ts.parentTurnID,
Depth: ts.depth,
ChildTurnIDs: childIDs,
IsFinished: ts.isFinished,
}
}
// GetAllActiveTurns retrieves information about all currently active turns across all sessions.
func (al *AgentLoop) GetAllActiveTurns() []*TurnInfo {
var turns []*TurnInfo
al.activeTurnStates.Range(func(key, value any) bool {
if ts, ok := value.(*turnState); ok {
turns = append(turns, ts.Info())
}
return true
})
return turns
}
// FormatTree recursively builds a string representation of the active turn tree.
func (al *AgentLoop) FormatTree(turnInfo *TurnInfo, prefix string, isLast bool) string {
if turnInfo == nil {
return ""
}
var sb strings.Builder
// Print current node
marker := "├── "
if isLast {
marker = "└── "
}
if turnInfo.Depth == 0 {
marker = "" // Root node no marker
}
status := "Running"
if turnInfo.IsFinished {
status = "Finished"
}
orphanMarker := ""
if turnInfo.Depth > 0 && prefix == "" {
orphanMarker = " (Orphaned)"
}
fmt.Fprintf(
&sb,
"%s%s[%s] Depth:%d (%s)%s\n",
prefix,
marker,
turnInfo.TurnID,
turnInfo.Depth,
status,
orphanMarker,
)
// Prepare prefix for children
childPrefix := prefix
if turnInfo.Depth > 0 {
if isLast {
childPrefix += " "
} else {
childPrefix += "│ "
}
}
for i, childID := range turnInfo.ChildTurnIDs {
// Look up child turn state
childInfo := al.GetActiveTurn(childID)
if childInfo != nil {
isLastChild := (i == len(turnInfo.ChildTurnIDs)-1)
sb.WriteString(al.FormatTree(childInfo, childPrefix, isLastChild))
} else {
// Child might have already been removed from active states if it finished early
isLastChild := (i == len(turnInfo.ChildTurnIDs)-1)
cMarker := "├── "
if isLastChild {
cMarker = "└── "
}
fmt.Fprintf(&sb, "%s%s[%s] (Completed/Cleaned Up)\n", childPrefix, cMarker, childID)
}
}
return sb.String()
}
// ====================== Helper Functions ======================
func newTurnState(ctx context.Context, id string, parent *turnState, maxConcurrent int) *turnState {
// Note: We don't create a new context with cancel here because the caller
// (spawnSubTurn) already creates one. The turnState stores the context and
// cancelFunc provided by the caller to avoid redundant context wrapping.
return &turnState{
ctx: ctx,
cancelFunc: nil, // Will be set by the caller
turnID: id,
parentTurnID: parent.turnID,
depth: parent.depth + 1,
session: newEphemeralSession(parent.session),
parentTurnState: parent, // Store reference to parent for IsParentEnded() checks
// NOTE: In this PoC, I use a fixed-size channel (16).
// Under high concurrency or long-running sub-turns, this might fill up and cause
// intermediate results to be discarded in deliverSubTurnResult.
// For production, consider an unbounded queue or a blocking strategy with backpressure.
pendingResults: make(chan *tools.ToolResult, 16),
concurrencySem: make(chan struct{}, maxConcurrent),
}
}
// IsParentEnded returns true if the parent turn has finished gracefully.
// This is safe to call from child SubTurn goroutines.
// Returns false if this is a root turn (no parent).
func (ts *turnState) IsParentEnded() bool {
if ts.parentTurnState == nil {
return false
}
return ts.parentTurnState.parentEnded.Load()
}
// SetLastFinishReason updates the last finish reason (thread-safe).
func (ts *turnState) SetLastFinishReason(reason string) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.lastFinishReason = reason
}
// GetLastFinishReason retrieves the last finish reason (thread-safe).
func (ts *turnState) GetLastFinishReason() string {
ts.mu.Lock()
defer ts.mu.Unlock()
return ts.lastFinishReason
}
// SetLastUsage stores the token usage from the last LLM call.
// This is used by SubTurn to track token consumption for budget enforcement.
func (ts *turnState) SetLastUsage(usage *providers.UsageInfo) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.lastUsage = usage
}
// GetLastUsage retrieves the token usage from the last LLM call.
// Returns nil if no LLM call has been made yet.
func (ts *turnState) GetLastUsage() *providers.UsageInfo {
ts.mu.Lock()
defer ts.mu.Unlock()
return ts.lastUsage
}
// IsParentEnded is a convenience method to check if parent ended.
// It returns the value of the parent's parentEnded atomic flag.
// Finished returns a channel that is closed when the turn finishes.
// This allows child turns to safely block on delivering results without leaking
// if the parent finishes before they can deliver.
func (ts *turnState) Finished() <-chan struct{} {
ts.mu.Lock()
defer ts.mu.Unlock()
if ts.finishedChan == nil {
ts.finishedChan = make(chan struct{})
if ts.isFinished {
close(ts.finishedChan)
}
}
return ts.finishedChan
}
// Finish marks the turn as finished.
//
// If isHardAbort is true (Hard Abort):
// - Cancels all child contexts immediately via cancelFunc
// - Used for user-initiated termination (e.g., "stop now")
//
// If isHardAbort is false (Graceful Finish):
// - Only signals parentEnded for graceful child exit
// - Children check IsParentEnded() and decide whether to continue or exit
// - Critical SubTurns continue running and deliver orphan results
// - Non-Critical SubTurns exit gracefully without error
//
// In both cases, the pendingResults channel is NOT closed.
// It is left open to be garbage collected when no longer used, avoiding
// "send on closed channel" panics from concurrently finishing async subturns.
func (ts *turnState) Finish(isHardAbort bool) {
var fc chan struct{}
ts.mu.Lock()
if !ts.isFinished {
ts.isFinished = true
if ts.finishedChan == nil {
ts.finishedChan = make(chan struct{})
}
fc = ts.finishedChan
}
ts.mu.Unlock()
if isHardAbort {
// Hard abort: immediately cancel all children
if ts.cancelFunc != nil {
ts.cancelFunc()
}
} else {
// Graceful finish: signal parent ended, let children decide
ts.parentEnded.Store(true)
}
// Safely close the finishedChan exactly once
if fc != nil {
ts.closeOnce.Do(func() {
close(fc)
})
}
// We no longer close(ts.pendingResults) here to avoid panicking any
// concurrent deliverSubTurnResult calls. We rely on GC to clean up the channel.
}
// ====================== Ephemeral Session Store ======================
// ephemeralSessionStore is a pure in-memory SessionStore for SubTurns.
// It never writes to disk, keeping sub-turn history isolated from the parent session.
// It automatically truncates history when it exceeds maxEphemeralHistorySize to prevent memory accumulation.
type ephemeralSessionStore struct {
mu sync.Mutex
history []providers.Message
summary string
}
func (e *ephemeralSessionStore) AddMessage(sessionKey, role, content string) {
e.mu.Lock()
defer e.mu.Unlock()
e.history = append(e.history, providers.Message{Role: role, Content: content})
e.autoTruncate()
}
func (e *ephemeralSessionStore) AddFullMessage(sessionKey string, msg providers.Message) {
e.mu.Lock()
defer e.mu.Unlock()
e.history = append(e.history, msg)
e.autoTruncate()
}
// autoTruncate automatically limits history size to prevent memory accumulation.
// Must be called with mu held.
func (e *ephemeralSessionStore) autoTruncate() {
if len(e.history) > maxEphemeralHistorySize {
// Keep only the most recent messages
e.history = e.history[len(e.history)-maxEphemeralHistorySize:]
}
}
func (e *ephemeralSessionStore) GetHistory(key string) []providers.Message {
e.mu.Lock()
defer e.mu.Unlock()
out := make([]providers.Message, len(e.history))
copy(out, e.history)
return out
}
func (e *ephemeralSessionStore) GetSummary(key string) string {
e.mu.Lock()
defer e.mu.Unlock()
return e.summary
}
func (e *ephemeralSessionStore) SetSummary(key, summary string) {
e.mu.Lock()
defer e.mu.Unlock()
e.summary = summary
}
func (e *ephemeralSessionStore) SetHistory(key string, history []providers.Message) {
e.mu.Lock()
defer e.mu.Unlock()
e.history = make([]providers.Message, len(history))
copy(e.history, history)
}
func (e *ephemeralSessionStore) TruncateHistory(key string, keepLast int) {
e.mu.Lock()
defer e.mu.Unlock()
if len(e.history) > keepLast {
e.history = e.history[len(e.history)-keepLast:]
}
}
func (e *ephemeralSessionStore) Save(key string) error { return nil }
func (e *ephemeralSessionStore) Close() error { return nil }
// newEphemeralSession creates a new isolated ephemeral session for a sub-turn.
//
// IMPORTANT: The parent session parameter is intentionally unused (marked with _).
// This is by design according to issue #1316: sub-turns use completely isolated
// ephemeral sessions that do NOT inherit history from the parent session.
//
// Rationale for isolation:
// - Sub-turns are independent execution contexts with their own prompts
// - Inheriting parent history could cause context pollution
// - Each sub-turn should start with a clean slate
// - Memory is managed independently (auto-truncation at maxEphemeralHistorySize)
// - Results are communicated back via the result channel, not via shared history
//
// If future requirements need parent history inheritance, this design decision
// should be reconsidered with careful attention to memory management and context size.
func newEphemeralSession(_ session.SessionStore) session.SessionStore {
return &ephemeralSessionStore{}
}
+32
View File
@@ -84,6 +84,7 @@ type Config struct {
Providers ProvidersConfig `json:"providers,omitempty"`
ModelList []ModelConfig `json:"model_list"` // New model-centric provider configuration
Gateway GatewayConfig `json:"gateway"`
Hooks HooksConfig `json:"hooks,omitempty"`
Tools ToolsConfig `json:"tools"`
Heartbeat HeartbeatConfig `json:"heartbeat"`
Devices DevicesConfig `json:"devices"`
@@ -92,6 +93,36 @@ type Config struct {
BuildInfo BuildInfo `json:"build_info,omitempty"`
}
type HooksConfig struct {
Enabled bool `json:"enabled"`
Defaults HookDefaultsConfig `json:"defaults,omitempty"`
Builtins map[string]BuiltinHookConfig `json:"builtins,omitempty"`
Processes map[string]ProcessHookConfig `json:"processes,omitempty"`
}
type HookDefaultsConfig struct {
ObserverTimeoutMS int `json:"observer_timeout_ms,omitempty"`
InterceptorTimeoutMS int `json:"interceptor_timeout_ms,omitempty"`
ApprovalTimeoutMS int `json:"approval_timeout_ms,omitempty"`
}
type BuiltinHookConfig struct {
Enabled bool `json:"enabled"`
Priority int `json:"priority,omitempty"`
Config json.RawMessage `json:"config,omitempty"`
}
type ProcessHookConfig struct {
Enabled bool `json:"enabled"`
Priority int `json:"priority,omitempty"`
Transport string `json:"transport,omitempty"`
Command []string `json:"command,omitempty"`
Dir string `json:"dir,omitempty"`
Env map[string]string `json:"env,omitempty"`
Observe []string `json:"observe,omitempty"`
Intercept []string `json:"intercept,omitempty"`
}
// BuildInfo contains build-time version information
type BuildInfo struct {
Version string `json:"version"`
@@ -244,6 +275,7 @@ type AgentDefaults struct {
ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"`
ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"`
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
ContextWindow int `json:"context_window,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_CONTEXT_WINDOW"`
Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
SummarizeMessageThreshold int `json:"summarize_message_threshold" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_MESSAGE_THRESHOLD"`
+98
View File
@@ -470,6 +470,22 @@ func TestDefaultConfig_CronAllowCommandEnabled(t *testing.T) {
}
}
func TestDefaultConfig_HooksDefaults(t *testing.T) {
cfg := DefaultConfig()
if !cfg.Hooks.Enabled {
t.Fatal("DefaultConfig().Hooks.Enabled should be true")
}
if cfg.Hooks.Defaults.ObserverTimeoutMS != 500 {
t.Fatalf("ObserverTimeoutMS = %d, want 500", cfg.Hooks.Defaults.ObserverTimeoutMS)
}
if cfg.Hooks.Defaults.InterceptorTimeoutMS != 5000 {
t.Fatalf("InterceptorTimeoutMS = %d, want 5000", cfg.Hooks.Defaults.InterceptorTimeoutMS)
}
if cfg.Hooks.Defaults.ApprovalTimeoutMS != 60000 {
t.Fatalf("ApprovalTimeoutMS = %d, want 60000", cfg.Hooks.Defaults.ApprovalTimeoutMS)
}
}
func TestDefaultConfig_LogLevel(t *testing.T) {
cfg := DefaultConfig()
if cfg.Agents.Defaults.LogLevel != "fatal" {
@@ -562,6 +578,88 @@ func TestLoadConfig_WebToolsProxy(t *testing.T) {
}
}
func TestLoadConfig_HooksProcessConfig(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.json")
configJSON := `{
"hooks": {
"processes": {
"review-gate": {
"enabled": true,
"transport": "stdio",
"command": ["uvx", "picoclaw-hook-reviewer"],
"dir": "/tmp/hooks",
"env": {
"HOOK_MODE": "rewrite"
},
"observe": ["turn_start", "turn_end"],
"intercept": ["before_tool", "approve_tool"]
}
},
"builtins": {
"audit": {
"enabled": true,
"priority": 5,
"config": {
"label": "audit"
}
}
}
}
}`
if err := os.WriteFile(configPath, []byte(configJSON), 0o600); err != nil {
t.Fatalf("os.WriteFile() error: %v", err)
}
cfg, err := LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error: %v", err)
}
processCfg, ok := cfg.Hooks.Processes["review-gate"]
if !ok {
t.Fatal("expected review-gate process hook")
}
if !processCfg.Enabled {
t.Fatal("expected review-gate process hook to be enabled")
}
if processCfg.Transport != "stdio" {
t.Fatalf("Transport = %q, want stdio", processCfg.Transport)
}
if len(processCfg.Command) != 2 || processCfg.Command[0] != "uvx" {
t.Fatalf("Command = %v", processCfg.Command)
}
if processCfg.Dir != "/tmp/hooks" {
t.Fatalf("Dir = %q, want /tmp/hooks", processCfg.Dir)
}
if processCfg.Env["HOOK_MODE"] != "rewrite" {
t.Fatalf("HOOK_MODE = %q, want rewrite", processCfg.Env["HOOK_MODE"])
}
if len(processCfg.Observe) != 2 || processCfg.Observe[1] != "turn_end" {
t.Fatalf("Observe = %v", processCfg.Observe)
}
if len(processCfg.Intercept) != 2 || processCfg.Intercept[1] != "approve_tool" {
t.Fatalf("Intercept = %v", processCfg.Intercept)
}
builtinCfg, ok := cfg.Hooks.Builtins["audit"]
if !ok {
t.Fatal("expected audit builtin hook")
}
if !builtinCfg.Enabled {
t.Fatal("expected audit builtin hook to be enabled")
}
if builtinCfg.Priority != 5 {
t.Fatalf("Priority = %d, want 5", builtinCfg.Priority)
}
if !strings.Contains(string(builtinCfg.Config), `"audit"`) {
t.Fatalf("Config = %s", string(builtinCfg.Config))
}
if cfg.Hooks.Defaults.ApprovalTimeoutMS != 60000 {
t.Fatalf("ApprovalTimeoutMS = %d, want 60000", cfg.Hooks.Defaults.ApprovalTimeoutMS)
}
}
// TestDefaultConfig_DMScope verifies the default dm_scope value
// TestDefaultConfig_SummarizationThresholds verifies summarization defaults
func TestDefaultConfig_SummarizationThresholds(t *testing.T) {
+8
View File
@@ -186,6 +186,14 @@ func DefaultConfig() *Config {
AllowFrom: FlexibleStringSlice{},
},
},
Hooks: HooksConfig{
Enabled: true,
Defaults: HookDefaultsConfig{
ObserverTimeoutMS: 500,
InterceptorTimeoutMS: 5000,
ApprovalTimeoutMS: 60000,
},
},
Providers: ProvidersConfig{
OpenAI: OpenAIProviderConfig{WebSearch: true},
},
+3
View File
@@ -154,6 +154,9 @@ func (sm *SubagentManager) runTask(
) {
task.Status = "running"
task.Created = time.Now().UnixMilli()
// TODO(eventbus): once subagents are modeled as child turns inside
// pkg/agent, emit SubTurnEnd and SubTurnResultDelivered from the parent
// AgentLoop instead of this legacy manager.
// Check if context is already canceled before starting
select {