Merge branch 'main' into feat/skill-channel-commands

# Conflicts:
#	pkg/agent/loop.go
This commit is contained in:
afjcjsbx
2026-03-22 20:51:16 +01:00
62 changed files with 16762 additions and 708 deletions
+27 -16
View File
@@ -222,13 +222,10 @@ func (cb *ContextBuilder) InvalidateCache() {
// invalidation (bootstrap files + memory). Skill roots are handled separately
// because they require both directory-level and recursive file-level checks.
func (cb *ContextBuilder) sourcePaths() []string {
return []string{
filepath.Join(cb.workspace, "AGENTS.md"),
filepath.Join(cb.workspace, "SOUL.md"),
filepath.Join(cb.workspace, "USER.md"),
filepath.Join(cb.workspace, "IDENTITY.md"),
filepath.Join(cb.workspace, "memory", "MEMORY.md"),
}
agentDefinition := cb.LoadAgentDefinition()
paths := agentDefinition.trackedPaths(cb.workspace)
paths = append(paths, filepath.Join(cb.workspace, "memory", "MEMORY.md"))
return uniquePaths(paths)
}
// skillRoots returns all skill root directories that can affect
@@ -432,18 +429,32 @@ func skillFilesChangedSince(skillRoots []string, filesAtCache map[string]time.Ti
}
func (cb *ContextBuilder) LoadBootstrapFiles() string {
bootstrapFiles := []string{
"AGENTS.md",
"SOUL.md",
"USER.md",
"IDENTITY.md",
var sb strings.Builder
agentDefinition := cb.LoadAgentDefinition()
if agentDefinition.Agent != nil {
label := string(agentDefinition.Source)
if label == "" {
label = relativeWorkspacePath(cb.workspace, agentDefinition.Agent.Path)
}
fmt.Fprintf(&sb, "## %s\n\n%s\n\n", label, agentDefinition.Agent.Body)
}
if agentDefinition.Soul != nil {
fmt.Fprintf(
&sb,
"## %s\n\n%s\n\n",
relativeWorkspacePath(cb.workspace, agentDefinition.Soul.Path),
agentDefinition.Soul.Content,
)
}
if agentDefinition.User != nil {
fmt.Fprintf(&sb, "## %s\n\n%s\n\n", "USER.md", agentDefinition.User.Content)
}
var sb strings.Builder
for _, filename := range bootstrapFiles {
filePath := filepath.Join(cb.workspace, filename)
if agentDefinition.Source != AgentDefinitionSourceAgent {
filePath := filepath.Join(cb.workspace, "IDENTITY.md")
if data, err := os.ReadFile(filePath); err == nil {
fmt.Fprintf(&sb, "## %s\n\n%s\n\n", filename, data)
fmt.Fprintf(&sb, "## %s\n\n%s\n\n", "IDENTITY.md", data)
}
}
+176
View File
@@ -0,0 +1,176 @@
// PicoClaw - Ultra-lightweight personal AI agent
// License: MIT
//
// Copyright (c) 2026 PicoClaw contributors
package agent
import (
"encoding/json"
"unicode/utf8"
"github.com/sipeed/picoclaw/pkg/providers"
)
// parseTurnBoundaries returns the starting index of each Turn in the history.
// A Turn is a complete "user input → LLM iterations → final response" cycle
// (as defined in #1316). Each Turn begins at a user message and extends
// through all subsequent assistant/tool messages until the next user message.
//
// Cutting at a Turn boundary guarantees that no tool-call sequence
// (assistant+ToolCalls → tool results) is split across the cut.
func parseTurnBoundaries(history []providers.Message) []int {
var starts []int
for i, msg := range history {
if msg.Role == "user" {
starts = append(starts, i)
}
}
return starts
}
// isSafeBoundary reports whether index is a valid Turn boundary — i.e.,
// a position where the kept portion (history[index:]) begins at a user
// message, so no tool-call sequence is torn apart.
func isSafeBoundary(history []providers.Message, index int) bool {
if index <= 0 || index >= len(history) {
return true
}
return history[index].Role == "user"
}
// findSafeBoundary locates the nearest Turn boundary to targetIndex.
// It prefers the boundary at or before targetIndex (preserving more recent
// context). Falls back to the nearest boundary after targetIndex, and
// returns targetIndex unchanged only when no Turn boundary exists at all.
func findSafeBoundary(history []providers.Message, targetIndex int) int {
if len(history) == 0 {
return 0
}
if targetIndex <= 0 {
return 0
}
if targetIndex >= len(history) {
return len(history)
}
turns := parseTurnBoundaries(history)
if len(turns) == 0 {
return targetIndex
}
// Find the last Turn boundary at or before targetIndex.
// Prefer backward: keeps more recent messages.
backward := -1
for _, t := range turns {
if t <= targetIndex {
backward = t
}
}
if backward > 0 {
return backward
}
// No valid Turn boundary before target (or only at index 0 which
// would keep everything). Use the first Turn after targetIndex.
for _, t := range turns {
if t > targetIndex {
return t
}
}
// No Turn boundary after targetIndex either. The only boundary is at
// index 0, meaning the entire history is a single Turn. Return 0 to
// signal that safe compression is not possible — callers check for
// mid <= 0 and skip compression in that case.
return 0
}
// estimateMessageTokens estimates the token count for a single message,
// including Content, ReasoningContent, ToolCalls arguments, ToolCallID
// metadata, and Media items. Uses a heuristic of 2.5 characters per token.
func estimateMessageTokens(msg providers.Message) int {
chars := utf8.RuneCountInString(msg.Content)
// ReasoningContent (extended thinking / chain-of-thought) can be
// substantial and is stored in session history via AddFullMessage.
if msg.ReasoningContent != "" {
chars += utf8.RuneCountInString(msg.ReasoningContent)
}
for _, tc := range msg.ToolCalls {
chars += len(tc.ID) + len(tc.Type)
if tc.Function != nil {
// Count function name + arguments (the wire format for most providers).
// tc.Name mirrors tc.Function.Name — count only once to avoid double-counting.
chars += len(tc.Function.Name) + len(tc.Function.Arguments)
} else {
// Fallback: some provider formats use top-level Name without Function.
chars += len(tc.Name)
}
}
if msg.ToolCallID != "" {
chars += len(msg.ToolCallID)
}
// Per-message overhead for role label, JSON structure, separators.
const messageOverhead = 12
chars += messageOverhead
tokens := chars * 2 / 5
// Media items (images, files) are serialized by provider adapters into
// multipart or image_url payloads. Add a fixed per-item token estimate
// directly (not through the chars heuristic) since actual cost depends
// on resolution and provider-specific image tokenization.
const mediaTokensPerItem = 256
tokens += len(msg.Media) * mediaTokensPerItem
return tokens
}
// estimateToolDefsTokens estimates the total token cost of tool definitions
// as they appear in the LLM request. Each tool's name, description, and
// JSON schema parameters contribute to the context window budget.
func estimateToolDefsTokens(defs []providers.ToolDefinition) int {
if len(defs) == 0 {
return 0
}
totalChars := 0
for _, d := range defs {
totalChars += len(d.Function.Name) + len(d.Function.Description)
if d.Function.Parameters != nil {
if paramJSON, err := json.Marshal(d.Function.Parameters); err == nil {
totalChars += len(paramJSON)
}
}
// Per-tool overhead: type field, JSON structure, separators.
totalChars += 20
}
return totalChars * 2 / 5
}
// isOverContextBudget checks whether the assembled messages plus tool definitions
// and output reserve would exceed the model's context window. This enables
// proactive compression before calling the LLM, rather than reacting to 400 errors.
func isOverContextBudget(
contextWindow int,
messages []providers.Message,
toolDefs []providers.ToolDefinition,
maxTokens int,
) bool {
msgTokens := 0
for _, m := range messages {
msgTokens += estimateMessageTokens(m)
}
toolTokens := estimateToolDefsTokens(toolDefs)
total := msgTokens + toolTokens + maxTokens
return total > contextWindow
}
+826
View File
@@ -0,0 +1,826 @@
package agent
import (
"fmt"
"strings"
"testing"
"github.com/sipeed/picoclaw/pkg/providers"
)
// msgUser creates a user message.
func msgUser(content string) providers.Message {
return providers.Message{Role: "user", Content: content}
}
// msgAssistant creates a plain assistant message (no tool calls).
func msgAssistant(content string) providers.Message {
return providers.Message{Role: "assistant", Content: content}
}
// msgAssistantTC creates an assistant message with tool calls.
func msgAssistantTC(toolIDs ...string) providers.Message {
tcs := make([]providers.ToolCall, len(toolIDs))
for i, id := range toolIDs {
tcs[i] = providers.ToolCall{
ID: id,
Type: "function",
Name: "tool_" + id,
Function: &providers.FunctionCall{
Name: "tool_" + id,
Arguments: `{"key":"value"}`,
},
}
}
return providers.Message{Role: "assistant", ToolCalls: tcs}
}
// msgTool creates a tool result message.
func msgTool(callID, content string) providers.Message {
return providers.Message{Role: "tool", ToolCallID: callID, Content: content}
}
func TestParseTurnBoundaries(t *testing.T) {
tests := []struct {
name string
history []providers.Message
want []int
}{
{
name: "empty history",
history: nil,
want: nil,
},
{
name: "simple exchange",
history: []providers.Message{
msgUser("q1"),
msgAssistant("a1"),
msgUser("q2"),
msgAssistant("a2"),
},
want: []int{0, 2},
},
{
name: "tool-call Turn",
history: []providers.Message{
msgUser("search"),
msgAssistantTC("tc1"),
msgTool("tc1", "result"),
msgAssistant("found it"),
msgUser("thanks"),
msgAssistant("welcome"),
},
want: []int{0, 4},
},
{
name: "chained tool calls in single Turn",
history: []providers.Message{
msgUser("save and notify"),
msgAssistantTC("tc_save"),
msgTool("tc_save", "saved"),
msgAssistantTC("tc_notify"),
msgTool("tc_notify", "notified"),
msgAssistant("done"),
},
want: []int{0},
},
{
name: "no user messages",
history: []providers.Message{
msgAssistant("a1"),
msgAssistant("a2"),
},
want: nil,
},
{
name: "leading non-user messages",
history: []providers.Message{
msgAssistantTC("tc1"),
msgTool("tc1", "r1"),
msgAssistant("greeting"),
msgUser("hello"),
msgAssistant("hi"),
},
want: []int{3},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := parseTurnBoundaries(tt.history)
if len(got) != len(tt.want) {
t.Errorf("parseTurnBoundaries() = %v, want %v", got, tt.want)
return
}
for i := range got {
if got[i] != tt.want[i] {
t.Errorf("parseTurnBoundaries()[%d] = %d, want %d", i, got[i], tt.want[i])
}
}
})
}
}
func TestIsSafeBoundary(t *testing.T) {
tests := []struct {
name string
history []providers.Message
index int
want bool
}{
{
name: "empty history, index 0",
history: nil,
index: 0,
want: true,
},
{
name: "single user message, index 0",
history: []providers.Message{msgUser("hi")},
index: 0,
want: true,
},
{
name: "single user message, index 1 (end)",
history: []providers.Message{msgUser("hi")},
index: 1,
want: true,
},
{
name: "at user message",
history: []providers.Message{
msgAssistant("hello"),
msgUser("how are you"),
msgAssistant("fine"),
},
index: 1,
want: true,
},
{
name: "at assistant without tool calls",
history: []providers.Message{
msgUser("hello"),
msgAssistant("response"),
msgUser("follow up"),
},
index: 1,
want: false,
},
{
name: "at assistant with tool calls",
history: []providers.Message{
msgUser("search something"),
msgAssistantTC("tc1"),
msgTool("tc1", "result"),
msgAssistant("here is what I found"),
},
index: 1,
want: false,
},
{
name: "at tool result",
history: []providers.Message{
msgUser("do something"),
msgAssistantTC("tc1"),
msgTool("tc1", "done"),
msgAssistant("completed"),
},
index: 2,
want: false,
},
{
name: "negative index",
history: []providers.Message{
msgUser("hello"),
},
index: -1,
want: true,
},
{
name: "index beyond length",
history: []providers.Message{
msgUser("hello"),
},
index: 5,
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := isSafeBoundary(tt.history, tt.index)
if got != tt.want {
t.Errorf("isSafeBoundary(history, %d) = %v, want %v", tt.index, got, tt.want)
}
})
}
}
func TestFindSafeBoundary(t *testing.T) {
tests := []struct {
name string
history []providers.Message
targetIndex int
want int
}{
{
name: "empty history",
history: nil,
targetIndex: 0,
want: 0,
},
{
name: "target at 0",
history: []providers.Message{msgUser("hi")},
targetIndex: 0,
want: 0,
},
{
name: "target beyond length",
history: []providers.Message{msgUser("hi")},
targetIndex: 5,
want: 1,
},
{
name: "target already at user message",
history: []providers.Message{
msgUser("q1"),
msgAssistant("a1"),
msgUser("q2"),
msgAssistant("a2"),
},
targetIndex: 2,
want: 2,
},
{
name: "target at assistant, scan backward finds user",
history: []providers.Message{
msgUser("q1"),
msgAssistant("a1"),
msgUser("q2"),
msgAssistant("a2"),
msgUser("q3"),
},
targetIndex: 3, // assistant "a2"
want: 2, // backward to user "q2"
},
{
name: "target inside tool sequence, scan backward finds user",
history: []providers.Message{
msgUser("q1"),
msgAssistant("a1"),
msgUser("q2"),
msgAssistantTC("tc1", "tc2"),
msgTool("tc1", "r1"),
msgTool("tc2", "r2"),
msgAssistant("summary"),
msgUser("q3"),
},
targetIndex: 4, // tool result "r1"
want: 2, // backward: 3=assistant+TC (not safe), 2=user → safe
},
{
name: "target inside tool sequence, backward finds user before chain",
history: []providers.Message{
msgUser("q1"),
msgAssistant("a1"),
msgUser("q2"),
msgAssistantTC("tc1", "tc2"),
msgTool("tc1", "r1"),
msgTool("tc2", "r2"),
msgAssistant("summary"),
msgUser("q3"),
},
targetIndex: 5, // tool result "r2"
want: 2, // backward: 4=tool, 3=assistant+TC, 2=user → safe
},
{
name: "no backward user, scan forward finds one",
history: []providers.Message{
msgAssistantTC("tc1"),
msgTool("tc1", "r1"),
msgAssistant("a1"),
msgUser("q1"),
},
targetIndex: 1, // tool result
want: 3, // forward to user "q1"
},
{
name: "multi-step tool chain preserves atomicity",
history: []providers.Message{
msgUser("q1"),
msgAssistant("a1"),
msgUser("q2"),
msgAssistantTC("tc1"),
msgTool("tc1", "r1"),
msgAssistantTC("tc2"),
msgTool("tc2", "r2"),
msgAssistant("final"),
msgUser("q3"),
msgAssistant("a3"),
},
targetIndex: 5, // second assistant+TC
want: 2, // backward: 4=tool, 3=assistant+TC, 2=user → safe
},
{
name: "all non-user messages returns target unchanged",
history: []providers.Message{
msgAssistant("a1"),
msgAssistant("a2"),
msgAssistant("a3"),
},
targetIndex: 1,
want: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := findSafeBoundary(tt.history, tt.targetIndex)
if got != tt.want {
t.Errorf("findSafeBoundary(history, %d) = %d, want %d",
tt.targetIndex, got, tt.want)
}
})
}
}
func TestFindSafeBoundary_SingleTurnReturnsZero(t *testing.T) {
// A single Turn with no subsequent user message. The only Turn boundary
// is at index 0; cutting anywhere else would split the Turn's tool
// sequence. findSafeBoundary must return 0 so callers skip compression.
history := []providers.Message{
msgUser("do everything"), // 0 ← only Turn boundary
msgAssistantTC("tc1"), // 1
msgTool("tc1", "result"), // 2
msgAssistant("all done"), // 3
}
got := findSafeBoundary(history, 2)
if got != 0 {
t.Errorf("findSafeBoundary(single_turn, 2) = %d, want 0 (cannot split single Turn)", got)
}
}
func TestFindSafeBoundary_BackwardScanSkipsToolSequence(t *testing.T) {
// A long tool-call chain: user → assistant+TC → tool → tool → ... → assistant → user
// Target is inside the chain; boundary should skip the entire chain backward.
history := []providers.Message{
msgUser("start"), // 0
msgAssistant("before chain"), // 1
msgUser("trigger"), // 2 ← expected safe boundary
msgAssistantTC("t1", "t2", "t3"), // 3
msgTool("t1", "r1"), // 4
msgTool("t2", "r2"), // 5
msgTool("t3", "r3"), // 6
msgAssistantTC("t4"), // 7
msgTool("t4", "r4"), // 8
msgAssistant("chain done"), // 9
msgUser("next"), // 10
}
// Target at index 6 (middle of tool results)
got := findSafeBoundary(history, 6)
if got != 2 {
t.Errorf("findSafeBoundary(history, 6) = %d, want 2 (user before chain)", got)
}
}
func TestEstimateMessageTokens(t *testing.T) {
tests := []struct {
name string
msg providers.Message
want int // minimum expected tokens (exact value depends on overhead)
}{
{
name: "plain user message",
msg: msgUser("Hello, world!"),
want: 1, // at least some tokens
},
{
name: "empty message still has overhead",
msg: providers.Message{Role: "user"},
want: 1, // message overhead alone
},
{
name: "assistant with tool calls",
msg: msgAssistantTC("tc_123"),
want: 1,
},
{
name: "tool result with ID",
msg: msgTool("call_abc", "Here is the search result with lots of content"),
want: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := estimateMessageTokens(tt.msg)
if got < tt.want {
t.Errorf("estimateMessageTokens() = %d, want >= %d", got, tt.want)
}
})
}
}
func TestEstimateMessageTokens_ToolCallsContribute(t *testing.T) {
plain := msgAssistant("thinking")
withTC := providers.Message{
Role: "assistant",
Content: "thinking",
ToolCalls: []providers.ToolCall{
{
ID: "call_1",
Type: "function",
Name: "web_search",
Function: &providers.FunctionCall{
Name: "web_search",
Arguments: `{"query":"picoclaw agent framework","max_results":5}`,
},
},
},
}
plainTokens := estimateMessageTokens(plain)
withTCTokens := estimateMessageTokens(withTC)
if withTCTokens <= plainTokens {
t.Errorf("message with ToolCalls (%d tokens) should exceed plain message (%d tokens)",
withTCTokens, plainTokens)
}
}
func TestEstimateMessageTokens_MultibyteContent(t *testing.T) {
// Multi-byte characters (e.g. emoji, accented letters) are single runes
// but may map to different token counts. The heuristic should still produce
// reasonable estimates via RuneCountInString.
msg := msgUser("caf\u00e9 na\u00efve r\u00e9sum\u00e9 \u00fcber stra\u00dfe")
tokens := estimateMessageTokens(msg)
if tokens <= 0 {
t.Errorf("multibyte message should produce positive token count, got %d", tokens)
}
}
func TestEstimateMessageTokens_LargeArguments(t *testing.T) {
// Simulate a tool call with large JSON arguments.
largeArgs := fmt.Sprintf(`{"content":"%s"}`, strings.Repeat("x", 5000))
msg := providers.Message{
Role: "assistant",
ToolCalls: []providers.ToolCall{
{
ID: "call_large",
Type: "function",
Name: "write_file",
Function: &providers.FunctionCall{
Name: "write_file",
Arguments: largeArgs,
},
},
},
}
tokens := estimateMessageTokens(msg)
// 5000+ chars → at least 2000 tokens with the 2.5 char/token heuristic
if tokens < 2000 {
t.Errorf("large tool call arguments should produce significant token count, got %d", tokens)
}
}
func TestEstimateMessageTokens_ReasoningContent(t *testing.T) {
plain := msgAssistant("result")
withReasoning := providers.Message{
Role: "assistant",
Content: "result",
ReasoningContent: strings.Repeat("thinking step ", 200),
}
plainTokens := estimateMessageTokens(plain)
reasoningTokens := estimateMessageTokens(withReasoning)
if reasoningTokens <= plainTokens {
t.Errorf("message with ReasoningContent (%d tokens) should exceed plain message (%d tokens)",
reasoningTokens, plainTokens)
}
}
func TestEstimateMessageTokens_MediaItems(t *testing.T) {
plain := msgUser("describe this")
withMedia := providers.Message{
Role: "user",
Content: "describe this",
Media: []string{"media://img1.png", "media://img2.png"},
}
plainTokens := estimateMessageTokens(plain)
mediaTokens := estimateMessageTokens(withMedia)
if mediaTokens <= plainTokens {
t.Errorf("message with Media (%d tokens) should exceed plain message (%d tokens)",
mediaTokens, plainTokens)
}
// Each media item should add exactly 256 tokens (not run through chars*2/5).
expectedDelta := 256 * 2
actualDelta := mediaTokens - plainTokens
if actualDelta != expectedDelta {
t.Errorf("2 media items should add %d tokens, got delta %d", expectedDelta, actualDelta)
}
}
// --- estimateToolDefsTokens tests ---
func TestEstimateToolDefsTokens(t *testing.T) {
tests := []struct {
name string
defs []providers.ToolDefinition
want int // minimum expected tokens
}{
{
name: "empty tool list",
defs: nil,
want: 0,
},
{
name: "single tool with params",
defs: []providers.ToolDefinition{
{
Type: "function",
Function: providers.ToolFunctionDefinition{
Name: "web_search",
Description: "Search the web for information",
Parameters: map[string]any{
"type": "object",
"properties": map[string]any{
"query": map[string]any{"type": "string"},
},
"required": []any{"query"},
},
},
},
},
want: 1,
},
{
name: "tool without params",
defs: []providers.ToolDefinition{
{
Type: "function",
Function: providers.ToolFunctionDefinition{
Name: "list_dir",
Description: "List directory contents",
},
},
},
want: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := estimateToolDefsTokens(tt.defs)
if got < tt.want {
t.Errorf("estimateToolDefsTokens() = %d, want >= %d", got, tt.want)
}
})
}
}
func TestEstimateToolDefsTokens_ScalesWithCount(t *testing.T) {
makeTool := func(name string) providers.ToolDefinition {
return providers.ToolDefinition{
Type: "function",
Function: providers.ToolFunctionDefinition{
Name: name,
Description: "A test tool that does something useful",
Parameters: map[string]any{
"type": "object",
"properties": map[string]any{
"input": map[string]any{"type": "string", "description": "Input value"},
},
},
},
}
}
one := estimateToolDefsTokens([]providers.ToolDefinition{makeTool("tool_a")})
three := estimateToolDefsTokens([]providers.ToolDefinition{
makeTool("tool_a"), makeTool("tool_b"), makeTool("tool_c"),
})
if three <= one {
t.Errorf("3 tools (%d tokens) should exceed 1 tool (%d tokens)", three, one)
}
}
// --- isOverContextBudget tests ---
func TestIsOverContextBudget(t *testing.T) {
systemMsg := providers.Message{Role: "system", Content: strings.Repeat("x", 1000)}
userMsg := msgUser("hello")
smallHistory := []providers.Message{systemMsg, msgUser("q1"), msgAssistant("a1"), userMsg}
tools := []providers.ToolDefinition{
{
Type: "function",
Function: providers.ToolFunctionDefinition{
Name: "test_tool",
Description: "A test tool",
Parameters: map[string]any{"type": "object"},
},
},
}
tests := []struct {
name string
contextWindow int
messages []providers.Message
toolDefs []providers.ToolDefinition
maxTokens int
want bool
}{
{
name: "within budget",
contextWindow: 100000,
messages: smallHistory,
toolDefs: tools,
maxTokens: 4096,
want: false,
},
{
name: "over budget with small window",
contextWindow: 100, // very small window
messages: smallHistory,
toolDefs: tools,
maxTokens: 4096,
want: true,
},
{
name: "large max_tokens eats budget",
contextWindow: 2000,
messages: smallHistory,
toolDefs: tools,
maxTokens: 1800, // leaves almost no room
want: true,
},
{
name: "empty messages within budget",
contextWindow: 10000,
messages: nil,
toolDefs: nil,
maxTokens: 4096,
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := isOverContextBudget(tt.contextWindow, tt.messages, tt.toolDefs, tt.maxTokens)
if got != tt.want {
t.Errorf("isOverContextBudget() = %v, want %v", got, tt.want)
}
})
}
}
// --- Tests reflecting actual session data shape ---
// Session history never contains system messages. The system prompt is
// built dynamically by BuildMessages. These tests use realistic history
// shapes: user/assistant/tool only, with tool chains and reasoning content.
func TestFindSafeBoundary_SessionHistoryNoSystem(t *testing.T) {
// Real session history starts with a user message, not a system message.
history := []providers.Message{
msgUser("hello"), // 0
msgAssistant("hi there"), // 1
msgUser("search for X"), // 2
msgAssistantTC("tc1"), // 3
msgTool("tc1", "found X"), // 4
msgAssistant("here is X"), // 5
msgUser("thanks"), // 6
msgAssistant("you're welcome"), // 7
}
// Mid-point is 4 (tool result). Should snap backward to 2 (user).
got := findSafeBoundary(history, 4)
if got != 2 {
t.Errorf("findSafeBoundary(session_history, 4) = %d, want 2", got)
}
}
func TestFindSafeBoundary_SessionWithChainedTools(t *testing.T) {
// Session with chained tool calls (save then notify).
history := []providers.Message{
msgUser("save and notify"), // 0
msgAssistantTC("tc_save"), // 1
msgTool("tc_save", "saved"), // 2
msgAssistantTC("tc_notify"), // 3
msgTool("tc_notify", "notified"), // 4
msgAssistant("done"), // 5
msgUser("check status"), // 6
msgAssistant("all good"), // 7
}
// Target at 3 (inside chain). Should find user at 0, but backward
// scan stops at i>0, so forward scan finds user at 6.
// Actually: backward from 3: 2=tool (no), 1=assistantTC (no). Forward: 4=tool, 5=asst, 6=user ✓
got := findSafeBoundary(history, 3)
if got != 6 {
t.Errorf("findSafeBoundary(chained_tools, 3) = %d, want 6", got)
}
}
func TestEstimateMessageTokens_WithReasoningAndMedia(t *testing.T) {
// Message with all fields populated — mirrors what AddFullMessage stores.
msg := providers.Message{
Role: "assistant",
Content: "Here is the analysis.",
ReasoningContent: strings.Repeat("Let me think about this carefully. ", 50),
ToolCalls: []providers.ToolCall{
{
ID: "call_1",
Type: "function",
Name: "analyze",
Function: &providers.FunctionCall{
Name: "analyze",
Arguments: `{"data":"sample","depth":3}`,
},
},
},
}
tokens := estimateMessageTokens(msg)
// ReasoningContent alone is ~1700 chars → ~680 tokens.
// Content + TC + overhead adds more. Should be well above 500.
if tokens < 500 {
t.Errorf("message with reasoning+toolcalls should have significant tokens, got %d", tokens)
}
// Compare without reasoning to ensure it's counted.
msgNoReasoning := msg
msgNoReasoning.ReasoningContent = ""
tokensNoReasoning := estimateMessageTokens(msgNoReasoning)
if tokens <= tokensNoReasoning {
t.Errorf("reasoning content should add tokens: with=%d, without=%d", tokens, tokensNoReasoning)
}
}
func TestIsOverContextBudget_RealisticSession(t *testing.T) {
// Simulate what BuildMessages produces: system + session history + current user.
// System message is built by BuildMessages, not stored in session.
systemMsg := providers.Message{
Role: "system",
Content: strings.Repeat("system prompt content ", 100),
}
sessionHistory := []providers.Message{
msgUser("first question"),
msgAssistant("first answer"),
msgUser("use tool X"),
{
Role: "assistant",
Content: "I'll use tool X",
ToolCalls: []providers.ToolCall{
{
ID: "tc1", Type: "function", Name: "tool_x",
Function: &providers.FunctionCall{
Name: "tool_x",
Arguments: `{"query":"test","verbose":true}`,
},
},
},
},
{Role: "tool", Content: strings.Repeat("result data ", 200), ToolCallID: "tc1"},
msgAssistant("Here are the results from tool X."),
}
currentUser := msgUser("follow up question")
// Assemble as BuildMessages would.
messages := make([]providers.Message, 0, 1+len(sessionHistory)+1)
messages = append(messages, systemMsg)
messages = append(messages, sessionHistory...)
messages = append(messages, currentUser)
tools := []providers.ToolDefinition{
{
Type: "function",
Function: providers.ToolFunctionDefinition{
Name: "tool_x",
Description: "A useful tool",
Parameters: map[string]any{"type": "object"},
},
},
}
// With a large context window, should be within budget.
if isOverContextBudget(131072, messages, tools, 32768) {
t.Error("realistic session should be within 131072 context window")
}
// With a tiny context window, should exceed budget.
if !isOverContextBudget(500, messages, tools, 32768) {
t.Error("realistic session should exceed 500 context window")
}
}
+10 -10
View File
@@ -37,7 +37,7 @@ func setupWorkspace(t *testing.T, files map[string]string) string {
// Codex (only reads last system message as instructions).
func TestSingleSystemMessage(t *testing.T) {
tmpDir := setupWorkspace(t, map[string]string{
"IDENTITY.md": "# Identity\nTest agent.",
"AGENT.md": "# Agent\nTest agent.",
})
defer os.RemoveAll(tmpDir)
@@ -202,10 +202,10 @@ func TestMtimeAutoInvalidation(t *testing.T) {
}{
{
name: "bootstrap file change",
file: "IDENTITY.md",
contentV1: "# Original Identity",
contentV2: "# Updated Identity",
checkField: "Updated Identity",
file: "AGENT.md",
contentV1: "# Original Agent",
contentV2: "# Updated Agent",
checkField: "Updated Agent",
},
{
name: "memory file change",
@@ -280,7 +280,7 @@ func TestMtimeAutoInvalidation(t *testing.T) {
// even when source files haven't changed (useful for tests and reload commands).
func TestExplicitInvalidateCache(t *testing.T) {
tmpDir := setupWorkspace(t, map[string]string{
"IDENTITY.md": "# Test Identity",
"AGENT.md": "# Test Agent",
})
defer os.RemoveAll(tmpDir)
@@ -307,8 +307,8 @@ func TestExplicitInvalidateCache(t *testing.T) {
// when no files change (regression test for issue #607).
func TestCacheStability(t *testing.T) {
tmpDir := setupWorkspace(t, map[string]string{
"IDENTITY.md": "# Identity\nContent",
"SOUL.md": "# Soul\nContent",
"AGENT.md": "# Agent\nContent",
"SOUL.md": "# Soul\nContent",
})
defer os.RemoveAll(tmpDir)
@@ -607,7 +607,7 @@ description: delete-me-v1
// Run with: go test -race ./pkg/agent/ -run TestConcurrentBuildSystemPromptWithCache
func TestConcurrentBuildSystemPromptWithCache(t *testing.T) {
tmpDir := setupWorkspace(t, map[string]string{
"IDENTITY.md": "# Identity\nConcurrency test agent.",
"AGENT.md": "# Agent\nConcurrency test agent.",
"SOUL.md": "# Soul\nBe helpful.",
"memory/MEMORY.md": "# Memory\nUser prefers Go.",
"skills/demo/SKILL.md": "---\nname: demo\ndescription: \"demo skill\"\n---\n# Demo",
@@ -714,7 +714,7 @@ func BenchmarkBuildMessagesWithCache(b *testing.B) {
os.MkdirAll(filepath.Join(tmpDir, "memory"), 0o755)
os.MkdirAll(filepath.Join(tmpDir, "skills"), 0o755)
for _, name := range []string{"IDENTITY.md", "SOUL.md", "USER.md"} {
for _, name := range []string{"AGENT.md", "SOUL.md"} {
os.WriteFile(filepath.Join(tmpDir, name), []byte(strings.Repeat("Content.\n", 10)), 0o644)
}
+255
View File
@@ -0,0 +1,255 @@
package agent
import (
"os"
"path/filepath"
"slices"
"strings"
"github.com/gomarkdown/markdown/parser"
"gopkg.in/yaml.v3"
"github.com/sipeed/picoclaw/pkg/logger"
)
// AgentDefinitionSource identifies which agent bootstrap file produced the definition.
type AgentDefinitionSource string
const (
// AgentDefinitionSourceAgent indicates the new AGENT.md format.
AgentDefinitionSourceAgent AgentDefinitionSource = "AGENT.md"
// AgentDefinitionSourceAgents indicates the legacy AGENTS.md format.
AgentDefinitionSourceAgents AgentDefinitionSource = "AGENTS.md"
)
// AgentFrontmatter holds machine-readable AGENT.md configuration.
//
// Known fields are exposed directly for convenience. Fields keeps the full
// parsed frontmatter so future refactors can read additional keys without
// changing the loader contract again.
type AgentFrontmatter struct {
Name string `json:"name"`
Description string `json:"description"`
Tools []string `json:"tools,omitempty"`
Model string `json:"model,omitempty"`
MaxTurns *int `json:"maxTurns,omitempty"`
Skills []string `json:"skills,omitempty"`
MCPServers []string `json:"mcpServers,omitempty"`
Fields map[string]any `json:"fields,omitempty"`
}
// AgentPromptDefinition represents the parsed AGENT.md or AGENTS.md prompt file.
type AgentPromptDefinition struct {
Path string `json:"path"`
Raw string `json:"raw"`
Body string `json:"body"`
RawFrontmatter string `json:"raw_frontmatter,omitempty"`
Frontmatter AgentFrontmatter `json:"frontmatter"`
}
// SoulDefinition represents the resolved SOUL.md file linked to the agent.
type SoulDefinition struct {
Path string `json:"path"`
Content string `json:"content"`
}
// UserDefinition represents the resolved USER.md file linked to the workspace.
type UserDefinition struct {
Path string `json:"path"`
Content string `json:"content"`
}
// AgentContextDefinition captures the workspace agent definition in a runtime-friendly shape.
type AgentContextDefinition struct {
Source AgentDefinitionSource `json:"source,omitempty"`
Agent *AgentPromptDefinition `json:"agent,omitempty"`
Soul *SoulDefinition `json:"soul,omitempty"`
User *UserDefinition `json:"user,omitempty"`
}
// LoadAgentDefinition parses the workspace agent bootstrap files.
//
// It prefers the new AGENT.md format and its paired SOUL.md file. When the
// structured files are absent, it falls back to the legacy AGENTS.md layout so
// the current runtime can transition incrementally.
func (cb *ContextBuilder) LoadAgentDefinition() AgentContextDefinition {
return loadAgentDefinition(cb.workspace)
}
func loadAgentDefinition(workspace string) AgentContextDefinition {
definition := AgentContextDefinition{}
definition.User = loadUserDefinition(workspace)
agentPath := filepath.Join(workspace, string(AgentDefinitionSourceAgent))
if content, err := os.ReadFile(agentPath); err == nil {
prompt := parseAgentPromptDefinition(agentPath, string(content))
definition.Source = AgentDefinitionSourceAgent
definition.Agent = &prompt
soulPath := filepath.Join(workspace, "SOUL.md")
if content, err := os.ReadFile(soulPath); err == nil {
definition.Soul = &SoulDefinition{
Path: soulPath,
Content: string(content),
}
}
return definition
}
legacyPath := filepath.Join(workspace, string(AgentDefinitionSourceAgents))
if content, err := os.ReadFile(legacyPath); err == nil {
definition.Source = AgentDefinitionSourceAgents
definition.Agent = &AgentPromptDefinition{
Path: legacyPath,
Raw: string(content),
Body: string(content),
}
}
defaultSoulPath := filepath.Join(workspace, "SOUL.md")
if definition.Source != "" || fileExists(defaultSoulPath) {
if content, err := os.ReadFile(defaultSoulPath); err == nil {
definition.Soul = &SoulDefinition{
Path: defaultSoulPath,
Content: string(content),
}
}
}
return definition
}
func (definition AgentContextDefinition) trackedPaths(workspace string) []string {
paths := []string{
filepath.Join(workspace, string(AgentDefinitionSourceAgent)),
filepath.Join(workspace, "SOUL.md"),
filepath.Join(workspace, "USER.md"),
}
if definition.Source != AgentDefinitionSourceAgent {
paths = append(paths,
filepath.Join(workspace, string(AgentDefinitionSourceAgents)),
filepath.Join(workspace, "IDENTITY.md"),
)
}
return uniquePaths(paths)
}
func loadUserDefinition(workspace string) *UserDefinition {
userPath := filepath.Join(workspace, "USER.md")
if content, err := os.ReadFile(userPath); err == nil {
return &UserDefinition{
Path: userPath,
Content: string(content),
}
}
return nil
}
func parseAgentPromptDefinition(path, content string) AgentPromptDefinition {
frontmatter, body := splitAgentFrontmatter(content)
return AgentPromptDefinition{
Path: path,
Raw: content,
Body: body,
RawFrontmatter: frontmatter,
Frontmatter: parseAgentFrontmatter(path, frontmatter),
}
}
func parseAgentFrontmatter(path, frontmatter string) AgentFrontmatter {
frontmatter = strings.TrimSpace(frontmatter)
if frontmatter == "" {
return AgentFrontmatter{}
}
rawFields := make(map[string]any)
if err := yaml.Unmarshal([]byte(frontmatter), &rawFields); err != nil {
logger.WarnCF("agent", "Failed to parse AGENT.md frontmatter", map[string]any{
"path": path,
"error": err.Error(),
})
return AgentFrontmatter{}
}
var typed struct {
Name string `yaml:"name"`
Description string `yaml:"description"`
Tools []string `yaml:"tools"`
Model string `yaml:"model"`
MaxTurns *int `yaml:"maxTurns"`
Skills []string `yaml:"skills"`
MCPServers []string `yaml:"mcpServers"`
}
if err := yaml.Unmarshal([]byte(frontmatter), &typed); err != nil {
logger.WarnCF("agent", "Failed to decode AGENT.md frontmatter fields", map[string]any{
"path": path,
"error": err.Error(),
})
return AgentFrontmatter{}
}
return AgentFrontmatter{
Name: strings.TrimSpace(typed.Name),
Description: strings.TrimSpace(typed.Description),
Tools: append([]string(nil), typed.Tools...),
Model: strings.TrimSpace(typed.Model),
MaxTurns: typed.MaxTurns,
Skills: append([]string(nil), typed.Skills...),
MCPServers: append([]string(nil), typed.MCPServers...),
Fields: rawFields,
}
}
func splitAgentFrontmatter(content string) (frontmatter, body string) {
normalized := string(parser.NormalizeNewlines([]byte(content)))
lines := strings.Split(normalized, "\n")
if len(lines) == 0 || lines[0] != "---" {
return "", content
}
end := -1
for i := 1; i < len(lines); i++ {
if lines[i] == "---" {
end = i
break
}
}
if end == -1 {
return "", content
}
frontmatter = strings.Join(lines[1:end], "\n")
body = strings.Join(lines[end+1:], "\n")
body = strings.TrimLeft(body, "\n")
return frontmatter, body
}
func relativeWorkspacePath(workspace, path string) string {
if strings.TrimSpace(path) == "" {
return ""
}
relativePath, err := filepath.Rel(workspace, path)
if err == nil && relativePath != "." && !strings.HasPrefix(relativePath, "..") {
return filepath.ToSlash(relativePath)
}
return filepath.Clean(path)
}
func uniquePaths(paths []string) []string {
result := make([]string, 0, len(paths))
for _, path := range paths {
if strings.TrimSpace(path) == "" {
continue
}
cleaned := filepath.Clean(path)
if slices.Contains(result, cleaned) {
continue
}
result = append(result, cleaned)
}
return result
}
func fileExists(path string) bool {
_, err := os.Stat(path)
return err == nil
}
+302
View File
@@ -0,0 +1,302 @@
package agent
import (
"os"
"path/filepath"
"strings"
"testing"
"time"
)
func TestLoadAgentDefinitionParsesFrontmatterAndSoul(t *testing.T) {
tmpDir := setupWorkspace(t, map[string]string{
"AGENT.md": `---
name: pico
description: Structured agent
model: claude-3-7-sonnet
tools:
- shell
- search
maxTurns: 8
skills:
- review
- search-docs
mcpServers:
- github
metadata:
mode: strict
---
# Agent
Act directly and use tools first.
`,
"SOUL.md": "# Soul\nStay precise.",
})
defer cleanupWorkspace(t, tmpDir)
cb := NewContextBuilder(tmpDir)
definition := cb.LoadAgentDefinition()
if definition.Source != AgentDefinitionSourceAgent {
t.Fatalf("expected source %q, got %q", AgentDefinitionSourceAgent, definition.Source)
}
if definition.Agent == nil {
t.Fatal("expected AGENT.md definition to be loaded")
}
if definition.Agent.Body == "" || !strings.Contains(definition.Agent.Body, "Act directly") {
t.Fatalf("expected AGENT.md body to be preserved, got %q", definition.Agent.Body)
}
if definition.Agent.Frontmatter.Name != "pico" {
t.Fatalf("expected name to be parsed, got %q", definition.Agent.Frontmatter.Name)
}
if definition.Agent.Frontmatter.Model != "claude-3-7-sonnet" {
t.Fatalf("expected model to be parsed, got %q", definition.Agent.Frontmatter.Model)
}
if len(definition.Agent.Frontmatter.Tools) != 2 {
t.Fatalf("expected tools to be parsed, got %v", definition.Agent.Frontmatter.Tools)
}
if definition.Agent.Frontmatter.MaxTurns == nil || *definition.Agent.Frontmatter.MaxTurns != 8 {
t.Fatalf("expected maxTurns to be parsed, got %v", definition.Agent.Frontmatter.MaxTurns)
}
if len(definition.Agent.Frontmatter.Skills) != 2 {
t.Fatalf("expected skills to be parsed, got %v", definition.Agent.Frontmatter.Skills)
}
if len(definition.Agent.Frontmatter.MCPServers) != 1 || definition.Agent.Frontmatter.MCPServers[0] != "github" {
t.Fatalf("expected mcpServers to be parsed, got %v", definition.Agent.Frontmatter.MCPServers)
}
if definition.Agent.Frontmatter.Fields["metadata"] == nil {
t.Fatal("expected arbitrary frontmatter fields to remain available")
}
if definition.Soul == nil {
t.Fatal("expected SOUL.md to be loaded")
}
if !strings.Contains(definition.Soul.Content, "Stay precise") {
t.Fatalf("expected soul content to be loaded, got %q", definition.Soul.Content)
}
if definition.Soul.Path != filepath.Join(tmpDir, "SOUL.md") {
t.Fatalf("expected default SOUL.md path, got %q", definition.Soul.Path)
}
}
func TestLoadAgentDefinitionFallsBackToLegacyAgentsMarkdown(t *testing.T) {
tmpDir := setupWorkspace(t, map[string]string{
"AGENTS.md": "# Legacy Agent\nKeep compatibility.",
"SOUL.md": "# Soul\nLegacy soul.",
})
defer cleanupWorkspace(t, tmpDir)
cb := NewContextBuilder(tmpDir)
definition := cb.LoadAgentDefinition()
if definition.Source != AgentDefinitionSourceAgents {
t.Fatalf("expected source %q, got %q", AgentDefinitionSourceAgents, definition.Source)
}
if definition.Agent == nil {
t.Fatal("expected AGENTS.md to be loaded")
}
if definition.Agent.RawFrontmatter != "" {
t.Fatalf("legacy AGENTS.md should not have frontmatter, got %q", definition.Agent.RawFrontmatter)
}
if !strings.Contains(definition.Agent.Body, "Keep compatibility") {
t.Fatalf("expected legacy body to be preserved, got %q", definition.Agent.Body)
}
if definition.Soul == nil || !strings.Contains(definition.Soul.Content, "Legacy soul") {
t.Fatal("expected default SOUL.md to be loaded for legacy format")
}
}
func TestLoadAgentDefinitionLoadsWorkspaceUserMarkdown(t *testing.T) {
tmpDir := setupWorkspace(t, map[string]string{
"AGENT.md": "# Agent\nStructured agent.",
"USER.md": "# User\nWorkspace preferences.",
})
defer cleanupWorkspace(t, tmpDir)
cb := NewContextBuilder(tmpDir)
definition := cb.LoadAgentDefinition()
if definition.User == nil {
t.Fatal("expected USER.md to be loaded")
}
if definition.User.Path != filepath.Join(tmpDir, "USER.md") {
t.Fatalf("expected workspace USER.md path, got %q", definition.User.Path)
}
if !strings.Contains(definition.User.Content, "Workspace preferences") {
t.Fatalf("expected workspace USER.md content, got %q", definition.User.Content)
}
}
func TestLoadAgentDefinitionInvalidFrontmatterFallsBackToEmptyStructuredFields(t *testing.T) {
tmpDir := setupWorkspace(t, map[string]string{
"AGENT.md": `---
name: pico
tools:
- shell
broken
---
# Agent
Keep going.
`,
})
defer cleanupWorkspace(t, tmpDir)
cb := NewContextBuilder(tmpDir)
definition := cb.LoadAgentDefinition()
if definition.Agent == nil {
t.Fatal("expected AGENT.md definition to be loaded")
}
if !strings.Contains(definition.Agent.Body, "Keep going.") {
t.Fatalf("expected AGENT.md body to be preserved, got %q", definition.Agent.Body)
}
if definition.Agent.Frontmatter.Name != "" ||
definition.Agent.Frontmatter.Description != "" ||
definition.Agent.Frontmatter.Model != "" ||
definition.Agent.Frontmatter.MaxTurns != nil ||
len(definition.Agent.Frontmatter.Tools) != 0 ||
len(definition.Agent.Frontmatter.Skills) != 0 ||
len(definition.Agent.Frontmatter.MCPServers) != 0 ||
len(definition.Agent.Frontmatter.Fields) != 0 {
t.Fatalf("expected invalid frontmatter to decode as empty struct, got %+v", definition.Agent.Frontmatter)
}
}
func TestLoadBootstrapFilesUsesAgentBodyNotFrontmatter(t *testing.T) {
tmpDir := setupWorkspace(t, map[string]string{
"AGENT.md": `---
name: pico
model: codex-mini
---
# Agent
Follow the body prompt.
`,
"SOUL.md": "# Soul\nSpeak plainly.",
"IDENTITY.md": "# Identity\nWorkspace identity.",
})
defer cleanupWorkspace(t, tmpDir)
cb := NewContextBuilder(tmpDir)
bootstrap := cb.LoadBootstrapFiles()
if !strings.Contains(bootstrap, "Follow the body prompt") {
t.Fatalf("expected AGENT.md body in bootstrap, got %q", bootstrap)
}
if !strings.Contains(bootstrap, "Speak plainly") {
t.Fatalf("expected resolved soul content in bootstrap, got %q", bootstrap)
}
if strings.Contains(bootstrap, "name: pico") {
t.Fatalf("bootstrap should not expose raw frontmatter, got %q", bootstrap)
}
if strings.Contains(bootstrap, "model: codex-mini") {
t.Fatalf("bootstrap should not expose raw frontmatter, got %q", bootstrap)
}
if !strings.Contains(bootstrap, "SOUL.md") {
t.Fatalf("expected bootstrap to label SOUL.md, got %q", bootstrap)
}
if strings.Contains(bootstrap, "Workspace identity") {
t.Fatalf("structured bootstrap should ignore IDENTITY.md, got %q", bootstrap)
}
}
func TestLoadBootstrapFilesIncludesWorkspaceUserMarkdown(t *testing.T) {
tmpDir := setupWorkspace(t, map[string]string{
"AGENT.md": "# Agent\nFollow the new structure.",
"SOUL.md": "# Soul\nSpeak plainly.",
"USER.md": "# User\nShared profile.",
})
defer cleanupWorkspace(t, tmpDir)
cb := NewContextBuilder(tmpDir)
bootstrap := cb.LoadBootstrapFiles()
if !strings.Contains(bootstrap, "Shared profile") {
t.Fatalf("expected workspace USER.md in bootstrap, got %q", bootstrap)
}
if !strings.Contains(bootstrap, "## USER.md") {
t.Fatalf("expected USER.md heading in bootstrap, got %q", bootstrap)
}
}
func TestStructuredAgentIgnoresIdentityChanges(t *testing.T) {
tmpDir := setupWorkspace(t, map[string]string{
"AGENT.md": "# Agent\nFollow the new structure.",
"SOUL.md": "# Soul\nVersion one.",
"IDENTITY.md": "# Identity\nLegacy identity.",
})
defer cleanupWorkspace(t, tmpDir)
cb := NewContextBuilder(tmpDir)
promptV1 := cb.BuildSystemPromptWithCache()
if strings.Contains(promptV1, "Legacy identity") {
t.Fatalf("structured prompt should not include IDENTITY.md, got %q", promptV1)
}
identityPath := filepath.Join(tmpDir, "IDENTITY.md")
if err := os.WriteFile(identityPath, []byte("# Identity\nVersion two."), 0o644); err != nil {
t.Fatal(err)
}
future := time.Now().Add(2 * time.Second)
if err := os.Chtimes(identityPath, future, future); err != nil {
t.Fatal(err)
}
cb.systemPromptMutex.RLock()
changed := cb.sourceFilesChangedLocked()
cb.systemPromptMutex.RUnlock()
if changed {
t.Fatal("IDENTITY.md should not invalidate cache for structured agent definitions")
}
promptV2 := cb.BuildSystemPromptWithCache()
if promptV1 != promptV2 {
t.Fatal("structured prompt should remain stable after IDENTITY.md changes")
}
}
func TestStructuredAgentUserChangesInvalidateCache(t *testing.T) {
tmpDir := setupWorkspace(t, map[string]string{
"AGENT.md": "# Agent\nFollow the new structure.",
"SOUL.md": "# Soul\nVersion one.",
"USER.md": "# User\nInitial workspace preferences.",
})
defer cleanupWorkspace(t, tmpDir)
cb := NewContextBuilder(tmpDir)
promptV1 := cb.BuildSystemPromptWithCache()
if !strings.Contains(promptV1, "Initial workspace preferences") {
t.Fatalf("expected workspace USER.md in prompt, got %q", promptV1)
}
userPath := filepath.Join(tmpDir, "USER.md")
if err := os.WriteFile(userPath, []byte("# User\nUpdated workspace preferences."), 0o644); err != nil {
t.Fatal(err)
}
future := time.Now().Add(2 * time.Second)
if err := os.Chtimes(userPath, future, future); err != nil {
t.Fatal(err)
}
cb.systemPromptMutex.RLock()
changed := cb.sourceFilesChangedLocked()
cb.systemPromptMutex.RUnlock()
if !changed {
t.Fatal("workspace USER.md changes should invalidate cache")
}
promptV2 := cb.BuildSystemPromptWithCache()
if !strings.Contains(promptV2, "Updated workspace preferences") {
t.Fatalf("expected updated workspace USER.md in prompt, got %q", promptV2)
}
}
func cleanupWorkspace(t *testing.T, path string) {
t.Helper()
if err := os.RemoveAll(path); err != nil {
t.Fatalf("failed to clean up workspace %s: %v", path, err)
}
}
+121
View File
@@ -0,0 +1,121 @@
package agent
import (
"sync"
"sync/atomic"
"time"
)
const defaultEventSubscriberBuffer = 16
// EventSubscription identifies a subscriber channel returned by EventBus.Subscribe.
type EventSubscription struct {
ID uint64
C <-chan Event
}
type eventSubscriber struct {
ch chan Event
}
// EventBus is a lightweight multi-subscriber broadcaster for agent-loop events.
type EventBus struct {
mu sync.RWMutex
subs map[uint64]eventSubscriber
nextID uint64
closed bool
dropped [eventKindCount]atomic.Int64
}
// NewEventBus creates a new in-process event broadcaster.
func NewEventBus() *EventBus {
return &EventBus{
subs: make(map[uint64]eventSubscriber),
}
}
// Subscribe registers a new subscriber with the requested channel buffer size.
// A non-positive buffer uses the default size.
func (b *EventBus) Subscribe(buffer int) EventSubscription {
if buffer <= 0 {
buffer = defaultEventSubscriberBuffer
}
b.mu.Lock()
defer b.mu.Unlock()
if b.closed {
ch := make(chan Event)
close(ch)
return EventSubscription{C: ch}
}
b.nextID++
id := b.nextID
ch := make(chan Event, buffer)
b.subs[id] = eventSubscriber{ch: ch}
return EventSubscription{ID: id, C: ch}
}
// Unsubscribe removes a subscriber and closes its channel.
func (b *EventBus) Unsubscribe(id uint64) {
b.mu.Lock()
defer b.mu.Unlock()
sub, ok := b.subs[id]
if !ok {
return
}
delete(b.subs, id)
close(sub.ch)
}
// Emit broadcasts an event to all current subscribers without blocking.
// When a subscriber channel is full, the event is dropped for that subscriber.
func (b *EventBus) Emit(evt Event) {
if evt.Time.IsZero() {
evt.Time = time.Now()
}
b.mu.RLock()
defer b.mu.RUnlock()
if b.closed {
return
}
for _, sub := range b.subs {
select {
case sub.ch <- evt:
default:
if evt.Kind < eventKindCount {
b.dropped[evt.Kind].Add(1)
}
}
}
}
// Dropped returns the number of dropped events for a given kind.
func (b *EventBus) Dropped(kind EventKind) int64 {
if kind >= eventKindCount {
return 0
}
return b.dropped[kind].Load()
}
// Close closes all subscriber channels and stops future broadcasts.
func (b *EventBus) Close() {
b.mu.Lock()
defer b.mu.Unlock()
if b.closed {
return
}
b.closed = true
for id, sub := range b.subs {
close(sub.ch)
delete(b.subs, id)
}
}
+684
View File
@@ -0,0 +1,684 @@
package agent
import (
"context"
"os"
"slices"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/tools"
)
func TestEventBus_SubscribeEmitUnsubscribeClose(t *testing.T) {
eventBus := NewEventBus()
sub := eventBus.Subscribe(1)
eventBus.Emit(Event{
Kind: EventKindTurnStart,
Meta: EventMeta{TurnID: "turn-1"},
})
select {
case evt := <-sub.C:
if evt.Kind != EventKindTurnStart {
t.Fatalf("expected %v, got %v", EventKindTurnStart, evt.Kind)
}
if evt.Meta.TurnID != "turn-1" {
t.Fatalf("expected turn id turn-1, got %q", evt.Meta.TurnID)
}
case <-time.After(time.Second):
t.Fatal("timed out waiting for event")
}
eventBus.Unsubscribe(sub.ID)
if _, ok := <-sub.C; ok {
t.Fatal("expected subscriber channel to be closed after unsubscribe")
}
eventBus.Close()
closedSub := eventBus.Subscribe(1)
if _, ok := <-closedSub.C; ok {
t.Fatal("expected closed bus to return a closed subscriber channel")
}
}
func TestEventBus_DropsWhenSubscriberIsFull(t *testing.T) {
eventBus := NewEventBus()
sub := eventBus.Subscribe(1)
defer eventBus.Unsubscribe(sub.ID)
start := time.Now()
for i := 0; i < 1000; i++ {
eventBus.Emit(Event{Kind: EventKindLLMRequest})
}
if elapsed := time.Since(start); elapsed > 100*time.Millisecond {
t.Fatalf("Emit took too long with a blocked subscriber: %s", elapsed)
}
if got := eventBus.Dropped(EventKindLLMRequest); got != 999 {
t.Fatalf("expected 999 dropped events, got %d", got)
}
}
type scriptedToolProvider struct {
calls int
}
func (m *scriptedToolProvider) Chat(
ctx context.Context,
messages []providers.Message,
toolDefs []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
m.calls++
if m.calls == 1 {
return &providers.LLMResponse{
ToolCalls: []providers.ToolCall{
{
ID: "call-1",
Name: "mock_custom",
Arguments: map[string]any{"task": "ping"},
},
},
}, nil
}
return &providers.LLMResponse{
Content: "done",
}, nil
}
func (m *scriptedToolProvider) GetDefaultModel() string {
return "scripted-tool-model"
}
func TestAgentLoop_EmitsMinimalTurnEvents(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-eventbus-*")
if err != nil {
t.Fatalf("failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &scriptedToolProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
al.RegisterTool(&mockCustomTool{})
defaultAgent := al.registry.GetDefaultAgent()
if defaultAgent == nil {
t.Fatal("expected default agent")
}
sub := al.SubscribeEvents(16)
defer al.UnsubscribeEvents(sub.ID)
response, err := al.runAgentLoop(context.Background(), defaultAgent, processOptions{
SessionKey: "session-1",
Channel: "cli",
ChatID: "direct",
UserMessage: "run tool",
DefaultResponse: defaultResponse,
EnableSummary: false,
SendResponse: false,
})
if err != nil {
t.Fatalf("runAgentLoop failed: %v", err)
}
if response != "done" {
t.Fatalf("expected final response 'done', got %q", response)
}
events := collectEventStream(sub.C)
if len(events) != 8 {
t.Fatalf("expected 8 events, got %d", len(events))
}
kinds := make([]EventKind, 0, len(events))
for _, evt := range events {
kinds = append(kinds, evt.Kind)
}
expectedKinds := []EventKind{
EventKindTurnStart,
EventKindLLMRequest,
EventKindLLMResponse,
EventKindToolExecStart,
EventKindToolExecEnd,
EventKindLLMRequest,
EventKindLLMResponse,
EventKindTurnEnd,
}
if !slices.Equal(kinds, expectedKinds) {
t.Fatalf("unexpected event sequence: got %v want %v", kinds, expectedKinds)
}
turnID := events[0].Meta.TurnID
for i, evt := range events {
if evt.Meta.TurnID != turnID {
t.Fatalf("event %d has mismatched turn id %q, want %q", i, evt.Meta.TurnID, turnID)
}
if evt.Meta.SessionKey != "session-1" {
t.Fatalf("event %d has session key %q, want session-1", i, evt.Meta.SessionKey)
}
}
startPayload, ok := events[0].Payload.(TurnStartPayload)
if !ok {
t.Fatalf("expected TurnStartPayload, got %T", events[0].Payload)
}
if startPayload.UserMessage != "run tool" {
t.Fatalf("expected user message 'run tool', got %q", startPayload.UserMessage)
}
toolStartPayload, ok := events[3].Payload.(ToolExecStartPayload)
if !ok {
t.Fatalf("expected ToolExecStartPayload, got %T", events[3].Payload)
}
if toolStartPayload.Tool != "mock_custom" {
t.Fatalf("expected tool name mock_custom, got %q", toolStartPayload.Tool)
}
toolEndPayload, ok := events[4].Payload.(ToolExecEndPayload)
if !ok {
t.Fatalf("expected ToolExecEndPayload, got %T", events[4].Payload)
}
if toolEndPayload.Tool != "mock_custom" {
t.Fatalf("expected tool end payload for mock_custom, got %q", toolEndPayload.Tool)
}
if toolEndPayload.IsError {
t.Fatal("expected mock_custom tool to succeed")
}
turnEndPayload, ok := events[len(events)-1].Payload.(TurnEndPayload)
if !ok {
t.Fatalf("expected TurnEndPayload, got %T", events[len(events)-1].Payload)
}
if turnEndPayload.Status != TurnEndStatusCompleted {
t.Fatalf("expected completed turn, got %q", turnEndPayload.Status)
}
if turnEndPayload.Iterations != 2 {
t.Fatalf("expected 2 iterations, got %d", turnEndPayload.Iterations)
}
}
func TestAgentLoop_EmitsSteeringAndSkippedToolEvents(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-eventbus-steering-*")
if err != nil {
t.Fatalf("failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
tool1ExecCh := make(chan struct{})
tool1 := &slowTool{name: "tool_one", duration: 50 * time.Millisecond, execCh: tool1ExecCh}
tool2 := &slowTool{name: "tool_two", duration: 50 * time.Millisecond}
provider := &toolCallProvider{
toolCalls: []providers.ToolCall{
{
ID: "call_1",
Type: "function",
Name: "tool_one",
Function: &providers.FunctionCall{
Name: "tool_one",
Arguments: "{}",
},
Arguments: map[string]any{},
},
{
ID: "call_2",
Type: "function",
Name: "tool_two",
Function: &providers.FunctionCall{
Name: "tool_two",
Arguments: "{}",
},
Arguments: map[string]any{},
},
},
finalResp: "steered response",
}
msgBus := bus.NewMessageBus()
al := NewAgentLoop(cfg, msgBus, provider)
al.RegisterTool(tool1)
al.RegisterTool(tool2)
sub := al.SubscribeEvents(32)
defer al.UnsubscribeEvents(sub.ID)
resultCh := make(chan string, 1)
go func() {
resp, _ := al.ProcessDirectWithChannel(context.Background(), "do something", "test-session", "test", "chat1")
resultCh <- resp
}()
select {
case <-tool1ExecCh:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for tool_one to start")
}
if err := al.Steer(providers.Message{Role: "user", Content: "change course"}); err != nil {
t.Fatalf("Steer failed: %v", err)
}
select {
case resp := <-resultCh:
if resp != "steered response" {
t.Fatalf("expected steered response, got %q", resp)
}
case <-time.After(5 * time.Second):
t.Fatal("timeout waiting for steered response")
}
events := collectEventStream(sub.C)
steeringEvt, ok := findEvent(events, EventKindSteeringInjected)
if !ok {
t.Fatal("expected steering injected event")
}
steeringPayload, ok := steeringEvt.Payload.(SteeringInjectedPayload)
if !ok {
t.Fatalf("expected SteeringInjectedPayload, got %T", steeringEvt.Payload)
}
if steeringPayload.Count != 1 {
t.Fatalf("expected 1 steering message, got %d", steeringPayload.Count)
}
skippedEvt, ok := findEvent(events, EventKindToolExecSkipped)
if !ok {
t.Fatal("expected skipped tool event")
}
skippedPayload, ok := skippedEvt.Payload.(ToolExecSkippedPayload)
if !ok {
t.Fatalf("expected ToolExecSkippedPayload, got %T", skippedEvt.Payload)
}
if skippedPayload.Tool != "tool_two" {
t.Fatalf("expected skipped tool_two, got %q", skippedPayload.Tool)
}
interruptEvt, ok := findEvent(events, EventKindInterruptReceived)
if !ok {
t.Fatal("expected interrupt received event")
}
interruptPayload, ok := interruptEvt.Payload.(InterruptReceivedPayload)
if !ok {
t.Fatalf("expected InterruptReceivedPayload, got %T", interruptEvt.Payload)
}
if interruptPayload.Role != "user" {
t.Fatalf("expected interrupt role user, got %q", interruptPayload.Role)
}
if interruptPayload.Kind != InterruptKindSteering {
t.Fatalf("expected steering interrupt kind, got %q", interruptPayload.Kind)
}
if interruptPayload.ContentLen != len("change course") {
t.Fatalf("expected interrupt content len %d, got %d", len("change course"), interruptPayload.ContentLen)
}
}
func TestAgentLoop_EmitsContextCompressEventOnRetry(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-eventbus-compress-*")
if err != nil {
t.Fatalf("failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
contextErr := stringError("InvalidParameter: Total tokens of image and text exceed max message tokens")
provider := &failFirstMockProvider{
failures: 1,
failError: contextErr,
successResp: "Recovered from context error",
}
msgBus := bus.NewMessageBus()
al := NewAgentLoop(cfg, msgBus, provider)
defaultAgent := al.registry.GetDefaultAgent()
if defaultAgent == nil {
t.Fatal("expected default agent")
}
defaultAgent.Sessions.SetHistory("session-1", []providers.Message{
{Role: "user", Content: "Old message 1"},
{Role: "assistant", Content: "Old response 1"},
{Role: "user", Content: "Old message 2"},
{Role: "assistant", Content: "Old response 2"},
{Role: "user", Content: "Trigger message"},
})
sub := al.SubscribeEvents(16)
defer al.UnsubscribeEvents(sub.ID)
resp, err := al.runAgentLoop(context.Background(), defaultAgent, processOptions{
SessionKey: "session-1",
Channel: "cli",
ChatID: "direct",
UserMessage: "Trigger message",
DefaultResponse: defaultResponse,
EnableSummary: false,
SendResponse: false,
})
if err != nil {
t.Fatalf("runAgentLoop failed: %v", err)
}
if resp != "Recovered from context error" {
t.Fatalf("expected retry success, got %q", resp)
}
events := collectEventStream(sub.C)
retryEvt, ok := findEvent(events, EventKindLLMRetry)
if !ok {
t.Fatal("expected llm retry event")
}
retryPayload, ok := retryEvt.Payload.(LLMRetryPayload)
if !ok {
t.Fatalf("expected LLMRetryPayload, got %T", retryEvt.Payload)
}
if retryPayload.Reason != "context_limit" {
t.Fatalf("expected context_limit retry reason, got %q", retryPayload.Reason)
}
if retryPayload.Attempt != 1 {
t.Fatalf("expected retry attempt 1, got %d", retryPayload.Attempt)
}
compressEvt, ok := findEvent(events, EventKindContextCompress)
if !ok {
t.Fatal("expected context compress event")
}
payload, ok := compressEvt.Payload.(ContextCompressPayload)
if !ok {
t.Fatalf("expected ContextCompressPayload, got %T", compressEvt.Payload)
}
if payload.Reason != ContextCompressReasonRetry {
t.Fatalf("expected retry compress reason, got %q", payload.Reason)
}
if payload.DroppedMessages == 0 {
t.Fatal("expected dropped messages to be recorded")
}
}
func TestAgentLoop_EmitsSessionSummarizeEvent(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-eventbus-summary-*")
if err != nil {
t.Fatalf("failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
ContextWindow: 8000,
SummarizeMessageThreshold: 2,
SummarizeTokenPercent: 75,
},
},
}
msgBus := bus.NewMessageBus()
al := NewAgentLoop(cfg, msgBus, &simpleMockProvider{response: "summary text"})
defaultAgent := al.registry.GetDefaultAgent()
if defaultAgent == nil {
t.Fatal("expected default agent")
}
defaultAgent.Sessions.SetHistory("session-1", []providers.Message{
{Role: "user", Content: "Question one"},
{Role: "assistant", Content: "Answer one"},
{Role: "user", Content: "Question two"},
{Role: "assistant", Content: "Answer two"},
{Role: "user", Content: "Question three"},
{Role: "assistant", Content: "Answer three"},
})
sub := al.SubscribeEvents(16)
defer al.UnsubscribeEvents(sub.ID)
turnScope := al.newTurnEventScope(defaultAgent.ID, "session-1")
al.summarizeSession(defaultAgent, "session-1", turnScope)
events := collectEventStream(sub.C)
summaryEvt, ok := findEvent(events, EventKindSessionSummarize)
if !ok {
t.Fatal("expected session summarize event")
}
payload, ok := summaryEvt.Payload.(SessionSummarizePayload)
if !ok {
t.Fatalf("expected SessionSummarizePayload, got %T", summaryEvt.Payload)
}
if payload.SummaryLen == 0 {
t.Fatal("expected non-empty summary length")
}
}
func TestAgentLoop_EmitsFollowUpQueuedEvent(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-eventbus-followup-*")
if err != nil {
t.Fatalf("failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
provider := &toolCallProvider{
toolCalls: []providers.ToolCall{
{
ID: "call_async_1",
Type: "function",
Name: "async_followup",
Function: &providers.FunctionCall{
Name: "async_followup",
Arguments: "{}",
},
Arguments: map[string]any{},
},
},
finalResp: "async launched",
}
msgBus := bus.NewMessageBus()
al := NewAgentLoop(cfg, msgBus, provider)
doneCh := make(chan struct{})
al.RegisterTool(&asyncFollowUpTool{
name: "async_followup",
followUpText: "background result",
completionSig: doneCh,
})
defaultAgent := al.registry.GetDefaultAgent()
if defaultAgent == nil {
t.Fatal("expected default agent")
}
sub := al.SubscribeEvents(32)
defer al.UnsubscribeEvents(sub.ID)
resp, err := al.runAgentLoop(context.Background(), defaultAgent, processOptions{
SessionKey: "session-1",
Channel: "cli",
ChatID: "direct",
UserMessage: "run async tool",
DefaultResponse: defaultResponse,
EnableSummary: false,
SendResponse: false,
})
if err != nil {
t.Fatalf("runAgentLoop failed: %v", err)
}
if resp != "async launched" {
t.Fatalf("expected final response 'async launched', got %q", resp)
}
select {
case <-doneCh:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for async tool completion")
}
followUpEvt := waitForEvent(t, sub.C, 2*time.Second, func(evt Event) bool {
return evt.Kind == EventKindFollowUpQueued
})
payload, ok := followUpEvt.Payload.(FollowUpQueuedPayload)
if !ok {
t.Fatalf("expected FollowUpQueuedPayload, got %T", followUpEvt.Payload)
}
if payload.SourceTool != "async_followup" {
t.Fatalf("expected source tool async_followup, got %q", payload.SourceTool)
}
if payload.Channel != "cli" {
t.Fatalf("expected channel cli, got %q", payload.Channel)
}
if payload.ChatID != "direct" {
t.Fatalf("expected chat id direct, got %q", payload.ChatID)
}
if payload.ContentLen != len("background result") {
t.Fatalf("expected content len %d, got %d", len("background result"), payload.ContentLen)
}
if followUpEvt.Meta.SessionKey != "session-1" {
t.Fatalf("expected session key session-1, got %q", followUpEvt.Meta.SessionKey)
}
if followUpEvt.Meta.TurnID == "" {
t.Fatal("expected follow-up event to include turn id")
}
}
func collectEventStream(ch <-chan Event) []Event {
var events []Event
for {
select {
case evt, ok := <-ch:
if !ok {
return events
}
events = append(events, evt)
default:
return events
}
}
}
func waitForEvent(t *testing.T, ch <-chan Event, timeout time.Duration, match func(Event) bool) Event {
t.Helper()
timer := time.NewTimer(timeout)
defer timer.Stop()
for {
select {
case evt, ok := <-ch:
if !ok {
t.Fatal("event stream closed before expected event arrived")
}
if match(evt) {
return evt
}
case <-timer.C:
t.Fatal("timed out waiting for expected event")
}
}
}
func findEvent(events []Event, kind EventKind) (Event, bool) {
for _, evt := range events {
if evt.Kind == kind {
return evt, true
}
}
return Event{}, false
}
type stringError string
func (e stringError) Error() string {
return string(e)
}
type asyncFollowUpTool struct {
name string
followUpText string
completionSig chan struct{}
}
func (t *asyncFollowUpTool) Name() string {
return t.name
}
func (t *asyncFollowUpTool) Description() string {
return "async follow-up tool for testing"
}
func (t *asyncFollowUpTool) Parameters() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]any{},
}
}
func (t *asyncFollowUpTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
return tools.AsyncResult("async follow-up scheduled")
}
func (t *asyncFollowUpTool) ExecuteAsync(
ctx context.Context,
args map[string]any,
cb tools.AsyncCallback,
) *tools.ToolResult {
go func() {
cb(ctx, &tools.ToolResult{ForLLM: t.followUpText})
if t.completionSig != nil {
close(t.completionSig)
}
}()
return tools.AsyncResult("async follow-up scheduled")
}
var (
_ tools.Tool = (*mockCustomTool)(nil)
_ tools.AsyncExecutor = (*asyncFollowUpTool)(nil)
)
+271
View File
@@ -0,0 +1,271 @@
package agent
import (
"fmt"
"time"
)
// EventKind identifies a structured agent-loop event.
type EventKind uint8
const (
// EventKindTurnStart is emitted when a turn begins processing.
EventKindTurnStart EventKind = iota
// EventKindTurnEnd is emitted when a turn finishes, successfully or with an error.
EventKindTurnEnd
// EventKindLLMRequest is emitted before a provider chat request is made.
EventKindLLMRequest
// EventKindLLMDelta is emitted when a streaming provider yields a partial delta.
EventKindLLMDelta
// EventKindLLMResponse is emitted after a provider chat response is received.
EventKindLLMResponse
// EventKindLLMRetry is emitted when an LLM request is retried.
EventKindLLMRetry
// EventKindContextCompress is emitted when session history is forcibly compressed.
EventKindContextCompress
// EventKindSessionSummarize is emitted when asynchronous summarization completes.
EventKindSessionSummarize
// EventKindToolExecStart is emitted immediately before a tool executes.
EventKindToolExecStart
// EventKindToolExecEnd is emitted immediately after a tool finishes executing.
EventKindToolExecEnd
// EventKindToolExecSkipped is emitted when a queued tool call is skipped.
EventKindToolExecSkipped
// EventKindSteeringInjected is emitted when queued steering is injected into context.
EventKindSteeringInjected
// EventKindFollowUpQueued is emitted when an async tool queues a follow-up system message.
EventKindFollowUpQueued
// EventKindInterruptReceived is emitted when a soft interrupt message is accepted.
EventKindInterruptReceived
// EventKindSubTurnSpawn is emitted when a sub-turn is spawned.
EventKindSubTurnSpawn
// EventKindSubTurnEnd is emitted when a sub-turn finishes.
EventKindSubTurnEnd
// EventKindSubTurnResultDelivered is emitted when a sub-turn result is delivered.
EventKindSubTurnResultDelivered
// EventKindSubTurnOrphan is emitted when a sub-turn result cannot be delivered.
EventKindSubTurnOrphan
// EventKindError is emitted when a turn encounters an execution error.
EventKindError
eventKindCount
)
var eventKindNames = [...]string{
"turn_start",
"turn_end",
"llm_request",
"llm_delta",
"llm_response",
"llm_retry",
"context_compress",
"session_summarize",
"tool_exec_start",
"tool_exec_end",
"tool_exec_skipped",
"steering_injected",
"follow_up_queued",
"interrupt_received",
"subturn_spawn",
"subturn_end",
"subturn_result_delivered",
"subturn_orphan",
"error",
}
// String returns the stable string form of an EventKind.
func (k EventKind) String() string {
if k >= eventKindCount {
return fmt.Sprintf("event_kind(%d)", k)
}
return eventKindNames[k]
}
// Event is the structured envelope broadcast by the agent EventBus.
type Event struct {
Kind EventKind
Time time.Time
Meta EventMeta
Payload any
}
// EventMeta contains correlation fields shared by all agent-loop events.
type EventMeta struct {
AgentID string
TurnID string
ParentTurnID string
SessionKey string
Iteration int
TracePath string
Source string
}
// TurnEndStatus describes the terminal state of a turn.
type TurnEndStatus string
const (
// TurnEndStatusCompleted indicates the turn finished normally.
TurnEndStatusCompleted TurnEndStatus = "completed"
// TurnEndStatusError indicates the turn ended because of an error.
TurnEndStatusError TurnEndStatus = "error"
// TurnEndStatusAborted indicates the turn was hard-aborted and rolled back.
TurnEndStatusAborted TurnEndStatus = "aborted"
)
// TurnStartPayload describes the start of a turn.
type TurnStartPayload struct {
Channel string
ChatID string
UserMessage string
MediaCount int
}
// TurnEndPayload describes the completion of a turn.
type TurnEndPayload struct {
Status TurnEndStatus
Iterations int
Duration time.Duration
FinalContentLen int
}
// LLMRequestPayload describes an outbound LLM request.
type LLMRequestPayload struct {
Model string
MessagesCount int
ToolsCount int
MaxTokens int
Temperature float64
}
// LLMResponsePayload describes an inbound LLM response.
type LLMResponsePayload struct {
ContentLen int
ToolCalls int
HasReasoning bool
}
// LLMDeltaPayload describes a streamed LLM delta.
type LLMDeltaPayload struct {
ContentDeltaLen int
ReasoningDeltaLen int
}
// LLMRetryPayload describes a retry of an LLM request.
type LLMRetryPayload struct {
Attempt int
MaxRetries int
Reason string
Error string
Backoff time.Duration
}
// ContextCompressReason identifies why emergency compression ran.
type ContextCompressReason string
const (
// ContextCompressReasonProactive indicates compression before the first LLM call.
ContextCompressReasonProactive ContextCompressReason = "proactive_budget"
// ContextCompressReasonRetry indicates compression during context-error retry handling.
ContextCompressReasonRetry ContextCompressReason = "llm_retry"
)
// ContextCompressPayload describes a forced history compression.
type ContextCompressPayload struct {
Reason ContextCompressReason
DroppedMessages int
RemainingMessages int
}
// SessionSummarizePayload describes a completed async session summarization.
type SessionSummarizePayload struct {
SummarizedMessages int
KeptMessages int
SummaryLen int
OmittedOversized bool
}
// ToolExecStartPayload describes a tool execution request.
type ToolExecStartPayload struct {
Tool string
Arguments map[string]any
}
// ToolExecEndPayload describes the outcome of a tool execution.
type ToolExecEndPayload struct {
Tool string
Duration time.Duration
ForLLMLen int
ForUserLen int
IsError bool
Async bool
}
// ToolExecSkippedPayload describes a skipped tool call.
type ToolExecSkippedPayload struct {
Tool string
Reason string
}
// SteeringInjectedPayload describes steering messages appended before the next LLM call.
type SteeringInjectedPayload struct {
Count int
TotalContentLen int
}
// FollowUpQueuedPayload describes an async follow-up queued back into the inbound bus.
type FollowUpQueuedPayload struct {
SourceTool string
Channel string
ChatID string
ContentLen int
}
type InterruptKind string
const (
InterruptKindSteering InterruptKind = "steering"
InterruptKindGraceful InterruptKind = "graceful"
InterruptKindHard InterruptKind = "hard_abort"
)
// InterruptReceivedPayload describes accepted turn-control input.
type InterruptReceivedPayload struct {
Kind InterruptKind
Role string
ContentLen int
QueueDepth int
HintLen int
}
// SubTurnSpawnPayload describes the creation of a child turn.
type SubTurnSpawnPayload struct {
AgentID string
Label string
ParentTurnID string
}
// SubTurnEndPayload describes the completion of a child turn.
type SubTurnEndPayload struct {
AgentID string
Status string
}
// SubTurnResultDeliveredPayload describes delivery of a sub-turn result.
type SubTurnResultDeliveredPayload struct {
TargetChannel string
TargetChatID string
ContentLen int
}
// SubTurnOrphanPayload describes a sub-turn result that could not be delivered.
type SubTurnOrphanPayload struct {
ParentTurnID string
ChildTurnID string
Reason string
}
// ErrorPayload describes an execution error inside the agent loop.
type ErrorPayload struct {
Stage string
Message string
}
+317
View File
@@ -0,0 +1,317 @@
package agent
import (
"context"
"fmt"
"sort"
"sync"
"time"
"github.com/sipeed/picoclaw/pkg/config"
)
type hookRuntime struct {
initOnce sync.Once
mu sync.Mutex
initErr error
mounted []string
}
func (r *hookRuntime) setInitErr(err error) {
r.mu.Lock()
r.initErr = err
r.mu.Unlock()
}
func (r *hookRuntime) getInitErr() error {
r.mu.Lock()
defer r.mu.Unlock()
return r.initErr
}
func (r *hookRuntime) setMounted(names []string) {
r.mu.Lock()
r.mounted = append([]string(nil), names...)
r.mu.Unlock()
}
func (r *hookRuntime) reset(al *AgentLoop) {
r.mu.Lock()
names := append([]string(nil), r.mounted...)
r.mounted = nil
r.initErr = nil
r.initOnce = sync.Once{}
r.mu.Unlock()
for _, name := range names {
al.UnmountHook(name)
}
}
// BuiltinHookFactory constructs an in-process hook from config.
type BuiltinHookFactory func(ctx context.Context, spec config.BuiltinHookConfig) (any, error)
var (
builtinHookRegistryMu sync.RWMutex
builtinHookRegistry = map[string]BuiltinHookFactory{}
)
// RegisterBuiltinHook registers a named in-process hook factory for config-driven mounting.
func RegisterBuiltinHook(name string, factory BuiltinHookFactory) error {
if name == "" {
return fmt.Errorf("builtin hook name is required")
}
if factory == nil {
return fmt.Errorf("builtin hook %q factory is nil", name)
}
builtinHookRegistryMu.Lock()
defer builtinHookRegistryMu.Unlock()
if _, exists := builtinHookRegistry[name]; exists {
return fmt.Errorf("builtin hook %q is already registered", name)
}
builtinHookRegistry[name] = factory
return nil
}
func unregisterBuiltinHook(name string) {
if name == "" {
return
}
builtinHookRegistryMu.Lock()
delete(builtinHookRegistry, name)
builtinHookRegistryMu.Unlock()
}
func lookupBuiltinHook(name string) (BuiltinHookFactory, bool) {
builtinHookRegistryMu.RLock()
defer builtinHookRegistryMu.RUnlock()
factory, ok := builtinHookRegistry[name]
return factory, ok
}
func configureHookManagerFromConfig(hm *HookManager, cfg *config.Config) {
if hm == nil || cfg == nil {
return
}
hm.ConfigureTimeouts(
hookTimeoutFromMS(cfg.Hooks.Defaults.ObserverTimeoutMS),
hookTimeoutFromMS(cfg.Hooks.Defaults.InterceptorTimeoutMS),
hookTimeoutFromMS(cfg.Hooks.Defaults.ApprovalTimeoutMS),
)
}
func hookTimeoutFromMS(ms int) time.Duration {
if ms <= 0 {
return 0
}
return time.Duration(ms) * time.Millisecond
}
func (al *AgentLoop) ensureHooksInitialized(ctx context.Context) error {
if al == nil || al.cfg == nil || al.hooks == nil {
return nil
}
al.hookRuntime.initOnce.Do(func() {
al.hookRuntime.setInitErr(al.loadConfiguredHooks(ctx))
})
return al.hookRuntime.getInitErr()
}
func (al *AgentLoop) loadConfiguredHooks(ctx context.Context) (err error) {
if al == nil || al.cfg == nil || !al.cfg.Hooks.Enabled {
return nil
}
mounted := make([]string, 0)
defer func() {
if err != nil {
for _, name := range mounted {
al.UnmountHook(name)
}
return
}
al.hookRuntime.setMounted(mounted)
}()
builtinNames := enabledBuiltinHookNames(al.cfg.Hooks.Builtins)
for _, name := range builtinNames {
spec := al.cfg.Hooks.Builtins[name]
factory, ok := lookupBuiltinHook(name)
if !ok {
return fmt.Errorf("builtin hook %q is not registered", name)
}
hook, factoryErr := factory(ctx, spec)
if factoryErr != nil {
return fmt.Errorf("build builtin hook %q: %w", name, factoryErr)
}
if err := al.MountHook(HookRegistration{
Name: name,
Priority: spec.Priority,
Source: HookSourceInProcess,
Hook: hook,
}); err != nil {
return fmt.Errorf("mount builtin hook %q: %w", name, err)
}
mounted = append(mounted, name)
}
processNames := enabledProcessHookNames(al.cfg.Hooks.Processes)
for _, name := range processNames {
spec := al.cfg.Hooks.Processes[name]
opts, buildErr := processHookOptionsFromConfig(spec)
if buildErr != nil {
return fmt.Errorf("configure process hook %q: %w", name, buildErr)
}
processHook, buildErr := NewProcessHook(ctx, name, opts)
if buildErr != nil {
return fmt.Errorf("start process hook %q: %w", name, buildErr)
}
if err := al.MountHook(HookRegistration{
Name: name,
Priority: spec.Priority,
Source: HookSourceProcess,
Hook: processHook,
}); err != nil {
_ = processHook.Close()
return fmt.Errorf("mount process hook %q: %w", name, err)
}
mounted = append(mounted, name)
}
return nil
}
func enabledBuiltinHookNames(specs map[string]config.BuiltinHookConfig) []string {
if len(specs) == 0 {
return nil
}
names := make([]string, 0, len(specs))
for name, spec := range specs {
if spec.Enabled {
names = append(names, name)
}
}
sort.Strings(names)
return names
}
func enabledProcessHookNames(specs map[string]config.ProcessHookConfig) []string {
if len(specs) == 0 {
return nil
}
names := make([]string, 0, len(specs))
for name, spec := range specs {
if spec.Enabled {
names = append(names, name)
}
}
sort.Strings(names)
return names
}
func processHookOptionsFromConfig(spec config.ProcessHookConfig) (ProcessHookOptions, error) {
transport := spec.Transport
if transport == "" {
transport = "stdio"
}
if transport != "stdio" {
return ProcessHookOptions{}, fmt.Errorf("unsupported transport %q", transport)
}
if len(spec.Command) == 0 {
return ProcessHookOptions{}, fmt.Errorf("command is required")
}
opts := ProcessHookOptions{
Command: append([]string(nil), spec.Command...),
Dir: spec.Dir,
Env: processHookEnvFromMap(spec.Env),
}
observeKinds, observeEnabled, err := processHookObserveKindsFromConfig(spec.Observe)
if err != nil {
return ProcessHookOptions{}, err
}
opts.Observe = observeEnabled
opts.ObserveKinds = observeKinds
for _, intercept := range spec.Intercept {
switch intercept {
case "before_llm", "after_llm":
opts.InterceptLLM = true
case "before_tool", "after_tool":
opts.InterceptTool = true
case "approve_tool":
opts.ApproveTool = true
case "":
continue
default:
return ProcessHookOptions{}, fmt.Errorf("unsupported intercept %q", intercept)
}
}
if !opts.Observe && !opts.InterceptLLM && !opts.InterceptTool && !opts.ApproveTool {
return ProcessHookOptions{}, fmt.Errorf("no hook modes enabled")
}
return opts, nil
}
func processHookEnvFromMap(envMap map[string]string) []string {
if len(envMap) == 0 {
return nil
}
keys := make([]string, 0, len(envMap))
for key := range envMap {
keys = append(keys, key)
}
sort.Strings(keys)
env := make([]string, 0, len(keys))
for _, key := range keys {
env = append(env, key+"="+envMap[key])
}
return env
}
func processHookObserveKindsFromConfig(observe []string) ([]string, bool, error) {
if len(observe) == 0 {
return nil, false, nil
}
validKinds := validHookEventKinds()
normalized := make([]string, 0, len(observe))
for _, kind := range observe {
switch kind {
case "", "*", "all":
return nil, true, nil
default:
if _, ok := validKinds[kind]; !ok {
return nil, false, fmt.Errorf("unsupported observe event %q", kind)
}
normalized = append(normalized, kind)
}
}
if len(normalized) == 0 {
return nil, false, nil
}
return normalized, true, nil
}
func validHookEventKinds() map[string]struct{} {
kinds := make(map[string]struct{}, int(eventKindCount))
for kind := EventKind(0); kind < eventKindCount; kind++ {
kinds[kind.String()] = struct{}{}
}
return kinds
}
+179
View File
@@ -0,0 +1,179 @@
package agent
import (
"context"
"encoding/json"
"path/filepath"
"testing"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
)
type builtinAutoHookConfig struct {
Model string `json:"model"`
Suffix string `json:"suffix"`
}
type builtinAutoHook struct {
model string
suffix string
}
func (h *builtinAutoHook) BeforeLLM(
ctx context.Context,
req *LLMHookRequest,
) (*LLMHookRequest, HookDecision, error) {
next := req.Clone()
next.Model = h.model
return next, HookDecision{Action: HookActionModify}, nil
}
func (h *builtinAutoHook) AfterLLM(
ctx context.Context,
resp *LLMHookResponse,
) (*LLMHookResponse, HookDecision, error) {
next := resp.Clone()
if next.Response != nil {
next.Response.Content += h.suffix
}
return next, HookDecision{Action: HookActionModify}, nil
}
func newConfiguredHookLoop(t *testing.T, provider *llmHookTestProvider, hooks config.HooksConfig) *AgentLoop {
t.Helper()
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: t.TempDir(),
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
Hooks: hooks,
}
return NewAgentLoop(cfg, bus.NewMessageBus(), provider)
}
func TestAgentLoop_ProcessDirectWithChannel_AutoMountsBuiltinHook(t *testing.T) {
const hookName = "test-auto-builtin-hook"
if err := RegisterBuiltinHook(hookName, func(
ctx context.Context,
spec config.BuiltinHookConfig,
) (any, error) {
var hookCfg builtinAutoHookConfig
if len(spec.Config) > 0 {
if err := json.Unmarshal(spec.Config, &hookCfg); err != nil {
return nil, err
}
}
return &builtinAutoHook{
model: hookCfg.Model,
suffix: hookCfg.Suffix,
}, nil
}); err != nil {
t.Fatalf("RegisterBuiltinHook failed: %v", err)
}
t.Cleanup(func() {
unregisterBuiltinHook(hookName)
})
rawCfg, err := json.Marshal(builtinAutoHookConfig{
Model: "builtin-model",
Suffix: "|builtin",
})
if err != nil {
t.Fatalf("json.Marshal failed: %v", err)
}
provider := &llmHookTestProvider{}
al := newConfiguredHookLoop(t, provider, config.HooksConfig{
Enabled: true,
Builtins: map[string]config.BuiltinHookConfig{
hookName: {
Enabled: true,
Config: rawCfg,
},
},
})
defer al.Close()
resp, err := al.ProcessDirectWithChannel(context.Background(), "hello", "session-1", "cli", "direct")
if err != nil {
t.Fatalf("ProcessDirectWithChannel failed: %v", err)
}
if resp != "provider content|builtin" {
t.Fatalf("expected builtin-hooked content, got %q", resp)
}
provider.mu.Lock()
lastModel := provider.lastModel
provider.mu.Unlock()
if lastModel != "builtin-model" {
t.Fatalf("expected builtin model, got %q", lastModel)
}
}
func TestAgentLoop_ProcessDirectWithChannel_AutoMountsProcessHook(t *testing.T) {
provider := &llmHookTestProvider{}
eventLog := filepath.Join(t.TempDir(), "events.log")
al := newConfiguredHookLoop(t, provider, config.HooksConfig{
Enabled: true,
Processes: map[string]config.ProcessHookConfig{
"ipc-auto": {
Enabled: true,
Command: processHookHelperCommand(),
Env: map[string]string{
"PICOCLAW_HOOK_HELPER": "1",
"PICOCLAW_HOOK_MODE": "rewrite",
"PICOCLAW_HOOK_EVENT_LOG": eventLog,
},
Observe: []string{"turn_end"},
Intercept: []string{"before_llm", "after_llm"},
},
},
})
defer al.Close()
resp, err := al.ProcessDirectWithChannel(context.Background(), "hello", "session-1", "cli", "direct")
if err != nil {
t.Fatalf("ProcessDirectWithChannel failed: %v", err)
}
if resp != "provider content|ipc" {
t.Fatalf("expected process-hooked content, got %q", resp)
}
provider.mu.Lock()
lastModel := provider.lastModel
provider.mu.Unlock()
if lastModel != "process-model" {
t.Fatalf("expected process model, got %q", lastModel)
}
waitForFileContains(t, eventLog, "turn_end")
}
func TestAgentLoop_ProcessDirectWithChannel_InvalidConfiguredHookFails(t *testing.T) {
provider := &llmHookTestProvider{}
al := newConfiguredHookLoop(t, provider, config.HooksConfig{
Enabled: true,
Processes: map[string]config.ProcessHookConfig{
"bad-hook": {
Enabled: true,
Command: processHookHelperCommand(),
Intercept: []string{"not_supported"},
},
},
})
defer al.Close()
_, err := al.ProcessDirectWithChannel(context.Background(), "hello", "session-1", "cli", "direct")
if err == nil {
t.Fatal("expected invalid configured hook error")
}
}
+511
View File
@@ -0,0 +1,511 @@
package agent
import (
"bufio"
"context"
"encoding/json"
"fmt"
"io"
"os"
"os/exec"
"sync"
"sync/atomic"
"time"
"github.com/sipeed/picoclaw/pkg/logger"
)
const (
processHookJSONRPCVersion = "2.0"
processHookReadBufferSize = 1024 * 1024
processHookCloseTimeout = 2 * time.Second
)
type ProcessHookOptions struct {
Command []string
Dir string
Env []string
Observe bool
ObserveKinds []string
InterceptLLM bool
InterceptTool bool
ApproveTool bool
}
type ProcessHook struct {
name string
opts ProcessHookOptions
cmd *exec.Cmd
stdin io.WriteCloser
observeKinds map[string]struct{}
writeMu sync.Mutex
pendingMu sync.Mutex
pending map[uint64]chan processHookRPCMessage
nextID atomic.Uint64
closed atomic.Bool
done chan struct{}
closeErr error
closeMu sync.Mutex
closeOnce sync.Once
}
type processHookRPCMessage struct {
JSONRPC string `json:"jsonrpc,omitempty"`
ID uint64 `json:"id,omitempty"`
Method string `json:"method,omitempty"`
Params json.RawMessage `json:"params,omitempty"`
Result json.RawMessage `json:"result,omitempty"`
Error *processHookRPCError `json:"error,omitempty"`
}
type processHookRPCError struct {
Code int `json:"code"`
Message string `json:"message"`
}
type processHookHelloParams struct {
Name string `json:"name"`
Version int `json:"version"`
Modes []string `json:"modes,omitempty"`
}
type processHookDecisionResponse struct {
Action HookAction `json:"action"`
Reason string `json:"reason,omitempty"`
}
type processHookBeforeLLMResponse struct {
processHookDecisionResponse
Request *LLMHookRequest `json:"request,omitempty"`
}
type processHookAfterLLMResponse struct {
processHookDecisionResponse
Response *LLMHookResponse `json:"response,omitempty"`
}
type processHookBeforeToolResponse struct {
processHookDecisionResponse
Call *ToolCallHookRequest `json:"call,omitempty"`
}
type processHookAfterToolResponse struct {
processHookDecisionResponse
Result *ToolResultHookResponse `json:"result,omitempty"`
}
func NewProcessHook(ctx context.Context, name string, opts ProcessHookOptions) (*ProcessHook, error) {
if len(opts.Command) == 0 {
return nil, fmt.Errorf("process hook command is required")
}
cmd := exec.Command(opts.Command[0], opts.Command[1:]...)
cmd.Dir = opts.Dir
if len(opts.Env) > 0 {
cmd.Env = append(os.Environ(), opts.Env...)
}
stdin, err := cmd.StdinPipe()
if err != nil {
return nil, fmt.Errorf("create process hook stdin: %w", err)
}
stdout, err := cmd.StdoutPipe()
if err != nil {
return nil, fmt.Errorf("create process hook stdout: %w", err)
}
stderr, err := cmd.StderrPipe()
if err != nil {
return nil, fmt.Errorf("create process hook stderr: %w", err)
}
if err := cmd.Start(); err != nil {
return nil, fmt.Errorf("start process hook: %w", err)
}
ph := &ProcessHook{
name: name,
opts: opts,
cmd: cmd,
stdin: stdin,
observeKinds: newProcessHookObserveKinds(opts.ObserveKinds),
pending: make(map[uint64]chan processHookRPCMessage),
done: make(chan struct{}),
}
go ph.readLoop(stdout)
go ph.readStderr(stderr)
go ph.waitLoop()
helloCtx := ctx
if helloCtx == nil {
var cancel context.CancelFunc
helloCtx, cancel = context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
}
if err := ph.hello(helloCtx); err != nil {
_ = ph.Close()
return nil, err
}
return ph, nil
}
func (ph *ProcessHook) Close() error {
if ph == nil {
return nil
}
ph.closeOnce.Do(func() {
ph.closed.Store(true)
if ph.stdin != nil {
_ = ph.stdin.Close()
}
select {
case <-ph.done:
case <-time.After(processHookCloseTimeout):
if ph.cmd != nil && ph.cmd.Process != nil {
_ = ph.cmd.Process.Kill()
}
<-ph.done
}
})
ph.closeMu.Lock()
defer ph.closeMu.Unlock()
return ph.closeErr
}
func (ph *ProcessHook) OnEvent(ctx context.Context, evt Event) error {
if ph == nil || !ph.opts.Observe {
return nil
}
if len(ph.observeKinds) > 0 {
if _, ok := ph.observeKinds[evt.Kind.String()]; !ok {
return nil
}
}
return ph.notify(ctx, "hook.event", evt)
}
func (ph *ProcessHook) BeforeLLM(
ctx context.Context,
req *LLMHookRequest,
) (*LLMHookRequest, HookDecision, error) {
if ph == nil || !ph.opts.InterceptLLM {
return req, HookDecision{Action: HookActionContinue}, nil
}
var resp processHookBeforeLLMResponse
if err := ph.call(ctx, "hook.before_llm", req, &resp); err != nil {
return nil, HookDecision{}, err
}
if resp.Request == nil {
resp.Request = req
}
return resp.Request, HookDecision{Action: resp.Action, Reason: resp.Reason}, nil
}
func (ph *ProcessHook) AfterLLM(
ctx context.Context,
resp *LLMHookResponse,
) (*LLMHookResponse, HookDecision, error) {
if ph == nil || !ph.opts.InterceptLLM {
return resp, HookDecision{Action: HookActionContinue}, nil
}
var result processHookAfterLLMResponse
if err := ph.call(ctx, "hook.after_llm", resp, &result); err != nil {
return nil, HookDecision{}, err
}
if result.Response == nil {
result.Response = resp
}
return result.Response, HookDecision{Action: result.Action, Reason: result.Reason}, nil
}
func (ph *ProcessHook) BeforeTool(
ctx context.Context,
call *ToolCallHookRequest,
) (*ToolCallHookRequest, HookDecision, error) {
if ph == nil || !ph.opts.InterceptTool {
return call, HookDecision{Action: HookActionContinue}, nil
}
var resp processHookBeforeToolResponse
if err := ph.call(ctx, "hook.before_tool", call, &resp); err != nil {
return nil, HookDecision{}, err
}
if resp.Call == nil {
resp.Call = call
}
return resp.Call, HookDecision{Action: resp.Action, Reason: resp.Reason}, nil
}
func (ph *ProcessHook) AfterTool(
ctx context.Context,
result *ToolResultHookResponse,
) (*ToolResultHookResponse, HookDecision, error) {
if ph == nil || !ph.opts.InterceptTool {
return result, HookDecision{Action: HookActionContinue}, nil
}
var resp processHookAfterToolResponse
if err := ph.call(ctx, "hook.after_tool", result, &resp); err != nil {
return nil, HookDecision{}, err
}
if resp.Result == nil {
resp.Result = result
}
return resp.Result, HookDecision{Action: resp.Action, Reason: resp.Reason}, nil
}
func (ph *ProcessHook) ApproveTool(ctx context.Context, req *ToolApprovalRequest) (ApprovalDecision, error) {
if ph == nil || !ph.opts.ApproveTool {
return ApprovalDecision{Approved: true}, nil
}
var resp ApprovalDecision
if err := ph.call(ctx, "hook.approve_tool", req, &resp); err != nil {
return ApprovalDecision{}, err
}
return resp, nil
}
func (ph *ProcessHook) hello(ctx context.Context) error {
modes := make([]string, 0, 4)
if ph.opts.Observe {
modes = append(modes, "observe")
}
if ph.opts.InterceptLLM {
modes = append(modes, "llm")
}
if ph.opts.InterceptTool {
modes = append(modes, "tool")
}
if ph.opts.ApproveTool {
modes = append(modes, "approve")
}
var result map[string]any
return ph.call(ctx, "hook.hello", processHookHelloParams{
Name: ph.name,
Version: 1,
Modes: modes,
}, &result)
}
func (ph *ProcessHook) notify(ctx context.Context, method string, params any) error {
msg := processHookRPCMessage{
JSONRPC: processHookJSONRPCVersion,
Method: method,
}
if params != nil {
body, err := json.Marshal(params)
if err != nil {
return err
}
msg.Params = body
}
return ph.send(ctx, msg)
}
func (ph *ProcessHook) call(ctx context.Context, method string, params any, out any) error {
if ph.closed.Load() {
return fmt.Errorf("process hook %q is closed", ph.name)
}
id := ph.nextID.Add(1)
respCh := make(chan processHookRPCMessage, 1)
ph.pendingMu.Lock()
ph.pending[id] = respCh
ph.pendingMu.Unlock()
msg := processHookRPCMessage{
JSONRPC: processHookJSONRPCVersion,
ID: id,
Method: method,
}
if params != nil {
body, err := json.Marshal(params)
if err != nil {
ph.removePending(id)
return err
}
msg.Params = body
}
if err := ph.send(ctx, msg); err != nil {
ph.removePending(id)
return err
}
select {
case resp, ok := <-respCh:
if !ok {
return fmt.Errorf("process hook %q closed while waiting for %s", ph.name, method)
}
if resp.Error != nil {
return fmt.Errorf("process hook %q %s failed: %s", ph.name, method, resp.Error.Message)
}
if out != nil && len(resp.Result) > 0 {
if err := json.Unmarshal(resp.Result, out); err != nil {
return fmt.Errorf("decode process hook %q %s result: %w", ph.name, method, err)
}
}
return nil
case <-ctx.Done():
ph.removePending(id)
return ctx.Err()
}
}
func (ph *ProcessHook) send(ctx context.Context, msg processHookRPCMessage) error {
body, err := json.Marshal(msg)
if err != nil {
return err
}
body = append(body, '\n')
ph.writeMu.Lock()
defer ph.writeMu.Unlock()
if ph.closed.Load() {
return fmt.Errorf("process hook %q is closed", ph.name)
}
done := make(chan error, 1)
go func() {
_, writeErr := ph.stdin.Write(body)
done <- writeErr
}()
select {
case err := <-done:
if err != nil {
return fmt.Errorf("write process hook %q message: %w", ph.name, err)
}
return nil
case <-ctx.Done():
return ctx.Err()
}
}
func (ph *ProcessHook) readLoop(stdout io.Reader) {
scanner := bufio.NewScanner(stdout)
scanner.Buffer(make([]byte, 0, 64*1024), processHookReadBufferSize)
for scanner.Scan() {
var msg processHookRPCMessage
if err := json.Unmarshal(scanner.Bytes(), &msg); err != nil {
logger.WarnCF("hooks", "Failed to decode process hook message", map[string]any{
"hook": ph.name,
"error": err.Error(),
})
continue
}
if msg.ID == 0 {
continue
}
ph.pendingMu.Lock()
respCh, ok := ph.pending[msg.ID]
if ok {
delete(ph.pending, msg.ID)
}
ph.pendingMu.Unlock()
if ok {
respCh <- msg
close(respCh)
}
}
}
func (ph *ProcessHook) readStderr(stderr io.Reader) {
scanner := bufio.NewScanner(stderr)
scanner.Buffer(make([]byte, 0, 16*1024), processHookReadBufferSize)
for scanner.Scan() {
logger.WarnCF("hooks", "Process hook stderr", map[string]any{
"hook": ph.name,
"stderr": scanner.Text(),
})
}
}
func (ph *ProcessHook) waitLoop() {
err := ph.cmd.Wait()
ph.closeMu.Lock()
ph.closeErr = err
ph.closeMu.Unlock()
ph.failPending(err)
close(ph.done)
}
func (ph *ProcessHook) failPending(err error) {
ph.pendingMu.Lock()
defer ph.pendingMu.Unlock()
msg := processHookRPCMessage{
Error: &processHookRPCError{
Code: -32000,
Message: "process exited",
},
}
if err != nil {
msg.Error.Message = err.Error()
}
for id, ch := range ph.pending {
delete(ph.pending, id)
ch <- msg
close(ch)
}
}
func (ph *ProcessHook) removePending(id uint64) {
ph.pendingMu.Lock()
defer ph.pendingMu.Unlock()
if ch, ok := ph.pending[id]; ok {
delete(ph.pending, id)
close(ch)
}
}
func (al *AgentLoop) MountProcessHook(ctx context.Context, name string, opts ProcessHookOptions) error {
if al == nil {
return fmt.Errorf("agent loop is nil")
}
processHook, err := NewProcessHook(ctx, name, opts)
if err != nil {
return err
}
if err := al.MountHook(HookRegistration{
Name: name,
Source: HookSourceProcess,
Hook: processHook,
}); err != nil {
_ = processHook.Close()
return err
}
return nil
}
func newProcessHookObserveKinds(kinds []string) map[string]struct{} {
if len(kinds) == 0 {
return nil
}
normalized := make(map[string]struct{}, len(kinds))
for _, kind := range kinds {
if kind == "" {
continue
}
normalized[kind] = struct{}{}
}
if len(normalized) == 0 {
return nil
}
return normalized
}
+339
View File
@@ -0,0 +1,339 @@
package agent
import (
"bufio"
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/providers"
)
func TestProcessHook_HelperProcess(t *testing.T) {
if os.Getenv("PICOCLAW_HOOK_HELPER") != "1" {
return
}
if err := runProcessHookHelper(); err != nil {
fmt.Fprintln(os.Stderr, err.Error())
os.Exit(1)
}
os.Exit(0)
}
func TestAgentLoop_MountProcessHook_LLMAndObserver(t *testing.T) {
provider := &llmHookTestProvider{}
al, agent, cleanup := newHookTestLoop(t, provider)
defer cleanup()
eventLog := filepath.Join(t.TempDir(), "events.log")
if err := al.MountProcessHook(context.Background(), "ipc-llm", ProcessHookOptions{
Command: processHookHelperCommand(),
Env: processHookHelperEnv("rewrite", eventLog),
Observe: true,
InterceptLLM: true,
}); err != nil {
t.Fatalf("MountProcessHook failed: %v", err)
}
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
SessionKey: "session-1",
Channel: "cli",
ChatID: "direct",
UserMessage: "hello",
DefaultResponse: defaultResponse,
EnableSummary: false,
SendResponse: false,
})
if err != nil {
t.Fatalf("runAgentLoop failed: %v", err)
}
if resp != "provider content|ipc" {
t.Fatalf("expected process-hooked llm content, got %q", resp)
}
provider.mu.Lock()
lastModel := provider.lastModel
provider.mu.Unlock()
if lastModel != "process-model" {
t.Fatalf("expected process model, got %q", lastModel)
}
waitForFileContains(t, eventLog, "turn_end")
}
func TestAgentLoop_MountProcessHook_ToolRewrite(t *testing.T) {
provider := &toolHookProvider{}
al, agent, cleanup := newHookTestLoop(t, provider)
defer cleanup()
al.RegisterTool(&echoTextTool{})
if err := al.MountProcessHook(context.Background(), "ipc-tool", ProcessHookOptions{
Command: processHookHelperCommand(),
Env: processHookHelperEnv("rewrite", ""),
InterceptTool: true,
}); err != nil {
t.Fatalf("MountProcessHook failed: %v", err)
}
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
SessionKey: "session-1",
Channel: "cli",
ChatID: "direct",
UserMessage: "run tool",
DefaultResponse: defaultResponse,
EnableSummary: false,
SendResponse: false,
})
if err != nil {
t.Fatalf("runAgentLoop failed: %v", err)
}
if resp != "ipc:ipc" {
t.Fatalf("expected rewritten process-hook tool result, got %q", resp)
}
}
type blockedToolProvider struct {
calls int
}
func (p *blockedToolProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
p.calls++
if p.calls == 1 {
return &providers.LLMResponse{
ToolCalls: []providers.ToolCall{
{
ID: "call-1",
Name: "blocked_tool",
Arguments: map[string]any{},
},
},
}, nil
}
return &providers.LLMResponse{
Content: messages[len(messages)-1].Content,
}, nil
}
func (p *blockedToolProvider) GetDefaultModel() string {
return "blocked-tool-provider"
}
func TestAgentLoop_MountProcessHook_ApprovalDeny(t *testing.T) {
provider := &blockedToolProvider{}
al, agent, cleanup := newHookTestLoop(t, provider)
defer cleanup()
if err := al.MountProcessHook(context.Background(), "ipc-approval", ProcessHookOptions{
Command: processHookHelperCommand(),
Env: processHookHelperEnv("deny", ""),
ApproveTool: true,
}); err != nil {
t.Fatalf("MountProcessHook failed: %v", err)
}
sub := al.SubscribeEvents(16)
defer al.UnsubscribeEvents(sub.ID)
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
SessionKey: "session-1",
Channel: "cli",
ChatID: "direct",
UserMessage: "run blocked tool",
DefaultResponse: defaultResponse,
EnableSummary: false,
SendResponse: false,
})
if err != nil {
t.Fatalf("runAgentLoop failed: %v", err)
}
expected := "Tool execution denied by approval hook: blocked by ipc hook"
if resp != expected {
t.Fatalf("expected %q, got %q", expected, resp)
}
events := collectEventStream(sub.C)
skippedEvt, ok := findEvent(events, EventKindToolExecSkipped)
if !ok {
t.Fatal("expected tool skipped event")
}
payload, ok := skippedEvt.Payload.(ToolExecSkippedPayload)
if !ok {
t.Fatalf("expected ToolExecSkippedPayload, got %T", skippedEvt.Payload)
}
if payload.Reason != expected {
t.Fatalf("expected reason %q, got %q", expected, payload.Reason)
}
}
func processHookHelperCommand() []string {
return []string{os.Args[0], "-test.run=TestProcessHook_HelperProcess", "--"}
}
func processHookHelperEnv(mode, eventLog string) []string {
env := []string{
"PICOCLAW_HOOK_HELPER=1",
"PICOCLAW_HOOK_MODE=" + mode,
}
if eventLog != "" {
env = append(env, "PICOCLAW_HOOK_EVENT_LOG="+eventLog)
}
return env
}
func waitForFileContains(t *testing.T, path, substring string) {
t.Helper()
deadline := time.Now().Add(3 * time.Second)
for time.Now().Before(deadline) {
data, err := os.ReadFile(path)
if err == nil && strings.Contains(string(data), substring) {
return
}
time.Sleep(20 * time.Millisecond)
}
data, _ := os.ReadFile(path)
t.Fatalf("timed out waiting for %q in %s; current content: %q", substring, path, string(data))
}
func runProcessHookHelper() error {
mode := os.Getenv("PICOCLAW_HOOK_MODE")
eventLog := os.Getenv("PICOCLAW_HOOK_EVENT_LOG")
scanner := bufio.NewScanner(os.Stdin)
scanner.Buffer(make([]byte, 0, 64*1024), processHookReadBufferSize)
encoder := json.NewEncoder(os.Stdout)
for scanner.Scan() {
var msg processHookRPCMessage
if err := json.Unmarshal(scanner.Bytes(), &msg); err != nil {
return err
}
if msg.ID == 0 {
if msg.Method == "hook.event" && eventLog != "" {
var evt map[string]any
if err := json.Unmarshal(msg.Params, &evt); err == nil {
if rawKind, ok := evt["Kind"].(float64); ok {
kind := EventKind(rawKind)
_ = os.WriteFile(eventLog, []byte(kind.String()+"\n"), 0o644)
}
}
}
continue
}
result, rpcErr := handleProcessHookRequest(mode, msg)
resp := processHookRPCMessage{
JSONRPC: processHookJSONRPCVersion,
ID: msg.ID,
}
if rpcErr != nil {
resp.Error = rpcErr
} else if result != nil {
body, err := json.Marshal(result)
if err != nil {
return err
}
resp.Result = body
} else {
resp.Result = []byte("{}")
}
if err := encoder.Encode(resp); err != nil {
return err
}
}
return scanner.Err()
}
func handleProcessHookRequest(mode string, msg processHookRPCMessage) (any, *processHookRPCError) {
switch msg.Method {
case "hook.hello":
return map[string]any{"ok": true}, nil
case "hook.before_llm":
if mode != "rewrite" {
return map[string]any{"action": HookActionContinue}, nil
}
var req map[string]any
_ = json.Unmarshal(msg.Params, &req)
req["model"] = "process-model"
return map[string]any{
"action": HookActionModify,
"request": req,
}, nil
case "hook.after_llm":
if mode != "rewrite" {
return map[string]any{"action": HookActionContinue}, nil
}
var resp map[string]any
_ = json.Unmarshal(msg.Params, &resp)
if rawResponse, ok := resp["response"].(map[string]any); ok {
if content, ok := rawResponse["content"].(string); ok {
rawResponse["content"] = content + "|ipc"
}
}
return map[string]any{
"action": HookActionModify,
"response": resp,
}, nil
case "hook.before_tool":
if mode != "rewrite" {
return map[string]any{"action": HookActionContinue}, nil
}
var call map[string]any
_ = json.Unmarshal(msg.Params, &call)
rawArgs, ok := call["arguments"].(map[string]any)
if !ok || rawArgs == nil {
rawArgs = map[string]any{}
}
rawArgs["text"] = "ipc"
call["arguments"] = rawArgs
return map[string]any{
"action": HookActionModify,
"call": call,
}, nil
case "hook.after_tool":
if mode != "rewrite" {
return map[string]any{"action": HookActionContinue}, nil
}
var result map[string]any
_ = json.Unmarshal(msg.Params, &result)
if rawResult, ok := result["result"].(map[string]any); ok {
if forLLM, ok := rawResult["for_llm"].(string); ok {
rawResult["for_llm"] = "ipc:" + forLLM
}
}
return map[string]any{
"action": HookActionModify,
"result": result,
}, nil
case "hook.approve_tool":
if mode == "deny" {
return ApprovalDecision{
Approved: false,
Reason: "blocked by ipc hook",
}, nil
}
return ApprovalDecision{Approved: true}, nil
default:
return nil, &processHookRPCError{
Code: -32601,
Message: "method not found",
}
}
}
+809
View File
@@ -0,0 +1,809 @@
package agent
import (
"context"
"fmt"
"io"
"sort"
"sync"
"time"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/tools"
)
const (
defaultHookObserverTimeout = 500 * time.Millisecond
defaultHookInterceptorTimeout = 5 * time.Second
defaultHookApprovalTimeout = 60 * time.Second
hookObserverBufferSize = 64
)
type HookAction string
const (
HookActionContinue HookAction = "continue"
HookActionModify HookAction = "modify"
HookActionDenyTool HookAction = "deny_tool"
HookActionAbortTurn HookAction = "abort_turn"
HookActionHardAbort HookAction = "hard_abort"
)
type HookDecision struct {
Action HookAction `json:"action"`
Reason string `json:"reason,omitempty"`
}
func (d HookDecision) normalizedAction() HookAction {
if d.Action == "" {
return HookActionContinue
}
return d.Action
}
type ApprovalDecision struct {
Approved bool `json:"approved"`
Reason string `json:"reason,omitempty"`
}
type HookSource uint8
const (
HookSourceInProcess HookSource = iota
HookSourceProcess
)
type HookRegistration struct {
Name string
Priority int
Source HookSource
Hook any
}
func NamedHook(name string, hook any) HookRegistration {
return HookRegistration{
Name: name,
Source: HookSourceInProcess,
Hook: hook,
}
}
type EventObserver interface {
OnEvent(ctx context.Context, evt Event) error
}
type LLMInterceptor interface {
BeforeLLM(ctx context.Context, req *LLMHookRequest) (*LLMHookRequest, HookDecision, error)
AfterLLM(ctx context.Context, resp *LLMHookResponse) (*LLMHookResponse, HookDecision, error)
}
type ToolInterceptor interface {
BeforeTool(ctx context.Context, call *ToolCallHookRequest) (*ToolCallHookRequest, HookDecision, error)
AfterTool(ctx context.Context, result *ToolResultHookResponse) (*ToolResultHookResponse, HookDecision, error)
}
type ToolApprover interface {
ApproveTool(ctx context.Context, req *ToolApprovalRequest) (ApprovalDecision, error)
}
type LLMHookRequest struct {
Meta EventMeta `json:"meta"`
Model string `json:"model"`
Messages []providers.Message `json:"messages,omitempty"`
Tools []providers.ToolDefinition `json:"tools,omitempty"`
Options map[string]any `json:"options,omitempty"`
Channel string `json:"channel,omitempty"`
ChatID string `json:"chat_id,omitempty"`
GracefulTerminal bool `json:"graceful_terminal,omitempty"`
}
func (r *LLMHookRequest) Clone() *LLMHookRequest {
if r == nil {
return nil
}
cloned := *r
cloned.Messages = cloneProviderMessages(r.Messages)
cloned.Tools = cloneToolDefinitions(r.Tools)
cloned.Options = cloneStringAnyMap(r.Options)
return &cloned
}
type LLMHookResponse struct {
Meta EventMeta `json:"meta"`
Model string `json:"model"`
Response *providers.LLMResponse `json:"response,omitempty"`
Channel string `json:"channel,omitempty"`
ChatID string `json:"chat_id,omitempty"`
}
func (r *LLMHookResponse) Clone() *LLMHookResponse {
if r == nil {
return nil
}
cloned := *r
cloned.Response = cloneLLMResponse(r.Response)
return &cloned
}
type ToolCallHookRequest struct {
Meta EventMeta `json:"meta"`
Tool string `json:"tool"`
Arguments map[string]any `json:"arguments,omitempty"`
Channel string `json:"channel,omitempty"`
ChatID string `json:"chat_id,omitempty"`
}
func (r *ToolCallHookRequest) Clone() *ToolCallHookRequest {
if r == nil {
return nil
}
cloned := *r
cloned.Arguments = cloneStringAnyMap(r.Arguments)
return &cloned
}
type ToolApprovalRequest struct {
Meta EventMeta `json:"meta"`
Tool string `json:"tool"`
Arguments map[string]any `json:"arguments,omitempty"`
Channel string `json:"channel,omitempty"`
ChatID string `json:"chat_id,omitempty"`
}
func (r *ToolApprovalRequest) Clone() *ToolApprovalRequest {
if r == nil {
return nil
}
cloned := *r
cloned.Arguments = cloneStringAnyMap(r.Arguments)
return &cloned
}
type ToolResultHookResponse struct {
Meta EventMeta `json:"meta"`
Tool string `json:"tool"`
Arguments map[string]any `json:"arguments,omitempty"`
Result *tools.ToolResult `json:"result,omitempty"`
Duration time.Duration `json:"duration"`
Channel string `json:"channel,omitempty"`
ChatID string `json:"chat_id,omitempty"`
}
func (r *ToolResultHookResponse) Clone() *ToolResultHookResponse {
if r == nil {
return nil
}
cloned := *r
cloned.Arguments = cloneStringAnyMap(r.Arguments)
cloned.Result = cloneToolResult(r.Result)
return &cloned
}
type HookManager struct {
eventBus *EventBus
observerTimeout time.Duration
interceptorTimeout time.Duration
approvalTimeout time.Duration
mu sync.RWMutex
hooks map[string]HookRegistration
ordered []HookRegistration
sub EventSubscription
done chan struct{}
closeOnce sync.Once
}
func NewHookManager(eventBus *EventBus) *HookManager {
hm := &HookManager{
eventBus: eventBus,
observerTimeout: defaultHookObserverTimeout,
interceptorTimeout: defaultHookInterceptorTimeout,
approvalTimeout: defaultHookApprovalTimeout,
hooks: make(map[string]HookRegistration),
done: make(chan struct{}),
}
if eventBus == nil {
close(hm.done)
return hm
}
hm.sub = eventBus.Subscribe(hookObserverBufferSize)
go hm.dispatchEvents()
return hm
}
func (hm *HookManager) Close() {
if hm == nil {
return
}
hm.closeOnce.Do(func() {
if hm.eventBus != nil {
hm.eventBus.Unsubscribe(hm.sub.ID)
}
<-hm.done
hm.closeAllHooks()
})
}
func (hm *HookManager) ConfigureTimeouts(observer, interceptor, approval time.Duration) {
if hm == nil {
return
}
if observer > 0 {
hm.observerTimeout = observer
}
if interceptor > 0 {
hm.interceptorTimeout = interceptor
}
if approval > 0 {
hm.approvalTimeout = approval
}
}
func (hm *HookManager) Mount(reg HookRegistration) error {
if hm == nil {
return fmt.Errorf("hook manager is nil")
}
if reg.Name == "" {
return fmt.Errorf("hook name is required")
}
if reg.Hook == nil {
return fmt.Errorf("hook %q is nil", reg.Name)
}
hm.mu.Lock()
defer hm.mu.Unlock()
if existing, ok := hm.hooks[reg.Name]; ok {
closeHookIfPossible(existing.Hook)
}
hm.hooks[reg.Name] = reg
hm.rebuildOrdered()
return nil
}
func (hm *HookManager) Unmount(name string) {
if hm == nil || name == "" {
return
}
hm.mu.Lock()
defer hm.mu.Unlock()
if existing, ok := hm.hooks[name]; ok {
closeHookIfPossible(existing.Hook)
}
delete(hm.hooks, name)
hm.rebuildOrdered()
}
func (hm *HookManager) dispatchEvents() {
defer close(hm.done)
for evt := range hm.sub.C {
for _, reg := range hm.snapshotHooks() {
observer, ok := reg.Hook.(EventObserver)
if !ok {
continue
}
hm.runObserver(reg.Name, observer, evt)
}
}
}
func (hm *HookManager) BeforeLLM(ctx context.Context, req *LLMHookRequest) (*LLMHookRequest, HookDecision) {
if hm == nil || req == nil {
return req, HookDecision{Action: HookActionContinue}
}
current := req.Clone()
for _, reg := range hm.snapshotHooks() {
interceptor, ok := reg.Hook.(LLMInterceptor)
if !ok {
continue
}
next, decision, ok := hm.callBeforeLLM(ctx, reg.Name, interceptor, current.Clone())
if !ok {
continue
}
switch decision.normalizedAction() {
case HookActionContinue, HookActionModify:
if next != nil {
current = next
}
case HookActionAbortTurn, HookActionHardAbort:
return current, decision
default:
hm.logUnsupportedAction(reg.Name, "before_llm", decision.Action)
}
}
return current, HookDecision{Action: HookActionContinue}
}
func (hm *HookManager) AfterLLM(ctx context.Context, resp *LLMHookResponse) (*LLMHookResponse, HookDecision) {
if hm == nil || resp == nil {
return resp, HookDecision{Action: HookActionContinue}
}
current := resp.Clone()
for _, reg := range hm.snapshotHooks() {
interceptor, ok := reg.Hook.(LLMInterceptor)
if !ok {
continue
}
next, decision, ok := hm.callAfterLLM(ctx, reg.Name, interceptor, current.Clone())
if !ok {
continue
}
switch decision.normalizedAction() {
case HookActionContinue, HookActionModify:
if next != nil {
current = next
}
case HookActionAbortTurn, HookActionHardAbort:
return current, decision
default:
hm.logUnsupportedAction(reg.Name, "after_llm", decision.Action)
}
}
return current, HookDecision{Action: HookActionContinue}
}
func (hm *HookManager) BeforeTool(
ctx context.Context,
call *ToolCallHookRequest,
) (*ToolCallHookRequest, HookDecision) {
if hm == nil || call == nil {
return call, HookDecision{Action: HookActionContinue}
}
current := call.Clone()
for _, reg := range hm.snapshotHooks() {
interceptor, ok := reg.Hook.(ToolInterceptor)
if !ok {
continue
}
next, decision, ok := hm.callBeforeTool(ctx, reg.Name, interceptor, current.Clone())
if !ok {
continue
}
switch decision.normalizedAction() {
case HookActionContinue, HookActionModify:
if next != nil {
current = next
}
case HookActionDenyTool, HookActionAbortTurn, HookActionHardAbort:
return current, decision
default:
hm.logUnsupportedAction(reg.Name, "before_tool", decision.Action)
}
}
return current, HookDecision{Action: HookActionContinue}
}
func (hm *HookManager) AfterTool(
ctx context.Context,
result *ToolResultHookResponse,
) (*ToolResultHookResponse, HookDecision) {
if hm == nil || result == nil {
return result, HookDecision{Action: HookActionContinue}
}
current := result.Clone()
for _, reg := range hm.snapshotHooks() {
interceptor, ok := reg.Hook.(ToolInterceptor)
if !ok {
continue
}
next, decision, ok := hm.callAfterTool(ctx, reg.Name, interceptor, current.Clone())
if !ok {
continue
}
switch decision.normalizedAction() {
case HookActionContinue, HookActionModify:
if next != nil {
current = next
}
case HookActionAbortTurn, HookActionHardAbort:
return current, decision
default:
hm.logUnsupportedAction(reg.Name, "after_tool", decision.Action)
}
}
return current, HookDecision{Action: HookActionContinue}
}
func (hm *HookManager) ApproveTool(ctx context.Context, req *ToolApprovalRequest) ApprovalDecision {
if hm == nil || req == nil {
return ApprovalDecision{Approved: true}
}
for _, reg := range hm.snapshotHooks() {
approver, ok := reg.Hook.(ToolApprover)
if !ok {
continue
}
decision, ok := hm.callApproveTool(ctx, reg.Name, approver, req.Clone())
if !ok {
return ApprovalDecision{
Approved: false,
Reason: fmt.Sprintf("tool approval hook %q failed", reg.Name),
}
}
if !decision.Approved {
return decision
}
}
return ApprovalDecision{Approved: true}
}
func (hm *HookManager) rebuildOrdered() {
hm.ordered = hm.ordered[:0]
for _, reg := range hm.hooks {
hm.ordered = append(hm.ordered, reg)
}
sort.SliceStable(hm.ordered, func(i, j int) bool {
if hm.ordered[i].Source != hm.ordered[j].Source {
return hm.ordered[i].Source < hm.ordered[j].Source
}
if hm.ordered[i].Priority == hm.ordered[j].Priority {
return hm.ordered[i].Name < hm.ordered[j].Name
}
return hm.ordered[i].Priority < hm.ordered[j].Priority
})
}
func (hm *HookManager) snapshotHooks() []HookRegistration {
hm.mu.RLock()
defer hm.mu.RUnlock()
snapshot := make([]HookRegistration, len(hm.ordered))
copy(snapshot, hm.ordered)
return snapshot
}
func (hm *HookManager) closeAllHooks() {
hm.mu.Lock()
defer hm.mu.Unlock()
for name, reg := range hm.hooks {
closeHookIfPossible(reg.Hook)
delete(hm.hooks, name)
}
hm.ordered = nil
}
func (hm *HookManager) runObserver(name string, observer EventObserver, evt Event) {
ctx, cancel := context.WithTimeout(context.Background(), hm.observerTimeout)
defer cancel()
done := make(chan error, 1)
go func() {
done <- observer.OnEvent(ctx, evt)
}()
select {
case err := <-done:
if err != nil {
logger.WarnCF("hooks", "Event observer failed", map[string]any{
"hook": name,
"event": evt.Kind.String(),
"error": err.Error(),
})
}
case <-ctx.Done():
logger.WarnCF("hooks", "Event observer timed out", map[string]any{
"hook": name,
"event": evt.Kind.String(),
"timeout_ms": hm.observerTimeout.Milliseconds(),
})
}
}
func (hm *HookManager) callBeforeLLM(
parent context.Context,
name string,
interceptor LLMInterceptor,
req *LLMHookRequest,
) (*LLMHookRequest, HookDecision, bool) {
return runInterceptorHook(
parent,
hm.interceptorTimeout,
name,
"before_llm",
func(ctx context.Context) (*LLMHookRequest, HookDecision, error) {
return interceptor.BeforeLLM(ctx, req)
},
)
}
func (hm *HookManager) callAfterLLM(
parent context.Context,
name string,
interceptor LLMInterceptor,
resp *LLMHookResponse,
) (*LLMHookResponse, HookDecision, bool) {
return runInterceptorHook(
parent,
hm.interceptorTimeout,
name,
"after_llm",
func(ctx context.Context) (*LLMHookResponse, HookDecision, error) {
return interceptor.AfterLLM(ctx, resp)
},
)
}
func (hm *HookManager) callBeforeTool(
parent context.Context,
name string,
interceptor ToolInterceptor,
call *ToolCallHookRequest,
) (*ToolCallHookRequest, HookDecision, bool) {
return runInterceptorHook(
parent,
hm.interceptorTimeout,
name,
"before_tool",
func(ctx context.Context) (*ToolCallHookRequest, HookDecision, error) {
return interceptor.BeforeTool(ctx, call)
},
)
}
func (hm *HookManager) callAfterTool(
parent context.Context,
name string,
interceptor ToolInterceptor,
resultView *ToolResultHookResponse,
) (*ToolResultHookResponse, HookDecision, bool) {
return runInterceptorHook(
parent,
hm.interceptorTimeout,
name,
"after_tool",
func(ctx context.Context) (*ToolResultHookResponse, HookDecision, error) {
return interceptor.AfterTool(ctx, resultView)
},
)
}
func (hm *HookManager) callApproveTool(
parent context.Context,
name string,
approver ToolApprover,
req *ToolApprovalRequest,
) (ApprovalDecision, bool) {
return runApprovalHook(
parent,
hm.approvalTimeout,
name,
"approve_tool",
func(ctx context.Context) (ApprovalDecision, error) {
return approver.ApproveTool(ctx, req)
},
)
}
func runInterceptorHook[T any](
parent context.Context,
timeout time.Duration,
name string,
stage string,
fn func(ctx context.Context) (T, HookDecision, error),
) (T, HookDecision, bool) {
var zero T
ctx, cancel := context.WithTimeout(parent, timeout)
defer cancel()
type result struct {
value T
decision HookDecision
err error
}
done := make(chan result, 1)
go func() {
value, decision, err := fn(ctx)
done <- result{value: value, decision: decision, err: err}
}()
select {
case res := <-done:
if res.err != nil {
logger.WarnCF("hooks", "Interceptor hook failed", map[string]any{
"hook": name,
"stage": stage,
"error": res.err.Error(),
})
return zero, HookDecision{}, false
}
return res.value, res.decision, true
case <-ctx.Done():
logger.WarnCF("hooks", "Interceptor hook timed out", map[string]any{
"hook": name,
"stage": stage,
"timeout_ms": timeout.Milliseconds(),
})
return zero, HookDecision{}, false
}
}
func runApprovalHook(
parent context.Context,
timeout time.Duration,
name string,
stage string,
fn func(ctx context.Context) (ApprovalDecision, error),
) (ApprovalDecision, bool) {
ctx, cancel := context.WithTimeout(parent, timeout)
defer cancel()
type result struct {
decision ApprovalDecision
err error
}
done := make(chan result, 1)
go func() {
decision, err := fn(ctx)
done <- result{decision: decision, err: err}
}()
select {
case res := <-done:
if res.err != nil {
logger.WarnCF("hooks", "Approval hook failed", map[string]any{
"hook": name,
"stage": stage,
"error": res.err.Error(),
})
return ApprovalDecision{}, false
}
return res.decision, true
case <-ctx.Done():
logger.WarnCF("hooks", "Approval hook timed out", map[string]any{
"hook": name,
"stage": stage,
"timeout_ms": timeout.Milliseconds(),
})
return ApprovalDecision{
Approved: false,
Reason: fmt.Sprintf("tool approval hook %q timed out", name),
}, true
}
}
func (hm *HookManager) logUnsupportedAction(name, stage string, action HookAction) {
logger.WarnCF("hooks", "Hook returned unsupported action for stage", map[string]any{
"hook": name,
"stage": stage,
"action": action,
})
}
func cloneProviderMessages(messages []providers.Message) []providers.Message {
if len(messages) == 0 {
return nil
}
cloned := make([]providers.Message, len(messages))
for i, msg := range messages {
cloned[i] = msg
if len(msg.Media) > 0 {
cloned[i].Media = append([]string(nil), msg.Media...)
}
if len(msg.SystemParts) > 0 {
cloned[i].SystemParts = append([]providers.ContentBlock(nil), msg.SystemParts...)
}
if len(msg.ToolCalls) > 0 {
cloned[i].ToolCalls = cloneProviderToolCalls(msg.ToolCalls)
}
}
return cloned
}
func cloneProviderToolCalls(calls []providers.ToolCall) []providers.ToolCall {
if len(calls) == 0 {
return nil
}
cloned := make([]providers.ToolCall, len(calls))
for i, call := range calls {
cloned[i] = call
if call.Function != nil {
fn := *call.Function
cloned[i].Function = &fn
}
if call.Arguments != nil {
cloned[i].Arguments = cloneStringAnyMap(call.Arguments)
}
if call.ExtraContent != nil {
extra := *call.ExtraContent
if call.ExtraContent.Google != nil {
google := *call.ExtraContent.Google
extra.Google = &google
}
cloned[i].ExtraContent = &extra
}
}
return cloned
}
func cloneToolDefinitions(defs []providers.ToolDefinition) []providers.ToolDefinition {
if len(defs) == 0 {
return nil
}
cloned := make([]providers.ToolDefinition, len(defs))
for i, def := range defs {
cloned[i] = def
cloned[i].Function.Parameters = cloneStringAnyMap(def.Function.Parameters)
}
return cloned
}
func cloneLLMResponse(resp *providers.LLMResponse) *providers.LLMResponse {
if resp == nil {
return nil
}
cloned := *resp
cloned.ToolCalls = cloneProviderToolCalls(resp.ToolCalls)
if len(resp.ReasoningDetails) > 0 {
cloned.ReasoningDetails = append(cloned.ReasoningDetails[:0:0], resp.ReasoningDetails...)
}
if resp.Usage != nil {
usage := *resp.Usage
cloned.Usage = &usage
}
return &cloned
}
func cloneStringAnyMap(src map[string]any) map[string]any {
if len(src) == 0 {
return nil
}
cloned := make(map[string]any, len(src))
for k, v := range src {
cloned[k] = v
}
return cloned
}
func cloneToolResult(result *tools.ToolResult) *tools.ToolResult {
if result == nil {
return nil
}
cloned := *result
if len(result.Media) > 0 {
cloned.Media = append([]string(nil), result.Media...)
}
return &cloned
}
func closeHookIfPossible(hook any) {
closer, ok := hook.(io.Closer)
if !ok {
return
}
if err := closer.Close(); err != nil {
logger.WarnCF("hooks", "Failed to close hook", map[string]any{
"error": err.Error(),
})
}
}
+345
View File
@@ -0,0 +1,345 @@
package agent
import (
"context"
"os"
"sync"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/tools"
)
func newHookTestLoop(
t *testing.T,
provider providers.LLMProvider,
) (*AgentLoop, *AgentInstance, func()) {
t.Helper()
tmpDir, err := os.MkdirTemp("", "agent-hooks-*")
if err != nil {
t.Fatalf("failed to create temp dir: %v", err)
}
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
al := NewAgentLoop(cfg, bus.NewMessageBus(), provider)
agent := al.registry.GetDefaultAgent()
if agent == nil {
t.Fatal("expected default agent")
}
return al, agent, func() {
al.Close()
_ = os.RemoveAll(tmpDir)
}
}
func TestHookManager_SortsInProcessBeforeProcess(t *testing.T) {
hm := NewHookManager(nil)
defer hm.Close()
if err := hm.Mount(HookRegistration{
Name: "process",
Priority: -10,
Source: HookSourceProcess,
Hook: struct{}{},
}); err != nil {
t.Fatalf("mount process hook: %v", err)
}
if err := hm.Mount(HookRegistration{
Name: "in-process",
Priority: 100,
Source: HookSourceInProcess,
Hook: struct{}{},
}); err != nil {
t.Fatalf("mount in-process hook: %v", err)
}
ordered := hm.snapshotHooks()
if len(ordered) != 2 {
t.Fatalf("expected 2 hooks, got %d", len(ordered))
}
if ordered[0].Name != "in-process" {
t.Fatalf("expected in-process hook first, got %q", ordered[0].Name)
}
if ordered[1].Name != "process" {
t.Fatalf("expected process hook second, got %q", ordered[1].Name)
}
}
type llmHookTestProvider struct {
mu sync.Mutex
lastModel string
}
func (p *llmHookTestProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
p.mu.Lock()
p.lastModel = model
p.mu.Unlock()
return &providers.LLMResponse{
Content: "provider content",
}, nil
}
func (p *llmHookTestProvider) GetDefaultModel() string {
return "llm-hook-provider"
}
type llmObserverHook struct {
eventCh chan Event
}
func (h *llmObserverHook) OnEvent(ctx context.Context, evt Event) error {
if evt.Kind == EventKindTurnEnd {
select {
case h.eventCh <- evt:
default:
}
}
return nil
}
func (h *llmObserverHook) BeforeLLM(
ctx context.Context,
req *LLMHookRequest,
) (*LLMHookRequest, HookDecision, error) {
next := req.Clone()
next.Model = "hook-model"
return next, HookDecision{Action: HookActionModify}, nil
}
func (h *llmObserverHook) AfterLLM(
ctx context.Context,
resp *LLMHookResponse,
) (*LLMHookResponse, HookDecision, error) {
next := resp.Clone()
next.Response.Content = "hooked content"
return next, HookDecision{Action: HookActionModify}, nil
}
func TestAgentLoop_Hooks_ObserverAndLLMInterceptor(t *testing.T) {
provider := &llmHookTestProvider{}
al, agent, cleanup := newHookTestLoop(t, provider)
defer cleanup()
hook := &llmObserverHook{eventCh: make(chan Event, 1)}
if err := al.MountHook(NamedHook("llm-observer", hook)); err != nil {
t.Fatalf("MountHook failed: %v", err)
}
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
SessionKey: "session-1",
Channel: "cli",
ChatID: "direct",
UserMessage: "hello",
DefaultResponse: defaultResponse,
EnableSummary: false,
SendResponse: false,
})
if err != nil {
t.Fatalf("runAgentLoop failed: %v", err)
}
if resp != "hooked content" {
t.Fatalf("expected hooked content, got %q", resp)
}
provider.mu.Lock()
lastModel := provider.lastModel
provider.mu.Unlock()
if lastModel != "hook-model" {
t.Fatalf("expected model hook-model, got %q", lastModel)
}
select {
case evt := <-hook.eventCh:
if evt.Kind != EventKindTurnEnd {
t.Fatalf("expected turn end event, got %v", evt.Kind)
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for hook observer event")
}
}
type toolHookProvider struct {
mu sync.Mutex
calls int
}
func (p *toolHookProvider) Chat(
ctx context.Context,
messages []providers.Message,
tools []providers.ToolDefinition,
model string,
opts map[string]any,
) (*providers.LLMResponse, error) {
p.mu.Lock()
defer p.mu.Unlock()
p.calls++
if p.calls == 1 {
return &providers.LLMResponse{
ToolCalls: []providers.ToolCall{
{
ID: "call-1",
Name: "echo_text",
Arguments: map[string]any{"text": "original"},
},
},
}, nil
}
last := messages[len(messages)-1]
return &providers.LLMResponse{
Content: last.Content,
}, nil
}
func (p *toolHookProvider) GetDefaultModel() string {
return "tool-hook-provider"
}
type echoTextTool struct{}
func (t *echoTextTool) Name() string {
return "echo_text"
}
func (t *echoTextTool) Description() string {
return "echo a text argument"
}
func (t *echoTextTool) Parameters() map[string]any {
return map[string]any{
"type": "object",
"properties": map[string]any{
"text": map[string]any{
"type": "string",
},
},
}
}
func (t *echoTextTool) Execute(ctx context.Context, args map[string]any) *tools.ToolResult {
text, _ := args["text"].(string)
return tools.SilentResult(text)
}
type toolRewriteHook struct{}
func (h *toolRewriteHook) BeforeTool(
ctx context.Context,
call *ToolCallHookRequest,
) (*ToolCallHookRequest, HookDecision, error) {
next := call.Clone()
next.Arguments["text"] = "modified"
return next, HookDecision{Action: HookActionModify}, nil
}
func (h *toolRewriteHook) AfterTool(
ctx context.Context,
result *ToolResultHookResponse,
) (*ToolResultHookResponse, HookDecision, error) {
next := result.Clone()
next.Result.ForLLM = "after:" + next.Result.ForLLM
return next, HookDecision{Action: HookActionModify}, nil
}
func TestAgentLoop_Hooks_ToolInterceptorCanRewrite(t *testing.T) {
provider := &toolHookProvider{}
al, agent, cleanup := newHookTestLoop(t, provider)
defer cleanup()
al.RegisterTool(&echoTextTool{})
if err := al.MountHook(NamedHook("tool-rewrite", &toolRewriteHook{})); err != nil {
t.Fatalf("MountHook failed: %v", err)
}
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
SessionKey: "session-1",
Channel: "cli",
ChatID: "direct",
UserMessage: "run tool",
DefaultResponse: defaultResponse,
EnableSummary: false,
SendResponse: false,
})
if err != nil {
t.Fatalf("runAgentLoop failed: %v", err)
}
if resp != "after:modified" {
t.Fatalf("expected rewritten tool result, got %q", resp)
}
}
type denyApprovalHook struct{}
func (h *denyApprovalHook) ApproveTool(ctx context.Context, req *ToolApprovalRequest) (ApprovalDecision, error) {
return ApprovalDecision{
Approved: false,
Reason: "blocked",
}, nil
}
func TestAgentLoop_Hooks_ToolApproverCanDeny(t *testing.T) {
provider := &toolHookProvider{}
al, agent, cleanup := newHookTestLoop(t, provider)
defer cleanup()
al.RegisterTool(&echoTextTool{})
if err := al.MountHook(NamedHook("deny-approval", &denyApprovalHook{})); err != nil {
t.Fatalf("MountHook failed: %v", err)
}
sub := al.SubscribeEvents(16)
defer al.UnsubscribeEvents(sub.ID)
resp, err := al.runAgentLoop(context.Background(), agent, processOptions{
SessionKey: "session-1",
Channel: "cli",
ChatID: "direct",
UserMessage: "run tool",
DefaultResponse: defaultResponse,
EnableSummary: false,
SendResponse: false,
})
if err != nil {
t.Fatalf("runAgentLoop failed: %v", err)
}
expected := "Tool execution denied by approval hook: blocked"
if resp != expected {
t.Fatalf("expected %q, got %q", expected, resp)
}
events := collectEventStream(sub.C)
skippedEvt, ok := findEvent(events, EventKindToolExecSkipped)
if !ok {
t.Fatal("expected tool skipped event")
}
payload, ok := skippedEvt.Payload.(ToolExecSkippedPayload)
if !ok {
t.Fatalf("expected ToolExecSkippedPayload, got %T", skippedEvt.Payload)
}
if payload.Reason != expected {
t.Fatalf("expected skipped reason %q, got %q", expected, payload.Reason)
}
}
+12 -1
View File
@@ -130,6 +130,17 @@ func NewAgentInstance(
maxTokens = 8192
}
contextWindow := defaults.ContextWindow
if contextWindow == 0 {
// Default heuristic: 4x the output token limit.
// Most models have context windows well above their output limits
// (e.g., GPT-4o 128k ctx / 16k out, Claude 200k ctx / 8k out).
// 4x is a conservative lower bound that avoids premature
// summarization while remaining safe — the reactive
// forceCompression handles any overshoot.
contextWindow = maxTokens * 4
}
temperature := 0.7
if defaults.Temperature != nil {
temperature = *defaults.Temperature
@@ -182,7 +193,7 @@ func NewAgentInstance(
MaxTokens: maxTokens,
Temperature: temperature,
ThinkingLevel: thinkingLevel,
ContextWindow: maxTokens,
ContextWindow: contextWindow,
SummarizeMessageThreshold: summarizeMessageThreshold,
SummarizeTokenPercent: summarizeTokenPercent,
Provider: provider,
+1559 -419
View File
File diff suppressed because it is too large Load Diff
+8 -9
View File
@@ -1235,11 +1235,11 @@ func TestAgentLoop_ContextExhaustionRetry(t *testing.T) {
al := NewAgentLoop(cfg, msgBus, provider)
// Inject some history to simulate a full context
// Inject some history to simulate a full context.
// Session history only stores user/assistant/tool messages — the system
// prompt is built dynamically by BuildMessages and is NOT stored here.
sessionKey := "test-session-context"
// Create dummy history
history := []providers.Message{
{Role: "system", Content: "System prompt"},
{Role: "user", Content: "Old message 1"},
{Role: "assistant", Content: "Old response 1"},
{Role: "user", Content: "Old message 2"},
@@ -1277,12 +1277,11 @@ func TestAgentLoop_ContextExhaustionRetry(t *testing.T) {
// Check final history length
finalHistory := defaultAgent.Sessions.GetHistory(sessionKey)
// We verify that the history has been modified (compressed)
// Original length: 6
// Expected behavior: compression drops ~50% of history (mid slice)
// We can assert that the length is NOT what it would be without compression.
// Without compression: 6 + 1 (new user msg) + 1 (assistant msg) = 8
if len(finalHistory) >= 8 {
t.Errorf("Expected history to be compressed (len < 8), got %d", len(finalHistory))
// Original length: 5
// Expected behavior: compression drops ~50% of Turns
// Without compression: 5 + 1 (new user msg) + 1 (assistant msg) = 7
if len(finalHistory) >= 7 {
t.Errorf("Expected history to be compressed (len < 7), got %d", len(finalHistory))
}
}
+503
View File
@@ -0,0 +1,503 @@
package agent
import (
"context"
"fmt"
"strings"
"sync"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/routing"
"github.com/sipeed/picoclaw/pkg/tools"
)
// SteeringMode controls how queued steering messages are dequeued.
type SteeringMode string
const (
// SteeringOneAtATime dequeues only the first queued message per poll.
SteeringOneAtATime SteeringMode = "one-at-a-time"
// SteeringAll drains the entire queue in a single poll.
SteeringAll SteeringMode = "all"
// MaxQueueSize number of possible messages in the Steering Queue
MaxQueueSize = 10
// manualSteeringScope is the legacy fallback queue used when no active
// turn/session scope is available.
manualSteeringScope = "__manual__"
)
// parseSteeringMode normalizes a config string into a SteeringMode.
func parseSteeringMode(s string) SteeringMode {
switch s {
case "all":
return SteeringAll
default:
return SteeringOneAtATime
}
}
// steeringQueue is a thread-safe queue of user messages that can be injected
// into a running agent loop to interrupt it between tool calls.
type steeringQueue struct {
mu sync.Mutex
queues map[string][]providers.Message
mode SteeringMode
}
func newSteeringQueue(mode SteeringMode) *steeringQueue {
return &steeringQueue{
queues: make(map[string][]providers.Message),
mode: mode,
}
}
func normalizeSteeringScope(scope string) string {
scope = strings.TrimSpace(scope)
if scope == "" {
return manualSteeringScope
}
return scope
}
// push enqueues a steering message in the legacy fallback scope.
func (sq *steeringQueue) push(msg providers.Message) error {
return sq.pushScope(manualSteeringScope, msg)
}
// pushScope enqueues a steering message for the provided scope.
func (sq *steeringQueue) pushScope(scope string, msg providers.Message) error {
sq.mu.Lock()
defer sq.mu.Unlock()
scope = normalizeSteeringScope(scope)
queue := sq.queues[scope]
if len(queue) >= MaxQueueSize {
return fmt.Errorf("steering queue is full")
}
sq.queues[scope] = append(queue, msg)
return nil
}
// dequeue removes and returns pending steering messages from the legacy
// fallback scope according to the configured mode.
func (sq *steeringQueue) dequeue() []providers.Message {
return sq.dequeueScope(manualSteeringScope)
}
// dequeueScope removes and returns pending steering messages for the provided
// scope according to the configured mode.
func (sq *steeringQueue) dequeueScope(scope string) []providers.Message {
sq.mu.Lock()
defer sq.mu.Unlock()
return sq.dequeueLocked(normalizeSteeringScope(scope))
}
// dequeueScopeWithFallback drains the scoped queue first and falls back to the
// legacy manual scope for backwards compatibility.
func (sq *steeringQueue) dequeueScopeWithFallback(scope string) []providers.Message {
sq.mu.Lock()
defer sq.mu.Unlock()
scope = strings.TrimSpace(scope)
if scope != "" {
if msgs := sq.dequeueLocked(scope); len(msgs) > 0 {
return msgs
}
}
return sq.dequeueLocked(manualSteeringScope)
}
func (sq *steeringQueue) dequeueLocked(scope string) []providers.Message {
queue := sq.queues[scope]
if len(queue) == 0 {
return nil
}
switch sq.mode {
case SteeringAll:
msgs := append([]providers.Message(nil), queue...)
delete(sq.queues, scope)
return msgs
default:
msg := queue[0]
queue[0] = providers.Message{} // Clear reference for GC
queue = queue[1:]
if len(queue) == 0 {
delete(sq.queues, scope)
} else {
sq.queues[scope] = queue
}
return []providers.Message{msg}
}
}
// len returns the number of queued messages across all scopes.
func (sq *steeringQueue) len() int {
sq.mu.Lock()
defer sq.mu.Unlock()
total := 0
for _, queue := range sq.queues {
total += len(queue)
}
return total
}
// lenScope returns the number of queued messages for a specific scope.
func (sq *steeringQueue) lenScope(scope string) int {
sq.mu.Lock()
defer sq.mu.Unlock()
return len(sq.queues[normalizeSteeringScope(scope)])
}
// setMode updates the steering mode.
func (sq *steeringQueue) setMode(mode SteeringMode) {
sq.mu.Lock()
defer sq.mu.Unlock()
sq.mode = mode
}
// getMode returns the current steering mode.
func (sq *steeringQueue) getMode() SteeringMode {
sq.mu.Lock()
defer sq.mu.Unlock()
return sq.mode
}
// Steer enqueues a user message to be injected into the currently running
// agent loop. The message will be picked up after the current tool finishes
// executing, causing any remaining tool calls in the batch to be skipped.
func (al *AgentLoop) Steer(msg providers.Message) error {
scope := ""
agentID := ""
if ts := al.getAnyActiveTurnState(); ts != nil {
scope = ts.sessionKey
agentID = ts.agentID
}
return al.enqueueSteeringMessage(scope, agentID, msg)
}
func (al *AgentLoop) enqueueSteeringMessage(scope, agentID string, msg providers.Message) error {
if al.steering == nil {
return fmt.Errorf("steering queue is not initialized")
}
if err := al.steering.pushScope(scope, msg); err != nil {
logger.WarnCF("agent", "Failed to enqueue steering message", map[string]any{
"error": err.Error(),
"role": msg.Role,
"scope": normalizeSteeringScope(scope),
})
return err
}
queueDepth := al.steering.lenScope(scope)
logger.DebugCF("agent", "Steering message enqueued", map[string]any{
"role": msg.Role,
"content_len": len(msg.Content),
"media_count": len(msg.Media),
"queue_len": queueDepth,
"scope": normalizeSteeringScope(scope),
})
meta := EventMeta{
Source: "Steer",
TracePath: "turn.interrupt.received",
}
if ts := al.getAnyActiveTurnState(); ts != nil {
meta = ts.eventMeta("Steer", "turn.interrupt.received")
} else {
if strings.TrimSpace(agentID) != "" {
meta.AgentID = agentID
}
normalizedScope := normalizeSteeringScope(scope)
if normalizedScope != manualSteeringScope {
meta.SessionKey = normalizedScope
}
if meta.AgentID == "" {
if registry := al.GetRegistry(); registry != nil {
if agent := registry.GetDefaultAgent(); agent != nil {
meta.AgentID = agent.ID
}
}
}
}
al.emitEvent(
EventKindInterruptReceived,
meta,
InterruptReceivedPayload{
Kind: InterruptKindSteering,
Role: msg.Role,
ContentLen: len(msg.Content),
QueueDepth: queueDepth,
},
)
return nil
}
// SteeringMode returns the current steering mode.
func (al *AgentLoop) SteeringMode() SteeringMode {
if al.steering == nil {
return SteeringOneAtATime
}
return al.steering.getMode()
}
// SetSteeringMode updates the steering mode.
func (al *AgentLoop) SetSteeringMode(mode SteeringMode) {
if al.steering == nil {
return
}
al.steering.setMode(mode)
}
// dequeueSteeringMessages is the internal method called by the agent loop
// to poll for steering messages in the legacy fallback scope.
func (al *AgentLoop) dequeueSteeringMessages() []providers.Message {
if al.steering == nil {
return nil
}
return al.steering.dequeue()
}
func (al *AgentLoop) dequeueSteeringMessagesForScope(scope string) []providers.Message {
if al.steering == nil {
return nil
}
return al.steering.dequeueScope(scope)
}
func (al *AgentLoop) dequeueSteeringMessagesForScopeWithFallback(scope string) []providers.Message {
if al.steering == nil {
return nil
}
return al.steering.dequeueScopeWithFallback(scope)
}
func (al *AgentLoop) pendingSteeringCountForScope(scope string) int {
if al.steering == nil {
return 0
}
return al.steering.lenScope(scope)
}
func (al *AgentLoop) continueWithSteeringMessages(
ctx context.Context,
agent *AgentInstance,
sessionKey, channel, chatID string,
steeringMsgs []providers.Message,
) (string, error) {
return al.runAgentLoop(ctx, agent, processOptions{
SessionKey: sessionKey,
Channel: channel,
ChatID: chatID,
DefaultResponse: defaultResponse,
EnableSummary: true,
SendResponse: false,
InitialSteeringMessages: steeringMsgs,
SkipInitialSteeringPoll: true,
})
}
func (al *AgentLoop) agentForSession(sessionKey string) *AgentInstance {
registry := al.GetRegistry()
if registry == nil {
return nil
}
if parsed := routing.ParseAgentSessionKey(sessionKey); parsed != nil {
if agent, ok := registry.GetAgent(parsed.AgentID); ok {
return agent
}
}
return registry.GetDefaultAgent()
}
// Continue resumes an idle agent by dequeuing any pending steering messages
// and running them through the agent loop. This is used when the agent's last
// message was from the assistant (i.e., it has stopped processing) and the
// user has since enqueued steering messages.
//
// If no steering messages are pending, it returns an empty string.
func (al *AgentLoop) Continue(ctx context.Context, sessionKey, channel, chatID string) (string, error) {
if active := al.GetActiveTurn(); active != nil {
return "", fmt.Errorf("turn %s is still active", active.TurnID)
}
if err := al.ensureHooksInitialized(ctx); err != nil {
return "", err
}
if err := al.ensureMCPInitialized(ctx); err != nil {
return "", err
}
steeringMsgs := al.dequeueSteeringMessagesForScopeWithFallback(sessionKey)
if len(steeringMsgs) == 0 {
return "", nil
}
agent := al.agentForSession(sessionKey)
if agent == nil {
return "", fmt.Errorf("no agent available for session %q", sessionKey)
}
if tool, ok := agent.Tools.Get("message"); ok {
if resetter, ok := tool.(interface{ ResetSentInRound() }); ok {
resetter.ResetSentInRound()
}
}
return al.continueWithSteeringMessages(ctx, agent, sessionKey, channel, chatID, steeringMsgs)
}
func (al *AgentLoop) InterruptGraceful(hint string) error {
ts := al.getAnyActiveTurnState()
if ts == nil {
return fmt.Errorf("no active turn")
}
if !ts.requestGracefulInterrupt(hint) {
return fmt.Errorf("turn %s cannot accept graceful interrupt", ts.turnID)
}
al.emitEvent(
EventKindInterruptReceived,
ts.eventMeta("InterruptGraceful", "turn.interrupt.received"),
InterruptReceivedPayload{
Kind: InterruptKindGraceful,
HintLen: len(hint),
},
)
return nil
}
func (al *AgentLoop) InterruptHard() error {
ts := al.getAnyActiveTurnState()
if ts == nil {
return fmt.Errorf("no active turn")
}
if !ts.requestHardAbort() {
return fmt.Errorf("turn %s is already aborting", ts.turnID)
}
al.emitEvent(
EventKindInterruptReceived,
ts.eventMeta("InterruptHard", "turn.interrupt.received"),
InterruptReceivedPayload{
Kind: InterruptKindHard,
},
)
return nil
}
// ====================== SubTurn Result Polling ======================
// dequeuePendingSubTurnResults polls the SubTurn result channel for the given
// session and returns all available results without blocking.
// Returns nil if no active turn state exists for this session.
func (al *AgentLoop) dequeuePendingSubTurnResults(sessionKey string) []*tools.ToolResult {
tsInterface, ok := al.activeTurnStates.Load(sessionKey)
if !ok {
return nil
}
ts, ok := tsInterface.(*turnState)
if !ok {
return nil
}
var results []*tools.ToolResult
for {
select {
case result, ok := <-ts.pendingResults:
if !ok {
return results
}
if result != nil {
results = append(results, result)
}
default:
return results
}
}
}
// ====================== Hard Abort ======================
// HardAbort immediately cancels the running agent loop for the given session,
// cascading the cancellation to all child SubTurns. This is a destructive operation
// that terminates execution without waiting for graceful cleanup.
//
// Use this when the user explicitly requests immediate termination (e.g., "stop now", "abort").
// For graceful interruption that allows the agent to finish the current tool and summarize,
// use Steer() instead.
func (al *AgentLoop) HardAbort(sessionKey string) error {
tsInterface, ok := al.activeTurnStates.Load(sessionKey)
if !ok {
return fmt.Errorf("no active turn state found for session %s", sessionKey)
}
ts, ok := tsInterface.(*turnState)
if !ok {
return fmt.Errorf("invalid turn state type for session %s", sessionKey)
}
logger.InfoCF("agent", "Hard abort triggered", map[string]any{
"session_key": sessionKey,
"turn_id": ts.turnID,
"depth": ts.depth,
"initial_history_length": ts.initialHistoryLength,
})
// IMPORTANT: Trigger cascading cancellation FIRST to stop all child SubTurns
// from adding more messages to the session. This prevents race conditions
// where rollback happens while children are still writing.
// Use isHardAbort=true for hard abort to immediately cancel all children.
ts.Finish(true)
// Roll back session history to the state before the turn started.
if ts.session != nil {
history := ts.session.GetHistory(sessionKey)
if ts.initialHistoryLength < len(history) {
ts.session.SetHistory(sessionKey, history[:ts.initialHistoryLength])
}
}
return nil
}
// ====================== Follow-Up Injection ======================
// InjectFollowUp enqueues a message to be automatically processed after the current
// turn completes. Unlike Steer(), which interrupts the current execution, InjectFollowUp
// waits for the current turn to finish naturally before processing the message.
//
// This is useful for:
// - Automated workflows that need to chain multiple turns
// - Background tasks that should run after the main task completes
// - Scheduled follow-up actions
//
// The message will be processed via Continue() when the agent becomes idle.
func (al *AgentLoop) InjectFollowUp(msg providers.Message) error {
// InjectFollowUp uses the same steering queue mechanism as Steer(),
// but the semantic difference is in when it's called:
// - Steer() is called during active execution to interrupt
// - InjectFollowUp() is called when planning future work
//
// Both end up in the same queue and are processed by Continue()
// when the agent is idle.
return al.Steer(msg)
}
// ====================== API Aliases for Design Document Compatibility ======================
// InjectSteering is an alias for Steer() to match the design document naming.
// It injects a steering message into the currently running agent loop.
func (al *AgentLoop) InjectSteering(msg providers.Message) error {
return al.Steer(msg)
}
File diff suppressed because it is too large Load Diff
+671
View File
@@ -0,0 +1,671 @@
package agent
import (
"context"
"errors"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/tools"
)
// ====================== Config & Constants ======================
const (
// Default values for SubTurn configuration (used when config is not set or is zero)
defaultMaxSubTurnDepth = 3
defaultMaxConcurrentSubTurns = 5
defaultConcurrencyTimeout = 30 * time.Second
defaultSubTurnTimeout = 5 * time.Minute
// maxEphemeralHistorySize limits the number of messages stored in ephemeral sessions.
// This prevents memory accumulation in long-running sub-turns.
maxEphemeralHistorySize = 50
)
var (
ErrDepthLimitExceeded = errors.New("sub-turn depth limit exceeded")
ErrInvalidSubTurnConfig = errors.New("invalid sub-turn config")
ErrConcurrencyTimeout = errors.New("timeout waiting for concurrency slot")
)
// getSubTurnConfig returns the effective SubTurn configuration with defaults applied.
func (al *AgentLoop) getSubTurnConfig() subTurnRuntimeConfig {
cfg := al.cfg.Agents.Defaults.SubTurn
maxDepth := cfg.MaxDepth
if maxDepth <= 0 {
maxDepth = defaultMaxSubTurnDepth
}
maxConcurrent := cfg.MaxConcurrent
if maxConcurrent <= 0 {
maxConcurrent = defaultMaxConcurrentSubTurns
}
concurrencyTimeout := time.Duration(cfg.ConcurrencyTimeoutSec) * time.Second
if concurrencyTimeout <= 0 {
concurrencyTimeout = defaultConcurrencyTimeout
}
defaultTimeout := time.Duration(cfg.DefaultTimeoutMinutes) * time.Minute
if defaultTimeout <= 0 {
defaultTimeout = defaultSubTurnTimeout
}
return subTurnRuntimeConfig{
maxDepth: maxDepth,
maxConcurrent: maxConcurrent,
concurrencyTimeout: concurrencyTimeout,
defaultTimeout: defaultTimeout,
defaultTokenBudget: cfg.DefaultTokenBudget,
}
}
// subTurnRuntimeConfig holds the effective runtime configuration for SubTurn execution.
type subTurnRuntimeConfig struct {
maxDepth int
maxConcurrent int
concurrencyTimeout time.Duration
defaultTimeout time.Duration
defaultTokenBudget int
}
// ====================== SubTurn Config ======================
// SubTurnConfig configures the execution of a child sub-turn.
//
// Usage Examples:
//
// Synchronous sub-turn (Async=false):
//
// cfg := SubTurnConfig{
// Model: "gpt-4o-mini",
// SystemPrompt: "Analyze this code",
// Async: false, // Result returned immediately
// }
// result, err := SpawnSubTurn(ctx, cfg)
// // Use result directly here
// processResult(result)
//
// Asynchronous sub-turn (Async=true):
//
// cfg := SubTurnConfig{
// Model: "gpt-4o-mini",
// SystemPrompt: "Background analysis",
// Async: true, // Result delivered to channel
// }
// result, err := SpawnSubTurn(ctx, cfg)
// // Result also available in parent's pendingResults channel
// // Parent turn will poll and process it in a later iteration
type SubTurnConfig struct {
Model string
Tools []tools.Tool
SystemPrompt string
MaxTokens int
// Async controls the result delivery mechanism:
//
// When Async = false (synchronous sub-turn):
// - The caller blocks until the sub-turn completes
// - The result is ONLY returned via the function return value
// - The result is NOT delivered to the parent's pendingResults channel
// - This prevents double delivery: caller gets result immediately, no need for channel
// - Use case: When the caller needs the result immediately to continue execution
// - Example: A tool that needs to process the sub-turn result before returning
//
// When Async = true (asynchronous sub-turn):
// - The sub-turn runs in the background (still blocks the caller, but semantically async)
// - The result is delivered to the parent's pendingResults channel
// - The result is ALSO returned via the function return value (for consistency)
// - The parent turn can poll pendingResults in later iterations to process results
// - Use case: Fire-and-forget operations, or when results are processed in batches
// - Example: Spawning multiple sub-turns in parallel and collecting results later
//
// IMPORTANT: The Async flag does NOT make the call non-blocking. It only controls
// whether the result is delivered via the channel. For true non-blocking execution,
// the caller must spawn the sub-turn in a separate goroutine.
Async bool
// Critical indicates this SubTurn's result is important and should continue
// running even after the parent turn finishes gracefully.
//
// When parent finishes gracefully (Finish(false)):
// - Critical=true: SubTurn continues running, delivers result as orphan
// - Critical=false: SubTurn exits gracefully without error
//
// When parent finishes with hard abort (Finish(true)):
// - All SubTurns are canceled regardless of Critical flag
Critical bool
// Timeout is the maximum duration for this SubTurn.
// If the SubTurn runs longer than this, it will be canceled.
// Default is 5 minutes (defaultSubTurnTimeout) if not specified.
Timeout time.Duration
// MaxContextRunes limits the context size (in runes) passed to the SubTurn.
// This prevents context window overflow by truncating message history before LLM calls.
//
// Values:
// 0 = Auto-calculate based on model's ContextWindow * 0.75 (default, recommended)
// -1 = No limit (disable soft truncation, rely only on hard context errors)
// >0 = Use specified rune limit
//
// The soft limit acts as a first line of defense before hitting the provider's
// hard context window limit. When exceeded, older messages are intelligently
// truncated while preserving system messages and recent context.
MaxContextRunes int
// ActualSystemPrompt is injected as the true 'system' role message for the childAgent.
// The legacy SystemPrompt field is actually used as the first 'user' message (task description).
ActualSystemPrompt string
// InitialMessages preloads the ephemeral session history before the agent loop starts.
// Used by evaluator-optimizer patterns to pass the full worker context across multiple iterations.
InitialMessages []providers.Message
// InitialTokenBudget is a shared atomic counter for tracking remaining tokens.
// If set, the SubTurn will inherit this budget and deduct tokens after each LLM call.
// If nil, the SubTurn will inherit the parent's tokenBudget (if any).
// Used by team tool to enforce token limits across all team members.
InitialTokenBudget *atomic.Int64
// Can be extended with temperature, topP, etc.
}
// ====================== Context Keys ======================
type agentLoopKeyType struct{}
var agentLoopKey = agentLoopKeyType{}
// WithAgentLoop injects AgentLoop into context for tool access
func WithAgentLoop(ctx context.Context, al *AgentLoop) context.Context {
return context.WithValue(ctx, agentLoopKey, al)
}
// AgentLoopFromContext retrieves AgentLoop from context
func AgentLoopFromContext(ctx context.Context) *AgentLoop {
al, _ := ctx.Value(agentLoopKey).(*AgentLoop)
return al
}
// ====================== Helper Functions ======================
func (al *AgentLoop) generateSubTurnID() string {
return fmt.Sprintf("subturn-%d", al.subTurnCounter.Add(1))
}
// ====================== Core Function: spawnSubTurn ======================
// AgentLoopSpawner implements tools.SubTurnSpawner interface.
// This allows tools to spawn sub-turns without circular dependency.
type AgentLoopSpawner struct {
al *AgentLoop
}
// SpawnSubTurn implements tools.SubTurnSpawner interface.
func (s *AgentLoopSpawner) SpawnSubTurn(
ctx context.Context,
cfg tools.SubTurnConfig,
) (*tools.ToolResult, error) {
parentTS := turnStateFromContext(ctx)
if parentTS == nil {
return nil, errors.New(
"parent turnState not found in context - cannot spawn sub-turn outside of a turn",
)
}
// Convert tools.SubTurnConfig to agent.SubTurnConfig
agentCfg := SubTurnConfig{
Model: cfg.Model,
Tools: cfg.Tools,
SystemPrompt: cfg.SystemPrompt,
ActualSystemPrompt: cfg.ActualSystemPrompt,
InitialMessages: cfg.InitialMessages,
InitialTokenBudget: cfg.InitialTokenBudget,
MaxTokens: cfg.MaxTokens,
Async: cfg.Async,
Critical: cfg.Critical,
Timeout: cfg.Timeout,
MaxContextRunes: cfg.MaxContextRunes,
}
return spawnSubTurn(ctx, s.al, parentTS, agentCfg)
}
// NewSubTurnSpawner creates a SubTurnSpawner for the given AgentLoop.
func NewSubTurnSpawner(al *AgentLoop) *AgentLoopSpawner {
return &AgentLoopSpawner{al: al}
}
// SpawnSubTurn is the exported entry point for tools to spawn sub-turns.
// It retrieves AgentLoop and parent turnState from context and delegates to spawnSubTurn.
func SpawnSubTurn(ctx context.Context, cfg SubTurnConfig) (*tools.ToolResult, error) {
al := AgentLoopFromContext(ctx)
if al == nil {
return nil, errors.New(
"AgentLoop not found in context - ensure context is properly initialized",
)
}
parentTS := turnStateFromContext(ctx)
if parentTS == nil {
return nil, errors.New(
"parent turnState not found in context - cannot spawn sub-turn outside of a turn",
)
}
return spawnSubTurn(ctx, al, parentTS, cfg)
}
func spawnSubTurn(
ctx context.Context,
al *AgentLoop,
parentTS *turnState,
cfg SubTurnConfig,
) (result *tools.ToolResult, err error) {
// Get effective SubTurn configuration
rtCfg := al.getSubTurnConfig()
// 0. Acquire concurrency semaphore FIRST to ensure it's released even if early validation fails.
// Blocks if parent already has maxConcurrentSubTurns running, with a timeout to prevent indefinite blocking.
// Also respects context cancellation so we don't block forever if parent is aborted.
// NOTE: The semaphore is released immediately after runTurn completes (not in a defer) to
// ensure it is freed before the cleanup phase (async result delivery), which may block on
// a full pendingResults channel. Holding the semaphore through cleanup would allow the
// parent's goroutine to be blocked waiting for a semaphore slot while child turns are
// blocked delivering results — a deadlock.
var semAcquired bool
if parentTS.concurrencySem != nil {
// Create a timeout context for semaphore acquisition
timeoutCtx, cancel := context.WithTimeout(ctx, rtCfg.concurrencyTimeout)
defer cancel()
select {
case parentTS.concurrencySem <- struct{}{}:
semAcquired = true
defer func() {
if semAcquired {
<-parentTS.concurrencySem
}
}()
case <-timeoutCtx.Done():
// Check parent context first - if it was canceled, propagate that error
if ctx.Err() != nil {
return nil, ctx.Err()
}
// Otherwise it's our timeout
return nil, fmt.Errorf("%w: all %d slots occupied for %v",
ErrConcurrencyTimeout, rtCfg.maxConcurrent, rtCfg.concurrencyTimeout)
}
}
// 1. Depth limit check
if parentTS.depth >= rtCfg.maxDepth {
logger.WarnCF("subturn", "Depth limit exceeded", map[string]any{
"parent_id": parentTS.turnID,
"depth": parentTS.depth,
"max_depth": rtCfg.maxDepth,
})
return nil, ErrDepthLimitExceeded
}
// 2. Config validation
if cfg.Model == "" {
return nil, ErrInvalidSubTurnConfig
}
// 3. Determine timeout for child SubTurn
timeout := cfg.Timeout
if timeout <= 0 {
timeout = rtCfg.defaultTimeout
}
// 4. Create INDEPENDENT child context (not derived from parent ctx).
// This allows the child to continue running after parent finishes gracefully.
// The child has its own timeout for self-protection.
childCtx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
childID := al.generateSubTurnID()
// Get the agent instance from parent, falling back to the default agent.
// Wrap it in a shallow copy that uses an ephemeral (in-memory only) session store
// so that child turns never pollute or persist to the parent's session history.
baseAgent := parentTS.agent
if baseAgent == nil {
baseAgent = al.registry.GetDefaultAgent()
}
if baseAgent == nil {
return nil, errors.New("parent turnState has no agent instance")
}
ephemeralStore := newEphemeralSession(nil)
agent := *baseAgent // shallow copy
agent.Sessions = ephemeralStore
// Clone the tool registry so child turn's tool registrations
// don't pollute the parent's registry.
if baseAgent.Tools != nil {
agent.Tools = baseAgent.Tools.Clone()
}
// Create processOptions for the child turn
opts := processOptions{
SessionKey: childID,
Channel: parentTS.channel,
ChatID: parentTS.chatID,
SenderID: parentTS.opts.SenderID,
SenderDisplayName: parentTS.opts.SenderDisplayName,
UserMessage: cfg.SystemPrompt, // Task description becomes the first user message
SystemPromptOverride: cfg.ActualSystemPrompt,
Media: nil,
InitialSteeringMessages: cfg.InitialMessages,
DefaultResponse: "",
EnableSummary: false,
SendResponse: false,
NoHistory: true, // SubTurns don't use session history
SkipInitialSteeringPoll: true,
}
// Create event scope for the child turn
scope := al.newTurnEventScope(agent.ID, childID)
// Create child turnState using the new API
childTS := newTurnState(&agent, opts, scope)
// Set SubTurn-specific fields
childTS.cancelFunc = cancel
childTS.critical = cfg.Critical
childTS.depth = parentTS.depth + 1
childTS.parentTurnID = parentTS.turnID
childTS.parentTurnState = parentTS
childTS.pendingResults = make(chan *tools.ToolResult, 16)
childTS.concurrencySem = make(chan struct{}, rtCfg.maxConcurrent)
childTS.al = al // back-ref for hard abort cascade
childTS.session = ephemeralStore // same store as agent.Sessions
// Token budget initialization/inheritance
// If InitialTokenBudget is explicitly provided (e.g., by team tool), use it.
// Otherwise, inherit from parent's tokenBudget (for nested SubTurns).
if cfg.InitialTokenBudget != nil {
childTS.tokenBudget = cfg.InitialTokenBudget
} else if parentTS.tokenBudget != nil {
childTS.tokenBudget = parentTS.tokenBudget
} else if rtCfg.defaultTokenBudget > 0 {
// Apply default token budget from config if no budget is set
budget := &atomic.Int64{}
budget.Store(int64(rtCfg.defaultTokenBudget))
childTS.tokenBudget = budget
}
// IMPORTANT: Put childTS into childCtx so that code inside runTurn can retrieve it
childCtx = withTurnState(childCtx, childTS)
childCtx = WithAgentLoop(childCtx, al) // Propagate AgentLoop to child turn
childTS.ctx = childCtx
// Register child turn state so GetAllActiveTurns/Subagents can find it
al.activeTurnStates.Store(childID, childTS)
defer al.activeTurnStates.Delete(childID)
// 5. Establish parent-child relationship (thread-safe)
parentTS.mu.Lock()
parentTS.childTurnIDs = append(parentTS.childTurnIDs, childID)
parentTS.mu.Unlock()
// 6. Emit Spawn event
al.emitEvent(EventKindSubTurnSpawn,
childTS.eventMeta("spawnSubTurn", "subturn.spawn"),
SubTurnSpawnPayload{
AgentID: childTS.agentID,
Label: childID,
ParentTurnID: parentTS.turnID,
},
)
// 7. Defer cleanup: deliver result (for async), emit End event, and recover from panics
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("subturn panicked: %v", r)
result = nil
logger.ErrorCF("subturn", "SubTurn panicked", map[string]any{
"child_id": childID,
"parent_id": parentTS.turnID,
"panic": r,
})
}
// Result Delivery Strategy (Async vs Sync)
if cfg.Async {
deliverSubTurnResult(al, parentTS, childID, result)
}
status := "completed"
if err != nil {
status = "error"
}
al.emitEvent(EventKindSubTurnEnd,
childTS.eventMeta("spawnSubTurn", "subturn.end"),
SubTurnEndPayload{
AgentID: childTS.agentID,
Status: status,
},
)
}()
// 8. Execute sub-turn via the real agent loop.
turnRes, turnErr := al.runTurn(childCtx, childTS)
// Release the concurrency semaphore immediately after runTurn completes,
// before the cleanup defer runs. This prevents a deadlock where:
// - All semaphore slots are held by sub-turns in their cleanup phase
// - Cleanup blocks on a full pendingResults channel
// - The parent goroutine is blocked waiting for a semaphore slot
// - The parent cannot consume pendingResults because it is blocked on the semaphore
if semAcquired {
<-parentTS.concurrencySem
semAcquired = false // prevent the defer from double-releasing
}
// Convert turnResult to tools.ToolResult
if turnErr != nil {
err = turnErr
result = &tools.ToolResult{
Err: turnErr,
ForLLM: fmt.Sprintf("SubTurn failed: %v", turnErr),
}
} else {
result = &tools.ToolResult{
ForLLM: turnRes.finalContent,
ForUser: turnRes.finalContent,
}
}
return result, err
}
// ====================== Result Delivery ======================
// deliverSubTurnResult delivers a sub-turn result to the parent turn's pendingResults channel.
//
// IMPORTANT: This function is ONLY called for asynchronous sub-turns (Async=true).
// For synchronous sub-turns (Async=false), results are returned directly via the function
// return value to avoid double delivery.
//
// Delivery behavior:
// - If parent turn is still running: attempts to deliver to pendingResults channel
// - If channel is full: emits SubTurnOrphanResultEvent (result is lost from channel but tracked)
// - If parent turn has finished: emits SubTurnOrphanResultEvent (late arrival)
//
// Thread safety:
// - Reads parent state under lock, then releases lock before channel send
// - Small race window exists but is acceptable (worst case: result becomes orphan)
//
// Event emissions:
// - SubTurnResultDeliveredEvent: successful delivery to channel
// - SubTurnOrphanResultEvent: delivery failed (parent finished or channel full)
func deliverSubTurnResult(al *AgentLoop, parentTS *turnState, childID string, result *tools.ToolResult) {
// Let GC clean up the pendingResults channel; parent Finish will no longer close it.
// We use defer/recover to catch any unlikely channel panics if it were ever closed.
defer func() {
if r := recover(); r != nil {
logger.WarnCF("subturn", "recovered panic sending to pendingResults", map[string]any{
"parent_id": parentTS.turnID,
"child_id": childID,
"recover": r,
})
if result != nil && al != nil {
al.emitEvent(EventKindSubTurnOrphan,
parentTS.eventMeta("deliverSubTurnResult", "subturn.orphan"),
SubTurnOrphanPayload{ParentTurnID: parentTS.turnID, ChildTurnID: childID, Reason: "panic"},
)
}
}
}()
parentTS.mu.Lock()
isFinished := parentTS.isFinished.Load()
resultChan := parentTS.pendingResults
parentTS.mu.Unlock()
// If parent turn has already finished, treat this as an orphan result
if isFinished || resultChan == nil {
if result != nil && al != nil {
al.emitEvent(EventKindSubTurnOrphan,
parentTS.eventMeta("deliverSubTurnResult", "subturn.orphan"),
SubTurnOrphanPayload{ParentTurnID: parentTS.turnID, ChildTurnID: childID, Reason: "parent_finished"},
)
}
return
}
// Parent Turn is still running → attempt to deliver result
// We use a select statement with parentTS.Finished() to ensure that if the
// parent turn finishes while we are waiting to send the result (e.g. channel
// is full), we don't leak this goroutine by blocking forever.
select {
case resultChan <- result:
// Successfully delivered
if al != nil {
al.emitEvent(EventKindSubTurnResultDelivered,
parentTS.eventMeta("deliverSubTurnResult", "subturn.result_delivered"),
SubTurnResultDeliveredPayload{ContentLen: len(result.ForLLM)},
)
}
case <-parentTS.Finished():
// Parent finished while we were waiting to deliver.
// The result cannot be delivered to the LLM, so it becomes an orphan.
logger.WarnCF("subturn", "parent finished before result could be delivered", map[string]any{
"parent_id": parentTS.turnID,
"child_id": childID,
})
if result != nil && al != nil {
al.emitEvent(
EventKindSubTurnOrphan,
parentTS.eventMeta("deliverSubTurnResult", "subturn.orphan"),
SubTurnOrphanPayload{
ParentTurnID: parentTS.turnID,
ChildTurnID: childID,
Reason: "parent_finished_waiting",
},
)
}
}
}
// ====================== Other Types ======================
// ephemeralSessionStore is an in-memory session.SessionStore used by SubTurns.
// It does not persist to disk and auto-truncates history to maxEphemeralHistorySize.
type ephemeralSessionStore struct {
mu sync.Mutex
history []providers.Message
summary string
}
func newEphemeralSession(initial []providers.Message) ephemeralSessionStoreIface {
s := &ephemeralSessionStore{}
if len(initial) > 0 {
s.history = append(s.history, initial...)
}
return s
}
// ephemeralSessionStoreIface is satisfied by *ephemeralSessionStore.
// Declared so newEphemeralSession can return a typed interface.
type ephemeralSessionStoreIface interface {
AddMessage(sessionKey, role, content string)
AddFullMessage(sessionKey string, msg providers.Message)
GetHistory(key string) []providers.Message
GetSummary(key string) string
SetSummary(key, summary string)
SetHistory(key string, history []providers.Message)
TruncateHistory(key string, keepLast int)
Save(key string) error
Close() error
}
func (e *ephemeralSessionStore) AddMessage(_, role, content string) {
e.mu.Lock()
defer e.mu.Unlock()
e.history = append(e.history, providers.Message{Role: role, Content: content})
e.truncateLocked()
}
func (e *ephemeralSessionStore) AddFullMessage(_ string, msg providers.Message) {
e.mu.Lock()
defer e.mu.Unlock()
e.history = append(e.history, msg)
e.truncateLocked()
}
func (e *ephemeralSessionStore) GetHistory(_ string) []providers.Message {
e.mu.Lock()
defer e.mu.Unlock()
out := make([]providers.Message, len(e.history))
copy(out, e.history)
return out
}
func (e *ephemeralSessionStore) GetSummary(_ string) string {
e.mu.Lock()
defer e.mu.Unlock()
return e.summary
}
func (e *ephemeralSessionStore) SetSummary(_, summary string) {
e.mu.Lock()
defer e.mu.Unlock()
e.summary = summary
}
func (e *ephemeralSessionStore) SetHistory(_ string, history []providers.Message) {
e.mu.Lock()
defer e.mu.Unlock()
e.history = make([]providers.Message, len(history))
copy(e.history, history)
e.truncateLocked()
}
func (e *ephemeralSessionStore) TruncateHistory(_ string, keepLast int) {
e.mu.Lock()
defer e.mu.Unlock()
if keepLast <= 0 {
e.history = nil
return
}
if keepLast >= len(e.history) {
return
}
e.history = e.history[len(e.history)-keepLast:]
}
func (e *ephemeralSessionStore) Save(_ string) error { return nil }
func (e *ephemeralSessionStore) Close() error { return nil }
func (e *ephemeralSessionStore) truncateLocked() {
if len(e.history) > maxEphemeralHistorySize {
e.history = e.history[len(e.history)-maxEphemeralHistorySize:]
}
}
File diff suppressed because it is too large Load Diff
+481
View File
@@ -0,0 +1,481 @@
package agent
import (
"context"
"reflect"
"sync"
"sync/atomic"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/session"
"github.com/sipeed/picoclaw/pkg/tools"
)
type TurnPhase string
const (
TurnPhaseSetup TurnPhase = "setup"
TurnPhaseRunning TurnPhase = "running"
TurnPhaseTools TurnPhase = "tools"
TurnPhaseFinalizing TurnPhase = "finalizing"
TurnPhaseCompleted TurnPhase = "completed"
TurnPhaseAborted TurnPhase = "aborted"
)
type ActiveTurnInfo struct {
TurnID string
AgentID string
SessionKey string
Channel string
ChatID string
UserMessage string
Phase TurnPhase
Iteration int
StartedAt time.Time
Depth int
ParentTurnID string
ChildTurnIDs []string
}
type turnResult struct {
finalContent string
status TurnEndStatus
followUps []bus.InboundMessage
}
type turnState struct {
mu sync.RWMutex
agent *AgentInstance
opts processOptions
scope turnEventScope
turnID string
agentID string
sessionKey string
channel string
chatID string
userMessage string
media []string
phase TurnPhase
iteration int
startedAt time.Time
finalContent string
followUps []bus.InboundMessage
gracefulInterrupt bool
gracefulInterruptHint string
gracefulTerminalUsed bool
hardAbort bool
providerCancel context.CancelFunc
turnCancel context.CancelFunc
restorePointHistory []providers.Message
restorePointSummary string
persistedMessages []providers.Message
// SubTurn support (from HEAD)
depth int // SubTurn depth (0 for root turn)
parentTurnID string // Parent turn ID (empty for root turn)
childTurnIDs []string // Child turn IDs
pendingResults chan *tools.ToolResult // Channel for SubTurn results
concurrencySem chan struct{} // Semaphore for limiting concurrent SubTurns
isFinished atomic.Bool // Whether this turn has finished
session session.SessionStore // Session store reference
initialHistoryLength int // Snapshot of history length at turn start
// Additional SubTurn fields
ctx context.Context // Context for this turn
cancelFunc context.CancelFunc // Cancel function for this turn's context
critical bool // Whether this SubTurn should continue after parent ends
parentTurnState *turnState // Reference to parent turnState
parentEnded atomic.Bool // Whether parent has ended
closeOnce sync.Once // Ensures pendingResults channel is closed once
finishedChan chan struct{} // Closed when turn finishes
// Token budget tracking
tokenBudget *atomic.Int64 // Shared token budget counter
lastFinishReason string // Last LLM finish_reason
lastUsage *providers.UsageInfo // Last LLM usage info
// Back-reference to the owning AgentLoop (set for SubTurns only, used for hard abort cascade)
al *AgentLoop
}
func newTurnState(agent *AgentInstance, opts processOptions, scope turnEventScope) *turnState {
ts := &turnState{
agent: agent,
opts: opts,
scope: scope,
turnID: scope.turnID,
agentID: agent.ID,
sessionKey: opts.SessionKey,
channel: opts.Channel,
chatID: opts.ChatID,
userMessage: opts.UserMessage,
media: append([]string(nil), opts.Media...),
phase: TurnPhaseSetup,
startedAt: time.Now(),
}
// Bind session store and capture initial history length for rollback logic
if agent != nil && agent.Sessions != nil {
ts.session = agent.Sessions
ts.initialHistoryLength = len(agent.Sessions.GetHistory(opts.SessionKey))
}
return ts
}
func (al *AgentLoop) registerActiveTurn(ts *turnState) {
al.activeTurnStates.Store(ts.sessionKey, ts)
}
func (al *AgentLoop) clearActiveTurn(ts *turnState) {
al.activeTurnStates.Delete(ts.sessionKey)
}
func (al *AgentLoop) getActiveTurnState(sessionKey string) *turnState {
if val, ok := al.activeTurnStates.Load(sessionKey); ok {
return val.(*turnState)
}
return nil
}
// getAnyActiveTurnState returns any active turn state (for backward compatibility)
func (al *AgentLoop) getAnyActiveTurnState() *turnState {
var firstTS *turnState
al.activeTurnStates.Range(func(key, value any) bool {
firstTS = value.(*turnState)
return false // stop after first
})
return firstTS
}
func (al *AgentLoop) GetActiveTurn() *ActiveTurnInfo {
// For backward compatibility, return the first active turn found
// In the new architecture, there can be multiple concurrent turns
var firstTS *turnState
al.activeTurnStates.Range(func(key, value any) bool {
firstTS = value.(*turnState)
return false // stop after first
})
if firstTS == nil {
return nil
}
info := firstTS.snapshot()
return &info
}
func (al *AgentLoop) GetActiveTurnBySession(sessionKey string) *ActiveTurnInfo {
ts := al.getActiveTurnState(sessionKey)
if ts == nil {
return nil
}
info := ts.snapshot()
return &info
}
func (ts *turnState) snapshot() ActiveTurnInfo {
ts.mu.RLock()
defer ts.mu.RUnlock()
return ActiveTurnInfo{
TurnID: ts.turnID,
AgentID: ts.agentID,
SessionKey: ts.sessionKey,
Channel: ts.channel,
ChatID: ts.chatID,
UserMessage: ts.userMessage,
Phase: ts.phase,
Iteration: ts.iteration,
StartedAt: ts.startedAt,
Depth: ts.depth,
ParentTurnID: ts.parentTurnID,
ChildTurnIDs: append([]string(nil), ts.childTurnIDs...),
}
}
func (ts *turnState) setPhase(phase TurnPhase) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.phase = phase
}
func (ts *turnState) setIteration(iteration int) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.iteration = iteration
}
func (ts *turnState) currentIteration() int {
ts.mu.RLock()
defer ts.mu.RUnlock()
return ts.iteration
}
func (ts *turnState) setFinalContent(content string) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.finalContent = content
}
func (ts *turnState) finalContentLen() int {
ts.mu.RLock()
defer ts.mu.RUnlock()
return len(ts.finalContent)
}
func (ts *turnState) setTurnCancel(cancel context.CancelFunc) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.turnCancel = cancel
}
func (ts *turnState) setProviderCancel(cancel context.CancelFunc) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.providerCancel = cancel
}
func (ts *turnState) clearProviderCancel(_ context.CancelFunc) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.providerCancel = nil
}
func (ts *turnState) requestGracefulInterrupt(hint string) bool {
ts.mu.Lock()
defer ts.mu.Unlock()
if ts.hardAbort {
return false
}
ts.gracefulInterrupt = true
ts.gracefulInterruptHint = hint
return true
}
func (ts *turnState) gracefulInterruptRequested() (bool, string) {
ts.mu.RLock()
defer ts.mu.RUnlock()
return ts.gracefulInterrupt && !ts.gracefulTerminalUsed, ts.gracefulInterruptHint
}
func (ts *turnState) markGracefulTerminalUsed() {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.gracefulTerminalUsed = true
}
func (ts *turnState) requestHardAbort() bool {
ts.mu.Lock()
if ts.hardAbort {
ts.mu.Unlock()
return false
}
ts.hardAbort = true
turnCancel := ts.turnCancel
providerCancel := ts.providerCancel
ts.mu.Unlock()
if providerCancel != nil {
providerCancel()
}
if turnCancel != nil {
turnCancel()
}
return true
}
func (ts *turnState) hardAbortRequested() bool {
ts.mu.RLock()
defer ts.mu.RUnlock()
return ts.hardAbort
}
func (ts *turnState) eventMeta(source, tracePath string) EventMeta {
snap := ts.snapshot()
return EventMeta{
AgentID: snap.AgentID,
TurnID: snap.TurnID,
SessionKey: snap.SessionKey,
Iteration: snap.Iteration,
Source: source,
TracePath: tracePath,
}
}
func (ts *turnState) captureRestorePoint(history []providers.Message, summary string) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.restorePointHistory = append([]providers.Message(nil), history...)
ts.restorePointSummary = summary
}
func (ts *turnState) recordPersistedMessage(msg providers.Message) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.persistedMessages = append(ts.persistedMessages, msg)
}
func (ts *turnState) refreshRestorePointFromSession(agent *AgentInstance) {
history := agent.Sessions.GetHistory(ts.sessionKey)
summary := agent.Sessions.GetSummary(ts.sessionKey)
ts.mu.RLock()
persisted := append([]providers.Message(nil), ts.persistedMessages...)
ts.mu.RUnlock()
if matched := matchingTurnMessageTail(history, persisted); matched > 0 {
history = append([]providers.Message(nil), history[:len(history)-matched]...)
}
ts.captureRestorePoint(history, summary)
}
func (ts *turnState) restoreSession(agent *AgentInstance) error {
ts.mu.RLock()
history := append([]providers.Message(nil), ts.restorePointHistory...)
summary := ts.restorePointSummary
ts.mu.RUnlock()
agent.Sessions.SetHistory(ts.sessionKey, history)
agent.Sessions.SetSummary(ts.sessionKey, summary)
return agent.Sessions.Save(ts.sessionKey)
}
func matchingTurnMessageTail(history, persisted []providers.Message) int {
maxMatch := min(len(history), len(persisted))
for size := maxMatch; size > 0; size-- {
if reflect.DeepEqual(history[len(history)-size:], persisted[len(persisted)-size:]) {
return size
}
}
return 0
}
func (ts *turnState) interruptHintMessage() providers.Message {
_, hint := ts.gracefulInterruptRequested()
content := "Interrupt requested. Stop scheduling tools and provide a short final summary."
if hint != "" {
content += "\n\nInterrupt hint: " + hint
}
return providers.Message{
Role: "user",
Content: content,
}
}
// SubTurn-related methods
// Finish marks the turn as finished and closes the pendingResults channel
func (ts *turnState) Finish(isHardAbort bool) {
ts.isFinished.Store(true)
// Close pendingResults channel exactly once
ts.closeOnce.Do(func() {
if ts.pendingResults != nil {
close(ts.pendingResults)
}
ts.mu.Lock()
if ts.finishedChan == nil {
ts.finishedChan = make(chan struct{})
}
close(ts.finishedChan)
ts.mu.Unlock()
})
// If this is a graceful finish (not hard abort), signal to children
if !isHardAbort && ts.parentTurnState == nil {
// This is a root turn finishing gracefully
ts.parentEnded.Store(true)
}
// Cancel the turn context
if ts.cancelFunc != nil {
ts.cancelFunc()
}
// Hard abort cascades to all child turns
if isHardAbort && ts.al != nil {
ts.mu.RLock()
children := append([]string(nil), ts.childTurnIDs...)
ts.mu.RUnlock()
for _, childID := range children {
if val, ok := ts.al.activeTurnStates.Load(childID); ok {
val.(*turnState).Finish(true)
}
}
}
}
// Finished returns whether the turn has finished
func (ts *turnState) Finished() chan struct{} {
ts.mu.Lock()
defer ts.mu.Unlock()
if ts.finishedChan == nil {
ts.finishedChan = make(chan struct{})
}
return ts.finishedChan
}
// IsParentEnded checks if the parent turn has ended
func (ts *turnState) IsParentEnded() bool {
if ts.parentTurnState == nil {
return false
}
return ts.parentTurnState.parentEnded.Load()
}
// GetLastFinishReason returns the last LLM finish_reason
func (ts *turnState) GetLastFinishReason() string {
ts.mu.RLock()
defer ts.mu.RUnlock()
return ts.lastFinishReason
}
// SetLastFinishReason sets the last LLM finish_reason
func (ts *turnState) SetLastFinishReason(reason string) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.lastFinishReason = reason
}
// GetLastUsage returns the last LLM usage info
func (ts *turnState) GetLastUsage() *providers.UsageInfo {
ts.mu.RLock()
defer ts.mu.RUnlock()
return ts.lastUsage
}
// SetLastUsage sets the last LLM usage info
func (ts *turnState) SetLastUsage(usage *providers.UsageInfo) {
ts.mu.Lock()
defer ts.mu.Unlock()
ts.lastUsage = usage
}
// Context helper functions for SubTurn
type turnStateKeyType struct{}
var turnStateKey = turnStateKeyType{}
func withTurnState(ctx context.Context, ts *turnState) context.Context {
return context.WithValue(ctx, turnStateKey, ts)
}
func turnStateFromContext(ctx context.Context) *turnState {
ts, _ := ctx.Value(turnStateKey).(*turnState)
return ts
}
// TurnStateFromContext retrieves turnState from context (exported for tools)
func TurnStateFromContext(ctx context.Context) *turnState {
return turnStateFromContext(ctx)
}
+1
View File
@@ -14,6 +14,7 @@ func BuiltinDefinitions() []Definition {
switchCommand(),
checkCommand(),
clearCommand(),
subagentsCommand(),
reloadCommand(),
}
}
+42
View File
@@ -0,0 +1,42 @@
package commands
import (
"context"
"fmt"
)
// TurnInfo is a mirrored struct from agent.TurnInfo to avoid circular dependencies.
type TurnInfo struct {
TurnID string
ParentTurnID string
Depth int
ChildTurnIDs []string
IsFinished bool
}
func subagentsCommand() Definition {
return Definition{
Name: "subagents",
Description: "Show running subagents and task tree",
Handler: func(ctx context.Context, req Request, rt *Runtime) error {
getTurnFn := rt.GetActiveTurn
if getTurnFn == nil {
return req.Reply("Runtime does not support querying active turns.")
}
turnRaw := getTurnFn()
if turnRaw == nil {
return req.Reply("No active tasks running in this session.")
}
if treeStr, ok := turnRaw.(string); ok {
if treeStr == "" {
return req.Reply("No active tasks running in this session.")
}
return req.Reply(fmt.Sprintf("🤖 **Active Subagents Tree**\n```text\n%s\n```", treeStr))
}
return req.Reply(fmt.Sprintf("🤖 **Active Subagents List**\n```text\n%+v\n```", turnRaw))
},
}
}
+1
View File
@@ -12,6 +12,7 @@ type Runtime struct {
ListDefinitions func() []Definition
ListSkillNames func() []string
GetEnabledChannels func() []string
GetActiveTurn func() any // Returning any to avoid circular dependency with agent package
SwitchModel func(value string) (oldModel string, err error)
SwitchChannel func(value string) error
ClearHistory func() error
+58 -8
View File
@@ -84,6 +84,7 @@ type Config struct {
Providers ProvidersConfig `json:"providers,omitempty"`
ModelList []ModelConfig `json:"model_list"` // New model-centric provider configuration
Gateway GatewayConfig `json:"gateway"`
Hooks HooksConfig `json:"hooks,omitempty"`
Tools ToolsConfig `json:"tools"`
Heartbeat HeartbeatConfig `json:"heartbeat"`
Devices DevicesConfig `json:"devices"`
@@ -92,6 +93,36 @@ type Config struct {
BuildInfo BuildInfo `json:"build_info,omitempty"`
}
type HooksConfig struct {
Enabled bool `json:"enabled"`
Defaults HookDefaultsConfig `json:"defaults,omitempty"`
Builtins map[string]BuiltinHookConfig `json:"builtins,omitempty"`
Processes map[string]ProcessHookConfig `json:"processes,omitempty"`
}
type HookDefaultsConfig struct {
ObserverTimeoutMS int `json:"observer_timeout_ms,omitempty"`
InterceptorTimeoutMS int `json:"interceptor_timeout_ms,omitempty"`
ApprovalTimeoutMS int `json:"approval_timeout_ms,omitempty"`
}
type BuiltinHookConfig struct {
Enabled bool `json:"enabled"`
Priority int `json:"priority,omitempty"`
Config json.RawMessage `json:"config,omitempty"`
}
type ProcessHookConfig struct {
Enabled bool `json:"enabled"`
Priority int `json:"priority,omitempty"`
Transport string `json:"transport,omitempty"`
Command []string `json:"command,omitempty"`
Dir string `json:"dir,omitempty"`
Env map[string]string `json:"env,omitempty"`
Observe []string `json:"observe,omitempty"`
Intercept []string `json:"intercept,omitempty"`
}
// BuildInfo contains build-time version information
type BuildInfo struct {
Version string `json:"version"`
@@ -219,9 +250,15 @@ type RoutingConfig struct {
Threshold float64 `json:"threshold"` // complexity score in [0,1]; score >= threshold → primary model
}
// ToolFeedbackConfig controls whether tool execution details are sent to the
// chat channel as real-time feedback messages. When enabled, every tool call
// produces a short notification with the tool name and its parameters.
// SubTurnConfig configures the SubTurn execution system.
type SubTurnConfig struct {
MaxDepth int `json:"max_depth" env:"PICOCLAW_AGENTS_DEFAULTS_SUBTURN_MAX_DEPTH"`
MaxConcurrent int `json:"max_concurrent" env:"PICOCLAW_AGENTS_DEFAULTS_SUBTURN_MAX_CONCURRENT"`
DefaultTimeoutMinutes int `json:"default_timeout_minutes" env:"PICOCLAW_AGENTS_DEFAULTS_SUBTURN_DEFAULT_TIMEOUT_MINUTES"`
DefaultTokenBudget int `json:"default_token_budget" env:"PICOCLAW_AGENTS_DEFAULTS_SUBTURN_DEFAULT_TOKEN_BUDGET"`
ConcurrencyTimeoutSec int `json:"concurrency_timeout_sec" env:"PICOCLAW_AGENTS_DEFAULTS_SUBTURN_CONCURRENCY_TIMEOUT_SEC"`
}
type ToolFeedbackConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_AGENTS_DEFAULTS_TOOL_FEEDBACK_ENABLED"`
MaxArgsLength int `json:"max_args_length" env:"PICOCLAW_AGENTS_DEFAULTS_TOOL_FEEDBACK_MAX_ARGS_LENGTH"`
@@ -238,12 +275,15 @@ type AgentDefaults struct {
ImageModel string `json:"image_model,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_IMAGE_MODEL"`
ImageModelFallbacks []string `json:"image_model_fallbacks,omitempty"`
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
ContextWindow int `json:"context_window,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_CONTEXT_WINDOW"`
Temperature *float64 `json:"temperature,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
SummarizeMessageThreshold int `json:"summarize_message_threshold" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_MESSAGE_THRESHOLD"`
SummarizeTokenPercent int `json:"summarize_token_percent" env:"PICOCLAW_AGENTS_DEFAULTS_SUMMARIZE_TOKEN_PERCENT"`
MaxMediaSize int `json:"max_media_size,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_MEDIA_SIZE"`
Routing *RoutingConfig `json:"routing,omitempty"`
SteeringMode string `json:"steering_mode,omitempty" env:"PICOCLAW_AGENTS_DEFAULTS_STEERING_MODE"` // "one-at-a-time" (default) or "all"
SubTurn SubTurnConfig `json:"subturn" envPrefix:"PICOCLAW_AGENTS_DEFAULTS_SUBTURN_"`
ToolFeedback ToolFeedbackConfig `json:"tool_feedback,omitempty"`
LogLevel string `json:"log_level,omitempty" env:"PICOCLAW_LOG_LEVEL"`
}
@@ -923,10 +963,13 @@ func LoadConfig(path string) (*Config, error) {
if passphrase := credential.PassphraseProvider(); passphrase != "" {
for _, m := range cfg.ModelList {
if m.APIKey != "" && !strings.HasPrefix(m.APIKey, "enc://") && !strings.HasPrefix(m.APIKey, "file://") {
fmt.Fprintf(os.Stderr,
if m.APIKey != "" && !strings.HasPrefix(m.APIKey, "enc://") &&
!strings.HasPrefix(m.APIKey, "file://") {
fmt.Fprintf(
os.Stderr,
"picoclaw: warning: model %q has a plaintext api_key; call SaveConfig to encrypt it\n",
m.ModelName)
m.ModelName,
)
}
}
}
@@ -979,7 +1022,8 @@ func encryptPlaintextAPIKeys(models []ModelConfig, passphrase string) ([]ModelCo
changed := false
for i := range sealed {
m := &sealed[i]
if m.APIKey == "" || strings.HasPrefix(m.APIKey, "enc://") || strings.HasPrefix(m.APIKey, "file://") {
if m.APIKey == "" || strings.HasPrefix(m.APIKey, "enc://") ||
strings.HasPrefix(m.APIKey, "file://") {
continue
}
encrypted, err := credential.Encrypt(passphrase, "", m.APIKey)
@@ -1012,7 +1056,13 @@ func resolveAPIKeys(models []ModelConfig, configDir string) error {
for j, key := range models[i].APIKeys {
resolved, err := cr.Resolve(key)
if err != nil {
return fmt.Errorf("model_list[%d] (%s): api_keys[%d]: %w", i, models[i].ModelName, j, err)
return fmt.Errorf(
"model_list[%d] (%s): api_keys[%d]: %w",
i,
models[i].ModelName,
j,
err,
)
}
models[i].APIKeys[j] = resolved
}
+98
View File
@@ -470,6 +470,22 @@ func TestDefaultConfig_CronAllowCommandEnabled(t *testing.T) {
}
}
func TestDefaultConfig_HooksDefaults(t *testing.T) {
cfg := DefaultConfig()
if !cfg.Hooks.Enabled {
t.Fatal("DefaultConfig().Hooks.Enabled should be true")
}
if cfg.Hooks.Defaults.ObserverTimeoutMS != 500 {
t.Fatalf("ObserverTimeoutMS = %d, want 500", cfg.Hooks.Defaults.ObserverTimeoutMS)
}
if cfg.Hooks.Defaults.InterceptorTimeoutMS != 5000 {
t.Fatalf("InterceptorTimeoutMS = %d, want 5000", cfg.Hooks.Defaults.InterceptorTimeoutMS)
}
if cfg.Hooks.Defaults.ApprovalTimeoutMS != 60000 {
t.Fatalf("ApprovalTimeoutMS = %d, want 60000", cfg.Hooks.Defaults.ApprovalTimeoutMS)
}
}
func TestDefaultConfig_LogLevel(t *testing.T) {
cfg := DefaultConfig()
if cfg.Agents.Defaults.LogLevel != "fatal" {
@@ -562,6 +578,88 @@ func TestLoadConfig_WebToolsProxy(t *testing.T) {
}
}
func TestLoadConfig_HooksProcessConfig(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.json")
configJSON := `{
"hooks": {
"processes": {
"review-gate": {
"enabled": true,
"transport": "stdio",
"command": ["uvx", "picoclaw-hook-reviewer"],
"dir": "/tmp/hooks",
"env": {
"HOOK_MODE": "rewrite"
},
"observe": ["turn_start", "turn_end"],
"intercept": ["before_tool", "approve_tool"]
}
},
"builtins": {
"audit": {
"enabled": true,
"priority": 5,
"config": {
"label": "audit"
}
}
}
}
}`
if err := os.WriteFile(configPath, []byte(configJSON), 0o600); err != nil {
t.Fatalf("os.WriteFile() error: %v", err)
}
cfg, err := LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error: %v", err)
}
processCfg, ok := cfg.Hooks.Processes["review-gate"]
if !ok {
t.Fatal("expected review-gate process hook")
}
if !processCfg.Enabled {
t.Fatal("expected review-gate process hook to be enabled")
}
if processCfg.Transport != "stdio" {
t.Fatalf("Transport = %q, want stdio", processCfg.Transport)
}
if len(processCfg.Command) != 2 || processCfg.Command[0] != "uvx" {
t.Fatalf("Command = %v", processCfg.Command)
}
if processCfg.Dir != "/tmp/hooks" {
t.Fatalf("Dir = %q, want /tmp/hooks", processCfg.Dir)
}
if processCfg.Env["HOOK_MODE"] != "rewrite" {
t.Fatalf("HOOK_MODE = %q, want rewrite", processCfg.Env["HOOK_MODE"])
}
if len(processCfg.Observe) != 2 || processCfg.Observe[1] != "turn_end" {
t.Fatalf("Observe = %v", processCfg.Observe)
}
if len(processCfg.Intercept) != 2 || processCfg.Intercept[1] != "approve_tool" {
t.Fatalf("Intercept = %v", processCfg.Intercept)
}
builtinCfg, ok := cfg.Hooks.Builtins["audit"]
if !ok {
t.Fatal("expected audit builtin hook")
}
if !builtinCfg.Enabled {
t.Fatal("expected audit builtin hook to be enabled")
}
if builtinCfg.Priority != 5 {
t.Fatalf("Priority = %d, want 5", builtinCfg.Priority)
}
if !strings.Contains(string(builtinCfg.Config), `"audit"`) {
t.Fatalf("Config = %s", string(builtinCfg.Config))
}
if cfg.Hooks.Defaults.ApprovalTimeoutMS != 60000 {
t.Fatalf("ApprovalTimeoutMS = %d, want 60000", cfg.Hooks.Defaults.ApprovalTimeoutMS)
}
}
// TestDefaultConfig_DMScope verifies the default dm_scope value
// TestDefaultConfig_SummarizationThresholds verifies summarization defaults
func TestDefaultConfig_SummarizationThresholds(t *testing.T) {
+9
View File
@@ -36,6 +36,7 @@ func DefaultConfig() *Config {
MaxToolIterations: 50,
SummarizeMessageThreshold: 20,
SummarizeTokenPercent: 75,
SteeringMode: "one-at-a-time",
ToolFeedback: ToolFeedbackConfig{
Enabled: true,
MaxArgsLength: 300,
@@ -193,6 +194,14 @@ func DefaultConfig() *Config {
AllowFrom: FlexibleStringSlice{},
},
},
Hooks: HooksConfig{
Enabled: true,
Defaults: HookDefaultsConfig{
ObserverTimeoutMS: 500,
InterceptorTimeoutMS: 5000,
ApprovalTimeoutMS: 60000,
},
},
Providers: ProvidersConfig{
OpenAI: OpenAIProviderConfig{WebSearch: true},
},
+10 -1
View File
@@ -214,11 +214,20 @@ func ParseResponse(body io.Reader) (*LLMResponse, error) {
Reasoning: choice.Message.Reasoning,
ReasoningDetails: choice.Message.ReasoningDetails,
ToolCalls: toolCalls,
FinishReason: choice.FinishReason,
FinishReason: normalizeFinishReason(choice.FinishReason),
Usage: apiResponse.Usage,
}, nil
}
// normalizeFinishReason normalizes finish_reason values across providers.
// Converts "length" to "truncated" for consistent handling.
func normalizeFinishReason(reason string) string {
if reason == "length" {
return "truncated"
}
return reason
}
// DecodeToolCallArguments decodes a tool call's arguments from raw JSON.
func DecodeToolCallArguments(raw json.RawMessage, name string) map[string]any {
arguments := make(map[string]any)
+19
View File
@@ -384,3 +384,22 @@ func (r *ToolRegistry) GetSummaries() []string {
}
return summaries
}
// GetAll returns all registered tools (both core and non-core with TTL > 0).
// Used by SubTurn to inherit parent's tool set.
func (r *ToolRegistry) GetAll() []Tool {
r.mu.RLock()
defer r.mu.RUnlock()
sorted := r.sortedToolNames()
tools := make([]Tool, 0, len(sorted))
for _, name := range sorted {
entry := r.tools[name]
// Include core tools and non-core tools with active TTL
if entry.IsCore || entry.TTL > 0 {
tools = append(tools, entry.Tool)
}
}
return tools
}
+10 -1
View File
@@ -1,6 +1,10 @@
package tools
import "encoding/json"
import (
"encoding/json"
"github.com/sipeed/picoclaw/pkg/providers"
)
// ToolResult represents the structured return value from tool execution.
// It provides clear semantics for different types of results and supports
@@ -34,6 +38,11 @@ type ToolResult struct {
// Media contains media store refs produced by this tool.
// When non-empty, the agent will publish these as OutboundMediaMessage.
Media []string `json:"media,omitempty"`
// Messages holds the ephemeral session history after execution.
// Only populated by SubTurn executions; used by evaluator_optimizer
// to carry stateful worker context across evaluation iterations.
Messages []providers.Message `json:"-"`
}
// NewToolResult creates a basic ToolResult with content for the LLM.
+71 -25
View File
@@ -7,7 +7,10 @@ import (
)
type SpawnTool struct {
manager *SubagentManager
spawner SubTurnSpawner
defaultModel string
maxTokens int
temperature float64
allowlistCheck func(targetAgentID string) bool
}
@@ -15,9 +18,19 @@ type SpawnTool struct {
var _ AsyncExecutor = (*SpawnTool)(nil)
func NewSpawnTool(manager *SubagentManager) *SpawnTool {
return &SpawnTool{
manager: manager,
if manager == nil {
return &SpawnTool{}
}
return &SpawnTool{
defaultModel: manager.defaultModel,
maxTokens: manager.maxTokens,
temperature: manager.temperature,
}
}
// SetSpawner sets the SubTurnSpawner for direct sub-turn execution.
func (t *SpawnTool) SetSpawner(spawner SubTurnSpawner) {
t.spawner = spawner
}
func (t *SpawnTool) Name() string {
@@ -59,11 +72,19 @@ func (t *SpawnTool) Execute(ctx context.Context, args map[string]any) *ToolResul
// ExecuteAsync implements AsyncExecutor. The callback is passed through to the
// subagent manager as a call parameter — never stored on the SpawnTool instance.
func (t *SpawnTool) ExecuteAsync(ctx context.Context, args map[string]any, cb AsyncCallback) *ToolResult {
func (t *SpawnTool) ExecuteAsync(
ctx context.Context,
args map[string]any,
cb AsyncCallback,
) *ToolResult {
return t.execute(ctx, args, cb)
}
func (t *SpawnTool) execute(ctx context.Context, args map[string]any, cb AsyncCallback) *ToolResult {
func (t *SpawnTool) execute(
ctx context.Context,
args map[string]any,
cb AsyncCallback,
) *ToolResult {
task, ok := args["task"].(string)
if !ok || strings.TrimSpace(task) == "" {
return ErrorResult("task is required and must be a non-empty string")
@@ -79,28 +100,53 @@ func (t *SpawnTool) execute(ctx context.Context, args map[string]any, cb AsyncCa
}
}
if t.manager == nil {
return ErrorResult("Subagent manager not configured")
// Build system prompt for spawned subagent
systemPrompt := fmt.Sprintf(
`You are a spawned subagent running in the background. Complete the given task independently and report back when done.
Task: %s`,
task,
)
if label != "" {
systemPrompt = fmt.Sprintf(
`You are a spawned subagent labeled "%s" running in the background. Complete the given task independently and report back when done.
Task: %s`,
label,
task,
)
}
// Read channel/chatID from context (injected by registry).
// Fall back to "cli"/"direct" for non-conversation callers (e.g., CLI, tests)
// to preserve the same defaults as the original NewSpawnTool constructor.
channel := ToolChannel(ctx)
if channel == "" {
channel = "cli"
}
chatID := ToolChatID(ctx)
if chatID == "" {
chatID = "direct"
// Use spawner if available (direct SpawnSubTurn call)
if t.spawner != nil {
// Launch async sub-turn in goroutine
go func() {
result, err := t.spawner.SpawnSubTurn(ctx, SubTurnConfig{
Model: t.defaultModel,
Tools: nil, // Will inherit from parent via context
SystemPrompt: systemPrompt,
MaxTokens: t.maxTokens,
Temperature: t.temperature,
Async: true, // Async execution
})
if err != nil {
result = ErrorResult(fmt.Sprintf("Spawn failed: %v", err)).WithError(err)
}
// Call callback if provided
if cb != nil {
cb(ctx, result)
}
}()
// Return immediate acknowledgment
if label != "" {
return AsyncResult(fmt.Sprintf("Spawned subagent '%s' for task: %s", label, task))
}
return AsyncResult(fmt.Sprintf("Spawned subagent for task: %s", task))
}
// Pass callback to manager for async completion notification
result, err := t.manager.Spawn(ctx, task, label, agentID, channel, chatID, cb)
if err != nil {
return ErrorResult(fmt.Sprintf("failed to spawn subagent: %v", err))
}
// Return AsyncResult since the task runs in background
return AsyncResult(result)
// Fallback: spawner not configured
return ErrorResult("Subagent manager not configured")
}
+19
View File
@@ -6,6 +6,24 @@ import (
"testing"
)
// mockSpawner implements SubTurnSpawner for testing
type mockSpawner struct{}
func (m *mockSpawner) SpawnSubTurn(ctx context.Context, cfg SubTurnConfig) (*ToolResult, error) {
// Extract task from system prompt for response
task := cfg.SystemPrompt
if strings.Contains(task, "Task: ") {
parts := strings.Split(task, "Task: ")
if len(parts) > 1 {
task = parts[1]
}
}
return &ToolResult{
ForLLM: "Task completed: " + task,
ForUser: "Task completed",
}, nil
}
func TestSpawnTool_Execute_EmptyTask(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
@@ -44,6 +62,7 @@ func TestSpawnTool_Execute_ValidTask(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
tool := NewSpawnTool(manager)
tool.SetSpawner(&mockSpawner{})
ctx := context.Background()
args := map[string]any{
+180 -127
View File
@@ -4,11 +4,34 @@ import (
"context"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/sipeed/picoclaw/pkg/providers"
)
// SubTurnSpawner is an interface for spawning sub-turns.
// This avoids circular dependency between tools and agent packages.
type SubTurnSpawner interface {
SpawnSubTurn(ctx context.Context, cfg SubTurnConfig) (*ToolResult, error)
}
// SubTurnConfig holds configuration for spawning a sub-turn.
type SubTurnConfig struct {
Model string
Tools []Tool
SystemPrompt string
MaxTokens int
Temperature float64
Async bool // true for async (spawn), false for sync (subagent)
Critical bool // continue running after parent finishes gracefully
Timeout time.Duration // 0 = use default (5 minutes)
MaxContextRunes int // 0 = auto, -1 = no limit, >0 = explicit limit
ActualSystemPrompt string
InitialMessages []providers.Message
InitialTokenBudget *atomic.Int64 // Shared token budget for team members; nil if no budget
}
type SubagentTask struct {
ID string
Task string
@@ -21,6 +44,15 @@ type SubagentTask struct {
Created int64
}
type SpawnSubTurnFunc func(
ctx context.Context,
task, label, agentID string,
tools *ToolRegistry,
maxTokens int,
temperature float64,
hasMaxTokens, hasTemperature bool,
) (*ToolResult, error)
type SubagentManager struct {
tasks map[string]*SubagentTask
mu sync.RWMutex
@@ -34,6 +66,7 @@ type SubagentManager struct {
hasMaxTokens bool
hasTemperature bool
nextID int
spawner SpawnSubTurnFunc
}
func NewSubagentManager(
@@ -51,6 +84,12 @@ func NewSubagentManager(
}
}
func (sm *SubagentManager) SetSpawner(spawner SpawnSubTurnFunc) {
sm.mu.Lock()
defer sm.mu.Unlock()
sm.spawner = spawner
}
// SetLLMOptions sets max tokens and temperature for subagent LLM calls.
func (sm *SubagentManager) SetLLMOptions(maxTokens int, temperature float64) {
sm.mu.Lock()
@@ -108,22 +147,16 @@ func (sm *SubagentManager) Spawn(
return fmt.Sprintf("Spawned subagent for task: %s", task), nil
}
func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask, callback AsyncCallback) {
// Build system prompt for subagent
systemPrompt := `You are a subagent. Complete the given task independently and report the result.
You have access to tools - use them as needed to complete your task.
After completing the task, provide a clear summary of what was done.`
messages := []providers.Message{
{
Role: "system",
Content: systemPrompt,
},
{
Role: "user",
Content: task.Task,
},
}
func (sm *SubagentManager) runTask(
ctx context.Context,
task *SubagentTask,
callback AsyncCallback,
) {
task.Status = "running"
task.Created = time.Now().UnixMilli()
// TODO(eventbus): once subagents are modeled as child turns inside
// pkg/agent, emit SubTurnEnd and SubTurnResultDelivered from the parent
// AgentLoop instead of this legacy manager.
// Check if context is already canceled before starting
select {
@@ -136,8 +169,8 @@ After completing the task, provide a clear summary of what was done.`
default:
}
// Run tool loop with access to tools
sm.mu.RLock()
spawner := sm.spawner
tools := sm.tools
maxIter := sm.maxIterations
maxTokens := sm.maxTokens
@@ -146,27 +179,69 @@ After completing the task, provide a clear summary of what was done.`
hasTemperature := sm.hasTemperature
sm.mu.RUnlock()
var llmOptions map[string]any
if hasMaxTokens || hasTemperature {
llmOptions = map[string]any{}
if hasMaxTokens {
llmOptions["max_tokens"] = maxTokens
var result *ToolResult
var err error
if spawner != nil {
result, err = spawner(
ctx,
task.Task,
task.Label,
task.AgentID,
tools,
maxTokens,
temperature,
hasMaxTokens,
hasTemperature,
)
} else {
// Fallback to legacy RunToolLoop
systemPrompt := `You are a subagent. Complete the given task independently and report the result.
You have access to tools - use them as needed to complete your task.
After completing the task, provide a clear summary of what was done.`
messages := []providers.Message{
{Role: "system", Content: systemPrompt},
{Role: "user", Content: task.Task},
}
if hasTemperature {
llmOptions["temperature"] = temperature
var llmOptions map[string]any
if hasMaxTokens || hasTemperature {
llmOptions = map[string]any{}
if hasMaxTokens {
llmOptions["max_tokens"] = maxTokens
}
if hasTemperature {
llmOptions["temperature"] = temperature
}
}
var loopResult *ToolLoopResult
loopResult, err = RunToolLoop(ctx, ToolLoopConfig{
Provider: sm.provider,
Model: sm.defaultModel,
Tools: tools,
MaxIterations: maxIter,
LLMOptions: llmOptions,
}, messages, task.OriginChannel, task.OriginChatID)
if err == nil {
result = &ToolResult{
ForLLM: fmt.Sprintf(
"Subagent '%s' completed (iterations: %d): %s",
task.Label,
loopResult.Iterations,
loopResult.Content,
),
ForUser: loopResult.Content,
Silent: false,
IsError: false,
Async: false,
}
}
}
loopResult, err := RunToolLoop(ctx, ToolLoopConfig{
Provider: sm.provider,
Model: sm.defaultModel,
Tools: tools,
MaxIterations: maxIter,
LLMOptions: llmOptions,
}, messages, task.OriginChannel, task.OriginChatID)
sm.mu.Lock()
var result *ToolResult
defer func() {
sm.mu.Unlock()
// Call callback if provided and result is set
@@ -193,19 +268,7 @@ After completing the task, provide a clear summary of what was done.`
}
} else {
task.Status = "completed"
task.Result = loopResult.Content
result = &ToolResult{
ForLLM: fmt.Sprintf(
"Subagent '%s' completed (iterations: %d): %s",
task.Label,
loopResult.Iterations,
loopResult.Content,
),
ForUser: loopResult.Content,
Silent: false,
IsError: false,
Async: false,
}
task.Result = result.ForLLM
}
}
@@ -253,16 +316,28 @@ func (sm *SubagentManager) ListTaskCopies() []SubagentTask {
}
// SubagentTool executes a subagent task synchronously and returns the result.
// Unlike SpawnTool which runs tasks asynchronously, SubagentTool waits for completion
// and returns the result directly in the ToolResult.
// It directly calls SubTurnSpawner with Async=false for synchronous execution.
type SubagentTool struct {
manager *SubagentManager
spawner SubTurnSpawner
defaultModel string
maxTokens int
temperature float64
}
func NewSubagentTool(manager *SubagentManager) *SubagentTool {
return &SubagentTool{
manager: manager,
if manager == nil {
return &SubagentTool{}
}
return &SubagentTool{
defaultModel: manager.defaultModel,
maxTokens: manager.maxTokens,
temperature: manager.temperature,
}
}
// SetSpawner sets the SubTurnSpawner for direct sub-turn execution.
func (t *SubagentTool) SetSpawner(spawner SubTurnSpawner) {
t.spawner = spawner
}
func (t *SubagentTool) Name() string {
@@ -298,86 +373,64 @@ func (t *SubagentTool) Execute(ctx context.Context, args map[string]any) *ToolRe
label, _ := args["label"].(string)
if t.manager == nil {
return ErrorResult("Subagent manager not configured").WithError(fmt.Errorf("manager is nil"))
// Build system prompt for subagent
systemPrompt := fmt.Sprintf(
`You are a subagent. Complete the given task independently and provide a clear, concise result.
Task: %s`,
task,
)
if label != "" {
systemPrompt = fmt.Sprintf(
`You are a subagent labeled "%s". Complete the given task independently and provide a clear, concise result.
Task: %s`,
label,
task,
)
}
// Build messages for subagent
messages := []providers.Message{
{
Role: "system",
Content: "You are a subagent. Complete the given task independently and provide a clear, concise result.",
},
{
Role: "user",
Content: task,
},
}
// Use RunToolLoop to execute with tools (same as async SpawnTool)
sm := t.manager
sm.mu.RLock()
tools := sm.tools
maxIter := sm.maxIterations
maxTokens := sm.maxTokens
temperature := sm.temperature
hasMaxTokens := sm.hasMaxTokens
hasTemperature := sm.hasTemperature
sm.mu.RUnlock()
var llmOptions map[string]any
if hasMaxTokens || hasTemperature {
llmOptions = map[string]any{}
if hasMaxTokens {
llmOptions["max_tokens"] = maxTokens
// Use spawner if available (direct SpawnSubTurn call)
if t.spawner != nil {
result, err := t.spawner.SpawnSubTurn(ctx, SubTurnConfig{
Model: t.defaultModel,
Tools: nil, // Will inherit from parent via context
SystemPrompt: systemPrompt,
MaxTokens: t.maxTokens,
Temperature: t.temperature,
Async: false, // Synchronous execution
})
if err != nil {
return ErrorResult(fmt.Sprintf("Subagent execution failed: %v", err)).WithError(err)
}
if hasTemperature {
llmOptions["temperature"] = temperature
// Format result for display
userContent := result.ForLLM
if result.ForUser != "" {
userContent = result.ForUser
}
maxUserLen := 500
if len(userContent) > maxUserLen {
userContent = userContent[:maxUserLen] + "..."
}
labelStr := label
if labelStr == "" {
labelStr = "(unnamed)"
}
llmContent := fmt.Sprintf("Subagent task completed:\nLabel: %s\nResult: %s",
labelStr, result.ForLLM)
return &ToolResult{
ForLLM: llmContent,
ForUser: userContent,
Silent: false,
IsError: result.IsError,
Async: false,
}
}
// Fall back to "cli"/"direct" for non-conversation callers (e.g., CLI, tests)
// to preserve the same defaults as the original NewSubagentTool constructor.
channel := ToolChannel(ctx)
if channel == "" {
channel = "cli"
}
chatID := ToolChatID(ctx)
if chatID == "" {
chatID = "direct"
}
loopResult, err := RunToolLoop(ctx, ToolLoopConfig{
Provider: sm.provider,
Model: sm.defaultModel,
Tools: tools,
MaxIterations: maxIter,
LLMOptions: llmOptions,
}, messages, channel, chatID)
if err != nil {
return ErrorResult(fmt.Sprintf("Subagent execution failed: %v", err)).WithError(err)
}
// ForUser: Brief summary for user (truncated if too long)
userContent := loopResult.Content
maxUserLen := 500
if len(userContent) > maxUserLen {
userContent = userContent[:maxUserLen] + "..."
}
// ForLLM: Full execution details
labelStr := label
if labelStr == "" {
labelStr = "(unnamed)"
}
llmContent := fmt.Sprintf("Subagent task completed:\nLabel: %s\nIterations: %d\nResult: %s",
labelStr, loopResult.Iterations, loopResult.Content)
return &ToolResult{
ForLLM: llmContent,
ForUser: userContent,
Silent: false,
IsError: false,
Async: false,
}
// Fallback: spawner not configured
return ErrorResult("Subagent manager not configured").WithError(fmt.Errorf("spawner not set"))
}
+13 -14
View File
@@ -48,24 +48,19 @@ func TestSubagentManager_SetLLMOptions_AppliesToRunToolLoop(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
manager.SetLLMOptions(2048, 0.6)
tool := NewSubagentTool(manager)
ctx := WithToolContext(context.Background(), "cli", "direct")
args := map[string]any{"task": "Do something"}
result := tool.Execute(ctx, args)
if result == nil || result.IsError {
t.Fatalf("Expected successful result, got: %+v", result)
// Verify options are set on manager
if manager.maxTokens != 2048 {
t.Errorf("manager.maxTokens = %d, want 2048", manager.maxTokens)
}
if provider.lastOptions == nil {
t.Fatal("Expected LLM options to be passed, got nil")
if manager.temperature != 0.6 {
t.Errorf("manager.temperature = %f, want 0.6", manager.temperature)
}
if provider.lastOptions["max_tokens"] != 2048 {
t.Fatalf("max_tokens = %v, want %d", provider.lastOptions["max_tokens"], 2048)
if !manager.hasMaxTokens {
t.Error("manager.hasMaxTokens should be true")
}
if provider.lastOptions["temperature"] != 0.6 {
t.Fatalf("temperature = %v, want %v", provider.lastOptions["temperature"], 0.6)
if !manager.hasTemperature {
t.Error("manager.hasTemperature should be true")
}
}
@@ -150,6 +145,7 @@ func TestSubagentTool_Execute_Success(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
tool := NewSubagentTool(manager)
tool.SetSpawner(&mockSpawner{})
ctx := WithToolContext(context.Background(), "telegram", "chat-123")
args := map[string]any{
@@ -204,6 +200,7 @@ func TestSubagentTool_Execute_NoLabel(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
tool := NewSubagentTool(manager)
tool.SetSpawner(&mockSpawner{})
ctx := context.Background()
args := map[string]any{
@@ -277,6 +274,7 @@ func TestSubagentTool_Execute_ContextPassing(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
tool := NewSubagentTool(manager)
tool.SetSpawner(&mockSpawner{})
channel := "test-channel"
chatID := "test-chat"
@@ -302,6 +300,7 @@ func TestSubagentTool_ForUserTruncation(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test")
tool := NewSubagentTool(manager)
tool.SetSpawner(&mockSpawner{})
ctx := context.Background()
+173
View File
@@ -0,0 +1,173 @@
// PicoClaw - Ultra-lightweight personal AI agent
// Inspired by and based on nanobot: https://github.com/HKUDS/nanobot
// License: MIT
//
// Copyright (c) 2026 PicoClaw contributors
package utils
import (
"encoding/json"
"fmt"
"unicode/utf8"
"github.com/sipeed/picoclaw/pkg/providers"
)
// CalculateDefaultMaxContextRunes computes a default context limit based on the model's context window.
// Strategy: Use 75% of the context window and convert to rune estimate.
//
// Token-to-rune conversion ratios (conservative estimates):
// - English: ~4 chars per token
// - Chinese: ~1.5-2 chars per token
// - Mixed: ~3 chars per token (used here for safety)
func CalculateDefaultMaxContextRunes(contextWindow int) int {
if contextWindow <= 0 {
// Conservative fallback when context window is unknown
return 8000 // ~2000 tokens
}
// Use 75% of context window to leave headroom
targetTokens := int(float64(contextWindow) * 0.75)
// Convert tokens to runes using conservative ratio
const avgCharsPerToken = 3
return targetTokens * avgCharsPerToken
}
// ResolveMaxContextRunes determines the final MaxContextRunes value to use.
// Priority: explicit config > auto-calculate > conservative default
func ResolveMaxContextRunes(configValue, contextWindow int) int {
switch {
case configValue > 0:
// Explicitly configured, use as-is
return configValue
case configValue == -1:
// Explicitly disabled
return -1
default:
// 0 or unset: auto-calculate
return CalculateDefaultMaxContextRunes(contextWindow)
}
}
// MeasureContextRunes calculates the total rune count of a message list.
// Includes content, reasoning content, and estimates for tool calls.
func MeasureContextRunes(messages []providers.Message) int {
totalRunes := 0
for _, msg := range messages {
totalRunes += utf8.RuneCountInString(msg.Content)
totalRunes += utf8.RuneCountInString(msg.ReasoningContent)
// Tool calls: serialize to JSON and count
if len(msg.ToolCalls) > 0 {
for _, tc := range msg.ToolCalls {
totalRunes += utf8.RuneCountInString(tc.Name)
// Arguments: serialize and count
if argsJSON, err := json.Marshal(tc.Arguments); err == nil {
totalRunes += utf8.RuneCount(argsJSON)
} else {
// Fallback estimate if serialization fails
totalRunes += 100
}
}
}
// ToolCallID
totalRunes += utf8.RuneCountInString(msg.ToolCallID)
}
return totalRunes
}
// TruncateContextSmart intelligently truncates message history to fit within maxRunes.
//
// Strategy:
// 1. Always preserve system messages (they define the agent's behavior)
// 2. Keep the most recent messages (they contain current context)
// 3. Drop older middle messages when necessary
// 4. Insert a truncation notice to inform the LLM
//
// Returns the truncated message list.
func TruncateContextSmart(messages []providers.Message, maxRunes int) []providers.Message {
if len(messages) == 0 {
return messages
}
// Separate system messages from others
var systemMsgs []providers.Message
var otherMsgs []providers.Message
for _, msg := range messages {
if msg.Role == "system" {
systemMsgs = append(systemMsgs, msg)
} else {
otherMsgs = append(otherMsgs, msg)
}
}
// Calculate system message size
systemRunes := 0
for _, msg := range systemMsgs {
systemRunes += utf8.RuneCountInString(msg.Content)
systemRunes += utf8.RuneCountInString(msg.ReasoningContent)
}
// Reserve space for truncation notice (estimate ~80 runes)
const truncationNoticeEstimate = 80
// Allocate remaining space for other messages
remainingRunes := maxRunes - systemRunes - truncationNoticeEstimate
if remainingRunes <= 0 {
// System messages already exceed limit - return only system messages
return systemMsgs
}
// Collect recent messages in reverse order until we hit the limit
var keptMsgs []providers.Message
currentRunes := 0
for i := len(otherMsgs) - 1; i >= 0; i-- {
msg := otherMsgs[i]
msgRunes := utf8.RuneCountInString(msg.Content) +
utf8.RuneCountInString(msg.ReasoningContent)
// Estimate tool call size
if len(msg.ToolCalls) > 0 {
for _, tc := range msg.ToolCalls {
msgRunes += utf8.RuneCountInString(tc.Name)
if argsJSON, err := json.Marshal(tc.Arguments); err == nil {
msgRunes += utf8.RuneCount(argsJSON)
} else {
msgRunes += 100
}
}
}
msgRunes += utf8.RuneCountInString(msg.ToolCallID)
if currentRunes+msgRunes > remainingRunes {
// Would exceed limit, stop collecting
break
}
// Prepend to maintain chronological order
keptMsgs = append([]providers.Message{msg}, keptMsgs...)
currentRunes += msgRunes
}
// If we dropped messages, add a truncation notice
result := systemMsgs
if len(keptMsgs) < len(otherMsgs) {
droppedCount := len(otherMsgs) - len(keptMsgs)
truncationNotice := providers.Message{
Role: "system",
Content: fmt.Sprintf(
"[Context truncated: %d earlier messages omitted to stay within context limits]",
droppedCount,
),
}
result = append(result, truncationNotice)
}
result = append(result, keptMsgs...)
return result
}
+450
View File
@@ -0,0 +1,450 @@
// PicoClaw - Ultra-lightweight personal AI agent
// License: MIT
//
// Copyright (c) 2026 PicoClaw contributors
package utils
import (
"testing"
"github.com/sipeed/picoclaw/pkg/providers"
)
func TestCalculateDefaultMaxContextRunes(t *testing.T) {
tests := []struct {
name string
contextWindow int
want int
}{
{
name: "zero context window uses fallback",
contextWindow: 0,
want: 8000,
},
{
name: "negative context window uses fallback",
contextWindow: -1,
want: 8000,
},
{
name: "small context window (4k tokens)",
contextWindow: 4000,
want: 9000, // 4000 * 0.75 * 3 = 9000
},
{
name: "medium context window (128k tokens)",
contextWindow: 128000,
want: 288000, // 128000 * 0.75 * 3 = 288000
},
{
name: "large context window (1M tokens)",
contextWindow: 1000000,
want: 2250000, // 1000000 * 0.75 * 3 = 2250000
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := CalculateDefaultMaxContextRunes(tt.contextWindow)
if got != tt.want {
t.Errorf("CalculateDefaultMaxContextRunes(%d) = %d, want %d",
tt.contextWindow, got, tt.want)
}
})
}
}
func TestResolveMaxContextRunes(t *testing.T) {
tests := []struct {
name string
configValue int
contextWindow int
want int
}{
{
name: "explicit positive value",
configValue: 12000,
contextWindow: 4000,
want: 12000,
},
{
name: "explicit disable (-1)",
configValue: -1,
contextWindow: 4000,
want: -1,
},
{
name: "zero uses auto-calculate",
configValue: 0,
contextWindow: 4000,
want: 9000, // 4000 * 0.75 * 3
},
{
name: "unset (0) with unknown context window",
configValue: 0,
contextWindow: 0,
want: 8000, // fallback
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ResolveMaxContextRunes(tt.configValue, tt.contextWindow)
if got != tt.want {
t.Errorf("ResolveMaxContextRunes(%d, %d) = %d, want %d",
tt.configValue, tt.contextWindow, got, tt.want)
}
})
}
}
func TestMeasureContextRunes(t *testing.T) {
tests := []struct {
name string
messages []providers.Message
want int
}{
{
name: "empty messages",
messages: []providers.Message{},
want: 0,
},
{
name: "single simple message",
messages: []providers.Message{
{Role: "user", Content: "Hello"},
},
want: 5, // "Hello" = 5 runes
},
{
name: "message with reasoning",
messages: []providers.Message{
{
Role: "assistant",
Content: "Answer",
ReasoningContent: "Thinking",
},
},
want: 14, // "Answer" (6) + "Thinking" (8) = 14
},
{
name: "message with tool call",
messages: []providers.Message{
{
Role: "assistant",
Content: "Using tool",
ToolCalls: []providers.ToolCall{
{
Name: "test_tool",
Arguments: map[string]any{"key": "value"},
},
},
},
},
want: 10 + 9 + 15, // "Using tool" + "test_tool" + {"key":"value"}
},
{
name: "multiple messages",
messages: []providers.Message{
{Role: "system", Content: "You are helpful"},
{Role: "user", Content: "Hi"},
{Role: "assistant", Content: "Hello!"},
},
want: 15 + 2 + 6, // 15 + 2 + 6 = 23
},
{
name: "unicode characters",
messages: []providers.Message{
{Role: "user", Content: "\u4f60\u597d\u4e16\u754c"}, // 4 Chinese characters
},
want: 4,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := MeasureContextRunes(tt.messages)
if got != tt.want {
t.Errorf("MeasureContextRunes() = %d, want %d", got, tt.want)
}
})
}
}
func TestTruncateContextSmart(t *testing.T) {
tests := []struct {
name string
messages []providers.Message
maxRunes int
wantLen int
wantHas []string // Content strings that should be present
wantNot []string // Content strings that should be absent
}{
{
name: "empty messages",
messages: []providers.Message{},
maxRunes: 100,
wantLen: 0,
},
{
name: "no truncation needed",
messages: []providers.Message{
{Role: "system", Content: "System"},
{Role: "user", Content: "Hello"},
},
maxRunes: 100,
wantLen: 2,
wantHas: []string{"System", "Hello"},
},
{
name: "truncate when limit is tight",
messages: []providers.Message{
{Role: "system", Content: "System"},
{Role: "user", Content: "Message 1 with some content here"},
{Role: "assistant", Content: "Response 1 with some content here"},
{Role: "user", Content: "Message 2 with some content here"},
{Role: "assistant", Content: "Response 2 with some content here"},
{Role: "user", Content: "Latest"},
},
maxRunes: 120, // Tight limit to force truncation
wantLen: -1, // Don't check exact length, just verify truncation occurred
wantHas: []string{"System", "Latest"},
wantNot: []string{"Message 1", "Response 1"},
},
{
name: "system messages exceed limit",
messages: []providers.Message{
{Role: "system", Content: "Very long system message"},
{Role: "user", Content: "User message"},
},
maxRunes: 10, // Less than system message
wantLen: 1, // Only system message
wantHas: []string{"Very long system message"},
wantNot: []string{"User message"},
},
{
name: "preserve multiple system messages",
messages: []providers.Message{
{Role: "system", Content: "Sys1"},
{Role: "system", Content: "Sys2"},
{Role: "user", Content: "Old"},
{Role: "user", Content: "New"},
},
maxRunes: 200, // Generous limit
wantLen: 4, // Both system + truncation notice + new
wantHas: []string{"Sys1", "Sys2", "New"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := TruncateContextSmart(tt.messages, tt.maxRunes)
if tt.wantLen >= 0 && len(got) != tt.wantLen {
t.Errorf("TruncateContextSmart() returned %d messages, want %d",
len(got), tt.wantLen)
}
// Check for expected content
allContent := ""
for _, msg := range got {
allContent += msg.Content + " "
}
for _, want := range tt.wantHas {
found := false
for _, msg := range got {
if msg.Content == want || containsSubstring(msg.Content, want) {
found = true
break
}
}
if !found {
t.Errorf("Expected content %q not found in truncated messages", want)
}
}
for _, notWant := range tt.wantNot {
for _, msg := range got {
if containsSubstring(msg.Content, notWant) {
t.Errorf("Unexpected content %q found in truncated messages", notWant)
}
}
}
})
}
}
func containsSubstring(s, substr string) bool {
return len(s) >= len(substr) && findSubstring(s, substr)
}
func findSubstring(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}
// TestSubTurnConfigMaxContextRunes verifies that MaxContextRunes configuration
// is properly integrated into the SubTurn execution flow.
func TestSubTurnConfigMaxContextRunes(t *testing.T) {
tests := []struct {
name string
maxContextRunes int
contextWindow int
wantResolved int
}{
{
name: "default (0) auto-calculates from context window",
maxContextRunes: 0,
contextWindow: 4000,
wantResolved: 9000, // 4000 * 0.75 * 3
},
{
name: "explicit value is used",
maxContextRunes: 12000,
contextWindow: 4000,
wantResolved: 12000,
},
{
name: "disabled (-1) returns -1",
maxContextRunes: -1,
contextWindow: 4000,
wantResolved: -1,
},
{
name: "fallback when context window unknown",
maxContextRunes: 0,
contextWindow: 0,
wantResolved: 8000, // conservative fallback
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ResolveMaxContextRunes(tt.maxContextRunes, tt.contextWindow)
if got != tt.wantResolved {
t.Errorf("utils.ResolveMaxContextRunes(%d, %d) = %d, want %d",
tt.maxContextRunes, tt.contextWindow, got, tt.wantResolved)
}
})
}
}
// TestContextTruncationFlow verifies the complete context truncation flow:
// 1. Messages accumulate beyond soft limit
// 2. Truncation is triggered
// 3. System messages are preserved
// 4. Recent messages are kept
func TestContextTruncationFlow(t *testing.T) {
// Build a message history that exceeds the limit
messages := []providers.Message{
{Role: "system", Content: "You are a helpful assistant"}, // ~27 runes
{Role: "user", Content: "First question"}, // ~14 runes
{Role: "assistant", Content: "First answer"}, // ~12 runes
{Role: "user", Content: "Second question"}, // ~15 runes
{Role: "assistant", Content: "Second answer"}, // ~13 runes
{Role: "user", Content: "Third question"}, // ~14 runes
{Role: "assistant", Content: "Third answer"}, // ~12 runes
{Role: "user", Content: "Latest question"}, // ~15 runes
}
// Total: ~122 runes
totalRunes := MeasureContextRunes(messages)
if totalRunes < 100 {
t.Errorf("Expected total runes > 100, got %d", totalRunes)
}
// Set limit to 150 runes - should force truncation of old messages
// but preserve system + truncation notice + recent messages
maxRunes := 150
truncated := TruncateContextSmart(messages, maxRunes)
// Verify truncation occurred
if len(truncated) >= len(messages) {
t.Errorf("Expected truncation, but got %d messages (original: %d)",
len(truncated), len(messages))
}
// Verify system message is preserved
foundSystem := false
for _, msg := range truncated {
if msg.Role == "system" && msg.Content == "You are a helpful assistant" {
foundSystem = true
break
}
}
if !foundSystem {
t.Error("System message was not preserved after truncation")
}
// Verify latest message is preserved
foundLatest := false
for _, msg := range truncated {
if msg.Content == "Latest question" {
foundLatest = true
break
}
}
if !foundLatest {
t.Error("Latest message was not preserved after truncation")
}
// Verify truncation notice is present
foundNotice := false
for _, msg := range truncated {
if msg.Role == "system" && containsSubstring(msg.Content, "truncated") {
foundNotice = true
break
}
}
if !foundNotice {
t.Error("Truncation notice was not added")
}
// Verify result is within limit (with some tolerance for estimation)
resultRunes := MeasureContextRunes(truncated)
if resultRunes > maxRunes+20 { // Allow 20 rune tolerance
t.Errorf("Truncated context (%d runes) significantly exceeds limit (%d runes)",
resultRunes, maxRunes)
}
}
// TestContextTruncationPreservesToolCalls verifies that tool calls are
// properly handled during context truncation.
func TestContextTruncationPreservesToolCalls(t *testing.T) {
messages := []providers.Message{
{Role: "system", Content: "System"},
{Role: "user", Content: "Old message that should be dropped"},
{
Role: "assistant",
Content: "Recent tool use",
ToolCalls: []providers.ToolCall{
{
Name: "important_tool",
Arguments: map[string]any{"key": "value"},
},
},
},
}
// Set a generous limit that should keep the tool call message
maxRunes := 200
truncated := TruncateContextSmart(messages, maxRunes)
// Verify tool call message is preserved
foundToolCall := false
for _, msg := range truncated {
if len(msg.ToolCalls) > 0 && msg.ToolCalls[0].Name == "important_tool" {
foundToolCall = true
break
}
}
if !foundToolCall {
t.Error("Tool call message was not preserved during truncation")
}
}